├── site └── readme_as_index ├── .gitignore ├── source └── dstats │ ├── package.d │ ├── kerneldensity.d │ ├── pca.d │ ├── infotheory.d │ ├── summary.d │ ├── random.d │ └── sort.d ├── dub.json ├── .github └── workflows │ └── d.yml └── README.md /site/readme_as_index: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .dub 2 | dub.selections.json 3 | __* 4 | *.exe 5 | *.[oa] 6 | *-test-library 7 | -------------------------------------------------------------------------------- /source/dstats/package.d: -------------------------------------------------------------------------------- 1 | /**Convenience module that simply publicly imports everything else.*/ 2 | /* 3 | * License: 4 | * If you think this module is even copyrightable, you have issues. 5 | */ 6 | 7 | module dstats; 8 | 9 | public import dstats.alloc, dstats.sort, dstats.base, dstats.cor, 10 | dstats.distrib, dstats.infotheory, dstats.random, 11 | dstats.summary, dstats.tests, dstats.regress, dstats.kerneldensity, 12 | dstats.pca; 13 | -------------------------------------------------------------------------------- /dub.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "dstats", 3 | "description": "a statistics package for D", 4 | "authors": ["David Simcha", "Don Clugston"], 5 | "homepage": "https://github.com/DlangScience/dstats", 6 | "license": "various", 7 | 8 | "buildTypes": { 9 | "DSddox": { 10 | "buildOptions": ["syntaxOnly"], 11 | "dflags": ["-c", "-Df__dummy.html", "-Xfdocs.json"], 12 | "postBuildCommands": [ 13 | "rm -rf site/api", 14 | "ddox filter --min-protection=Protected docs.json", 15 | "ddox generate-html --navigation-type=ModuleTree docs.json site/api" 16 | ] 17 | } 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /.github/workflows/d.yml: -------------------------------------------------------------------------------- 1 | name: D 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | matrix: 14 | os: [ubuntu-latest, macos-latest] 15 | dc: [ldc-latest, dmd-latest] 16 | include: 17 | - os: ubuntu-latest 18 | dc: gdc 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | 23 | - name: Install D compiler 24 | if: matrix.dc == 'ldc-latest' || matrix.dc == 'dmd-latest' 25 | uses: dlang-community/setup-dlang@v1 26 | with: 27 | compiler: ${{ matrix.dc }} 28 | 29 | - name: Install GDC 30 | if: matrix.dc == 'gdc' 31 | run: | 32 | sudo apt install gdc 33 | wget https://github.com/dlang/dub/releases/download/v1.23.0/dub-v1.23.0-linux-x86_64.tar.gz 34 | tar xvf dub-v1.23.0-linux-x86_64.tar.gz 35 | 36 | - name: 'Build & Test' 37 | run: | 38 | export PATH=$PATH:`pwd` # for GDC build 39 | # Build the project, with its main file included, without unittests 40 | dub build 41 | # Build and run tests, as defined by `unittest` configuration 42 | # In this mode, `mainSourceFile` is excluded and `version (unittest)` are included 43 | # See https://dub.pm/package-format-json.html#configurations 44 | dub test 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | dstats 2 | ====== 3 | 4 | A statistics library for D, emphasising a middle ground between performance and ease 5 | of use. This repository is a fork of David Simcha's https://github.com/dsimcha/dstats 6 | created to bring the library up to date and enable dub support. 7 | 8 | ### [Click here for documentation](https://DlangScience.github.io/dstats/api) 9 | 10 | Building 11 | -------- 12 | 13 | ### dub 14 | Simply add dstats as a dependency in your projects dub.json 15 | 16 | The SciD version of dstats is not currently supported in dub. 17 | 18 | ### manual 19 | 20 | This library has no mandatory dependencies other than the latest versions of Phobos 21 | and DMD. 22 | To build, simply unpack all the files into an empty directory and do a: 23 | 24 | ```sh 25 | dmd -O -inline -release -lib -ofdstats.lib *.d 26 | ``` 27 | 28 | SciD is an optional dependency, as Dstats is slowly being integrated into it. 29 | If used, it enables a few extra features and faster implementations of some 30 | algorithms. 31 | To build with this enabled, make sure your SciD directory is in your import path and 32 | do: 33 | 34 | ```sh 35 | dmd -O -inline -release -lib -ofdstats.lib -version=scid *.d 36 | ``` 37 | 38 | You'll then need to link in your SciD library and Blas and Lapack libraries when 39 | compiling 40 | an application that uses Dstats. 41 | 42 | Conventions 43 | ----------- 44 | 45 | 1. A delicate balance between ease of use, flexibility and performance should be maintained. 46 | There are tons of good libraries for hardcore numerics programmers that emphasize performance above 47 | all else. There are also tons of good statistics packages for people who are basically 48 | non-programmers and aren't doing large-scale analyses or analyses in the context of larger programs. 49 | The distribution seems very bimodal. This library tries to target the middle ground and recognize 50 | the principles of tradeoffs and diminishing returns with regard to performance, flexibility 51 | and ease of use. 52 | 53 | 2. Everything should work with the lowest common denominator generic range possible. It's 54 | frustrating to have to write tons of boilerplate code just to translate data from one format into 55 | another. Also, oftentimes even if the data is in the form of an array it needs to be copied so it 56 | can be reordered without the reordering being visible to the caller. In these cases, it can be 57 | copied just as easily whether the input data is in the form of an array or some other range. 58 | 59 | 3. Throwing exceptions vs. returning NaN: The convention here is that an exception should be 60 | thrown if a primitive parameter (i.e. an int or a float) is not in the acceptable range. This is 61 | because such things can trivially be checked upfront and should not occur by accident in most cases, 62 | except for the case of bugs internal to dstats. If the errant function parameter is the dataset, 63 | i.e. a range of some kind, then a NaN should be returned, because when doing large-scale analyses, 64 | a few pieces of data are expected to be defective in ways that are not easy to check upfront and 65 | should not halt the whole analysis. 66 | 67 | In general, this means that dstats.distrib should throw on invalid parameters, 68 | and all other modules should return a NaN. Any other result is most likely a bug. 69 | Cases where dstats.tests calls into dstats.distrib, resulting in thrown exceptions, are 70 | unfortunately too common and need to be fixed. 71 | 72 | 4. License: Each file contains a license header. All modules that are exclusively written by 73 | the main author (David Simcha) are licensed under the Boost license, so that pieces of them may 74 | freely be incorporated into Phobos and attribution is not required for binaries. Some modules 75 | consist of code borrowed from other places and are thus required to conform to the terms of these 76 | licenses. All are under permissive (i.e. non-copyleft) open source licenses, but some may require 77 | binary attribution. 78 | 79 | Known Problems 80 | -------------- 81 | 82 | https://issues.dlang.org/show_bug.cgi?id=9449 causes a segfault in ```dstats.tests.friedmanTest``` on the line ```Mean[len] colMeans;```. This is a backend bug and does not affect ldc or gdc. 83 | -------------------------------------------------------------------------------- /source/dstats/kerneldensity.d: -------------------------------------------------------------------------------- 1 | /**This module contains a small but growing library for performing kernel 2 | * density estimation. 3 | * 4 | * Author: David Simcha 5 | */ 6 | /* 7 | * License: 8 | * Boost Software License - Version 1.0 - August 17th, 2003 9 | * 10 | * Permission is hereby granted, free of charge, to any person or organization 11 | * obtaining a copy of the software and accompanying documentation covered by 12 | * this license (the "Software") to use, reproduce, display, distribute, 13 | * execute, and transmit the Software, and to prepare derivative works of the 14 | * Software, and to permit third-parties to whom the Software is furnished to 15 | * do so, all subject to the following: 16 | * 17 | * The copyright notices in the Software and this entire statement, including 18 | * the above license grant, this restriction and the following disclaimer, 19 | * must be included in all copies of the Software, in whole or in part, and 20 | * all derivative works of the Software, unless such copies or derivative 21 | * works are solely in the form of machine-executable object code generated by 22 | * a source language processor. 23 | * 24 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 25 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 26 | * FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT 27 | * SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE 28 | * FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, 29 | * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 30 | * DEALINGS IN THE SOFTWARE. 31 | */ 32 | module dstats.kerneldensity; 33 | 34 | import std.conv, std.math, std.exception, std.traits, std.range, 35 | std.array, std.typetuple, dstats.distrib; 36 | 37 | import std.algorithm : min, max; 38 | 39 | import dstats.alloc, dstats.base, dstats.summary; 40 | 41 | version(unittest) { 42 | import dstats.random, std.stdio; 43 | } 44 | 45 | /**Estimates densities in the 1-dimensional case. The 1-D case is special 46 | * enough to be treated as a special case, since it's very common and enables 47 | * some significant optimizations that are otherwise not feasible. 48 | * 49 | * Under the hood, this works by binning the data into a large number of bins 50 | * (currently 1,000), convolving it with the kernel function to smooth it, and 51 | * then using linear interpolation to evaluate the density estimates. This 52 | * will produce results that are different from the textbook definition of 53 | * kernel density estimation, but to an extent that's negligible in most cases. 54 | * It also means that constructing this object is relatively expensive, but 55 | * evaluating a density estimate can be done in O(1) time complexity afterwords. 56 | */ 57 | class KernelDensity1D { 58 | private: 59 | immutable double[] bins; 60 | immutable double[] cumulative; 61 | immutable double minElem; 62 | immutable double maxElem; 63 | immutable double diffNeg1Nbin; 64 | 65 | 66 | this(immutable double[] bins, immutable double[] cumulative, 67 | double minElem, double maxElem) { 68 | this.bins = bins; 69 | this.cumulative = cumulative; 70 | this.minElem = minElem; 71 | this.maxElem = maxElem; 72 | this.diffNeg1Nbin = bins.length / (maxElem - minElem); 73 | } 74 | 75 | private static double findEdgeBuffer(C)(C kernel) { 76 | // Search for the approx. point where the kernel's density is 0.001 * 77 | // what it is at zero. 78 | immutable zeroVal = kernel(0); 79 | double ret = 1; 80 | 81 | double factor = 4; 82 | double kernelVal; 83 | 84 | do { 85 | while(kernel(ret) / zeroVal > 1e-3) { 86 | ret *= factor; 87 | } 88 | 89 | factor = (factor - 1) / 2 + 1; 90 | while(kernel(ret) / zeroVal < 1e-4) { 91 | ret /= factor; 92 | } 93 | 94 | kernelVal = kernel(ret) / zeroVal; 95 | } while((kernelVal > 1e-3 || kernelVal < 1e-4) && factor > 1); 96 | 97 | return ret; 98 | } 99 | 100 | public: 101 | /**Construct a kernel density estimation object from a callable object. 102 | * R must be a range of numeric types. C must be a kernel function, 103 | * delegate, or class or struct with overloaded opCall. The kernel 104 | * function is assumed to be symmetric about zero, to take its maximum 105 | * value at zero and to be unimodal. 106 | * 107 | * edgeBuffer determines how much space below and above the smallest and 108 | * largest observed value will be allocated when doing the binning. 109 | * Values less than reduce!min(range) - edgeBuffer or greater than 110 | * reduce!max(range) + edgeBuffer will be assigned a density of zero. 111 | * If this value is left at its default, it will be set to a value at which 112 | * the kernel is somewhere between 1e-3 and 1e-4 times its value at zero. 113 | * 114 | * The bandwidth of the kernel is indirectly selected by parametrizing the 115 | * kernel function. 116 | * 117 | * Examples: 118 | * --- 119 | * auto randNums = randArray!rNormal(1_000, 0, 1); 120 | * auto kernel = parametrize!normalPDF(0, 0.01); 121 | * auto density = KernelDensity1D(kernel, randNums); 122 | * writeln(normalPDF(1, 0, 1), " ", density(1)). // Should be about the same. 123 | * --- 124 | */ 125 | static KernelDensity1D fromCallable(C, R) 126 | (scope C kernel, R range, double edgeBuffer = double.nan) 127 | if(isForwardRange!R && is(typeof(kernel(2.0)) : double)) { 128 | enum nBin = 1000; 129 | auto alloc = newRegionAllocator(); 130 | 131 | uint N = 0; 132 | double minElem = double.infinity; 133 | double maxElem = -double.infinity; 134 | foreach(elem; range) { 135 | minElem = min(minElem, elem); 136 | maxElem = max(maxElem, elem); 137 | N++; 138 | } 139 | 140 | if(isNaN(edgeBuffer)) { 141 | edgeBuffer = findEdgeBuffer(kernel); 142 | } 143 | minElem -= edgeBuffer; 144 | maxElem += edgeBuffer; 145 | 146 | // Using ints here because they convert faster to floats than uints do. 147 | auto binsRaw = alloc.uninitializedArray!(int[])(nBin); 148 | binsRaw[] = 0; 149 | 150 | foreach(elemRaw; range) { 151 | double elem = elemRaw - minElem; 152 | elem /= (maxElem - minElem); 153 | elem *= nBin; 154 | auto bin = to!uint(elem); 155 | if(bin == nBin) { 156 | bin--; 157 | } 158 | 159 | binsRaw[bin]++; 160 | } 161 | 162 | // Convolve the binned data with our kernel. Since N is fairly small 163 | // we'll use a simple O(N^2) algorithm. According to my measurements, 164 | // this is actually comparable in speed to using an FFT (and a lot 165 | // simplier and more space efficient) because: 166 | // 167 | // 1. We can take advantage of kernel symmetry. 168 | // 169 | // 2. We can take advantage of the sparsity of binsRaw. (We don't 170 | // need to convolve the zero count bins.) 171 | // 172 | // 3. We don't need to do any zero padding to get a non-cyclic 173 | // convolution. 174 | // 175 | // 4. We don't need to convolve the tails of the kernel function, 176 | // where the contribution to the final density estimate would be 177 | // negligible. 178 | auto binsCooked = uninitializedArray!(double[])(nBin); 179 | binsCooked[] = 0; 180 | 181 | auto kernelPoints = alloc.uninitializedArray!(double[])(nBin); 182 | immutable stepSize = (maxElem - minElem) / nBin; 183 | 184 | kernelPoints[0] = kernel(0); 185 | immutable stopAt = kernelPoints[0] * 1e-10; 186 | foreach(ptrdiff_t i; 1..kernelPoints.length) { 187 | kernelPoints[i] = kernel(stepSize * i); 188 | 189 | // Don't bother convolving stuff that contributes negligibly. 190 | if(kernelPoints[i] < stopAt) { 191 | kernelPoints = kernelPoints[0..i]; 192 | break; 193 | } 194 | } 195 | 196 | foreach(i, count; binsRaw) if(count > 0) { 197 | binsCooked[i] += kernelPoints[0] * count; 198 | 199 | foreach(offset; 1..min(kernelPoints.length, max(i + 1, nBin - i))) { 200 | immutable kernelVal = kernelPoints[offset]; 201 | 202 | if(i >= offset) { 203 | binsCooked[i - offset] += kernelVal * count; 204 | } 205 | 206 | if(i + offset < nBin) { 207 | binsCooked[i + offset] += kernelVal * count; 208 | } 209 | } 210 | } 211 | 212 | binsCooked[] /= sum(binsCooked); 213 | binsCooked[] *= nBin / (maxElem - minElem); // Make it a density. 214 | 215 | auto cumulative = uninitializedArray!(double[])(nBin); 216 | cumulative[0] = binsCooked[0]; 217 | foreach(i; 1..nBin) { 218 | cumulative[i] = cumulative[i - 1] + binsCooked[i]; 219 | } 220 | cumulative[] /= cumulative[$ - 1]; 221 | 222 | return new typeof(this)( 223 | assumeUnique(binsCooked), assumeUnique(cumulative), 224 | minElem, maxElem); 225 | } 226 | 227 | /**Construct a kernel density estimator from an alias.*/ 228 | static KernelDensity1D fromAlias(alias kernel, R) 229 | (R range, double edgeBuffer = double.nan) 230 | if(isForwardRange!R && is(typeof(kernel(2.0)) : double)) { 231 | static double kernelFun(double x) { 232 | return kernel(x); 233 | } 234 | 235 | return fromCallable(&kernelFun, range, edgeBuffer); 236 | } 237 | 238 | /**Construct a kernel density estimator using the default kernel, which is 239 | * a Gaussian kernel with the Scott bandwidth. 240 | */ 241 | static KernelDensity1D fromDefaultKernel(R) 242 | (R range, double edgeBuffer = double.nan) 243 | if(isForwardRange!R && is(ElementType!R : double)) { 244 | immutable bandwidth = scottBandwidth(range.save); 245 | 246 | double kernel(double x) { 247 | return normalPDF(x, 0, bandwidth); 248 | } 249 | 250 | return fromCallable(&kernel, range, edgeBuffer); 251 | } 252 | 253 | /**Compute the probability density at a given point.*/ 254 | double opCall(double x) const { 255 | if(x < minElem || x > maxElem) { 256 | return 0; 257 | } 258 | 259 | x -= minElem; 260 | x *= diffNeg1Nbin; 261 | 262 | immutable fract = x - floor(x); 263 | immutable upper = to!size_t(ceil(x)); 264 | immutable lower = to!size_t(floor(x)); 265 | 266 | if(upper == bins.length) { 267 | return bins[$ - 1]; 268 | } 269 | 270 | immutable ret = fract * bins[upper] + (1 - fract) * bins[lower]; 271 | return max(0, ret); // Compensate for roundoff 272 | } 273 | 274 | /**Compute the cumulative density, i.e. the integral from -infinity to x.*/ 275 | double cdf(double x) const { 276 | if(x <= minElem) { 277 | return 0; 278 | } else if(x >= maxElem) { 279 | return 1; 280 | } 281 | 282 | x -= minElem; 283 | x *= diffNeg1Nbin; 284 | 285 | immutable fract = x - floor(x); 286 | immutable upper = to!size_t(ceil(x)); 287 | immutable lower = to!size_t(floor(x)); 288 | 289 | if(upper == cumulative.length) { 290 | return 1; 291 | } 292 | 293 | return fract * cumulative[upper] + (1 - fract) * cumulative[lower]; 294 | } 295 | 296 | /**Compute the cumulative density from the rhs, i.e. the integral from 297 | * x to infinity. 298 | */ 299 | double cdfr(double x) const { 300 | // Here, we can get away with just returning 1 - cdf b/c 301 | // there are inaccuracies several orders of magnitude bigger than 302 | // the rounding error. 303 | return 1.0 - cdf(x); 304 | } 305 | } 306 | 307 | unittest { 308 | auto kde = KernelDensity1D.fromCallable(parametrize!normalPDF(0, 1), [0]); 309 | assert(approxEqual(kde(1), normalPDF(1))); 310 | assert(approxEqual(kde.cdf(1), normalCDF(1))); 311 | assert(approxEqual(kde.cdfr(1), normalCDFR(1))); 312 | 313 | // This is purely to see if fromAlias works. 314 | auto cosKde = KernelDensity1D.fromAlias!cos([0], 1); 315 | 316 | // Make sure fromDefaultKernel at least instantiates. 317 | auto defaultKde = KernelDensity1D.fromDefaultKernel([1, 2, 3]); 318 | } 319 | 320 | /**Uses Scott's Rule to select the bandwidth of the Gaussian kernel density 321 | * estimator. This is 1.06 * min(stdev(data), interquartileRange(data) / 1.34) 322 | * N ^^ -0.2. R must be a forward range of numeric types. 323 | * 324 | * Examples: 325 | * --- 326 | * immutable bandwidth = scottBandwidth(data); 327 | * auto kernel = parametrize!normalPDF(0, bandwidth); 328 | * auto kde = KernelDensity1D(data, kernel); 329 | * --- 330 | * 331 | * References: 332 | * Scott, D. W. (1992) Multivariate Density Estimation: Theory, Practice, 333 | * and Visualization. Wiley. 334 | */ 335 | double scottBandwidth(R)(R data) 336 | if(isForwardRange!R && is(ElementType!R : double)) { 337 | 338 | immutable summary = meanStdev(data.save); 339 | immutable interquartile = interquantileRange(data.save, 0.25) / 1.34; 340 | immutable sigmaHat = min(summary.stdev, interquartile); 341 | 342 | return 1.06 * sigmaHat * (summary.N ^^ -0.2); 343 | } 344 | 345 | unittest { 346 | // Values from R. 347 | assert(approxEqual(scottBandwidth([1,2,3,4,5]), 1.14666)); 348 | assert(approxEqual(scottBandwidth([1,2,2,2,2,8,8,8,8]), 2.242446)); 349 | } 350 | 351 | /**Construct an N-dimensional kernel density estimator. This is done using 352 | * the textbook definition of kernel density estimation, since the binning 353 | * and convolving method used in the 1-D case would rapidly become 354 | * unfeasible w.r.t. memory usage as dimensionality increased. 355 | * 356 | * Eventually, a 2-D estimator might be added as another special case, but 357 | * beyond 2-D, bin-and-convolute clearly isn't feasible. 358 | * 359 | * This class can be used for 1-D estimation instead of KernelDensity1D, and 360 | * will work properly. This is useful if: 361 | * 362 | * 1. You can't accept even the slightest deviation from the results that the 363 | * textbook definition of kernel density estimation would produce. 364 | * 365 | * 2. You are only going to evaluate at a few points and want to avoid the 366 | * up-front cost of the convolution used in the 1-D case. 367 | * 368 | * 3. You're using some weird kernel that doesn't meet the assumptions 369 | * required for KernelDensity1D. 370 | */ 371 | class KernelDensity { 372 | private immutable double[][] points; 373 | private double delegate(double[]...) kernel; 374 | 375 | private this(immutable double[][] points) { 376 | this.points = points; 377 | } 378 | 379 | /**Returns the number of dimensions in the estimator.*/ 380 | uint nDimensions() const @property { 381 | // More than uint.max dimensions is absolutely implausible. 382 | assert(points.length <= uint.max); 383 | return cast(uint) points.length; 384 | } 385 | 386 | /**Construct a kernel density estimator from a kernel provided as a callable 387 | * object (such as a function pointer, delegate, or class with overloaded 388 | * opCall). R must be either a range of ranges, multiple ranges passed in 389 | * as variadic arguments, or a single range for the 1D case. Each range 390 | * represents the values of one variable in the joint distribution. 391 | * kernel must accept either an array of doubles or the same number of 392 | * arguments as the number of dimensions, and must return a floating point 393 | * number. 394 | * 395 | * Examples: 396 | * --- 397 | * // Create an estimate of the density of the joint distribution of 398 | * // hours sleep and programming skill. 399 | * auto programmingSkill = [8,6,7,5,3,0,9]; 400 | * auto hoursSleep = [3,6,2,4,3,5,8]; 401 | * 402 | * // Make a 2D Gaussian kernel function with bandwidth 0.5 in both 403 | * // dimensions and covariance zero. 404 | * static double myKernel(double x1, double x2) { 405 | * return normalPDF(x1, 0, 0.5) * normalPDF(x2, 0, 0.5); 406 | * } 407 | * 408 | * auto estimator = KernelDensity.fromCallable 409 | * (&myKernel, programmingSkill, hoursSleep); 410 | * 411 | * // Estimate the density at programming skill 1, 2 hours sleep. 412 | * auto density = estimator(1, 2); 413 | * --- 414 | */ 415 | static KernelDensity fromCallable(C, R...)(C kernel, R ranges) 416 | if(allSatisfy!(isInputRange, R)) { 417 | auto kernelWrapped = wrapToArrayVariadic(kernel); 418 | 419 | static if(R.length == 1 && isInputRange!(typeof(ranges[0].front))) { 420 | alias ranges[0] data; 421 | } else { 422 | alias ranges data; 423 | } 424 | 425 | double[][] points; 426 | foreach(range; data) { 427 | double[] asDoubles; 428 | 429 | static if(hasLength!(typeof(range))) { 430 | asDoubles = uninitializedArray!(double[])(range.length); 431 | 432 | size_t i = 0; 433 | foreach(elem; range) { 434 | asDoubles[i++] = elem; 435 | } 436 | } else { 437 | auto app = appender(&asDoubles); 438 | foreach(elem; range) { 439 | app.put(elem); 440 | } 441 | } 442 | 443 | points ~= asDoubles; 444 | } 445 | 446 | dstatsEnforce(points.length, 447 | "Can't construct a zero dimensional kernel density estimator."); 448 | 449 | foreach(arr; points[1..$]) { 450 | dstatsEnforce(arr.length == points[0].length, 451 | "All ranges must be the same length to construct a KernelDensity."); 452 | } 453 | 454 | auto ret = new KernelDensity(assumeUnique(points)); 455 | ret.kernel = kernelWrapped; 456 | 457 | return ret; 458 | } 459 | 460 | /**Estimate the density at the point given by x. The variables in X are 461 | * provided in the same order as the ranges were provided for estimation. 462 | */ 463 | double opCall(double[] x...) const { 464 | dstatsEnforce(x.length == points.length, 465 | "Dimension mismatch when evaluating kernel density."); 466 | double sum = 0; 467 | 468 | auto alloc = newRegionAllocator(); 469 | auto dataPoint = alloc.uninitializedArray!(double[])(points.length); 470 | foreach(i; 0..points[0].length) { 471 | foreach(j; 0..points.length) { 472 | dataPoint[j] = x[j] - points[j][i]; 473 | } 474 | 475 | sum += kernel(dataPoint); 476 | } 477 | 478 | sum /= points[0].length; 479 | return sum; 480 | } 481 | } 482 | 483 | unittest { 484 | auto data = randArray!rNormal(100, 0, 1); 485 | auto kernel = parametrize!normalPDF(0, scottBandwidth(data)); 486 | auto kde = KernelDensity.fromCallable(kernel, data); 487 | auto kde1 = KernelDensity1D.fromCallable(kernel, data); 488 | foreach(i; 0..5) { 489 | assert(abs(kde(i) - kde1(i)) < 0.01); 490 | } 491 | 492 | // Make sure example compiles. 493 | auto programmingSkill = [8,6,7,5,3,0,9]; 494 | auto hoursSleep = [3,6,2,4,3,5,8]; 495 | 496 | // Make a 2D Gaussian kernel function with bandwidth 0.5 in both 497 | // dimensions and covariance zero. 498 | static double myKernel(double x1, double x2) { 499 | return normalPDF(x1, 0, 0.5) * normalPDF(x2, 0, 0.5); 500 | } 501 | 502 | auto estimator = KernelDensity.fromCallable 503 | (&myKernel, programmingSkill, hoursSleep); 504 | 505 | // Estimate the density at programming skill 1, 2 hours sleep. 506 | auto density = estimator(1, 2); 507 | 508 | // Test instantiating from functor. 509 | auto foo = KernelDensity.fromCallable(estimator, hoursSleep); 510 | } 511 | 512 | 513 | private: 514 | 515 | double delegate(double[]...) wrapToArrayVariadic(C)(C callable) { 516 | static if(is(C == delegate) || isFunctionPointer!C) { 517 | alias callable fun; 518 | } else { // It's a functor. 519 | alias callable.opCall fun; 520 | } 521 | 522 | alias ParameterTypeTuple!fun params; 523 | static if(params.length == 1 && is(params[0] == double[])) { 524 | // Already in the right form. 525 | static if(is(C == delegate) && is(ReturnType!C == double)) { 526 | return callable; 527 | } else static if(is(ReturnType!(callable.opCall) == double)) { 528 | return &callable.opCall; 529 | } else { // Need to forward. 530 | double forward(double[] args...) { 531 | return fun(args); 532 | } 533 | 534 | return &forward; 535 | } 536 | } else { // Need to convert to single arguments and forward. 537 | static assert(allSatisfy!(isFloatingPoint, params)); 538 | 539 | double doCall(double[] args...) { 540 | assert(args.length == params.length); 541 | mixin("return fun(" ~ makeCallList(params.length) ~ ");"); 542 | } 543 | 544 | return &doCall; 545 | } 546 | } 547 | 548 | // CTFE function for forwarding elements of an array as single function 549 | // arguments. 550 | string makeCallList(uint N) { 551 | string ret; 552 | foreach(i; 0..N - 1) { 553 | ret ~= "args[" ~ to!string(i) ~ "], "; 554 | } 555 | 556 | ret ~= "args[" ~ to!string(N - 1) ~ "]"; 557 | return ret; 558 | } 559 | -------------------------------------------------------------------------------- /source/dstats/pca.d: -------------------------------------------------------------------------------- 1 | /** 2 | This module contains a basic implementation of principal component analysis, 3 | based on the NIPALS algorithm. This is fast when you only need the first 4 | few components (which is usually the case since PCA's main uses are 5 | visualization and dimensionality reduction). However, convergence slows 6 | drastically after the first few components have been removed and most of 7 | the matrix is just noise. 8 | 9 | References: 10 | 11 | en.wikipedia.org/wiki/Principal_component_analysis#Computing_principal_components_iteratively 12 | 13 | Author: David Simcha 14 | */ 15 | 16 | /* 17 | * License: 18 | * Boost Software License - Version 1.0 - August 17th, 2003 19 | * 20 | * Permission is hereby granted, free of charge, to any person or organization 21 | * obtaining a copy of the software and accompanying documentation covered by 22 | * this license (the "Software") to use, reproduce, display, distribute, 23 | * execute, and transmit the Software, and to prepare derivative works of the 24 | * Software, and to permit third-parties to whom the Software is furnished to 25 | * do so, all subject to the following: 26 | * 27 | * The copyright notices in the Software and this entire statement, including 28 | * the above license grant, this restriction and the following disclaimer, 29 | * must be included in all copies of the Software, in whole or in part, and 30 | * all derivative works of the Software, unless such copies or derivative 31 | * works are solely in the form of machine-executable object code generated by 32 | * a source language processor. 33 | * 34 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 35 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 36 | * FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT 37 | * SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE 38 | * FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, 39 | * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 40 | * DEALINGS IN THE SOFTWARE. 41 | */ 42 | module dstats.pca; 43 | 44 | import std.range, dstats.base, dstats.alloc, std.numeric, std.stdio, std.math, 45 | std.algorithm, std.array, dstats.summary, dstats.random, std.conv, 46 | std.exception, dstats.regress, std.traits; 47 | 48 | /// Result holder 49 | struct PrincipalComponent { 50 | /// The projection of the data onto the first principal component. 51 | double[] x; 52 | 53 | /// The vector representing the first principal component loadings. 54 | double[] rotation; 55 | } 56 | 57 | /** 58 | Sets options for principal component analysis. The default options are 59 | also the values in PrinCompOptions.init. 60 | */ 61 | struct PrinCompOptions { 62 | /// Center each column to zero mean. Default value: true. 63 | bool zeroMean = true; 64 | 65 | /** 66 | Scale each column to unit variance. Note that, if this option is set to 67 | true, zeroMean is ignored and the mean of each column is set to zero even 68 | if zeroMean is false. Default value: false. 69 | */ 70 | bool unitVariance = false; 71 | 72 | /** 73 | Overwrite input matrix instead of copying. Ignored if the matrix 74 | passed in does not have assignable, lvalue elements and centering or 75 | scaling is enabled. Default value: false. 76 | */ 77 | bool destructive = false; 78 | 79 | /** 80 | Effectively transpose the matrix. If enabled, treat each column as a 81 | data points and each row as a dimension. If disabled, do the opposite. 82 | Note that, if this is enabled, each row will be scaled and centered, 83 | not each column. Default value: false. 84 | */ 85 | bool transpose = false; 86 | 87 | /** 88 | Relative error at which to stop the optimization procedure. Default: 1e-4 89 | */ 90 | double relError = 1.0e-4; 91 | 92 | /** 93 | Absolute error at which to stop the optimization procedure. Default: 1e-5 94 | */ 95 | double absError = 1.0e-5; 96 | 97 | /** 98 | Maximum iterations for the optimization procedure. After this many 99 | iterations, the algorithm gives up and calls teh solution "good enough" 100 | no matter what. For exploratory analyses, "good enough" solutions 101 | can be had fast sometimes by making this value small. Default: uint.max 102 | */ 103 | uint maxIter = uint.max; 104 | 105 | private void doCenterScaleTransposed(R)(R data) { 106 | foreach(row; data.save) { 107 | immutable msd = meanStdev(row.save); 108 | 109 | foreach(ref elem; row) { 110 | // Already checked for whether we're supposed to be normalizing. 111 | elem -= msd.mean; 112 | if(unitVariance) elem /= msd.stdev; 113 | } 114 | } 115 | } 116 | 117 | private void doCenterScale(R)(R data) { 118 | if(!zeroMean && !unitVariance) return; 119 | if(data.empty) { 120 | return; 121 | } 122 | 123 | if(transpose) return doCenterScaleTransposed(data); 124 | 125 | auto alloc = newRegionAllocator(); 126 | immutable rowLen = walkLength(data.front.save); 127 | 128 | auto summs = alloc.uninitializedArray!(MeanSD[])(rowLen); 129 | summs[] = MeanSD.init; 130 | foreach(row; data) { 131 | size_t i = 0; 132 | foreach(elem; row) { 133 | enforce(i < rowLen, "Matrix must be rectangular for PCA."); 134 | summs[i++].put(elem); 135 | } 136 | 137 | enforce(i == rowLen, "Matrix must be rectangular for PCA."); 138 | } 139 | 140 | foreach(row; data) { 141 | size_t i = 0; 142 | foreach(ref elem; row) { 143 | elem -= summs[i].mean; 144 | if(unitVariance) elem /= summs[i].stdev; 145 | i++; 146 | } 147 | } 148 | } 149 | } 150 | 151 | 152 | /** 153 | Uses expectation-maximization to compute the first principal component of mat. 154 | Since there are a lot of options, they are controlled by a PrinCompOptions 155 | struct. (See above. PrinCompOptions.init contains the default values.) 156 | To have the results returned in a pre-allocated space, pass an explicit value 157 | for buf. 158 | */ 159 | PrincipalComponent firstComponent(Ror)( 160 | Ror data, 161 | PrinCompOptions opts = PrinCompOptions.init, 162 | PrincipalComponent buf = PrincipalComponent.init 163 | ) { 164 | auto alloc = newRegionAllocator(); 165 | 166 | PrincipalComponent doNonDestructive() { 167 | double[][] dataFixed; 168 | 169 | if(opts.transpose) { 170 | dataFixed = transposeDup(data, alloc); 171 | } else { 172 | dataFixed = doubleTempdupMatrix(data, alloc); 173 | } 174 | 175 | opts.transpose = false; // We already transposed if necessary. 176 | opts.doCenterScale(dataFixed); 177 | return firstComponentImpl(dataFixed, buf, opts); 178 | } 179 | 180 | static if(!hasLvalueElements!(ElementType!Ror) || 181 | !hasAssignableElements!(ElementType!Ror)) { 182 | if(opts.zeroMean || opts.unitVariance) { 183 | return doNonDestructive(); 184 | } else { 185 | return firstComponentImpl(data, buf, opts); 186 | } 187 | } else { 188 | if(!opts.destructive) { 189 | return doNonDestructive; 190 | } 191 | 192 | opts.doCenterScale(data); 193 | return firstComponentImpl(data, buf, opts); 194 | } 195 | } 196 | 197 | private PrincipalComponent firstComponentImpl(Ror)( 198 | Ror data, 199 | PrincipalComponent buf, 200 | PrinCompOptions opts 201 | ) { 202 | auto alloc = newRegionAllocator(); 203 | 204 | if(data.empty) return typeof(return).init; 205 | size_t rowLen = walkLength(data.front.save); 206 | size_t colLen = walkLength(data.save); 207 | 208 | immutable transposed = opts.transpose; 209 | if(transposed) swap(rowLen, colLen); 210 | 211 | auto t = alloc.uninitializedArray!(double[])(rowLen); 212 | auto p = (buf.rotation.length >= rowLen) ? 213 | buf.rotation[0..rowLen] : new double[rowLen]; 214 | p[] = 1; 215 | 216 | bool approxEqualOrNotFinite(const double[] a, const double[] b) { 217 | foreach(i; 0..a.length) { 218 | if(!isFinite(a[i]) || !isFinite(b[i])) { 219 | return true; 220 | } else if(!approxEqual(a[i], b[i], opts.relError, opts.absError)) { 221 | return false; 222 | } 223 | } 224 | 225 | return true; 226 | } 227 | 228 | uint iter; 229 | for(; iter < opts.maxIter; iter++) { 230 | t[] = 0; 231 | 232 | if(transposed) { 233 | auto dps = alloc.uninitializedArray!(double[])(colLen); 234 | scope(exit) alloc.freeLast(); 235 | dps[] = 0; 236 | 237 | size_t i = 0; 238 | foreach(row; data.save) { 239 | scope(exit) i++; 240 | 241 | static if(is(typeof(row) : const(double)[])) { 242 | // Take advantage of array ops. 243 | dps[] += p[i] * row[]; 244 | } else { 245 | size_t j = 0; 246 | foreach(elem; row) { 247 | scope(exit) j++; 248 | dps[j] += p[i] * elem; 249 | } 250 | } 251 | } 252 | 253 | i = 0; 254 | foreach(row; data.save) { 255 | scope(exit) i++; 256 | t[i] += dotProduct(row, dps); 257 | } 258 | 259 | } else { 260 | foreach(row; data.save) { 261 | immutable dp = dotProduct(p, row); 262 | static if( is(typeof(row) : const(double)[] )) { 263 | // Use array op optimization if possible. 264 | t[] += row[] * dp; 265 | } else { 266 | size_t i = 0; 267 | foreach(elem; row.save) { 268 | t[i++] += elem * dp; 269 | } 270 | } 271 | } 272 | } 273 | 274 | immutable tMagnitude = magnitude(t); 275 | t[] /= tMagnitude; 276 | 277 | if(approxEqualOrNotFinite(t, p)) { 278 | p[] = t[]; 279 | break; 280 | } 281 | 282 | p[] = t[]; 283 | } 284 | 285 | auto x = (buf.x.length >= colLen) ? 286 | buf.x[0..colLen] : new double[colLen]; 287 | size_t i = 0; 288 | 289 | if(transposed) { 290 | x[] = 0; 291 | 292 | size_t rowIndex = 0; 293 | foreach(row; data) { 294 | scope(exit) rowIndex++; 295 | size_t colIndex = 0; 296 | 297 | foreach(elem; row) { 298 | scope(exit) colIndex++; 299 | x[colIndex] += p[rowIndex] * elem; 300 | } 301 | } 302 | 303 | } else { 304 | foreach(row; data) { 305 | x[i++] = dotProduct(p, row); 306 | } 307 | } 308 | 309 | return PrincipalComponent(x, p); 310 | } 311 | 312 | /// Used for removeComponent(). 313 | enum Transposed : bool { 314 | 315 | /// 316 | yes = true, 317 | 318 | /// 319 | no = false 320 | } 321 | 322 | /** 323 | Remove the principal component specified by the given rotation vector from 324 | data. data must have assignable elements. Transposed controls whether 325 | rotation is considered a loading for the transposed matrix or the matrix 326 | as-is. 327 | */ 328 | void removeComponent(Ror, R)( 329 | Ror data, 330 | R rotation, 331 | Transposed transposed = Transposed.no 332 | ) { 333 | double[2] regressBuf; 334 | 335 | immutable rotMagNeg1 = 1.0 / magnitude(rotation.save); 336 | 337 | if(transposed) { 338 | auto alloc = newRegionAllocator(); 339 | auto dps = alloc.uninitializedArray!(double[]) 340 | (walkLength(data.front.save)); 341 | dps[] = 0; 342 | 343 | auto r2 = rotation.save; 344 | foreach(row; data.save) { 345 | scope(exit) r2.popFront(); 346 | 347 | size_t j = 0; 348 | 349 | foreach(elem; row) { 350 | scope(exit) j++; 351 | dps[j] += r2.front * elem; 352 | } 353 | } 354 | 355 | dps[] *= rotMagNeg1; 356 | 357 | r2 = rotation.save; 358 | foreach(row; data.save) { 359 | scope(exit) r2.popFront(); 360 | 361 | auto rs = row.save; 362 | for(size_t j = 0; !rs.empty; rs.popFront, j++) { 363 | rs.front = rs.front - r2.front * dps[j]; 364 | } 365 | } 366 | 367 | } else { 368 | foreach(row; data.save) { 369 | immutable dotProd = dotProduct(rotation, row); 370 | immutable coeff = dotProd * rotMagNeg1; 371 | 372 | auto rs = row.save; 373 | auto rots = rotation.save; 374 | while(!rs.empty && !rots.empty) { 375 | scope(exit) { 376 | rs.popFront(); 377 | rots.popFront(); 378 | } 379 | 380 | rs.front = rs.front - rots.front * coeff; 381 | } 382 | } 383 | } 384 | } 385 | 386 | /** 387 | Computes the first N principal components of the matrix. More efficient than 388 | calling firstComponent and removeComponent repeatedly because copying and 389 | transposing, if enabled, only happen once. 390 | */ 391 | PrincipalComponent[] firstNComponents(Ror)( 392 | Ror data, 393 | uint n, 394 | PrinCompOptions opts = PrinCompOptions.init, 395 | PrincipalComponent[] buf = null 396 | ) { 397 | 398 | auto alloc = newRegionAllocator(); 399 | 400 | PrincipalComponent[] doNonDestructive() { 401 | double[][] dataFixed; 402 | 403 | if(opts.transpose) { 404 | dataFixed = transposeDup(data, alloc); 405 | } else { 406 | dataFixed = doubleTempdupMatrix(data, alloc); 407 | } 408 | 409 | opts.transpose = false; // We already transposed if necessary. 410 | opts.doCenterScale(dataFixed); 411 | return firstNComponentsImpl(dataFixed, n, opts, buf); 412 | } 413 | 414 | static if(!hasLvalueElements!(ElementType!Ror) || 415 | !hasAssignableElements!(ElementType!Ror)) { 416 | return doNonDestructive(); 417 | } else { 418 | if(!opts.destructive) { 419 | return doNonDestructive(); 420 | } 421 | 422 | opts.doCenterScale(data); 423 | return firstNComponentsImpl(data, n, opts, buf); 424 | } 425 | } 426 | 427 | private PrincipalComponent[] firstNComponentsImpl(Ror)(Ror data, uint n, 428 | PrinCompOptions opts, PrincipalComponent[] buf = null) { 429 | 430 | opts.destructive = true; // We already copied if necessary. 431 | opts.unitVariance = false; // Already did this. 432 | 433 | buf.length = n; 434 | foreach(comp; 0..n) { 435 | if(comp != 0) { 436 | removeComponent(data, buf[comp - 1].rotation, 437 | cast(Transposed) opts.transpose); 438 | } 439 | 440 | buf[comp] = firstComponent(data, opts, buf[comp]); 441 | } 442 | 443 | return buf; 444 | } 445 | 446 | private double magnitude(R)(R x) { 447 | return sqrt(reduce!"a + b * b"(0.0, x)); 448 | } 449 | 450 | // Convert the matrix to a double[][]. 451 | double[] doubleTempdup(R)(R range, RegionAllocator alloc) { 452 | return alloc.array(map!(to!double)(range)); 453 | } 454 | 455 | private double[][] doubleTempdupMatrix(R)(R data, RegionAllocator alloc) { 456 | auto dataFixed = alloc.uninitializedArray!(double[][]) 457 | (data.length); 458 | foreach(i, ref elem; dataFixed) { 459 | elem = doubleTempdup(data[i], alloc); 460 | } 461 | 462 | return dataFixed; 463 | } 464 | 465 | private double[][] transposeDup(Ror)(Ror data, RegionAllocator alloc) { 466 | if(data.empty) return null; 467 | 468 | immutable rowLen = walkLength(data.front.save); 469 | immutable colLen = walkLength(data.save); 470 | auto ret = alloc.uninitializedArray!(double[][])(rowLen, colLen); 471 | 472 | size_t i = 0; 473 | foreach(row; data) { 474 | scope(exit) i++; 475 | if(i == colLen) break; 476 | 477 | size_t j = 0; 478 | foreach(col; row) { 479 | scope(exit) j++; 480 | if(j == rowLen) break; 481 | ret[j][i] = col; 482 | } 483 | 484 | dstatsEnforce(j == rowLen, "Matrices must be rectangular for PCA."); 485 | } 486 | 487 | dstatsEnforce(i == colLen, "Matrices must be rectangular for PCA."); 488 | return ret; 489 | } 490 | 491 | version(unittest) { 492 | // There are two equally valid answers for PCA that differ only by sign. 493 | // This tests whether one of them matches the test value. 494 | bool plusMinusAe(T, U)(T lhs, U rhs) { 495 | return approxEqual(lhs, rhs) || approxEqual(lhs, map!"-a"(rhs)); 496 | } 497 | } 498 | 499 | unittest { 500 | // Values from R's prcomp function. Not testing the 4th component because 501 | // it's mostly numerical fuzz. 502 | 503 | static double[][] getMat() { 504 | return [[3,6,2,4], [3,6,8,8], [6,7,5,3], [0,9,3,1]]; 505 | } 506 | 507 | auto mat = getMat(); 508 | auto allComps = firstNComponents(mat, 3); 509 | 510 | assert(plusMinusAe(allComps[0].x, [1.19, -5.11, -0.537, 4.45])); 511 | assert(plusMinusAe(allComps[0].rotation, [-0.314, 0.269, -0.584, -0.698])); 512 | 513 | assert(plusMinusAe(allComps[1].x, [0.805, -1.779, 2.882, -1.908])); 514 | assert(plusMinusAe(allComps[1].rotation, [0.912, -0.180, -0.2498, -0.2713])); 515 | 516 | assert(plusMinusAe(allComps[2].x, [2.277, -0.1055, -1.2867, -0.8849])); 517 | assert(plusMinusAe(allComps[2].rotation, [-0.1578, -0.5162, -0.704, 0.461])); 518 | 519 | auto comp1 = firstComponent(mat); 520 | assert(plusMinusAe(comp1.x, allComps[0].x)); 521 | assert(plusMinusAe(comp1.rotation, allComps[0].rotation)); 522 | 523 | // Test transposed. 524 | PrinCompOptions opts; 525 | opts.transpose = true; 526 | const double[][] m2 = mat; 527 | auto allCompsT = firstNComponents(m2, 3, opts); 528 | 529 | assert(plusMinusAe(allCompsT[0].x, [-3.2045, 6.3829695, -0.7227162, -2.455])); 530 | assert(plusMinusAe(allCompsT[0].rotation, [0.3025, 0.05657, 0.25142, 0.91763])); 531 | 532 | assert(plusMinusAe(allCompsT[1].x, [-3.46136, -0.6365, 1.75111, 2.3468])); 533 | assert(plusMinusAe(allCompsT[1].rotation, 534 | [-0.06269096, 0.88643747, -0.4498119, 0.08926183])); 535 | 536 | assert(plusMinusAe(allCompsT[2].x, 537 | [2.895362e-03, 3.201053e-01, -1.631345e+00, 1.308344e+00])); 538 | assert(plusMinusAe(allCompsT[2].rotation, 539 | [0.87140678, -0.14628160, -0.4409721, -0.15746595])); 540 | 541 | auto comp1T = firstComponent(m2, opts); 542 | assert(plusMinusAe(comp1T.x, allCompsT[0].x)); 543 | assert(plusMinusAe(comp1T.rotation, allCompsT[0].rotation)); 544 | 545 | // Test with scaling. 546 | opts.unitVariance = true; 547 | opts.transpose = false; 548 | auto allCompsScale = firstNComponents(mat, 3, opts); 549 | assert(plusMinusAe(allCompsScale[0].x, 550 | [6.878307e-02, -1.791647e+00, -3.733826e-01, 2.096247e+00])); 551 | assert(plusMinusAe(allCompsScale[0].rotation, 552 | [-0.3903603, 0.5398265, -0.4767623, -0.5735014])); 553 | 554 | assert(plusMinusAe(allCompsScale[1].x, 555 | [6.804833e-01, -9.412491e-01, 9.231432e-01, -6.623774e-01])); 556 | assert(plusMinusAe(allCompsScale[1].rotation, 557 | [0.7355678, -0.2849885, -0.5068900, -0.3475401])); 558 | 559 | assert(plusMinusAe(allCompsScale[2].x, 560 | [9.618048e-01, 1.428492e-02, -8.120905e-01, -1.639992e-01])); 561 | assert(plusMinusAe(allCompsScale[2].rotation, 562 | [-0.4925027, -0.5721616, -0.5897120, 0.2869006])); 563 | 564 | auto comp1S = firstComponent(m2, opts); 565 | assert(plusMinusAe(comp1S.x, allCompsScale[0].x)); 566 | assert(plusMinusAe(comp1S.rotation, allCompsScale[0].rotation)); 567 | 568 | opts.transpose = true; 569 | auto allTScale = firstNComponents(mat, 3, opts); 570 | 571 | assert(plusMinusAe(allTScale[0].x, 572 | [-1.419319e-01, 2.141908e+00, -8.368606e-01, -1.163116e+00])); 573 | assert(plusMinusAe(allTScale[0].rotation, 574 | [0.5361711, -0.2270814, 0.5685768, 0.5810981])); 575 | 576 | assert(plusMinusAe(allTScale[1].x, 577 | [-1.692899e+00, 4.929717e-01, 3.049089e-01, 8.950189e-01])); 578 | assert(plusMinusAe(allTScale[1].rotation, 579 | [0.3026505, 0.7906601, -0.3652524, 0.3871047])); 580 | 581 | assert(plusMinusAe(allTScale[2].x, 582 | [ 2.035977e-01, 2.705193e-02, -9.113051e-01, 6.806556e-01])); 583 | assert(plusMinusAe(allTScale[2].rotation, 584 | [0.7333168, -0.3396207, -0.4837054, -0.3360555])); 585 | 586 | auto comp1ST = firstComponent(m2, opts); 587 | assert(plusMinusAe(comp1ST.x, allTScale[0].x)); 588 | assert(plusMinusAe(comp1ST.rotation, allTScale[0].rotation)); 589 | 590 | void compAll(PrincipalComponent[] lhs, PrincipalComponent[] rhs) { 591 | assert(lhs.length == rhs.length); 592 | foreach(i, elem; lhs) { 593 | assert(plusMinusAe(elem.x, rhs[i].x)); 594 | assert(plusMinusAe(elem.rotation, rhs[i].rotation)); 595 | } 596 | } 597 | 598 | opts.destructive = true; 599 | auto allDestructive = firstNComponents(mat, 3, opts); 600 | compAll(allTScale, allDestructive); 601 | compAll([firstComponent(getMat(), opts)], allDestructive[0..1]); 602 | 603 | mat = getMat(); 604 | opts.transpose = false; 605 | allDestructive = firstNComponents(mat, 3, opts); 606 | compAll(allDestructive, allCompsScale); 607 | compAll([firstComponent(getMat(), opts)], allDestructive[0..1]); 608 | 609 | mat = getMat(); 610 | opts.unitVariance = false; 611 | allDestructive = firstNComponents(mat, 3, opts); 612 | compAll(allDestructive, allComps); 613 | compAll([firstComponent(getMat(), opts)], allDestructive[0..1]); 614 | 615 | mat = getMat(); 616 | opts.transpose = true; 617 | allDestructive = firstNComponents(mat, 3, opts); 618 | compAll(allDestructive, allCompsT); 619 | compAll([firstComponent(getMat(), opts)], allDestructive[0..1]); 620 | } 621 | -------------------------------------------------------------------------------- /source/dstats/infotheory.d: -------------------------------------------------------------------------------- 1 | /**Basic information theory. Joint entropy, mutual information, conditional 2 | * mutual information. This module uses the base 2 definition of these 3 | * quantities, i.e, entropy, mutual info, etc. are output in bits. 4 | * 5 | * Author: David Simcha*/ 6 | /* 7 | * License: 8 | * Boost Software License - Version 1.0 - August 17th, 2003 9 | * 10 | * Permission is hereby granted, free of charge, to any person or organization 11 | * obtaining a copy of the software and accompanying documentation covered by 12 | * this license (the "Software") to use, reproduce, display, distribute, 13 | * execute, and transmit the Software, and to prepare derivative works of the 14 | * Software, and to permit third-parties to whom the Software is furnished to 15 | * do so, all subject to the following: 16 | * 17 | * The copyright notices in the Software and this entire statement, including 18 | * the above license grant, this restriction and the following disclaimer, 19 | * must be included in all copies of the Software, in whole or in part, and 20 | * all derivative works of the Software, unless such copies or derivative 21 | * works are solely in the form of machine-executable object code generated by 22 | * a source language processor. 23 | * 24 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 25 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 26 | * FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT 27 | * SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE 28 | * FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, 29 | * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 30 | * DEALINGS IN THE SOFTWARE. 31 | */ 32 | 33 | module dstats.infotheory; 34 | 35 | import std.traits, std.math, std.typetuple, std.functional, std.range, 36 | std.array, std.typecons, std.algorithm; 37 | 38 | import dstats.base, dstats.alloc; 39 | import dstats.summary : sum; 40 | import dstats.distrib : chiSquareCDFR; 41 | 42 | import dstats.tests : toContingencyScore, gTestContingency; 43 | 44 | version(unittest) { 45 | import std.stdio, std.bigint, dstats.tests : gTestObs; 46 | } 47 | 48 | /**This function calculates the Shannon entropy of a forward range that is 49 | * treated as frequency counts of a set of discrete observations. 50 | * 51 | * Examples: 52 | * --- 53 | * double uniform3 = entropyCounts([4, 4, 4]); 54 | * assert(approxEqual(uniform3, log2(3))); 55 | * double uniform4 = entropyCounts([5, 5, 5, 5]); 56 | * assert(approxEqual(uniform4, 2)); 57 | * --- 58 | */ 59 | double entropyCounts(T)(T data) 60 | if(isForwardRange!(T) && doubleInput!(T)) { 61 | auto save = data.save(); 62 | return entropyCounts(save, sum!(T, double)(data)); 63 | } 64 | 65 | double entropyCounts(T)(T data, double n) 66 | if(isIterable!(T)) { 67 | immutable double nNeg1 = 1.0 / n; 68 | double entropy = 0; 69 | foreach(value; data) { 70 | if(value == 0) 71 | continue; 72 | double pxi = cast(double) value * nNeg1; 73 | entropy -= pxi * log2(pxi); 74 | } 75 | return entropy; 76 | } 77 | 78 | unittest { 79 | double uniform3 = entropyCounts([4, 4, 4].dup); 80 | assert(approxEqual(uniform3, log2(3.0))); 81 | double uniform4 = entropyCounts([5, 5, 5, 5].dup); 82 | assert(approxEqual(uniform4, 2)); 83 | assert(entropyCounts([2,2].dup)==1); 84 | assert(entropyCounts([5.1,5.1,5.1,5.1].dup)==2); 85 | assert(approxEqual(entropyCounts([1,2,3,4,5].dup), 2.1492553971685)); 86 | } 87 | 88 | template FlattenType(T...) { 89 | alias FlattenTypeImpl!(T).ret FlattenType; 90 | } 91 | 92 | template FlattenTypeImpl(T...) { 93 | static if(T.length == 0) { 94 | alias TypeTuple!() ret; 95 | } else { 96 | T[0] j; 97 | static if(is(typeof(j._jointRanges))) { 98 | alias TypeTuple!(typeof(j._jointRanges), FlattenType!(T[1..$])) ret; 99 | } else { 100 | alias TypeTuple!(T[0], FlattenType!(T[1..$])) ret; 101 | } 102 | } 103 | } 104 | 105 | private Joint!(FlattenType!(T, U)) flattenImpl(T, U...)(T start, U rest) { 106 | static if(rest.length == 0) { 107 | return start; 108 | } else static if(is(typeof(rest[0]._jointRanges))) { 109 | return flattenImpl(jointImpl(start.tupleof, rest[0]._jointRanges), rest[1..$]); 110 | } else { 111 | return flattenImpl(jointImpl(start.tupleof, rest[0]), rest[1..$]); 112 | } 113 | } 114 | 115 | Joint!(FlattenType!(T)) flatten(T...)(T args) { 116 | static assert(args.length > 0); 117 | static if(is(typeof(args[0]._jointRanges))) { 118 | auto myTuple = args[0]; 119 | } else { 120 | auto myTuple = jointImpl(args[0]); 121 | } 122 | static if(args.length == 1) { 123 | return myTuple; 124 | } else { 125 | return flattenImpl(myTuple, args[1..$]); 126 | } 127 | } 128 | 129 | /**Bind a set of ranges together to represent a joint probability distribution. 130 | * 131 | * Examples: 132 | * --- 133 | * auto foo = [1,2,3,1,1]; 134 | * auto bar = [2,4,6,2,2]; 135 | * auto e = entropy(joint(foo, bar)); // Calculate joint entropy of foo, bar. 136 | * --- 137 | */ 138 | Joint!(FlattenType!(T)) joint(T...)(T args) { 139 | return jointImpl(flatten(args).tupleof); 140 | } 141 | 142 | Joint!(T) jointImpl(T...)(T args) { 143 | return Joint!(T)(args); 144 | } 145 | 146 | /**Iterate over a set of ranges by value in lockstep and return an ObsEnt, 147 | * which is used internally by entropy functions on each iteration.*/ 148 | struct Joint(T...) { 149 | T _jointRanges; 150 | 151 | @property ObsEnt!(ElementsTuple!(T)) front() { 152 | alias ElementsTuple!(T) E; 153 | alias ObsEnt!(E) rt; 154 | rt ret; 155 | foreach(ti, elem; _jointRanges) { 156 | ret.tupleof[ti] = elem.front; 157 | } 158 | return ret; 159 | } 160 | 161 | void popFront() { 162 | foreach(ti, elem; _jointRanges) { 163 | _jointRanges[ti].popFront; 164 | } 165 | } 166 | 167 | @property bool empty() { 168 | foreach(elem; _jointRanges) { 169 | if(elem.empty) { 170 | return true; 171 | } 172 | } 173 | return false; 174 | } 175 | 176 | static if(T.length > 0 && allSatisfy!(hasLength, T)) { 177 | @property size_t length() { 178 | size_t ret = size_t.max; 179 | foreach(range; _jointRanges) { 180 | auto len = range.length; 181 | if(len < ret) { 182 | ret = len; 183 | } 184 | } 185 | return ret; 186 | } 187 | } 188 | } 189 | 190 | template ElementsTuple(T...) { 191 | static if(T.length == 1) { 192 | alias TypeTuple!(Unqual!(ElementType!(T[0]))) ElementsTuple; 193 | } else { 194 | alias TypeTuple!(Unqual!(ElementType!(T[0])), ElementsTuple!(T[1..$])) 195 | ElementsTuple; 196 | } 197 | } 198 | 199 | private template Comparable(T) { 200 | enum bool Comparable = is(typeof({ 201 | T a; 202 | T b; 203 | return a < b; })); 204 | } 205 | 206 | static assert(Comparable!ubyte); 207 | static assert(Comparable!ubyte); 208 | 209 | struct ObsEnt(T...) { 210 | T compRep; 211 | alias compRep this; 212 | 213 | static if(anySatisfy!(hasIndirections, T)) { 214 | 215 | // Then there's indirection involved. We can't just do all our 216 | // comparison and hashing operations bitwise. 217 | hash_t toHash() const nothrow @trusted { 218 | hash_t sum = 0; 219 | foreach(i, elem; this.tupleof) { 220 | sum *= 11; 221 | static if(is(elem : long) && elem.sizeof <= hash_t.sizeof) { 222 | sum += elem; 223 | } else static if(__traits(compiles, elem.toHash)) { 224 | sum += elem.toHash; 225 | } else { 226 | auto ti = typeid(typeof(elem)); 227 | sum += ti.getHash(&elem); 228 | } 229 | } 230 | return sum; 231 | } 232 | 233 | bool opEquals(const ref typeof(this) rhs) const { 234 | foreach(ti, elem; this.tupleof) { 235 | if(elem != rhs.tupleof[ti]) 236 | return false; 237 | } 238 | return true; 239 | } 240 | } 241 | // Else just use the default runtime functions for hash and equality. 242 | 243 | 244 | static if(allSatisfy!(Comparable, T)) { 245 | int opCmp(const ref typeof(this) rhs) const { 246 | foreach(ti, elem; this.tupleof) { 247 | if(rhs.tupleof[ti] < elem) { 248 | return -1; 249 | } else if(rhs.tupleof[ti] > elem) { 250 | return 1; 251 | } 252 | } 253 | return 0; 254 | } 255 | } 256 | } 257 | 258 | // Whether we can use StackTreeAA, or whether we have to use a regular AA for 259 | // entropy. 260 | package template NeedsHeap(T) { 261 | static if(!hasIndirections!(ForeachType!(T))) { 262 | enum bool NeedsHeap = false; 263 | } else static if(isArray!(T)) { 264 | enum bool NeedsHeap = false; 265 | } else static if(is(Joint!(typeof(T.init.tupleof))) 266 | && is(T == Joint!(typeof(T.init.tupleof))) 267 | && allSatisfy!(isArray, typeof(T.init.tupleof))) { 268 | enum bool NeedsHeap = false; 269 | } else { 270 | enum bool NeedsHeap = true; 271 | } 272 | } 273 | 274 | unittest { 275 | auto foo = filter!"a"(cast(uint[][]) [[1]]); 276 | auto bar = filter!("a")([1,2,3][]); 277 | static assert(NeedsHeap!(typeof(foo))); 278 | static assert(!NeedsHeap!(typeof(bar))); 279 | static assert(NeedsHeap!(Joint!(uint[], typeof(foo)))); 280 | static assert(!NeedsHeap!(Joint!(uint[], typeof(bar)))); 281 | static assert(!NeedsHeap!(Joint!(uint[], uint[]))); 282 | } 283 | 284 | /**Calculates the joint entropy of a set of observations. Each input range 285 | * represents a vector of observations. If only one range is given, this reduces 286 | * to the plain old entropy. Input range must have a length. 287 | * 288 | * Note: This function specializes if ElementType!(T) is a byte, ubyte, or 289 | * char, resulting in a much faster entropy calculation. When possible, try 290 | * to provide data in the form of a byte, ubyte, or char. 291 | * 292 | * Examples: 293 | * --- 294 | * int[] foo = [1, 1, 1, 2, 2, 2, 3, 3, 3]; 295 | * double entropyFoo = entropy(foo); // Plain old entropy of foo. 296 | * assert(approxEqual(entropyFoo, log2(3))); 297 | * int[] bar = [1, 2, 3, 1, 2, 3, 1, 2, 3]; 298 | * double HFooBar = entropy(joint(foo, bar)); // Joint entropy of foo and bar. 299 | * assert(approxEqual(HFooBar, log2(9))); 300 | * --- 301 | */ 302 | double entropy(T)(T data) 303 | if(isIterable!(T)) { 304 | static if(!hasLength!(T)) { 305 | return entropyImpl!(uint, T)(data); 306 | } else { 307 | if(data.length <= ubyte.max) { 308 | return entropyImpl!(ubyte, T)(data); 309 | } else if(data.length <= ushort.max) { 310 | return entropyImpl!(ushort, T)(data); 311 | } else { 312 | return entropyImpl!(uint, T)(data); 313 | } 314 | } 315 | } 316 | 317 | private double entropyImpl(U, T)(T data) 318 | if((ForeachType!(T).sizeof > 1 || is(ForeachType!T == struct)) && !NeedsHeap!(T)) { 319 | // Generic version. 320 | auto alloc = newRegionAllocator(); 321 | alias ForeachType!(T) E; 322 | 323 | static if(hasLength!T) { 324 | auto counts = StackHash!(E, U)(max(20, data.length / 20), alloc); 325 | } else { 326 | auto counts = StackTreeAA!(E, U)(alloc); 327 | } 328 | uint N; 329 | 330 | foreach(elem; data) { 331 | counts[elem]++; 332 | N++; 333 | } 334 | 335 | double ans = entropyCounts(counts.values, N); 336 | return ans; 337 | } 338 | 339 | private double entropyImpl(U, T)(T data) 340 | if(ForeachType!(T).sizeof > 1 && NeedsHeap!(T)) { // Generic version. 341 | alias ForeachType!(T) E; 342 | 343 | uint len = 0; 344 | U[E] counts; 345 | foreach(elem; data) { 346 | len++; 347 | counts[elem]++; 348 | } 349 | return entropyCounts(counts, len); 350 | } 351 | 352 | private double entropyImpl(U, T)(T data) // byte/char specialization 353 | if(ForeachType!(T).sizeof == 1 && !is(ForeachType!T == struct)) { 354 | alias ForeachType!(T) E; 355 | 356 | U[ubyte.max + 1] counts; 357 | 358 | uint min = ubyte.max, max = 0, len = 0; 359 | foreach(elem; data) { 360 | len++; 361 | static if(is(E == byte)) { 362 | // Keep adjacent elements adjacent. In real world use cases, 363 | // probably will have ranges like [-1, 1]. 364 | ubyte e = cast(ubyte) (cast(ubyte) (elem) + byte.max); 365 | } else { 366 | ubyte e = cast(ubyte) elem; 367 | } 368 | counts[e]++; 369 | if(e > max) { 370 | max = e; 371 | } 372 | if(e < min) { 373 | min = e; 374 | } 375 | } 376 | 377 | return entropyCounts(counts.ptr[min..max + 1], len); 378 | } 379 | 380 | unittest { 381 | { // Generic version. 382 | int[] foo = [1, 1, 1, 2, 2, 2, 3, 3, 3]; 383 | double entropyFoo = entropy(foo); 384 | assert(approxEqual(entropyFoo, log2(3.0))); 385 | int[] bar = [1, 2, 3, 1, 2, 3, 1, 2, 3]; 386 | auto stuff = joint(foo, bar); 387 | double jointEntropyFooBar = entropy(joint(foo, bar)); 388 | assert(approxEqual(jointEntropyFooBar, log2(9.0))); 389 | } 390 | { // byte specialization 391 | byte[] foo = [-1, -1, -1, 2, 2, 2, 3, 3, 3]; 392 | double entropyFoo = entropy(foo); 393 | assert(approxEqual(entropyFoo, log2(3.0))); 394 | string bar = "ACTGGCTA"; 395 | assert(entropy(bar) == 2); 396 | } 397 | { // NeedsHeap version. 398 | string[] arr = ["1", "1", "1", "2", "2", "2", "3", "3", "3"]; 399 | auto m = map!("a")(arr); 400 | assert(approxEqual(entropy(m), log2(3.0))); 401 | } 402 | } 403 | 404 | /**Calculate the conditional entropy H(data | cond).*/ 405 | double condEntropy(T, U)(T data, U cond) 406 | if(isInputRange!(T) && isInputRange!(U)) { 407 | static if(isForwardRange!U) { 408 | alias cond condForward; 409 | } else { 410 | auto alloc = newRegionAllocator(); 411 | auto condForward = alloc.array(cond); 412 | } 413 | 414 | return entropy(joint(data, condForward.save)) - entropy(condForward.save); 415 | } 416 | 417 | unittest { 418 | // This shouldn't be easy to screw up. Just really basic. 419 | int[] foo = [1,2,2,1,1]; 420 | int[] bar = [1,2,3,1,2]; 421 | assert(approxEqual(entropy(foo) - condEntropy(foo, bar), 422 | mutualInfo(foo, bar))); 423 | } 424 | 425 | private double miContingency(double observed, double expected) { 426 | return (observed == 0) ? 0 : 427 | (observed * log2(observed / expected)); 428 | } 429 | 430 | 431 | /**Calculates the mutual information of two vectors of discrete observations. 432 | */ 433 | double mutualInfo(T, U)(T x, U y) 434 | if(isInputRange!(T) && isInputRange!(U)) { 435 | uint xFreedom, yFreedom, n; 436 | typeof(return) ret; 437 | 438 | static if(!hasLength!T && !hasLength!U) { 439 | ret = toContingencyScore!(T, U, uint) 440 | (x, y, &miContingency, xFreedom, yFreedom, n); 441 | } else { 442 | immutable minLen = min(x.length, y.length); 443 | if(minLen <= ubyte.max) { 444 | ret = toContingencyScore!(T, U, ubyte) 445 | (x, y, &miContingency, xFreedom, yFreedom, n); 446 | } else if(minLen <= ushort.max) { 447 | ret = toContingencyScore!(T, U, ushort) 448 | (x, y, &miContingency, xFreedom, yFreedom, n); 449 | } else { 450 | ret = toContingencyScore!(T, U, uint) 451 | (x, y, &miContingency, xFreedom, yFreedom, n); 452 | } 453 | } 454 | 455 | return ret / n; 456 | } 457 | 458 | unittest { 459 | // Values from R, but converted from base e to base 2. 460 | assert(approxEqual(mutualInfo(bin([1,2,3,3,8].dup, 10), 461 | bin([8,6,7,5,3].dup, 10)), 1.921928)); 462 | assert(approxEqual(mutualInfo(bin([1,2,1,1,3,4,3,6].dup, 2), 463 | bin([2,7,9,6,3,1,7,40].dup, 2)), .2935645)); 464 | assert(approxEqual(mutualInfo(bin([1,2,1,1,3,4,3,6].dup, 4), 465 | bin([2,7,9,6,3,1,7,40].dup, 4)), .5435671)); 466 | 467 | } 468 | 469 | /** 470 | Calculates the mutual information of a contingency table representing a joint 471 | discrete probability distribution. Takes a set of finite forward ranges, 472 | one for each column in the contingency table. These can be expressed either as 473 | a tuple of ranges or a range of ranges. 474 | */ 475 | double mutualInfoTable(T...)(T table) { 476 | // This function is really just included to give conceptual unity to 477 | // the infotheory module. 478 | return gTestContingency(table).mutualInfo; 479 | } 480 | 481 | /** 482 | Calculates the conditional mutual information I(x, y | z) from a set of 483 | observations. 484 | */ 485 | double condMutualInfo(T, U, V)(T x, U y, V z) { 486 | auto ret = entropy(joint(x, z)) - entropy(joint(x, y, z)) - entropy(z) 487 | + entropy(joint(y, z)); 488 | return max(ret, 0); 489 | } 490 | 491 | unittest { 492 | // Values from Matlab mi package by Hanchuan Peng. 493 | auto res = condMutualInfo([1,2,1,2,1,2,1,2].dup, [3,1,2,3,4,2,1,2].dup, 494 | [1,2,3,1,2,3,1,2].dup); 495 | assert(approxEqual(res, 0.4387)); 496 | res = condMutualInfo([1,2,3,1,2].dup, [2,1,3,2,1].dup, 497 | joint([1,1,1,2,2].dup, [2,2,2,1,1].dup)); 498 | assert(approxEqual(res, 1.3510)); 499 | } 500 | 501 | /**Calculates the entropy of any old input range of observations more quickly 502 | * than entropy(), provided that all equal values are adjacent. If the input 503 | * is sorted by more than one key, i.e. structs, the result will be the joint 504 | * entropy of all of the keys. The compFun alias will be used to compare 505 | * adjacent elements and determine how many instances of each value exist.*/ 506 | double entropySorted(alias compFun = "a == b", T)(T data) 507 | if(isInputRange!(T)) { 508 | alias ElementType!(T) E; 509 | alias binaryFun!(compFun) comp; 510 | immutable n = data.length; 511 | immutable nrNeg1 = 1.0L / n; 512 | 513 | double sum = 0.0; 514 | int nSame = 1; 515 | auto last = data.front; 516 | data.popFront; 517 | foreach(elem; data) { 518 | if(comp(elem, last)) { 519 | nSame++; 520 | } else { 521 | immutable p = nSame * nrNeg1; 522 | nSame = 1; 523 | sum -= p * log2(p); 524 | } 525 | last = elem; 526 | } 527 | // Handle last run. 528 | immutable p = nSame * nrNeg1; 529 | sum -= p * log2(p); 530 | 531 | return sum; 532 | } 533 | 534 | unittest { 535 | uint[] foo = [1U,2,3,1,3,2,6,3,1,6,3,2,2,1,3,5,2,1].dup; 536 | auto sorted = foo.dup; 537 | sort(sorted); 538 | assert(approxEqual(entropySorted(sorted), entropy(foo))); 539 | } 540 | 541 | /** 542 | Much faster implementations of information theory functions for the special 543 | but common case where all observations are integers on the range [0, nBin$(RPAREN). 544 | This is the case, for example, when the observations have been previously 545 | binned using, for example, dstats.base.frqBin(). 546 | 547 | Note that, due to the optimizations used, joint() cannot be used with 548 | the member functions of this struct, except entropy(). 549 | 550 | For those looking for hard numbers, this seems to be on the order of 10x 551 | faster than the generic implementations according to my quick and dirty 552 | benchmarks. 553 | */ 554 | struct DenseInfoTheory { 555 | private uint nBin; 556 | 557 | // Saves space and makes things cache efficient by using the smallest 558 | // integer width necessary for binning. 559 | double selectSize(alias fun, T...)(T args) { 560 | static if(allSatisfy!(hasLength, T)) { 561 | immutable len = args[0].length; 562 | 563 | if(len <= ubyte.max) { 564 | return fun!ubyte(args); 565 | } else if(len <= ushort.max) { 566 | return fun!ushort(args); 567 | } else { 568 | return fun!uint(args); 569 | } 570 | 571 | // For now, assume that noone is going to have more than 572 | // 4 billion observations. 573 | } else { 574 | return fun!uint(args); 575 | } 576 | } 577 | 578 | /** 579 | Constructs a DenseInfoTheory object for nBin bins. The values taken by 580 | each observation must then be on the interval [0, nBin$(RPAREN). 581 | */ 582 | this(uint nBin) { 583 | this.nBin = nBin; 584 | } 585 | 586 | /** 587 | Computes the entropy of a set of observations. Note that, for this 588 | function, the joint() function can be used to compute joint entropies 589 | as long as each individual range contains only integers on [0, nBin$(RPAREN). 590 | */ 591 | double entropy(R)(R range) if(isIterable!R) { 592 | return selectSize!entropyImpl(range); 593 | } 594 | 595 | private double entropyImpl(Uint, R)(R range) { 596 | auto alloc = newRegionAllocator(); 597 | uint n = 0; 598 | 599 | static if(is(typeof(range._jointRanges))) { 600 | // Compute joint entropy. 601 | immutable nRanges = range._jointRanges.length; 602 | auto counts = alloc.uninitializedArray!(Uint[])(nBin ^^ nRanges); 603 | counts[] = 0; 604 | 605 | Outer: 606 | while(true) { 607 | uint multiplier = 1; 608 | uint index = 0; 609 | 610 | foreach(ti, Unused; typeof(range._jointRanges)) { 611 | if(range._jointRanges[ti].empty) break Outer; 612 | immutable rFront = range._jointRanges[ti].front; 613 | assert(rFront < nBin); // Enforce is too costly here. 614 | 615 | index += multiplier * cast(uint) rFront; 616 | range._jointRanges[ti].popFront(); 617 | multiplier *= nBin; 618 | } 619 | 620 | counts[index]++; 621 | n++; 622 | } 623 | 624 | return entropyCounts(counts, n); 625 | } else { 626 | auto counts = alloc.uninitializedArray!(Uint[])(nBin); 627 | 628 | counts[] = 0; 629 | foreach(elem; range) { 630 | counts[elem]++; 631 | n++; 632 | } 633 | 634 | return entropyCounts(counts, n); 635 | } 636 | } 637 | 638 | /// I(x; y) 639 | double mutualInfo(R1, R2)(R1 x, R2 y) 640 | if(isIterable!R1 && isIterable!R2) { 641 | return selectSize!mutualInfoImpl(x, y); 642 | } 643 | 644 | private double mutualInfoImpl(Uint, R1, R2)(R1 x, R2 y) { 645 | auto alloc = newRegionAllocator(); 646 | auto joint = alloc.uninitializedArray!(Uint[])(nBin * nBin); 647 | auto margx = alloc.uninitializedArray!(Uint[])(nBin); 648 | auto margy = alloc.uninitializedArray!(Uint[])(nBin); 649 | joint[] = 0; 650 | margx[] = 0; 651 | margy[] = 0; 652 | uint n; 653 | 654 | while(!x.empty && !y.empty) { 655 | immutable xFront = cast(uint) x.front; 656 | immutable yFront = cast(uint) y.front; 657 | assert(xFront < nBin); 658 | assert(yFront < nBin); 659 | 660 | joint[xFront * nBin + yFront]++; 661 | margx[xFront]++; 662 | margy[yFront]++; 663 | n++; 664 | x.popFront(); 665 | y.popFront(); 666 | } 667 | 668 | auto ret = entropyCounts(margx, n) + entropyCounts(margy, n) - 669 | entropyCounts(joint, n); 670 | return max(0, ret); 671 | } 672 | 673 | /** 674 | Calculates the P-value for I(X; Y) assuming x and y both have supports 675 | of [0, nBin$(RPAREN). The P-value is calculated using a Chi-Square approximation. 676 | It is asymptotically correct, but is approximate for finite sample size. 677 | 678 | Parameters: 679 | mutualInfo: I(x; y), in bits 680 | n: The number of samples used to calculate I(x; y) 681 | */ 682 | double mutualInfoPval(double mutualInfo, double n) { 683 | immutable df = (nBin - 1) ^^ 2; 684 | 685 | immutable testStat = mutualInfo * 2 * LN2 * n; 686 | return chiSquareCDFR(testStat, df); 687 | } 688 | 689 | /// H(X | Y) 690 | double condEntropy(R1, R2)(R1 x, R2 y) 691 | if(isIterable!R1 && isIterable!R2) { 692 | return selectSize!condEntropyImpl(x, y); 693 | } 694 | 695 | private double condEntropyImpl(Uint, R1, R2)(R1 x, R2 y) { 696 | auto alloc = newRegionAllocator(); 697 | auto joint = alloc.uninitializedArray!(Uint[])(nBin * nBin); 698 | auto margy = alloc.uninitializedArray!(Uint[])(nBin); 699 | joint[] = 0; 700 | margy[] = 0; 701 | uint n; 702 | 703 | while(!x.empty && !y.empty) { 704 | immutable xFront = cast(uint) x.front; 705 | immutable yFront = cast(uint) y.front; 706 | assert(xFront < nBin); 707 | assert(yFront < nBin); 708 | 709 | joint[xFront * nBin + yFront]++; 710 | margy[yFront]++; 711 | n++; 712 | x.popFront(); 713 | y.popFront(); 714 | } 715 | 716 | auto ret = entropyCounts(joint, n) - entropyCounts(margy, n); 717 | return max(0, ret); 718 | } 719 | 720 | /// I(X; Y | Z) 721 | double condMutualInfo(R1, R2, R3)(R1 x, R2 y, R3 z) 722 | if(allSatisfy!(isIterable, R1, R2, R3)) { 723 | return selectSize!condMutualInfoImpl(x, y, z); 724 | } 725 | 726 | private double condMutualInfoImpl(Uint, R1, R2, R3)(R1 x, R2 y, R3 z) { 727 | auto alloc = newRegionAllocator(); 728 | immutable nBinSq = nBin * nBin; 729 | auto jointxyz = alloc.uninitializedArray!(Uint[])(nBin * nBin * nBin); 730 | auto jointxz = alloc.uninitializedArray!(Uint[])(nBinSq); 731 | auto jointyz = alloc.uninitializedArray!(Uint[])(nBinSq); 732 | auto margz = alloc.uninitializedArray!(Uint[])(nBin); 733 | jointxyz[] = 0; 734 | jointxz[] = 0; 735 | jointyz[] = 0; 736 | margz[] = 0; 737 | uint n = 0; 738 | 739 | while(!x.empty && !y.empty && !z.empty) { 740 | immutable xFront = cast(uint) x.front; 741 | immutable yFront = cast(uint) y.front; 742 | immutable zFront = cast(uint) z.front; 743 | assert(xFront < nBin); 744 | assert(yFront < nBin); 745 | assert(zFront < nBin); 746 | 747 | jointxyz[xFront * nBinSq + yFront * nBin + zFront]++; 748 | jointxz[xFront * nBin + zFront]++; 749 | jointyz[yFront * nBin + zFront]++; 750 | margz[zFront]++; 751 | n++; 752 | 753 | x.popFront(); 754 | y.popFront(); 755 | z.popFront(); 756 | } 757 | 758 | auto ret = entropyCounts(jointxz, n) - entropyCounts(jointxyz, n) - 759 | entropyCounts(margz, n) + entropyCounts(jointyz, n); 760 | return max(0, ret); 761 | } 762 | } 763 | 764 | unittest { 765 | alias ae = approxEqual; 766 | auto dense = DenseInfoTheory(3); 767 | auto a = [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]; 768 | auto b = [1, 2, 2, 2, 0, 0, 1, 1, 1, 1, 0, 0]; 769 | auto c = [1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 0, 0]; 770 | 771 | // need some approxEquals in here because some methods do floating 772 | // point ops in hash-dependent orders 773 | 774 | assert(entropy(a) == dense.entropy(a)); 775 | assert(entropy(b) == dense.entropy(b)); 776 | assert(entropy(c) == dense.entropy(c)); 777 | assert(ae(entropy(joint(a, c)), dense.entropy(joint(c, a)))); 778 | assert(ae(entropy(joint(a, b)), dense.entropy(joint(a, b)))); 779 | assert(entropy(joint(c, b)) == dense.entropy(joint(c, b))); 780 | 781 | assert(condEntropy(a, c) == dense.condEntropy(a, c)); 782 | assert(ae(condEntropy(a, b), dense.condEntropy(a, b))); 783 | assert(condEntropy(c, b) == dense.condEntropy(c, b)); 784 | 785 | assert(ae(mutualInfo(a, c), dense.mutualInfo(c, a))); 786 | assert(ae(mutualInfo(a, b), dense.mutualInfo(a, b))); 787 | assert(ae(mutualInfo(c, b), dense.mutualInfo(c, b))); 788 | 789 | assert(ae(condMutualInfo(a, b, c), dense.condMutualInfo(a, b, c))); 790 | assert(ae(condMutualInfo(a, c, b), dense.condMutualInfo(a, c, b))); 791 | assert(ae(condMutualInfo(b, c, a), dense.condMutualInfo(b, c, a))); 792 | 793 | // Test P-value stuff. 794 | immutable pDense = dense.mutualInfoPval(dense.mutualInfo(a, b), a.length); 795 | immutable pNotDense = gTestObs(a, b).p; 796 | assert(approxEqual(pDense, pNotDense)); 797 | } 798 | -------------------------------------------------------------------------------- /source/dstats/summary.d: -------------------------------------------------------------------------------- 1 | /**Summary statistics such as mean, median, sum, variance, skewness, kurtosis. 2 | * Except for median and median absolute deviation, which cannot be calculated 3 | * online, all summary statistics have both an input range interface and an 4 | * output range interface. 5 | * 6 | * Notes: The put method on the structs defined in this module returns this by 7 | * ref. The use case for returning this is to enable these structs 8 | * to be used with std.algorithm.reduce. The rationale for returning 9 | * by ref is that the return value usually won't be used, and the 10 | * overhead of returning a large struct by value should be avoided. 11 | * 12 | * Bugs: This whole module assumes that input will be doubles or types implicitly 13 | * convertible to double. No allowances are made for user-defined numeric 14 | * types such as BigInts. This is necessary for simplicity. However, 15 | * if you have a function that converts your data to doubles, most of 16 | * these functions work with any input range, so you can simply map 17 | * this function onto your range. 18 | * 19 | * Author: David Simcha 20 | */ 21 | /* 22 | * Copyright (C) 2008-2010 David Simcha 23 | * 24 | * License: 25 | * Boost Software License - Version 1.0 - August 17th, 2003 26 | * 27 | * Permission is hereby granted, free of charge, to any person or organization 28 | * obtaining a copy of the software and accompanying documentation covered by 29 | * this license (the "Software") to use, reproduce, display, distribute, 30 | * execute, and transmit the Software, and to prepare derivative works of the 31 | * Software, and to permit third-parties to whom the Software is furnished to 32 | * do so, all subject to the following: 33 | * 34 | * The copyright notices in the Software and this entire statement, including 35 | * the above license grant, this restriction and the following disclaimer, 36 | * must be included in all copies of the Software, in whole or in part, and 37 | * all derivative works of the Software, unless such copies or derivative 38 | * works are solely in the form of machine-executable object code generated by 39 | * a source language processor. 40 | * 41 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 42 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 43 | * FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT 44 | * SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE 45 | * FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, 46 | * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 47 | * DEALINGS IN THE SOFTWARE. 48 | */ 49 | 50 | 51 | module dstats.summary; 52 | 53 | import std.functional, std.conv, std.range, std.array, 54 | std.traits, std.math; 55 | 56 | import std.algorithm : reduce, min, max, swap, map, filter; 57 | 58 | import dstats.sort, dstats.base, dstats.alloc; 59 | 60 | version(unittest) { 61 | import std.stdio, dstats.random; 62 | } 63 | 64 | /**Finds median of an input range in O(N) time on average. In the case of an 65 | * even number of elements, the mean of the two middle elements is returned. 66 | * This is a convenience founction designed specifically for numeric types, 67 | * where the averaging of the two middle elements is desired. A more general 68 | * selection algorithm that can handle any type with a total ordering, as well 69 | * as selecting any position in the ordering, can be found at 70 | * dstats.sort.quickSelect() and dstats.sort.partitionK(). 71 | * Allocates memory, does not reorder input data.*/ 72 | double median(T)(T data) 73 | if(doubleInput!(T)) { 74 | auto alloc = newRegionAllocator(); 75 | auto dataDup = alloc.array(data); 76 | return medianPartition(dataDup); 77 | } 78 | 79 | /**Median finding as in median(), but will partition input data such that 80 | * elements less than the median will have smaller indices than that of the 81 | * median, and elements larger than the median will have larger indices than 82 | * that of the median. Useful both for its partititioning and to avoid 83 | * memory allocations. Requires a random access range with swappable 84 | * elements.*/ 85 | double medianPartition(T)(T data) 86 | if(isRandomAccessRange!(T) && 87 | is(ElementType!(T) : double) && 88 | hasSwappableElements!(T) && 89 | hasLength!(T)) 90 | { 91 | if(data.length == 0) { 92 | return double.nan; 93 | } 94 | // Upper half of median in even length case is just the smallest element 95 | // with an index larger than the lower median, after the array is 96 | // partially sorted. 97 | if(data.length == 1) { 98 | return data[0]; 99 | } else if(data.length & 1) { //Is odd. 100 | return cast(double) partitionK(data, data.length / 2); 101 | } else { 102 | auto lower = partitionK(data, data.length / 2 - 1); 103 | auto upper = ElementType!(T).max; 104 | 105 | // Avoid requiring slicing to be supported. 106 | foreach(i; data.length / 2..data.length) { 107 | if(data[i] < upper) { 108 | upper = data[i]; 109 | } 110 | } 111 | return lower * 0.5 + upper * 0.5; 112 | } 113 | } 114 | 115 | unittest { 116 | float brainDeadMedian(float[] foo) { 117 | qsort(foo); 118 | if(foo.length & 1) 119 | return foo[$ / 2]; 120 | return (foo[$ / 2] + foo[$ / 2 - 1]) / 2; 121 | } 122 | 123 | float[] test = new float[1000]; 124 | size_t upperBound, lowerBound; 125 | foreach(testNum; 0..1000) { 126 | foreach(ref e; test) { 127 | e = uniform(0f, 1000f); 128 | } 129 | do { 130 | upperBound = uniform(0u, test.length); 131 | lowerBound = uniform(0u, test.length); 132 | } while(lowerBound == upperBound); 133 | if(lowerBound > upperBound) { 134 | swap(lowerBound, upperBound); 135 | } 136 | auto quickRes = median(test[lowerBound..upperBound]); 137 | auto accurateRes = brainDeadMedian(test[lowerBound..upperBound]); 138 | 139 | // Off by some tiny fraction in even N case because of division. 140 | // No idea why, but it's too small a rounding error to care about. 141 | assert(approxEqual(quickRes, accurateRes)); 142 | } 143 | 144 | // Make sure everything works with lowest common denominator range type. 145 | static struct Count { 146 | uint num; 147 | uint upTo; 148 | @property size_t front() { 149 | return num; 150 | } 151 | void popFront() { 152 | num++; 153 | } 154 | @property bool empty() { 155 | return num >= upTo; 156 | } 157 | } 158 | 159 | Count a; 160 | a.upTo = 100; 161 | assert(approxEqual(median(a), 49.5)); 162 | } 163 | 164 | /**Plain old data holder struct for median, median absolute deviation. 165 | * Alias this'd to the median absolute deviation member. 166 | */ 167 | struct MedianAbsDev { 168 | double median; 169 | double medianAbsDev; 170 | 171 | alias medianAbsDev this; 172 | } 173 | 174 | /**Calculates the median absolute deviation of a dataset. This is the median 175 | * of all absolute differences from the median of the dataset. 176 | * 177 | * Returns: A MedianAbsDev struct that contains the median (since it is 178 | * computed anyhow) and the median absolute deviation. 179 | * 180 | * Notes: No bias correction is used in this implementation, since using 181 | * one would require assumptions about the underlying distribution of the data. 182 | */ 183 | MedianAbsDev medianAbsDev(T)(T data) 184 | if(doubleInput!(T)) { 185 | auto alloc = newRegionAllocator(); 186 | auto dataDup = alloc.array(data); 187 | immutable med = medianPartition(dataDup); 188 | immutable len = dataDup.length; 189 | alloc.freeLast(); 190 | 191 | double[] devs = alloc.uninitializedArray!(double[])(len); 192 | 193 | size_t i = 0; 194 | foreach(elem; data) { 195 | devs[i++] = abs(med - elem); 196 | } 197 | auto ret = medianPartition(devs); 198 | alloc.freeLast(); 199 | return MedianAbsDev(med, ret); 200 | } 201 | 202 | unittest { 203 | assert(approxEqual(medianAbsDev([7,1,8,2,8,1,9,2,8,4,5,9].dup).medianAbsDev, 2.5L)); 204 | assert(approxEqual(medianAbsDev([8,6,7,5,3,0,999].dup).medianAbsDev, 2.0L)); 205 | } 206 | 207 | /**Computes the interquantile range of data at the given quantile value in O(N) 208 | * time complexity. For example, using a quantile value of either 0.25 or 0.75 209 | * will give the interquartile range. (This is the default since it is 210 | * apparently the most common interquantile range in common usage.) 211 | * Using a quantile value of 0.2 or 0.8 will give the interquntile range. 212 | * 213 | * If the quantile point falls between two indices, linear interpolation is 214 | * used. 215 | * 216 | * This function is somewhat more efficient than simply finding the upper and 217 | * lower quantile and subtracting them. 218 | * 219 | * Tip: A quantile of 0 or 1 is handled as a special case and will compute the 220 | * plain old range of the data in a single pass. 221 | */ 222 | double interquantileRange(R)(R data, double quantile = 0.25) 223 | if(doubleInput!R) { 224 | alias quantile q; // Save typing. 225 | dstatsEnforce(q >= 0 && q <= 1, 226 | "Quantile must be between 0, 1 for interquantileRange."); 227 | 228 | auto alloc = newRegionAllocator(); 229 | if(q > 0.5) { 230 | q = 1.0 - q; 231 | } 232 | 233 | if(q == 0) { // Special case: Compute the plain old range. 234 | double minElem = double.infinity; 235 | double maxElem = -double.infinity; 236 | 237 | foreach(elem; data) { 238 | minElem = min(minElem, elem); 239 | maxElem = max(maxElem, elem); 240 | } 241 | 242 | return maxElem - minElem; 243 | } 244 | 245 | // Common case. 246 | auto duped = alloc.array(data); 247 | immutable double N = duped.length; 248 | if(duped.length < 2) { 249 | return double.nan; // Can't do it. 250 | } 251 | 252 | immutable lowEnd = to!size_t((N - 1) * q); 253 | immutable lowFract = (N - 1) * q - lowEnd; 254 | 255 | partitionK(duped, lowEnd); 256 | immutable lowQuantile1 = duped[lowEnd]; 257 | double minAbove = double.infinity; 258 | 259 | foreach(elem; duped[lowEnd + 1..$]) { 260 | minAbove = min(minAbove, elem); 261 | } 262 | 263 | immutable lowerQuantile = 264 | lowFract * minAbove + (1 - lowFract) * lowQuantile1; 265 | 266 | immutable highEnd = to!size_t((N - 1) * (1.0 - q) - lowEnd); 267 | immutable highFract = (N - 1) * (1.0 - q) - lowEnd - highEnd; 268 | duped = duped[lowEnd..$]; 269 | assert(highEnd < duped.length - 1); 270 | 271 | partitionK(duped, highEnd); 272 | immutable minAbove2 = reduce!min(double.infinity, duped[highEnd + 1..$]); 273 | immutable upperQuantile = minAbove2 * highFract 274 | + duped[highEnd] * (1 - highFract); 275 | 276 | return upperQuantile - lowerQuantile; 277 | } 278 | 279 | unittest { 280 | // 0 3 5 6 7 8 9 281 | assert(approxEqual(interquantileRange([1,2,3,4,5,6,7,8]), 3.5)); 282 | assert(approxEqual(interquantileRange([1,2,3,4,5,6,7,8,9]), 4)); 283 | assert(interquantileRange([1,9,2,4,3,6,8], 0) == 8); 284 | assert(approxEqual(interquantileRange([8,6,7,5,3,0,9], 0.2), 4.4)); 285 | } 286 | 287 | /**Output range to calculate the mean online. Getter for mean costs a branch to 288 | * check for N == 0. This struct uses O(1) space and does *NOT* store the 289 | * individual elements. 290 | * 291 | * Note: This struct can implicitly convert to the value of the mean. 292 | * 293 | * Examples: 294 | * --- 295 | * Mean summ; 296 | * summ.put(1); 297 | * summ.put(2); 298 | * summ.put(3); 299 | * summ.put(4); 300 | * summ.put(5); 301 | * assert(summ.mean == 3); 302 | * ---*/ 303 | struct Mean { 304 | private: 305 | double result = 0; 306 | double k = 0; 307 | 308 | public: 309 | ///// Allow implicit casting to double, by returning the current mean. 310 | alias mean this; 311 | 312 | /// 313 | void put(double element) pure nothrow @safe { 314 | result += (element - result) / ++k; 315 | } 316 | 317 | /**Adds the contents of rhs to this instance. 318 | * 319 | * Examples: 320 | * --- 321 | * Mean mean1, mean2, combined; 322 | * foreach(i; 0..5) { 323 | * mean1.put(i); 324 | * } 325 | * 326 | * foreach(i; 5..10) { 327 | * mean2.put(i); 328 | * } 329 | * 330 | * mean1.put(mean2); 331 | * 332 | * foreach(i; 0..10) { 333 | * combined.put(i); 334 | * } 335 | * 336 | * assert(approxEqual(combined.mean, mean1.mean)); 337 | * --- 338 | */ 339 | void put(typeof(this) rhs) pure nothrow @safe { 340 | immutable totalN = k + rhs.k; 341 | result = result * (k / totalN) + rhs.result * (rhs.k / totalN); 342 | k = totalN; 343 | } 344 | 345 | const pure nothrow @property @safe { 346 | 347 | /// 348 | double sum() { 349 | return result * k; 350 | } 351 | 352 | /// 353 | double mean() { 354 | return (k == 0) ? double.nan : result; 355 | } 356 | 357 | /// 358 | double N() { 359 | return k; 360 | } 361 | 362 | /**Simply returns this. Useful in generic programming contexts.*/ 363 | Mean toMean() { 364 | return this; 365 | } 366 | } 367 | 368 | /// 369 | string toString() const { 370 | return to!(string)(mean); 371 | } 372 | } 373 | 374 | /**Finds the arithmetic mean of any input range whose elements are implicitly 375 | * convertible to double.*/ 376 | Mean mean(T)(T data) 377 | if(doubleIterable!(T)) { 378 | 379 | static if(isRandomAccessRange!T && hasLength!T) { 380 | // This is optimized for maximum instruction level parallelism: 381 | // The loop is unrolled such that there are 1 / (nILP)th the data 382 | // dependencies of the naive algorithm. 383 | enum nILP = 8; 384 | 385 | Mean ret; 386 | size_t i = 0; 387 | if(data.length > 2 * nILP) { 388 | double k = 0; 389 | double[nILP] means = 0; 390 | for(; i + nILP < data.length; i += nILP) { 391 | immutable kNeg1 = 1 / ++k; 392 | 393 | foreach(j; StaticIota!nILP) { 394 | means[j] += (data[i + j] - means[j]) * kNeg1; 395 | } 396 | } 397 | 398 | ret.k = k; 399 | ret.result = means[0]; 400 | foreach(m; means[1..$]) { 401 | ret.put( Mean(m, k)); 402 | } 403 | } 404 | 405 | // Handle the remainder. 406 | for(; i < data.length; i++) { 407 | ret.put(data[i]); 408 | } 409 | return ret; 410 | 411 | } else { 412 | // Just submit everything to a single Mean struct and return it. 413 | Mean meanCalc; 414 | 415 | foreach(element; data) { 416 | meanCalc.put(element); 417 | } 418 | return meanCalc; 419 | } 420 | } 421 | 422 | /**Output range to calculate the geometric mean online. 423 | * Operates similarly to dstats.summary.Mean*/ 424 | struct GeometricMean { 425 | private: 426 | Mean m; 427 | public: 428 | /////Allow implicit casting to double, by returning current geometric mean. 429 | alias geoMean this; 430 | 431 | /// 432 | void put(double element) pure nothrow @safe { 433 | m.put(log2(element)); 434 | } 435 | 436 | /// Combine two GeometricMean's. 437 | void put(typeof(this) rhs) pure nothrow @safe { 438 | m.put(rhs.m); 439 | } 440 | 441 | const pure nothrow @property { 442 | /// 443 | double geoMean() { 444 | return exp2(m.mean); 445 | } 446 | 447 | /// 448 | double N() { 449 | return m.k; 450 | } 451 | } 452 | 453 | /// 454 | string toString() const { 455 | return to!(string)(geoMean); 456 | } 457 | } 458 | 459 | /**Calculates the geometric mean of any input range that has elements implicitly 460 | * convertible to double*/ 461 | double geometricMean(T)(T data) 462 | if(doubleIterable!(T)) { 463 | // This is relatively seldom used and the log function is the bottleneck 464 | // anyhow, not worth ILP optimizing. 465 | GeometricMean m; 466 | foreach(elem; data) { 467 | m.put(elem); 468 | } 469 | return m.geoMean; 470 | } 471 | 472 | unittest { 473 | string[] data = ["1", "2", "3", "4", "5"]; 474 | auto foo = map!(to!(uint))(data); 475 | 476 | auto result = geometricMean(map!(to!(uint))(data)); 477 | assert(approxEqual(result, 2.60517)); 478 | 479 | Mean mean1, mean2, combined; 480 | foreach(i; 0..5) { 481 | mean1.put(i); 482 | } 483 | 484 | foreach(i; 5..10) { 485 | mean2.put(i); 486 | } 487 | 488 | mean1.put(mean2); 489 | 490 | foreach(i; 0..10) { 491 | combined.put(i); 492 | } 493 | 494 | assert(approxEqual(combined.mean, mean1.mean), 495 | text(combined.mean, " ", mean1.mean)); 496 | assert(combined.N == mean1.N); 497 | } 498 | 499 | /**Finds the sum of an input range whose elements implicitly convert to double. 500 | * User has option of making U a different type than T to prevent overflows 501 | * on large array summing operations. However, by default, return type is 502 | * T (same as input type).*/ 503 | U sum(T, U = Unqual!(ForeachType!(T)))(T data) 504 | if(doubleIterable!(T)) { 505 | 506 | static if(isRandomAccessRange!T && hasLength!T) { 507 | enum nILP = 8; 508 | U[nILP] sum = 0; 509 | 510 | size_t i = 0; 511 | if(data.length > 2 * nILP) { 512 | 513 | for(; i + nILP < data.length; i += nILP) { 514 | foreach(j; StaticIota!nILP) { 515 | sum[j] += data[i + j]; 516 | } 517 | } 518 | 519 | foreach(j; 1..nILP) { 520 | sum[0] += sum[j]; 521 | } 522 | } 523 | 524 | for(; i < data.length; i++) { 525 | sum[0] += data[i]; 526 | } 527 | 528 | return sum[0]; 529 | } else { 530 | U sum = 0; 531 | foreach(elem; data) { 532 | sum += elem; 533 | } 534 | 535 | return sum; 536 | } 537 | } 538 | 539 | unittest { 540 | assert(sum([1,2,3,4,5,6,7,8,9,10][]) == 55); 541 | assert(sum(filter!"true"([1,2,3,4,5,6,7,8,9,10][])) == 55); 542 | assert(sum(cast(int[]) [1,2,3,4,5])==15); 543 | assert(approxEqual( sum(cast(int[]) [40.0, 40.1, 5.2]), 85.3)); 544 | assert(mean(cast(int[]) [1,2,3]).mean == 2); 545 | assert(mean(cast(int[]) [1.0, 2.0, 3.0]).mean == 2.0); 546 | assert(mean([1, 2, 5, 10, 17][]).mean == 7); 547 | assert(mean([1, 2, 5, 10, 17][]).sum == 35); 548 | assert(approxEqual(mean([8,6,7,5,3,0,9,3,6,2,4,3,6][]).mean, 4.769231)); 549 | 550 | // Test the OO struct a little, since we're using the new ILP algorithm. 551 | Mean m; 552 | m.put(1); 553 | m.put(2); 554 | m.put(5); 555 | m.put(10); 556 | m.put(17); 557 | assert(m.mean == 7); 558 | 559 | foreach(i; 0..100) { 560 | // Monte carlo test the unrolled version. 561 | auto foo = randArray!rNormal(uniform(5, 100), 0, 1); 562 | auto res1 = mean(foo); 563 | Mean res2; 564 | foreach(elem; foo) { 565 | res2.put(elem); 566 | } 567 | 568 | foreach(ti, elem; res1.tupleof) { 569 | assert(approxEqual(elem, res2.tupleof[ti])); 570 | } 571 | } 572 | } 573 | 574 | 575 | /**Output range to compute mean, stdev, variance online. Getter methods 576 | * for stdev, var cost a few floating point ops. Getter for mean costs 577 | * a single branch to check for N == 0. Relatively expensive floating point 578 | * ops, if you only need mean, try Mean. This struct uses O(1) space and 579 | * does *NOT* store the individual elements. 580 | * 581 | * Note: This struct can implicitly convert to a Mean struct. 582 | * 583 | * References: Computing Higher-Order Moments Online. 584 | * http://people.xiph.org/~tterribe/notes/homs.html 585 | * 586 | * Examples: 587 | * --- 588 | * MeanSD summ; 589 | * summ.put(1); 590 | * summ.put(2); 591 | * summ.put(3); 592 | * summ.put(4); 593 | * summ.put(5); 594 | * assert(summ.mean == 3); 595 | * assert(summ.stdev == sqrt(2.5)); 596 | * assert(summ.var == 2.5); 597 | * ---*/ 598 | struct MeanSD { 599 | private: 600 | double _mean = 0; 601 | double _var = 0; 602 | double _k = 0; 603 | public: 604 | /// 605 | void put(double element) pure nothrow @safe { 606 | immutable kMinus1 = _k; 607 | immutable delta = element - _mean; 608 | immutable deltaN = delta / ++_k; 609 | 610 | _mean += deltaN; 611 | _var += kMinus1 * deltaN * delta; 612 | return; 613 | } 614 | 615 | /// Combine two MeanSD's. 616 | void put(typeof(this) rhs) pure nothrow @safe { 617 | if(_k == 0) { 618 | foreach(ti, elem; rhs.tupleof) { 619 | this.tupleof[ti] = elem; 620 | } 621 | 622 | return; 623 | } else if(rhs._k == 0) { 624 | return; 625 | } 626 | 627 | immutable totalN = _k + rhs._k; 628 | immutable delta = rhs._mean - _mean; 629 | _mean = _mean * (_k / totalN) + rhs._mean * (rhs._k / totalN); 630 | 631 | _var = _var + rhs._var + (_k / totalN * rhs._k * delta * delta); 632 | _k = totalN; 633 | } 634 | 635 | const pure nothrow @property @safe { 636 | 637 | /// 638 | double sum() { 639 | return _k * _mean; 640 | } 641 | 642 | /// 643 | double mean() { 644 | return (_k == 0) ? double.nan : _mean; 645 | } 646 | 647 | /// 648 | double stdev() { 649 | return sqrt(var); 650 | } 651 | 652 | /// 653 | double var() { 654 | return (_k < 2) ? double.nan : _var / (_k - 1); 655 | } 656 | 657 | /** 658 | Mean squared error. In other words, a biased estimate of variance. 659 | */ 660 | double mse() { 661 | return (_k < 1) ? double.nan : _var / _k; 662 | } 663 | 664 | /// 665 | double N() { 666 | return _k; 667 | } 668 | 669 | /**Converts this struct to a Mean struct. Also called when an 670 | * implicit conversion via alias this takes place. 671 | */ 672 | Mean toMean() { 673 | return Mean(_mean, _k); 674 | } 675 | 676 | /**Simply returns this. Useful in generic programming contexts.*/ 677 | MeanSD toMeanSD() const { 678 | return this; 679 | } 680 | } 681 | 682 | alias toMean this; 683 | 684 | /// 685 | string toString() const { 686 | return text("N = ", cast(ulong) _k, "\nMean = ", mean, "\nVariance = ", 687 | var, "\nStdev = ", stdev); 688 | } 689 | } 690 | 691 | /**Puts all elements of data into a MeanSD struct, 692 | * then returns this struct. This can be faster than doing this manually 693 | * due to ILP optimizations.*/ 694 | MeanSD meanStdev(T)(T data) 695 | if(doubleIterable!(T)) { 696 | 697 | MeanSD ret; 698 | 699 | static if(isRandomAccessRange!T && hasLength!T) { 700 | // Optimize for instruction level parallelism. 701 | enum nILP = 6; 702 | double k = 0; 703 | double[nILP] means = 0; 704 | double[nILP] variances = 0; 705 | size_t i = 0; 706 | 707 | if(data.length > 2 * nILP) { 708 | for(; i + nILP < data.length; i += nILP) { 709 | immutable kMinus1 = k; 710 | immutable kNeg1 = 1 / ++k; 711 | 712 | foreach(j; StaticIota!nILP) { 713 | immutable double delta = data[i + j] - means[j]; 714 | immutable deltaN = delta * kNeg1; 715 | 716 | means[j] += deltaN; 717 | variances[j] += kMinus1 * deltaN * delta; 718 | } 719 | } 720 | 721 | ret._mean = means[0]; 722 | ret._var = variances[0]; 723 | ret._k = k; 724 | 725 | foreach(j; 1..nILP) { 726 | ret.put( MeanSD(means[j], variances[j], k)); 727 | } 728 | } 729 | 730 | // Handle remainder. 731 | for(; i < data.length; i++) { 732 | ret.put(data[i]); 733 | } 734 | } else { 735 | foreach(elem; data) { 736 | ret.put(elem); 737 | } 738 | } 739 | return ret; 740 | } 741 | 742 | /**Finds the variance of an input range with members implicitly convertible 743 | * to doubles.*/ 744 | double variance(T)(T data) 745 | if(doubleIterable!(T)) { 746 | return meanStdev(data).var; 747 | } 748 | 749 | /**Calculate the standard deviation of an input range with members 750 | * implicitly converitble to double.*/ 751 | double stdev(T)(T data) 752 | if(doubleIterable!(T)) { 753 | return meanStdev(data).stdev; 754 | } 755 | 756 | unittest { 757 | auto res = meanStdev(cast(int[]) [3, 1, 4, 5]); 758 | assert(approxEqual(res.stdev, 1.7078)); 759 | assert(approxEqual(res.mean, 3.25)); 760 | res = meanStdev(cast(double[]) [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]); 761 | assert(approxEqual(res.stdev, 2.160247)); 762 | assert(approxEqual(res.mean, 4)); 763 | assert(approxEqual(res.sum, 28)); 764 | 765 | MeanSD mean1, mean2, combined; 766 | foreach(i; 0..5) { 767 | mean1.put(i); 768 | } 769 | 770 | foreach(i; 5..10) { 771 | mean2.put(i); 772 | } 773 | 774 | mean1.put(mean2); 775 | 776 | foreach(i; 0..10) { 777 | combined.put(i); 778 | } 779 | 780 | assert(approxEqual(combined.mean, mean1.mean)); 781 | assert(approxEqual(combined.stdev, mean1.stdev)); 782 | assert(combined.N == mean1.N); 783 | assert(approxEqual(combined.mean, 4.5)); 784 | assert(approxEqual(combined.stdev, 3.027650)); 785 | 786 | foreach(i; 0..100) { 787 | // Monte carlo test the unrolled version. 788 | auto foo = randArray!rNormal(uniform(5, 100), 0, 1); 789 | auto res1 = meanStdev(foo); 790 | MeanSD res2; 791 | foreach(elem; foo) { 792 | res2.put(elem); 793 | } 794 | 795 | foreach(ti, elem; res1.tupleof) { 796 | assert(approxEqual(elem, res2.tupleof[ti])); 797 | } 798 | 799 | MeanSD resCornerCase; // Test corner cases where one of the N's is 0. 800 | resCornerCase.put(res1); 801 | MeanSD dummy; 802 | resCornerCase.put(dummy); 803 | foreach(ti, elem; res1.tupleof) { 804 | assert(elem == resCornerCase.tupleof[ti]); 805 | } 806 | } 807 | } 808 | 809 | /**Output range to compute mean, stdev, variance, skewness, kurtosis, min, and 810 | * max online. Using this struct is relatively expensive, so if you just need 811 | * mean and/or stdev, try MeanSD or Mean. Getter methods for stdev, 812 | * var cost a few floating point ops. Getter for mean costs a single branch to 813 | * check for N == 0. Getters for skewness and kurtosis cost a whole bunch of 814 | * floating point ops. This struct uses O(1) space and does *NOT* store the 815 | * individual elements. 816 | * 817 | * Note: This struct can implicitly convert to a MeanSD. 818 | * 819 | * References: Computing Higher-Order Moments Online. 820 | * http://people.xiph.org/~tterribe/notes/homs.html 821 | * 822 | * Examples: 823 | * --- 824 | * Summary summ; 825 | * summ.put(1); 826 | * summ.put(2); 827 | * summ.put(3); 828 | * summ.put(4); 829 | * summ.put(5); 830 | * assert(summ.N == 5); 831 | * assert(summ.mean == 3); 832 | * assert(summ.stdev == sqrt(2.5)); 833 | * assert(summ.var == 2.5); 834 | * assert(approxEqual(summ.kurtosis, -1.9120)); 835 | * assert(summ.min == 1); 836 | * assert(summ.max == 5); 837 | * assert(summ.sum == 15); 838 | * ---*/ 839 | struct Summary { 840 | private: 841 | double _mean = 0; 842 | double _m2 = 0; 843 | double _m3 = 0; 844 | double _m4 = 0; 845 | double _k = 0; 846 | double _min = double.infinity; 847 | double _max = -double.infinity; 848 | public: 849 | /// 850 | void put(double element) pure nothrow @safe { 851 | immutable kMinus1 = _k; 852 | immutable kNeg1 = 1.0 / ++_k; 853 | _min = (element < _min) ? element : _min; 854 | _max = (element > _max) ? element : _max; 855 | 856 | immutable delta = element - _mean; 857 | immutable deltaN = delta * kNeg1; 858 | _mean += deltaN; 859 | 860 | _m4 += kMinus1 * deltaN * (_k * _k - 3 * _k + 3) * deltaN * deltaN * delta + 861 | 6 * _m2 * deltaN * deltaN - 4 * deltaN * _m3; 862 | _m3 += kMinus1 * deltaN * (_k - 2) * deltaN * delta - 3 * delta * _m2 * kNeg1; 863 | _m2 += kMinus1 * deltaN * delta; 864 | } 865 | 866 | /// Combine two Summary's. 867 | void put(typeof(this) rhs) pure nothrow @safe { 868 | if(_k == 0) { 869 | foreach(ti, elem; rhs.tupleof) { 870 | this.tupleof[ti] = elem; 871 | } 872 | 873 | return; 874 | } else if(rhs._k == 0) { 875 | return; 876 | } 877 | 878 | immutable totalN = _k + rhs._k; 879 | immutable delta = rhs._mean - _mean; 880 | immutable deltaN = delta / totalN; 881 | _mean = _mean * (_k / totalN) + rhs._mean * (rhs._k / totalN); 882 | 883 | _m4 = _m4 + rhs._m4 + 884 | deltaN * _k * deltaN * rhs._k * deltaN * delta * 885 | (_k * _k - _k * rhs._k + rhs._k * rhs._k) + 886 | 6 * deltaN * _k * deltaN * _k * rhs._m2 + 887 | 6 * deltaN * rhs._k * deltaN * rhs._k * _m2 + 888 | 4 * deltaN * _k * rhs._m3 - 889 | 4 * deltaN * rhs._k * _m3; 890 | 891 | _m3 = _m3 + rhs._m3 + deltaN * _k * deltaN * rhs._k * (_k - rhs._k) + 892 | 3 * deltaN * _k * rhs._m2 - 893 | 3 * deltaN * rhs._k * _m2; 894 | 895 | _m2 = _m2 + rhs._m2 + (_k / totalN * rhs._k * delta * delta); 896 | 897 | _k = totalN; 898 | _max = (_max > rhs._max) ? _max : rhs._max; 899 | _min = (_min < rhs._min) ? _min : rhs._min; 900 | } 901 | 902 | const pure nothrow @property @safe { 903 | 904 | /// 905 | double sum() { 906 | return _mean * _k; 907 | } 908 | 909 | /// 910 | double mean() { 911 | return (_k == 0) ? double.nan : _mean; 912 | } 913 | 914 | /// 915 | double stdev() { 916 | return sqrt(var); 917 | } 918 | 919 | /// 920 | double var() { 921 | return (_k < 2) ? double.nan : _m2 / (_k - 1); 922 | } 923 | 924 | /** 925 | Mean squared error. In other words, a biased estimate of variance. 926 | */ 927 | double mse() { 928 | return (_k < 1) ? double.nan : _m2 / _k; 929 | } 930 | 931 | /// 932 | double skewness() { 933 | immutable sqM2 = sqrt(_m2); 934 | return _m3 / (sqM2 * sqM2 * sqM2) * sqrt(_k); 935 | } 936 | 937 | /// 938 | double kurtosis() { 939 | return _m4 / _m2 * _k / _m2 - 3; 940 | } 941 | 942 | /// 943 | double N() { 944 | return _k; 945 | } 946 | 947 | /// 948 | double min() { 949 | return _min; 950 | } 951 | 952 | /// 953 | double max() { 954 | return _max; 955 | } 956 | 957 | /**Converts this struct to a MeanSD. Called via alias this when an 958 | * implicit conversion is attetmpted. 959 | */ 960 | MeanSD toMeanSD() { 961 | return MeanSD(_mean, _m2, _k); 962 | } 963 | } 964 | 965 | alias toMeanSD this; 966 | 967 | /// 968 | string toString() const { 969 | return text("N = ", roundTo!long(_k), 970 | "\nMean = ", mean, 971 | "\nVariance = ", var, 972 | "\nStdev = ", stdev, 973 | "\nSkewness = ", skewness, 974 | "\nKurtosis = ", kurtosis, 975 | "\nMin = ", _min, 976 | "\nMax = ", _max); 977 | } 978 | } 979 | 980 | unittest { 981 | // Everything else is tested indirectly through kurtosis, skewness. Test 982 | // put(typeof(this)). 983 | 984 | Summary mean1, mean2, combined; 985 | foreach(i; 0..5) { 986 | mean1.put(i); 987 | } 988 | 989 | foreach(i; 5..10) { 990 | mean2.put(i); 991 | } 992 | 993 | auto m1_2 = mean1; 994 | auto m2_2 = mean2; 995 | m1_2.put(m2_2); 996 | 997 | mean1.put(mean2); 998 | 999 | foreach(i; 0..10) { 1000 | combined.put(i); 1001 | } 1002 | 1003 | foreach(ti, elem; mean1.tupleof) { 1004 | assert(approxEqual(elem, combined.tupleof[ti])); 1005 | } 1006 | 1007 | Summary summCornerCase; // Case where one N is zero. 1008 | summCornerCase.put(mean1); 1009 | Summary dummy; 1010 | summCornerCase.put(dummy); 1011 | foreach(ti, elem; summCornerCase.tupleof) { 1012 | assert(elem == mean1.tupleof[ti]); 1013 | } 1014 | } 1015 | 1016 | /**Excess kurtosis relative to normal distribution. High kurtosis means that 1017 | * the variance is due to infrequent, large deviations from the mean. Low 1018 | * kurtosis means that the variance is due to frequent, small deviations from 1019 | * the mean. The normal distribution is defined as having kurtosis of 0. 1020 | * Input must be an input range with elements implicitly convertible to double.*/ 1021 | double kurtosis(T)(T data) 1022 | if(doubleIterable!(T)) { 1023 | // This is too infrequently used and has too much ILP within a single 1024 | // iteration to be worth ILP optimizing. 1025 | Summary kCalc; 1026 | foreach(elem; data) { 1027 | kCalc.put(elem); 1028 | } 1029 | return kCalc.kurtosis; 1030 | } 1031 | 1032 | unittest { 1033 | // Values from Matlab. 1034 | assert(approxEqual(kurtosis([1, 1, 1, 1, 10].dup), 0.25)); 1035 | assert(approxEqual(kurtosis([2.5, 3.5, 4.5, 5.5].dup), -1.36)); 1036 | assert(approxEqual(kurtosis([1,2,2,2,2,2,100].dup), 2.1657)); 1037 | } 1038 | 1039 | /**Skewness is a measure of symmetry of a distribution. Positive skewness 1040 | * means that the right tail is longer/fatter than the left tail. Negative 1041 | * skewness means the left tail is longer/fatter than the right tail. Zero 1042 | * skewness indicates a symmetrical distribution. Input must be an input 1043 | * range with elements implicitly convertible to double.*/ 1044 | double skewness(T)(T data) 1045 | if(doubleIterable!(T)) { 1046 | // This is too infrequently used and has too much ILP within a single 1047 | // iteration to be worth ILP optimizing. 1048 | Summary sCalc; 1049 | foreach(elem; data) { 1050 | sCalc.put(elem); 1051 | } 1052 | return sCalc.skewness; 1053 | } 1054 | 1055 | unittest { 1056 | // Values from Octave. 1057 | assert(approxEqual(skewness([1,2,3,4,5].dup), 0)); 1058 | assert(approxEqual(skewness([3,1,4,1,5,9,2,6,5].dup), 0.5443)); 1059 | assert(approxEqual(skewness([2,7,1,8,2,8,1,8,2,8,4,5,9].dup), -0.0866)); 1060 | 1061 | // Test handling of ranges that are not arrays. 1062 | string[] stringy = ["3", "1", "4", "1", "5", "9", "2", "6", "5"]; 1063 | auto intified = map!(to!(int))(stringy); 1064 | assert(approxEqual(skewness(intified), 0.5443)); 1065 | } 1066 | 1067 | /**Convenience function. Puts all elements of data into a Summary struct, 1068 | * and returns this struct.*/ 1069 | Summary summary(T)(T data) 1070 | if(doubleIterable!(T)) { 1071 | // This is too infrequently used and has too much ILP within a single 1072 | // iteration to be worth ILP optimizing. 1073 | Summary summ; 1074 | foreach(elem; data) { 1075 | summ.put(elem); 1076 | } 1077 | return summ; 1078 | } 1079 | // Just a convenience function for a well-tested struct. No unittest really 1080 | // necessary. (Famous last words.) 1081 | 1082 | /// 1083 | struct ZScore(T) if(isForwardRange!(T) && is(ElementType!(T) : double)) { 1084 | private: 1085 | T range; 1086 | double mean; 1087 | double sdNeg1; 1088 | 1089 | double z(double elem) { 1090 | return (elem - mean) * sdNeg1; 1091 | } 1092 | 1093 | public: 1094 | this(T range) { 1095 | this.range = range; 1096 | auto msd = meanStdev(range); 1097 | this.mean = msd.mean; 1098 | this.sdNeg1 = 1.0 / msd.stdev; 1099 | } 1100 | 1101 | this(T range, double mean, double sd) { 1102 | this.range = range; 1103 | this.mean = mean; 1104 | this.sdNeg1 = 1.0 / sd; 1105 | } 1106 | 1107 | /// 1108 | @property double front() { 1109 | return z(range.front); 1110 | } 1111 | 1112 | /// 1113 | void popFront() { 1114 | range.popFront; 1115 | } 1116 | 1117 | /// 1118 | @property bool empty() { 1119 | return range.empty; 1120 | } 1121 | 1122 | static if(isForwardRange!(T)) { 1123 | /// 1124 | @property typeof(this) save() { 1125 | auto ret = this; 1126 | ret.range = range.save; 1127 | return ret; 1128 | } 1129 | } 1130 | 1131 | static if(isRandomAccessRange!(T)) { 1132 | /// 1133 | double opIndex(size_t index) { 1134 | return z(range[index]); 1135 | } 1136 | } 1137 | 1138 | static if(isBidirectionalRange!(T)) { 1139 | /// 1140 | @property double back() { 1141 | return z(range.back); 1142 | } 1143 | 1144 | /// 1145 | void popBack() { 1146 | range.popBack; 1147 | } 1148 | } 1149 | 1150 | static if(hasLength!(T)) { 1151 | /// 1152 | @property size_t length() { 1153 | return range.length; 1154 | } 1155 | } 1156 | } 1157 | 1158 | /**Returns a range with whatever properties T has (forward range, random 1159 | * access range, bidirectional range, hasLength, etc.), 1160 | * of the z-scores of the underlying 1161 | * range. A z-score of an element in a range is defined as 1162 | * (element - mean(range)) / stdev(range). 1163 | * 1164 | * Notes: 1165 | * 1166 | * If the data contained in the range is a sample of a larger population, 1167 | * rather than an entire population, then technically, the results output 1168 | * from the ZScore range are T statistics, not Z statistics. This is because 1169 | * the sample mean and standard deviation are only estimates of the population 1170 | * parameters. This does not affect the mechanics of using this range, 1171 | * but it does affect the interpretation of its output. 1172 | * 1173 | * Accessing elements of this range is fairly expensive, as a 1174 | * floating point multiply is involved. Also, constructing this range is 1175 | * costly, as the entire input range has to be iterated over to find the 1176 | * mean and standard deviation. 1177 | */ 1178 | ZScore!(T) zScore(T)(T range) 1179 | if(isForwardRange!(T) && doubleInput!(T)) { 1180 | return ZScore!(T)(range); 1181 | } 1182 | 1183 | /**Allows the construction of a ZScore range with precomputed mean and 1184 | * stdev. 1185 | */ 1186 | ZScore!(T) zScore(T)(T range, double mean, double sd) 1187 | if(isForwardRange!(T) && doubleInput!(T)) { 1188 | return ZScore!(T)(range, mean, sd); 1189 | } 1190 | 1191 | unittest { 1192 | int[] arr = [1,2,3,4,5]; 1193 | auto m = mean(arr).mean; 1194 | auto sd = stdev(arr); 1195 | auto z = zScore(arr); 1196 | 1197 | size_t pos = 0; 1198 | foreach(elem; z) { 1199 | assert(approxEqual(elem, (arr[pos++] - m) / sd)); 1200 | } 1201 | 1202 | assert(z.length == 5); 1203 | foreach(i; 0..z.length) { 1204 | assert(approxEqual(z[i], (arr[i] - m) / sd)); 1205 | } 1206 | } 1207 | -------------------------------------------------------------------------------- /source/dstats/random.d: -------------------------------------------------------------------------------- 1 | /**Generates random samples from a various probability distributions. 2 | * These are mostly D ports of the NumPy random number generators.*/ 3 | 4 | /* This library is a D port of a large portion of the Numpy random number 5 | * library. A few distributions were excluded because they were too obscure 6 | * to be tested properly. They may be included at some point in the future. 7 | * 8 | * Port to D copyright 2009 David Simcha. 9 | * 10 | * The original C code is available under the licenses below. No additional 11 | * restrictions shall apply to this D translation. Eventually, I will try to 12 | * discuss the licensing issues with the original authors of Numpy and 13 | * make this sane enough that this module can be included in Phobos without 14 | * concern. For now, it's free enough that you can at least use it in 15 | * personal projects without any serious issues. 16 | * 17 | * Main Numpy license: 18 | * 19 | * Copyright (c) 2005-2009, NumPy Developers. 20 | * All rights reserved. 21 | * 22 | * Redistribution and use in source and binary forms, with or without 23 | * modification, are permitted provided that the following conditions are 24 | * met: 25 | * 26 | * * Redistributions of source code must retain the above copyright 27 | * notice, this list of conditions and the following disclaimer. 28 | * 29 | * * Redistributions in binary form must reproduce the above 30 | * copyright notice, this list of conditions and the following 31 | * disclaimer in the documentation and/or other materials provided 32 | * with the distribution. 33 | * 34 | * * Neither the name of the NumPy Developers nor the names of any 35 | * contributors may be used to endorse or promote products derived 36 | * from this software without specific prior written permission. 37 | * 38 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 39 | * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 40 | * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 41 | * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 42 | * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 43 | * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 44 | * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 45 | * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 46 | * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 47 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 48 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 49 | * 50 | * distribution.c license: 51 | * 52 | * Copyright 2005 Robert Kern (robert.kern@gmail.com) 53 | * 54 | * Permission is hereby granted, free of charge, to any person obtaining a 55 | * copy of this software and associated documentation files (the 56 | * "Software"), to deal in the Software without restriction, including 57 | * without limitation the rights to use, copy, modify, merge, publish, 58 | * distribute, sublicense, and/or sell copies of the Software, and to 59 | * permit persons to whom the Software is furnished to do so, subject to 60 | * the following conditions: 61 | * 62 | * The above copyright notice and this permission notice shall be included 63 | * in all copies or substantial portions of the Software. 64 | * 65 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 66 | * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 67 | * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 68 | * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 69 | * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 70 | * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 71 | * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 72 | */ 73 | 74 | /* The implementations of rHypergeometricHyp() and rHypergeometricHrua() 75 | * were adapted from Ivan Frohne's rv.py which has this 76 | * license: 77 | * 78 | * Copyright 1998 by Ivan Frohne; Wasilla, Alaska, U.S.A. 79 | * All Rights Reserved 80 | * 81 | * Permission to use, copy, modify and distribute this software and its 82 | * documentation for any purpose, free of charge, is granted subject to the 83 | * following conditions: 84 | * The above copyright notice and this permission notice shall be included in 85 | * all copies or substantial portions of the software. 86 | * 87 | * THE SOFTWARE AND DOCUMENTATION IS PROVIDED WITHOUT WARRANTY OF ANY KIND, 88 | * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO MERCHANTABILITY, FITNESS 89 | * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHOR 90 | * OR COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM OR DAMAGES IN A CONTRACT 91 | * ACTION, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 92 | * SOFTWARE OR ITS DOCUMENTATION. 93 | */ 94 | 95 | /* References: 96 | * 97 | * Devroye, Luc. _Non-Uniform Random Variate Generation_. 98 | * Springer-Verlag, New York, 1986. 99 | * http://cgm.cs.mcgill.ca/~luc/rnbookindex.html 100 | * 101 | * Kachitvichyanukul, V. and Schmeiser, B. W. Binomial Random Variate 102 | * Generation. Communications of the ACM, 31, 2 (February, 1988) 216. 103 | * 104 | * Hoermann, W. The Transformed Rejection Method for Generating Poisson Random 105 | * Variables. Insurance: Mathematics and Economics, (to appear) 106 | * http://citeseer.csail.mit.edu/151115.html 107 | * 108 | * Marsaglia, G. and Tsang, W. W. A Simple Method for Generating Gamma 109 | * Variables. ACM Transactions on Mathematical Software, Vol. 26, No. 3, 110 | * September 2000, Pages 363-372. 111 | */ 112 | 113 | 114 | /* Unit tests are non-deterministic. They prove that the distributions 115 | * are reasonable by using K-S tests and summary stats, but cannot 116 | * deterministically prove correctness.*/ 117 | 118 | module dstats.random; 119 | 120 | import std.math, dstats.distrib, std.traits, std.typetuple, 121 | std.exception, std.mathspecial, std.array; 122 | import std.algorithm : min, max; 123 | public import std.random; //For uniform distrib. 124 | 125 | import dstats.alloc, dstats.base; 126 | 127 | version(unittest) { 128 | import std.stdio, dstats.tests, dstats.summary, std.range, core.memory; 129 | } 130 | 131 | /**Convenience function to allow one-statement creation of arrays of random 132 | * numbers. 133 | * 134 | * Examples: 135 | * --- 136 | * // Create an array of 10 random numbers distributed Normal(0, 1). 137 | * auto normals = randArray!rNormal(10, 0, 1); 138 | * --- 139 | */ 140 | auto randArray(alias randFun, Args...)(size_t N, auto ref Args args) { 141 | alias typeof(randFun(args)) R; 142 | return randArray!(R, randFun, Args)(N, args); 143 | } 144 | 145 | unittest { 146 | // Just check if it compiles. 147 | auto nums = randArray!rNormal(5, 0, 1); 148 | auto nums2 = randArray!rBinomial(10, 5, 0.5); 149 | } 150 | 151 | /**Allows the creation of an array of random numbers with an explicitly 152 | * specified type. Useful, for example, when single-precision floats are all 153 | * you need. 154 | * 155 | * Examples: 156 | * --- 157 | * // Create an array of 10 million floats distributed Normal(0, 1). 158 | * float[] normals = randArray!(float, rNormal)(10, 0, 1); 159 | * --- 160 | */ 161 | R[] randArray(R, alias randFun, Args...)(size_t N, auto ref Args args) { 162 | auto ret = uninitializedArray!(R[])(N); 163 | foreach(ref elem; ret) { 164 | elem = randFun(args); 165 | } 166 | 167 | return ret; 168 | } 169 | 170 | /// 171 | struct RandRange(alias randFun, T...) { 172 | private: 173 | T args; 174 | double normData = double.nan; // TLS stuff for normal. 175 | typeof(randFun(args)) frontElem; 176 | public: 177 | enum bool empty = false; 178 | 179 | this(T args) { 180 | this.args = args; 181 | popFront; 182 | } 183 | 184 | @property typeof(randFun(args)) front() { 185 | return frontElem; 186 | } 187 | 188 | void popFront() { 189 | /* This is a kludge to make the contents of this range deterministic 190 | * given the state of the underlying random number generator without 191 | * a massive redesign. We store the state in this struct and 192 | * swap w/ the TLS data for rNormal on each call to popFront. This has to 193 | * be done no matter what distribution we're using b/c a lot of others 194 | * rely on the normal.*/ 195 | auto lastNormPtr = &lastNorm; // Cache ptr once, avoid repeated TLS lookup. 196 | auto temp = *lastNormPtr; // Store old state. 197 | *lastNormPtr = normData; // Replace it. 198 | this.frontElem = randFun(args); 199 | normData = *lastNormPtr; 200 | *lastNormPtr = temp; 201 | } 202 | 203 | @property typeof(this) save() { 204 | return this; 205 | } 206 | } 207 | 208 | /**Turn a random number generator function into an infinite range. 209 | * Params is a tuple of the distribution parameters. This is specified 210 | * in the same order as when calling the function directly. 211 | * 212 | * The sequence generated by this range is deterministic and repeatable given 213 | * the state of the underlying random number generator. If the underlying 214 | * random number generator is explicitly specified, as opposed to using the 215 | * default thread-local global RNG, it is copied when the struct is copied. 216 | * See below for an example of this behavior. 217 | * 218 | * Examples: 219 | * --- 220 | * // Print out some summary statistics for 10,000 Poisson-distributed 221 | * // random numbers w/ Poisson parameter 2. 222 | * auto gen = Random(unpredictableSeed); 223 | * auto pois1k = take(10_000, randRange!rPoisson(2, gen)); 224 | * writeln( summary(pois1k) ); 225 | * writeln( summary(pois1k) ); // Exact same results as first call. 226 | * --- 227 | */ 228 | RandRange!(randFun, T) randRange(alias randFun, T...)(T params) { 229 | alias RandRange!(randFun, T) RT; 230 | RT ret; // Bypass the ctor b/c it's screwy. 231 | ret.args = params; 232 | ret.popFront; 233 | return ret; 234 | } 235 | 236 | unittest { 237 | // The thing to test here is that the results are deterministic given 238 | // an underlying RNG. 239 | 240 | { 241 | auto norms = take(randRange!rNormal(0, 1, Random(unpredictableSeed)), 99); 242 | auto arr1 = array(norms); 243 | auto arr2 = array(norms); 244 | assert(arr1 == arr2); 245 | } 246 | 247 | { 248 | auto binomSmall = take(randRange!rBinomial(20, 0.5, Random(unpredictableSeed)), 99); 249 | auto arr1 = array(binomSmall); 250 | auto arr2 = array(binomSmall); 251 | assert(arr1 == arr2); 252 | } 253 | 254 | { 255 | auto binomLarge = take(randRange!rBinomial(20000, 0.4, Random(unpredictableSeed)), 99); 256 | auto arr1 = array(binomLarge); 257 | auto arr2 = array(binomLarge); 258 | assert(arr1 == arr2); 259 | } 260 | writeln("Passed RandRange test."); 261 | } 262 | 263 | // Thread local data for normal distrib. that is preserved across calls. 264 | private static double lastNorm = double.nan; 265 | 266 | /// 267 | double rNormal(RGen = Random)(double mean, double sd, ref RGen gen = rndGen) { 268 | dstatsEnforce(sd > 0, "Standard deviation must be > 0 for rNormal."); 269 | 270 | double lr = lastNorm; 271 | if (!isNaN(lr)) { 272 | lastNorm = double.nan; 273 | return lr * sd + mean; 274 | } 275 | 276 | double x1 = void, x2 = void, r2 = void; 277 | do { 278 | x1 = uniform(-1.0L, 1.0L, gen); 279 | x2 = uniform(-1.0L, 1.0L, gen); 280 | r2 = x1 * x1 + x2 * x2; 281 | } while (r2 > 1.0L || r2 == 0.0L); 282 | double f = sqrt(-2.0L * log(r2) / r2); 283 | lastNorm = f * x1; 284 | return f * x2 * sd + mean; 285 | } 286 | 287 | 288 | unittest { 289 | auto observ = randArray!rNormal(100_000, 0, 1); 290 | auto ksRes = ksTest(observ, parametrize!(normalCDF)(0.0L, 1.0L)); 291 | auto summ = summary(observ); 292 | 293 | writeln("100k samples from normal(0, 1): K-S P-val: ", ksRes.p); 294 | writeln("\tMean Expected: 0 Observed: ", summ.mean); 295 | writeln("\tMedian Expected: 0 Observed: ", median(observ)); 296 | writeln("\tStdev Expected: 1 Observed: ", summ.stdev); 297 | writeln("\tKurtosis Expected: 0 Observed: ", summ.kurtosis); 298 | writeln("\tSkewness Expected: 0 Observed: ", summ.skewness); 299 | } 300 | 301 | /// 302 | double rCauchy(RGen = Random)(double X0, double gamma, ref RGen gen = rndGen) { 303 | dstatsEnforce(gamma > 0, "gamma must be > 0 for Cauchy distribution."); 304 | 305 | return (rNormal(0, 1, gen) / rNormal(0, 1, gen)) * gamma + X0; 306 | } 307 | 308 | unittest { 309 | auto observ = randArray!rCauchy(100_000, 2, 5); 310 | auto ksRes = ksTest(observ, parametrize!(cauchyCDF)(2.0L, 5.0L)); 311 | 312 | auto summ = summary(observ); 313 | writeln("100k samples from Cauchy(2, 5): K-S P-val: ", ksRes.p); 314 | writeln("\tMean Expected: N/A Observed: ", summ.mean); 315 | writeln("\tMedian Expected: 2 Observed: ", median(observ)); 316 | writeln("\tStdev Expected: N/A Observed: ", summ.stdev); 317 | writeln("\tKurtosis Expected: N/A Observed: ", summ.kurtosis); 318 | writeln("\tSkewness Expected: N/A Observed: ", summ.skewness); 319 | } 320 | 321 | /// 322 | double rStudentsT(RGen = Random)(double df, ref RGen gen = rndGen) { 323 | dstatsEnforce(df > 0, "Student's T distribution must have >0 degrees of freedom."); 324 | 325 | double N = rNormal(0, 1, gen); 326 | double G = stdGamma(df / 2, gen); 327 | double X = sqrt(df / 2) * N / sqrt(G); 328 | return X; 329 | } 330 | 331 | unittest { 332 | auto observ = randArray!rStudentsT(100_000, 5); 333 | auto ksRes = ksTest(observ, parametrize!(studentsTCDF)(5)); 334 | 335 | auto summ = summary(observ); 336 | writeln("100k samples from T(5): K-S P-val: ", ksRes.p); 337 | writeln("\tMean Expected: 0 Observed: ", summ.mean); 338 | writeln("\tMedian Expected: 0 Observed: ", median(observ)); 339 | writeln("\tStdev Expected: 1.2909 Observed: ", summ.stdev); 340 | writeln("\tKurtosis Expected: 6 Observed: ", summ.kurtosis); 341 | writeln("\tSkewness Expected: 0 Observed: ", summ.skewness); 342 | } 343 | 344 | /// 345 | double rFisher(RGen = Random)(double df1, double df2, ref RGen gen = rndGen) { 346 | dstatsEnforce(df1 > 0 && df2 > 0, 347 | "df1 and df2 must be >0 for the Fisher distribution."); 348 | 349 | return (rChiSquare(df1, gen) * df2) / 350 | (rChiSquare(df2, gen) * df1); 351 | } 352 | 353 | unittest { 354 | auto observ = randArray!rFisher(100_000, 5, 7); 355 | auto ksRes = ksTest(observ, parametrize!(fisherCDF)(5, 7)); 356 | writeln("100k samples from fisher(5, 7): K-S P-val: ", ksRes.p); 357 | writeln("\tMean Expected: ", 7.0 / 5, " Observed: ", mean(observ)); 358 | writeln("\tMedian Expected: ?? Observed: ", median(observ)); 359 | writeln("\tStdev Expected: ?? Observed: ", stdev(observ)); 360 | writeln("\tKurtosis Expected: ?? Observed: ", kurtosis(observ)); 361 | writeln("\tSkewness Expected: ?? Observed: ", skewness(observ)); 362 | GC.free(observ.ptr); 363 | } 364 | 365 | /// 366 | double rChiSquare(RGen = Random)(double df, ref RGen gen = rndGen) { 367 | dstatsEnforce(df > 0, "df must be > 0 for chiSquare distribution."); 368 | 369 | return 2.0 * stdGamma(df / 2.0L, gen); 370 | } 371 | 372 | unittest { 373 | double df = 5; 374 | double[] observ = new double[100_000]; 375 | foreach(ref elem; observ) 376 | elem = rChiSquare(df); 377 | auto ksRes = ksTest(observ, parametrize!(chiSquareCDF)(5)); 378 | writeln("100k samples from Chi-Square: K-S P-val: ", ksRes.p); 379 | writeln("\tMean Expected: ", df, " Observed: ", mean(observ)); 380 | writeln("\tMedian Expected: ", df - (2.0L / 3.0L), " Observed: ", median(observ)); 381 | writeln("\tStdev Expected: ", sqrt(2 * df), " Observed: ", stdev(observ)); 382 | writeln("\tKurtosis Expected: ", 12 / df, " Observed: ", kurtosis(observ)); 383 | writeln("\tSkewness Expected: ", sqrt(8 / df), " Observed: ", skewness(observ)); 384 | GC.free(observ.ptr); 385 | } 386 | 387 | /// 388 | int rPoisson(RGen = Random)(double lam, ref RGen gen = rndGen) { 389 | dstatsEnforce(lam > 0, "lambda must be >0 for Poisson distribution."); 390 | 391 | static int poissonMult(ref RGen gen, double lam) { 392 | double U = void; 393 | 394 | double enlam = exp(-lam); 395 | int X = 0; 396 | double prod = 1.0; 397 | while (true) { 398 | U = uniform(0.0L, 1.0L, gen); 399 | prod *= U; 400 | if (prod > enlam) { 401 | X += 1; 402 | } else { 403 | return X; 404 | } 405 | } 406 | assert(0); 407 | } 408 | 409 | enum double LS2PI = 0.91893853320467267; 410 | enum double TWELFTH = 0.083333333333333333333333; 411 | static int poissonPtrs(ref RGen gen, double lam) { 412 | int k; 413 | double U = void, V = void, us = void; 414 | 415 | double slam = sqrt(lam); 416 | double loglam = log(lam); 417 | double b = 0.931 + 2.53*slam; 418 | double a = -0.059 + 0.02483*b; 419 | double invalpha = 1.1239 + 1.1328/(b-3.4); 420 | double vr = 0.9277 - 3.6224/(b-2); 421 | 422 | while (true) { 423 | U = uniform(-0.5L, 0.5L, gen); 424 | V = uniform(0.0L, 1.0L, gen); 425 | us = 0.5 - abs(U); 426 | k = cast(int) floor((2*a/us + b)*U + lam + 0.43); 427 | if ((us >= 0.07) && (V <= vr)) { 428 | return k; 429 | } 430 | if ((k < 0) || ((us < 0.013) && (V > us))) { 431 | continue; 432 | } 433 | if ((log(V) + log(invalpha) - log(a/(us*us)+b)) <= 434 | (-lam + k*loglam - logGamma(k+1))) { 435 | return k; 436 | } 437 | } 438 | assert(0); 439 | } 440 | 441 | 442 | if (lam >= 10) { 443 | return poissonPtrs(gen, lam); 444 | } else if (lam == 0) { 445 | return 0; 446 | } else { 447 | return poissonMult(gen, lam); 448 | } 449 | } 450 | 451 | unittest { 452 | double lambda = 15L; 453 | int[] observ = new int[100_000]; 454 | foreach(ref elem; observ) 455 | elem = rPoisson(lambda); 456 | writeln("100k samples from poisson(", lambda, "):"); 457 | writeln("\tMean Expected: ", lambda, 458 | " Observed: ", mean(observ)); 459 | writeln("\tMedian Expected: ?? Observed: ", median(observ)); 460 | writeln("\tStdev Expected: ", sqrt(lambda), 461 | " Observed: ", stdev(observ)); 462 | writeln("\tKurtosis Expected: ", 1 / lambda, 463 | " Observed: ", kurtosis(observ)); 464 | writeln("\tSkewness Expected: ", 1 / sqrt(lambda), 465 | " Observed: ", skewness(observ)); 466 | GC.free(observ.ptr); 467 | } 468 | 469 | /// 470 | int rBernoulli(RGen = Random)(double P = 0.5, ref RGen gen = rndGen) { 471 | dstatsEnforce(P >= 0 && P <= 1, "P must be between 0, 1 for Bernoulli distribution."); 472 | 473 | double pVal = uniform(0.0L, 1.0L, gen); 474 | return cast(int) (pVal <= P); 475 | } 476 | 477 | private struct BinoState { 478 | bool has_binomial; 479 | int nsave; 480 | double psave; 481 | int m; 482 | double r,q,fm,p1,xm,xl,xr,c,laml,lamr,p2,p3,p4; 483 | double a,u,v,s,F,rho,t,A,nrq,x1,x2,f1,f2,z,z2,w,w2,x; 484 | } 485 | 486 | private BinoState* binoState() { 487 | // Store BinoState structs on heap rather than directly in TLS. 488 | 489 | static BinoState* stateTLS; 490 | auto tlsPtr = stateTLS; 491 | if (tlsPtr is null) { 492 | tlsPtr = new BinoState; 493 | stateTLS = tlsPtr; 494 | } 495 | return tlsPtr; 496 | } 497 | 498 | 499 | private int rBinomialBtpe(RGen = Random)(int n, double p, ref RGen gen = rndGen) { 500 | auto state = binoState; 501 | double r,q,fm,p1,xm,xl,xr,c,laml,lamr,p2,p3,p4; 502 | double a,u,v,s,F,rho,t,A,nrq,x1,x2,f1,f2,z,z2,w,w2,x; 503 | int m,y,k,i; 504 | 505 | if (!(state.has_binomial) || 506 | (state.nsave != n) || 507 | (state.psave != p)) { 508 | /* initialize */ 509 | state.nsave = n; 510 | state.psave = p; 511 | state.has_binomial = 1; 512 | state.r = r = min(p, 1.0-p); 513 | state.q = q = 1.0 - r; 514 | state.fm = fm = n*r+r; 515 | state.m = m = cast(int)floor(state.fm); 516 | state.p1 = p1 = floor(2.195*sqrt(n*r*q)-4.6*q) + 0.5; 517 | state.xm = xm = m + 0.5; 518 | state.xl = xl = xm - p1; 519 | state.xr = xr = xm + p1; 520 | state.c = c = 0.134 + 20.5/(15.3 + m); 521 | a = (fm - xl)/(fm-xl*r); 522 | state.laml = laml = a*(1.0 + a/2.0); 523 | a = (xr - fm)/(xr*q); 524 | state.lamr = lamr = a*(1.0 + a/2.0); 525 | state.p2 = p2 = p1*(1.0 + 2.0*c); 526 | state.p3 = p3 = p2 + c/laml; 527 | state.p4 = p4 = p3 + c/lamr; 528 | } else { 529 | r = state.r; 530 | q = state.q; 531 | fm = state.fm; 532 | m = state.m; 533 | p1 = state.p1; 534 | xm = state.xm; 535 | xl = state.xl; 536 | xr = state.xr; 537 | c = state.c; 538 | laml = state.laml; 539 | lamr = state.lamr; 540 | p2 = state.p2; 541 | p3 = state.p3; 542 | p4 = state.p4; 543 | } 544 | 545 | /* sigh ... */ 546 | Step10: 547 | nrq = n*r*q; 548 | u = uniform(0.0L, p4, gen); 549 | v = uniform(0.0L, 1.0L, gen); 550 | if (u > p1) goto Step20; 551 | y = cast(int)floor(xm - p1*v + u); 552 | goto Step60; 553 | 554 | Step20: 555 | if (u > p2) goto Step30; 556 | x = xl + (u - p1)/c; 557 | v = v*c + 1.0 - fabs(m - x + 0.5)/p1; 558 | if (v > 1.0) goto Step10; 559 | y = cast(int)floor(x); 560 | goto Step50; 561 | 562 | Step30: 563 | if (u > p3) goto Step40; 564 | y = cast(int)floor(xl + log(v)/laml); 565 | if (y < 0) goto Step10; 566 | v = v*(u-p2)*laml; 567 | goto Step50; 568 | 569 | Step40: 570 | y = cast(int)floor(xr - log(v)/lamr); 571 | if (y > n) goto Step10; 572 | v = v*(u-p3)*lamr; 573 | 574 | Step50: 575 | k = cast(int) abs(y - m); 576 | if ((k > 20) && (k < ((nrq)/2.0 - 1))) goto Step52; 577 | 578 | s = r/q; 579 | a = s*(n+1); 580 | F = 1.0; 581 | if (m < y) { 582 | for (i=m; i<=y; i++) { 583 | F *= (a/i - s); 584 | } 585 | } else if (m > y) { 586 | for (i=y; i<=m; i++) { 587 | F /= (a/i - s); 588 | } 589 | } else { 590 | if (v > F) goto Step10; 591 | goto Step60; 592 | } 593 | 594 | Step52: 595 | rho = (k/(nrq))*((k*(k/3.0 + 0.625) + 0.16666666666666666)/nrq + 0.5); 596 | t = -k*k/(2*nrq); 597 | A = log(v); 598 | if (A < (t - rho)) goto Step60; 599 | if (A > (t + rho)) goto Step10; 600 | 601 | x1 = y+1; 602 | f1 = m+1; 603 | z = n+1-m; 604 | w = n-y+1; 605 | x2 = x1*x1; 606 | f2 = f1*f1; 607 | z2 = z*z; 608 | w2 = w*w; 609 | if (A > (xm*log(f1/x1) 610 | + (n-m+0.5)*log(z/w) 611 | + (y-m)*log(w*r/(x1*q)) 612 | + (13680.-(462.-(132.-(99.-140./f2)/f2)/f2)/f2)/f1/166320. 613 | + (13680.-(462.-(132.-(99.-140./z2)/z2)/z2)/z2)/z/166320. 614 | + (13680.-(462.-(132.-(99.-140./x2)/x2)/x2)/x2)/x1/166320. 615 | + (13680.-(462.-(132.-(99.-140./w2)/w2)/w2)/w2)/w/166320.)) { 616 | goto Step10; 617 | } 618 | 619 | Step60: 620 | if (p > 0.5) { 621 | y = n - y; 622 | } 623 | 624 | return y; 625 | } 626 | 627 | private int rBinomialInversion(RGen = Random)(int n, double p, ref RGen gen = rndGen) { 628 | auto state = binoState; 629 | double q, qn, np, px, U; 630 | int X, bound; 631 | 632 | if (!(state.has_binomial) || 633 | (state.nsave != n) || 634 | (state.psave != p)) { 635 | state.nsave = n; 636 | state.psave = p; 637 | state.has_binomial = 1; 638 | state.q = q = 1.0 - p; 639 | state.r = qn = exp(n * log(q)); 640 | state.c = np = n*p; 641 | state.m = bound = cast(int) min(n, np + 10.0*sqrt(np*q + 1)); 642 | } else { 643 | q = state.q; 644 | qn = state.r; 645 | np = state.c; 646 | bound = cast(int) state.m; 647 | } 648 | X = 0; 649 | px = qn; 650 | U = uniform(0.0L, 1.0L, gen); 651 | while (U > px) { 652 | X++; 653 | if (X > bound) { 654 | X = 0; 655 | px = qn; 656 | U = uniform(0.0L, 1.0L, gen); 657 | } else { 658 | U -= px; 659 | px = ((n-X+1) * p * px)/(X*q); 660 | } 661 | } 662 | return X; 663 | } 664 | 665 | /// 666 | int rBinomial(RGen = Random)(int n, double p, ref RGen gen = rndGen) { 667 | dstatsEnforce(n >= 0, "n must be >= 0 for binomial distribution."); 668 | dstatsEnforce(p >= 0 && p <= 1, "p must be between 0, 1 for binomial distribution."); 669 | 670 | if (p <= 0.5) { 671 | if (p*n <= 30.0) { 672 | return rBinomialInversion(n, p, gen); 673 | } else { 674 | return rBinomialBtpe(n, p, gen); 675 | } 676 | } else { 677 | double q = 1.0-p; 678 | if (q*n <= 30.0) { 679 | return n - rBinomialInversion(n, q, gen); 680 | } else { 681 | return n - rBinomialBtpe(n, q, gen); 682 | } 683 | } 684 | } 685 | 686 | unittest { 687 | void testBinom(int n, double p) { 688 | int[] observ = new int[100_000]; 689 | foreach(ref elem; observ) 690 | elem = rBinomial(n, p); 691 | writeln("100k samples from binom.(", n, ", ", p, "):"); 692 | writeln("\tMean Expected: ", n * p, 693 | " Observed: ", mean(observ)); 694 | writeln("\tMedian Expected: ", n * p, " Observed: ", median(observ)); 695 | writeln("\tStdev Expected: ", sqrt(n * p * (1 - p)), 696 | " Observed: ", stdev(observ)); 697 | writeln("\tKurtosis Expected: ", (1 - 6 * p * (1 - p)) / (n * p * (1 - p)), 698 | " Observed: ", kurtosis(observ)); 699 | writeln("\tSkewness Expected: ", (1 - 2 * p) / (sqrt(n * p * (1 - p))), 700 | " Observed: ", skewness(observ)); 701 | GC.free(observ.ptr); 702 | } 703 | 704 | testBinom(1000, 0.6); 705 | testBinom(3, 0.7); 706 | } 707 | 708 | private int hypergeoHyp(RGen = Random)(int good, int bad, int sample, ref RGen gen = rndGen) { 709 | int Z = void; 710 | double U = void; 711 | 712 | int d1 = bad + good - sample; 713 | double d2 = cast(double)min(bad, good); 714 | 715 | double Y = d2; 716 | int K = sample; 717 | while (Y > 0.0) { 718 | U = uniform(0.0L, 1.0L, gen); 719 | Y -= cast(int)floor(U + Y/(d1 + K)); 720 | K--; 721 | if (K == 0) break; 722 | } 723 | Z = cast(int)(d2 - Y); 724 | if (good > bad) Z = sample - Z; 725 | return Z; 726 | } 727 | 728 | private enum double D1 = 1.7155277699214135; 729 | private enum double D2 = 0.8989161620588988; 730 | private int hypergeoHrua(RGen = Random)(int good, int bad, int sample, ref RGen gen = rndGen) { 731 | int Z = void; 732 | double T = void, W = void, X = void, Y = void; 733 | 734 | int mingoodbad = min(good, bad); 735 | int popsize = good + bad; 736 | int maxgoodbad = max(good, bad); 737 | int m = min(sample, popsize - sample); 738 | double d4 = (cast(double)mingoodbad) / popsize; 739 | double d5 = 1.0 - d4; 740 | double d6 = m*d4 + 0.5; 741 | double d7 = sqrt((popsize - m) * sample * d4 *d5 / (popsize-1) + 0.5); 742 | double d8 = D1*d7 + D2; 743 | int d9 = cast(int)floor(cast(double)((m+1)*(mingoodbad+1))/(popsize+2)); 744 | double d10 = (logGamma(d9+1) + logGamma(mingoodbad-d9+1) + logGamma(m-d9+1) + 745 | logGamma(maxgoodbad-m+d9+1)); 746 | double d11 = min(min(m, mingoodbad)+1.0, floor(d6+16*d7)); 747 | /* 16 for 16-decimal-digit precision in D1 and D2 */ 748 | 749 | while (true) { 750 | X = uniform(0.0L, 1.0L, gen); 751 | Y = uniform(0.0L, 1.0L, gen); 752 | W = d6 + d8*(Y- 0.5)/X; 753 | 754 | /* fast rejection: */ 755 | if ((W < 0.0) || (W >= d11)) continue; 756 | 757 | Z = cast(int)floor(W); 758 | T = d10 - (logGamma(Z+1) + logGamma(mingoodbad-Z+1) + logGamma(m-Z+1) + 759 | logGamma(maxgoodbad-m+Z+1)); 760 | 761 | /* fast acceptance: */ 762 | if ((X*(4.0-X)-3.0) <= T) break; 763 | 764 | /* fast rejection: */ 765 | if (X*(X-T) >= 1) continue; 766 | 767 | if (2.0*log(X) <= T) break; /* acceptance */ 768 | } 769 | 770 | /* this is a correction to HRUA* by Ivan Frohne in rv.py */ 771 | if (good > bad) Z = m - Z; 772 | 773 | /* another fix from rv.py to allow sample to exceed popsize/2 */ 774 | if (m < sample) Z = good - Z; 775 | 776 | return Z; 777 | } 778 | 779 | /// 780 | int rHypergeometric(RGen = Random)(int n1, int n2, int n, ref RGen gen = rndGen) { 781 | dstatsEnforce(n <= n1 + n2, "n must be <= n1 + n2 for hypergeometric distribution."); 782 | dstatsEnforce(n1 >= 0 && n2 >= 0 && n >= 0, 783 | "n, n1, n2 must be >= 0 for hypergeometric distribution."); 784 | 785 | alias n1 good; 786 | alias n2 bad; 787 | alias n sample; 788 | if (sample > 10) { 789 | return hypergeoHrua(good, bad, sample, gen); 790 | } else { 791 | return hypergeoHyp(good, bad, sample, gen); 792 | } 793 | } 794 | 795 | unittest { 796 | 797 | static double hyperStdev(int n1, int n2, int n) { 798 | return sqrt(cast(double) n * (cast(double) n1 / (n1 + n2)) 799 | * (1 - cast(double) n1 / (n1 + n2)) * (n1 + n2 - n) / (n1 + n2 - 1)); 800 | } 801 | 802 | static double hyperSkew(double n1, double n2, double n) { 803 | double N = n1 + n2; 804 | alias n1 m; 805 | double numer = (N - 2 * m) * sqrt(N - 1) * (N - 2 * n); 806 | double denom = sqrt(n * m * (N - m) * (N - n)) * (N - 2); 807 | return numer / denom; 808 | } 809 | 810 | void testHyper(int n1, int n2, int n) { 811 | int[] observ = new int[100_000]; 812 | foreach(ref elem; observ) 813 | elem = rHypergeometric(n1, n2, n); 814 | auto ksRes = ksTest(observ, parametrize!(hypergeometricCDF)(n1, n2, n)); 815 | writeln("100k samples from hypergeom.(", n1, ", ", n2, ", ", n, "):"); 816 | writeln("\tMean Expected: ", n * cast(double) n1 / (n1 + n2), 817 | " Observed: ", mean(observ)); 818 | writeln("\tMedian Expected: ?? Observed: ", median(observ)); 819 | writeln("\tStdev Expected: ", hyperStdev(n1, n2, n), 820 | " Observed: ", stdev(observ)); 821 | writeln("\tKurtosis Expected: ?? Observed: ", kurtosis(observ)); 822 | writeln("\tSkewness Expected: ", hyperSkew(n1, n2, n), " Observed: ", skewness(observ)); 823 | GC.free(observ.ptr); 824 | } 825 | 826 | testHyper(4, 5, 2); 827 | testHyper(120, 105, 70); 828 | } 829 | 830 | private int rGeomSearch(RGen = Random)(double p, ref RGen gen = rndGen) { 831 | int X = 1; 832 | double sum = p, prod = p; 833 | double q = 1.0 - p; 834 | double U = uniform(0.0L, 1.0L, gen); 835 | while (U > sum) { 836 | prod *= q; 837 | sum += prod; 838 | X++; 839 | } 840 | return X; 841 | } 842 | 843 | private int rGeomInvers(RGen = Random)(double p, ref RGen gen = rndGen) { 844 | return cast(int)ceil(log(1.0-uniform(0.0L, 1.0L, gen))/log(1.0-p)); 845 | } 846 | 847 | int rGeometric(RGen = Random)(double p, ref RGen gen = rndGen) { 848 | dstatsEnforce(p >= 0 && p <= 1, "p must be between 0, 1 for geometric distribution."); 849 | 850 | if (p >= 0.333333333333333333333333) { 851 | return rGeomSearch(p, gen); 852 | } else { 853 | return rGeomInvers(p, gen); 854 | } 855 | } 856 | 857 | unittest { 858 | 859 | void testGeom(double p) { 860 | int[] observ = new int[100_000]; 861 | foreach(ref elem; observ) 862 | elem = rGeometric(p); 863 | writeln("100k samples from geometric.(", p, "):"); 864 | writeln("\tMean Expected: ", 1 / p, 865 | " Observed: ", mean(observ)); 866 | writeln("\tMedian Expected: ", ceil(-log(2.0) / log(1 - p)), 867 | " Observed: ", median(observ)); 868 | writeln("\tStdev Expected: ", sqrt((1 - p) / (p * p)), 869 | " Observed: ", stdev(observ)); 870 | writeln("\tKurtosis Expected: ", 6 + (p * p) / (1 - p), 871 | " Observed: ", kurtosis(observ)); 872 | writeln("\tSkewness Expected: ", (2 - p) / sqrt(1 - p), 873 | " Observed: ", skewness(observ)); 874 | GC.free(observ.ptr); 875 | } 876 | 877 | testGeom(0.1); 878 | testGeom(0.74); 879 | 880 | } 881 | 882 | /// 883 | int rNegBinom(RGen = Random)(double n, double p, ref RGen gen = rndGen) { 884 | dstatsEnforce(n >= 0, "n must be >= 0 for negative binomial distribution."); 885 | dstatsEnforce(p >= 0 && p <= 1, 886 | "p must be between 0, 1 for negative binomial distribution."); 887 | 888 | double Y = stdGamma(n, gen); 889 | Y *= (1 - p) / p; 890 | return rPoisson(Y, gen); 891 | } 892 | 893 | unittest { 894 | Random gen; 895 | gen.seed(unpredictableSeed); 896 | double p = 0.3L; 897 | int n = 30; 898 | int[] observ = new int[100_000]; 899 | foreach(ref elem; observ) 900 | elem = rNegBinom(n, p); 901 | writeln("100k samples from neg. binom.(", n,", ", p, "):"); 902 | writeln("\tMean Expected: ", n * (1 - p) / p, 903 | " Observed: ", mean(observ)); 904 | writeln("\tMedian Expected: ?? Observed: ", median(observ)); 905 | writeln("\tStdev Expected: ", sqrt(n * (1 - p) / (p * p)), 906 | " Observed: ", stdev(observ)); 907 | writeln("\tKurtosis Expected: ", (6 - p * (6 - p)) / (n * (1 - p)), 908 | " Observed: ", kurtosis(observ)); 909 | writeln("\tSkewness Expected: ", (2 - p) / sqrt(n * (1 - p)), 910 | " Observed: ", skewness(observ)); 911 | GC.free(observ.ptr); 912 | } 913 | 914 | /// 915 | double rLaplace(RGen = Random)(double mu = 0, double b = 1, ref RGen gen = rndGen) { 916 | dstatsEnforce(b > 0, "b must be > 0 for Laplace distribution."); 917 | 918 | double p = uniform(0.0L, 1.0L, gen); 919 | return invLaplaceCDF(p, mu, b); 920 | } 921 | 922 | unittest { 923 | Random gen; 924 | gen.seed(unpredictableSeed); 925 | double[] observ = new double[100_000]; 926 | foreach(ref elem; observ) 927 | elem = rLaplace(); 928 | auto ksRes = ksTest(observ, parametrize!(laplaceCDF)(0.0L, 1.0L)); 929 | writeln("100k samples from Laplace(0, 1): K-S P-val: ", ksRes.p); 930 | writeln("\tMean Expected: 0 Observed: ", mean(observ)); 931 | writeln("\tMedian Expected: 0 Observed: ", median(observ)); 932 | writeln("\tStdev Expected: 1.414 Observed: ", stdev(observ)); 933 | writeln("\tKurtosis Expected: 3 Observed: ", kurtosis(observ)); 934 | writeln("\tSkewness Expected: 0 Observed: ", skewness(observ)); 935 | GC.free(observ.ptr); 936 | } 937 | 938 | /// 939 | double rExponential(RGen = Random)(double lambda, ref RGen gen = rndGen) { 940 | dstatsEnforce(lambda > 0, "lambda must be > 0 for exponential distribution."); 941 | 942 | double p = uniform(0.0L, 1.0L, gen); 943 | return -log(p) / lambda; 944 | } 945 | 946 | unittest { 947 | double[] observ = new double[100_000]; 948 | foreach(ref elem; observ) 949 | elem = rExponential(2.0L); 950 | auto ksRes = ksTest(observ, parametrize!(gammaCDF)(2, 1)); 951 | writeln("100k samples from exponential(2): K-S P-val: ", ksRes.p); 952 | writeln("\tMean Expected: 0.5 Observed: ", mean(observ)); 953 | writeln("\tMedian Expected: 0.3465 Observed: ", median(observ)); 954 | writeln("\tStdev Expected: 0.5 Observed: ", stdev(observ)); 955 | writeln("\tKurtosis Expected: 6 Observed: ", kurtosis(observ)); 956 | writeln("\tSkewness Expected: 2 Observed: ", skewness(observ)); 957 | GC.free(observ.ptr); 958 | } 959 | 960 | private double stdGamma(RGen = Random)(double shape, ref RGen gen) { 961 | double b = void, c = void; 962 | double U = void, V = void, X = void, Y = void; 963 | 964 | if (shape == 1.0) { 965 | return rExponential(1.0, gen); 966 | } else if (shape < 1.0) { 967 | for (;;) { 968 | U = uniform(0.0L, 1.0, gen); 969 | V = rExponential(1.0, gen); 970 | if (U <= 1.0 - shape) { 971 | X = pow(U, 1.0/shape); 972 | if (X <= V) { 973 | return X; 974 | } 975 | } else { 976 | Y = -log((1-U)/shape); 977 | X = pow(1.0 - shape + shape*Y, 1./shape); 978 | if (X <= (V + Y)) { 979 | return X; 980 | } 981 | } 982 | } 983 | } else { 984 | b = shape - 1./3.; 985 | c = 1./sqrt(9*b); 986 | for (;;) { 987 | do { 988 | X = rNormal(0.0L, 1.0L, gen); 989 | V = 1.0 + c*X; 990 | } while (V <= 0.0); 991 | 992 | V = V*V*V; 993 | U = uniform(0.0L, 1.0L, gen); 994 | if (U < 1.0 - 0.0331*(X*X)*(X*X)) return (b*V); 995 | if (log(U) < 0.5*X*X + b*(1. - V + log(V))) return (b*V); 996 | } 997 | } 998 | } 999 | 1000 | /// 1001 | double rGamma(RGen = Random)(double a, double b, ref RGen gen = rndGen) { 1002 | dstatsEnforce(a > 0, "a must be > 0 for gamma distribution."); 1003 | dstatsEnforce(b > 0, "b must be > 0 for gamma distribution."); 1004 | 1005 | return stdGamma(b, gen) / a; 1006 | } 1007 | 1008 | unittest { 1009 | double[] observ = new double[100_000]; 1010 | foreach(ref elem; observ) 1011 | elem = rGamma(2.0L, 3.0L); 1012 | auto ksRes = ksTest(observ, parametrize!(gammaCDF)(2, 3)); 1013 | writeln("100k samples from gamma(2, 3): K-S P-val: ", ksRes.p); 1014 | writeln("\tMean Expected: 1.5 Observed: ", mean(observ)); 1015 | writeln("\tMedian Expected: ?? Observed: ", median(observ)); 1016 | writeln("\tStdev Expected: 0.866 Observed: ", stdev(observ)); 1017 | writeln("\tKurtosis Expected: 2 Observed: ", kurtosis(observ)); 1018 | writeln("\tSkewness Expected: 1.15 Observed: ", skewness(observ)); 1019 | GC.free(observ.ptr); 1020 | } 1021 | 1022 | /// 1023 | double rBeta(RGen = Random)(double a, double b, ref RGen gen = rndGen) { 1024 | dstatsEnforce(a > 0, "a must be > 0 for beta distribution."); 1025 | dstatsEnforce(b > 0, "b must be > 0 for beta distribution."); 1026 | 1027 | double Ga = void, Gb = void; 1028 | 1029 | if ((a <= 1.0) && (b <= 1.0)) { 1030 | double U, V, X, Y; 1031 | /* Use Jonk's algorithm */ 1032 | 1033 | while (1) { 1034 | U = uniform(0.0L, 1.0L, gen); 1035 | V = uniform(0.0L, 1.0L, gen); 1036 | X = pow(U, 1.0/a); 1037 | Y = pow(V, 1.0/b); 1038 | 1039 | if ((X + Y) <= 1.0) { 1040 | return X / (X + Y); 1041 | } 1042 | } 1043 | } else { 1044 | Ga = stdGamma(a, gen); 1045 | Gb = stdGamma(b, gen); 1046 | return Ga/(Ga + Gb); 1047 | } 1048 | assert(0); 1049 | } 1050 | 1051 | unittest { 1052 | double delegate(double) paramBeta(double a, double b) { 1053 | double parametrizedBeta(double x) { 1054 | return betaIncomplete(a, b, x); 1055 | } 1056 | return ¶metrizedBeta; 1057 | } 1058 | 1059 | static double betaStdev(double a, double b) { 1060 | return sqrt(a * b / ((a + b) * (a + b) * (a + b + 1))); 1061 | } 1062 | 1063 | static double betaSkew(double a, double b) { 1064 | auto numer = 2 * (b - a) * sqrt(a + b + 1); 1065 | auto denom = (a + b + 2) * sqrt(a * b); 1066 | return numer / denom; 1067 | } 1068 | 1069 | static double betaKurtosis(double a, double b) { 1070 | double numer = a * a * a - a * a * (2 * b - 1) + b * b * (b + 1) - 2 * a * b * (b + 2); 1071 | double denom = a * b * (a + b + 2) * (a + b + 3); 1072 | return 6 * numer / denom; 1073 | } 1074 | 1075 | void testBeta(double a, double b) { 1076 | double[] observ = new double[100_000]; 1077 | foreach(ref elem; observ) 1078 | elem = rBeta(a, b); 1079 | auto ksRes = ksTest(observ, paramBeta(a, b)); 1080 | auto summ = summary(observ); 1081 | writeln("100k samples from beta(", a, ", ", b, "): K-S P-val: ", ksRes.p); 1082 | writeln("\tMean Expected: ", a / (a + b), " Observed: ", summ.mean); 1083 | writeln("\tMedian Expected: ?? Observed: ", median(observ)); 1084 | writeln("\tStdev Expected: ", betaStdev(a, b), " Observed: ", summ.stdev); 1085 | writeln("\tKurtosis Expected: ", betaKurtosis(a, b), " Observed: ", summ.kurtosis); 1086 | writeln("\tSkewness Expected: ", betaSkew(a, b), " Observed: ", summ.skewness); 1087 | GC.free(observ.ptr); 1088 | } 1089 | 1090 | testBeta(0.5, 0.7); 1091 | testBeta(5, 3); 1092 | } 1093 | 1094 | /// 1095 | double rLogistic(RGen = Random)(double loc, double scale, ref RGen gen = rndGen) { 1096 | dstatsEnforce(scale > 0, "scale must be > 0 for logistic distribution."); 1097 | 1098 | double U = uniform(0.0L, 1.0L, gen); 1099 | return loc + scale * log(U/(1.0 - U)); 1100 | } 1101 | 1102 | unittest { 1103 | double[] observ = new double[100_000]; 1104 | foreach(ref elem; observ) 1105 | elem = rLogistic(2.0L, 3.0L); 1106 | auto ksRes = ksTest(observ, parametrize!(logisticCDF)(2, 3)); 1107 | writeln("100k samples from logistic(2, 3): K-S P-val: ", ksRes.p); 1108 | writeln("\tMean Expected: 2 Observed: ", mean(observ)); 1109 | writeln("\tMedian Expected: 2 Observed: ", median(observ)); 1110 | writeln("\tStdev Expected: ", PI * PI * 3, " Observed: ", stdev(observ)); 1111 | writeln("\tKurtosis Expected: 1.2 Observed: ", kurtosis(observ)); 1112 | writeln("\tSkewness Expected: 0 Observed: ", skewness(observ)); 1113 | GC.free(observ.ptr); 1114 | } 1115 | 1116 | /// 1117 | double rLogNormal(RGen = Random)(double mu, double sigma, ref RGen gen = rndGen) { 1118 | dstatsEnforce(sigma > 0, "sigma must be > 0 for log-normal distribution."); 1119 | 1120 | return exp(rNormal(mu, sigma, gen)); 1121 | } 1122 | 1123 | unittest { 1124 | auto observ = randArray!rLogNormal(100_000, -2, 1); 1125 | auto ksRes = ksTest(observ, paramFunctor!(logNormalCDF)(-2, 1)); 1126 | 1127 | auto summ = summary(observ); 1128 | writeln("100k samples from log-normal(-2, 1): K-S P-val: ", ksRes.p); 1129 | writeln("\tMean Expected: ", exp(-1.5), " Observed: ", summ.mean); 1130 | writeln("\tMedian Expected: ", exp(-2.0L), " Observed: ", median(observ)); 1131 | writeln("\tStdev Expected: ", sqrt((exp(1.) - 1) * exp(-4.0L + 1)), 1132 | " Observed: ", summ.stdev); 1133 | writeln("\tKurtosis Expected: ?? Observed: ", summ.kurtosis); 1134 | writeln("\tSkewness Expected: ", (exp(1.) + 2) * sqrt(exp(1.) - 1), 1135 | " Observed: ", summ.skewness); 1136 | } 1137 | 1138 | /// 1139 | double rWeibull(RGen = Random)(double shape, double scale = 1, ref RGen gen = rndGen) { 1140 | dstatsEnforce(shape > 0, "shape must be > 0 for weibull distribution."); 1141 | dstatsEnforce(scale > 0, "scale must be > 0 for weibull distribution."); 1142 | 1143 | return pow(rExponential(1, gen), 1. / shape) * scale; 1144 | } 1145 | 1146 | unittest { 1147 | double[] observ = new double[100_000]; 1148 | foreach(ref elem; observ) 1149 | elem = rWeibull(2.0L, 3.0L); 1150 | auto ksRes = ksTest(observ, parametrize!(weibullCDF)(2.0, 3.0)); 1151 | writeln("100k samples from weibull(2, 3): K-S P-val: ", ksRes.p); 1152 | GC.free(observ.ptr); 1153 | } 1154 | 1155 | /// 1156 | double rWald(RGen = Random)(double mu, double lambda, ref RGen gen = rndGen) { 1157 | dstatsEnforce(mu > 0, "mu must be > 0 for Wald distribution."); 1158 | dstatsEnforce(lambda > 0, "lambda must be > 0 for Wald distribution."); 1159 | 1160 | alias mu mean; 1161 | alias lambda scale; 1162 | 1163 | double mu_2l = mean / (2*scale); 1164 | double Y = rNormal(0, 1, gen); 1165 | Y = mean*Y*Y; 1166 | double X = mean + mu_2l*(Y - sqrt(4*scale*Y + Y*Y)); 1167 | double U = uniform(0.0L, 1.0L, gen); 1168 | if (U <= mean/(mean+X)) { 1169 | return X; 1170 | } else 1171 | 1172 | { 1173 | return mean*mean/X; 1174 | } 1175 | } 1176 | 1177 | unittest { 1178 | auto observ = randArray!rWald(100_000, 4, 7); 1179 | auto ksRes = ksTest(observ, parametrize!(waldCDF)(4, 7)); 1180 | 1181 | auto summ = summary(observ); 1182 | writeln("100k samples from wald(4, 7): K-S P-val: ", ksRes.p); 1183 | writeln("\tMean Expected: ", 4, " Observed: ", summ.mean); 1184 | writeln("\tMedian Expected: ?? Observed: ", median(observ)); 1185 | writeln("\tStdev Expected: ", sqrt(64.0 / 7), " Observed: ", summ.stdev); 1186 | writeln("\tKurtosis Expected: ", 15.0 * 4 / 7, " Observed: ", summ.kurtosis); 1187 | writeln("\tSkewness Expected: ", 3 * sqrt(4.0 / 7), " Observed: ", summ.skewness); 1188 | } 1189 | 1190 | /// 1191 | double rRayleigh(RGen = Random)(double mode, ref RGen gen = rndGen) { 1192 | dstatsEnforce(mode > 0, "mode must be > 0 for Rayleigh distribution."); 1193 | 1194 | return mode*sqrt(-2.0 * log(1.0 - uniform(0.0L, 1.0L, gen))); 1195 | } 1196 | 1197 | unittest { 1198 | auto observ = randArray!rRayleigh(100_000, 3); 1199 | auto ksRes = ksTest(observ, parametrize!(rayleighCDF)(3)); 1200 | writeln("100k samples from rayleigh(3): K-S P-val: ", ksRes.p); 1201 | } 1202 | 1203 | deprecated { 1204 | alias rNorm = rNormal; 1205 | alias rLogNorm = rLogNormal; 1206 | alias rStudentT = rStudentsT; 1207 | } 1208 | -------------------------------------------------------------------------------- /source/dstats/sort.d: -------------------------------------------------------------------------------- 1 | /** 2 | A comprehensive sorting library for statistical functions. Each function 3 | takes N arguments, which are arrays or array-like objects, sorts the first 4 | and sorts the rest in lockstep. For merge and insertion sort, if the last 5 | argument is a ulong*, increments the dereference of this ulong* by the bubble 6 | sort distance between the first argument and the sorted version of the first 7 | argument. This is useful for some statistical calculations. 8 | 9 | All sorting functions have the precondition that all parallel input arrays 10 | must have the same length. 11 | 12 | Notes: 13 | 14 | Comparison functions must be written such that compFun(x, x) == false. 15 | For example, "a < b" is good, "a <= b" is not. 16 | 17 | These functions are heavily optimized for sorting arrays of 18 | ints and floats (by far the most common case when doing statistical 19 | calculations). In these cases, they can be several times faster than the 20 | equivalent functions in std.algorithm. Since sorting is extremely important 21 | for non-parametric statistics, this results in important real-world 22 | performance gains. However, it comes at a price in terms of generality: 23 | 24 | 1. They assume that what they are sorting is cheap to copy via normal 25 | assignment. 26 | 27 | 2. They don't work at all with general ranges, only arrays and maybe 28 | ranges very similar to arrays. 29 | 30 | 3. All tuning and micro-optimization is done with ints and floats, not 31 | classes, large structs, strings, etc. 32 | 33 | Examples: 34 | --- 35 | auto foo = [3, 1, 2, 4, 5].dup; 36 | auto bar = [8, 6, 7, 5, 3].dup; 37 | qsort(foo, bar); 38 | assert(foo == [1, 2, 3, 4, 5]); 39 | assert(bar == [6, 7, 8, 5, 3]); 40 | auto baz = [1.0, 0, -1, -2, -3].dup; 41 | mergeSort!("a > b")(bar, foo, baz); 42 | assert(bar == [8, 7, 6, 5, 3]); 43 | assert(foo == [3, 2, 1, 4, 5]); 44 | assert(baz == [-1.0, 0, 1, -2, -3]); 45 | --- 46 | 47 | Author: David Simcha 48 | */ 49 | /* 50 | * License: 51 | * Boost Software License - Version 1.0 - August 17th, 2003 52 | * 53 | * Permission is hereby granted, free of charge, to any person or organization 54 | * obtaining a copy of the software and accompanying documentation covered by 55 | * this license (the "Software") to use, reproduce, display, distribute, 56 | * execute, and transmit the Software, and to prepare derivative works of the 57 | * Software, and to permit third-parties to whom the Software is furnished to 58 | * do so, all subject to the following: 59 | * 60 | * The copyright notices in the Software and this entire statement, including 61 | * the above license grant, this restriction and the following disclaimer, 62 | * must be included in all copies of the Software, in whole or in part, and 63 | * all derivative works of the Software, unless such copies or derivative 64 | * works are solely in the form of machine-executable object code generated by 65 | * a source language processor. 66 | * 67 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 68 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 69 | * FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT 70 | * SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE 71 | * FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, 72 | * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 73 | * DEALINGS IN THE SOFTWARE. 74 | */ 75 | 76 | module dstats.sort; 77 | 78 | import std.traits, std.algorithm, std.math, std.functional, std.math, std.typecons, 79 | std.typetuple, std.range, std.array, std.traits, std.ascii : whitespace; 80 | 81 | import dstats.alloc; 82 | 83 | version(unittest) { 84 | import std.stdio, std.random; 85 | } 86 | 87 | class SortException : Exception { 88 | this(string msg) { 89 | super(msg); 90 | } 91 | } 92 | 93 | /* CTFE function. Used in isSimpleComparison.*/ 94 | /*private*/ string removeWhitespace(string input) pure nothrow { 95 | string ret; 96 | foreach(elem; input) { 97 | bool shouldAppend = true; 98 | foreach(whiteChar; whitespace) { 99 | if(elem == whiteChar) { 100 | shouldAppend = false; 101 | break; 102 | } 103 | } 104 | 105 | if(shouldAppend) { 106 | ret ~= elem; 107 | } 108 | } 109 | return ret; 110 | } 111 | 112 | /* Conservatively tests whether the comparison function is simple enough that 113 | * we can get away with comparing floats as if they were ints. 114 | */ 115 | /*private*/ template isSimpleComparison(alias comp) { 116 | static if(!isSomeString!(typeof(comp))) { 117 | enum bool isSimpleComparison = false; 118 | } else { 119 | enum bool isSimpleComparison = 120 | removeWhitespace(comp) == "ab"; 122 | } 123 | } 124 | 125 | /*private*/ bool intIsNaN(I)(I i) { 126 | static if(is(I == int) || is(I == uint)) { 127 | // IEEE 754 single precision float has a 23-bit significand stored in the 128 | // lowest order bits, followed by an 8-bit exponent. A NaN is when the 129 | // exponent bits are all ones and the significand is nonzero. 130 | enum uint significandMask = 0b111_1111_1111_1111_1111_1111UL; 131 | enum uint exponentMask = 0b1111_1111UL << 23; 132 | } else static if(is(I == long) || is(I == ulong)) { 133 | // IEEE 754 double precision float has a 52-bit significand stored in the 134 | // lowest order bits, followed by an 11-bit exponent. A NaN is when the 135 | // exponent bits are all ones and the significand is nonzero. 136 | enum ulong significandMask = 137 | 0b1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111UL; 138 | enum ulong exponentMask = 0b111_1111_1111UL << 52; 139 | } else { 140 | static assert(0); 141 | } 142 | 143 | return ((i & exponentMask) == exponentMask) && ((i & significandMask) != 0); 144 | } 145 | 146 | unittest { 147 | // Test on randomly generated integers punned to floats. We expect that 148 | // about 1 in 256 will be NaNs. 149 | foreach(i; 0..10_000) { 150 | uint randInt = uniform(0U, uint.max); 151 | assert(std.math.isNaN(*(cast(float*) &randInt)) == intIsNaN(randInt)); 152 | } 153 | 154 | // Test on randomly generated integers punned to doubles. We expect that 155 | // about 1 in 2048 will be NaNs. 156 | foreach(i; 0..1_000_000) { 157 | ulong randInt = (cast(ulong) uniform(0U, uint.max) << 32) + uniform(0U, uint.max); 158 | assert(std.math.isNaN(*(cast(double*) &randInt)) == intIsNaN(randInt)); 159 | } 160 | } 161 | 162 | /* Check for NaN and do some bit twiddling so that a float or double can be 163 | * compared as an integer. This results in approximately a 40% speedup 164 | * compared to just sorting as floats. 165 | */ 166 | auto prepareForSorting(alias comp, T)(T arr) { 167 | static if(isSimpleComparison!comp) { 168 | static if(is(T == real[])) { 169 | foreach(elem; arr) { 170 | if(isNaN(elem)) { 171 | throw new SortException("Can't sort NaNs."); 172 | } 173 | } 174 | 175 | return arr; 176 | } else static if(is(T == double[]) || is(T == float[])) { 177 | static if(is(T == double[])) { 178 | alias long Int; 179 | enum signMask = 1UL << 63; 180 | } else { 181 | alias int Int; 182 | enum signMask = 1U << 31; 183 | } 184 | 185 | Int[] intArr = cast(Int[]) arr; 186 | foreach(i, ref elem; intArr) { 187 | if(intIsNaN(elem)) { 188 | // Roll back the bit twiddling in case someone catches the 189 | // exception, so that they don't see corrupted values. 190 | postProcess!comp(intArr[0..i]); 191 | 192 | throw new SortException("Can't sort NaNs."); 193 | } 194 | 195 | if(elem & signMask) { 196 | // Negative. 197 | elem ^= signMask; 198 | elem = ~elem; 199 | } 200 | } 201 | 202 | return intArr; 203 | } else { 204 | return arr; 205 | } 206 | 207 | } else { 208 | return arr; 209 | } 210 | } 211 | 212 | /*private*/ void postProcess(alias comp, T)(T arr) 213 | if(!isSimpleComparison!comp || (!is(T == double[]) && !is(T == float[]))) {} 214 | 215 | /* Undo bit twiddling from prepareForSorting() to get back original 216 | * floating point numbers. 217 | */ 218 | /*private*/ void postProcess(alias comp, F)(F arr) 219 | if((is(F == double[]) || is(F == float[])) && isSimpleComparison!comp) { 220 | static if(is(F == double[])) { 221 | alias long Int; 222 | enum mask = 1UL << 63; 223 | } else { 224 | alias int Int; 225 | enum mask = 1U << 31; 226 | } 227 | 228 | Int[] useMe = cast(Int[]) arr; 229 | foreach(ref elem; useMe) { 230 | if(elem & mask) { 231 | elem = ~elem; 232 | elem ^= mask; 233 | } 234 | } 235 | } 236 | 237 | version(unittest) { 238 | static void testFloating(alias fun, F)() { 239 | F[] testL = new F[1_000]; 240 | foreach(ref e; testL) { 241 | e = uniform(-1_000_000, 1_000_000); 242 | } 243 | auto testL2 = testL.dup; 244 | 245 | static if(__traits(isSame, fun, mergeSortTemp)) { 246 | auto temp1 = testL.dup; 247 | auto temp2 = testL.dup; 248 | } 249 | 250 | foreach(i; 0..200) { 251 | randomShuffle(zip(testL, testL2)); 252 | uint len = uniform(0, 1_000); 253 | 254 | static if(__traits(isSame, fun, mergeSortTemp)) { 255 | fun!"a > b"(testL[0..len], testL2[0..len], temp1[0..len], temp2[0..len]); 256 | } else { 257 | fun!("a > b")(testL[0..len], testL2[0..len]); 258 | } 259 | 260 | assert(isSorted!("a > b")(testL[0..len])); 261 | assert(testL == testL2, fun.stringof ~ '\t' ~ F.stringof); 262 | } 263 | } 264 | } 265 | 266 | void rotateLeft(T)(T input) 267 | if(isRandomAccessRange!(T)) { 268 | if(input.length < 2) return; 269 | ElementType!(T) temp = input[0]; 270 | foreach(i; 1..input.length) { 271 | input[i-1] = input[i]; 272 | } 273 | input[$-1] = temp; 274 | } 275 | 276 | void rotateRight(T)(T input) 277 | if(isRandomAccessRange!(T)) { 278 | if(input.length < 2) return; 279 | ElementType!(T) temp = input[$-1]; 280 | for(size_t i = input.length - 1; i > 0; i--) { 281 | input[i] = input[i-1]; 282 | } 283 | input[0] = temp; 284 | } 285 | 286 | /* Returns the index, NOT the value, of the median of the first, middle, last 287 | * elements of data.*/ 288 | size_t medianOf3(alias compFun, T)(T[] data) { 289 | alias binaryFun!(compFun) comp; 290 | immutable size_t mid = data.length / 2; 291 | immutable uint result = ((cast(uint) (comp(data[0], data[mid]))) << 2) | 292 | ((cast(uint) (comp(data[0], data[$ - 1]))) << 1) | 293 | (cast(uint) (comp(data[mid], data[$ - 1]))); 294 | 295 | assert(result != 2 && result != 5 && result < 8); // Cases 2, 5 can't happen. 296 | switch(result) { 297 | case 1: // 001 298 | case 6: // 110 299 | return data.length - 1; 300 | case 3: // 011 301 | case 4: // 100 302 | return 0; 303 | case 0: // 000 304 | case 7: // 111 305 | return mid; 306 | default: 307 | assert(0); 308 | } 309 | assert(0); 310 | } 311 | 312 | unittest { 313 | assert(medianOf3!("a < b")([1,2,3,4,5]) == 2); 314 | assert(medianOf3!("a < b")([1,2,5,4,3]) == 4); 315 | assert(medianOf3!("a < b")([3,2,1,4,5]) == 0); 316 | assert(medianOf3!("a < b")([5,2,3,4,1]) == 2); 317 | assert(medianOf3!("a < b")([5,2,1,4,3]) == 4); 318 | assert(medianOf3!("a < b")([3,2,5,4,1]) == 0); 319 | } 320 | 321 | 322 | /**Quick sort. Unstable, O(N log N) time average, worst 323 | * case, O(log N) space, small constant term in time complexity. 324 | * 325 | * In this implementation, the following steps are taken to avoid the 326 | * O(N2) worst case of naive quick sorts: 327 | * 328 | * 1. At each recursion, the median of the first, middle and last elements of 329 | * the array is used as the pivot. 330 | * 331 | * 2. To handle the case of few unique elements, the "Fit Pivot" technique 332 | * previously decribed by Andrei Alexandrescu is used. This allows 333 | * reasonable performance with few unique elements, with zero overhead 334 | * in other cases. 335 | * 336 | * 3. After a much larger than expected amount of recursion has occured, 337 | * this function transitions to a heap sort. This guarantees an O(N log N) 338 | * worst case.*/ 339 | T[0] qsort(alias compFun = "a < b", T...)(T data) 340 | if(T.length != 0) 341 | in { 342 | assert(data.length > 0); 343 | size_t len = data[0].length; 344 | foreach(array; data[1..$]) { 345 | assert(array.length == len); 346 | } 347 | } do { 348 | if(data[0].length < 25) { 349 | // Skip computing logarithm rather than waiting until qsortImpl to 350 | // do this. 351 | return insertionSort!compFun(data); 352 | } 353 | 354 | // Determines the transition point to a heap sort. 355 | uint TTL = cast(uint) (log2(cast(real) data[0].length) * 2); 356 | 357 | auto toSort = prepareForSorting!compFun(data[0]); 358 | 359 | /* qsort() throws if an invalid comparison function is passed. Even in 360 | * this case, the data should be post-processed so the bit twiddling 361 | * hacks for floats can be undone. 362 | */ 363 | try { 364 | qsortImpl!(compFun)(toSort, data[1..$], TTL); 365 | } finally { 366 | postProcess!compFun(data[0]); 367 | } 368 | 369 | return data[0]; 370 | } 371 | 372 | //TTL = time to live, before transitioning to heap sort. 373 | void qsortImpl(alias compFun, T...)(T data, uint TTL) { 374 | alias binaryFun!(compFun) comp; 375 | if(data[0].length < 25) { 376 | insertionSortImpl!(compFun)(data); 377 | return; 378 | } 379 | if(TTL == 0) { 380 | heapSortImpl!(compFun)(data); 381 | return; 382 | } 383 | TTL--; 384 | 385 | { 386 | immutable size_t med3 = medianOf3!(comp)(data[0]); 387 | foreach(array; data) { 388 | auto temp = array[med3]; 389 | array[med3] = array[$ - 1]; 390 | array[$ - 1] = temp; 391 | } 392 | } 393 | 394 | T less, greater; 395 | size_t lessI = size_t.max, greaterI = data[0].length - 1; 396 | 397 | auto pivot = data[0][$ - 1]; 398 | if(comp(pivot, pivot)) { 399 | throw new SortException 400 | ("Comparison function must be such that compFun(x, x) == false."); 401 | } 402 | 403 | while(true) { 404 | while(comp(data[0][++lessI], pivot)) {} 405 | while(greaterI > 0 && comp(pivot, data[0][--greaterI])) {} 406 | 407 | if(lessI < greaterI) { 408 | foreach(array; data) { 409 | auto temp = array[lessI]; 410 | array[lessI] = array[greaterI]; 411 | array[greaterI] = temp; 412 | } 413 | } else break; 414 | } 415 | 416 | foreach(ti, array; data) { 417 | auto temp = array[$ - 1]; 418 | array[$ - 1] = array[lessI]; 419 | array[lessI] = temp; 420 | less[ti] = array[0..min(lessI, greaterI + 1)]; 421 | greater[ti] = array[lessI + 1..$]; 422 | } 423 | // Allow tail recursion optimization for larger block. This guarantees 424 | // that, given a reasonable amount of stack space, no stack overflow will 425 | // occur even in pathological cases. 426 | if(greater[0].length > less[0].length) { 427 | qsortImpl!(compFun)(less, TTL); 428 | qsortImpl!(compFun)(greater, TTL); 429 | return; 430 | } else { 431 | qsortImpl!(compFun)(greater, TTL); 432 | qsortImpl!(compFun)(less, TTL); 433 | } 434 | } 435 | 436 | unittest { 437 | { // Test integer. 438 | uint[] test = new uint[1_000]; 439 | foreach(ref e; test) { 440 | e = uniform(0, 100); 441 | } 442 | auto test2 = test.dup; 443 | foreach(i; 0..1_000) { 444 | randomShuffle(zip(test, test2)); 445 | uint len = uniform(0, 1_000); 446 | qsort(test[0..len], test2[0..len]); 447 | assert(isSorted(test[0..len])); 448 | assert(test == test2); 449 | } 450 | } 451 | 452 | testFloating!(qsort, float)(); 453 | testFloating!(qsort, double)(); 454 | testFloating!(qsort, real)(); 455 | 456 | auto nanArr = [double.nan, 1.0]; 457 | try { 458 | qsort(nanArr); 459 | assert(0); 460 | } catch(SortException) {} 461 | } 462 | 463 | /* Keeps track of what array merge sort data is in. This is a speed hack to 464 | * copy back and forth less.*/ 465 | /*private*/ enum { 466 | DATA, 467 | TEMP 468 | } 469 | 470 | /**Merge sort. O(N log N) time, O(N) space, small constant. Stable sort. 471 | * If last argument is a ulong* instead of an array-like type, 472 | * the dereference of the ulong* will be incremented by the bubble sort 473 | * distance between the input array and the sorted version. This is useful 474 | * in some statistics functions such as Kendall's tau.*/ 475 | T[0] mergeSort(alias compFun = "a < b", T...)(T data) 476 | if(T.length != 0) 477 | in { 478 | assert(data.length > 0); 479 | size_t len = data[0].length; 480 | foreach(array; data[1..$]) { 481 | static if(!is(typeof(array) == ulong*)) 482 | assert(array.length == len); 483 | } 484 | } do { 485 | if(data[0].length < 65) { //Avoid mem allocation. 486 | return insertionSortImpl!(compFun)(data); 487 | } 488 | static if(is(T[$ - 1] == ulong*)) { 489 | enum dl = data.length - 1; 490 | alias data[$ - 1] swapCount; 491 | } else { 492 | enum dl = data.length; 493 | alias TypeTuple!() swapCount; // Place holder. 494 | } 495 | 496 | auto keyArr = prepareForSorting!compFun(data[0]); 497 | auto toSort = TypeTuple!(keyArr, data[1..dl]); 498 | 499 | typeof(toSort) temp; 500 | auto alloc = newRegionAllocator(); 501 | foreach(i, array; temp) { 502 | temp[i] = alloc.uninitializedArray!(typeof(temp[i][0])[])(data[i].length); 503 | } 504 | 505 | uint res = mergeSortImpl!(compFun)(toSort, temp, swapCount); 506 | if(res == TEMP) { 507 | foreach(ti, array; temp) { 508 | toSort[ti][0..$] = temp[ti][0..$]; 509 | } 510 | } 511 | 512 | postProcess!compFun(data[0]); 513 | return data[0]; 514 | } 515 | 516 | unittest { 517 | uint[] test = new uint[1_000], stability = new uint[1_000]; 518 | uint[] temp1 = new uint[1_000], temp2 = new uint[1_000]; 519 | foreach(ref e; test) { 520 | e = uniform(0, 100); //Lots of ties. 521 | } 522 | foreach(i; 0..100) { 523 | ulong mergeCount = 0, bubbleCount = 0; 524 | foreach(j, ref e; stability) { 525 | e = cast(uint) j; 526 | } 527 | randomShuffle(test); 528 | uint len = uniform(0, 1_000); 529 | // Testing bubble sort distance against bubble sort, 530 | // since bubble sort distance computed by bubble sort 531 | // is straightforward, unlikely to contain any subtle bugs. 532 | bubbleSort(test[0..len].dup, &bubbleCount); 533 | if(i & 1) // Test both temp and non-temp branches. 534 | mergeSort(test[0..len], stability[0..len], &mergeCount); 535 | else 536 | mergeSortTemp(test[0..len], stability[0..len], temp1[0..len], 537 | temp2[0..len], &mergeCount); 538 | assert(bubbleCount == mergeCount); 539 | assert(isSorted(test[0..len])); 540 | foreach(j; 1..len) { 541 | if(test[j - 1] == test[j]) { 542 | assert(stability[j - 1] < stability[j]); 543 | } 544 | } 545 | } 546 | // Test without swapCounts. 547 | foreach(i; 0..1000) { 548 | foreach(j, ref e; stability) { 549 | e = cast(uint) j; 550 | } 551 | randomShuffle(test); 552 | uint len = uniform(0, 1_000); 553 | if(i & 1) // Test both temp and non-temp branches. 554 | mergeSort(test[0..len], stability[0..len]); 555 | else 556 | mergeSortTemp(test[0..len], stability[0..len], temp1[0..len], 557 | temp2[0..len]); 558 | assert(isSorted(test[0..len])); 559 | foreach(j; 1..len) { 560 | if(test[j - 1] == test[j]) { 561 | assert(stability[j - 1] < stability[j]); 562 | } 563 | } 564 | } 565 | 566 | testFloating!(mergeSort, float)(); 567 | testFloating!(mergeSort, double)(); 568 | testFloating!(mergeSort, real)(); 569 | 570 | testFloating!(mergeSortTemp, float)(); 571 | testFloating!(mergeSortTemp, double)(); 572 | testFloating!(mergeSortTemp, real)(); 573 | } 574 | 575 | /**Merge sort, allowing caller to provide a temp variable. This allows 576 | * recycling instead of repeated allocations. If D is data, T is temp, 577 | * and U is a ulong* for calculating bubble sort distance, this can be called 578 | * as mergeSortTemp(D, D, D, T, T, T, U) or mergeSortTemp(D, D, D, T, T, T) 579 | * where each D has a T of corresponding type. 580 | * 581 | * Examples: 582 | * --- 583 | * int[] foo = [3, 1, 2, 4, 5].dup; 584 | * int[] temp = new uint[5]; 585 | * mergeSortTemp!("a < b")(foo, temp); 586 | * assert(foo == [1, 2, 3, 4, 5]); // The contents of temp will be undefined. 587 | * foo = [3, 1, 2, 4, 5].dup; 588 | * real bar = [3.14L, 15.9, 26.5, 35.8, 97.9]; 589 | * real temp2 = new real[5]; 590 | * mergeSortTemp(foo, bar, temp, temp2); 591 | * assert(foo == [1, 2, 3, 4, 5]); 592 | * assert(bar == [15.9L, 26.5, 3.14, 35.8, 97.9]); 593 | * // The contents of both temp and temp2 will be undefined. 594 | * --- 595 | */ 596 | T[0] mergeSortTemp(alias compFun = "a < b", T...)(T data) 597 | if(T.length != 0) 598 | in { 599 | assert(data.length > 0); 600 | size_t len = data[0].length; 601 | foreach(array; data[1..$]) { 602 | static if(!is(typeof(array) == ulong*)) 603 | assert(array.length == len); 604 | } 605 | } do { 606 | static if(is(T[$ - 1] == ulong*)) { 607 | enum dl = data.length - 1; 608 | } else { 609 | enum dl = data.length; 610 | } 611 | 612 | auto keyArr = prepareForSorting!compFun(data[0]); 613 | auto keyTemp = cast(typeof(keyArr)) data[dl / 2]; 614 | auto toSort = TypeTuple!( 615 | keyArr, 616 | data[1..dl / 2], 617 | keyTemp, 618 | data[dl / 2 + 1..$] 619 | ); 620 | 621 | uint res = mergeSortImpl!(compFun)(toSort); 622 | 623 | if(res == TEMP) { 624 | foreach(ti, array; toSort[0..$ / 2]) { 625 | toSort[ti][0..$] = toSort[ti + dl / 2][0..$]; 626 | } 627 | } 628 | 629 | postProcess!compFun(data[0]); 630 | return data[0]; 631 | } 632 | 633 | /*private*/ uint mergeSortImpl(alias compFun = "a < b", T...)(T dataIn) { 634 | static if(is(T[$ - 1] == ulong*)) { 635 | alias dataIn[$ - 1] swapCount; 636 | alias dataIn[0..dataIn.length / 2] data; 637 | alias dataIn[dataIn.length / 2..$ - 1] temp; 638 | } else { // Make empty dummy tuple. 639 | alias TypeTuple!() swapCount; 640 | alias dataIn[0..dataIn.length / 2] data; 641 | alias dataIn[dataIn.length / 2..$] temp; 642 | } 643 | 644 | if(data[0].length < 50) { 645 | insertionSortImpl!(compFun)(data, swapCount); 646 | return DATA; 647 | } 648 | size_t half = data[0].length / 2; 649 | typeof(data) left, right, tempLeft, tempRight; 650 | foreach(ti, array; data) { 651 | left[ti] = array[0..half]; 652 | right[ti] = array[half..$]; 653 | tempLeft[ti] = temp[ti][0..half]; 654 | tempRight[ti] = temp[ti][half..$]; 655 | } 656 | 657 | /* Implementation note: The lloc, rloc stuff is a hack to avoid constantly 658 | * copying data back and forth between the data and temp arrays. 659 | * Instad of copying every time, I keep track of which array the last merge 660 | * went into, and only copy at the end or if the two sides ended up in 661 | * different arrays.*/ 662 | uint lloc = mergeSortImpl!(compFun)(left, tempLeft, swapCount); 663 | uint rloc = mergeSortImpl!(compFun)(right, tempRight, swapCount); 664 | if(lloc == DATA && rloc == TEMP) { 665 | foreach(ti, array; tempLeft) { 666 | array[] = left[ti][]; 667 | } 668 | lloc = TEMP; 669 | } else if(lloc == TEMP && rloc == DATA) { 670 | foreach(ti, array; tempRight) { 671 | array[] = right[ti][]; 672 | } 673 | } 674 | if(lloc == DATA) { 675 | merge!(compFun)(left, right, temp, swapCount); 676 | return TEMP; 677 | } else { 678 | merge!(compFun)(tempLeft, tempRight, data, swapCount); 679 | return DATA; 680 | } 681 | } 682 | 683 | /*private*/ void merge(alias compFun, T...)(T data) { 684 | alias binaryFun!(compFun) comp; 685 | 686 | static if(is(T[$ - 1] == ulong*)) { 687 | enum dl = data.length - 1; //Length after removing swapCount; 688 | alias data[$ - 1] swapCount; 689 | } else { 690 | enum dl = data.length; 691 | } 692 | 693 | static assert(dl % 3 == 0); 694 | alias data[0..dl / 3] left; 695 | alias data[dl / 3..dl * 2 / 3] right; 696 | alias data[dl * 2 / 3..dl] result; 697 | static assert(left.length == right.length && right.length == result.length); 698 | size_t i = 0, l = 0, r = 0; 699 | while(l < left[0].length && r < right[0].length) { 700 | if(comp(right[0][r], left[0][l])) { 701 | 702 | static if(is(T[$ - 1] == ulong*)) { 703 | *swapCount += left[0].length - l; 704 | } 705 | 706 | foreach(ti, array; result) { 707 | result[ti][i] = right[ti][r]; 708 | } 709 | r++; 710 | } else { 711 | foreach(ti, array; result) { 712 | result[ti][i] = left[ti][l]; 713 | } 714 | l++; 715 | } 716 | i++; 717 | } 718 | if(right[0].length > r) { 719 | foreach(ti, array; result) { 720 | result[ti][i..$] = right[ti][r..$]; 721 | } 722 | } else { 723 | foreach(ti, array; result) { 724 | result[ti][i..$] = left[ti][l..$]; 725 | } 726 | } 727 | } 728 | 729 | /**In-place merge sort, based on C++ STL's stable_sort(). O(N log2 N) 730 | * time complexity, O(1) space complexity, stable. Much slower than plain 731 | * old mergeSort(), so only use it if you really need the O(1) space.*/ 732 | T[0] mergeSortInPlace(alias compFun = "a < b", T...)(T data) 733 | if(T.length != 0) 734 | in { 735 | assert(data.length > 0); 736 | size_t len = data[0].length; 737 | foreach(array; data[1..$]) { 738 | assert(array.length == len); 739 | } 740 | } do { 741 | auto toSort = prepareForSorting!compFun(data[0]); 742 | mergeSortInPlaceImpl!compFun(toSort, data[1..$]); 743 | postProcess!compFun(data[0]); 744 | return data[0]; 745 | } 746 | 747 | /*private*/ T[0] mergeSortInPlaceImpl(alias compFun, T...)(T data) { 748 | if (data[0].length <= 100) 749 | return insertionSortImpl!(compFun)(data); 750 | 751 | T left, right; 752 | foreach(ti, array; data) { 753 | left[ti] = array[0..$ / 2]; 754 | right[ti] = array[$ / 2..$]; 755 | } 756 | 757 | mergeSortInPlace!(compFun, T)(right); 758 | mergeSortInPlace!(compFun, T)(left); 759 | mergeInPlace!(compFun)(data, data[0].length / 2); 760 | return data[0]; 761 | } 762 | 763 | unittest { 764 | uint[] test = new uint[1_000], stability = new uint[1_000]; 765 | foreach(ref e; test) { 766 | e = uniform(0, 100); //Lots of ties. 767 | } 768 | uint[] test2 = test.dup; 769 | foreach(i; 0..1000) { 770 | foreach(j, ref e; stability) { 771 | e = cast(uint) j; 772 | } 773 | randomShuffle(zip(test, test2)); 774 | uint len = uniform(0, 1_000); 775 | mergeSortInPlace(test[0..len], test2[0..len], stability[0..len]); 776 | assert(isSorted(test[0..len])); 777 | assert(test == test2); 778 | foreach(j; 1..len) { 779 | if(test[j - 1] == test[j]) { 780 | assert(stability[j - 1] < stability[j]); 781 | } 782 | } 783 | } 784 | 785 | testFloating!(mergeSortInPlace, float)(); 786 | testFloating!(mergeSortInPlace, double)(); 787 | testFloating!(mergeSortInPlace, real)(); 788 | } 789 | 790 | // Loosely based on C++ STL's __merge_without_buffer(). 791 | /*private*/ void mergeInPlace(alias compFun = "a < b", T...)(T data, size_t middle) { 792 | alias binaryFun!(compFun) comp; 793 | 794 | static size_t largestLess(T)(T[] data, T value) { 795 | return assumeSorted!(comp)(data).lowerBound(value).length; 796 | } 797 | 798 | static size_t smallestGr(T)(T[] data, T value) { 799 | return data.length - 800 | assumeSorted!(comp)(data).upperBound(value).length; 801 | } 802 | 803 | 804 | if (data[0].length < 2 || middle == 0 || middle == data[0].length) { 805 | return; 806 | } 807 | 808 | if (data[0].length == 2) { 809 | if(comp(data[0][1], data[0][0])) { 810 | foreach(array; data) { 811 | auto temp = array[0]; 812 | array[0] = array[1]; 813 | array[1] = temp; 814 | } 815 | } 816 | return; 817 | } 818 | 819 | size_t half1, half2, firstCut, secondCut; 820 | 821 | if (middle > data[0].length - middle) { 822 | half1 = middle / 2; 823 | auto pivot = data[0][half1]; 824 | half2 = largestLess(data[0][middle..$], pivot); 825 | } else { 826 | half2 = (data[0].length - middle) / 2; 827 | auto pivot = data[0][half2 + middle]; 828 | half1 = smallestGr(data[0][0..middle], pivot); 829 | } 830 | 831 | foreach(array; data) { 832 | bringToFront(array[half1..middle], array[middle..middle + half2]); 833 | } 834 | size_t newMiddle = half1 + half2; 835 | 836 | T left, right; 837 | foreach(ti, array; data) { 838 | left[ti] = array[0..newMiddle]; 839 | right[ti] = array[newMiddle..$]; 840 | } 841 | 842 | mergeInPlace!(compFun, T)(left, half1); 843 | mergeInPlace!(compFun, T)(right, half2 + middle - newMiddle); 844 | } 845 | 846 | 847 | /**Heap sort. Unstable, O(N log N) time average and worst case, O(1) space, 848 | * large constant term in time complexity.*/ 849 | T[0] heapSort(alias compFun = "a < b", T...)(T data) 850 | if(T.length != 0) 851 | in { 852 | assert(data.length > 0); 853 | size_t len = data[0].length; 854 | foreach(array; data[1..$]) { 855 | assert(array.length == len); 856 | } 857 | } do { 858 | auto toSort = prepareForSorting!compFun(data[0]); 859 | heapSortImpl!compFun(toSort, data[1..$]); 860 | postProcess!compFun(data[0]); 861 | return data[0]; 862 | } 863 | 864 | /*private*/ T[0] heapSortImpl(alias compFun, T...)(T input) { 865 | // Heap sort has such a huge constant that insertion sort's faster for N < 866 | // 100 (for reals; even larger for smaller types). 867 | if(input[0].length <= 100) { 868 | return insertionSortImpl!(compFun)(input); 869 | } 870 | 871 | alias binaryFun!(compFun) comp; 872 | if(input[0].length < 2) return input[0]; 873 | makeMultiHeap!(compFun)(input); 874 | for(size_t end = input[0].length - 1; end > 0; end--) { 875 | foreach(ti, ia; input) { 876 | auto temp = ia[end]; 877 | ia[end] = ia[0]; 878 | ia[0] = temp; 879 | } 880 | multiSiftDown!(compFun)(input, 0, end); 881 | } 882 | return input[0]; 883 | } 884 | 885 | unittest { 886 | uint[] test = new uint[1_000]; 887 | foreach(ref e; test) { 888 | e = uniform(0, 100_000); 889 | } 890 | auto test2 = test.dup; 891 | foreach(i; 0..1_000) { 892 | randomShuffle(zip(test, test2)); 893 | uint len = uniform(0, 1_000); 894 | heapSort(test[0..len], test2[0..len]); 895 | assert(isSorted(test[0..len])); 896 | assert(test == test2); 897 | } 898 | 899 | testFloating!(heapSort, float)(); 900 | testFloating!(heapSort, double)(); 901 | testFloating!(heapSort, real)(); 902 | } 903 | 904 | void makeMultiHeap(alias compFun = "a < b", T...)(T input) { 905 | if(input[0].length < 2) 906 | return; 907 | alias binaryFun!(compFun) comp; 908 | for(sizediff_t start = (input[0].length - 1) / 2; start >= 0; start--) { 909 | multiSiftDown!(compFun)(input, start, input[0].length); 910 | } 911 | } 912 | 913 | void multiSiftDown(alias compFun = "a < b", T...) 914 | (T input, size_t root, size_t end) { 915 | alias binaryFun!(compFun) comp; 916 | alias input[0] a; 917 | while(root * 2 + 1 < end) { 918 | size_t child = root * 2 + 1; 919 | if(child + 1 < end && comp(a[child], a[child + 1])) { 920 | child++; 921 | } 922 | if(comp(a[root], a[child])) { 923 | foreach(ia; input) { 924 | auto temp = ia[root]; 925 | ia[root] = ia[child]; 926 | ia[child] = temp; 927 | } 928 | root = child; 929 | } 930 | else return; 931 | } 932 | } 933 | 934 | /**Insertion sort. O(N2) time worst, average case, O(1) space, VERY 935 | * small constant, which is why it's useful for sorting small subarrays in 936 | * divide and conquer algorithms. If last argument is a ulong*, increments 937 | * the dereference of this argument by the bubble sort distance between the 938 | * input array and the sorted version of the input.*/ 939 | T[0] insertionSort(alias compFun = "a < b", T...)(T data) 940 | in { 941 | assert(data.length > 0); 942 | size_t len = data[0].length; 943 | foreach(array; data[1..$]) { 944 | static if(!is(typeof(array) == ulong*)) 945 | assert(array.length == len); 946 | } 947 | } do { 948 | auto toSort = prepareForSorting!compFun(data[0]); 949 | insertionSortImpl!compFun(toSort, data[1..$]); 950 | postProcess!compFun(data[0]); 951 | return data[0]; 952 | } 953 | 954 | private template IndexType(T) { 955 | alias typeof(T.init[0]) IndexType; 956 | } 957 | 958 | /*private*/ T[0] insertionSortImpl(alias compFun, T...)(T data) { 959 | alias binaryFun!(compFun) comp; 960 | static if(is(T[$ - 1] == ulong*)) { 961 | enum dl = data.length - 1; 962 | alias data[$ - 1] swapCount; 963 | } else { 964 | enum dl = data.length; 965 | } 966 | 967 | alias data[0] keyArray; 968 | if(keyArray.length < 2) { 969 | return keyArray; 970 | } 971 | 972 | // Yes, I measured this, caching this value is actually faster on DMD. 973 | immutable maxJ = keyArray.length - 1; 974 | for(size_t i = keyArray.length - 2; i != size_t.max; --i) { 975 | size_t j = i; 976 | 977 | Tuple!(staticMap!(IndexType, typeof(data[0..dl]))) temp = void; 978 | foreach(ti, Type; typeof(data[0..dl])) { 979 | static if(hasElaborateAssign!Type) { 980 | emplace(&(temp.field[ti]), data[ti][i]); 981 | } else { 982 | temp.field[ti] = data[ti][i]; 983 | } 984 | } 985 | 986 | for(; j < maxJ && comp(keyArray[j + 1], temp.field[0]); ++j) { 987 | // It's faster to do all copying here than to call rotateLeft() 988 | // later, probably due to better ILP. 989 | foreach(array; data[0..dl]) { 990 | array[j] = array[j + 1]; 991 | } 992 | } 993 | 994 | foreach(ti, Unused; typeof(temp.field)) { 995 | data[ti][j] = temp.field[ti]; 996 | } 997 | 998 | static if(is(typeof(swapCount))) { 999 | *swapCount += (j - i); //Increment swapCount variable. 1000 | } 1001 | } 1002 | 1003 | return keyArray; 1004 | } 1005 | 1006 | unittest { 1007 | uint[] test = new uint[100], stability = new uint[100]; 1008 | foreach(ref e; test) { 1009 | e = uniform(0, 100); //Lots of ties. 1010 | } 1011 | foreach(i; 0..1_000) { 1012 | ulong insertCount = 0, bubbleCount = 0; 1013 | foreach(j, ref e; stability) { 1014 | e = cast(uint) j; 1015 | } 1016 | randomShuffle(test); 1017 | uint len = uniform(0, 100); 1018 | // Testing bubble sort distance against bubble sort, 1019 | // since bubble sort distance computed by bubble sort 1020 | // is straightforward, unlikely to contain any subtle bugs. 1021 | bubbleSort(test[0..len].dup, &bubbleCount); 1022 | insertionSort(test[0..len], stability[0..len], &insertCount); 1023 | assert(bubbleCount == insertCount); 1024 | assert(isSorted(test[0..len])); 1025 | foreach(j; 1..len) { 1026 | if(test[j - 1] == test[j]) { 1027 | assert(stability[j - 1] < stability[j]); 1028 | } 1029 | } 1030 | } 1031 | } 1032 | 1033 | // Kept around only because it's easy to implement, and therefore good for 1034 | // testing more complex sort functions against. Especially useful for bubble 1035 | // sort distance, since it's straightforward with a bubble sort, and not with 1036 | // a merge sort or insertion sort. 1037 | version(unittest) { 1038 | T[0] bubbleSort(alias compFun = "a < b", T...)(T data) { 1039 | alias binaryFun!(compFun) comp; 1040 | static if(is(T[$ - 1] == ulong*)) 1041 | enum dl = data.length - 1; 1042 | else enum dl = data.length; 1043 | if(data[0].length < 2) 1044 | return data[0]; 1045 | bool swapExecuted; 1046 | foreach(i; 0..data[0].length) { 1047 | swapExecuted = false; 1048 | foreach(j; 1..data[0].length) { 1049 | if(comp(data[0][j], data[0][j - 1])) { 1050 | swapExecuted = true; 1051 | static if(is(T[$ - 1] == ulong*)) 1052 | (*(data[$-1]))++; 1053 | foreach(array; data[0..dl]) 1054 | swap(array[j-1], array[j]); 1055 | } 1056 | } 1057 | if(!swapExecuted) return data[0]; 1058 | } 1059 | return data[0]; 1060 | } 1061 | } 1062 | 1063 | unittest { 1064 | //Sanity check for bubble sort distance. 1065 | uint[] test = [4, 5, 3, 2, 1]; 1066 | ulong dist = 0; 1067 | bubbleSort(test, &dist); 1068 | assert(dist == 9); 1069 | dist = 0; 1070 | test = [6, 1, 2, 4, 5, 3]; 1071 | bubbleSort(test, &dist); 1072 | assert(dist == 7); 1073 | } 1074 | 1075 | /**Returns the kth largest/smallest element (depending on compFun, 0-indexed) 1076 | * in the input array in O(N) time. Allocates memory, does not modify input 1077 | * array.*/ 1078 | T quickSelect(alias compFun = "a < b", T)(T[] data, sizediff_t k) { 1079 | auto alloc = newRegionAllocator(); 1080 | auto dataDup = alloc.array(data); 1081 | return partitionK!(compFun)(dataDup, k); 1082 | } 1083 | 1084 | /**Partitions the input data according to compFun, such that position k contains 1085 | * the kth largest/smallest element according to compFun. For all elements e 1086 | * with indices < k, !compFun(data[k], e) is guaranteed to be true. For all 1087 | * elements e with indices > k, !compFun(e, data[k]) is guaranteed to be true. 1088 | * For example, if compFun is "a < b", all elements with indices < k will be 1089 | * <= data[k], and all elements with indices larger than k will be >= k. 1090 | * Reorders any additional input arrays in lockstep. 1091 | * 1092 | * Examples: 1093 | * --- 1094 | * auto foo = [3, 1, 5, 4, 2].dup; 1095 | * auto secondSmallest = partitionK(foo, 1); 1096 | * assert(secondSmallest == 2); 1097 | * foreach(elem; foo[0..1]) { 1098 | * assert(elem <= foo[1]); 1099 | * } 1100 | * foreach(elem; foo[2..$]) { 1101 | * assert(elem >= foo[1]); 1102 | * } 1103 | * --- 1104 | * 1105 | * Returns: The kth element of the array. 1106 | */ 1107 | ElementType!(T[0]) partitionK(alias compFun = "a < b", T...)(T data, ptrdiff_t k) 1108 | in { 1109 | assert(data.length > 0); 1110 | size_t len = data[0].length; 1111 | foreach(array; data[1..$]) { 1112 | assert(array.length == len); 1113 | } 1114 | } do { 1115 | // Don't use the float-to-int trick because it's actually slower here 1116 | // because the main part of the algorithm is O(N), not O(N log N). 1117 | return partitionKImpl!compFun(data, k); 1118 | } 1119 | 1120 | /*private*/ ElementType!(T[0]) partitionKImpl(alias compFun, T...)(T data, ptrdiff_t k) { 1121 | alias binaryFun!(compFun) comp; 1122 | 1123 | { 1124 | immutable size_t med3 = medianOf3!(comp)(data[0]); 1125 | foreach(array; data) { 1126 | auto temp = array[med3]; 1127 | array[med3] = array[$ - 1]; 1128 | array[$ - 1] = temp; 1129 | } 1130 | } 1131 | 1132 | ptrdiff_t lessI = -1, greaterI = data[0].length - 1; 1133 | auto pivot = data[0][$ - 1]; 1134 | while(true) { 1135 | while(comp(data[0][++lessI], pivot)) {} 1136 | while(greaterI > 0 && comp(pivot, data[0][--greaterI])) {} 1137 | 1138 | if(lessI < greaterI) { 1139 | foreach(array; data) { 1140 | auto temp = array[lessI]; 1141 | array[lessI] = array[greaterI]; 1142 | array[greaterI] = temp; 1143 | } 1144 | } else break; 1145 | } 1146 | foreach(array; data) { 1147 | auto temp = array[lessI]; 1148 | array[lessI] = array[$ - 1]; 1149 | array[$ - 1] = temp; 1150 | } 1151 | 1152 | if((greaterI < k && lessI >= k) || lessI == k) { 1153 | return data[0][k]; 1154 | } else if(lessI < k) { 1155 | foreach(ti, array; data) { 1156 | data[ti] = array[lessI + 1..$]; 1157 | } 1158 | return partitionK!(compFun, T)(data, k - lessI - 1); 1159 | } else { 1160 | foreach(ti, array; data) { 1161 | data[ti] = array[0..min(greaterI + 1, lessI)]; 1162 | } 1163 | return partitionK!(compFun, T)(data, k); 1164 | } 1165 | } 1166 | 1167 | template ArrayElemType(T : T[]) { 1168 | alias T ArrayElemType; 1169 | } 1170 | 1171 | unittest { 1172 | enum n = 1000; 1173 | uint[] test = new uint[n]; 1174 | uint[] test2 = new uint[n]; 1175 | uint[] lockstep = new uint[n]; 1176 | foreach(ref e; test) { 1177 | e = uniform(0, 1000); 1178 | } 1179 | foreach(i; 0..1_000) { 1180 | test2[] = test[]; 1181 | lockstep[] = test[]; 1182 | uint len = uniform(0, n - 1) + 1; 1183 | qsort!("a > b")(test2[0..len]); 1184 | int k = uniform(0, len); 1185 | auto qsRes = partitionK!("a > b")(test[0..len], lockstep[0..len], k); 1186 | assert(qsRes == test2[k]); 1187 | foreach(elem; test[0..k]) { 1188 | assert(elem >= test[k]); 1189 | } 1190 | foreach(elem; test[k + 1..len]) { 1191 | assert(elem <= test[k]); 1192 | } 1193 | assert(test == lockstep); 1194 | } 1195 | } 1196 | 1197 | /**Given a set of data points entered through the put function, this output range 1198 | * maintains the invariant that the top N according to compFun will be 1199 | * contained in the data structure. Uses a heap internally, O(log N) insertion 1200 | * time. Good for finding the largest/smallest N elements of a very large 1201 | * dataset that cannot be sorted quickly in its entirety, and may not even fit 1202 | * in memory. If less than N datapoints have been entered, all are contained in 1203 | * the structure. 1204 | * 1205 | * Examples: 1206 | * --- 1207 | * Random gen; 1208 | * gen.seed(unpredictableSeed); 1209 | * uint[] nums = seq(0U, 100U); 1210 | * auto less = TopN!(uint, "a < b")(10); 1211 | * auto more = TopN!(uint, "a > b")(10); 1212 | * randomShuffle(nums, gen); 1213 | * foreach(n; nums) { 1214 | * less.put(n); 1215 | * more.put(n); 1216 | * } 1217 | * assert(less.getSorted == [0U, 1,2,3,4,5,6,7,8,9]); 1218 | * assert(more.getSorted == [99U, 98, 97, 96, 95, 94, 93, 92, 91, 90]); 1219 | * --- 1220 | */ 1221 | struct TopN(T, alias compFun = "a > b") { 1222 | private: 1223 | alias binaryFun!(compFun) comp; 1224 | uint n; 1225 | uint nAdded; 1226 | 1227 | T[] nodes; 1228 | public: 1229 | /** The variable ntop controls how many elements are retained.*/ 1230 | this(uint ntop) { 1231 | n = ntop; 1232 | nodes = new T[n]; 1233 | } 1234 | 1235 | /** Insert an element into the topN struct.*/ 1236 | void put(T elem) { 1237 | if(nAdded < n) { 1238 | nodes[nAdded] = elem; 1239 | if(nAdded == n - 1) { 1240 | makeMultiHeap!(comp)(nodes); 1241 | } 1242 | nAdded++; 1243 | } else if(nAdded >= n) { 1244 | if(comp(elem, nodes[0])) { 1245 | nodes[0] = elem; 1246 | multiSiftDown!(comp)(nodes, 0, nodes.length); 1247 | } 1248 | } 1249 | } 1250 | 1251 | /**Get the elements currently in the struct. Returns a reference to 1252 | * internal state, elements will be in an arbitrary order. Cheap.*/ 1253 | T[] getElements() { 1254 | return nodes[0..min(n, nAdded)]; 1255 | } 1256 | 1257 | /**Returns the elements sorted by compFun. The array returned is a 1258 | * duplicate of the input array. Not cheap.*/ 1259 | T[] getSorted() { 1260 | return qsort!(comp)(nodes[0..min(n, nAdded)].dup); 1261 | } 1262 | } 1263 | 1264 | unittest { 1265 | alias TopN!(uint, "a < b") TopNLess; 1266 | alias TopN!(uint, "a > b") TopNGreater; 1267 | Random gen; 1268 | gen.seed(unpredictableSeed); 1269 | uint[] nums = new uint[100]; 1270 | foreach(i, ref n; nums) { 1271 | n = cast(uint) i; 1272 | } 1273 | foreach(i; 0..100) { 1274 | auto less = TopNLess(10); 1275 | auto more = TopNGreater(10); 1276 | randomShuffle(nums, gen); 1277 | foreach(n; nums) { 1278 | less.put(n); 1279 | more.put(n); 1280 | } 1281 | assert(less.getSorted == [0U, 1,2,3,4,5,6,7,8,9]); 1282 | assert(more.getSorted == [99U, 98, 97, 96, 95, 94, 93, 92, 91, 90]); 1283 | } 1284 | foreach(i; 0..100) { 1285 | auto less = TopNLess(10); 1286 | auto more = TopNGreater(10); 1287 | randomShuffle(nums, gen); 1288 | foreach(n; nums[0..5]) { 1289 | less.put(n); 1290 | more.put(n); 1291 | } 1292 | assert(less.getSorted == qsort!("a < b")(nums[0..5])); 1293 | assert(more.getSorted == qsort!("a > b")(nums[0..5])); 1294 | } 1295 | } 1296 | --------------------------------------------------------------------------------