├── .gitignore ├── src ├── .gitignore ├── dune-project ├── .merlin ├── dune ├── test.py ├── tensor.mli └── tensor.ml ├── tests └── testmult.ml ├── .github └── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.cmi 2 | -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | _build/* 2 | _build 3 | -------------------------------------------------------------------------------- /src/dune-project: -------------------------------------------------------------------------------- 1 | (lang dune 1.9) 2 | -------------------------------------------------------------------------------- /src/.merlin: -------------------------------------------------------------------------------- 1 | EXCLUDE_QUERY_DIR 2 | B _build/default/.zeta.objs/byte 3 | S . 4 | FLG -open Zeta -w @a-4-29-40-41-42-44-45-48-58-59-60-40 -strict-sequence -strict-formats -short-paths -keep-locs 5 | -------------------------------------------------------------------------------- /src/dune: -------------------------------------------------------------------------------- 1 | (library 2 | (name zeta)) 3 | 4 | ;; The library will be composed of all the modules in the same directory. 5 | ;; Outside of the library, module Tensor will be accessible as Zeta.Tensor, 6 | ;; unless you write an explicit zeta.ml file. 7 | 8 | ;; This library can be made available as an opam package if you replace 9 | ;; (name zeta) by (public_name zeta) and write a zeta.opam file. 10 | 11 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | a = t.FloatTensor([1]) 3 | a.requires_grad=True 4 | a.retain_grad() 5 | b = a * 2 6 | b.retain_grad() 7 | c = a * 3 8 | c.retain_grad() 9 | d = a * 4 10 | e = b*(c+d) 11 | #print(a.is_leaf) 12 | #print(e.is_leaf) 13 | print(e.grad_fn) 14 | e.backward(retain_graph = True) 15 | print(a.grad) 16 | print(b.grad) 17 | print(c.grad) 18 | print(d.grad) 19 | b.retain_grad() 20 | -------------------------------------------------------------------------------- /tests/testmult.ml: -------------------------------------------------------------------------------- 1 | let a = new_float [| 2;1;3 |] 1.0 2 | let () = set a [| 0;0;0 |] 1.0 3 | let () = set a [| 0;0;1 |] 2.0 4 | let () = set a [| 0;0;2 |] 3.0 5 | let () = set a [| 1;0;0 |] 4.0 6 | let () = set a [| 1;0;1 |] 5.0 7 | let () = set a [| 1;0;2 |] 6.0 8 | let b = new_float [| 1;4;1 |] 2.0 9 | let () = set b [| 0;0;0 |] 7.0 10 | let () = set b [| 0;1;0 |] 8.0 11 | let () = set b [| 0;2;0 |] 9.0 12 | let () = set b [| 0;3;0 |] 10.0 13 | 14 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /src/tensor.mli: -------------------------------------------------------------------------------- 1 | type op = 2 | | IntOp : (int -> int) -> op 3 | | BoolOp : (bool -> bool) -> op 4 | | FloatOp : (float -> float) -> op 5 | type predicate = 6 | | IntP : (int -> bool) -> predicate 7 | | BoolP : (bool -> bool) -> predicate 8 | | FloatP : (float -> bool) -> predicate 9 | type 'a tensordata = 10 | | IntScalar : int ref -> int tensordata 11 | | FloatScalar : float ref -> float tensordata 12 | | BoolScalar : bool ref -> bool tensordata 13 | | IntTensor : int tensordata array -> int tensordata 14 | | FloatTensor : float tensordata array -> float tensordata 15 | | BoolTensor : bool tensordata array -> bool tensordata 16 | type shape = int array 17 | type index_selection = Range of (int * int) | Index of int | All 18 | and slice = index_selection array 19 | type index = int array 20 | type 'a grad_fn = End | Fn of ('a tensor * op) array 21 | and 'a gradient = Retain of bool | Grad of (bool * 'a tensordata) 22 | and 'a parent = 'a tensor array 23 | and 'a node = LeafNoGrad | LeafGrad of 'a gradient | Node of ('a parent * 'a gradient) 24 | and 'a directed_acyclic_graph = Null | Graph of ('a grad_fn * 'a node) 25 | and 'a tensor = (shape * 'a tensordata * 'a directed_acyclic_graph ) ref 26 | exception TypeMismatch of string 27 | exception TensorInvariantViolated 28 | exception ShapeMismatch of string 29 | exception IndexError of string 30 | exception ZeroDimension 31 | exception AutogradError of string 32 | val is_leaf : 'a tensor -> bool 33 | val requires_grad : 'a tensor -> bool -> unit 34 | val retain_grad : 'a tensor -> bool -> unit 35 | val detach : 'a tensor -> 'a tensor 36 | val copy : ('a * 'b tensordata * 'c) ref -> ('a * 'b tensordata * 'c) ref 37 | val slice : slice -> 'a tensor -> 'a tensor 38 | val new_bool : shape -> bool -> bool tensor 39 | val new_int : shape -> int -> int tensor 40 | val new_float : shape -> float -> float tensor 41 | val reduce : predicate -> (bool * bool -> bool) -> bool -> 'a tensor -> bool 42 | val all : predicate -> 'a tensor -> bool 43 | val any : predicate -> 'a tensor -> bool 44 | val elem_apply : op -> 'a tensor -> (shape * 'a tensordata * 'a directed_acyclic_graph) ref 45 | val sigmoid : 'a tensor -> (shape * 'a tensordata * 'a directed_acyclic_graph) ref 46 | val abs : 'a tensor -> (shape * 'a tensordata * 'a directed_acyclic_graph) ref 47 | val new_t : shape -> 'a tensor -> bool -> (shape * 'a tensordata * 'a directed_acyclic_graph) ref 48 | val set : 'a tensor -> int array -> 'a -> unit 49 | val get : 'a tensor -> int array -> 'a 50 | val broadcast : (int array * 'a tensordata * 'b) ref -> int array -> bool -> (int array * 'a tensordata * 'b) ref 51 | val (#*) : 'a tensor -> 'a tensor -> 'a tensor 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # zeta (work in progress) 2 | 3 | functional neural networks in ocaml 4 | 5 | zeta is not (yet) usable. I made the repository public with the intention of inviting the community to contribute to this project. Please see below for motivations and the to-do list. Discussions are welcome, but pelease bear in mind that this is an early-stage work. 6 | 7 | ## Using zeta 8 | **Requirements**: `ocaml` and `dune` for building 9 | Zeta is still in early development stage. The API is undocumented and is subject to change. 10 | ``` 11 | git clone https://github.com/liaopeiyuan/zeta 12 | cd zeta/src && dune build zeta.a 13 | ``` 14 | The static library, `zeta.a`, will be located under `_build/default`. 15 | 16 | ## Why? 17 | functions are values :) 18 | 19 | But, in all seriousness, I personally think that the combination between functional programming and deep learning can create interesting results as I've myself noticed several dualities between the two when I'm learning the materials. 20 | 21 | ## Features of zeta 22 | 23 | **1. Pedagogical** 24 | 25 | zeta does not particularly aim for performance, though I will make sure that reasonable demos are runnable. The source code of zeta is designed to be easy-to-read and succinct so that the user can get more beyond merely using this library for their daily research by reading them. I will later add more documentations and possibly tutorials for this library. 26 | 27 | **2. Functional** 28 | 29 | One of the most annoying error messages I've encountered in PyTorch looks something like this: 30 | 31 | ``` 32 | RuntimeError: expected Double tensor (got Float tensor) 33 | ``` 34 | 35 | zeta aims to "moves" errors like this from runtime to compile-time by adopting a functional programming paradigm in OCaml. 36 | 37 | **3. Dynamic Computation Graphs** 38 | 39 | zeta provides interfaces similar to that of the PyTorch, where users can create a computational graph on-the-fly. 40 | 41 | **4. Imperative** 42 | 43 | The implementation of zeta's core module, Tensor, is inherently imperative. This is to help create a more efficient representation of a computation graph, and therefore a neural network. 44 | 45 | **5. ADTs/GADTs (Algebraic Data Types / Generalized Algebraic Data Types)** 46 | 47 | One of the main contributions of zeta is to abstract neural network and tensor operations into numerous ADTs/GADTs, and in the process summarizing some of the basic behaviors deep learning algorithms exhibit. For example, a tensor can be recursively defined as a GADT: 48 | 49 | ``` 50 | type 'a tensordata = 51 | | IntScalar : int ref -> int tensordata 52 | | FloatScalar : float ref -> float tensordata 53 | | BoolScalar : bool ref -> bool tensordata 54 | | IntTensor : int tensordata array -> int tensordata 55 | | FloatTensor : float tensordata array -> float tensordata 56 | | BoolTensor : bool tensordata array -> bool tensordata 57 | ``` 58 | 59 | Which inherently restricts creations of ill-typed tensors, e.g., implicit casting is performed in this PyTorch example: 60 | 61 | ``` 62 | >>> b = torch.FloatTensor([False]) 63 | >>> b 64 | tensor([0.]) 65 | ``` 66 | 67 | But the following would not type check in zeta: 68 | 69 | ``` 70 | let a = FloatTensor [| BoolScalar (ref false) |];; 71 | Error: This expression has type bool tensordata 72 | but an expression was expected of type float tensordata 73 | Type bool is not compatible with type float 74 | ``` 75 | 76 | ## To-Do List 77 | 78 | - Tensor viewing, ~~slicing~~, reshaping, concatenating 79 | - Tensor dot product and tensor product 80 | - Matrix multiplication 81 | - Convolutions (1d, 2d, 3d) 82 | - Autograd mechanisms: hooks, backward, etc. 83 | - Neural Network (G)ADT abstractions 84 | - Optimizer (G)ADT abstractions 85 | - Data loader and training abstractions 86 | - Pseudo-parallel implementation of Sequence 87 | - (Actually) parallel implementation of Sequence (possibly via Async) 88 | 89 | ## Looking for contributors 90 | 91 | zeta is still a work in progress! I'm actively looking for collaborators/contributors, and currently my plans include 92 | 93 | - Maintaining a documentation and a tutorial page to explain source code of zeta and to encourage educational use of the library 94 | - CI and unit tests 95 | - Items in the to-do list 96 | 97 | Or, if you are simply interested in functional programming/machine learning/DSL/type theory or whatever, you can always send me an email via (peiyuanl [at] andrew [dot] cmu [dot] edu) :) Make sure to include "zeta" in the title! 98 | 99 | ## References/Inspirations 100 | 101 | This project is inspired/helped by the following works: 102 | 103 | https://arxiv.org/abs/1912.01703 104 | 105 | http://www.cs.cmu.edu/afs/cs/academic/class/15210-f14/www/docs/sig/SEQUENCE.html 106 | 107 | https://arxiv.org/pdf/1411.0583.pdf 108 | 109 | https://en.wikipedia.org/wiki/Automatic_differentiation#Automatic_differentiation_using_dual_numbers 110 | 111 | https://github.com/ThoughtWorksInc/DeepDarkFantasy 112 | 113 | https://en.wikibooks.org/wiki/Haskell/GADT 114 | -------------------------------------------------------------------------------- /src/tensor.ml: -------------------------------------------------------------------------------- 1 | 2 | type op = 3 | | IntOp : (int -> int) -> op 4 | | BoolOp : (bool -> bool) -> op 5 | | FloatOp : (float -> float) -> op 6 | 7 | type predicate = 8 | | IntP : (int -> bool) -> predicate 9 | | BoolP : (bool -> bool) -> predicate 10 | | FloatP : (float -> bool) -> predicate 11 | 12 | type 'a tensordata = 13 | | IntScalar : int ref -> int tensordata 14 | | FloatScalar : float ref -> float tensordata 15 | | BoolScalar : bool ref -> bool tensordata 16 | | IntTensor : int tensordata array -> int tensordata 17 | | FloatTensor : float tensordata array -> float tensordata 18 | | BoolTensor : bool tensordata array -> bool tensordata 19 | 20 | type shape = int array 21 | type index_selection = Range of (int * int) | Index of int | All 22 | and slice = index_selection array 23 | type index = int array 24 | 25 | (* Next Layer *) 26 | (* Gradient : RetainRequire for freshly constructed graphs and tensors that don't require gradient 27 | and Grad for backpropped graphs 28 | bool used to encode retrain_grad option 29 | *) 30 | (* Directed Acyclic Graph (DAG) for tensor operations: 31 | Null for "isolated" tensors that do not belong to any graph 32 | Graph for tensors as nodes in the graph - 33 | gradient, node it's connected to with associated operations and node that connects to it 34 | *) 35 | type 'a grad_fn = End | Fn of ('a tensor * op) array 36 | and 'a gradient = Retain of bool | Grad of (bool * 'a tensordata) 37 | and 'a parent = 'a tensor array 38 | and 'a node = LeafNoGrad | LeafGrad of 'a gradient | Node of ('a parent * 'a gradient) 39 | and 'a directed_acyclic_graph = Null | Graph of ('a grad_fn * 'a node) 40 | (* shape describes the dimensions of the data *) 41 | and 'a tensor = (shape * 'a tensordata * 'a directed_acyclic_graph ) ref 42 | 43 | exception TypeMismatch of string 44 | (* This is to satisfy Caml's type system; in theory it should never be raised any time in code *) 45 | exception TensorInvariantViolated 46 | exception ShapeMismatch of string 47 | exception IndexError of string 48 | exception ZeroDimension 49 | exception AutogradError of string 50 | 51 | let is_leaf (t : 'a tensor) : bool = let (_, _, dag) = !t in 52 | match dag with 53 | | Null -> true 54 | | Graph (_, LeafNoGrad) -> true 55 | | Graph (_, LeafGrad _) -> true 56 | | Graph (_, Node _) -> false 57 | 58 | let requires_grad (t : 'a tensor) (b : bool) : unit = let (shape, data, dag) = !t in 59 | match (b, dag) with 60 | | (true, Null) -> t := (shape, data, Graph (End, LeafGrad (Retain false))) 61 | | (false, Null) -> Printf.printf "Warning : isolated leaf tensors does not require gradient. \n" 62 | | (_, Graph (_, Node _)) -> raise (AutogradError "you can only change requires_grad flags of leaf tensors.") 63 | | (true, Graph (_, LeafGrad _)) -> Printf.printf "Warning : tensor requires gradient already. \n" 64 | | (false, Graph (a, LeafGrad _)) -> t := (shape, data, Graph (a, LeafNoGrad)) 65 | | (true, Graph (a, LeafNoGrad)) -> t := (shape, data, Graph (a, LeafGrad (Retain false))) 66 | | (false, Graph (_, LeafNoGrad)) -> Printf.printf "Warning : leaf tensors does not require gradient already. \n" 67 | 68 | let retain_grad (t : 'a tensor) (b : bool) : unit = let (shape, data, dag) = !t in 69 | match dag with 70 | | Null -> raise (AutogradError "tensor does not require gradient.") 71 | | Graph (a, Node (p, Retain _)) -> t := (shape, data, Graph (a, Node (p, Retain b))) 72 | | Graph (a, Node (p, Grad (_, d))) -> t := (shape, data, Graph (a, Node (p, Grad (b, d))) ) 73 | | Graph (_, LeafNoGrad ) -> raise (AutogradError "leaf tensor does not require gradient.") 74 | | Graph (a, LeafGrad (Retain _) ) -> t := (shape, data, Graph (a, LeafGrad (Retain b) ) ) 75 | | Graph (a, LeafGrad (Grad (_,d)) ) -> t := (shape, data, Graph (a, LeafGrad (Grad (b,d)) ) ) 76 | 77 | let detach (t : 'a tensor) : 'a tensor = let (shape, data, _) = !t in ref (shape, data, Null) 78 | 79 | let _check_valid_shape shape = 80 | let len = Array.length shape in 81 | if (Array.fold_left (fun x y -> x || y) false (Array.init len (fun i -> (Array.get shape i)<0)) ) then raise (IndexError "Negative size along one of the dimensions") 82 | else if (Array.fold_left (fun x y -> x || y) false (Array.init len (fun i -> (Array.get shape i)=0)) ) 83 | then (Printf.printf "Warning : one of the dimensions is zero. \n"; raise ZeroDimension) 84 | else () 85 | 86 | let rec _copy : 'a. 'a tensordata -> bool -> 'a tensordata = 87 | fun (type el) (e : el tensordata) (b : bool) : el tensordata -> 88 | if b then match e with 89 | | IntScalar r -> IntScalar (ref (!r)) 90 | | FloatScalar r -> FloatScalar (ref (!r)) 91 | | BoolScalar r -> BoolScalar (ref (!r)) 92 | | BoolTensor r -> BoolTensor (Array.map (fun i -> _copy i b) r) 93 | | FloatTensor r -> FloatTensor (Array.map (fun i -> _copy i b) r) 94 | | IntTensor r -> IntTensor (Array.map (fun i -> _copy i b) r) 95 | else e 96 | 97 | let copy t = let (shape, data, dag) = !t in 98 | ref (shape, _copy data true, dag) 99 | 100 | (* TODO : DAG connection *) 101 | let slice (s : slice) (t : 'a tensor) : 'a tensor = 102 | let (shape, data, _) = !t in 103 | let shape_l = Array.to_list shape in 104 | let l = Array.to_list s in 105 | 106 | let rec slice' : 'a. index_selection list -> int list -> 'a tensordata -> 'a tensordata = 107 | fun (type el) (s : index_selection list) (l : int list) (d : el tensordata) : el tensordata -> 108 | match (s, l, d) with 109 | | ([],[], _) -> d 110 | | (All::xs, _::ys, IntTensor a) -> IntTensor (Array.map (fun i -> slice' xs ys i) a) 111 | | (All::xs, _::ys, BoolTensor a) -> BoolTensor (Array.map (fun i -> slice' xs ys i) a) 112 | | (All::xs, _::ys, FloatTensor a) -> FloatTensor (Array.map (fun i -> slice' xs ys i) a) 113 | | ((Index i)::xs, y::ys, IntTensor a) -> 114 | if (i < 0) || (i >= y) then 115 | let r = (("Expected index between 0 and " ^ (string_of_int (y-1)))^" ; got ")^(string_of_int i) 116 | in raise (IndexError r) 117 | else slice' xs ys (Array.get a i) 118 | | ((Index i)::xs, y::ys, FloatTensor a) -> 119 | if (i < 0) || (i >= y) then 120 | let r = (("Expected index between 0 and " ^ (string_of_int (y-1)))^" ; got ")^(string_of_int i) 121 | in raise (IndexError r) 122 | else slice' xs ys (Array.get a i) 123 | | ((Index i)::xs, y::ys, BoolTensor a) -> 124 | if (i < 0) || (i >= y) then 125 | let r = (("Expected index between 0 and " ^ (string_of_int (y-1)))^" ; got ")^(string_of_int i) 126 | in raise (IndexError r) 127 | else slice' xs ys (Array.get a i) 128 | 129 | | ((Range (i1, i2))::xs, y::ys, IntTensor a) -> 130 | if (i1 >= i2) then raise (IndexError "invalid range") else 131 | if (i1 < 0) || (i2 < 0) || (i1 >= y) || (i2 >= y) then 132 | let r = (("Expected index range between 0 and " ^ (string_of_int (y-1)))^" ; got ")^(((string_of_int i1)^"; ")^(string_of_int i2)) 133 | in raise (IndexError r) 134 | else IntTensor (Array.map (fun i -> slice' xs ys i) (Array.sub a i1 (i2-i1))) 135 | | ((Range (i1, i2))::xs, y::ys, FloatTensor a) -> 136 | if (i1 >= i2) then raise (IndexError "invalid range") else 137 | if (i1 < 0) || (i2 < 0) || (i1 >= y) || (i2 >= y) then 138 | let r = (("Expected index range between 0 and " ^ (string_of_int (y-1)))^" ; got ")^(((string_of_int i1)^"; ")^(string_of_int i2)) 139 | in raise (IndexError r) 140 | else FloatTensor (Array.map (fun i -> slice' xs ys i) (Array.sub a i1 (i2-i1))) 141 | | ((Range (i1, i2))::xs, y::ys, BoolTensor a) -> 142 | if (i1 >= i2) then raise (IndexError "invalid range") else 143 | if (i1 < 0) || (i2 < 0) || (i1 >= y) || (i2 >= y) then 144 | let r = (("Expected index range between 0 and " ^ (string_of_int (y-1)))^" ; got ")^(((string_of_int i1)^"; ")^(string_of_int i2)) 145 | in raise (IndexError r) 146 | else BoolTensor (Array.map (fun i -> slice' xs ys i) (Array.sub a i1 (i2-i1))) 147 | | _ -> raise (ShapeMismatch "slice must have the same dimension as the shape") in 148 | 149 | let rec new_shape l s r = match (l,s) with 150 | | ([], []) -> r 151 | | (x::xs, y::ys) -> (match x with 152 | | All -> new_shape xs ys (y::r) 153 | | Index _ -> new_shape xs ys r 154 | | Range (a,b) -> new_shape xs ys ((b-a)::r)) 155 | | _ -> raise TensorInvariantViolated in 156 | 157 | ref ( (new_shape l shape_l []) |> List.rev |> Array.of_list , _copy (slice' l shape_l data) true, Null) 158 | 159 | let rec _new_bool (s : int list) v b = match s with 160 | | [] -> _copy v b 161 | | [e] -> BoolTensor (Array.init e (fun _ -> _copy v b)) 162 | | e::s' -> BoolTensor (Array.init e (fun _ -> _new_bool s' v b)) 163 | 164 | let rec _new_int (s : int list) v b = match s with 165 | | [] -> _copy v b 166 | | [e] -> IntTensor (Array.init e (fun _ -> _copy v b)) 167 | | e::s' -> IntTensor (Array.init e (fun _ -> _new_int s' v b)) 168 | 169 | let rec _new_float (s : int list) v b = match s with 170 | | [] -> _copy v b 171 | | [e] -> FloatTensor (Array.init e (fun _ -> _copy v b)) 172 | | e::s' -> FloatTensor (Array.init e (fun _ -> _new_float s' v b)) 173 | 174 | let new_bool (s : shape) v = 175 | let s' = Array.to_list s in 176 | let v' = BoolScalar (ref v) in 177 | try (_check_valid_shape s; (ref (s, _new_bool s' v' true, Null) : bool tensor)) 178 | with ZeroDimension -> (ref (s, BoolTensor [||], Null)) 179 | 180 | let new_int (s : shape) v = 181 | let s' = Array.to_list s in 182 | let v' = IntScalar (ref v) in 183 | try (_check_valid_shape s; (ref (s, _new_int s' v' true, Null) : int tensor)) 184 | with ZeroDimension -> (ref (s, IntTensor [||], Null)) 185 | 186 | let new_float (s : shape) v = 187 | let s' = Array.to_list s in 188 | let v' = FloatScalar (ref v) in 189 | try (_check_valid_shape s; (ref (s, _new_float s' v' true, Null) : float tensor)) 190 | with ZeroDimension -> (ref (s, FloatTensor [||], Null)) 191 | 192 | let rec _reduce : 'a. predicate -> (bool * bool -> bool) -> bool -> 'a tensordata -> bool = 193 | fun (type el) f g v (t : el tensordata) : bool -> 194 | match (f, t) with 195 | | (BoolP f', BoolScalar e) -> f' (!e) 196 | | (BoolP _, BoolTensor e) -> Array.fold_left (fun b p -> g (b,(_reduce f g v p))) v e 197 | | (IntP f', IntScalar e) -> f' (!e) 198 | | (IntP _, IntTensor e) -> Array.fold_left (fun b p -> g (b,(_reduce f g v p))) v e 199 | | (FloatP f', FloatScalar e) -> f' (!e) 200 | | (FloatP _, FloatTensor e) -> Array.fold_left (fun b p -> g (b,(_reduce f g v p))) v e 201 | | (_, _) -> raise (TypeMismatch "You can only apply predicate and tensor of the same type") 202 | 203 | let reduce f g v (t : 'a tensor) = let (_, data, _) = !t in _reduce f g v data 204 | let all f (t : 'a tensor) = let (_, data, _) = !t in _reduce f (fun (x,y) -> x && y) true data 205 | let any f (t : 'a tensor) = let (_, data, _) = !t in _reduce f (fun (x,y) -> x || y) false data 206 | 207 | let rec _elem_apply : 'a. op -> 'a tensordata -> 'a tensordata = 208 | fun (type el) (f : op) (t : el tensordata) : el tensordata -> 209 | match (f, t) with 210 | | (BoolOp f', BoolScalar e) -> BoolScalar (ref (f' (!e))) 211 | | (BoolOp _, BoolTensor e) -> BoolTensor (Array.map (fun i -> _elem_apply f i) e) 212 | | (IntOp f', IntScalar e) -> IntScalar (ref (f' (!e))) 213 | | (IntOp _, IntTensor e) -> IntTensor (Array.map (fun i -> _elem_apply f i) e) 214 | | (FloatOp f', FloatScalar e) -> FloatScalar (ref (f' (!e))) 215 | | (FloatOp _, FloatTensor e) -> FloatTensor (Array.map (fun i -> _elem_apply f i) e) 216 | | (_, _) -> raise (TypeMismatch "You can only apply op and tensor of the same type") 217 | 218 | let elem_apply f (t : 'a tensor) = let (shape, data, dag) = !t in 219 | let newd = _elem_apply f data in 220 | match dag with 221 | | Null -> ref (shape, newd , Null) 222 | | Graph (End, x) -> let newt = ref (shape, newd, Graph (End, Node ([|t|], Retain false)) ) in 223 | ( Printf.printf "Warning : you can only back-propagate with float tensors. \n" ; 224 | t := (shape, data, Graph (Fn [|(newt, f)|], x )) ; 225 | newt 226 | ) 227 | | Graph (Fn l, x) -> let newt = ref (shape, newd, Graph (End, Node ([|t|], Retain false)) ) in 228 | ( Printf.printf "Warning : you can only back-propagate with float tensors. \n" ; 229 | t := (shape, data, Graph (Fn (Array.append [|(newt, f)|] l), x )) ; 230 | newt 231 | ) 232 | 233 | let sigmoid (t : 'a tensor) = elem_apply (FloatOp (fun x -> Float.exp(x) /. (Float.exp(x) +. 1.0))) t 234 | 235 | let _abs (type el) (t : el tensordata) : el tensordata = 236 | let absf v = if v > 0.0 then v else v *. (-1.0) in 237 | let absi v = if v > 0 then v else v * (-1) in 238 | let absb _ = true in 239 | match t with 240 | | BoolScalar e -> _elem_apply (BoolOp absb) (BoolScalar e) 241 | | BoolTensor e -> _elem_apply (BoolOp absb) (BoolTensor e) 242 | | IntScalar e -> _elem_apply (IntOp absi) (IntScalar e) 243 | | IntTensor e -> _elem_apply (IntOp absi) (IntTensor e) 244 | | FloatScalar e -> _elem_apply (FloatOp absf) (FloatScalar e) 245 | | FloatTensor e -> _elem_apply (FloatOp absf) (FloatTensor e) 246 | 247 | let abs (t : 'a tensor) = 248 | let (shape, data, dag) = !t in 249 | let newd = _abs data in 250 | match dag with 251 | | Null -> ref (shape, newd , Null) 252 | | Graph (End, x) -> let newt = ref (shape, newd, Graph (End, Node ([|t|], Retain false)) ) in 253 | ( Printf.printf "Warning : you can only back-propagate with float tensors. \n" ; 254 | t := (shape, data, Graph (Fn [|(newt, FloatOp (fun x -> x))|], x )) ; 255 | newt 256 | ) 257 | | Graph (Fn l, x) -> let newt = ref (shape, newd, Graph (End, Node ([|t|], Retain false)) ) in 258 | ( Printf.printf "Warning : you can only back-propagate with float tensors. \n" ; 259 | t := (shape, data, Graph (Fn (Array.append [|(newt, FloatOp (fun x -> x))|] l), x )) ; 260 | newt 261 | ) 262 | 263 | let _new_t (type el) (s : int list) (v : el tensordata) b : el tensordata = 264 | let f t b = if b then ref (!t) else t in 265 | match (s,v) with 266 | | ([], IntScalar t) -> IntScalar (f t b) 267 | | (e::s', IntScalar t) -> _new_int (e::s') (IntScalar t) b 268 | | ([], IntTensor t) -> _copy (IntTensor t) b 269 | | (e::s', IntTensor t) -> _new_int (e::s') (IntTensor t) b 270 | | ([], FloatTensor t) -> _copy (FloatTensor t) b 271 | | (e::s', FloatTensor t) -> _new_float (e::s') (FloatTensor t) b 272 | | ([], FloatScalar t) -> FloatScalar (f t b) 273 | | (e::s', FloatScalar t) -> _new_float (e::s') (FloatScalar t) b 274 | | ([], BoolScalar t) -> BoolScalar (f t b) 275 | | (e::s', BoolScalar t) -> _new_bool (e::s') (BoolScalar t) b 276 | | ([], BoolTensor t) -> _copy (BoolTensor t) b 277 | | (e::s', BoolTensor t) -> _new_bool (e::s') (BoolTensor t) b 278 | 279 | let new_t (type el) (s : shape) (t : el tensor) b = 280 | let s' = (Array.to_list s) in 281 | let (shape, data, dag) = !t in 282 | let news = Array.of_list( List.append s' (Array.to_list shape) ) in 283 | let newt = 284 | try (_check_valid_shape s; (news, _new_t s' data b, dag)) 285 | with ZeroDimension -> match data with 286 | | IntScalar _ -> (s, IntTensor [||], dag) 287 | | IntTensor _ -> (s, IntTensor [||], dag) 288 | | FloatScalar _ -> (s, FloatTensor [||], dag) 289 | | FloatTensor _ -> (s, FloatTensor [||], dag) 290 | | BoolScalar _ -> (s, BoolTensor [||], dag) 291 | | BoolTensor _ -> (s, BoolTensor [||], dag) 292 | in 293 | if b then ref newt else (t := newt; t) 294 | 295 | let rec _getset: 'a. 'a tensordata -> int list -> ('a ref -> 'b) -> 'b = 296 | fun (type el) (t : el tensordata) idx (f : el ref -> 'a) -> 297 | match (t, idx) with 298 | | (FloatScalar r, []) -> f r 299 | | (FloatTensor r, e::s') -> _getset (Array.get r e) s' f 300 | | (IntScalar r, []) -> f r 301 | | (IntTensor r, e::s') -> _getset (Array.get r e) s' f 302 | | (BoolScalar r, []) -> f r 303 | | (BoolTensor r, e::s') -> _getset (Array.get r e) s' f 304 | | _ -> raise TensorInvariantViolated 305 | 306 | let _check_valid_idx (_, shape, idx) = 307 | let len1 = Array.length shape in 308 | let len2 = Array.length idx in 309 | if (len1) != (len2) then raise (IndexError (("Expected index of length "^(string_of_int len1))^("; Got "^(string_of_int len2)) ) ) 310 | else if idx < Array.init len1 (fun _ -> 0) then raise (IndexError "Negative indexing not supported") 311 | else if not (Array.fold_left (fun x y -> x && y) true (Array.init len1 (fun i -> (Array.get idx i) < (Array.get shape i))) ) 312 | then raise (IndexError "Array index out of bound") 313 | else () 314 | 315 | let set (t : 'a tensor) idx e = let (shape, data, _) = !t in 316 | (_check_valid_idx (data, shape, idx) ; _getset data (Array.to_list idx) (fun x -> x := e)) 317 | 318 | let get (t : 'a tensor) idx = let (shape, data, _) = !t in 319 | (_check_valid_idx (data, shape, idx) ; _getset data (Array.to_list idx) (fun x -> !x)) 320 | 321 | (* dangerous *) 322 | let _set t idx e = _getset t (Array.to_list idx) (fun x -> x := e) 323 | 324 | let _check_broadcastable s d = 325 | let (source, destination) = ((List.rev (Array.to_list s)), (List.rev (Array.to_list d))) in 326 | let rec _check_broadcastable' source destination = 327 | match (source, destination) with 328 | | ([], _) -> (destination, []) 329 | | (_ :: _,[]) -> raise (ShapeMismatch "source array has more dimensions than desired shape") 330 | | (s :: s', d :: d') -> 331 | if s != d && s != 1 then raise (ShapeMismatch "one of the trailing dimensions don't agree") 332 | else let (lead, trail) = _check_broadcastable' s' d' in (lead, d::trail) in 333 | let (s', d') = _check_broadcastable' source destination in 334 | (List.rev s', List.rev d') 335 | 336 | let rec _map: 'a. 'a tensordata -> int list -> int list -> bool -> 'a tensordata = 337 | fun (type el) (t : el tensordata) source target copy : el tensordata -> 338 | match (t, source, target) with 339 | | (IntScalar r, [], []) -> if copy then IntScalar (ref (!r)) else IntScalar r 340 | | (IntTensor r, e::e', d::d') -> 341 | if e = d then 342 | IntTensor (Array.map (fun i -> _map i e' d' copy) r) 343 | else 344 | IntTensor (Array.init d (fun _ -> _map (Array.get r 0) e' d' copy)) 345 | | (IntTensor r, [], []) -> IntTensor r 346 | | (FloatScalar r, [], []) -> if copy then FloatScalar (ref (!r)) else FloatScalar r 347 | | (FloatTensor r, e::e', d::d') -> 348 | if e = d then 349 | FloatTensor (Array.map (fun i -> _map i e' d' copy) r) 350 | else 351 | FloatTensor (Array.init d (fun _ -> _map (Array.get r 0) e' d' copy)) 352 | | (FloatTensor r, [], []) -> FloatTensor r 353 | | (BoolScalar r, [], []) -> if copy then BoolScalar (ref (!r)) else BoolScalar r 354 | | (BoolTensor r, e::e', d::d') -> 355 | if e = d then 356 | BoolTensor (Array.map (fun i -> _map i e' d' copy) r) 357 | else 358 | BoolTensor (Array.init d (fun _ -> _map (Array.get r 0) e' d' copy)) 359 | | (BoolTensor r, [], []) -> BoolTensor r 360 | | _ -> raise TensorInvariantViolated 361 | 362 | 363 | let _broadcast (type el) (t : el tensordata) 364 | (source : int list) (lead : int list) 365 | (trail : int list) (copy : bool) : el tensordata = 366 | let f t b = if b then ref (!t) else t in 367 | match t with 368 | | FloatTensor r -> _new_float lead (_map (FloatTensor r) source trail copy) copy 369 | | BoolTensor r -> _new_bool lead (_map (BoolTensor r) source trail copy) copy 370 | | IntTensor r -> _new_int lead (_map (IntTensor r) source trail copy) copy 371 | | IntScalar r -> _new_int lead (IntScalar (f r copy)) copy 372 | | FloatScalar r -> _new_float lead (FloatScalar (f r copy)) copy 373 | | BoolScalar r -> _new_bool lead (BoolScalar (f r copy)) copy 374 | 375 | let broadcast t destination copy = let (source, data, dag) = !t in 376 | let (lead_dim, trail_dim) = _check_broadcastable source destination in 377 | let newdata = _broadcast data (Array.to_list source) lead_dim trail_dim copy in 378 | let news = Array.of_list (lead_dim @ trail_dim) in 379 | ref (news, newdata, dag) 380 | 381 | let rec _elem_mul : 'a. 'a tensordata -> 'a tensordata -> 'a tensordata = 382 | fun (type el) (t1 : el tensordata) (t2 : el tensordata) : el tensordata -> 383 | match (t1, t2) with 384 | | (BoolScalar s, BoolScalar s') -> BoolScalar (ref (!s && !s')) 385 | | (BoolScalar s, BoolTensor t) -> if !s then _copy (BoolTensor t) true else _elem_apply (BoolOp (fun _ -> false)) (BoolTensor t) 386 | | (BoolTensor t, BoolScalar s) -> if !s then _copy (BoolTensor t) true else _elem_apply (BoolOp (fun _ -> false)) (BoolTensor t) 387 | | (IntScalar s, IntScalar s') -> IntScalar (ref (!s * !s')) 388 | | (IntScalar s, IntTensor t) -> _elem_apply (IntOp (fun i -> !s * i)) (IntTensor t) 389 | | (IntTensor t, IntScalar s) -> _elem_apply (IntOp (fun i -> !s * i)) (IntTensor t) 390 | | (FloatScalar s, FloatScalar s') -> FloatScalar (ref (!s *. !s')) 391 | | (FloatScalar s, FloatTensor t) -> _elem_apply (FloatOp (fun i -> !s *. i)) (FloatTensor t) 392 | | (FloatTensor t, FloatScalar s) -> _elem_apply (FloatOp (fun i -> !s *. i)) (FloatTensor t) 393 | | (FloatTensor t, FloatTensor t') -> FloatTensor (Array.mapi (fun i e -> _elem_mul (Array.get t i) e) t') 394 | | (IntTensor t, IntTensor t') -> IntTensor (Array.mapi (fun i e -> _elem_mul (Array.get t i) e) t') 395 | | (BoolTensor t, BoolTensor t') -> BoolTensor (Array.mapi (fun i e -> _elem_mul (Array.get t i) e) t') 396 | 397 | 398 | let (#*) (t1 : 'a tensor) (t2 : 'a tensor) : 'a tensor = 399 | let ((s1, d1, dag1),(s2, d2, dag2)) = (!t1, !t2) in 400 | let max_dim s1 s2 = 401 | let (l1, l2) = ((List.rev (Array.to_list s1)),(List.rev (Array.to_list s2))) in 402 | let rec max_dim' l1 l2 = 403 | match (l1, l2) with 404 | | ([], []) -> [] 405 | | (x::xs, []) -> x::xs 406 | | ([], x::xs) -> x::xs 407 | | (x::xs, y::ys) -> (max x y)::(max_dim' xs ys) in 408 | List.rev (max_dim' l1 l2) in 409 | let news = Array.of_list (max_dim s1 s2) in 410 | match (Array.length s1, Array.length s2, s1=s2) with 411 | | (0, _, _) -> ref (s2,_elem_mul d1 d2,dag2) 412 | | (_, 0, _) -> ref (s1,_elem_mul d1 d2,dag1) 413 | | (_, _, true) -> ref (s1,_elem_mul d1 d2,dag1) 414 | | (_, _, _) -> 415 | let 416 | ((_,t1',dag1),(_,t2',_)) = (!(broadcast t1 news true),!(broadcast t2 news true)) in 417 | ref (news,_elem_mul t1' t2',dag1) 418 | --------------------------------------------------------------------------------