├── .npmrc ├── argamak.jpg ├── .formatter.exs ├── .gitignore ├── test ├── argamak_test.gleam ├── argamak │ ├── format_test.gleam │ ├── space_test.gleam │ └── tensor_test.gleam └── argamak_test_ffi.mjs ├── package.json ├── deno.json ├── NOTICE ├── gleam.toml ├── .github └── workflows │ └── ci.yml ├── manifest.toml ├── src ├── argamak │ ├── format.gleam │ ├── axis.gleam │ └── space.gleam ├── argamak_ffi.ex └── argamak_ffi.mjs ├── CHANGELOG.md ├── README.md └── LICENSE /.npmrc: -------------------------------------------------------------------------------- 1 | loglevel = "error" 2 | -------------------------------------------------------------------------------- /argamak.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tynanbe/argamak/HEAD/argamak.jpg -------------------------------------------------------------------------------- /.formatter.exs: -------------------------------------------------------------------------------- 1 | [ 2 | inputs: [".formatter.exs", "{src,test}/*.ex"] 3 | ] 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.beam 2 | *.ez 3 | *.log 4 | build 5 | node_modules 6 | erl_crash.dump 7 | -------------------------------------------------------------------------------- /test/argamak_test.gleam: -------------------------------------------------------------------------------- 1 | import gleeunit 2 | 3 | pub fn main() { 4 | gleeunit.main() 5 | } 6 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "dependencies": { 3 | "@tensorflow/tfjs-node": "^4.17.0" 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /deno.json: -------------------------------------------------------------------------------- 1 | { 2 | "fmt": { 3 | "include": [ 4 | "CHANGELOG.md", 5 | "README.md", 6 | "deno.json", 7 | "src", 8 | "test" 9 | ] 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /test/argamak/format_test.gleam: -------------------------------------------------------------------------------- 1 | import gleeunit/should 2 | import argamak/format 3 | 4 | pub fn to_string_test() { 5 | format.float32() 6 | |> format.to_string 7 | |> should.equal("Format(Float32)") 8 | 9 | format.int32() 10 | |> format.to_string 11 | |> should.equal("Format(Int32)") 12 | } 13 | -------------------------------------------------------------------------------- /test/argamak_test_ffi.mjs: -------------------------------------------------------------------------------- 1 | import { tensor as tf_tensor } from "@tensorflow/tfjs-node"; 2 | import { inspect } from "../gleam_stdlib/gleam_stdlib.mjs"; 3 | 4 | export const tensor = (x) => tf_tensor(eval(inspect(x))); 5 | 6 | export const shape = (x) => x.shape; 7 | 8 | export const type = (x) => x.dtype; 9 | 10 | export const infinity = () => Infinity; 11 | 12 | export const neg_infinity = () => -Infinity; 13 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | LEGAL NOTICE INFORMATION 2 | ------------------------ 3 | 4 | All the files in this distribution are copyright to the terms below. 5 | 6 | Copyright 2021 Tynan Beatty 7 | 8 | Licensed under the Apache License, Version 2.0 (the "License"); 9 | you may not use this file except in compliance with the License. 10 | You may obtain a copy of the License at 11 | 12 | https://www.apache.org/licenses/LICENSE-2.0 13 | 14 | Unless required by applicable law or agreed to in writing, software 15 | distributed under the License is distributed on an "AS IS" BASIS, 16 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | See the License for the specific language governing permissions and 18 | limitations under the License. 19 | -------------------------------------------------------------------------------- /gleam.toml: -------------------------------------------------------------------------------- 1 | name = "argamak" 2 | version = "1.1.0" 3 | description = "A tensor library for the Gleam programming language" 4 | licences = ["Apache-2.0"] 5 | gleam = ">= 0.34.0" 6 | 7 | [repository] 8 | repo = "argamak" 9 | user = "tynanbe" 10 | type = "github" 11 | 12 | [[links]] 13 | href = "https://gleam.run/" 14 | title = "Website" 15 | 16 | [dependencies] 17 | gleam_stdlib = "~> 0.34 or ~> 1.0" 18 | nx = "~> 0.5 or ~> 1.0" 19 | 20 | [dev-dependencies] 21 | gleeunit = "~> 1.0" 22 | rad = "~> 1.1" 23 | 24 | [rad] 25 | targets = ["erlang", "javascript"] 26 | 27 | [[rad.formatters]] 28 | name = "elixir" 29 | check = ["sh", "-euc", """ 30 | mix format --check-formatted 31 | echo -n 'Checked all files in `src` and `test`' 32 | """] 33 | run = ["sh", "-euc", """ 34 | mix format 35 | echo -n 'Formatted all files in `src` and `test`' 36 | """] 37 | 38 | [[rad.formatters]] 39 | name = "javascript" 40 | check = ["deno", "fmt", "--check"] 41 | run = ["deno", "fmt"] 42 | 43 | [[rad.tasks]] 44 | path = ["init"] 45 | run = ["npm", "ci"] 46 | shortdoc = "Initialize argamak" 47 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - "v*.*.*" 8 | pull_request: 9 | 10 | jobs: 11 | test: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | 16 | - uses: erlef/setup-beam@v1 17 | with: 18 | otp-version: "26" 19 | rebar3-version: "3" 20 | elixir-version: "1.16" 21 | gleam-version: "1.0.0-rc2" 22 | 23 | - id: cache-gleam 24 | uses: actions/cache@v3 25 | with: 26 | path: build/packages 27 | key: ${{ runner.os }}-gleam-${{ hashFiles('manifest.toml') }} 28 | 29 | - uses: denoland/setup-deno@v1 30 | with: 31 | deno-version: "v1.x" 32 | 33 | - id: cache-deno 34 | uses: actions/cache@v3 35 | with: 36 | path: | 37 | ~/.deno 38 | ~/.cache/deno 39 | key: ${{ runner.os }}-deno-${{ hashFiles('deno.lock') }} 40 | 41 | - uses: actions/setup-node@v3 42 | with: 43 | node-version: "20" 44 | 45 | - id: cache-node 46 | uses: actions/cache@v3 47 | with: 48 | path: node_modules 49 | key: ${{ runner.os }}-node-${{ hashFiles('package-lock.json') }} 50 | 51 | - run: echo "$PWD/build/packages/rad/priv" >> $GITHUB_PATH 52 | 53 | - if: ${{ !steps.cache-gleam.outputs.cache-hit }} 54 | run: gleam deps download 55 | 56 | - if: ${{ !steps.cache-node.outputs.cache-hit }} 57 | run: rad init 58 | 59 | - run: rad test 60 | 61 | - run: rad format --check 62 | -------------------------------------------------------------------------------- /manifest.toml: -------------------------------------------------------------------------------- 1 | # This file was generated by Gleam 2 | # You typically do not need to edit this file 3 | 4 | packages = [ 5 | { name = "complex", version = "0.5.0", build_tools = ["mix"], requirements = [], otp_app = "complex", source = "hex", outer_checksum = "2683BD3C184466CFB94FAD74CBFDDFAA94B860E27AD4CA1BFFE3BFF169D91EF1" }, 6 | { name = "gleam_community_ansi", version = "1.4.0", build_tools = ["gleam"], requirements = ["gleam_community_colour", "gleam_stdlib"], otp_app = "gleam_community_ansi", source = "hex", outer_checksum = "FE79E08BF97009729259B6357EC058315B6FBB916FAD1C2FF9355115FEB0D3A4" }, 7 | { name = "gleam_community_colour", version = "1.3.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleam_community_colour", source = "hex", outer_checksum = "A49A5E3AE8B637A5ACBA80ECB9B1AFE89FD3D5351FF6410A42B84F666D40D7D5" }, 8 | { name = "gleam_http", version = "3.5.3", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleam_http", source = "hex", outer_checksum = "C2FC3322203B16F897C1818D9810F5DEFCE347F0751F3B44421E1261277A7373" }, 9 | { name = "gleam_httpc", version = "2.1.2", build_tools = ["gleam"], requirements = ["gleam_http", "gleam_stdlib"], otp_app = "gleam_httpc", source = "hex", outer_checksum = "ACD05CA3BAC7780DF5FFAE334621FD199D1B490FAF6ECDFF74316CAA61CE88E6" }, 10 | { name = "gleam_json", version = "1.0.0", build_tools = ["gleam"], requirements = ["gleam_stdlib", "thoas"], otp_app = "gleam_json", source = "hex", outer_checksum = "8B197DD5D578EA6AC2C0D4BDC634C71A5BCA8E7DB5F47091C263ECB411A60DF3" }, 11 | { name = "gleam_stdlib", version = "0.35.1", build_tools = ["gleam"], requirements = [], otp_app = "gleam_stdlib", source = "hex", outer_checksum = "5443EEB74708454B65650FEBBB1EF5175057D1DEC62AEA9D7C6D96F41DA79152" }, 12 | { name = "gleeunit", version = "1.0.2", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleeunit", source = "hex", outer_checksum = "D364C87AFEB26BDB4FB8A5ABDE67D635DC9FA52D6AB68416044C35B096C6882D" }, 13 | { name = "glint", version = "0.15.0", build_tools = ["gleam"], requirements = ["gleam_community_ansi", "gleam_community_colour", "gleam_stdlib", "snag"], otp_app = "glint", source = "hex", outer_checksum = "D5324DBE11F57BF0B303D99EA086D66B8DC319EE59C1355C76EBB1544187C237" }, 14 | { name = "nx", version = "0.6.4", build_tools = ["mix"], requirements = ["complex", "telemetry"], otp_app = "nx", source = "hex", outer_checksum = "BB9C2E2E3545B5EB4739D69046A988DAAA212D127DBA7D97801C291616AFF6D6" }, 15 | { name = "rad", version = "1.1.0", build_tools = ["gleam"], requirements = ["gleam_http", "gleam_httpc", "gleam_json", "gleam_stdlib", "glint", "shellout", "snag", "thoas", "tomerl"], otp_app = "rad", source = "hex", outer_checksum = "F2427C9DC1969B715B2D28694183C0D9498652764933C1BED19C5D77C77D33C1" }, 16 | { name = "shellout", version = "1.6.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "shellout", source = "hex", outer_checksum = "E2FCD18957F0E9F67E1F497FC9FF57393392F8A9BAEAEA4779541DE7A68DD7E0" }, 17 | { name = "snag", version = "0.3.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "snag", source = "hex", outer_checksum = "54D32E16E33655346AA3E66CBA7E191DE0A8793D2C05284E3EFB90AD2CE92BCC" }, 18 | { name = "telemetry", version = "1.2.1", build_tools = ["rebar3"], requirements = [], otp_app = "telemetry", source = "hex", outer_checksum = "DAD9CE9D8EFFC621708F99EAC538EF1CBE05D6A874DD741DE2E689C47FEAFED5" }, 19 | { name = "thoas", version = "0.4.1", build_tools = ["rebar3"], requirements = [], otp_app = "thoas", source = "hex", outer_checksum = "4918D50026C073C4AB1388437132C77A6F6F7C8AC43C60C13758CC0ADCE2134E" }, 20 | { name = "tomerl", version = "0.5.0", build_tools = ["rebar3"], requirements = [], otp_app = "tomerl", source = "hex", outer_checksum = "2A7FB62F9EBF0E75561B39255638BC2B805B437C86FEC538657E7C3B576979FA" }, 21 | ] 22 | 23 | [requirements] 24 | gleam_stdlib = { version = "~> 0.34 or ~> 1.0" } 25 | gleeunit = { version = "~> 1.0" } 26 | nx = { version = "~> 0.5 or ~> 1.0" } 27 | rad = { version = "~> 1.1" } 28 | -------------------------------------------------------------------------------- /src/argamak/format.gleam: -------------------------------------------------------------------------------- 1 | import gleam/string 2 | 3 | /// Numerical formats for tensors. 4 | /// 5 | /// Each `Format` uses a set number of bits to represent every `Float`-like or 6 | /// `Int`-like value. 7 | /// 8 | pub opaque type Format(a) { 9 | Format(a) 10 | } 11 | 12 | /// A 32-bit floating point type, argamak's standard for working with floats. 13 | /// 14 | pub type Float32 { 15 | Float32 16 | } 17 | 18 | /// Creates a 32-bit floating point `Format`, argamak's standard for working 19 | /// with floats. 20 | /// 21 | pub fn float32() -> Format(Float32) { 22 | Format(Float32) 23 | } 24 | 25 | /// A 32-bit signed integer type, argamak's standard for working with ints. 26 | /// 27 | pub type Int32 { 28 | Int32 29 | } 30 | 31 | /// Creates a 32-bit signed integer `Format`, argamak's standard for working 32 | /// with ints. 33 | /// 34 | pub fn int32() -> Format(Int32) { 35 | Format(Int32) 36 | } 37 | 38 | @target(erlang) 39 | /// A 64-bit floating point type. 40 | /// 41 | pub type Float64 { 42 | Float64 43 | } 44 | 45 | @target(erlang) 46 | /// Creates a 64-bit floating point `Format`. 47 | /// 48 | pub fn float64() -> Format(Float64) { 49 | Format(Float64) 50 | } 51 | 52 | @target(erlang) 53 | /// A 64-bit signed integer type. 54 | /// 55 | pub type Int64 { 56 | Int64 57 | } 58 | 59 | @target(erlang) 60 | /// Creates a 64-bit signed integer `Format`. 61 | /// 62 | pub fn int64() -> Format(Int64) { 63 | Format(Int64) 64 | } 65 | 66 | @target(erlang) 67 | /// A 64-bit unsigned integer type. 68 | /// 69 | pub type Uint64 { 70 | Uint64 71 | } 72 | 73 | @target(erlang) 74 | /// Creates a 64-bit unsigned integer `Format`. 75 | /// 76 | pub fn uint64() -> Format(Uint64) { 77 | Format(Uint64) 78 | } 79 | 80 | @target(erlang) 81 | /// A 32-bit unsigned integer type. 82 | /// 83 | pub type Uint32 { 84 | Uint32 85 | } 86 | 87 | @target(erlang) 88 | /// Creates a 32-bit unsigned integer `Format`. 89 | /// 90 | pub fn uint32() -> Format(Uint32) { 91 | Format(Uint32) 92 | } 93 | 94 | @target(erlang) 95 | /// A 16-bit brain floating point type. 96 | /// 97 | pub type Bfloat16 { 98 | Bfloat16 99 | } 100 | 101 | @target(erlang) 102 | /// Creates a 16-bit brain floating point `Format`. 103 | /// 104 | pub fn bfloat16() -> Format(Bfloat16) { 105 | Format(Bfloat16) 106 | } 107 | 108 | @target(erlang) 109 | /// A 16-bit floating point type. 110 | /// 111 | pub type Float16 { 112 | Float16 113 | } 114 | 115 | @target(erlang) 116 | /// Creates a 16-bit floating point `Format`. 117 | /// 118 | pub fn float16() -> Format(Float16) { 119 | Format(Float16) 120 | } 121 | 122 | @target(erlang) 123 | /// A 16-bit signed integer type. 124 | /// 125 | pub type Int16 { 126 | Int16 127 | } 128 | 129 | @target(erlang) 130 | /// Creates a 16-bit signed integer `Format`. 131 | /// 132 | pub fn int16() -> Format(Int16) { 133 | Format(Int16) 134 | } 135 | 136 | @target(erlang) 137 | /// A 16-bit unsigned integer type. 138 | /// 139 | pub type Uint16 { 140 | Uint16 141 | } 142 | 143 | @target(erlang) 144 | /// Creates a 16-bit unsigned integer `Format`. 145 | /// 146 | pub fn uint16() -> Format(Uint16) { 147 | Format(Uint16) 148 | } 149 | 150 | @target(erlang) 151 | /// An 8-bit signed integer type. 152 | /// 153 | pub type Int8 { 154 | Int8 155 | } 156 | 157 | @target(erlang) 158 | /// Creates an 8-bit signed integer `Format`. 159 | /// 160 | pub fn int8() -> Format(Int8) { 161 | Format(Int8) 162 | } 163 | 164 | @target(erlang) 165 | /// An 8-bit unsigned integer type. 166 | /// 167 | pub type Uint8 { 168 | Uint8 169 | } 170 | 171 | @target(erlang) 172 | /// Creates an 8-bit unsigned integer `Format`. 173 | /// 174 | pub fn uint8() -> Format(Uint8) { 175 | Format(Uint8) 176 | } 177 | 178 | /// A type for `Native` format representations. 179 | /// 180 | pub type Native 181 | 182 | /// Converts a given `Format` into its native representation. 183 | /// 184 | pub fn to_native(format: Format(a)) -> Native { 185 | let Format(x) = format 186 | do_to_native(x) 187 | } 188 | 189 | @external(erlang, "argamak_ffi", "format_to_native") 190 | @external(javascript, "../argamak_ffi.mjs", "format_to_native") 191 | fn do_to_native(format: a) -> Native 192 | 193 | /// Converts a `Format` into a `String`. 194 | /// 195 | pub fn to_string(format: Format(a)) -> String { 196 | string.inspect(format) 197 | } 198 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## 1.1.0 - 2024-02-17 4 | 5 | - Argamak now supports `gleam_stdlib` v1.0. 6 | - Argamak now requires Gleam v0.34 or later. 7 | 8 | ## 1.0.0 - 2023-12-22 9 | 10 | - Argamak now requires Gleam v0.33 or later. 11 | - The `tensor` module's `concat` function now returns an `AxisNotFound` error 12 | when the given find function is `False` for every `Axis`. 13 | 14 | ## v0.4.0 - 2023-09-11 15 | 16 | - Argamak now requires Gleam v0.30 or later. 17 | - The `tensor` module gains the `concat` slicing/joining function. 18 | 19 | ## v0.3.0 - 2022-11-20 20 | 21 | - The `tensor` module gains the `TensorResult` type; the `from_bool` and 22 | `from_bools` creation functions; the `size` reflection function; the `squeeze` 23 | transformation function; the `equal`, `not_equal`, `greater`, 24 | `greater_or_equal`, `less`, `less_or_equal`, `logical_and`, `logical_or`, 25 | `logical_xor`, and `logical_not` logical functions; the `add`, `subtract`, 26 | `multiply`, `divide`, `try_divide`, `remainder`, `try_remainder`, `modulo`, 27 | `try_modulo`, `power`, `max`, and `min` arithmetic functions; the 28 | `absolute_value`, `negate`, `sign`, `ceiling`, `floor`, `round`, `exp`, 29 | `square_root`, and `ln` basic math functions; the `all`, `in_situ_all`, `any`, 30 | `in_situ_any`, `arg_max`, `in_situ_arg_max`, `arg_min`, `in_situ_arg_min`, 31 | `max_over`, `in_situ_max_over`, `min_over`, `in_situ_min_over`, `sum`, 32 | `in_situ_sum`, `product`, `in_situ_product`, `mean`, and `in_situ_mean` 33 | reduction functions; the `to_bool`, `to_floats`, `to_ints`, and `to_bools` 34 | conversion functions; and the `debug` and `print_data` utility functions. 35 | - The `tensor` module's `as_format` function has been renamed to `reformat` and 36 | now takes a `Format` record instead of a function reference. 37 | - The `tensor` module no longer includes the `to_list` function. 38 | - The `Tensor` type signature now includes only the numeric format as a generic. 39 | - The `axis` module has been added with the `Axis` and `Axes` types; the `name` 40 | and `size` reflection functions; and the `rename` and `resize` transformation 41 | functions. 42 | - The `space` module gains the `from_list` creation function; and the `map` and 43 | `merge` transformation functions. 44 | - The `space` module no longer includes the `elements` and `map_elements` 45 | functions. 46 | - The `space` module's `d0` function has been renamed to `new` and now returns 47 | an empty `Space` record directly. 48 | - The `space` module and its `Space` and `SpaceError` types have been reworked: 49 | The `Space` type signature no longer includes any generics, and the 50 | constructors `D0` through `D6` have been removed. 51 | - The `Format` type has been reworked and now includes the numeric format as a 52 | generic. 53 | - Several numeric format types have been added to the `format` module. 54 | - The `util` module has been removed. 55 | 56 | ## v0.2.0 - 2022-09-29 57 | 58 | - The `space` module gets an updated `d1` function so the dimension size can be 59 | given. 60 | - The `tensor` module gains the `broadcast`, and `broadcast_over` functions for 61 | all compilation targets. 62 | - Argamak now compiles and runs with the JavaScript target. 63 | - Argamak now uses the `gleam` build tool. 64 | 65 | ## v0.1.0 - 2022-01-20 66 | 67 | - Initial release! 68 | - The `format` module gains the `Format` and `Native` (JavaScript planned) types 69 | for all compilation targets, along with the `float32`, `int32`, `to_native`, 70 | and `to_string` functions for all compilation targets (JavaScript planned) and 71 | the `bfloat16`, `float16`, `float64`, `int16`, `int64`, `int8`, `uint16`, 72 | `uint32`, `uint64`, and `uint8` functions for the Erlang compilation target. 73 | - The `space` module gains the `D0`, `D1`, `D2`, `D3`, `D4`, `D5`, `D6`, 74 | `Space`, and `SpaceError` (with `SpaceErrors` alias) types for all compilation 75 | targets, along with the `axes`, `d0`, `d1`, `d2`, `d3`, `d4`, `d5`, `d6`, 76 | `degree`, `elements`, `map_elements`, `shape`, and `to_string` functions for 77 | all compilation targets (JavaScript planned). 78 | - The `tensor` module gains the `Native` (JavaScript planned), `Tensor`, and 79 | `TensorError` types for all compilation targets, along with the `as_format`, 80 | `axes`, `format`, `from_float`, `from_floats`, `from_int`, `from_ints`, 81 | `from_native`, `print`, `rank`, `reshape`, `shape`, `space`, `to_float`, 82 | `to_int`, `to_list`, and `to_native` functions for all compilation targets 83 | (JavaScript planned). 84 | - The `util` module gains the `UtilError` type for all compilation targets, 85 | along with the `record_to_string` function for all compilation targets. 86 | -------------------------------------------------------------------------------- /src/argamak/axis.gleam: -------------------------------------------------------------------------------- 1 | import gleam/result 2 | import gleam/string 3 | 4 | /// The elements that comprise a `Space`. 5 | /// 6 | /// Except for `Infer`, every `Axis` has a `size` corresponding to the number of 7 | /// values that fit along that `Axis` when a `Tensor` is put into a `Space` 8 | /// containing that `Axis`. 9 | /// 10 | /// An `Axis` can be given a unique `name` using the `Axis` or `Infer` 11 | /// constructors. Single-letter constructors are also provided for convenience. 12 | /// 13 | /// The special `Infer` constructor can be used once per `Space`. It will be 14 | /// replaced and have its `size` computed when a `Tensor` is put into that 15 | /// `Space`. 16 | /// 17 | pub type Axis { 18 | Axis(name: String, size: Int) 19 | Infer(name: String) 20 | A(size: Int) 21 | B(size: Int) 22 | C(size: Int) 23 | D(size: Int) 24 | E(size: Int) 25 | F(size: Int) 26 | G(size: Int) 27 | H(size: Int) 28 | I(size: Int) 29 | J(size: Int) 30 | K(size: Int) 31 | L(size: Int) 32 | M(size: Int) 33 | N(size: Int) 34 | O(size: Int) 35 | P(size: Int) 36 | Q(size: Int) 37 | R(size: Int) 38 | S(size: Int) 39 | T(size: Int) 40 | U(size: Int) 41 | V(size: Int) 42 | W(size: Int) 43 | X(size: Int) 44 | Y(size: Int) 45 | Z(size: Int) 46 | } 47 | 48 | /// An `Axis` list. 49 | /// 50 | pub type Axes = 51 | List(Axis) 52 | 53 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 54 | // Reflection Functions // 55 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 56 | 57 | /// Returns the name of a given `Axis`. 58 | /// 59 | /// ## Examples 60 | /// 61 | /// ```gleam 62 | /// > name(A(1)) 63 | /// "A" 64 | /// 65 | /// > name(Axis(name: "Sparkle", size: 99)) 66 | /// "Sparkle" 67 | /// 68 | /// > name(Infer("Silver")) 69 | /// "Silver" 70 | /// ``` 71 | pub fn name(x: Axis) -> String { 72 | case x { 73 | Axis(name: name, ..) | Infer(name: name) -> name 74 | _else -> 75 | x 76 | |> string.inspect 77 | |> string.first 78 | |> result.unwrap(or: "") 79 | } 80 | } 81 | 82 | /// Returns the size of a given `Axis`. 83 | /// 84 | /// The size of `Infer` is always `0`. 85 | /// 86 | /// ## Examples 87 | /// 88 | /// ```gleam 89 | /// > size(A(1)) 90 | /// 1 91 | /// 92 | /// > size(Axis(name: "Sparkle", size: 99)) 93 | /// 99 94 | /// 95 | /// > size(Infer("Silver")) 96 | /// 0 97 | /// ``` 98 | pub fn size(x: Axis) -> Int { 99 | case x { 100 | Infer(_) -> 0 101 | Axis(size: size, ..) 102 | | A(size) 103 | | B(size) 104 | | C(size) 105 | | D(size) 106 | | E(size) 107 | | F(size) 108 | | G(size) 109 | | H(size) 110 | | I(size) 111 | | J(size) 112 | | K(size) 113 | | L(size) 114 | | M(size) 115 | | N(size) 116 | | O(size) 117 | | P(size) 118 | | Q(size) 119 | | R(size) 120 | | S(size) 121 | | T(size) 122 | | U(size) 123 | | V(size) 124 | | W(size) 125 | | X(size) 126 | | Y(size) 127 | | Z(size) -> size 128 | } 129 | } 130 | 131 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 132 | // Transformation Functions // 133 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 134 | 135 | /// Renames the given `Axis`, retaining its `size`. 136 | /// 137 | /// If an `Axis` is renamed to a single capital letter (from `"A"` to `"Z"` 138 | /// inclusive), the single-letter convenience constructor will be used for the 139 | /// new `Axis`. 140 | /// 141 | /// ## Examples 142 | /// 143 | /// ```gleam 144 | /// > let x = A(1) 145 | /// > rename(x, "B") 146 | /// B(1) 147 | /// 148 | /// > let x = Axis("Thing", 3) 149 | /// > rename(x, "Y") 150 | /// Y(3) 151 | /// 152 | /// > let x = Axis(name: "Sparkle", size: 99) 153 | /// > rename(x, "Shine") 154 | /// Axis("Shine", 99) 155 | /// 156 | /// > let x = Infer("Silver") 157 | /// > rename(x, "Gold") 158 | /// Infer("Gold") 159 | /// ``` 160 | /// 161 | pub fn rename(x: Axis, name: String) -> Axis { 162 | let size = size(x) 163 | let is_infer = case x { 164 | Infer(_) -> True 165 | _else -> False 166 | } 167 | let x = case name { 168 | "A" -> A 169 | "B" -> B 170 | "C" -> C 171 | "D" -> D 172 | "E" -> E 173 | "F" -> F 174 | "G" -> G 175 | "H" -> H 176 | "I" -> I 177 | "J" -> J 178 | "K" -> K 179 | "L" -> L 180 | "M" -> M 181 | "N" -> N 182 | "O" -> O 183 | "P" -> P 184 | "Q" -> Q 185 | "R" -> R 186 | "S" -> S 187 | "T" -> T 188 | "U" -> U 189 | "V" -> V 190 | "W" -> W 191 | "X" -> X 192 | "Y" -> Y 193 | "Z" -> Z 194 | _else if is_infer -> fn(_) { Infer(name) } 195 | _else -> Axis(name: name, size: _) 196 | } 197 | x(size) 198 | } 199 | 200 | /// Changes the `size` of the given `Axis`. 201 | /// 202 | /// Resizing an `Infer` returns an `Axis` record that will no longer have its 203 | /// `size` automatically computed. 204 | /// 205 | /// ## Examples 206 | /// 207 | /// ```gleam 208 | /// > let x = A(1) 209 | /// > resize(x, 3) 210 | /// A(3) 211 | /// 212 | /// > let x = Axis("Y", 2) 213 | /// > resize(x, 3) 214 | /// Y(3) 215 | /// 216 | /// > let x = Axis(name: "Sparkle", size: 99) 217 | /// > resize(x, 42) 218 | /// Axis("Sparkle", 42) 219 | /// 220 | /// > let x = Infer("A") 221 | /// > resize(x, 1) 222 | /// A(1) 223 | /// ``` 224 | /// 225 | pub fn resize(x: Axis, size: Int) -> Axis { 226 | let x = case name(x) { 227 | "A" -> A 228 | "B" -> B 229 | "C" -> C 230 | "D" -> D 231 | "E" -> E 232 | "F" -> F 233 | "G" -> G 234 | "H" -> H 235 | "I" -> I 236 | "J" -> J 237 | "K" -> K 238 | "L" -> L 239 | "M" -> M 240 | "N" -> N 241 | "O" -> O 242 | "P" -> P 243 | "Q" -> Q 244 | "R" -> R 245 | "S" -> S 246 | "T" -> T 247 | "U" -> U 248 | "V" -> V 249 | "W" -> W 250 | "X" -> X 251 | "Y" -> Y 252 | "Z" -> Z 253 | name -> Axis(name: name, size: _) 254 | } 255 | x(size) 256 | } 257 | -------------------------------------------------------------------------------- /src/argamak_ffi.ex: -------------------------------------------------------------------------------- 1 | defmodule :argamak_ffi do 2 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 3 | # Constants # 4 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 5 | 6 | @result :gleam@result 7 | 8 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 9 | # Tensor Creation Functions # 10 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 11 | 12 | def tensor(x, format), do: fn -> Nx.tensor(x, type: format) end |> result 13 | 14 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 15 | # Tensor Reflection Functions # 16 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 17 | 18 | def size(x), do: Nx.size(x) 19 | 20 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 21 | # Tensor Transformation Functions # 22 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 23 | 24 | def reformat(x, like: y), do: Nx.as_type(x, Nx.type(y)) 25 | def reformat(x, format), do: Nx.as_type(x, format) 26 | 27 | def reshape(x, shape), 28 | do: fn -> Nx.reshape(x, :erlang.list_to_tuple(shape)) end |> shape_result 29 | 30 | def broadcast(x, shape), 31 | do: fn -> Nx.broadcast(x, :erlang.list_to_tuple(shape)) end |> shape_result 32 | 33 | def squeeze(x, i), do: Nx.squeeze(x, axes: i) 34 | 35 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 36 | # Tensor Logical Functions # 37 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 38 | 39 | def equal(a, b), 40 | do: fn -> Nx.equal(a, b) |> reformat(like: a) end |> broadcast_result 41 | 42 | def not_equal(a, b), 43 | do: fn -> Nx.not_equal(a, b) |> reformat(like: a) end |> broadcast_result 44 | 45 | def greater(a, b), 46 | do: fn -> Nx.greater(a, b) |> reformat(like: a) end |> broadcast_result 47 | 48 | def greater_or_equal(a, b), 49 | do: fn -> Nx.greater_equal(a, b) |> reformat(like: a) end |> broadcast_result 50 | 51 | def less(a, b), 52 | do: fn -> Nx.less(a, b) |> reformat(like: a) end |> broadcast_result 53 | 54 | def less_or_equal(a, b), 55 | do: fn -> Nx.less_equal(a, b) |> reformat(like: a) end |> broadcast_result 56 | 57 | def logical_and(a, b), 58 | do: fn -> Nx.logical_and(a, b) |> reformat(like: a) end |> broadcast_result 59 | 60 | def logical_or(a, b), 61 | do: fn -> Nx.logical_or(a, b) |> reformat(like: a) end |> broadcast_result 62 | 63 | def logical_xor(a, b), 64 | do: fn -> Nx.logical_xor(a, b) |> reformat(like: a) end |> broadcast_result 65 | 66 | def logical_not(x), do: Nx.logical_not(x) |> reformat(like: x) 67 | 68 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 69 | # Tensor Arithmetic Functions # 70 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 71 | 72 | def add(a, b), do: fn -> Nx.add(a, b) end |> broadcast_result 73 | 74 | def subtract(a, b), do: fn -> Nx.subtract(a, b) end |> broadcast_result 75 | 76 | def multiply(a, b), do: fn -> Nx.multiply(a, b) end |> broadcast_result 77 | 78 | def divide(a, b), 79 | do: fn -> Nx.divide(a, b) end |> broadcast_result |> @result.map(&clip_reformat(&1, like: a)) 80 | 81 | def remainder(a, b), do: fn -> Nx.remainder(a, b) end |> broadcast_result 82 | 83 | def power(a, b), do: fn -> Nx.pow(a, b) end |> broadcast_result 84 | 85 | def max(a, b), do: fn -> Nx.max(a, b) end |> broadcast_result 86 | 87 | def min(a, b), do: fn -> Nx.min(a, b) end |> broadcast_result 88 | 89 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 90 | # Tensor Basic Math Functions # 91 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 92 | 93 | def absolute_value(x), do: Nx.abs(x) 94 | 95 | def negate(x), do: Nx.negate(x) 96 | 97 | def sign(x), do: Nx.sign(x) 98 | 99 | def ceiling(x), do: Nx.ceil(x) 100 | 101 | def floor(x), do: Nx.floor(x) 102 | 103 | def round(x), do: Nx.round(x) 104 | 105 | def exp(x), do: Nx.exp(x) |> clip_reformat(like: x) 106 | 107 | def square_root(x), 108 | do: fn -> Nx.sqrt(x) end |> result |> @result.map(&clip_reformat(&1, like: x)) 109 | 110 | def ln(x), 111 | do: fn -> Nx.log(x) end |> result |> @result.map(&clip_reformat(&1, like: x)) 112 | 113 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 114 | # Tensor Reduction Functions # 115 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 116 | 117 | def all(x, i), do: Nx.all(x, axes: i) |> reformat(like: x) 118 | 119 | def any(x, i), do: Nx.any(x, axes: i) |> reformat(like: x) 120 | 121 | def arg_max(x, i), do: Nx.argmax(x, axis: i) |> reformat(like: x) 122 | 123 | def arg_min(x, i), do: Nx.argmin(x, axis: i) |> reformat(like: x) 124 | 125 | def max_over(x, i), do: Nx.reduce_max(x, axes: i) 126 | 127 | def min_over(x, i), do: Nx.reduce_min(x, axes: i) 128 | 129 | def sum(x, i), do: Nx.sum(x, axes: i) |> clip_reformat(like: x) 130 | 131 | def product(x, i), do: Nx.product(x, axes: i) |> clip_reformat(like: x) 132 | 133 | def mean(x, i), do: Nx.mean(x, axes: i) |> reformat(like: x) 134 | 135 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 136 | # Tensor Slicing & Joining Functions # 137 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 138 | 139 | def concat(xs, i), do: Nx.concatenate(xs, axis: i) 140 | 141 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 142 | # Tensor Conversion Functions # 143 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 144 | 145 | def to_float(x), 146 | do: fn -> clip_reformat(x, float_format(x)) |> Nx.to_number() end |> shape_result 147 | 148 | def to_int(x), 149 | do: fn -> clip_reformat(x, int_format(x)) |> Nx.to_number() end |> shape_result 150 | 151 | def to_floats(x), do: clip_reformat(x, float_format(x)) |> Nx.to_flat_list() 152 | 153 | def to_ints(x), do: clip_reformat(x, int_format(x)) |> Nx.to_flat_list() 154 | 155 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 156 | # Tensor Utility Functions # 157 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 158 | 159 | def prepare_to_string(x) do 160 | {format, _} = Nx.type(x) 161 | is_int = Enum.any?([:s, :u], fn x -> x == format end) 162 | 163 | x 164 | |> Nx.to_flat_list() 165 | |> Enum.reverse() 166 | |> Enum.reduce({[], 0}, fn x, {xs, item_width} -> 167 | x = 168 | case x do 169 | :infinity -> 170 | "Infinity" 171 | 172 | :neg_infinity -> 173 | "-Infinity" 174 | 175 | _else -> 176 | x = 177 | ~s(~.3#{if is_int, do: "f", else: "g"}) 178 | |> :io_lib.format([x + 0.0]) 179 | |> :erlang.list_to_binary() 180 | |> String.trim_trailing("e+0") 181 | 182 | x = 183 | if String.contains?(x, "e") do 184 | x 185 | else 186 | String.trim_trailing(x, "0") 187 | end 188 | 189 | if is_int do 190 | String.trim_trailing(x, ".") 191 | else 192 | String.replace_trailing(x, ".", ".0") 193 | end 194 | end 195 | 196 | {[x | xs], Kernel.max(String.length(x), item_width)} 197 | end) 198 | end 199 | 200 | def columns() do 201 | case :io.columns() do 202 | {:ok, columns} -> columns 203 | _else -> 0 204 | end 205 | end 206 | 207 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 208 | # Format Functions # 209 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 210 | 211 | def format_to_native(x) do 212 | case x do 213 | :float64 -> {:f, 64} 214 | :int64 -> {:s, 64} 215 | :uint64 -> {:u, 64} 216 | :float32 -> {:f, 32} 217 | :int32 -> {:s, 32} 218 | :uint32 -> {:u, 32} 219 | :bfloat16 -> {:bf, 16} 220 | :float16 -> {:f, 16} 221 | :int16 -> {:s, 16} 222 | :uint16 -> {:u, 16} 223 | :int8 -> {:s, 8} 224 | :uint8 -> {:u, 8} 225 | end 226 | end 227 | 228 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 229 | # Private Tensor Functions # 230 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 231 | 232 | defp clip_reformat(x, like: y) do 233 | if Nx.type(x) == Nx.type(y) do 234 | # Don't clip if reformat is noop. 235 | x 236 | else 237 | x 238 | |> reformat(like: y) 239 | |> clip(based_on: x) 240 | end 241 | end 242 | 243 | defp clip_reformat(x, format) do 244 | x 245 | |> reformat(format) 246 | |> clip(based_on: x) 247 | end 248 | 249 | defp clip(x, based_on: y) do 250 | format = Nx.type(x) 251 | 252 | scalar = fn f -> 253 | format 254 | |> f.() 255 | |> Nx.from_binary(format) 256 | |> Nx.reshape({}) 257 | end 258 | 259 | min = scalar.(&Nx.Type.min_finite_binary/1) 260 | max = scalar.(&Nx.Type.max_finite_binary/1) 261 | 262 | less = Nx.less(y, min) 263 | greater = Nx.greater(y, max) 264 | 265 | replace = fn x, predicate, with: value -> Nx.select(predicate, value, x) end 266 | 267 | x 268 | |> replace.(less, with: min) 269 | |> replace.(greater, with: max) 270 | end 271 | 272 | defp float_format(x), do: Nx.type(x) |> Nx.Type.to_floating() 273 | 274 | defp int_format(x) do 275 | format = Nx.type(x) 276 | if Nx.Type.integer?(format), do: format, else: {:s, 32} 277 | end 278 | 279 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 280 | # Private Result Functions # 281 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 282 | 283 | defp broadcast_result(f), do: result(f, or: :cannot_broadcast) 284 | 285 | defp shape_result(f), do: result(f, or: :incompatible_shape) 286 | 287 | defp result(f, opts \\ []) do 288 | try do 289 | x = f.() 290 | 291 | case x |> Nx.is_nan() |> Nx.any() |> Nx.to_number() do 292 | 0 -> {:ok, x} 293 | 1 -> {:error, :invalid_data} 294 | end 295 | rescue 296 | ArithmeticError -> {:error, :invalid_data} 297 | _else -> {:error, Keyword.get(opts, :or, :invalid_data)} 298 | end 299 | end 300 | end 301 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # argamak 🐎 2 | 3 | [![Hex Package](https://img.shields.io/hexpm/v/argamak?color=ffaff3&label&labelColor=2f2f2f&logo=data:image/svg+xml;base64,PHN2ZyByb2xlPSJpbWciIHZpZXdCb3g9IjAgMCAyNCAyNCIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj48cGF0aCBmaWxsPSIjZmVmZWZjIiBkPSJNIDYuMjgzMiwxLjU5OTYgOS4yODMyLDYuNzk0OSBIIDE0LjcwNTEgTCAxNy43MDUxLDEuNTk5NiBaIE0gMTguMTQwNywxLjg0MzggbCAtMyw1LjE5NzMgMi43MTQ5LDQuNjk5MiBoIDYgeiBNIDUuODUzNSwxLjg1NTUgMC4xNDQ1LDExLjc0MDIgSCA2LjE0NDUgTCA4Ljg1MTYsNy4wNDg4IFogTSAwLjE0NDUsMTIuMjQwMiA1Ljg1MzUsMjIuMTI3IDguODUxNiwxNi45MzM2IDYuMTQ0NSwxMi4yNDAyIFogbSAxNy43MTEsMCAtMi43MTQ5LDQuNzAxMiAzLDUuMTk1MyA1LjcxNDksLTkuODk2NSB6IE0gOS4yODMyLDE3LjE4NzUgNi4yODUyLDIyLjM4MDkgSCAxNy43MDMyIEwgMTQuNzA1MSwxNy4xODc1IFoiLz48L3N2Zz4K)](https://hex.pm/packages/argamak) 4 | [![Hex Docs](https://img.shields.io/badge/hex-docs-ffaff3?label&labelColor=2f2f2f&logo=data:image/svg+xml;base64,PHN2ZyByb2xlPSJpbWciIHZpZXdCb3g9IjAgMCAyNiAyOCIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj48cGF0aCBmaWxsPSIjZmVmZWZjIiBkPSJNMjUuNjA5IDcuNDY5YzAuMzkxIDAuNTYyIDAuNSAxLjI5NyAwLjI4MSAyLjAxNmwtNC4yOTcgMTQuMTU2Yy0wLjM5MSAxLjMyOC0xLjc2NiAyLjM1OS0zLjEwOSAyLjM1OWgtMTQuNDIyYy0xLjU5NCAwLTMuMjk3LTEuMjY2LTMuODc1LTIuODkxLTAuMjUtMC43MDMtMC4yNS0xLjM5MS0wLjAzMS0xLjk4NCAwLjAzMS0wLjMxMyAwLjA5NC0wLjYyNSAwLjEwOS0xIDAuMDE2LTAuMjUtMC4xMjUtMC40NTMtMC4wOTQtMC42NDEgMC4wNjMtMC4zNzUgMC4zOTEtMC42NDEgMC42NDEtMS4wNjIgMC40NjktMC43ODEgMS0yLjA0NyAxLjE3Mi0yLjg1OSAwLjA3OC0wLjI5Ny0wLjA3OC0wLjY0MSAwLTAuOTA2IDAuMDc4LTAuMjk3IDAuMzc1LTAuNTE2IDAuNTMxLTAuNzk3IDAuNDIyLTAuNzE5IDAuOTY5LTIuMTA5IDEuMDQ3LTIuODQ0IDAuMDMxLTAuMzI4LTAuMTI1LTAuNjg4LTAuMDMxLTAuOTM4IDAuMTA5LTAuMzU5IDAuNDUzLTAuNTE2IDAuNjg4LTAuODI4IDAuMzc1LTAuNTE2IDEtMiAxLjA5NC0yLjgyOCAwLjAzMS0wLjI2Ni0wLjEyNS0wLjUzMS0wLjA3OC0wLjgxMiAwLjA2My0wLjI5NyAwLjQzOC0wLjYwOSAwLjY4OC0wLjk2OSAwLjY1Ni0wLjk2OSAwLjc4MS0zLjEwOSAyLjc2Ni0yLjU0N2wtMC4wMTYgMC4wNDdjMC4yNjYtMC4wNjMgMC41MzEtMC4xNDEgMC43OTctMC4xNDFoMTEuODkxYzAuNzM0IDAgMS4zOTEgMC4zMjggMS43ODEgMC44NzUgMC40MDYgMC41NjIgMC41IDEuMjk3IDAuMjgxIDIuMDMxbC00LjI4MSAxNC4xNTZjLTAuNzM0IDIuNDA2LTEuMTQxIDIuOTM4LTMuMTI1IDIuOTM4aC0xMy41NzhjLTAuMjAzIDAtMC40NTMgMC4wNDctMC41OTQgMC4yMzQtMC4xMjUgMC4xODctMC4xNDEgMC4zMjgtMC4wMTYgMC42NzIgMC4zMTMgMC45MDYgMS4zOTEgMS4wOTQgMi4yNSAxLjA5NGgxNC40MjJjMC41NzggMCAxLjI1LTAuMzI4IDEuNDIyLTAuODkxbDQuNjg4LTE1LjQyMmMwLjA5NC0wLjI5NyAwLjA5NC0wLjYwOSAwLjA3OC0wLjg5MSAwLjM1OSAwLjE0MSAwLjY4OCAwLjM1OSAwLjkyMiAwLjY3MnpNOC45ODQgNy41Yy0wLjA5NCAwLjI4MSAwLjA2MyAwLjUgMC4zNDQgMC41aDkuNWMwLjI2NiAwIDAuNTYyLTAuMjE5IDAuNjU2LTAuNWwwLjMyOC0xYzAuMDk0LTAuMjgxLTAuMDYzLTAuNS0wLjM0NC0wLjVoLTkuNWMtMC4yNjYgMC0wLjU2MiAwLjIxOS0wLjY1NiAwLjV6TTcuNjg4IDExLjVjLTAuMDk0IDAuMjgxIDAuMDYzIDAuNSAwLjM0NCAwLjVoOS41YzAuMjY2IDAgMC41NjItMC4yMTkgMC42NTYtMC41bDAuMzI4LTFjMC4wOTQtMC4yODEtMC4wNjMtMC41LTAuMzQ0LTAuNWgtOS41Yy0wLjI2NiAwLTAuNTYyIDAuMjE5LTAuNjU2IDAuNXoiPjwvcGF0aD48L3N2Zz4K)](https://hexdocs.pm/argamak/) 5 | [![License](https://img.shields.io/hexpm/l/argamak?color=ffaff3&label&labelColor=2f2f2f&logo=data:image/svg+xml;base64,PHN2ZyB2ZXJzaW9uPSIxLjEiIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyIgd2lkdGg9IjM0IiBoZWlnaHQ9IjI4IiB2aWV3Qm94PSIwIDAgMzQgMjgiPgo8cGF0aCBmaWxsPSIjZmVmZWZjIiBkPSJNMjcgN2wtNiAxMWgxMnpNNyA3bC02IDExaDEyek0xOS44MjggNGMtMC4yOTcgMC44NDQtMC45ODQgMS41MzEtMS44MjggMS44Mjh2MjAuMTcyaDkuNWMwLjI4MSAwIDAuNSAwLjIxOSAwLjUgMC41djFjMCAwLjI4MS0wLjIxOSAwLjUtMC41IDAuNWgtMjFjLTAuMjgxIDAtMC41LTAuMjE5LTAuNS0wLjV2LTFjMC0wLjI4MSAwLjIxOS0wLjUgMC41LTAuNWg5LjV2LTIwLjE3MmMtMC44NDQtMC4yOTctMS41MzEtMC45ODQtMS44MjgtMS44MjhoLTcuNjcyYy0wLjI4MSAwLTAuNS0wLjIxOS0wLjUtMC41di0xYzAtMC4yODEgMC4yMTktMC41IDAuNS0wLjVoNy42NzJjMC40MjItMS4xNzIgMS41MTYtMiAyLjgyOC0yczIuNDA2IDAuODI4IDIuODI4IDJoNy42NzJjMC4yODEgMCAwLjUgMC4yMTkgMC41IDAuNXYxYzAgMC4yODEtMC4yMTkgMC41LTAuNSAwLjVoLTcuNjcyek0xNyA0LjI1YzAuNjg4IDAgMS4yNS0wLjU2MiAxLjI1LTEuMjVzLTAuNTYyLTEuMjUtMS4yNS0xLjI1LTEuMjUgMC41NjItMS4yNSAxLjI1IDAuNTYyIDEuMjUgMS4yNSAxLjI1ek0zNCAxOGMwIDMuMjE5LTQuNDUzIDQuNS03IDQuNXMtNy0xLjI4MS03LTQuNXYwYzAtMC42MDkgNS40NTMtMTAuMjY2IDYuMTI1LTExLjQ4NCAwLjE3Mi0wLjMxMyAwLjUxNi0wLjUxNiAwLjg3NS0wLjUxNnMwLjcwMyAwLjIwMyAwLjg3NSAwLjUxNmMwLjY3MiAxLjIxOSA2LjEyNSAxMC44NzUgNi4xMjUgMTEuNDg0djB6TTE0IDE4YzAgMy4yMTktNC40NTMgNC41LTcgNC41cy03LTEuMjgxLTctNC41djBjMC0wLjYwOSA1LjQ1My0xMC4yNjYgNi4xMjUtMTEuNDg0IDAuMTcyLTAuMzEzIDAuNTE2LTAuNTE2IDAuODc1LTAuNTE2czAuNzAzIDAuMjAzIDAuODc1IDAuNTE2YzAuNjcyIDEuMjE5IDYuMTI1IDEwLjg3NSA2LjEyNSAxMS40ODR6Ij48L3BhdGg+Cjwvc3ZnPgo=)](https://github.com/tynanbe/argamak/blob/main/LICENSE) 6 | [![Build](https://img.shields.io/github/actions/workflow/status/tynanbe/argamak/ci.yml?branch=main&color=ffaff3&label&labelColor=2f2f2f&logo=github-actions&logoColor=fefefc)](https://github.com/tynanbe/argamak/actions) 7 | 8 | A Gleam library for tensor maths. 9 | 10 | > “I admire the elegance of your method of computation; it must be nice to ride 11 | > through these fields upon the horse of true mathematics while the like of us 12 | > have to make our way laboriously on foot.” 13 | > 14 | > —Albert Einstein, to Tullio Levi-Civita, circa 1915–1917 15 | 16 |

Argamak: A shiny steed.

17 | 18 | ## Installation 19 | 20 | ### As a dependency of your Gleam project 21 | 22 | • Add `argamak` to `gleam.toml` from the command line 23 | 24 | ```shell 25 | $ gleam add argamak 26 | ``` 27 | 28 | ### As a dependency of your Mix project 29 | 30 | • Add `argamak` to `mix.exs` 31 | 32 | ```elixir 33 | defp deps do 34 | [ 35 | {:argamak, "~> 1.1"}, 36 | ] 37 | end 38 | ``` 39 | 40 | ### As a dependency of your Rebar3 project 41 | 42 | • Add `argamak` to `rebar.config` 43 | 44 | ```erlang 45 | {deps, [ 46 | {argamak, "1.1.0"} 47 | ]}. 48 | ``` 49 | 50 | ### JavaScript 51 | 52 | The `@tensorflow/tfjs` package is a runtime requirement for `argamak`; however, 53 | its import path in the `argamak_ffi.mjs` module might need adjustment, depending 54 | on your use case. It can be used as is in your Node.js project after running 55 | `npm install @tensorflow/tfjs-node` or an equivalent command for your package 56 | manager of choice. 57 | 58 | ## Usage 59 | 60 | ```gleam 61 | // derby.gleam 62 | import gleam/function 63 | import gleam/io 64 | import gleam/list 65 | import gleam/result 66 | import gleam/string 67 | import argamak/axis.{Axis, Infer} 68 | import argamak/space 69 | import argamak/tensor.{type TensorError, InvalidData} 70 | 71 | pub fn announce_winner( 72 | from horses: List(String), 73 | with times: List(Float), 74 | ) -> Result(Nil, TensorError) { 75 | // Space records help maintain a clear understanding of a Tensor's data. 76 | // 77 | // We begin by creating a two-dimensional Space with "Horse" and "Trial" Axes. 78 | // The "Trial" Axis size is two because horses always run twice in our derby. 79 | // The "Horse" Axis size will be inferred based on the data when a Tensor is 80 | // put into our Space (perhaps we won't always know how many horses will run). 81 | // 82 | use d2 <- result.try( 83 | space.d2(Infer(name: "Horse"), Axis(name: "Trial", size: 2)) 84 | |> result.map_error(with: tensor.SpaceErrors), 85 | ) 86 | 87 | // Every Tensor has a numerical Format, a Space, and some data. 88 | // A 2d Tensor can be visualized like a table or matrix. 89 | // 90 | // Tensor( 91 | // Format(Float32) 92 | // Space(Axis("Horse", 5), Axis("Trial", 2)) 93 | // 94 | // Trial 95 | // H [[horse1_time1, horse1_time2], 96 | // o [horse2_time1, horse2_time2], 97 | // r [horse3_time1, horse3_time2], 98 | // s [horse4_time1, horse4_time2], 99 | // e [horse5_time1, horse5_time2]], 100 | // ) 101 | // 102 | // Next we create a Tensor from a List of times and put it into our 2d Space. 103 | // 104 | use x <- result.try(tensor.from_floats(of: times, into: d2)) 105 | 106 | let announce = function.compose(string.inspect, io.println) 107 | 108 | announce("Trial times per horse") 109 | tensor.print(x) 110 | 111 | // Axes can be referenced by name. 112 | // 113 | // Here we reduce away the "Trial" Axis to get each horse's mean run time. 114 | // 115 | announce("Mean time per horse") 116 | let mean_times = 117 | x 118 | |> tensor.mean(with: fn(a) { axis.name(a) == "Trial" }) 119 | |> tensor.debug 120 | 121 | // This catch-all function will reduce away all Axes, although at this point 122 | // only the "Horse" Axis remains. 123 | // 124 | let all_axes = fn(_) { True } 125 | 126 | // We get a String representation of the minimum mean time. 127 | // 128 | announce("Fastest mean time") 129 | let time = 130 | mean_times 131 | |> tensor.min_over(with: all_axes) 132 | |> tensor.debug 133 | |> tensor.to_string(return: tensor.Data, wrap_at: 0) 134 | 135 | // And we get an index number, followed by the name of the winning horse. 136 | // 137 | announce("Fastest horse") 138 | use horse <- result.try( 139 | mean_times 140 | |> tensor.arg_min(with: all_axes) 141 | |> tensor.debug 142 | |> tensor.to_int, 143 | ) 144 | use horse <- result.try( 145 | horses 146 | |> list.at(get: horse) 147 | |> result.replace_error(InvalidData), 148 | ) 149 | 150 | // Finally, we make our announcement! 151 | // 152 | { horse <> " wins the day with a mean time of " <> time <> " minutes!" } 153 | |> announce 154 | |> Ok 155 | } 156 | ``` 157 | 158 | ### Example 159 | 160 | ```gleam 161 | > derby.announce_winner( 162 | > from: ["Pony Express", "Hay Girl", "Low Rider"], 163 | > with: [1.2, 1.3, 1.3, 1.0, 1.5, 0.9], 164 | > ) 165 | "Trial times per horse" 166 | Tensor( 167 | Format(Float32), 168 | Space(Axis("Horse", 3), Axis("Trial", 2)), 169 | [[1.2, 1.3], 170 | [1.3, 1.0], 171 | [1.5, 0.9]], 172 | ) 173 | "Mean time per horse" 174 | Tensor( 175 | Format(Float32), 176 | Space(Axis("Horse", 3)), 177 | [1.25, 1.15, 1.2], 178 | ) 179 | "Fastest mean time" 180 | Tensor( 181 | Format(Float32), 182 | Space(), 183 | 1.15, 184 | ) 185 | "Fastest horse" 186 | Tensor( 187 | Format(Float32), 188 | Space(), 189 | 1.0, 190 | ) 191 | "Hay Girl wins the day with a mean time of 1.15 minutes!" 192 | Ok(Nil) 193 | ``` 194 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | -------------------------------------------------------------------------------- /src/argamak_ffi.mjs: -------------------------------------------------------------------------------- 1 | import * as tf from "@tensorflow/tfjs-node"; 2 | import { inspect } from "../gleam_stdlib/gleam_stdlib.mjs"; 3 | import { Error as GleamError, List, Ok, Result, toList } from "./gleam.mjs"; 4 | import { 5 | CannotBroadcast, 6 | IncompatibleShape, 7 | InvalidData, 8 | } from "./argamak/tensor.mjs"; 9 | 10 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 11 | // Constants // 12 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 13 | 14 | const Nil = undefined; 15 | 16 | const Tensor = tf.Tensor.prototype; 17 | 18 | const Extrema = { 19 | int32: { min: -2_147_483_648, max: 2_147_483_647 }, 20 | float32: { 21 | min: -340_282_346_638_528_859_811_704_183_484_516_925_440, 22 | max: 340_282_346_638_528_859_811_704_183_484_516_925_440, 23 | }, 24 | }; 25 | 26 | class Fn { 27 | constructor(f) { 28 | this.f = f; 29 | } 30 | 31 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 32 | // Tensor Methods // 33 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 34 | 35 | static clip_reformat_like(x) { 36 | return x.dtype === this.dtype 37 | // Don't clip if reformat is noop. 38 | ? this 39 | : this.reformat_like(x).clip_based_on(this); 40 | } 41 | 42 | static clip_reformat(x) { 43 | return (x === this.dtype ? this : this.cast(x)).clip_based_on(this); 44 | } 45 | 46 | static clip_based_on(x) { 47 | let format = this.dtype; 48 | 49 | let scalar = (x) => tf.scalar(Extrema[format][x], format); 50 | 51 | let min = scalar("min"); 52 | let max = scalar("max"); 53 | 54 | let less = x.less(min); 55 | let greater = x.greater(max); 56 | 57 | Tensor.replace = function (predicate, value) { 58 | return value.where(predicate, this); 59 | }; 60 | 61 | return this.replace(less, min).replace(greater, max); 62 | } 63 | 64 | static reformat_like(x) { 65 | x = x.dtype; 66 | return x === this.dtype ? this : this.cast(x); 67 | } 68 | 69 | static to_number() { 70 | if (Number.isFinite(this)) { 71 | return this; 72 | } else if (tf.util.isScalarShape(this.shape)) { 73 | return this.arraySync(); 74 | } 75 | throw new Error(Nil); 76 | } 77 | 78 | static to_flat_list() { 79 | return toList(this.reshape([-1]).arraySync()); 80 | } 81 | 82 | static add_fn(x) { 83 | this[x] = Fn[x]; 84 | return this; 85 | } 86 | 87 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 88 | // Result Methods // 89 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 90 | 91 | static map(f) { 92 | return this.isOk() ? new Ok(f(this[0])) : this; 93 | } 94 | 95 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 96 | // Fn Methods // 97 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 98 | 99 | broadcast_result() { 100 | return this.result(CannotBroadcast); 101 | } 102 | 103 | shape_result() { 104 | return this.result(IncompatibleShape); 105 | } 106 | 107 | result(error_type = InvalidData) { 108 | try { 109 | let x = this.f(); 110 | // Detect any NaN 111 | return tf.equal(x, x).all().arraySync() 112 | ? new Ok(x) 113 | : new GleamError(new InvalidData()); 114 | } catch { 115 | return new GleamError(new error_type()); 116 | } 117 | } 118 | } 119 | 120 | Tensor.add_fn = Fn.add_fn; 121 | 122 | Tensor.add_fn("clip_reformat_like") 123 | .add_fn("clip_reformat") 124 | .add_fn("clip_based_on") 125 | .add_fn("reformat_like") 126 | .add_fn("to_number") 127 | .add_fn("to_flat_list"); 128 | 129 | Result.prototype.map = Fn.map; 130 | 131 | const fn = (f) => new Fn(f); 132 | 133 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 134 | // Tensor Creation Functions // 135 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 136 | 137 | export const tensor = (x, format) => 138 | fn(() => { 139 | if (x instanceof List) { 140 | x = x.toArray(); 141 | if (!x.length) { 142 | throw new Error(Nil); 143 | } 144 | } 145 | let shape = Array.isArray(x) ? [x.length] : []; 146 | return tf.tensor(x, shape, format); 147 | }).result(); 148 | 149 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 150 | // Tensor Reflection Functions // 151 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 152 | 153 | export const size = (x) => x.size; 154 | 155 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 156 | // Tensor Transformation Functions // 157 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 158 | 159 | export const reformat = (x, format) => format === x.dtype ? x : x.cast(format); 160 | 161 | export const reshape = (x, shape) => 162 | fn(() => x.reshape(shape.toArray())).shape_result(); 163 | 164 | export const broadcast = (x, shape) => 165 | fn(() => x.broadcastTo(shape.toArray())).shape_result(); 166 | 167 | export const squeeze = (x, i) => x.squeeze(i.toArray()); 168 | 169 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 170 | // Tensor Logical Functions // 171 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 172 | 173 | export const equal = (a, b) => 174 | fn(() => a.equal(b).reformat_like(a)).broadcast_result(); 175 | 176 | export const not_equal = (a, b) => 177 | fn(() => a.notEqual(b).reformat_like(a)).broadcast_result(); 178 | 179 | export const greater = (a, b) => 180 | fn(() => a.greater(b).reformat_like(a)).broadcast_result(); 181 | 182 | export const greater_or_equal = (a, b) => 183 | fn(() => a.greaterEqual(b).reformat_like(a)).broadcast_result(); 184 | 185 | export const less = (a, b) => 186 | fn(() => a.less(b).reformat_like(a)).broadcast_result(); 187 | 188 | export const less_or_equal = (a, b) => 189 | fn(() => a.lessEqual(b).reformat_like(a)).broadcast_result(); 190 | 191 | export const logical_and = (a, b) => 192 | fn(() => a.cast("bool").logicalAnd(b.cast("bool")).reformat_like(a)) 193 | .broadcast_result(); 194 | 195 | export const logical_or = (a, b) => 196 | fn(() => a.cast("bool").logicalOr(b.cast("bool")).reformat_like(a)) 197 | .broadcast_result(); 198 | 199 | export const logical_xor = (a, b) => 200 | fn(() => a.cast("bool").logicalXor(b.cast("bool")).reformat_like(a)) 201 | .broadcast_result(); 202 | 203 | export const logical_not = (x) => x.cast("bool").logicalNot().reformat_like(x); 204 | 205 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 206 | // Tensor Arithmetic Functions // 207 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 208 | 209 | export const add = (a, b) => fn(() => a.add(b)).broadcast_result(); 210 | 211 | export const subtract = (a, b) => fn(() => a.sub(b)).broadcast_result(); 212 | 213 | export const multiply = (a, b) => fn(() => a.mul(b)).broadcast_result(); 214 | 215 | export const divide = (a, b) => 216 | fn(() => a.div(b)) 217 | .broadcast_result() 218 | .map((x) => x.clip_reformat_like(a)); 219 | 220 | export const modulo = (a, b) => fn(() => a.mod(b)).broadcast_result(); 221 | 222 | export const power = (a, b) => fn(() => a.pow(b)).broadcast_result(); 223 | 224 | export const max = (a, b) => fn(() => a.maximum(b)).broadcast_result(); 225 | 226 | export const min = (a, b) => fn(() => a.minimum(b)).broadcast_result(); 227 | 228 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 229 | // Tensor Basic Math Functions // 230 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 231 | 232 | export const absolute_value = tf.abs; 233 | 234 | export const negate = tf.neg; 235 | 236 | export const sign = tf.sign; 237 | 238 | export const ceiling = (x) => x.cast("float32").ceil().reformat_like(x); 239 | 240 | export const floor = (x) => x.cast("float32").floor().reformat_like(x); 241 | 242 | export const round = tf.round; 243 | 244 | export const exp = (x) => x.exp().clip_reformat_like(x); 245 | 246 | export const square_root = (x) => 247 | fn(() => x.cast("float32").sqrt()) 248 | .result() 249 | .map((y) => y.clip_reformat_like(x)); 250 | 251 | export const ln = (x) => 252 | fn(() => x.cast("float32").log()) 253 | .result() 254 | .map((y) => y.clip_reformat_like(x)); 255 | 256 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 257 | // Tensor Reduction Functions // 258 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 259 | 260 | export const all = (x, i) => x.cast("bool").all(i.toArray()).reformat_like(x); 261 | 262 | export const any = (x, i) => x.cast("bool").any(i.toArray()).reformat_like(x); 263 | 264 | export const arg_max = (x, i) => x.argMax(i).reformat_like(x); 265 | 266 | export const arg_min = (x, i) => x.argMin(i).reformat_like(x); 267 | 268 | export const max_over = (x, i) => x.max(i.toArray()); 269 | 270 | export const min_over = (x, i) => x.min(i.toArray()); 271 | 272 | export const sum = (x, i) => x.sum(i.toArray()).clip_reformat_like(x); 273 | 274 | export const product = (x, i) => x.prod(i.toArray()).clip_reformat_like(x); 275 | 276 | export const mean = (x, i) => x.mean(i.toArray()).reformat_like(x); 277 | 278 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 279 | // Tensor Slicing & Joining Functions // 280 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 281 | 282 | export const concat = (xs, i) => tf.concat(xs.toArray(), i); 283 | 284 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 285 | // Tensor Conversion Functions // 286 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 287 | 288 | export const to_float = (x) => 289 | fn(() => x.clip_reformat("float32").to_number()).shape_result(); 290 | 291 | export const to_int = (x) => 292 | fn(() => x.clip_reformat("int32").to_number()).shape_result(); 293 | 294 | export const to_floats = (x) => x.clip_reformat("float32").to_flat_list(); 295 | 296 | export const to_ints = (x) => x.clip_reformat("int32").to_flat_list(); 297 | 298 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 299 | // Tensor Utility Functions // 300 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 301 | 302 | export function prepare_to_string(x) { 303 | let is_int = "int32" === x.dtype; 304 | 305 | let [xs, item_width] = x 306 | .reshape([-1]) 307 | .arraySync() 308 | .reduce( 309 | ([xs, item_width], x) => { 310 | x = Number(x).toFixed(3); 311 | x = x.includes("e") ? x : trim_trailing(x, "0"); 312 | if (is_int) { 313 | x = trim_trailing(x, "."); 314 | } else { 315 | x = x.substr(-1) === "." ? `${x}0` : x; 316 | } 317 | return [[...xs, x], Math.max(x.length, item_width)]; 318 | }, 319 | [[], 0], 320 | ); 321 | 322 | return [toList(xs), item_width]; 323 | } 324 | 325 | export function columns() { 326 | let stdout = process.stdout; 327 | return stdout.isTTY ? stdout.columns : 0; 328 | } 329 | 330 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 331 | // Format Functions // 332 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 333 | 334 | export const format_to_native = (x) => inspect(x).toLowerCase(); 335 | 336 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 337 | // Private Functions // 338 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 339 | 340 | function trim_trailing(x, character) { 341 | x = `${x}`; 342 | let i = x.length; 343 | while (x.charAt(--i) === character) { 344 | x = x.slice(0, i); 345 | } 346 | return x; 347 | } 348 | -------------------------------------------------------------------------------- /test/argamak/space_test.gleam: -------------------------------------------------------------------------------- 1 | import gleam/list 2 | import gleeunit/should 3 | import argamak/axis.{A, Axis, B, C, D, E, Infer, Z} 4 | import argamak/space.{ 5 | CannotInfer, CannotMerge, DuplicateName, InvalidSize, SpaceError, 6 | } 7 | 8 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 9 | // Creation Functions // 10 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 11 | 12 | pub fn new_test() { 13 | space.new() 14 | |> space.axes 15 | |> should.equal([]) 16 | } 17 | 18 | pub fn d1_test() { 19 | let a = A(size: 1) 20 | let axis = Axis(name: "Sparkle", size: 9) 21 | let infer = Infer(name: "Shine") 22 | 23 | a 24 | |> space.d1 25 | |> should.be_ok 26 | 27 | axis 28 | |> space.d1 29 | |> should.be_ok 30 | 31 | infer 32 | |> space.d1 33 | |> should.be_ok 34 | 35 | let a = A(size: 0) 36 | space.d1(a) 37 | |> should.equal(Error([SpaceError(InvalidSize, [a])])) 38 | 39 | let a = A(size: -1) 40 | space.d1(a) 41 | |> should.equal(Error([SpaceError(InvalidSize, [a])])) 42 | } 43 | 44 | pub fn d2_test() { 45 | let a = A(size: 1) 46 | let b = B(size: 3) 47 | let axis = Axis(name: "Sparkle", size: 9) 48 | let infer = Infer(name: "Shine") 49 | 50 | space.d2(a, axis) 51 | |> should.be_ok 52 | 53 | space.d2(a, infer) 54 | |> should.be_ok 55 | 56 | let a = A(size: 0) 57 | space.d2(a, b) 58 | |> should.equal(Error([SpaceError(InvalidSize, [a])])) 59 | 60 | let a = A(size: -1) 61 | space.d2(a, b) 62 | |> should.equal(Error([SpaceError(InvalidSize, [a])])) 63 | 64 | let axis_a = Axis(name: "A", size: 1) 65 | space.d2(a, axis_a) 66 | |> should.equal( 67 | Error([SpaceError(InvalidSize, [a]), SpaceError(DuplicateName, [axis_a])]), 68 | ) 69 | 70 | space.d2(Infer(name: "A"), infer) 71 | |> should.equal(Error([SpaceError(CannotInfer, [infer])])) 72 | } 73 | 74 | pub fn d3_test() { 75 | let a = A(size: 1) 76 | let b = B(size: 3) 77 | let axis = Axis(name: "Sparkle", size: 9) 78 | let infer = Infer(name: "Shine") 79 | 80 | space.d3(a, b, axis) 81 | |> should.be_ok 82 | 83 | space.d3(a, infer, axis) 84 | |> should.be_ok 85 | 86 | let a = A(size: 0) 87 | space.d3(a, b, axis) 88 | |> should.equal(Error([SpaceError(InvalidSize, [a])])) 89 | 90 | let a = A(size: -1) 91 | space.d3(a, b, infer) 92 | |> should.equal(Error([SpaceError(InvalidSize, [a])])) 93 | 94 | let axis_a = Axis(name: "A", size: 1) 95 | space.d3(a, axis_a, axis) 96 | |> should.equal( 97 | Error([SpaceError(InvalidSize, [a]), SpaceError(DuplicateName, [axis_a])]), 98 | ) 99 | 100 | space.d3(Infer(name: "A"), infer, axis) 101 | |> should.equal(Error([SpaceError(CannotInfer, [infer])])) 102 | } 103 | 104 | pub fn d4_test() { 105 | let a = A(size: 1) 106 | let b = B(size: 3) 107 | let c = C(size: 9) 108 | let axis = Axis(name: "Sparkle", size: 9) 109 | let infer = Infer(name: "Shine") 110 | 111 | space.d4(a, b, c, axis) 112 | |> should.be_ok 113 | 114 | space.d4(a, b, infer, axis) 115 | |> should.be_ok 116 | 117 | let a = A(size: 0) 118 | space.d4(a, b, c, axis) 119 | |> should.equal(Error([SpaceError(InvalidSize, [a])])) 120 | 121 | let a = A(size: -1) 122 | space.d4(a, b, c, infer) 123 | |> should.equal(Error([SpaceError(InvalidSize, [a])])) 124 | 125 | let axis_a = Axis(name: "A", size: 1) 126 | space.d4(a, b, axis_a, axis) 127 | |> should.equal( 128 | Error([SpaceError(InvalidSize, [a]), SpaceError(DuplicateName, [axis_a])]), 129 | ) 130 | 131 | space.d4(Infer(name: "A"), infer, b, axis) 132 | |> should.equal(Error([SpaceError(CannotInfer, [infer])])) 133 | } 134 | 135 | pub fn d5_test() { 136 | let a = A(size: 1) 137 | let b = B(size: 3) 138 | let c = C(size: 9) 139 | let d = D(size: 27) 140 | let axis = Axis(name: "Sparkle", size: 9) 141 | let infer = Infer(name: "Shine") 142 | 143 | space.d5(a, b, c, d, axis) 144 | |> should.be_ok 145 | 146 | space.d5(a, b, c, infer, axis) 147 | |> should.be_ok 148 | 149 | let a = A(size: 0) 150 | space.d5(a, b, c, d, axis) 151 | |> should.equal(Error([SpaceError(InvalidSize, [a])])) 152 | 153 | let a = A(size: -1) 154 | space.d5(a, b, c, d, infer) 155 | |> should.equal(Error([SpaceError(InvalidSize, [a])])) 156 | 157 | let axis_a = Axis(name: "A", size: 1) 158 | space.d5(a, b, c, axis_a, axis) 159 | |> should.equal( 160 | Error([SpaceError(InvalidSize, [a]), SpaceError(DuplicateName, [axis_a])]), 161 | ) 162 | 163 | space.d5(Infer(name: "A"), infer, b, c, axis) 164 | |> should.equal(Error([SpaceError(CannotInfer, [infer])])) 165 | } 166 | 167 | pub fn d6_test() { 168 | let a = A(size: 1) 169 | let b = B(size: 3) 170 | let c = C(size: 9) 171 | let d = D(size: 27) 172 | let e = E(size: 81) 173 | let axis = Axis(name: "Sparkle", size: 9) 174 | let infer = Infer(name: "Shine") 175 | 176 | space.d6(a, b, c, d, e, axis) 177 | |> should.be_ok 178 | 179 | space.d6(a, b, c, d, infer, axis) 180 | |> should.be_ok 181 | 182 | let a = A(size: 0) 183 | space.d6(a, b, c, d, e, axis) 184 | |> should.equal(Error([SpaceError(InvalidSize, [a])])) 185 | 186 | let a = A(size: -1) 187 | space.d6(a, b, c, d, e, infer) 188 | |> should.equal(Error([SpaceError(InvalidSize, [a])])) 189 | 190 | let axis_a = Axis(name: "A", size: 1) 191 | space.d6(a, b, c, d, axis_a, axis) 192 | |> should.equal( 193 | Error([SpaceError(InvalidSize, [a]), SpaceError(DuplicateName, [axis_a])]), 194 | ) 195 | 196 | space.d6(Infer(name: "A"), infer, b, c, d, axis) 197 | |> should.equal(Error([SpaceError(CannotInfer, [infer])])) 198 | } 199 | 200 | pub fn from_list_test() { 201 | let a = A(size: 1) 202 | let b = B(size: 3) 203 | let c = C(size: 9) 204 | let d = D(size: 27) 205 | let e = E(size: 81) 206 | let z = Z(size: 243) 207 | let axis = Axis(name: "Sparkle", size: 9) 208 | let infer = Infer(name: "Shine") 209 | 210 | [a, b, c, d, e, z, axis] 211 | |> space.from_list 212 | |> should.be_ok 213 | 214 | [a, b, c, d, e, infer, axis] 215 | |> space.from_list 216 | |> should.be_ok 217 | 218 | let a = A(size: 0) 219 | [a, b, c, d, e, z, axis] 220 | |> space.from_list 221 | |> should.equal(Error([SpaceError(InvalidSize, [a])])) 222 | 223 | let a = A(size: -1) 224 | [a, b, c, d, e, z, infer] 225 | |> space.from_list 226 | |> should.equal(Error([SpaceError(InvalidSize, [a])])) 227 | 228 | let axis_a = Axis(name: "A", size: 1) 229 | [a, b, c, d, e, axis_a, axis] 230 | |> space.from_list 231 | |> should.equal( 232 | Error([SpaceError(InvalidSize, [a]), SpaceError(DuplicateName, [axis_a])]), 233 | ) 234 | 235 | [Infer(name: "A"), infer, b, c, d, e, axis] 236 | |> space.from_list 237 | |> should.equal(Error([SpaceError(CannotInfer, [infer])])) 238 | } 239 | 240 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 241 | // Reflection Functions // 242 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 243 | 244 | pub fn axes_test() { 245 | space.new() 246 | |> space.axes 247 | |> should.equal([]) 248 | 249 | let a = A(size: 1) 250 | 251 | let assert Ok(d1) = space.d1(a) 252 | d1 253 | |> space.axes 254 | |> should.equal([a]) 255 | 256 | let xs = [ 257 | a, 258 | B(size: 3), 259 | C(size: 9), 260 | D(size: 27), 261 | E(size: 81), 262 | Z(size: 243), 263 | Infer(name: "Shine"), 264 | Axis(name: "Sparkle", size: 9), 265 | ] 266 | let assert Ok(d8) = space.from_list(xs) 267 | d8 268 | |> space.axes 269 | |> should.equal(xs) 270 | } 271 | 272 | pub fn degree_test() { 273 | space.new() 274 | |> space.degree 275 | |> should.equal(0) 276 | 277 | let a = A(size: 1) 278 | 279 | let assert Ok(d1) = space.d1(a) 280 | d1 281 | |> space.degree 282 | |> should.equal(1) 283 | 284 | let xs = [ 285 | a, 286 | B(size: 3), 287 | C(size: 9), 288 | D(size: 27), 289 | E(size: 81), 290 | Z(size: 243), 291 | Infer(name: "Shine"), 292 | Axis(name: "Sparkle", size: 9), 293 | ] 294 | let assert Ok(d8) = space.from_list(xs) 295 | d8 296 | |> space.degree 297 | |> should.equal(8) 298 | } 299 | 300 | pub fn shape_test() { 301 | space.new() 302 | |> space.shape 303 | |> should.equal([]) 304 | 305 | let a = A(size: 1) 306 | 307 | let assert Ok(d1) = space.d1(a) 308 | d1 309 | |> space.shape 310 | |> should.equal([1]) 311 | 312 | let xs = [ 313 | a, 314 | B(size: 3), 315 | C(size: 9), 316 | D(size: 27), 317 | E(size: 81), 318 | Z(size: 243), 319 | Infer(name: "Shine"), 320 | Axis(name: "Sparkle", size: 9), 321 | ] 322 | let assert Ok(d8) = space.from_list(xs) 323 | d8 324 | |> space.shape 325 | |> should.equal([1, 3, 9, 27, 81, 243, 0, 9]) 326 | } 327 | 328 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 329 | // Transformation Functions // 330 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 331 | 332 | pub fn map_test() { 333 | let resize = axis.resize(_, 3) 334 | 335 | let assert Ok(d0) = 336 | space.new() 337 | |> space.map(with: resize) 338 | d0 339 | |> space.axes 340 | |> should.equal([]) 341 | 342 | let a = A(size: 1) 343 | let assert Ok(d1) = space.d1(a) 344 | let assert Ok(d1) = space.map(d1, with: resize) 345 | d1 346 | |> space.axes 347 | |> should.equal([A(size: 3)]) 348 | 349 | let xs = [ 350 | a, 351 | B(size: 3), 352 | C(size: 9), 353 | D(size: 27), 354 | E(size: 81), 355 | Z(size: 243), 356 | Infer(name: "Shine"), 357 | Axis(name: "Sparkle", size: 9), 358 | ] 359 | let assert Ok(d8) = space.from_list(xs) 360 | let assert Ok(d8) = space.map(d8, with: resize) 361 | 362 | d8 363 | |> space.axes 364 | |> list.map(with: axis.size) 365 | |> list.all(fn(x) { x == 3 }) 366 | |> should.be_true 367 | 368 | d8 369 | |> space.map(with: axis.resize(_, 0)) 370 | |> should.be_error 371 | 372 | d8 373 | |> space.map(with: axis.rename(_, "A")) 374 | |> should.be_error 375 | 376 | d8 377 | |> space.map(with: fn(a) { 378 | a 379 | |> axis.name 380 | |> Infer 381 | }) 382 | |> should.be_error 383 | } 384 | 385 | pub fn merge_test() { 386 | let a = A(size: 1) 387 | let infer = Infer(name: "Shine") 388 | let axis = Axis(name: "Sparkle", size: 9) 389 | 390 | let d0 = space.new() 391 | 392 | let assert Ok(d1) = space.d1(a) 393 | let assert Ok(d1) = space.merge(d1, d0) 394 | d1 395 | |> space.axes 396 | |> should.equal([a]) 397 | 398 | let assert Ok(d2) = space.d2(a, axis) 399 | let assert Ok(d2) = space.merge(d2, d0) 400 | d2 401 | |> space.axes 402 | |> should.equal([a, axis]) 403 | 404 | let assert Ok(d3) = space.d3(a, infer, axis) 405 | let assert Ok(d3) = space.merge(d3, d0) 406 | d3 407 | |> space.axes 408 | |> should.equal([a, infer, axis]) 409 | 410 | let assert Ok(d1_axis) = space.d1(Axis(name: "A", size: 9)) 411 | let assert Ok(d1) = space.merge(d1, d1_axis) 412 | d1 413 | |> space.axes 414 | |> should.equal([A(size: 9)]) 415 | 416 | let assert Ok(d1) = 417 | axis 418 | |> axis.resize(1) 419 | |> space.d1 420 | let assert Ok(d3) = space.merge(d3, d1) 421 | d3 422 | |> space.axes 423 | |> should.equal([a, infer, axis]) 424 | 425 | let assert Ok(d1) = 426 | axis 427 | |> axis.name 428 | |> Infer 429 | |> space.d1 430 | let assert Ok(d3) = space.map(d3, with: axis.resize(_, 1)) 431 | let assert Ok(d3) = space.merge(d1, d3) 432 | d3 433 | |> space.axes 434 | |> should.equal([a, Axis(name: "Shine", size: 1), Infer(name: "Sparkle")]) 435 | 436 | let assert Ok(d1) = space.d1(a) 437 | space.merge(d1, d3) 438 | |> should.equal( 439 | Error([SpaceError(CannotMerge, [Infer(name: "Sparkle"), A(size: 1)])]), 440 | ) 441 | 442 | let assert Ok(d2) = space.d2(Infer(name: "Shine"), axis) 443 | space.merge(d3, d2) 444 | |> should.equal(Error([SpaceError(CannotInfer, [Infer(name: "Sparkle")])])) 445 | } 446 | 447 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 448 | // Conversion Functions // 449 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 450 | 451 | pub fn to_string_test() { 452 | space.new() 453 | |> space.to_string 454 | |> should.equal("Space()") 455 | 456 | let a = A(size: 1) 457 | 458 | let assert Ok(d1) = space.d1(a) 459 | d1 460 | |> space.to_string 461 | |> should.equal("Space(A(1))") 462 | 463 | let xs = [ 464 | a, 465 | B(size: 3), 466 | C(size: 9), 467 | D(size: 27), 468 | E(size: 81), 469 | Z(size: 243), 470 | Infer(name: "Shine"), 471 | Axis(name: "Sparkle", size: 9), 472 | ] 473 | let assert Ok(d8) = space.from_list(xs) 474 | d8 475 | |> space.to_string 476 | |> should.equal( 477 | "Space(A(1), B(3), C(9), D(27), E(81), Z(243), Infer(\"Shine\"), Axis(\"Sparkle\", 9))", 478 | ) 479 | } 480 | -------------------------------------------------------------------------------- /src/argamak/space.gleam: -------------------------------------------------------------------------------- 1 | import gleam/dict 2 | import gleam/int 3 | import gleam/list 4 | import gleam/result 5 | import gleam/string 6 | import argamak/axis.{type Axes, type Axis, Axis, Infer} 7 | 8 | /// An n-dimensional `Space` containing `Axes` of various sizes. 9 | /// 10 | pub opaque type Space { 11 | Space(axes: Axes) 12 | } 13 | 14 | /// An error returned when attempting to create an invalid `Space`. 15 | /// 16 | pub type SpaceError { 17 | CannotMerge 18 | CannotInfer 19 | DuplicateName 20 | InvalidSize 21 | SpaceError(reason: SpaceError, axes: Axes) 22 | } 23 | 24 | /// A `SpaceError` list. 25 | /// 26 | pub type SpaceErrors = 27 | List(SpaceError) 28 | 29 | /// A `Result` alias type for spaces. 30 | /// 31 | pub type SpaceResult = 32 | Result(Space, SpaceErrors) 33 | 34 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 35 | // Creation Functions // 36 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 37 | 38 | /// Results in a dimensionless `Space`. 39 | /// 40 | /// ## Examples 41 | /// 42 | /// ```gleam 43 | /// > new() |> axes 44 | /// [] 45 | /// ``` 46 | /// 47 | pub fn new() -> Space { 48 | Space(axes: []) 49 | } 50 | 51 | /// Results in a one-dimensional `Space` on success, or `SpaceErrors` on 52 | /// failure. 53 | /// 54 | /// ## Examples 55 | /// 56 | /// ```gleam 57 | /// > import argamak/axis.{Infer} 58 | /// > let assert Ok(space) = d1(Infer("A")) 59 | /// > axes(space) 60 | /// [Infer("A")] 61 | /// ``` 62 | /// 63 | pub fn d1(a: Axis) -> SpaceResult { 64 | [a] 65 | |> Space 66 | |> validate 67 | } 68 | 69 | /// Results in a two-dimensional `Space` on success, or `SpaceErrors` on 70 | /// failure. 71 | /// 72 | /// ## Examples 73 | /// 74 | /// ```gleam 75 | /// > import argamak/axis.{A, B} 76 | /// > let assert Ok(space) = d2(A(2), B(2)) 77 | /// > axes(space) 78 | /// [A(2), B(2)] 79 | /// ``` 80 | /// 81 | pub fn d2(a: Axis, b: Axis) -> SpaceResult { 82 | [a, b] 83 | |> Space 84 | |> validate 85 | } 86 | 87 | /// Results in a three-dimensional `Space` on success, or `SpaceErrors` on 88 | /// failure. 89 | /// 90 | /// ## Examples 91 | /// 92 | /// ```gleam 93 | /// > import argamak/axis.{A, B, Infer} 94 | /// > let assert Ok(space) = d3(A(2), B(2), Infer("C")) 95 | /// > axes(space) 96 | /// [A(2), B(2), Infer("C")] 97 | /// ``` 98 | /// 99 | pub fn d3(a: Axis, b: Axis, c: Axis) -> SpaceResult { 100 | [a, b, c] 101 | |> Space 102 | |> validate 103 | } 104 | 105 | /// Results in a four-dimensional `Space` on success, or `SpaceErrors` on 106 | /// failure. 107 | /// 108 | /// ## Examples 109 | /// 110 | /// ```gleam 111 | /// > import argamak/axis.{A, B, D, Infer} 112 | /// > let assert Ok(space) = d4(A(2), B(2), Infer("C"), D(1)) 113 | /// > axes(space) 114 | /// [A(2), B(2), Infer("C"), D(1)] 115 | /// ``` 116 | /// 117 | pub fn d4(a: Axis, b: Axis, c: Axis, d: Axis) -> SpaceResult { 118 | [a, b, c, d] 119 | |> Space 120 | |> validate 121 | } 122 | 123 | /// Results in a five-dimensional `Space` on success, or `SpaceErrors` on 124 | /// failure. 125 | /// 126 | /// ## Examples 127 | /// 128 | /// ```gleam 129 | /// > import argamak/axis.{A, B, C, D, E} 130 | /// > let assert Ok(space) = d5(A(5), B(4), C(3), D(2), E(1)) 131 | /// > axes(space) 132 | /// [A(5), B(4), C(3), D(2), E(1)] 133 | /// ``` 134 | /// 135 | pub fn d5(a: Axis, b: Axis, c: Axis, d: Axis, e: Axis) -> SpaceResult { 136 | [a, b, c, d, e] 137 | |> Space 138 | |> validate 139 | } 140 | 141 | /// Results in a six-dimensional `Space` on success, or `SpaceErrors` on 142 | /// failure. 143 | /// 144 | /// ## Examples 145 | /// 146 | /// ```gleam 147 | /// > import argamak/axis.{A, B, C, D, E, F} 148 | /// > let assert Ok(space) = d6(A(9), B(9), C(9), D(9), E(9), F(9)) 149 | /// > axes(space) 150 | /// [A(9), B(9), C(9), D(9), E(9), F(9)] 151 | /// ``` 152 | /// 153 | pub fn d6(a: Axis, b: Axis, c: Axis, d: Axis, e: Axis, f: Axis) -> SpaceResult { 154 | [a, b, c, d, e, f] 155 | |> Space 156 | |> validate 157 | } 158 | 159 | /// Results in a `Space` created from a list of `Axes` on success, or 160 | /// `SpaceErrors` on failure. 161 | /// 162 | /// ## Examples 163 | /// 164 | /// ```gleam 165 | /// > import argamak/axis.{A, B, C, D, E, F, Z} 166 | /// > let assert Ok(space) = from_list([A(9), B(9), C(9), D(9), E(9), F(9), Z(9)]) 167 | /// > axes(space) 168 | /// [A(9), B(9), C(9), D(9), E(9), F(9), Z(9)] 169 | /// ``` 170 | /// 171 | pub fn from_list(x: Axes) -> SpaceResult { 172 | x 173 | |> Space 174 | |> validate 175 | } 176 | 177 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 178 | // Reflection Functions // 179 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 180 | 181 | /// Returns the axes of a given `Space`. 182 | /// 183 | /// ## Examples 184 | /// 185 | /// ```gleam 186 | /// > new() |> axes 187 | /// [] 188 | /// 189 | /// > import argamak/axis.{A, B, Infer} 190 | /// > let assert Ok(space) = d1(Infer("A")) 191 | /// > axes(space) 192 | /// [Infer("A")] 193 | /// 194 | /// > let assert Ok(space) = d3(A(2), B(2), Infer("C")) 195 | /// > axes(space) 196 | /// [A(2), B(2), Infer("C")] 197 | /// ``` 198 | /// 199 | pub fn axes(x: Space) -> Axes { 200 | x.axes 201 | } 202 | 203 | /// Returns the degree of a given `Space`. 204 | /// 205 | /// ## Examples 206 | /// 207 | /// ```gleam 208 | /// > new() |> degree 209 | /// 0 210 | /// 211 | /// > import argamak/axis.{A, B, Infer} 212 | /// > let assert Ok(space) = d1(Infer("A")) 213 | /// > degree(space) 214 | /// 1 215 | /// 216 | /// > let assert Ok(space) = d3(A(2), B(2), Infer("C")) 217 | /// > degree(space) 218 | /// 3 219 | /// ``` 220 | /// 221 | pub fn degree(x: Space) -> Int { 222 | x 223 | |> axes 224 | |> list.length 225 | } 226 | 227 | /// Returns the shape of a given `Space`. 228 | /// 229 | /// ## Examples 230 | /// 231 | /// ```gleam 232 | /// > new() |> shape 233 | /// [] 234 | /// 235 | /// > import argamak/axis.{A, B, Infer} 236 | /// > let assert Ok(space) = d1(Infer("A")) 237 | /// > shape(space) 238 | /// [0] 239 | /// 240 | /// > let assert Ok(space) = d3(A(2), B(2), Infer("C")) 241 | /// > shape(space) 242 | /// [2, 2, 0] 243 | /// ``` 244 | /// 245 | pub fn shape(x: Space) -> List(Int) { 246 | x 247 | |> axes 248 | |> list.map(with: axis.size) 249 | } 250 | 251 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 252 | // Transformation Functions // 253 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 254 | 255 | /// Results in a new `Space` with the same number of dimensions as the given 256 | /// `Space` on success, or `SpaceErrors` on failure. 257 | /// 258 | /// Applies the given function to each `Axis` of the `Space`. 259 | /// 260 | /// ## Examples 261 | /// 262 | /// ```gleam 263 | /// > import argamak/axis.{B, C, Infer} 264 | /// > let assert Ok(space) = map(new(), with: fn(_) { C(3) }) 265 | /// > axes(space) 266 | /// [] 267 | /// 268 | /// > let assert Ok(space) = d1(Infer("A")) 269 | /// > let assert Ok(space) = map(space, with: fn(_) { C(3) }) 270 | /// > axes(space) 271 | /// [C(3)] 272 | /// 273 | /// > let assert Ok(space) = d3(Infer("A"), B(2), C(2)) 274 | /// > let assert Ok(space) = map(space, with: fn(axis) { 275 | /// > case axis { 276 | /// > Infer(_) -> axis.resize(axis, 4) 277 | /// > _else -> axis 278 | /// > } 279 | /// > }) 280 | /// > axes(space) 281 | /// [A(4), B(2), C(2)] 282 | /// ``` 283 | /// 284 | pub fn map(x: Space, with fun: fn(Axis) -> Axis) -> SpaceResult { 285 | x 286 | |> axes 287 | |> list.map(with: fun) 288 | |> Space 289 | |> validate 290 | } 291 | 292 | /// Results in a new `Space` that is the element-wise maximum of the given 293 | /// spaces on success, or `SpaceErrors` on failure. 294 | /// 295 | /// Spaces are merged tail-first, and corresponding `Axis` names must match. 296 | /// 297 | /// ## Examples 298 | /// 299 | /// ```gleam 300 | /// > import argamak/axis.{Axis, Infer, X, Y} 301 | /// > let assert Ok(a) = d1(Infer("X")) 302 | /// > merge(a, new()) |> result.map(with: axes) 303 | /// Ok([Infer("X")]) 304 | /// 305 | /// > let assert Ok(b) = d2(Axis("Sparkle", 2), X(2)) 306 | /// > merge(a, b) |> result.map(with: axes) 307 | /// Ok([Axis("Sparkle", 2), Infer("X")]) 308 | /// 309 | /// > let assert Ok(c) = d3(Infer("X"), Axis("Sparkle", 3), Y(3)) 310 | /// > merge(b, c) 311 | /// Error([SpaceError(CannotMerge, [Y(3), X(2)])]) 312 | /// ``` 313 | /// 314 | pub fn merge(a: Space, b: Space) -> SpaceResult { 315 | let index = fn(x: Space) { 316 | x 317 | |> axes 318 | |> list.index_map(with: fn(axis, index) { #(index, axis) }) 319 | |> dict.from_list 320 | } 321 | let a_index = index(a) 322 | let b_index = index(b) 323 | 324 | let a_size = dict.size(a_index) 325 | let b_size = dict.size(b_index) 326 | 327 | let #(x, dict) = case a_size < b_size { 328 | True -> #(axes(b), a_index) 329 | False -> #(axes(a), b_index) 330 | } 331 | let offset = int.absolute_value(a_size - b_size) 332 | 333 | let #(x, errors) = 334 | x 335 | |> list.index_map(with: fn(a_axis, index) { 336 | let b_axis = 337 | dict 338 | |> dict.get(index - offset) 339 | |> result.unwrap(or: a_axis) 340 | let a_name = axis.name(a_axis) 341 | let b_name = axis.name(b_axis) 342 | let a_size = axis.size(a_axis) 343 | let b_size = axis.size(b_axis) 344 | let should_infer = a_axis == Infer(a_name) || b_axis == Infer(b_name) 345 | case a_name == b_name { 346 | True if should_infer -> 347 | a_name 348 | |> Infer 349 | |> Ok 350 | True -> 351 | a_axis 352 | |> axis.resize(int.max(a_size, b_size)) 353 | |> Ok 354 | False -> 355 | CannotMerge 356 | |> SpaceError(axes: [a_axis, b_axis]) 357 | |> Error 358 | } 359 | }) 360 | |> list.partition(with: result.is_ok) 361 | 362 | use x <- result.try(case errors { 363 | [] -> 364 | x 365 | |> result.all 366 | |> result.map_error(with: fn(error) { [error] }) 367 | _else -> 368 | errors 369 | |> list.map(with: fn(x) { 370 | let assert Error(x) = x 371 | x 372 | }) 373 | |> Error 374 | }) 375 | 376 | x 377 | |> Space 378 | |> validate 379 | } 380 | 381 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 382 | // Conversion Functions // 383 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 384 | 385 | /// Converts a `Space` into a `String`. 386 | /// 387 | /// ## Examples 388 | /// 389 | /// ```gleam 390 | /// > new() |> to_string 391 | /// "Space()" 392 | /// 393 | /// > import argamak/axis.{A, B, Axis, Infer} 394 | /// > let assert Ok(space) = d1(Axis("Sparkle", 2)) 395 | /// > to_string(space) 396 | /// "Space(Axis(\"Sparkle\", 2))" 397 | /// 398 | /// > let assert Ok(space) = d3(A(2), B(2), Infer("C")) 399 | /// > to_string(space) 400 | /// "Space(A(2), B(2), Infer(\"C\"))" 401 | /// ``` 402 | /// 403 | pub fn to_string(x: Space) -> String { 404 | let axes = { 405 | use x <- list.map(axes(x)) 406 | let name = axis.name(x) 407 | let size = 408 | x 409 | |> axis.size 410 | |> int.to_string 411 | case x { 412 | Axis(..) -> "Axis(\"" <> name <> "\", " <> size <> ")" 413 | Infer(_) -> "Infer(\"" <> name <> "\")" 414 | _else -> name <> "(" <> size <> ")" 415 | } 416 | } 417 | "Space(" <> string.join(axes, with: ", ") <> ")" 418 | } 419 | 420 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 421 | // Private Functions // 422 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 423 | 424 | /// Results in the given `Space` on success, or `SpaceErrors` on failure. 425 | /// 426 | /// Ensures that no axes are duplicated, that there is at most a single 427 | /// inferred dimension size, and that no other dimension sizes are less than 428 | /// one. 429 | /// 430 | /// ## Examples 431 | /// 432 | /// ```gleam 433 | /// > import argamak/axis.{Infer, X, Y, Z} 434 | /// > validate(d3(X(1), Infer("Y"), Z(1))) 435 | /// Ok(space) 436 | /// 437 | /// > validate(d2(X(1), Infer("X"))) 438 | /// Error([SpaceError(DuplicateName, [Infer("X")])]) 439 | /// 440 | /// > validate(d2(Infer("X"), Infer("Y"))) 441 | /// Error([SpaceError(CannotInfer, [Infer("Y")])]) 442 | /// 443 | /// > validate(d2(X(0), Y(1))) 444 | /// Error([SpaceError(InvalidSize, [X(0)])]) 445 | /// 446 | /// > validate(d3(X(-2), Infer("X"), Infer("Z"))) 447 | /// Error([ 448 | /// SpaceError(InvalidSize, [X(-2)]), 449 | /// SpaceError(DuplicateName, [Infer("X")]), 450 | /// SpaceError(CannotInfer, [Infer("Z")]), 451 | /// ]) 452 | /// ``` 453 | /// 454 | fn validate(space: Space) -> SpaceResult { 455 | let ValidateAcc(_, _, results: results) = { 456 | use acc, axis <- list.fold( 457 | over: axes(space), 458 | from: ValidateAcc(names: [], inferred: False, results: []), 459 | ) 460 | let name = axis.name(axis) 461 | let size = axis.size(axis) 462 | let errors = 463 | [ 464 | Invalid(error: DuplicateName, when: list.contains(acc.names, any: name)), 465 | Invalid(error: CannotInfer, when: acc.inferred && axis == Infer(name)), 466 | Invalid(error: InvalidSize, when: size < 1 && axis != Infer(name)), 467 | ] 468 | |> list.map(with: fn(invalid) { 469 | case invalid.when { 470 | False -> [] 471 | True -> [SpaceError(reason: invalid.error, axes: [axis])] 472 | } 473 | }) 474 | |> list.flatten 475 | let result = case errors { 476 | [] -> Ok(axis) 477 | _else -> Error(errors) 478 | } 479 | ValidateAcc( 480 | names: [name, ..acc.names], 481 | inferred: acc.inferred || axis == Infer(name), 482 | results: [result, ..acc.results], 483 | ) 484 | } 485 | 486 | case list.any(in: results, satisfying: result.is_error) { 487 | False -> Ok(space) 488 | True -> 489 | results 490 | |> list.reverse 491 | |> list.map(with: fn(result) { 492 | case result { 493 | Ok(_) -> [] 494 | Error(errors) -> errors 495 | } 496 | }) 497 | |> list.flatten 498 | |> Error 499 | } 500 | } 501 | 502 | type ValidateAcc { 503 | ValidateAcc( 504 | names: List(String), 505 | inferred: Bool, 506 | results: List(Result(Axis, SpaceErrors)), 507 | ) 508 | } 509 | 510 | type Invalid { 511 | Invalid(error: SpaceError, when: Bool) 512 | } 513 | -------------------------------------------------------------------------------- /test/argamak/tensor_test.gleam: -------------------------------------------------------------------------------- 1 | import gleam/dynamic.{type Dynamic} 2 | import gleam/float 3 | import gleam/int 4 | import gleam/list 5 | import gleam/order.{Eq} 6 | import gleeunit/should 7 | import argamak/axis.{A, B, C, D, E, F, Infer, Z} 8 | import argamak/format 9 | import argamak/space 10 | import argamak/tensor.{type Tensor} 11 | 12 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 13 | // Creation Functions // 14 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 15 | 16 | pub fn from_float_test() { 17 | 0.0 18 | |> tensor.from_float 19 | |> should_share_native_format 20 | |> tensor.to_float 21 | |> should.equal(Ok(0.0)) 22 | } 23 | 24 | pub fn from_int_test() { 25 | 0 26 | |> tensor.from_int 27 | |> should_share_native_format 28 | |> tensor.to_int 29 | |> should.equal(Ok(0)) 30 | } 31 | 32 | pub fn from_bool_test() { 33 | True 34 | |> tensor.from_bool 35 | |> should_share_native_format 36 | |> tensor.to_bool 37 | |> should.equal(Ok(True)) 38 | 39 | False 40 | |> tensor.from_bool 41 | |> tensor.to_bool 42 | |> should.equal(Ok(False)) 43 | } 44 | 45 | pub fn from_floats_test() { 46 | let xs = 47 | list.range(from: 1, to: 64) 48 | |> list.map(with: int.to_float) 49 | 50 | let d0 = space.new() 51 | let assert Ok(d1) = space.d1(Infer("A")) 52 | let assert Ok(d2) = space.d2(A(2), B(32)) 53 | let assert Ok(d3) = space.d3(A(2), B(2), C(16)) 54 | let assert Ok(d4) = space.d4(A(2), B(2), C(2), D(8)) 55 | let assert Ok(d5) = space.d5(A(2), B(2), C(2), D(2), E(4)) 56 | let assert Ok(d6) = space.d6(A(2), B(2), C(2), D(2), E(2), F(2)) 57 | 58 | xs 59 | |> tensor.from_floats(into: d0) 60 | |> should.equal(Error(tensor.IncompatibleShape)) 61 | 62 | let assert Ok(x) = tensor.from_floats(of: xs, into: d1) 63 | x 64 | |> should_share_native_format 65 | |> tensor.to_floats 66 | |> should.equal(xs) 67 | 68 | let assert Ok(x) = tensor.from_floats(of: xs, into: d2) 69 | x 70 | |> tensor.to_floats 71 | |> should.equal(xs) 72 | 73 | let assert Ok(x) = tensor.from_floats(of: xs, into: d3) 74 | x 75 | |> tensor.to_floats 76 | |> should.equal(xs) 77 | 78 | let assert Ok(x) = tensor.from_floats(of: xs, into: d4) 79 | x 80 | |> tensor.to_floats 81 | |> should.equal(xs) 82 | 83 | let assert Ok(x) = tensor.from_floats(of: xs, into: d5) 84 | x 85 | |> tensor.to_floats 86 | |> should.equal(xs) 87 | 88 | let assert Ok(x) = tensor.from_floats(of: xs, into: d6) 89 | x 90 | |> tensor.to_floats 91 | |> should.equal(xs) 92 | } 93 | 94 | pub fn from_ints_test() { 95 | let xs = list.range(from: 1, to: 64) 96 | 97 | let d0 = space.new() 98 | let assert Ok(d1) = space.d1(Infer("A")) 99 | let assert Ok(d2) = space.d2(A(2), B(32)) 100 | let assert Ok(d3) = space.d3(A(2), B(2), C(16)) 101 | let assert Ok(d4) = space.d4(A(2), B(2), C(2), D(8)) 102 | let assert Ok(d5) = space.d5(A(2), B(2), C(2), D(2), E(4)) 103 | let assert Ok(d6) = space.d6(A(2), B(2), C(2), D(2), E(2), F(2)) 104 | 105 | xs 106 | |> tensor.from_ints(into: d0) 107 | |> should.equal(Error(tensor.IncompatibleShape)) 108 | 109 | let assert Ok(x) = tensor.from_ints(of: xs, into: d1) 110 | x 111 | |> should_share_native_format 112 | |> tensor.to_ints 113 | |> should.equal(xs) 114 | 115 | let assert Ok(x) = tensor.from_ints(of: xs, into: d2) 116 | x 117 | |> tensor.to_ints 118 | |> should.equal(xs) 119 | 120 | let assert Ok(x) = tensor.from_ints(of: xs, into: d3) 121 | x 122 | |> tensor.to_ints 123 | |> should.equal(xs) 124 | 125 | let assert Ok(x) = tensor.from_ints(of: xs, into: d4) 126 | x 127 | |> tensor.to_ints 128 | |> should.equal(xs) 129 | 130 | let assert Ok(x) = tensor.from_ints(of: xs, into: d5) 131 | x 132 | |> tensor.to_ints 133 | |> should.equal(xs) 134 | 135 | let assert Ok(x) = tensor.from_ints(of: xs, into: d6) 136 | x 137 | |> tensor.to_ints 138 | |> should.equal(xs) 139 | } 140 | 141 | pub fn from_bools_test() { 142 | let xs = 143 | 1 144 | |> list.range(to: 64) 145 | |> list.map(with: fn(x) { 146 | case x % 3 { 147 | 0 -> False 148 | _else -> True 149 | } 150 | }) 151 | 152 | let d0 = space.new() 153 | let assert Ok(d1) = space.d1(Infer("A")) 154 | let assert Ok(d2) = space.d2(A(2), B(32)) 155 | let assert Ok(d3) = space.d3(A(2), B(2), C(16)) 156 | let assert Ok(d4) = space.d4(A(2), B(2), C(2), D(8)) 157 | let assert Ok(d5) = space.d5(A(2), B(2), C(2), D(2), E(4)) 158 | let assert Ok(d6) = space.d6(A(2), B(2), C(2), D(2), E(2), F(2)) 159 | 160 | xs 161 | |> tensor.from_bools(into: d0) 162 | |> should.equal(Error(tensor.IncompatibleShape)) 163 | 164 | let assert Ok(x) = tensor.from_bools(of: xs, into: d1) 165 | x 166 | |> should_share_native_format 167 | |> tensor.to_bools 168 | |> should.equal(xs) 169 | 170 | let assert Ok(x) = tensor.from_bools(of: xs, into: d2) 171 | x 172 | |> tensor.to_bools 173 | |> should.equal(xs) 174 | 175 | let assert Ok(x) = tensor.from_bools(of: xs, into: d3) 176 | x 177 | |> tensor.to_bools 178 | |> should.equal(xs) 179 | 180 | let assert Ok(x) = tensor.from_bools(of: xs, into: d4) 181 | x 182 | |> tensor.to_bools 183 | |> should.equal(xs) 184 | 185 | let assert Ok(x) = tensor.from_bools(of: xs, into: d5) 186 | x 187 | |> tensor.to_bools 188 | |> should.equal(xs) 189 | 190 | let assert Ok(x) = tensor.from_bools(of: xs, into: d6) 191 | x 192 | |> tensor.to_bools 193 | |> should.equal(xs) 194 | } 195 | 196 | pub fn from_native_test() { 197 | let assert Ok(space) = space.d2(A(2), Infer("B")) 198 | 199 | let assert Ok(x) = 200 | [[1, 2], [3, 4]] 201 | |> dynamic.from 202 | |> native_tensor 203 | |> tensor.from_native(into: space, with: format.int32()) 204 | x 205 | |> should_share_native_format 206 | |> tensor.to_ints 207 | |> should.equal([1, 2, 3, 4]) 208 | x 209 | |> tensor.axes 210 | |> should.equal([A(2), B(2)]) 211 | } 212 | 213 | @external(erlang, "Elixir.Nx", "tensor") 214 | @external(javascript, "../argamak_test_ffi.mjs", "tensor") 215 | fn native_tensor(data: Dynamic) -> tensor.Native 216 | 217 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 218 | // Reflection Functions // 219 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 220 | 221 | pub fn format_test() { 222 | let assert Ok(d1) = space.d1(Infer("A")) 223 | 224 | 0.0 225 | |> tensor.from_float 226 | |> tensor.format 227 | |> should.equal(format.float32()) 228 | 229 | 0 230 | |> tensor.from_int 231 | |> tensor.format 232 | |> should.equal(format.int32()) 233 | 234 | let assert Ok(x) = tensor.from_floats([0.0], into: d1) 235 | x 236 | |> tensor.format 237 | |> should.equal(format.float32()) 238 | 239 | let assert Ok(x) = tensor.from_ints(of: [0], into: d1) 240 | x 241 | |> tensor.format 242 | |> should.equal(format.int32()) 243 | } 244 | 245 | pub fn space_test() { 246 | let xs = [1, 2, 3, 4, 5, 6, 7, 8] 247 | 248 | 0.0 249 | |> tensor.from_float 250 | |> tensor.space 251 | |> should.equal(space.new()) 252 | 253 | let assert Ok(space) = space.d1(Infer("A")) 254 | let assert Ok(x) = tensor.from_ints(of: xs, into: space) 255 | x 256 | |> tensor.space 257 | |> space.axes 258 | |> should.equal([A(8)]) 259 | 260 | let assert Ok(space) = space.d3(A(2), B(2), C(2)) 261 | let assert Ok(x) = tensor.from_ints(of: xs, into: space) 262 | x 263 | |> tensor.space 264 | |> should.equal(space) 265 | } 266 | 267 | pub fn axes_test() { 268 | let xs = [0.0] 269 | 270 | let assert Ok(d1) = space.d1(Infer("A")) 271 | let assert Ok(d2) = space.d2(A(1), B(1)) 272 | let assert Ok(d3) = space.d3(A(1), B(1), C(1)) 273 | let assert Ok(d4) = space.d4(A(1), B(1), C(1), D(1)) 274 | let assert Ok(d5) = space.d5(A(1), B(1), C(1), D(1), E(1)) 275 | let assert Ok(d6) = space.d6(A(1), B(1), C(1), D(1), E(1), F(1)) 276 | let assert Ok(d7) = 277 | space.from_list([A(1), B(1), C(1), D(1), E(1), F(1), Z(1)]) 278 | 279 | 0.0 280 | |> tensor.from_float 281 | |> tensor.axes 282 | |> should.equal([]) 283 | 284 | let assert Ok(x) = tensor.from_floats(of: xs, into: d1) 285 | x 286 | |> tensor.axes 287 | |> should.equal([A(1)]) 288 | 289 | let assert Ok(x) = tensor.from_floats(of: xs, into: d2) 290 | x 291 | |> tensor.axes 292 | |> should.equal([A(1), B(1)]) 293 | 294 | let assert Ok(x) = tensor.from_floats(of: xs, into: d3) 295 | x 296 | |> tensor.axes 297 | |> should.equal([A(1), B(1), C(1)]) 298 | 299 | let assert Ok(x) = tensor.from_floats(of: xs, into: d4) 300 | x 301 | |> tensor.axes 302 | |> should.equal([A(1), B(1), C(1), D(1)]) 303 | 304 | let assert Ok(x) = tensor.from_floats(of: xs, into: d5) 305 | x 306 | |> tensor.axes 307 | |> should.equal([A(1), B(1), C(1), D(1), E(1)]) 308 | 309 | let assert Ok(x) = tensor.from_floats(of: xs, into: d6) 310 | x 311 | |> tensor.axes 312 | |> should.equal([A(1), B(1), C(1), D(1), E(1), F(1)]) 313 | 314 | let assert Ok(x) = tensor.from_floats(of: xs, into: d7) 315 | x 316 | |> tensor.axes 317 | |> should.equal([A(1), B(1), C(1), D(1), E(1), F(1), Z(1)]) 318 | } 319 | 320 | pub fn rank_test() { 321 | let xs = [0.0] 322 | 323 | let assert Ok(d1) = space.d1(Infer("A")) 324 | let assert Ok(d2) = space.d2(A(1), B(1)) 325 | let assert Ok(d3) = space.d3(A(1), B(1), C(1)) 326 | let assert Ok(d4) = space.d4(A(1), B(1), C(1), D(1)) 327 | let assert Ok(d5) = space.d5(A(1), B(1), C(1), D(1), E(1)) 328 | let assert Ok(d6) = space.d6(A(1), B(1), C(1), D(1), E(1), F(1)) 329 | let assert Ok(d7) = 330 | space.from_list([A(1), B(1), C(1), D(1), E(1), F(1), Z(1)]) 331 | 332 | 0.0 333 | |> tensor.from_float 334 | |> tensor.rank 335 | |> should.equal(0) 336 | 337 | let assert Ok(x) = tensor.from_floats(of: xs, into: d1) 338 | x 339 | |> tensor.rank 340 | |> should.equal(1) 341 | 342 | let assert Ok(x) = tensor.from_floats(of: xs, into: d2) 343 | x 344 | |> tensor.rank 345 | |> should.equal(2) 346 | 347 | let assert Ok(x) = tensor.from_floats(of: xs, into: d3) 348 | x 349 | |> tensor.rank 350 | |> should.equal(3) 351 | 352 | let assert Ok(x) = tensor.from_floats(of: xs, into: d4) 353 | x 354 | |> tensor.rank 355 | |> should.equal(4) 356 | 357 | let assert Ok(x) = tensor.from_floats(of: xs, into: d5) 358 | x 359 | |> tensor.rank 360 | |> should.equal(5) 361 | 362 | let assert Ok(x) = tensor.from_floats(of: xs, into: d6) 363 | x 364 | |> tensor.rank 365 | |> should.equal(6) 366 | 367 | let assert Ok(x) = tensor.from_floats(of: xs, into: d7) 368 | x 369 | |> tensor.rank 370 | |> should.equal(7) 371 | } 372 | 373 | pub fn shape_test() { 374 | let xs = 375 | list.range(from: 1, to: 720) 376 | |> list.map(with: int.to_float) 377 | 378 | let assert Ok(d1) = space.d1(Infer("A")) 379 | let assert Ok(d2) = space.d2(A(1), B(2)) 380 | let assert Ok(d3) = space.d3(A(1), B(2), C(3)) 381 | let assert Ok(d4) = space.d4(A(1), B(2), C(3), D(4)) 382 | let assert Ok(d5) = space.d5(A(1), B(2), C(3), D(4), E(5)) 383 | let assert Ok(d6) = space.d6(A(1), B(2), C(3), D(4), E(5), F(6)) 384 | let assert Ok(d7) = 385 | space.from_list([A(1), B(2), C(3), D(4), E(5), F(6), Z(1)]) 386 | 387 | 0.0 388 | |> tensor.from_float 389 | |> tensor.shape 390 | |> should.equal([]) 391 | 392 | let assert Ok(x) = tensor.from_floats(of: [1.0], into: d1) 393 | x 394 | |> tensor.shape 395 | |> should.equal([1]) 396 | 397 | let assert Ok(x) = tensor.from_floats(of: [1.0, 2.0], into: d2) 398 | x 399 | |> tensor.shape 400 | |> should.equal([1, 2]) 401 | 402 | let assert Ok(x) = 403 | tensor.from_floats(of: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], into: d3) 404 | x 405 | |> tensor.shape 406 | |> should.equal([1, 2, 3]) 407 | 408 | let assert Ok(x) = 409 | tensor.from_floats(of: list.take(from: xs, up_to: 24), into: d4) 410 | x 411 | |> tensor.shape 412 | |> should.equal([1, 2, 3, 4]) 413 | 414 | let assert Ok(x) = 415 | tensor.from_floats(of: list.take(from: xs, up_to: 120), into: d5) 416 | x 417 | |> tensor.shape 418 | |> should.equal([1, 2, 3, 4, 5]) 419 | 420 | let assert Ok(x) = tensor.from_floats(of: xs, into: d6) 421 | x 422 | |> tensor.shape 423 | |> should.equal([1, 2, 3, 4, 5, 6]) 424 | 425 | let assert Ok(x) = tensor.from_floats(of: xs, into: d7) 426 | x 427 | |> tensor.shape 428 | |> should.equal([1, 2, 3, 4, 5, 6, 1]) 429 | } 430 | 431 | pub fn size_test() { 432 | 0.0 433 | |> tensor.from_float 434 | |> tensor.size 435 | |> should.equal(1) 436 | 437 | let assert Ok(d1) = space.d1(Infer("A")) 438 | let assert Ok(x) = tensor.from_ints(of: [1, 2, 3], into: d1) 439 | x 440 | |> tensor.size 441 | |> should.equal(3) 442 | 443 | let assert Ok(d3) = space.d3(A(2), B(2), C(2)) 444 | let assert Ok(x) = tensor.from_ints(of: [1, 2, 3, 4, 5, 6, 7, 8], into: d3) 445 | x 446 | |> tensor.size 447 | |> should.equal(8) 448 | } 449 | 450 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 451 | // Transformation Functions // 452 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 453 | 454 | pub fn reformat_test() { 455 | 0 456 | |> tensor.from_int 457 | |> tensor.reformat(apply: format.float32()) 458 | |> should_share_native_format 459 | |> tensor.format 460 | |> should.equal(format.float32()) 461 | 462 | 0.0 463 | |> tensor.from_float 464 | |> tensor.reformat(apply: format.int32()) 465 | |> should_share_native_format 466 | |> tensor.format 467 | |> should.equal(format.int32()) 468 | } 469 | 470 | pub fn broadcast_test() { 471 | let assert Ok(d1) = space.d1(A(3)) 472 | let assert Ok(d2) = space.d2(A(2), B(3)) 473 | 474 | let assert Ok(x) = 475 | 0 476 | |> tensor.from_int 477 | |> tensor.broadcast(into: d1) 478 | x 479 | |> should_share_native_format 480 | |> tensor.space 481 | |> space.axes 482 | |> should.equal(space.axes(d1)) 483 | x 484 | |> tensor.to_ints 485 | |> should.equal([0, 0, 0]) 486 | 487 | let assert Ok(x) = tensor.broadcast(from: x, into: d2) 488 | x 489 | |> tensor.space 490 | |> space.axes 491 | |> should.equal(space.axes(d2)) 492 | x 493 | |> tensor.to_ints 494 | |> should.equal([0, 0, 0, 0, 0, 0]) 495 | } 496 | 497 | pub fn broadcast_over_test() { 498 | let assert Ok(d1) = space.d1(Infer("A")) 499 | let assert Ok(d2) = space.d2(A(3), B(2)) 500 | let assert Ok(d3) = space.d3(A(3), B(2), C(2)) 501 | 502 | let xs = [1, 2, 3] 503 | let assert Ok(x) = tensor.from_ints(of: xs, into: d1) 504 | let assert Ok(x) = 505 | tensor.broadcast_over(from: x, into: d2, with: fn(_) { "A" }) 506 | x 507 | |> should_share_native_format 508 | |> tensor.space 509 | |> space.axes 510 | |> should.equal(space.axes(d2)) 511 | x 512 | |> tensor.to_ints 513 | |> should.equal(list.flat_map(over: xs, with: list.repeat(item: _, times: 2))) 514 | 515 | let xs = [1, 2, 3, 4, 5, 6] 516 | 517 | let assert Ok(x) = tensor.from_ints(of: xs, into: d2) 518 | let assert Ok(y) = tensor.broadcast_over(from: x, into: d3, with: axis.name) 519 | y 520 | |> tensor.space 521 | |> space.axes 522 | |> should.equal(space.axes(d3)) 523 | y 524 | |> tensor.to_ints 525 | |> should.equal(list.flat_map(over: xs, with: list.repeat(item: _, times: 2))) 526 | 527 | let assert Ok(y) = 528 | tensor.broadcast_over(from: x, into: d3, with: fn(axis) { 529 | case axis.name(axis) { 530 | "A" -> "A" 531 | "B" -> "C" 532 | name -> name 533 | } 534 | }) 535 | y 536 | |> tensor.space 537 | |> space.axes 538 | |> should.equal(space.axes(d3)) 539 | y 540 | |> tensor.to_ints 541 | |> should.equal( 542 | xs 543 | |> list.sized_chunk(into: 2) 544 | |> list.flat_map(with: list.repeat(item: _, times: 2)) 545 | |> list.flatten, 546 | ) 547 | } 548 | 549 | pub fn reshape_test() { 550 | let d0 = space.new() 551 | let assert Ok(d1) = space.d1(Infer("A")) 552 | let assert Ok(d2) = space.d2(Infer("A"), B(1)) 553 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(1)) 554 | let assert Ok(d4) = space.d4(A(1), B(1), Infer("C"), D(1)) 555 | let assert Ok(d5) = space.d5(A(1), B(1), C(1), Infer("D"), E(1)) 556 | let assert Ok(d6) = space.d6(A(1), B(1), C(1), D(1), Infer("E"), F(1)) 557 | 558 | let assert Ok(x) = 559 | 0.0 560 | |> tensor.from_float 561 | |> tensor.reshape(into: d1) 562 | x 563 | |> should_share_native_format 564 | |> tensor.shape 565 | |> should.equal([1]) 566 | 567 | let assert Ok(x) = tensor.reshape(put: x, into: d2) 568 | x 569 | |> tensor.shape 570 | |> should.equal([1, 1]) 571 | 572 | let assert Ok(x) = tensor.reshape(put: x, into: d3) 573 | x 574 | |> tensor.shape 575 | |> should.equal([1, 1, 1]) 576 | 577 | let assert Ok(x) = tensor.reshape(put: x, into: d4) 578 | x 579 | |> tensor.shape 580 | |> should.equal([1, 1, 1, 1]) 581 | 582 | let assert Ok(x) = tensor.reshape(put: x, into: d5) 583 | x 584 | |> tensor.shape 585 | |> should.equal([1, 1, 1, 1, 1]) 586 | 587 | let assert Ok(x) = tensor.reshape(put: x, into: d6) 588 | x 589 | |> tensor.shape 590 | |> should.equal([1, 1, 1, 1, 1, 1]) 591 | 592 | let assert Ok(x) = tensor.reshape(put: x, into: d5) 593 | x 594 | |> tensor.shape 595 | |> should.equal([1, 1, 1, 1, 1]) 596 | 597 | let assert Ok(x) = tensor.reshape(put: x, into: d4) 598 | x 599 | |> tensor.shape 600 | |> should.equal([1, 1, 1, 1]) 601 | 602 | let assert Ok(x) = tensor.reshape(put: x, into: d3) 603 | x 604 | |> tensor.shape 605 | |> should.equal([1, 1, 1]) 606 | 607 | let assert Ok(x) = tensor.reshape(put: x, into: d2) 608 | x 609 | |> tensor.shape 610 | |> should.equal([1, 1]) 611 | 612 | let assert Ok(x) = tensor.reshape(put: x, into: d1) 613 | x 614 | |> tensor.shape 615 | |> should.equal([1]) 616 | 617 | let assert Ok(x) = tensor.reshape(put: x, into: d0) 618 | x 619 | |> tensor.shape 620 | |> should.equal([]) 621 | } 622 | 623 | pub fn squeeze_test() { 624 | 0 625 | |> tensor.from_int 626 | |> tensor.squeeze(with: fn(_) { True }) 627 | |> should_share_native_format 628 | |> tensor.axes 629 | |> should.equal([]) 630 | 631 | 3.0 632 | |> tensor.from_float 633 | |> tensor.squeeze(with: fn(_) { False }) 634 | |> should_share_native_format 635 | |> tensor.axes 636 | |> should.equal([]) 637 | 638 | let assert Ok(d1) = space.d1(Infer("A")) 639 | 640 | let assert Ok(x) = tensor.from_ints(of: [3], into: d1) 641 | x 642 | |> tensor.squeeze(with: fn(_) { True }) 643 | |> tensor.axes 644 | |> should.equal([]) 645 | x 646 | |> tensor.squeeze(with: fn(_) { False }) 647 | |> tensor.axes 648 | |> should.equal([A(1)]) 649 | 650 | let xs = [1, 2] 651 | 652 | let assert Ok(x) = tensor.from_ints(xs, into: d1) 653 | x 654 | |> tensor.squeeze(with: fn(_) { True }) 655 | |> tensor.axes 656 | |> should.equal([A(2)]) 657 | 658 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(1)) 659 | let assert Ok(x) = tensor.from_ints(xs, into: d3) 660 | x 661 | |> tensor.squeeze(with: fn(_) { True }) 662 | |> tensor.axes 663 | |> should.equal([B(2)]) 664 | x 665 | |> tensor.squeeze(with: fn(x) { axis.name(x) == "C" }) 666 | |> tensor.axes 667 | |> should.equal([A(1), B(2)]) 668 | } 669 | 670 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 671 | // Logical Functions // 672 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 673 | 674 | pub fn broadcastable_test() { 675 | let assert Ok(d1) = space.d1(Infer("B")) 676 | let assert Ok(a) = tensor.from_ints(of: [5, 4], into: d1) 677 | let assert Ok(x) = tensor.equal(is: a, to: tensor.from_int(4)) 678 | x 679 | |> should_share_native_format 680 | |> tensor.axes 681 | |> should.equal([B(2)]) 682 | 683 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 684 | let assert Ok(b) = tensor.from_ints(of: [4, 4, 5, 5], into: d2) 685 | let assert Ok(x) = tensor.equal(is: a, to: b) 686 | x 687 | |> tensor.axes 688 | |> should.equal([A(2), B(2)]) 689 | 690 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(1)) 691 | let assert Ok(c) = tensor.from_ints(of: [4, 5, 6], into: d3) 692 | b 693 | |> tensor.equal(to: c) 694 | |> should.be_error 695 | 696 | let assert Ok(d3) = space.d3(C(1), Infer("A"), B(1)) 697 | let assert Ok(c) = tensor.reshape(put: c, into: d3) 698 | b 699 | |> tensor.equal(to: c) 700 | |> should.be_error 701 | 702 | let assert Ok(a) = tensor.from_floats(of: [5.0, 4.0], into: d1) 703 | let assert Ok(x) = tensor.equal(is: a, to: tensor.from_float(4.0)) 704 | x 705 | |> should_share_native_format 706 | |> tensor.to_floats 707 | |> should.equal([0.0, 1.0]) 708 | 709 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 710 | let assert Ok(b) = tensor.from_floats(of: [4.0, 4.0, 5.0, 5.0], into: d2) 711 | let assert Ok(x) = tensor.equal(is: a, to: b) 712 | x 713 | |> tensor.to_floats 714 | |> should.equal([0.0, 1.0, 1.0, 0.0]) 715 | } 716 | 717 | pub fn equal_test() { 718 | let assert Ok(d1) = space.d1(Infer("B")) 719 | let assert Ok(a) = tensor.from_ints(of: [5, 4], into: d1) 720 | let assert Ok(x) = tensor.equal(is: a, to: tensor.from_int(4)) 721 | x 722 | |> should_share_native_format 723 | |> tensor.to_ints 724 | |> should.equal([0, 1]) 725 | 726 | let a = tensor.reformat(a, apply: format.float32()) 727 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 728 | let assert Ok(b) = tensor.from_floats(of: [4.0, 4.0, 5.0, 5.0], into: d2) 729 | let assert Ok(x) = tensor.equal(is: a, to: b) 730 | x 731 | |> should_share_native_format 732 | |> tensor.to_ints 733 | |> should.equal([0, 1, 1, 0]) 734 | } 735 | 736 | pub fn not_equal_test() { 737 | let assert Ok(d1) = space.d1(Infer("B")) 738 | let assert Ok(a) = tensor.from_ints(of: [5, 4], into: d1) 739 | let assert Ok(x) = tensor.not_equal(is: a, to: tensor.from_int(4)) 740 | x 741 | |> should_share_native_format 742 | |> tensor.to_ints 743 | |> should.equal([1, 0]) 744 | 745 | let a = tensor.reformat(a, apply: format.float32()) 746 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 747 | let assert Ok(b) = tensor.from_floats(of: [4.0, 4.0, 5.0, 5.0], into: d2) 748 | let assert Ok(x) = tensor.not_equal(is: a, to: b) 749 | x 750 | |> should_share_native_format 751 | |> tensor.to_ints 752 | |> should.equal([1, 0, 0, 1]) 753 | } 754 | 755 | pub fn greater_test() { 756 | let assert Ok(d1) = space.d1(Infer("B")) 757 | let assert Ok(a) = tensor.from_ints(of: [5, 4], into: d1) 758 | let assert Ok(x) = tensor.greater(is: a, than: tensor.from_int(4)) 759 | x 760 | |> should_share_native_format 761 | |> tensor.to_ints 762 | |> should.equal([1, 0]) 763 | 764 | let a = tensor.reformat(a, apply: format.float32()) 765 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 766 | let assert Ok(b) = tensor.from_floats(of: [4.0, 4.0, 5.0, 5.0], into: d2) 767 | let assert Ok(x) = tensor.greater(is: a, than: b) 768 | x 769 | |> should_share_native_format 770 | |> tensor.to_ints 771 | |> should.equal([1, 0, 0, 0]) 772 | } 773 | 774 | pub fn greater_or_equal_test() { 775 | let assert Ok(d1) = space.d1(Infer("B")) 776 | let assert Ok(a) = tensor.from_ints(of: [5, 4], into: d1) 777 | let assert Ok(x) = tensor.greater_or_equal(is: a, to: tensor.from_int(4)) 778 | x 779 | |> should_share_native_format 780 | |> tensor.to_ints 781 | |> should.equal([1, 1]) 782 | 783 | let a = tensor.reformat(a, apply: format.float32()) 784 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 785 | let assert Ok(b) = tensor.from_floats(of: [4.0, 4.0, 5.0, 5.0], into: d2) 786 | let assert Ok(x) = tensor.greater_or_equal(is: a, to: b) 787 | x 788 | |> should_share_native_format 789 | |> tensor.to_ints 790 | |> should.equal([1, 1, 1, 0]) 791 | } 792 | 793 | pub fn less_test() { 794 | let assert Ok(d1) = space.d1(Infer("B")) 795 | let assert Ok(a) = tensor.from_ints(of: [5, 4], into: d1) 796 | let assert Ok(x) = tensor.less(is: a, than: tensor.from_int(5)) 797 | x 798 | |> should_share_native_format 799 | |> tensor.to_ints 800 | |> should.equal([0, 1]) 801 | 802 | let a = tensor.reformat(a, apply: format.float32()) 803 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 804 | let assert Ok(b) = tensor.from_floats(of: [4.0, 4.0, 5.0, 5.0], into: d2) 805 | let assert Ok(x) = tensor.less(is: a, than: b) 806 | x 807 | |> should_share_native_format 808 | |> tensor.to_ints 809 | |> should.equal([0, 0, 0, 1]) 810 | } 811 | 812 | pub fn less_or_equal_test() { 813 | let assert Ok(d1) = space.d1(Infer("B")) 814 | let assert Ok(a) = tensor.from_ints(of: [5, 4], into: d1) 815 | let assert Ok(x) = tensor.less_or_equal(is: a, to: tensor.from_int(5)) 816 | x 817 | |> should_share_native_format 818 | |> tensor.to_ints 819 | |> should.equal([1, 1]) 820 | 821 | let a = tensor.reformat(a, apply: format.float32()) 822 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 823 | let assert Ok(b) = tensor.from_floats(of: [4.0, 4.0, 5.0, 5.0], into: d2) 824 | let assert Ok(x) = tensor.less_or_equal(is: a, to: b) 825 | x 826 | |> should_share_native_format 827 | |> tensor.to_ints 828 | |> should.equal([0, 1, 1, 1]) 829 | } 830 | 831 | pub fn logical_and_test() { 832 | let assert Ok(d1) = space.d1(Infer("B")) 833 | let assert Ok(a) = tensor.from_ints(of: [9, 0], into: d1) 834 | let assert Ok(x) = tensor.logical_and(a, tensor.from_int(3)) 835 | x 836 | |> should_share_native_format 837 | |> tensor.to_ints 838 | |> should.equal([1, 0]) 839 | 840 | let a = tensor.reformat(a, apply: format.float32()) 841 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 842 | let assert Ok(b) = tensor.from_floats(of: [0.0, 4.0, 5.0, 0.0], into: d2) 843 | let assert Ok(x) = tensor.logical_and(a, b) 844 | x 845 | |> should_share_native_format 846 | |> tensor.to_ints 847 | |> should.equal([0, 0, 1, 0]) 848 | } 849 | 850 | pub fn logical_or_test() { 851 | let assert Ok(d1) = space.d1(Infer("B")) 852 | let assert Ok(a) = tensor.from_ints(of: [9, 0], into: d1) 853 | let assert Ok(x) = tensor.logical_or(a, tensor.from_int(3)) 854 | x 855 | |> should_share_native_format 856 | |> tensor.to_ints 857 | |> should.equal([1, 1]) 858 | 859 | let a = tensor.reformat(a, apply: format.float32()) 860 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 861 | let assert Ok(b) = tensor.from_floats(of: [0.0, 4.0, 5.0, 0.0], into: d2) 862 | let assert Ok(x) = tensor.logical_or(a, b) 863 | x 864 | |> should_share_native_format 865 | |> tensor.to_ints 866 | |> should.equal([1, 1, 1, 0]) 867 | } 868 | 869 | pub fn logical_xor_test() { 870 | let assert Ok(d1) = space.d1(Infer("B")) 871 | let assert Ok(a) = tensor.from_ints(of: [9, 0], into: d1) 872 | let assert Ok(x) = tensor.logical_xor(a, tensor.from_int(3)) 873 | x 874 | |> should_share_native_format 875 | |> tensor.to_ints 876 | |> should.equal([0, 1]) 877 | 878 | let a = tensor.reformat(a, apply: format.float32()) 879 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 880 | let assert Ok(b) = tensor.from_floats(of: [0.0, 4.0, 5.0, 0.0], into: d2) 881 | let assert Ok(x) = tensor.logical_xor(a, b) 882 | x 883 | |> should_share_native_format 884 | |> tensor.to_ints 885 | |> should.equal([1, 1, 0, 0]) 886 | } 887 | 888 | pub fn logical_not_test() { 889 | 3 890 | |> tensor.from_int 891 | |> tensor.logical_not 892 | |> should_share_native_format 893 | |> tensor.to_int 894 | |> should.equal(Ok(0)) 895 | 896 | let assert Ok(d1) = space.d1(Infer("A")) 897 | let assert Ok(x) = tensor.from_floats([-0.3], into: d1) 898 | x 899 | |> tensor.logical_not 900 | |> should_share_native_format 901 | |> tensor.to_floats 902 | |> should.equal([0.0]) 903 | 904 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(1)) 905 | let assert Ok(x) = tensor.from_ints([-1, 8, 0], into: d3) 906 | let x = tensor.logical_not(x) 907 | x 908 | |> tensor.to_ints 909 | |> should.equal([0, 0, 1]) 910 | x 911 | |> tensor.axes 912 | |> should.equal([A(1), B(3), C(1)]) 913 | } 914 | 915 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 916 | // Arithmetic Functions // 917 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 918 | 919 | pub fn add_test() { 920 | let assert Ok(d1) = space.d1(Infer("B")) 921 | let assert Ok(a) = tensor.from_ints(of: [0, 9], into: d1) 922 | let assert Ok(x) = tensor.add(a, tensor.from_int(3)) 923 | x 924 | |> should_share_native_format 925 | |> tensor.to_ints 926 | |> should.equal([3, 12]) 927 | 928 | let a = tensor.reformat(a, apply: format.float32()) 929 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 930 | let assert Ok(b) = tensor.from_floats(of: [0.0, 4.0, 5.0, 0.0], into: d2) 931 | let assert Ok(x) = tensor.add(a, b) 932 | x 933 | |> should_share_native_format 934 | |> tensor.to_ints 935 | |> should.equal([0, 13, 5, 9]) 936 | } 937 | 938 | pub fn subtract_test() { 939 | let assert Ok(d1) = space.d1(Infer("B")) 940 | let assert Ok(a) = tensor.from_ints(of: [0, 9], into: d1) 941 | let assert Ok(x) = tensor.subtract(from: a, value: tensor.from_int(3)) 942 | x 943 | |> should_share_native_format 944 | |> tensor.to_ints 945 | |> should.equal([-3, 6]) 946 | 947 | let a = tensor.reformat(a, apply: format.float32()) 948 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 949 | let assert Ok(b) = tensor.from_floats(of: [0.0, 4.0, 5.0, 0.0], into: d2) 950 | let assert Ok(x) = tensor.subtract(from: a, value: b) 951 | x 952 | |> should_share_native_format 953 | |> tensor.to_ints 954 | |> should.equal([0, 5, -5, 9]) 955 | } 956 | 957 | pub fn multiply_test() { 958 | let assert Ok(d1) = space.d1(Infer("B")) 959 | let assert Ok(a) = tensor.from_ints(of: [1, 9], into: d1) 960 | let assert Ok(x) = tensor.multiply(a, tensor.from_int(3)) 961 | x 962 | |> should_share_native_format 963 | |> tensor.to_ints 964 | |> should.equal([3, 27]) 965 | 966 | let a = tensor.reformat(a, apply: format.float32()) 967 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 968 | let assert Ok(b) = tensor.from_floats(of: [0.0, 4.0, 5.0, 9.0], into: d2) 969 | let assert Ok(x) = tensor.multiply(a, b) 970 | x 971 | |> should_share_native_format 972 | |> tensor.to_ints 973 | |> should.equal([0, 36, 5, 81]) 974 | } 975 | 976 | pub fn divide_test() { 977 | let assert Ok(d1) = space.d1(Infer("B")) 978 | let assert Ok(a) = tensor.from_ints(of: [1, 9], into: d1) 979 | let assert Ok(x) = tensor.divide(from: a, by: tensor.from_int(3)) 980 | x 981 | |> should_share_native_format 982 | |> tensor.to_ints 983 | |> should.equal([0, 3]) 984 | 985 | let a = tensor.reformat(a, apply: format.float32()) 986 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 987 | let assert Ok(b) = tensor.from_floats(of: [0.0, 4.0, 5.0, 9.0], into: d2) 988 | let assert Ok(x) = tensor.divide(from: a, by: b) 989 | x 990 | |> should_share_native_format 991 | |> tensor.to_floats 992 | |> should_loosely_equal([0.0, 2.25, 0.2, 1.0]) 993 | 994 | let assert Ok(x) = 995 | [infinity()] 996 | |> dynamic.from 997 | |> native_tensor 998 | |> tensor.from_native(into: space.new(), with: format.float32()) 999 | x 1000 | |> tensor.divide(by: x) 1001 | |> should.equal(Error(tensor.InvalidData)) 1002 | } 1003 | 1004 | pub fn try_divide_test() { 1005 | let assert Ok(d1) = space.d1(Infer("B")) 1006 | let assert Ok(a) = tensor.from_ints(of: [1, 9], into: d1) 1007 | let assert Ok(x) = tensor.try_divide(from: a, by: tensor.from_int(3)) 1008 | x 1009 | |> should_share_native_format 1010 | |> tensor.to_ints 1011 | |> should.equal([0, 3]) 1012 | 1013 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 1014 | let assert Ok(b) = tensor.from_ints(of: [0, 4, 5, 9], into: d2) 1015 | a 1016 | |> tensor.try_divide(by: b) 1017 | |> should.equal(Error(tensor.ZeroDivision)) 1018 | } 1019 | 1020 | pub fn remainder_test() { 1021 | let assert Ok(d1) = space.d1(Infer("B")) 1022 | let assert Ok(a) = tensor.from_ints(of: [13, -13], into: d1) 1023 | let assert Ok(x) = tensor.remainder(from: a, divided_by: tensor.from_int(0)) 1024 | x 1025 | |> should_share_native_format 1026 | |> tensor.to_ints 1027 | |> should.equal([0, 0]) 1028 | 1029 | let a = tensor.reformat(a, apply: format.float32()) 1030 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 1031 | let assert Ok(b) = tensor.from_floats(of: [3.0, 3.0, -3.0, -3.0], into: d2) 1032 | let assert Ok(x) = tensor.remainder(from: a, divided_by: b) 1033 | x 1034 | |> should_share_native_format 1035 | |> tensor.to_ints 1036 | |> should.equal([1, -1, 1, -1]) 1037 | } 1038 | 1039 | pub fn try_remainder_test() { 1040 | let assert Ok(d1) = space.d1(Infer("B")) 1041 | let assert Ok(a) = tensor.from_ints(of: [1, 9], into: d1) 1042 | let assert Ok(x) = 1043 | tensor.try_remainder(from: a, divided_by: tensor.from_int(3)) 1044 | x 1045 | |> should_share_native_format 1046 | |> tensor.to_ints 1047 | |> should.equal([1, 0]) 1048 | 1049 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 1050 | let assert Ok(b) = tensor.from_ints(of: [0, 4, 5, 9], into: d2) 1051 | a 1052 | |> tensor.try_remainder(divided_by: b) 1053 | |> should.equal(Error(tensor.ZeroDivision)) 1054 | } 1055 | 1056 | pub fn modulo_test() { 1057 | let assert Ok(d1) = space.d1(Infer("B")) 1058 | let assert Ok(a) = tensor.from_ints(of: [13, -13], into: d1) 1059 | let assert Ok(x) = tensor.modulo(from: a, divided_by: tensor.from_int(0)) 1060 | x 1061 | |> should_share_native_format 1062 | |> tensor.to_ints 1063 | |> should.equal([0, 0]) 1064 | 1065 | let a = tensor.reformat(a, apply: format.float32()) 1066 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 1067 | let assert Ok(b) = tensor.from_floats(of: [3.0, 3.0, -3.0, -3.0], into: d2) 1068 | let assert Ok(x) = tensor.modulo(from: a, divided_by: b) 1069 | x 1070 | |> should_share_native_format 1071 | |> tensor.to_ints 1072 | |> should.equal([1, 2, -2, -1]) 1073 | } 1074 | 1075 | pub fn try_modulo_test() { 1076 | let assert Ok(d1) = space.d1(Infer("B")) 1077 | let assert Ok(a) = tensor.from_ints(of: [1, 9], into: d1) 1078 | let assert Ok(x) = tensor.try_modulo(from: a, divided_by: tensor.from_int(3)) 1079 | x 1080 | |> tensor.to_ints 1081 | |> should.equal([1, 0]) 1082 | 1083 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 1084 | let assert Ok(b) = tensor.from_ints(of: [0, 4, 5, 9], into: d2) 1085 | a 1086 | |> tensor.try_modulo(divided_by: b) 1087 | |> should.equal(Error(tensor.ZeroDivision)) 1088 | } 1089 | 1090 | pub fn power_test() { 1091 | let assert Ok(d1) = space.d1(Infer("B")) 1092 | let assert Ok(a) = tensor.from_ints(of: [1, 9], into: d1) 1093 | let assert Ok(x) = tensor.power(raise: a, to_the: tensor.from_int(3)) 1094 | x 1095 | |> should_share_native_format 1096 | |> tensor.to_ints 1097 | |> should.equal([1, 729]) 1098 | 1099 | let a = tensor.reformat(a, apply: format.float32()) 1100 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 1101 | let assert Ok(b) = tensor.from_floats(of: [0.0, 0.4, 0.5, 0.9], into: d2) 1102 | let assert Ok(x) = tensor.power(raise: a, to_the: b) 1103 | x 1104 | |> should_share_native_format 1105 | |> tensor.to_floats 1106 | |> should_loosely_equal([1.0, 2.408, 1.0, 7.225]) 1107 | } 1108 | 1109 | pub fn max_test() { 1110 | let assert Ok(d1) = space.d1(Infer("B")) 1111 | let assert Ok(a) = tensor.from_ints(of: [1, 9], into: d1) 1112 | let assert Ok(x) = tensor.max(a, tensor.from_int(3)) 1113 | x 1114 | |> should_share_native_format 1115 | |> tensor.to_ints 1116 | |> should.equal([3, 9]) 1117 | 1118 | let a = tensor.reformat(a, apply: format.float32()) 1119 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 1120 | let assert Ok(b) = tensor.from_floats(of: [0.0, 4.0, 5.0, -9.0], into: d2) 1121 | let assert Ok(x) = tensor.max(a, b) 1122 | x 1123 | |> should_share_native_format 1124 | |> tensor.to_ints 1125 | |> should.equal([1, 9, 5, 9]) 1126 | } 1127 | 1128 | pub fn min_test() { 1129 | let assert Ok(d1) = space.d1(Infer("B")) 1130 | let assert Ok(a) = tensor.from_ints(of: [1, 9], into: d1) 1131 | let assert Ok(x) = tensor.min(a, tensor.from_int(3)) 1132 | x 1133 | |> should_share_native_format 1134 | |> tensor.to_ints 1135 | |> should.equal([1, 3]) 1136 | 1137 | let a = tensor.reformat(a, apply: format.float32()) 1138 | let assert Ok(d2) = space.d2(Infer("A"), B(2)) 1139 | let assert Ok(b) = tensor.from_floats(of: [0.0, 4.0, 5.0, -9.0], into: d2) 1140 | let assert Ok(x) = tensor.min(a, b) 1141 | x 1142 | |> should_share_native_format 1143 | |> tensor.to_ints 1144 | |> should.equal([0, 4, 1, -9]) 1145 | } 1146 | 1147 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 1148 | // Basic Math Functions // 1149 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 1150 | 1151 | pub fn absolute_value_test() { 1152 | 3 1153 | |> tensor.from_int 1154 | |> tensor.absolute_value 1155 | |> should_share_native_format 1156 | |> tensor.to_int 1157 | |> should.equal(Ok(3)) 1158 | 1159 | let assert Ok(d1) = space.d1(Infer("A")) 1160 | let assert Ok(x) = tensor.from_floats([-0.3], into: d1) 1161 | x 1162 | |> tensor.absolute_value 1163 | |> should_share_native_format 1164 | |> tensor.to_floats 1165 | |> should_loosely_equal([0.3]) 1166 | 1167 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(1)) 1168 | let assert Ok(x) = tensor.from_ints([-1, 8, 0], into: d3) 1169 | let x = tensor.absolute_value(x) 1170 | x 1171 | |> tensor.to_ints 1172 | |> should.equal([1, 8, 0]) 1173 | x 1174 | |> tensor.axes 1175 | |> should.equal([A(1), B(3), C(1)]) 1176 | } 1177 | 1178 | pub fn negate_test() { 1179 | 3 1180 | |> tensor.from_int 1181 | |> tensor.negate 1182 | |> should_share_native_format 1183 | |> tensor.to_int 1184 | |> should.equal(Ok(-3)) 1185 | 1186 | let assert Ok(d1) = space.d1(Infer("A")) 1187 | let assert Ok(x) = tensor.from_floats([-0.3], into: d1) 1188 | x 1189 | |> tensor.negate 1190 | |> should_share_native_format 1191 | |> tensor.to_floats 1192 | |> should_loosely_equal([0.3]) 1193 | 1194 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(1)) 1195 | let assert Ok(x) = tensor.from_ints([-1, 8, 0], into: d3) 1196 | let x = tensor.negate(x) 1197 | x 1198 | |> tensor.to_ints 1199 | |> should.equal([1, -8, 0]) 1200 | x 1201 | |> tensor.axes 1202 | |> should.equal([A(1), B(3), C(1)]) 1203 | } 1204 | 1205 | pub fn sign_test() { 1206 | 3 1207 | |> tensor.from_int 1208 | |> tensor.sign 1209 | |> should_share_native_format 1210 | |> tensor.to_int 1211 | |> should.equal(Ok(1)) 1212 | 1213 | let assert Ok(d1) = space.d1(Infer("A")) 1214 | let assert Ok(x) = tensor.from_floats([-0.3], into: d1) 1215 | x 1216 | |> tensor.sign 1217 | |> should_share_native_format 1218 | |> tensor.to_floats 1219 | |> should.equal([-1.0]) 1220 | 1221 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(1)) 1222 | let assert Ok(x) = tensor.from_ints([-1, 8, 0], into: d3) 1223 | let x = tensor.sign(x) 1224 | x 1225 | |> tensor.to_ints 1226 | |> should.equal([-1, 1, 0]) 1227 | x 1228 | |> tensor.axes 1229 | |> should.equal([A(1), B(3), C(1)]) 1230 | } 1231 | 1232 | pub fn ceiling_test() { 1233 | 3 1234 | |> tensor.from_int 1235 | |> tensor.ceiling 1236 | |> should_share_native_format 1237 | |> tensor.to_int 1238 | |> should.equal(Ok(3)) 1239 | 1240 | let assert Ok(d1) = space.d1(Infer("A")) 1241 | let assert Ok(x) = tensor.from_floats([-0.5], into: d1) 1242 | x 1243 | |> tensor.ceiling 1244 | |> should_share_native_format 1245 | |> tensor.to_floats 1246 | |> should.equal([0.0]) 1247 | 1248 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(1)) 1249 | let assert Ok(x) = tensor.from_floats([-1.2, 7.8, 0.0], into: d3) 1250 | let x = tensor.ceiling(x) 1251 | x 1252 | |> tensor.to_floats 1253 | |> should.equal([-1.0, 8.0, 0.0]) 1254 | x 1255 | |> tensor.axes 1256 | |> should.equal([A(1), B(3), C(1)]) 1257 | } 1258 | 1259 | pub fn floor_test() { 1260 | 3 1261 | |> tensor.from_int 1262 | |> tensor.floor 1263 | |> should_share_native_format 1264 | |> tensor.to_int 1265 | |> should.equal(Ok(3)) 1266 | 1267 | let assert Ok(d1) = space.d1(Infer("A")) 1268 | let assert Ok(x) = tensor.from_floats([-0.5], into: d1) 1269 | x 1270 | |> tensor.floor 1271 | |> should_share_native_format 1272 | |> tensor.to_floats 1273 | |> should.equal([-1.0]) 1274 | 1275 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(1)) 1276 | let assert Ok(x) = tensor.from_floats([-1.2, 7.8, 0.0], into: d3) 1277 | let x = tensor.floor(x) 1278 | x 1279 | |> tensor.to_floats 1280 | |> should.equal([-2.0, 7.0, 0.0]) 1281 | x 1282 | |> tensor.axes 1283 | |> should.equal([A(1), B(3), C(1)]) 1284 | } 1285 | 1286 | pub fn round_test() { 1287 | 3 1288 | |> tensor.from_int 1289 | |> tensor.round 1290 | |> should_share_native_format 1291 | |> tensor.to_int 1292 | |> should.equal(Ok(3)) 1293 | 1294 | // For (+/-)0.5, TensorFlow currently rounds to 0. 1295 | let assert Ok(d1) = space.d1(Infer("A")) 1296 | let assert Ok(x) = tensor.from_floats([-1.5], into: d1) 1297 | x 1298 | |> tensor.round 1299 | |> should_share_native_format 1300 | |> tensor.to_floats 1301 | |> should.equal([-2.0]) 1302 | 1303 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(1)) 1304 | let assert Ok(x) = tensor.from_floats([-1.2, 7.8, 0.0], into: d3) 1305 | let x = tensor.round(x) 1306 | x 1307 | |> tensor.to_floats 1308 | |> should.equal([-1.0, 8.0, 0.0]) 1309 | x 1310 | |> tensor.axes 1311 | |> should.equal([A(1), B(3), C(1)]) 1312 | } 1313 | 1314 | pub fn exp_test() { 1315 | 3 1316 | |> tensor.from_int 1317 | |> tensor.exp 1318 | |> should_share_native_format 1319 | |> tensor.to_int 1320 | |> should.equal(Ok(20)) 1321 | 1322 | let assert Ok(d1) = space.d1(Infer("A")) 1323 | let assert Ok(x) = tensor.from_floats([-1.5], into: d1) 1324 | x 1325 | |> tensor.exp 1326 | |> should_share_native_format 1327 | |> tensor.to_floats 1328 | |> should_loosely_equal([0.223]) 1329 | 1330 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(1)) 1331 | let assert Ok(x) = tensor.from_floats([-1.2, 7.8, 0.0], into: d3) 1332 | let x = tensor.exp(x) 1333 | x 1334 | |> tensor.to_floats 1335 | |> should_loosely_equal([0.301, 2440.603, 1.0]) 1336 | x 1337 | |> tensor.axes 1338 | |> should.equal([A(1), B(3), C(1)]) 1339 | 1340 | let assert Ok(x) = tensor.from_ints([-90, 90, 0], into: d1) 1341 | x 1342 | |> tensor.exp 1343 | |> tensor.to_string(return: tensor.Record, wrap_at: 0) 1344 | |> should.equal( 1345 | "Tensor( 1346 | Format(Int32), 1347 | Space(A(3)), 1348 | [ 0, 2147483647, 1], 1349 | )", 1350 | ) 1351 | } 1352 | 1353 | pub fn square_root_test() { 1354 | let assert Ok(x) = 1355 | 3 1356 | |> tensor.from_int 1357 | |> tensor.square_root 1358 | x 1359 | |> should_share_native_format 1360 | |> tensor.to_int 1361 | |> should.equal(Ok(1)) 1362 | 1363 | let assert Ok(d1) = space.d1(Infer("A")) 1364 | let assert Ok(x) = tensor.from_floats([1.5], into: d1) 1365 | let assert Ok(x) = tensor.square_root(x) 1366 | x 1367 | |> should_share_native_format 1368 | |> tensor.to_floats 1369 | |> should_loosely_equal([1.225]) 1370 | 1371 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(1)) 1372 | let assert Ok(x) = tensor.from_floats([1.2, 7.8, 0.0], into: d3) 1373 | let assert Ok(x) = tensor.square_root(x) 1374 | x 1375 | |> tensor.to_floats 1376 | |> should_loosely_equal([1.095, 2.793, 0.0]) 1377 | x 1378 | |> tensor.axes 1379 | |> should.equal([A(1), B(3), C(1)]) 1380 | 1381 | let assert Ok(x) = tensor.from_ints([1, 90, 0], into: d1) 1382 | let assert Ok(x) = tensor.square_root(x) 1383 | x 1384 | |> tensor.to_string(return: tensor.Record, wrap_at: 0) 1385 | |> should.equal( 1386 | "Tensor( 1387 | Format(Int32), 1388 | Space(A(3)), 1389 | [1, 9, 0], 1390 | )", 1391 | ) 1392 | 1393 | -1 1394 | |> tensor.from_int 1395 | |> tensor.square_root 1396 | |> should.equal(Error(tensor.InvalidData)) 1397 | 1398 | let assert Ok(x) = tensor.from_floats(of: [-0.1], into: d1) 1399 | x 1400 | |> tensor.square_root 1401 | |> should.equal(Error(tensor.InvalidData)) 1402 | } 1403 | 1404 | pub fn ln_test() { 1405 | let assert Ok(x) = 1406 | 3 1407 | |> tensor.from_int 1408 | |> tensor.ln 1409 | x 1410 | |> should_share_native_format 1411 | |> tensor.to_int 1412 | |> should.equal(Ok(1)) 1413 | 1414 | let assert Ok(d1) = space.d1(Infer("A")) 1415 | let assert Ok(x) = tensor.from_floats([1.5], into: d1) 1416 | let assert Ok(x) = tensor.ln(x) 1417 | x 1418 | |> should_share_native_format 1419 | |> tensor.to_floats 1420 | |> should_loosely_equal([0.405]) 1421 | 1422 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(1)) 1423 | let assert Ok(x) = tensor.from_floats([1.2, 7.8, 0.0], into: d3) 1424 | let assert Ok(x) = tensor.ln(x) 1425 | x 1426 | |> tensor.to_floats 1427 | |> should_loosely_equal([0.182, 2.054, float32_min]) 1428 | x 1429 | |> tensor.axes 1430 | |> should.equal([A(1), B(3), C(1)]) 1431 | 1432 | let assert Ok(x) = tensor.from_ints([1, 90, 0], into: d1) 1433 | let assert Ok(x) = tensor.ln(x) 1434 | x 1435 | |> tensor.to_string(return: tensor.Record, wrap_at: 0) 1436 | |> should.equal( 1437 | "Tensor( 1438 | Format(Int32), 1439 | Space(A(3)), 1440 | [ 0, 4, -2147483648], 1441 | )", 1442 | ) 1443 | 1444 | -1 1445 | |> tensor.from_int 1446 | |> tensor.ln 1447 | |> should.equal(Error(tensor.InvalidData)) 1448 | 1449 | let assert Ok(x) = tensor.from_floats(of: [-0.1], into: d1) 1450 | x 1451 | |> tensor.ln 1452 | |> should.equal(Error(tensor.InvalidData)) 1453 | } 1454 | 1455 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 1456 | // Reduction Functions // 1457 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 1458 | 1459 | pub fn in_situ_test() { 1460 | 0 1461 | |> tensor.from_int 1462 | |> tensor.in_situ_all(with: fn(_) { True }) 1463 | |> should_share_native_format 1464 | |> tensor.to_int 1465 | |> should.equal(Ok(0)) 1466 | 1467 | -3.0 1468 | |> tensor.from_float 1469 | |> tensor.in_situ_all(with: fn(_) { False }) 1470 | |> should_share_native_format 1471 | |> tensor.to_float 1472 | |> should.equal(Ok(1.0)) 1473 | 1474 | let assert Ok(d1) = space.d1(Infer("A")) 1475 | 1476 | let assert Ok(x) = tensor.from_ints(of: [3], into: d1) 1477 | let y = tensor.in_situ_all(from: x, with: fn(_) { True }) 1478 | y 1479 | |> tensor.to_ints 1480 | |> should.equal([1]) 1481 | y 1482 | |> tensor.axes 1483 | |> should.equal([A(1)]) 1484 | let y = tensor.in_situ_all(from: x, with: fn(_) { False }) 1485 | y 1486 | |> tensor.axes 1487 | |> should.equal([A(1)]) 1488 | 1489 | let xs = [0, 1] 1490 | 1491 | let assert Ok(x) = tensor.from_ints(xs, into: d1) 1492 | x 1493 | |> tensor.in_situ_all(with: fn(_) { True }) 1494 | |> tensor.axes 1495 | |> should.equal([A(1)]) 1496 | 1497 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(1)) 1498 | let assert Ok(x) = tensor.from_ints(xs, into: d3) 1499 | let y = tensor.in_situ_all(from: x, with: fn(_) { True }) 1500 | y 1501 | |> tensor.axes 1502 | |> should.equal([A(1), B(1), C(1)]) 1503 | y 1504 | |> tensor.squeeze(with: fn(_) { True }) 1505 | |> tensor.to_int 1506 | |> should.equal(Ok(0)) 1507 | let y = tensor.in_situ_all(from: x, with: fn(x) { axis.name(x) == "C" }) 1508 | y 1509 | |> tensor.axes 1510 | |> should.equal([A(1), B(2), C(1)]) 1511 | y 1512 | |> tensor.squeeze(with: fn(_) { True }) 1513 | |> tensor.to_ints 1514 | |> should.equal([0, 1]) 1515 | } 1516 | 1517 | pub fn all_test() { 1518 | 0.0 1519 | |> tensor.from_float 1520 | |> tensor.all(with: fn(_) { True }) 1521 | |> should_share_native_format 1522 | |> tensor.to_float 1523 | |> should.equal(Ok(0.0)) 1524 | 1525 | 3 1526 | |> tensor.from_int 1527 | |> tensor.all(with: fn(_) { False }) 1528 | |> should_share_native_format 1529 | |> tensor.to_int 1530 | |> should.equal(Ok(1)) 1531 | 1532 | let assert Ok(d1) = space.d1(Infer("A")) 1533 | 1534 | let assert Ok(x) = tensor.from_ints(of: [3], into: d1) 1535 | x 1536 | |> tensor.all(with: fn(_) { True }) 1537 | |> tensor.to_int 1538 | |> should.equal(Ok(1)) 1539 | x 1540 | |> tensor.all(with: fn(_) { False }) 1541 | |> tensor.axes 1542 | |> should.equal([A(1)]) 1543 | 1544 | let xs = [0, 1] 1545 | 1546 | let assert Ok(x) = tensor.from_ints(xs, into: d1) 1547 | x 1548 | |> tensor.all(with: fn(_) { True }) 1549 | |> tensor.to_int 1550 | |> should.equal(Ok(0)) 1551 | 1552 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(1)) 1553 | let assert Ok(x) = tensor.from_ints(xs, into: d3) 1554 | x 1555 | |> tensor.all(with: fn(_) { True }) 1556 | |> tensor.to_int 1557 | |> should.equal(Ok(0)) 1558 | let y = tensor.all(from: x, with: fn(x) { axis.name(x) == "B" }) 1559 | y 1560 | |> tensor.squeeze(with: fn(_) { True }) 1561 | |> tensor.to_int 1562 | |> should.equal(Ok(0)) 1563 | y 1564 | |> tensor.axes 1565 | |> should.equal([A(1), C(1)]) 1566 | } 1567 | 1568 | pub fn any_test() { 1569 | 0.0 1570 | |> tensor.from_float 1571 | |> tensor.any(with: fn(_) { True }) 1572 | |> should_share_native_format 1573 | |> tensor.to_float 1574 | |> should.equal(Ok(0.0)) 1575 | 1576 | 3 1577 | |> tensor.from_int 1578 | |> tensor.any(with: fn(_) { False }) 1579 | |> should_share_native_format 1580 | |> tensor.to_int 1581 | |> should.equal(Ok(1)) 1582 | 1583 | let assert Ok(d1) = space.d1(Infer("A")) 1584 | 1585 | let assert Ok(x) = tensor.from_ints(of: [3], into: d1) 1586 | x 1587 | |> tensor.any(with: fn(_) { True }) 1588 | |> tensor.to_int 1589 | |> should.equal(Ok(1)) 1590 | x 1591 | |> tensor.any(with: fn(_) { False }) 1592 | |> tensor.axes 1593 | |> should.equal([A(1)]) 1594 | 1595 | let xs = [0, 1] 1596 | 1597 | let assert Ok(x) = tensor.from_ints(xs, into: d1) 1598 | x 1599 | |> tensor.any(with: fn(_) { True }) 1600 | |> tensor.to_int 1601 | |> should.equal(Ok(1)) 1602 | 1603 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(1)) 1604 | let assert Ok(x) = tensor.from_ints(xs, into: d3) 1605 | x 1606 | |> tensor.any(with: fn(_) { True }) 1607 | |> tensor.to_int 1608 | |> should.equal(Ok(1)) 1609 | let y = tensor.any(from: x, with: fn(x) { axis.name(x) == "B" }) 1610 | y 1611 | |> tensor.squeeze(with: fn(_) { True }) 1612 | |> tensor.to_int 1613 | |> should.equal(Ok(1)) 1614 | y 1615 | |> tensor.axes 1616 | |> should.equal([A(1), C(1)]) 1617 | } 1618 | 1619 | pub fn arg_max_test() { 1620 | 0.0 1621 | |> tensor.from_float 1622 | |> tensor.arg_max(with: fn(_) { True }) 1623 | |> should_share_native_format 1624 | |> tensor.to_float 1625 | |> should.equal(Ok(0.0)) 1626 | 1627 | 3 1628 | |> tensor.from_int 1629 | |> tensor.arg_max(with: fn(_) { False }) 1630 | |> should_share_native_format 1631 | |> tensor.to_int 1632 | |> should.equal(Ok(0)) 1633 | 1634 | let assert Ok(d1) = space.d1(Infer("A")) 1635 | 1636 | let assert Ok(x) = tensor.from_ints(of: [3], into: d1) 1637 | x 1638 | |> tensor.arg_max(with: fn(_) { True }) 1639 | |> tensor.to_int 1640 | |> should.equal(Ok(0)) 1641 | x 1642 | |> tensor.arg_max(with: fn(_) { False }) 1643 | |> tensor.axes 1644 | |> should.equal([]) 1645 | 1646 | let xs = [1, 4, 3, 2] 1647 | 1648 | let assert Ok(x) = tensor.from_ints(xs, into: d1) 1649 | x 1650 | |> tensor.arg_max(with: fn(_) { True }) 1651 | |> tensor.to_int 1652 | |> should.equal(Ok(1)) 1653 | 1654 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(2)) 1655 | let assert Ok(x) = tensor.from_ints(xs, into: d3) 1656 | x 1657 | |> tensor.arg_max(with: fn(_) { True }) 1658 | |> tensor.to_ints 1659 | |> should.equal([0, 0, 0, 0]) 1660 | let y = tensor.arg_max(from: x, with: fn(x) { axis.name(x) == "C" }) 1661 | y 1662 | |> tensor.to_ints 1663 | |> should.equal([1, 0]) 1664 | y 1665 | |> tensor.axes 1666 | |> should.equal([A(1), B(2)]) 1667 | } 1668 | 1669 | pub fn arg_min_test() { 1670 | 0.0 1671 | |> tensor.from_float 1672 | |> tensor.arg_min(with: fn(_) { True }) 1673 | |> should_share_native_format 1674 | |> tensor.to_float 1675 | |> should.equal(Ok(0.0)) 1676 | 1677 | 3 1678 | |> tensor.from_int 1679 | |> tensor.arg_min(with: fn(_) { False }) 1680 | |> should_share_native_format 1681 | |> tensor.to_int 1682 | |> should.equal(Ok(0)) 1683 | 1684 | let assert Ok(d1) = space.d1(Infer("A")) 1685 | 1686 | let assert Ok(x) = tensor.from_ints(of: [3], into: d1) 1687 | x 1688 | |> tensor.arg_min(with: fn(_) { True }) 1689 | |> tensor.to_int 1690 | |> should.equal(Ok(0)) 1691 | x 1692 | |> tensor.arg_min(with: fn(_) { False }) 1693 | |> tensor.axes 1694 | |> should.equal([]) 1695 | 1696 | let xs = [1, 4, 3, 2] 1697 | 1698 | let assert Ok(x) = tensor.from_ints(xs, into: d1) 1699 | x 1700 | |> tensor.arg_min(with: fn(_) { True }) 1701 | |> tensor.to_int 1702 | |> should.equal(Ok(0)) 1703 | 1704 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(2)) 1705 | let assert Ok(x) = tensor.from_ints(xs, into: d3) 1706 | x 1707 | |> tensor.arg_min(with: fn(_) { True }) 1708 | |> tensor.to_ints 1709 | |> should.equal([0, 0, 0, 0]) 1710 | let y = tensor.arg_min(from: x, with: fn(x) { axis.name(x) == "C" }) 1711 | y 1712 | |> tensor.to_ints 1713 | |> should.equal([0, 1]) 1714 | y 1715 | |> tensor.axes 1716 | |> should.equal([A(1), B(2)]) 1717 | } 1718 | 1719 | pub fn max_over_test() { 1720 | 0.0 1721 | |> tensor.from_float 1722 | |> tensor.max_over(with: fn(_) { True }) 1723 | |> should_share_native_format 1724 | |> tensor.to_float 1725 | |> should.equal(Ok(0.0)) 1726 | 1727 | 3 1728 | |> tensor.from_int 1729 | |> tensor.max_over(with: fn(_) { False }) 1730 | |> should_share_native_format 1731 | |> tensor.to_int 1732 | |> should.equal(Ok(3)) 1733 | 1734 | let assert Ok(d1) = space.d1(Infer("A")) 1735 | 1736 | let assert Ok(x) = tensor.from_ints(of: [3], into: d1) 1737 | x 1738 | |> tensor.max_over(with: fn(_) { True }) 1739 | |> tensor.to_int 1740 | |> should.equal(Ok(3)) 1741 | x 1742 | |> tensor.max_over(with: fn(_) { False }) 1743 | |> tensor.axes 1744 | |> should.equal([A(1)]) 1745 | 1746 | let xs = [1, 4, 3, 2] 1747 | 1748 | let assert Ok(x) = tensor.from_ints(xs, into: d1) 1749 | x 1750 | |> tensor.max_over(with: fn(_) { True }) 1751 | |> tensor.to_int 1752 | |> should.equal(Ok(4)) 1753 | 1754 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(2)) 1755 | let assert Ok(x) = tensor.from_ints(xs, into: d3) 1756 | x 1757 | |> tensor.max_over(with: fn(_) { True }) 1758 | |> tensor.to_int 1759 | |> should.equal(Ok(4)) 1760 | let y = tensor.max_over(from: x, with: fn(x) { axis.name(x) == "C" }) 1761 | y 1762 | |> tensor.to_ints 1763 | |> should.equal([4, 3]) 1764 | y 1765 | |> tensor.axes 1766 | |> should.equal([A(1), B(2)]) 1767 | } 1768 | 1769 | pub fn min_over_test() { 1770 | 0.0 1771 | |> tensor.from_float 1772 | |> tensor.min_over(with: fn(_) { True }) 1773 | |> should_share_native_format 1774 | |> tensor.to_float 1775 | |> should.equal(Ok(0.0)) 1776 | 1777 | 3 1778 | |> tensor.from_int 1779 | |> tensor.min_over(with: fn(_) { False }) 1780 | |> should_share_native_format 1781 | |> tensor.to_int 1782 | |> should.equal(Ok(3)) 1783 | 1784 | let assert Ok(d1) = space.d1(Infer("A")) 1785 | 1786 | let assert Ok(x) = tensor.from_ints(of: [3], into: d1) 1787 | x 1788 | |> tensor.min_over(with: fn(_) { True }) 1789 | |> tensor.to_int 1790 | |> should.equal(Ok(3)) 1791 | x 1792 | |> tensor.min_over(with: fn(_) { False }) 1793 | |> tensor.axes 1794 | |> should.equal([A(1)]) 1795 | 1796 | let xs = [1, 4, 3, 2] 1797 | 1798 | let assert Ok(x) = tensor.from_ints(xs, into: d1) 1799 | x 1800 | |> tensor.min_over(with: fn(_) { True }) 1801 | |> tensor.to_int 1802 | |> should.equal(Ok(1)) 1803 | 1804 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(2)) 1805 | let assert Ok(x) = tensor.from_ints(xs, into: d3) 1806 | x 1807 | |> tensor.min_over(with: fn(_) { True }) 1808 | |> tensor.to_int 1809 | |> should.equal(Ok(1)) 1810 | let y = tensor.min_over(from: x, with: fn(x) { axis.name(x) == "C" }) 1811 | y 1812 | |> tensor.to_ints 1813 | |> should.equal([1, 2]) 1814 | y 1815 | |> tensor.axes 1816 | |> should.equal([A(1), B(2)]) 1817 | } 1818 | 1819 | pub fn sum_test() { 1820 | 0.0 1821 | |> tensor.from_float 1822 | |> tensor.sum(with: fn(_) { True }) 1823 | |> should_share_native_format 1824 | |> tensor.to_float 1825 | |> should.equal(Ok(0.0)) 1826 | 1827 | 3 1828 | |> tensor.from_int 1829 | |> tensor.sum(with: fn(_) { False }) 1830 | |> should_share_native_format 1831 | |> tensor.to_int 1832 | |> should.equal(Ok(3)) 1833 | 1834 | let assert Ok(d1) = space.d1(Infer("A")) 1835 | 1836 | let assert Ok(x) = tensor.from_ints(of: [3], into: d1) 1837 | x 1838 | |> tensor.sum(with: fn(_) { True }) 1839 | |> tensor.to_int 1840 | |> should.equal(Ok(3)) 1841 | x 1842 | |> tensor.sum(with: fn(_) { False }) 1843 | |> tensor.axes 1844 | |> should.equal([A(1)]) 1845 | 1846 | let xs = [-1, 4, 3, 2] 1847 | 1848 | let assert Ok(x) = tensor.from_ints(xs, into: d1) 1849 | x 1850 | |> tensor.sum(with: fn(_) { True }) 1851 | |> tensor.to_int 1852 | |> should.equal(Ok(8)) 1853 | 1854 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(2)) 1855 | let assert Ok(x) = tensor.from_ints(xs, into: d3) 1856 | x 1857 | |> tensor.sum(with: fn(_) { True }) 1858 | |> tensor.to_int 1859 | |> should.equal(Ok(8)) 1860 | let y = tensor.sum(from: x, with: fn(x) { axis.name(x) == "C" }) 1861 | y 1862 | |> tensor.to_ints 1863 | |> should.equal([3, 5]) 1864 | y 1865 | |> tensor.axes 1866 | |> should.equal([A(1), B(2)]) 1867 | } 1868 | 1869 | pub fn product_test() { 1870 | 0.0 1871 | |> tensor.from_float 1872 | |> tensor.product(with: fn(_) { True }) 1873 | |> should_share_native_format 1874 | |> tensor.to_float 1875 | |> should.equal(Ok(0.0)) 1876 | 1877 | 3 1878 | |> tensor.from_int 1879 | |> tensor.product(with: fn(_) { False }) 1880 | |> should_share_native_format 1881 | |> tensor.to_int 1882 | |> should.equal(Ok(3)) 1883 | 1884 | let assert Ok(d1) = space.d1(Infer("A")) 1885 | 1886 | let assert Ok(x) = tensor.from_ints(of: [3], into: d1) 1887 | x 1888 | |> tensor.product(with: fn(_) { True }) 1889 | |> tensor.to_int 1890 | |> should.equal(Ok(3)) 1891 | x 1892 | |> tensor.product(with: fn(_) { False }) 1893 | |> tensor.axes 1894 | |> should.equal([A(1)]) 1895 | 1896 | let xs = [-1, 4, 3, 2] 1897 | 1898 | let assert Ok(x) = tensor.from_ints(xs, into: d1) 1899 | x 1900 | |> tensor.product(with: fn(_) { True }) 1901 | |> tensor.to_int 1902 | |> should.equal(Ok(-24)) 1903 | 1904 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(2)) 1905 | let assert Ok(x) = tensor.from_ints(xs, into: d3) 1906 | x 1907 | |> tensor.product(with: fn(_) { True }) 1908 | |> tensor.to_int 1909 | |> should.equal(Ok(-24)) 1910 | let y = tensor.product(from: x, with: fn(x) { axis.name(x) == "C" }) 1911 | y 1912 | |> tensor.to_ints 1913 | |> should.equal([-4, 6]) 1914 | y 1915 | |> tensor.axes 1916 | |> should.equal([A(1), B(2)]) 1917 | } 1918 | 1919 | pub fn mean_test() { 1920 | 0.0 1921 | |> tensor.from_float 1922 | |> tensor.mean(with: fn(_) { True }) 1923 | |> should_share_native_format 1924 | |> tensor.to_float 1925 | |> should.equal(Ok(0.0)) 1926 | 1927 | 3 1928 | |> tensor.from_int 1929 | |> tensor.mean(with: fn(_) { False }) 1930 | |> should_share_native_format 1931 | |> tensor.to_int 1932 | |> should.equal(Ok(3)) 1933 | 1934 | let assert Ok(d1) = space.d1(Infer("A")) 1935 | 1936 | let assert Ok(x) = tensor.from_ints(of: [3], into: d1) 1937 | x 1938 | |> tensor.mean(with: fn(_) { True }) 1939 | |> tensor.to_int 1940 | |> should.equal(Ok(3)) 1941 | x 1942 | |> tensor.mean(with: fn(_) { False }) 1943 | |> tensor.axes 1944 | |> should.equal([A(1)]) 1945 | 1946 | let xs = [-1, 4, 3, 2] 1947 | 1948 | let assert Ok(x) = tensor.from_ints(xs, into: d1) 1949 | x 1950 | |> tensor.mean(with: fn(_) { True }) 1951 | |> tensor.to_int 1952 | |> should.equal(Ok(2)) 1953 | 1954 | let assert Ok(d3) = space.d3(A(1), Infer("B"), C(2)) 1955 | let assert Ok(x) = tensor.from_ints(xs, into: d3) 1956 | let f = fn(x) { axis.name(x) == "C" } 1957 | 1958 | x 1959 | |> tensor.mean(with: fn(_) { True }) 1960 | |> tensor.to_int 1961 | |> should.equal(Ok(2)) 1962 | let y = tensor.mean(from: x, with: f) 1963 | y 1964 | |> tensor.to_ints 1965 | |> should.equal([1, 2]) 1966 | y 1967 | |> tensor.axes 1968 | |> should.equal([A(1), B(2)]) 1969 | 1970 | x 1971 | |> tensor.reformat(apply: format.float32()) 1972 | |> tensor.mean(with: f) 1973 | |> tensor.to_floats 1974 | |> should_loosely_equal([1.5, 2.5]) 1975 | } 1976 | 1977 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 1978 | // Slicing & Joining Functions // 1979 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 1980 | 1981 | pub fn concat_test() { 1982 | let assert Ok(d1) = space.d1(Infer("A")) 1983 | let ints = [0, 1, 2, 3] 1984 | let assert Ok(x) = tensor.from_ints(of: ints, into: d1) 1985 | 1986 | let assert Ok(x) = tensor.concat([x], with: fn(_) { True }) 1987 | x 1988 | |> should_share_native_format 1989 | |> tensor.axes 1990 | |> should.equal([A(4)]) 1991 | x 1992 | |> tensor.to_ints 1993 | |> should.equal(ints) 1994 | 1995 | let ints = 1996 | ints 1997 | |> list.repeat(times: 3) 1998 | |> list.flatten 1999 | let assert Ok(x) = 2000 | x 2001 | |> list.repeat(times: 3) 2002 | |> tensor.concat(with: fn(_) { True }) 2003 | x 2004 | |> should_share_native_format 2005 | |> tensor.axes 2006 | |> should.equal([A(12)]) 2007 | x 2008 | |> tensor.to_ints 2009 | |> should.equal(ints) 2010 | 2011 | let assert Ok(d2) = space.d2(A(2), Infer("B")) 2012 | let a_floats = [0.0, 1.0, 2.0, 3.0] 2013 | let b_floats = [4.0, 5.0] 2014 | let assert Ok(a) = tensor.from_floats(of: a_floats, into: d2) 2015 | let assert Ok(b) = tensor.from_floats(of: b_floats, into: d2) 2016 | let assert Ok(x) = tensor.concat([a, b], with: fn(a) { axis.name(a) == "B" }) 2017 | x 2018 | |> should_share_native_format 2019 | |> tensor.axes 2020 | |> should.equal([A(2), B(3)]) 2021 | x 2022 | |> tensor.to_floats 2023 | |> should.equal([0.0, 1.0, 4.0, 2.0, 3.0, 5.0]) 2024 | 2025 | [a, b] 2026 | |> tensor.concat(with: fn(_) { False }) 2027 | |> should.equal(Error(tensor.AxisNotFound)) 2028 | 2029 | let error = Error(tensor.IncompatibleShape) 2030 | 2031 | [a, b] 2032 | |> tensor.concat(with: fn(a) { axis.name(a) == "A" }) 2033 | |> should.equal(error) 2034 | 2035 | let assert Ok(b) = tensor.reshape(put: b, into: d1) 2036 | [a, b] 2037 | |> tensor.concat(with: fn(a) { axis.name(a) == "A" }) 2038 | |> should.equal(error) 2039 | 2040 | let assert Ok(d2) = space.d2(Infer("C"), B(1)) 2041 | let assert Ok(b) = tensor.reshape(put: b, into: d2) 2042 | [a, b] 2043 | |> tensor.concat(with: fn(a) { axis.name(a) == "B" }) 2044 | |> should.equal(error) 2045 | } 2046 | 2047 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 2048 | // Conversion Functions // 2049 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 2050 | 2051 | const float32_min = -340_282_346_638_528_859_811_704_183_484_516_925_440.0 2052 | 2053 | const float32_max = 340_282_346_638_528_859_811_704_183_484_516_925_440.0 2054 | 2055 | const int32_min = -2_147_483_648 2056 | 2057 | const int32_max = 2_147_483_647 2058 | 2059 | pub fn to_float_test() { 2060 | 0.0 2061 | |> tensor.from_float 2062 | |> tensor.to_float 2063 | |> should.equal(Ok(0.0)) 2064 | 2065 | 0 2066 | |> tensor.from_int 2067 | |> tensor.to_float 2068 | |> should.equal(Ok(0.0)) 2069 | 2070 | let d0 = space.new() 2071 | 2072 | let assert Ok(x) = 2073 | [neg_infinity()] 2074 | |> dynamic.from 2075 | |> native_tensor 2076 | |> tensor.from_native(into: d0, with: format.float32()) 2077 | x 2078 | |> tensor.to_float 2079 | |> should.equal(Ok(float32_min)) 2080 | 2081 | let assert Ok(x) = 2082 | [infinity()] 2083 | |> dynamic.from 2084 | |> native_tensor 2085 | |> tensor.from_native(into: d0, with: format.float32()) 2086 | x 2087 | |> tensor.to_float 2088 | |> should.equal(Ok(float32_max)) 2089 | 2090 | let assert Ok(d1) = space.d1(Infer("A")) 2091 | 2092 | let assert Ok(x) = tensor.from_floats(of: [0.0], into: d1) 2093 | x 2094 | |> tensor.to_float 2095 | |> should.equal(Error(tensor.IncompatibleShape)) 2096 | 2097 | let assert Ok(x) = tensor.from_floats(of: [0.0, 1.0], into: d1) 2098 | x 2099 | |> tensor.to_float 2100 | |> should.equal(Error(tensor.IncompatibleShape)) 2101 | 2102 | let assert Ok(x) = tensor.from_ints(of: [0], into: d1) 2103 | x 2104 | |> tensor.to_float 2105 | |> should.equal(Error(tensor.IncompatibleShape)) 2106 | 2107 | let assert Ok(x) = tensor.from_ints(of: [0, 1], into: d1) 2108 | x 2109 | |> tensor.to_float 2110 | |> should.equal(Error(tensor.IncompatibleShape)) 2111 | } 2112 | 2113 | pub fn to_int_test() { 2114 | 0.0 2115 | |> tensor.from_float 2116 | |> tensor.to_int 2117 | |> should.equal(Ok(0)) 2118 | 2119 | 0 2120 | |> tensor.from_int 2121 | |> tensor.to_int 2122 | |> should.equal(Ok(0)) 2123 | 2124 | let d0 = space.new() 2125 | 2126 | let assert Ok(x) = 2127 | [neg_infinity()] 2128 | |> dynamic.from 2129 | |> native_tensor 2130 | |> tensor.from_native(into: d0, with: format.float32()) 2131 | x 2132 | |> tensor.to_int 2133 | |> should.equal(Ok(int32_min)) 2134 | 2135 | let assert Ok(x) = 2136 | [infinity()] 2137 | |> dynamic.from 2138 | |> native_tensor 2139 | |> tensor.from_native(into: d0, with: format.float32()) 2140 | x 2141 | |> tensor.to_int 2142 | |> should.equal(Ok(int32_max)) 2143 | 2144 | let assert Ok(d1) = space.d1(Infer("A")) 2145 | 2146 | let assert Ok(x) = tensor.from_floats(of: [0.0], into: d1) 2147 | x 2148 | |> tensor.to_int 2149 | |> should.equal(Error(tensor.IncompatibleShape)) 2150 | 2151 | let assert Ok(x) = tensor.from_floats(of: [0.0, 1.0], into: d1) 2152 | x 2153 | |> tensor.to_int 2154 | |> should.equal(Error(tensor.IncompatibleShape)) 2155 | 2156 | let assert Ok(x) = tensor.from_ints(of: [0], into: d1) 2157 | x 2158 | |> tensor.to_int 2159 | |> should.equal(Error(tensor.IncompatibleShape)) 2160 | 2161 | let assert Ok(x) = tensor.from_ints(of: [0, 1], into: d1) 2162 | x 2163 | |> tensor.to_int 2164 | |> should.equal(Error(tensor.IncompatibleShape)) 2165 | } 2166 | 2167 | pub fn to_bool_test() { 2168 | 0.0 2169 | |> tensor.from_float 2170 | |> tensor.to_bool 2171 | |> should.equal(Ok(False)) 2172 | 2173 | 1 2174 | |> tensor.from_int 2175 | |> tensor.to_bool 2176 | |> should.equal(Ok(True)) 2177 | 2178 | let d0 = space.new() 2179 | 2180 | let assert Ok(x) = 2181 | [neg_infinity()] 2182 | |> dynamic.from 2183 | |> native_tensor 2184 | |> tensor.from_native(into: d0, with: format.float32()) 2185 | x 2186 | |> tensor.to_bool 2187 | |> should.equal(Ok(True)) 2188 | 2189 | let assert Ok(x) = 2190 | [infinity()] 2191 | |> dynamic.from 2192 | |> native_tensor 2193 | |> tensor.from_native(into: d0, with: format.float32()) 2194 | x 2195 | |> tensor.to_bool 2196 | |> should.equal(Ok(True)) 2197 | 2198 | let assert Ok(d1) = space.d1(Infer("A")) 2199 | 2200 | let assert Ok(x) = tensor.from_floats(of: [0.0], into: d1) 2201 | x 2202 | |> tensor.to_bool 2203 | |> should.equal(Error(tensor.IncompatibleShape)) 2204 | 2205 | let assert Ok(x) = tensor.from_floats(of: [0.0, 1.0], into: d1) 2206 | x 2207 | |> tensor.to_bool 2208 | |> should.equal(Error(tensor.IncompatibleShape)) 2209 | 2210 | let assert Ok(x) = tensor.from_ints(of: [0], into: d1) 2211 | x 2212 | |> tensor.to_bool 2213 | |> should.equal(Error(tensor.IncompatibleShape)) 2214 | 2215 | let assert Ok(x) = tensor.from_ints(of: [0, 1], into: d1) 2216 | x 2217 | |> tensor.to_bool 2218 | |> should.equal(Error(tensor.IncompatibleShape)) 2219 | } 2220 | 2221 | pub fn to_floats_test() { 2222 | let assert Ok(d1) = space.d1(Infer("A")) 2223 | let assert Ok(d6) = space.d6(A(1), B(1), C(1), D(1), E(1), Infer("F")) 2224 | 2225 | let xs = [1.0, 2.0, 3.0] 2226 | 2227 | 0.0 2228 | |> tensor.from_float 2229 | |> tensor.to_floats 2230 | |> should.equal([0.0]) 2231 | 2232 | let assert Ok(x) = tensor.from_floats(of: xs, into: d1) 2233 | x 2234 | |> tensor.to_floats 2235 | |> should.equal(xs) 2236 | 2237 | let assert Ok(x) = tensor.from_floats(of: xs, into: d6) 2238 | x 2239 | |> tensor.to_floats 2240 | |> should.equal(xs) 2241 | 2242 | let ys = [1, 2, 3] 2243 | 2244 | 0 2245 | |> tensor.from_int 2246 | |> tensor.to_floats 2247 | |> should.equal([0.0]) 2248 | 2249 | let assert Ok(x) = tensor.from_ints(of: ys, into: d1) 2250 | x 2251 | |> tensor.to_floats 2252 | |> should.equal(xs) 2253 | 2254 | let assert Ok(x) = tensor.from_ints(of: ys, into: d6) 2255 | x 2256 | |> tensor.to_floats 2257 | |> should.equal(xs) 2258 | 2259 | let assert Ok(x) = 2260 | [neg_infinity(), infinity()] 2261 | |> dynamic.from 2262 | |> native_tensor 2263 | |> tensor.from_native(into: d1, with: format.float32()) 2264 | x 2265 | |> tensor.to_floats 2266 | |> should.equal([float32_min, float32_max]) 2267 | } 2268 | 2269 | pub fn to_ints_test() { 2270 | let assert Ok(d1) = space.d1(Infer("A")) 2271 | let assert Ok(d6) = space.d6(A(1), B(1), C(1), D(1), E(1), Infer("F")) 2272 | 2273 | let xs = [1, 2, 3] 2274 | 2275 | 0 2276 | |> tensor.from_int 2277 | |> tensor.to_ints 2278 | |> should.equal([0]) 2279 | 2280 | let assert Ok(x) = tensor.from_ints(of: xs, into: d1) 2281 | x 2282 | |> tensor.to_ints 2283 | |> should.equal(xs) 2284 | 2285 | let assert Ok(x) = tensor.from_ints(of: xs, into: d6) 2286 | x 2287 | |> tensor.to_ints 2288 | |> should.equal(xs) 2289 | 2290 | let ys = [1.0, 2.0, 3.0] 2291 | 2292 | 0.0 2293 | |> tensor.from_float 2294 | |> tensor.to_ints 2295 | |> should.equal([0]) 2296 | 2297 | let assert Ok(x) = tensor.from_floats(of: ys, into: d1) 2298 | x 2299 | |> tensor.to_ints 2300 | |> should.equal(xs) 2301 | 2302 | let assert Ok(x) = tensor.from_floats(of: ys, into: d6) 2303 | x 2304 | |> tensor.to_ints 2305 | |> should.equal(xs) 2306 | 2307 | let assert Ok(x) = 2308 | [neg_infinity(), infinity()] 2309 | |> dynamic.from 2310 | |> native_tensor 2311 | |> tensor.from_native(into: d1, with: format.float32()) 2312 | x 2313 | |> tensor.to_ints 2314 | |> should.equal([int32_min, int32_max]) 2315 | } 2316 | 2317 | pub fn to_bools_test() { 2318 | let assert Ok(d1) = space.d1(Infer("A")) 2319 | let assert Ok(d6) = space.d6(A(1), B(1), C(1), D(1), E(1), Infer("F")) 2320 | 2321 | let xs = [True, False, True] 2322 | 2323 | let ys = [1, 0, -3] 2324 | 2325 | 1 2326 | |> tensor.from_int 2327 | |> tensor.to_bools 2328 | |> should.equal([True]) 2329 | 2330 | let assert Ok(x) = tensor.from_ints(of: ys, into: d1) 2331 | x 2332 | |> tensor.to_bools 2333 | |> should.equal(xs) 2334 | 2335 | let assert Ok(x) = tensor.from_ints(of: ys, into: d6) 2336 | x 2337 | |> tensor.to_bools 2338 | |> should.equal(xs) 2339 | 2340 | let ys = [1.0, 0.0, -3.0] 2341 | 2342 | 0.0 2343 | |> tensor.from_float 2344 | |> tensor.to_bools 2345 | |> should.equal([False]) 2346 | 2347 | let assert Ok(x) = tensor.from_floats(of: ys, into: d1) 2348 | x 2349 | |> tensor.to_bools 2350 | |> should.equal(xs) 2351 | 2352 | let assert Ok(x) = tensor.from_floats(of: ys, into: d6) 2353 | x 2354 | |> tensor.to_bools 2355 | |> should.equal(xs) 2356 | 2357 | let assert Ok(x) = 2358 | [neg_infinity(), infinity()] 2359 | |> dynamic.from 2360 | |> native_tensor 2361 | |> tensor.from_native(into: d1, with: format.float32()) 2362 | x 2363 | |> tensor.to_bools 2364 | |> should.equal([True, True]) 2365 | } 2366 | 2367 | pub fn to_native_test() { 2368 | let assert Ok(space) = space.d3(A(2), Infer("B"), C(2)) 2369 | let assert Ok(x) = 2370 | [1, 2, 3, 4, 5, 6, 7, 8] 2371 | |> tensor.from_ints(into: space) 2372 | x 2373 | |> tensor.to_native 2374 | |> native_shape 2375 | |> should.equal(dynamic.from(#(2, 2, 2))) 2376 | } 2377 | 2378 | @external(erlang, "Elixir.Nx", "shape") 2379 | @external(javascript, "../argamak_test_ffi.mjs", "shape") 2380 | fn native_shape(tensor: tensor.Native) -> Dynamic 2381 | 2382 | pub fn to_string_test() { 2383 | 0.0 2384 | |> tensor.from_float 2385 | |> tensor.to_string(return: tensor.Data, wrap_at: 0) 2386 | |> should.equal("0.0") 2387 | 2388 | 0 2389 | |> tensor.from_int 2390 | |> tensor.to_string(return: tensor.Record, wrap_at: 0) 2391 | |> should.equal( 2392 | "Tensor( 2393 | Format(Int32), 2394 | Space(), 2395 | 0, 2396 | )", 2397 | ) 2398 | 2399 | let assert Ok(d1) = space.d1(Infer("A")) 2400 | 2401 | let assert Ok(x) = tensor.from_floats(of: [0.0], into: d1) 2402 | x 2403 | |> tensor.to_string(return: tensor.Data, wrap_at: 0) 2404 | |> should.equal("[0.0]") 2405 | 2406 | let assert Ok(x) = tensor.from_ints(of: [0], into: d1) 2407 | x 2408 | |> tensor.to_string(return: tensor.Record, wrap_at: 0) 2409 | |> should.equal( 2410 | "Tensor( 2411 | Format(Int32), 2412 | Space(A(1)), 2413 | [0], 2414 | )", 2415 | ) 2416 | 2417 | let assert Ok(x) = tensor.from_ints(of: [101, 3, 225, 4_000_000], into: d1) 2418 | x 2419 | |> tensor.to_string(return: tensor.Data, wrap_at: 30) 2420 | |> should.equal("[ 101, 3, 225,\n 4000000]") 2421 | 2422 | let assert Ok(d2) = space.d2(A(2), B(2)) 2423 | let assert Ok(x) = 2424 | [0.0, 2.25, 0.20000000298023224, -1.0] 2425 | |> tensor.from_floats(into: d2) 2426 | x 2427 | |> tensor.to_string(return: tensor.Record, wrap_at: 0) 2428 | |> should.equal( 2429 | "Tensor( 2430 | Format(Float32), 2431 | Space(A(2), B(2)), 2432 | [[ 0.0, 2.25], 2433 | [ 0.2, -1.0]], 2434 | )", 2435 | ) 2436 | 2437 | let assert Ok(d4) = space.d4(A(1), Infer("B"), C(2), D(2)) 2438 | let assert Ok(x) = 2439 | [1, 2, 3, 44, 5, 6789, 10, 11, 12, 132, 5, 7] 2440 | |> tensor.from_ints(into: d4) 2441 | x 2442 | |> tensor.to_string(return: tensor.Record, wrap_at: 0) 2443 | |> should.equal( 2444 | "Tensor( 2445 | Format(Int32), 2446 | Space(A(1), B(3), C(2), D(2)), 2447 | [[[[ 1, 2], 2448 | [ 3, 44]], 2449 | [[ 5, 6789], 2450 | [ 10, 11]], 2451 | [[ 12, 132], 2452 | [ 5, 7]]]], 2453 | )", 2454 | ) 2455 | 2456 | let assert Ok(d6) = space.d6(A(1), B(1), Infer("C"), D(3), E(3), F(3)) 2457 | let assert Ok(x) = 2458 | 1 2459 | |> list.range(to: 54) 2460 | |> tensor.from_ints(into: d6) 2461 | x 2462 | |> tensor.to_string(return: tensor.Data, wrap_at: 0) 2463 | |> should.equal( 2464 | "[[[[[[ 1, 2, 3], 2465 | [ 4, 5, 6], 2466 | [ 7, 8, 9]], 2467 | [[10, 11, 12], 2468 | [13, 14, 15], 2469 | [16, 17, 18]], 2470 | [[19, 20, 21], 2471 | [22, 23, 24], 2472 | [25, 26, 27]]], 2473 | [[[28, 29, 30], 2474 | [31, 32, 33], 2475 | [34, 35, 36]], 2476 | [[37, 38, 39], 2477 | [40, 41, 42], 2478 | [43, 44, 45]], 2479 | [[46, 47, 48], 2480 | [49, 50, 51], 2481 | [52, 53, 54]]]]]]", 2482 | ) 2483 | 2484 | let assert Ok(x) = 2485 | [neg_infinity(), infinity()] 2486 | |> dynamic.from 2487 | |> native_tensor 2488 | |> tensor.from_native(into: d1, with: format.float32()) 2489 | x 2490 | |> tensor.to_string(return: tensor.Data, wrap_at: 0) 2491 | |> should.equal("[-Infinity, Infinity]") 2492 | } 2493 | 2494 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 2495 | // Private Functions // 2496 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// 2497 | 2498 | pub fn fit_test() { 2499 | let assert Ok(d1) = space.d1(Infer("A")) 2500 | let assert Ok(d2) = space.d2(Infer("A"), B(1)) 2501 | 2502 | 0.0 2503 | |> tensor.from_float 2504 | |> tensor.shape 2505 | |> should.equal([]) 2506 | 2507 | let assert Ok(x) = tensor.from_floats(of: [1.0, 2.0, 3.0], into: d1) 2508 | x 2509 | |> tensor.shape 2510 | |> should.equal([3]) 2511 | 2512 | let assert Ok(x) = tensor.reshape(put: x, into: d2) 2513 | x 2514 | |> tensor.shape 2515 | |> should.equal([3, 1]) 2516 | } 2517 | 2518 | fn should_share_native_format(x: Tensor(a)) -> Tensor(a) { 2519 | x 2520 | |> tensor.to_native 2521 | |> native_format 2522 | |> should.equal(format.to_native(tensor.format(x))) 2523 | 2524 | x 2525 | } 2526 | 2527 | @external(erlang, "Elixir.Nx", "type") 2528 | @external(javascript, "../argamak_test_ffi.mjs", "type") 2529 | fn native_format(tensor: tensor.Native) -> format.Native 2530 | 2531 | fn should_loosely_equal(a: List(Float), b: List(Float)) -> Nil { 2532 | a 2533 | |> list.zip(b) 2534 | |> list.map(with: fn(pair) { 2535 | float.loosely_compare(pair.0, with: pair.1, tolerating: 0.002) 2536 | }) 2537 | |> list.all(satisfying: fn(x) { x == Eq }) 2538 | |> should.be_true 2539 | } 2540 | 2541 | type Infinity { 2542 | Infinity 2543 | NegInfinity 2544 | } 2545 | 2546 | @external(javascript, "../argamak_test_ffi.mjs", "infinity") 2547 | fn infinity() -> Dynamic { 2548 | dynamic.from(Infinity) 2549 | } 2550 | 2551 | @external(javascript, "../argamak_test_ffi.mjs", "neg_infinity") 2552 | fn neg_infinity() -> Dynamic { 2553 | dynamic.from(NegInfinity) 2554 | } 2555 | --------------------------------------------------------------------------------