├── .document ├── .gitignore ├── .rspec ├── Gemfile ├── LICENSE.txt ├── README.md ├── Rakefile ├── VERSION ├── lib └── nbayes.rb ├── nbayes.gemspec └── spec ├── nbayes_spec.rb └── spec_helper.rb /.document: -------------------------------------------------------------------------------- 1 | lib/**/*.rb 2 | bin/* 3 | - 4 | features/**/*.feature 5 | LICENSE.txt 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | Gemfile.lock 3 | 4 | # rcov generated 5 | coverage 6 | coverage.data 7 | # rdoc generated 8 | rdoc 9 | 10 | # yard generated 11 | doc 12 | .yardoc 13 | 14 | # bundler 15 | .bundle 16 | 17 | # jeweler generated 18 | pkg 19 | 20 | # Have editor/IDE/OS specific files you need to ignore? Consider using a global gitignore: 21 | # 22 | # * Create a file at ~/.gitignore 23 | # * Include files you want ignored 24 | # * Run: git config --global core.excludesfile ~/.gitignore 25 | # 26 | # After doing this, these files will be ignored in all your git projects, 27 | # saving you from having to 'pollute' every project you touch with them 28 | # 29 | # Not sure what to needs to be ignored for particular editors/OSes? Here's some ideas to get you started. (Remember, remove the leading # of the line) 30 | # 31 | # For MacOS: 32 | # 33 | #.DS_Store 34 | 35 | # For TextMate 36 | #*.tmproj 37 | #tmtags 38 | 39 | # For emacs: 40 | #*~ 41 | #\#* 42 | #.\#* 43 | 44 | # For vim: 45 | *.swp 46 | 47 | # For redcar: 48 | #.redcar 49 | 50 | # For rubinius: 51 | #*.rbc 52 | 53 | # other 54 | spec/tmp 55 | tmp 56 | -------------------------------------------------------------------------------- /.rspec: -------------------------------------------------------------------------------- 1 | --color 2 | -------------------------------------------------------------------------------- /Gemfile: -------------------------------------------------------------------------------- 1 | source "http://rubygems.org" 2 | # Add dependencies required to use your gem here. 3 | # Example: 4 | # gem "activesupport", ">= 2.3.5" 5 | 6 | # Add dependencies to develop your gem here. 7 | # Include everything needed to run rake, tests, features, etc. 8 | group :development do 9 | gem "rspec", ">= 3.9.0" 10 | gem "rdoc", ">= 3.0.0" 11 | gem "bundler", ">= 2.0.0" 12 | gem "jeweler", ">= 2.3.0" 13 | end 14 | gem 'simplecov', :require => false, :group => :test 15 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012-2016 Oasic Technologies LLC 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nbayes 2 | 3 | ``` 4 | gem install nbayes 5 | ``` 6 | 7 | NBayes is a full-featured, Ruby implementation of ``Naive Bayes``. Some of the features include: 8 | 9 | * allows prior distribution on classes to be assumed uniform (optional) 10 | * generic to work with all types of tokens, not just text 11 | * outputs probabilities, instead of just class w/max probability 12 | * customizable constant value for Laplacian smoothing 13 | * optional and customizable purging of low-frequency tokens (for performance) 14 | * optional binarized mode 15 | * uses log probabilities to avoid underflow 16 | 17 | For more information, view this blog post: http://blog.oasic.net/2012/06/naive-bayes-for-ruby.html 18 | 19 | ## Contributing to nbayes 20 | 21 | * Check out the latest master to make sure the feature hasn't been implemented or the bug hasn't been fixed yet. 22 | * Check out the issue tracker to make sure someone already hasn't requested it and/or contributed it. 23 | * Fork the project. 24 | * Start a feature/bugfix branch. 25 | * Commit and push until you are happy with your contribution. 26 | * Make sure to add tests for it. This is important so I don't break it in a future version unintentionally. 27 | * Please try not to mess with the Rakefile, version, or history. If you want to have your own version, or is otherwise necessary, that is fine, but please isolate to its own commit so I can cherry-pick around it. 28 | 29 | ## Acknowledgements 30 | 31 | This project is supported by the GrammarBot [grammar checker](http://www.GrammarBot.io/) 32 | 33 | 34 | ## Copyright 35 | 36 | Copyright (c) 2012-2021 Oasic Technologies LLC. See LICENSE.txt for further details. 37 | 38 | -------------------------------------------------------------------------------- /Rakefile: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | require 'rubygems' 4 | require 'bundler' 5 | begin 6 | Bundler.setup(:default, :development) 7 | rescue Bundler::BundlerError => e 8 | $stderr.puts e.message 9 | $stderr.puts "Run `bundle install` to install missing gems" 10 | exit e.status_code 11 | end 12 | require 'rake' 13 | 14 | require 'jeweler' 15 | Jeweler::Tasks.new do |gem| 16 | # gem is a Gem::Specification... see http://docs.rubygems.org/read/chapter/20 for more options 17 | gem.name = "nbayes" 18 | gem.homepage = "http://github.com/oasic/nbayes" 19 | gem.license = "MIT" 20 | gem.summary = %Q{Full-featured Ruby implementation of Naive Bayes classifier} 21 | gem.description = %Q{Ruby implementation of Naive Bayes that generates true probabilities per class, works with many token types, and provides lots of bells and whistles while being optimized for performance.} 22 | gem.email = "j@oasic.net" 23 | gem.authors = ["oasic"] 24 | # dependencies defined in Gemfile 25 | end 26 | Jeweler::RubygemsDotOrgTasks.new 27 | 28 | require 'rspec/core' 29 | require 'rspec/core/rake_task' 30 | RSpec::Core::RakeTask.new(:spec) do |spec| 31 | spec.pattern = FileList['spec/**/*_spec.rb'] 32 | end 33 | 34 | RSpec::Core::RakeTask.new(:rcov) do |spec| 35 | spec.pattern = 'spec/**/*_spec.rb' 36 | spec.rcov = true 37 | end 38 | 39 | task :default => :spec 40 | 41 | require 'rdoc/task' 42 | Rake::RDocTask.new do |rdoc| 43 | version = File.exist?('VERSION') ? File.read('VERSION') : "" 44 | 45 | rdoc.rdoc_dir = 'rdoc' 46 | rdoc.title = "nbayes #{version}" 47 | rdoc.rdoc_files.include('README*') 48 | rdoc.rdoc_files.include('lib/**/*.rb') 49 | end 50 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 0.1.3 2 | -------------------------------------------------------------------------------- /lib/nbayes.rb: -------------------------------------------------------------------------------- 1 | require 'yaml' 2 | 3 | # == NBayes::Base 4 | # 5 | # Robust implementation of NaiveBayes: 6 | # - using log probabilities to avoid floating point issues 7 | # - Laplacian smoothing for unseen tokens 8 | # - allows binarized or standard NB 9 | # - allows Prior distribution on category to be assumed uniform (optional) 10 | # - generic to work with all types of tokens, not just text 11 | 12 | 13 | module NBayes 14 | 15 | class Vocab 16 | attr_accessor :log_size, :tokens 17 | 18 | def initialize(options = {}) 19 | @tokens = Hash.new 20 | # for smoothing, use log of vocab size, rather than vocab size 21 | @log_size = options[:log_size] 22 | end 23 | 24 | def delete(token) 25 | tokens.delete(token) 26 | end 27 | 28 | def each(&block) 29 | tokens.keys.each(&block) 30 | end 31 | 32 | def size 33 | if log_size 34 | Math.log(tokens.count) 35 | else 36 | tokens.count 37 | end 38 | end 39 | 40 | def seen_token(token) 41 | tokens[token] = 1 42 | end 43 | end 44 | 45 | class Data 46 | attr_accessor :data 47 | def initialize(options = {}) 48 | @data = Hash.new 49 | #@data = { 50 | # "category1": { 51 | # "tokens": Hash.new(0), 52 | # "total_tokens": 0, 53 | # "examples": 0 54 | # }, 55 | # ... 56 | #} 57 | end 58 | 59 | def categories 60 | data.keys 61 | end 62 | 63 | def token_trained?(token, category) 64 | data[category] ? data[category][:tokens].has_key?(token) : false 65 | end 66 | 67 | def cat_data(category) 68 | unless data[category].is_a? Hash 69 | data[category] = new_category 70 | end 71 | data[category] 72 | end 73 | 74 | def category_stats 75 | tmp = [] 76 | total_example_count = total_examples 77 | self.each do |category| 78 | e = example_count(category) 79 | t = token_count(category) 80 | tmp << "For category #{category}, %d examples (%.02f%% of the total) and %d total_tokens" % [e, 100.0 * e / total_example_count, t] 81 | end 82 | tmp.join("\n") 83 | end 84 | 85 | def each(&block) 86 | data.keys.each(&block) 87 | end 88 | 89 | # Increment the number of training examples for this category 90 | def increment_examples(category) 91 | cat_data(category)[:examples] += 1 92 | end 93 | 94 | # Decrement the number of training examples for this category. 95 | # Delete the category if the examples counter is 0. 96 | def decrement_examples(category) 97 | cat_data(category)[:examples] -= 1 98 | delete_category(category) if cat_data(category)[:examples] < 1 99 | end 100 | 101 | def example_count(category) 102 | cat_data(category)[:examples] 103 | end 104 | 105 | def token_count(category) 106 | cat_data(category)[:total_tokens] 107 | end 108 | 109 | # XXX - Add Enumerable and see if I get inject? 110 | # Total number of training instances 111 | def total_examples 112 | sum = 0 113 | self.each {|category| sum += example_count(category) } 114 | sum 115 | end 116 | 117 | # Add this token to this category 118 | def add_token_to_category(category, token) 119 | cat_data(category)[:tokens][token] += 1 120 | cat_data(category)[:total_tokens] += 1 121 | end 122 | 123 | # Decrement the token counter in a category 124 | # If the counter is 0, delete the token. 125 | # If the total number of tokens is 0, delete the category. 126 | def remove_token_from_category(category, token) 127 | cat_data(category)[:tokens][token] -= 1 128 | delete_token_from_category(category, token) if cat_data(category)[:tokens][token] < 1 129 | cat_data(category)[:total_tokens] -= 1 130 | delete_category(category) if cat_data(category)[:total_tokens] < 1 131 | end 132 | 133 | # How many times does this token appear in this category? 134 | def count_of_token_in_category(category, token) 135 | cat_data(category)[:tokens][token] 136 | end 137 | 138 | def delete_token_from_category(category, token) 139 | count = count_of_token_in_category(category, token) 140 | cat_data(category)[:tokens].delete(token) 141 | # Update this category's total token count 142 | cat_data(category)[:total_tokens] -= count 143 | end 144 | 145 | def purge_less_than(token, x) 146 | return if token_count_across_categories(token) >= x 147 | self.each do |category| 148 | delete_token_from_category(category, token) 149 | end 150 | true # Let caller know we removed this token 151 | end 152 | 153 | # XXX - TODO - use count_of_token_in_category 154 | # Return the total number of tokens we've seen across all categories 155 | def token_count_across_categories(token) 156 | data.keys.inject(0){|sum, cat| sum + @data[cat][:tokens][token] } 157 | end 158 | 159 | def reset_after_import 160 | self.each {|category| cat_data(category)[:tokens].default = 0 } 161 | end 162 | 163 | def new_category 164 | { 165 | :tokens => Hash.new(0), # holds freq counts 166 | :total_tokens => 0, 167 | :examples => 0 168 | } 169 | end 170 | 171 | def delete_category(category) 172 | data.delete(category) if data.has_key?(category) 173 | categories 174 | end 175 | 176 | end 177 | 178 | class Base 179 | 180 | attr_accessor :assume_uniform, :debug, :k, :vocab, :data 181 | attr_reader :binarized 182 | 183 | def initialize(options={}) 184 | @debug = false 185 | @k = 1 186 | @binarized = options[:binarized] || false 187 | @assume_uniform = false 188 | @vocab = Vocab.new(:log_size => options[:log_vocab]) 189 | @data = Data.new 190 | end 191 | 192 | # Allows removal of low frequency words that increase processing time and may overfit 193 | # - tokens with a count less than x (measured by summing across all classes) are removed 194 | # Ex: nb.purge_less_than(2) 195 | # 196 | # NOTE: this does not decrement the "examples" count, so purging is not *always* the same 197 | # as if the item was never added in the first place, but usually so 198 | def purge_less_than(x) 199 | remove_list = {} 200 | @vocab.each do |token| 201 | if data.purge_less_than(token, x) 202 | # print "removing #{token}\n" 203 | remove_list[token] = 1 204 | end 205 | end # each vocab word 206 | remove_list.keys.each {|token| @vocab.delete(token) } 207 | # print "total vocab size is now #{vocab.size}\n" 208 | end 209 | 210 | # Delete an entire category from the classification data 211 | def delete_category(category) 212 | data.delete_category(category) 213 | end 214 | 215 | def train(tokens, category) 216 | tokens = tokens.uniq if binarized 217 | data.increment_examples(category) 218 | tokens.each do |token| 219 | vocab.seen_token(token) 220 | data.add_token_to_category(category, token) 221 | end 222 | end 223 | 224 | # Be carefull with this function: 225 | # * It decrement the number of examples for the category. 226 | # If the being-untrained category has no more examples, it is removed from the category list. 227 | # * It untrain already trained tokens, non existing tokens are not considered. 228 | def untrain(tokens, category) 229 | tokens = tokens.uniq if binarized 230 | data.decrement_examples(category) 231 | 232 | tokens.each do |token| 233 | if data.token_trained?(token, category) 234 | vocab.delete(token) 235 | data.remove_token_from_category(category, token) 236 | end 237 | end 238 | end 239 | 240 | def classify(tokens) 241 | print "classify: #{tokens.join(', ')}\n" if @debug 242 | probs = {} 243 | tokens = tokens.uniq if binarized 244 | probs = calculate_probabilities(tokens) 245 | print "results: #{probs.to_yaml}\n" if @debug 246 | probs.extend(NBayes::Result) 247 | probs 248 | end 249 | 250 | def category_stats 251 | data.category_stats 252 | end 253 | 254 | # Calculates the actual probability of a class given the tokens 255 | # (this is the work horse of the code) 256 | def calculate_probabilities(tokens) 257 | # P(class|words) = P(w1,...,wn|class) * P(class) / P(w1,...,wn) 258 | # = argmax P(w1,...,wn|class) * P(class) 259 | # 260 | # P(wi|class) = (count(wi, class) + k)/(count(w,class) + kV) 261 | prob_numerator = {} 262 | v_size = vocab.size 263 | 264 | cat_prob = Math.log(1 / data.categories.count.to_f) 265 | total_example_count = data.total_examples.to_f 266 | 267 | data.each do |category| 268 | unless assume_uniform 269 | cat_prob = Math.log(data.example_count(category) / total_example_count) 270 | end 271 | 272 | log_probs = 0 273 | denominator = (data.token_count(category) + @k * v_size).to_f 274 | tokens.each do |token| 275 | numerator = data.count_of_token_in_category(category, token) + @k 276 | log_probs += Math.log( numerator / denominator ) 277 | end 278 | prob_numerator[category] = log_probs + cat_prob 279 | end 280 | normalize(prob_numerator) 281 | end 282 | 283 | def normalize(prob_numerator) 284 | # calculate the denominator, which normalizes this into a probability; it's just the sum of all numerators from above 285 | normalizer = 0 286 | prob_numerator.each {|cat, numerator| normalizer += numerator } 287 | # One more caveat: 288 | # We're using log probabilities, so the numbers are negative and the smallest negative number is actually the largest prob. 289 | # To convert, we need to maintain the relative distance between all of the probabilities: 290 | # - divide log prob by normalizer: this keeps ratios the same, but reverses the ordering 291 | # - re-normalize based off new counts 292 | # - final calculation 293 | # Ex: -1,-1,-2 => -4/-1, -4/-1, -4/-2 294 | # - renormalize and calculate => 4/10, 4/10, 2/10 295 | intermed = {} 296 | renormalizer = 0 297 | prob_numerator.each do |cat, numerator| 298 | intermed[cat] = normalizer / numerator.to_f 299 | renormalizer += intermed[cat] 300 | end 301 | # calculate final probs 302 | final_probs = {} 303 | intermed.each do |cat, value| 304 | final_probs[cat] = value / renormalizer.to_f 305 | end 306 | final_probs 307 | end 308 | 309 | # called internally after yaml import to reset Hash defaults 310 | def reset_after_import 311 | data.reset_after_import 312 | end 313 | 314 | def self.from_yml(yml_data) 315 | nbayes = YAML.load(yml_data) 316 | nbayes.reset_after_import() # yaml does not properly set the defaults on the Hashes 317 | nbayes 318 | end 319 | 320 | # Loads class instance from a data file (e.g., yaml) 321 | def self.from(yml_file) 322 | File.open(yml_file, "rb") do |file| 323 | self.from_yml(file.read) 324 | end 325 | end 326 | 327 | # Load class instance 328 | def load(yml) 329 | if yml.nil? 330 | nbayes = NBayes::Base.new 331 | elsif yml[0..2] == "---" 332 | nbayes = self.class.from_yml(yml) 333 | else 334 | nbayes = self.class.from(yml) 335 | end 336 | nbayes 337 | end 338 | 339 | # Dumps class instance to a data file (e.g., yaml) or a string 340 | def dump(arg) 341 | if arg.instance_of? String 342 | File.open(arg, "w") {|f| YAML.dump(self, f) } 343 | else 344 | YAML.dump(arg) 345 | end 346 | end 347 | 348 | end 349 | 350 | module Result 351 | # Return the key having the largest value 352 | def max_class 353 | keys.max{ |a,b| self[a] <=> self[b] } 354 | end 355 | end 356 | 357 | end 358 | -------------------------------------------------------------------------------- /nbayes.gemspec: -------------------------------------------------------------------------------- 1 | # Generated by jeweler 2 | # DO NOT EDIT THIS FILE DIRECTLY 3 | # Instead, edit Jeweler::Tasks in Rakefile, and run 'rake gemspec' 4 | # -*- encoding: utf-8 -*- 5 | # stub: nbayes 0.1.3 ruby lib 6 | 7 | Gem::Specification.new do |s| 8 | s.name = "nbayes".freeze 9 | s.version = "0.1.3" 10 | 11 | s.required_rubygems_version = Gem::Requirement.new(">= 0".freeze) if s.respond_to? :required_rubygems_version= 12 | s.require_paths = ["lib".freeze] 13 | s.authors = ["oasic".freeze] 14 | s.date = "2020-06-26" 15 | s.description = "Ruby implementation of Naive Bayes that generates true probabilities per class, works with many token types, and provides lots of bells and whistles while being optimized for performance.".freeze 16 | s.email = "j@oasic.net".freeze 17 | s.extra_rdoc_files = [ 18 | "LICENSE.txt", 19 | "README.md" 20 | ] 21 | s.files = [ 22 | ".document", 23 | ".rspec", 24 | "Gemfile", 25 | "LICENSE.txt", 26 | "README.md", 27 | "Rakefile", 28 | "VERSION", 29 | "lib/nbayes.rb", 30 | "nbayes.gemspec", 31 | "spec/nbayes_spec.rb", 32 | "spec/spec_helper.rb" 33 | ] 34 | s.homepage = "http://github.com/oasic/nbayes".freeze 35 | s.licenses = ["MIT".freeze] 36 | s.rubygems_version = "2.6.14".freeze 37 | s.summary = "Full-featured Ruby implementation of Naive Bayes classifier".freeze 38 | 39 | if s.respond_to? :specification_version then 40 | s.specification_version = 4 41 | 42 | if Gem::Version.new(Gem::VERSION) >= Gem::Version.new('1.2.0') then 43 | s.add_development_dependency(%q.freeze, [">= 3.9.0"]) 44 | s.add_development_dependency(%q.freeze, [">= 3.0.0"]) 45 | s.add_development_dependency(%q.freeze, [">= 2.0.0"]) 46 | s.add_development_dependency(%q.freeze, [">= 2.3.0"]) 47 | else 48 | s.add_dependency(%q.freeze, [">= 3.9.0"]) 49 | s.add_dependency(%q.freeze, [">= 3.0.0"]) 50 | s.add_dependency(%q.freeze, [">= 2.0.0"]) 51 | s.add_dependency(%q.freeze, [">= 2.3.0"]) 52 | end 53 | else 54 | s.add_dependency(%q.freeze, [">= 3.9.0"]) 55 | s.add_dependency(%q.freeze, [">= 3.0.0"]) 56 | s.add_dependency(%q.freeze, [">= 2.0.0"]) 57 | s.add_dependency(%q.freeze, [">= 2.3.0"]) 58 | end 59 | end 60 | 61 | -------------------------------------------------------------------------------- /spec/nbayes_spec.rb: -------------------------------------------------------------------------------- 1 | require File.expand_path(File.dirname(__FILE__) + '/spec_helper') 2 | require 'fileutils' 3 | 4 | describe NBayes do 5 | let(:nbayes) { NBayes::Base.new } 6 | 7 | describe 'should assign equal probability to each class' do 8 | let(:results) { nbayes.classify(%w(a b c)) } 9 | 10 | before do 11 | nbayes.train(%w(a b c d e f g), 'classA') 12 | nbayes.train(%w(a b c d e f g), 'classB') 13 | end 14 | 15 | specify { expect(results['classA']).to eq(0.5) } 16 | specify { expect(results['classB']).to eq(0.5) } 17 | end 18 | 19 | describe 'should handle more than 2 classes' do 20 | let(:results) { nbayes.classify(%w(a a a a b c)) } 21 | 22 | before do 23 | nbayes.train(%w(a a a a), 'classA') 24 | nbayes.train(%w(b b b b), 'classB') 25 | nbayes.train(%w(c c), 'classC') 26 | end 27 | 28 | specify { expect(results.max_class).to eq('classA') } 29 | specify { expect(results['classA']).to be >= 0.4 } 30 | specify { expect(results['classB']).to be <= 0.3 } 31 | specify { expect(results['classC']).to be <= 0.3 } 32 | end 33 | 34 | describe 'should use smoothing by default to eliminate errors' do 35 | context 'when dividing by zero' do 36 | let(:results) { nbayes.classify(%w(x y z)) } 37 | 38 | before do 39 | nbayes.train(%w(a a a a), 'classA') 40 | nbayes.train(%w(b b b b), 'classB') 41 | end 42 | 43 | specify { expect(results['classA']).to be >= 0.0 } 44 | specify { expect(results['classB']).to be >= 0.0 } 45 | end 46 | end 47 | 48 | describe 'should optionally purge low frequency data' do 49 | let(:results) { nbayes.classify(%w(c)) } 50 | let(:token_count) { nbayes.data.count_of_token_in_category('classB', 'c') } 51 | 52 | before do 53 | 100.times do 54 | nbayes.train(%w(a a a a), 'classA') 55 | nbayes.train(%w(b b b b), 'classB') 56 | end 57 | nbayes.train(%w(a), 'classA') 58 | nbayes.train(%w(c b), 'classB') 59 | end 60 | 61 | context 'before purge' do 62 | specify { expect(results.max_class).to eq('classB') } 63 | specify { expect(results['classB']).to be > 0.5 } 64 | specify { expect(token_count).to eq(1) } 65 | end 66 | 67 | context 'after purge' do 68 | before { nbayes.purge_less_than(2) } 69 | 70 | specify { expect(results['classA']).to eq(0.5) } 71 | specify { expect(results['classB']).to eq(0.5) } 72 | specify { expect(token_count).to be_zero } 73 | end 74 | end 75 | 76 | it 'works on all tokens - not just strings' do 77 | nbayes.train([1, 2, 3], 'low') 78 | nbayes.train([5, 6, 7], 'high') 79 | results = nbayes.classify([2]) 80 | expect(results.max_class).to eq('low') 81 | results = nbayes.classify([6]) 82 | expect(results.max_class).to eq('high') 83 | end 84 | 85 | describe 'should optionally allow class distribution to be assumed uniform' do 86 | context 'before uniform distribution' do 87 | let(:before_results) { nbayes.classify(['a']) } 88 | 89 | before do 90 | nbayes.train(%w(a a a a b), 'classA') 91 | nbayes.train(%w(a a a a), 'classA') 92 | nbayes.train(%w(a a a a), 'classB') 93 | end 94 | 95 | specify { expect(before_results.max_class).to eq('classA') } 96 | specify { expect(before_results['classA']).to be > 0.5 } 97 | 98 | context 'and after uniform distribution assumption' do 99 | let(:after_results) { nbayes.classify(['a']) } 100 | 101 | before { nbayes.assume_uniform = true } 102 | 103 | specify { expect(after_results.max_class).to eq('classB') } 104 | specify { expect(after_results['classB']).to be > 0.5 } 105 | end 106 | end 107 | end 108 | 109 | # In binarized mode, the frequency count is set to 1 for each token in each instance 110 | # For text, this is "set of words" rather than "bag of words" 111 | it 'should allow binarized mode' do 112 | # w/o binarized mode, token repetition can skew the results 113 | # def train_it 114 | nbayes.train(%w(a a a a a a a a a a a), 'classA') 115 | nbayes.train(%w(b b), 'classA') 116 | nbayes.train(%w(a c), 'classB') 117 | nbayes.train(%w(a c), 'classB') 118 | nbayes.train(%w(a c), 'classB') 119 | # end 120 | # train_it 121 | results = nbayes.classify(['a']) 122 | expect(results.max_class).to eq('classA') 123 | expect(results['classA']).to be > 0.5 124 | # this does not happen in binarized mode 125 | nbayes = NBayes::Base.new(binarized: true) 126 | nbayes.train(%w(a a a a a a a a a a a), 'classA') 127 | nbayes.train(%w(b b), 'classA') 128 | nbayes.train(%w(a c), 'classB') 129 | nbayes.train(%w(a c), 'classB') 130 | nbayes.train(%w(a c), 'classB') 131 | results = nbayes.classify(['a']) 132 | expect(results.max_class).to eq('classB') 133 | expect(results['classB']).to be > 0.5 134 | end 135 | 136 | it 'allows smoothing constant k to be set to any value' do 137 | # increasing k increases smoothing 138 | nbayes.train(%w(a a a c), 'classA') 139 | nbayes.train(%w(b b b d), 'classB') 140 | expect(nbayes.k).to eq(1) 141 | results = nbayes.classify(['c']) 142 | prob_k1 = results['classA'] 143 | nbayes.k = 5 144 | results = nbayes.classify(['c']) 145 | prob_k5 = results['classA'] 146 | expect(prob_k1).to be > prob_k5 # increasing smoothing constant dampens the effect of the rare token 'c' 147 | end 148 | 149 | it 'optionally allows using the log of vocab size during smoothing' do 150 | 10_000.times do 151 | nbayes.train([rand(100)], 'classA') 152 | nbayes.train(%w(b b b d), 'classB') 153 | end 154 | end 155 | 156 | describe 'saving' do 157 | let(:tmp_dir) { File.join(File.dirname(__FILE__), 'tmp') } 158 | let(:yml_file) { File.join(tmp_dir, 'test.yml') } 159 | 160 | before { FileUtils.mkdir(tmp_dir) unless File.exist?(tmp_dir) } 161 | 162 | after { FileUtils.rm(yml_file) if File.exist?(yml_file) } 163 | 164 | it 'should save to yaml and load from yaml' do 165 | nbayes.train(%w(a a a a), 'classA') 166 | nbayes.train(%w(b b b b), 'classB') 167 | results = nbayes.classify(['b']) 168 | expect(results['classB']).to be >= 0.5 169 | nbayes.dump(yml_file) 170 | expect(File.exist?(yml_file)).to eq(true) 171 | nbayes2 = NBayes::Base.from(yml_file) 172 | results = nbayes.classify(['b']) 173 | expect(results['classB']).to be >= 0.5 174 | end 175 | end 176 | 177 | it 'should dump to yaml string and load from yaml string' do 178 | nbayes.train(%w(a a a a), 'classA') 179 | nbayes.train(%w(b b b b), 'classB') 180 | results = nbayes.classify(['b']) 181 | expect(results['classB']).to be >= 0.5 182 | yml = nbayes.dump(nbayes) 183 | nbayes2 = NBayes::Base.new.load(yml) 184 | results = nbayes.classify(['b']) 185 | expect(results['classB']).to be >= 0.5 186 | end 187 | 188 | describe 'should delete a category' do 189 | before do 190 | nbayes.train(%w(a a a a), 'classA') 191 | nbayes.train(%w(b b b b), 'classB') 192 | expect(nbayes.data.categories).to eq(%w(classA classB)) 193 | expect(nbayes.delete_category('classB')).to eq(['classA']) 194 | end 195 | 196 | specify { expect(nbayes.data.categories).to eq(['classA']) } 197 | end 198 | 199 | describe 'should do nothing if asked to delete an inexistant category' do 200 | before { nbayes.train(%w(a a a a), 'classA') } 201 | 202 | specify { expect(nbayes.data.categories).to eq(['classA']) } 203 | specify { expect(nbayes.delete_category('classB')).to eq(['classA']) } 204 | specify { expect(nbayes.data.categories).to eq(['classA']) } 205 | end 206 | 207 | describe 'should untrain a class' do 208 | let(:results) { nbayes.classify(%w(a b c)) } 209 | 210 | before do 211 | nbayes.train(%w(a b c d e f g), 'classA') 212 | nbayes.train(%w(a b c d e f g), 'classB') 213 | nbayes.train(%w(a b c d e f g), 'classB') 214 | nbayes.untrain(%w(a b c d e f g), 'classB') 215 | end 216 | 217 | specify { expect(results['classA']).to eq(0.5) } 218 | specify { expect(results['classB']).to eq(0.5) } 219 | end 220 | 221 | describe 'should remove the category when the only example is untrained' do 222 | before do 223 | nbayes.train(%w(a b c d e f g), 'classA') 224 | nbayes.untrain(%w(a b c d e f g), 'classA') 225 | end 226 | 227 | specify { expect(nbayes.data.categories).to eq([]) } 228 | end 229 | 230 | describe 'try untraining a non-existant category' do 231 | let(:results) { nbayes.classify(%w(a b c)) } 232 | 233 | before do 234 | nbayes.train(%w(a b c d e f g), 'classA') 235 | nbayes.train(%w(a b c d e f g), 'classB') 236 | nbayes.untrain(%w(a b c d e f g), 'classC') 237 | end 238 | 239 | specify { expect(nbayes.data.categories).to eq(%w(classA classB)) } 240 | specify { expect(results['classA']).to eq(0.5) } 241 | specify { expect(results['classB']).to eq(0.5) } 242 | end 243 | end 244 | -------------------------------------------------------------------------------- /spec/spec_helper.rb: -------------------------------------------------------------------------------- 1 | # These 2 lines MUST be first 2 | require 'simplecov' 3 | SimpleCov.start 4 | 5 | $LOAD_PATH.unshift(File.join(File.dirname(__FILE__), '..', 'lib')) 6 | $LOAD_PATH.unshift(File.dirname(__FILE__)) 7 | require 'rspec' 8 | require 'nbayes' 9 | 10 | # Requires supporting files with custom matchers and macros, etc, 11 | # in ./support/ and its subdirectories. 12 | Dir["#{File.dirname(__FILE__)}/support/**/*.rb"].each {|f| require f} 13 | 14 | RSpec.configure do |config| 15 | 16 | end 17 | --------------------------------------------------------------------------------