├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── README.md ├── src ├── lib.rs └── main.rs └── tests └── xor.rs /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | /target/ 3 | **/*.rs.bk 4 | mnist/*.csv 5 | -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | [[package]] 2 | name = "bitflags" 3 | version = "1.0.1" 4 | source = "registry+https://github.com/rust-lang/crates.io-index" 5 | 6 | [[package]] 7 | name = "cloudabi" 8 | version = "0.0.3" 9 | source = "registry+https://github.com/rust-lang/crates.io-index" 10 | dependencies = [ 11 | "bitflags 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", 12 | ] 13 | 14 | [[package]] 15 | name = "csv" 16 | version = "1.0.0-beta.5" 17 | source = "registry+https://github.com/rust-lang/crates.io-index" 18 | dependencies = [ 19 | "csv-core 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)", 20 | "serde 1.0.37 (registry+https://github.com/rust-lang/crates.io-index)", 21 | ] 22 | 23 | [[package]] 24 | name = "csv-core" 25 | version = "0.1.4" 26 | source = "registry+https://github.com/rust-lang/crates.io-index" 27 | dependencies = [ 28 | "memchr 2.0.1 (registry+https://github.com/rust-lang/crates.io-index)", 29 | ] 30 | 31 | [[package]] 32 | name = "either" 33 | version = "1.5.0" 34 | source = "registry+https://github.com/rust-lang/crates.io-index" 35 | 36 | [[package]] 37 | name = "fuchsia-zircon" 38 | version = "0.3.3" 39 | source = "registry+https://github.com/rust-lang/crates.io-index" 40 | dependencies = [ 41 | "bitflags 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", 42 | "fuchsia-zircon-sys 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)", 43 | ] 44 | 45 | [[package]] 46 | name = "fuchsia-zircon-sys" 47 | version = "0.3.3" 48 | source = "registry+https://github.com/rust-lang/crates.io-index" 49 | 50 | [[package]] 51 | name = "itertools" 52 | version = "0.7.8" 53 | source = "registry+https://github.com/rust-lang/crates.io-index" 54 | dependencies = [ 55 | "either 1.5.0 (registry+https://github.com/rust-lang/crates.io-index)", 56 | ] 57 | 58 | [[package]] 59 | name = "libc" 60 | version = "0.2.40" 61 | source = "registry+https://github.com/rust-lang/crates.io-index" 62 | 63 | [[package]] 64 | name = "matrixmultiply" 65 | version = "0.1.14" 66 | source = "registry+https://github.com/rust-lang/crates.io-index" 67 | dependencies = [ 68 | "rawpointer 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", 69 | ] 70 | 71 | [[package]] 72 | name = "memchr" 73 | version = "2.0.1" 74 | source = "registry+https://github.com/rust-lang/crates.io-index" 75 | dependencies = [ 76 | "libc 0.2.40 (registry+https://github.com/rust-lang/crates.io-index)", 77 | ] 78 | 79 | [[package]] 80 | name = "ndarray" 81 | version = "0.11.2" 82 | source = "registry+https://github.com/rust-lang/crates.io-index" 83 | dependencies = [ 84 | "itertools 0.7.8 (registry+https://github.com/rust-lang/crates.io-index)", 85 | "matrixmultiply 0.1.14 (registry+https://github.com/rust-lang/crates.io-index)", 86 | "num-complex 0.1.43 (registry+https://github.com/rust-lang/crates.io-index)", 87 | "num-traits 0.1.43 (registry+https://github.com/rust-lang/crates.io-index)", 88 | ] 89 | 90 | [[package]] 91 | name = "num-complex" 92 | version = "0.1.43" 93 | source = "registry+https://github.com/rust-lang/crates.io-index" 94 | dependencies = [ 95 | "num-traits 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", 96 | ] 97 | 98 | [[package]] 99 | name = "num-traits" 100 | version = "0.1.43" 101 | source = "registry+https://github.com/rust-lang/crates.io-index" 102 | dependencies = [ 103 | "num-traits 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", 104 | ] 105 | 106 | [[package]] 107 | name = "num-traits" 108 | version = "0.2.2" 109 | source = "registry+https://github.com/rust-lang/crates.io-index" 110 | 111 | [[package]] 112 | name = "rand" 113 | version = "0.5.0-pre.0" 114 | source = "registry+https://github.com/rust-lang/crates.io-index" 115 | dependencies = [ 116 | "cloudabi 0.0.3 (registry+https://github.com/rust-lang/crates.io-index)", 117 | "fuchsia-zircon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)", 118 | "libc 0.2.40 (registry+https://github.com/rust-lang/crates.io-index)", 119 | "rand_core 0.1.0-pre.0 (registry+https://github.com/rust-lang/crates.io-index)", 120 | "winapi 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", 121 | ] 122 | 123 | [[package]] 124 | name = "rand_core" 125 | version = "0.1.0-pre.0" 126 | source = "registry+https://github.com/rust-lang/crates.io-index" 127 | 128 | [[package]] 129 | name = "rawpointer" 130 | version = "0.1.0" 131 | source = "registry+https://github.com/rust-lang/crates.io-index" 132 | 133 | [[package]] 134 | name = "serde" 135 | version = "1.0.37" 136 | source = "registry+https://github.com/rust-lang/crates.io-index" 137 | 138 | [[package]] 139 | name = "tsetlin_machine" 140 | version = "0.1.0" 141 | dependencies = [ 142 | "csv 1.0.0-beta.5 (registry+https://github.com/rust-lang/crates.io-index)", 143 | "ndarray 0.11.2 (registry+https://github.com/rust-lang/crates.io-index)", 144 | "rand 0.5.0-pre.0 (registry+https://github.com/rust-lang/crates.io-index)", 145 | ] 146 | 147 | [[package]] 148 | name = "winapi" 149 | version = "0.3.4" 150 | source = "registry+https://github.com/rust-lang/crates.io-index" 151 | dependencies = [ 152 | "winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", 153 | "winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", 154 | ] 155 | 156 | [[package]] 157 | name = "winapi-i686-pc-windows-gnu" 158 | version = "0.4.0" 159 | source = "registry+https://github.com/rust-lang/crates.io-index" 160 | 161 | [[package]] 162 | name = "winapi-x86_64-pc-windows-gnu" 163 | version = "0.4.0" 164 | source = "registry+https://github.com/rust-lang/crates.io-index" 165 | 166 | [metadata] 167 | "checksum bitflags 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "b3c30d3802dfb7281680d6285f2ccdaa8c2d8fee41f93805dba5c4cf50dc23cf" 168 | "checksum cloudabi 0.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "ddfc5b9aa5d4507acaf872de71051dfd0e309860e88966e1051e462a077aac4f" 169 | "checksum csv 1.0.0-beta.5 (registry+https://github.com/rust-lang/crates.io-index)" = "e7a9e063dcebdb56c306f23e672bfd31df3da8ec5f6d696b35f2c29c2a9572f0" 170 | "checksum csv-core 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "4dd8e6d86f7ba48b4276ef1317edc8cc36167546d8972feb4a2b5fec0b374105" 171 | "checksum either 1.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3be565ca5c557d7f59e7cfcf1844f9e3033650c929c6566f511e8005f205c1d0" 172 | "checksum fuchsia-zircon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2e9763c69ebaae630ba35f74888db465e49e259ba1bc0eda7d06f4a067615d82" 173 | "checksum fuchsia-zircon-sys 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7" 174 | "checksum itertools 0.7.8 (registry+https://github.com/rust-lang/crates.io-index)" = "f58856976b776fedd95533137617a02fb25719f40e7d9b01c7043cd65474f450" 175 | "checksum libc 0.2.40 (registry+https://github.com/rust-lang/crates.io-index)" = "6fd41f331ac7c5b8ac259b8bf82c75c0fb2e469bbf37d2becbba9a6a2221965b" 176 | "checksum matrixmultiply 0.1.14 (registry+https://github.com/rust-lang/crates.io-index)" = "cac1a66eab356036af85ea093101a14223dc6e3f4c02a59b7d572e5b93270bf7" 177 | "checksum memchr 2.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "796fba70e76612589ed2ce7f45282f5af869e0fdd7cc6199fa1aa1f1d591ba9d" 178 | "checksum ndarray 0.11.2 (registry+https://github.com/rust-lang/crates.io-index)" = "0e3d24c5ba54015d7d5203ca6f00d4cc16c71042bf7f7be26f091236f390a16a" 179 | "checksum num-complex 0.1.43 (registry+https://github.com/rust-lang/crates.io-index)" = "b288631d7878aaf59442cffd36910ea604ecd7745c36054328595114001c9656" 180 | "checksum num-traits 0.1.43 (registry+https://github.com/rust-lang/crates.io-index)" = "92e5113e9fd4cc14ded8e499429f396a20f98c772a47cc8622a736e1ec843c31" 181 | "checksum num-traits 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "dee092fcdf725aee04dd7da1d21debff559237d49ef1cb3e69bcb8ece44c7364" 182 | "checksum rand 0.5.0-pre.0 (registry+https://github.com/rust-lang/crates.io-index)" = "bdabe8654fc412303409a9ddb23cd8878cd54b1f0a7112b6cc95eb9fe684f3f1" 183 | "checksum rand_core 0.1.0-pre.0 (registry+https://github.com/rust-lang/crates.io-index)" = "2362a41734390a5953cfbf12dbb74b4573137f7ba9dad344bba804ea4355f23a" 184 | "checksum rawpointer 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ebac11a9d2e11f2af219b8b8d833b76b1ea0e054aa0e8d8e9e4cbde353bdf019" 185 | "checksum serde 1.0.37 (registry+https://github.com/rust-lang/crates.io-index)" = "d3bcee660dcde8f52c3765dd9ca5ee36b4bf35470a738eb0bd5a8752b0389645" 186 | "checksum winapi 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "04e3bd221fcbe8a271359c04f21a76db7d0c6028862d1bb5512d85e1e2eb5bb3" 187 | "checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" 188 | "checksum winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" 189 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tsetlin_machine" 3 | version = "0.1.0" 4 | authors = ["Khaled Sharif "] 5 | 6 | [dependencies] 7 | rand = "0.5.0-pre.0" 8 | csv = "1.0.0-beta.5" 9 | ndarray = "0.11.2" 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Tsetlin Machine implementation in Rust 2 | 3 | A "Tsetlin Machine" solves complex pattern recognition problems with easy-to-interpret 4 | propositional formulas, and is composed of a collective of 5 | [Tsetlin Automata](https://en.wikipedia.org/wiki/Learning_automata). The idea of 6 | the machine was proposed in 7 | [a paper by Ole-Christoffer Granmo](https://arxiv.org/abs/1804.01508). 8 | 9 | ## Running the XOR test 10 | 11 | 12 | - Clone this repository using `git clone https://github.com/KhaledSharif/TsetlinMachine.git` 13 | - Inside the repository root folder, run `cargo test` 14 | - The test will run the XOR example found in `tests/xor.rs` 15 | - The test will only pass if the Tsetlin Machine reaches an accuracy greater than 99% on XOR 16 | 17 | ## Training and testing on MNIST 18 | 19 | - Get [the MNIST data from Kaggle](https://www.kaggle.com/c/digit-recognizer/data) in CSV form 20 | - Create a folder called `mnist` in the same folder that contains `src` and `tests` 21 | - Copy `train.csv` and `test.csv` into the newly created `mnist` folder 22 | - Run `cargo run` from the repository root folder 23 | - Read the code inside `src/main.rs` to get a better understanding 24 | 25 | ## Example XOR code 26 | 27 | ```rust 28 | let mut tm = tsetlin_machine(); 29 | tm.create(2, 2, 10); 30 | 31 | let mut rng = thread_rng(); 32 | let mut average_error : f32 = 1.0; 33 | 34 | for e in 0..1000 35 | { 36 | let input_vector = &inputs[e % 4]; 37 | { 38 | let output_vector = tm.activate(input_vector.to_vec()); 39 | let mut correct = false; 40 | if (input_vector[0] == input_vector[1]) && (!output_vector[0] && output_vector[1]) 41 | { 42 | correct = true; 43 | } 44 | else if output_vector[0] && !output_vector[1] 45 | { 46 | correct = true; 47 | } 48 | average_error = 0.99 * average_error + 0.01 * (if !correct {1.0} else {0.0}); 49 | println!("{} {} -> {} {} | {}", input_vector[0], input_vector[1], output_vector[0], output_vector[1], average_error); 50 | } 51 | tm.learn(&outputs[e % 4], 4.0, 4.0, &mut rng); 52 | } 53 | ``` 54 | 55 | ## Example XOR output 56 | 57 | ``` 58 | true true -> false true | 0.00007643679 59 | false false -> false true | 0.00007567242 60 | false true -> true false | 0.0000749157 61 | true false -> true false | 0.00007416654 62 | true true -> false true | 0.000073424875 63 | ``` 64 | 65 | ## Original implementation 66 | 67 | This repository is [a translation of this repository](https://github.com/222464/TsetlinMachine), 68 | which is an implementation of the Tsetlin Machine in C++. 69 | 70 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | extern crate rand; 2 | use rand::Rng; 3 | use rand::ThreadRng; 4 | use rand::distributions::Uniform; 5 | use std::cmp::min; 6 | 7 | pub struct TsetlinMachine 8 | { 9 | input_states : Vec, 10 | output_states : Vec, 11 | outputs : Vec, 12 | } 13 | 14 | impl TsetlinMachine 15 | { 16 | pub fn new() -> TsetlinMachine 17 | { 18 | TsetlinMachine 19 | { 20 | input_states : Vec::new(), 21 | output_states : Vec::new(), 22 | outputs : Vec::new(), 23 | } 24 | } 25 | 26 | fn inclusion_update(&mut self, oi : usize, ci : usize, ai : usize) 27 | { 28 | let inclusion = self.outputs[oi].clauses[ci].automata_states[ai] > 0; 29 | let it = self.outputs[oi].clauses[ci].inclusions.iter().position(|&s| s == ai); 30 | if inclusion 31 | { 32 | if it.is_none() 33 | { 34 | self.outputs[oi].clauses[ci].inclusions.push(ai); 35 | } 36 | } 37 | else 38 | { 39 | if it.is_some() 40 | { 41 | self.outputs[oi].clauses[ci].inclusions.remove(it.unwrap()); 42 | } 43 | } 44 | } 45 | 46 | fn modify_phase_one(&mut self, oi : usize, ci : usize, s_inverse : f32, s_inverse_conjugate : f32, rng : &mut ThreadRng) 47 | { 48 | let clause_state = self.outputs[oi].clauses[ci].state; 49 | for ai in 0..self.outputs[oi].clauses[ci].automata_states.len() 50 | { 51 | let input = if ai >= self.input_states.len() {!self.input_states[ai - self.input_states.len()]} else {self.input_states[ai]}; 52 | let inclusion = self.outputs[oi].clauses[ci].automata_states[ai] > 0; 53 | let s : f32 = rng.sample(Uniform); 54 | if clause_state 55 | { 56 | if input 57 | { 58 | if inclusion 59 | { 60 | if s < s_inverse_conjugate 61 | { 62 | self.outputs[oi].clauses[ci].automata_states[ai] += 1; 63 | self.inclusion_update(oi, ci, ai); 64 | } 65 | } 66 | else 67 | { 68 | if s < s_inverse_conjugate 69 | { 70 | self.outputs[oi].clauses[ci].automata_states[ai] += 1; 71 | self.inclusion_update(oi, ci, ai); 72 | } 73 | } 74 | } 75 | else 76 | { 77 | if !inclusion && s < s_inverse 78 | { 79 | self.outputs[oi].clauses[ci].automata_states[ai] -= 1; 80 | self.inclusion_update(oi, ci, ai); 81 | } 82 | } 83 | } 84 | else 85 | { 86 | if input 87 | { 88 | if inclusion 89 | { 90 | if s < s_inverse 91 | { 92 | self.outputs[oi].clauses[ci].automata_states[ai] -= 1; 93 | self.inclusion_update(oi, ci, ai); 94 | } 95 | } 96 | else 97 | { 98 | if s < s_inverse 99 | { 100 | self.outputs[oi].clauses[ci].automata_states[ai] -= 1; 101 | self.inclusion_update(oi, ci, ai); 102 | } 103 | } 104 | } 105 | else { 106 | if inclusion 107 | { 108 | if s < s_inverse 109 | { 110 | self.outputs[oi].clauses[ci].automata_states[ai] -= 1; 111 | self.inclusion_update(oi, ci, ai); 112 | } 113 | } 114 | else 115 | { 116 | if s < s_inverse 117 | { 118 | self.outputs[oi].clauses[ci].automata_states[ai] -= 1; 119 | self.inclusion_update(oi, ci, ai); 120 | } 121 | } 122 | } 123 | } 124 | } 125 | } 126 | 127 | fn modify_phase_two(&mut self, oi : usize, ci : usize) 128 | { 129 | let clause_state = self.outputs[oi].clauses[ci].state; 130 | for ai in 0..self.outputs[oi].clauses[ci].automata_states.len() 131 | { 132 | let input = if ai >= self.input_states.len() {!self.input_states[ai - self.input_states.len()]} else {self.input_states[ai]}; 133 | let inclusion = self.outputs[oi].clauses[ci].automata_states[ai] > 0; 134 | if clause_state && !input && !inclusion 135 | { 136 | self.outputs[oi].clauses[ci].automata_states[ai] += 1; 137 | self.inclusion_update(oi, ci, ai); 138 | } 139 | } 140 | } 141 | 142 | pub fn create(&mut self, number_of_inputs : usize, number_of_outputs : usize, clauses_per_output : usize) 143 | { 144 | self.input_states.resize(number_of_inputs, false); 145 | self.outputs.resize(number_of_outputs, create_null_output()); 146 | for oi in 0..number_of_outputs 147 | { 148 | self.outputs[oi].clauses.resize(clauses_per_output, create_null_clause()); 149 | for ci in 0..clauses_per_output 150 | { 151 | self.outputs[oi].clauses[ci].automata_states.resize(number_of_inputs * (2 as usize), 0); 152 | } 153 | } 154 | self.output_states.resize(number_of_outputs, false); 155 | } 156 | 157 | pub fn learn(&mut self, target_output_states : &Vec, s : f32, t : f32, rng : &mut ThreadRng) 158 | { 159 | let s_inv = 1.0 / s; 160 | let s_inv_conj = 1.0 - s_inv; 161 | for oi in 0..self.outputs.len() 162 | { 163 | let clamped_sum = t.min((-t).max(self.outputs[oi].sum as f32)); 164 | let rescale = 1.0 / ((2.0 * t) as f32); 165 | let probability_feedback_alpha = (t - clamped_sum) * rescale; 166 | let probability_feedback_beta = (t + clamped_sum) * rescale; 167 | 168 | for ci in 0..self.outputs[oi].clauses.len() 169 | { 170 | let s : f32 = rng.sample(Uniform); 171 | if ci % 2 == 0 172 | { 173 | if target_output_states[oi] 174 | { 175 | if s < probability_feedback_alpha 176 | { 177 | self.modify_phase_one(oi, ci, s_inv, s_inv_conj, rng); 178 | } 179 | } 180 | else if s < probability_feedback_beta 181 | { 182 | self.modify_phase_two(oi, ci); 183 | } 184 | } 185 | else 186 | { 187 | if target_output_states[oi] 188 | { 189 | if s < probability_feedback_alpha 190 | { 191 | self.modify_phase_two(oi, ci); 192 | } 193 | } 194 | else if s < probability_feedback_beta 195 | { 196 | self.modify_phase_one(oi, ci, s_inv, s_inv_conj, rng); 197 | } 198 | } 199 | } 200 | } 201 | } 202 | 203 | pub fn activate(&mut self, input_states : Vec) -> &Vec 204 | { 205 | self.input_states = input_states; 206 | for (outputs_index, mut outputs_element) in self.outputs.iter_mut().enumerate() 207 | { 208 | let mut sum = 0; 209 | for (clauses_index, clauses_element) in outputs_element.clauses.iter_mut().enumerate() 210 | { 211 | let mut state = true; 212 | for cit in clauses_element.inclusions.iter() 213 | { 214 | let ai = *cit; 215 | if ai >= self.input_states.len() 216 | { 217 | state = state && !self.input_states[min(self.input_states.len() - 1, ai - self.input_states.len())]; 218 | } 219 | else 220 | { 221 | state = state && self.input_states[ai]; 222 | } 223 | } 224 | clauses_element.state = state; 225 | { 226 | let _state = if state {1} else {0}; 227 | sum += if clauses_index % 2 == 0 {_state} else {-_state}; 228 | } 229 | } 230 | outputs_element.sum = sum; 231 | self.output_states[outputs_index] = sum > 0; 232 | } 233 | &self.output_states 234 | } 235 | } 236 | 237 | struct Clause 238 | { 239 | automata_states : Vec, 240 | inclusions : Vec, 241 | state : bool, 242 | } 243 | 244 | struct Output 245 | { 246 | clauses : Vec, 247 | sum : i32, 248 | } 249 | 250 | fn create_null_output() -> Output 251 | { 252 | Output 253 | { 254 | clauses: create_null_clauses_vector(), 255 | sum: 0, 256 | } 257 | } 258 | 259 | fn create_null_clauses_vector() -> Vec 260 | { 261 | Vec::new() 262 | } 263 | 264 | fn create_null_clause() -> Clause 265 | { 266 | Clause 267 | { 268 | automata_states: Vec::new(), 269 | inclusions: Vec::new(), 270 | state: false, 271 | } 272 | } 273 | 274 | impl Clone for Output 275 | { 276 | fn clone(&self) -> Output 277 | { 278 | let c = &self.clauses; 279 | Output 280 | { 281 | clauses: c.to_vec(), 282 | sum: self.sum, 283 | } 284 | } 285 | } 286 | 287 | impl Clone for Clause 288 | { 289 | fn clone(&self) -> Clause 290 | { 291 | let a = &self.automata_states; 292 | let i = &self.inclusions; 293 | Clause 294 | { 295 | automata_states: a.to_vec(), 296 | inclusions: i.to_vec(), 297 | state: self.state, 298 | } 299 | } 300 | } 301 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | extern crate rand; 2 | use rand::thread_rng; 3 | 4 | extern crate tsetlin_machine; 5 | use tsetlin_machine::TsetlinMachine; 6 | 7 | extern crate csv; 8 | use csv::Reader; 9 | 10 | fn main() 11 | { 12 | let training_file_path = "mnist/train.csv"; 13 | let testing_file_path = "mnist/test.csv"; 14 | 15 | let (training_inputs, training_outputs) = read(training_file_path); 16 | 17 | println!("Training dataset"); 18 | println!("Inputs length: {}, Inputs[0] length: {}", training_inputs.len(), training_inputs[0].len()); 19 | println!("Outputs length: {}, Outputs[0] length: {}", training_outputs.len(), training_outputs[0].len()); 20 | 21 | let (testing_inputs, testing_outputs) = read(testing_file_path); 22 | 23 | println!("Testing dataset"); 24 | println!("Inputs length: {}, Inputs[0] length: {}", testing_inputs.len(), testing_inputs[0].len()); 25 | println!("Outputs length: {}, Outputs[0] length: {}", testing_outputs.len(), testing_outputs[0].len()); 26 | 27 | let mut tm = TsetlinMachine::new(); 28 | tm.create(training_inputs[0].len(), training_outputs[0].len(), 10); 29 | 30 | let mut rng = thread_rng(); 31 | let mut average_training_error : f32 = 1.0; 32 | let mut average_testing_error : f32 = 1.0; 33 | 34 | loop 35 | { 36 | for e in 0 .. (training_inputs.len() - 1) 37 | { 38 | let expected_output_vector = &training_outputs[e]; 39 | { 40 | { 41 | let input_vector = &training_inputs[e]; 42 | let correct = check_two_vectors(expected_output_vector, tm.activate(input_vector.to_vec())); 43 | average_training_error = 0.99 * average_training_error + 0.01 * (if !correct {1.0} else {0.0}); 44 | } 45 | 46 | if e % 10 == 0 47 | { 48 | println!( 49 | "{}% of training dataset | {}% training accuracy", 50 | (e as f32 / training_inputs.len() as f32) * 100.0, 51 | (1.0 - average_training_error) * 100.0, 52 | ); 53 | } 54 | } 55 | tm.learn(expected_output_vector, 4.0, 4.0, &mut rng); 56 | } 57 | 58 | for f in 0 .. (testing_inputs.len() - 1) 59 | { 60 | let expected_output_vector = &testing_outputs[f]; 61 | { 62 | { 63 | let input_vector = &testing_inputs[f]; 64 | let correct = check_two_vectors(expected_output_vector, tm.activate(input_vector.to_vec())); 65 | average_testing_error = 0.99 * average_testing_error + 0.01 * (if !correct {1.0} else {0.0}); 66 | } 67 | 68 | if f % 10 == 0 69 | { 70 | println!( 71 | "{}% of testing dataset | {}% testing accuracy", 72 | (f as f32 / testing_inputs.len() as f32) * 100.0, 73 | (1.0 - average_testing_error) * 100.0, 74 | ); 75 | } 76 | } 77 | } 78 | } 79 | } 80 | 81 | struct LabelWithData 82 | { 83 | label : Vec, 84 | data : Vec, 85 | } 86 | 87 | type BooleanMatrix = Vec>; 88 | type BooleanMatrixTuple = (BooleanMatrix, BooleanMatrix); 89 | 90 | fn read(file_path : &str) -> BooleanMatrixTuple 91 | { 92 | fn one_hot_encoder(n: u16) -> Vec 93 | { 94 | let mut m : Vec = (0..).take(10).map(|_x| false).collect(); 95 | m[n as usize] = true; 96 | return m; 97 | } 98 | 99 | fn converter(list: Vec) -> BooleanMatrixTuple 100 | { 101 | let mut inputs : BooleanMatrix = Vec::new(); 102 | let mut outputs : BooleanMatrix = Vec::new(); 103 | 104 | for l in list 105 | { 106 | inputs.push(l.data); 107 | outputs.push(l.label); 108 | } 109 | 110 | (inputs, outputs) 111 | } 112 | 113 | let rdr_with_error = Reader::from_path(file_path); 114 | assert!(!rdr_with_error.is_err()); 115 | 116 | let mut rdr = rdr_with_error.ok().unwrap(); 117 | 118 | let mut training_data : Vec = Vec::new(); 119 | for result in rdr.records() 120 | { 121 | let record = result.ok().unwrap(); 122 | let vector : Vec = record 123 | .iter() 124 | .map(|x| x.parse::().unwrap()) 125 | .collect(); 126 | 127 | { 128 | let label = one_hot_encoder(vector[0]); 129 | let data = vector[1..].to_vec().iter().map(|&x| x > 128).collect(); 130 | 131 | { 132 | let lwv = LabelWithData 133 | { 134 | label : label, 135 | data : data, 136 | }; 137 | training_data.push(lwv); 138 | } 139 | } 140 | } 141 | 142 | converter(training_data) 143 | } 144 | 145 | fn check_two_vectors(y_true : &Vec, y_pred : &Vec) -> bool 146 | { 147 | assert_eq!(y_true.len(), y_pred.len()); 148 | assert!(y_true.len() > 0); 149 | 150 | for i in 0..y_true.len() 151 | { 152 | if (y_true[i] && !y_pred[i]) || (!y_true[i] && y_pred[i]) 153 | { 154 | return false; 155 | } 156 | } 157 | return true; 158 | } 159 | -------------------------------------------------------------------------------- /tests/xor.rs: -------------------------------------------------------------------------------- 1 | extern crate tsetlin_machine; 2 | use tsetlin_machine::TsetlinMachine; 3 | 4 | extern crate rand; 5 | use rand::thread_rng; 6 | 7 | #[test] 8 | fn test_xor_convergence() 9 | { 10 | let inputs : Vec> = 11 | [ 12 | [ 0, 0 ], 13 | [ 0, 1 ], 14 | [ 1, 0 ], 15 | [ 1, 1 ], 16 | ] 17 | .iter() 18 | .map(|x| x.iter().map(|&y| y == 1).collect::>().to_vec()) 19 | .collect(); 20 | 21 | let outputs : Vec> = 22 | [ 23 | [ 0, 1 ], 24 | [ 1, 0 ], 25 | [ 1, 0 ], 26 | [ 0, 1 ], 27 | ] 28 | .iter() 29 | .map(|x| x.iter().map(|&y| y == 1).collect::>().to_vec()) 30 | .collect(); 31 | 32 | let mut tm = TsetlinMachine::new(); 33 | tm.create(2, 2, 10); 34 | 35 | let mut rng = thread_rng(); 36 | let mut average_error : f32 = 1.0; 37 | 38 | for e in 0..5000 39 | { 40 | let input_vector = &inputs[e % 4]; 41 | { 42 | let output_vector = tm.activate(input_vector.to_vec()); 43 | let mut correct = false; 44 | if (input_vector[0] == input_vector[1]) && (!output_vector[0] && output_vector[1]) 45 | { 46 | correct = true; 47 | } 48 | else if output_vector[0] && !output_vector[1] 49 | { 50 | correct = true; 51 | } 52 | average_error = 0.99 * average_error + 0.01 * (if !correct {1.0} else {0.0}); 53 | } 54 | tm.learn(&outputs[e % 4], 4.0, 4.0, &mut rng); 55 | if average_error < 0.01 56 | { 57 | break; 58 | } 59 | } 60 | 61 | assert!(average_error < 0.01); 62 | } --------------------------------------------------------------------------------