├── README.md ├── Setup.hs ├── test └── Spec.hs ├── demo ├── mnist-ghcjs.js_o └── mnist-ghcjs.hs ├── .gitignore ├── src └── Backprop │ ├── Learn.hs │ └── Learn │ ├── Train.hs │ ├── Model │ ├── Parameter.hs │ ├── Stochastic.hs │ ├── Neural.hs │ ├── State.hs │ └── Neural │ │ └── LSTM.hs │ ├── Loss.hs │ ├── Run.hs │ ├── Test.hs │ └── Initialize.hs ├── old2 └── src │ ├── Backprop │ ├── Learn.hs │ └── Learn │ │ ├── Train.hs │ │ ├── Run.hs │ │ ├── Initialize.hs │ │ ├── Model │ │ ├── Parameter.hs │ │ ├── Neural.hs │ │ ├── Stochastic.hs │ │ └── Class.hs │ │ ├── Loss.hs │ │ └── Test.hs │ └── Data │ └── Type │ ├── NonEmpty.hs │ └── Mayb.hs ├── TODO.txt ├── old ├── src │ ├── Learn │ │ ├── Neural.hs │ │ └── Neural │ │ │ ├── Test.hs │ │ │ ├── Layer │ │ │ ├── Identity.hs │ │ │ ├── Applying.hs │ │ │ ├── FullyConnected.hs │ │ │ └── Compose.hs │ │ │ ├── Loss.hs │ │ │ └── Network │ │ │ └── Dropout.hs │ ├── Data │ │ └── Type │ │ │ └── Util.hs │ └── Numeric │ │ ├── BLAS │ │ ├── FVector.hs │ │ └── NVector.hs │ │ └── BLAS.hs └── app │ ├── MNIST.hs │ ├── Letter2Vec.hs │ └── Language.hs ├── LICENSE ├── stack.yaml ├── app ├── mnist.hs ├── word2vec.hs └── char-rnn.hs ├── Build.hs ├── package.yaml ├── stack.yaml.lock ├── stack-ghcjs.yaml └── .travis.yml /README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | main = defaultMain 3 | -------------------------------------------------------------------------------- /test/Spec.hs: -------------------------------------------------------------------------------- 1 | main :: IO () 2 | main = putStrLn "Test suite not yet implemented" 3 | -------------------------------------------------------------------------------- /demo/mnist-ghcjs.js_o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mstksg/backprop-learn/HEAD/demo/mnist-ghcjs.js_o -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.stack-work 2 | /data 3 | 4 | /TAGS 5 | /tags 6 | /out 7 | 8 | /.build 9 | /.shake 10 | /demo-exe 11 | /demo-js 12 | 13 | /backprop-learn.cabal 14 | 15 | /testout 16 | 17 | *.dump-hi 18 | -------------------------------------------------------------------------------- /demo/mnist-ghcjs.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings #-} 2 | 3 | import Reflex.Dom 4 | 5 | main :: IO () 6 | main = mainWidget $ el "div" $ do 7 | t <- textInput def 8 | dynText $ _textInput_value t 9 | -------------------------------------------------------------------------------- /src/Backprop/Learn.hs: -------------------------------------------------------------------------------- 1 | 2 | module Backprop.Learn ( 3 | module L 4 | ) where 5 | 6 | import Backprop.Learn.Initialize as L 7 | import Backprop.Learn.Loss as L 8 | import Backprop.Learn.Model as L 9 | import Backprop.Learn.Run as L 10 | import Backprop.Learn.Test as L 11 | import Backprop.Learn.Train as L 12 | -------------------------------------------------------------------------------- /old2/src/Backprop/Learn.hs: -------------------------------------------------------------------------------- 1 | 2 | module Backprop.Learn ( 3 | module L 4 | ) where 5 | 6 | import Backprop.Learn.Initialize as L 7 | import Backprop.Learn.Loss as L 8 | import Backprop.Learn.Model as L 9 | import Backprop.Learn.Run as L 10 | import Backprop.Learn.Test as L 11 | import Backprop.Learn.Train as L 12 | -------------------------------------------------------------------------------- /TODO.txt: -------------------------------------------------------------------------------- 1 | 2 | * Reset if diverge 3 | * Feedback loops 4 | * Get rid of initial state completely nad just require initialiation to zero 5 | for destate 6 | * trace state or internal activations once unrolled? (actually maybe easy) 7 | * KL-divergence autencoders 8 | * Variational autoencoders 9 | * Address regularization 10 | * GRU lstms 11 | * Binary 12 | * Dynamically grow or shrink networks 13 | * Reinforcement learning: loss function takes no "target" 14 | * Elman 15 | 16 | 17 | * Get rid of initparam and initstate 18 | -------------------------------------------------------------------------------- /old/src/Learn/Neural.hs: -------------------------------------------------------------------------------- 1 | 2 | module Learn.Neural ( 3 | module N 4 | ) where 5 | 6 | import Learn.Neural.Layer as N 7 | import Learn.Neural.Layer.Applying as N 8 | import Learn.Neural.Layer.FullyConnected as N 9 | import Learn.Neural.Layer.Identity as N 10 | import Learn.Neural.Layer.Mapping as N 11 | import Learn.Neural.Loss as N 12 | import Learn.Neural.Network as N 13 | import Learn.Neural.Network.Dropout as N 14 | import Learn.Neural.Test as N 15 | import Learn.Neural.Train as N 16 | import Numeric.BLAS as N 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Justin Le (c) 2017 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above 12 | copyright notice, this list of conditions and the following 13 | disclaimer in the documentation and/or other materials provided 14 | with the distribution. 15 | 16 | * Neither the name of Justin Le nor the names of other 17 | contributors may be used to endorse or promote products derived 18 | from this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /old/src/Learn/Neural/Test.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE LambdaCase #-} 3 | {-# LANGUAGE RankNTypes #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | {-# LANGUAGE TypeApplications #-} 6 | {-# LANGUAGE TypeOperators #-} 7 | 8 | module Learn.Neural.Test ( 9 | TestFunc 10 | , maxTest 11 | , rmseTest 12 | , squaredErrorTest 13 | , crossEntropyTest 14 | , testNet 15 | , testNetList 16 | ) where 17 | 18 | import Data.Proxy 19 | import GHC.TypeLits 20 | import Learn.Neural.Layer 21 | import Learn.Neural.Network 22 | import Numeric.BLAS 23 | import qualified Control.Foldl as F 24 | 25 | type TestFunc o = forall b. (BLAS b, Num (b o)) => b o -> b o -> Double 26 | 27 | maxTest :: KnownNat n => TestFunc '[n + 1] 28 | maxTest x y | iamax x == iamax y = 1 29 | | otherwise = 0 30 | 31 | rmseTest :: forall n. KnownNat n => TestFunc '[n] 32 | rmseTest x y = sqrt $ realToFrac (e `dot` e) / fromIntegral (natVal (Proxy @n)) 33 | where 34 | e = axpy (-1) x y 35 | 36 | squaredErrorTest :: KnownNat n => TestFunc '[n] 37 | squaredErrorTest x y = realToFrac $ e `dot` e 38 | where 39 | e = axpy (-1) x y 40 | 41 | crossEntropyTest :: KnownNat n => TestFunc '[n] 42 | crossEntropyTest r t = negate $ realToFrac (tmap log r `dot` t) 43 | 44 | testNet 45 | :: (BLAS b, Num (b i), Num (b o)) 46 | => TestFunc o 47 | -> SomeNet 'FeedForward b i o 48 | -> b i 49 | -> b o 50 | -> Double 51 | testNet tf = \case 52 | SomeNet _ n -> \x t -> 53 | let y = runNetPure n x 54 | in tf y t 55 | 56 | testNetList 57 | :: (BLAS b, Num (b i), Num (b o)) 58 | => TestFunc o 59 | -> SomeNet 'FeedForward b i o 60 | -> [(b i, b o)] 61 | -> Double 62 | testNetList tf n = F.fold F.mean . fmap (uncurry (testNet tf n)) 63 | -------------------------------------------------------------------------------- /old/src/Learn/Neural/Layer/Identity.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances #-} 2 | {-# LANGUAGE KindSignatures #-} 3 | {-# LANGUAGE MultiParamTypeClasses #-} 4 | {-# LANGUAGE TypeFamilies #-} 5 | {-# LANGUAGE TypeInType #-} 6 | 7 | 8 | module Learn.Neural.Layer.Identity ( 9 | Ident 10 | ) where 11 | 12 | import Data.Kind 13 | import Learn.Neural.Layer 14 | import Numeric.BLAS 15 | import Numeric.Backprop 16 | 17 | data Ident :: Type 18 | 19 | instance Num (CParam Ident b i i) where 20 | _ + _ = IdP 21 | _ * _ = IdP 22 | _ - _ = IdP 23 | negate _ = IdP 24 | abs _ = IdP 25 | signum _ = IdP 26 | fromInteger _ = IdP 27 | 28 | instance Fractional (CParam Ident b i i) where 29 | _ / _ = IdP 30 | recip _ = IdP 31 | fromRational _ = IdP 32 | 33 | instance Floating (CParam Ident b i i) where 34 | sqrt _ = IdP 35 | 36 | instance Num (CState Ident b i i) where 37 | _ + _ = IdS 38 | _ * _ = IdS 39 | _ - _ = IdS 40 | negate _ = IdS 41 | abs _ = IdS 42 | signum _ = IdS 43 | fromInteger _ = IdS 44 | 45 | instance Fractional (CState Ident b i i) where 46 | _ / _ = IdS 47 | recip _ = IdS 48 | fromRational _ = IdS 49 | 50 | instance Floating (CState Ident b i i) where 51 | sqrt _ = IdS 52 | 53 | instance BLAS b => Component Ident b i i where 54 | data CParam Ident b i i = IdP 55 | data CState Ident b i i = IdS 56 | data CConf Ident b i i = IdC 57 | 58 | componentOp = componentOpDefault 59 | 60 | initParam _ _ _ _ = return IdP 61 | initState _ _ _ _ = return IdS 62 | defConf = IdC 63 | 64 | instance BLAS b => ComponentFF Ident b i i where 65 | componentOpFF = bpOp . withInps $ \(x :< _ :< Ø) -> 66 | return . only $ x 67 | 68 | instance BLAS b => ComponentLayer r Ident b i i where 69 | componentRunMode = RMIsFF 70 | -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | # This file was automatically generated by 'stack init' 2 | # 3 | # Some commonly used options have been documented as comments in this file. 4 | # For advanced use and comprehensive documentation of the format, please see: 5 | # http://docs.haskellstack.org/en/stable/yaml_configuration/ 6 | 7 | # Resolver to choose a 'specific' stackage snapshot or a compiler version. 8 | # A snapshot resolver dictates the compiler version and the set of packages 9 | # to be used for project dependencies. For example: 10 | # 11 | # resolver: lts-3.5 12 | # resolver: nightly-2015-09-21 13 | # resolver: ghc-7.10.2 14 | # resolver: ghcjs-0.1.0_ghc-7.10.2 15 | # resolver: 16 | # name: custom-snapshot 17 | # location: "./custom-snapshot.yaml" 18 | resolver: nightly-2020-01-30 19 | 20 | # User packages to be built. 21 | # Various formats can be used as shown in the example below. 22 | # 23 | # packages: 24 | # - some-directory 25 | # - https://example.com/foo/bar/baz-0.0.2.tar.gz 26 | # - location: 27 | # git: https://github.com/commercialhaskell/stack.git 28 | # commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a 29 | # - location: https://github.com/commercialhaskell/stack/commit/e7b331f14bcffb8367cd58fbfc8b40ec7642100a 30 | # extra-dep: true 31 | # subdirs: 32 | # - auto-update 33 | # - wai 34 | # 35 | # A package marked 'extra-dep: true' will only be built if demanded by a 36 | # non-dependency (i.e. a user package), and its test suites and benchmarks 37 | # will not be run. This is useful for tweaking upstream packages. 38 | packages: 39 | - '.' 40 | 41 | # Dependency packages to be pulled from upstream that are not in the resolver 42 | # (e.g., acme-missiles-0.3) 43 | extra-deps: 44 | - github: mstksg/opto 45 | commit: def9e41adc1123385037f92ead38c92cb6454dd2 46 | - backprop-0.2.6.3 47 | - decidable-0.3.0.0 48 | - functor-products-0.1.1.0 49 | - hmatrix-vector-sized-0.1.2.0 50 | - hmatrix-backprop-0.1.3.0 51 | - list-witnesses-0.1.3.2 52 | - typelits-witnesses-0.4.0.0 53 | - vinyl-0.12.0 54 | - dependent-sum-0.6.2.0 55 | - constraints-extras-0.3.0.2 56 | 57 | # Override default flag values for local packages and extra-deps 58 | flags: {} 59 | 60 | # Extra package databases containing global packages 61 | extra-package-dbs: [] 62 | 63 | ghc-options: 64 | "$locals": -ddump-to-file -ddump-hi 65 | 66 | build: 67 | haddock-arguments: 68 | haddock-args: 69 | - --optghc=-fdefer-type-errors 70 | 71 | # Control whether we use the GHC we find on the path 72 | # system-ghc: true 73 | # 74 | # Require a specific version of stack, using version ranges 75 | # require-stack-version: -any # Default 76 | # require-stack-version: ">=1.3" 77 | # 78 | # Override the architecture used by stack, especially useful on Windows 79 | # arch: i386 80 | # arch: x86_64 81 | # 82 | # Extra directories used by stack for building 83 | # extra-include-dirs: [/path/to/dir] 84 | # extra-lib-dirs: [/path/to/dir] 85 | # 86 | # Allow a newer minor version of GHC than the snapshot specifies 87 | # compiler-check: newer-minor 88 | -------------------------------------------------------------------------------- /src/Backprop/Learn/Train.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-# LANGUAGE MultiParamTypeClasses #-} 4 | {-# LANGUAGE PartialTypeSignatures #-} 5 | {-# LANGUAGE RankNTypes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE TypeApplications #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# OPTIONS_GHC -Wno-partial-type-signatures #-} 10 | 11 | module Backprop.Learn.Train ( 12 | -- * Gradients 13 | gradModelLoss 14 | , gradModelStochLoss 15 | -- * Opto 16 | , Grad 17 | , modelGrad 18 | , modelGradStoch 19 | ) where 20 | 21 | import Backprop.Learn.Loss 22 | import Backprop.Learn.Model 23 | import Control.Monad.Primitive 24 | import Control.Monad.ST 25 | import Data.Word 26 | import Numeric.Backprop 27 | import Numeric.Opto.Core 28 | import qualified Data.Vector.Unboxed as VU 29 | import qualified System.Random.MWC as MWC 30 | 31 | -- | Gradient of model with respect to loss function and target 32 | gradModelLoss 33 | :: Backprop p 34 | => Loss b 35 | -> Regularizer p 36 | -> Model ('Just p) 'Nothing a b 37 | -> p 38 | -> a 39 | -> b 40 | -> p 41 | gradModelLoss loss reg f p x y = gradBP (\p' -> 42 | loss y (runLearnStateless f (PJust p') (constVar x)) + reg p' 43 | ) p 44 | 45 | -- | Stochastic gradient of model with respect to loss function and target 46 | gradModelStochLoss 47 | :: (Backprop p, PrimMonad m) 48 | => Loss b 49 | -> Regularizer p 50 | -> Model ('Just p) 'Nothing a b 51 | -> MWC.Gen (PrimState m) 52 | -> p 53 | -> a 54 | -> b 55 | -> m p 56 | gradModelStochLoss loss reg f g p x y = do 57 | seed <- MWC.uniformVector @_ @Word32 @VU.Vector g 2 58 | pure $ gradBP (\p' -> runST $ do 59 | g' <- MWC.initialize seed 60 | lo <- loss y <$> runLearnStochStateless f g' (PJust p') (constVar x) 61 | pure (lo + reg p') 62 | ) p 63 | 64 | -- | Using a model's deterministic prediction function (with a given loss 65 | -- function), generate a 'Grad' compatible with "Numeric.Opto" and 66 | -- "Numeric.Opto.Run". 67 | modelGrad 68 | :: (Applicative m, Backprop p) 69 | => Loss b 70 | -> Regularizer p 71 | -> Model ('Just p) 'Nothing a b 72 | -> Grad m (a, b) p 73 | modelGrad loss reg f = pureGrad $ \(x,y) p -> gradModelLoss loss reg f p x y 74 | 75 | -- | Using a model's stochastic prediction function (with a given loss 76 | -- function), generate a 'Grad' compatible with "Numeric.Opto" and 77 | -- "Numeric.Opto.Run". 78 | modelGradStoch 79 | :: (PrimMonad m, Backprop p) 80 | => Loss b 81 | -> Regularizer p 82 | -> Model ('Just p) 'Nothing a b 83 | -> MWC.Gen (PrimState m) 84 | -> Grad m (a, b) p 85 | modelGradStoch loss reg f g = \(x,y) p -> 86 | gradModelStochLoss loss reg f g p x y 87 | -------------------------------------------------------------------------------- /app/mnist.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-# LANGUAGE FlexibleInstances #-} 4 | {-# LANGUAGE MultiParamTypeClasses #-} 5 | {-# LANGUAGE PartialTypeSignatures #-} 6 | {-# LANGUAGE TupleSections #-} 7 | {-# LANGUAGE TypeApplications #-} 8 | {-# LANGUAGE TypeOperators #-} 9 | {-# LANGUAGE TypeSynonymInstances #-} 10 | {-# LANGUAGE UndecidableInstances #-} 11 | {-# OPTIONS_GHC -fno-warn-orphans #-} 12 | {-# OPTIONS_GHC -fno-warn-partial-type-signatures #-} 13 | 14 | import Backprop.Learn 15 | import Control.Monad.Trans.Maybe 16 | import Data.Bitraversable 17 | import Data.Default 18 | import Data.IDX 19 | import Data.Traversable 20 | import Data.Tuple 21 | import Numeric.LinearAlgebra.Static.Backprop 22 | import Numeric.Opto 23 | import Numeric.Opto.Run.Simple 24 | import System.Environment 25 | import System.FilePath 26 | import Text.Printf 27 | import qualified Data.Vector.Generic as VG 28 | import qualified Numeric.LinearAlgebra as HM 29 | import qualified Numeric.LinearAlgebra.Static as H 30 | import qualified System.Random.MWC as MWC 31 | 32 | mnistNet :: Model _ _ (R 784) (R 10) 33 | mnistNet = fca @300 softMax 34 | <~ dropout 0.25 35 | <~ fca @784 logistic 36 | 37 | main :: IO () 38 | main = MWC.withSystemRandom $ \g -> do 39 | datadir:_ <- getArgs 40 | Just train <- loadMNIST (datadir "train-images-idx3-ubyte") 41 | (datadir "train-labels-idx1-ubyte") 42 | Just test <- loadMNIST (datadir "t10k-images-idx3-ubyte") 43 | (datadir "t10k-labels-idx1-ubyte") 44 | putStrLn "Loaded data." 45 | net0 <- initParamNormal mnistNet 0.2 g 46 | 47 | let so = def { soTestSet = Just test 48 | , soEvaluate = runTest 49 | , soSkipSamps = 2500 50 | } 51 | runTest chnk net = printf "Error: %.2f%%" ((1 - score) * 100) 52 | where 53 | score = testModelAll maxIxTest mnistNet (TJust net) chnk 54 | 55 | simpleRunner so train SOSingle def net0 56 | (adam def $ modelGradStoch crossEntropy noReg mnistNet g) 57 | g 58 | 59 | loadMNIST 60 | :: FilePath 61 | -> FilePath 62 | -> IO (Maybe [(R 784, R 10)]) 63 | loadMNIST fpI fpL = runMaybeT $ do 64 | i <- MaybeT $ decodeIDXFile fpI 65 | l <- MaybeT $ decodeIDXLabelsFile fpL 66 | d <- MaybeT . return $ labeledIntData l i 67 | MaybeT . return $ for d (bitraverse mkImage mkLabel . swap) 68 | where 69 | mkImage = H.create . VG.convert . VG.map (\i -> fromIntegral i / 255) 70 | mkLabel n = H.create $ HM.build 10 (\i -> if round i == n then 1 else 0) 71 | 72 | -------------------------------------------------------------------------------- /old/src/Learn/Neural/Loss.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE GADTs #-} 3 | {-# LANGUAGE LambdaCase #-} 4 | {-# LANGUAGE PolyKinds #-} 5 | {-# LANGUAGE RankNTypes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE TypeApplications #-} 8 | 9 | module Learn.Neural.Loss ( 10 | LossFunction 11 | , crossEntropy 12 | , squaredError 13 | , sumLoss 14 | , sumLossDecay 15 | , zipLoss 16 | ) where 17 | 18 | import Data.Type.Util 19 | import Data.Type.Vector 20 | import GHC.TypeLits 21 | import Numeric.BLAS 22 | import Numeric.Backprop 23 | import Numeric.Backprop.Op 24 | import Type.Class.Witness 25 | import qualified Data.Type.Nat as TCN 26 | 27 | -- type LossFunction s = forall b q. (BLAS b, Num (b s)) => b s -> OpB q '[ b s ] '[ Scalar b ] 28 | type LossFunction as b = forall s. Tuple as -> OpB s as '[ b ] 29 | 30 | crossEntropy 31 | :: forall b n. (BLAS b, KnownNat n, Num (b '[n])) 32 | => LossFunction '[ b '[n] ] (Scalar b) 33 | crossEntropy (I targ :< Ø) = bpOp . withInps $ \(r :< Ø) -> do 34 | logR <- tmapOp log ~$ (r :< Ø) 35 | res <- negate <$> (dotOp ~$ (logR :< t :< Ø)) 36 | only <$> bindVar res 37 | where 38 | t = constVar targ 39 | 40 | squaredError 41 | :: (BLAS b, KnownNat n, Num (b '[n])) 42 | => LossFunction '[ b '[n] ] (Scalar b) 43 | squaredError (I targ :< Ø) = bpOp . withInps $ \(r :< Ø) -> do 44 | err <- bindVar $ r - t 45 | only <$> (dotOp ~$ (err :< err :< Ø)) 46 | where 47 | t = constVar targ 48 | 49 | sumLoss 50 | :: forall n a b. (Num a, Num b) 51 | => LossFunction '[ a ] b 52 | -> TCN.Nat n 53 | -> LossFunction (Replicate n a) b 54 | sumLoss l = \case 55 | TCN.Z_ -> \case Ø -> op0 (only_ 0) 56 | TCN.S_ n -> \case 57 | I x :< xs -> (replLen @_ @a n //) $ 58 | (replWit @_ @Num @a n Wit //) $ 59 | bpOp . withInps $ \(y :< ys) -> do 60 | z <- l (only_ x) ~$ (y :< Ø) 61 | zs <- sumLoss l n xs ~$ ys 62 | return . only $ z + zs 63 | 64 | sumLossDecay 65 | :: forall n a b. (Num a, Num b) 66 | => LossFunction '[ a ] b 67 | -> TCN.Nat n 68 | -> b 69 | -> LossFunction (Replicate n a) b 70 | sumLossDecay l n λ = zipLoss l (genDecay 1 n) 71 | where 72 | genDecay :: b -> TCN.Nat m -> Vec m b 73 | genDecay b = \case 74 | TCN.Z_ -> ØV 75 | TCN.S_ m -> case genDecay b m of 76 | ØV -> b :+ ØV 77 | I c :* cs -> (c * λ) :+ c :+ cs 78 | 79 | 80 | zipLoss 81 | :: forall n a b. (Num a, Num b) 82 | => LossFunction '[ a ] b 83 | -> Vec n b 84 | -> LossFunction (Replicate n a) b 85 | zipLoss l = \case 86 | ØV -> \case Ø -> op0 (only 0) 87 | I α :* αs -> 88 | let αn = vecLenNat αs 89 | in \case 90 | I x :< xs -> (replLen @_ @a αn //) $ 91 | (replWit @_ @Num @a αn Wit //) $ 92 | bpOp . withInps $ \(y :< ys) -> do 93 | z <- l (only_ x) ~$ (y :< Ø) 94 | zs <- zipLoss l αs xs ~$ ys 95 | return . only $ (z * constVar α) + zs 96 | -------------------------------------------------------------------------------- /Build.hs: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env stack 2 | -- stack --install-ghc runghc --package shake 3 | 4 | import Development.Shake 5 | import Development.Shake.FilePath 6 | import System.Directory 7 | 8 | opts = shakeOptions { shakeFiles = ".shake" 9 | , shakeVersion = "1.0" 10 | , shakeVerbosity = Normal 11 | , shakeThreads = 0 12 | } 13 | 14 | main :: IO () 15 | main = getDirectoryFilesIO "demo" ["/*.lhs", "/*.hs"] >>= \allDemos -> 16 | getDirectoryFilesIO "src" ["//*.hs"] >>= \allSrc -> 17 | shakeArgs opts $ do 18 | 19 | want ["all"] 20 | 21 | "all" ~> 22 | -- need ["haddocks", "gentags", "install", "exe", "js"] 23 | need ["haddocks", "gentags", "install", "js"] 24 | 25 | "exe" ~> 26 | need (map (\f -> "demo-exe" dropExtension f) allDemos) 27 | 28 | "js" ~> 29 | need (map (\f -> "demo-js" dropExtension f) allDemos) 30 | 31 | "haddocks" ~> do 32 | need (("src" ) <$> allSrc) 33 | cmd "jle-git-haddocks" 34 | 35 | "install" ~> do 36 | need $ ("src" ) <$> allSrc 37 | cmd "stack install" 38 | 39 | "install-ghcjs" ~> do 40 | need $ ("src" ) <$> allSrc 41 | cmd "stack install --stack-yaml stack-ghcjs.yaml" 42 | 43 | "gentags" ~> 44 | need ["tags", "TAGS"] 45 | 46 | "demo-exe/*" %> \f -> do 47 | need ["install"] 48 | [src] <- getDirectoryFiles "demo" $ (takeFileName f <.>) <$> ["hs","lhs"] 49 | liftIO $ do 50 | createDirectoryIfMissing True "demo-exe" 51 | createDirectoryIfMissing True ".build" 52 | removeFilesAfter "demo" ["/*.o"] 53 | cmd "stack exec" "--package backprop-learn" 54 | "--" 55 | "ghc" 56 | ("demo" src) 57 | "-o" f 58 | "-hidir" ".build" 59 | "-threaded" 60 | "-rtsopts" 61 | "-with-rtsopts=-N" 62 | "-Wall" 63 | "-O2" 64 | 65 | "demo-js/*" %> \f -> do 66 | need ["install-ghcjs"] 67 | [src] <- getDirectoryFiles "demo" $ (takeFileName f <.>) <$> ["hs","lhs"] 68 | liftIO $ do 69 | createDirectoryIfMissing True "demo-js" 70 | createDirectoryIfMissing True ".build" 71 | removeFilesAfter "demo" ["/*.o"] 72 | cmd "stack exec" "--package backprop-learn" 73 | "--package reflex-dom" 74 | "--stack-yaml stack-ghcjs.yaml" 75 | "--" 76 | "ghcjs" 77 | ("demo" src) 78 | "-o" f 79 | "-hidir" ".build" 80 | "-threaded" 81 | "-rtsopts" 82 | "-with-rtsopts=-N" 83 | "-Wall" 84 | "-O2" 85 | 86 | ["tags","TAGS"] &%> \_ -> do 87 | need (("src" ) <$> allSrc) 88 | cmd "hasktags" "src/" 89 | 90 | "clean" ~> do 91 | unit $ cmd "stack clean" 92 | removeFilesAfter ".shake" ["//*"] 93 | removeFilesAfter ".build" ["//*"] 94 | removeFilesAfter "demo-exe" ["//*"] 95 | 96 | -------------------------------------------------------------------------------- /package.yaml: -------------------------------------------------------------------------------- 1 | name: backprop-learn 2 | version: 0.1.0.0 3 | github: mstksg/backprop-learn 4 | license: BSD3 5 | author: Justin Le 6 | maintainer: justin@jle.im 7 | copyright: (c) Justin Le 2018 8 | tested-with: GHC >= 8.2 9 | 10 | extra-source-files: 11 | - README.md 12 | 13 | # Metadata used when publishing your package 14 | synopsis: Combinators and useful tools for ANNs using the backprop library 15 | category: Math 16 | 17 | # To avoid duplicated efforts in documentation and dealing with the 18 | # complications of embedding Haddock markup inside cabal files, it is 19 | # common to point users to the README.md file. 20 | description: See README.md 21 | 22 | ghc-options: 23 | - -Wall 24 | - -Wcompat 25 | - -Wincomplete-record-updates 26 | - -Wredundant-constraints 27 | # - -O0 28 | 29 | dependencies: 30 | - base >=4.7 && <5 31 | - ghc-typelits-extra 32 | - ghc-typelits-knownnat 33 | - ghc-typelits-natnormalise 34 | - hmatrix 35 | - hmatrix-backprop >= 0.1.3 36 | - mwc-random 37 | - opto 38 | 39 | library: 40 | source-dirs: src 41 | dependencies: 42 | - backprop >= 0.2.6.3 43 | - binary 44 | - bytestring 45 | - conduit 46 | - containers 47 | - deepseq 48 | - finite-typelits 49 | - foldl >= 1.4 50 | - functor-products 51 | - hmatrix-vector-sized >= 0.1.2 52 | - list-witnesses >= 0.1.2 53 | - microlens 54 | - microlens-th 55 | - one-liner 56 | - one-liner-instances 57 | - primitive 58 | - profunctors 59 | - singletons 60 | - statistics 61 | - transformers 62 | - typelits-witnesses 63 | - vector >= 0.12.0.2 64 | - vector-sized 65 | - vinyl 66 | 67 | _exec: &exec 68 | source-dirs: app 69 | ghc-options: 70 | - -threaded 71 | - -rtsopts 72 | - -with-rtsopts=-N 73 | - -O2 74 | 75 | executables: 76 | backprop-learn-mnist: 77 | <<: *exec 78 | main: mnist.hs 79 | dependencies: 80 | - backprop-learn 81 | - data-default 82 | - filepath 83 | - mnist-idx 84 | - transformers 85 | - vector >= 0.12.0.2 86 | backprop-learn-series: 87 | <<: *exec 88 | main: series.hs 89 | dependencies: 90 | - backprop >= 0.2.6.3 91 | - backprop-learn 92 | - conduit 93 | - data-default 94 | - deepseq 95 | - hmatrix-backprop >= 0.1.3 96 | - optparse-applicative 97 | - primitive 98 | - singletons 99 | - singletons 100 | - statistics 101 | - time 102 | - transformers 103 | - typelits-witnesses 104 | - vector-sized 105 | backprop-learn-char-rnn: 106 | <<: *exec 107 | main: char-rnn.hs 108 | dependencies: 109 | - backprop-learn 110 | - conduit 111 | - containers 112 | - data-default 113 | - deepseq 114 | - hmatrix-vector-sized 115 | - text 116 | - time 117 | - transformers 118 | - vector-sized 119 | backprop-learn-word2vec: 120 | <<: *exec 121 | main: word2vec.hs 122 | dependencies: 123 | - backprop-learn 124 | - conduit 125 | - containers 126 | - data-default 127 | - deepseq 128 | - hmatrix 129 | - text 130 | - time 131 | - transformers 132 | - vector >= 0.12.0.2 133 | - vector-sized 134 | 135 | -------------------------------------------------------------------------------- /old/src/Learn/Neural/Layer/Applying.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleContexts #-} 2 | {-# LANGUAGE FlexibleInstances #-} 3 | {-# LANGUAGE GADTs #-} 4 | {-# LANGUAGE MultiParamTypeClasses #-} 5 | {-# LANGUAGE RankNTypes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE TypeApplications #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# LANGUAGE TypeInType #-} 10 | 11 | module Learn.Neural.Layer.Applying ( 12 | Applying 13 | , TensorOp(..) 14 | , CommonOp(..) 15 | , SoftMax 16 | ) where 17 | 18 | import Data.Kind 19 | import Data.Reflection 20 | import Data.Singletons 21 | import GHC.TypeLits 22 | import Learn.Neural.Layer 23 | import Numeric.BLAS 24 | import Numeric.Backprop 25 | 26 | data Applying :: k -> Type 27 | 28 | newtype TensorOp :: [Nat] -> [Nat] -> Type where 29 | TF :: { getTensorOp 30 | :: forall b s. (BLAS b, Num (b i), Num (b o)) => OpB s '[ b i ] '[ b o ] 31 | } 32 | -> TensorOp i o 33 | 34 | instance Num (CParam (Applying s) b i o) where 35 | _ + _ = AppP 36 | _ * _ = AppP 37 | _ - _ = AppP 38 | negate _ = AppP 39 | abs _ = AppP 40 | signum _ = AppP 41 | fromInteger _ = AppP 42 | 43 | instance Fractional (CParam (Applying s) b i o) where 44 | _ / _ = AppP 45 | recip _ = AppP 46 | fromRational _ = AppP 47 | 48 | instance Floating (CParam (Applying s) b i o) where 49 | sqrt _ = AppP 50 | 51 | instance Num (CState (Applying s) b i o) where 52 | _ + _ = AppS 53 | _ * _ = AppS 54 | _ - _ = AppS 55 | negate _ = AppS 56 | abs _ = AppS 57 | signum _ = AppS 58 | fromInteger _ = AppS 59 | 60 | instance Fractional (CState (Applying s) b i o) where 61 | _ / _ = AppS 62 | recip _ = AppS 63 | fromRational _ = AppS 64 | 65 | instance Floating (CState (Applying s) b i o) where 66 | sqrt _ = AppS 67 | 68 | instance (BLAS b, Reifies s (TensorOp i o), SingI i, SingI o) => Component (Applying s) b i o where 69 | data CParam (Applying s) b i o = AppP 70 | data CState (Applying s) b i o = AppS 71 | data CConf (Applying s) b i o = AppC 72 | 73 | componentOp = componentOpDefault 74 | 75 | initParam _ _ _ _ = return AppP 76 | initState _ _ _ _ = return AppS 77 | defConf = AppC 78 | 79 | instance (BLAS b, Reifies s (TensorOp i o), SingI i, SingI o) => ComponentFF (Applying s) b i o where 80 | componentOpFF = bpOp . withInps $ \(x :< _ :< Ø) -> do 81 | y <- getTensorOp to ~$ (x :< Ø) 82 | return . only $ y 83 | where 84 | to :: TensorOp i o 85 | to = reflect (Proxy @s) 86 | 87 | instance (BLAS b, Reifies s (TensorOp i o), SingI i, SingI o) => ComponentLayer r (Applying s) b i o where 88 | componentRunMode = RMIsFF 89 | 90 | data CommonOp :: Type where 91 | TO_Softmax :: [Nat] -> CommonOp 92 | 93 | instance SingI i => Reifies ('TO_Softmax i) (TensorOp i i) where 94 | reflect _ = TF $ bpOp . withInps $ \(x :< Ø) -> do 95 | expX <- tmapOp exp ~$ (x :< Ø) 96 | totX <- tsumOp ~$ (expX :< Ø) 97 | sm <- scaleOp ~$ (1/totX :< expX :< Ø) 98 | return $ only sm 99 | 100 | type SoftMax i = Applying ('TO_Softmax i) 101 | -------------------------------------------------------------------------------- /old/src/Data/Type/Util.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE GADTs #-} 4 | {-# LANGUAGE KindSignatures #-} 5 | {-# LANGUAGE LambdaCase #-} 6 | {-# LANGUAGE PolyKinds #-} 7 | {-# LANGUAGE RankNTypes #-} 8 | {-# LANGUAGE ScopedTypeVariables #-} 9 | {-# LANGUAGE TypeApplications #-} 10 | {-# LANGUAGE TypeFamilyDependencies #-} 11 | {-# LANGUAGE TypeOperators #-} 12 | {-# OPTIONS_GHC -fno-warn-orphans #-} 13 | 14 | module Data.Type.Util ( 15 | MaybeToList 16 | , replWit 17 | , replLen 18 | , prodToVec' 19 | , vecToProd 20 | , vtraverse 21 | , vecLenNat 22 | , zipP 23 | , last' 24 | , takeVec 25 | , unzipVec 26 | ) where 27 | 28 | import Data.Finite 29 | import Data.Type.Combinator 30 | import Data.Type.Conjunction 31 | import Data.Type.Index 32 | import Data.Type.Length 33 | import Data.Type.Nat 34 | import Data.Type.Product hiding (last') 35 | import Data.Type.Vector 36 | import Numeric.Backprop.Op (Replicate) 37 | import Type.Class.Higher 38 | import Type.Class.Witness 39 | import Type.Family.Nat 40 | 41 | type family MaybeToList (a :: Maybe k) = (b :: [k]) | b -> a where 42 | MaybeToList ('Just a ) = '[a] 43 | MaybeToList 'Nothing = '[] 44 | 45 | replWit 46 | :: forall n c a. () 47 | => Nat n 48 | -> Wit (c a) 49 | -> Wit (Every c (Replicate n a)) 50 | replWit = \case 51 | Z_ -> (Wit \\) 52 | S_ n -> \case 53 | w@Wit -> Wit \\ replWit n w 54 | 55 | replLen 56 | :: forall n a. () 57 | => Nat n 58 | -> Length (Replicate n a) 59 | replLen = \case 60 | Z_ -> LZ 61 | S_ n -> LS (replLen @_ @a n) 62 | 63 | prodToVec' 64 | :: Nat n 65 | -> Prod f (Replicate n a) 66 | -> VecT n f a 67 | prodToVec' = \case 68 | Z_ -> \case 69 | Ø -> ØV 70 | S_ n -> \case 71 | x :< xs -> 72 | x :* prodToVec' n xs 73 | 74 | vecToProd 75 | :: VecT n f a 76 | -> Prod f (Replicate n a) 77 | vecToProd = \case 78 | ØV -> Ø 79 | x :* xs -> x :< vecToProd xs 80 | 81 | vtraverse 82 | :: Applicative h 83 | => (f a -> h (g b)) 84 | -> VecT n f a 85 | -> h (VecT n g b) 86 | vtraverse f = \case 87 | ØV -> pure ØV 88 | x :* xs -> (:*) <$> f x <*> vtraverse f xs 89 | 90 | zipP 91 | :: Prod f as 92 | -> Prod g as 93 | -> Prod (f :&: g) as 94 | zipP = \case 95 | Ø -> \case 96 | Ø -> Ø 97 | x :< xs -> \case 98 | y:< ys -> (x :&: y) :< zipP xs ys 99 | 100 | vecLenNat 101 | :: VecT n f a 102 | -> Nat n 103 | vecLenNat = \case 104 | ØV -> Z_ 105 | _ :* xs -> S_ (vecLenNat xs) 106 | 107 | instance Eq1 Finite 108 | 109 | last' 110 | :: VecT ('S n) f a 111 | -> f a 112 | last' = \case 113 | x :* ØV -> x 114 | _ :* xs@(_ :* _) -> last' xs 115 | 116 | takeVec 117 | :: Nat n 118 | -> [a] 119 | -> Maybe (Vec n a) 120 | takeVec = \case 121 | Z_ -> \_ -> Just ØV 122 | S_ n -> \case 123 | [] -> Nothing 124 | x:xs -> (x :+) <$> takeVec n xs 125 | 126 | unzipVec 127 | :: Vec n (a, b) 128 | -> (Vec n a, Vec n b) 129 | unzipVec = \case 130 | ØV -> (ØV, ØV) 131 | I (x,y) :* xsys -> 132 | let (xs, ys) = unzipVec xsys 133 | in (I x :* xs, I y :* ys) 134 | -------------------------------------------------------------------------------- /old2/src/Backprop/Learn/Train.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-# LANGUAGE MultiParamTypeClasses #-} 4 | {-# LANGUAGE PartialTypeSignatures #-} 5 | {-# LANGUAGE RankNTypes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE TypeApplications #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# OPTIONS_GHC -fno-warn-partial-type-signatures #-} 10 | 11 | module Backprop.Learn.Train ( 12 | -- * Gradients 13 | gradLearnLoss 14 | , gradLearnStochLoss 15 | -- * Opto 16 | , Grad 17 | , learnGrad 18 | , learnGradStoch 19 | ) where 20 | 21 | import Backprop.Learn.Loss 22 | import Backprop.Learn.Model 23 | import Control.Monad.Primitive 24 | import Control.Monad.ST 25 | import Control.Monad.Sample 26 | import Data.Word 27 | import Numeric.Backprop 28 | import Numeric.Opto.Core 29 | import qualified Data.Vector.Unboxed as VU 30 | import qualified System.Random.MWC as MWC 31 | 32 | -- | Gradient of model with respect to loss function and target 33 | gradLearnLoss 34 | :: ( Learn a b l 35 | , NoState l 36 | , LParamMaybe l ~ 'Just (LParam l) 37 | , Backprop (LParam l) 38 | ) 39 | => Loss b 40 | -> Regularizer (LParam l) 41 | -> l 42 | -> LParam l 43 | -> a 44 | -> b 45 | -> LParam l 46 | gradLearnLoss loss reg l p x y = gradBP (\p' -> 47 | loss y (runLearnStateless l (J_ p') (constVar x)) + reg p' 48 | ) p 49 | 50 | -- | Stochastic gradient of model with respect to loss function and target 51 | gradLearnStochLoss 52 | :: ( Learn a b l 53 | , NoState l 54 | , LParamMaybe l ~ 'Just (LParam l) 55 | , Backprop (LParam l) 56 | , PrimMonad m 57 | ) 58 | => Loss b 59 | -> Regularizer (LParam l) 60 | -> l 61 | -> MWC.Gen (PrimState m) 62 | -> LParam l 63 | -> a 64 | -> b 65 | -> m (LParam l) 66 | gradLearnStochLoss loss reg l g p x y = do 67 | seed <- MWC.uniformVector @_ @Word32 @VU.Vector g 2 68 | pure $ gradBP (\p' -> runST $ do 69 | g' <- MWC.initialize seed 70 | lo <- loss y <$> runLearnStochStateless l g' (J_ p') (constVar x) 71 | pure (lo + reg p') 72 | ) p 73 | 74 | -- | Using a model's deterministic prediction function (with a given loss 75 | -- function), generate a 'Grad' compatible with "Numeric.Opto" and 76 | -- "Numeric.Opto.Run". 77 | learnGrad 78 | :: ( MonadSample (a, b) m 79 | , Learn a b l 80 | , NoState l 81 | , LParamMaybe l ~ 'Just (LParam l) 82 | , Backprop (LParam l) 83 | ) 84 | => Loss b 85 | -> Regularizer (LParam l) 86 | -> l 87 | -> Grad m (LParam l) 88 | learnGrad loss reg l = pureSampling $ \(x,y) p -> gradLearnLoss loss reg l p x y 89 | 90 | -- | Using a model's stochastic prediction function (with a given loss 91 | -- function), generate a 'Grad' compatible with "Numeric.Opto" and 92 | -- "Numeric.Opto.Run". 93 | learnGradStoch 94 | :: ( MonadSample (a, b) m 95 | , PrimMonad m 96 | , Learn a b l 97 | , NoState l 98 | , LParamMaybe l ~ 'Just (LParam l) 99 | , Backprop (LParam l) 100 | ) 101 | => Loss b 102 | -> Regularizer (LParam l) 103 | -> l 104 | -> MWC.Gen (PrimState m) 105 | -> Grad m (LParam l) 106 | learnGradStoch loss reg l g = sampling $ \(x,y) p -> 107 | gradLearnStochLoss loss reg l g p x y 108 | -------------------------------------------------------------------------------- /src/Backprop/Learn/Model/Parameter.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveDataTypeable #-} 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-# LANGUAGE FlexibleInstances #-} 4 | {-# LANGUAGE GADTs #-} 5 | {-# LANGUAGE KindSignatures #-} 6 | {-# LANGUAGE MultiParamTypeClasses #-} 7 | {-# LANGUAGE PatternSynonyms #-} 8 | {-# LANGUAGE RankNTypes #-} 9 | {-# LANGUAGE RecordWildCards #-} 10 | {-# LANGUAGE ScopedTypeVariables #-} 11 | {-# LANGUAGE TypeFamilies #-} 12 | {-# LANGUAGE TypeInType #-} 13 | {-# LANGUAGE UndecidableInstances #-} 14 | 15 | module Backprop.Learn.Model.Parameter ( 16 | deParam, deParamD 17 | , reParam, reParamD 18 | , dummyParam 19 | ) where 20 | 21 | import Backprop.Learn.Model.Types 22 | import Control.Monad.Primitive 23 | import Numeric.Backprop 24 | import qualified System.Random.MWC as MWC 25 | 26 | -- | Fix a part of a parameter as constant, preventing backpropagation 27 | -- through it and not training it. 28 | -- 29 | -- Treats a @pq@ parameter as essentialyl a @(p, q)@, witnessed through the 30 | -- split and join functions. 31 | -- 32 | -- Takes the fixed value of @q@, as well as a stochastic mode version with 33 | -- fixed distribution. 34 | deParam 35 | :: forall p q pq s a b. (Backprop p, Backprop q) 36 | => (pq -> (p, q)) -- ^ split 37 | -> (p -> q -> pq) -- ^ join 38 | -> q -- ^ fixed param 39 | -> (forall m. (PrimMonad m) => MWC.Gen (PrimState m) -> m q) -- ^ fixed stoch param 40 | -> Model ('Just pq) s a b 41 | -> Model ('Just p ) s a b 42 | deParam spl joi q qStoch = reParam (PJust . r . fromPJust) 43 | (\g -> fmap PJust . rStoch g . fromPJust) 44 | where 45 | r :: Reifies z W => BVar z p -> BVar z pq 46 | r p = isoVar2 joi spl p (auto q) 47 | rStoch 48 | :: (PrimMonad m, Reifies z W) 49 | => MWC.Gen (PrimState m) 50 | -> BVar z p 51 | -> m (BVar z pq) 52 | rStoch g p = isoVar2 joi spl p . auto <$> qStoch g 53 | 54 | -- | 'deParam', but with no special stochastic mode version. 55 | deParamD 56 | :: (Backprop p, Backprop q) 57 | => (pq -> (p, q)) -- ^ split 58 | -> (p -> q -> pq) -- ^ join 59 | -> q -- ^ fixed param 60 | -> Model ('Just pq) s a b 61 | -> Model ('Just p ) s a b 62 | deParamD spl joi q = deParam spl joi q (const (pure q)) 63 | 64 | -- | Pre-applies a function to a parameter before a model sees it. 65 | -- Essentially something like 'lmap' for parameters. 66 | -- 67 | -- Takes a determinstic function and also a stochastic function for 68 | -- stochastic mode. 69 | reParam 70 | :: (forall z. Reifies z W => PMaybe (BVar z) q -> PMaybe (BVar z) p) 71 | -> (forall m z. (PrimMonad m, Reifies z W) => MWC.Gen (PrimState m) -> PMaybe (BVar z) q -> m (PMaybe (BVar z) p)) 72 | -> Model p s a b 73 | -> Model q s a b 74 | reParam r rStoch f = Model 75 | { runLearn = runLearn f . r 76 | , runLearnStoch = \g p x s -> do 77 | q <- rStoch g p 78 | runLearnStoch f g q x s 79 | } 80 | 81 | -- | 'reParam', but with no special stochastic mode function. 82 | reParamD 83 | :: (forall z. Reifies z W => PMaybe (BVar z) q -> PMaybe (BVar z) p) 84 | -> Model p s a b 85 | -> Model q s a b 86 | reParamD r = reParam r (\_ -> pure . r) 87 | 88 | -- | Give an unparameterized model a "dummy" parameter. Useful for usage 89 | -- with combinators like 'Control.Category..' from that require all input 90 | -- models to share a common parameterization. 91 | dummyParam 92 | :: Model 'Nothing s a b 93 | -> Model p s a b 94 | dummyParam = reParamD (const PNothing) 95 | -------------------------------------------------------------------------------- /old/app/MNIST.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE ScopedTypeVariables #-} 3 | {-# LANGUAGE TypeOperators #-} 4 | 5 | import Control.DeepSeq 6 | import Control.Exception 7 | import Control.Monad 8 | import Control.Monad.IO.Class 9 | import Control.Monad.Trans.Maybe 10 | import Control.Monad.Trans.State 11 | import Data.Bifunctor 12 | import Data.Bitraversable 13 | import Data.Default 14 | import Data.Finite 15 | import Data.IDX 16 | import Data.List.Split 17 | import Data.Time.Clock 18 | import Data.Traversable 19 | import Data.Tuple 20 | import Data.Type.Product 21 | import Learn.Neural 22 | import Numeric.BLAS.HMatrix 23 | import Numeric.LinearAlgebra.Static 24 | import Text.Printf 25 | import qualified Data.Vector as V 26 | import qualified Data.Vector.Generic as VG 27 | import qualified Data.Vector.Unboxed as VU 28 | import qualified System.Random.MWC as MWC 29 | import qualified System.Random.MWC.Distributions as MWC 30 | 31 | loadMNIST 32 | :: FilePath 33 | -> FilePath 34 | -> IO (Maybe [(HM '[784], HM '[10])]) 35 | loadMNIST fpI fpL = runMaybeT $ do 36 | i <- MaybeT $ decodeIDXFile fpI 37 | l <- MaybeT $ decodeIDXLabelsFile fpL 38 | d <- MaybeT . return $ labeledIntData l i 39 | r <- MaybeT . return $ for d (bitraverse mkImage mkLabel . swap) 40 | liftIO . evaluate $ force r 41 | where 42 | mkImage :: VU.Vector Int -> Maybe (HM '[784]) 43 | mkImage = fmap HM . create . VG.convert . VG.map (\i -> fromIntegral i / 255) 44 | mkLabel :: Int -> Maybe (HM '[10]) 45 | mkLabel = fmap (oneHot . only) . packFinite . fromIntegral 46 | 47 | main :: IO () 48 | main = MWC.withSystemRandom $ \g -> do 49 | Just train <- loadMNIST "data/train-images-idx3-ubyte" "data/train-labels-idx1-ubyte" 50 | Just test <- loadMNIST "data/t10k-images-idx3-ubyte" "data/t10k-labels-idx1-ubyte" 51 | putStrLn "Loaded data." 52 | net0 :: Network 'FeedForward HM ( '[784] :~ FullyConnected ) 53 | '[ '[300] :~ LogitMap 54 | , '[300] :~ FullyConnected 55 | , '[100] :~ LogitMap 56 | , '[100] :~ FullyConnected 57 | , '[10 ] :~ SoftMax '[10] 58 | ] 59 | '[10] <- initDefNet g 60 | let dout = alongNet net0 $ Nothing 61 | :&% Just 0.2 62 | :&% Nothing 63 | :&% Just 0.2 64 | :&% Nothing 65 | :&% DOExt 66 | flip evalStateT net0 . forM_ [1..] $ \e -> do 67 | train' <- liftIO . fmap V.toList $ MWC.uniformShuffle (V.fromList train) g 68 | liftIO $ printf "[Epoch %d]\n" (e :: Int) 69 | 70 | forM_ ([1..] `zip` chunksOf batch train') $ \(b, chnk) -> StateT $ \n0 -> do 71 | printf "(Batch %d)\n" (b :: Int) 72 | 73 | t0 <- getCurrentTime 74 | -- n' <- evaluate $ optimizeList_ (bimap only_ only_ <$> chnk) n0 75 | -- -- (sgdOptimizer rate netOpPure crossEntropy) 76 | -- (adamOptimizer def netOpPure crossEntropy) 77 | n' <- optimizeListM_ (bimap only_ only_ <$> chnk) n0 78 | (adamOptimizerM def (netOpDOPure dout g) crossEntropy) 79 | t1 <- getCurrentTime 80 | printf "Trained on %d points in %s.\n" batch (show (t1 `diffUTCTime` t0)) 81 | 82 | let trainScore = testNetList maxTest (someNet n') chnk 83 | testScore = testNetList maxTest (someNet n') test 84 | printf "Training error: %.2f%%\n" ((1 - trainScore) * 100) 85 | printf "Validation error: %.2f%%\n" ((1 - testScore ) * 100) 86 | 87 | return ((), n') 88 | where 89 | -- rate :: Double 90 | -- rate = 0.02 91 | batch :: Int 92 | batch = 2500 93 | -------------------------------------------------------------------------------- /stack.yaml.lock: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by Stack. 2 | # You should not edit this file by hand. 3 | # For more information, please see the documentation at: 4 | # https://docs.haskellstack.org/en/stable/lock_files 5 | 6 | packages: 7 | - completed: 8 | size: 26685 9 | url: https://github.com/mstksg/opto/archive/def9e41adc1123385037f92ead38c92cb6454dd2.tar.gz 10 | cabal-file: 11 | size: 2442 12 | sha256: 163eeb672b5afc9f7147275d786060a80308c31e57a84d94ebbb62c74c059ee8 13 | name: opto 14 | version: 0.1.0.0 15 | sha256: 627e8bf47074ae57b1ffa843810177c2df72fe3993de59aa2f8c68c6722b67f8 16 | pantry-tree: 17 | size: 1215 18 | sha256: f3dedce3dc9e88ec9fe99f0b69b8893fb2047e8c780d901e5832cfd2e1f52aae 19 | original: 20 | url: https://github.com/mstksg/opto/archive/def9e41adc1123385037f92ead38c92cb6454dd2.tar.gz 21 | - completed: 22 | hackage: backprop-0.2.6.3@sha256:a38430a6b4539cebd3c41d4fad9c48c2654fad6f6455d3b119882e7bf4804cfc,2805 23 | pantry-tree: 24 | size: 2123 25 | sha256: 433f277d8d4a0ad480fe2f0e8f71bd1f2cc71604b1fd11a15e8c86c81424648f 26 | original: 27 | hackage: backprop-0.2.6.3 28 | - completed: 29 | hackage: decidable-0.3.0.0@sha256:34857003b57139a047c9ab7944c313c227d9db702a8dcefa1478966257099423,1774 30 | pantry-tree: 31 | size: 764 32 | sha256: cc9b297fe8d4b4606583d24ff17a310ce227db3edd9a3d9c35df90582431ff68 33 | original: 34 | hackage: decidable-0.3.0.0 35 | - completed: 36 | hackage: functor-products-0.1.1.0@sha256:2bea36b6106b5756be6b81b3a5bfe7b41db1cf45fb63c19a1f04b572ba90fd0c,1456 37 | pantry-tree: 38 | size: 408 39 | sha256: 6c7d58498a2c23338baa8275a51e9099739812b1bd36126f887e5cdf57cce45b 40 | original: 41 | hackage: functor-products-0.1.1.0 42 | - completed: 43 | hackage: hmatrix-vector-sized-0.1.2.0@sha256:5b13ea0d371c54293e8a2159d3cea1f1466b8908331e1a1956c26b8732ca67a3,1881 44 | pantry-tree: 45 | size: 402 46 | sha256: 38bbeddbc873627d95d789a55d12104e8fe376f80cf9ae3ef3ca4ca282e5b177 47 | original: 48 | hackage: hmatrix-vector-sized-0.1.2.0 49 | - completed: 50 | hackage: hmatrix-backprop-0.1.3.0@sha256:eff786a68b34d44df9bd461c5890d0d876684f34e5fd4b6078a67263d21279f1,2294 51 | pantry-tree: 52 | size: 455 53 | sha256: 30892afe2120ad983d39bbda594a9d18b4d44572ba9d032e271455ba1b00091d 54 | original: 55 | hackage: hmatrix-backprop-0.1.3.0 56 | - completed: 57 | hackage: list-witnesses-0.1.3.2@sha256:a72d5c2ca295b89313842c405e8ccd972153d0476dba00387515312931b87a3f,1670 58 | pantry-tree: 59 | size: 398 60 | sha256: 5c1c96db8ce32341b5693ae7c110c6e51e09d9aa98baaf6e76f0c5b965db2c4c 61 | original: 62 | hackage: list-witnesses-0.1.3.2 63 | - completed: 64 | hackage: typelits-witnesses-0.4.0.0@sha256:1d7092ba98fdc33f4b413e04144eb3ead7b105f74b2998e3c74a8a0feee685a9,1985 65 | pantry-tree: 66 | size: 403 67 | sha256: 2ee741f6bb4dba710e6449da335fdcf8940adb767798b29fdb8ae2606d22e0cb 68 | original: 69 | hackage: typelits-witnesses-0.4.0.0 70 | - completed: 71 | hackage: vinyl-0.12.0@sha256:6136e2608c2c4be0c112944fb0f5a6a0df56b50adec12eb1b7240258abfcf9b1,3790 72 | pantry-tree: 73 | size: 1857 74 | sha256: aeb9e0e1a3bbe2b1f048a096430d240964a31c6936a1da89f4b32e931eba9d69 75 | original: 76 | hackage: vinyl-0.12.0 77 | - completed: 78 | hackage: dependent-sum-0.6.2.0@sha256:bff37c85b38e768b942f9d81c2465b63a96076f1ba006e35612aa357770807b6,1856 79 | pantry-tree: 80 | size: 474 81 | sha256: ad3fbed5104f9ee9c8082c9dcc8ade847674e3053572533e8d26ad1a866f1107 82 | original: 83 | hackage: dependent-sum-0.6.2.0 84 | - completed: 85 | hackage: constraints-extras-0.3.0.2@sha256:bf6884be65958e9188ae3c9e5547abfd6d201df021bff8a4704c2c4fe1e1ae5b,1784 86 | pantry-tree: 87 | size: 594 88 | sha256: b0bcc96d375ee11b1972a2e9e8e42039b3f420b0e1c46e9c70652470445a6505 89 | original: 90 | hackage: constraints-extras-0.3.0.2 91 | snapshots: 92 | - completed: 93 | size: 471743 94 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/nightly/2020/1/30.yaml 95 | sha256: d55e2ae57e1af8641591b271a0315405ea34690c4ee50b4eaf7445ee1780ada2 96 | original: nightly-2020-01-30 97 | -------------------------------------------------------------------------------- /src/Backprop/Learn/Loss.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-# LANGUAGE RankNTypes #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | {-# LANGUAGE TypeApplications #-} 6 | {-# LANGUAGE TypeOperators #-} 7 | {-# OPTIONS_GHC -Wno-redundant-constraints #-} 8 | 9 | module Backprop.Learn.Loss ( 10 | -- * Loss functions 11 | Loss 12 | , crossEntropy, crossEntropy1 13 | , squaredError, absError, totalSquaredError, squaredErrorV 14 | -- , totalCov 15 | -- ** Manipulate loss functions 16 | , scaleLoss 17 | , sumLoss 18 | , sumLossDecay 19 | , lastLoss 20 | , zipLoss 21 | , t2Loss 22 | -- * Regularization 23 | , Regularizer 24 | , l2Reg 25 | , l1Reg 26 | , noReg 27 | -- ** Manipulate regularizers 28 | , addReg 29 | , scaleReg 30 | ) where 31 | 32 | import Backprop.Learn.Regularize 33 | import Control.Applicative 34 | import Data.Coerce 35 | import Data.Finite 36 | import Data.Type.Tuple 37 | import GHC.TypeNats 38 | import Lens.Micro 39 | import Numeric.Backprop 40 | import Numeric.LinearAlgebra.Static.Backprop 41 | import qualified Data.Vector.Sized as SV 42 | import qualified Prelude.Backprop as B 43 | 44 | type Loss a = forall s. Reifies s W => a -> BVar s a -> BVar s Double 45 | 46 | crossEntropy :: KnownNat n => Loss (R n) 47 | crossEntropy targ res = -(log res <.> auto targ) 48 | 49 | crossEntropy1 :: Loss Double 50 | crossEntropy1 targ res = -(log res * auto targ + log (1 - res) * auto (1 - targ)) 51 | 52 | squaredErrorV :: KnownNat n => Loss (R n) 53 | squaredErrorV targ res = e <.> e 54 | where 55 | e = res - auto targ 56 | 57 | totalSquaredError 58 | :: (Backprop (t Double), Num (t Double), Foldable t, Functor t) 59 | => Loss (t Double) 60 | totalSquaredError targ res = B.sum (e * e) 61 | where 62 | e = auto targ - res 63 | 64 | squaredError :: Loss Double 65 | squaredError targ res = (res - auto targ) ** 2 66 | 67 | absError :: Loss Double 68 | absError targ res = abs (res - auto targ) 69 | 70 | -- -- | Sum of covariances between matching components. Not sure if anyone 71 | -- -- uses this. 72 | -- totalCov :: (Backprop (t Double), Foldable t, Functor t) => Loss (t Double) 73 | -- totalCov targ res = -(xy / fromIntegral n - (x * y) / fromIntegral (n * n)) 74 | -- where 75 | -- x = auto $ sum targ 76 | -- y = B.sum res 77 | -- xy = B.sum (auto targ * res) 78 | -- n = length targ 79 | 80 | -- klDivergence :: Loss Double 81 | -- klDivergence = 82 | 83 | sumLoss 84 | :: (Traversable t, Applicative t, Backprop a) 85 | => Loss a 86 | -> Loss (t a) 87 | sumLoss l targ = sum . liftA2 l targ . sequenceVar 88 | 89 | zipLoss 90 | :: (Traversable t, Applicative t, Backprop a) 91 | => t Double 92 | -> Loss a 93 | -> Loss (t a) 94 | zipLoss xs l targ = sum 95 | . liftA3 (\α t -> (* auto α) . l t) xs targ 96 | . sequenceVar 97 | 98 | sumLossDecay 99 | :: forall n a. (KnownNat n, Backprop a) 100 | => Double 101 | -> Loss a 102 | -> Loss (SV.Vector n a) 103 | sumLossDecay β = zipLoss $ SV.generate (\i -> β ** (fromIntegral i - n)) 104 | where 105 | n = fromIntegral $ maxBound @(Finite n) 106 | 107 | lastLoss 108 | :: forall n a. (KnownNat (n + 1), Backprop a) 109 | => Loss a 110 | -> Loss (SV.Vector (n + 1) a) 111 | lastLoss l targ = l (SV.last targ) 112 | . viewVar (coerced . SV.ix @(n + 1) maxBound) 113 | . B.coerce @_ @(ABP (SV.Vector (n + 1)) a) 114 | 115 | coerced :: Coercible a b => Lens' a b 116 | coerced f x = coerce <$> f (coerce x) 117 | {-# INLINE coerced #-} 118 | 119 | -- | Scale the result of a loss function. 120 | scaleLoss :: Double -> Loss a -> Loss a 121 | scaleLoss β l x = (* auto β) . l x 122 | 123 | -- | Lift and sum a loss function over the components of a ':&'. 124 | t2Loss 125 | :: (Backprop a, Backprop b) 126 | => Loss a -- ^ loss on first component 127 | -> Loss b -- ^ loss on second component 128 | -> Loss (a :# b) 129 | t2Loss f g (xT :# yT) (xR :## yR) = f xT xR + g yT yR 130 | 131 | -------------------------------------------------------------------------------- /src/Backprop/Learn/Model/Stochastic.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveDataTypeable #-} 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-# LANGUAGE FlexibleInstances #-} 4 | {-# LANGUAGE GADTs #-} 5 | {-# LANGUAGE KindSignatures #-} 6 | {-# LANGUAGE MultiParamTypeClasses #-} 7 | {-# LANGUAGE PatternSynonyms #-} 8 | {-# LANGUAGE RankNTypes #-} 9 | {-# LANGUAGE RecordWildCards #-} 10 | {-# LANGUAGE TypeFamilies #-} 11 | {-# LANGUAGE TypeInType #-} 12 | {-# LANGUAGE UndecidableInstances #-} 13 | {-# LANGUAGE ViewPatterns #-} 14 | 15 | module Backprop.Learn.Model.Stochastic ( 16 | dropout 17 | , rreLU 18 | , injectNoise, applyNoise 19 | , injectNoiseR, applyNoiseR 20 | ) where 21 | 22 | import Backprop.Learn.Model.Function 23 | import Backprop.Learn.Model.Types 24 | import Control.Monad.Primitive 25 | import Data.Bool 26 | import GHC.TypeNats 27 | import Numeric.Backprop 28 | import Numeric.LinearAlgebra.Static.Backprop 29 | import Numeric.LinearAlgebra.Static.Vector 30 | import qualified Data.Vector.Storable.Sized as SVS 31 | import qualified Statistics.Distribution as Stat 32 | import qualified System.Random.MWC as MWC 33 | import qualified System.Random.MWC.Distributions as MWC 34 | 35 | -- | Dropout layer. Parameterized by dropout percentage (should be between 36 | -- 0 and 1). 37 | -- 38 | -- 0 corresponds to no dropout, 1 corresponds to complete dropout of all 39 | -- nodes every time. 40 | dropout 41 | :: KnownNat n 42 | => Double 43 | -> Model 'Nothing 'Nothing (R n) (R n) 44 | dropout r = Func 45 | { runFunc = (auto (realToFrac (1 - r)) *) 46 | , runFuncStoch = \g x -> do 47 | (x *) . auto . vecR <$> SVS.replicateM (mask g) 48 | } 49 | where 50 | mask :: PrimMonad m => MWC.Gen (PrimState m) -> m Double 51 | mask = fmap (bool 1 0) . MWC.bernoulli r 52 | 53 | -- | Random leaky rectified linear unit 54 | rreLU 55 | :: (Stat.ContGen d, Stat.Mean d, KnownNat n) 56 | => d 57 | -> Model 'Nothing 'Nothing (R n) (R n) 58 | rreLU d = Func 59 | { runFunc = vmap' (preLU v) 60 | , runFuncStoch = \g x -> do 61 | α <- vecR <$> SVS.replicateM (Stat.genContVar d g) 62 | pure (zipWithVector preLU (constVar α) x) 63 | } 64 | where 65 | v :: BVar s Double 66 | v = auto (Stat.mean d) 67 | 68 | -- | Inject random noise. Usually used between neural network layers, or 69 | -- at the very beginning to pre-process input. 70 | -- 71 | -- In non-stochastic mode, this adds the mean of the distribution. 72 | injectNoise 73 | :: (Stat.ContGen d, Stat.Mean d, Fractional a) 74 | => d 75 | -> Model 'Nothing 'Nothing a a 76 | injectNoise d = Func 77 | { runFunc = (realToFrac (Stat.mean d) +) 78 | , runFuncStoch = \g x -> do 79 | e <- Stat.genContVar d g 80 | pure (realToFrac e + x) 81 | } 82 | 83 | -- | 'injectNoise' lifted to 'R' 84 | injectNoiseR 85 | :: (Stat.ContGen d, Stat.Mean d, KnownNat n) 86 | => d 87 | -> Model 'Nothing 'Nothing (R n) (R n) 88 | injectNoiseR d = Func 89 | { runFunc = (realToFrac (Stat.mean d) +) 90 | , runFuncStoch = \g x -> do 91 | e <- vecR <$> SVS.replicateM (Stat.genContVar d g) 92 | pure (constVar e + x) 93 | } 94 | 95 | -- | Multply by random noise. Can be used to implement dropout-like 96 | -- behavior. 97 | -- 98 | -- In non-stochastic mode, this scales by the mean of the distribution. 99 | applyNoise 100 | :: (Stat.ContGen d, Stat.Mean d, Fractional a) 101 | => d 102 | -> Model 'Nothing 'Nothing a a 103 | applyNoise d = Func 104 | { runFunc = (realToFrac (Stat.mean d) *) 105 | , runFuncStoch = \g x -> do 106 | e <- Stat.genContVar d g 107 | pure (realToFrac e * x) 108 | } 109 | 110 | -- | 'applyNoise' lifted to 'R' 111 | applyNoiseR 112 | :: (Stat.ContGen d, Stat.Mean d, KnownNat n) 113 | => d 114 | -> Model 'Nothing 'Nothing (R n) (R n) 115 | applyNoiseR d = Func 116 | { runFunc = (realToFrac (Stat.mean d) *) 117 | , runFuncStoch = \g x -> do 118 | e <- vecR <$> SVS.replicateM (Stat.genContVar d g) 119 | pure (constVar e * x) 120 | } 121 | -------------------------------------------------------------------------------- /old2/src/Backprop/Learn/Run.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE RankNTypes #-} 3 | {-# LANGUAGE ScopedTypeVariables #-} 4 | {-# LANGUAGE TupleSections #-} 5 | {-# LANGUAGE TypeApplications #-} 6 | {-# LANGUAGE TypeOperators #-} 7 | 8 | module Backprop.Learn.Run ( 9 | consecutives 10 | , consecutivesN 11 | , leadings 12 | , conduitLearn, conduitLearnStoch 13 | -- * Encoding and decoding for learning 14 | , oneHot', oneHot, oneHotR 15 | , SVG.maxIndex, maxIndexR 16 | ) where 17 | 18 | import Backprop.Learn.Model 19 | import Control.Monad 20 | import Control.Monad.Primitive 21 | import Control.Monad.Trans.Class 22 | import Control.Monad.Trans.Maybe 23 | import Data.Bool 24 | import Data.Conduit 25 | import Data.Finite 26 | import Data.Foldable 27 | import Data.Proxy 28 | import GHC.TypeNats 29 | import Numeric.LinearAlgebra.Static 30 | import Numeric.LinearAlgebra.Static.Vector 31 | import qualified Data.Conduit.Combinators as C 32 | import qualified Data.Sequence as Seq 33 | import qualified Data.Vector.Generic as VG 34 | import qualified Data.Vector.Generic.Sized as SVG 35 | import qualified System.Random.MWC as MWC 36 | 37 | consecutives :: Monad m => ConduitT i (i, i) m () 38 | consecutives = void . runMaybeT $ do 39 | x <- MaybeT await 40 | go x 41 | where 42 | go x = do 43 | y <- MaybeT await 44 | lift $ yield (x, y) 45 | go y 46 | 47 | consecutivesN 48 | :: forall v n i m. (KnownNat n, VG.Vector v i, Monad m) 49 | => ConduitT i (SVG.Vector v n i, SVG.Vector v n i) m () 50 | consecutivesN = conseq (fromIntegral n) .| C.concatMap process 51 | where 52 | n = natVal (Proxy @n) 53 | process (xs, ys, _) = (,) <$> SVG.fromList (toList xs) 54 | <*> SVG.fromList (toList ys) 55 | 56 | leadings 57 | :: forall v n i m. (KnownNat n, VG.Vector v i, Monad m) 58 | => ConduitT i (SVG.Vector v n i, i) m () 59 | leadings = conseq (fromIntegral n) .| C.concatMap process 60 | where 61 | n = natVal (Proxy @n) 62 | process (xs, _, y) = (, y) <$> SVG.fromList (toList xs) 63 | 64 | conseq 65 | :: forall i m. Monad m 66 | => Int 67 | -> ConduitT i (Seq.Seq i, Seq.Seq i, i) m () 68 | conseq n = void . runMaybeT $ do 69 | xs <- Seq.replicateM n $ MaybeT await 70 | go xs 71 | where 72 | go xs = do 73 | _ Seq.:<| xs' <- pure xs 74 | y <- MaybeT await 75 | let ys = xs' Seq.:|> y 76 | lift $ yield (xs, ys, y) 77 | go ys 78 | 79 | conduitLearn 80 | :: (Learn a b l, Backprop b, MaybeC Backprop (LStateMaybe l), Monad m) 81 | => l 82 | -> LParam_ I l 83 | -> LState_ I l 84 | -> ConduitT a b m (LState_ I l) 85 | conduitLearn l p = go 86 | where 87 | go s = do 88 | mx <- await 89 | case mx of 90 | Nothing -> return s 91 | Just x -> do 92 | let (y, s') = runLearn_ l p x s 93 | yield y 94 | go s' 95 | 96 | conduitLearnStoch 97 | :: (Learn a b l, Backprop b, MaybeC Backprop (LStateMaybe l), PrimMonad m) 98 | => l 99 | -> MWC.Gen (PrimState m) 100 | -> LParam_ I l 101 | -> LState_ I l 102 | -> ConduitT a b m (LState_ I l) 103 | conduitLearnStoch l g p = go 104 | where 105 | go s = do 106 | mx <- await 107 | case mx of 108 | Nothing -> return s 109 | Just x -> do 110 | (y, s') <- lift $ runLearnStoch_ l g p x s 111 | yield y 112 | go s' 113 | 114 | -- | What module should this be in? 115 | oneHot' 116 | :: (VG.Vector v a, KnownNat n) 117 | => a -- ^ not hot 118 | -> a -- ^ hot 119 | -> Finite n 120 | -> SVG.Vector v n a 121 | oneHot' nothot hot i = SVG.generate (bool nothot hot . (== i)) 122 | 123 | oneHot 124 | :: (VG.Vector v a, KnownNat n, Num a) 125 | => Finite n 126 | -> SVG.Vector v n a 127 | oneHot = oneHot' 0 1 128 | 129 | oneHotR :: KnownNat n => Finite n -> R n 130 | oneHotR = vecR . oneHot 131 | 132 | -- | Could be in /hmatrix/. 133 | maxIndexR :: KnownNat n => R (n + 1) -> Finite (n + 1) 134 | maxIndexR = SVG.maxIndex . rVec 135 | -------------------------------------------------------------------------------- /old/src/Learn/Neural/Layer/FullyConnected.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveGeneric #-} 2 | {-# LANGUAGE ExistentialQuantification #-} 3 | {-# LANGUAGE FlexibleInstances #-} 4 | {-# LANGUAGE InstanceSigs #-} 5 | {-# LANGUAGE LambdaCase #-} 6 | {-# LANGUAGE MultiParamTypeClasses #-} 7 | {-# LANGUAGE PolyKinds #-} 8 | {-# LANGUAGE StandaloneDeriving #-} 9 | {-# LANGUAGE TypeFamilies #-} 10 | {-# LANGUAGE TypeInType #-} 11 | 12 | module Learn.Neural.Layer.FullyConnected ( 13 | FullyConnected 14 | ) where 15 | 16 | import Data.Kind 17 | import Data.Singletons.Prelude 18 | import Data.Singletons.TypeLits 19 | import GHC.Generics (Generic) 20 | import GHC.Generics.Numeric 21 | import Learn.Neural.Layer 22 | import Numeric.BLAS 23 | import Numeric.Backprop 24 | import Statistics.Distribution 25 | import Statistics.Distribution.Normal 26 | import qualified Generics.SOP as SOP 27 | 28 | data FullyConnected :: Type 29 | 30 | instance (Num (b '[o,i]), Num (b '[o])) => Num (CParam FullyConnected b '[i] '[o]) where 31 | FCP w1 b1 + FCP w2 b2 = FCP (w1 + w2) (b1 + b2) 32 | FCP w1 b1 - FCP w2 b2 = FCP (w1 - w2) (b1 - b2) 33 | FCP w1 b1 * FCP w2 b2 = FCP (w1 * w2) (b1 * b2) 34 | negate (FCP w b) = FCP (negate w) (negate b) 35 | signum (FCP w b) = FCP (signum w) (signum b) 36 | abs (FCP w b) = FCP (abs w) (abs b) 37 | fromInteger x = FCP (fromInteger x) (fromInteger x) 38 | 39 | instance (Fractional (b '[o,i]), Fractional (b '[o])) => Fractional (CParam FullyConnected b '[i] '[o]) where 40 | FCP w1 b1 / FCP w2 b2 = FCP (w1 / w2) (b1 / b2) 41 | recip (FCP w b) = FCP (recip w) (recip b) 42 | fromRational x = FCP (fromRational x) (fromRational x) 43 | 44 | instance (Floating (b '[o,i]), Floating (b '[o])) => Floating (CParam FullyConnected b '[i] '[o]) where 45 | sqrt (FCP w b) = FCP (sqrt w) (sqrt b) 46 | 47 | instance Num (CState FullyConnected b '[i] '[o]) where 48 | _ + _ = FCS 49 | _ * _ = FCS 50 | _ - _ = FCS 51 | negate _ = FCS 52 | abs _ = FCS 53 | signum _ = FCS 54 | fromInteger _ = FCS 55 | 56 | instance Fractional (CState FullyConnected b '[i] '[o]) where 57 | _ / _ = FCS 58 | recip _ = FCS 59 | fromRational _ = FCS 60 | 61 | instance Floating (CState FullyConnected b '[i] '[o]) where 62 | sqrt _ = FCS 63 | 64 | 65 | 66 | 67 | 68 | deriving instance Generic (CParam FullyConnected b '[i] '[o]) 69 | instance SOP.Generic (CParam FullyConnected b '[i] '[o]) 70 | 71 | instance (BLAS b, KnownNat i, KnownNat o, Floating (b '[o,i]), Floating (b '[o])) 72 | => Component FullyConnected b '[i] '[o] where 73 | data CParam FullyConnected b '[i] '[o] = 74 | FCP { _fcWeights :: !(b '[o,i]) 75 | , _fcBiases :: !(b '[o]) 76 | } 77 | data CState FullyConnected b '[i] '[o] = FCS 78 | type CConstr FullyConnected b '[i] '[o] = Num (b '[o,i]) 79 | data CConf FullyConnected b '[i] '[o] = forall d. ContGen d => FCC d 80 | 81 | componentOp = componentOpDefault 82 | 83 | initParam = \case 84 | i `SCons` SNil -> \case 85 | so@(o `SCons` SNil) -> \(FCC d) g -> do 86 | w <- genA (o `SCons` (i `SCons` SNil)) $ \_ -> 87 | realToFrac <$> genContVar d g 88 | b <- genA so $ \_ -> 89 | realToFrac <$> genContVar d g 90 | return $ FCP w b 91 | _ -> error "inaccessible." 92 | _ -> error "inaccessible." 93 | 94 | initState _ _ _ _ = return FCS 95 | 96 | defConf = FCC (normalDistr 0 0.01) 97 | 98 | instance (BLAS b, KnownNat i, KnownNat o, Floating (b '[o,i]), Floating (b '[o])) 99 | => ComponentFF FullyConnected b '[i] '[o] where 100 | componentOpFF = bpOp . withInps $ \(x :< p :< Ø) -> do 101 | w :< b :< Ø <- gTuple #<~ p 102 | y <- matVecOp ~$ (w :< x :< Ø) 103 | z <- (+.) ~$ (y :< b :< Ø) 104 | return . only $ z 105 | 106 | instance (BLAS b, KnownNat i, KnownNat o, Floating (b '[o,i]), Floating (b '[o])) 107 | => ComponentLayer r FullyConnected b '[i] '[o] where 108 | componentRunMode = RMIsFF 109 | -------------------------------------------------------------------------------- /src/Backprop/Learn/Model/Neural.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE FlexibleContexts #-} 4 | {-# LANGUAGE FlexibleInstances #-} 5 | {-# LANGUAGE MultiParamTypeClasses #-} 6 | {-# LANGUAGE PatternSynonyms #-} 7 | {-# LANGUAGE RankNTypes #-} 8 | {-# LANGUAGE ScopedTypeVariables #-} 9 | {-# LANGUAGE TypeApplications #-} 10 | {-# LANGUAGE TypeFamilies #-} 11 | {-# LANGUAGE TypeOperators #-} 12 | {-# LANGUAGE UndecidableInstances #-} 13 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 14 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} 15 | 16 | module Backprop.Learn.Model.Neural ( 17 | -- * Feed-forward 18 | FCp, fc, fca 19 | , fcWeights, fcBias 20 | -- * Recurrent 21 | , fcr, fcra 22 | , FCRp, fcrBias, fcrInputWeights, fcrStateWeights 23 | ) where 24 | 25 | 26 | import Backprop.Learn.Model.Combinator 27 | import Backprop.Learn.Model.Regression 28 | import Backprop.Learn.Model.State 29 | import Backprop.Learn.Model.Types 30 | import Data.Tuple 31 | import GHC.TypeNats 32 | import Lens.Micro 33 | import Numeric.Backprop 34 | import Numeric.LinearAlgebra.Static.Backprop 35 | import qualified Numeric.LinearAlgebra.Static as H 36 | 37 | -- | Parameters for fully connected feed-forward layer with bias. 38 | type FCp = LRp 39 | 40 | fcWeights :: Lens (FCp i o) (FCp i' o) (L o i) (L o i') 41 | fcWeights = lrBeta 42 | 43 | fcBias :: forall i o. Lens' (FCp i o) (R o) 44 | fcBias = lrAlpha 45 | 46 | -- | Fully connected feed-forward layer with bias. Parameterized by its 47 | -- initialization distribution. 48 | -- 49 | -- Note that this has no activation function; to use as a model with 50 | -- activation function, chain it with an activation function using 'RMap', 51 | -- ':.~', etc.; see 'FCA' for a convenient type synonym and constructor. 52 | -- 53 | -- Without any activation function, this is essentially a multivariate 54 | -- linear regression. 55 | -- 56 | -- With the logistic function as an activation function, this is 57 | -- essentially multivariate logistic regression. (See 'logReg') 58 | fc :: (KnownNat i, KnownNat o) 59 | => Model ('Just (FCp i o)) 'Nothing (R i) (R o) 60 | fc = linReg 61 | 62 | -- | Convenient synonym for an 'fC' post-composed with a simple 63 | -- parameterless activation function. 64 | fca :: (KnownNat i, KnownNat o) 65 | => (forall z. Reifies z W => BVar z (R o) -> BVar z (R o)) 66 | -> Model ('Just (FCp i o)) 'Nothing (R i) (R o) 67 | fca f = funcD f <~ linReg 68 | 69 | -- | Fully connected recurrent layer with bias. 70 | fcr :: (KnownNat i, KnownNat o, KnownNat s) 71 | => (forall z. Reifies z W => BVar z (R o) -> BVar z (R s)) -- ^ store 72 | -> Model ('Just (FCRp s i o)) ('Just (R s)) (R i) (R o) 73 | fcr s = recurrent H.split (H.#) s fc 74 | 75 | -- | Convenient synonym for an 'fcr' post-composed with a simple 76 | -- parameterless activation function. 77 | fcra 78 | :: (KnownNat i, KnownNat o, KnownNat s) 79 | => (forall z. Reifies z W => BVar z (R o) -> BVar z (R o)) 80 | -> (forall z. Reifies z W => BVar z (R o) -> BVar z (R s)) -- ^ store 81 | -> Model ('Just (FCRp s i o)) ('Just (R s)) (R i) (R o) 82 | fcra f s = funcD f <~ recurrent H.split (H.#) s fc 83 | 84 | -- | Parameter for fully connected recurrent layer. 85 | type FCRp s i o = FCp (i + s) o 86 | 87 | lensIso :: (s -> (a, x)) -> ((b, x) -> t) -> Lens s t a b 88 | lensIso f g h x = g <$> _1 h (f x) 89 | 90 | fcrInputWeights 91 | :: (KnownNat s, KnownNat i, KnownNat i', KnownNat o) 92 | => Lens (FCRp s i o) (FCRp s i' o) (L o i) (L o i') 93 | fcrInputWeights = fcWeights 94 | . lensIso H.splitCols (uncurry (H.|||)) 95 | 96 | fcrStateWeights 97 | :: (KnownNat s, KnownNat s', KnownNat i, KnownNat o) 98 | => Lens (FCRp s i o) (FCRp s' i o) (L o s) (L o s') 99 | fcrStateWeights = fcWeights 100 | . lensIso (swap . H.splitCols) (uncurry (H.|||) . swap) 101 | 102 | fcrBias :: forall s i o. Lens' (FCRp s i o) (R o) 103 | fcrBias = fcBias @(i + s) @o 104 | -------------------------------------------------------------------------------- /stack-ghcjs.yaml: -------------------------------------------------------------------------------- 1 | # This file was automatically generated by 'stack init' 2 | # 3 | # Some commonly used options have been documented as comments in this file. 4 | # For advanced use and comprehensive documentation of the format, please see: 5 | # http://docs.haskellstack.org/en/stable/yaml_configuration/ 6 | 7 | # Resolver to choose a 'specific' stackage snapshot or a compiler version. 8 | # A snapshot resolver dictates the compiler version and the set of packages 9 | # to be used for project dependencies. For example: 10 | # 11 | # resolver: lts-3.5 12 | # resolver: nightly-2015-09-21 13 | # resolver: ghc-7.10.2 14 | # resolver: ghcjs-0.1.0_ghc-7.10.2 15 | # resolver: 16 | # name: custom-snapshot 17 | # location: "./custom-snapshot.yaml" 18 | resolver: lts-7.22 19 | 20 | # User packages to be built. 21 | # Various formats can be used as shown in the example below. 22 | # 23 | # packages: 24 | # - some-directory 25 | # - https://example.com/foo/bar/baz-0.0.2.tar.gz 26 | # - location: 27 | # git: https://github.com/commercialhaskell/stack.git 28 | # commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a 29 | # - location: https://github.com/commercialhaskell/stack/commit/e7b331f14bcffb8367cd58fbfc8b40ec7642100a 30 | # extra-dep: true 31 | # subdirs: 32 | # - auto-update 33 | # - wai 34 | # 35 | # A package marked 'extra-dep: true' will only be built if demanded by a 36 | # non-dependency (i.e. a user package), and its test suites and benchmarks 37 | # will not be run. This is useful for tweaking upstream packages. 38 | packages: 39 | - '.' 40 | - location: 41 | git: https://github.com/mstksg/backprop.git 42 | commit: 5c5fa4301cdae6f0b9c9607c18df01de32d288b7 43 | extra-dep: true 44 | - location: 45 | git: https://github.com/mstksg/type-combinators-singletons.git 46 | commit: 8bd05539987692ee745f0d78b83e5d72c7af21c4 47 | extra-dep: true 48 | - location: 49 | git: https://github.com/mstksg/generics-lift.git 50 | commit: c8a127d522f47efbf4cce55c9276a1d19cd6a211 51 | extra-dep: true 52 | - location: 53 | git: https://github.com/ghcjs/ghcjs-base.git 54 | commit: dd7034ef8582ea8a175a71a988393a9d1ee86d6f 55 | extra-dep: true 56 | - location: 57 | git: https://github.com/reflex-frp/reflex.git 58 | commit: 2fe0f566f8d6b6eceb178a85516643390111bb83 59 | extra-dep: true 60 | - location: 61 | git: https://github.com/reflex-frp/reflex-dom.git 62 | commit: 706ab47df9729bdc5c4ac3f4d8dfd4661d9f6e1a 63 | subdirs: 64 | - reflex-dom 65 | - reflex-dom-core 66 | extra-dep: true 67 | 68 | # Dependency packages to be pulled from upstream that are not in the resolver 69 | # (e.g., acme-missiles-0.3) 70 | extra-deps: 71 | - hmatrix-0.18.0.0 72 | - type-combinators-0.2.4.3 73 | - mnist-idx-0.1.2.6 74 | - finite-typelits-0.1.2.0 75 | - ghcjs-dom-0.8.0.0 76 | - ghcjs-dom-jsffi-0.8.0.0 77 | - jsaddle-0.8.3.2 78 | - prim-uniq-0.1.0.1 79 | - ref-tf-0.4.0.1 80 | - zenc-0.1.1 81 | # - ghcjs-base-0.2.0.0 82 | # - ghcjs-dom-0.2.4.0 83 | # - ghcjs-dom-jsffi-0.8.0.0 84 | # - ref-tf-0.4.0.1 85 | # - reflex-0.4.0 86 | # - reflex-dom-0.4 87 | 88 | # Override default flag values for local packages and extra-deps 89 | flags: {} 90 | 91 | # Extra package databases containing global packages 92 | extra-package-dbs: [] 93 | 94 | # # compiler: ghcjs-0.2.1.9007019_ghc-8.0.1 95 | compiler: ghcjs-0.2.1_ghc-8.0.1 96 | compiler-check: match-exact 97 | 98 | setup-info: 99 | ghcjs: 100 | source: 101 | ghcjs-0.2.1_ghc-8.0.1: 102 | url: /home/justin/projects/haskell/ghcjs/ghcjs/.stack-work/dist/x86_64-linux/Cabal-1.24.0.0/ghcjs-0.2.1.tar.gz 103 | # sha1: f2398e082a83a3aca381bc4457f50397af7fedfb 104 | # # sha1: 74ae6d67442221bcc15b43fc0e72784418f0bc53 105 | # # url: http://ghcjs.tolysz.org/ghc-8.0-2017-02-05-lts-7.19-9007019.tar.gz 106 | # # sha1: d2cfc25f9cda32a25a87d9af68891b2186ee52f9 107 | 108 | # Control whether we use the GHC we find on the path 109 | system-ghc: false 110 | # 111 | # Require a specific version of stack, using version ranges 112 | # require-stack-version: -any # Default 113 | # require-stack-version: ">=1.3" 114 | # 115 | # Override the architecture used by stack, especially useful on Windows 116 | # arch: i386 117 | # arch: x86_64 118 | # 119 | # Extra directories used by stack for building 120 | # extra-include-dirs: [/path/to/dir] 121 | # extra-lib-dirs: [/path/to/dir] 122 | # 123 | # Allow a newer minor version of GHC than the snapshot specifies 124 | # compiler-check: newer-minor 125 | -------------------------------------------------------------------------------- /src/Backprop/Learn/Run.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE RankNTypes #-} 3 | {-# LANGUAGE ScopedTypeVariables #-} 4 | {-# LANGUAGE TupleSections #-} 5 | {-# LANGUAGE TypeApplications #-} 6 | {-# LANGUAGE TypeOperators #-} 7 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 8 | 9 | module Backprop.Learn.Run ( 10 | consecutives 11 | , consecutivesN 12 | , leadings 13 | , conduitModel, conduitModelStoch 14 | -- * Encoding and decoding for learning 15 | , oneHot', oneHot, oneHotR 16 | , SVG.maxIndex, maxIndexR 17 | ) where 18 | 19 | import Backprop.Learn.Model 20 | import Control.Monad 21 | import Control.Monad.Primitive 22 | import Control.Monad.Trans.Class 23 | import Control.Monad.Trans.Maybe 24 | import Data.Bool 25 | import Data.Conduit 26 | import Data.Finite 27 | import Data.Foldable 28 | import Data.Proxy 29 | import Data.Type.Functor.Product 30 | import GHC.TypeNats 31 | import Numeric.LinearAlgebra.Static 32 | import Numeric.LinearAlgebra.Static.Vector 33 | import qualified Data.Conduit.Combinators as C 34 | import qualified Data.Sequence as Seq 35 | import qualified Data.Vector.Generic as VG 36 | import qualified Data.Vector.Generic.Sized as SVG 37 | import qualified System.Random.MWC as MWC 38 | 39 | consecutives :: Monad m => ConduitT i (i, i) m () 40 | consecutives = void . runMaybeT $ do 41 | x <- MaybeT await 42 | go x 43 | where 44 | go x = do 45 | y <- MaybeT await 46 | lift $ yield (x, y) 47 | go y 48 | 49 | consecutivesN 50 | :: forall v n i m. (KnownNat n, VG.Vector v i, Monad m) 51 | => ConduitT i (SVG.Vector v n i, SVG.Vector v n i) m () 52 | consecutivesN = conseq (fromIntegral n) .| C.concatMap process 53 | where 54 | n = natVal (Proxy @n) 55 | process (xs, ys, _) = (,) <$> SVG.fromList (toList xs) 56 | <*> SVG.fromList (toList ys) 57 | 58 | leadings 59 | :: forall v n i m. (KnownNat n, VG.Vector v i, Monad m) 60 | => ConduitT i (SVG.Vector v n i, i) m () 61 | leadings = conseq (fromIntegral n) .| C.concatMap process 62 | where 63 | n = natVal (Proxy @n) 64 | process (xs, _, y) = (, y) <$> SVG.fromList (toList xs) 65 | 66 | conseq 67 | :: forall i m. Monad m 68 | => Int 69 | -> ConduitT i (Seq.Seq i, Seq.Seq i, i) m () 70 | conseq n = void . runMaybeT $ do 71 | xs <- Seq.replicateM n $ MaybeT await 72 | go xs 73 | where 74 | go xs = do 75 | _ Seq.:<| xs' <- pure xs 76 | y <- MaybeT await 77 | let ys = xs' Seq.:|> y 78 | lift $ yield (xs, ys, y) 79 | go ys 80 | 81 | conduitModel 82 | :: (Backprop b, AllConstrainedProd Backprop s, Monad m) 83 | => Model p s a b 84 | -> TMaybe p 85 | -> TMaybe s 86 | -> ConduitT a b m (TMaybe s) 87 | conduitModel f p = go 88 | where 89 | go s = do 90 | mx <- await 91 | case mx of 92 | Nothing -> return s 93 | Just x -> do 94 | let (y, s') = runModel f p x s 95 | yield y 96 | go s' 97 | 98 | conduitModelStoch 99 | :: (Backprop b, AllConstrainedProd Backprop s, PrimMonad m) 100 | => Model p s a b 101 | -> MWC.Gen (PrimState m) 102 | -> TMaybe p 103 | -> TMaybe s 104 | -> ConduitT a b m (TMaybe s) 105 | conduitModelStoch f g p = go 106 | where 107 | go s = do 108 | mx <- await 109 | case mx of 110 | Nothing -> return s 111 | Just x -> do 112 | (y, s') <- lift $ runModelStoch f g p x s 113 | yield y 114 | go s' 115 | 116 | -- | What module should this be in? 117 | oneHot' 118 | :: (VG.Vector v a, KnownNat n) 119 | => a -- ^ not hot 120 | -> a -- ^ hot 121 | -> Finite n 122 | -> SVG.Vector v n a 123 | oneHot' nothot hot i = SVG.generate (bool nothot hot . (== i)) 124 | 125 | oneHot 126 | :: (VG.Vector v a, KnownNat n, Num a) 127 | => Finite n 128 | -> SVG.Vector v n a 129 | oneHot = oneHot' 0 1 130 | 131 | oneHotR :: KnownNat n => Finite n -> R n 132 | oneHotR = vecR . oneHot 133 | 134 | -- | Could be in /hmatrix/. 135 | maxIndexR :: KnownNat n => R (n + 1) -> Finite (n + 1) 136 | maxIndexR = SVG.maxIndex . rVec 137 | -------------------------------------------------------------------------------- /old/app/Letter2Vec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ApplicativeDo #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE GADTs #-} 4 | {-# LANGUAGE LambdaCase #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | {-# LANGUAGE TypeApplications #-} 7 | {-# LANGUAGE TypeOperators #-} 8 | {-# LANGUAGE ViewPatterns #-} 9 | {-# OPTIONS_GHC -fno-warn-orphans #-} 10 | 11 | 12 | import Control.Applicative 13 | import Control.DeepSeq 14 | import Control.Exception 15 | import Control.Monad.IO.Class 16 | import Control.Monad.Trans.State 17 | import Data.Bifunctor 18 | import Data.Char 19 | import Data.Default 20 | import Data.Finite 21 | import Data.Foldable 22 | import Data.List 23 | import Data.List.Split 24 | import Data.Maybe 25 | import Data.Ord 26 | import Data.Time.Clock 27 | import Data.Type.Product hiding (toList, head') 28 | import Learn.Neural 29 | import Numeric.BLAS.HMatrix 30 | import Text.Printf hiding (toChar, fromChar) 31 | import Type.Class.Known 32 | import qualified Data.Vector as V 33 | import qualified System.Random.MWC as MWC 34 | import qualified System.Random.MWC.Distributions as MWC 35 | 36 | 37 | type ASCII = Finite 128 38 | 39 | fromChar :: Char -> Maybe ASCII 40 | fromChar = packFinite . fromIntegral . ord 41 | 42 | toChar :: ASCII -> Char 43 | toChar = chr . fromIntegral 44 | 45 | charOneHot :: Tensor t => Char -> Maybe (t '[128]) 46 | charOneHot = fmap (oneHot . only) . fromChar 47 | 48 | oneHotChar :: BLAS t => t '[128] -> Char 49 | oneHotChar = toChar . iamax 50 | 51 | charRank :: Tensor t => t '[128] -> [Char] 52 | charRank = map fst . sortBy (flip (comparing snd)) . zip ['\0'..] . toList . textract 53 | 54 | main :: IO () 55 | main = MWC.withSystemRandom $ \g -> do 56 | holmes <- evaluate . force . mapMaybe (charOneHot @HM) 57 | =<< readFile "data/holmes.txt" 58 | putStrLn "Loaded data" 59 | let slices :: [(HM '[128], HM '[128])] 60 | slices = concat . getZipList $ do 61 | skips <- traverse (ZipList . flip drop holmes) [0..2] 62 | pure (case skips of 63 | [l1,l2,l3] -> 64 | [(l2,l1),(l2,l3)] 65 | _ -> [] 66 | ) 67 | slices' <- liftIO . fmap V.toList $ MWC.uniformShuffle (V.fromList slices) g 68 | let (test,train) = splitAt (length slices `div` 50) slices' 69 | net0 :: Network 'FeedForward HM 70 | ( '[128] :~ FullyConnected ) 71 | '[ '[32 ] :~ LogitMap 72 | , '[32 ] :~ FullyConnected 73 | , '[3 ] :~ LogitMap 74 | , '[3 ] :~ FullyConnected 75 | , '[32 ] :~ LogitMap 76 | , '[32 ] :~ FullyConnected 77 | , '[128] :~ SoftMax '[128] 78 | ] 79 | '[128] <- initDefNet g 80 | flip evalStateT net0 . forM_ [1..] $ \e -> do 81 | train' <- liftIO . fmap V.toList $ MWC.uniformShuffle (V.fromList train) g 82 | liftIO $ printf "[Epoch %d]\n" (e :: Int) 83 | let chunkUp = chunksOf batch train' 84 | numChunks = length chunkUp 85 | 86 | forM_ ([1..] `zip` chunkUp) $ \(b, chnk) -> StateT $ \n0 -> do 87 | printf "(Epoch %d, Batch %d / %d)\n" e (b :: Int) numChunks 88 | 89 | t0 <- getCurrentTime 90 | n' <- evaluate $ optimizeListBatches_ (bimap only_ only_ <$> chnk) n0 91 | (batching (adamOptimizer def netOpPure crossEntropy)) 92 | 25 93 | t1 <- getCurrentTime 94 | printf "Trained on %d points in %s.\n" batch (show (t1 `diffUTCTime` t0)) 95 | 96 | let encoder :: Network 'FeedForward HM ( '[128] :~ FullyConnected ) 97 | '[ '[32] :~ LogitMap, '[32] :~ FullyConnected, '[3] :~ LogitMap ] 98 | '[3] 99 | encoder = takeNet known n' 100 | 101 | forM_ [0..127] $ \(c :: ASCII) -> do 102 | let enc :: HM '[3] 103 | enc = runNetPure encoder (oneHot (c :< Ø)) 104 | [x,y,z] = toList $ textract enc 105 | printf "%s\t%.4f\t%.4f\t%.4f\n" (show (toChar c)) x y z 106 | 107 | let trainScore = testNetList maxTest (someNet n') chnk 108 | testScore = testNetList maxTest (someNet n') test 109 | printf "Training Score: %.2f%%\n" ((1 - trainScore) * 100) 110 | printf "Validation Score: %.2f%%\n" ((1 - testScore ) * 100) 111 | 112 | return ((), n') 113 | where 114 | batch :: Int 115 | batch = 10000 116 | 117 | 118 | -------------------------------------------------------------------------------- /app/word2vec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE PartialTypeSignatures #-} 4 | {-# LANGUAGE PolyKinds #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | {-# LANGUAGE TupleSections #-} 7 | {-# LANGUAGE TypeApplications #-} 8 | {-# LANGUAGE TypeOperators #-} 9 | {-# OPTIONS_GHC -fno-warn-partial-type-signatures #-} 10 | 11 | import Backprop.Learn 12 | import Control.DeepSeq 13 | import Control.Exception 14 | import Control.Monad 15 | import Control.Monad.IO.Class 16 | import Control.Monad.Trans.Class 17 | import Control.Monad.Trans.State 18 | import Data.Char 19 | import Data.Conduit 20 | import Data.Default 21 | import Data.List 22 | import Data.Proxy 23 | import Data.Time 24 | import Data.Type.Tuple 25 | import GHC.TypeNats 26 | import Numeric.LinearAlgebra.Static.Backprop 27 | import Numeric.Opto 28 | import System.Environment 29 | import Text.Printf 30 | import qualified Conduit as C 31 | import qualified Data.Conduit.Combinators as C 32 | import qualified Data.Set as S 33 | import qualified Data.Text as T 34 | import qualified Data.Vector.Sized as SV 35 | import qualified Data.Vector.Storable as VS 36 | import qualified Numeric.LinearAlgebra.Static as H 37 | import qualified System.Random.MWC as MWC 38 | 39 | encoder :: forall n e. (KnownNat n, KnownNat e) => Model _ _ (R n) (R e) 40 | encoder = fca logistic 41 | 42 | decoder :: forall n e. (KnownNat n, KnownNat e) => Model _ _ (R e) (R n) 43 | decoder = fca softMax 44 | 45 | word2vec :: forall n e. (KnownNat n, KnownNat e) => Model _ _ (R n) (R n) 46 | word2vec = decoder @n @e <~ encoder @n @e 47 | 48 | oneHotWord 49 | :: KnownNat n 50 | => S.Set String 51 | -> String 52 | -> Maybe (R n) 53 | oneHotWord ws = fmap (oneHotR . fromIntegral) . (`S.lookupIndex` ws) 54 | 55 | makeCBOW :: KnownNat w => SV.Vector w (R n) -> (R n, R n) 56 | makeCBOW v = (SV.index v mid, SV.sum (v SV.// [(mid, 0)])) 57 | where 58 | mid = maxBound `div` 2 59 | 60 | main :: IO () 61 | main = MWC.withSystemRandom @IO $ \g -> do 62 | sourceFile:logFile:testFile:_ <- getArgs 63 | wordSet <- S.fromList . tokenize <$> readFile sourceFile 64 | SomeNat (Proxy :: Proxy n) <- pure $ someNatVal (fromIntegral (S.size wordSet)) 65 | 66 | printf "%d unique words found.\n" (natVal (Proxy @n)) 67 | 68 | let model = word2vec @n @100 69 | enc = encoder @n @100 70 | p0 <- initParamNormal model 0.2 g 71 | 72 | 73 | let report n b = do 74 | liftIO $ printf "(Batch %d)\n" (b :: Int) 75 | t0 <- liftIO getCurrentTime 76 | C.drop (n - 1) 77 | mp <- mapM (liftIO . evaluate . force) =<< await 78 | t1 <- liftIO getCurrentTime 79 | case mp of 80 | Nothing -> liftIO $ putStrLn "Done!" 81 | Just p@(_ :# pEnc) -> do 82 | chnk <- lift . state $ (,[]) 83 | liftIO $ do 84 | printf "Trained on %d points in %s.\n" 85 | (length chnk) 86 | (show (t1 `diffUTCTime` t0)) 87 | let trainScore = testModelAll maxIxTest model (TJust p) chnk 88 | printf "Training error: %.3f%%\n" ((1 - trainScore) * 100) 89 | 90 | testWords <- tokenize <$> readFile testFile 91 | let tests = flip map testWords $ \w -> 92 | let v = maybe 0 (runModelStateless enc (TJust pEnc)) 93 | $ oneHotWord wordSet w 94 | in intercalate "," $ map (printf "%0.4f") (VS.toList (H.extract v)) 95 | 96 | writeFile logFile $ unlines tests 97 | 98 | report n (b + 1) 99 | 100 | 101 | C.runResourceT . flip evalStateT [] 102 | . runConduit 103 | $ forever ( C.sourceFile sourceFile 104 | .| C.decodeUtf8 105 | .| C.concatMap (tokenize . T.unpack) 106 | .| C.concatMap (oneHotWord wordSet) 107 | .| C.slidingWindow 5 108 | .| C.concatMap (SV.fromList @5) 109 | .| C.map makeCBOW 110 | ) 111 | .| skipSampling 0.02 g 112 | .| C.iterM (modify . (:)) 113 | .| optoConduit def p0 114 | (adam def (modelGradStoch crossEntropy noReg model g)) 115 | .| report 500 0 116 | .| C.sinkNull 117 | 118 | tokenize :: String -> [String] 119 | tokenize = words . map (whitePunc . toLower) 120 | where 121 | whitePunc c 122 | | isPunctuation c = ' ' 123 | | otherwise = c 124 | -------------------------------------------------------------------------------- /old2/src/Backprop/Learn/Initialize.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes #-} 2 | {-# LANGUAGE ConstraintKinds #-} 3 | {-# LANGUAGE DataKinds #-} 4 | {-# LANGUAGE DefaultSignatures #-} 5 | {-# LANGUAGE FlexibleContexts #-} 6 | {-# LANGUAGE FlexibleInstances #-} 7 | {-# LANGUAGE GADTs #-} 8 | {-# LANGUAGE RankNTypes #-} 9 | {-# LANGUAGE ScopedTypeVariables #-} 10 | {-# LANGUAGE TypeApplications #-} 11 | {-# LANGUAGE TypeFamilies #-} 12 | {-# LANGUAGE TypeOperators #-} 13 | {-# LANGUAGE UndecidableInstances #-} 14 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 15 | 16 | module Backprop.Learn.Initialize ( 17 | Initialize(..) 18 | , gInitialize 19 | , initializeNormal 20 | , initializeSingle 21 | ) where 22 | 23 | import Control.Monad.Primitive 24 | import Data.Complex 25 | import Data.Type.Length 26 | import Data.Type.NonEmpty 27 | import Data.Type.Tuple 28 | import GHC.TypeNats 29 | import Generics.OneLiner 30 | import Numeric.LinearAlgebra.Static.Vector 31 | import Statistics.Distribution 32 | import Statistics.Distribution.Normal 33 | import Type.Class.Known 34 | import Type.Family.List 35 | import qualified Data.Vector.Generic as VG 36 | import qualified Data.Vector.Generic.Sized as SVG 37 | import qualified Numeric.LinearAlgebra.Static as H 38 | import qualified System.Random.MWC as MWC 39 | 40 | -- | Class for types that are basically a bunch of 'Double's, which can be 41 | -- initialized with a given identical and independent distribution. 42 | class Initialize p where 43 | initialize 44 | :: (ContGen d, PrimMonad m) 45 | => d 46 | -> MWC.Gen (PrimState m) 47 | -> m p 48 | 49 | default initialize 50 | :: (ADTRecord p, Constraints p Initialize, ContGen d, PrimMonad m) 51 | => d 52 | -> MWC.Gen (PrimState m) 53 | -> m p 54 | initialize = gInitialize 55 | 56 | -- | 'initialize' for any instance of 'Generic'. 57 | gInitialize 58 | :: (ADTRecord p, Constraints p Initialize, ContGen d, PrimMonad m) 59 | => d 60 | -> MWC.Gen (PrimState m) 61 | -> m p 62 | gInitialize d g = createA' @Initialize (initialize d g) 63 | 64 | initializeNormal 65 | :: (Initialize p, PrimMonad m) 66 | => Double -- ^ standard deviation 67 | -> MWC.Gen (PrimState m) 68 | -> m p 69 | initializeNormal = initialize . normalDistr 0 70 | 71 | -- | 'initialize' definition if @p@ is a single number. 72 | initializeSingle 73 | :: (ContGen d, PrimMonad m, Fractional p) 74 | => d 75 | -> MWC.Gen (PrimState m) 76 | -> m p 77 | initializeSingle d = fmap realToFrac . genContVar d 78 | 79 | instance Initialize Double where 80 | initialize = initializeSingle 81 | instance Initialize Float where 82 | initialize = initializeSingle 83 | 84 | -- | Initializes real and imaginary components identically 85 | instance Initialize a => Initialize (Complex a) where 86 | 87 | instance Initialize T0 88 | instance (Initialize a, Initialize b) => Initialize (T2 a b) 89 | instance (Initialize a, Initialize b, Initialize c) => Initialize (T3 a b c) 90 | 91 | instance (ListC (Initialize <$> as), Known Length as) => Initialize (T as) where 92 | initialize d g = constTA @Initialize (initialize d g) known 93 | 94 | instance (Initialize a, ListC (Initialize <$> as), Known Length as) => Initialize (NETup (a ':| as)) where 95 | initialize d g = NET <$> initialize d g 96 | <*> initialize d g 97 | 98 | instance Initialize () 99 | instance (Initialize a, Initialize b) => Initialize (a, b) 100 | instance (Initialize a, Initialize b, Initialize c) => Initialize (a, b, c) 101 | 102 | instance (VG.Vector v a, KnownNat n, Initialize a) => Initialize (SVG.Vector v n a) where 103 | initialize d = SVG.replicateM . initialize d 104 | 105 | instance KnownNat n => Initialize (H.R n) where 106 | initialize d = fmap vecR . initialize d 107 | instance KnownNat n => Initialize (H.C n) where 108 | initialize d = fmap vecC . initialize d 109 | 110 | instance (KnownNat n, KnownNat m) => Initialize (H.L n m) where 111 | initialize d = fmap vecL . initialize d 112 | instance (KnownNat n, KnownNat m) => Initialize (H.M n m) where 113 | initialize d = fmap vecM . initialize d 114 | 115 | 116 | constTA 117 | :: forall c as f. (ListC (c <$> as), Applicative f) 118 | => (forall a. c a => f a) 119 | -> Length as 120 | -> f (T as) 121 | constTA x = go 122 | where 123 | go :: forall bs. ListC (c <$> bs) => Length bs -> f (T bs) 124 | go LZ = pure TNil 125 | go (LS l) = (:&) <$> x <*> go l 126 | 127 | -------------------------------------------------------------------------------- /old2/src/Data/Type/NonEmpty.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BangPatterns #-} 2 | {-# LANGUAGE FlexibleInstances #-} 3 | {-# LANGUAGE GADTs #-} 4 | {-# LANGUAGE KindSignatures #-} 5 | {-# LANGUAGE LambdaCase #-} 6 | {-# LANGUAGE MultiParamTypeClasses #-} 7 | {-# LANGUAGE PatternSynonyms #-} 8 | {-# LANGUAGE ScopedTypeVariables #-} 9 | {-# LANGUAGE TypeApplications #-} 10 | {-# LANGUAGE TypeFamilies #-} 11 | {-# LANGUAGE TypeFamilyDependencies #-} 12 | {-# LANGUAGE TypeInType #-} 13 | {-# LANGUAGE TypeOperators #-} 14 | {-# LANGUAGE UndecidableInstances #-} 15 | {-# LANGUAGE ViewPatterns #-} 16 | 17 | module Data.Type.NonEmpty ( 18 | NETup(.., NETT), ToNonEmpty 19 | , netHead, netTail 20 | , unNet 21 | , netT 22 | , NonEmpty(..) 23 | ) where 24 | 25 | import Control.DeepSeq 26 | import Data.Kind 27 | import Data.List.NonEmpty (NonEmpty(..)) 28 | import Data.Type.Length 29 | import Data.Type.Tuple 30 | import Lens.Micro 31 | import Numeric.Backprop (Backprop(..)) 32 | import Numeric.Opto.Ref 33 | import Numeric.Opto.Update 34 | import Type.Class.Known 35 | import Type.Family.List 36 | import qualified Data.Binary as Bi 37 | 38 | data NETup :: NonEmpty Type -> Type where 39 | NET :: !a -> !(T as) -> NETup (a ':| as) 40 | 41 | pattern NETT :: T (a ': as) -> NETup (a ':| as) 42 | pattern NETT { netT } <- (\case NET x xs -> x :& xs->(!netT)) 43 | where 44 | NETT (!(x :& xs)) = NET x xs 45 | {-# COMPLETE NETT #-} 46 | 47 | instance (NFData a, ListC (NFData <$> as)) => NFData (NETup (a ':| as)) where 48 | rnf (NETT xs) = rnf xs 49 | 50 | instance (Num a, ListC (Num <$> as), Known Length as) => Num (NETup (a ':| as)) where 51 | NETT xs + NETT ys = NETT (xs + ys) 52 | NETT xs - NETT ys = NETT (xs - ys) 53 | NETT xs * NETT ys = NETT (xs * ys) 54 | negate (NETT xs) = NETT (negate xs) 55 | abs (NETT xs) = NETT (abs xs) 56 | signum (NETT xs) = NETT (signum xs) 57 | fromInteger = NETT . fromInteger 58 | 59 | instance (Fractional a, ListC (Num <$> as), ListC (Fractional <$> as), Known Length as) 60 | => Fractional (NETup (a ':| as)) where 61 | NETT xs / NETT ys = NETT (xs / ys) 62 | recip (NETT xs) = NETT (recip xs) 63 | fromRational = NETT . fromRational 64 | 65 | instance (Floating a, ListC (Num <$> as), ListC (Fractional <$> as), ListC (Floating <$> as), Known Length as) 66 | => Floating (NETup (a ':| as)) where 67 | pi = NETT pi 68 | sqrt (NETT xs) = NETT (sqrt xs) 69 | exp (NETT xs) = NETT (exp xs) 70 | log (NETT xs) = NETT (log xs) 71 | sin (NETT xs) = NETT (sin xs) 72 | cos (NETT xs) = NETT (cos xs) 73 | tan (NETT xs) = NETT (tan xs) 74 | asin (NETT xs) = NETT (asin xs) 75 | acos (NETT xs) = NETT (acos xs) 76 | atan (NETT xs) = NETT (atan xs) 77 | sinh (NETT xs) = NETT (sinh xs) 78 | cosh (NETT xs) = NETT (cosh xs) 79 | tanh (NETT xs) = NETT (tanh xs) 80 | asinh (NETT xs) = NETT (asinh xs) 81 | acosh (NETT xs) = NETT (acosh xs) 82 | atanh (NETT xs) = NETT (atanh xs) 83 | 84 | instance (Additive a, Additive (T as)) 85 | => Additive (NETup (a ':| as)) where 86 | NET x xs .+. NET y ys = NET (x .+. y) (xs .+. ys) 87 | addZero = NET addZero addZero 88 | 89 | instance (Scaling c a, Scaling c (T as)) => Scaling c (NETup (a ':| as)) where 90 | c .* NET x xs = NET (c .* x) (c .* xs) 91 | scaleOne = scaleOne @c @a 92 | 93 | instance (Metric c a, Metric c (T as), Ord c, Floating c) => Metric c (NETup (a ':| as)) where 94 | NET x xs <.> NET y ys = (x <.> y) + (xs <.> ys) 95 | norm_inf (NET x xs) = max (norm_inf x) (norm_inf xs) 96 | norm_0 (NET x xs) = norm_0 x + norm_0 xs 97 | norm_1 (NET x xs) = norm_1 x + norm_1 xs 98 | quadrance (NET x xs) = quadrance x + quadrance xs 99 | 100 | instance (Additive a, Additive (T as), Ref m (NETup (a ':| as)) v) 101 | => AdditiveInPlace m v (NETup (a ':| as)) 102 | instance (Scaling s a, Scaling s (T as), Ref m (NETup (a ':| as)) v) 103 | => ScalingInPlace m v s (NETup (a ':| as)) 104 | 105 | instance (Backprop a, ListC (Backprop <$> as)) => Backprop (NETup (a ':| as)) where 106 | zero (NET x xs) = NET (zero x) (zero xs) 107 | add (NET x xs) (NET y ys) = NET (add x y) (add xs ys) 108 | one (NET x xs) = NET (one x) (one xs) 109 | 110 | instance (Bi.Binary a, ListC (Bi.Binary <$> as), Known Length as) => Bi.Binary (NETup (a ':| as)) where 111 | get = NET <$> Bi.get 112 | <*> Bi.get 113 | put (NET x xs) = Bi.put x *> Bi.put xs 114 | 115 | netHead :: Lens (NETup (a ':| as)) (NETup (b ':| as)) a b 116 | netHead f (NET x xs) = (`NET` xs) <$> f x 117 | 118 | netTail :: Lens (NETup (a ':| as)) (NETup (a ':| bs)) (T as) (T bs) 119 | netTail f (NET x xs) = NET x <$> f xs 120 | 121 | unNet :: NETup (a ':| as) -> (a, T as) 122 | unNet (NET x xs) = (x, xs) 123 | 124 | type family ToNonEmpty (l :: [k]) = (m :: Maybe (NonEmpty k)) | m -> l where 125 | ToNonEmpty '[] = 'Nothing 126 | ToNonEmpty (a ': as) = 'Just (a ':| as) 127 | -------------------------------------------------------------------------------- /old2/src/Backprop/Learn/Model/Parameter.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveDataTypeable #-} 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-# LANGUAGE FlexibleInstances #-} 4 | {-# LANGUAGE GADTs #-} 5 | {-# LANGUAGE KindSignatures #-} 6 | {-# LANGUAGE MultiParamTypeClasses #-} 7 | {-# LANGUAGE PatternSynonyms #-} 8 | {-# LANGUAGE RankNTypes #-} 9 | {-# LANGUAGE RecordWildCards #-} 10 | {-# LANGUAGE ScopedTypeVariables #-} 11 | {-# LANGUAGE TypeFamilies #-} 12 | {-# LANGUAGE TypeInType #-} 13 | {-# LANGUAGE UndecidableInstances #-} 14 | 15 | module Backprop.Learn.Model.Parameter ( 16 | DeParam(..) 17 | , dpDeterm 18 | , DeParamAt(..) 19 | , dpaDeterm 20 | , ReParam 21 | , rpDeterm 22 | ) where 23 | 24 | import Backprop.Learn.Model.Class 25 | import Control.Monad.Primitive 26 | import Data.Kind 27 | import Data.Typeable 28 | import Numeric.Backprop 29 | import qualified System.Random.MWC as MWC 30 | 31 | -- | Convert a model with trainabile parameters into a model without any 32 | -- trainable parameters. 33 | -- 34 | -- The parameters are instead fixed (or stochastic, with a fixed 35 | -- distribution), and the model is treated as an untrainable function. 36 | -- 37 | -- 'DeParam' is essentially 'DeParamAt', with 'id' as the lens. 38 | data DeParam :: Type -> Type -> Type where 39 | DP :: { _dpParam :: p 40 | , _dpParamStoch :: forall m. PrimMonad m => MWC.Gen (PrimState m) -> m p 41 | , _dpLearn :: l 42 | } 43 | -> DeParam p l 44 | deriving (Typeable) 45 | 46 | -- | Create a 'DeParam' from a deterministic, non-stochastic parameter. 47 | dpDeterm :: p -> l -> DeParam p l 48 | dpDeterm p = DP p (const (pure p)) 49 | 50 | instance (Learn a b l, LParamMaybe l ~ 'Just p) => Learn a b (DeParam p l) where 51 | type LParamMaybe (DeParam p l) = 'Nothing 52 | type LStateMaybe (DeParam p l) = LStateMaybe l 53 | 54 | runLearn DP{..} _ = runLearn _dpLearn (J_ (constVar _dpParam)) 55 | runLearnStoch DP{..} g _ x s = do 56 | p <- constVar <$> _dpParamStoch g 57 | runLearnStoch _dpLearn g (J_ p) x s 58 | 59 | -- | Wrapping a mode with @'DeParamAt' pq p q@ says that the mode's 60 | -- parameter @pq@ can be split into @p@ and @q@, and fixes @q@ to a given 61 | -- fixed (or stochastic with a fixed distribution) value. The model now 62 | -- effectively has parameter @p@ only, and the @q@ part will not be 63 | -- backpropagated. 64 | -- 65 | -- 'DeParam' is essentially 'DeParamAt' where @p@ is '()' or 'T0'. 66 | data DeParamAt :: Type -> Type -> Type -> Type -> Type where 67 | DPA :: { _dpaSplit :: pq -> (p, q) 68 | , _dpaJoin :: p -> q -> pq 69 | , _dpaParam :: q 70 | , _dpaParamStoch :: forall m. PrimMonad m => MWC.Gen (PrimState m) -> m q 71 | , _dpaLearn :: l 72 | } 73 | -> DeParamAt pq p q l 74 | deriving (Typeable) 75 | 76 | -- | Create a 'DeParamAt' from a deterministic, non-stochastic fixed value 77 | -- as a part of the parameter. 78 | dpaDeterm :: (pq -> (p, q)) -> (p -> q -> pq) -> q -> l -> DeParamAt pq p q l 79 | dpaDeterm s j q = DPA s j q (const (pure q)) 80 | 81 | instance (Learn a b l, LParamMaybe l ~ 'Just pq, Backprop pq, Backprop p, Backprop q) 82 | => Learn a b (DeParamAt pq p q l) where 83 | type LParamMaybe (DeParamAt pq p q l) = 'Just p 84 | type LStateMaybe (DeParamAt pq p q l) = LStateMaybe l 85 | 86 | runLearn DPA{..} (J_ p) = runLearn _dpaLearn (J_ p') 87 | where 88 | p' = isoVar2 _dpaJoin _dpaSplit p (constVar _dpaParam) 89 | 90 | runLearnStoch DPA{..} g (J_ p) x s = do 91 | q <- _dpaParamStoch g 92 | let p' = isoVar2 _dpaJoin _dpaSplit p (constVar q) 93 | runLearnStoch _dpaLearn g (J_ p') x s 94 | 95 | 96 | -- | Pre-apply a function to the parameter before the original model sees 97 | -- it. A @'ReParam' p q@ turns a model taking @p@ into a model taking @q@. 98 | -- 99 | -- Note that a @'ReParam' p ''Nothing'@ is essentially the same as 100 | -- a @'DeParam' p@, and one could implement @'DeParamAt' p q@ in terms of 101 | -- @'ReParam' p (''Just' q)@. 102 | data ReParam :: Type -> Maybe Type -> Type -> Type where 103 | RP :: { _rpFrom :: forall s. Reifies s W => Mayb (BVar s) q -> BVar s p 104 | , _rpFromStoch :: forall m s. (PrimMonad m, Reifies s W) => MWC.Gen (PrimState m) -> Mayb (BVar s) q -> m (BVar s p) 105 | , _rpLearn :: l 106 | } 107 | -> ReParam p q l 108 | deriving (Typeable) 109 | 110 | instance (Learn a b l, LParamMaybe l ~ 'Just p) => Learn a b (ReParam p q l) where 111 | type LParamMaybe (ReParam p q l) = q 112 | type LStateMaybe (ReParam p q l) = LStateMaybe l 113 | 114 | runLearn RP{..} q = runLearn _rpLearn (J_ (_rpFrom q)) 115 | runLearnStoch RP{..} g q x s = do 116 | p <- _rpFromStoch g q 117 | runLearnStoch _rpLearn g (J_ p) x s 118 | 119 | -- | Create a 'ReParam' from a deterministic, non-stochastic 120 | -- transformation function. 121 | rpDeterm 122 | :: (forall s. Reifies s W => Mayb (BVar s) q -> BVar s p) 123 | -> l 124 | -> ReParam p q l 125 | rpDeterm f = RP f (const (pure . f)) 126 | -------------------------------------------------------------------------------- /old2/src/Backprop/Learn/Model/Neural.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-# LANGUAGE FlexibleInstances #-} 4 | {-# LANGUAGE MultiParamTypeClasses #-} 5 | {-# LANGUAGE PatternSynonyms #-} 6 | {-# LANGUAGE RankNTypes #-} 7 | {-# LANGUAGE TypeFamilies #-} 8 | {-# LANGUAGE TypeOperators #-} 9 | {-# LANGUAGE UndecidableInstances #-} 10 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 11 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} 12 | 13 | module Backprop.Learn.Model.Neural ( 14 | -- * Fully connected 15 | -- ** Feed-forward 16 | FC, pattern FC, FCp, fcBias, fcWeights 17 | -- *** With activation function 18 | , FCA, pattern FCA, _fcaActivation 19 | -- ** Recurrent 20 | , FCR, pattern FCR, FCRp, fcrBias, fcrInputWeights, fcrStateWeights 21 | -- *** With activation function 22 | , FCRA, pattern FCRA, _fcraStore, _fcraActivation 23 | ) where 24 | 25 | 26 | import Backprop.Learn.Model.Combinator 27 | import Backprop.Learn.Model.Regression 28 | import Backprop.Learn.Model.State 29 | import Data.Tuple 30 | import GHC.TypeNats 31 | import Lens.Micro 32 | import Numeric.Backprop 33 | import Numeric.LinearAlgebra.Static.Backprop 34 | import qualified Numeric.LinearAlgebra.Static as H 35 | 36 | -- | Fully connected feed-forward layer with bias. Parameterized by its 37 | -- initialization distribution. 38 | -- 39 | -- Note that this has no activation function; to use as a model with 40 | -- activation function, chain it with an activation function using 'RMap', 41 | -- ':.~', etc.; see 'FCA' for a convenient type synonym and constructor. 42 | -- 43 | -- Without any activation function, this is essentially a multivariate 44 | -- linear regression. 45 | -- 46 | -- With the logistic function as an activation function, this is 47 | -- essentially multivariate logistic regression. (See 'Logistic') 48 | type FC i o = LinReg i o 49 | 50 | pattern FC :: FC i o 51 | pattern FC = LinReg 52 | 53 | -- | Convenient synonym for an 'FC' post-composed with a simple 54 | -- parameterless activation function. 55 | type FCA i o = RMap (R o) (R o) (FC i o) 56 | 57 | -- | Construct an 'FCA' using a generating function and activation 58 | -- function. 59 | -- 60 | -- Some common ones include 'logistic' and @'vmap' 'reLU'@. 61 | pattern FCA 62 | :: (forall s. Reifies s W => BVar s (R o) -> BVar s (R o)) -- ^ '_fcaActivation' 63 | -> FCA i o 64 | pattern FCA { _fcaActivation } = RM _fcaActivation FC 65 | 66 | type FCp = LRp 67 | 68 | fcWeights :: Lens (FCp i o) (FCp i' o) (L o i) (L o i') 69 | fcWeights = lrBeta 70 | 71 | fcBias :: Lens' (FCp i o) (R o) 72 | fcBias = lrAlpha 73 | 74 | -- | Fully connected recurrent layer with bias. 75 | -- 76 | -- Parameterized by its initialization distributions, and also the function 77 | -- to compute the new state from previous input. 78 | -- 79 | -- @ 80 | -- instance 'Learn' ('R' i) (R o) ('FCR' h i o) where 81 | -- type 'LParamMaybe' (FCR h i o) = ''Just' ('FCRp' h i o) 82 | -- type 'LStateMaybe' (FCR h i o) = 'Just (R h) 83 | -- @ 84 | type FCR h i o = Recurrent (R (i + h)) (R i) (R h) (R o) (FC (i + h) o) 85 | 86 | -- | Construct an 'FCR' 87 | pattern FCR 88 | :: (KnownNat h, KnownNat i) 89 | => (forall s. Reifies s W => BVar s (R o) -> BVar s (R h)) -- ^ '_fcrSTore' 90 | -> FCR h i o 91 | pattern FCR { _fcrStore } <- 92 | Rec { _recLoop = _fcrStore 93 | , _recLearn = FC 94 | } 95 | where 96 | FCR s = Rec { _recSplit = H.split 97 | , _recJoin = (H.#) 98 | , _recLoop = s 99 | , _recLearn = FC 100 | } 101 | {-# COMPLETE FCR #-} 102 | 103 | -- | Convenient synonym for an 'FCR' post-composed with a simple 104 | -- parameterless activation function. 105 | type FCRA h i o = RMap (R o) (R o) (FCR h i o) 106 | 107 | -- | Construct an 'FCRA' using a generating function and activation 108 | -- function. 109 | -- 110 | -- Some common ones include 'logistic' and @'vmap' 'reLU'@. 111 | pattern FCRA 112 | :: (KnownNat h, KnownNat i) 113 | => (forall s. Reifies s W => BVar s (R o) -> BVar s (R h)) -- ^ '_fcraStore' 114 | -> (forall s. Reifies s W => BVar s (R o) -> BVar s (R o)) -- ^ '_fcraActivation' 115 | -> FCRA h i o 116 | pattern FCRA { _fcraStore, _fcraActivation } 117 | = RM _fcraActivation (FCR _fcraStore) 118 | 119 | type FCRp h i o = FCp (i + h) o 120 | 121 | lensIso :: (s -> (a, x)) -> ((b, x) -> t) -> Lens s t a b 122 | lensIso f g h x = g <$> _1 h (f x) 123 | 124 | fcrInputWeights 125 | :: (KnownNat h, KnownNat i, KnownNat i', KnownNat o) 126 | => Lens (FCRp h i o) (FCRp h i' o) (L o i) (L o i') 127 | fcrInputWeights = fcWeights . lensIso H.splitCols (uncurry (H.|||)) 128 | 129 | fcrStateWeights 130 | :: (KnownNat h, KnownNat h', KnownNat i, KnownNat o) 131 | => Lens (FCRp h i o) (FCRp h' i o) (L o h) (L o h') 132 | fcrStateWeights = fcWeights . lensIso (swap . H.splitCols) (uncurry (H.|||) . swap) 133 | 134 | fcrBias :: Lens' (FCRp h i o) (R o) 135 | fcrBias = fcBias 136 | 137 | -------------------------------------------------------------------------------- /app/char-rnn.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE FlexibleContexts #-} 4 | {-# LANGUAGE GADTs #-} 5 | {-# LANGUAGE PartialTypeSignatures #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE TupleSections #-} 8 | {-# LANGUAGE TypeApplications #-} 9 | {-# LANGUAGE TypeOperators #-} 10 | {-# OPTIONS_GHC -fno-warn-partial-type-signatures #-} 11 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 12 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} 13 | 14 | import Backprop.Learn 15 | import Control.DeepSeq 16 | import Control.Exception 17 | import Control.Monad 18 | import Control.Monad.IO.Class 19 | import Control.Monad.Trans.Class 20 | import Control.Monad.Trans.State 21 | import Data.Char 22 | import Data.Conduit 23 | import Data.Default 24 | import Data.Foldable 25 | import Data.Proxy 26 | import Data.Time 27 | import Data.Type.Equality 28 | import Data.Type.Tuple 29 | import GHC.TypeNats 30 | import Numeric.LinearAlgebra.Static.Backprop 31 | import Numeric.LinearAlgebra.Static.Vector 32 | import Numeric.Opto 33 | import System.Environment 34 | import Text.Printf 35 | import qualified Conduit as C 36 | import qualified Data.Conduit.Combinators as C 37 | import qualified Data.Set as S 38 | import qualified Data.Text as T 39 | import qualified Data.Vector.Sized as SV 40 | import qualified Data.Vector.Storable.Sized as SVS 41 | import qualified System.Random.MWC as MWC 42 | import qualified System.Random.MWC.Distributions as MWC 43 | 44 | -- | TODO: replace with 'LModel' 45 | charRNN 46 | :: forall n h1 h2. (KnownNat n, KnownNat h1, KnownNat h2) 47 | => LModel _ _ (R n) (R n) 48 | charRNN = fca softMax 49 | #: dropout @h2 0.25 50 | #: lstm 51 | #: dropout @h1 0.25 52 | #: lstm 53 | #: nilLM 54 | 55 | oneHotChar 56 | :: KnownNat n 57 | => S.Set Char 58 | -> Char 59 | -> R n 60 | oneHotChar cs = oneHotR . fromIntegral . (`S.findIndex` cs) 61 | 62 | main :: IO () 63 | main = MWC.withSystemRandom @IO $ \g -> do 64 | sourceFile:_ <- getArgs 65 | charMap <- S.fromList <$> readFile sourceFile 66 | SomeNat (Proxy :: Proxy n) <- pure $ someNatVal (fromIntegral (length charMap)) 67 | SomeNat (Proxy :: Proxy n') <- pure $ someNatVal (fromIntegral (length charMap - 1)) 68 | Just Refl <- pure $ sameNat (Proxy @(n' + 1)) (Proxy @n) 69 | 70 | printf "%d characters found.\n" (natVal (Proxy @n)) 71 | 72 | let model0 = charRNN @n @100 @50 73 | model = trainState . unrollFinal @(SV.Vector 15) $ model0 74 | 75 | p0 <- initParamNormal model 0.2 g 76 | 77 | let report n b = do 78 | liftIO $ printf "(Batch %d)\n" (b :: Int) 79 | t0 <- liftIO getCurrentTime 80 | C.drop (n - 1) 81 | mp <- mapM (liftIO . evaluate . force) =<< await 82 | t1 <- liftIO getCurrentTime 83 | case mp of 84 | Nothing -> liftIO $ putStrLn "Done!" 85 | Just p@(p' :# s') -> do 86 | chnk <- lift . state $ (,[]) 87 | liftIO $ do 88 | printf "Trained on %d points in %s.\n" 89 | (length chnk) 90 | (show (t1 `diffUTCTime` t0)) 91 | let trainScore = testModelAll maxIxTest model (TJust p) chnk 92 | printf "Training error: %.3f%%\n" ((1 - trainScore) * 100) 93 | 94 | forM_ (take 15 chnk) $ \(x,y) -> do 95 | let primed = primeModel model0 (TJust p') x (TJust s') 96 | testOut <- fmap reverse . flip execStateT [] $ 97 | iterateModelM ( fmap (oneHotR . fromIntegral) 98 | . (>>= \r -> r <$ modify (r:)) -- trace 99 | . (`MWC.categorical` g) 100 | . SVS.fromSized 101 | . rVec 102 | ) 103 | 100 model0 (TJust p') y primed 104 | printf "%s|%s\n" 105 | (sanitize . (`S.elemAt` charMap) . fromIntegral . maxIndexR <$> (toList x ++ [y])) 106 | (sanitize . (`S.elemAt` charMap) <$> testOut) 107 | report n (b + 1) 108 | 109 | C.runResourceT . flip evalStateT [] 110 | . runConduit 111 | $ forever ( C.sourceFile sourceFile 112 | .| C.decodeUtf8 113 | .| C.concatMap T.unpack 114 | .| C.map (oneHotChar charMap) 115 | .| leadings 116 | ) 117 | .| skipSampling 0.02 g 118 | .| C.iterM (modify . (:)) 119 | .| optoConduit 120 | def 121 | p0 122 | (adam def (modelGradStoch crossEntropy noReg model g)) 123 | .| report 2500 0 124 | .| C.sinkNull 125 | 126 | 127 | sanitize :: Char -> Char 128 | sanitize c | isPrint c = c 129 | | otherwise = '#' 130 | -------------------------------------------------------------------------------- /old2/src/Backprop/Learn/Loss.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-# LANGUAGE RankNTypes #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | {-# LANGUAGE TypeApplications #-} 6 | {-# LANGUAGE TypeOperators #-} 7 | 8 | module Backprop.Learn.Loss ( 9 | -- * Loss functions 10 | Loss 11 | , crossEntropy 12 | , squaredError, absError, totalSquaredError, squaredErrorV 13 | -- , totalCov 14 | -- ** Manipulate loss functions 15 | , scaleLoss 16 | , sumLoss 17 | , sumLossDecay 18 | , lastLoss 19 | , zipLoss 20 | , t2Loss 21 | , t3Loss 22 | -- * Regularization 23 | , Regularizer 24 | , l2Reg 25 | , l1Reg 26 | , noReg 27 | -- ** Manipulate regularizers 28 | , addReg 29 | , scaleReg 30 | ) where 31 | 32 | import Control.Applicative 33 | import Data.Finite 34 | import Data.Type.Tuple hiding (T2(..), T3(..)) 35 | import GHC.TypeNats 36 | import Numeric.Backprop 37 | import Numeric.LinearAlgebra.Static.Backprop 38 | import Numeric.Opto.Update hiding ((<.>)) 39 | import qualified Data.Type.Tuple as T 40 | import qualified Data.Vector.Sized as SV 41 | import qualified Prelude.Backprop as B 42 | 43 | type Loss a = forall s. Reifies s W => a -> BVar s a -> BVar s Double 44 | 45 | crossEntropy :: KnownNat n => Loss (R n) 46 | crossEntropy targ res = -(log res <.> constVar targ) 47 | 48 | squaredErrorV :: KnownNat n => Loss (R n) 49 | squaredErrorV targ res = e <.> e 50 | where 51 | e = res - constVar targ 52 | 53 | totalSquaredError 54 | :: (Backprop (t Double), Num (t Double), Foldable t, Functor t) 55 | => Loss (t Double) 56 | totalSquaredError targ res = B.sum (e * e) 57 | where 58 | e = constVar targ - res 59 | 60 | squaredError :: Loss Double 61 | squaredError targ res = (res - constVar targ) ** 2 62 | 63 | absError :: Loss Double 64 | absError targ res = abs (res - constVar targ) 65 | 66 | -- -- | Sum of covariances between matching components. Not sure if anyone 67 | -- -- uses this. 68 | -- totalCov :: (Backprop (t Double), Foldable t, Functor t) => Loss (t Double) 69 | -- totalCov targ res = -(xy / fromIntegral n - (x * y) / fromIntegral (n * n)) 70 | -- where 71 | -- x = constVar $ sum targ 72 | -- y = B.sum res 73 | -- xy = B.sum (constVar targ * res) 74 | -- n = length targ 75 | 76 | -- klDivergence :: Loss Double 77 | -- klDivergence = 78 | 79 | sumLoss 80 | :: (Traversable t, Applicative t, Backprop a) 81 | => Loss a 82 | -> Loss (t a) 83 | sumLoss l targ = sum . liftA2 l targ . sequenceVar 84 | 85 | zipLoss 86 | :: (Traversable t, Applicative t, Backprop a) 87 | => t Double 88 | -> Loss a 89 | -> Loss (t a) 90 | zipLoss xs l targ = sum 91 | . liftA3 (\α t -> (* constVar α) . l t) xs targ 92 | . sequenceVar 93 | 94 | sumLossDecay 95 | :: forall n a. (KnownNat n, Backprop a) 96 | => Double 97 | -> Loss a 98 | -> Loss (SV.Vector n a) 99 | sumLossDecay β = zipLoss $ SV.generate (\i -> β ** (fromIntegral i - n)) 100 | where 101 | n = fromIntegral $ maxBound @(Finite n) 102 | 103 | lastLoss 104 | :: (KnownNat (n + 1), Backprop a) 105 | => Loss a 106 | -> Loss (SV.Vector (n + 1) a) 107 | lastLoss l targ = l (SV.last targ) . viewVar (SV.ix maxBound) 108 | 109 | -- | Scale the result of a loss function. 110 | scaleLoss :: Double -> Loss a -> Loss a 111 | scaleLoss β l x = (* constVar β) . l x 112 | 113 | -- | Lift and sum a loss function over the components of a 'T.T2'. 114 | t2Loss 115 | :: (Backprop a, Backprop b) 116 | => Loss a -- ^ loss on first component 117 | -> Loss b -- ^ loss on second component 118 | -> Loss (T.T2 a b) 119 | t2Loss f g (T.T2 xT yT) (T2B xR yR) = f xT xR + g yT yR 120 | 121 | -- | Lift and sum a loss function over the components of a 'T.T3'. 122 | t3Loss 123 | :: (Backprop a, Backprop b, Backprop c) 124 | => Loss a -- ^ loss on first component 125 | -> Loss b -- ^ loss on second component 126 | -> Loss c -- ^ loss on third component 127 | -> Loss (T.T3 a b c) 128 | t3Loss f g h (T.T3 xT yT zT) xRyRzR 129 | = f xT (xRyRzR ^^. t3_1) 130 | + g yT (xRyRzR ^^. t3_2) 131 | + h zT (xRyRzR ^^. t3_3) 132 | 133 | -- | A regularizer on parameters 134 | type Regularizer p = forall s. Reifies s W => BVar s p -> BVar s Double 135 | 136 | -- | L2 regularization 137 | -- 138 | -- \[ 139 | -- \sum_w \frac{1}{2} w^2 140 | -- \] 141 | l2Reg 142 | :: (Metric Double p, Backprop p) 143 | => Double -- ^ scaling factor (often 0.5) 144 | -> Regularizer p 145 | l2Reg λ = liftOp1 . op1 $ \x -> 146 | ( λ * quadrance x / 2, (.* x) . (* λ)) 147 | 148 | -- | L1 regularization 149 | -- 150 | -- \[ 151 | -- \sum_w \lvert w \rvert 152 | -- \] 153 | l1Reg 154 | :: (Num p, Metric Double p, Backprop p) 155 | => Double -- ^ scaling factor (often 0.5) 156 | -> Regularizer p 157 | l1Reg λ = liftOp1 . op1 $ \x -> 158 | ( λ * norm_1 x, (.* signum x) . (* λ) 159 | ) 160 | 161 | -- | No regularization 162 | noReg :: Regularizer p 163 | noReg _ = constVar 0 164 | 165 | -- | Add together two regularizers 166 | addReg :: Regularizer p -> Regularizer p -> Regularizer p 167 | addReg = liftA2 (+) 168 | 169 | -- | Scale a regularizer's influence 170 | scaleReg :: Double -> Regularizer p -> Regularizer p 171 | scaleReg λ reg = (* constVar λ) . reg 172 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # Use new container infrastructure to enable caching 2 | sudo: false 3 | 4 | # Choose a lightweight base image; we provide our own build tools. 5 | language: c 6 | 7 | # Caching so the next build will be fast too. 8 | cache: 9 | directories: 10 | - $HOME/.ghc 11 | - $HOME/.cabal 12 | - $HOME/.stack 13 | 14 | # The different configurations we want to test. We have BUILD=cabal which uses 15 | # cabal-install, and BUILD=stack which uses Stack. More documentation on each 16 | # of those below. 17 | # 18 | # We set the compiler values here to tell Travis to use a different 19 | # cache file per set of arguments. 20 | # 21 | # If you need to have different apt packages for each combination in the 22 | # matrix, you can use a line such as: 23 | # addons: {apt: {packages: [libfcgi-dev,libgmp-dev]}} 24 | matrix: 25 | include: 26 | # We grab the appropriate GHC and cabal-install versions from hvr's PPA. See: 27 | # https://github.com/hvr/multi-ghc-travis 28 | - env: BUILD=cabal GHCVER=8.0.2 CABALVER=1.24 HAPPYVER=1.19.5 ALEXVER=3.1.7 29 | compiler: ": #GHC 8.0.2" 30 | addons: {apt: {packages: [cabal-install-1.24,ghc-8.0.2,happy-1.19.5,alex-3.1.7,libblas-dev,liblapack-dev], sources: [hvr-ghc]}} 31 | 32 | # Build with the newest GHC and cabal-install. This is an accepted failure, 33 | # see below. 34 | - env: BUILD=cabal GHCVER=head CABALVER=head HAPPYVER=1.19.5 ALEXVER=3.1.7 35 | compiler: ": #GHC HEAD" 36 | addons: {apt: {packages: [cabal-install-head,ghc-head,happy-1.19.5,alex-3.1.7,libblas-dev,liblapack-dev], sources: [hvr-ghc]}} 37 | 38 | # The Stack builds. We can pass in arbitrary Stack arguments via the ARGS 39 | # variable, such as using --stack-yaml to point to a different file. 40 | - env: BUILD=stack ARGS="" 41 | compiler: ": #stack default" 42 | addons: {apt: {packages: [ghc-8.0.2,libblas-dev,liblapack-dev], sources: [hvr-ghc]}} 43 | 44 | - env: BUILD=stack ARGS="--resolver lts-8" 45 | compiler: ": #stack 8.0.2" 46 | addons: {apt: {packages: [ghc-8.0.2,libblas-dev,liblapack-dev], sources: [hvr-ghc]}} 47 | 48 | # Nightly builds are allowed to fail 49 | - env: BUILD=stack ARGS="--resolver nightly" 50 | compiler: ": #stack nightly" 51 | addons: {apt: {packages: [libgmp,libgmp-dev,libblas-dev,liblapack-dev]}} 52 | 53 | # Build on OS X in addition to Linux 54 | - env: BUILD=stack ARGS="" 55 | compiler: ": #stack default osx" 56 | os: osx 57 | 58 | - env: BUILD=stack ARGS="--resolver lts-8" 59 | compiler: ": #stack 8.0.2 osx" 60 | os: osx 61 | 62 | - env: BUILD=stack ARGS="--resolver nightly" 63 | compiler: ": #stack nightly osx" 64 | os: osx 65 | 66 | allow_failures: 67 | - env: BUILD=cabal GHCVER=head CABALVER=head HAPPYVER=1.19.5 ALEXVER=3.1.7 68 | - env: BUILD=stack ARGS="--resolver nightly" 69 | 70 | before_install: 71 | # Using compiler above sets CC to an invalid value, so unset it 72 | - unset CC 73 | 74 | # We want to always allow newer versions of packages when building on GHC HEAD 75 | - CABALARGS="" 76 | - if [ "x$GHCVER" = "xhead" ]; then CABALARGS=--allow-newer; fi 77 | 78 | # Download and unpack the stack executable 79 | - export PATH=/opt/ghc/$GHCVER/bin:/opt/cabal/$CABALVER/bin:$HOME/.local/bin:/opt/alex/$ALEXVER/bin:/opt/happy/$HAPPYVER/bin:$HOME/.cabal/bin:$PATH 80 | - mkdir -p ~/.local/bin 81 | - | 82 | if [ `uname` = "Darwin" ] 83 | then 84 | travis_retry curl --insecure -L https://www.stackage.org/stack/osx-x86_64 | tar xz --strip-components=1 --include '*/stack' -C ~/.local/bin 85 | else 86 | travis_retry curl -L https://www.stackage.org/stack/linux-x86_64 | tar xz --wildcards --strip-components=1 -C ~/.local/bin '*/stack' 87 | fi 88 | 89 | # Use the more reliable S3 mirror of Hackage 90 | mkdir -p $HOME/.cabal 91 | echo 'remote-repo: hackage.haskell.org:http://hackage.fpcomplete.com/' > $HOME/.cabal/config 92 | echo 'remote-repo-cache: $HOME/.cabal/packages' >> $HOME/.cabal/config 93 | 94 | if [ "$CABALVER" != "1.16" ] 95 | then 96 | echo 'jobs: $ncpus' >> $HOME/.cabal/config 97 | fi 98 | 99 | # Get the list of packages from the stack.yaml file 100 | - PACKAGES=$(stack --install-ghc query locals | grep '^ *path' | sed 's@^ *path:@@') 101 | 102 | install: 103 | - echo "$(ghc --version) [$(ghc --print-project-git-commit-id 2> /dev/null || echo '?')]" 104 | - if [ -f configure.ac ]; then autoreconf -i; fi 105 | - | 106 | set -ex 107 | case "$BUILD" in 108 | stack) 109 | stack --no-terminal --install-ghc $ARGS test --bench --only-dependencies 110 | ;; 111 | cabal) 112 | cabal --version 113 | travis_retry cabal update 114 | cabal install --only-dependencies --enable-tests --enable-benchmarks --force-reinstalls --ghc-options=-O0 --reorder-goals --max-backjumps=-1 $CABALARGS $PACKAGES 115 | ;; 116 | esac 117 | set +ex 118 | 119 | script: 120 | - | 121 | set -ex 122 | case "$BUILD" in 123 | stack) 124 | stack --no-terminal $ARGS test --bench --no-run-benchmarks --haddock --no-haddock-deps 125 | ;; 126 | cabal) 127 | cabal install --enable-tests --enable-benchmarks --force-reinstalls --ghc-options=-O0 --reorder-goals --max-backjumps=-1 $CABALARGS $PACKAGES 128 | 129 | ORIGDIR=$(pwd) 130 | for dir in $PACKAGES 131 | do 132 | cd $dir 133 | cabal check || [ "$CABALVER" == "1.16" ] 134 | cabal sdist 135 | SRC_TGZ=$(cabal info . | awk '{print $2;exit}').tar.gz && \ 136 | (cd dist && cabal install --force-reinstalls "$SRC_TGZ") 137 | cd $ORIGDIR 138 | # TODO: temporary kludge until we can somehow only have cabal check 139 | # non extra-dep packages 140 | break 141 | done 142 | ;; 143 | esac 144 | set +ex 145 | 146 | 147 | -------------------------------------------------------------------------------- /old/src/Learn/Neural/Layer/Compose.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE DeriveGeneric #-} 4 | {-# LANGUAGE FlexibleInstances #-} 5 | {-# LANGUAGE InstanceSigs #-} 6 | {-# LANGUAGE KindSignatures #-} 7 | {-# LANGUAGE LambdaCase #-} 8 | {-# LANGUAGE MultiParamTypeClasses #-} 9 | {-# LANGUAGE ScopedTypeVariables #-} 10 | {-# LANGUAGE TypeFamilies #-} 11 | {-# LANGUAGE TypeInType #-} 12 | {-# LANGUAGE TypeOperators #-} 13 | {-# LANGUAGE UndecidableInstances #-} 14 | 15 | module Learn.Neural.Layer.Compose ( 16 | ) where 17 | 18 | import Data.Kind 19 | import Data.Singletons 20 | import GHC.Generics (Generic) 21 | import GHC.Generics.Numeric 22 | import Learn.Neural.Layer 23 | import Numeric.BLAS 24 | import Numeric.Backprop 25 | import Numeric.Backprop.Iso 26 | import qualified Generics.SOP as SOP 27 | 28 | data Comp :: Type -> Type -> k -> Type 29 | 30 | instance ( BLAS b 31 | , Component l b i h 32 | , Component j b h o 33 | ) 34 | => Component (Comp l j h) b i o where 35 | data CParam (Comp l j h) b i o = CP { _cp1 :: !(CParam l b i h) 36 | , _cp2 :: !(CParam j b h o) 37 | } 38 | deriving Generic 39 | data CState (Comp l j h) b i o = CS { _cs1 :: !(CState l b i h) 40 | , _cs2 :: !(CState j b h o) 41 | } 42 | deriving Generic 43 | type CConstr (Comp l j h) b i o = 44 | ( CConstr l b i h 45 | , CConstr j b h o 46 | , Num (b h) 47 | , SingI h 48 | ) 49 | data CConf (Comp l j h) b i o = CC { _cc1 :: !(CConf l b i h) 50 | , _cc2 :: !(CConf j b h o) 51 | } 52 | 53 | componentOp = bpOp . withInps $ \(x :< p :< s :< Ø) -> do 54 | p1 :< p2 :< Ø <- cpIso #<~ p 55 | s1 :< s2 :< Ø <- csIso #<~ s 56 | y :< s1' :< Ø <- componentOp ~$$ (x :< p1 :< s1 :< Ø) 57 | z :< s2' :< Ø <- componentOp ~$$ (y :< p2 :< s2 :< Ø) 58 | s' :< Ø <- isoVar (from csIso . tup1) (s1' :< s2' :< Ø) 59 | return $ z :< s' :< Ø 60 | 61 | initParam si so (CC c1 c2) g = 62 | CP <$> initParam si sh c1 g 63 | <*> initParam sh so c2 g 64 | where 65 | sh :: Sing h 66 | sh = sing 67 | 68 | initState si so (CC c1 c2) g = 69 | CS <$> initState si sh c1 g 70 | <*> initState sh so c2 g 71 | where 72 | sh :: Sing h 73 | sh = sing 74 | 75 | defConf = CC defConf defConf 76 | 77 | 78 | cpIso :: Iso' (CParam (Comp l j h) b i o) (Tuple '[CParam l b i h, CParam j b h o]) 79 | cpIso = iso (\case CP c1 c2 -> c1 ::< c2 ::< Ø) (\case I c1 :< I c2 :< Ø -> CP c1 c2) 80 | 81 | csIso :: Iso' (CState (Comp l j h) b i o) (Tuple '[CState l b i h, CState j b h o]) 82 | csIso = iso (\case CS s1 s2 -> s1 ::< s2 ::< Ø) (\case I s1 :< I s2 :< Ø -> CS s1 s2) 83 | 84 | instance (SOP.Generic (CState l b i h), SOP.Generic (CState j b h o)) 85 | => SOP.Generic (CState (Comp l j h) b i o) 86 | instance (SOP.Generic (CParam l b i h), SOP.Generic (CParam j b h o)) 87 | => SOP.Generic (CParam (Comp l j h) b i o) 88 | 89 | instance (Num (CParam l b i h), Num (CParam j b h o)) 90 | => Num (CParam (Comp l j h) b i o) where 91 | (+) = genericPlus 92 | (-) = genericMinus 93 | (*) = genericTimes 94 | negate = genericNegate 95 | abs = genericAbs 96 | signum = genericSignum 97 | fromInteger = genericFromInteger 98 | 99 | instance (Fractional (CParam l b i h), Fractional (CParam j b h o)) 100 | => Fractional (CParam (Comp l j h) b i o) where 101 | (/) = genericDivide 102 | recip = genericRecip 103 | fromRational = genericFromRational 104 | 105 | instance (Floating (CParam l b i h), Floating (CParam j b h o)) 106 | => Floating (CParam (Comp l j h) b i o) where 107 | pi = genericPi 108 | exp = genericExp 109 | (**) = genericPower 110 | log = genericLog 111 | logBase = genericLogBase 112 | sin = genericSin 113 | cos = genericCos 114 | tan = genericTan 115 | asin = genericAsin 116 | acos = genericAcos 117 | atan = genericAtan 118 | sinh = genericSinh 119 | cosh = genericCosh 120 | tanh = genericTanh 121 | asinh = genericAsinh 122 | acosh = genericAcosh 123 | atanh = genericAtanh 124 | 125 | instance (Num (CState l b i h), Num (CState j b h o)) 126 | => Num (CState (Comp l j h) b i o) where 127 | (+) = genericPlus 128 | (-) = genericMinus 129 | (*) = genericTimes 130 | negate = genericNegate 131 | abs = genericAbs 132 | signum = genericSignum 133 | fromInteger = genericFromInteger 134 | 135 | instance (Fractional (CState l b i h), Fractional (CState j b h o)) 136 | => Fractional (CState (Comp l j h) b i o) where 137 | (/) = genericDivide 138 | recip = genericRecip 139 | fromRational = genericFromRational 140 | 141 | instance (Floating (CState l b i h), Floating (CState j b h o)) 142 | => Floating (CState (Comp l j h) b i o) where 143 | pi = genericPi 144 | exp = genericExp 145 | (**) = genericPower 146 | log = genericLog 147 | logBase = genericLogBase 148 | sin = genericSin 149 | cos = genericCos 150 | tan = genericTan 151 | asin = genericAsin 152 | acos = genericAcos 153 | atan = genericAtan 154 | sinh = genericSinh 155 | cosh = genericCosh 156 | tanh = genericTanh 157 | asinh = genericAsinh 158 | acosh = genericAcosh 159 | atanh = genericAtanh 160 | -------------------------------------------------------------------------------- /old/src/Numeric/BLAS/FVector.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE DeriveGeneric #-} 3 | {-# LANGUAGE GADTs #-} 4 | {-# LANGUAGE InstanceSigs #-} 5 | {-# LANGUAGE KindSignatures #-} 6 | {-# LANGUAGE LambdaCase #-} 7 | {-# LANGUAGE RankNTypes #-} 8 | {-# LANGUAGE ScopedTypeVariables #-} 9 | {-# LANGUAGE TypeApplications #-} 10 | {-# LANGUAGE TypeFamilies #-} 11 | {-# LANGUAGE TypeOperators #-} 12 | 13 | module Numeric.BLAS.FVector ( 14 | FV(..) 15 | ) where 16 | 17 | import Data.Finite 18 | import Data.Finite.Internal 19 | import Data.Kind 20 | import Data.List 21 | import Data.Maybe 22 | import Data.Singletons.Prelude 23 | import Data.Singletons.TypeLits 24 | import Data.Type.Combinator 25 | import Data.Type.Product 26 | import Data.Type.Vector 27 | import GHC.Generics (Generic) 28 | import GHC.TypeLits 29 | import Lens.Micro 30 | import Numeric.BLAS 31 | import Numeric.Tensor 32 | import qualified Data.Vector.Sized as SV 33 | import qualified Data.Vector.Unboxed as VU 34 | 35 | newtype FV :: [Nat] -> Type where 36 | FV :: { getFV :: VU.Vector Double 37 | } 38 | -> FV b 39 | deriving (Generic, Show) 40 | 41 | instance Tensor FV where 42 | type Scalar FV = Double 43 | 44 | genA s f = fmap (FV . VU.fromList) . traverse f $ range s 45 | gen s f = FV . VU.fromList . fmap f $ range s 46 | tkonst s = FV . VU.replicate (fromIntegral (product (fromSing s))) 47 | tsum = VU.sum . getFV 48 | tmap f = FV . VU.map f . getFV 49 | tzip f (FV xs) (FV ys) = FV (VU.zipWith f xs ys) 50 | tzipN 51 | :: forall s n. SingI s 52 | => (Vec n Double -> Double) 53 | -> VecT n FV s 54 | -> FV s 55 | tzipN f xs = FV $ VU.generate (fromIntegral len) $ \i -> 56 | (f (vmap (I . (VU.! i)) xs')) 57 | where 58 | len = product (fromSing (sing @_ @s)) 59 | xs' = vmap getFV xs 60 | 61 | 62 | tslice p0 = FV . go sing p0 . getFV 63 | where 64 | go :: Sing ns -> ProdMap Slice ns ms -> VU.Vector Double -> VU.Vector Double 65 | go = \case 66 | SNil -> \case 67 | PMZ -> id 68 | SNat `SCons` ss -> \case 69 | PMS (Slice sL sC _) pms -> 70 | let -- some wasted work here in re-computing the product, 71 | -- but premature optimization blah blah 72 | innerSize = fromIntegral (product (fromSing ss)) 73 | dropper = innerSize * fromIntegral (fromSing sL) 74 | taker = innerSize * fromIntegral (fromSing sC) 75 | in over (chunks innerSize) (go ss pms) 76 | . VU.slice dropper taker 77 | 78 | tindex i (FV xs) = xs VU.! fromIntegral (getFinite (reIndex i)) 79 | 80 | treshape _ (FV xs) = FV xs 81 | tload _ = FV . VU.convert . SV.fromSized 82 | textract 83 | :: forall s. SingI s 84 | => FV s 85 | -> SV.Vector (Product s) Double 86 | textract = withKnownNat (sProduct (sing @_ @s)) $ 87 | fromJust . SV.toSized . VU.convert . getFV 88 | 89 | instance BLAS FV where 90 | 91 | iamax 92 | :: forall n. KnownNat n 93 | => FV '[n + 1] 94 | -> Finite (n + 1) 95 | iamax = withKnownNat (SNat @n %:+ SNat @1) $ 96 | Finite . fromIntegral . VU.maxIndex . VU.map abs . getFV 97 | 98 | gemv 99 | :: forall m n. (KnownNat m, KnownNat n) 100 | => Double 101 | -> FV '[m, n] 102 | -> FV '[n] 103 | -> Maybe (Double, FV '[m]) 104 | -> FV '[m] 105 | gemv α (FV a) (FV x) b = 106 | FV 107 | . maybe id (\(β, FV ys) -> VU.zipWith (\y z -> β * y + z) ys) b 108 | . over (chunkDown innerSize) (\r -> α * VU.sum (VU.zipWith (*) r x)) 109 | $ a 110 | where 111 | innerSize = fromIntegral $ natVal (Proxy @n) 112 | 113 | ger α (FV xs) (FV ys) b = 114 | FV 115 | . maybe id (\(FV zs) -> VU.zipWith (+) zs) b 116 | . VU.concatMap (\x -> VU.map ((α * x) *) ys) 117 | $ xs 118 | 119 | gemm 120 | :: forall m o n. (KnownNat m, KnownNat o, KnownNat n) 121 | => Double 122 | -> FV '[m, o] 123 | -> FV '[o, n] 124 | -> Maybe (Double, FV '[m, n]) 125 | -> FV '[m, n] 126 | gemm α (FV as) bs cs = 127 | FV 128 | . maybe id (uncurry f) cs 129 | . over (chunks innerSize) muller 130 | $ as 131 | where 132 | innerSize = fromIntegral $ natVal (Proxy @o) 133 | muller r = 134 | over (chunkDown innerSize) (\c -> α * VU.sum (VU.zipWith (*) r c)) 135 | . getFV 136 | $ transp bs 137 | f β (FV cs') = VU.zipWith (\c b -> β * c + b) cs' 138 | 139 | range :: Sing ns -> [Prod Finite ns] 140 | range = \case 141 | SNil -> [Ø] 142 | SNat `SCons` ss -> (:<) <$> finites <*> range ss 143 | 144 | reIndex 145 | :: SingI ns 146 | => Prod Finite ns 147 | -> Finite (Product ns) 148 | reIndex = Finite . fst . unsafeReIndex sing 149 | 150 | unsafeReIndex 151 | :: Sing ns 152 | -> Prod Finite ns 153 | -> (Integer, Integer) 154 | unsafeReIndex = \case 155 | SNil -> \case 156 | Ø -> (0, 1) 157 | SNat `SCons` ss -> \case 158 | (i :: Finite n) :< is -> 159 | let (j, jSize) = unsafeReIndex ss is 160 | iSize = jSize * (fromSing (SNat @n)) 161 | in (j + jSize * getFinite i, iSize) 162 | 163 | chunks 164 | :: (VU.Unbox a, VU.Unbox b) 165 | => Int 166 | -> Traversal (VU.Vector a) (VU.Vector b) (VU.Vector a) (VU.Vector b) 167 | chunks n f = fmap VU.concat . traverse f . unfoldr u 168 | where 169 | u xs | VU.length xs >= n = Just (VU.splitAt n xs) 170 | | otherwise = Nothing 171 | 172 | chunkDown 173 | :: (VU.Unbox a, VU.Unbox b) 174 | => Int 175 | -> Traversal (VU.Vector a) (VU.Vector b) (VU.Vector a) b 176 | chunkDown n f = fmap VU.fromList . traverse f . unfoldr u 177 | where 178 | u xs | VU.length xs >= n = Just (VU.splitAt n xs) 179 | | otherwise = Nothing 180 | -------------------------------------------------------------------------------- /old2/src/Backprop/Learn/Test.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ApplicativeDo #-} 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-# LANGUAGE PartialTypeSignatures #-} 4 | {-# LANGUAGE RankNTypes #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | {-# LANGUAGE TypeApplications #-} 7 | {-# LANGUAGE TypeFamilies #-} 8 | {-# OPTIONS_GHC -fno-warn-partial-type-signatures #-} 9 | 10 | module Backprop.Learn.Test ( 11 | -- * Tests 12 | Test 13 | , maxIxTest, rmseTest 14 | , squaredErrorTest, absErrorTest, totalSquaredErrorTest, squaredErrorTestV 15 | , crossEntropyTest 16 | -- ** Manipulate tests 17 | , lossTest, lmapTest 18 | -- * Run tests 19 | , testLearn, testLearnStoch, testLearnAll, testLearnStochAll 20 | -- ** Correlation tests 21 | , testLearnCov, testLearnCorr 22 | , testLearnStochCov, testLearnStochCorr 23 | ) where 24 | 25 | import Backprop.Learn.Loss 26 | import Backprop.Learn.Model 27 | import Control.Monad.Primitive 28 | import Data.Bifunctor 29 | import Data.Bitraversable 30 | import Data.Function 31 | import Data.Profunctor 32 | import Data.Proxy 33 | import GHC.TypeNats 34 | import Numeric.Backprop 35 | import qualified Control.Foldl as L 36 | import qualified Numeric.LinearAlgebra as HU 37 | import qualified Numeric.LinearAlgebra.Static as H 38 | import qualified System.Random.MWC as MWC 39 | 40 | -- TODO: support non-double results? 41 | 42 | type Test o = o -> o -> Double 43 | 44 | -- | Create a 'Test' from a 'Loss' 45 | lossTest :: Loss a -> Test a 46 | lossTest l x = evalBP (l x) 47 | 48 | maxIxTest :: KnownNat n => Test (H.R n) 49 | maxIxTest x y 50 | | match x y = 1 51 | | otherwise = 0 52 | where 53 | match = (==) `on` (HU.maxIndex . H.extract) 54 | 55 | rmseTest :: forall n. KnownNat n => Test (H.R n) 56 | rmseTest x y = H.norm_2 (x - y) / sqrt (fromIntegral (natVal (Proxy @n))) 57 | 58 | squaredErrorTest :: Real a => Test a 59 | squaredErrorTest x y = e * e 60 | where 61 | e = realToFrac (x - y) 62 | 63 | absErrorTest :: Real a => Test a 64 | absErrorTest x y = realToFrac . abs $ x - y 65 | 66 | totalSquaredErrorTest :: (Applicative t, Foldable t, Real a) => Test (t a) 67 | totalSquaredErrorTest x y = realToFrac (sum e) 68 | where 69 | e = do 70 | x' <- x 71 | y' <- y 72 | pure ((x' - y') ^ (2 :: Int)) 73 | 74 | squaredErrorTestV :: KnownNat n => Test (H.R n) 75 | squaredErrorTestV x y = e `H.dot` e 76 | where 77 | e = x - y 78 | 79 | crossEntropyTest :: KnownNat n => Test (H.R n) 80 | crossEntropyTest targ res = -(log res H.<.> targ) 81 | 82 | lmapTest 83 | :: (a -> b) 84 | -> Test b 85 | -> Test a 86 | lmapTest f t x y = t (f x) (f y) 87 | 88 | testLearn 89 | :: (Learn a b l, NoState l) 90 | => Test b 91 | -> l 92 | -> LParam_ I l 93 | -> a 94 | -> b 95 | -> Double 96 | testLearn t l mp x y = t y $ runLearnStateless_ l mp x 97 | 98 | testLearnStoch 99 | :: (Learn a b l, NoState l, PrimMonad m) 100 | => Test b 101 | -> l 102 | -> MWC.Gen (PrimState m) 103 | -> LParam_ I l 104 | -> a 105 | -> b 106 | -> m Double 107 | testLearnStoch t l g mp x y = t y <$> runLearnStochStateless_ l g mp x 108 | 109 | cov :: Fractional a => L.Fold (a, a) a 110 | cov = do 111 | x <- lmap fst L.sum 112 | y <- lmap snd L.sum 113 | xy <- lmap (uncurry (*)) L.sum 114 | n <- fromIntegral <$> L.length 115 | pure (xy / n - (x * y) / n / n) 116 | 117 | corr :: Floating a => L.Fold (a, a) a 118 | corr = do 119 | x <- lmap fst L.sum 120 | x2 <- lmap ((**2) . fst) L.sum 121 | y <- lmap snd L.sum 122 | y2 <- lmap ((**2) . snd) L.sum 123 | xy <- lmap (uncurry (*)) L.sum 124 | n <- fromIntegral <$> L.length 125 | pure $ (xy / n - (x * y) / n / n) 126 | / sqrt ( x2 / n - (x / n)**2 ) 127 | / sqrt ( y2 / n - (y / n)**2 ) 128 | 129 | testLearnCov 130 | :: (Learn a b l, NoState l, Foldable t, Fractional b) 131 | => l 132 | -> LParam_ I l 133 | -> t (a, b) 134 | -> b 135 | testLearnCov l p = L.fold $ (lmap . first) (runLearnStateless_ l p) cov 136 | 137 | testLearnCorr 138 | :: (Learn a b l, NoState l, Foldable t, Floating b) 139 | => l 140 | -> LParam_ I l 141 | -> t (a, b) 142 | -> b 143 | testLearnCorr l p = L.fold $ (lmap . first) (runLearnStateless_ l p) corr 144 | 145 | testLearnAll 146 | :: (Learn a b l, NoState l, Foldable t) 147 | => Test b 148 | -> l 149 | -> LParam_ I l 150 | -> t (a, b) 151 | -> Double 152 | testLearnAll t l p = L.fold $ lmap (uncurry (testLearn t l p)) L.mean 153 | 154 | -- newtype M m a = M { getM :: m a } 155 | -- instance (Semigroup a, Applicative m) => Semigroup (M m a) where 156 | -- M x <> M y = M $ liftA2 (<>) x y 157 | -- instance (Monoid a, Applicative m) => Monoid (M m a) where 158 | -- mappend = (<>) 159 | -- mempty = M (pure mempty) 160 | 161 | testLearnStochAll 162 | :: (Learn a b l, NoState l, PrimMonad m, Foldable t) 163 | => Test b 164 | -> l 165 | -> MWC.Gen (PrimState m) 166 | -> LParam_ I l 167 | -> t (a, b) 168 | -> m Double 169 | testLearnStochAll t l g p = L.foldM $ L.premapM (uncurry (testLearnStoch t l g p)) 170 | (L.generalize L.mean) 171 | 172 | testLearnStochCov 173 | :: (Learn a b l, NoState l, PrimMonad m, Foldable t, Fractional b) 174 | => l 175 | -> MWC.Gen (PrimState m) 176 | -> LParam_ I l 177 | -> t (a, b) 178 | -> m b 179 | testLearnStochCov l g p = L.foldM $ (L.premapM . flip bitraverse pure) 180 | (runLearnStochStateless_ l g p) 181 | (L.generalize cov) 182 | 183 | testLearnStochCorr 184 | :: (Learn a b l, NoState l, PrimMonad m, Foldable t, Floating b) 185 | => l 186 | -> MWC.Gen (PrimState m) 187 | -> LParam_ I l 188 | -> t (a, b) 189 | -> m b 190 | testLearnStochCorr l g p = L.foldM $ (L.premapM . flip bitraverse pure) 191 | (runLearnStochStateless_ l g p) 192 | (L.generalize corr) 193 | -------------------------------------------------------------------------------- /src/Backprop/Learn/Model/State.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ApplicativeDo #-} 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-# LANGUAGE FlexibleInstances #-} 4 | {-# LANGUAGE GADTs #-} 5 | {-# LANGUAGE MultiParamTypeClasses #-} 6 | {-# LANGUAGE PatternSynonyms #-} 7 | {-# LANGUAGE RankNTypes #-} 8 | {-# LANGUAGE ScopedTypeVariables #-} 9 | {-# LANGUAGE TypeApplications #-} 10 | {-# LANGUAGE TypeFamilies #-} 11 | {-# LANGUAGE TypeInType #-} 12 | {-# LANGUAGE TypeOperators #-} 13 | {-# LANGUAGE UndecidableInstances #-} 14 | 15 | module Backprop.Learn.Model.State ( 16 | -- * To and from statelessness 17 | trainState, deState, deStateD, zeroState, dummyState 18 | -- * Manipulate model states 19 | , unroll, unrollFinal, recurrent 20 | ) where 21 | 22 | import Backprop.Learn.Model.Types 23 | import Control.Monad.Primitive 24 | import Control.Monad.Trans.State 25 | import Data.Bifunctor 26 | import Data.Foldable 27 | import Data.Type.Functor.Product 28 | import Data.Type.Tuple 29 | import Numeric.Backprop 30 | import qualified System.Random.MWC as MWC 31 | 32 | -- | Make a model stateless by converting the state to a trained parameter, 33 | -- and dropping the modified state from the result. 34 | -- 35 | -- One of the ways to make a model stateless for training purposes. Useful 36 | -- when used after 'Unroll'. See 'DeState', as well. 37 | -- 38 | -- Its parameters are: 39 | -- 40 | -- * If the input has no parameters, just the initial state. 41 | -- * If the input has a parameter, a ':#' of that parameter and initial state. 42 | trainState 43 | :: forall p s a b. 44 | ( PureProd Maybe p 45 | , PureProd Maybe s 46 | , AllConstrainedProd Backprop p 47 | , AllConstrainedProd Backprop s 48 | ) 49 | => Model p s a b 50 | -> Model (p :#? s) 'Nothing a b 51 | trainState = withModelFunc $ \f (p :#? s) x n_ -> 52 | (second . const) n_ <$> f p x s 53 | 54 | -- | Make a model stateless by pre-applying a fixed state (or a stochastic 55 | -- one with fixed stribution) and dropping the modified state from the 56 | -- result. 57 | -- 58 | -- One of the ways to make a model stateless for training purposes. Useful 59 | -- when used after 'Unroll'. See 'TrainState', as well. 60 | deState 61 | :: s 62 | -> (forall m. PrimMonad m => MWC.Gen (PrimState m) -> m s) 63 | -> Model p ('Just s) a b 64 | -> Model p 'Nothing a b 65 | deState s sStoch f = Model 66 | { runLearn = \p x n_ -> 67 | (second . const) n_ $ runLearn f p x (PJust (auto s)) 68 | , runLearnStoch = \g p x n_ -> do 69 | s' <- sStoch g 70 | (second . const) n_ <$> runLearnStoch f g p x (PJust (auto s')) 71 | } 72 | 73 | -- | 'deState', except the state is always the same even in stochastic 74 | -- mode. 75 | deStateD 76 | :: s 77 | -> Model p ('Just s) a b 78 | -> Model p 'Nothing a b 79 | deStateD s = deState s (const (pure s)) 80 | 81 | -- | 'deState' with a constant state of 0. 82 | zeroState 83 | :: Num s 84 | => Model p ('Just s) a b 85 | -> Model p 'Nothing a b 86 | zeroState = deStateD 0 87 | 88 | -- | Unroll a (usually) stateful model into one taking a vector of 89 | -- sequential inputs. 90 | -- 91 | -- Basically applies the model to every item of input and returns all of 92 | -- the results, but propagating the state between every step. 93 | -- 94 | -- Useful when used before 'trainState' or 'deState'. See 95 | -- 'unrollTrainState' and 'unrollDeState'. 96 | -- 97 | -- Compare to 'feedbackTrace', which, instead of receiving a vector of 98 | -- sequential inputs, receives a single input and uses its output as the 99 | -- next input. 100 | unroll 101 | :: (Traversable t, Backprop a, Backprop b) 102 | => Model p s a b 103 | -> Model p s (t a) (t b) 104 | unroll = withModelFunc $ \f p xs s -> 105 | (fmap . first) collectVar 106 | . flip runStateT s 107 | . traverse (StateT . f p) 108 | . sequenceVar 109 | $ xs 110 | 111 | -- | Version of 'unroll' that only keeps the "final" result, dropping all 112 | -- of the intermediate results. 113 | -- 114 | -- Turns a stateful model into one that runs the model repeatedly on 115 | -- multiple inputs sequentially and outputs the final result after seeing 116 | -- all items. 117 | -- 118 | -- Note will be partial if given an empty sequence. 119 | unrollFinal 120 | :: (Traversable t, Backprop a) 121 | => Model p s a b 122 | -> Model p s (t a) b 123 | unrollFinal = withModelFunc $ \f p xs s0 -> 124 | foldlM (\(_, s) x -> f p x s) 125 | (undefined, s0) 126 | (sequenceVar xs) 127 | 128 | -- | Fix a part of a parameter of a model to be (a function of) the 129 | -- /previous/ ouput of the model itself. 130 | -- 131 | -- Essentially, takes a \( X \times Y \rightarrow Z \) into a /stateful/ 132 | -- \( X \rightarrow Z \), where the Y is given by a function of the 133 | -- /previous output/ of the model. 134 | -- 135 | -- Essentially makes a model "recurrent": it receives its previous output 136 | -- as input. 137 | -- 138 | -- See 'fcr' for an application. 139 | recurrent 140 | :: forall p s ab a b c. 141 | -- ( KnownMayb s 142 | ( AllConstrainedProd Backprop s 143 | , PureProd Maybe s 144 | , Backprop a 145 | , Backprop b 146 | ) 147 | => (ab -> (a, b)) -- ^ split 148 | -> (a -> b -> ab) -- ^ join 149 | -> BFunc c b -- ^ store state 150 | -> Model p s ab c 151 | -> Model p (s :#? 'Just b) a c 152 | recurrent spl joi sto = withModelFunc $ \f p x (s :#? y) -> do 153 | (z, s') <- f p (isoVar2 joi spl x (fromPJust y)) s 154 | pure (z, s' :#? PJust (sto z)) 155 | 156 | -- | Give a stateless model a "dummy" state. For now, useful for using 157 | -- with combinators like 'deState' that require state. However, 'deState' 158 | -- could also be made more lenient (to accept non stateful models) in the 159 | -- future. 160 | -- 161 | -- Also useful for usage with combinators like 'Control.Category..' from 162 | -- "Control.Category" that requires all input models to share common state. 163 | dummyState 164 | :: forall s p a b. () 165 | => Model p 'Nothing a b 166 | -> Model p s a b 167 | dummyState = withModelFunc $ \f p x s -> 168 | (second . const) s <$> f p x PNothing 169 | -------------------------------------------------------------------------------- /old/src/Numeric/BLAS/NVector.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE DeriveGeneric #-} 3 | {-# LANGUAGE FlexibleContexts #-} 4 | {-# LANGUAGE GADTs #-} 5 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 6 | {-# LANGUAGE InstanceSigs #-} 7 | {-# LANGUAGE KindSignatures #-} 8 | {-# LANGUAGE LambdaCase #-} 9 | {-# LANGUAGE ScopedTypeVariables #-} 10 | {-# LANGUAGE StandaloneDeriving #-} 11 | {-# LANGUAGE TypeApplications #-} 12 | {-# LANGUAGE TypeFamilyDependencies #-} 13 | {-# LANGUAGE TypeInType #-} 14 | {-# LANGUAGE TypeOperators #-} 15 | {-# LANGUAGE UndecidableInstances #-} 16 | 17 | module Numeric.BLAS.NVector ( 18 | NV(..) 19 | , NV' 20 | ) where 21 | 22 | import Control.Applicative 23 | import Control.Monad 24 | import Control.Monad.Trans.State 25 | import Data.Finite.Internal 26 | import Data.Kind 27 | import Data.Maybe 28 | import Data.Monoid (Endo(..)) 29 | import Data.Singletons 30 | import Data.Singletons.Prelude 31 | import Data.Singletons.TypeLits 32 | import Data.Type.Product 33 | import GHC.Generics (Generic) 34 | import GHC.TypeLits 35 | import Numeric.BLAS 36 | import Numeric.Tensor 37 | import qualified Data.Vector as UV 38 | import qualified Data.Vector.Sized as V 39 | import qualified Data.Vector.Storable as UVS 40 | import qualified Data.Vector.Storable.Sized as VS 41 | 42 | 43 | type family NV' (s :: [Nat]) = (h :: Type) | h -> s where 44 | NV' '[] = Double 45 | NV' (n ': ns) = V.Vector n (NV' ns) 46 | 47 | newtype NV :: [Nat] -> Type where 48 | NV :: { getNV :: NV' b } 49 | -> NV b 50 | deriving (Generic) 51 | 52 | deriving instance (Show (NV' a)) => Show (NV a) 53 | 54 | genNV :: Sing ns -> (Prod Finite ns -> Double) -> NV' ns 55 | genNV = \case 56 | SNil -> \f -> f Ø 57 | SNat `SCons` ss -> \f -> V.generate_ $ \i -> 58 | genNV ss (f . (i :<)) 59 | 60 | genNVA 61 | :: Applicative f 62 | => Sing ns 63 | -> (Prod Finite ns -> f Double) 64 | -> f (NV' ns) 65 | genNVA = \case 66 | SNil -> \f -> f Ø 67 | SNat `SCons` ss -> \f -> sequenceA . V.generate_ $ \i -> 68 | genNVA ss (f . (i :<)) 69 | 70 | sumNV 71 | :: Sing ns 72 | -> NV' ns 73 | -> Double 74 | sumNV = \case 75 | SNil -> id 76 | _ `SCons` ss -> V.sum . fmap (sumNV ss) 77 | 78 | mapNV 79 | :: Sing ns 80 | -> (Double -> Double) 81 | -> NV' ns 82 | -> NV' ns 83 | mapNV = \case 84 | SNil -> id 85 | _ `SCons` ss -> \f -> fmap (mapNV ss f) 86 | 87 | zipNV 88 | :: Sing ns 89 | -> (Double -> Double -> Double) 90 | -> NV' ns 91 | -> NV' ns 92 | -> NV' ns 93 | zipNV = \case 94 | SNil -> id 95 | _ `SCons` ss -> \f -> V.zipWith (zipNV ss f) 96 | 97 | indexNV 98 | :: Sing ns 99 | -> Prod Finite ns 100 | -> NV' ns 101 | -> Double 102 | indexNV = \case 103 | SNil -> \case 104 | Ø -> id 105 | SNat `SCons` ss -> \case 106 | i :< is -> indexNV ss is . flip V.index i 107 | 108 | loadNV 109 | :: Sing ns 110 | -> V.Vector (Product ns) Double 111 | -> NV' ns 112 | loadNV = \case 113 | SNil -> V.head 114 | sn@SNat `SCons` ss -> case sProduct ss of 115 | sp@SNat -> fromJust 116 | . V.fromList 117 | . evalState (replicateM (fromInteger (fromSing sn)) ( 118 | loadNV ss . fromJust . V.toSized 119 | <$> state (UV.splitAt (fromInteger (fromSing sp))) 120 | )) 121 | . V.fromSized 122 | 123 | nvElems 124 | :: Sing ns 125 | -> NV' ns 126 | -> [Double] 127 | nvElems s n = appEndo (go s n) [] 128 | where 129 | go :: Sing ms -> NV' ms -> Endo [Double] 130 | go = \case 131 | SNil -> \x -> Endo (x:) 132 | _ `SCons` ss -> foldMap (go ss) 133 | 134 | sliceNV 135 | :: ProdMap Slice ns ms 136 | -> NV' ns 137 | -> NV' ms 138 | sliceNV = \case 139 | PMZ -> id 140 | PMS (Slice sL sC@SNat _) pms -> 141 | let l = fromIntegral $ fromSing sL 142 | c = fromIntegral $ fromSing sC 143 | in fmap (sliceNV pms) 144 | . fromJust . V.toSized 145 | . UV.take c 146 | . UV.drop l 147 | . V.fromSized 148 | 149 | instance Tensor NV where 150 | type Scalar NV = Double 151 | 152 | gen s = NV . genNV s 153 | genA s = fmap NV . genNVA s 154 | tsum = sumNV sing . getNV 155 | tmap f = NV . mapNV sing f . getNV 156 | tzip f xs ys = NV $ zipNV sing f (getNV xs) (getNV ys) 157 | 158 | tindex i = indexNV sing i . getNV 159 | 160 | tload s = NV . loadNV s 161 | textract = withKnownNat (sProduct ss) $ 162 | fromJust . V.fromList . nvElems ss . getNV 163 | where 164 | ss = sing 165 | 166 | tslice p = NV . sliceNV p . getNV 167 | 168 | instance BLAS NV where 169 | transp = NV . sequenceA . getNV 170 | scal α = NV . fmap (α *) . getNV 171 | axpy α (NV xs) (NV ys) = NV $ liftA2 (\x y -> α * x + y) xs ys 172 | dot (NV xs) (NV ys) = V.sum $ V.zipWith (*) xs ys 173 | norm2 = V.sum . fmap (**2) . getNV 174 | asum = V.sum . fmap abs . getNV 175 | 176 | iamax 177 | :: forall n. KnownNat n 178 | => NV '[n + 1] 179 | -> Finite (n + 1) 180 | iamax = withKnownNat (SNat @n %:+ SNat @1) $ 181 | Finite . fromIntegral . UV.maxIndex . fmap abs . V.fromSized . getNV 182 | 183 | gemv α (NV a) (NV xs) b = maybe id (uncurry axpy) b 184 | . NV 185 | . fmap (V.sum . V.zipWith (\x -> (* (x * α))) xs) 186 | $ a 187 | 188 | ger α (NV xs) (NV ys) a = NV . addA $ fmap (\x -> fmap (* (x * α)) ys) xs 189 | where 190 | addA = case a of 191 | Nothing -> id 192 | Just (NV a') -> (V.zipWith . V.zipWith) (+) a' 193 | 194 | gemm α (NV ass) (NV bss) c = NV . addC $ 195 | fmap (sumVs . V.zipWith (\bs a -> fmap (* (α * a)) bs) bss) ass 196 | where 197 | sumVs = V.foldl' (V.zipWith (+)) (V.generate (\_ -> 0)) 198 | addC = case c of 199 | Nothing -> id 200 | Just (β, NV css) -> (V.zipWith . V.zipWith) (\c' -> (+ (β * c'))) css 201 | 202 | -------------------------------------------------------------------------------- /src/Backprop/Learn/Test.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ApplicativeDo #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE FlexibleContexts #-} 4 | {-# LANGUAGE PartialTypeSignatures #-} 5 | {-# LANGUAGE RankNTypes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE TypeApplications #-} 8 | {-# LANGUAGE TypeFamilies #-} 9 | {-# OPTIONS_GHC -Wno-partial-type-signatures #-} 10 | 11 | module Backprop.Learn.Test ( 12 | -- * Tests 13 | Test 14 | , maxIxTest, rmseTest 15 | , squaredErrorTest, absErrorTest, totalSquaredErrorTest, squaredErrorTestV 16 | , crossEntropyTest, crossEntropyTest1 17 | , boolTest 18 | -- ** Manipulate tests 19 | , lossTest, lmapTest 20 | -- * Run tests 21 | , testModel, testModelStoch, testModelAll, testModelStochAll 22 | -- ** Correlation tests 23 | , testModelCov, testModelCorr 24 | , testModelStochCov, testModelStochCorr 25 | ) where 26 | 27 | import Backprop.Learn.Loss 28 | import Backprop.Learn.Model 29 | import Control.Monad.Primitive 30 | import Data.Bifunctor 31 | import Data.Bitraversable 32 | import Data.Function 33 | import Data.Profunctor 34 | import Data.Proxy 35 | import GHC.TypeNats 36 | import Numeric.Backprop 37 | import qualified Control.Foldl as L 38 | import qualified Numeric.LinearAlgebra as HU 39 | import qualified Numeric.LinearAlgebra.Static as H 40 | import qualified System.Random.MWC as MWC 41 | 42 | -- TODO: support non-double results? 43 | 44 | type Test o = o -> o -> Double 45 | 46 | -- | Create a 'Test' from a 'Loss' 47 | lossTest :: Loss a -> Test a 48 | lossTest l x = evalBP (l x) 49 | 50 | boolTest :: forall a. RealFrac a => Test a 51 | boolTest x y 52 | | ri x == ri y = 1 53 | | otherwise = 0 54 | where 55 | ri :: a -> Int 56 | ri = round 57 | 58 | maxIxTest :: KnownNat n => Test (H.R n) 59 | maxIxTest x y 60 | | match x y = 1 61 | | otherwise = 0 62 | where 63 | match = (==) `on` (HU.maxIndex . H.extract) 64 | 65 | rmseTest :: forall n. KnownNat n => Test (H.R n) 66 | rmseTest x y = H.norm_2 (x - y) / sqrt (fromIntegral (natVal (Proxy @n))) 67 | 68 | squaredErrorTest :: Real a => Test a 69 | squaredErrorTest x y = e * e 70 | where 71 | e = realToFrac (x - y) 72 | 73 | absErrorTest :: Real a => Test a 74 | absErrorTest x y = realToFrac . abs $ x - y 75 | 76 | totalSquaredErrorTest :: (Applicative t, Foldable t, Real a) => Test (t a) 77 | totalSquaredErrorTest x y = realToFrac (sum e) 78 | where 79 | e = do 80 | x' <- x 81 | y' <- y 82 | pure ((x' - y') ^ (2 :: Int)) 83 | 84 | squaredErrorTestV :: KnownNat n => Test (H.R n) 85 | squaredErrorTestV x y = e `H.dot` e 86 | where 87 | e = x - y 88 | 89 | crossEntropyTest :: KnownNat n => Test (H.R n) 90 | crossEntropyTest targ res = -(log res H.<.> targ) 91 | 92 | crossEntropyTest1 :: Test Double 93 | crossEntropyTest1 targ res = -(log res * targ + log (1 - res) * (1 - targ)) 94 | 95 | lmapTest 96 | :: (a -> b) 97 | -> Test b 98 | -> Test a 99 | lmapTest f t x y = t (f x) (f y) 100 | 101 | testModel 102 | :: Test b 103 | -> Model p 'Nothing a b 104 | -> TMaybe p 105 | -> a 106 | -> b 107 | -> Double 108 | testModel t f mp x y = t y $ runModelStateless f mp x 109 | 110 | testModelStoch 111 | :: PrimMonad m 112 | => Test b 113 | -> Model p 'Nothing a b 114 | -> MWC.Gen (PrimState m) 115 | -> TMaybe p 116 | -> a 117 | -> b 118 | -> m Double 119 | testModelStoch t f g mp x y = t y <$> runModelStochStateless f g mp x 120 | 121 | cov :: Fractional a => L.Fold (a, a) a 122 | cov = do 123 | x <- lmap fst L.sum 124 | y <- lmap snd L.sum 125 | xy <- lmap (uncurry (*)) L.sum 126 | n <- fromIntegral <$> L.length 127 | pure (xy / n - (x * y) / n / n) 128 | 129 | corr :: Floating a => L.Fold (a, a) a 130 | corr = do 131 | x <- lmap fst L.sum 132 | x2 <- lmap ((**2) . fst) L.sum 133 | y <- lmap snd L.sum 134 | y2 <- lmap ((**2) . snd) L.sum 135 | xy <- lmap (uncurry (*)) L.sum 136 | n <- fromIntegral <$> L.length 137 | pure $ (xy / n - (x * y) / n / n) 138 | / sqrt ( x2 / n - (x / n)**2 ) 139 | / sqrt ( y2 / n - (y / n)**2 ) 140 | 141 | testModelCov 142 | :: (Foldable t, Fractional b) 143 | => Model p 'Nothing a b 144 | -> TMaybe p 145 | -> t (a, b) 146 | -> b 147 | testModelCov f p = L.fold $ (lmap . first) (runModelStateless f p) cov 148 | 149 | testModelCorr 150 | :: (Foldable t, Floating b) 151 | => Model p 'Nothing a b 152 | -> TMaybe p 153 | -> t (a, b) 154 | -> b 155 | testModelCorr f p = L.fold $ (lmap . first) (runModelStateless f p) corr 156 | 157 | testModelAll 158 | :: Foldable t 159 | => Test b 160 | -> Model p 'Nothing a b 161 | -> TMaybe p 162 | -> t (a, b) 163 | -> Double 164 | testModelAll t f p = L.fold $ lmap (uncurry (testModel t f p)) L.mean 165 | 166 | -- newtype M m a = M { getM :: m a } 167 | -- instance (Semigroup a, Applicative m) => Semigroup (M m a) where 168 | -- M x <> M y = M $ liftA2 (<>) x y 169 | -- instance (Monoid a, Applicative m) => Monoid (M m a) where 170 | -- mappend = (<>) 171 | -- mempty = M (pure mempty) 172 | 173 | testModelStochAll 174 | :: (Foldable t, PrimMonad m) 175 | => Test b 176 | -> Model p 'Nothing a b 177 | -> MWC.Gen (PrimState m) 178 | -> TMaybe p 179 | -> t (a, b) 180 | -> m Double 181 | testModelStochAll t f g p = L.foldM $ L.premapM (uncurry (testModelStoch t f g p)) 182 | (L.generalize L.mean) 183 | 184 | testModelStochCov 185 | :: (Foldable t, PrimMonad m, Fractional b) 186 | => Model p 'Nothing a b 187 | -> MWC.Gen (PrimState m) 188 | -> TMaybe p 189 | -> t (a, b) 190 | -> m b 191 | testModelStochCov f g p = L.foldM $ (L.premapM . flip bitraverse pure) 192 | (runModelStochStateless f g p) 193 | (L.generalize cov) 194 | 195 | testModelStochCorr 196 | :: (Foldable t, PrimMonad m, Floating b) 197 | => Model p 'Nothing a b 198 | -> MWC.Gen (PrimState m) 199 | -> TMaybe p 200 | -> t (a, b) 201 | -> m b 202 | testModelStochCorr f g p = L.foldM $ (L.premapM . flip bitraverse pure) 203 | (runModelStochStateless f g p) 204 | (L.generalize corr) 205 | -------------------------------------------------------------------------------- /old2/src/Backprop/Learn/Model/Stochastic.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveDataTypeable #-} 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-# LANGUAGE FlexibleInstances #-} 4 | {-# LANGUAGE GADTs #-} 5 | {-# LANGUAGE KindSignatures #-} 6 | {-# LANGUAGE MultiParamTypeClasses #-} 7 | {-# LANGUAGE PatternSynonyms #-} 8 | {-# LANGUAGE RankNTypes #-} 9 | {-# LANGUAGE RecordWildCards #-} 10 | {-# LANGUAGE TypeFamilies #-} 11 | {-# LANGUAGE TypeInType #-} 12 | {-# LANGUAGE UndecidableInstances #-} 13 | {-# LANGUAGE ViewPatterns #-} 14 | 15 | module Backprop.Learn.Model.Stochastic ( 16 | DO(..) 17 | , StochFunc(..) 18 | , FixedStochFunc, pattern FSF, _fsfRunDeterm, _fsfRunStoch 19 | , rreLU 20 | , injectNoise, applyNoise 21 | , injectNoiseR, applyNoiseR 22 | ) where 23 | 24 | import Backprop.Learn.Model.Class 25 | import Backprop.Learn.Model.Function 26 | import Control.Monad.Primitive 27 | import Data.Bool 28 | import Data.Kind 29 | import Data.Typeable 30 | import GHC.TypeNats 31 | import Numeric.Backprop 32 | import Numeric.LinearAlgebra.Static.Backprop 33 | import Numeric.LinearAlgebra.Static.Vector 34 | import qualified Data.Vector.Storable.Sized as SVS 35 | import qualified Statistics.Distribution as Stat 36 | import qualified System.Random.MWC as MWC 37 | import qualified System.Random.MWC.Distributions as MWC 38 | 39 | -- | Dropout layer. Parameterized by dropout percentage (should be between 40 | -- 0 and 1). 41 | -- 42 | -- 0 corresponds to no dropout, 1 corresponds to complete dropout of all 43 | -- nodes every time. 44 | newtype DO (n :: Nat) = DO { _doRate :: Double } 45 | deriving (Typeable) 46 | 47 | instance KnownNat n => Learn (R n) (R n) (DO n) where 48 | runLearn (DO r) _ = stateless (constVar (realToFrac (1-r)) *) 49 | runLearnStoch (DO r) g _ = statelessM $ \x -> 50 | (x *) . constVar . vecR <$> SVS.replicateM (mask g) 51 | where 52 | mask = fmap (bool 1 0) . MWC.bernoulli r 53 | 54 | -- | Represents a random-valued function, with a possible trainable 55 | -- parameter. 56 | -- 57 | -- Requires both a "deterministic" and a "stochastic" mode. The 58 | -- deterministic mode ideally should approximate some mean of the 59 | -- stochastic mode. 60 | data StochFunc :: Maybe Type -> Type -> Type -> Type where 61 | SF :: { _sfRunDeterm :: forall s. Reifies s W => Mayb (BVar s) p -> BVar s a -> BVar s b 62 | , _sfRunStoch 63 | :: forall m s. (PrimMonad m, Reifies s W) 64 | => MWC.Gen (PrimState m) 65 | -> Mayb (BVar s) p 66 | -> BVar s a 67 | -> m (BVar s b) 68 | } 69 | -> StochFunc p a b 70 | deriving (Typeable) 71 | 72 | instance Learn a b (StochFunc p a b) where 73 | type LParamMaybe (StochFunc p a b) = p 74 | type LStateMaybe (StochFunc p a b) = 'Nothing 75 | 76 | runLearn SF{..} = stateless . _sfRunDeterm 77 | runLearnStoch SF{..} g = statelessM . _sfRunStoch g 78 | 79 | -- | Convenient alias for a 'StochFunc' (random-valued function with both 80 | -- deterministic and stochastic modes) with no trained parameters. 81 | type FixedStochFunc = StochFunc 'Nothing 82 | 83 | -- | Construct a 'FixedStochFunc' 84 | pattern FSF :: (forall s. Reifies s W => BVar s a -> BVar s b) 85 | -> (forall m s. (PrimMonad m, Reifies s W) => MWC.Gen (PrimState m) -> BVar s a -> m (BVar s b)) 86 | -> FixedStochFunc a b 87 | pattern FSF { _fsfRunDeterm, _fsfRunStoch } <- (getFSF->(getWD->_fsfRunDeterm,getWS->_fsfRunStoch)) 88 | where 89 | FSF d s = SF { _sfRunDeterm = const d 90 | , _sfRunStoch = const . s 91 | } 92 | {-# COMPLETE FSF #-} 93 | 94 | newtype WrapDeterm a b = WD { getWD :: forall s. Reifies s W => BVar s a -> BVar s b } 95 | newtype WrapStoch a b = WS { getWS :: forall m s. (PrimMonad m, Reifies s W) => MWC.Gen (PrimState m) -> BVar s a -> m (BVar s b) } 96 | 97 | getFSF :: FixedStochFunc a b -> (WrapDeterm a b, WrapStoch a b) 98 | getFSF SF{..} = ( WD (_sfRunDeterm N_) 99 | , WS (`_sfRunStoch` N_) 100 | ) 101 | 102 | -- | Random leaky rectified linear unit 103 | rreLU 104 | :: (Stat.ContGen d, Stat.Mean d, KnownNat n) 105 | => d 106 | -> FixedStochFunc (R n) (R n) 107 | rreLU d = FSF { _fsfRunDeterm = vmap' (preLU v) 108 | , _fsfRunStoch = \g x -> do 109 | α <- vecR <$> SVS.replicateM (Stat.genContVar d g) 110 | pure (zipWithVector preLU (constVar α) x) 111 | } 112 | where 113 | v :: BVar s Double 114 | v = constVar (Stat.mean d) 115 | 116 | -- | Inject random noise. Usually used between neural network layers, or 117 | -- at the very beginning to pre-process input. 118 | -- 119 | -- In non-stochastic mode, this adds the mean of the distribution. 120 | injectNoise 121 | :: (Stat.ContGen d, Stat.Mean d, Fractional a) 122 | => d 123 | -> FixedStochFunc a a 124 | injectNoise d = FSF { _fsfRunDeterm = (realToFrac (Stat.mean d) +) 125 | , _fsfRunStoch = \g x -> do 126 | e <- Stat.genContVar d g 127 | pure (realToFrac e + x) 128 | } 129 | 130 | 131 | -- | 'injectNoise' lifted to 'R' 132 | injectNoiseR 133 | :: (Stat.ContGen d, Stat.Mean d, KnownNat n) 134 | => d 135 | -> FixedStochFunc (R n) (R n) 136 | injectNoiseR d = FSF { _fsfRunDeterm = (realToFrac (Stat.mean d) +) 137 | , _fsfRunStoch = \g x -> do 138 | e <- vecR <$> SVS.replicateM (Stat.genContVar d g) 139 | pure (constVar e + x) 140 | } 141 | 142 | -- | Multply by random noise. Can be used to implement dropout-like 143 | -- behavior. 144 | -- 145 | -- In non-stochastic mode, this scales by the mean of the distribution. 146 | applyNoise 147 | :: (Stat.ContGen d, Stat.Mean d, Fractional a) 148 | => d 149 | -> FixedStochFunc a a 150 | applyNoise d = FSF { _fsfRunDeterm = (realToFrac (Stat.mean d) *) 151 | , _fsfRunStoch = \g x -> do 152 | e <- Stat.genContVar d g 153 | pure (realToFrac e * x) 154 | } 155 | 156 | -- | 'applyNoise' lifted to 'R' 157 | applyNoiseR 158 | :: (Stat.ContGen d, Stat.Mean d, KnownNat n) 159 | => d 160 | -> FixedStochFunc (R n) (R n) 161 | applyNoiseR d = FSF { _fsfRunDeterm = (realToFrac (Stat.mean d) *) 162 | , _fsfRunStoch = \g x -> do 163 | e <- vecR <$> SVS.replicateM (Stat.genContVar d g) 164 | pure (constVar e * x) 165 | } 166 | -------------------------------------------------------------------------------- /src/Backprop/Learn/Initialize.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes #-} 2 | {-# LANGUAGE ConstraintKinds #-} 3 | {-# LANGUAGE DataKinds #-} 4 | {-# LANGUAGE DefaultSignatures #-} 5 | {-# LANGUAGE FlexibleContexts #-} 6 | {-# LANGUAGE FlexibleInstances #-} 7 | {-# LANGUAGE GADTs #-} 8 | {-# LANGUAGE RankNTypes #-} 9 | {-# LANGUAGE ScopedTypeVariables #-} 10 | {-# LANGUAGE TypeApplications #-} 11 | {-# LANGUAGE TypeFamilies #-} 12 | {-# LANGUAGE TypeOperators #-} 13 | {-# LANGUAGE UndecidableInstances #-} 14 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 15 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} 16 | 17 | module Backprop.Learn.Initialize ( 18 | Initialize(..) 19 | , gInitialize 20 | , initializeNormal 21 | , initializeSingle 22 | -- * Reshape 23 | , reshapeR 24 | , reshapeLRows 25 | , reshapeLCols 26 | ) where 27 | 28 | import Control.Monad.Primitive 29 | import Data.Complex 30 | import Data.Proxy 31 | import Data.Type.Equality 32 | import Data.Type.Tuple 33 | import Data.Vinyl 34 | import GHC.TypeLits.Compare 35 | import GHC.TypeNats 36 | import Generics.OneLiner 37 | import Numeric.LinearAlgebra.Static.Vector 38 | import Statistics.Distribution 39 | import Statistics.Distribution.Normal 40 | import qualified Data.Vector.Generic as VG 41 | import qualified Data.Vector.Generic.Sized as SVG 42 | import qualified Numeric.LinearAlgebra.Static as H 43 | import qualified System.Random.MWC as MWC 44 | 45 | -- | Class for types that are basically a bunch of 'Double's, which can be 46 | -- initialized with a given identical and independent distribution. 47 | class Initialize p where 48 | initialize 49 | :: (ContGen d, PrimMonad m) 50 | => d 51 | -> MWC.Gen (PrimState m) 52 | -> m p 53 | 54 | default initialize 55 | :: (ADTRecord p, Constraints p Initialize, ContGen d, PrimMonad m) 56 | => d 57 | -> MWC.Gen (PrimState m) 58 | -> m p 59 | initialize = gInitialize 60 | 61 | -- | 'initialize' for any instance of 'Generic'. 62 | gInitialize 63 | :: (ADTRecord p, Constraints p Initialize, ContGen d, PrimMonad m) 64 | => d 65 | -> MWC.Gen (PrimState m) 66 | -> m p 67 | gInitialize d g = createA' @Initialize (initialize d g) 68 | 69 | -- | Helper over 'inititialize' for a gaussian distribution centered around 70 | -- zero. 71 | initializeNormal 72 | :: (Initialize p, PrimMonad m) 73 | => Double -- ^ standard deviation 74 | -> MWC.Gen (PrimState m) 75 | -> m p 76 | initializeNormal = initialize . normalDistr 0 77 | 78 | -- | 'initialize' definition if @p@ is a single number. 79 | initializeSingle 80 | :: (ContGen d, PrimMonad m, Fractional p) 81 | => d 82 | -> MWC.Gen (PrimState m) 83 | -> m p 84 | initializeSingle d = fmap realToFrac . genContVar d 85 | 86 | instance Initialize Double where 87 | initialize = initializeSingle 88 | instance Initialize Float where 89 | initialize = initializeSingle 90 | 91 | -- | Initializes real and imaginary components identically 92 | instance Initialize a => Initialize (Complex a) where 93 | 94 | instance Initialize T0 95 | instance Initialize a => Initialize (TF a) 96 | instance (Initialize a, Initialize b) => Initialize (a :# b) 97 | 98 | instance RPureConstrained Initialize as => Initialize (T as) where 99 | initialize d g = rtraverse (fmap TF) 100 | $ rpureConstrained @Initialize (initialize d g) 101 | 102 | -- instance (Initialize a, ListC (Initialize <$> as), Known Length as) => Initialize (NETup (a ':| as)) where 103 | -- initialize d g = NET <$> initialize d g 104 | -- <*> initialize d g 105 | 106 | instance Initialize () 107 | instance (Initialize a, Initialize b) => Initialize (a, b) 108 | instance (Initialize a, Initialize b, Initialize c) => Initialize (a, b, c) 109 | instance (Initialize a, Initialize b, Initialize c, Initialize d) => Initialize (a, b, c, d) 110 | instance (Initialize a, Initialize b, Initialize c, Initialize d, Initialize e) => Initialize (a, b, c, d, e) 111 | 112 | instance (VG.Vector v a, KnownNat n, Initialize a) => Initialize (SVG.Vector v n a) where 113 | initialize d = SVG.replicateM . initialize d 114 | 115 | instance KnownNat n => Initialize (H.R n) where 116 | initialize d = fmap vecR . initialize d 117 | instance KnownNat n => Initialize (H.C n) where 118 | initialize d = fmap vecC . initialize d 119 | 120 | instance (KnownNat n, KnownNat m) => Initialize (H.L n m) where 121 | initialize d = fmap vecL . initialize d 122 | instance (KnownNat n, KnownNat m) => Initialize (H.M n m) where 123 | initialize d = fmap vecM . initialize d 124 | 125 | -- | Reshape a vector to have a different amount of items If the matrix is 126 | -- grown, new weights are initialized according to the given distribution. 127 | reshapeR 128 | :: forall i j d m. (ContGen d, PrimMonad m, KnownNat i, KnownNat j) 129 | => d 130 | -> MWC.Gen (PrimState m) 131 | -> H.R i 132 | -> m (H.R j) 133 | reshapeR d g x = case Proxy @j %<=? Proxy @i of 134 | LE Refl -> pure . vecR . SVG.take @_ @j @(i - j) . rVec $ x 135 | NLE Refl Refl -> (x H.#) <$> initialize @(H.R (j - i)) d g 136 | 137 | -- | Reshape a matrix to have a different amount of rows If the matrix 138 | -- is grown, new weights are initialized according to the given 139 | -- distribution. 140 | reshapeLRows 141 | :: forall i j n d m. (ContGen d, PrimMonad m, KnownNat n, KnownNat i, KnownNat j) 142 | => d 143 | -> MWC.Gen (PrimState m) 144 | -> H.L i n 145 | -> m (H.L j n) 146 | reshapeLRows d g x = case Proxy @j %<=? Proxy @i of 147 | LE Refl -> pure . rowsL . SVG.take @_ @j @(i - j) . lRows $ x 148 | NLE Refl Refl -> (x H.===) <$> initialize @(H.L (j - i) n) d g 149 | 150 | -- | Reshape a matrix to have a different amount of columns. If the matrix 151 | -- is grown, new weights are initialized according to the given 152 | -- distribution. 153 | reshapeLCols 154 | :: forall i j n d m. (ContGen d, PrimMonad m, KnownNat n, KnownNat i, KnownNat j) 155 | => d 156 | -> MWC.Gen (PrimState m) 157 | -> H.L n i 158 | -> m (H.L n j) 159 | reshapeLCols d g x = case Proxy @j %<=? Proxy @i of 160 | LE Refl -> pure . colsL . SVG.take @_ @j @(i - j) . lCols $ x 161 | NLE Refl Refl -> (x H.|||) <$> initialize @(H.L n (j - i)) d g 162 | -------------------------------------------------------------------------------- /old2/src/Data/Type/Mayb.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes #-} 2 | {-# LANGUAGE ConstraintKinds #-} 3 | {-# LANGUAGE FlexibleContexts #-} 4 | {-# LANGUAGE FlexibleInstances #-} 5 | {-# LANGUAGE GADTs #-} 6 | {-# LANGUAGE KindSignatures #-} 7 | {-# LANGUAGE LambdaCase #-} 8 | {-# LANGUAGE MultiParamTypeClasses #-} 9 | {-# LANGUAGE PatternSynonyms #-} 10 | {-# LANGUAGE RankNTypes #-} 11 | {-# LANGUAGE ScopedTypeVariables #-} 12 | {-# LANGUAGE StandaloneDeriving #-} 13 | {-# LANGUAGE TypeApplications #-} 14 | {-# LANGUAGE TypeFamilies #-} 15 | {-# LANGUAGE TypeFamilyDependencies #-} 16 | {-# LANGUAGE TypeInType #-} 17 | {-# LANGUAGE TypeOperators #-} 18 | {-# LANGUAGE UndecidableInstances #-} 19 | 20 | module Data.Type.Mayb ( 21 | MaybeC, MaybeToList, ListToMaybe 22 | , Mayb(.., J_I), fromJ_, maybToList, listToMayb 23 | , P(..), KnownMayb, knownMayb 24 | , zipMayb 25 | , zipMayb3 26 | , FromJust 27 | , MaybeWit(..), type (<$>) 28 | , TupMaybe, splitTupMaybe, tupMaybe 29 | , BoolMayb, boolMayb 30 | ) where 31 | 32 | import Data.Bifunctor 33 | import Data.Kind 34 | import Data.Type.Boolean 35 | import Data.Type.Combinator 36 | import Data.Type.Product 37 | import Data.Type.Tuple 38 | import Type.Class.Higher 39 | import Type.Class.Known 40 | import Type.Class.Witness 41 | import Type.Family.Maybe (type (<$>)) 42 | import qualified GHC.TypeLits as TL 43 | 44 | type family MaybeC (c :: k -> Constraint) (m :: Maybe k) :: Constraint where 45 | MaybeC c ('Just a) = c a 46 | MaybeC c 'Nothing = () 47 | 48 | type family MaybeToList (m :: Maybe k) = (l :: [k]) | l -> m where 49 | MaybeToList 'Nothing = '[] 50 | MaybeToList ('Just a) = '[a] 51 | 52 | -- type family ConsMaybe (ml :: (Maybe k, [k])) = (l :: (Bool, [k])) | l -> ml where 53 | -- ConsMaybe '( 'Nothing, as) = '( 'False, as ) 54 | -- ConsMaybe '( 'Just a , as) = '( 'True , a ': as ) 55 | 56 | maybToList 57 | :: Mayb f m 58 | -> Prod f (MaybeToList m) 59 | maybToList N_ = Ø 60 | maybToList (J_ x) = x :< Ø 61 | 62 | type family ListToMaybe (l :: [k]) :: Maybe k where 63 | ListToMaybe '[] = 'Nothing 64 | ListToMaybe (a ': as) = 'Just a 65 | 66 | listToMayb 67 | :: Prod f as 68 | -> Mayb f (ListToMaybe as) 69 | listToMayb Ø = N_ 70 | listToMayb (x :< _) = J_ x 71 | 72 | class MaybeWit (c :: k -> Constraint) (m :: Maybe k) where 73 | maybeWit :: Mayb (Wit1 c) m 74 | 75 | instance (MaybeC c m, Known (Mayb P) m) => MaybeWit c m where 76 | maybeWit = case known @_ @(Mayb P) @m of 77 | J_ _ -> J_ Wit1 78 | N_ -> N_ 79 | 80 | type KnownMayb = Known (Mayb P) 81 | 82 | knownMayb :: KnownMayb p => Mayb P p 83 | knownMayb = known 84 | 85 | data Mayb :: (k -> Type) -> Maybe k -> Type where 86 | N_ :: Mayb f 'Nothing 87 | J_ :: !(f a) -> Mayb f ('Just a) 88 | 89 | deriving instance MaybeC Show (f <$> m) => Show (Mayb f m) 90 | 91 | data P :: k -> Type where 92 | P :: P a 93 | 94 | fromJ_ :: Mayb f ('Just a) -> f a 95 | fromJ_ (J_ x) = x 96 | 97 | pattern J_I :: a -> Mayb I ('Just a) 98 | pattern J_I x = J_ (I x) 99 | 100 | instance Known P k where 101 | known = P 102 | 103 | instance Known (Mayb f) 'Nothing where 104 | known = N_ 105 | 106 | instance Known f a => Known (Mayb f) ('Just a) where 107 | type KnownC (Mayb f) ('Just a) = Known f a 108 | known = J_ known 109 | 110 | instance Functor1 Mayb where 111 | map1 f (J_ x) = J_ (f x) 112 | map1 _ N_ = N_ 113 | 114 | zipMayb 115 | :: (forall a. f a -> g a -> h a) 116 | -> Mayb f m 117 | -> Mayb g m 118 | -> Mayb h m 119 | zipMayb f (J_ x) (J_ y) = J_ (f x y) 120 | zipMayb _ N_ N_ = N_ 121 | 122 | zipMayb3 123 | :: (forall a. f a -> g a -> h a -> i a) 124 | -> Mayb f m 125 | -> Mayb g m 126 | -> Mayb h m 127 | -> Mayb i m 128 | zipMayb3 f (J_ x) (J_ y) (J_ z) = J_ (f x y z) 129 | zipMayb3 _ N_ N_ N_ = N_ 130 | 131 | m2 :: forall c f m. MaybeC c (f <$> m) 132 | => (forall a. c (f a) => f a -> f a -> f a) 133 | -> Mayb f m 134 | -> Mayb f m 135 | -> Mayb f m 136 | m2 f (J_ x) (J_ y) = J_ (f x y) 137 | m2 _ N_ N_ = N_ 138 | 139 | m1 :: forall c f m. MaybeC c (f <$> m) 140 | => (forall a. c (f a) => f a -> f a) 141 | -> Mayb f m 142 | -> Mayb f m 143 | m1 f (J_ x) = J_ (f x) 144 | m1 _ N_ = N_ 145 | 146 | m0 :: forall c f m. (MaybeC c (f <$> m)) 147 | => (forall a. c (f a) => f a) 148 | -> Mayb P m 149 | -> Mayb f m 150 | m0 x (J_ _) = J_ x 151 | m0 _ N_ = N_ 152 | 153 | instance (Known (Mayb P) m, MaybeC Num (f <$> m)) => Num (Mayb f m) where 154 | (+) = m2 @Num (+) 155 | (-) = m2 @Num (-) 156 | (*) = m2 @Num (*) 157 | negate = m1 @Num negate 158 | abs = m1 @Num abs 159 | signum = m1 @Num signum 160 | fromInteger x = m0 @Num (fromInteger x) known 161 | 162 | instance (Known (Mayb P) m, MaybeC Num (f <$> m), MaybeC Fractional (f <$> m)) 163 | => Fractional (Mayb f m) where 164 | (/) = m2 @Fractional (/) 165 | recip = m1 @Fractional recip 166 | fromRational x = m0 @Fractional (fromRational x) known 167 | 168 | type family FromJust (d :: TL.ErrorMessage) (m :: Maybe k) :: k where 169 | FromJust e ('Just a) = a 170 | FromJust e 'Nothing = TL.TypeError e 171 | 172 | type family TupMaybe (a :: Maybe Type) (b :: Maybe Type) :: Maybe Type where 173 | TupMaybe 'Nothing 'Nothing = 'Nothing 174 | TupMaybe 'Nothing ('Just b) = 'Just b 175 | TupMaybe ('Just a) 'Nothing = 'Just a 176 | TupMaybe ('Just a) ('Just b) = 'Just (T2 a b) 177 | 178 | tupMaybe 179 | :: forall f a b. () 180 | => (forall a' b'. (a ~ 'Just a', b ~ 'Just b') => f a' -> f b' -> f (T2 a' b')) 181 | -> Mayb f a 182 | -> Mayb f b 183 | -> Mayb f (TupMaybe a b) 184 | tupMaybe f = \case 185 | N_ -> \case 186 | N_ -> N_ 187 | J_ y -> J_ y 188 | J_ x -> \case 189 | N_ -> J_ x 190 | J_ y -> J_ (f x y) 191 | 192 | splitTupMaybe 193 | :: forall f a b. (KnownMayb a, KnownMayb b) 194 | => (forall a' b'. (a ~ 'Just a', b ~ 'Just b') => f (T2 a' b') -> (f a', f b')) 195 | -> Mayb f (TupMaybe a b) 196 | -> (Mayb f a, Mayb f b) 197 | splitTupMaybe f = case knownMayb @a of 198 | N_ -> case knownMayb @b of 199 | N_ -> \case 200 | N_ -> (N_, N_) 201 | J_ _ -> \case 202 | J_ y -> (N_, J_ y) 203 | J_ _ -> case knownMayb @b of 204 | N_ -> \case 205 | J_ x -> (J_ x, N_) 206 | J_ _ -> \case 207 | J_ xy -> bimap J_ J_ . f $ xy 208 | 209 | type family BoolMayb (b :: Bool) = (m :: Maybe ()) | m -> b where 210 | BoolMayb 'False = 'Nothing 211 | BoolMayb 'True = 'Just '() 212 | 213 | boolMayb :: Boolean b -> Mayb P (BoolMayb b) 214 | boolMayb False_ = N_ 215 | boolMayb True_ = J_ P 216 | -------------------------------------------------------------------------------- /old2/src/Backprop/Learn/Model/Class.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes #-} 2 | {-# LANGUAGE ConstraintKinds #-} 3 | {-# LANGUAGE DataKinds #-} 4 | {-# LANGUAGE FlexibleContexts #-} 5 | {-# LANGUAGE FlexibleInstances #-} 6 | {-# LANGUAGE FunctionalDependencies #-} 7 | {-# LANGUAGE GADTs #-} 8 | {-# LANGUAGE KindSignatures #-} 9 | {-# LANGUAGE MultiParamTypeClasses #-} 10 | {-# LANGUAGE PatternSynonyms #-} 11 | {-# LANGUAGE ScopedTypeVariables #-} 12 | {-# LANGUAGE TupleSections #-} 13 | {-# LANGUAGE TypeFamilies #-} 14 | {-# LANGUAGE TypeFamilyDependencies #-} 15 | {-# LANGUAGE TypeInType #-} 16 | {-# LANGUAGE TypeOperators #-} 17 | {-# LANGUAGE UndecidableInstances #-} 18 | {-# LANGUAGE UndecidableSuperClasses #-} 19 | 20 | module Backprop.Learn.Model.Class ( 21 | Learn(..) 22 | , LParam, LState, LParams, LStates, NoParam, NoState 23 | , LParam_, LState_ 24 | , stateless, statelessM 25 | , runLearnStateless 26 | , runLearnStochStateless 27 | , Mayb(..), fromJ_, MaybeC, KnownMayb, knownMayb, I(..) 28 | , SomeLearn(..) 29 | ) where 30 | 31 | import Backprop.Learn.Initialize 32 | import Control.DeepSeq 33 | import Control.Monad.Primitive 34 | import Data.Kind 35 | import Data.Type.Mayb 36 | import Data.Typeable 37 | import Numeric.Backprop 38 | import Numeric.Opto.Update 39 | import Type.Family.List (type (++)) 40 | import qualified GHC.TypeLits as TL 41 | import qualified System.Random.MWC as MWC 42 | 43 | -- | The trainable parameter type of a model. Will be a compile-time error 44 | -- if the model has no trainable parameters. 45 | type LParam l = FromJust 46 | ('TL.ShowType l 'TL.:<>: 'TL.Text " has no trainable parameters") 47 | (LParamMaybe l) 48 | 49 | -- | The state type of a model. Will be a compile-time error if the model 50 | -- has no state. 51 | type LState l = FromJust 52 | ('TL.ShowType l 'TL.:<>: 'TL.Text " has no trainable parameters") 53 | (LStateMaybe l) 54 | 55 | -- | Constraint specifying that a given model has no trainabale parameters. 56 | type NoParam l = LParamMaybe l ~ 'Nothing 57 | 58 | -- | Constraint specifying that a given model has no state. 59 | type NoState l = LStateMaybe l ~ 'Nothing 60 | 61 | -- | Is 'N_' if there is @l@ has no trainable parameters; otherwise is 'J_' 62 | -- with @f p@, for trainable parameter type @p@. 63 | type LParam_ f l = Mayb f (LParamMaybe l) 64 | 65 | -- | Is 'N_' if there is @l@ has no state; otherwise is 'J_' with @f 66 | -- s@, for state type @s@. 67 | type LState_ f l = Mayb f (LStateMaybe l) 68 | 69 | -- | List of parameters of 'Learn' instances 70 | type family LParams (ls :: [Type]) :: [Type] where 71 | LParams '[] = '[] 72 | LParams (l ': ls) = MaybeToList (LParamMaybe l) ++ LParams ls 73 | 74 | -- | List of states of 'Learn' instances 75 | type family LStates (ls :: [Type]) :: [Type] where 76 | LStates '[] = '[] 77 | LStates (l ': ls) = MaybeToList (LStateMaybe l) ++ LStates ls 78 | 79 | -- | Class for models that can be trained using gradient descent 80 | -- 81 | -- An instance @l@ of @'Learn' a b@ is parameterized by @p@, takes @a@ as 82 | -- input, and returns @b@ as outputs. @l@ can be thought of as a value 83 | -- containing the /hyperparmaeters/ of the model. 84 | class Learn a b l | l -> a b where 85 | 86 | -- | The trainable parameters of model @l@. 87 | -- 88 | -- By default, is ''Nothing'. To give a type for learned parameters @p@, 89 | -- use the type @''Just' p@ 90 | type LParamMaybe l :: Maybe Type 91 | 92 | -- | The type of the state of model @l@. Used for things like 93 | -- recurrent neural networks. 94 | -- 95 | -- By default, is ''Nothing'. To give a type for state @s@, use the 96 | -- type @''Just' s@. 97 | -- 98 | -- Most models will not use state, training algorithms will only work 99 | -- if 'LStateMaybe' is ''Nothing'. However, models that use state can 100 | -- be converted to models that do not using 'Unroll'; this can be done 101 | -- before training. 102 | type LStateMaybe l :: Maybe Type 103 | 104 | type LParamMaybe l = 'Nothing 105 | type LStateMaybe l = 'Nothing 106 | 107 | -- | Run the model itself, deterministically. 108 | -- 109 | -- If your model has no state, you can define this conveniently using 110 | -- 'stateless'. 111 | runLearn 112 | :: Reifies s W 113 | => l 114 | -> LParam_ (BVar s) l 115 | -> BVar s a 116 | -> LState_ (BVar s) l 117 | -> (BVar s b, LState_ (BVar s) l) 118 | 119 | -- | Run a model in stochastic mode. 120 | -- 121 | -- If model is inherently non-stochastic, a default implementation is 122 | -- given in terms of 'runLearn'. 123 | -- 124 | -- If your model has no state, you can define this conveniently using 125 | -- 'statelessStoch'. 126 | runLearnStoch 127 | :: (Reifies s W, PrimMonad m) 128 | => l 129 | -> MWC.Gen (PrimState m) 130 | -> LParam_ (BVar s) l 131 | -> BVar s a 132 | -> LState_ (BVar s) l 133 | -> m (BVar s b, LState_ (BVar s) l) 134 | runLearnStoch l _ p x s = pure (runLearn l p x s) 135 | 136 | -- | Useful for defining 'runLearn' if your model has no state. 137 | stateless 138 | :: (a -> b) 139 | -> (a -> s -> (b, s)) 140 | stateless f x = (f x,) 141 | 142 | -- | Useful for defining 'runLearnStoch' if your model has no state. 143 | statelessM 144 | :: Functor m 145 | => (a -> m b) 146 | -> (a -> s -> m (b, s)) 147 | statelessM f x s = (, s) <$> f x 148 | 149 | runLearnStateless 150 | :: (Learn a b l, Reifies s W, NoState l) 151 | => l 152 | -> LParam_ (BVar s) l 153 | -> BVar s a 154 | -> BVar s b 155 | runLearnStateless l p = fst . flip (runLearn l p) N_ 156 | 157 | runLearnStochStateless 158 | :: (Learn a b l, Reifies s W, NoState l, PrimMonad m) 159 | => l 160 | -> MWC.Gen (PrimState m) 161 | -> LParam_ (BVar s) l 162 | -> BVar s a 163 | -> m (BVar s b) 164 | runLearnStochStateless l g p = fmap fst . flip (runLearnStoch l g p) N_ 165 | 166 | -- | Existential wrapper for learnable model, representing a trainable 167 | -- function from @a@ to @b@. 168 | data SomeLearn :: Type -> Type -> Type where 169 | SL :: ( Learn a b l 170 | , Typeable l 171 | , KnownMayb (LParamMaybe l) 172 | , KnownMayb (LStateMaybe l) 173 | , MaybeC Floating (LParamMaybe l) 174 | , MaybeC Floating (LStateMaybe l) 175 | , MaybeC (Metric Double) (LParamMaybe l) 176 | , MaybeC (Metric Double) (LStateMaybe l) 177 | , MaybeC NFData (LParamMaybe l) 178 | , MaybeC NFData (LStateMaybe l) 179 | , MaybeC Initialize (LParamMaybe l) 180 | , MaybeC Initialize (LStateMaybe l) 181 | ) 182 | => l 183 | -> SomeLearn a b 184 | 185 | -------------------------------------------------------------------------------- /old/app/Language.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE GADTs #-} 3 | {-# LANGUAGE LambdaCase #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | {-# LANGUAGE TypeApplications #-} 6 | {-# LANGUAGE TypeOperators #-} 7 | {-# LANGUAGE ViewPatterns #-} 8 | {-# OPTIONS_GHC -fno-warn-orphans #-} 9 | 10 | -- import Data.Type.Combinator 11 | import Control.DeepSeq 12 | import Control.Exception 13 | import Control.Monad.IO.Class 14 | import Control.Monad.Primitive 15 | import Control.Monad.Trans.State 16 | import Data.Bifunctor 17 | import Data.Char 18 | import Data.Default 19 | import Data.Finite 20 | import Data.Foldable 21 | import Data.List 22 | import Data.List.Split 23 | import Data.Maybe 24 | import Data.Ord 25 | import Data.Time.Clock 26 | import Data.Type.Product hiding (toList, head') 27 | import Data.Type.Vector 28 | import GHC.TypeLits hiding (type (+)) 29 | import Learn.Neural 30 | import Learn.Neural.Layer.Recurrent.FullyConnected 31 | import Learn.Neural.Layer.Recurrent.LSTM 32 | import Numeric.BLAS.HMatrix 33 | import Numeric.Backprop.Op hiding (head') 34 | import Text.Printf hiding (toChar, fromChar) 35 | import Type.Class.Known 36 | import Type.Family.Nat 37 | import qualified Data.List.NonEmpty as NE 38 | import qualified Data.Type.Nat as TCN 39 | import qualified Data.Vector as V 40 | import qualified Data.Vector.Sized as VSi 41 | import qualified System.Random.MWC as MWC 42 | import qualified System.Random.MWC.Distributions as MWC 43 | 44 | type ASCII = Finite 128 45 | 46 | fromChar :: Char -> Maybe ASCII 47 | fromChar = packFinite . fromIntegral . ord 48 | 49 | toChar :: ASCII -> Char 50 | toChar = chr . fromIntegral 51 | 52 | charOneHot :: Tensor t => Char -> Maybe (t '[128]) 53 | charOneHot = fmap (oneHot . only) . fromChar 54 | 55 | oneHotChar :: BLAS t => t '[128] -> Char 56 | oneHotChar = toChar . iamax 57 | 58 | charRank :: Tensor t => t '[128] -> [Char] 59 | charRank = map fst . sortBy (flip (comparing snd)) . zip ['\0'..] . toList . textract 60 | 61 | main :: IO () 62 | main = MWC.withSystemRandom $ \g -> do 63 | holmes <- evaluate . force . mapMaybe (charOneHot @HM) 64 | =<< readFile "data/holmes.txt" 65 | putStrLn "Loaded data" 66 | -- let slices_ :: [(Vec N4 (HM '[128]), HM '[128])] 67 | -- slices_ = slidingPartsLast known . asFeedback $ holmes 68 | let slices_ :: [(Vec N8 (HM '[128]), Vec N8 (HM '[128]))] 69 | slices_ = slidingPartsSplit known . asFeedback $ holmes 70 | slices <- evaluate . force $ slices_ 71 | putStrLn "Processed data" 72 | let opt0 = batching $ adamOptimizer def (netOpRecurrent_ known) 73 | (sumLossDecay crossEntropy known α) 74 | -- net0 :: Network 'Recurrent HM ( '[128] :~ FullyConnectedR' 'MF_Logit ) 75 | -- '[ '[64 ] :~ ReLUMap 76 | -- , '[64 ] :~ FullyConnectedR' 'MF_Logit 77 | -- , '[32 ] :~ ReLUMap 78 | -- , '[32 ] :~ FullyConnected 79 | -- , '[128] :~ SoftMax '[128] 80 | -- ] 81 | -- '[128] <- initDefNet g 82 | net0 :: Network 'Recurrent HM ( '[128] :~ LSTM ) 83 | '[ '[96 ] :~ ReLUMap 84 | , '[96 ] :~ LSTM 85 | , '[64 ] :~ ReLUMap 86 | , '[64 ] :~ FullyConnected 87 | , '[128] :~ SoftMax '[128] 88 | ] 89 | '[128] <- initDefNet g 90 | flip evalStateT (net0, opt0) . forM_ [1..] $ \e -> do 91 | train' <- liftIO . fmap V.toList $ MWC.uniformShuffle (V.fromList slices) g 92 | liftIO $ printf "[Epoch %d]\n" (e :: Int) 93 | 94 | let chunkUp = chunksOf batch train' 95 | numChunks = length chunkUp 96 | forM_ ([1..] `zip` chunkUp) $ \(b, chnk) -> StateT $ \(n0, o0) -> do 97 | printf "(Epoch %d, Batch %d / %d)\n" (e :: Int) (b :: Int) numChunks 98 | 99 | t0 <- getCurrentTime 100 | -- n' <- evaluate $ optimizeList_ (bimap vecToProd only_ <$> chnk) n0 101 | (n', o') <- evaluate 102 | $ optimizeListBatches (bimap vecToProd vecToProd <$> chnk) n0 o0 25 103 | t1 <- getCurrentTime 104 | printf "Trained on %d points in %s.\n" batch (show (t1 `diffUTCTime` t0)) 105 | 106 | forM_ (map (bimap toList (last . toList)) . take 3 $ chnk) $ \(lastChnk, x0) -> do 107 | let (ys, primed) = runNetRecurrent n' lastChnk 108 | next :: HM '[128] -> IO ((Char, HM '[128]), HM '[128]) 109 | next x = do 110 | pick <- pickNext 4 x g 111 | return ((toChar pick, x), oneHot (only pick)) 112 | test <- toList . fst 113 | <$> runNetFeedbackM (known @_ @_ @(N10 + N6)) next primed x0 114 | 115 | forM_ (zip (lastChnk ++ [x0]) (ys ++ [snd (head test)])) $ \(c,y) -> 116 | printf "|%c\t=> %s\t(%.4f)\n" 117 | (censor (oneHotChar c)) 118 | (take 25 (censor <$> charRank y)) 119 | (amax y) 120 | forM_ (zip test (drop 1 test)) $ \((t,_),(_,p)) -> 121 | printf " %c\t=> %s\t(%.4f)\n" 122 | (censor t) 123 | (take 25 (censor <$> charRank p)) 124 | (amax p) 125 | putStrLn "---" 126 | 127 | let (test, _) = runNetRecurrentLast n' . vecNonEmpty . fst . head $ chnk 128 | let n'' | isNaN (amax test) = n0 129 | | otherwise = n' 130 | return ((), (n'', o')) 131 | where 132 | batch :: Int 133 | batch = 1000 134 | α :: Double 135 | α = 2/3 136 | 137 | pickNext 138 | :: (PrimMonad m, BLAS t, KnownNat n) 139 | => Double 140 | -> t '[n] 141 | -> MWC.Gen (PrimState m) 142 | -> m (Finite n) 143 | pickNext α x g 144 | = fmap fromIntegral 145 | . flip MWC.categorical g 146 | . fmap ((** α) . realToFrac) 147 | . VSi.fromSized 148 | . textract 149 | $ x 150 | 151 | 152 | censor :: Char -> Char 153 | censor c 154 | | isPrint c = c 155 | | otherwise = '░' 156 | 157 | instance NFData a => NFData (I a) where 158 | rnf = \case 159 | I x -> rnf x 160 | 161 | instance NFData (f a) => NFData (VecT n f a) where 162 | rnf = \case 163 | ØV -> () 164 | x :* xs -> x `deepseq` rnf xs 165 | 166 | vecToProd 167 | :: VecT n f a 168 | -> Prod f (Replicate n a) 169 | vecToProd = \case 170 | ØV -> Ø 171 | x :* xs -> x :< vecToProd xs 172 | 173 | vecNonEmpty 174 | :: Vec ('S n) a 175 | -> NE.NonEmpty a 176 | vecNonEmpty = \case 177 | I x :* xs -> x NE.:| toList xs 178 | -------------------------------------------------------------------------------- /old/src/Numeric/BLAS.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE DeriveFunctor #-} 3 | {-# LANGUAGE FlexibleContexts #-} 4 | {-# LANGUAGE FunctionalDependencies #-} 5 | {-# LANGUAGE GADTs #-} 6 | {-# LANGUAGE InstanceSigs #-} 7 | {-# LANGUAGE KindSignatures #-} 8 | {-# LANGUAGE LambdaCase #-} 9 | {-# LANGUAGE MultiParamTypeClasses #-} 10 | {-# LANGUAGE PolyKinds #-} 11 | {-# LANGUAGE ScopedTypeVariables #-} 12 | {-# LANGUAGE StandaloneDeriving #-} 13 | {-# LANGUAGE TemplateHaskell #-} 14 | {-# LANGUAGE TypeApplications #-} 15 | {-# LANGUAGE TypeFamilies #-} 16 | {-# LANGUAGE TypeInType #-} 17 | {-# LANGUAGE TypeOperators #-} 18 | {-# LANGUAGE UndecidableInstances #-} 19 | 20 | module Numeric.BLAS ( 21 | BLAS(..) 22 | , matVec 23 | , vecMat 24 | , matMat 25 | , outer 26 | , diag 27 | , eye 28 | , amax 29 | , concretize 30 | , matVecOp 31 | , dotOp 32 | , asumOp 33 | , module Numeric.Tensor 34 | ) where 35 | 36 | import Data.Finite 37 | import Data.Finite.Internal 38 | import Data.Foldable hiding (asum) 39 | import Data.Kind 40 | import Data.Maybe 41 | import Data.Ord 42 | import Data.Singletons 43 | import Data.Singletons.Prelude.Num 44 | import Data.Singletons.TypeLits 45 | import GHC.TypeLits 46 | import Numeric.Backprop.Op 47 | import Numeric.Tensor 48 | import qualified Data.Vector.Sized as V 49 | 50 | class Tensor b => BLAS (b :: [Nat] -> Type) where 51 | 52 | transp 53 | :: (KnownNat m, KnownNat n) 54 | => b '[m, n] 55 | -> b '[n, m] 56 | transp x = gen sing $ \case 57 | n :< m :< Ø -> tindex (m :< n :< Ø) x 58 | 59 | -- Level 1 60 | scal 61 | :: KnownNat n 62 | => Scalar b -- ^ α 63 | -> b '[n] -- ^ x 64 | -> b '[n] -- ^ α x 65 | scal α = tmap (α *) 66 | 67 | axpy 68 | :: KnownNat n 69 | => Scalar b -- ^ α 70 | -> b '[n] -- ^ x 71 | -> b '[n] -- ^ y 72 | -> b '[n] -- ^ α x + y 73 | axpy α = tzip (\x y -> α * x + y) 74 | 75 | dot :: KnownNat n 76 | => b '[n] -- ^ x 77 | -> b '[n] -- ^ y 78 | -> Scalar b -- ^ x' y 79 | dot x y = tsum (tzip (*) x y) 80 | 81 | norm2 82 | :: KnownNat n 83 | => b '[n] -- ^ x 84 | -> Scalar b -- ^ ||x|| 85 | norm2 = tsum . tmap (** 2) 86 | 87 | asum 88 | :: KnownNat n 89 | => b '[n] -- ^ x 90 | -> Scalar b -- ^ sum_i |x_i| 91 | asum = tsum . tmap abs 92 | 93 | iamax 94 | :: forall n. KnownNat n 95 | => b '[n + 1] -- ^ x 96 | -> Finite (n + 1) -- ^ argmax_i |x_i| 97 | iamax = withKnownNat (SNat @n %:+ SNat @1) $ 98 | Finite . fromIntegral . V.maxIndex . textract . tmap abs 99 | 100 | -- Level 2 101 | gemv 102 | :: (KnownNat m, KnownNat n) 103 | => Scalar b -- ^ α 104 | -> b '[m, n] -- ^ A 105 | -> b '[n] -- ^ x 106 | -> Maybe (Scalar b, b '[m]) -- ^ β, y 107 | -> b '[m] -- ^ α A x + β y 108 | gemv α a x b = maybe id (uncurry axpy) b . gen sing $ \case 109 | i :< Ø -> α * dot x (treshape sing (tslice (SliceSingle i `PMS` SliceAll `PMS` PMZ) a)) 110 | 111 | ger :: (KnownNat m, KnownNat n) 112 | => Scalar b -- ^ α 113 | -> b '[m] -- ^ x 114 | -> b '[n] -- ^ y 115 | -> Maybe (b '[m, n]) -- ^ A 116 | -> b '[m, n] -- ^ α x y' + A 117 | ger α x y b = maybe id (tzip (+)) b . gen sing $ \case 118 | i :< j :< Ø -> α * tindex (i :< Ø) x * tindex (j :< Ø) y 119 | 120 | syr :: KnownNat n 121 | => Scalar b -- ^ α 122 | -> b '[n] -- ^ x 123 | -> Maybe (b '[n, n]) -- ^ A 124 | -> b '[n, n] -- ^ x x' + A 125 | syr α x a = ger α x x a 126 | 127 | -- Level 3 128 | gemm 129 | :: (KnownNat m, KnownNat o, KnownNat n) 130 | => Scalar b -- ^ α 131 | -> b '[m, o] -- ^ A 132 | -> b '[o, n] -- ^ B 133 | -> Maybe (Scalar b, b '[m, n]) -- ^ β, C 134 | -> b '[m, n] -- ^ α A B + β C 135 | gemm α a b c = maybe id (uncurry f) c . gen sing $ \case 136 | i :< j :< Ø -> 137 | α * dot (treshape sing (tslice (SliceSingle i `PMS` SliceAll `PMS` PMZ) a)) 138 | (treshape sing (tslice (SliceAll `PMS` SliceSingle j `PMS` PMZ) b)) 139 | where 140 | f β = tzip (\d r -> β * d + r) 141 | 142 | syrk 143 | :: (KnownNat m, KnownNat n) 144 | => Scalar b -- ^ α 145 | -> b '[m, n] -- ^ A 146 | -> Maybe (Scalar b, b '[m, m]) -- ^ β, C 147 | -> b '[m, m] -- ^ α A A' + β C 148 | syrk α a c = gemm α a (transp a) c 149 | 150 | {-# MINIMAL #-} 151 | 152 | matVec 153 | :: (KnownNat m, KnownNat n, BLAS b) 154 | => b '[m, n] 155 | -> b '[n] 156 | -> b '[m] 157 | matVec a x = gemv 1 a x Nothing 158 | 159 | vecMat 160 | :: (KnownNat m, KnownNat n, BLAS b) 161 | => b '[m] 162 | -> b '[m, n] 163 | -> b '[n] 164 | vecMat x a = gemv 1 (transp a) x Nothing 165 | 166 | matMat 167 | :: (KnownNat m, KnownNat o, KnownNat n, BLAS b) 168 | => b '[m, o] 169 | -> b '[o, n] 170 | -> b '[m, n] 171 | matMat a b = gemm 1 a b Nothing 172 | 173 | outer 174 | :: (KnownNat m, KnownNat n, BLAS b) 175 | => b '[m] 176 | -> b '[n] 177 | -> b '[m, n] 178 | outer x y = ger 1 x y Nothing 179 | 180 | diag 181 | :: (KnownNat n, Tensor b) 182 | => b '[n] 183 | -> b '[n, n] 184 | diag x = gen sing $ \case 185 | i :< j :< Ø 186 | | i `equals` j -> tindex (i :< Ø) x 187 | | otherwise -> 0 188 | 189 | eye 190 | :: (KnownNat n, Tensor b) 191 | => b '[n, n] 192 | eye = gen sing $ \case 193 | i :< j :< Ø 194 | | i `equals` j -> 1 195 | | otherwise -> 0 196 | 197 | amax 198 | :: forall b n. (BLAS b, KnownNat n) 199 | => b '[n + 1] 200 | -> Scalar b 201 | amax = do 202 | i <- only . iamax 203 | withKnownNat (SNat @n %:+ SNat @1) $ 204 | tindex i 205 | 206 | concretize :: forall b n. (BLAS b, KnownNat n) => b '[n + 1] -> b '[n + 1] 207 | concretize = withKnownNat (SNat @n %:+ SNat @1) $ 208 | oneHot . only . iamax 209 | 210 | matVecOp 211 | :: (KnownNat m, KnownNat n, BLAS b) 212 | => Op '[ b '[m, n], b '[n] ] '[ b '[m] ] 213 | matVecOp = op2' $ \a x -> 214 | ( only_ (matVec a x) 215 | , (\g -> (outer g x, vecMat g a)) 216 | . fromMaybe (tkonst sing 1) 217 | . head' 218 | ) 219 | 220 | dotOp 221 | :: forall b n. (KnownNat n, BLAS b) 222 | => Op '[ b '[n], b '[n] ] '[ Scalar b ] 223 | dotOp = op2' $ \x y -> 224 | ( only_ (dot x y) 225 | , \case Nothing :< Ø -> (y , x ) 226 | Just g :< Ø -> (scal g y, scal g x) 227 | ) 228 | 229 | asumOp 230 | :: forall b n. (KnownNat n, BLAS b, Num (b '[n])) 231 | => Op '[ b '[n] ] '[ Scalar b ] 232 | asumOp = op1' $ \x -> 233 | ( only_ (asum x) 234 | , \case Nothing :< Ø -> signum x 235 | Just g :< Ø -> scal g (signum x) 236 | ) 237 | -------------------------------------------------------------------------------- /old/src/Learn/Neural/Network/Dropout.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-# LANGUAGE FlexibleInstances #-} 4 | {-# LANGUAGE GADTs #-} 5 | {-# LANGUAGE KindSignatures #-} 6 | {-# LANGUAGE LambdaCase #-} 7 | {-# LANGUAGE ScopedTypeVariables #-} 8 | {-# LANGUAGE Strict #-} 9 | {-# LANGUAGE TypeApplications #-} 10 | {-# LANGUAGE TypeInType #-} 11 | {-# LANGUAGE TypeOperators #-} 12 | {-# LANGUAGE UndecidableInstances #-} 13 | 14 | module Learn.Neural.Network.Dropout ( 15 | Dropout(..) 16 | , NetworkDO(..) 17 | , netOpDO 18 | , netOpDOPure 19 | , konstDO, konstDO', mapDO, mapDO' 20 | , alongNet 21 | ) where 22 | 23 | 24 | import Control.Monad.Primitive 25 | import Data.Bool 26 | import Data.Kind 27 | import Data.Singletons 28 | import Data.Traversable 29 | import Data.Type.Combinator 30 | import Data.Type.Product 31 | import GHC.TypeLits 32 | import Learn.Neural.Layer 33 | import Learn.Neural.Network 34 | import Numeric.BLAS 35 | import Numeric.Backprop 36 | import Numeric.Backprop.Op 37 | import System.Random.MWC 38 | import System.Random.MWC.Distributions 39 | import Type.Class.Known 40 | 41 | data NetworkDO :: RunMode -> ([Nat] -> Type) -> LChain -> [LChain] -> [Nat] -> Type where 42 | NDO :: { ndoDropout :: !(Dropout r b i hs o) 43 | , ndoNetwork :: !(Network r b i hs o) 44 | } 45 | -> NetworkDO r b i hs o 46 | 47 | -- | For something like 48 | -- 49 | -- @ 50 | -- l1 ':&' l2 :& 'NetExt' l3 51 | -- @ 52 | -- 53 | -- you could have 54 | -- 55 | -- @ 56 | -- 0.2 ':&%' 0.3 :&% DOExt 57 | -- @ 58 | -- 59 | -- Would would represent a 20% dropout after @l1@ and a 30% dropout after 60 | -- @l2@. By conscious design decision, 'DOExt' does not take any dropout 61 | -- rate, and it is therefore impossible to have dropout after the final 62 | -- layer before the output. It is also not possible to have dropout before 63 | -- the first layer, after the input. 64 | data Dropout :: RunMode -> ([Nat] -> Type) -> LChain -> [LChain] -> [Nat] -> Type where 65 | DOExt 66 | :: Dropout r b (i :~ c) '[] o 67 | (:&%) 68 | :: (Num (b h), SingI h) 69 | => !(Maybe Double) 70 | -> !(Dropout r b (h :~ d) hs o) 71 | -> Dropout r b (i :~ c) ((h :~ d) ': hs) o 72 | 73 | infixr 4 :&% 74 | 75 | alongNet 76 | :: n r b i hs o 77 | -> Dropout r b i hs o 78 | -> Dropout r b i hs o 79 | alongNet _ d = d 80 | 81 | konstDO 82 | :: forall r b i hs o. () 83 | => NetStruct r b i hs o 84 | -> Maybe Double 85 | -> Dropout r b i hs o 86 | konstDO s0 x = go s0 87 | where 88 | go :: NetStruct r b j js o -> Dropout r b j js o 89 | go = \case 90 | NSExt -> DOExt 91 | NSInt s -> x :&% go s 92 | 93 | konstDO' 94 | :: forall r b i hs o. Known (NetStruct r b i hs) o 95 | => Maybe Double 96 | -> Dropout r b i hs o 97 | konstDO' = konstDO known 98 | 99 | mapDO 100 | :: forall r b i hs o. () 101 | => (Double -> Double) 102 | -> Dropout r b i hs o 103 | -> Dropout r b i hs o 104 | mapDO f = go 105 | where 106 | go :: Dropout r b j js o -> Dropout r b j js o 107 | go = \case 108 | DOExt -> DOExt 109 | x :&% d -> fmap f x :&% go d 110 | 111 | mapDO' 112 | :: forall r b i hs o. () 113 | => (Maybe Double -> Maybe Double) 114 | -> Dropout r b i hs o 115 | -> Dropout r b i hs o 116 | mapDO' f = go 117 | where 118 | go :: Dropout r b j js o -> Dropout r b j js o 119 | go = \case 120 | DOExt -> DOExt 121 | x :&% d -> f x :&% go d 122 | 123 | netOpDO 124 | :: forall m b i c hs o r. (BLAS b, Num (b i), Num (b o), PrimMonad m, SingI o) 125 | => Dropout r b (i :~ c) hs o 126 | -> Gen (PrimState m) 127 | -> m (OpBS '[ Network r b (i :~ c) hs o, b i ] '[ Network r b (i :~ c) hs o, b o ]) 128 | netOpDO = \case 129 | DOExt -> \_ -> return $ OpBS $ OpM $ \(I n :< I x :< Ø) -> case n of 130 | NetExt (l :: Layer r c b i o) -> do 131 | (I l' :< I y :< Ø, gF) <- runOpM' (layerOp @r @c @i @o @b) (l ::< x ::< Ø) 132 | let gF' = fmap (\case I dL :< I dX :< Ø -> NetExt dL ::< dX ::< Ø) 133 | . gF 134 | . (\case Just (NetExt dL) :< dY :< Ø -> 135 | Just dL :< dY :< Ø 136 | Nothing :< dY :< Ø -> 137 | Nothing :< dY :< Ø 138 | ) 139 | return (NetExt l' ::< y ::< Ø, gF') 140 | r :&% (d :: Dropout r b (h :~ d) js o) -> \g -> do 141 | mask <- forM r $ \r' -> 142 | genA @b (sing @_ @h) $ \_ -> bool (1 / (1 - realToFrac r')) 0 <$> bernoulli r' g 143 | no :: OpBS '[ Network r b (h :~ d) js o, b h ] '[ Network r b (h :~ d) js o, b o ] 144 | <- netOpDO @m @b @h @d @js @o @r d g 145 | return $ OpBS $ OpM $ \(I n :< I x :< Ø) -> case n of 146 | (l :: Layer r c b i h) :& (n2 :: Network r b (h ':~ d) js o) -> do 147 | (I l' :< I y :< Ø, gF ) <- runOpM' (layerOp @r @c @i @h @b) (l ::< x ::< Ø) 148 | (I n2' :< I z :< Ø, gF') <- runOpM' (runOpBS no ) (n2 ::< y ::< Ø) 149 | let gF'' = \case Just (dL :& dN) :< dZ :< Ø -> do 150 | I dN2 :< I dY :< Ø <- gF' (Just dN :< dZ :< Ø) 151 | let dY' = maybe dY (tzip (*) dY) mask 152 | I dL0 :< I dX :< Ø <- gF (Just dL :< Just dY' :< Ø) 153 | return $ (dL0 :& dN2) ::< dX ::< Ø 154 | Nothing :< dZ :< Ø -> do 155 | I dN2 :< I dY :< Ø <- gF' (Nothing :< dZ :< Ø) 156 | let dY' = maybe dY (tzip (*) dY) mask 157 | I dL0 :< I dX :< Ø <- gF (Nothing :< Just dY' :< Ø) 158 | return $ (dL0 :& dN2) ::< dX ::< Ø 159 | return ((l' :& n2') ::< z ::< Ø, gF'') 160 | 161 | netOpDOPure 162 | :: forall m b i c hs o. (BLAS b, Num (b i), Num (b o), PrimMonad m, SingI o) 163 | => Dropout 'FeedForward b (i :~ c) hs o 164 | -> Gen (PrimState m) 165 | -> m (OpBS '[ Network 'FeedForward b (i :~ c) hs o, b i ] '[ b o ]) 166 | netOpDOPure = \case 167 | DOExt -> \_ -> return $ OpBS $ OpM $ \(I n :< I x :< Ø) -> case n of 168 | NetExt (l :: Layer 'FeedForward c b i o) -> do 169 | (I y :< Ø, gF) <- runOpM' (layerOpPure @c @i @o @b) (l ::< x ::< Ø) 170 | let gF' = fmap (\case I dL :< I dX :< Ø -> NetExt dL ::< dX ::< Ø) 171 | . gF 172 | return (y ::< Ø, gF') 173 | r :&% (d :: Dropout 'FeedForward b (h :~ d) js o) -> \g -> do 174 | mask <- forM r $ \r' -> 175 | genA @b (sing @_ @h) $ \_ -> bool (1 / (1 - realToFrac r')) 0 <$> bernoulli r' g 176 | no :: OpBS '[ Network 'FeedForward b (h :~ d) js o, b h ] '[ b o ] 177 | <- netOpDOPure @m @b @h @d @js @o d g 178 | return $ OpBS $ OpM $ \(I n :< I x :< Ø) -> case n of 179 | (l :: Layer 'FeedForward c b i h) :& (n2 :: Network 'FeedForward b (h ':~ d) js o) -> do 180 | (I y :< Ø, gF ) <- runOpM' (layerOpPure @c @i @h @b) (l ::< x ::< Ø) 181 | (I z :< Ø, gF') <- runOpM' (runOpBS no ) (n2 ::< y ::< Ø) 182 | let gF'' dZ = do 183 | I dN2 :< I dY :< Ø <- gF' dZ 184 | let dY' = maybe dY (tzip (*) dY) mask 185 | I dL0 :< I dX :< Ø <- gF (Just dY' :< Ø) 186 | return $ (dL0 :& dN2) ::< dX ::< Ø 187 | return (z ::< Ø, gF'') 188 | -------------------------------------------------------------------------------- /src/Backprop/Learn/Model/Neural/LSTM.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE DeriveAnyClass #-} 3 | {-# LANGUAGE DeriveDataTypeable #-} 4 | {-# LANGUAGE DeriveGeneric #-} 5 | {-# LANGUAGE DerivingVia #-} 6 | {-# LANGUAGE FlexibleInstances #-} 7 | {-# LANGUAGE GADTs #-} 8 | {-# LANGUAGE KindSignatures #-} 9 | {-# LANGUAGE MultiParamTypeClasses #-} 10 | {-# LANGUAGE PatternSynonyms #-} 11 | {-# LANGUAGE RankNTypes #-} 12 | {-# LANGUAGE RecordWildCards #-} 13 | {-# LANGUAGE ScopedTypeVariables #-} 14 | {-# LANGUAGE StandaloneDeriving #-} 15 | {-# LANGUAGE TemplateHaskell #-} 16 | {-# LANGUAGE TypeApplications #-} 17 | {-# LANGUAGE TypeFamilies #-} 18 | {-# LANGUAGE TypeInType #-} 19 | {-# LANGUAGE TypeOperators #-} 20 | {-# LANGUAGE UndecidableInstances #-} 21 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 22 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} 23 | 24 | module Backprop.Learn.Model.Neural.LSTM ( 25 | -- * LSTM 26 | lstm 27 | , LSTMp(..), lstmForget, lstmInput, lstmUpdate, lstmOutput 28 | , reshapeLSTMpInput 29 | , reshapeLSTMpOutput 30 | , lstm' 31 | -- * GRU 32 | , gru 33 | , GRUp(..), gruMemory, gruUpdate, gruOutput 34 | , gru' 35 | ) where 36 | 37 | import Backprop.Learn.Initialize 38 | import Backprop.Learn.Model.Function 39 | import Backprop.Learn.Model.Neural 40 | import Backprop.Learn.Model.Regression 41 | import Backprop.Learn.Model.State 42 | import Backprop.Learn.Model.Types 43 | import Control.DeepSeq 44 | import Control.Monad 45 | import Control.Monad.Primitive 46 | import Data.Type.Tuple 47 | import Data.Typeable 48 | import GHC.Generics (Generic) 49 | import GHC.TypeNats 50 | import Lens.Micro 51 | import Lens.Micro.TH 52 | import Numeric.Backprop 53 | import Numeric.LinearAlgebra.Static.Backprop 54 | import Numeric.OneLiner 55 | import Numeric.Opto.Ref 56 | import Numeric.Opto.Update 57 | import Statistics.Distribution 58 | import qualified Data.Binary as Bi 59 | import qualified Numeric.LinearAlgebra.Static as H 60 | import qualified System.Random.MWC as MWC 61 | 62 | -- TODO: allow parameterize internal activation function? 63 | -- TODO: Peepholes 64 | 65 | -- | 'LSTM' layer parmateters 66 | data LSTMp (i :: Nat) (o :: Nat) = 67 | LSTMp { _lstmForget :: !(FCp (i + o) o) 68 | , _lstmInput :: !(FCp (i + o) o) 69 | , _lstmUpdate :: !(FCp (i + o) o) 70 | , _lstmOutput :: !(FCp (i + o) o) 71 | } 72 | deriving stock (Generic, Typeable, Show) 73 | deriving anyclass (NFData, Linear Double, Metric Double, Bi.Binary, Regularize, Backprop) 74 | 75 | deriving via (GNum (LSTMp i o)) instance (KnownNat i, KnownNat o) => Num (LSTMp i o) 76 | deriving via (GNum (LSTMp i o)) instance (KnownNat i, KnownNat o) => Fractional (LSTMp i o) 77 | deriving via (GNum (LSTMp i o)) instance (KnownNat i, KnownNat o) => Floating (LSTMp i o) 78 | 79 | makeLenses ''LSTMp 80 | 81 | instance (PrimMonad m, KnownNat i, KnownNat o) => Mutable m (LSTMp i o) where 82 | type Ref m (LSTMp i o) = GRef m (LSTMp i o) 83 | thawRef = gThawRef 84 | freezeRef = gFreezeRef 85 | copyRef = gCopyRef 86 | instance (PrimMonad m, KnownNat i, KnownNat o) => LinearInPlace m Double (LSTMp i o) 87 | 88 | instance (PrimMonad m, KnownNat i, KnownNat o) => Learnable m (LSTMp i o) 89 | 90 | -- | Stateless version of 'lstm' that takes the "previous input" as a part 91 | -- of the input vector. 92 | lstm' 93 | :: (KnownNat i, KnownNat o) 94 | => Model ('Just (LSTMp i o)) ('Just (R o)) (R (i + o)) (R o) 95 | lstm' = modelD $ \(PJust p) x (PJust s) -> 96 | let forget = logistic $ runLRp (p ^^. lstmForget) x 97 | input = logistic $ runLRp (p ^^. lstmInput ) x 98 | update = tanh $ runLRp (p ^^. lstmUpdate) x 99 | s' = forget * s + input * update 100 | o = logistic $ runLRp (p ^^. lstmOutput) x 101 | h = o * tanh s' 102 | in (h, PJust s') 103 | 104 | -- | Long-term short-term memory layer 105 | -- 106 | -- 107 | -- 108 | lstm 109 | :: (KnownNat i, KnownNat o) 110 | => Model ('Just (LSTMp i o)) ('Just (R o :# R o)) (R i) (R o) 111 | lstm = recurrent H.split (H.#) id lstm' 112 | 113 | reshapeLSTMpInput 114 | :: (ContGen d, PrimMonad m, KnownNat i, KnownNat i', KnownNat o) 115 | => d 116 | -> MWC.Gen (PrimState m) 117 | -> LSTMp i o 118 | -> m (LSTMp i' o) 119 | reshapeLSTMpInput d g (LSTMp forget input update output) = 120 | LSTMp <$> reshaper forget 121 | <*> reshaper input 122 | <*> reshaper update 123 | <*> reshaper output 124 | where 125 | reshaper = reshapeLRpInput d g 126 | 127 | reshapeLSTMpOutput 128 | :: (ContGen d, PrimMonad m, KnownNat i, KnownNat o, KnownNat o') 129 | => d 130 | -> MWC.Gen (PrimState m) 131 | -> LSTMp i o 132 | -> m (LSTMp i o') 133 | reshapeLSTMpOutput d g (LSTMp forget input update output) = 134 | LSTMp <$> reshaper forget 135 | <*> reshaper input 136 | <*> reshaper update 137 | <*> reshaper output 138 | where 139 | reshaper = reshapeLRpInput d g 140 | <=< reshapeLRpOutput d g 141 | 142 | -- | Forget biases initialized to 1 143 | instance (KnownNat i, KnownNat o) => Initialize (LSTMp i o) where 144 | initialize d g = LSTMp <$> set (mapped . fcBias) 1 (initialize d g) 145 | <*> initialize d g 146 | <*> initialize d g 147 | <*> initialize d g 148 | 149 | -- | 'GRU' layer parmateters 150 | data GRUp (i :: Nat) (o :: Nat) = 151 | GRUp { _gruMemory :: !(FCp (i + o) o) 152 | , _gruUpdate :: !(FCp (i + o) o) 153 | , _gruOutput :: !(FCp (i + o) o) 154 | } 155 | deriving stock (Generic, Typeable, Show) 156 | deriving anyclass (NFData, Linear Double, Metric Double, Bi.Binary, Initialize, Regularize, Backprop) 157 | 158 | deriving via (GNum (GRUp i o)) instance (KnownNat i, KnownNat o) => Num (GRUp i o) 159 | deriving via (GNum (GRUp i o)) instance (KnownNat i, KnownNat o) => Fractional (GRUp i o) 160 | deriving via (GNum (GRUp i o)) instance (KnownNat i, KnownNat o) => Floating (GRUp i o) 161 | 162 | makeLenses ''GRUp 163 | 164 | instance (PrimMonad m, KnownNat i, KnownNat o) => Mutable m (GRUp i o) where 165 | type Ref m (GRUp i o) = GRef m (GRUp i o) 166 | thawRef = gThawRef 167 | freezeRef = gFreezeRef 168 | copyRef = gCopyRef 169 | instance (KnownNat i, KnownNat o, Mutable m (GRUp i o)) => LinearInPlace m Double (GRUp i o) 170 | 171 | instance (KnownNat i, KnownNat o, PrimMonad m) => Learnable m (GRUp i o) 172 | 173 | -- | Stateless version of 'gru' that takes the "previous input" as a part 174 | -- of the input vector. 175 | gru' 176 | :: forall i o. (KnownNat i, KnownNat o) 177 | => Model ('Just (GRUp i o)) 'Nothing (R (i + o)) (R o) 178 | gru' = modelStatelessD $ \(PJust p) x -> 179 | let z = logistic $ runLRp (p ^^. gruMemory) x 180 | r = logistic $ runLRp (p ^^. gruUpdate) x 181 | r' = 1 # r 182 | h' = tanh $ runLRp (p ^^. gruOutput) (r' * x) 183 | in (1 - z) * snd (split @i x) + z * h' 184 | 185 | -- | Gated Recurrent Unit 186 | -- 187 | -- 188 | -- 189 | gru :: (KnownNat i, KnownNat o) 190 | => Model ('Just (GRUp i o)) ('Just (R o)) (R i) (R o) 191 | gru = recurrent H.split (H.#) id gru' 192 | --------------------------------------------------------------------------------