├── go.mod ├── fit ├── package.go ├── loess_test.go ├── loess.go └── lsquares.go ├── graph ├── graphalg │ ├── package.go │ ├── order_test.go │ ├── marks_test.go │ ├── graph_test.go │ ├── dom_test.go │ ├── visit.go │ ├── order.go │ ├── multigraph.go │ ├── marks.go │ ├── scc_test.go │ ├── dom.go │ └── scc.go ├── graphout │ ├── package.go │ └── dot.go ├── weighted.go ├── eq.go ├── graph.go ├── subgraph_test.go └── subgraph.go ├── vec ├── package.go └── vec.go ├── go.sum ├── README.md ├── scale ├── util.go ├── err.go ├── package.go ├── ticks_test.go ├── interface.go ├── ticks.go ├── linear.go ├── linear_test.go ├── log.go └── log_test.go ├── mathx ├── package.go ├── sign.go ├── beta_test.go ├── gamma_test.go ├── choose.go ├── gamma.go └── beta.go ├── stats ├── kdekernel_string.go ├── kdeboundarymethod_string.go ├── locationhypothesis_string.go ├── hypergdist_test.go ├── package.go ├── normaldist_test.go ├── tdist.go ├── dist_test.go ├── linearhist.go ├── binomdist_test.go ├── deltadist.go ├── sample_test.go ├── util_test.go ├── hist.go ├── loghist.go ├── binomdist.go ├── kde_test.go ├── ttest_test.go ├── utest_test.go ├── stream.go ├── alg.go ├── hypergdist.go ├── tdist_test.go ├── normaldist.go ├── ttest.go ├── quantileci_test.go ├── dist.go ├── quantileci.go ├── utest.go ├── sample.go └── udist_test.go ├── LICENSE ├── internal └── mathtest │ └── mathtest.go └── cmd └── dist ├── dist.go └── plot.go /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/aclements/go-moremath 2 | 3 | go 1.22 4 | 5 | require gonum.org/v1/gonum v0.15.1 6 | -------------------------------------------------------------------------------- /fit/package.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package fit provides functions for fitting models to data. 6 | package fit 7 | -------------------------------------------------------------------------------- /graph/graphalg/package.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package graphalg implements common graph algorithms. 6 | package graphalg 7 | -------------------------------------------------------------------------------- /graph/graphout/package.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package graphout implements functions to write graphs to common 6 | // graph formats. 7 | package graphout 8 | -------------------------------------------------------------------------------- /vec/package.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package vec provides functions for float64 vectors. 6 | package vec // import "github.com/aclements/go-moremath/vec" 7 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ= 2 | golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= 3 | gonum.org/v1/gonum v0.15.1 h1:FNy7N6OUZVUaWG9pTiD+jlhdQ3lMP+/LcTpJ6+a8sQ0= 4 | gonum.org/v1/gonum v0.15.1/go.mod h1:eZTZuRFrzu5pcyjN5wJhcIhnUdNijYxX1T2IcrOGY0o= 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | These packages provide more specialized math routines than are 2 | available in the standard Go math package. go-moremath currently 3 | focuses on statistical routines, with particular focus on high-quality 4 | implementations and APIs for non-parametric methods. 5 | 6 | The API is not stable. 7 | 8 | Please see the [documentation](https://godoc.org/github.com/aclements/go-moremath). 9 | -------------------------------------------------------------------------------- /scale/util.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package scale 6 | 7 | // clamp clamps x to the range [0, 1]. 8 | func clamp(x float64) float64 { 9 | if x < 0 { 10 | return 0 11 | } 12 | if x > 1 { 13 | return 1 14 | } 15 | return x 16 | } 17 | -------------------------------------------------------------------------------- /scale/err.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package scale 6 | 7 | // RangeErr is an error that indicates some argument or value is out 8 | // of range. 9 | type RangeErr string 10 | 11 | func (r RangeErr) Error() string { 12 | return string(r) 13 | } 14 | -------------------------------------------------------------------------------- /mathx/package.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package mathx implements special functions not provided by the 6 | // standard math package. 7 | package mathx // import "github.com/aclements/go-moremath/mathx" 8 | 9 | import "math" 10 | 11 | var nan = math.NaN() 12 | -------------------------------------------------------------------------------- /stats/kdekernel_string.go: -------------------------------------------------------------------------------- 1 | // generated by stringer -type=KDEKernel; DO NOT EDIT 2 | 3 | package stats 4 | 5 | import "fmt" 6 | 7 | const _KDEKernel_name = "GaussianKernelDeltaKernel" 8 | 9 | var _KDEKernel_index = [...]uint8{0, 14, 25} 10 | 11 | func (i KDEKernel) String() string { 12 | if i < 0 || i+1 >= KDEKernel(len(_KDEKernel_index)) { 13 | return fmt.Sprintf("KDEKernel(%d)", i) 14 | } 15 | return _KDEKernel_name[_KDEKernel_index[i]:_KDEKernel_index[i+1]] 16 | } 17 | -------------------------------------------------------------------------------- /mathx/sign.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package mathx 6 | 7 | // Sign returns the sign of x: -1 if x < 0, 0 if x == 0, 1 if x > 0. 8 | // If x is NaN, it returns NaN. 9 | func Sign(x float64) float64 { 10 | if x == 0 { 11 | return 0 12 | } else if x < 0 { 13 | return -1 14 | } else if x > 0 { 15 | return 1 16 | } 17 | return nan 18 | } 19 | -------------------------------------------------------------------------------- /scale/package.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package scale provides abstractions for scales that map from one 6 | // domain to another and provide methods for indicating human-readable 7 | // intervals in the input domain. The most common type of scale is a 8 | // quantitative scale, such as a linear or log scale, which is 9 | // captured by the Quantitative interface. 10 | package scale 11 | -------------------------------------------------------------------------------- /stats/kdeboundarymethod_string.go: -------------------------------------------------------------------------------- 1 | // generated by stringer -type=KDEBoundaryMethod; DO NOT EDIT 2 | 3 | package stats 4 | 5 | import "fmt" 6 | 7 | const _KDEBoundaryMethod_name = "BoundaryReflect" 8 | 9 | var _KDEBoundaryMethod_index = [...]uint8{0, 15} 10 | 11 | func (i KDEBoundaryMethod) String() string { 12 | if i < 0 || i+1 >= KDEBoundaryMethod(len(_KDEBoundaryMethod_index)) { 13 | return fmt.Sprintf("KDEBoundaryMethod(%d)", i) 14 | } 15 | return _KDEBoundaryMethod_name[_KDEBoundaryMethod_index[i]:_KDEBoundaryMethod_index[i+1]] 16 | } 17 | -------------------------------------------------------------------------------- /stats/locationhypothesis_string.go: -------------------------------------------------------------------------------- 1 | // generated by stringer -type LocationHypothesis; DO NOT EDIT 2 | 3 | package stats 4 | 5 | import "fmt" 6 | 7 | const _LocationHypothesis_name = "LocationLessLocationDiffersLocationGreater" 8 | 9 | var _LocationHypothesis_index = [...]uint8{0, 12, 27, 42} 10 | 11 | func (i LocationHypothesis) String() string { 12 | i -= -1 13 | if i < 0 || i+1 >= LocationHypothesis(len(_LocationHypothesis_index)) { 14 | return fmt.Sprintf("LocationHypothesis(%d)", i+-1) 15 | } 16 | return _LocationHypothesis_name[_LocationHypothesis_index[i]:_LocationHypothesis_index[i+1]] 17 | } 18 | -------------------------------------------------------------------------------- /stats/hypergdist_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import ( 8 | "fmt" 9 | "testing" 10 | ) 11 | 12 | func TestHypergeometricDist(t *testing.T) { 13 | dist1 := HypergeometicDist{N: 50, K: 5, Draws: 10} 14 | testFunc(t, fmt.Sprintf("%+v.PMF", dist1), dist1.PMF, 15 | map[float64]float64{ 16 | -0.1: 0, 17 | 4: 0.003964583058, 18 | 4.9: 0.003964583058, // Test rounding 19 | 5: 0.000118937492, 20 | 5.9: 0.000118937492, 21 | 6: 0, 22 | }) 23 | testDiscreteCDF(t, fmt.Sprintf("%+v.CDF", dist1), dist1) 24 | } 25 | -------------------------------------------------------------------------------- /graph/graphalg/order_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package graphalg 6 | 7 | import ( 8 | "reflect" 9 | "testing" 10 | ) 11 | 12 | func TestPreOrder(t *testing.T) { 13 | po := PreOrder(graphMuchnick, 0) 14 | want := []int{0, 1, 2, 3, 4, 5, 7, 6} 15 | if !reflect.DeepEqual(want, po) { 16 | t.Errorf("want %v, got %v", want, po) 17 | } 18 | } 19 | 20 | func TestPostOrder(t *testing.T) { 21 | po := PostOrder(graphMuchnick, 0) 22 | want := []int{3, 7, 5, 6, 4, 2, 1, 0} 23 | if !reflect.DeepEqual(want, po) { 24 | t.Errorf("want %v, got %v", want, po) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /graph/weighted.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package graph 6 | 7 | // Weighted represented a weighted directed graph. 8 | type Weighted interface { 9 | Graph 10 | 11 | // OutWeight returns the weight of the e'th edge out from node 12 | // i. e must be in the range [0, len(Out(i))). 13 | OutWeight(i, e int) float64 14 | } 15 | 16 | // WeightedUnit wraps a graph as a weighted graph where all edges have 17 | // weight 1. 18 | type WeightedUnit struct { 19 | Graph 20 | } 21 | 22 | // OutWeight returns 1. 23 | func (w WeightedUnit) OutWeight(i, e int) float64 { 24 | return 1 25 | } 26 | -------------------------------------------------------------------------------- /graph/graphalg/marks_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package graphalg 6 | 7 | import ( 8 | "reflect" 9 | "testing" 10 | ) 11 | 12 | func TestNodeMarksNext(t *testing.T) { 13 | tests := [][]int{ 14 | {0}, 15 | {1}, 16 | {0, 4}, 17 | {}, // No marks 18 | {0, 100}, // Big gap 19 | } 20 | 21 | for _, test := range tests { 22 | m := NewNodeMarks() 23 | for _, id := range test { 24 | m.Mark(id) 25 | } 26 | got := []int{} 27 | for i := m.Next(-1); i >= 0; i = m.Next(i) { 28 | got = append(got, i) 29 | } 30 | if !reflect.DeepEqual(test, got) { 31 | t.Errorf("want %v, got %v", test, got) 32 | } 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /stats/package.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package stats implements several statistical distributions, 6 | // hypothesis tests, and functions for descriptive statistics. 7 | // 8 | // Currently stats is fairly small, but for what it does implement, it 9 | // focuses on high quality, fast implementations with good, idiomatic 10 | // Go APIs. 11 | package stats // import "github.com/aclements/go-moremath/stats" 12 | 13 | import ( 14 | "errors" 15 | "math" 16 | ) 17 | 18 | var inf = math.Inf(1) 19 | var nan = math.NaN() 20 | 21 | // TODO: Put all errors in the same place and maybe unify them. 22 | 23 | var ( 24 | ErrSamplesEqual = errors.New("all samples are equal") 25 | ) 26 | -------------------------------------------------------------------------------- /mathx/beta_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package mathx 6 | 7 | import ( 8 | "testing" 9 | 10 | . "github.com/aclements/go-moremath/internal/mathtest" 11 | ) 12 | 13 | func TestBetaInc(t *testing.T) { 14 | // Example values from MATLAB betainc documentation. 15 | WantFunc(t, "I_0.5(%v, 3)", 16 | func(a float64) float64 { return BetaInc(0.5, a, 3) }, 17 | map[float64]float64{ 18 | 0: 1.00000000000000, 19 | 1: 0.87500000000000, 20 | 2: 0.68750000000000, 21 | 3: 0.50000000000000, 22 | 4: 0.34375000000000, 23 | 5: 0.22656250000000, 24 | 6: 0.14453125000000, 25 | 7: 0.08984375000000, 26 | 8: 0.05468750000000, 27 | 9: 0.03271484375000, 28 | 10: 0.01928710937500}) 29 | } 30 | -------------------------------------------------------------------------------- /graph/graphalg/graph_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package graphalg 6 | 7 | import "github.com/aclements/go-moremath/graph" 8 | 9 | // Example graph from Muchnick, "Advanced Compiler Design & 10 | // Implementation", figure 8.21. 11 | var graphMuchnick = graph.MakeBiGraph(graph.IntGraph{ 12 | 0: {1}, 13 | 1: {2}, 14 | 2: {3, 4}, 15 | 3: {2}, 16 | 4: {5, 6}, 17 | 5: {7}, 18 | 6: {7}, 19 | 7: {}, 20 | }) 21 | 22 | // Example graph from 23 | // https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec04-SSA.pdf 24 | // slide 24. 25 | var graphCS252 = graph.MakeBiGraph(graph.IntGraph{ 26 | 0: {1}, 27 | 1: {2, 5}, 28 | 2: {3, 4}, 29 | 3: {6}, 30 | 4: {6}, 31 | 5: {1, 7}, 32 | 6: {7}, 33 | 7: {8}, 34 | 8: {}, 35 | }) 36 | -------------------------------------------------------------------------------- /stats/normaldist_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import ( 8 | "fmt" 9 | "math" 10 | "testing" 11 | ) 12 | 13 | func TestNormalDist(t *testing.T) { 14 | d := StdNormal 15 | 16 | testFunc(t, fmt.Sprintf("%+v.PDF", d), d.PDF, map[float64]float64{ 17 | -10000: 0, // approx 18 | -1: 1 / math.Sqrt(2*math.Pi) * math.Exp(-0.5), 19 | 0: 1 / math.Sqrt(2*math.Pi), 20 | 1: 1 / math.Sqrt(2*math.Pi) * math.Exp(-0.5), 21 | 10000: 0, // approx 22 | }) 23 | 24 | testFunc(t, fmt.Sprintf("%+v.CDF", d), d.CDF, map[float64]float64{ 25 | -10000: 0, // approx 26 | 0: 0.5, 27 | 10000: 1, // approx 28 | }) 29 | 30 | d2 := NormalDist{Mu: 2, Sigma: 5} 31 | testInvCDF(t, d, false) 32 | testInvCDF(t, d2, false) 33 | } 34 | -------------------------------------------------------------------------------- /stats/tdist.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import ( 8 | "math" 9 | 10 | "github.com/aclements/go-moremath/mathx" 11 | ) 12 | 13 | // A TDist is a Student's t-distribution with V degrees of freedom. 14 | type TDist struct { 15 | V float64 16 | } 17 | 18 | func lgamma(x float64) float64 { 19 | y, _ := math.Lgamma(x) 20 | return y 21 | } 22 | 23 | func (t TDist) PDF(x float64) float64 { 24 | return math.Exp(lgamma((t.V+1)/2)-lgamma(t.V/2)) / 25 | math.Sqrt(t.V*math.Pi) * math.Pow(1+(x*x)/t.V, -(t.V+1)/2) 26 | } 27 | 28 | func (t TDist) CDF(x float64) float64 { 29 | if x == 0 { 30 | return 0.5 31 | } else if x > 0 { 32 | return 1 - 0.5*mathx.BetaInc(t.V/(t.V+x*x), t.V/2, 0.5) 33 | } else if x < 0 { 34 | return 1 - t.CDF(-x) 35 | } else { 36 | return math.NaN() 37 | } 38 | } 39 | 40 | func (t TDist) Bounds() (float64, float64) { 41 | return -4, 4 42 | } 43 | -------------------------------------------------------------------------------- /graph/graphalg/dom_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package graphalg 6 | 7 | import ( 8 | "reflect" 9 | "testing" 10 | ) 11 | 12 | func TestIDom(t *testing.T) { 13 | idom := IDom(graphMuchnick, 0) 14 | want := []int{0: -1, 1: 0, 2: 1, 3: 2, 4: 2, 5: 4, 6: 4, 7: 4} 15 | if !reflect.DeepEqual(want, idom) { 16 | t.Errorf("graphMuchnick: want %v, got %v", want, idom) 17 | } 18 | 19 | idom = IDom(graphCS252, 0) 20 | want = []int{0: -1, 1: 0, 2: 1, 3: 2, 4: 2, 5: 1, 6: 2, 7: 1, 8: 7} 21 | if !reflect.DeepEqual(want, idom) { 22 | t.Errorf("graphCS252: want %v, got %v", want, idom) 23 | } 24 | } 25 | 26 | func TestDomFrontier(t *testing.T) { 27 | df := DomFrontier(graphCS252, 0, nil) 28 | want := [][]int{ 29 | 0: {}, 30 | 1: {1}, 31 | 2: {7}, 32 | 3: {6}, 33 | 4: {6}, 34 | 5: {1, 7}, 35 | 6: {7}, 36 | 7: {}, 37 | 8: {}, 38 | } 39 | if !reflect.DeepEqual(want, df) { 40 | t.Errorf("want %v, got %v", want, df) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /stats/dist_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import ( 8 | "fmt" 9 | "testing" 10 | ) 11 | 12 | type funnyCDF struct { 13 | left float64 14 | } 15 | 16 | func (f funnyCDF) CDF(x float64) float64 { 17 | switch { 18 | case x < f.left: 19 | return 0 20 | case x < f.left+1: 21 | return (x - f.left) / 2 22 | case x < f.left+2: 23 | return 0.5 24 | case x < f.left+3: 25 | return (x-f.left-2)/2 + 0.5 26 | default: 27 | return 1 28 | } 29 | } 30 | 31 | func (f funnyCDF) Bounds() (float64, float64) { 32 | return f.left, f.left + 3 33 | } 34 | 35 | func TestInvCDF(t *testing.T) { 36 | for _, f := range []funnyCDF{funnyCDF{1}, funnyCDF{-1.5}, funnyCDF{-4}} { 37 | testFunc(t, fmt.Sprintf("InvCDF(funnyCDF%+v)", f), InvCDF(f), 38 | map[float64]float64{ 39 | -0.1: nan, 40 | 0: f.left, 41 | 0.25: f.left + 0.5, 42 | 0.5: f.left + 1, 43 | 0.75: f.left + 2.5, 44 | 1: f.left + 3, 45 | 1.1: nan, 46 | }) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /graph/eq.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package graph 6 | 7 | import "sort" 8 | 9 | // Equal returns true if g1 and g2 have identical nodes and edges, 10 | // including the IDs of all nodes. 11 | func Equal(g1, g2 Graph) bool { 12 | n := g1.NumNodes() 13 | if n != g2.NumNodes() { 14 | return false 15 | } 16 | var temp []int 17 | for i := 0; i < n; i++ { 18 | e1 := g1.Out(i) 19 | e2 := g2.Out(i) 20 | if len(e1) != len(e2) { 21 | return false 22 | } 23 | // Quick check to see if they're identical without 24 | // sorting. 25 | eq := true 26 | for ei, x := range e1 { 27 | if e2[ei] != x { 28 | eq = false 29 | break 30 | } 31 | } 32 | if eq { 33 | continue 34 | } 35 | // Sort the adjacency list and check equality again. 36 | temp = append(append(temp[:0], e1...), e2...) 37 | e1, e2 = temp[:len(e1)], temp[len(e1):] 38 | sort.Ints(e1) 39 | sort.Ints(e2) 40 | for ei, x := range e1 { 41 | if e2[ei] != x { 42 | return false 43 | } 44 | } 45 | } 46 | 47 | return true 48 | } 49 | -------------------------------------------------------------------------------- /stats/linearhist.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | // LinearHist is a Histogram with uniformly-sized bins. 8 | type LinearHist struct { 9 | min, max float64 10 | delta float64 // 1/bin width (to avoid division in hot path) 11 | low, high uint 12 | bins []uint 13 | } 14 | 15 | // NewLinearHist returns an empty histogram with nbins uniformly-sized 16 | // bins spanning [min, max]. 17 | func NewLinearHist(min, max float64, nbins int) *LinearHist { 18 | delta := float64(nbins) / (max - min) 19 | return &LinearHist{min, max, delta, 0, 0, make([]uint, nbins)} 20 | } 21 | 22 | func (h *LinearHist) bin(x float64) int { 23 | return int(h.delta * (x - h.min)) 24 | } 25 | 26 | func (h *LinearHist) Add(x float64) { 27 | bin := h.bin(x) 28 | if bin < 0 { 29 | h.low++ 30 | } else if bin >= len(h.bins) { 31 | h.high++ 32 | } else { 33 | h.bins[bin]++ 34 | } 35 | } 36 | 37 | func (h *LinearHist) Counts() (uint, []uint, uint) { 38 | return h.low, h.bins, h.high 39 | } 40 | 41 | func (h *LinearHist) BinToValue(bin float64) float64 { 42 | return h.min + bin/h.delta 43 | } 44 | -------------------------------------------------------------------------------- /stats/binomdist_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import ( 8 | "fmt" 9 | "math" 10 | "testing" 11 | ) 12 | 13 | func TestBinomialDist(t *testing.T) { 14 | dist := BinomialDist{N: 5, P: 0.2} 15 | testFunc(t, fmt.Sprintf("%+v.PMF", dist), dist.PMF, 16 | map[float64]float64{ 17 | -1000: 0, 18 | -1: 0, 19 | 0: 0.32768, 20 | 1: 0.4096, 21 | 2: 0.2048, 22 | 3: 0.0512, 23 | 4: 0.0064, 24 | 5: math.Pow(dist.P, 5), 25 | 6: 0, 26 | 1000: 0, 27 | }) 28 | testDiscreteCDF(t, fmt.Sprintf("%+v.CDF", dist), dist) 29 | 30 | dist = BinomialDist{N: 30, P: 0.5} 31 | norm := dist.NormalApprox() 32 | for k := 10; k <= 20; k++ { 33 | b := dist.PMF(float64(k)) 34 | n := norm.CDF(float64(k)+0.5) - norm.CDF(float64(k)-0.5) 35 | 36 | // The normal approximation isn't actually very close, 37 | // even with high N and P near 0.5, so we only check 38 | // the center of the distribution and we're pretty 39 | // lax. 40 | err := math.Abs(b/n - 1) 41 | if err > 0.01 { 42 | t.Errorf("want %v ≅ %v at %d", b, n, k) 43 | } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /stats/deltadist.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | // DeltaDist is the Dirac delta function, centered at T, with total 8 | // area 1. 9 | // 10 | // The CDF of the Dirac delta function is the Heaviside step function, 11 | // centered at T. Specifically, f(T) == 1. 12 | type DeltaDist struct { 13 | T float64 14 | } 15 | 16 | func (d DeltaDist) PDF(x float64) float64 { 17 | if x == d.T { 18 | return inf 19 | } 20 | return 0 21 | } 22 | 23 | func (d DeltaDist) pdfEach(xs []float64) []float64 { 24 | res := make([]float64, len(xs)) 25 | for i, x := range xs { 26 | if x == d.T { 27 | res[i] = inf 28 | } 29 | } 30 | return res 31 | } 32 | 33 | func (d DeltaDist) CDF(x float64) float64 { 34 | if x >= d.T { 35 | return 1 36 | } 37 | return 0 38 | } 39 | 40 | func (d DeltaDist) cdfEach(xs []float64) []float64 { 41 | res := make([]float64, len(xs)) 42 | for i, x := range xs { 43 | res[i] = d.CDF(x) 44 | } 45 | return res 46 | } 47 | 48 | func (d DeltaDist) InvCDF(y float64) float64 { 49 | if y < 0 || y > 1 { 50 | return nan 51 | } 52 | return d.T 53 | } 54 | 55 | func (d DeltaDist) Bounds() (float64, float64) { 56 | return d.T - 1, d.T + 1 57 | } 58 | -------------------------------------------------------------------------------- /graph/graphalg/visit.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package graphalg 6 | 7 | import ( 8 | "github.com/aclements/go-moremath/graph" 9 | ) 10 | 11 | // Euler visits a graph using an Euler tour. 12 | // 13 | // For a tree, the Euler tour is well-defined and unique (given an 14 | // ordering of the children of a node). For a general graph, this uses 15 | // the tree formed by the pre-order traversal of the graph. 16 | type Euler struct { 17 | // Enter is called when a node a first visited. It may be nil. 18 | Enter func(n int) 19 | 20 | // Exit is called when all of the children of n have been 21 | // visited. It may be nil. 22 | // 23 | // Calls to Enter and Exit are always paired in nested order. 24 | Exit func(n int) 25 | } 26 | 27 | // Visit performs a Euler tour over g starting at root and invokes the 28 | // callbacks on e. 29 | func (e Euler) Visit(g graph.Graph, root int) { 30 | visited := NewNodeMarks() 31 | var visit func(n int) 32 | visit = func(n int) { 33 | if e.Enter != nil { 34 | e.Enter(n) 35 | } 36 | visited.Mark(n) 37 | for _, succ := range g.Out(n) { 38 | if !visited.Test(succ) { 39 | visit(succ) 40 | } 41 | } 42 | if e.Exit != nil { 43 | e.Exit(n) 44 | } 45 | } 46 | visit(root) 47 | } 48 | -------------------------------------------------------------------------------- /graph/graphalg/order.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package graphalg 6 | 7 | import ( 8 | "github.com/aclements/go-moremath/graph" 9 | ) 10 | 11 | // PreOrder returns the nodes of g visited in pre-order. 12 | func PreOrder(g graph.Graph, root int) []int { 13 | visited := NewNodeMarks() 14 | out := []int{} 15 | var visit func(n int) 16 | visit = func(n int) { 17 | out = append(out, n) 18 | visited.Mark(n) 19 | for _, succ := range g.Out(n) { 20 | if !visited.Test(succ) { 21 | visit(succ) 22 | } 23 | } 24 | } 25 | visit(root) 26 | 27 | return out 28 | } 29 | 30 | // PostOrder returns the nodes of g visited in post-order. 31 | func PostOrder(g graph.Graph, root int) []int { 32 | visited := NewNodeMarks() 33 | out := []int{} 34 | var visit func(n int) 35 | visit = func(n int) { 36 | visited.Mark(n) 37 | for _, succ := range g.Out(n) { 38 | if !visited.Test(succ) { 39 | visit(succ) 40 | } 41 | } 42 | out = append(out, n) 43 | } 44 | visit(root) 45 | 46 | return out 47 | } 48 | 49 | // Reverse reverses xs in place and returns the slice. This is useful 50 | // in conjunction with PreOrder and PostOrder to compute reverse 51 | // post-order and reverse pre-order. 52 | func Reverse(xs []int) []int { 53 | for i, j := 0, len(xs)-1; i < j; i, j = i+1, j-1 { 54 | xs[i], xs[j] = xs[j], xs[i] 55 | } 56 | return xs 57 | } 58 | -------------------------------------------------------------------------------- /stats/sample_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import ( 8 | "math" 9 | "testing" 10 | ) 11 | 12 | func TestSampleQuantile(t *testing.T) { 13 | s := Sample{Xs: []float64{15, 20, 35, 40, 50}} 14 | testFunc(t, "Quantile", s.Quantile, map[float64]float64{ 15 | -1: 15, 16 | 0: 15, 17 | .05: 15, 18 | .30: 19.666666666666666, 19 | .40: 27, 20 | .95: 50, 21 | 1: 50, 22 | 2: 50, 23 | }) 24 | } 25 | 26 | func TestMeanCI(t *testing.T) { 27 | var xs []float64 28 | naneq := func(a, b float64) bool { 29 | return a == b || (math.IsNaN(a) && math.IsNaN(b)) 30 | } 31 | check := func(conf, wmean, wlo, whi float64) { 32 | t.Helper() 33 | mean, lo, hi := MeanCI(xs, conf) 34 | if !(naneq(mean, wmean) && naneq(lo, wlo) && naneq(hi, whi)) { 35 | t.Errorf("for %v, want %v@[%v,%v], got %v@[%v,%v]", xs, wmean, wlo, whi, mean, lo, hi) 36 | } 37 | } 38 | 39 | xs = []float64{-8, 2, 3, 4, 5, 6} 40 | check(0, 2, 2, 2) 41 | check(0.95, 2, -3.351092806089359, 7.351092806089359) 42 | check(0.99, 2, -6.39357495385287, 10.39357495385287) 43 | check(1, 2, -inf, inf) 44 | 45 | xs = []float64{1} 46 | check(0, 1, 1, 1) 47 | check(0.95, 1, -inf, inf) 48 | check(1, 1, -inf, inf) 49 | 50 | xs = nil 51 | check(0, math.NaN(), math.NaN(), math.NaN()) 52 | check(0.95, math.NaN(), math.NaN(), math.NaN()) 53 | check(1, math.NaN(), math.NaN(), math.NaN()) 54 | } 55 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 The Go Authors. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are 5 | met: 6 | 7 | * Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above 10 | copyright notice, this list of conditions and the following disclaimer 11 | in the documentation and/or other materials provided with the 12 | distribution. 13 | * Neither the name of Google Inc. nor the names of its 14 | contributors may be used to endorse or promote products derived from 15 | this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 23 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 25 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /internal/mathtest/mathtest.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package mathtest 6 | 7 | import ( 8 | "fmt" 9 | "math" 10 | "sort" 11 | "strings" 12 | "testing" 13 | ) 14 | 15 | var ( 16 | aeqDigits int 17 | aeqFactor float64 18 | ) 19 | 20 | func SetAeqDigits(digits int) int { 21 | old := aeqDigits 22 | aeqDigits = digits 23 | aeqFactor = 1 - math.Pow(10, float64(-digits+1)) 24 | return old 25 | } 26 | 27 | func init() { 28 | SetAeqDigits(8) 29 | } 30 | 31 | // Aeq returns true if expect and got are equal up to the current 32 | // number of aeq digits set by SetAeqDigits. By default, this is 8 33 | // significant figures (1 part in 100 million). 34 | func Aeq(expect, got float64) bool { 35 | if expect < 0 && got < 0 { 36 | expect, got = -expect, -got 37 | } 38 | return expect*aeqFactor <= got && got*aeqFactor <= expect 39 | } 40 | 41 | func WantFunc(t *testing.T, name string, f func(float64) float64, vals map[float64]float64) { 42 | xs := make([]float64, 0, len(vals)) 43 | for x := range vals { 44 | xs = append(xs, x) 45 | } 46 | sort.Float64s(xs) 47 | 48 | for _, x := range xs { 49 | want, got := vals[x], f(x) 50 | if math.IsNaN(want) && math.IsNaN(got) || Aeq(want, got) { 51 | continue 52 | } 53 | var label string 54 | if strings.Contains(name, "%v") { 55 | label = fmt.Sprintf(name, x) 56 | } else { 57 | label = fmt.Sprintf("%s(%v)", name, x) 58 | } 59 | t.Errorf("want %s=%v, got %v", label, want, got) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /mathx/gamma_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package mathx 6 | 7 | import ( 8 | "testing" 9 | 10 | . "github.com/aclements/go-moremath/internal/mathtest" 11 | ) 12 | 13 | func TestGammaInc(t *testing.T) { 14 | WantFunc(t, "GammaInc(1, %v)", 15 | func(x float64) float64 { return GammaInc(1, x) }, 16 | map[float64]float64{ 17 | 0.1: 0.095162581964040441, 18 | 0.2: 0.18126924692201815, 19 | 0.3: 0.25918177931828207, 20 | 0.4: 0.32967995396436056, 21 | 0.5: 0.39346934028736652, 22 | 0.6: 0.45118836390597361, 23 | 0.7: 0.50341469620859047, 24 | 0.8: 0.55067103588277833, 25 | 0.9: 0.59343034025940089, 26 | 1: 0.63212055882855778, 27 | 2: 0.86466471676338730, 28 | 3: 0.95021293163213605, 29 | 4: 0.98168436111126578, 30 | 5: 0.99326205300091452, 31 | 6: 0.99752124782333362, 32 | 7: 0.99908811803444553, 33 | 8: 0.99966453737209748, 34 | 9: 0.99987659019591335, 35 | 10: 0.99995460007023750, 36 | }) 37 | WantFunc(t, "GammaInc(2, %v)", 38 | func(x float64) float64 { return GammaInc(2, x) }, 39 | map[float64]float64{ 40 | 1: 0.26424111765711528, 41 | 2: 0.59399415029016167, 42 | 3: 0.80085172652854419, 43 | 4: 0.90842180555632912, 44 | 5: 0.95957231800548726, 45 | 6: 0.98264873476333547, 46 | 7: 0.99270494427556388, 47 | 8: 0.99698083634887735, 48 | 9: 0.99876590195913317, 49 | 10: 0.99950060077261271, 50 | }) 51 | 52 | // TODO: Test strange values. 53 | } 54 | -------------------------------------------------------------------------------- /mathx/choose.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package mathx 6 | 7 | import "math" 8 | 9 | const smallFactLimit = 20 // 20! => 62 bits 10 | var smallFact [smallFactLimit + 1]int64 11 | 12 | func init() { 13 | smallFact[0] = 1 14 | fact := int64(1) 15 | for n := int64(1); n <= smallFactLimit; n++ { 16 | fact *= n 17 | smallFact[n] = fact 18 | } 19 | } 20 | 21 | // Choose returns the binomial coefficient of n and k. 22 | func Choose(n, k int) float64 { 23 | if k == 0 || k == n { 24 | return 1 25 | } 26 | if k < 0 || n < k { 27 | return 0 28 | } 29 | if n <= smallFactLimit { // Implies k <= smallFactLimit 30 | // It's faster to do several integer multiplications 31 | // than it is to do an extra integer division. 32 | // Remarkably, this is also faster than pre-computing 33 | // Pascal's triangle (presumably because this is very 34 | // cache efficient). 35 | numer := int64(1) 36 | for n1 := int64(n - (k - 1)); n1 <= int64(n); n1++ { 37 | numer *= n1 38 | } 39 | denom := smallFact[k] 40 | return float64(numer / denom) 41 | } 42 | 43 | return math.Exp(lchoose(n, k)) 44 | } 45 | 46 | // Lchoose returns math.Log(Choose(n, k)). 47 | func Lchoose(n, k int) float64 { 48 | if k == 0 || k == n { 49 | return 0 50 | } 51 | if k < 0 || n < k { 52 | return math.NaN() 53 | } 54 | return lchoose(n, k) 55 | } 56 | 57 | func lchoose(n, k int) float64 { 58 | a, _ := math.Lgamma(float64(n + 1)) 59 | b, _ := math.Lgamma(float64(k + 1)) 60 | c, _ := math.Lgamma(float64(n - k + 1)) 61 | return a - b - c 62 | } 63 | -------------------------------------------------------------------------------- /graph/graphalg/multigraph.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package graphalg 6 | 7 | import "github.com/aclements/go-moremath/graph" 8 | 9 | // SimplifyMulti simplifies a multigraph to a weighted simple graph. 10 | // 11 | // If g is a weighted graph, each edge in the result receives the sum 12 | // of the weights of the combined edges in g. If g is not weighted, 13 | // each edge in g is assumed to have a weight of 1. 14 | func SimplifyMulti(g graph.Graph) graph.Weighted { 15 | gw, ok := g.(graph.Weighted) 16 | if !ok { 17 | gw = graph.WeightedUnit{g} 18 | } 19 | 20 | indexes := make([]int, gw.NumNodes()+1) 21 | var edges []int 22 | var weights []float64 23 | edgeMap := make(map[int]int) 24 | for n := range indexes[1:] { 25 | for k := range edgeMap { 26 | delete(edgeMap, k) 27 | } 28 | for i, o := range gw.Out(n) { 29 | if idx, ok := edgeMap[o]; ok { 30 | // Already have an edge. 31 | weights[idx] += gw.OutWeight(n, i) 32 | } else { 33 | edgeMap[o] = len(edges) 34 | edges = append(edges, o) 35 | weights = append(weights, gw.OutWeight(n, i)) 36 | } 37 | } 38 | indexes[n+1] = len(edges) 39 | } 40 | return &simplified{indexes, edges, weights} 41 | } 42 | 43 | type simplified struct { 44 | indexes []int 45 | edges []int 46 | weights []float64 47 | } 48 | 49 | func (g *simplified) NumNodes() int { 50 | return len(g.indexes) - 1 51 | } 52 | 53 | func (g *simplified) Out(n int) []int { 54 | return g.edges[g.indexes[n]:g.indexes[n+1]] 55 | } 56 | 57 | func (g *simplified) OutWeight(n, e int) float64 { 58 | return g.weights[g.indexes[n]:g.indexes[n+1]][e] 59 | } 60 | -------------------------------------------------------------------------------- /scale/ticks_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package scale 6 | 7 | import "testing" 8 | 9 | type testTicker struct{} 10 | 11 | func (testTicker) CountTicks(level int) int { 12 | c := 10 - level 13 | if c < 1 { 14 | c = 1 15 | } 16 | return c 17 | } 18 | 19 | func (t testTicker) TicksAtLevel(level int) interface{} { 20 | m := make([]float64, t.CountTicks(level)) 21 | for i := 0; i < len(m); i++ { 22 | m[i] = float64(i) 23 | } 24 | return m 25 | } 26 | 27 | func TestTicks(t *testing.T) { 28 | check := func(o TickOptions, want int) { 29 | wantL, wantOK := want, true 30 | if want == -999 { 31 | wantL, wantOK = 0, false 32 | } 33 | for _, guess := range []int{0, -50, 50} { 34 | l, ok := o.FindLevel(testTicker{}, guess) 35 | if l != wantL || ok != wantOK { 36 | t.Errorf("%+v.FindLevel with guess %v returned %v, %v; wanted %v, %v", o, guess, l, ok, wantL, wantOK) 37 | } 38 | } 39 | } 40 | 41 | // Argument sanity checking. 42 | check(TickOptions{}, -999) 43 | check(TickOptions{MinLevel: 10, MaxLevel: 9}, -999) 44 | 45 | // Just max constraint. 46 | check(TickOptions{Max: 1}, 9) 47 | check(TickOptions{Max: 6}, 4) 48 | check(TickOptions{Max: 20}, -10) 49 | 50 | // Max and level constraints. 51 | check(TickOptions{Max: 1, MaxLevel: 9}, 9) 52 | check(TickOptions{Max: 1, MaxLevel: 8}, -999) 53 | check(TickOptions{Max: 1, MinLevel: 9, MaxLevel: 1000}, 9) 54 | check(TickOptions{Max: 1, MinLevel: 10, MaxLevel: 1000}, 10) 55 | 56 | check(TickOptions{Max: 6, MaxLevel: 9}, 4) 57 | check(TickOptions{Max: 6, MaxLevel: 3}, -999) 58 | check(TickOptions{Max: 6, MinLevel: 10, MaxLevel: 11}, 10) 59 | } 60 | -------------------------------------------------------------------------------- /graph/graph.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package graph provides interfaces and basic representations for 6 | // graphs. 7 | // 8 | // Sub-packages provide common graph algorithms. 9 | package graph 10 | 11 | // Graph represents a directed graph. The nodes of the graph must be 12 | // densely numbered starting at 0. 13 | type Graph interface { 14 | // NumNodes returns the number of nodes in this graph. 15 | NumNodes() int 16 | 17 | // Out returns the nodes to which node i points. 18 | Out(i int) []int 19 | } 20 | 21 | // BiGraph extends Graph to graphs that represent both out-edges and 22 | // in-edges. 23 | type BiGraph interface { 24 | Graph 25 | 26 | // In returns the nodes which point to node i. 27 | In(i int) []int 28 | } 29 | 30 | // MakeBiGraph constructs a BiGraph from what may be a unidirectional 31 | // Graph. If g is already a BiGraph, this returns g. 32 | func MakeBiGraph(g Graph) BiGraph { 33 | if g, ok := g.(BiGraph); ok { 34 | return g 35 | } 36 | 37 | preds := make([][]int, g.NumNodes()) 38 | for i := range preds { 39 | for _, j := range g.Out(i) { 40 | preds[j] = append(preds[j], i) 41 | } 42 | } 43 | 44 | return &bigraph{g, preds} 45 | } 46 | 47 | type bigraph struct { 48 | Graph 49 | preds [][]int 50 | } 51 | 52 | func (b *bigraph) In(i int) []int { 53 | return b.preds[i] 54 | } 55 | 56 | // IntGraph is a basic Graph g where g[i] is the list of out-edge 57 | // indexes of node i. 58 | type IntGraph [][]int 59 | 60 | func (g IntGraph) NumNodes() int { 61 | return len(g) 62 | } 63 | 64 | func (g IntGraph) Out(i int) []int { 65 | return g[i] 66 | } 67 | 68 | // Edge identifies an edge in a graph. Given Graph g, Edge e 69 | // represents edge g.Out(e.Node)[e.Edge]. 70 | type Edge struct { 71 | Node int // Node ID 72 | Edge int // Edge index 73 | } 74 | -------------------------------------------------------------------------------- /fit/loess_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package fit 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/aclements/go-moremath/internal/mathtest" 11 | ) 12 | 13 | func TestLOESS_NIST(t *testing.T) { 14 | // LOWESS example from the NIST handbook. 15 | xs := []float64{0.5578196, 16 | 2.0217271, 17 | 2.5773252, 18 | 3.4140288, 19 | 4.3014084, 20 | 4.7448394, 21 | 5.1073781, 22 | 6.5411662, 23 | 6.7216176, 24 | 7.2600583, 25 | 8.1335874, 26 | 9.1224379, 27 | 11.9296663, 28 | 12.3797674, 29 | 13.2728619, 30 | 14.2767453, 31 | 15.3731026, 32 | 15.6476637, 33 | 18.5605355, 34 | 18.5866354, 35 | 18.7572812, 36 | } 37 | ys := []float64{18.63654, 38 | 103.49646, 39 | 150.35391, 40 | 190.51031, 41 | 208.70115, 42 | 213.71135, 43 | 228.49353, 44 | 233.55387, 45 | 234.55054, 46 | 223.89225, 47 | 227.68339, 48 | 223.91982, 49 | 168.01999, 50 | 164.95750, 51 | 152.61107, 52 | 160.78742, 53 | 168.55567, 54 | 152.42658, 55 | 221.70702, 56 | 222.69040, 57 | 243.18828, 58 | } 59 | 60 | defer mathtest.SetAeqDigits(mathtest.SetAeqDigits(7)) 61 | mathtest.WantFunc(t, "LOESS", LOESS(xs, ys, 1, 0.33), 62 | map[float64]float64{ 63 | 0.5578196: 20.59302, 64 | 2.0217271: 107.1603, 65 | 2.5773252: 139.7674, 66 | 3.4140288: 174.2630, 67 | 4.301408: 207.2334, 68 | 4.744839: 216.6616, 69 | 5.107378: 220.5445, 70 | 6.541166: 229.8607, 71 | 6.721618: 229.8347, 72 | 7.260058: 229.4301, 73 | 8.133587: 226.6045, 74 | 9.122438: 220.3904, 75 | 11.929666: 172.3480, 76 | 12.379767: 163.8417, 77 | 13.272862: 161.8490, 78 | 14.27675: 160.3351, 79 | 15.37310: 160.1920, 80 | 15.64766: 161.0556, 81 | 18.56054: 227.3400, 82 | 18.58664: 227.8985, 83 | 18.75728: 231.5586, 84 | }) 85 | } 86 | -------------------------------------------------------------------------------- /stats/util_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import ( 8 | "fmt" 9 | "testing" 10 | 11 | "github.com/aclements/go-moremath/internal/mathtest" 12 | "github.com/aclements/go-moremath/vec" 13 | ) 14 | 15 | var aeq = mathtest.Aeq 16 | var testFunc = mathtest.WantFunc 17 | 18 | func testDiscreteCDF(t *testing.T, name string, dist DiscreteDist) { 19 | // Build the expected CDF out of the PMF. 20 | l, h := dist.Bounds() 21 | s := dist.Step() 22 | want := map[float64]float64{l - 0.1: 0, h: 1} 23 | sum := 0.0 24 | for x := l; x < h; x += s { 25 | sum += dist.PMF(x) 26 | want[x] = sum 27 | want[x+s/2] = sum 28 | } 29 | 30 | testFunc(t, name, dist.CDF, want) 31 | } 32 | 33 | func testInvCDF(t *testing.T, dist Dist, bounded bool) { 34 | inv := InvCDF(dist) 35 | name := fmt.Sprintf("InvCDF(%+v)", dist) 36 | cdfName := fmt.Sprintf("CDF(%+v)", dist) 37 | 38 | // Test bounds. 39 | vals := map[float64]float64{-0.01: nan, 1.01: nan} 40 | if !bounded { 41 | vals[0] = -inf 42 | vals[1] = inf 43 | } 44 | testFunc(t, name, inv, vals) 45 | 46 | if bounded { 47 | lo, hi := inv(0), inv(1) 48 | vals := map[float64]float64{ 49 | lo - 0.01: 0, lo: 0, 50 | hi: 1, hi + 0.01: 1, 51 | } 52 | testFunc(t, cdfName, dist.CDF, vals) 53 | if got := dist.CDF(lo + 0.01); !(got > 0) { 54 | t.Errorf("%s(0)=%v, but %s(%v)=0", name, lo, cdfName, lo+0.01) 55 | } 56 | if got := dist.CDF(hi - 0.01); !(got < 1) { 57 | t.Errorf("%s(1)=%v, but %s(%v)=1", name, hi, cdfName, hi-0.01) 58 | } 59 | } 60 | 61 | // Test points between. 62 | vals = map[float64]float64{} 63 | for _, p := range vec.Linspace(0, 1, 11) { 64 | if p == 0 || p == 1 { 65 | continue 66 | } 67 | x := inv(p) 68 | vals[x] = x 69 | } 70 | testFunc(t, fmt.Sprintf("InvCDF(CDF(%+v))", dist), 71 | func(x float64) float64 { 72 | return inv(dist.CDF(x)) 73 | }, 74 | vals) 75 | } 76 | -------------------------------------------------------------------------------- /vec/vec.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package vec 6 | 7 | import "math" 8 | 9 | // Vectorize returns a function g(xs) that applies f to each x in xs. 10 | // 11 | // f may be evaluated in parallel and in any order. 12 | func Vectorize(f func(float64) float64) func(xs []float64) []float64 { 13 | return func(xs []float64) []float64 { 14 | return Map(f, xs) 15 | } 16 | } 17 | 18 | // Map returns f(x) for each x in xs. 19 | // 20 | // f may be evaluated in parallel and in any order. 21 | func Map(f func(float64) float64, xs []float64) []float64 { 22 | // TODO(austin) Parallelize 23 | res := make([]float64, len(xs)) 24 | for i, x := range xs { 25 | res[i] = f(x) 26 | } 27 | return res 28 | } 29 | 30 | // Linspace returns num values spaced evenly between lo and hi, 31 | // inclusive. If num is 1, this returns an array consisting of lo. 32 | func Linspace(lo, hi float64, num int) []float64 { 33 | res := make([]float64, num) 34 | if num == 1 { 35 | res[0] = lo 36 | return res 37 | } 38 | for i := 0; i < num; i++ { 39 | res[i] = lo + float64(i)*(hi-lo)/float64(num-1) 40 | } 41 | return res 42 | } 43 | 44 | // Logspace returns num values spaced evenly on a logarithmic scale 45 | // between base**lo and base**hi, inclusive. 46 | func Logspace(lo, hi float64, num int, base float64) []float64 { 47 | res := Linspace(lo, hi, num) 48 | for i, x := range res { 49 | res[i] = math.Pow(base, x) 50 | } 51 | return res 52 | } 53 | 54 | // Sum returns the sum of xs. 55 | func Sum(xs []float64) float64 { 56 | sum := 0.0 57 | for _, x := range xs { 58 | sum += x 59 | } 60 | return sum 61 | } 62 | 63 | // Concat returns the concatenation of its arguments. It does not 64 | // modify its inputs. 65 | func Concat(xss ...[]float64) []float64 { 66 | total := 0 67 | for _, xs := range xss { 68 | total += len(xs) 69 | } 70 | out := make([]float64, total) 71 | pos := 0 72 | for _, xs := range xss { 73 | pos += copy(out[pos:], xs) 74 | } 75 | return out 76 | } 77 | -------------------------------------------------------------------------------- /graph/graphalg/marks.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package graphalg 6 | 7 | import "math/bits" 8 | 9 | // NodeMarks is a structure for marking nodes in a graph. 10 | type NodeMarks struct { 11 | marks []uint32 12 | } 13 | 14 | // Test returns whether node i is marked. 15 | func (m NodeMarks) Test(i int) bool { 16 | if i < 0 || i/32 >= len(m.marks) { 17 | return false 18 | } 19 | return m.marks[i/32]&(1<= len(m.marks) { 25 | m.grow(i) 26 | } 27 | m.marks[i/32] |= 1 << uint(i%32) 28 | } 29 | 30 | // Unmark clears the mark on node i. 31 | func (m *NodeMarks) Unmark(i int) { 32 | if i/32 >= len(m.marks) { 33 | return 34 | } 35 | m.marks[i/32] &^= 1 << uint(i%32) 36 | } 37 | 38 | func (m *NodeMarks) grow(i int) { 39 | n := i/32 + 1 40 | // Round n up to a power of two. 41 | k := 1 42 | for k > n { 43 | k <<= 1 44 | } 45 | marks := make([]uint32, k) 46 | copy(marks, m.marks) 47 | m.marks = marks 48 | } 49 | 50 | // Next returns the index of the next set mark after mark i, or -1 if 51 | // there are no set marks after i. 52 | // 53 | // This is typically used to loop over set marks like: 54 | // 55 | // for i := m.Next(-1); i >= 0; i = m.Next(i) { ... } 56 | func (m NodeMarks) Next(i int) int { 57 | i++ 58 | if i < 0 { 59 | i = 0 60 | } 61 | // Start with the block containing i. 62 | if i/32 >= len(m.marks) { 63 | return -1 64 | } 65 | b0 := m.marks[i/32] >> uint(i%32) 66 | if b0 != 0 { 67 | return i + bits.TrailingZeros32(b0) 68 | } 69 | // Scan the remaining blocks. 70 | for bi := (i / 32) + 1; bi < len(m.marks); bi++ { 71 | b := m.marks[bi] 72 | if b != 0 { 73 | return 32*bi + bits.TrailingZeros32(b) 74 | } 75 | } 76 | return -1 77 | } 78 | 79 | // NewNodeMarks returns a node mark set with no marks set. 80 | func NewNodeMarks() *NodeMarks { 81 | // This is small enough to get inlined, allowing the initial 82 | // marks slice to get stack-allocated. 83 | return &NodeMarks{make([]uint32, 1024/32)} 84 | } 85 | -------------------------------------------------------------------------------- /scale/interface.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package scale 6 | 7 | // A Quantative scale is an invertible function from some continuous 8 | // input domain to an output range of [0, 1]. 9 | type Quantitative interface { 10 | // Map maps from a value x in the input domain to [0, 1]. If x 11 | // is outside the input domain and clamping is enabled, x will 12 | // first be clamped to the input domain. 13 | Map(x float64) float64 14 | 15 | // Unmap is the inverse of Map. That is, if x is in the input 16 | // domain or clamping is disabled, x = Unmap(Map(x)). If 17 | // clamping is enabled and y is outside [0,1], the results are 18 | // undefined. 19 | Unmap(y float64) float64 20 | 21 | // SetClamp sets the clamping mode of this scale. 22 | SetClamp(bool) 23 | 24 | // Ticks returns major and minor ticks that satisfy the 25 | // constraints given by o. These ticks will have "nice" values 26 | // within the input domain. Both arrays are sorted in 27 | // ascending order and minor includes ticks in major. 28 | Ticks(o TickOptions) (major, minor []float64) 29 | 30 | // Nice expands the input domain of this scale to "nice" 31 | // values for covering the input domain satisfying the 32 | // constraints given by o. After calling Nice(o), the first 33 | // and last major ticks returned by Ticks(o) will equal the 34 | // lower and upper bounds of the input domain. 35 | Nice(o TickOptions) 36 | 37 | // A Quantitative scale is also a Ticker. 38 | Ticker 39 | } 40 | 41 | // A QQ maps from a source Quantitative scale to a destination 42 | // Quantitative scale. 43 | type QQ struct { 44 | Src, Dest Quantitative 45 | } 46 | 47 | // Map maps from a value x in the source scale's input domain to a 48 | // value y in the destination scale's input domain. 49 | func (q QQ) Map(x float64) float64 { 50 | return q.Dest.Unmap(q.Src.Map(x)) 51 | } 52 | 53 | // Unmap maps from a value y in the destination scale's input domain to 54 | // a value x in the source scale's input domain. 55 | func (q QQ) Unmap(x float64) float64 { 56 | return q.Src.Unmap(q.Dest.Map(x)) 57 | } 58 | -------------------------------------------------------------------------------- /stats/hist.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import "math" 8 | 9 | // TODO: Implement histograms on top of scales. 10 | 11 | type Histogram interface { 12 | // Add adds a sample with value x to histogram h. 13 | Add(x float64) 14 | 15 | // Counts returns the number of samples less than the lowest 16 | // bin, a slice of the number of samples in each bin, 17 | // and the number of samples greater than the highest bin. 18 | Counts() (under uint, counts []uint, over uint) 19 | 20 | // BinToValue returns the value that would appear at the given 21 | // bin index. 22 | // 23 | // For integral values of bin, BinToValue returns the lower 24 | // bound of bin. That is, a sample value x will be in bin if 25 | // bin is integral and 26 | // 27 | // BinToValue(bin) <= x < BinToValue(bin + 1) 28 | // 29 | // For non-integral values of bin, BinToValue interpolates 30 | // between the lower and upper bounds of math.Floor(bin). 31 | // 32 | // BinToValue is undefined if bin > 1 + the number of bins. 33 | BinToValue(bin float64) float64 34 | } 35 | 36 | // HistogramQuantile returns the x such that n*q samples in hist are 37 | // <= x, assuming values are distibuted within each bin according to 38 | // hist's distribution. 39 | // 40 | // If the q'th sample falls below the lowest bin or above the highest 41 | // bin, returns NaN. 42 | func HistogramQuantile(hist Histogram, q float64) float64 { 43 | under, counts, over := hist.Counts() 44 | total := under + over 45 | for _, count := range counts { 46 | total += count 47 | } 48 | 49 | goal := uint(float64(total) * q) 50 | if goal <= under || goal > total-over { 51 | return math.NaN() 52 | } 53 | for bin, count := range counts { 54 | if count > goal { 55 | return hist.BinToValue(float64(bin) + float64(goal)/float64(count)) 56 | } 57 | goal -= count 58 | } 59 | panic("goal count not reached") 60 | } 61 | 62 | // HistogramIQR returns the interquartile range of the samples in 63 | // hist. 64 | func HistogramIQR(hist Histogram) float64 { 65 | return HistogramQuantile(hist, 0.75) - HistogramQuantile(hist, 0.25) 66 | } 67 | -------------------------------------------------------------------------------- /stats/loghist.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import "math" 8 | 9 | // LogHist is a Histogram with logarithmically-spaced bins. 10 | type LogHist struct { 11 | b int 12 | m float64 13 | mOverLogb float64 14 | low, high uint 15 | bins []uint 16 | } 17 | 18 | // NewLogHist returns an empty logarithmic histogram with bins for 19 | // integral values of m * log_b(x) up to x = max. 20 | func NewLogHist(b int, m float64, max float64) *LogHist { 21 | // TODO(austin) Minimum value as well? If the samples are 22 | // actually integral, having fractional bin boundaries can 23 | // mess up smoothing. 24 | mOverLogb := m / math.Log(float64(b)) 25 | nbins := int(math.Ceil(mOverLogb * math.Log(max))) 26 | return &LogHist{b: b, m: m, mOverLogb: mOverLogb, low: 0, high: 0, bins: make([]uint, nbins)} 27 | } 28 | 29 | func (h *LogHist) bin(x float64) int { 30 | return int(h.mOverLogb * math.Log(x)) 31 | } 32 | 33 | func (h *LogHist) Add(x float64) { 34 | bin := h.bin(x) 35 | if bin < 0 { 36 | h.low++ 37 | } else if bin >= len(h.bins) { 38 | h.high++ 39 | } else { 40 | h.bins[bin]++ 41 | } 42 | } 43 | 44 | func (h *LogHist) Counts() (uint, []uint, uint) { 45 | return h.low, h.bins, h.high 46 | } 47 | 48 | func (h *LogHist) BinToValue(bin float64) float64 { 49 | return math.Pow(float64(h.b), bin/h.m) 50 | } 51 | 52 | func (h *LogHist) At(x float64) float64 { 53 | bin := h.bin(x) 54 | if bin < 0 || bin >= len(h.bins) { 55 | return 0 56 | } 57 | return float64(h.bins[bin]) 58 | } 59 | 60 | func (h *LogHist) Bounds() (float64, float64) { 61 | // XXX Plot will plot this on a linear axis. Maybe this 62 | // should be able to return the natural axis? 63 | // Maybe then we could also give it the bins for the tics. 64 | lowbin := 0 65 | if h.low == 0 { 66 | for bin, count := range h.bins { 67 | if count > 0 { 68 | lowbin = bin 69 | break 70 | } 71 | } 72 | } 73 | highbin := len(h.bins) 74 | if h.high == 0 { 75 | for bin := range h.bins { 76 | if h.bins[len(h.bins)-bin-1] > 0 { 77 | highbin = len(h.bins) - bin 78 | break 79 | } 80 | } 81 | } 82 | return h.BinToValue(float64(lowbin)), h.BinToValue(float64(highbin)) 83 | } 84 | -------------------------------------------------------------------------------- /stats/binomdist.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import ( 8 | "math" 9 | 10 | "github.com/aclements/go-moremath/mathx" 11 | ) 12 | 13 | // BinomialDist is a binomial distribution. 14 | type BinomialDist struct { 15 | // N is the number of independent Bernoulli trials. N >= 0. 16 | // 17 | // If N=1, this is equivalent to the Bernoulli distribution. 18 | N int 19 | 20 | // P is the probability of success in each trial. 0 <= P <= 1. 21 | P float64 22 | } 23 | 24 | // PMF is the probability of getting exactly int(k) successes in d.N 25 | // independent Bernoulli trials with probability d.P. 26 | func (d BinomialDist) PMF(k float64) float64 { 27 | ki := int(math.Floor(k)) 28 | if ki < 0 || ki > d.N { 29 | return 0 30 | } 31 | return mathx.Choose(d.N, ki) * math.Pow(d.P, float64(ki)) * math.Pow(1-d.P, float64(d.N-ki)) 32 | } 33 | 34 | // CDF is the probability of getting k or fewer successes in d.N 35 | // independent Bernoulli trials with probability d.P. 36 | func (d BinomialDist) CDF(k float64) float64 { 37 | k = math.Floor(k) 38 | ki := int(k) 39 | if ki < 0 { 40 | return 0 41 | } else if ki >= d.N { 42 | return 1 43 | } 44 | 45 | return mathx.BetaInc(1-d.P, float64(d.N-ki), k+1) 46 | } 47 | 48 | func (d BinomialDist) Bounds() (float64, float64) { 49 | return 0, float64(d.N) 50 | } 51 | 52 | func (d BinomialDist) Step() float64 { 53 | return 1 54 | } 55 | 56 | func (d BinomialDist) Mean() float64 { 57 | return float64(d.N) * d.P 58 | } 59 | 60 | func (d BinomialDist) Variance() float64 { 61 | return float64(d.N) * d.P * (1 - d.P) 62 | } 63 | 64 | // NormalApprox returns a normal distribution approximation of 65 | // binomial distribution d. 66 | // 67 | // Because the binomial distribution is discrete and the normal 68 | // distribution is continuous, the caller must apply a continuity 69 | // correction when using this approximation. Specifically, if b is the 70 | // binomial distribution and n is the normal approximation, operations 71 | // map as follows: 72 | // 73 | // b.PMF(k) => n.CDF(k+0.5) - n.CDF(k-0.5) 74 | // b.CDF(k) => n.CDF(k+0.5) 75 | func (d BinomialDist) NormalApprox() NormalDist { 76 | return NormalDist{Mu: d.Mean(), Sigma: math.Sqrt(d.Variance())} 77 | } 78 | -------------------------------------------------------------------------------- /cmd/dist/dist.go: -------------------------------------------------------------------------------- 1 | // dist reads newline-separated numbers and describes their distribution. 2 | // 3 | // For example, 4 | // 5 | // $ seq 1 20 | grep -v 1 | dist 6 | // N 9 sum 64 mean 7.11111 gmean 5.78509 std dev 5.34894 variance 28.6111 7 | // 8 | // min 2 9 | // 1%ile 2 10 | // 5%ile 2 11 | // 25%ile 3.66667 12 | // median 6 13 | // 75%ile 8.33333 14 | // 95%ile 20 15 | // 99%ile 20 16 | // max 20 17 | // 18 | // ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣠⠖⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠦⣀⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡖ 0.1 19 | // ⠀⠀⠀⠀⠀⠀⠀⢀⣠⠴⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠲⢤⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇ 20 | // ⠠⠤⠤⠤⠤⠴⠒⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠑⠲⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠴⠒⠋⠉⠉⠀⠀⠉⠉⠙⠒⠦⠤⠤⠤⠤⠄⠧ 0.0 21 | // ⠈⠉⠉⠉⠉⠙⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠋⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠋⠉⠉⠉⠉⠉⠉⠉⠉⠉⠁ 22 | // 0 10 20 23 | package main 24 | 25 | import ( 26 | "bufio" 27 | "fmt" 28 | "io" 29 | "math" 30 | "os" 31 | "strconv" 32 | "strings" 33 | 34 | "github.com/aclements/go-moremath/stats" 35 | ) 36 | 37 | func main() { 38 | s := readInput(os.Stdin) 39 | if len(s.Xs) == 0 { 40 | fmt.Fprintln(os.Stderr, "no input") 41 | return 42 | } 43 | s.Sort() 44 | 45 | fmt.Printf("N %d sum %.6g mean %.6g", len(s.Xs), s.Sum(), s.Mean()) 46 | gmean := s.GeoMean() 47 | if !math.IsNaN(gmean) { 48 | fmt.Printf(" gmean %.6g", gmean) 49 | } 50 | fmt.Printf(" std dev %.6g variance %.6g\n", s.StdDev(), s.Variance()) 51 | fmt.Println() 52 | 53 | // Quartiles and tails. 54 | labels := map[int]string{0: "min", 50: "median", 100: "max"} 55 | for _, p := range []int{0, 1, 5, 25, 50, 75, 95, 99, 100} { 56 | label, ok := labels[p] 57 | if !ok { 58 | label = fmt.Sprintf("%d%%ile", p) 59 | } 60 | fmt.Printf("%8s %.6g\n", label, s.Quantile(float64(p)/100)) 61 | } 62 | fmt.Println() 63 | 64 | // Kernel density estimate. 65 | kde := &stats.KDE{Sample: s} 66 | FprintPDF(os.Stdout, kde) 67 | } 68 | 69 | func readInput(r io.Reader) (sample stats.Sample) { 70 | scanner := bufio.NewScanner(r) 71 | for scanner.Scan() { 72 | l := scanner.Text() 73 | l = strings.TrimSpace(l) 74 | if l == "" { 75 | continue 76 | } 77 | value, err := strconv.ParseFloat(l, 64) 78 | if err != nil { 79 | fmt.Fprintln(os.Stderr, err) 80 | os.Exit(1) 81 | } 82 | 83 | sample.Xs = append(sample.Xs, value) 84 | } 85 | if err := scanner.Err(); err != nil { 86 | fmt.Fprintln(os.Stderr, err) 87 | os.Exit(1) 88 | } 89 | 90 | return 91 | } 92 | -------------------------------------------------------------------------------- /graph/subgraph_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package graph 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | "testing" 11 | ) 12 | 13 | var graph1 = IntGraph{ 14 | 0: {0, 1}, 15 | 1: {0, 2}, 16 | 2: {3, 4}, 17 | 3: {4}, 18 | 4: {}, 19 | } 20 | 21 | func TestSubgraphKeep(t *testing.T) { 22 | var want = IntGraph{ 23 | 0: {}, // Was node 4 24 | 1: {2}, // Was node 1 25 | 2: {1, 2}, // Was node 0 26 | } 27 | g2 := SubgraphKeep(graph1, []int{4, 1, 0}, []Edge{{0, 1}, {0, 0}, {1, 0}}) 28 | if !Equal(want, g2) { 29 | t.Fatalf("want:\n%sgot:\n%s", pgraph(want), pgraph(g2)) 30 | } 31 | 32 | nodeMap := g2.NodeMap(func(node int) interface{} { return node }) 33 | newToOldNode := []int{4, 1, 0} 34 | for newNode, oldNode := range newToOldNode { 35 | if got := nodeMap(newNode).(int); got != oldNode { 36 | t.Errorf("nodeMap(%d) = %d, want %d", newNode, got, oldNode) 37 | } 38 | } 39 | 40 | edgeMap := g2.EdgeMap(func(node, edge int) interface{} { return Edge{node, edge} }) 41 | newToOldEdge := map[Edge]Edge{ 42 | Edge{1, 0}: Edge{1, 0}, 43 | Edge{2, 0}: Edge{0, 1}, 44 | Edge{2, 1}: Edge{0, 0}, 45 | } 46 | for newEdge, oldEdge := range newToOldEdge { 47 | if got := edgeMap(newEdge.Node, newEdge.Edge); got != oldEdge { 48 | t.Errorf("edgeMap(%d, %d) = %v, want %v", newEdge.Node, newEdge.Edge, got, oldEdge) 49 | } 50 | } 51 | } 52 | 53 | func TestSubgraphRemove(t *testing.T) { 54 | // Test automatic edge removal. 55 | var want = IntGraph{ 56 | 0: {0}, // Was node 0 57 | 1: {2}, // Was node 3 58 | 2: {}, // Was node 4 59 | } 60 | g2 := SubgraphRemove(graph1, []int{1, 2}, nil) 61 | if !Equal(want, g2) { 62 | t.Fatalf("want:\n%sgot:\n%s", pgraph(want), pgraph(g2)) 63 | } 64 | 65 | // Test edge removal. 66 | want = IntGraph{ 67 | 0: {1}, 68 | 1: {0}, 69 | 2: {3, 4}, 70 | 3: {4}, 71 | 4: {}, 72 | } 73 | g2 = SubgraphRemove(graph1, nil, []Edge{{0, 0}, {1, 1}}) 74 | if !Equal(want, g2) { 75 | t.Fatalf("want:\n%sgot:\n%s", pgraph(want), pgraph(g2)) 76 | } 77 | } 78 | 79 | func pgraph(g Graph) string { 80 | var buf bytes.Buffer 81 | for nid := 0; nid < g.NumNodes(); nid++ { 82 | fmt.Fprintf(&buf, "%d ->", nid) 83 | for _, n2 := range g.Out(nid) { 84 | fmt.Fprintf(&buf, " %d", n2) 85 | } 86 | fmt.Fprintf(&buf, "\n") 87 | } 88 | return buf.String() 89 | } 90 | -------------------------------------------------------------------------------- /mathx/gamma.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package mathx 6 | 7 | import "math" 8 | 9 | // GammaInc returns the value of the incomplete gamma function (also 10 | // known as the regularized gamma function): 11 | // 12 | // P(a, x) = 1 / Γ(a) * ∫₀ˣ exp(-t) t**(a-1) dt 13 | func GammaInc(a, x float64) float64 { 14 | // Based on Numerical Recipes in C, section 6.2. 15 | 16 | if a <= 0 || x < 0 || math.IsNaN(a) || math.IsNaN(x) { 17 | return math.NaN() 18 | } 19 | 20 | if x < a+1 { 21 | // Use the series representation, which converges more 22 | // rapidly in this range. 23 | return gammaIncSeries(a, x) 24 | } else { 25 | // Use the continued fraction representation. 26 | return 1 - gammaIncCF(a, x) 27 | } 28 | } 29 | 30 | // GammaIncComp returns the complement of the incomplete gamma 31 | // function 1 - GammaInc(a, x). This is more numerically stable for 32 | // values near 0. 33 | func GammaIncComp(a, x float64) float64 { 34 | if a <= 0 || x < 0 || math.IsNaN(a) || math.IsNaN(x) { 35 | return math.NaN() 36 | } 37 | 38 | if x < a+1 { 39 | return 1 - gammaIncSeries(a, x) 40 | } else { 41 | return gammaIncCF(a, x) 42 | } 43 | } 44 | 45 | func gammaIncSeries(a, x float64) float64 { 46 | const maxIterations = 200 47 | const epsilon = 3e-14 48 | 49 | if x == 0 { 50 | return 0 51 | } 52 | 53 | ap := a 54 | del := 1 / a 55 | sum := del 56 | for n := 0; n < maxIterations; n++ { 57 | ap++ 58 | del *= x / ap 59 | sum += del 60 | if math.Abs(del) < math.Abs(sum)*epsilon { 61 | return sum * math.Exp(-x+a*math.Log(x)-lgamma(a)) 62 | } 63 | } 64 | panic("a too large; failed to converge") 65 | } 66 | 67 | func gammaIncCF(a, x float64) float64 { 68 | const maxIterations = 200 69 | const epsilon = 3e-14 70 | 71 | raiseZero := func(z float64) float64 { 72 | if math.Abs(z) < math.SmallestNonzeroFloat64 { 73 | return math.SmallestNonzeroFloat64 74 | } 75 | return z 76 | } 77 | 78 | b := x + 1 - a 79 | c := math.MaxFloat64 80 | d := 1 / b 81 | h := d 82 | 83 | for i := 1; i <= maxIterations; i++ { 84 | an := -float64(i) * (float64(i) - a) 85 | b += 2 86 | d = raiseZero(an*d + b) 87 | c = raiseZero(b + an/c) 88 | d = 1 / d 89 | del := d * c 90 | h *= del 91 | if math.Abs(del-1) < epsilon { 92 | return math.Exp(-x+a*math.Log(x)-lgamma(a)) * h 93 | } 94 | } 95 | panic("a too large; failed to converge") 96 | } 97 | -------------------------------------------------------------------------------- /stats/kde_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import ( 8 | "fmt" 9 | "testing" 10 | ) 11 | 12 | func TestKDEOneSample(t *testing.T) { 13 | x := float64(5) 14 | 15 | // Unweighted, fixed bandwidth 16 | kde := KDE{ 17 | Sample: Sample{Xs: []float64{x}}, 18 | Kernel: GaussianKernel, 19 | Bandwidth: 1, 20 | } 21 | if e, g := StdNormal.PDF(0), kde.PDF(x); !aeq(e, g) { 22 | t.Errorf("bad PDF value at sample: expected %g, got %g", e, g) 23 | } 24 | if e, g := 0.0, kde.PDF(-10000); !aeq(e, g) { 25 | t.Errorf("bad PDF value at low tail: expected %g, got %g", e, g) 26 | } 27 | if e, g := 0.0, kde.PDF(10000); !aeq(e, g) { 28 | t.Errorf("bad PDF value at high tail: expected %g, got %g", e, g) 29 | } 30 | 31 | if e, g := 0.5, kde.CDF(x); !aeq(e, g) { 32 | t.Errorf("bad CDF value at sample: expected %g, got %g", e, g) 33 | } 34 | if e, g := 0.0, kde.CDF(-10000); !aeq(e, g) { 35 | t.Errorf("bad CDF value at low tail: expected %g, got %g", e, g) 36 | } 37 | if e, g := 1.0, kde.CDF(10000); !aeq(e, g) { 38 | t.Errorf("bad CDF value at high tail: expected %g, got %g", e, g) 39 | } 40 | 41 | low, high := kde.Bounds() 42 | if e, g := x-2, low; e < g { 43 | t.Errorf("bad low bound: expected %g, got %g", e, g) 44 | } 45 | if e, g := x+2, high; e > g { 46 | t.Errorf("bad high bound: expected %g, got %g", e, g) 47 | } 48 | 49 | kde = KDE{ 50 | Sample: Sample{Xs: []float64{x}}, 51 | Kernel: EpanechnikovKernel, 52 | Bandwidth: 2, 53 | } 54 | testFunc(t, fmt.Sprintf("%+v.PDF", kde), kde.PDF, map[float64]float64{ 55 | x - 2: 0, 56 | x - 1: 0.5625 / 2, 57 | x: 0.75 / 2, 58 | x + 1: 0.5625 / 2, 59 | x + 2: 0, 60 | }) 61 | testFunc(t, fmt.Sprintf("%+v.CDF", kde), kde.CDF, map[float64]float64{ 62 | x - 2: 0, 63 | x - 1: 0.15625, 64 | x: 0.5, 65 | x + 1: 0.84375, 66 | x + 2: 1, 67 | }) 68 | } 69 | 70 | func TestKDETwoSamples(t *testing.T) { 71 | kde := KDE{ 72 | Sample: Sample{Xs: []float64{1, 3}}, 73 | Kernel: GaussianKernel, 74 | Bandwidth: 2, 75 | } 76 | testFunc(t, "PDF", kde.PDF, map[float64]float64{ 77 | 0: 0.120395730, 78 | 1: 0.160228251, 79 | 2: 0.176032663, 80 | 3: 0.160228251, 81 | 4: 0.120395730}) 82 | 83 | testFunc(t, "CDF", kde.CDF, map[float64]float64{ 84 | 0: 0.187672369, 85 | 1: 0.329327626, 86 | 2: 0.5, 87 | 3: 0.670672373, 88 | 4: 0.812327630}) 89 | } 90 | -------------------------------------------------------------------------------- /stats/ttest_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import "testing" 8 | 9 | func TestTTest(t *testing.T) { 10 | s1 := Sample{Xs: []float64{2, 1, 3, 4}} 11 | s2 := Sample{Xs: []float64{6, 5, 7, 9}} 12 | 13 | check := func(want, got *TTestResult) { 14 | if want.N1 != got.N1 || want.N2 != got.N2 || 15 | !aeq(want.T, got.T) || !aeq(want.DoF, got.DoF) || 16 | want.AltHypothesis != got.AltHypothesis || 17 | !aeq(want.P, got.P) { 18 | t.Errorf("want %+v, got %+v", want, got) 19 | } 20 | } 21 | check3 := func(test func(alt LocationHypothesis) (*TTestResult, error), n1, n2 int, t, dof float64, pless, pdiff, pgreater float64) { 22 | want := &TTestResult{N1: n1, N2: n2, T: t, DoF: dof} 23 | 24 | want.AltHypothesis = LocationLess 25 | want.P = pless 26 | got, _ := test(want.AltHypothesis) 27 | check(want, got) 28 | 29 | want.AltHypothesis = LocationDiffers 30 | want.P = pdiff 31 | got, _ = test(want.AltHypothesis) 32 | check(want, got) 33 | 34 | want.AltHypothesis = LocationGreater 35 | want.P = pgreater 36 | got, _ = test(want.AltHypothesis) 37 | check(want, got) 38 | } 39 | 40 | check3(func(alt LocationHypothesis) (*TTestResult, error) { 41 | return TwoSampleTTest(s1, s1, alt) 42 | }, 4, 4, 0, 6, 43 | 0.5, 1, 0.5) 44 | check3(func(alt LocationHypothesis) (*TTestResult, error) { 45 | return TwoSampleWelchTTest(s1, s1, alt) 46 | }, 4, 4, 0, 6, 47 | 0.5, 1, 0.5) 48 | 49 | check3(func(alt LocationHypothesis) (*TTestResult, error) { 50 | return TwoSampleTTest(s1, s2, alt) 51 | }, 4, 4, -3.9703446152237674, 6, 52 | 0.0036820296121056195, 0.0073640592242113214, 0.9963179703878944) 53 | check3(func(alt LocationHypothesis) (*TTestResult, error) { 54 | return TwoSampleWelchTTest(s1, s2, alt) 55 | }, 4, 4, -3.9703446152237674, 5.584615384615385, 56 | 0.004256431565689112, 0.0085128631313781695, 0.9957435684343109) 57 | 58 | check3(func(alt LocationHypothesis) (*TTestResult, error) { 59 | return PairedTTest(s1.Xs, s2.Xs, 0, alt) 60 | }, 4, 4, -17, 3, 61 | 0.0002216717691559955, 0.00044334353831207749, 0.999778328230844) 62 | 63 | check3(func(alt LocationHypothesis) (*TTestResult, error) { 64 | return OneSampleTTest(s1, 0, alt) 65 | }, 4, 0, 3.872983346207417, 3, 66 | 0.9847668541689145, 0.030466291662170977, 0.015233145831085482) 67 | check3(func(alt LocationHypothesis) (*TTestResult, error) { 68 | return OneSampleTTest(s1, 2.5, alt) 69 | }, 4, 0, 0, 3, 70 | 0.5, 1, 0.5) 71 | } 72 | -------------------------------------------------------------------------------- /mathx/beta.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package mathx 6 | 7 | import "math" 8 | 9 | func lgamma(x float64) float64 { 10 | y, _ := math.Lgamma(x) 11 | return y 12 | } 13 | 14 | // Beta returns the value of the complete beta function B(a, b). 15 | func Beta(a, b float64) float64 { 16 | // B(x,y) = Γ(x)Γ(y) / Γ(x+y) 17 | return math.Exp(lgamma(a) + lgamma(b) - lgamma(a+b)) 18 | } 19 | 20 | // BetaInc returns the value of the regularized incomplete beta 21 | // function Iₓ(a, b) = 1 / B(a, b) * ∫₀ˣ tᵃ⁻¹ (1-t)ᵇ⁻¹ dt. 22 | // 23 | // This is not to be confused with the "incomplete beta function", 24 | // which can be computed as BetaInc(x, a, b)*Beta(a, b). 25 | // 26 | // If x < 0 or x > 1, returns NaN. 27 | func BetaInc(x, a, b float64) float64 { 28 | // Based on Numerical Recipes in C, section 6.4. This uses the 29 | // continued fraction definition of I: 30 | // 31 | // (xᵃ*(1-x)ᵇ)/(a*B(a,b)) * (1/(1+(d₁/(1+(d₂/(1+...)))))) 32 | // 33 | // where B(a,b) is the beta function and 34 | // 35 | // d_{2m+1} = -(a+m)(a+b+m)x/((a+2m)(a+2m+1)) 36 | // d_{2m} = m(b-m)x/((a+2m-1)(a+2m)) 37 | if x < 0 || x > 1 { 38 | return math.NaN() 39 | } 40 | bt := 0.0 41 | if 0 < x && x < 1 { 42 | // Compute the coefficient before the continued 43 | // fraction. 44 | bt = math.Exp(lgamma(a+b) - lgamma(a) - lgamma(b) + 45 | a*math.Log(x) + b*math.Log(1-x)) 46 | } 47 | if x < (a+1)/(a+b+2) { 48 | // Compute continued fraction directly. 49 | return bt * betacf(x, a, b) / a 50 | } else { 51 | // Compute continued fraction after symmetry transform. 52 | return 1 - bt*betacf(1-x, b, a)/b 53 | } 54 | } 55 | 56 | // betacf is the continued fraction component of the regularized 57 | // incomplete beta function Iₓ(a, b). 58 | func betacf(x, a, b float64) float64 { 59 | const maxIterations = 200 60 | const epsilon = 3e-14 61 | 62 | raiseZero := func(z float64) float64 { 63 | if math.Abs(z) < math.SmallestNonzeroFloat64 { 64 | return math.SmallestNonzeroFloat64 65 | } 66 | return z 67 | } 68 | 69 | c := 1.0 70 | d := 1 / raiseZero(1-(a+b)*x/(a+1)) 71 | h := d 72 | for m := 1; m <= maxIterations; m++ { 73 | mf := float64(m) 74 | 75 | // Even step of the recurrence. 76 | numer := mf * (b - mf) * x / ((a + 2*mf - 1) * (a + 2*mf)) 77 | d = 1 / raiseZero(1+numer*d) 78 | c = raiseZero(1 + numer/c) 79 | h *= d * c 80 | 81 | // Odd step of the recurrence. 82 | numer = -(a + mf) * (a + b + mf) * x / ((a + 2*mf) * (a + 2*mf + 1)) 83 | d = 1 / raiseZero(1+numer*d) 84 | c = raiseZero(1 + numer/c) 85 | hfac := d * c 86 | h *= hfac 87 | 88 | if math.Abs(hfac-1) < epsilon { 89 | return h 90 | } 91 | } 92 | panic("betainc: a or b too big; failed to converge") 93 | } 94 | -------------------------------------------------------------------------------- /stats/utest_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import "testing" 8 | 9 | func TestMannWhitneyUTest(t *testing.T) { 10 | check := func(want, got *MannWhitneyUTestResult) { 11 | if want.N1 != got.N1 || want.N2 != got.N2 || 12 | !aeq(want.U, got.U) || 13 | want.AltHypothesis != got.AltHypothesis || 14 | !aeq(want.P, got.P) { 15 | t.Errorf("want %+v, got %+v", want, got) 16 | } 17 | } 18 | check3 := func(x1, x2 []float64, U float64, pless, pdiff, pgreater float64) { 19 | want := &MannWhitneyUTestResult{N1: len(x1), N2: len(x2), U: U} 20 | 21 | want.AltHypothesis = LocationLess 22 | want.P = pless 23 | got, _ := MannWhitneyUTest(x1, x2, want.AltHypothesis) 24 | check(want, got) 25 | 26 | want.AltHypothesis = LocationDiffers 27 | want.P = pdiff 28 | got, _ = MannWhitneyUTest(x1, x2, want.AltHypothesis) 29 | check(want, got) 30 | 31 | want.AltHypothesis = LocationGreater 32 | want.P = pgreater 33 | got, _ = MannWhitneyUTest(x1, x2, want.AltHypothesis) 34 | check(want, got) 35 | } 36 | 37 | s1 := []float64{2, 1, 3, 5} 38 | s2 := []float64{12, 11, 13, 15} 39 | s3 := []float64{0, 4, 6, 7} // Interleaved with s1, but no ties 40 | s4 := []float64{2, 2, 2, 2} 41 | s5 := []float64{1, 1, 1, 1, 1} 42 | 43 | // Small sample, no ties 44 | check3(s1, s2, 0, 0.014285714285714289, 0.028571428571428577, 1) 45 | check3(s2, s1, 16, 1, 0.028571428571428577, 0.014285714285714289) 46 | check3(s1, s3, 5, 0.24285714285714288, 0.485714285714285770, 0.8285714285714285) 47 | 48 | // Small sample, ties 49 | // TODO: Check these against some other implementation. 50 | check3(s1, s1, 8, 0.6285714285714286, 1, 0.6285714285714286) 51 | check3(s1, s4, 10, 0.8571428571428571, 0.7142857142857143, 0.3571428571428571) 52 | check3(s1, s5, 17.5, 1, 0, 0.04761904761904767) 53 | 54 | r, err := MannWhitneyUTest(s4, s4, LocationDiffers) 55 | if err != ErrSamplesEqual { 56 | t.Errorf("want ErrSamplesEqual, got %+v, %+v", r, err) 57 | } 58 | 59 | // Large samples. 60 | l1 := make([]float64, 500) 61 | for i := range l1 { 62 | l1[i] = float64(i * 2) 63 | } 64 | l2 := make([]float64, 600) 65 | for i := range l2 { 66 | l2[i] = float64(i*2 - 41) 67 | } 68 | l3 := append([]float64{}, l2...) 69 | for i := 0; i < 30; i++ { 70 | l3[i] = l1[i] 71 | } 72 | // For comparing with R's wilcox.test: 73 | // l1 <- seq(0, 499)*2 74 | // l2 <- seq(0,599)*2-41 75 | // l3 <- l2; for (i in 1:30) { l3[i] = l1[i] } 76 | 77 | check3(l1, l2, 135250, 0.0024667680407086112, 0.0049335360814172224, 0.9975346930458906) 78 | check3(l1, l1, 125000, 0.5000436801680628, 1, 0.5000436801680628) 79 | check3(l1, l3, 134845, 0.0019351907119808942, 0.0038703814239617884, 0.9980659818257166) 80 | } 81 | -------------------------------------------------------------------------------- /stats/stream.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import ( 8 | "fmt" 9 | "math" 10 | ) 11 | 12 | // TODO(austin) Unify more with Sample interface 13 | 14 | // StreamStats tracks basic statistics for a stream of data in O(1) 15 | // space. 16 | // 17 | // StreamStats should be initialized to its zero value. 18 | type StreamStats struct { 19 | Count uint 20 | Total, Min, Max float64 21 | 22 | // Numerically stable online mean 23 | mean float64 24 | meanOfSquares float64 25 | 26 | // Online variance 27 | vM2 float64 28 | } 29 | 30 | // Add updates s's statistics with sample value x. 31 | func (s *StreamStats) Add(x float64) { 32 | s.Total += x 33 | if s.Count == 0 { 34 | s.Min, s.Max = x, x 35 | } else { 36 | if x < s.Min { 37 | s.Min = x 38 | } 39 | if x > s.Max { 40 | s.Max = x 41 | } 42 | } 43 | s.Count++ 44 | 45 | // Update online mean, mean of squares, and variance. Online 46 | // variance based on Wikipedia's presentation ("Algorithms for 47 | // calculating variance") of Knuth's formulation of Welford 48 | // 1962. 49 | delta := x - s.mean 50 | s.mean += delta / float64(s.Count) 51 | s.meanOfSquares += (x*x - s.meanOfSquares) / float64(s.Count) 52 | s.vM2 += delta * (x - s.mean) 53 | } 54 | 55 | func (s *StreamStats) Weight() float64 { 56 | return float64(s.Count) 57 | } 58 | 59 | func (s *StreamStats) Mean() float64 { 60 | return s.mean 61 | } 62 | 63 | func (s *StreamStats) Variance() float64 { 64 | return s.vM2 / float64(s.Count-1) 65 | } 66 | 67 | func (s *StreamStats) StdDev() float64 { 68 | return math.Sqrt(s.Variance()) 69 | } 70 | 71 | func (s *StreamStats) RMS() float64 { 72 | return math.Sqrt(s.meanOfSquares) 73 | } 74 | 75 | // Combine updates s's statistics as if all samples added to o were 76 | // added to s. 77 | func (s *StreamStats) Combine(o *StreamStats) { 78 | count := s.Count + o.Count 79 | 80 | // Compute combined online variance statistics 81 | delta := o.mean - s.mean 82 | mean := s.mean + delta*float64(o.Count)/float64(count) 83 | vM2 := s.vM2 + o.vM2 + delta*delta*float64(s.Count)*float64(o.Count)/float64(count) 84 | 85 | s.Count = count 86 | s.Total += o.Total 87 | if o.Min < s.Min { 88 | s.Min = o.Min 89 | } 90 | if o.Max > s.Max { 91 | s.Max = o.Max 92 | } 93 | s.mean = mean 94 | s.meanOfSquares += (o.meanOfSquares - s.meanOfSquares) * float64(o.Count) / float64(count) 95 | s.vM2 = vM2 96 | } 97 | 98 | func (s *StreamStats) String() string { 99 | return fmt.Sprintf("Count=%d Total=%g Min=%g Mean=%g RMS=%g Max=%g StdDev=%g", s.Count, s.Total, s.Min, s.Mean(), s.RMS(), s.Max, s.StdDev()) 100 | } 101 | -------------------------------------------------------------------------------- /stats/alg.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | // Miscellaneous helper algorithms 8 | 9 | import ( 10 | "fmt" 11 | 12 | "github.com/aclements/go-moremath/mathx" 13 | ) 14 | 15 | func maxint(a, b int) int { 16 | if a > b { 17 | return a 18 | } 19 | return b 20 | } 21 | 22 | func minint(a, b int) int { 23 | if a < b { 24 | return a 25 | } 26 | return b 27 | } 28 | 29 | func sumint(xs []int) int { 30 | sum := 0 31 | for _, x := range xs { 32 | sum += x 33 | } 34 | return sum 35 | } 36 | 37 | // bisect returns an x in [low, high] such that |f(x)| <= tolerance 38 | // using the bisection method. 39 | // 40 | // f(low) and f(high) must have opposite signs. 41 | // 42 | // If f does not have a root in this interval (e.g., it is 43 | // discontiguous), this returns the X of the apparent discontinuity 44 | // and false. 45 | func bisect(f func(float64) float64, low, high, tolerance float64) (float64, bool) { 46 | flow, fhigh := f(low), f(high) 47 | if -tolerance <= flow && flow <= tolerance { 48 | return low, true 49 | } 50 | if -tolerance <= fhigh && fhigh <= tolerance { 51 | return high, true 52 | } 53 | if mathx.Sign(flow) == mathx.Sign(fhigh) { 54 | panic(fmt.Sprintf("root of f is not bracketed by [low, high]; f(%g)=%g f(%g)=%g", low, flow, high, fhigh)) 55 | } 56 | for { 57 | mid := (high + low) / 2 58 | fmid := f(mid) 59 | if -tolerance <= fmid && fmid <= tolerance { 60 | return mid, true 61 | } 62 | if mid == high || mid == low { 63 | return mid, false 64 | } 65 | if mathx.Sign(fmid) == mathx.Sign(flow) { 66 | low = mid 67 | flow = fmid 68 | } else { 69 | high = mid 70 | fhigh = fmid 71 | } 72 | } 73 | } 74 | 75 | // bisectBool implements the bisection method on a boolean function. 76 | // It returns x1, x2 ∈ [low, high], x1 < x2 such that f(x1) != f(x2) 77 | // and x2 - x1 <= xtol. 78 | // 79 | // If f(low) == f(high), it panics. 80 | func bisectBool(f func(float64) bool, low, high, xtol float64) (x1, x2 float64) { 81 | flow, fhigh := f(low), f(high) 82 | if flow == fhigh { 83 | panic(fmt.Sprintf("root of f is not bracketed by [low, high]; f(%g)=%v f(%g)=%v", low, flow, high, fhigh)) 84 | } 85 | for { 86 | if high-low <= xtol { 87 | return low, high 88 | } 89 | mid := (high + low) / 2 90 | if mid == high || mid == low { 91 | return low, high 92 | } 93 | fmid := f(mid) 94 | if fmid == flow { 95 | low = mid 96 | flow = fmid 97 | } else { 98 | high = mid 99 | fhigh = fmid 100 | } 101 | } 102 | } 103 | 104 | // series returns the sum of the series f(0), f(1), ... 105 | // 106 | // This implementation is fast, but subject to round-off error. 107 | func series(f func(float64) float64) float64 { 108 | y, yp := 0.0, 1.0 109 | for n := 0.0; y != yp; n++ { 110 | yp = y 111 | y += f(n) 112 | } 113 | return y 114 | } 115 | -------------------------------------------------------------------------------- /stats/hypergdist.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import ( 8 | "math" 9 | 10 | "github.com/aclements/go-moremath/mathx" 11 | ) 12 | 13 | // HypergeometicDist is a hypergeometric distribution. 14 | type HypergeometicDist struct { 15 | // N is the size of the population. N >= 0. 16 | N int 17 | 18 | // K is the number of successes in the population. 0 <= K <= N. 19 | K int 20 | 21 | // Draws is the number of draws from the population. This is 22 | // usually written "n", but is called Draws here because of 23 | // limitations on Go identifier naming. 0 <= Draws <= N. 24 | Draws int 25 | } 26 | 27 | // PMF is the probability of getting exactly int(k) successes in 28 | // d.Draws draws with replacement from a population of size d.N that 29 | // contains exactly d.K successes. 30 | func (d HypergeometicDist) PMF(k float64) float64 { 31 | ki := int(math.Floor(k)) 32 | l, h := d.bounds() 33 | if ki < l || ki > h { 34 | return 0 35 | } 36 | return d.pmf(ki) 37 | } 38 | 39 | func (d HypergeometicDist) pmf(k int) float64 { 40 | return math.Exp(mathx.Lchoose(d.K, k) + mathx.Lchoose(d.N-d.K, d.Draws-k) - mathx.Lchoose(d.N, d.Draws)) 41 | } 42 | 43 | // CDF is the probability of getting int(k) or fewer successes in 44 | // d.Draws draws with replacement from a population of size d.N that 45 | // contains exactly d.K successes. 46 | func (d HypergeometicDist) CDF(k float64) float64 { 47 | // Based on Klotz, A Computational Approach to Statistics. 48 | ki := int(math.Floor(k)) 49 | l, h := d.bounds() 50 | if ki < l { 51 | return 0 52 | } else if ki >= h { 53 | return 1 54 | } 55 | // Use symmetry to compute the smaller sum. 56 | flip := false 57 | if ki > (d.Draws+1)/(d.N+1)*(d.K+1) { 58 | flip = true 59 | ki = d.K - ki - 1 60 | d.Draws = d.N - d.Draws 61 | } 62 | p := d.pmf(ki) * d.sum(ki) 63 | if flip { 64 | p = 1 - p 65 | } 66 | return p 67 | } 68 | 69 | func (d HypergeometicDist) sum(k int) float64 { 70 | const epsilon = 1e-14 71 | sum, ak := 1.0, 1.0 72 | L := maxint(0, d.Draws+d.K-d.N) 73 | for dk := 1; dk <= k-L && ak/sum > epsilon; dk++ { 74 | ak *= float64(1+k-dk) / float64(d.Draws-k+dk) 75 | ak *= float64(d.N-d.K-d.Draws+k+1-dk) / float64(d.K-k+dk) 76 | sum += ak 77 | } 78 | return sum 79 | } 80 | 81 | func (d HypergeometicDist) bounds() (int, int) { 82 | return maxint(0, d.Draws+d.K-d.N), minint(d.Draws, d.K) 83 | } 84 | 85 | func (d HypergeometicDist) Bounds() (float64, float64) { 86 | l, h := d.bounds() 87 | return float64(l), float64(h) 88 | } 89 | 90 | func (d HypergeometicDist) Step() float64 { 91 | return 1 92 | } 93 | 94 | func (d HypergeometicDist) Mean() float64 { 95 | return float64(d.Draws*d.K) / float64(d.N) 96 | } 97 | 98 | func (d HypergeometicDist) Variance() float64 { 99 | return float64(d.Draws*d.K*(d.N-d.K)*(d.N-d.Draws)) / 100 | float64(d.N*d.N*(d.N-1)) 101 | } 102 | -------------------------------------------------------------------------------- /stats/tdist_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import "testing" 8 | 9 | func TestT(t *testing.T) { 10 | testFunc(t, "PDF(%v|v=1)", TDist{1}.PDF, map[float64]float64{ 11 | -10: 0.0031515830315226806, 12 | -9: 0.0038818278802901312, 13 | -8: 0.0048970751720583188, 14 | -7: 0.0063661977236758151, 15 | -6: 0.0086029698968592104, 16 | -5: 0.012242687930145799, 17 | -4: 0.018724110951987692, 18 | -3: 0.031830988618379075, 19 | -2: 0.063661977236758149, 20 | -1: 0.15915494309189537, 21 | 0: 0.31830988618379075, 22 | 1: 0.15915494309189537, 23 | 2: 0.063661977236758149, 24 | 3: 0.031830988618379075, 25 | 4: 0.018724110951987692, 26 | 5: 0.012242687930145799, 27 | 6: 0.0086029698968592104, 28 | 7: 0.0063661977236758151, 29 | 8: 0.0048970751720583188, 30 | 9: 0.0038818278802901312}) 31 | testFunc(t, "PDF(%v|v=5)", TDist{5}.PDF, map[float64]float64{ 32 | -10: 4.0989816415343313e-05, 33 | -9: 7.4601664362590413e-05, 34 | -8: 0.00014444303269563934, 35 | -7: 0.00030134402928803911, 36 | -6: 0.00068848154013743002, 37 | -5: 0.0017574383788078445, 38 | -4: 0.0051237270519179133, 39 | -3: 0.017292578800222964, 40 | -2: 0.065090310326216455, 41 | -1: 0.21967979735098059, 42 | 0: 0.3796066898224944, 43 | 1: 0.21967979735098059, 44 | 2: 0.065090310326216455, 45 | 3: 0.017292578800222964, 46 | 4: 0.0051237270519179133, 47 | 5: 0.0017574383788078445, 48 | 6: 0.00068848154013743002, 49 | 7: 0.00030134402928803911, 50 | 8: 0.00014444303269563934, 51 | 9: 7.4601664362590413e-05}) 52 | 53 | testFunc(t, "CDF(%v|v=1)", TDist{1}.CDF, map[float64]float64{ 54 | -10: 0.03172551743055356, 55 | -9: 0.035223287477277272, 56 | -8: 0.039583424160565539, 57 | -7: 0.045167235300866547, 58 | -6: 0.052568456711253424, 59 | -5: 0.06283295818900117, 60 | -4: 0.077979130377369324, 61 | -3: 0.10241638234956672, 62 | -2: 0.14758361765043321, 63 | -1: 0.24999999999999978, 64 | 0: 0.5, 65 | 1: 0.75000000000000022, 66 | 2: 0.85241638234956674, 67 | 3: 0.89758361765043326, 68 | 4: 0.92202086962263075, 69 | 5: 0.93716704181099886, 70 | 6: 0.94743154328874657, 71 | 7: 0.95483276469913347, 72 | 8: 0.96041657583943452, 73 | 9: 0.96477671252272279}) 74 | testFunc(t, "CDF(%v|v=5)", TDist{5}.CDF, map[float64]float64{ 75 | -10: 8.5473787871481787e-05, 76 | -9: 0.00014133998712194845, 77 | -8: 0.00024645333028622187, 78 | -7: 0.00045837375719920225, 79 | -6: 0.00092306914479700695, 80 | -5: 0.0020523579900266612, 81 | -4: 0.0051617077404157259, 82 | -3: 0.015049623948731284, 83 | -2: 0.05096973941492914, 84 | -1: 0.18160873382456127, 85 | 0: 0.5, 86 | 1: 0.81839126617543867, 87 | 2: 0.9490302605850709, 88 | 3: 0.98495037605126878, 89 | 4: 0.99483829225958431, 90 | 5: 0.99794764200997332, 91 | 6: 0.99907693085520299, 92 | 7: 0.99954162624280074, 93 | 8: 0.99975354666971372, 94 | 9: 0.9998586600128780}) 95 | } 96 | -------------------------------------------------------------------------------- /graph/graphalg/scc_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package graphalg 6 | 7 | import ( 8 | "reflect" 9 | "sort" 10 | "testing" 11 | 12 | "github.com/aclements/go-moremath/graph" 13 | "github.com/aclements/go-moremath/graph/graphout" 14 | ) 15 | 16 | type sccTest struct { 17 | g graph.Graph 18 | components [][]int // Component -> Sub-node IDs (sorted) 19 | edges graph.Graph 20 | } 21 | 22 | // SCC example from CLRS. 23 | var clrsSCC = sccTest{ 24 | graph.IntGraph{ 25 | 0: {1}, 26 | 1: {2, 4, 5}, 27 | 2: {3, 6}, 28 | 3: {2, 7}, 29 | 4: {0, 5}, 30 | 5: {6}, 31 | 6: {5, 7}, 32 | 7: {7}, 33 | }, 34 | [][]int{ 35 | 0: {7}, 36 | 1: {5, 6}, 37 | 2: {2, 3}, 38 | 3: {0, 1, 4}, 39 | }, 40 | graph.IntGraph{ 41 | 0: {}, 42 | 1: {0}, 43 | 2: {0, 1}, 44 | 3: {1, 2}, 45 | }, 46 | } 47 | 48 | // SCC example from Sedgewick, Algorithms in C, Part 5, Third Edition, p. 199. 49 | var sedgewickSCC = sccTest{ 50 | graph.IntGraph{ 51 | 0: {2}, 52 | 1: {0}, 53 | 2: {3, 4}, 54 | 3: {2, 4}, 55 | 4: {5, 6}, 56 | 5: {0, 3}, 57 | 6: {0, 7}, 58 | 7: {8}, 59 | 8: {7}, 60 | 9: {6, 8, 12}, 61 | 10: {9}, 62 | 11: {4, 9}, 63 | 12: {10, 11}, 64 | }, 65 | [][]int{ 66 | 3: {9, 10, 11, 12}, 67 | 2: {1}, 68 | 1: {0, 2, 3, 4, 5, 6}, 69 | 0: {7, 8}, 70 | }, 71 | graph.IntGraph{ 72 | 3: {0, 1}, 73 | 2: {1}, 74 | 1: {0}, 75 | 0: {}, 76 | }, 77 | } 78 | 79 | // SCC example from 80 | // https://algs4.cs.princeton.edu/lectures/42DirectedGraphs-2x2.pdf 81 | // 82 | // This is very similar to the Sedgewick graph, but not the same. 83 | // (Maybe this is from the fourth edition?) 84 | var sedgewick2SCC = sccTest{ 85 | graph.IntGraph{ 86 | 0: {1, 5}, 87 | 1: {}, 88 | 2: {0, 3}, 89 | 3: {2, 5}, 90 | 4: {2, 3}, 91 | 5: {4}, 92 | 6: {0, 4, 8, 9}, 93 | 7: {6, 9}, 94 | 8: {6}, 95 | 9: {10, 11}, 96 | 10: {12}, 97 | 11: {4, 12}, 98 | 12: {9}, 99 | }, 100 | [][]int{ 101 | 0: {1}, 102 | 1: {0, 2, 3, 4, 5}, 103 | 2: {9, 10, 11, 12}, 104 | 3: {6, 8}, 105 | 4: {7}, 106 | }, 107 | graph.IntGraph{ 108 | 0: {}, 109 | 1: {0}, 110 | 2: {1}, 111 | 3: {1, 2}, 112 | 4: {2, 3}, 113 | }, 114 | } 115 | 116 | func TestSCC(t *testing.T) { 117 | t.Run("clrs", func(t *testing.T) { testSCC(t, clrsSCC) }) 118 | t.Run("sedgewick", func(t *testing.T) { testSCC(t, sedgewickSCC) }) 119 | t.Run("sedgewick2", func(t *testing.T) { testSCC(t, sedgewick2SCC) }) 120 | } 121 | 122 | func testSCC(t *testing.T, test sccTest) { 123 | scc := SCC(test.g, SCCEdges) 124 | 125 | // Check components. 126 | var components [][]int 127 | for i := 0; i < scc.NumNodes(); i++ { 128 | comp := append([]int{}, scc.Subnodes(i)...) 129 | sort.Ints(comp) 130 | components = append(components, comp) 131 | } 132 | if !reflect.DeepEqual(test.components, components) { 133 | t.Errorf("want components:\n%v\ngot:\n%v", test.components, components) 134 | } 135 | 136 | // Check the edges. 137 | if !graph.Equal(test.edges, scc) { 138 | t.Errorf("want edges:\n%s\ngot:\n%s\n", 139 | graphout.Dot{}.Sprint(test.edges), 140 | graphout.Dot{}.Sprint(scc), 141 | ) 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /fit/loess.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package fit 6 | 7 | import ( 8 | "math" 9 | "sort" 10 | ) 11 | 12 | // LOESS computes the locally-weighted least squares polynomial 13 | // regression to the data (xs[i], ys[i]). 0 < span <= 1 is the 14 | // smoothing parameter, where smaller values fit the data more 15 | // tightly. Degree is typically 2 and span is typically between 0.5 16 | // and 0.75. 17 | // 18 | // The regression is "local" because the weights used for the 19 | // polynomial regression depend on the x at which the regression 20 | // function is evaluated. The weight of observation i is 21 | // W((x-xs[i])/d(x)) where d(x) is the distance from x to the 22 | // span*len(xs)'th closest point to x and W is the tricube weight 23 | // function W(u) = (1-|u|³)³ for |u| < 1, 0 otherwise. One consequence 24 | // of this is that only the span*len(xs) points closest to x affect 25 | // the regression at x, and that the effect of these points falls off 26 | // further from x. 27 | // 28 | // # References 29 | // 30 | // Cleveland, William S., and Susan J. Devlin. "Locally weighted 31 | // regression: an approach to regression analysis by local fitting." 32 | // Journal of the American Statistical Association 83.403 (1988): 33 | // 596-610. 34 | // 35 | // http://www.itl.nist.gov/div898/handbook/pmd/section1/dep/dep144.htm 36 | func LOESS(xs, ys []float64, degree int, span float64) func(x float64) float64 { 37 | if degree < 0 { 38 | panic("degree must be non-negative") 39 | } 40 | if span <= 0 { 41 | panic("span must be positive") 42 | } 43 | 44 | // q is the window width in data points. 45 | q := int(math.Ceil(span * float64(len(xs)))) 46 | if q >= len(xs) { 47 | q = len(xs) 48 | } 49 | 50 | // Sort xs. 51 | if !sort.Float64sAreSorted(xs) { 52 | xs = append([]float64(nil), xs...) 53 | ys = append([]float64(nil), ys...) 54 | sort.Sort(&pairSlice{xs, ys}) 55 | } 56 | 57 | return func(x float64) float64 { 58 | // Find the q points closest to x. 59 | n := 0 60 | if len(xs) > q { 61 | n = sort.Search(len(xs)-q, func(i int) bool { 62 | // The cut-off between xs[i:i+q] and 63 | // xs[i+1:i+1+q] is avg(xs[i], 64 | // xs[i+q]). 65 | return (xs[i] + xs[i+q]) >= x*2 66 | }) 67 | } 68 | closest := xs[n : n+q] 69 | 70 | // Compute the distance to the q'th farthest point. 71 | // This will be either the first or last point in 72 | // closest. 73 | d := x - closest[0] 74 | if closest[q-1]-x > d { 75 | d = closest[q-1] - x 76 | } 77 | 78 | // Compute the weights. 79 | weights := make([]float64, q) 80 | for i, c := range closest { 81 | // u is the normalized distance from x to 82 | // closest[i]. 83 | u := math.Abs(x-c) / d 84 | // Compute the tricube weight (1-|u|³)³ for 85 | // |u| < 1. We know 0 <= u <= 1, so we can 86 | // simplify this a bit. 87 | tmp := 1 - u*u*u 88 | weights[i] = tmp * tmp * tmp 89 | } 90 | 91 | // Compute the polynomial regression at x. 92 | pr := PolynomialRegression(closest, ys[n:n+q], weights, degree) 93 | 94 | // Evaluate the polynomial at x. 95 | return pr.F(x) 96 | } 97 | } 98 | 99 | type pairSlice struct { 100 | xs, ys []float64 101 | } 102 | 103 | func (s *pairSlice) Len() int { 104 | return len(s.xs) 105 | } 106 | 107 | func (s *pairSlice) Less(i, j int) bool { 108 | return s.xs[i] < s.xs[j] 109 | } 110 | 111 | func (s *pairSlice) Swap(i, j int) { 112 | s.xs[i], s.xs[j] = s.xs[j], s.xs[i] 113 | s.ys[i], s.ys[j] = s.ys[j], s.ys[i] 114 | } 115 | -------------------------------------------------------------------------------- /scale/ticks.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package scale 6 | 7 | // TickOptions specifies constraints for constructing scale ticks. 8 | // 9 | // A Ticks method will return the ticks at the lowest level (largest 10 | // number of ticks) that satisfies all of the constraints. The exact 11 | // meaning of the tick level differs between scale types, but for all 12 | // scales higher tick levels result in ticks that are further apart 13 | // (fewer ticks in a given interval). In general, the minor ticks are 14 | // the ticks from one level below the major ticks. 15 | type TickOptions struct { 16 | // Max is the maximum number of major ticks to return. 17 | Max int 18 | 19 | // MinLevel and MaxLevel are the minimum and maximum tick 20 | // levels to accept, respectively. If they are both 0, there is 21 | // no limit on acceptable tick levels. 22 | MinLevel, MaxLevel int 23 | } 24 | 25 | // A Ticker computes tick marks for a scale. The "level" of the ticks 26 | // controls how many ticks there are and how closely they are spaced. 27 | // Higher levels have fewer ticks, while lower levels have more ticks. 28 | // For example, on a numerical scale, one could have ticks at every 29 | // n*(10^level). 30 | type Ticker interface { 31 | // CountTicks returns the number of ticks at level in this 32 | // scale's input range. This is equivalent to 33 | // len(TicksAtLevel(level)), but should be much more 34 | // efficient. CountTicks is a weakly monotonically decreasing 35 | // function of level. 36 | CountTicks(level int) int 37 | 38 | // TicksAtLevel returns a slice of "nice" tick values in 39 | // increasing order at level in this scale's input range. 40 | // Typically, TicksAtLevel(l+1) is a subset of 41 | // TicksAtLevel(l). That is, higher levels remove ticks from 42 | // lower levels. 43 | TicksAtLevel(level int) interface{} 44 | } 45 | 46 | // FindLevel returns the lowest level that satisfies the constraints 47 | // given by o: 48 | // 49 | // * ticker.CountTicks(level) <= o.Max 50 | // 51 | // * o.MinLevel <= level <= o.MaxLevel (if MinLevel and MaxLevel != 0). 52 | // 53 | // If the constraints cannot be satisfied, it returns 0, false. 54 | // 55 | // guess is the level to start the optimization at. 56 | func (o *TickOptions) FindLevel(ticker Ticker, guess int) (int, bool) { 57 | minLevel, maxLevel := o.MinLevel, o.MaxLevel 58 | if minLevel == 0 && maxLevel == 0 { 59 | minLevel, maxLevel = -1000, 1000 60 | } else if minLevel > maxLevel { 61 | return 0, false 62 | } 63 | if o.Max < 1 { 64 | return 0, false 65 | } 66 | 67 | // Start with the initial guess. 68 | l := guess 69 | if l < minLevel { 70 | l = minLevel 71 | } else if l > maxLevel { 72 | l = maxLevel 73 | } 74 | 75 | // Optimize count against o.Max. 76 | if ticker.CountTicks(l) <= o.Max { 77 | // We're satisfying the o.Max and min/maxLevel 78 | // constraints. count is monotonically decreasing, so 79 | // decrease level to increase the count until we 80 | // violate either o.Max or minLevel. 81 | for l--; l >= minLevel && ticker.CountTicks(l) <= o.Max; l-- { 82 | } 83 | // We went one too far. 84 | l++ 85 | } else { 86 | // We're over o.Max. Increase level to decrease the 87 | // count until we go below o.Max. This may cause us to 88 | // violate maxLevel. 89 | for l++; l <= maxLevel && ticker.CountTicks(l) > o.Max; l++ { 90 | } 91 | if l > maxLevel { 92 | // We can't satisfy both o.Max and maxLevel. 93 | return 0, false 94 | } 95 | } 96 | 97 | // At this point l is the lowest value that satisfies the 98 | // o.Max, minLevel, and maxLevel constraints. 99 | 100 | return l, true 101 | } 102 | -------------------------------------------------------------------------------- /stats/normaldist.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import ( 8 | "math" 9 | "math/rand" 10 | ) 11 | 12 | // NormalDist is a normal (Gaussian) distribution with mean Mu and 13 | // standard deviation Sigma. 14 | type NormalDist struct { 15 | Mu, Sigma float64 16 | } 17 | 18 | // StdNormal is the standard normal distribution (Mu = 0, Sigma = 1) 19 | var StdNormal = NormalDist{0, 1} 20 | 21 | // 1/sqrt(2 * pi) 22 | const invSqrt2Pi = 0.39894228040143267793994605993438186847585863116493465766592583 23 | 24 | func (n NormalDist) PDF(x float64) float64 { 25 | z := x - n.Mu 26 | return math.Exp(-z*z/(2*n.Sigma*n.Sigma)) * invSqrt2Pi / n.Sigma 27 | } 28 | 29 | func (n NormalDist) pdfEach(xs []float64) []float64 { 30 | res := make([]float64, len(xs)) 31 | if n.Mu == 0 && n.Sigma == 1 { 32 | // Standard normal fast path 33 | for i, x := range xs { 34 | res[i] = math.Exp(-x*x/2) * invSqrt2Pi 35 | } 36 | } else { 37 | a := -1 / (2 * n.Sigma * n.Sigma) 38 | b := invSqrt2Pi / n.Sigma 39 | for i, x := range xs { 40 | z := x - n.Mu 41 | res[i] = math.Exp(z*z*a) * b 42 | } 43 | } 44 | return res 45 | } 46 | 47 | func (n NormalDist) CDF(x float64) float64 { 48 | return math.Erfc(-(x-n.Mu)/(n.Sigma*math.Sqrt2)) / 2 49 | } 50 | 51 | func (n NormalDist) cdfEach(xs []float64) []float64 { 52 | res := make([]float64, len(xs)) 53 | a := 1 / (n.Sigma * math.Sqrt2) 54 | for i, x := range xs { 55 | res[i] = math.Erfc(-(x-n.Mu)*a) / 2 56 | } 57 | return res 58 | } 59 | 60 | func (n NormalDist) InvCDF(p float64) (x float64) { 61 | // This is based on Peter John Acklam's inverse normal CDF 62 | // algorithm: http://home.online.no/~pjacklam/notes/invnorm/ 63 | const ( 64 | a1 = -3.969683028665376e+01 65 | a2 = 2.209460984245205e+02 66 | a3 = -2.759285104469687e+02 67 | a4 = 1.383577518672690e+02 68 | a5 = -3.066479806614716e+01 69 | a6 = 2.506628277459239e+00 70 | 71 | b1 = -5.447609879822406e+01 72 | b2 = 1.615858368580409e+02 73 | b3 = -1.556989798598866e+02 74 | b4 = 6.680131188771972e+01 75 | b5 = -1.328068155288572e+01 76 | 77 | c1 = -7.784894002430293e-03 78 | c2 = -3.223964580411365e-01 79 | c3 = -2.400758277161838e+00 80 | c4 = -2.549732539343734e+00 81 | c5 = 4.374664141464968e+00 82 | c6 = 2.938163982698783e+00 83 | 84 | d1 = 7.784695709041462e-03 85 | d2 = 3.224671290700398e-01 86 | d3 = 2.445134137142996e+00 87 | d4 = 3.754408661907416e+00 88 | 89 | plow = 0.02425 90 | phigh = 1 - plow 91 | ) 92 | 93 | if p < 0 || p > 1 { 94 | return nan 95 | } else if p == 0 { 96 | return -inf 97 | } else if p == 1 { 98 | return inf 99 | } 100 | 101 | if p < plow { 102 | // Rational approximation for lower region. 103 | q := math.Sqrt(-2 * math.Log(p)) 104 | x = (((((c1*q+c2)*q+c3)*q+c4)*q+c5)*q + c6) / 105 | ((((d1*q+d2)*q+d3)*q+d4)*q + 1) 106 | } else if phigh < p { 107 | // Rational approximation for upper region. 108 | q := math.Sqrt(-2 * math.Log(1-p)) 109 | x = -(((((c1*q+c2)*q+c3)*q+c4)*q+c5)*q + c6) / 110 | ((((d1*q+d2)*q+d3)*q+d4)*q + 1) 111 | } else { 112 | // Rational approximation for central region. 113 | q := p - 0.5 114 | r := q * q 115 | x = (((((a1*r+a2)*r+a3)*r+a4)*r+a5)*r + a6) * q / 116 | (((((b1*r+b2)*r+b3)*r+b4)*r+b5)*r + 1) 117 | } 118 | 119 | // Refine approximation. 120 | e := 0.5*math.Erfc(-x/math.Sqrt2) - p 121 | u := e * math.Sqrt(2*math.Pi) * math.Exp(x*x/2) 122 | x = x - u/(1+x*u/2) 123 | 124 | // Adjust from standard normal. 125 | return x*n.Sigma + n.Mu 126 | } 127 | 128 | func (n NormalDist) Rand(r *rand.Rand) float64 { 129 | var x float64 130 | if r == nil { 131 | x = rand.NormFloat64() 132 | } else { 133 | x = r.NormFloat64() 134 | } 135 | return x*n.Sigma + n.Mu 136 | } 137 | 138 | func (n NormalDist) Bounds() (float64, float64) { 139 | const stddevs = 3 140 | return n.Mu - stddevs*n.Sigma, n.Mu + stddevs*n.Sigma 141 | } 142 | 143 | func (n NormalDist) Mean() float64 { 144 | return n.Mu 145 | } 146 | 147 | func (n NormalDist) Variance() float64 { 148 | return n.Sigma * n.Sigma 149 | } 150 | -------------------------------------------------------------------------------- /graph/graphalg/dom.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package graphalg 6 | 7 | import "github.com/aclements/go-moremath/graph" 8 | 9 | // IDom returns the immediate dominator of each node of g. Nodes that 10 | // don't have an immediate dominator (including root) are assigned -1. 11 | func IDom(g graph.BiGraph, root int) []int { 12 | // This implements the "engineered algorithm" of Cooper, 13 | // Harvey, and Kennedy, "A Simple, Fast Dominance Algorithm", 14 | // 2001. 15 | // 16 | // Unlike in Cooper, we mostly use the original node naming, 17 | // but "intersect" maps into the post-order node naming as 18 | // needed. 19 | 20 | po := PostOrder(g, root) 21 | 22 | // Compute the post-order node naming for the "intersect" 23 | // routine. poNum maps from node to post-order name. 24 | poNum := make([]int, g.NumNodes()) 25 | for i, n := range po { 26 | poNum[n] = i 27 | } 28 | 29 | rpo, po := Reverse(po), nil 30 | 31 | // Initialize IDom. 32 | idom := make([]int, g.NumNodes()) 33 | for i := range idom { 34 | idom[i] = -1 35 | } 36 | idom[root] = root 37 | 38 | // Iterate to convergence. 39 | changed := true 40 | for changed { 41 | changed = false 42 | for _, b := range rpo { 43 | if b == root { 44 | continue 45 | } 46 | 47 | newIdom := -1 48 | for _, p := range g.In(b) { 49 | if idom[p] == -1 { 50 | continue 51 | } 52 | if newIdom == -1 { 53 | newIdom = p 54 | continue 55 | } 56 | newIdom = intersect(idom, poNum, p, newIdom) 57 | } 58 | 59 | if idom[b] != newIdom { 60 | idom[b] = newIdom 61 | changed = true 62 | } 63 | } 64 | } 65 | 66 | // Clear root's dominator, which is currently a self-loop. 67 | idom[root] = -1 68 | 69 | return idom 70 | } 71 | 72 | func intersect(idom, poNum []int, b1, b2 int) int { 73 | for b1 != b2 { 74 | for poNum[b1] < poNum[b2] { 75 | b1 = idom[b1] 76 | } 77 | for poNum[b2] < poNum[b1] { 78 | b2 = idom[b2] 79 | } 80 | } 81 | return b1 82 | } 83 | 84 | // DomFrontier returns the dominance frontier of each node in g. idom 85 | // must be IDom(g, root). idom may be nil, in which case this computes 86 | // IDom. 87 | func DomFrontier(g graph.BiGraph, root int, idom []int) [][]int { 88 | // This implements the dominance frontier algorithm of Cooper, 89 | // Harvey, and Kennedy, "A Simple, Fast Dominance Algorithm", 90 | // 2001. 91 | 92 | if idom == nil { 93 | idom = IDom(g, root) 94 | } 95 | 96 | df := make([][]int, g.NumNodes()) 97 | for b, bdom := range idom { 98 | preds := g.In(b) 99 | if len(preds) < 2 { 100 | continue 101 | } 102 | 103 | for _, pred := range preds { 104 | runner := pred 105 | for runner != bdom { 106 | // Add b to runner's DF set. 107 | for _, rdf := range df[runner] { 108 | if rdf == b { 109 | goto found 110 | } 111 | } 112 | df[runner] = append(df[runner], b) 113 | found: 114 | runner = idom[runner] 115 | } 116 | } 117 | } 118 | 119 | // Make sure empty sets are filled in. 120 | for i := range df { 121 | if df[i] == nil { 122 | df[i] = []int{} 123 | } 124 | } 125 | return df 126 | } 127 | 128 | // Dom computes the dominator tree from the immediate dominators (as 129 | // computed by IDom). The nodes of the resulting DomTree have the same 130 | // numbering as the nodes in the original graph. 131 | func Dom(idom []int) *DomTree { 132 | children := make([][]int, len(idom)) 133 | 134 | // Chop up a single slice used to store the children. 135 | cspace := make([]int, len(idom)) 136 | for _, parent := range idom { 137 | if parent != -1 { 138 | cspace[parent]++ 139 | } 140 | } 141 | used := 0 142 | for i, n := range cspace { 143 | children[i] = cspace[used : used : used+n] 144 | used += n 145 | } 146 | 147 | // Actually create the children tree now. 148 | for node, parent := range idom { 149 | if parent != -1 { 150 | children[parent] = append(children[parent], node) 151 | } 152 | } 153 | 154 | return &DomTree{idom, children} 155 | } 156 | 157 | // DomTree is a dominator tree. 158 | // 159 | // It also satisfies the BiGraph interface, which edges pointing 160 | // toward children. 161 | type DomTree struct { 162 | idom []int 163 | children [][]int 164 | } 165 | 166 | func (t *DomTree) IDom(n int) int { 167 | return t.idom[n] 168 | } 169 | 170 | func (t *DomTree) NumNodes() int { 171 | return len(t.idom) 172 | } 173 | 174 | func (t *DomTree) In(n int) []int { 175 | return t.idom[n : n+1] 176 | } 177 | 178 | func (t *DomTree) Out(n int) []int { 179 | return t.children[n] 180 | } 181 | -------------------------------------------------------------------------------- /graph/subgraph.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package graph 6 | 7 | // A Subgraph is a Graph that consists of a subset of the nodes and 8 | // vertices from another, underlying Graph. 9 | type Subgraph interface { 10 | Graph 11 | 12 | // Underlying returns the underlying graph that this is a 13 | // subgraph of. 14 | Underlying() Graph 15 | 16 | // NodeMap transduces a node property map on the underlying 17 | // graph into a node property map on this graph. 18 | NodeMap(underlyingMap func(node int) interface{}) func(node int) interface{} 19 | 20 | // EdgeMap transduces an edge property map on the underlying 21 | // graph into an edge property map on this graph. 22 | EdgeMap(underlyingMap func(node, edge int) interface{}) func(node, edge int) interface{} 23 | } 24 | 25 | // SubgraphKeep returns a subgraph of g that keeps the given nodes and 26 | // edges. Subgraph node i corresponds to nodes[i] in g. 27 | func SubgraphKeep(g Graph, nodes []int, edges []Edge) Subgraph { 28 | // Create old-to-new node mapping. 29 | gNodes := g.NumNodes() 30 | oldToNew := make(map[int]int, len(nodes)) 31 | for newNode, oldNode := range nodes { 32 | if oldNode < 0 || oldNode >= gNodes { 33 | panic("node not in underlying graph") 34 | } 35 | if _, ok := oldToNew[oldNode]; ok { 36 | panic("duplicate node") 37 | } 38 | oldToNew[oldNode] = newNode 39 | } 40 | 41 | // Construct new nodes. 42 | newNodes := make([]listSubgraphNode, len(nodes)) 43 | for i, oldNode := range nodes { 44 | newNodes[i].oldNode = oldNode 45 | } 46 | 47 | // Map old edge indexes to new node IDs. 48 | for _, oldEdge := range edges { 49 | newNode := &newNodes[oldToNew[oldEdge.Node]] 50 | oldTo := g.Out(oldEdge.Node)[oldEdge.Edge] 51 | newTo := oldToNew[oldTo] 52 | 53 | newNode.out = append(newNode.out, newTo) 54 | newNode.oldEdges = append(newNode.oldEdges, oldEdge.Edge) 55 | } 56 | 57 | return &listSubgraph{g, newNodes} 58 | } 59 | 60 | // SubgraphRemove returns a subgraph of g that removes the given nodes 61 | // and edges from g, as well as all edges incident to those nodes. 62 | func SubgraphRemove(g Graph, nodes []int, edges []Edge) Subgraph { 63 | // Collect the set of nodes and edges to remove. 64 | rmNodes := make(map[int]struct{}, len(nodes)) 65 | for _, node := range nodes { 66 | rmNodes[node] = struct{}{} 67 | } 68 | rmEdges := make(map[Edge]struct{}, len(edges)) 69 | for _, edge := range edges { 70 | rmEdges[edge] = struct{}{} 71 | } 72 | 73 | // Create new-to-old and old-to-new node mappings. 74 | newNodes := make([]listSubgraphNode, 0, g.NumNodes()-len(rmNodes)) 75 | oldToNew := make(map[int]int, cap(newNodes)) 76 | for oldNode := 0; oldNode < g.NumNodes(); oldNode++ { 77 | if _, ok := rmNodes[oldNode]; ok { 78 | continue 79 | } 80 | newNode := len(newNodes) 81 | newNodes = append(newNodes, listSubgraphNode{oldNode: oldNode}) 82 | oldToNew[oldNode] = newNode 83 | } 84 | 85 | // Create edge mappings. 86 | for i := range newNodes { 87 | newNode := &newNodes[i] 88 | oldNode := newNode.oldNode 89 | oldOut := g.Out(oldNode) 90 | for j, oldNode2 := range oldOut { 91 | if _, ok := rmNodes[oldNode2]; ok { 92 | // Target node removed. 93 | continue 94 | } 95 | if _, ok := rmEdges[Edge{oldNode, j}]; ok { 96 | // Edge removed. 97 | continue 98 | } 99 | newNode.out = append(newNode.out, oldToNew[oldNode2]) 100 | newNode.oldEdges = append(newNode.oldEdges, j) 101 | } 102 | } 103 | 104 | return &listSubgraph{g, newNodes} 105 | } 106 | 107 | type listSubgraph struct { 108 | underlying Graph 109 | nodes []listSubgraphNode 110 | } 111 | 112 | type listSubgraphNode struct { 113 | out []int // Adjacency list 114 | oldNode int // Node ID in underlying graph 115 | oldEdges []int // New edge index -> old edge index 116 | } 117 | 118 | func (s *listSubgraph) NumNodes() int { 119 | return len(s.nodes) 120 | } 121 | 122 | func (s *listSubgraph) Out(node int) []int { 123 | return s.nodes[node].out 124 | } 125 | 126 | func (s *listSubgraph) Underlying() Graph { 127 | return s.underlying 128 | } 129 | 130 | func (s *listSubgraph) NodeMap(underlyingMap func(node int) interface{}) func(node int) interface{} { 131 | return func(node int) interface{} { 132 | return underlyingMap(s.nodes[node].oldNode) 133 | } 134 | } 135 | 136 | func (s *listSubgraph) EdgeMap(underlyingMap func(node, edge int) interface{}) func(node, edge int) interface{} { 137 | return func(node, edge int) interface{} { 138 | newNode := &s.nodes[node] 139 | return underlyingMap(newNode.oldNode, newNode.oldEdges[edge]) 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /graph/graphout/dot.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package graphout 6 | 7 | import ( 8 | "fmt" 9 | "io" 10 | "os" 11 | "strings" 12 | 13 | "github.com/aclements/go-moremath/graph" 14 | ) 15 | 16 | // Dot contains options for generating a Graphviz Dot graph from a 17 | // Graph. 18 | type Dot struct { 19 | // Name is the name given to the graph. Usually this can be 20 | // left blank. 21 | Name string 22 | 23 | // Label returns the string to use as a label for the given 24 | // node. If nil, nodes are labeled with their node numbers. 25 | Label func(node int) string 26 | 27 | // NodeAttrs, if non-nil, returns a set of attributes for a 28 | // node. If this includes a "label" attribute, it overrides 29 | // the label returned by Label. 30 | NodeAttrs func(node int) []DotAttr 31 | 32 | // EdgeAttrs, if non-nil, returns a set of attributes for an 33 | // edge. 34 | EdgeAttrs func(node, edge int) []DotAttr 35 | } 36 | 37 | // DotAttr is an attribute for a Dot node or edge. 38 | type DotAttr struct { 39 | Name string 40 | // Val is the value of this attribute. It may be a string 41 | // (which will be escaped), bool, int, uint, float64 or 42 | // DotLiteral. 43 | Val interface{} 44 | } 45 | 46 | // DotLiteral is a string literal that should be passed to dot 47 | // unescaped. 48 | type DotLiteral string 49 | 50 | func defaultLabel(node int) string { 51 | return fmt.Sprintf("%d", node) 52 | } 53 | 54 | // Print writes the Dot form of g to os.Stdout. 55 | func (d Dot) Print(g graph.Graph) error { 56 | return d.Fprint(os.Stdout, g) 57 | } 58 | 59 | // Sprint returns the Dot form of g as a string. 60 | func (d Dot) Sprint(g graph.Graph) string { 61 | var buf strings.Builder 62 | d.Fprint(&buf, g) 63 | return buf.String() 64 | } 65 | 66 | // Fprint writes the Dot form of g to w. 67 | func (d Dot) Fprint(w io.Writer, g graph.Graph) error { 68 | label := d.Label 69 | if label == nil { 70 | label = defaultLabel 71 | } 72 | 73 | _, err := fmt.Fprintf(w, "digraph %s {\n", DotString(d.Name)) 74 | if err != nil { 75 | return err 76 | } 77 | 78 | for i := 0; i < g.NumNodes(); i++ { 79 | // Define node. 80 | var attrList []DotAttr 81 | var haveLabel bool 82 | if d.NodeAttrs != nil { 83 | attrList = d.NodeAttrs(i) 84 | for _, attr := range attrList { 85 | if attr.Name == "label" { 86 | haveLabel = true 87 | break 88 | } 89 | } 90 | } 91 | if !haveLabel { 92 | attrList = attrList[:len(attrList):len(attrList)] 93 | attrList = append(attrList, DotAttr{"label", label(i)}) 94 | } 95 | _, err = fmt.Fprintf(w, "n%d%s;\n", i, formatAttrs(attrList)) 96 | if err != nil { 97 | return err 98 | } 99 | 100 | // Connect node. 101 | for j, out := range g.Out(i) { 102 | var attrs string 103 | if d.EdgeAttrs != nil { 104 | attrs = formatAttrs(d.EdgeAttrs(i, j)) 105 | } 106 | _, err = fmt.Fprintf(w, "n%d -> n%d%s;\n", i, out, attrs) 107 | if err != nil { 108 | return err 109 | } 110 | } 111 | } 112 | 113 | _, err = fmt.Fprintf(w, "}\n") 114 | return err 115 | } 116 | 117 | // DotString returns s as a quoted dot string. 118 | // 119 | // Users of the Dot type don't need to call this, since it will 120 | // automatically quote strings. However, this is useful for building 121 | // custom dot output. 122 | func DotString(s string) string { 123 | buf := []byte{'"'} 124 | for i := 0; i < len(s); i++ { 125 | switch s[i] { 126 | case '\n': 127 | buf = append(buf, '\\', 'n') 128 | case '\\', '"', '{', '}', '<', '>', '|': 129 | // TODO: Option to allow formatting 130 | // characters? Maybe private use code points 131 | // to encode formatting characters? Or 132 | // something more usefully structured? 133 | buf = append(buf, '\\', s[i]) 134 | default: 135 | buf = append(buf, s[i]) 136 | } 137 | } 138 | buf = append(buf, '"') 139 | return string(buf) 140 | } 141 | 142 | // formatAttrs formats attrs as a dot attribute set, including the 143 | // surrounding brackets. If attrs is empty, it returns an empty 144 | // string. 145 | func formatAttrs(attrs []DotAttr) string { 146 | if len(attrs) == 0 { 147 | return "" 148 | } 149 | var buf strings.Builder 150 | buf.WriteString(" [") 151 | for i, attr := range attrs { 152 | if i > 0 { 153 | buf.WriteString(",") 154 | } 155 | buf.WriteString(attr.Name) 156 | buf.WriteString("=") 157 | switch val := attr.Val.(type) { 158 | case string: 159 | buf.WriteString(DotString(val)) 160 | case int, uint, float64: 161 | fmt.Fprintf(&buf, "%v", val) 162 | case DotLiteral: 163 | buf.WriteString(string(val)) 164 | default: 165 | panic(fmt.Sprintf("dot attribute %s had unknown type %T", attr.Name, attr.Val)) 166 | } 167 | } 168 | buf.WriteString("]") 169 | return buf.String() 170 | } 171 | -------------------------------------------------------------------------------- /stats/ttest.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import ( 8 | "errors" 9 | "math" 10 | ) 11 | 12 | // A TTestResult is the result of a t-test. 13 | type TTestResult struct { 14 | // N1 and N2 are the sizes of the input samples. For a 15 | // one-sample t-test, N2 is 0. 16 | N1, N2 int 17 | 18 | // T is the value of the t-statistic for this t-test. 19 | T float64 20 | 21 | // DoF is the degrees of freedom for this t-test. 22 | DoF float64 23 | 24 | // AltHypothesis specifies the alternative hypothesis tested 25 | // by this test against the null hypothesis that there is no 26 | // difference in the means of the samples. 27 | AltHypothesis LocationHypothesis 28 | 29 | // P is p-value for this t-test for the given null hypothesis. 30 | P float64 31 | } 32 | 33 | func newTTestResult(n1, n2 int, t, dof float64, alt LocationHypothesis) *TTestResult { 34 | dist := TDist{dof} 35 | var p float64 36 | switch alt { 37 | case LocationDiffers: 38 | p = 2 * (1 - dist.CDF(math.Abs(t))) 39 | case LocationLess: 40 | p = dist.CDF(t) 41 | case LocationGreater: 42 | p = 1 - dist.CDF(t) 43 | } 44 | return &TTestResult{N1: n1, N2: n2, T: t, DoF: dof, AltHypothesis: alt, P: p} 45 | } 46 | 47 | // A TTestSample is a sample that can be used for a one or two sample 48 | // t-test. 49 | type TTestSample interface { 50 | Weight() float64 51 | Mean() float64 52 | Variance() float64 53 | } 54 | 55 | var ( 56 | ErrSampleSize = errors.New("sample is too small") 57 | ErrZeroVariance = errors.New("sample has zero variance") 58 | ErrMismatchedSamples = errors.New("samples have different lengths") 59 | ) 60 | 61 | // TwoSampleTTest performs a two-sample (unpaired) Student's t-test on 62 | // samples x1 and x2. This is a test of the null hypothesis that x1 63 | // and x2 are drawn from populations with equal means. It assumes x1 64 | // and x2 are independent samples, that the distributions have equal 65 | // variance, and that the populations are normally distributed. 66 | func TwoSampleTTest(x1, x2 TTestSample, alt LocationHypothesis) (*TTestResult, error) { 67 | n1, n2 := x1.Weight(), x2.Weight() 68 | if n1 == 0 || n2 == 0 { 69 | return nil, ErrSampleSize 70 | } 71 | v1, v2 := x1.Variance(), x2.Variance() 72 | if v1 == 0 && v2 == 0 { 73 | return nil, ErrZeroVariance 74 | } 75 | 76 | dof := n1 + n2 - 2 77 | v12 := ((n1-1)*v1 + (n2-1)*v2) / dof 78 | t := (x1.Mean() - x2.Mean()) / math.Sqrt(v12*(1/n1+1/n2)) 79 | return newTTestResult(int(n1), int(n2), t, dof, alt), nil 80 | } 81 | 82 | // TwoSampleWelchTTest performs a two-sample (unpaired) Welch's t-test 83 | // on samples x1 and x2. This is like TwoSampleTTest, but does not 84 | // assume the distributions have equal variance. 85 | func TwoSampleWelchTTest(x1, x2 TTestSample, alt LocationHypothesis) (*TTestResult, error) { 86 | n1, n2 := x1.Weight(), x2.Weight() 87 | if n1 <= 1 || n2 <= 1 { 88 | // TODO: Can we still do this with n == 1? 89 | return nil, ErrSampleSize 90 | } 91 | v1, v2 := x1.Variance(), x2.Variance() 92 | if v1 == 0 && v2 == 0 { 93 | return nil, ErrZeroVariance 94 | } 95 | 96 | dof := math.Pow(v1/n1+v2/n2, 2) / 97 | (math.Pow(v1/n1, 2)/(n1-1) + math.Pow(v2/n2, 2)/(n2-1)) 98 | s := math.Sqrt(v1/n1 + v2/n2) 99 | t := (x1.Mean() - x2.Mean()) / s 100 | return newTTestResult(int(n1), int(n2), t, dof, alt), nil 101 | } 102 | 103 | // PairedTTest performs a two-sample paired t-test on samples x1 and 104 | // x2. If μ0 is non-zero, this tests if the average of the difference 105 | // is significantly different from μ0. If x1 and x2 are identical, 106 | // this returns nil. 107 | func PairedTTest(x1, x2 []float64, μ0 float64, alt LocationHypothesis) (*TTestResult, error) { 108 | if len(x1) != len(x2) { 109 | return nil, ErrMismatchedSamples 110 | } 111 | if len(x1) <= 1 { 112 | // TODO: Can we still do this with n == 1? 113 | return nil, ErrSampleSize 114 | } 115 | 116 | dof := float64(len(x1) - 1) 117 | 118 | diff := make([]float64, len(x1)) 119 | for i := range x1 { 120 | diff[i] = x1[i] - x2[i] 121 | } 122 | sd := StdDev(diff) 123 | if sd == 0 { 124 | // TODO: Can we still do the test? 125 | return nil, ErrZeroVariance 126 | } 127 | t := (Mean(diff) - μ0) * math.Sqrt(float64(len(x1))) / sd 128 | return newTTestResult(len(x1), len(x2), t, dof, alt), nil 129 | } 130 | 131 | // OneSampleTTest performs a one-sample t-test on sample x. This tests 132 | // the null hypothesis that the population mean is equal to μ0. This 133 | // assumes the distribution of the population of sample means is 134 | // normal. 135 | func OneSampleTTest(x TTestSample, μ0 float64, alt LocationHypothesis) (*TTestResult, error) { 136 | n, v := x.Weight(), x.Variance() 137 | if n == 0 { 138 | return nil, ErrSampleSize 139 | } 140 | if v == 0 { 141 | // TODO: Can we still do the test? 142 | return nil, ErrZeroVariance 143 | } 144 | dof := n - 1 145 | t := (x.Mean() - μ0) * math.Sqrt(n) / math.Sqrt(v) 146 | return newTTestResult(int(n), 0, t, dof, alt), nil 147 | } 148 | -------------------------------------------------------------------------------- /scale/linear.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package scale 6 | 7 | import ( 8 | "math" 9 | 10 | "github.com/aclements/go-moremath/vec" 11 | ) 12 | 13 | type Linear struct { 14 | // Min and Max specify the lower and upper bounds of the input 15 | // domain. The input domain [Min, Max] will be linearly mapped 16 | // to the output range [0, 1]. 17 | Min, Max float64 18 | 19 | // Base specifies a base for computing ticks. Ticks will be 20 | // placed at powers of Base; that is at n*Base^l for n ∈ ℤ and 21 | // some integer tick level l. As a special case, a base of 0 22 | // alternates between ticks at n*10^⌊l/2⌋ and ticks at 23 | // 5n*10^⌊l/2⌋. 24 | Base int 25 | 26 | // If Clamp is true, the input is clamped to [Min, Max]. 27 | Clamp bool 28 | } 29 | 30 | // *Linear is a Quantitative scale. 31 | var _ Quantitative = &Linear{} 32 | 33 | func (s Linear) Map(x float64) float64 { 34 | if s.Min == s.Max { 35 | return 0.5 36 | } 37 | y := (x - s.Min) / (s.Max - s.Min) 38 | if s.Clamp { 39 | y = clamp(y) 40 | } 41 | return y 42 | } 43 | 44 | func (s Linear) Unmap(y float64) float64 { 45 | return y*(s.Max-s.Min) + s.Min 46 | } 47 | 48 | func (s *Linear) SetClamp(clamp bool) { 49 | s.Clamp = clamp 50 | } 51 | 52 | // ebase sanity checks and returns the "effective base" of this scale. 53 | // If s.Base is 0, it returns 10. If s.Base is 1 or negative, it 54 | // panics. 55 | func (s Linear) ebase() int { 56 | if s.Base == 0 { 57 | return 10 58 | } else if s.Base == 1 { 59 | panic("scale.Linear cannot have a base of 1") 60 | } else if s.Base < 0 { 61 | panic("scale.Linear cannot have a negative base") 62 | } 63 | return s.Base 64 | } 65 | 66 | // In the default base, the tick levels are: 67 | // 68 | // Level -2 is a major tick at -0.1, 0, 0.1, etc. 69 | // Level -1 is a major tick at -1, -0.5, 0, 0.5, 1, etc. 70 | // Level 0 is a major tick at -1, 0, 1, etc. 71 | // Level 1 is a major tick at -10, -5, 0, 5, 10, etc. 72 | // Level 2 is a major tick at -10, 0, 10, etc. 73 | // 74 | // That is, level 0 is unit intervals, and we alternate between 75 | // interval *= 5 and interval *= 2. Combined, these give us interval 76 | // *= 10 at every other level. 77 | // 78 | // In non-default bases, level 0 is the same and we alternate between 79 | // interval *= 1 (for consistency) and interval *= base. 80 | 81 | func (s *Linear) guessLevel() int { 82 | return 2 * int(math.Log(s.Max-s.Min)/math.Log(float64(s.ebase()))) 83 | } 84 | 85 | func (s *Linear) spacingAtLevel(level int, roundOut bool) (firstN, lastN, spacing float64) { 86 | // Watch out! Integer division is round toward zero, but we 87 | // need round down, and modulus is signed. 88 | exp, double := math.Floor(float64(level)/2), (level%2 == 1 || level%2 == -1) 89 | spacing = math.Pow(float64(s.ebase()), exp) 90 | if double && s.Base == 0 { 91 | spacing *= 5 92 | } 93 | 94 | // Add a tiny bit of slack to the floor and ceiling below so 95 | // that rounding errors don't significantly affect tick marks. 96 | slack := (s.Max - s.Min) * 1e-10 97 | 98 | if roundOut { 99 | firstN = math.Floor((s.Min + slack) / spacing) 100 | lastN = math.Ceil((s.Max - slack) / spacing) 101 | } else { 102 | firstN = math.Ceil((s.Min - slack) / spacing) 103 | lastN = math.Floor((s.Max + slack) / spacing) 104 | } 105 | return 106 | } 107 | 108 | // CountTicks returns the number of ticks in [s.Min, s.Max] at the 109 | // given tick level. 110 | func (s Linear) CountTicks(level int) int { 111 | return linearTicker{&s, false}.CountTicks(level) 112 | } 113 | 114 | // TicksAtLevel returns the tick locations in [s.Min, s.Max] as a 115 | // []float64 at the given tick level in ascending order. 116 | func (s Linear) TicksAtLevel(level int) interface{} { 117 | return linearTicker{&s, false}.TicksAtLevel(level) 118 | } 119 | 120 | type linearTicker struct { 121 | s *Linear 122 | roundOut bool 123 | } 124 | 125 | func (t linearTicker) CountTicks(level int) int { 126 | firstN, lastN, _ := t.s.spacingAtLevel(level, t.roundOut) 127 | return int(lastN - firstN + 1) 128 | } 129 | 130 | func (t linearTicker) TicksAtLevel(level int) interface{} { 131 | firstN, lastN, spacing := t.s.spacingAtLevel(level, t.roundOut) 132 | n := int(lastN - firstN + 1) 133 | return vec.Linspace(firstN*spacing, lastN*spacing, n) 134 | } 135 | 136 | func (s Linear) Ticks(o TickOptions) (major, minor []float64) { 137 | if o.Max <= 0 { 138 | return nil, nil 139 | } else if s.Min == s.Max { 140 | return []float64{s.Min}, []float64{s.Min} 141 | } else if s.Min > s.Max { 142 | s.Min, s.Max = s.Max, s.Min 143 | } 144 | 145 | level, ok := o.FindLevel(linearTicker{&s, false}, s.guessLevel()) 146 | if !ok { 147 | return nil, nil 148 | } 149 | return s.TicksAtLevel(level).([]float64), s.TicksAtLevel(level - 1).([]float64) 150 | } 151 | 152 | func (s *Linear) Nice(o TickOptions) { 153 | if s.Min == s.Max { 154 | s.Min -= 0.5 155 | s.Max += 0.5 156 | } else if s.Min > s.Max { 157 | s.Min, s.Max = s.Max, s.Min 158 | } 159 | 160 | level, ok := o.FindLevel(linearTicker{s, true}, s.guessLevel()) 161 | if !ok { 162 | return 163 | } 164 | 165 | firstN, lastN, spacing := s.spacingAtLevel(level, true) 166 | s.Min = firstN * spacing 167 | s.Max = lastN * spacing 168 | } 169 | -------------------------------------------------------------------------------- /fit/lsquares.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package fit 6 | 7 | import ( 8 | "fmt" 9 | "math" 10 | "strings" 11 | 12 | "gonum.org/v1/gonum/mat" 13 | ) 14 | 15 | // LinearLeastSquares computes the least squares fit for the function 16 | // 17 | // f(x) = Β₀terms₀(x) + Β₁terms₁(x) + ... 18 | // 19 | // to the data (xs[i], ys[i]). It returns the parameters Β₀, Β₁, ... 20 | // that minimize the sum of the squares of the residuals of f: 21 | // 22 | // ∑ (ys[i] - f(xs[i]))² 23 | // 24 | // If weights is non-nil, it is used to weight these residuals: 25 | // 26 | // ∑ weights[i] × (ys[i] - f(xs[i]))² 27 | // 28 | // The function f is specified by one Go function for each linear 29 | // term. For efficiency, the Go function is vectorized: it will be 30 | // passed a slice of x values in xs and must fill the slice termOut 31 | // with the value of the term for each value in xs. 32 | // 33 | // Note that this is called a "linear" least squares fit because the 34 | // fitted function is linear in the computed parameters. The function 35 | // need not be linear in x. 36 | func LinearLeastSquares(xs, ys, weights []float64, terms ...func(xs, termOut []float64)) (params []float64) { 37 | // The optimal parameters are found by solving for Β̂ in the 38 | // "normal equations": 39 | // 40 | // (𝐗ᵀ𝐖𝐗)Β̂ = 𝐗ᵀ𝐖𝐲 41 | // 42 | // where 𝐖 is a diagonal weight matrix (or the identity matrix 43 | // for the unweighted case). 44 | 45 | // TODO: Consider using orthogonal decomposition. 46 | 47 | // TODO: Consider providing a multidimensional version of 48 | // this. 49 | 50 | if len(xs) != len(ys) { 51 | panic("len(xs) != len(ys)") 52 | } 53 | if weights != nil && len(xs) != len(weights) { 54 | panic("len(xs) != len(weights)") 55 | } 56 | 57 | // Construct 𝐗ᵀ. This is the more convenient representation 58 | // for efficiently calling the term functions. 59 | xTVals := make([]float64, len(terms)*len(xs)) 60 | for i, term := range terms { 61 | term(xs, xTVals[i*len(xs):i*len(xs)+len(xs)]) 62 | } 63 | XT := mat.NewDense(len(terms), len(xs), xTVals) 64 | X := XT.T() 65 | 66 | // Construct 𝐗ᵀ𝐖. 67 | var XTW *mat.Dense 68 | if weights == nil { 69 | // 𝐖 is the identity matrix. 70 | XTW = XT 71 | } else { 72 | // Since 𝐖 is a diagonal matrix, we do this directly. 73 | XTW = mat.DenseCopyOf(XT) 74 | WDiag := mat.NewVecDense(len(weights), weights) 75 | for row := 0; row < len(terms); row++ { 76 | rowView := XTW.RowView(row).(*mat.VecDense) 77 | rowView.MulElemVec(rowView, WDiag) 78 | } 79 | } 80 | 81 | // Construct 𝐲. 82 | y := mat.NewVecDense(len(ys), ys) 83 | 84 | // Compute Β̂. 85 | lhs := mat.NewDense(len(terms), len(terms), nil) 86 | lhs.Mul(XTW, X) 87 | 88 | rhs := mat.NewVecDense(len(terms), nil) 89 | rhs.MulVec(XTW, y) 90 | 91 | BVals := make([]float64, len(terms)) 92 | B := mat.NewVecDense(len(terms), BVals) 93 | B.SolveVec(lhs, rhs) 94 | return BVals 95 | } 96 | 97 | // PolynomialRegressionResult is the resulting polynomial from a 98 | // PolynomialRegression. 99 | // 100 | // TODO: Should this just be a least squares regression result? We 101 | // have the terms functions, so we can construct F, though it won't be 102 | // very efficient. 103 | type PolynomialRegressionResult struct { 104 | // Coefficients is the coefficients of the fitted polynomial. 105 | // Coefficients[i] is the coefficient of the x^i term. 106 | Coefficients []float64 107 | 108 | // F evaluates the fitted polynomial at x. 109 | F func(x float64) float64 110 | } 111 | 112 | func (r PolynomialRegressionResult) String() string { 113 | var terms []string 114 | for pow, factor := range r.Coefficients { 115 | switch { 116 | case factor == 0: 117 | continue 118 | case pow == 0: 119 | terms = append(terms, fmt.Sprintf("%v", factor)) 120 | case pow == 1: 121 | terms = append(terms, fmt.Sprintf("%vx", factor)) 122 | default: 123 | terms = append(terms, fmt.Sprintf("%vx^%d", factor, pow)) 124 | } 125 | } 126 | if len(terms) == 0 { 127 | return "0" 128 | } 129 | return strings.Join(terms, "+") 130 | } 131 | 132 | // PolynomialRegression performs a least squares regression with a 133 | // polynomial of the given degree. If weights is non-nil, it is used 134 | // to weight the residuals. 135 | func PolynomialRegression(xs, ys, weights []float64, degree int) PolynomialRegressionResult { 136 | terms := make([]func(xs, termOut []float64), degree+1) 137 | terms[0] = func(xs, termsOut []float64) { 138 | for i := range termsOut { 139 | termsOut[i] = 1 140 | } 141 | } 142 | if degree >= 1 { 143 | terms[1] = func(xs, termOut []float64) { 144 | copy(termOut, xs) 145 | } 146 | } 147 | if degree >= 2 { 148 | terms[2] = func(xs, termOut []float64) { 149 | for i, x := range xs { 150 | termOut[i] = x * x 151 | } 152 | } 153 | } 154 | for d := 3; d < len(terms); d++ { 155 | d := d 156 | terms[d] = func(xs, termOut []float64) { 157 | for i, x := range xs { 158 | termOut[i] = math.Pow(x, float64(d+1)) 159 | } 160 | } 161 | } 162 | 163 | coeffs := LinearLeastSquares(xs, ys, weights, terms...) 164 | f := func(x float64) float64 { 165 | y := coeffs[0] 166 | xp := x 167 | for _, c := range coeffs[1:] { 168 | y += xp * c 169 | xp *= x 170 | } 171 | return y 172 | } 173 | return PolynomialRegressionResult{coeffs, f} 174 | } 175 | -------------------------------------------------------------------------------- /scale/linear_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package scale 6 | 7 | import ( 8 | "fmt" 9 | "testing" 10 | 11 | "github.com/aclements/go-moremath/internal/mathtest" 12 | "github.com/aclements/go-moremath/vec" 13 | ) 14 | 15 | func TestLinear(t *testing.T) { 16 | l := Linear{Min: -10, Max: 10} 17 | mathtest.WantFunc(t, fmt.Sprintf("%v.Map", l), l.Map, 18 | map[float64]float64{ 19 | -20: -0.5, 20 | -10: 0, 21 | 0: 0.5, 22 | 10: 1, 23 | 20: 1.5, 24 | }) 25 | mathtest.WantFunc(t, fmt.Sprintf("%v.Unmap", l), l.Unmap, 26 | map[float64]float64{ 27 | -0.5: -20, 28 | 0: -10, 29 | 0.5: 0, 30 | 1: 10, 31 | 1.5: 20, 32 | }) 33 | 34 | l.SetClamp(true) 35 | mathtest.WantFunc(t, fmt.Sprintf("%v.Map", l), l.Map, 36 | map[float64]float64{ 37 | -20: 0, 38 | -10: 0, 39 | 0: 0.5, 40 | 10: 1, 41 | 20: 1, 42 | }) 43 | mathtest.WantFunc(t, fmt.Sprintf("%v.Unmap", l), l.Unmap, 44 | map[float64]float64{ 45 | 0: -10, 46 | 0.5: 0, 47 | 1: 10, 48 | }) 49 | 50 | l = Linear{Min: 5, Max: 5} 51 | mathtest.WantFunc(t, fmt.Sprintf("%v.Map", l), l.Map, 52 | map[float64]float64{ 53 | -10: 0.5, 54 | 0: 0.5, 55 | 10: 0.5, 56 | }) 57 | mathtest.WantFunc(t, fmt.Sprintf("%v.Unmap", l), l.Unmap, 58 | map[float64]float64{ 59 | 0: 5, 60 | 0.5: 5, 61 | 1: 5, 62 | }) 63 | } 64 | 65 | func ticksEq(major, wmajor, minor, wminor []float64) bool { 66 | // TODO: It would be nice to have a deep Aeq. It could also 67 | // support checking predicates like LE(5) or IsNaN within 68 | // structures, which could be used in WantFunc. Heck, deep Aeq 69 | // could subsume WantFunc where the left side is a function 70 | // and the right side is a map from arguments to results, but 71 | // maybe it would be harder to produce a good error message. 72 | if len(major) != len(wmajor) || len(minor) != len(wminor) { 73 | return false 74 | } 75 | for i, v := range major { 76 | if !mathtest.Aeq(wmajor[i], v) { 77 | return false 78 | } 79 | } 80 | for i, v := range minor { 81 | if !mathtest.Aeq(wminor[i], v) { 82 | return false 83 | } 84 | } 85 | return true 86 | } 87 | 88 | func TestLinearTicks(t *testing.T) { 89 | m := func(m int) TickOptions { 90 | return TickOptions{Max: m} 91 | } 92 | 93 | l := Linear{Min: 0, Max: 100} 94 | major, minor := l.Ticks(m(5)) 95 | wmajor, wminor := vec.Linspace(0, 100, 3), vec.Linspace(0, 100, 11) 96 | if !ticksEq(major, wmajor, minor, wminor) { 97 | t.Errorf("%v.Ticks(5) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 98 | } 99 | 100 | major, minor = l.Ticks(m(2)) 101 | wmajor, wminor = vec.Linspace(0, 100, 2), vec.Linspace(0, 100, 3) 102 | if !ticksEq(major, wmajor, minor, wminor) { 103 | t.Errorf("%v.Ticks(2) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 104 | } 105 | 106 | l.Nice(m(2)) 107 | major, minor = l.Ticks(m(2)) 108 | if !ticksEq(major, wmajor, minor, wminor) { 109 | t.Errorf("%v.Ticks(2) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 110 | } 111 | 112 | l = Linear{Min: 15.4, Max: 16.6} 113 | major, minor = l.Ticks(m(5)) 114 | wmajor, wminor = vec.Linspace(15.5, 16.5, 3), vec.Linspace(15.4, 16.6, 13) 115 | if !ticksEq(major, wmajor, minor, wminor) { 116 | t.Errorf("%v.Ticks(5) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 117 | } 118 | 119 | l.Nice(m(5)) 120 | major, minor = l.Ticks(m(5)) 121 | wmajor, wminor = vec.Linspace(15, 17, 5), vec.Linspace(15, 17, 21) 122 | if !ticksEq(major, wmajor, minor, wminor) { 123 | t.Errorf("%v.Ticks(5) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 124 | } 125 | 126 | // Test negative tick levels. 127 | l = Linear{Min: 9.9989, Max: 10} 128 | major, minor = l.Ticks(m(2)) 129 | wmajor, wminor = vec.Linspace(9.999, 10, 2), vec.Linspace(9.999, 10, 3) 130 | if !ticksEq(major, wmajor, minor, wminor) { 131 | t.Errorf("%v.Ticks(2) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 132 | } 133 | 134 | l.Nice(m(2)) 135 | major, minor = l.Ticks(m(2)) 136 | wmajor, wminor = vec.Linspace(9.995, 10, 2), vec.Linspace(9.995, 10, 6) 137 | if !ticksEq(major, wmajor, minor, wminor) { 138 | t.Errorf("%v.Ticks(2) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 139 | } 140 | 141 | // Test non-default bases. 142 | l = Linear{Min: 2, Max: 9, Base: 2} 143 | major, minor = l.Ticks(m(5)) 144 | wmajor, wminor = vec.Linspace(2, 8, 4), vec.Linspace(2, 9, 8) 145 | if !ticksEq(major, wmajor, minor, wminor) { 146 | t.Errorf("%v.Ticks(5) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 147 | } 148 | 149 | l.Nice(m(5)) 150 | major, minor = l.Ticks(m(5)) 151 | wmajor, wminor = vec.Linspace(2, 10, 5), vec.Linspace(2, 10, 9) 152 | if !ticksEq(major, wmajor, minor, wminor) { 153 | t.Errorf("%v.Ticks(5) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 154 | } 155 | 156 | // Test Min==Max. 157 | l = Linear{Min: 2, Max: 2} 158 | major, minor = l.Ticks(m(5)) 159 | wmajor, wminor = []float64{2}, []float64{2} 160 | if !ticksEq(major, wmajor, minor, wminor) { 161 | t.Errorf("%v.Ticks(5) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 162 | } 163 | 164 | l.Nice(m(5)) 165 | major, minor = l.Ticks(m(5)) 166 | wmajor, wminor = vec.Linspace(1.5, 2.5, 3), vec.Linspace(1.5, 2.5, 11) 167 | if !ticksEq(major, wmajor, minor, wminor) { 168 | t.Errorf("%v.Ticks(5) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 169 | } 170 | 171 | } 172 | -------------------------------------------------------------------------------- /scale/log.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package scale 6 | 7 | import "math" 8 | 9 | type Log struct { 10 | private struct{} 11 | 12 | // Min and Max specify the lower and upper bounds of the input 13 | // domain. The input range [Min, Max] will be mapped to the 14 | // output range [0, 1]. The range [Min, Max] must not include 15 | // 0. 16 | Min, Max float64 17 | 18 | // Base specifies the base of the logarithm for computing 19 | // ticks. Ticks will be placed at Base^((2^l)*n) for tick 20 | // level l ∈ ℕ and n ∈ ℤ. Typically l is 0, in which case this 21 | // is simply Base^n. 22 | Base int 23 | 24 | // If Clamp is true, the input is clamped to [Min, Max]. 25 | Clamp bool 26 | 27 | // TODO: Let the user specify the minor ticks. Default to [1, 28 | // .. 9], but [1, 3] and [1, 2, 5] are common. 29 | } 30 | 31 | // *Log is a Quantitative scale. 32 | var _ Quantitative = &Log{} 33 | 34 | // NewLog constructs a Log scale. If the arguments are out of range, 35 | // it returns a RangeErr. 36 | func NewLog(min, max float64, base int) (Log, error) { 37 | if min > max { 38 | min, max = max, min 39 | } 40 | 41 | if base <= 1 { 42 | return Log{}, RangeErr("Log scale base must be 2 or more") 43 | } 44 | if min <= 0 && max >= 0 { 45 | return Log{}, RangeErr("Log scale range cannot include 0") 46 | } 47 | 48 | return Log{Min: min, Max: max, Base: base}, nil 49 | } 50 | 51 | func (s *Log) ebounds() (bool, float64, float64) { 52 | if s.Min < 0 { 53 | return true, -s.Max, -s.Min 54 | } 55 | return false, s.Min, s.Max 56 | } 57 | 58 | func (s Log) Map(x float64) float64 { 59 | neg, min, max := s.ebounds() 60 | if neg { 61 | x = -x 62 | } 63 | if x <= 0 { 64 | return math.NaN() 65 | } 66 | if min == max { 67 | return 0.5 68 | } 69 | 70 | logMin, logMax := math.Log(min), math.Log(max) 71 | y := (math.Log(x) - logMin) / (logMax - logMin) 72 | if neg { 73 | y = 1 - y 74 | } 75 | if s.Clamp { 76 | y = clamp(y) 77 | } 78 | return y 79 | } 80 | 81 | func (s Log) Unmap(y float64) float64 { 82 | neg, min, max := s.ebounds() 83 | if neg { 84 | y = 1 - y 85 | } 86 | logMin, logMax := math.Log(min), math.Log(max) 87 | x := math.Exp(y*(logMax-logMin) + logMin) 88 | if neg { 89 | x = -x 90 | } 91 | return x 92 | } 93 | 94 | func (s *Log) SetClamp(clamp bool) { 95 | s.Clamp = clamp 96 | } 97 | 98 | // The tick levels are: 99 | // 100 | // Level 0 is a major tick at Base^n (1, 10, 100, ...) 101 | // Level 1 is a major tick at Base^(2*n) (1, 100, 10000, ...) 102 | // Level 2 is a major tick at Base^(4*n) (1, 10000, 100000000, ...) 103 | // 104 | // That is, each level eliminates every other tick. Levels below 0 are 105 | // not defined. 106 | 107 | func logb(x float64, b float64) float64 { 108 | return math.Log(x) / math.Log(b) 109 | } 110 | 111 | func (s *Log) spacingAtLevel(level int, roundOut bool) (firstN, lastN, ebase float64) { 112 | _, min, max := s.ebounds() 113 | 114 | // Compute the effective base at this level. 115 | ebase = math.Pow(float64(s.Base), math.Pow(2, float64(level))) 116 | lmin, lmax := logb(min, ebase), logb(max, ebase) 117 | 118 | // Add a tiny bit of slack to the floor and ceiling so that 119 | // rounding errors don't significantly affect tick marks. 120 | slack := (lmax - lmin) * 1e-10 121 | 122 | if roundOut { 123 | firstN = math.Floor(lmin + slack) 124 | lastN = math.Ceil(lmax - slack) 125 | } else { 126 | firstN = math.Ceil(lmin - slack) 127 | lastN = math.Floor(lmax + slack) 128 | } 129 | 130 | return 131 | } 132 | 133 | func (s *Log) CountTicks(level int) int { 134 | return logTicker{s, false}.CountTicks(level) 135 | } 136 | 137 | func (s *Log) TicksAtLevel(level int) interface{} { 138 | return logTicker{s, false}.TicksAtLevel(level) 139 | } 140 | 141 | type logTicker struct { 142 | s *Log 143 | roundOut bool 144 | } 145 | 146 | func (t logTicker) CountTicks(level int) int { 147 | if level < 0 { 148 | const maxInt = int(^uint(0) >> 1) 149 | return maxInt 150 | } 151 | 152 | firstN, lastN, _ := t.s.spacingAtLevel(level, t.roundOut) 153 | return int(lastN - firstN + 1) 154 | } 155 | 156 | func (t logTicker) TicksAtLevel(level int) interface{} { 157 | neg, min, max := t.s.ebounds() 158 | ticks := []float64{} 159 | 160 | if level < 0 { 161 | // Minor ticks for level 0. Get the major 162 | // ticks, but round out so we can fill in 163 | // minor ticks outside of the major ticks. 164 | firstN, lastN, _ := t.s.spacingAtLevel(0, true) 165 | for n := firstN; n <= lastN; n++ { 166 | tick := math.Pow(float64(t.s.Base), n) 167 | step := tick 168 | for i := 0; i < t.s.Base-1; i++ { 169 | if min <= tick && tick <= max { 170 | ticks = append(ticks, tick) 171 | } 172 | tick += step 173 | } 174 | } 175 | } else { 176 | firstN, lastN, base := t.s.spacingAtLevel(level, t.roundOut) 177 | for n := firstN; n <= lastN; n++ { 178 | ticks = append(ticks, math.Pow(base, n)) 179 | } 180 | } 181 | 182 | if neg { 183 | // Negate and reverse order of ticks. 184 | for i := 0; i < (len(ticks)+1)/2; i++ { 185 | j := len(ticks) - i - 1 186 | ticks[i], ticks[j] = -ticks[j], -ticks[i] 187 | } 188 | } 189 | 190 | return ticks 191 | } 192 | 193 | func (s Log) Ticks(o TickOptions) (major, minor []float64) { 194 | if o.Max <= 0 { 195 | return nil, nil 196 | } else if s.Min == s.Max { 197 | return []float64{s.Min}, []float64{s.Max} 198 | } 199 | t := logTicker{&s, false} 200 | 201 | level, ok := o.FindLevel(t, 0) 202 | if !ok { 203 | return nil, nil 204 | } 205 | return t.TicksAtLevel(level).([]float64), t.TicksAtLevel(level - 1).([]float64) 206 | } 207 | 208 | func (s *Log) Nice(o TickOptions) { 209 | if s.Min == s.Max { 210 | return 211 | } 212 | neg, _, _ := s.ebounds() 213 | t := logTicker{s, true} 214 | 215 | level, ok := o.FindLevel(t, 0) 216 | if !ok { 217 | return 218 | } 219 | firstN, lastN, base := s.spacingAtLevel(level, true) 220 | s.Min = math.Pow(base, firstN) 221 | s.Max = math.Pow(base, lastN) 222 | if neg { 223 | s.Min, s.Max = -s.Max, -s.Min 224 | } 225 | } 226 | -------------------------------------------------------------------------------- /cmd/dist/plot.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "fmt" 9 | "io" 10 | "math" 11 | "unicode/utf8" 12 | 13 | "github.com/aclements/go-moremath/scale" 14 | "github.com/aclements/go-moremath/stats" 15 | "github.com/aclements/go-moremath/vec" 16 | ) 17 | 18 | const ( 19 | // printSamples is the number of points on the X axis to 20 | // sample a function at for printing. 21 | printSamples = 500 22 | 23 | // printWidth is the width of the plot area in dots. 24 | printWidth = 70 * 2 25 | // printHeight is the height of the plot area in dots. 26 | printHeight = 3 * 4 27 | 28 | printXMargin = 1 29 | printYMargin = 1 30 | ) 31 | 32 | // FprintPDF prints a Unicode representation of the PDF of each 33 | // distribution in dists to w. Multiple distributions are printed 34 | // stacked vertically and on the same X axis (but possibly different Y 35 | // axes). 36 | func FprintPDF(w io.Writer, dists ...stats.Dist) error { 37 | xscale, xs := commonScale(dists...) 38 | for _, d := range dists { 39 | if err := fprintFn(w, d.PDF, xscale, xs); err != nil { 40 | return err 41 | } 42 | } 43 | return fprintScale(w, xscale) 44 | } 45 | 46 | // FprintCDF is equivalent to FprintPDF, but prints the CDF of each 47 | // distribution. 48 | func FprintCDF(w io.Writer, dists ...stats.Dist) error { 49 | xscale, xs := commonScale(dists...) 50 | for _, d := range dists { 51 | if err := fprintFn(w, d.CDF, xscale, xs); err != nil { 52 | return err 53 | } 54 | } 55 | return fprintScale(w, xscale) 56 | } 57 | 58 | // makeScale creates a linear scale from [x1, x2) to [y1, y2). 59 | func makeScale(x1, x2 float64, y1, y2 int) scale.QQ { 60 | return scale.QQ{ 61 | Src: &scale.Linear{Min: x1, Max: x2, Clamp: true}, 62 | Dest: &scale.Linear{Min: float64(y1), Max: float64(y2) - 1e-10}, 63 | } 64 | } 65 | 66 | func commonScale(dist ...stats.Dist) (xscale scale.QQ, xs []float64) { 67 | var l, h float64 68 | if len(dist) == 0 { 69 | l, h = -1, 1 70 | } else { 71 | l, h = dist[0].Bounds() 72 | for _, d := range dist[1:] { 73 | dl, dh := d.Bounds() 74 | l, h = math.Min(l, dl), math.Max(h, dh) 75 | } 76 | } 77 | xscale = makeScale(l, h, printXMargin, printWidth-printXMargin) 78 | //xscale.Src.Nice(10) 79 | src := xscale.Src.(*scale.Linear) 80 | xs = vec.Linspace(src.Min, src.Max, printSamples) 81 | return 82 | } 83 | 84 | func fprintScale(w io.Writer, sc scale.QQ) error { 85 | img := make([][]bool, printWidth) 86 | for i := range img { 87 | if i < printXMargin || i >= printWidth-printXMargin { 88 | img[i] = make([]bool, 2) 89 | } else { 90 | img[i] = []bool{true, false} 91 | } 92 | } 93 | major, _ := sc.Src.Ticks(scale.TickOptions{Max: 3}) 94 | labels := make([]string, len(major)) 95 | lpos := make([]int, len(major)) 96 | for i, tick := range major { 97 | x := int(sc.Map(tick)) 98 | img[x][1] = true 99 | // TODO: It would be nice if the scale could format 100 | // these ticks in a consistent way. 101 | labels[i] = fmt.Sprintf("%g", tick) 102 | width := len(labels[i]) 103 | lpos[i] = minint(maxint(x/2-width/2, 0), (printWidth+1)/2-width) 104 | } 105 | if err := fprintImage(w, img, []string{""}); err != nil { 106 | return err 107 | } 108 | curpos := 0 109 | for i, label := range labels { 110 | gap := lpos[i] - curpos 111 | if i > 0 { 112 | gap = maxint(gap, 1) 113 | } 114 | _, err := fmt.Fprintf(w, "%*s%s", gap, "", label) 115 | if err != nil { 116 | return err 117 | } 118 | curpos += gap + len(label) 119 | } 120 | _, err := fmt.Fprintf(w, "\n") 121 | return err 122 | } 123 | 124 | func fprintFn(w io.Writer, fn func(float64) float64, xscale scale.QQ, xs []float64) error { 125 | ys := vec.Map(fn, xs) 126 | 127 | yl, yh := stats.Bounds(ys) 128 | if yl > 0 && yl-(yh-yl)*0.1 <= 0 { 129 | yl = 0 130 | } 131 | yscale := makeScale(yh, yl, printYMargin, printHeight-printYMargin) 132 | 133 | // Render the function to an image. 134 | img := make([][]bool, printWidth+2) 135 | for i := range img { 136 | img[i] = make([]bool, printHeight) 137 | } 138 | for i, x := range xs { 139 | img[int(xscale.Map(x))][int(yscale.Map(ys[i]))] = true 140 | } 141 | 142 | // Render Y axis. 143 | ypos := printWidth 144 | for y := printYMargin; y < printHeight-printYMargin; y++ { 145 | img[ypos][y] = true 146 | } 147 | img[ypos+1][printYMargin] = true 148 | img[ypos+1][len(img[0])-1-printYMargin] = true 149 | 150 | trail := make([]string, (printHeight+3)/4) 151 | trail[0] = fmt.Sprintf(" %4.3f", yh) 152 | trail[len(trail)-1] = fmt.Sprintf(" %4.3f", yl) 153 | 154 | return fprintImage(w, img, trail) 155 | } 156 | 157 | func fprintImage(w io.Writer, img [][]bool, trail []string) error { 158 | var x, y int 159 | bit := func(ox, oy int) byte { 160 | if x+ox < len(img) && y+oy < len(img[x+ox]) && img[x+ox][y+oy] { 161 | return 1 162 | } 163 | return 0 164 | } 165 | 166 | maxTrail := len(trail[0]) 167 | for _, trail1 := range trail { 168 | maxTrail = maxint(maxTrail, len(trail1)) 169 | } 170 | buf := make([]byte, 3*(len(img)+1)/2+maxTrail+1) 171 | for y = 0; y < len(img[0]); y += 4 { 172 | bufpos := 0 173 | for x = 0; x < len(img); x += 2 { 174 | // Grab the 2x4 cell of pixels and encode it 175 | // into a byte with the following bit layout: 176 | // 0 3 177 | // 1 4 178 | // 2 5 179 | // 6 7 180 | cell := bit(0, 0)<<0 | bit(1, 0)<<3 181 | cell |= bit(0, 1)<<1 | bit(1, 1)<<4 182 | cell |= bit(0, 2)<<2 | bit(1, 2)<<5 183 | cell |= bit(0, 3)<<6 | bit(1, 3)<<7 184 | // Translate cell into the Unicode Braille space. 185 | r := 0x2800 + rune(cell) 186 | bufpos += utf8.EncodeRune(buf[bufpos:], r) 187 | } 188 | bufpos += copy(buf[bufpos:], trail[y/4]) 189 | buf[bufpos] = '\n' 190 | if _, err := w.Write(buf[:bufpos+1]); err != nil { 191 | return err 192 | } 193 | } 194 | return nil 195 | } 196 | 197 | // TODO: These should be exported by go-moremath. 198 | 199 | func maxint(a, b int) int { 200 | if a > b { 201 | return a 202 | } 203 | return b 204 | } 205 | 206 | func minint(a, b int) int { 207 | if a < b { 208 | return a 209 | } 210 | return b 211 | } 212 | -------------------------------------------------------------------------------- /stats/quantileci_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import ( 8 | "fmt" 9 | "math" 10 | "testing" 11 | ) 12 | 13 | func TestQuantileCI(t *testing.T) { 14 | var res QuantileCIResult 15 | check := func(wlo, whi int, wactual float64, wambig bool) { 16 | t.Helper() 17 | if wlo != res.LoOrder || whi != res.HiOrder || !aeq(wactual, res.Confidence) || wambig != res.Ambiguous { 18 | t.Errorf("want [%v,%v]@%v/%v, got [%v,%v]@%v/%v", 19 | wlo, whi, wactual, wambig, 20 | res.LoOrder, res.HiOrder, res.Confidence, res.Ambiguous) 21 | } 22 | } 23 | eq := func(a, b float64) bool { 24 | return a == b || 25 | math.IsInf(a, 1) && math.IsInf(b, 1) || 26 | math.IsInf(a, -1) && math.IsInf(b, -1) 27 | } 28 | checkSample := func(wlo, whi float64) { 29 | t.Helper() 30 | var s Sample 31 | for i := 1; i <= res.N; i++ { 32 | s.Xs = append(s.Xs, float64(i)) 33 | } 34 | s.Sorted = true 35 | _, lo, hi := res.SampleCI(s) 36 | if !eq(wlo, lo) || !eq(whi, hi) { 37 | t.Errorf("want [%v,%v], got [%v,%v]", wlo, whi, lo, hi) 38 | } 39 | } 40 | 41 | binomBuckets := func(n int, p float64) []float64 { 42 | t.Helper() 43 | dist := BinomialDist{N: n, P: p} 44 | bs := make([]float64, n+1) 45 | t.Logf("B(%d,%v):", n, p) 46 | for i := range bs { 47 | bs[i] = dist.PMF(float64(i)) 48 | t.Logf(" %d | %v", i, bs[i]) 49 | } 50 | return bs 51 | } 52 | normBuckets := func(n int, p float64) []float64 { 53 | t.Helper() 54 | norm := BinomialDist{N: n, P: p}.NormalApprox() 55 | bs := make([]float64, n+1) 56 | t.Logf("normal approximation to B(%d,%v):", n, p) 57 | for i := range bs { 58 | bs[i] = norm.CDF(float64(i)+0.5) - norm.CDF(float64(i)-0.5) 59 | t.Logf(" %d | %v", i, bs[i]) 60 | } 61 | return bs 62 | } 63 | 64 | // Confidence is so low that it has to fall directly around 65 | // the quantile. 66 | binomBuckets(4, 0.5) // Just for logging 67 | res = QuantileCI(4, 0.5, 0.001) 68 | check(2, 3, 0.375, false) 69 | checkSample(2, 3) 70 | res = QuantileCI(4, 0.25, 0.001) 71 | check(1, 2, 0.421875, false) 72 | checkSample(1, 2) 73 | // Quantile near 0. 74 | res = QuantileCI(4, 0, 0.001) 75 | check(0, 1, 1, false) 76 | checkSample(-inf, 1) 77 | res = QuantileCI(4, 0.0001, 0.001) 78 | check(0, 1, binomBuckets(4, 0.0001)[0], false) 79 | // Quantile near 1. 80 | res = QuantileCI(4, 1, 0.001) 81 | check(4, 5, 1, false) 82 | checkSample(4, inf) 83 | res = QuantileCI(4, 0.999, 0.001) 84 | check(4, 5, binomBuckets(4, 0.999)[4], false) 85 | // Confidence is exactly the PMF. 86 | res = QuantileCI(4, 0.5, 0.375) 87 | check(2, 3, 0.375, false) 88 | // And just beyond the PMF. This should be left-biased. 89 | res = QuantileCI(4, 0.5, 0.3750001) 90 | check(1, 3, 0.375+0.25, true) 91 | // Confidence is 1 or nearly 1. 92 | res = QuantileCI(4, 0.5, 1) 93 | check(0, 5, 1, false) 94 | res = QuantileCI(4, 0.5, 0.99) 95 | check(0, 5, 1, false) 96 | // Confidence is enough to trim one bucket. This should be 97 | // left-biased. 98 | res = QuantileCI(4, 0.5, 0.99-0.0625) 99 | check(0, 4, 0.375+2*0.25+0.0625, true) 100 | 101 | // Odd sample size with very low confidence. This should be 102 | // left-biased. 103 | binomBuckets(5, 0.5) // Just for logging 104 | res = QuantileCI(5, 0.5, 0.001) 105 | check(2, 3, 0.3125, true) 106 | // Confidence is exactly the PMF. This should be left-biased. 107 | res = QuantileCI(5, 0.5, 0.3125) 108 | check(2, 3, 0.3125, true) 109 | // And just beyond the PMF. 110 | res = QuantileCI(5, 0.5, 0.3125001) 111 | check(2, 4, 0.3125*2, false) 112 | // Confidence is 1 or nearly 1. 113 | res = QuantileCI(5, 0.5, 1) 114 | check(0, 6, 1, false) 115 | res = QuantileCI(5, 0.5, 0.99) 116 | check(0, 6, 1, false) 117 | // Confidence trims one bucket. 118 | res = QuantileCI(5, 0.5, 0.99-0.03125) 119 | check(0, 5, 1-0.03125, true) 120 | 121 | // Test normal approximation with even sample size. 122 | defer func(x int) { quantileCIApproxThreshold = x }(quantileCIApproxThreshold) 123 | quantileCIApproxThreshold = 0 124 | n := normBuckets(4, 0.5) 125 | // Low confidence directly around the quantile. 126 | res = QuantileCI(4, 0.5, 0.001) 127 | check(2, 3, n[2], false) 128 | // Confidence exactly equal to the center band. 129 | res = QuantileCI(4, 0.5, n[2]) 130 | check(2, 3, n[2], false) 131 | // And just above. This should be left-biased. 132 | res = QuantileCI(4, 0.5, n[2]+0.00001) 133 | check(1, 3, n[1]+n[2], true) 134 | // Confidence is 1. 135 | res = QuantileCI(4, 0.5, 1) 136 | check(0, 5, 1, false) 137 | // Confidence is nearly 1. Because of the approximation, we 138 | // have to drop fairly low before we lose a tail, so this is 139 | // still the full range. 140 | res = QuantileCI(4, 0.5, 0.99) 141 | check(0, 5, 1, false) 142 | // Confidence is low enough to lose the right-most band. This 143 | // should be left-biased. 144 | res = QuantileCI(4, 0.5, 0.90) 145 | check(0, 4, n[0]+n[1]+n[2]+n[3], true) 146 | 147 | // Test normal approximation with odd sample size. 148 | n = normBuckets(5, 0.5) 149 | // Low confidence directly around the quantile. Left-biased. 150 | res = QuantileCI(5, 0.5, 0.001) 151 | check(2, 3, n[2], true) 152 | // Confidence exactly equal to the mode band. Left-biased. 153 | res = QuantileCI(5, 0.5, n[2]) 154 | check(2, 3, n[2], true) 155 | // And just above. Symmetric. 156 | res = QuantileCI(5, 0.5, n[2]+0.00001) 157 | check(2, 4, n[2]+n[3], false) 158 | 159 | // Test normal approximation degenerate cases. 160 | res = QuantileCI(5, 0, 0.95) // 0%ile 161 | check(0, 1, 1, false) 162 | res = QuantileCI(5, 0.001, 0.95) 163 | check(0, 1, 1, false) 164 | res = QuantileCI(5, 1, 0.95) // 100%ile 165 | check(5, 6, 1, false) 166 | res = QuantileCI(5, 0.999, 0.95) 167 | check(5, 6, 1, false) 168 | } 169 | 170 | func BenchmarkQuantileCI(b *testing.B) { 171 | defer func(x int) { quantileCIApproxThreshold = x }(quantileCIApproxThreshold) 172 | for n := 5; n <= 100; n += 5 { 173 | for _, approx := range []bool{false, true} { 174 | if approx { 175 | quantileCIApproxThreshold = 0 176 | } else { 177 | quantileCIApproxThreshold = 1000 178 | } 179 | 180 | b.Run(fmt.Sprintf("n=%d/approx=%v", n, approx), func(b *testing.B) { 181 | for i := 0; i < b.N; i++ { 182 | QuantileCI(n, 0.5, 0.95) 183 | } 184 | }) 185 | } 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /scale/log_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package scale 6 | 7 | import ( 8 | "fmt" 9 | "math" 10 | "testing" 11 | 12 | "github.com/aclements/go-moremath/internal/mathtest" 13 | "github.com/aclements/go-moremath/vec" 14 | ) 15 | 16 | func TestLog(t *testing.T) { 17 | l, err := NewLog(0, 10, 10) 18 | if _, ok := err.(RangeErr); !ok { 19 | t.Errorf("want RangeErr; got %v", err) 20 | } 21 | l, err = NewLog(-10, 0, 10) 22 | if _, ok := err.(RangeErr); !ok { 23 | t.Errorf("want RangeErr; got %v", err) 24 | } 25 | l, err = NewLog(-10, 10, 10) 26 | if _, ok := err.(RangeErr); !ok { 27 | t.Errorf("want RangeErr; got %v", err) 28 | } 29 | l, err = NewLog(10, 20, 0) 30 | if _, ok := err.(RangeErr); !ok { 31 | t.Errorf("want RangeErr; got %v", err) 32 | } 33 | 34 | l, _ = NewLog(1, 10, 10) 35 | mathtest.WantFunc(t, fmt.Sprintf("%v.Map", l), l.Map, 36 | map[float64]float64{ 37 | -1: math.NaN(), 38 | 0: math.NaN(), 39 | 0.1: -1, 40 | 1: 0, 41 | math.Pow(10, 0.5): 0.5, 42 | 10: 1, 43 | 100: 2, 44 | }) 45 | mathtest.WantFunc(t, fmt.Sprintf("%v.Unmap", l), l.Unmap, 46 | map[float64]float64{ 47 | -1: 0.1, 48 | 0: 1, 49 | 0.5: math.Pow(10, 0.5), 50 | 1: 10, 51 | 2: 100, 52 | }) 53 | 54 | l.SetClamp(true) 55 | mathtest.WantFunc(t, fmt.Sprintf("%v.Map", l), l.Map, 56 | map[float64]float64{ 57 | -1: math.NaN(), 58 | 0: math.NaN(), 59 | 0.1: 0, 60 | 1: 0, 61 | math.Pow(10, 0.5): 0.5, 62 | 10: 1, 63 | 100: 1, 64 | }) 65 | mathtest.WantFunc(t, fmt.Sprintf("%v.Unmap", l), l.Unmap, 66 | map[float64]float64{ 67 | 0: 1, 68 | 0.5: math.Pow(10, 0.5), 69 | 1: 10, 70 | }) 71 | 72 | l, _ = NewLog(-1, -10, 10) 73 | mathtest.WantFunc(t, fmt.Sprintf("%v.Map", l), l.Map, 74 | map[float64]float64{ 75 | 1: math.NaN(), 76 | 0: math.NaN(), 77 | -0.1: 2, 78 | -1: 1, 79 | -math.Pow(10, 0.5): 0.5, 80 | -10: 0, 81 | -100: -1, 82 | }) 83 | mathtest.WantFunc(t, fmt.Sprintf("%v.Unmap", l), l.Unmap, 84 | map[float64]float64{ 85 | 2: -0.1, 86 | 1: -1, 87 | 0.5: -math.Pow(10, 0.5), 88 | 0: -10, 89 | -1: -100, 90 | }) 91 | 92 | l, _ = NewLog(5, 5, 10) 93 | mathtest.WantFunc(t, fmt.Sprintf("%v.Map", l), l.Map, 94 | map[float64]float64{ 95 | -1: math.NaN(), 96 | 0: math.NaN(), 97 | 1: 0.5, 98 | 10: 0.5, 99 | }) 100 | mathtest.WantFunc(t, fmt.Sprintf("%v.Unmap", l), l.Unmap, 101 | map[float64]float64{ 102 | 0: 5, 103 | 0.5: 5, 104 | 1: 5, 105 | }) 106 | } 107 | 108 | func TestLogTicks(t *testing.T) { 109 | m := func(m int) TickOptions { 110 | return TickOptions{Max: m} 111 | } 112 | 113 | // Test the obvious. 114 | l, _ := NewLog(1, 10, 10) 115 | major, minor := l.Ticks(m(5)) 116 | wmajor, wminor := vec.Logspace(0, 1, 2, 10), vec.Linspace(1, 10, 10) 117 | if !ticksEq(major, wmajor, minor, wminor) { 118 | t.Errorf("%v.Ticks(5) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 119 | } 120 | 121 | // Test two orders of magnitude. 122 | l, _ = NewLog(1, 100, 10) 123 | major, minor = l.Ticks(m(5)) 124 | wmajor, wminor = vec.Logspace(0, 2, 3, 10), vec.Concat(vec.Linspace(1, 9, 9), vec.Linspace(10, 100, 10)) 125 | if !ticksEq(major, wmajor, minor, wminor) { 126 | t.Errorf("%v.Ticks(5) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 127 | } 128 | 129 | // Test many orders of magnitude (higher tick levels). 130 | l, _ = NewLog(1, 1e8, 10) 131 | major, minor = l.Ticks(m(5)) 132 | wmajor, wminor = vec.Logspace(0, 4, 5, 100), vec.Logspace(0, 8, 9, 10) 133 | if !ticksEq(major, wmajor, minor, wminor) { 134 | t.Errorf("%v.Ticks(5) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 135 | } 136 | 137 | major, minor = l.Ticks(m(4)) 138 | wmajor, wminor = vec.Logspace(0, 2, 3, 10000), vec.Logspace(0, 4, 5, 100) 139 | if !ticksEq(major, wmajor, minor, wminor) { 140 | t.Errorf("%v.Ticks(5) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 141 | } 142 | 143 | // Test minor ticks outside major ticks. 144 | l, _ = NewLog(0.91, 200, 10) 145 | major, minor = l.Ticks(m(5)) 146 | wmajor, wminor = vec.Logspace(0, 2, 3, 10), vec.Concat(vec.Linspace(1, 9, 9), vec.Linspace(10, 100, 10), []float64{200}) 147 | if !ticksEq(major, wmajor, minor, wminor) { 148 | t.Errorf("%v.Ticks(5) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 149 | } 150 | 151 | // Test nicing. 152 | l.Nice(m(5)) 153 | major, minor = l.Ticks(m(5)) 154 | wmajor, wminor = vec.Logspace(-1, 3, 5, 10), vec.Concat(vec.Linspace(0.1, 0.9, 9), vec.Linspace(1, 9, 9), vec.Linspace(10, 90, 9), vec.Linspace(100, 1000, 10)) 155 | if !ticksEq(major, wmajor, minor, wminor) { 156 | t.Errorf("%v.Ticks(5) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 157 | } 158 | 159 | // Test negative ticks. 160 | neg := vec.Vectorize(func(x float64) float64 { return -x }) 161 | l, _ = NewLog(-1, -100, 10) 162 | major, minor = l.Ticks(m(5)) 163 | wmajor, wminor = neg(vec.Logspace(2, 0, 3, 10)), neg(vec.Concat(vec.Linspace(100, 10, 10), vec.Linspace(9, 1, 9))) 164 | if !ticksEq(major, wmajor, minor, wminor) { 165 | t.Errorf("%v.Ticks(5) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 166 | } 167 | 168 | major, minor = l.Ticks(m(2)) 169 | wmajor, wminor = neg(vec.Logspace(1, 0, 2, 100)), neg(vec.Logspace(2, 0, 3, 10)) 170 | if !ticksEq(major, wmajor, minor, wminor) { 171 | t.Errorf("%v.Ticks(5) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 172 | } 173 | 174 | l.Nice(m(5)) 175 | major, minor = l.Ticks(m(5)) 176 | wmajor, wminor = neg(vec.Logspace(2, 0, 3, 10)), neg(vec.Concat(vec.Linspace(100, 10, 10), vec.Linspace(9, 1, 9))) 177 | if !ticksEq(major, wmajor, minor, wminor) { 178 | t.Errorf("%v.Ticks(5) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 179 | } 180 | 181 | // Test Min==Max. 182 | l, _ = NewLog(5, 5, 10) 183 | major, minor = l.Ticks(m(5)) 184 | wmajor, wminor = []float64{5}, []float64{5} 185 | if !ticksEq(major, wmajor, minor, wminor) { 186 | t.Errorf("%v.Ticks(5) = %v, %v; want %v, %v", l, major, minor, wmajor, wminor) 187 | } 188 | } 189 | -------------------------------------------------------------------------------- /graph/graphalg/scc.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package graphalg 6 | 7 | import ( 8 | "sort" 9 | 10 | "github.com/aclements/go-moremath/graph" 11 | ) 12 | 13 | // SCCFlags is a set of optional analyses to perform when constructing 14 | // strongly-connected components. 15 | type SCCFlags int 16 | 17 | const ( 18 | // SCCSubnodeComponent instructs SCC to record a mapping from 19 | // subnode to component ID containing that subnode. 20 | SCCSubnodeComponent SCCFlags = 1 << iota 21 | 22 | // SCCEdges instructs SCC to record edges between components. 23 | // Otherwise, the resulting SCC graph will have a node for 24 | // each strongly-connected component, but no edges. 25 | SCCEdges 26 | ) 27 | 28 | // SCC computes the strongly-connected components of graph g. 29 | // 30 | // This implements Tarjan's strongly connected components algorithm 31 | // [1]. It runs in O(V + E) time and O(V) space. 32 | // 33 | // [1] Tarjan, R. E. (1972), "Depth-first search and linear graph 34 | // algorithms", SIAM Journal on Computing, 1 (2): 146–160. 35 | func SCC(g graph.Graph, flags SCCFlags) *SCCGraph { 36 | var sccs SCCGraph 37 | 38 | if flags&SCCEdges != 0 { 39 | // Edge construction requires sub-graph ID -> 40 | // component ID mapping. 41 | flags |= SCCSubnodeComponent 42 | } 43 | 44 | // This is based on the presentation of Tarjan's algorithm in 45 | // Sedgewick, Algorithms in C, Part 5, Third Edition, p. 202. 46 | // This is a fair bit simpler than Tarjan's original 47 | // presentation. We further simplify it by combining "pre" and 48 | // "low" into just "low", since pre is only ever used as a 49 | // visited mark. We instead start node indexing at 1 and use 50 | // low[x] == 0 to indicate node x has not been visited. 51 | numNodes := g.NumNodes() 52 | // For low, 0 means "not visited", ^uint(0) means "processed". 53 | low := make([]uint, numNodes) 54 | stack := []int{} 55 | index := uint(1) 56 | 57 | // We construct out-edges of each component by maintaining a 58 | // parallel stack of seen out-edges. As we pop a component off 59 | // the primary stack, we match it with this sack by recording 60 | // the length of the stack when an edge was pushed. 61 | type outEdge struct { 62 | cid int 63 | stackLen int 64 | } 65 | var out []outEdge 66 | 67 | var connect func(nid int) 68 | connect = func(nid int) { 69 | // Set the depth of n to the next unused index. 70 | low[nid] = index 71 | min := index 72 | index++ 73 | stackPos := len(stack) 74 | stack = append(stack, nid) 75 | 76 | // Process successors of n. 77 | for _, oid := range g.Out(nid) { 78 | if low[oid] == 0 { 79 | // Successor has not yet been visited. 80 | connect(oid) 81 | } 82 | if low[oid] < min { 83 | min = low[oid] 84 | } 85 | 86 | if flags&SCCEdges != 0 && low[oid] == ^uint(0) { 87 | // Successor is in another component. 88 | // Record the out-edge from this 89 | // component. 90 | cid := sccs.subnodeComponent[oid] 91 | out = append(out, outEdge{cid, stackPos}) 92 | } 93 | } 94 | 95 | if min < low[nid] { 96 | // Node n is not the root of an SCC. 97 | low[nid] = min 98 | return 99 | } 100 | 101 | // Node n is a root of an SCC. Pop the stack to 102 | // construct the component. 103 | cid := len(sccs.subnodeIndexes) 104 | var i int 105 | for i = len(stack) - 1; i >= 0; i-- { 106 | oid := stack[i] 107 | // Set low such that it can never be less than 108 | // min. This also indicates we've connected 109 | // oid to a component. 110 | low[oid] = ^uint(0) 111 | if flags&SCCSubnodeComponent != 0 { 112 | sccs.subnodeComponent[oid] = cid 113 | } 114 | if oid == nid { 115 | break 116 | } 117 | } 118 | sccs.subnodeIndexes = append(sccs.subnodeIndexes, len(sccs.subnodes)) 119 | sccs.subnodes = append(sccs.subnodes, stack[i:]...) 120 | stack = stack[:i] 121 | 122 | // Collect out-edges of this SCC. 123 | if flags&SCCEdges != 0 { 124 | outStart := len(sccs.out) 125 | sccs.outIndexes = append(sccs.outIndexes, outStart) 126 | // Pop the out-edge stack until it 127 | // aligns with the node stack. 128 | for i = len(out) - 1; i >= 0; i-- { 129 | if out[i].stackLen < len(stack) { 130 | break 131 | } 132 | sccs.out = append(sccs.out, out[i].cid) 133 | } 134 | i++ 135 | out = out[:i] 136 | // Dedup component IDs. 137 | sort.Ints(sccs.out[outStart:]) 138 | i = outStart 139 | for j := outStart; j < len(sccs.out); j++ { 140 | if i == outStart || sccs.out[i-1] != sccs.out[j] { 141 | sccs.out[i] = sccs.out[j] 142 | i++ 143 | } 144 | } 145 | sccs.out = sccs.out[:i] 146 | } 147 | } 148 | 149 | sccs.subnodes = make([]int, 0, numNodes) 150 | if flags&SCCSubnodeComponent != 0 { 151 | sccs.subnodeComponent = make([]int, numNodes) 152 | } 153 | 154 | for nid := range low { 155 | // If node n is not yet visited, then connect it. 156 | if low[nid] == 0 { 157 | connect(nid) 158 | } 159 | } 160 | sccs.subnodeIndexes = append(sccs.subnodeIndexes, len(sccs.subnodes)) 161 | if flags&SCCEdges != 0 { 162 | sccs.outIndexes = append(sccs.outIndexes, len(sccs.out)) 163 | } 164 | 165 | return &sccs 166 | } 167 | 168 | // SCCGraph is a set of strongly-connected components of another 169 | // graph. 170 | // 171 | // Each strongly-connected component is a node in this graph. 172 | // The components are numbered in reverse topological sort order. 173 | // 174 | // If the graph was constructed with flag SCCEdges, then it also has 175 | // edges between the components that follow the edges in the 176 | // underlying graph. 177 | type SCCGraph struct { 178 | subnodes []int // Concatenated list of sub-graph nodes in each component 179 | subnodeIndexes []int // Component ID -> subnodes base index 180 | 181 | subnodeComponent []int // Sub-node ID -> component ID 182 | 183 | out []int // Concatenated list of out-edges of each component 184 | outIndexes []int // Component ID -> out base index 185 | } 186 | 187 | // Subnodes returns the IDs of the nodes in the underlying graph that 188 | // comprise component cid. 189 | func (g *SCCGraph) Subnodes(cid int) []int { 190 | return g.subnodes[g.subnodeIndexes[cid]:g.subnodeIndexes[cid+1]] 191 | } 192 | 193 | // SubnodeComponent returns the component ID (a node ID in g) of 194 | // sub-graph node subID (a node ID in the underlying graph). 195 | // 196 | // Graph g must have been constructed with flag SCCSubnodeComponent. 197 | func (g *SCCGraph) SubnodeComponent(subID int) (componentID int) { 198 | if g.subnodeComponent == nil { 199 | panic("SCCGraph constructed without SCCSubnodeComponent flag") 200 | } 201 | return g.subnodeComponent[subID] 202 | } 203 | 204 | // NumNodes returns the number of strongly-connected components in g. 205 | func (g *SCCGraph) NumNodes() int { 206 | return len(g.subnodeIndexes) - 1 207 | } 208 | 209 | // Out returns the IDs of the components for which there are any edges 210 | // in the underlying graph from component cid. 211 | // 212 | // Graph g must have been constructed with flag SCCEdges. Otherwise 213 | // this returns nil. 214 | func (g *SCCGraph) Out(cid int) []int { 215 | if g.out == nil { 216 | return nil 217 | } 218 | return g.out[g.outIndexes[cid]:g.outIndexes[cid+1]] 219 | } 220 | -------------------------------------------------------------------------------- /stats/dist.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import "math/rand" 8 | 9 | // A DistCommon is a statistical distribution. DistCommon is a base 10 | // interface provided by both continuous and discrete distributions. 11 | type DistCommon interface { 12 | // CDF returns the cumulative probability Pr[X <= x]. 13 | // 14 | // For continuous distributions, the CDF is the integral of 15 | // the PDF from -inf to x. 16 | // 17 | // For discrete distributions, the CDF is the sum of the PMF 18 | // at all defined points from -inf to x, inclusive. Note that 19 | // the CDF of a discrete distribution is defined for the whole 20 | // real line (unlike the PMF) but has discontinuities where 21 | // the PMF is non-zero. 22 | // 23 | // The CDF is a monotonically increasing function and has a 24 | // domain of all real numbers. If the distribution has bounded 25 | // support, it has a range of [0, 1]; otherwise it has a range 26 | // of (0, 1). Finally, CDF(-inf)==0 and CDF(inf)==1. 27 | CDF(x float64) float64 28 | 29 | // Bounds returns reasonable bounds for this distribution's 30 | // PDF/PMF and CDF. The total weight outside of these bounds 31 | // should be approximately 0. 32 | // 33 | // For a discrete distribution, both bounds are integer 34 | // multiples of Step(). 35 | // 36 | // If this distribution has finite support, it returns exact 37 | // bounds l, h such that CDF(l')=0 for all l' < l and 38 | // CDF(h')=1 for all h' >= h. 39 | Bounds() (float64, float64) 40 | } 41 | 42 | // A Dist is a continuous statistical distribution. 43 | type Dist interface { 44 | DistCommon 45 | 46 | // PDF returns the value of the probability density function 47 | // of this distribution at x. 48 | PDF(x float64) float64 49 | } 50 | 51 | // A DiscreteDist is a discrete statistical distribution. 52 | // 53 | // Most discrete distributions are defined only at integral values of 54 | // the random variable. However, some are defined at other intervals, 55 | // so this interface takes a float64 value for the random variable. 56 | // The probability mass function rounds down to the nearest defined 57 | // point. Note that float64 values can exactly represent integer 58 | // values between ±2**53, so this generally shouldn't be an issue for 59 | // integer-valued distributions (likewise, for half-integer-valued 60 | // distributions, float64 can exactly represent all values between 61 | // ±2**52). 62 | type DiscreteDist interface { 63 | DistCommon 64 | 65 | // PMF returns the value of the probability mass function 66 | // Pr[X = x'], where x' is x rounded down to the nearest 67 | // defined point on the distribution. 68 | // 69 | // Note for implementers: for integer-valued distributions, 70 | // round x using int(math.Floor(x)). Do not use int(x), since 71 | // that truncates toward zero (unless all x <= 0 are handled 72 | // the same). 73 | PMF(x float64) float64 74 | 75 | // Step returns s, where the distribution is defined for sℕ. 76 | Step() float64 77 | } 78 | 79 | // TODO: Add a Support method for finite support distributions? Or 80 | // maybe just another return value from Bounds indicating that the 81 | // bounds are exact? 82 | 83 | // TODO: Plot method to return a pre-configured Plot object with 84 | // reasonable bounds and an integral function? Have to distinguish 85 | // PDF/CDF/InvCDF. Three methods? Argument? 86 | // 87 | // Doesn't have to be a method of Dist. Could be just a function that 88 | // takes a Dist and uses Bounds. 89 | 90 | // InvCDF returns the inverse CDF function of the given distribution 91 | // (also known as the quantile function or the percent point 92 | // function). This is a function f such that f(dist.CDF(x)) == x. If 93 | // dist.CDF is only weakly monotonic (that it, there are intervals 94 | // over which it is constant) and y > 0, f returns the smallest x that 95 | // satisfies this condition. In general, the inverse CDF is not 96 | // well-defined for y==0, but for convenience if y==0, f returns the 97 | // largest x that satisfies this condition. For distributions with 98 | // infinite support both the largest and smallest x are -Inf; however, 99 | // for distributions with finite support, this is the lower bound of 100 | // the support. 101 | // 102 | // If y < 0 or y > 1, f returns NaN. 103 | // 104 | // If dist implements InvCDF(float64) float64, this returns that 105 | // method. Otherwise, it returns a function that uses a generic 106 | // numerical method to construct the inverse CDF at y by finding x 107 | // such that dist.CDF(x) == y. This may have poor precision around 108 | // points of discontinuity, including f(0) and f(1). 109 | func InvCDF(dist DistCommon) func(y float64) (x float64) { 110 | type invCDF interface { 111 | InvCDF(float64) float64 112 | } 113 | if dist, ok := dist.(invCDF); ok { 114 | return dist.InvCDF 115 | } 116 | 117 | // Otherwise, use a numerical algorithm. 118 | // 119 | // TODO: For discrete distributions, use the step size to 120 | // inform this computation. 121 | return func(y float64) (x float64) { 122 | const almostInf = 1e100 123 | const xtol = 1e-16 124 | 125 | if y < 0 || y > 1 { 126 | return nan 127 | } else if y == 0 { 128 | l, _ := dist.Bounds() 129 | if dist.CDF(l) == 0 { 130 | // Finite support 131 | return l 132 | } else { 133 | // Infinite support 134 | return -inf 135 | } 136 | } else if y == 1 { 137 | _, h := dist.Bounds() 138 | if dist.CDF(h) == 1 { 139 | // Finite support 140 | return h 141 | } else { 142 | // Infinite support 143 | return inf 144 | } 145 | } 146 | 147 | // Find loX, hiX for which cdf(loX) < y <= cdf(hiX). 148 | var loX, loY, hiX, hiY float64 149 | x1, y1 := 0.0, dist.CDF(0) 150 | xdelta := 1.0 151 | if y1 < y { 152 | hiX, hiY = x1, y1 153 | for hiY < y && hiX != inf { 154 | loX, loY, hiX = hiX, hiY, hiX+xdelta 155 | hiY = dist.CDF(hiX) 156 | xdelta *= 2 157 | } 158 | } else { 159 | loX, loY = x1, y1 160 | for y <= loY && loX != -inf { 161 | hiX, hiY, loX = loX, loY, loX-xdelta 162 | loY = dist.CDF(loX) 163 | xdelta *= 2 164 | } 165 | } 166 | if loX == -inf { 167 | return loX 168 | } else if hiX == inf { 169 | return hiX 170 | } 171 | 172 | // Use bisection on the interval to find the smallest 173 | // x at which cdf(x) <= y. 174 | _, x = bisectBool(func(x float64) bool { 175 | return dist.CDF(x) < y 176 | }, loX, hiX, xtol) 177 | return 178 | } 179 | } 180 | 181 | // Rand returns a random number generator that draws from the given 182 | // distribution. The returned generator takes an optional source of 183 | // randomness; if this is nil, it uses the default global source. 184 | // 185 | // If dist implements Rand(*rand.Rand) float64, Rand returns that 186 | // method. Otherwise, it returns a generic generator based on dist's 187 | // inverse CDF (which may in turn use an efficient implementation or a 188 | // generic numerical implementation; see InvCDF). 189 | func Rand(dist DistCommon) func(*rand.Rand) float64 { 190 | type distRand interface { 191 | Rand(*rand.Rand) float64 192 | } 193 | if dist, ok := dist.(distRand); ok { 194 | return dist.Rand 195 | } 196 | 197 | // Otherwise, use a generic algorithm. 198 | inv := InvCDF(dist) 199 | return func(r *rand.Rand) float64 { 200 | var y float64 201 | for y == 0 { 202 | if r == nil { 203 | y = rand.Float64() 204 | } else { 205 | y = r.Float64() 206 | } 207 | } 208 | return inv(y) 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /stats/quantileci.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import ( 8 | "fmt" 9 | "math" 10 | ) 11 | 12 | // QuantileCIResult is the confidence interval for a quantile. 13 | type QuantileCIResult struct { 14 | // Quantile is the quantile of this confidence interval. This 15 | // is simply a copy of the argument to QuantileCI. 16 | Quantile float64 17 | 18 | // N is the sample size. 19 | N int 20 | 21 | // Confidence is the actual confidence level of this interval. 22 | // This will be >= the requested confidence. 23 | Confidence float64 24 | 25 | // LoOrder and HiOrder are the order statistics that bound the 26 | // confidence interval. By convention, these are 1-based, so 27 | // given an ordered slice of samples Xs, the CI is 28 | // Xs[LoOrder-1] to Xs[HiOrder-1]. 29 | // 30 | // These may be outside the range of the sample, which 31 | // indicates that corresponding bound is negative or positive 32 | // infinity. This can happen, for example, if the sample is 33 | // too small for a high confidence level, or the quantile is 34 | // close to 0 or 1. 35 | LoOrder, HiOrder int 36 | 37 | // Ambiguous indicates that the given confidence interval is 38 | // ambiguous. In this case, the interval LoOrder+1 to 39 | // HiOrder+1 has equivalent confidence. 40 | Ambiguous bool 41 | } 42 | 43 | // SampleCI returns the quantile and its confidence interval for a 44 | // sample given the parameters in ci. It may return negative or 45 | // positive infinity if the interval lies outside the sample. 46 | func (ci QuantileCIResult) SampleCI(s Sample) (q, lo, hi float64) { 47 | if s.Weights != nil { 48 | panic("Cannot compute quantile CI on a weighted sample") 49 | } 50 | if len(s.Xs) != ci.N { 51 | panic("Sample size differs from computed quantile CI") 52 | } 53 | 54 | if !s.Sorted { 55 | s = *s.Copy().Sort() 56 | } 57 | 58 | q = s.Quantile(ci.Quantile) 59 | if ci.LoOrder < 1 { 60 | // The sample is too small or the confidence is too high. 61 | lo = math.Inf(-1) 62 | } else { 63 | lo = s.Xs[ci.LoOrder-1] 64 | } 65 | if ci.HiOrder-1 >= len(s.Xs) { 66 | hi = math.Inf(1) 67 | } else { 68 | hi = s.Xs[ci.HiOrder-1] 69 | } 70 | return 71 | } 72 | 73 | // quantileCIApproxThreshold is the threshold above which a normal 74 | // approximation is used. This is a variable for testing. 75 | // 76 | // Performance-wise, these cross over at about n=5, but the 77 | // approximation isn't very good at low n. 78 | var quantileCIApproxThreshold = 30 79 | 80 | // QuantileCI returns the bounds of the confidence interval of the 81 | // q'th quantile in a sample of size n. 82 | func QuantileCI(n int, q, confidence float64) QuantileCIResult { 83 | const debug = false 84 | 85 | var res QuantileCIResult 86 | res.N = n 87 | res.Quantile = q 88 | 89 | if confidence >= 1 { 90 | res.Confidence = 1 91 | res.LoOrder = 0 92 | res.HiOrder = n + 1 93 | return res 94 | } 95 | 96 | if debug { 97 | fmt.Printf("QuantileCI(%v, %v, %v)\n", n, q, confidence) 98 | } 99 | 100 | // There's a dearth of good information online about how to 101 | // compute this, especially in corner cases. Some useful 102 | // online resources: 103 | // 104 | // https://online.stat.psu.edu/stat415/book/export/html/835 - 105 | // The concept of intervals, some worked examples. 106 | // 107 | // http://www.milefoot.com/math/stat/ci-medians.htm - Good 108 | // walk through of summing up binomial probabilities, 109 | // continuity correction for the normal approximation. 110 | 111 | // The sampling distribution for order statistics is the 112 | // binomial distribution. In this distribution, k is how many 113 | // samples come before the population median; or, 114 | // alternatively, an index into the intervals between samples 115 | // (where 0 is the interval from -∞ to the first sample). 116 | // Hence, PMF(k) gives the probability that the population 117 | // median falls in interval k, or between s.Xs[k-1] and 118 | // s.Xs[k]. 119 | samp := BinomialDist{N: n, P: q} 120 | 121 | // l and r are the left and right order statistics of the 122 | // confidence interval. 123 | var l, r int 124 | if samp.N <= quantileCIApproxThreshold { 125 | if debug { 126 | for i := 0; i <= samp.N; i++ { 127 | fmt.Printf(" %d | %v\n", i, samp.PMF(float64(i))) 128 | } 129 | } 130 | 131 | // Start with the mode and accumulate probabilities in 132 | // decreasing order until we pass the confidence 133 | // level. This uses the fact that the probabilities 134 | // decrease monotonically as you move out from the 135 | // mode. 136 | // 137 | // The binomial distribution can be have equal modes. 138 | // Since we want to left-bias our result, we start 139 | // with the lower of the two. 140 | x := int(math.Ceil(float64(samp.N+1)*samp.P) - 1) 141 | if samp.P == 0 { // Special case of the mode 142 | x = 0 143 | } 144 | accum := samp.PMF(float64(x)) 145 | if debug { 146 | fmt.Printf(" start %d => %v\n", x, accum) 147 | } 148 | 149 | // Compute the neighboring probabilities so we can 150 | // incrementally add and update them. [l, r) is the 151 | // interval we've summed. 152 | l, r = x, x+1 153 | lp, rp := samp.PMF(float64(l-1)), samp.PMF(float64(r)) 154 | // If the binomial distribution has two modes, then 155 | // our initial selection is ambiguous. 156 | res.Ambiguous = rp == accum 157 | 158 | // Accumulate probabilities to reach the desired 159 | // confidence level. We defend against accumulation 160 | // errors by stopping if there's no more to 161 | // accumulate. 162 | // 163 | // For the particular case of q=0.5, the distribution 164 | // is symmetric and we could just use InvCDF like we 165 | // do in the normal approximation. But that doesn't 166 | // generalize to other quantiles, and InvCDF isn't 167 | // particularly efficient on the binomial distribution 168 | // anyway. 169 | for accum < confidence && (lp > 0 || rp > 0) { 170 | res.Ambiguous = lp == rp 171 | if lp >= rp { // Left-bias 172 | accum += lp 173 | if debug { 174 | fmt.Printf(" +left %d => %v\n", l-1, accum) 175 | } 176 | l-- 177 | lp = samp.PMF(float64(l - 1)) 178 | } else { 179 | accum += rp 180 | if debug { 181 | fmt.Printf(" +right %d => %v\n", r, accum) 182 | } 183 | r++ 184 | rp = samp.PMF(float64(r)) 185 | } 186 | } 187 | res.Confidence = accum 188 | 189 | if debug { 190 | fmt.Printf(" final [%d,%d) => %v (ambiguous %v)\n", l, r, accum, res.Ambiguous) 191 | } 192 | } else { 193 | // Use the normal approximation. 194 | norm := samp.NormalApprox() 195 | alpha := (1 - confidence) / 2 196 | 197 | // Find the center "confidence" weight of the 198 | // distribution. 199 | l1 := norm.InvCDF(alpha) 200 | r1 := 2*norm.Mu - l1 // Symmetric around mean. 201 | 202 | // Find the band of the discrete binomial distribution 203 | // containing [l1, r1]. Because of the continuity 204 | // correction, point k in the binomial distribution 205 | // corresponds to band [k-0.5, k+0.5] in the normal 206 | // distribution. Hence, we round out to ℕ + 0.5 207 | // boundaries and then recover k. 208 | // 209 | // For example, let's say mu=2 and confidence is 210 | // really low. If [l1, r1] is [1.9, 2.1], that rounds 211 | // out to [1.5, 2.5], which is the band [2, 3) in the 212 | // binomial distribution. But if [l1, r1] is [1.4, 213 | // 2.6], that rounds out to [0.5, 3.5], which is the 214 | // band [1, 4) in the binomial distribution. 215 | floorInt := func(x float64) int { 216 | // int(x) truncates toward 0, so floor first. 217 | return int(math.Floor(x)) 218 | } 219 | l = floorInt(math.Floor(l1-0.5)+0.5) + 1 220 | r = floorInt(math.Ceil(r1-0.5)+0.5) + 1 221 | 222 | if debug { 223 | fmt.Printf(" [%v,%v] rounds to [%v,%v]\n", l1, r1, l, r) 224 | } 225 | 226 | // The actual confidence on the binomial 227 | // distribution is 228 | // 229 | // Pr[l <= X < r] = Pr[X <= r - 1] - Pr[X <= l - 1] 230 | // 231 | // To translate this into the normal 232 | // approximation, we add 0.5 to each bound for 233 | // the continuity correction. 234 | cdf := func(l, r int) float64 { 235 | return norm.CDF(float64(r)-0.5) - norm.CDF(float64(l)-0.5) 236 | } 237 | res.Confidence = cdf(l, r) 238 | // The computed interval is always symmetric. 239 | // Try left-biasing it and see if we can do 240 | // better while still satisfying the 241 | // confidence level. 242 | rBiased := r - 1 243 | if debug { 244 | fmt.Printf(" unbiased %v, biased %v\n", res.Confidence, cdf(l, rBiased)) 245 | } 246 | if aBiased := cdf(l, rBiased); aBiased >= confidence && aBiased < res.Confidence { 247 | if debug { 248 | fmt.Printf(" taking biased\n") 249 | } 250 | res.Confidence, res.Ambiguous = aBiased, true 251 | r = rBiased 252 | } 253 | if l <= 0 && r >= n+1 { 254 | // The CI covers everything, but 255 | // because the normal distribution has 256 | // infinite support, the confidence 257 | // computed by CDF won't be quite 1. 258 | // Certainly the median falls between 259 | // -inf and +inf. This can happen even 260 | // in the biasing case, so we check 261 | // this in any case. 262 | if debug { 263 | fmt.Printf(" adjusting for full range\n") 264 | } 265 | res.Confidence = 1 266 | res.Ambiguous = false 267 | } 268 | } 269 | 270 | if l < 0 { 271 | l = 0 272 | } 273 | if r > n+1 { 274 | r = n + 1 275 | } 276 | res.LoOrder, res.HiOrder = l, r 277 | return res 278 | } 279 | -------------------------------------------------------------------------------- /stats/utest.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import ( 8 | "math" 9 | "sort" 10 | 11 | "github.com/aclements/go-moremath/mathx" 12 | ) 13 | 14 | // A LocationHypothesis specifies the alternative hypothesis of a 15 | // location test such as a t-test or a Mann-Whitney U-test. The 16 | // default (zero) value is to test against the alternative hypothesis 17 | // that they differ. 18 | type LocationHypothesis int 19 | 20 | //go:generate stringer -type LocationHypothesis 21 | 22 | const ( 23 | // LocationLess specifies the alternative hypothesis that the 24 | // location of the first sample is less than the second. This 25 | // is a one-tailed test. 26 | LocationLess LocationHypothesis = -1 27 | 28 | // LocationDiffers specifies the alternative hypothesis that 29 | // the locations of the two samples are not equal. This is a 30 | // two-tailed test. 31 | LocationDiffers LocationHypothesis = 0 32 | 33 | // LocationGreater specifies the alternative hypothesis that 34 | // the location of the first sample is greater than the 35 | // second. This is a one-tailed test. 36 | LocationGreater LocationHypothesis = 1 37 | ) 38 | 39 | // A MannWhitneyUTestResult is the result of a Mann-Whitney U-test. 40 | type MannWhitneyUTestResult struct { 41 | // N1 and N2 are the sizes of the input samples. 42 | N1, N2 int 43 | 44 | // U is the value of the Mann-Whitney U statistic for this 45 | // test, generalized by counting ties as 0.5. 46 | // 47 | // Given the Cartesian product of the two samples, this is the 48 | // number of pairs in which the value from the first sample is 49 | // greater than the value of the second, plus 0.5 times the 50 | // number of pairs where the values from the two samples are 51 | // equal. Hence, U is always an integer multiple of 0.5 (it is 52 | // a whole integer if there are no ties) in the range [0, N1*N2]. 53 | // 54 | // U statistics always come in pairs, depending on which 55 | // sample is "first". The mirror U for the other sample can be 56 | // calculated as N1*N2 - U. 57 | // 58 | // There are many equivalent statistics with slightly 59 | // different definitions. The Wilcoxon (1945) W statistic 60 | // (generalized for ties) is U + (N1(N1+1))/2. It is also 61 | // common to use 2U to eliminate the half steps and Smid 62 | // (1956) uses N1*N2 - 2U to additionally center the 63 | // distribution. 64 | U float64 65 | 66 | // AltHypothesis specifies the alternative hypothesis tested 67 | // by this test against the null hypothesis that there is no 68 | // difference in the locations of the samples. 69 | AltHypothesis LocationHypothesis 70 | 71 | // P is the p-value of the Mann-Whitney test for the given 72 | // null hypothesis. 73 | P float64 74 | } 75 | 76 | // MannWhitneyExactLimit gives the largest sample size for which the 77 | // exact U distribution will be used for the Mann-Whitney U-test. 78 | // 79 | // Using the exact distribution is necessary for small sample sizes 80 | // because the distribution is highly irregular. However, computing 81 | // the distribution for large sample sizes is both computationally 82 | // expensive and unnecessary because it quickly approaches a normal 83 | // approximation. Computing the distribution for two 50 value samples 84 | // takes a few milliseconds on a 2014 laptop. 85 | var MannWhitneyExactLimit = 50 86 | 87 | // MannWhitneyTiesExactLimit gives the largest sample size for which 88 | // the exact U distribution will be used for the Mann-Whitney U-test 89 | // in the presence of ties. 90 | // 91 | // Computing this distribution is more expensive than computing the 92 | // distribution without ties, so this is set lower. Computing this 93 | // distribution for two 25 value samples takes about ten milliseconds 94 | // on a 2014 laptop. 95 | var MannWhitneyTiesExactLimit = 25 96 | 97 | // MannWhitneyUTest performs a Mann-Whitney U-test [1,2] of the null 98 | // hypothesis that two samples come from the same population against 99 | // the alternative hypothesis that one sample tends to have larger or 100 | // smaller values than the other. 101 | // 102 | // This is similar to a t-test, but unlike the t-test, the 103 | // Mann-Whitney U-test is non-parametric (it does not assume a normal 104 | // distribution). It has very slightly lower efficiency than the 105 | // t-test on normal distributions. 106 | // 107 | // Computing the exact U distribution is expensive for large sample 108 | // sizes, so this uses a normal approximation for sample sizes larger 109 | // than MannWhitneyExactLimit if there are no ties or 110 | // MannWhitneyTiesExactLimit if there are ties. This normal 111 | // approximation uses both the tie correction and the continuity 112 | // correction. 113 | // 114 | // This can fail with ErrSampleSize if either sample is empty or 115 | // ErrSamplesEqual if all sample values are equal. 116 | // 117 | // This is also known as a Mann-Whitney-Wilcoxon test and is 118 | // equivalent to the Wilcoxon rank-sum test, though the Wilcoxon 119 | // rank-sum test differs in nomenclature. 120 | // 121 | // [1] Mann, Henry B.; Whitney, Donald R. (1947). "On a Test of 122 | // Whether one of Two Random Variables is Stochastically Larger than 123 | // the Other". Annals of Mathematical Statistics 18 (1): 50–60. 124 | // 125 | // [2] Klotz, J. H. (1966). "The Wilcoxon, Ties, and the Computer". 126 | // Journal of the American Statistical Association 61 (315): 772-787. 127 | func MannWhitneyUTest(x1, x2 []float64, alt LocationHypothesis) (*MannWhitneyUTestResult, error) { 128 | n1, n2 := len(x1), len(x2) 129 | if n1 == 0 || n2 == 0 { 130 | return nil, ErrSampleSize 131 | } 132 | 133 | // Compute the U statistic and tie vector T. 134 | x1 = append([]float64(nil), x1...) 135 | x2 = append([]float64(nil), x2...) 136 | sort.Float64s(x1) 137 | sort.Float64s(x2) 138 | merged, labels := labeledMerge(x1, x2) 139 | 140 | R1 := 0.0 141 | T, hasTies := []int{}, false 142 | for i := 0; i < len(merged); { 143 | rank1, nx1, v1 := i+1, 0, merged[i] 144 | // Consume samples that tie this sample (including itself). 145 | for ; i < len(merged) && merged[i] == v1; i++ { 146 | if labels[i] == 1 { 147 | nx1++ 148 | } 149 | } 150 | // Assign all tied samples the average rank of the 151 | // samples, where merged[0] has rank 1. 152 | if nx1 != 0 { 153 | rank := float64(i+rank1) / 2 154 | R1 += rank * float64(nx1) 155 | } 156 | T = append(T, i-rank1+1) 157 | if i > rank1 { 158 | hasTies = true 159 | } 160 | } 161 | U1 := R1 - float64(n1*(n1+1))/2 162 | 163 | // Compute the smaller of U1 and U2 164 | U2 := float64(n1*n2) - U1 165 | Usmall := math.Min(U1, U2) 166 | 167 | var p float64 168 | if !hasTies && n1 <= MannWhitneyExactLimit && n2 <= MannWhitneyExactLimit || 169 | hasTies && n1 <= MannWhitneyTiesExactLimit && n2 <= MannWhitneyTiesExactLimit { 170 | // Use exact U distribution. U1 will be an integer. 171 | if len(T) == 1 { 172 | // All values are equal. Test is meaningless. 173 | return nil, ErrSamplesEqual 174 | } 175 | 176 | dist := UDist{N1: n1, N2: n2, T: T} 177 | switch alt { 178 | case LocationDiffers: 179 | if U1 == U2 { 180 | // The distribution is symmetric about 181 | // Usmall. Since the distribution is 182 | // discrete, the CDF is discontinuous 183 | // and if simply double CDF(Usmall), 184 | // we'll double count the 185 | // (non-infinitesimal) probability 186 | // mass at Usmall. What we want is 187 | // just the integral of the whole CDF, 188 | // which is 1. 189 | p = 1 190 | } else { 191 | p = dist.CDF(Usmall) * 2 192 | } 193 | 194 | case LocationLess: 195 | p = dist.CDF(U1) 196 | 197 | case LocationGreater: 198 | p = 1 - dist.CDF(U1-1) 199 | } 200 | } else { 201 | // Use normal approximation (with tie and continuity 202 | // correction). 203 | t := tieCorrection(T) 204 | N := float64(n1 + n2) 205 | μ_U := float64(n1*n2) / 2 206 | σ_U := math.Sqrt(float64(n1*n2) * ((N + 1) - t/(N*(N-1))) / 12) 207 | if σ_U == 0 { 208 | return nil, ErrSamplesEqual 209 | } 210 | numer := U1 - μ_U 211 | // Perform continuity correction. 212 | switch alt { 213 | case LocationDiffers: 214 | numer -= mathx.Sign(numer) * 0.5 215 | case LocationLess: 216 | numer += 0.5 217 | case LocationGreater: 218 | numer -= 0.5 219 | } 220 | z := numer / σ_U 221 | switch alt { 222 | case LocationDiffers: 223 | p = 2 * math.Min(StdNormal.CDF(z), 1-StdNormal.CDF(z)) 224 | case LocationLess: 225 | p = StdNormal.CDF(z) 226 | case LocationGreater: 227 | p = 1 - StdNormal.CDF(z) 228 | } 229 | } 230 | 231 | return &MannWhitneyUTestResult{N1: n1, N2: n2, U: U1, 232 | AltHypothesis: alt, P: p}, nil 233 | } 234 | 235 | // labeledMerge merges sorted lists x1 and x2 into sorted list merged. 236 | // labels[i] is 1 or 2 depending on whether merged[i] is a value from 237 | // x1 or x2, respectively. 238 | func labeledMerge(x1, x2 []float64) (merged []float64, labels []byte) { 239 | merged = make([]float64, len(x1)+len(x2)) 240 | labels = make([]byte, len(x1)+len(x2)) 241 | 242 | i, j, o := 0, 0, 0 243 | for i < len(x1) && j < len(x2) { 244 | if x1[i] < x2[j] { 245 | merged[o] = x1[i] 246 | labels[o] = 1 247 | i++ 248 | } else { 249 | merged[o] = x2[j] 250 | labels[o] = 2 251 | j++ 252 | } 253 | o++ 254 | } 255 | for ; i < len(x1); i++ { 256 | merged[o] = x1[i] 257 | labels[o] = 1 258 | o++ 259 | } 260 | for ; j < len(x2); j++ { 261 | merged[o] = x2[j] 262 | labels[o] = 2 263 | o++ 264 | } 265 | return 266 | } 267 | 268 | // tieCorrection computes the tie correction factor Σ_j (t_j³ - t_j) 269 | // where t_j is the number of ties in the j'th rank. 270 | func tieCorrection(ties []int) float64 { 271 | t := 0 272 | for _, tie := range ties { 273 | t += tie*tie*tie - tie 274 | } 275 | return float64(t) 276 | } 277 | -------------------------------------------------------------------------------- /stats/sample.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import ( 8 | "math" 9 | "sort" 10 | 11 | "github.com/aclements/go-moremath/vec" 12 | ) 13 | 14 | // Sample is a collection of possibly weighted data points. 15 | type Sample struct { 16 | // Xs is the slice of sample values. 17 | Xs []float64 18 | 19 | // Weights[i] is the weight of sample Xs[i]. If Weights is 20 | // nil, all Xs have weight 1. Weights must have the same 21 | // length of Xs and all values must be non-negative. 22 | Weights []float64 23 | 24 | // Sorted indicates that Xs is sorted in ascending order. 25 | Sorted bool 26 | } 27 | 28 | // Bounds returns the minimum and maximum values of xs. 29 | func Bounds(xs []float64) (min float64, max float64) { 30 | if len(xs) == 0 { 31 | return math.NaN(), math.NaN() 32 | } 33 | min, max = xs[0], xs[0] 34 | for _, x := range xs { 35 | if x < min { 36 | min = x 37 | } 38 | if x > max { 39 | max = x 40 | } 41 | } 42 | return 43 | } 44 | 45 | // Bounds returns the minimum and maximum values of the Sample. 46 | // 47 | // If the Sample is weighted, this ignores samples with zero weight. 48 | // 49 | // This is constant time if s.Sorted and there are no zero-weighted 50 | // values. 51 | func (s Sample) Bounds() (min float64, max float64) { 52 | if len(s.Xs) == 0 || (!s.Sorted && s.Weights == nil) { 53 | return Bounds(s.Xs) 54 | } 55 | 56 | if s.Sorted { 57 | if s.Weights == nil { 58 | return s.Xs[0], s.Xs[len(s.Xs)-1] 59 | } 60 | min, max = math.NaN(), math.NaN() 61 | for i, w := range s.Weights { 62 | if w != 0 { 63 | min = s.Xs[i] 64 | break 65 | } 66 | } 67 | if math.IsNaN(min) { 68 | return 69 | } 70 | for i := range s.Weights { 71 | if s.Weights[len(s.Weights)-i-1] != 0 { 72 | max = s.Xs[len(s.Weights)-i-1] 73 | break 74 | } 75 | } 76 | } else { 77 | min, max = math.Inf(1), math.Inf(-1) 78 | for i, x := range s.Xs { 79 | w := s.Weights[i] 80 | if x < min && w != 0 { 81 | min = x 82 | } 83 | if x > max && w != 0 { 84 | max = x 85 | } 86 | } 87 | if math.IsInf(min, 0) { 88 | min, max = math.NaN(), math.NaN() 89 | } 90 | } 91 | return 92 | } 93 | 94 | // Sum returns the (possibly weighted) sum of the Sample. 95 | func (s Sample) Sum() float64 { 96 | if s.Weights == nil { 97 | return vec.Sum(s.Xs) 98 | } 99 | sum := 0.0 100 | for i, x := range s.Xs { 101 | sum += x * s.Weights[i] 102 | } 103 | return sum 104 | } 105 | 106 | // Weight returns the total weight of the Sasmple. 107 | func (s Sample) Weight() float64 { 108 | if s.Weights == nil { 109 | return float64(len(s.Xs)) 110 | } 111 | return vec.Sum(s.Weights) 112 | } 113 | 114 | // Mean returns the arithmetic mean of xs. 115 | func Mean(xs []float64) float64 { 116 | if len(xs) == 0 { 117 | return math.NaN() 118 | } 119 | m := 0.0 120 | for i, x := range xs { 121 | m += (x - m) / float64(i+1) 122 | } 123 | return m 124 | } 125 | 126 | // Mean returns the arithmetic mean of the Sample. 127 | func (s Sample) Mean() float64 { 128 | if len(s.Xs) == 0 || s.Weights == nil { 129 | return Mean(s.Xs) 130 | } 131 | 132 | m, wsum := 0.0, 0.0 133 | for i, x := range s.Xs { 134 | // Use weighted incremental mean: 135 | // m_i = (1 - w_i/wsum_i) * m_(i-1) + (w_i/wsum_i) * x_i 136 | // = m_(i-1) + (x_i - m_(i-1)) * (w_i/wsum_i) 137 | w := s.Weights[i] 138 | wsum += w 139 | m += (x - m) * w / wsum 140 | } 141 | return m 142 | } 143 | 144 | // MeanCI returns the arithmetic mean of xs and its confidence 145 | // interval based on the sample standard deviation. 146 | func MeanCI(xs []float64, confidence float64) (mean, lo, hi float64) { 147 | mean = Mean(xs) 148 | 149 | var w float64 150 | if confidence <= 0 { 151 | // At confidence level 0, the CI width is 0. 152 | w = 0 153 | } else if confidence >= 1 || len(xs) <= 1 { 154 | // With confidence level 1, we don't know anything. 155 | // This is also the case if the sample is too small to 156 | // have a CI. 157 | w = math.Inf(1) 158 | } else { 159 | s := StdDev(xs) 160 | tdist := TDist{V: float64(len(xs) - 1)} 161 | alpha := (1 - confidence) / 2 162 | t := -InvCDF(tdist)(alpha) 163 | w = t * s / math.Sqrt(float64(len(xs))) 164 | } 165 | 166 | return mean, mean - w, mean + w 167 | } 168 | 169 | // MeanCI returns the arithmetic mean of the Sample and its confidence 170 | // interval based on the sample standard deviation. 171 | func (s Sample) MeanCI(confidence float64) (mean, lo, hi float64) { 172 | if len(s.Xs) == 0 || s.Weights == nil { 173 | return MeanCI(s.Xs, confidence) 174 | } 175 | // TODO(austin) 176 | panic("Weighted MeanCI not implemented") 177 | } 178 | 179 | // GeoMean returns the geometric mean of xs. xs must be positive. 180 | func GeoMean(xs []float64) float64 { 181 | if len(xs) == 0 { 182 | return math.NaN() 183 | } 184 | m := 0.0 185 | for i, x := range xs { 186 | if x <= 0 { 187 | return math.NaN() 188 | } 189 | lx := math.Log(x) 190 | m += (lx - m) / float64(i+1) 191 | } 192 | return math.Exp(m) 193 | } 194 | 195 | // GeoMean returns the geometric mean of the Sample. All samples 196 | // values must be positive. 197 | func (s Sample) GeoMean() float64 { 198 | if len(s.Xs) == 0 || s.Weights == nil { 199 | return GeoMean(s.Xs) 200 | } 201 | 202 | m, wsum := 0.0, 0.0 203 | for i, x := range s.Xs { 204 | w := s.Weights[i] 205 | wsum += w 206 | lx := math.Log(x) 207 | m += (lx - m) * w / wsum 208 | } 209 | return math.Exp(m) 210 | } 211 | 212 | // Variance returns the sample variance of xs. 213 | func Variance(xs []float64) float64 { 214 | if len(xs) == 0 { 215 | return math.NaN() 216 | } else if len(xs) <= 1 { 217 | return 0 218 | } 219 | 220 | // Based on Wikipedia's presentation of Welford 1962 221 | // (http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm). 222 | // This is more numerically stable than the standard two-pass 223 | // formula and not prone to massive cancellation. 224 | mean, M2 := 0.0, 0.0 225 | for n, x := range xs { 226 | delta := x - mean 227 | mean += delta / float64(n+1) 228 | M2 += delta * (x - mean) 229 | } 230 | return M2 / float64(len(xs)-1) 231 | } 232 | 233 | func (s Sample) Variance() float64 { 234 | if len(s.Xs) == 0 || s.Weights == nil { 235 | return Variance(s.Xs) 236 | } 237 | // TODO(austin) 238 | panic("Weighted Variance not implemented") 239 | } 240 | 241 | // StdDev returns the sample standard deviation of xs. 242 | func StdDev(xs []float64) float64 { 243 | return math.Sqrt(Variance(xs)) 244 | } 245 | 246 | // StdDev returns the sample standard deviation of the Sample. 247 | func (s Sample) StdDev() float64 { 248 | if len(s.Xs) == 0 || s.Weights == nil { 249 | return StdDev(s.Xs) 250 | } 251 | // TODO(austin) 252 | panic("Weighted StdDev not implemented") 253 | } 254 | 255 | // Quantile returns the sample value X at which q*weight of the sample 256 | // is <= X. This uses interpolation method R8 from Hyndman and Fan 257 | // (1996). 258 | // 259 | // q will be capped to the range [0, 1]. If len(xs) == 0 or all 260 | // weights are 0, returns NaN. 261 | // 262 | // Quantile(0.5) is the median. Quantile(0.25) and Quantile(0.75) are 263 | // the first and third quartiles, respectively. Quantile(P/100) is the 264 | // P'th percentile. 265 | // 266 | // See also function QuantileCI. 267 | // 268 | // This is constant time if s.Sorted and s.Weights == nil. 269 | func (s Sample) Quantile(q float64) float64 { 270 | if len(s.Xs) == 0 { 271 | return math.NaN() 272 | } else if q <= 0 { 273 | min, _ := s.Bounds() 274 | return min 275 | } else if q >= 1 { 276 | _, max := s.Bounds() 277 | return max 278 | } 279 | 280 | if !s.Sorted { 281 | // TODO(austin) Use select algorithm instead 282 | s = *s.Copy().Sort() 283 | } 284 | 285 | if s.Weights == nil { 286 | N := float64(len(s.Xs)) 287 | //n := q * (N + 1) // R6 288 | n := 1/3.0 + q*(N+1/3.0) // R8 289 | kf, frac := math.Modf(n) 290 | k := int(kf) 291 | if k <= 0 { 292 | return s.Xs[0] 293 | } else if k >= len(s.Xs) { 294 | return s.Xs[len(s.Xs)-1] 295 | } 296 | return s.Xs[k-1] + frac*(s.Xs[k]-s.Xs[k-1]) 297 | } else { 298 | // TODO(austin): Implement interpolation 299 | 300 | target := s.Weight() * q 301 | 302 | // TODO(austin) If we had cumulative weights, we could 303 | // do this in log time. 304 | for i, weight := range s.Weights { 305 | target -= weight 306 | if target < 0 { 307 | return s.Xs[i] 308 | } 309 | } 310 | return s.Xs[len(s.Xs)-1] 311 | } 312 | } 313 | 314 | // IQR returns the interquartile range of the Sample. 315 | // 316 | // This is constant time if s.Sorted and s.Weights == nil. 317 | func (s Sample) IQR() float64 { 318 | if !s.Sorted { 319 | s = *s.Copy().Sort() 320 | } 321 | return s.Quantile(0.75) - s.Quantile(0.25) 322 | } 323 | 324 | type sampleSorter struct { 325 | xs []float64 326 | weights []float64 327 | } 328 | 329 | func (p *sampleSorter) Len() int { 330 | return len(p.xs) 331 | } 332 | 333 | func (p *sampleSorter) Less(i, j int) bool { 334 | return p.xs[i] < p.xs[j] 335 | } 336 | 337 | func (p *sampleSorter) Swap(i, j int) { 338 | p.xs[i], p.xs[j] = p.xs[j], p.xs[i] 339 | p.weights[i], p.weights[j] = p.weights[j], p.weights[i] 340 | } 341 | 342 | // Sort sorts the samples in place in s and returns s. 343 | // 344 | // A sorted sample improves the performance of some algorithms. 345 | func (s *Sample) Sort() *Sample { 346 | if s.Sorted || sort.Float64sAreSorted(s.Xs) { 347 | // All set 348 | } else if s.Weights == nil { 349 | sort.Float64s(s.Xs) 350 | } else { 351 | sort.Sort(&sampleSorter{s.Xs, s.Weights}) 352 | } 353 | s.Sorted = true 354 | return s 355 | } 356 | 357 | // Copy returns a copy of the Sample. 358 | // 359 | // The returned Sample shares no data with the original, so they can 360 | // be modified (for example, sorted) independently. 361 | func (s Sample) Copy() *Sample { 362 | xs := make([]float64, len(s.Xs)) 363 | copy(xs, s.Xs) 364 | 365 | weights := []float64(nil) 366 | if s.Weights != nil { 367 | weights = make([]float64, len(s.Weights)) 368 | copy(weights, s.Weights) 369 | } 370 | 371 | return &Sample{xs, weights, s.Sorted} 372 | } 373 | -------------------------------------------------------------------------------- /stats/udist_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package stats 6 | 7 | import ( 8 | "fmt" 9 | "math" 10 | "testing" 11 | 12 | "github.com/aclements/go-moremath/mathx" 13 | ) 14 | 15 | func aeqTable(a, b [][]float64) bool { 16 | if len(a) != len(b) { 17 | return false 18 | } 19 | for i := range a { 20 | if len(a[i]) != len(b[i]) { 21 | return false 22 | } 23 | for j := range a[i] { 24 | // "%f" precision 25 | if math.Abs(a[i][j]-b[i][j]) >= 0.000001 { 26 | return false 27 | } 28 | } 29 | } 30 | return true 31 | } 32 | 33 | // U distribution for N=3 up to U=5. 34 | var udist3 = [][]float64{ 35 | // m=1 2 3 36 | {0.250000, 0.100000, 0.050000}, // U=0 37 | {0.500000, 0.200000, 0.100000}, // U=1 38 | {0.750000, 0.400000, 0.200000}, // U=2 39 | {1.000000, 0.600000, 0.350000}, // U=3 40 | {1.000000, 0.800000, 0.500000}, // U=4 41 | {1.000000, 0.900000, 0.650000}, // U=5 42 | } 43 | 44 | // U distribution for N=5 up to U=5. 45 | var udist5 = [][]float64{ 46 | // m=1 2 3 4 5 47 | {0.166667, 0.047619, 0.017857, 0.007937, 0.003968}, // U=0 48 | {0.333333, 0.095238, 0.035714, 0.015873, 0.007937}, // U=1 49 | {0.500000, 0.190476, 0.071429, 0.031746, 0.015873}, // U=2 50 | {0.666667, 0.285714, 0.125000, 0.055556, 0.027778}, // U=3 51 | {0.833333, 0.428571, 0.196429, 0.095238, 0.047619}, // U=4 52 | {1.000000, 0.571429, 0.285714, 0.142857, 0.075397}, // U=5 53 | } 54 | 55 | func TestUDist(t *testing.T) { 56 | makeTable := func(n int) [][]float64 { 57 | out := make([][]float64, 6) 58 | for U := 0; U < 6; U++ { 59 | out[U] = make([]float64, n) 60 | for m := 1; m <= n; m++ { 61 | out[U][m-1] = UDist{N1: m, N2: n}.CDF(float64(U)) 62 | } 63 | } 64 | return out 65 | } 66 | fmtTable := func(a [][]float64) string { 67 | out := fmt.Sprintf("%8s", "m=") 68 | for m := 1; m <= len(a[0]); m++ { 69 | out += fmt.Sprintf("%9d", m) 70 | } 71 | out += "\n" 72 | 73 | for U, row := range a { 74 | out += fmt.Sprintf("U=%-6d", U) 75 | for m := 1; m <= len(a[0]); m++ { 76 | out += fmt.Sprintf(" %f", row[m-1]) 77 | } 78 | out += "\n" 79 | } 80 | return out 81 | } 82 | 83 | // Compare against tables given in Mann, Whitney (1947). 84 | got3 := makeTable(3) 85 | if !aeqTable(got3, udist3) { 86 | t.Errorf("For n=3, want:\n%sgot:\n%s", fmtTable(udist3), fmtTable(got3)) 87 | } 88 | 89 | got5 := makeTable(5) 90 | if !aeqTable(got5, udist5) { 91 | t.Errorf("For n=5, want:\n%sgot:\n%s", fmtTable(udist5), fmtTable(got5)) 92 | } 93 | } 94 | 95 | func BenchmarkUDist(b *testing.B) { 96 | for i := 0; i < b.N; i++ { 97 | // R uses the exact distribution up to N=50. 98 | // N*M/2=1250 is the hardest point to get the CDF for. 99 | UDist{N1: 50, N2: 50}.CDF(1250) 100 | } 101 | } 102 | 103 | func TestUDistTies(t *testing.T) { 104 | makeTable := func(m, N int, t []int, minx, maxx float64) [][]float64 { 105 | out := [][]float64{} 106 | dist := UDist{N1: m, N2: N - m, T: t} 107 | for x := minx; x <= maxx; x += 0.5 { 108 | // Convert x from uQt' to uQv'. 109 | U := x - float64(m*m)/2 110 | P := dist.CDF(U) 111 | if len(out) == 0 || !aeq(out[len(out)-1][1], P) { 112 | out = append(out, []float64{x, P}) 113 | } 114 | } 115 | return out 116 | } 117 | fmtTable := func(table [][]float64) string { 118 | out := "" 119 | for _, row := range table { 120 | out += fmt.Sprintf("%5.1f %f\n", row[0], row[1]) 121 | } 122 | return out 123 | } 124 | 125 | // Compare against Table 1 from Klotz (1966). 126 | got := makeTable(5, 10, []int{1, 1, 2, 1, 1, 2, 1, 1}, 12.5, 19.5) 127 | want := [][]float64{ 128 | {12.5, 0.003968}, {13.5, 0.007937}, 129 | {15.0, 0.023810}, {16.5, 0.047619}, 130 | {17.5, 0.071429}, {18.0, 0.087302}, 131 | {19.0, 0.134921}, {19.5, 0.138889}, 132 | } 133 | if !aeqTable(got, want) { 134 | t.Errorf("Want:\n%sgot:\n%s", fmtTable(want), fmtTable(got)) 135 | } 136 | 137 | got = makeTable(10, 21, []int{6, 5, 4, 3, 2, 1}, 52, 87) 138 | want = [][]float64{ 139 | {52.0, 0.000014}, {56.5, 0.000128}, 140 | {57.5, 0.000145}, {60.0, 0.000230}, 141 | {61.0, 0.000400}, {62.0, 0.000740}, 142 | {62.5, 0.000797}, {64.0, 0.000825}, 143 | {64.5, 0.001165}, {65.5, 0.001477}, 144 | {66.5, 0.002498}, {67.0, 0.002725}, 145 | {67.5, 0.002895}, {68.0, 0.003150}, 146 | {68.5, 0.003263}, {69.0, 0.003518}, 147 | {69.5, 0.003603}, {70.0, 0.005648}, 148 | {70.5, 0.005818}, {71.0, 0.006626}, 149 | {71.5, 0.006796}, {72.0, 0.008157}, 150 | {72.5, 0.009688}, {73.0, 0.009801}, 151 | {73.5, 0.010430}, {74.0, 0.011111}, 152 | {74.5, 0.014230}, {75.0, 0.014612}, 153 | {75.5, 0.017249}, {76.0, 0.018307}, 154 | {76.5, 0.020178}, {77.0, 0.022270}, 155 | {77.5, 0.023189}, {78.0, 0.026931}, 156 | {78.5, 0.028207}, {79.0, 0.029979}, 157 | {79.5, 0.030931}, {80.0, 0.038969}, 158 | {80.5, 0.043063}, {81.0, 0.044262}, 159 | {81.5, 0.046389}, {82.0, 0.049581}, 160 | {82.5, 0.056300}, {83.0, 0.058027}, 161 | {83.5, 0.063669}, {84.0, 0.067454}, 162 | {84.5, 0.074122}, {85.0, 0.077425}, 163 | {85.5, 0.083498}, {86.0, 0.094079}, 164 | {86.5, 0.096693}, {87.0, 0.101132}, 165 | } 166 | if !aeqTable(got, want) { 167 | t.Errorf("Want:\n%sgot:\n%s", fmtTable(want), fmtTable(got)) 168 | } 169 | 170 | got = makeTable(8, 16, []int{2, 2, 2, 2, 2, 2, 2, 2}, 32, 54) 171 | want = [][]float64{ 172 | {32.0, 0.000078}, {34.0, 0.000389}, 173 | {36.0, 0.001088}, {38.0, 0.002642}, 174 | {40.0, 0.005905}, {42.0, 0.011500}, 175 | {44.0, 0.021057}, {46.0, 0.035664}, 176 | {48.0, 0.057187}, {50.0, 0.086713}, 177 | {52.0, 0.126263}, {54.0, 0.175369}, 178 | } 179 | if !aeqTable(got, want) { 180 | t.Errorf("Want:\n%sgot:\n%s", fmtTable(want), fmtTable(got)) 181 | } 182 | 183 | // Check remaining tables from Klotz against the reference 184 | // implementation. 185 | checkRef := func(n1 int, tie []int) { 186 | wantPMF1, wantCDF1 := udistRef(n1, tie) 187 | 188 | dist := UDist{N1: n1, N2: sumint(tie) - n1, T: tie} 189 | gotPMF, wantPMF := [][]float64{}, [][]float64{} 190 | gotCDF, wantCDF := [][]float64{}, [][]float64{} 191 | N := sumint(tie) 192 | for U := 0.0; U <= float64(n1*(N-n1)); U += 0.5 { 193 | gotPMF = append(gotPMF, []float64{U, dist.PMF(U)}) 194 | gotCDF = append(gotCDF, []float64{U, dist.CDF(U)}) 195 | wantPMF = append(wantPMF, []float64{U, wantPMF1[int(U*2)]}) 196 | wantCDF = append(wantCDF, []float64{U, wantCDF1[int(U*2)]}) 197 | } 198 | if !aeqTable(wantPMF, gotPMF) { 199 | t.Errorf("For PMF of n1=%v, t=%v, want:\n%sgot:\n%s", n1, tie, fmtTable(wantPMF), fmtTable(gotPMF)) 200 | } 201 | if !aeqTable(wantCDF, gotCDF) { 202 | t.Errorf("For CDF of n1=%v, t=%v, want:\n%sgot:\n%s", n1, tie, fmtTable(wantCDF), fmtTable(gotCDF)) 203 | } 204 | } 205 | checkRef(5, []int{1, 1, 2, 1, 1, 2, 1, 1}) 206 | checkRef(5, []int{1, 1, 2, 1, 1, 1, 2, 1}) 207 | checkRef(5, []int{1, 3, 1, 2, 1, 1, 1}) 208 | checkRef(8, []int{1, 2, 1, 1, 1, 1, 2, 2, 1, 2}) 209 | checkRef(12, []int{3, 3, 4, 3, 4, 5}) 210 | checkRef(10, []int{1, 2, 3, 4, 5, 6}) 211 | } 212 | 213 | func BenchmarkUDistTies(b *testing.B) { 214 | // Worst case: just one tie. 215 | n := 20 216 | t := make([]int, 2*n-1) 217 | for i := range t { 218 | t[i] = 1 219 | } 220 | t[0] = 2 221 | 222 | for i := 0; i < b.N; i++ { 223 | UDist{N1: n, N2: n, T: t}.CDF(float64(n*n) / 2) 224 | } 225 | } 226 | 227 | func XTestPrintUmemo(t *testing.T) { 228 | // Reproduce table from Cheung, Klotz. 229 | ties := []int{4, 5, 3, 4, 6} 230 | printUmemo(makeUmemo(80, 10, ties), ties) 231 | } 232 | 233 | // udistRef computes the PMF and CDF of the U distribution for two 234 | // samples of sizes n1 and sum(t)-n1 with tie vector t. The returned 235 | // pmf and cdf are indexed by 2*U. 236 | // 237 | // This uses the "graphical method" of Klotz (1966). It is very slow 238 | // (Θ(∏ (t[i]+1)) = Ω(2^|t|)), but very correct, and hence useful as a 239 | // reference for testing faster implementations. 240 | func udistRef(n1 int, t []int) (pmf, cdf []float64) { 241 | // Enumerate all u vectors for which 0 <= u_i <= t_i. Count 242 | // the number of permutations of two samples of sizes n1 and 243 | // sum(t)-n1 with tie vector t and accumulate these counts by 244 | // their U statistics in count[2*U]. 245 | counts := make([]int, 1+2*n1*(sumint(t)-n1)) 246 | 247 | u := make([]int, len(t)) 248 | u[0] = -1 // Get enumeration started. 249 | enumu: 250 | for { 251 | // Compute the next u vector. 252 | u[0]++ 253 | for i := 0; i < len(u) && u[i] > t[i]; i++ { 254 | if i == len(u)-1 { 255 | // All u vectors have been enumerated. 256 | break enumu 257 | } 258 | // Carry. 259 | u[i+1]++ 260 | u[i] = 0 261 | } 262 | 263 | // Is this a legal u vector? 264 | if sumint(u) != n1 { 265 | // Klotz (1966) has a method for directly 266 | // enumerating legal u vectors, but the point 267 | // of this is to be correct, not fast. 268 | continue 269 | } 270 | 271 | // Compute 2*U statistic for this u vector. 272 | twoU, vsum := 0, 0 273 | for i, u_i := range u { 274 | v_i := t[i] - u_i 275 | // U = U + vsum*u_i + u_i*v_i/2 276 | twoU += 2*vsum*u_i + u_i*v_i 277 | vsum += v_i 278 | } 279 | 280 | // Compute Π choose(t_i, u_i). This is the number of 281 | // ways of permuting the input sample under u. 282 | prod := 1 283 | for i, u_i := range u { 284 | prod *= int(mathx.Choose(t[i], u_i) + 0.5) 285 | } 286 | 287 | // Accumulate the permutations on this u path. 288 | counts[twoU] += prod 289 | 290 | if false { 291 | // Print a table in the form of Klotz's 292 | // "direct enumeration" example. 293 | // 294 | // Convert 2U = 2UQV' to UQt' used in Klotz 295 | // examples. 296 | UQt := float64(twoU)/2 + float64(n1*n1)/2 297 | fmt.Printf("%+v %f %-2d\n", u, UQt, prod) 298 | } 299 | } 300 | 301 | // Convert counts into probabilities for PMF and CDF. 302 | pmf = make([]float64, len(counts)) 303 | cdf = make([]float64, len(counts)) 304 | total := int(mathx.Choose(sumint(t), n1) + 0.5) 305 | for i, count := range counts { 306 | pmf[i] = float64(count) / float64(total) 307 | if i > 0 { 308 | cdf[i] = cdf[i-1] 309 | } 310 | cdf[i] += pmf[i] 311 | } 312 | return 313 | } 314 | 315 | // printUmemo prints the output of makeUmemo for debugging. 316 | func printUmemo(A []map[ukey]float64, t []int) { 317 | fmt.Printf("K\tn1\t2*U\tpr\n") 318 | for K := len(A) - 1; K >= 0; K-- { 319 | for i, pr := range A[K] { 320 | _, ref := udistRef(i.n1, t[:K]) 321 | fmt.Printf("%v\t%v\t%v\t%v\t%v\n", K, i.n1, i.twoU, pr, ref[i.twoU]) 322 | } 323 | } 324 | } 325 | --------------------------------------------------------------------------------