├── Setup.lhs ├── .gitignore ├── Algorithms └── MachineLearning │ ├── Utilities.hs │ ├── LinearClassification.hs │ ├── BasisFunctions.hs │ ├── LinearAlgebra.hs │ ├── Tests │ ├── Driver.hs │ └── Data.hs │ ├── Framework.hs │ └── LinearRegression.hs ├── machine-learning.cabal └── LICENSE /Setup.lhs: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env runhaskell 2 | 3 | > import Distribution.Simple 4 | > main = defaultMain -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Cabal temporary stuff 2 | dist/ 3 | 4 | # Output from running the executable 5 | *.dat 6 | *.ps -------------------------------------------------------------------------------- /Algorithms/MachineLearning/Utilities.hs: -------------------------------------------------------------------------------- 1 | -- | We commit the usual sin of lumping a load of useful functions with no clear 2 | -- home in a "utilities" module 3 | module Algorithms.MachineLearning.Utilities where 4 | 5 | import Data.List 6 | import Data.Ord 7 | 8 | import System.Random 9 | 10 | 11 | newtype K a = K { unK :: a } 12 | 13 | instance Functor K where 14 | fmap f (K x) = K (f x) 15 | 16 | 17 | square :: Num a => a -> a 18 | square x = x * x 19 | 20 | singleton :: a -> [a] 21 | singleton x = [x] 22 | 23 | void :: Monad m => m a -> m () 24 | void ma = ma >> return () 25 | 26 | rationalToDouble :: Rational -> Double 27 | rationalToDouble = realToFrac 28 | 29 | onLeft :: (a -> c) -> (a, b) -> (c, b) 30 | onLeft f (x, y) = (f x, y) 31 | 32 | onRight :: (b -> c) -> (a, b) -> (a, c) 33 | onRight f (x, y) = (x, f y) 34 | 35 | shuffle :: StdGen -> [a] -> [a] 36 | shuffle gen xs = map snd $ sortBy (comparing fst) (zip (randoms gen :: [Double]) xs) 37 | 38 | chunk :: Int -> [a] -> [[a]] 39 | chunk _ [] = [] 40 | chunk n xs = this : chunk n rest 41 | where 42 | (this, rest) = splitAt n xs 43 | 44 | sample :: StdGen -> Int -> [a] -> [a] 45 | sample gen n xs = take n (shuffle gen xs) 46 | 47 | eqWithin :: Double -> Double -> Double -> Bool 48 | eqWithin jitter left right = abs (left - right) < jitter 49 | 50 | enumAsList :: (Enum a, Bounded a) => [a] 51 | enumAsList = enumFromTo minBound maxBound 52 | 53 | enumSize :: (Enum a, Bounded a) => a -> Int 54 | enumSize what_enum = length (enumAsList `asTypeOf` [what_enum]) -------------------------------------------------------------------------------- /Algorithms/MachineLearning/LinearClassification.hs: -------------------------------------------------------------------------------- 1 | module Algorithms.MachineLearning.LinearClassification ( 2 | DiscriminantModel, 3 | regressLinearClassificationModel, 4 | ) where 5 | 6 | import Algorithms.MachineLearning.Framework 7 | import Algorithms.MachineLearning.LinearAlgebra 8 | import Algorithms.MachineLearning.LinearRegression 9 | import Algorithms.MachineLearning.Utilities 10 | 11 | import Data.Ord 12 | import Data.List 13 | import Data.Maybe 14 | 15 | 16 | data (Bounded classes, Enum classes) => DiscriminantModel input classes = DiscriminantModel { 17 | dm_class_models :: AnyModel input (Vector Double) 18 | } 19 | 20 | instance (Vectorable input, Bounded classes, Enum classes) => Model (DiscriminantModel input classes) input classes where 21 | predict model input = snd $ maximumBy (comparing fst) predictions 22 | where 23 | predictions = toList (predict (dm_class_models model) input) `zip` enumAsList 24 | 25 | 26 | regressLinearClassificationModel :: (Vectorable input, Vectorable classes, Bounded classes, Enum classes, Eq classes) 27 | => [input -> Double] -- ^ Basis functions 28 | -> DataSet input classes -- ^ Class mapping to use for training 29 | -> DiscriminantModel input classes 30 | regressLinearClassificationModel basis_fns ds = DiscriminantModel { dm_class_models = class_models } 31 | where 32 | class_models = AnyModel $ regressLinearModel basis_fns (fmapDataSetTarget classToCharacteristicVector ds) 33 | indexed_classes = enumAsList `zip` [0..] 34 | classToCharacteristicVector the_class = fromList $ replicate index 0 ++ [1] ++ replicate (size - index - 1) 0 35 | where 36 | size = enumSize the_class 37 | index = fromJust $ lookup the_class indexed_classes -------------------------------------------------------------------------------- /Algorithms/MachineLearning/BasisFunctions.hs: -------------------------------------------------------------------------------- 1 | -- | Basis functions of various kinds, useful for e.g. use with the LinearRegression module 2 | module Algorithms.MachineLearning.BasisFunctions where 3 | 4 | import Algorithms.MachineLearning.Framework 5 | import Algorithms.MachineLearning.LinearAlgebra 6 | import Algorithms.MachineLearning.Utilities 7 | 8 | 9 | -- | Basis function that is 1 everywhere 10 | constantBasis :: a -> Double 11 | constantBasis = const 1 12 | 13 | -- | /Unnormalized/ 1D Gaussian, suitable for use as a basis function. 14 | gaussianBasis :: Mean -- ^ Mean of the Gaussian 15 | -> Variance -- ^ Variance of the Gaussian 16 | -> Double -- ^ Point on X axis to sample 17 | -> Double 18 | gaussianBasis mean variance x = exp (negate $ (square (x - mean)) / (2 * variance)) 19 | 20 | -- | Family of gaussian basis functions with constant variance and the given means, with 21 | -- a constant basis function to capture the mean of the target variable. 22 | gaussianBasisFamily :: [Mean] -> Variance -> [Double -> Double] 23 | gaussianBasisFamily means variance = constantBasis : map (flip gaussianBasis variance) means 24 | 25 | -- | /Unnormalized/ multi-dimensional Gaussian, suitable for use as a basis function. 26 | multivariateGaussianBasis :: Vector Mean -- ^ Mean of the Gaussian 27 | -> Matrix Variance -- ^ Covariance matrix 28 | -> Vector Double -- ^ Point to sample 29 | -> Double 30 | multivariateGaussianBasis mean covariance x = exp (negate $ (deviation <.> (inv covariance <> deviation)) / 2) 31 | where deviation = x - mean 32 | 33 | -- | Family of multi-dimensional gaussian basis functions with constant, isotropic variance and 34 | -- the given means, with a constant basis function to capture the mean of the target variable. 35 | multivariateIsotropicGaussianBasisFamily :: [Vector Mean] -> Variance -> [Vector Double -> Double] 36 | multivariateIsotropicGaussianBasisFamily means common_variance = constantBasis : map (flip multivariateGaussianBasis covariance) means 37 | where covariance = (1 / common_variance) .* ident (dim (head means)) -------------------------------------------------------------------------------- /Algorithms/MachineLearning/LinearAlgebra.hs: -------------------------------------------------------------------------------- 1 | -- | Linear algebra used in the machine learning library: just HMatrix re-exports and some 2 | -- other useful functions I have built up. 3 | module Algorithms.MachineLearning.LinearAlgebra ( 4 | module Numeric.LinearAlgebra, 5 | module Algorithms.MachineLearning.LinearAlgebra 6 | ) where 7 | 8 | import Numeric.LinearAlgebra 9 | 10 | 11 | -- | Given the input functions: 12 | -- 13 | -- @[f_1, f_2, ..., f_n]@ 14 | -- 15 | -- and input matrix: 16 | -- 17 | -- @ 18 | -- [ r_1 19 | -- , r_2 20 | -- , ... 21 | -- , r_m ] 22 | -- @ 23 | -- 24 | -- returns the output matrix: 25 | -- 26 | -- @ 27 | -- [ f_1(r_1), f_2(r_1), ..., f_n(r_1) 28 | -- , f_1(r_2), f_2(r_2), ..., f_n(r_2) 29 | -- , f_1(r_m), f_2(r_m), ..., f_n(r_m) ] 30 | -- @ 31 | applyMatrix :: [Vector Double -> Double] -> Matrix Double -> Matrix Double 32 | applyMatrix fns inputs = fromLists (map (\r -> map ($ r) fns) rs) 33 | where 34 | rs = toRows inputs 35 | 36 | -- | Given the input functions: 37 | -- 38 | -- @[f_1, f_2, ..., f_n]@ 39 | -- 40 | -- and input: 41 | -- 42 | -- @x@ 43 | -- 44 | -- returns the output vector: 45 | -- 46 | -- @ 47 | -- [ f_1(x) 48 | -- , f_2(x) 49 | -- , ... 50 | -- , f_n(x) ] 51 | -- @ 52 | applyVector :: [inputs -> Double] -> inputs -> Vector Double 53 | applyVector fns inputs = fromList $ map ($ inputs) fns 54 | 55 | -- | Summation of the elements in a vector. 56 | vectorSum :: Element a => Vector a -> a 57 | vectorSum v = constant 1 (dim v) <.> v 58 | 59 | -- | The sum of the squares of the elements of the vector 60 | vectorSumSquares :: Element a => Vector a -> a 61 | vectorSumSquares v = v <.> v 62 | 63 | -- | Mean of the elements in a vector. 64 | vectorMean :: Element a => Vector a -> a 65 | vectorMean v = (vectorSum v) / fromIntegral (dim v) 66 | 67 | -- | Column-wise summation of a matrix. 68 | sumColumns :: Element a => Matrix a -> Vector a 69 | sumColumns m = constant 1 (rows m) <> m 70 | 71 | -- | Create a constant matrix of the given dimension, analagously to 'constant'. 72 | constantM :: Element a => a -> Int -> Int -> Matrix a 73 | constantM elt row_count col_count = reshape row_count (constant elt (row_count * col_count)) 74 | 75 | matrixToVector :: Element a => Matrix a -> Vector a 76 | matrixToVector m 77 | | rows m == 1 -- Row vector 78 | || cols m == 1 -- Column vector 79 | = flatten m 80 | | otherwise 81 | = error "matrixToVector: matrix is neither a row or column vector" 82 | 83 | matrixTrace :: Element a => Matrix a -> a 84 | matrixTrace = vectorSum . takeDiag -------------------------------------------------------------------------------- /machine-learning.cabal: -------------------------------------------------------------------------------- 1 | Name: machine-learning 2 | Version: 0.1 3 | Cabal-Version: >= 1.2 4 | Category: Algorithms 5 | Synopsis: Machine learning algorithms 6 | Description: Algorithms for machine learning and pattern recognition, based on the book "Pattern Recognition and Machine Learning" by Christopher Bishop 7 | -- Because of HMatrix and HTam, I have to release this package under GPL 8 | License: GPL 9 | License-File: LICENSE 10 | Author: Max Bolingbroke 11 | Maintainer: batterseapower@hotmail.com 12 | Homepage: http://bsp.lighthouseapp.com/projects/15641-hs-machine-learning/overview 13 | Build-Type: Simple 14 | 15 | Flag Tests 16 | Description: Enable building the tests 17 | Default: False 18 | 19 | Flag SplitBase 20 | Description: Choose the new smaller, split-up base package 21 | Default: True 22 | 23 | 24 | Library 25 | Exposed-Modules: Algorithms.MachineLearning.BasisFunctions 26 | Algorithms.MachineLearning.Framework 27 | Algorithms.MachineLearning.LinearClassification 28 | Algorithms.MachineLearning.LinearRegression 29 | 30 | Other-Modules: Algorithms.MachineLearning.LinearAlgebra 31 | Algorithms.MachineLearning.Utilities 32 | 33 | Build-Depends: hmatrix >= 0.4.0.0, random >= 1.0.0.0 34 | if flag(splitBase) 35 | Build-Depends: base >= 3 36 | else 37 | Build-Depends: base < 3 38 | 39 | Extensions: PatternSignatures 40 | MultiParamTypeClasses 41 | FunctionalDependencies 42 | TypeSynonymInstances 43 | FlexibleInstances 44 | FlexibleContexts 45 | ExistentialQuantification 46 | ScopedTypeVariables 47 | Ghc-Options: -O2 -fvia-C -Wall 48 | 49 | Executable machine-learning-tests 50 | Main-Is: Algorithms/MachineLearning/Tests/Driver.hs 51 | 52 | -- I just need HTam for the GNUPlot module. Probably I should just write my 53 | -- own GNUPlot interface module instead of depending on such a weird package :-) 54 | Build-Depends: hmatrix >= 0.4.0.0, HTam, random >= 1.0.0.0 55 | if flag(splitBase) 56 | Build-Depends: base >= 3, process >= 1.0.0.0 57 | else 58 | Build-Depends: base < 3 59 | 60 | Extensions: PatternSignatures 61 | MultiParamTypeClasses 62 | FunctionalDependencies 63 | TypeSynonymInstances 64 | FlexibleInstances 65 | FlexibleContexts 66 | ExistentialQuantification 67 | ScopedTypeVariables 68 | Ghc-Options: -O2 -fvia-C -Wall 69 | 70 | if !flag(tests) 71 | Buildable: False -------------------------------------------------------------------------------- /Algorithms/MachineLearning/Tests/Driver.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Algorithms.MachineLearning.BasisFunctions 4 | import Algorithms.MachineLearning.Framework 5 | import Algorithms.MachineLearning.LinearAlgebra 6 | import Algorithms.MachineLearning.LinearClassification 7 | import Algorithms.MachineLearning.LinearRegression 8 | import Algorithms.MachineLearning.Tests.Data 9 | import Algorithms.MachineLearning.Utilities 10 | 11 | import GNUPlot 12 | 13 | import Data.List 14 | import Data.Ord 15 | 16 | import System.Cmd 17 | import System.Random 18 | 19 | 20 | basisFunctions :: [Double -> Double] 21 | basisFunctions 22 | -- = [constantBasis] 23 | = gaussianBasisFamily (map rationalToDouble [-1,-0.5..1]) 0.09 24 | -- = gaussianBasisFamily (map rationalToDouble [-1,-0.9..1]) 0.04 25 | 26 | basisFunctions2D :: [(Double, Double) -> Double] 27 | basisFunctions2D 28 | = [\(x, _) -> x, \(_, y) -> y] 29 | -- = map ((. \(x, y) -> fromList [x, y])) $ multivariateIsotropicGaussianBasisFamily [fromList [x, y] | x <- range, y <- range] 0.1 30 | where range = map rationalToDouble [-1,-0.8..1] 31 | 32 | sampleFunction :: (Double -> a) -> [(Double, a)] 33 | sampleFunction f = map (\(x :: Rational) -> let x' = rationalToDouble x in (x', f x')) 34 | [0,0.01..1.0] 35 | 36 | sampleFunction2D :: ((Double, Double) -> a) -> [((Double, Double), a)] 37 | sampleFunction2D f = map (\((x, y) :: (Rational, Rational)) -> let x' = rationalToDouble x; y' = rationalToDouble y in ((x', y'), f (x', y'))) 38 | [(x, y) | x <- [-1.0,-0.99..1.0], y <- [-1.0,-0.99..1.0]] 39 | 40 | evaluate :: (Vectorable input, Vectorable target, Model model input target, MetricSpace target) => model -> DataSet input target -> IO () 41 | evaluate model true_data = do 42 | putStrLn $ "Target Raw Means = " ++ show (map vectorMean (toColumns $ ds_targets true_data)) 43 | putStrLn $ "Error = " ++ show (modelSumSquaredError model true_data) 44 | 45 | plot :: [[(Double, Target)]] -> IO () 46 | plot sampless = do 47 | plotPaths [EPS "output.ps"] (map (sortBy (comparing fst)) sampless) 48 | void $ rawSystem "open" ["output.ps"] 49 | 50 | plotClasses :: [((Double, Double), Class)] -> IO () 51 | plotClasses classess = do 52 | let -- Utilize a hack to obtain color output :-) 53 | color_eps filename = [EPS filename, ColorBox (Just [";set terminal postscript enhanced color"])] 54 | red_cross_style = (Points, CustomStyle [PointType 1]) 55 | blue_circle_style = (Points, CustomStyle [PointType 6]) -- These are actually /green/ circles, but who's counting? 56 | generations = [ (red_cross_style, [position | (position, RedCross) <- classess]) 57 | , (blue_circle_style, [position | (position, BlueCircle) <- classess]) ] 58 | plot2dMultiGen (color_eps "output.ps") generations 59 | void $ rawSystem "open" ["output.ps"] 60 | 61 | 62 | linearModelTest :: IO () 63 | linearModelTest = do 64 | -- Do the regression 65 | let used_data = sinDataSet 66 | model = regressLinearModel basisFunctions used_data 67 | 68 | -- Show some model statistics 69 | evaluate model used_data 70 | putStrLn $ "Model For Target:\n" ++ show model 71 | 72 | -- Show some graphical information about the model 73 | plot [dataSetToSampleList used_data, sampleFunction $ predict model] 74 | 75 | bayesianLinearModelTest :: IO () 76 | bayesianLinearModelTest = do 77 | gen <- newStdGen 78 | let used_data = sampleDataSet gen 10 sinDataSet 79 | (model, variance_model) = regressBayesianLinearModel 1 (1 / 0.09) basisFunctions used_data 80 | 81 | -- Show some model statistics 82 | evaluate model used_data 83 | putStrLn $ "Model For Target:\n" ++ show model 84 | putStrLn $ "Model For Variance:\n" ++ show variance_model 85 | 86 | -- Show some graphical information about the model 87 | plot [dataSetToSampleList used_data, sampleFunction $ predict model, sampleFunction $ (sqrt . predict variance_model)] 88 | 89 | emBayesianLinearModelTest :: IO () 90 | emBayesianLinearModelTest = do 91 | gen <- newStdGen 92 | let used_data = sampleDataSet gen 10 sinDataSet 93 | (model, variance_model, gamma) = regressEMBayesianLinearModel 1 (1 / 0.09) basisFunctions used_data 94 | 95 | -- Show some model statistics 96 | evaluate model used_data 97 | putStrLn $ "Model For Target:\n" ++ show model 98 | putStrLn $ "Model For Variance:\n" ++ show variance_model 99 | putStrLn $ "Gamma = " ++ show gamma 100 | 101 | -- Show some graphical information about the model 102 | plot [dataSetToSampleList used_data, sampleFunction $ predict model, sampleFunction $ (sqrt . predict variance_model)] 103 | 104 | linearClassificationModelTest :: IO () 105 | linearClassificationModelTest = do 106 | let used_data = classificationDataSet 107 | model = regressLinearClassificationModel basisFunctions2D used_data 108 | 109 | -- Show some model statistics 110 | evaluate model used_data 111 | 112 | -- Show some graphical information about the model 113 | plotClasses (dataSetToSampleList classificationDataSet) 114 | plotClasses (sampleFunction2D $ predict model) 115 | 116 | main :: IO () 117 | main = linearClassificationModelTest -------------------------------------------------------------------------------- /Algorithms/MachineLearning/Framework.hs: -------------------------------------------------------------------------------- 1 | -- | The "framework" provides all the core classes and types used ubiquitously by 2 | -- the machine learning algorithms. 3 | module Algorithms.MachineLearning.Framework where 4 | 5 | import Algorithms.MachineLearning.LinearAlgebra 6 | import Algorithms.MachineLearning.Utilities 7 | 8 | import Numeric.LinearAlgebra 9 | 10 | import Data.List 11 | 12 | import System.Random 13 | 14 | 15 | -- 16 | -- Ubiquitous synonyms for documentation purposes 17 | -- 18 | 19 | -- | The target is the variable you wish to predict with your machine learning algorithm. 20 | type Target = Double 21 | 22 | type Weight = Double 23 | 24 | -- | Commonly called the "average" of a set of data. 25 | type Mean = Double 26 | 27 | -- | Variance is the mean squared deviation from the mean. Must be positive. 28 | type Variance = Double 29 | 30 | -- | Precision is the inverse of variance. Must be positive. 31 | type Precision = Double 32 | 33 | -- | A positive coefficient indicating how strongly regularization should be applied. A good 34 | -- choice for this parameter might be your belief about the variance of the inherent noise 35 | -- in the samples (1/beta) divided by your belief about the variance of the weights that 36 | -- should be learnt by the model (1/alpha). 37 | -- 38 | -- Commonly written as lambda. 39 | -- 40 | -- See also equation 3.55 and 3.28 in Bishop. 41 | type RegularizationCoefficient = Double 42 | 43 | -- | A positive number that indicates the number of fully determined parameters in a learnt 44 | -- model. If all your parameters are determined, it will be equal to the number of parameters 45 | -- available, and if your data did not support any parameters it will be simply 0. 46 | -- 47 | -- See also section 3.5.3 of Bishop. 48 | type EffectiveNumberOfParameters = Double 49 | 50 | -- 51 | -- Injections to and from vectors 52 | -- 53 | class Vectorable a where 54 | toVector :: a -> Vector Double 55 | fromVector :: Vector Double -> a 56 | 57 | instance Vectorable Double where 58 | toVector = flip constant 1 59 | fromVector = flip (@>) 0 60 | 61 | instance Vectorable (Double, Double) where 62 | toVector (x, y) = 2 |> [x, y] 63 | fromVector vec = (vec @> 0, vec @> 1) 64 | 65 | instance Vectorable (Vector Double) where 66 | toVector = id 67 | fromVector = id 68 | 69 | -- 70 | -- Labelled data set 71 | -- 72 | 73 | data DataSet input target = DataSet { 74 | ds_inputs :: Matrix Double, -- One row per sample, one column per input variable 75 | ds_targets :: Matrix Target -- One row per sample, one column per target variable 76 | } 77 | 78 | fmapDataSetInput :: (Vectorable input, Vectorable input', Vectorable target) => (input -> input') -> DataSet input target -> DataSet input' target 79 | fmapDataSetInput f = dataSetFromSampleList . fmap (onLeft f) . dataSetToSampleList 80 | 81 | fmapDataSetTarget :: (Vectorable input, Vectorable target, Vectorable target') => (target -> target') -> DataSet input target -> DataSet input target' 82 | fmapDataSetTarget f = dataSetFromSampleList . fmap (onRight f) . dataSetToSampleList 83 | 84 | dataSetFromSampleList :: (Vectorable input, Vectorable target) => [(input, target)] -> DataSet input target 85 | dataSetFromSampleList elts 86 | | null elts = error "dataSetFromSampleList: no data supplied" 87 | | otherwise = DataSet { 88 | ds_inputs = fromRows $ map (toVector . fst) elts, 89 | ds_targets = fromRows $ map (toVector . snd) elts 90 | } 91 | 92 | dataSetToSampleList :: (Vectorable input, Vectorable target) => DataSet input target -> [(input, target)] 93 | dataSetToSampleList ds = zip (dataSetInputs ds) (dataSetTargets ds) 94 | 95 | dataSetInputs :: Vectorable input => DataSet input target -> [input] 96 | dataSetInputs ds = map fromVector $ toRows $ ds_inputs ds 97 | 98 | dataSetTargets :: Vectorable target => DataSet input target -> [target] 99 | dataSetTargets ds = map fromVector $ toRows $ ds_targets ds 100 | 101 | dataSetInputLength :: DataSet input target -> Int 102 | dataSetInputLength ds = cols (ds_inputs ds) 103 | 104 | dataSetSize :: DataSet input target -> Int 105 | dataSetSize ds = rows (ds_inputs ds) 106 | 107 | binDataSet :: StdGen -> Int -> DataSet input target -> [DataSet input target] 108 | binDataSet gen bins = transformDataSetAsVectors binDataSet' 109 | where 110 | binDataSet' ds = map dataSetFromSampleList $ chunk bin_size shuffled_samples 111 | where 112 | shuffled_samples = shuffle gen (dataSetToSampleList ds) 113 | bin_size = ceiling $ (fromIntegral $ dataSetSize ds :: Double) / (fromIntegral bins) 114 | 115 | sampleDataSet :: StdGen -> Int -> DataSet input target -> DataSet input target 116 | sampleDataSet gen n = unK . transformDataSetAsVectors (K . dataSetFromSampleList . sample gen n . dataSetToSampleList) 117 | 118 | transformDataSetAsVectors :: Functor f => (DataSet (Vector Double) (Vector Double) -> f (DataSet (Vector Double) (Vector Double))) -> DataSet input target -> f (DataSet input target) 119 | transformDataSetAsVectors transform input = fmap castDataSet (transform (castDataSet input)) 120 | where 121 | castDataSet :: DataSet input1 target1 -> DataSet input2 target2 122 | castDataSet ds = DataSet { 123 | ds_inputs = ds_inputs ds, 124 | ds_targets = ds_targets ds 125 | } 126 | 127 | -- 128 | -- Metric spaces 129 | -- 130 | 131 | class MetricSpace a where 132 | distance :: a -> a -> Double 133 | 134 | instance MetricSpace Double where 135 | distance x y = abs (x - y) 136 | 137 | instance MetricSpace (Vector Double) where 138 | distance x y = vectorSumSquares (x - y) 139 | 140 | -- 141 | -- Models 142 | -- 143 | 144 | class Model model input target | model -> input target where 145 | predict :: model -> input -> target 146 | 147 | data AnyModel input output = forall model. Model model input output => AnyModel { theModel :: model } 148 | 149 | instance Model (AnyModel input output) input output where 150 | predict (AnyModel model) = predict model 151 | 152 | modelSumSquaredError :: (Model model input target, MetricSpace target, Vectorable input, Vectorable target) => model -> DataSet input target -> Double 153 | modelSumSquaredError model ds = sum [sample_error * sample_error | sample_error <- sample_errors] 154 | where 155 | sample_errors = zipWith (\x y -> x `distance` y) (dataSetTargets ds) (map (predict model) (dataSetInputs ds)) -------------------------------------------------------------------------------- /Algorithms/MachineLearning/Tests/Data.hs: -------------------------------------------------------------------------------- 1 | -- | Sample data sets, used to drive the machine learning library during development. 2 | module Algorithms.MachineLearning.Tests.Data where 3 | 4 | import Algorithms.MachineLearning.Framework 5 | import Algorithms.MachineLearning.LinearAlgebra 6 | 7 | 8 | -- | Dataset of noisy samples from a sine function: 9 | -- 10 | -- @ 11 | -- input = U(0, 1) 12 | -- target = sin(2πx) + N(0, 0.09) 13 | -- @ 14 | -- 15 | -- Source: http://research.microsoft.com/~cmbishop/PRML/webdatasets/curvefitting.txt 16 | sinDataSet :: DataSet Double Double 17 | sinDataSet = dataSetFromSampleList [ 18 | (0.000000, 0.349486), 19 | (0.111111, 0.830839), 20 | (0.222222, 1.007332), 21 | (0.333333, 0.971507), 22 | (0.444444, 0.133066), 23 | (0.555556, 0.166823), 24 | (0.666667, -0.848307), 25 | (0.777778, -0.445686), 26 | (0.888889, -0.563567), 27 | (1.000000, 0.261502) 28 | ] 29 | 30 | -- | Classes for use in demonstrating point classification 31 | data Class = RedCross 32 | | BlueCircle 33 | deriving (Bounded, Enum, Eq) 34 | 35 | -- This instance is only used to encode the Class within a DataSet: after that we 36 | -- (typically) use a custom 1-of-K coding in LinearClassification. Does this point 37 | -- to a refactoring not to store inputs in a matrix in DataSet? 38 | instance Vectorable Class where 39 | toVector x = constant (fromIntegral $ fromEnum x) 1 40 | fromVector v = toEnum (round (v @> 0)) 41 | 42 | instance MetricSpace Class where 43 | distance x y = if x == y then 0 else 1 44 | 45 | -- | Dataset of point classifications:, with Gaussian mixture configured as follows: 46 | -- 47 | -- @ 48 | -- mix.priors = [0.5 0.25 0.25]; 49 | -- mix.centres = [0 -0.1; 1 1; 1 -1]; 50 | -- mix.covars(:,:,1) = [0.625 -0.2165; -0.2165 0.875]; 51 | -- mix.covars(:,:,2) = [0.2241 -0.1368; -0.1368 0.9759]; 52 | -- mix.covars(:,:,3) = [0.2375 0.1516; 0.1516 0.4125]; 53 | -- @ 54 | -- 55 | -- Source: http://research.microsoft.com/~cmbishop/prml/webdatasets/classification.txt 56 | classificationDataSet :: DataSet (Double, Double) Class 57 | classificationDataSet = dataSetFromSampleList [ 58 | ((1.208985, 0.421448), RedCross), 59 | ((0.504542, -0.285730), BlueCircle), 60 | ((0.630568, 1.054712), RedCross), 61 | ((1.056364, 0.601873), RedCross), 62 | ((1.095326, -1.447579), BlueCircle), 63 | ((-0.210165, 0.000284), BlueCircle), 64 | ((-0.367151, -1.255189), BlueCircle), 65 | ((0.868013, -1.063465), RedCross), 66 | ((1.704441, -0.644833), RedCross), 67 | ((0.565619, -1.637858), BlueCircle), 68 | ((0.598389, -1.477808), RedCross), 69 | ((0.580927, -0.783898), BlueCircle), 70 | ((1.183283, -1.797936), RedCross), 71 | ((0.331843, -1.869486), RedCross), 72 | ((-0.051195, 0.989475), BlueCircle), 73 | ((2.427090, 0.173557), RedCross), 74 | ((1.603778, -0.030691), BlueCircle), 75 | ((1.286206, -1.079916), RedCross), 76 | ((-1.243951, 1.005355), BlueCircle), 77 | ((1.181748, 1.523744), RedCross), 78 | ((0.896222, 1.899568), RedCross), 79 | ((-0.366207, -0.664987), BlueCircle), 80 | ((-0.078800, 1.007368), BlueCircle), 81 | ((-1.351435, 1.766786), BlueCircle), 82 | ((-0.220423, -0.442405), BlueCircle), 83 | ((0.836253, -1.927526), RedCross), 84 | ((0.039899, -1.435842), RedCross), 85 | ((0.256755, 0.946722), RedCross), 86 | ((0.974836, -0.944967), RedCross), 87 | ((0.705256, -2.618644), RedCross), 88 | ((0.738188, -1.666242), RedCross), 89 | ((1.245931, -2.200826), RedCross), 90 | ((0.297604, 0.159463), BlueCircle), 91 | ((-2.210680, 1.195815), BlueCircle), 92 | ((-0.872624, -0.131252), BlueCircle), 93 | ((1.112762, -0.653777), RedCross), 94 | ((1.123989, -1.347470), RedCross), 95 | ((0.750833, 0.811870), RedCross), 96 | ((-0.183497, 1.416116), BlueCircle), 97 | ((0.287582, -1.342512), RedCross), 98 | ((1.092719, 1.380559), RedCross), 99 | ((0.719502, 1.594624), RedCross), 100 | ((-1.016254, 0.651607), BlueCircle), 101 | ((0.379677, 2.802498), RedCross), 102 | ((0.150675, 0.474679), BlueCircle), 103 | ((-0.116477, 0.437483), BlueCircle), 104 | ((1.122528, 0.698541), RedCross), 105 | ((0.953551, 1.088368), RedCross), 106 | ((-0.000228, 0.347187), BlueCircle), 107 | ((0.505024, 0.455407), BlueCircle), 108 | ((0.113753, 0.559572), BlueCircle), 109 | ((-0.677993, 0.322716), BlueCircle), 110 | ((1.114811, -0.735813), RedCross), 111 | ((0.344114, -1.770137), RedCross), 112 | ((0.684242, -0.636027), BlueCircle), 113 | ((-0.684629, -0.300568), BlueCircle), 114 | ((-0.362677, -0.669101), BlueCircle), 115 | ((0.604984, -1.558581), RedCross), 116 | ((0.514202, -0.225827), RedCross), 117 | ((0.227014, -1.579346), BlueCircle), 118 | ((1.044068, -1.491114), RedCross), 119 | ((0.314855, -2.535762), BlueCircle), 120 | ((1.187904, -1.367278), RedCross), 121 | ((0.517132, 1.375811), RedCross), 122 | ((1.244285, -0.764164), RedCross), 123 | ((-0.831841, 1.728708), BlueCircle), 124 | ((1.719616, -2.491282), BlueCircle), 125 | ((0.594216, 1.137571), BlueCircle), 126 | ((0.939919, -0.474988), RedCross), 127 | ((-0.918736, -0.748474), BlueCircle), 128 | ((0.913760, -1.194336), RedCross), 129 | ((0.893221, -1.569459), RedCross), 130 | ((0.653152, 0.510498), RedCross), 131 | ((0.766890, -1.577565), RedCross), 132 | ((0.868315, -1.966740), BlueCircle), 133 | ((0.874218, 0.514959), BlueCircle), 134 | ((-0.559543, 1.749552), BlueCircle), 135 | ((1.526669, -1.797734), BlueCircle), 136 | ((1.843439, -0.363161), RedCross), 137 | ((1.163746, 2.062245), RedCross), 138 | ((0.565749, -2.432301), BlueCircle), 139 | ((1.016715, 2.878822), RedCross), 140 | ((1.433979, -1.944960), BlueCircle), 141 | ((-0.510225, 0.295742), BlueCircle), 142 | ((-0.385261, 0.278145), BlueCircle), 143 | ((1.042889, -0.564351), RedCross), 144 | ((-0.607265, 1.885851), BlueCircle), 145 | ((-0.355286, -1.813131), BlueCircle), 146 | ((-0.790644, -0.790761), BlueCircle), 147 | ((1.372382, 0.879619), RedCross), 148 | ((1.133019, -0.300956), RedCross), 149 | ((1.395009, -1.006842), RedCross), 150 | ((0.887843, 0.222319), BlueCircle), 151 | ((1.484690, 0.095074), RedCross), 152 | ((1.268061, 1.832532), RedCross), 153 | ((0.124568, 0.910824), BlueCircle), 154 | ((1.061504, -0.768175), BlueCircle), 155 | ((0.298551, 2.573175), RedCross), 156 | ((0.241114, -0.613155), RedCross), 157 | ((-0.423781, -1.524901), BlueCircle), 158 | ((0.528691, -0.939526), RedCross), 159 | ((1.601252, 1.791658), RedCross), 160 | ((0.793609, 0.812783), BlueCircle), 161 | ((0.327097, 0.326998), RedCross), 162 | ((1.131868, -0.985696), BlueCircle), 163 | ((1.273154, 1.656441), RedCross), 164 | ((-0.816691, 0.961580), BlueCircle), 165 | ((0.669064, 1.162614), RedCross), 166 | ((-0.453759, -1.146883), BlueCircle), 167 | ((2.055105, 0.025811), RedCross), 168 | ((0.463119, -0.813294), BlueCircle), 169 | ((0.802392, -0.140807), BlueCircle), 170 | ((-0.730255, -0.145175), BlueCircle), 171 | ((0.569256, 0.567628), BlueCircle), 172 | ((0.486947, 1.130519), RedCross), 173 | ((1.793588, -1.426926), RedCross), 174 | ((1.178831, -0.581314), BlueCircle), 175 | ((0.480055, 1.257981), RedCross), 176 | ((0.683732, 0.190071), BlueCircle), 177 | ((-0.119082, -0.004020), BlueCircle), 178 | ((-1.251554, -0.176027), BlueCircle), 179 | ((1.094741, -1.099305), RedCross), 180 | ((-0.238250, -1.277484), BlueCircle), 181 | ((-0.661556, 1.327722), BlueCircle), 182 | ((1.442837, 1.241720), RedCross), 183 | ((1.202320, 0.489702), RedCross), 184 | ((0.932890, 0.296430), RedCross), 185 | ((0.665568, -1.314006), RedCross), 186 | ((-0.058993, 1.322294), BlueCircle), 187 | ((0.209525, -1.006357), RedCross), 188 | ((1.023340, 0.219375), RedCross), 189 | ((1.324444, 0.446567), BlueCircle), 190 | ((1.453910, -1.151325), RedCross), 191 | ((0.616303, 0.974796), RedCross), 192 | ((1.492010, -0.885984), RedCross), 193 | ((1.738658, 0.686807), BlueCircle), 194 | ((0.900582, -0.280724), RedCross), 195 | ((0.961914, -0.053991), BlueCircle), 196 | ((1.819706, -0.953273), BlueCircle), 197 | ((1.581289, -0.340552), RedCross), 198 | ((0.520837, -0.680639), BlueCircle), 199 | ((1.433771, -0.914798), RedCross), 200 | ((0.611594, -1.691685), RedCross), 201 | ((1.591513, -0.978986), BlueCircle), 202 | ((1.282094, 0.113769), RedCross), 203 | ((0.985715, 0.275551), RedCross), 204 | ((-1.805143, 2.628696), BlueCircle), 205 | ((1.473100, -0.241372), RedCross), 206 | ((-0.242212, -1.040151), BlueCircle), 207 | ((1.175525, -1.662026), RedCross), 208 | ((0.696040, 0.154387), RedCross), 209 | ((1.457713, 1.608681), RedCross), 210 | ((0.883215, 1.330538), RedCross), 211 | ((-0.681209, 0.622394), BlueCircle), 212 | ((-0.355082, 0.432941), BlueCircle), 213 | ((0.633011, -1.194431), RedCross), 214 | ((0.782723, 1.060008), BlueCircle), 215 | ((0.670180, -0.766999), BlueCircle), 216 | ((-0.047154, 0.698693), BlueCircle), 217 | ((0.287385, -1.097756), RedCross), 218 | ((0.069561, 1.632585), BlueCircle), 219 | ((1.013230, 1.111551), RedCross), 220 | ((0.639065, -0.697237), RedCross), 221 | ((1.174621, 2.240022), BlueCircle), 222 | ((1.322020, 0.040277), BlueCircle), 223 | ((0.019127, 0.105667), BlueCircle), 224 | ((0.584584, 1.101914), RedCross), 225 | ((1.157265, -0.665947), RedCross), 226 | ((1.565230, -0.840790), RedCross), 227 | ((1.759315, 0.963703), BlueCircle), 228 | ((1.687068, -1.086466), RedCross), 229 | ((0.578314, -0.340961), BlueCircle), 230 | ((0.118925, -1.487694), BlueCircle), 231 | ((0.471201, 0.330872), BlueCircle), 232 | ((-0.268209, -0.353477), RedCross), 233 | ((1.625390, -1.718798), RedCross), 234 | ((1.117791, 2.752549), RedCross), 235 | ((-0.194552, -0.752687), BlueCircle), 236 | ((0.769548, -2.066152), RedCross), 237 | ((0.186062, 0.022072), BlueCircle), 238 | ((1.771337, -0.393550), RedCross), 239 | ((-1.300597, 0.962803), BlueCircle), 240 | ((0.708730, -1.013371), RedCross), 241 | ((-0.624235, -0.892995), BlueCircle), 242 | ((0.377055, -1.296098), RedCross), 243 | ((0.804404, -0.856253), BlueCircle), 244 | ((1.359887, -0.974291), RedCross), 245 | ((-0.115505, 0.228439), BlueCircle), 246 | ((0.913645, -0.344936), BlueCircle), 247 | ((0.318875, -0.886290), BlueCircle), 248 | ((0.822157, 0.102548), RedCross), 249 | ((-0.281208, 1.302572), BlueCircle), 250 | ((0.044639, -1.107980), BlueCircle), 251 | ((-0.029205, -2.033973), RedCross), 252 | ((0.879914, -2.000582), BlueCircle), 253 | ((0.601936, -0.503923), RedCross), 254 | ((-0.490114, -0.841122), BlueCircle), 255 | ((1.847075, 2.362322), RedCross), 256 | ((-0.279703, 0.753196), BlueCircle), 257 | ((1.953357, -0.746632), RedCross) 258 | ] -------------------------------------------------------------------------------- /Algorithms/MachineLearning/LinearRegression.hs: -------------------------------------------------------------------------------- 1 | -- | Linear regression models, as discussed in chapter 3 of Bishop. 2 | module Algorithms.MachineLearning.LinearRegression ( 3 | LinearModel, BayesianVarianceModel, 4 | regressLinearModel, regressRegularizedLinearModel, regressBayesianLinearModel, 5 | regressEMBayesianLinearModel, regressFullyDeterminedEMBayesianLinearModel 6 | ) where 7 | 8 | import Algorithms.MachineLearning.Framework 9 | import Algorithms.MachineLearning.LinearAlgebra 10 | import Algorithms.MachineLearning.Utilities 11 | 12 | 13 | data LinearModel input target = LinearModel { 14 | lm_basis_fns :: [input -> Double], 15 | lm_weights :: Matrix Weight -- One column per target variable, 16 | -- one row per basis function output 17 | } 18 | 19 | instance Show (LinearModel input target) where 20 | show model = "Weights: " ++ show (lm_weights model) 21 | 22 | instance (Vectorable input, Vectorable target) => Model (LinearModel input target) input target where 23 | predict model input = fromVector $ (trans $ lm_weights model) <> phi_app_x 24 | where 25 | phi_app_x = applyVector (lm_basis_fns model) input 26 | 27 | 28 | data BayesianVarianceModel input = BayesianVarianceModel { 29 | bvm_basis_fns :: [input -> Double], 30 | bvm_inv_hessian :: Matrix Weight, -- Equivalent to the weight distribution covariance matrix 31 | bvm_beta :: Precision 32 | } 33 | 34 | instance Show (BayesianVarianceModel input) where 35 | show model = "Inverse Hessian: " ++ show (bvm_inv_hessian model) ++ "\n" ++ 36 | "Beta: " ++ show (bvm_beta model) 37 | 38 | instance (Vectorable input) => Model (BayesianVarianceModel input) input Variance where 39 | predict model input = recip (bvm_beta model) + (phi_app_x <> bvm_inv_hessian model) <.> phi_app_x 40 | where 41 | phi_app_x = applyVector (bvm_basis_fns model) input 42 | 43 | 44 | regressDesignMatrix :: (Vectorable input) => [input -> Double] -> Matrix Double -> Matrix Double 45 | regressDesignMatrix basis_fns inputs 46 | = applyMatrix (map (. fromVector) basis_fns) inputs -- One row per sample, one column per basis function 47 | 48 | -- | Regularized pseudo-inverse of a matrix, with regularization coefficient lambda. 49 | regularizedPinv :: RegularizationCoefficient -> Matrix Double -> Matrix Double 50 | regularizedPinv lambda phi = regularizedPrePinv lambda 1 phi <> trans phi 51 | 52 | -- | Just the left portion of the formula for the pseudo-inverse, with coefficients alpha and beta, i.e.: 53 | -- 54 | -- > (alpha * _I_ + beta * _phi_ ^ T * _phi_) ^ -1 55 | regularizedPrePinv :: Precision -> Precision -> Matrix Double -> Matrix Double 56 | regularizedPrePinv alpha beta phi = inv $ (alpha .* (ident (cols phi))) + (beta .* (trans phi <> phi)) 57 | 58 | 59 | -- | Regress a basic linear model with no regularization at all onto the given data using the 60 | -- supplied basis functions. 61 | -- 62 | -- The resulting model is likely to suffer from overfitting, and may not be well defined in the basis 63 | -- functions are close to colinear. 64 | -- 65 | -- However, the model will be the optimal model for the data given the basis in least-squares terms. It 66 | -- is also very quick to find, since there is a closed form solution. 67 | -- 68 | -- Equation 3.15 in Bishop. 69 | regressLinearModel 70 | :: (Vectorable input) => [input -> Double] -> DataSet input target -> LinearModel input target 71 | regressLinearModel basis_fns ds = LinearModel { lm_basis_fns = basis_fns, lm_weights = weights } 72 | where 73 | design_matrix = regressDesignMatrix basis_fns (ds_inputs ds) 74 | weights = pinv design_matrix <> ds_targets ds 75 | 76 | -- | Regress a basic linear model with a sum-of-squares regularization term. This penalizes models with weight 77 | -- vectors of large magnitudes and hence ameliorates the over-fitting problem of 'regressLinearModel'. 78 | -- The strength of the regularization is controlled by the lambda parameter. If lambda is 0 then this function 79 | -- is equivalent to the unregularized regression. 80 | -- 81 | -- The resulting model will be optimal in terms of least-squares penalized by lambda times the sum-of-squares of 82 | -- the weight vector. Like 'regressLinearModel', a closed form solution is used to find the model quickly. 83 | -- 84 | -- Equation 3.28 in Bishop. 85 | regressRegularizedLinearModel 86 | :: (Vectorable input) => RegularizationCoefficient -> [input -> Double] -> DataSet input target -> LinearModel input target 87 | regressRegularizedLinearModel lambda basis_fns ds = LinearModel { lm_basis_fns = basis_fns, lm_weights = weights } 88 | where 89 | design_matrix = regressDesignMatrix basis_fns (ds_inputs ds) 90 | weights = regularizedPinv lambda design_matrix <> ds_targets ds 91 | 92 | 93 | -- | Determine the mean weight and inverse hessian matrix given alpha, beta, the design matrix and the targets. 94 | bayesianPosteriorParameters :: Precision -> Precision -> Matrix Double -> Matrix Double -> (Matrix Double, Matrix Double) 95 | bayesianPosteriorParameters alpha beta design_matrix targets = (weights, inv_hessian) 96 | where 97 | inv_hessian = regularizedPrePinv alpha beta design_matrix 98 | weights = beta .* inv_hessian <> trans design_matrix <> targets 99 | 100 | -- | Bayesian linear regression, using an isotropic Gaussian prior for the weights centred at the origin. The precision 101 | -- of the weight prior is controlled by the parameter alpha, and our belief about the inherent noise in the data is 102 | -- controlled by the precision parameter beta. 103 | -- 104 | -- Bayesion linear regression with this prior is entirely equivalent to calling 'regressRegularizedLinearModel' with 105 | -- lambda = alpha / beta. However, the twist is that we can use our knowledge of the prior to also make an estimate 106 | -- for the variance of the true value about any input point. 107 | -- 108 | -- For the case of multiple target variables, this function makes the naive Bayesian assumption that the probability 109 | -- distributions on output variables are independent, and takes as an error metric the unweighted sum-squared error 110 | -- in all the targets. The variance model is common to all the target variables. 111 | -- 112 | -- Equations 3.53, 3.54 and 3.59 in Bishop. 113 | regressBayesianLinearModel 114 | :: (Vectorable input) 115 | => Precision -- ^ Precision of Gaussian weight prior 116 | -> Precision -- ^ Precision of noise on samples 117 | -> [input -> Double] -> DataSet input target -> (LinearModel input target, BayesianVarianceModel input) 118 | regressBayesianLinearModel alpha beta basis_fns ds 119 | = (LinearModel { lm_basis_fns = basis_fns, lm_weights = weights }, 120 | BayesianVarianceModel { bvm_basis_fns = basis_fns, bvm_inv_hessian = inv_hessian, bvm_beta = beta }) 121 | where 122 | design_matrix = regressDesignMatrix basis_fns (ds_inputs ds) 123 | (weights, inv_hessian) = bayesianPosteriorParameters alpha beta design_matrix (ds_targets ds) 124 | 125 | -- | Evidence-maximising Bayesian linear regression, using an isotropic Gaussian prior for the weights centred at the 126 | -- origin. The precision of the weight prior is controlled by the parameter alpha, and our belief about the inherent 127 | -- noise in the data is controlled by the precision parameter beta. 128 | -- 129 | -- This is similar to 'bayesianLinearRegression', but rather than just relying on the supplied values for alpha and beta 130 | -- an iterative procedure is used to try and find values that are best supported by the supplied training data. This is 131 | -- an excellent way of finding a reasonable trade-off between over-fitting of the training set with a complex model and 132 | -- accuracy of the model. 133 | -- 134 | -- As a bonus, this function returns gamma, the effective number of parameters used by the regressed model. 135 | -- 136 | -- For the case of multiple target variables, this function makes the naive Bayesian assumption that the probability 137 | -- distributions on output variables are independent, and takes as an error metric the unweighted sum-squared error 138 | -- in all the targets. The variance model is common to all the target variables. 139 | -- 140 | -- Equations 3.87, 3.92 and 3.95 in Bishop. 141 | regressEMBayesianLinearModel 142 | :: (Vectorable input, Vectorable target, MetricSpace target) 143 | => Precision -- ^ Initial estimate of Gaussian weight prior 144 | -> Precision -- ^ Initial estimate for precision of noise on samples 145 | -> [input -> Double] -> DataSet input target -> (LinearModel input target, BayesianVarianceModel input, EffectiveNumberOfParameters) 146 | regressEMBayesianLinearModel initial_alpha initial_beta basis_fns ds 147 | = convergeOnEMBayesianLinearModel loopWorker design_matrix initial_alpha initial_beta basis_fns ds 148 | where 149 | n = fromIntegral $ dataSetSize ds 150 | 151 | design_matrix = regressDesignMatrix basis_fns (ds_inputs ds) 152 | -- The unscaled eigenvalues will be positive because phi ^ T * phi is positive definite. 153 | (unscaled_eigenvalues, _) = eigSH (trans design_matrix <> design_matrix) 154 | 155 | loopWorker alpha beta = (n - gamma, gamma) 156 | where 157 | -- We save computation by calculating eigenvalues once for the design matrix and rescaling each iteration 158 | eigenvalues = beta .* unscaled_eigenvalues 159 | gamma = vectorSum (eigenvalues / (addConstant alpha eigenvalues)) 160 | 161 | -- | Evidence-maximising Bayesian linear regression, using an isotropic Gaussian prior for the weights centred at the 162 | -- origin. The precision of the weight prior is controlled by the parameter alpha, and our belief about the inherent 163 | -- noise in the data iscontrolled by the precision parameter beta. 164 | -- 165 | -- This is similar to 'regressEMBayesianLinearModel', but suitable only for the situation where there is much more 166 | -- training data than there are basis functions you want to assign weights to. Due to the introduction of this 167 | -- constraint, it is much faster than the other function and yet produces results of similar quality. 168 | -- 169 | -- Like with 'regressEMBayesianLinearModel', the effective number of parameters, gamma, used by the regressed model 170 | -- is returned. However, because for this function to make sense you need to be sure that there is sufficient data 171 | -- that all the parameters are determined, the returned gamma is always just the number of basis functions (and 172 | -- hence weights). 173 | -- 174 | -- For the case of multiple target variables, this function makes the naive Bayesian assumption that the probability 175 | -- distributions on output variables are independent, and takes as an error metric the unweighted sum-squared error 176 | -- in all the targets. The variance model is common to all the target variables. 177 | -- 178 | -- Equations 3.98 and 3.99 in Bishop. 179 | regressFullyDeterminedEMBayesianLinearModel 180 | :: (Vectorable input, Vectorable target, MetricSpace target) 181 | => Precision -- ^ Initial estimate of Gaussian weight prior 182 | -> Precision -- ^ Initial estimate for precision of noise on samples 183 | -> [input -> Double] -> DataSet input target -> (LinearModel input target, BayesianVarianceModel input, EffectiveNumberOfParameters) 184 | regressFullyDeterminedEMBayesianLinearModel initial_alpha initial_beta basis_fns ds 185 | = convergeOnEMBayesianLinearModel loopWorker design_matrix initial_alpha initial_beta basis_fns ds 186 | where 187 | n = fromIntegral $ dataSetSize ds 188 | m = fromIntegral $ length basis_fns 189 | 190 | design_matrix = regressDesignMatrix basis_fns (ds_inputs ds) 191 | 192 | -- In the limit n >> m, n - gamma = n, so we use that as the beta numerator 193 | -- We assume all paramaters are determined because n >> m, so we return m as gamma 194 | loopWorker _ _ = (n, m) 195 | 196 | convergeOnEMBayesianLinearModel 197 | :: (Vectorable input, Vectorable target, MetricSpace target) 198 | => (Precision -> Precision -> (Double, EffectiveNumberOfParameters)) -- ^ Loop worker: given alpha and beta, return new beta numerator and gamma 199 | -> Matrix Double -- ^ Design matrix 200 | -> Precision -- ^ Initial alpha 201 | -> Precision -- ^ Initial beta 202 | -> [input -> Double] -- ^ Basis functions 203 | -> DataSet input target 204 | -> (LinearModel input target, BayesianVarianceModel input, EffectiveNumberOfParameters) 205 | convergeOnEMBayesianLinearModel loop_worker design_matrix initial_alpha initial_beta basis_fns ds 206 | = loop eps initial_alpha initial_beta False 207 | where 208 | loop threshold alpha beta done 209 | | done = (linear_model, BayesianVarianceModel { bvm_basis_fns = basis_fns, bvm_inv_hessian = inv_hessian, bvm_beta = beta }, gamma) 210 | | otherwise = loop (threshold * 2) alpha' beta' (eqWithin threshold alpha alpha' && eqWithin threshold beta beta') 211 | where 212 | (weights, inv_hessian) = bayesianPosteriorParameters alpha beta design_matrix (ds_targets ds) 213 | linear_model = LinearModel { lm_basis_fns = basis_fns, lm_weights = weights } 214 | 215 | (beta_numerator, gamma) = loop_worker alpha beta 216 | 217 | -- This alpha computation is not the most efficient way to get the result, but it is idiomatic. 218 | -- This is the modification to the algorithm in Bishop that generalises the result to the case 219 | -- of multiple target variables, but to prove that this is the right thing to do I had to make the 220 | -- naive Bayesian assumption. 221 | -- 222 | -- The reason that this is correct because under naive Bayes: 223 | -- 224 | -- dE(W) K T T 225 | -- ------- = \Sigma W * W = Tr (W * W) 226 | -- d\alpha k = 1 k k 227 | alpha' = gamma / (matrixTrace $ (trans weights) <> weights) 228 | beta' = beta_numerator / modelSumSquaredError linear_model ds -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 2, June 1991 3 | 4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc., 5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | Preamble 10 | 11 | The licenses for most software are designed to take away your 12 | freedom to share and change it. By contrast, the GNU General Public 13 | License is intended to guarantee your freedom to share and change free 14 | software--to make sure the software is free for all its users. This 15 | General Public License applies to most of the Free Software 16 | Foundation's software and to any other program whose authors commit to 17 | using it. (Some other Free Software Foundation software is covered by 18 | the GNU Lesser General Public License instead.) You can apply it to 19 | your programs, too. 20 | 21 | When we speak of free software, we are referring to freedom, not 22 | price. Our General Public Licenses are designed to make sure that you 23 | have the freedom to distribute copies of free software (and charge for 24 | this service if you wish), that you receive source code or can get it 25 | if you want it, that you can change the software or use pieces of it 26 | in new free programs; and that you know you can do these things. 27 | 28 | To protect your rights, we need to make restrictions that forbid 29 | anyone to deny you these rights or to ask you to surrender the rights. 30 | These restrictions translate to certain responsibilities for you if you 31 | distribute copies of the software, or if you modify it. 32 | 33 | For example, if you distribute copies of such a program, whether 34 | gratis or for a fee, you must give the recipients all the rights that 35 | you have. You must make sure that they, too, receive or can get the 36 | source code. And you must show them these terms so they know their 37 | rights. 38 | 39 | We protect your rights with two steps: (1) copyright the software, and 40 | (2) offer you this license which gives you legal permission to copy, 41 | distribute and/or modify the software. 42 | 43 | Also, for each author's protection and ours, we want to make certain 44 | that everyone understands that there is no warranty for this free 45 | software. If the software is modified by someone else and passed on, we 46 | want its recipients to know that what they have is not the original, so 47 | that any problems introduced by others will not reflect on the original 48 | authors' reputations. 49 | 50 | Finally, any free program is threatened constantly by software 51 | patents. We wish to avoid the danger that redistributors of a free 52 | program will individually obtain patent licenses, in effect making the 53 | program proprietary. To prevent this, we have made it clear that any 54 | patent must be licensed for everyone's free use or not licensed at all. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | GNU GENERAL PUBLIC LICENSE 60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 61 | 62 | 0. This License applies to any program or other work which contains 63 | a notice placed by the copyright holder saying it may be distributed 64 | under the terms of this General Public License. The "Program", below, 65 | refers to any such program or work, and a "work based on the Program" 66 | means either the Program or any derivative work under copyright law: 67 | that is to say, a work containing the Program or a portion of it, 68 | either verbatim or with modifications and/or translated into another 69 | language. (Hereinafter, translation is included without limitation in 70 | the term "modification".) Each licensee is addressed as "you". 71 | 72 | Activities other than copying, distribution and modification are not 73 | covered by this License; they are outside its scope. The act of 74 | running the Program is not restricted, and the output from the Program 75 | is covered only if its contents constitute a work based on the 76 | Program (independent of having been made by running the Program). 77 | Whether that is true depends on what the Program does. 78 | 79 | 1. You may copy and distribute verbatim copies of the Program's 80 | source code as you receive it, in any medium, provided that you 81 | conspicuously and appropriately publish on each copy an appropriate 82 | copyright notice and disclaimer of warranty; keep intact all the 83 | notices that refer to this License and to the absence of any warranty; 84 | and give any other recipients of the Program a copy of this License 85 | along with the Program. 86 | 87 | You may charge a fee for the physical act of transferring a copy, and 88 | you may at your option offer warranty protection in exchange for a fee. 89 | 90 | 2. You may modify your copy or copies of the Program or any portion 91 | of it, thus forming a work based on the Program, and copy and 92 | distribute such modifications or work under the terms of Section 1 93 | above, provided that you also meet all of these conditions: 94 | 95 | a) You must cause the modified files to carry prominent notices 96 | stating that you changed the files and the date of any change. 97 | 98 | b) You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | 103 | c) If the modified program normally reads commands interactively 104 | when run, you must cause it, when started running for such 105 | interactive use in the most ordinary way, to print or display an 106 | announcement including an appropriate copyright notice and a 107 | notice that there is no warranty (or else, saying that you provide 108 | a warranty) and that users may redistribute the program under 109 | these conditions, and telling the user how to view a copy of this 110 | License. (Exception: if the Program itself is interactive but 111 | does not normally print such an announcement, your work based on 112 | the Program is not required to print an announcement.) 113 | 114 | These requirements apply to the modified work as a whole. If 115 | identifiable sections of that work are not derived from the Program, 116 | and can be reasonably considered independent and separate works in 117 | themselves, then this License, and its terms, do not apply to those 118 | sections when you distribute them as separate works. But when you 119 | distribute the same sections as part of a whole which is a work based 120 | on the Program, the distribution of the whole must be on the terms of 121 | this License, whose permissions for other licensees extend to the 122 | entire whole, and thus to each and every part regardless of who wrote it. 123 | 124 | Thus, it is not the intent of this section to claim rights or contest 125 | your rights to work written entirely by you; rather, the intent is to 126 | exercise the right to control the distribution of derivative or 127 | collective works based on the Program. 128 | 129 | In addition, mere aggregation of another work not based on the Program 130 | with the Program (or with a work based on the Program) on a volume of 131 | a storage or distribution medium does not bring the other work under 132 | the scope of this License. 133 | 134 | 3. You may copy and distribute the Program (or a work based on it, 135 | under Section 2) in object code or executable form under the terms of 136 | Sections 1 and 2 above provided that you also do one of the following: 137 | 138 | a) Accompany it with the complete corresponding machine-readable 139 | source code, which must be distributed under the terms of Sections 140 | 1 and 2 above on a medium customarily used for software interchange; or, 141 | 142 | b) Accompany it with a written offer, valid for at least three 143 | years, to give any third party, for a charge no more than your 144 | cost of physically performing source distribution, a complete 145 | machine-readable copy of the corresponding source code, to be 146 | distributed under the terms of Sections 1 and 2 above on a medium 147 | customarily used for software interchange; or, 148 | 149 | c) Accompany it with the information you received as to the offer 150 | to distribute corresponding source code. (This alternative is 151 | allowed only for noncommercial distribution and only if you 152 | received the program in object code or executable form with such 153 | an offer, in accord with Subsection b above.) 154 | 155 | The source code for a work means the preferred form of the work for 156 | making modifications to it. For an executable work, complete source 157 | code means all the source code for all modules it contains, plus any 158 | associated interface definition files, plus the scripts used to 159 | control compilation and installation of the executable. However, as a 160 | special exception, the source code distributed need not include 161 | anything that is normally distributed (in either source or binary 162 | form) with the major components (compiler, kernel, and so on) of the 163 | operating system on which the executable runs, unless that component 164 | itself accompanies the executable. 165 | 166 | If distribution of executable or object code is made by offering 167 | access to copy from a designated place, then offering equivalent 168 | access to copy the source code from the same place counts as 169 | distribution of the source code, even though third parties are not 170 | compelled to copy the source along with the object code. 171 | 172 | 4. You may not copy, modify, sublicense, or distribute the Program 173 | except as expressly provided under this License. Any attempt 174 | otherwise to copy, modify, sublicense or distribute the Program is 175 | void, and will automatically terminate your rights under this License. 176 | However, parties who have received copies, or rights, from you under 177 | this License will not have their licenses terminated so long as such 178 | parties remain in full compliance. 179 | 180 | 5. You are not required to accept this License, since you have not 181 | signed it. However, nothing else grants you permission to modify or 182 | distribute the Program or its derivative works. These actions are 183 | prohibited by law if you do not accept this License. Therefore, by 184 | modifying or distributing the Program (or any work based on the 185 | Program), you indicate your acceptance of this License to do so, and 186 | all its terms and conditions for copying, distributing or modifying 187 | the Program or works based on it. 188 | 189 | 6. Each time you redistribute the Program (or any work based on the 190 | Program), the recipient automatically receives a license from the 191 | original licensor to copy, distribute or modify the Program subject to 192 | these terms and conditions. You may not impose any further 193 | restrictions on the recipients' exercise of the rights granted herein. 194 | You are not responsible for enforcing compliance by third parties to 195 | this License. 196 | 197 | 7. If, as a consequence of a court judgment or allegation of patent 198 | infringement or for any other reason (not limited to patent issues), 199 | conditions are imposed on you (whether by court order, agreement or 200 | otherwise) that contradict the conditions of this License, they do not 201 | excuse you from the conditions of this License. If you cannot 202 | distribute so as to satisfy simultaneously your obligations under this 203 | License and any other pertinent obligations, then as a consequence you 204 | may not distribute the Program at all. For example, if a patent 205 | license would not permit royalty-free redistribution of the Program by 206 | all those who receive copies directly or indirectly through you, then 207 | the only way you could satisfy both it and this License would be to 208 | refrain entirely from distribution of the Program. 209 | 210 | If any portion of this section is held invalid or unenforceable under 211 | any particular circumstance, the balance of the section is intended to 212 | apply and the section as a whole is intended to apply in other 213 | circumstances. 214 | 215 | It is not the purpose of this section to induce you to infringe any 216 | patents or other property right claims or to contest validity of any 217 | such claims; this section has the sole purpose of protecting the 218 | integrity of the free software distribution system, which is 219 | implemented by public license practices. Many people have made 220 | generous contributions to the wide range of software distributed 221 | through that system in reliance on consistent application of that 222 | system; it is up to the author/donor to decide if he or she is willing 223 | to distribute software through any other system and a licensee cannot 224 | impose that choice. 225 | 226 | This section is intended to make thoroughly clear what is believed to 227 | be a consequence of the rest of this License. 228 | 229 | 8. If the distribution and/or use of the Program is restricted in 230 | certain countries either by patents or by copyrighted interfaces, the 231 | original copyright holder who places the Program under this License 232 | may add an explicit geographical distribution limitation excluding 233 | those countries, so that distribution is permitted only in or among 234 | countries not thus excluded. In such case, this License incorporates 235 | the limitation as if written in the body of this License. 236 | 237 | 9. The Free Software Foundation may publish revised and/or new versions 238 | of the General Public License from time to time. Such new versions will 239 | be similar in spirit to the present version, but may differ in detail to 240 | address new problems or concerns. 241 | 242 | Each version is given a distinguishing version number. If the Program 243 | specifies a version number of this License which applies to it and "any 244 | later version", you have the option of following the terms and conditions 245 | either of that version or of any later version published by the Free 246 | Software Foundation. If the Program does not specify a version number of 247 | this License, you may choose any version ever published by the Free Software 248 | Foundation. 249 | 250 | 10. If you wish to incorporate parts of the Program into other free 251 | programs whose distribution conditions are different, write to the author 252 | to ask for permission. For software which is copyrighted by the Free 253 | Software Foundation, write to the Free Software Foundation; we sometimes 254 | make exceptions for this. Our decision will be guided by the two goals 255 | of preserving the free status of all derivatives of our free software and 256 | of promoting the sharing and reuse of software generally. 257 | 258 | NO WARRANTY 259 | 260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 268 | REPAIR OR CORRECTION. 269 | 270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 278 | POSSIBILITY OF SUCH DAMAGES. --------------------------------------------------------------------------------