├── .gitignore ├── .travis.yml ├── .vscode ├── launch.json └── tasks.json ├── Elements.fs ├── ElementsTests.fs ├── FourierMotzkin.fs ├── LICENSE ├── Program.fs ├── README.md ├── TensorAlgDiff.fsproj └── elemdiff.pdf /.gitignore: -------------------------------------------------------------------------------- 1 | bin/ 2 | obj/ 3 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: csharp 2 | mono: none 3 | dotnet: 2.1.105 4 | script: 5 | - unset DOTNET_CLI_TELEMETRY_OPTOUT 6 | - dotnet build -c Release 7 | - dotnet test -c Release 8 | - dotnet run -c Release 9 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": ".NET Core Launch (console)", 9 | "type": "coreclr", 10 | "request": "launch", 11 | "preLaunchTask": "build", 12 | "program": "${workspaceRoot}/bin/Debug/netcoreapp2.0/TensorAlgDiff.dll", 13 | "args": [], 14 | "cwd": "${workspaceRoot}", 15 | "stopAtEntry": true, 16 | "console": "internalConsole" 17 | } 18 | ] 19 | } -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | // See https://go.microsoft.com/fwlink/?LinkId=733558 3 | // for the documentation about the tasks.json format 4 | "version": "2.0.0", 5 | "tasks": [ 6 | { 7 | "label": "build", 8 | "command": "dotnet build", 9 | "type": "shell", 10 | "group": "build", 11 | "presentation": { 12 | "reveal": "silent" 13 | }, 14 | "problemMatcher": "$msCompile" 15 | }, 16 | { 17 | "label": "tests", 18 | "command": "dotnet test -v:n", 19 | "type": "shell", 20 | "problemMatcher": "$msCompile" 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /Elements.fs: -------------------------------------------------------------------------------- 1 | namespace Elements 2 | 3 | open System 4 | open Tensor 5 | open Tensor.Algorithm 6 | 7 | 8 | /// element expression 9 | module Elements = 10 | 11 | /// An expression for an index as a linear combination. 12 | [] 13 | type IdxExpr = 14 | IdxExpr of Map 15 | with 16 | static member zero = 17 | IdxExpr Map.empty 18 | static member one = 19 | IdxExpr.factor "1" Rat.One 20 | static member named name = 21 | IdxExpr.factor name Rat.One 22 | static member constant value = 23 | value * IdxExpr.one 24 | static member factor dim value = 25 | IdxExpr (Map [dim, value]) 26 | static member (~-) (IdxExpr af) = 27 | af |> Map.map (fun ai av -> -av) |> IdxExpr 28 | static member (+) (IdxExpr af, IdxExpr bf) = 29 | let f = bf |> Map.fold (fun f i bv -> match f |> Map.tryFind i with 30 | | Some v -> f |> Map.add i (v+bv) 31 | | None -> f |> Map.add i bv) af 32 | IdxExpr f 33 | static member (-) (a: IdxExpr, b: IdxExpr) = 34 | a + (-b) 35 | static member (*) (f: Rat, IdxExpr bf) = 36 | bf |> Map.map (fun bi bv -> f * bv) |> IdxExpr 37 | static member (/) (IdxExpr af, f: Rat) = 38 | af |> Map.map (fun ai av -> av / f) |> IdxExpr 39 | member this.Pretty = 40 | let (IdxExpr f) = this 41 | let sf = 42 | Map.toList f 43 | |> List.map fst 44 | |> List.sort 45 | |> List.choose (fun n -> 46 | if f.[n] = Rat.Zero then None 47 | elif f.[n] = Rat.One then Some n 48 | elif f.[n] = Rat.MinusOne then Some ("-" + n) 49 | elif n = "1" then Some (sprintf "%A" f.[n]) 50 | else Some (sprintf "%A*%s" f.[n] n)) 51 | if List.isEmpty sf then "0" else sf |> String.concat " + " 52 | static member name (IdxExpr f) = 53 | f |> Map.toList |> List.exactlyOne |> fst 54 | member this.Name = IdxExpr.name this 55 | static member eval idxEnv (IdxExpr f) = 56 | let idxEnv = idxEnv |> Map.add "1" Rat.One 57 | f |> Map.fold (fun s i v -> s + v * idxEnv.[i]) Rat.Zero 58 | static member subst (repl: Map) (IdxExpr f) = 59 | (IdxExpr.zero, f) ||> Map.fold (fun r i v -> 60 | match repl |> Map.tryFind i with 61 | | Some iv -> r + v * iv 62 | | None -> r + IdxExpr.factor i v) 63 | static member constVal (IdxExpr f) = 64 | match f |> Map.tryFind "1" with 65 | | Some v -> v 66 | | None -> Rat.Zero 67 | static member ofSeq indices values = 68 | Seq.zip indices values 69 | |> Map.ofSeq 70 | |> IdxExpr 71 | 72 | /// Matches an index expression that consists only of a constant. 73 | let (|ConstIdxExpr|_|) (IdxExpr f) = 74 | let f = f |> Map.toList |> List.filter (fun (_, v) -> v <> Rat.Zero) 75 | match f with 76 | | [] -> Some Rat.Zero 77 | | [i, v] when i = "1" -> Some v 78 | | _ -> None 79 | 80 | /// Matches an index expression that consists only of a single (non-constant) factor. 81 | let (|SingleIdxExpr|_|) (IdxExpr f) = 82 | let f = f |> Map.toList |> List.filter (fun (_, v) -> v <> Rat.Zero) 83 | match f with 84 | | [i, v] when i <> "1" -> Some (i, v) 85 | | _ -> None 86 | 87 | 88 | /// Index expressions for all indicies of a tensor. 89 | [] 90 | type IdxExprs = 91 | IdxExprs of IdxExpr list 92 | with 93 | static member toMatrix inNames (IdxExprs idx) = 94 | let nIn = List.length inNames |> int64 95 | let nOut = idx |> List.length |> int64 96 | let m = HostTensor.zeros [nOut; nIn] 97 | idx |> List.iteri (fun r (IdxExpr f) -> 98 | f |> Map.iter (fun name v -> 99 | match inNames |> List.tryFindIndex ((=) name) with 100 | | Some c -> m.[[int64 r; int64 c]] <- v 101 | | None -> failwithf "dimension %s does not exist" name)) 102 | m 103 | member this.Pretty = 104 | let (IdxExprs idx) = this 105 | sprintf "%A" idx 106 | static member eval idxEnv (IdxExprs idx) = 107 | idx |> List.map (IdxExpr.eval idxEnv) 108 | static member subst repl (IdxExprs idx) = 109 | idx |> List.map (IdxExpr.subst repl) |> IdxExprs 110 | static member length (IdxExprs idx) = 111 | List.length idx 112 | 113 | type LeafOp = 114 | | Const of float 115 | | IdxValue of idx:IdxExpr 116 | | Argument of name:string * idxs:IdxExprs 117 | 118 | and UnaryOp = 119 | | Negate 120 | | Abs 121 | | Sgn 122 | | Log 123 | | Log10 124 | | Exp 125 | | Tanh 126 | | Sqrt 127 | | Sum of idx:string * lows:IdxExpr list * highs:IdxExpr list 128 | 129 | and BinaryOp = 130 | | Add 131 | | Substract 132 | | Multiply 133 | | Divide 134 | | Modulo 135 | | Power 136 | | IdxIf of idx:IdxExpr * cmp:IdxComparison 137 | 138 | and IdxComparison = 139 | | EqualToZero 140 | | GreaterOrEqualToZero 141 | | Integer 142 | 143 | /// an element expression 144 | and [] 145 | ElemExpr = 146 | | Leaf of LeafOp 147 | | Unary of UnaryOp * ElemExpr 148 | | Binary of BinaryOp * ElemExpr * ElemExpr 149 | 150 | and [] 151 | ElemFunc = { 152 | Name: string 153 | DimNames: string list 154 | DimSize: Map 155 | Expr: ElemExpr 156 | ArgShapes: Map 157 | } with 158 | member this.Pretty = 159 | let dims = this.DimNames |> String.concat "; " 160 | sprintf "%s[%s] = %A" this.Name dims this.Expr 161 | member this.Shape = 162 | this.DimNames |> List.map (fun d -> this.DimSize.[d]) 163 | 164 | /// Returns all arguments occuring in the given expression. 165 | let rec extractArgs expr = 166 | match expr with 167 | | Leaf (Argument (name, idxs)) -> Set [name, idxs] 168 | | Leaf _ -> Set.empty 169 | | Unary (_, a) -> extractArgs a 170 | | Binary (_, a, b) -> Set.union (extractArgs a) (extractArgs b) 171 | 172 | /// Builds a function. 173 | let func name dimNames dimSizes argShapes expr = 174 | for (argName, argIdx) in extractArgs expr do 175 | match argShapes |> Map.tryFind argName with 176 | | Some shp when IdxExprs.length argIdx <> List.length shp -> 177 | failwithf "shape dimensionality mismatch for argument %s" argName 178 | | Some shp -> () 179 | | None -> failwithf "no shape specified for argument %s" argName 180 | {Name=name; DimNames=dimNames; DimSize=dimSizes; Expr=expr; ArgShapes=argShapes} 181 | 182 | /// a constant value given by a ConstSpec 183 | let scalar v = Leaf (Const v) 184 | 185 | type ElemExpr with 186 | 187 | // elementwise unary 188 | static member (~+) (a: ElemExpr) = a 189 | static member (~-) (a: ElemExpr) = Unary(Negate, a) 190 | static member Abs (a: ElemExpr) = Unary(Abs, a) 191 | static member Sgn (a: ElemExpr) = Unary(Sgn, a) 192 | static member Log (a: ElemExpr) = Unary(Log, a) 193 | static member Log10 (a: ElemExpr) = Unary(Log10, a) 194 | static member Exp (a: ElemExpr) = Unary(Exp, a) 195 | static member Tanh (a: ElemExpr) = Unary(Tanh, a) 196 | static member Sqrt (a: ElemExpr) = Unary(Sqrt, a) 197 | 198 | // elementwise binary 199 | static member (+) (a: ElemExpr, b: ElemExpr) = Binary(Add, a, b) 200 | static member (-) (a: ElemExpr, b: ElemExpr) = Binary(Substract, a, b) 201 | static member (*) (a: ElemExpr, b: ElemExpr) = Binary(Multiply, a, b) 202 | static member (/) (a: ElemExpr, b: ElemExpr) = Binary(Divide, a, b) 203 | static member (%) (a: ElemExpr, b: ElemExpr) = Binary(Modulo, a, b) 204 | static member Pow (a: ElemExpr, b: ElemExpr) = Binary(Power, a, b) 205 | static member ( *** ) (a: ElemExpr, b: ElemExpr) = a ** b 206 | 207 | // elementwise binary with basetype 208 | static member (+) (a: ElemExpr, b: float) = a + (scalar b) 209 | static member (-) (a: ElemExpr, b: float) = a - (scalar b) 210 | static member (*) (a: ElemExpr, b: float) = a * (scalar b) 211 | static member (/) (a: ElemExpr, b: float) = a / (scalar b) 212 | static member (%) (a: ElemExpr, b: float) = a % (scalar b) 213 | static member Pow (a: ElemExpr, b: float) = a ** (scalar b) 214 | static member ( *** ) (a: ElemExpr, b: float) = a ** (scalar b) 215 | 216 | static member (+) (a: float, b: ElemExpr) = (scalar a) + b 217 | static member (-) (a: float, b: ElemExpr) = (scalar a) - b 218 | static member (*) (a: float, b: ElemExpr) = (scalar a) * b 219 | static member (/) (a: float, b: ElemExpr) = (scalar a) / b 220 | static member (%) (a: float, b: ElemExpr) = (scalar a) % b 221 | static member Pow (a: float, b: ElemExpr) = (scalar a) ** b 222 | static member ( *** ) (a: float, b: ElemExpr) = (scalar a) ** b 223 | 224 | member private this.PrettyAndPriority = 225 | match this with 226 | | Leaf (op) -> 227 | let myPri = 20 228 | let myStr = 229 | match op with 230 | | Const v -> sprintf "%g" v 231 | | IdxValue idx -> sprintf "(%A)" idx 232 | | Argument (name, idxs) -> sprintf "%s%A" name idxs 233 | myStr, myPri 234 | 235 | | Unary (op, a) -> 236 | let myPri = 10 237 | let aStr, aPri = a.PrettyAndPriority 238 | let aStr = 239 | if myPri > aPri then sprintf "(%s)" aStr 240 | else aStr 241 | let myStr = 242 | match op with 243 | | Negate -> sprintf "(-%s)" aStr 244 | | Abs -> sprintf "abs %s" aStr 245 | | Sgn -> sprintf "sgn %s" aStr 246 | | Log -> sprintf "log %s" aStr 247 | | Log10 -> sprintf "log10 %s" aStr 248 | | Exp -> sprintf "exp %s" aStr 249 | | Tanh -> sprintf "tanh %s" aStr 250 | | Sqrt -> sprintf "sqrt %s" aStr 251 | | Sum (sym, lows, highs) -> 252 | let lowsStr = 253 | match lows with 254 | | [ConstIdxExpr low] -> sprintf "%A" low 255 | | [low] -> sprintf "(%A)" low 256 | | _ -> sprintf "(max %A)" lows 257 | let highsStr = 258 | match highs with 259 | | [ConstIdxExpr high] -> sprintf "%A" high 260 | | [high] -> sprintf "(%A)" high 261 | | _ -> sprintf "(min %A)" highs 262 | sprintf "sum{%s}_%s^%s (%s)" sym lowsStr highsStr aStr 263 | myStr, myPri 264 | 265 | | Binary(op, a, b) -> 266 | let aStr, aPri = a.PrettyAndPriority 267 | let bStr, bPri = b.PrettyAndPriority 268 | match op with 269 | | Add | Substract | Multiply | Divide | Modulo | Power -> 270 | let mySym, myPri = 271 | match op with 272 | | Add -> "+", 1 273 | | Substract -> "-", 1 274 | | Multiply -> "*", 2 275 | | Divide -> "/", 2 276 | | Modulo -> "%", 2 277 | | Power -> "**", 5 278 | | _ -> failwith "unexpected" 279 | let aStr = 280 | if myPri > aPri then sprintf "(%s)" aStr 281 | else aStr 282 | let bStr = 283 | if myPri > bPri then sprintf "(%s)" bStr 284 | else bStr 285 | let myStr = sprintf "%s %s %s" aStr mySym bStr 286 | myStr, myPri 287 | | IdxIf (idx, cmp) -> 288 | let cmpStr = 289 | match cmp with 290 | | GreaterOrEqualToZero -> ">= 0" 291 | | EqualToZero -> "= 0" 292 | | Integer -> "is int" 293 | sprintf "if {%A %s} then (%s) else (%s)" idx cmpStr aStr bStr, 0 294 | 295 | member this.Pretty = this.PrettyAndPriority |> fst 296 | 297 | /// sign keeping type 298 | let sgn (a: ElemExpr) = 299 | ElemExpr.Sgn a 300 | 301 | /// square root 302 | let sqrtt (a: ElemExpr) = 303 | ElemExpr.Sqrt a 304 | 305 | /// index symbol for given dimension of the result 306 | let idxValue idx = 307 | Leaf (IdxValue idx) 308 | 309 | /// specifed element of argument 310 | let arg name idx = 311 | Leaf (Argument (name, IdxExprs idx)) 312 | 313 | /// index of given name 314 | let pos name = IdxExpr.factor name Rat.One 315 | 316 | /// constant index value 317 | let idxConst v = IdxExpr.factor "1" v 318 | 319 | /// index value one 320 | let idxOne = idxConst Rat.One 321 | 322 | /// Summation over an index. 323 | let sum idx lows highs a = 324 | Unary (Sum (idx, lows, highs), a) 325 | 326 | /// Summation over an index using constant low and high values. 327 | let sumConstRng idx (low: int64) (high: int64) a = 328 | sum idx [IdxExpr.constant (Rat low)] [IdxExpr.constant (Rat high)] a 329 | 330 | /// Expression conditioned on index values. 331 | let idxIf idx cmp thenExpr elseExpr = 332 | match cmp, idx with 333 | | EqualToZero, ConstIdxExpr v when v = Rat.Zero -> thenExpr 334 | | EqualToZero, ConstIdxExpr v -> elseExpr 335 | | GreaterOrEqualToZero, ConstIdxExpr v when v >= Rat.Zero -> thenExpr 336 | | GreaterOrEqualToZero, ConstIdxExpr v -> elseExpr 337 | | _ -> Binary (IdxIf (idx, cmp), thenExpr, elseExpr) 338 | 339 | /// Substitutes the specified size symbols with their replacements. 340 | let rec substIdx repl expr = 341 | let sub = substIdx repl 342 | match expr with 343 | | Leaf (IdxValue idx) -> Leaf (IdxValue (IdxExpr.subst repl idx)) 344 | | Leaf (Argument (name, idxs)) -> Leaf (Argument (name, IdxExprs.subst repl idxs)) 345 | | Leaf (op) -> Leaf (op) 346 | | Unary (Sum (idx, lows, highs), a) -> 347 | Unary (Sum (idx, lows |> List.map (IdxExpr.subst repl), highs |> List.map (IdxExpr.subst repl)), 348 | substIdx (repl |> Map.remove idx) a) 349 | | Unary (op, a) -> Unary (op, sub a) 350 | | Binary (IdxIf (idx, cmp), a, b) -> 351 | Binary (IdxIf (idx |> IdxExpr.subst repl, cmp), sub a, sub b) 352 | | Binary (op, a, b) -> Binary (op, sub a, sub b) 353 | 354 | /// Evaluates the given expression. 355 | let rec evalExpr (argEnv: Map>) idxEnv expr = 356 | let subEval = evalExpr argEnv idxEnv 357 | match expr with 358 | | Leaf op -> 359 | match op with 360 | | Const v -> v 361 | | IdxValue idx -> idx |> IdxExpr.eval idxEnv |> float 362 | | Argument (name, idxs) -> 363 | let idxs = idxs |> IdxExprs.eval idxEnv |> List.map int64 364 | match argEnv |> Map.tryFind name with 365 | | Some arg -> arg.[idxs] 366 | | None -> failwithf "argument %s not present in argument environment" name 367 | 368 | | Unary (op, a) -> 369 | match op with 370 | | Negate -> -(subEval a) 371 | | Abs -> abs (subEval a) 372 | | Sgn -> Operators.sgn (subEval a) 373 | | Log -> log (subEval a) 374 | | Log10 -> log10 (subEval a) 375 | | Exp -> exp (subEval a) 376 | | Tanh -> tanh (subEval a) 377 | | Sqrt -> sqrt (subEval a) 378 | | Sum (sym, lows, highs) -> 379 | let low = lows |> List.map (IdxExpr.eval idxEnv) |> List.max |> ceil 380 | let high = highs |> List.map (IdxExpr.eval idxEnv) |> List.min |> floor 381 | seq {low .. high} 382 | |> Seq.map (fun v -> evalExpr argEnv (idxEnv |> Map.add sym v) a) 383 | |> Seq.sum 384 | 385 | | Binary (op, a, b) -> 386 | match op with 387 | | Add -> (subEval a) + (subEval b) 388 | | Substract -> (subEval a) - (subEval b) 389 | | Multiply -> (subEval a) * (subEval b) 390 | | Divide -> (subEval a) / (subEval b) 391 | | Modulo -> (subEval a) % (subEval b) 392 | | Power -> (subEval a) ** (subEval b) 393 | | IdxIf (idx, cmp) -> 394 | let idxVal = idx |> IdxExpr.eval idxEnv 395 | match cmp with 396 | | EqualToZero when idxVal = Rat.Zero -> subEval a 397 | | EqualToZero -> subEval b 398 | | GreaterOrEqualToZero when idxVal >= Rat.Zero -> subEval a 399 | | GreaterOrEqualToZero -> subEval b 400 | | Integer when Rat.isInteger idxVal -> subEval a 401 | | Integer -> subEval b 402 | 403 | /// Evaluates the given function. 404 | let evalFunc argEnv (func: ElemFunc) = 405 | let fv = HostTensor.zeros func.Shape 406 | for pos in Tensor.Backend.TensorLayout.allIdxOfShape func.Shape do 407 | let idxEnv = 408 | List.zip pos func.DimNames 409 | |> List.fold (fun env (p, name) -> env |> Map.add name (Rat p)) Map.empty 410 | fv.[pos] <- evalExpr argEnv idxEnv func.Expr 411 | fv 412 | 413 | /// Calculates the derivative expression given the incoming derivative dExpr. 414 | let rec derivExpr syms constrs expr dExpr = 415 | // constrs >= 0 416 | let d = dExpr 417 | let rds = derivExpr syms constrs 418 | match expr with 419 | | Leaf op -> 420 | match op with 421 | | Const v -> [] 422 | | IdxValue idx -> [] 423 | | Argument (name, idxs) -> [(name, idxs), (syms, constrs, d)] 424 | | Unary (op, a) -> 425 | match op with 426 | | Negate -> -d |> rds a 427 | | Abs -> d * sgn a |> rds a 428 | | Sgn -> [] 429 | | Log -> d * (a ** -1.0) |> rds a 430 | | Log10 -> d |> rds (log a / log 10.0) 431 | | Exp -> d * exp a |> rds a 432 | | Tanh -> d * (1.0 - (tanh a)**2.0) |> rds a 433 | | Sqrt -> d * (1.0 / (2.0 * sqrtt a)) |> rds a 434 | | Sum (sym, lows, highs) -> 435 | // low limits: lows <= sym => sym - lows >= 0 436 | let lowConstrs = lows |> List.map (fun low -> IdxExpr.named sym - low) |> Set.ofList 437 | // high limits: sym <= highs => -sym + highs >= 0 438 | let highConstrs = highs |> List.map (fun high -> -IdxExpr.named sym + high) |> Set.ofList 439 | derivExpr (syms |> Set.add sym) (Set.unionMany [constrs; lowConstrs; highConstrs]) a d 440 | | Binary (op, a, b) -> 441 | let (.+) da db = List.append (rds a da) (rds b db) 442 | match op with 443 | | Add -> d .+ d 444 | | Substract -> d .+ (-d) 445 | | Multiply -> (d * b) .+ (a * d) 446 | | Divide -> d |> rds (a * b ** -1.0) 447 | | Modulo -> failwith "buggy" 448 | | Power -> (d * b * a**(b - 1.0)) .+ (d * a**b * log a) 449 | | IdxIf (idx, cmp) -> 450 | (idxIf idx cmp d (scalar 0.0)) .+ (idxIf idx cmp (scalar 0.0) d) 451 | 452 | 453 | /// Calculates the derivative functions of y w.r.t. all of its arguments. 454 | let derivFunc (y: ElemFunc) = 455 | // get dimension names and add constant bias dimension 456 | let ySyms = y.DimNames @ ["1"] |> Set.ofList 457 | 458 | // incoming derivative dy w.r.t. function y 459 | let dyArgName = sprintf "d%s" y.Name 460 | let dy = arg dyArgName (y.DimNames |> List.map (fun d -> IdxExpr.factor d Rat.One)) 461 | let argShapes = y.ArgShapes |> Map.add dyArgName y.Shape 462 | 463 | // Build constraints from ranges of y. 464 | // low limit: y_i >= 0 465 | let rngLowConstrs = y.DimNames |> List.map (fun name -> IdxExpr.named name) |> Set.ofList 466 | // low limit: y_i <= size_i-1 => -y_i + size_i - 1 >= 0 467 | let rngHighConstrs = 468 | y.DimSize 469 | |> Map.toSeq 470 | |> Seq.map (fun (name, size) -> -IdxExpr.named name + IdxExpr.constant (Rat (size-1L))) 471 | |> Set.ofSeq 472 | let rngConstrs = Set.union rngLowConstrs rngHighConstrs 473 | 474 | // Calculate derivative expressions w.r.t. all indiced arguments. 475 | let dxs = derivExpr ySyms rngConstrs y.Expr dy 476 | 477 | // Perform index substitution and nullspace summation on the derivatives of all arguments. 478 | let processDeriv xName (IdxExprs xIdxs) (ySyms: Set) (yConstrs: Set) dx = //(yIdxs1: Map) dx = 479 | // get names of used indices 480 | let yIdxNames1 = Set.toList ySyms 481 | 482 | // name the argument and its indices 483 | let dxName = sprintf "d%s" xName 484 | let dxIdxNames = xIdxs |> List.mapi (fun i _ -> sprintf "%s_%d" dxName i) 485 | let dxIdxSizes = dxIdxNames |> List.mapi (fun i name -> name, y.ArgShapes.[xName].[i]) |> Map.ofList 486 | 487 | // Add "1" dimension to indices for constant terms. 488 | let dxIdxs1, dxIdxNames1 = xIdxs @ [IdxExpr.one], dxIdxNames @ ["1"] 489 | 490 | // Construct matrix mapping from function indices to argument indices yToX[xDim, yDim]. 491 | let yToX = IdxExprs.toMatrix yIdxNames1 (IdxExprs dxIdxs1) |> Tensor.convert 492 | 493 | // Compute the generalized inverse of it: 494 | // y = XToY .* x + Nullspace .* z 495 | let xToY, xSolvability, yNull = LinAlg.integerInverse yToX 496 | 497 | // Build constraint matrix C from constraints specified as index expressions. 498 | // Constraints are specified as: C .* y >= 0 499 | // This translates to: 500 | // C .* XToY .* x + C .* Nullspace .* z >= 0 501 | // C .* Nullspace .* z >= - C .* XToY .* x 502 | let yConstrs = yConstrs |> Set.toList |> IdxExprs 503 | let C = IdxExprs.toMatrix yIdxNames1 yConstrs 504 | 505 | // Compute the summation range constraints. 506 | let CNull = C .* Tensor.convert yNull 507 | let sumConstr = FourierMotzkin.solve CNull 508 | 509 | // Perform summation over nullspace. 510 | let rec buildSum summand sols sumSyms = 511 | match sols with 512 | | FourierMotzkin.Feasibility fs :: rSols -> 513 | let summand = buildSum summand rSols sumSyms 514 | // System is feasible if fs .* b <= 0, where b = - C .* XToY .* x 515 | let fsMat = -fs .* C .* xToY |> HostTensor.toList2D 516 | let fsIdxs = 517 | fsMat 518 | |> List.map (fun bFacs -> IdxExpr.ofSeq dxIdxNames1 bFacs) 519 | |> List.filter (fun ie -> 520 | // Filter inequalaties that are always true. 521 | // Each inequality of the form cv + iv * "i" <= 0 is considered. 522 | let cv = IdxExpr.constVal ie 523 | match ie - cv * IdxExpr.one with 524 | // cv - "i" <= 0 => cv <= "i" => always true for cv <= 0 because "i" >= 0 525 | | SingleIdxExpr (i, iv) when iv = Rat.MinusOne && cv <= Rat.Zero -> false 526 | // cv + "i" <= 0 => "i" <= -cv => always true for -cv >= size_i-1 because "i" <= size_i-1 527 | | SingleIdxExpr (i, iv) when iv = Rat.One && -cv >= Rat (dxIdxSizes.[i]-1L) -> false 528 | | _ -> true) 529 | (summand, fsIdxs) ||> List.fold (fun s fsIdx -> idxIf -fsIdx GreaterOrEqualToZero s (scalar 0.0)) 530 | | FourierMotzkin.Range rng :: rSols -> 531 | let sumSym = sprintf "%s_z%d" dxName rng.Idx 532 | let summand = buildSum summand rSols (sumSym::sumSyms) 533 | // The limits are given by 534 | // Low limits: x[Idx] >= BLow .* b - SLow .* z.[Idx+1L..] 535 | // High limits: x[Idx] <= BHigh .* b - SHigh .* z.[Idx+1L..] 536 | // where b = - C .* XToY .* x 537 | let bMat = -C .* xToY 538 | let bLowMat = rng.BLow .* bMat |> HostTensor.toList2D 539 | let bHighMat = rng.BHigh .* bMat |> HostTensor.toList2D 540 | let sLowMat = rng.SLow |> HostTensor.toList2D 541 | let sHighMat = rng.SHigh |> HostTensor.toList2D 542 | let idxExpr bMat sMat = 543 | List.zip bMat sMat 544 | |> List.map (fun (bFacs, sFacs) -> IdxExpr.ofSeq dxIdxNames1 bFacs + IdxExpr.ofSeq sumSyms sFacs) 545 | let lows, highs = idxExpr bLowMat sLowMat, idxExpr bHighMat sHighMat 546 | sum sumSym lows highs summand 547 | | [] -> 548 | let xToY = xToY |> HostTensor.toList2D 549 | let zToY = yNull |> Tensor.convert |> HostTensor.toList2D 550 | let subs = 551 | List.zip3 yIdxNames1 xToY zToY 552 | |> List.map (fun (name, argFacs, nsFacs) -> 553 | name, IdxExpr.ofSeq dxIdxNames1 argFacs + IdxExpr.ofSeq sumSyms nsFacs) 554 | |> Map.ofList 555 | |> Map.add "1" IdxExpr.one 556 | substIdx subs summand 557 | 558 | let dxSummed = buildSum dx sumConstr [] 559 | 560 | // Check that all y are integer. 561 | // Check is only required for y that contain non-integer coefficients. 562 | let intIdxs = 563 | xToY 564 | |> HostTensor.toList2D 565 | |> List.filter (List.exists (Rat.isInteger >> not)) 566 | |> List.map (IdxExpr.ofSeq dxIdxNames1) 567 | let dxIntChecked = 568 | (dxSummed, intIdxs) ||> List.fold (fun s intIdx -> idxIf intIdx Integer s (scalar 0.0)) 569 | 570 | // Check solvability. 571 | let solIdxs = 572 | xSolvability 573 | |> Tensor.convert 574 | |> HostTensor.toList2D 575 | |> List.map (fun sFacs -> IdxExpr.ofSeq dxIdxNames1 sFacs) 576 | let dxSolChecked = 577 | (dxIntChecked, solIdxs) ||> List.fold (fun s solIdx -> idxIf solIdx EqualToZero s (scalar 0.0)) 578 | 579 | // Build derivative function. 580 | func dxName dxIdxNames dxIdxSizes argShapes dxSolChecked 581 | 582 | // Perform index substitution on the derivatives of all arguments and sum by argument. 583 | let dxFns = 584 | dxs 585 | |> List.map (fun ((xName, xIdxs), (syms, constrs, dx)) -> xName, processDeriv xName xIdxs syms constrs dx) 586 | |> List.groupBy fst 587 | |> List.map (fun (xName, dxs) -> 588 | xName, dxs |> List.map snd |> List.reduce (fun a {Expr=bExpr} -> {a with Expr=a.Expr + bExpr})) 589 | |> Map.ofList 590 | 591 | dxFns 592 | 593 | -------------------------------------------------------------------------------- /ElementsTests.fs: -------------------------------------------------------------------------------- 1 | namespace Elements 2 | 3 | open Xunit 4 | open Xunit.Abstractions 5 | open FsUnit.Xunit 6 | 7 | open Tensor 8 | open Tensor.Algorithm 9 | 10 | 11 | module DerivCheck = 12 | 13 | /// evaluates the Jacobian of f at x numerically with specified finite difference step 14 | let inline numDerivEpsilon (epsilon: 'T) (f: Tensor<'T> -> Tensor<'T>) (x: Tensor<'T>) = 15 | let y = f x 16 | let xElems, yElems = Tensor.nElems x, Tensor.nElems y 17 | let xShp = Tensor.shape x 18 | 19 | let jac = Tensor.zeros x.Dev [yElems; xElems] 20 | let xd = x |> Tensor.reshape [xElems] |> Tensor.copy 21 | for xi in 0L .. xElems-1L do 22 | let xiVal = xd.[[xi]] 23 | // f (x+epsilon) 24 | xd.[[xi]] <- xiVal + epsilon 25 | let ydf = xd |> Tensor.reshape xShp |> f |> Tensor.reshape [yElems] 26 | // f (x-epsilon) 27 | xd.[[xi]] <- xiVal - epsilon 28 | let ydb = xd |> Tensor.reshape xShp |> f |> Tensor.reshape [yElems] 29 | // [f (x+epsilon) - f (x-epsilon)] / (2 * epsilon) 30 | jac.[*, xi] <- (ydf - ydb) / (Tensor.scalar ydf.Dev (epsilon + epsilon)) 31 | xd.[[xi]] <- xiVal 32 | jac 33 | 34 | /// evaluates the Jacobian of f at x numerically 35 | let numDeriv f x = 36 | numDerivEpsilon 1e-5 f x 37 | 38 | let numDerivOfFunc argEnv (fn: Elements.ElemFunc) xName = 39 | let f xv = Elements.evalFunc (argEnv |> Map.add xName xv) fn 40 | numDeriv f argEnv.[xName] 41 | 42 | /// Calculates the Jacobian using the derivative of a function. 43 | let jacobianOfDerivFunc argEnv dInArg (dFn: Elements.ElemFunc) = 44 | let outElems = dFn.Shape |> List.fold (*) 1L 45 | let inElems = dFn.ArgShapes.[dInArg] |> List.fold (*) 1L 46 | let jac = HostTensor.zeros [inElems; outElems] 47 | for i in 0L .. inElems-1L do 48 | let dIn = HostTensor.zeros [inElems] 49 | dIn.[[i]] <- 1.0 50 | let dIn = dIn |> Tensor.reshape dFn.ArgShapes.[dInArg] 51 | let dArgEnv = argEnv |> Map.add dInArg dIn 52 | let dOut = Elements.evalFunc dArgEnv dFn 53 | jac.[i, *] <- Tensor.flatten dOut 54 | jac 55 | 56 | 57 | type ElementsTests (output: ITestOutputHelper) = 58 | 59 | let printfn format = Printf.kprintf (fun msg -> output.WriteLine(msg)) format 60 | 61 | let checkFuncDerivs orders argEnv fn = 62 | let rec doCheck order fn = 63 | let dFns = Elements.derivFunc fn 64 | let dInArg = "d" + fn.Name 65 | printfn "Checking %d. derivative of: %A" order fn 66 | for KeyValue(v, dFn) in dFns do 67 | printfn "%d. derivative of %s w.r.t. %s: %A" order fn.Name v dFn 68 | let nJac = DerivCheck.numDerivOfFunc argEnv fn v 69 | let aJac = DerivCheck.jacobianOfDerivFunc argEnv dInArg dFn 70 | if not (Tensor.almostEqual (nJac, aJac, 1e-3, 1e-3)) then 71 | printfn "Analytic Jacobian:\n%A" aJac 72 | printfn "Numeric Jacobian:\n%A" nJac 73 | printfn "Jacobian mismatch!!" 74 | failwith "Jacobian mismatch in function derivative check" 75 | else 76 | //printfn "Analytic Jacobian:\n%A" aJac 77 | //printfn "Numeric Jacobian:\n%A" nJac 78 | //printfn "Analytic and numeric Jacobians match." 79 | () 80 | 81 | if order < orders then 82 | doCheck (order+1) dFn 83 | doCheck 1 fn 84 | 85 | let randomDerivCheck orders iters (fn: Elements.ElemFunc) = 86 | let rnd = System.Random 123 87 | let rndTensor shp = HostTensor.randomUniform rnd (-1., 1.) shp 88 | for i in 1 .. iters do 89 | let argEnv = 90 | seq { 91 | if orders=2 then yield "d" + fn.Name, rndTensor fn.Shape 92 | for KeyValue(name, shp) in fn.ArgShapes do 93 | yield name, rndTensor shp 94 | if orders=2 then yield "d" + name, rndTensor shp 95 | if orders=2 then yield "dd" + name, rndTensor shp 96 | } |> Map.ofSeq 97 | checkFuncDerivs orders argEnv fn 98 | 99 | [] 100 | let ``EvalTest1`` () = 101 | let i, iSize = Elements.pos "i", 3L 102 | let j, jSize = Elements.pos "j", 4L 103 | let k, kSize = Elements.pos "k", 5L 104 | 105 | let xv = HostTensor.zeros [iSize; jSize] + 1.0 106 | let yv = HostTensor.zeros [jSize; jSize] + 2.0 107 | let zv = HostTensor.zeros [kSize] + 3.0 108 | 109 | let dimNames = [i.Name; j.Name; k.Name] 110 | let dimSizes = Map [i.Name, iSize; j.Name, jSize; k.Name, kSize] 111 | let argShapes = Map ["x", xv.Shape; "y", yv.Shape; "z", zv.Shape] 112 | 113 | let expr = Elements.arg "x" [i; j] + 2.0 * (Elements.arg "y" [j; j] * (Elements.arg "z" [k])**3.0) 114 | let func = Elements.func "f" dimNames dimSizes argShapes expr 115 | 116 | printfn "Evaluating:" 117 | printfn "x=\n%A" xv 118 | printfn "y=\n%A" yv 119 | printfn "z=\n%A" zv 120 | let argEnv = Map ["x", xv; "y", yv; "z", zv] 121 | let fv = Elements.evalFunc argEnv func 122 | printfn "f=\n%A" fv 123 | 124 | 125 | [] 126 | let ``DerivTest1`` () = 127 | let i, iSize = Elements.pos "i", 3L 128 | let j, jSize = Elements.pos "j", 4L 129 | let k, kSize = Elements.pos "k", 5L 130 | 131 | let xv = HostTensor.zeros [iSize; jSize] + 1.0 132 | let yv = HostTensor.zeros [jSize; jSize] + 2.0 133 | let zv = HostTensor.zeros [kSize] + 3.0 134 | 135 | let dimNames = [i.Name; j.Name; k.Name] 136 | let dimSizes = Map [i.Name, iSize; j.Name, jSize; k.Name, kSize] 137 | let argShapes = Map ["x", xv.Shape; "y", yv.Shape; "z", zv.Shape] 138 | 139 | let expr = Elements.arg "x" [i; j] + 2.0 * (Elements.arg "y" [j; j] * (Elements.arg "z" [k])**3.0) 140 | let func = Elements.func "f" dimNames dimSizes argShapes expr 141 | 142 | printfn "%A" func 143 | printfn "Ranges: %A" dimSizes 144 | let dFns = Elements.derivFunc func 145 | printfn "dFns:" 146 | for KeyValue(_, dFn) in dFns do 147 | printfn "%A" dFn 148 | 149 | [] 150 | let ``DerivTest2`` () = 151 | let i, iSize = Elements.pos "i", 3L 152 | let j, jSize = Elements.pos "j", 4L 153 | let k, kSize = Elements.pos "k", 5L 154 | 155 | let xv = HostTensor.zeros [iSize; jSize] + 1.0 156 | let yv = HostTensor.zeros [jSize; jSize] + 2.0 157 | let zv = HostTensor.zeros [kSize] + 3.0 158 | 159 | let dimNames = [i.Name; j.Name; k.Name] 160 | let dimSizes = Map [i.Name, iSize; j.Name, jSize; k.Name, kSize] 161 | let argShapes = Map ["x", xv.Shape] 162 | 163 | let expr = Elements.arg "x" [Rat 2*i; j] 164 | let func = Elements.func "f" dimNames dimSizes argShapes expr 165 | 166 | printfn "%A" func 167 | printfn "Ranges: %A" dimSizes 168 | let dFns = Elements.derivFunc func 169 | printfn "dFns:" 170 | for KeyValue(_, dFn) in dFns do 171 | printfn "%A" dFn 172 | 173 | [] 174 | let ``DerivTest3`` () = 175 | let i, iSize = Elements.pos "i", 3L 176 | let s, sSize = Elements.pos "s", 7L 177 | 178 | let dimNames = [i.Name] 179 | let dimSizes = Map [i.Name, iSize] 180 | let argShapes = Map ["x", [iSize; 10L]] 181 | 182 | let summand = Elements.arg "x" [i; s] 183 | let expr = Elements.sumConstRng "s" 0L (sSize-1L) summand 184 | let func = Elements.func "f" dimNames dimSizes argShapes expr 185 | 186 | printfn "%A" func 187 | printfn "Ranges: %A" dimSizes 188 | let dFns = Elements.derivFunc func 189 | printfn "dFns:" 190 | for KeyValue(_, dFn) in dFns do 191 | printfn "%A" dFn 192 | 193 | [] 194 | let ``DerivTest4`` () = 195 | let i, iSize = Elements.pos "i", 7L 196 | 197 | let dimNames = [i.Name] 198 | let dimSizes = Map [i.Name, iSize] 199 | let argShapes = Map ["x", [10L]] 200 | 201 | let expr = Elements.arg "x" [i + Elements.idxConst (Rat 2)] 202 | let func = Elements.func "f" dimNames dimSizes argShapes expr 203 | 204 | printfn "%A" func 205 | printfn "Ranges: %A" dimSizes 206 | let dFns = Elements.derivFunc func 207 | printfn "dFns:" 208 | for KeyValue(_, dFn) in dFns do 209 | printfn "%A" dFn 210 | 211 | [] 212 | let ``DerivCheck1`` () = 213 | let i, iSize = Elements.pos "i", 3L 214 | let j, jSize = Elements.pos "j", 4L 215 | let k, kSize = Elements.pos "k", 5L 216 | 217 | let rnd = System.Random 123 218 | let xv = HostTensor.randomUniform rnd (0., 1.) [iSize; jSize] 219 | let yv = HostTensor.randomUniform rnd (0., 1.) [jSize; jSize] 220 | let zv = HostTensor.randomUniform rnd (0., 1.) [kSize] 221 | 222 | let dimNames = [i.Name; j.Name; k.Name] 223 | let dimSizes = Map [i.Name, iSize; j.Name, jSize; k.Name, kSize] 224 | let argShapes = Map ["x", xv.Shape; "y", yv.Shape; "z", zv.Shape] 225 | 226 | let expr = Elements.arg "x" [i; j] ** 2.0 + 2.0 * (Elements.arg "y" [j; j] * (Elements.arg "z" [k])**3.0) 227 | let func = Elements.func "f" dimNames dimSizes argShapes expr 228 | 229 | printfn "x=\n%A" xv 230 | printfn "y=\n%A" yv 231 | printfn "z=\n%A" zv 232 | let argEnv = Map ["x", xv; "y", yv; "z", zv] 233 | let fv = Elements.evalFunc argEnv func 234 | checkFuncDerivs 1 argEnv func 235 | 236 | 237 | 238 | [] 239 | let ``DerivCheck2`` () = 240 | let i, iSize = Elements.pos "i", 3L 241 | let j, jSize = Elements.pos "j", 4L 242 | let k, kSize = Elements.pos "k", 5L 243 | 244 | let dimNames = [i.Name; j.Name; k.Name] 245 | let dimSizes = Map [i.Name, iSize; j.Name, jSize; k.Name, kSize] 246 | let argShapes = Map ["x", [iSize; jSize]; "y", [jSize; jSize]; "z", [kSize]] 247 | 248 | let expr = Elements.arg "x" [i; j] ** 2.0 + 2.0 * (Elements.arg "y" [j; j] * (Elements.arg "z" [k])**3.0) 249 | let func = Elements.func "f" dimNames dimSizes argShapes expr 250 | 251 | randomDerivCheck 2 3 func 252 | 253 | 254 | [] 255 | let ``DerivCheck3`` () = 256 | let r, rSize = Elements.pos "r", 2L 257 | let s, sSize = Elements.pos "s", 3L 258 | let n, nSize = Elements.pos "n", 4L 259 | 260 | let dimNames = [r.Name; s.Name; n.Name] 261 | let dimSizes = Map [r.Name, rSize; s.Name, sSize; n.Name, nSize] 262 | let argShapes = Map ["Sigma", [sSize; nSize; nSize]; "mu", [sSize; nSize]; "V", [rSize; nSize]] 263 | 264 | let Sigma = Elements.arg "Sigma" 265 | let mu = Elements.arg "mu" 266 | let V = Elements.arg "V" 267 | let expr = // added **2 to Sigma to make it positive 268 | sqrt (1. / (1. + 2. * Sigma[s;n;n]**2.)) * exp (- (mu[s;n] - V[r;n])**2. / (1. + 2. * Sigma[s;n;n])) 269 | let func = Elements.func "S" dimNames dimSizes argShapes expr 270 | 271 | randomDerivCheck 2 2 func 272 | 273 | [] 274 | let ``DerivCheck4`` () = 275 | let r, rSize = Elements.pos "r", 2L 276 | let s, sSize = Elements.pos "s", 3L 277 | let t, tSize = Elements.pos "t", 2L // =r 278 | let n, nSize = Elements.pos "n", 4L 279 | 280 | let dimNames = [r.Name; s.Name; t.Name; n.Name] 281 | let dimSizes = Map [r.Name, rSize; s.Name, sSize; t.Name, tSize; n.Name, nSize] 282 | let argShapes = Map ["Sigma", [sSize; nSize; nSize]; "mu", [sSize; nSize]; "V", [rSize; nSize]] 283 | 284 | let Sigma = Elements.arg "Sigma" 285 | let mu = Elements.arg "mu" 286 | let V = Elements.arg "V" 287 | let expr = // added **2 to Sigma to make it positive 288 | sqrt (1. / (1. + 4. * Sigma[s;n;n]**2.)) * exp (- 2. * (mu[s;n] - (V[r;n] + V[t;n])/2.)**2. / (1. + 4. * Sigma[s;n;n]) - 289 | (V[r;n] - V[t;n])**2. / 2.) 290 | let func = Elements.func "S" dimNames dimSizes argShapes expr 291 | 292 | randomDerivCheck 1 2 func 293 | 294 | [] 295 | let ``DerivCheck5`` () = 296 | let s, sSize = Elements.pos "s", 3L 297 | let t, tSize = Elements.pos "t", 2L 298 | let n, nSize = Elements.pos "n", 4L 299 | 300 | let dimNames = [t.Name; n.Name] 301 | let dimSizes = Map [t.Name, tSize; n.Name, nSize] 302 | let argShapes = Map ["x", [nSize; sSize; 3L*sSize]; "y", [tSize; nSize]] 303 | 304 | let x, y = Elements.arg "x", Elements.arg "y" 305 | //let summand = y[t;n] 306 | let summand = x[n ; s; Rat 2 * s]**2. + x[n; s; Rat 0 * s]**2. - y[t; Rat 0 * s] 307 | let expr = y[t;n]**2. * Elements.sumConstRng "s" 0L (sSize-1L) summand 308 | //let expr = Elements.simpleSum "s" 0L (sSize-1L) summand 309 | let func = Elements.func "S" dimNames dimSizes argShapes expr 310 | 311 | randomDerivCheck 1 5 func 312 | 313 | [] 314 | let ``DerivCheck6`` () = 315 | let s, sSize = Elements.pos "s", 3L 316 | let t, tSize = Elements.pos "t", 10L 317 | 318 | let dimNames = [s.Name; t.Name] 319 | let dimSizes = Map [s.Name, sSize; t.Name, tSize] 320 | let argShapes = Map ["x", [50L]] 321 | 322 | let x = Elements.arg "x" 323 | let expr = x[Rat 10*s+t]**2. 324 | let func = Elements.func "f" dimNames dimSizes argShapes expr 325 | 326 | randomDerivCheck 1 5 func 327 | 328 | [] 329 | let ``DerivDemo`` () = 330 | let i, iSize = Elements.pos "i", 3L 331 | let j, jSize = Elements.pos "j", 4L 332 | let k, kSize = Elements.pos "k", 5L // summation index 333 | 334 | let dimNames = [i.Name; j.Name] 335 | let dimSizes = Map [i.Name, iSize; j.Name, jSize] 336 | 337 | let argShapes = Map ["a",[iSize; kSize]; "b",[jSize; kSize]; "c",[iSize; iSize]; "d",[iSize + kSize]] 338 | 339 | let a, b, c, d = Elements.arg "a", Elements.arg "b", Elements.arg "c", Elements.arg "d" 340 | let summand = (a[i;k] + b[j;k])**2. * c[i;i] + (d[i+k])**3. 341 | let expr = exp (- Elements.sumConstRng "k" 0L (kSize-1L) summand) 342 | let func = Elements.func "f" dimNames dimSizes argShapes expr 343 | 344 | randomDerivCheck 1 5 func 345 | 346 | -------------------------------------------------------------------------------- /FourierMotzkin.fs: -------------------------------------------------------------------------------- 1 | /// Foruier-Motzkin elimination method for solving system of inequalities. 2 | /// This algorithm is only practical for very small systems. 3 | module FourierMotzkin 4 | 5 | open Tensor 6 | open Tensor.Algorithm 7 | 8 | 9 | /// A range for a particular element of x in a system of inequalities. 10 | /// The low limits and high limits for x[Idx] are given as follows: 11 | /// Low limits: x[Idx] >= BLow .* b - SLow .* x.[Idx+1L..] 12 | /// High limits: x[Idx] <= BHigh .* b - SHigh .* x.[Idx+1L..] 13 | /// If a low limits is larger that a high limit, then no solution exists. 14 | type Range = { 15 | /// Index of x element this solution is for. 16 | Idx: int64 17 | /// Matrix to multiply b with to obtain low limits. 18 | BLow: Tensor 19 | /// Matrix to multiply b with to obtain high limits. 20 | BHigh: Tensor 21 | /// Matrix to multiply x.[Idx+1L..] with to obtain low limits. 22 | SLow: Tensor 23 | /// Matrix to multiply x.[Idx+1L..] with to obtain high limits. 24 | SHigh: Tensor 25 | } 26 | 27 | /// Solution element. 28 | type Solution = 29 | /// Specifies range for a particular element of x. 30 | | Range of Range 31 | /// Specifies that system only has solution if B .* b <= 0. 32 | | Feasibility of B:Tensor 33 | 34 | /// A possible solution to a system of inequalities. 35 | type Solutions = Solution list 36 | 37 | /// Solves a system of inequalities of the form A .* x >= b for arbitrary b. 38 | let solve (A: Tensor) : Solutions = 39 | let m, n = 40 | match A.Shape with 41 | | [m; n] -> m, n 42 | | _ -> invalidArg "A" "A must be a matrix" 43 | let needFeasibilityCheck = 44 | A ==== Rat.Zero |> Tensor.allAxis 1 |> Tensor.any 45 | 46 | /// Elimination step. 47 | let rec eliminate (rA: Tensor) (rB: Tensor) rAs rBs = 48 | let k = List.length rAs |> int64 49 | //printfn "Elimination step %d:" k 50 | //printfn "rA=\n%A" rA 51 | //printfn "rb=\n%A" rb 52 | //printfn "arbitrary: %A" arb 53 | if k < n then 54 | let rA, rB = Tensor.copy rA, Tensor.copy rB 55 | let zRows = rA.[*, k] ==== Rat.Zero 56 | let nzRows = ~~~~zRows 57 | 58 | // Divide rows so that x_k = +1 or x_k = -1 or x_k = 0 in each inequality. 59 | let facs = abs (rA.M(nzRows, NoMask).[*, k..k]) 60 | rA.M(nzRows, NoMask) <- rA.M(nzRows, NoMask) / facs 61 | rB.M(nzRows, NoMask) <- rB.M(nzRows, NoMask) / facs 62 | 63 | //printfn "after division:" 64 | //printfn "rA=\n%A" rA 65 | //printfn "rb=\n%A" rb 66 | 67 | // Check condition of x_k. 68 | if Tensor.all (rA.[*, k] ==== Rat.Zero) then 69 | // all the coefficients of x_k are zero, thus it is arbitrary 70 | //printfn "all x_k=0" 71 | eliminate rA rB (rA::rAs) (rB::rBs) 72 | elif Tensor.all (rA.[*, k] ==== Rat.One) || Tensor.all (rA.[*, k] ==== Rat.MinusOne) then 73 | // the coefficients of x_k are all +1 or -1 74 | //printfn "all x_k=+1 or all x_k=-1" 75 | eliminate (rA.M(zRows, NoMask)) (rB.M(zRows, NoMask)) (rA::rAs) (rB::rBs) 76 | elif Tensor.all ((rA.[*,k] ==== Rat.Zero) |||| (rA.[*,k] ==== Rat.One)) || 77 | Tensor.all ((rA.[*,k] ==== Rat.Zero) |||| (rA.[*,k] ==== Rat.MinusOne)) then 78 | // the coefficients of x_k are a mix of 0 and +1 or a mix of 0 and -1 79 | //printfn "x_k is mix of 0 and +1 or mix of 0 and -1" 80 | eliminate (rA.M(zRows, NoMask)) (rB.M(zRows, NoMask)) (rA::rAs) (rB::rBs) 81 | else 82 | //printfn "x_k has +1 and -1" 83 | // there is at least one pair of inequalities with a +1 and a -1 coefficient for x_k 84 | let pRows = rA.[*, k] ==== Rat.One |> Tensor.trueIdx |> Tensor.flatten |> HostTensor.toList 85 | let nRows = rA.[*, k] ==== Rat.MinusOne |> Tensor.trueIdx |> Tensor.flatten |> HostTensor.toList 86 | // for each pair augment the reduced system by their sum 87 | let nextRA = 88 | List.allPairs pRows nRows 89 | |> List.map (fun (p, n) -> rA.[p..p, *] + rA.[n..n, *]) 90 | |> List.append [rA.M(zRows, NoMask)] 91 | |> Tensor.concat 0 92 | let nextRB = 93 | List.allPairs pRows nRows 94 | |> List.map (fun (p, n) -> rB.[p..p, *] + rB.[n..n, *]) 95 | |> List.append [rB.M(zRows, NoMask)] 96 | |> Tensor.concat 0 97 | eliminate nextRA nextRB (rA::rAs) (rB::rBs) 98 | else 99 | let feasibility = 100 | if needFeasibilityCheck && rB.Shape.[0] > 0L then [Feasibility rB] 101 | else [] 102 | backSubst rAs rBs feasibility 103 | 104 | /// Backsubstitution step. 105 | and backSubst rAs rBs sols = 106 | match rAs, rBs with 107 | | rA::rAs, rB::rBs -> 108 | let k = List.length rAs |> int64 109 | 110 | // split B for lower and upper limit of x_j 111 | let Blow = rB.M(rA.[*,k] ==== Rat.One, NoMask) 112 | let Bhigh = -rB.M(rA.[*,k] ==== Rat.MinusOne, NoMask) 113 | 114 | // substitute the values into the system: x = [0; ...; 0; v_j; ...; v_n] 115 | // solution: B .* y - A .* x.[j..n] 116 | let S = -rA.[*, k+1L..] 117 | //printfn "S for %d=\n%A" k S 118 | 119 | // split C for lower and upper limit of x_j 120 | let Slow = S.M(rA.[*,k] ==== Rat.One, NoMask) 121 | let Shigh = -S.M(rA.[*,k] ==== Rat.MinusOne, NoMask) 122 | 123 | let sol = Range { 124 | Idx=k 125 | BLow=Blow; BHigh=Bhigh 126 | SLow=Slow; SHigh=Shigh 127 | } 128 | backSubst rAs rBs (sol::sols) 129 | | _ -> sols 130 | 131 | eliminate A (HostTensor.identity m) [] [] |> List.rev 132 | 133 | 134 | /// Checks system for feasibility. 135 | let feasible (fs: Tensor) (b: Tensor) = 136 | match b.Shape with 137 | | [l] when l = fs.Shape.[1] -> () 138 | | _ -> invalidArg "b" "b has wrong size" 139 | Rat.Zero >>== fs .* b |> Tensor.all 140 | 141 | 142 | /// Returns the range (xMin, xMax) for x.[sol.Idx] so that xMin <= x.[sol.Idx] <= xMax given x.[sol.Idx+1L..]. 143 | /// If xMin > xMax, then no solution exists. 144 | let range (sol: Range) (b: Tensor) (xRight: Tensor) = 145 | match b.Shape with 146 | | [l] when l = sol.BLow.Shape.[1] -> () 147 | | _ -> invalidArg "b" "b has wrong size" 148 | match xRight.Shape with 149 | | [l] when l = sol.SLow.Shape.[1] -> () 150 | | _ -> invalidArg "xRight" "wrong number of substitution variables" 151 | 152 | let lows = sol.BLow .* b + sol.SLow .* xRight 153 | let highs = sol.BHigh .* b + sol.SHigh .* xRight 154 | Tensor.max lows, Tensor.min highs 155 | 156 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /Program.fs: -------------------------------------------------------------------------------- 1 | open Elements 2 | 3 | let printAllDerivs fn = 4 | let dFns = Elements.derivFunc fn 5 | let dInArg = "d" + fn.Name 6 | printfn "Input: %A" fn 7 | for KeyValue(v, dFn) in dFns do 8 | printfn "Derivative of %s w.r.t. %s: %A" fn.Name v dFn 9 | 10 | 11 | let doDemo () = 12 | let i, iSize = Elements.pos "i", 3L 13 | let j, jSize = Elements.pos "j", 4L 14 | let k, kSize = Elements.pos "k", 5L // summation index 15 | 16 | let dimNames = [i.Name; j.Name] 17 | let dimSizes = Map [i.Name, iSize; j.Name, jSize] 18 | 19 | let argShapes = Map ["a",[iSize; kSize]; "b",[jSize; kSize]; "c",[iSize; iSize]; "d",[iSize + kSize]] 20 | 21 | let a, b, c, d = Elements.arg "a", Elements.arg "b", Elements.arg "c", Elements.arg "d" 22 | let summand = (a[i;k] + b[j;k])**2. * c[i;i] + (d[i+k])**3. 23 | let expr = exp (- Elements.sumConstRng "k" 0L (kSize-1L) summand) 24 | let func = Elements.func "f" dimNames dimSizes argShapes expr 25 | 26 | printAllDerivs func 27 | 28 | 29 | [] 30 | let main argv = 31 | doDemo () 32 | 0 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Automatic Differentiation for Tensor Algebras 2 | ============================================= 3 | 4 | [![Build Status](https://travis-ci.org/surban/TensorAlgDiff.svg?branch=master)](https://travis-ci.org/surban/TensorAlgDiff) 5 | 6 | Read the corresponding [technical report for more details and examples](https://arxiv.org/abs/1711.01348). 7 | 8 | This code computes expressions for the derivatives of element-wise defined tensor-valued functions. 9 | It can handle arguments inside the functions that are indexed by arbitrary linear combinations of the function indices. 10 | Furthermore, the function may contain (nested) sums with arbitrary ranges (even linearly depending on other indices). 11 | An example is the matrix-valued function f(a,b,c,d), which is element-wise defined by the expression 12 | 13 | f[i; j] = exp (-sum{k}_0^4 (((a[i; k] + b[j; k]) ** 2 * c[i; i] + d[i + k] ** 3))) 14 | 15 | Deriving derivative expressions when the mapping between function indices and argument indices is not 1:1 requires special attention. 16 | For example, for the function `f_{ij} (x) = x_i^2`, the derivative of some loss `l=l(f(x))` w.r.t. `x` is `(dx)_i = dl / dx_i = \sum_j (df)_{ij} 2 x_i`; the sum is necessary because index `j` does not appear in the indices of `f`. 17 | Another example is `f_i (x) = x_{ii}^2`, where `x` is a matrix; here we have `(dx)_{ij} = \kronecker_{i=j} (df)_i 2 x_{ii}`; the Kronecker delta is necessary because the derivative is zero for off-diagonal elements. 18 | Another indexing scheme is used by `f_{ij} (x) = exp x_{i+j}`; here the correct derivative is `(dx)_k = \sum_i (df)_{i,k-i} \exp x_k`, where the range of the sum must be chosen appropriately. 19 | 20 | Our algorithm can handle any case in which the indices of an argument are an *arbitrary linear combination* of the indices of the function, thus all of the above examples can be handled. 21 | Sums (and their ranges) and Kronecker deltas are automatically inserted into the derivatives as necessary. 22 | Additionally, the indices are transformed, if required (as in the last example). 23 | The algorithm outputs a symbolic expression that can be subsequently fed into a tensor algebra compiler. 24 | 25 | For the above expression the algorithm outputs: 26 | 27 | Derivative of f wrt. a: da[da_0; da_1] = sum{da_z0}_0^3 (((-(df[da_0; da_z0] * exp (-sum{k}_0^4 (((a[da_0; k] + b[da_z0; k]) ** 2 * c[da_0; da_0] + d[da_0 + k] ** 3))))) * c[da_0; da_0] * 2 * (a[da_0; da_1] + b[da_z0; da_1]) ** (2 - 1))) 28 | Derivative of f wrt. b: db[db_0; db_1] = sum{db_z0}_0^2 (((-(df[db_z0; db_0] * exp (-sum{k}_0^4 (((a[db_z0; k] + b[db_0; k]) ** 2 * c[db_z0; db_z0] + d[db_z0 + k] ** 3))))) * c[db_z0; db_z0] * 2 * (a[db_z0; db_1] + b[db_0; db_1]) ** (2 - 1))) 29 | Derivative of f wrt. c: dc[dc_0; dc_1] = if {dc_0 + -dc_1 = 0} then (sum{dc_z1}_0^4 (sum{dc_z0}_0^3 (((a[dc_1; dc_z1] + b[dc_z0; dc_z1]) ** 2 * (-(df[dc_1; dc_z0] * exp (-sum{k}_0^4 (((a[dc_1; k] + b[dc_z0; k]) ** 2 * c[dc_1; dc_1] + d[dc_1 + k] ** 3))))))))) else (0) 30 | Derivative of f wrt. d: dd[dd_0] = sum{dd_z1}_(max [0; -2 + dd_0])^(min [4; dd_0]) (sum{dd_z0}_0^3 (((-(df[dd_0 + -dd_z1; dd_z0] * exp (-sum{k}_0^4 (((a[dd_0 + -dd_z1; k] + b[dd_z0; k]) ** 2 * c[dd_0 + -dd_z1; dd_0 + -dd_z1] + d[dd_0 + -dd_z1 + k] ** 3))))) * 3 * d[dd_0] ** (3 - 1)))) 31 | 32 | Internally, the derivatives are stored as computational trees to avoid repeated computations and thus expression blowup that otherwise occurs in symbolic differentiation. 33 | This work can easily be employed in system that generate C++ or CUDA code for expressions or be combined with a Tensor Algebra Compiler like . 34 | 35 | Running 36 | ------- 37 | 38 | 1. Install .NET Core 2.0 from (packages available for all operating systems). We tested our code on Ubuntu Linux. 39 | 40 | 2. To build and run the demo execute `dotnet run` 41 | 42 | 3. To run the numeric verification tests run `dotnet test` (takes approx. 2 minutes) 43 | 44 | Reference 45 | --------- 46 | 47 | When using this work or the provided code please refer to the following publication. 48 | 49 | Sebastian Urban, Patrick van der Smagt. Automatic Differentiation for Tensor Algebras. arXiv:1711.01348 [cs.SC], 2017. 50 | 51 | Note that we employ some algorithms implemented in our open-source Tensor library; their source is at . 52 | Documentation is available at . 53 | 54 | License 55 | ------- 56 | 57 | Apache License 2.0 58 | -------------------------------------------------------------------------------- /TensorAlgDiff.fsproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | Exe 4 | netcoreapp2.0 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /elemdiff.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/surban/TensorAlgDiff/e9995af704a6930385d71e0fe0b4ebf5d8e46e85/elemdiff.pdf --------------------------------------------------------------------------------