├── .rspec ├── lib ├── brains │ ├── version.rb │ ├── gson.jar │ ├── brains.jar │ ├── commons-cli.jar │ ├── commons-lang3.jar │ └── net.rb └── brains.rb ├── spec ├── spec_helper.rb └── brains_spec.rb ├── Gemfile ├── .travis.yml ├── Rakefile ├── .gitignore ├── bin ├── setup └── console ├── CHANGELOG.md ├── LICENSE.txt ├── example ├── xor.rb ├── sine_function.rb ├── colors.rb ├── iris.rb └── iris.data ├── brains.gemspec ├── CODE_OF_CONDUCT.md ├── iris.data └── README.md /.rspec: -------------------------------------------------------------------------------- 1 | --format documentation 2 | --color 3 | -------------------------------------------------------------------------------- /lib/brains/version.rb: -------------------------------------------------------------------------------- 1 | module Brains 2 | VERSION = "0.2.2" 3 | end 4 | -------------------------------------------------------------------------------- /lib/brains/gson.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jedld/brains-jruby/HEAD/lib/brains/gson.jar -------------------------------------------------------------------------------- /lib/brains/brains.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jedld/brains-jruby/HEAD/lib/brains/brains.jar -------------------------------------------------------------------------------- /spec/spec_helper.rb: -------------------------------------------------------------------------------- 1 | $LOAD_PATH.unshift File.expand_path('../../lib', __FILE__) 2 | require 'brains' 3 | -------------------------------------------------------------------------------- /lib/brains/commons-cli.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jedld/brains-jruby/HEAD/lib/brains/commons-cli.jar -------------------------------------------------------------------------------- /lib/brains/commons-lang3.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jedld/brains-jruby/HEAD/lib/brains/commons-lang3.jar -------------------------------------------------------------------------------- /Gemfile: -------------------------------------------------------------------------------- 1 | source 'https://rubygems.org' 2 | 3 | # Specify your gem's dependencies in brains.gemspec 4 | gemspec 5 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | language: ruby 3 | rvm: 4 | - 2.3.1 5 | before_install: gem install bundler -v 1.12.3 6 | -------------------------------------------------------------------------------- /Rakefile: -------------------------------------------------------------------------------- 1 | require "bundler/gem_tasks" 2 | require "rspec/core/rake_task" 3 | 4 | RSpec::Core::RakeTask.new(:spec) 5 | 6 | task :default => :spec 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.bundle/ 2 | /.yardoc 3 | /Gemfile.lock 4 | /_yardoc/ 5 | /coverage/ 6 | /doc/ 7 | /pkg/ 8 | /spec/reports/ 9 | /tmp/ 10 | *.gem 11 | -------------------------------------------------------------------------------- /bin/setup: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -euo pipefail 3 | IFS=$'\n\t' 4 | set -vx 5 | 6 | bundle install 7 | 8 | # Do any other automated setup that you need to do here 9 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | Version 0.2.2 2 | ------------- 3 | 4 | * Bug fix for softmax function support in the output layer 5 | 6 | Version 0.2.1 7 | ------------- 8 | 9 | * Bug fix for recurrent neural network support 10 | 11 | Version 0.2.0 12 | ------------- 13 | 14 | * Support for multilayer recurrent neural networks 15 | -------------------------------------------------------------------------------- /lib/brains.rb: -------------------------------------------------------------------------------- 1 | require "brains/version" 2 | require "brains/brains.jar" 3 | require "brains/gson.jar" 4 | require "brains/commons-lang3.jar" 5 | require "brains/commons-cli.jar" 6 | require "brains/net" 7 | require "json" 8 | 9 | module Brains 10 | class Config 11 | attr_accessor :neurons_per_layer, :input_neurons, :output_neurons 12 | end 13 | end 14 | -------------------------------------------------------------------------------- /bin/console: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env ruby 2 | 3 | require "bundler/setup" 4 | require "brains" 5 | 6 | # You can add fixtures and/or initialization code here to make experimenting 7 | # with your gem easier. You can also use a different console, if you like. 8 | 9 | # (If you use this, don't forget to add pry to your Gemfile!) 10 | # require "pry" 11 | # Pry.start 12 | 13 | require "irb" 14 | IRB.start 15 | -------------------------------------------------------------------------------- /spec/brains_spec.rb: -------------------------------------------------------------------------------- 1 | require 'spec_helper' 2 | 3 | describe Brains do 4 | it 'has a version number' do 5 | expect(Brains::VERSION).not_to be nil 6 | end 7 | 8 | it 'Train the sin function using neural networks' do 9 | # 1 input neuron, 1 output neuron, 5 total neurons 10 | brain = Brains::Net.new(1, 1, 20) 11 | brain.randomize_weights 12 | test_cases = [ 13 | [ [1.0] , [Math.sin(1.0)] ], 14 | [ [0.9] , [Math.sin(0.9)] ], 15 | [ [0.5] , [Math.sin(0.5)] ], 16 | [ [0.3] , [Math.sin(0.3)] ], 17 | [ [0.1] , [Math.sin(0.1)] ], 18 | [ [0.2] , [Math.sin(0.2)] ], 19 | [ [0.8] , [Math.sin(0.8)] ], 20 | [ [0] , [Math.sin(0)] ], 21 | ] 22 | expect((brain.feed([0.6])[0] - Math.sin(0.6)).abs).to be > 0.1 23 | brain.optimize(test_cases, 0.001) { |i, total_errors| 24 | puts "#{i} -> #{total_errors}" 25 | } 26 | expect((brain.feed([0.6])[0] - Math.sin(0.6)).abs).to be < 0.1 27 | expect(JSON.parse(brain.to_json).select { |k,v| ['n'].include? k}).to eq ({'n' => 20}) 28 | end 29 | end 30 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Joseph Emmanuel Dayo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /example/xor.rb: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env ruby 2 | 3 | require 'brains' 4 | 5 | 6 | # Build a 3 layer network: 4 input neurons, 4 hidden neurons, 3 output neurons 7 | # Bias neurons are automatically added to input + hidden layers; no need to specify these 8 | # 5 = 4 in one hidden layer + 1 output neuron (input neurons not counted) 9 | 10 | nn = Brains::Net.create(2, 1, 1, { neurons_per_layer: 4 }) 11 | nn.randomize_weights 12 | 13 | # A B A XOR B 14 | # 1 1 0 15 | # 1 0 1 16 | # 0 1 1 17 | # 0 0 0 18 | 19 | training_data = [ 20 | [[0.9, 0.9], [0.1]], 21 | [[0.9, 0.1], [0.9]], 22 | [[0.1, 0.9], [0.9]], 23 | [[0.1, 0.1], [0.1]], 24 | ] 25 | 26 | # test on untrained data 27 | test_data = [ 28 | [0.9, 0.9], 29 | [0.9, 0.1], 30 | [0.1, 0.9], 31 | [0.1, 0.1] 32 | ] 33 | 34 | results = test_data.collect { |item| 35 | nn.feed(item) 36 | } 37 | 38 | p results 39 | 40 | result = nn.optimize(training_data, 0.01, 1_000_000 ) { |i, error| 41 | puts "#{i} #{error}" 42 | } 43 | 44 | puts "after training" 45 | 46 | results = test_data.collect { |item| 47 | nn.feed(item) 48 | } 49 | 50 | 51 | p results 52 | 53 | state = nn.to_json 54 | puts state 55 | 56 | nn2 = Brains::Net.load(state) 57 | 58 | results2 = test_data.collect { |item| 59 | nn2.feed(item) 60 | } 61 | 62 | puts "use saved state" 63 | 64 | p results2 65 | -------------------------------------------------------------------------------- /brains.gemspec: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | lib = File.expand_path('../lib', __FILE__) 3 | $LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib) 4 | require 'brains/version' 5 | 6 | Gem::Specification.new do |spec| 7 | spec.name = "brains" 8 | spec.version = Brains::VERSION 9 | spec.authors = ["Joseph Emmanuel Dayo"] 10 | spec.email = ["joseph.dayo@gmail.com"] 11 | 12 | spec.summary = %q{A feedforward neural network library for JRuby} 13 | spec.description = %q{A feedforward neural network library for JRuby. Aims to provide a quick way to get started on machine learning with ruby } 14 | spec.homepage = "https://github.com/jedld/brains-jruby" 15 | spec.license = "MIT" 16 | spec.platform = "java" 17 | 18 | # Prevent pushing this gem to RubyGems.org. To allow pushes either set the 'allowed_push_host' 19 | # to allow pushing to a single host or delete this section to allow pushing to any host. 20 | if spec.respond_to?(:metadata) 21 | spec.metadata['allowed_push_host'] = "https://rubygems.org" 22 | else 23 | raise "RubyGems 2.0 or newer is required to protect against public gem pushes." 24 | end 25 | 26 | spec.files = `git ls-files -z`.split("\x0").reject { |f| f.match(%r{^(test|spec|features)/}) } 27 | spec.bindir = "exe" 28 | spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) } 29 | spec.require_paths = ["lib"] 30 | 31 | spec.add_development_dependency "bundler", "~> 1.12" 32 | spec.add_development_dependency "rake", "~> 10.0" 33 | spec.add_development_dependency "rspec", "~> 3.0" 34 | end 35 | -------------------------------------------------------------------------------- /example/sine_function.rb: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env ruby 2 | 3 | require 'brains' 4 | 5 | # RNN to approximate a sine function sequence 6 | 7 | def generate_sine_test_data(start_t, end_t) 8 | inputs = [] 9 | outputs = [] 10 | 11 | (start_t...end_t).each do |t| 12 | inputs << [Math.sin(t)] 13 | outputs << [Math.sin(t + 1)] 14 | end 15 | 16 | [[inputs, outputs]] 17 | end 18 | 19 | 20 | 21 | training_data = generate_sine_test_data(0, 10) 22 | 23 | testing_data = generate_sine_test_data(11, 20) 24 | 25 | # input sequence 26 | input_sequence = training_data[0][0].map { |a| a[0] } 27 | output_sequence = training_data[0][1].map { |a| a[0] } 28 | test_input_sequence = testing_data[0][0].map { |a| a[0] } 29 | test_output_sequence = testing_data[0][1].map { |a| a[0] } 30 | 31 | nn = Brains::Net.create(1, 1, 1, { neurons_per_layer: 3, 32 | learning_rate: 0.01, 33 | recurrent: true, 34 | output_function: :htan, 35 | }) 36 | 37 | # randomize weights before training 38 | nn.randomize_weights 39 | 40 | results = nn.feed(testing_data[0][0]) 41 | results.each_with_index do |a, index| 42 | puts "#{test_input_sequence[index]} => #{a[0]}" 43 | end 44 | 45 | result = nn.optimize_recurrent(training_data, 0.001, 100_000_000, 10_000 ) { |i, error| 46 | puts "#{i} #{error}" 47 | } 48 | 49 | results = nn.feed(training_data[0][0]) 50 | puts " Training data" 51 | results.each_with_index do |a, index| 52 | puts "#{input_sequence[index]} => #{a[0]} (#{output_sequence[index]})" 53 | end 54 | puts " Testing data" 55 | results = nn.feed(testing_data[0][0]) 56 | 57 | results.each_with_index do |a, index| 58 | puts "#{test_input_sequence[index]} => #{a[0]} (#{test_output_sequence[index]})" 59 | end 60 | 61 | puts nn.to_json 62 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Code of Conduct 2 | 3 | As contributors and maintainers of this project, and in the interest of 4 | fostering an open and welcoming community, we pledge to respect all people who 5 | contribute through reporting issues, posting feature requests, updating 6 | documentation, submitting pull requests or patches, and other activities. 7 | 8 | We are committed to making participation in this project a harassment-free 9 | experience for everyone, regardless of level of experience, gender, gender 10 | identity and expression, sexual orientation, disability, personal appearance, 11 | body size, race, ethnicity, age, religion, or nationality. 12 | 13 | Examples of unacceptable behavior by participants include: 14 | 15 | * The use of sexualized language or imagery 16 | * Personal attacks 17 | * Trolling or insulting/derogatory comments 18 | * Public or private harassment 19 | * Publishing other's private information, such as physical or electronic 20 | addresses, without explicit permission 21 | * Other unethical or unprofessional conduct 22 | 23 | Project maintainers have the right and responsibility to remove, edit, or 24 | reject comments, commits, code, wiki edits, issues, and other contributions 25 | that are not aligned to this Code of Conduct, or to ban temporarily or 26 | permanently any contributor for other behaviors that they deem inappropriate, 27 | threatening, offensive, or harmful. 28 | 29 | By adopting this Code of Conduct, project maintainers commit themselves to 30 | fairly and consistently applying these principles to every aspect of managing 31 | this project. Project maintainers who do not follow or enforce the Code of 32 | Conduct may be permanently removed from the project team. 33 | 34 | This code of conduct applies both within project spaces and in public spaces 35 | when an individual is representing the project or its community. 36 | 37 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 38 | reported by contacting a project maintainer at joseph.dayo@gmail.com. All 39 | complaints will be reviewed and investigated and will result in a response that 40 | is deemed necessary and appropriate to the circumstances. Maintainers are 41 | obligated to maintain confidentiality with regard to the reporter of an 42 | incident. 43 | 44 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 45 | version 1.3.0, available at 46 | [http://contributor-covenant.org/version/1/3/0/][version] 47 | 48 | [homepage]: http://contributor-covenant.org 49 | [version]: http://contributor-covenant.org/version/1/3/0/ -------------------------------------------------------------------------------- /example/colors.rb: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env ruby 2 | 3 | require 'brains' 4 | 5 | #This neural network will identify the main color name based on rgb values 6 | RED = [0,0,1] 7 | GREEN = [0,1,0] 8 | BLUE = [1,0,0] 9 | 10 | def color_value(color_value) 11 | [ 12 | Integer(color_value[0..1], 16).to_f / 0xff.to_f, 13 | Integer(color_value[2..3], 16).to_f / 0xff.to_f, 14 | Integer(color_value[4..5], 16).to_f / 0xff.to_f, 15 | ] 16 | end 17 | 18 | def color_desc(result) 19 | return "blue" if (result[0] > result[1] && result[0] > result[2]) 20 | return "green" if (result[1] > result[0] && result[1] > result[2]) 21 | return "red" if (result[2] > result[0] && result[2] > result[1]) 22 | end 23 | 24 | label_encodings = { 25 | "Blue" => [1, 0, 0], 26 | "Green" => [0, 1, 0], 27 | "Red" => [0, 0 ,1] 28 | } 29 | #0000ff 30 | training_data = [ 31 | # red 32 | [color_value('E32636'), RED], 33 | [color_value('8B0000'), RED], 34 | [color_value('800000'), RED], 35 | [color_value('65000B'), RED], 36 | [color_value('674846'), RED], 37 | 38 | #green 39 | [color_value('8F9779'), GREEN], 40 | [color_value('568203'), GREEN], 41 | [color_value('013220'), GREEN], 42 | [color_value('00FF00'), GREEN], 43 | [color_value('006400'), GREEN], 44 | [color_value('00A877'), GREEN], 45 | 46 | #blue 47 | [color_value('89CFF0'), BLUE], 48 | [color_value('ADD8E6'), BLUE], 49 | [color_value('0000FF'), BLUE], 50 | [color_value('0070BB'), BLUE], 51 | [color_value('545AA7'), BLUE], 52 | [color_value('4C516D'), BLUE], 53 | ] 54 | 55 | nn = Brains::Net.create(3, 3, 1, { 56 | neurons_per_layer: 3, 57 | output_function: :softmax, 58 | error: :cross_entropy }) 59 | 60 | # randomize weights before training 61 | nn.randomize_weights 62 | 63 | 64 | # test on untrained data 65 | #0000ee #C41E3A 66 | test_data = [ 67 | [color_value('0087BD') , 'blue'], # blue 68 | [color_value('C80815') , 'red'], # venetian red 69 | [color_value('009E60') , 'green'], # Shamrock green 70 | [color_value('00FF00') , 'green'], # green 71 | [color_value('333399') , 'blue'], # blue 72 | ] 73 | 74 | correct = 0 75 | test_data.each_with_index { |item , index| 76 | c = color_desc(nn.feed(item[0])) 77 | correct +=1 if (c == item[1]) 78 | puts c 79 | } 80 | 81 | puts "#{correct}/#{test_data.length}" 82 | 83 | result = nn.optimize(training_data, 0.25, 100_000, 100 ) { |i, error| 84 | puts "#{i} #{error}" 85 | } 86 | 87 | p result 88 | 89 | puts "after training" 90 | 91 | correct = 0 92 | test_data.each_with_index { |item , index| 93 | r = nn.feed(item[0]) 94 | c = color_desc(r) 95 | correct +=1 if (c == item[1]) 96 | puts "#{r} -> #{c}" 97 | } 98 | 99 | puts "#{correct}/#{test_data.length}" 100 | 101 | puts nn.to_json 102 | -------------------------------------------------------------------------------- /example/iris.rb: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env ruby 2 | 3 | require 'brains' 4 | 5 | # This neural network will predict the species of an iris based on sepal and petal size 6 | # Dataset: http://en.wikipedia.org/wiki/Iris_flower_data_set 7 | 8 | rows = File.readlines("iris.data").map {|l| l.chomp.split(',') } 9 | 10 | rows.shuffle! 11 | 12 | label_encodings = { 13 | "Iris-setosa" => [1, 0, 0], 14 | "Iris-versicolor" => [0, 1, 0], 15 | "Iris-virginica" => [0, 0 ,1] 16 | } 17 | 18 | x_data = rows.map {|row| row[0,4].map(&:to_f) } 19 | y_data = rows.map {|row| label_encodings[row[4]] } 20 | 21 | # Normalize data values before feeding into network 22 | normalize = -> (val, high, low) { (val - low) / (high - low) } # maps input to float between 0 and 1 23 | 24 | columns = (0..3).map do |i| 25 | x_data.map {|row| row[i] } 26 | end 27 | 28 | x_data.map! do |row| 29 | row.map.with_index do |val, j| 30 | max, min = columns[j].max, columns[j].min 31 | normalize.(val, max, min) 32 | end 33 | end 34 | 35 | x_train = x_data.slice(0, 100) 36 | y_train = y_data.slice(0, 100) 37 | 38 | x_test = x_data.slice(100, 50) 39 | y_test = y_data.slice(100, 50) 40 | 41 | test_cases = [] 42 | x_train.each_with_index do |x, index| 43 | test_cases << [x, y_train[index] ] 44 | end 45 | 46 | validation_cases = [] 47 | x_test.each_with_index do |x, index| 48 | test_cases << [x, y_test[index] ] 49 | end 50 | 51 | # Build a 3 layer network: 4 input neurons, 4 hidden neurons, 3 output neurons 52 | # Bias neurons are automatically added to input + hidden layers; no need to specify these 53 | nn = Brains::Net.create(4, 3, 1, { neurons_per_layer: 4 }) 54 | nn.randomize_weights 55 | 56 | prediction_success = -> (actual, ideal) { 57 | predicted = (0..2).max_by {|i| actual[i] } 58 | ideal[predicted] == 1 59 | } 60 | 61 | mse = -> (actual, ideal) { 62 | errors = actual.zip(ideal).map {|a, i| a - i } 63 | (errors.inject(0) {|sum, err| sum += err**2}) / errors.length.to_f 64 | } 65 | 66 | error_rate = -> (errors, total) { ((errors / total.to_f) * 100).round } 67 | 68 | run_test = -> (nn, inputs, expected_outputs) { 69 | success, failure, errsum = 0,0,0 70 | inputs.each.with_index do |input, i| 71 | output = nn.feed input 72 | prediction_success.(output, expected_outputs[i]) ? success += 1 : failure += 1 73 | errsum += mse.(output, expected_outputs[i]) 74 | end 75 | [success, failure, errsum / inputs.length.to_f] 76 | } 77 | 78 | puts "Testing the untrained network..." 79 | 80 | success, failure, avg_mse = run_test.(nn, x_test, y_test) 81 | 82 | puts "Untrained classification success: #{success}, failure: #{failure} (classification error: #{error_rate.(failure, x_test.length)}%, mse: #{(avg_mse * 100).round(2)}%)" 83 | 84 | 85 | puts "\nTraining the network...\n\n" 86 | 87 | t1 = Time.now 88 | result = nn.optimize(test_cases, 0.01, 1_000 ) { |i, error| 89 | puts "#{i} #{error}" 90 | } 91 | 92 | # puts result 93 | puts "\nDone training the network: #{result[:iterations]} iterations, #{(result[:error] * 100).round(2)}% mse, #{(Time.now - t1).round(1)}s" 94 | 95 | 96 | puts "\nTesting the trained network..." 97 | 98 | success, failure, avg_mse = run_test.(nn, x_test, y_test) 99 | 100 | puts "Trained classification success: #{success}, failure: #{failure} (classification error: #{error_rate.(failure, x_test.length)}%, mse: #{(avg_mse * 100).round(2)}%)" 101 | -------------------------------------------------------------------------------- /iris.data: -------------------------------------------------------------------------------- 1 | 5.1,3.5,1.4,0.2,Iris-setosa 2 | 4.9,3.0,1.4,0.2,Iris-setosa 3 | 4.7,3.2,1.3,0.2,Iris-setosa 4 | 4.6,3.1,1.5,0.2,Iris-setosa 5 | 5.0,3.6,1.4,0.2,Iris-setosa 6 | 5.4,3.9,1.7,0.4,Iris-setosa 7 | 4.6,3.4,1.4,0.3,Iris-setosa 8 | 5.0,3.4,1.5,0.2,Iris-setosa 9 | 4.4,2.9,1.4,0.2,Iris-setosa 10 | 4.9,3.1,1.5,0.1,Iris-setosa 11 | 5.4,3.7,1.5,0.2,Iris-setosa 12 | 4.8,3.4,1.6,0.2,Iris-setosa 13 | 4.8,3.0,1.4,0.1,Iris-setosa 14 | 4.3,3.0,1.1,0.1,Iris-setosa 15 | 5.8,4.0,1.2,0.2,Iris-setosa 16 | 5.7,4.4,1.5,0.4,Iris-setosa 17 | 5.4,3.9,1.3,0.4,Iris-setosa 18 | 5.1,3.5,1.4,0.3,Iris-setosa 19 | 5.7,3.8,1.7,0.3,Iris-setosa 20 | 5.1,3.8,1.5,0.3,Iris-setosa 21 | 5.4,3.4,1.7,0.2,Iris-setosa 22 | 5.1,3.7,1.5,0.4,Iris-setosa 23 | 4.6,3.6,1.0,0.2,Iris-setosa 24 | 5.1,3.3,1.7,0.5,Iris-setosa 25 | 4.8,3.4,1.9,0.2,Iris-setosa 26 | 5.0,3.0,1.6,0.2,Iris-setosa 27 | 5.0,3.4,1.6,0.4,Iris-setosa 28 | 5.2,3.5,1.5,0.2,Iris-setosa 29 | 5.2,3.4,1.4,0.2,Iris-setosa 30 | 4.7,3.2,1.6,0.2,Iris-setosa 31 | 4.8,3.1,1.6,0.2,Iris-setosa 32 | 5.4,3.4,1.5,0.4,Iris-setosa 33 | 5.2,4.1,1.5,0.1,Iris-setosa 34 | 5.5,4.2,1.4,0.2,Iris-setosa 35 | 4.9,3.1,1.5,0.1,Iris-setosa 36 | 5.0,3.2,1.2,0.2,Iris-setosa 37 | 5.5,3.5,1.3,0.2,Iris-setosa 38 | 4.9,3.1,1.5,0.1,Iris-setosa 39 | 4.4,3.0,1.3,0.2,Iris-setosa 40 | 5.1,3.4,1.5,0.2,Iris-setosa 41 | 5.0,3.5,1.3,0.3,Iris-setosa 42 | 4.5,2.3,1.3,0.3,Iris-setosa 43 | 4.4,3.2,1.3,0.2,Iris-setosa 44 | 5.0,3.5,1.6,0.6,Iris-setosa 45 | 5.1,3.8,1.9,0.4,Iris-setosa 46 | 4.8,3.0,1.4,0.3,Iris-setosa 47 | 5.1,3.8,1.6,0.2,Iris-setosa 48 | 4.6,3.2,1.4,0.2,Iris-setosa 49 | 5.3,3.7,1.5,0.2,Iris-setosa 50 | 5.0,3.3,1.4,0.2,Iris-setosa 51 | 7.0,3.2,4.7,1.4,Iris-versicolor 52 | 6.4,3.2,4.5,1.5,Iris-versicolor 53 | 6.9,3.1,4.9,1.5,Iris-versicolor 54 | 5.5,2.3,4.0,1.3,Iris-versicolor 55 | 6.5,2.8,4.6,1.5,Iris-versicolor 56 | 5.7,2.8,4.5,1.3,Iris-versicolor 57 | 6.3,3.3,4.7,1.6,Iris-versicolor 58 | 4.9,2.4,3.3,1.0,Iris-versicolor 59 | 6.6,2.9,4.6,1.3,Iris-versicolor 60 | 5.2,2.7,3.9,1.4,Iris-versicolor 61 | 5.0,2.0,3.5,1.0,Iris-versicolor 62 | 5.9,3.0,4.2,1.5,Iris-versicolor 63 | 6.0,2.2,4.0,1.0,Iris-versicolor 64 | 6.1,2.9,4.7,1.4,Iris-versicolor 65 | 5.6,2.9,3.6,1.3,Iris-versicolor 66 | 6.7,3.1,4.4,1.4,Iris-versicolor 67 | 5.6,3.0,4.5,1.5,Iris-versicolor 68 | 5.8,2.7,4.1,1.0,Iris-versicolor 69 | 6.2,2.2,4.5,1.5,Iris-versicolor 70 | 5.6,2.5,3.9,1.1,Iris-versicolor 71 | 5.9,3.2,4.8,1.8,Iris-versicolor 72 | 6.1,2.8,4.0,1.3,Iris-versicolor 73 | 6.3,2.5,4.9,1.5,Iris-versicolor 74 | 6.1,2.8,4.7,1.2,Iris-versicolor 75 | 6.4,2.9,4.3,1.3,Iris-versicolor 76 | 6.6,3.0,4.4,1.4,Iris-versicolor 77 | 6.8,2.8,4.8,1.4,Iris-versicolor 78 | 6.7,3.0,5.0,1.7,Iris-versicolor 79 | 6.0,2.9,4.5,1.5,Iris-versicolor 80 | 5.7,2.6,3.5,1.0,Iris-versicolor 81 | 5.5,2.4,3.8,1.1,Iris-versicolor 82 | 5.5,2.4,3.7,1.0,Iris-versicolor 83 | 5.8,2.7,3.9,1.2,Iris-versicolor 84 | 6.0,2.7,5.1,1.6,Iris-versicolor 85 | 5.4,3.0,4.5,1.5,Iris-versicolor 86 | 6.0,3.4,4.5,1.6,Iris-versicolor 87 | 6.7,3.1,4.7,1.5,Iris-versicolor 88 | 6.3,2.3,4.4,1.3,Iris-versicolor 89 | 5.6,3.0,4.1,1.3,Iris-versicolor 90 | 5.5,2.5,4.0,1.3,Iris-versicolor 91 | 5.5,2.6,4.4,1.2,Iris-versicolor 92 | 6.1,3.0,4.6,1.4,Iris-versicolor 93 | 5.8,2.6,4.0,1.2,Iris-versicolor 94 | 5.0,2.3,3.3,1.0,Iris-versicolor 95 | 5.6,2.7,4.2,1.3,Iris-versicolor 96 | 5.7,3.0,4.2,1.2,Iris-versicolor 97 | 5.7,2.9,4.2,1.3,Iris-versicolor 98 | 6.2,2.9,4.3,1.3,Iris-versicolor 99 | 5.1,2.5,3.0,1.1,Iris-versicolor 100 | 5.7,2.8,4.1,1.3,Iris-versicolor 101 | 6.3,3.3,6.0,2.5,Iris-virginica 102 | 5.8,2.7,5.1,1.9,Iris-virginica 103 | 7.1,3.0,5.9,2.1,Iris-virginica 104 | 6.3,2.9,5.6,1.8,Iris-virginica 105 | 6.5,3.0,5.8,2.2,Iris-virginica 106 | 7.6,3.0,6.6,2.1,Iris-virginica 107 | 4.9,2.5,4.5,1.7,Iris-virginica 108 | 7.3,2.9,6.3,1.8,Iris-virginica 109 | 6.7,2.5,5.8,1.8,Iris-virginica 110 | 7.2,3.6,6.1,2.5,Iris-virginica 111 | 6.5,3.2,5.1,2.0,Iris-virginica 112 | 6.4,2.7,5.3,1.9,Iris-virginica 113 | 6.8,3.0,5.5,2.1,Iris-virginica 114 | 5.7,2.5,5.0,2.0,Iris-virginica 115 | 5.8,2.8,5.1,2.4,Iris-virginica 116 | 6.4,3.2,5.3,2.3,Iris-virginica 117 | 6.5,3.0,5.5,1.8,Iris-virginica 118 | 7.7,3.8,6.7,2.2,Iris-virginica 119 | 7.7,2.6,6.9,2.3,Iris-virginica 120 | 6.0,2.2,5.0,1.5,Iris-virginica 121 | 6.9,3.2,5.7,2.3,Iris-virginica 122 | 5.6,2.8,4.9,2.0,Iris-virginica 123 | 7.7,2.8,6.7,2.0,Iris-virginica 124 | 6.3,2.7,4.9,1.8,Iris-virginica 125 | 6.7,3.3,5.7,2.1,Iris-virginica 126 | 7.2,3.2,6.0,1.8,Iris-virginica 127 | 6.2,2.8,4.8,1.8,Iris-virginica 128 | 6.1,3.0,4.9,1.8,Iris-virginica 129 | 6.4,2.8,5.6,2.1,Iris-virginica 130 | 7.2,3.0,5.8,1.6,Iris-virginica 131 | 7.4,2.8,6.1,1.9,Iris-virginica 132 | 7.9,3.8,6.4,2.0,Iris-virginica 133 | 6.4,2.8,5.6,2.2,Iris-virginica 134 | 6.3,2.8,5.1,1.5,Iris-virginica 135 | 6.1,2.6,5.6,1.4,Iris-virginica 136 | 7.7,3.0,6.1,2.3,Iris-virginica 137 | 6.3,3.4,5.6,2.4,Iris-virginica 138 | 6.4,3.1,5.5,1.8,Iris-virginica 139 | 6.0,3.0,4.8,1.8,Iris-virginica 140 | 6.9,3.1,5.4,2.1,Iris-virginica 141 | 6.7,3.1,5.6,2.4,Iris-virginica 142 | 6.9,3.1,5.1,2.3,Iris-virginica 143 | 5.8,2.7,5.1,1.9,Iris-virginica 144 | 6.8,3.2,5.9,2.3,Iris-virginica 145 | 6.7,3.3,5.7,2.5,Iris-virginica 146 | 6.7,3.0,5.2,2.3,Iris-virginica 147 | 6.3,2.5,5.0,1.9,Iris-virginica 148 | 6.5,3.0,5.2,2.0,Iris-virginica 149 | 6.2,3.4,5.4,2.3,Iris-virginica 150 | 5.9,3.0,5.1,1.8,Iris-virginica 151 | -------------------------------------------------------------------------------- /example/iris.data: -------------------------------------------------------------------------------- 1 | 5.1,3.5,1.4,0.2,Iris-setosa 2 | 4.9,3.0,1.4,0.2,Iris-setosa 3 | 4.7,3.2,1.3,0.2,Iris-setosa 4 | 4.6,3.1,1.5,0.2,Iris-setosa 5 | 5.0,3.6,1.4,0.2,Iris-setosa 6 | 5.4,3.9,1.7,0.4,Iris-setosa 7 | 4.6,3.4,1.4,0.3,Iris-setosa 8 | 5.0,3.4,1.5,0.2,Iris-setosa 9 | 4.4,2.9,1.4,0.2,Iris-setosa 10 | 4.9,3.1,1.5,0.1,Iris-setosa 11 | 5.4,3.7,1.5,0.2,Iris-setosa 12 | 4.8,3.4,1.6,0.2,Iris-setosa 13 | 4.8,3.0,1.4,0.1,Iris-setosa 14 | 4.3,3.0,1.1,0.1,Iris-setosa 15 | 5.8,4.0,1.2,0.2,Iris-setosa 16 | 5.7,4.4,1.5,0.4,Iris-setosa 17 | 5.4,3.9,1.3,0.4,Iris-setosa 18 | 5.1,3.5,1.4,0.3,Iris-setosa 19 | 5.7,3.8,1.7,0.3,Iris-setosa 20 | 5.1,3.8,1.5,0.3,Iris-setosa 21 | 5.4,3.4,1.7,0.2,Iris-setosa 22 | 5.1,3.7,1.5,0.4,Iris-setosa 23 | 4.6,3.6,1.0,0.2,Iris-setosa 24 | 5.1,3.3,1.7,0.5,Iris-setosa 25 | 4.8,3.4,1.9,0.2,Iris-setosa 26 | 5.0,3.0,1.6,0.2,Iris-setosa 27 | 5.0,3.4,1.6,0.4,Iris-setosa 28 | 5.2,3.5,1.5,0.2,Iris-setosa 29 | 5.2,3.4,1.4,0.2,Iris-setosa 30 | 4.7,3.2,1.6,0.2,Iris-setosa 31 | 4.8,3.1,1.6,0.2,Iris-setosa 32 | 5.4,3.4,1.5,0.4,Iris-setosa 33 | 5.2,4.1,1.5,0.1,Iris-setosa 34 | 5.5,4.2,1.4,0.2,Iris-setosa 35 | 4.9,3.1,1.5,0.1,Iris-setosa 36 | 5.0,3.2,1.2,0.2,Iris-setosa 37 | 5.5,3.5,1.3,0.2,Iris-setosa 38 | 4.9,3.1,1.5,0.1,Iris-setosa 39 | 4.4,3.0,1.3,0.2,Iris-setosa 40 | 5.1,3.4,1.5,0.2,Iris-setosa 41 | 5.0,3.5,1.3,0.3,Iris-setosa 42 | 4.5,2.3,1.3,0.3,Iris-setosa 43 | 4.4,3.2,1.3,0.2,Iris-setosa 44 | 5.0,3.5,1.6,0.6,Iris-setosa 45 | 5.1,3.8,1.9,0.4,Iris-setosa 46 | 4.8,3.0,1.4,0.3,Iris-setosa 47 | 5.1,3.8,1.6,0.2,Iris-setosa 48 | 4.6,3.2,1.4,0.2,Iris-setosa 49 | 5.3,3.7,1.5,0.2,Iris-setosa 50 | 5.0,3.3,1.4,0.2,Iris-setosa 51 | 7.0,3.2,4.7,1.4,Iris-versicolor 52 | 6.4,3.2,4.5,1.5,Iris-versicolor 53 | 6.9,3.1,4.9,1.5,Iris-versicolor 54 | 5.5,2.3,4.0,1.3,Iris-versicolor 55 | 6.5,2.8,4.6,1.5,Iris-versicolor 56 | 5.7,2.8,4.5,1.3,Iris-versicolor 57 | 6.3,3.3,4.7,1.6,Iris-versicolor 58 | 4.9,2.4,3.3,1.0,Iris-versicolor 59 | 6.6,2.9,4.6,1.3,Iris-versicolor 60 | 5.2,2.7,3.9,1.4,Iris-versicolor 61 | 5.0,2.0,3.5,1.0,Iris-versicolor 62 | 5.9,3.0,4.2,1.5,Iris-versicolor 63 | 6.0,2.2,4.0,1.0,Iris-versicolor 64 | 6.1,2.9,4.7,1.4,Iris-versicolor 65 | 5.6,2.9,3.6,1.3,Iris-versicolor 66 | 6.7,3.1,4.4,1.4,Iris-versicolor 67 | 5.6,3.0,4.5,1.5,Iris-versicolor 68 | 5.8,2.7,4.1,1.0,Iris-versicolor 69 | 6.2,2.2,4.5,1.5,Iris-versicolor 70 | 5.6,2.5,3.9,1.1,Iris-versicolor 71 | 5.9,3.2,4.8,1.8,Iris-versicolor 72 | 6.1,2.8,4.0,1.3,Iris-versicolor 73 | 6.3,2.5,4.9,1.5,Iris-versicolor 74 | 6.1,2.8,4.7,1.2,Iris-versicolor 75 | 6.4,2.9,4.3,1.3,Iris-versicolor 76 | 6.6,3.0,4.4,1.4,Iris-versicolor 77 | 6.8,2.8,4.8,1.4,Iris-versicolor 78 | 6.7,3.0,5.0,1.7,Iris-versicolor 79 | 6.0,2.9,4.5,1.5,Iris-versicolor 80 | 5.7,2.6,3.5,1.0,Iris-versicolor 81 | 5.5,2.4,3.8,1.1,Iris-versicolor 82 | 5.5,2.4,3.7,1.0,Iris-versicolor 83 | 5.8,2.7,3.9,1.2,Iris-versicolor 84 | 6.0,2.7,5.1,1.6,Iris-versicolor 85 | 5.4,3.0,4.5,1.5,Iris-versicolor 86 | 6.0,3.4,4.5,1.6,Iris-versicolor 87 | 6.7,3.1,4.7,1.5,Iris-versicolor 88 | 6.3,2.3,4.4,1.3,Iris-versicolor 89 | 5.6,3.0,4.1,1.3,Iris-versicolor 90 | 5.5,2.5,4.0,1.3,Iris-versicolor 91 | 5.5,2.6,4.4,1.2,Iris-versicolor 92 | 6.1,3.0,4.6,1.4,Iris-versicolor 93 | 5.8,2.6,4.0,1.2,Iris-versicolor 94 | 5.0,2.3,3.3,1.0,Iris-versicolor 95 | 5.6,2.7,4.2,1.3,Iris-versicolor 96 | 5.7,3.0,4.2,1.2,Iris-versicolor 97 | 5.7,2.9,4.2,1.3,Iris-versicolor 98 | 6.2,2.9,4.3,1.3,Iris-versicolor 99 | 5.1,2.5,3.0,1.1,Iris-versicolor 100 | 5.7,2.8,4.1,1.3,Iris-versicolor 101 | 6.3,3.3,6.0,2.5,Iris-virginica 102 | 5.8,2.7,5.1,1.9,Iris-virginica 103 | 7.1,3.0,5.9,2.1,Iris-virginica 104 | 6.3,2.9,5.6,1.8,Iris-virginica 105 | 6.5,3.0,5.8,2.2,Iris-virginica 106 | 7.6,3.0,6.6,2.1,Iris-virginica 107 | 4.9,2.5,4.5,1.7,Iris-virginica 108 | 7.3,2.9,6.3,1.8,Iris-virginica 109 | 6.7,2.5,5.8,1.8,Iris-virginica 110 | 7.2,3.6,6.1,2.5,Iris-virginica 111 | 6.5,3.2,5.1,2.0,Iris-virginica 112 | 6.4,2.7,5.3,1.9,Iris-virginica 113 | 6.8,3.0,5.5,2.1,Iris-virginica 114 | 5.7,2.5,5.0,2.0,Iris-virginica 115 | 5.8,2.8,5.1,2.4,Iris-virginica 116 | 6.4,3.2,5.3,2.3,Iris-virginica 117 | 6.5,3.0,5.5,1.8,Iris-virginica 118 | 7.7,3.8,6.7,2.2,Iris-virginica 119 | 7.7,2.6,6.9,2.3,Iris-virginica 120 | 6.0,2.2,5.0,1.5,Iris-virginica 121 | 6.9,3.2,5.7,2.3,Iris-virginica 122 | 5.6,2.8,4.9,2.0,Iris-virginica 123 | 7.7,2.8,6.7,2.0,Iris-virginica 124 | 6.3,2.7,4.9,1.8,Iris-virginica 125 | 6.7,3.3,5.7,2.1,Iris-virginica 126 | 7.2,3.2,6.0,1.8,Iris-virginica 127 | 6.2,2.8,4.8,1.8,Iris-virginica 128 | 6.1,3.0,4.9,1.8,Iris-virginica 129 | 6.4,2.8,5.6,2.1,Iris-virginica 130 | 7.2,3.0,5.8,1.6,Iris-virginica 131 | 7.4,2.8,6.1,1.9,Iris-virginica 132 | 7.9,3.8,6.4,2.0,Iris-virginica 133 | 6.4,2.8,5.6,2.2,Iris-virginica 134 | 6.3,2.8,5.1,1.5,Iris-virginica 135 | 6.1,2.6,5.6,1.4,Iris-virginica 136 | 7.7,3.0,6.1,2.3,Iris-virginica 137 | 6.3,3.4,5.6,2.4,Iris-virginica 138 | 6.4,3.1,5.5,1.8,Iris-virginica 139 | 6.0,3.0,4.8,1.8,Iris-virginica 140 | 6.9,3.1,5.4,2.1,Iris-virginica 141 | 6.7,3.1,5.6,2.4,Iris-virginica 142 | 6.9,3.1,5.1,2.3,Iris-virginica 143 | 5.8,2.7,5.1,1.9,Iris-virginica 144 | 6.8,3.2,5.9,2.3,Iris-virginica 145 | 6.7,3.3,5.7,2.5,Iris-virginica 146 | 6.7,3.0,5.2,2.3,Iris-virginica 147 | 6.3,2.5,5.0,1.9,Iris-virginica 148 | 6.5,3.0,5.2,2.0,Iris-virginica 149 | 6.2,3.4,5.4,2.3,Iris-virginica 150 | 5.9,3.0,5.1,1.8,Iris-virginica 151 | -------------------------------------------------------------------------------- /lib/brains/net.rb: -------------------------------------------------------------------------------- 1 | require 'java' 2 | 3 | module Brains 4 | class Net 5 | attr_accessor :nn, :config 6 | 7 | def self.create(input, output, total_hidden_layers = 1, opts = {}) 8 | neurons_per_layer = opts[:neurons_per_layer] || 5 9 | 10 | config = com.dayosoft.nn.NeuralNet::Config.new(input, output, total_hidden_layers * neurons_per_layer + output) 11 | config.bias = opts[:bias] || 1.0 12 | config.outputBias = opts[:output_bias] || 1.0 13 | config.learningRate = opts[:learning_rate] || 0.1 14 | config.neuronsPerLayer = neurons_per_layer 15 | config.momentumFactor = opts[:momentum_factor] || 0.5 16 | config.isRecurrent = opts[:recurrent] || false 17 | config.backPropagationAlgorithm = opt_t_back_alg(opts[:train_method] || :standard) 18 | config.activationFunctionType = opt_to_func(opts[:activation_function] || :htan) 19 | config.outputActivationFunctionType = opt_to_func(opts[:output_function] || :sigmoid) 20 | config.errorFormula = opt_to_error_func(opts[:error] || :mean_squared) 21 | nn = com.dayosoft.nn.NeuralNet.new(config); 22 | 23 | Brains::Net.new.set_nn(nn).set_config(config) 24 | end 25 | 26 | def self.load(json_string) 27 | nn = com.dayosoft.nn.NeuralNet::loadStateFromJsonString(nil, json_string) 28 | config = nn.getConfig 29 | 30 | Brains::Net.new.set_nn(nn).set_config(config) 31 | end 32 | 33 | def randomize_weights(min = -1, max = 1) 34 | @nn.randomizeWeights(min, max) 35 | end 36 | 37 | def dump_weights 38 | @nn.dumpWeights.to_a.map(&:to_a) 39 | end 40 | 41 | def dump_biases 42 | @nn.dumpWeights.to_a.map(&:to_a) 43 | end 44 | 45 | def optimize(test_cases, target_error = 0.01, max_epoch = 1_000_000_000, callback_interval = 1000, &callback) 46 | inputs = [] 47 | outputs = [] 48 | 49 | test_cases.each do |item| 50 | inputs << item[0].to_java(Java::double) 51 | outputs << item[1].to_java(Java::double) 52 | end 53 | 54 | result = @nn.optimize(java.util.ArrayList.new(inputs), java.util.ArrayList.new(outputs), target_error, max_epoch, 55 | callback_interval, callback) 56 | { iterations: result.first, error: result.second } 57 | end 58 | 59 | def optimize_recurrent(test_cases, target_error = 0.01, max_epoch = 1_000_000_000, 60 | callback_interval = 1000, max_layers = 0, &callback) 61 | inputs = [] 62 | outputs = [] 63 | 64 | input_set = java.util.ArrayList.new 65 | output_set = java.util.ArrayList.new 66 | 67 | test_cases.each do |item| 68 | 69 | inputs = java.util.ArrayList.new 70 | outputs = java.util.ArrayList.new 71 | 72 | item[0].each do |item| 73 | inputs.add(item.to_java(Java::double)) 74 | end 75 | 76 | item[1].each do |item| 77 | outputs.add(item.to_java(Java::double)) 78 | end 79 | 80 | input_set.add(inputs) 81 | output_set.add(outputs) 82 | end 83 | 84 | result = @nn.optimizeRecurrent(input_set, output_set, target_error, max_layers, max_epoch, 85 | callback_interval, callback) 86 | 87 | { iterations: result.first, error: result.second } 88 | end 89 | 90 | def feed(input) 91 | # recurrent mode when array of array is passed. 92 | if input && input.size > 0 && input[0].kind_of?(Array) 93 | feed_recurrent(input) 94 | else 95 | result = @nn.feed(input.to_java(Java::double)).to_a 96 | @nn.updatePreviousOutputs if config.isRecurrent 97 | 98 | result 99 | end 100 | end 101 | 102 | # For a recurrent network, this resets hidden states back to zero 103 | def reset 104 | if config.isRecurrent 105 | @nn.resetRecurrentStates 106 | else 107 | puts "Warning not a recurrent network. This does nothing" 108 | end 109 | end 110 | 111 | def feed_recurrent(inputs) 112 | inputset = java.util.ArrayList.new 113 | inputs.each do |input| 114 | inputset.add(input.to_java(Java::double)) 115 | end 116 | 117 | output_set = @nn.feed(inputset).to_a 118 | output_set.collect do |output| 119 | output.to_a 120 | end 121 | end 122 | 123 | def to_json 124 | @nn.saveStateToJson 125 | end 126 | 127 | def set_nn(nn) 128 | @nn = nn 129 | self 130 | end 131 | 132 | def set_config(config) 133 | @config = config 134 | self 135 | end 136 | 137 | protected 138 | 139 | def initialize 140 | end 141 | 142 | private 143 | 144 | def self.opt_to_func(func) 145 | case func 146 | when :htan 147 | com.dayosoft.nn.Neuron::HTAN 148 | when :sigmoid 149 | com.dayosoft.nn.Neuron::SIGMOID 150 | when :softmax 151 | com.dayosoft.nn.Neuron::SOFTMAX 152 | when :rectifier 153 | com.dayosoft.nn.Neuron::RECTIFIER 154 | else 155 | raise "invalid activation function #{func}" 156 | end 157 | end 158 | 159 | def self.opt_t_back_alg(func) 160 | case func 161 | when :standard 162 | com.dayosoft.nn.NeuralNet::Config::STANDARD_BACKPROPAGATION 163 | when :rprop 164 | com.dayosoft.nn.NeuralNet::Config::RPROP_BACKPROPAGATION 165 | else 166 | raise "Invalid backpropagation method #{func}" 167 | end 168 | end 169 | 170 | def self.opt_to_error_func(func) 171 | case func 172 | when :mean_squared 173 | com.dayosoft.nn.NeuralNet::Config::MEAN_SQUARED 174 | when :cross_entropy 175 | com.dayosoft.nn.NeuralNet::Config::CROSS_ENTROPY 176 | else 177 | raise "Invalid Error Function #{func}" 178 | end 179 | end 180 | end 181 | end 182 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Brains 2 | 3 | A Feedforward neural network toolkit for JRuby. Easily add machine learning features 4 | to your ruby application using this Gem. Though there are faster native C implementations 5 | available (e.g. FANN) we need something that is simple, beginner friendly and just works. 6 | 7 | This java based implementation provides a balance of performance and ease of use. 8 | 9 | ## Installation 10 | 11 | Do note that this gem requires JRuby as it uses a java backend to run the neural network 12 | computations. 13 | 14 | Add this line to your application's Gemfile: 15 | 16 | ```ruby 17 | gem 'brains' 18 | ``` 19 | 20 | And then execute: 21 | 22 | $ bundle 23 | 24 | Or install it yourself as: 25 | 26 | $ gem install brains 27 | 28 | ## Features 29 | 30 | * Customizable network parameters depending on requirements 31 | * Fast (A bit slower than FANN but significantly faster than a pure ruby implementation) 32 | * NN backend implementation in Java which allows for platform agnostic implementation 33 | 34 | ## Usage 35 | 36 | The brains gem contains facilities for training and using the feedforward neural network 37 | 38 | Training (XOR example) 39 | -------- 40 | 41 | Initialize the neural net backend 42 | 43 | ```ruby 44 | require 'brains' 45 | 46 | 47 | # Build a 3 layer network: 4 input neurons, 4 hidden neurons, 3 output neurons 48 | # Bias neurons are automatically added to input + hidden layers; no need to specify these 49 | 50 | nn = Brains::Net.create(2 /* no. of inputs */, 1 /*no. of outputs */, 1 /*hidden layer*/, { neurons_per_layer: 4 }) 51 | nn.randomize_weights 52 | ``` 53 | 54 | Consider that we want to train the neural network to handle XOR computations 55 | 56 | ``` 57 | A B A XOR B 58 | 1 1 0 59 | 1 0 1 60 | 0 1 1 61 | 0 0 0 62 | ``` 63 | 64 | First we build the training data. This is an array of arrays with each item 65 | in the following format: 66 | 67 | ``` 68 | [ 69 | [[input1, input2, input3....], [expected1, expected2, expected3 ...]] 70 | [[input1, input2, input3....], [expected1, expected2, expected3 ...]] 71 | ] 72 | ``` 73 | 74 | ```ruby 75 | training_data = [ 76 | [[0.9, 0.9], [0.1]], 77 | [[0.9, 0.1], [0.9]], 78 | [[0.1, 0.9], [0.9]], 79 | [[0.1, 0.1], [0.1]], 80 | ] 81 | ``` 82 | Note that we map 1 = 0.9 and 0 = 0.1 since using absolute 1 and 0s might cause 83 | issues with certain neural networks. There are other techniques to "normalize" 84 | input, but this is beyond the scope of this example. 85 | 86 | Start training on the data by calling optimize. Here we use 0.01 as the expected 87 | MSE error before terminating and 1000 as the max epochs. 88 | 89 | ```ruby 90 | result = nn.optimize(training_data, 0.01, 1_000 ) { |i, error| 91 | puts "#{i} #{error}" 92 | } 93 | ``` 94 | 95 | To test the neural network you can call the feed method. 96 | 97 | nn.feed( [test_input1, test_input2, .....]) => [output1, output2, ...] 98 | 99 | Check if the network is trained. There are more advanced and proper techniques to check if 100 | a network is sufficiently trained, but this is beyond the scope of this example. 101 | 102 | ```ruby 103 | # test on untrained data 104 | test_data = [ 105 | [0.9, 0.9], 106 | [0.9, 0.1], 107 | [0.1, 0.9], 108 | [0.1, 0.1] 109 | ] 110 | 111 | results = test_data.collect { |item| 112 | nn.feed(item) 113 | } 114 | 115 | p results 116 | 117 | [[0.19717958808009528], [0.7983320405281495], [0.8386219299757574], [0.16609147896733775]] 118 | ``` 119 | 120 | Using the test data we can see the correlation and the neural network function now approximates 121 | the xor function with the desired error: 122 | 123 | ``` 124 | [0.9, 0.9] => [0.19717958808009528] 125 | [0.9, 0.1] => [0.7983320405281495] 126 | [0.1, 0.9] => [0.8386219299757574] 127 | [0.1, 0.1] => [0.16609147896733775] 128 | ``` 129 | 130 | Saving brain state 131 | ================== 132 | 133 | Save the neuron state at any time to a string using to_json 134 | 135 | ```ruby 136 | saved_state = nn.to_json 137 | ``` 138 | 139 | You can then save it to a file. You can then load it back using load() 140 | 141 | ```ruby 142 | nn = Brains::Net.load(saved_state) 143 | 144 | 145 | # use 146 | nn.feed([0.9, 0.9]) 147 | ``` 148 | 149 | For other samples please take a look at the example folder. 150 | 151 | Java Neural Network backend is based on: 152 | 153 | https://github.com/jedld/brains 154 | 155 | You can compile the java source code as brains.jar and use it directly with this gem. 156 | 157 | ## RNNs (Recurrent Neural Networks) 158 | 159 | For recurrent neural networks (Look at the sine function in the examples). Only 160 | the backpropagation through time (BPTT) training algorithm is 161 | supported for now. 162 | 163 | ## Development 164 | 165 | After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake spec` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment. 166 | 167 | To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and tags, and push the `.gem` file to [rubygems.org](https://rubygems.org). 168 | 169 | ## Resources 170 | 171 | Machine learning is still a rapidly evolving field and research is ongoing on various aspects of it. This is just the tip of the iceberg, the field of machine learning is extremely complex, below are various resources for the average developer to get started: 172 | 173 | ftp://ftp.sas.com/pub/neural/FAQ.html#questions 174 | 175 | ## Contributing 176 | 177 | Bug reports and pull requests are welcome on GitHub at https://github.com/[USERNAME]/brains. This project is intended to be a safe, welcoming space for collaboration, and contributors are expected to adhere to the [Contributor Covenant](http://contributor-covenant.org) code of conduct. 178 | 179 | 180 | ## License 181 | 182 | The gem is available as open source under the terms of the [MIT License](http://opensource.org/licenses/MIT). 183 | --------------------------------------------------------------------------------