├── README ├── cholesky.rkt ├── dirichlet.rkt ├── histogram.rkt ├── main.rkt ├── more-math.rkt ├── multivariate-normal.rkt └── tests ├── cholesky.rkt └── multivariate-normal.rkt /README: -------------------------------------------------------------------------------- 1 | Racket Machine Learning 2 | 3 | A collection of things I found useful for doing Machine Learning problem sets. 4 | 5 | I wouldn't rely on anything that doesn't have a corresponding file in the tests/ 6 | folder. It's all in development. 7 | 8 | The implementation of the Dirichlet distribution owes inspiration to Luc 9 | Devroye's book, "Non-Uniform Random Variate Generation", which is freely 10 | available online [1]. 11 | 12 | [1] http://luc.devroye.org/rnbookindex.html -------------------------------------------------------------------------------- /cholesky.rkt: -------------------------------------------------------------------------------- 1 | #lang typed/racket 2 | (require math) 3 | (provide cholesky) 4 | 5 | (define mref matrix-ref) 6 | 7 | (: compute-diagonal-entry : ([Matrix Real] 8 | [Matrix Real] 9 | Natural 10 | -> 11 | Real)) 12 | (define (compute-diagonal-entry A L k) 13 | (define entry-squared 14 | (- (mref A k k) 15 | (for/sum: : Real 16 | ([j (in-range 0 k)]) 17 | (sqr (mref L k j))))) 18 | (if (< entry-squared 0) 19 | (error 'cholesky 20 | (string-append 21 | "This matrix isn't positive definite! " 22 | "Computing diagonal ~a, A_kk is ~a. " 23 | "k'th partial row is ~a. " 24 | "Entry squared values is ~a. ") 25 | k 26 | (mref A k k) 27 | (for/vector: : [Vectorof Real] 28 | ([j (in-range 0 k)]) 29 | (mref L k j)) 30 | entry-squared) 31 | (sqrt entry-squared))) 32 | 33 | (: build-cholesky-column : ([Matrix Real] 34 | [Matrix Real] 35 | Natural 36 | Natural 37 | -> 38 | [Matrix Real])) 39 | ;; Given a positive definite matrix A and a partial cholesky factor of A such 40 | ;; that all columns less than k are already properly filled out 41 | (define (build-cholesky-column A L k n) 42 | (define diagonal-entry 43 | (compute-diagonal-entry A L k)) 44 | 45 | (for/array: #:shape (vector n 1) 46 | ([i (in-range 0 n)]) 47 | : Real 48 | (cond [(< i k) 0] 49 | [(= i k) diagonal-entry] 50 | [else 51 | (* (/ 1 diagonal-entry) 52 | (- (mref A i k) 53 | (for/sum: : Real 54 | ([j (in-range 0 k)]) 55 | (* (mref L i j) (mref L k j)))))]))) 56 | 57 | (: cholesky : ((Matrix Real) -> (Matrix Real))) 58 | (define (cholesky A) 59 | (define mref matrix-ref) 60 | (define n (matrix-num-rows A)) 61 | 62 | (for*/fold: : [Matrix Real] 63 | ([L : [Matrix Real] (build-cholesky-column A 64 | (make-matrix 0 0 0) 65 | 0 66 | n)]) 67 | ([k : Natural (in-range 1 n)]) 68 | (matrix-augment (list L 69 | (build-cholesky-column A L k n))))) 70 | -------------------------------------------------------------------------------- /dirichlet.rkt: -------------------------------------------------------------------------------- 1 | #lang typed/racket 2 | 3 | (require math 4 | "more-math.rkt" 5 | ) 6 | 7 | (provide dirichlet-dist 8 | ) 9 | 10 | (define-type RealVector [Vectorof Real]) 11 | 12 | (: dirichlet-dist : (RealVector 13 | -> 14 | (distribution RealVector RealVector))) 15 | (define (dirichlet-dist alphas) 16 | (: dirichlet-pdf : (case-> (RealVector -> Flonum) 17 | (RealVector (U Any False) -> Flonum))) 18 | (define (dirichlet-pdf xs [log? #f]) 19 | (if (for/or: : Boolean 20 | ([x : Real (in-vector xs)]) 21 | (or (< x 0) (> x 1))) 22 | (error 'dirichlet-pdf 23 | (string-append "All elements of the vector must be between 0 " 24 | "and 1, inclusive, given ~a.") 25 | xs) 26 | (if log? 27 | (- (for/fold: : Flonum 28 | ([result : Flonum #i0]) 29 | ([alpha alphas] 30 | [x xs]) 31 | (+ result (fl (* (sub1 alpha) (log (max 0 x))))))) 32 | (/ (for/fold: : Flonum 33 | ([result : Flonum #i1]) 34 | ([alpha alphas] 35 | [x xs]) 36 | (* result (fl (expt (max 0 x) (sub1 alpha))))) 37 | (multivariate-beta alphas))))) 38 | 39 | (: dirichlet-sampler : (case-> (-> RealVector) 40 | (Integer -> [Listof RealVector]))) 41 | (define (dirichlet-sampler [n #f]) 42 | (if n 43 | (for/list: : [Listof RealVector] 44 | ([i (in-range 0 n)]) 45 | (dirichlet-sample-one)) 46 | (dirichlet-sample-one))) 47 | 48 | (: dirichlet-sample-one : (-> RealVector)) 49 | (define (dirichlet-sample-one) 50 | (let* ((gammas (for/vector: : RealVector 51 | ([alpha alphas]) 52 | (sample (gamma-dist alpha 1)))) 53 | (sum (for/fold: : Flonum ([sum #i0]) ([gamma gammas]) 54 | (+ sum gamma)))) 55 | (for/vector: : RealVector 56 | ([gamma gammas]) 57 | (/ gamma sum)))) 58 | 59 | 60 | (distribution dirichlet-pdf dirichlet-sampler)) 61 | -------------------------------------------------------------------------------- /histogram.rkt: -------------------------------------------------------------------------------- 1 | #lang typed/racket 2 | 3 | (require plot/typed 4 | (only-in plot/typed/utils linear-seq) 5 | (only-in racket/snip image-snip%) 6 | ) 7 | 8 | 9 | (require/typed racket 10 | [in-value (All (X) (X -> [Sequenceof X]))] 11 | [in-cycle (All (X) ([Sequenceof X] * -> [Sequenceof X]))] 12 | ) 13 | 14 | (provide (struct-out histogram) 15 | histogram->function 16 | hist-gen&render 17 | histogram->renderer 18 | generate-histogram 19 | ) 20 | 21 | ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; 22 | ;; Histogram 23 | 24 | (struct: histogram 25 | ([bins : [Vectorof Real]] 26 | [left : Real] 27 | [right : Real] 28 | [number-of-bins : Natural] 29 | [bin-width : Real] 30 | [which-bin : (Real -> Natural)])) 31 | 32 | ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; 33 | 34 | (: histogram->function : (histogram -> (Real -> Real))) 35 | (define/match (histogram->function h) 36 | [((histogram bins left right _ _ which-bin)) 37 | (lambda: ([n : Real]) 38 | (when (or (< n left) (> n right)) 39 | (error 'histogram-as-function 40 | "The value ~a is not within the range of this histogram [~a,~a]" 41 | n left right)) 42 | ((inst vector-ref Real) bins (which-bin n)))]) 43 | 44 | (: bucketize : ((Sequenceof Real) 45 | (Sequenceof Real) 46 | Natural 47 | (Real -> Natural) 48 | Real 49 | [#:normalize? Boolean] 50 | -> 51 | (Vectorof Real))) 52 | (define (bucketize data 53 | weights 54 | number-of-bins 55 | which-bin 56 | bin-width 57 | #:normalize? [normalize? #f]) 58 | (define bins 59 | (for/fold: 60 | ([hist : [Vectorof Real] (make-vector number-of-bins 0)]) 61 | ([d : Real data] 62 | [n : Real weights]) 63 | (let ((bin (which-bin d))) 64 | (vector-set! hist bin (+ (vector-ref hist bin) n)) 65 | hist))) 66 | 67 | (if normalize? (normalize-vector bins bin-width) bins)) 68 | 69 | (: generate-which-bin : (Real Real Natural Real -> (Real -> Natural))) 70 | ;; generate a function which partitions the line segment [left, right] into 71 | ;; number-of-bins (almost) equally sized bins. 72 | ;; 73 | ;; NB: The last bin is inclusive of both endpoints whereas all other bins are 74 | ;; exclusive of the right-most endpoint. 75 | (define (generate-which-bin left right number-of-bins bin-width) 76 | (lambda: ([value : Real]) 77 | (if (and (>= value left) (<= value right)) 78 | (max (min (exact-floor (/ (- value left) bin-width)) 79 | (sub1 number-of-bins)) ; catch the right end-point 80 | 0) ; appease the type checker 81 | (error 'histogram-which-bin 82 | "Value ~a is not within range [~a,~a]." 83 | value left right)))) 84 | 85 | (: normalize-vector ((Vectorof Real) Real -> (Vectorof Real))) 86 | (define (normalize-vector v bin-width) 87 | (define sum (for/sum: : Real ([value : Real (in-vector v)]) 88 | (* bin-width value))) 89 | 90 | (for/vector: : (Vectorof Real) ([value : Real (in-vector v)]) 91 | (/ value sum))) 92 | 93 | (: sequence-of-ones : [Sequenceof Natural]) 94 | (define sequence-of-ones (in-cycle (in-value 1))) 95 | 96 | (: generate-histogram : ([Sequenceof Real] 97 | Natural 98 | [#:weights [Sequenceof Natural]] 99 | [#:normalize? Boolean] 100 | -> 101 | histogram)) 102 | (define (generate-histogram data 103 | number-of-bins 104 | #:weights [weights sequence-of-ones] 105 | #:normalize? [normalize? #t]) 106 | (define left (exact-floor (apply min (sequence->list data)))) 107 | (define right (exact-ceiling (apply max (sequence->list data)))) 108 | (define total-width (max 0 (- right left))) 109 | (define bin-width (max 0 (/ total-width number-of-bins))) 110 | (define which-bin 111 | (generate-which-bin left right number-of-bins bin-width)) 112 | 113 | (histogram (bucketize data 114 | weights 115 | number-of-bins 116 | which-bin 117 | bin-width 118 | #:normalize? normalize?) 119 | left right 120 | number-of-bins 121 | bin-width 122 | which-bin)) 123 | 124 | (: histogram->renderer : (histogram -> renderer2d)) 125 | (define/match (histogram->renderer h) 126 | [((histogram _ left right number-of-bins _ _)) 127 | (area-histogram (histogram->function h) (linear-seq left right number-of-bins))]) 128 | 129 | ;; poorly named, but skips the intermediate step for the impatient among us 130 | (: hist-gen&render : ([Sequenceof Real] 131 | Natural 132 | [#:weights [Sequenceof Natural]] 133 | [#:normalize? Boolean] 134 | -> 135 | renderer2d)) 136 | (define (hist-gen&render data 137 | number-of-bins 138 | #:weights [weights sequence-of-ones] 139 | #:normalize? [normalize? #t]) 140 | (histogram->renderer (generate-histogram data 141 | number-of-bins 142 | #:weights weights 143 | #:normalize? normalize?))) 144 | -------------------------------------------------------------------------------- /main.rkt: -------------------------------------------------------------------------------- 1 | #lang typed/racket 2 | 3 | (require "cholesky.rkt" 4 | "dirichlet.rkt" 5 | "histogram.rkt" 6 | "more-math.rkt" 7 | ) 8 | 9 | (provide ;; cholesky 10 | cholesky 11 | ;; dirichlet 12 | dirichlet-dist 13 | ;; histogram 14 | (struct-out histogram) 15 | histogram->function 16 | hist-gen&render 17 | histogram->renderer 18 | generate-histogram 19 | ;; more-math 20 | multivariate-beta 21 | ) 22 | -------------------------------------------------------------------------------- /more-math.rkt: -------------------------------------------------------------------------------- 1 | #lang typed/racket 2 | 3 | (require math 4 | ) 5 | (provide multivariate-beta 6 | ) 7 | 8 | (: multivariate-beta : ([Vectorof Real] [#:log? Boolean] -> Real)) 9 | (define (multivariate-beta vec #:log? [log? #f]) 10 | (if log? 11 | (- (for/sum: : Real ([v vec]) (log-gamma v)) 12 | (log-gamma (for/sum: : Real ([v vec]) v))) 13 | (/ (for/product: : Real ([v vec]) (gamma v)) 14 | (gamma (for/sum: : Real ([v vec]) v))))) 15 | -------------------------------------------------------------------------------- /multivariate-normal.rkt: -------------------------------------------------------------------------------- 1 | #lang typed/racket 2 | 3 | (require math 4 | racket-ml 5 | (only-in racket/flonum flexpt) 6 | ) 7 | 8 | (provide multivariate-normal-dist 9 | pdf-multivariate-normal 10 | sample-multivariate-normal 11 | ) 12 | 13 | (: multivariate-normal-dist : ([Matrix Real] 14 | [Matrix Real] 15 | -> 16 | (distribution [Matrix Real] 17 | [Matrix Flonum]))) 18 | (define (multivariate-normal-dist mean covariance) 19 | (distribution (pdf-multivariate-normal mean covariance) 20 | (sample-multivariate-normal mean covariance))) 21 | 22 | ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; 23 | ;; Note that the covariance matrices here must be nonnegative-definite in order 24 | ;; to produce true multivariate-normal distributions. I do not check this 25 | ;; property (primarily because there doesn't currently exist an eigenvalue 26 | ;; function and nonnegative-definite can be checked mostly easily with the 27 | ;; eigenvalues). 28 | 29 | (: pdf-multivariate-normal : ([Matrix Real] 30 | [Matrix Real] 31 | -> 32 | (PDF [Matrix Real]))) 33 | (define (pdf-multivariate-normal mean covariance) 34 | (define det (fl (matrix-determinant covariance))) 35 | (define k (fl (square-matrix-size covariance))) 36 | (define 2pi^-k (flexpt (/ 1 (* 2 pi)) k)) 37 | (define coefficient 38 | (if (or (< 2pi^-k 0) (<= det 0)) 39 | (begin 40 | (displayln covariance) 41 | (error 'pdf-multivariate-normal 42 | (string-append 43 | "I'll bet you the covariance matrix (should be printed above) " 44 | "isn't nonngetaive-definite. Tried to take the square root of " 45 | "a negative number."))) 46 | (* (sqrt 2pi^-k) 47 | (sqrt (/ 1.0 det))))) 48 | (: exponent : [Matrix Real] -> Flonum) 49 | (define (exponent x) 50 | (define difference (matrix- x mean)) 51 | 52 | (- (* 0.5 53 | (fl (1x1-matrix->scalar (matrix* (matrix-transpose difference) 54 | (matrix-inverse covariance) 55 | difference)))))) 56 | 57 | (: pdf : (PDF [Matrix Real])) 58 | (define (pdf x [log? #f]) 59 | (if log? 60 | (+ (log coefficient) (exponent x)) 61 | (* coefficient (exp (exponent x))))) 62 | 63 | pdf) 64 | 65 | (: sample-multivariate-normal : ([Matrix Real] 66 | [Matrix Real] 67 | -> 68 | (Sample [Matrix Flonum]))) 69 | (define (sample-multivariate-normal mean covariance) 70 | (define L (array->flarray (cholesky covariance))) 71 | 72 | (: sample : (Sample [Matrix Flonum])) 73 | (define (sample [samples #f]) 74 | (if samples 75 | (for/list: : [Listof [Matrix Flonum]] 76 | ([i (in-range 0 samples)]) 77 | (matrix+ mean (matrix* L (base-mvn (square-matrix-size covariance))))) 78 | (matrix+ mean (matrix* L (base-mvn (square-matrix-size covariance)))))) 79 | 80 | sample) 81 | 82 | (: base-mvn : (Natural -> [Matrix Flonum])) 83 | (define (base-mvn n) 84 | (for/matrix: n 1 85 | ([i (in-range 0 n)]) 86 | : Flonum 87 | (sample (normal-dist 0 1)))) 88 | 89 | ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; 90 | ;; Utilities 91 | 92 | (: 1x1-matrix->scalar : (All (A) ([Matrix A] -> A))) 93 | ;; takes a 1x1 matrix and demotes the value to a scalar 94 | (define (1x1-matrix->scalar m) 95 | (if (= 1 (square-matrix-size m)) 96 | (matrix-ref m 0 0) 97 | (error '1x1-matrix->scalar 98 | "Matrix is not a 1x1, actually is ~ax~a." 99 | (square-matrix-size m)))) 100 | -------------------------------------------------------------------------------- /tests/cholesky.rkt: -------------------------------------------------------------------------------- 1 | #lang racket 2 | 3 | (require math 4 | rackunit 5 | "../cholesky.rkt" 6 | ) 7 | 8 | (define/provide-test-suite cholesky-tests 9 | (check-equal? (cholesky (matrix [[25 15 -5] 10 | [15 18 0] 11 | [-5 0 11]])) 12 | (matrix [[5 0 0] 13 | [3 3 0] 14 | [-1 1 3]])) 15 | 16 | (check-equal? 17 | (cholesky (matrix [[18 22 54 42] 18 | [22 70 86 62] 19 | [54 86 174 134] 20 | [42 62 134 106]])) 21 | (matrix [[4.242640687119285 0 0 0] 22 | [ 5.185449728701349 6.565905201197403 0 0] 23 | [12.727922061357857 3.0460384954008553 1.6497422479090704 0] 24 | [ 9.899494936611665 1.6245538642137891 1.849711005231382 1.3926212476455924]])) 25 | 26 | (check-equal? (cholesky (matrix [[4 12 -16] 27 | [12 37 -43] 28 | [-16 -43 98]])) 29 | (matrix [[2 6 -8] 30 | [0 1 5] 31 | [0 0 3]]))) 32 | -------------------------------------------------------------------------------- /tests/multivariate-normal.rkt: -------------------------------------------------------------------------------- 1 | #lang racket 2 | 3 | (require rackunit 4 | rackunit/text-ui 5 | math 6 | "../multivariate-normal.rkt" 7 | ) 8 | 9 | ;; I'm not really sure how to check distributions. I pulled these values from 10 | ;; the mvtnorm R package, so I lightly trust them. I should at least be as good 11 | ;; as R, right? 12 | 13 | (define zero-mean2 (matrix [[0] [0]])) 14 | (define identity-covariance2 (identity-matrix 2)) 15 | 16 | (define zero-mean3 (matrix [[0] [0] [0]])) 17 | (define identity-covariance3 (identity-matrix 3)) 18 | 19 | (define 10-5-0-mean (matrix [[10] [5] [0]])) 20 | (define all-5s-covariance3 (matrix [[5 5 5] 21 | [5 5 5] 22 | [5 5 5]])) 23 | 24 | (define epsilon 1e-7) 25 | 26 | (define-test-suite mvn-tests 27 | (test-case 28 | "2d tests" 29 | (check-= (pdf (multivariate-normal-dist zero-mean2 identity-covariance2) 30 | (matrix [[0] [0]])) 31 | 0.1591549 32 | epsilon) 33 | 34 | (for ((pair '((10 5) (5 0) (1.5 3.5) (0.1 0.8))) 35 | (result '(1.143971e-28 5.931153e-07 0.0001130278 0.1149938))) 36 | ;; because we're using identity covariance and zero mean the reversed 37 | ;; coordinates should have the same value 38 | ;; additionally the sign of the coordinates shouldn't matter 39 | (for ((x pair) ;; e.g. (10 5) 40 | (y (reverse pair)) ;; e.g. (5 10) 41 | ;; thus (x,y) is the sequence (10,5) (5,10) 42 | (signs `((,- ,-) (,- ,+) (,+ ,-) (,+ ,+)))) 43 | (check-= (pdf (multivariate-normal-dist zero-mean2 identity-covariance2) 44 | (matrix [[((first signs) x)] 45 | [((second signs) y)]])) 46 | result 47 | epsilon)))) 48 | 49 | (test-case 50 | "3d tests" 51 | (check-= (pdf (multivariate-normal-dist zero-mean3 identity-covariance3) 52 | (matrix [[0] [0] [0]])) 53 | 0.06349364 54 | epsilon) 55 | (check-= (pdf (multivariate-normal-dist 10-5-0-mean identity-covariance3) 56 | (matrix [[0] [0] [0]])) 57 | 4.563784e-29 58 | epsilon) 59 | (check-= (pdf (multivariate-normal-dist 10-5-0-mean identity-covariance3) 60 | (matrix [[1] [1] [1]])) 61 | 3.328899e-23 62 | epsilon) 63 | (check-= (pdf (multivariate-normal-dist 10-5-0-mean identity-covariance3) 64 | (matrix [[10] [10] [10]])) 65 | 4.563784e-29 66 | epsilon) 67 | (check-= (pdf (multivariate-normal-dist 10-5-0-mean identity-covariance3) 68 | (matrix [[10] [-10] [10]])) 69 | 4.563784e-29 70 | epsilon))) 71 | 72 | (run-tests mvn-tests) 73 | --------------------------------------------------------------------------------