├── .gitignore ├── DeepLearning.ipkg ├── ExLens ├── BiTambara.hs ├── ExLens.hs ├── NNet.hs └── Perceptron.hs ├── Haskell ├── NNet.hs ├── PLens.hs └── Perceptron.hs ├── Idris ├── HVect.idr ├── Main.idr ├── NNet.idr ├── PLens.idr └── Perceptron.idr ├── LICENSE ├── PreLens ├── Main.hs ├── NNet.hs ├── ParaLens.hs ├── Params.hs ├── Perceptron.hs ├── PreLens.hs ├── Tambara.hs └── TriLens.hs ├── README.md └── src ├── HVect.idr ├── Main.idr ├── NNet.idr ├── PLens.idr └── Perceptron.idr /.gitignore: -------------------------------------------------------------------------------- 1 | # Idris 2 2 | *.ttc 3 | *.ttm 4 | 5 | # Idris 1 6 | *.ibc 7 | *.o 8 | src/build/exec/_tmpchez 9 | src/build/exec/_tmpchez_app/_tmpchez.ss 10 | src/build/exec/_tmpchez_app/libidris2_support.dylib 11 | -------------------------------------------------------------------------------- /DeepLearning.ipkg: -------------------------------------------------------------------------------- 1 | package DeepLearning 2 | version = 0.1.0 3 | authors = "Bartosz Milewski" 4 | main = Main 5 | executable = "Deep" 6 | sourcedir = "src" 7 | -------------------------------------------------------------------------------- /ExLens/BiTambara.hs: -------------------------------------------------------------------------------- 1 | {-# language ScopedTypeVariables #-} 2 | module BiTambara where 3 | import Data.Bifunctor ( Bifunctor(second, first, bimap) ) 4 | 5 | -- The monoidal plumbing 6 | 7 | lunit_1 :: q -> ((), q) 8 | lunit_1 q = ((), q) 9 | lunit :: ((), q) -> q 10 | lunit ((), q) = q 11 | runit_1 :: q -> (q, ()) 12 | runit_1 q = (q, ()) 13 | runit :: (q, ()) -> q 14 | runit (q, ()) = q 15 | 16 | assoc :: ((a, b), c) -> (a, (b, c)) 17 | assoc ((a, b), c) = (a, (b, c)) 18 | assoc_1 :: (a, (b, c)) -> ((a, b), c) 19 | assoc_1 (a, (b, c))= ((a, b), c) 20 | 21 | -- This is the same existential lens but with rearranged wires 22 | -- outputs first, then parameters, then inputs 23 | data ExLens a da p dp s ds = 24 | forall m . ExLens ((p, s) -> (m, a)) 25 | ((m, da) -> (dp, ds)) 26 | 27 | -- Accessors: forward and backward pass 28 | fwd :: ExLens a da p dp s ds -> (p, s) -> a 29 | fwd (ExLens f b) = snd . f 30 | bwd :: ExLens a da p dp s ds -> (p, s, da) -> (dp, ds) 31 | bwd (ExLens f b) (p, s, da)= b (fst (f (p, s)), da) 32 | 33 | -- Profunctor representation of lens 34 | -- We need parameterized profunctors 35 | 36 | class BiProfunctor p where 37 | dimap :: (a' -> a) -> (b -> b') -> p q q' a b -> p q q' a' b' 38 | dimap' :: (r -> q) -> (q' -> r') -> p q q' a b -> p r r' a b 39 | 40 | -- A BiTambara module has the same alpha structure as regular Tambara 41 | -- Plus an additional beta structure going in the opposite direction 42 | -- which lets us switch the parameters 43 | class BiProfunctor p => BiTambara p where 44 | alpha :: p q q' a da -> p q q' (m, a) (m, da) 45 | beta :: p r r' (q, s) (q', ds) -> p (r, q) (r', q') s ds 46 | 47 | -- BiTambara lens generalizes a profunctor lens 48 | -- (forall p. Tambara p => p a b -> p s t) 49 | -- It's polymorphic both in BiTambara modules and in parameters r r' 50 | type BiLens q q' s ds a da = 51 | forall p. BiTambara p => forall r r'. p r r' a da -> p (r, q) (r', q') s ds 52 | 53 | -- This is the identity parametric lens wrt composition 54 | identityLens :: ExLens a da () () a da 55 | identityLens = ExLens id id 56 | 57 | -- An existential lens is an example of a BiProfunctor 58 | 59 | instance BiProfunctor (ExLens a da) where 60 | dimap :: (s' -> s) -> (ds -> ds') -> ExLens a da q q' s ds -> ExLens a da q q' s' ds' 61 | dimap f g (ExLens fw bw) = ExLens fw' bw' 62 | where fw' (q, s') = fw (q, f s') 63 | bw' (m, da) = second g (bw (m, da)) 64 | dimap' :: (r -> q) -> (q' -> r') -> ExLens a da q q' s ds -> ExLens a da r r' s ds 65 | dimap' f g (ExLens fw bw) = ExLens fw' bw' 66 | where 67 | --fw' :: (r, s) -> (m, a) 68 | fw' (r, s) = fw (f r, s) 69 | --bw' :: (m, da) -> (r', ds) 70 | bw' (m, da) = first g $ bw (m, da) 71 | 72 | -- An existential lens is an example of a BiTambara module 73 | 74 | instance BiTambara (ExLens a da) where 75 | alpha :: ExLens a da q q' s ds -> ExLens a da q q' (n, s) (n, ds) 76 | alpha (ExLens fw bw) = ExLens fw' bw' 77 | where fw' (q, (n, s)) = first (n,) $ fw (q, s) -- use (n, m) as residue 78 | bw' ((n, m), da) = second (n,) (bw (m, da)) 79 | 80 | beta :: ExLens a da r r' (q, s) (q', ds) -> ExLens a da (r, q) (r', q') s ds 81 | beta (ExLens fw bw) = ExLens fw' bw' 82 | -- fw :: (r, (q, s)) -> (m, a) 83 | -- bw :: (m, da) -> (r', (q', ds)) 84 | -- fw' :: ((r, q), s) -> (m, a) 85 | -- bw' :: (m, da) -> ((r', q') ds) 86 | where fw' ((r, q), s) = fw (r, (q, s)) 87 | bw' (m, da) = let (r', (q', ds)) = bw (m, da) 88 | in ((r', q'), ds) 89 | 90 | -- The category of Parametric lenses is equivalent to 91 | -- the category of BiTambara lenses 92 | 93 | -- Conversion from bi-Tambara to ExLens uses the Yoneda trick 94 | -- since identityLens is a bi-Tambara module, we can 95 | -- apply a BiLens to it 96 | 97 | -- p q' q a da -> p (q', q) (q, q') s ds 98 | fromTamb :: BiLens q q' s ds a da -> ExLens a da q q' s ds 99 | fromTamb pab_pst = dimap' lunit_1 lunit $ pab_pst identityLens 100 | 101 | -- Conversion from ExLens to BiLens 102 | toTamb :: ExLens a da q q' s ds -> BiLens q q' s ds a da 103 | -- p r r' a da -> p (r, q) (r', q') s ds 104 | -- p r r' (m, a) (m, da) 105 | -- p r r' (q, s) (q', ds) 106 | -- p (r, q) (r', q') s ds 107 | toTamb (ExLens fw bw) = beta . dimap fw bw . alpha 108 | 109 | -- Serial composition of BiLenses is just a (re-associated) function composition 110 | 111 | comp :: BiLens q q' s ds a da -> BiLens r r' v dv s ds -> BiLens (q, r) (q', r') v dv a da 112 | -- p z z' a da -> p (z, q) (z', q') s ds 113 | -- p (z, q) (z', q') s ds -> p ((z, q), r) ((z', q'), r') v dv 114 | -- reassoc 115 | -- p z z' a da -> p (z, (q, r)) (z', (q', r') v dv 116 | comp l1 l2 = dimap' assoc_1 assoc . l2 . l1 117 | 118 | -- Parallel product of BiLenses 119 | 120 | -- Show that a BiTambara of products is a BiTambara in both sides of the product 121 | 122 | -- Rearrange the wires for Haskell to see q' r' q' b' in one block 123 | data PRight p q r a b q' r' a' b' = PRight { unPRight :: p (q, q') (r, r') (a, a') (b, b') } 124 | 125 | -- It's a BiProfunctor in these variables 126 | instance (BiProfunctor p) => BiProfunctor (PRight p q r a b) where 127 | dimap f g (PRight p) = PRight $ dimap (second f) (second g) p 128 | dimap' f g (PRight p) = PRight $ dimap' (second f) (second g) p 129 | -- dimap' :: (r -> q) -> (q' -> r') -> p q q' a b -> p r r' a b 130 | 131 | -- It's a BiTambara in thes variables 132 | instance (BiTambara p) => BiTambara (PRight p q r a b) where 133 | -- p (q, q') (r, r') (a, a') (b, b') -alpha-> 134 | -- p (q, q') (r, r') (m, (a, a')) (m, (b, b')) -dimap-> 135 | -- p (q, q') (r, r') (a, (m, a')) (b, (m, b')) 136 | alpha (PRight p) = PRight $ dimap (\(a, (m, a')) -> (m, (a, a'))) 137 | (\(m, (b, b')) -> (b, (m, b'))) $ alpha p 138 | -- p (q, q') (r, r') (a, (q1, a')) (b, (q1', b')) - dimap-> 139 | -- p (q, q') (r, r') (q1, (a, a')) (q1', (b, b')) - beta -> 140 | -- p ((q, q'), q1) ((r, r'), q1') (a, a') (b, b') - dimap' -> 141 | -- p (q, (q', q1)) (r, (r', q1')) (a, a') (b, b') 142 | beta (PRight p) = PRight $ 143 | dimap' (\(q, (q', q1)) -> ((q, q'), q1)) (\((r, r'), q1') -> (r, (r', q1'))) $ 144 | beta $ 145 | dimap (\(q1, (a, a')) -> (a, (q1, a'))) (\(b, (q1', b')) -> (q1', (b, b'))) p 146 | 147 | -- And the same for left sides of products: q r a b 148 | 149 | data PLeft p q' r' a' b' q r a b = PLeft { unPLeft :: p (q, q') (r, r') (a, a') (b, b') } 150 | 151 | instance (BiProfunctor p) => BiProfunctor (PLeft p q r a b) where 152 | dimap f g (PLeft p) = PLeft $ dimap (first f) (first g) p 153 | dimap' f g (PLeft p) = PLeft $ dimap' (first f) (first g) p 154 | 155 | instance (BiTambara p) => BiTambara (PLeft p q r a b) where 156 | alpha (PLeft p) = PLeft $ dimap (\((m, a), a') -> (m, (a, a'))) 157 | (\(m, (b, b')) -> ((m, b), b')) $ alpha p 158 | beta (PLeft p) = PLeft $ 159 | dimap' (\((q, q1), q') -> ((q, q'), q1)) (\((r, r'), q1') -> ((r, q1'), r')) $ 160 | beta $ 161 | dimap (\(q1, (a, a')) -> ((q1, a), a')) (\((q1', b), b') -> (q1', (b, b'))) p 162 | 163 | -- Parallel composition of bi-Tambara representations 164 | 165 | prodLens :: BiLens q dq s ds a da -> 166 | BiLens q' dq' s' ds' a' da' -> 167 | BiLens (q, q') (dq, dq') (s, s') (ds, ds') (a, a') (da, da') 168 | -- l1 :: p1 r1 r1' a da -> p1 (r1, q) (r1', dq) s ds 169 | -- l2 :: p2 r2 r2' a' da' -> p2 (r2, q') (r2', dq') s' ds' 170 | -- l3 :: p r r' (a, a') (da, da') -> p (r, (q, q')) (r', (dq, dq')) (s, s') (ds, ds') 171 | prodLens l1 l2 = 172 | dimap' assoc_1 assoc . 173 | dimap' (second lunit_1) (second lunit) . 174 | unPRight . l2 . PRight . unPLeft . l1 . PLeft . 175 | dimap' runit runit_1 176 | 177 | {- -- Or more tediously: 178 | prodLens l1 l2 p0 = 179 | let p1 = dimap' runit runit_1 p0 180 | pl1 = PLeft p1 181 | (PLeft pl2) = l1 pl1 182 | pr1 = PRight pl2 183 | (PRight pr2) = l2 pr1 184 | pr3 = dimap' (second lunit_1) (second lunit) pr2 185 | in dimap' assoc_1 assoc pr3 186 | -} 187 | -- p0 :: (r, ()) (r', ()) (a, a') (da, da') ~ PLeft () () a' da' r r' a da 188 | -- l1 :: PLeft () () a' da' (r, q) (r', dq) s ds 189 | -- PRight (r, q) (r', dq) s ds () () a' da' 190 | -- l2 :: PRigth (r, q) (r', dq) s ds ((), q') ((), dq') s' ds' 191 | -- ((r, q), q') ((r', dq), dq') (s, s') (ds, ds') 192 | -- data PRight q r a b q' r' a' b' = (q, q') (r, r') (a, a') (b, b') 193 | -- data PLeft q' r' a' b' q r a b = (q, q') (r, r') (a, a') (b, b') 194 | 195 | 196 | -- Testing 197 | 198 | type D = Double 199 | -- Ideally, a counted vector 200 | type V = [D] 201 | 202 | -- Simple linear lens, scalar product of parameters and inputs 203 | linearL :: ExLens D D V V V V 204 | linearL = ExLens fw bw 205 | where 206 | fw :: (V, V) -> ((V, V), D) 207 | -- a = Sum p * s 208 | fw (p, s) = ((s, p), sum $ zipWith (*) p s) 209 | -- da/dp = s, da/ds = p 210 | bw :: ((V, V), D) -> (V, V) 211 | bw ((s, p), da) = (fmap (da *) s -- da/dp 212 | ,fmap (da *) p) -- da/ds 213 | 214 | -- Add bias to input 215 | biasL :: ExLens D D D D D D 216 | biasL = ExLens fw bw 217 | where 218 | fw :: (D, D) -> ((), D) 219 | fw (p, s) = ((), p + s) 220 | -- da/dp = 1, da/ds = 1 221 | bw :: ((), D) -> (D, D) 222 | bw (_, da) = (da, da) 223 | 224 | -- Convert both to BiLens 225 | -- p V D D -> p V V V 226 | linearT :: BiLens V V V V D D 227 | linearT = toTamb linearL 228 | -- p D D D -> p D D D 229 | biasT :: BiLens D D D D D D 230 | biasT = toTamb biasL 231 | 232 | -- Compose two BiLenses 233 | -- comp :: BiLens q q' s ds a da -> BiLens r r' v dv s ds -> BiLens (q, r) (q', r') v dv a da 234 | affineT :: BiLens (D, V) (D, V) V V D D 235 | affineT = comp biasT linearT 236 | 237 | -- Turn the composition back to ExLens 238 | -- fromTamb :: BiLens q q' s ds a da -> ExLens a da q q' s ds 239 | affine :: ExLens D D (D, V) (D, V) V V 240 | affine = fromTamb affineT 241 | 242 | testTamb :: IO () 243 | testTamb = do 244 | putStrLn "forward" 245 | print $ fwd affine ((0.1, [-1, 1]), [2, 30]) 246 | putStrLn "backward" 247 | -- (Para [1.3, -1.4] 0.1, [21, 33], 1) 248 | print $ bwd affine ((0.1, [1.3, -1.4]), [21, 33], 1) 249 | -------------------------------------------------------------------------------- /ExLens/ExLens.hs: -------------------------------------------------------------------------------- 1 | module ExLens where 2 | 3 | data PLens a da p dp s ds = 4 | PLens { fwd' :: (p, s) -> a 5 | , bwd' :: (p, s, da) -> (dp, ds) 6 | } 7 | 8 | -- Existential parametic lens 9 | 10 | data ExLens a da p dp s ds = 11 | forall m . ExLens ((p, s) -> (m, a)) 12 | ((m, da) -> (dp, ds)) 13 | 14 | -- For convenience, a lens with empty (unit) parameter 15 | data Lens s ds a da = 16 | forall m . Lens (s -> (m, a)) 17 | ((m, da) -> ds) 18 | 19 | -- Accessors 20 | 21 | fwd :: ExLens a da p dp s ds -> (p, s) -> a 22 | fwd (ExLens f g) (p, s) = snd $ f (p, s) 23 | 24 | bwd :: ExLens a da p dp s ds -> (p, s, da) -> (dp, ds) 25 | bwd (ExLens f g) (p, s, da) = g (fst (f (p, s)), da) 26 | 27 | fwd0 :: Lens s ds a da -> s -> a 28 | fwd0 (Lens f g) s = snd $ f s 29 | 30 | bwd0 :: Lens s ds a da -> (s, da) -> ds 31 | bwd0 (Lens f g) (s, da) = g (fst (f s), da) 32 | 33 | -- Serial composition 34 | 35 | compose :: 36 | ExLens a da p dp s ds -> ExLens b db q dq a da -> 37 | ExLens b db (p, q) (dp, dq) s ds 38 | compose (ExLens f1 g1) (ExLens f2 g2) = ExLens f3 g3 39 | where 40 | f3 ((p, q), s) = 41 | let (m, a) = f1 (p, s) 42 | (n, b) = f2 (q, a) 43 | in ((m, n), b) 44 | g3 ((m, n), db) = 45 | let (dq, da) = g2 (n, db) 46 | (dp, ds) = g1 (m, da) 47 | in ((dp, dq), ds) 48 | 49 | -- Convenient special cases 50 | 51 | composeR :: 52 | ExLens a da p dp s ds -> Lens a da b db -> 53 | ExLens b db p dp s ds 54 | composeR (ExLens f1 g1) (Lens f2 g2) = ExLens f3 g3 55 | where 56 | f3 (p, s) = 57 | let (m, a) = f1 (p, s) 58 | (n, b) = f2 a 59 | in ((m, n), b) 60 | g3 ((m, n), db) = 61 | let da = g2 (n, db) 62 | (dp, ds) = g1 (m, da) 63 | in (dp, ds) 64 | 65 | composeL :: 66 | Lens s ds a da -> ExLens b db q dq a da -> 67 | ExLens b db q dq s ds 68 | composeL (Lens f1 g1) (ExLens f2 g2) = ExLens f3 g3 69 | where 70 | f3 (q, s) = 71 | let (m, a) = f1 s 72 | (n, b) = f2 (q, a) 73 | in ((m, n), b) 74 | g3 ((m, n), db) = 75 | let (dq, da) = g2 (n, db) 76 | ds = g1 (m, da) 77 | in (dq, ds) 78 | 79 | -- Parallel composition 80 | 81 | -- A pair of lenses in parallel 82 | prodLens :: 83 | ExLens a da p dp s ds -> ExLens a' da' p' dp' s' ds' -> 84 | ExLens (a, a') (da, da') (p, p') (dp, dp') (s, s') (ds, ds') 85 | prodLens (ExLens f1 g1) (ExLens f2 g2) = ExLens f3 g3 86 | where 87 | f3 ((p, p'), (s, s')) = ((m, m'), (a, a')) 88 | where (m, a) = f1 (p, s) 89 | (m', a') = f2 (p', s') 90 | g3 ((m, m'), (da, da')) = ((dp, dp'), (ds, ds')) 91 | where 92 | (dp, ds) = g1 (m, da) 93 | (dp', ds') = g2 (m', da') 94 | 95 | -- Vector lens, combines n identical lenses in parallel 96 | vecLens :: 97 | Int -> ExLens a da p dp s ds -> ExLens [a] [da] [p] [dp] [s] [ds] 98 | vecLens 0 _ = ExLens (const ([], [])) (const ([], [])) 99 | vecLens n lns = consLens lns (vecLens (n - 1) lns) 100 | 101 | branch :: Monoid s => Int -> Lens s s [s] [s] 102 | branch n = Lens (\s -> ((), replicate n s)) 103 | (\(_, ss) -> mconcat ss) -- pointwise <+> 104 | 105 | -- A cons function combines a lens with a (parallel) list of lenses 106 | consLens :: 107 | ExLens a da p dp s ds -> ExLens [a] [da] [p] [dp] [s] [ds] -> 108 | ExLens [a] [da] [p] [dp] [s] [ds] 109 | consLens (ExLens f g) (ExLens fs gs) = ExLens fv gv 110 | where 111 | fv (p : ps, s : ss) = ((m, ms), a : as) 112 | where (m, a) = f (p, s) 113 | (ms, as) = fs (ps, ss) 114 | gv ((m, ms), da : das) = (dp : dps, ds : dss) 115 | where (dp, ds) = g (m, da) 116 | (dps, dss) = gs (ms, das) 117 | 118 | -- Helper functions for wiring networks 119 | 120 | -- xs = [1, 2, 3, 4, 5, 6] 121 | -- vw = [[1, 2, 3], [4, 5, 6]] m = 3 n = 2 122 | rechunk :: Int -> Int -> [a] -> [[a]] 123 | rechunk m 0 xs = [] 124 | rechunk m n xs = take m xs : rechunk m (n - 1) (drop m xs) 125 | 126 | -- Lens (Vect n (Vect m s)) (Vect (n * m) s) 127 | -- Here the existential parameter m is just (Int, Int) 128 | flatten :: Lens [[s]] [[ds]] [s] [ds] 129 | flatten = Lens f g 130 | where 131 | f sss = ((length (head sss), length sss), concat sss) 132 | -- (Vect n (Vect m s), Vect (n * m) s) -> (Vect n (Vect m s)) 133 | g ((m, n), ds) = rechunk m n ds 134 | 135 | -- This is for training neural networks. Instead of running batches 136 | -- of training data in series, we can do it in parallel and accumulate 137 | -- the parameters for the next batch. 138 | 139 | -- A batch of lenses in parallel, sharing the same parameters 140 | -- Back propagation combines the parameters 141 | batchN :: (Monoid dp) => 142 | Int -> ExLens a da p dp s ds -> ExLens [a] [da] p dp [s] [ds] 143 | batchN n (ExLens f g) = ExLens fv gv 144 | where 145 | fv (p, ss) = unzip $ fmap f $ zip (replicate n p) ss 146 | gv (ms, das) = (mconcat dps, dss) 147 | where -- g :: (m, da) -> (dp, ds) 148 | (dps, dss) = unzip $ fmap g $ zip ms das 149 | 150 | -- Rearrange vectors of parameters 151 | 152 | consParas :: ExLens a da (p, [p]) (p, [p]) s ds -> ExLens a da [p] [p] s ds 153 | consParas (ExLens f g) = ExLens f' g' 154 | where 155 | f' (p : ps, s) = f ((p, ps), s) 156 | g' (m, da) = 157 | let ((dp, dps), ds) = g (m, da) 158 | in (dp : dps, ds) 159 | 160 | singleParas :: ExLens a da p dp s ds -> ExLens a da [p] [dp] s ds 161 | singleParas (ExLens f g) = ExLens f' g' 162 | where 163 | f' ([p], s) = f (p, s) 164 | g' (m, da) = 165 | let (dp, ds) = g (m, da) 166 | in ([dp], ds) 167 | 168 | 169 | 170 | test1 :: IO () 171 | test1 = do 172 | let sss = [[1, 2, 3], [4, 5, 6]] 173 | let ss = [10, 11, 12, 13, 14, 15, 16] 174 | putStrLn "flatten forward" 175 | print $ fwd0 flatten sss 176 | putStrLn "flatten backward" 177 | print $ bwd0 flatten (sss, ss) 178 | putStrLn "" -------------------------------------------------------------------------------- /ExLens/NNet.hs: -------------------------------------------------------------------------------- 1 | module NNet where 2 | import ExLens 3 | 4 | -- Use existential lenses to create more complex neural networks 5 | 6 | type D = Double 7 | -- Ideally, a counted vector 8 | type V = [D] 9 | 10 | -- Parameters for a single neuron 11 | data Para = Para 12 | { weight :: V 13 | , bias :: D 14 | } deriving Show 15 | 16 | -- Parameters for a layer of neurons 17 | type ParaBlock = [Para] 18 | 19 | -- Additive monoid 20 | 21 | instance Semigroup D where 22 | (<>) = (+) 23 | 24 | instance Monoid D where 25 | mempty = 0.0 26 | 27 | instance Semigroup Para where 28 | (<>) :: Para -> Para -> Para 29 | p1 <> p2 = Para (zipWith (+) (weight p1) (weight p2)) (bias p1 + bias p2) 30 | 31 | instance Monoid Para where 32 | mempty :: Para 33 | mempty = Para (repeat 0.0) 0.0 34 | 35 | -- Parameters form a vector space, we need to scale them and add them 36 | 37 | class Monoid v => VSpace v where 38 | scale :: D -> v -> v 39 | 40 | instance VSpace D where 41 | scale :: D -> D -> D 42 | scale a x = a * x 43 | 44 | instance VSpace a => VSpace [a] where 45 | scale :: VSpace a => D -> [a] -> [a] 46 | scale a = fmap (scale a) 47 | 48 | instance VSpace Para where 49 | scale :: D -> Para -> Para 50 | scale a p = Para (scale a (weight p)) (scale a (bias p)) 51 | 52 | -- A simple linear lens: a scalar product of parameters and inputs 53 | 54 | linearL :: ExLens D D V V V V 55 | linearL = ExLens fw bw 56 | where 57 | fw :: (V, V) -> ((V, V), D) 58 | -- a = Sum p * s 59 | fw (p, s) = ((s, p), sum $ zipWith (*) p s) 60 | -- da/dp = s, da/ds = p 61 | bw :: ((V, V), D) -> (V, V) 62 | bw ((s, p), da) = (scale da s -- da/dp 63 | ,scale da p) -- da/ds 64 | 65 | biasL :: ExLens D D D D D D 66 | biasL = ExLens fw bw 67 | where 68 | fw :: (D, D) -> ((), D) 69 | fw (p, s) = ((), p + s) 70 | -- da/dp = 1, da/ds = 1 71 | bw :: ((), D) -> (D, D) 72 | bw (_, da) = (da, da) 73 | 74 | -- Non-linear activation lens using tanh 75 | activ :: Lens D D D D 76 | activ = Lens fw bw 77 | where 78 | -- a = tanh s 79 | fw s = (s, tanh s) 80 | -- da/ds = 1 + (tanh s)^2 81 | bw (s, da)= da * (1 - (tanh s)^2) -- a * da/ds 82 | 83 | -- Neuron as a composite of linear, bias, and activation 84 | neuron0 :: ExLens D D (V, D) (V, D) V V 85 | neuron0 = composeR (compose linearL biasL) activ 86 | 87 | -- Affine parametric lens 88 | -- (really a composition of linear and bias, but they are always used in combination) 89 | 90 | affine :: Int -> ExLens D D Para Para V V 91 | affine m = ExLens fw bw 92 | where 93 | fw :: (Para, V) -> ((V, V), D) 94 | -- a = b + w * s 95 | fw (p, s) = ((w, s), foldl (+) (bias p) (zipWith (*) w s)) 96 | where w = weight p 97 | bw :: ((V, V), D) -> (Para, V) 98 | bw ((w, s), da) = ( Para (scale da s) da -- (da/dw, da/db) 99 | , scale da w) -- da/ds 100 | 101 | -- Neuron with m inputs and one output with tanh activation 102 | neuron :: Int -> ExLens D D Para Para V V 103 | neuron m = composeR (affine m) activ 104 | 105 | -- Initialize parameters for an affine lens from an infinite stream 106 | initPara :: Int -> [D] -> (Para, [D]) 107 | initPara m stm = (Para w b, stm'') 108 | where 109 | (w, stm') = splitAt m stm 110 | ([b], stm'') = splitAt 1 stm' 111 | 112 | -- A layer of nOut identical neurons, each with mIn inputs 113 | layer :: Int -> Int -> ExLens V V [Para] [Para] V V 114 | layer nOut mIn = composeL (branch nOut) (vecLens nOut (neuron mIn)) 115 | 116 | -- Initialize a block of nOut parameters, each for a neuron with mIn inputs 117 | initParaBlock :: Int -> Int -> [D] -> ([Para], [D]) 118 | initParaBlock mIn nOut stm = unfoldl nOut (initPara mIn) stm 119 | 120 | 121 | -- The loss lens, compares results with ground truth 122 | loss :: V -> Lens V V D D 123 | loss gTruth = Lens fw bw 124 | where 125 | fw :: V -> (V, D) 126 | fw s = (s, delta s gTruth) 127 | bw :: (V, D) -> V 128 | -- da/ds = s - g 129 | bw (s, da) = map (da *) (zipWith (-) s gTruth) 130 | -- 1/2 Sum (s - g)^2 131 | delta s g = 0.5 * sum (map (^2) (zipWith (-) s g)) 132 | 133 | -- Helper function 134 | 135 | unfoldl :: Int -> (s -> (a, s)) -> s -> ([a], s) 136 | unfoldl 0 f s = ([], s) 137 | unfoldl n f s = (x : xs, s'') 138 | where 139 | (x, s') = f s 140 | (xs, s'') = unfoldl (n-1) f s' 141 | 142 | test2 :: IO () 143 | test2 = do 144 | let s = [0, 0.1 .. ] 145 | let (p, s') = initPara 2 s 146 | print p 147 | print $ fst $ unfoldl 3 (initPara 2) s' 148 | 149 | test3 :: IO () 150 | test3 = do 151 | putStrLn "Compare different implementation of neurons" 152 | let s = [1, 0.5, 0, 0] 153 | let (p, s') = initPara 3 s 154 | let nrn = neuron 3 155 | let ins = [-1, 0, 1] 156 | putStrLn "Forward neurons" 157 | print $ fwd nrn (p, ins) 158 | putStrLn "" 159 | print $ fwd neuron0 ((weight p, bias p), ins) 160 | putStrLn "Backward neurons" 161 | print $ bwd nrn (p, ins, 1) 162 | putStrLn "" 163 | print $ bwd neuron0 ((weight p, bias p), ins, 1) 164 | 165 | test4 :: IO () 166 | test4 = do 167 | putStrLn "Test backward passes" 168 | let p = Para [0.5, -0.5] 0.5 169 | let in1 = [1, 0] 170 | let in2 = [0, 1] 171 | let nrn = neuron 2 172 | print $ fwd nrn (p, in1) 173 | let (dp, ds) = bwd nrn (p, in1, 1) 174 | print dp 175 | print ds 176 | print $ fwd nrn (p, in2) 177 | let (dp, ds) = bwd nrn (p, in2, 1) 178 | print dp 179 | print ds 180 | 181 | test5 :: IO () 182 | test5 = do 183 | putStrLn "forward" 184 | print $ fwd (affine 2) (Para [-1, 1] 0.1, [2, 30]) 185 | putStrLn $ show $ (-2) + 30 + 0.1 186 | putStrLn "backward" 187 | print $ bwd (affine 2) (Para [1.3, -1.4] 0.1, [21, 33], 1) 188 | -- y = q1 * x1 + q2 * x2 + d 189 | -- dy/dq = (x1, x2), dy/dd = 1, dy/dx = (q1, q2) 190 | putStrLn $ show $ (Para [21, 33] 1, [1.3, -1.4]) 191 | -------------------------------------------------------------------------------- /ExLens/Perceptron.hs: -------------------------------------------------------------------------------- 1 | module Perceptron where 2 | import ExLens 3 | import NNet 4 | 5 | -- Multi-layer perceptron 6 | -- The first layer contains neurons with mIn inputs each 7 | -- The list [Int] specifies the number of neurons in each layer (staring with the first layer) 8 | -- Each neuron has one output 9 | makeMlp :: Int -> [Int] -> ExLens V V [[Para]] [[Para]] V V 10 | makeMlp mIn [nOut] = singleParas $ layer mIn nOut 11 | makeMlp mIn (n1 : n2 : ns) = consParas $ compose ly mlp 12 | where ly = layer mIn n1 13 | mlp = makeMlp n1 (n2 : ns) 14 | 15 | -- Initialize parameters for an MLP 16 | initParaMlp :: Int -> [Int] -> [D] -> ([[Para]], [D]) 17 | initParaMlp mIn [nOut] stm = 18 | let (pb, stm') = initParaBlock mIn nOut stm 19 | in ([pb], stm') 20 | initParaMlp mIn (n1 : n2 : ns) stm = 21 | let (pb, stm') = initParaBlock mIn n1 stm 22 | (pbs, stm'') = initParaMlp n1 (n2 : ns) stm' 23 | in (pb : pbs, stm'') 24 | -------------------------------------------------------------------------------- /Haskell/NNet.hs: -------------------------------------------------------------------------------- 1 | {-# language ScopedTypeVariables #-} 2 | module NNet where 3 | import PLens 4 | 5 | type D = Double 6 | -- Ideally, a counted vector 7 | type V = [D] 8 | 9 | data Para = Para 10 | { weight :: V 11 | , bias :: D 12 | } deriving Show 13 | 14 | type ParaBlock = [Para] 15 | 16 | -- Additive monoid 17 | 18 | instance Semigroup D where 19 | (<>) = (+) 20 | 21 | instance Monoid D where 22 | mempty = 0.0 23 | 24 | instance Semigroup Para where 25 | (<>) :: Para -> Para -> Para 26 | p1 <> p2 = Para (zipWith (+) (weight p1) (weight p2)) (bias p1 + bias p2) 27 | 28 | instance Monoid Para where 29 | mempty :: Para 30 | mempty = Para (repeat 0.0) 0.0 31 | 32 | -- Vector space 33 | class Monoid v => VSpace v where 34 | scale :: D -> v -> v 35 | 36 | instance VSpace D where 37 | scale :: D -> D -> D 38 | scale a x = a * x 39 | 40 | instance VSpace a => VSpace [a] where 41 | scale :: VSpace a => D -> [a] -> [a] 42 | scale s = fmap (scale s) 43 | 44 | instance VSpace Para where 45 | scale :: D -> Para -> Para 46 | scale a p = Para (scale a (weight p)) (scale a (bias p)) 47 | 48 | linearL :: PLens V V V V D D 49 | linearL = PLens fw bw 50 | where 51 | fw :: (V, V) -> D 52 | -- a = Sum p * s 53 | fw (p, s) = sum $ zipWith (*) p s 54 | -- da/dp = s, da/ds = p 55 | bw :: (V, V, D) -> (V, V) 56 | bw (p, s, da) = (scale da s -- da/dp 57 | ,scale da p) -- da/ds 58 | 59 | biasL :: PLens D D D D D D 60 | biasL = PLens fw bw 61 | where 62 | fw :: (D, D) -> D 63 | fw (p, s) = p + s 64 | -- da/dp = 1, da/ds = 1 65 | bw :: (D, D, D) -> (D, D) 66 | bw (p, s, da) = (da, da) 67 | 68 | -- Non-linear activation lens using tanh 69 | activ :: Lens D D D D 70 | activ = Lens fw bw 71 | where 72 | -- a = tanh s 73 | fw = tanh 74 | -- da/ds = 1 + (tanh s)^2 75 | bw = (\(s, a) -> a * (1 - (tanh s)^2)) -- a * da/ds 76 | 77 | -- Neuron as a composite of linear, bias, and activation 78 | neuron0 :: PLens (V, D) (V, D) V V D D 79 | neuron0 = composeR (compose linearL biasL) activ 80 | 81 | -- Affine parametric lens 82 | -- (really a composition of linear and bias, but they are always used in combination) 83 | 84 | affine :: Int -> PLens Para Para V V D D 85 | affine m = PLens fw bw 86 | where 87 | fw :: (Para, V) -> D 88 | -- a = b + w * s 89 | fw (p, s) = foldl (+) (bias p) (zipWith (*) (weight p) s) 90 | bw :: (Para, V, D) -> (Para, V) 91 | bw (p, s, da) = ( Para (map (da*) s) da -- (da/dw, da/db) 92 | , map (da*) (weight p)) -- da/ds 93 | 94 | -- Neuron with m inputs and one output and tanh activation 95 | neuron :: Int -> PLens Para Para V V D D 96 | neuron m = composeR (affine m) activ 97 | 98 | -- Initialize parameters for an affine lens from an infinite stream 99 | initPara :: Int -> [D] -> (Para, [D]) 100 | initPara m stm = (Para w b, stm'') 101 | where 102 | (w, stm') = splitAt m stm 103 | ([b], stm'') = splitAt 1 stm' 104 | 105 | 106 | layer :: Int -> Int -> PLens [Para] [Para] V V V V 107 | layer nOut mIn = composeL (branch nOut) (vecLens nOut (neuron mIn)) 108 | 109 | -- Initialize a block of nOut parameters, each for a neuron with mIn inputs 110 | initParaBlock :: Int -> Int -> [D] -> ([Para], [D]) 111 | initParaBlock mIn nOut stm = unfoldl nOut (initPara mIn) stm 112 | 113 | 114 | 115 | -- The loss lens, compares results with ground truth 116 | loss :: V -> Lens V V D D 117 | loss gTruth = Lens fw bw 118 | where 119 | fw :: V -> D 120 | fw s = delta s gTruth 121 | bw :: (V, D) -> V 122 | -- da/ds = s - g 123 | bw (s, da) = map (da *) (zipWith (-) s gTruth) 124 | -- 1/2 Sum (s - g)^2 125 | delta s g = 0.5 * sum (map (^2) (zipWith (-) s g)) 126 | 127 | -- Helper function 128 | 129 | unfoldl :: Int -> (s -> (a, s)) -> s -> ([a], s) 130 | unfoldl 0 f s = ([], s) 131 | unfoldl n f s = (x : xs, s'') 132 | where 133 | (x, s') = f s 134 | (xs, s'') = unfoldl (n-1) f s' 135 | 136 | test2 :: IO () 137 | test2 = do 138 | let s = [0, 0.1 .. ] 139 | let (p, s') = initPara 2 s 140 | print p 141 | print $ fst $ unfoldl 3 (initPara 2) s' 142 | 143 | test3 :: IO () 144 | test3 = do 145 | putStrLn "Compare different implementation of neurons" 146 | let s = [1, 0.5, 0, 0] 147 | let (p, s') = initPara 3 s 148 | let nrn = neuron 3 149 | let ins = [-1, 0, 1] 150 | putStrLn "Forward neurons" 151 | print $ fwd nrn (p, ins) 152 | putStrLn "" 153 | print $ fwd neuron0 ((weight p, bias p), ins) 154 | putStrLn "Backward neurons" 155 | print $ bwd nrn (p, ins, 1) 156 | putStrLn "" 157 | print $ bwd neuron0 ((weight p, bias p), ins, 1) 158 | 159 | test4 :: IO () 160 | test4 = do 161 | putStrLn "Test backward passes" 162 | let p = Para [0.5, -0.5] 0.5 163 | let in1 = [1, 0] 164 | let in2 = [0, 1] 165 | let nrn = neuron 2 166 | print $ fwd nrn (p, in1) 167 | let (dp, ds) = bwd nrn (p, in1, 1) 168 | print dp 169 | print ds 170 | print $ fwd nrn (p, in2) 171 | let (dp, ds) = bwd nrn (p, in2, 1) 172 | print dp 173 | print ds 174 | -------------------------------------------------------------------------------- /Haskell/PLens.hs: -------------------------------------------------------------------------------- 1 | {-# language ScopedTypeVariables #-} 2 | module PLens where 3 | 4 | data PLens p dp s ds a da = PLens 5 | { fwd :: (p, s) -> a 6 | , bwd :: (p, s, da) -> (dp, ds) 7 | } 8 | 9 | data Lens s ds a da = Lens 10 | { fwd0 :: s -> a 11 | , bwd0 :: (s, da) -> ds 12 | } 13 | 14 | compose :: forall p dp q dq s ds a da c dc. 15 | PLens p dp s ds a da -> PLens q dq a da c dc -> 16 | PLens (p, q) (dp, dq) s ds c dc 17 | compose pl ql = PLens fw bw 18 | where 19 | fw :: ((p, q), s) -> c 20 | fw ((p, q), s) = fwd ql $ (q, fwd pl (p, s)) 21 | bw :: ((p, q), s, dc) -> ((dp, dq), ds) 22 | bw ((p, q), s, dc) = ((dp, dq), ds) 23 | where 24 | a = fwd pl (p, s) 25 | (dq, da) = bwd ql (q, a, dc) 26 | (dp, ds) = bwd pl (p, s, da) 27 | 28 | composeR :: forall p dp s ds a da c dc. 29 | PLens p dp s ds a da -> Lens a da c dc -> 30 | PLens p dp s ds c dc 31 | composeR pl l0 = PLens fw bw 32 | where 33 | fw :: (p, s) -> c 34 | fw (p, s) = fwd0 l0 $ fwd pl (p, s) 35 | bw :: (p, s, dc) -> (dp, ds) 36 | bw (p, s, dc) = (dp, ds) 37 | where 38 | a = fwd pl (p, s) 39 | da = bwd0 l0 (a, dc) 40 | (dp, ds) = bwd pl (p, s, da) 41 | 42 | composeL :: forall s ds a da q dq c dc. 43 | Lens s ds a da -> PLens q dq a da c dc -> 44 | PLens q dq s ds c dc 45 | composeL l0 ql = PLens fw bw 46 | where 47 | fw :: (q, s) -> c 48 | fw (q, s) = fwd ql $ (q, fwd0 l0 s) 49 | bw :: (q, s, dc) -> (dq, ds) 50 | bw (q, s, dc) = (dq, ds) 51 | where 52 | a = fwd0 l0 s 53 | (dq, da) = bwd ql (q, a, dc) 54 | ds = bwd0 l0 (s, da) 55 | 56 | prodLens :: forall p dp s ds a da p' dp' s' ds' a' da'. 57 | PLens p dp s ds a da -> PLens p' dp' s' ds' a' da' -> 58 | PLens (p, p') (dp, dp') (s, s') (ds, ds') (a, a') (da, da') 59 | prodLens pl pl' = PLens fwdProd bwdProd 60 | where 61 | fwdProd :: ((p, p'), (s, s')) -> (a, a') 62 | fwdProd ((p, p'), (s, s')) = (a, a') 63 | where a = fwd pl (p, s) 64 | a' = fwd pl' (p', s') 65 | bwdProd :: ((p, p'), (s, s'), (da, da')) -> ((dp, dp'), (ds, ds')) 66 | bwdProd ((p, p'), (s, s'), (da, da')) = ((dp, dp'), (ds, ds')) 67 | where 68 | (dp, ds) = bwd pl (p, s, da) 69 | (dp', ds') = bwd pl' (p', s', da') 70 | 71 | -- Unit wrt product 72 | unitLens :: PLens () () () () () () 73 | unitLens = PLens (const ()) (const ((), ())) 74 | 75 | -- n lenses in parallel 76 | vecLens :: forall p dp s ds a da. 77 | Int -> PLens p dp s ds a da -> PLens [p] [dp] [s] [ds] [a] [da] 78 | vecLens 0 pl = PLens (const []) (const ([], [])) 79 | vecLens n pl = PLens fw bw 80 | where 81 | plrec = vecLens (n - 1) pl 82 | prod = prodLens pl plrec 83 | fw :: ([p], [s]) -> [a] 84 | fw (p : ps, s : ss) = a : as 85 | where 86 | (a, as) = fwd prod ((p, ps), (s, ss)) 87 | bw :: ([p], [s], [da]) -> ([dp], [ds]) 88 | bw (p : ps, s : ss, da : das) = (dp : dps, ds : dss) 89 | where 90 | ((dp, dps), (ds, dss)) = bwd prod ((p, ps), (s, ss), (da, das)) 91 | 92 | branch :: Monoid s => Int -> Lens s s [s] [s] 93 | branch n = Lens (replicate n) (\(_, ss) -> mconcat ss) -- pointwise <+> 94 | 95 | -- xs = [1, 2, 3, 4, 5, 6] 96 | -- vw = [[1, 2, 3], [4, 5, 6]] m = 3 n = 2 97 | rechunk :: Int -> Int -> [a] -> [[a]] 98 | rechunk m 0 xs = [] 99 | rechunk m n xs = take m xs : rechunk m (n - 1) (drop m xs) 100 | 101 | -- Lens (Vect n (Vect m s)) (Vect (n * m) s) 102 | flatten :: Lens [[s]] [[ds]] [s] [ds] 103 | flatten = Lens fw bw 104 | where 105 | fw = concat 106 | -- (Vect n (Vect m s), Vect (n * m) s) -> (Vect n (Vect m s) 107 | bw (sss, ds) = rechunk (length (head sss)) (length sss) ds 108 | 109 | batchN :: forall p dp s ds a da. 110 | Int -> ([dp] -> dp) -> PLens p dp s ds a da -> PLens p dp [s] [ds] [a] [da] 111 | batchN n fold lns = PLens fw bw 112 | where 113 | fw :: (p, [s]) -> [a] 114 | fw (px, ss) = fmap (fwd lns) $ zip (replicate n px) ss 115 | bw :: (p, [s], [da]) -> (dp, [ds]) 116 | bw (px, ss, das) = (fold dps', dss') 117 | where 118 | (dps', dss') = unzip $ fmap (bwd lns) $ zip3 (replicate n px) ss das 119 | 120 | toSingleton :: Lens s ds [s] [ds] 121 | toSingleton = Lens (\s -> [s]) (\(_, [ds]) -> ds) 122 | 123 | test1 :: IO () 124 | test1 = do 125 | let sss = [[1, 2, 3], [4, 5, 6]] 126 | let ss = [10, 11, 12, 13, 14, 15, 16] 127 | putStrLn "flatten forward" 128 | print $ fwd0 flatten sss 129 | putStrLn "flatten backward" 130 | print $ bwd0 flatten (sss, ss) 131 | putStrLn "" -------------------------------------------------------------------------------- /Haskell/Perceptron.hs: -------------------------------------------------------------------------------- 1 | {-# language ScopedTypeVariables #-} 2 | module Perceptron where 3 | import PLens 4 | import NNet 5 | import Distribution.FieldGrammar (VCat(VCat)) 6 | import Language.Haskell.TH (safe) 7 | 8 | makeMlp :: Int -> [Int] -> PLens [[Para]] [[Para]] V V V V 9 | makeMlp mIn [nOut] = PLens fw bw 10 | where 11 | ly :: PLens [Para] [Para] V V V V 12 | ly = layer mIn nOut 13 | fw ([ps], s) = fwd ly (ps, s) 14 | bw ([ps], s, da) = 15 | let (dp, ds) = bwd ly (ps, s, da) 16 | in ([dp], ds) 17 | makeMlp mIn (n1 : n2 : ns) = PLens fw bw 18 | where 19 | ly :: PLens [Para] [Para] V V V V 20 | ly = layer mIn n1 21 | mlp = makeMlp n1 (n2 : ns) 22 | lns :: PLens ([Para], [[Para]]) ([Para], [[Para]]) V V V V 23 | lns = compose ly mlp 24 | fw (p1 : ps, s) = fwd lns ((p1, ps), s) 25 | bw (p1 : ps, s, da) = 26 | let ((p1', ps'), ds) = bwd lns ((p1, ps), s, da) 27 | in (p1' : ps', ds) 28 | 29 | -- Initialize parameters for an MLP 30 | initParaMlp :: Int -> [Int] -> [D] -> ([[Para]], [D]) 31 | initParaMlp mIn [nOut] stm = 32 | let (pb, stm') = initParaBlock mIn nOut stm 33 | in ([pb], stm') 34 | initParaMlp mIn (n1 : n2 : ns) stm = 35 | let (pb, stm') = initParaBlock mIn n1 stm 36 | (pbs, stm'') = initParaMlp n1 (n2 : ns) stm' 37 | in (pb : pbs, stm'') 38 | -------------------------------------------------------------------------------- /Idris/HVect.idr: -------------------------------------------------------------------------------- 1 | module HVect 2 | import Data.Vect 3 | --import Data.Vect.Quantifiers 4 | 5 | -- Heterogeneous vector 6 | public export 7 | data HVect : Vect n Type -> Type where 8 | Nil : HVect Nil 9 | (::) : h -> HVect t -> HVect (h :: t) 10 | 11 | export 12 | Show (HVect []) where 13 | show Nil = "\n" 14 | 15 | export 16 | (Show t, Show (HVect ts)) => Show (HVect (t :: ts)) where 17 | show (x :: xs) = show x ++ " :: " ++ show xs 18 | 19 | export 20 | Semigroup (HVect []) where 21 | [] <+> [] = [] 22 | 23 | export 24 | (Semigroup t, Semigroup (HVect ts)) => 25 | Semigroup (HVect (t :: ts)) where 26 | (a :: as) <+> (b :: bs) = (a <+> b) :: (as <+> bs) 27 | 28 | export 29 | Monoid (HVect []) where 30 | neutral = [] 31 | 32 | export 33 | (Monoid t, Monoid (HVect ts)) => Monoid (HVect (t :: ts)) where 34 | neutral = neutral :: neutral 35 | 36 | 37 | public export 38 | interface (Semigroup v, Monoid v) => VSpace v where 39 | scale : Double -> v -> v 40 | 41 | export 42 | VSpace (HVect []) where 43 | scale a Nil = Nil 44 | 45 | export 46 | (VSpace t, VSpace (HVect ts)) => VSpace (HVect (t :: ts)) where 47 | scale a (v :: vs) = scale a v :: scale a vs 48 | 49 | 50 | -- Replicate a vector of types 51 | -- map (Vect k) ts 52 | public export 53 | 0 ReplTypes : {l : Nat} -> (k : Nat) -> (ts : Vect l Type) -> Vect l Type 54 | ReplTypes k [] = [] 55 | ReplTypes k (t' :: ts') = Vect k t' :: ReplTypes k ts' 56 | 57 | -- Concatenate vectors of heterogeneous monoid types 58 | export 59 | concatH : {k : Nat} -> {l : Nat} -> {ts : Vect l Type} -> {isMono : HVect (map Monoid ts)} -> 60 | HVect (ReplTypes k ts) -> HVect ts 61 | concatH {l = 0} {ts = []} Nil = Nil 62 | concatH {ts = t' :: ts'} {isMono = (_ :: pfs)} (v :: vs) = concat v :: concatH {isMono = pfs} vs 63 | 64 | export 65 | emptyVTypes : {l : Nat} -> (ts : Vect l Type) -> HVect (ReplTypes 0 ts) 66 | emptyVTypes [] = Nil 67 | emptyVTypes (t' :: ts') = [] :: emptyVTypes ts' 68 | 69 | -- Generalization of zipWith (::) 70 | export 71 | zipCons : {k : Nat} -> {l : Nat} -> {ts : Vect l Type} -> 72 | HVect ts -> HVect (ReplTypes k ts) -> HVect (ReplTypes (S k) ts) 73 | zipCons [] [] = [] 74 | zipCons (t' :: ts') (vs :: vss) = (t' :: vs) :: zipCons ts' vss 75 | 76 | -- Transpose a vector whose entries are heterogeneous vectors 77 | export 78 | transposeH : {k : Nat} -> {l : Nat} -> {ts : Vect l Type} -> 79 | Vect k (HVect ts) -> HVect (ReplTypes k ts) 80 | transposeH {k=0} {ts} [] = emptyVTypes ts 81 | transposeH (h :: hs) = zipCons h (transposeH hs) 82 | 83 | export 84 | unfoldl : (n : Nat) -> (s -> (a, s)) -> s -> (Vect n a, s) 85 | unfoldl Z f s = (Nil, s) 86 | unfoldl (S k) f s = 87 | let (x, s') = f s 88 | (xs, s'') = unfoldl k f s' 89 | in (x :: xs, s'') 90 | 91 | export 92 | takes : (n : Nat) -> Stream a -> (Vect n a, Stream a) 93 | takes Z s = (Nil, s) 94 | takes (S k) (x :: xs) = 95 | let (v, s') = takes k xs 96 | in (x :: v, s') 97 | 98 | export 99 | take1 : Stream a -> (a, Stream a) 100 | take1 (x :: xs) = (x, xs) -------------------------------------------------------------------------------- /Idris/Main.idr: -------------------------------------------------------------------------------- 1 | module Main 2 | import Data.Bits 3 | import Data.Vect 4 | import HVect 5 | import PLens 6 | import NNet 7 | import Perceptron 8 | 9 | -- Simple random number generator 10 | 11 | random : Int32 -> Stream Int32 12 | random seed = let seed' = 1664525 * seed + 1013904333 13 | in (seed' `shiftR` 2) :: random seed' 14 | 15 | -- Stream of pseudo-random doubles [-1, 1] 16 | rands : Stream Double 17 | rands = map normalize (random 42) 18 | where 19 | normalize : Int32 -> Double 20 | normalize n = (fromInteger (cast n)) / fromInteger(2147483647) 21 | 22 | 23 | run : {l : Nat} -> (mIn : Nat) -> (ns : Vect (S l) Nat) -> Vect mIn Double -> Vect (Vect.last ns) Double 24 | run mIn ns v = 25 | let mlp : PLens (HVect (ParaChain mIn ns)) (V mIn) (V (last ns)) := makeMLP mIn ns 26 | paras : HVect (ParaChain mIn ns) := fst (initParaChain mIn ns rands) 27 | in mlp.fwd (paras, v) 28 | 29 | 30 | 31 | main : IO () 32 | main = do 33 | let ns = [4,4,1] 34 | let inpt = [1, 2, 3] 35 | printLn $ run {l = 2} 3 ns inpt 36 | -------------------------------------------------------------------------------- /Idris/NNet.idr: -------------------------------------------------------------------------------- 1 | module NNet 2 | import Data.Vect 3 | import HVect 4 | import PLens 5 | 6 | -- Vector of Double 7 | public export 8 | 0 V : (n : Nat) -> Type 9 | V n = Vect n Double 10 | 11 | -- Parameters for an affine lens of 12 | -- mIn inputs, one output 13 | public export 14 | record Para (mIn : Nat) where 15 | constructor MkPara 16 | weight : Vect mIn Double 17 | bias : Double 18 | 19 | -- ParaBlock mIn nOut, vector of nOut parameters, each mIn wide 20 | -- (nOut * mIn inputs, nOut outputs) 21 | public export 22 | ParaBlock : (mIn : Nat) -> (nOut : Nat) -> Type 23 | ParaBlock mIn nOut = Vect nOut (Para mIn) 24 | 25 | ------------- 26 | -- Interfaces 27 | ------------- 28 | export 29 | {mIn : Nat} -> Show (Para mIn) where 30 | show pa = "weight: " ++ show (weight pa) ++ " bias: " ++ show (bias pa) ++ "\n" 31 | 32 | -- Semigroup 33 | 34 | Semigroup Double where 35 | x <+> y = x + y 36 | 37 | export 38 | Semigroup (Para m) where 39 | (MkPara w b) <+> (MkPara w' b') = MkPara (zipWith (+) w w') (b + b') 40 | 41 | -- Monoid 42 | Monoid Double where 43 | neutral = 0.0 44 | 45 | {m : Nat} -> Monoid (Para m) where 46 | neutral = MkPara (replicate m 0.0) 0.0 47 | 48 | -- Vector Space 49 | export 50 | {m : Nat} -> VSpace (Para m) where 51 | scale a (MkPara w b) = MkPara (map (a *) w) (a * b) 52 | 53 | export 54 | {mIn : Nat} -> {nOut : Nat} -> VSpace (ParaBlock mIn nOut) where 55 | scale a v = map (scale a) v 56 | 57 | ------------------------------------- 58 | ------- Vector parametric lenses ---- 59 | ------------------------------------- 60 | 61 | -- activation lens using tanh (no parameters) 62 | 63 | activ : Lens Double Double 64 | activ = MkLens (\s => tanh s) 65 | (\(s, a) => a * (1 - (tanh s)*(tanh s))) -- a * da/ds 66 | 67 | -- Affine parametric lens 68 | -- (really a composition of linear and bias, but they are always used in combination) 69 | 70 | affine : (mIn : Nat) -> PLens (Para mIn) (V mIn) Double 71 | affine nOut = MkPLens fwd' bwd' 72 | where 73 | fwd' : (Para mIn, V mIn) -> Double 74 | fwd' (p, s) = foldl (+) (bias p) (zipWith (*) (weight p) s) -- a = b + w * s 75 | bwd' : (Para mIn, V mIn, Double) -> (Para mIn, V mIn) 76 | bwd' (p, s, a) = ( MkPara (map (a *) s) a -- (da/dw, da/db) 77 | , map (a *) (weight p)) -- da/ds 78 | 79 | -- Initialize parameters for an affine lens 80 | initPara : (mIn : Nat) -> Stream Double -> (Para mIn, Stream Double) 81 | initPara mIn s = 82 | let (v, s') = takes mIn s 83 | (x, s'') = take1 s' 84 | in (MkPara v x, s'') 85 | 86 | 87 | -- Neuron with mIn inputs and one output 88 | 89 | -- affine : PLens (Para mIn) (V mIn) Double 90 | -- activ : Lens Double Double 91 | -- composite : PLens (Para mIn) (V mIn) Double 92 | 93 | export 94 | neuron : (mIn : Nat) -> PLens (Para mIn) (V mIn) Double 95 | neuron mIn = composeR (affine mIn) activ 96 | 97 | 98 | -- A layer of neurons 99 | 100 | -- n neurons with m inputs each 101 | -- 1 2 .. n 102 | -- | | | 103 | -- m m m 104 | -- \ / \ / 105 | -- m 106 | -- ParaBlock m n = Vect n (Para m) 107 | -- neuron m : PLens (Vect n (Para 1)) (V m) (V 1) 108 | -- vecLens n (neuron m): PLens (Vect n (Vect n (Para 1))) (Vect n (V m)) (Vect n (V 1)) 109 | -- branch n : Lens (V m) (Vect n (V m)) 110 | -- s a 111 | -- composeL : Lens s a -> PLens p a b -> PLens p s b 112 | export 113 | layer : (nOut : Nat) -> (mIn : Nat) -> PLens (Vect nOut (Para mIn)) (V mIn) (V nOut) 114 | layer nOut mIn = composeL (branch nOut) (vecLens nOut (neuron mIn)) 115 | 116 | -- Initialize parameters for a layer of n neurons, each with m inputs 117 | -- ParaBlock mIn nOut, vector of nOut parameters, each mIn wide 118 | export 119 | initParaBlock : (mIn : Nat) -> (nOut : Nat) -> Stream Double -> 120 | (ParaBlock mIn nOut, Stream Double) 121 | initParaBlock mIn nOut s = unfoldl nOut (initPara mIn) s 122 | 123 | -- mean square error 0.5 * Sum (si - gi)^2 124 | -- derivative: d/dsi = (si - gi) 125 | delta : V n -> V n -> Double 126 | delta s g = 0.5 * (sum $ map (\x => x * x) (zipWith (-) s g)) 127 | 128 | -- Sum of squares loss lens 129 | export 130 | loss : V n -> Lens (V n) Double 131 | loss gtruth = MkLens (\s => delta s gtruth) 132 | (\(s, a) => backLoss gtruth s a) 133 | where 134 | backLoss : V n -> V n -> Double -> V n 135 | backLoss g s a = map ( a *) (zipWith (-) s g) 136 | -------------------------------------------------------------------------------- /Idris/PLens.idr: -------------------------------------------------------------------------------- 1 | module PLens 2 | import Data.Vect 3 | import HVect 4 | 5 | ---------- 6 | -- Parametric lens 7 | 8 | -- record PLens p p' s s' a a' 9 | -- fwd : (p, s) -> a 10 | -- lens1.bwd : (p, s, a') -> (p', s') 11 | 12 | public export 13 | record PLens p s a where 14 | constructor MkPLens 15 | fwd : (p, s) -> a 16 | bwd : (p, s, a) -> (p, s) 17 | 18 | -- Special case of parametric lens with p = () 19 | -- Simplifies composition 20 | 21 | public export 22 | record Lens s a where 23 | constructor MkLens 24 | fwd0 : s -> a 25 | bwd0 : (s, a) -> s 26 | 27 | -- Composition of parametric lenses 28 | export 29 | compose : PLens p s a -> PLens q a b -> 30 | PLens (p, q) s b 31 | -- lens1.fwd : (p, s) -> a 32 | -- lens1.bwd : (p, s, a) -> (p, s) 33 | -- lens2.fwd : (q, a) -> b 34 | -- lens2.bwd : (q, a, b) -> (q, a) 35 | compose lens1 lens2 = MkPLens fwd' bwd' 36 | where 37 | fwd' : ((p, q), s) -> b 38 | fwd' ((p, q), s) = lens2.fwd (q, lens1.fwd (p, s)) 39 | bwd' : ((p, q), s, b) -> ((p, q), s) 40 | bwd' ((p, q), s, b) = 41 | let (q', a') = lens2.bwd (q, lens1.fwd (p, s), b) 42 | (p', s') = lens1.bwd (p, s, a') 43 | in ((p', q'), s') 44 | 45 | -- Helpers for composing a parametric lens with a non-parametric one 46 | export 47 | composeR : PLens p s a -> Lens a b -> 48 | PLens p s b 49 | composeR lens1 lens2 = MkPLens fwd' bwd' 50 | where 51 | fwd' : (p, s) -> b 52 | fwd' (p, s) = lens2.fwd0 (lens1.fwd (p, s)) 53 | bwd' : (p, s, b) -> (p, s) 54 | bwd' (p, s, b) = 55 | let a' = lens2.bwd0 (lens1.fwd (p, s), b) 56 | (p', s') = lens1.bwd (p, s, a') 57 | in (p', s') 58 | export 59 | composeL : Lens s a -> PLens p a b -> 60 | PLens p s b 61 | -- lens1.fwd : s -> a 62 | -- lens1.bwd : (s, a) -> s 63 | -- lens2.fwd : (p, a) -> b 64 | -- lens2.bwd : (p, a, b) -> (p, a) 65 | composeL lens1 lens2 = MkPLens fwd' bwd' 66 | where 67 | fwd' : (p, s) -> b 68 | fwd' (p, s) = lens2.fwd (p, lens1.fwd0 s) 69 | bwd' : (p, s, b) -> (p, s) 70 | bwd' (p, s, b) = 71 | let (p', a') = lens2.bwd (p, lens1.fwd0 s, b) 72 | s' = lens1.bwd0 (s, a') 73 | in (p', s') 74 | 75 | -- Product of parametric lenses, 76 | prodLens : 77 | PLens p s a -> 78 | PLens p' s' a' -> 79 | PLens (p, p') (s, s') (a, a') 80 | -- lens1.fwd : (p, s) -> a 81 | -- lens1.bwd : (p, s, a) -> (p, s) 82 | prodLens lens1 lens2 = 83 | MkPLens fwdProd bwdProd 84 | where 85 | fwdProd : ((p, p'), (s, s')) -> (a, a') 86 | fwdProd ((p, p'), (s, s')) = (lens1.fwd (p, s), lens2.fwd (p', s')) 87 | 88 | bwdProd : ((p, p'), (s, s'), (a, a')) -> ((p, p'), (s, s')) 89 | bwdProd ((p, p'), (s, s'), (a, a')) = 90 | let (q, t) = lens1.bwd (p, s, a) 91 | (q', t') = lens2.bwd (p', s', a') 92 | in ((q, q'), (t, t')) 93 | 94 | -- duplicate a lens in parallel n+1 times 95 | export 96 | vecLens : (n : Nat) -> PLens p s a -> PLens (Vect n p) (Vect n s) (Vect n a) 97 | vecLens Z _ = MkPLens (\(Nil, Nil) => Nil) (\(Nil, Nil, Nil) => (Nil, Nil)) 98 | vecLens (S n) lns = MkPLens fwd' bwd' 99 | where 100 | lnsN : PLens (Vect n p) (Vect n s) (Vect n a) 101 | lnsN = vecLens n lns 102 | fwd' : (Vect (S n) p, Vect (S n) s) -> Vect (S n) a 103 | fwd' (p :: ps, s :: ss) = lns.fwd (p, s) :: lnsN.fwd (ps, ss) 104 | bwd' : (Vect (S n) p, Vect (S n) s, Vect (S n) a) -> (Vect (S n) p, Vect (S n) s) 105 | bwd' (p :: ps, s :: ss, a :: as) = 106 | let (p', s') = lns.bwd (p, s, a) 107 | (ps', ss') = lnsN.bwd (ps, ss, as) 108 | in (p' :: ps', s' :: ss') 109 | 110 | -- A branching combinator 111 | export 112 | branch : Monoid s => (n : Nat) -> Lens s (Vect n s) 113 | branch n = MkLens (replicate n) (\(_, ss) => concat ss) -- pointwise <+> 114 | 115 | -- Batch n lenses in parallel sharing the same parameters 116 | -- input and output are n-tupled, parameters are collected 117 | batch : Monoid p => 118 | (n : Nat) -> 119 | PLens p s a -> 120 | PLens p (Vect n s) (Vect n a) 121 | batch n lns = 122 | MkPLens fwdB bwdB 123 | where 124 | fwdB : (p, Vect n s) -> Vect n a 125 | fwdB (p, ss) = map lns.fwd (zip (replicate n p) ss) 126 | bwdB : (p, Vect n s, Vect n a) -> (p, Vect n s) 127 | bwdB (p, ss, as) = 128 | let (ps', ss') = unzip $ map lns.bwd $ zip3 (replicate n p) ss as 129 | in (concat ps', ss') 130 | 131 | 132 | 133 | -- xs = [1, 2, 3, 4, 5, 6] 134 | -- vw = [[1, 2, 3], [4, 5, 6]] m=3 n=2 135 | export 136 | rechunk : (m : Nat) -> (n : Nat) -> Vect (n * m) a -> Vect n (Vect m a) 137 | rechunk m Z xs = [] 138 | rechunk m (S k) xs = take m xs :: rechunk m k (drop m xs) 139 | 140 | -- A connector lens, flattens a mxn array 141 | export 142 | flatten : {m : Nat} -> {n : Nat} -> Lens (Vect n (Vect m s)) (Vect (n * m) s) 143 | flatten = MkLens fwd' bwd' 144 | where 145 | fwd' : Vect n (Vect m s) -> Vect (n*m) s 146 | fwd' vs = concat vs 147 | bwd' : {m : Nat} -> {n : Nat} -> (Vect n (Vect m s), Vect (n*m) s) -> (Vect n (Vect m s)) 148 | bwd' {m} {n} (vs, w) = (rechunk m n w) 149 | 150 | -- Batches n identical neural networks sharing the same parameters 151 | -- Use it for batch training 152 | export 153 | batchN : (n : Nat) -> 154 | (Vect n p -> p) -> 155 | PLens p s a -> 156 | PLens p (Vect n s) (Vect n a) 157 | batchN n collectP lns = 158 | MkPLens fwdB bckB 159 | where 160 | fwdB : (p, Vect n s) -> Vect n a 161 | fwdB (p, ss) = map lns.fwd (zip (replicate n p) ss) 162 | bckB : (p, Vect n s, Vect n a) -> (p, Vect n s) 163 | bckB (p, ss, as) = 164 | let (ps', ss') = unzip $ map lns.bwd $ zip3 (replicate n p) ss as 165 | in (collectP ps', ss') 166 | 167 | -- Produce a singleton vector 168 | export 169 | single : Lens s (Vect 1 s) 170 | single = MkLens fwd bwd 171 | where 172 | fwd : s -> Vect 1 s 173 | fwd s = [s] 174 | bwd : (s, Vect 1 s) -> s 175 | bwd (_, [s]) = s 176 | -------------------------------------------------------------------------------- /Idris/Perceptron.idr: -------------------------------------------------------------------------------- 1 | module Perceptron 2 | import Data.Vect 3 | import HVect 4 | import PLens 5 | import NNet 6 | 7 | -- The architecture is specified by number of inputs mIn and a list of l+1 layers ns 8 | -- Where l is zero or more hidden layers 9 | -- mIn -> [mIn, n1] -> [n1, n2] -> ... [n l, n (l+1)] 10 | {- 11 | public export 12 | data Layout : (mIn : Nat) -> (layers : Vect (S l) Nat) -> Type where 13 | MkLayout : (mIn : Nat) -> (l : Nat) -> (layers : Vect (S l) Nat) -> Layout mIn layers 14 | 15 | export 16 | inN : Layout i ls -> Nat 17 | inN (MkLayout mIn l _) = mIn 18 | 19 | export 20 | outN : Layout i ls -> Nat 21 | outN (MkLayout _ l layers) = last layers 22 | -} 23 | 24 | -- Chain of parameter blocks 25 | -- Parameters for multi-layer perceptron with mIn inputs 26 | -- A chain of types: 27 | -- ParaBlock mIn n1 :: ParaBlock n1 n2 :: ParaBlock n2 n3 ... 28 | public export 29 | ParaChain : {l : Nat} -> (mIn : Nat) -> (ns : Vect l Nat) -> Vect l Type 30 | ParaChain mIn [] = Nil 31 | ParaChain mIn (n :: ns') = ParaBlock mIn n :: ParaChain n ns' 32 | 33 | -- Chain of vectors of parameter blocks (for batches of perceptrons) 34 | 0 VParaChain : {l : Nat} -> (k : Nat) -> (mIn : Nat) -> (ns : Vect l Nat) -> Vect l Type 35 | VParaChain k mIn ns = ReplTypes k (ParaChain mIn ns) 36 | 37 | -- Proof that every type in ParaChain is a Monoid 38 | isMonoChain : {l : Nat} -> (mIn : Nat) -> (ns : Vect l Nat) -> HVect (map Monoid (ParaChain mIn ns)) 39 | isMonoChain mIn [] = Nil 40 | isMonoChain mIn (n :: ns') = (%search) :: isMonoChain n ns' 41 | 42 | -- 43 | -- concatH : {k : Nat} -> {l : Nat} -> {ts : Vect l Type} -> {isMono : HVect (map Monoid ts)} -> 44 | -- HVect (ReplTypes k ts) -> HVect ts 45 | 46 | -- in (ParaChain mIn ns), all types are Monoid 47 | collectH : {l : Nat} -> {k : Nat} -> {mIn : Nat} -> {ns : Vect l Nat} -> 48 | HVect (VParaChain k mIn ns) -> HVect (ParaChain mIn ns) 49 | collectH hv = concatH {isMono = isMonoChain mIn ns} hv 50 | 51 | -- transposeH : {k : Nat} -> {l : Nat} -> {ts : Vect l Type} -> 52 | -- Vect k (HVect ts) -> HVect (ReplTypes k ts) 53 | 54 | export 55 | collectParas : {k : Nat} -> {mIn : Nat} -> {l : Nat} -> {ns : Vect l Nat} -> 56 | Vect k (HVect (ParaChain mIn ns)) -> HVect (ParaChain mIn ns) 57 | collectParas = collectH . transposeH 58 | 59 | -- Multi layer perceptron with m inputs and l+1 layers 60 | -- neuron count in each layer is given by (Vect l Nat) 61 | 62 | -- 1 2 .. n2 [n2] 63 | -- n1 n1 n1 64 | -- |/ \|/ \| 65 | -- 1 2 .. n1 [n1] <-n1- [P[m], P[m] .. P[m]] 66 | -- m m m 67 | -- \ / \ / 68 | -- m 69 | 70 | export 71 | makeMLP : {l : Nat} -> (mIn : Nat) -> (ns : Vect (S l) Nat) -> 72 | PLens (HVect (ParaChain mIn ns)) (V mIn) (V (last ns)) 73 | makeMLP mIn ([nOut]) = MkPLens fwd' bwd' 74 | where 75 | lr : PLens (ParaBlock mIn nOut) (V mIn) (V nOut) 76 | lr = layer nOut mIn 77 | 78 | fwd' : (HVect (ParaChain mIn [nOut]), V mIn) -> V (nOut) 79 | fwd' ([p], v) = lr.fwd (p, v) 80 | bwd' : (HVect (ParaChain mIn [nOut]), V mIn, V nOut) -> (HVect (ParaChain mIn [nOut]), V mIn) 81 | bwd' ([p], v, w) = let (p', v') = lr.bwd (p, v, w) 82 | in ([p'], v') 83 | makeMLP mIn (n1 :: n2 :: ns) = MkPLens fwd' bwd' 84 | where 85 | -- m -> [m, n1] -> [n1, n2] -> ... [n l, n (l+1)] 86 | -- Layout for the recursive part 87 | mlp' : PLens (HVect (ParaChain n1 (n2 :: ns))) (V n1) (V (last (n2 :: ns))) 88 | mlp' = makeMLP n1 (n2 :: ns) --<< recurse 89 | -- compose with the bottom layer 90 | mlpComp : PLens (ParaBlock mIn n1, HVect (ParaChain n1 (n2 :: ns))) 91 | (V mIn) 92 | (V (last (n2 :: ns))) 93 | mlpComp = compose (layer n1 mIn) mlp' 94 | fwd' : (HVect (ParaChain mIn (n1 :: n2 :: ns)), V mIn) -> V (last (n1 :: n2 :: ns)) 95 | fwd' (p1 :: ps, vm) = mlpComp.fwd ((p1, ps), vm) 96 | bwd' : (HVect (ParaChain mIn (n1 :: n2 :: ns)), V mIn, V (last (n1 :: n2 :: ns))) -> 97 | (HVect (ParaChain mIn (n1 :: n2 :: ns)), V mIn) 98 | bwd' (pmn1 :: pmns, s, a) = 99 | let ((pmn1', pmns'), s') = mlpComp.bwd ((pmn1, pmns), s, a) 100 | in (pmn1' :: pmns', s') 101 | 102 | -- Initialize parameters for an MLP 103 | 104 | export 105 | initParaChain : {l : Nat} -> (mIn : Nat) -> (ns : Vect (S l) Nat) -> 106 | Stream Double -> (HVect (ParaChain mIn ns), Stream Double) 107 | initParaChain mIn ([n]) s = 108 | let (pb, s') = initParaBlock mIn n s 109 | in ([pb], s') 110 | initParaChain mIn (n1 :: n2 :: ns) s = 111 | let (pb, s') = initParaBlock mIn n1 s 112 | (pbs, s'') = initParaChain n1 (n2 :: ns) s' 113 | in (pb :: pbs, s'') 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Bartosz Milewski 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PreLens/Main.hs: -------------------------------------------------------------------------------- 1 | module Main where 2 | import Tambara 3 | import TriLens 4 | import NNet 5 | import Perceptron 6 | import Params 7 | import Data.Int (Int32) 8 | import Data.Bits (shiftR) 9 | import Data.List 10 | import Control.Monad 11 | 12 | -- Home-made random number generator, in case no Random library available 13 | 14 | random :: Int32 -> [Int32] 15 | random seed = let seed' = 1664525 * seed + 1013904333 16 | in (seed' `shiftR` 2) : random seed' 17 | 18 | rands :: [Double] 19 | rands = map normalize (random 42) 20 | where 21 | normalize :: Int32 -> Double 22 | normalize n = fromIntegral n / 2147483647 23 | 24 | testLayer :: IO () 25 | testLayer = do 26 | let (paras, _) = initParaBlock 3 2 rands 27 | print $ fwd lyr (paras, [2, 3, -1]) 28 | putStrLn "Backward:" 29 | let (dp, ds) = bwd lyr (paras, [2, 3, -1], [1, 1]) 30 | print dp 31 | where 32 | lyr :: TriLens V V [((V, V), D)] [((V, V), D)] [Para] [Para] V V 33 | lyr = layer 3 2 -- 3 in, 2 out 34 | 35 | -- Gradient descent 36 | testLearning :: TriLens V V [[((V, V), D)]] [[((V, V), D)]] [[Para]] [[Para]] V V -> 37 | Double -> [V] -> [V] -> [[Para]] -> IO [[Para]] 38 | testLearning mlp rate xs ys para = do 39 | let ((_, dp), _) = bwd batchLoss ((ys, para), xs, 1) 40 | -- Update parameters 41 | let para1 = para <+> scale (-rate) dp 42 | putStrLn "\nForward pass: " 43 | print $ fwd batch (para1, xs) 44 | putStrLn "Forward pass error: " 45 | print $ fwd batchLoss ((ys, para1), xs) 46 | return para1 47 | where 48 | batch :: TriLens [V] [V] [[[((V, V), D)]]] [[[((V, V), D)]]] [[Para]] [[Para]] [V] [V] 49 | batch = batchN 4 mlp 50 | batchLoss :: TriLens D D 51 | ([[[((V, V), D)]]], ([V], [V])) ([[[((V, V), D)]]], ([V], [V])) 52 | ([V], [[Para]]) ([V], [[Para]]) 53 | [V] [V] 54 | batchLoss = triCompose batch lossT 55 | 56 | iterateM :: Monad m => Int -> (a -> m a) -> a -> m [a] 57 | iterateM 0 _ _ = return [] 58 | iterateM n f x = do 59 | x' <- f x 60 | (x':) `fmap` iterateM (n-1) f x' 61 | 62 | -- Taken from Andrej Karpathy's micrograd tutorial 63 | main :: IO () 64 | main = do 65 | let xs = [[2, 3, -1] 66 | ,[3, -1, 0.5] 67 | ,[0.5, 1, 1] 68 | ,[1, 1, -1]] 69 | let ys = fmap singleton [1, -1, -1, 1] 70 | let rate = 0.5 71 | let (para, _) = initParaMlp 3 [4, 4, 1] rands 72 | paras <- iterateM 10 (testLearning mlp rate xs ys) para 73 | return () 74 | where 75 | mlp :: TriLens V V [[((V, V), D)]] [[((V, V), D)]] [[Para]] [[Para]] V V 76 | mlp = makeMlp 3 [4, 4, 1] 77 | -------------------------------------------------------------------------------- /PreLens/NNet.hs: -------------------------------------------------------------------------------- 1 | module NNet where 2 | import PreLens 3 | import Tambara 4 | import TriLens 5 | import Params 6 | import Data.Bifunctor ( Bifunctor(second, first, bimap) ) 7 | 8 | -- Use existential lenses to create more complex neural networks 9 | 10 | -- Simple linear lens, scalar product of parameters and inputs 11 | linearL :: Int -> PreLens D D (V, V) (V, V) V V V V 12 | linearL n = PreLens fw bw 13 | where 14 | fw :: (V, V) -> ((V, V), D) 15 | -- a = Sum p * s 16 | fw (p, s) = ((s, p), sumN n $ zipWith (*) p s) 17 | -- da/dp = s, da/ds = p 18 | bw :: ((V, V), D) -> (V, V) 19 | bw ((s, p), da) = (fmap (da *) s -- da * da/dp 20 | ,fmap (da *) p) -- da * da/ds 21 | 22 | -- Add bias to input 23 | biasL :: PreLens D D () () D D D D 24 | biasL = PreLens fw bw 25 | where 26 | fw :: (D, D) -> ((), D) 27 | fw (p, s) = ((), p + s) 28 | -- da/dp = 1, da/ds = 1 29 | bw :: ((), D) -> (D, D) 30 | bw (_, da) = (da, da) 31 | 32 | -- Non-linear activation lens using tanh 33 | activL :: PreLens D D D D () () D D 34 | activL = PreLens fw bw 35 | where 36 | -- a = tanh s 37 | fw (_, s) = (s, tanh s) 38 | -- da/ds = 1 + (tanh s)^2 39 | bw (s, da)= ((), da * (1 - (tanh s)^2)) -- a * da/ds 40 | 41 | neuronL :: Int -> PreLens D D ((V, V), D) ((V, V), D) Para Para V V 42 | neuronL mIn = PreLens f' b' 43 | where 44 | PreLens f b = preCompose (preCompose (linearL mIn) biasL) activL 45 | f' :: (Para, V) -> (((V, V), D), D) 46 | f' (Para bi wt, s) = let (((vv, ()), d), a) = f (((), (bi, wt)), s) 47 | in ((vv, d), a) 48 | b' :: (((V, V), D), D) -> (Para, V) 49 | b' ((vv, d), da) = let (((), (d', w')), ds) = b (((vv, ()), d), da) 50 | in (Para d' w', ds) 51 | 52 | -- Convert to TriLens 53 | -- m1 p1 D -> ((V, V), m1) (p1, (V, V)) V 54 | linearT :: Int -> TriLens D D (V, V) (V, V) V V V V 55 | linearT n = toTamb (linearL n) 56 | 57 | -- m1 p1 D -> ((), m1) (p1, ()) D 58 | biasT :: TriLens D D () () D D D D 59 | biasT = toTamb biasL 60 | 61 | affineT :: Int -> TriLens D D (V, V) (V, V) (D, V) (D, V) V V 62 | affineT n = 63 | dimapM (first runit) (first unRunit) . 64 | triCompose (linearT n) biasT 65 | 66 | activT :: TriLens D D D D () () D D 67 | activT = toTamb activL 68 | 69 | neuronT :: Int -> TriLens D D ((V, V), D) ((V, V), D) Para Para V V 70 | -- m1 p1 D -> (((V, V), D), m1) (p1, Para) V 71 | neuronT mIn = 72 | dimapP (second (unLunit . unPara)) (second (mkPara . lunit)) . 73 | triCompose (dimapM (first runit) (first unRunit) . 74 | triCompose (linearT mIn) biasT) activT 75 | 76 | -- Initialize parameters for an affine lens from an infinite stream 77 | initPara :: Int -> [D] -> (Para, [D]) 78 | initPara m stm = (Para b w, stm'') 79 | where 80 | (w, stm') = splitAt m stm 81 | ([b], stm'') = splitAt 1 stm' 82 | 83 | 84 | -- A layer of nOut identical neurons, each with mIn inputs 85 | -- V [((V, V), D)] [Para] V 86 | layer :: Int -> Int -> TriLens V V [((V, V), D)] [((V, V), D)] [Para] [Para] V V 87 | layer mIn nOut = 88 | dimapP (second unRunit) (second runit) . 89 | dimapM (first lunit) (first unLunit) . 90 | triCompose (branch nOut) (vecLens nOut (neuronT mIn)) -- m1 p1 V -> (m1, ((), [((V, V), D)])) (([Para], ()), p1) V 91 | 92 | 93 | -- Initialize a block of nOut parameters, each for a neuron with mIn inputs 94 | initParaBlock :: Int -> Int -> [D] -> ([Para], [D]) 95 | initParaBlock mIn nOut stm = unfoldl nOut (initPara mIn) stm 96 | 97 | -- Helper function 98 | 99 | unfoldl :: Int -> (s -> (a, s)) -> s -> ([a], s) 100 | unfoldl 0 f s = ([], s) 101 | unfoldl n f s = (x : xs, s'') 102 | where 103 | (x, s') = f s 104 | (xs, s'') = unfoldl (n-1) f s' 105 | 106 | 107 | -- The loss lens, compares results with ground truth 108 | loss1L :: PreLens D D (V, V) (V, V) V V V V 109 | loss1L = PreLens fw bw 110 | where 111 | fw :: (V, V) -> ((V, V), D) 112 | fw (gTruth, s) = ((gTruth, s), sqDist s gTruth) 113 | bw :: ((V, V), D) -> (V, V) 114 | bw ((gTruth, s), da) = (fmap negate delta', delta') 115 | where delta' = map (da *) (s `minus` gTruth) 116 | -- da/ds = s - g 117 | -- da/dg = g - s 118 | 119 | minus :: Num c => [c] -> [c] -> [c] 120 | minus = zipWith (-) 121 | 122 | -- 1/2 Sum (s - g)^2 123 | sqDist :: Fractional a => [a] -> [a] -> a 124 | sqDist x y = 0.5 * sum (map (^2) (zipWith (-) x y)) 125 | 126 | loss1T :: TriLens D D (V, V) (V, V) V V V V 127 | loss1T = toTamb loss1L 128 | 129 | -- The loss lens, compares results with ground truth 130 | lossL :: PreLens D D ([V], [V]) ([V], [V]) [V] [V] [V] [V] 131 | lossL = PreLens fw bw 132 | where 133 | fw :: ([V], [V]) -> (([V], [V]), D) 134 | fw (gTruth, s) = ((gTruth, s), sqDist (concat s) (concat gTruth)) 135 | bw :: (([V], [V]), D) -> ([V], [V]) 136 | bw ((gTruth, s), da) = (fmap (fmap negate) delta', delta') 137 | -- da/ds = s - g 138 | -- da/dg = g - s 139 | where 140 | delta' :: [V] 141 | delta' = fmap (fmap (da *)) (zipWith minus s gTruth) 142 | 143 | lossT :: TriLens D D ([V], [V]) ([V], [V]) [V] [V] [V] [V] 144 | lossT = toTamb lossL 145 | -------------------- 146 | 147 | 148 | -- fwd :: TriLens a da m dm p dp s ds -> (p, s) -> a 149 | -- fwd l = let PreLens f b = fromTamb l 150 | -- in snd . f 151 | -- bwd :: TriLens a da m dm p dp s ds -> (p, s, da) -> (dp, ds) 152 | -- bwd l (p, s, da) = 153 | -- let (PreLens f b) = fromTamb l 154 | -- in b (fst (f (p, s)), da) 155 | 156 | -- affineT :: Int -> TriLens D D (V, V) (V, V) (D, V) (D, V) V V 157 | testTriTamb :: IO () 158 | testTriTamb = do 159 | putStrLn "forward" 160 | print $ fwd (affineT 2) ((0.01, [-0.1, 0.1]), [2, 30]) 161 | putStrLn "backward" 162 | print $ bwd (affineT 2) ((0.1, [1.3, -1.4]), [0.21, 0.33], 1) 163 | 164 | putStrLn "forward neuron" 165 | print $ fwd (neuronT 2) (Para 0.01 [-0.1, 0.1], [2, 30]) 166 | putStrLn "backward neuron" 167 | print $ bwd (neuronT 2) (Para 0.1 [1.3, -1.4], [0.21, 0.33], 1) 168 | 169 | test2 :: IO () 170 | test2 = do 171 | let s = [0, 0.1 .. ] 172 | let (p, s') = initPara 2 s 173 | let (p', s'') = initPara 2 s' 174 | putStrLn $ "p = " ++ show p ++ "\np' = " ++ show p' ++ "\np <+> (scale (-0.1) p') = " ++ show (p <+> (scale (-0.1) p')) ++ "\n" 175 | print p 176 | print $ fst $ unfoldl 3 (initPara 2) s' 177 | 178 | nrn3 :: TriLens D D ((V, V), D) ((V, V), D) Para Para V V 179 | nrn3 = neuronT 3 180 | 181 | test3 :: IO () 182 | test3 = do 183 | putStrLn "Compare different implementation of neurons" 184 | let s = [1, 0.5, 0, 0] 185 | let (p, s') = initPara 3 s 186 | let ins = [-1, 0, 1] 187 | putStrLn "Forward neurons" 188 | print $ fwd nrn3 (p, ins) 189 | putStrLn "" 190 | let neuron0 = ExLens (neuronL 3) 191 | print $ fwd' neuron0 (p, ins) 192 | putStrLn "Backward neurons" 193 | print $ bwd nrn3 (p, ins, 1) 194 | putStrLn "" 195 | print $ bwd' neuron0 (p, ins, 1) 196 | 197 | 198 | nrn2 :: TriLens D D ((V, V), D) ((V, V), D) Para Para V V 199 | nrn2 = neuronT 2 200 | 201 | test4 :: IO () 202 | test4 = do 203 | putStrLn "Test backward passes" 204 | let p = Para 0.5 [0.5, -0.5] 205 | let in1 = [1, 0] 206 | let in2 = [0, 1] 207 | print $ fwd nrn2 (p, in1) 208 | let (dp, ds) = bwd nrn2 (p, in1, 1) 209 | print dp 210 | print ds 211 | print $ fwd nrn2 (p, in2) 212 | let (dp, ds) = bwd nrn2 (p, in2, 1) 213 | print dp 214 | print ds 215 | 216 | 217 | test5 :: IO () 218 | test5 = do 219 | putStrLn "forward" 220 | print $ fwd (affineT 2) ((0.1, [-1, 1]), [2, 30]) 221 | putStrLn $ show $ (-2) + 30 + 0.1 222 | putStrLn "backward" 223 | print $ bwd (affineT 2) ((0.1, [1.3, -1.4]), [21, 33], 1) 224 | -- y = q1 * x1 + q2 * x2 + d 225 | -- dy/dq = (x1, x2), dy/dd = 1, dy/dx = (q1, q2) 226 | putStrLn $ show $ (Para 1 [21, 33], [1.3, -1.4]) 227 | -------------------------------------------------------------------------------- /PreLens/ParaLens.hs: -------------------------------------------------------------------------------- 1 | module ParaLens where 2 | import Data.Bifunctor ( Bifunctor(second, first, bimap) ) 3 | 4 | -- Parametric lenses and BiTambara modules 5 | 6 | -- Parametric lens, get/set or forward/backward representation 7 | data PLens a da p dp s ds = 8 | PLens { fwd' :: (p, s) -> a 9 | , bwd' :: (p, s, da) -> (dp, ds) 10 | } 11 | 12 | -- Existential representation of parametic lens 13 | data ExLens a da p dp s ds = 14 | forall m . ExLens ((p, s) -> (m, a)) 15 | ((m, da) -> (dp, ds)) 16 | 17 | -- Accessors 18 | fwd :: ExLens a da p dp s ds -> (p, s) -> a 19 | fwd (ExLens f g) (p, s) = snd $ f (p, s) 20 | 21 | bwd :: ExLens a da p dp s ds -> (p, s, da) -> (dp, ds) 22 | bwd (ExLens f g) (p, s, da) = g (fst (f (p, s)), da) 23 | 24 | -- Serial composition 25 | compose :: 26 | ExLens a da p dp s ds -> ExLens b db q dq a da -> 27 | ExLens b db (p, q) (dp, dq) s ds 28 | compose (ExLens f1 g1) (ExLens f2 g2) = ExLens f3 g3 29 | where 30 | f3 ((p, q), s) = 31 | let (m, a) = f1 (p, s) 32 | (n, b) = f2 (q, a) 33 | in ((m, n), b) 34 | g3 ((m, n), db) = 35 | let (dq, da) = g2 (n, db) 36 | (dp, ds) = g1 (m, da) 37 | in ((dp, dq), ds) 38 | 39 | identityLens :: ExLens a da () () a da 40 | identityLens = ExLens id id 41 | 42 | -- Parallel composition 43 | 44 | -- A pair of lenses in parallel 45 | prodLens :: 46 | ExLens a da p dp s ds -> ExLens a' da' p' dp' s' ds' -> 47 | ExLens (a, a') (da, da') (p, p') (dp, dp') (s, s') (ds, ds') 48 | prodLens (ExLens f1 g1) (ExLens f2 g2) = ExLens f3 g3 49 | where 50 | f3 ((p, p'), (s, s')) = ((m, m'), (a, a')) 51 | where (m, a) = f1 (p, s) 52 | (m', a') = f2 (p', s') 53 | g3 ((m, m'), (da, da')) = ((dp, dp'), (ds, ds')) 54 | where 55 | (dp, ds) = g1 (m, da) 56 | (dp', ds') = g2 (m', da') 57 | 58 | -- van Laarhoven representation of parametric lens 59 | -- Not very useful, as it doesn't compose nicely 60 | type VanL a da p dp s ds = forall f. Functor f => 61 | (a -> f da) -> (p, s) -> f (dp, ds) 62 | 63 | type ParaLens a da p dp s ds = (p, s) -> (da -> (dp, ds), a) 64 | 65 | toVLL :: ParaLens a da p dp s ds -> VanL a da p dp s ds 66 | toVLL para f = fmap (uncurry ($)) . strength . second f . para 67 | 68 | fromVLL :: VanL a da p dp s ds -> ParaLens a da p dp s ds 69 | fromVLL vll = unF . vll (curry MkF id) 70 | 71 | newtype F a da x = MkF { unF :: (da -> x, a) } 72 | deriving Functor 73 | 74 | -- Profunctor representation 75 | 76 | -- As a reminder, this is the vanilla Tambara module 77 | class Profunctor p where 78 | dimapVanilla :: (a' -> a) -> (b -> b') -> p a b -> p a' b' 79 | 80 | class Profunctor p => Tambara p where 81 | alphaVanilla :: forall a da m. p a da -> p (m, a) (m, da) 82 | -- 83 | 84 | class BiProfunctor p where 85 | dimap :: (a' -> a) -> (b -> b') -> p q q' a b -> p q q' a' b' 86 | dimap' :: (r -> q) -> (q' -> r') -> p q q' a b -> p r r' a b 87 | 88 | -- Parametric version of Tambara module 89 | class BiProfunctor p => BiTambara p where 90 | alpha :: p q q' a da -> p q q' (m, a) (m, da) 91 | beta :: p r r' (q, s) (q', ds) -> p (r, q) (r', q') s ds 92 | 93 | -- Profunctor representation of a parametric lens 94 | type BiLens q q' s ds a da = 95 | forall p. BiTambara p => forall r r'. p r r' a da -> p (r, q) (r', q') s ds 96 | 97 | -- Existential lens is an example of a BiTambara module 98 | instance BiProfunctor (ExLens a da) where 99 | dimap :: (s' -> s) -> (ds -> ds') -> ExLens a da q q' s ds -> ExLens a da q q' s' ds' 100 | dimap f g (ExLens fw bw) = ExLens fw' bw' 101 | where fw' (q, s') = fw (q, f s') 102 | bw' (m, da) = second g (bw (m, da)) 103 | dimap' :: (r -> q) -> (q' -> r') -> ExLens a da q q' s ds -> ExLens a da r r' s ds 104 | dimap' f g (ExLens fw bw) = ExLens fw' bw' 105 | where 106 | fw' (r, s) = fw (f r, s) 107 | bw' (m, da) = first g $ bw (m, da) 108 | 109 | instance BiTambara (ExLens a da) where 110 | alpha :: ExLens a da q q' s ds -> ExLens a da q q' (n, s) (n, ds) 111 | alpha (ExLens fw bw) = ExLens fw' bw' 112 | where fw' (q, (n, s)) = first (n,) $ fw (q, s) -- use (n, m) as residue 113 | bw' ((n, m), da) = second (n,) (bw (m, da)) 114 | 115 | beta :: ExLens a da r r' (q, s) (q', ds) -> ExLens a da (r, q) (r', q') s ds 116 | beta (ExLens fw bw) = ExLens fw' bw' 117 | where fw' ((r, q), s) = fw (r, (q, s)) 118 | bw' (m, da) = let (r', (q', ds)) = bw (m, da) 119 | in ((r', q'), ds) 120 | 121 | fromTamb :: BiLens q q' s ds a da -> ExLens a da q q' s ds 122 | fromTamb pab_pst = dimap' unLunit lunit $ pab_pst identityLens 123 | 124 | lunit :: ((), a) -> a 125 | lunit ((), a) = a 126 | unLunit :: a -> ((), a) 127 | unLunit a = ((), a) 128 | 129 | -- Conversion from ExLens to BiLens 130 | toTamb :: ExLens a da q q' s ds -> BiLens q q' s ds a da 131 | toTamb (ExLens fw bw) = beta . dimap fw bw . alpha 132 | 133 | strength :: Functor f => (a, f b) -> f (a, b) 134 | strength (a, fb) = fmap (a,) fb 135 | -------------------------------------------------------------------------------- /PreLens/Params.hs: -------------------------------------------------------------------------------- 1 | module Params where 2 | type D = Double 3 | -- Ideally, a counted vector 4 | type V = [D] 5 | 6 | -- Parameters for a single neuron 7 | data Para = Para 8 | { bias :: D 9 | , weight :: V 10 | } deriving Show 11 | 12 | mkPara :: (D, V) -> Para 13 | mkPara (b, v) = Para b v 14 | 15 | unPara :: Para -> (D, V) 16 | unPara p = (bias p, weight p) 17 | 18 | -- Parameters for a layer of neurons 19 | type ParaBlock = [Para] 20 | 21 | -- Parameters form a vector space, we need to scale them and add them 22 | 23 | class VSpace v where 24 | (<+>) :: v -> v -> v 25 | scale :: D -> v -> v 26 | vzero :: v 27 | 28 | instance VSpace D where 29 | (<+>) :: D -> D -> D 30 | (<+>) = (+) 31 | scale :: D -> D -> D 32 | scale a x = a * x 33 | vzero :: D 34 | vzero = 0.0 35 | 36 | instance VSpace a => VSpace [a] where 37 | (<+>) :: VSpace a => [a] -> [a] -> [a] 38 | (<+>) = zipWith (<+>) 39 | scale :: VSpace a => D -> [a] -> [a] 40 | scale a = fmap (scale a) 41 | vzero :: VSpace a => [a] 42 | vzero = repeat vzero 43 | 44 | instance VSpace Para where 45 | (<+>) :: Para -> Para -> Para 46 | p1 <+> p2 = Para (bias p1 + bias p2) (zipWith (+) (weight p1) (weight p2)) 47 | scale :: D -> Para -> Para 48 | scale a p = Para (scale a (bias p)) (scale a (weight p)) 49 | vzero :: Para 50 | vzero = Para 0.0 (repeat 0.0) 51 | 52 | sumN :: Int -> V -> D 53 | sumN 0 _ = 0 54 | sumN n [] = error $ "sumN " ++ show n 55 | sumN n (a : as) = a + sumN (n - 1) as 56 | 57 | accumulate :: VSpace v => [v] -> v 58 | accumulate = foldr (<+>) vzero -------------------------------------------------------------------------------- /PreLens/Perceptron.hs: -------------------------------------------------------------------------------- 1 | module Perceptron where 2 | import PreLens 3 | import Tambara 4 | import TriLens 5 | import Params 6 | import NNet 7 | import Data.Bifunctor ( Bifunctor(second, first, bimap) ) 8 | import Data.List 9 | 10 | -- Multi-layer perceptron 11 | -- The first layer contains neurons with mIn inputs each 12 | -- The list [Int] specifies the number of neurons in each layer (staring with the first layer) 13 | -- Each neuron has one output 14 | makeMlp :: Int -> [Int] -> TriLens V V [[((V, V), D)]] [[((V, V), D)]] [[Para]] [[Para]] V V 15 | -- layer : V [((V, V), D)] [Para] V 16 | makeMlp mIn [nOut] = 17 | dimapM (first singleton) (first head) . 18 | dimapP (second head) (second singleton) . 19 | layer mIn nOut 20 | makeMlp mIn (n1 : n2 : ns) = 21 | dimapM (first cons) (first unCons) . 22 | dimapP (second (sym . unCons)) (second (cons . sym)) . 23 | triCompose (layer mIn n1) (makeMlp n1 (n2 : ns)) 24 | 25 | -- Initialize parameters for an MLP 26 | initParaMlp :: Int -> [Int] -> [D] -> ([[Para]], [D]) 27 | initParaMlp mIn [nOut] stm = 28 | let (pb, stm') = initParaBlock mIn nOut stm 29 | in ([pb], stm') 30 | initParaMlp mIn (n1 : n2 : ns) stm = 31 | let (pb, stm') = initParaBlock mIn n1 stm 32 | (pbs, stm'') = initParaMlp n1 (n2 : ns) stm' 33 | in (pb : pbs, stm'') 34 | -------------------------------------------------------------------------------- /PreLens/PreLens.hs: -------------------------------------------------------------------------------- 1 | module PreLens where 2 | import Data.Bifunctor ( Bifunctor(second, first, bimap) ) 3 | 4 | -- Pre-lens, uses 4 monoidal actions parameterized by m dm and p dp 5 | -- Pre-lens category has objects , etc. 6 | -- Pre-lenses are morphism from to 7 | 8 | data PreLens a da m dm p dp s ds = 9 | PreLens ((p, s) -> (m, a)) 10 | ((dm, da) -> (dp, ds)) 11 | 12 | -- Pre-lenses are composable 13 | preCompose :: 14 | PreLens a' da' m dm p dp s ds -> 15 | PreLens a da n dn q dq a' da' -> 16 | PreLens a da (m, n) (dm, dn) (q, p) (dq, dp) s ds 17 | preCompose (PreLens f1 g1) (PreLens f2 g2) = PreLens f3 g3 18 | where 19 | f3 = unAssoc . second f2 . assoc . first sym . unAssoc . second f1 . assoc 20 | g3 = unAssoc . second g1 . assoc . first sym . unAssoc . second g2 . assoc 21 | {- or more verbose: 22 | f3 ((q, p), s) = 23 | let (m, a) = f1 (p, s) 24 | (n, b) = f2 (q, a) 25 | in ((m, n), b) 26 | g3 ((dm, dn), db) = 27 | let (dq, da) = g2 (dn, db) 28 | (dp, ds) = g1 (dm, da) 29 | in ((dq, dp), ds) 30 | -} 31 | 32 | idPreLens :: PreLens a da () () () () a da 33 | idPreLens = PreLens id id 34 | 35 | 36 | -- Parallel composition 37 | 38 | -- A pair of lenses in parallel 39 | prodLens :: 40 | PreLens a da m dm p dp s ds -> PreLens a' da' m' dm' p' dp' s' ds' -> 41 | PreLens (a, a') (da, da') (m, m') (dm, dm') (p, p') (dp, dp') (s, s') (ds, ds') 42 | prodLens (PreLens f1 g1) (PreLens f2 g2) = PreLens f3 g3 43 | where 44 | f3 ((p, p'), (s, s')) = ((m, m'), (a, a')) 45 | where (m, a) = f1 (p, s) 46 | (m', a') = f2 (p', s') 47 | g3 ((dm, dm'), (da, da')) = ((dp, dp'), (ds, ds')) 48 | where 49 | (dp, ds) = g1 (dm, da) 50 | (dp', ds') = g2 (dm', da') 51 | 52 | -- An existential lens is a trace over m of a PreLens 53 | -- The tracing can be done after all the compositions 54 | 55 | data ExLens a da p dp s ds = forall m. ExLens (PreLens a da m m p dp s ds) 56 | 57 | -- Extractors for an existential lens 58 | fwd' :: ExLens a da p dp s ds -> (p, s) -> a 59 | fwd' (ExLens (PreLens f b)) = snd . f 60 | bwd' :: ExLens a da p dp s ds -> (p, s, da) -> (dp, ds) 61 | bwd' (ExLens (PreLens f b)) (p, s, da) = b (fst (f (p, s)), da) 62 | 63 | -- Composition of existential lenses follows 64 | -- the composition of pre-lenses 65 | composeL :: 66 | ExLens a da p dp s ds -> ExLens b db q dq a da -> 67 | ExLens b db (q, p) (dq, dp) s ds 68 | composeL (ExLens pl) (ExLens pl') = ExLens $ preCompose pl pl' 69 | 70 | -- Combining the two extractors into one function 71 | 72 | type ParaLens a da p dp s ds = (p, s) -> (da -> (dp, ds), a) 73 | 74 | -- Monoidal category structure maps 75 | lunit :: ((), a) -> a 76 | lunit ((), a) = a 77 | unLunit :: a -> ((), a) 78 | unLunit a = ((), a) 79 | runit :: (a, ()) -> a 80 | runit (a, ()) = a 81 | unRunit :: a -> (a, ()) 82 | unRunit a = (a, ()) 83 | 84 | assoc :: ((a, b), c) -> (a, (b, c)) 85 | assoc ((a, b), c) = (a, (b, c)) 86 | unAssoc :: (a, (b, c)) -> ((a, b), c) 87 | unAssoc (a, (b, c))= ((a, b), c) 88 | 89 | -- Symmetric monoidal structure maps 90 | 91 | sym :: (a, b) -> (b, a) 92 | sym (a, b) = (b, a) 93 | 94 | skipRight :: (x, (b, c)) -> (b, (x, c)) 95 | skipRight (x, (b, c)) = (b, (x, c)) 96 | 97 | skipLeft :: ((a, b), x) -> ((a, x), b) 98 | skipLeft ((a, b), x) = ((a, x), b) 99 | -------------------------------------------------------------------------------- /PreLens/Tambara.hs: -------------------------------------------------------------------------------- 1 | module Tambara where 2 | import PreLens 3 | import Data.Bifunctor ( Bifunctor(second, first, bimap) ) 4 | 5 | -- A profunctor in three pairs of arguments (Notice: the polarities of m dm are flipped) 6 | class TriProFunctor t where 7 | dimapS :: (s' -> s) -> (ds -> ds') -> t m dm p dp s ds -> t m dm p dp s' ds' 8 | dimapP :: (p' -> p) -> (dp -> dp') -> t m dm p dp s ds -> t m dm p' dp' s ds 9 | dimapM :: (m -> m') -> (dm' -> dm) -> t m dm p dp s ds -> t m' dm' p dp s ds 10 | 11 | -- PreLens is a profunctor in three pairs of arguments 12 | instance TriProFunctor (PreLens a da) where 13 | dimapS f g (PreLens fw bw) = PreLens fw' bw' 14 | where fw' (p, s') = fw (p, f s') 15 | bw' (dm, da) = second g $ bw (dm, da) 16 | dimapP f g (PreLens fw bw) = PreLens fw' bw' 17 | where fw' (p', s) = fw (f p', s) 18 | bw' (dm, da) = first g $ bw (dm, da) 19 | dimapM f g (PreLens fw bw) = PreLens fw' bw' 20 | where fw' (p, s) = first f $ fw (p, s) 21 | bw' (dm', da) = bw (g dm', da) 22 | 23 | -- A generalization of Tambara modules with three pairs of arguments 24 | class TriProFunctor t => Trimbara t where 25 | -- shorthand: alpha :: m p s -> (m1, m) p (m1, s) 26 | alpha :: t m dm p dp s ds -> t (m1, m) (dm1, dm) p dp (m1, s) (dm1, ds) 27 | -- shorthand: beta :: m p (p1, s) -> m (p, p1) s 28 | beta :: t m dm p dp (p1, s) (dp1, ds) -> t m dm (p, p1) (dp, dp1) s ds 29 | 30 | -- PreLens is an example of such a Tambara module 31 | instance Trimbara (PreLens a da) where 32 | -- fw :: (p, s) -> (m, a) 33 | -- bw :: (dm, da) -> (dp, ds) 34 | alpha :: PreLens a da m dm p dp s ds -> PreLens a da (n, m) (dn, dm) p dp (n, s) (dn, ds) 35 | alpha (PreLens fw bw) = PreLens fw' bw' 36 | where 37 | --fw' :: (p, (n, s)) -> ((n, m)), a) 38 | fw' (p, (n, s)) = let (m, a) = fw (p, s) 39 | in ((n, m), a) 40 | --bw' :: ((dn, dm), da) -> (dp, (dn, ds)) 41 | bw' ((dn, dm), da) = let (dp, ds) = bw (dm, da) 42 | in (dp, (dn, ds)) 43 | 44 | beta :: forall m dm p dp s ds a da r dr . 45 | PreLens a da m dm p dp (r, s) (dr, ds) -> PreLens a da m dm (p, r) (dp, dr) s ds 46 | -- fw :: (p, (r, s)) -> (m, a) 47 | -- bw :: (dm, da) -> (dp, (dr, ds)) 48 | beta (PreLens fw bw) = PreLens fw' bw' 49 | where 50 | fw' :: ((p, r), s) -> (m, a) 51 | fw' ((p, r), s) = let (m, a) = fw (p, (r, s)) 52 | in (m, a) 53 | bw' :: (dm, da) -> ((dp, dr), ds) 54 | bw' (dm, da) = let (dp, (dr, ds)) = bw (dm, da) 55 | in ((dp, dr), ds) 56 | -------------------------------------------------------------------------------- /PreLens/TriLens.hs: -------------------------------------------------------------------------------- 1 | module TriLens where 2 | import PreLens 3 | import Tambara 4 | import Params 5 | import Data.Bifunctor ( Bifunctor(second, first, bimap) ) 6 | 7 | ---------- 8 | -- This function polymorphic in Trimbara modules is equivalent to a PreLens 9 | -- shorthand: m1 p1 a -> (m, m1) (p1, p) s 10 | ---------- 11 | type TriLens a da m dm p dp s ds = 12 | forall t. Trimbara t => forall p1 dp1 m1 dm1. 13 | t m1 dm1 p1 dp1 a da -> t (m, m1) (dm, dm1) (p1, p) (dp1, dp) s ds 14 | 15 | -- n r a -> (m, n)(r, p) s 16 | -- () () a -> (m, ()) ((), p) s 17 | fromTamb :: forall a da m dm p dp s ds . 18 | TriLens a da m dm p dp s ds -> PreLens a da m dm p dp s ds 19 | fromTamb pab_pst = dimapM runit unRunit $ 20 | dimapP unLunit lunit $ 21 | pab_pst idPreLens 22 | 23 | toTamb :: PreLens a da m dm p dp s ds -> TriLens a da m dm p dp s ds 24 | -- want :: m1 p1 a -> (m, m1) (p1, p) s 25 | -- alpha :: m1 p1 a -> (m, m1) p1 (m, a) 26 | -- dimapS fw bw :: -> (m, m1) p1 (p, s) 27 | -- beta :: -> (m, m1) (p1, p) s 28 | toTamb (PreLens fw bw) = beta . dimapS fw bw . alpha 29 | 30 | -- triCompose :: b m p s -> a n q b -> a (m, n) (q, p) s 31 | triCompose :: 32 | TriLens b db m dm p dp s ds -> 33 | TriLens a da n dn q dq b db -> 34 | TriLens a da (m, n) (dm, dn) (q, p) (dq, dp) s ds 35 | -- lba :: m1 p1 a -> (n, m1) (p1, q) b 36 | -- las :: (n, m1) (p1, q) b -> (m, (n, m1)) ((p1, q), p) s 37 | -- lbs :: m1 p1 a -> ((m, n), m1) (p1, (q, p)) s 38 | -- dimapP :: (p' -> p) -> (dp -> dp') -> m p s -> m p' s 39 | -- dimapM :: (m -> m') -> (dm' -> dm) -> m p s -> m' p s 40 | -- las . lba :: m1 p1 a -> (m, (n, m1)) ((p1, q), p) s 41 | triCompose las lba = dimapP unAssoc assoc . 42 | dimapM unAssoc assoc . 43 | las . lba 44 | 45 | -- Extractors for a triple Tambara lens 46 | fwd :: TriLens a da m m p dp s ds -> (p, s) -> a 47 | fwd l = fwd' (ExLens (fromTamb l)) 48 | bwd :: TriLens a da m m p dp s ds -> (p, s, da) -> (dp, ds) 49 | bwd l = bwd' (ExLens (fromTamb l)) 50 | 51 | 52 | -- Parallel product of TriLenses 53 | 54 | -- Show that a TriTambara of products is a TriTambara in both sides of the product 55 | 56 | -- Rearrange the wires for Haskell 57 | newtype PRight t m dm p dp s ds m' dm' p' dp' s' ds' = PRight { 58 | unPRight :: t (m, m') (dm, dm') (p, p') (dp, dp') (s, s') (ds, ds') } 59 | 60 | -- It's a TriProfunctor in these variables 61 | instance (TriProFunctor t) => TriProFunctor (PRight t m dm p dp s ds) where 62 | dimapS f g (PRight t) = PRight $ dimapS (second f) (second g) t 63 | dimapP f g (PRight t) = PRight $ dimapP (second f) (second g) t 64 | dimapM f g (PRight t) = PRight $ dimapM (second f) (second g) t 65 | 66 | -- It's a TriTambara in thes variables 67 | instance (Trimbara t) => Trimbara (PRight t m dm p dp s ds) where 68 | -- alpha :: m p s -> (m1, m) p (m1, s) 69 | -- need :: (m, m') (p, p') (s, s') -> 70 | -- (m, (m1, m')) (p, p') (s, (m1, s')) 71 | alpha = PRight . 72 | dimapS skipRight skipRight . 73 | dimapM skipRight skipRight . 74 | alpha . -- (m1, (m, m')) (p, p') (m1, (s, s')) 75 | unPRight -- (m, m') (p, p') (s, s') 76 | 77 | -- beta :: m p (p1, s) -> m (p, p1) s 78 | -- need :: (m, m') (p, p') (s, (p1, s')) -> (m, m') ((p, (p', p1)) (s, s') 79 | beta = PRight . 80 | dimapP unAssoc assoc . 81 | beta . -- (m, m') ((p, p'), p1) (s, s') 82 | dimapS skipRight skipRight . -- (m, m') (p, p') (p1, (s, s')) 83 | unPRight -- (m, m') (p, p') (s, (p1, s')) 84 | 85 | newtype PLeft t m' dm' p' dp' s' ds' m dm p dp s ds = PLeft { 86 | unPLeft :: t (m, m') (dm, dm') (p, p') (dp, dp') (s, s') (ds, ds') } 87 | 88 | -- It's a TriProfunctor in these variables 89 | instance (TriProFunctor t) => TriProFunctor (PLeft t m dm p dp s ds) where 90 | dimapS f g (PLeft t) = PLeft $ dimapS (first f) (first g) t 91 | dimapP f g (PLeft t) = PLeft $ dimapP (first f) (first g) t 92 | dimapM f g (PLeft t) = PLeft $ dimapM (first f) (first g) t 93 | 94 | -- It's a TriTambara in these variables 95 | instance (Trimbara t) => Trimbara (PLeft t m dm p dp s ds) where 96 | -- alpha :: m p s -> (m1, m) p (m1, s) 97 | -- need :: (m, m') (p, p') (s, s') -> 98 | -- ((m1, m), m') (p, p') ((m1, s), s') 99 | alpha = PLeft . 100 | dimapS assoc unAssoc . 101 | dimapM unAssoc assoc . 102 | alpha . -- (m1, (m, m')) (p, p') (m1, (s, s')) 103 | unPLeft -- (m, m') (p, p') (s, s') 104 | -- beta :: m p (p1, s) -> m (p, p1) s 105 | -- need :: (m, m') (p, p') ((p1, s), s') -> (m, m') ((p, p1), p') (s, s') 106 | beta = PLeft . 107 | dimapP skipLeft skipLeft . 108 | beta . -- (m, m') ((p, p'), p1), (s, s') 109 | dimapS unAssoc assoc . -- (m, m') (p, p') (p1, (s, s')) 110 | unPLeft -- (m, m') (p, p') ((p1, s), s') 111 | 112 | prodLensT :: TriLens a da m dm p dp s ds -> 113 | TriLens a' da' m' dm' p' dp' s' ds' -> 114 | TriLens (a, a') (da, da') (m, m') (dm, dm') (p, p') (dp, dp') (s, s') (ds, ds') 115 | -- l1 :: m1 p1 a -> (m, m1) (p1, p) s 116 | -- l2 :: m1' p1' a' -> (m', m1') (p1', p') s' 117 | -- l3 :: m1 p1 (a, a') -> ((m, m'), m1) (p1, (p, p')) (s, s') 118 | prodLensT l1 l2 = 119 | dimapP unAssoc assoc . -- ((m, m'), m1) (p1, (p, p')) (s, s') 120 | dimapM unAssoc assoc . -- 121 | dimapP (second unLunit) (second lunit) . -- (m, (m', m1)) ((p1, p), p') (s, s') 122 | dimapM (first runit) (first unRunit) . -- 123 | unPRight . l2 . PRight . -- ((m, ()), (m', m1)) ((p1, p), ((), p')) (s, s') 124 | unPLeft . l1 . PLeft . -- ((m, ()), m1) ((p1, p), ()) ((s, a') 125 | dimapP runit unRunit . -- ((), m1) (p1, ()) (a, a') 126 | dimapM unLunit lunit -- ((), m1) p1 (a, a') 127 | 128 | -- Create a vector of n identical lenses in parallel 129 | 130 | vecLens :: Int -> TriLens a da m dm p dp s ds -> 131 | TriLens [a] [da] [m] [dm] [p] [dp] [s] [ds] 132 | -- m1 p1 [a] -> ([m], m1) (p1, [p]) [s] 133 | vecLens 0 _ = nilLens 134 | vecLens n l = consLens l (vecLens (n - 1) l) 135 | 136 | nilLens :: TriLens [a] [da] [m] [dm] [p] [dp] [s] [ds] 137 | -- m1 p1 [a] -> ([m], m1) (p1, [p]) [s] 138 | nilLens = dimapM ([], ) snd . 139 | dimapP fst (, []) . 140 | dimapS (const []) (const []) 141 | 142 | consLens :: TriLens a da m dm p dp s ds -> 143 | TriLens [a] [da] [m] [dm] [p] [dp] [s] [ds] -> 144 | TriLens [a] [da] [m] [dm] [p] [dp] [s] [ds] 145 | -- l1 :: m1 p1 a -> (m, m1) (p1, p) s 146 | -- l2 :: m2 p2 [a] -> ([m], m2) (p2, [p]) [s] 147 | -- l3 :: m3 p3 [a] -> ([m], m3) (p3, [p]) [s] 148 | consLens l1 l2 = 149 | dimapP (second unCons) (second cons) . 150 | dimapM (first cons) (first unCons) . 151 | dimapS unCons cons . 152 | prodLensT l1 l2 . -- m3 p3 (a, [a]) -> ((m, [m]), m3) (p3, (p, [p]))(s, [s]) 153 | dimapS cons unCons 154 | 155 | cons :: (a, [a]) -> [a] 156 | cons (a, as) = a : as 157 | unCons :: [a] -> (a, [a]) 158 | unCons (a : as) = (a, as) 159 | 160 | 161 | -- This is for training neural networks. Instead of running batches 162 | -- of training data in series, we can do it in parallel and accumulate 163 | -- the parameters for the next batch. 164 | 165 | -- A batch of lenses in parallel, sharing the same parameters 166 | -- Back propagation combines the parameters 167 | batchN :: (VSpace dp) => 168 | Int -> TriLens a da m dm p dp s ds -> TriLens [a] [da] [m] [dm] p dp [s] [ds] 169 | -- l :: m1 p1 a -> (m, m1) (p1, p) s 170 | -- vec :: m1 p1 [a] -> ([m], m1) (p1, [p]) [s] 171 | -- out :: m1 p1 [a] -> ([m], m1) (p1, p) [s] 172 | batchN n l = 173 | dimapP (second (replicate n)) (second accumulate) . vecLens n l 174 | 175 | -- A splitter combinator 176 | -- The simplest example of a combinator for connecting layers 177 | 178 | branch :: Monoid s => Int -> TriLens [s] [s] () () () () s s 179 | -- m1 p1 [s] -> ((), m1) (p1, ()) s 180 | branch n = 181 | dimapM unLunit lunit . 182 | dimapP runit unRunit . 183 | dimapS (replicate n) mconcat 184 | 185 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepLearning 2 | Neural networks using lenses in Haskell and Idris 3 | -------------------------------------------------------------------------------- /src/HVect.idr: -------------------------------------------------------------------------------- 1 | module HVect 2 | import Data.Vect 3 | --import Data.Vect.Quantifiers 4 | 5 | -- Heterogeneous vector 6 | public export 7 | data HVect : Vect n Type -> Type where 8 | Nil : HVect Nil 9 | (::) : h -> HVect t -> HVect (h :: t) 10 | 11 | export 12 | Show (HVect []) where 13 | show Nil = "\n" 14 | 15 | export 16 | (Show t, Show (HVect ts)) => Show (HVect (t :: ts)) where 17 | show (x :: xs) = show x ++ " :: " ++ show xs 18 | 19 | export 20 | Semigroup (HVect []) where 21 | [] <+> [] = [] 22 | 23 | export 24 | (Semigroup t, Semigroup (HVect ts)) => 25 | Semigroup (HVect (t :: ts)) where 26 | (a :: as) <+> (b :: bs) = (a <+> b) :: (as <+> bs) 27 | 28 | export 29 | Monoid (HVect []) where 30 | neutral = [] 31 | 32 | export 33 | (Monoid t, Monoid (HVect ts)) => Monoid (HVect (t :: ts)) where 34 | neutral = neutral :: neutral 35 | 36 | 37 | public export 38 | interface (Semigroup v, Monoid v) => VSpace v where 39 | scale : Double -> v -> v 40 | 41 | export 42 | VSpace (HVect []) where 43 | scale a Nil = Nil 44 | 45 | export 46 | (VSpace t, VSpace (HVect ts)) => VSpace (HVect (t :: ts)) where 47 | scale a (v :: vs) = scale a v :: scale a vs 48 | 49 | 50 | -- Replicate a vector of types 51 | -- map (Vect k) ts 52 | public export 53 | 0 ReplTypes : {l : Nat} -> (k : Nat) -> (ts : Vect l Type) -> Vect l Type 54 | ReplTypes k [] = [] 55 | ReplTypes k (t' :: ts') = Vect k t' :: ReplTypes k ts' 56 | 57 | -- Concatenate vectors of heterogeneous monoid types 58 | export 59 | concatH : {k : Nat} -> {l : Nat} -> {ts : Vect l Type} -> {isMono : HVect (map Monoid ts)} -> 60 | HVect (ReplTypes k ts) -> HVect ts 61 | concatH {l = 0} {ts = []} Nil = Nil 62 | concatH {ts = t' :: ts'} {isMono = (_ :: pfs)} (v :: vs) = concat v :: concatH {isMono = pfs} vs 63 | 64 | export 65 | emptyVTypes : {l : Nat} -> (ts : Vect l Type) -> HVect (ReplTypes 0 ts) 66 | emptyVTypes [] = Nil 67 | emptyVTypes (t' :: ts') = [] :: emptyVTypes ts' 68 | 69 | -- Generalization of zipWith (::) 70 | export 71 | zipCons : {k : Nat} -> {l : Nat} -> {ts : Vect l Type} -> 72 | HVect ts -> HVect (ReplTypes k ts) -> HVect (ReplTypes (S k) ts) 73 | zipCons [] [] = [] 74 | zipCons (t' :: ts') (vs :: vss) = (t' :: vs) :: zipCons ts' vss 75 | 76 | -- Transpose a vector whose entries are heterogeneous vectors 77 | export 78 | transposeH : {k : Nat} -> {l : Nat} -> {ts : Vect l Type} -> 79 | Vect k (HVect ts) -> HVect (ReplTypes k ts) 80 | transposeH {k=0} {ts} [] = emptyVTypes ts 81 | transposeH (h :: hs) = zipCons h (transposeH hs) 82 | 83 | export 84 | unfoldl : (n : Nat) -> (s -> (a, s)) -> s -> (Vect n a, s) 85 | unfoldl Z f s = (Nil, s) 86 | unfoldl (S k) f s = 87 | let (x, s') = f s 88 | (xs, s'') = unfoldl k f s' 89 | in (x :: xs, s'') 90 | 91 | export 92 | takes : (n : Nat) -> Stream a -> (Vect n a, Stream a) 93 | takes Z s = (Nil, s) 94 | takes (S k) (x :: xs) = 95 | let (v, s') = takes k xs 96 | in (x :: v, s') 97 | 98 | export 99 | take1 : Stream a -> (a, Stream a) 100 | take1 (x :: xs) = (x, xs) -------------------------------------------------------------------------------- /src/Main.idr: -------------------------------------------------------------------------------- 1 | module Main 2 | import Data.Bits 3 | import Data.Vect 4 | import HVect 5 | import PLens 6 | import NNet 7 | import Perceptron 8 | 9 | -- Simple random number generator 10 | 11 | random : Int32 -> Stream Int32 12 | random seed = let seed' = 1664525 * seed + 1013904333 13 | in (seed' `shiftR` 2) :: random seed' 14 | 15 | -- Stream of pseudo-random doubles [-1, 1] 16 | rands : Stream Double 17 | rands = map normalize (random 42) 18 | where 19 | normalize : Int32 -> Double 20 | normalize n = (fromInteger (cast n)) / fromInteger(2147483647) 21 | 22 | 23 | run : (Layout i o) -> Vect i Double -> Vect (Vect.last o) Double 24 | run ly@(MkLayout ins l layers) v = 25 | let mlp : PLens (HVect (ParaChain ly)) (V ins) (V (last layers)) := makeMLP ly 26 | paras : HVect (ParaChain ly) := fst (initParaChain ly rands) 27 | in mlp.fwd (paras, v) 28 | 29 | main : IO () 30 | main = do 31 | let ly = MkLayout 3 2 [4,4,1] 32 | let inpt = [1, 2, 3] 33 | printLn $ run ly inpt -------------------------------------------------------------------------------- /src/NNet.idr: -------------------------------------------------------------------------------- 1 | module NNet 2 | import Data.Vect 3 | import HVect 4 | import PLens 5 | 6 | -- Vector of Double 7 | public export 8 | 0 V : (n : Nat) -> Type 9 | V n = Vect n Double 10 | 11 | -- Parameters for an affine lens of 12 | -- mIn inputs, one output 13 | public export 14 | record Para (mIn : Nat) where 15 | constructor MkPara 16 | weight : Vect mIn Double 17 | bias : Double 18 | 19 | -- ParaBlock mIn nOut, vector of nOut parameters, each mIn wide 20 | -- (nOut * mIn inputs, nOut outputs) 21 | public export 22 | ParaBlock : (mIn : Nat) -> (nOut : Nat) -> Type 23 | ParaBlock mIn nOut = Vect nOut (Para mIn) 24 | 25 | ------------- 26 | -- Interfaces 27 | ------------- 28 | export 29 | {mIn : Nat} -> Show (Para mIn) where 30 | show pa = "weight: " ++ show (weight pa) ++ " bias: " ++ show (bias pa) ++ "\n" 31 | 32 | -- Semigroup 33 | 34 | Semigroup Double where 35 | x <+> y = x + y 36 | 37 | export 38 | Semigroup (Para m) where 39 | (MkPara w b) <+> (MkPara w' b') = MkPara (zipWith (+) w w') (b + b') 40 | 41 | -- Monoid 42 | Monoid Double where 43 | neutral = 0.0 44 | 45 | {m : Nat} -> Monoid (Para m) where 46 | neutral = MkPara (replicate m 0.0) 0.0 47 | 48 | -- Vector Space 49 | export 50 | {m : Nat} -> VSpace (Para m) where 51 | scale a (MkPara w b) = MkPara (map (a *) w) (a * b) 52 | 53 | export 54 | {mIn : Nat} -> {nOut : Nat} -> VSpace (ParaBlock mIn nOut) where 55 | scale a v = map (scale a) v 56 | 57 | ------------------------------------- 58 | ------- Vector parametric lenses ---- 59 | ------------------------------------- 60 | 61 | -- activation lens using tanh (no parameters) 62 | 63 | activ : Lens Double Double 64 | activ = MkLens (\s => tanh s) 65 | (\(s, a) => a * (1 - (tanh s)*(tanh s))) -- a * da/ds 66 | 67 | -- Affine parametric lens 68 | -- (really a composition of linear and bias, but they are always used in combination) 69 | 70 | affine : (mIn : Nat) -> PLens (Para mIn) (V mIn) Double 71 | affine nOut = MkPLens fwd' bwd' 72 | where 73 | fwd' : (Para mIn, V mIn) -> Double 74 | fwd' (p, s) = foldl (+) (bias p) (zipWith (*) (weight p) s) -- a = b + w * s 75 | bwd' : (Para mIn, V mIn, Double) -> (Para mIn, V mIn) 76 | bwd' (p, s, a) = ( MkPara (map (a *) s) a -- (da/dw, da/db) 77 | , map (a *) (weight p)) -- da/ds 78 | 79 | -- Initialize parameters for an affine lens 80 | initPara : (mIn : Nat) -> Stream Double -> (Para mIn, Stream Double) 81 | initPara mIn s = 82 | let (v, s') = takes mIn s 83 | (x, s'') = take1 s' 84 | in (MkPara v x, s'') 85 | 86 | 87 | -- Neuron with mIn inputs and one output 88 | 89 | -- affine : PLens (Para mIn) (V mIn) Double 90 | -- activ : Lens Double Double 91 | -- composite : PLens (Para mIn) (V mIn) Double 92 | 93 | export 94 | neuron : (mIn : Nat) -> PLens (Para mIn) (V mIn) Double 95 | neuron mIn = composeR (affine mIn) activ 96 | 97 | 98 | -- A layer of neurons 99 | 100 | -- n neurons with m inputs each 101 | -- 1 2 .. n 102 | -- | | | 103 | -- m m m 104 | -- \ / \ / 105 | -- m 106 | -- ParaBlock m n = Vect n (Para m) 107 | -- neuron m : PLens (Vect n (Para 1)) (V m) (V 1) 108 | -- vecLens n (neuron m): PLens (Vect n (Vect n (Para 1))) (Vect n (V m)) (Vect n (V 1)) 109 | -- branch n : Lens (V m) (Vect n (V m)) 110 | -- s a 111 | -- composeL : Lens s a -> PLens p a b -> PLens p s b 112 | export 113 | layer : (nOut : Nat) -> (mIn : Nat) -> PLens (Vect nOut (Para mIn)) (V mIn) (V nOut) 114 | layer nOut mIn = composeL (branch nOut) (vecLens nOut (neuron mIn)) 115 | 116 | -- Initialize parameters for a layer of n neurons, each with m inputs 117 | -- ParaBlock mIn nOut, vector of nOut parameters, each mIn wide 118 | export 119 | initParaBlock : (mIn : Nat) -> (nOut : Nat) -> Stream Double -> 120 | (ParaBlock mIn nOut, Stream Double) 121 | initParaBlock mIn nOut s = unfoldl nOut (initPara mIn) s 122 | 123 | -- mean square error 0.5 * Sum (si - gi)^2 124 | -- derivative: d/dsi = (si - gi) 125 | delta : V n -> V n -> Double 126 | delta s g = 0.5 * (sum $ map (\x => x * x) (zipWith (-) s g)) 127 | 128 | -- Sum of squares loss lens 129 | export 130 | loss : V n -> Lens (V n) Double 131 | loss gtruth = MkLens (\s => delta s gtruth) 132 | (\(s, a) => backLoss gtruth s a) 133 | where 134 | backLoss : V n -> V n -> Double -> V n 135 | backLoss g s a = map ( a *) (zipWith (-) s g) 136 | -------------------------------------------------------------------------------- /src/PLens.idr: -------------------------------------------------------------------------------- 1 | module PLens 2 | import Data.Vect 3 | import HVect 4 | 5 | ---------- 6 | -- Parametric lens 7 | 8 | -- record PLens p p' s s' a a' 9 | -- fwd : (p, s) -> a 10 | -- lens1.bwd : (p, s, a') -> (p', s') 11 | 12 | public export 13 | record PLens p s a where 14 | constructor MkPLens 15 | fwd : (p, s) -> a 16 | bwd : (p, s, a) -> (p, s) 17 | 18 | -- Special case of parametric lens with p = () 19 | -- Simplifies composition 20 | 21 | public export 22 | record Lens s a where 23 | constructor MkLens 24 | fwd0 : s -> a 25 | bwd0 : (s, a) -> s 26 | 27 | -- Composition of parametric lenses 28 | export 29 | compose : PLens p s a -> PLens q a b -> 30 | PLens (p, q) s b 31 | -- lens1.fwd : (p, s) -> a 32 | -- lens1.bwd : (p, s, a) -> (p, s) 33 | -- lens2.fwd : (q, a) -> b 34 | -- lens2.bwd : (q, a, b) -> (q, a) 35 | compose lens1 lens2 = MkPLens fwd' bwd' 36 | where 37 | fwd' : ((p, q), s) -> b 38 | fwd' ((p, q), s) = lens2.fwd (q, lens1.fwd (p, s)) 39 | bwd' : ((p, q), s, b) -> ((p, q), s) 40 | bwd' ((p, q), s, b) = 41 | let (q', a') = lens2.bwd (q, lens1.fwd (p, s), b) 42 | (p', s') = lens1.bwd (p, s, a') 43 | in ((p', q'), s') 44 | 45 | -- Helpers for composing a parametric lens with a non-parametric one 46 | export 47 | composeR : PLens p s a -> Lens a b -> 48 | PLens p s b 49 | composeR lens1 lens2 = MkPLens fwd' bwd' 50 | where 51 | fwd' : (p, s) -> b 52 | fwd' (p, s) = lens2.fwd0 (lens1.fwd (p, s)) 53 | bwd' : (p, s, b) -> (p, s) 54 | bwd' (p, s, b) = 55 | let a' = lens2.bwd0 (lens1.fwd (p, s), b) 56 | (p', s') = lens1.bwd (p, s, a') 57 | in (p', s') 58 | export 59 | composeL : Lens s a -> PLens p a b -> 60 | PLens p s b 61 | -- lens1.fwd : s -> a 62 | -- lens1.bwd : (s, a) -> s 63 | -- lens2.fwd : (p, a) -> b 64 | -- lens2.bwd : (p, a, b) -> (p, a) 65 | composeL lens1 lens2 = MkPLens fwd' bwd' 66 | where 67 | fwd' : (p, s) -> b 68 | fwd' (p, s) = lens2.fwd (p, lens1.fwd0 s) 69 | bwd' : (p, s, b) -> (p, s) 70 | bwd' (p, s, b) = 71 | let (p', a') = lens2.bwd (p, lens1.fwd0 s, b) 72 | s' = lens1.bwd0 (s, a') 73 | in (p', s') 74 | 75 | -- Product of parametric lenses, 76 | prodLens : 77 | PLens p s a -> 78 | PLens p' s' a' -> 79 | PLens (p, p') (s, s') (a, a') 80 | -- lens1.fwd : (p, s) -> a 81 | -- lens1.bwd : (p, s, a) -> (p, s) 82 | prodLens lens1 lens2 = 83 | MkPLens fwdProd bwdProd 84 | where 85 | fwdProd : ((p, p'), (s, s')) -> (a, a') 86 | fwdProd ((p, p'), (s, s')) = (lens1.fwd (p, s), lens2.fwd (p', s')) 87 | 88 | bwdProd : ((p, p'), (s, s'), (a, a')) -> ((p, p'), (s, s')) 89 | bwdProd ((p, p'), (s, s'), (a, a')) = 90 | let (q, t) = lens1.bwd (p, s, a) 91 | (q', t') = lens2.bwd (p', s', a') 92 | in ((q, q'), (t, t')) 93 | 94 | -- duplicate a lens in parallel n+1 times 95 | export 96 | vecLens : (n : Nat) -> PLens p s a -> PLens (Vect n p) (Vect n s) (Vect n a) 97 | vecLens Z _ = MkPLens (\(Nil, Nil) => Nil) (\(Nil, Nil, Nil) => (Nil, Nil)) 98 | vecLens (S n) lns = MkPLens fwd' bwd' 99 | where 100 | lnsN : PLens (Vect n p) (Vect n s) (Vect n a) 101 | lnsN = vecLens n lns 102 | fwd' : (Vect (S n) p, Vect (S n) s) -> Vect (S n) a 103 | fwd' (p :: ps, s :: ss) = lns.fwd (p, s) :: lnsN.fwd (ps, ss) 104 | bwd' : (Vect (S n) p, Vect (S n) s, Vect (S n) a) -> (Vect (S n) p, Vect (S n) s) 105 | bwd' (p :: ps, s :: ss, a :: as) = 106 | let (p', s') = lns.bwd (p, s, a) 107 | (ps', ss') = lnsN.bwd (ps, ss, as) 108 | in (p' :: ps', s' :: ss') 109 | 110 | -- A branching combinator 111 | export 112 | branch : Monoid s => (n : Nat) -> Lens s (Vect n s) 113 | branch n = MkLens (replicate n) (\(_, ss) => concat ss) -- pointwise <+> 114 | 115 | -- Batch n lenses in parallel sharing the same parameters 116 | -- input and output are n-tupled, parameters are collected 117 | batch : Monoid p => 118 | (n : Nat) -> 119 | PLens p s a -> 120 | PLens p (Vect n s) (Vect n a) 121 | batch n lns = 122 | MkPLens fwdB bwdB 123 | where 124 | fwdB : (p, Vect n s) -> Vect n a 125 | fwdB (p, ss) = map lns.fwd (zip (replicate n p) ss) 126 | bwdB : (p, Vect n s, Vect n a) -> (p, Vect n s) 127 | bwdB (p, ss, as) = 128 | let (ps', ss') = unzip $ map lns.bwd $ zip3 (replicate n p) ss as 129 | in (concat ps', ss') 130 | 131 | 132 | 133 | -- xs = [1, 2, 3, 4, 5, 6] 134 | -- vw = [[1, 2, 3], [4, 5, 6]] m=3 n=2 135 | export 136 | rechunk : (m : Nat) -> (n : Nat) -> Vect (n * m) a -> Vect n (Vect m a) 137 | rechunk m Z xs = [] 138 | rechunk m (S k) xs = take m xs :: rechunk m k (drop m xs) 139 | 140 | -- A connector lens, flattens a mxn array 141 | export 142 | flatten : {m : Nat} -> {n : Nat} -> Lens (Vect n (Vect m s)) (Vect (n * m) s) 143 | flatten = MkLens fwd' bwd' 144 | where 145 | fwd' : Vect n (Vect m s) -> Vect (n*m) s 146 | fwd' vs = concat vs 147 | bwd' : {m : Nat} -> {n : Nat} -> (Vect n (Vect m s), Vect (n*m) s) -> (Vect n (Vect m s)) 148 | bwd' {m} {n} (vs, w) = (rechunk m n w) 149 | 150 | -- Batches n identical neural networks sharing the same parameters 151 | -- Use it for batch training 152 | export 153 | batchN : (n : Nat) -> 154 | (Vect n p -> p) -> 155 | PLens p s a -> 156 | PLens p (Vect n s) (Vect n a) 157 | batchN n collectP lns = 158 | MkPLens fwdB bckB 159 | where 160 | fwdB : (p, Vect n s) -> Vect n a 161 | fwdB (p, ss) = map lns.fwd (zip (replicate n p) ss) 162 | bckB : (p, Vect n s, Vect n a) -> (p, Vect n s) 163 | bckB (p, ss, as) = 164 | let (ps', ss') = unzip $ map lns.bwd $ zip3 (replicate n p) ss as 165 | in (collectP ps', ss') 166 | 167 | -- Produce a singleton vector 168 | export 169 | single : Lens s (Vect 1 s) 170 | single = MkLens fwd bwd 171 | where 172 | fwd : s -> Vect 1 s 173 | fwd s = [s] 174 | bwd : (s, Vect 1 s) -> s 175 | bwd (_, [s]) = s 176 | -------------------------------------------------------------------------------- /src/Perceptron.idr: -------------------------------------------------------------------------------- 1 | module Perceptron 2 | import Data.Vect 3 | import HVect 4 | import PLens 5 | import NNet 6 | 7 | -- The architecture is specified by number of inputs mIn and a list of l+1 layers ns 8 | -- Where l is zero or more hidden layers 9 | -- mIn -> [mIn, n1] -> [n1, n2] -> ... [n l, n (l+1)] 10 | 11 | public export 12 | data Layout : (mIn : Nat) -> (layers : Vect (S l) Nat) -> Type where 13 | MkLayout : (mIn : Nat) -> (l : Nat) -> (layers : Vect (S l) Nat) -> Layout mIn layers 14 | 15 | export 16 | inN : Layout i ls -> Nat 17 | inN (MkLayout mIn l _) = mIn 18 | 19 | export 20 | outN : Layout i ls -> Nat 21 | outN (MkLayout _ l layers) = last layers 22 | 23 | -- Chain of parameter blocks 24 | -- Parameters for multi-layer perceptron with mIn inputs 25 | -- A chain of types: 26 | -- ParaBlock mIn n1 :: ParaBlock n1 n2 :: ParaBlock n2 n3 ... 27 | public export 28 | ParaChain : Layout mIn ns -> Vect (length ns) Type 29 | ParaChain ly with (ly) 30 | _ | (MkLayout m Z [n]) = [ParaBlock m n] 31 | _ | (MkLayout m (S l) (n1 :: n2 :: ns')) = ParaBlock m n1 :: ParaChain (MkLayout n1 l (n2 :: ns')) 32 | 33 | 34 | -- Chain of vectors of parameter blocks (for batches of perceptrons) 35 | 0 VParaChain : (k : Nat) -> Layout mIn ns -> Vect (Vect.length ns) Type 36 | VParaChain k ly = ReplTypes k (ParaChain ly) 37 | 38 | -- Proof that every type in ParaChain is a Monoid 39 | isMonoChain : (ly : Layout mIn ns) -> HVect (map Monoid (ParaChain ly)) 40 | isMonoChain ly with (ly) 41 | _ | (MkLayout m Z [n]) = [%search] 42 | _ | (MkLayout m (S l) (n1 :: n2 :: ns')) = (%search) :: isMonoChain (MkLayout n1 l (n2 :: ns')) 43 | 44 | 45 | -- Proof that every type in ParaChain is a Show 46 | isShowChain : (ly : Layout mIn ns) -> HVect (map Show (ParaChain ly)) 47 | isShowChain ly with (ly) 48 | _ | (MkLayout m Z [n]) = [%search] 49 | _ | (MkLayout m (S l) (n1 :: n2 :: ns')) = (%search) :: isShowChain (MkLayout n1 l (n2 :: ns')) 50 | 51 | 52 | -- Proof that every type in ParaChain is a VSpace 53 | isVSpaceChain : (ly : Layout mIn ns) -> HVect (map VSpace (ParaChain ly)) 54 | isVSpaceChain ly with (ly) 55 | _ | (MkLayout m Z [n]) = [%search] 56 | _ | (MkLayout m (S l) (n1 :: n2 :: ns')) = (%search) :: isVSpaceChain (MkLayout n1 l (n2 :: ns')) 57 | 58 | 59 | -- concatH : {k : Nat} -> {l : Nat} -> {ts : Vect l Type} -> {isMono : HVect (map Monoid ts)} -> 60 | -- HVect (ReplTypes k ts) -> HVect ts 61 | 62 | -- in (ParaChain mIn ns), all types are Monoid 63 | collectH : {k : Nat} -> {mIn : Nat} -> {layers : Vect (S l) Nat} -> (ly : Layout mIn layers) -> 64 | HVect (VParaChain k ly) -> HVect (ParaChain ly) 65 | collectH ly@(MkLayout mIn l layers) hv = 66 | concatH {ts = ParaChain ly} {isMono = isMonoChain ly} hv 67 | 68 | -- transposeH : {k : Nat} -> {l : Nat} -> {ts : Vect l Type} -> 69 | -- Vect k (HVect ts) -> HVect (ReplTypes k ts) 70 | 71 | export 72 | collectParas : {k : Nat} -> {mIn : Nat} -> {l : Nat} -> {layers : Vect (S l) Nat} -> 73 | (ly : Layout mIn layers) -> 74 | Vect k (HVect (ParaChain ly)) -> HVect (ParaChain ly) 75 | collectParas ly@(MkLayout mIn l layers) v = collectH ly (transposeH v) 76 | 77 | export 78 | showParas : (ly : Layout mIn layers) -> HVect (ParaChain ly) -> String 79 | showParas (MkLayout mIn Z [n]) v = show v 80 | showParas (MkLayout mIn (S l) (n1 :: n2 :: ns)) (p :: ps) = show p ++ ", " ++ 81 | showParas (MkLayout n1 l (n2 :: ns)) ps 82 | 83 | public export 84 | {ly : Layout mIn layers} -> Show (HVect (ParaChain ly)) where 85 | show {ly} ps = showParas ly ps 86 | 87 | -- Multi layer perceptron with m inputs and l+1 layers 88 | -- neuron count in each layer is given by (Vect l Nat) 89 | 90 | -- 1 2 .. n2 [n2] 91 | -- n1 n1 n1 92 | -- |/ \|/ \| 93 | -- 1 2 .. n1 [n1] <-n1- [P[m], P[m] .. P[m]] 94 | -- m m m 95 | -- \ / \ / 96 | -- m 97 | 98 | export 99 | makeMLP : (ly : Layout mIn layers) -> 100 | PLens (HVect (ParaChain ly)) (V mIn) (V (last layers)) 101 | makeMLP (MkLayout mIn Z [nOut]) = MkPLens fwd' bwd' 102 | where 103 | lr : PLens (ParaBlock mIn nOut) (V mIn) (V nOut) 104 | lr = layer nOut mIn 105 | -- new layout with one layer 106 | Ly : Layout mIn [nOut] -- must be capitalized or the magic won't happen 107 | Ly = MkLayout mIn Z [nOut] 108 | 109 | fwd' : (HVect (ParaChain Ly), V mIn) -> V (nOut) 110 | fwd' ([p], v) = lr.fwd (p, v) 111 | bwd' : (HVect (ParaChain Ly), V mIn, V nOut) -> (HVect (ParaChain Ly), V mIn) 112 | bwd' ([p], v, w) = let (p', v') = lr.bwd (p, v, w) 113 | in ([p'], v') 114 | makeMLP (MkLayout mIn (S l) (n1 :: n2 :: ns)) = MkPLens fwd' bwd' 115 | where 116 | -- m -> [m, n1] -> [n1, n2] -> ... [n l, n (l+1)] 117 | -- Layout for the recursive part 118 | Ly : Layout n1 (n2 :: ns) 119 | Ly = MkLayout n1 l (n2 :: ns) 120 | mlp' : PLens (HVect (ParaChain Ly)) (V (inN Ly)) (V (outN Ly)) 121 | mlp' = makeMLP Ly --<< recurse 122 | -- compose with the bottom layer 123 | mlpComp : PLens (ParaBlock mIn n1, HVect (ParaChain Ly)) 124 | (V mIn) 125 | (V (last (n2 :: ns))) 126 | mlpComp = compose (layer n1 mIn) mlp' 127 | -- New layout for the composite 128 | Ly' : Layout mIn (n1 :: n2 :: ns) 129 | Ly' = MkLayout mIn (S l) (n1 :: n2 :: ns) 130 | 131 | fwd' : (HVect (ParaChain Ly'), V mIn) -> V (last (n1 :: n2 :: ns)) 132 | fwd' (p1 :: ps, vm) = mlpComp.fwd ((p1, ps), vm) 133 | bwd' : (HVect (ParaChain Ly'), V mIn, V (last (n1 :: n2 :: ns))) -> 134 | (HVect (ParaChain Ly'), V mIn) 135 | bwd' (pmn1 :: pmns, s, a) = 136 | let ((pmn1', pmns'), s') = mlpComp.bwd ((pmn1, pmns), s, a) 137 | in (pmn1' :: pmns', s') 138 | 139 | 140 | -- Initialize parameters for an MLP 141 | 142 | export 143 | initParaChain : (ly : Layout mIn layers) -> 144 | Stream Double -> (HVect (ParaChain ly), Stream Double) 145 | initParaChain (MkLayout m Z ([n])) s = 146 | let (pb, s') = initParaBlock m n s 147 | in ([pb], s') 148 | initParaChain (MkLayout m (S l) (n1 :: n2 :: ns)) s = 149 | let (pb, s') = initParaBlock m n1 s 150 | (pbs, s'') = initParaChain (MkLayout n1 l (n2 :: ns)) s' 151 | in (pb :: pbs, s'') 152 | 153 | --------------------------------------------------------------------------------