├── .envrc ├── .github ├── dependabot.yml └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE ├── README.md ├── default.nix ├── flake.lock ├── flake.nix ├── tests ├── test.nix └── test.py └── tools └── approximate.py /.envrc: -------------------------------------------------------------------------------- 1 | use flake 2 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: 'Unit Test' 2 | on: 3 | workflow_dispatch: 4 | pull_request: 5 | push: 6 | branches: 7 | - master 8 | jobs: 9 | tests: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout repository 13 | uses: actions/checkout@v4 14 | 15 | - name: Install nix 16 | uses: cachix/install-nix-action@v31 17 | with: 18 | extra_nix_config: | 19 | experimental-features = nix-command flakes ca-derivations 20 | extra-experimental-features = nix-command flakes ca-derivations 21 | access-tokens = github.com=${{ secrets.GITHUB_TOKEN }} 22 | extra-platforms = i686-linux aarch64-linux arm-linux 23 | log-lines = 25 24 | 25 | - name: Setup GitHub Actions cache for Nix 26 | uses: DeterminateSystems/magic-nix-cache-action@main 27 | 28 | - name: Run unit test 29 | run: | 30 | nix run . 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | result* 2 | .vscode 3 | .direnv 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yuhui Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nix-math 2 | 3 | Experimental mathematical library in pure Nix, using no external library. 4 | 5 | ## Why? 6 | 7 | 1. Because I can. 8 | 2. Because I want to approximate the network latency between my servers using their latitudes/longitudes. No, I don't want to run `traceroute` between my servers every now and then. 9 | 10 | ## Limitations 11 | 12 | - Nix does not provide some lower level operations, such as bit operation on floating point numbers. This leads to computation inaccuracies, and it's impossible to get the exact same result as GLibC or `numpy`, even if I have their code as reference. For functions where getting exact same results are impossible, I target for error within 0.0001%. 13 | 14 | - This library does not support these features. ALthough I added exceptions for some situations, this is by no means comprehensive. Please consider these as undefined behaviors, and submit an issue when you encounter one. 15 | - Floating point infinity (+inf, -inf) 16 | - NaN 17 | - Floating power overflow and underflow 18 | - Operations where the value is extremely small (around `1e-38`) or extremely large (around `1e38`) 19 | - Imaginary numbers 20 | 21 | ## Usage 22 | 23 | ```nix 24 | { 25 | inputs = { 26 | nix-math.url = "github:xddxdd/nix-math"; 27 | }; 28 | 29 | outputs = inputs: let 30 | math = inputs.nix-math.lib.math; 31 | in{ 32 | value = math.sin (math.deg2rad 45); 33 | }; 34 | } 35 | ``` 36 | 37 | ## Provided functions 38 | 39 | - `abs [x]`: Absolute value of `x` 40 | - `arange [min] [max] [step]`: Create a list of numbers from `min` (inclusive) to `max` (exclusive), adding `step` each time. 41 | - `arange2 [min] [max] [step]`: Same as `arange`, but includes `max` as well. 42 | - `atan [x]`: Arctangent function. Returns radian. 43 | - `cos [x]`: Trigonometric function. Takes radian as input. 44 | - `deg2rad [x]`: Degrees to radian. 45 | - `div [a] [b]`: Divide `a` by `b` with no remainder. 46 | - `exp [x]`: Exponential function. Returns `e^x`. 47 | - `fabs [x]`: Absolute value of `x` 48 | - `factorial [x]`: Returns factorial of `x`. `x` is an integer, `x >= 0`. 49 | - `haversine [lat1] [lon1] [lat2] [lon2]`: Returns distance of two points on Earth for the given latitude/longitude. Uses 6371km as Earth radius. 50 | - `haversine' [radius] [lat1] [lon1] [lat2] [lon2]`: Returns distance of two points on sphere for the given radius/latitude/longitude. 51 | - `int [x]`: Integer part of `x`. 52 | - `ln [x]`: Logarithmetic function. Returns `log_e x`. 53 | - `log [a] [b]`: Logarithmetic function. Returns `log_a b`. 54 | - `log10 [x]`: Logarithmetic function. Returns `log_10 x`. 55 | - `log2 [x]`: Logarithmetic function. Returns `log_2 x`. 56 | - `mod [a] [b]`: Modulos of dividing `a` by `b`. 57 | - `pow [a] [b]`: Returns `a` to the power of `b`. Now supports floating point `b`. 58 | - `round [x]`: Round `x` to the nearest integer. For input ending with `.5`, will round to the nearest even number (same logic as `np.round`). 59 | - `sin [x]`: Trigonometric function. Takes radian as input. 60 | - `sqrt [x]`: Square root of `x`. `x >= 0`. 61 | - `tan [x]`: Trigonometric function. Takes radian as input. 62 | 63 | ## Implementation Details 64 | 65 | - `sin` function is implemented with its Taylor series: for `x >= 0`, `sin(x) = x - x^3/3! + x^5/5!`. Calculation is repeated until the next value in series is less than epsilon (`1e-10`). 66 | - `cos` is `cos(x) = sin (pi/2 - x)`. 67 | - `tan` is `tan(x) = sin x / cos x`. 68 | - For `sin`, `cos` and `tan`, result error is within 0.0001% as checked by unit test. 69 | - `atan` is implemented by approximating to polynomial function. This is faster and more accurate than using its Taylor series, because its Taylor series does not converge fast enough, and may cause "max-call-depth exceeded" error. 70 | - For `atan`, result error is within 0.0001%. 71 | - `sqrt` is implemented with Newtonian method. Calculation is repeated until the next value is less than epsilon (`1e-10`). 72 | - For `sqrt`, result error is within `1e-10`. 73 | - `exp` is implemented by approximating to polynomial function. This is faster and more accurate than using its Taylor series, because its Taylor series does not converge fast enough, and may cause "max-call-depth exceeded" error. 74 | - For `exp`, result error is within 0.0001%. 75 | - `log` is implemented with its Taylor series: 76 | - For `1 <= x <= 1.9`, `log(x) = (x-1)/1 - (x-1)^2/2 + (x-1)^3/3`. Calculation is repeated until the next value in series is less than epsilon (`1e-10`). 77 | - For `x >= 1.9`, `log(x) = 2 * log(sqrt(x))` 78 | - For `0 < x < 1`, `log(x) = -log(1/x)` 79 | - Although the Taylor series applies to `0 <= x <= 2`, calculation outside `1 <= x <= 1.9` is very slow and may cause max-call-depth exceeded error. 80 | - For `log`, result error is within 0.0001%. 81 | - `pow` is `pow(x, y) = exp(y * log(x))`. 82 | - `ln` is `ln(x, y) = log(y) / log(x)`. 83 | - `log2` is `log2(x) = log(x) / log(2)`. 84 | - `log10` is `log10(x) = log(x) / log(10)`. 85 | - For `pow`, `ln`, `log2`, `log10`, result error is within 0.0001%. 86 | - `haversine` is implemented based on . 87 | 88 | ## Unit test 89 | 90 | Unit test is defined in `tests/test.py`. It invokes `tests/test.nix` which tests the mathematical functions with a range of inputs, and compares the output to the same function from Numpy. 91 | 92 | To run the unit test: 93 | 94 | ```bash 95 | nix run . 96 | ``` 97 | 98 | ## License 99 | 100 | MIT. 101 | -------------------------------------------------------------------------------- /default.nix: -------------------------------------------------------------------------------- 1 | { lib, ... }: 2 | rec { 3 | inherit (builtins) floor ceil; 4 | 5 | pi = 3.14159265358979323846264338327950288; 6 | e = 2.718281828459045235360287471352; 7 | epsilon = _pow_int (0.1) 10; 8 | 9 | sum = builtins.foldl' builtins.add 0; 10 | multiply = builtins.foldl' builtins.mul 1; 11 | 12 | # Absolute value of `x` 13 | abs = x: if x < 0 then 0 - x else x; 14 | 15 | # Absolute value of `x` 16 | fabs = abs; 17 | 18 | # Returns `a` to the power of `b`. **Only supports integer for `b`!** 19 | # Internal use only. Users should use `_pow_int`, which supports floating point exponentials. 20 | _pow_int = 21 | x: times: 22 | if times == 0 then 23 | 1 24 | else if times < 0 then 25 | 1 / (_pow_int x (0 - times)) 26 | else 27 | multiply (lib.replicate times x); 28 | 29 | # Create a list of numbers from `min` (inclusive) to `max` (exclusive), adding `step` each time. 30 | arange = 31 | min: max: step: 32 | let 33 | count = floor ((max - min) / step); 34 | in 35 | lib.genList (i: min + step * i) count; 36 | 37 | # Create a list of numbers from `min` (inclusive) to `max` (inclusive), adding `step` each time. 38 | arange2 = 39 | min: max: step: 40 | arange min (max + step) step; 41 | 42 | # Calculate x^0*poly[0] + x^1*poly[1] + ... + x^n*poly[n] 43 | polynomial = 44 | x: poly: 45 | let 46 | step = i: (_pow_int x i) * (builtins.elemAt poly i); 47 | in 48 | sum (lib.genList step (builtins.length poly)); 49 | 50 | parseFloat = builtins.fromJSON; 51 | 52 | int = x: if x < 0 then -int (0 - x) else builtins.floor x; 53 | 54 | round = 55 | x: 56 | let 57 | intPart = builtins.floor x; 58 | intIsEven = 0 == mod intPart 2; 59 | fractionPart = x - intPart; 60 | in 61 | if abs (fractionPart - 0.5) < epsilon then 62 | if intIsEven then intPart else intPart + 1 63 | else if fractionPart < 0.5 then 64 | intPart 65 | else 66 | intPart + 1; 67 | 68 | hasFraction = 69 | x: 70 | let 71 | splitted = lib.splitString "." (builtins.toString x); 72 | in 73 | builtins.length splitted >= 2 74 | && 75 | builtins.length ( 76 | builtins.filter (ch: ch != "0") (lib.stringToCharacters (builtins.elemAt splitted 1)) 77 | ) > 0; 78 | 79 | # Divide `a` by `b` with no remainder. 80 | div = 81 | a: b: 82 | let 83 | divideExactly = !(hasFraction (1.0 * a / b)); 84 | offset = if divideExactly then 0 else (0 - 1); 85 | in 86 | if b < 0 then 87 | offset - div a (0 - b) 88 | else if a < 0 then 89 | offset - div (0 - a) b 90 | else 91 | floor (1.0 * a / b); 92 | 93 | # Modulos of dividing `a` by `b`. 94 | mod = 95 | a: b: 96 | if b < 0 then 97 | 0 - mod (0 - a) (0 - b) 98 | else if a < 0 then 99 | mod (b - mod (0 - a) b) b 100 | else 101 | a - b * (div a b); 102 | 103 | # Returns factorial of `x`. `x` is an integer, `x >= 0`. 104 | factorial = x: multiply (lib.range 1 x); 105 | 106 | # Trigonometric function. Takes radian as input. 107 | # Taylor series: for x >= 0, sin(x) = x - x^3/3! + x^5/5! - ... 108 | sin = 109 | x: 110 | let 111 | x' = mod (1.0 * x) (2 * pi); 112 | step = i: (_pow_int (0 - 1) (i - 1)) * multiply (lib.genList (j: x' / (j + 1)) (i * 2 - 1)); 113 | helper = 114 | tmp: i: 115 | let 116 | value = step i; 117 | in 118 | if (fabs value) < epsilon then tmp else helper (tmp + value) (i + 1); 119 | in 120 | if x < 0 then -sin (0 - x) else helper 0 1; 121 | 122 | # Trigonometric function. Takes radian as input. 123 | cos = x: sin (0.5 * pi - x); 124 | 125 | # Trigonometric function. Takes radian as input. 126 | tan = x: (sin x) / (cos x); 127 | 128 | # Arctangent function. Polynomial approximation. 129 | atan = 130 | x: 131 | let 132 | poly = builtins.fromJSON "[-3.45783607234591e-15, 0.99999999999744, 5.257304414192212e-10, -0.33333336391488594, 8.433269318729302e-07, 0.1999866363777591, 0.00013446991236889277, -0.14376659407790288, 0.00426000182788111, 0.097197156521193, 0.030912220117352136, -0.133167493353323, 0.020663690408239177, 0.11398478735740854, -0.06791459641806276, -0.06663597903061667, 0.11580255232044795, -0.07215375057397233, 0.022284945086684438, -0.0028573630133916046]"; 133 | in 134 | if x < 0 then 135 | -atan (0 - x) 136 | else if x > 1 then 137 | pi / 2 - atan (1 / x) 138 | else 139 | polynomial x poly; 140 | 141 | # Exponential function. Polynomial approximation. 142 | # (https://github.com/nadavrot/fast_log) 143 | exp = 144 | x: 145 | let 146 | x_int = int x; 147 | x_decimal = x - x_int; 148 | decimal_poly = builtins.fromJSON "[0.9999999999999997, 0.9999999999999494, 0.5000000000013429, 0.16666666664916754, 0.04166666680065545, 0.008333332669176907, 0.001388891142716621, 0.00019840730702746657, 2.481076351588151e-05, 2.744709498016379e-06, 2.846575263734758e-07, 2.0215584670370862e-08, 3.542885385105854e-09]"; 149 | in 150 | if x < 0 then 1 / (exp (0 - x)) else (_pow_int e x_int) * (polynomial x_decimal decimal_poly); 151 | 152 | # Logarithmetic function. Takes radian as input. 153 | # Taylor series: for 1 <= x <= 1.9, ln(x) = (x-1)/1 - (x-1)^2/2 + (x-1)^3/3 154 | # (https://en.wikipedia.org/wiki/Logarithm#Taylor_series) 155 | # For x >= 1.9, ln(x) = 2 * ln(sqrt(x)) 156 | # For 0 < x < 1, ln(x) = -ln(1/x) 157 | # 158 | # Although Taylor series applies to 0 <= x <= 2, calculation outside 159 | # 1 <= x <= 1.9 is very slow and may cause max-call-depth exceeded 160 | ln = 161 | x: 162 | let 163 | step = i: (_pow_int (0 - 1) (i - 1)) * (_pow_int (1.0 * x - 1.0) i) / i; 164 | helper = 165 | tmp: i: 166 | let 167 | value = step i; 168 | in 169 | if (fabs value) < epsilon then tmp else helper (tmp + value) (i + 1); 170 | in 171 | if x <= 0 then 172 | throw "ln(x<=0) returns invalid value" 173 | else if x < 1 then 174 | -ln (1 / x) 175 | else if x > 1.9 then 176 | 2 * (ln (sqrt x)) 177 | else 178 | helper 0 1; 179 | 180 | # Power function that supports float. 181 | # pow(x, y) = exp(y * ln(x)), plus a few edge cases. 182 | pow = 183 | x: times: 184 | let 185 | is_int_times = abs (times - int times) < epsilon; 186 | in 187 | if is_int_times then 188 | _pow_int x (int times) 189 | else if x == 0 then 190 | 0 191 | else if x < 0 then 192 | throw "Calculating power of negative base and decimal exponential is not supported" 193 | else 194 | exp (times * ln x); 195 | 196 | log = base: x: (ln x) / (ln base); 197 | log2 = log 2; 198 | log10 = log 10; 199 | 200 | # Degrees to radian. 201 | deg2rad = x: x * pi / 180; 202 | 203 | # Square root of `x`. `x >= 0`. 204 | sqrt = 205 | x: 206 | let 207 | helper = 208 | tmp: 209 | let 210 | value = (tmp + 1.0 * x / tmp) / 2; 211 | in 212 | if (fabs (value - tmp)) < epsilon then value else helper value; 213 | in 214 | if x < epsilon then 0 else helper (1.0 * x); 215 | 216 | # Returns distance of two points on Earth for the given latitude/longitude. 217 | # https://stackoverflow.com/questions/27928/calculate-distance-between-two-latitude-longitude-points-haversine-formula 218 | haversine = haversine' 6371000; 219 | haversine' = 220 | radius: lat1: lon1: lat2: lon2: 221 | let 222 | rad_lat = deg2rad ((1.0 * lat2) - (1.0 * lat1)); 223 | rad_lon = deg2rad ((1.0 * lon2) - (1.0 * lon1)); 224 | a = 225 | (sin (rad_lat / 2)) * (sin (rad_lat / 2)) 226 | + 227 | (cos (deg2rad (1.0 * lat1))) 228 | * (cos (deg2rad (1.0 * lat2))) 229 | * (sin (rad_lon / 2)) 230 | * (sin (rad_lon / 2)); 231 | c = 2 * atan ((sqrt a) / (sqrt (1 - a))); 232 | result = radius * c; 233 | in 234 | if result < 0 then 0 else result; 235 | } 236 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "flake-parts": { 4 | "inputs": { 5 | "nixpkgs-lib": "nixpkgs-lib" 6 | }, 7 | "locked": { 8 | "lastModified": 1733312601, 9 | "narHash": "sha256-4pDvzqnegAfRkPwO3wmwBhVi/Sye1mzps0zHWYnP88c=", 10 | "owner": "hercules-ci", 11 | "repo": "flake-parts", 12 | "rev": "205b12d8b7cd4802fbcb8e8ef6a0f1408781a4f9", 13 | "type": "github" 14 | }, 15 | "original": { 16 | "owner": "hercules-ci", 17 | "repo": "flake-parts", 18 | "type": "github" 19 | } 20 | }, 21 | "nixpkgs": { 22 | "locked": { 23 | "lastModified": 1735005654, 24 | "narHash": "sha256-62vInUmo5WS+0ZXhobdeMUdr1GHYfQHb2YasELZ6kk0=", 25 | "owner": "NixOS", 26 | "repo": "nixpkgs", 27 | "rev": "0a4b89adfe914aa10c33eaee34c93ea698a4ee2b", 28 | "type": "github" 29 | }, 30 | "original": { 31 | "owner": "NixOS", 32 | "ref": "nixos-unstable-small", 33 | "repo": "nixpkgs", 34 | "type": "github" 35 | } 36 | }, 37 | "nixpkgs-lib": { 38 | "locked": { 39 | "lastModified": 1733096140, 40 | "narHash": "sha256-1qRH7uAUsyQI7R1Uwl4T+XvdNv778H0Nb5njNrqvylY=", 41 | "type": "tarball", 42 | "url": "https://github.com/NixOS/nixpkgs/archive/5487e69da40cbd611ab2cadee0b4637225f7cfae.tar.gz" 43 | }, 44 | "original": { 45 | "type": "tarball", 46 | "url": "https://github.com/NixOS/nixpkgs/archive/5487e69da40cbd611ab2cadee0b4637225f7cfae.tar.gz" 47 | } 48 | }, 49 | "root": { 50 | "inputs": { 51 | "flake-parts": "flake-parts", 52 | "nixpkgs": "nixpkgs" 53 | } 54 | } 55 | }, 56 | "root": "root", 57 | "version": 7 58 | } 59 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | inputs = { 3 | nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable-small"; 4 | flake-parts.url = "github:hercules-ci/flake-parts"; 5 | }; 6 | 7 | outputs = 8 | { 9 | self, 10 | nixpkgs, 11 | flake-parts, 12 | ... 13 | }@inputs: 14 | flake-parts.lib.mkFlake { inherit inputs; } { 15 | systems = [ 16 | "x86_64-linux" 17 | "aarch64-linux" 18 | ]; 19 | 20 | flake = { 21 | lib.math = import ./default.nix { inherit (nixpkgs) lib; }; 22 | test.mathOutput = import ./tests/test.nix { 23 | inherit (nixpkgs) lib; 24 | inherit (self.lib) math; 25 | }; 26 | }; 27 | 28 | perSystem = 29 | { 30 | config, 31 | system, 32 | pkgs, 33 | ... 34 | }: 35 | let 36 | python3 = pkgs.python3.withPackages ( 37 | ps: with ps; [ 38 | numpy 39 | pytest 40 | ] 41 | ); 42 | in 43 | { 44 | apps.default = { 45 | type = "app"; 46 | program = builtins.toString ( 47 | pkgs.writeShellScript "test" '' 48 | set -euo pipefail 49 | exec ${python3}/bin/python3 -m pytest --verbose ${self}/tests/test.py 50 | '' 51 | ); 52 | }; 53 | 54 | devShells.default = python3.env; 55 | }; 56 | }; 57 | } 58 | -------------------------------------------------------------------------------- /tests/test.nix: -------------------------------------------------------------------------------- 1 | { 2 | lib, 3 | math, 4 | ... 5 | }: 6 | let 7 | testOnInputs = 8 | inputs: fn: 9 | builtins.listToAttrs ( 10 | builtins.map (v: { 11 | # toJSON will output float upto precision limit 12 | # toString will round the float, causing imprecisions during check 13 | # 14 | # Example: 15 | # 16 | # nix-repl> builtins.toJSON (0-10+0.001*9985) 17 | # "-0.015000000000000568" 18 | # 19 | # nix-repl> builtins.toString (0-10+0.001*9985) 20 | # "-0.015000" 21 | name = builtins.toJSON v; 22 | value = fn v; 23 | }) inputs 24 | ); 25 | 26 | testRange = 27 | min: max: step: 28 | testOnInputs (math.arange2 min max step); 29 | 30 | tests = { 31 | "atan" = testRange (0 - 10) 10 0.001 math.atan; 32 | "cos" = testRange (0 - 10) 10 0.001 math.cos; 33 | "deg2rad" = testRange (0 - 360) 360 0.001 math.deg2rad; 34 | "div_-2.5" = testRange (0 - 10) 10 0.001 (x: math.div x (0 - 2.5)); 35 | "div_3_int" = testOnInputs (builtins.genList (x: x - 5) 11) (x: math.div x 3); 36 | "div_3" = testRange (0 - 10) 10 0.001 (x: math.div x 3); 37 | "div_4.5" = testRange (0 - 10) 10 0.001 (x: math.div x 4.5); 38 | "exp_large" = testRange (0 - 700) 700 0.1 math.exp; 39 | "exp_small" = testRange (0 - 2) 2 0.001 math.exp; 40 | "fabs" = testRange (0 - 2) 2 0.001 math.fabs; 41 | "factorial" = testRange 0 10 1 math.factorial; 42 | "int" = testRange (0 - 10) 10 0.001 math.int; 43 | "ln_large" = testRange 1 10000 1 math.ln; 44 | "ln_small" = testRange 0.001 2 0.001 math.ln; 45 | "log10" = testRange 1 10000 1 math.log10; 46 | "log2" = testRange 1 10000 1 math.log2; 47 | "mod_-2.5" = testRange (0 - 10) 10 0.001 (x: math.mod x (0 - 2.5)); 48 | "mod_3_int" = testOnInputs (builtins.genList (x: x - 5) 11) (x: math.mod x 3); 49 | "mod_3" = testRange (0 - 10) 10 0.001 (x: math.mod x 3); 50 | "mod_4.5" = testRange (0 - 10) 10 0.001 (x: math.mod x 4.5); 51 | "pow_-2.5_x" = testRange 1 100 1 (math.pow (0 - 2.5)); 52 | # Avoid `pow 0 -2` since that is undefined 53 | "pow_x_-2_positive" = testRange 0.001 10 0.001 (x: math.pow x (0 - 2)); 54 | "pow_x_-2_negative" = testRange (0 - 10) (0 - 0.001) 0.001 (x: math.pow x (0 - 2)); 55 | "pow_x_0" = testRange (0 - 10) 10 0.001 (x: math.pow x 0); 56 | "pow_x_3" = testRange (0 - 10) 10 0.001 (x: math.pow x 3); 57 | "pow_4.5_x" = testRange 1 100 0.01 (math.pow 4.5); 58 | "round" = testRange (0 - 10) 10 0.001 math.round; 59 | "sin" = testRange (0 - 10) 10 0.001 math.sin; 60 | "sqrt" = testRange 0 10 0.001 math.sqrt; 61 | "tan" = testRange (0 - 10) 10 0.001 math.tan; 62 | }; 63 | in 64 | lib.mapAttrs (k: builtins.toJSON) tests 65 | -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | from typing import Callable, Dict 4 | import numpy as np 5 | import os 6 | import pytest 7 | import subprocess 8 | 9 | EPSILON = 1e-10 10 | SCRIPT_PATH = os.path.realpath(os.path.dirname(__file__)) 11 | FLAKE_PATH = os.path.realpath(os.path.join(SCRIPT_PATH, "..")) 12 | 13 | 14 | def compare_absolute(epsilon: float): 15 | def fn(expected: float, actual: float) -> bool: 16 | return np.fabs(expected - actual) <= epsilon 17 | 18 | return fn 19 | 20 | 21 | def compare_ratio(error: float, epsilon: float = EPSILON): 22 | def fn(expected: float, actual: float) -> bool: 23 | # Avoid division by zero 24 | if np.fabs(expected) < epsilon and np.fabs(actual) < epsilon: 25 | return True 26 | return np.fabs(actual / expected - 1) < error 27 | 28 | return fn 29 | 30 | 31 | def get_nix_output(test_item: str) -> Dict[str, str]: 32 | return json.loads( 33 | subprocess.check_output( 34 | ["nix", "eval", "--raw", f'{FLAKE_PATH}#test.mathOutput."{test_item}"'] 35 | ) 36 | ) 37 | 38 | 39 | @pytest.mark.parametrize( 40 | "test_item,ground_truth,comparator", 41 | [ 42 | ("atan", np.arctan, compare_ratio(0.000001)), 43 | ("cos", np.cos, compare_ratio(0.000001)), 44 | ("deg2rad", np.deg2rad, compare_ratio(0.000001)), 45 | ("div_-2.5", lambda x: x // -2.5, compare_absolute(EPSILON)), 46 | ("div_3_int", lambda x: x // 3, compare_absolute(EPSILON)), 47 | ("div_3", lambda x: x // 3, compare_absolute(EPSILON)), 48 | ("div_4.5", lambda x: x // 4.5, compare_absolute(EPSILON)), 49 | ("exp_large", np.exp, compare_ratio(0.000001)), 50 | ("exp_small", np.exp, compare_ratio(0.000001)), 51 | ("fabs", np.fabs, compare_absolute(EPSILON)), 52 | ("factorial", lambda x: math.factorial(int(x)), compare_absolute(EPSILON)), 53 | ("int", int, compare_absolute(EPSILON)), 54 | ("ln_large", np.log, compare_ratio(0.000001)), 55 | ("ln_small", np.log, compare_ratio(0.000001)), 56 | ("log10", np.log10, compare_ratio(0.000001)), 57 | ("log2", np.log2, compare_ratio(0.000001)), 58 | ("mod_-2.5", lambda x: x % -2.5, compare_absolute(EPSILON)), 59 | ("mod_3_int", lambda x: x % 3, compare_absolute(EPSILON)), 60 | ("mod_3", lambda x: x % 3, compare_absolute(EPSILON)), 61 | ("mod_4.5", lambda x: x % 4.5, compare_absolute(EPSILON)), 62 | ("pow_x_0", lambda x: np.power(x, 0), compare_absolute(EPSILON)), 63 | ("pow_x_3", lambda x: np.power(x, 3), compare_absolute(EPSILON)), 64 | ("pow_x_-2_positive", lambda x: np.power(x, -2), compare_absolute(EPSILON)), 65 | ("pow_x_-2_negative", lambda x: np.power(x, -2), compare_absolute(EPSILON)), 66 | ("pow_-2.5_x", lambda x: np.power(-2.5, x), compare_ratio(0.000001)), 67 | ("pow_4.5_x", lambda x: np.power(4.5, x), compare_ratio(0.000001)), 68 | ("round", np.round, compare_absolute(EPSILON)), 69 | ("sin", np.sin, compare_ratio(0.000001)), 70 | ("sqrt", np.sqrt, compare_absolute(EPSILON)), 71 | ("tan", np.tan, compare_ratio(0.000001)), 72 | ], 73 | ) 74 | def test_runner( 75 | test_item: str, 76 | ground_truth: Callable[[float], float], 77 | comparator: Callable[[float, float], float], 78 | ): 79 | test_results = get_nix_output(test_item) 80 | has_failure = False 81 | for input, output in test_results.items(): 82 | expected = ground_truth(float(input)) 83 | if not comparator(expected, float(output)): 84 | has_failure = True 85 | print( 86 | f"FAIL: test {test_item} input {input} expected {expected} actual {output}" 87 | ) 88 | if has_failure: 89 | raise RuntimeError("Some items did not pass test") 90 | -------------------------------------------------------------------------------- /tools/approximate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env nix-shell 2 | #!nix-shell -i python3 -p python3 -p python3Packages.numpy 3 | import json 4 | from typing import Callable, Iterable, List, Optional, Tuple 5 | import numpy as np 6 | from numpy.polynomial.polynomial import Polynomial 7 | 8 | EPSILON = 1e-10 9 | 10 | class Approximate: 11 | def __init__(self, fn: Callable[[Iterable[float]], Iterable[float]], linspace: Tuple[float, float, float], max_poly_degrees: Optional[int] = None, target_error_percent: Optional[float]=None): 12 | self.fn = fn 13 | self.linspace = linspace 14 | self.input = np.linspace(*linspace) 15 | self.expected = fn(self.input) 16 | 17 | if not max_poly_degrees and not target_error_percent: 18 | raise ValueError("Either max_poly_degrees or target_error_percent must be set to specify search range") 19 | self.max_poly_degrees = max_poly_degrees 20 | self.target_error_percent = target_error_percent 21 | 22 | def _fit(self, deg: int) -> Tuple[float, Polynomial]: 23 | fit: Polynomial = Polynomial.fit(self.input, self.expected, deg, domain=(self.linspace[0], self.linspace[1]), window=(self.linspace[0], self.linspace[1])) 24 | result = fit(self.input) 25 | error_percent = np.fabs((result - self.expected) / self.expected) 26 | max_error_percent = np.max(error_percent[error_percent < 1e308] * 100) 27 | return max_error_percent, fit 28 | 29 | def _run_max_poly_degrees(self) -> Tuple[float, Polynomial]: 30 | error, poly = self._fit(1) 31 | for deg in range(2, self.max_poly_degrees+1): 32 | _error, _poly = self._fit(deg) 33 | if _error < error: 34 | error = _error 35 | poly = _poly 36 | return error, poly 37 | 38 | def _run_target_error_percent(self) -> Tuple[float, Polynomial]: 39 | deg = 0 40 | while True: 41 | deg += 1 42 | error, poly = self._fit(deg) 43 | if error <= self.target_error_percent: 44 | return error, poly 45 | 46 | def run(self) -> Tuple[float, Polynomial]: 47 | if self.max_poly_degrees: 48 | return self._run_max_poly_degrees() 49 | elif self.target_error_percent: 50 | return self._run_target_error_percent() 51 | else: 52 | raise NotImplementedError() 53 | 54 | def explain(self) -> Polynomial: 55 | error, poly = self.run() 56 | print(f"Degree: {poly.degree()}") 57 | print(f"Error %: {error}") 58 | print(f"Coefficients: {json.dumps(json.dumps(list(poly.coef)))}") 59 | return poly 60 | 61 | Approximate(np.exp, (0, 1, 10000), max_poly_degrees=100).explain() 62 | # Approximate(np.exp, (0, 1, 10000), target_error_percent=1e-4).explain() 63 | Approximate(np.arctan, (0, 1, 10000), max_poly_degrees=100).explain() 64 | --------------------------------------------------------------------------------