├── README.md
├── Setup.hs
├── test
└── Spec.hs
├── demo
├── mnist-ghcjs.js_o
└── mnist-ghcjs.hs
├── .gitignore
├── src
└── Backprop
│ ├── Learn.hs
│ └── Learn
│ ├── Train.hs
│ ├── Model
│ ├── Parameter.hs
│ ├── Stochastic.hs
│ ├── Neural.hs
│ ├── State.hs
│ └── Neural
│ │ └── LSTM.hs
│ ├── Loss.hs
│ ├── Run.hs
│ ├── Test.hs
│ └── Initialize.hs
├── old2
└── src
│ ├── Backprop
│ ├── Learn.hs
│ └── Learn
│ │ ├── Train.hs
│ │ ├── Run.hs
│ │ ├── Initialize.hs
│ │ ├── Model
│ │ ├── Parameter.hs
│ │ ├── Neural.hs
│ │ ├── Stochastic.hs
│ │ └── Class.hs
│ │ ├── Loss.hs
│ │ └── Test.hs
│ └── Data
│ └── Type
│ ├── NonEmpty.hs
│ └── Mayb.hs
├── TODO.txt
├── old
├── src
│ ├── Learn
│ │ ├── Neural.hs
│ │ └── Neural
│ │ │ ├── Test.hs
│ │ │ ├── Layer
│ │ │ ├── Identity.hs
│ │ │ ├── Applying.hs
│ │ │ ├── FullyConnected.hs
│ │ │ └── Compose.hs
│ │ │ ├── Loss.hs
│ │ │ └── Network
│ │ │ └── Dropout.hs
│ ├── Data
│ │ └── Type
│ │ │ └── Util.hs
│ └── Numeric
│ │ ├── BLAS
│ │ ├── FVector.hs
│ │ └── NVector.hs
│ │ └── BLAS.hs
└── app
│ ├── MNIST.hs
│ ├── Letter2Vec.hs
│ └── Language.hs
├── LICENSE
├── stack.yaml
├── app
├── mnist.hs
├── word2vec.hs
└── char-rnn.hs
├── Build.hs
├── package.yaml
├── stack.yaml.lock
├── stack-ghcjs.yaml
└── .travis.yml
/README.md:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Setup.hs:
--------------------------------------------------------------------------------
1 | import Distribution.Simple
2 | main = defaultMain
3 |
--------------------------------------------------------------------------------
/test/Spec.hs:
--------------------------------------------------------------------------------
1 | main :: IO ()
2 | main = putStrLn "Test suite not yet implemented"
3 |
--------------------------------------------------------------------------------
/demo/mnist-ghcjs.js_o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mstksg/backprop-learn/HEAD/demo/mnist-ghcjs.js_o
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /.stack-work
2 | /data
3 |
4 | /TAGS
5 | /tags
6 | /out
7 |
8 | /.build
9 | /.shake
10 | /demo-exe
11 | /demo-js
12 |
13 | /backprop-learn.cabal
14 |
15 | /testout
16 |
17 | *.dump-hi
18 |
--------------------------------------------------------------------------------
/demo/mnist-ghcjs.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE OverloadedStrings #-}
2 |
3 | import Reflex.Dom
4 |
5 | main :: IO ()
6 | main = mainWidget $ el "div" $ do
7 | t <- textInput def
8 | dynText $ _textInput_value t
9 |
--------------------------------------------------------------------------------
/src/Backprop/Learn.hs:
--------------------------------------------------------------------------------
1 |
2 | module Backprop.Learn (
3 | module L
4 | ) where
5 |
6 | import Backprop.Learn.Initialize as L
7 | import Backprop.Learn.Loss as L
8 | import Backprop.Learn.Model as L
9 | import Backprop.Learn.Run as L
10 | import Backprop.Learn.Test as L
11 | import Backprop.Learn.Train as L
12 |
--------------------------------------------------------------------------------
/old2/src/Backprop/Learn.hs:
--------------------------------------------------------------------------------
1 |
2 | module Backprop.Learn (
3 | module L
4 | ) where
5 |
6 | import Backprop.Learn.Initialize as L
7 | import Backprop.Learn.Loss as L
8 | import Backprop.Learn.Model as L
9 | import Backprop.Learn.Run as L
10 | import Backprop.Learn.Test as L
11 | import Backprop.Learn.Train as L
12 |
--------------------------------------------------------------------------------
/TODO.txt:
--------------------------------------------------------------------------------
1 |
2 | * Reset if diverge
3 | * Feedback loops
4 | * Get rid of initial state completely nad just require initialiation to zero
5 | for destate
6 | * trace state or internal activations once unrolled? (actually maybe easy)
7 | * KL-divergence autencoders
8 | * Variational autoencoders
9 | * Address regularization
10 | * GRU lstms
11 | * Binary
12 | * Dynamically grow or shrink networks
13 | * Reinforcement learning: loss function takes no "target"
14 | * Elman
15 |
16 |
17 | * Get rid of initparam and initstate
18 |
--------------------------------------------------------------------------------
/old/src/Learn/Neural.hs:
--------------------------------------------------------------------------------
1 |
2 | module Learn.Neural (
3 | module N
4 | ) where
5 |
6 | import Learn.Neural.Layer as N
7 | import Learn.Neural.Layer.Applying as N
8 | import Learn.Neural.Layer.FullyConnected as N
9 | import Learn.Neural.Layer.Identity as N
10 | import Learn.Neural.Layer.Mapping as N
11 | import Learn.Neural.Loss as N
12 | import Learn.Neural.Network as N
13 | import Learn.Neural.Network.Dropout as N
14 | import Learn.Neural.Test as N
15 | import Learn.Neural.Train as N
16 | import Numeric.BLAS as N
17 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright Justin Le (c) 2017
2 |
3 | All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | * Redistributions of source code must retain the above copyright
9 | notice, this list of conditions and the following disclaimer.
10 |
11 | * Redistributions in binary form must reproduce the above
12 | copyright notice, this list of conditions and the following
13 | disclaimer in the documentation and/or other materials provided
14 | with the distribution.
15 |
16 | * Neither the name of Justin Le nor the names of other
17 | contributors may be used to endorse or promote products derived
18 | from this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/old/src/Learn/Neural/Test.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DataKinds #-}
2 | {-# LANGUAGE LambdaCase #-}
3 | {-# LANGUAGE RankNTypes #-}
4 | {-# LANGUAGE ScopedTypeVariables #-}
5 | {-# LANGUAGE TypeApplications #-}
6 | {-# LANGUAGE TypeOperators #-}
7 |
8 | module Learn.Neural.Test (
9 | TestFunc
10 | , maxTest
11 | , rmseTest
12 | , squaredErrorTest
13 | , crossEntropyTest
14 | , testNet
15 | , testNetList
16 | ) where
17 |
18 | import Data.Proxy
19 | import GHC.TypeLits
20 | import Learn.Neural.Layer
21 | import Learn.Neural.Network
22 | import Numeric.BLAS
23 | import qualified Control.Foldl as F
24 |
25 | type TestFunc o = forall b. (BLAS b, Num (b o)) => b o -> b o -> Double
26 |
27 | maxTest :: KnownNat n => TestFunc '[n + 1]
28 | maxTest x y | iamax x == iamax y = 1
29 | | otherwise = 0
30 |
31 | rmseTest :: forall n. KnownNat n => TestFunc '[n]
32 | rmseTest x y = sqrt $ realToFrac (e `dot` e) / fromIntegral (natVal (Proxy @n))
33 | where
34 | e = axpy (-1) x y
35 |
36 | squaredErrorTest :: KnownNat n => TestFunc '[n]
37 | squaredErrorTest x y = realToFrac $ e `dot` e
38 | where
39 | e = axpy (-1) x y
40 |
41 | crossEntropyTest :: KnownNat n => TestFunc '[n]
42 | crossEntropyTest r t = negate $ realToFrac (tmap log r `dot` t)
43 |
44 | testNet
45 | :: (BLAS b, Num (b i), Num (b o))
46 | => TestFunc o
47 | -> SomeNet 'FeedForward b i o
48 | -> b i
49 | -> b o
50 | -> Double
51 | testNet tf = \case
52 | SomeNet _ n -> \x t ->
53 | let y = runNetPure n x
54 | in tf y t
55 |
56 | testNetList
57 | :: (BLAS b, Num (b i), Num (b o))
58 | => TestFunc o
59 | -> SomeNet 'FeedForward b i o
60 | -> [(b i, b o)]
61 | -> Double
62 | testNetList tf n = F.fold F.mean . fmap (uncurry (testNet tf n))
63 |
--------------------------------------------------------------------------------
/old/src/Learn/Neural/Layer/Identity.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE FlexibleInstances #-}
2 | {-# LANGUAGE KindSignatures #-}
3 | {-# LANGUAGE MultiParamTypeClasses #-}
4 | {-# LANGUAGE TypeFamilies #-}
5 | {-# LANGUAGE TypeInType #-}
6 |
7 |
8 | module Learn.Neural.Layer.Identity (
9 | Ident
10 | ) where
11 |
12 | import Data.Kind
13 | import Learn.Neural.Layer
14 | import Numeric.BLAS
15 | import Numeric.Backprop
16 |
17 | data Ident :: Type
18 |
19 | instance Num (CParam Ident b i i) where
20 | _ + _ = IdP
21 | _ * _ = IdP
22 | _ - _ = IdP
23 | negate _ = IdP
24 | abs _ = IdP
25 | signum _ = IdP
26 | fromInteger _ = IdP
27 |
28 | instance Fractional (CParam Ident b i i) where
29 | _ / _ = IdP
30 | recip _ = IdP
31 | fromRational _ = IdP
32 |
33 | instance Floating (CParam Ident b i i) where
34 | sqrt _ = IdP
35 |
36 | instance Num (CState Ident b i i) where
37 | _ + _ = IdS
38 | _ * _ = IdS
39 | _ - _ = IdS
40 | negate _ = IdS
41 | abs _ = IdS
42 | signum _ = IdS
43 | fromInteger _ = IdS
44 |
45 | instance Fractional (CState Ident b i i) where
46 | _ / _ = IdS
47 | recip _ = IdS
48 | fromRational _ = IdS
49 |
50 | instance Floating (CState Ident b i i) where
51 | sqrt _ = IdS
52 |
53 | instance BLAS b => Component Ident b i i where
54 | data CParam Ident b i i = IdP
55 | data CState Ident b i i = IdS
56 | data CConf Ident b i i = IdC
57 |
58 | componentOp = componentOpDefault
59 |
60 | initParam _ _ _ _ = return IdP
61 | initState _ _ _ _ = return IdS
62 | defConf = IdC
63 |
64 | instance BLAS b => ComponentFF Ident b i i where
65 | componentOpFF = bpOp . withInps $ \(x :< _ :< Ø) ->
66 | return . only $ x
67 |
68 | instance BLAS b => ComponentLayer r Ident b i i where
69 | componentRunMode = RMIsFF
70 |
--------------------------------------------------------------------------------
/stack.yaml:
--------------------------------------------------------------------------------
1 | # This file was automatically generated by 'stack init'
2 | #
3 | # Some commonly used options have been documented as comments in this file.
4 | # For advanced use and comprehensive documentation of the format, please see:
5 | # http://docs.haskellstack.org/en/stable/yaml_configuration/
6 |
7 | # Resolver to choose a 'specific' stackage snapshot or a compiler version.
8 | # A snapshot resolver dictates the compiler version and the set of packages
9 | # to be used for project dependencies. For example:
10 | #
11 | # resolver: lts-3.5
12 | # resolver: nightly-2015-09-21
13 | # resolver: ghc-7.10.2
14 | # resolver: ghcjs-0.1.0_ghc-7.10.2
15 | # resolver:
16 | # name: custom-snapshot
17 | # location: "./custom-snapshot.yaml"
18 | resolver: nightly-2020-01-30
19 |
20 | # User packages to be built.
21 | # Various formats can be used as shown in the example below.
22 | #
23 | # packages:
24 | # - some-directory
25 | # - https://example.com/foo/bar/baz-0.0.2.tar.gz
26 | # - location:
27 | # git: https://github.com/commercialhaskell/stack.git
28 | # commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a
29 | # - location: https://github.com/commercialhaskell/stack/commit/e7b331f14bcffb8367cd58fbfc8b40ec7642100a
30 | # extra-dep: true
31 | # subdirs:
32 | # - auto-update
33 | # - wai
34 | #
35 | # A package marked 'extra-dep: true' will only be built if demanded by a
36 | # non-dependency (i.e. a user package), and its test suites and benchmarks
37 | # will not be run. This is useful for tweaking upstream packages.
38 | packages:
39 | - '.'
40 |
41 | # Dependency packages to be pulled from upstream that are not in the resolver
42 | # (e.g., acme-missiles-0.3)
43 | extra-deps:
44 | - github: mstksg/opto
45 | commit: def9e41adc1123385037f92ead38c92cb6454dd2
46 | - backprop-0.2.6.3
47 | - decidable-0.3.0.0
48 | - functor-products-0.1.1.0
49 | - hmatrix-vector-sized-0.1.2.0
50 | - hmatrix-backprop-0.1.3.0
51 | - list-witnesses-0.1.3.2
52 | - typelits-witnesses-0.4.0.0
53 | - vinyl-0.12.0
54 | - dependent-sum-0.6.2.0
55 | - constraints-extras-0.3.0.2
56 |
57 | # Override default flag values for local packages and extra-deps
58 | flags: {}
59 |
60 | # Extra package databases containing global packages
61 | extra-package-dbs: []
62 |
63 | ghc-options:
64 | "$locals": -ddump-to-file -ddump-hi
65 |
66 | build:
67 | haddock-arguments:
68 | haddock-args:
69 | - --optghc=-fdefer-type-errors
70 |
71 | # Control whether we use the GHC we find on the path
72 | # system-ghc: true
73 | #
74 | # Require a specific version of stack, using version ranges
75 | # require-stack-version: -any # Default
76 | # require-stack-version: ">=1.3"
77 | #
78 | # Override the architecture used by stack, especially useful on Windows
79 | # arch: i386
80 | # arch: x86_64
81 | #
82 | # Extra directories used by stack for building
83 | # extra-include-dirs: [/path/to/dir]
84 | # extra-lib-dirs: [/path/to/dir]
85 | #
86 | # Allow a newer minor version of GHC than the snapshot specifies
87 | # compiler-check: newer-minor
88 |
--------------------------------------------------------------------------------
/src/Backprop/Learn/Train.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DataKinds #-}
2 | {-# LANGUAGE FlexibleContexts #-}
3 | {-# LANGUAGE MultiParamTypeClasses #-}
4 | {-# LANGUAGE PartialTypeSignatures #-}
5 | {-# LANGUAGE RankNTypes #-}
6 | {-# LANGUAGE ScopedTypeVariables #-}
7 | {-# LANGUAGE TypeApplications #-}
8 | {-# LANGUAGE TypeFamilies #-}
9 | {-# OPTIONS_GHC -Wno-partial-type-signatures #-}
10 |
11 | module Backprop.Learn.Train (
12 | -- * Gradients
13 | gradModelLoss
14 | , gradModelStochLoss
15 | -- * Opto
16 | , Grad
17 | , modelGrad
18 | , modelGradStoch
19 | ) where
20 |
21 | import Backprop.Learn.Loss
22 | import Backprop.Learn.Model
23 | import Control.Monad.Primitive
24 | import Control.Monad.ST
25 | import Data.Word
26 | import Numeric.Backprop
27 | import Numeric.Opto.Core
28 | import qualified Data.Vector.Unboxed as VU
29 | import qualified System.Random.MWC as MWC
30 |
31 | -- | Gradient of model with respect to loss function and target
32 | gradModelLoss
33 | :: Backprop p
34 | => Loss b
35 | -> Regularizer p
36 | -> Model ('Just p) 'Nothing a b
37 | -> p
38 | -> a
39 | -> b
40 | -> p
41 | gradModelLoss loss reg f p x y = gradBP (\p' ->
42 | loss y (runLearnStateless f (PJust p') (constVar x)) + reg p'
43 | ) p
44 |
45 | -- | Stochastic gradient of model with respect to loss function and target
46 | gradModelStochLoss
47 | :: (Backprop p, PrimMonad m)
48 | => Loss b
49 | -> Regularizer p
50 | -> Model ('Just p) 'Nothing a b
51 | -> MWC.Gen (PrimState m)
52 | -> p
53 | -> a
54 | -> b
55 | -> m p
56 | gradModelStochLoss loss reg f g p x y = do
57 | seed <- MWC.uniformVector @_ @Word32 @VU.Vector g 2
58 | pure $ gradBP (\p' -> runST $ do
59 | g' <- MWC.initialize seed
60 | lo <- loss y <$> runLearnStochStateless f g' (PJust p') (constVar x)
61 | pure (lo + reg p')
62 | ) p
63 |
64 | -- | Using a model's deterministic prediction function (with a given loss
65 | -- function), generate a 'Grad' compatible with "Numeric.Opto" and
66 | -- "Numeric.Opto.Run".
67 | modelGrad
68 | :: (Applicative m, Backprop p)
69 | => Loss b
70 | -> Regularizer p
71 | -> Model ('Just p) 'Nothing a b
72 | -> Grad m (a, b) p
73 | modelGrad loss reg f = pureGrad $ \(x,y) p -> gradModelLoss loss reg f p x y
74 |
75 | -- | Using a model's stochastic prediction function (with a given loss
76 | -- function), generate a 'Grad' compatible with "Numeric.Opto" and
77 | -- "Numeric.Opto.Run".
78 | modelGradStoch
79 | :: (PrimMonad m, Backprop p)
80 | => Loss b
81 | -> Regularizer p
82 | -> Model ('Just p) 'Nothing a b
83 | -> MWC.Gen (PrimState m)
84 | -> Grad m (a, b) p
85 | modelGradStoch loss reg f g = \(x,y) p ->
86 | gradModelStochLoss loss reg f g p x y
87 |
--------------------------------------------------------------------------------
/app/mnist.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DataKinds #-}
2 | {-# LANGUAGE FlexibleContexts #-}
3 | {-# LANGUAGE FlexibleInstances #-}
4 | {-# LANGUAGE MultiParamTypeClasses #-}
5 | {-# LANGUAGE PartialTypeSignatures #-}
6 | {-# LANGUAGE TupleSections #-}
7 | {-# LANGUAGE TypeApplications #-}
8 | {-# LANGUAGE TypeOperators #-}
9 | {-# LANGUAGE TypeSynonymInstances #-}
10 | {-# LANGUAGE UndecidableInstances #-}
11 | {-# OPTIONS_GHC -fno-warn-orphans #-}
12 | {-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
13 |
14 | import Backprop.Learn
15 | import Control.Monad.Trans.Maybe
16 | import Data.Bitraversable
17 | import Data.Default
18 | import Data.IDX
19 | import Data.Traversable
20 | import Data.Tuple
21 | import Numeric.LinearAlgebra.Static.Backprop
22 | import Numeric.Opto
23 | import Numeric.Opto.Run.Simple
24 | import System.Environment
25 | import System.FilePath
26 | import Text.Printf
27 | import qualified Data.Vector.Generic as VG
28 | import qualified Numeric.LinearAlgebra as HM
29 | import qualified Numeric.LinearAlgebra.Static as H
30 | import qualified System.Random.MWC as MWC
31 |
32 | mnistNet :: Model _ _ (R 784) (R 10)
33 | mnistNet = fca @300 softMax
34 | <~ dropout 0.25
35 | <~ fca @784 logistic
36 |
37 | main :: IO ()
38 | main = MWC.withSystemRandom $ \g -> do
39 | datadir:_ <- getArgs
40 | Just train <- loadMNIST (datadir > "train-images-idx3-ubyte")
41 | (datadir > "train-labels-idx1-ubyte")
42 | Just test <- loadMNIST (datadir > "t10k-images-idx3-ubyte")
43 | (datadir > "t10k-labels-idx1-ubyte")
44 | putStrLn "Loaded data."
45 | net0 <- initParamNormal mnistNet 0.2 g
46 |
47 | let so = def { soTestSet = Just test
48 | , soEvaluate = runTest
49 | , soSkipSamps = 2500
50 | }
51 | runTest chnk net = printf "Error: %.2f%%" ((1 - score) * 100)
52 | where
53 | score = testModelAll maxIxTest mnistNet (TJust net) chnk
54 |
55 | simpleRunner so train SOSingle def net0
56 | (adam def $ modelGradStoch crossEntropy noReg mnistNet g)
57 | g
58 |
59 | loadMNIST
60 | :: FilePath
61 | -> FilePath
62 | -> IO (Maybe [(R 784, R 10)])
63 | loadMNIST fpI fpL = runMaybeT $ do
64 | i <- MaybeT $ decodeIDXFile fpI
65 | l <- MaybeT $ decodeIDXLabelsFile fpL
66 | d <- MaybeT . return $ labeledIntData l i
67 | MaybeT . return $ for d (bitraverse mkImage mkLabel . swap)
68 | where
69 | mkImage = H.create . VG.convert . VG.map (\i -> fromIntegral i / 255)
70 | mkLabel n = H.create $ HM.build 10 (\i -> if round i == n then 1 else 0)
71 |
72 |
--------------------------------------------------------------------------------
/old/src/Learn/Neural/Loss.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DataKinds #-}
2 | {-# LANGUAGE GADTs #-}
3 | {-# LANGUAGE LambdaCase #-}
4 | {-# LANGUAGE PolyKinds #-}
5 | {-# LANGUAGE RankNTypes #-}
6 | {-# LANGUAGE ScopedTypeVariables #-}
7 | {-# LANGUAGE TypeApplications #-}
8 |
9 | module Learn.Neural.Loss (
10 | LossFunction
11 | , crossEntropy
12 | , squaredError
13 | , sumLoss
14 | , sumLossDecay
15 | , zipLoss
16 | ) where
17 |
18 | import Data.Type.Util
19 | import Data.Type.Vector
20 | import GHC.TypeLits
21 | import Numeric.BLAS
22 | import Numeric.Backprop
23 | import Numeric.Backprop.Op
24 | import Type.Class.Witness
25 | import qualified Data.Type.Nat as TCN
26 |
27 | -- type LossFunction s = forall b q. (BLAS b, Num (b s)) => b s -> OpB q '[ b s ] '[ Scalar b ]
28 | type LossFunction as b = forall s. Tuple as -> OpB s as '[ b ]
29 |
30 | crossEntropy
31 | :: forall b n. (BLAS b, KnownNat n, Num (b '[n]))
32 | => LossFunction '[ b '[n] ] (Scalar b)
33 | crossEntropy (I targ :< Ø) = bpOp . withInps $ \(r :< Ø) -> do
34 | logR <- tmapOp log ~$ (r :< Ø)
35 | res <- negate <$> (dotOp ~$ (logR :< t :< Ø))
36 | only <$> bindVar res
37 | where
38 | t = constVar targ
39 |
40 | squaredError
41 | :: (BLAS b, KnownNat n, Num (b '[n]))
42 | => LossFunction '[ b '[n] ] (Scalar b)
43 | squaredError (I targ :< Ø) = bpOp . withInps $ \(r :< Ø) -> do
44 | err <- bindVar $ r - t
45 | only <$> (dotOp ~$ (err :< err :< Ø))
46 | where
47 | t = constVar targ
48 |
49 | sumLoss
50 | :: forall n a b. (Num a, Num b)
51 | => LossFunction '[ a ] b
52 | -> TCN.Nat n
53 | -> LossFunction (Replicate n a) b
54 | sumLoss l = \case
55 | TCN.Z_ -> \case Ø -> op0 (only_ 0)
56 | TCN.S_ n -> \case
57 | I x :< xs -> (replLen @_ @a n //) $
58 | (replWit @_ @Num @a n Wit //) $
59 | bpOp . withInps $ \(y :< ys) -> do
60 | z <- l (only_ x) ~$ (y :< Ø)
61 | zs <- sumLoss l n xs ~$ ys
62 | return . only $ z + zs
63 |
64 | sumLossDecay
65 | :: forall n a b. (Num a, Num b)
66 | => LossFunction '[ a ] b
67 | -> TCN.Nat n
68 | -> b
69 | -> LossFunction (Replicate n a) b
70 | sumLossDecay l n λ = zipLoss l (genDecay 1 n)
71 | where
72 | genDecay :: b -> TCN.Nat m -> Vec m b
73 | genDecay b = \case
74 | TCN.Z_ -> ØV
75 | TCN.S_ m -> case genDecay b m of
76 | ØV -> b :+ ØV
77 | I c :* cs -> (c * λ) :+ c :+ cs
78 |
79 |
80 | zipLoss
81 | :: forall n a b. (Num a, Num b)
82 | => LossFunction '[ a ] b
83 | -> Vec n b
84 | -> LossFunction (Replicate n a) b
85 | zipLoss l = \case
86 | ØV -> \case Ø -> op0 (only 0)
87 | I α :* αs ->
88 | let αn = vecLenNat αs
89 | in \case
90 | I x :< xs -> (replLen @_ @a αn //) $
91 | (replWit @_ @Num @a αn Wit //) $
92 | bpOp . withInps $ \(y :< ys) -> do
93 | z <- l (only_ x) ~$ (y :< Ø)
94 | zs <- zipLoss l αs xs ~$ ys
95 | return . only $ (z * constVar α) + zs
96 |
--------------------------------------------------------------------------------
/Build.hs:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env stack
2 | -- stack --install-ghc runghc --package shake
3 |
4 | import Development.Shake
5 | import Development.Shake.FilePath
6 | import System.Directory
7 |
8 | opts = shakeOptions { shakeFiles = ".shake"
9 | , shakeVersion = "1.0"
10 | , shakeVerbosity = Normal
11 | , shakeThreads = 0
12 | }
13 |
14 | main :: IO ()
15 | main = getDirectoryFilesIO "demo" ["/*.lhs", "/*.hs"] >>= \allDemos ->
16 | getDirectoryFilesIO "src" ["//*.hs"] >>= \allSrc ->
17 | shakeArgs opts $ do
18 |
19 | want ["all"]
20 |
21 | "all" ~>
22 | -- need ["haddocks", "gentags", "install", "exe", "js"]
23 | need ["haddocks", "gentags", "install", "js"]
24 |
25 | "exe" ~>
26 | need (map (\f -> "demo-exe" > dropExtension f) allDemos)
27 |
28 | "js" ~>
29 | need (map (\f -> "demo-js" > dropExtension f) allDemos)
30 |
31 | "haddocks" ~> do
32 | need (("src" >) <$> allSrc)
33 | cmd "jle-git-haddocks"
34 |
35 | "install" ~> do
36 | need $ ("src" >) <$> allSrc
37 | cmd "stack install"
38 |
39 | "install-ghcjs" ~> do
40 | need $ ("src" >) <$> allSrc
41 | cmd "stack install --stack-yaml stack-ghcjs.yaml"
42 |
43 | "gentags" ~>
44 | need ["tags", "TAGS"]
45 |
46 | "demo-exe/*" %> \f -> do
47 | need ["install"]
48 | [src] <- getDirectoryFiles "demo" $ (takeFileName f <.>) <$> ["hs","lhs"]
49 | liftIO $ do
50 | createDirectoryIfMissing True "demo-exe"
51 | createDirectoryIfMissing True ".build"
52 | removeFilesAfter "demo" ["/*.o"]
53 | cmd "stack exec" "--package backprop-learn"
54 | "--"
55 | "ghc"
56 | ("demo" > src)
57 | "-o" f
58 | "-hidir" ".build"
59 | "-threaded"
60 | "-rtsopts"
61 | "-with-rtsopts=-N"
62 | "-Wall"
63 | "-O2"
64 |
65 | "demo-js/*" %> \f -> do
66 | need ["install-ghcjs"]
67 | [src] <- getDirectoryFiles "demo" $ (takeFileName f <.>) <$> ["hs","lhs"]
68 | liftIO $ do
69 | createDirectoryIfMissing True "demo-js"
70 | createDirectoryIfMissing True ".build"
71 | removeFilesAfter "demo" ["/*.o"]
72 | cmd "stack exec" "--package backprop-learn"
73 | "--package reflex-dom"
74 | "--stack-yaml stack-ghcjs.yaml"
75 | "--"
76 | "ghcjs"
77 | ("demo" > src)
78 | "-o" f
79 | "-hidir" ".build"
80 | "-threaded"
81 | "-rtsopts"
82 | "-with-rtsopts=-N"
83 | "-Wall"
84 | "-O2"
85 |
86 | ["tags","TAGS"] &%> \_ -> do
87 | need (("src" >) <$> allSrc)
88 | cmd "hasktags" "src/"
89 |
90 | "clean" ~> do
91 | unit $ cmd "stack clean"
92 | removeFilesAfter ".shake" ["//*"]
93 | removeFilesAfter ".build" ["//*"]
94 | removeFilesAfter "demo-exe" ["//*"]
95 |
96 |
--------------------------------------------------------------------------------
/package.yaml:
--------------------------------------------------------------------------------
1 | name: backprop-learn
2 | version: 0.1.0.0
3 | github: mstksg/backprop-learn
4 | license: BSD3
5 | author: Justin Le
6 | maintainer: justin@jle.im
7 | copyright: (c) Justin Le 2018
8 | tested-with: GHC >= 8.2
9 |
10 | extra-source-files:
11 | - README.md
12 |
13 | # Metadata used when publishing your package
14 | synopsis: Combinators and useful tools for ANNs using the backprop library
15 | category: Math
16 |
17 | # To avoid duplicated efforts in documentation and dealing with the
18 | # complications of embedding Haddock markup inside cabal files, it is
19 | # common to point users to the README.md file.
20 | description: See README.md
21 |
22 | ghc-options:
23 | - -Wall
24 | - -Wcompat
25 | - -Wincomplete-record-updates
26 | - -Wredundant-constraints
27 | # - -O0
28 |
29 | dependencies:
30 | - base >=4.7 && <5
31 | - ghc-typelits-extra
32 | - ghc-typelits-knownnat
33 | - ghc-typelits-natnormalise
34 | - hmatrix
35 | - hmatrix-backprop >= 0.1.3
36 | - mwc-random
37 | - opto
38 |
39 | library:
40 | source-dirs: src
41 | dependencies:
42 | - backprop >= 0.2.6.3
43 | - binary
44 | - bytestring
45 | - conduit
46 | - containers
47 | - deepseq
48 | - finite-typelits
49 | - foldl >= 1.4
50 | - functor-products
51 | - hmatrix-vector-sized >= 0.1.2
52 | - list-witnesses >= 0.1.2
53 | - microlens
54 | - microlens-th
55 | - one-liner
56 | - one-liner-instances
57 | - primitive
58 | - profunctors
59 | - singletons
60 | - statistics
61 | - transformers
62 | - typelits-witnesses
63 | - vector >= 0.12.0.2
64 | - vector-sized
65 | - vinyl
66 |
67 | _exec: &exec
68 | source-dirs: app
69 | ghc-options:
70 | - -threaded
71 | - -rtsopts
72 | - -with-rtsopts=-N
73 | - -O2
74 |
75 | executables:
76 | backprop-learn-mnist:
77 | <<: *exec
78 | main: mnist.hs
79 | dependencies:
80 | - backprop-learn
81 | - data-default
82 | - filepath
83 | - mnist-idx
84 | - transformers
85 | - vector >= 0.12.0.2
86 | backprop-learn-series:
87 | <<: *exec
88 | main: series.hs
89 | dependencies:
90 | - backprop >= 0.2.6.3
91 | - backprop-learn
92 | - conduit
93 | - data-default
94 | - deepseq
95 | - hmatrix-backprop >= 0.1.3
96 | - optparse-applicative
97 | - primitive
98 | - singletons
99 | - singletons
100 | - statistics
101 | - time
102 | - transformers
103 | - typelits-witnesses
104 | - vector-sized
105 | backprop-learn-char-rnn:
106 | <<: *exec
107 | main: char-rnn.hs
108 | dependencies:
109 | - backprop-learn
110 | - conduit
111 | - containers
112 | - data-default
113 | - deepseq
114 | - hmatrix-vector-sized
115 | - text
116 | - time
117 | - transformers
118 | - vector-sized
119 | backprop-learn-word2vec:
120 | <<: *exec
121 | main: word2vec.hs
122 | dependencies:
123 | - backprop-learn
124 | - conduit
125 | - containers
126 | - data-default
127 | - deepseq
128 | - hmatrix
129 | - text
130 | - time
131 | - transformers
132 | - vector >= 0.12.0.2
133 | - vector-sized
134 |
135 |
--------------------------------------------------------------------------------
/old/src/Learn/Neural/Layer/Applying.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE FlexibleContexts #-}
2 | {-# LANGUAGE FlexibleInstances #-}
3 | {-# LANGUAGE GADTs #-}
4 | {-# LANGUAGE MultiParamTypeClasses #-}
5 | {-# LANGUAGE RankNTypes #-}
6 | {-# LANGUAGE ScopedTypeVariables #-}
7 | {-# LANGUAGE TypeApplications #-}
8 | {-# LANGUAGE TypeFamilies #-}
9 | {-# LANGUAGE TypeInType #-}
10 |
11 | module Learn.Neural.Layer.Applying (
12 | Applying
13 | , TensorOp(..)
14 | , CommonOp(..)
15 | , SoftMax
16 | ) where
17 |
18 | import Data.Kind
19 | import Data.Reflection
20 | import Data.Singletons
21 | import GHC.TypeLits
22 | import Learn.Neural.Layer
23 | import Numeric.BLAS
24 | import Numeric.Backprop
25 |
26 | data Applying :: k -> Type
27 |
28 | newtype TensorOp :: [Nat] -> [Nat] -> Type where
29 | TF :: { getTensorOp
30 | :: forall b s. (BLAS b, Num (b i), Num (b o)) => OpB s '[ b i ] '[ b o ]
31 | }
32 | -> TensorOp i o
33 |
34 | instance Num (CParam (Applying s) b i o) where
35 | _ + _ = AppP
36 | _ * _ = AppP
37 | _ - _ = AppP
38 | negate _ = AppP
39 | abs _ = AppP
40 | signum _ = AppP
41 | fromInteger _ = AppP
42 |
43 | instance Fractional (CParam (Applying s) b i o) where
44 | _ / _ = AppP
45 | recip _ = AppP
46 | fromRational _ = AppP
47 |
48 | instance Floating (CParam (Applying s) b i o) where
49 | sqrt _ = AppP
50 |
51 | instance Num (CState (Applying s) b i o) where
52 | _ + _ = AppS
53 | _ * _ = AppS
54 | _ - _ = AppS
55 | negate _ = AppS
56 | abs _ = AppS
57 | signum _ = AppS
58 | fromInteger _ = AppS
59 |
60 | instance Fractional (CState (Applying s) b i o) where
61 | _ / _ = AppS
62 | recip _ = AppS
63 | fromRational _ = AppS
64 |
65 | instance Floating (CState (Applying s) b i o) where
66 | sqrt _ = AppS
67 |
68 | instance (BLAS b, Reifies s (TensorOp i o), SingI i, SingI o) => Component (Applying s) b i o where
69 | data CParam (Applying s) b i o = AppP
70 | data CState (Applying s) b i o = AppS
71 | data CConf (Applying s) b i o = AppC
72 |
73 | componentOp = componentOpDefault
74 |
75 | initParam _ _ _ _ = return AppP
76 | initState _ _ _ _ = return AppS
77 | defConf = AppC
78 |
79 | instance (BLAS b, Reifies s (TensorOp i o), SingI i, SingI o) => ComponentFF (Applying s) b i o where
80 | componentOpFF = bpOp . withInps $ \(x :< _ :< Ø) -> do
81 | y <- getTensorOp to ~$ (x :< Ø)
82 | return . only $ y
83 | where
84 | to :: TensorOp i o
85 | to = reflect (Proxy @s)
86 |
87 | instance (BLAS b, Reifies s (TensorOp i o), SingI i, SingI o) => ComponentLayer r (Applying s) b i o where
88 | componentRunMode = RMIsFF
89 |
90 | data CommonOp :: Type where
91 | TO_Softmax :: [Nat] -> CommonOp
92 |
93 | instance SingI i => Reifies ('TO_Softmax i) (TensorOp i i) where
94 | reflect _ = TF $ bpOp . withInps $ \(x :< Ø) -> do
95 | expX <- tmapOp exp ~$ (x :< Ø)
96 | totX <- tsumOp ~$ (expX :< Ø)
97 | sm <- scaleOp ~$ (1/totX :< expX :< Ø)
98 | return $ only sm
99 |
100 | type SoftMax i = Applying ('TO_Softmax i)
101 |
--------------------------------------------------------------------------------
/old/src/Data/Type/Util.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE AllowAmbiguousTypes #-}
2 | {-# LANGUAGE DataKinds #-}
3 | {-# LANGUAGE GADTs #-}
4 | {-# LANGUAGE KindSignatures #-}
5 | {-# LANGUAGE LambdaCase #-}
6 | {-# LANGUAGE PolyKinds #-}
7 | {-# LANGUAGE RankNTypes #-}
8 | {-# LANGUAGE ScopedTypeVariables #-}
9 | {-# LANGUAGE TypeApplications #-}
10 | {-# LANGUAGE TypeFamilyDependencies #-}
11 | {-# LANGUAGE TypeOperators #-}
12 | {-# OPTIONS_GHC -fno-warn-orphans #-}
13 |
14 | module Data.Type.Util (
15 | MaybeToList
16 | , replWit
17 | , replLen
18 | , prodToVec'
19 | , vecToProd
20 | , vtraverse
21 | , vecLenNat
22 | , zipP
23 | , last'
24 | , takeVec
25 | , unzipVec
26 | ) where
27 |
28 | import Data.Finite
29 | import Data.Type.Combinator
30 | import Data.Type.Conjunction
31 | import Data.Type.Index
32 | import Data.Type.Length
33 | import Data.Type.Nat
34 | import Data.Type.Product hiding (last')
35 | import Data.Type.Vector
36 | import Numeric.Backprop.Op (Replicate)
37 | import Type.Class.Higher
38 | import Type.Class.Witness
39 | import Type.Family.Nat
40 |
41 | type family MaybeToList (a :: Maybe k) = (b :: [k]) | b -> a where
42 | MaybeToList ('Just a ) = '[a]
43 | MaybeToList 'Nothing = '[]
44 |
45 | replWit
46 | :: forall n c a. ()
47 | => Nat n
48 | -> Wit (c a)
49 | -> Wit (Every c (Replicate n a))
50 | replWit = \case
51 | Z_ -> (Wit \\)
52 | S_ n -> \case
53 | w@Wit -> Wit \\ replWit n w
54 |
55 | replLen
56 | :: forall n a. ()
57 | => Nat n
58 | -> Length (Replicate n a)
59 | replLen = \case
60 | Z_ -> LZ
61 | S_ n -> LS (replLen @_ @a n)
62 |
63 | prodToVec'
64 | :: Nat n
65 | -> Prod f (Replicate n a)
66 | -> VecT n f a
67 | prodToVec' = \case
68 | Z_ -> \case
69 | Ø -> ØV
70 | S_ n -> \case
71 | x :< xs ->
72 | x :* prodToVec' n xs
73 |
74 | vecToProd
75 | :: VecT n f a
76 | -> Prod f (Replicate n a)
77 | vecToProd = \case
78 | ØV -> Ø
79 | x :* xs -> x :< vecToProd xs
80 |
81 | vtraverse
82 | :: Applicative h
83 | => (f a -> h (g b))
84 | -> VecT n f a
85 | -> h (VecT n g b)
86 | vtraverse f = \case
87 | ØV -> pure ØV
88 | x :* xs -> (:*) <$> f x <*> vtraverse f xs
89 |
90 | zipP
91 | :: Prod f as
92 | -> Prod g as
93 | -> Prod (f :&: g) as
94 | zipP = \case
95 | Ø -> \case
96 | Ø -> Ø
97 | x :< xs -> \case
98 | y:< ys -> (x :&: y) :< zipP xs ys
99 |
100 | vecLenNat
101 | :: VecT n f a
102 | -> Nat n
103 | vecLenNat = \case
104 | ØV -> Z_
105 | _ :* xs -> S_ (vecLenNat xs)
106 |
107 | instance Eq1 Finite
108 |
109 | last'
110 | :: VecT ('S n) f a
111 | -> f a
112 | last' = \case
113 | x :* ØV -> x
114 | _ :* xs@(_ :* _) -> last' xs
115 |
116 | takeVec
117 | :: Nat n
118 | -> [a]
119 | -> Maybe (Vec n a)
120 | takeVec = \case
121 | Z_ -> \_ -> Just ØV
122 | S_ n -> \case
123 | [] -> Nothing
124 | x:xs -> (x :+) <$> takeVec n xs
125 |
126 | unzipVec
127 | :: Vec n (a, b)
128 | -> (Vec n a, Vec n b)
129 | unzipVec = \case
130 | ØV -> (ØV, ØV)
131 | I (x,y) :* xsys ->
132 | let (xs, ys) = unzipVec xsys
133 | in (I x :* xs, I y :* ys)
134 |
--------------------------------------------------------------------------------
/old2/src/Backprop/Learn/Train.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DataKinds #-}
2 | {-# LANGUAGE FlexibleContexts #-}
3 | {-# LANGUAGE MultiParamTypeClasses #-}
4 | {-# LANGUAGE PartialTypeSignatures #-}
5 | {-# LANGUAGE RankNTypes #-}
6 | {-# LANGUAGE ScopedTypeVariables #-}
7 | {-# LANGUAGE TypeApplications #-}
8 | {-# LANGUAGE TypeFamilies #-}
9 | {-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
10 |
11 | module Backprop.Learn.Train (
12 | -- * Gradients
13 | gradLearnLoss
14 | , gradLearnStochLoss
15 | -- * Opto
16 | , Grad
17 | , learnGrad
18 | , learnGradStoch
19 | ) where
20 |
21 | import Backprop.Learn.Loss
22 | import Backprop.Learn.Model
23 | import Control.Monad.Primitive
24 | import Control.Monad.ST
25 | import Control.Monad.Sample
26 | import Data.Word
27 | import Numeric.Backprop
28 | import Numeric.Opto.Core
29 | import qualified Data.Vector.Unboxed as VU
30 | import qualified System.Random.MWC as MWC
31 |
32 | -- | Gradient of model with respect to loss function and target
33 | gradLearnLoss
34 | :: ( Learn a b l
35 | , NoState l
36 | , LParamMaybe l ~ 'Just (LParam l)
37 | , Backprop (LParam l)
38 | )
39 | => Loss b
40 | -> Regularizer (LParam l)
41 | -> l
42 | -> LParam l
43 | -> a
44 | -> b
45 | -> LParam l
46 | gradLearnLoss loss reg l p x y = gradBP (\p' ->
47 | loss y (runLearnStateless l (J_ p') (constVar x)) + reg p'
48 | ) p
49 |
50 | -- | Stochastic gradient of model with respect to loss function and target
51 | gradLearnStochLoss
52 | :: ( Learn a b l
53 | , NoState l
54 | , LParamMaybe l ~ 'Just (LParam l)
55 | , Backprop (LParam l)
56 | , PrimMonad m
57 | )
58 | => Loss b
59 | -> Regularizer (LParam l)
60 | -> l
61 | -> MWC.Gen (PrimState m)
62 | -> LParam l
63 | -> a
64 | -> b
65 | -> m (LParam l)
66 | gradLearnStochLoss loss reg l g p x y = do
67 | seed <- MWC.uniformVector @_ @Word32 @VU.Vector g 2
68 | pure $ gradBP (\p' -> runST $ do
69 | g' <- MWC.initialize seed
70 | lo <- loss y <$> runLearnStochStateless l g' (J_ p') (constVar x)
71 | pure (lo + reg p')
72 | ) p
73 |
74 | -- | Using a model's deterministic prediction function (with a given loss
75 | -- function), generate a 'Grad' compatible with "Numeric.Opto" and
76 | -- "Numeric.Opto.Run".
77 | learnGrad
78 | :: ( MonadSample (a, b) m
79 | , Learn a b l
80 | , NoState l
81 | , LParamMaybe l ~ 'Just (LParam l)
82 | , Backprop (LParam l)
83 | )
84 | => Loss b
85 | -> Regularizer (LParam l)
86 | -> l
87 | -> Grad m (LParam l)
88 | learnGrad loss reg l = pureSampling $ \(x,y) p -> gradLearnLoss loss reg l p x y
89 |
90 | -- | Using a model's stochastic prediction function (with a given loss
91 | -- function), generate a 'Grad' compatible with "Numeric.Opto" and
92 | -- "Numeric.Opto.Run".
93 | learnGradStoch
94 | :: ( MonadSample (a, b) m
95 | , PrimMonad m
96 | , Learn a b l
97 | , NoState l
98 | , LParamMaybe l ~ 'Just (LParam l)
99 | , Backprop (LParam l)
100 | )
101 | => Loss b
102 | -> Regularizer (LParam l)
103 | -> l
104 | -> MWC.Gen (PrimState m)
105 | -> Grad m (LParam l)
106 | learnGradStoch loss reg l g = sampling $ \(x,y) p ->
107 | gradLearnStochLoss loss reg l g p x y
108 |
--------------------------------------------------------------------------------
/src/Backprop/Learn/Model/Parameter.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DeriveDataTypeable #-}
2 | {-# LANGUAGE FlexibleContexts #-}
3 | {-# LANGUAGE FlexibleInstances #-}
4 | {-# LANGUAGE GADTs #-}
5 | {-# LANGUAGE KindSignatures #-}
6 | {-# LANGUAGE MultiParamTypeClasses #-}
7 | {-# LANGUAGE PatternSynonyms #-}
8 | {-# LANGUAGE RankNTypes #-}
9 | {-# LANGUAGE RecordWildCards #-}
10 | {-# LANGUAGE ScopedTypeVariables #-}
11 | {-# LANGUAGE TypeFamilies #-}
12 | {-# LANGUAGE TypeInType #-}
13 | {-# LANGUAGE UndecidableInstances #-}
14 |
15 | module Backprop.Learn.Model.Parameter (
16 | deParam, deParamD
17 | , reParam, reParamD
18 | , dummyParam
19 | ) where
20 |
21 | import Backprop.Learn.Model.Types
22 | import Control.Monad.Primitive
23 | import Numeric.Backprop
24 | import qualified System.Random.MWC as MWC
25 |
26 | -- | Fix a part of a parameter as constant, preventing backpropagation
27 | -- through it and not training it.
28 | --
29 | -- Treats a @pq@ parameter as essentialyl a @(p, q)@, witnessed through the
30 | -- split and join functions.
31 | --
32 | -- Takes the fixed value of @q@, as well as a stochastic mode version with
33 | -- fixed distribution.
34 | deParam
35 | :: forall p q pq s a b. (Backprop p, Backprop q)
36 | => (pq -> (p, q)) -- ^ split
37 | -> (p -> q -> pq) -- ^ join
38 | -> q -- ^ fixed param
39 | -> (forall m. (PrimMonad m) => MWC.Gen (PrimState m) -> m q) -- ^ fixed stoch param
40 | -> Model ('Just pq) s a b
41 | -> Model ('Just p ) s a b
42 | deParam spl joi q qStoch = reParam (PJust . r . fromPJust)
43 | (\g -> fmap PJust . rStoch g . fromPJust)
44 | where
45 | r :: Reifies z W => BVar z p -> BVar z pq
46 | r p = isoVar2 joi spl p (auto q)
47 | rStoch
48 | :: (PrimMonad m, Reifies z W)
49 | => MWC.Gen (PrimState m)
50 | -> BVar z p
51 | -> m (BVar z pq)
52 | rStoch g p = isoVar2 joi spl p . auto <$> qStoch g
53 |
54 | -- | 'deParam', but with no special stochastic mode version.
55 | deParamD
56 | :: (Backprop p, Backprop q)
57 | => (pq -> (p, q)) -- ^ split
58 | -> (p -> q -> pq) -- ^ join
59 | -> q -- ^ fixed param
60 | -> Model ('Just pq) s a b
61 | -> Model ('Just p ) s a b
62 | deParamD spl joi q = deParam spl joi q (const (pure q))
63 |
64 | -- | Pre-applies a function to a parameter before a model sees it.
65 | -- Essentially something like 'lmap' for parameters.
66 | --
67 | -- Takes a determinstic function and also a stochastic function for
68 | -- stochastic mode.
69 | reParam
70 | :: (forall z. Reifies z W => PMaybe (BVar z) q -> PMaybe (BVar z) p)
71 | -> (forall m z. (PrimMonad m, Reifies z W) => MWC.Gen (PrimState m) -> PMaybe (BVar z) q -> m (PMaybe (BVar z) p))
72 | -> Model p s a b
73 | -> Model q s a b
74 | reParam r rStoch f = Model
75 | { runLearn = runLearn f . r
76 | , runLearnStoch = \g p x s -> do
77 | q <- rStoch g p
78 | runLearnStoch f g q x s
79 | }
80 |
81 | -- | 'reParam', but with no special stochastic mode function.
82 | reParamD
83 | :: (forall z. Reifies z W => PMaybe (BVar z) q -> PMaybe (BVar z) p)
84 | -> Model p s a b
85 | -> Model q s a b
86 | reParamD r = reParam r (\_ -> pure . r)
87 |
88 | -- | Give an unparameterized model a "dummy" parameter. Useful for usage
89 | -- with combinators like 'Control.Category..' from that require all input
90 | -- models to share a common parameterization.
91 | dummyParam
92 | :: Model 'Nothing s a b
93 | -> Model p s a b
94 | dummyParam = reParamD (const PNothing)
95 |
--------------------------------------------------------------------------------
/old/app/MNIST.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DataKinds #-}
2 | {-# LANGUAGE ScopedTypeVariables #-}
3 | {-# LANGUAGE TypeOperators #-}
4 |
5 | import Control.DeepSeq
6 | import Control.Exception
7 | import Control.Monad
8 | import Control.Monad.IO.Class
9 | import Control.Monad.Trans.Maybe
10 | import Control.Monad.Trans.State
11 | import Data.Bifunctor
12 | import Data.Bitraversable
13 | import Data.Default
14 | import Data.Finite
15 | import Data.IDX
16 | import Data.List.Split
17 | import Data.Time.Clock
18 | import Data.Traversable
19 | import Data.Tuple
20 | import Data.Type.Product
21 | import Learn.Neural
22 | import Numeric.BLAS.HMatrix
23 | import Numeric.LinearAlgebra.Static
24 | import Text.Printf
25 | import qualified Data.Vector as V
26 | import qualified Data.Vector.Generic as VG
27 | import qualified Data.Vector.Unboxed as VU
28 | import qualified System.Random.MWC as MWC
29 | import qualified System.Random.MWC.Distributions as MWC
30 |
31 | loadMNIST
32 | :: FilePath
33 | -> FilePath
34 | -> IO (Maybe [(HM '[784], HM '[10])])
35 | loadMNIST fpI fpL = runMaybeT $ do
36 | i <- MaybeT $ decodeIDXFile fpI
37 | l <- MaybeT $ decodeIDXLabelsFile fpL
38 | d <- MaybeT . return $ labeledIntData l i
39 | r <- MaybeT . return $ for d (bitraverse mkImage mkLabel . swap)
40 | liftIO . evaluate $ force r
41 | where
42 | mkImage :: VU.Vector Int -> Maybe (HM '[784])
43 | mkImage = fmap HM . create . VG.convert . VG.map (\i -> fromIntegral i / 255)
44 | mkLabel :: Int -> Maybe (HM '[10])
45 | mkLabel = fmap (oneHot . only) . packFinite . fromIntegral
46 |
47 | main :: IO ()
48 | main = MWC.withSystemRandom $ \g -> do
49 | Just train <- loadMNIST "data/train-images-idx3-ubyte" "data/train-labels-idx1-ubyte"
50 | Just test <- loadMNIST "data/t10k-images-idx3-ubyte" "data/t10k-labels-idx1-ubyte"
51 | putStrLn "Loaded data."
52 | net0 :: Network 'FeedForward HM ( '[784] :~ FullyConnected )
53 | '[ '[300] :~ LogitMap
54 | , '[300] :~ FullyConnected
55 | , '[100] :~ LogitMap
56 | , '[100] :~ FullyConnected
57 | , '[10 ] :~ SoftMax '[10]
58 | ]
59 | '[10] <- initDefNet g
60 | let dout = alongNet net0 $ Nothing
61 | :&% Just 0.2
62 | :&% Nothing
63 | :&% Just 0.2
64 | :&% Nothing
65 | :&% DOExt
66 | flip evalStateT net0 . forM_ [1..] $ \e -> do
67 | train' <- liftIO . fmap V.toList $ MWC.uniformShuffle (V.fromList train) g
68 | liftIO $ printf "[Epoch %d]\n" (e :: Int)
69 |
70 | forM_ ([1..] `zip` chunksOf batch train') $ \(b, chnk) -> StateT $ \n0 -> do
71 | printf "(Batch %d)\n" (b :: Int)
72 |
73 | t0 <- getCurrentTime
74 | -- n' <- evaluate $ optimizeList_ (bimap only_ only_ <$> chnk) n0
75 | -- -- (sgdOptimizer rate netOpPure crossEntropy)
76 | -- (adamOptimizer def netOpPure crossEntropy)
77 | n' <- optimizeListM_ (bimap only_ only_ <$> chnk) n0
78 | (adamOptimizerM def (netOpDOPure dout g) crossEntropy)
79 | t1 <- getCurrentTime
80 | printf "Trained on %d points in %s.\n" batch (show (t1 `diffUTCTime` t0))
81 |
82 | let trainScore = testNetList maxTest (someNet n') chnk
83 | testScore = testNetList maxTest (someNet n') test
84 | printf "Training error: %.2f%%\n" ((1 - trainScore) * 100)
85 | printf "Validation error: %.2f%%\n" ((1 - testScore ) * 100)
86 |
87 | return ((), n')
88 | where
89 | -- rate :: Double
90 | -- rate = 0.02
91 | batch :: Int
92 | batch = 2500
93 |
--------------------------------------------------------------------------------
/stack.yaml.lock:
--------------------------------------------------------------------------------
1 | # This file was autogenerated by Stack.
2 | # You should not edit this file by hand.
3 | # For more information, please see the documentation at:
4 | # https://docs.haskellstack.org/en/stable/lock_files
5 |
6 | packages:
7 | - completed:
8 | size: 26685
9 | url: https://github.com/mstksg/opto/archive/def9e41adc1123385037f92ead38c92cb6454dd2.tar.gz
10 | cabal-file:
11 | size: 2442
12 | sha256: 163eeb672b5afc9f7147275d786060a80308c31e57a84d94ebbb62c74c059ee8
13 | name: opto
14 | version: 0.1.0.0
15 | sha256: 627e8bf47074ae57b1ffa843810177c2df72fe3993de59aa2f8c68c6722b67f8
16 | pantry-tree:
17 | size: 1215
18 | sha256: f3dedce3dc9e88ec9fe99f0b69b8893fb2047e8c780d901e5832cfd2e1f52aae
19 | original:
20 | url: https://github.com/mstksg/opto/archive/def9e41adc1123385037f92ead38c92cb6454dd2.tar.gz
21 | - completed:
22 | hackage: backprop-0.2.6.3@sha256:a38430a6b4539cebd3c41d4fad9c48c2654fad6f6455d3b119882e7bf4804cfc,2805
23 | pantry-tree:
24 | size: 2123
25 | sha256: 433f277d8d4a0ad480fe2f0e8f71bd1f2cc71604b1fd11a15e8c86c81424648f
26 | original:
27 | hackage: backprop-0.2.6.3
28 | - completed:
29 | hackage: decidable-0.3.0.0@sha256:34857003b57139a047c9ab7944c313c227d9db702a8dcefa1478966257099423,1774
30 | pantry-tree:
31 | size: 764
32 | sha256: cc9b297fe8d4b4606583d24ff17a310ce227db3edd9a3d9c35df90582431ff68
33 | original:
34 | hackage: decidable-0.3.0.0
35 | - completed:
36 | hackage: functor-products-0.1.1.0@sha256:2bea36b6106b5756be6b81b3a5bfe7b41db1cf45fb63c19a1f04b572ba90fd0c,1456
37 | pantry-tree:
38 | size: 408
39 | sha256: 6c7d58498a2c23338baa8275a51e9099739812b1bd36126f887e5cdf57cce45b
40 | original:
41 | hackage: functor-products-0.1.1.0
42 | - completed:
43 | hackage: hmatrix-vector-sized-0.1.2.0@sha256:5b13ea0d371c54293e8a2159d3cea1f1466b8908331e1a1956c26b8732ca67a3,1881
44 | pantry-tree:
45 | size: 402
46 | sha256: 38bbeddbc873627d95d789a55d12104e8fe376f80cf9ae3ef3ca4ca282e5b177
47 | original:
48 | hackage: hmatrix-vector-sized-0.1.2.0
49 | - completed:
50 | hackage: hmatrix-backprop-0.1.3.0@sha256:eff786a68b34d44df9bd461c5890d0d876684f34e5fd4b6078a67263d21279f1,2294
51 | pantry-tree:
52 | size: 455
53 | sha256: 30892afe2120ad983d39bbda594a9d18b4d44572ba9d032e271455ba1b00091d
54 | original:
55 | hackage: hmatrix-backprop-0.1.3.0
56 | - completed:
57 | hackage: list-witnesses-0.1.3.2@sha256:a72d5c2ca295b89313842c405e8ccd972153d0476dba00387515312931b87a3f,1670
58 | pantry-tree:
59 | size: 398
60 | sha256: 5c1c96db8ce32341b5693ae7c110c6e51e09d9aa98baaf6e76f0c5b965db2c4c
61 | original:
62 | hackage: list-witnesses-0.1.3.2
63 | - completed:
64 | hackage: typelits-witnesses-0.4.0.0@sha256:1d7092ba98fdc33f4b413e04144eb3ead7b105f74b2998e3c74a8a0feee685a9,1985
65 | pantry-tree:
66 | size: 403
67 | sha256: 2ee741f6bb4dba710e6449da335fdcf8940adb767798b29fdb8ae2606d22e0cb
68 | original:
69 | hackage: typelits-witnesses-0.4.0.0
70 | - completed:
71 | hackage: vinyl-0.12.0@sha256:6136e2608c2c4be0c112944fb0f5a6a0df56b50adec12eb1b7240258abfcf9b1,3790
72 | pantry-tree:
73 | size: 1857
74 | sha256: aeb9e0e1a3bbe2b1f048a096430d240964a31c6936a1da89f4b32e931eba9d69
75 | original:
76 | hackage: vinyl-0.12.0
77 | - completed:
78 | hackage: dependent-sum-0.6.2.0@sha256:bff37c85b38e768b942f9d81c2465b63a96076f1ba006e35612aa357770807b6,1856
79 | pantry-tree:
80 | size: 474
81 | sha256: ad3fbed5104f9ee9c8082c9dcc8ade847674e3053572533e8d26ad1a866f1107
82 | original:
83 | hackage: dependent-sum-0.6.2.0
84 | - completed:
85 | hackage: constraints-extras-0.3.0.2@sha256:bf6884be65958e9188ae3c9e5547abfd6d201df021bff8a4704c2c4fe1e1ae5b,1784
86 | pantry-tree:
87 | size: 594
88 | sha256: b0bcc96d375ee11b1972a2e9e8e42039b3f420b0e1c46e9c70652470445a6505
89 | original:
90 | hackage: constraints-extras-0.3.0.2
91 | snapshots:
92 | - completed:
93 | size: 471743
94 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/nightly/2020/1/30.yaml
95 | sha256: d55e2ae57e1af8641591b271a0315405ea34690c4ee50b4eaf7445ee1780ada2
96 | original: nightly-2020-01-30
97 |
--------------------------------------------------------------------------------
/src/Backprop/Learn/Loss.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DataKinds #-}
2 | {-# LANGUAGE FlexibleContexts #-}
3 | {-# LANGUAGE RankNTypes #-}
4 | {-# LANGUAGE ScopedTypeVariables #-}
5 | {-# LANGUAGE TypeApplications #-}
6 | {-# LANGUAGE TypeOperators #-}
7 | {-# OPTIONS_GHC -Wno-redundant-constraints #-}
8 |
9 | module Backprop.Learn.Loss (
10 | -- * Loss functions
11 | Loss
12 | , crossEntropy, crossEntropy1
13 | , squaredError, absError, totalSquaredError, squaredErrorV
14 | -- , totalCov
15 | -- ** Manipulate loss functions
16 | , scaleLoss
17 | , sumLoss
18 | , sumLossDecay
19 | , lastLoss
20 | , zipLoss
21 | , t2Loss
22 | -- * Regularization
23 | , Regularizer
24 | , l2Reg
25 | , l1Reg
26 | , noReg
27 | -- ** Manipulate regularizers
28 | , addReg
29 | , scaleReg
30 | ) where
31 |
32 | import Backprop.Learn.Regularize
33 | import Control.Applicative
34 | import Data.Coerce
35 | import Data.Finite
36 | import Data.Type.Tuple
37 | import GHC.TypeNats
38 | import Lens.Micro
39 | import Numeric.Backprop
40 | import Numeric.LinearAlgebra.Static.Backprop
41 | import qualified Data.Vector.Sized as SV
42 | import qualified Prelude.Backprop as B
43 |
44 | type Loss a = forall s. Reifies s W => a -> BVar s a -> BVar s Double
45 |
46 | crossEntropy :: KnownNat n => Loss (R n)
47 | crossEntropy targ res = -(log res <.> auto targ)
48 |
49 | crossEntropy1 :: Loss Double
50 | crossEntropy1 targ res = -(log res * auto targ + log (1 - res) * auto (1 - targ))
51 |
52 | squaredErrorV :: KnownNat n => Loss (R n)
53 | squaredErrorV targ res = e <.> e
54 | where
55 | e = res - auto targ
56 |
57 | totalSquaredError
58 | :: (Backprop (t Double), Num (t Double), Foldable t, Functor t)
59 | => Loss (t Double)
60 | totalSquaredError targ res = B.sum (e * e)
61 | where
62 | e = auto targ - res
63 |
64 | squaredError :: Loss Double
65 | squaredError targ res = (res - auto targ) ** 2
66 |
67 | absError :: Loss Double
68 | absError targ res = abs (res - auto targ)
69 |
70 | -- -- | Sum of covariances between matching components. Not sure if anyone
71 | -- -- uses this.
72 | -- totalCov :: (Backprop (t Double), Foldable t, Functor t) => Loss (t Double)
73 | -- totalCov targ res = -(xy / fromIntegral n - (x * y) / fromIntegral (n * n))
74 | -- where
75 | -- x = auto $ sum targ
76 | -- y = B.sum res
77 | -- xy = B.sum (auto targ * res)
78 | -- n = length targ
79 |
80 | -- klDivergence :: Loss Double
81 | -- klDivergence =
82 |
83 | sumLoss
84 | :: (Traversable t, Applicative t, Backprop a)
85 | => Loss a
86 | -> Loss (t a)
87 | sumLoss l targ = sum . liftA2 l targ . sequenceVar
88 |
89 | zipLoss
90 | :: (Traversable t, Applicative t, Backprop a)
91 | => t Double
92 | -> Loss a
93 | -> Loss (t a)
94 | zipLoss xs l targ = sum
95 | . liftA3 (\α t -> (* auto α) . l t) xs targ
96 | . sequenceVar
97 |
98 | sumLossDecay
99 | :: forall n a. (KnownNat n, Backprop a)
100 | => Double
101 | -> Loss a
102 | -> Loss (SV.Vector n a)
103 | sumLossDecay β = zipLoss $ SV.generate (\i -> β ** (fromIntegral i - n))
104 | where
105 | n = fromIntegral $ maxBound @(Finite n)
106 |
107 | lastLoss
108 | :: forall n a. (KnownNat (n + 1), Backprop a)
109 | => Loss a
110 | -> Loss (SV.Vector (n + 1) a)
111 | lastLoss l targ = l (SV.last targ)
112 | . viewVar (coerced . SV.ix @(n + 1) maxBound)
113 | . B.coerce @_ @(ABP (SV.Vector (n + 1)) a)
114 |
115 | coerced :: Coercible a b => Lens' a b
116 | coerced f x = coerce <$> f (coerce x)
117 | {-# INLINE coerced #-}
118 |
119 | -- | Scale the result of a loss function.
120 | scaleLoss :: Double -> Loss a -> Loss a
121 | scaleLoss β l x = (* auto β) . l x
122 |
123 | -- | Lift and sum a loss function over the components of a ':&'.
124 | t2Loss
125 | :: (Backprop a, Backprop b)
126 | => Loss a -- ^ loss on first component
127 | -> Loss b -- ^ loss on second component
128 | -> Loss (a :# b)
129 | t2Loss f g (xT :# yT) (xR :## yR) = f xT xR + g yT yR
130 |
131 |
--------------------------------------------------------------------------------
/src/Backprop/Learn/Model/Stochastic.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DeriveDataTypeable #-}
2 | {-# LANGUAGE FlexibleContexts #-}
3 | {-# LANGUAGE FlexibleInstances #-}
4 | {-# LANGUAGE GADTs #-}
5 | {-# LANGUAGE KindSignatures #-}
6 | {-# LANGUAGE MultiParamTypeClasses #-}
7 | {-# LANGUAGE PatternSynonyms #-}
8 | {-# LANGUAGE RankNTypes #-}
9 | {-# LANGUAGE RecordWildCards #-}
10 | {-# LANGUAGE TypeFamilies #-}
11 | {-# LANGUAGE TypeInType #-}
12 | {-# LANGUAGE UndecidableInstances #-}
13 | {-# LANGUAGE ViewPatterns #-}
14 |
15 | module Backprop.Learn.Model.Stochastic (
16 | dropout
17 | , rreLU
18 | , injectNoise, applyNoise
19 | , injectNoiseR, applyNoiseR
20 | ) where
21 |
22 | import Backprop.Learn.Model.Function
23 | import Backprop.Learn.Model.Types
24 | import Control.Monad.Primitive
25 | import Data.Bool
26 | import GHC.TypeNats
27 | import Numeric.Backprop
28 | import Numeric.LinearAlgebra.Static.Backprop
29 | import Numeric.LinearAlgebra.Static.Vector
30 | import qualified Data.Vector.Storable.Sized as SVS
31 | import qualified Statistics.Distribution as Stat
32 | import qualified System.Random.MWC as MWC
33 | import qualified System.Random.MWC.Distributions as MWC
34 |
35 | -- | Dropout layer. Parameterized by dropout percentage (should be between
36 | -- 0 and 1).
37 | --
38 | -- 0 corresponds to no dropout, 1 corresponds to complete dropout of all
39 | -- nodes every time.
40 | dropout
41 | :: KnownNat n
42 | => Double
43 | -> Model 'Nothing 'Nothing (R n) (R n)
44 | dropout r = Func
45 | { runFunc = (auto (realToFrac (1 - r)) *)
46 | , runFuncStoch = \g x -> do
47 | (x *) . auto . vecR <$> SVS.replicateM (mask g)
48 | }
49 | where
50 | mask :: PrimMonad m => MWC.Gen (PrimState m) -> m Double
51 | mask = fmap (bool 1 0) . MWC.bernoulli r
52 |
53 | -- | Random leaky rectified linear unit
54 | rreLU
55 | :: (Stat.ContGen d, Stat.Mean d, KnownNat n)
56 | => d
57 | -> Model 'Nothing 'Nothing (R n) (R n)
58 | rreLU d = Func
59 | { runFunc = vmap' (preLU v)
60 | , runFuncStoch = \g x -> do
61 | α <- vecR <$> SVS.replicateM (Stat.genContVar d g)
62 | pure (zipWithVector preLU (constVar α) x)
63 | }
64 | where
65 | v :: BVar s Double
66 | v = auto (Stat.mean d)
67 |
68 | -- | Inject random noise. Usually used between neural network layers, or
69 | -- at the very beginning to pre-process input.
70 | --
71 | -- In non-stochastic mode, this adds the mean of the distribution.
72 | injectNoise
73 | :: (Stat.ContGen d, Stat.Mean d, Fractional a)
74 | => d
75 | -> Model 'Nothing 'Nothing a a
76 | injectNoise d = Func
77 | { runFunc = (realToFrac (Stat.mean d) +)
78 | , runFuncStoch = \g x -> do
79 | e <- Stat.genContVar d g
80 | pure (realToFrac e + x)
81 | }
82 |
83 | -- | 'injectNoise' lifted to 'R'
84 | injectNoiseR
85 | :: (Stat.ContGen d, Stat.Mean d, KnownNat n)
86 | => d
87 | -> Model 'Nothing 'Nothing (R n) (R n)
88 | injectNoiseR d = Func
89 | { runFunc = (realToFrac (Stat.mean d) +)
90 | , runFuncStoch = \g x -> do
91 | e <- vecR <$> SVS.replicateM (Stat.genContVar d g)
92 | pure (constVar e + x)
93 | }
94 |
95 | -- | Multply by random noise. Can be used to implement dropout-like
96 | -- behavior.
97 | --
98 | -- In non-stochastic mode, this scales by the mean of the distribution.
99 | applyNoise
100 | :: (Stat.ContGen d, Stat.Mean d, Fractional a)
101 | => d
102 | -> Model 'Nothing 'Nothing a a
103 | applyNoise d = Func
104 | { runFunc = (realToFrac (Stat.mean d) *)
105 | , runFuncStoch = \g x -> do
106 | e <- Stat.genContVar d g
107 | pure (realToFrac e * x)
108 | }
109 |
110 | -- | 'applyNoise' lifted to 'R'
111 | applyNoiseR
112 | :: (Stat.ContGen d, Stat.Mean d, KnownNat n)
113 | => d
114 | -> Model 'Nothing 'Nothing (R n) (R n)
115 | applyNoiseR d = Func
116 | { runFunc = (realToFrac (Stat.mean d) *)
117 | , runFuncStoch = \g x -> do
118 | e <- vecR <$> SVS.replicateM (Stat.genContVar d g)
119 | pure (constVar e * x)
120 | }
121 |
--------------------------------------------------------------------------------
/old2/src/Backprop/Learn/Run.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DataKinds #-}
2 | {-# LANGUAGE RankNTypes #-}
3 | {-# LANGUAGE ScopedTypeVariables #-}
4 | {-# LANGUAGE TupleSections #-}
5 | {-# LANGUAGE TypeApplications #-}
6 | {-# LANGUAGE TypeOperators #-}
7 |
8 | module Backprop.Learn.Run (
9 | consecutives
10 | , consecutivesN
11 | , leadings
12 | , conduitLearn, conduitLearnStoch
13 | -- * Encoding and decoding for learning
14 | , oneHot', oneHot, oneHotR
15 | , SVG.maxIndex, maxIndexR
16 | ) where
17 |
18 | import Backprop.Learn.Model
19 | import Control.Monad
20 | import Control.Monad.Primitive
21 | import Control.Monad.Trans.Class
22 | import Control.Monad.Trans.Maybe
23 | import Data.Bool
24 | import Data.Conduit
25 | import Data.Finite
26 | import Data.Foldable
27 | import Data.Proxy
28 | import GHC.TypeNats
29 | import Numeric.LinearAlgebra.Static
30 | import Numeric.LinearAlgebra.Static.Vector
31 | import qualified Data.Conduit.Combinators as C
32 | import qualified Data.Sequence as Seq
33 | import qualified Data.Vector.Generic as VG
34 | import qualified Data.Vector.Generic.Sized as SVG
35 | import qualified System.Random.MWC as MWC
36 |
37 | consecutives :: Monad m => ConduitT i (i, i) m ()
38 | consecutives = void . runMaybeT $ do
39 | x <- MaybeT await
40 | go x
41 | where
42 | go x = do
43 | y <- MaybeT await
44 | lift $ yield (x, y)
45 | go y
46 |
47 | consecutivesN
48 | :: forall v n i m. (KnownNat n, VG.Vector v i, Monad m)
49 | => ConduitT i (SVG.Vector v n i, SVG.Vector v n i) m ()
50 | consecutivesN = conseq (fromIntegral n) .| C.concatMap process
51 | where
52 | n = natVal (Proxy @n)
53 | process (xs, ys, _) = (,) <$> SVG.fromList (toList xs)
54 | <*> SVG.fromList (toList ys)
55 |
56 | leadings
57 | :: forall v n i m. (KnownNat n, VG.Vector v i, Monad m)
58 | => ConduitT i (SVG.Vector v n i, i) m ()
59 | leadings = conseq (fromIntegral n) .| C.concatMap process
60 | where
61 | n = natVal (Proxy @n)
62 | process (xs, _, y) = (, y) <$> SVG.fromList (toList xs)
63 |
64 | conseq
65 | :: forall i m. Monad m
66 | => Int
67 | -> ConduitT i (Seq.Seq i, Seq.Seq i, i) m ()
68 | conseq n = void . runMaybeT $ do
69 | xs <- Seq.replicateM n $ MaybeT await
70 | go xs
71 | where
72 | go xs = do
73 | _ Seq.:<| xs' <- pure xs
74 | y <- MaybeT await
75 | let ys = xs' Seq.:|> y
76 | lift $ yield (xs, ys, y)
77 | go ys
78 |
79 | conduitLearn
80 | :: (Learn a b l, Backprop b, MaybeC Backprop (LStateMaybe l), Monad m)
81 | => l
82 | -> LParam_ I l
83 | -> LState_ I l
84 | -> ConduitT a b m (LState_ I l)
85 | conduitLearn l p = go
86 | where
87 | go s = do
88 | mx <- await
89 | case mx of
90 | Nothing -> return s
91 | Just x -> do
92 | let (y, s') = runLearn_ l p x s
93 | yield y
94 | go s'
95 |
96 | conduitLearnStoch
97 | :: (Learn a b l, Backprop b, MaybeC Backprop (LStateMaybe l), PrimMonad m)
98 | => l
99 | -> MWC.Gen (PrimState m)
100 | -> LParam_ I l
101 | -> LState_ I l
102 | -> ConduitT a b m (LState_ I l)
103 | conduitLearnStoch l g p = go
104 | where
105 | go s = do
106 | mx <- await
107 | case mx of
108 | Nothing -> return s
109 | Just x -> do
110 | (y, s') <- lift $ runLearnStoch_ l g p x s
111 | yield y
112 | go s'
113 |
114 | -- | What module should this be in?
115 | oneHot'
116 | :: (VG.Vector v a, KnownNat n)
117 | => a -- ^ not hot
118 | -> a -- ^ hot
119 | -> Finite n
120 | -> SVG.Vector v n a
121 | oneHot' nothot hot i = SVG.generate (bool nothot hot . (== i))
122 |
123 | oneHot
124 | :: (VG.Vector v a, KnownNat n, Num a)
125 | => Finite n
126 | -> SVG.Vector v n a
127 | oneHot = oneHot' 0 1
128 |
129 | oneHotR :: KnownNat n => Finite n -> R n
130 | oneHotR = vecR . oneHot
131 |
132 | -- | Could be in /hmatrix/.
133 | maxIndexR :: KnownNat n => R (n + 1) -> Finite (n + 1)
134 | maxIndexR = SVG.maxIndex . rVec
135 |
--------------------------------------------------------------------------------
/old/src/Learn/Neural/Layer/FullyConnected.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DeriveGeneric #-}
2 | {-# LANGUAGE ExistentialQuantification #-}
3 | {-# LANGUAGE FlexibleInstances #-}
4 | {-# LANGUAGE InstanceSigs #-}
5 | {-# LANGUAGE LambdaCase #-}
6 | {-# LANGUAGE MultiParamTypeClasses #-}
7 | {-# LANGUAGE PolyKinds #-}
8 | {-# LANGUAGE StandaloneDeriving #-}
9 | {-# LANGUAGE TypeFamilies #-}
10 | {-# LANGUAGE TypeInType #-}
11 |
12 | module Learn.Neural.Layer.FullyConnected (
13 | FullyConnected
14 | ) where
15 |
16 | import Data.Kind
17 | import Data.Singletons.Prelude
18 | import Data.Singletons.TypeLits
19 | import GHC.Generics (Generic)
20 | import GHC.Generics.Numeric
21 | import Learn.Neural.Layer
22 | import Numeric.BLAS
23 | import Numeric.Backprop
24 | import Statistics.Distribution
25 | import Statistics.Distribution.Normal
26 | import qualified Generics.SOP as SOP
27 |
28 | data FullyConnected :: Type
29 |
30 | instance (Num (b '[o,i]), Num (b '[o])) => Num (CParam FullyConnected b '[i] '[o]) where
31 | FCP w1 b1 + FCP w2 b2 = FCP (w1 + w2) (b1 + b2)
32 | FCP w1 b1 - FCP w2 b2 = FCP (w1 - w2) (b1 - b2)
33 | FCP w1 b1 * FCP w2 b2 = FCP (w1 * w2) (b1 * b2)
34 | negate (FCP w b) = FCP (negate w) (negate b)
35 | signum (FCP w b) = FCP (signum w) (signum b)
36 | abs (FCP w b) = FCP (abs w) (abs b)
37 | fromInteger x = FCP (fromInteger x) (fromInteger x)
38 |
39 | instance (Fractional (b '[o,i]), Fractional (b '[o])) => Fractional (CParam FullyConnected b '[i] '[o]) where
40 | FCP w1 b1 / FCP w2 b2 = FCP (w1 / w2) (b1 / b2)
41 | recip (FCP w b) = FCP (recip w) (recip b)
42 | fromRational x = FCP (fromRational x) (fromRational x)
43 |
44 | instance (Floating (b '[o,i]), Floating (b '[o])) => Floating (CParam FullyConnected b '[i] '[o]) where
45 | sqrt (FCP w b) = FCP (sqrt w) (sqrt b)
46 |
47 | instance Num (CState FullyConnected b '[i] '[o]) where
48 | _ + _ = FCS
49 | _ * _ = FCS
50 | _ - _ = FCS
51 | negate _ = FCS
52 | abs _ = FCS
53 | signum _ = FCS
54 | fromInteger _ = FCS
55 |
56 | instance Fractional (CState FullyConnected b '[i] '[o]) where
57 | _ / _ = FCS
58 | recip _ = FCS
59 | fromRational _ = FCS
60 |
61 | instance Floating (CState FullyConnected b '[i] '[o]) where
62 | sqrt _ = FCS
63 |
64 |
65 |
66 |
67 |
68 | deriving instance Generic (CParam FullyConnected b '[i] '[o])
69 | instance SOP.Generic (CParam FullyConnected b '[i] '[o])
70 |
71 | instance (BLAS b, KnownNat i, KnownNat o, Floating (b '[o,i]), Floating (b '[o]))
72 | => Component FullyConnected b '[i] '[o] where
73 | data CParam FullyConnected b '[i] '[o] =
74 | FCP { _fcWeights :: !(b '[o,i])
75 | , _fcBiases :: !(b '[o])
76 | }
77 | data CState FullyConnected b '[i] '[o] = FCS
78 | type CConstr FullyConnected b '[i] '[o] = Num (b '[o,i])
79 | data CConf FullyConnected b '[i] '[o] = forall d. ContGen d => FCC d
80 |
81 | componentOp = componentOpDefault
82 |
83 | initParam = \case
84 | i `SCons` SNil -> \case
85 | so@(o `SCons` SNil) -> \(FCC d) g -> do
86 | w <- genA (o `SCons` (i `SCons` SNil)) $ \_ ->
87 | realToFrac <$> genContVar d g
88 | b <- genA so $ \_ ->
89 | realToFrac <$> genContVar d g
90 | return $ FCP w b
91 | _ -> error "inaccessible."
92 | _ -> error "inaccessible."
93 |
94 | initState _ _ _ _ = return FCS
95 |
96 | defConf = FCC (normalDistr 0 0.01)
97 |
98 | instance (BLAS b, KnownNat i, KnownNat o, Floating (b '[o,i]), Floating (b '[o]))
99 | => ComponentFF FullyConnected b '[i] '[o] where
100 | componentOpFF = bpOp . withInps $ \(x :< p :< Ø) -> do
101 | w :< b :< Ø <- gTuple #<~ p
102 | y <- matVecOp ~$ (w :< x :< Ø)
103 | z <- (+.) ~$ (y :< b :< Ø)
104 | return . only $ z
105 |
106 | instance (BLAS b, KnownNat i, KnownNat o, Floating (b '[o,i]), Floating (b '[o]))
107 | => ComponentLayer r FullyConnected b '[i] '[o] where
108 | componentRunMode = RMIsFF
109 |
--------------------------------------------------------------------------------
/src/Backprop/Learn/Model/Neural.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE AllowAmbiguousTypes #-}
2 | {-# LANGUAGE DataKinds #-}
3 | {-# LANGUAGE FlexibleContexts #-}
4 | {-# LANGUAGE FlexibleInstances #-}
5 | {-# LANGUAGE MultiParamTypeClasses #-}
6 | {-# LANGUAGE PatternSynonyms #-}
7 | {-# LANGUAGE RankNTypes #-}
8 | {-# LANGUAGE ScopedTypeVariables #-}
9 | {-# LANGUAGE TypeApplications #-}
10 | {-# LANGUAGE TypeFamilies #-}
11 | {-# LANGUAGE TypeOperators #-}
12 | {-# LANGUAGE UndecidableInstances #-}
13 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
14 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
15 |
16 | module Backprop.Learn.Model.Neural (
17 | -- * Feed-forward
18 | FCp, fc, fca
19 | , fcWeights, fcBias
20 | -- * Recurrent
21 | , fcr, fcra
22 | , FCRp, fcrBias, fcrInputWeights, fcrStateWeights
23 | ) where
24 |
25 |
26 | import Backprop.Learn.Model.Combinator
27 | import Backprop.Learn.Model.Regression
28 | import Backprop.Learn.Model.State
29 | import Backprop.Learn.Model.Types
30 | import Data.Tuple
31 | import GHC.TypeNats
32 | import Lens.Micro
33 | import Numeric.Backprop
34 | import Numeric.LinearAlgebra.Static.Backprop
35 | import qualified Numeric.LinearAlgebra.Static as H
36 |
37 | -- | Parameters for fully connected feed-forward layer with bias.
38 | type FCp = LRp
39 |
40 | fcWeights :: Lens (FCp i o) (FCp i' o) (L o i) (L o i')
41 | fcWeights = lrBeta
42 |
43 | fcBias :: forall i o. Lens' (FCp i o) (R o)
44 | fcBias = lrAlpha
45 |
46 | -- | Fully connected feed-forward layer with bias. Parameterized by its
47 | -- initialization distribution.
48 | --
49 | -- Note that this has no activation function; to use as a model with
50 | -- activation function, chain it with an activation function using 'RMap',
51 | -- ':.~', etc.; see 'FCA' for a convenient type synonym and constructor.
52 | --
53 | -- Without any activation function, this is essentially a multivariate
54 | -- linear regression.
55 | --
56 | -- With the logistic function as an activation function, this is
57 | -- essentially multivariate logistic regression. (See 'logReg')
58 | fc :: (KnownNat i, KnownNat o)
59 | => Model ('Just (FCp i o)) 'Nothing (R i) (R o)
60 | fc = linReg
61 |
62 | -- | Convenient synonym for an 'fC' post-composed with a simple
63 | -- parameterless activation function.
64 | fca :: (KnownNat i, KnownNat o)
65 | => (forall z. Reifies z W => BVar z (R o) -> BVar z (R o))
66 | -> Model ('Just (FCp i o)) 'Nothing (R i) (R o)
67 | fca f = funcD f <~ linReg
68 |
69 | -- | Fully connected recurrent layer with bias.
70 | fcr :: (KnownNat i, KnownNat o, KnownNat s)
71 | => (forall z. Reifies z W => BVar z (R o) -> BVar z (R s)) -- ^ store
72 | -> Model ('Just (FCRp s i o)) ('Just (R s)) (R i) (R o)
73 | fcr s = recurrent H.split (H.#) s fc
74 |
75 | -- | Convenient synonym for an 'fcr' post-composed with a simple
76 | -- parameterless activation function.
77 | fcra
78 | :: (KnownNat i, KnownNat o, KnownNat s)
79 | => (forall z. Reifies z W => BVar z (R o) -> BVar z (R o))
80 | -> (forall z. Reifies z W => BVar z (R o) -> BVar z (R s)) -- ^ store
81 | -> Model ('Just (FCRp s i o)) ('Just (R s)) (R i) (R o)
82 | fcra f s = funcD f <~ recurrent H.split (H.#) s fc
83 |
84 | -- | Parameter for fully connected recurrent layer.
85 | type FCRp s i o = FCp (i + s) o
86 |
87 | lensIso :: (s -> (a, x)) -> ((b, x) -> t) -> Lens s t a b
88 | lensIso f g h x = g <$> _1 h (f x)
89 |
90 | fcrInputWeights
91 | :: (KnownNat s, KnownNat i, KnownNat i', KnownNat o)
92 | => Lens (FCRp s i o) (FCRp s i' o) (L o i) (L o i')
93 | fcrInputWeights = fcWeights
94 | . lensIso H.splitCols (uncurry (H.|||))
95 |
96 | fcrStateWeights
97 | :: (KnownNat s, KnownNat s', KnownNat i, KnownNat o)
98 | => Lens (FCRp s i o) (FCRp s' i o) (L o s) (L o s')
99 | fcrStateWeights = fcWeights
100 | . lensIso (swap . H.splitCols) (uncurry (H.|||) . swap)
101 |
102 | fcrBias :: forall s i o. Lens' (FCRp s i o) (R o)
103 | fcrBias = fcBias @(i + s) @o
104 |
--------------------------------------------------------------------------------
/stack-ghcjs.yaml:
--------------------------------------------------------------------------------
1 | # This file was automatically generated by 'stack init'
2 | #
3 | # Some commonly used options have been documented as comments in this file.
4 | # For advanced use and comprehensive documentation of the format, please see:
5 | # http://docs.haskellstack.org/en/stable/yaml_configuration/
6 |
7 | # Resolver to choose a 'specific' stackage snapshot or a compiler version.
8 | # A snapshot resolver dictates the compiler version and the set of packages
9 | # to be used for project dependencies. For example:
10 | #
11 | # resolver: lts-3.5
12 | # resolver: nightly-2015-09-21
13 | # resolver: ghc-7.10.2
14 | # resolver: ghcjs-0.1.0_ghc-7.10.2
15 | # resolver:
16 | # name: custom-snapshot
17 | # location: "./custom-snapshot.yaml"
18 | resolver: lts-7.22
19 |
20 | # User packages to be built.
21 | # Various formats can be used as shown in the example below.
22 | #
23 | # packages:
24 | # - some-directory
25 | # - https://example.com/foo/bar/baz-0.0.2.tar.gz
26 | # - location:
27 | # git: https://github.com/commercialhaskell/stack.git
28 | # commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a
29 | # - location: https://github.com/commercialhaskell/stack/commit/e7b331f14bcffb8367cd58fbfc8b40ec7642100a
30 | # extra-dep: true
31 | # subdirs:
32 | # - auto-update
33 | # - wai
34 | #
35 | # A package marked 'extra-dep: true' will only be built if demanded by a
36 | # non-dependency (i.e. a user package), and its test suites and benchmarks
37 | # will not be run. This is useful for tweaking upstream packages.
38 | packages:
39 | - '.'
40 | - location:
41 | git: https://github.com/mstksg/backprop.git
42 | commit: 5c5fa4301cdae6f0b9c9607c18df01de32d288b7
43 | extra-dep: true
44 | - location:
45 | git: https://github.com/mstksg/type-combinators-singletons.git
46 | commit: 8bd05539987692ee745f0d78b83e5d72c7af21c4
47 | extra-dep: true
48 | - location:
49 | git: https://github.com/mstksg/generics-lift.git
50 | commit: c8a127d522f47efbf4cce55c9276a1d19cd6a211
51 | extra-dep: true
52 | - location:
53 | git: https://github.com/ghcjs/ghcjs-base.git
54 | commit: dd7034ef8582ea8a175a71a988393a9d1ee86d6f
55 | extra-dep: true
56 | - location:
57 | git: https://github.com/reflex-frp/reflex.git
58 | commit: 2fe0f566f8d6b6eceb178a85516643390111bb83
59 | extra-dep: true
60 | - location:
61 | git: https://github.com/reflex-frp/reflex-dom.git
62 | commit: 706ab47df9729bdc5c4ac3f4d8dfd4661d9f6e1a
63 | subdirs:
64 | - reflex-dom
65 | - reflex-dom-core
66 | extra-dep: true
67 |
68 | # Dependency packages to be pulled from upstream that are not in the resolver
69 | # (e.g., acme-missiles-0.3)
70 | extra-deps:
71 | - hmatrix-0.18.0.0
72 | - type-combinators-0.2.4.3
73 | - mnist-idx-0.1.2.6
74 | - finite-typelits-0.1.2.0
75 | - ghcjs-dom-0.8.0.0
76 | - ghcjs-dom-jsffi-0.8.0.0
77 | - jsaddle-0.8.3.2
78 | - prim-uniq-0.1.0.1
79 | - ref-tf-0.4.0.1
80 | - zenc-0.1.1
81 | # - ghcjs-base-0.2.0.0
82 | # - ghcjs-dom-0.2.4.0
83 | # - ghcjs-dom-jsffi-0.8.0.0
84 | # - ref-tf-0.4.0.1
85 | # - reflex-0.4.0
86 | # - reflex-dom-0.4
87 |
88 | # Override default flag values for local packages and extra-deps
89 | flags: {}
90 |
91 | # Extra package databases containing global packages
92 | extra-package-dbs: []
93 |
94 | # # compiler: ghcjs-0.2.1.9007019_ghc-8.0.1
95 | compiler: ghcjs-0.2.1_ghc-8.0.1
96 | compiler-check: match-exact
97 |
98 | setup-info:
99 | ghcjs:
100 | source:
101 | ghcjs-0.2.1_ghc-8.0.1:
102 | url: /home/justin/projects/haskell/ghcjs/ghcjs/.stack-work/dist/x86_64-linux/Cabal-1.24.0.0/ghcjs-0.2.1.tar.gz
103 | # sha1: f2398e082a83a3aca381bc4457f50397af7fedfb
104 | # # sha1: 74ae6d67442221bcc15b43fc0e72784418f0bc53
105 | # # url: http://ghcjs.tolysz.org/ghc-8.0-2017-02-05-lts-7.19-9007019.tar.gz
106 | # # sha1: d2cfc25f9cda32a25a87d9af68891b2186ee52f9
107 |
108 | # Control whether we use the GHC we find on the path
109 | system-ghc: false
110 | #
111 | # Require a specific version of stack, using version ranges
112 | # require-stack-version: -any # Default
113 | # require-stack-version: ">=1.3"
114 | #
115 | # Override the architecture used by stack, especially useful on Windows
116 | # arch: i386
117 | # arch: x86_64
118 | #
119 | # Extra directories used by stack for building
120 | # extra-include-dirs: [/path/to/dir]
121 | # extra-lib-dirs: [/path/to/dir]
122 | #
123 | # Allow a newer minor version of GHC than the snapshot specifies
124 | # compiler-check: newer-minor
125 |
--------------------------------------------------------------------------------
/src/Backprop/Learn/Run.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DataKinds #-}
2 | {-# LANGUAGE RankNTypes #-}
3 | {-# LANGUAGE ScopedTypeVariables #-}
4 | {-# LANGUAGE TupleSections #-}
5 | {-# LANGUAGE TypeApplications #-}
6 | {-# LANGUAGE TypeOperators #-}
7 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
8 |
9 | module Backprop.Learn.Run (
10 | consecutives
11 | , consecutivesN
12 | , leadings
13 | , conduitModel, conduitModelStoch
14 | -- * Encoding and decoding for learning
15 | , oneHot', oneHot, oneHotR
16 | , SVG.maxIndex, maxIndexR
17 | ) where
18 |
19 | import Backprop.Learn.Model
20 | import Control.Monad
21 | import Control.Monad.Primitive
22 | import Control.Monad.Trans.Class
23 | import Control.Monad.Trans.Maybe
24 | import Data.Bool
25 | import Data.Conduit
26 | import Data.Finite
27 | import Data.Foldable
28 | import Data.Proxy
29 | import Data.Type.Functor.Product
30 | import GHC.TypeNats
31 | import Numeric.LinearAlgebra.Static
32 | import Numeric.LinearAlgebra.Static.Vector
33 | import qualified Data.Conduit.Combinators as C
34 | import qualified Data.Sequence as Seq
35 | import qualified Data.Vector.Generic as VG
36 | import qualified Data.Vector.Generic.Sized as SVG
37 | import qualified System.Random.MWC as MWC
38 |
39 | consecutives :: Monad m => ConduitT i (i, i) m ()
40 | consecutives = void . runMaybeT $ do
41 | x <- MaybeT await
42 | go x
43 | where
44 | go x = do
45 | y <- MaybeT await
46 | lift $ yield (x, y)
47 | go y
48 |
49 | consecutivesN
50 | :: forall v n i m. (KnownNat n, VG.Vector v i, Monad m)
51 | => ConduitT i (SVG.Vector v n i, SVG.Vector v n i) m ()
52 | consecutivesN = conseq (fromIntegral n) .| C.concatMap process
53 | where
54 | n = natVal (Proxy @n)
55 | process (xs, ys, _) = (,) <$> SVG.fromList (toList xs)
56 | <*> SVG.fromList (toList ys)
57 |
58 | leadings
59 | :: forall v n i m. (KnownNat n, VG.Vector v i, Monad m)
60 | => ConduitT i (SVG.Vector v n i, i) m ()
61 | leadings = conseq (fromIntegral n) .| C.concatMap process
62 | where
63 | n = natVal (Proxy @n)
64 | process (xs, _, y) = (, y) <$> SVG.fromList (toList xs)
65 |
66 | conseq
67 | :: forall i m. Monad m
68 | => Int
69 | -> ConduitT i (Seq.Seq i, Seq.Seq i, i) m ()
70 | conseq n = void . runMaybeT $ do
71 | xs <- Seq.replicateM n $ MaybeT await
72 | go xs
73 | where
74 | go xs = do
75 | _ Seq.:<| xs' <- pure xs
76 | y <- MaybeT await
77 | let ys = xs' Seq.:|> y
78 | lift $ yield (xs, ys, y)
79 | go ys
80 |
81 | conduitModel
82 | :: (Backprop b, AllConstrainedProd Backprop s, Monad m)
83 | => Model p s a b
84 | -> TMaybe p
85 | -> TMaybe s
86 | -> ConduitT a b m (TMaybe s)
87 | conduitModel f p = go
88 | where
89 | go s = do
90 | mx <- await
91 | case mx of
92 | Nothing -> return s
93 | Just x -> do
94 | let (y, s') = runModel f p x s
95 | yield y
96 | go s'
97 |
98 | conduitModelStoch
99 | :: (Backprop b, AllConstrainedProd Backprop s, PrimMonad m)
100 | => Model p s a b
101 | -> MWC.Gen (PrimState m)
102 | -> TMaybe p
103 | -> TMaybe s
104 | -> ConduitT a b m (TMaybe s)
105 | conduitModelStoch f g p = go
106 | where
107 | go s = do
108 | mx <- await
109 | case mx of
110 | Nothing -> return s
111 | Just x -> do
112 | (y, s') <- lift $ runModelStoch f g p x s
113 | yield y
114 | go s'
115 |
116 | -- | What module should this be in?
117 | oneHot'
118 | :: (VG.Vector v a, KnownNat n)
119 | => a -- ^ not hot
120 | -> a -- ^ hot
121 | -> Finite n
122 | -> SVG.Vector v n a
123 | oneHot' nothot hot i = SVG.generate (bool nothot hot . (== i))
124 |
125 | oneHot
126 | :: (VG.Vector v a, KnownNat n, Num a)
127 | => Finite n
128 | -> SVG.Vector v n a
129 | oneHot = oneHot' 0 1
130 |
131 | oneHotR :: KnownNat n => Finite n -> R n
132 | oneHotR = vecR . oneHot
133 |
134 | -- | Could be in /hmatrix/.
135 | maxIndexR :: KnownNat n => R (n + 1) -> Finite (n + 1)
136 | maxIndexR = SVG.maxIndex . rVec
137 |
--------------------------------------------------------------------------------
/old/app/Letter2Vec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE ApplicativeDo #-}
2 | {-# LANGUAGE DataKinds #-}
3 | {-# LANGUAGE GADTs #-}
4 | {-# LANGUAGE LambdaCase #-}
5 | {-# LANGUAGE ScopedTypeVariables #-}
6 | {-# LANGUAGE TypeApplications #-}
7 | {-# LANGUAGE TypeOperators #-}
8 | {-# LANGUAGE ViewPatterns #-}
9 | {-# OPTIONS_GHC -fno-warn-orphans #-}
10 |
11 |
12 | import Control.Applicative
13 | import Control.DeepSeq
14 | import Control.Exception
15 | import Control.Monad.IO.Class
16 | import Control.Monad.Trans.State
17 | import Data.Bifunctor
18 | import Data.Char
19 | import Data.Default
20 | import Data.Finite
21 | import Data.Foldable
22 | import Data.List
23 | import Data.List.Split
24 | import Data.Maybe
25 | import Data.Ord
26 | import Data.Time.Clock
27 | import Data.Type.Product hiding (toList, head')
28 | import Learn.Neural
29 | import Numeric.BLAS.HMatrix
30 | import Text.Printf hiding (toChar, fromChar)
31 | import Type.Class.Known
32 | import qualified Data.Vector as V
33 | import qualified System.Random.MWC as MWC
34 | import qualified System.Random.MWC.Distributions as MWC
35 |
36 |
37 | type ASCII = Finite 128
38 |
39 | fromChar :: Char -> Maybe ASCII
40 | fromChar = packFinite . fromIntegral . ord
41 |
42 | toChar :: ASCII -> Char
43 | toChar = chr . fromIntegral
44 |
45 | charOneHot :: Tensor t => Char -> Maybe (t '[128])
46 | charOneHot = fmap (oneHot . only) . fromChar
47 |
48 | oneHotChar :: BLAS t => t '[128] -> Char
49 | oneHotChar = toChar . iamax
50 |
51 | charRank :: Tensor t => t '[128] -> [Char]
52 | charRank = map fst . sortBy (flip (comparing snd)) . zip ['\0'..] . toList . textract
53 |
54 | main :: IO ()
55 | main = MWC.withSystemRandom $ \g -> do
56 | holmes <- evaluate . force . mapMaybe (charOneHot @HM)
57 | =<< readFile "data/holmes.txt"
58 | putStrLn "Loaded data"
59 | let slices :: [(HM '[128], HM '[128])]
60 | slices = concat . getZipList $ do
61 | skips <- traverse (ZipList . flip drop holmes) [0..2]
62 | pure (case skips of
63 | [l1,l2,l3] ->
64 | [(l2,l1),(l2,l3)]
65 | _ -> []
66 | )
67 | slices' <- liftIO . fmap V.toList $ MWC.uniformShuffle (V.fromList slices) g
68 | let (test,train) = splitAt (length slices `div` 50) slices'
69 | net0 :: Network 'FeedForward HM
70 | ( '[128] :~ FullyConnected )
71 | '[ '[32 ] :~ LogitMap
72 | , '[32 ] :~ FullyConnected
73 | , '[3 ] :~ LogitMap
74 | , '[3 ] :~ FullyConnected
75 | , '[32 ] :~ LogitMap
76 | , '[32 ] :~ FullyConnected
77 | , '[128] :~ SoftMax '[128]
78 | ]
79 | '[128] <- initDefNet g
80 | flip evalStateT net0 . forM_ [1..] $ \e -> do
81 | train' <- liftIO . fmap V.toList $ MWC.uniformShuffle (V.fromList train) g
82 | liftIO $ printf "[Epoch %d]\n" (e :: Int)
83 | let chunkUp = chunksOf batch train'
84 | numChunks = length chunkUp
85 |
86 | forM_ ([1..] `zip` chunkUp) $ \(b, chnk) -> StateT $ \n0 -> do
87 | printf "(Epoch %d, Batch %d / %d)\n" e (b :: Int) numChunks
88 |
89 | t0 <- getCurrentTime
90 | n' <- evaluate $ optimizeListBatches_ (bimap only_ only_ <$> chnk) n0
91 | (batching (adamOptimizer def netOpPure crossEntropy))
92 | 25
93 | t1 <- getCurrentTime
94 | printf "Trained on %d points in %s.\n" batch (show (t1 `diffUTCTime` t0))
95 |
96 | let encoder :: Network 'FeedForward HM ( '[128] :~ FullyConnected )
97 | '[ '[32] :~ LogitMap, '[32] :~ FullyConnected, '[3] :~ LogitMap ]
98 | '[3]
99 | encoder = takeNet known n'
100 |
101 | forM_ [0..127] $ \(c :: ASCII) -> do
102 | let enc :: HM '[3]
103 | enc = runNetPure encoder (oneHot (c :< Ø))
104 | [x,y,z] = toList $ textract enc
105 | printf "%s\t%.4f\t%.4f\t%.4f\n" (show (toChar c)) x y z
106 |
107 | let trainScore = testNetList maxTest (someNet n') chnk
108 | testScore = testNetList maxTest (someNet n') test
109 | printf "Training Score: %.2f%%\n" ((1 - trainScore) * 100)
110 | printf "Validation Score: %.2f%%\n" ((1 - testScore ) * 100)
111 |
112 | return ((), n')
113 | where
114 | batch :: Int
115 | batch = 10000
116 |
117 |
118 |
--------------------------------------------------------------------------------
/app/word2vec.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE AllowAmbiguousTypes #-}
2 | {-# LANGUAGE DataKinds #-}
3 | {-# LANGUAGE PartialTypeSignatures #-}
4 | {-# LANGUAGE PolyKinds #-}
5 | {-# LANGUAGE ScopedTypeVariables #-}
6 | {-# LANGUAGE TupleSections #-}
7 | {-# LANGUAGE TypeApplications #-}
8 | {-# LANGUAGE TypeOperators #-}
9 | {-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
10 |
11 | import Backprop.Learn
12 | import Control.DeepSeq
13 | import Control.Exception
14 | import Control.Monad
15 | import Control.Monad.IO.Class
16 | import Control.Monad.Trans.Class
17 | import Control.Monad.Trans.State
18 | import Data.Char
19 | import Data.Conduit
20 | import Data.Default
21 | import Data.List
22 | import Data.Proxy
23 | import Data.Time
24 | import Data.Type.Tuple
25 | import GHC.TypeNats
26 | import Numeric.LinearAlgebra.Static.Backprop
27 | import Numeric.Opto
28 | import System.Environment
29 | import Text.Printf
30 | import qualified Conduit as C
31 | import qualified Data.Conduit.Combinators as C
32 | import qualified Data.Set as S
33 | import qualified Data.Text as T
34 | import qualified Data.Vector.Sized as SV
35 | import qualified Data.Vector.Storable as VS
36 | import qualified Numeric.LinearAlgebra.Static as H
37 | import qualified System.Random.MWC as MWC
38 |
39 | encoder :: forall n e. (KnownNat n, KnownNat e) => Model _ _ (R n) (R e)
40 | encoder = fca logistic
41 |
42 | decoder :: forall n e. (KnownNat n, KnownNat e) => Model _ _ (R e) (R n)
43 | decoder = fca softMax
44 |
45 | word2vec :: forall n e. (KnownNat n, KnownNat e) => Model _ _ (R n) (R n)
46 | word2vec = decoder @n @e <~ encoder @n @e
47 |
48 | oneHotWord
49 | :: KnownNat n
50 | => S.Set String
51 | -> String
52 | -> Maybe (R n)
53 | oneHotWord ws = fmap (oneHotR . fromIntegral) . (`S.lookupIndex` ws)
54 |
55 | makeCBOW :: KnownNat w => SV.Vector w (R n) -> (R n, R n)
56 | makeCBOW v = (SV.index v mid, SV.sum (v SV.// [(mid, 0)]))
57 | where
58 | mid = maxBound `div` 2
59 |
60 | main :: IO ()
61 | main = MWC.withSystemRandom @IO $ \g -> do
62 | sourceFile:logFile:testFile:_ <- getArgs
63 | wordSet <- S.fromList . tokenize <$> readFile sourceFile
64 | SomeNat (Proxy :: Proxy n) <- pure $ someNatVal (fromIntegral (S.size wordSet))
65 |
66 | printf "%d unique words found.\n" (natVal (Proxy @n))
67 |
68 | let model = word2vec @n @100
69 | enc = encoder @n @100
70 | p0 <- initParamNormal model 0.2 g
71 |
72 |
73 | let report n b = do
74 | liftIO $ printf "(Batch %d)\n" (b :: Int)
75 | t0 <- liftIO getCurrentTime
76 | C.drop (n - 1)
77 | mp <- mapM (liftIO . evaluate . force) =<< await
78 | t1 <- liftIO getCurrentTime
79 | case mp of
80 | Nothing -> liftIO $ putStrLn "Done!"
81 | Just p@(_ :# pEnc) -> do
82 | chnk <- lift . state $ (,[])
83 | liftIO $ do
84 | printf "Trained on %d points in %s.\n"
85 | (length chnk)
86 | (show (t1 `diffUTCTime` t0))
87 | let trainScore = testModelAll maxIxTest model (TJust p) chnk
88 | printf "Training error: %.3f%%\n" ((1 - trainScore) * 100)
89 |
90 | testWords <- tokenize <$> readFile testFile
91 | let tests = flip map testWords $ \w ->
92 | let v = maybe 0 (runModelStateless enc (TJust pEnc))
93 | $ oneHotWord wordSet w
94 | in intercalate "," $ map (printf "%0.4f") (VS.toList (H.extract v))
95 |
96 | writeFile logFile $ unlines tests
97 |
98 | report n (b + 1)
99 |
100 |
101 | C.runResourceT . flip evalStateT []
102 | . runConduit
103 | $ forever ( C.sourceFile sourceFile
104 | .| C.decodeUtf8
105 | .| C.concatMap (tokenize . T.unpack)
106 | .| C.concatMap (oneHotWord wordSet)
107 | .| C.slidingWindow 5
108 | .| C.concatMap (SV.fromList @5)
109 | .| C.map makeCBOW
110 | )
111 | .| skipSampling 0.02 g
112 | .| C.iterM (modify . (:))
113 | .| optoConduit def p0
114 | (adam def (modelGradStoch crossEntropy noReg model g))
115 | .| report 500 0
116 | .| C.sinkNull
117 |
118 | tokenize :: String -> [String]
119 | tokenize = words . map (whitePunc . toLower)
120 | where
121 | whitePunc c
122 | | isPunctuation c = ' '
123 | | otherwise = c
124 |
--------------------------------------------------------------------------------
/old2/src/Backprop/Learn/Initialize.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE AllowAmbiguousTypes #-}
2 | {-# LANGUAGE ConstraintKinds #-}
3 | {-# LANGUAGE DataKinds #-}
4 | {-# LANGUAGE DefaultSignatures #-}
5 | {-# LANGUAGE FlexibleContexts #-}
6 | {-# LANGUAGE FlexibleInstances #-}
7 | {-# LANGUAGE GADTs #-}
8 | {-# LANGUAGE RankNTypes #-}
9 | {-# LANGUAGE ScopedTypeVariables #-}
10 | {-# LANGUAGE TypeApplications #-}
11 | {-# LANGUAGE TypeFamilies #-}
12 | {-# LANGUAGE TypeOperators #-}
13 | {-# LANGUAGE UndecidableInstances #-}
14 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
15 |
16 | module Backprop.Learn.Initialize (
17 | Initialize(..)
18 | , gInitialize
19 | , initializeNormal
20 | , initializeSingle
21 | ) where
22 |
23 | import Control.Monad.Primitive
24 | import Data.Complex
25 | import Data.Type.Length
26 | import Data.Type.NonEmpty
27 | import Data.Type.Tuple
28 | import GHC.TypeNats
29 | import Generics.OneLiner
30 | import Numeric.LinearAlgebra.Static.Vector
31 | import Statistics.Distribution
32 | import Statistics.Distribution.Normal
33 | import Type.Class.Known
34 | import Type.Family.List
35 | import qualified Data.Vector.Generic as VG
36 | import qualified Data.Vector.Generic.Sized as SVG
37 | import qualified Numeric.LinearAlgebra.Static as H
38 | import qualified System.Random.MWC as MWC
39 |
40 | -- | Class for types that are basically a bunch of 'Double's, which can be
41 | -- initialized with a given identical and independent distribution.
42 | class Initialize p where
43 | initialize
44 | :: (ContGen d, PrimMonad m)
45 | => d
46 | -> MWC.Gen (PrimState m)
47 | -> m p
48 |
49 | default initialize
50 | :: (ADTRecord p, Constraints p Initialize, ContGen d, PrimMonad m)
51 | => d
52 | -> MWC.Gen (PrimState m)
53 | -> m p
54 | initialize = gInitialize
55 |
56 | -- | 'initialize' for any instance of 'Generic'.
57 | gInitialize
58 | :: (ADTRecord p, Constraints p Initialize, ContGen d, PrimMonad m)
59 | => d
60 | -> MWC.Gen (PrimState m)
61 | -> m p
62 | gInitialize d g = createA' @Initialize (initialize d g)
63 |
64 | initializeNormal
65 | :: (Initialize p, PrimMonad m)
66 | => Double -- ^ standard deviation
67 | -> MWC.Gen (PrimState m)
68 | -> m p
69 | initializeNormal = initialize . normalDistr 0
70 |
71 | -- | 'initialize' definition if @p@ is a single number.
72 | initializeSingle
73 | :: (ContGen d, PrimMonad m, Fractional p)
74 | => d
75 | -> MWC.Gen (PrimState m)
76 | -> m p
77 | initializeSingle d = fmap realToFrac . genContVar d
78 |
79 | instance Initialize Double where
80 | initialize = initializeSingle
81 | instance Initialize Float where
82 | initialize = initializeSingle
83 |
84 | -- | Initializes real and imaginary components identically
85 | instance Initialize a => Initialize (Complex a) where
86 |
87 | instance Initialize T0
88 | instance (Initialize a, Initialize b) => Initialize (T2 a b)
89 | instance (Initialize a, Initialize b, Initialize c) => Initialize (T3 a b c)
90 |
91 | instance (ListC (Initialize <$> as), Known Length as) => Initialize (T as) where
92 | initialize d g = constTA @Initialize (initialize d g) known
93 |
94 | instance (Initialize a, ListC (Initialize <$> as), Known Length as) => Initialize (NETup (a ':| as)) where
95 | initialize d g = NET <$> initialize d g
96 | <*> initialize d g
97 |
98 | instance Initialize ()
99 | instance (Initialize a, Initialize b) => Initialize (a, b)
100 | instance (Initialize a, Initialize b, Initialize c) => Initialize (a, b, c)
101 |
102 | instance (VG.Vector v a, KnownNat n, Initialize a) => Initialize (SVG.Vector v n a) where
103 | initialize d = SVG.replicateM . initialize d
104 |
105 | instance KnownNat n => Initialize (H.R n) where
106 | initialize d = fmap vecR . initialize d
107 | instance KnownNat n => Initialize (H.C n) where
108 | initialize d = fmap vecC . initialize d
109 |
110 | instance (KnownNat n, KnownNat m) => Initialize (H.L n m) where
111 | initialize d = fmap vecL . initialize d
112 | instance (KnownNat n, KnownNat m) => Initialize (H.M n m) where
113 | initialize d = fmap vecM . initialize d
114 |
115 |
116 | constTA
117 | :: forall c as f. (ListC (c <$> as), Applicative f)
118 | => (forall a. c a => f a)
119 | -> Length as
120 | -> f (T as)
121 | constTA x = go
122 | where
123 | go :: forall bs. ListC (c <$> bs) => Length bs -> f (T bs)
124 | go LZ = pure TNil
125 | go (LS l) = (:&) <$> x <*> go l
126 |
127 |
--------------------------------------------------------------------------------
/old2/src/Data/Type/NonEmpty.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE BangPatterns #-}
2 | {-# LANGUAGE FlexibleInstances #-}
3 | {-# LANGUAGE GADTs #-}
4 | {-# LANGUAGE KindSignatures #-}
5 | {-# LANGUAGE LambdaCase #-}
6 | {-# LANGUAGE MultiParamTypeClasses #-}
7 | {-# LANGUAGE PatternSynonyms #-}
8 | {-# LANGUAGE ScopedTypeVariables #-}
9 | {-# LANGUAGE TypeApplications #-}
10 | {-# LANGUAGE TypeFamilies #-}
11 | {-# LANGUAGE TypeFamilyDependencies #-}
12 | {-# LANGUAGE TypeInType #-}
13 | {-# LANGUAGE TypeOperators #-}
14 | {-# LANGUAGE UndecidableInstances #-}
15 | {-# LANGUAGE ViewPatterns #-}
16 |
17 | module Data.Type.NonEmpty (
18 | NETup(.., NETT), ToNonEmpty
19 | , netHead, netTail
20 | , unNet
21 | , netT
22 | , NonEmpty(..)
23 | ) where
24 |
25 | import Control.DeepSeq
26 | import Data.Kind
27 | import Data.List.NonEmpty (NonEmpty(..))
28 | import Data.Type.Length
29 | import Data.Type.Tuple
30 | import Lens.Micro
31 | import Numeric.Backprop (Backprop(..))
32 | import Numeric.Opto.Ref
33 | import Numeric.Opto.Update
34 | import Type.Class.Known
35 | import Type.Family.List
36 | import qualified Data.Binary as Bi
37 |
38 | data NETup :: NonEmpty Type -> Type where
39 | NET :: !a -> !(T as) -> NETup (a ':| as)
40 |
41 | pattern NETT :: T (a ': as) -> NETup (a ':| as)
42 | pattern NETT { netT } <- (\case NET x xs -> x :& xs->(!netT))
43 | where
44 | NETT (!(x :& xs)) = NET x xs
45 | {-# COMPLETE NETT #-}
46 |
47 | instance (NFData a, ListC (NFData <$> as)) => NFData (NETup (a ':| as)) where
48 | rnf (NETT xs) = rnf xs
49 |
50 | instance (Num a, ListC (Num <$> as), Known Length as) => Num (NETup (a ':| as)) where
51 | NETT xs + NETT ys = NETT (xs + ys)
52 | NETT xs - NETT ys = NETT (xs - ys)
53 | NETT xs * NETT ys = NETT (xs * ys)
54 | negate (NETT xs) = NETT (negate xs)
55 | abs (NETT xs) = NETT (abs xs)
56 | signum (NETT xs) = NETT (signum xs)
57 | fromInteger = NETT . fromInteger
58 |
59 | instance (Fractional a, ListC (Num <$> as), ListC (Fractional <$> as), Known Length as)
60 | => Fractional (NETup (a ':| as)) where
61 | NETT xs / NETT ys = NETT (xs / ys)
62 | recip (NETT xs) = NETT (recip xs)
63 | fromRational = NETT . fromRational
64 |
65 | instance (Floating a, ListC (Num <$> as), ListC (Fractional <$> as), ListC (Floating <$> as), Known Length as)
66 | => Floating (NETup (a ':| as)) where
67 | pi = NETT pi
68 | sqrt (NETT xs) = NETT (sqrt xs)
69 | exp (NETT xs) = NETT (exp xs)
70 | log (NETT xs) = NETT (log xs)
71 | sin (NETT xs) = NETT (sin xs)
72 | cos (NETT xs) = NETT (cos xs)
73 | tan (NETT xs) = NETT (tan xs)
74 | asin (NETT xs) = NETT (asin xs)
75 | acos (NETT xs) = NETT (acos xs)
76 | atan (NETT xs) = NETT (atan xs)
77 | sinh (NETT xs) = NETT (sinh xs)
78 | cosh (NETT xs) = NETT (cosh xs)
79 | tanh (NETT xs) = NETT (tanh xs)
80 | asinh (NETT xs) = NETT (asinh xs)
81 | acosh (NETT xs) = NETT (acosh xs)
82 | atanh (NETT xs) = NETT (atanh xs)
83 |
84 | instance (Additive a, Additive (T as))
85 | => Additive (NETup (a ':| as)) where
86 | NET x xs .+. NET y ys = NET (x .+. y) (xs .+. ys)
87 | addZero = NET addZero addZero
88 |
89 | instance (Scaling c a, Scaling c (T as)) => Scaling c (NETup (a ':| as)) where
90 | c .* NET x xs = NET (c .* x) (c .* xs)
91 | scaleOne = scaleOne @c @a
92 |
93 | instance (Metric c a, Metric c (T as), Ord c, Floating c) => Metric c (NETup (a ':| as)) where
94 | NET x xs <.> NET y ys = (x <.> y) + (xs <.> ys)
95 | norm_inf (NET x xs) = max (norm_inf x) (norm_inf xs)
96 | norm_0 (NET x xs) = norm_0 x + norm_0 xs
97 | norm_1 (NET x xs) = norm_1 x + norm_1 xs
98 | quadrance (NET x xs) = quadrance x + quadrance xs
99 |
100 | instance (Additive a, Additive (T as), Ref m (NETup (a ':| as)) v)
101 | => AdditiveInPlace m v (NETup (a ':| as))
102 | instance (Scaling s a, Scaling s (T as), Ref m (NETup (a ':| as)) v)
103 | => ScalingInPlace m v s (NETup (a ':| as))
104 |
105 | instance (Backprop a, ListC (Backprop <$> as)) => Backprop (NETup (a ':| as)) where
106 | zero (NET x xs) = NET (zero x) (zero xs)
107 | add (NET x xs) (NET y ys) = NET (add x y) (add xs ys)
108 | one (NET x xs) = NET (one x) (one xs)
109 |
110 | instance (Bi.Binary a, ListC (Bi.Binary <$> as), Known Length as) => Bi.Binary (NETup (a ':| as)) where
111 | get = NET <$> Bi.get
112 | <*> Bi.get
113 | put (NET x xs) = Bi.put x *> Bi.put xs
114 |
115 | netHead :: Lens (NETup (a ':| as)) (NETup (b ':| as)) a b
116 | netHead f (NET x xs) = (`NET` xs) <$> f x
117 |
118 | netTail :: Lens (NETup (a ':| as)) (NETup (a ':| bs)) (T as) (T bs)
119 | netTail f (NET x xs) = NET x <$> f xs
120 |
121 | unNet :: NETup (a ':| as) -> (a, T as)
122 | unNet (NET x xs) = (x, xs)
123 |
124 | type family ToNonEmpty (l :: [k]) = (m :: Maybe (NonEmpty k)) | m -> l where
125 | ToNonEmpty '[] = 'Nothing
126 | ToNonEmpty (a ': as) = 'Just (a ':| as)
127 |
--------------------------------------------------------------------------------
/old2/src/Backprop/Learn/Model/Parameter.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DeriveDataTypeable #-}
2 | {-# LANGUAGE FlexibleContexts #-}
3 | {-# LANGUAGE FlexibleInstances #-}
4 | {-# LANGUAGE GADTs #-}
5 | {-# LANGUAGE KindSignatures #-}
6 | {-# LANGUAGE MultiParamTypeClasses #-}
7 | {-# LANGUAGE PatternSynonyms #-}
8 | {-# LANGUAGE RankNTypes #-}
9 | {-# LANGUAGE RecordWildCards #-}
10 | {-# LANGUAGE ScopedTypeVariables #-}
11 | {-# LANGUAGE TypeFamilies #-}
12 | {-# LANGUAGE TypeInType #-}
13 | {-# LANGUAGE UndecidableInstances #-}
14 |
15 | module Backprop.Learn.Model.Parameter (
16 | DeParam(..)
17 | , dpDeterm
18 | , DeParamAt(..)
19 | , dpaDeterm
20 | , ReParam
21 | , rpDeterm
22 | ) where
23 |
24 | import Backprop.Learn.Model.Class
25 | import Control.Monad.Primitive
26 | import Data.Kind
27 | import Data.Typeable
28 | import Numeric.Backprop
29 | import qualified System.Random.MWC as MWC
30 |
31 | -- | Convert a model with trainabile parameters into a model without any
32 | -- trainable parameters.
33 | --
34 | -- The parameters are instead fixed (or stochastic, with a fixed
35 | -- distribution), and the model is treated as an untrainable function.
36 | --
37 | -- 'DeParam' is essentially 'DeParamAt', with 'id' as the lens.
38 | data DeParam :: Type -> Type -> Type where
39 | DP :: { _dpParam :: p
40 | , _dpParamStoch :: forall m. PrimMonad m => MWC.Gen (PrimState m) -> m p
41 | , _dpLearn :: l
42 | }
43 | -> DeParam p l
44 | deriving (Typeable)
45 |
46 | -- | Create a 'DeParam' from a deterministic, non-stochastic parameter.
47 | dpDeterm :: p -> l -> DeParam p l
48 | dpDeterm p = DP p (const (pure p))
49 |
50 | instance (Learn a b l, LParamMaybe l ~ 'Just p) => Learn a b (DeParam p l) where
51 | type LParamMaybe (DeParam p l) = 'Nothing
52 | type LStateMaybe (DeParam p l) = LStateMaybe l
53 |
54 | runLearn DP{..} _ = runLearn _dpLearn (J_ (constVar _dpParam))
55 | runLearnStoch DP{..} g _ x s = do
56 | p <- constVar <$> _dpParamStoch g
57 | runLearnStoch _dpLearn g (J_ p) x s
58 |
59 | -- | Wrapping a mode with @'DeParamAt' pq p q@ says that the mode's
60 | -- parameter @pq@ can be split into @p@ and @q@, and fixes @q@ to a given
61 | -- fixed (or stochastic with a fixed distribution) value. The model now
62 | -- effectively has parameter @p@ only, and the @q@ part will not be
63 | -- backpropagated.
64 | --
65 | -- 'DeParam' is essentially 'DeParamAt' where @p@ is '()' or 'T0'.
66 | data DeParamAt :: Type -> Type -> Type -> Type -> Type where
67 | DPA :: { _dpaSplit :: pq -> (p, q)
68 | , _dpaJoin :: p -> q -> pq
69 | , _dpaParam :: q
70 | , _dpaParamStoch :: forall m. PrimMonad m => MWC.Gen (PrimState m) -> m q
71 | , _dpaLearn :: l
72 | }
73 | -> DeParamAt pq p q l
74 | deriving (Typeable)
75 |
76 | -- | Create a 'DeParamAt' from a deterministic, non-stochastic fixed value
77 | -- as a part of the parameter.
78 | dpaDeterm :: (pq -> (p, q)) -> (p -> q -> pq) -> q -> l -> DeParamAt pq p q l
79 | dpaDeterm s j q = DPA s j q (const (pure q))
80 |
81 | instance (Learn a b l, LParamMaybe l ~ 'Just pq, Backprop pq, Backprop p, Backprop q)
82 | => Learn a b (DeParamAt pq p q l) where
83 | type LParamMaybe (DeParamAt pq p q l) = 'Just p
84 | type LStateMaybe (DeParamAt pq p q l) = LStateMaybe l
85 |
86 | runLearn DPA{..} (J_ p) = runLearn _dpaLearn (J_ p')
87 | where
88 | p' = isoVar2 _dpaJoin _dpaSplit p (constVar _dpaParam)
89 |
90 | runLearnStoch DPA{..} g (J_ p) x s = do
91 | q <- _dpaParamStoch g
92 | let p' = isoVar2 _dpaJoin _dpaSplit p (constVar q)
93 | runLearnStoch _dpaLearn g (J_ p') x s
94 |
95 |
96 | -- | Pre-apply a function to the parameter before the original model sees
97 | -- it. A @'ReParam' p q@ turns a model taking @p@ into a model taking @q@.
98 | --
99 | -- Note that a @'ReParam' p ''Nothing'@ is essentially the same as
100 | -- a @'DeParam' p@, and one could implement @'DeParamAt' p q@ in terms of
101 | -- @'ReParam' p (''Just' q)@.
102 | data ReParam :: Type -> Maybe Type -> Type -> Type where
103 | RP :: { _rpFrom :: forall s. Reifies s W => Mayb (BVar s) q -> BVar s p
104 | , _rpFromStoch :: forall m s. (PrimMonad m, Reifies s W) => MWC.Gen (PrimState m) -> Mayb (BVar s) q -> m (BVar s p)
105 | , _rpLearn :: l
106 | }
107 | -> ReParam p q l
108 | deriving (Typeable)
109 |
110 | instance (Learn a b l, LParamMaybe l ~ 'Just p) => Learn a b (ReParam p q l) where
111 | type LParamMaybe (ReParam p q l) = q
112 | type LStateMaybe (ReParam p q l) = LStateMaybe l
113 |
114 | runLearn RP{..} q = runLearn _rpLearn (J_ (_rpFrom q))
115 | runLearnStoch RP{..} g q x s = do
116 | p <- _rpFromStoch g q
117 | runLearnStoch _rpLearn g (J_ p) x s
118 |
119 | -- | Create a 'ReParam' from a deterministic, non-stochastic
120 | -- transformation function.
121 | rpDeterm
122 | :: (forall s. Reifies s W => Mayb (BVar s) q -> BVar s p)
123 | -> l
124 | -> ReParam p q l
125 | rpDeterm f = RP f (const (pure . f))
126 |
--------------------------------------------------------------------------------
/old2/src/Backprop/Learn/Model/Neural.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DataKinds #-}
2 | {-# LANGUAGE FlexibleContexts #-}
3 | {-# LANGUAGE FlexibleInstances #-}
4 | {-# LANGUAGE MultiParamTypeClasses #-}
5 | {-# LANGUAGE PatternSynonyms #-}
6 | {-# LANGUAGE RankNTypes #-}
7 | {-# LANGUAGE TypeFamilies #-}
8 | {-# LANGUAGE TypeOperators #-}
9 | {-# LANGUAGE UndecidableInstances #-}
10 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
11 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
12 |
13 | module Backprop.Learn.Model.Neural (
14 | -- * Fully connected
15 | -- ** Feed-forward
16 | FC, pattern FC, FCp, fcBias, fcWeights
17 | -- *** With activation function
18 | , FCA, pattern FCA, _fcaActivation
19 | -- ** Recurrent
20 | , FCR, pattern FCR, FCRp, fcrBias, fcrInputWeights, fcrStateWeights
21 | -- *** With activation function
22 | , FCRA, pattern FCRA, _fcraStore, _fcraActivation
23 | ) where
24 |
25 |
26 | import Backprop.Learn.Model.Combinator
27 | import Backprop.Learn.Model.Regression
28 | import Backprop.Learn.Model.State
29 | import Data.Tuple
30 | import GHC.TypeNats
31 | import Lens.Micro
32 | import Numeric.Backprop
33 | import Numeric.LinearAlgebra.Static.Backprop
34 | import qualified Numeric.LinearAlgebra.Static as H
35 |
36 | -- | Fully connected feed-forward layer with bias. Parameterized by its
37 | -- initialization distribution.
38 | --
39 | -- Note that this has no activation function; to use as a model with
40 | -- activation function, chain it with an activation function using 'RMap',
41 | -- ':.~', etc.; see 'FCA' for a convenient type synonym and constructor.
42 | --
43 | -- Without any activation function, this is essentially a multivariate
44 | -- linear regression.
45 | --
46 | -- With the logistic function as an activation function, this is
47 | -- essentially multivariate logistic regression. (See 'Logistic')
48 | type FC i o = LinReg i o
49 |
50 | pattern FC :: FC i o
51 | pattern FC = LinReg
52 |
53 | -- | Convenient synonym for an 'FC' post-composed with a simple
54 | -- parameterless activation function.
55 | type FCA i o = RMap (R o) (R o) (FC i o)
56 |
57 | -- | Construct an 'FCA' using a generating function and activation
58 | -- function.
59 | --
60 | -- Some common ones include 'logistic' and @'vmap' 'reLU'@.
61 | pattern FCA
62 | :: (forall s. Reifies s W => BVar s (R o) -> BVar s (R o)) -- ^ '_fcaActivation'
63 | -> FCA i o
64 | pattern FCA { _fcaActivation } = RM _fcaActivation FC
65 |
66 | type FCp = LRp
67 |
68 | fcWeights :: Lens (FCp i o) (FCp i' o) (L o i) (L o i')
69 | fcWeights = lrBeta
70 |
71 | fcBias :: Lens' (FCp i o) (R o)
72 | fcBias = lrAlpha
73 |
74 | -- | Fully connected recurrent layer with bias.
75 | --
76 | -- Parameterized by its initialization distributions, and also the function
77 | -- to compute the new state from previous input.
78 | --
79 | -- @
80 | -- instance 'Learn' ('R' i) (R o) ('FCR' h i o) where
81 | -- type 'LParamMaybe' (FCR h i o) = ''Just' ('FCRp' h i o)
82 | -- type 'LStateMaybe' (FCR h i o) = 'Just (R h)
83 | -- @
84 | type FCR h i o = Recurrent (R (i + h)) (R i) (R h) (R o) (FC (i + h) o)
85 |
86 | -- | Construct an 'FCR'
87 | pattern FCR
88 | :: (KnownNat h, KnownNat i)
89 | => (forall s. Reifies s W => BVar s (R o) -> BVar s (R h)) -- ^ '_fcrSTore'
90 | -> FCR h i o
91 | pattern FCR { _fcrStore } <-
92 | Rec { _recLoop = _fcrStore
93 | , _recLearn = FC
94 | }
95 | where
96 | FCR s = Rec { _recSplit = H.split
97 | , _recJoin = (H.#)
98 | , _recLoop = s
99 | , _recLearn = FC
100 | }
101 | {-# COMPLETE FCR #-}
102 |
103 | -- | Convenient synonym for an 'FCR' post-composed with a simple
104 | -- parameterless activation function.
105 | type FCRA h i o = RMap (R o) (R o) (FCR h i o)
106 |
107 | -- | Construct an 'FCRA' using a generating function and activation
108 | -- function.
109 | --
110 | -- Some common ones include 'logistic' and @'vmap' 'reLU'@.
111 | pattern FCRA
112 | :: (KnownNat h, KnownNat i)
113 | => (forall s. Reifies s W => BVar s (R o) -> BVar s (R h)) -- ^ '_fcraStore'
114 | -> (forall s. Reifies s W => BVar s (R o) -> BVar s (R o)) -- ^ '_fcraActivation'
115 | -> FCRA h i o
116 | pattern FCRA { _fcraStore, _fcraActivation }
117 | = RM _fcraActivation (FCR _fcraStore)
118 |
119 | type FCRp h i o = FCp (i + h) o
120 |
121 | lensIso :: (s -> (a, x)) -> ((b, x) -> t) -> Lens s t a b
122 | lensIso f g h x = g <$> _1 h (f x)
123 |
124 | fcrInputWeights
125 | :: (KnownNat h, KnownNat i, KnownNat i', KnownNat o)
126 | => Lens (FCRp h i o) (FCRp h i' o) (L o i) (L o i')
127 | fcrInputWeights = fcWeights . lensIso H.splitCols (uncurry (H.|||))
128 |
129 | fcrStateWeights
130 | :: (KnownNat h, KnownNat h', KnownNat i, KnownNat o)
131 | => Lens (FCRp h i o) (FCRp h' i o) (L o h) (L o h')
132 | fcrStateWeights = fcWeights . lensIso (swap . H.splitCols) (uncurry (H.|||) . swap)
133 |
134 | fcrBias :: Lens' (FCRp h i o) (R o)
135 | fcrBias = fcBias
136 |
137 |
--------------------------------------------------------------------------------
/app/char-rnn.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE AllowAmbiguousTypes #-}
2 | {-# LANGUAGE DataKinds #-}
3 | {-# LANGUAGE FlexibleContexts #-}
4 | {-# LANGUAGE GADTs #-}
5 | {-# LANGUAGE PartialTypeSignatures #-}
6 | {-# LANGUAGE ScopedTypeVariables #-}
7 | {-# LANGUAGE TupleSections #-}
8 | {-# LANGUAGE TypeApplications #-}
9 | {-# LANGUAGE TypeOperators #-}
10 | {-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
11 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
12 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
13 |
14 | import Backprop.Learn
15 | import Control.DeepSeq
16 | import Control.Exception
17 | import Control.Monad
18 | import Control.Monad.IO.Class
19 | import Control.Monad.Trans.Class
20 | import Control.Monad.Trans.State
21 | import Data.Char
22 | import Data.Conduit
23 | import Data.Default
24 | import Data.Foldable
25 | import Data.Proxy
26 | import Data.Time
27 | import Data.Type.Equality
28 | import Data.Type.Tuple
29 | import GHC.TypeNats
30 | import Numeric.LinearAlgebra.Static.Backprop
31 | import Numeric.LinearAlgebra.Static.Vector
32 | import Numeric.Opto
33 | import System.Environment
34 | import Text.Printf
35 | import qualified Conduit as C
36 | import qualified Data.Conduit.Combinators as C
37 | import qualified Data.Set as S
38 | import qualified Data.Text as T
39 | import qualified Data.Vector.Sized as SV
40 | import qualified Data.Vector.Storable.Sized as SVS
41 | import qualified System.Random.MWC as MWC
42 | import qualified System.Random.MWC.Distributions as MWC
43 |
44 | -- | TODO: replace with 'LModel'
45 | charRNN
46 | :: forall n h1 h2. (KnownNat n, KnownNat h1, KnownNat h2)
47 | => LModel _ _ (R n) (R n)
48 | charRNN = fca softMax
49 | #: dropout @h2 0.25
50 | #: lstm
51 | #: dropout @h1 0.25
52 | #: lstm
53 | #: nilLM
54 |
55 | oneHotChar
56 | :: KnownNat n
57 | => S.Set Char
58 | -> Char
59 | -> R n
60 | oneHotChar cs = oneHotR . fromIntegral . (`S.findIndex` cs)
61 |
62 | main :: IO ()
63 | main = MWC.withSystemRandom @IO $ \g -> do
64 | sourceFile:_ <- getArgs
65 | charMap <- S.fromList <$> readFile sourceFile
66 | SomeNat (Proxy :: Proxy n) <- pure $ someNatVal (fromIntegral (length charMap))
67 | SomeNat (Proxy :: Proxy n') <- pure $ someNatVal (fromIntegral (length charMap - 1))
68 | Just Refl <- pure $ sameNat (Proxy @(n' + 1)) (Proxy @n)
69 |
70 | printf "%d characters found.\n" (natVal (Proxy @n))
71 |
72 | let model0 = charRNN @n @100 @50
73 | model = trainState . unrollFinal @(SV.Vector 15) $ model0
74 |
75 | p0 <- initParamNormal model 0.2 g
76 |
77 | let report n b = do
78 | liftIO $ printf "(Batch %d)\n" (b :: Int)
79 | t0 <- liftIO getCurrentTime
80 | C.drop (n - 1)
81 | mp <- mapM (liftIO . evaluate . force) =<< await
82 | t1 <- liftIO getCurrentTime
83 | case mp of
84 | Nothing -> liftIO $ putStrLn "Done!"
85 | Just p@(p' :# s') -> do
86 | chnk <- lift . state $ (,[])
87 | liftIO $ do
88 | printf "Trained on %d points in %s.\n"
89 | (length chnk)
90 | (show (t1 `diffUTCTime` t0))
91 | let trainScore = testModelAll maxIxTest model (TJust p) chnk
92 | printf "Training error: %.3f%%\n" ((1 - trainScore) * 100)
93 |
94 | forM_ (take 15 chnk) $ \(x,y) -> do
95 | let primed = primeModel model0 (TJust p') x (TJust s')
96 | testOut <- fmap reverse . flip execStateT [] $
97 | iterateModelM ( fmap (oneHotR . fromIntegral)
98 | . (>>= \r -> r <$ modify (r:)) -- trace
99 | . (`MWC.categorical` g)
100 | . SVS.fromSized
101 | . rVec
102 | )
103 | 100 model0 (TJust p') y primed
104 | printf "%s|%s\n"
105 | (sanitize . (`S.elemAt` charMap) . fromIntegral . maxIndexR <$> (toList x ++ [y]))
106 | (sanitize . (`S.elemAt` charMap) <$> testOut)
107 | report n (b + 1)
108 |
109 | C.runResourceT . flip evalStateT []
110 | . runConduit
111 | $ forever ( C.sourceFile sourceFile
112 | .| C.decodeUtf8
113 | .| C.concatMap T.unpack
114 | .| C.map (oneHotChar charMap)
115 | .| leadings
116 | )
117 | .| skipSampling 0.02 g
118 | .| C.iterM (modify . (:))
119 | .| optoConduit
120 | def
121 | p0
122 | (adam def (modelGradStoch crossEntropy noReg model g))
123 | .| report 2500 0
124 | .| C.sinkNull
125 |
126 |
127 | sanitize :: Char -> Char
128 | sanitize c | isPrint c = c
129 | | otherwise = '#'
130 |
--------------------------------------------------------------------------------
/old2/src/Backprop/Learn/Loss.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DataKinds #-}
2 | {-# LANGUAGE FlexibleContexts #-}
3 | {-# LANGUAGE RankNTypes #-}
4 | {-# LANGUAGE ScopedTypeVariables #-}
5 | {-# LANGUAGE TypeApplications #-}
6 | {-# LANGUAGE TypeOperators #-}
7 |
8 | module Backprop.Learn.Loss (
9 | -- * Loss functions
10 | Loss
11 | , crossEntropy
12 | , squaredError, absError, totalSquaredError, squaredErrorV
13 | -- , totalCov
14 | -- ** Manipulate loss functions
15 | , scaleLoss
16 | , sumLoss
17 | , sumLossDecay
18 | , lastLoss
19 | , zipLoss
20 | , t2Loss
21 | , t3Loss
22 | -- * Regularization
23 | , Regularizer
24 | , l2Reg
25 | , l1Reg
26 | , noReg
27 | -- ** Manipulate regularizers
28 | , addReg
29 | , scaleReg
30 | ) where
31 |
32 | import Control.Applicative
33 | import Data.Finite
34 | import Data.Type.Tuple hiding (T2(..), T3(..))
35 | import GHC.TypeNats
36 | import Numeric.Backprop
37 | import Numeric.LinearAlgebra.Static.Backprop
38 | import Numeric.Opto.Update hiding ((<.>))
39 | import qualified Data.Type.Tuple as T
40 | import qualified Data.Vector.Sized as SV
41 | import qualified Prelude.Backprop as B
42 |
43 | type Loss a = forall s. Reifies s W => a -> BVar s a -> BVar s Double
44 |
45 | crossEntropy :: KnownNat n => Loss (R n)
46 | crossEntropy targ res = -(log res <.> constVar targ)
47 |
48 | squaredErrorV :: KnownNat n => Loss (R n)
49 | squaredErrorV targ res = e <.> e
50 | where
51 | e = res - constVar targ
52 |
53 | totalSquaredError
54 | :: (Backprop (t Double), Num (t Double), Foldable t, Functor t)
55 | => Loss (t Double)
56 | totalSquaredError targ res = B.sum (e * e)
57 | where
58 | e = constVar targ - res
59 |
60 | squaredError :: Loss Double
61 | squaredError targ res = (res - constVar targ) ** 2
62 |
63 | absError :: Loss Double
64 | absError targ res = abs (res - constVar targ)
65 |
66 | -- -- | Sum of covariances between matching components. Not sure if anyone
67 | -- -- uses this.
68 | -- totalCov :: (Backprop (t Double), Foldable t, Functor t) => Loss (t Double)
69 | -- totalCov targ res = -(xy / fromIntegral n - (x * y) / fromIntegral (n * n))
70 | -- where
71 | -- x = constVar $ sum targ
72 | -- y = B.sum res
73 | -- xy = B.sum (constVar targ * res)
74 | -- n = length targ
75 |
76 | -- klDivergence :: Loss Double
77 | -- klDivergence =
78 |
79 | sumLoss
80 | :: (Traversable t, Applicative t, Backprop a)
81 | => Loss a
82 | -> Loss (t a)
83 | sumLoss l targ = sum . liftA2 l targ . sequenceVar
84 |
85 | zipLoss
86 | :: (Traversable t, Applicative t, Backprop a)
87 | => t Double
88 | -> Loss a
89 | -> Loss (t a)
90 | zipLoss xs l targ = sum
91 | . liftA3 (\α t -> (* constVar α) . l t) xs targ
92 | . sequenceVar
93 |
94 | sumLossDecay
95 | :: forall n a. (KnownNat n, Backprop a)
96 | => Double
97 | -> Loss a
98 | -> Loss (SV.Vector n a)
99 | sumLossDecay β = zipLoss $ SV.generate (\i -> β ** (fromIntegral i - n))
100 | where
101 | n = fromIntegral $ maxBound @(Finite n)
102 |
103 | lastLoss
104 | :: (KnownNat (n + 1), Backprop a)
105 | => Loss a
106 | -> Loss (SV.Vector (n + 1) a)
107 | lastLoss l targ = l (SV.last targ) . viewVar (SV.ix maxBound)
108 |
109 | -- | Scale the result of a loss function.
110 | scaleLoss :: Double -> Loss a -> Loss a
111 | scaleLoss β l x = (* constVar β) . l x
112 |
113 | -- | Lift and sum a loss function over the components of a 'T.T2'.
114 | t2Loss
115 | :: (Backprop a, Backprop b)
116 | => Loss a -- ^ loss on first component
117 | -> Loss b -- ^ loss on second component
118 | -> Loss (T.T2 a b)
119 | t2Loss f g (T.T2 xT yT) (T2B xR yR) = f xT xR + g yT yR
120 |
121 | -- | Lift and sum a loss function over the components of a 'T.T3'.
122 | t3Loss
123 | :: (Backprop a, Backprop b, Backprop c)
124 | => Loss a -- ^ loss on first component
125 | -> Loss b -- ^ loss on second component
126 | -> Loss c -- ^ loss on third component
127 | -> Loss (T.T3 a b c)
128 | t3Loss f g h (T.T3 xT yT zT) xRyRzR
129 | = f xT (xRyRzR ^^. t3_1)
130 | + g yT (xRyRzR ^^. t3_2)
131 | + h zT (xRyRzR ^^. t3_3)
132 |
133 | -- | A regularizer on parameters
134 | type Regularizer p = forall s. Reifies s W => BVar s p -> BVar s Double
135 |
136 | -- | L2 regularization
137 | --
138 | -- \[
139 | -- \sum_w \frac{1}{2} w^2
140 | -- \]
141 | l2Reg
142 | :: (Metric Double p, Backprop p)
143 | => Double -- ^ scaling factor (often 0.5)
144 | -> Regularizer p
145 | l2Reg λ = liftOp1 . op1 $ \x ->
146 | ( λ * quadrance x / 2, (.* x) . (* λ))
147 |
148 | -- | L1 regularization
149 | --
150 | -- \[
151 | -- \sum_w \lvert w \rvert
152 | -- \]
153 | l1Reg
154 | :: (Num p, Metric Double p, Backprop p)
155 | => Double -- ^ scaling factor (often 0.5)
156 | -> Regularizer p
157 | l1Reg λ = liftOp1 . op1 $ \x ->
158 | ( λ * norm_1 x, (.* signum x) . (* λ)
159 | )
160 |
161 | -- | No regularization
162 | noReg :: Regularizer p
163 | noReg _ = constVar 0
164 |
165 | -- | Add together two regularizers
166 | addReg :: Regularizer p -> Regularizer p -> Regularizer p
167 | addReg = liftA2 (+)
168 |
169 | -- | Scale a regularizer's influence
170 | scaleReg :: Double -> Regularizer p -> Regularizer p
171 | scaleReg λ reg = (* constVar λ) . reg
172 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | # Use new container infrastructure to enable caching
2 | sudo: false
3 |
4 | # Choose a lightweight base image; we provide our own build tools.
5 | language: c
6 |
7 | # Caching so the next build will be fast too.
8 | cache:
9 | directories:
10 | - $HOME/.ghc
11 | - $HOME/.cabal
12 | - $HOME/.stack
13 |
14 | # The different configurations we want to test. We have BUILD=cabal which uses
15 | # cabal-install, and BUILD=stack which uses Stack. More documentation on each
16 | # of those below.
17 | #
18 | # We set the compiler values here to tell Travis to use a different
19 | # cache file per set of arguments.
20 | #
21 | # If you need to have different apt packages for each combination in the
22 | # matrix, you can use a line such as:
23 | # addons: {apt: {packages: [libfcgi-dev,libgmp-dev]}}
24 | matrix:
25 | include:
26 | # We grab the appropriate GHC and cabal-install versions from hvr's PPA. See:
27 | # https://github.com/hvr/multi-ghc-travis
28 | - env: BUILD=cabal GHCVER=8.0.2 CABALVER=1.24 HAPPYVER=1.19.5 ALEXVER=3.1.7
29 | compiler: ": #GHC 8.0.2"
30 | addons: {apt: {packages: [cabal-install-1.24,ghc-8.0.2,happy-1.19.5,alex-3.1.7,libblas-dev,liblapack-dev], sources: [hvr-ghc]}}
31 |
32 | # Build with the newest GHC and cabal-install. This is an accepted failure,
33 | # see below.
34 | - env: BUILD=cabal GHCVER=head CABALVER=head HAPPYVER=1.19.5 ALEXVER=3.1.7
35 | compiler: ": #GHC HEAD"
36 | addons: {apt: {packages: [cabal-install-head,ghc-head,happy-1.19.5,alex-3.1.7,libblas-dev,liblapack-dev], sources: [hvr-ghc]}}
37 |
38 | # The Stack builds. We can pass in arbitrary Stack arguments via the ARGS
39 | # variable, such as using --stack-yaml to point to a different file.
40 | - env: BUILD=stack ARGS=""
41 | compiler: ": #stack default"
42 | addons: {apt: {packages: [ghc-8.0.2,libblas-dev,liblapack-dev], sources: [hvr-ghc]}}
43 |
44 | - env: BUILD=stack ARGS="--resolver lts-8"
45 | compiler: ": #stack 8.0.2"
46 | addons: {apt: {packages: [ghc-8.0.2,libblas-dev,liblapack-dev], sources: [hvr-ghc]}}
47 |
48 | # Nightly builds are allowed to fail
49 | - env: BUILD=stack ARGS="--resolver nightly"
50 | compiler: ": #stack nightly"
51 | addons: {apt: {packages: [libgmp,libgmp-dev,libblas-dev,liblapack-dev]}}
52 |
53 | # Build on OS X in addition to Linux
54 | - env: BUILD=stack ARGS=""
55 | compiler: ": #stack default osx"
56 | os: osx
57 |
58 | - env: BUILD=stack ARGS="--resolver lts-8"
59 | compiler: ": #stack 8.0.2 osx"
60 | os: osx
61 |
62 | - env: BUILD=stack ARGS="--resolver nightly"
63 | compiler: ": #stack nightly osx"
64 | os: osx
65 |
66 | allow_failures:
67 | - env: BUILD=cabal GHCVER=head CABALVER=head HAPPYVER=1.19.5 ALEXVER=3.1.7
68 | - env: BUILD=stack ARGS="--resolver nightly"
69 |
70 | before_install:
71 | # Using compiler above sets CC to an invalid value, so unset it
72 | - unset CC
73 |
74 | # We want to always allow newer versions of packages when building on GHC HEAD
75 | - CABALARGS=""
76 | - if [ "x$GHCVER" = "xhead" ]; then CABALARGS=--allow-newer; fi
77 |
78 | # Download and unpack the stack executable
79 | - export PATH=/opt/ghc/$GHCVER/bin:/opt/cabal/$CABALVER/bin:$HOME/.local/bin:/opt/alex/$ALEXVER/bin:/opt/happy/$HAPPYVER/bin:$HOME/.cabal/bin:$PATH
80 | - mkdir -p ~/.local/bin
81 | - |
82 | if [ `uname` = "Darwin" ]
83 | then
84 | travis_retry curl --insecure -L https://www.stackage.org/stack/osx-x86_64 | tar xz --strip-components=1 --include '*/stack' -C ~/.local/bin
85 | else
86 | travis_retry curl -L https://www.stackage.org/stack/linux-x86_64 | tar xz --wildcards --strip-components=1 -C ~/.local/bin '*/stack'
87 | fi
88 |
89 | # Use the more reliable S3 mirror of Hackage
90 | mkdir -p $HOME/.cabal
91 | echo 'remote-repo: hackage.haskell.org:http://hackage.fpcomplete.com/' > $HOME/.cabal/config
92 | echo 'remote-repo-cache: $HOME/.cabal/packages' >> $HOME/.cabal/config
93 |
94 | if [ "$CABALVER" != "1.16" ]
95 | then
96 | echo 'jobs: $ncpus' >> $HOME/.cabal/config
97 | fi
98 |
99 | # Get the list of packages from the stack.yaml file
100 | - PACKAGES=$(stack --install-ghc query locals | grep '^ *path' | sed 's@^ *path:@@')
101 |
102 | install:
103 | - echo "$(ghc --version) [$(ghc --print-project-git-commit-id 2> /dev/null || echo '?')]"
104 | - if [ -f configure.ac ]; then autoreconf -i; fi
105 | - |
106 | set -ex
107 | case "$BUILD" in
108 | stack)
109 | stack --no-terminal --install-ghc $ARGS test --bench --only-dependencies
110 | ;;
111 | cabal)
112 | cabal --version
113 | travis_retry cabal update
114 | cabal install --only-dependencies --enable-tests --enable-benchmarks --force-reinstalls --ghc-options=-O0 --reorder-goals --max-backjumps=-1 $CABALARGS $PACKAGES
115 | ;;
116 | esac
117 | set +ex
118 |
119 | script:
120 | - |
121 | set -ex
122 | case "$BUILD" in
123 | stack)
124 | stack --no-terminal $ARGS test --bench --no-run-benchmarks --haddock --no-haddock-deps
125 | ;;
126 | cabal)
127 | cabal install --enable-tests --enable-benchmarks --force-reinstalls --ghc-options=-O0 --reorder-goals --max-backjumps=-1 $CABALARGS $PACKAGES
128 |
129 | ORIGDIR=$(pwd)
130 | for dir in $PACKAGES
131 | do
132 | cd $dir
133 | cabal check || [ "$CABALVER" == "1.16" ]
134 | cabal sdist
135 | SRC_TGZ=$(cabal info . | awk '{print $2;exit}').tar.gz && \
136 | (cd dist && cabal install --force-reinstalls "$SRC_TGZ")
137 | cd $ORIGDIR
138 | # TODO: temporary kludge until we can somehow only have cabal check
139 | # non extra-dep packages
140 | break
141 | done
142 | ;;
143 | esac
144 | set +ex
145 |
146 |
147 |
--------------------------------------------------------------------------------
/old/src/Learn/Neural/Layer/Compose.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE AllowAmbiguousTypes #-}
2 | {-# LANGUAGE DataKinds #-}
3 | {-# LANGUAGE DeriveGeneric #-}
4 | {-# LANGUAGE FlexibleInstances #-}
5 | {-# LANGUAGE InstanceSigs #-}
6 | {-# LANGUAGE KindSignatures #-}
7 | {-# LANGUAGE LambdaCase #-}
8 | {-# LANGUAGE MultiParamTypeClasses #-}
9 | {-# LANGUAGE ScopedTypeVariables #-}
10 | {-# LANGUAGE TypeFamilies #-}
11 | {-# LANGUAGE TypeInType #-}
12 | {-# LANGUAGE TypeOperators #-}
13 | {-# LANGUAGE UndecidableInstances #-}
14 |
15 | module Learn.Neural.Layer.Compose (
16 | ) where
17 |
18 | import Data.Kind
19 | import Data.Singletons
20 | import GHC.Generics (Generic)
21 | import GHC.Generics.Numeric
22 | import Learn.Neural.Layer
23 | import Numeric.BLAS
24 | import Numeric.Backprop
25 | import Numeric.Backprop.Iso
26 | import qualified Generics.SOP as SOP
27 |
28 | data Comp :: Type -> Type -> k -> Type
29 |
30 | instance ( BLAS b
31 | , Component l b i h
32 | , Component j b h o
33 | )
34 | => Component (Comp l j h) b i o where
35 | data CParam (Comp l j h) b i o = CP { _cp1 :: !(CParam l b i h)
36 | , _cp2 :: !(CParam j b h o)
37 | }
38 | deriving Generic
39 | data CState (Comp l j h) b i o = CS { _cs1 :: !(CState l b i h)
40 | , _cs2 :: !(CState j b h o)
41 | }
42 | deriving Generic
43 | type CConstr (Comp l j h) b i o =
44 | ( CConstr l b i h
45 | , CConstr j b h o
46 | , Num (b h)
47 | , SingI h
48 | )
49 | data CConf (Comp l j h) b i o = CC { _cc1 :: !(CConf l b i h)
50 | , _cc2 :: !(CConf j b h o)
51 | }
52 |
53 | componentOp = bpOp . withInps $ \(x :< p :< s :< Ø) -> do
54 | p1 :< p2 :< Ø <- cpIso #<~ p
55 | s1 :< s2 :< Ø <- csIso #<~ s
56 | y :< s1' :< Ø <- componentOp ~$$ (x :< p1 :< s1 :< Ø)
57 | z :< s2' :< Ø <- componentOp ~$$ (y :< p2 :< s2 :< Ø)
58 | s' :< Ø <- isoVar (from csIso . tup1) (s1' :< s2' :< Ø)
59 | return $ z :< s' :< Ø
60 |
61 | initParam si so (CC c1 c2) g =
62 | CP <$> initParam si sh c1 g
63 | <*> initParam sh so c2 g
64 | where
65 | sh :: Sing h
66 | sh = sing
67 |
68 | initState si so (CC c1 c2) g =
69 | CS <$> initState si sh c1 g
70 | <*> initState sh so c2 g
71 | where
72 | sh :: Sing h
73 | sh = sing
74 |
75 | defConf = CC defConf defConf
76 |
77 |
78 | cpIso :: Iso' (CParam (Comp l j h) b i o) (Tuple '[CParam l b i h, CParam j b h o])
79 | cpIso = iso (\case CP c1 c2 -> c1 ::< c2 ::< Ø) (\case I c1 :< I c2 :< Ø -> CP c1 c2)
80 |
81 | csIso :: Iso' (CState (Comp l j h) b i o) (Tuple '[CState l b i h, CState j b h o])
82 | csIso = iso (\case CS s1 s2 -> s1 ::< s2 ::< Ø) (\case I s1 :< I s2 :< Ø -> CS s1 s2)
83 |
84 | instance (SOP.Generic (CState l b i h), SOP.Generic (CState j b h o))
85 | => SOP.Generic (CState (Comp l j h) b i o)
86 | instance (SOP.Generic (CParam l b i h), SOP.Generic (CParam j b h o))
87 | => SOP.Generic (CParam (Comp l j h) b i o)
88 |
89 | instance (Num (CParam l b i h), Num (CParam j b h o))
90 | => Num (CParam (Comp l j h) b i o) where
91 | (+) = genericPlus
92 | (-) = genericMinus
93 | (*) = genericTimes
94 | negate = genericNegate
95 | abs = genericAbs
96 | signum = genericSignum
97 | fromInteger = genericFromInteger
98 |
99 | instance (Fractional (CParam l b i h), Fractional (CParam j b h o))
100 | => Fractional (CParam (Comp l j h) b i o) where
101 | (/) = genericDivide
102 | recip = genericRecip
103 | fromRational = genericFromRational
104 |
105 | instance (Floating (CParam l b i h), Floating (CParam j b h o))
106 | => Floating (CParam (Comp l j h) b i o) where
107 | pi = genericPi
108 | exp = genericExp
109 | (**) = genericPower
110 | log = genericLog
111 | logBase = genericLogBase
112 | sin = genericSin
113 | cos = genericCos
114 | tan = genericTan
115 | asin = genericAsin
116 | acos = genericAcos
117 | atan = genericAtan
118 | sinh = genericSinh
119 | cosh = genericCosh
120 | tanh = genericTanh
121 | asinh = genericAsinh
122 | acosh = genericAcosh
123 | atanh = genericAtanh
124 |
125 | instance (Num (CState l b i h), Num (CState j b h o))
126 | => Num (CState (Comp l j h) b i o) where
127 | (+) = genericPlus
128 | (-) = genericMinus
129 | (*) = genericTimes
130 | negate = genericNegate
131 | abs = genericAbs
132 | signum = genericSignum
133 | fromInteger = genericFromInteger
134 |
135 | instance (Fractional (CState l b i h), Fractional (CState j b h o))
136 | => Fractional (CState (Comp l j h) b i o) where
137 | (/) = genericDivide
138 | recip = genericRecip
139 | fromRational = genericFromRational
140 |
141 | instance (Floating (CState l b i h), Floating (CState j b h o))
142 | => Floating (CState (Comp l j h) b i o) where
143 | pi = genericPi
144 | exp = genericExp
145 | (**) = genericPower
146 | log = genericLog
147 | logBase = genericLogBase
148 | sin = genericSin
149 | cos = genericCos
150 | tan = genericTan
151 | asin = genericAsin
152 | acos = genericAcos
153 | atan = genericAtan
154 | sinh = genericSinh
155 | cosh = genericCosh
156 | tanh = genericTanh
157 | asinh = genericAsinh
158 | acosh = genericAcosh
159 | atanh = genericAtanh
160 |
--------------------------------------------------------------------------------
/old/src/Numeric/BLAS/FVector.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DataKinds #-}
2 | {-# LANGUAGE DeriveGeneric #-}
3 | {-# LANGUAGE GADTs #-}
4 | {-# LANGUAGE InstanceSigs #-}
5 | {-# LANGUAGE KindSignatures #-}
6 | {-# LANGUAGE LambdaCase #-}
7 | {-# LANGUAGE RankNTypes #-}
8 | {-# LANGUAGE ScopedTypeVariables #-}
9 | {-# LANGUAGE TypeApplications #-}
10 | {-# LANGUAGE TypeFamilies #-}
11 | {-# LANGUAGE TypeOperators #-}
12 |
13 | module Numeric.BLAS.FVector (
14 | FV(..)
15 | ) where
16 |
17 | import Data.Finite
18 | import Data.Finite.Internal
19 | import Data.Kind
20 | import Data.List
21 | import Data.Maybe
22 | import Data.Singletons.Prelude
23 | import Data.Singletons.TypeLits
24 | import Data.Type.Combinator
25 | import Data.Type.Product
26 | import Data.Type.Vector
27 | import GHC.Generics (Generic)
28 | import GHC.TypeLits
29 | import Lens.Micro
30 | import Numeric.BLAS
31 | import Numeric.Tensor
32 | import qualified Data.Vector.Sized as SV
33 | import qualified Data.Vector.Unboxed as VU
34 |
35 | newtype FV :: [Nat] -> Type where
36 | FV :: { getFV :: VU.Vector Double
37 | }
38 | -> FV b
39 | deriving (Generic, Show)
40 |
41 | instance Tensor FV where
42 | type Scalar FV = Double
43 |
44 | genA s f = fmap (FV . VU.fromList) . traverse f $ range s
45 | gen s f = FV . VU.fromList . fmap f $ range s
46 | tkonst s = FV . VU.replicate (fromIntegral (product (fromSing s)))
47 | tsum = VU.sum . getFV
48 | tmap f = FV . VU.map f . getFV
49 | tzip f (FV xs) (FV ys) = FV (VU.zipWith f xs ys)
50 | tzipN
51 | :: forall s n. SingI s
52 | => (Vec n Double -> Double)
53 | -> VecT n FV s
54 | -> FV s
55 | tzipN f xs = FV $ VU.generate (fromIntegral len) $ \i ->
56 | (f (vmap (I . (VU.! i)) xs'))
57 | where
58 | len = product (fromSing (sing @_ @s))
59 | xs' = vmap getFV xs
60 |
61 |
62 | tslice p0 = FV . go sing p0 . getFV
63 | where
64 | go :: Sing ns -> ProdMap Slice ns ms -> VU.Vector Double -> VU.Vector Double
65 | go = \case
66 | SNil -> \case
67 | PMZ -> id
68 | SNat `SCons` ss -> \case
69 | PMS (Slice sL sC _) pms ->
70 | let -- some wasted work here in re-computing the product,
71 | -- but premature optimization blah blah
72 | innerSize = fromIntegral (product (fromSing ss))
73 | dropper = innerSize * fromIntegral (fromSing sL)
74 | taker = innerSize * fromIntegral (fromSing sC)
75 | in over (chunks innerSize) (go ss pms)
76 | . VU.slice dropper taker
77 |
78 | tindex i (FV xs) = xs VU.! fromIntegral (getFinite (reIndex i))
79 |
80 | treshape _ (FV xs) = FV xs
81 | tload _ = FV . VU.convert . SV.fromSized
82 | textract
83 | :: forall s. SingI s
84 | => FV s
85 | -> SV.Vector (Product s) Double
86 | textract = withKnownNat (sProduct (sing @_ @s)) $
87 | fromJust . SV.toSized . VU.convert . getFV
88 |
89 | instance BLAS FV where
90 |
91 | iamax
92 | :: forall n. KnownNat n
93 | => FV '[n + 1]
94 | -> Finite (n + 1)
95 | iamax = withKnownNat (SNat @n %:+ SNat @1) $
96 | Finite . fromIntegral . VU.maxIndex . VU.map abs . getFV
97 |
98 | gemv
99 | :: forall m n. (KnownNat m, KnownNat n)
100 | => Double
101 | -> FV '[m, n]
102 | -> FV '[n]
103 | -> Maybe (Double, FV '[m])
104 | -> FV '[m]
105 | gemv α (FV a) (FV x) b =
106 | FV
107 | . maybe id (\(β, FV ys) -> VU.zipWith (\y z -> β * y + z) ys) b
108 | . over (chunkDown innerSize) (\r -> α * VU.sum (VU.zipWith (*) r x))
109 | $ a
110 | where
111 | innerSize = fromIntegral $ natVal (Proxy @n)
112 |
113 | ger α (FV xs) (FV ys) b =
114 | FV
115 | . maybe id (\(FV zs) -> VU.zipWith (+) zs) b
116 | . VU.concatMap (\x -> VU.map ((α * x) *) ys)
117 | $ xs
118 |
119 | gemm
120 | :: forall m o n. (KnownNat m, KnownNat o, KnownNat n)
121 | => Double
122 | -> FV '[m, o]
123 | -> FV '[o, n]
124 | -> Maybe (Double, FV '[m, n])
125 | -> FV '[m, n]
126 | gemm α (FV as) bs cs =
127 | FV
128 | . maybe id (uncurry f) cs
129 | . over (chunks innerSize) muller
130 | $ as
131 | where
132 | innerSize = fromIntegral $ natVal (Proxy @o)
133 | muller r =
134 | over (chunkDown innerSize) (\c -> α * VU.sum (VU.zipWith (*) r c))
135 | . getFV
136 | $ transp bs
137 | f β (FV cs') = VU.zipWith (\c b -> β * c + b) cs'
138 |
139 | range :: Sing ns -> [Prod Finite ns]
140 | range = \case
141 | SNil -> [Ø]
142 | SNat `SCons` ss -> (:<) <$> finites <*> range ss
143 |
144 | reIndex
145 | :: SingI ns
146 | => Prod Finite ns
147 | -> Finite (Product ns)
148 | reIndex = Finite . fst . unsafeReIndex sing
149 |
150 | unsafeReIndex
151 | :: Sing ns
152 | -> Prod Finite ns
153 | -> (Integer, Integer)
154 | unsafeReIndex = \case
155 | SNil -> \case
156 | Ø -> (0, 1)
157 | SNat `SCons` ss -> \case
158 | (i :: Finite n) :< is ->
159 | let (j, jSize) = unsafeReIndex ss is
160 | iSize = jSize * (fromSing (SNat @n))
161 | in (j + jSize * getFinite i, iSize)
162 |
163 | chunks
164 | :: (VU.Unbox a, VU.Unbox b)
165 | => Int
166 | -> Traversal (VU.Vector a) (VU.Vector b) (VU.Vector a) (VU.Vector b)
167 | chunks n f = fmap VU.concat . traverse f . unfoldr u
168 | where
169 | u xs | VU.length xs >= n = Just (VU.splitAt n xs)
170 | | otherwise = Nothing
171 |
172 | chunkDown
173 | :: (VU.Unbox a, VU.Unbox b)
174 | => Int
175 | -> Traversal (VU.Vector a) (VU.Vector b) (VU.Vector a) b
176 | chunkDown n f = fmap VU.fromList . traverse f . unfoldr u
177 | where
178 | u xs | VU.length xs >= n = Just (VU.splitAt n xs)
179 | | otherwise = Nothing
180 |
--------------------------------------------------------------------------------
/old2/src/Backprop/Learn/Test.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE ApplicativeDo #-}
2 | {-# LANGUAGE FlexibleContexts #-}
3 | {-# LANGUAGE PartialTypeSignatures #-}
4 | {-# LANGUAGE RankNTypes #-}
5 | {-# LANGUAGE ScopedTypeVariables #-}
6 | {-# LANGUAGE TypeApplications #-}
7 | {-# LANGUAGE TypeFamilies #-}
8 | {-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
9 |
10 | module Backprop.Learn.Test (
11 | -- * Tests
12 | Test
13 | , maxIxTest, rmseTest
14 | , squaredErrorTest, absErrorTest, totalSquaredErrorTest, squaredErrorTestV
15 | , crossEntropyTest
16 | -- ** Manipulate tests
17 | , lossTest, lmapTest
18 | -- * Run tests
19 | , testLearn, testLearnStoch, testLearnAll, testLearnStochAll
20 | -- ** Correlation tests
21 | , testLearnCov, testLearnCorr
22 | , testLearnStochCov, testLearnStochCorr
23 | ) where
24 |
25 | import Backprop.Learn.Loss
26 | import Backprop.Learn.Model
27 | import Control.Monad.Primitive
28 | import Data.Bifunctor
29 | import Data.Bitraversable
30 | import Data.Function
31 | import Data.Profunctor
32 | import Data.Proxy
33 | import GHC.TypeNats
34 | import Numeric.Backprop
35 | import qualified Control.Foldl as L
36 | import qualified Numeric.LinearAlgebra as HU
37 | import qualified Numeric.LinearAlgebra.Static as H
38 | import qualified System.Random.MWC as MWC
39 |
40 | -- TODO: support non-double results?
41 |
42 | type Test o = o -> o -> Double
43 |
44 | -- | Create a 'Test' from a 'Loss'
45 | lossTest :: Loss a -> Test a
46 | lossTest l x = evalBP (l x)
47 |
48 | maxIxTest :: KnownNat n => Test (H.R n)
49 | maxIxTest x y
50 | | match x y = 1
51 | | otherwise = 0
52 | where
53 | match = (==) `on` (HU.maxIndex . H.extract)
54 |
55 | rmseTest :: forall n. KnownNat n => Test (H.R n)
56 | rmseTest x y = H.norm_2 (x - y) / sqrt (fromIntegral (natVal (Proxy @n)))
57 |
58 | squaredErrorTest :: Real a => Test a
59 | squaredErrorTest x y = e * e
60 | where
61 | e = realToFrac (x - y)
62 |
63 | absErrorTest :: Real a => Test a
64 | absErrorTest x y = realToFrac . abs $ x - y
65 |
66 | totalSquaredErrorTest :: (Applicative t, Foldable t, Real a) => Test (t a)
67 | totalSquaredErrorTest x y = realToFrac (sum e)
68 | where
69 | e = do
70 | x' <- x
71 | y' <- y
72 | pure ((x' - y') ^ (2 :: Int))
73 |
74 | squaredErrorTestV :: KnownNat n => Test (H.R n)
75 | squaredErrorTestV x y = e `H.dot` e
76 | where
77 | e = x - y
78 |
79 | crossEntropyTest :: KnownNat n => Test (H.R n)
80 | crossEntropyTest targ res = -(log res H.<.> targ)
81 |
82 | lmapTest
83 | :: (a -> b)
84 | -> Test b
85 | -> Test a
86 | lmapTest f t x y = t (f x) (f y)
87 |
88 | testLearn
89 | :: (Learn a b l, NoState l)
90 | => Test b
91 | -> l
92 | -> LParam_ I l
93 | -> a
94 | -> b
95 | -> Double
96 | testLearn t l mp x y = t y $ runLearnStateless_ l mp x
97 |
98 | testLearnStoch
99 | :: (Learn a b l, NoState l, PrimMonad m)
100 | => Test b
101 | -> l
102 | -> MWC.Gen (PrimState m)
103 | -> LParam_ I l
104 | -> a
105 | -> b
106 | -> m Double
107 | testLearnStoch t l g mp x y = t y <$> runLearnStochStateless_ l g mp x
108 |
109 | cov :: Fractional a => L.Fold (a, a) a
110 | cov = do
111 | x <- lmap fst L.sum
112 | y <- lmap snd L.sum
113 | xy <- lmap (uncurry (*)) L.sum
114 | n <- fromIntegral <$> L.length
115 | pure (xy / n - (x * y) / n / n)
116 |
117 | corr :: Floating a => L.Fold (a, a) a
118 | corr = do
119 | x <- lmap fst L.sum
120 | x2 <- lmap ((**2) . fst) L.sum
121 | y <- lmap snd L.sum
122 | y2 <- lmap ((**2) . snd) L.sum
123 | xy <- lmap (uncurry (*)) L.sum
124 | n <- fromIntegral <$> L.length
125 | pure $ (xy / n - (x * y) / n / n)
126 | / sqrt ( x2 / n - (x / n)**2 )
127 | / sqrt ( y2 / n - (y / n)**2 )
128 |
129 | testLearnCov
130 | :: (Learn a b l, NoState l, Foldable t, Fractional b)
131 | => l
132 | -> LParam_ I l
133 | -> t (a, b)
134 | -> b
135 | testLearnCov l p = L.fold $ (lmap . first) (runLearnStateless_ l p) cov
136 |
137 | testLearnCorr
138 | :: (Learn a b l, NoState l, Foldable t, Floating b)
139 | => l
140 | -> LParam_ I l
141 | -> t (a, b)
142 | -> b
143 | testLearnCorr l p = L.fold $ (lmap . first) (runLearnStateless_ l p) corr
144 |
145 | testLearnAll
146 | :: (Learn a b l, NoState l, Foldable t)
147 | => Test b
148 | -> l
149 | -> LParam_ I l
150 | -> t (a, b)
151 | -> Double
152 | testLearnAll t l p = L.fold $ lmap (uncurry (testLearn t l p)) L.mean
153 |
154 | -- newtype M m a = M { getM :: m a }
155 | -- instance (Semigroup a, Applicative m) => Semigroup (M m a) where
156 | -- M x <> M y = M $ liftA2 (<>) x y
157 | -- instance (Monoid a, Applicative m) => Monoid (M m a) where
158 | -- mappend = (<>)
159 | -- mempty = M (pure mempty)
160 |
161 | testLearnStochAll
162 | :: (Learn a b l, NoState l, PrimMonad m, Foldable t)
163 | => Test b
164 | -> l
165 | -> MWC.Gen (PrimState m)
166 | -> LParam_ I l
167 | -> t (a, b)
168 | -> m Double
169 | testLearnStochAll t l g p = L.foldM $ L.premapM (uncurry (testLearnStoch t l g p))
170 | (L.generalize L.mean)
171 |
172 | testLearnStochCov
173 | :: (Learn a b l, NoState l, PrimMonad m, Foldable t, Fractional b)
174 | => l
175 | -> MWC.Gen (PrimState m)
176 | -> LParam_ I l
177 | -> t (a, b)
178 | -> m b
179 | testLearnStochCov l g p = L.foldM $ (L.premapM . flip bitraverse pure)
180 | (runLearnStochStateless_ l g p)
181 | (L.generalize cov)
182 |
183 | testLearnStochCorr
184 | :: (Learn a b l, NoState l, PrimMonad m, Foldable t, Floating b)
185 | => l
186 | -> MWC.Gen (PrimState m)
187 | -> LParam_ I l
188 | -> t (a, b)
189 | -> m b
190 | testLearnStochCorr l g p = L.foldM $ (L.premapM . flip bitraverse pure)
191 | (runLearnStochStateless_ l g p)
192 | (L.generalize corr)
193 |
--------------------------------------------------------------------------------
/src/Backprop/Learn/Model/State.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE ApplicativeDo #-}
2 | {-# LANGUAGE FlexibleContexts #-}
3 | {-# LANGUAGE FlexibleInstances #-}
4 | {-# LANGUAGE GADTs #-}
5 | {-# LANGUAGE MultiParamTypeClasses #-}
6 | {-# LANGUAGE PatternSynonyms #-}
7 | {-# LANGUAGE RankNTypes #-}
8 | {-# LANGUAGE ScopedTypeVariables #-}
9 | {-# LANGUAGE TypeApplications #-}
10 | {-# LANGUAGE TypeFamilies #-}
11 | {-# LANGUAGE TypeInType #-}
12 | {-# LANGUAGE TypeOperators #-}
13 | {-# LANGUAGE UndecidableInstances #-}
14 |
15 | module Backprop.Learn.Model.State (
16 | -- * To and from statelessness
17 | trainState, deState, deStateD, zeroState, dummyState
18 | -- * Manipulate model states
19 | , unroll, unrollFinal, recurrent
20 | ) where
21 |
22 | import Backprop.Learn.Model.Types
23 | import Control.Monad.Primitive
24 | import Control.Monad.Trans.State
25 | import Data.Bifunctor
26 | import Data.Foldable
27 | import Data.Type.Functor.Product
28 | import Data.Type.Tuple
29 | import Numeric.Backprop
30 | import qualified System.Random.MWC as MWC
31 |
32 | -- | Make a model stateless by converting the state to a trained parameter,
33 | -- and dropping the modified state from the result.
34 | --
35 | -- One of the ways to make a model stateless for training purposes. Useful
36 | -- when used after 'Unroll'. See 'DeState', as well.
37 | --
38 | -- Its parameters are:
39 | --
40 | -- * If the input has no parameters, just the initial state.
41 | -- * If the input has a parameter, a ':#' of that parameter and initial state.
42 | trainState
43 | :: forall p s a b.
44 | ( PureProd Maybe p
45 | , PureProd Maybe s
46 | , AllConstrainedProd Backprop p
47 | , AllConstrainedProd Backprop s
48 | )
49 | => Model p s a b
50 | -> Model (p :#? s) 'Nothing a b
51 | trainState = withModelFunc $ \f (p :#? s) x n_ ->
52 | (second . const) n_ <$> f p x s
53 |
54 | -- | Make a model stateless by pre-applying a fixed state (or a stochastic
55 | -- one with fixed stribution) and dropping the modified state from the
56 | -- result.
57 | --
58 | -- One of the ways to make a model stateless for training purposes. Useful
59 | -- when used after 'Unroll'. See 'TrainState', as well.
60 | deState
61 | :: s
62 | -> (forall m. PrimMonad m => MWC.Gen (PrimState m) -> m s)
63 | -> Model p ('Just s) a b
64 | -> Model p 'Nothing a b
65 | deState s sStoch f = Model
66 | { runLearn = \p x n_ ->
67 | (second . const) n_ $ runLearn f p x (PJust (auto s))
68 | , runLearnStoch = \g p x n_ -> do
69 | s' <- sStoch g
70 | (second . const) n_ <$> runLearnStoch f g p x (PJust (auto s'))
71 | }
72 |
73 | -- | 'deState', except the state is always the same even in stochastic
74 | -- mode.
75 | deStateD
76 | :: s
77 | -> Model p ('Just s) a b
78 | -> Model p 'Nothing a b
79 | deStateD s = deState s (const (pure s))
80 |
81 | -- | 'deState' with a constant state of 0.
82 | zeroState
83 | :: Num s
84 | => Model p ('Just s) a b
85 | -> Model p 'Nothing a b
86 | zeroState = deStateD 0
87 |
88 | -- | Unroll a (usually) stateful model into one taking a vector of
89 | -- sequential inputs.
90 | --
91 | -- Basically applies the model to every item of input and returns all of
92 | -- the results, but propagating the state between every step.
93 | --
94 | -- Useful when used before 'trainState' or 'deState'. See
95 | -- 'unrollTrainState' and 'unrollDeState'.
96 | --
97 | -- Compare to 'feedbackTrace', which, instead of receiving a vector of
98 | -- sequential inputs, receives a single input and uses its output as the
99 | -- next input.
100 | unroll
101 | :: (Traversable t, Backprop a, Backprop b)
102 | => Model p s a b
103 | -> Model p s (t a) (t b)
104 | unroll = withModelFunc $ \f p xs s ->
105 | (fmap . first) collectVar
106 | . flip runStateT s
107 | . traverse (StateT . f p)
108 | . sequenceVar
109 | $ xs
110 |
111 | -- | Version of 'unroll' that only keeps the "final" result, dropping all
112 | -- of the intermediate results.
113 | --
114 | -- Turns a stateful model into one that runs the model repeatedly on
115 | -- multiple inputs sequentially and outputs the final result after seeing
116 | -- all items.
117 | --
118 | -- Note will be partial if given an empty sequence.
119 | unrollFinal
120 | :: (Traversable t, Backprop a)
121 | => Model p s a b
122 | -> Model p s (t a) b
123 | unrollFinal = withModelFunc $ \f p xs s0 ->
124 | foldlM (\(_, s) x -> f p x s)
125 | (undefined, s0)
126 | (sequenceVar xs)
127 |
128 | -- | Fix a part of a parameter of a model to be (a function of) the
129 | -- /previous/ ouput of the model itself.
130 | --
131 | -- Essentially, takes a \( X \times Y \rightarrow Z \) into a /stateful/
132 | -- \( X \rightarrow Z \), where the Y is given by a function of the
133 | -- /previous output/ of the model.
134 | --
135 | -- Essentially makes a model "recurrent": it receives its previous output
136 | -- as input.
137 | --
138 | -- See 'fcr' for an application.
139 | recurrent
140 | :: forall p s ab a b c.
141 | -- ( KnownMayb s
142 | ( AllConstrainedProd Backprop s
143 | , PureProd Maybe s
144 | , Backprop a
145 | , Backprop b
146 | )
147 | => (ab -> (a, b)) -- ^ split
148 | -> (a -> b -> ab) -- ^ join
149 | -> BFunc c b -- ^ store state
150 | -> Model p s ab c
151 | -> Model p (s :#? 'Just b) a c
152 | recurrent spl joi sto = withModelFunc $ \f p x (s :#? y) -> do
153 | (z, s') <- f p (isoVar2 joi spl x (fromPJust y)) s
154 | pure (z, s' :#? PJust (sto z))
155 |
156 | -- | Give a stateless model a "dummy" state. For now, useful for using
157 | -- with combinators like 'deState' that require state. However, 'deState'
158 | -- could also be made more lenient (to accept non stateful models) in the
159 | -- future.
160 | --
161 | -- Also useful for usage with combinators like 'Control.Category..' from
162 | -- "Control.Category" that requires all input models to share common state.
163 | dummyState
164 | :: forall s p a b. ()
165 | => Model p 'Nothing a b
166 | -> Model p s a b
167 | dummyState = withModelFunc $ \f p x s ->
168 | (second . const) s <$> f p x PNothing
169 |
--------------------------------------------------------------------------------
/old/src/Numeric/BLAS/NVector.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DataKinds #-}
2 | {-# LANGUAGE DeriveGeneric #-}
3 | {-# LANGUAGE FlexibleContexts #-}
4 | {-# LANGUAGE GADTs #-}
5 | {-# LANGUAGE GeneralizedNewtypeDeriving #-}
6 | {-# LANGUAGE InstanceSigs #-}
7 | {-# LANGUAGE KindSignatures #-}
8 | {-# LANGUAGE LambdaCase #-}
9 | {-# LANGUAGE ScopedTypeVariables #-}
10 | {-# LANGUAGE StandaloneDeriving #-}
11 | {-# LANGUAGE TypeApplications #-}
12 | {-# LANGUAGE TypeFamilyDependencies #-}
13 | {-# LANGUAGE TypeInType #-}
14 | {-# LANGUAGE TypeOperators #-}
15 | {-# LANGUAGE UndecidableInstances #-}
16 |
17 | module Numeric.BLAS.NVector (
18 | NV(..)
19 | , NV'
20 | ) where
21 |
22 | import Control.Applicative
23 | import Control.Monad
24 | import Control.Monad.Trans.State
25 | import Data.Finite.Internal
26 | import Data.Kind
27 | import Data.Maybe
28 | import Data.Monoid (Endo(..))
29 | import Data.Singletons
30 | import Data.Singletons.Prelude
31 | import Data.Singletons.TypeLits
32 | import Data.Type.Product
33 | import GHC.Generics (Generic)
34 | import GHC.TypeLits
35 | import Numeric.BLAS
36 | import Numeric.Tensor
37 | import qualified Data.Vector as UV
38 | import qualified Data.Vector.Sized as V
39 | import qualified Data.Vector.Storable as UVS
40 | import qualified Data.Vector.Storable.Sized as VS
41 |
42 |
43 | type family NV' (s :: [Nat]) = (h :: Type) | h -> s where
44 | NV' '[] = Double
45 | NV' (n ': ns) = V.Vector n (NV' ns)
46 |
47 | newtype NV :: [Nat] -> Type where
48 | NV :: { getNV :: NV' b }
49 | -> NV b
50 | deriving (Generic)
51 |
52 | deriving instance (Show (NV' a)) => Show (NV a)
53 |
54 | genNV :: Sing ns -> (Prod Finite ns -> Double) -> NV' ns
55 | genNV = \case
56 | SNil -> \f -> f Ø
57 | SNat `SCons` ss -> \f -> V.generate_ $ \i ->
58 | genNV ss (f . (i :<))
59 |
60 | genNVA
61 | :: Applicative f
62 | => Sing ns
63 | -> (Prod Finite ns -> f Double)
64 | -> f (NV' ns)
65 | genNVA = \case
66 | SNil -> \f -> f Ø
67 | SNat `SCons` ss -> \f -> sequenceA . V.generate_ $ \i ->
68 | genNVA ss (f . (i :<))
69 |
70 | sumNV
71 | :: Sing ns
72 | -> NV' ns
73 | -> Double
74 | sumNV = \case
75 | SNil -> id
76 | _ `SCons` ss -> V.sum . fmap (sumNV ss)
77 |
78 | mapNV
79 | :: Sing ns
80 | -> (Double -> Double)
81 | -> NV' ns
82 | -> NV' ns
83 | mapNV = \case
84 | SNil -> id
85 | _ `SCons` ss -> \f -> fmap (mapNV ss f)
86 |
87 | zipNV
88 | :: Sing ns
89 | -> (Double -> Double -> Double)
90 | -> NV' ns
91 | -> NV' ns
92 | -> NV' ns
93 | zipNV = \case
94 | SNil -> id
95 | _ `SCons` ss -> \f -> V.zipWith (zipNV ss f)
96 |
97 | indexNV
98 | :: Sing ns
99 | -> Prod Finite ns
100 | -> NV' ns
101 | -> Double
102 | indexNV = \case
103 | SNil -> \case
104 | Ø -> id
105 | SNat `SCons` ss -> \case
106 | i :< is -> indexNV ss is . flip V.index i
107 |
108 | loadNV
109 | :: Sing ns
110 | -> V.Vector (Product ns) Double
111 | -> NV' ns
112 | loadNV = \case
113 | SNil -> V.head
114 | sn@SNat `SCons` ss -> case sProduct ss of
115 | sp@SNat -> fromJust
116 | . V.fromList
117 | . evalState (replicateM (fromInteger (fromSing sn)) (
118 | loadNV ss . fromJust . V.toSized
119 | <$> state (UV.splitAt (fromInteger (fromSing sp)))
120 | ))
121 | . V.fromSized
122 |
123 | nvElems
124 | :: Sing ns
125 | -> NV' ns
126 | -> [Double]
127 | nvElems s n = appEndo (go s n) []
128 | where
129 | go :: Sing ms -> NV' ms -> Endo [Double]
130 | go = \case
131 | SNil -> \x -> Endo (x:)
132 | _ `SCons` ss -> foldMap (go ss)
133 |
134 | sliceNV
135 | :: ProdMap Slice ns ms
136 | -> NV' ns
137 | -> NV' ms
138 | sliceNV = \case
139 | PMZ -> id
140 | PMS (Slice sL sC@SNat _) pms ->
141 | let l = fromIntegral $ fromSing sL
142 | c = fromIntegral $ fromSing sC
143 | in fmap (sliceNV pms)
144 | . fromJust . V.toSized
145 | . UV.take c
146 | . UV.drop l
147 | . V.fromSized
148 |
149 | instance Tensor NV where
150 | type Scalar NV = Double
151 |
152 | gen s = NV . genNV s
153 | genA s = fmap NV . genNVA s
154 | tsum = sumNV sing . getNV
155 | tmap f = NV . mapNV sing f . getNV
156 | tzip f xs ys = NV $ zipNV sing f (getNV xs) (getNV ys)
157 |
158 | tindex i = indexNV sing i . getNV
159 |
160 | tload s = NV . loadNV s
161 | textract = withKnownNat (sProduct ss) $
162 | fromJust . V.fromList . nvElems ss . getNV
163 | where
164 | ss = sing
165 |
166 | tslice p = NV . sliceNV p . getNV
167 |
168 | instance BLAS NV where
169 | transp = NV . sequenceA . getNV
170 | scal α = NV . fmap (α *) . getNV
171 | axpy α (NV xs) (NV ys) = NV $ liftA2 (\x y -> α * x + y) xs ys
172 | dot (NV xs) (NV ys) = V.sum $ V.zipWith (*) xs ys
173 | norm2 = V.sum . fmap (**2) . getNV
174 | asum = V.sum . fmap abs . getNV
175 |
176 | iamax
177 | :: forall n. KnownNat n
178 | => NV '[n + 1]
179 | -> Finite (n + 1)
180 | iamax = withKnownNat (SNat @n %:+ SNat @1) $
181 | Finite . fromIntegral . UV.maxIndex . fmap abs . V.fromSized . getNV
182 |
183 | gemv α (NV a) (NV xs) b = maybe id (uncurry axpy) b
184 | . NV
185 | . fmap (V.sum . V.zipWith (\x -> (* (x * α))) xs)
186 | $ a
187 |
188 | ger α (NV xs) (NV ys) a = NV . addA $ fmap (\x -> fmap (* (x * α)) ys) xs
189 | where
190 | addA = case a of
191 | Nothing -> id
192 | Just (NV a') -> (V.zipWith . V.zipWith) (+) a'
193 |
194 | gemm α (NV ass) (NV bss) c = NV . addC $
195 | fmap (sumVs . V.zipWith (\bs a -> fmap (* (α * a)) bs) bss) ass
196 | where
197 | sumVs = V.foldl' (V.zipWith (+)) (V.generate (\_ -> 0))
198 | addC = case c of
199 | Nothing -> id
200 | Just (β, NV css) -> (V.zipWith . V.zipWith) (\c' -> (+ (β * c'))) css
201 |
202 |
--------------------------------------------------------------------------------
/src/Backprop/Learn/Test.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE ApplicativeDo #-}
2 | {-# LANGUAGE DataKinds #-}
3 | {-# LANGUAGE FlexibleContexts #-}
4 | {-# LANGUAGE PartialTypeSignatures #-}
5 | {-# LANGUAGE RankNTypes #-}
6 | {-# LANGUAGE ScopedTypeVariables #-}
7 | {-# LANGUAGE TypeApplications #-}
8 | {-# LANGUAGE TypeFamilies #-}
9 | {-# OPTIONS_GHC -Wno-partial-type-signatures #-}
10 |
11 | module Backprop.Learn.Test (
12 | -- * Tests
13 | Test
14 | , maxIxTest, rmseTest
15 | , squaredErrorTest, absErrorTest, totalSquaredErrorTest, squaredErrorTestV
16 | , crossEntropyTest, crossEntropyTest1
17 | , boolTest
18 | -- ** Manipulate tests
19 | , lossTest, lmapTest
20 | -- * Run tests
21 | , testModel, testModelStoch, testModelAll, testModelStochAll
22 | -- ** Correlation tests
23 | , testModelCov, testModelCorr
24 | , testModelStochCov, testModelStochCorr
25 | ) where
26 |
27 | import Backprop.Learn.Loss
28 | import Backprop.Learn.Model
29 | import Control.Monad.Primitive
30 | import Data.Bifunctor
31 | import Data.Bitraversable
32 | import Data.Function
33 | import Data.Profunctor
34 | import Data.Proxy
35 | import GHC.TypeNats
36 | import Numeric.Backprop
37 | import qualified Control.Foldl as L
38 | import qualified Numeric.LinearAlgebra as HU
39 | import qualified Numeric.LinearAlgebra.Static as H
40 | import qualified System.Random.MWC as MWC
41 |
42 | -- TODO: support non-double results?
43 |
44 | type Test o = o -> o -> Double
45 |
46 | -- | Create a 'Test' from a 'Loss'
47 | lossTest :: Loss a -> Test a
48 | lossTest l x = evalBP (l x)
49 |
50 | boolTest :: forall a. RealFrac a => Test a
51 | boolTest x y
52 | | ri x == ri y = 1
53 | | otherwise = 0
54 | where
55 | ri :: a -> Int
56 | ri = round
57 |
58 | maxIxTest :: KnownNat n => Test (H.R n)
59 | maxIxTest x y
60 | | match x y = 1
61 | | otherwise = 0
62 | where
63 | match = (==) `on` (HU.maxIndex . H.extract)
64 |
65 | rmseTest :: forall n. KnownNat n => Test (H.R n)
66 | rmseTest x y = H.norm_2 (x - y) / sqrt (fromIntegral (natVal (Proxy @n)))
67 |
68 | squaredErrorTest :: Real a => Test a
69 | squaredErrorTest x y = e * e
70 | where
71 | e = realToFrac (x - y)
72 |
73 | absErrorTest :: Real a => Test a
74 | absErrorTest x y = realToFrac . abs $ x - y
75 |
76 | totalSquaredErrorTest :: (Applicative t, Foldable t, Real a) => Test (t a)
77 | totalSquaredErrorTest x y = realToFrac (sum e)
78 | where
79 | e = do
80 | x' <- x
81 | y' <- y
82 | pure ((x' - y') ^ (2 :: Int))
83 |
84 | squaredErrorTestV :: KnownNat n => Test (H.R n)
85 | squaredErrorTestV x y = e `H.dot` e
86 | where
87 | e = x - y
88 |
89 | crossEntropyTest :: KnownNat n => Test (H.R n)
90 | crossEntropyTest targ res = -(log res H.<.> targ)
91 |
92 | crossEntropyTest1 :: Test Double
93 | crossEntropyTest1 targ res = -(log res * targ + log (1 - res) * (1 - targ))
94 |
95 | lmapTest
96 | :: (a -> b)
97 | -> Test b
98 | -> Test a
99 | lmapTest f t x y = t (f x) (f y)
100 |
101 | testModel
102 | :: Test b
103 | -> Model p 'Nothing a b
104 | -> TMaybe p
105 | -> a
106 | -> b
107 | -> Double
108 | testModel t f mp x y = t y $ runModelStateless f mp x
109 |
110 | testModelStoch
111 | :: PrimMonad m
112 | => Test b
113 | -> Model p 'Nothing a b
114 | -> MWC.Gen (PrimState m)
115 | -> TMaybe p
116 | -> a
117 | -> b
118 | -> m Double
119 | testModelStoch t f g mp x y = t y <$> runModelStochStateless f g mp x
120 |
121 | cov :: Fractional a => L.Fold (a, a) a
122 | cov = do
123 | x <- lmap fst L.sum
124 | y <- lmap snd L.sum
125 | xy <- lmap (uncurry (*)) L.sum
126 | n <- fromIntegral <$> L.length
127 | pure (xy / n - (x * y) / n / n)
128 |
129 | corr :: Floating a => L.Fold (a, a) a
130 | corr = do
131 | x <- lmap fst L.sum
132 | x2 <- lmap ((**2) . fst) L.sum
133 | y <- lmap snd L.sum
134 | y2 <- lmap ((**2) . snd) L.sum
135 | xy <- lmap (uncurry (*)) L.sum
136 | n <- fromIntegral <$> L.length
137 | pure $ (xy / n - (x * y) / n / n)
138 | / sqrt ( x2 / n - (x / n)**2 )
139 | / sqrt ( y2 / n - (y / n)**2 )
140 |
141 | testModelCov
142 | :: (Foldable t, Fractional b)
143 | => Model p 'Nothing a b
144 | -> TMaybe p
145 | -> t (a, b)
146 | -> b
147 | testModelCov f p = L.fold $ (lmap . first) (runModelStateless f p) cov
148 |
149 | testModelCorr
150 | :: (Foldable t, Floating b)
151 | => Model p 'Nothing a b
152 | -> TMaybe p
153 | -> t (a, b)
154 | -> b
155 | testModelCorr f p = L.fold $ (lmap . first) (runModelStateless f p) corr
156 |
157 | testModelAll
158 | :: Foldable t
159 | => Test b
160 | -> Model p 'Nothing a b
161 | -> TMaybe p
162 | -> t (a, b)
163 | -> Double
164 | testModelAll t f p = L.fold $ lmap (uncurry (testModel t f p)) L.mean
165 |
166 | -- newtype M m a = M { getM :: m a }
167 | -- instance (Semigroup a, Applicative m) => Semigroup (M m a) where
168 | -- M x <> M y = M $ liftA2 (<>) x y
169 | -- instance (Monoid a, Applicative m) => Monoid (M m a) where
170 | -- mappend = (<>)
171 | -- mempty = M (pure mempty)
172 |
173 | testModelStochAll
174 | :: (Foldable t, PrimMonad m)
175 | => Test b
176 | -> Model p 'Nothing a b
177 | -> MWC.Gen (PrimState m)
178 | -> TMaybe p
179 | -> t (a, b)
180 | -> m Double
181 | testModelStochAll t f g p = L.foldM $ L.premapM (uncurry (testModelStoch t f g p))
182 | (L.generalize L.mean)
183 |
184 | testModelStochCov
185 | :: (Foldable t, PrimMonad m, Fractional b)
186 | => Model p 'Nothing a b
187 | -> MWC.Gen (PrimState m)
188 | -> TMaybe p
189 | -> t (a, b)
190 | -> m b
191 | testModelStochCov f g p = L.foldM $ (L.premapM . flip bitraverse pure)
192 | (runModelStochStateless f g p)
193 | (L.generalize cov)
194 |
195 | testModelStochCorr
196 | :: (Foldable t, PrimMonad m, Floating b)
197 | => Model p 'Nothing a b
198 | -> MWC.Gen (PrimState m)
199 | -> TMaybe p
200 | -> t (a, b)
201 | -> m b
202 | testModelStochCorr f g p = L.foldM $ (L.premapM . flip bitraverse pure)
203 | (runModelStochStateless f g p)
204 | (L.generalize corr)
205 |
--------------------------------------------------------------------------------
/old2/src/Backprop/Learn/Model/Stochastic.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DeriveDataTypeable #-}
2 | {-# LANGUAGE FlexibleContexts #-}
3 | {-# LANGUAGE FlexibleInstances #-}
4 | {-# LANGUAGE GADTs #-}
5 | {-# LANGUAGE KindSignatures #-}
6 | {-# LANGUAGE MultiParamTypeClasses #-}
7 | {-# LANGUAGE PatternSynonyms #-}
8 | {-# LANGUAGE RankNTypes #-}
9 | {-# LANGUAGE RecordWildCards #-}
10 | {-# LANGUAGE TypeFamilies #-}
11 | {-# LANGUAGE TypeInType #-}
12 | {-# LANGUAGE UndecidableInstances #-}
13 | {-# LANGUAGE ViewPatterns #-}
14 |
15 | module Backprop.Learn.Model.Stochastic (
16 | DO(..)
17 | , StochFunc(..)
18 | , FixedStochFunc, pattern FSF, _fsfRunDeterm, _fsfRunStoch
19 | , rreLU
20 | , injectNoise, applyNoise
21 | , injectNoiseR, applyNoiseR
22 | ) where
23 |
24 | import Backprop.Learn.Model.Class
25 | import Backprop.Learn.Model.Function
26 | import Control.Monad.Primitive
27 | import Data.Bool
28 | import Data.Kind
29 | import Data.Typeable
30 | import GHC.TypeNats
31 | import Numeric.Backprop
32 | import Numeric.LinearAlgebra.Static.Backprop
33 | import Numeric.LinearAlgebra.Static.Vector
34 | import qualified Data.Vector.Storable.Sized as SVS
35 | import qualified Statistics.Distribution as Stat
36 | import qualified System.Random.MWC as MWC
37 | import qualified System.Random.MWC.Distributions as MWC
38 |
39 | -- | Dropout layer. Parameterized by dropout percentage (should be between
40 | -- 0 and 1).
41 | --
42 | -- 0 corresponds to no dropout, 1 corresponds to complete dropout of all
43 | -- nodes every time.
44 | newtype DO (n :: Nat) = DO { _doRate :: Double }
45 | deriving (Typeable)
46 |
47 | instance KnownNat n => Learn (R n) (R n) (DO n) where
48 | runLearn (DO r) _ = stateless (constVar (realToFrac (1-r)) *)
49 | runLearnStoch (DO r) g _ = statelessM $ \x ->
50 | (x *) . constVar . vecR <$> SVS.replicateM (mask g)
51 | where
52 | mask = fmap (bool 1 0) . MWC.bernoulli r
53 |
54 | -- | Represents a random-valued function, with a possible trainable
55 | -- parameter.
56 | --
57 | -- Requires both a "deterministic" and a "stochastic" mode. The
58 | -- deterministic mode ideally should approximate some mean of the
59 | -- stochastic mode.
60 | data StochFunc :: Maybe Type -> Type -> Type -> Type where
61 | SF :: { _sfRunDeterm :: forall s. Reifies s W => Mayb (BVar s) p -> BVar s a -> BVar s b
62 | , _sfRunStoch
63 | :: forall m s. (PrimMonad m, Reifies s W)
64 | => MWC.Gen (PrimState m)
65 | -> Mayb (BVar s) p
66 | -> BVar s a
67 | -> m (BVar s b)
68 | }
69 | -> StochFunc p a b
70 | deriving (Typeable)
71 |
72 | instance Learn a b (StochFunc p a b) where
73 | type LParamMaybe (StochFunc p a b) = p
74 | type LStateMaybe (StochFunc p a b) = 'Nothing
75 |
76 | runLearn SF{..} = stateless . _sfRunDeterm
77 | runLearnStoch SF{..} g = statelessM . _sfRunStoch g
78 |
79 | -- | Convenient alias for a 'StochFunc' (random-valued function with both
80 | -- deterministic and stochastic modes) with no trained parameters.
81 | type FixedStochFunc = StochFunc 'Nothing
82 |
83 | -- | Construct a 'FixedStochFunc'
84 | pattern FSF :: (forall s. Reifies s W => BVar s a -> BVar s b)
85 | -> (forall m s. (PrimMonad m, Reifies s W) => MWC.Gen (PrimState m) -> BVar s a -> m (BVar s b))
86 | -> FixedStochFunc a b
87 | pattern FSF { _fsfRunDeterm, _fsfRunStoch } <- (getFSF->(getWD->_fsfRunDeterm,getWS->_fsfRunStoch))
88 | where
89 | FSF d s = SF { _sfRunDeterm = const d
90 | , _sfRunStoch = const . s
91 | }
92 | {-# COMPLETE FSF #-}
93 |
94 | newtype WrapDeterm a b = WD { getWD :: forall s. Reifies s W => BVar s a -> BVar s b }
95 | newtype WrapStoch a b = WS { getWS :: forall m s. (PrimMonad m, Reifies s W) => MWC.Gen (PrimState m) -> BVar s a -> m (BVar s b) }
96 |
97 | getFSF :: FixedStochFunc a b -> (WrapDeterm a b, WrapStoch a b)
98 | getFSF SF{..} = ( WD (_sfRunDeterm N_)
99 | , WS (`_sfRunStoch` N_)
100 | )
101 |
102 | -- | Random leaky rectified linear unit
103 | rreLU
104 | :: (Stat.ContGen d, Stat.Mean d, KnownNat n)
105 | => d
106 | -> FixedStochFunc (R n) (R n)
107 | rreLU d = FSF { _fsfRunDeterm = vmap' (preLU v)
108 | , _fsfRunStoch = \g x -> do
109 | α <- vecR <$> SVS.replicateM (Stat.genContVar d g)
110 | pure (zipWithVector preLU (constVar α) x)
111 | }
112 | where
113 | v :: BVar s Double
114 | v = constVar (Stat.mean d)
115 |
116 | -- | Inject random noise. Usually used between neural network layers, or
117 | -- at the very beginning to pre-process input.
118 | --
119 | -- In non-stochastic mode, this adds the mean of the distribution.
120 | injectNoise
121 | :: (Stat.ContGen d, Stat.Mean d, Fractional a)
122 | => d
123 | -> FixedStochFunc a a
124 | injectNoise d = FSF { _fsfRunDeterm = (realToFrac (Stat.mean d) +)
125 | , _fsfRunStoch = \g x -> do
126 | e <- Stat.genContVar d g
127 | pure (realToFrac e + x)
128 | }
129 |
130 |
131 | -- | 'injectNoise' lifted to 'R'
132 | injectNoiseR
133 | :: (Stat.ContGen d, Stat.Mean d, KnownNat n)
134 | => d
135 | -> FixedStochFunc (R n) (R n)
136 | injectNoiseR d = FSF { _fsfRunDeterm = (realToFrac (Stat.mean d) +)
137 | , _fsfRunStoch = \g x -> do
138 | e <- vecR <$> SVS.replicateM (Stat.genContVar d g)
139 | pure (constVar e + x)
140 | }
141 |
142 | -- | Multply by random noise. Can be used to implement dropout-like
143 | -- behavior.
144 | --
145 | -- In non-stochastic mode, this scales by the mean of the distribution.
146 | applyNoise
147 | :: (Stat.ContGen d, Stat.Mean d, Fractional a)
148 | => d
149 | -> FixedStochFunc a a
150 | applyNoise d = FSF { _fsfRunDeterm = (realToFrac (Stat.mean d) *)
151 | , _fsfRunStoch = \g x -> do
152 | e <- Stat.genContVar d g
153 | pure (realToFrac e * x)
154 | }
155 |
156 | -- | 'applyNoise' lifted to 'R'
157 | applyNoiseR
158 | :: (Stat.ContGen d, Stat.Mean d, KnownNat n)
159 | => d
160 | -> FixedStochFunc (R n) (R n)
161 | applyNoiseR d = FSF { _fsfRunDeterm = (realToFrac (Stat.mean d) *)
162 | , _fsfRunStoch = \g x -> do
163 | e <- vecR <$> SVS.replicateM (Stat.genContVar d g)
164 | pure (constVar e * x)
165 | }
166 |
--------------------------------------------------------------------------------
/src/Backprop/Learn/Initialize.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE AllowAmbiguousTypes #-}
2 | {-# LANGUAGE ConstraintKinds #-}
3 | {-# LANGUAGE DataKinds #-}
4 | {-# LANGUAGE DefaultSignatures #-}
5 | {-# LANGUAGE FlexibleContexts #-}
6 | {-# LANGUAGE FlexibleInstances #-}
7 | {-# LANGUAGE GADTs #-}
8 | {-# LANGUAGE RankNTypes #-}
9 | {-# LANGUAGE ScopedTypeVariables #-}
10 | {-# LANGUAGE TypeApplications #-}
11 | {-# LANGUAGE TypeFamilies #-}
12 | {-# LANGUAGE TypeOperators #-}
13 | {-# LANGUAGE UndecidableInstances #-}
14 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
15 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
16 |
17 | module Backprop.Learn.Initialize (
18 | Initialize(..)
19 | , gInitialize
20 | , initializeNormal
21 | , initializeSingle
22 | -- * Reshape
23 | , reshapeR
24 | , reshapeLRows
25 | , reshapeLCols
26 | ) where
27 |
28 | import Control.Monad.Primitive
29 | import Data.Complex
30 | import Data.Proxy
31 | import Data.Type.Equality
32 | import Data.Type.Tuple
33 | import Data.Vinyl
34 | import GHC.TypeLits.Compare
35 | import GHC.TypeNats
36 | import Generics.OneLiner
37 | import Numeric.LinearAlgebra.Static.Vector
38 | import Statistics.Distribution
39 | import Statistics.Distribution.Normal
40 | import qualified Data.Vector.Generic as VG
41 | import qualified Data.Vector.Generic.Sized as SVG
42 | import qualified Numeric.LinearAlgebra.Static as H
43 | import qualified System.Random.MWC as MWC
44 |
45 | -- | Class for types that are basically a bunch of 'Double's, which can be
46 | -- initialized with a given identical and independent distribution.
47 | class Initialize p where
48 | initialize
49 | :: (ContGen d, PrimMonad m)
50 | => d
51 | -> MWC.Gen (PrimState m)
52 | -> m p
53 |
54 | default initialize
55 | :: (ADTRecord p, Constraints p Initialize, ContGen d, PrimMonad m)
56 | => d
57 | -> MWC.Gen (PrimState m)
58 | -> m p
59 | initialize = gInitialize
60 |
61 | -- | 'initialize' for any instance of 'Generic'.
62 | gInitialize
63 | :: (ADTRecord p, Constraints p Initialize, ContGen d, PrimMonad m)
64 | => d
65 | -> MWC.Gen (PrimState m)
66 | -> m p
67 | gInitialize d g = createA' @Initialize (initialize d g)
68 |
69 | -- | Helper over 'inititialize' for a gaussian distribution centered around
70 | -- zero.
71 | initializeNormal
72 | :: (Initialize p, PrimMonad m)
73 | => Double -- ^ standard deviation
74 | -> MWC.Gen (PrimState m)
75 | -> m p
76 | initializeNormal = initialize . normalDistr 0
77 |
78 | -- | 'initialize' definition if @p@ is a single number.
79 | initializeSingle
80 | :: (ContGen d, PrimMonad m, Fractional p)
81 | => d
82 | -> MWC.Gen (PrimState m)
83 | -> m p
84 | initializeSingle d = fmap realToFrac . genContVar d
85 |
86 | instance Initialize Double where
87 | initialize = initializeSingle
88 | instance Initialize Float where
89 | initialize = initializeSingle
90 |
91 | -- | Initializes real and imaginary components identically
92 | instance Initialize a => Initialize (Complex a) where
93 |
94 | instance Initialize T0
95 | instance Initialize a => Initialize (TF a)
96 | instance (Initialize a, Initialize b) => Initialize (a :# b)
97 |
98 | instance RPureConstrained Initialize as => Initialize (T as) where
99 | initialize d g = rtraverse (fmap TF)
100 | $ rpureConstrained @Initialize (initialize d g)
101 |
102 | -- instance (Initialize a, ListC (Initialize <$> as), Known Length as) => Initialize (NETup (a ':| as)) where
103 | -- initialize d g = NET <$> initialize d g
104 | -- <*> initialize d g
105 |
106 | instance Initialize ()
107 | instance (Initialize a, Initialize b) => Initialize (a, b)
108 | instance (Initialize a, Initialize b, Initialize c) => Initialize (a, b, c)
109 | instance (Initialize a, Initialize b, Initialize c, Initialize d) => Initialize (a, b, c, d)
110 | instance (Initialize a, Initialize b, Initialize c, Initialize d, Initialize e) => Initialize (a, b, c, d, e)
111 |
112 | instance (VG.Vector v a, KnownNat n, Initialize a) => Initialize (SVG.Vector v n a) where
113 | initialize d = SVG.replicateM . initialize d
114 |
115 | instance KnownNat n => Initialize (H.R n) where
116 | initialize d = fmap vecR . initialize d
117 | instance KnownNat n => Initialize (H.C n) where
118 | initialize d = fmap vecC . initialize d
119 |
120 | instance (KnownNat n, KnownNat m) => Initialize (H.L n m) where
121 | initialize d = fmap vecL . initialize d
122 | instance (KnownNat n, KnownNat m) => Initialize (H.M n m) where
123 | initialize d = fmap vecM . initialize d
124 |
125 | -- | Reshape a vector to have a different amount of items If the matrix is
126 | -- grown, new weights are initialized according to the given distribution.
127 | reshapeR
128 | :: forall i j d m. (ContGen d, PrimMonad m, KnownNat i, KnownNat j)
129 | => d
130 | -> MWC.Gen (PrimState m)
131 | -> H.R i
132 | -> m (H.R j)
133 | reshapeR d g x = case Proxy @j %<=? Proxy @i of
134 | LE Refl -> pure . vecR . SVG.take @_ @j @(i - j) . rVec $ x
135 | NLE Refl Refl -> (x H.#) <$> initialize @(H.R (j - i)) d g
136 |
137 | -- | Reshape a matrix to have a different amount of rows If the matrix
138 | -- is grown, new weights are initialized according to the given
139 | -- distribution.
140 | reshapeLRows
141 | :: forall i j n d m. (ContGen d, PrimMonad m, KnownNat n, KnownNat i, KnownNat j)
142 | => d
143 | -> MWC.Gen (PrimState m)
144 | -> H.L i n
145 | -> m (H.L j n)
146 | reshapeLRows d g x = case Proxy @j %<=? Proxy @i of
147 | LE Refl -> pure . rowsL . SVG.take @_ @j @(i - j) . lRows $ x
148 | NLE Refl Refl -> (x H.===) <$> initialize @(H.L (j - i) n) d g
149 |
150 | -- | Reshape a matrix to have a different amount of columns. If the matrix
151 | -- is grown, new weights are initialized according to the given
152 | -- distribution.
153 | reshapeLCols
154 | :: forall i j n d m. (ContGen d, PrimMonad m, KnownNat n, KnownNat i, KnownNat j)
155 | => d
156 | -> MWC.Gen (PrimState m)
157 | -> H.L n i
158 | -> m (H.L n j)
159 | reshapeLCols d g x = case Proxy @j %<=? Proxy @i of
160 | LE Refl -> pure . colsL . SVG.take @_ @j @(i - j) . lCols $ x
161 | NLE Refl Refl -> (x H.|||) <$> initialize @(H.L n (j - i)) d g
162 |
--------------------------------------------------------------------------------
/old2/src/Data/Type/Mayb.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE AllowAmbiguousTypes #-}
2 | {-# LANGUAGE ConstraintKinds #-}
3 | {-# LANGUAGE FlexibleContexts #-}
4 | {-# LANGUAGE FlexibleInstances #-}
5 | {-# LANGUAGE GADTs #-}
6 | {-# LANGUAGE KindSignatures #-}
7 | {-# LANGUAGE LambdaCase #-}
8 | {-# LANGUAGE MultiParamTypeClasses #-}
9 | {-# LANGUAGE PatternSynonyms #-}
10 | {-# LANGUAGE RankNTypes #-}
11 | {-# LANGUAGE ScopedTypeVariables #-}
12 | {-# LANGUAGE StandaloneDeriving #-}
13 | {-# LANGUAGE TypeApplications #-}
14 | {-# LANGUAGE TypeFamilies #-}
15 | {-# LANGUAGE TypeFamilyDependencies #-}
16 | {-# LANGUAGE TypeInType #-}
17 | {-# LANGUAGE TypeOperators #-}
18 | {-# LANGUAGE UndecidableInstances #-}
19 |
20 | module Data.Type.Mayb (
21 | MaybeC, MaybeToList, ListToMaybe
22 | , Mayb(.., J_I), fromJ_, maybToList, listToMayb
23 | , P(..), KnownMayb, knownMayb
24 | , zipMayb
25 | , zipMayb3
26 | , FromJust
27 | , MaybeWit(..), type (<$>)
28 | , TupMaybe, splitTupMaybe, tupMaybe
29 | , BoolMayb, boolMayb
30 | ) where
31 |
32 | import Data.Bifunctor
33 | import Data.Kind
34 | import Data.Type.Boolean
35 | import Data.Type.Combinator
36 | import Data.Type.Product
37 | import Data.Type.Tuple
38 | import Type.Class.Higher
39 | import Type.Class.Known
40 | import Type.Class.Witness
41 | import Type.Family.Maybe (type (<$>))
42 | import qualified GHC.TypeLits as TL
43 |
44 | type family MaybeC (c :: k -> Constraint) (m :: Maybe k) :: Constraint where
45 | MaybeC c ('Just a) = c a
46 | MaybeC c 'Nothing = ()
47 |
48 | type family MaybeToList (m :: Maybe k) = (l :: [k]) | l -> m where
49 | MaybeToList 'Nothing = '[]
50 | MaybeToList ('Just a) = '[a]
51 |
52 | -- type family ConsMaybe (ml :: (Maybe k, [k])) = (l :: (Bool, [k])) | l -> ml where
53 | -- ConsMaybe '( 'Nothing, as) = '( 'False, as )
54 | -- ConsMaybe '( 'Just a , as) = '( 'True , a ': as )
55 |
56 | maybToList
57 | :: Mayb f m
58 | -> Prod f (MaybeToList m)
59 | maybToList N_ = Ø
60 | maybToList (J_ x) = x :< Ø
61 |
62 | type family ListToMaybe (l :: [k]) :: Maybe k where
63 | ListToMaybe '[] = 'Nothing
64 | ListToMaybe (a ': as) = 'Just a
65 |
66 | listToMayb
67 | :: Prod f as
68 | -> Mayb f (ListToMaybe as)
69 | listToMayb Ø = N_
70 | listToMayb (x :< _) = J_ x
71 |
72 | class MaybeWit (c :: k -> Constraint) (m :: Maybe k) where
73 | maybeWit :: Mayb (Wit1 c) m
74 |
75 | instance (MaybeC c m, Known (Mayb P) m) => MaybeWit c m where
76 | maybeWit = case known @_ @(Mayb P) @m of
77 | J_ _ -> J_ Wit1
78 | N_ -> N_
79 |
80 | type KnownMayb = Known (Mayb P)
81 |
82 | knownMayb :: KnownMayb p => Mayb P p
83 | knownMayb = known
84 |
85 | data Mayb :: (k -> Type) -> Maybe k -> Type where
86 | N_ :: Mayb f 'Nothing
87 | J_ :: !(f a) -> Mayb f ('Just a)
88 |
89 | deriving instance MaybeC Show (f <$> m) => Show (Mayb f m)
90 |
91 | data P :: k -> Type where
92 | P :: P a
93 |
94 | fromJ_ :: Mayb f ('Just a) -> f a
95 | fromJ_ (J_ x) = x
96 |
97 | pattern J_I :: a -> Mayb I ('Just a)
98 | pattern J_I x = J_ (I x)
99 |
100 | instance Known P k where
101 | known = P
102 |
103 | instance Known (Mayb f) 'Nothing where
104 | known = N_
105 |
106 | instance Known f a => Known (Mayb f) ('Just a) where
107 | type KnownC (Mayb f) ('Just a) = Known f a
108 | known = J_ known
109 |
110 | instance Functor1 Mayb where
111 | map1 f (J_ x) = J_ (f x)
112 | map1 _ N_ = N_
113 |
114 | zipMayb
115 | :: (forall a. f a -> g a -> h a)
116 | -> Mayb f m
117 | -> Mayb g m
118 | -> Mayb h m
119 | zipMayb f (J_ x) (J_ y) = J_ (f x y)
120 | zipMayb _ N_ N_ = N_
121 |
122 | zipMayb3
123 | :: (forall a. f a -> g a -> h a -> i a)
124 | -> Mayb f m
125 | -> Mayb g m
126 | -> Mayb h m
127 | -> Mayb i m
128 | zipMayb3 f (J_ x) (J_ y) (J_ z) = J_ (f x y z)
129 | zipMayb3 _ N_ N_ N_ = N_
130 |
131 | m2 :: forall c f m. MaybeC c (f <$> m)
132 | => (forall a. c (f a) => f a -> f a -> f a)
133 | -> Mayb f m
134 | -> Mayb f m
135 | -> Mayb f m
136 | m2 f (J_ x) (J_ y) = J_ (f x y)
137 | m2 _ N_ N_ = N_
138 |
139 | m1 :: forall c f m. MaybeC c (f <$> m)
140 | => (forall a. c (f a) => f a -> f a)
141 | -> Mayb f m
142 | -> Mayb f m
143 | m1 f (J_ x) = J_ (f x)
144 | m1 _ N_ = N_
145 |
146 | m0 :: forall c f m. (MaybeC c (f <$> m))
147 | => (forall a. c (f a) => f a)
148 | -> Mayb P m
149 | -> Mayb f m
150 | m0 x (J_ _) = J_ x
151 | m0 _ N_ = N_
152 |
153 | instance (Known (Mayb P) m, MaybeC Num (f <$> m)) => Num (Mayb f m) where
154 | (+) = m2 @Num (+)
155 | (-) = m2 @Num (-)
156 | (*) = m2 @Num (*)
157 | negate = m1 @Num negate
158 | abs = m1 @Num abs
159 | signum = m1 @Num signum
160 | fromInteger x = m0 @Num (fromInteger x) known
161 |
162 | instance (Known (Mayb P) m, MaybeC Num (f <$> m), MaybeC Fractional (f <$> m))
163 | => Fractional (Mayb f m) where
164 | (/) = m2 @Fractional (/)
165 | recip = m1 @Fractional recip
166 | fromRational x = m0 @Fractional (fromRational x) known
167 |
168 | type family FromJust (d :: TL.ErrorMessage) (m :: Maybe k) :: k where
169 | FromJust e ('Just a) = a
170 | FromJust e 'Nothing = TL.TypeError e
171 |
172 | type family TupMaybe (a :: Maybe Type) (b :: Maybe Type) :: Maybe Type where
173 | TupMaybe 'Nothing 'Nothing = 'Nothing
174 | TupMaybe 'Nothing ('Just b) = 'Just b
175 | TupMaybe ('Just a) 'Nothing = 'Just a
176 | TupMaybe ('Just a) ('Just b) = 'Just (T2 a b)
177 |
178 | tupMaybe
179 | :: forall f a b. ()
180 | => (forall a' b'. (a ~ 'Just a', b ~ 'Just b') => f a' -> f b' -> f (T2 a' b'))
181 | -> Mayb f a
182 | -> Mayb f b
183 | -> Mayb f (TupMaybe a b)
184 | tupMaybe f = \case
185 | N_ -> \case
186 | N_ -> N_
187 | J_ y -> J_ y
188 | J_ x -> \case
189 | N_ -> J_ x
190 | J_ y -> J_ (f x y)
191 |
192 | splitTupMaybe
193 | :: forall f a b. (KnownMayb a, KnownMayb b)
194 | => (forall a' b'. (a ~ 'Just a', b ~ 'Just b') => f (T2 a' b') -> (f a', f b'))
195 | -> Mayb f (TupMaybe a b)
196 | -> (Mayb f a, Mayb f b)
197 | splitTupMaybe f = case knownMayb @a of
198 | N_ -> case knownMayb @b of
199 | N_ -> \case
200 | N_ -> (N_, N_)
201 | J_ _ -> \case
202 | J_ y -> (N_, J_ y)
203 | J_ _ -> case knownMayb @b of
204 | N_ -> \case
205 | J_ x -> (J_ x, N_)
206 | J_ _ -> \case
207 | J_ xy -> bimap J_ J_ . f $ xy
208 |
209 | type family BoolMayb (b :: Bool) = (m :: Maybe ()) | m -> b where
210 | BoolMayb 'False = 'Nothing
211 | BoolMayb 'True = 'Just '()
212 |
213 | boolMayb :: Boolean b -> Mayb P (BoolMayb b)
214 | boolMayb False_ = N_
215 | boolMayb True_ = J_ P
216 |
--------------------------------------------------------------------------------
/old2/src/Backprop/Learn/Model/Class.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE AllowAmbiguousTypes #-}
2 | {-# LANGUAGE ConstraintKinds #-}
3 | {-# LANGUAGE DataKinds #-}
4 | {-# LANGUAGE FlexibleContexts #-}
5 | {-# LANGUAGE FlexibleInstances #-}
6 | {-# LANGUAGE FunctionalDependencies #-}
7 | {-# LANGUAGE GADTs #-}
8 | {-# LANGUAGE KindSignatures #-}
9 | {-# LANGUAGE MultiParamTypeClasses #-}
10 | {-# LANGUAGE PatternSynonyms #-}
11 | {-# LANGUAGE ScopedTypeVariables #-}
12 | {-# LANGUAGE TupleSections #-}
13 | {-# LANGUAGE TypeFamilies #-}
14 | {-# LANGUAGE TypeFamilyDependencies #-}
15 | {-# LANGUAGE TypeInType #-}
16 | {-# LANGUAGE TypeOperators #-}
17 | {-# LANGUAGE UndecidableInstances #-}
18 | {-# LANGUAGE UndecidableSuperClasses #-}
19 |
20 | module Backprop.Learn.Model.Class (
21 | Learn(..)
22 | , LParam, LState, LParams, LStates, NoParam, NoState
23 | , LParam_, LState_
24 | , stateless, statelessM
25 | , runLearnStateless
26 | , runLearnStochStateless
27 | , Mayb(..), fromJ_, MaybeC, KnownMayb, knownMayb, I(..)
28 | , SomeLearn(..)
29 | ) where
30 |
31 | import Backprop.Learn.Initialize
32 | import Control.DeepSeq
33 | import Control.Monad.Primitive
34 | import Data.Kind
35 | import Data.Type.Mayb
36 | import Data.Typeable
37 | import Numeric.Backprop
38 | import Numeric.Opto.Update
39 | import Type.Family.List (type (++))
40 | import qualified GHC.TypeLits as TL
41 | import qualified System.Random.MWC as MWC
42 |
43 | -- | The trainable parameter type of a model. Will be a compile-time error
44 | -- if the model has no trainable parameters.
45 | type LParam l = FromJust
46 | ('TL.ShowType l 'TL.:<>: 'TL.Text " has no trainable parameters")
47 | (LParamMaybe l)
48 |
49 | -- | The state type of a model. Will be a compile-time error if the model
50 | -- has no state.
51 | type LState l = FromJust
52 | ('TL.ShowType l 'TL.:<>: 'TL.Text " has no trainable parameters")
53 | (LStateMaybe l)
54 |
55 | -- | Constraint specifying that a given model has no trainabale parameters.
56 | type NoParam l = LParamMaybe l ~ 'Nothing
57 |
58 | -- | Constraint specifying that a given model has no state.
59 | type NoState l = LStateMaybe l ~ 'Nothing
60 |
61 | -- | Is 'N_' if there is @l@ has no trainable parameters; otherwise is 'J_'
62 | -- with @f p@, for trainable parameter type @p@.
63 | type LParam_ f l = Mayb f (LParamMaybe l)
64 |
65 | -- | Is 'N_' if there is @l@ has no state; otherwise is 'J_' with @f
66 | -- s@, for state type @s@.
67 | type LState_ f l = Mayb f (LStateMaybe l)
68 |
69 | -- | List of parameters of 'Learn' instances
70 | type family LParams (ls :: [Type]) :: [Type] where
71 | LParams '[] = '[]
72 | LParams (l ': ls) = MaybeToList (LParamMaybe l) ++ LParams ls
73 |
74 | -- | List of states of 'Learn' instances
75 | type family LStates (ls :: [Type]) :: [Type] where
76 | LStates '[] = '[]
77 | LStates (l ': ls) = MaybeToList (LStateMaybe l) ++ LStates ls
78 |
79 | -- | Class for models that can be trained using gradient descent
80 | --
81 | -- An instance @l@ of @'Learn' a b@ is parameterized by @p@, takes @a@ as
82 | -- input, and returns @b@ as outputs. @l@ can be thought of as a value
83 | -- containing the /hyperparmaeters/ of the model.
84 | class Learn a b l | l -> a b where
85 |
86 | -- | The trainable parameters of model @l@.
87 | --
88 | -- By default, is ''Nothing'. To give a type for learned parameters @p@,
89 | -- use the type @''Just' p@
90 | type LParamMaybe l :: Maybe Type
91 |
92 | -- | The type of the state of model @l@. Used for things like
93 | -- recurrent neural networks.
94 | --
95 | -- By default, is ''Nothing'. To give a type for state @s@, use the
96 | -- type @''Just' s@.
97 | --
98 | -- Most models will not use state, training algorithms will only work
99 | -- if 'LStateMaybe' is ''Nothing'. However, models that use state can
100 | -- be converted to models that do not using 'Unroll'; this can be done
101 | -- before training.
102 | type LStateMaybe l :: Maybe Type
103 |
104 | type LParamMaybe l = 'Nothing
105 | type LStateMaybe l = 'Nothing
106 |
107 | -- | Run the model itself, deterministically.
108 | --
109 | -- If your model has no state, you can define this conveniently using
110 | -- 'stateless'.
111 | runLearn
112 | :: Reifies s W
113 | => l
114 | -> LParam_ (BVar s) l
115 | -> BVar s a
116 | -> LState_ (BVar s) l
117 | -> (BVar s b, LState_ (BVar s) l)
118 |
119 | -- | Run a model in stochastic mode.
120 | --
121 | -- If model is inherently non-stochastic, a default implementation is
122 | -- given in terms of 'runLearn'.
123 | --
124 | -- If your model has no state, you can define this conveniently using
125 | -- 'statelessStoch'.
126 | runLearnStoch
127 | :: (Reifies s W, PrimMonad m)
128 | => l
129 | -> MWC.Gen (PrimState m)
130 | -> LParam_ (BVar s) l
131 | -> BVar s a
132 | -> LState_ (BVar s) l
133 | -> m (BVar s b, LState_ (BVar s) l)
134 | runLearnStoch l _ p x s = pure (runLearn l p x s)
135 |
136 | -- | Useful for defining 'runLearn' if your model has no state.
137 | stateless
138 | :: (a -> b)
139 | -> (a -> s -> (b, s))
140 | stateless f x = (f x,)
141 |
142 | -- | Useful for defining 'runLearnStoch' if your model has no state.
143 | statelessM
144 | :: Functor m
145 | => (a -> m b)
146 | -> (a -> s -> m (b, s))
147 | statelessM f x s = (, s) <$> f x
148 |
149 | runLearnStateless
150 | :: (Learn a b l, Reifies s W, NoState l)
151 | => l
152 | -> LParam_ (BVar s) l
153 | -> BVar s a
154 | -> BVar s b
155 | runLearnStateless l p = fst . flip (runLearn l p) N_
156 |
157 | runLearnStochStateless
158 | :: (Learn a b l, Reifies s W, NoState l, PrimMonad m)
159 | => l
160 | -> MWC.Gen (PrimState m)
161 | -> LParam_ (BVar s) l
162 | -> BVar s a
163 | -> m (BVar s b)
164 | runLearnStochStateless l g p = fmap fst . flip (runLearnStoch l g p) N_
165 |
166 | -- | Existential wrapper for learnable model, representing a trainable
167 | -- function from @a@ to @b@.
168 | data SomeLearn :: Type -> Type -> Type where
169 | SL :: ( Learn a b l
170 | , Typeable l
171 | , KnownMayb (LParamMaybe l)
172 | , KnownMayb (LStateMaybe l)
173 | , MaybeC Floating (LParamMaybe l)
174 | , MaybeC Floating (LStateMaybe l)
175 | , MaybeC (Metric Double) (LParamMaybe l)
176 | , MaybeC (Metric Double) (LStateMaybe l)
177 | , MaybeC NFData (LParamMaybe l)
178 | , MaybeC NFData (LStateMaybe l)
179 | , MaybeC Initialize (LParamMaybe l)
180 | , MaybeC Initialize (LStateMaybe l)
181 | )
182 | => l
183 | -> SomeLearn a b
184 |
185 |
--------------------------------------------------------------------------------
/old/app/Language.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DataKinds #-}
2 | {-# LANGUAGE GADTs #-}
3 | {-# LANGUAGE LambdaCase #-}
4 | {-# LANGUAGE ScopedTypeVariables #-}
5 | {-# LANGUAGE TypeApplications #-}
6 | {-# LANGUAGE TypeOperators #-}
7 | {-# LANGUAGE ViewPatterns #-}
8 | {-# OPTIONS_GHC -fno-warn-orphans #-}
9 |
10 | -- import Data.Type.Combinator
11 | import Control.DeepSeq
12 | import Control.Exception
13 | import Control.Monad.IO.Class
14 | import Control.Monad.Primitive
15 | import Control.Monad.Trans.State
16 | import Data.Bifunctor
17 | import Data.Char
18 | import Data.Default
19 | import Data.Finite
20 | import Data.Foldable
21 | import Data.List
22 | import Data.List.Split
23 | import Data.Maybe
24 | import Data.Ord
25 | import Data.Time.Clock
26 | import Data.Type.Product hiding (toList, head')
27 | import Data.Type.Vector
28 | import GHC.TypeLits hiding (type (+))
29 | import Learn.Neural
30 | import Learn.Neural.Layer.Recurrent.FullyConnected
31 | import Learn.Neural.Layer.Recurrent.LSTM
32 | import Numeric.BLAS.HMatrix
33 | import Numeric.Backprop.Op hiding (head')
34 | import Text.Printf hiding (toChar, fromChar)
35 | import Type.Class.Known
36 | import Type.Family.Nat
37 | import qualified Data.List.NonEmpty as NE
38 | import qualified Data.Type.Nat as TCN
39 | import qualified Data.Vector as V
40 | import qualified Data.Vector.Sized as VSi
41 | import qualified System.Random.MWC as MWC
42 | import qualified System.Random.MWC.Distributions as MWC
43 |
44 | type ASCII = Finite 128
45 |
46 | fromChar :: Char -> Maybe ASCII
47 | fromChar = packFinite . fromIntegral . ord
48 |
49 | toChar :: ASCII -> Char
50 | toChar = chr . fromIntegral
51 |
52 | charOneHot :: Tensor t => Char -> Maybe (t '[128])
53 | charOneHot = fmap (oneHot . only) . fromChar
54 |
55 | oneHotChar :: BLAS t => t '[128] -> Char
56 | oneHotChar = toChar . iamax
57 |
58 | charRank :: Tensor t => t '[128] -> [Char]
59 | charRank = map fst . sortBy (flip (comparing snd)) . zip ['\0'..] . toList . textract
60 |
61 | main :: IO ()
62 | main = MWC.withSystemRandom $ \g -> do
63 | holmes <- evaluate . force . mapMaybe (charOneHot @HM)
64 | =<< readFile "data/holmes.txt"
65 | putStrLn "Loaded data"
66 | -- let slices_ :: [(Vec N4 (HM '[128]), HM '[128])]
67 | -- slices_ = slidingPartsLast known . asFeedback $ holmes
68 | let slices_ :: [(Vec N8 (HM '[128]), Vec N8 (HM '[128]))]
69 | slices_ = slidingPartsSplit known . asFeedback $ holmes
70 | slices <- evaluate . force $ slices_
71 | putStrLn "Processed data"
72 | let opt0 = batching $ adamOptimizer def (netOpRecurrent_ known)
73 | (sumLossDecay crossEntropy known α)
74 | -- net0 :: Network 'Recurrent HM ( '[128] :~ FullyConnectedR' 'MF_Logit )
75 | -- '[ '[64 ] :~ ReLUMap
76 | -- , '[64 ] :~ FullyConnectedR' 'MF_Logit
77 | -- , '[32 ] :~ ReLUMap
78 | -- , '[32 ] :~ FullyConnected
79 | -- , '[128] :~ SoftMax '[128]
80 | -- ]
81 | -- '[128] <- initDefNet g
82 | net0 :: Network 'Recurrent HM ( '[128] :~ LSTM )
83 | '[ '[96 ] :~ ReLUMap
84 | , '[96 ] :~ LSTM
85 | , '[64 ] :~ ReLUMap
86 | , '[64 ] :~ FullyConnected
87 | , '[128] :~ SoftMax '[128]
88 | ]
89 | '[128] <- initDefNet g
90 | flip evalStateT (net0, opt0) . forM_ [1..] $ \e -> do
91 | train' <- liftIO . fmap V.toList $ MWC.uniformShuffle (V.fromList slices) g
92 | liftIO $ printf "[Epoch %d]\n" (e :: Int)
93 |
94 | let chunkUp = chunksOf batch train'
95 | numChunks = length chunkUp
96 | forM_ ([1..] `zip` chunkUp) $ \(b, chnk) -> StateT $ \(n0, o0) -> do
97 | printf "(Epoch %d, Batch %d / %d)\n" (e :: Int) (b :: Int) numChunks
98 |
99 | t0 <- getCurrentTime
100 | -- n' <- evaluate $ optimizeList_ (bimap vecToProd only_ <$> chnk) n0
101 | (n', o') <- evaluate
102 | $ optimizeListBatches (bimap vecToProd vecToProd <$> chnk) n0 o0 25
103 | t1 <- getCurrentTime
104 | printf "Trained on %d points in %s.\n" batch (show (t1 `diffUTCTime` t0))
105 |
106 | forM_ (map (bimap toList (last . toList)) . take 3 $ chnk) $ \(lastChnk, x0) -> do
107 | let (ys, primed) = runNetRecurrent n' lastChnk
108 | next :: HM '[128] -> IO ((Char, HM '[128]), HM '[128])
109 | next x = do
110 | pick <- pickNext 4 x g
111 | return ((toChar pick, x), oneHot (only pick))
112 | test <- toList . fst
113 | <$> runNetFeedbackM (known @_ @_ @(N10 + N6)) next primed x0
114 |
115 | forM_ (zip (lastChnk ++ [x0]) (ys ++ [snd (head test)])) $ \(c,y) ->
116 | printf "|%c\t=> %s\t(%.4f)\n"
117 | (censor (oneHotChar c))
118 | (take 25 (censor <$> charRank y))
119 | (amax y)
120 | forM_ (zip test (drop 1 test)) $ \((t,_),(_,p)) ->
121 | printf " %c\t=> %s\t(%.4f)\n"
122 | (censor t)
123 | (take 25 (censor <$> charRank p))
124 | (amax p)
125 | putStrLn "---"
126 |
127 | let (test, _) = runNetRecurrentLast n' . vecNonEmpty . fst . head $ chnk
128 | let n'' | isNaN (amax test) = n0
129 | | otherwise = n'
130 | return ((), (n'', o'))
131 | where
132 | batch :: Int
133 | batch = 1000
134 | α :: Double
135 | α = 2/3
136 |
137 | pickNext
138 | :: (PrimMonad m, BLAS t, KnownNat n)
139 | => Double
140 | -> t '[n]
141 | -> MWC.Gen (PrimState m)
142 | -> m (Finite n)
143 | pickNext α x g
144 | = fmap fromIntegral
145 | . flip MWC.categorical g
146 | . fmap ((** α) . realToFrac)
147 | . VSi.fromSized
148 | . textract
149 | $ x
150 |
151 |
152 | censor :: Char -> Char
153 | censor c
154 | | isPrint c = c
155 | | otherwise = '░'
156 |
157 | instance NFData a => NFData (I a) where
158 | rnf = \case
159 | I x -> rnf x
160 |
161 | instance NFData (f a) => NFData (VecT n f a) where
162 | rnf = \case
163 | ØV -> ()
164 | x :* xs -> x `deepseq` rnf xs
165 |
166 | vecToProd
167 | :: VecT n f a
168 | -> Prod f (Replicate n a)
169 | vecToProd = \case
170 | ØV -> Ø
171 | x :* xs -> x :< vecToProd xs
172 |
173 | vecNonEmpty
174 | :: Vec ('S n) a
175 | -> NE.NonEmpty a
176 | vecNonEmpty = \case
177 | I x :* xs -> x NE.:| toList xs
178 |
--------------------------------------------------------------------------------
/old/src/Numeric/BLAS.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DataKinds #-}
2 | {-# LANGUAGE DeriveFunctor #-}
3 | {-# LANGUAGE FlexibleContexts #-}
4 | {-# LANGUAGE FunctionalDependencies #-}
5 | {-# LANGUAGE GADTs #-}
6 | {-# LANGUAGE InstanceSigs #-}
7 | {-# LANGUAGE KindSignatures #-}
8 | {-# LANGUAGE LambdaCase #-}
9 | {-# LANGUAGE MultiParamTypeClasses #-}
10 | {-# LANGUAGE PolyKinds #-}
11 | {-# LANGUAGE ScopedTypeVariables #-}
12 | {-# LANGUAGE StandaloneDeriving #-}
13 | {-# LANGUAGE TemplateHaskell #-}
14 | {-# LANGUAGE TypeApplications #-}
15 | {-# LANGUAGE TypeFamilies #-}
16 | {-# LANGUAGE TypeInType #-}
17 | {-# LANGUAGE TypeOperators #-}
18 | {-# LANGUAGE UndecidableInstances #-}
19 |
20 | module Numeric.BLAS (
21 | BLAS(..)
22 | , matVec
23 | , vecMat
24 | , matMat
25 | , outer
26 | , diag
27 | , eye
28 | , amax
29 | , concretize
30 | , matVecOp
31 | , dotOp
32 | , asumOp
33 | , module Numeric.Tensor
34 | ) where
35 |
36 | import Data.Finite
37 | import Data.Finite.Internal
38 | import Data.Foldable hiding (asum)
39 | import Data.Kind
40 | import Data.Maybe
41 | import Data.Ord
42 | import Data.Singletons
43 | import Data.Singletons.Prelude.Num
44 | import Data.Singletons.TypeLits
45 | import GHC.TypeLits
46 | import Numeric.Backprop.Op
47 | import Numeric.Tensor
48 | import qualified Data.Vector.Sized as V
49 |
50 | class Tensor b => BLAS (b :: [Nat] -> Type) where
51 |
52 | transp
53 | :: (KnownNat m, KnownNat n)
54 | => b '[m, n]
55 | -> b '[n, m]
56 | transp x = gen sing $ \case
57 | n :< m :< Ø -> tindex (m :< n :< Ø) x
58 |
59 | -- Level 1
60 | scal
61 | :: KnownNat n
62 | => Scalar b -- ^ α
63 | -> b '[n] -- ^ x
64 | -> b '[n] -- ^ α x
65 | scal α = tmap (α *)
66 |
67 | axpy
68 | :: KnownNat n
69 | => Scalar b -- ^ α
70 | -> b '[n] -- ^ x
71 | -> b '[n] -- ^ y
72 | -> b '[n] -- ^ α x + y
73 | axpy α = tzip (\x y -> α * x + y)
74 |
75 | dot :: KnownNat n
76 | => b '[n] -- ^ x
77 | -> b '[n] -- ^ y
78 | -> Scalar b -- ^ x' y
79 | dot x y = tsum (tzip (*) x y)
80 |
81 | norm2
82 | :: KnownNat n
83 | => b '[n] -- ^ x
84 | -> Scalar b -- ^ ||x||
85 | norm2 = tsum . tmap (** 2)
86 |
87 | asum
88 | :: KnownNat n
89 | => b '[n] -- ^ x
90 | -> Scalar b -- ^ sum_i |x_i|
91 | asum = tsum . tmap abs
92 |
93 | iamax
94 | :: forall n. KnownNat n
95 | => b '[n + 1] -- ^ x
96 | -> Finite (n + 1) -- ^ argmax_i |x_i|
97 | iamax = withKnownNat (SNat @n %:+ SNat @1) $
98 | Finite . fromIntegral . V.maxIndex . textract . tmap abs
99 |
100 | -- Level 2
101 | gemv
102 | :: (KnownNat m, KnownNat n)
103 | => Scalar b -- ^ α
104 | -> b '[m, n] -- ^ A
105 | -> b '[n] -- ^ x
106 | -> Maybe (Scalar b, b '[m]) -- ^ β, y
107 | -> b '[m] -- ^ α A x + β y
108 | gemv α a x b = maybe id (uncurry axpy) b . gen sing $ \case
109 | i :< Ø -> α * dot x (treshape sing (tslice (SliceSingle i `PMS` SliceAll `PMS` PMZ) a))
110 |
111 | ger :: (KnownNat m, KnownNat n)
112 | => Scalar b -- ^ α
113 | -> b '[m] -- ^ x
114 | -> b '[n] -- ^ y
115 | -> Maybe (b '[m, n]) -- ^ A
116 | -> b '[m, n] -- ^ α x y' + A
117 | ger α x y b = maybe id (tzip (+)) b . gen sing $ \case
118 | i :< j :< Ø -> α * tindex (i :< Ø) x * tindex (j :< Ø) y
119 |
120 | syr :: KnownNat n
121 | => Scalar b -- ^ α
122 | -> b '[n] -- ^ x
123 | -> Maybe (b '[n, n]) -- ^ A
124 | -> b '[n, n] -- ^ x x' + A
125 | syr α x a = ger α x x a
126 |
127 | -- Level 3
128 | gemm
129 | :: (KnownNat m, KnownNat o, KnownNat n)
130 | => Scalar b -- ^ α
131 | -> b '[m, o] -- ^ A
132 | -> b '[o, n] -- ^ B
133 | -> Maybe (Scalar b, b '[m, n]) -- ^ β, C
134 | -> b '[m, n] -- ^ α A B + β C
135 | gemm α a b c = maybe id (uncurry f) c . gen sing $ \case
136 | i :< j :< Ø ->
137 | α * dot (treshape sing (tslice (SliceSingle i `PMS` SliceAll `PMS` PMZ) a))
138 | (treshape sing (tslice (SliceAll `PMS` SliceSingle j `PMS` PMZ) b))
139 | where
140 | f β = tzip (\d r -> β * d + r)
141 |
142 | syrk
143 | :: (KnownNat m, KnownNat n)
144 | => Scalar b -- ^ α
145 | -> b '[m, n] -- ^ A
146 | -> Maybe (Scalar b, b '[m, m]) -- ^ β, C
147 | -> b '[m, m] -- ^ α A A' + β C
148 | syrk α a c = gemm α a (transp a) c
149 |
150 | {-# MINIMAL #-}
151 |
152 | matVec
153 | :: (KnownNat m, KnownNat n, BLAS b)
154 | => b '[m, n]
155 | -> b '[n]
156 | -> b '[m]
157 | matVec a x = gemv 1 a x Nothing
158 |
159 | vecMat
160 | :: (KnownNat m, KnownNat n, BLAS b)
161 | => b '[m]
162 | -> b '[m, n]
163 | -> b '[n]
164 | vecMat x a = gemv 1 (transp a) x Nothing
165 |
166 | matMat
167 | :: (KnownNat m, KnownNat o, KnownNat n, BLAS b)
168 | => b '[m, o]
169 | -> b '[o, n]
170 | -> b '[m, n]
171 | matMat a b = gemm 1 a b Nothing
172 |
173 | outer
174 | :: (KnownNat m, KnownNat n, BLAS b)
175 | => b '[m]
176 | -> b '[n]
177 | -> b '[m, n]
178 | outer x y = ger 1 x y Nothing
179 |
180 | diag
181 | :: (KnownNat n, Tensor b)
182 | => b '[n]
183 | -> b '[n, n]
184 | diag x = gen sing $ \case
185 | i :< j :< Ø
186 | | i `equals` j -> tindex (i :< Ø) x
187 | | otherwise -> 0
188 |
189 | eye
190 | :: (KnownNat n, Tensor b)
191 | => b '[n, n]
192 | eye = gen sing $ \case
193 | i :< j :< Ø
194 | | i `equals` j -> 1
195 | | otherwise -> 0
196 |
197 | amax
198 | :: forall b n. (BLAS b, KnownNat n)
199 | => b '[n + 1]
200 | -> Scalar b
201 | amax = do
202 | i <- only . iamax
203 | withKnownNat (SNat @n %:+ SNat @1) $
204 | tindex i
205 |
206 | concretize :: forall b n. (BLAS b, KnownNat n) => b '[n + 1] -> b '[n + 1]
207 | concretize = withKnownNat (SNat @n %:+ SNat @1) $
208 | oneHot . only . iamax
209 |
210 | matVecOp
211 | :: (KnownNat m, KnownNat n, BLAS b)
212 | => Op '[ b '[m, n], b '[n] ] '[ b '[m] ]
213 | matVecOp = op2' $ \a x ->
214 | ( only_ (matVec a x)
215 | , (\g -> (outer g x, vecMat g a))
216 | . fromMaybe (tkonst sing 1)
217 | . head'
218 | )
219 |
220 | dotOp
221 | :: forall b n. (KnownNat n, BLAS b)
222 | => Op '[ b '[n], b '[n] ] '[ Scalar b ]
223 | dotOp = op2' $ \x y ->
224 | ( only_ (dot x y)
225 | , \case Nothing :< Ø -> (y , x )
226 | Just g :< Ø -> (scal g y, scal g x)
227 | )
228 |
229 | asumOp
230 | :: forall b n. (KnownNat n, BLAS b, Num (b '[n]))
231 | => Op '[ b '[n] ] '[ Scalar b ]
232 | asumOp = op1' $ \x ->
233 | ( only_ (asum x)
234 | , \case Nothing :< Ø -> signum x
235 | Just g :< Ø -> scal g (signum x)
236 | )
237 |
--------------------------------------------------------------------------------
/old/src/Learn/Neural/Network/Dropout.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DataKinds #-}
2 | {-# LANGUAGE FlexibleContexts #-}
3 | {-# LANGUAGE FlexibleInstances #-}
4 | {-# LANGUAGE GADTs #-}
5 | {-# LANGUAGE KindSignatures #-}
6 | {-# LANGUAGE LambdaCase #-}
7 | {-# LANGUAGE ScopedTypeVariables #-}
8 | {-# LANGUAGE Strict #-}
9 | {-# LANGUAGE TypeApplications #-}
10 | {-# LANGUAGE TypeInType #-}
11 | {-# LANGUAGE TypeOperators #-}
12 | {-# LANGUAGE UndecidableInstances #-}
13 |
14 | module Learn.Neural.Network.Dropout (
15 | Dropout(..)
16 | , NetworkDO(..)
17 | , netOpDO
18 | , netOpDOPure
19 | , konstDO, konstDO', mapDO, mapDO'
20 | , alongNet
21 | ) where
22 |
23 |
24 | import Control.Monad.Primitive
25 | import Data.Bool
26 | import Data.Kind
27 | import Data.Singletons
28 | import Data.Traversable
29 | import Data.Type.Combinator
30 | import Data.Type.Product
31 | import GHC.TypeLits
32 | import Learn.Neural.Layer
33 | import Learn.Neural.Network
34 | import Numeric.BLAS
35 | import Numeric.Backprop
36 | import Numeric.Backprop.Op
37 | import System.Random.MWC
38 | import System.Random.MWC.Distributions
39 | import Type.Class.Known
40 |
41 | data NetworkDO :: RunMode -> ([Nat] -> Type) -> LChain -> [LChain] -> [Nat] -> Type where
42 | NDO :: { ndoDropout :: !(Dropout r b i hs o)
43 | , ndoNetwork :: !(Network r b i hs o)
44 | }
45 | -> NetworkDO r b i hs o
46 |
47 | -- | For something like
48 | --
49 | -- @
50 | -- l1 ':&' l2 :& 'NetExt' l3
51 | -- @
52 | --
53 | -- you could have
54 | --
55 | -- @
56 | -- 0.2 ':&%' 0.3 :&% DOExt
57 | -- @
58 | --
59 | -- Would would represent a 20% dropout after @l1@ and a 30% dropout after
60 | -- @l2@. By conscious design decision, 'DOExt' does not take any dropout
61 | -- rate, and it is therefore impossible to have dropout after the final
62 | -- layer before the output. It is also not possible to have dropout before
63 | -- the first layer, after the input.
64 | data Dropout :: RunMode -> ([Nat] -> Type) -> LChain -> [LChain] -> [Nat] -> Type where
65 | DOExt
66 | :: Dropout r b (i :~ c) '[] o
67 | (:&%)
68 | :: (Num (b h), SingI h)
69 | => !(Maybe Double)
70 | -> !(Dropout r b (h :~ d) hs o)
71 | -> Dropout r b (i :~ c) ((h :~ d) ': hs) o
72 |
73 | infixr 4 :&%
74 |
75 | alongNet
76 | :: n r b i hs o
77 | -> Dropout r b i hs o
78 | -> Dropout r b i hs o
79 | alongNet _ d = d
80 |
81 | konstDO
82 | :: forall r b i hs o. ()
83 | => NetStruct r b i hs o
84 | -> Maybe Double
85 | -> Dropout r b i hs o
86 | konstDO s0 x = go s0
87 | where
88 | go :: NetStruct r b j js o -> Dropout r b j js o
89 | go = \case
90 | NSExt -> DOExt
91 | NSInt s -> x :&% go s
92 |
93 | konstDO'
94 | :: forall r b i hs o. Known (NetStruct r b i hs) o
95 | => Maybe Double
96 | -> Dropout r b i hs o
97 | konstDO' = konstDO known
98 |
99 | mapDO
100 | :: forall r b i hs o. ()
101 | => (Double -> Double)
102 | -> Dropout r b i hs o
103 | -> Dropout r b i hs o
104 | mapDO f = go
105 | where
106 | go :: Dropout r b j js o -> Dropout r b j js o
107 | go = \case
108 | DOExt -> DOExt
109 | x :&% d -> fmap f x :&% go d
110 |
111 | mapDO'
112 | :: forall r b i hs o. ()
113 | => (Maybe Double -> Maybe Double)
114 | -> Dropout r b i hs o
115 | -> Dropout r b i hs o
116 | mapDO' f = go
117 | where
118 | go :: Dropout r b j js o -> Dropout r b j js o
119 | go = \case
120 | DOExt -> DOExt
121 | x :&% d -> f x :&% go d
122 |
123 | netOpDO
124 | :: forall m b i c hs o r. (BLAS b, Num (b i), Num (b o), PrimMonad m, SingI o)
125 | => Dropout r b (i :~ c) hs o
126 | -> Gen (PrimState m)
127 | -> m (OpBS '[ Network r b (i :~ c) hs o, b i ] '[ Network r b (i :~ c) hs o, b o ])
128 | netOpDO = \case
129 | DOExt -> \_ -> return $ OpBS $ OpM $ \(I n :< I x :< Ø) -> case n of
130 | NetExt (l :: Layer r c b i o) -> do
131 | (I l' :< I y :< Ø, gF) <- runOpM' (layerOp @r @c @i @o @b) (l ::< x ::< Ø)
132 | let gF' = fmap (\case I dL :< I dX :< Ø -> NetExt dL ::< dX ::< Ø)
133 | . gF
134 | . (\case Just (NetExt dL) :< dY :< Ø ->
135 | Just dL :< dY :< Ø
136 | Nothing :< dY :< Ø ->
137 | Nothing :< dY :< Ø
138 | )
139 | return (NetExt l' ::< y ::< Ø, gF')
140 | r :&% (d :: Dropout r b (h :~ d) js o) -> \g -> do
141 | mask <- forM r $ \r' ->
142 | genA @b (sing @_ @h) $ \_ -> bool (1 / (1 - realToFrac r')) 0 <$> bernoulli r' g
143 | no :: OpBS '[ Network r b (h :~ d) js o, b h ] '[ Network r b (h :~ d) js o, b o ]
144 | <- netOpDO @m @b @h @d @js @o @r d g
145 | return $ OpBS $ OpM $ \(I n :< I x :< Ø) -> case n of
146 | (l :: Layer r c b i h) :& (n2 :: Network r b (h ':~ d) js o) -> do
147 | (I l' :< I y :< Ø, gF ) <- runOpM' (layerOp @r @c @i @h @b) (l ::< x ::< Ø)
148 | (I n2' :< I z :< Ø, gF') <- runOpM' (runOpBS no ) (n2 ::< y ::< Ø)
149 | let gF'' = \case Just (dL :& dN) :< dZ :< Ø -> do
150 | I dN2 :< I dY :< Ø <- gF' (Just dN :< dZ :< Ø)
151 | let dY' = maybe dY (tzip (*) dY) mask
152 | I dL0 :< I dX :< Ø <- gF (Just dL :< Just dY' :< Ø)
153 | return $ (dL0 :& dN2) ::< dX ::< Ø
154 | Nothing :< dZ :< Ø -> do
155 | I dN2 :< I dY :< Ø <- gF' (Nothing :< dZ :< Ø)
156 | let dY' = maybe dY (tzip (*) dY) mask
157 | I dL0 :< I dX :< Ø <- gF (Nothing :< Just dY' :< Ø)
158 | return $ (dL0 :& dN2) ::< dX ::< Ø
159 | return ((l' :& n2') ::< z ::< Ø, gF'')
160 |
161 | netOpDOPure
162 | :: forall m b i c hs o. (BLAS b, Num (b i), Num (b o), PrimMonad m, SingI o)
163 | => Dropout 'FeedForward b (i :~ c) hs o
164 | -> Gen (PrimState m)
165 | -> m (OpBS '[ Network 'FeedForward b (i :~ c) hs o, b i ] '[ b o ])
166 | netOpDOPure = \case
167 | DOExt -> \_ -> return $ OpBS $ OpM $ \(I n :< I x :< Ø) -> case n of
168 | NetExt (l :: Layer 'FeedForward c b i o) -> do
169 | (I y :< Ø, gF) <- runOpM' (layerOpPure @c @i @o @b) (l ::< x ::< Ø)
170 | let gF' = fmap (\case I dL :< I dX :< Ø -> NetExt dL ::< dX ::< Ø)
171 | . gF
172 | return (y ::< Ø, gF')
173 | r :&% (d :: Dropout 'FeedForward b (h :~ d) js o) -> \g -> do
174 | mask <- forM r $ \r' ->
175 | genA @b (sing @_ @h) $ \_ -> bool (1 / (1 - realToFrac r')) 0 <$> bernoulli r' g
176 | no :: OpBS '[ Network 'FeedForward b (h :~ d) js o, b h ] '[ b o ]
177 | <- netOpDOPure @m @b @h @d @js @o d g
178 | return $ OpBS $ OpM $ \(I n :< I x :< Ø) -> case n of
179 | (l :: Layer 'FeedForward c b i h) :& (n2 :: Network 'FeedForward b (h ':~ d) js o) -> do
180 | (I y :< Ø, gF ) <- runOpM' (layerOpPure @c @i @h @b) (l ::< x ::< Ø)
181 | (I z :< Ø, gF') <- runOpM' (runOpBS no ) (n2 ::< y ::< Ø)
182 | let gF'' dZ = do
183 | I dN2 :< I dY :< Ø <- gF' dZ
184 | let dY' = maybe dY (tzip (*) dY) mask
185 | I dL0 :< I dX :< Ø <- gF (Just dY' :< Ø)
186 | return $ (dL0 :& dN2) ::< dX ::< Ø
187 | return (z ::< Ø, gF'')
188 |
--------------------------------------------------------------------------------
/src/Backprop/Learn/Model/Neural/LSTM.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE DataKinds #-}
2 | {-# LANGUAGE DeriveAnyClass #-}
3 | {-# LANGUAGE DeriveDataTypeable #-}
4 | {-# LANGUAGE DeriveGeneric #-}
5 | {-# LANGUAGE DerivingVia #-}
6 | {-# LANGUAGE FlexibleInstances #-}
7 | {-# LANGUAGE GADTs #-}
8 | {-# LANGUAGE KindSignatures #-}
9 | {-# LANGUAGE MultiParamTypeClasses #-}
10 | {-# LANGUAGE PatternSynonyms #-}
11 | {-# LANGUAGE RankNTypes #-}
12 | {-# LANGUAGE RecordWildCards #-}
13 | {-# LANGUAGE ScopedTypeVariables #-}
14 | {-# LANGUAGE StandaloneDeriving #-}
15 | {-# LANGUAGE TemplateHaskell #-}
16 | {-# LANGUAGE TypeApplications #-}
17 | {-# LANGUAGE TypeFamilies #-}
18 | {-# LANGUAGE TypeInType #-}
19 | {-# LANGUAGE TypeOperators #-}
20 | {-# LANGUAGE UndecidableInstances #-}
21 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
22 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
23 |
24 | module Backprop.Learn.Model.Neural.LSTM (
25 | -- * LSTM
26 | lstm
27 | , LSTMp(..), lstmForget, lstmInput, lstmUpdate, lstmOutput
28 | , reshapeLSTMpInput
29 | , reshapeLSTMpOutput
30 | , lstm'
31 | -- * GRU
32 | , gru
33 | , GRUp(..), gruMemory, gruUpdate, gruOutput
34 | , gru'
35 | ) where
36 |
37 | import Backprop.Learn.Initialize
38 | import Backprop.Learn.Model.Function
39 | import Backprop.Learn.Model.Neural
40 | import Backprop.Learn.Model.Regression
41 | import Backprop.Learn.Model.State
42 | import Backprop.Learn.Model.Types
43 | import Control.DeepSeq
44 | import Control.Monad
45 | import Control.Monad.Primitive
46 | import Data.Type.Tuple
47 | import Data.Typeable
48 | import GHC.Generics (Generic)
49 | import GHC.TypeNats
50 | import Lens.Micro
51 | import Lens.Micro.TH
52 | import Numeric.Backprop
53 | import Numeric.LinearAlgebra.Static.Backprop
54 | import Numeric.OneLiner
55 | import Numeric.Opto.Ref
56 | import Numeric.Opto.Update
57 | import Statistics.Distribution
58 | import qualified Data.Binary as Bi
59 | import qualified Numeric.LinearAlgebra.Static as H
60 | import qualified System.Random.MWC as MWC
61 |
62 | -- TODO: allow parameterize internal activation function?
63 | -- TODO: Peepholes
64 |
65 | -- | 'LSTM' layer parmateters
66 | data LSTMp (i :: Nat) (o :: Nat) =
67 | LSTMp { _lstmForget :: !(FCp (i + o) o)
68 | , _lstmInput :: !(FCp (i + o) o)
69 | , _lstmUpdate :: !(FCp (i + o) o)
70 | , _lstmOutput :: !(FCp (i + o) o)
71 | }
72 | deriving stock (Generic, Typeable, Show)
73 | deriving anyclass (NFData, Linear Double, Metric Double, Bi.Binary, Regularize, Backprop)
74 |
75 | deriving via (GNum (LSTMp i o)) instance (KnownNat i, KnownNat o) => Num (LSTMp i o)
76 | deriving via (GNum (LSTMp i o)) instance (KnownNat i, KnownNat o) => Fractional (LSTMp i o)
77 | deriving via (GNum (LSTMp i o)) instance (KnownNat i, KnownNat o) => Floating (LSTMp i o)
78 |
79 | makeLenses ''LSTMp
80 |
81 | instance (PrimMonad m, KnownNat i, KnownNat o) => Mutable m (LSTMp i o) where
82 | type Ref m (LSTMp i o) = GRef m (LSTMp i o)
83 | thawRef = gThawRef
84 | freezeRef = gFreezeRef
85 | copyRef = gCopyRef
86 | instance (PrimMonad m, KnownNat i, KnownNat o) => LinearInPlace m Double (LSTMp i o)
87 |
88 | instance (PrimMonad m, KnownNat i, KnownNat o) => Learnable m (LSTMp i o)
89 |
90 | -- | Stateless version of 'lstm' that takes the "previous input" as a part
91 | -- of the input vector.
92 | lstm'
93 | :: (KnownNat i, KnownNat o)
94 | => Model ('Just (LSTMp i o)) ('Just (R o)) (R (i + o)) (R o)
95 | lstm' = modelD $ \(PJust p) x (PJust s) ->
96 | let forget = logistic $ runLRp (p ^^. lstmForget) x
97 | input = logistic $ runLRp (p ^^. lstmInput ) x
98 | update = tanh $ runLRp (p ^^. lstmUpdate) x
99 | s' = forget * s + input * update
100 | o = logistic $ runLRp (p ^^. lstmOutput) x
101 | h = o * tanh s'
102 | in (h, PJust s')
103 |
104 | -- | Long-term short-term memory layer
105 | --
106 | --
107 | --
108 | lstm
109 | :: (KnownNat i, KnownNat o)
110 | => Model ('Just (LSTMp i o)) ('Just (R o :# R o)) (R i) (R o)
111 | lstm = recurrent H.split (H.#) id lstm'
112 |
113 | reshapeLSTMpInput
114 | :: (ContGen d, PrimMonad m, KnownNat i, KnownNat i', KnownNat o)
115 | => d
116 | -> MWC.Gen (PrimState m)
117 | -> LSTMp i o
118 | -> m (LSTMp i' o)
119 | reshapeLSTMpInput d g (LSTMp forget input update output) =
120 | LSTMp <$> reshaper forget
121 | <*> reshaper input
122 | <*> reshaper update
123 | <*> reshaper output
124 | where
125 | reshaper = reshapeLRpInput d g
126 |
127 | reshapeLSTMpOutput
128 | :: (ContGen d, PrimMonad m, KnownNat i, KnownNat o, KnownNat o')
129 | => d
130 | -> MWC.Gen (PrimState m)
131 | -> LSTMp i o
132 | -> m (LSTMp i o')
133 | reshapeLSTMpOutput d g (LSTMp forget input update output) =
134 | LSTMp <$> reshaper forget
135 | <*> reshaper input
136 | <*> reshaper update
137 | <*> reshaper output
138 | where
139 | reshaper = reshapeLRpInput d g
140 | <=< reshapeLRpOutput d g
141 |
142 | -- | Forget biases initialized to 1
143 | instance (KnownNat i, KnownNat o) => Initialize (LSTMp i o) where
144 | initialize d g = LSTMp <$> set (mapped . fcBias) 1 (initialize d g)
145 | <*> initialize d g
146 | <*> initialize d g
147 | <*> initialize d g
148 |
149 | -- | 'GRU' layer parmateters
150 | data GRUp (i :: Nat) (o :: Nat) =
151 | GRUp { _gruMemory :: !(FCp (i + o) o)
152 | , _gruUpdate :: !(FCp (i + o) o)
153 | , _gruOutput :: !(FCp (i + o) o)
154 | }
155 | deriving stock (Generic, Typeable, Show)
156 | deriving anyclass (NFData, Linear Double, Metric Double, Bi.Binary, Initialize, Regularize, Backprop)
157 |
158 | deriving via (GNum (GRUp i o)) instance (KnownNat i, KnownNat o) => Num (GRUp i o)
159 | deriving via (GNum (GRUp i o)) instance (KnownNat i, KnownNat o) => Fractional (GRUp i o)
160 | deriving via (GNum (GRUp i o)) instance (KnownNat i, KnownNat o) => Floating (GRUp i o)
161 |
162 | makeLenses ''GRUp
163 |
164 | instance (PrimMonad m, KnownNat i, KnownNat o) => Mutable m (GRUp i o) where
165 | type Ref m (GRUp i o) = GRef m (GRUp i o)
166 | thawRef = gThawRef
167 | freezeRef = gFreezeRef
168 | copyRef = gCopyRef
169 | instance (KnownNat i, KnownNat o, Mutable m (GRUp i o)) => LinearInPlace m Double (GRUp i o)
170 |
171 | instance (KnownNat i, KnownNat o, PrimMonad m) => Learnable m (GRUp i o)
172 |
173 | -- | Stateless version of 'gru' that takes the "previous input" as a part
174 | -- of the input vector.
175 | gru'
176 | :: forall i o. (KnownNat i, KnownNat o)
177 | => Model ('Just (GRUp i o)) 'Nothing (R (i + o)) (R o)
178 | gru' = modelStatelessD $ \(PJust p) x ->
179 | let z = logistic $ runLRp (p ^^. gruMemory) x
180 | r = logistic $ runLRp (p ^^. gruUpdate) x
181 | r' = 1 # r
182 | h' = tanh $ runLRp (p ^^. gruOutput) (r' * x)
183 | in (1 - z) * snd (split @i x) + z * h'
184 |
185 | -- | Gated Recurrent Unit
186 | --
187 | --
188 | --
189 | gru :: (KnownNat i, KnownNat o)
190 | => Model ('Just (GRUp i o)) ('Just (R o)) (R i) (R o)
191 | gru = recurrent H.split (H.#) id gru'
192 |
--------------------------------------------------------------------------------