├── .gitignore ├── .gitmodules ├── .travis.yml ├── HLearn.cabal ├── LICENSE ├── README.md ├── Setup.hs ├── bench └── allknn │ ├── AllKnn.java │ ├── README.md │ ├── allknn.R │ ├── allknn.flann │ ├── allknn.julia │ ├── allknn.scikit │ ├── cover_tree │ ├── COPYING │ ├── Makefile │ ├── README │ ├── cover_tree.cc │ ├── cover_tree.h │ ├── knn.cc │ ├── point.cc │ ├── point.h │ ├── stack.h │ └── test_nn.cc │ ├── runtest.sh │ └── time2sec.hs ├── cbits └── emd.c ├── examples ├── example0001-optimization-univariate.hs ├── example0002-optimization-multivariate.hs └── example0003-classification.hs ├── executables ├── README.md └── hlearn-allknn │ └── Main.hs ├── install ├── README.md ├── ubuntu-precise-extras.sh └── ubuntu-precise.sh ├── src └── HLearn │ ├── Classifiers │ └── Linear.hs │ ├── Data │ ├── Graph.hs │ ├── Image.hs │ ├── LoadData.hs │ ├── SpaceTree.hs │ └── SpaceTree │ │ ├── Algorithms.hs │ │ ├── Algorithms │ │ ├── Correlation.hs │ │ ├── KernelDensityEstimation.hs │ │ ├── NearestNeighbor.Old.hs │ │ ├── NearestNeighbor.hs │ │ └── RangeSearch.hs │ │ ├── Algorithms_Specialized.hs │ │ ├── CoverTree.hs │ │ ├── CoverTree │ │ └── Unsafe.hs │ │ ├── CoverTree_Specialized.hs │ │ └── Diagrams.hs │ ├── Evaluation │ ├── CrossValidation.hs │ ├── CrossValidationData.hs │ └── CrossValidationHom.hs │ ├── History.hs │ ├── History │ └── Timing.hs │ ├── Models │ └── Distributions.hs │ └── Optimization │ ├── Amoeba.hs │ ├── Conic.hs │ ├── Multivariate.hs │ ├── StepSize.hs │ ├── StepSize │ ├── AlmeidaLanglois.hs │ ├── Const.hs │ └── Linear.hs │ ├── StochasticGradientDescent.hs │ ├── TestFunctions.hs │ └── Univariate.hs ├── stack.yaml └── test ├── BashTests.hs ├── QuickCheck.hs ├── allknn-mlpack └── runtest.sh └── allknn-verify ├── dataset-10000x2.csv ├── dataset-10000x2.csv-distances ├── dataset-10000x2.csv-neighbors ├── dataset-10000x20.csv ├── dataset-10000x20.csv-distances ├── dataset-10000x20.csv-neighbors └── runtest.sh /.gitignore: -------------------------------------------------------------------------------- 1 | *.Rout 2 | *.csv 3 | *.class 4 | *.o 5 | 6 | results 7 | cover_tree 8 | 9 | hlearn-allknn 10 | hlearn-linear 11 | *.swp 12 | *.swo 13 | dist/ 14 | gitignore/ 15 | examples/old/ 16 | scripts/ 17 | 18 | .cabal-sandbox/ 19 | cabal.sandbox.config 20 | 21 | .stack-work/ 22 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "datasets"] 2 | path = datasets 3 | url = https://github.com/mikeizbicki/datasets.git 4 | [submodule "subhask"] 5 | path = subhask 6 | url = https://github.com/mikeizbicki/subhask.git 7 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # NB: don't set `language: haskell` here 2 | 3 | # The following enables several GHC versions to be tested; often it's enough to test only against the last release in a major GHC version. Feel free to omit lines listings versions you don't need/want testing for. 4 | env: 5 | - CABALVER=1.22 GHCVER=7.10.1 LLVMVER=3.5 6 | - CABALVER=1.22 GHCVER=7.10.2 LLVMVER=3.5 7 | # - CABALVER=head GHCVER=head # see section about GHC HEAD snapshots 8 | 9 | before_install: 10 | - export PATH="~/.cabal/bin:/opt/ghc/$GHCVER/bin:/opt/cabal/$CABALVER/bin:/usr/bin:$PATH" 11 | 12 | install: 13 | - travis_retry sudo curl https://raw.githubusercontent.com/mikeizbicki/HLearn/travis/install/ubuntu-precise.sh | sh 14 | #- mkdir /home/travis/.cabal 15 | #- echo 'optimization: False' >> /home/travis/.cabal/config 16 | # Sometimes GHC uses too much RAM for the travis server, causing the install script to fail. 17 | # In this case, we can fix the problem by reinstalling without optimizations. 18 | 19 | # Here starts the actual work to be performed for the package under test; any command which exits with a non-zero exit code causes the build to fail. 20 | script: 21 | - if [ -f configure.ac ]; then autoreconf -i; fi 22 | - cabal configure -fO0 --enable-tests --enable-benchmarks -v2 # -v2 provides useful information for debugging 23 | - cabal build # this builds all libraries and executables (including tests/benchmarks) 24 | - cabal test 25 | - cabal check 26 | - cabal sdist # tests that a source-distribution can be generated 27 | 28 | # Check that the resulting source distribution can be built & installed. 29 | # If there are no other `.tar.gz` files in `dist`, this can be even simpler: 30 | # `cabal install --force-reinstalls dist/*-*.tar.gz` 31 | - SRC_TGZ=$(cabal info . | awk '{print $2;exit}').tar.gz && 32 | (cd dist && cabal install --force-reinstalls "$SRC_TGZ") 33 | -------------------------------------------------------------------------------- /HLearn.cabal: -------------------------------------------------------------------------------- 1 | Name: HLearn 2 | Version: 2.0.1.0 3 | Synopsis: Homomorphic machine learning 4 | Category: Data Mining, Machine Learning, Algorithms, Data structures 5 | License: BSD3 6 | License-file: LICENSE 7 | Author: Mike izbicki 8 | Maintainer: mike@izbicki.me 9 | Build-Type: Simple 10 | Cabal-Version: >=1.8 11 | homepage: http://github.com/mikeizbicki/HLearn/ 12 | bug-reports: http://github.com/mikeizbicki/HLearn/issues 13 | 14 | Description: 15 | HLearn is an experimental, but high performance machine learning library. 16 | For example, it currently has the world's fastest nearest neighbor implementation. 17 | It uses for fast numeric computations. 18 | The on github contains much more information. 19 | 20 | source-repository head 21 | type: git 22 | location: http://github.com/mikeizbicki/hlearn 23 | 24 | -------------------------------------------------------------------------------- 25 | 26 | Library 27 | Build-Depends: 28 | -- common dependencies 29 | base >= 4.8 && <4.9, 30 | subhask == 0.1.1.0, 31 | 32 | -- control 33 | mtl >= 2.1.2, 34 | 35 | -- i/o 36 | ansi-terminal >= 0.6.1.1, 37 | directory >= 1.2, 38 | time >= 1.4.2 39 | 40 | -- visualization 41 | -- diagrams-svg >= 0.6, 42 | -- diagrams-lib >= 1.3, 43 | -- process >= 1.1 44 | -- graphviz >= 2999.16 45 | 46 | hs-source-dirs: 47 | src 48 | 49 | Exposed-modules: 50 | 51 | HLearn.History 52 | HLearn.History.Timing 53 | 54 | -- HLearn.Data.Graph 55 | -- HLearn.Data.Image 56 | HLearn.Data.LoadData 57 | HLearn.Data.SpaceTree 58 | HLearn.Data.SpaceTree.CoverTree 59 | HLearn.Data.SpaceTree.CoverTree_Specialized 60 | HLearn.Data.SpaceTree.CoverTree.Unsafe 61 | HLearn.Data.SpaceTree.Algorithms 62 | HLearn.Data.SpaceTree.Algorithms_Specialized 63 | 64 | -- HLearn.Evaluation.CrossValidation 65 | 66 | HLearn.Classifiers.Linear 67 | HLearn.Models.Distributions 68 | 69 | HLearn.Optimization.Multivariate 70 | HLearn.Optimization.Univariate 71 | 72 | -- HLearn.Optimization.Amoeba 73 | -- HLearn.Optimization.Conic 74 | -- HLearn.Optimization.StepSize 75 | -- HLearn.Optimization.StochasticGradientDescent 76 | -- HLearn.Optimization.StepSize.Linear 77 | -- HLearn.Optimization.StepSize.Const 78 | -- HLearn.Optimization.StepSize.AlmeidaLanglois 79 | 80 | Other-modules: 81 | 82 | Extensions: 83 | FlexibleInstances 84 | FlexibleContexts 85 | MultiParamTypeClasses 86 | FunctionalDependencies 87 | UndecidableInstances 88 | ScopedTypeVariables 89 | BangPatterns 90 | TypeOperators 91 | GeneralizedNewtypeDeriving 92 | TypeFamilies 93 | StandaloneDeriving 94 | GADTs 95 | KindSignatures 96 | ConstraintKinds 97 | DeriveDataTypeable 98 | RankNTypes 99 | ImpredicativeTypes 100 | 101 | DataKinds 102 | PolyKinds 103 | AutoDeriveTypeable 104 | TemplateHaskell 105 | QuasiQuotes 106 | RebindableSyntax 107 | NoImplicitPrelude 108 | UnboxedTuples 109 | MagicHash 110 | PolymorphicComponents 111 | ExplicitNamespaces 112 | EmptyDataDecls 113 | 114 | PartialTypeSignatures 115 | MultiWayIf 116 | 117 | -------------------------------------------------------------------------------- 118 | 119 | Test-Suite BashTests 120 | type: exitcode-stdio-1.0 121 | hs-source-dirs: test 122 | main-is: BashTests.hs 123 | build-depends: base, process 124 | 125 | --------------------------------------- 126 | 127 | Test-Suite QuickCheck-Unoptimized 128 | type: exitcode-stdio-1.0 129 | hs-source-dirs: test 130 | main-is: QuickCheck.hs 131 | 132 | ghc-options: 133 | -O0 134 | 135 | build-depends: 136 | subhask, 137 | HLearn, 138 | test-framework-quickcheck2 >= 0.3.0, 139 | test-framework >= 0.8.0 140 | 141 | Test-Suite QuickCheck-Optimized 142 | type: exitcode-stdio-1.0 143 | hs-source-dirs: test 144 | main-is: QuickCheck.hs 145 | 146 | ghc-options: 147 | -O1 148 | 149 | build-depends: 150 | subhask, 151 | HLearn, 152 | test-framework-quickcheck2 >= 0.3.0, 153 | test-framework >= 0.8.0 154 | 155 | --------------------------------------- 156 | 157 | Test-Suite Example0001 158 | type: exitcode-stdio-1.0 159 | hs-source-dirs: examples 160 | main-is: example0001-optimization-univariate.hs 161 | build-depends: HLearn, subhask, base 162 | 163 | Test-Suite Example0002 164 | type: exitcode-stdio-1.0 165 | hs-source-dirs: examples 166 | main-is: example0002-optimization-multivariate.hs 167 | build-depends: HLearn, subhask, base 168 | 169 | Test-Suite Example0003 170 | type: exitcode-stdio-1.0 171 | hs-source-dirs: examples 172 | main-is: example0003-classification.hs 173 | build-depends: HLearn, subhask, base 174 | 175 | -------------------------------------------------------------------------------- 176 | 177 | executable hlearn-allknn 178 | main-is: Main.hs 179 | hs-source-dirs: executables/hlearn-allknn 180 | 181 | ghc-options: 182 | -threaded 183 | -rtsopts 184 | 185 | -funbox-strict-fields 186 | -fexcess-precision 187 | 188 | -- -prof -osuf p_o -hisuf p_hi 189 | 190 | -fllvm 191 | -optlo-O3 192 | -optlo-enable-fp-mad 193 | -optlo-enable-no-infs-fp-math 194 | -optlo-enable-no-nans-fp-math 195 | -optlo-enable-unsafe-fp-math 196 | 197 | -- -ddump-to-file 198 | -- -ddump-rule-firings 199 | -- -ddump-rule-rewrites 200 | -- -ddump-rules 201 | -- -ddump-simpl 202 | -- -ddump-simpl-stats 203 | -- -dppr-debug 204 | -- -dsuppress-module-prefixes 205 | -- -dsuppress-uniques 206 | -- -dsuppress-idinfo 207 | -- -dsuppress-coercions 208 | -- -dsuppress-type-applications 209 | 210 | -- -ddump-cmm 211 | 212 | build-depends: 213 | base , 214 | HLearn , 215 | subhask , 216 | 217 | cmdargs, 218 | MonadRandom 219 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013, Michael Izbicki 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions 7 | are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the author nor the names of his contributors 17 | may be used to endorse or promote products derived from this software 18 | without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE CONTRIBUTORS ``AS IS'' AND ANY EXPRESS 21 | OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 22 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE FOR 24 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS 26 | OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 27 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 28 | STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 29 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 30 | POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HLearn 2 | 3 | 4 | 5 | HLearn is a high performance machine learning library written in [Haskell](http://haskell.org). 6 | For example, it currently has the fastest nearest neighbor implementation for arbitrary metric spaces (see [this blog post](http://izbicki.me)). 7 | 8 | HLearn is also a research project. 9 | The research goal is to discover the "best possible" interface for machine learning. 10 | This involves two competing demands: 11 | The library should be as fast as low-level libraries written in C/C++/Fortran/Assembly; 12 | but it should be as flexible as libraries written in high level languages like Python/R/Matlab. 13 | [Julia](http://julialang.org/) is making amazing progress in this direction, 14 | but HLearn is more ambitious. 15 | In particular, HLearn's goal is to be *faster* than the low level languages and *more flexible* than the high level languages. 16 | 17 | To achieve this goal, HLearn uses a very different interface than standard learning libraries. 18 | The H in HLearn stands for three separate concepts that are fundamental to HLearn's design: 19 | 20 | 1. The H stands for [Haskell](http://haskell.org). 21 | Machine learning is about estimating *functions* from data, 22 | so it makes sense that a functional programming language would be well suited for machine learning. 23 | But Functional programming languages are not widely used in machine learning because they traditionally lack strong support for the fast numerical computations required for learning algorithms. 24 | HLearn uses the [SubHask](http://github.com/mikeizbicki/subhask) library to get this fast numeric support in Haskell. 25 | The two libraries are being developed in tandem with each other. 26 | 27 | 28 | 29 | 30 | 1. The H stands for [Homomorphisms](https://en.wikipedia.org/wiki/Homomorphism). 31 | Homomorphisms are a fundamental concept in [abstract algebra](https://en.wikipedia.org/wiki/Abstract_algebra), 32 | and HLearn exploits the algebraic structures inherrent in learning systems. 33 | The following table gives a brief overview of what these structures give us: 34 | 35 | | Structure | What we get | 36 | |:--------------|:--------------------------------------| 37 | | Monoid | parallel batch training | 38 | | Monoid | online training | 39 | | Monoid | fast cross-validation | 40 | | Abelian group | "untraining" of data points | 41 | | Abelian group | more fast cross-validation | 42 | | R-Module | weighted data points | 43 | | Vector space | fractionally weighted data points | 44 | | Functor | fast simple preprocessing of data | 45 | | Monad | fast complex preprocessing of data | 46 | 47 | 1. The H stands for the [History monad](https://github.com/mikeizbicki/HLearn/blob/master/src/HLearn/History.hs). 48 | One of the most difficult tasks of developing a new learning algorithm is debugging the optimization procedure. 49 | There has previously been essentially no work on making this debugging process easier, 50 | and the `History` monad tries to solve this problem. 51 | It lets you thread debugging information throughout the optimization code *without modifying the original code*. 52 | Furthermore, there is no runtime overhead associated with this technique. 53 | 54 | The downside of HLearn's ambition is that it currently does not implement many of the popular machine learning techniques. 55 | 56 | ## More Documentation 57 | 58 | Due to the rapid pace of development, HLearn's documentation is sparse. 59 | That said, the [examples](https://github.com/mikeizbicki/HLearn/tree/master/examples) folder is a good place to start. 60 | The haddock documentation embedded within the code is decent; 61 | but unfortunately, hackage is unable to compile the haddocks because it uses an older version of GHC. 62 | 63 | HLearn has several academic papers: 64 | 65 | * ICML15 - [Faster Cover Trees](http://izbicki.me/public/papers/icml2015-faster-cover-trees.pdf) 66 | * ICML13 - [Algebraic Classifiers: a generic approach to fast cross-validation, online training, and parallel training](http://izbicki.me/public/papers/icml2013-algebraic-classifiers.pdf) 67 | * TFP13 - [HLearn: A Machine Learning Library for Haskell](http://izbicki.me/public/papers/tfp2013-hlearn-a-machine-learning-library-for-haskell.pdf) 68 | 69 | There are also a number of blog posts on [my personal website](http://izbicki.me). 70 | Unfortunately, they are mostly out of date with the latest version of HLearn. 71 | They might help you understand some of the main concepts in HLearn, but the code they use won't work at all. 72 | 73 | * [The categorical distribution's monoid/group/module Structure](http://izbicki.me/blog/the-categorical-distributions-algebraic-structure) 74 | * [The categorical distribution's functor/monad structure](http://izbicki.me/blog/functors-and-monads-for-analyzing-data) 75 | * [Markov Networks, monoids, and futurama](http://izbicki.me/blog/markov-networks-monoids-and-futurama) 76 | * [Solving NP-complete problems with HLearn, and how to write your own HomTrainer instances](http://izbicki.me/public/papers/monoids-for-approximating-np-complete-problems.pdf) 77 | * [Nuclear weapon statistics using monoids, groups, and modules](http://izbicki.me/blog/nuclear-weapon-statistics-using-monoids-groups-and-modules-in-haskell) 78 | * [Gaussian distributions form a monoid](http://izbicki.me/blog/gausian-distributions-are-monoids) 79 | * [HLearn cross-validates >400x faster than Weka](http://izbicki.me/blog/hlearn-cross-validates-400x-faster-than-weka) 80 | * [HLearn's code is shorter and clearer than Weka's](http://izbicki.me/blog/hlearns-code-is-shorter-and-clearer-than-wekas) 81 | 82 | ## Contributing 83 | 84 | 85 | I'd love to have you contribute, and I'd be happy to help you get started! 86 | Just [create an issue](https://github.com/mikeizbicki/hlearn/issues) to let me know you're interested and we can work something out. 87 | -------------------------------------------------------------------------------- /Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | main = defaultMain 3 | -------------------------------------------------------------------------------- /bench/allknn/AllKnn.java: -------------------------------------------------------------------------------- 1 | /** 2 | * This file is modified from https://github.com/zoq/benchmarks/blob/master/methods/weka/allknn.py 3 | */ 4 | 5 | import java.io.*; 6 | import weka.core.*; 7 | import weka.core.neighboursearch.KDTree; 8 | import weka.core.converters.ConverterUtils.DataSource; 9 | 10 | /** 11 | * This class use the weka libary to implement All K-Nearest-Neighbors. 12 | */ 13 | public class AllKnn { 14 | private static final String USAGE = String 15 | .format(" This program will calculate the all k-nearest-neighbors of a set\n" 16 | + "of points using kd-trees. You may specify a separate set of\n" 17 | + "reference points and query points, or just a reference set which\n" 18 | + "will be used as both the reference and query set.\n\n" 19 | + "Required options:\n" 20 | + "-r [string] File containing the reference dataset.\n" 21 | + "-k [int] Number of furthest neighbors to find.\n\n" 22 | + "Options:\n" 23 | + "-l [int] Leaf size for tree building. Default value 20.\n" 24 | + "-q [string] File containing query points (optional).\n" 25 | + " Default value ''.\n"); 26 | 27 | public static void main(String args[]) { 28 | //Timers timer = new Timers(); 29 | try { 30 | // Get the data set path. 31 | String referenceFile = Utils.getOption('r', args); 32 | String queryFile = Utils.getOption('q', args); 33 | if (referenceFile.length() == 0) 34 | throw new IllegalArgumentException("Required option: File containing" + 35 | "the reference dataset."); 36 | 37 | // Load input dataset. 38 | DataSource source = new DataSource(referenceFile); 39 | Instances referenceData = source.getDataSet(); 40 | 41 | Instances queryData = null; 42 | if (queryFile.length() != 0) 43 | { 44 | source = new DataSource(queryFile); 45 | queryData = source.getDataSet(); 46 | } 47 | 48 | //timer.StartTimer("total_time"); 49 | 50 | // Get all the parameters. 51 | String leafSize = Utils.getOption('l', args); 52 | String neighbors = Utils.getOption('k', args); 53 | 54 | // Validate options. 55 | int k = 0; 56 | if (neighbors.length() == 0) 57 | { 58 | throw new IllegalArgumentException("Required option: Number of " + 59 | "furthest neighbors to find."); 60 | } 61 | else 62 | { 63 | k = Integer.parseInt(neighbors); 64 | if (k < 1 || k > referenceData.numInstances()) 65 | throw new IllegalArgumentException("[Fatal] Invalid k"); 66 | } 67 | 68 | int l = 20; 69 | if (leafSize.length() != 0) 70 | l = Integer.parseInt(leafSize); 71 | 72 | // Create KDTree. 73 | KDTree tree = new KDTree(); 74 | tree.setMaxInstInLeaf(l); 75 | tree.setInstances(referenceData); 76 | 77 | // Perform All K-Nearest-Neighbors. 78 | if (queryFile.length() != 0) 79 | { 80 | for (int i = 0; i < queryData.numInstances(); i++) 81 | { 82 | Instances out = tree.kNearestNeighbours(queryData.instance(i), k); 83 | } 84 | } 85 | else 86 | { 87 | for (int i = 0; i < referenceData.numInstances(); i++) 88 | { 89 | Instances out = tree.kNearestNeighbours(referenceData.instance(i), k); 90 | } 91 | } 92 | 93 | //timer.StopTimer("total_time"); 94 | //timer.PrintTimer("total_time"); 95 | } catch (IOException e) { 96 | System.err.println(USAGE); 97 | } catch (Exception e) { 98 | e.printStackTrace(); 99 | } 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /bench/allknn/README.md: -------------------------------------------------------------------------------- 1 | # bench/allknn 2 | 3 | This folder contains scripts that run all nearest neighbor searches in a number of libraries. 4 | For the most part, the scripts are very bare-bones. 5 | For example, they don't even output the results. 6 | 7 | To run the scripts, you'll obviously first need to install the libraries. 8 | The `/install` folder in this repo contains scripts for installing all of these libraries. 9 | With all the libraries installed, just call the `runtest.sh` script with a single parameter that is the dataset to test on. 10 | 11 | The table below provides a brief description of the libraries compared against. 12 | 13 | | Library | Description | 14 | |-------------------|---------------------------| 15 | | [FLANN](http://www.cs.ubc.ca/research/flann/) | The Fast Library for Approximate Nearest Neighbor queries. This C++ library is the standard method for nearest neighbor in Matlab/Octave and the [OpenCV](http://opencv.org) computer vision toolkit. | 16 | | [Julia](http://julia.org) | A popular new language designed from the ground up for fast data processing. Julia supports faster nearest neighbor queries using the [KDTrees.jl](https://github.com/JuliaGeometry/KDTrees.jl) package. | 17 | | [Langford's cover tree](http://hunch.net/~jl/projects/cover_tree/cover_tree.html) | A reference implementation for the cover tree data structure created by John Langford. The implementation is in C, and the data structure is widely included in C/C++ machine learning libraries. | 18 | | [MLPack](http://mlpack.org) | A C++ library for machine learning. MLPack was the first library to demonstrate the utility of generic programming in machine learning. The interface for nearest neighbor queries lets you use either a cover tree or kdtree. 19 | | [R](http://r-project.org) | A popular language for statisticians. Nearest neighbor queries are implemented in the [FNN](https://cran.r-project.org/web/packages/FNN/index.html) package, which provides bindings to the C-based [ANN](http://www.cs.umd.edu/~mount/ANN/) library for kdtrees. | 20 | | [scikit-learn](scikitlearn.org) | The Python machine learning toolkit. The documentation is very beginner friendly and easy to learn. The interface for nearest neighbor queries lets you use either a [ball tree](https://en.wikipedia.org/wiki/Ball_tree) or [kdtree](https://en.wikipedia.org/wiki/K-d_tree) to speed up the calculations. Both data structures were written in [Cython](http://cython.org). | 21 | | [Weka](http://weka.org) | A Java data mining tool with a popular GUI frontend. Nearest neighbor queries in Weka are very, very slow for me and not remotely competitive with any of the libraries above. | 22 | 23 | -------------------------------------------------------------------------------- /bench/allknn/allknn.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | library(FNN) 3 | 4 | args <- commandArgs(trailingOnly=TRUE) 5 | file <- args[1] 6 | 7 | k <- 1 8 | if (length(args)>=3) { 9 | k <- as.numeric(args[2]); 10 | } 11 | 12 | tree <- "cover_tree" 13 | #tree <- "kd_tree" 14 | #tree <- "brute" 15 | if (length(args)>=4) { 16 | tree <- args[3] 17 | } 18 | 19 | 20 | file_neighbors <- "neighbors_R.csv" 21 | file_distances <- "distances_R.csv" 22 | 23 | ####################################### 24 | ## load file 25 | 26 | cat("file = ", file, "\n") 27 | 28 | t0 <- Sys.time() 29 | cat("loading file......................") 30 | data <- read.csv(file, header=FALSE) 31 | t1 <- Sys.time() 32 | cat("done. ",difftime(t1,t0,units="secs"), " sec\n") 33 | 34 | ####################################### 35 | ## do allknn search 36 | 37 | t0 <- Sys.time() 38 | cat("finding neighbors.................") 39 | #print (data) 40 | #res <- get.knn(data[1:10000,],k=k,algorithm=tree) 41 | res <- get.knn(data,k=k,algorithm=tree) 42 | t1 <- Sys.time() 43 | cat("done. ",difftime(t1,t0,units="secs"), " sec\n") 44 | 45 | ####################################### 46 | ## output to files 47 | 48 | t0 <- Sys.time() 49 | cat("outputing neighbors...............") 50 | neighbors <- res[1] 51 | write.table(neighbors, file=file_neighbors,row.names=FALSE,col.names=FALSE,sep=",") 52 | t1 <- Sys.time() 53 | cat("done. ",difftime(t1,t0,units="secs"), " sec\n") 54 | 55 | t0 <- Sys.time() 56 | cat("outputing distances...............") 57 | write.csv(res[2], file=file_distances,row.names=FALSE) 58 | t1 <- Sys.time() 59 | cat("done. ",difftime(t1,t0,units="secs"), " sec\n") 60 | -------------------------------------------------------------------------------- /bench/allknn/allknn.flann: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | import numpy as np 5 | from pyflann import * 6 | 7 | print "loading csv" 8 | f = open(sys.argv[1]) 9 | data = np.loadtxt(fname=f, delimiter=',') 10 | 11 | print "building index" 12 | precision=1 13 | if len(sys.argv) > 3: 14 | precision=float(sys.argv[3]) 15 | flann = FLANN() 16 | params = flann.build_index(data, algorithm="autotuned", target_precision=precision) 17 | 18 | print "performing queries" 19 | result,dists = flann.nn_index(data, int(sys.argv[2]), checks=params["checks"]); 20 | -------------------------------------------------------------------------------- /bench/allknn/allknn.julia: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env julia 2 | 3 | println("compiling KDTrees") 4 | using KDTrees 5 | 6 | println("reading csv") 7 | data = transpose(readcsv(ARGS[1])) 8 | 9 | println("constructing kdtree") 10 | tree = KDTree(data) 11 | 12 | println("performing queries") 13 | for k in 1:size(data,2) 14 | p = data[:,k] 15 | knn(tree, vec(p), int(ARGS[2])) 16 | end 17 | -------------------------------------------------------------------------------- /bench/allknn/allknn.scikit: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | from sklearn.neighbors import NearestNeighbors 5 | import numpy as np 6 | 7 | f = open(sys.argv[1]) 8 | data = np.loadtxt(fname=f, delimiter=',') 9 | 10 | nbrs = NearestNeighbors(n_neighbors=int(sys.argv[2]), algorithm=sys.argv[3]).fit(data) 11 | distances, indices = nbrs.kneighbors(data) 12 | 13 | -------------------------------------------------------------------------------- /bench/allknn/cover_tree/COPYING: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 2, June 1991 3 | 4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc. 5 | 675 Mass Ave, Cambridge, MA 02139, USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | Preamble 10 | 11 | The licenses for most software are designed to take away your 12 | freedom to share and change it. By contrast, the GNU General Public 13 | License is intended to guarantee your freedom to share and change free 14 | software--to make sure the software is free for all its users. This 15 | General Public License applies to most of the Free Software 16 | Foundation's software and to any other program whose authors commit to 17 | using it. (Some other Free Software Foundation software is covered by 18 | the GNU Library General Public License instead.) You can apply it to 19 | your programs, too. 20 | 21 | When we speak of free software, we are referring to freedom, not 22 | price. Our General Public Licenses are designed to make sure that you 23 | have the freedom to distribute copies of free software (and charge for 24 | this service if you wish), that you receive source code or can get it 25 | if you want it, that you can change the software or use pieces of it 26 | in new free programs; and that you know you can do these things. 27 | 28 | To protect your rights, we need to make restrictions that forbid 29 | anyone to deny you these rights or to ask you to surrender the rights. 30 | These restrictions translate to certain responsibilities for you if you 31 | distribute copies of the software, or if you modify it. 32 | 33 | For example, if you distribute copies of such a program, whether 34 | gratis or for a fee, you must give the recipients all the rights that 35 | you have. You must make sure that they, too, receive or can get the 36 | source code. And you must show them these terms so they know their 37 | rights. 38 | 39 | We protect your rights with two steps: (1) copyright the software, and 40 | (2) offer you this license which gives you legal permission to copy, 41 | distribute and/or modify the software. 42 | 43 | Also, for each author's protection and ours, we want to make certain 44 | that everyone understands that there is no warranty for this free 45 | software. If the software is modified by someone else and passed on, we 46 | want its recipients to know that what they have is not the original, so 47 | that any problems introduced by others will not reflect on the original 48 | authors' reputations. 49 | 50 | Finally, any free program is threatened constantly by software 51 | patents. We wish to avoid the danger that redistributors of a free 52 | program will individually obtain patent licenses, in effect making the 53 | program proprietary. To prevent this, we have made it clear that any 54 | patent must be licensed for everyone's free use or not licensed at all. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | GNU GENERAL PUBLIC LICENSE 60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 61 | 62 | 0. This License applies to any program or other work which contains 63 | a notice placed by the copyright holder saying it may be distributed 64 | under the terms of this General Public License. The "Program", below, 65 | refers to any such program or work, and a "work based on the Program" 66 | means either the Program or any derivative work under copyright law: 67 | that is to say, a work containing the Program or a portion of it, 68 | either verbatim or with modifications and/or translated into another 69 | language. (Hereinafter, translation is included without limitation in 70 | the term "modification".) Each licensee is addressed as "you". 71 | 72 | Activities other than copying, distribution and modification are not 73 | covered by this License; they are outside its scope. The act of 74 | running the Program is not restricted, and the output from the Program 75 | is covered only if its contents constitute a work based on the 76 | Program (independent of having been made by running the Program). 77 | Whether that is true depends on what the Program does. 78 | 79 | 1. You may copy and distribute verbatim copies of the Program's 80 | source code as you receive it, in any medium, provided that you 81 | conspicuously and appropriately publish on each copy an appropriate 82 | copyright notice and disclaimer of warranty; keep intact all the 83 | notices that refer to this License and to the absence of any warranty; 84 | and give any other recipients of the Program a copy of this License 85 | along with the Program. 86 | 87 | You may charge a fee for the physical act of transferring a copy, and 88 | you may at your option offer warranty protection in exchange for a fee. 89 | 90 | 2. You may modify your copy or copies of the Program or any portion 91 | of it, thus forming a work based on the Program, and copy and 92 | distribute such modifications or work under the terms of Section 1 93 | above, provided that you also meet all of these conditions: 94 | 95 | a) You must cause the modified files to carry prominent notices 96 | stating that you changed the files and the date of any change. 97 | 98 | b) You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | 103 | c) If the modified program normally reads commands interactively 104 | when run, you must cause it, when started running for such 105 | interactive use in the most ordinary way, to print or display an 106 | announcement including an appropriate copyright notice and a 107 | notice that there is no warranty (or else, saying that you provide 108 | a warranty) and that users may redistribute the program under 109 | these conditions, and telling the user how to view a copy of this 110 | License. (Exception: if the Program itself is interactive but 111 | does not normally print such an announcement, your work based on 112 | the Program is not required to print an announcement.) 113 | 114 | These requirements apply to the modified work as a whole. If 115 | identifiable sections of that work are not derived from the Program, 116 | and can be reasonably considered independent and separate works in 117 | themselves, then this License, and its terms, do not apply to those 118 | sections when you distribute them as separate works. But when you 119 | distribute the same sections as part of a whole which is a work based 120 | on the Program, the distribution of the whole must be on the terms of 121 | this License, whose permissions for other licensees extend to the 122 | entire whole, and thus to each and every part regardless of who wrote it. 123 | 124 | Thus, it is not the intent of this section to claim rights or contest 125 | your rights to work written entirely by you; rather, the intent is to 126 | exercise the right to control the distribution of derivative or 127 | collective works based on the Program. 128 | 129 | In addition, mere aggregation of another work not based on the Program 130 | with the Program (or with a work based on the Program) on a volume of 131 | a storage or distribution medium does not bring the other work under 132 | the scope of this License. 133 | 134 | 3. You may copy and distribute the Program (or a work based on it, 135 | under Section 2) in object code or executable form under the terms of 136 | Sections 1 and 2 above provided that you also do one of the following: 137 | 138 | a) Accompany it with the complete corresponding machine-readable 139 | source code, which must be distributed under the terms of Sections 140 | 1 and 2 above on a medium customarily used for software interchange; or, 141 | 142 | b) Accompany it with a written offer, valid for at least three 143 | years, to give any third party, for a charge no more than your 144 | cost of physically performing source distribution, a complete 145 | machine-readable copy of the corresponding source code, to be 146 | distributed under the terms of Sections 1 and 2 above on a medium 147 | customarily used for software interchange; or, 148 | 149 | c) Accompany it with the information you received as to the offer 150 | to distribute corresponding source code. (This alternative is 151 | allowed only for noncommercial distribution and only if you 152 | received the program in object code or executable form with such 153 | an offer, in accord with Subsection b above.) 154 | 155 | The source code for a work means the preferred form of the work for 156 | making modifications to it. For an executable work, complete source 157 | code means all the source code for all modules it contains, plus any 158 | associated interface definition files, plus the scripts used to 159 | control compilation and installation of the executable. However, as a 160 | special exception, the source code distributed need not include 161 | anything that is normally distributed (in either source or binary 162 | form) with the major components (compiler, kernel, and so on) of the 163 | operating system on which the executable runs, unless that component 164 | itself accompanies the executable. 165 | 166 | If distribution of executable or object code is made by offering 167 | access to copy from a designated place, then offering equivalent 168 | access to copy the source code from the same place counts as 169 | distribution of the source code, even though third parties are not 170 | compelled to copy the source along with the object code. 171 | 172 | 4. You may not copy, modify, sublicense, or distribute the Program 173 | except as expressly provided under this License. Any attempt 174 | otherwise to copy, modify, sublicense or distribute the Program is 175 | void, and will automatically terminate your rights under this License. 176 | However, parties who have received copies, or rights, from you under 177 | this License will not have their licenses terminated so long as such 178 | parties remain in full compliance. 179 | 180 | 5. You are not required to accept this License, since you have not 181 | signed it. However, nothing else grants you permission to modify or 182 | distribute the Program or its derivative works. These actions are 183 | prohibited by law if you do not accept this License. Therefore, by 184 | modifying or distributing the Program (or any work based on the 185 | Program), you indicate your acceptance of this License to do so, and 186 | all its terms and conditions for copying, distributing or modifying 187 | the Program or works based on it. 188 | 189 | 6. Each time you redistribute the Program (or any work based on the 190 | Program), the recipient automatically receives a license from the 191 | original licensor to copy, distribute or modify the Program subject to 192 | these terms and conditions. You may not impose any further 193 | restrictions on the recipients' exercise of the rights granted herein. 194 | You are not responsible for enforcing compliance by third parties to 195 | this License. 196 | 197 | 7. If, as a consequence of a court judgment or allegation of patent 198 | infringement or for any other reason (not limited to patent issues), 199 | conditions are imposed on you (whether by court order, agreement or 200 | otherwise) that contradict the conditions of this License, they do not 201 | excuse you from the conditions of this License. If you cannot 202 | distribute so as to satisfy simultaneously your obligations under this 203 | License and any other pertinent obligations, then as a consequence you 204 | may not distribute the Program at all. For example, if a patent 205 | license would not permit royalty-free redistribution of the Program by 206 | all those who receive copies directly or indirectly through you, then 207 | the only way you could satisfy both it and this License would be to 208 | refrain entirely from distribution of the Program. 209 | 210 | If any portion of this section is held invalid or unenforceable under 211 | any particular circumstance, the balance of the section is intended to 212 | apply and the section as a whole is intended to apply in other 213 | circumstances. 214 | 215 | It is not the purpose of this section to induce you to infringe any 216 | patents or other property right claims or to contest validity of any 217 | such claims; this section has the sole purpose of protecting the 218 | integrity of the free software distribution system, which is 219 | implemented by public license practices. Many people have made 220 | generous contributions to the wide range of software distributed 221 | through that system in reliance on consistent application of that 222 | system; it is up to the author/donor to decide if he or she is willing 223 | to distribute software through any other system and a licensee cannot 224 | impose that choice. 225 | 226 | This section is intended to make thoroughly clear what is believed to 227 | be a consequence of the rest of this License. 228 | 229 | 8. If the distribution and/or use of the Program is restricted in 230 | certain countries either by patents or by copyrighted interfaces, the 231 | original copyright holder who places the Program under this License 232 | may add an explicit geographical distribution limitation excluding 233 | those countries, so that distribution is permitted only in or among 234 | countries not thus excluded. In such case, this License incorporates 235 | the limitation as if written in the body of this License. 236 | 237 | 9. The Free Software Foundation may publish revised and/or new versions 238 | of the General Public License from time to time. Such new versions will 239 | be similar in spirit to the present version, but may differ in detail to 240 | address new problems or concerns. 241 | 242 | Each version is given a distinguishing version number. If the Program 243 | specifies a version number of this License which applies to it and "any 244 | later version", you have the option of following the terms and conditions 245 | either of that version or of any later version published by the Free 246 | Software Foundation. If the Program does not specify a version number of 247 | this License, you may choose any version ever published by the Free Software 248 | Foundation. 249 | 250 | 10. If you wish to incorporate parts of the Program into other free 251 | programs whose distribution conditions are different, write to the author 252 | to ask for permission. For software which is copyrighted by the Free 253 | Software Foundation, write to the Free Software Foundation; we sometimes 254 | make exceptions for this. Our decision will be guided by the two goals 255 | of preserving the free status of all derivatives of our free software and 256 | of promoting the sharing and reuse of software generally. 257 | 258 | NO WARRANTY 259 | 260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 268 | REPAIR OR CORRECTION. 269 | 270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 278 | POSSIBILITY OF SUCH DAMAGES. 279 | 280 | END OF TERMS AND CONDITIONS 281 | 282 | How to Apply These Terms to Your New Programs 283 | 284 | If you develop a new program, and you want it to be of the greatest 285 | possible use to the public, the best way to achieve this is to make it 286 | free software which everyone can redistribute and change under these terms. 287 | 288 | To do so, attach the following notices to the program. It is safest 289 | to attach them to the start of each source file to most effectively 290 | convey the exclusion of warranty; and each file should have at least 291 | the "copyright" line and a pointer to where the full notice is found. 292 | 293 | 294 | Copyright (C) 19yy 295 | 296 | This program is free software; you can redistribute it and/or modify 297 | it under the terms of the GNU General Public License as published by 298 | the Free Software Foundation; either version 2 of the License, or 299 | (at your option) any later version. 300 | 301 | This program is distributed in the hope that it will be useful, 302 | but WITHOUT ANY WARRANTY; without even the implied warranty of 303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 304 | GNU General Public License for more details. 305 | 306 | You should have received a copy of the GNU General Public License 307 | along with this program; if not, write to the Free Software 308 | Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. 309 | 310 | Also add information on how to contact you by electronic and paper mail. 311 | 312 | If the program is interactive, make it output a short notice like this 313 | when it starts in an interactive mode: 314 | 315 | Gnomovision version 69, Copyright (C) 19yy name of author 316 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 317 | This is free software, and you are welcome to redistribute it 318 | under certain conditions; type `show c' for details. 319 | 320 | The hypothetical commands `show w' and `show c' should show the appropriate 321 | parts of the General Public License. Of course, the commands you use may 322 | be called something other than `show w' and `show c'; they could even be 323 | mouse-clicks or menu items--whatever suits your program. 324 | 325 | You should also get your employer (if you work as a programmer) or your 326 | school, if any, to sign a "copyright disclaimer" for the program, if 327 | necessary. Here is a sample; alter the names: 328 | 329 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program 330 | `Gnomovision' (which makes passes at compilers) written by James Hacker. 331 | 332 | , 1 April 1989 333 | Ty Coon, President of Vice 334 | 335 | This General Public License does not permit incorporating your program into 336 | proprietary programs. If your program is a subroutine library, you may 337 | consider it more useful to permit linking proprietary applications with the 338 | library. If this is what you want to do, use the GNU Library General 339 | Public License instead of this License. 340 | -------------------------------------------------------------------------------- /bench/allknn/cover_tree/Makefile: -------------------------------------------------------------------------------- 1 | CFLAGS=-O3 -ffast-math -funroll-loops -static # -msse -march=athlon 2 | 3 | all: test_nn knn 4 | 5 | cover_tree.o: cover_tree.cc cover_tree.h 6 | g++ -g -c -Wall $(CFLAGS) cover_tree.cc 7 | 8 | point.o: point.cc point.h 9 | g++ -g -c -Wall $(CFLAGS) point.cc 10 | 11 | test_nn: point.o cover_tree.o stack.h point.h test_nn.cc 12 | g++ -Wall $(CFLAGS) -o test_nn test_nn.cc cover_tree.o point.o 13 | 14 | knn: point.o cover_tree.o stack.h point.h knn.cc 15 | g++ -g -Wall $(CFLAGS) -o knn knn.cc cover_tree.o point.o 16 | 17 | clean: 18 | rm *.o 19 | -------------------------------------------------------------------------------- /bench/allknn/cover_tree/README: -------------------------------------------------------------------------------- 1 | This is a first pass implementation of a cover tree in c/c++. It is 2 | somewhat optimized and not terribly robust. 3 | 4 | 5 | 'test_nn ' checks the running times of tree creation and 6 | querying for nearest neighbors for every point in . 7 | 8 | 'knn ' computes the nearest neighbors of every point in 9 | amongst the points in 10 | 11 | Copyright is owned by John Langford. 12 | 13 | The code is licensed under the GPL. 14 | 15 | The file 'cover_tree.{cc,h}' contains the core cover tree computation code. 16 | The file 'point.{cc,h}' contains the definition of a point as well as 17 | parsing and printing functions. 18 | -------------------------------------------------------------------------------- /bench/allknn/cover_tree/cover_tree.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #define NDEBUG 6 | #include 7 | #include "point.h" 8 | 9 | struct node { 10 | point p; 11 | float max_dist; // The maximum distance to any grandchild. 12 | float parent_dist; // The distance to the parent. 13 | node* children; 14 | unsigned short int num_children; // The number of children. 15 | short int scale; // Essentially, an upper bound on the distance to any child. 16 | }; 17 | 18 | void print(int depth, node &top_node); 19 | 20 | //construction 21 | node new_leaf(const point &p); 22 | node batch_create(v_array points); 23 | //node insert(point, node *top_node); // not yet implemented 24 | //void remove(point, node *top_node); // not yet implemented 25 | //query 26 | void k_nearest_neighbor(const node &tope_node, const node &query, 27 | v_array > &results, int k); 28 | void epsilon_nearest_neighbor(const node &tope_node, const node &query, 29 | v_array > &results, float epsilon); 30 | void unequal_nearest_neighbor(const node &tope_node, const node &query, 31 | v_array > &results); 32 | //information gathering 33 | int height_dist(const node top_node,v_array &heights); 34 | void breadth_dist(const node top_node,v_array &breadths); 35 | void depth_dist(int top_scale, const node top_node,v_array &depths); 36 | -------------------------------------------------------------------------------- /bench/allknn/cover_tree/knn.cc: -------------------------------------------------------------------------------- 1 | #include "cover_tree.h" 2 | 3 | // Compute the k nearest neighbors 4 | 5 | int main(int argc, char *argv[]) 6 | { 7 | printf("log %f\n",log(-1)); 8 | 9 | int k = atoi(argv[1]); 10 | 11 | v_array set_of_points = parse_points(fopen(argv[2],"r")); 12 | v_array set_of_queries = parse_points(fopen(argv[3],"r")); 13 | 14 | numdist=0; 15 | node top = batch_create(set_of_points); 16 | printf("numdist in batch_create: %lu\n", numdist); 17 | 18 | v_array > res; 19 | 20 | numdist=0; 21 | if (!strcmp(argv[4],"dual")) { 22 | printf("dual tree search\n"); 23 | node top_query = batch_create(set_of_queries); 24 | k_nearest_neighbor(top,top_query,res,k); 25 | } 26 | else { 27 | printf("single tree search\n"); 28 | for (int i = 0; i < set_of_queries.index; i++) 29 | k_nearest_neighbor(top,new_leaf(set_of_queries[i]),res,k); 30 | } 31 | printf("numdist in k_nearest_neighbor: %lu\n", numdist); 32 | 33 | /* 34 | printf("Printing results\n"); 35 | for (int i = 0; i < res.index; i++) 36 | { 37 | for (int j = 0; j 4 | #include 5 | #include 6 | #include 7 | using namespace std; 8 | 9 | //typedef float v4sf __attribute__ ((mode(V4SF))); 10 | 11 | const int batch = 120;//must be a multiple of 8 12 | 13 | int point_len = 0; 14 | unsigned long numdist=0; 15 | 16 | //Assumption: points are a multiples of 8 long 17 | float distance(point p1, point p2, float upper_bound) 18 | { 19 | numdist++; 20 | float sum = 0.; 21 | float *end = p1 + point_len; 22 | upper_bound *= upper_bound; 23 | for (float *batch_end = p1 + batch; batch_end <= end; batch_end += batch) 24 | { 25 | for (; p1 != batch_end; p1+=2, p2+=2) 26 | { 27 | float d1 = *p1 - *p2; 28 | float d2 = *(p1+1) - *(p2+1); 29 | d1 *= d1; 30 | d2 *= d2; 31 | sum = sum + d1 + d2; 32 | } 33 | if (sum > upper_bound) 34 | return sqrt(sum); 35 | } 36 | for (; p1 != end; p1+=1, p2+=1) 37 | { 38 | float d1 = *p1 - *p2; 39 | float d2 = *(p1+1) - *(p2+1); 40 | d1 *= d1; 41 | d2 *= d2; 42 | sum = sum + d1 + d2; 43 | } 44 | return sqrt(sum); 45 | } 46 | /* 47 | //Assumption: points are a multiples of 8 long 48 | float sse_distance(point p1, point p2, float upper_bound) 49 | { 50 | v4sf sum = {0.,0.,0.,0.}; 51 | float *end = p1 + point_len; 52 | upper_bound *= upper_bound; 53 | for (float *batch_end = p1 + batch; batch_end <= end; batch_end += batch) 54 | { 55 | for (; p1 != batch_end; p1+=8, p2+=8) 56 | { 57 | v4sf v1 = __builtin_ia32_loadaps(p1); 58 | v4sf v2 = __builtin_ia32_loadaps(p2); 59 | v4sf v3 = __builtin_ia32_loadaps(p1+4); 60 | v4sf v4 = __builtin_ia32_loadaps(p2+4); 61 | v1 = __builtin_ia32_subps(v1, v2); 62 | v3 = __builtin_ia32_subps(v3, v4); 63 | v1 = __builtin_ia32_mulps(v1, v1); 64 | v3 = __builtin_ia32_mulps(v3, v3); 65 | v1 = __builtin_ia32_addps(v1,v3); 66 | sum = __builtin_ia32_addps(sum,v1); 67 | } 68 | v4sf temp = __builtin_ia32_addps(sum,__builtin_ia32_shufps(sum,sum,14)); 69 | temp = __builtin_ia32_addss(temp,__builtin_ia32_shufps(temp,temp,1)); 70 | if (((float *)&temp)[0] > upper_bound) 71 | { 72 | temp = __builtin_ia32_sqrtss(temp); 73 | return ((float *)&temp)[0]; 74 | } 75 | } 76 | for (; p1 != end; p1+=8, p2+=8) 77 | { 78 | v4sf v1 = __builtin_ia32_loadaps(p1); 79 | v4sf v2 = __builtin_ia32_loadaps(p2); 80 | v4sf v3 = __builtin_ia32_loadaps(p1+4); 81 | v4sf v4 = __builtin_ia32_loadaps(p2+4); 82 | v1 = __builtin_ia32_subps(v1, v2); 83 | v3 = __builtin_ia32_subps(v3, v4); 84 | v1 = __builtin_ia32_mulps(v1, v1); 85 | v3 = __builtin_ia32_mulps(v3, v3); 86 | v1 = __builtin_ia32_addps(v1,v3); 87 | sum = __builtin_ia32_addps(sum,v1); 88 | } 89 | sum = __builtin_ia32_addps(sum,__builtin_ia32_shufps(sum,sum,14)); 90 | sum = __builtin_ia32_addss(sum,__builtin_ia32_shufps(sum,sum,1)); 91 | sum = __builtin_ia32_sqrtss(sum); 92 | return ((float *) & sum)[0]; 93 | }*/ 94 | 95 | /* 96 | float distance(point p1, point p2, float upper_bound) 97 | { 98 | return fabsf(p1 - p2); 99 | } 100 | 101 | v_array parse_points(FILE *input) 102 | { 103 | v_array ret; 104 | for (int i = 0; i< 1000; i++) 105 | push(ret,(float) i); 106 | return ret; 107 | } 108 | 109 | void print(point &p) 110 | { 111 | printf("%f ",p); 112 | printf("\n"); 113 | } 114 | 115 | */ 116 | 117 | v_array parse_points(FILE *input) 118 | { 119 | v_array parsed; 120 | char c; 121 | v_array p; 122 | while ( (c = getc(input)) != EOF ) 123 | { 124 | ungetc(c,input); 125 | 126 | while ((c = getc(input)) != '\n' ) 127 | { 128 | while (c != '0' && c != '1' && c != '2' && c != '3' 129 | && c != '4' && c != '5' && c != '6' && c != '7' 130 | && c != '8' && c != '9' && c != '\n' && c != EOF && c != '-') 131 | c = getc(input); 132 | if (c != '\n' && c != EOF) { 133 | ungetc(c,input); 134 | float f; 135 | fscanf(input, "%f",&f); 136 | push(p,f); 137 | } 138 | else 139 | if (c == '\n') 140 | ungetc(c,input); 141 | } 142 | 143 | if (p.index %8 > 0) 144 | for (int i = 8 - p.index %8; i> 0; i--) 145 | push(p,(float) 0.); 146 | float *new_p; 147 | posix_memalign((void **)&new_p, 16, p.index*sizeof(float)); 148 | memcpy(new_p,p.elements,sizeof(float)*p.index); 149 | 150 | if (point_len > 0 && point_len != p.index) 151 | { 152 | printf("Can't handle vectors of differing length, bailing\n"); 153 | exit(0); 154 | } 155 | 156 | point_len = p.index; 157 | p.index = 0; 158 | push(parsed,new_p); 159 | } 160 | return parsed; 161 | } 162 | 163 | void print(point &p) 164 | { 165 | for (int i = 0; i 3 | #include 4 | 5 | typedef float* point; 6 | 7 | extern unsigned long numdist; 8 | 9 | float complete_distance(point v1, point v2); 10 | float distance(point v1, point v2, float upper_bound); 11 | v_array parse_points(FILE *input); 12 | void print(point &p); 13 | -------------------------------------------------------------------------------- /bench/allknn/cover_tree/stack.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | template class v_array{ 4 | public: 5 | int index; 6 | int length; 7 | T* elements; 8 | 9 | T last() { return elements[index-1];} 10 | void decr() { index--;} 11 | v_array() { index = 0; length=0; elements = NULL;} 12 | T& operator[](unsigned int i) { return elements[i]; } 13 | }; 14 | 15 | template void push(v_array& v, const T &new_ele) 16 | { 17 | while(v.index >= v.length) 18 | { 19 | v.length = 2*v.length + 3; 20 | v.elements = (T *)realloc(v.elements,sizeof(T) * v.length); 21 | } 22 | v[v.index++] = new_ele; 23 | } 24 | 25 | template void alloc(v_array& v, int length) 26 | { 27 | v.elements = (T *)realloc(v.elements, sizeof(T) * length); 28 | v.length = length; 29 | } 30 | 31 | template v_array pop(v_array > &stack) 32 | { 33 | if (stack.index > 0) 34 | return stack[--stack.index]; 35 | else 36 | return v_array(); 37 | } 38 | -------------------------------------------------------------------------------- /bench/allknn/cover_tree/test_nn.cc: -------------------------------------------------------------------------------- 1 | #include "cover_tree.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | using namespace std; 9 | 10 | float diff_timeval(timeval t1, timeval t2) 11 | { 12 | return (float) (t1.tv_sec - t2.tv_sec) + (t1.tv_usec - t2.tv_usec) * 1e-6; 13 | } 14 | 15 | float diff_clock(clock_t t1, clock_t t2) 16 | { 17 | return (float) (t1 - t2) / (float) CLOCKS_PER_SEC; 18 | } 19 | 20 | int compare(const void* p1, const void* p2) 21 | { 22 | if (p1 v = make_same(10000); 34 | if (argc <2 ) 35 | { 36 | cout << "usage: test_nn " << endl; 37 | exit(1); 38 | } 39 | int k = atoi(argv[1]); 40 | FILE* fd = fopen(argv[2],"r"); 41 | v_array v = parse_points(fd); 42 | printf("point length = %f\n",v[0]); 43 | //printf("point length = %i\n",v[0].index); 44 | //printf("first point = \n"); 45 | //print(v.elements[0]); 46 | 47 | cout << "fart" << endl; 48 | timeval parsed; 49 | clock_t parsed_clock = clock(); 50 | gettimeofday(&parsed,NULL); 51 | //printf("parse in %f seconds\n",parsed-start); 52 | 53 | cout << "batch_create" << endl; 54 | node top = batch_create(v); 55 | cout << "done." << endl; 56 | timeval created; 57 | // clock_t created_clock = clock(); 58 | gettimeofday(&created, NULL); 59 | //printf("created in %f seconds\n",diff(created,parsed)); 60 | 61 | //print(0, top); 62 | /* v_array depths; 63 | depth_dist(top.scale, top, depths); 64 | 65 | printf("depth distribution = \n"); 66 | for (int i = 0; i < depths.index; i++) 67 | if (depths[i] > 0) 68 | printf("%i\t",i); 69 | printf("\n"); 70 | for (int i = 0; i < depths.index; i++) 71 | if (depths[i] > 0) 72 | printf("%i\t",depths[i]); 73 | printf("\n"); 74 | 75 | v_array heights; 76 | printf("max height = %i\n",height_dist(top, heights)); 77 | 78 | printf("height distribution = \n"); 79 | for (int i = 0; i < heights.index; i++) 80 | printf("%i\t",i); 81 | printf("\n"); 82 | for (int i = 0; i < heights.index; i++) 83 | printf("%i\t",heights[i]); 84 | printf("\n"); 85 | 86 | v_array breadths; 87 | breadth_dist(top,breadths); 88 | 89 | printf("breadth distribution = \n"); 90 | for (int i = 0; i < breadths.index; i++) 91 | if (breadths[i] > 0) 92 | printf("%i\t",i); 93 | printf("\n"); 94 | for (int i = 0; i < breadths.index; i++) 95 | if (breadths[i] > 0) 96 | printf("%i\t",breadths[i]); 97 | printf("\n");*/ 98 | 99 | cout << "v_array" << endl; 100 | v_array > res; 101 | /* for (int i = 0; i < v.index; i++) 102 | k_nearest_neighbor(top,new_leaf(v[i]),res,k); */ 103 | cout << "starting knn" << endl; 104 | k_nearest_neighbor(top,top,res,k); 105 | 106 | /* printf("Printing results\n"); 107 | for (int i = 0; i< res.index; i++) 108 | { 109 | for (int j = 0; j > brute_neighbors; 131 | for (int i=0; i < res.index && i < thresh; i++) { 132 | point this_point = res[i][0]; 133 | float upper_dist[k]; 134 | point min_points[k]; 135 | for (int j=0; j us; 165 | push(us,this_point); 166 | for (int j = 0; j&1 >> "stdout.$sanitizedcmd") 97 | echo "$runtime" >> "stderr.$sanitizedcmd" 98 | runtime=$(tail -1 <<< "$runtime") 99 | runseconds=$(../../time2sec.hs $runtime) 100 | 101 | echo "$runtime $runseconds" 102 | 103 | #cat "stdout.$sanitizedcmd" 104 | #cat "stderr.$sanitizedcmd" 105 | 106 | echo "$sanitizedcmd $runtime $runseconds" >> results 107 | done 108 | -------------------------------------------------------------------------------- /bench/allknn/time2sec.hs: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env runghc 2 | 3 | import Data.List.Split 4 | import System.Environment 5 | 6 | main = do 7 | str:[] <- getArgs 8 | print $ str2seconds str 9 | 10 | str2seconds :: String -> Double 11 | str2seconds xs = case splitOn ":" xs of 12 | [s] -> read s 13 | [m,s] -> read m*60 + read s 14 | [h,m,s] -> read h*60*60+read m*60+read s 15 | 16 | -------------------------------------------------------------------------------- /cbits/emd.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | float emd_float(float *sig1, int len1, float *sig2, int len2, float *costmat) 5 | { 6 | CvMat *s1=cvCreateMatHeader(len1,1,CV_32F), 7 | *s2=cvCreateMatHeader(len2,1,CV_32F), 8 | *c =cvCreateMatHeader(len1,len2,CV_32F); 9 | 10 | s1->data.fl=sig1; 11 | s2->data.fl=sig2; 12 | c->data.fl=costmat; 13 | 14 | float lb=1; 15 | 16 | float ret=cvCalcEMD2(s1,s2,CV_DIST_USER,NULL,c,NULL,NULL,NULL); 17 | 18 | cvReleaseMat(&s1); 19 | cvReleaseMat(&s2); 20 | cvReleaseMat(&c); 21 | return ret; 22 | } 23 | 24 | /* 25 | int main() 26 | { 27 | float s1[]={1,4,4}; 28 | float s2[]={6,2,1}; 29 | float c[]= 30 | { 0, 1, 4 31 | , 1, 0, 1 32 | , 4, 1, 0 33 | }; 34 | 35 | float res=emd_float(s1,3,s2,3,c); 36 | printf("emd=%f\n",res); 37 | return 0; 38 | } 39 | */ 40 | -------------------------------------------------------------------------------- /examples/example0001-optimization-univariate.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE NoImplicitPrelude #-} 2 | {-# LANGUAGE ScopedTypeVariables #-} 3 | {-# LANGUAGE RebindableSyntax #-} 4 | 5 | import SubHask 6 | import SubHask.Algebra.Vector 7 | import SubHask.Category.Trans.Derivative 8 | 9 | import HLearn.History 10 | import HLearn.Optimization.Univariate 11 | 12 | import System.IO 13 | 14 | -------------------------------------------------------------------------------- 15 | 16 | x0 :: OrdField a => a 17 | x0 = (-5) 18 | 19 | optmethod :: OrdField a => (a -> a) -> a 20 | optmethod f = x1 $ evalHistory $ fminuncM_brent (maxIterations 20 || noProgress) x0 (return . f) 21 | 22 | -------------------------------------------------------------------------------- 23 | 24 | main = do 25 | 26 | let f_slopes :: Float = optmethod slopes 27 | d_slopes :: Double = optmethod slopes 28 | r_slopes :: Rational = optmethod slopes 29 | 30 | putStrLn $ "f_slopes = " ++ show f_slopes 31 | putStrLn $ "d_slopes = " ++ show d_slopes 32 | putStrLn $ "r_slopes = " ++ show r_slopes 33 | 34 | let f_slopesWithDiscontinuity :: Float = optmethod slopesWithDiscontinuity 35 | d_slopesWithDiscontinuity :: Double = optmethod slopesWithDiscontinuity 36 | r_slopesWithDiscontinuity :: Rational = optmethod slopesWithDiscontinuity 37 | 38 | putStrLn $ "f_slopesWithDiscontinuity = " ++ show f_slopesWithDiscontinuity 39 | putStrLn $ "d_slopesWithDiscontinuity = " ++ show d_slopesWithDiscontinuity 40 | putStrLn $ "r_slopesWithDiscontinuity = " ++ show r_slopesWithDiscontinuity 41 | 42 | putStrLn "done." 43 | -------------------------------------------------------------------------------- /examples/example0002-optimization-multivariate.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE NoImplicitPrelude #-} 2 | {-# LANGUAGE ScopedTypeVariables #-} 3 | {-# LANGUAGE RebindableSyntax #-} 4 | {-# LANGUAGE DataKinds #-} 5 | {-# LANGUAGE ConstraintKinds #-} 6 | {-# LANGUAGE RankNTypes #-} 7 | 8 | import SubHask 9 | import SubHask.Algebra.Vector 10 | import SubHask.Category.Trans.Derivative 11 | 12 | import HLearn.History 13 | import HLearn.Optimization.Univariate 14 | import HLearn.Optimization.Multivariate 15 | 16 | import System.IO 17 | 18 | -------------------------------------------------------------------------------- 19 | 20 | main = do 21 | let x0 = unsafeToModule [1,2,1,2] :: SVector 4 Double 22 | f = rosenbrock 23 | lineSearch = lineSearch_brent ( stop_brent 1e-12 || maxIterations 20 ) 24 | 25 | stop :: StopCondition a 26 | stop = maxIterations 20 27 | 28 | let cgd conj = evalHistory $ fminunc_cgd_ conj lineSearch stop x0 f 29 | 30 | putStrLn $ "steepestDescent = " ++ show (fx1 $ cgd steepestDescent) 31 | putStrLn $ "fletcherReeves = " ++ show (fx1 $ cgd fletcherReeves) 32 | putStrLn $ "polakRibiere = " ++ show (fx1 $ cgd polakRibiere) 33 | putStrLn $ "hestenesStiefel = " ++ show (fx1 $ cgd hestenesStiefel) 34 | 35 | putStrLn $ "bfgs = " ++ show (fx1 $ evalHistory $ fminunc_bfgs_ lineSearch stop x0 f ) 36 | -------------------------------------------------------------------------------- /examples/example0003-classification.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE NoImplicitPrelude #-} 2 | {-# LANGUAGE ScopedTypeVariables #-} 3 | {-# LANGUAGE RebindableSyntax #-} 4 | {-# LANGUAGE DataKinds #-} 5 | {-# LANGUAGE ConstraintKinds #-} 6 | {-# LANGUAGE RankNTypes #-} 7 | 8 | import SubHask 9 | import SubHask.Algebra.Array 10 | import SubHask.Algebra.Vector 11 | import SubHask.Algebra.Container 12 | 13 | import HLearn.Data.LoadData 14 | import HLearn.Classifiers.Linear 15 | import HLearn.History 16 | 17 | import qualified Prelude as P 18 | import System.IO 19 | 20 | -------------------------------------------------------------------------------- 21 | 22 | main = do 23 | xs :: BArray (Labeled' (SVector "dyn" Double) (Lexical String)) 24 | <- loadCSVLabeled' 0 "datasets/csv/uci/wine.csv" 25 | -- <- loadCSVLabeled' 8 "datasets/csv/uci/pima-indians-diabetes.csv" 26 | 27 | glm <- runHistory 28 | ( (displayFilter (maxReportLevel 2) dispIteration) 29 | + summaryTable 30 | ) 31 | $ trainLogisticRegression 1e-3 xs 32 | 33 | putStrLn $ "loss_01 = "++show (validate loss_01 (toList xs) glm) 34 | putStrLn $ "loss_logistic = "++show (validate loss_logistic (toList xs) glm) 35 | putStrLn $ "loss_hinge = "++show (validate loss_hinge (toList xs) glm) 36 | 37 | -- putStrLn "" 38 | -- print $ show $ weights glm!Lexical "1" 39 | -- print $ show $ weights glm!Lexical "2" 40 | -- print $ show $ weights glm!Lexical "3" 41 | 42 | putStrLn "done." 43 | -------------------------------------------------------------------------------- /executables/README.md: -------------------------------------------------------------------------------- 1 | HLearn comes with a number of executable commands. 2 | These executables provide a convenient interface to HLearn for shell scripts. 3 | 4 | Currently, only [hlearn-allknn](/executables/hlearn-allknn) is complete. 5 | It provides a fast method for nearest neighbor queries. 6 | -------------------------------------------------------------------------------- /install/README.md: -------------------------------------------------------------------------------- 1 | I use these scripts to install HLearn on AWS Ubuntu Precise machines (e.g. using [this AMI](https://us-west-2.console.aws.amazon.com/ec2/home?region=us-west-2#LaunchInstanceWizard:ami=ami-17b0b127)) and for testing with Travis CI (which also uses Ubuntu Precise). 2 | The `ubuntu-precise.sh` script installs just the componenets needed for HLearn, whereas the `ubuntu-precise-extras.sh` script installs all the related packages needed to run the benchmarks and tests. 3 | It should be pretty straightforward to adapt these scripts to whatever your preferred OS is. 4 | 5 | To install on an Ubuntu Precise machine, pipe the script to a shell instance using: 6 | ``` 7 | curl https://raw.githubusercontent.com/mikeizbicki/HLearn/travis/install/ubuntu-precise.sh | sh 8 | ``` 9 | -------------------------------------------------------------------------------- /install/ubuntu-precise-extras.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd 3 | set -e 4 | 5 | # installing julia 6 | #sudo add-apt-repository -y ppa:staticfloat/juliareleases 7 | #sudo add-apt-repository -y ppa:staticfloat/julia-deps 8 | ##sudo add-apt-repository -y 'deb http://ppa.launchpad.net/staticfloat/juliareleases/ubuntu precise main' 9 | #sudo apt-get update 10 | #sudo apt-get install -qq julia 11 | 12 | # the julia package is not compatible with python's numpy due to conflicting lapack dependencies 13 | wget https://github.com/JuliaLang/julia/releases/download/v0.3.10/julia-0.3.10_c8ceeefcc1.tar.gz 14 | tar -xf julia-0.3.10_c8ceeefcc1.tar.gz 15 | cd julia 16 | make -j5 17 | sudo make install 18 | export PATH="/home/ubuntu/julia/julia-0.3.10/bin:$PATH" 19 | cd 20 | 21 | echo "Pkg.add(\"KDTrees\")" > setup.julia 22 | julia setup.julia 23 | 24 | # install R 25 | sudo apt-key adv --keyserver keyserver.ubuntu.com --recv-keys E084DAB9 26 | sudo add-apt-repository -y 'deb http://streaming.stat.iastate.edu/CRAN/bin/linux/ubuntu precise/' 27 | sudo apt-get update 28 | sudo apt-get install -qq r-base r-base-dev 29 | 30 | echo "install.packages(\"FNN\",repos=\"http://cran.us.r-project.org\")" > setup.R 31 | sudo Rscript setup.R 32 | 33 | # install scikit 34 | sudo apt-get install -qq subversion 35 | sudo apt-get install -qq python-dev python-pip 36 | sudo pip install numpy 37 | sudo pip install scipy 38 | sudo pip install sklearn 39 | 40 | # install flann 41 | sudo apt-get install -qq cmake unzip make 42 | wget http://www.cs.ubc.ca/research/flann/uploads/FLANN/flann-1.8.4-src.zip 43 | unzip flann-1.8.4-src.zip 44 | cd flann-1.8.4-src 45 | mkdir build 46 | cd build 47 | cmake .. 48 | make -j5 49 | sudo make install 50 | cd 51 | 52 | # install mlpack 53 | sudo apt-get install -qq libxml2-dev 54 | sudo apt-get install -qq libboost-all-dev libboost-program-options-dev libboost-test-dev libboost-random-dev 55 | sudo apt-get install -qq doxygen 56 | 57 | wget http://downloads.sourceforge.net/project/boost/boost/1.58.0/boost_1_58_0.tar.gz 58 | tar -xf boost_1_58_0.tar.gz 59 | cd boost_1_58_0 60 | ./booststrap.sh 61 | ./b2 62 | sudo ./b2 install 63 | cd 64 | 65 | wget http://sourceforge.net/projects/arma/files/armadillo-5.200.2.tar.gz 66 | tar -xf armadillo-5.200.2.tar.gz 67 | cd armadillo-5.200.2 68 | cmake .. 69 | make -j5 70 | sudo make install 71 | cd 72 | 73 | http://mlpack.org/files/mlpack-1.0.12.tar.gz 74 | tar -xf mlpack-1.0.12.tar.gz 75 | cd mlpack-1.0.12 76 | mkdir build 77 | cd build 78 | cmake .. 79 | make -j5 80 | sudo make install 81 | cd 82 | 83 | # weka 84 | sudo apt-get install -qq openjdk-7-jdk 85 | sudo apt-get install -qq weka 86 | -------------------------------------------------------------------------------- /install/ubuntu-precise.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | # travis automatically sets these variables; 6 | # this is a hacked test to check if we're running on travis, 7 | # if not, we need to set them manually 8 | if [ -z "$CABALVER" ]; then 9 | 10 | # set environment variables 11 | CABALVER=1.22 12 | GHCVER=7.10.2 13 | LLVMVER=3.5 14 | 15 | # install git 16 | sudo apt-get update -qq 17 | sudo apt-get install -qq git 18 | 19 | git config --global user.name 'InstallScript' 20 | git config --global user.email 'installscript@izbicki.me' 21 | 22 | # get hlearn code 23 | git clone https://github.com/mikeizbicki/hlearn 24 | cd hlearn 25 | git submodule update --init --recursive subhask 26 | fi 27 | 28 | # install numeric deps 29 | sudo apt-get install -qq libatlas3gf-base 30 | sudo apt-get install -qq libblas-dev liblapack-dev 31 | 32 | # update g++ version (required for llvm) 33 | sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test 34 | sudo apt-get update -qq 35 | sudo apt-get install -qq g++-4.8 36 | sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-4.8 90 37 | 38 | # update llvm 39 | wget -O - http://llvm.org/apt/llvm-snapshot.gpg.key|sudo apt-key add - 40 | sudo add-apt-repository "deb http://llvm.org/apt/precise/ llvm-toolchain-precise main" 41 | sudo add-apt-repository "deb http://llvm.org/apt/precise/ llvm-toolchain-precise-$LLVMVER main" 42 | sudo apt-get update 43 | sudo apt-get install -y llvm-$LLVMVER llvm-$LLVMVER-dev 44 | sudo ln -s /usr/bin/opt-$LLVMVER /usr/bin/opt 45 | sudo ln -s /usr/bin/llc-$LLVMVER /usr/bin/llc 46 | export PATH="/usr/bin:$PATH" 47 | 48 | # install haskell bits 49 | sudo add-apt-repository -y ppa:hvr/ghc 50 | sudo apt-get update -qq 51 | sudo apt-get install -qq cabal-install-$CABALVER ghc-$GHCVER 52 | export PATH="~/.cabal/bin:/opt/ghc/$GHCVER/bin:/opt/cabal/$CABALVER/bin:$PATH" 53 | 54 | # install the local version of subhask 55 | cd subhask 56 | cabal update 57 | cabal install -j4 58 | cd .. 59 | 60 | # install hlearn 61 | cabal --version 62 | echo "$(ghc --version) [$(ghc --print-project-git-commit-id 2> /dev/null || echo '?')]" 63 | cabal install -j4 --only-dependencies --enable-tests --enable-benchmarks 64 | cabal install || cabal install --disable-optimization 65 | 66 | echo "Optimization: False" >> /home/travis/.cabal/config 67 | -------------------------------------------------------------------------------- /src/HLearn/Classifiers/Linear.hs: -------------------------------------------------------------------------------- 1 | module HLearn.Classifiers.Linear 2 | where 3 | 4 | import SubHask 5 | import SubHask.Category.Trans.Derivative 6 | import SubHask.Compatibility.Containers 7 | 8 | import HLearn.History 9 | import HLearn.Optimization.Univariate 10 | import HLearn.Optimization.Multivariate 11 | 12 | import qualified Prelude as P 13 | import qualified Data.List as L 14 | import Debug.Trace 15 | 16 | -------------------------------------------------------------------------------- 17 | 18 | -- | The data type to represent arbitrary . 19 | -- Important examples include least squares regression, logistic regression, and support vector machines. 20 | -- In statistics, these models are called . 21 | data GLM x y = GLM 22 | { weights :: Map' y x 23 | , numdp :: Scalar x 24 | } 25 | 26 | deriving instance (Show x, Show y, Show (Scalar x)) => Show (GLM x y) 27 | 28 | type instance Scalar (GLM x y) = Scalar x 29 | 30 | -------------------------------------------------------------------------------- 31 | 32 | -- type Foldable' xs x = (Foldable xs, Elem xs~x, Scalar xs~Int) 33 | 34 | type IsFoldable xs x = {-forall xs.-} (Foldable xs, Elem xs~x, Scalar xs~Int) 35 | 36 | {-# INLINEABLE trainLogisticRegression #-} 37 | trainLogisticRegression :: 38 | ( Ord y 39 | , Show y 40 | , Hilbert x 41 | , BoundedField (Scalar x) 42 | ) => Scalar x -- ^ regularization parameter 43 | -> IsFoldable xys (Labeled' x y) => xys -- ^ dataset 44 | -> ( cxt (LineBracket (Scalar x)) 45 | , cxt (Iterator_cgd x) 46 | , cxt (Iterator_cgd (Scalar x)) 47 | , cxt (Iterator_brent (Scalar x)) 48 | , cxt (Backtracking x) 49 | , cxt (Map' y x) 50 | , cxt Int 51 | ) => History cxt (GLM x y) 52 | trainLogisticRegression lambda xs = trainGLM_ 53 | ( fminunc_cgd_ 54 | -- hestenesStiefel 55 | -- polakRibiere 56 | -- steepestDescent 57 | fletcherReeves 58 | (lineSearch_brent (stop_brent 1e-6 || maxIterations 50 || noProgress || fx1grows)) 59 | -- (backtracking (strongCurvature 1e-10)) 60 | -- (backtracking fx1grows) 61 | (mulTolerance 1e-9 || maxIterations 50 || noProgress {-- || fx1grows-}) 62 | ) 63 | loss_logistic 64 | lambda 65 | (toList xs) 66 | 67 | {-# INLINEABLE trainGLM_ #-} 68 | trainGLM_ :: forall xys x y cxt opt. 69 | ( Ord y 70 | , Show y 71 | , Hilbert x 72 | ) => Has_x1 opt x => (x -> C1 (x -> Scalar x) -> forall s. History_ cxt s (opt x)) -- ^ optimization method 73 | -> (y -> Labeled' x y -> C2 (x -> Scalar x)) -- ^ loss function 74 | -> Scalar x -- ^ regularization parameter 75 | -> [Labeled' x y] -- ^ dataset 76 | -> History cxt (GLM x y) 77 | trainGLM_ optmethod loss lambda xs = do 78 | -- let ys = fromList $ map yLabeled' $ toList xs :: Set y 79 | -- ws <- fmap sum $ mapM go (toList ys) 80 | -- return $ GLM 81 | -- { weights = ws 82 | -- , numdp = fromIntegral $ length xs 83 | -- } 84 | let Just (y0,ys) = uncons (fromList $ map yLabeled' $ toList xs :: Set y) 85 | ws <- fmap sum $ mapM go (toList ys) 86 | return $ GLM 87 | { weights = insertAt y0 zero ws 88 | , numdp = fromIntegral $ length xs 89 | } 90 | where 91 | go :: Show y => y -> forall s. History_ cxt s (Map' y x) 92 | go y0 = beginFunction ("trainGLM("++show y0++")") $ do 93 | w <- fmap x1 $ optmethod zero (totalLoss y0) 94 | return $ singletonAt y0 w 95 | 96 | totalLoss :: y -> C1 (x -> Scalar x) 97 | totalLoss y0 = unsafeProveC1 f f' 98 | where 99 | g = foldMap (loss y0) xs 100 | 101 | f w = ( g $ w) + lambda*size w 102 | f' w = (derivative g $ w) + lambda*.w 103 | 104 | 105 | -------------------------------------------------------------------------------- 106 | -- loss functions 107 | 108 | classify :: (Ord y, Hilbert x) => GLM x y -> x -> y 109 | classify (GLM ws _) x 110 | = fst 111 | $ P.head 112 | $ L.sortBy (\(_,wx1) (_,wx2) -> compare wx2 wx1) 113 | $ toIxList 114 | $ imap (\_ w -> w<>x) ws 115 | 116 | validate :: 117 | ( Ord y 118 | , Hilbert x 119 | , cat <: (->) 120 | ) => (y -> Labeled' x y -> (x `cat` Scalar x)) -> [Labeled' x y] -> GLM x y -> Scalar x 121 | validate loss xs model@(GLM ws _) = (sum $ map go xs) -- /(fromIntegral $ length xs) 122 | where 123 | go xy@(Labeled' x y) = loss y' xy $ ws!y' 124 | where 125 | y' = classify model x 126 | 127 | {-# INLINEABLE loss_01 #-} 128 | loss_01 :: (HasScalar x, Eq y) => y -> Labeled' x y -> x -> Scalar x 129 | loss_01 y0 (Labeled' x y) = indicator $ y0/=y 130 | 131 | {-# INLINEABLE loss_squared #-} 132 | loss_squared :: (Hilbert x, Eq y) => y -> Labeled' x y -> C2 (x -> Scalar x) 133 | loss_squared y0 (Labeled' x y) = unsafeProveC2 f f' f'' 134 | where 135 | labelscore = bool2num $ y0==y 136 | 137 | f w = 0.5 * (w<>x-labelscore)**2 138 | f' w = x.*(w<>x-labelscore) 139 | f'' w = x> y -> Labeled' x y -> C2 (x -> Scalar x) 143 | loss_logistic y0 (Labeled' x y) = unsafeProveC2 f f' f'' 144 | where 145 | labelscore = bool2num $ y0==y 146 | 147 | f w = logSumOfExp2 zero $ -labelscore*w<>x 148 | f' w = -labelscore*(1-invlogit (labelscore*w<>x)) *. x 149 | f'' w = x>x)) 150 | 151 | {-# INLINEABLE loss_hinge #-} 152 | loss_hinge :: (Hilbert x, Eq y) => y -> Labeled' x y -> C2 (x -> Scalar x) 153 | loss_hinge y0 (Labeled' x y) = unsafeProveC2 f f' f'' 154 | where 155 | labelscore = bool2num $ y0==y 156 | 157 | f w = max 0 $ 1 -labelscore*w<>x 158 | f' w = if x<>w > 1 then zero else -labelscore*.x 159 | f'' w = zero 160 | 161 | ---------------------------------------- 162 | -- helpers 163 | 164 | bool2num True = 1 165 | bool2num False = -1 166 | 167 | invlogit x = 1 / (1 + exp (-x)) 168 | 169 | -- | calculates log . sum . map exp in a numerically stable way 170 | logSumOfExp xs = m + log (sum [ exp $ x-m | x <- xs ] ) 171 | where 172 | m = maximum xs 173 | 174 | -- | calculates log $ exp x1 + exp x2 in a numerically stable way 175 | logSumOfExp2 x1 x2 = big + log ( exp (small-big) + 1 ) 176 | where 177 | big = max x1 x2 178 | small = min x1 x2 179 | 180 | -- logSumOfExp2 x1 x2 = m + log ( exp (x1-m) + exp (x2-m) ) 181 | -- where 182 | -- m = max x1 x2 183 | -------------------------------------------------------------------------------- /src/HLearn/Data/Graph.hs: -------------------------------------------------------------------------------- 1 | module HLearn.Data.Graph 2 | where 3 | 4 | 5 | import SubHask 6 | import SubHask.Compatibility.HMatrix 7 | import SubHask.TemplateHaskell.Deriving 8 | 9 | import qualified Data.Vector.Generic as VG 10 | 11 | import Data.List (reverse,take,permutations) 12 | import Control.DeepSeq 13 | import Data.List (lines,filter,head,words,sort,replicate,take,nubBy,zip,zip3) 14 | import System.IO 15 | import System.Directory 16 | 17 | -------------------------------------------------------------------------------- 18 | 19 | class 20 | ( Metric v 21 | , Ord (Scalar v) 22 | , HasScalar v 23 | , ExpField (Scalar v) 24 | ) => KernelSpace v 25 | where 26 | kernel :: v -> v -> Scalar v 27 | 28 | selfKernel :: v -> Scalar v 29 | selfKernel v = kernel v v 30 | 31 | kernelNorm :: KernelSpace v => v -> Scalar v 32 | kernelNorm = sqrt . selfKernel 33 | 34 | kernelDistance :: KernelSpace v => v -> v -> Scalar v 35 | kernelDistance v1 v2 = sqrt $ selfKernel v1 + selfKernel v2 - 2* kernel v1 v2 36 | 37 | --------------------------------------- 38 | 39 | -- | A "SelfKernel" precomputes the kernel applied to itself twice. 40 | -- This is a common calculation in many kernelized algorithms, so this can greatly speed up computation. 41 | -- newtype SelfKernel v = SelfKernel (Scalar v, v) 42 | -- deriveHierarchy ''SelfKernel [''Ord,''Boolean] 43 | -- deriveTypefamilies [''Scalar] ''SelfKernel 44 | -- deriveSingleInstance ''SelfKernel ''Eq 45 | -- deriveSingleInstance ''SelfKernel ''POrd 46 | -- deriveSingleInstance ''SelfKernel ''Ord 47 | -- deriveSingleInstance ''SelfKernel ''SupSemilattice 48 | -- deriveSingleInstance ''SelfKernel ''InfSemilattice 49 | -- deriveSingleInstance ''SelfKernel ''Lattic 50 | 51 | data SelfKernel v = SelfKernel !(Scalar v) !v 52 | 53 | type instance Logic (SelfKernel v) = Logic v 54 | 55 | instance (NFData v, NFData (Scalar v)) => NFData (SelfKernel v) where 56 | rnf (SelfKernel k v) = deepseq k $ rnf v 57 | 58 | instance (Show (Scalar v), Show v) => Show (SelfKernel v) where 59 | show (SelfKernel k v) = "SelfKernel "++show k++" "++show v 60 | 61 | mkSelfKernel :: KernelSpace v => v -> SelfKernel v 62 | mkSelfKernel v = SelfKernel (selfKernel v) v 63 | 64 | type instance Scalar (SelfKernel v) = Scalar v 65 | type instance Elem (SelfKernel v) = Elem v 66 | 67 | instance Eq v => Eq_ (SelfKernel v) where 68 | (SelfKernel _ v1)==(SelfKernel _ v2) = v1==v2 69 | 70 | instance (KernelSpace v, POrd v) => POrd_ (SelfKernel v) where 71 | inf (SelfKernel _ v1) (SelfKernel _ v2) = mkSelfKernel $ inf v1 v2 72 | 73 | instance (KernelSpace v, Lattice_ v) => Lattice_ (SelfKernel v) where 74 | sup (SelfKernel _ v1) (SelfKernel _ v2) = mkSelfKernel $ sup v1 v2 75 | 76 | instance (KernelSpace v, Ord v) => Ord_ (SelfKernel v) 77 | 78 | instance KernelSpace v => KernelSpace (SelfKernel v) where 79 | kernel (SelfKernel _ v1) (SelfKernel _ v2) = kernel v1 v2 80 | selfKernel (SelfKernel k _) = k 81 | 82 | instance KernelSpace v => Normed (SelfKernel v) where 83 | size = kernelNorm 84 | 85 | instance KernelSpace v => Metric (SelfKernel v) where 86 | distance = kernelDistance 87 | 88 | -------------------------------------------------------------------------------- 89 | -- Graph 90 | 91 | -- type Graph = SelfKernel Graph_ 92 | 93 | data Graph = Graph 94 | { graph :: Graph_ 95 | , memo :: [Double] 96 | } 97 | deriving (Show) 98 | 99 | type instance Logic Graph = Bool 100 | 101 | instance Eq_ Graph where 102 | g1==g2 = graph g1==graph g2 103 | 104 | -- instance POrd_ Graph 105 | -- instance Lattice_ Graph 106 | -- instance Ord_ Graph 107 | 108 | instance NFData Graph where 109 | rnf (Graph g m) = deepseq g $ rnf m 110 | 111 | mkGraph :: Graph_ -> Graph 112 | mkGraph g_ = Graph 113 | { graph = g_ 114 | , memo = go lambdas one [] 115 | } 116 | where 117 | 118 | go [] tm' ret = ret 119 | -- go (0:xs) tm' ret = go xs tm' $ 0:ret 120 | -- go (x:xs) tm' ret = go xs tm'' $ ((startVec g_ <> (tm'' `mXv` stopVec g_))):ret 121 | go (0:xs) tm' ret = go xs tm' $ ret+ [0] 122 | go (x:xs) tm' ret = go xs tm'' $ ret + [(startVec g_ <> (tm'' `mXv` stopVec g_))] 123 | where 124 | tm'' = tm' * transitionMatrix g_ 125 | 126 | instance Metric Graph where 127 | distance (Graph _ xs) (Graph _ ys) = sqrt $ sum 128 | [ lambda * (x-y)**2 129 | | (lambda,x,y) <- zip3 lambdas xs ys 130 | ] 131 | 132 | type instance Scalar Graph = Double 133 | 134 | lambdas = [1,1/2,1/4,1/8,1/16] 135 | 136 | distance' g1 g2 = distance (graph g1) (graph g2) 137 | 138 | ----------------------------- 139 | 140 | data Graph_ = Graph_ 141 | { transitionMatrix :: !(Matrix Double) 142 | , startVec :: !(DynVector 0 Double) 143 | , stopVec :: !(DynVector 0 Double) 144 | } 145 | deriving (Show) 146 | 147 | type instance Logic Graph_ = Bool 148 | 149 | instance Eq_ Graph_ where 150 | g1==g2 = transitionMatrix g1==transitionMatrix g2 151 | &&startVec g1==startVec g2 152 | &&stopVec g1 ==stopVec g2 153 | 154 | type instance Scalar Graph_ = Double 155 | 156 | instance NFData Graph_ where 157 | -- rnf g = () 158 | rnf g = deepseq (transitionMatrix g) 159 | $ deepseq (startVec g) 160 | $ rnf (stopVec g) 161 | 162 | -- instance POrd_ Graph_ 163 | -- instance Lattice_ Graph_ where 164 | -- pcompare g1 g2 = pcompare (toVector $ transitionMatrix g1) (toVector $ transitionMatrix g2) 165 | -- 166 | -- instance Ord_ Graph_ 167 | 168 | productGraph_ :: Graph_ -> Graph_ -> Graph_ 169 | productGraph_ g1 g2 = Graph_ 170 | { transitionMatrix = transitionMatrix g1 >< transitionMatrix g2 171 | , startVec = toVector $ startVec g1 >< startVec g2 172 | , stopVec = toVector $ stopVec g1 >< stopVec g2 173 | } 174 | 175 | instance KernelSpace Graph_ where 176 | kernel = mkKernelGraph_ lambdas 177 | 178 | -- mkKernelGraph :: [Double] -> Graph -> Graph -> Double 179 | -- mkKernelGraph xs (SelfKernel _ g1) (SelfKernel _ g2) = mkKernelGraph_ xs g1 g2 180 | 181 | mkKernelGraph_ :: [Double] -> Graph_ -> Graph_ -> Double 182 | mkKernelGraph_ xs g1 g2 = go xs one 0 183 | where 184 | gprod = productGraph_ g1 g2 185 | 186 | go [] tm' ret = ret 187 | go (0:xs) tm' ret = go xs tm' $ ret 188 | go (x:xs) tm' ret = go xs tm'' $ ret + (x *. (startVec gprod <> (tm'' `mXv` stopVec gprod))) 189 | where 190 | tm'' = tm' * transitionMatrix gprod 191 | 192 | -- mkDistanceGraph :: [Double] -> Graph -> Graph -> Double 193 | -- mkDistanceGraph lambdas (SelfKernel _ g1) (SelfKernel _ g2) 194 | -- = sqrt $ sum 195 | -- -- [ mkKernelGraph_ lambdas g1 g1 196 | -- -- + mkKernelGraph_ lambdas g2 g2 197 | -- -- - mkKernelGraph_ lambdas g1 g2 198 | -- -- - mkKernelGraph_ lambdas g1 g2 199 | -- -- 200 | -- -- ] 201 | -- 202 | -- -- [ (lambdas!(i-1))* 203 | -- -- ( startVec g1' <> ((transitionMatrix g1'^^^i) `mXv` stopVec g1') 204 | -- -- + startVec g2' <> ((transitionMatrix g2'^^^i) `mXv` stopVec g2') 205 | -- -- - startVec gp <> ((transitionMatrix gp^^^i) `mXv` stopVec gp) 206 | -- -- - startVec gp <> ((transitionMatrix gp^^^i) `mXv` stopVec gp) 207 | -- -- ) 208 | -- -- 209 | -- -- | i <- [1..length lambdas::Int] 210 | -- -- ] 211 | -- -- where 212 | -- -- g1'=productGraph_ g1 g1 213 | -- -- g2'=productGraph_ g2 g2 214 | -- -- gp=productGraph_ g1 g2 215 | -- 216 | -- [ (lambdas!i)* 217 | -- ( startVec g1 <> ((transitionMatrix g1^^^(i+1)) `mXv` stopVec g1) 218 | -- - startVec g2 <> ((transitionMatrix g2^^^(i+1)) `mXv` stopVec g2) 219 | -- ) **2 220 | -- | i <- [0..length lambdas-1::Int] 221 | -- ] 222 | 223 | (^^^) :: Ring r => r -> Int -> r 224 | (^^^) r 0 = one 225 | (^^^) r 1 = r 226 | (^^^) r i = r*(r^^^(i-1)) 227 | 228 | mag :: Graph_ -> Double 229 | mag g = startVec g <> (transitionMatrix g `mXv` stopVec g) 230 | 231 | instance Metric Graph_ where 232 | distance = kernelDistance 233 | 234 | edgeList2UndirectedGraph :: Int -> [(Int,Int)] -> Graph 235 | edgeList2UndirectedGraph numVertices edgeList = edgeList2Graph numVertices $ symmetrize edgeList 236 | 237 | edgeList2Graph :: Int -> [(Int,Int)] -> Graph 238 | -- edgeList2Graph numVertices edgeList = mkSelfKernel $ Graph_ 239 | edgeList2Graph numVertices edgeList = mkGraph $ Graph_ 240 | { transitionMatrix = mat -- +one 241 | , startVec=VG.replicate numVertices $ 1/fromIntegral numVertices 242 | , stopVec=VG.replicate numVertices $ 1/fromIntegral numVertices 243 | } 244 | where 245 | mat = mkMatrix numVertices numVertices 246 | $ map (fromIntegral :: Int -> Double) 247 | $ edgeList2AdjList numVertices numVertices edgeList 248 | 249 | edgeList2AdjList :: Int -> Int -> [(Int,Int)] -> [Int] 250 | edgeList2AdjList w h xs = go 0 0 (sort xs) 251 | where 252 | go r c [] = replicate (w-c+(h-r-1)*w) 0 253 | go r c ((xr,xc):xs) = if r==xr && c==xc 254 | then 1:go r' c' xs 255 | else 0:go r' c' ((xr,xc):xs) 256 | where 257 | c' = (c+1) `mod` w 258 | r' = if c+1 ==c' then r else r+1 259 | 260 | -- go i j [] ret = (reverse ret) ++ replicate (w-i+(h-j-1)*w) 0 261 | -- go i j ((xi,xj):xs) ret = if i==xi && j==xj 262 | -- then go i' j' xs (1:ret) 263 | -- else go i' j' ((xi,xj):xs) (0:ret) 264 | -- where 265 | -- j'=(j+1) `mod` w 266 | -- i'=if (j+1)==j' then i else i+1 267 | 268 | symmetrize :: [(Int,Int)] -> [(Int,Int)] 269 | symmetrize xs = sort $ go xs [] 270 | where 271 | go [] ret = ret 272 | go ((i,j):xs) ret = if i==j 273 | then go xs $ (i,j):ret 274 | else go xs $ (i,j):(j,i):ret 275 | 276 | -------------------------------------------------------------------------------- 277 | -- file IO 278 | 279 | -- | helper for "loadDirectory" 280 | isFileTypePLG :: FilePath -> Bool 281 | isFileTypePLG filepath = take 4 (reverse filepath) == "glp." 282 | 283 | -- | helper for "loadDirectory" 284 | isNonemptyGraph :: Graph -> Bool 285 | isNonemptyGraph (Graph v _) = VG.length (startVec v) > 0 286 | 287 | -- | loads a file in the PLG data format into a graph 288 | -- 289 | -- See: www.bioinformatik.uni-frankfurt.de/tools/vplg/ for a description of the file format 290 | loadPLG 291 | :: Bool -- ^ print debug info? 292 | -> FilePath -- ^ path of SUBDUE graph file 293 | -> IO Graph 294 | loadPLG debug filepath = {-# SCC loadPLG #-} do 295 | filedata <- liftM lines $ readFile filepath 296 | 297 | let edgeList = mkEdgeList filedata 298 | numEdges = length edgeList 299 | numVertices = length $ mkVertexList filedata 300 | 301 | when debug $ do 302 | putStrLn $ filepath++"; num vertices = "++show numVertices++"; num edges = "++show numEdges 303 | 304 | let ret = edgeList2UndirectedGraph numVertices edgeList 305 | deepseq ret $ return ret 306 | 307 | where 308 | mkVertexList xs = filter (\x -> head x == "|") $ map words xs 309 | mkEdgeList xs = map (\["=",v1,"=",t,"=",v2] -> (read v1-1::Int,{-t,-}read v2-1::Int)) 310 | $ filter (\x -> head x == "=") 311 | $ map words xs 312 | 313 | 314 | -- | loads a file in the SUBDUE data format into a graph 315 | -- 316 | -- FIXME: not implemented 317 | loadSubdue 318 | :: Bool -- ^ print debug info? 319 | -> FilePath -- ^ path of SUBDUE graph file 320 | -> IO Graph_ 321 | loadSubdue debug filepath = do 322 | file <- liftM lines $ readFile filepath 323 | let numVertices = getNumVertices file 324 | numEdges = getNumEdges file 325 | putStrLn $ "num vertices = " ++ show numVertices 326 | putStrLn $ "num edges = " ++ show numEdges 327 | undefined 328 | where 329 | getNumVertices xs = length $ filter (\x -> head x == "v") $ map words xs 330 | getNumEdges xs = length $ filter (\x -> head x == "u") $ map words xs 331 | -- hin <- openFile filepath ReadMode 332 | -- hClose hin 333 | -- undefined 334 | 335 | 336 | -------------------------------------------------------------------------------- 337 | -- tests 338 | 339 | g1 = edgeList2UndirectedGraph 2 $ [(0,0)] 340 | g2 = edgeList2UndirectedGraph 2 $ [(0,0),(0,1)] 341 | g3 = edgeList2UndirectedGraph 2 $ [(0,0),(1,1)] 342 | g4 = edgeList2UndirectedGraph 2 $ [(1,0)] 343 | h1 = edgeList2UndirectedGraph 3 $ [(0,0),(1,1),(2,2)] 344 | h2 = edgeList2UndirectedGraph 3 $ [(0,0),(0,1),(0,2)] 345 | h3 = edgeList2UndirectedGraph 3 $ [(0,0),(0,1),(0,2),(2,2)] 346 | h4 = edgeList2UndirectedGraph 3 $ [(1,1),(0,1),(0,2),(2,2)] 347 | 348 | a = edgeList2UndirectedGraph 3 $ [(0,0),(0,1),(0,2)] 349 | b = edgeList2UndirectedGraph 3 $ [(0,0),(0,2)] 350 | c = edgeList2UndirectedGraph 3 $ [(0,0),(0,2),(1,1),(1,2)] 351 | d = edgeList2UndirectedGraph 3 $ [(0,0),(0,2),(1,1)] 352 | e = edgeList2UndirectedGraph 3 $ [(2,2),(1,1)] 353 | f = edgeList2UndirectedGraph 3 $ [(0,0),(0,1),(0,2),(1,1),(1,2)] 354 | g = edgeList2UndirectedGraph 3 $ [(0,0),(0,1),(0,2),(1,1),(1,2),(2,2)] 355 | h = edgeList2UndirectedGraph 3 $ [(1,1)] 356 | i = edgeList2UndirectedGraph 3 $ [] 357 | 358 | -------------------------------------------------------------------------------- /src/HLearn/Data/Image.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ForeignFunctionInterface #-} 2 | 3 | module HLearn.Data.Image 4 | ( CIELab 5 | , RGB 6 | , rgb2cielab 7 | , ColorSig 8 | , loadColorSig 9 | ) 10 | where 11 | 12 | import SubHask 13 | import SubHask.TemplateHaskell.Deriving 14 | 15 | import qualified Data.Vector as V 16 | import qualified Data.Vector.Storable as VS 17 | import qualified Data.Vector.Generic as VG 18 | 19 | import System.IO 20 | import System.IO.Unsafe 21 | 22 | import Foreign.C 23 | import Foreign.Ptr 24 | import Foreign.ForeignPtr 25 | import Foreign.Marshal.Array 26 | 27 | import qualified Prelude as P 28 | 29 | -------------------------------------------------------------------------------- 30 | 31 | -- | An alternative way to represent colors than RGB that is closer to how humans actually perceive color. 32 | -- 33 | -- See wikipedia on 34 | data CIELab a = CIELab 35 | { l :: !a 36 | , a :: !a 37 | , b :: !a 38 | } 39 | deriving (Show) 40 | 41 | type instance Scalar (CIELab a) = a 42 | type instance Logic (CIELab a) = Logic a 43 | 44 | instance ValidEq a => Eq_ (CIELab a) where 45 | c1==c2 = l c1 == l c2 46 | && a c1 == a c2 47 | && b c1 == b c2 48 | 49 | instance NFData a => NFData (CIELab a) where 50 | rnf c = deepseq (l c) 51 | $ deepseq (a c) 52 | $ rnf (b c) 53 | 54 | instance Storable a => Storable (CIELab a) where 55 | sizeOf _ = 3*sizeOf (undefined::a) 56 | alignment _ = alignment (undefined::a) 57 | 58 | peek p = do 59 | l <- peek $ plusPtr p $ 0*sizeOf (undefined::a) 60 | a <- peek $ plusPtr p $ 1*sizeOf (undefined::a) 61 | b <- peek $ plusPtr p $ 2*sizeOf (undefined::a) 62 | return $ CIELab l a b 63 | 64 | poke p (CIELab l a b) = do 65 | poke (plusPtr p $ 0*sizeOf (undefined::a)) l 66 | poke (plusPtr p $ 1*sizeOf (undefined::a)) a 67 | poke (plusPtr p $ 2*sizeOf (undefined::a)) b 68 | 69 | -- | Implements formulas taken from the opencv page: 70 | -- http://docs.opencv.org/modules/imgproc/doc/miscellaneous_transformations.html?highlight=cvtcolor 71 | -- 72 | -- FIXME: 73 | -- We should either: 74 | -- * implement all of the color differences 75 | -- * use the haskell package 76 | -- * use opencv 77 | -- 78 | rgb2cielab :: (ClassicalLogic a, Ord a, ExpField a) => RGB a -> CIELab a 79 | rgb2cielab (RGB r g b) = CIELab l_ a_ b_ 80 | where 81 | x = 0.412453*r + 0.357580*g + 0.180423*b 82 | y = 0.212671*r + 0.715160*g + 0.072169*b 83 | z = 0.019334*r + 0.119193*g + 0.950227*b 84 | 85 | x' = x / 0.950456 86 | z' = z / 1.088754 87 | 88 | l_ = if y > 0.008856 89 | then 116*y**(1/3)-16 90 | else 903.3*y 91 | a_ = 500*(f x' - f y ) + delta 92 | b_ = 200*(f y - f z') + delta 93 | 94 | f t = if t > 0.008856 95 | then t**(1/3) 96 | else 7.787*t+16/116 97 | 98 | delta=0 99 | 100 | instance Metric (CIELab Float) where 101 | {-# INLINABLE distance #-} 102 | distance c1 c2 = sqrt $ (l c1-l c2)*(l c1-l c2) 103 | + (a c1-a c2)*(a c1-a c2) 104 | + (b c1-b c2)*(b c1-b c2) 105 | 106 | -------------------------------------------------------------------------------- 107 | 108 | -- | The standard method for representing colors on most computer monitors and display formats. 109 | data RGB a = RGB 110 | { red :: !a 111 | , green :: !a 112 | , blue :: !a 113 | } 114 | deriving (Show) 115 | 116 | type instance Scalar (RGB a) = a 117 | type instance Logic (RGB a) = Logic a 118 | 119 | instance ValidEq a => Eq_ (RGB a) where 120 | c1==c2 = red c1 == red c2 121 | && green c1 == green c2 122 | && blue c1 == blue c2 123 | 124 | instance NFData a => NFData (RGB a) where 125 | rnf c = deepseq (red c) 126 | $ deepseq (green c) 127 | $ rnf (blue c) 128 | 129 | instance Storable a => Storable (RGB a) where 130 | sizeOf _ = 3*sizeOf (undefined::a) 131 | alignment _ = alignment (undefined::a) 132 | 133 | peek p = do 134 | r <- peek $ plusPtr p $ 0*sizeOf (undefined::a) 135 | g <- peek $ plusPtr p $ 1*sizeOf (undefined::a) 136 | b <- peek $ plusPtr p $ 2*sizeOf (undefined::a) 137 | return $ RGB r g b 138 | 139 | poke p (RGB r g b) = do 140 | poke (plusPtr p $ 0*sizeOf (undefined::a)) r 141 | poke (plusPtr p $ 1*sizeOf (undefined::a)) g 142 | poke (plusPtr p $ 2*sizeOf (undefined::a)) b 143 | 144 | instance Metric (RGB Float) where 145 | distance c1 c2 = sqrt $ (red c1-red c2)*(red c1-red c2) 146 | + (green c1-green c2)*(green c1-green c2) 147 | + (blue c1-blue c2)*(blue c1-blue c2) 148 | 149 | -------------------------------------------------------------------------------- 150 | 151 | -- | A signature is sparse representation of a histogram. 152 | -- This is used to implement the earth mover distance between images. 153 | data ColorSig a = ColorSig 154 | { rgbV :: !(StorableArray (CIELab a)) 155 | , weightV :: !(StorableArray a) 156 | } 157 | deriving (Show) 158 | 159 | type instance Scalar (ColorSig a) = Scalar a 160 | type instance Logic (ColorSig a) = Logic a 161 | 162 | instance NFData a => NFData (ColorSig a) where 163 | rnf a = deepseq (rgbV a) 164 | $ rnf (weightV a) 165 | 166 | instance Eq_ (ColorSig Float) where 167 | sig1==sig2 = rgbV sig1 == rgbV sig2 168 | && weightV sig1 == weightV sig2 169 | 170 | loadColorSig :: 171 | ( 172 | ) => Bool -- ^ print debug info? 173 | -> FilePath -- ^ path of signature file 174 | -> IO (ColorSig Float) 175 | loadColorSig debug filepath = {-# SCC loadColorSig #-} do 176 | filedata <- liftM P.lines $ readFile filepath 177 | 178 | let (rgbs,ws) = P.unzip 179 | $ map (\[b,g,r,v] -> (rgb2cielab $ RGB r g b, v)) 180 | $ map (read.(\x->"["+x+"]")) filedata 181 | 182 | let totalWeight = sum ws 183 | ws' = map (/totalWeight) ws 184 | 185 | let ret = ColorSig 186 | { rgbV = fromList rgbs 187 | , weightV = fromList ws' 188 | } 189 | 190 | when debug $ do 191 | putStrLn $ "filepath="++show filepath 192 | putStrLn $ " filedata="++show filedata 193 | putStrLn $ "signature length=" ++ show (length filedata) 194 | 195 | deepseq ret $ return ret 196 | 197 | instance Metric (ColorSig Float) where 198 | distance = emd_float 199 | distanceUB = lb2distanceUB emlb_float 200 | 201 | foreign import ccall unsafe "emd_float" emd_float_ 202 | :: Ptr Float -> Int -> Ptr Float -> Int -> Ptr Float -> IO Float 203 | 204 | {-# INLINABLE emd_float #-} 205 | emd_float :: ColorSig Float -> ColorSig Float -> Float 206 | emd_float (ColorSig rgbV1 (ArrayT weightV1)) (ColorSig rgbV2 (ArrayT weightV2)) 207 | = {-# SCC emd_float #-} unsafeDupablePerformIO $ 208 | withForeignPtr fp1 $ \p1 -> 209 | withForeignPtr fp2 $ \p2 -> 210 | withForeignPtr fpcost $ \pcost -> 211 | emd_float_ p1 n1 p2 n2 pcost 212 | 213 | where 214 | (fp1,n1) = VS.unsafeToForeignPtr0 weightV1 215 | (fp2,n2) = VS.unsafeToForeignPtr0 weightV2 216 | 217 | vcost = {-# SCC vcost #-} VS.generate (n1*n2) $ \i -> distance 218 | (rgbV1 `VG.unsafeIndex` (i`div`n2)) 219 | (rgbV2 `VG.unsafeIndex` (i`mod`n2)) 220 | 221 | (fpcost,_) = VS.unsafeToForeignPtr0 vcost 222 | 223 | emlb_float :: ColorSig Float -> ColorSig Float -> Float 224 | emlb_float sig1 sig2 = distance (centroid sig1) (centroid sig2) 225 | 226 | centroid :: ColorSig Float -> CIELab Float 227 | centroid (ColorSig rgbV weightV) = go (VG.length rgbV-1) (CIELab 0 0 0) 228 | where 229 | go (-1) tot = tot 230 | go i tot = go (i-1) $ CIELab 231 | { l = l tot + l (rgbV `VG.unsafeIndex` i) * (weightV `VG.unsafeIndex` i) 232 | , a = a tot + a (rgbV `VG.unsafeIndex` i) * (weightV `VG.unsafeIndex` i) 233 | , b = b tot + b (rgbV `VG.unsafeIndex` i) * (weightV `VG.unsafeIndex` i) 234 | } 235 | 236 | -------------------------------------------------------------------------------- /src/HLearn/Data/LoadData.hs: -------------------------------------------------------------------------------- 1 | -- | This module handles loading data from disk. 2 | module HLearn.Data.LoadData 3 | where 4 | 5 | import SubHask 6 | import SubHask.Algebra.Array 7 | import SubHask.Algebra.Container 8 | import SubHask.Algebra.Parallel 9 | import SubHask.Compatibility.ByteString 10 | import SubHask.Compatibility.Cassava 11 | import SubHask.Compatibility.Containers 12 | import SubHask.TemplateHaskell.Deriving 13 | 14 | import HLearn.History.Timing 15 | import HLearn.Models.Distributions 16 | 17 | import qualified Prelude as P 18 | import Prelude (asTypeOf,unzip,head,take,drop,zipWith) 19 | import Control.Monad.ST 20 | import qualified Data.List as L 21 | import Data.Maybe 22 | import System.Directory 23 | import System.IO 24 | 25 | -------------------------------------------------------------------------------- 26 | 27 | {- 28 | FIXME: 29 | This code was written a long time ago to assist with the Cover Tree ICML paper. 30 | It needs to be updated to use the new subhask interface. 31 | This should be an easy project. 32 | 33 | -- | This loads files in the format used by the BagOfWords UCI dataset. 34 | -- See: https://archive.ics.uci.edu/ml/machine-learning-databases/bag-of-words/readme.txt 35 | loadBagOfWords :: FilePath -> IO (BArray (Map' Int Float)) 36 | loadBagOfWords filepath = do 37 | hin <- openFile filepath ReadMode 38 | numdp :: Int <- liftM read $ hGetLine hin 39 | numdim :: Int <- liftM read $ hGetLine hin 40 | numlines :: Int <- liftM read $ hGetLine hin 41 | 42 | ret <- VGM.replicate numdp zero 43 | forM [0..numlines-1] $ \i -> do 44 | line <- hGetLine hin 45 | let [dp,dim,val] :: [Int] = map read $ L.words line 46 | curdp <- VGM.read ret (dp-1) 47 | VGM.write ret (dp-1) $ insertAt dim (fromIntegral val) curdp 48 | 49 | hClose hin 50 | VG.unsafeFreeze ret 51 | -} 52 | 53 | -- | Loads a dataset of strings in the unix words file format (i.e. one word per line). 54 | -- This format is also used by the UCI Bag Of Words dataset. 55 | -- See: https://archive.ics.uci.edu/ml/machine-learning-databases/bag-of-words/readme.txt 56 | loadWords :: (Monoid dp, Elem dp~Char, Eq dp, Constructible dp) => FilePath -> IO (BArray dp) 57 | loadWords filepath = do 58 | hin <- openFile filepath ReadMode 59 | contents <- hGetContents hin 60 | return $ fromList $ map fromList $ L.lines contents 61 | 62 | -------------------------------------------------------------------------------- 63 | 64 | -- | Returns all files in a subdirectory (and all descendant directories). 65 | -- Unlike "getDirectoryContents", this function prepends the directory's path to each filename. 66 | -- This is important so that we can tell where in the hierarchy the file is located. 67 | -- 68 | -- FIXME: 69 | -- This is relatively untested. 70 | -- It probably has bugs related to weird symbolic links. 71 | getDirectoryContentsRecursive :: FilePath -> IO [FilePath] 72 | getDirectoryContentsRecursive = fmap toList . go 73 | where 74 | go :: FilePath -> IO (Seq FilePath) 75 | go dirpath = do 76 | files <- getDirectoryContents dirpath 77 | fmap concat $ forM files $ \file -> case file of 78 | '.':_ -> return empty 79 | _ -> do 80 | let file' = dirpath++"/"++file 81 | isdir <- doesDirectoryExist file' 82 | contents <- if isdir 83 | then go file' 84 | else return empty 85 | return $ file' `cons` contents 86 | 87 | -- | A generic method for loading data points. 88 | -- Each file in a directory hierarchy corresponds to a single data point. 89 | -- 90 | -- The label assigned to the data point is simply the name of the file. 91 | -- This means each data point will have a distinct label. 92 | -- For typical supervised learning tasks, you will want to prune the 93 | loadDirectory :: 94 | ( Eq a 95 | , NFData a 96 | ) => Maybe Int -- ^ maximum number of datapoints to load; Nothing for unlimitted 97 | -> (FilePath -> IO a) -- ^ function to load an individual file 98 | -> (FilePath -> Bool) -- ^ function to filter out invalid filenames 99 | -> (a -> Bool) -- ^ function to filter out malformed results 100 | -> FilePath -- ^ directory to load data from 101 | -> IO (BArray (Labeled' a FilePath)) -- ^ 102 | loadDirectory numdp loadFile validFilepath validResult dirpath = {-# SCC loadDirectory #-} do 103 | 104 | files <- timeIO "getDirectoryContentsRecursive" $ do 105 | xs <- getDirectoryContentsRecursive dirpath 106 | 107 | let takedp = case numdp of 108 | Nothing -> id 109 | Just n -> fmap (L.take n) 110 | 111 | return $ takedp $ L.filter validFilepath xs 112 | 113 | results <- timeIO "loadDirectory" $ do 114 | xs <- forM files $ \filepath -> do 115 | res <- loadFile filepath 116 | return $ Labeled' res filepath 117 | return $ L.filter (validResult . xLabeled') xs 118 | 119 | putStrLn $ " numdp: " ++ show (length files) 120 | 121 | return $ fromList results 122 | 123 | -- | Load a CSV file containing numeric attributes. 124 | {-# INLINABLE loadCSV #-} 125 | loadCSV :: 126 | ( NFData a 127 | , FromRecord a 128 | , FiniteModule a 129 | , Eq a 130 | , Show (Scalar a) 131 | ) => FilePath -> IO (BArray a) 132 | loadCSV filepath = do 133 | 134 | bs <- timeIO ("loading ["++filepath++"]") $ readFileByteString filepath 135 | 136 | let rse = decode NoHeader bs 137 | time "parsing csv file" rse 138 | 139 | rs <- case rse of 140 | Right rs -> return rs 141 | Left str -> error $ "failed to parse CSV file " ++ filepath ++ ": " ++ L.take 1000 str 142 | 143 | putStrLn " dataset info:" 144 | putStrLn $ " num dp: " ++ show ( size rs ) 145 | putStrLn $ " numdim: " ++ show ( dim $ rs!0 ) 146 | putStrLn "" 147 | 148 | return rs 149 | 150 | -- | FIXME: this should be combined with the CSV function above 151 | loadCSVLabeled' :: 152 | ( NFData x 153 | , FromRecord x 154 | , FiniteModule x 155 | , Eq x 156 | , Show (Scalar x) 157 | , Read (Scalar x) 158 | ) => Int -- ^ column of csv file containing the label 159 | -> FilePath -- ^ path to csv file 160 | -> IO (BArray (Labeled' x (Lexical String))) 161 | loadCSVLabeled' col filepath = do 162 | 163 | bs <- timeIO ("loading ["++filepath++"]") $ readFileByteString filepath 164 | 165 | let rse = decode NoHeader bs 166 | time "parsing csv file" rse 167 | 168 | rs :: BArray (BArray String) <- case rse of 169 | Right rs -> return rs 170 | Left str -> error $ "failed to parse CSV file " ++ filepath ++ ": " ++ L.take 1000 str 171 | 172 | let ret = fromList $ map go $ toList rs 173 | 174 | putStrLn " dataset info:" 175 | putStrLn $ " num dp: " ++ show ( size ret ) 176 | putStrLn $ " numdim: " ++ show ( dim $ xLabeled' $ ret!0 ) 177 | putStrLn "" 178 | 179 | return ret 180 | 181 | where 182 | go arr = Labeled' x y 183 | where 184 | y = Lexical $ arr!col 185 | x = unsafeToModule $ map read $ take (col) arrlist ++ drop (col+1) arrlist 186 | 187 | arrlist = toList arr 188 | 189 | ------------------------------------------------------------------------------- 190 | -- data preprocessing 191 | -- 192 | -- FIXME: 193 | -- Find a better location for all this code. 194 | 195 | -- | Uses an efficient 1-pass algorithm to calculate the mean variance. 196 | -- This is much faster than the 2-pass algorithms on large datasets, 197 | -- but has (slightly) worse numeric stability. 198 | -- 199 | -- See http://www.cs.berkeley.edu/~mhoemmen/cs194/Tutorials/variance.pdf for details. 200 | {-# INLINE meanAndVarianceInOnePass #-} 201 | meanAndVarianceInOnePass :: (Foldable xs, Field (Elem xs)) => xs -> (Elem xs, Elem xs) 202 | meanAndVarianceInOnePass ys = 203 | {-# SCC meanAndVarianceInOnePass #-} 204 | case uncons ys of 205 | Nothing -> error "meanAndVarianceInOnePass on empty container" 206 | Just (x,xs) -> (\(k,m,v) -> (m,v/(k-1))) $ foldl' go (2,x,0) xs 207 | where 208 | go (k,mk,qk) x = (k+1,mk',qk') 209 | where 210 | mk'=mk+(x-mk)/k 211 | qk'=qk+(k-1)*(x-mk)*(x-mk)/k 212 | 213 | -- | A wrapper around "meanAndVarianceInOnePass" 214 | {-# INLINE varianceInOnePass #-} 215 | varianceInOnePass :: (Foldable xs, Field (Elem xs)) => xs -> Elem xs 216 | varianceInOnePass = snd . meanAndVarianceInOnePass 217 | 218 | -- | Calculate the variance of each column, then sort so that the highest variance is first. 219 | -- This can be useful for preprocessing data. 220 | -- 221 | -- NOTE: 222 | -- The git history has a lot of versions of this function with different levels of efficiency. 223 | -- I need to write a blog post about how all the subtle haskellisms effect the runtime. 224 | {-# INLINABLE mkShuffleMap #-} 225 | mkShuffleMap :: forall v. 226 | ( FiniteModule v 227 | , VectorSpace v 228 | , Unboxable v 229 | , Unboxable (Scalar v) 230 | , Eq v 231 | , Elem (SetElem v (Elem v)) ~ Elem v 232 | , Elem (SetElem v (Scalar (Elem v))) ~ Scalar (Elem v) 233 | , IxContainer (SetElem v (Elem v)) 234 | ) => BArray v -> UArray Int 235 | mkShuffleMap vs = {-# SCC mkShuffleMap #-} if size vs==0 236 | then error "mkShuffleMap: called on empty array" 237 | else runST ( do 238 | -- FIXME: 239 | -- @smalldata@ should be a random subsample of the data. 240 | -- The size should also depend on the dimension. 241 | let smalldata = P.take 1000 $ toList vs 242 | 243 | -- let variances = fromList 244 | -- $ values 245 | -- $ varianceInOnePass 246 | -- $ VG.map Componentwise vs 247 | -- :: BArray (Scalar v) 248 | 249 | let variances 250 | = imap (\i _ -> varianceInOnePass $ (imap (\_ -> (!i)) smalldata)) 251 | $ values 252 | $ vs!0 253 | :: [Scalar v] 254 | 255 | return 256 | $ fromList 257 | $ map fst 258 | $ L.sortBy (\(_,v1) (_,v2) -> compare v2 v1) 259 | $ imap (,) 260 | $ variances 261 | ) 262 | 263 | -- | apply the shufflemap to the data set to get a better ordering of the data 264 | {-# INLINABLE apShuffleMap #-} 265 | apShuffleMap :: forall v. FiniteModule v => UArray Int -> v -> v 266 | apShuffleMap vmap v = unsafeToModule xs 267 | where 268 | xs :: [Scalar v] 269 | xs = generate1 (size vmap) $ \i -> v!(vmap!i) 270 | 271 | {-# INLINABLE generate1 #-} 272 | generate1 :: (Monoid v, Constructible v) => Int -> (Int -> Elem v) -> v 273 | generate1 n f = if n <= 0 274 | then zero 275 | else fromList1N n (f 0) (map f [1..n-1]) 276 | 277 | {- 278 | FIXME: 279 | All this needs to be reimplemented using the subhask interface. 280 | This requires fixing some of the features of subhask's linear algebra system. 281 | 282 | -- | translate a dataset so the mean is zero 283 | {-# INLINABLE meanCenter #-} 284 | meanCenter :: 285 | ( VG.Vector v1 (v2 a) 286 | , VG.Vector v2 a 287 | , Real a 288 | ) => v1 (v2 a) -> v1 (v2 a) 289 | meanCenter dps = {-# SCC meanCenter #-} VG.map (\v -> VG.zipWith (-) v meanV) dps 290 | where 291 | meanV = {-# SCC meanV #-} VG.map (/ fromIntegral (VG.length dps)) $ VG.foldl1' (VG.zipWith (+)) dps 292 | 293 | -- | rotates the data using the PCA transform 294 | {-# INLINABLE rotatePCA #-} 295 | rotatePCA :: 296 | ( VG.Vector container dp 297 | , VG.Vector container [Float] 298 | , VG.Vector v a 299 | , dp ~ v a 300 | , Show a 301 | , a ~ Float 302 | ) => container dp -> container dp 303 | rotatePCA dps' = {-# SCC rotatePCA #-} VG.map rotate dps 304 | where 305 | -- rotate dp = VG.convert $ LA.single $ eigm LA.<> LA.double (VG.convert dp :: VS.Vector Float) 306 | rotate dp = {-# SCC convert #-} VG.convert $ LA.single $ (LA.trans eigm) LA.<> LA.double (VG.convert dp :: VS.Vector Float) 307 | dps = meanCenter dps' 308 | 309 | (eigv,eigm) = {-# SCC eigSH #-} LA.eigSH $ LA.double gramMatrix 310 | 311 | -- gramMatrix = {-# SCC gramMatrix #-} gramMatrix_ $ map VG.convert $ VG.toList dps 312 | -- gramMatrix = {-# SCC gramMatrix #-} LA.trans tmpm LA.<> tmpm 313 | -- where 314 | -- tmpm = LA.fromLists (VG.toList $ VG.map VG.toList dps) 315 | 316 | gramMatrix = {-# SCC gramMatrix #-} foldl1' (P.+) 317 | [ let dp' = VG.convert dp in LA.asColumn dp' LA.<> LA.asRow dp' | dp <- VG.toList dps ] 318 | 319 | gramMatrix_ :: (Ring a, Storable a) => [VS.Vector a] -> LA.Matrix a 320 | gramMatrix_ xs = runST ( do 321 | let dim = VG.length (head xs) 322 | m <- LA.newMatrix 0 dim dim 323 | 324 | forM_ xs $ \x -> do 325 | forM_ [0..dim-1] $ \i -> do 326 | forM_ [0..dim-1] $ \j -> do 327 | mij <- LA.unsafeReadMatrix m i j 328 | LA.unsafeWriteMatrix m i j $ mij + (x `VG.unsafeIndex` i)*(x `VG.unsafeIndex` j) 329 | 330 | LA.unsafeFreezeMatrix m 331 | ) 332 | 333 | 334 | {-# INLINABLE rotatePCADouble #-} 335 | -- | rotates the data using the PCA transform 336 | rotatePCADouble :: 337 | ( VG.Vector container (v Double) 338 | , VG.Vector container [Double] 339 | , VG.Vector v Double 340 | ) => container (v Double) -> container (v Double) 341 | rotatePCADouble dps' = VG.map rotate dps 342 | where 343 | rotate dp = VG.convert $ (LA.trans eigm) LA.<> (VG.convert dp :: VS.Vector Double) 344 | dps = meanCenter dps' 345 | 346 | (eigv,eigm) = LA.eigSH gramMatrix 347 | 348 | gramMatrix = LA.trans tmpm LA.<> tmpm 349 | where 350 | tmpm = LA.fromLists (VG.toList $ VG.map VG.toList dps) 351 | -} 352 | -------------------------------------------------------------------------------- /src/HLearn/Data/SpaceTree.hs: -------------------------------------------------------------------------------- 1 | -- | The space tree abstraction was pioneered by the library. 2 | -- It provides a generic interface for all tree structures that support nearest neighbor queries. 3 | -- The paper gives full details. 4 | -- 5 | -- FIXME: 6 | -- Should this interface be incorporated into subhask's "Container" class hierarchy? 7 | -- 8 | -- FIXME: 9 | -- There is a close relation between the pruning folds described in the MLPack paper and the theory of ana/catamorphisms that haskellers love. 10 | -- Making this explicit would require some serious work, but would (maybe) provide an even simpler interface. 11 | module HLearn.Data.SpaceTree 12 | ( 13 | -- * Type classes 14 | SpaceTree (..) 15 | 16 | -- * Generic algorithms 17 | , stToList 18 | -- , stToListW 19 | , stHasNoChildren 20 | , stChildrenList 21 | , stDescendents 22 | , stNumDp 23 | , stNumNodes 24 | , stNumLeaves 25 | , stNumGhosts 26 | , stAveGhostChildren 27 | , stNumGhostSingletons 28 | , stNumGhostLeaves 29 | , stNumGhostSelfparent 30 | , stMaxLeaves 31 | , stAveLeaves 32 | , stMaxChildren 33 | , stAveChildren 34 | , stMaxDepth 35 | , stNumSingletons 36 | , stExtraLeaves 37 | ) 38 | where 39 | 40 | import qualified Data.List as L 41 | 42 | import Prelude (map) 43 | import SubHask 44 | import SubHask.Algebra.Array 45 | import SubHask.Monad 46 | import SubHask.Compatibility.Containers 47 | 48 | import HLearn.Models.Distributions 49 | 50 | ------------------------------------------------------------------------------- 51 | -- SpaceTree 52 | 53 | class 54 | ( Metric dp 55 | , Bounded (Scalar dp) 56 | , Logic (t dp) ~ Bool 57 | , Logic dp ~ Bool 58 | , Logic (LeafContainer t dp) ~ Bool 59 | , Logic (LeafContainer t (t dp)) ~ Bool 60 | , Logic (ChildContainer t (t dp)) ~ Bool 61 | , Eq_ (t dp) 62 | , Eq_ (ChildContainer t (t dp)) 63 | -- , Scalar dp ~ Scalar (t dp) 64 | , Elem (ChildContainer t (t dp)) ~ t dp 65 | , Elem (LeafContainer t dp) ~ dp 66 | , Constructible (LeafContainer t dp) 67 | , Constructible (ChildContainer t (t dp)) 68 | , Normed (LeafContainer t dp) 69 | , Normed (ChildContainer t (t dp)) 70 | , Foldable (LeafContainer t dp) 71 | , Foldable (ChildContainer t (t dp)) 72 | , dp ~ Elem (t dp) 73 | ) => SpaceTree t dp 74 | where 75 | 76 | type LeafContainer t :: * -> * 77 | type LeafContainer t = BArray 78 | 79 | type ChildContainer t :: * -> * 80 | type ChildContainer t = BArray 81 | 82 | stChildren :: t dp -> ChildContainer t (t dp) 83 | stLeaves :: t dp -> LeafContainer t dp 84 | 85 | stNode :: t dp -> dp 86 | stWeight :: t dp -> Scalar dp 87 | 88 | {-# INLINE stNodeW #-} 89 | stNodeW :: t dp -> Labeled' dp (Scalar dp) 90 | stNodeW t = Labeled' (stNode t) (stWeight t) 91 | 92 | {-# INLINABLE stMaxDescendentDistance #-} 93 | stMaxDescendentDistance :: t dp -> Scalar dp 94 | stMaxDescendentDistance t = maxBound 95 | 96 | ------------------------------------------------------------------------------- 97 | 98 | instance 99 | ( Metric dp 100 | , Logic dp~Bool 101 | , Bounded (Scalar dp) 102 | ) => SpaceTree BArray dp 103 | where 104 | 105 | type LeafContainer BArray = BArray 106 | type ChildContainer BArray = BArray 107 | 108 | stChildren _ = empty 109 | stLeaves = id 110 | stNode = (!0) 111 | stWeight = 1 112 | 113 | ------------------------------------------------------------------------------- 114 | -- generic algorithms 115 | 116 | {-# INLINABLE stToSeqDFS #-} 117 | stToSeqDFS :: SpaceTree t dp => t dp -> Seq dp 118 | stToSeqDFS t 119 | = stNode t `cons` (fromList $ toList $ stLeaves t) 120 | + (foldl' (+) empty $ map stToSeqDFS $ stChildrenList t) 121 | 122 | {-# INLINABLE stToSeqBFS #-} 123 | stToSeqBFS :: SpaceTree t dp => t dp -> Seq dp 124 | stToSeqBFS t = stNode t `cons` go t 125 | where 126 | go t = (fromList $ toList $ stLeaves t) 127 | + (fromList $ map stNode $ toList $ stChildren t) 128 | + (foldl' (+) empty $ map go $ stChildrenList t) 129 | 130 | {-# INLINABLE stToList #-} 131 | stToList :: SpaceTree t dp => t dp -> [dp] 132 | stToList = toList . stToSeqDFS 133 | 134 | -- {-# INLINABLE stToList #-} 135 | -- stToList :: (Eq dp, SpaceTree t dp) => t dp -> [dp] 136 | -- stToList t = if stHasNoChildren t && stWeight t > 0 137 | -- then (stNode t):(toList $ stLeaves t) 138 | -- else go (concat $ map stToList $ stChildrenList t) 139 | -- where 140 | -- go xs = if stWeight t > 0 141 | -- then (stNode t) : (toList (stLeaves t) ++ xs) 142 | -- else toList (stLeaves t) ++ xs 143 | 144 | -- {-# INLINABLE stToListW 145 | -- stToListW :: (Eq dp, SpaceTree t dp) => t dp -> [Weighted dp] 146 | -- stToListW t = if stHasNoChildren t && stWeight t > 0 147 | -- then [(stWeight t,stNode t)] 148 | -- else go (concat $ map stToListW $ stChildrenList t) 149 | -- where 150 | -- go xs = if stWeight t > 0 151 | -- then (stWeight t,stNode t) : xs 152 | -- else xs 153 | 154 | -- {-# INLINABLE toTagList #-} 155 | -- toTagList :: (Eq dp, SpaceTree (t tag) dp, Taggable t dp) => t tag dp -> [(dp,tag)] 156 | -- toTagList t = if stHasNoChildren t 157 | -- then [(stNode t,getTag t)] 158 | -- else go (concat $ map toTagList $ stChildrenList t) 159 | -- where 160 | -- go xs = if stNode t `elem` (map stNode $ stChildrenList t) 161 | -- then xs 162 | -- else (stNode t,getTag t) : xs 163 | 164 | {-# INLINABLE stHasNoChildren #-} 165 | stHasNoChildren :: SpaceTree t dp => t dp -> Bool 166 | stHasNoChildren t = isEmpty $ stChildren t 167 | 168 | {-# INLINABLE stChildrenList #-} 169 | stChildrenList :: SpaceTree t dp => t dp -> [t dp] 170 | stChildrenList t = toList $ stChildren t 171 | 172 | {-# INLINABLE stDescendents #-} 173 | stDescendents :: SpaceTree t dp => t dp -> [dp] 174 | stDescendents t = case tailMaybe $ go t of 175 | Just xs -> xs 176 | where 177 | go t = stNode t : L.concatMap go (stChildrenList t) ++ toList (stLeaves t) 178 | 179 | -- stDescendents t = if stHasNoChildren t 180 | -- then [stNode t] 181 | -- else L.concatMap stDescendents (stChildrenList t) ++ toList (stLeaves t) 182 | 183 | {-# INLINABLE stNumDp #-} 184 | stNumDp :: SpaceTree t dp => t dp -> Scalar dp 185 | stNumDp t = if stHasNoChildren t 186 | then stWeight t 187 | else stWeight t + sum (map stNumDp $ stChildrenList t) 188 | 189 | {-# INLINABLE stNumNodes #-} 190 | stNumNodes :: SpaceTree t dp => t dp -> Int 191 | stNumNodes t = if stHasNoChildren t 192 | then 1 193 | else 1 + sum (map stNumNodes $ stChildrenList t) 194 | 195 | {-# INLINABLE stNumLeaves #-} 196 | stNumLeaves :: 197 | ( Integral (Scalar (LeafContainer t dp)) 198 | , SpaceTree t dp 199 | ) => t dp -> Scalar dp 200 | stNumLeaves t = (fromIntegral $ size (stLeaves t)) + sum (map stNumLeaves $ toList $ stChildren t) 201 | 202 | {-# INLINABLE stNumGhosts #-} 203 | stNumGhosts :: SpaceTree t dp => t dp -> Int 204 | stNumGhosts t = (if stWeight t == 0 then 1 else 0) + if stHasNoChildren t 205 | then 0 206 | else sum (map stNumGhosts $ stChildrenList t) 207 | 208 | {-# INLINABLE stAveGhostChildren #-} 209 | stAveGhostChildren :: SpaceTree t dp => t dp -> Normal Double 210 | stAveGhostChildren t = 211 | ( if stWeight t == 0 212 | then train1Normal . fromIntegral . size $ stChildrenList t 213 | else zero 214 | ) 215 | + 216 | ( if stHasNoChildren t 217 | then zero 218 | else (reduce . map stAveGhostChildren $ stChildrenList t) 219 | ) 220 | 221 | {-# INLINABLE stMaxLeaves #-} 222 | stMaxLeaves :: SpaceTree t dp => t dp -> Int 223 | stMaxLeaves t = maximum $ (size $ toList $ stLeaves t):(map stMaxLeaves $ stChildrenList t) 224 | 225 | {-# INLINABLE stAveLeaves #-} 226 | stAveLeaves :: SpaceTree t dp => t dp -> Normal Double 227 | stAveLeaves t = (train1Normal . fromIntegral . size . toList $ stLeaves t) 228 | + (reduce . map stAveLeaves $ stChildrenList t) 229 | 230 | {-# INLINABLE stMaxChildren #-} 231 | stMaxChildren :: SpaceTree t dp => t dp -> Int 232 | stMaxChildren t = if stHasNoChildren t 233 | then 0 234 | else maximum $ (size $ stChildrenList t):(map stMaxChildren $ stChildrenList t) 235 | 236 | {-# INLINABLE stAveChildren #-} 237 | stAveChildren :: SpaceTree t dp => t dp -> Normal Double 238 | stAveChildren t = if stHasNoChildren t 239 | then zero 240 | else (train1Normal . fromIntegral . size $ stChildrenList t) 241 | + (reduce . map stAveChildren $ stChildrenList t) 242 | 243 | {-# INLINABLE stMaxDepth #-} 244 | stMaxDepth :: SpaceTree t dp => t dp -> Int 245 | stMaxDepth t = if stHasNoChildren t 246 | then 1 247 | else 1+maximum (map stMaxDepth $ stChildrenList t) 248 | 249 | {-# INLINABLE stNumSingletons #-} 250 | stNumSingletons :: SpaceTree t dp => t dp -> Int 251 | stNumSingletons t = if stHasNoChildren t 252 | then 0 253 | else sum (map stNumSingletons $ stChildrenList t) + if size (stChildrenList t) == 1 254 | then 1 255 | else 0 256 | 257 | {-# INLINABLE stNumGhostSingletons #-} 258 | stNumGhostSingletons :: SpaceTree t dp => t dp -> Int 259 | stNumGhostSingletons t = if stHasNoChildren t 260 | then 0 261 | else sum (map stNumGhostSingletons $ stChildrenList t) 262 | + if size (stChildrenList t) == 1 && stWeight t==0 263 | then 1 264 | else 0 265 | 266 | {-# INLINABLE stNumGhostLeaves #-} 267 | stNumGhostLeaves :: SpaceTree t dp => t dp -> Int 268 | stNumGhostLeaves t = if stHasNoChildren t 269 | then if stWeight t==0 270 | then 1 271 | else 0 272 | else sum (map stNumGhostLeaves $ stChildrenList t) 273 | 274 | {-# INLINABLE stNumGhostSelfparent #-} 275 | stNumGhostSelfparent :: (Eq dp, SpaceTree t dp) => t dp -> Int 276 | stNumGhostSelfparent t = if stHasNoChildren t 277 | then 0 278 | else sum (map stNumGhostSelfparent $ stChildrenList t) 279 | + if stWeight t==0 && stNode t `elem` map stNode (stChildrenList t) 280 | then 1 281 | else 0 282 | 283 | {-# INLINABLE stExtraLeaves #-} 284 | stExtraLeaves :: (Eq dp, SpaceTree t dp) => t dp -> Int 285 | stExtraLeaves t = if stHasNoChildren t 286 | then 0 287 | else sum (map stExtraLeaves $ stChildrenList t) 288 | + if supremum $ map (\c -> stNode c==stNode t && stHasNoChildren c) $ stChildrenList t 289 | then 1 290 | else 0 291 | -------------------------------------------------------------------------------- /src/HLearn/Data/SpaceTree/Algorithms.hs: -------------------------------------------------------------------------------- 1 | -- | This module contains algorithms for efficient operations over space trees. 2 | -- Currently, only nearest neighbor queries are implemented. 3 | -- It would be easy to implement more query types, however. 4 | -- If there is another query type you want supported, ask me and I'll implement it for you. 5 | -- 6 | -- The paper gives full details on possible queries. 7 | module HLearn.Data.SpaceTree.Algorithms 8 | ( 9 | 10 | Neighbor (..) 11 | 12 | , findNeighbor 13 | , findNeighbor_NoSort 14 | ) 15 | where 16 | 17 | import GHC.Exts (inline) 18 | import Data.List (sortBy) 19 | 20 | import SubHask 21 | import SubHask.Algebra.Array 22 | import SubHask.Algebra.Container 23 | import SubHask.Algebra.Vector 24 | import SubHask.Compatibility.Containers 25 | import SubHask.Monad 26 | import SubHask.TemplateHaskell.Deriving 27 | 28 | import HLearn.Data.SpaceTree 29 | 30 | ------------------------------------------------------------------------------- 31 | 32 | data Neighbor dp = Neighbor 33 | { neighbor :: !dp 34 | , neighborDistance :: !(Scalar dp) 35 | } 36 | 37 | deriving instance (Show dp, Show (Scalar dp)) => Show (Neighbor dp) 38 | 39 | instance (NFData dp, NFData (Scalar dp)) => NFData (Neighbor dp) where 40 | rnf (Neighbor _ _) = () 41 | 42 | type instance Logic (Neighbor dp) = Bool 43 | 44 | instance (Eq dp, Eq (Scalar dp)) => Eq_ (Neighbor dp) where 45 | (Neighbor dp1 dist1)==(Neighbor dp2 dist2) = dist1==dist2 && dp1==dp2 46 | 47 | ---------------------------------------- 48 | 49 | -- | Find the nearest neighbor of a node. 50 | -- 51 | -- NOTE: 52 | -- If we remove the call to "inline" on "foldr'", 53 | -- GHC 7.10 will pass dictionaries and everything becomes very slow. 54 | {-# INLINE findNeighbor #-} 55 | findNeighbor :: 56 | ( SpaceTree t dp 57 | , Bounded (Scalar dp) 58 | ) => Scalar dp -> t dp -> dp -> Neighbor dp 59 | findNeighbor ε t q = 60 | {-# SCC findNeighbor #-} 61 | go (Labeled' t startdist) startnode 62 | where 63 | startnode = if startdist == 0 64 | then Neighbor q maxBound 65 | else Neighbor (stNode t) startdist 66 | 67 | startdist = distance (stNode t) q 68 | 69 | go (Labeled' t dist) (Neighbor n distn) = if dist*ε > maxdist 70 | then Neighbor n distn 71 | else inline foldr' go leafres 72 | $ sortBy (\(Labeled' _ d1) (Labeled' _ d2) -> compare d2 d1) 73 | $ map (\t' -> Labeled' t' (distanceUB q (stNode t') (distnleaf+stMaxDescendentDistance t))) 74 | $ toList 75 | $ stChildren t 76 | where 77 | leafres@(Neighbor _ distnleaf) = inline foldr' 78 | (\dp n@(Neighbor _ distn') -> cata dp (distanceUB q dp distn') n) 79 | (cata (stNode t) dist (Neighbor n distn)) 80 | (stLeaves t) 81 | 82 | maxdist = distn+stMaxDescendentDistance t 83 | 84 | cata !dp !dist (Neighbor n distn) = 85 | if dist==0 || dist>distn 86 | then Neighbor n distn 87 | else Neighbor dp dist 88 | 89 | ---------------------------------------- 90 | 91 | -- | Find the nearest neighbor of a node. 92 | -- Internally, this function does not sort the distances of the children before descending. 93 | -- In some (rare) cases this reduces the number of distance comparisons. 94 | {-# INLINE findNeighbor_NoSort #-} 95 | findNeighbor_NoSort :: 96 | ( SpaceTree t dp 97 | , Bounded (Scalar dp) 98 | ) => Scalar dp -> t dp -> dp -> Neighbor dp 99 | findNeighbor_NoSort ε t q = 100 | {-# SCC findNeighbor_NoSort #-} 101 | go t (Neighbor q maxBound) 102 | where 103 | go t res@(Neighbor _ distn) = if dist*ε > maxdist 104 | then res 105 | else inline foldr' go leafres $ stChildren t 106 | where 107 | leafres = inline foldr' 108 | (\dp n@(Neighbor _ distn') -> cata dp (distanceUB q dp distn') n) 109 | (cata (stNode t) dist res) 110 | (stLeaves t) 111 | 112 | dist = distanceUB q (stNode t) maxdist 113 | maxdist = distn+stMaxDescendentDistance t 114 | 115 | cata !dp !dist (Neighbor n distn) = 116 | if dist==0 || dist>distn 117 | then Neighbor n distn 118 | else Neighbor dp dist 119 | -------------------------------------------------------------------------------- /src/HLearn/Data/SpaceTree/Algorithms/Correlation.hs: -------------------------------------------------------------------------------- 1 | module HLearn.DataStructures.SpaceTree.Algorithms.Correlation 2 | where 3 | 4 | import Debug.Trace 5 | 6 | import Control.DeepSeq 7 | import qualified Data.Map.Strict as Map 8 | import qualified Data.Set as Set 9 | import qualified Data.Strict.Either as Strict 10 | import qualified Data.Strict.Maybe as Strict 11 | import qualified Data.Strict.Tuple as Strict 12 | import GHC.TypeLits 13 | 14 | import HLearn.Algebra 15 | import HLearn.DataStructures.SpaceTree 16 | 17 | ------------------------------------------------------------------------------- 18 | -- Correlation 19 | 20 | data Correlation dp = Correlation 21 | { unCorrelation :: !(Ring dp) 22 | , range :: !(Ring dp) 23 | } 24 | deriving (Read,Show,Eq,Ord) 25 | 26 | mkCorrelation :: HasRing dp => Ring dp -> Correlation dp 27 | mkCorrelation r = Correlation 0 r 28 | 29 | instance NFData (Correlation dp) where 30 | rnf c = seq c () 31 | 32 | --------------------------------------- 33 | 34 | instance (Eq (Ring dp), HasRing dp) => Abelian (Correlation dp) 35 | instance (Eq (Ring dp), HasRing dp) => Monoid (Correlation dp) where 36 | mempty = Correlation 0 0 37 | mappend (Correlation c1 r1) (Correlation c2 r2) = if r1 /= r2 38 | then error "Correlation.Monoid./=" 39 | else Correlation (r1+r2) r1 40 | 41 | instance (Eq (Ring dp), HasRing dp) => Group (Correlation dp) where 42 | inverse (Correlation c r) = Correlation (-c) r 43 | 44 | instance (Eq (Ring dp), HasRing dp) => HasRing (Correlation dp) where 45 | type Ring (Correlation dp) = Ring dp 46 | 47 | instance (Eq (Ring dp), HasRing dp) => Module (Correlation dp) where 48 | a .* (Correlation c r) = Correlation (c*a) r 49 | 50 | --------------------------------------- 51 | 52 | findCorrelationSingle :: SpaceTree t dp => t dp -> Ring dp -> dp -> Correlation dp 53 | findCorrelationSingle st range dp = prunefoldC (cor_catadp dp) (cor_cata dp) mempty st 54 | 55 | cor_catadp :: MetricSpace dp => dp -> dp -> Correlation dp -> Correlation dp 56 | cor_catadp query dp cor = if distance query dp < range cor 57 | then cor { unCorrelation = unCorrelation cor+1 } 58 | else cor 59 | 60 | cor_cata :: SpaceTree t dp => dp -> t dp -> Correlation dp -> Strict.Either (Correlation dp) (Correlation dp) 61 | cor_cata query st cor = case stIsMinDistanceDpFartherThanWithDistance st query (range cor) of 62 | Strict.Nothing -> Strict.Left cor 63 | Strict.Just dist -> if stMaxDistanceDp st query < range cor 64 | then Strict.Left $ cor { unCorrelation = unCorrelation cor + numdp st } 65 | else Strict.Right $ if dist < range cor 66 | then cor { unCorrelation = unCorrelation cor + stWeight st } 67 | else cor 68 | 69 | findCorrelationDual :: (Eq dp, SpaceTree t dp) => DualTree (t dp) -> Ring dp -> Correlation dp 70 | findCorrelationDual dual range = reduce $ 71 | map (\dp -> findCorrelationSingle (reference dual) range dp) (stToList $ query dual) 72 | -------------------------------------------------------------------------------- /src/HLearn/Data/SpaceTree/Algorithms/KernelDensityEstimation.hs: -------------------------------------------------------------------------------- 1 | module HLearn.DataStructures.SpaceTree.Algorithms.KernelDensityEstimation 2 | where 3 | 4 | import Debug.Trace 5 | 6 | import Control.DeepSeq 7 | import qualified Data.Map.Strict as Map 8 | import qualified Data.Set as Set 9 | import qualified Data.Strict.Either as Strict 10 | import qualified Data.Strict.Maybe as Strict 11 | import qualified Data.Strict.Tuple as Strict 12 | import GHC.TypeLits 13 | 14 | import HLearn.Algebra 15 | import HLearn.DataStructures.SpaceTree 16 | import HLearn.Models.Distributions.Kernels 17 | 18 | ------------------------------------------------------------------------------- 19 | -- Density 20 | 21 | newtype Density kernel dp = Density { getdensity :: (Ring dp) } 22 | deriving (Read,Show,Eq,Ord) 23 | 24 | deriving instance NFData (Ring dp) => NFData (Density kernel dp) 25 | 26 | --------------------------------------- 27 | 28 | instance HasRing dp => Abelian (Density kernel dp) 29 | instance HasRing dp => Monoid (Density kernel dp) where 30 | {-# INLINE mempty #-} 31 | {-# INLINE mappend #-} 32 | mempty = Density 0 33 | mappend (Density a) (Density b) = Density $ a+b 34 | 35 | instance HasRing dp => Group (Density kernel dp) where 36 | {-# INLINE inverse #-} 37 | inverse (Density a) = Density (-a) 38 | 39 | instance HasRing dp => HasRing (Density kernel dp) where 40 | type Ring (Density kernel dp) = Ring dp 41 | 42 | instance HasRing dp => Module (Density kernel dp) where 43 | r .* (Density a) = Density $ r * a 44 | 45 | --------------------------------------- 46 | 47 | findDensity :: 48 | ( SpaceTree t dp 49 | , Function kernel (Ring (t dp)) (Ring (t dp)) 50 | ) => 51 | t dp -> Ring dp -> dp -> Density kernel dp 52 | findDensity st epsilon query = prunefoldC (kde_catadp bound query) (kde_cata bound query) mempty st 53 | where 54 | bound = epsilon/numdp st 55 | 56 | kde_catadp :: forall dp kernel. 57 | ( MetricSpace dp 58 | , Function kernel (Ring dp) (Ring dp) 59 | ) => Ring dp -> dp -> dp -> Density kernel dp -> Density kernel dp 60 | kde_catadp bound query dp kde = kde <> density' 61 | where 62 | f = function (undefined::kernel) 63 | density' = Density $ f $ distance dp query 64 | 65 | kde_cata :: forall t dp kernel. 66 | ( SpaceTree t dp 67 | , Function kernel (Ring dp) (Ring dp) 68 | ) => Ring dp -> dp -> t dp -> Density kernel dp -> Strict.Either (Density kernel dp) (Density kernel dp) 69 | kde_cata bound query st kde = if kmin-kmax>bound 70 | then Strict.Left $ kde <> numdp st .* density' 71 | else Strict.Right $ kde <> stWeight st .* density' 72 | where 73 | f = function (undefined::kernel) 74 | density' = Density $ f $ distance (stNode st) query 75 | 76 | kmin = f $ stMinDistanceDp st query 77 | kmax = f $ stMaxDistanceDp st query 78 | 79 | ------------------------------------------------------------------------------- 80 | -- DensityMap 81 | 82 | newtype DensityMap kernel dp = DensityMap { dm2map :: Map.Map dp (Density kernel dp) } 83 | 84 | deriving instance (NFData dp, NFData (Ring dp)) => NFData (DensityMap kernel dp) 85 | 86 | instance (Ord dp, HasRing dp, Function kernel (Ring dp) (Ring dp)) => Monoid (DensityMap kernel dp) where 87 | mempty = DensityMap mempty 88 | mappend !(DensityMap rm1) !(DensityMap rm2) = DensityMap $ Map.unionWith (<>) rm1 rm2 89 | 90 | --------------------------------------- 91 | 92 | findDensityMap :: 93 | ( SpaceTree t dp 94 | , Function kernel (Ring dp) (Ring dp) 95 | , Ord dp 96 | ) => Ring dp -> DualTree (t dp) -> DensityMap kernel dp 97 | findDensityMap epsilon dual = reduce $ 98 | map (\dp -> DensityMap $ Map.singleton dp $ findDensity (reference dual) epsilon dp) (stToList $ query dual) 99 | 100 | -------------------------------------------------------------------------------- /src/HLearn/Data/SpaceTree/Algorithms/NearestNeighbor.Old.hs: -------------------------------------------------------------------------------- 1 | 2 | {-# LANGUAGE DataKinds #-} 3 | 4 | module HLearn.DataStructures.SpaceTree.Algorithms.NearestNeighbor 5 | where 6 | 7 | import Debug.Trace 8 | 9 | import Control.Monad 10 | import Control.Monad.ST 11 | import Control.DeepSeq 12 | import qualified Data.Foldable as F 13 | import qualified Data.Map.Strict as Map 14 | import qualified Data.Vector as V 15 | import qualified Data.Vector.Mutable as VM 16 | 17 | import HLearn.Algebra 18 | import HLearn.DataStructures.SpaceTree 19 | 20 | ------------------------------------------------------------------------------- 21 | -- data types 22 | 23 | data Neighbor dp = Neighbor 24 | { neighbor :: !dp 25 | , neighborDistance :: !(Ring dp) 26 | } 27 | 28 | deriving instance (Read dp, Read (Ring dp)) => Read (Neighbor dp) 29 | deriving instance (Show dp, Show (Ring dp)) => Show (Neighbor dp) 30 | 31 | instance Eq (Ring dp) => Eq (Neighbor dp) where 32 | a == b = neighborDistance a == neighborDistance b 33 | 34 | instance Ord (Ring dp) => Ord (Neighbor dp) where 35 | compare a b = compare (neighborDistance a) (neighborDistance b) 36 | 37 | instance (NFData dp, NFData (Ring dp)) => NFData (Neighbor dp) where 38 | rnf n = deepseq (neighbor n) $ rnf (neighborDistance n) 39 | 40 | --------------------------------------- 41 | 42 | -- newtype KNN (k::Nat) dp = KNN { getknn :: [Neighbor dp] } 43 | newtype KNN (k::Nat) dp = KNN { getknn :: V.Vector (Neighbor dp) } 44 | 45 | deriving instance (Read dp, Read (Ring dp)) => Read (KNN k dp) 46 | deriving instance (Show dp, Show (Ring dp)) => Show (KNN k dp) 47 | deriving instance (NFData dp, NFData (Ring dp)) => NFData (KNN k dp) 48 | 49 | knn_maxdist :: forall k dp. (SingI k,Ord dp,Fractional (Ring dp)) => KNN k dp -> Ring dp 50 | knn_maxdist (KNN v) = if V.length v > 0 51 | then neighborDistance $ v V.! (V.length v-1) 52 | else inf 53 | 54 | inf :: Fractional n => n 55 | inf = 1/0 56 | 57 | --------------------------------------- 58 | 59 | newtype KNN2 (k::Nat) dp = KNN2 60 | { getknn2 :: Map.Map dp (KNN k dp) 61 | } 62 | 63 | deriving instance (Read dp, Read (Ring dp), Ord dp, Read (KNN k dp)) => Read (KNN2 k dp) 64 | deriving instance (Show dp, Show (Ring dp), Ord dp, Show (KNN k dp)) => Show (KNN2 k dp) 65 | deriving instance (NFData dp, NFData (Ring dp)) => NFData (KNN2 k dp) 66 | 67 | instance (SpaceTree t dp, Ord dp, SingI k) => Function (KNN2 k dp) (DualTree (t dp)) (KNN2 k dp) where 68 | function _ = knn2 69 | 70 | ------------------------------------------------------------------------------- 71 | -- algebra 72 | 73 | instance (SingI k, MetricSpace dp, Eq dp) => Monoid (KNN k dp) where 74 | mempty = KNN mempty 75 | mappend (KNN v1) (KNN v2) = KNN $ runST $ do 76 | v' <- VM.new k' 77 | go v' 0 0 0 78 | V.unsafeFreeze v' 79 | where 80 | go :: VM.MVector s (Neighbor dp) -> Int -> Int -> Int -> ST s () 81 | go v' i i1 i2 = if i>=k' 82 | then return () 83 | else if v1 V.! i1 < v2 V.! i2 84 | then VM.write v' i (v1 V.! i1) >> go v' (i+1) (i1+1) i2 85 | else VM.write v' i (v2 V.! i2) >> go v' (i+1) i1 (i2+1) 86 | 87 | k'=min k (V.length v1+V.length v2) 88 | k=fromIntegral $ fromSing (sing :: Sing k) 89 | -- 90 | -- mempty = KNN [] 91 | -- mappend (KNN xs) (KNN ys) = KNN $ take k $ interleave xs ys 92 | -- where 93 | -- k=fromIntegral $ fromSing (sing :: Sing k) 94 | 95 | instance (SingI k, MetricSpace dp, Ord dp) => Monoid (KNN2 k dp) where 96 | mempty = KNN2 mempty 97 | mappend (KNN2 x) (KNN2 y) = KNN2 $ Map.unionWith (<>) x y 98 | 99 | ------------------------------------------------------------------------------- 100 | -- dual tree 101 | 102 | knn2 :: (SpaceTree t dp, Ord dp, SingI k) => DualTree (t dp) -> KNN2 k dp 103 | knn2=knn2_fast 104 | 105 | knn2_fast :: (SpaceTree t dp, Ord dp, SingI k) => DualTree (t dp) -> KNN2 k dp 106 | knn2_fast = prunefold2init initKNN2 knn2_prune knn2_cata 107 | 108 | knn2_slow :: (SpaceTree t dp, Ord dp, SingI k) => DualTree (t dp) -> KNN2 k dp 109 | knn2_slow = prunefold2init initKNN2 noprune knn2_cata 110 | 111 | initKNN2 :: SpaceTree t dp => DualTree (t dp) -> KNN2 k dp 112 | initKNN2 dual = KNN2 $ Map.singleton qnode val 113 | where 114 | rnode = stNode $ reference dual 115 | qnode = stNode $ query dual 116 | val = KNN $ V.singleton $ Neighbor rnode (distance qnode rnode) 117 | 118 | knn2_prune :: forall k t dp. (SingI k, SpaceTree t dp, Ord dp) => KNN2 k dp -> DualTree (t dp) -> Bool 119 | knn2_prune knn2 dual = stMinDistance (reference dual) (query dual) > bound 120 | where 121 | bound = maxdist knn2 (reference dual) 122 | 123 | dist :: forall k t dp. (SingI k, MetricSpace dp, Ord dp) => KNN2 k dp -> dp -> Ring dp 124 | dist knn2 dp = knn_maxdist $ Map.findWithDefault mempty dp $ getknn2 knn2 125 | 126 | maxdist :: forall k t dp. (SingI k, SpaceTree t dp, Ord dp) => KNN2 k dp -> t dp -> Ring dp 127 | maxdist knn2 tree = if stIsLeaf tree 128 | then dist knn2 (stNode tree) 129 | else maximum 130 | $ (dist knn2 (stNode tree)) 131 | : (fmap (maxdist knn2) $ stChildren tree) 132 | 133 | knn2_cata :: (SingI k, Ord dp, MetricSpace dp) => DualTree dp -> KNN2 k dp -> KNN2 k dp 134 | knn2_cata !dual !knn2 = KNN2 $ Map.insertWith (<>) qnode knn' $ getknn2 knn2 135 | where 136 | rnode = reference dual 137 | qnode = query dual 138 | dualdist = distance rnode qnode 139 | knn' = KNN $ V.singleton $ Neighbor rnode dualdist 140 | 141 | 142 | ------------------------------------------------------------------------------- 143 | -- single tree 144 | 145 | init_neighbor :: SpaceTree t dp => dp -> t dp -> Neighbor dp 146 | init_neighbor query t = Neighbor 147 | { neighbor = stNode t 148 | , neighborDistance = distance query (stNode t) 149 | } 150 | 151 | nearestNeighbor :: SpaceTree t dp => dp -> t dp -> Neighbor dp 152 | nearestNeighbor query t = prunefoldinit (init_neighbor query) (nn_prune query) (nn_cata query) t 153 | 154 | nearestNeighbor_slow :: SpaceTree t dp => dp -> t dp -> Neighbor dp 155 | nearestNeighbor_slow query t = prunefoldinit undefined noprune (nn_cata query) t 156 | 157 | nn_prune :: SpaceTree t dp => dp -> Neighbor dp -> t dp -> Bool 158 | nn_prune query b t = neighborDistance b < distance query (stNode t) 159 | 160 | nn_cata :: MetricSpace dp => dp -> dp -> Neighbor dp -> Neighbor dp 161 | nn_cata query next current = if neighborDistance current < nextDistance 162 | then current 163 | else Neighbor next nextDistance 164 | where 165 | nextDistance = distance query next 166 | 167 | --------------------------------------- 168 | 169 | knn :: (SingI k, SpaceTree t dp, Eq dp) => dp -> t dp -> KNN k dp 170 | knn query t = prunefoldinit (init_knn query) (knn_prune query) (knn_cata query) t 171 | 172 | knn_prune :: forall k t dp. (SingI k, SpaceTree t dp) => dp -> KNN k dp -> t dp -> Bool 173 | knn_prune query res t = knnMaxDistance res < distance query (stNode t) && knnFull res 174 | 175 | knn_cata :: (SingI k, MetricSpace dp, Eq dp) => dp -> dp -> KNN k dp -> KNN k dp 176 | knn_cata query next current = KNN (V.singleton (Neighbor next $ distance query next)) <> current 177 | 178 | knnFull :: forall k dp. SingI k => KNN k dp -> Bool 179 | knnFull knn = V.length (getknn knn) > k 180 | where 181 | k = fromIntegral $ fromSing (sing :: Sing k) 182 | 183 | knnMaxDistance :: KNN k dp -> Ring dp 184 | knnMaxDistance (KNN xs) = neighborDistance $ V.last xs 185 | 186 | init_knn :: SpaceTree t dp => dp -> t dp -> KNN k dp 187 | init_knn query t = KNN $ V.singleton $ Neighbor (stNode t) (distance (stNode t) query) 188 | 189 | interleave :: (Eq a, Ord (Ring a)) => [Neighbor a] -> [Neighbor a] -> [Neighbor a] 190 | interleave xs [] = xs 191 | interleave [] ys = ys 192 | interleave (x:xs) (y:ys) = case compare x y of 193 | LT -> x:(interleave xs (y:ys)) 194 | GT -> y:(interleave (x:xs) ys) 195 | EQ -> if neighbor x == neighbor y 196 | then x:interleave xs ys 197 | else x:y:interleave xs ys 198 | 199 | --------------------------------------- 200 | 201 | knn2_single :: (SingI k, SpaceTree t dp, Eq dp, F.Foldable t, Ord dp) => DualTree (t dp) -> KNN2 k dp 202 | knn2_single dual = F.foldMap (\dp -> KNN2 $ Map.singleton dp $ knn dp $ reference dual) (query dual) 203 | -------------------------------------------------------------------------------- /src/HLearn/Data/SpaceTree/Algorithms/NearestNeighbor.hs: -------------------------------------------------------------------------------- 1 | -- | This module contains algorithms for efficient operations over space trees. 2 | -- Currently, only nearest neighbor queries are implemented. 3 | -- It would be easy to implement more query types, however. 4 | -- If there is another query type you want supported, ask me and I'll implement it for you. 5 | -- 6 | -- The paper gives full details on possible queries. 7 | module HLearn.Data.SpaceTree.Algorithms 8 | ( 9 | 10 | Neighbor (..) 11 | , ValidNeighbor (..) 12 | 13 | , findNeighbor 14 | , findNeighbor_NoSort 15 | ) 16 | where 17 | 18 | import GHC.Exts (inline) 19 | import Data.List (sortBy) 20 | 21 | import SubHask 22 | import SubHask.Algebra.Array 23 | import SubHask.Algebra.Container 24 | import SubHask.Algebra.Vector 25 | import SubHask.Compatibility.Containers 26 | import SubHask.Monad 27 | import SubHask.TemplateHaskell.Deriving 28 | 29 | import HLearn.Data.SpaceTree 30 | 31 | ------------------------------------------------------------------------------- 32 | 33 | data Neighbor dp = Neighbor 34 | { neighbor :: !dp 35 | , neighborDistance :: !(Scalar dp) 36 | } 37 | 38 | deriving instance (Show dp, Show (Scalar dp)) => Show (Neighbor dp) 39 | 40 | instance (NFData dp, NFData (Scalar dp)) => NFData (Neighbor dp) where 41 | rnf (Neighbor _ _) = () 42 | 43 | type instance Logic (Neighbor dp) = Bool 44 | 45 | instance (Eq dp, Eq (Scalar dp)) => Eq_ (Neighbor dp) where 46 | (Neighbor dp1 dist1)==(Neighbor dp2 dist2) = dist1==dist2 && dp1==dp2 47 | 48 | ---------------------------------------- 49 | 50 | -- | Find the nearest neighbor of a node. 51 | -- 52 | -- NOTE: 53 | -- If we remove the call to "inline" on "foldr'", 54 | -- GHC 7.10 will pass dictionaries and everything becomes very slow. 55 | {-# INLINE findNeighbor #-} 56 | findNeighbor :: 57 | ( SpaceTree t dp 58 | , Bounded (Scalar dp) 59 | ) => t dp -> dp -> Neighbor dp 60 | findNeighbor t q = 61 | {-# SCC findNeighbor #-} 62 | go (Labeled' t startdist) startnode 63 | where 64 | startnode = if startdist == 0 65 | then Neighbor q maxBound 66 | else Neighbor (stNode t) startdist 67 | 68 | startdist = distance (stNode t) q 69 | 70 | go (Labeled' t dist) (Neighbor n distn) = if dist > maxdist 71 | then Neighbor n distn 72 | else inline foldr' go leafres 73 | $ sortBy (\(Labeled' _ d1) (Labeled' _ d2) -> compare d2 d1) 74 | $ map (\t' -> Labeled' t' (distanceUB q (stNode t') (distnleaf+stMaxDescendentDistance t))) 75 | $ toList 76 | $ stChildren t 77 | where 78 | leafres@(Neighbor _ distnleaf) = inline foldr' 79 | (\dp n@(Neighbor _ distn') -> cata dp (distanceUB q dp distn') n) 80 | (cata (stNode t) dist (Neighbor n distn)) 81 | (stLeaves t) 82 | 83 | maxdist = distn+stMaxDescendentDistance t 84 | 85 | cata !dp !dist (Neighbor n distn) = 86 | if dist==0 || dist>distn 87 | then Neighbor n distn 88 | else Neighbor dp dist 89 | 90 | -- {-# INLINE findAllNeighbors #-} 91 | -- findAllNeighbors :: 92 | -- ( SpaceTree t dp 93 | -- , Bounded (Scalar dp) 94 | -- ) => Scalar dp 95 | -- -> t dp 96 | -- -> [dp] 97 | -- -> All Constructible0 ( dp, Neighbor dp ) 98 | -- findAllNeighbors epsilon rtree qs = fromList $ map (\dp -> (dp, findNeighbor dp rtree)) qs 99 | 100 | ---------------------------------------- 101 | 102 | -- | Find the nearest neighbor of a node. 103 | -- Internally, this function does not sort the distances of the children before descending. 104 | -- In some (rare) cases this reduces the number of distance comparisons. 105 | {-# INLINE findNeighbor_NoSort #-} 106 | findNeighbor_NoSort :: 107 | ( SpaceTree t dp 108 | , Bounded (Scalar dp) 109 | ) => t dp -> dp -> Neighbor dp 110 | findNeighbor_NoSort t q = 111 | {-# SCC findNeighbor_NoSort #-} 112 | go t (Neighbor q maxBound) 113 | where 114 | go t res@(Neighbor _ distn) = if dist > maxdist 115 | then res 116 | else inline foldr' go leafres $ stChildren t 117 | where 118 | leafres = inline foldr' 119 | (\dp n@(Neighbor _ distn') -> cata dp (distanceUB q dp distn') n) 120 | (cata (stNode t) dist res) 121 | (stLeaves t) 122 | 123 | dist = distanceUB q (stNode t) maxdist 124 | maxdist = distn+stMaxDescendentDistance t 125 | 126 | cata !dp !dist (Neighbor n distn) = 127 | if dist==0 || dist>distn 128 | then Neighbor n distn 129 | else Neighbor dp dist 130 | 131 | -- {-# INLINE findAllNeighbors_NoSort #-} 132 | -- findAllNeighbors_NoSort :: 133 | -- ( SpaceTree t dp 134 | -- , Bounded (Scalar dp) 135 | -- ) => Scalar dp 136 | -- -> t dp 137 | -- -> [dp] 138 | -- -> All Constructible0 ( dp, Neighbor dp ) 139 | -- findAllNeighbors_NoSort epsilon rtree qs = fromList $ map (\dp -> (dp, findNeighbor_NoSort dp rtree)) qs 140 | -------------------------------------------------------------------------------- /src/HLearn/Data/SpaceTree/Algorithms/RangeSearch.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ScopedTypeVariables, DataKinds #-} 2 | module HLearn.DataStructures.SpaceTree.Algorithms.RangeSearch 3 | where 4 | 5 | import Debug.Trace 6 | 7 | import qualified Data.Map.Strict as Map 8 | 9 | import SubHask 10 | import SubHask.Monad 11 | import HLearn.DataStructures.SpaceTree 12 | 13 | ------------------------------------------------------------------------------- 14 | -- Range 15 | 16 | data Range dp = Range 17 | { rangedp :: !dp 18 | , rangedistance :: !(Scalar dp) 19 | } 20 | 21 | type instance Logic (Range dp) = Bool 22 | 23 | deriving instance (Read dp, Read (Scalar dp)) => Read (Range dp) 24 | deriving instance (Show dp, Show (Scalar dp)) => Show (Range dp) 25 | 26 | instance NFData (Range dp) where 27 | rnf dp = seq dp () 28 | 29 | instance Eq (Scalar dp) => Eq_ (Range dp) where 30 | r1 == r2 = rangedistance r1 == rangedistance r2 31 | 32 | instance Ord (Scalar dp) => POrd_ (Range dp) where 33 | inf r1 r2 = if rangedistance r1 < rangedistance r2 34 | then r1 35 | else r2 36 | 37 | instance Ord (Scalar dp) => Lattice_ (Range dp) where 38 | sup r1 r2 = if rangedistance r1 > rangedistance r2 39 | then r1 40 | else r2 41 | 42 | instance Ord (Scalar dp) => Ord_ (Range dp) 43 | 44 | --------------------------------------- 45 | 46 | {-# INLINABLE findRangeList #-} 47 | findRangeList :: 48 | ( SpaceTree t dp 49 | , Eq dp 50 | , CanError (Scalar dp) 51 | ) => t dp -> Scalar dp -> dp -> [dp] 52 | findRangeList tree maxdist query = 53 | prunefold (rl_prune maxdist query) (rl_cata maxdist query) [] tree 54 | 55 | {-# INLINABLE rl_prune #-} 56 | rl_prune :: 57 | ( SpaceTree t dp 58 | , Ord (Scalar dp) 59 | ) => Scalar dp -> dp -> [dp] -> t dp -> Bool 60 | rl_prune maxdist query xs tree = 61 | stMinDistanceDp tree query > maxdist 62 | 63 | {-# INLINABLE rl_cata #-} 64 | rl_cata :: 65 | ( MetricSpace dp 66 | , Logic dp~Bool 67 | ) => Scalar dp -> dp -> dp -> [dp] -> [dp] 68 | rl_cata maxdist query dp xs = if distance dp query < maxdist 69 | then dp:xs 70 | else xs 71 | 72 | -- {-# INLINABLE findRangeList #-} 73 | -- findRangeList :: 74 | -- ( SpaceTree t dp 75 | -- , Eq dp 76 | -- , CanError (Scalar dp) 77 | -- ) => t dp -> Scalar dp -> dp -> [dp] 78 | -- findRangeList tree maxdist query = 79 | -- prunefoldB_CanError (rl_catadp maxdist query) (rl_cata maxdist query) [] tree 80 | -- 81 | -- {-# INLINABLE rl_catadp #-} 82 | -- rl_catadp :: 83 | -- ( MetricSpace dp 84 | -- , CanError (Scalar dp) 85 | -- , Ord (Scalar dp) 86 | -- ) => Scalar dp -> dp -> dp -> [dp] -> [dp] 87 | -- rl_catadp !maxdist !query !dp !rl = {-# SCC rl_catadp #-} 88 | -- if isError dist 89 | -- then rl 90 | -- else dp:rl 91 | -- where 92 | -- dist = isFartherThanWithDistanceCanError dp query maxdist 93 | -- 94 | -- {-# INLINABLE rl_cata #-} 95 | -- rl_cata :: 96 | -- ( SpaceTree t dp 97 | -- , CanError (Scalar dp) 98 | -- , Eq dp 99 | -- ) => Scalar dp -> dp -> t dp -> [dp] -> [dp] 100 | -- rl_cata !maxdist !query !tree !rl = {-# SCC rl_cata #-} 101 | -- if isError dist 102 | -- then errorVal 103 | -- else if isFartherThan (stNode tree) query maxdist 104 | -- then rl 105 | -- else stNode tree:rl 106 | -- where 107 | -- dist = stIsMinDistanceDpFartherThanWithDistanceCanError tree query maxdist 108 | 109 | 110 | ------------ 111 | ---- test 112 | 113 | -- instance MetricSpace (Double,Double) where 114 | -- distance (a1,a2) (b1,b2) = sqrt $ (a1-b1)**2 + (a2-b2)**2 115 | -- 116 | -- instance POrd_ (Double,Double) where 117 | -- inf (a1,a2) (b1,b2) = (inf a1 b1, inf a2 b2) 118 | -- 119 | -- instance SupSemilattice (Double,Double) where 120 | -- sup (a1,a2) (b1,b2) = (sup a1 b1, sup a2 b2) 121 | -- 122 | -- instance Lattice (Double,Double) where 123 | -- 124 | -- instance POrd (Double,Double) where 125 | -- pcompare (a1,a2) (b1,b2) = case pcompare a1 a2 of 126 | -- PEQ -> pcompare b1 b2 127 | -- _ -> pcompare a1 a2 128 | -- 129 | -- instance Ord (Double,Double) 130 | -------------------------------------------------------------------------------- /src/HLearn/Data/SpaceTree/Algorithms_Specialized.hs: -------------------------------------------------------------------------------- 1 | -- | This module contains algorithms for efficient operations over space trees. 2 | -- Currently, only nearest neighbor queries are implemented. 3 | -- It would be easy to implement more query types, however. 4 | -- If there is another query type you want supported, ask me and I'll implement it for you. 5 | -- 6 | -- The paper gives full details on possible queries. 7 | module HLearn.Data.SpaceTree.Algorithms_Specialized 8 | ( 9 | 10 | Neighbor (..) 11 | , ValidNeighbor (..) 12 | 13 | , findNeighbor 14 | , findNeighbor_NoSort 15 | ) 16 | where 17 | 18 | import GHC.Exts (inline) 19 | import Data.List (sortBy) 20 | 21 | import SubHask 22 | import SubHask.Algebra.Array 23 | import SubHask.Algebra.Container 24 | import SubHask.Algebra.Vector 25 | import SubHask.Compatibility.Containers 26 | import SubHask.Monad 27 | import SubHask.TemplateHaskell.Deriving 28 | 29 | import HLearn.Data.SpaceTree 30 | 31 | ------------------------------------------------------------------------------- 32 | 33 | data Neighbor dp = Neighbor 34 | -- { neighbor :: !dp 35 | -- , neighborDistance :: !(Scalar dp) 36 | { neighbor :: !(Labeled' (UVector "dyn" Float) Int) 37 | , neighborDistance :: !Float 38 | } 39 | 40 | type ValidNeighbor dp = dp~(Labeled' (UVector "dyn" Float) Int) 41 | 42 | deriving instance (Show dp, Show (Scalar dp)) => Show (Neighbor dp) 43 | 44 | instance (NFData dp, NFData (Scalar dp)) => NFData (Neighbor dp) where 45 | rnf (Neighbor _ _) = () 46 | 47 | type instance Logic (Neighbor dp) = Bool 48 | 49 | instance (Eq dp, Eq (Scalar dp)) => Eq_ (Neighbor dp) where 50 | (Neighbor dp1 dist1)==(Neighbor dp2 dist2) = dist1==dist2 && dp1==dp2 51 | 52 | ---------------------------------------- 53 | 54 | -- | Find the nearest neighbor of a node. 55 | -- 56 | -- NOTE: 57 | -- If we remove the call to "inline" on "foldr'", 58 | -- GHC 7.10 will pass dictionaries and everything becomes very slow. 59 | {-# INLINE findNeighbor #-} 60 | findNeighbor :: 61 | ( SpaceTree t dp 62 | , Bounded (Scalar dp) 63 | , ValidNeighbor dp 64 | ) => Scalar dp -> t dp -> dp -> Neighbor dp 65 | findNeighbor ε t q = 66 | {-# SCC findNeighbor #-} 67 | go (Labeled' t startdist) startnode 68 | where 69 | startnode = if startdist == 0 70 | then Neighbor q maxBound 71 | else Neighbor (stNode t) startdist 72 | 73 | startdist = distance (stNode t) q 74 | 75 | go (Labeled' t dist) (Neighbor n distn) = if dist*ε > maxdist 76 | then Neighbor n distn 77 | else inline foldr' go leafres 78 | $ sortBy (\(Labeled' _ d1) (Labeled' _ d2) -> compare d2 d1) 79 | $ map (\t' -> Labeled' t' (distanceUB q (stNode t') (distnleaf+stMaxDescendentDistance t))) 80 | $ toList 81 | $ stChildren t 82 | where 83 | leafres@(Neighbor _ distnleaf) = inline foldr' 84 | (\dp n@(Neighbor _ distn') -> cata dp (distanceUB q dp distn') n) 85 | (cata (stNode t) dist (Neighbor n distn)) 86 | (stLeaves t) 87 | 88 | maxdist = distn+stMaxDescendentDistance t 89 | 90 | cata !dp !dist (Neighbor n distn) = 91 | if dist==0 || dist>distn 92 | then Neighbor n distn 93 | else Neighbor dp dist 94 | 95 | ---------------------------------------- 96 | 97 | -- | Find the nearest neighbor of a node. 98 | -- Internally, this function does not sort the distances of the children before descending. 99 | -- In some (rare) cases this reduces the number of distance comparisons. 100 | {-# INLINE findNeighbor_NoSort #-} 101 | findNeighbor_NoSort :: 102 | ( SpaceTree t dp 103 | , Bounded (Scalar dp) 104 | , ValidNeighbor dp 105 | ) => Scalar dp -> t dp -> dp -> Neighbor dp 106 | findNeighbor_NoSort ε t q = 107 | {-# SCC findNeighbor_NoSort #-} 108 | go t (Neighbor q maxBound) 109 | where 110 | go t res@(Neighbor _ distn) = if dist*ε > maxdist 111 | then res 112 | else inline foldr' go leafres $ stChildren t 113 | where 114 | leafres = inline foldr' 115 | (\dp n@(Neighbor _ distn') -> cata dp (distanceUB q dp distn') n) 116 | (cata (stNode t) dist res) 117 | (stLeaves t) 118 | 119 | dist = distanceUB q (stNode t) maxdist 120 | maxdist = distn+stMaxDescendentDistance t 121 | 122 | cata !dp !dist (Neighbor n distn) = 123 | if dist==0 || dist>distn 124 | then Neighbor n distn 125 | else Neighbor dp dist 126 | -------------------------------------------------------------------------------- /src/HLearn/Data/SpaceTree/CoverTree/Unsafe.hs: -------------------------------------------------------------------------------- 1 | -- | This module let's us tune the cover tree's expansion ratio. 2 | -- You might be able to get around 5% better performance by tuning this value to your specific application. 3 | -- But you probably don't want to do this. 4 | module HLearn.Data.SpaceTree.CoverTree.Unsafe 5 | ( setExprat 6 | , getExprat 7 | ) 8 | where 9 | 10 | import SubHask 11 | import System.IO.Unsafe 12 | import Data.IORef 13 | 14 | -------------------------------------------------------------------------------- 15 | 16 | {-# NOINLINE expratIORef #-} 17 | expratIORef = unsafePerformIO $ newIORef (1.3::Rational) 18 | 19 | setExprat :: Rational -> IO () 20 | setExprat r = writeIORef expratIORef r 21 | 22 | {-# INLINABLE getExprat #-} 23 | getExprat :: Field r => r 24 | getExprat = fromRational $ unsafePerformIO $ readIORef expratIORef 25 | 26 | -------------------------------------------------------------------------------- /src/HLearn/Data/SpaceTree/Diagrams.hs: -------------------------------------------------------------------------------- 1 | module HLearn.Data.SpaceTree.Diagrams 2 | where 3 | 4 | import qualified Prelude as P 5 | 6 | import SubHask 7 | 8 | import HLearn.Data.SpaceTree 9 | import HLearn.Models.Distributions 10 | 11 | import Diagrams.Prelude () 12 | import qualified Diagrams.Prelude as D 13 | import Diagrams.Backend.SVG hiding (size) 14 | 15 | -------------------------------------------------------------------------------- 16 | 17 | {- 18 | -- drawCT :: 19 | -- ( ValidCT exprat childC leafC dp 20 | -- , VG.Vector childC (QDiagram SVG R2 Any) 21 | -- , Integral (Scalar (leafC dp)) 22 | -- , Integral (Scalar (childC (CoverTree_ exprat childC leafC dp))) 23 | -- ) => P.FilePath 24 | -- -> CoverTree_ exprat childC leafC dp 25 | -- -> IO () 26 | drawCT path ct = renderSVG path (Dims 500 300) (diagramCT_ 0 ct) 27 | 28 | 29 | -- diagramCT node = diagramCT_ 0 node 30 | 31 | -- type instance Scalar R2 = Double 32 | 33 | -- diagramCT_ :: 34 | -- ( ValidCT exprat childC leafC dp 35 | -- ) => Int 36 | -- -> CoverTree_ exprat childC leafC dp 37 | -- -> Diagram a R2 38 | diagramCT_ (depth::Int) tree 39 | = mkConnections $ 40 | ( named (label++show depth) $ fontSize (Global 0.01) $ 41 | ( 42 | (text label D.<> strutY 0.5) 43 | === (text (show (sepdist tree)) D.<> strutY 0.5) 44 | -- === (text (show (maxDescendentDistance tree)) <> strutY 0.5) 45 | ) 46 | D.<> circle 1 # fc nodecolor 47 | ) 48 | === (pad 1.05 $ centerName (label++show (depth+1)) $ 49 | VG.foldr (|||) mempty $ VG.map (diagramCT_ (depth+1)) $ children tree) 50 | 51 | where 52 | label = intShow $ nodedp tree 53 | nodecolor = if ctBetterMovableNodes tree==0 --nodeWeight tree > 0 54 | then red 55 | else lightblue 56 | 57 | mkConnections = 58 | D.connect (label++show depth) (label++show (depth+1)) 59 | . apList (fmap 60 | (\key -> D.connect (label++show depth) (intShow key++show (depth+1))) 61 | (map nodedp $ toList $ children tree) 62 | ) 63 | 64 | centerName name = withName name $ \b a -> moveOriginTo (location b) a 65 | 66 | apList :: [a -> a] -> a -> a 67 | apList [] a = a 68 | apList (x:xs) a = apList xs (x a) 69 | 70 | intShow :: Show a => a -> String 71 | intShow a = P.filter go $ show a 72 | where 73 | go x 74 | | x=='.' = True 75 | | x==',' = True 76 | | x=='-' = True 77 | | x=='1' = True 78 | | x=='2' = True 79 | | x=='3' = True 80 | | x=='4' = True 81 | | x=='5' = True 82 | | x=='6' = True 83 | | x=='7' = True 84 | | x=='8' = True 85 | | x=='9' = True 86 | | x=='0' = True 87 | | otherwise = False 88 | 89 | -} 90 | -------------------------------------------------------------------------------- /src/HLearn/Evaluation/CrossValidation.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RankNTypes,ScopedTypeVariables #-} 2 | module HLearn.Evaluation.CrossValidation 3 | where 4 | 5 | -- import Control.Monad 6 | import Control.Monad.Random hiding (fromList) 7 | -- import Control.Monad.ST 8 | import Control.Monad.Trans (lift) 9 | import Data.Array.ST 10 | import qualified GHC.Arr as Arr 11 | import qualified Data.Foldable as F 12 | import qualified Data.Map as Map 13 | import qualified Data.Set as S 14 | import qualified Data.Vector as V 15 | -- import qualified Data.Vector as V 16 | 17 | import Debug.Trace 18 | 19 | import qualified Data.DList as D 20 | import Prelude (take,drop,map,filter,zip) 21 | 22 | import SubHask 23 | import SubHask.Monad 24 | 25 | import HLearn.History 26 | import HLearn.Models.Distributions.Common 27 | import HLearn.Models.Distributions.Univariate.Normal 28 | import HLearn.Models.Classifiers.Common 29 | -- import qualified Control.ConstraintKinds as CK 30 | 31 | ------------------------------------------------------------------------------- 32 | -- standard k-fold cross validation 33 | 34 | type MonadRandom_ m = (Monad Hask m, MonadRandom m) 35 | 36 | type SamplingMethod = forall dp r. (Eq dp, MonadRandom_ r) => [dp] -> r [([dp],[dp])] 37 | 38 | repeatExperiment :: Int -> SamplingMethod -> SamplingMethod 39 | repeatExperiment n f xs = 40 | liftM concat $ forM [1..n] $ \_ -> do 41 | xs' <- shuffle xs 42 | f xs' 43 | 44 | trainingPercent :: Double -> SamplingMethod 45 | trainingPercent percent xs = do 46 | xs' <- shuffle xs 47 | let n = round $ percent * fromIntegral (length xs') 48 | return [ (take n xs', drop n xs') ] 49 | 50 | setMaxDatapoints :: Int -> SamplingMethod -> SamplingMethod 51 | setMaxDatapoints n f xs = do 52 | xs' <- shuffle xs 53 | f $ take n xs' 54 | 55 | kfold :: Int -> SamplingMethod 56 | kfold k xs = do 57 | xs' <- shuffle xs 58 | let step = floor $ (fromIntegral $ length xs :: Double) / fromIntegral k 59 | trace ("step="++show step) $ return 60 | [ ( take ((i)*step) xs' ++ drop ((i+1)*step) xs' 61 | , take step $ drop (i*step) xs' 62 | ) 63 | | i <- [0..k-1] 64 | ] 65 | 66 | -- leaveOneOut :: SamplingMethod 67 | -- leaveOneOut xs = return $ map (\x -> [x]) xs 68 | -- 69 | -- withPercent :: Double -> SamplingMethod -> SamplingMethod 70 | -- withPercent p f xs = do 71 | -- xs' <- shuffle xs 72 | -- f $ take (floor $ (fromIntegral $ length xs') * p) xs' 73 | -- 74 | -- repeatExperiment :: Int -> SamplingMethod -> SamplingMethod 75 | -- repeatExperiment i f xs = do 76 | -- liftM concat $ forM [1..i] $ \i -> do 77 | -- f xs 78 | -- 79 | -- kfold :: Int -> SamplingMethod 80 | -- kfold k xs = do 81 | -- xs' <- shuffle xs 82 | -- return [takeEvery k $ drop j xs' | j<-[0..k-1]] 83 | -- where 84 | -- takeEvery n [] = [] 85 | -- takeEvery n xs = head xs : (takeEvery n $ drop n xs) 86 | -- 87 | -- numSamples :: Int -> SamplingMethod -> SamplingMethod 88 | -- numSamples n f dps = f $ take n dps 89 | 90 | -- | randomly shuffles a list in time O(n log n); see http://www.haskell.org/haskellwiki/Random_shuffle 91 | shuffle :: (Eq a, MonadRandom_ m) => [a] -> m [a] 92 | shuffle xs = do 93 | let l = length xs 94 | rands <- take l `liftM` getRandomRs (0, l-1) 95 | let ar = runSTArray ( do 96 | ar <- Arr.thawSTArray (Arr.listArray (0, l-1) xs) 97 | forM_ (zip [0..(l-1)] rands) $ \(i, j) -> do 98 | vi <- Arr.readSTArray ar i 99 | vj <- Arr.readSTArray ar j 100 | Arr.writeSTArray ar j vi 101 | Arr.writeSTArray ar i vj 102 | return ar 103 | ) 104 | return (Arr.elems ar) 105 | 106 | --------------------------------------- 107 | 108 | type LossFunction = forall model. 109 | ( Classifier model 110 | -- , HomTrainer model 111 | , Labeled (Datapoint model) 112 | , Eq (Label (Datapoint model)) 113 | , Eq (Datapoint model) 114 | ) => model -> [Datapoint model] -> Double 115 | 116 | accuracy :: LossFunction 117 | accuracy model dataL = (fromIntegral $ length $ filter (==False) resultsL) / (fromIntegral $ length dataL) 118 | where 119 | resultsL = map (\(l1,l2) -> l1/=l2) $ zip trueL classifyL 120 | trueL = map getLabel dataL 121 | classifyL = map (classify model . getAttributes) dataL 122 | 123 | errorRate :: LossFunction 124 | errorRate model dataL = 1 - accuracy model dataL 125 | 126 | --------------------------------------- 127 | 128 | 129 | validateM :: forall model g container m. 130 | -- ( HomTrainer model 131 | ( Classifier model 132 | , RandomGen g 133 | , Eq (Datapoint model) 134 | , Eq (Label (Datapoint model)) 135 | , Foldable (container (Datapoint model)) 136 | , Constructible (container (Datapoint model)) 137 | , Elem (container (Datapoint model)) ~ Datapoint model 138 | , HistoryMonad m 139 | ) => SamplingMethod 140 | -> LossFunction 141 | -> container (Datapoint model) 142 | -> (container (Datapoint model) -> m model) 143 | -> RandT g m (Normal Double) 144 | validateM samplingMethod loss xs trainM = do 145 | xs' <- samplingMethod $ toList xs 146 | lift $ collectReports $ fmap trainNormal $ forM xs' $ \(trainingset, testset) -> do 147 | model <- trainM (fromList trainingset) 148 | return $ loss model testset 149 | 150 | 151 | -------------------------------------------------------------------------------- /src/HLearn/Evaluation/CrossValidationData.hs: -------------------------------------------------------------------------------- 1 | module HLearn.Evaluation.CrossValidationData 2 | where 3 | 4 | import GHC.TypeLits 5 | import HLearn.Algebra 6 | 7 | data FoldType = LeaveOneOut 8 | | KFold Nat 9 | 10 | data CrossValidation model (fold :: FoldType) = CrossValidation 11 | { modelL 12 | -------------------------------------------------------------------------------- /src/HLearn/Evaluation/CrossValidationHom.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ConstraintKinds,Rank2Types,KindSignatures #-} 2 | module HLearn.Evaluation.CrossValidationHom 3 | where 4 | 5 | import qualified Data.Vector as V 6 | import qualified Data.Sequence as Seq 7 | 8 | import HLearn.Algebra 9 | import HLearn.Models.Distributions 10 | import HLearn.Models.Classifiers.Common 11 | 12 | import HLearn.DataStructures.CoverTree 13 | import HLearn.DataStructures.SpaceTree.Algorithms.NearestNeighbor 14 | import HLearn.Metrics.Lebesgue 15 | import qualified Data.Vector.Unboxed as VU 16 | 17 | ------------------------------------------------------------------------------- 18 | -- data types 19 | 20 | data CrossValidation model result (infer:: k) lossfunction lossresult = CrossValidation 21 | { dataL :: Seq.Seq (Datapoint model) 22 | , resultL :: Seq.Seq result 23 | , lossL :: Seq.Seq lossresult 24 | , loss :: Normal lossresult lossresult 25 | , model :: model 26 | } 27 | 28 | type family Get_Model (a:: *) 29 | type instance Get_Model (CrossValidation model result infer lossfunction lossresult) = model 30 | 31 | type family Set_Model (a:: *) (model':: *) 32 | type instance Set_Model (CrossValidation model result infer lossfunction lossresult) model' = 33 | CrossValidation model' result infer lossfunction lossresult 34 | 35 | 36 | type DP = L2 (VU.Vector Double) 37 | type Test = CrossValidation (CoverTree DP) (KNN2 1 DP) KNN ErrorRate Double 38 | 39 | --------------------------------------- 40 | 41 | data InferType (a::a) 42 | 43 | instance 44 | ( HomTrainer model 45 | , PDF model 46 | , dp ~ Datapoint model 47 | , prob ~ Probability model 48 | ) => Function (InferType PDF) (model,dp) prob where 49 | function _ (m,dp) = pdf m dp 50 | 51 | function2 a = curry $ function a 52 | 53 | --------------------------------------- 54 | 55 | data SquaredLoss 56 | instance Num a => Function SquaredLoss (a,a) a where 57 | function _ (result,target) = (target-result)*(target-result) 58 | 59 | data ErrorRate 60 | instance (Num a, Eq a) => Function ErrorRate (a,a) a where 61 | function _ (result,target) = indicator $ target==result 62 | 63 | -- ------------------------------------------------------------------------------- 64 | -- algebra 65 | 66 | instance 67 | ( HomTrainer model 68 | , Labeled (Datapoint model) 69 | , Monoid result 70 | , Function (InferType infertype) (model,Datapoint model) result 71 | , Function lossfunction (Label (Datapoint model),Label (Datapoint model)) lossresult 72 | , result ~ Label (Datapoint model) 73 | , Num lossresult 74 | ) => Monoid (CrossValidation model result infertype lossfunction lossresult) 75 | where 76 | mempty = CrossValidation 77 | { dataL = mempty 78 | , resultL = mempty 79 | , lossL = mempty 80 | , loss = mempty 81 | , model = mempty 82 | } 83 | 84 | mappend cv1 cv2 = CrossValidation 85 | { dataL = dataL' 86 | , resultL = resultL' 87 | , lossL = lossL' 88 | , loss = train lossL' 89 | , model = model cv1 <> model cv2 90 | } 91 | where 92 | dataL' = dataL cv1 <> dataL cv2 93 | resultL' = resultL1 <> resultL2 94 | lossL' = fmap loss $ Seq.zip resultL' $ fmap getLabel dataL' 95 | resultL1 = fmap (\(x,r) -> infer (model cv2) x <> r) $ Seq.zip (dataL cv1) (resultL cv1) 96 | resultL2 = fmap (\(x,r) -> infer (model cv1) x <> r) $ Seq.zip (dataL cv2) (resultL cv2) 97 | 98 | infer = function2 (undefined :: InferType infertype) 99 | loss = function (undefined :: lossfunction) 100 | 101 | -- data CoverTree 102 | -- data KNN 103 | -- 104 | -- data UndefinedModel 105 | -- 106 | -- type DefaultCV = CrossValidation UndefinedModel KNN KNN ErrorRate Double 107 | -- 108 | -- train xs :: DefaultCV [tlr| model=NearestNeighbor [tlr| k=5 |], lossfunction=ErrorRate |] 109 | 110 | instance 111 | ( Monoid (CrossValidation model result infertype lossfunction lossresult) 112 | , HomTrainer model 113 | , Num lossresult 114 | ) => HomTrainer (CrossValidation model result infertype lossfunction lossresult) 115 | where 116 | 117 | type Datapoint (CrossValidation model result infertype lossfunction lossresult) = Datapoint model 118 | 119 | train1dp dp = CrossValidation 120 | { dataL = Seq.singleton dp 121 | , resultL = mempty 122 | , lossL = mempty 123 | , loss = mempty 124 | , model = train1dp dp 125 | } 126 | 127 | ------------------------------------------------------------------------------- 128 | -- Junk 129 | 130 | class (Monoid model) => Hom container model where 131 | type HomDP model :: * 132 | -- type HomDomain model :: * -> * 133 | 134 | hom1dp :: HomDP model -> model 135 | hom :: container (HomDP model) -> model 136 | homAdd1dp :: model -> HomDP model -> model 137 | homAddBatch :: model -> container (HomDP model) -> model 138 | 139 | 140 | -- class LossFunction f where 141 | -- loss :: f -> model -> Datapoint model -> Ring f 142 | -------------------------------------------------------------------------------- /src/HLearn/History/Timing.hs: -------------------------------------------------------------------------------- 1 | -- | 2 | -- 3 | -- FIXME: incorporate this into the History monad 4 | module HLearn.History.Timing 5 | where 6 | 7 | import SubHask (NFData, deepseq) 8 | 9 | import Prelude 10 | import Data.Time.Clock 11 | import Numeric 12 | import System.CPUTime 13 | import System.IO 14 | 15 | time :: NFData a => String -> a -> IO a 16 | time str a = timeIO str $ deepseq a $ return a 17 | 18 | timeIO :: NFData a => String -> IO a -> IO a 19 | timeIO str f = do 20 | hPutStr stderr $ str ++ replicate (45-length str) '.' 21 | hFlush stderr 22 | cputime1 <- getCPUTime 23 | realtime1 <- getCurrentTime >>= return . utctDayTime 24 | ret <- f 25 | deepseq ret $ return () 26 | cputime2 <- getCPUTime 27 | realtime2 <- getCurrentTime >>= return . utctDayTime 28 | 29 | hPutStrLn stderr $ "done" 30 | ++ ". real time=" ++ show (realtime2-realtime1) 31 | ++ "; cpu time=" ++ showFFloat (Just 6) ((fromIntegral $ cputime2-cputime1)/1e12 :: Double) "" ++ "s" 32 | return ret 33 | 34 | -------------------------------------------------------------------------------- /src/HLearn/Models/Distributions.hs: -------------------------------------------------------------------------------- 1 | 2 | module HLearn.Models.Distributions 3 | where 4 | 5 | import SubHask 6 | import SubHask.TemplateHaskell.Deriving 7 | 8 | -------------------------------------------------------------------------------- 9 | 10 | -- | In measure theory, it is common to treat discrete and continuous distributions the same. 11 | -- This lets us more easily generalize to multivariate distributions. 12 | -- 13 | -- FIXME: 14 | -- We need to think carefully about this class hierarchy. 15 | -- Every distribution has a cdf, but these are often difficult to calculate. 16 | -- Not every distribution has a density, but most of the useful ones do. 17 | class Distribution d where 18 | 19 | mean :: d -> Elem d 20 | 21 | -- | The density (or mass) function of the distribution. 22 | pdf :: d -> Elem d -> Scalar d 23 | 24 | -- | Gives a result proportional to pdf. 25 | -- This is usually faster to compute and many algorithms don't require the pdf be normalized. 26 | pdf_ :: d -> Elem d -> Scalar d 27 | pdf_ = pdf 28 | 29 | -------------------------------------------------------------------------------- 30 | 31 | -- | Stores the unnormalized raw 0th, 1st, and 2nd moments of a distribution. 32 | -- These are sufficient statistics for many distributions. 33 | -- We can then easily construct other distributions from this distribution. 34 | -- 35 | -- FIXME: 36 | -- This data type has some numeric stability built-in. 37 | -- But for many distributions, there exist methods of calculating the parameters that are even more stable. 38 | data Moments v = Moments 39 | { m0 :: !(Scalar v) 40 | , m1 :: !v 41 | , m2 :: !(v>< r = Moments (v> Semigroup (Moments v) where 56 | (Moments a1 b1 c1)+(Moments a2 b2 c2) = Moments (a1+a2) (b1+b2) (c1+c2) 57 | 58 | instance Hilbert v => Monoid (Moments v) where 59 | zero = Moments zero zero zero 60 | 61 | instance Hilbert v => Abelian (Moments v) 62 | 63 | instance Hilbert v => Cancellative (Moments v) where 64 | (Moments a1 b1 c1)-(Moments a2 b2 c2) = Moments (a1-a2) (b1-b2) (c1-c2) 65 | 66 | instance Hilbert v => Group (Moments v) where 67 | negate (Moments a b c) = Moments (negate a) (negate b) (negate c) 68 | 69 | instance Hilbert v => Module (Moments v) where 70 | (Moments a b c).*r = Moments (r*a) (b.*r) (c.*r) 71 | 72 | instance Hilbert v => FreeModule (Moments v) where 73 | -- TODO: what is the probabilistic interpretation of this? 74 | (Moments a1 b1 c1).*.(Moments a2 b2 c2) = Moments (a1*a2) (b1.*.b2) (c1.*.c2) 75 | 76 | instance Hilbert v => VectorSpace (Moments v) where 77 | (Moments a b c)./r = Moments (r/a) (b./r) (c./r) 78 | 79 | (Moments a1 b1 c1)./.(Moments a2 b2 c2) = Moments (a1/a2) (b1./.b2) (c1./.c2) 80 | 81 | ------------------- 82 | -- container hierarchy 83 | 84 | instance Hilbert v => Constructible (Moments v) where 85 | singleton v = Moments 1 v (v> Semigroup (Normal v) where 103 | (Normal n1)+(Normal n2)=Normal $ n1+n2 104 | 105 | instance Hilbert v => Monoid (Normal v) where 106 | zero = Normal zero 107 | 108 | train1Normal :: Hilbert v => v -> Normal v 109 | train1Normal v = Normal $ singleton v 110 | 111 | instance (FiniteModule v, Hilbert v) => Distribution (Normal v) where 112 | 113 | mean (Normal (Moments m0 m1 m2)) = m1 ./ m0 114 | 115 | pdf (Normal (Moments m0 m1 m2)) v 116 | = (2*pi*size sigma)**(-fromIntegral (dim v)/2)*exp((-1/2)*(v' `vXm` reciprocal sigma)<>v') 117 | where 118 | v' = v - mu 119 | 120 | mu = m1 ./ m0 121 | sigma = 1 + m2 ./ m0 - mu> Normal v -> v>)) 17 | import qualified Numeric.LinearAlgebra as LA 18 | 19 | findMinAmoeba f x0 = runST $ do 20 | 21 | -- initialize simplex 22 | vec <- VM.new (VG.length x0+1) 23 | VGM.write vec 0 (f x0,x0) 24 | forM [1..VGM.length vec-1] $ \i -> do 25 | e_i <- VGM.replicate (VG.length x0) 0 26 | VGM.write e_i (i-1) 1 27 | e_i' <- VG.freeze e_i 28 | let x_i = x0 `LA.add` e_i' 29 | VGM.write vec i (f x_i,x_i) 30 | 31 | -- iterate 32 | vec' <- itrM 1000 (stepAmoeba f) vec 33 | 34 | -- return 35 | (_,ret) <- VGM.read vec 0 36 | return ret 37 | 38 | stepAmoeba f vec = stepAmoebaRaw 1 2 (-1/2) (1/2) f vec 39 | 40 | stepAmoebaRaw :: 41 | ( Fractional b 42 | , Ord b 43 | , VGM.MVector vec (b,a) 44 | -- , a ~ LA.Matrix b 45 | , a ~ LA.Vector b 46 | , Field b 47 | , vec ~ VM.MVector 48 | , b ~ Double 49 | ) => b 50 | -> b 51 | -> b 52 | -> b 53 | -> (a -> b) 54 | -> vec s (b,a) 55 | -> ST s (vec s (b,a)) 56 | stepAmoebaRaw alpha gamma ro sigma f vec = do 57 | 58 | Intro.sortBy (\a b -> compare (fst a) (fst b)) vec 59 | 60 | (f_1,x_1) <- VGM.read vec 0 61 | (f_2,x_2) <- VGM.read vec 1 62 | (f_n1,x_n1) <- VGM.read vec $ VGM.length vec -1 63 | 64 | x_0 <- liftM ( scale (1/fromIntegral (VGM.length vec-1)) 65 | . foldl1' (LA.add) 66 | . init 67 | . map snd 68 | . V.toList 69 | ) $ VG.unsafeFreeze vec 70 | 71 | let x_r = x_0 `LA.add` (scale alpha $ x_0 `LA.sub` x_n1) 72 | f_r = f x_r 73 | 74 | x_e = x_0 `LA.add` (scale gamma $ x_0 `LA.sub` x_n1) 75 | f_e = f x_e 76 | 77 | x_c = x_0 `LA.add` (scale ro $ x_0 `LA.sub` x_n1) 78 | f_c = f x_c 79 | 80 | -- check reflection 81 | if f_1 <= f_r && f_r < f_1 82 | then VGM.write vec (VGM.length vec-1) (f_r,x_r) 83 | 84 | -- check expansion 85 | else if f_r < f_1 86 | then if f_e < f_r 87 | then VGM.write vec (VGM.length vec-1) (f_e,x_e) 88 | else VGM.write vec (VGM.length vec-1) (f_r,x_r) 89 | 90 | -- check contraction 91 | else if f_c < f_n1 92 | then VGM.write vec (VGM.length vec-1) (f_c,x_c) 93 | 94 | -- reduction 95 | else forM_ [1..VGM.length vec-1] $ \i -> do 96 | (f_i,x_i) <- VGM.read vec i 97 | let x_i' = x_1 `LA.add` (scale sigma $ x_i `LA.sub` x_1) 98 | f_i' = f x_i' 99 | VGM.write vec i (f_i',x_i') 100 | 101 | return vec 102 | -- refMinVal <- newSTRef (-infinity) 103 | -- refMinIndex <- newSTRef 0 104 | -- forM [0..VGM.length vec-1] $ \i -> do 105 | -- ival <- VGM.read vec i 106 | -- minVal <- readSTRef refMinVal 107 | -- if minVal < fst ival 108 | -- then return () 109 | -- else do 110 | -- writeSTRef refMinVal $ fst ival 111 | -- writeSTRef refMinIndex i 112 | -- undefined 113 | 114 | itrM :: Monad m => Int -> (a -> m a) -> a -> m a 115 | itrM 0 f a = return a 116 | itrM i f a = do 117 | a' <- f a 118 | -- if a' == a 119 | -- then trace ("no movement\n a="++show a++"\n a'="++show a') $ return () 120 | -- else trace ("itrM i="++show i++"; f a ="++show a'++"; a="++show a) $ return () 121 | itrM (i-1) f a' 122 | -------------------------------------------------------------------------------- /src/HLearn/Optimization/Conic.hs: -------------------------------------------------------------------------------- 1 | module HLearn.Optimization.Conic 2 | where 3 | 4 | import Control.DeepSeq 5 | import Control.Monad 6 | import Control.Monad.Random 7 | import Control.Monad.ST 8 | import Data.List 9 | import Data.List.Extras 10 | import Debug.Trace 11 | import qualified Data.Vector as V 12 | import qualified Data.Vector.Mutable as VM 13 | import qualified Data.Vector.Storable as VS 14 | import qualified Data.Vector.Storable.Mutable as VSM 15 | import qualified Data.Vector.Generic as VG 16 | import qualified Data.Vector.Generic.Mutable as VGM 17 | import qualified Data.Vector.Algorithms.Intro as Intro 18 | import Numeric.LinearAlgebra hiding ((<>)) 19 | import qualified Numeric.LinearAlgebra as LA 20 | import Data.Random.Normal 21 | 22 | import HLearn.Algebra 23 | import qualified HLearn.Optimization.Common as Recipe 24 | import qualified HLearn.Optimization.LineMinimization as LineMin 25 | 26 | ------------------------------------------------------------------------------- 27 | 28 | conicprojection :: Matrix Double -> Matrix Double 29 | conicprojection m = cmap realPart $ u LA.<> lambda' LA.<> trans u 30 | where 31 | lambda' = cmap (\x -> if realPart x < 0 then 0 else x) lambda 32 | lambda = diagRect 0 l (VG.length l) (VG.length l) 33 | (l,u) = eig m 34 | 35 | ------------------------------------------------------------------------------- 36 | 37 | data RandomConicPersuit a = RandomConicPersuit 38 | { _stdgen :: !StdGen 39 | , _soln :: !a 40 | , _fx :: !(Scalar a) 41 | , _solnlast :: !a 42 | } 43 | 44 | -- data OptInfo a = OptInfo 45 | -- { _stdgen :: !StdGen 46 | -- , _x :: !a 47 | -- , _fx :: 48 | -- , _xold :: !a 49 | -- } 50 | 51 | itr :: Int -> (tmp -> a) -> (tmp -> Bool) -> (tmp -> tmp) -> tmp -> a 52 | itr i result stop step init = --trace ("i="++show i++"; fx="++show (_fx init)) $ --trace ("i="++show i++"; init="++show init) $ 53 | if i==0 || stop init 54 | then result init 55 | else itr (i-1) result stop step (step init) 56 | 57 | randomConicPersuit f x0 = _soln $ itr 10 id (\x -> False) (step_RandomConicPersuit f) init 58 | where 59 | init = RandomConicPersuit 60 | { _stdgen = mkStdGen $ round $ sumElements x0 61 | , _soln = x0 62 | , _fx = f x0 63 | , _solnlast = x0 64 | } 65 | 66 | conicProjection f x0 = _soln $ argmin _fx [itr 1 id (\x -> False) (step_RandomConicPersuit f) (init i) | i <- [0..10]] 67 | where 68 | init i = RandomConicPersuit 69 | { _stdgen = mkStdGen $ i+(round $ sumElements x0) 70 | , _soln = x0 71 | , _fx = f x0 72 | , _solnlast = x0 73 | } 74 | 75 | stop_RandomConicPersuit (RandomConicPersuit stdgen soln fx solnlast) = undefined 76 | 77 | step_RandomConicPersuit f (RandomConicPersuit stdgen soln fx solnlast) = --trace ("lambda = "++show lambda++ "; phi="++show phi) $ 78 | ret 79 | where 80 | -- (x', stdgen') = runRand ((LA.fromList . take (rows soln)) `liftM` getRandomRs (-1,1)) stdgen 81 | normL 0 (xs,g) = (xs,g) 82 | normL i (xs,g) = normL (i-1) (x:xs,g') 83 | where 84 | (x,g') = normal g 85 | 86 | (x'std, stdgen') = normL (rows soln) ([],stdgen) 87 | x' = VS.fromList x'std 88 | -- (lambda,phi::Matrix Double) = eigSH soln 89 | -- q = (diag $ cmap (sqrt.abs) lambda) LA.<> phi 90 | -- x' = (scale (1-kappa) q - scale kappa (ident (rows q))) LA.<> LA.fromList x'std 91 | -- kappa = 1e-4 92 | 93 | y' = asColumn x' LA.<> asRow x' 94 | -- y' = asColumn (LA.fromList x') LA.<> asRow (LA.fromList x') 95 | 96 | g_alpha alpha = f $ scale alpha y' + soln 97 | alpha_hat = error "step_randomConic" -- LineMin._x $ runOptimization $ LineMin.brent g_alpha (LineMin.lineBracket g_alpha 0 1) 98 | 99 | alpha_hat_y' = scale alpha_hat y' 100 | g_beta beta = f $ alpha_hat_y' + scale beta soln 101 | beta_hat = error "step_randomConic" -- LineMin._x $ LineMin.brent g_beta (LineMin.lineBracket g_beta 0 1) 102 | 103 | soln' = alpha_hat_y' + scale beta_hat soln 104 | 105 | ret = RandomConicPersuit 106 | { _stdgen = stdgen' 107 | , _soln = soln' 108 | , _fx = f soln' 109 | , _solnlast = soln 110 | } 111 | 112 | -------------------------------------------------------------------------------- /src/HLearn/Optimization/StepSize.hs: -------------------------------------------------------------------------------- 1 | module HLearn.Optimization.StepSize 2 | ( 3 | -- * fancy step sizes 4 | -- ** Almeida Langlois 5 | lrAlmeidaLanglois 6 | 7 | 8 | -- * simple step sizes 9 | -- ** linear decrease 10 | , lrLinear 11 | , eta 12 | -- , gamma 13 | 14 | -- ** constant step size 15 | , lrConst 16 | -- , step 17 | ) 18 | where 19 | 20 | import HLearn.Optimization.StepSize.Linear 21 | import HLearn.Optimization.StepSize.Const 22 | import HLearn.Optimization.StepSize.AlmeidaLanglois 23 | -------------------------------------------------------------------------------- /src/HLearn/Optimization/StepSize/AlmeidaLanglois.hs: -------------------------------------------------------------------------------- 1 | -- | 2 | -- 3 | -- See: Almeida and Langlois, "Parameter Adaptation in Stochastic Optimization" equation 4.6 4 | module HLearn.Optimization.StepSize.AlmeidaLanglois 5 | where 6 | 7 | import SubHask hiding (Functor(..), Applicative(..), Monad(..), Then(..), fail, return) 8 | import HLearn.History 9 | import HLearn.Optimization.Common 10 | 11 | import qualified Data.Vector.Generic as VG 12 | 13 | -- lrAlmeidaLanglois :: 14 | -- ( VectorSpace v 15 | -- , VG.Vector u r 16 | -- , u r ~ v 17 | -- , Field r 18 | -- ) => Scalar v 19 | -- -> Scalar v 20 | -- -> v 21 | -- -> AlmeidaLanglois v 22 | -- -> History (AlmeidaLanglois v) 23 | -- lrAlmeidaLanglois k gamma grad AlmeidaLangloisInit = return $ AlmeidaLanglois 24 | -- { d = grad 25 | -- , v = VG.map (const 100) grad 26 | -- , p = VG.map (const 0.1) grad 27 | -- } 28 | -- lrAlmeidaLanglois k gamma grad al = return $ AlmeidaLanglois 29 | -- { d = grad 30 | -- , v = v' 31 | -- , p = p al - k *. p al .*. d al .*. grad ./. v' 32 | -- } 33 | -- where 34 | -- v' = gamma *. v al + (1 - gamma) *. grad .*. grad 35 | 36 | lrAlmeidaLanglois :: (VectorSpace (v r), VG.Vector v r) => Hyperparams (v r) 37 | lrAlmeidaLanglois = Hyperparams 38 | { k = 0.01 39 | , gamma = 0.9 40 | , u = 1.01 41 | } 42 | 43 | data Hyperparams v = Hyperparams 44 | { k :: !(Scalar v) 45 | , gamma :: !(Scalar v) 46 | , u :: !(Scalar v) 47 | } 48 | 49 | -- | The criptic variable names in this data type match the variable names from the original paper 50 | data Params v = Params 51 | { d :: !v -- ^ gradient 52 | , v :: !v -- ^ component-wise running average of the gradient squared 53 | , p :: !v -- ^ actual step size 54 | } 55 | 56 | instance (r ~ Scalar (v r), VectorSpace (v r), VG.Vector v r) => LearningRate Hyperparams Params (v r) where 57 | lrInit (Hyperparams _ _ u) x = Params 58 | { d = VG.map (const $ 1/u) x 59 | , v = VG.map (const $ 1/u) x 60 | , p = VG.map (const $ 0.1) x 61 | } 62 | 63 | lrStep (Hyperparams k gamma _) (Params d0 v0 p0) grad = return $ Params 64 | { d = grad 65 | , v = v' 66 | , p = p0 + k *. p0 .*. d0 .*. grad ./. v' 67 | } 68 | where 69 | v' = gamma *. v0 + (1 - gamma) *. grad .*. grad 70 | 71 | lrApply _ (Params d v p) x = p .*. x 72 | -------------------------------------------------------------------------------- /src/HLearn/Optimization/StepSize/Const.hs: -------------------------------------------------------------------------------- 1 | -- | For setting a constant step size. 2 | module HLearn.Optimization.StepSize.Const 3 | ( lrConst 4 | , Hyperparams (..) 5 | ) 6 | where 7 | 8 | import SubHask hiding (Functor(..), Applicative(..), Monad(..), Then(..), fail, return) 9 | import HLearn.History 10 | import HLearn.Optimization.Common 11 | 12 | lrConst :: VectorSpace v => Hyperparams v 13 | lrConst = Hyperparams 14 | { step = 0.001 15 | } 16 | 17 | newtype Hyperparams v = Hyperparams 18 | { step :: Scalar v 19 | } 20 | 21 | data Params v = Params 22 | 23 | instance VectorSpace v => LearningRate Hyperparams Params v where 24 | lrInit _ _ = Params 25 | 26 | lrStep (Hyperparams step) _ _ = return Params 27 | 28 | lrApply (Hyperparams step) _ v = step *. v 29 | 30 | 31 | -------------------------------------------------------------------------------- /src/HLearn/Optimization/StepSize/Linear.hs: -------------------------------------------------------------------------------- 1 | -- | This is the classic formula for adjusting the learning rate from the original 1951 paper "A stochastic approximation method". 2 | -- Assymptotically, it has optimal convergence on strongly convex functions. 3 | -- But in practice setting good hyperparameters is difficult, and many important functions are not strongly convex. 4 | module HLearn.Optimization.StepSize.Linear 5 | ( lrLinear 6 | , eta 7 | , gamma 8 | , Hyperparams (..) 9 | ) 10 | where 11 | 12 | import SubHask hiding (Functor(..), Applicative(..), Monad(..), Then(..), fail, return) 13 | import HLearn.History 14 | import HLearn.Optimization.Common 15 | 16 | {- 17 | -- | This is a slight generalization of "lrLinear" taken from the paper "Stochastic Gradient Descent Tricks" 18 | lrBottou :: 19 | ( Floating (Scalar v) 20 | ) => Scalar v -- ^ the initial value 21 | -> Scalar v -- ^ the rate of decrease (recommended to be the smallest eigenvalue) 22 | -> Scalar v -- ^ the exponent; 23 | -- this should be negative to ensure the learning rate decreases; 24 | -- setting it to `-1` gives "lrLinear" 25 | -> LearningRate v 26 | lrBottou gamma lambda exp = do 27 | t <- currentItr 28 | return $ \v -> v .* (gamma * ( 1 + gamma * lambda * t) ** exp) 29 | -} 30 | 31 | lrLinear :: VectorSpace v => Hyperparams v 32 | lrLinear = Hyperparams 33 | { eta = 0.001 34 | , gamma = 0.1 35 | } 36 | 37 | data Hyperparams v = Hyperparams 38 | { eta :: !(Scalar v) -- ^ the initial value for the step size 39 | , gamma :: !(Scalar v) -- ^ the rate of decrease 40 | } 41 | 42 | newtype Params v = Params 43 | { step :: Scalar v 44 | } 45 | 46 | instance VectorSpace v => LearningRate Hyperparams Params v where 47 | lrInit (Hyperparams eta gamma) _ = Params eta 48 | 49 | lrStep (Hyperparams eta gamma) _ _ = do 50 | t <- currentItr 51 | return $ Params $ eta / (1 + gamma * t) 52 | 53 | lrApply _ (Params r) v = r *. v 54 | 55 | -------------------------------------------------------------------------------- /src/HLearn/Optimization/StochasticGradientDescent.hs: -------------------------------------------------------------------------------- 1 | -- | 2 | -- 3 | -- Important references: 4 | -- 5 | -- 1. Bach and Moulines, "Non-Asymptotic Analysis of Stochastic Approximation Algorithms for Machine Learning" 6 | -- 2. Xu, "Towards optimal one pass large scale learning with averaged stochastic gradient descent" 7 | -- 8 | module HLearn.Optimization.StochasticGradientDescent 9 | -- ( 10 | -- ) 11 | where 12 | 13 | import SubHask 14 | 15 | import HLearn.History 16 | import HLearn.Optimization.Multivariate 17 | 18 | import qualified Data.Vector.Generic as VG 19 | 20 | ------------------------------------------------------------------------------- 21 | 22 | class LearningRate init step v | init -> step, step -> init where 23 | lrStep :: init v -> step v -> v -> History (step v) 24 | lrInit :: init v -> v -> step v 25 | lrApply :: init v -> step v -> v -> v 26 | 27 | ------------------------------------------------------------------------------- 28 | 29 | data SGD container step dp v = SGD 30 | { __dataset :: !(container dp) 31 | , __x1 :: !v 32 | , __step :: !(step v) 33 | } 34 | deriving Typeable 35 | 36 | instance Show (SGD a b c d) where 37 | show _ = "SGD" 38 | 39 | instance Has_x1 (SGD container step dp) v where x1 = __x1 40 | 41 | type StochasticMethod container dp v 42 | = container dp -> (dp -> v -> v) -> History (v -> v) 43 | 44 | randomSample :: 45 | ( VG.Vector container dp 46 | ) => StochasticMethod container dp v 47 | randomSample !dataset !f = do 48 | t <- currentItr 49 | let a = 1664525 50 | c = 1013904223 51 | return $ f $ dataset VG.! ((a*t + c) `mod` VG.length dataset) 52 | 53 | linearScan :: 54 | ( VG.Vector container dp 55 | ) => StochasticMethod container dp v 56 | linearScan !dataset !f = do 57 | t <- currentItr 58 | return $ f $ dataset VG.! ( t `mod` VG.length dataset ) 59 | 60 | minibatch :: forall container dp v. 61 | ( VG.Vector container dp 62 | , Semigroup v 63 | ) => Int -> StochasticMethod container dp v 64 | minibatch !size !dataset !f = do 65 | t <- currentItr 66 | let start = t `mod` VG.length dataset 67 | minidata = VG.slice start (min (VG.length dataset-start) size) dataset 68 | 69 | let go (-1) ret = ret 70 | go j ret = go (j-1) $ ret + (f $ minidata `VG.unsafeIndex` j) 71 | 72 | return $ go (VG.length minidata-1) id 73 | 74 | --------------------------------------- 75 | 76 | stochasticGradientDescent :: 77 | ( VectorSpace v 78 | , LearningRate hyperparams params v 79 | , Typeable v 80 | , Typeable container 81 | , Typeable dp 82 | , Typeable params 83 | ) => StochasticMethod container dp v 84 | -> hyperparams v 85 | -> container dp 86 | -> (dp -> v -> v) 87 | -> v 88 | -> StopCondition_ (SGD container params dp v) 89 | -> History (SGD container params dp v) 90 | stochasticGradientDescent !sm !lr !dataset !f' !x0 !stops = iterate 91 | ( stepSGD sm lr f' ) 92 | ( SGD dataset x0 $ lrInit lr x0 ) 93 | stops 94 | 95 | stepSGD :: 96 | ( VectorSpace v 97 | , LearningRate hyperparams params v 98 | , Typeable v 99 | , Typeable container 100 | , Typeable dp 101 | , Typeable params 102 | ) => StochasticMethod container dp v 103 | -> hyperparams v 104 | -> (dp -> v -> v) 105 | -> SGD container params dp v 106 | -> History (SGD container params dp v) 107 | stepSGD !sm !hyperparams !f' !sgd = do 108 | calcgrad <- sm (__dataset sgd) f' 109 | let grad = calcgrad $ __x1 sgd 110 | step <- lrStep hyperparams (__step sgd) grad 111 | report $ sgd 112 | { __x1 = __x1 sgd - (lrApply hyperparams step grad) 113 | , __step = step 114 | } 115 | -------------------------------------------------------------------------------- /src/HLearn/Optimization/TestFunctions.hs: -------------------------------------------------------------------------------- 1 | module HLearn.Optimization.TestFunctions 2 | where 3 | 4 | 5 | import SubHask 6 | import SubHask.Category.Trans.Derivative 7 | 8 | import HLearn.History 9 | import HLearn.Optimization.Multivariate 10 | import HLearn.Optimization.Univariate 11 | 12 | ------------------------------------------------------------------------------- 13 | 14 | -- import SubHask.Compatibility.Vector 15 | -- import SubHask.Compatibility.HMatrix 16 | -- import qualified Data.Vector.Generic as VG 17 | 18 | sphere :: Hilbert v => C2 (v -> Scalar v) 19 | sphere = unsafeProveC2 f f' f'' 20 | where 21 | f v = let x = size v in x*x 22 | f' v = 2*.v 23 | f'' v = 2 24 | 25 | logloss :: ( IsScalar v, Floating v ) => C2 ( v -> v ) 26 | logloss = unsafeProveC2 f f' f'' 27 | where 28 | f x = log (1 + exp (-x) ) 29 | f' x = -1/(1 + exp x) 30 | f'' x = exp x / (1 + exp x)**2 31 | 32 | {- 33 | f :: Hilbert v => v -> v -> Scalar v 34 | f x y = (x-y)<>(x-y) 35 | 36 | f' :: Hilbert v => v -> v -> v 37 | f' x y = 2*.(x-y) 38 | 39 | f'' :: (TensorAlgebra v, Module (Outer v), Hilbert v) => v -> v -> Outer v 40 | f'' x y = 2*.1 41 | 42 | u = VG.fromList [1,2,3,4,5,6,7] :: Vector Double 43 | v = VG.fromList [2,3,1,0,-2,-3,-1] :: Vector Double 44 | w = VG.fromList [1,1] :: Vector Double 45 | 46 | beale :: (FiniteModule v, Hilbert v) => v -> Scalar v 47 | beale v = (1.5-x+x*y)**2 + (2.25 - x + x*y**2)**2 + (2.625 - x + x*y**3)**2 48 | where 49 | x=v!0 50 | y=v!1 51 | 52 | beale' :: (FiniteModule v, Hilbert v) => v -> v 53 | beale' v = unsafeToModule 54 | [ 2*(-1+y)*(1.5-x+x*y)**2 + 2*(-1+y**2)*(2.25 - x + x*y**2)**2 + 2*(-1+y**3)*(2.625 - x + x*y**3)**2 55 | , 2*x*(1.5-x+x*y)**2 + 2*x*2*y*(2.25 - x + x*y**2)**2 + 2*x*3*y*y*(2.625 - x + x*y**3)**2 56 | ] 57 | where 58 | x=v!0 59 | y=v!1 60 | 61 | -- beale'' :: (FiniteModule v, TensorAlgebra, Hilbert v) => v -> Outer v 62 | -- beale'' v = mkMatrix 2 2 63 | -- [ 2*(-1+y)*(-1+y)**2 + 2*(-1+y**2)*(- 1 + y**2)**2 + 2*(-1+y**3)*(-1 + y**3)**2 64 | -- , 2*(-1+y)*(1.5-x+x*y)**2 65 | -- + 2*(-1+y)*(1.5-x+x*y)**2 66 | -- + 2*(-1+y**2)*(2.25 - x + x*y**2)**2 67 | -- + 2*(-1+y**3)*(2.625 - x + x*y**3)**2 68 | -- , 2*x*(1.5-x+x*y)**2 + 2*x*2*y*(2.25 - x + x*y**2)**2 + 2*x*3*y*y*(2.625 - x + x*y**3)**2 69 | -- , 2*x*(1.5-x+x*y)**2 + 2*x*2*y*(2.25 - x + x*y**2)**2 + 2*x*3*y*y*(2.625 - x + x*y**3)**2 70 | -- ] 71 | -} 72 | -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | flags: {} 2 | packages: 3 | - '.' 4 | - subhask/ 5 | extra-deps: 6 | [ gamma-0.9.0.2 7 | , continued-fractions-0.9.1.1 8 | , converge-0.1.0.1 9 | ] 10 | resolver: lts-5.14 11 | -------------------------------------------------------------------------------- /test/BashTests.hs: -------------------------------------------------------------------------------- 1 | {- 2 | - This file is a hack so that we can run bash script tests using cabal. 3 | - See: http://stackoverflow.com/questions/31213883/how-to-use-cabal-with-bash-tests 4 | - 5 | - Note that running `cabal test` will also run `cabal install`. 6 | - 7 | -} 8 | 9 | import Control.Monad 10 | import System.Exit 11 | import System.Process 12 | 13 | runBashTest :: String -> IO () 14 | runBashTest cmd = do 15 | putStr $ cmd ++ "..." 16 | ExitSuccess <- system cmd 17 | return () 18 | 19 | main = do 20 | system "cabal install" 21 | sequence_ 22 | [ runBashTest $ "./test/allknn-verify/runtest.sh " 23 | ++ " " ++ dataset 24 | ++ " " ++ treetype 25 | ++ " " ++ foldtype 26 | ++ " " ++ rotate 27 | 28 | | dataset <- 29 | [ "./test/allknn-verify/dataset-10000x2.csv" 30 | -- , "./test/allknn-verify/dataset-10000x20.csv" 31 | -- , "./test/allknn-verify/mnist-10000.csv" 32 | ] 33 | , treetype <- 34 | [ "--treetype=simplified" 35 | , "--treetype=ancestor" 36 | ] 37 | , foldtype <- 38 | [ "--fold=foldsort" 39 | , "--fold=fold" 40 | ] 41 | , rotate <- 42 | [ "--rotate=norotate" 43 | , "--rotate=variance" 44 | ] 45 | ] 46 | -------------------------------------------------------------------------------- /test/QuickCheck.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE NoImplicitPrelude #-} 2 | {-# LANGUAGE TemplateHaskell #-} 3 | {-# LANGUAGE DataKinds #-} 4 | 5 | module Main 6 | where 7 | 8 | import HLearn.Data.SpaceTree.CoverTree 9 | 10 | import SubHask 11 | import SubHask.Algebra.Container 12 | import SubHask.Algebra.Vector 13 | import SubHask.TemplateHaskell.Test 14 | 15 | import Test.Framework (defaultMain, testGroup) 16 | import Test.Framework.Providers.QuickCheck2 (testProperty) 17 | import Test.Framework.Runners.Console 18 | import Test.Framework.Runners.Options 19 | 20 | -------------------------------------------------------------------------------- 21 | 22 | main = defaultMainWithOpts 23 | [ testGroup "CoverTree_" 24 | [ $( mkSpecializedClassTests [t| UCoverTree (UVector "dyn" Float) |] [ ''Constructible ] ) 25 | , $( mkSpecializedClassTests [t| BCoverTree (UVector "dyn" Float) |] [ ''Constructible ] ) 26 | ] 27 | ] 28 | $ RunnerOptions 29 | { ropt_threads = Nothing 30 | , ropt_test_options = Nothing 31 | , ropt_test_patterns = Nothing 32 | , ropt_xml_output = Nothing 33 | , ropt_xml_nested = Nothing 34 | , ropt_color_mode = Just ColorAlways 35 | , ropt_hide_successes = Just True 36 | , ropt_list_only = Just True 37 | } 38 | -------------------------------------------------------------------------------- /test/allknn-mlpack/runtest.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | K=1 4 | 5 | tmpdir=$(mktemp --tmpdir -d) 6 | 7 | echo "--------------------------------------------------------------------------------" 8 | echo "tempdir=$tmpdir" 9 | echo "--------------------------------------------------------------------------------" 10 | hlearn_neighbors="$tmpdir/neighbors_hlearn.csv" 11 | hlearn_distances="$tmpdir/distances_hlearn.csv" 12 | mlpack_neighbors="$tmpdir/neighbors_mlpack.csv" 13 | mlpack_distances="$tmpdir/distances_mlpack.csv" 14 | 15 | hlearn-allknn -k $K -r $@ -n "$hlearn_neighbors" -d "$hlearn_distances" +RTS -K1000M -N1 16 | allknn -r $1 -n "$mlpack_neighbors" -d "$mlpack_distances" -k $K -v 17 | 18 | echo "-------------------------------------" 19 | echo "num differences: " `diff $hlearn_neighbors $mlpack_neighbors | wc -l` " / " `cat $1 | wc -l` 20 | diff $hlearn_neighbors $mlpack_neighbors > /dev/null 21 | -------------------------------------------------------------------------------- /test/allknn-verify/runtest.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | K=1 4 | 5 | tmpdir=$(mktemp --tmpdir -d) 6 | 7 | echo "--------------------------------------------------------------------------------" 8 | echo "tempdir=$tmpdir" 9 | echo "--------------------------------------------------------------------------------" 10 | hlearn_neighbors="$tmpdir/neighbors_hlearn.csv" 11 | hlearn_distances="$tmpdir/distances_hlearn.csv" 12 | 13 | hlearn-allknn -k $K -r $@ -n "$hlearn_neighbors" -d "$hlearn_distances" 14 | 15 | echo "-------------------------------------" 16 | echo "num differences: " `diff "$hlearn_neighbors" "$1-neighbors" | wc -l` " / " `cat $1 | wc -l` 17 | diff "$hlearn_neighbors" "$1-neighbors" > /dev/null 18 | --------------------------------------------------------------------------------