├── .document ├── .gitignore ├── .rspec ├── .rubocop.yml ├── .travis.yml ├── Gemfile ├── History.md ├── LICENSE.txt ├── README.md ├── Rakefile ├── lib ├── statsample-glm.rb └── statsample-glm │ ├── glm.rb │ ├── glm │ ├── base.rb │ ├── formula │ │ ├── formula.rb │ │ ├── token.rb │ │ └── wrapper.rb │ ├── irls │ │ ├── base.rb │ │ ├── logistic.rb │ │ └── poisson.rb │ ├── logistic.rb │ ├── mle │ │ ├── base.rb │ │ ├── logistic.rb │ │ ├── normal.rb │ │ └── probit.rb │ ├── normal.rb │ ├── poisson.rb │ ├── probit.rb │ └── regression.rb │ └── version.rb ├── spec ├── data │ ├── binary.csv │ ├── df.csv │ ├── logistic.csv │ ├── logistic_mle.csv │ └── normal.csv ├── formula_spec.rb ├── formula_wrapper_spec.rb ├── logistic_spec.rb ├── normal_spec.rb ├── poisson_spec.rb ├── probit_spec.rb ├── regression_spec.rb ├── shared_context │ ├── formula_checker.rb │ ├── parser_checker.rb │ └── reduce_formula.rb ├── spec_helper.rb └── token_spec.rb └── statsample-glm.gemspec /.document: -------------------------------------------------------------------------------- 1 | lib/**/*.rb 2 | bin/* 3 | - 4 | features/**/*.feature 5 | LICENSE.txt 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # rcov generated 2 | coverage 3 | coverage.data 4 | 5 | # rdoc generated 6 | rdoc 7 | 8 | # yard generated 9 | doc 10 | .yardoc 11 | 12 | # bundler 13 | .bundle 14 | 15 | # jeweler generated 16 | pkg 17 | 18 | # Have editor/IDE/OS specific files you need to ignore? Consider using a global gitignore: 19 | # 20 | # * Create a file at ~/.gitignore 21 | # * Include files you want ignored 22 | # * Run: git config --global core.excludesfile ~/.gitignore 23 | # 24 | # After doing this, these files will be ignored in all your git projects, 25 | # saving you from having to 'pollute' every project you touch with them 26 | # 27 | # Not sure what to needs to be ignored for particular editors/OSes? Here's some ideas to get you started. (Remember, remove the leading # of the line) 28 | # 29 | # For MacOS: 30 | # 31 | #.DS_Store 32 | 33 | # For TextMate 34 | #*.tmproj 35 | #tmtags 36 | 37 | # For emacs: 38 | #*~ 39 | #\#* 40 | #.\#* 41 | 42 | # For vim: 43 | #*.swp 44 | 45 | # For redcar: 46 | #.redcar 47 | 48 | # For rubinius: 49 | #*.rbc 50 | # Ignore Gemfile.lock for gems. See http://yehudakatz.com/2010/12/16/clarifying-the-roles-of-the-gemspec-and-gemfile/ 51 | Gemfile.lock 52 | *.gem 53 | -------------------------------------------------------------------------------- /.rspec: -------------------------------------------------------------------------------- 1 | --color -------------------------------------------------------------------------------- /.rubocop.yml: -------------------------------------------------------------------------------- 1 | AllCops: 2 | Include: 3 | - 'lib/statsample-glm/glm/formula/formula.rb' 4 | - 'lib/statsample-glm/glm/formula/wrapper.rb' 5 | - 'lib/statsample-glm/glm/formula/token.rb' 6 | - 'lib/statsample-glm/glm/regression.rb' 7 | Exclude: 8 | - 'lib/statsample-glm/glm/base.rb' 9 | - 'lib/statsample-glm/glm/logistic.rb' 10 | - 'lib/statsample-glm/glm/normal.rb' 11 | - 'lib/statsample-glm/glm/poisson.rb' 12 | - 'lib/statsample-glm/glm/probit.rb' 13 | - 'lib/statsample-glm/glm/irls/*' 14 | - 'lib/statsample-glm/glm/mle/*' 15 | - 'spec/*' 16 | - 'lib/statsample-glm/version.rb' 17 | - 'lib/statsample-glm/glm.rb' 18 | - 'lib/statsample-glm.rb' 19 | DisplayCopNames: true 20 | 21 | Metrics/AbcSize: 22 | Max: 20 23 | 24 | Metrics/MethodLength: 25 | Max: 15 26 | 27 | Metrics/CyclomaticComplexity: 28 | Max: 7 29 | 30 | Metrics/ModuleLength: 31 | Max: 200 32 | 33 | Metrics/ClassLength: 34 | Max: 200 -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: ruby 2 | cache: bundler 3 | rvm: 4 | - 2.0.0 5 | - 2.1.1 6 | - 2.2.1 7 | - 2.3.0 8 | 9 | script: "bundle exec rspec" 10 | 11 | install: 12 | - gem install bundler 13 | - bundle install 14 | -------------------------------------------------------------------------------- /Gemfile: -------------------------------------------------------------------------------- 1 | source "https://rubygems.org" 2 | gemspec 3 | -------------------------------------------------------------------------------- /History.md: -------------------------------------------------------------------------------- 1 | # 0.2.1 2 | 3 | Minor test updates for compatibility with daru 0.1.1 4 | 5 | # 0.2.0 6 | 7 | * Added dependency on daru (0.1) for data structures and increased statsample dependency to 1.5. -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | This version of Statsample-GLM is licensed under the BSD 2-clause license. 2 | 3 | * http://sciruby.com 4 | * http://github.com/sciruby/sciruby/wiki/License 5 | 6 | You *must* read the Contributor Agreement before contributing code to the SciRuby Project. This is available online: 7 | 8 | * http://github.com/sciruby/sciruby/wiki/Contributor-Agreement 9 | 10 | ----- 11 | 12 | Copyright (c) 2013, Ankur Goel and the Ruby Science Foundation 13 | All rights reserved. 14 | 15 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 16 | 17 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 18 | 19 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # statsample-glm 2 | 3 | [![Build Status](https://travis-ci.org/SciRuby/statsample-glm.svg?branch=master)](https://travis-ci.org/SciRuby/statsample-glm) 4 | 5 | [![Gem Version](https://badge.fury.io/rb/statsample-glm.svg)](http://badge.fury.io/rb/statsample-glm) 6 | 7 | Statsample-GLM is an extension of *Generalized Linear Models* to [Statsample](https://github.com/SciRuby/statsample), a suite of advance statistics in Ruby. 8 | 9 | Requires ruby 2.0.0 or higher. 10 | 11 | ## Description 12 | 13 | Statsample-glm includes the following Generalized Linear Models: 14 | 15 | * Iteratively Reweighted Least Squares 16 | * Poisson Regression 17 | * Logistic Regression 18 | * Maximum Likelihood Models (Newton Raphson) 19 | * Logistic Regression 20 | * Probit Regression 21 | * Normal Regression 22 | 23 | Statsample-GLM was created by Ankur Goel as part of Google's Summer of Code 2013. It is the part of [the SciRuby Project](http://sciruby.com). 24 | 25 | ## Installation 26 | 27 | `gem install statsample-glm` 28 | 29 | 30 | ## Usage 31 | 32 | To use the library 33 | 34 | `require 'statsample-glm'` 35 | 36 | ### Blogs 37 | 38 | * [Generalized Linear Models: Introduction and implementation in Ruby](http://v0dro.github.io/blog/2014/09/21/code-generalized-linear-models-introduction-and-implementation-in-ruby/). 39 | * [Formula language implementation in Statsample-GLM](http://lokeshh.github.io/blog/2016/07/19/formula-language-week7-8/) 40 | * [Addition of shortcut symbols in formula language](http://lokeshh.github.io/blog/2016/07/31/shortcut-symbols/) 41 | 42 | ### Case Studies 43 | 44 | * [Logistic Regression Analysis with daru and statsample-glm](http://nbviewer.ipython.org/github/SciRuby/sciruby-notebooks/blob/master/Data%20Analysis/Logistic%20Regression%20with%20daru%20and%20statsample-glm.ipynb) 45 | * [Logistic Regression involving categorical variable and use of formula language](http://nbviewer.jupyter.org/github/SciRuby/sciruby-notebooks/blob/master/Data%20Analysis/Categorical%20Data/examples/[Example]%20Formula%20language%20in%20Statsample-GLM.ipynb) 46 | 47 | ## Documentation 48 | 49 | The API doc is [online](http://rubygems.org/gems/statsample-glm). For more code examples see also the spec files in the source tree. 50 | 51 | ## Project home page 52 | 53 | http://github.com/sciruby/statsample-glm 54 | 55 | ## Copyright 56 | 57 | Copyright (c) 2013 Ankur Goel and the Ruby Science Foundation. See LICENSE.txt for further details. 58 | 59 | Statsample is (c) 2009-2013 Claudio Bustos and the Ruby Science Foundation. 60 | -------------------------------------------------------------------------------- /Rakefile: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | require 'rake' 3 | require 'bundler/gem_tasks' 4 | require 'bundler' 5 | 6 | lib_folder = File.expand_path("../lib", __FILE__) 7 | 8 | begin 9 | Bundler.setup(:default, :development) 10 | rescue Bundler::BundlerError => e 11 | $stderr.puts e.message 12 | $stderr.puts "Run `bundle install` to install missing gems" 13 | exit e.status_code 14 | end 15 | 16 | desc "Open IRB with statsample-timeseries loaded." 17 | task :console do 18 | require 'irb' 19 | require 'irb/completion' 20 | $:.unshift File.expand_path("../lib", __FILE__) 21 | require 'statsample-glm' 22 | ARGV.clear 23 | IRB.start 24 | end 25 | 26 | task :pry do |task| 27 | cmd = [ 'pry', "-r '#{lib_folder}/statsample-glm.rb'" ] 28 | run *cmd 29 | end 30 | 31 | task :rubocop do |task| 32 | run 'rubocop' rescue nil 33 | end 34 | 35 | require 'rspec/core/rake_task' 36 | 37 | RSpec::Core::RakeTask.new(:spec) 38 | 39 | task :default => :spec 40 | 41 | require 'rdoc/task' 42 | Rake::RDocTask.new do |rdoc| 43 | $:.unshift File.expand_path("../lib", __FILE__) 44 | version = Statsample::GLM::VERSION 45 | 46 | rdoc.rdoc_dir = 'rdoc' 47 | rdoc.title = "statsample-glm #{version}" 48 | rdoc.rdoc_files.include('README*') 49 | rdoc.rdoc_files.include('lib/**/*.rb') 50 | end 51 | 52 | def run *cmd 53 | sh(cmd.join(" ")) 54 | end 55 | -------------------------------------------------------------------------------- /lib/statsample-glm.rb: -------------------------------------------------------------------------------- 1 | require 'daru' 2 | require 'statsample' 3 | require 'statsample-glm/glm' 4 | -------------------------------------------------------------------------------- /lib/statsample-glm/glm.rb: -------------------------------------------------------------------------------- 1 | require 'statsample-glm/glm/logistic' 2 | require 'statsample-glm/glm/probit' 3 | require 'statsample-glm/glm/poisson' 4 | require 'statsample-glm/glm/normal' 5 | require 'statsample-glm/glm/formula/formula.rb' 6 | require 'statsample-glm/glm/formula/wrapper.rb' 7 | require 'statsample-glm/glm/formula/token.rb' 8 | require 'statsample-glm/glm/regression' 9 | 10 | module Statsample 11 | module GLM 12 | include Statsample::VectorShorthands 13 | 14 | # = Generalized linear models 15 | # == Parameters 16 | # 17 | # * x = model matrix 18 | # * y = response vector 19 | # 20 | # == Usage 21 | # require 'statsample-glm' 22 | # x1 = Daru::Vector.new([0.537322309644812,-0.717124209978434,-0.519166718891331,0.434970973986765,-0.761822002215759,1.51170030921189,0.883854199811195,-0.908689798854196,1.70331977539793,-0.246971150634099,-1.59077593922623,-0.721548040910253,0.467025703920194,-0.510132788447137,0.430106510266798,-0.144353683251536,-1.54943800728303,0.849307651309298,-0.640304240933579,1.31462478279425,-0.399783455165345,0.0453055645017902,-2.58212161987746,-1.16484414309359,-1.08829266466281,-0.243893919684792,-1.96655661929441,0.301335373291024,-0.665832694463588,-0.0120650855753837,1.5116066367604,0.557300353673344,1.12829931872045,0.234443748015922,-2.03486690662651,0.275544751380246,-0.231465849558696,-0.356880153225012,-0.57746647541923,1.35758352580655,1.23971669378224,-0.662466275100489,0.313263561921793,-1.08783223256362,1.41964722846899,1.29325100940785,0.72153880625103,0.440580131022748,0.0351917814720056, -0.142353224879252]) 23 | # x2 = Daru::Vector.new([-0.866655707911859,-0.367820249977585,0.361486610435,0.857332626245179,0.133438466268095,0.716104533073575,1.77206093023382,-0.10136697295802,-0.777086491435508,-0.204573554913706,0.963353531412233,-1.10103024900542,-0.404372761837392,-0.230226345183469,0.0363730246866971,-0.838265540390497,1.12543549657924,-0.57929175648001,-0.747060244805248,0.58946979365152,-0.531952663697324,1.53338594419818,0.521992029051441,1.41631763288724,0.611402316795129,-0.518355638373296,-0.515192557101107,-0.672697937866108,1.84347042325327,-0.21195540664804,-0.269869371631611,0.296155694010096,-2.18097898069634,-1.21314663927206,1.49193669881581,1.38969280369493,-0.400680808117106,-1.87282814976479,1.82394870451051,0.637864732838274,-0.141155946382493,0.0699950644281617,1.32568550595165,-0.412599258349398,0.14436832227506,-1.16507785388489,-2.16782049922428,0.24318371493798,0.258954871320764,-0.151966534521183]) 24 | # y_pois = Daru::Vector.new([1,2,1,3,3,1,10,1,1,2,15,0,0,2,1,2,18,2,1,1,1,8,18,13,7,1,1,0,26,0,2,2,0,0,25,7,0,0,21,0,0,1,5,0,3,0,0,1,0,0]) 25 | # x=Daru::DataFrame.new({:x1 => x1,:x2 => x2, :y => y_pois}) 26 | # obj = Statsample::GLM.compute(x, :y, :poisson, {algorithm: :irls}) 27 | # #=> Logistic Regression object 28 | # 29 | # == Returns 30 | # GLM object for given method. 31 | def self.compute(data_set, dependent_column, method, opts={}) 32 | opts[:method] = method 33 | 34 | Kernel.const_get( 35 | "Statsample::GLM::#{method.capitalize}" 36 | ).new data_set, dependent_column, opts 37 | end 38 | 39 | # TODO: Decide whether to remove this or not. 40 | # def self.fit_model(formula, df, method, opts={}) 41 | # reg = Statsample::GLM::Regression.new formula, df, method, opts 42 | # reg.fit_model 43 | # end 44 | end 45 | end 46 | -------------------------------------------------------------------------------- /lib/statsample-glm/glm/base.rb: -------------------------------------------------------------------------------- 1 | require 'statsample-glm/glm/irls/logistic' 2 | require 'statsample-glm/glm/irls/poisson' 3 | require 'statsample-glm/glm/mle/logistic' 4 | require 'statsample-glm/glm/mle/probit' 5 | require 'statsample-glm/glm/mle/normal' 6 | 7 | module Statsample 8 | module GLM 9 | class Base 10 | extend Gem::Deprecate 11 | 12 | def initialize ds, y, opts={} 13 | @opts = opts 14 | 15 | set_default_opts_if_any 16 | 17 | @dependent = ds[y] 18 | @data_set = ds.delete_vector y 19 | 20 | if @opts[:constant] 21 | add_constant_vector @opts[:constant] 22 | else 23 | add_constant_vector(1) if self.is_a? Statsample::GLM::Normal 24 | end 25 | 26 | algorithm = @opts[:algorithm].upcase 27 | method = @opts[:method].capitalize 28 | 29 | @regression = Object.const_get( 30 | "Statsample::GLM::#{algorithm}::#{method}" 31 | ).new @data_set, @dependent, @opts 32 | end 33 | 34 | 35 | # Returns the coefficients of trained model 36 | # 37 | # @param [Symbol] as_a Specifies the form of output 38 | # 39 | # @return [Vector, Hash, Array] coefficients of the model 40 | # 41 | # @example 42 | # require 'statsample-glm' 43 | # data_set = Daru::DataFrame.from_csv "spec/data/logistic.csv" 44 | # glm = Statsample::GLM.compute data_set, "y", :logistic, {constant: 1} 45 | # glm.coefficients as_a = :hash 46 | # # => 47 | # # {:x1=>-0.3124937545689041, :x2=>2.286713333462646, :constant=>0.675603176233328} 48 | # 49 | def coefficients as_a=:vector 50 | case as_a 51 | when :hash 52 | c = {} 53 | @data_set.vectors.to_a.each_with_index do |f,i| 54 | c[f.to_sym] = @regression.coefficients[i] 55 | end 56 | c 57 | when :array 58 | @regression.coefficients.to_a 59 | when :vector 60 | @regression.coefficients 61 | else 62 | raise ArgumentError, "as_a has to be one of :array, :hash, or :vector" 63 | end 64 | end 65 | 66 | # Returns the standard errors for the coefficient estimates 67 | # 68 | # @param [Symbol] as_a Specifies the form of output 69 | # 70 | # @return [Vector, Hash, Array] standard error 71 | # 72 | # @example 73 | # require 'statsample-glm' 74 | # data_set = Daru::DataFrame.from_csv "spec/data/logistic.csv" 75 | # glm = Statsample::GLM.compute data_set, "y", :logistic, {constant: 1} 76 | # glm.standard_errors 77 | # # # 78 | # # nil 79 | # # 0 0.4130813039878828 80 | # # 1 0.7194644911927432 81 | # # 2 0.40380565497038895 82 | # 83 | def standard_errors as_a=:vector 84 | case as_a 85 | when :hash 86 | se = {} 87 | @data_set.vectors.to_a.each_with_index do |f,i| 88 | se[f.to_sym] = @regression.standard_errors[i] 89 | end 90 | se 91 | when :array 92 | @regression.standard_errors.to_a 93 | when :vector 94 | @regression.standard_errors 95 | else 96 | raise ArgumentError, "as_a has to be one of :array, :hash, or :vector" 97 | end 98 | end 99 | 100 | # standard_error will be removed soon 101 | alias :standard_error :standard_errors 102 | deprecate :standard_error, :standard_errors, 2017, 1 103 | 104 | def iterations 105 | @regression.iterations 106 | end 107 | 108 | # Returns the values predicted by the model 109 | # 110 | # @return [Vector] vectors of predicted values 111 | # 112 | # @example 113 | # require 'statsample-glm' 114 | # data_set = Daru::DataFrame.from_csv "spec/data/logistic.csv" 115 | # glm = Statsample::GLM.compute data_set, "y", :logistic, constant: 1 116 | # glm.fitted_mean_values 117 | # # => 118 | # # # 119 | # # nil 120 | # # 0 0.18632025624516532 121 | # # 1 0.5146459448198846 122 | # # 2 0.84083523282549 123 | # # 3 0.9241524337773334 124 | # # 4 0.7718528863631826 125 | # # ... ... 126 | # 127 | def fitted_mean_values 128 | @regression.fitted_mean_values 129 | end 130 | 131 | # Returns the residual for every data point 132 | # 133 | # @return [Vector] all residuals in a vector 134 | # 135 | # @example 136 | # require 'statsample-glm' 137 | # data_set = Daru::DataFrame.from_csv "spec/data/logistic.csv" 138 | # glm = Statsample::GLM.compute data_set, "y", :logistic, {constant: 1} 139 | # glm.residuals 140 | # # # 141 | # # y 142 | # # 0 -0.18632025624516532 143 | # # 1 -0.5146459448198846 144 | # # 2 0.15916476717451 145 | # # 3 -0.9241524337773334 146 | # # 4 0.2281471136368174 147 | # # ... ... 148 | # 149 | def residuals 150 | @regression.residuals 151 | end 152 | 153 | # Returns the degrees of freedom value. 154 | # 155 | # @return [Integer] the degrees of freedom 156 | # 157 | # @example 158 | # require 'statsample-glm' 159 | # data_set = Daru::DataFrame.from_csv "spec/data/logistic.csv" 160 | # glm = Statsample::GLM.compute data_set, "y", :logistic, constant: 1 161 | # glm.degrees_of_freedom 162 | # # => 47 163 | # 164 | def degrees_of_freedom 165 | @regression.degrees_of_freedom 166 | end 167 | 168 | # degrees_of_freedom will be removed soon 169 | alias :degree_of_freedom :degrees_of_freedom 170 | deprecate :degree_of_freedom, :degrees_of_freedom, 2017, 1 171 | 172 | # Returns the optimal value of the log-likelihood function when using MLE algorithm. 173 | # The optimal value is the value of the log-likelihood function at the MLE solution. 174 | # 175 | # @return [Numeric] the optimal value of log-likelihood function 176 | # 177 | # @example 178 | # require 'statsample-glm' 179 | # data_set = Daru::DataFrame.from_csv "spec/data/logistic.csv" 180 | # glm = Statsample::GLM.compute data_set, "y", :logistic, constant: 1, algorithm: :mle 181 | # glm.log_likelihood 182 | # # => -21.4752278175261 183 | # 184 | def log_likelihood 185 | @regression.log_likelihood if @opts[:algorithm] == :mle 186 | end 187 | 188 | # Use the fitted GLM to obtain predictions on new data. 189 | # 190 | # == Arguments 191 | # 192 | # * new_data - a `Daru::DataFrame` containing new observations for the same 193 | # variables that were used to fit the model. The vectors must be given 194 | # in the same order as in the data frame that was originally used to fit 195 | # the model. If `new_data` is not provided, then the original data frame 196 | # which was used to fit the model, is used in place of `new_data`. 197 | # 198 | # == Returns 199 | # 200 | # A `Daru::Vector` containing the predictions. The predictions are 201 | # computed on the scale of the response variable (for example, for 202 | # the logistic regression model, the predictions are probabilities 203 | # on logit scale). 204 | # 205 | # == Usage 206 | # 207 | # require 'statsample-glm' 208 | # data_set = Daru::DataFrame.from_csv "spec/data/logistic.csv" 209 | # glm = Statsample::GLM.compute data_set, "y", :logistic, {constant: 1} 210 | # new_data = Daru::DataFrame.new([[0.1, 0.2, 0.3], [-0.1, 0.0, 0.1]], 211 | # order: ["x1", "x2"]) 212 | # glm.predict new_data 213 | # # => 214 | # # # 215 | # # nil 216 | # # 0 0.6024496420392775 217 | # # 1 0.6486486378079906 218 | # # 2 0.6922216620285223 219 | # 220 | def predict new_data=nil 221 | if @opts[:constant] then 222 | new_data.add_vector :constant, [@opts[:constant]]*new_data.nrows 223 | end 224 | # Statsample::GLM::Normal model always has an intercept term, see #initialize 225 | if self.is_a? Statsample::GLM::Normal then 226 | new_data.add_vector :constant, [1.0]*new_data.nrows 227 | end 228 | 229 | @regression.predict new_data 230 | end 231 | 232 | private 233 | 234 | def set_default_opts_if_any 235 | @opts[:algorithm] ||= :irls 236 | @opts[:iterations] ||= 100 237 | @opts[:epsilon] ||= 1e-7 238 | @opts[:link] ||= :log 239 | end 240 | 241 | def create_vector arr 242 | Daru::Vector.new(arr) 243 | end 244 | 245 | def add_constant_vector x=1 246 | @data_set.add_vector :constant, [x]*@data_set.nrows 247 | end 248 | end 249 | end 250 | end 251 | -------------------------------------------------------------------------------- /lib/statsample-glm/glm/formula/formula.rb: -------------------------------------------------------------------------------- 1 | module Statsample 2 | module GLM 3 | # To process formula language 4 | class Formula 5 | attr_reader :tokens, :canonical_tokens 6 | 7 | def initialize(tokens) 8 | @tokens = tokens 9 | @canonical_tokens = parse_formula 10 | end 11 | 12 | def canonical_to_s 13 | canonical_tokens.join '+' 14 | end 15 | 16 | # private 17 | # TODO: Uncomment private after debuggin 18 | 19 | def parse_formula 20 | @tokens.inject([]) do |acc, token| 21 | acc + add_non_redundant_elements(token, acc) 22 | end 23 | end 24 | 25 | def add_non_redundant_elements(token, result_so_far) 26 | return [token] if token.value == '1' 27 | tokens = token.expand 28 | result_so_far = result_so_far.flat_map(&:expand) 29 | tokens -= result_so_far 30 | contract_if_possible tokens 31 | end 32 | 33 | def contract_if_possible(tokens) 34 | tokens.combination(2).each do |a, b| 35 | result = a.add b 36 | next unless result 37 | tokens.delete a 38 | tokens.delete b 39 | tokens << result 40 | return contract_if_possible tokens 41 | end 42 | tokens.sort 43 | end 44 | end 45 | end 46 | end 47 | -------------------------------------------------------------------------------- /lib/statsample-glm/glm/formula/token.rb: -------------------------------------------------------------------------------- 1 | module Statsample 2 | module GLM 3 | # To encapsulate interaction as well as non-interaction terms 4 | class Token 5 | attr_reader :value, :full, :interact_terms 6 | 7 | def initialize(value, full = true) 8 | @interact_terms = value.include?(':') ? value.split(':') : [value] 9 | @full = coerce_full full 10 | end 11 | 12 | def value 13 | interact_terms.join(':') 14 | end 15 | 16 | def size 17 | # TODO: Return size 1 for value '1' also 18 | # CAn't do this at the moment because have to make 19 | # changes in sorting first 20 | value == '1' ? 0 : interact_terms.size 21 | end 22 | 23 | def condition_1?(other) 24 | # 1: ANYTHING + FACTOR- : ANYTHING = FACTOR : ANYTHING 25 | other.size == 2 && 26 | size == 1 && 27 | other.interact_terms.last == value && 28 | other.full.last == full.first && 29 | other.full.first == false 30 | end 31 | 32 | def condition_2?(other) 33 | # 2: ANYTHING + ANYTHING : FACTOR- = ANYTHING : FACTOR 34 | other.size == 2 && 35 | size == 1 && 36 | other.interact_terms.first == value && 37 | other.full.first == full.first && 38 | other.full.last == false 39 | end 40 | 41 | def add(other) # rubocop:disable Metrics/AbcSize 42 | # 1: ANYTHING + FACTOR- : ANYTHING = FACTOR : ANYTHING 43 | # 2: ANYTHING + ANYTHING : FACTOR- = ANYTHING : FACTOR 44 | if size > other.size 45 | other.add self 46 | 47 | elsif condition_1? other 48 | Token.new( 49 | "#{other.interact_terms.first}:#{value}", 50 | [true, other.full.last] 51 | ) 52 | 53 | elsif condition_2? other 54 | Token.new( 55 | "#{value}:#{other.interact_terms.last}", 56 | [other.full.first, true] 57 | ) 58 | 59 | elsif value == '1' && other.size == 1 60 | Token.new(other.value, true) 61 | end 62 | end 63 | 64 | def ==(other) 65 | value == other.value && 66 | full == other.full 67 | end 68 | 69 | alias eql? == 70 | 71 | def hash 72 | value.hash ^ full.hash 73 | end 74 | 75 | def <=>(other) 76 | size <=> other.size 77 | end 78 | 79 | def to_s 80 | interact_terms 81 | .zip(full) 82 | .map { |t, f| f ? t : t + '(-)' } 83 | .join ':' 84 | end 85 | 86 | def expand 87 | case size 88 | when 0 89 | [self] 90 | when 1 91 | [Token.new('1'), Token.new(value, false)] 92 | when 2 93 | a, b = interact_terms 94 | [Token.new('1'), Token.new(a, false), Token.new(b, false), 95 | Token.new(a + ':' + b, [false, false])] 96 | end 97 | end 98 | 99 | def to_df(df) 100 | case size 101 | when 1 102 | if df[value].category? 103 | df[value].contrast_code full: full.first 104 | else 105 | Daru::DataFrame.new value => df[value].to_a 106 | end 107 | when 2 108 | to_df_when_interaction(df) 109 | end 110 | end 111 | 112 | private 113 | 114 | def coerce_full(value) 115 | if value.is_a? Array 116 | value + Array.new((@interact_terms.size - value.size), true) 117 | else 118 | [value] * @interact_terms.size 119 | end 120 | end 121 | 122 | def to_df_when_interaction(df) 123 | case interact_terms.map { |t| df[t].category? } 124 | when [true, true] 125 | df.interact_code(interact_terms, full) 126 | when [false, false] 127 | to_df_numeric_interact_with_numeric df 128 | when [true, false] 129 | to_df_category_interact_with_numeric df 130 | when [false, true] 131 | to_df_numeric_interact_with_category df 132 | end 133 | end 134 | 135 | def to_df_numeric_interact_with_numeric(df) 136 | Daru::DataFrame.new value => (df[interact_terms.first] * 137 | df[interact_terms.last]).to_a 138 | end 139 | 140 | def to_df_category_interact_with_numeric(df) 141 | a, b = interact_terms 142 | Daru::DataFrame.new( 143 | df[a].contrast_code(full: full.first) 144 | .map { |dv| ["#{dv.name}:#{b}", (dv * df[b]).to_a] } 145 | .to_h 146 | ) 147 | end 148 | 149 | def to_df_numeric_interact_with_category(df) 150 | a, b = interact_terms 151 | Daru::DataFrame.new( 152 | df[b].contrast_code(full: full.last) 153 | .map { |dv| ["#{a}:#{dv.name}", (dv * df[a]).to_a] } 154 | .to_h 155 | ) 156 | end 157 | end 158 | end 159 | end 160 | -------------------------------------------------------------------------------- /lib/statsample-glm/glm/formula/wrapper.rb: -------------------------------------------------------------------------------- 1 | require_relative 'token' 2 | 3 | module Statsample 4 | module GLM 5 | # This class recognizes what terms are numeric 6 | # and accordingly forms groups which are fed to Formula 7 | # Once they are parsed with Formula, they are combined back 8 | class FormulaWrapper 9 | attr_reader :tokens, :y, :canonical_tokens 10 | 11 | # Initializes formula wrapper object to parse a given formula into 12 | # some tokens which do not overlap one another. 13 | # @note Specify 0 as a term in the formula if you do not want constant 14 | # to be included in the parsed formula 15 | # @param [string] formula to parse 16 | # @param [Daru::DataFrame] df dataframe requried to know what vectors 17 | # are numerical 18 | # @example 19 | # df = Daru::DataFrame.from_csv 'spec/data/df.csv' 20 | # df.to_category 'c', 'd', 'e' 21 | # formula = Statsample::GLM::FormulaWrapper.new 'y~a+d:c', df 22 | # formula.canonical_to_s 23 | # #=> "1+c(-)+d(-):c+a" 24 | def initialize(formula, df) 25 | @df = df 26 | # @y store the LHS term that is name of vector to be predicted 27 | # @tokens store the RHS terms of the formula 28 | formula = formula.gsub(/\s+/, '') 29 | lhs, rhs = split_lhs_rhs formula 30 | @y = Token(lhs) 31 | @tokens = split_to_tokens reduce_formula(rhs) 32 | @tokens = @tokens.uniq.sort 33 | manage_constant_term 34 | @canonical_tokens = non_redundant_tokens 35 | end 36 | 37 | def split_lhs_rhs(expr) 38 | expr.split '~' 39 | end 40 | 41 | def reduce_formula(expr) 42 | # Split the expression to array 43 | expr = expr.split %r{(?=[+*/:()])|(?<=[+*/:()])} 44 | # Convert infix exp to postfix exp 45 | postfix_expr = to_postfix expr 46 | # Evaluate the expression 47 | eval_postfix postfix_expr 48 | end 49 | 50 | # Returns canonical tokens in a readable form. 51 | # @return [String] canonical tokens in a readable form. 52 | # @note 'y~a+b(-)' means 'a' exist in full rank expansion 53 | # and 'b(-)' exist in reduced rank expansion 54 | # @example 55 | # df = Daru::DataFrame.from_csv 'spec/data/df.csv' 56 | # df.to_category 'c', 'd', 'e' 57 | # formula = Statsample::GLM::FormulaWrapper.new 'y~a+d:c', df 58 | # formula.canonical_to_s 59 | # #=> "1+c(-)+d(-):c+a" 60 | def canonical_to_s 61 | canonical_tokens.join '+' 62 | end 63 | 64 | # Returns tokens to produce non-redundant design matrix 65 | # @return [Array] array of tokens that do not produce redundant matrix 66 | def non_redundant_tokens 67 | groups = split_to_groups 68 | # TODO: An enhancement 69 | # Right now x:c appears as c:x 70 | groups.each { |k, v| groups[k] = strip_numeric v, k } 71 | groups.each { |k, v| groups[k] = Formula.new(v).canonical_tokens } 72 | groups.flat_map { |k, v| add_numeric v, k } 73 | end 74 | 75 | def to_s 76 | "#{@y}~#{@tokens.join '+'}" 77 | end 78 | 79 | private 80 | 81 | TOKEN_0 = Token.new '0' 82 | TOKEN_1 = Token.new '1' 83 | def Token(val, full = true) # rubocop:disable Style/MethodName 84 | return TOKEN_0 if val == '0' 85 | return TOKEN_1 if val == '1' 86 | Token.new(val, full) 87 | end 88 | 89 | # Removes intercept token if term '0' is found in the formula. 90 | # Intercept token remains if term '1' is found. 91 | # If neither term '0' nor term '1' is found then, 92 | # intercept token is added. 93 | def manage_constant_term 94 | @tokens.unshift Token('1') unless 95 | @tokens.include?(Token('1')) || 96 | @tokens.include?(Token('0')) 97 | @tokens.delete Token('0') 98 | end 99 | 100 | # Groups the tokens to gropus based on the numerical terms 101 | # they are interacting with. 102 | def split_to_groups 103 | @tokens.group_by { |t| extract_numeric t } 104 | end 105 | 106 | # Add numeric interaction term which was removed earlier 107 | # @param [Array] tokens tokens on which to add numerical terms 108 | # @param [Array] numeric array of numeric terms to add 109 | def add_numeric(tokens, numeric) 110 | tokens.map do |t| 111 | terms = t.interact_terms + numeric 112 | if terms == ['1'] 113 | Token('1') 114 | else 115 | terms = terms.reject { |i| i == '1' } 116 | Token(terms.join(':'), t.full) 117 | end 118 | end 119 | end 120 | 121 | # Strip numerical interacting terms 122 | # @param [Array] tokens tokens from which to strip numeric 123 | # @param [Array] numeric array of numeric terms to strip from tokens 124 | # @return [Array] array of tokens with striped numerical terms 125 | def strip_numeric(tokens, numeric) 126 | tokens.map do |t| 127 | terms = t.interact_terms - numeric 128 | terms = ['1'] if terms.empty? 129 | Token(terms.join(':')) 130 | end 131 | end 132 | 133 | # Extract numeric interacting terms 134 | # @param [Statsample::GLM::Token] token to extract numeric terms from 135 | # @return [Array] array of numericl terms 136 | def extract_numeric(token) 137 | terms = token.interact_terms 138 | return [] if terms == ['1'] 139 | terms.reject { |t| @df[t].category? } 140 | end 141 | 142 | def split_to_tokens(formula) 143 | formula.split('+').map { |t| Token(t) } 144 | end 145 | 146 | # ==========BEGIN========== 147 | # Helpers for reduce_formula 148 | PRIORITY = %w(+ * / :).freeze 149 | def priority_le?(op1, op2) 150 | return false unless PRIORITY.include? op2 151 | PRIORITY.index(op1) <= PRIORITY.index(op2) 152 | end 153 | 154 | # to_postfix 'a+b' gives 'ab+' 155 | def to_postfix(expr) # rubocop:disable Metrics/MethodLength 156 | res_exp = [] 157 | stack = ['('] 158 | expr << ')' 159 | expr.each do |s| 160 | if s == '(' 161 | stack.push '(' 162 | elsif PRIORITY.include? s 163 | res_exp << stack.pop while priority_le?(s, stack.last) 164 | stack.push s 165 | elsif s == ')' 166 | res_exp << stack.pop until stack.last == '(' 167 | stack.pop 168 | else 169 | res_exp << s 170 | end 171 | end 172 | res_exp 173 | end 174 | 175 | # eval_postfix 'ab*' gives 'a+b+a:b' 176 | def eval_postfix(expr) 177 | # Scan through each symbol 178 | stack = [] 179 | expr.each do |s| 180 | if PRIORITY.include? s 181 | y = stack.pop 182 | x = stack.pop 183 | stack << apply_operation(s, x, y) 184 | else 185 | stack.push s 186 | end 187 | end 188 | stack.pop 189 | end 190 | 191 | def apply_interact_op(x, y) 192 | x = x.split('+').to_a 193 | y = y.split('+').to_a 194 | terms = x.product(y) 195 | terms.map! { |term| "#{term[0]}:#{term[1]}" } 196 | terms.join '+' 197 | end 198 | 199 | def apply_operation(op, x, y) 200 | case op 201 | when '+' 202 | [x, y].join op 203 | when ':' 204 | apply_interact_op x, y 205 | when '*' 206 | [x, y, apply_interact_op(x, y)].join '+' 207 | when '/' 208 | [x, apply_interact_op(x, y)].join '+' 209 | else 210 | raise ArgumentError, "Invalid operation #{op}." 211 | end 212 | end 213 | #==========END========== 214 | end 215 | end 216 | end 217 | -------------------------------------------------------------------------------- /lib/statsample-glm/glm/irls/base.rb: -------------------------------------------------------------------------------- 1 | module Statsample 2 | module GLM 3 | module IRLS 4 | class Base 5 | 6 | attr_reader :coefficients, :standard_errors, :iterations, 7 | :fitted_mean_values, :residuals, :degrees_of_freedom 8 | 9 | def initialize data_set, dependent, opts={} 10 | @data_set = data_set.to_matrix 11 | @dependent = dependent 12 | @opts = opts 13 | 14 | irls 15 | end 16 | 17 | # Use the fitted GLM to obtain predictions on new data. 18 | # 19 | # == Arguments 20 | # 21 | # * new_data - a `Daru::DataFrame` containing new observations for the same 22 | # variables that were used to fit the model. The vectors must be given 23 | # in the same order as in the data frame that was originally used to fit 24 | # the model. If `new_data` is not provided, then the original data frame 25 | # which was used to fit the model, is used in place of `new_data`. 26 | # 27 | # == Returns 28 | # 29 | # A `Daru::Vector` containing the predictions. The predictions are 30 | # computed on the scale of the response variable (for example, for 31 | # the logistic regression model, the predictions are probabilities 32 | # on logit scale). 33 | # 34 | def predict new_data_set=nil 35 | if new_data_set.nil? then 36 | @fitted_mean_values 37 | else 38 | new_data_matrix = new_data_set.to_matrix 39 | b = @coefficients.to_matrix axis=:vertical 40 | create_vector measurement(new_data_matrix, b).to_a.flatten 41 | end 42 | end 43 | 44 | private 45 | 46 | def irls 47 | max_iter = @opts[:iterations] 48 | b = Matrix.column_vector Array.new(@data_set.column_size,0.0) 49 | 50 | 1.upto(max_iter) do 51 | intermediate = (hessian(@data_set,b).inverse * 52 | jacobian(@data_set, b, @dependent)) 53 | b_new = b - intermediate 54 | 55 | if((b_new - b).map(&:abs)).to_a.flatten.inject(:+) < @opts[:epsilon] 56 | b = b_new 57 | break 58 | end 59 | b = b_new 60 | end 61 | 62 | @coefficients = create_vector(b.column_vectors[0]) 63 | @iterations = max_iter 64 | @standard_errors = create_vector(hessian(@data_set,b).inverse 65 | .diagonal 66 | .map{ |x| -x} 67 | .map{ |y| Math.sqrt(y) }) 68 | @fitted_mean_values = create_vector measurement(@data_set,b).to_a.flatten 69 | @residuals = @dependent - @fitted_mean_values 70 | @degrees_of_freedom = @dependent.count - @data_set.column_size 71 | end 72 | 73 | def create_vector arr 74 | Daru::Vector.new(arr) 75 | end 76 | end 77 | end 78 | end 79 | end 80 | -------------------------------------------------------------------------------- /lib/statsample-glm/glm/irls/logistic.rb: -------------------------------------------------------------------------------- 1 | require 'statsample-glm/glm/irls/base' 2 | 3 | module Statsample 4 | module GLM 5 | module IRLS 6 | class Logistic < Statsample::GLM::IRLS::Base 7 | def initialize data_set, dependent, opts={} 8 | super data_set, dependent, opts 9 | end 10 | 11 | def to_s 12 | "Statsample::GLM::Logistic" 13 | end 14 | 15 | protected 16 | 17 | def measurement x, b 18 | (x * b).map { |y| 1/(1 + Math.exp(-y)) } 19 | end 20 | 21 | def weight x, b 22 | mus = measurement(x,b).column_vectors.map(&:to_a).flatten 23 | mus_intermediate = mus.map { |p| 1 - p } 24 | weights = mus.zip(mus_intermediate).collect { |x| x.inject(:*) } 25 | 26 | w_mat = Matrix.I(weights.size) 27 | w_enum = weights.to_enum 28 | return w_mat.map do |x| 29 | x.eql?(1) ? w_enum.next : x # diagonal consists of first derivatives of logit 30 | end 31 | end 32 | 33 | def jacobian x, b, y 34 | mu_flat = measurement(x,b).column_vectors.map(&:to_a).flatten 35 | column_data = y.zip(mu_flat).map { |x| x.inject(:-) } 36 | 37 | x.transpose * Matrix.column_vector(column_data) 38 | end 39 | 40 | def hessian x, b 41 | (x.transpose * weight(x, b) * x).map { |x| -x } 42 | end 43 | end 44 | end 45 | end 46 | end -------------------------------------------------------------------------------- /lib/statsample-glm/glm/irls/poisson.rb: -------------------------------------------------------------------------------- 1 | require 'statsample-glm/glm/irls/base' 2 | 3 | module Statsample 4 | module GLM 5 | module IRLS 6 | class Poisson < Statsample::GLM::IRLS::Base 7 | def initialize data_set, dependent, opts={} 8 | super data_set, dependent, opts 9 | end 10 | 11 | def to_s 12 | puts "Logistic Regression (Statsample::Regression::GLM::Logistic)" 13 | end 14 | protected 15 | 16 | def measurement x, b 17 | if @opts[:link] == :log 18 | (x * b).map { |y| Math.exp(y) } 19 | elsif @opts[:link] == :sqrt 20 | (x * b).map { |y| y**2 } 21 | end 22 | end 23 | 24 | def weight x, b 25 | m = measurement(x,b).column_vectors.map(&:to_a).flatten 26 | 27 | w_mat = Matrix.I(m.size) 28 | w_enum = m.to_enum 29 | 30 | return w_mat.map do |x| 31 | x.eql?(1) ? w_enum.next : x # diagonal consists of first derivatives of logit 32 | end 33 | end 34 | 35 | def hessian x, b 36 | (x.transpose * weight(x, b) * x).map { |x| -x } 37 | end 38 | 39 | def jacobian x, b, y 40 | measurement_flat = measurement(x,b).column_vectors.map(&:to_a).flatten 41 | column_data = y.zip(measurement_flat).collect { |x| x.inject(:-) } 42 | 43 | x.transpose * Matrix.columns([column_data]) 44 | end 45 | end 46 | end 47 | end 48 | end -------------------------------------------------------------------------------- /lib/statsample-glm/glm/logistic.rb: -------------------------------------------------------------------------------- 1 | require 'statsample-glm/glm/base' 2 | 3 | module Statsample 4 | module GLM 5 | class Logistic < Statsample::GLM::Base 6 | 7 | def initialize data_set, dependent, opts 8 | super data_set, dependent, opts 9 | end 10 | 11 | def to_s 12 | "Statsample::GLM::Logistic" 13 | end 14 | end 15 | end 16 | end 17 | -------------------------------------------------------------------------------- /lib/statsample-glm/glm/mle/base.rb: -------------------------------------------------------------------------------- 1 | module Statsample 2 | 3 | module GLM 4 | module MLE 5 | class Base 6 | attr_reader :coefficients, :iterations, 7 | :fitted_mean_values, :residuals, :degrees_of_freedom, 8 | :log_likelihood 9 | 10 | MIN_DIFF_PARAMETERS=1e-2 11 | 12 | def initialize data_set, dependent, opts 13 | @opts = opts 14 | 15 | @data_set = data_set 16 | @dependent = dependent 17 | 18 | @stop_criteria = :parameters 19 | @var_cov_matrix = nil 20 | @iterations = nil 21 | @parameters = nil 22 | 23 | x = @data_set.to_matrix 24 | y = @dependent.to_matrix(:vertical) 25 | 26 | @coefficients = newton_raphson x, y 27 | @log_likelihood = _log_likelihood x, y, @coefficients 28 | @fitted_mean_values = create_vector measurement(x, @coefficients).to_a.flatten 29 | @residuals = @dependent - @fitted_mean_values 30 | @degrees_of_freedom = @dependent.count - x.column_size 31 | 32 | # This jugad is done because the last vector index for Normal is sigma^2 33 | # which we dont want to return to the user. 34 | @coefficients = create_vector(self.is_a?(Statsample::GLM::MLE::Normal) ? 35 | @coefficients.to_a.flatten[0..-2] : @coefficients.to_a.flatten) 36 | end 37 | 38 | def standard_errors 39 | out = [] 40 | 41 | @data_set.vectors.to_a.each_index do |i| 42 | out << Math::sqrt(@var_cov_matrix[i,i]) 43 | end 44 | 45 | create_vector out 46 | end 47 | 48 | # Use the fitted GLM to obtain predictions on new data. 49 | # 50 | # == Arguments 51 | # 52 | # * new_data - a `Daru::DataFrame` containing new observations for the same 53 | # variables that were used to fit the model. The vectors must be given 54 | # in the same order as in the data frame that was originally used to fit 55 | # the model. If `new_data` is not provided, then the original data frame 56 | # which was used to fit the model, is used in place of `new_data`. 57 | # 58 | # == Returns 59 | # 60 | # A `Daru::Vector` containing the predictions. The predictions are 61 | # computed on the scale of the response variable (for example, for 62 | # the logistic regression model, the predictions are probabilities 63 | # on logit scale). 64 | # 65 | def predict new_data_set=nil 66 | if new_data_set.nil? then 67 | @fitted_mean_values 68 | else 69 | new_data_matrix = new_data_set.to_matrix 70 | # this if statement is done because Statsample::GLM::MLE::Normal#measurement expects that 71 | # the coefficient vector has a redundant last element (which is discarded by #measurement 72 | # in further computation) 73 | b = if self.is_a?(Statsample::GLM::MLE::Normal) then 74 | Matrix.column_vector(@coefficients.to_a + [nil]) 75 | else 76 | @coefficients.to_matrix(axis=:vertical) 77 | end 78 | create_vector measurement(new_data_matrix, b).to_a.flatten 79 | end 80 | end 81 | 82 | # Newton Raphson with automatic stopping criteria. 83 | # Based on: Von Tessin, P. (2005). Maximum Likelihood Estimation With Java and Ruby 84 | # 85 | # x:: matrix of dependent variables. Should have nxk dimensions 86 | # y:: matrix of independent values. Should have nx1 dimensions 87 | # @m:: class for @ming. Could be Normal or Logistic 88 | # start_values:: matrix of coefficients. Should have 1xk dimensions 89 | def newton_raphson(x,y, start_values=nil) 90 | # deep copy? 91 | if start_values.nil? 92 | parameters = set_default_parameters(x) 93 | else 94 | parameters = start_values.dup 95 | end 96 | k = parameters.row_size 97 | 98 | raise "n on y != n on x" if x.row_size != y.row_size 99 | h = nil 100 | fd = nil 101 | 102 | if @stop_criteria == :mle 103 | old_likelihood = _log_likelihood(x, y, parameters) 104 | else 105 | old_parameters = parameters 106 | end 107 | 108 | @opts[:iterations].times do |i| 109 | @iterations = i + 1 110 | 111 | h = second_derivative(x,y,parameters) 112 | if h.singular? 113 | raise "Hessian is singular!" 114 | end 115 | fd = first_derivative(x,y,parameters) 116 | parameters = parameters - (h.inverse * (fd)) 117 | 118 | if @stop_criteria == :parameters 119 | flag = true 120 | k.times do |j| 121 | diff = ( parameters[j,0] - old_parameters[j,0] ) / parameters[j,0] 122 | flag = false if diff.abs >= MIN_DIFF_PARAMETERS 123 | 124 | end 125 | 126 | if flag 127 | @var_cov_matrix = h.inverse*-1.0 128 | return parameters 129 | end 130 | old_parameters = parameters 131 | else 132 | begin 133 | new_likelihood = _log_likelihood(x,y,parameters) 134 | 135 | if(new_likelihood < old_likelihood) or ((new_likelihood - old_likelihood) / new_likelihood).abs < @opts[:epsilon] 136 | @var_cov_matrix = h.inverse*-1.0 137 | break; 138 | end 139 | old_likelihood = new_likelihood 140 | rescue =>e 141 | puts "#{e}" 142 | end 143 | end 144 | end 145 | @parameters = parameters 146 | parameters 147 | end 148 | 149 | private 150 | # Calculate likelihood for matrices x and y, given b parameters 151 | def likelihood x,y,b 152 | prod = 1 153 | x.row_size.times{|i| 154 | xi=Matrix.rows([x.row(i).to_a.collect{|v| v.to_f}]) 155 | y_val=y[i,0].to_f 156 | #fbx=f(b,x) 157 | prod=prod*likelihood_i(xi, y_val ,b) 158 | } 159 | prod 160 | end 161 | 162 | # Calculate log likelihood for matrices x and y, given b parameters 163 | def _log_likelihood x,y,b 164 | sum = 0 165 | x.row_size.times{|i| 166 | xi = Matrix.rows([x.row(i).to_a.collect{|v| v.to_f}]) 167 | y_val = y[i,0].to_f 168 | sum += log_likelihood_i xi, y_val, b 169 | } 170 | 171 | sum 172 | end 173 | 174 | # Creates a zero matrix Mx1, with M=x.M 175 | def set_default_parameters x 176 | fd = [0.0] * x.column_size 177 | 178 | fd.push(0.1) if self.is_a? Statsample::GLM::MLE::Normal 179 | Matrix.columns([fd]) 180 | end 181 | 182 | def create_vector arr 183 | Daru::Vector.new(arr) 184 | end 185 | end 186 | end 187 | end 188 | end 189 | -------------------------------------------------------------------------------- /lib/statsample-glm/glm/mle/logistic.rb: -------------------------------------------------------------------------------- 1 | require 'statsample-glm/glm/mle/base' 2 | 3 | module Statsample 4 | module GLM 5 | module MLE 6 | # Logistic MLE estimation. 7 | # See Statsample::Regression for methods to generate a logit regression. 8 | # Usage: 9 | # 10 | # mle=Statsample::GLM::MLE::Logistic.new 11 | # mle.newton_raphson(x,y) 12 | # beta=mle.coefficients 13 | # likelihood=mle.likelihood(x, y, beta) 14 | # iterations=mle.iterations 15 | # 16 | class Logistic < Statsample::GLM::MLE::Base 17 | 18 | protected 19 | # F(B'Xi) 20 | def f(b,xi) 21 | p_bx = (xi*b)[0,0] 22 | res = (1.0/(1.0+Math::exp(-p_bx))) 23 | if res == 0.0 24 | res = 1e-15 25 | elsif res == 1.0 26 | res = 0.999999999999999 27 | end 28 | 29 | res 30 | end 31 | 32 | # Likehood for x_i vector, y_i scalar and b parameters 33 | def likelihood_i(xi,yi,b) 34 | (f(b,xi)**yi)*((1-f(b,xi))**(1-yi)) 35 | end 36 | 37 | # Log Likehood for x_i vector, y_i scalar and b parameters 38 | def log_likelihood_i(xi,yi,b) 39 | fbx = f(b,xi) 40 | (yi.to_f*Math::log(fbx))+((1.0-yi.to_f)*Math::log(1.0-fbx)) 41 | end 42 | 43 | # First derivative of log-likelihood function 44 | # x: Matrix (NxM) 45 | # y: Matrix (Nx1) 46 | # p: Matrix (Mx1) 47 | def first_derivative(x,y,p) 48 | raise "x.rows != y.rows" if x.row_size != y.row_size 49 | raise "x.columns != p.rows" if x.column_size != p.row_size 50 | 51 | n = x.row_size 52 | k = x.column_size 53 | fd = Array.new(k) 54 | k.times {|i| fd[i] = [0.0]} 55 | 56 | n.times do |i| 57 | row = x.row(i).to_a 58 | value1 = (1 - y[i,0]) - p_plus(row,p) 59 | 60 | k.times do |j| 61 | fd[j][0] -= value1*row[j] 62 | end 63 | end 64 | Matrix.rows(fd, true) 65 | end 66 | # Second derivative of log-likelihood function 67 | # x: Matrix (NxM) 68 | # y: Matrix (Nx1) 69 | # p: Matrix (Mx1) 70 | def second_derivative(x,y,p2) 71 | raise "x.rows!=y.rows" if x.row_size!=y.row_size 72 | raise "x.columns!=p.rows" if x.column_size!=p2.row_size 73 | n = x.row_size 74 | k = x.column_size 75 | sd = Array.new(k) 76 | k.times do |i| 77 | arr = Array.new(k) 78 | k.times{ |j| arr[j]=0.0} 79 | sd[i] = arr 80 | end 81 | n.times do |i| 82 | row = x.row(i).to_a 83 | p_m = p_minus(row,p2) 84 | k.times do |j| 85 | k.times do |l| 86 | sd[j][l] -= (p_m*(1-p_m)*row[j]*row[l]) 87 | end 88 | end 89 | end 90 | Matrix.rows(sd, true) 91 | end 92 | 93 | def measurement x, b 94 | (x * b).map { |y| 1/(1 + Math.exp(-y)) } 95 | end 96 | 97 | private 98 | def p_minus(x_row,p) 99 | value = 0.0; 100 | x_row.each_index { |i| value += x_row[i]*p[i,0]} 101 | 1/(1+Math.exp(-value)) 102 | end 103 | 104 | def p_plus(x_row,p) 105 | value = 0.0; 106 | x_row.each_index { |i| value += x_row[i]*p[i,0]} 107 | 1/(1+Math.exp(value)) 108 | end 109 | 110 | end # Logistic 111 | end # MLE 112 | end # GLM 113 | end # Statsample 114 | -------------------------------------------------------------------------------- /lib/statsample-glm/glm/mle/normal.rb: -------------------------------------------------------------------------------- 1 | require 'statsample-glm/glm/mle/base' 2 | 3 | module Statsample 4 | module GLM 5 | module MLE 6 | # Normal Distribution MLE estimation. 7 | # Usage: 8 | # TODO : Document this properly 9 | # mle=Statsample::MLE::Normal.new 10 | # mle.newton_raphson(x,y) 11 | # beta=mle.coefficients 12 | # likelihood=mle.likelihood(x,y,beta) 13 | # iterations=mle.iterations 14 | class Normal < Statsample::GLM::MLE::Base 15 | 16 | protected 17 | def measurement data_set, coefficients 18 | (data_set * coefficients[0..-2,0]).map { |xb| xb } 19 | end 20 | # Total MLE for given X, Y and B matrices (overridden over the Base version) 21 | def _log_likelihood(x,y,b) 22 | n = x.row_size.to_f 23 | sigma2 = b[b.row_size-1,0] 24 | betas = Matrix.columns([b.column(0). to_a[0...b.row_size-1]]) 25 | e = y - (x * betas) 26 | last = (1 / (2*sigma2)) * e.t * e 27 | (-(n / 2.0) * Math::log(2*Math::PI))-((n / 2.0)*Math::log(sigma2)) - last[0,0] 28 | end 29 | # First derivative for Normal Model. 30 | # p should be [k+1,1], because the last parameter is sigma^2 31 | def first_derivative(x,y,p) 32 | raise "x.rows != y.rows" if x.row_size != y.row_size 33 | raise "x.columns + 1 != p.rows" if x.column_size + 1 != p.row_size 34 | 35 | n = x.row_size 36 | k = x.column_size 37 | b = Array.new(k) 38 | 39 | k.times{|i| b[i]=[p[i,0]]} 40 | beta = Matrix.rows(b) 41 | sigma2 = p[k,0] 42 | sigma4 = sigma2 * sigma2 43 | e = y-(x * (beta)) 44 | xte = x.transpose*(e) 45 | ete = e.transpose*(e) 46 | #rows of the Jacobian 47 | rows = Array.new(k+1) 48 | k.times{|i| rows[i] = [xte[i,0] / sigma2]} 49 | rows[k] = [ete[0,0] / (2*sigma4) - n / (2*sigma2)] 50 | Matrix.rows(rows, true) 51 | end 52 | 53 | # second derivative for normal model 54 | # p should be [k+1,1], because the last parameter is sigma^2 55 | def second_derivative(x,y,p) 56 | raise "x.rows != y.rows" if x.row_size != y.row_size 57 | raise "x.columns + 1 != p.rows" if x.column_size + 1 != p.row_size 58 | #n = x.row_size 59 | k = x.column_size 60 | b = Array.new(k) 61 | k.times{|i| b[i] = [p[i,0]]} 62 | beta = Matrix.rows(b) 63 | sigma2 = p[k,0] 64 | sigma4 = sigma2*sigma2 65 | sigma6 = sigma2*sigma2*sigma2 66 | e = y-(x*(beta)) 67 | xtx = x.transpose*(x) 68 | xte = x.transpose*(e) 69 | ete = e.transpose*(e) 70 | #rows of the Hessian 71 | rows = Array.new(k+1) 72 | 73 | k.times do |i| 74 | row = Array.new(k+1) 75 | k.times do |j| 76 | row[j] = -xtx[i,j] / sigma2 77 | end 78 | row[k] = -xte[i,0] / sigma4 79 | rows[i] = row 80 | end 81 | 82 | last_row = Array.new(k+1) 83 | k.times do |i| 84 | last_row[i] = -xte[i,0] / sigma4 85 | end 86 | 87 | last_row[k] = 2*sigma4 - ete[0,0] / sigma6 88 | rows[k] = last_row 89 | Matrix.rows(rows, true) 90 | end 91 | end 92 | end 93 | end 94 | end -------------------------------------------------------------------------------- /lib/statsample-glm/glm/mle/probit.rb: -------------------------------------------------------------------------------- 1 | require 'statsample-glm/glm/mle/base' 2 | 3 | module Statsample 4 | module GLM 5 | module MLE 6 | # Probit MLE estimation. 7 | # == Usage: 8 | # 9 | # mle = Statsample::MLE::Probit.new 10 | # mle.newton_raphson(x,y) 11 | # beta = mle.parameters 12 | # likelihood = mle.likelihood(x,y,beta) 13 | # iterations = mle.iterations 14 | class Probit < Statsample::GLM::MLE::Base 15 | 16 | protected 17 | def measurement data_set, coefficients 18 | (data_set * coefficients).map { |x| Distribution::Normal.cdf(x) } 19 | end 20 | # F(B'Xi) 21 | if Statsample.has_gsl? 22 | # F(B'Xi) 23 | def f(b,x) 24 | p_bx=(x*b)[0,0] 25 | GSL::Cdf::ugaussian_P(p_bx) 26 | end 27 | # f(B'Xi) 28 | def ff(b,x) 29 | p_bx=(x*b)[0,0] 30 | GSL::Ran::ugaussian_pdf(p_bx) 31 | end 32 | else 33 | def f(b,x) #:nodoc: 34 | p_bx=(x*b)[0,0] 35 | Distribution::Normal.cdf(p_bx) 36 | end 37 | def ff(b,x) #:nodoc: 38 | p_bx=(x*b)[0,0] 39 | Distribution::Normal.pdf(p_bx) 40 | end 41 | end 42 | # Log Likehood for x_i vector, y_i scalar and b parameters 43 | def log_likelihood_i(xi,yi,b) 44 | fbx=f(b,xi) 45 | (yi.to_f*Math::log(fbx))+((1.0-yi.to_f)*Math::log(1.0-fbx)) 46 | end 47 | # First derivative of log-likelihood probit function 48 | # x: Matrix (NxM) 49 | # y: Matrix (Nx1) 50 | # p: Matrix (Mx1) 51 | def first_derivative(x,y,b) 52 | raise "x.rows!=y.rows" if x.row_size!=y.row_size 53 | raise "x.columns!=p.rows" if x.column_size!=b.row_size 54 | n = x.row_size 55 | k = x.column_size 56 | fd = Array.new(k) 57 | k.times {|i| fd[i] = [0.0]} 58 | n.times do |i| 59 | xi = Matrix.rows([x.row(i).to_a]) 60 | fbx=f(b,xi) 61 | value1 = (y[i,0]-fbx)/ ( fbx*(1-fbx))*ff(b,xi) 62 | k.times do |j| 63 | fd[j][0] += value1*xi[0,j] 64 | end 65 | end 66 | Matrix.rows(fd, true) 67 | end 68 | # Second derivative of log-likelihood probit function 69 | # x: Matrix (NxM) 70 | # y: Matrix (Nx1) 71 | # p: Matrix (Mx1) 72 | 73 | def second_derivative(x,y,b) 74 | raise "x.rows!=y.rows" if x.row_size!=y.row_size 75 | raise "x.columns!=p.rows" if x.column_size!=b.row_size 76 | n = x.row_size 77 | k = x.column_size 78 | if Statsample.has_gsl? 79 | sum=GSL::Matrix.zeros(k) 80 | else 81 | sum=Matrix.zero(k) 82 | end 83 | n.times do |i| 84 | xi=Matrix.rows([x.row(i).to_a]) 85 | fbx=f(b,xi) 86 | val=((ff(b,xi)**2) / (fbx*(1.0-fbx)))*xi.t*xi 87 | if Statsample.has_gsl? 88 | val=val.to_gsl 89 | end 90 | sum-=val 91 | end 92 | if Statsample.has_gsl? 93 | sum=sum.to_matrix 94 | end 95 | sum 96 | end 97 | end # Probit 98 | end # MLE 99 | end # GLM 100 | end # Statsample 101 | -------------------------------------------------------------------------------- /lib/statsample-glm/glm/normal.rb: -------------------------------------------------------------------------------- 1 | require 'statsample-glm/glm/base' 2 | 3 | module Statsample 4 | module GLM 5 | 6 | class Normal < Statsample::GLM::Base 7 | 8 | def initialize data_set, dependent, opts={} 9 | super data_set, dependent, opts 10 | end 11 | 12 | def to_s 13 | "Statsample::GLM::Normal" 14 | end 15 | end 16 | end 17 | end -------------------------------------------------------------------------------- /lib/statsample-glm/glm/poisson.rb: -------------------------------------------------------------------------------- 1 | require 'statsample-glm/glm/base' 2 | 3 | module Statsample 4 | module GLM 5 | 6 | class Poisson < Statsample::GLM::Base 7 | 8 | def initialize data_set, dependent, opts={} 9 | super data_set, dependent, opts 10 | end 11 | 12 | def to_s 13 | "Statsample::GLM::Poisson" 14 | end 15 | end 16 | end 17 | end 18 | -------------------------------------------------------------------------------- /lib/statsample-glm/glm/probit.rb: -------------------------------------------------------------------------------- 1 | require 'statsample-glm/glm/base' 2 | 3 | module Statsample 4 | module GLM 5 | class Probit < Statsample::GLM::Base 6 | 7 | def initialize data_set, dependent, opts={} 8 | super data_set, dependent, opts 9 | end 10 | 11 | def to_s 12 | "Statsample::GLM::Probit" 13 | end 14 | end 15 | end 16 | end -------------------------------------------------------------------------------- /lib/statsample-glm/glm/regression.rb: -------------------------------------------------------------------------------- 1 | module Statsample 2 | module GLM 3 | # Class for performing regression 4 | class Regression 5 | # Initializes a regression object to fit model using formlua language 6 | # @param [String] formula formula for creating model 7 | # @param [Daru::DataFrame] df dataframe to be used for the fitting model 8 | # @param [Symbol] method method of regression. 9 | # For example, :logistic, :normal, etc. 10 | # @example 11 | # df = Daru::DataFrame.from_csv 'spec/data/df.csv' 12 | # df.to_category 'c', 'd', 'e' 13 | # reg = Statsample::GLM::Regression.new 'y~a+b:c', df, :logistic 14 | def initialize(formula, df, method, opts = {}) 15 | @formula = FormulaWrapper.new formula, df 16 | @df = df 17 | @method = method 18 | @opts = opts 19 | end 20 | 21 | # Returns the fitted model 22 | # @return model associated with regression object obtained by applying 23 | # the formula langauge on the given dataframe with given method 24 | # @example 25 | # df = Daru::DataFrame.from_csv 'spec/data/df.csv' 26 | # df.to_category 'c', 'd', 'e' 27 | # reg = Statsample::GLM::Regression.new 'y~a+b:c', df, :logistic 28 | # mod = reg.model 29 | # mod.coefficients :hash 30 | # # => {:a=>-0.4315113121759436, 31 | # # :"c_no:b"=>-0.23438037201383238, 32 | # # :"c_yes:b"=>-0.23683973232674818, 33 | # # :constant=>16.81450207777355} 34 | def model 35 | @model || fit_model 36 | end 37 | 38 | # Obtain predictions on new data 39 | # @param [Daru::DataFrame] new_data the data to obtain predictions on 40 | # @return [Daru::Vector] vector containing predictions for new data 41 | # @example 42 | # df = Daru::DataFrame.from_csv 'spec/data/df.csv' 43 | # df.to_category 'c', 'd', 'e' 44 | # reg = Statsample::GLM::Regression.new 'y~a+b:c', df, :logistic 45 | # reg.predict df.head(3) 46 | # # => # 47 | # # 0 0.41834114079218554 48 | # # 1 0.6961349288519916 49 | # # 2 0.9993004245984171 50 | def predict(new_data) 51 | model.predict(df_for_prediction(new_data)) 52 | end 53 | 54 | # Returns dataframe obtained through applying the formula 55 | # on the given dataframe. Its for obtaining predicitons on new data. 56 | # @param [Daru::DataFrame] df datafraem for which to obtain predicitons 57 | # @return [Daru::DataFrame] dataframe obtained after applying formula 58 | # @example 59 | # df = Daru::DataFrame.from_csv 'spec/data/df.csv' 60 | # df.to_category 'c', 'd', 'e' 61 | # reg = Statsample::GLM::Regression.new 'y~a+b:c', df, :logistic 62 | # reg.df_for_prediction df.head(3) 63 | # # => # 64 | # # a c_no:b c_yes:b 65 | # # 0 6 62.1 0.0 66 | # # 1 18 0.0 34.7 67 | # # 2 6 29.7 0.0 68 | def df_for_prediction(df) 69 | # TODO: This code can be improved. 70 | # See https://github.com/v0dro/daru/issues/245 71 | df = Daru::DataFrame.new(df.to_h, 72 | order: @df.vectors.to_a & df.vectors.to_a 73 | ) 74 | df.vectors.each do |vec| 75 | if @df[vec].category? 76 | df[vec] = df[vec].to_category 77 | df[vec].categories = @df[vec].categories 78 | df[vec].base_category = @df[vec].base_category 79 | end 80 | end 81 | canonicalize_df(df) 82 | end 83 | 84 | # Returns dataframe obtained through applying formula on the dataframe. 85 | # Its used for fitting the model. 86 | # @return [Daru::DataFrame] dataframe obtained after applying formula 87 | # @example 88 | # df = Daru::DataFrame.from_csv 'spec/data/df.csv' 89 | # df.to_category 'c', 'd', 'e' 90 | # reg = Statsample::GLM::Regression.new 'y~a+b:c', df, :logistic 91 | # reg.df_for_regression.head(3) 92 | # # => # 93 | # # a c_no:b c_yes:b y 94 | # # 0 6 62.1 0.0 0 95 | # # 1 18 0.0 34.7 1 96 | # # 2 6 29.7 0.0 1 97 | def df_for_regression 98 | df = canonicalize_df(@df) 99 | df[@formula.y.value] = @df[@formula.y.value] 100 | df 101 | end 102 | 103 | private 104 | 105 | def canonicalize_df(orig_df) 106 | tokens = @formula.canonical_tokens 107 | tokens.shift if tokens.first.value == '1' 108 | df = tokens.map { |t| t.to_df orig_df }.reduce(&:merge) 109 | df 110 | end 111 | 112 | def fit_model 113 | @opts[:constant] = 1 if 114 | @formula.canonical_tokens.include? Token.new('1') 115 | @model = Statsample::GLM.compute( 116 | df_for_regression, 117 | @formula.y.value, 118 | @method, 119 | @opts 120 | ) 121 | end 122 | end 123 | end 124 | end 125 | -------------------------------------------------------------------------------- /lib/statsample-glm/version.rb: -------------------------------------------------------------------------------- 1 | module Statsample 2 | module GLM 3 | VERSION = "0.2.1" 4 | end 5 | end 6 | -------------------------------------------------------------------------------- /spec/data/binary.csv: -------------------------------------------------------------------------------- 1 | admit,gre,gpa,rank 2 | 0,380,3.61,3 3 | 1,660,3.67,3 4 | 1,800,4,1 5 | 1,640,3.19,4 6 | 0,520,2.93,4 7 | 1,760,3,2 8 | 1,560,2.98,1 9 | 0,400,3.08,2 10 | 1,540,3.39,3 11 | 0,700,3.92,2 12 | 0,800,4,4 13 | 0,440,3.22,1 14 | 1,760,4,1 15 | 0,700,3.08,2 16 | 1,700,4,1 17 | 0,480,3.44,3 18 | 0,780,3.87,4 19 | 0,360,2.56,3 20 | 0,800,3.75,2 21 | 1,540,3.81,1 22 | 0,500,3.17,3 23 | 1,660,3.63,2 24 | 0,600,2.82,4 25 | 0,680,3.19,4 26 | 1,760,3.35,2 27 | 1,800,3.66,1 28 | 1,620,3.61,1 29 | 1,520,3.74,4 30 | 1,780,3.22,2 31 | 0,520,3.29,1 32 | 0,540,3.78,4 33 | 0,760,3.35,3 34 | 0,600,3.4,3 35 | 1,800,4,3 36 | 0,360,3.14,1 37 | 0,400,3.05,2 38 | 0,580,3.25,1 39 | 0,520,2.9,3 40 | 1,500,3.13,2 41 | 1,520,2.68,3 42 | 0,560,2.42,2 43 | 1,580,3.32,2 44 | 1,600,3.15,2 45 | 0,500,3.31,3 46 | 0,700,2.94,2 47 | 1,460,3.45,3 48 | 1,580,3.46,2 49 | 0,500,2.97,4 50 | 0,440,2.48,4 51 | 0,400,3.35,3 52 | 0,640,3.86,3 53 | 0,440,3.13,4 54 | 0,740,3.37,4 55 | 1,680,3.27,2 56 | 0,660,3.34,3 57 | 1,740,4,3 58 | 0,560,3.19,3 59 | 0,380,2.94,3 60 | 0,400,3.65,2 61 | 0,600,2.82,4 62 | 1,620,3.18,2 63 | 0,560,3.32,4 64 | 0,640,3.67,3 65 | 1,680,3.85,3 66 | 0,580,4,3 67 | 0,600,3.59,2 68 | 0,740,3.62,4 69 | 0,620,3.3,1 70 | 0,580,3.69,1 71 | 0,800,3.73,1 72 | 0,640,4,3 73 | 0,300,2.92,4 74 | 0,480,3.39,4 75 | 0,580,4,2 76 | 0,720,3.45,4 77 | 0,720,4,3 78 | 0,560,3.36,3 79 | 1,800,4,3 80 | 0,540,3.12,1 81 | 1,620,4,1 82 | 0,700,2.9,4 83 | 0,620,3.07,2 84 | 0,500,2.71,2 85 | 0,380,2.91,4 86 | 1,500,3.6,3 87 | 0,520,2.98,2 88 | 0,600,3.32,2 89 | 0,600,3.48,2 90 | 0,700,3.28,1 91 | 1,660,4,2 92 | 0,700,3.83,2 93 | 1,720,3.64,1 94 | 0,800,3.9,2 95 | 0,580,2.93,2 96 | 1,660,3.44,2 97 | 0,660,3.33,2 98 | 0,640,3.52,4 99 | 0,480,3.57,2 100 | 0,700,2.88,2 101 | 0,400,3.31,3 102 | 0,340,3.15,3 103 | 0,580,3.57,3 104 | 0,380,3.33,4 105 | 0,540,3.94,3 106 | 1,660,3.95,2 107 | 1,740,2.97,2 108 | 1,700,3.56,1 109 | 0,480,3.13,2 110 | 0,400,2.93,3 111 | 0,480,3.45,2 112 | 0,680,3.08,4 113 | 0,420,3.41,4 114 | 0,360,3,3 115 | 0,600,3.22,1 116 | 0,720,3.84,3 117 | 0,620,3.99,3 118 | 1,440,3.45,2 119 | 0,700,3.72,2 120 | 1,800,3.7,1 121 | 0,340,2.92,3 122 | 1,520,3.74,2 123 | 1,480,2.67,2 124 | 0,520,2.85,3 125 | 0,500,2.98,3 126 | 0,720,3.88,3 127 | 0,540,3.38,4 128 | 1,600,3.54,1 129 | 0,740,3.74,4 130 | 0,540,3.19,2 131 | 0,460,3.15,4 132 | 1,620,3.17,2 133 | 0,640,2.79,2 134 | 0,580,3.4,2 135 | 0,500,3.08,3 136 | 0,560,2.95,2 137 | 0,500,3.57,3 138 | 0,560,3.33,4 139 | 0,700,4,3 140 | 0,620,3.4,2 141 | 1,600,3.58,1 142 | 0,640,3.93,2 143 | 1,700,3.52,4 144 | 0,620,3.94,4 145 | 0,580,3.4,3 146 | 0,580,3.4,4 147 | 0,380,3.43,3 148 | 0,480,3.4,2 149 | 0,560,2.71,3 150 | 1,480,2.91,1 151 | 0,740,3.31,1 152 | 1,800,3.74,1 153 | 0,400,3.38,2 154 | 1,640,3.94,2 155 | 0,580,3.46,3 156 | 0,620,3.69,3 157 | 1,580,2.86,4 158 | 0,560,2.52,2 159 | 1,480,3.58,1 160 | 0,660,3.49,2 161 | 0,700,3.82,3 162 | 0,600,3.13,2 163 | 0,640,3.5,2 164 | 1,700,3.56,2 165 | 0,520,2.73,2 166 | 0,580,3.3,2 167 | 0,700,4,1 168 | 0,440,3.24,4 169 | 0,720,3.77,3 170 | 0,500,4,3 171 | 0,600,3.62,3 172 | 0,400,3.51,3 173 | 0,540,2.81,3 174 | 0,680,3.48,3 175 | 1,800,3.43,2 176 | 0,500,3.53,4 177 | 1,620,3.37,2 178 | 0,520,2.62,2 179 | 1,620,3.23,3 180 | 0,620,3.33,3 181 | 0,300,3.01,3 182 | 0,620,3.78,3 183 | 0,500,3.88,4 184 | 0,700,4,2 185 | 1,540,3.84,2 186 | 0,500,2.79,4 187 | 0,800,3.6,2 188 | 0,560,3.61,3 189 | 0,580,2.88,2 190 | 0,560,3.07,2 191 | 0,500,3.35,2 192 | 1,640,2.94,2 193 | 0,800,3.54,3 194 | 0,640,3.76,3 195 | 0,380,3.59,4 196 | 1,600,3.47,2 197 | 0,560,3.59,2 198 | 0,660,3.07,3 199 | 1,400,3.23,4 200 | 0,600,3.63,3 201 | 0,580,3.77,4 202 | 0,800,3.31,3 203 | 1,580,3.2,2 204 | 1,700,4,1 205 | 0,420,3.92,4 206 | 1,600,3.89,1 207 | 1,780,3.8,3 208 | 0,740,3.54,1 209 | 1,640,3.63,1 210 | 0,540,3.16,3 211 | 0,580,3.5,2 212 | 0,740,3.34,4 213 | 0,580,3.02,2 214 | 0,460,2.87,2 215 | 0,640,3.38,3 216 | 1,600,3.56,2 217 | 1,660,2.91,3 218 | 0,340,2.9,1 219 | 1,460,3.64,1 220 | 0,460,2.98,1 221 | 1,560,3.59,2 222 | 0,540,3.28,3 223 | 0,680,3.99,3 224 | 1,480,3.02,1 225 | 0,800,3.47,3 226 | 0,800,2.9,2 227 | 1,720,3.5,3 228 | 0,620,3.58,2 229 | 0,540,3.02,4 230 | 0,480,3.43,2 231 | 1,720,3.42,2 232 | 0,580,3.29,4 233 | 0,600,3.28,3 234 | 0,380,3.38,2 235 | 0,420,2.67,3 236 | 1,800,3.53,1 237 | 0,620,3.05,2 238 | 1,660,3.49,2 239 | 0,480,4,2 240 | 0,500,2.86,4 241 | 0,700,3.45,3 242 | 0,440,2.76,2 243 | 1,520,3.81,1 244 | 1,680,2.96,3 245 | 0,620,3.22,2 246 | 0,540,3.04,1 247 | 0,800,3.91,3 248 | 0,680,3.34,2 249 | 0,440,3.17,2 250 | 0,680,3.64,3 251 | 0,640,3.73,3 252 | 0,660,3.31,4 253 | 0,620,3.21,4 254 | 1,520,4,2 255 | 1,540,3.55,4 256 | 1,740,3.52,4 257 | 0,640,3.35,3 258 | 1,520,3.3,2 259 | 1,620,3.95,3 260 | 0,520,3.51,2 261 | 0,640,3.81,2 262 | 0,680,3.11,2 263 | 0,440,3.15,2 264 | 1,520,3.19,3 265 | 1,620,3.95,3 266 | 1,520,3.9,3 267 | 0,380,3.34,3 268 | 0,560,3.24,4 269 | 1,600,3.64,3 270 | 1,680,3.46,2 271 | 0,500,2.81,3 272 | 1,640,3.95,2 273 | 0,540,3.33,3 274 | 1,680,3.67,2 275 | 0,660,3.32,1 276 | 0,520,3.12,2 277 | 1,600,2.98,2 278 | 0,460,3.77,3 279 | 1,580,3.58,1 280 | 1,680,3,4 281 | 1,660,3.14,2 282 | 0,660,3.94,2 283 | 0,360,3.27,3 284 | 0,660,3.45,4 285 | 0,520,3.1,4 286 | 1,440,3.39,2 287 | 0,600,3.31,4 288 | 1,800,3.22,1 289 | 1,660,3.7,4 290 | 0,800,3.15,4 291 | 0,420,2.26,4 292 | 1,620,3.45,2 293 | 0,800,2.78,2 294 | 0,680,3.7,2 295 | 0,800,3.97,1 296 | 0,480,2.55,1 297 | 0,520,3.25,3 298 | 0,560,3.16,1 299 | 0,460,3.07,2 300 | 0,540,3.5,2 301 | 0,720,3.4,3 302 | 0,640,3.3,2 303 | 1,660,3.6,3 304 | 1,400,3.15,2 305 | 1,680,3.98,2 306 | 0,220,2.83,3 307 | 0,580,3.46,4 308 | 1,540,3.17,1 309 | 0,580,3.51,2 310 | 0,540,3.13,2 311 | 0,440,2.98,3 312 | 0,560,4,3 313 | 0,660,3.67,2 314 | 0,660,3.77,3 315 | 1,520,3.65,4 316 | 0,540,3.46,4 317 | 1,300,2.84,2 318 | 1,340,3,2 319 | 1,780,3.63,4 320 | 1,480,3.71,4 321 | 0,540,3.28,1 322 | 0,460,3.14,3 323 | 0,460,3.58,2 324 | 0,500,3.01,4 325 | 0,420,2.69,2 326 | 0,520,2.7,3 327 | 0,680,3.9,1 328 | 0,680,3.31,2 329 | 1,560,3.48,2 330 | 0,580,3.34,2 331 | 0,500,2.93,4 332 | 0,740,4,3 333 | 0,660,3.59,3 334 | 0,420,2.96,1 335 | 0,560,3.43,3 336 | 1,460,3.64,3 337 | 1,620,3.71,1 338 | 0,520,3.15,3 339 | 0,620,3.09,4 340 | 0,540,3.2,1 341 | 1,660,3.47,3 342 | 0,500,3.23,4 343 | 1,560,2.65,3 344 | 0,500,3.95,4 345 | 0,580,3.06,2 346 | 0,520,3.35,3 347 | 0,500,3.03,3 348 | 0,600,3.35,2 349 | 0,580,3.8,2 350 | 0,400,3.36,2 351 | 0,620,2.85,2 352 | 1,780,4,2 353 | 0,620,3.43,3 354 | 1,580,3.12,3 355 | 0,700,3.52,2 356 | 1,540,3.78,2 357 | 1,760,2.81,1 358 | 0,700,3.27,2 359 | 0,720,3.31,1 360 | 1,560,3.69,3 361 | 0,720,3.94,3 362 | 1,520,4,1 363 | 1,540,3.49,1 364 | 0,680,3.14,2 365 | 0,460,3.44,2 366 | 1,560,3.36,1 367 | 0,480,2.78,3 368 | 0,460,2.93,3 369 | 0,620,3.63,3 370 | 0,580,4,1 371 | 0,800,3.89,2 372 | 1,540,3.77,2 373 | 1,680,3.76,3 374 | 1,680,2.42,1 375 | 1,620,3.37,1 376 | 0,560,3.78,2 377 | 0,560,3.49,4 378 | 0,620,3.63,2 379 | 1,800,4,2 380 | 0,640,3.12,3 381 | 0,540,2.7,2 382 | 0,700,3.65,2 383 | 1,540,3.49,2 384 | 0,540,3.51,2 385 | 0,660,4,1 386 | 1,480,2.62,2 387 | 0,420,3.02,1 388 | 1,740,3.86,2 389 | 0,580,3.36,2 390 | 0,640,3.17,2 391 | 0,640,3.51,2 392 | 1,800,3.05,2 393 | 1,660,3.88,2 394 | 1,600,3.38,3 395 | 1,620,3.75,2 396 | 1,460,3.99,3 397 | 0,620,4,2 398 | 0,560,3.04,3 399 | 0,460,2.63,2 400 | 0,700,3.65,2 401 | 0,600,3.89,3 402 | -------------------------------------------------------------------------------- /spec/data/df.csv: -------------------------------------------------------------------------------- 1 | y,a,b,c,d,e 2 | 0,6,62.1,no,female,A 3 | 1,18,34.7,yes,male,B 4 | 1,6,29.7,no,female,C 5 | 0,4,71,no,male,C 6 | 1,5,36.9,yes,male,B 7 | 0,11,58.7,no,female,B 8 | 0,8,63.3,no,male,B 9 | 1,21,20.4,yes,male,A 10 | 1,2,20.5,yes,male,C 11 | 0,11,59.2,no,male,B 12 | 0,1,76.4,yes,female,A 13 | 0,8,71.7,no,female,B 14 | 1,2,77.5,no,male,C 15 | 1,3,31.1,no,male,B -------------------------------------------------------------------------------- /spec/data/logistic.csv: -------------------------------------------------------------------------------- 1 | x1,x2,y 2 | 0.537322309644812,-0.866655707911859,0 3 | -0.717124209978434,-0.367820249977585,0 4 | -0.519166718891331,0.361486610435,1 5 | 0.434970973986765,0.857332626245179,0 6 | -0.761822002215759,0.133438466268095,1 7 | 1.51170030921189,0.716104533073575,1 8 | 0.883854199811195,1.77206093023382,1 9 | -0.908689798854196,-0.10136697295802,1 10 | 1.70331977539793,-0.777086491435508,0 11 | -0.246971150634099,-0.204573554913706,1 12 | -1.59077593922623,0.963353531412233,1 13 | -0.721548040910253,-1.10103024900542,1 14 | 0.467025703920194,-0.404372761837392,1 15 | -0.510132788447137,-0.230226345183469,0 16 | 0.430106510266798,0.0363730246866971,1 17 | -0.144353683251536,-0.838265540390497,0 18 | -1.54943800728303,1.12543549657924,1 19 | 0.849307651309298,-0.57929175648001,1 20 | -0.640304240933579,-0.747060244805248,0 21 | 1.31462478279425,0.58946979365152,1 22 | -0.399783455165345,-0.531952663697324,0 23 | 0.0453055645017902,1.53338594419818,1 24 | -2.58212161987746,0.521992029051441,1 25 | -1.16484414309359,1.41631763288724,1 26 | -1.08829266466281,0.611402316795129,1 27 | -0.243893919684792,-0.518355638373296,0 28 | -1.96655661929441,-0.515192557101107,0 29 | 0.301335373291024,-0.672697937866108,1 30 | -0.665832694463588,1.84347042325327,1 31 | -0.0120650855753837,-0.21195540664804,0 32 | 1.5116066367604,-0.269869371631611,0 33 | 0.557300353673344,0.296155694010096,1 34 | 1.12829931872045,-2.18097898069634,0 35 | 0.234443748015922,-1.21314663927206,0 36 | -2.03486690662651,1.49193669881581,1 37 | 0.275544751380246,1.38969280369493,1 38 | -0.231465849558696,-0.400680808117106,0 39 | -0.356880153225012,-1.87282814976479,0 40 | -0.57746647541923,1.82394870451051,1 41 | 1.35758352580655,0.637864732838274,1 42 | 1.23971669378224,-0.141155946382493,0 43 | -0.662466275100489,0.0699950644281617,1 44 | 0.313263561921793,1.32568550595165,1 45 | -1.08783223256362,-0.412599258349398,1 46 | 1.41964722846899,0.14436832227506,1 47 | 1.29325100940785,-1.16507785388489,0 48 | 0.72153880625103,-2.16782049922428,0 49 | 0.440580131022748,0.24318371493798,0 50 | 0.0351917814720056,0.258954871320764,1 51 | -0.142353224879252,-0.151966534521183,1 -------------------------------------------------------------------------------- /spec/data/logistic_mle.csv: -------------------------------------------------------------------------------- 1 | a,b,c,y 2 | 0.751712137198122,-3.26835910896471,1.70092606756102,0.0 3 | 0.554214068711979,-2.95659724374009,2.66368360383625,0.0 4 | -1.85331644460541,-2.8293733798982,3.34679611669776,0.0 5 | -2.88610159022917,-0.738982407138072,4.74970154250585,0.0 6 | -2.60553098049408,0.561020310042032,5.48308397712906,0.0 7 | -4.27353214430849,1.62383436799538,5.35813425193809,0.0 8 | -4.77012590041537,1.22025583429631,6.41070111604149,0.0 9 | -6.92314832885029,2.86547174912175,8.7318591921429,0.0 10 | -7.56419504045061,4.94028695088613,8.94193466569898,0.0 11 | -8.63093667936286,4.27420502079032,9.27002100933965,0.0 12 | -8.99111149517962,5.10389362675139,11.7669513163404,0.0 13 | -9.99057638012989,7.87484596043453,12.4794035020575,0.0 14 | -10.3818782968703,8.84300238382604,13.7498993372395,0.0 15 | -11.0476824276485,9.44613324059569,13.502502696779,0.0 16 | -12.4344241411336,9.7051587072806,15.122117356301,0.0 17 | -13.6272942288038,10.419034297173,16.3289942692609,0.0 18 | -15.6202224973741,11.3788332186947,17.7367653402519,0.0 19 | -16.2922390869162,13.1516565107205,18.6939344625842,0.0 20 | -16.7159137753851,14.9076297530125,18.0246863680123,0.0 21 | -17.9501251836104,15.8533651031786,20.6826094110231,0.0 22 | -18.9898843433331,15.4331557188719,20.9101142300728,0.0 23 | -19.9085085832061,16.8542366574895,22.0721145894251,0.0 24 | -21.1466526906367,18.6785324946036,23.4977598550022,0.0 25 | -21.3675743976537,18.3208056358668,23.9121114020162,0.0 26 | -22.131396135351,20.7616214854992,24.1683442928502,0.0 27 | -23.1636319405886,21.1293492055766,25.2695476916398,0.0 28 | -24.1360762823224,21.7035705688561,27.9161820459799,0.0 29 | -25.3860721945057,23.3588003206916,27.8755285811136,0.0 30 | -27.2546274134222,24.9201403553661,28.9810564843169,0.0 31 | -28.84506179655,25.1681854417924,29.6749936376452,0.0 32 | -29.5648703289208,26.3989365511582,30.3290447782847,0.0 33 | -30.9473334996246,26.9014594846656,32.5460486188393,0.0 34 | -30.6333205969831,27.9838524246205,32.63981307435,0.0 35 | -31.1288464401156,28.0475385515748,34.4913247020856,0.0 36 | -32.9192755199859,30.8664195945179,35.0830004548623,0.0 37 | -33.5126863006676,31.7135111883763,35.963262553823,0.0 38 | -35.0626135760564,31.0828516323826,36.6124285861271,0.0 39 | -36.9524020008844,33.9722904006922,37.1501952879747,0.0 40 | -36.8859458347289,34.1203333509726,38.5388292299586,0.0 41 | -38.3029057124175,34.3054755372445,40.9141143587977,0.0 42 | -38.4946443648172,35.0506904211858,40.2609817417396,0.0 43 | -39.3511324722213,36.8560827180564,41.727133714198,0.0 44 | -41.7530250817957,37.750760484498,43.8282650147001,0.0 45 | -42.3509555518331,39.6673203525628,44.870018186039,0.0 46 | -42.244528415626,39.0788755186789,45.7949681462311,0.0 47 | -43.4340872344621,41.1905511155583,45.845002676574,0.0 48 | -45.6641407465385,42.0427797845611,46.7401996025804,0.0 49 | -46.3384022695138,42.4548123590202,47.6088880049696,0.0 50 | -47.2110032014801,43.9466417934236,49.2425272385793,0.0 51 | -48.5809364095671,45.2438778157098,49.7183364732398,0.0 52 | -48.8367941243197,45.8093002387181,51.8176699160374,0.0 53 | -49.3632780402121,47.2173660517208,51.910626136845,0.0 54 | -51.3304857529218,48.2292571877905,53.890768221757,0.0 55 | -52.6713212093626,48.0055890605259,54.2008024542977,0.0 56 | -53.0492822384067,50.9227140239038,55.8288448459473,0.0 57 | -53.5475407502719,50.399466736719,56.6653737429534,0.0 58 | -54.9713103390392,52.2174375936081,57.5892087088816,0.0 59 | -55.3824555656592,52.1624080567189,58.9048271122945,0.0 60 | -57.2392061954268,53.0027225338458,59.7821645733191,0.0 61 | -58.9104229152484,54.9892539015416,60.4219388424604,0.0 62 | -58.1824120199409,56.3542717685635,61.2574760633137,0.0 63 | -60.341564759061,56.7219808335487,61.0889655818828,0.0 64 | -61.7134385451754,58.6296683714548,63.962361122351,0.0 65 | -61.4485849309236,59.2722211731019,64.3931122256424,0.0 66 | -63.0320403297596,60.7473007790999,64.4546114078189,0.0 67 | -64.1041544802386,61.8740389113319,65.9649824945026,0.0 68 | -65.4439008347463,62.1787709650214,67.730652575126,0.0 69 | -65.3433374939793,63.9649079050439,67.2579323919331,0.0 70 | -66.1146018531461,64.8632381781096,69.639884237254,0.0 71 | -68.447747797377,65.4474036538936,69.6078430641903,0.0 72 | -69.294418370038,65.4181483428477,70.0720760179831,0.0 73 | -70.2849681665589,66.6959106050402,71.7102828712744,0.0 74 | -70.4280941800939,67.6086197232511,72.3531894807932,1.0 75 | -71.4902330755333,68.6959578032103,73.2224551274211,0.0 76 | -72.0927102365517,69.6523144077888,74.0072309103283,0.0 77 | -73.8239369250543,70.9727748471512,75.0623744862556,0.0 78 | -74.0048051089996,72.4351875425075,76.9734643262443,0.0 79 | -75.3894569624969,73.2005662031078,78.2395878278339,0.0 80 | -77.2087700153069,73.873044622658,79.2174499824634,1.0 81 | -78.3312007339737,74.3177101419062,80.4500892602852,0.0 82 | -78.0152851918014,76.0879642141581,81.7807815886821,1.0 83 | -79.6348564839424,76.2685721204187,82.516078707618,0.0 84 | -81.9637989522169,77.0024986767795,82.4661448054153,0.0 85 | -81.7384252154184,79.6668712051742,84.6703671298337,0.0 86 | -83.6602393343142,80.4913108235325,85.1069139287146,0.0 87 | -84.5817566138972,80.0267527066757,85.7126502390058,0.0 88 | -84.1500643836875,82.9742935504215,87.7889543910899,0.0 89 | -85.4066570769733,82.3347246270609,87.2294546291805,0.0 90 | -87.7637024688064,83.3797423673815,88.7428383996369,1.0 91 | -87.8532422216585,84.5508112461818,90.4419761299396,1.0 92 | -89.4596482634063,86.7455889667711,90.0173299461518,0.0 93 | -90.0512708441363,87.7184551425266,92.6086387957092,0.0 94 | -91.7774688537336,88.3362578318951,92.5323490472281,0.0 95 | -91.8941758249984,89.1756154247095,93.1997545484098,0.0 96 | -93.6063748804845,89.6466643235472,95.9725910658342,0.0 97 | -94.0398325903283,90.3812804437975,96.7702387180417,0.0 98 | -94.8766180342569,91.4799550730711,97.2352819068861,1.0 99 | -96.6522324427766,93.3217737988283,97.4817099617913,1.0 100 | -96.2874457012073,93.6988909891918,98.8166711820106,0.0 101 | -97.7197245583556,95.1054536364693,100.672788663657,0.0 102 | -98.1481566037241,95.1218250594043,100.732200143903,0.0 103 | -99.2071669041095,96.2150468128895,101.714216717689,0.0 104 | -100.936418858708,98.4045788153991,102.613459593007,1.0 105 | -102.443560747664,99.9196460296025,103.024089431854,1.0 106 | -103.253024963442,100.800612294368,104.827681828628,1.0 107 | -103.749612770423,101.647413918771,105.735965604858,1.0 108 | -104.328852482254,101.52632439617,106.261555245069,1.0 109 | -106.909854948268,103.417639882281,107.655384160196,0.0 110 | -107.146277874964,103.735504363015,109.010106082004,1.0 111 | -108.299680591253,105.221040236174,109.433289426765,1.0 112 | -108.483982549154,106.450953357466,110.204651921964,1.0 113 | -109.852652647883,107.528286824546,111.038226237901,1.0 114 | -110.908323481598,107.109110976433,113.280776112873,1.0 115 | -112.467924011291,108.685764539417,113.830693757837,1.0 116 | -113.205608139008,110.215691301935,114.433169300046,0.0 117 | -114.088065910725,111.855183307649,115.490962636794,1.0 118 | -114.932088811468,112.821293025765,116.840371472296,1.0 119 | -116.346270551874,112.158706089612,117.351225094843,0.0 120 | -117.586523469426,113.660859727683,119.786169957249,0.0 121 | -117.399020777901,115.140026386953,119.538050647704,0.0 122 | -119.742251706502,116.297893641214,121.055886377645,1.0 123 | -120.811568959087,116.442419636884,122.285416410297,1.0 124 | -121.87197183902,117.772559468586,123.134175983163,0.0 125 | -121.203532711179,118.217415486896,124.033131679149,1.0 126 | -122.799770233652,119.066944773458,125.957266439052,0.0 127 | -123.25939666012,121.092648631783,126.957392689029,1.0 128 | -124.665594116895,122.943504248815,126.726922729708,1.0 129 | -126.45629792163,123.680027227716,128.578446920242,1.0 130 | -126.478395127934,124.860787507164,129.382354701582,1.0 131 | -128.287481953834,124.904366629726,130.857704333381,1.0 132 | -129.634518688906,126.002157605703,130.664679201595,1.0 133 | -130.832869120258,126.312676682549,131.028091005298,1.0 134 | -131.824704198069,128.491040094009,132.438881937649,1.0 135 | -131.027748413386,128.666245091681,134.386819143277,1.0 136 | -133.251129013031,129.041770903105,135.442572078667,0.0 137 | -134.69199358524,131.797330031224,135.990306488139,1.0 138 | -134.269836322383,131.114035752663,137.455991445092,1.0 139 | -136.322562482609,133.458573512031,137.051219239798,1.0 140 | -137.181797709245,134.690639765394,139.80399233003,1.0 141 | -138.996622047529,135.735702581706,140.012525838337,1.0 142 | -139.215574081352,135.215412440418,141.384571880875,1.0 143 | -140.25643517594,136.693217029473,141.619094359002,1.0 144 | -141.091696435808,138.544186252854,143.447467807219,1.0 145 | -141.615279226692,139.247015471284,143.781551942357,1.0 146 | -143.738916186876,140.442376021513,145.690446407267,1.0 147 | -143.406661315583,141.611755086021,146.834954933569,1.0 148 | -145.539994351071,141.864759154179,146.821868669993,1.0 149 | -146.759792237469,143.72389809576,147.648189754648,1.0 150 | -147.742745321685,143.773042536024,148.335412957969,1.0 151 | -147.827110918347,144.891157822512,150.991195838727,1.0 152 | -149.385928352478,146.506476517532,151.284254745032,1.0 153 | -150.977776481931,146.021859562023,152.768419217875,1.0 154 | -150.312752816396,148.730155452661,152.038272157343,1.0 155 | -151.558966920835,148.800062079608,153.169105510546,1.0 156 | -152.600091618579,150.820711027947,154.708689228766,1.0 157 | -153.408136649531,151.844749166301,156.485789011581,1.0 158 | -155.851442988913,152.727616801779,157.051864933563,1.0 159 | -155.804122660613,153.080067542298,157.928538460402,1.0 160 | -157.073427576551,153.343498427848,158.067904946476,1.0 161 | -157.107508909122,154.26235452322,160.34811507316,1.0 162 | -159.484841744779,155.071935569765,161.331016311668,1.0 163 | -160.494393486944,156.832393661499,162.766796789861,1.0 164 | -160.993630097021,157.641448785465,162.16966187911,1.0 165 | -162.552759845799,159.539687734098,163.085996347117,1.0 166 | -163.564091719115,159.348960279159,164.704726412269,1.0 167 | -163.726283737347,161.928264713823,165.777638968284,1.0 168 | -165.875733565084,162.238166456333,167.748146433357,1.0 169 | -166.396130106934,163.168469688741,167.777541018587,1.0 170 | -166.362936492002,164.060990842406,168.793540054343,1.0 171 | -167.190830641697,164.69633527615,169.716276093277,1.0 172 | -169.856337102559,165.948746195011,171.767907484416,1.0 173 | -170.437574285639,166.585793448766,172.423441930238,1.0 174 | -171.364589341033,167.713387609531,172.424258150602,1.0 175 | -171.719858914268,168.568258270241,174.186728478313,1.0 176 | -172.222432925774,169.341495037779,175.014592418557,1.0 177 | -174.925024864504,170.009200154178,175.823166747739,1.0 178 | -174.440609108192,172.585679891303,177.497935948773,1.0 179 | -175.685134632762,172.597439708769,178.865842811381,1.0 180 | -177.037735766697,174.337033018417,178.526154196112,1.0 181 | -177.695754636999,175.158079880788,179.090845324338,1.0 182 | -178.530965125958,175.998360662267,180.756244303945,1.0 183 | -179.33859605694,177.055963520977,182.537104759468,1.0 184 | -181.051026538749,178.749994896864,182.511752447157,1.0 185 | -181.910324655635,178.191126997934,183.205490333945,1.0 186 | -182.658814367512,179.821227820741,184.077256331919,1.0 187 | -184.92643232869,181.744602320009,185.895789796461,1.0 188 | -185.336918062898,181.363095955327,186.129937394056,1.0 189 | -185.267288986526,182.500514904395,188.252778726315,1.0 190 | -186.872953708712,184.965877299035,188.828219832583,1.0 191 | -188.918306879815,185.051974487609,190.899813628421,1.0 192 | -189.763521615511,185.68824128269,190.220429056165,1.0 193 | -190.644174012638,187.750343044182,191.493983126918,1.0 194 | -190.321219894606,187.817497818665,193.40572238254,1.0 195 | -192.79354360347,188.86697228881,193.159805548127,1.0 196 | -192.025564246072,189.147494777077,194.034017243282,1.0 197 | -193.32579637306,190.629543307974,195.316361684183,1.0 198 | -194.09803396593,192.184663556046,196.170525930209,1.0 199 | -196.951286736902,192.950569139836,198.865411666979,1.0 200 | -197.273285018295,193.455170048879,199.984457526702,1.0 201 | -197.367054178611,194.765784862729,199.831724804912,1.0 202 | -------------------------------------------------------------------------------- /spec/data/normal.csv: -------------------------------------------------------------------------------- 1 | ROLL,UNEM,HGRAD,INC 2 | 5501,8.1,9552,1923 3 | 5945,7,9680,1961 4 | 6629,7.3,9731,1979 5 | 7556,7.5,11666,2030 6 | 8716,7,14675,2112 7 | 9369,6.4,15265,2192 8 | 9920,6.5,15484,2235 9 | 10167,6.4,15723,2351 10 | 11084,6.3,16501,2411 11 | 12504,7.7,16890,2475 12 | 13746,8.2,17203,2524 13 | 13656,7.5,17707,2674 14 | 13850,7.4,18108,2833 15 | 14145,8.2,18266,2863 16 | 14888,10.1,19308,2839 17 | 14991,9.2,18224,2898 18 | 14836,7.7,18997,3123 19 | 14478,5.7,19505,3195 20 | 14539,6.5,19800,3239 21 | 14395,7.5,19546,3129 22 | 14599,7.3,19117,3100 23 | 14969,9.2,18774,3008 24 | 15107,10.1,17813,2983 25 | 14831,7.5,17304,3069 26 | 15081,8.8,16756,3151 27 | 15127,9.1,16749,3127 28 | 15856,8.8,16925,3179 29 | 15938,7.8,17231,3207 30 | 16081,7,16816,3345 -------------------------------------------------------------------------------- /spec/formula_spec.rb: -------------------------------------------------------------------------------- 1 | require 'spec_helper.rb' 2 | require 'shared_context/parser_checker.rb' 3 | 4 | describe Statsample::GLM::Formula do 5 | context '#parse_formula' do 6 | context 'no interaction' do 7 | include_context 'parser checker', '1+a+b' => 8 | '1+a(-)+b(-)' 9 | end 10 | 11 | context '2-way interaction' do 12 | context 'none reoccur' do 13 | include_context 'parser checker', '1+c+a:b' => 14 | '1+c(-)+b(-)+a(-):b' 15 | end 16 | 17 | context 'first reoccur' do 18 | include_context 'parser checker', '1+a+a:b' => 19 | '1+a(-)+a:b(-)' 20 | end 21 | 22 | context 'second reoccur' do 23 | include_context 'parser checker', '1+b+a:b' => 24 | '1+b(-)+a(-):b' 25 | end 26 | 27 | context 'both reoccur' do 28 | include_context 'parser checker', '1+a+b+a:b' => 29 | '1+a(-)+b(-)+a(-):b(-)' 30 | end 31 | end 32 | 33 | context 'complex cases' do 34 | include_context 'parser checker', '1+a+a:b+b:d' => 35 | '1+a(-)+a:b(-)+b:d(-)' 36 | end 37 | end 38 | end -------------------------------------------------------------------------------- /spec/formula_wrapper_spec.rb: -------------------------------------------------------------------------------- 1 | require 'spec_helper.rb' 2 | require 'shared_context/reduce_formula.rb' 3 | 4 | describe Statsample::GLM::FormulaWrapper do 5 | context '#reduce_formula' do 6 | let(:df) { Daru::DataFrame.from_csv 'spec/data/df.csv' } 7 | 8 | before do 9 | df.to_category 'c', 'd', 'e' 10 | end 11 | 12 | context 'shortcut symbols' do 13 | context '*' do 14 | context 'two terms' do 15 | include_context 'reduce formula', 'y~a*b' => 'y~1+a+b+a:b' 16 | end 17 | 18 | context 'correct precedance' do 19 | context 'with :' do 20 | include_context 'reduce formula', 'y~a*b:c' => 21 | 'y~1+a+b:c+a:b:c' 22 | end 23 | 24 | context 'with +' do 25 | include_context 'reduce formula', 'y~a+b*c' => 26 | 'y~1+a+b+c+b:c' 27 | end 28 | end 29 | 30 | context 'more than two terms' do 31 | include_context 'reduce formula', 'y~a*b*c' => 32 | 'y~1+a+b+c+a:b+a:c+b:c+a:b:c' 33 | end 34 | end 35 | 36 | context '/' do 37 | context 'two terms' do 38 | include_context 'reduce formula', 'y~a/b' => 'y~1+a+a:b' 39 | end 40 | 41 | # TODO: Mismatch with Patsy 42 | xcontext 'more than two terms' do 43 | include_context 'reduce formula', 'y~a/b/c' => 44 | 'y~1+a+a:b+a:b:c' 45 | end 46 | 47 | context 'correct precedance' do 48 | context 'with :' do 49 | include_context 'reduce formula', 'y~a/b:c' => 50 | 'y~1+a+a:b:c' 51 | end 52 | 53 | context 'with +' do 54 | include_context 'reduce formula', 'y~a/b+c' => 55 | 'y~1+a+c+a:b' 56 | end 57 | end 58 | end 59 | 60 | context 'brackets' do 61 | context 'with + and :' do 62 | include_context 'reduce formula', 'y~(a+b):c' => 63 | 'y~1+a:c+b:c' 64 | end 65 | 66 | context 'with * and :' do 67 | include_context 'reduce formula', 'y~(a*b):c' => 68 | 'y~1+a:c+b:c+a:b:c' 69 | end 70 | 71 | xcontext 'with / and :' do 72 | include_context 'reduce formula', 'y~(a/b):c' => 73 | 'y~1+a:c+a:b:c' 74 | end 75 | 76 | # TODO: Mismatch with Patsy 77 | xcontext 'with * and /' do 78 | include_context 'reduce formula', 'y~(a*b)/c' => 79 | 'y~1+a+b+a:b+a:b:c' 80 | end 81 | end 82 | 83 | context 'corner cases' do 84 | context 'names of more than one character' do 85 | before do 86 | df['ax'] = df['a'] 87 | df['bx'] = df['b'] 88 | end 89 | include_context 'reduce formula', 'y~ax*bx:c' => 90 | 'y~1+ax+bx:c+ax:bx:c' 91 | end 92 | end 93 | 94 | context 'complex cases' do 95 | context 'example 1' do 96 | include_context 'reduce formula', 'y~(a+b)*(c+d)' => 97 | 'y~1+a+b+c+d+a:c+a:d+b:c+b:d' 98 | end 99 | end 100 | end 101 | end 102 | end 103 | -------------------------------------------------------------------------------- /spec/logistic_spec.rb: -------------------------------------------------------------------------------- 1 | require 'spec_helper.rb' 2 | 3 | describe Statsample::GLM::Logistic do 4 | context "IRLS algorithm" do 5 | before do 6 | @data_set = Daru::DataFrame.from_csv "spec/data/logistic.csv" 7 | @data_set.vectors = Daru::Index.new([:x1,:x2,:y]) 8 | @glm = Statsample::GLM.compute @data_set, :y, :logistic, {constant: 1} 9 | end 10 | 11 | it "reports correct coefficients as an array" do 12 | expect_similar_array(@glm.coefficients(:array), [-0.312493754568903, 13 | 2.28671333346264, 14 | 0.675603176233325]) 15 | end 16 | 17 | it "reports correct coefficients as a hash" do 18 | expect_similar_hash(@glm.coefficients(:hash), {:constant => 0.675603176233325, 19 | :x1 => -0.312493754568903, 20 | :x2 => 2.28671333346264}) 21 | end 22 | 23 | it "reports correct coefficients as a Daru::Vector" do 24 | expect_similar_vector(@glm.coefficients, [-0.312493754568903, 25 | 2.28671333346264, 0.675603176233325]) 26 | end 27 | 28 | it "computes predictions on new data correctly" do 29 | new_data = Daru::DataFrame.new([[0.1, 0.2, 0.3], [-0.1, 0.0, 0.1]], 30 | order: [:x1, :x2]) 31 | #predictions obtained in R with predict.glm with type='response': 32 | predictions = [0.6024496420392773, 0.6486486378079901, 0.6922216620285218] 33 | expect_similar_vector @glm.predict(new_data), predictions 34 | end 35 | end 36 | 37 | context "MLE algorithm" do 38 | before do 39 | @data_set = Daru::DataFrame.from_csv("spec/data/logistic_mle.csv") 40 | @data_set.vectors = Daru::Index.new([:a,:b,:c,:y]) 41 | @glm = Statsample::GLM.compute @data_set,:y, :logistic, {constant: 1, algorithm: :mle} 42 | end 43 | 44 | it "reports correct log-likelihood" do 45 | expect(@glm.log_likelihood).to be_within(0.001).of(-38.8669) 46 | end 47 | 48 | it "report the correct number of iterations" do 49 | expect(@glm.iterations).to eq(7) 50 | end 51 | 52 | it "reports correct regression coefficients as a Daru::Vector" do 53 | expect_similar_vector(@glm.coefficients, [0.3270, 0.8147, -0.4031,-5.3658], 0.001) 54 | end 55 | 56 | it "reports correct standard deviations as a Daru::Vector" do 57 | expect_similar_vector(@glm.standard_errors, [0.4390, 0.4270, 0.3819,1.9045], 0.001) 58 | end 59 | 60 | it "reports correct regression coefficients as an array" do 61 | expect_similar_array(@glm.coefficients(:array), [0.3270, 0.8147, -0.4031,-5.3658], 0.001) 62 | end 63 | 64 | it "reports correct standard deviations as an array" do 65 | expect_similar_array(@glm.standard_errors(:array), [0.4390, 0.4270, 0.3819,1.9045], 0.001) 66 | end 67 | 68 | it "reports correct regression coefficients as a hash" do 69 | expect_similar_hash(@glm.coefficients(:hash), {:a => 0.3270, :b => 0.8147, :c => -0.4031, 70 | :constant => -5.3658}, 0.001) 71 | end 72 | 73 | it "reports correct standard deviations as a hash" do 74 | expect_similar_hash(@glm.standard_errors(:hash), {:a => 0.4390, :b => 0.4270, :c => 0.3819, 75 | :constant => 1.9045}, 0.001) 76 | end 77 | 78 | it "computes predictions on new data correctly" do 79 | new_data = Daru::DataFrame.new([[-1.0, -150.0], [0.0, 150.0], [1.0, 150.0]], 80 | order: ['a', 'b', 'c']) 81 | #predictions obtained with in R predict.glm with type='response': 82 | predictions = [0.002247048350428831, 0.999341821607089287] 83 | expect_similar_vector @glm.predict(new_data), predictions 84 | end 85 | end 86 | end 87 | -------------------------------------------------------------------------------- /spec/normal_spec.rb: -------------------------------------------------------------------------------- 1 | describe Statsample::GLM::Normal do 2 | context "MLE algorithm" do 3 | before do 4 | # Below data set taken from http://dl.dropbox.com/u/10246536/Web/RTutorialSeries/dataset_multipleRegression.csv 5 | @ds = Daru::DataFrame.from_csv "spec/data/normal.csv", 6 | order: ['ROLL', 'UNEM', 'HGRAD', 'INC'] 7 | end 8 | 9 | it "reports correct values as a Daru::Vector", focus: true do 10 | @glm = Statsample::GLM.compute @ds, 'ROLL', :normal, {algorithm: :mle} 11 | 12 | expect_similar_vector @glm.coefficients, [450.12450365911894, 13 | 0.4064837278023981, 4.27485769721736, -9153.254462671905] 14 | end 15 | 16 | it "reports correct values when constant is different from 1", focus: true do 17 | @glm = Statsample::GLM.compute @ds, 'ROLL', :normal, {constant: 2, algorithm: :mle} 18 | 19 | expect_similar_vector @glm.coefficients, [450.12450365911894, 20 | 0.4064837278023981, 4.27485769721736, -4576.627231335952] 21 | end 22 | 23 | it "computes predictions of new data correctly" do 24 | @glm = Statsample::GLM.compute @ds, 'ROLL', :normal, {algorithm: :mle} 25 | new_data = Daru::DataFrame.new([[7, 8, 9], 26 | [15000, 16000, 17000], 27 | [3000, 4000, 5000]], 28 | order: ['UNEM', 'HGRAD', 'INC']) 29 | # predictions obtained with predict.lm in R: 30 | predictions = [12919.44607162998, 18050.91200030885, 31 | 23182.37792898773] 32 | expect_similar_vector @glm.predict(new_data), predictions 33 | end 34 | end 35 | end 36 | -------------------------------------------------------------------------------- /spec/poisson_spec.rb: -------------------------------------------------------------------------------- 1 | describe Statsample::GLM::Poisson do 2 | context "IRLS algorithm" do 3 | before :each do 4 | x1 = Daru::Vector.new([0.537322309644812,-0.717124209978434,-0.519166718891331,0.434970973986765,-0.761822002215759,1.51170030921189,0.883854199811195,-0.908689798854196,1.70331977539793,-0.246971150634099,-1.59077593922623,-0.721548040910253,0.467025703920194,-0.510132788447137,0.430106510266798,-0.144353683251536,-1.54943800728303,0.849307651309298,-0.640304240933579,1.31462478279425,-0.399783455165345,0.0453055645017902,-2.58212161987746,-1.16484414309359,-1.08829266466281,-0.243893919684792,-1.96655661929441,0.301335373291024,-0.665832694463588,-0.0120650855753837,1.5116066367604,0.557300353673344,1.12829931872045,0.234443748015922,-2.03486690662651,0.275544751380246,-0.231465849558696,-0.356880153225012,-0.57746647541923,1.35758352580655,1.23971669378224,-0.662466275100489,0.313263561921793,-1.08783223256362,1.41964722846899,1.29325100940785,0.72153880625103,0.440580131022748,0.0351917814720056, -0.142353224879252]) 5 | x2 = Daru::Vector.new([-0.866655707911859,-0.367820249977585,0.361486610435,0.857332626245179,0.133438466268095,0.716104533073575,1.77206093023382,-0.10136697295802,-0.777086491435508,-0.204573554913706,0.963353531412233,-1.10103024900542,-0.404372761837392,-0.230226345183469,0.0363730246866971,-0.838265540390497,1.12543549657924,-0.57929175648001,-0.747060244805248,0.58946979365152,-0.531952663697324,1.53338594419818,0.521992029051441,1.41631763288724,0.611402316795129,-0.518355638373296,-0.515192557101107,-0.672697937866108,1.84347042325327,-0.21195540664804,-0.269869371631611,0.296155694010096,-2.18097898069634,-1.21314663927206,1.49193669881581,1.38969280369493,-0.400680808117106,-1.87282814976479,1.82394870451051,0.637864732838274,-0.141155946382493,0.0699950644281617,1.32568550595165,-0.412599258349398,0.14436832227506,-1.16507785388489,-2.16782049922428,0.24318371493798,0.258954871320764,-0.151966534521183]) 6 | @y_pois = Daru::Vector.new([1,2,1,3,3,1,10,1,1,2,15,0,0,2,1,2,18,2,1,1,1,8,18,13,7,1,1,0,26,0,2,2,0,0,25,7,0,0,21,0,0,1,5,0,3,0,0,1,0,0]) 7 | 8 | @df = Daru::DataFrame.new({:x1 => x1,:x2 => x2, :y => @y_pois}) 9 | 10 | end 11 | 12 | it "reports return values a Daru::Vector for IRLS" do 13 | @glm = Statsample::GLM.compute @df, :y, :poisson, {algorithm: :irls, constant: 1} 14 | 15 | expect_similar_vector(@glm.coefficients, [-0.586359358356708, 16 | 1.28511323439258, 17 | 0.32993246633711]) 18 | end 19 | 20 | it "reports return values as a hash for IRLS" do 21 | @glm = Statsample::GLM.compute @df, :y, :poisson, {algorithm: :irls, constant: 1} 22 | 23 | expect_similar_hash(@glm.coefficients(:hash), {:constant => 0.32993246633711, 24 | :x1 => -0.586359358356708, 25 | :x2 => 1.28511323439258}) 26 | end 27 | 28 | it "reports return values as an array for IRLS" do 29 | @glm = Statsample::GLM.compute @df, :y, :poisson, {algorithm: :irls, constant: 1} 30 | 31 | expect_similar_array(@glm.coefficients(:array), [-0.586359358356708, 32 | 1.28511323439258, 33 | 0.32993246633711]) 34 | end 35 | 36 | it "computes predictions on new data correctly" do 37 | @glm = Statsample::GLM.compute @df,:y,:poisson, {algorithm: :irls, constant: 1} 38 | new_data = Daru::DataFrame.new(x1: [-0.5, 0.5], x2: [-1.0, 1.0]) 39 | #predictions obtained in R with predict.glm with type='response': 40 | predictions = [0.5158181031737609, 3.7504132036471081] 41 | expect_similar_vector @glm.predict(new_data), predictions 42 | end 43 | end 44 | 45 | context "MLE algorithm" do 46 | # TODO: Implement MLE for poisson 47 | end 48 | end 49 | -------------------------------------------------------------------------------- /spec/probit_spec.rb: -------------------------------------------------------------------------------- 1 | describe Statsample::GLM::Probit do 2 | context "IRLS algorithm" do 3 | # TODO : Implement this! 4 | end 5 | 6 | context "MLE algorithm" do 7 | before do 8 | @data_set = Daru::DataFrame.from_csv 'spec/data/logistic_mle.csv' 9 | @data_set.vectors = Daru::Index.new([:a,:b,:c,:y]) 10 | @glm = Statsample::GLM.compute @data_set, :y, :probit, 11 | {algorithm: :mle, constant: 1} 12 | end 13 | 14 | it "reports correct log-likelihood" do 15 | expect(@glm.log_likelihood).to be_within(0.0001).of(-38.31559) 16 | end 17 | 18 | it "reports correct regression coefficients as a Daru::Vector" do 19 | expect_similar_vector(@glm.coefficients,[0.1763,0.4483,-0.2240,-3.0670],0.001) 20 | end 21 | 22 | it "reports correct regression coefficients as an array" do 23 | expect_similar_array(@glm.coefficients(:array),[0.1763,0.4483,-0.2240,-3.0670],0.001) 24 | end 25 | 26 | it "reports correct regression coefficients as a hash" do 27 | expect_similar_hash(@glm.coefficients(:hash), {:a => 0.1763, :b => 0.4483, :c => -0.2240, 28 | :constant => -3.0670}, 0.001) 29 | end 30 | 31 | it "computes predictions on new data correctly" do 32 | new_data = Daru::DataFrame.new([[-50.0, -100.0], [50.0, 100.0], [50.0, 100.0]], 33 | order: ['a', 'b', 'c']) 34 | #predictions obtained with in R predict.glm with type='response': 35 | predictions = [0.2516918644447207, 0.9580621633922622] 36 | expect_similar_vector @glm.predict(new_data), predictions, delta=1e-4 37 | end 38 | end 39 | end 40 | -------------------------------------------------------------------------------- /spec/regression_spec.rb: -------------------------------------------------------------------------------- 1 | require 'spec_helper.rb' 2 | require 'shared_context/formula_checker.rb' 3 | 4 | describe Statsample::GLM::Regression do 5 | let(:df) { Daru::DataFrame.from_csv 'spec/data/df.csv' } 6 | let(:rank_df) { Daru::DataFrame.from_csv 'spec/data/binary.csv' } 7 | before do 8 | df.to_category 'c', 'd', 'e' 9 | rank_df.to_category 'rank' 10 | end 11 | 12 | context '#model' do 13 | context 'numerical' do 14 | let(:model) { described_class.new 'y ~ a+b+a:b', df, :logistic } 15 | let(:expected_hash) { {:a => 1.14462, :b => -0.04292, :'a:b' => -0.03011, 16 | :constant => 4.73822 } } 17 | subject { model.model } 18 | 19 | it { is_expected.to be_a Statsample::GLM::Logistic } 20 | it 'verifies the coefficients' do 21 | expect_similar_hash(subject.coefficients(:hash), expected_hash, 1e-5) 22 | end 23 | end 24 | 25 | context 'category' do 26 | let(:model) { described_class.new 'y ~ 0+c', df, :logistic } 27 | let(:expected_hash) { {c_no: -0.6931, c_yes: 1.3863 } } 28 | subject { model.model } 29 | 30 | it { is_expected.to be_a Statsample::GLM::Logistic } 31 | it 'verifies the coefficients' do 32 | expect_similar_hash(subject.coefficients(:hash), expected_hash, 1e-4) 33 | end 34 | end 35 | 36 | context 'category and numeric' do 37 | let(:model) { described_class.new 'y ~ a+b:c', df, :logistic } 38 | let(:expected_hash) { {:constant => 16.8145, :a => -0.4315, 39 | :'c_no:b' => -0.2344, :'c_yes:b' => -0.2344} } 40 | subject { model.model } 41 | 42 | it { is_expected.to be_a Statsample::GLM::Logistic } 43 | it 'verifies the coefficients' do 44 | expect_similar_hash(subject.coefficients(:hash), expected_hash, 1e-2) 45 | end 46 | end 47 | 48 | context 'other regression types' do 49 | # TODO: Right now it only verifies appropriate model gets generated 50 | # with appropriate coefficients but it doesn't verify the actual values 51 | # of coefficients. 52 | context 'normal' do 53 | let(:model) { described_class.new 'a ~ b:c', df, :normal, algorithm: :mle } 54 | subject { model.model } 55 | 56 | it { is_expected.to be_a Statsample::GLM::Normal } 57 | it { expect(subject.coefficients(:hash).keys).to eq( 58 | [:"c_no:b", :"c_yes:b", :constant]) } 59 | end 60 | 61 | context 'poisson' do 62 | let(:model) { described_class.new 'a ~ b:c', df, :poisson} 63 | subject { model.model } 64 | 65 | it { is_expected.to be_a Statsample::GLM::Poisson } 66 | it { expect(subject.coefficients(:hash).keys).to eq( 67 | [:"c_no:b", :"c_yes:b", :constant]) } 68 | end 69 | 70 | context 'probit' do 71 | let(:model) { described_class.new 'a ~ b:c', df, :probit, algorithm: :mle} 72 | subject { model.model } 73 | 74 | it { is_expected.to be_a Statsample::GLM::Probit } 75 | it { expect(subject.coefficients(:hash).keys).to eq( 76 | [:"c_no:b", :"c_yes:b", :constant]) } 77 | end 78 | end 79 | end 80 | 81 | context '#predict' do 82 | context 'numerical' do 83 | let(:model) { described_class.new 'y ~ a+b+a:b', df, :logistic } 84 | let(:new_data) { df.head 3 } 85 | subject { model.predict new_data } 86 | 87 | it 'verifies the prediction' do 88 | expect_similar_vector(subject, [0.0930, 0.9936, 0.9931], 1e-4) 89 | end 90 | end 91 | 92 | context 'category' do 93 | let(:model) { described_class.new 'y ~ 0+c', df, :logistic } 94 | let(:new_data) { df.head 3 } 95 | subject { model.predict new_data } 96 | 97 | it 'verifies the prediction' do 98 | expect_similar_vector(subject, [0.3333, 0.8, 0.3333], 1e-4) 99 | end 100 | end 101 | 102 | context 'category and numeric' do 103 | let(:model) { described_class.new 'y ~ a+b:c', df, :logistic } 104 | let(:new_data) { df.head 3 } 105 | subject { model.predict new_data } 106 | 107 | it 'verifies the prediction' do 108 | expect_similar_vector(subject, [0.4183, 0.6961, 0.9993], 1e-4) 109 | end 110 | end 111 | 112 | context "order doesn't matter" do 113 | let(:model) { described_class.new 'admit ~ gpa + gre + rank', 114 | rank_df, :logistic } 115 | let(:new_data) do 116 | Daru::DataFrame.new({ 117 | 'gre' => [rank_df['gre'].mean]*4, 118 | 'gpa' => [rank_df['gpa'].mean]*4, 119 | 'rank' => [1,2,3,4] 120 | }, order: ['rank', 'gpa', 'gre']) 121 | end 122 | subject { model.predict new_data } 123 | 124 | it 'verfies the prediction' do 125 | expect_similar_vector(subject, [0.5166, 0.3523, 0.2186, 0.1847], 1e-4) 126 | end 127 | end 128 | end 129 | 130 | context '#df_for_regression' do 131 | context 'with intercept' do 132 | context 'no interaction' do 133 | include_context "formula checker", 'y~a+e' => %w[a e_B e_C y] 134 | end 135 | 136 | context '2-way interaction' do 137 | context 'interaction of numerical with numerical' do 138 | context 'none reoccur' do 139 | include_context 'formula checker', 'y~a:b' => 140 | %w[a:b y] 141 | end 142 | 143 | context 'one reoccur' do 144 | include_context 'formula checker', 'y~a+a:b' => 145 | %w[a a:b y] 146 | end 147 | 148 | context 'both reoccur' do 149 | include_context 'formula checker', 'y~a+b+a:b' => 150 | %w[a a:b b y] 151 | end 152 | end 153 | 154 | context 'interaction of category with numerical' do 155 | context 'none reoccur' do 156 | include_context 'formula checker', 'y~a:e' => 157 | %w[e_A:a e_B:a e_C:a y] 158 | end 159 | 160 | context 'one reoccur' do 161 | context 'numeric occur' do 162 | include_context 'formula checker', 'y~a+a:e' => 163 | %w[a e_B:a e_C:a y] 164 | end 165 | 166 | context 'category occur' do 167 | include_context 'formula checker', 'y~e+a:e' => 168 | %w[e_B e_C e_A:a e_B:a e_C:a y] 169 | end 170 | end 171 | 172 | context 'both reoccur' do 173 | include_context 'formula checker', 'y~a+e+a:e' => 174 | %w[a e_B e_C e_B:a e_C:a y] 175 | end 176 | end 177 | 178 | context 'interaction of category with category' do 179 | context 'none reoccur' do 180 | include_context 'formula checker', 'y~c:e' => 181 | %w[e_B e_C c_yes:e_A c_yes:e_B c_yes:e_C y] 182 | end 183 | 184 | context 'one reoccur' do 185 | include_context 'formula checker', 'y~e+c:e' => 186 | %w[e_B e_C c_yes:e_A c_yes:e_B c_yes:e_C y] 187 | end 188 | 189 | context 'both reoccur' do 190 | include_context 'formula checker', 'y~c+e+c:e' => 191 | %w[c_yes e_B e_C c_yes:e_B c_yes:e_C y] 192 | end 193 | end 194 | end 195 | end 196 | 197 | context 'without intercept' do 198 | context 'no interaction' do 199 | include_context "formula checker", 'y~0+a+e' => %w[a e_A e_B e_C y] 200 | end 201 | 202 | context '2-way interaction' do 203 | context 'interaction of numerical with numerical' do 204 | context 'none reoccur' do 205 | include_context 'formula checker', 'y~0+a:b' => 206 | %w[a:b y] 207 | end 208 | 209 | context 'one reoccur' do 210 | include_context 'formula checker', 'y~0+a+a:b' => 211 | %w[a a:b y] 212 | end 213 | 214 | context 'both reoccur' do 215 | include_context 'formula checker', 'y~0+a+b+a:b' => 216 | %w[a a:b b y] 217 | end 218 | end 219 | 220 | context 'interaction of category with numerical' do 221 | context 'none reoccur' do 222 | include_context 'formula checker', 'y~0+a:e' => 223 | %w[e_A:a e_B:a e_C:a y] 224 | end 225 | 226 | context 'one reoccur' do 227 | context 'numeric occur' do 228 | include_context 'formula checker', 'y~0+a+a:e' => 229 | %w[a e_B:a e_C:a y] 230 | end 231 | 232 | context 'category occur' do 233 | include_context 'formula checker', 'y~0+e+a:e' => 234 | %w[e_A e_B e_C e_A:a e_B:a e_C:a y] 235 | end 236 | end 237 | 238 | context 'both reoccur' do 239 | include_context 'formula checker', 'y~0+a+e+a:e' => 240 | %w[a e_A e_B e_C e_B:a e_C:a y] 241 | end 242 | end 243 | 244 | context 'interaction of category with category' do 245 | context 'none reoccur' do 246 | include_context 'formula checker', 'y~0+c:e' => 247 | %w[c_no:e_A c_no:e_B c_no:e_C c_yes:e_A c_yes:e_B c_yes:e_C y] 248 | end 249 | 250 | context 'one reoccur' do 251 | include_context 'formula checker', 'y~0+e+c:e' => 252 | %w[e_A e_B e_C c_yes:e_A c_yes:e_B c_yes:e_C y] 253 | end 254 | 255 | context 'both reoccur' do 256 | include_context 'formula checker', 'y~0+c+e+c:e' => 257 | %w[c_yes c_no e_B e_C c_yes:e_B c_yes:e_C y] 258 | end 259 | end 260 | end 261 | end 262 | 263 | context 'shortcut symbols' do 264 | context 'symbol *' do 265 | include_context 'formula checker', 'y~0+a*c' => 266 | %w[a c_yes c_no c_yes:a y] 267 | end 268 | 269 | context 'symbol /' do 270 | include_context 'formula checker', 'y~a/c' => 271 | %w[a c_yes:a y] 272 | end 273 | end 274 | 275 | context 'corner case' do 276 | context 'example 1' do 277 | include_context 'formula checker', 'y~d:a+d:e' => 278 | %w[e_B e_C d_male:e_A d_male:e_B d_male:e_C d_female:a d_male:a y] 279 | end 280 | 281 | context 'example 2' do 282 | include_context 'formula checker', 'y~0+d:a+d:c' => 283 | %w[d_female:c_no d_male:c_no d_female:c_yes d_male:c_yes d_female:a d_male:a y] 284 | end 285 | end 286 | 287 | context 'complex examples' do 288 | context 'random example 1' do 289 | include_context 'formula checker', 'y~a+e+c:d+e:d' => 290 | %w[e_B e_C d_male c_yes:d_female c_yes:d_male e_B:d_male e_C:d_male a y] 291 | end 292 | 293 | context 'random example 2' do 294 | include_context 'formula checker', 'y~e+b+c+d:e+b:e+a:e+0' => 295 | %w[e_A e_B e_C c_yes d_male:e_A d_male:e_B d_male:e_C b e_B:b e_C:b e_A:a e_B:a e_C:a y] 296 | end 297 | end 298 | # TODO: Three way interaction 299 | end 300 | end -------------------------------------------------------------------------------- /spec/shared_context/formula_checker.rb: -------------------------------------------------------------------------------- 1 | RSpec.shared_context 'formula checker' do |params| 2 | let(:formula) { params.keys.first } 3 | let(:vectors) { params.values.first } 4 | 5 | let(:model) { described_class.new formula, df, :logistic } 6 | subject { model.df_for_regression } 7 | 8 | it { is_expected.to be_a Daru::DataFrame } 9 | its(:'vectors.to_a.sort') { is_expected.to eq vectors.sort } 10 | end 11 | -------------------------------------------------------------------------------- /spec/shared_context/parser_checker.rb: -------------------------------------------------------------------------------- 1 | RSpec.shared_context 'parser checker' do |params| 2 | let(:input) { params.keys.first } 3 | let(:parse_result) { params.values.first } 4 | 5 | let(:formula) do 6 | described_class.new( 7 | input.split('+').map { |i| Statsample::GLM::Token.new i } 8 | ) 9 | end 10 | subject { formula.canonical_to_s } 11 | 12 | it { is_expected.to eq parse_result } 13 | end 14 | -------------------------------------------------------------------------------- /spec/shared_context/reduce_formula.rb: -------------------------------------------------------------------------------- 1 | RSpec.shared_context 'reduce formula' do |params| 2 | let(:input) { params.keys.first } 3 | let(:result) { params.values.first } 4 | 5 | let(:formula) { described_class.new input, df } 6 | subject { formula.to_s } 7 | 8 | it { is_expected.to be_a String } 9 | it { is_expected.to eq result } 10 | end 11 | -------------------------------------------------------------------------------- /spec/spec_helper.rb: -------------------------------------------------------------------------------- 1 | require 'rspec/its' 2 | require 'rubygems' 3 | require 'bundler' 4 | begin 5 | Bundler.setup(:default, :development) 6 | rescue Bundler::BundlerError => e 7 | $stderr.puts e.message 8 | $stderr.puts "Run `bundle install` to install missing gems" 9 | exit e.status_code 10 | end 11 | 12 | require 'rspec' 13 | 14 | $LOAD_PATH.unshift(File.join(File.dirname(__FILE__), '..', 'lib')) 15 | $LOAD_PATH.unshift(File.dirname(__FILE__)) 16 | require 'statsample-glm' 17 | 18 | def expect_similar_vector(exp, obs, delta=1e-10,msg=nil) 19 | expect(exp.is_a? Daru::Vector).to be true 20 | expect(exp.size).to eq(obs.size) 21 | 22 | exp.to_a.each_with_index do |v,i| 23 | expect(v).to be_within(delta).of(obs[i]) 24 | end 25 | end 26 | 27 | def expect_similar_hash(exp, obs, delta=1e-10,msg=nil) 28 | expect(exp.is_a? Hash).to be true 29 | expect(exp.size).to eq(obs.size) 30 | 31 | exp.each_key do |k| 32 | expect(exp[k]).to be_within(delta).of(obs[k]) 33 | end 34 | end 35 | 36 | def expect_similar_array(exp, obs, delta=1e-10,msg=nil) 37 | expect(exp.is_a? Array).to be true 38 | expect(exp.size).to eq(obs.size) 39 | 40 | exp.each_with_index do |v,i| 41 | expect(v).to be_within(delta).of(obs[i]) 42 | end 43 | end 44 | 45 | def expect_equal_vector(exp,obs,delta=1e-10,msg=nil) 46 | expect(exp.is_a? Daru::Vector).to be true 47 | expect(exp.size).to eq(obs.size) 48 | 49 | exp.size.times do |i| 50 | expect(exp[i]).to be_within(delta).of(obs[i]) 51 | end 52 | end 53 | 54 | def expect_equal_matrix(exp,obs,delta=1e-10,msg=nil) 55 | expect(exp.row_size).to eq(obs.row_size) 56 | expect(exp.column_size).to eq(obs.column_size) 57 | 58 | exp.row_size.times do |i| 59 | exp.column_size.times do |j| 60 | expect(exp[i,j]).to be_within(delta).of(obs[i,j]) 61 | end 62 | end 63 | end 64 | -------------------------------------------------------------------------------- /spec/token_spec.rb: -------------------------------------------------------------------------------- 1 | require 'spec_helper.rb' 2 | 3 | describe Statsample::GLM::Token do 4 | context '#initialize' do 5 | context 'no interaction' do 6 | context 'full' do 7 | subject(:token) { described_class.new 'a' } 8 | 9 | it { is_expected.to be_a described_class } 10 | its(:to_s) { is_expected.to eq 'a' } 11 | its(:full) { is_expected.to eq [true] } 12 | end 13 | 14 | context 'not-full' do 15 | subject(:token) { described_class.new 'a', false } 16 | 17 | it { is_expected.to be_a described_class } 18 | its(:to_s) { is_expected.to eq 'a(-)' } 19 | its(:full) { is_expected.to eq [false] } 20 | end 21 | end 22 | 23 | context '2-way interaction' do 24 | subject(:token) { described_class.new 'a:b', [true, false] } 25 | 26 | it { is_expected.to be_a described_class } 27 | its(:to_s) { is_expected.to eq 'a:b(-)' } 28 | its(:full) { is_expected.to eq [true, false] } 29 | end 30 | end 31 | 32 | context '#to_df' do 33 | let(:df) { Daru::DataFrame.from_csv 'spec/data/df.csv' } 34 | before do 35 | df.to_category 'c', 'd', 'e' 36 | df['c'].categories = ['no', 'yes'] 37 | df['d'].categories = ['female', 'male'] 38 | df['e'].categories = ['A', 'B', 'C'] 39 | df['d'].base_category = 'female' 40 | end 41 | 42 | context 'no interaction' do 43 | context 'numerical' do 44 | context 'full rank' do 45 | let(:token) { Statsample::GLM::Token.new 'a', [true] } 46 | subject { token.to_df df } 47 | 48 | it { is_expected.to be_a Daru::DataFrame } 49 | it { expect(subject['a']).to eq df['a'] } 50 | end 51 | 52 | context 'reduced rank' do 53 | let(:token) { Statsample::GLM::Token.new 'a', [false] } 54 | subject { token.to_df df } 55 | 56 | it { is_expected.to be_a Daru::DataFrame } 57 | it { expect(subject['a']).to eq df['a'] } 58 | end 59 | end 60 | 61 | context 'category' do 62 | context 'full rank' do 63 | let(:token) { Statsample::GLM::Token.new 'e', [true] } 64 | subject { token.to_df df } 65 | it { is_expected.to be_a Daru:: 66 | DataFrame } 67 | its(:shape) { is_expected.to eq [14, 3] } 68 | its(:'vectors.to_a') { is_expected.to eq %w(e_A e_B e_C) } 69 | end 70 | 71 | context 'reduced rank' do 72 | let(:token) { Statsample::GLM::Token.new 'e', [false] } 73 | subject { token.to_df df } 74 | 75 | it { is_expected.to be_a Daru::DataFrame } 76 | its(:shape) { is_expected.to eq [14, 2] } 77 | its(:'vectors.to_a') { is_expected.to eq %w(e_B e_C) } 78 | end 79 | end 80 | end 81 | 82 | context '2-way interaction' do 83 | context 'numerical-numerical' do 84 | let(:token) { Statsample::GLM::Token.new 'a:b', [true, false] } 85 | subject { token.to_df df } 86 | 87 | it { is_expected.to be_a Daru::DataFrame } 88 | its(:shape) { is_expected.to eq [14, 1] } 89 | its(:'vectors.to_a') { is_expected.to eq ['a:b'] } 90 | it { expect(subject['a:b'].to_a).to eq (df['a']*df['b']).to_a } 91 | end 92 | 93 | context 'category-category' do 94 | context 'full-full' do 95 | let(:token) { Statsample::GLM::Token.new 'c:d', [true, true] } 96 | subject { token.to_df df } 97 | it { is_expected.to be_a Daru::DataFrame } 98 | its(:shape) { is_expected.to eq [14, 4] } 99 | its(:'vectors.to_a') { is_expected.to eq( 100 | ["c_no:d_female", "c_no:d_male", "c_yes:d_female", "c_yes:d_male"] 101 | ) } 102 | end 103 | 104 | context 'full-reduced' do 105 | let(:token) { Statsample::GLM::Token.new 'c:d', [true, false] } 106 | subject { token.to_df df } 107 | it { is_expected.to be_a Daru::DataFrame } 108 | its(:shape) { is_expected.to eq [14, 2] } 109 | its(:'vectors.to_a') { is_expected.to eq ['c_no:d_male', 'c_yes:d_male'] } 110 | end 111 | 112 | context 'reduced-full' do 113 | let(:token) { Statsample::GLM::Token.new 'c:d', [false, true] } 114 | subject { token.to_df df } 115 | it { is_expected.to be_a Daru::DataFrame } 116 | its(:shape) { is_expected.to eq [14, 2] } 117 | its(:'vectors.to_a') { is_expected.to eq ['c_yes:d_female', 'c_yes:d_male'] } 118 | end 119 | 120 | context 'reduced-reduced' do 121 | let(:token) { Statsample::GLM::Token.new 'c:d', [false, false] } 122 | subject { token.to_df df } 123 | it { is_expected.to be_a Daru::DataFrame } 124 | its(:shape) { is_expected.to eq [14, 1] } 125 | its(:'vectors.to_a') { is_expected.to eq ['c_yes:d_male'] } 126 | end 127 | end 128 | 129 | context 'numerical-category' do 130 | context 'full-full' do 131 | let(:token) { Statsample::GLM::Token.new 'a:c', [true, true] } 132 | subject { token.to_df df } 133 | it { is_expected.to be_a Daru::DataFrame } 134 | its(:shape) { is_expected.to eq [14, 2] } 135 | its(:'vectors.to_a') { is_expected.to eq ['a:c_no', 'a:c_yes'] } 136 | it { expect(subject['a:c_no'].to_a).to eq( 137 | [6, 0, 6, 4, 0, 11, 8, 0, 0, 11, 0, 8, 2, 3]) } 138 | it { expect(subject['a:c_yes'].to_a).to eq( 139 | [0, 18, 0, 0, 5, 0, 0, 21, 2, 0, 1, 0, 0, 0]) } 140 | end 141 | 142 | context 'reduced-reduced' do 143 | let(:token) { Statsample::GLM::Token.new 'a:c', [false, false] } 144 | subject { token.to_df df } 145 | it { is_expected.to be_a Daru::DataFrame } 146 | its(:shape) { is_expected.to eq [14, 1] } 147 | its(:'vectors.to_a') { is_expected.to eq ['a:c_yes'] } 148 | it { expect(subject['a:c_yes'].to_a).to eq( 149 | [0, 18, 0, 0, 5, 0, 0, 21, 2, 0, 1, 0, 0, 0]) } 150 | end 151 | end 152 | 153 | context 'category-numerical' do 154 | context 'full-full' do 155 | let(:token) { Statsample::GLM::Token.new 'c:a', [true, true] } 156 | subject { token.to_df df } 157 | it { is_expected.to be_a Daru::DataFrame } 158 | its(:shape) { is_expected.to eq [14, 2] } 159 | its(:'vectors.to_a') { is_expected.to eq ['c_no:a', 'c_yes:a'] } 160 | it { expect(subject['c_no:a'].to_a).to eq( 161 | [6, 0, 6, 4, 0, 11, 8, 0, 0, 11, 0, 8, 2, 3]) } 162 | it { expect(subject['c_yes:a'].to_a).to eq( 163 | [0, 18, 0, 0, 5, 0, 0, 21, 2, 0, 1, 0, 0, 0]) } 164 | end 165 | 166 | context 'reduced-reduced' do 167 | let(:token) { Statsample::GLM::Token.new 'c:a', [false, false] } 168 | subject { token.to_df df } 169 | it { is_expected.to be_a Daru::DataFrame } 170 | its(:shape) { is_expected.to eq [14, 1] } 171 | its(:'vectors.to_a') { is_expected.to eq ['c_yes:a'] } 172 | it { expect(subject['c_yes:a'].to_a).to eq( 173 | [0, 18, 0, 0, 5, 0, 0, 21, 2, 0, 1, 0, 0, 0]) } 174 | end 175 | end 176 | end 177 | end 178 | end -------------------------------------------------------------------------------- /statsample-glm.gemspec: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | $:.unshift File.expand_path("../lib", __FILE__) 3 | 4 | require 'statsample-glm/version' 5 | 6 | Statsample::GLM::DESCRIPTION = < 0.1' 31 | spec.add_runtime_dependency 'statsample', '~> 2.0' 32 | 33 | spec.add_development_dependency 'bundler', '~> 1.10' 34 | spec.add_development_dependency 'rake' 35 | spec.add_development_dependency 'rspec' 36 | spec.add_development_dependency 'awesome_print' 37 | spec.add_development_dependency 'rdoc', '~> 3.12' 38 | spec.add_development_dependency 'pry', '~> 0.10' 39 | spec.add_development_dependency 'pry-byebug' 40 | spec.add_development_dependency 'rspec-its' 41 | spec.add_development_dependency 'rubocop' 42 | end 43 | --------------------------------------------------------------------------------