├── .gitignore ├── .styx ├── pretty-compact.nix ├── shell.nix └── typedflow.nix ├── LICENSE ├── Makefile ├── README.org ├── TypedFlow.hs ├── TypedFlow ├── Abstract.hs ├── Broadcast.hs ├── Haskell.hs ├── Layers.hs ├── Layers │ ├── Core.hs │ ├── RNN.hs │ └── RNN │ │ ├── Attention.hs │ │ ├── Base.hs │ │ └── Cells.hs ├── Learn.hs ├── Memo.hs ├── Memo2.hs ├── Models │ ├── Topic.hs │ └── Transformer.hs ├── Python.hs ├── TF.hs ├── Types.hs └── Types │ └── Proofs.hs ├── cabal.project ├── docs ├── HOT.org ├── Talk.org ├── cards.jpg └── imperiallegion.jpg ├── examples ├── agreement │ └── Aggr.hs ├── mnist │ ├── MNIST.hs │ ├── Makefile │ ├── main.py │ └── mnist_model.py └── seq2seq │ ├── GenTr.hs │ ├── Makefile │ ├── Seq2Seq.hs │ ├── main.py │ └── shell.nix ├── styx.yaml ├── typedflow.cabal └── typedflow_rts.py /.gitignore: -------------------------------------------------------------------------------- 1 | .styx 2 | *~ 3 | dist 4 | dist-* 5 | cabal-dev 6 | *.o 7 | *.hi 8 | *.chi 9 | *.chs.h 10 | *.dyn_o 11 | *.dyn_hi 12 | .hpc 13 | .hsenv 14 | .cabal-sandbox/ 15 | cabal.sandbox.config 16 | *.prof 17 | *.aux 18 | *.hp 19 | *.eventlog 20 | .stack-work/ 21 | cabal.project.local 22 | .HTF/ 23 | /examples/seq2seq/s2s.py 24 | /examples/seq2seq/synthtrees.txt 25 | MNIST_data 26 | __pycache__ 27 | /examples/seq2seq/GenTr 28 | /.tramp_history 29 | -------------------------------------------------------------------------------- /.styx/pretty-compact.nix: -------------------------------------------------------------------------------- 1 | { mkDerivation, aeson, base, base-compat, bytestring, containers 2 | , criterion, deepseq, fetchgit, pretty, stdenv, text 3 | , unordered-containers, wl-pprint 4 | }: 5 | mkDerivation { 6 | pname = "pretty-compact"; 7 | version = "3.0"; 8 | src = fetchgit { 9 | url = "git@github.com:jyp/prettiest.git"; 10 | sha256 = "0m8bjpc1pwzfkdzq7fgji81yffwn91ywybvmnazmy2b47rg24wjf"; 11 | rev = "a36f4ea19eed4ece78f7c939a1bc73a3393386a2"; 12 | }; 13 | libraryHaskellDepends = [ base base-compat containers ]; 14 | benchmarkHaskellDepends = [ 15 | aeson base base-compat bytestring criterion deepseq pretty text 16 | unordered-containers wl-pprint 17 | ]; 18 | description = "Pretty-printing library"; 19 | license = "GPL"; 20 | } 21 | -------------------------------------------------------------------------------- /.styx/shell.nix: -------------------------------------------------------------------------------- 1 | { nixpkgs ? import {} 2 | 3 | }: 4 | let nixpkgs_source = 5 | fetchTarball "https://github.com/NixOS/nixpkgs/archive/nixos-21.05.tar.gz"; 6 | nixpkgs' = (import nixpkgs_source){}; 7 | in with nixpkgs'.pkgs; 8 | let hp = haskellPackages.override{ 9 | overrides = self: super: { 10 | typedflow = self.callPackage ./typedflow.nix {}; 11 | };}; 12 | getHaskellDeps = ps: path: 13 | let f = import path; 14 | gatherDeps = { buildDepends ? [], libraryHaskellDepends ? [], executableHaskellDepends ? [], libraryToolDepends ? [], executableToolDepends ? [], ...}: 15 | buildDepends ++ libraryHaskellDepends ++ executableHaskellDepends ++ libraryToolDepends ++ executableToolDepends; 16 | x = f (builtins.intersectAttrs (builtins.functionArgs f) 17 | (ps // 18 | nixpkgs'.pkgs) # can also depend on non-haskell packages 19 | // {lib = lib; mkDerivation = gatherDeps;}); 20 | in x; 21 | ghc = hp.ghcWithPackages (ps: with ps; lib.lists.subtractLists 22 | [typedflow] 23 | ([ cabal-install 24 | QuickCheck hscolour 25 | ] ++ getHaskellDeps ps ./typedflow.nix)); 26 | in 27 | pkgs.stdenv.mkDerivation { 28 | name = "my-haskell-env-0"; 29 | buildInputs = [ glibcLocales ghc ]; 30 | shellHook = '' 31 | export LANG=en_US.UTF-8 32 | eval $(egrep ^export ${ghc}/bin/ghc) 33 | ''; 34 | } 35 | -------------------------------------------------------------------------------- /.styx/typedflow.nix: -------------------------------------------------------------------------------- 1 | { mkDerivation, base, containers, ghc-typelits-knownnat, lib, mtl 2 | , prettyprinter 3 | }: 4 | mkDerivation { 5 | pname = "typedflow"; 6 | version = "0.9"; 7 | src = /home/jyp/repo/gu/TypedFlow; 8 | libraryHaskellDepends = [ 9 | base containers ghc-typelits-knownnat mtl prettyprinter 10 | ]; 11 | description = "Typed frontend to TensorFlow and higher-order deep learning"; 12 | license = lib.licenses.lgpl3Only; 13 | } 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | viewdoc: dist/doc/html/typedflow/index.html 3 | xdg-open $< 4 | 5 | dist/doc/html/typedflow/index.html: 6 | styx cabal -- haddock --hyperlink-source 7 | styx cabal -- hscolour 8 | 9 | -------------------------------------------------------------------------------- /README.org: -------------------------------------------------------------------------------- 1 | #+TITLE: TypedFlow 2 | 3 | TypedFlow is a typed, higher-order frontend to [[http://www.tensorflow.org][TensorFlow]] and a 4 | high-level library for deep-learning. 5 | 6 | The main design principles are: 7 | 8 | - To make the parameters of layers explicit. This choice makes 9 | sharing of parameters explicit and allows to implement "layers" as 10 | pure functions. 11 | 12 | - To provide as precise as possible types. Functions are explicit 13 | about the shapes and elements of the tensors that they manipulate 14 | (they are often polymorphic in shapes and elements though.) 15 | 16 | - To let combinators be as transparent as possible. If a NN layers 17 | is a simple tensor transformation it will be exposed as such. 18 | 19 | 20 | In this version, the interface to TensorFlow is done via python-code 21 | generation and a suitable runtime system. 22 | 23 | ** Documentation 24 | 25 | The compiled documentation should be found on [[https://hackage.haskell.org/package/typedflow][hackage]]. 26 | 27 | ** Examples 28 | 29 | TypedFlow comes with two examples of neural networks: 30 | 31 | - An adaptation of the [[examples/mnist][MNIST tensorflow tutorial]] 32 | - A simple [[examples/seq2seq][sequence to sequence model]] which 33 | attempts to learn to translate pre-order into post-order. 34 | 35 | To running the examples can be done like so: 36 | 37 | #+BEGIN_SRC shell 38 | nix-env -iA nixpkgs.haskellPackages.styx 39 | nix-env -iA nixpkgs.cabal2nix 40 | styx configure 41 | cd examples/seq2seq 42 | make 43 | #+END_SRC 44 | 45 | -------------------------------------------------------------------------------- /TypedFlow.hs: -------------------------------------------------------------------------------- 1 | {-| 2 | Module : TypedFlow 3 | Description : Higher-Order Typed Binding to TensorFlow and Deep Learning Library 4 | Copyright : (c) Jean-Philippe Bernardy, 2017 5 | License : LGPL-3 6 | Maintainer : jean-philippe.bernardy@gu.se 7 | Stability : experimental 8 | 9 | This module re-exports all functions. 10 | -} 11 | 12 | module TypedFlow 13 | (module TypedFlow.Types 14 | ,module TypedFlow.TF 15 | ,module TypedFlow.Layers 16 | ,module TypedFlow.Learn 17 | ,module GHC.TypeLits) where 18 | 19 | import TypedFlow.TF 20 | import TypedFlow.Types 21 | import TypedFlow.Layers 22 | import TypedFlow.Learn 23 | import GHC.TypeLits 24 | 25 | -------------------------------------------------------------------------------- /TypedFlow/Haskell.hs: -------------------------------------------------------------------------------- 1 | {-| 2 | Module : TypedFlow.Haskell 3 | Description : Generation of computation graph using tensorflow haskell. 4 | Copyright : (c) Jean-Philippe Bernardy, 2017 5 | License : LGPL-3 6 | Maintainer : jean-philippe.bernardy@gu.se 7 | Stability : experimental 8 | 9 | -} 10 | 11 | {-# LANGUAGE AllowAmbiguousTypes #-} 12 | {-# LANGUAGE ConstraintKinds #-} 13 | {-# LANGUAGE DataKinds #-} 14 | {-# LANGUAGE DeriveFoldable #-} 15 | {-# LANGUAGE DeriveFunctor #-} 16 | {-# LANGUAGE DeriveTraversable #-} 17 | {-# LANGUAGE FlexibleContexts #-} 18 | {-# LANGUAGE FlexibleInstances #-} 19 | {-# LANGUAGE GADTs #-} 20 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 21 | {-# LANGUAGE LambdaCase #-} 22 | {-# LANGUAGE MagicHash #-} 23 | {-# LANGUAGE MultiParamTypeClasses #-} 24 | {-# LANGUAGE OverloadedStrings #-} 25 | {-# LANGUAGE PatternSynonyms #-} 26 | {-# LANGUAGE RankNTypes #-} 27 | {-# LANGUAGE RecordWildCards #-} 28 | {-# LANGUAGE ScopedTypeVariables #-} 29 | {-# LANGUAGE StandaloneDeriving #-} 30 | {-# LANGUAGE TypeApplications #-} 31 | {-# LANGUAGE TypeFamilies #-} 32 | {-# LANGUAGE TypeInType #-} 33 | {-# LANGUAGE TypeOperators #-} 34 | {-# LANGUAGE UndecidableInstances #-} 35 | {-# LANGUAGE UndecidableSuperClasses #-} 36 | {-# LANGUAGE UnicodeSyntax #-} 37 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 38 | 39 | module TypedFlow.Haskell where 40 | 41 | import Data.Type.Equality 42 | import Data.List (genericReplicate) 43 | import GHC.TypeLits 44 | import Control.Monad.State 45 | import TypedFlow.Types 46 | import TypedFlow.Types.Proofs 47 | import TypedFlow.Abstract (newId, permToFun, unopInputShape) 48 | import TypedFlow.Memo 49 | import System.Mem.StableName 50 | import System.IO.Unsafe 51 | 52 | import qualified Data.Int as Backend 53 | 54 | import qualified TensorFlow.Core as Backend 55 | import qualified TensorFlow.GenOps.Core as BackCore 56 | import qualified TensorFlow.Minimize as Backend 57 | import qualified TensorFlow.Ops as Backend 58 | import qualified TensorFlow.NN as Backend 59 | -- import qualified TensorFlow.Variable as Backend 60 | import qualified TensorFlow.Tensor 61 | 62 | import qualified Data.IntMap as IM 63 | import Data.IntMap (IntMap) 64 | 65 | type BackendShape = BackendTensor ('Typ 'Int 'B32) 66 | type BackendTensor t = Backend.Tensor Backend.Build (HaskType t) 67 | type BackendVariable t = Backend.Tensor Backend.Ref (HaskType t) 68 | type BackendTensorType t = Backend.TensorType (HaskType t) 69 | 70 | shapeFromType :: ∀ (s :: Shape). KnownShape s => BackendShape 71 | shapeFromType = shapeVector (typeSShape @s) 72 | 73 | -- | Show a shape, but "None" is replaced by "-1" 74 | shapeVector :: forall (s::Shape) proxy. All KnownNat s => SList' proxy s -> BackendShape 75 | shapeVector s = shapeFromList (shapeToList'' s) 76 | 77 | permToTensor :: SShape s -> Permutation s t -> Backend.Tensor Backend.Build Backend.Int32 78 | permToTensor s p = Backend.vector (map (fromInteger . permToFun p) [0.. sListLength s]) 79 | 80 | shapeFromList :: [Integer] -> BackendShape 81 | shapeFromList = Backend.vector . map convertNone 82 | 83 | showShapeLen :: ∀ (s::Shape). KnownLen s => Backend.Int32 84 | showShapeLen = fromIntegral (listTypeLen @ s) 85 | 86 | convertNone :: Num a => Integer -> a 87 | convertNone n = (if n == 514229 then (-1) else fromIntegral n) 88 | 89 | -- runWithFeeds 90 | 91 | data BT (s :: Shape) (t :: Typ) where 92 | BT :: forall s t. (BackendTensor t) -> BT s t 93 | 94 | data HState = HState {genVars :: IntMap Var 95 | ,genPureTable :: SNMap22 Shape Typ T BT 96 | -- alternative: use tensorRefFromName and make this closer to the python backed. 97 | } 98 | 99 | type BM a = Backend.BuildT (StateT HState (State GState)) a 100 | 101 | data Var = forall s t v. TensorFlow.Tensor.TensorKind v => Var (SShape s) (STyp t) (Backend.Tensor v (HaskType t)) 102 | 103 | initializedVariable :: forall s a. KnownShape s => KnownTyp a => T s a -> BM (Ref s a) 104 | initializedVariable initVal = do 105 | BT i <- interpretPure initVal 106 | x <- lift (lift newId) 107 | v <- backendTensor (typeSTyp @a) $ Backend.initializedVariable i 108 | let var = (Var (typeSShape @s) (typeSTyp @a) v) 109 | lift (modify $ \HState{..} -> HState {genVars = IM.insert (fromIntegral x) var genVars,..}) 110 | return (Ref (fromIntegral x) typeSShape typeSTyp ) 111 | 112 | placeholder :: forall s a. SShape s -> STyp a -> BM (Ref s a) 113 | placeholder s t = do 114 | x <- lift (lift newId) 115 | ph <- backendTensor t $ Backend.placeholder (Backend.Shape (map convertNone $ shapeToList' s)) 116 | let var = (Var s t ph) 117 | lift (modify $ \HState{..} -> HState {genVars = IM.insert (fromIntegral x) var genVars,..}) 118 | return (Ref (fromIntegral x) s t ) 119 | 120 | interpGen :: Gen a -> BM a 121 | interpGen (GPReturn x) = return x 122 | interpGen (GPVariable _trainable _name initVal) = initializedVariable initVal 123 | interpGen (GPPlaceholder s t _name) = placeholder s t 124 | interpGen (GPModify _ _) = error "GPModify: TODO" 125 | interpGen (GPState f) = lift (lift (state f)) 126 | interpGen (GPBind a b) = do x <- interpGen a 127 | interpGen (b x) 128 | 129 | listProxyLen :: forall proxy s. KnownLen s => proxy s -> Integer 130 | listProxyLen _ = listTypeLen @s 131 | 132 | -- genDistr :: forall s s0 t. KnownTyp t => Distribution s t -> SShape s0 -> SShape s -> DOC 133 | -- genDistr d sh s1 = case d of 134 | -- TruncatedNormalD stddev -> funcall "tf.truncated_normal" 135 | -- [showSShape (sh .+. s1), named "stddev" (float stddev), named "dtype" (showTyp @t)] 136 | -- UniformD low high -> funcall "tf.random_uniform" [showSShape (sh .+. s1) 137 | -- ,named "minval" (float low) 138 | -- ,named "maxval" (float high) 139 | -- ,named "dtype" (showTyp @t)] 140 | -- OrthogonalD -> 141 | -- funcall' (funcall "tf.orthogonal_initializer" [named "dtype" (showTyp @t)]) [named "shape" (showSShape (sh .+. s1))] 142 | 143 | 144 | knownNumeric :: forall t k. KnownNumeric t => (KnownTyp t => Num (HaskType t) => Backend.OneOf '[Backend.Int32, Float, Double] (HaskType t) => k) -> k 145 | knownNumeric = knownNumeric' (typeSTyp @t) 146 | 147 | knownNumeric' :: forall t k. KnownNumeric t => STyp t -> (KnownTyp t => Num (HaskType t) => Backend.OneOf '[Backend.Int32, Float, Double] (HaskType t) => k) -> k 148 | knownNumeric' (STyp tk tb Refl) k = case tk of 149 | SFloat -> case tb of 150 | SB32 -> k 151 | SB64 -> k 152 | SBool -> error "TFNumeric bug" 153 | SInt -> case tb of 154 | SB32 -> k 155 | SB64 -> error "missing in tensorflow: int64 is not supported in matmul T_T" 156 | 157 | knownFloatingB :: forall t k. (KnownTyp t, TypKind t ~ 'Float) => (Backend.OneOf '[Float, Double] (HaskType t) => k) -> k 158 | knownFloatingB k = case bitsVal @(TypBits t) of 159 | SB32 -> k 160 | SB64 -> k 161 | 162 | knownInt :: forall t k. (KnownTyp t, TypKind t ~ 'Int) => (Backend.OneOf '[Backend.Int32, Backend.Int64] (HaskType t) => k) -> k 163 | knownInt k = case bitsVal @(TypBits t) of 164 | SB32 -> k 165 | SB64 -> k 166 | 167 | backendTensor :: STyp t -> (Backend.TensorType (HaskType t) => k) -> k 168 | backendTensor (STyp SFloat SB32 Refl) k = k 169 | backendTensor (STyp SInt SB64 Refl) k = k 170 | backendTensor (STyp SBool _ Refl) k = k 171 | backendTensor (STyp SFloat SB64 Refl) k = k 172 | backendTensor (STyp SInt SB32 Refl) k = k 173 | 174 | backendTensor' :: forall t k proxy. KnownTyp t => proxy t -> (Backend.TensorType (HaskType t) => k) -> k 175 | backendTensor' _ = backendTensor (typeSTyp @t) 176 | 177 | 178 | runUnOp :: forall s s1 t s2 u. KnownTyp u => KnownTyp t => BackendTensorType u => SShape s -> UnOp s1 t s2 u -> BT (s++s1) t -> BT (s++s2) u 179 | runUnOp sL op (BT x) = backendTensor (typeSTyp @t) $ case op of 180 | SliceOp _ sR lo hi -> BT $ BackCore.slice x 181 | (shapeFromList (replicate (sListLen sL) 0 ++ [lo] ++ replicate (sListLen sR) 0)) 182 | (shapeFromList (shapeToList' sL ++ [hi-lo] ++ (shapeToList' sR))) 183 | Axis1Op aop -> case aop of 184 | (ArgMax _ _) -> knownNumeric @t $ knownInt @u $ BT $ BackCore.argMax x (Backend.scalar sLLen) 185 | (OneHot _) -> knownNumeric @u $ knownInt @t $ BT $ Backend.oneHot x (Backend.scalar sLLen) (Backend.scalar 1) (Backend.scalar 0) 186 | ReduceOp _ _sR rop -> knownNumeric @t $ case rop of 187 | Max -> BT $ BackCore.max x redindices 188 | Min -> BT $ BackCore.min x redindices 189 | Sum -> BT $ Backend.sum x redindices 190 | Mean -> BT $ Backend.mean x redindices 191 | where redindices = (Backend.vector [fromIntegral (sListLen sL) :: Backend.Int32 ]) 192 | StopGradient -> BT $ BackCore.stopGradient x 193 | Cast -> BT $ Backend.cast x 194 | (Num1Op numop) -> knownNumeric @t $ case numop of 195 | Square -> BT (Backend.mul x x) 196 | Negate -> BT (Backend.neg x) 197 | Sign -> BT (Backend.sign x) 198 | Abs -> BT (Backend.abs x) 199 | FloorMod -> BT (Backend.floorMod x) 200 | Float1Op flop -> knownFloatingB @t $ knownFloating @(TypBits u) $ knownFloatingB @u $ case flop of 201 | Tanh -> BT (BackCore.tanh x) 202 | Sin -> BT (BackCore.sin x) 203 | Exp -> BT (BackCore.exp x) 204 | Sigmoid -> BT (BackCore.sigmoid x) 205 | Relu -> BT (BackCore.relu x) 206 | Floor -> BT (BackCore.floor x) 207 | Round -> BT (BackCore.round x) 208 | Cos -> BT (BackCore.cos x) 209 | Log -> BT (BackCore.log x) 210 | Asin -> BT (BackCore.asin x) 211 | Acos -> BT (BackCore.acos x) 212 | Sinh -> BT (BackCore.sinh x) 213 | Cosh -> BT (BackCore.cosh x) 214 | Asinh -> BT (BackCore.asinh x) 215 | Acosh -> BT (BackCore.acosh x) 216 | Atan -> BT (BackCore.atan x) 217 | Atanh -> BT (BackCore.atanh x) 218 | Sqrt -> BT (BackCore.sqrt x) 219 | HardSigmoid -> error "Haskell: no hard sigmoid defined yet" 220 | ClipByValue lo hi -> BT $ BackCore.clipByValue x (Backend.scalar $ realToFrac lo) (Backend.scalar $ realToFrac hi) 221 | Diag _ -> BT $ BackCore.batchMatrixDiag x 222 | where sLLen = fromIntegral (sListLen sL) :: Backend.Int32 223 | 224 | interpretPure :: forall s t. KnownTyp t => KnownShape s => T s t -> BM (BT s t) 225 | interpretPure x = do 226 | let sn = unsafePerformIO $ makeStableName x 227 | mv <- snMap22Lookup sn <$> lift (gets genPureTable) 228 | case mv of 229 | Just v -> return v 230 | Nothing -> do 231 | e <- interpretPure' (\s x' -> knownSShape s $ interpretPure x') typeSShape x 232 | lift $ modify (\g -> g {genPureTable = (snMap22Insert (KV sn e)) (genPureTable g)}) 233 | return e 234 | 235 | interpNilOp :: forall s t. Backend.TensorType (HaskType t) => NilOp s t -> BM (BT s t) 236 | interpNilOp = \case 237 | Constant c -> return $ BT $ Backend.scalar c 238 | Range n@Sat -> knownNumeric @t $ return $ 239 | let start,limit,delta :: HaskType t 240 | start = 0 241 | limit = fromIntegral $ natVal n 242 | delta = 1 243 | in BT $ Backend.range (Backend.scalar start) (Backend.scalar limit) (Backend.scalar delta) 244 | Variable (Ref r sr tr) -> do 245 | tbl <- lift (gets genVars) 246 | case IM.lookup r tbl of 247 | Just (Var sx tx x) -> case (testEq sx sr, testEq tx tr) of 248 | (Just Refl, Just Refl) -> return (BT (Backend.expr x)) 249 | _ -> error "panic: variable does not have the expected type" 250 | _ -> error "panic: variable not found" 251 | 252 | interpretPure' :: forall s t. KnownTyp t => (forall s' t'. KnownTyp t' => SShape s' -> T s' t' -> BM (BT s' t')) -> SShape s -> T s t -> BM (BT s t) 253 | interpretPure' rec sR = knownSShape sR $ backendTensor (typeSTyp @t) $ \case 254 | Unbroadcast{} -> error "broadcasting operation did not complete!" 255 | DirectBroadcast s0 s1 s2 s3 x -> do 256 | BT recx <- rec (s0 .+. s2) x 257 | let expandedShape = shapeFromList 258 | (concat [shapeToList' s0, genericReplicate (sListLength s1) 1 259 | ,shapeToList' s2, genericReplicate (sListLength s3) 1 ]) 260 | targetShape = shapeFromList sR 261 | return $ BT $ BackCore.broadcastTo (Backend.reshape recx expandedShape) targetShape 262 | -- Noise noiseId s0 s1 x -> do 263 | -- return $ (genDistr x s0 s1) <+> (text "# " <> integer noiseId) 264 | T op -> interpNilOp op 265 | Where c x y -> do 266 | BT rc <- rec typeSShape c 267 | BT rx <- rec typeSShape x 268 | BT ry <- rec typeSShape y 269 | return $ BT $ BackCore.select rc rx ry 270 | UnOp operation s0 x -> do 271 | recx <- rec (s0 .+. unopInputShape operation) x 272 | return (runUnOp s0 operation recx) 273 | MatMul s0 a b c x y -> do 274 | BT recx <- rec (s0 .+. a :* b :* Unit) x 275 | BT recy <- rec (s0 .+. b :* c :* Unit) y 276 | return $ knownNumeric @t $ BT $ BackCore.batchMatMul recx recy 277 | BinOp operation s0 s1 t s2 u x y -> knownSShape s0 $ knownSShape s1 $ knownSShape s2 $ knownProduct' s0 $ do 278 | BT recx <- rec (s0 .+. s1) x 279 | BT recy <- rec (s0 .+. s2) y 280 | let reshx = backendTensor t $ Backend.reshape recx (shapeVector (satProd s0 :* s1)) 281 | reshy = backendTensor u $ Backend.reshape recy (shapeVector (satProd s0 :* s2)) 282 | return $ case operation of 283 | Simple2Op sop -> case sop of 284 | Add -> knownNumeric @t $ BT $ Backend.add recx recy 285 | Divide -> knownNumeric @t $ BT $ BackCore.div recx recy 286 | Equal -> backendTensor u $ BT $ Backend.equal recx recy 287 | Subtract -> knownNumeric @t $ BT $ Backend.sub recx recy 288 | Multiply -> knownNumeric @t $ BT $ Backend.mul recx recy 289 | Minimum -> knownNumeric @t $ BT $ BackCore.minimum recx recy 290 | Maximum -> knownNumeric @t $ BT $ BackCore.maximum recx recy 291 | LessThan -> knownNumeric' u $ BT $ BackCore.less recx recy 292 | -- WTF moment: the arguments do not seem to be in the same order in python as in haskell 293 | -- python: https://www.tensorflow.org/api_docs/python/tf/nn/sparse_softmax_cross_entropy_with_logits 294 | -- haskell: https://tensorflow.github.io/haskell/haddock/tensorflow-core-ops-0.2.0.0/TensorFlow-GenOps-Core.html#v:sparseSoftmaxCrossEntropyWithLogits 295 | SparseSoftmaxCrossEntropyWithLogits -> case t of 296 | STyp SInt SB32 Refl -> knownFloatingB @t $ BT $ fst $ BackCore.sparseSoftmaxCrossEntropyWithLogits reshy reshx 297 | SoftmaxCrossEntropyWithLogits -> knownFloatingB @t $ BT $ fst $ BackCore.softmaxCrossEntropyWithLogits reshy reshx 298 | -- SigmoidCrossEntropyWithLogits -> knownFloatingB @t $ BT $ Backend.sigmoidCrossEntropyWithLogits recy recx -- type is not as general as necessary 299 | ReshapeFrom s t -> do 300 | BT rt <- rec s t 301 | return $ BT $ BackCore.reshape rt (shapeVector sR) 302 | Concat s0 s1 xs -> do 303 | let go :: forall s0 s1 ns. SShape s0 -> SShape s1 -> NP (Catable s0 s1 t) ns -> BM [BackendTensor t] 304 | go _ _ Unit = return [] 305 | go s0' s1' (Catable n y :* ys) = do 306 | BT y' <- rec (s0' .+. n :* s1') y 307 | (y' :) <$> go s0' s1' ys 308 | rxs <- go s0 s1 xs 309 | return $ BT $ Backend.concat (Backend.scalar (fromIntegral (sListLength s0))) rxs 310 | Transpose s p x -> do 311 | BT rx <- rec s x 312 | return $ BT $ Backend.transpose rx (permToTensor s p) 313 | -- Gather indexShape s0 m s1 x ix -> do 314 | -- rx <- rec (s0 .+. ((:*) m s1)) x 315 | -- rix <- rec indexShape ix 316 | -- return (func "tf.gather" [rx, rix] []) 317 | -- GatherND containerShape elementShape indexShape x ix -> do 318 | -- rx <- rec (containerShape .+. elementShape) x 319 | -- rix <- rec (indexShape *: (sListLenAsNat containerShape)) ix 320 | -- return (func "tf.gather_nd" [rx, rix] []) 321 | Convolution bs inChans outChans filterShape s0 x filters -> do 322 | BT recx <- rec (bs :* (s0 *: inChans)) x 323 | BT recFilters <- rec (filterShape .+. inChans :* outChans :* Unit) filters 324 | case filterShape of 325 | _width :* _height :* Unit -> 326 | return $ BT $ knownFloatingB @t $ BackCore.conv2D recx recFilters 327 | _ -> error "TypedFlow Haskell backend: convolution on an unsupported number of dims" 328 | -- Pool bs window typ numChans outSpatial x -> do 329 | -- rx <- rec ((:*) bs (zipWithMulSShapes window outSpatial .+. (:*) numChans Unit)) x 330 | -- return (func "tf.nn.pool" 331 | -- [rx, showSShape window, typ', text (show ("SAME" :: String))] 332 | -- [("strides", showSShape window)]) 333 | -- where typ' = text $ (show $ case typ of MaxPool -> "MAX"; AvgPool -> "AVG" :: String) 334 | -- -- where rec :: forall s' t'. KnownTyp t' => SShape s' -> T s' t' -> DOC 335 | -- -- rec = generatePure' 336 | 337 | -------------------------------------------------------------------------------- /TypedFlow/Layers.hs: -------------------------------------------------------------------------------- 1 | 2 | module TypedFlow.Layers 3 | (module TypedFlow.Layers.Core 4 | ,module TypedFlow.Layers.RNN 5 | ) where 6 | 7 | import TypedFlow.Layers.Core 8 | import TypedFlow.Layers.RNN 9 | 10 | -------------------------------------------------------------------------------- /TypedFlow/Layers/Core.hs: -------------------------------------------------------------------------------- 1 | {-| 2 | Module : TypedFlow.Layers.Core 3 | Description : Core layers and combinators. 4 | Copyright : (c) Jean-Philippe Bernardy, 2017 5 | License : LGPL-3 6 | Maintainer : jean-philippe.bernardy@gu.se 7 | Stability : experimental 8 | -} 9 | {-# LANGUAGE CPP #-} 10 | #if __GLASGOW_HASKELL__ >= 806 11 | {-# LANGUAGE NoStarIsType #-} 12 | #endif 13 | {-# LANGUAGE FlexibleInstances #-} 14 | {-# LANGUAGE MultiParamTypeClasses #-} 15 | {-# LANGUAGE RankNTypes #-} 16 | {-# LANGUAGE ConstraintKinds #-} 17 | {-# LANGUAGE ViewPatterns #-} 18 | {-# LANGUAGE TypeInType #-} 19 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 20 | {-# LANGUAGE AllowAmbiguousTypes #-} 21 | {-# LANGUAGE DataKinds #-} 22 | {-# LANGUAGE DeriveFoldable #-} 23 | {-# LANGUAGE DeriveFunctor #-} 24 | {-# LANGUAGE DeriveTraversable #-} 25 | {-# LANGUAGE FlexibleContexts #-} 26 | {-# LANGUAGE GADTs #-} 27 | {-# LANGUAGE MagicHash #-} 28 | {-# LANGUAGE RankNTypes #-} 29 | {-# LANGUAGE ScopedTypeVariables #-} 30 | {-# LANGUAGE StandaloneDeriving #-} 31 | {-# LANGUAGE TypeApplications #-} 32 | {-# LANGUAGE TypeFamilies #-} 33 | {-# LANGUAGE TypeOperators #-} 34 | {-# LANGUAGE UndecidableInstances #-} 35 | {-# LANGUAGE UnicodeSyntax #-} 36 | {-# LANGUAGE PatternSynonyms #-} 37 | 38 | module TypedFlow.Layers.Core 39 | ( 40 | -- * Dense 41 | DenseP(..), dense, (#), 42 | -- * Dropout 43 | DropProb(..), mkMask, mkDropout, mkDropouts, 44 | -- * Embedding 45 | EmbeddingP(..), embedding, 46 | -- * Convolutional 47 | ConvP(..), conv, conv', {-convValid,-} maxPool1D, maxPool2D, 48 | glu 49 | ) 50 | 51 | where 52 | import Prelude hiding (RealFrac(..)) 53 | import GHC.TypeLits 54 | import TypedFlow.TF 55 | import TypedFlow.Types 56 | import TypedFlow.Types.Proofs 57 | import TypedFlow.Abstract 58 | import Control.Monad.State (gets) 59 | import Data.Monoid ((<>)) 60 | --------------------- 61 | -- Linear functions 62 | 63 | 64 | -- | A dense layer is a linear function form a to b: a transformation matrix and a bias. 65 | data DenseP t a b = DenseP {denseWeights :: Tensor '[a,b] t 66 | ,denseBiases :: Tensor '[b] t} 67 | 68 | ----------------------- 69 | -- Feed-forward layers 70 | 71 | -- | Parameters for the embedding layers 72 | newtype EmbeddingP numObjects embeddingSize t = EmbeddingP (Tensor '[numObjects, embeddingSize] t) 73 | 74 | instance (KnownNat numObjects, KnownTyp b, KnownNat embeddingSize) => KnownTensors (EmbeddingP numObjects embeddingSize b) where 75 | travTensor f s (EmbeddingP p) = EmbeddingP <$> travTensor f s p 76 | 77 | instance (KnownNat numObjects, KnownBits b, KnownNat embeddingSize) => ParamWithDefault (EmbeddingP numObjects embeddingSize ('Typ 'Float b)) where 78 | defaultInitializer = EmbeddingP <$> (noise $ UniformD (-0.05) 0.05) 79 | 80 | instance (KnownNat numObjects, KnownBits b, KnownNat embeddingSize) => ParamWithDefault (EmbeddingP numObjects embeddingSize ('Typ 'Cmplx b)) where 81 | defaultInitializer = EmbeddingP <$> (mkComplex <$> (noise $ UniformD (-0.05) 0.05) <*> (noise $ UniformD (-0.05) 0.05)) 82 | 83 | -- | embedding layer 84 | embedding :: ∀ embeddingSize numObjects t. KnownNat embeddingSize => KnownNat numObjects => 85 | EmbeddingP numObjects embeddingSize t -> Tensor '[] Int32 -> Tensor '[embeddingSize] t 86 | embedding (EmbeddingP param) input = gather param input 87 | 88 | 89 | 90 | instance (KnownNat a, KnownNat b, KnownTyp t) => KnownTensors (DenseP t a b) where 91 | travTensor f s (DenseP x y) = DenseP <$> travTensor f (s<>"_w") x <*> travTensor f (s<>"_bias") y 92 | 93 | instance (KnownNat n, KnownNat m, KnownFloat b) => ParamWithDefault (DenseP b n m) where 94 | defaultInitializer = DenseP <$> glorotUniform <*> (noise $ TruncatedNormalD 0.1) 95 | 96 | -- | Dense layer (Apply a linear function) 97 | (#), dense :: ∀m n t. KnownNat n => KnownNat m => KnownNumeric t => DenseP t n m -> Tensor '[n] t -> Tensor '[m] t 98 | (DenseP weightMatrix bias) # v = (weightMatrix ∙ v) + bias 99 | 100 | dense = (#) 101 | 102 | -- | A drop probability. (This type is used to make sure one does not 103 | -- confuse keep probability and drop probability) 104 | data DropProb = DropProb Float 105 | 106 | -- | Generate a dropout function. The mask applied by the returned 107 | -- function will be constant for any given call to mkDropout. See 108 | -- 'noise' for the sampling behaviour. 109 | mkDropout :: forall s t. KnownShape s => KnownFloat t => DropProb -> Gen (Tensor s t -> Tensor s t) 110 | mkDropout d = (⊙) <$> mkMask d 111 | 112 | -- | Generate a 0-1 mask with given probability, suitable for dropout, 113 | -- or all ones if not in training phase. See 'noise' for the sampling 114 | -- behaviour. 115 | mkMask :: forall s t. KnownShape s => KnownFloat t => DropProb -> Gen (Tensor s t) 116 | mkMask (DropProb dropProb) = do 117 | let keepProb = 1 - dropProb 118 | let isTraining = genTrainingPlaceholder 119 | r <- noise $ UniformD keepProb (1 + keepProb) 120 | return $ if_ isTraining 121 | (floor r ⊘ constant (knownAlgebraic @t $ realToFrac keepProb)) 122 | ones 123 | 124 | newtype EndoTensor t s = EndoTensor (Tensor s t -> Tensor s t) 125 | 126 | -- | Generate a dropout function for an heterogeneous tensor vector. 127 | mkDropouts :: KnownFloat t => KnownLen shapes => All KnownShape shapes => DropProb -> Gen (HTV t shapes -> HTV t shapes) 128 | mkDropouts d = appEndoTensor <$> mkDropouts' typeSList where 129 | mkDropouts' :: forall shapes t. KnownFloat t => All KnownShape shapes => 130 | SList shapes -> Gen (NP (EndoTensor t) shapes) 131 | mkDropouts' Unit = return Unit 132 | mkDropouts' (_ :* rest) = do 133 | x <- mkDropout d 134 | xs <- mkDropouts' rest 135 | return (EndoTensor x :* xs) 136 | 137 | appEndoTensor :: NP (EndoTensor t) s -> HTV t s -> HTV t s 138 | appEndoTensor Unit Unit = Unit 139 | appEndoTensor (EndoTensor f :* fs) (F x :* xs) = F (f x) :* appEndoTensor fs xs 140 | 141 | 142 | ------------------------ 143 | -- Convolutional layers 144 | 145 | data ConvP t outChannels inChannels filterSpatialShape 146 | = ConvP (T (filterSpatialShape ++ '[inChannels,outChannels]) t) 147 | (T '[outChannels] t) 148 | 149 | instance (KnownNat outChannels,KnownNat inChannels, KnownShape filterSpatialShape, KnownFloat t) => 150 | ParamWithDefault (ConvP t outChannels inChannels filterSpatialShape) where 151 | defaultInitializer = prodHomo @filterSpatialShape @'[inChannels, outChannels] #> 152 | prodAssoc @(Product filterSpatialShape) @inChannels @outChannels #> 153 | knownAppend @filterSpatialShape @'[inChannels,outChannels] ?> 154 | knownProduct @filterSpatialShape ?> 155 | ConvP <$> (reshape <$> i) <*> pure (knownAlgebraic @t (constant 0.1)) 156 | where i :: Gen (T '[Product filterSpatialShape*inChannels,outChannels] t) 157 | i = knownProduct @filterSpatialShape ?> glorotUniform 158 | 159 | instance (KnownNat outChannels,KnownNat inChannels, KnownShape filterSpatialShape, KnownAlgebraic t) => 160 | KnownTensors (ConvP t outChannels inChannels filterSpatialShape) where 161 | travTensor f s (ConvP x y) = knownAppend @filterSpatialShape @'[inChannels,outChannels] ?> 162 | (ConvP <$> travTensor f (s<>"_filters") x <*> travTensor f (s <> "_biases") y) 163 | 164 | -- | Size-preserving convolution layer 165 | conv' :: forall s outChannels filterSpatialShape inChannels t. 166 | KnownShape s => KnownNat inChannels => KnownNat outChannels => KnownShape filterSpatialShape => KnownAlgebraic t 167 | => Length filterSpatialShape <= 3 168 | => Length filterSpatialShape ~ Length s 169 | => ConvP t outChannels inChannels filterSpatialShape 170 | -> T (s ++ '[inChannels]) t 171 | -> T (s ++ '[outChannels]) t 172 | conv' (ConvP filters bias) input = mapTT @s (+bias) (convolution @outChannels @filterSpatialShape @inChannels @s input filters) 173 | 174 | 175 | 176 | conv :: forall outChannels filterSpatialShape inChannels s t. 177 | KnownShape s => KnownNat inChannels => KnownNat outChannels => KnownShape filterSpatialShape => KnownAlgebraic t 178 | => Length filterSpatialShape <= 3 179 | => (Length filterSpatialShape + 1) ~ Length s -- The ranks must match, but not necessarily the dimensions 180 | => (Last s ~ outChannels) 181 | => ConvP t outChannels inChannels filterSpatialShape 182 | -> T (Init s ++ '[inChannels]) t 183 | -> T s t 184 | conv = initLast' @s #> 185 | incrPos @(Length filterSpatialShape) #> 186 | lengthInit (typeSList @s) #> 187 | incrCong @(Length filterSpatialShape) @(Length (Init s)) #> 188 | knownInit @s ?> 189 | conv' @(Init s) 190 | 191 | 192 | -- -- | Convolution layers with no padding (applying the filter only on 193 | -- -- positions where the input is fully defined, aka "VALID" in 194 | -- -- tensorflow.) 195 | -- convValid :: forall outChannels filterSpatialShape inChannels s t. 196 | -- ((1 + Length filterSpatialShape) ~ Length s, 197 | -- Length filterSpatialShape <= 3, 198 | -- KnownLen filterSpatialShape) -- the last dim of s is the batch size 199 | -- => ConvP t outChannels inChannels filterSpatialShape -- ^ Parameters 200 | -- -> T ('[inChannels] ++ AddSpatialDims s filterSpatialShape) ('Typ 'Float t) -- ^ input 201 | -- -> (T ('[outChannels] ++ s) ('Typ 'Float t)) 202 | -- convValid (ConvP filters bias) input = convolutionValid input filters + bias 203 | 204 | -- | Gated Linear Unit 205 | -- See: Language Modeling with Gated Convolutional Networks 206 | -- https://arxiv.org/pdf/1612.08083.pdf 207 | glu :: forall n t. KnownBits t => KnownNat n => T '[n+n] ('Typ 'Float t) -> T '[n] ('Typ 'Float t) 208 | glu x = plusMono @n @n #> knownPlus @n @n ?> 209 | let gate, h :: T '[n] ('Typ 'Float t) 210 | gate = slice0 @0 @n x 211 | h = termCancelation @n @n #> slice0 @n @(n+n) x 212 | in sigmoid gate ⊙ h 213 | -------------------------------------------------------------------------------- /TypedFlow/Layers/RNN.hs: -------------------------------------------------------------------------------- 1 | {-| 2 | Module : TypedFlow.Layers.RNN 3 | Description : RNN cells, layers and combinators. 4 | Copyright : (c) Jean-Philippe Bernardy, 2017 5 | License : LGPL-3 6 | Maintainer : jean-philippe.bernardy@gu.se 7 | Stability : experimental 8 | -} 9 | 10 | 11 | module TypedFlow.Layers.RNN ( 12 | module TypedFlow.Layers.RNN.Base, 13 | module TypedFlow.Layers.RNN.Cells, 14 | module TypedFlow.Layers.RNN.Attention) where 15 | 16 | import TypedFlow.Layers.RNN.Base 17 | import TypedFlow.Layers.RNN.Cells 18 | import TypedFlow.Layers.RNN.Attention 19 | -------------------------------------------------------------------------------- /TypedFlow/Layers/RNN/Attention.hs: -------------------------------------------------------------------------------- 1 | {-| 2 | Module : TypedFlow.Layers.RNN.Attention 3 | Description : Attention combinators to be used with RNN cells 4 | Copyright : (c) Jean-Philippe Bernardy, 2018 5 | License : LGPL-3 6 | Maintainer : jean-philippe.bernardy@gu.se 7 | Stability : experimental 8 | -} 9 | 10 | {-# LANGUAGE FlexibleInstances #-} 11 | {-# LANGUAGE MultiParamTypeClasses #-} 12 | {-# LANGUAGE RankNTypes #-} 13 | {-# LANGUAGE ConstraintKinds #-} 14 | {-# LANGUAGE ViewPatterns #-} 15 | {-# LANGUAGE TypeInType #-} 16 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 17 | {-# LANGUAGE AllowAmbiguousTypes #-} 18 | {-# LANGUAGE DataKinds #-} 19 | {-# LANGUAGE DeriveFoldable #-} 20 | {-# LANGUAGE DeriveFunctor #-} 21 | {-# LANGUAGE DeriveTraversable #-} 22 | {-# LANGUAGE FlexibleContexts #-} 23 | {-# LANGUAGE GADTs #-} 24 | {-# LANGUAGE MagicHash #-} 25 | {-# LANGUAGE RankNTypes #-} 26 | {-# LANGUAGE ScopedTypeVariables #-} 27 | {-# LANGUAGE StandaloneDeriving #-} 28 | {-# LANGUAGE TypeApplications #-} 29 | {-# LANGUAGE TypeFamilies #-} 30 | {-# LANGUAGE TypeOperators #-} 31 | {-# LANGUAGE UndecidableInstances #-} 32 | {-# LANGUAGE UnicodeSyntax #-} 33 | {-# LANGUAGE PatternSynonyms #-} 34 | 35 | module TypedFlow.Layers.RNN.Attention ( 36 | -- * Attention mechanisms 37 | -- ** Scoring functions 38 | AttentionScoring, 39 | multiplicativeScoring, 40 | AdditiveScoringP(..), additiveScoring, 41 | -- ** Attention functions 42 | AttentionFunction, 43 | uniformAttn, 44 | luongAttention, 45 | -- ** Attention combinators 46 | attentiveWithFeedback 47 | ) where 48 | 49 | import Prelude hiding (RealFrac(..)) 50 | import GHC.TypeLits 51 | import TypedFlow.TF 52 | import TypedFlow.Types 53 | import TypedFlow.Types.Proofs (appRUnit,(#>)) 54 | import TypedFlow.Layers.RNN.Base 55 | 56 | -- | An attention scoring function. This function should produce a 57 | -- score (between 0 and 1). 58 | type AttentionScoring t keySize valueSize = 59 | Tensor '[keySize] t -> Tensor '[valueSize] t -> Tensor '[] t 60 | 61 | -- | A function which attends to an external input. Typically a 62 | -- function of this type is a closure which has the attended input in 63 | -- its environment. This environment is interpreted as an associative 64 | -- memory form key to value. 65 | type AttentionFunction t keySize valueSize = 66 | T '[keySize] t -> T '[valueSize] t 67 | 68 | -- | @attnExample1 θ h st@ combines each element of the vector h with 69 | -- s, and applies a dense layer with parameters θ. The "winning" 70 | -- element of h (using softmax) is returned. 71 | uniformAttn :: ∀ valueSize m keySize t. KnownNat valueSize => KnownNat m => KnownFloat t 72 | => AttentionScoring t keySize valueSize -- ^ scoring function 73 | -> T '[] Int32 -- ^ length of the input 74 | -> T '[m,valueSize] t -- ^ input (what we're attending to) 75 | -> AttentionFunction t keySize valueSize 76 | uniformAttn score len hs key = c 77 | where xx,α :: T '[m] t 78 | xx = mapT (score key) hs 79 | α = softmax0 (mask ⊙ xx) 80 | c :: T '[valueSize] t 81 | c = hs ∙ α 82 | mask = cast (sequenceMask @m len) -- mask according to length 83 | 84 | -- | Add some attention to an RnnCell, and feed the attention vector to 85 | -- the next iteration in the rnn. (This follows the diagram at 86 | -- https://github.com/tensorflow/nmt#background-on-the-attention-mechanism 87 | -- commit 75aa22dfb159f10a1a5b4557777d9ff547c1975a). 88 | attentiveWithFeedback ::forall attSize cellSize inputSize w ss. KnownNat inputSize => KnownNat attSize => KnownLen ss => 89 | KnownTyp w => 90 | AttentionFunction w cellSize attSize -> 91 | RnnCell w ss (T '[inputSize+attSize] w) (T '[cellSize] w) -> 92 | RnnCell w ('[attSize] ': ss) (T '[inputSize ] w) (T '[attSize] w) 93 | attentiveWithFeedback attn cell = appRUnit @ss #> withFeedback (cell .-. timeDistribute attn) 94 | 95 | 96 | -- -- | LSTM for an attention model. The result of attention is fed to the next step. 97 | -- attentiveLstm :: forall attSize n x bs t. KnownNat bs => 98 | -- AttentionFunction t bs n attSize -> 99 | -- LSTMP t n (x+attSize) -> 100 | -- RnnCell t '[ '[attSize,bs], '[n,bs], '[n,bs] ] (Tensor '[x,bs] (Flt t)) (Tensor '[attSize,bs] (Flt t)) 101 | -- attentiveLstm att w = attentiveWithFeedback att (lstm w) 102 | 103 | 104 | -- | Luong attention function (following 105 | -- https://github.com/tensorflow/nmt#background-on-the-attention-mechanism 106 | -- commit 75aa22dfb159f10a1a5b4557777d9ff547c1975a). 107 | -- Essentially a dense layer with tanh activation, on top of uniform attention. 108 | luongAttention :: ∀ attnSize d m e w. KnownNat e => KnownNat d => KnownNat attnSize => KnownNat m => KnownFloat w 109 | => Tensor '[d+e,attnSize] w -- ^ weights for the dense layer 110 | -> AttentionScoring w e d -- ^ scoring function 111 | -> Tensor '[] Int32 -- ^ length of the input 112 | -> T '[m,d] w -- ^ inputs 113 | -> AttentionFunction w e attnSize 114 | luongAttention w scoring lens hs_ ht = 115 | let ct = uniformAttn scoring lens hs_ ht 116 | in (tanh (w ∙ (concat0 ct ht))) 117 | 118 | -- | Multiplicative scoring function 119 | multiplicativeScoring :: forall valueSize keySize t. 120 | KnownFloat t => KnownNat valueSize => KnownNat keySize 121 | => T [keySize,valueSize] t -- ^ weights 122 | -> AttentionScoring t keySize valueSize 123 | multiplicativeScoring w dt h = ir · h 124 | where ir :: T '[valueSize] t 125 | ir = w ∙ dt 126 | 127 | 128 | data AdditiveScoringP sz keySize valueSize t = AdditiveScoringP 129 | (Tensor '[1,sz] t) 130 | (Tensor '[keySize, sz] t) 131 | (Tensor '[valueSize, sz] t) 132 | 133 | instance (KnownNat n, KnownNat k, KnownNat v, KnownTyp t) => KnownTensors (AdditiveScoringP k v n t) where 134 | travTensor f s (AdditiveScoringP x y z) = AdditiveScoringP <$> travTensor f (s<>"_v") x <*> travTensor f (s<>"_w1") y <*> travTensor f (s<>"_w2") z 135 | instance (KnownNat n, KnownNat k, KnownNat v, KnownFloat t) => ParamWithDefault (AdditiveScoringP k v n t) where 136 | defaultInitializer = AdditiveScoringP <$> glorotUniform <*> glorotUniform <*> glorotUniform 137 | 138 | -- | An additive scoring function. See https://arxiv.org/pdf/1412.7449.pdf 139 | additiveScoring :: forall sz keySize valueSize t. KnownNat valueSize => KnownNat sz => KnownNat keySize => KnownFloat t => 140 | AdditiveScoringP sz keySize valueSize t -> AttentionScoring t valueSize keySize 141 | additiveScoring (AdditiveScoringP v w1 w2) dt h = r'' 142 | where w1h :: Tensor '[sz] t 143 | w1h = w1 ∙ h 144 | w2dt = w2 ∙ dt 145 | z' :: Tensor '[sz] t 146 | z' = tanh (w1h + w2dt) 147 | r'' = z' · squeeze0 v 148 | 149 | -------------------------------------------------------------------------------- /TypedFlow/Layers/RNN/Base.hs: -------------------------------------------------------------------------------- 1 | {-| 2 | Module : TypedFlow.Layers.RNN.Base 3 | Description : RNN cells, layers and combinators. 4 | Copyright : (c) Jean-Philippe Bernardy, 2017 5 | License : LGPL-3 6 | Maintainer : jean-philippe.bernardy@gu.se 7 | Stability : experimental 8 | -} 9 | 10 | {-# LANGUAGE FlexibleInstances #-} 11 | {-# LANGUAGE MultiParamTypeClasses #-} 12 | {-# LANGUAGE RankNTypes #-} 13 | {-# LANGUAGE ConstraintKinds #-} 14 | {-# LANGUAGE ViewPatterns #-} 15 | {-# LANGUAGE TypeInType #-} 16 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 17 | {-# LANGUAGE AllowAmbiguousTypes #-} 18 | {-# LANGUAGE DataKinds #-} 19 | {-# LANGUAGE DeriveFoldable #-} 20 | {-# LANGUAGE DeriveFunctor #-} 21 | {-# LANGUAGE DeriveTraversable #-} 22 | {-# LANGUAGE FlexibleContexts #-} 23 | {-# LANGUAGE GADTs #-} 24 | {-# LANGUAGE MagicHash #-} 25 | {-# LANGUAGE RankNTypes #-} 26 | {-# LANGUAGE ScopedTypeVariables #-} 27 | {-# LANGUAGE StandaloneDeriving #-} 28 | {-# LANGUAGE TypeApplications #-} 29 | {-# LANGUAGE TypeFamilies #-} 30 | {-# LANGUAGE TypeOperators #-} 31 | {-# LANGUAGE UndecidableInstances #-} 32 | {-# LANGUAGE UnicodeSyntax #-} 33 | {-# LANGUAGE PatternSynonyms #-} 34 | 35 | module TypedFlow.Layers.RNN.Base ( 36 | -- * Cell Combinators 37 | RnnCell, 38 | simpleRnn, 39 | runCell, mkCell, 40 | stackRnnCells, (.-.), 41 | bothRnnCells, (.|.), 42 | withBypass, withFeedback, 43 | onStates, 44 | -- * Rnn Combinators 45 | Rnn, 46 | runRnn, 47 | stackRnns, (.--.), 48 | bothRnns,(.++.), 49 | -- * RNN unfolding functions 50 | timeDistribute, 51 | iterateCell, 52 | iterateCellBackward, 53 | iterateWithCull, 54 | -- * Monad-like interface for cell construction 55 | Component(..), bindC, returnC, 56 | -- rnnBackwardsWithCull, 57 | ) 58 | 59 | where 60 | import Prelude hiding (tanh,Num(..),Floating(..),floor) 61 | import GHC.TypeLits 62 | import TypedFlow.TF 63 | import TypedFlow.Types 64 | import TypedFlow.Types.Proofs 65 | -- import Data.Type.Equality 66 | -- import Data.Kind (Type,Constraint) 67 | 68 | -- | The RNN Component generalized monad. This can be used to build 69 | -- RNNs cells which do not follow the simple and usual "stacking" 70 | -- pattern. This is not a simple monad, because the indexing over 71 | -- states is non-uniform; see 'BindC'. 72 | newtype Component t (states::[Shape]) a 73 | = C {runC :: HTV t states -> (HTV t states , a)} 74 | -- Note: states are tensors only, because we need to index into them 75 | -- in the time dimension in iterateWithCull 76 | 77 | instance Functor (Component t states) where 78 | fmap = mapC 79 | 80 | mapC :: (a -> b) -> Component t s a -> Component t s b 81 | mapC f c = C $ \s -> 82 | let (s',x) = runC c s 83 | in (s', f x) 84 | 85 | -- | Unit of the Component monad. 86 | returnC :: a -> Component t '[] a 87 | returnC x = C $ \Unit -> (Unit,x) 88 | 89 | -- | Bind operation for Components. States are accumulated. 90 | bindC :: forall t s0 s1 a b. KnownLen s1 91 | => Component t s0 a -> (a -> Component t s1 b) -> Component t (s1++s0) b 92 | bindC f g = C $ \(hsplit @s1 -> (s1,s0)) -> 93 | let (s0',x) = runC f s0 94 | (s1',y) = runC (g x) s1 95 | in (happ s1' s0',y) 96 | 97 | -- | A cell (one time-step) in an rnn. @state@ is the state propagated through time. 98 | type RnnCell t states input output = input -> Component t states output 99 | 100 | -- | An rnn. @n@ is the length of the time sequence. @state@ is the state propagated through time. 101 | type Rnn n b state input output = RnnCell b state (V n input) (V n output) 102 | 103 | -- | Run a cell 104 | runCell :: RnnCell t states input output -> (HTV t states,input) -> (HTV t states, output) 105 | runCell cell = uncurry (flip (runC . cell)) 106 | 107 | -- | Run an RNN, using a tensor as input. @n@ is the length of the time sequence. 108 | runRnn :: (KnownNat n,KnownShape s0, KnownShape s1, KnownTyp t1) 109 | => Rnn n t2 states (T s1 t1) (T s0 t0) 110 | -> (HTV t2 states, Tensor (n ': s1) t1) 111 | -> (HTV t2 states, Tensor (n ': s0) t0) 112 | runRnn l (s,x) = 113 | let x' = unstack0 x 114 | (s',y) = runCell l (s,x') 115 | in (s',stack0 y) 116 | 117 | -- | Run an RNN composed of a single RNN cell. 118 | simpleRnn :: KnownTyp t1 => KnownShape s1 => KnownShape s0 => KnownNat n 119 | => RnnCell t2 states (T s1 t1) (T s0 t0) 120 | -> (HTV t2 states, Tensor (n : s1) t1) 121 | -> (HTV t2 states, Tensor (n : s0) t0) 122 | simpleRnn = runRnn . iterateCell 123 | 124 | -- | Construct a cell from an arbitrary stateful function 125 | mkCell :: ((HTV t states,input) -> (HTV t states, output)) -> RnnCell t states input output 126 | mkCell cell = C . flip (curry cell) 127 | 128 | ---------------------- 129 | -- Lifting functions 130 | 131 | -- | Convert a pure function (feed-forward layer) to an RNN cell by 132 | -- ignoring the RNN state. 133 | timeDistribute :: (a -> b) -> RnnCell t '[] a b 134 | timeDistribute = constantOverSteps 135 | 136 | -- | Convert a pure function (feed-forward layer) to an RNN cell by 137 | -- ignoring the RNN state. 138 | constantOverSteps :: (a -> b) -> RnnCell t '[] a b 139 | constantOverSteps stateLess a = returnC (stateLess a) 140 | 141 | -------------------------------------- 142 | -- Combinators 143 | 144 | -- | Compose two rnn layers. This is useful for example to combine 145 | -- forward and backward layers. 146 | (.--.),stackRnns :: forall s1 s2 a b c n bits. KnownLen s2 147 | => Rnn n bits s1 a b -> Rnn n bits s2 b c -> Rnn n bits (s2 ++ s1) a c 148 | stackRnns = stackRnnCells 149 | 150 | infixr .--. 151 | (.--.) = stackRnns 152 | 153 | -- | Compose two rnn layers in parallel. 154 | bothRnns,(.++.) :: forall s1 s2 a b c n bits t. 155 | KnownTyp t => KnownLen s1 => KnownLen s2 => KnownNat n 156 | => KnownNat b => KnownNat c 157 | => Rnn n bits s1 a (T '[b] t) -> Rnn n bits s2 a (T '[c] t) -> Rnn n bits (s2 ++ s1) a (T ('[b+c]) t) 158 | bothRnns f g x = 159 | f x `bindC` \y -> 160 | g x `bindC` \z -> 161 | returnC (concat0 <$> y <*> z) 162 | 163 | infixr .++. 164 | (.++.) = bothRnns 165 | 166 | -- | Apply a function on the cell state(s) before running the cell itself. 167 | onStates :: (HTV t xs -> HTV t xs) -> RnnCell t xs a b -> RnnCell t xs a b 168 | onStates f cell x = C $ \h -> do 169 | runC (cell x) (f h) 170 | 171 | -- | Stack two RNN cells (LHS is run first) 172 | stackRnnCells, (.-.) :: forall s0 s1 a b c t. KnownLen s1 173 | => RnnCell t s0 a b -> RnnCell t s1 b c -> RnnCell t (s1 ++ s0) a c 174 | stackRnnCells l1 l2 x = l1 x `bindC` l2 175 | (.-.) = stackRnnCells 176 | 177 | 178 | -- | Compose two rnn cells in parallel. 179 | bothRnnCells, (.|.) :: forall s0 s1 a b c t bits. KnownLen s0 => KnownLen s1 180 | => KnownBits bits 181 | => KnownNat b => KnownNat c 182 | => RnnCell t s0 a (T '[b] (Flt bits)) 183 | -> RnnCell t s1 a (T '[c] (Flt bits)) 184 | -> RnnCell t (s1 ++ s0) a (T '[b+c] (Flt bits)) 185 | bothRnnCells l1 l2 x = 186 | l1 x `bindC` \y -> 187 | l2 x `bindC` \z -> 188 | returnC (concat0 y z) 189 | 190 | (.|.) = bothRnnCells 191 | 192 | 193 | -- | Run the cell, and forward the input to the output, by 194 | -- concatenation with the output of the cell. This bypass is sometimes 195 | -- called a 'highway' in the literature. 196 | withBypass :: forall x y t b s0. KnownNat x => KnownNat y => KnownLen s0 197 | => KnownTyp t 198 | => RnnCell b s0 (T '[x] t) (T '[y] t) -> RnnCell b s0 (T '[x] t) (T '[x+y] t) 199 | withBypass cell x = appRUnit @s0 #> 200 | cell x `bindC` \y -> 201 | returnC (concat0 x y) 202 | 203 | -- | Run the cell, and feeds its output as input to the next time-step 204 | withFeedback :: forall outputSize inputSize (w :: Typ) ss. 205 | KnownTyp w => KnownNat outputSize => KnownNat inputSize => 206 | RnnCell w ss (T '[inputSize+outputSize] w) (T '[outputSize] w) -> 207 | RnnCell w ('[outputSize] ': ss) (T '[inputSize ] w) (T '[outputSize] w) 208 | withFeedback cell x = C $ \(F prevoutputnVector :* s) -> 209 | let (s',y) = runC (cell (concat0 x prevoutputnVector)) s 210 | in (F y :* s',y) 211 | 212 | --------------------------------------------------------- 213 | -- RNN unfolding 214 | 215 | -- | Build a RNN by repeating a cell @n@ times. 216 | iterateCell :: ∀ n state input output b. 217 | (KnownNat n) => 218 | RnnCell b state input output -> Rnn n b state input output 219 | iterateCell c x = C $ \s -> chainForward (\(t,y) -> runC (c y) t) (s,x) 220 | 221 | -- | Build a RNN by repeating a cell @n@ times. However the state is 222 | -- propagated in the right-to-left direction (decreasing indices in 223 | -- the time dimension of the input and output tensors) 224 | iterateCellBackward :: ∀ n state input output b. 225 | (KnownNat n) => 226 | RnnCell b state input output -> Rnn n b state input output 227 | iterateCellBackward c x = C $ \s -> chainBackward (\(t,y) -> runC (c y) t) (s,x) 228 | 229 | -- | RNN helper 230 | chainForward :: ∀ state a b n. ((state , a) -> (state , b)) → (state , V n a) -> (state , V n b) 231 | chainForward _ (s0 , VUnit) = (s0 , VUnit) 232 | chainForward f (s0 , x :** xs) = 233 | let (s1,x') = f (s0 , x) 234 | (sFin,xs') = chainForward f (s1 , xs) 235 | in (sFin,(x':**xs')) 236 | 237 | -- | RNN helper 238 | chainBackward :: ∀ state a b n. ((state , a) -> (state , b)) → (state , V n a) -> (state , V n b) 239 | chainBackward _ (s0 , VUnit) = (s0 , VUnit) 240 | chainBackward f (s0 , (x:**xs)) = 241 | let (s1,xs') = chainBackward f (s0,xs) 242 | (sFin, x') = f (s1,x) 243 | in (sFin,(x':**xs')) 244 | 245 | 246 | -- | RNN helper 247 | chainForwardWithState :: ∀ state a b n. ((state , a) -> (state , b)) → (state , V n a) -> (V n b, V n state) 248 | chainForwardWithState _ (_s0 , VUnit) = (VUnit, VUnit) 249 | chainForwardWithState f (s0 , (x:**xs)) = 250 | let (s1,x') = f (s0 , x) 251 | (xs',ss) = chainForwardWithState f (s1 , xs) 252 | in ((x':**xs'), (s1:**ss) ) 253 | 254 | -- -- | RNN helper 255 | -- chainBackwardWithState :: 256 | -- ∀ state a b n. ((state , a) -> (state , b)) → (state , V n a) -> (state , V n b, V n state) 257 | -- chainBackwardWithState _ (s0 , VUnit) = return (s0 , VUnit, VUnit) 258 | -- chainBackwardWithState f (s0 , (x:**xs)) = do 259 | -- (s1,xs',ss') <- chainBackwardWithState f (s0,xs) 260 | -- (sFin, x') <- f (s1,x) 261 | -- return (sFin,(x':**xs'),(sFin:**ss')) 262 | 263 | -- | RNN helper 264 | transposeV :: forall n xs t. All KnownShape xs => KnownNat n => 265 | SList xs -> V n (HTV t xs) -> HTV t (Ap (FMap (Cons n)) xs) 266 | transposeV Unit _ = Unit 267 | transposeV (_ :* n) xxs = F ys' :* yys' 268 | where (ys,yys) = help @(Tail xs) xxs 269 | ys' = stack0 ys 270 | yys' = transposeV n yys 271 | help :: forall ys x tt. V n (HTV tt (x ': ys)) -> (V n (T x tt) , V n (HTV tt ys)) 272 | help (xs) = ((fmap (fromF . hhead) xs),(fmap htail xs)) 273 | 274 | -- | @(gatherFinalStates dynLen states)[i] = states[dynLen[i]-1]@ 275 | gatherFinalStates :: KnownShape x => KnownNat n => T '[] Int32 -> T (n ': x) t -> T x t 276 | gatherFinalStates dynLen states = gather states (dynLen ⊝ constant 1) 277 | 278 | gathers :: forall n xs t. All KnownShape xs => KnownNat n => 279 | SList xs -> T '[] Int32 -> HTV t (Ap (FMap (Cons n)) xs) -> HTV t xs 280 | gathers Unit _ Unit = Unit 281 | gathers (_ :* n) ixs (F x :* xs) = F (gatherFinalStates ixs x) :* gathers @n n ixs xs 282 | 283 | -- | @rnnWithCull dynLen@ constructs an RNN as normal, but returns the 284 | -- state after step @dynLen@ only. 285 | iterateWithCull :: forall n x y ls b. 286 | KnownLen ls => KnownNat n => All KnownShape ls => 287 | T '[] Int32 -- ^ dynamic length 288 | -> RnnCell b ls x y -> Rnn n b ls x y 289 | iterateWithCull dynLen cell xs = C $ \s0 -> 290 | let (us,ss) = chainForwardWithState (uncurry (flip (runC . cell))) (s0,xs) 291 | sss = transposeV @n (typeSList @ls) ss 292 | in (gathers @n (typeSList @ls) dynLen sss,us) 293 | 294 | -- -- | Like @rnnWithCull@, but states are threaded backwards. 295 | -- rnnBackwardsWithCull :: forall n bs x y ls b. 296 | -- KnownLen ls => KnownNat n => All KnownLen ls => All (LastEqual bs) ls => 297 | -- T '[bs] Int32 -> RnnCell b ls x y -> RNN n b ls x y 298 | -- rnnBackwardsWithCull dynLen cell (s0, t) = do 299 | -- (us,ss) <- chainBackwardWithState cell (s0,xs) 300 | -- let sss = transposeV @n (shapeSList @ls) ss 301 | -- return (gathers @n (shapeSList @ls) (n - dynLen) sss,us) 302 | -------------------------------------------------------------------------------- /TypedFlow/Layers/RNN/Cells.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE UndecidableInstances #-} 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-# LANGUAGE ViewPatterns #-} 4 | {-# LANGUAGE AllowAmbiguousTypes #-} 5 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 6 | {-# LANGUAGE DataKinds #-} 7 | {-# LANGUAGE RankNTypes #-} 8 | {-# LANGUAGE ScopedTypeVariables #-} 9 | {-# LANGUAGE TypeApplications #-} 10 | {-# LANGUAGE TypeFamilies #-} 11 | {-# LANGUAGE TypeOperators #-} 12 | {-# LANGUAGE UnicodeSyntax #-} 13 | {-| 14 | Module : TypedFlow.Layers.RNN.Cells 15 | Description : RNN cells 16 | Copyright : (c) Jean-Philippe Bernardy, 2017 17 | License : LGPL-3 18 | Maintainer : jean-philippe.bernardy@gu.se 19 | Stability : experimental 20 | -} 21 | 22 | 23 | module TypedFlow.Layers.RNN.Cells ( 24 | -- * RNN Cells 25 | cellInitializerBit, 26 | LSTMP(..), 27 | lstm, 28 | GRUP(..), 29 | gru, 30 | StackP(..), 31 | stackRU, 32 | ) where 33 | 34 | import TypedFlow.Layers.RNN.Base 35 | import TypedFlow.TF 36 | import TypedFlow.Types 37 | import TypedFlow.Types.Proofs 38 | import GHC.TypeLits 39 | import TypedFlow.Layers.Core (DenseP(..),(#)) 40 | import Prelude hiding (RealFrac(..)) 41 | 42 | -------------------------------------- 43 | -- Cells 44 | 45 | -- | Standard RNN gate initializer. (The recurrent kernel is 46 | -- orthogonal to avoid divergence; the input kernel is glorot) 47 | cellInitializerBit :: ∀ n x t. (KnownNat n, KnownNat x, KnownFloat t) => Gen (DenseP t (n + x) n) 48 | cellInitializerBit = DenseP <$> (concat0 <$> recurrentInitializer <*> kernelInitializer) <*> biasInitializer 49 | where recurrentInitializer :: Gen (Tensor '[n, n] t) 50 | recurrentInitializer = noise $ OrthogonalD 51 | kernelInitializer :: Gen (Tensor '[x, n] t) 52 | kernelInitializer = glorotUniform 53 | biasInitializer = pure zeros 54 | 55 | -- | Parameter for an LSTM 56 | data LSTMP t n x = LSTMP (DenseP t (n+x) n) (DenseP t (n+x) n) (DenseP t (n+x) n) (DenseP t (n+x) n) 57 | 58 | instance (KnownNat n, KnownNat x, KnownFloat t) => KnownTensors (LSTMP t n x) where 59 | travTensor f s (LSTMP x y z w) = LSTMP <$> travTensor f (s<>"_f") x <*> travTensor f (s<>"_i") y <*> travTensor f (s<>"_c") z <*> travTensor f (s<>"_o") w 60 | instance (KnownNat n, KnownNat x, KnownFloat t) => ParamWithDefault (LSTMP t n x) where 61 | defaultInitializer = LSTMP <$> forgetInit <*> cellInitializerBit <*> cellInitializerBit <*> cellInitializerBit 62 | where forgetInit = DenseP <$> (denseWeights <$> cellInitializerBit) <*> pure ones 63 | 64 | -- | Standard LSTM 65 | lstm :: ∀ n x t. KnownNat x => KnownNat n => KnownFloat t 66 | => LSTMP t n x -> RnnCell t '[ '[n], '[n]] (Tensor '[x] t) (Tensor '[n] t) 67 | lstm (LSTMP wf wi wc wo) input = C $ \(VecPair ht1 ct1) -> 68 | let f = sigmoid (wf # hx) 69 | hx = (concat0 ht1 input) 70 | i = sigmoid (wi # hx) 71 | cTilda = tanh (wc # hx) 72 | o = sigmoid (wo # hx) 73 | c = ((f ⊙ ct1) + (i ⊙ cTilda)) 74 | h = (o ⊙ tanh c) 75 | in (VecPair h c, h) 76 | 77 | -- | Parameter for a GRU 78 | data GRUP t n x = GRUP (T [n+x,n] t) (T [n+x,n] t) (T [n+x,n] t) 79 | 80 | instance (KnownNat n, KnownNat x, KnownFloat t) => KnownTensors (GRUP t n x) where 81 | travTensor f s (GRUP x y z) = GRUP <$> travTensor f (s<>"_z") x <*> travTensor f (s<>"_r") y <*> travTensor f (s<>"_w") z 82 | instance (KnownNat n, KnownNat x, KnownFloat t) => ParamWithDefault (GRUP t n x) where 83 | defaultInitializer = GRUP <$> (denseWeights <$> cellInitializerBit) <*> (denseWeights <$> cellInitializerBit) <*> (denseWeights <$> cellInitializerBit) 84 | 85 | 86 | 87 | -- | Standard GRU cell 88 | gru :: ∀ n x t. KnownNat x => (KnownNat n, KnownFloat t) => GRUP t n x -> 89 | RnnCell t '[ '[n] ] (Tensor '[x] t) (Tensor '[n] t) 90 | gru (GRUP wz wr w) xt = C $ \(VecSing ht1) -> 91 | let hx = (concat0 ht1 xt) 92 | zt = sigmoid (wz ∙ hx) 93 | rt = sigmoid (wr ∙ hx) 94 | hTilda = tanh (w ∙ (concat0 (rt ⊙ ht1) xt)) 95 | ht = ((ones ⊝ zt) ⊙ ht1 + zt ⊙ hTilda) 96 | in (VecSing ht, ht) 97 | 98 | 99 | data StackP w n = StackP (DenseP w (n + n) 3) 100 | 101 | defStackP :: KnownNat n => KnownFloat w => Gen (StackP w n) 102 | defStackP = StackP <$> defaultInitializer 103 | -- (DenseP glorotUniform (stack0 (V [zeros, constant (-1), zeros]) )) -- demote popping a bit 104 | 105 | instance (KnownNat n, KnownTyp w) => KnownTensors (StackP w n) where 106 | travTensor f s (StackP d) = StackP <$> travTensor f s d 107 | 108 | instance (KnownNat n, KnownFloat w) => (ParamWithDefault (StackP w n)) where 109 | defaultInitializer = defStackP 110 | 111 | -- | A stack recurrent unit. The input has two purposes: 1. it is 112 | -- saved in a stack. 2. it controls (a dense layer which gives) the 113 | -- operation to apply on the stack. The first type argument is the 114 | -- depth of the stack. 115 | stackRU :: ∀k n bs w. KnownNat k => KnownNat n => (KnownNat bs) => (KnownFloat w) => StackP w n -> 116 | RnnCell w '[ '[k+1,n]] (Tensor '[n] w) (Tensor '[n] w) 117 | stackRU (StackP w) input = C $ \(VecSing st1) -> 118 | succPos @k #> 119 | plusMono @k @1 #> 120 | plusComm @k @1 #> 121 | termCancelation @k @1 #> 122 | let ct1 = nth0' @0 st1 123 | hx = concat0 ct1 input 124 | action :: T '[3] w 125 | action = softmax0 (w # hx) 126 | tl :: T '[k,n] w 127 | tl = slice0 @1 @(k+1) st1 128 | it :: T '[k,n] w 129 | it = slice0 @0 @k st1 130 | stTilda :: T '[3,k+1,n] w 131 | stTilda = stack0 (st1 :** (tl `concat0` zeros) :** (expandDim0 input `concat0` it) :** VUnit) 132 | st :: T '[k+1,n] w 133 | st = inflate2 (flatten12 stTilda ∙ action) 134 | ct = nth0' @0 st 135 | in (VecSing st, ct) 136 | 137 | -------------------------------------------------------------------------------- /TypedFlow/Learn.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE LambdaCase #-} 2 | {-# LANGUAGE FlexibleInstances #-} 3 | {-# LANGUAGE PatternSynonyms #-} 4 | {-| 5 | Module : TypedFlow.Learn 6 | Description : Loss functions and optimization strategies 7 | Copyright : (c) Jean-Philippe Bernardy, 2017 8 | License : LGPL-3 9 | Maintainer : jean-philippe.bernardy@gu.se 10 | Stability : experimental 11 | -} 12 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 13 | {-# LANGUAGE AllowAmbiguousTypes #-} 14 | {-# LANGUAGE ApplicativeDo #-} 15 | {-# LANGUAGE ConstraintKinds #-} 16 | {-# LANGUAGE DataKinds #-} 17 | {-# LANGUAGE DeriveFoldable #-} 18 | {-# LANGUAGE DeriveFunctor #-} 19 | {-# LANGUAGE DeriveTraversable #-} 20 | {-# LANGUAGE FlexibleContexts #-} 21 | {-# LANGUAGE GADTs #-} 22 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 23 | {-# LANGUAGE InstanceSigs #-} 24 | {-# LANGUAGE MagicHash #-} 25 | {-# LANGUAGE RankNTypes #-} 26 | {-# LANGUAGE RecordWildCards #-} 27 | {-# LANGUAGE ScopedTypeVariables #-} 28 | {-# LANGUAGE StandaloneDeriving #-} 29 | {-# LANGUAGE TupleSections #-} 30 | {-# LANGUAGE TypeApplications #-} 31 | {-# LANGUAGE TypeFamilies #-} 32 | {-# LANGUAGE TypeInType #-} 33 | {-# LANGUAGE TypeOperators #-} 34 | {-# LANGUAGE UndecidableInstances #-} 35 | {-# LANGUAGE UndecidableSuperClasses #-} 36 | {-# LANGUAGE UnicodeSyntax #-} 37 | 38 | module TypedFlow.Learn 39 | (-- losses: 40 | sparseCategorical, binary, timedCategorical, categoricalDistribution,sparseCategoricalDensePredictions, 41 | -- types 42 | Options(..), defaultOptions, 43 | Function(..),Model,ModelOutput, 44 | PreparedFunction(..), PreparedModel(..), 45 | -- other 46 | simpleModel, modelFunction, probeFunction, 47 | addRegularizer, 48 | prepare, 49 | -- utils 50 | placeholderName, 51 | ) where 52 | 53 | import Data.Proxy 54 | import TypedFlow.Types 55 | import TypedFlow.Types.Proofs (knownAppend, (?>), ) 56 | import TypedFlow.Broadcast (doBroadcast, mapPlaceHolders, ConsSh,doBroadcastSingle) 57 | import TypedFlow.Abstract (doExtractVars) 58 | import TypedFlow.TF 59 | import Prelude hiding (RealFrac(..)) 60 | import GHC.TypeLits 61 | 62 | -- | Triple of values that are always output in a model: prediction, loss and accuracy. 63 | -- @t@ is the type of the prediction. 64 | -- @s@ is the shape of the loss and accuracy 65 | type ModelOutput t predictionShape s 66 | = Placeholders '[ '("loss",s,Float32) -- loss associated with the prediction 67 | , '("accuracy",s,Float32) -- is the prediction correct? 68 | , '("y_",s++predictionShape,t) -- prediction (which can contain prediction-shaped info) 69 | ] 70 | 71 | pattern ModelOutput :: T (s++predictionShape) t -> T s Float32 -> T s Float32 -> ModelOutput t predictionShape s 72 | pattern ModelOutput y loss accur = PHT loss :* PHT accur :* PHT y :* Unit 73 | 74 | -- | A standard modelling function: (input value, gold value) ↦ (prediction, accuracy, loss). 75 | -- input is the shape of the input. 76 | -- output is the shape of the output (one element per individual loss and accuracy) 77 | -- p is the shape of each output element. 78 | -- g is the shape of each gold output --- often equal to p. 79 | type Model input tIn g p output tOut 80 | = T input tIn -> T (g++output) tOut -> ModelOutput tOut p output 81 | 82 | -- | First type argument is the number of classes. @categorical 83 | -- logits gold@ return (prediction, accuraccy, loss) 84 | 85 | sparseCategorical :: forall nCat. KnownNat nCat => Model '[nCat] Float32 '[] '[] '[] Int32 86 | sparseCategorical logits y = 87 | let y_ = argmax0 logits 88 | modelCorrect = cast (equal y_ y) 89 | modelLoss = sparseSoftmaxCrossEntropyWithLogits y logits 90 | in ModelOutput y_ modelLoss modelCorrect 91 | 92 | -- | First type argument is the number of classes. @categorical 93 | -- logits gold@ return (prediction, accuracy, loss) 94 | sparseCategoricalDensePredictions :: forall nCat. KnownNat nCat 95 | => Tensor '[nCat] Float32 96 | -> Tensor '[] Int32 97 | -> ModelOutput Float32 '[nCat] '[] 98 | sparseCategoricalDensePredictions logits y = 99 | let y_ :: T '[nCat] Float32 100 | y_ = softmax0 logits 101 | modelCorrect = cast (equal (argmax0 logits) y) 102 | modelLoss = sparseSoftmaxCrossEntropyWithLogits y logits 103 | in ModelOutput y_ modelLoss modelCorrect 104 | 105 | 106 | -- | First type argument is the number of classes. 107 | -- @categoricalDistribution logits gold@ return (prediction, 108 | -- accuraccy, loss) accuracy is reported as predicting the same class 109 | -- as the input 'winning' class. 110 | categoricalDistribution :: forall nCat. KnownNat nCat => Model '[nCat] Float32 '[nCat] '[nCat] '[] Float32 111 | categoricalDistribution logits y = 112 | ModelOutput (softmax0 logits) 113 | (softmaxCrossEntropyWithLogits y logits) 114 | (cast (equal (argmax0 @'B32 logits) (argmax0 y))) 115 | 116 | 117 | -- | @timedCategorical targetWeights logits y@ 118 | -- 119 | -- targetWeights: a zero-one matrix of the same size as 120 | -- decoder_outputs. It is intended to mask padding positions outside 121 | -- of the target sequence lengths with values 0. 122 | -- 123 | -- Note that the accuracy is computed by multiplying the accuracies at 124 | -- individual time steps with the targetWeights. 125 | 126 | timedCategorical :: forall len nCat bits. KnownNat nCat => KnownNat len => KnownBits bits => 127 | Tensor '[len] (Flt bits) -> Tensor '[len,nCat] (Flt bits) -> Tensor '[len] Int32 -> ModelOutput (Flt bits) '[len,nCat] '[] 128 | timedCategorical targetWeights logits y = 129 | let y_ :: Tensor '[len] Int32 130 | y_ = argmax1 logits 131 | modelY = softmax1 logits 132 | -- correct prediction for each position 133 | correctPrediction :: Tensor '[len] TFBool 134 | correctPrediction = equal y_ y 135 | -- total number of correct predictions 136 | correctPredictionWeighted :: Tensor '[] (Flt bits) 137 | correctPredictionWeighted = reduceSumAll (cast @(Flt bits) correctPrediction ⊙ targetWeights) 138 | weightSum = reduceSumAll targetWeights 139 | modelCorrect :: Tensor '[] Float32 140 | modelCorrect = cast (correctPredictionWeighted / weightSum) 141 | crossEntropies = zipWithT sparseSoftmaxCrossEntropyWithLogits y logits 142 | modelLoss = cast @Float32 (reduceSumAll (crossEntropies ⊙ targetWeights) / weightSum) 143 | in ModelOutput modelY modelLoss modelCorrect 144 | 145 | -- | Model with @n@ binary outputs. 146 | binary :: KnownNat n => Model '[n] Float32 '[] '[] '[n] Int32 147 | binary logits y = 148 | let y_ = cast @Int32 (round sigy_) 149 | sigy_ = sigmoid logits 150 | in ModelOutput (y_) 151 | (sigmoidCrossEntropyWithLogits (cast @Float32 y) logits) 152 | (cast (equal y_ y)) 153 | 154 | -- | Model compiler options 155 | data Options = Options {maxGradientNorm :: Maybe Prelude.Float -- ^ apply gradient clipping 156 | } 157 | 158 | -- | default model compiler options 159 | defaultOptions :: Options 160 | defaultOptions = Options {maxGradientNorm = Nothing} 161 | 162 | type family Concatenate xs where 163 | Concatenate (x ': xs) = x ++ Concatenate xs 164 | Concatenate '[] = '[] 165 | 166 | genPlaceholders :: All KnownPlaceholder shapesAndTypes => SList shapesAndTypes -> Placeholders shapesAndTypes 167 | genPlaceholders Unit = Unit 168 | genPlaceholders (ph :* names) = PHT (T (ExternalVar (Ref (placeholderName ph) typeSShape typeSTyp))) :* genPlaceholders names 169 | 170 | placeholderName :: forall (ph :: PH) p. KnownPlaceholder ph => p ph -> String 171 | placeholderName proxy = refName (placeHolderRef proxy) 172 | 173 | simpleModel :: forall p sx tx sy ty sy_ ty_. 174 | (KnownShape sy_, KnownShape p, KnownShape sx, KnownTyp ty_, KnownShape sy, KnownTyp tx, KnownTyp ty) 175 | => (Tensor sx tx -> Tensor sy ty -> ModelOutput ty_ p sy_) 176 | -> Function 177 | simpleModel f = knownAppend @sy_ @p ?> modelFunction "runModel" f' 178 | where f' :: Placeholders '[ '("x",sx,tx), '("y",sy,ty)] -> ModelOutput ty_ p sy_ 179 | f' (PHT x :* PHT y :* Unit) = f x y 180 | 181 | 182 | -- | Add a term to the loss. This function is intendend to add 183 | -- regularizers, ie. losses that do not depend on the predicted 184 | -- output, but rather on the structure of a parameter. 185 | addRegularizer :: Scalar Float32 -> Gen () 186 | addRegularizer r = GPState $ \GState{..} -> ((),GState{genRegularizers=r:genRegularizers,..}) 187 | 188 | 189 | 190 | knownBatchModel :: forall n ps. KnownNat n => NP (Sat KnownPlaceholder) ps -> NP (Sat KnownPlaceholder) (Ap (FMap (ConsSh n)) ps) 191 | knownBatchModel Unit = Unit 192 | knownBatchModel (Comp Dict :* xs) = Sat :* knownBatchModel @n xs 193 | 194 | -- | take the mean of loss/accur over the batch, etc. and add regulariser to loss 195 | consolidate :: forall s rest. KnownShape s 196 | => Scalar Float32 197 | -> Placeholders ( '("loss",s ,Float32) ': '("accuracy",s ,Float32) ': rest) 198 | -> Placeholders ( '("loss",'[],Float32) ': '("accuracy",'[],Float32) ': rest) 199 | consolidate extraLoss (PHT loss :* PHT accur :* rest) = (PHT (reduceMeanAll loss + extraLoss) :* PHT (reduceMeanAll accur) :* rest) 200 | 201 | class (All KnownPlaceholder ps, KnownLen ps) => KnownPHS ps 202 | instance (All KnownPlaceholder ps, KnownLen ps) => KnownPHS ps 203 | 204 | data PreparedFunction = PreparedFunction {pfName :: String, 205 | pfBatched :: Bool, 206 | pfInputs, pfOutputs :: SomeSuch KnownPHS Placeholders} 207 | data PreparedModel = PreparedModel {pmBatchSize :: Integer, 208 | pmParams :: [VarInfo], 209 | pmFunctions :: [PreparedFunction] 210 | } 211 | 212 | -- | Prepare compilation of a model by: 213 | -- extracting and exposing parameters 214 | -- batching the model 215 | -- exposing placeholders 216 | -- consolidating loss and accuracy 217 | -- adding regularizers to the loss 218 | prepare :: forall bs. (KnownNat bs) 219 | => Gen [Function] 220 | -> PreparedModel 221 | prepare fGen = 222 | PreparedModel 223 | {pmBatchSize = natVal (Proxy @bs) 224 | ,pmParams = [VarInfo{varInitial=fmap doBroadcastSingle varInitial,..} | VarInfo{..} <- filter varTrainable vars] 225 | ,pmFunctions = flip map fs $ \case 226 | ModelFn nm st1 st2 f -> 227 | knownAll (knownBatchModel @bs st1) $ 228 | knownAll (knownBatchModel @bs st2) $ 229 | knownAll st1 $ 230 | knownAll st2 $ 231 | let placeHolders = genPlaceholders typeSList 232 | u = -777 -- magic unique identifier for the batch dimension 233 | in PreparedFunction nm 234 | True 235 | (SomeSuch placeHolders) 236 | (SomeSuch $ doBroadcast (consolidate {-@(bs ': s) @(BPH bs st2)-} regular (mapPlaceHolders @bs u True f placeHolders))) 237 | ProbeFn nm st1 st2 f -> 238 | knownAll st1 $ 239 | knownAll st2 $ 240 | let placeHolders = genPlaceholders typeSList 241 | in PreparedFunction nm False (SomeSuch placeHolders) (SomeSuch (doBroadcast (f placeHolders))) 242 | } 243 | where (fs,finalState,vars) = doExtractVars fGen 244 | regular = sum (genRegularizers finalState) 245 | 246 | data Function where 247 | ModelFn :: (KnownShape s, KnownLen st1, KnownLen st2) 248 | => String 249 | -> NP (Sat KnownPlaceholder) st1 -> NP (Sat KnownPlaceholder) st2 250 | -> (Placeholders st1 -> Placeholders ('("loss",s,Float32) ': '("accuracy",s,Float32) ': st2)) -> Function 251 | ProbeFn :: (KnownLen st1, KnownLen st2, All KnownPlaceholder st1, All KnownPlaceholder st2) 252 | => String 253 | -> NP (Sat KnownPlaceholder) st1 -> NP (Sat KnownPlaceholder) st2 254 | -> (Placeholders st1 -> Placeholders st2) -> Function 255 | 256 | modelFunction :: (KnownShape s, KnownLen st1, KnownLen st2, All KnownPlaceholder st1, All KnownPlaceholder st2) 257 | => String 258 | -> (Placeholders st1 -> Placeholders ('("loss",s,Float32) ': '("accuracy",s,Float32) ': st2)) -> Function 259 | modelFunction nm f = ModelFn nm (allKnown @KnownPlaceholder) (allKnown @KnownPlaceholder) f 260 | 261 | 262 | probeFunction :: (KnownLen st1, KnownLen st2, All KnownPlaceholder st1, All KnownPlaceholder st2) 263 | => String 264 | -> (Placeholders st1 -> Placeholders st2) -> Function 265 | probeFunction nm f = ProbeFn nm (allKnown @KnownPlaceholder) (allKnown @KnownPlaceholder) f 266 | 267 | 268 | -------------------------------------------------------------------------------- /TypedFlow/Memo.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeInType #-} 2 | {-# LANGUAGE PolyKinds #-} 3 | {-# LANGUAGE KindSignatures #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | {-# LANGUAGE RankNTypes #-} 6 | {-# LANGUAGE GADTs #-} 7 | module TypedFlow.Memo where 8 | 9 | import qualified Data.IntMap as I 10 | import qualified Data.Map.Strict as M 11 | import System.Mem.StableName 12 | import Data.IORef 13 | import System.IO.Unsafe 14 | import Unsafe.Coerce 15 | import Data.Kind (Type) 16 | type SNMap k v = I.IntMap [(StableName k,v)] 17 | 18 | snMapLookup :: StableName k -> SNMap k v -> Maybe v 19 | snMapLookup sn m = do 20 | x <- I.lookup (hashStableName sn) m 21 | lookup sn x 22 | 23 | snMapInsert :: StableName k -> v -> SNMap k v -> SNMap k v 24 | snMapInsert sn res = I.insertWith (++) (hashStableName sn) [(sn,res)] 25 | 26 | memo :: (a -> b) -> a -> b 27 | memo f = unsafePerformIO ( 28 | do { tref <- newIORef (I.empty) 29 | ; return (applyStable f tref) 30 | }) 31 | 32 | applyStable :: (a -> b) -> IORef (SNMap a b) -> a -> b 33 | applyStable f tbl arg = unsafePerformIO ( 34 | do { sn <- makeStableName arg 35 | ; lkp <- snMapLookup sn <$> readIORef tbl 36 | ; case lkp of 37 | Just result -> return result 38 | Nothing -> 39 | do { let res = f arg 40 | ; modifyIORef tbl (snMapInsert sn res) 41 | ; return res 42 | }}) 43 | 44 | memoOrd :: Ord a => (a -> b) -> a -> b 45 | memoOrd f = unsafePerformIO ( 46 | do { tref <- newIORef (M.empty) 47 | ; return (applyStableOrd f tref) 48 | }) 49 | 50 | applyStableOrd :: Ord a => (a -> b) -> IORef (M.Map a b) -> a -> b 51 | applyStableOrd f tbl arg = unsafePerformIO ( 52 | do { lkp <- M.lookup arg <$> readIORef tbl 53 | ; case lkp of 54 | Just result -> return result 55 | Nothing -> 56 | do { let res = f arg 57 | ; modifyIORef tbl (M.insert arg res) 58 | ; return res 59 | }}) 60 | 61 | 62 | data Some2 k1 k2 (f :: k1 -> k2 -> Type) where 63 | Some2 :: forall k1 k2 f a b. StableName (f a b) -> Some2 k1 k2 f 64 | 65 | instance Eq (Some2 k1 k2 f) where 66 | Some2 sn1 == Some2 sn2 = eqStableName sn1 sn2 67 | 68 | type SSNMap2 k1 k2 (f :: k1 -> k2 -> Type) v = I.IntMap [(Some2 k1 k2 f,v)] 69 | 70 | makeSn2 :: f a b -> Some2 k1 k2 f 71 | makeSn2 = Some2 . unsafePerformIO . makeStableName 72 | 73 | snMapLookup2 :: Some2 k1 k2 f -> SSNMap2 k1 k2 f v -> Maybe v 74 | snMapLookup2 (Some2 sn) m = do 75 | x <- I.lookup (hashStableName sn) m 76 | lookup (Some2 sn) x 77 | 78 | snMapInsert2 :: Some2 k1 k2 f -> v -> SSNMap2 k1 k2 f v -> SSNMap2 k1 k2 f v 79 | snMapInsert2 (Some2 sn) res = I.insertWith (++) (hashStableName sn) [(Some2 sn,res)] 80 | 81 | data KV k1 k2 (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type) where 82 | KV :: forall k1 k2 f v a b. StableName (f a b) -> v a b -> KV k1 k2 f v 83 | 84 | type SNMap22 k1 k2 (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type) = I.IntMap [KV k1 k2 f v] 85 | 86 | snMap22Lookup :: StableName (f a b) -> SNMap22 k1 k2 f v -> Maybe (v a b) 87 | snMap22Lookup sn m = do 88 | x <- I.lookup (hashStableName sn) m 89 | lkKV sn x 90 | 91 | lkKV :: StableName (f a b) -> [KV k1 k2 f v] -> Maybe (v a b) 92 | lkKV _ [] = Nothing 93 | lkKV sn (KV sn' v:kvs) | eqStableName sn sn' = Just (unsafeCoerce v) -- sn == sn' -> a == a' and b == b' 94 | | otherwise = lkKV sn kvs 95 | 96 | snMap22Insert :: KV k1 k2 f v -> SNMap22 k1 k2 f v -> SNMap22 k1 k2 f v 97 | snMap22Insert (KV sn res) = I.insertWith (++) (hashStableName sn) [KV sn res] 98 | 99 | 100 | -- | The type of a memo table for functions of a. 101 | type Memo a = forall r. (a -> r) -> (a -> r) 102 | 103 | -- | Memoize a two argument function (just apply the table directly for 104 | -- single argument functions). 105 | memo2 :: Memo a -> Memo b -> (a -> b -> r) -> (a -> b -> r) 106 | memo2 a b = a . (b .) 107 | 108 | -- | Memoize a three argument function. 109 | memo3 :: Memo a -> Memo b -> Memo c -> (a -> b -> c -> r) -> (a -> b -> c -> r) 110 | memo3 a b c = a . (memo2 b c .) 111 | -------------------------------------------------------------------------------- /TypedFlow/Memo2.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 2 | {-# LANGUAGE TypeOperators #-} 3 | {-# LANGUAGE LambdaCase #-} 4 | {-# LANGUAGE RecordWildCards #-} 5 | {-# LANGUAGE PolyKinds #-} 6 | {-# LANGUAGE KindSignatures #-} 7 | {-# LANGUAGE ScopedTypeVariables #-} 8 | {-# LANGUAGE RankNTypes #-} 9 | {-# LANGUAGE GADTs #-} 10 | 11 | module TypedFlow.Memo2 where 12 | 13 | import Data.Kind (Type) 14 | import qualified Data.Map.Strict as M 15 | import System.Mem.StableName 16 | -- import Data.IORef 17 | -- import System.IO.Unsafe 18 | import Unsafe.Coerce 19 | import qualified Data.IntMap as I 20 | import Data.Type.Equality 21 | import Control.Monad.IO.Class 22 | import Data.IORef 23 | import TypedFlow.Types.Proofs (SingEq(..)) 24 | import Data.List (intercalate) 25 | 26 | data Map0 k (m :: Type -> Type) f v = forall . Map0 { 27 | m0Key :: f -> IO k, 28 | m0Empty :: m v, 29 | m0lk :: k -> m v -> Maybe v, 30 | m0upd :: k -> (Maybe v -> v) -> m v -> m v, 31 | m0fmap :: forall u w. (u -> w) -> m u -> m w, 32 | m0showKey :: k -> String, 33 | m0showTbl :: (v -> String) -> (m v -> String) 34 | } 35 | 36 | 37 | data Map1 (k :: k1 -> Type) (m :: (k1 -> Type) -> Type) (f :: k1 -> Type) (v :: k1 -> Type) = Map1 { 38 | m1Key :: forall x. f x -> IO (k x), 39 | m1Empty :: m v, 40 | m1lk :: forall x. k x -> m v -> Maybe (v x), 41 | m1upd :: forall x. k x -> (Maybe (v x) -> (v x)) -> m v -> m v, 42 | m1showKey :: forall x . k x -> String, 43 | m1showTbl :: (forall x . v x -> String) -> (m v -> String) 44 | } 45 | 46 | data Map2 (k :: k1 -> k2 -> Type) (m :: (k1 -> k2 -> Type) -> Type) (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type) = Map2 { 47 | m2Key :: forall x y. f x y -> IO (k x y), 48 | m2Empty :: m v, 49 | m2lk :: forall x y. k x y -> m v -> Maybe (v x y), 50 | m2upd :: forall x y. k x y -> (Maybe (v x y) -> (v x y)) -> m v -> m v, 51 | -- m2fmap :: forall u w. (forall x y. u x y -> w x y) -> m u -> m w, 52 | m2showKey :: forall x y. k x y -> String, 53 | m2showTbl :: (forall x y. v x y -> String) -> (m v -> String) 54 | } 55 | 56 | data Map3 (k :: k1 -> k2 -> k3 -> Type) (m :: (k1 -> k2 -> k3 -> Type) -> Type) (f :: k1 -> k2 -> k3 -> Type) (v :: k1 -> k2 -> k3 -> Type) = Map3 { 57 | m3Key :: forall x y z. f x y z -> IO (k x y z), 58 | m3Empty :: m v, 59 | m3lk :: forall x y z. k x y z -> m v -> Maybe (v x y z), 60 | m3upd :: forall x y z. k x y z -> (Maybe (v x y z) -> (v x y z)) -> m v -> m v, 61 | m3showKey :: forall x y z. k x y z -> String, 62 | m3showTbl :: (forall x y z. v x y z -> String) -> (m v -> String) 63 | } 64 | 65 | newtype Id x = Id x 66 | 67 | ordMap :: forall k b. (Ord k, Show k) => Map0 k (M.Map k) k b 68 | ordMap = Map0 {..} where 69 | m0Key = return 70 | m0Empty = mempty 71 | m0lk k = M.lookup k 72 | m0upd k f m = M.alter (Just . f) k m 73 | m0fmap = fmap 74 | m0showKey = show 75 | m0showTbl sh m = intercalate ";" [(show k) <> "↦" <> (sh v) | (k,v) <- M.assocs m] 76 | 77 | data Single1 f g where 78 | None1 :: Single1 f g 79 | Single1 :: f a -> g a -> Single1 f g 80 | 81 | verifMap1 :: forall k v. SingEq k => Map1 k (Single1 k) k v 82 | verifMap1 = Map1 {..} where 83 | m1Key = return 84 | m1Empty = None1 85 | m1lk :: k a -> Single1 k b -> Maybe (b a) 86 | m1lk k = \case 87 | None1 -> Nothing 88 | Single1 k' v -> case testEq k k' of 89 | Just Refl -> Just v 90 | Nothing -> error "verifMap1: mismatching keys! (1)" 91 | m1upd :: forall x. k x -> (Maybe (v x) -> (v x)) -> Single1 k v -> Single1 k v 92 | m1upd k f None1 = Single1 k (f Nothing) 93 | m1upd k f (Single1 k' v) = case testEq k k' of 94 | Just Refl -> Single1 k (f (Just v)) 95 | Nothing -> error "verifMap1: mismatching keys! (2)" 96 | m1showKey _ = "#" 97 | m1showTbl :: (forall x . v x -> String) -> (Single1 k v -> String) 98 | m1showTbl _ None1 = "·" 99 | m1showTbl h (Single1 _ v) = "!" <> (h v) 100 | 101 | 102 | testStable :: StableName a -> StableName b -> Maybe (a :~: b) 103 | testStable sn sn' | eqStableName sn sn' = Just (unsafeCoerce Refl) 104 | | otherwise = Nothing 105 | 106 | snMap2 :: forall f v. Map2 (SN2 f) (SNMap22 f) f v 107 | snMap2 = Map2 {..} where 108 | m2showTbl :: (forall x y. v x y -> String) -> (SNMap22 f v -> String) 109 | m2showTbl h (SNMap22 m) = intercalate "," [ m2showKey k <> "↦" <> h v | e <- I.elems m, KV k v <- e ] 110 | m2showKey (SN2 sn) = show (hashStableName sn) 111 | m2Key obj = SN2 <$> makeStableName obj 112 | m2Empty = mempty 113 | m2lk = snMap22Lookup 114 | m2upd :: SN2 f x y -> (Maybe (v x y) -> (v x y)) -> SNMap22 f v -> SNMap22 f v 115 | m2upd (SN2 sn) f (SNMap22 m) = SNMap22 $ 116 | I.alter (\case Nothing -> Just [KV (SN2 sn) (f Nothing)] 117 | Just p -> Just (updKV (SN2 sn) f p)) 118 | (hashStableName sn) 119 | m 120 | 121 | updKV :: SN2 f' x y -> (Maybe (v' x y) -> (v' x y)) -> [KV k1 k2 (SN2 f') v'] -> [KV k1 k2 (SN2 f') v'] 122 | updKV (SN2 sn) f [] = [KV (SN2 sn) (f Nothing)] 123 | updKV (SN2 sn) f (v@(KV (SN2 sn') x):xs) = case testStable sn sn' of 124 | Just Refl -> KV (SN2 sn') (f (Just x)):xs 125 | Nothing -> v : updKV (SN2 sn) f xs 126 | 127 | -- m2fmap :: forall u w. (forall x y. u x y -> w x y) -> SNMap22 f u -> SNMap22 f w 128 | -- m2fmap h (SNMap22 t) = SNMap22 (fmap (fmap (\(KV k v) -> KV k (h v))) t) 129 | 130 | snMap22Lookup :: forall a b f' v'. SN2 f' a b -> SNMap22 f' v' -> Maybe (v' a b) 131 | snMap22Lookup (SN2 sn) (SNMap22 m) = do 132 | x <- I.lookup (hashStableName sn) m 133 | lkKV sn x 134 | 135 | lkKV :: forall k1 k2 f' v' a b . StableName (f' a b) -> [KV k1 k2 (SN2 f') v'] -> Maybe (v' a b) 136 | lkKV _ [] = Nothing 137 | lkKV sn (KV (SN2 sn') v:kvs) = case testStable sn sn' of 138 | Just Refl -> Just (unsafeCoerce v) -- sn == sn' -> a == a' and b == b' 139 | Nothing -> lkKV sn kvs 140 | 141 | 142 | data KV k1 k2 (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type) where 143 | KV :: forall k1 k2 f v a b. f a b -> v a b -> KV k1 k2 f v 144 | 145 | newtype SNMap22 (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type) = SNMap22 (I.IntMap [KV k1 k2 (SN2 f) v]) deriving (Monoid, Semigroup) 146 | 147 | newtype SN2 (f :: k1 -> k2 -> Type) a b = SN2 (StableName (f a b)) 148 | 149 | data (:.:) (m1 :: k2 -> Type) (m2 :: k1 -> k2) (h :: k1) = Comp (m1 (m2 h)) 150 | 151 | 152 | data Sig02 f g x y where 153 | Ex02 :: f -> g x y -> Sig02 f g x y 154 | 155 | data Sig03 f g x y z where 156 | Ex03 :: f -> g x y z -> Sig03 f g x y z 157 | 158 | data Sig12 f g x y z where 159 | Ex12 :: f x -> g y z -> Sig12 f g x y z 160 | 161 | data Sig22 f g x y where 162 | Ex22 :: f x y -> g x y -> Sig22 f g x y 163 | 164 | data P33 f g x y z where 165 | T33 :: f x y z -> g x y z -> P33 f g x y z 166 | 167 | 168 | 169 | containing00 :: (forall v. Map0 k1 m1 f v) -> Map0 k2 m2 g h -> Map0 (k1,k2) (m1 :.: m2) (f,g) h 170 | containing00 f g = Map0 171 | { 172 | m0Key = (\(a,b) -> (,) <$> m0Key f a <*> m0Key g b), 173 | m0Empty = Comp (m0Empty f), 174 | m0lk = \(k1,k2) (Comp t) -> do t' <- m0lk f k1 t; m0lk g k2 t', 175 | m0upd = \(k1,k2) h (Comp t) -> Comp $ m0upd f k1 (m0upd g k2 h . \case Just tb -> tb; Nothing -> (m0Empty g)) t, 176 | m0fmap = \h (Comp t) -> Comp $ m0fmap f (m0fmap g h) t, 177 | m0showKey = \(k1,k0) -> m0showKey f k1 <> "," <> m0showKey g k0, 178 | m0showTbl = \h (Comp t) -> m0showTbl f (m0showTbl g h) t 179 | } 180 | 181 | containing02 :: (forall v. Map0 k1 m1 f v) -> Map2 k2 m2 g h -> Map2 (Sig02 k1 k2) (m1 :.: m2) (Sig02 f g) h 182 | containing02 f g = Map2 183 | { 184 | m2Key = (\(Ex02 a b) -> Ex02 <$> m0Key f a <*> m2Key g b), 185 | m2Empty = Comp (m0Empty f), 186 | m2lk = \(Ex02 k1 k2) (Comp t) -> do t' <- m0lk f k1 t; m2lk g k2 t', 187 | m2upd = \(Ex02 k1 k2) h (Comp t) -> Comp $ m0upd f k1 (m2upd g k2 h . \case Just tb -> tb; Nothing -> (m2Empty g)) t, 188 | -- m2fmap = \h (Comp t) -> Comp $ m0fmap f (m2fmap g h) t, 189 | m2showKey = \(Ex02 k1 k2) -> m0showKey f k1 <> "," <> m2showKey g k2, 190 | m2showTbl = \h (Comp t) -> m0showTbl f (m2showTbl g h) t 191 | } 192 | 193 | containing03 :: (forall v. Map0 k1 m1 f v) -> Map3 k2 m2 g h -> Map3 (Sig03 k1 k2) (m1 :.: m2) (Sig03 f g) h 194 | containing03 f g = Map3 195 | { 196 | m3Key = (\(Ex03 a b) -> Ex03 <$> m0Key f a <*> m3Key g b), 197 | m3Empty = Comp (m0Empty f), 198 | m3lk = \(Ex03 k1 k3) (Comp t) -> do t' <- m0lk f k1 t; m3lk g k3 t', 199 | m3upd = \(Ex03 k1 k3) h (Comp t) -> Comp $ m0upd f k1 (m3upd g k3 h . \case Just tb -> tb; Nothing -> (m3Empty g)) t, 200 | m3showKey = \(Ex03 k1 k2) -> m0showKey f k1 <> "," <> m3showKey g k2 201 | , 202 | m3showTbl = \h (Comp t) -> m0showTbl f (m3showTbl g h) t 203 | } 204 | 205 | newtype Lam' (m2 :: (k2 -> k3 -> Type) -> Type) (h :: k1 -> k2 -> k3 -> Type) (a :: k1) = Lam' {fromLam' :: (m2 (h a))} 206 | data M12 (m1 :: (k1 -> Type) -> Type) (m2 :: (k2 -> k3 -> Type) -> Type) (h :: k1 -> k2 -> k3 -> Type) = M12 (m1 (Lam' m2 h)) 207 | 208 | containing12 :: (forall v. Map1 k1 m1 f v) -> (forall k4. Map2 k2 m2 g (h k4)) -> Map3 (Sig12 k1 k2) (M12 m1 m2) (Sig12 f g) h 209 | containing12 f g = Map3 210 | { 211 | m3Key = (\(Ex12 a b) -> Ex12 <$> m1Key f a <*> m2Key g b), 212 | m3Empty = M12 (m1Empty f), 213 | m3lk = \(Ex12 k1 k2) (M12 t) -> do Lam' t' <- m1lk f k1 t; m2lk g k2 t', 214 | m3upd = \(Ex12 k1 k2) h (M12 t) -> M12 $ m1upd f k1 (Lam' . m2upd g k2 h . (\case Just tb -> tb; Nothing -> m2Empty g) . fmap fromLam') t, 215 | m3showKey = \(Ex12 k1 k2) -> m1showKey f k1 <> ">" <> m2showKey g k2, 216 | m3showTbl = \h (M12 t) -> m1showTbl f (m2showTbl g h . fromLam') t 217 | } 218 | 219 | 220 | 221 | data F2m m g h = F2m (forall x y. g x y -> m (h x y)) 222 | data F2m' m g f h = F2m' (forall x y. g x y -> f x y -> m (h x y)) 223 | 224 | data F3m m g h = F3m (forall x y z. g x y z -> m (h x y z)) 225 | data F3m' m g f h = F3m' (forall x y z. g x y z -> f x y z -> m (h x y z)) 226 | 227 | memo2 :: forall g h k m n. MonadIO n => Map2 k m g h -> ((forall x y. g x y -> n (h x y)) -> forall x y. g x y -> n (h x y)) -> n (F2m n g h) 228 | memo2 Map2{..} f = do 229 | tblRef <- liftIO $ newIORef m2Empty 230 | let finished :: forall x y. g x y -> n (h x y) 231 | finished arg = do 232 | tbl <- liftIO $ readIORef tblRef 233 | key <- liftIO $ m2Key arg 234 | case m2lk key tbl of 235 | Just result -> do 236 | -- liftIO $ putStrLn "memo2: hit" 237 | return result 238 | Nothing -> do 239 | -- liftIO $ putStrLn "memo2: miss" 240 | res <- f finished arg 241 | liftIO $ modifyIORef tblRef (m2upd key $ \_ -> res) 242 | return res 243 | return (F2m finished) 244 | 245 | memo2' :: forall g f h k m n. MonadIO n => Map2 k m g h -> ((forall x y. g x y -> f x y -> n (h x y)) -> forall x y. g x y -> f x y -> n (h x y)) -> n (F2m' n g f h) 246 | memo2' Map2{..} f = do 247 | tblRef <- liftIO $ newIORef m2Empty 248 | let finished :: forall x y. g x y -> f x y -> n (h x y) 249 | finished arg extra = do 250 | tbl <- liftIO $ readIORef tblRef 251 | key <- liftIO $ m2Key arg 252 | case m2lk key tbl of 253 | Just result -> do 254 | -- liftIO $ putStrLn "memo2': hit" 255 | return result 256 | Nothing -> do 257 | -- liftIO $ putStrLn ("memo2: miss " <> m2showKey key) -- <> " from " <> m3showTbl (const ".") tbl 258 | res <- f finished arg extra 259 | liftIO $ modifyIORef tblRef (m2upd key $ \_ -> res) 260 | return res 261 | return (F2m' finished) 262 | 263 | memo3' :: forall g f h k m n. MonadIO n => Map3 k m g h -> ((forall x y z. g x y z -> f x y z -> n (h x y z)) -> forall x y z. g x y z -> f x y z -> n (h x y z)) -> n (F3m' n g f h) 264 | memo3' Map3{..} f = do 265 | tblRef <- liftIO $ newIORef m3Empty 266 | let finished :: forall x y z. g x y z -> f x y z -> n (h x y z) 267 | finished arg extra = do 268 | tbl <- liftIO $ readIORef tblRef 269 | key <- liftIO $ m3Key arg 270 | case m3lk key tbl of 271 | Just result -> do 272 | -- liftIO $ putStrLn "memo3: hit" 273 | return result 274 | Nothing -> do 275 | -- liftIO $ putStrLn ("memo3: miss " <> m3showKey key) -- <> " from " <> m3showTbl (const ".") tbl 276 | res <- f finished arg extra 277 | liftIO $ modifyIORef tblRef (m3upd key $ \_ -> res) 278 | return res 279 | return (F3m' finished) 280 | -------------------------------------------------------------------------------- /TypedFlow/Models/Topic.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances #-} 2 | {-# LANGUAGE MultiParamTypeClasses #-} 3 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 4 | {-# LANGUAGE ConstraintKinds #-} 5 | {-# LANGUAGE RecordWildCards #-} 6 | {-# LANGUAGE AllowAmbiguousTypes #-} 7 | {-# LANGUAGE DataKinds #-} 8 | {-# LANGUAGE DeriveFoldable #-} 9 | {-# LANGUAGE DeriveFunctor #-} 10 | {-# LANGUAGE DeriveTraversable #-} 11 | {-# LANGUAGE FlexibleContexts #-} 12 | {-# LANGUAGE GADTs #-} 13 | {-# LANGUAGE MagicHash #-} 14 | {-# LANGUAGE RankNTypes #-} 15 | {-# LANGUAGE ScopedTypeVariables #-} 16 | {-# LANGUAGE StandaloneDeriving #-} 17 | {-# LANGUAGE TypeApplications #-} 18 | {-# LANGUAGE TypeFamilies #-} 19 | {-# LANGUAGE TypeInType #-} 20 | {-# LANGUAGE TypeOperators #-} 21 | {-# LANGUAGE UndecidableInstances #-} 22 | {-# LANGUAGE UnicodeSyntax #-} 23 | {-| 24 | Module : TypedFlow.Models.Topic 25 | Description : Topic models 26 | Copyright : (c) Jean-Philippe Bernardy, 2017 27 | License : LGPL-3 28 | Maintainer : jean-philippe.bernardy@gu.se 29 | Stability : experimental 30 | -} 31 | 32 | 33 | module TypedFlow.Models.Topic where 34 | import Prelude hiding (RealFrac(..)) 35 | import TypedFlow.TF 36 | import TypedFlow.Layers 37 | import TypedFlow.Types 38 | import TypedFlow.Types.Proofs ((?>), knownSum') 39 | import TypedFlow.Learn 40 | import GHC.TypeLits 41 | import Data.Monoid ((<>)) 42 | import Data.Proxy 43 | 44 | -- | A convolutional document summary function. Described in 45 | -- 'Topically Driven Neural Language Model' by Lau, Baldwin and Cohn. 46 | tdlmDocsummary :: forall 47 | (vocSize :: Nat) -- number of words 48 | (e :: Nat) -- size of the embedding 49 | (a :: Nat) -- number of features of the document vector summary 50 | (n :: Nat) -- length of the document 51 | (filterSize :: Nat) -- size of the convolution filter 52 | (t :: NBits) -- size of floats 53 | . KnownNat vocSize => KnownNat filterSize => KnownNat e => KnownNat a => KnownNat n => KnownBits t 54 | => (EmbeddingP vocSize e (Flt t)) 55 | -> (ConvP (Flt t) a e '[filterSize]) 56 | -> DropProb 57 | -> Gen (T '[n] Int32 -> T '[a] (Flt t)) 58 | tdlmDocsummary embs filters dropProb = do 59 | drpEmb <- mkDropout dropProb 60 | return $ \document -> 61 | let embeddedDoc :: Tensor [n,e] (Flt t) 62 | embeddedDoc = mapT (drpEmb . embedding @e @vocSize embs) document 63 | in reduceMax axis0 (conv' @'[n] filters embeddedDoc) 64 | 65 | tdlmDocsummary' :: forall 66 | (vocSize :: Nat) -- number of words 67 | (e :: Nat) -- size of the embedding 68 | (n :: Nat) -- length of the document 69 | -- (a :: Nat) -- number of features of the document vector summary 70 | -- (filterSize :: Nat) -- size of the convolution filter 71 | spec 72 | (t :: NBits) -- size of floats 73 | proxy 74 | . KnownNat vocSize => KnownNat (Ap Frst' spec) => KnownNat e => KnownNat (Ap Scnd' spec) => KnownNat n => KnownBits t 75 | => proxy spec 76 | -> (EmbeddingP vocSize e (Flt t)) 77 | -> (ConvP (Flt t) (Ap Scnd' spec) e '[(Ap Frst' spec)]) 78 | -> DropProb 79 | -> Gen (T '[n] Int32 -> T '[Ap Scnd' spec] (Flt t)) 80 | tdlmDocsummary' _proxy = tdlmDocsummary 81 | 82 | scnds :: SList xs -> SList (Ap (FMap Scnd') xs) 83 | scnds Unit = Unit 84 | scnds (_ :* xs) = Proxy :* scnds xs 85 | -- hmap _ Unit = Unit 86 | -- hmap f (x :* xs) = f x :* hmap f xs 87 | 88 | mkTdlmDocsummary :: forall 89 | (vocSize :: Nat) -- number of words 90 | (e :: Nat) -- size of the embedding 91 | (spec :: [(Nat,Nat)]) -- (size of the convolution filter,number of features) 92 | (n :: Nat) -- length of the document 93 | (t :: NBits) -- size of floats 94 | . KnownNat vocSize => KnownNat e => KnownNat n => KnownBits t 95 | => All KnownNat (Ap (FMap Scnd') spec) 96 | => All KnownNat (Ap (FMap Frst') spec) 97 | => SList spec 98 | -> DropProb 99 | -> Gen (T '[n] Int32 -> T '[Sum (Ap (FMap Scnd') spec)] (Flt t)) 100 | mkTdlmDocsummary xs0 dropProb = case xs0 of 101 | Unit -> return (\_ -> zeros) 102 | (proxy :* xs) -> knownSum' (scnds xs) ?> 103 | do embs <- parameterDefault ("embs_topic_" ++ show (sListLength xs)) 104 | filters <- parameterDefault ("filters_topic_" ++ show (sListLength xs)) 105 | f <- tdlmDocsummary' @vocSize @e proxy embs filters dropProb 106 | fs <- mkTdlmDocsummary @vocSize @e xs dropProb 107 | return $ \input -> concat0 (f input) (fs input) 108 | 109 | -- | Parameter for topics. This is effectively map from document 110 | -- features (a) to topic representations (vectors of size b) via k 111 | -- topic distributions. 112 | data TopicP t a k b = TopicP {topicDistributions :: (T '[a,k] (Flt t)) -- ^ a linear map from documents features (a) to topic distributions (k) 113 | ,topicRepresentations :: (T '[k,b] (Flt t)) -- ^ a linear map from topic distributions (k) to topic representations (b) 114 | } 115 | 116 | instance (KnownNat a, KnownNat k, KnownNat b, KnownBits t) => KnownTensors (TopicP t a k b) where 117 | travTensor f s (TopicP x y) = TopicP <$> travTensor f (s<>"_A") x <*> travTensor f (s<>"_B") y 118 | instance (KnownNat a, KnownNat k, KnownNat b, KnownBits t) => ParamWithDefault (TopicP t a k b) where 119 | defaultInitializer = TopicP <$> glorotUniform <*> glorotUniform 120 | 121 | -- | A topic modeler. Described 'Topically Driven Neural Language 122 | -- Model' by Lau, Baldwin and Cohn. Returns a function converting raw 123 | -- representations (eg. document summaries) to topic representations. 124 | -- This representation can be used as input to a dense layer to 125 | -- predict a word, or as input to an LSTM (initial state) to predict 126 | -- sentences. 127 | mkTdlmTopic :: forall 128 | (kk :: Nat) -- number of topics 129 | (a :: Nat) -- document vector summary size 130 | (b :: Nat) -- topic representation size 131 | (t :: NBits) -- size of floats 132 | . KnownNat kk => KnownNat a => KnownNat b => KnownBits t 133 | => Float -> TopicP t a kk b -> Gen (T '[a] (Flt t) -> (Tensor '[b] (Flt t), Tensor '[kk] (Flt t))) 134 | mkTdlmTopic separationConstant (TopicP topicInput topicOutput) = do 135 | drpS <- mkDropout (DropProb 0.1) 136 | let topicNormalized :: T '[kk,b] (Flt t) 137 | topicNormalized = mapT normalize topicOutput 138 | -- matrix of correlation between the topics 139 | topicCorrelation :: T '[kk,kk] (Flt t) 140 | topicCorrelation = matmul topicNormalized (transpose01 topicNormalized) 141 | -- max correlation between two distinct topics 142 | topicOverlap = reduceMaxAll (square (topicCorrelation ⊝ eye)) 143 | addRegularizer (constant separationConstant ⊙ cast topicOverlap) -- regularizer which ensures that topics are disjoint 144 | 145 | return (\d -> let p :: T '[kk] (Flt t) 146 | p = softmax0 (topicInput ∙ d) -- attention distribution (among the topics) 147 | in (drpS (topicOutput ∙ p), p)) 148 | 149 | 150 | 151 | -- | Gating unit which can be used to mix a RNN hidden state with an 152 | -- external information source (eg. topic representation). Described 153 | -- 'Topically Driven Neural Language Model' by Lau, Baldwin and Cohn; 154 | -- formula (3) 155 | tdlmGatingUnit :: KnownNat n => KnownFloat t => KnownNat m => (GRUP t m n) -> T '[n] t -> T '[m] t -> (T '[m] t) 156 | tdlmGatingUnit (GRUP wz wr w) s h = 157 | let x = concat0 h s 158 | z = sigmoid (wz ∙ x) 159 | r = sigmoid (wr ∙ x) 160 | hTilda = tanh (w ∙ (concat0 (r ⊙ h) s)) 161 | in ((ones ⊝ z) ⊙ h + z ⊙ hTilda) 162 | -------------------------------------------------------------------------------- /TypedFlow/Models/Transformer.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE PartialTypeSignatures #-} 2 | {-# LANGUAGE FlexibleInstances #-} 3 | {-# LANGUAGE MultiParamTypeClasses #-} 4 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 5 | {-# LANGUAGE ConstraintKinds #-} 6 | {-# LANGUAGE RecordWildCards #-} 7 | {-# LANGUAGE AllowAmbiguousTypes #-} 8 | {-# LANGUAGE DataKinds #-} 9 | {-# LANGUAGE DeriveFoldable #-} 10 | {-# LANGUAGE DeriveFunctor #-} 11 | {-# LANGUAGE DeriveTraversable #-} 12 | {-# LANGUAGE FlexibleContexts #-} 13 | {-# LANGUAGE GADTs #-} 14 | {-# LANGUAGE MagicHash #-} 15 | {-# LANGUAGE RankNTypes #-} 16 | {-# LANGUAGE ScopedTypeVariables #-} 17 | {-# LANGUAGE StandaloneDeriving #-} 18 | {-# LANGUAGE TypeApplications #-} 19 | {-# LANGUAGE TypeFamilies #-} 20 | {-# LANGUAGE TypeInType #-} 21 | {-# LANGUAGE TypeOperators #-} 22 | {-# LANGUAGE UndecidableInstances #-} 23 | {-# LANGUAGE UnicodeSyntax #-} 24 | {-# LANGUAGE NoStarIsType #-} 25 | {-| 26 | Module : TypedFlow.Models.Transformer 27 | Description : Topic models 28 | Copyright : (c) Jean-Philippe Bernardy, 2020 29 | License : LGPL-3 30 | Maintainer : jean-philippe.bernardy@gu.se 31 | Stability : experimental 32 | -} 33 | 34 | 35 | module TypedFlow.Models.Transformer where 36 | import Prelude hiding (RealFrac(..)) 37 | import TypedFlow.TF 38 | import TypedFlow.Abstract 39 | import TypedFlow.Layers 40 | import TypedFlow.Types 41 | import TypedFlow.Types.Proofs ((?>), knownSum') 42 | import GHC.TypeLits 43 | 44 | -- Convention for type variables: 45 | -- h = number of heads 46 | -- e = embedding size 47 | -- n = sequence length 48 | 49 | average :: forall e. KnownNat e => T '[e] Float32 -> Scalar Float32 50 | average = reduceMeanAll 51 | 52 | -- | Normalise a vector. But add a small epsilon to avoid division by zero 53 | normalizer :: forall e. KnownNat e => T '[e] Float32 -> T '[e] Float32 54 | normalizer x = mapT (⊘ (sigma + epsilon)) xmu -- so the norm of result is almost 1 55 | where mu = average x 56 | xmu = mapT (⊝ mu) x -- so the average of xmu is 0 57 | sigma = sqrt (average (square xmu)) -- the norm of xmu. 58 | epsilon = 0.001 -- ? 59 | 60 | -- Informally: 61 | -- mapT f x = vector y such that y_i = f (x_i) -- (the first axis) 62 | 63 | dimAsFloat :: forall e. KnownNat e => Float 64 | dimAsFloat = fromIntegral (knownNatVal (natSat @e)) 65 | 66 | -- | dot product attention on one key (k) 67 | dotAttention1 :: forall e n. KnownNat e => KnownNat n 68 | => T '[e,n] Float32 -> T '[n,e] Float32 -> T '[e] Float32 -> T '[e] Float32 69 | dotAttention1 q v k = v ∙ softmax0 (mapT (⊘ normFactor) (q ∙ k)) 70 | where normFactor = constant (sqrt (dimAsFloat @e)) 71 | 72 | -- | dot product attention for every position 73 | dotAttention :: forall n e. KnownNat n => KnownNat e 74 | => T '[n,e] Float32 -> T '[n,e] Float32 -> T '[n,e] Float32 -> T '[n,e] Float32 75 | dotAttention v k q = mapT (dotAttention1 (transpose01 q) v) k 76 | 77 | -- | h copies of a dense layer (the same for every copy). 78 | multiheadLinearEncoder :: forall h e. KnownNat e => KnownNat h => 79 | String -> Gen (T '[e] Float32 -> T '[h,e] Float32) 80 | multiheadLinearEncoder name = do 81 | wv <- parameterDefault ("w_" ++ name) 82 | return $ \x -> reshape (wv # x) 83 | 84 | multiheadSelfAttentionModule 85 | :: forall h n e. KnownNat n => KnownNat h => KnownNat e 86 | => String -> Gen (T '[n,e] Float32 -> T '[n,e] Float32) 87 | multiheadSelfAttentionModule nm = do 88 | ev <- multiheadLinearEncoder @h ("v" ++ nm) 89 | eq <- multiheadLinearEncoder @h ("q" ++ nm) 90 | ek <- multiheadLinearEncoder @h ("k" ++ nm) 91 | w1 <- parameterDefault ("w1" ++ nm) 92 | -- w2 <- parameterDefault ("w2" ++ nm) 93 | return $ \x -> 94 | let v = transpose01 (mapT ev x) 95 | q = transpose01 (mapT eq x) 96 | k = transpose01 (mapT ek x) 97 | r :: T '[n,h,e] Float32 98 | r = transpose01 (zipWith3T dotAttention q k v) 99 | r' = mapT (dense @e w1 . reshape @'[h * e]) r 100 | in mapT ({-dense w2 . -}normalizer) (r' + x) 101 | -- x + mapT normalizer r' 102 | 103 | multiheadSelfAttentionModuleDecoder 104 | :: forall h n e. KnownNat n => KnownNat h => KnownNat e 105 | => String -> Gen (T '[n,e] Float32 -> T '[n,e] Float32 -> T '[n,e] Float32) 106 | multiheadSelfAttentionModuleDecoder nm = do 107 | ev <- multiheadLinearEncoder @h ("v" ++ nm) 108 | eq <- multiheadLinearEncoder @h ("q" ++ nm) 109 | ek <- multiheadLinearEncoder @h ("k" ++ nm) 110 | w1 <- parameterDefault ("w1" ++ nm) 111 | -- w2 <- parameterDefault ("w2" ++ nm) 112 | return $ \x -- comes from decoder 113 | y -- comes from encoder 114 | -> 115 | let k = transpose01 (mapT ek y) 116 | v = transpose01 (mapT ev x) 117 | q = transpose01 (mapT eq y) 118 | r :: T '[n,h,e] Float32 119 | r = transpose01 (zipWith3T dotAttention q k v) 120 | r' = mapT (dense @e w1 . reshape @'[h * e]) r 121 | in mapT ({-dense w2 . -}normalizer) (r' + x) 122 | -- x + mapT normalizer r' 123 | 124 | 125 | feedForwardModule :: forall e. KnownNat e 126 | => String -> Gen (T '[e] Float32 -> T '[e] Float32) 127 | feedForwardModule nm = do 128 | w1 :: DenseP Float32 e e <- parameterDefault (nm ++ "w1") 129 | w2 <- parameterDefault (nm ++ "w2") 130 | return $ \x -> normalizer (x + (w2 # relu (w1 # x))) 131 | 132 | encoderModule :: forall h n e. KnownNat n => KnownNat h => KnownNat e => DropProb 133 | -> String -> T '[n,e] Float32 -> Gen (T '[n,e] Float32 -> T '[n,e] Float32) 134 | encoderModule dropProb nm positionalTensor = do 135 | drp <- mkDropout dropProb 136 | selfAtt <- multiheadSelfAttentionModule @h (nm ++ "mh") 137 | ff <- feedForwardModule (nm ++ "ff") 138 | return (mapT ff . selfAtt . (+ positionalTensor) . drp) 139 | 140 | positionalModuleSinCos :: forall n e. KnownNat e => KnownNat n => T '[n,e] Float32 141 | positionalModuleSinCos = sin (transpose01 (broadcastT pos) * (broadcastT omega) + broadcastT phase) 142 | where pos = (cast (range @n @'B32)) :: T '[n] Float32 143 | phase = cast ((range @e @'B32) `floorMod` constant 2) * (constant pi/2) :: T '[e] Float32 144 | omega = constant (log 10000) * exp (constant (-2.0 / dimAsFloat @e) * cast (range @e @'B32)) 145 | -- Note I'm not dividing the frequence by 2 because integer 146 | -- division isn't implemented. Should not have any consequence. 147 | 148 | positionalModuleLearned :: KnownNat e => KnownNat n => Gen (T '[n,e] Float32) 149 | positionalModuleLearned = do 150 | e <- parameterDefault "positional" 151 | return $ let EmbeddingP x = e in x 152 | 153 | encoderStack :: forall h n e. KnownNat h => KnownNat n => KnownNat e 154 | => DropProb -> Int -> Gen (T '[n,e] Float32 -> T '[n,e] Float32) 155 | encoderStack dropProb n = do 156 | p <- positionalModuleLearned 157 | encoders <- mapM (\i -> encoderModule @h dropProb ("enc" ++ show i) p) [1..n] 158 | return (foldr (.) id encoders) -- n-ary function composition 159 | -------------------------------------------------------------------------------- /TypedFlow/Python.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ViewPatterns #-} 2 | {-| 3 | Module : TypedFlow.Python 4 | Description : Python-generation Functions 5 | Copyright : (c) Jean-Philippe Bernardy, 2017 6 | License : LGPL-3 7 | Maintainer : jean-philippe.bernardy@gu.se 8 | Stability : experimental 9 | 10 | -} 11 | 12 | {-# LANGUAGE AllowAmbiguousTypes #-} 13 | {-# LANGUAGE ConstraintKinds #-} 14 | {-# LANGUAGE DataKinds #-} 15 | {-# LANGUAGE DeriveFoldable #-} 16 | {-# LANGUAGE DeriveFunctor #-} 17 | {-# LANGUAGE DeriveTraversable #-} 18 | {-# LANGUAGE FlexibleContexts #-} 19 | {-# LANGUAGE FlexibleInstances #-} 20 | {-# LANGUAGE GADTs #-} 21 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 22 | {-# LANGUAGE LambdaCase #-} 23 | {-# LANGUAGE MagicHash #-} 24 | {-# LANGUAGE MultiParamTypeClasses #-} 25 | {-# LANGUAGE OverloadedStrings #-} 26 | {-# LANGUAGE PatternSynonyms #-} 27 | {-# LANGUAGE RankNTypes #-} 28 | {-# LANGUAGE RecordWildCards #-} 29 | {-# LANGUAGE ScopedTypeVariables #-} 30 | {-# LANGUAGE StandaloneDeriving #-} 31 | {-# LANGUAGE TypeApplications #-} 32 | {-# LANGUAGE TypeFamilies #-} 33 | {-# LANGUAGE TypeInType #-} 34 | {-# LANGUAGE TypeOperators #-} 35 | {-# LANGUAGE UndecidableInstances #-} 36 | {-# LANGUAGE UndecidableSuperClasses #-} 37 | {-# LANGUAGE UnicodeSyntax #-} 38 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 39 | 40 | module TypedFlow.Python (compile, compileGen, generateFile) where 41 | 42 | import Data.Char (toLower) 43 | import Data.Proxy 44 | import Data.List (genericReplicate, ) 45 | import GHC.TypeLits 46 | import Control.Monad.State 47 | import TypedFlow.Types 48 | import TypedFlow.Broadcast (permToFun,unopInputShape) 49 | import TypedFlow.Types.Proofs 50 | import TypedFlow.Memo 51 | import Prettyprinter as PP 52 | import Prettyprinter.Render.String as PP 53 | import qualified Data.Map as M 54 | import TypedFlow.Learn 55 | import qualified Data.Sequence as S 56 | import Data.Sequence (Seq, (|>), ) 57 | import Data.Foldable 58 | 59 | first :: (t -> a) -> (t, b) -> (a, b) 60 | first f (x,y) = (f x,y) 61 | 62 | paramShape' :: VarInfo -> [Integer] 63 | paramShape' (VarInfo {varRef = Ref _ s _}) = shapeToList' s 64 | 65 | paramDType :: VarInfo -> Typ 66 | paramDType (VarInfo {varRef = Ref _ _ t}) = sTypTyp t 67 | 68 | paramName :: VarInfo -> String 69 | paramName (VarInfo {varRef = Ref {..}}) = refName 70 | 71 | 72 | generateFile :: String -> Python [VarInfo] -> IO () 73 | generateFile fname g = do 74 | putStrLn ("Parameters (total " ++ show (sum [product (paramShape' p) | p <- params]) ++ "):") 75 | forM_ params printParam 76 | writeFile fname output 77 | where (output,params) = generate g 78 | printParam p = putStrLn (paramName p ++ ": " ++ "T " ++ renderSimple (showShape' (paramShape' p)) ++ " " ++ showT (paramDType p)) 79 | 80 | named :: String -> DOC -> DOC 81 | named fname x = text (fname <> "=") <> x 82 | 83 | text :: String -> DOC 84 | text = pretty 85 | 86 | genFun :: forall b. String -> [DOC] -> Python b -> Python b 87 | genFun name args body = do 88 | gen (text "def " <> text name <> align (tuple args) <> text ":") 89 | withDOC (\b -> " " <> align b) body 90 | 91 | 92 | showTyp :: forall t. KnownTyp t => DOC 93 | showTyp = text (showT (typVal @t)) 94 | 95 | showSTyp :: forall t. STyp t -> DOC 96 | showSTyp t = knownTyp t $ showTyp @t 97 | 98 | showT :: Typ -> [Char] 99 | showT (Typ Bool _) = "tf.bool" 100 | showT (Typ Cmplx B32) = "tf.complex64" 101 | showT (Typ Cmplx B64) = "tf.complex128" 102 | showT (Typ k l) = "tf." ++ map toLower (show k) ++ drop 1 (show l) 103 | 104 | showShape' :: [Integer] -> DOC 105 | showShape' s = list (map (showDim' "None") s) 106 | 107 | showShape :: ∀ (s :: Shape). All KnownNat s => SList s -> DOC 108 | showShape s = showShape' (shapeToList'' s) 109 | 110 | showSShape :: ∀ (s :: Shape). SShape s -> DOC 111 | showSShape s = showShape' (shapeToList' s) 112 | 113 | showShapeType :: ∀ (s :: Shape). KnownShape s => DOC 114 | showShapeType = showSShape (typeSShape @s) 115 | 116 | -- | Show a shape, but "None" is replaced by "-1" 117 | showShapeMinus :: forall (s::Shape) proxy. All KnownNat s => SList' proxy s -> DOC 118 | showShapeMinus s = list (map (showDim' "-1") (shapeToList'' s)) 119 | 120 | showShapeLen :: ∀ (s::Shape). KnownLen s => DOC 121 | showShapeLen = (text . show) (listTypeLen @ s) 122 | 123 | showDim' :: String -> Integer -> DOC 124 | showDim' none n = text (if n == 514229 then none else show n) 125 | 126 | showDimM :: forall n. KnownNat n => DOC 127 | showDimM = showDim' "-1" (natVal (Proxy @ n)) 128 | 129 | showDim :: forall n. KnownNat n => DOC 130 | showDim = showDim' "None" (natVal (Proxy @ n)) 131 | 132 | showDimS :: forall n. Sat KnownNat n -> DOC 133 | showDimS Sat = showDim @n 134 | 135 | gen :: DOC -> Python () 136 | gen s = modify $ \PyState{..} -> PyState {genText=genText |> s,..} 137 | 138 | setGen :: Seq DOC -> Python () 139 | setGen d = modify $ \PyState{..} -> PyState {genText=d,..} 140 | 141 | (<--) :: Ref Int s t -> UntypedExpression -> Python () 142 | x <-- y = gen (pyVarRepr x <> text "=" <> y) 143 | 144 | 145 | renderSimple :: Doc ann -> String 146 | renderSimple = renderString . layoutPretty (LayoutOptions Unbounded) 147 | 148 | -- | save an intermediate result to a variable and save it to 149 | -- genAssignTable for future re-use. 150 | cache :: forall s t. KnownTyp t => KnownShape s => DOC -> Python DOC 151 | cache x = do 152 | let x' = renderSimple x 153 | mcache <- M.lookup x' <$> gets genAssignTable 154 | case mcache of 155 | Just y -> do 156 | -- comment ("cache hit: " <> text x') 157 | return y 158 | Nothing -> do 159 | -- comment ("cache miss") 160 | v <- newPyVar @s @t 161 | comment ("shape: " <> (showShapeType @s)) 162 | v <-- x 163 | modify $ (\g -> g {genAssignTable = M.insert x' (pyVarRepr v) (genAssignTable g)}) 164 | return (pyVarRepr v) 165 | 166 | newPyVar' :: forall s t. SShape s -> STyp t -> Python (Ref Int s t) 167 | newPyVar' s t = knownSShape s ?> (knownTyp t $ newPyVar @s @t) 168 | 169 | newId :: Python Integer 170 | newId = do 171 | n <- gets genId 172 | modify $ \PyState{..} -> PyState {genId=genId+1,..} 173 | return n 174 | 175 | newPyVar :: forall s t. KnownShape s => KnownTyp t => Python (Ref Int s t) 176 | newPyVar = do 177 | n <- newId 178 | return $ Ref (fromIntegral n) typeSShape typeSTyp 179 | 180 | pyVarInfoRepr :: VarInfo -> DOC 181 | pyVarInfoRepr i = text (varName i) 182 | 183 | pyVarRepr :: Ref Int s t -> DOC 184 | pyVarRepr (Ref n _ _) = text ("var" <> show n) 185 | 186 | tuple :: [DOC] -> DOC 187 | tuple = parens . align . sep . punctuate comma 188 | dict :: [(String,DOC)] -> DOC 189 | dict xs = braces $ align $ sep $ punctuate comma [text (show k) <> ":" <> v | (k,v) <- xs] 190 | 191 | funcall :: String -> [DOC] -> DOC 192 | funcall = funcall' . text 193 | 194 | funcall' :: DOC -> [DOC] -> DOC 195 | funcall' f args = f <> tuple args 196 | 197 | comment :: DOC -> Python () 198 | comment c = gen ("#" <> c) 199 | 200 | func :: String -> [DOC] -> [(String,DOC)] -> DOC 201 | func fname positional namedArgs = funcall fname (positional ++ map (uncurry named) namedArgs ) 202 | 203 | withDOC :: forall a. (DOC -> DOC) -> Python a -> Python a 204 | withDOC f g = do 205 | before <- gets genText 206 | setGen mempty 207 | x <- g 208 | after <- gets genText 209 | setGen (before |> f (vcat $ toList after)) 210 | return x 211 | 212 | generate :: Python [VarInfo] -> (String,[VarInfo]) 213 | generate s = (renderString (layoutPretty (LayoutOptions (AvailablePerLine 92 1)) (vcat $ toList genText)), 214 | genPyVars) 215 | where (genPyVars,PyState{..}) = runState s initPyState 216 | initPyState = PyState {genPureTable = mempty 217 | ,genAssignTable = mempty 218 | ,genText = mempty 219 | ,genId = 10000} 220 | 221 | generatePure :: forall s t. KnownTyp t => KnownShape s => T s t -> Python DOC 222 | generatePure x = do 223 | let sn = makeSn2 x 224 | mv <- snMapLookup2 sn <$> gets genPureTable 225 | case mv of 226 | Just v -> do 227 | -- comment ("gp hit:" <> v) 228 | return v 229 | Nothing -> do 230 | -- comment ("gp miss") 231 | e <- generatePure' (\s x' -> knownSShape s ?> generatePure x') typeSShape x 232 | v <- cache @s @t e 233 | modify (\g -> g {genPureTable = (snMapInsert2 sn v) (genPureTable g)}) 234 | return v 235 | 236 | genDistr :: forall s s0 t. KnownTyp t => Distribution s t -> SShape s0 -> SShape s -> DOC 237 | genDistr d sh s1 = case d of 238 | TruncatedNormalD stddev -> funcall "tf.random.truncated_normal" 239 | [showSShape (sh .+. s1), named "stddev" (float stddev), named "dtype" (showTyp @t)] 240 | UniformD low high -> funcall "tf.random.uniform" [showSShape (sh .+. s1) 241 | ,named "minval" (float low) 242 | ,named "maxval" (float high) 243 | ,named "dtype" (showTyp @t)] 244 | OrthogonalD -> 245 | funcall' (funcall "tf.keras.initializers.orthogonal" []) [named "dtype" (showTyp @t), named "shape" (showSShape (sh .+. s1))] 246 | 247 | generatePure' :: forall s t. KnownTyp t => (forall s' t'. KnownTyp t' => SShape s' -> T s' t' -> Python DOC) -> SShape s -> T s t -> Python DOC 248 | generatePure' rec sR = knownSShape sR ?> \case 249 | Unbroadcast{} -> error "broadcasting operation did not complete (Unbroadcast)!" 250 | BroadcastT _ _ _ sh x -> --- error "broadcasting operation did not complete (BroadcastT)!" 251 | do 252 | -- debug help 253 | rx <- rec sh x 254 | return (funcall "ERROR:BroadcastT" [rx]) 255 | MapT {} -> error "broadcasting operation did not complete (mapT)!" 256 | ZipT {} -> error "broadcasting operation did not complete (ZipT)!" 257 | Zip3T {} -> error "broadcasting operation did not complete (Zip3T)!" 258 | If c x y -> do 259 | rc <- rec typeSShape c 260 | rx <- rec typeSShape x 261 | ry <- rec typeSShape y 262 | return (func "tf.cond" [rc] [("true_fn", lambda0 rx) ,("false_fn", lambda0 ry)]) 263 | where lambda0 z = text "lambda: " <> z 264 | -- if broadcast_to is broken: https://github.com/tensorflow/tensorflow/issues/21901 265 | -- DirectBroadcast s0 s1 s2 s3 x -> do 266 | -- recx <- rec (s0 .+. s2) x 267 | -- let expanded = func "tf.reshape" [recx,list (map (showDim' "-1") 268 | -- (concat [shapeToList' s0, genericReplicate (sListLength s1) 1 269 | -- ,shapeToList' s2, genericReplicate (sListLength s3) 1 ]))] [] 270 | -- return (funcall "tf.add" [expanded, func "tf.zeros" [showSShape sR] [("dtype", showTyp @t)]]) 271 | DirectBroadcast s0 s1 s2 s3 x -> do 272 | recx <- rec (s0 .+. s2) x 273 | let expanded = func "tf.reshape" [recx,list (map (showDim' "-1") 274 | (concat [shapeToList' s0, genericReplicate (sListLength s1) 1 275 | ,shapeToList' s2, genericReplicate (sListLength s3) 1 ]))] [] 276 | return (funcall "tf.broadcast_to" [expanded, showSShape sR]) 277 | Noise noiseId s0 s1 x -> do 278 | return $ (genDistr x s0 s1) <+> (text "# " <> integer noiseId) 279 | T op -> return $ case op of 280 | ExternalVar (Ref v _ _) -> text v 281 | Variable v -> pyVarRepr v 282 | (Constant c) -> funcall "tf.constant" [prettyT @t c, named "shape" (showSShape sR), named "dtype" (showTyp @t)] 283 | (Range n@Sat) -> (func "tf.range" [] [("start",integer 0), 284 | ("limit",integer (natVal n)), 285 | ("dtype",showTyp @t)]) 286 | Where c x y -> do 287 | rc <- rec typeSShape c 288 | rx <- rec typeSShape x 289 | ry <- rec typeSShape y 290 | return (funcall "tf.where" [rc, rx, ry]) 291 | UnOp operation s0 x -> do 292 | recx <- rec (s0 .+. unopInputShape operation) x 293 | return $ case operation of 294 | Diag _ -> funcall "tf.matrix_diag" [recx] 295 | Cast -> funcall "tf.cast" [recx,showTyp @t] 296 | StopGradient -> funcall "tf.stop_gradient" [recx] 297 | ExpM _ -> funcall "tf.linalg.expm" [recx] 298 | ZeroTriangle _ side k -> funcall ("tf.experemental.numpy.tri" ++ case side of Upper -> "u"; Lower -> "l") [recx, integer k] 299 | Conjugate -> funcall "tf.math.conj" [recx] 300 | RealPart -> funcall "tf.math.real" [recx] 301 | Axis1Op _ (SliceOp _ _ lo hi) -> recx <> list (replicate (fromIntegral (sListLength s0)) (text ":") ++ [integer lo <> text ":" <> integer hi]) 302 | Axis1Op _ (AccessOp _ idx) -> recx <> list (replicate (fromIntegral (sListLength s0)) (text ":") ++ [integer idx]) 303 | Axis1Op _ op' -> 304 | let (op,args) = case op' of 305 | SliceOp {} -> error "Python: panic: sliceop is special" 306 | AccessOp {} -> error "Python: panic: accessop is special" 307 | ReverseT _ -> ("tf.reverse",[]) 308 | OneHot depth -> ("tf.one_hot",[("dtype",showTyp @t), ("depth", showDimS depth)]) 309 | ArgMax{} -> ("tf.argmax",[("output_type",showTyp @t)]) 310 | ReduceOp _ r -> ("tf.reduce_" ++ rop, []) 311 | where rop = case r of 312 | Max -> "max" 313 | Min -> "min" 314 | Sum -> "sum" 315 | Mean -> "mean" 316 | axisName = if op == "tf.nn.softmax" then "dim" else "axis" -- use dim before TF 1.5 317 | useAxisList = case op' of ReverseT _ -> True; _ -> False 318 | in func op [recx] ((axisName,(if useAxisList then (list . (:[])) else id) (integer (sListLength s0))):args) 319 | Float1Op op' -> funcall op (recx:args) 320 | where (op,args) = case op' of 321 | HardSigmoid -> ("tf.keras.backend.hard_sigmoid",[]) 322 | Relu -> ("tf.nn.relu",[]) 323 | ClipByValue lo hi -> ("tf.clip_by_value",[float lo,float hi]) 324 | _ -> ("tf." ++ map toLower (show op'), []) 325 | Num1Op op' -> funcall op (recx:args) 326 | where (op,args) = case op' of 327 | Negate -> ("tf.negative",[]) 328 | _ -> ("tf." ++ map toLower (show op'), []) 329 | MatMul s0 a b c x y -> do 330 | recx <- rec (s0 .+. (:*) a ((:*) b Unit)) x 331 | recy <- rec (s0 .+. (:*) b ((:*) c Unit)) y 332 | return (funcall "tf.matmul" [recx, recy]) 333 | BinOp operation s0 s1 _ s2 _ x y -> do 334 | recx <- rec (s0 .+. s1) x 335 | recy <- rec (s0 .+. s2) y 336 | return $ case operation of 337 | Simple2Op sop -> let pop = case sop of 338 | MkComplex -> "tf.complex" 339 | Add -> "tf.add" 340 | Divide -> "tf.divide" 341 | IntegerDiv -> "tf.math.floordiv" 342 | Equal -> "tf.equal" 343 | Subtract -> "tf.subtract" 344 | Multiply -> "tf.multiply" 345 | Minimum -> "tf.minimum" 346 | Maximum -> "tf.maximum" 347 | Comparision op -> "tf.math." ++ case op of 348 | Less -> "less" 349 | Greater -> "greater" 350 | LessOrEqual -> "less_equal" 351 | GreaterOrEqual -> "greater_equal" 352 | Logic op -> "tf.math.logical_" ++ case op of 353 | And -> "and" 354 | Or -> "or" 355 | FloorMod -> "tf.math.floorMod" 356 | in funcall pop [recx,recy] 357 | SigmoidCrossEntropyWithLogits -> func "tf.nn.sigmoid_cross_entropy_with_logits" [] [("labels",recx),("logits",recy)] 358 | SparseSoftmaxCrossEntropyWithLogits -> func "tf.nn.sparse_softmax_cross_entropy_with_logits" [] [("labels",recx),("logits",recy)] 359 | SoftmaxCrossEntropyWithLogits -> func "tf.nn.softmax_cross_entropy_with_logits" [] [("labels",recx),("logits",recy)] -- FIXME: use _v2 for TF 1.5 360 | ReshapeFrom s t -> do 361 | rt <- rec s t 362 | return (funcall "tf.reshape" [rt, showShapeMinus sR]) 363 | Concat s0 s1 xs -> do 364 | let go :: forall s0 s1 ns. SShape s0 -> SShape s1 -> NP (Catable s0 s1 t) ns -> Python [DOC] 365 | go _ _ Unit = return [] 366 | go s0' s1' (Catable n y :* ys) = (:) <$> rec (s0' .+. n :* s1') y <*> go s0' s1' ys 367 | rxs <- go s0 s1 xs 368 | return (funcall "tf.concat" [list rxs, text "axis=" <> integer (sListLength s0)]) 369 | Transpose s p x -> do 370 | rx <- rec s x 371 | comment ("transpose: p = " <> text (show p) <> "; " <> text (show s)) 372 | return (func "tf.transpose" [rx] [("perm",list (map (integer . permToFun p) [0.. sListLength s-1]))]) 373 | Gather indexShape s0 m s1 x ix -> do 374 | rx <- rec (s0 .+. ((:*) m s1)) x 375 | rix <- rec (s0 .+. indexShape) ix 376 | return (func "tf.gather" [named "params" rx, named "indices" rix, named "batch_dims" (integer (sListLength s0)), named "axis" (integer (sListLength s0))] []) 377 | GatherND containerShape elementShape indexShape x ix -> do 378 | rx <- rec (containerShape .+. elementShape) x 379 | rix <- rec (indexShape *: (sListLenAsNat containerShape)) ix 380 | return (func "tf.gather_nd" [rx, rix] []) 381 | Convolution bs inChans outChans filterShape s0 x filters -> do 382 | recx <- rec ((:*) bs (s0 *: inChans)) x 383 | recFilters <- rec (filterShape .+. ((:*) inChans ((:*) outChans Unit))) filters 384 | return (func "tf.nn.convolution" [recx, recFilters] [("padding",text (show ("SAME"::String))),("data_format", text (show dataFormat))]) 385 | where dataFormat = case sListLength filterShape of 386 | 1 -> ("NWC" :: String) 387 | 2 -> "NHWC" 388 | 3 -> "NDHWC" 389 | _ -> error "convolution: more than 3 spatial dimensions are not supported!" 390 | Pool bs window typ numChans outSpatial x -> do 391 | rx <- rec ((:*) bs (zipWithMulSShapes window outSpatial .+. (:*) numChans Unit)) x 392 | return (func "tf.nn.pool" 393 | [rx, showSShape window, typ'] 394 | [("strides", showSShape window), 395 | ("padding",text (show ("SAME" :: String)))]) 396 | where typ' = text $ (show $ case typ of MaxPool -> "MAX"; AvgPool -> "AVG" :: String) 397 | Softmax _ _ x -> do 398 | rx <- rec typeSShape x 399 | return $ func "tf.nn.softmax" [rx] [("axis","1")] 400 | -- _ -> error "Python compiler: case not covered" 401 | type Python a = State PyState a 402 | 403 | generateParameters :: [VarInfo] -> Python [DOC] 404 | generateParameters genVars = do 405 | -- generate variables 406 | forM genVars $ \v -> case v of 407 | VarInfo {..} -> case varRef of 408 | Ref refId shap typ -> do 409 | ii <- case varInitial of 410 | Nothing -> return [] 411 | Just iii -> do 412 | iiii <- case knownSShape shap of 413 | Sat -> knownTyp typ $ generatePure iii 414 | return [named "initial_value" iiii] 415 | var <- newPyVar' shap typ 416 | var <-- funcall "tf.Variable" ([named "name" (string refId), named "trainable" (bool varTrainable)] ++ ii) 417 | return (pyVarRepr var) 418 | 419 | -- | Clip a gradient 420 | clipByGlobalNorm :: Float -> UntypedExpression -> UntypedExpression 421 | clipByGlobalNorm maxNorm x = funcall "tf.clip_by_global_norm" [x,float maxNorm] <> brackets (int 0) 422 | -- clip_by_global_norm returns a couple (clipped grads, global_norm) 423 | 424 | -- | Gradient of wrt. given parameters. 425 | grad :: UntypedExpression -> UntypedExpression -> UntypedExpression 426 | grad y vars = funcall "tf.gradients" [y, vars] 427 | 428 | 429 | fnToPython ::[VarInfo] -> PreparedFunction -> Python () 430 | fnToPython params PreparedFunction{pfInputs = SomeSuch placeHolders, 431 | pfOutputs = SomeSuch returned,..} = do 432 | -- we can't re-use intermediate computations from initialisers or other functions: 433 | modify $ \PyState {..} -> PyState {genPureTable = mempty, genAssignTable = M.empty,..} 434 | gen (text "@tf.function") 435 | genFun (pfName <> "_fn") (text "training_placeholder": 436 | map pyVarInfoRepr params ++ 437 | hMapToList @KnownPlaceholder (text . placeholderName) placeHolders) $ 438 | do returns <- hfor @KnownPlaceholder returned $ \ph@(PHT x) -> do 439 | r <- generatePure x 440 | return (placeholderName ph,r) 441 | gen (text "return " <> dict returns) 442 | return () 443 | gen (text pfName <> " = " <> 444 | dict [ 445 | ("function",text pfName <> "_fn"), 446 | ("batched",bool pfBatched), 447 | ("placeholders",dict (hMapToList @KnownPlaceholder 448 | (\ph -> case placeHolderRef ph of 449 | Ref nm shape typ -> 450 | (nm, dict [("shape",showSShape shape), ("dtype",showSTyp typ)])) 451 | placeHolders))]) 452 | return () 453 | 454 | toPython :: PreparedModel -> Python () 455 | toPython PreparedModel {..} = do 456 | gen (text "import tensorflow as tf") 457 | -- Static stuff: construct and initialise parameters, list placeholders, etc. 458 | genFun "mkModel" [] $ do 459 | vs <- generateParameters pmParams 460 | gen (text "return " <> 461 | dict [("batch_size", integer pmBatchSize) 462 | ,("parameters",list vs) 463 | ,("paramsdict",dict [(varName p, v) | (p,v) <- zip pmParams vs])]) 464 | -- Loss/Accur/Predict function 465 | forM_ pmFunctions (fnToPython pmParams) 466 | return () 467 | 468 | -- | Batchify and compile a model with simple input to output mapping. 469 | compile :: forall batchSize sx tx sy ty sy_ ty_ p 470 | . (KnownNat batchSize, KnownShape sx, KnownTyp tx, KnownShape sy, KnownTyp ty, KnownShape sy_, KnownTyp ty_, KnownShape p, KnownLen p) 471 | => Options 472 | -> Gen (Tensor sx tx -> Tensor sy ty -> ModelOutput ty_ p sy_) 473 | -> Python [VarInfo] 474 | compile options fGen = knownSShape (typeSShape @sy_ .+. typeSShape @p) ?> compileGen @batchSize options (sequenceA [simpleModel @p <$> fGen]) 475 | 476 | -- | Batchify and compile a model with generic input to output mapping and states 477 | compileGen :: forall bs. (KnownNat bs) 478 | => Options 479 | -> Gen [Function] 480 | -> Python [VarInfo] 481 | compileGen options model = toPython pm >> return pmParams 482 | where pm@PreparedModel{..} = prepare @bs model 483 | 484 | 485 | 486 | prettyT :: forall t. KnownTyp t => HaskType t -> DOC 487 | prettyT = case kindVal @(TypKind t) of 488 | SInt -> case bitsVal @(TypBits t) of 489 | SB32 -> int . fromIntegral 490 | SB64 -> int . fromIntegral 491 | SBool -> bool 492 | SFloat -> case bitsVal @(TypBits t) of 493 | SB32 -> float 494 | SB64 -> double 495 | 496 | 497 | 498 | data PyState = PyState {genId :: Integer 499 | ,genText :: S.Seq DOC 500 | ,genPureTable :: SSNMap2 Shape Typ T DOC 501 | -- ^ Table mapping pointers to their 502 | -- interpretations, so that sharing in the data 503 | -- structures can be exploited when generating 504 | ,genAssignTable :: M.Map String DOC 505 | -- ^ Table mapping expressions to variables, so 506 | -- that lost sharing can be recovered 507 | -- genPeeks :: [(String,UntypedExpression)] 508 | } 509 | 510 | type UntypedExpression = DOC 511 | type DOC = Doc () 512 | 513 | double :: Double -> DOC 514 | double = pretty 515 | float :: Float -> DOC 516 | float = pretty 517 | integer :: Integer -> DOC 518 | integer = pretty 519 | int :: Int -> DOC 520 | int = pretty 521 | bool :: Bool -> DOC 522 | bool = pretty 523 | string :: String -> DOC 524 | string = dquotes . text 525 | -------------------------------------------------------------------------------- /TypedFlow/TF.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE InstanceSigs #-} 2 | {-| 3 | Module : TypedFlow.TF 4 | Description : Binding to tensorflow functions 5 | Copyright : (c) Jean-Philippe Bernardy, 2017 6 | License : LGPL-3 7 | Maintainer : jean-philippe.bernardy@gu.se 8 | Stability : experimental 9 | 10 | This module provides direct access to the most commonly used 11 | TensorFlow functions. Higher-level functions are not defined here. 12 | -} 13 | 14 | {-# LANGUAGE LambdaCase #-} 15 | {-# LANGUAGE FlexibleInstances #-} 16 | {-# LANGUAGE MultiParamTypeClasses #-} 17 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 18 | {-# LANGUAGE ConstraintKinds #-} 19 | {-# LANGUAGE RecordWildCards #-} 20 | {-# LANGUAGE AllowAmbiguousTypes #-} 21 | {-# LANGUAGE DataKinds #-} 22 | {-# LANGUAGE DeriveFoldable #-} 23 | {-# LANGUAGE DeriveFunctor #-} 24 | {-# LANGUAGE DeriveTraversable #-} 25 | {-# LANGUAGE FlexibleContexts #-} 26 | {-# LANGUAGE GADTs #-} 27 | {-# LANGUAGE MagicHash #-} 28 | {-# LANGUAGE RankNTypes #-} 29 | {-# LANGUAGE ScopedTypeVariables #-} 30 | {-# LANGUAGE StandaloneDeriving #-} 31 | {-# LANGUAGE TypeApplications #-} 32 | {-# LANGUAGE TypeFamilies #-} 33 | {-# LANGUAGE TypeInType #-} 34 | {-# LANGUAGE TypeOperators #-} 35 | {-# LANGUAGE UndecidableInstances #-} 36 | {-# LANGUAGE UnicodeSyntax #-} 37 | {-# LANGUAGE ApplicativeDo #-} 38 | {-# LANGUAGE NoStarIsType #-} 39 | 40 | module TypedFlow.TF ( 41 | -- * Variables, Parameters 42 | -- ** Parameters 43 | parameter', 44 | parameter, 45 | parameterDefault, 46 | ParamWithDefault(..), 47 | -- getParameters, 48 | -- ** Persistent variables 49 | persistent, 50 | modifyPersistent, 51 | -- ** Placeholders and outputs 52 | -- placeholder, 53 | -- peekAt, 54 | -- peekAtMany, 55 | -- * Operations 56 | -- ** Constants 57 | zeros, 58 | ones, 59 | eye, 60 | constant, 61 | -- ** indexwise unary operators 62 | round, sigmoid, relu, floor, square, 63 | -- ** Indexwise binary operators 64 | addN, (⊕), (⊝), (⊙), (⊘), equal, 65 | minT, maxT, 66 | -- ** Products 67 | (∙), (·), matmul, 68 | -- ** Reducers 69 | reduceMeanAll, reduceSumAll, reduceMinAll, reduceMaxAll, 70 | reduceSum, reduceMean, reduceMin, reduceMax, 71 | -- argmax, 72 | argmax0, argmax1, 73 | softmax0, softmax1, 74 | -- ** Gradients 75 | -- grad, 76 | -- clipByGlobalNorm, 77 | clipByValue, 78 | -- ** Indexing 79 | last0, nth0, nth0', lookupT, lookupManyT, gather, range, reverseT, 80 | -- ** Split and concatenate 81 | slice, slice0, slice1, 82 | litStack0, 83 | stack0, unstack0, 84 | stack1, 85 | concatT, concat0, concat1, 86 | consT0, snocT0, 87 | headT0, tailT0, initT0, 88 | -- ** Reshaping 89 | expandDim, 90 | expandDim0, squeeze0, 91 | expandDim1, 92 | flatten2, flatten3, flatten12, flattenN2, 93 | inflate2, inflate3, inflate12, 94 | reshape, flattenAll, inflateAll, 95 | -- ** Transposition 96 | transposeN, transposeN', transpose01, transposeN01, 97 | -- ** Sequences 98 | sequenceMask, 99 | -- ** Convolutions 100 | convolution, 101 | -- ** Misc 102 | norm, normalize, 103 | stopGradient, 104 | cast, 105 | oneHot0, oneHot1, 106 | -- ** complex numbers 107 | expm, conjugate, realPart, 108 | -- ** Triangular and band Matrices 109 | tril, triu, fillTriangular, fillUpperTriangular, 110 | -- ** Testing conditions 111 | if_, where_, lessThan, 112 | -- * Contrib 113 | -- ** Mapping 114 | mapT, zipWithT, zipWith3T, 115 | mapTT, zipWithTT, 116 | -- ** Losses 117 | sigmoidCrossEntropyWithLogits, 118 | softmaxCrossEntropyWithLogits, 119 | sparseSoftmaxCrossEntropyWithLogits, 120 | -- ** Initializers 121 | noise, 122 | Distribution(..), 123 | varianceScaling, glorotUniform, 124 | 125 | -- ** Heterogeneous vectors 126 | repeatT, 127 | 128 | -- ** Heterogeneous heterogeneous vectors 129 | repeatHT 130 | ) where 131 | 132 | import Prelude hiding (RealFrac(..)) 133 | import GHC.TypeLits 134 | import Data.Proxy 135 | import TypedFlow.Types 136 | import TypedFlow.Types.Proofs 137 | import TypedFlow.Abstract 138 | import TypedFlow.Broadcast 139 | 140 | -- | Repeat a flexible-shape constant vector to form a heterogeneous tensor vector. 141 | repeatT :: forall (ss :: [Shape]) t. All KnownShape ss => KnownLen ss => 142 | (forall s. KnownShape s => T s t) -> HTV t ss 143 | repeatT f = zs (typeSList @ss) 144 | where zs :: forall (s :: [Shape]). All KnownShape s => SList s -> HTV t s 145 | zs Unit = Unit 146 | zs (_ :* n) = F f :* zs n 147 | 148 | -- | Repeat a flexible-shape constant vector to form a heterogeneous tensor vector. 149 | repeatHT :: forall ss. All KnownPair ss => KnownLen ss => 150 | (forall s t. KnownShape s => KnownTyp t => T s t) -> HHTV ss 151 | repeatHT f = zs (typeSList @ss) 152 | where zs :: forall s. All KnownPair s => SList s -> HHTV s 153 | zs Unit = Unit 154 | zs (_ :* n) = Uncurry f :* zs n 155 | 156 | -- | Declare a parameter to optimize. 157 | parameter' :: ∀ (shape :: Shape) t. (KnownTyp t,KnownShape shape) => String -> T shape t -> Gen (T shape t) 158 | parameter' = persistent True 159 | 160 | -- | Create a parameter. 161 | parameter :: forall p. KnownTensors p => String -> Gen p -> Gen p 162 | parameter s p = travTensor parameter' s =<< p 163 | 164 | -- | Declare variable which persists between calls to session.run. 165 | persistent :: ∀ (shape :: Shape) t. (KnownTyp t,KnownShape shape) => Bool -> String -> T shape t -> Gen (T shape t) 166 | persistent trainable name initial = do 167 | T . ExternalVar <$> GPVariable trainable name (Just initial) 168 | 169 | 170 | -- | Modify a mutable tensor. Attention: for the assignment to happen, 171 | -- the resulting tensor must be evaluated! 172 | modifyPersistent :: (KnownShape s,KnownTyp t) => T s t -> T s t -> Gen (T s t) 173 | modifyPersistent (T (Variable v)) x = GPModify v x -- FIXME: pattern matching here is poor style. 174 | 175 | -- type family AddSpatialDims xs ys where 176 | -- AddSpatialDims '[x] '[] = '[x] 177 | -- AddSpatialDims (x ': xs) (y ': ys) = (x+(y-1)) ': AddSpatialDims xs ys 178 | 179 | -- -- | Convolution operation with no padding (applying the filter only on positions where the input is fully defined) 180 | -- convolutionValid :: forall outputChannels filterSpatialShape inChannels s t. 181 | -- KnownLen filterSpatialShape 182 | -- => Length filterSpatialShape <= 3 183 | -- => ((1 + Length filterSpatialShape) ~ Length s) -- the last dim of s is the batch size 184 | -- => T (inChannels ': AddSpatialDims s filterSpatialShape) t -- ^ input tensor (batched) 185 | -- -> T ('[outputChannels,inChannels] ++ filterSpatialShape) t -- ^ filters 186 | -- -> T (outputChannels ': s) t 187 | -- convolutionValid = untypedConvolution "VALID" 188 | 189 | -- poolNC :: forall dim s inputSpatialShape channels batchSize t. 190 | -- (inputSpatialShape ~ Take dim s, '[batchSize] ~ Drop dim s) => 191 | -- T ('[channels] ++ s) t -> 192 | -- Vec dim -> String -> String -> 193 | -- T ('[channels] ++ s) t 194 | -- poolNC (T input) windowShape poolingType padding = 195 | -- T (funcall "tf.nn.pool" [input,list (map float (vecToList windowShape)),text poolingType,text padding,named "data_format" (text "NWC")]) 196 | 197 | -- Difficulty: relate windowSize, inputSpatialShape, outputSpatialShape 198 | 199 | 200 | 201 | 202 | --------------------------- 203 | -- Contrib 204 | data VarianceScaleMode = VSFanIn | VSFanOut | VSAvg 205 | data Distrib = NormalDistr | UniformDistr 206 | 207 | -- | Random tensor with variance scaling according to deeplearning lore. 208 | varianceScaling :: forall inDim outDim t. KnownNat inDim => (KnownNat outDim, KnownFloat t) => 209 | Float -> VarianceScaleMode -> Distrib -> Gen (Tensor '[inDim,outDim] t) 210 | varianceScaling factor mode distr = noise $ case distr of 211 | UniformDistr -> UniformD (-limit) limit 212 | NormalDistr -> TruncatedNormalD limit 213 | where 214 | fan_in = fromIntegral (natVal (Proxy @inDim)) 215 | fan_out = fromIntegral (natVal (Proxy @outDim)) 216 | n = max 1 $ case mode of 217 | VSFanIn -> fan_in 218 | VSFanOut -> fan_out 219 | VSAvg -> (fan_in + fan_out) / 2 220 | limit = sqrt ((case distr of NormalDistr -> 1.3; UniformDistr -> 3) * factor / n) 221 | 222 | 223 | glorotUniform :: forall inDim outDim t. KnownNat inDim => (KnownNat outDim, KnownBits t) => Gen (Tensor '[outDim,inDim] ('Typ 'Float t)) 224 | glorotUniform = varianceScaling 1 VSAvg UniformDistr 225 | 226 | -- | 'cons' an element and an array (in the first dimension) 227 | consT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n => T s t -> T (n ': s) t -> T (n+1 ': s) t 228 | consT0 x xs = plusComm @1 @n #> concat0 (expandDim0 x) xs 229 | 230 | -- | 'snoc' an element and an array (in the first dimension) 231 | snocT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n => KnownLen s => T (n ': s) t -> T s t -> T (n+1 ': s) t 232 | snocT0 xs x = concat0 xs (expandDim0 x) 233 | 234 | headT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n => T (n+1 ': s) t -> T (s) t 235 | headT0 xs = nth0 0 xs 236 | 237 | tailT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n => T (n+1 ': s) t -> T (n ': s) t 238 | tailT0 xs = incrPos @n #> -- 0 < n+1 239 | plusMinusAssoc @n @1 @1 #> -- (n+1) - 1 = -- n+ (1 - 1) 240 | slice0 @1 @(n+1) xs 241 | 242 | initT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n => T (n+1 ': s) t -> T (n ': s) t 243 | initT0 xs = plusMono @n @1 #> -- n <= n+1 244 | slice0 @0 @n xs 245 | 246 | ---------------- 247 | -- Helpers 248 | 249 | -- | Product of a matrix of weights with a vector. 250 | (∙) :: (KnownNumeric t, KnownNat cols, KnownNat rows, KnownTyp t) => Tensor '[cols, rows] t -> Tensor '[cols] t -> Tensor '[rows] t 251 | m ∙ v = squeeze0 (matmul (expandDim0 v) m) 252 | infixl 7 ∙ 253 | 254 | -- | Dot product between two vectors. 255 | (·) :: ∀ n t. (KnownNumeric t, KnownNat n) => 256 | Tensor '[n] t -> Tensor '[n] t -> Tensor '[] t 257 | x · y = reduceSum0 (x ⊙ y) 258 | infixl 7 · 259 | 260 | -- | 2-Norm of a vector 261 | norm :: KnownBits t => KnownNat n 262 | => T '[n] (Flt t) -> Scalar (Flt t) 263 | norm = frobNorm 264 | 265 | -- | 2-Norm of a tensor 266 | frobNorm :: KnownShape s => KnownBits t => T s (Flt t) -> Scalar (Flt t) 267 | frobNorm = sqrt . reduceSumAll . square 268 | 269 | normalize :: (KnownNat n, KnownBits t) => 270 | T '[n] (Flt t) -> T '[n] (Flt t) 271 | normalize v = mapT (/ (norm v + epsilon)) v 272 | where epsilon = 1.0e-8 273 | 274 | fillTriangular :: forall n l t. 275 | (KnownNat n, KnownNat l, KnownNumeric t, (((l+l)-n) ~ (n*n)), n <= l) 276 | => Tensor '[l] t -> Tensor '[n,n] t 277 | fillTriangular x = plusMinusAssoc @l @l @n #> tril 0 (inflate2 (concat0 x rr)) 278 | where rr :: Tensor '[l - n] t 279 | rr = subIneq @l @n #> slice0 @0 @(l-n) (reverseT x) 280 | 281 | 282 | -- @lookupManyT def indices array@ lokup indices in array, returning def if the index is -1 283 | lookupManyT :: forall s n t. KnownNat n => KnownShape s => (KnownNumeric t) => Scalar t -> T s Int32 -> T '[n] t -> T s t 284 | lookupManyT def indices array = 285 | appRUnit @s #> mapTT @s (\idx -> where_ (equal idx (-1)) def (lookupT idx array)) indices 286 | 287 | 288 | -- | A flexible upper-triangular matrix function: fill the upper triangle with l elements. 289 | fillUpperTriangular :: forall n l t. KnownNumeric t => KnownNat n => KnownNat l => T '[l] t -> T '[n,n] t 290 | fillUpperTriangular x = 291 | zipWithTT @'[n,n] 292 | (\i j -> let idx :: Scalar Int32 293 | idx = ((i * (2 * n - i - 3)) `floorDiv` 2 + j - 1) 294 | 295 | -- The index to lookup in the input array. It is computed from the formula: 296 | -- Output[i,j] = (j-i-1) + ∑_k^(i-1) (n-k) 297 | -- 298 | -- The term j-i-1 is the distance from the upper diagonal. 299 | -- The sum is the number of elements in the previous rows 300 | 301 | in where_ (((j - i) `greaterThan` 0) `logicAnd` (idx `lessThan` l)) 302 | (lookupT idx x) 303 | zeros) 304 | range0 305 | range1 where 306 | 307 | n, l :: Scalar Int32 308 | n = constant (fromIntegral (natVal (Proxy @n))) 309 | l = constant (fromIntegral (natVal (Proxy @l))) 310 | 311 | -- "j" index 312 | range1 :: forall n m w. (KnownNat n, KnownNat m) => KnownBits w => T '[n,m] ('Typ 'Int w) 313 | range1 = broadcastT range 314 | 315 | -- "i" index 316 | range0 :: forall n m w. (KnownNat n, KnownNat m) => KnownBits w => T '[n,m] ('Typ 'Int w) 317 | range0 = transpose01 range1 318 | 319 | 320 | ------------------------- 321 | -- Generic parameters 322 | 323 | -- | Create a parameter and initialize it with a suitable default for its type. Control the exact initializer using 'parameter'. 324 | parameterDefault :: forall p. ParamWithDefault p => String -> Gen p 325 | parameterDefault name = parameter name defaultInitializer 326 | 327 | 328 | -- flattenHTV :: KnownTyp t => All KnownShape xs => HTV t xs -> Tensor '[Sum (Ap (FMap CProduct) xs)] t 329 | -- flattenHTV Unit = zeros 330 | -- flattenHTV (F x :* xs) = concat0 (flattenAll x) (flattenHTV xs) 331 | 332 | -- class CProduct (xs :: [Nat]) 333 | -- instance Fun CProduct where type Ap CProduct xs = Product xs 334 | 335 | -- inflateHTV :: ∀ xs s t. (All KnownShape xs, KnownLen s, KnownLen xs) => 336 | -- Tensor '[Sum (Ap (FMap CProduct) xs)] t -> Gen (HTV t xs) 337 | -- inflateHTV (T x) = do 338 | -- v <- newVar 339 | -- gen (v <> text " = " <> funcall "tf.split" [x, showShape' (prodshape @xs shapeSList), text "axis=0"]) 340 | -- return (mkArr @xs 0 shapeSList v) 341 | -- where mkArr :: forall zs. All KnownShape zs => Int -> SList zs -> DOC -> HTV t zs 342 | -- mkArr _ LZ _ = Unit 343 | -- mkArr i (LS _ n) v = F (unsafeReshape (T (v <> brackets (int i)) )):* mkArr (succ i) n v 344 | -- prodshape :: forall zs. All KnownShape zs => SList zs -> [Integer] 345 | -- prodshape LZ = [] 346 | -- prodshape (LS xx xs) = product (shapeToList' (shapeSListProxy xx)) : prodshape xs 347 | 348 | 349 | -- -- | Gradient of wrt. given parameters. 350 | -- grad' :: KnownLen xs => T s Float32 -> HHTV xs -> Gen (HHTV xs) 351 | -- grad' (T y) vars = do 352 | -- v <- newVar 353 | -- v <-- funcall "tf.gradients" [y, list (htoList (hmap (\(Uncurry (T x)) -> K x) vars)) ] 354 | -- return (mkArr 0 shapeSList v) 355 | -- where mkArr :: forall xs. Int -> SList xs -> DOC -> HHTV xs 356 | -- mkArr _ LZ _ = Unit 357 | -- mkArr i (LS _ n) v = Uncurry (T (v <> brackets (int i))) :* mkArr (succ i) n v 358 | -------------------------------------------------------------------------------- /TypedFlow/Types/Proofs.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes #-} 2 | {-# LANGUAGE ConstraintKinds #-} 3 | {-# LANGUAGE DataKinds #-} 4 | {-# LANGUAGE DeriveFoldable #-} 5 | {-# LANGUAGE DeriveFunctor #-} 6 | {-# LANGUAGE DeriveTraversable #-} 7 | {-# LANGUAGE FlexibleContexts #-} 8 | {-# LANGUAGE FlexibleInstances #-} 9 | {-# LANGUAGE GADTs #-} 10 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 11 | {-# LANGUAGE InstanceSigs #-} 12 | {-# LANGUAGE MagicHash #-} 13 | {-# LANGUAGE MultiParamTypeClasses #-} 14 | {-# LANGUAGE OverloadedStrings #-} 15 | {-# LANGUAGE PatternSynonyms #-} 16 | {-# LANGUAGE RankNTypes #-} 17 | {-# LANGUAGE RecordWildCards #-} 18 | {-# LANGUAGE ScopedTypeVariables #-} 19 | {-# LANGUAGE StandaloneDeriving #-} 20 | {-# LANGUAGE TypeApplications #-} 21 | {-# LANGUAGE TypeFamilies #-} 22 | {-# LANGUAGE TypeInType #-} 23 | {-# LANGUAGE TypeOperators #-} 24 | {-# LANGUAGE UndecidableInstances #-} 25 | {-# LANGUAGE UndecidableSuperClasses #-} 26 | {-# LANGUAGE UnicodeSyntax #-} 27 | {-# LANGUAGE CPP #-} 28 | #if __GLASGOW_HASKELL__ >= 806 29 | {-# LANGUAGE NoStarIsType #-} 30 | #endif 31 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 32 | 33 | module TypedFlow.Types.Proofs where 34 | 35 | 36 | import Prelude hiding (RealFrac(..)) 37 | import GHC.TypeLits 38 | import Data.Proxy 39 | import TypedFlow.Types hiding (T) 40 | import Data.Type.Equality 41 | import Unsafe.Coerce 42 | import Data.Kind (Type) 43 | class SingEq s where 44 | testEq :: forall a b. s a -> s b -> Maybe (a :~: b) 45 | 46 | instance SingEq (Sat KnownNat) where 47 | testEq :: forall n m. Sat KnownNat n -> Sat KnownNat m -> Maybe (n :~: m) 48 | testEq = testNatEqual 49 | 50 | natValS :: forall m. Sat KnownNat m -> Integer 51 | natValS Sat = natVal (Proxy @m) 52 | 53 | testNatEqual :: Sat KnownNat m -> Sat KnownNat n -> Maybe (m :~: n) 54 | testNatEqual m n = if natValS m == natValS n then Just (unsafeCoerce Refl) else Nothing 55 | 56 | instance SingEq f => SingEq (NP f) where 57 | testEq Unit Unit = Just Refl 58 | testEq (x :* xs) (y :* ys) = case (testEq x y, testEq xs ys) of 59 | (Just Refl, Just Refl) -> Just Refl 60 | _ -> Nothing 61 | testEq _ _ = Nothing 62 | 63 | instance SingEq SKind where 64 | testEq SBool SBool = Just Refl 65 | testEq SInt SInt = Just Refl 66 | testEq SFloat SFloat = Just Refl 67 | testEq _ _ = Nothing 68 | 69 | instance SingEq SNBits where 70 | testEq SB32 SB32 = Just Refl 71 | testEq SB64 SB64 = Just Refl 72 | testEq _ _ = Nothing 73 | 74 | instance SingEq STyp where 75 | testEq (STyp k b Refl) (STyp k' b' Refl) = case (testEq k k', testEq b b') of 76 | (Just Refl, Just Refl) -> Just Refl 77 | _ -> Nothing 78 | 79 | -- | Use a reified equality relation 80 | (#>) :: (a :~: b) -> ((a ~ b) => k) -> k 81 | Refl #> k = k 82 | infixr 0 #> 83 | 84 | -- | Use a reified arbitrary predicate 85 | (?>) :: Sat constraint a -> (constraint a => k) -> k 86 | Sat ?> k = k 87 | infixr 0 ?> 88 | 89 | -- | Use a reified arbitrary constraint 90 | (??>) :: Dict constraint -> (constraint => k) -> k 91 | Dict ??> k = k 92 | infixr 0 ??> 93 | 94 | productS :: forall s. SShape s -> Sat KnownNat (Product s) 95 | productS s = knownSShape s ?> 96 | knownProduct @s ?> 97 | Sat 98 | 99 | plusComm :: forall x y. (x + y) :~: (y + x) 100 | plusComm = unsafeCoerce Refl 101 | 102 | plusCommS :: forall x y px py. px x -> py y -> (x + y) :~: (y + x) 103 | plusCommS _ _ = plusComm @x @y 104 | 105 | plusAssoc :: forall x y z. (x + y) + z :~: x + (y + z) 106 | plusAssoc = unsafeCoerce Refl 107 | 108 | plusAssocS :: forall x y z px py pz. px x -> py y -> pz z -> ((x + y) + z) :~: (x + (y + z)) 109 | plusAssocS _ _ _ = plusAssoc @x @y @z 110 | 111 | prodAssoc :: forall x y z. (x * y) * z :~: x * (y * z) 112 | prodAssoc = unsafeCoerce Refl 113 | 114 | prodAssocS :: forall x y z px py pz. px x -> py y -> pz z -> ((x * y) * z) :~: (x * (y * z)) 115 | prodAssocS _ _ _ = prodAssoc @x @y @z 116 | 117 | prodCommS :: forall x y px py. px x -> py y -> (x * y) :~: (y * x) 118 | prodCommS _ _ = unsafeCoerce Refl 119 | 120 | termCancelation :: forall a b. (a + b) - b :~: a 121 | termCancelation = plusMinusAssoc @a @b @b #> cancelation @b #> Refl 122 | 123 | plusMinusAssoc :: forall x y z. (x + y) - z :~: x + (y - z) 124 | plusMinusAssoc = unsafeCoerce Refl 125 | 126 | cancelation :: (a - a) :~: 0 127 | cancelation = unsafeCoerce Refl 128 | 129 | plusMono :: forall a b. (a <=? (a+b)) :~: 'True 130 | plusMono = unsafeCoerce Refl 131 | 132 | succPos :: (1 <=? 1+j) :~: 'True 133 | -- CmpNat 0 (1 + n) :~: 'LT 134 | succPos = unsafeCoerce Refl 135 | 136 | succPosProx2 :: forall n proxy a. proxy n a -> (0 :<: (1+n)) 137 | succPosProx2 _ = succPos @n 138 | 139 | prodHomo :: forall x y. Product (x ++ y) :~: Product x * Product y 140 | prodHomo = unsafeCoerce Refl 141 | 142 | prodHomoS :: forall x y px py. px x -> py y -> ((Product (x ++ y) :~: (Product x * Product y))) 143 | prodHomoS _ _ = prodHomo @x @y 144 | 145 | knownProduct' :: forall s f. All KnownNat s => NP f s -> Sat KnownNat (Product s) 146 | knownProduct' Unit = Sat 147 | knownProduct' (_ :* n) = knownProduct' n ?> Sat 148 | 149 | knownProduct :: forall s. KnownShape s => Sat KnownNat (Product s) 150 | knownProduct = knownProduct' @s typeSList 151 | 152 | knownSumS :: forall s. NP (Sat KnownNat) s -> Sat KnownNat (Sum s) 153 | knownSumS Unit = Sat 154 | knownSumS (Sat :* n) = knownSumS n ?> Sat 155 | 156 | knownSum' :: forall s f. All KnownNat s => NP f s -> Sat KnownNat (Sum s) 157 | knownSum' proxies = knownSumS (allKnown' proxies) 158 | 159 | knownSum :: forall s. KnownShape s => Sat KnownNat (Sum s) 160 | knownSum = knownSum' @s typeSList 161 | 162 | knownPlus :: forall m n. KnownNat m => KnownNat n => Sat KnownNat (m + n) 163 | knownPlus = Sat 164 | 165 | takeDrop :: forall s n. (PeanoNat n <= Length s) => (Take n s ++ Drop n s) :~: s 166 | takeDrop = unsafeCoerce Refl 167 | 168 | lengthHomo :: forall x y. Length (x ++ y) :~: Length x + Length y 169 | lengthHomo = unsafeCoerce Refl 170 | 171 | lengthHomoS :: forall x y proxyx proxyy. proxyx x -> proxyy y -> ((Length (x ++ y) :~: (Length x + Length y))) 172 | lengthHomoS _ _ = lengthHomo @x @y 173 | 174 | lengthInit :: forall s. (0 < Length s) => SList s -> ((Length (Init s) + 1) :~: Length s) 175 | lengthInit x = lengthHomo @(Init s) @'[Last s] #> initLast x #> Refl 176 | 177 | type a :<=: b = ((a <=? b):~: 'True) 178 | type i :<: j = (i+1) :<=: j 179 | 180 | incrPos :: forall x. 1 :<=: (x+1) 181 | incrPos = unsafeCoerce Refl 182 | 183 | 184 | subIneq :: forall x k. (x - k) :<=: x 185 | subIneq = unsafeCoerce Refl 186 | 187 | incrCong :: forall x y. ((x+1) ~ (y+1)) => x :~: y 188 | incrCong = unsafeCoerce Refl 189 | 190 | initLast :: forall s. {-(0 < Length s) => FIXME -} SList s -> ((Init s ++ '[Last s]) :~: s) 191 | initLast Unit = error "initLast': does not hold on empty lists" 192 | initLast ((:*) _ Unit) = Refl 193 | initLast ((:*) _ ((:*) y ys)) = initLast ((:*) y ys) #> Refl 194 | 195 | initLast' :: forall s. {-(0 < Length s) => FIXME -} KnownShape s => ((Init s ++ '[Last s]) :~: s) 196 | initLast' = initLast (typeSList @s) 197 | 198 | appRUnit :: forall s. (s ++ '[]) :~: s 199 | appRUnit = unsafeCoerce Refl 200 | 201 | appAssoc :: ((xs ++ ys) ++ zs) :~: (xs ++ (ys ++ zs)) 202 | appAssoc = unsafeCoerce Refl 203 | 204 | appAssocS :: forall xs ys zs proxy1 proxy2 proxy3. 205 | proxy1 xs -> proxy2 ys -> proxy3 zs -> (((xs ++ ys) ++ zs) :~: (xs ++ (ys ++ zs))) 206 | appAssocS _ _ _ = appAssoc @xs @ys @zs 207 | 208 | 209 | knownLast' :: All KnownNat s => SList s -> (KnownNat (Last s) => k) -> k 210 | knownLast' Unit _ = error "knownLast: does not hold on empty lists" 211 | knownLast' ((:*) _ Unit) k = k 212 | knownLast' ((:*) _ ((:*) y xs)) k = knownLast' ((:*) y xs) k 213 | 214 | knownLast :: forall s k. KnownShape s => (KnownNat (Last s) => k) -> k 215 | knownLast = knownLast' @s typeSList 216 | 217 | knownInit' :: All KnownNat s => SList s -> Sat KnownShape (Init s) 218 | knownInit' Unit = error "knownLast: does not hold on empty lists" 219 | knownInit' ((:*) _ Unit) = Sat 220 | knownInit' ((:*) _ ((:*) y xs)) = knownInit' ((:*) y xs) ?> Sat 221 | 222 | knownInit :: forall s. KnownShape s => Sat KnownShape (Init s) 223 | knownInit = knownInit' @s typeSList 224 | 225 | knownTail' :: forall x s k. All KnownNat s => SList (x ': s) -> (KnownShape s => k) -> k 226 | knownTail' ((:*) _ Unit) k = k 227 | knownTail' ((:*) _ ((:*) y xs)) k = knownTail' ((:*) y xs) k 228 | 229 | knownTail :: forall s x xs k. (s ~ (x ': xs), KnownShape s) => (KnownShape xs => k) -> k 230 | knownTail = knownTail' @x @xs typeSList 231 | 232 | knownAppendS :: forall s t pt. (All KnownNat s, KnownShape t) => SList s -> pt t -> Sat KnownShape (s ++ t) 233 | knownAppendS Unit _t = Sat 234 | knownAppendS ((:*) _ n) t = knownAppendS n t ?> Sat 235 | 236 | knownAppend :: forall s t. (KnownShape s, KnownShape t) => Sat KnownShape (s ++ t) 237 | knownAppend = knownAppendS (typeSList @s) (Proxy @t) 238 | 239 | 240 | -- knownFmap' :: forall f xs. SList xs -> SList (Ap (FMap f) xs) 241 | -- knownFmap' Unit = Unit 242 | -- knownFmap' ((:*) x n) = (:*) Proxy (knownFmap' @f n) 243 | 244 | knownSList :: NP proxy xs -> Sat KnownLen xs 245 | knownSList Unit = Sat 246 | knownSList (_ :* n) = knownSList n ?> Sat 247 | 248 | knownSShape :: SShape xs -> Sat KnownShape xs 249 | knownSShape Unit = Sat 250 | knownSShape ((:*) Sat s) = knownSShape s ?> Sat 251 | 252 | data DimExpr (a :: Nat) (x :: Nat) (b :: Nat) where 253 | ANat :: Sat KnownNat x -> DimExpr a x (a * x) 254 | (:*:) :: DimExpr a x b -> DimExpr b y c -> DimExpr a (x*y) c 255 | 256 | knownOutputDim :: forall a x b. Sat KnownNat a -> DimExpr a x b -> Sat KnownNat b 257 | knownOutputDim a (ANat x) = satMul a x 258 | knownOutputDim a (x :*: y) = knownOutputDim (knownOutputDim a x) y 259 | 260 | dimSat :: DimExpr a x b -> Sat KnownNat x 261 | dimSat (ANat s) = s 262 | dimSat (x :*: y) = dimSat x `satMul` dimSat y 263 | 264 | normDim :: forall ws xs ys. DimExpr ws xs ys -> (ws * xs) :~: ys 265 | normDim (ANat _) = Refl 266 | normDim (a :*:b) = normDim a #> 267 | normDim b #> 268 | prodAssocS (Proxy @ws) (dimSat a) (dimSat b) #> 269 | Refl 270 | 271 | data ShapeExpr (a :: Nat) (x :: Shape) (b :: Nat) where 272 | Single :: DimExpr a x b -> ShapeExpr a '[x] b 273 | AShape :: SShape x -> ShapeExpr a x (a * Product x) 274 | (:++:) :: ShapeExpr a x b -> ShapeExpr b y c -> ShapeExpr a (x++y) c 275 | 276 | infixr 5 :++: 277 | infixr 5 *:! 278 | infixr 5 !:* 279 | 280 | (!:*) :: DimExpr a x b -> ShapeExpr b xs c -> ShapeExpr a (x ': xs) c 281 | x !:* xs = Single x :++: xs 282 | 283 | (*:!) :: ShapeExpr a xs b -> DimExpr b x c -> ShapeExpr a (xs ++ '[x]) c 284 | xs *:! x = xs :++: Single x 285 | 286 | exprSShape :: forall a x b. ShapeExpr a x b -> SShape x 287 | exprSShape (AShape s) = s 288 | exprSShape (Single x) = dimSat x ?> typeSShape 289 | exprSShape (x :++: y) = exprSShape x .+. exprSShape y 290 | 291 | normShape :: forall ws xs ys. ShapeExpr ws xs ys -> (ws * Product xs) :~: ys 292 | normShape (Single x) = normDim x 293 | normShape (AShape _) = Refl 294 | normShape (l :++: r) = normShape l #> 295 | normShape r #> 296 | prodHomoS (exprSShape l) (exprSShape r) #> 297 | prodAssocS (Proxy @ws) (productS (exprSShape l)) (productS (exprSShape r)) #> 298 | Refl 299 | -- r :: normShape b y ys ----> (b * y) ~ ys (1) 300 | -- l :: normShape ws x b ----> (ws * x) ~ b (2) 301 | -- subst (2) in (1): ((ws * x) * y) ~ ys 302 | -- assoc: (ws * (x * y)) ~ ys 303 | 304 | decideProductEq1 :: forall xs zs. ShapeExpr 1 xs zs -> Product xs :~: zs 305 | decideProductEq1 a = case normShape a of Refl -> Refl 306 | 307 | type ShapeX = ShapeExpr 1 308 | 309 | decideProductEq :: ShapeExpr 1 xs zs -> ShapeExpr 1 ys zs -> Product xs :~: Product ys 310 | decideProductEq l r = case decideProductEq1 l of 311 | Refl -> case decideProductEq1 r of 312 | Refl -> Refl 313 | 314 | 315 | unsafePositive :: (1 <=? n) :~: 'True 316 | unsafePositive = unsafeCoerce Refl 317 | 318 | sucPred :: ((1 <=? n) ~ 'True) => (n - 1) + 1 :~: n 319 | sucPred = unsafeCoerce Refl 320 | 321 | -- data ORDEQ p a b where 322 | -- LT, GT :: ORDEQ a b 323 | -- EQ :: p -> ORDEQ a a 324 | 325 | data NatExpr n where 326 | NEVar :: Int -> NatExpr n 327 | (::+) :: NatExpr m -> NatExpr n -> NatExpr (m+n) 328 | (::*) :: NatExpr m -> NatExpr n -> NatExpr (m*n) 329 | 330 | data NatSum n where 331 | NSZero :: NatSum 0 332 | NSAdd :: NatProd m -> NatSum n -> NatSum (m+n) 333 | 334 | data NatProd n where 335 | NPUnit :: Sat KnownNat k -> NatProd k 336 | NPTimes :: Sat KnownNat m -> Int -> NatProd n -> NatProd (m*n) 337 | 338 | sortProd :: NatProd n -> NatProd n 339 | sortProd (NPUnit k) = NPUnit k 340 | sortProd (NPTimes x xId y) = insertProd x xId (sortProd y) 341 | where insertProd :: Sat KnownNat m -> Int -> NatProd n -> NatProd (m*n) 342 | insertProd x xId rest = case rest of 343 | (NPUnit k) -> NPTimes x xId rest 344 | (NPTimes y yId ys) -> if xId <= yId 345 | then NPTimes x xId rest 346 | else prodAssocS x y ys #> 347 | prodCommS x y #> 348 | prodAssocS y x ys #> 349 | NPTimes y yId (insertProd x xId ys) 350 | 351 | sortSum :: NatSum n -> NatSum n 352 | sortSum NSZero = NSZero 353 | sortSum (NSAdd x y) = insertTerm (sortProd x) (sortSum y) 354 | where insertTerm :: NatProd m -> NatSum n -> NatSum (m+n) 355 | insertTerm p rest = case rest of 356 | NSZero -> NSAdd p NSZero 357 | NSAdd q qs -> case compareTerms p q of 358 | Right p' -> plusAssocS p q qs #> NSAdd p' qs 359 | Left False -> NSAdd p rest 360 | Left True -> plusAssocS p q qs #> 361 | plusCommS p q #> 362 | plusAssocS q p qs #> 363 | NSAdd q (insertTerm p qs) 364 | 365 | 366 | compareTerms :: NatProd n -> NatProd m -> Either Bool (NatProd (n+m)) 367 | compareTerms (NPUnit Sat) (NPUnit Sat) = Right (NPUnit Sat) 368 | compareTerms (NPUnit _) (NPTimes _ _ _) = Left False 369 | compareTerms (NPTimes _ _ _) (NPUnit _) = Left True 370 | compareTerms (NPTimes x xId xs) (NPTimes y yId ys) = 371 | case testEq x y of 372 | Nothing -> Left (natValS x <= natValS y) 373 | Just Refl -> case compareTerms xs ys of 374 | Left x -> Left x 375 | Right p -> distrLS x xs ys #> Right (NPTimes x xId p) 376 | 377 | 378 | 379 | distrLS :: forall a b c px py pz. px a -> py b -> pz c -> a * (b + c) :~: (a * b + a * c) 380 | distrLS = unsafeCoerce Refl 381 | 382 | distrRS :: forall a b c px py pz. px a -> py b -> pz c -> (a + b) * c :~: ((a * c) + (b * c)) 383 | distrRS = unsafeCoerce Refl 384 | 385 | expandProd :: NatSum m -> NatSum n -> NatSum (m*n) 386 | expandProd NSZero _ = NSZero 387 | expandProd _ NSZero = NSZero 388 | expandProd (NSAdd a b) c = distrRS a b c #> expandSum (expandP' a c) (expandProd b c) 389 | 390 | expandP' :: NatProd m -> NatSum n -> NatSum (m*n) 391 | expandP' _ NSZero = NSZero 392 | expandP' p (NSAdd q a) = distrLS p q a #> NSAdd (expandPP p q) (expandP' p a) 393 | 394 | expandPP :: NatProd m -> NatProd n -> NatProd (m*n) 395 | expandPP (NPUnit k) (x) = expandKP k x 396 | expandPP (NPTimes x xId y) z = prodAssocS x y z #> NPTimes x xId (expandPP y z) 397 | 398 | expandKP :: Sat KnownNat k -> NatProd n -> NatProd (k*n) 399 | expandKP Sat (NPUnit Sat) = NPUnit Sat 400 | expandKP k@Sat (NPTimes x xId y) 401 | = prodAssocS k x y #> 402 | prodCommS k x #> 403 | prodAssocS x k y #> 404 | NPTimes x xId (expandKP k y) 405 | 406 | expandSum :: NatSum m -> NatSum n -> NatSum (m+n) 407 | expandSum NSZero x = x 408 | expandSum (NSAdd x y) z = plusAssocS x y z #> NSAdd x (expandSum y z) 409 | 410 | 411 | natRec :: forall (n :: Nat) (p :: Nat -> Type). KnownNat n => p 0 -> (forall (m :: Nat). p m -> p (m+1)) -> p n 412 | natRec z s = case natVal (Proxy @n) of 413 | 0 -> unsafeCoerce z 414 | _ -> case unsafePositive @n of 415 | Refl -> case sucPred @n of 416 | Refl -> s @(n-1) (natRec @(n-1) @p z s) 417 | 418 | 419 | data CountRes n where 420 | CountRes :: Integer -> V n Integer -> CountRes n 421 | 422 | vcount :: forall n. KnownNat n => V n Integer 423 | vcount = 424 | case natRec @n (CountRes (natVal (Proxy @n)-1) VUnit) (\(CountRes m xs) -> 425 | plusCommS (Proxy @1) (F xs) 426 | #> CountRes (m-1) (m :** xs)) of 427 | CountRes _ x -> x 428 | 429 | data V n a where 430 | VUnit :: V 0 a 431 | (:**) :: a -> V n a -> V (1+n) a 432 | infixr 5 :** 433 | 434 | deriving instance (Functor (V n)) 435 | 436 | instance KnownNat n => Applicative (V n) where 437 | pure x = fmap (const x) (vcount @n) 438 | VUnit <*> VUnit = VUnit 439 | (f :** fs) <*> (a :** as) = succPosProx2 fs #> (f a :** (fs <*> unsafeCoerce as)) 440 | -------------------------------------------------------------------------------- /cabal.project: -------------------------------------------------------------------------------- 1 | packages: 2 | ./typedflow.cabal 3 | -------------------------------------------------------------------------------- /docs/HOT.org: -------------------------------------------------------------------------------- 1 | #+TITLE: TypedFlow: The HOT parts 2 | #+AUTHOR: Jean-Philippe Bernardy, University of Gothenburg 3 | 4 | TensorFlow™ is an open source software library for numerical 5 | computation using data flow graphs. Nodes in the graph represent 6 | mathematical operations, while the graph edges represent the 7 | multidimensional data arrays (tensors) communicated between them. 8 | TensorFlow graphs can be efficiently evaluated on GPUs and is a 9 | popular choice to implement deep learning applications. 10 | 11 | TypedFlow is a higher-order and typed (HOT) frontend to tensorflow, 12 | written in (Glasgow) Haskell. 13 | 14 | In this talk I will: 15 | - briefly explain what TensorFlow is and how it applies to deep learning 16 | - recall the advantages of a HOT approach 17 | - expose some example programs written using TypedFlow 18 | - demonstrate how tensorflow functions can be given precise types, 19 | using GHC extensions 20 | - discuss some the difficulties of doing so 21 | 22 | * Machine learning in 45 seconds: 23 | 24 | - a vector of training inputs X :: [A] 25 | - a model f : (Θ × A) → ℝ⁺ 26 | 27 | Task: 28 | 29 | Given X, find θ such that f(θ,x) < ε, if 30 | x is considered similar to points in X. 31 | 32 | Commentary: Every point in X lie on a manyfold. We want to find what 33 | this manyfold is. (Interpolation problem.) 34 | 35 | * "Deep" learning in 45 seconds 36 | 37 | - "Deep" ≡ f is "complex" 38 | - So we must use a brute force method to compute θ: stochastic 39 | gradient descent (or variants thereof). 40 | 41 | - Typically, compute the gradient of f wrt. θ using AD. 42 | 43 | * Tensorflow 44 | 45 | - A (meta) programming language to define f. AD is builtin. (there is fineprint) 46 | 47 | - Restricted control flow (mostly, tensor-generalisations of +, *, -, 48 | /, ^) 49 | 50 | - Typically programmed using python (standard in scientific computing) 51 | 52 | - "Strongly typed". 53 | - but: "brodcasting" 54 | - but: running the metaprogram can be quite slow ~1 minute (so type 55 | errors can happen after 1 minute of loading the model --- and 56 | programs can do other things before ...) 57 | - but: types are typically not written as such. Given any two 58 | functions, weather they compose (and what the composition does) is 59 | a mistery unless one examines their code. 60 | 61 | * First-order culture 62 | 63 | https://github.com/fchollet/keras/blob/master/keras/layers/recurrent.py 64 | (search "class LSTM") 65 | 66 | * TypedFlow 67 | 68 | - An typed, higher-order frontend to tensorflow (basic tensor operations) 69 | - A library to construct neural networks 70 | - Generates python (yikes) 71 | 72 | * Typing tensors 73 | 74 | - tanh 75 | - matmul 76 | - concatT 77 | - repeatT 78 | - tile 79 | - convolution 80 | 81 | * Heterogeneous tensors 82 | 83 | type HTV 84 | 85 | * Example Higher-order stuff 86 | 87 | - mapT 88 | - rnn 89 | - withBypass 90 | - Attention model 91 | 92 | * Complete examples 93 | 94 | - mnist 95 | - seq2seq 96 | 97 | * GHC woes 98 | 99 | - see transposeV 100 | 101 | * Summary 102 | 103 | - Some NN building blocks are naturally higher-order. Taking an 104 | example (and simplifying) a recurrent neural network turns a tensor 105 | function into a function between lists (vectorslists) of tensors. 106 | 107 | - Functional programming is ideally suited to program complicated 108 | applications from building blocks. 109 | 110 | Example: an "Attention-model" is a thing where every step in a RNN adds 111 | a computation which depends on an external input. We can compose 112 | usual RNN cells with attention models in several ways. The state of 113 | the art is to reprogram all combinations by hand. 114 | 115 | - Typed APIs. 116 | 117 | Types can be used to check the tensor dimensions. Types catch a lof 118 | of errors, but they can also be used to *guide* the programming. 119 | 120 | Types are pretty much a necessity in the presence of HO 121 | functions. 122 | 123 | - TypedFlow is typically much closer to mathematical notation than 124 | python. Programs are short to write and easier to read. Standard 125 | building blocks can be swapped for custom versions quite easily. 126 | 127 | Examples 128 | - rnn stacking using "residual connections" instead of just 129 | stacking. 130 | - make it easy to share parameters between different components 131 | (example: if we do a style translation we may want to share the 132 | embedding layers between encoder and decoders parts) 133 | 134 | - Long game: integrate cutting edge ideas as they arrive with moderate 135 | effort. 136 | 137 | * FAQ 138 | - Why not Agda, Idris? A long term plan is to bypass python, so we'd 139 | want a "real" programming language for the programming bits that go 140 | around the TF program. 141 | -------------------------------------------------------------------------------- /docs/Talk.org: -------------------------------------------------------------------------------- 1 | #+TITLE: TypedFlow: A library for higher-order typed deep learning 2 | #+AUTHOR: Jean-Philippe Bernardy, University of Gothenburg 3 | 4 | TensorFlow is a library for numerical computation, with specific 5 | features for machine-learning such as gradient computation. It is 6 | perhaps the most popular backend for deep learning applications. 7 | TypedFlow a higher-order and typed (HOT) frontend to TensorFlow 8 | written in Haskell, and a library of neural-network layers and 9 | combinators. 10 | 11 | 12 | In this talk I will: 13 | 14 | - briefly recall what TensorFlow is and how it applies to deep 15 | learning 16 | - discuss the advantages of a HOT approach vs. plain TensorFlow 17 | - expose two use-cases: the standard MNIST example and a 18 | sequence-to-sequence network with attention model. 19 | 20 | 21 | Ideas: transparency, explainability 22 | 23 | 24 | 25 | * Machine learning in 45 seconds: 26 | 27 | - a vector of training inputs X::[A] 28 | - a model f : (Θ × A) → ℝ⁺ 29 | 30 | Task: 31 | 32 | Given X, find θ such that f(θ,x) < ε, if 33 | x is considered similar to points in X, and > ε otherwise. 34 | 35 | Commentary: Every point in X lie on a manifold. We want to find what 36 | this manyfold is. (Interpolation problem.) 37 | 38 | * "Deep" learning in 45 seconds 39 | 40 | - "Deep" ≡ f is "complicated" 41 | - So we must use a brute force method to compute θ: stochastic 42 | gradient descent (or variants thereof). 43 | 44 | - Typically, compute the gradient of f wrt. θ using AD. 45 | 46 | * Tensorflow 47 | 48 | - A (meta) programming language to define f. AD is builtin. (there is 49 | fineprint) 50 | 51 | - Restricted control flow (mostly, tensor-generalisations of +, *, -, 52 | /, ^, tanh, ...) 53 | 54 | - Typically programmed using python (standard in scientific computing) 55 | 56 | - "Strongly typed" 57 | - but: no abstraction over dimensions 58 | - but: "brodcasting" 59 | - but: running the metaprogram can be quite slow ~1 minute (so type 60 | errors can happen after 1 minute of loading the model --- and 61 | programs can do other things before ...) 62 | - but: types are typically not written as such. Given any two 63 | functions, weather they compose (and what the composition does) is 64 | a mistery unless one examines their code. 65 | 66 | - "map" has a surprising semantics (see below) 67 | 68 | * What is TypedFlow? 69 | 70 | - An typed, higher-order frontend to tensorflow 71 | (basic tensor operations) 72 | - A library to construct neural networks 73 | - Generates python 74 | 75 | * Why TypedFlow? 76 | 77 | Functional programming is ideally suited to program complicated 78 | applications from building blocks. 79 | 80 | - Notation 81 | - Types 82 | - HO 83 | 84 | * Deep Learning: The state of the art 85 | 86 | [[file:cards.jpg]] 87 | 88 | (Actually this has become worse!) 89 | * Notation 90 | 91 | Haskell is typically much closer to mathematical notation than 92 | python. Programs are short to write and easier to read. 93 | 94 | file:../TypedFlow/TF.hs::/⊕.*::/ 95 | 96 | * Why Types? 97 | 98 | Types can be used to check the tensor dimensions. 99 | 100 | - Types catch a lof of errors 101 | - but they can also be used to *guide* the programming. "Type holes" 102 | (see MNIST example) 103 | 104 | Types are pretty much a necessity to take advantage of HO functions. 105 | 106 | #+BEGIN_QUOTE 107 | Together with the absence of side effects, rich type systems enable to 108 | construct complex programs with a high degree of confidence: 109 | 110 | - types precisely abstract the intention of the programmer for each function, 111 | without any hidden side effect, and 112 | - provided that they match the contracts imposed by types, functions 113 | can be freely combined, using lazy evaluation and higher-order 114 | facilities, without risk of pernicious interference. 115 | #+END_QUOTE 116 | 117 | * Python, aka The Culture of First Order 118 | 119 | [[file:imperiallegion.jpg]] 120 | 121 | https://github.com/fchollet/keras/blob/master/keras/layers/recurrent.py 122 | (search "class LSTM") 123 | 124 | * Example 1: LSTM 125 | 126 | file:../TypedFlow/Layers/RNN.hs::/^lstm.*::/ 127 | 128 | * Example 2: Attention 129 | 130 | Example: an "Attention-model" is a model where every step in a RNN 131 | adds a computation which depends on an external input. We can compose 132 | usual RNN cells with attention models in several ways. The state of 133 | the art is to reprogram such combinations by hand. 134 | 135 | file:../TypedFlow/Layers/RNN.hs::/^attentiveWithFeedback.*::/ 136 | 137 | * Mapping tensors 138 | 139 | - Tensorflow's ~map~ spawns processes. This is (usually) quite a bad 140 | idea --- tensor operations are parallelized anyway (but not on 141 | several GPUs... the purpose of ~map~ apparently). 142 | 143 | - Most (but not all!) operations have so-called "broadcast semantics"; 144 | they can be (implicitly!) raised to tensors of higher dimensions. 145 | 146 | - file:../TypedFlow/Abstract.hs::/^protoBroadcast.*::/ 147 | 148 | - Note "gather" goes to "gather_nd" 149 | - Certain convolutions can't be broadcasted at all 😿 150 | 151 | * Pretending that tensor operations are functional 152 | 153 | - They are EXCEPT that sharing is lost 154 | - Use the old trick of observable sharing. (Memoizing, etc.) 155 | 156 | * Long game 157 | 158 | - Integrate cutting edge DL ideas as they arrive with moderate effort. 159 | 160 | * MNIST 161 | 162 | file:../examples/mnist/MNIST.hs 163 | 164 | * Seq2Seq 165 | 166 | file:../examples/seq2seq/Seq2Seq.hs 167 | -------------------------------------------------------------------------------- /docs/cards.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GU-CLASP/TypedFlow/3a8fa230d413279c12c70b1eb77b3e0cbd833f4a/docs/cards.jpg -------------------------------------------------------------------------------- /docs/imperiallegion.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GU-CLASP/TypedFlow/3a8fa230d413279c12c70b1eb77b3e0cbd833f4a/docs/imperiallegion.jpg -------------------------------------------------------------------------------- /examples/agreement/Aggr.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ApplicativeDo #-} 2 | {-# LANGUAGE ViewPatterns #-} 3 | {-# LANGUAGE AllowAmbiguousTypes #-} 4 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 5 | {-# LANGUAGE DataKinds #-} 6 | {-# LANGUAGE RankNTypes #-} 7 | {-# LANGUAGE ScopedTypeVariables #-} 8 | {-# LANGUAGE TypeApplications #-} 9 | {-# LANGUAGE TypeFamilies #-} 10 | {-# LANGUAGE TypeOperators #-} 11 | {-# LANGUAGE UnicodeSyntax #-} 12 | 13 | 14 | import TypedFlow 15 | import TypedFlow.Python 16 | import qualified GHC.Int as GHC 17 | 18 | onFST :: (Tensor s1 t -> Tensor s t) -> HTV t '[s1, s'] -> HTV t '[s, s'] 19 | onFST f (VecPair h c) = (VecPair (f h) c) 20 | 21 | mkLSTM :: ∀ n x. KnownNat x => KnownNat n => 22 | String -> DropProb -> Gen (RnnCell Float32 '[ '[n], '[n]] (Tensor '[x] Float32) (Tensor '[n] Float32)) 23 | mkLSTM pName dropProb = do 24 | params <- parameterDefault pName 25 | drp1 <- mkDropout dropProb 26 | rdrp1 <- mkDropout dropProb 27 | return (timeDistribute drp1 .-. onStates (onFST rdrp1) (lstm params)) 28 | 29 | model :: forall (vocSize::Nat) (len::Nat). KnownNat len => KnownNat vocSize => 30 | Gen (T '[len] Int32 -> T '[len] Int32 -> ModelOutput Float32 '[len,vocSize] '[]) 31 | model = do 32 | embs <- parameterDefault "embs" 33 | let dropProb = DropProb 0.10 34 | lstm1 <- mkLSTM @160 "w1" dropProb 35 | drp <- mkDropout dropProb 36 | w <- parameterDefault "dense" 37 | return $ \input gold -> do 38 | let masks = constant 1 ⊝ cast @Float32 (equal (constant padding) input) 39 | (_sFi,predictions) = 40 | simpleRnn (timeDistribute (embedding @12 @vocSize embs) .-. 41 | lstm1 .-. 42 | timeDistribute drp .-. 43 | timeDistribute (dense w)) 44 | (repeatT zeros, input) 45 | in timedCategorical masks predictions gold 46 | 47 | padding :: GHC.Int32 48 | padding = 10 49 | 50 | main :: IO () 51 | main = do 52 | generateFile "aggr.py" (compile @512 defaultOptions (model @12 @21)) 53 | putStrLn "done!" 54 | 55 | -- >>> main 56 | -- Parameters (total 134300): 57 | -- dense_bias: T [12] tf.float32 58 | -- dense_w: T [160,12] tf.float32 59 | -- w1_o_b: T [160] tf.float32 60 | -- w1_o_w: T [172,160] tf.float32 61 | -- w1_c_b: T [160] tf.float32 62 | -- w1_c_w: T [172,160] tf.float32 63 | -- w1_i_b: T [160] tf.float32 64 | -- w1_i_w: T [172,160] tf.float32 65 | -- w1_f_b: T [160] tf.float32 66 | -- w1_f_w: T [172,160] tf.float32 67 | -- embs: T [12,12] tf.float32 68 | -- y: T [512,21] tf.int32 69 | -- x: T [512,21] tf.int32 70 | -- done! 71 | 72 | 73 | (|>) :: ∀ a b. a -> b -> (a, b) 74 | (|>) = (,) 75 | infixr |> 76 | 77 | 78 | -- Local Variables: 79 | -- dante-repl-command-line: ("nix-shell" ".styx/shell.nix" "--pure" "--run" "cabal repl") 80 | -- End: 81 | 82 | -------------------------------------------------------------------------------- /examples/mnist/MNIST.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 2 | {-# LANGUAGE ApplicativeDo #-} 3 | {-# LANGUAGE DataKinds #-} 4 | {-# LANGUAGE RankNTypes #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | {-# LANGUAGE TypeApplications #-} 7 | {-# LANGUAGE TypeFamilies #-} 8 | {-# LANGUAGE TypeOperators #-} 9 | {-# LANGUAGE UnicodeSyntax #-} 10 | {-# LANGUAGE NoStarIsType #-} 11 | module MNIST where 12 | 13 | import TypedFlow 14 | import TypedFlow.Python 15 | 16 | atShape :: forall s t. T s t -> T s t 17 | atShape x = x 18 | 19 | mnist :: Gen (Model '[784] Float32 '[10] '[10] '[] Float32) 20 | mnist = do 21 | filters1 <- parameterDefault "f1" 22 | filters2 <- parameterDefault "f2" 23 | w1 <- parameterDefault "w1" 24 | w2 <- parameterDefault "w2" 25 | return $ \input gold -> 26 | let nn = dense @10 w2 . 27 | relu . dense @1024 w1 . 28 | reshape @'[7 * 7 * 64] . 29 | maxPool2D @2 @2 . 30 | relu . conv @64 @'[5,5] filters2 . 31 | maxPool2D @2 @2 . 32 | atShape @'[28,28,32] . 33 | relu . conv @32 @'[5,5] filters1 . 34 | reshape @'[28,28,1] 35 | logits = nn input 36 | 37 | in categoricalDistribution logits gold 38 | 39 | main :: IO () 40 | main = do 41 | generateFile "mnist_model.py" (compile @100 defaultOptions mnist) 42 | putStrLn "done!" 43 | 44 | -- >>> main 45 | -- Parameters (total 3274634): 46 | -- f1_filters: T [5, 5, 1, 32] tf.float32 47 | -- f1_biases: T [32] tf.float32 48 | -- f2_filters: T [5, 5, 32, 64] tf.float32 49 | -- f2_biases: T [64] tf.float32 50 | -- w1_w: T [3136, 1024] tf.float32 51 | -- w1_bias: T [1024] tf.float32 52 | -- w2_w: T [1024, 10] tf.float32 53 | -- w2_bias: T [10] tf.float32 54 | -- done! 55 | 56 | 57 | -- Local Variables: 58 | -- dante-repl-command-line: ("nix-shell" ".styx/shell.nix" "--pure" "--run" "cabal repl") 59 | -- End: 60 | 61 | -------------------------------------------------------------------------------- /examples/mnist/Makefile: -------------------------------------------------------------------------------- 1 | test: mnist_model.py main.py 2 | nix-shell ../seq2seq/shell.nix --run "python main.py" 3 | 4 | mnist_model.py: MNIST.hs 5 | nix-shell ../../.styx/shell.nix --run "ghci -i../.. MNIST.hs -e main" 6 | 7 | -------------------------------------------------------------------------------- /examples/mnist/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../..') # so we can see the rts. 3 | 4 | import typedflow_rts as tyf 5 | import tensorflow as tf 6 | import numpy as np 7 | from mnist_model import mkModel,runModel 8 | import os 9 | 10 | # comment out if you don't have CUDA 11 | tyf.cuda_use_one_free_device() 12 | 13 | optimizer = tf.keras.optimizers.Adam(1e-4) 14 | 15 | # import tfds.image_classification.MNIST as mnist # need to package tfds 16 | 17 | def train_generator(batch_size): 18 | for _ in range(1000): 19 | # (x,y) = mnist.batch(100) 20 | yield {"x":np.zeros((100,784), dtype=np.float32), # FIXME 21 | "y":np.zeros((100,10), dtype=np.float32) 22 | } 23 | 24 | 25 | model = mkModel() 26 | 27 | tyf.train(optimizer,model,runModel,train_generator) 28 | -------------------------------------------------------------------------------- /examples/mnist/mnist_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | def mkModel(): 3 | #shape: [25, 32] 4 | var10000=tf.random.uniform([25, 32], 5 | minval=-0.32444283, 6 | maxval=0.32444283, 7 | dtype=tf.float32) # 0 8 | #shape: [5, 5, 1, 32] 9 | var10001=tf.reshape(var10000, [5, 5, 1, 32]) 10 | var10002=tf.Variable(name="f1_filters", trainable=True, initial_value=var10001) 11 | #shape: [] 12 | var10003=tf.constant(0.1, shape=[], dtype=tf.float32) 13 | #shape: [32] 14 | var10004=ERROR:BroadcastT(var10003) 15 | #shape: [32] 16 | var10005=tf.reshape(var10004, [32]) 17 | var10006=tf.Variable(name="f1_biases", trainable=True, initial_value=var10005) 18 | #shape: [800, 64] 19 | var10007=tf.random.uniform([800, 64], 20 | minval=-8.3333336e-2, 21 | maxval=8.3333336e-2, 22 | dtype=tf.float32) # 1 23 | #shape: [5, 5, 32, 64] 24 | var10008=tf.reshape(var10007, [5, 5, 32, 64]) 25 | var10009=tf.Variable(name="f2_filters", trainable=True, initial_value=var10008) 26 | #shape: [64] 27 | var10010=tf.reshape(var10004, [64]) 28 | var10011=tf.Variable(name="f2_biases", trainable=True, initial_value=var10010) 29 | #shape: [3136, 1024] 30 | var10012=tf.random.uniform([3136, 1024], 31 | minval=-3.7977725e-2, 32 | maxval=3.7977725e-2, 33 | dtype=tf.float32) # 2 34 | var10013=tf.Variable(name="w1_w", trainable=True, initial_value=var10012) 35 | #shape: [1024] 36 | var10014=tf.random.truncated_normal([1024], stddev=0.1, dtype=tf.float32) # 3 37 | var10015=tf.Variable(name="w1_bias", trainable=True, initial_value=var10014) 38 | #shape: [1024, 10] 39 | var10016=tf.random.uniform([1024, 10], 40 | minval=-7.61755e-2, 41 | maxval=7.61755e-2, 42 | dtype=tf.float32) # 4 43 | var10017=tf.Variable(name="w2_w", trainable=True, initial_value=var10016) 44 | #shape: [10] 45 | var10018=tf.random.truncated_normal([10], stddev=0.1, dtype=tf.float32) # 5 46 | var10019=tf.Variable(name="w2_bias", trainable=True, initial_value=var10018) 47 | return {"batch_size":100, 48 | "parameters":[ var10002 49 | , var10006 50 | , var10009 51 | , var10011 52 | , var10013 53 | , var10015 54 | , var10017 55 | , var10019 ], 56 | "paramsdict":{"f1_filters":var10002, 57 | "f1_biases":var10006, 58 | "f2_filters":var10009, 59 | "f2_biases":var10011, 60 | "w1_w":var10013, 61 | "w1_bias":var10015, 62 | "w2_w":var10017, 63 | "w2_bias":var10019}} 64 | @tf.function 65 | def runModel_fn(training_placeholder, 66 | f1_filters, 67 | f1_biases, 68 | f2_filters, 69 | f2_biases, 70 | w1_w, 71 | w1_bias, 72 | w2_w, 73 | w2_bias, 74 | x, 75 | y): 76 | #shape: [100, 10] 77 | var10020=y 78 | #shape: [100, 784] 79 | var10021=x 80 | #shape: [100, 28, 28, 1] 81 | var10022=tf.reshape(var10021, [100, 28, 28, 1]) 82 | #shape: [5, 5, 1, 32] 83 | var10023=f1_filters 84 | #shape: [100, 28, 28, 32] 85 | var10024=tf.nn.convolution(var10022, var10023, padding="SAME", data_format="NHWC") 86 | #shape: [100, 784, 32] 87 | var10025=tf.reshape(var10024, [100, 784, 32]) 88 | #shape: [32] 89 | var10026=f1_biases 90 | #shape: [784, 32] 91 | var10027=tf.broadcast_to(tf.reshape(var10026, [1, 32]), [784, 32]) 92 | #shape: [100, 784, 32] 93 | var10028=tf.broadcast_to(tf.reshape(var10027, [1, 784, 32]), [100, 784, 32]) 94 | #shape: [100, 784, 32] 95 | var10029=tf.add(var10025, var10028) 96 | #shape: [100, 28, 28, 32] 97 | var10030=tf.reshape(var10029, [100, 28, 28, 32]) 98 | #shape: [100, 28, 28, 32] 99 | var10031=tf.nn.relu(var10030) 100 | #shape: [100, 28, 28, 32] 101 | var10032=tf.reshape(var10031, [100, 28, 28, 32]) 102 | #shape: [100, 14, 14, 32] 103 | var10033=tf.nn.pool(var10032, [2, 2], "MAX", strides=[2, 2], padding="SAME") 104 | #shape: [100, 14, 14, 32] 105 | var10034=tf.reshape(var10033, [100, 14, 14, 32]) 106 | #shape: [5, 5, 32, 64] 107 | var10035=f2_filters 108 | #shape: [100, 14, 14, 64] 109 | var10036=tf.nn.convolution(var10034, var10035, padding="SAME", data_format="NHWC") 110 | #shape: [100, 196, 64] 111 | var10037=tf.reshape(var10036, [100, 196, 64]) 112 | #shape: [64] 113 | var10038=f2_biases 114 | #shape: [196, 64] 115 | var10039=tf.broadcast_to(tf.reshape(var10038, [1, 64]), [196, 64]) 116 | #shape: [100, 196, 64] 117 | var10040=tf.broadcast_to(tf.reshape(var10039, [1, 196, 64]), [100, 196, 64]) 118 | #shape: [100, 196, 64] 119 | var10041=tf.add(var10037, var10040) 120 | #shape: [100, 14, 14, 64] 121 | var10042=tf.reshape(var10041, [100, 14, 14, 64]) 122 | #shape: [100, 14, 14, 64] 123 | var10043=tf.nn.relu(var10042) 124 | #shape: [100, 14, 14, 64] 125 | var10044=tf.reshape(var10043, [100, 14, 14, 64]) 126 | #shape: [100, 7, 7, 64] 127 | var10045=tf.nn.pool(var10044, [2, 2], "MAX", strides=[2, 2], padding="SAME") 128 | #shape: [100, 3136] 129 | var10046=tf.reshape(var10045, [100, 3136]) 130 | #shape: [3136, 1024] 131 | var10047=w1_w 132 | #shape: [100, 1024] 133 | var10048=tf.matmul(var10046, var10047) 134 | #shape: [100, 1024] 135 | var10049=tf.reshape(var10048, [100, 1024]) 136 | #shape: [1024] 137 | var10050=w1_bias 138 | #shape: [100, 1024] 139 | var10051=tf.broadcast_to(tf.reshape(var10050, [1, 1024]), [100, 1024]) 140 | #shape: [100, 1024] 141 | var10052=tf.add(var10049, var10051) 142 | #shape: [100, 1024] 143 | var10053=tf.nn.relu(var10052) 144 | #shape: [100, 1024] 145 | var10054=tf.reshape(var10053, [100, 1024]) 146 | #shape: [1024, 10] 147 | var10055=w2_w 148 | #shape: [100, 10] 149 | var10056=tf.matmul(var10054, var10055) 150 | #shape: [100, 10] 151 | var10057=tf.reshape(var10056, [100, 10]) 152 | #shape: [10] 153 | var10058=w2_bias 154 | #shape: [100, 10] 155 | var10059=tf.broadcast_to(tf.reshape(var10058, [1, 10]), [100, 10]) 156 | #shape: [100, 10] 157 | var10060=tf.add(var10057, var10059) 158 | #shape: [100] 159 | var10061=tf.nn.softmax_cross_entropy_with_logits(labels=var10020, logits=var10060) 160 | #shape: [100] 161 | var10062=tf.reshape(var10061, [100]) 162 | #shape: [] 163 | var10063=tf.reduce_mean(var10062, axis=0) 164 | #shape: [] 165 | var10064=tf.constant(0.0, shape=[], dtype=tf.float32) 166 | #shape: [1] 167 | var10065=tf.broadcast_to(tf.reshape(var10064, [1]), [1]) 168 | #shape: [] 169 | var10066=tf.reshape(var10065, []) 170 | #shape: [] 171 | var10067=tf.add(var10063, var10066) 172 | #shape: [100] 173 | var10068=tf.argmax(var10060, axis=1, output_type=tf.int32) 174 | #shape: [100] 175 | var10069=tf.argmax(var10020, axis=1, output_type=tf.int32) 176 | #shape: [100] 177 | var10070=tf.equal(var10068, var10069) 178 | #shape: [100] 179 | var10071=tf.cast(var10070, tf.float32) 180 | #shape: [100] 181 | var10072=tf.reshape(var10071, [100]) 182 | #shape: [] 183 | var10073=tf.reduce_mean(var10072, axis=0) 184 | #shape: [100, 10] 185 | var10074=tf.reshape(var10060, [100, 10]) 186 | #shape: [100, 10] 187 | var10075=tf.nn.softmax(var10074, axis=1) 188 | #shape: [100, 10] 189 | var10076=tf.reshape(var10075, [100, 10]) 190 | return {"loss":var10067, "accuracy":var10073, "y_":var10076} 191 | runModel = {"function":runModel_fn, 192 | "batched":True, 193 | "placeholders":{"x":{"shape":[100, 784], "dtype":tf.float32}, 194 | "y":{"shape":[100, 10], "dtype":tf.float32}}} -------------------------------------------------------------------------------- /examples/seq2seq/GenTr.hs: -------------------------------------------------------------------------------- 1 | import Control.Applicative 2 | import Test.QuickCheck.Gen 3 | import Data.List 4 | import Data.Array 5 | data Abs a = Bin a (Abs a) (Abs a) | Leaf a deriving Show 6 | 7 | type Method a = a -> [a] -> [a] -> [a] 8 | 9 | parens :: String -> String 10 | parens xs = "(" ++ xs ++ ")" 11 | 12 | preorder :: Char -> [Char] -> [Char] -> String 13 | preorder x l r = (x : l ++ r) 14 | postorder :: Char -> [Char] -> [Char] -> String 15 | postorder x l r = (l ++ r ++ [x]) 16 | reversePO :: Char -> [Char] -> [Char] -> String 17 | reversePO x l r = (x : r ++ l) 18 | 19 | linearize _ (Leaf x) = [x] 20 | linearize m (Bin x l r) = parens (m x (lin l) (lin r)) 21 | where lin = linearize m 22 | 23 | mkMethods :: Eq a => [(a->Bool,Method a)] -> Method a 24 | mkMethods ms x = case find (\(p,_) -> p x) ms of 25 | Just (_,m) -> m x 26 | Nothing -> error "no applicable linearization method" 27 | 28 | linPO :: Abs Char -> [Char] 29 | linPO = linearize (mkMethods [(const True,preorder)]) 30 | 31 | lin1 :: Abs Char -> [Char] 32 | lin1 = linearize (mkMethods [(\x -> x < '3',reversePO),(const True,preorder)]) 33 | 34 | 35 | ex :: Abs Char 36 | ex = Bin 'a' (Bin '1' (Leaf 'b') (Leaf 'c')) (Leaf 'd') 37 | 38 | guard :: Alternative f => Bool -> f a -> f a 39 | guard True x = x 40 | guard False _ = empty 41 | 42 | 43 | arb :: Gen (Abs Char) 44 | arb = sized $ \n -> do 45 | oneof (take (max 1 n) [(Leaf <$> elements ['a'..'e']) 46 | ,resize (n-1) (Bin <$> elements ['0'..'4'] <*> arb <*> arb)]) 47 | 48 | arbOkSize :: Gen (Abs Char) 49 | arbOkSize = do 50 | x <- resize 6 arb 51 | let xx = linPO x 52 | if (length xx > 2 && length xx < 22) 53 | then return x 54 | else arbOkSize 55 | 56 | mySample :: Int -> IO [Abs Char] 57 | mySample n = generate (sequence $ replicate n arbOkSize) 58 | 59 | showEx :: Abs Char -> String 60 | showEx x = linPO x ++ "\t" ++ lin1 x 61 | 62 | 63 | test :: IO () 64 | test = mapM_ putStrLn . map showEx =<< mySample 10 65 | 66 | 67 | main :: IO () 68 | main = writeFile "synthtrees.txt" . unlines . map showEx =<< mySample 100000 69 | -------------------------------------------------------------------------------- /examples/seq2seq/Makefile: -------------------------------------------------------------------------------- 1 | test: s2s.py synthtrees.txt main.py 2 | nix-shell --run "python main.py" 3 | 4 | s2s.py: Seq2Seq.hs 5 | nix-shell ../../.styx/shell.nix --run "ghci -i../.. Seq2Seq.hs -e main" 6 | 7 | synthtrees.txt: GenTr.hs 8 | nix-shell ../../.styx/shell.nix --run "ghc --make GenTr" 9 | ./GenTr 10 | -------------------------------------------------------------------------------- /examples/seq2seq/Seq2Seq.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes #-} 2 | {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} 3 | {-# LANGUAGE DataKinds #-} 4 | {-# LANGUAGE RankNTypes #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | {-# LANGUAGE TypeApplications #-} 7 | {-# LANGUAGE TypeFamilies #-} 8 | {-# LANGUAGE TypeOperators #-} 9 | {-# LANGUAGE UnicodeSyntax #-} 10 | 11 | module Main where 12 | 13 | import TypedFlow 14 | import TypedFlow.Python 15 | 16 | mkLSTM :: ∀ n x w. 17 | KnownNat x => KnownNat n => KnownBits w 18 | => String 19 | -> Gen (RnnCell w '[ '[n], '[n]] (Tensor '[x] (Flt w)) 20 | (Tensor '[n] (Flt w))) 21 | mkLSTM pName = do 22 | params <- parameterDefault pName 23 | drp1 <- mkDropout (DropProb 0.05) 24 | rdrp1 <- mkDropouts (DropProb 0.05) 25 | return (timeDistribute drp1 .-. onStates rdrp1 (lstm params)) 26 | 27 | encoder :: forall (lstmSize :: Nat) (vocSize :: Nat) (n :: Nat) w. 28 | KnownNat lstmSize => KnownNat vocSize 29 | => (KnownNat n) => KnownBits w 30 | => String 31 | -> Gen 32 | ( 33 | T '[] Int32 -- length 34 | -> Tensor '[n] Int32 -> 35 | ((HTV (Flt w) '[ '[lstmSize], '[lstmSize] ], Tensor '[n, lstmSize] (Flt w)))) 36 | encoder prefix = do 37 | embs <- parameterDefault (prefix++"embs") 38 | lstm1 <- mkLSTM (prefix++"lstm1") 39 | return $ \len input -> 40 | runRnn 41 | (iterateWithCull len (timeDistribute (embedding @vocSize @vocSize embs) .-. lstm1)) 42 | (repeatT zeros, input) 43 | 44 | decoder :: forall (lstmSize :: Nat) (n :: Nat) (outVocabSize :: Nat) (d::Nat) w. 45 | KnownNat lstmSize => KnownNat d => (KnownNat outVocabSize, KnownNat n) => KnownBits w => 46 | String 47 | -> Gen ( 48 | T '[] Int32 -- ^ length 49 | -> T '[n, d] (Flt w) -- todo: consider a larger size for the output string 50 | -> HTV (Flt w) '[ '[lstmSize], '[lstmSize] ] 51 | -> Tensor '[n] Int32 52 | -> Tensor '[n, outVocabSize] (Flt w)) 53 | decoder prefix = do 54 | -- note: for an intra-language translation the embeddings can be shared easily. 55 | projs <- parameterDefault (prefix++"proj") 56 | lstm1 <- mkLSTM (prefix++"lstm1") 57 | embs <- parameterDefault "embs" 58 | w1 <- parameter' (prefix++"att1") =<< glorotUniform 59 | return $ \ lens hs thoughtVectors targetInput -> 60 | let attn = uniformAttn (multiplicativeScoring w1) lens hs -- NOTE: attention on the left-part of the input. 61 | (_sFinal,outFinal) = simpleRnn 62 | ((timeDistribute (embedding @outVocabSize @outVocabSize embs) 63 | .-. 64 | attentiveWithFeedback attn lstm1 65 | .-. 66 | timeDistribute (dense projs))) 67 | ((F zeros :* thoughtVectors), targetInput) 68 | in outFinal 69 | 70 | 71 | seq2seq :: forall (vocSize :: Nat) (n :: Nat). 72 | KnownNat vocSize => (KnownNat n) 73 | => Gen (Placeholders 74 | '[ '("tgt_weights", '[n], Float32), 75 | '("src_in", '[n], Int32), 76 | '("src_len", '[], Int32), 77 | '("tgt_in", '[n], Int32), 78 | '("tgt_out", '[n], Int32)] -> 79 | ModelOutput Float32 '[n, vocSize] '[]) 80 | seq2seq = do 81 | enc <- encoder @256 @vocSize "enc" 82 | dec <- decoder "dec" 83 | return $ \(PHT masks :* PHT input :* PHT inputLen :* PHT tgtIn :* PHT tgtOut :* Unit) -> 84 | let (VecPair t1 t2,h) = enc inputLen input 85 | y_ = dec inputLen h (VecPair t1 t2) tgtIn 86 | in timedCategorical masks y_ tgtOut 87 | 88 | 89 | 90 | 91 | main :: IO () 92 | main = generateFile "s2s.py" (compileGen @256 93 | defaultOptions {maxGradientNorm = Just 5} 94 | (stateless <$> seq2seq @15 @22)) 95 | 96 | -- >>> main 97 | -- Parameters (total 889041): 98 | -- decatt1: T [256,256] tf.float32 99 | -- embs: T [15,15] tf.float32 100 | -- declstm1_o_bias: T [256] tf.float32 101 | -- declstm1_o_w: T [527,256] tf.float32 102 | -- declstm1_c_bias: T [256] tf.float32 103 | -- declstm1_c_w: T [527,256] tf.float32 104 | -- declstm1_i_bias: T [256] tf.float32 105 | -- declstm1_i_w: T [527,256] tf.float32 106 | -- declstm1_f_bias: T [256] tf.float32 107 | -- declstm1_f_w: T [527,256] tf.float32 108 | -- decproj_bias: T [15] tf.float32 109 | -- decproj_w: T [256,15] tf.float32 110 | -- enclstm1_o_bias: T [256] tf.float32 111 | -- enclstm1_o_w: T [271,256] tf.float32 112 | -- enclstm1_c_bias: T [256] tf.float32 113 | -- enclstm1_c_w: T [271,256] tf.float32 114 | -- enclstm1_i_bias: T [256] tf.float32 115 | -- enclstm1_i_w: T [271,256] tf.float32 116 | -- enclstm1_f_bias: T [256] tf.float32 117 | -- enclstm1_f_w: T [271,256] tf.float32 118 | -- encembs: T [15,15] tf.float32 119 | 120 | -- Local Variables: 121 | -- dante-repl-command-line: ("nix-shell" ".styx/shell.nix" "--pure" "--run" "cabal repl") 122 | -- End: 123 | 124 | -------------------------------------------------------------------------------- /examples/seq2seq/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../..') # so we can see the rts. 3 | 4 | import typedflow_rts as tyf 5 | import tensorflow as tf 6 | import numpy as np 7 | from s2s import mkModel 8 | import os 9 | import math 10 | import random 11 | 12 | 13 | # comment out if you don't have CUDA 14 | tyf.cuda_use_one_free_device() 15 | 16 | chars = sorted(list("()01234abcde^$ ")) 17 | 18 | print('total chars:', len(chars)) 19 | char_indices = dict((c, i) for i, c in enumerate(chars)) 20 | indices_char = dict((i, c) for i, c in enumerate(chars)) 21 | 22 | MAXLEN = 22 23 | def pad(ws): return (ws + ' '*(MAXLEN - len(ws))) 24 | 25 | def encode(s): 26 | # print ("proun", s) 27 | return np.array([char_indices[c] for c in s]) 28 | 29 | def decode(s): return "".join([indices_char[c] for c in list(s)]) 30 | 31 | def pad_right(sentence): return (MAXLEN - len(sentence)) * " " + sentence 32 | def pad_left(sentence): return sentence + (MAXLEN - len(sentence)) * " " 33 | 34 | def source_input_conversion(s): 35 | return encode(pad_left(s)) 36 | 37 | def target_input_conversion(sentence): 38 | return encode(pad_left("^"+sentence)) 39 | 40 | def target_output_conversion(sentence): 41 | return encode(pad_left(sentence+"$")) 42 | 43 | def sentence_target_weights(sentence): 44 | l = len(sentence) 45 | w = (l + 1) * [1] + (MAXLEN - (l + 1)) * [0] 46 | return np.array(w) 47 | 48 | def map(f,l): 49 | return [f(x) for x in l] 50 | 51 | def make_examples(l): 52 | (l1,l2) = zip(*l) 53 | return {"src_in":map(source_input_conversion,l1), 54 | "src_len":map(len,l1), 55 | "tgt_in":map(target_input_conversion,l2), 56 | "tgt_out":map(target_output_conversion,l2), 57 | "tgt_weights":map(sentence_target_weights,l2)} 58 | 59 | def s2s_generator(src_len,src_in,tgt_in,tgt_out,tgt_weights): 60 | def gen(bs): 61 | for i in range(0, bs*(len(src_in)//bs), bs): 62 | # print ({"src_len":src_len[i:i+bs], "src_in":src_in[i:i+bs], "tgt_in":tgt_in[i:i+bs], "tgt_out":tgt_out[i:i+bs], "tgt_weights":tgt_weights[i:i+bs]}) 63 | yield {"src_len":src_len[i:i+bs], 64 | "src_in":src_in[i:i+bs], 65 | "tgt_in":tgt_in[i:i+bs], 66 | "tgt_out":tgt_out[i:i+bs], 67 | "tgt_weights":tgt_weights[i:i+bs]} 68 | return gen 69 | 70 | 71 | def my_sample(l,n): 72 | return list(random.sample(l,min(n,len(l)))) 73 | 74 | print("Reading sentences...") 75 | all_sentences = [l.strip().split("\t") for l in open("synthtrees.txt").readlines()] 76 | 77 | val_set = make_examples(all_sentences[:2000]) 78 | train_set = make_examples(all_sentences[2000:]) 79 | 80 | print("Loading model") 81 | model = mkModel(tf.train.AdamOptimizer()) 82 | sess = tf.Session() 83 | saver = tf.train.Saver() 84 | 85 | def printer(x): 86 | (p,y,h) = x 87 | print("Prob", p, decode(y),h) 88 | 89 | 90 | def translate(s): 91 | r = tyf.beam_translate(sess,model,14, 92 | source_input_conversion(s), 93 | len(s), 94 | char_indices["^"], char_indices["$"], 95 | printer) 96 | for x in r: printer(x) 97 | 98 | def translate_cb(values): 99 | if values["epoch"] % 10 == 0: 100 | save_path = saver.save(sess, "model.ckpt") 101 | translate("(1(3cb)b)") 102 | print ("Desired:", "(1b(3cb))") 103 | return False 104 | 105 | tyf.initialize_params(sess,model) 106 | train_stats = tyf.train(sess, 107 | model, 108 | s2s_generator(**train_set), 109 | valid_generator = s2s_generator(**val_set), 110 | epochs=5000, 111 | callbacks=[tyf.StopWhenAccurate(.01), translate_cb]) 112 | 113 | translate("(1(3cb)b)") 114 | translate("(1(2c(3e(4(1cb)b)))c)") 115 | -------------------------------------------------------------------------------- /examples/seq2seq/shell.nix: -------------------------------------------------------------------------------- 1 | { bootstrap ? import {} }: 2 | let nixpkgs_source = fetchTarball https://github.com/NixOS/nixpkgs/archive/nixos-20.03.tar.gz; 3 | # nixpkgs_source = fetchTarball https://github.com/NixOS/nixpkgs/archive/4cf0b6ba5d5ab5eb20a88449e0612f4dad8e4c29.tar.gz; 4 | # nixpkgs_source = bootstrap.fetchFromGitHub { # for safety of checking the hash 5 | # owner = "jyp"; 6 | # repo = "nixpkgs"; 7 | # rev = "6b911c2d99ad116fca338fc26de86b8859079322"; 8 | # sha256 = "1bhwjkynya653mvpc4wwqks6kxnc06gyw6sbpwp8dbyr444ms4bd"; 9 | # }; 10 | # nixpkgs_source = ~/repo/nixpkgs; 11 | 12 | in with (import nixpkgs_source {}).pkgs; 13 | let py = (pkgs.python37.withPackages (ps: [ps.tensorflow-bin_2 ps.nltk])); 14 | 15 | in pkgs.stdenv.mkDerivation { 16 | name = "my-env-0"; 17 | buildInputs = [ py ]; 18 | } 19 | 20 | -------------------------------------------------------------------------------- /styx.yaml: -------------------------------------------------------------------------------- 1 | local-packages: 2 | typedflow: 3 | location: . 4 | 5 | nix-deps: 6 | - QuickCheck 7 | - hscolour 8 | 9 | # non-haskell-deps: 10 | # - glibcLocales 11 | 12 | nixpkgs: 13 | # commit: 80812af9e46167e3104038f2af6de251f90823a8 14 | # sha256: 0b718zkn5lhy71pyp0klbz7w872zck0ljqfk17f0b56k3rlvp1sy 15 | url: https://github.com/NixOS/nixpkgs/archive/nixos-21.05.tar.gz 16 | -------------------------------------------------------------------------------- /typedflow.cabal: -------------------------------------------------------------------------------- 1 | name: typedflow 2 | version: 0.9 3 | category: Deep Learning 4 | synopsis: Typed frontend to TensorFlow and higher-order deep learning 5 | description: TypedFlow is a typed, higher-order frontend to TensorFlow and a high-level library for deep-learning. 6 | . 7 | The main design principles are: 8 | . 9 | - To make the parameters of layers explicit. This choice makes sharing of parameters explicit and allows to implement "layers" as pure functions. 10 | . 11 | - To provide as precise as possible types. Functions are explicit about the shapes and elements of the tensors that they manipulate (they are often polymorphic in shapes and elements though.) 12 | . 13 | - To let combinators be as transparent as possible. If a NN layers is a simple tensor transformation it will be exposed as such. 14 | license: LGPL-3 15 | license-file: LICENSE 16 | author: Jean-Philippe Bernardy 17 | maintainer: jean-philippe.bernardy@gu.se 18 | Cabal-Version: >= 1.12 19 | build-type: Simple 20 | source-repository head 21 | type: git 22 | location: git@github.com:GU-CLASP/TypedFlow.git 23 | 24 | library 25 | default-language: Haskell2010 26 | build-depends: 27 | base==4.*, 28 | ghc-typelits-knownnat, 29 | prettyprinter, 30 | mtl, 31 | containers 32 | -- ,tensorflow-opgen, tensorflow, tensorflow-core-ops, tensorflow-ops 33 | 34 | exposed-modules: 35 | TypedFlow, 36 | TypedFlow.Layers, 37 | TypedFlow.Layers.Core, 38 | TypedFlow.Layers.RNN, 39 | TypedFlow.Layers.RNN.Base, 40 | TypedFlow.Layers.RNN.Cells, 41 | TypedFlow.Layers.RNN.Attention, 42 | TypedFlow.Learn, 43 | TypedFlow.Models.Topic, 44 | TypedFlow.Models.Transformer, 45 | TypedFlow.Python, 46 | TypedFlow.TF, 47 | TypedFlow.Types, 48 | TypedFlow.Types.Proofs 49 | 50 | other-modules: 51 | TypedFlow.Memo 52 | TypedFlow.Memo2 53 | TypedFlow.Abstract 54 | TypedFlow.Broadcast 55 | -------------------------------------------------------------------------------- /typedflow_rts.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import sys 4 | from time import time 5 | import os 6 | import random 7 | 8 | ############################################################### 9 | # Devices 10 | ############################################################### 11 | 12 | 13 | def cuda_use_device(n): 14 | """Attempt to use a given CUDA device by setting the appropriate environment variables""" 15 | os.environ["CUDA_DEVICE_ORDER"]= "PCI_BUS_ID" 16 | if os.environ.get("CUDA_VISIBLE_DEVICES") is None: 17 | os.environ["CUDA_VISIBLE_DEVICES"] = str(n) 18 | 19 | def find_free_cuda_device(): 20 | currentGPU = -1 21 | gpuMemory=dict() 22 | gpuUtil=dict() 23 | for line in os.popen("nvidia-smi -q"): 24 | fields = list(map(lambda x: x.strip(), line.split(":"))) 25 | k = fields[0] 26 | if k == "Minor Number": 27 | currentGPU += 1 28 | gpuMemory[currentGPU] = 0 29 | elif k == "Used GPU Memory": 30 | gpuMemory[currentGPU] = int(fields[1][:-4]) # last characters are " MiB" 31 | elif k == "Gpu": 32 | gpuUtil[currentGPU] = fields[1] # last characters are " %" 33 | minUse = min(gpuMemory.values()) 34 | freeGpus = [g for g in gpuMemory.keys() if gpuMemory[g] == minUse] 35 | if freeGpus == []: 36 | print("No free GPU could be found.") 37 | assert False 38 | else: 39 | result = random.choice(freeGpus) 40 | print ("Found device",result,"currently used at",gpuUtil[result],"and with",gpuMemory[result],"MB taken.") 41 | return result 42 | 43 | def cuda_use_one_free_device(): 44 | """Attempt to use a free CUDA device by setting the appropriate environment variables""" 45 | cuda_use_device(find_free_cuda_device()) 46 | 47 | ############################################################### 48 | # Generators 49 | ############################################################### 50 | 51 | def bilist_generator(l): 52 | """ 53 | Given a pair of x and y (each being a list or a np array) and a 54 | batch size, return a generator function which will yield the input 55 | in bs-sized chunks. Attention: if the size of the input is not 56 | divisible by bs, then the remainer will not be fed. Consider 57 | shuffling the input. 58 | """ 59 | (l0,l1) = l 60 | def gen(bs): 61 | if len(l0) == 0: 62 | return 63 | for i in range(0, bs*(len(l0)//bs), bs): 64 | yield {"x":l0[i:i+bs],"y":l1[i:i+bs]} 65 | return gen 66 | 67 | 68 | def bilist_generator_transposed(model,l): 69 | ''' 70 | Given a pair of l=(x,y) (both x,y being a list or a np array) and a 71 | batch size, return a generator function which will yield the input 72 | in bs*maxlen-sized chunks. This generator is intended to be used for 73 | stateful language models. That is, batch sequencing corresponds to 74 | ''' 75 | (batch_size,maxlen) = model["x"].shape 76 | (xs,ys) = l 77 | num_items = len(xs) // (batch_size*maxlen) 78 | x = np.zeros(shape=(num_items,batch_size,maxlen)) 79 | y = np.zeros(shape=(num_items,batch_size,maxlen)) 80 | for i in range(num_items): 81 | for j in range(batch_size): 82 | for k in range(maxlen): 83 | x[i][j][k] = xs[k+j*(num_items*maxlen)+i*maxlen] 84 | y[i][j][k] = ys[k+j*(num_items*maxlen)+i*maxlen] 85 | def gen(_bs): 86 | nonlocal num_items, x, y 87 | for i in range(num_items): 88 | yield {"x":x[i],"y":y[i]} 89 | return gen 90 | 91 | def dict_generator (xs): 92 | k0 = next (iter (xs.keys())) # at least one key is needed 93 | total_len = len(xs[k0]) 94 | 95 | def gen(bs): 96 | for i in range(0, bs*(total_len//bs), bs): 97 | yield dict((k,xs[k][i:i+bs]) for k in xs) 98 | 99 | return gen 100 | 101 | 102 | def initialize_params (session,model): 103 | '''Initialize the learnable parameters of the model''' 104 | # it'd be nice to do: 105 | 106 | # session.run(tf.variables_initializer(model["params"])) 107 | 108 | # However this does not initialize the optimizer's variables. So, 109 | # instead we do: 110 | 111 | session.run(tf.local_variables_initializer()) 112 | session.run(tf.global_variables_initializer()) 113 | 114 | def train (optimizer, model_static, model_fn, 115 | train_generator=bilist_generator(([],[])), 116 | valid_generator=bilist_generator(([],[])), 117 | epochs=100, 118 | callbacks=[], 119 | extraVectors=[]): 120 | ''' 121 | Train the given model. 122 | 123 | train_generator: training data 124 | 125 | valid_generator: validation data 126 | 127 | epochs: number of epochs 128 | 129 | callbacks: list of callbacks. 130 | Each callback receives an epoch entry (see below). If it returns False then the training is aborted. 131 | 132 | extraVectors: list of extra vectors to pass to session.run when training. 133 | 134 | modelPrefix: in case of a multitask/multimodel, give the prefix of the model to use. 135 | 136 | This function returns a list of epoch entries. Each entry is a dictionary with: 137 | - "epoch": current epoch 138 | - "val" and "train": dictionaries with 139 | - "loss", "accuracy", "error_rate", time", "start_time", "end_time" 140 | ''' 141 | batch_size = model_static["batch_size"] 142 | train_vars = model_static["parameters"] 143 | placeholders_info = model_fn["placeholders"] 144 | stats = [] 145 | def halfEpoch(isTraining): 146 | totalAccur = 0 147 | totalLoss = 0 148 | n = 0 149 | print ("Training" if isTraining else "Validation", end="") 150 | start_time = time() 151 | for inputs in train_generator(batch_size) if isTraining else valid_generator(batch_size): 152 | cast_inputs = dict((k,tf.cast(inputs[k], placeholders_info[k]["dtype"])) for k in placeholders_info) 153 | # the above forces inputs to be tensors. (It's convenient to pass just lists here) 154 | print(".",end="") 155 | sys.stdout.flush() 156 | with tf.GradientTape() as tape: 157 | results = model_fn["function"](tf.constant(isTraining, shape=[]), **{**(model_static["paramsdict"]), **cast_inputs}) 158 | loss = results["loss"] 159 | accur = results["accuracy"] 160 | if isTraining: 161 | grads = tape.gradient(loss, train_vars) 162 | optimizer.apply_gradients(zip(grads, train_vars)) 163 | n+=1 164 | totalLoss += loss 165 | totalAccur += accur 166 | end_time = time() 167 | totalAccur = totalAccur.numpy() 168 | totalLoss = totalLoss.numpy() 169 | if n > 0: 170 | avgLoss = totalLoss / float(n) 171 | avgAccur = totalAccur / float(n) 172 | print(".") 173 | print ("Time=%.1f" % (end_time - start_time), "loss=%g" % avgLoss, "accuracy=%.3f" % avgAccur) 174 | return {"loss":avgLoss,"accuracy":avgAccur,"time":(end_time - start_time),"error_rate":1-avgAccur,"start_time":start_time} 175 | else: 176 | print ("No data") 177 | return {"loss":0,"accur":0,"time":0,"error_rate":0,"start_time":0} 178 | 179 | for e in range(epochs): 180 | print ("Epoch {0}/{1}".format(e, epochs)) 181 | tr = halfEpoch(True) 182 | va = halfEpoch(False) 183 | epoch_stats = {"train":tr, "val":va, "epoch":e} 184 | stats.append(epoch_stats) 185 | if any(c(epoch_stats) for c in callbacks): 186 | break 187 | return stats 188 | 189 | def StopWhenValidationGetsWorse(patience = 1): 190 | '''Return a callback which stops training if validation loss gets worse.''' 191 | bestLoss = 10000000000 192 | p = patience 193 | def callback(values): 194 | nonlocal bestLoss, p, patience 195 | newLoss = values["val"]["loss"] 196 | if newLoss > bestLoss: 197 | p -= 1 198 | else: 199 | bestLoss = newLoss 200 | p = patience 201 | if p <= 0: 202 | return True 203 | return False 204 | return callback 205 | 206 | def StopWhenAccurate(phase="val",error_rate = .01): 207 | '''Return a callback which stops training if error rate drops below 1%''' 208 | def callback(values): 209 | nonlocal error_rate 210 | return values[phase]["error_rate"] < error_rate 211 | return callback 212 | 213 | def Every(n,f): 214 | '''Return a callback which calls its argument every n epochs''' 215 | def callback(values): 216 | nonlocal n,f 217 | if values["epoch"] % n == (n-1): 218 | return f(values) 219 | else: 220 | return False 221 | return callback 222 | 223 | def Save(sess,saver,ckptfile): 224 | def callback(values): 225 | nonlocal sess,saver 226 | print("Saving to",ckptfile) 227 | saver.save(sess, ckptfile) 228 | return False 229 | return callback 230 | 231 | ################################################################################################ 232 | # Prediction and evaluation 233 | 234 | 235 | def evaluate (model_static, model_fn, xs, result="y_"): 236 | '''Evaluate the model for given input and result. 237 | Input is given as a dictionary of lists to pass to session.run''' 238 | phs = model_fn["placeholders"] 239 | if phs: 240 | k0 = next (iter (phs.keys())) # 1st placeholder 241 | total_len = len(xs[k0]) # total length 242 | else: 243 | total_len = 1 244 | zeros = dict((k,tf.zeros(phs[k]["shape"][1:], # remove the batch size 245 | dtype=phs[k]["dtype"])) for k in phs.keys()) 246 | results = [] 247 | if model_fn["batched"]: 248 | def run(): 249 | bs = model_static["batch_size"] 250 | for i in range(0, bs*(-(-total_len//bs)), bs): 251 | print(".",end="") 252 | chunks = dict() 253 | for k in phs: 254 | chunks[k] = xs[k][i:i+bs] 255 | if i + bs > total_len: 256 | # dealing with an incomplete last chunk 257 | origLen = total_len - i 258 | for k in chunks: 259 | chunks[k] = list(chunks[k]) + [zeros[k]] * (bs - origLen) # pad the last chunk 260 | else: 261 | origLen = bs 262 | chunks = {k: tf.cast(v,dtype=phs[k]["dtype"]) for (k,v) in chunks.items()} 263 | results = model_fn["function"](tf.constant(False, shape=[]), **{**(model_static["paramsdict"]), **chunks}) 264 | yield results[result][:origLen] 265 | return np.concatenate(list(run())) 266 | else: 267 | def run(): 268 | for i in range(total_len): 269 | inputs = {k: tf.cast(xs[k][i], dtype=phs[k]["dtype"]) for k in phs} 270 | results = model_fn["function"](tf.constant(False, shape=[]), **{**(model_static["paramsdict"]), **inputs}) 271 | yield results[result] 272 | return list(run()) 273 | 274 | 275 | predict = evaluate 276 | 277 | def beam_translate(session, model, k, x, xlen, start_symbol, stop_symbol, debug=None): 278 | '''Beam translation of ONE input sentence.''' 279 | (_,out_len,voc_size) = model["y_"].shape 280 | xs = np.array ([x] * k) # The input is always the same 281 | xs_len = np.array ([xlen]*k) # it is VERY important to get the length right 282 | ys = [[start_symbol]] # start with a single thing; otherwise the minimum will be repeated k times 283 | probs = [1] 284 | results = [] 285 | hist = [[]] 286 | def pad(z): 287 | return np.array(z + [0] * (out_len - len(z))) 288 | for i in range(out_len-1): 289 | print ("beam search at:", i) 290 | inputs = {"src_len":xs_len[:len(ys)], "src_in":xs[:len(ys)], "tgt_in":np.array([pad(y) for y in ys])} 291 | y_s = predict(session,model,inputs) 292 | all_words = sorted([(y_s[j][i][w] * probs[j], ys[j] + [w], hist[j] + [y_s[j][i][w]]) 293 | for j in range(len(y_s)) 294 | for w in range(voc_size)]) 295 | best = all_words[-k:] 296 | if debug is not None: 297 | for x in best: debug(x) 298 | results += [(p,y,h) for (p,y,h) in best if y[i+1] == stop_symbol] 299 | continued = [(p,y,h) for (p,y,h) in best if y[i+1] != stop_symbol] 300 | if len(continued) == 0: break 301 | (probs,ys,hist) = zip(*continued) 302 | return sorted(results) 303 | 304 | ###################################################### 305 | # Saving and loading 306 | 307 | def save(model_static, file): 308 | numpy_tensors = {k:v.numpy() for (k,v) in model_static["paramsdict"].items()} 309 | print("Saving parameters: ", model_static["paramsdict"].keys()) 310 | np.savez(file,**numpy_tensors) 311 | print("Done") 312 | 313 | 314 | def load(model_static, file): 315 | print("Loading parameters") 316 | numpy_tensors = np.load(file) 317 | print("Loaded parameters: ", list(numpy_tensors.keys())) 318 | for k,v in model_static["paramsdict"].items(): 319 | v.assign(numpy_tensors[k]) 320 | print("Done") 321 | --------------------------------------------------------------------------------