├── .editorconfig ├── .gitignore ├── .travis.yml ├── Cargo.toml ├── LICENSE ├── README.md ├── scripts ├── id_rsa.enc ├── travis-doc-upload.cfg └── travis-doc-upload.sh ├── src ├── activation_func.rs ├── cascade_params.rs ├── error.rs ├── error_func.rs ├── lib.rs ├── net_type.rs ├── stop_func.rs ├── train_algorithm.rs └── train_data.rs └── test_files └── xor.data /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | end_of_line = lf 5 | charset = utf-8 6 | trim_trailing_whitespace = true 7 | insert_final_newline = true 8 | indent_style = space 9 | indent_size = 4 10 | 11 | [*.md] 12 | trim_trailing_whitespace = false 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | Cargo.lock 3 | *.bk 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: rust 2 | rust: 3 | - stable 4 | - beta 5 | - nightly 6 | sudo: false 7 | script: 8 | - cargo build --verbose 9 | - cargo test --verbose 10 | - cargo build --features "double" --verbose 11 | - cargo test --features "double" --verbose 12 | - cargo doc --verbose 13 | after_success: curl https://raw.githubusercontent.com/afck/fann-rs/master/scripts/travis-doc-upload.sh | sh 14 | addons: 15 | apt: 16 | packages: 17 | - libfann-dev 18 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "fann" 3 | version = "0.1.8" 4 | authors = ["Andreas Fackler "] 5 | license = "LGPL-3.0" 6 | documentation = "https://afck.github.io/docs/fann-rs/fann" 7 | repository = "https://github.com/afck/fann-rs" 8 | readme = "README.md" 9 | description = """ 10 | Wrapper for the Fast Artificial Neural Networks library 11 | """ 12 | keywords = ["neural", "network", "fann", "classifier", "backpropagation"] 13 | 14 | [features] 15 | # Use double precision. 16 | double = ["fann-sys/double"] 17 | 18 | [dependencies] 19 | fann-sys = "0.1.8" 20 | libc = "~0.2.20" 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | 167 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fann-rs 2 | 3 | [![Build Status](https://travis-ci.org/afck/fann-rs.svg?branch=master)](https://travis-ci.org/afck/fann-rs) 4 | [![Crates.io](http://meritbadge.herokuapp.com/fann)](https://crates.io/crates/fann) 5 | 6 | [Rust](http://www.rust-lang.org/) wrapper for the 7 | [Fast Artificial Neural Network](http://leenissen.dk/fann/wp/) (FANN) library. This crate provides a 8 | safe interface to FANN on top of the 9 | [low-level bindings fann-sys-rs](https://github.com/afck/fann-sys-rs). 10 | 11 | [Documentation](https://afck.github.io/docs/fann-rs/fann) 12 | 13 | 14 | ## Usage 15 | 16 | Add `fann` and `libc` to the list of dependencies in your `Cargo.toml`: 17 | 18 | ```toml 19 | [dependencies] 20 | fann = "*" 21 | libc = "*" 22 | ``` 23 | 24 | and this to your crate root: 25 | 26 | ```rust 27 | extern crate fann; 28 | extern crate libc; 29 | ``` 30 | 31 | Usage examples are included in the [Documentation](https://afck.github.io/docs/fann-rs/fann). 32 | -------------------------------------------------------------------------------- /scripts/id_rsa.enc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/afck/fann-rs/bd824e813b7ad3c9d0c6cf30e579c2103e7879f9/scripts/id_rsa.enc -------------------------------------------------------------------------------- /scripts/travis-doc-upload.cfg: -------------------------------------------------------------------------------- 1 | PROJECT_NAME=fann-rs 2 | DOCS_REPO=afck/docs.git 3 | SSH_KEY_TRAVIS_ID=a33de41c8e99 4 | -------------------------------------------------------------------------------- /scripts/travis-doc-upload.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # License: CC0 1.0 Universal 4 | # https://creativecommons.org/publicdomain/zero/1.0/legalcode 5 | 6 | set -e 7 | 8 | . ./scripts/travis-doc-upload.cfg 9 | 10 | [ "$TRAVIS_BRANCH" = master ] 11 | 12 | [ "$TRAVIS_PULL_REQUEST" = false ] 13 | 14 | eval key=\$encrypted_${SSH_KEY_TRAVIS_ID}_key 15 | eval iv=\$encrypted_${SSH_KEY_TRAVIS_ID}_iv 16 | 17 | mkdir -p ~/.ssh 18 | openssl aes-256-cbc -K $key -iv $iv -in scripts/id_rsa.enc -out ~/.ssh/id_rsa -d 19 | chmod 600 ~/.ssh/id_rsa 20 | 21 | git clone --branch gh-pages git@github.com:$DOCS_REPO deploy_docs 22 | 23 | cd deploy_docs 24 | git config user.name "doc upload bot" 25 | git config user.email "nobody@example.com" 26 | rm -rf $PROJECT_NAME 27 | mv ../target/doc $PROJECT_NAME 28 | git add -A $PROJECT_NAME 29 | git commit -qm "doc upload for $PROJECT_NAME ($TRAVIS_REPO_SLUG)" 30 | git push -q origin gh-pages 31 | -------------------------------------------------------------------------------- /src/activation_func.rs: -------------------------------------------------------------------------------- 1 | pub use error::{FannError, FannErrorType, FannResult}; 2 | use fann_sys::*; 3 | 4 | /// The activation functions used for the neurons during training. They can either be set for a 5 | /// group of neurons using `set_activation_func_hidden` and `set_activation_func_output`, or for a 6 | /// single neuron using `set_activation_func`. 7 | /// 8 | /// Similarly, the steepness of an activation function is specified using 9 | /// `set_activation_steepness_hidden`, `set_activation_steepness_output` and 10 | /// `set_activation_steepness`. 11 | /// 12 | /// In the descriptions of the functions: 13 | /// 14 | /// * x is the input to the activation function, 15 | /// 16 | /// * y is the output, 17 | /// 18 | /// * s is the steepness and 19 | /// 20 | /// * d is the derivation. 21 | #[derive(Copy, Clone, Debug, Eq, PartialEq)] 22 | pub enum ActivationFunc { 23 | /// Linear activation function. 24 | /// 25 | /// * span: -inf < y < inf 26 | /// 27 | /// * y = x*s, d = 1*s 28 | /// 29 | /// * Can NOT be used in fixed point. 30 | Linear, 31 | /// Threshold activation function. 32 | /// 33 | /// * x < 0 -> y = 0, x >= 0 -> y = 1 34 | /// 35 | /// * Can NOT be used during training. 36 | Threshold, 37 | /// Threshold activation function. 38 | /// 39 | /// * x < 0 -> y = 0, x >= 0 -> y = 1 40 | /// 41 | /// * Can NOT be used during training. 42 | ThresholdSymmetric, 43 | /// Sigmoid activation function. 44 | /// 45 | /// * One of the most used activation functions. 46 | /// 47 | /// * span: 0 < y < 1 48 | /// 49 | /// * y = 1/(1 + exp(-2*s*x)) 50 | /// 51 | /// * d = 2*s*y*(1 - y) 52 | Sigmoid, 53 | /// Stepwise linear approximation to sigmoid. 54 | /// 55 | /// * Faster than sigmoid but a bit less precise. 56 | SigmoidStepwise, 57 | /// Symmetric sigmoid activation function, aka. tanh. 58 | /// 59 | /// * One of the most used activation functions. 60 | /// 61 | /// * span: -1 < y < 1 62 | /// 63 | /// * y = tanh(s*x) = 2/(1 + exp(-2*s*x)) - 1 64 | /// 65 | /// * d = s*(1-(y*y)) 66 | SigmoidSymmetric, 67 | /// Stepwise linear approximation to symmetric sigmoid. 68 | /// 69 | /// * Faster than symmetric sigmoid but a bit less precise. 70 | SigmoidSymmetricStepwise, 71 | /// Gaussian activation function. 72 | /// 73 | /// * 0 when x = -inf, 1 when x = 0 and 0 when x = inf 74 | /// 75 | /// * span: 0 < y < 1 76 | /// 77 | /// * y = exp(-x*s*x*s) 78 | /// 79 | /// * d = -2*x*s*y*s 80 | Gaussian, 81 | /// Symmetric gaussian activation function. 82 | /// 83 | /// * -1 when x = -inf, 1 when x = 0 and 0 when x = inf 84 | /// 85 | /// * span: -1 < y < 1 86 | /// 87 | /// * y = exp(-x*s*x*s)*2-1 88 | /// 89 | /// * d = -2*x*s*(y+1)*s 90 | GaussianSymmetric, 91 | /// Stepwise linear approximation to gaussian. 92 | /// Faster than gaussian but a bit less precise. 93 | /// NOT implemented yet. 94 | GaussianStepwise, 95 | /// Fast (sigmoid like) activation function defined by David Elliott 96 | /// 97 | /// * span: 0 < y < 1 98 | /// 99 | /// * y = ((x*s) / 2) / (1 + |x*s|) + 0.5 100 | /// 101 | /// * d = s*1/(2*(1+|x*s|)*(1+|x*s|)) 102 | Elliott, 103 | /// Fast (symmetric sigmoid like) activation function defined by David Elliott 104 | /// 105 | /// * span: -1 < y < 1 106 | /// 107 | /// * y = (x*s) / (1 + |x*s|) 108 | /// 109 | /// * d = s*1/((1+|x*s|)*(1+|x*s|)) 110 | ElliottSymmetric, 111 | /// Bounded linear activation function. 112 | /// 113 | /// * span: 0 <= y <= 1 114 | /// 115 | /// * y = x*s, d = 1*s 116 | LinearPiece, 117 | /// Bounded linear activation function. 118 | /// 119 | /// * span: -1 <= y <= 1 120 | /// 121 | /// * y = x*s, d = 1*s 122 | LinearPieceSymmetric, 123 | /// Periodical sine activation function. 124 | /// 125 | /// * span: -1 <= y <= 1 126 | /// 127 | /// * y = sin(x*s) 128 | /// 129 | /// * d = s*cos(x*s) 130 | SinSymmetric, 131 | /// Periodical cosine activation function. 132 | /// 133 | /// * span: -1 <= y <= 1 134 | /// 135 | /// * y = cos(x*s) 136 | /// 137 | /// * d = s*-sin(x*s) 138 | CosSymmetric, 139 | /// Periodical sine activation function. 140 | /// 141 | /// * span: 0 <= y <= 1 142 | /// 143 | /// * y = sin(x*s)/2+0.5 144 | /// 145 | /// * d = s*cos(x*s)/2 146 | Sin, 147 | /// Periodical cosine activation function. 148 | /// 149 | /// * span: 0 <= y <= 1 150 | /// 151 | /// * y = cos(x*s)/2+0.5 152 | /// 153 | /// * d = s*-sin(x*s)/2 154 | Cos, 155 | } 156 | 157 | impl ActivationFunc { 158 | /// Create an `ActivationFunc` from a `fann_sys::fann_activationfunc_enum`. 159 | pub fn from_fann_activationfunc_enum( 160 | af_enum: fann_activationfunc_enum, 161 | ) -> FannResult { 162 | match af_enum { 163 | FANN_NONE => Err(FannError { 164 | error_type: FannErrorType::IndexOutOfBound, 165 | error_str: "Neuron or layer index is out of bound.".to_owned(), 166 | }), 167 | FANN_LINEAR => Ok(ActivationFunc::Linear), 168 | FANN_THRESHOLD => Ok(ActivationFunc::Threshold), 169 | FANN_THRESHOLD_SYMMETRIC => Ok(ActivationFunc::ThresholdSymmetric), 170 | FANN_SIGMOID => Ok(ActivationFunc::Sigmoid), 171 | FANN_SIGMOID_STEPWISE => Ok(ActivationFunc::SigmoidStepwise), 172 | FANN_SIGMOID_SYMMETRIC => Ok(ActivationFunc::SigmoidSymmetric), 173 | FANN_SIGMOID_SYMMETRIC_STEPWISE => Ok(ActivationFunc::SigmoidSymmetricStepwise), 174 | FANN_GAUSSIAN => Ok(ActivationFunc::Gaussian), 175 | FANN_GAUSSIAN_SYMMETRIC => Ok(ActivationFunc::GaussianSymmetric), 176 | FANN_GAUSSIAN_STEPWISE => Ok(ActivationFunc::GaussianStepwise), 177 | FANN_ELLIOTT => Ok(ActivationFunc::Elliott), 178 | FANN_ELLIOTT_SYMMETRIC => Ok(ActivationFunc::ElliottSymmetric), 179 | FANN_LINEAR_PIECE => Ok(ActivationFunc::LinearPiece), 180 | FANN_LINEAR_PIECE_SYMMETRIC => Ok(ActivationFunc::LinearPieceSymmetric), 181 | FANN_SIN_SYMMETRIC => Ok(ActivationFunc::SinSymmetric), 182 | FANN_COS_SYMMETRIC => Ok(ActivationFunc::CosSymmetric), 183 | FANN_SIN => Ok(ActivationFunc::Sin), 184 | FANN_COS => Ok(ActivationFunc::Cos), 185 | } 186 | } 187 | 188 | /// Return the `fann_sys::fann_activationfunc_enum` corresponding to this `ActivationFunc`. 189 | pub fn to_fann_activationfunc_enum(self) -> fann_activationfunc_enum { 190 | match self { 191 | ActivationFunc::Linear => FANN_LINEAR, 192 | ActivationFunc::Threshold => FANN_THRESHOLD, 193 | ActivationFunc::ThresholdSymmetric => FANN_THRESHOLD_SYMMETRIC, 194 | ActivationFunc::Sigmoid => FANN_SIGMOID, 195 | ActivationFunc::SigmoidStepwise => FANN_SIGMOID_STEPWISE, 196 | ActivationFunc::SigmoidSymmetric => FANN_SIGMOID_SYMMETRIC, 197 | ActivationFunc::SigmoidSymmetricStepwise => FANN_SIGMOID_SYMMETRIC_STEPWISE, 198 | ActivationFunc::Gaussian => FANN_GAUSSIAN, 199 | ActivationFunc::GaussianSymmetric => FANN_GAUSSIAN_SYMMETRIC, 200 | ActivationFunc::GaussianStepwise => FANN_GAUSSIAN_STEPWISE, 201 | ActivationFunc::Elliott => FANN_ELLIOTT, 202 | ActivationFunc::ElliottSymmetric => FANN_ELLIOTT_SYMMETRIC, 203 | ActivationFunc::LinearPiece => FANN_LINEAR_PIECE, 204 | ActivationFunc::LinearPieceSymmetric => FANN_LINEAR_PIECE_SYMMETRIC, 205 | ActivationFunc::SinSymmetric => FANN_SIN_SYMMETRIC, 206 | ActivationFunc::CosSymmetric => FANN_COS_SYMMETRIC, 207 | ActivationFunc::Sin => FANN_SIN, 208 | ActivationFunc::Cos => FANN_COS, 209 | } 210 | } 211 | } 212 | -------------------------------------------------------------------------------- /src/cascade_params.rs: -------------------------------------------------------------------------------- 1 | use activation_func::ActivationFunc; 2 | use fann_sys::fann_type; 3 | use libc::{c_float, c_uint}; 4 | 5 | /// Parameters for cascade training. 6 | #[derive(Clone, Debug, PartialEq)] 7 | pub struct CascadeParams { 8 | /// A number between 0 and 1 determining how large a fraction the mean square error should 9 | /// change within `output_stagnation_epochs` during training of the output connections, in 10 | /// order for the training to stagnate. After stagnation, training of the output connections 11 | /// ends and new candidates are prepared. 12 | /// 13 | /// This means: If the MSE does not change by a fraction of `output_change_fraction` during a 14 | /// period of `output_stagnation_epochs`, the training of the output connections is stopped 15 | /// because training has stagnated. 16 | pub output_change_fraction: c_float, 17 | /// The number of epochs training is allowed to continue without changing the MSE by a fraction 18 | /// of at least `output_change_fraction`. 19 | pub output_stagnation_epochs: c_uint, 20 | /// A number between 0 and 1 determining how large a fraction the mean square error should 21 | /// change within `candidate_stagnation_epochs` during training of the candidate neurons, in 22 | /// order for the training to stagnate. After stagnation, training of the candidate neurons is 23 | /// stopped and the best candidate is selected. 24 | /// 25 | /// This means: If the MSE does not change by a fraction of `candidate_change_fraction` during 26 | /// a period of `candidate_stagnation_epochs`, the training of the candidate neurons is stopped 27 | /// because training has stagnated. 28 | pub candidate_change_fraction: c_float, 29 | /// The number of epochs training is allowed to continue without changing the MSE by a fraction 30 | /// of `candidate_change_fraction`. 31 | pub candidate_stagnation_epochs: c_uint, 32 | /// A limit for how much the candidate neuron may be trained. It limits the ratio between the 33 | /// MSE and the candidate score. 34 | pub candidate_limit: fann_type, 35 | /// Multiplier for the weight of the candidate neuron before adding it to the network. Usually 36 | /// between 0 and 1, to make training less aggressive. 37 | pub weight_multiplier: fann_type, 38 | /// The maximum number of epochs the output connections may be trained after adding a new 39 | /// candidate neuron. 40 | pub max_out_epochs: c_uint, 41 | /// The maximum number of epochs the input connections to the candidates may be trained before 42 | /// adding a new candidate neuron. 43 | pub max_cand_epochs: c_uint, 44 | /// The activation functions for the candidate neurons. 45 | pub activation_functions: Vec, 46 | /// The activation function steepness values for the candidate neurons. 47 | pub activation_steepnesses: Vec, 48 | /// The number of candidate neurons to be trained for each combination of activation function 49 | /// and steepness. 50 | pub num_candidate_groups: c_uint, 51 | } 52 | 53 | impl CascadeParams { 54 | /// The number of candidates used during training: the number of combinations of activation 55 | /// functions and steepnesses, times `num_candidate_groups`. 56 | /// 57 | /// For every combination of activation function and steepness, `num_candidate_groups` such 58 | /// neurons, with different initial weights, are trained. 59 | pub fn get_num_candidates(&self) -> c_uint { 60 | self.activation_functions.len() as c_uint 61 | * self.activation_steepnesses.len() as c_uint 62 | * self.num_candidate_groups 63 | } 64 | } 65 | 66 | impl Default for CascadeParams { 67 | fn default() -> CascadeParams { 68 | CascadeParams { 69 | output_change_fraction: 0.01, 70 | output_stagnation_epochs: 12, 71 | candidate_change_fraction: 0.01, 72 | candidate_stagnation_epochs: 12, 73 | candidate_limit: 1000.0, 74 | weight_multiplier: 0.4, 75 | max_out_epochs: 150, 76 | max_cand_epochs: 150, 77 | activation_functions: vec![ 78 | ActivationFunc::Sigmoid, 79 | ActivationFunc::SigmoidSymmetric, 80 | ActivationFunc::Gaussian, 81 | ActivationFunc::GaussianSymmetric, 82 | ActivationFunc::Elliott, 83 | ActivationFunc::ElliottSymmetric, 84 | ActivationFunc::SinSymmetric, 85 | ActivationFunc::CosSymmetric, 86 | ActivationFunc::Sin, 87 | ActivationFunc::Cos, 88 | ], 89 | activation_steepnesses: vec![0.25, 0.5, 0.75, 1.0], 90 | num_candidate_groups: 2, 91 | } 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | use self::FannErrorType::*; 2 | use fann_sys::fann_errno_enum::*; 3 | use fann_sys::{fann_error, fann_get_errno, fann_get_errstr}; 4 | use libc::c_int; 5 | use std::error::Error; 6 | use std::ffi::CStr; 7 | use std::fmt; 8 | 9 | #[derive(Copy, Clone, Debug, Eq, PartialEq)] 10 | pub enum FannErrorType { 11 | /// Unable to open configuration file for reading 12 | CantOpenConfigR, 13 | /// Unable to open configuration file for writing 14 | CantOpenConfigW, 15 | /// Wrong version of configuration file 16 | WrongConfigVersion, 17 | /// Error reading info from configuration file 18 | CantReadConfig, 19 | /// Error reading neuron info from configuration file 20 | CantReadNeuron, 21 | /// Error reading connections from configuration file 22 | CantReadConnections, 23 | /// Number of connections not equal to the number expected 24 | WrongNumConnections, 25 | /// Unable to open train data file for writing 26 | CantOpenTdW, 27 | /// Unable to open train data file for reading 28 | CantOpenTdR, 29 | /// Error reading training data from file 30 | CantReadTd, 31 | /// Unable to allocate memory 32 | CantAllocateMem, 33 | /// Unable to train with the selected activation function 34 | CantTrainActivation, 35 | /// Unable to use the selected activation function 36 | CantUseActivation, 37 | /// Irreconcilable differences between two `fann_train_data` structures 38 | TrainDataMismatch, 39 | /// Unable to use the selected training algorithm 40 | CantUseTrainAlg, 41 | /// Trying to take subset which is not within the training set 42 | TrainDataSubset, 43 | /// Index is out of bound 44 | IndexOutOfBound, 45 | /// Scaling parameters not present 46 | ScaleNotPresent, 47 | // Errors specific to the Rust wrapper: 48 | /// Failed to save file 49 | CantSaveFile, 50 | /// C function returned an error code, i. e. not 0, but did not specify error 51 | ErrorCodeReturned, 52 | } 53 | 54 | impl fmt::Display for FannErrorType { 55 | fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { 56 | fmt::Display::fmt( 57 | match *self { 58 | CantOpenConfigR => "Unable to open configuration file for reading", 59 | CantOpenConfigW => "Unable to open configuration file for writing", 60 | WrongConfigVersion => "Wrong version of configuration file", 61 | CantReadConfig => "Error reading info from configuration file", 62 | CantReadNeuron => "Error reading neuron info from configuration file", 63 | CantReadConnections => "Error reading connections from configuration file", 64 | WrongNumConnections => "Number of connections not equal to the number expected", 65 | CantOpenTdW => "Unable to open train data file for writing", 66 | CantOpenTdR => "Unable to open train data file for reading", 67 | CantReadTd => "Error reading training data from file", 68 | CantAllocateMem => "Unable to allocate memory", 69 | CantTrainActivation => "Unable to train with the selected activation function", 70 | CantUseActivation => "Unable to use the selected activation function", 71 | TrainDataMismatch => "Irreconcilable differences between two Fann objects", 72 | CantUseTrainAlg => "Unable to use the selected training algorithm", 73 | TrainDataSubset => "Trying to take subset which is not within the training set", 74 | IndexOutOfBound => "Index is out of bound", 75 | ScaleNotPresent => "Scaling parameters not present", 76 | CantSaveFile => "Failed saving file", 77 | ErrorCodeReturned => "C function returned an error code but did not specify error", 78 | }, 79 | f, 80 | ) 81 | } 82 | } 83 | 84 | #[derive(Clone, Debug, Eq, PartialEq)] 85 | pub struct FannError { 86 | pub error_type: FannErrorType, 87 | pub error_str: String, 88 | } 89 | 90 | pub type FannResult = Result; 91 | 92 | impl fmt::Display for FannError { 93 | fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { 94 | self.error_type.fmt(f)?; 95 | ": ".fmt(f)?; 96 | self.error_str.fmt(f) 97 | } 98 | } 99 | 100 | impl Error for FannError { 101 | fn description(&self) -> &str { 102 | &self.error_str[..] 103 | } 104 | } 105 | 106 | impl FannError { 107 | /// Returns an `Err` if the previous operation on `errdat` has resulted in an error, otherwise 108 | /// `Ok(())`. 109 | pub unsafe fn check_no_error(errdat: *mut fann_error) -> FannResult<()> { 110 | if errdat.is_null() { 111 | return Err(FannError { 112 | error_type: FannErrorType::CantAllocateMem, 113 | error_str: "Unable to create a new object".to_owned(), 114 | }); 115 | } 116 | let error_type = match fann_get_errno(errdat) { 117 | FANN_E_NO_ERROR => return Ok(()), 118 | FANN_E_CANT_OPEN_CONFIG_R => CantOpenConfigR, 119 | FANN_E_CANT_OPEN_CONFIG_W => CantOpenConfigW, 120 | FANN_E_WRONG_CONFIG_VERSION => WrongConfigVersion, 121 | FANN_E_CANT_READ_CONFIG => CantReadConfig, 122 | FANN_E_CANT_READ_NEURON => CantReadNeuron, 123 | FANN_E_CANT_READ_CONNECTIONS => CantReadConnections, 124 | FANN_E_WRONG_NUM_CONNECTIONS => WrongNumConnections, 125 | FANN_E_CANT_OPEN_TD_W => CantOpenTdW, 126 | FANN_E_CANT_OPEN_TD_R => CantOpenTdR, 127 | FANN_E_CANT_READ_TD => CantReadTd, 128 | FANN_E_CANT_ALLOCATE_MEM => CantAllocateMem, 129 | FANN_E_CANT_TRAIN_ACTIVATION => CantTrainActivation, 130 | FANN_E_CANT_USE_ACTIVATION => CantUseActivation, 131 | FANN_E_TRAIN_DATA_MISMATCH => TrainDataMismatch, 132 | FANN_E_CANT_USE_TRAIN_ALG => CantUseTrainAlg, 133 | FANN_E_TRAIN_DATA_SUBSET => TrainDataSubset, 134 | FANN_E_INDEX_OUT_OF_BOUND => IndexOutOfBound, 135 | FANN_E_SCALE_NOT_PRESENT => ScaleNotPresent, 136 | }; 137 | let errstr_bytes = CStr::from_ptr(fann_get_errstr(errdat)).to_bytes().to_vec(); 138 | let error_str_opt = String::from_utf8(errstr_bytes); 139 | Err(FannError { 140 | error_type, 141 | error_str: error_str_opt.unwrap_or_else(|_| "Invalid UTF-8 in error string".to_owned()), 142 | }) 143 | } 144 | 145 | pub unsafe fn check_zero( 146 | result: c_int, 147 | errdat: *mut fann_error, 148 | error_str: &str, 149 | ) -> FannResult<()> { 150 | FannError::check_no_error(errdat)?; 151 | match result { 152 | 0 => Ok(()), 153 | _ => Err(FannError { 154 | error_type: FannErrorType::ErrorCodeReturned, 155 | error_str: error_str.to_owned(), 156 | }), 157 | } 158 | } 159 | } 160 | -------------------------------------------------------------------------------- /src/error_func.rs: -------------------------------------------------------------------------------- 1 | use fann_sys::*; 2 | 3 | /// Error function used during training. 4 | #[derive(Copy, Clone, Debug, Eq, PartialEq)] 5 | pub enum ErrorFunc { 6 | /// Standard linear error function 7 | Linear, 8 | /// Tanh error function; usually better but may require a lower learning rate. This error 9 | /// function aggressively targets outputs that differ much from the desired, while not targeting 10 | /// outputs that only differ slightly. Not recommended for cascade or incremental training. 11 | Tanh, 12 | } 13 | 14 | impl ErrorFunc { 15 | /// Create an `ErrorFunc` from a `fann_sys::fann_errorfunc_enum`. 16 | pub fn from_errorfunc_enum(ef_enum: fann_errorfunc_enum) -> ErrorFunc { 17 | match ef_enum { 18 | FANN_ERRORFUNC_LINEAR => ErrorFunc::Linear, 19 | FANN_ERRORFUNC_TANH => ErrorFunc::Tanh, 20 | } 21 | } 22 | 23 | /// Return the `fann_sys::fann_errorfunc_enum` corresponding to this `ErrorFunc`. 24 | pub fn to_errorfunc_enum(self) -> fann_errorfunc_enum { 25 | match self { 26 | ErrorFunc::Linear => FANN_ERRORFUNC_LINEAR, 27 | ErrorFunc::Tanh => FANN_ERRORFUNC_TANH, 28 | } 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! A Rust wrapper for the Fast Artificial Neural Network library. 2 | //! 3 | //! A new neural network with random weights can be created with the `Fann::new` method, or, for 4 | //! different network topologies, with its variants `Fann::new_sparse` and `Fann::new_shortcut`. 5 | //! Existing neural networks can be saved to and loaded from files. 6 | //! 7 | //! Similarly, training data sets can be loaded from and saved to human-readable files, or training 8 | //! data can be provided directly to the network as slices of floating point numbers. 9 | //! 10 | //! Example: 11 | //! 12 | //! ``` 13 | //! extern crate fann; 14 | //! use fann::{ActivationFunc, Fann, TrainAlgorithm, QuickpropParams}; 15 | //! 16 | //! fn main() { 17 | //! // Create a new network with two input neurons, a hidden layer with three neurons, and one 18 | //! // output neuron. 19 | //! let mut fann = Fann::new(&[2, 3, 1]).unwrap(); 20 | //! // Configure the activation functions for the hidden and output neurons. 21 | //! fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric); 22 | //! fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric); 23 | //! // Use the Quickprop learning algorithm, with default parameters. 24 | //! // (Otherwise, Rprop would be used.) 25 | //! fann.set_train_algorithm(TrainAlgorithm::Quickprop(Default::default())); 26 | //! // Train for up to 500000 epochs, displaying progress information after intervals of 1000 27 | //! // epochs. Stop when the network's error on the training data drops to 0.001. 28 | //! let max_epochs = 500000; 29 | //! let epochs_between_reports = 1000; 30 | //! let desired_error = 0.001; 31 | //! // Train directly on data loaded from the file "xor.data". 32 | //! fann.on_file("test_files/xor.data") 33 | //! .with_reports(epochs_between_reports) 34 | //! .train(max_epochs, desired_error).unwrap(); 35 | //! // The network now approximates the XOR problem: 36 | //! assert!(fann.run(&[-1.0, 1.0]).unwrap()[0] > 0.9); 37 | //! assert!(fann.run(&[ 1.0, -1.0]).unwrap()[0] > 0.9); 38 | //! assert!(fann.run(&[ 1.0, 1.0]).unwrap()[0] < 0.1); 39 | //! assert!(fann.run(&[-1.0, -1.0]).unwrap()[0] < 0.1); 40 | //! } 41 | //! ``` 42 | //! 43 | //! FANN also supports cascade training, where the network's topology is changed during training by 44 | //! adding additional neurons: 45 | //! 46 | //! ``` 47 | //! extern crate fann; 48 | //! use fann::{ActivationFunc, CascadeParams, Fann}; 49 | //! 50 | //! fn main() { 51 | //! // Create a new network with two input neurons and one output neuron. 52 | //! let mut fann = Fann::new_shortcut(&[2, 1]).unwrap(); 53 | //! // Use the default cascade training parameters, but a higher weight multiplier: 54 | //! fann.set_cascade_params(&CascadeParams { 55 | //! weight_multiplier: 0.6, 56 | //! ..CascadeParams::default() 57 | //! }); 58 | //! // Add up to 50 neurons, displaying progress information after each. 59 | //! // Stop when the network's error on the training data drops to 0.001. 60 | //! let max_neurons = 50; 61 | //! let neurons_between_reports = 1; 62 | //! let desired_error = 0.001; 63 | //! // Train directly on data loaded from the file "xor.data". 64 | //! fann.on_file("test_files/xor.data") 65 | //! .with_reports(neurons_between_reports) 66 | //! .cascade() 67 | //! .train(max_neurons, desired_error).unwrap(); 68 | //! // The network now approximates the XOR problem: 69 | //! assert!(fann.run(&[-1.0, 1.0]).unwrap()[0] > 0.9); 70 | //! assert!(fann.run(&[ 1.0, -1.0]).unwrap()[0] > 0.9); 71 | //! assert!(fann.run(&[ 1.0, 1.0]).unwrap()[0] < 0.1); 72 | //! assert!(fann.run(&[-1.0, -1.0]).unwrap()[0] < 0.1); 73 | //! } 74 | //! ``` 75 | 76 | extern crate fann_sys; 77 | extern crate libc; 78 | 79 | use fann_sys::*; 80 | use libc::{c_float, c_int, c_uint}; 81 | use std::cell::RefCell; 82 | use std::ffi::CString; 83 | use std::mem::{forget, transmute}; 84 | use std::path::Path; 85 | use std::ptr::{copy_nonoverlapping, null_mut}; 86 | 87 | pub use activation_func::ActivationFunc; 88 | pub use cascade_params::CascadeParams; 89 | pub use error::{FannError, FannErrorType, FannResult}; 90 | pub use error_func::ErrorFunc; 91 | pub use net_type::NetType; 92 | pub use stop_func::StopFunc; 93 | pub use train_algorithm::TrainAlgorithm; 94 | pub use train_algorithm::{BatchParams, IncrementalParams, QuickpropParams, RpropParams}; 95 | pub use train_data::TrainData; 96 | 97 | mod activation_func; 98 | mod cascade_params; 99 | mod error; 100 | mod error_func; 101 | mod net_type; 102 | mod stop_func; 103 | mod train_algorithm; 104 | mod train_data; 105 | 106 | /// The type of weights, inputs and outputs in a neural network. It is defined as `c_float` by 107 | /// default, and as `c_double` if the `double` feature is configured. 108 | pub type FannType = fann_type; 109 | 110 | pub type Connection = fann_connection; 111 | 112 | /// Convert a path to a `CString`. 113 | fn to_filename>(path: P) -> Result { 114 | match path.as_ref().to_str().map(CString::new) { 115 | None => Err(FannError { 116 | error_type: FannErrorType::CantOpenTdR, 117 | error_str: "File name contains invalid unicode characters".to_owned(), 118 | }), 119 | Some(Err(e)) => Err(FannError { 120 | error_type: FannErrorType::CantOpenTdR, 121 | error_str: format!( 122 | "File name contains a nul byte at position {}", 123 | e.nul_position() 124 | ), 125 | }), 126 | Some(Ok(cs)) => Ok(cs), 127 | } 128 | } 129 | 130 | /// Either an owned or a borrowed `TrainData`. 131 | enum CurrentTrainData<'a> { 132 | Own(FannResult), 133 | Ref(&'a TrainData), 134 | } 135 | 136 | // Thread-local container for a pointer to the current FannTrainer. 137 | // This is necessary because the raw fann_train_on_data_with_callback C function takes a function 138 | // pointer and not a closure. So instead of the user-supplied function we pass a function to it 139 | // which will call the callback stored in the trainer. 140 | // The 'static lifetime is a lie! But the trainer lives longer than the train method runs, and 141 | // afterwards resets this pointer to null again. 142 | thread_local!(static TRAINER: RefCell<*mut FannTrainer<'static>> = RefCell::new(null_mut())); 143 | 144 | #[derive(Clone, Copy, Debug)] 145 | pub enum CallbackResult { 146 | Stop, 147 | Continue, 148 | } 149 | 150 | impl CallbackResult { 151 | pub fn stop_if(condition: bool) -> CallbackResult { 152 | if condition { 153 | CallbackResult::Stop 154 | } else { 155 | CallbackResult::Continue 156 | } 157 | } 158 | } 159 | 160 | /// A training configuration. Create this with `Fann::on_data` or `Fann::on_file` and run the 161 | /// training with `train`. 162 | pub struct FannTrainer<'a> { 163 | fann: &'a mut Fann, 164 | cur_data: CurrentTrainData<'a>, 165 | callback: Option<&'a dyn Fn(&Fann, &TrainData, c_uint) -> CallbackResult>, 166 | interval: c_uint, 167 | cascade: bool, 168 | } 169 | 170 | impl<'a> FannTrainer<'a> { 171 | fn with_data<'b>(fann: &'b mut Fann, data: &'b TrainData) -> FannTrainer<'b> { 172 | FannTrainer { 173 | fann, 174 | cur_data: CurrentTrainData::Ref(data), 175 | callback: None, 176 | interval: 0, 177 | cascade: false, 178 | } 179 | } 180 | 181 | fn with_file>(fann: &mut Fann, path: P) -> FannTrainer { 182 | FannTrainer { 183 | fann, 184 | cur_data: CurrentTrainData::Own(TrainData::from_file(path)), 185 | callback: None, 186 | interval: 0, 187 | cascade: false, 188 | } 189 | } 190 | 191 | /// Activates printing reports periodically. Between two reports, `interval` neurons are added 192 | /// (for cascade training) or training goes on for `interval` epochs (otherwise). 193 | pub fn with_reports(self, interval: c_uint) -> FannTrainer<'a> { 194 | FannTrainer { interval, ..self } 195 | } 196 | 197 | /// Configures a callback to be called periodically during training. Every `interval` epochs 198 | /// (for regular training) or every time `interval` new neurons have been added (for cascade 199 | /// training), the callback runs. It receives as arguments: 200 | /// 201 | /// * a reference to the current `Fann`, 202 | /// * a reference to the training data, 203 | /// * the number of steps (added neurons or epochs) taken so far. 204 | pub fn with_callback( 205 | self, 206 | interval: c_uint, 207 | callback: &'a dyn Fn(&Fann, &TrainData, c_uint) -> CallbackResult, 208 | ) -> FannTrainer<'a> { 209 | FannTrainer { 210 | callback: Some(callback), 211 | interval, 212 | ..self 213 | } 214 | } 215 | 216 | /// Use the Cascade2 algorithm: This adds neurons to the neural network while training, starting 217 | /// with an ANN without any hidden layers. The network should use shortcut connections, so it 218 | /// needs to be created like this: 219 | /// 220 | /// ``` 221 | /// let td = fann::TrainData::from_file("test_files/xor.data").unwrap(); 222 | /// let fann = fann::Fann::new_shortcut(&[td.num_input(), td.num_output()]).unwrap(); 223 | /// ``` 224 | pub fn cascade(self) -> FannTrainer<'a> { 225 | FannTrainer { 226 | cascade: true, 227 | ..self 228 | } 229 | } 230 | 231 | extern "C" fn raw_callback( 232 | ann: *mut fann, 233 | td: *mut fann_train_data, 234 | _: c_uint, 235 | _: c_uint, 236 | _: c_float, 237 | steps: c_uint, 238 | ) -> c_int { 239 | // TODO: This is an ugly hack - find better ways to solve the following issues: 240 | // * The C callback is not a closure, so it cannot access the user-supplied argument. 241 | // https://aatch.github.io/blog/2015/01/17/unboxed-closures-and-ffi-callbacks doesn't 242 | // work here because the C callback doesn't take a user-defined pointer as an argument. 243 | // Instead, we store a pointer to the FannTrainer, which contains a fat pointer to the 244 | // callback, in a thread-local variable that is accessed by the raw callback. 245 | // * The lifetime isn't known at the point where the thread-local variable is declared, so 246 | // we just use 'static and transmute the pointer! 247 | // * The C callback is only given pointers to the raw structs, not to self and data. We 248 | // read these from the tread-local variable, too, and assert that they correspond to the 249 | // given raw structs. 250 | // * https://github.com/rust-lang/rust/issues/24010 seems to make it impossible to define a 251 | // trait that would act as a shortcut for Fn(...) -> CallbackResult. 252 | match TRAINER.with(|cell| unsafe { 253 | let trainer = *cell.borrow(); 254 | let data = (*trainer).get_data().unwrap(); 255 | assert_eq!(ann, (*trainer).fann.raw); 256 | assert_eq!(td, data.get_raw()); 257 | let callback = (*trainer).callback.unwrap(); 258 | callback((*trainer).fann, data, steps) 259 | }) { 260 | CallbackResult::Stop => -1, 261 | CallbackResult::Continue => 0, 262 | } 263 | } 264 | 265 | fn get_data(&'a self) -> FannResult<&'a TrainData> { 266 | match self.cur_data { 267 | CurrentTrainData::Ref(data) => Ok(data), 268 | CurrentTrainData::Own(ref result) => result.as_ref().map_err(FannError::clone), 269 | } 270 | } 271 | 272 | /// Train the network until either the mean square error drops below the `desired_error`, or 273 | /// the maximum number of steps is reached. If cascade training is activated, `max_steps` 274 | /// refers to the number of neurons that are added, otherwise it is the maximum number of 275 | /// training epochs. 276 | // Clippy's suggestion fails: https://github.com/rust-lang-nursery/rust-clippy/issues/3340 277 | #[cfg_attr(feature = "cargo-clippy", allow(useless_transmute))] 278 | pub fn train(&mut self, max_steps: c_uint, desired_error: c_float) -> FannResult<()> { 279 | unsafe { 280 | let raw_data = self.get_data()?.get_raw(); 281 | if self.callback.is_some() { 282 | TRAINER.with(|cell| *cell.borrow_mut() = transmute(&mut *self)); 283 | fann_set_callback(self.fann.raw, Some(FannTrainer::raw_callback)); 284 | } 285 | let raw_train_fn = if self.cascade { 286 | fann_cascadetrain_on_data 287 | } else { 288 | fann_train_on_data 289 | }; 290 | raw_train_fn( 291 | self.fann.raw, 292 | raw_data, 293 | max_steps, 294 | self.interval, 295 | desired_error, 296 | ); 297 | if self.callback.is_some() { 298 | fann_set_callback(self.fann.raw, None); 299 | TRAINER.with(|cell| *cell.borrow_mut() = null_mut()); 300 | } 301 | FannError::check_no_error(self.fann.raw as *mut fann_error) 302 | } 303 | } 304 | } 305 | 306 | pub struct Fann { 307 | // We don't consider setting and clearing the error string and number a mutation, and every 308 | // method should leave these fields cleared, either because it succeeded or because it read the 309 | // fields and returned the corresponding error. 310 | // We also don't consider writing the output data a mutation, as we don't provide access to it 311 | // and copy it before returning it. 312 | raw: *mut fann, 313 | } 314 | 315 | impl Fann { 316 | unsafe fn from_raw(raw: *mut fann) -> FannResult { 317 | FannError::check_no_error(raw as *mut fann_error)?; 318 | Ok(Fann { raw }) 319 | } 320 | 321 | /// Create a fully connected neural network. 322 | /// 323 | /// There will be a bias neuron in each layer except the output layer, 324 | /// and this bias neuron will be connected to all neurons in the next layer. 325 | /// When running the network, the bias nodes always emit 1. 326 | /// 327 | /// # Arguments 328 | /// 329 | /// * `layers` - Specifies the number of neurons in each layer, starting with the input and 330 | /// ending with the output layer. 331 | /// 332 | /// # Example 333 | /// 334 | /// ``` 335 | /// // Creating a network with 2 input neurons, 1 output neuron, 336 | /// // and two hidden layers with 8 and 9 neurons. 337 | /// let layers = [2, 8, 9, 1]; 338 | /// fann::Fann::new(&layers).unwrap(); 339 | /// ``` 340 | pub fn new(layers: &[c_uint]) -> FannResult { 341 | Fann::new_sparse(1.0, layers) 342 | } 343 | 344 | /// Create a neural network that is not necessarily fully connected. 345 | /// 346 | /// There will be a bias neuron in each layer except the output layer, 347 | /// and this bias neuron will be connected to all neurons in the next layer. 348 | /// When running the network, the bias nodes always emit 1. 349 | /// 350 | /// # Arguments 351 | /// 352 | /// * `connection_rate` - The share of pairs of neurons in consecutive layers that will be 353 | /// connected. 354 | /// * `layers` - Specifies the number of neurons in each layer, starting with the input 355 | /// and ending with the output layer. 356 | pub fn new_sparse(connection_rate: c_float, layers: &[c_uint]) -> FannResult { 357 | unsafe { 358 | Fann::from_raw(fann_create_sparse_array( 359 | connection_rate, 360 | layers.len() as c_uint, 361 | layers.as_ptr(), 362 | )) 363 | } 364 | } 365 | 366 | /// Create a neural network which has shortcut connections, i. e. it doesn't connect only each 367 | /// layer to its successor, but every layer with every later layer: Each neuron has connections 368 | /// to all neurons in all subsequent layers. 369 | pub fn new_shortcut(layers: &[c_uint]) -> FannResult { 370 | unsafe { 371 | Fann::from_raw(fann_create_shortcut_array( 372 | layers.len() as c_uint, 373 | layers.as_ptr(), 374 | )) 375 | } 376 | } 377 | 378 | /// Read a neural network from a file. 379 | pub fn from_file>(path: P) -> FannResult { 380 | let filename = to_filename(path)?; 381 | unsafe { Fann::from_raw(fann_create_from_file(filename.as_ptr())) } 382 | } 383 | 384 | /// Create a deep copy of a neural network. 385 | /// 386 | /// The `Clone` trait is intentionally not implemented, because this operation might fail. 387 | pub fn try_clone(&self) -> FannResult { 388 | unsafe { Fann::from_raw(fann_copy(self.raw)) } 389 | } 390 | 391 | /// Save the network to a configuration file. 392 | /// 393 | /// The file will contain all information about the neural network, except parameters generated 394 | /// during training, like mean square error and the bit fail limit. 395 | pub fn save>(&self, path: P) -> FannResult<()> { 396 | let filename = to_filename(path)?; 397 | unsafe { 398 | let result = fann_save(self.raw, filename.as_ptr()); 399 | FannError::check_zero(result, self.raw as *mut fann_error, "Error saving network") 400 | } 401 | } 402 | 403 | /// Give each connection a random weight between `min_weight` and `max_weight`. 404 | /// 405 | /// By default, weights in a new network are random between -0.1 and 0.1. 406 | pub fn randomize_weights(&mut self, min_weight: FannType, max_weight: FannType) { 407 | unsafe { fann_randomize_weights(self.raw, min_weight, max_weight) } 408 | } 409 | 410 | /// Initialize the weights using Widrow & Nguyen's algorithm. 411 | /// 412 | /// The algorithm developed by Derrick Nguyen and Bernard Widrow sets the weight in a way that 413 | /// can speed up training with the given training data. This technique is not always successful 414 | /// and in some cases can even be less efficient that a purely random initialization. 415 | pub fn init_weights(&mut self, train_data: &TrainData) { 416 | unsafe { fann_init_weights(self.raw, train_data.get_raw()) } 417 | } 418 | 419 | /// Print the connections of the network in a compact matrix, for easy viewing of its 420 | /// internals. 421 | /// 422 | /// The output on a small (2 2 1) network trained on the xor problem: 423 | /// 424 | /// ```text 425 | /// Layer / Neuron 012345 426 | /// L 1 / N 3 BBa... 427 | /// L 1 / N 4 BBA... 428 | /// L 1 / N 5 ...... 429 | /// L 2 / N 6 ...BBA 430 | /// L 2 / N 7 ...... 431 | /// ``` 432 | /// 433 | /// This network has five real neurons and two bias neurons. This gives a total of seven 434 | /// neurons named from 0 to 6. The connections between these neurons can be seen in the matrix. 435 | /// "." is a place where there is no connection, while a character tells how strong the 436 | /// connection is on a scale from a-z. The two real neurons in the hidden layer (neuron 3 and 4 437 | /// in layer 1) have connections from the three neurons in the previous layer as is visible in 438 | /// the first two lines. The output neuron 6 has connections from the three neurons in the 439 | /// hidden layer 3 - 5 as is visible in the fourth line. 440 | /// 441 | /// To simplify the matrix output neurons are not visible as neurons that connections can come 442 | /// from, and input and bias neurons are not visible as neurons that connections can go to. 443 | pub fn print_connections(&self) { 444 | unsafe { fann_print_connections(self.raw) } 445 | } 446 | 447 | /// Print all parameters and options of the network. 448 | pub fn print_parameters(&self) { 449 | unsafe { fann_print_parameters(self.raw) } 450 | } 451 | 452 | /// Return an `Err` if the size of the slice does not match the number of input neurons, 453 | /// otherwise `Ok(())`. 454 | fn check_input_size(&self, input: &[FannType]) -> FannResult<()> { 455 | let num_input = self.get_num_input() as usize; 456 | if input.len() == num_input { 457 | Ok(()) 458 | } else { 459 | Err(FannError { 460 | error_type: FannErrorType::IndexOutOfBound, 461 | error_str: format!( 462 | "Input has length {}, but there are {} input neurons", 463 | input.len(), 464 | num_input 465 | ), 466 | }) 467 | } 468 | } 469 | 470 | /// Return an `Err` if the size of the slice does not match the number of output neurons, 471 | /// otherwise `Ok(())`. 472 | fn check_output_size(&self, output: &[FannType]) -> FannResult<()> { 473 | let num_output = self.get_num_output() as usize; 474 | if output.len() == num_output { 475 | Ok(()) 476 | } else { 477 | Err(FannError { 478 | error_type: FannErrorType::IndexOutOfBound, 479 | error_str: format!( 480 | "Output has length {}, but there are {} output neurons", 481 | output.len(), 482 | num_output 483 | ), 484 | }) 485 | } 486 | } 487 | 488 | /// Train with a single pair of input and output. This is always incremental training (see 489 | /// `TrainAlg`), since only one pattern is presented. 490 | pub fn train(&mut self, input: &[FannType], desired_output: &[FannType]) -> FannResult<()> { 491 | unsafe { 492 | self.check_input_size(input)?; 493 | self.check_output_size(desired_output)?; 494 | fann_train(self.raw, input.as_ptr(), desired_output.as_ptr()); 495 | FannError::check_no_error(self.raw as *mut fann_error)?; 496 | } 497 | Ok(()) 498 | } 499 | 500 | /// Create a training configuration for the given data set. 501 | pub fn on_data<'a>(&'a mut self, data: &'a TrainData) -> FannTrainer<'a> { 502 | FannTrainer::with_data(self, data) 503 | } 504 | 505 | /// Create a training configuration, reading the training data from the given file. 506 | pub fn on_file>(&mut self, path: P) -> FannTrainer { 507 | FannTrainer::with_file(self, path) 508 | } 509 | 510 | /// Train one epoch with a set of training data, i. e. each sample from the training data is 511 | /// considered exactly once. 512 | /// 513 | /// Returns the mean square error as it is calculated either before or during the actual 514 | /// training. This is not the actual MSE after the training epoch, but since calculating this 515 | /// will require to go through the entire training set once more, it is more than adequate to 516 | /// use this value during training. 517 | pub fn train_epoch(&mut self, data: &TrainData) -> FannResult { 518 | unsafe { 519 | let mse = fann_train_epoch(self.raw, data.get_raw()); 520 | FannError::check_no_error(self.raw as *mut fann_error)?; 521 | Ok(mse) 522 | } 523 | } 524 | 525 | /// Test with a single pair of input and output. This operation updates the mean square error 526 | /// but does not change the network. 527 | /// 528 | /// Returns the actual output of the network. 529 | pub fn test( 530 | &mut self, 531 | input: &[FannType], 532 | desired_output: &[FannType], 533 | ) -> FannResult> { 534 | self.check_input_size(input)?; 535 | self.check_output_size(desired_output)?; 536 | let num_output = self.get_num_output() as usize; 537 | let mut result = Vec::with_capacity(num_output); 538 | unsafe { 539 | let output = fann_test(self.raw, input.as_ptr(), desired_output.as_ptr()); 540 | FannError::check_no_error(self.raw as *mut fann_error)?; 541 | copy_nonoverlapping(output, result.as_mut_ptr(), num_output); 542 | result.set_len(num_output); 543 | } 544 | Ok(result) 545 | } 546 | 547 | /// Test with a training data set and calculate the mean square error. 548 | pub fn test_data(&mut self, data: &TrainData) -> FannResult { 549 | unsafe { 550 | let mse = fann_test_data(self.raw, data.get_raw()); 551 | FannError::check_no_error(self.raw as *mut fann_error)?; 552 | Ok(mse) 553 | } 554 | } 555 | 556 | /// Get the mean square error. 557 | pub fn get_mse(&self) -> c_float { 558 | unsafe { fann_get_MSE(self.raw) } 559 | } 560 | 561 | /// Get the number of fail bits, i. e. the number of neurons which differed from the desired 562 | /// output by more than the bit fail limit since the previous reset. 563 | pub fn get_bit_fail(&self) -> c_uint { 564 | unsafe { fann_get_bit_fail(self.raw) } 565 | } 566 | 567 | /// Reset the mean square error and bit fail count. 568 | pub fn reset_mse_and_bit_fail(&mut self) { 569 | unsafe { 570 | fann_reset_MSE(self.raw); 571 | } 572 | } 573 | 574 | /// Run the input through the neural network and returns the output. The length of the input 575 | /// must equal the number of input neurons and the length of the output will equal the number 576 | /// of output neurons. 577 | pub fn run(&self, input: &[FannType]) -> FannResult> { 578 | self.check_input_size(input)?; 579 | let num_output = self.get_num_output() as usize; 580 | let mut result = Vec::with_capacity(num_output); 581 | unsafe { 582 | let output = fann_run(self.raw, input.as_ptr()); 583 | FannError::check_no_error(self.raw as *mut fann_error)?; 584 | copy_nonoverlapping(output, result.as_mut_ptr(), num_output); 585 | result.set_len(num_output); 586 | } 587 | Ok(result) 588 | } 589 | 590 | /// Get the number of input neurons. 591 | pub fn get_num_input(&self) -> c_uint { 592 | unsafe { fann_get_num_input(self.raw) } 593 | } 594 | 595 | /// Get the number of output neurons. 596 | pub fn get_num_output(&self) -> c_uint { 597 | unsafe { fann_get_num_output(self.raw) } 598 | } 599 | 600 | /// Get the total number of neurons, including the bias neurons. 601 | /// 602 | /// E. g. a 2-4-2 network has 3 + 5 + 2 = 10 neurons (because two layers have bias neurons). 603 | pub fn get_total_neurons(&self) -> c_uint { 604 | unsafe { fann_get_total_neurons(self.raw) } 605 | } 606 | 607 | /// Get the total number of connections. 608 | pub fn get_total_connections(&self) -> c_uint { 609 | unsafe { fann_get_total_connections(self.raw) } 610 | } 611 | 612 | /// Get the type of the neural network. 613 | pub fn get_network_type(&self) -> NetType { 614 | let nt_enum = unsafe { fann_get_network_type(self.raw) }; 615 | NetType::from_nettype_enum(nt_enum) 616 | } 617 | 618 | /// Get the connection rate used when the network was created. 619 | pub fn get_connection_rate(&self) -> c_float { 620 | unsafe { fann_get_connection_rate(self.raw) } 621 | } 622 | 623 | /// Get the number of layers in the network. 624 | pub fn get_num_layers(&self) -> c_uint { 625 | unsafe { fann_get_num_layers(self.raw) } 626 | } 627 | 628 | /// Get the number of neurons in each layer of the network. 629 | pub fn get_layer_sizes(&self) -> Vec { 630 | let num_layers = self.get_num_layers() as usize; 631 | let mut result = Vec::with_capacity(num_layers); 632 | unsafe { 633 | fann_get_layer_array(self.raw, result.as_mut_ptr()); 634 | result.set_len(num_layers); 635 | } 636 | result 637 | } 638 | 639 | /// Get the number of bias neurons in each layer of the network. 640 | pub fn get_bias_counts(&self) -> Vec { 641 | let num_layers = self.get_num_layers() as usize; 642 | let mut result = Vec::with_capacity(num_layers); 643 | unsafe { 644 | fann_get_bias_array(self.raw, result.as_mut_ptr()); 645 | result.set_len(num_layers); 646 | } 647 | result 648 | } 649 | 650 | /// Get a list of all connections in the network. 651 | pub fn get_connections(&self) -> Vec { 652 | let total = self.get_total_connections() as usize; 653 | let mut result = Vec::with_capacity(total); 654 | unsafe { 655 | fann_get_connection_array(self.raw, result.as_mut_ptr()); 656 | result.set_len(total); 657 | } 658 | result 659 | } 660 | 661 | /// Set the weights of all given connections. 662 | /// 663 | /// Connections that don't already exist are ignored. 664 | pub fn set_connections<'a, I: IntoIterator>(&mut self, connections: I) { 665 | for c in connections { 666 | self.set_weight(c.from_neuron, c.to_neuron, c.weight); 667 | } 668 | } 669 | 670 | /// Set the weight of the given connection. 671 | pub fn set_weight(&mut self, from_neuron: c_uint, to_neuron: c_uint, weight: FannType) { 672 | unsafe { fann_set_weight(self.raw, from_neuron, to_neuron, weight) } 673 | } 674 | 675 | /// Get the activation function for neuron number `neuron` in layer number `layer`, counting 676 | /// the input layer as number 0. Input layer neurons do not have an activation function, so 677 | /// `layer` must be at least 1. 678 | pub fn get_activation_func(&self, layer: c_int, neuron: c_int) -> FannResult { 679 | let af_enum = unsafe { fann_get_activation_function(self.raw, layer, neuron) }; 680 | unsafe { FannError::check_no_error(self.raw as *mut fann_error)? }; 681 | ActivationFunc::from_fann_activationfunc_enum(af_enum) 682 | } 683 | 684 | /// Set the activation function for neuron number `neuron` in layer number `layer`, counting 685 | /// the input layer as number 0. Input layer neurons do not have an activation function, so 686 | /// `layer` must be at least 1. 687 | pub fn set_activation_func(&mut self, af: ActivationFunc, layer: c_int, neuron: c_int) { 688 | let af_enum = af.to_fann_activationfunc_enum(); 689 | unsafe { fann_set_activation_function(self.raw, af_enum, layer, neuron) } 690 | } 691 | 692 | /// Set the activation function for all hidden layers. 693 | pub fn set_activation_func_hidden(&mut self, activation_func: ActivationFunc) { 694 | unsafe { 695 | let af_enum = activation_func.to_fann_activationfunc_enum(); 696 | fann_set_activation_function_hidden(self.raw, af_enum); 697 | } 698 | } 699 | 700 | /// Set the activation function for the output layer. 701 | pub fn set_activation_func_output(&mut self, activation_func: ActivationFunc) { 702 | unsafe { 703 | let af_enum = activation_func.to_fann_activationfunc_enum(); 704 | fann_set_activation_function_output(self.raw, af_enum) 705 | } 706 | } 707 | 708 | /// Get the activation steepness for neuron number `neuron` in layer number `layer`. 709 | #[cfg_attr(feature = "cargo-clippy", allow(float_cmp))] 710 | pub fn get_activation_steepness(&self, layer: c_int, neuron: c_int) -> Option { 711 | let steepness = unsafe { fann_get_activation_steepness(self.raw, layer, neuron) }; 712 | // This returns exactly -1 if the neuron is not defined. 713 | if steepness == -1.0 { 714 | return None; 715 | } 716 | Some(steepness) 717 | } 718 | 719 | /// Set the activation steepness for neuron number `neuron` in layer number `layer`, counting 720 | /// the input layer as number 0. Input layer neurons do not have an activation steepness, so 721 | /// layer must be at least 1. 722 | /// 723 | /// The steepness determines how fast the function goes from minimum to maximum. A higher value 724 | /// will result in more aggressive training. 725 | /// 726 | /// A steep activation function is adequate if outputs are binary, e. e. they are supposed to 727 | /// be either almost 0 or almost 1. 728 | /// 729 | /// The default value is 0.5. 730 | pub fn set_activation_steepness(&self, steepness: FannType, layer: c_int, neuron: c_int) { 731 | unsafe { fann_set_activation_steepness(self.raw, steepness, layer, neuron) } 732 | } 733 | 734 | /// Set the activation steepness for layer number `layer`. 735 | pub fn set_activation_steepness_layer(&self, steepness: FannType, layer: c_int) { 736 | unsafe { fann_set_activation_steepness_layer(self.raw, steepness, layer) } 737 | } 738 | 739 | /// Set the activation steepness for all hidden layers. 740 | pub fn set_activation_steepness_hidden(&self, steepness: FannType) { 741 | unsafe { fann_set_activation_steepness_hidden(self.raw, steepness) } 742 | } 743 | 744 | /// Set the activation steepness for the output layer. 745 | pub fn set_activation_steepness_output(&self, steepness: FannType) { 746 | unsafe { fann_set_activation_steepness_output(self.raw, steepness) } 747 | } 748 | 749 | /// Get the error function used during training. 750 | pub fn get_error_func(&self) -> ErrorFunc { 751 | let ef_enum = unsafe { fann_get_train_error_function(self.raw) }; 752 | ErrorFunc::from_errorfunc_enum(ef_enum) 753 | } 754 | 755 | /// Set the error function used during training. 756 | /// 757 | /// The default is `Tanh`. 758 | pub fn set_error_func(&mut self, ef: ErrorFunc) { 759 | let ef_enum = ef.to_errorfunc_enum(); 760 | unsafe { fann_set_train_error_function(self.raw, ef_enum) } 761 | } 762 | 763 | /// Get the stop criterion for training. 764 | pub fn get_stop_func(&self) -> StopFunc { 765 | let sf_enum = unsafe { fann_get_train_stop_function(self.raw) }; 766 | StopFunc::from_stopfunc_enum(sf_enum) 767 | } 768 | 769 | /// Set the stop criterion for training. 770 | /// 771 | /// The default is `Mse`. 772 | pub fn set_stop_func(&mut self, sf: StopFunc) { 773 | let sf_enum = sf.to_stopfunc_enum(); 774 | unsafe { fann_set_train_stop_function(self.raw, sf_enum) } 775 | } 776 | 777 | /// Get the bit fail limit. 778 | pub fn get_bit_fail_limit(&self) -> FannType { 779 | unsafe { fann_get_bit_fail_limit(self.raw) } 780 | } 781 | 782 | /// Set the bit fail limit. 783 | /// 784 | /// Each output neuron value that differs from the desired output by more than the bit fail 785 | /// limit is counted as a failed bit. 786 | pub fn set_bit_fail_limit(&mut self, bit_fail_limit: FannType) { 787 | unsafe { fann_set_bit_fail_limit(self.raw, bit_fail_limit) } 788 | } 789 | 790 | /// Get cascade training parameters. 791 | pub fn get_cascade_params(&self) -> CascadeParams { 792 | unsafe { 793 | let num_af = fann_get_cascade_activation_functions_count(self.raw) as usize; 794 | let af_enum_ptr = fann_get_cascade_activation_functions(self.raw); 795 | let af_enums = Vec::from_raw_parts(af_enum_ptr, num_af, num_af); 796 | let activation_functions = af_enums 797 | .iter() 798 | .map(|&af_enum| ActivationFunc::from_fann_activationfunc_enum(af_enum).unwrap()) 799 | .collect(); 800 | forget(af_enums); 801 | let num_st = fann_get_cascade_activation_steepnesses_count(self.raw) as usize; 802 | let mut activation_steepnesses = Vec::with_capacity(num_st); 803 | let st_ptr = fann_get_cascade_activation_steepnesses(self.raw); 804 | copy_nonoverlapping(st_ptr, activation_steepnesses.as_mut_ptr(), num_st); 805 | activation_steepnesses.set_len(num_st); 806 | CascadeParams { 807 | output_change_fraction: fann_get_cascade_output_change_fraction(self.raw), 808 | output_stagnation_epochs: fann_get_cascade_output_stagnation_epochs(self.raw), 809 | candidate_change_fraction: fann_get_cascade_candidate_change_fraction(self.raw), 810 | candidate_stagnation_epochs: fann_get_cascade_candidate_stagnation_epochs(self.raw), 811 | candidate_limit: fann_get_cascade_candidate_limit(self.raw), 812 | weight_multiplier: fann_get_cascade_weight_multiplier(self.raw), 813 | max_out_epochs: fann_get_cascade_max_out_epochs(self.raw), 814 | max_cand_epochs: fann_get_cascade_max_cand_epochs(self.raw), 815 | activation_functions, 816 | activation_steepnesses, 817 | num_candidate_groups: fann_get_cascade_num_candidate_groups(self.raw), 818 | } 819 | } 820 | } 821 | 822 | /// Set cascade training parameters. 823 | pub fn set_cascade_params(&mut self, params: &CascadeParams) { 824 | let af_enums: Vec<_> = params 825 | .activation_functions 826 | .iter() 827 | .map(|af| af.to_fann_activationfunc_enum()) 828 | .collect(); 829 | unsafe { 830 | fann_set_cascade_output_change_fraction(self.raw, params.output_change_fraction); 831 | fann_set_cascade_output_stagnation_epochs(self.raw, params.output_stagnation_epochs); 832 | fann_set_cascade_candidate_change_fraction(self.raw, params.candidate_change_fraction); 833 | fann_set_cascade_candidate_stagnation_epochs( 834 | self.raw, 835 | params.candidate_stagnation_epochs, 836 | ); 837 | fann_set_cascade_candidate_limit(self.raw, params.candidate_limit); 838 | fann_set_cascade_weight_multiplier(self.raw, params.weight_multiplier); 839 | fann_set_cascade_max_out_epochs(self.raw, params.max_out_epochs); 840 | fann_set_cascade_max_cand_epochs(self.raw, params.max_cand_epochs); 841 | fann_set_cascade_activation_functions( 842 | self.raw, 843 | af_enums.as_ptr(), 844 | af_enums.len() as c_uint, 845 | ); 846 | fann_set_cascade_activation_steepnesses( 847 | self.raw, 848 | params.activation_steepnesses.as_ptr(), 849 | params.activation_steepnesses.len() as c_uint, 850 | ); 851 | fann_set_cascade_num_candidate_groups(self.raw, params.num_candidate_groups); 852 | } 853 | } 854 | 855 | /// Get the currently configured training algorithm. 856 | pub fn get_train_algorithm(&self) -> TrainAlgorithm { 857 | let ft_enum = unsafe { fann_get_training_algorithm(self.raw) }; 858 | match ft_enum { 859 | FANN_TRAIN_INCREMENTAL => unsafe { 860 | TrainAlgorithm::Incremental(IncrementalParams { 861 | learning_momentum: fann_get_learning_momentum(self.raw), 862 | learning_rate: fann_get_learning_rate(self.raw), 863 | }) 864 | }, 865 | FANN_TRAIN_BATCH => unsafe { 866 | TrainAlgorithm::Batch(BatchParams { 867 | learning_rate: fann_get_learning_rate(self.raw), 868 | }) 869 | }, 870 | FANN_TRAIN_RPROP => unsafe { 871 | TrainAlgorithm::Rprop(RpropParams { 872 | decrease_factor: fann_get_rprop_decrease_factor(self.raw), 873 | increase_factor: fann_get_rprop_increase_factor(self.raw), 874 | delta_min: fann_get_rprop_delta_min(self.raw), 875 | delta_max: fann_get_rprop_delta_max(self.raw), 876 | delta_zero: fann_get_rprop_delta_zero(self.raw), 877 | }) 878 | }, 879 | FANN_TRAIN_QUICKPROP => unsafe { 880 | TrainAlgorithm::Quickprop(QuickpropParams { 881 | decay: fann_get_quickprop_decay(self.raw), 882 | mu: fann_get_quickprop_mu(self.raw), 883 | learning_rate: fann_get_learning_rate(self.raw), 884 | }) 885 | }, 886 | } 887 | } 888 | 889 | /// Set the algorithm to be used for training. 890 | pub fn set_train_algorithm(&mut self, ta: TrainAlgorithm) { 891 | match ta { 892 | TrainAlgorithm::Incremental(params) => unsafe { 893 | fann_set_training_algorithm(self.raw, FANN_TRAIN_INCREMENTAL); 894 | fann_set_learning_momentum(self.raw, params.learning_momentum); 895 | fann_set_learning_rate(self.raw, params.learning_rate); 896 | }, 897 | TrainAlgorithm::Batch(params) => unsafe { 898 | fann_set_training_algorithm(self.raw, FANN_TRAIN_BATCH); 899 | fann_set_learning_rate(self.raw, params.learning_rate); 900 | }, 901 | TrainAlgorithm::Rprop(params) => unsafe { 902 | fann_set_training_algorithm(self.raw, FANN_TRAIN_RPROP); 903 | fann_set_rprop_decrease_factor(self.raw, params.decrease_factor); 904 | fann_set_rprop_increase_factor(self.raw, params.increase_factor); 905 | fann_set_rprop_delta_min(self.raw, params.delta_min); 906 | fann_set_rprop_delta_max(self.raw, params.delta_max); 907 | fann_set_rprop_delta_zero(self.raw, params.delta_zero); 908 | }, 909 | TrainAlgorithm::Quickprop(params) => unsafe { 910 | fann_set_training_algorithm(self.raw, FANN_TRAIN_QUICKPROP); 911 | fann_set_quickprop_decay(self.raw, params.decay); 912 | fann_set_quickprop_mu(self.raw, params.mu); 913 | fann_set_learning_rate(self.raw, params.learning_rate); 914 | }, 915 | } 916 | } 917 | 918 | /// Calculate input scaling parameters for future use based on the given training data. 919 | pub fn set_input_scaling_params( 920 | &mut self, 921 | data: &TrainData, 922 | new_input_min: c_float, 923 | new_input_max: c_float, 924 | ) -> FannResult<()> { 925 | unsafe { 926 | let result = fann_set_input_scaling_params( 927 | self.raw, 928 | data.get_raw(), 929 | new_input_min, 930 | new_input_max, 931 | ); 932 | FannError::check_zero( 933 | result, 934 | self.raw as *mut fann_error, 935 | "Error calculating scaling parameters", 936 | ) 937 | } 938 | } 939 | 940 | /// Calculate output scaling parameters for future use based on the given training data. 941 | pub fn set_output_scaling_params( 942 | &mut self, 943 | data: &TrainData, 944 | new_output_min: c_float, 945 | new_output_max: c_float, 946 | ) -> FannResult<()> { 947 | unsafe { 948 | let result = fann_set_output_scaling_params( 949 | self.raw, 950 | data.get_raw(), 951 | new_output_min, 952 | new_output_max, 953 | ); 954 | FannError::check_zero( 955 | result, 956 | self.raw as *mut fann_error, 957 | "Error calculating scaling parameters", 958 | ) 959 | } 960 | } 961 | 962 | /// Calculate scaling parameters for future use based on the given training data. 963 | pub fn set_scaling_params( 964 | &mut self, 965 | data: &TrainData, 966 | new_input_min: c_float, 967 | new_input_max: c_float, 968 | new_output_min: c_float, 969 | new_output_max: c_float, 970 | ) -> FannResult<()> { 971 | unsafe { 972 | let result = fann_set_scaling_params( 973 | self.raw, 974 | data.get_raw(), 975 | new_input_min, 976 | new_input_max, 977 | new_output_min, 978 | new_output_max, 979 | ); 980 | FannError::check_zero( 981 | result, 982 | self.raw as *mut fann_error, 983 | "Error calculating scaling parameters", 984 | ) 985 | } 986 | } 987 | 988 | /// Clear scaling parameters. 989 | pub fn clear_scaling_params(&mut self) -> FannResult<()> { 990 | unsafe { 991 | FannError::check_zero( 992 | fann_clear_scaling_params(self.raw), 993 | self.raw as *mut fann_error, 994 | "Error clearing scaling parameters", 995 | ) 996 | } 997 | } 998 | 999 | /// Scale data in input vector before feeding it to the network, based on previously calculated 1000 | /// parameters. 1001 | pub fn scale_input(&self, input: &mut [FannType]) -> FannResult<()> { 1002 | unsafe { 1003 | fann_scale_input(self.raw, input.as_mut_ptr()); 1004 | FannError::check_no_error(self.raw as *mut fann_error) 1005 | } 1006 | } 1007 | 1008 | /// Scale data in output vector before feeding it to the network, based on previously calculated 1009 | /// parameters. 1010 | pub fn scale_output(&self, output: &mut [FannType]) -> FannResult<()> { 1011 | unsafe { 1012 | fann_scale_output(self.raw, output.as_mut_ptr()); 1013 | FannError::check_no_error(self.raw as *mut fann_error) 1014 | } 1015 | } 1016 | 1017 | /// Descale data in input vector after feeding it to the network, based on previously calculated 1018 | /// parameters. 1019 | pub fn descale_input(&self, input: &mut [FannType]) -> FannResult<()> { 1020 | unsafe { 1021 | fann_descale_input(self.raw, input.as_mut_ptr()); 1022 | FannError::check_no_error(self.raw as *mut fann_error) 1023 | } 1024 | } 1025 | 1026 | /// Descale data in output vector after getting it from the network, based on previously 1027 | /// calculated parameters. 1028 | pub fn descale_output(&self, output: &mut [FannType]) -> FannResult<()> { 1029 | unsafe { 1030 | fann_descale_output(self.raw, output.as_mut_ptr()); 1031 | FannError::check_no_error(self.raw as *mut fann_error) 1032 | } 1033 | } 1034 | 1035 | // TODO: set_error_log: Always disable, due to different error handling? 1036 | // TODO: save_to_fixed? 1037 | // TODO: user_data methods? 1038 | } 1039 | 1040 | impl Drop for Fann { 1041 | fn drop(&mut self) { 1042 | unsafe { 1043 | fann_destroy(self.raw); 1044 | } 1045 | } 1046 | } 1047 | 1048 | #[cfg(test)] 1049 | mod tests { 1050 | use super::*; 1051 | use fann_sys; 1052 | use libc::c_uint; 1053 | use std::cell::RefCell; 1054 | use std::ptr::null_mut; 1055 | 1056 | const EPSILON: FannType = 0.2; 1057 | 1058 | #[test] 1059 | fn test_tutorial() { 1060 | let max_epochs = 500_000; 1061 | let desired_error = 0.0001; 1062 | let mut fann = Fann::new(&[2, 3, 1]).unwrap(); 1063 | fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric); 1064 | fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric); 1065 | fann.on_file("test_files/xor.data") 1066 | .train(max_epochs, desired_error) 1067 | .unwrap(); 1068 | assert!(EPSILON > (1.0 - fann.run(&[-1.0, 1.0]).unwrap()[0]).abs()); 1069 | assert!(EPSILON > (1.0 - fann.run(&[1.0, -1.0]).unwrap()[0]).abs()); 1070 | assert!(EPSILON > (-1.0 - fann.run(&[1.0, 1.0]).unwrap()[0]).abs()); 1071 | assert!(EPSILON > (-1.0 - fann.run(&[-1.0, -1.0]).unwrap()[0]).abs()); 1072 | } 1073 | 1074 | #[test] 1075 | fn test_activation_func() { 1076 | let mut fann = Fann::new(&[4, 3, 3, 1]).unwrap(); 1077 | // Don't print the expected errors: 1078 | unsafe { 1079 | fann_sys::fann_set_error_log(fann.raw as *mut fann_sys::fann_error, null_mut()); 1080 | } 1081 | assert!(fann.get_activation_func(0, 1).is_err()); 1082 | assert!(fann.get_activation_func(4, 1).is_err()); 1083 | assert_eq!( 1084 | Ok(ActivationFunc::SigmoidStepwise), 1085 | fann.get_activation_func(2, 2) 1086 | ); 1087 | fann.set_activation_func(ActivationFunc::Sin, 2, 2); 1088 | assert_eq!(Ok(ActivationFunc::Sin), fann.get_activation_func(2, 2)); 1089 | } 1090 | 1091 | #[test] 1092 | fn test_train_algorithm() { 1093 | let mut fann = Fann::new(&[4, 3, 3, 1]).unwrap(); 1094 | assert_eq!(TrainAlgorithm::default(), fann.get_train_algorithm()); 1095 | let quickprop = TrainAlgorithm::Quickprop(QuickpropParams { 1096 | decay: -0.0002, 1097 | ..Default::default() 1098 | }); 1099 | fann.set_train_algorithm(quickprop); 1100 | assert_eq!(quickprop, fann.get_train_algorithm()); 1101 | } 1102 | 1103 | #[test] 1104 | fn test_layer_sizes() { 1105 | let fann = Fann::new(&[4, 3, 3, 1]).unwrap(); 1106 | assert_eq!(vec![4, 3, 3, 1], fann.get_layer_sizes()); 1107 | assert_eq!(vec![1, 1, 1, 0], fann.get_bias_counts()); 1108 | } 1109 | 1110 | #[test] 1111 | fn test_get_set_connections() { 1112 | let mut fann = Fann::new(&[1, 1]).unwrap(); 1113 | let connection = Connection { 1114 | from_neuron: 1, 1115 | to_neuron: 2, 1116 | weight: 0.123, 1117 | }; 1118 | fann.set_connections(&[connection]); 1119 | assert_eq!(2, fann.get_total_connections()); // 2 because of the bias neuron in layer 0. 1120 | assert_eq!(connection, fann.get_connections()[1]); 1121 | } 1122 | 1123 | #[test] 1124 | fn test_cascade_params() { 1125 | let fann = Fann::new(&[1, 1]).unwrap(); 1126 | assert_eq!(CascadeParams::default(), fann.get_cascade_params()); 1127 | } 1128 | 1129 | #[test] 1130 | fn test_train_data_from_callback() { 1131 | let mut fann = Fann::new(&[2, 3, 1]).unwrap(); 1132 | fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric); 1133 | fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric); 1134 | let td = TrainData::from_callback( 1135 | 4, 1136 | 2, 1137 | 1, 1138 | Box::new(|num| match num { 1139 | 0 => (vec![-1.0, 1.0], vec![1.0]), 1140 | 1 => (vec![1.0, -1.0], vec![1.0]), 1141 | 2 => (vec![-1.0, -1.0], vec![-1.0]), 1142 | 3 => (vec![1.0, 1.0], vec![-1.0]), 1143 | _ => unreachable!(), 1144 | }), 1145 | ).unwrap(); 1146 | fann.on_data(&td).train(500_000, 0.0001).unwrap(); 1147 | assert!(EPSILON > (1.0 - fann.run(&[-1.0, 1.0]).unwrap()[0]).abs()); 1148 | assert!(EPSILON > (1.0 - fann.run(&[1.0, -1.0]).unwrap()[0]).abs()); 1149 | assert!(EPSILON > (-1.0 - fann.run(&[1.0, 1.0]).unwrap()[0]).abs()); 1150 | assert!(EPSILON > (-1.0 - fann.run(&[-1.0, -1.0]).unwrap()[0]).abs()); 1151 | } 1152 | 1153 | #[test] 1154 | fn test_train_callback() { 1155 | // Without a hidden layer, the XOR problem cannot be solved, so the training will only stop 1156 | // when the callback says so. 1157 | let mut fann = Fann::new(&[2, 1]).unwrap(); 1158 | fann.set_activation_func_output(ActivationFunc::LinearPiece); 1159 | let xor_data = TrainData::from_file("test_files/xor.data").unwrap(); 1160 | let raw = fann.raw; 1161 | let callback_epochs = RefCell::new(Vec::new()); 1162 | let cb = |fann: &Fann, train_data: &TrainData, epochs: c_uint| { 1163 | assert_eq!(raw, fann.raw); 1164 | unsafe { 1165 | assert_eq!(xor_data.get_raw(), train_data.get_raw()); 1166 | } 1167 | callback_epochs.borrow_mut().push(epochs); 1168 | CallbackResult::stop_if(epochs == 40) // Stop after 40 epochs. 1169 | }; 1170 | fann.on_data(&xor_data) 1171 | .with_callback(10, &cb) 1172 | .train(100, 0.1) 1173 | .unwrap(); 1174 | // The interval was 10 epochs. Also, FANN always runs the callback after the first epoch. 1175 | assert_eq!(vec![1, 10, 20, 30, 40], *callback_epochs.borrow()); 1176 | } 1177 | } 1178 | -------------------------------------------------------------------------------- /src/net_type.rs: -------------------------------------------------------------------------------- 1 | use fann_sys::*; 2 | 3 | /// Network types 4 | #[derive(Copy, Clone, Debug, Eq, PartialEq)] 5 | pub enum NetType { 6 | /// Each layer of neurons only has connections to the next layer. 7 | Layer, 8 | /// Each layer has connections to all following layers. 9 | Shortcut, 10 | } 11 | 12 | impl NetType { 13 | /// Create a `NetType` from a `fann_sys::fann_nettype_enum`. 14 | pub fn from_nettype_enum(nt_enum: fann_nettype_enum) -> NetType { 15 | match nt_enum { 16 | FANN_NETTYPE_LAYER => NetType::Layer, 17 | FANN_NETTYPE_SHORTCUT => NetType::Shortcut, 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/stop_func.rs: -------------------------------------------------------------------------------- 1 | use fann_sys::*; 2 | 3 | /// Stop critieria for training. 4 | #[derive(Copy, Clone, Debug, Eq, PartialEq)] 5 | pub enum StopFunc { 6 | /// The mean square error of the whole output. 7 | Mse, 8 | /// The number of training data points where the output neuron's error was greater than the bit 9 | /// fail limit. Every neuron is counted for every training data sample where it fails. 10 | Bit, 11 | } 12 | 13 | impl StopFunc { 14 | /// Create a `StopFunc` from a `fann_sys::fann_stopfunc_enum`. 15 | pub fn from_stopfunc_enum(sf_enum: fann_stopfunc_enum) -> StopFunc { 16 | match sf_enum { 17 | FANN_STOPFUNC_MSE => StopFunc::Mse, 18 | FANN_STOPFUNC_BIT => StopFunc::Bit, 19 | } 20 | } 21 | 22 | /// Return the `fann_sys::fann_stopfunc_enum` corresponding to this `StopFunc`. 23 | pub fn to_stopfunc_enum(self) -> fann_stopfunc_enum { 24 | match self { 25 | StopFunc::Mse => FANN_STOPFUNC_MSE, 26 | StopFunc::Bit => FANN_STOPFUNC_BIT, 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/train_algorithm.rs: -------------------------------------------------------------------------------- 1 | use libc::c_float; 2 | 3 | #[derive(Copy, Clone, Debug, PartialEq)] 4 | pub struct IncrementalParams { 5 | /// A higher momentum can be used to speed up incremental training. It should be between 0 6 | /// and 1, the default is 0. 7 | pub learning_momentum: c_float, 8 | /// The learning rate determines how aggressive training should be. Default is 0.7. 9 | pub learning_rate: c_float, 10 | } 11 | 12 | impl Default for IncrementalParams { 13 | fn default() -> IncrementalParams { 14 | IncrementalParams { 15 | learning_momentum: 0.0, 16 | learning_rate: 0.7, 17 | } 18 | } 19 | } 20 | 21 | #[derive(Copy, Clone, Debug, PartialEq)] 22 | pub struct BatchParams { 23 | /// The learning rate determines how aggressive training should be. Default is 0.7. 24 | pub learning_rate: c_float, 25 | } 26 | 27 | impl Default for BatchParams { 28 | fn default() -> BatchParams { 29 | BatchParams { learning_rate: 0.7 } 30 | } 31 | } 32 | 33 | #[derive(Copy, Clone, Debug, PartialEq)] 34 | pub struct RpropParams { 35 | /// A value less than 1, used to decrease the step size during training. Default 0.5 36 | pub decrease_factor: c_float, 37 | /// A value greater than 1, used to increase the step size during training. Default 1.2 38 | pub increase_factor: c_float, 39 | /// The minimum step size. Default 0.0 40 | pub delta_min: c_float, 41 | /// The maximum step size. Default 50.0 42 | pub delta_max: c_float, 43 | /// The initial step size. Default 0.1 44 | pub delta_zero: c_float, 45 | } 46 | 47 | impl Default for RpropParams { 48 | fn default() -> RpropParams { 49 | RpropParams { 50 | decrease_factor: 0.5, 51 | increase_factor: 1.2, 52 | delta_min: 0.0, 53 | delta_max: 50.0, 54 | delta_zero: 0.1, 55 | } 56 | } 57 | } 58 | 59 | #[derive(Copy, Clone, Debug, PartialEq)] 60 | pub struct QuickpropParams { 61 | /// The factor by which weights should become smaller in each iteration, to ensure that 62 | /// the weights don't grow too large during training. Should be a negative number close to 63 | /// 0. The default is -0.0001. 64 | pub decay: c_float, 65 | /// The mu factor is used to increase or decrease the step size; should always be greater 66 | /// than 1. The default is 1.75. 67 | pub mu: c_float, 68 | /// The learning rate determines how aggressive training should be. Default is 0.7. 69 | pub learning_rate: c_float, 70 | } 71 | 72 | impl Default for QuickpropParams { 73 | fn default() -> QuickpropParams { 74 | QuickpropParams { 75 | decay: -0.0001, 76 | mu: 1.75, 77 | learning_rate: 0.7, 78 | } 79 | } 80 | } 81 | 82 | /// The Training algorithms used when training on `fann_train_data` with functions like 83 | /// `fann_train_on_data` or `fann_train_on_file`. The incremental training alters the weights 84 | /// after each time it is presented an input pattern, while batch only alters the weights once after 85 | /// it has been presented to all the patterns. 86 | #[derive(Copy, Clone, Debug, PartialEq)] 87 | pub enum TrainAlgorithm { 88 | /// Standard backpropagation algorithm, where the weights are updated after each training 89 | /// pattern. This means that the weights are updated many times during a single epoch and some 90 | /// problems will train very fast, while other more advanced problems will not train very well. 91 | Incremental(IncrementalParams), 92 | /// Standard backpropagation algorithm, where the weights are updated after calculating the mean 93 | /// square error for the whole training set. This means that the weights are only updated once 94 | /// during an epoch. For this reason some problems will train slower with this algorithm. But 95 | /// since the mean square error is calculated more correctly than in incremental training, some 96 | /// problems will reach better solutions. 97 | Batch(BatchParams), 98 | /// A more advanced batch training algorithm which achieves good results for many problems. 99 | /// `Rprop` is adaptive and therefore does not use the `learning_rate`. Some other parameters 100 | /// can, however, be set to change the way `Rprop` works, but it is only recommended for users 101 | /// with a deep understanding of the algorithm. The original RPROP training algorithm is 102 | /// described by [Riedmiller and Braun, 1993], but the algorithm used here is a variant, iRPROP, 103 | /// described by [Igel and Husken, 2000]. 104 | Rprop(RpropParams), 105 | /// A more advanced batch training algorithm which achieves good results for many problems. The 106 | /// quickprop training algorithm uses the `learning_rate` parameter along with other more 107 | /// advanced parameters, but it is only recommended to change these for users with a deep 108 | /// understanding of the algorithm. Quickprop is described by [Fahlman, 1988]. 109 | Quickprop(QuickpropParams), 110 | } 111 | 112 | impl Default for TrainAlgorithm { 113 | fn default() -> TrainAlgorithm { 114 | TrainAlgorithm::Rprop(Default::default()) 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /src/train_data.rs: -------------------------------------------------------------------------------- 1 | extern crate fann_sys; 2 | 3 | use super::{to_filename, Fann}; 4 | use error::{FannError, FannErrorType, FannResult}; 5 | use fann_sys::*; 6 | use libc::c_uint; 7 | use std::cell::RefCell; 8 | use std::path::Path; 9 | use std::ptr::copy_nonoverlapping; 10 | 11 | pub type TrainCallback = dyn Fn(c_uint) -> (Vec, Vec); 12 | 13 | // Thread-local container for user-supplied callback functions. 14 | // This is necessary because the raw fann_create_train_from_callback C function takes a function 15 | // pointer and not a closure. So instead of the user-supplied function we pass a function to it 16 | // which will call the content of CALLBACK. 17 | thread_local!(static CALLBACK: RefCell>> = RefCell::new(None)); 18 | 19 | pub struct TrainData { 20 | raw: *mut fann_train_data, 21 | } 22 | 23 | impl TrainData { 24 | /// Read a file that stores training data. 25 | /// 26 | /// The file must be formatted like: 27 | /// 28 | /// ```text 29 | /// num_train_data num_input num_output 30 | /// inputdata separated by space 31 | /// outputdata separated by space 32 | /// . 33 | /// . 34 | /// . 35 | /// inputdata separated by space 36 | /// outputdata separated by space 37 | /// ``` 38 | pub fn from_file>(path: P) -> FannResult { 39 | let filename = to_filename(path)?; 40 | unsafe { 41 | let raw = fann_read_train_from_file(filename.as_ptr()); 42 | FannError::check_no_error(raw as *mut fann_error)?; 43 | Ok(TrainData { raw }) 44 | } 45 | } 46 | 47 | /// Create training data using the given callback which for each number between `0` (included) 48 | /// and `num_data` (excluded) returns a pair of input and output vectors with `num_input` and 49 | /// `num_output` entries respectively. 50 | pub fn from_callback( 51 | num_data: c_uint, 52 | num_input: c_uint, 53 | num_output: c_uint, 54 | cb: Box, 55 | ) -> FannResult { 56 | extern "C" fn raw_callback( 57 | num: c_uint, 58 | num_input: c_uint, 59 | num_output: c_uint, 60 | input: *mut fann_type, 61 | output: *mut fann_type, 62 | ) { 63 | // Call the callback we stored in the thread-local container. 64 | let (in_vec, out_vec) = CALLBACK.with(|cell| cell.borrow().as_ref().unwrap()(num)); 65 | // Make sure it returned data of the correct size, then copy the data. 66 | assert_eq!(in_vec.len(), num_input as usize); 67 | assert_eq!(out_vec.len(), num_output as usize); 68 | unsafe { 69 | copy_nonoverlapping(in_vec.as_ptr(), input, in_vec.len()); 70 | copy_nonoverlapping(out_vec.as_ptr(), output, out_vec.len()); 71 | } 72 | } 73 | unsafe { 74 | // Put the callback into the thread-local container. 75 | CALLBACK.with(|cell| *cell.borrow_mut() = Some(cb)); 76 | let raw = fann_create_train_from_callback( 77 | num_data, 78 | num_input, 79 | num_output, 80 | Some(raw_callback), 81 | ); 82 | // Remove it from the thread-local container to free the memory. 83 | CALLBACK.with(|cell| *cell.borrow_mut() = None); 84 | FannError::check_no_error(raw as *mut fann_error)?; 85 | Ok(TrainData { raw }) 86 | } 87 | } 88 | 89 | /// Save the training data to a file. 90 | pub fn save>(&self, path: P) -> FannResult<()> { 91 | let filename = to_filename(path)?; 92 | unsafe { 93 | let result = fann_save_train(self.raw, filename.as_ptr()); 94 | FannError::check_no_error(self.raw as *mut fann_error)?; 95 | if result == -1 { 96 | Err(FannError { 97 | error_type: FannErrorType::CantSaveFile, 98 | error_str: "Error saving training data".to_owned(), 99 | }) 100 | } else { 101 | Ok(()) 102 | } 103 | } 104 | } 105 | 106 | /// Merge the given data sets into a new one. 107 | pub fn merge(data1: &TrainData, data2: &TrainData) -> FannResult { 108 | unsafe { 109 | let raw = fann_merge_train_data(data1.raw, data2.raw); 110 | FannError::check_no_error(raw as *mut fann_error)?; 111 | Ok(TrainData { raw }) 112 | } 113 | } 114 | 115 | /// Create a subset of the training data, starting at the given positon and consisting of 116 | /// `length` samples. 117 | pub fn subset(&self, pos: c_uint, length: c_uint) -> FannResult { 118 | unsafe { 119 | let raw = fann_subset_train_data(self.raw, pos, length); 120 | FannError::check_no_error(raw as *mut fann_error)?; 121 | Ok(TrainData { raw }) 122 | } 123 | } 124 | 125 | /// Return the number of training patterns in the data. 126 | pub fn length(&self) -> c_uint { 127 | unsafe { fann_length_train_data(self.raw) } 128 | } 129 | 130 | /// Return the number of input values in each training pattern. 131 | pub fn num_input(&self) -> c_uint { 132 | unsafe { fann_num_input_train_data(self.raw) } 133 | } 134 | 135 | /// Return the number of output values in each training pattern. 136 | pub fn num_output(&self) -> c_uint { 137 | unsafe { fann_num_output_train_data(self.raw) } 138 | } 139 | 140 | /// Scale input and output in the training data using the parameters previously calculated for 141 | /// the given network. 142 | pub fn scale_for(&mut self, fann: &Fann) -> FannResult<()> { 143 | unsafe { 144 | fann_scale_train(fann.raw, self.raw); 145 | FannError::check_no_error(fann.raw as *mut fann_error)?; 146 | FannError::check_no_error(self.raw as *mut fann_error) 147 | } 148 | } 149 | 150 | /// Descale input and output in the training data using the parameters previously calculated for 151 | /// the given network. 152 | pub fn descale_for(&mut self, fann: &Fann) -> FannResult<()> { 153 | unsafe { 154 | fann_descale_train(fann.raw, self.raw); 155 | FannError::check_no_error(fann.raw as *mut fann_error)?; 156 | FannError::check_no_error(self.raw as *mut fann_error) 157 | } 158 | } 159 | 160 | /// Scales the inputs in the training data to the specified range. 161 | pub fn scale_input(&mut self, new_min: fann_type, new_max: fann_type) -> FannResult<()> { 162 | unsafe { 163 | fann_scale_input_train_data(self.raw, new_min, new_max); 164 | FannError::check_no_error(self.raw as *mut fann_error) 165 | } 166 | } 167 | 168 | /// Scales the outputs in the training data to the specified range. 169 | pub fn scale_output(&mut self, new_min: fann_type, new_max: fann_type) -> FannResult<()> { 170 | unsafe { 171 | fann_scale_output_train_data(self.raw, new_min, new_max); 172 | FannError::check_no_error(self.raw as *mut fann_error) 173 | } 174 | } 175 | 176 | /// Scales the inputs and outputs in the training data to the specified range. 177 | pub fn scale(&mut self, new_min: fann_type, new_max: fann_type) -> FannResult<()> { 178 | unsafe { 179 | fann_scale_train_data(self.raw, new_min, new_max); 180 | FannError::check_no_error(self.raw as *mut fann_error) 181 | } 182 | } 183 | 184 | /// Shuffle training data, randomizing the order. This is recommended for incremental training 185 | /// while it does not affect batch training. 186 | pub fn shuffle(&mut self) { 187 | unsafe { 188 | fann_shuffle_train_data(self.raw); 189 | } 190 | } 191 | 192 | /// Get a pointer to the underlying raw `fann_train_data` structure. 193 | pub unsafe fn get_raw(&self) -> *mut fann_train_data { 194 | self.raw 195 | } 196 | 197 | // TODO: save_to_fixed? 198 | } 199 | 200 | impl Clone for TrainData { 201 | fn clone(&self) -> TrainData { 202 | unsafe { 203 | let raw = fann_duplicate_train_data(self.raw); 204 | if FannError::check_no_error(raw as *mut fann_error).is_err() { 205 | panic!("Unable to clone TrainData."); 206 | } 207 | TrainData { raw } 208 | } 209 | } 210 | } 211 | 212 | impl Drop for TrainData { 213 | fn drop(&mut self) { 214 | unsafe { 215 | fann_destroy_train(self.raw); 216 | } 217 | } 218 | } 219 | -------------------------------------------------------------------------------- /test_files/xor.data: -------------------------------------------------------------------------------- 1 | 4 2 1 2 | -1 -1 3 | -1 4 | -1 1 5 | 1 6 | 1 -1 7 | 1 8 | 1 1 9 | -1 10 | --------------------------------------------------------------------------------