├── README.md ├── img └── Equations.PNG └── src ├── Classifier.hpp └── Error.hpp /README.md: -------------------------------------------------------------------------------- 1 | # Competitive Feature Learning 2 | 3 | ## Classification 4 | 5 | Features are sets of values that become active when a set of matching inputs is received. The error of each feature is calculated and compared to a threshold which determines the state (i.e. active or inactive) of the feature. 6 | 7 | The number of features that respond to a given input is limited by the class size. When an input triggers an excess number of features, the most similar features stay active while the rest become inactive. 8 | 9 | ## Learning 10 | 11 | Features are updated according to the input. Learning occurs in three steps: First, the thresholds are adjusted (regardless of activation) to match the similarity of each feature. Second, the weights of active features are adjusted to match the importance of each value. Finally, the values of the active features are adjusted to match the input. 12 | 13 | ![Equations](https://github.com/CarsonScott/Competitive-Feature-Learning/blob/master/img/Equations.PNG) 14 | -------------------------------------------------------------------------------- /img/Equations.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CarsonScott/Competitive-Feature-Learning/94def226ec56e8d476df9d83905e13babdc33c2f/img/Equations.PNG -------------------------------------------------------------------------------- /src/Classifier.hpp: -------------------------------------------------------------------------------- 1 | #ifndef CLASSIFIER_HPP_INCLUDED 2 | #define CLASSIFIER_HPP_INCLUDED 3 | 4 | #include "Error.hpp" 5 | struct Category 6 | { 7 | float label; 8 | float error; 9 | Category(int l=0, float e=0) 10 | { 11 | label = l; 12 | error = e; 13 | } 14 | }; 15 | 16 | struct CategorySet 17 | { 18 | Array labels; 19 | Array errors; 20 | 21 | void append(float label, float error) 22 | { 23 | labels.push_back(label); 24 | errors.push_back(error); 25 | } 26 | 27 | void append(Category category) 28 | { 29 | append(category.label, category.error); 30 | } 31 | 32 | void insert(int i, float label, float error) 33 | { 34 | labels.insert(labels.begin() + i, label); 35 | errors.insert(errors.begin() + i, error); 36 | 37 | } 38 | 39 | void erase(int i) 40 | { 41 | labels.erase(labels.begin() + i); 42 | errors.erase(errors.begin() + i); 43 | } 44 | 45 | int size() 46 | { 47 | return labels.size(); 48 | } 49 | 50 | Category operator [](int i) 51 | { 52 | return Category(labels[i], errors[i]); 53 | } 54 | }; 55 | 56 | class Classifier 57 | { 58 | int input_size; 59 | int class_size; 60 | Array inputs; 61 | Array thresholds; 62 | Matrix features; 63 | Matrix weights; 64 | Matrix errors; 65 | Matrix previous_errors; 66 | Matrix progress; 67 | Matrix sample_sizes; 68 | CategorySet active_categories; 69 | float learning_rate; 70 | float decay_rate; 71 | 72 | int get_sorted_index(Array arr, float val) 73 | { 74 | int index = 0; 75 | while(index < arr.size()) 76 | { 77 | if(val <= arr[index]) 78 | { 79 | return index; 80 | } 81 | else 82 | { 83 | index += 1; 84 | } 85 | } 86 | return arr.size(); 87 | } 88 | 89 | public: 90 | Classifier(){} 91 | Classifier(int in_size, int cls_size) 92 | { 93 | input_size = in_size; 94 | class_size = cls_size; 95 | learning_rate = 0.0001; 96 | decay_rate = 0.00001; 97 | for(int i = 0; i < input_size; i++) 98 | { 99 | inputs.push_back(0); 100 | } 101 | } 102 | 103 | void create() 104 | { 105 | Array feature; 106 | Array weight; 107 | Array error; 108 | Array previous_error; 109 | Array prog; 110 | for(int i = 0; i < input_size; i++) 111 | { 112 | feature.push_back(0); 113 | weight.push_back(.5); 114 | error.push_back(0); 115 | previous_error.push_back(1); 116 | prog.push_back(0); 117 | } 118 | features.push_back(feature); 119 | weights.push_back(weight); 120 | errors.push_back(error); 121 | previous_errors.push_back(previous_error); 122 | progress.push_back(prog); 123 | thresholds.push_back(.5); 124 | } 125 | 126 | void create(Array feature, Array weight, float threshold) 127 | { 128 | create(); 129 | features.back() = feature; 130 | weights.back() = weight; 131 | thresholds.back() = threshold; 132 | } 133 | 134 | void updateError(int c, int f) 135 | { 136 | previous_errors[c][f] = errors[c][f]; 137 | errors[c][f] = werr(inputs[f], features[c][f], weights[c][f]); 138 | } 139 | 140 | void updateProgress(int c, int f) 141 | { 142 | progress[c][f] = werr(previous_errors[c][f], errors[c][f], weights[c][f]); 143 | } 144 | 145 | void updateWeight(int c, int f) 146 | { 147 | weights[c][f] += learning_rate * ((progress[c][f]+1) / (1+errors[c][f])); 148 | } 149 | 150 | void updateFeature(int c, int f) 151 | { 152 | features[c][f] += learning_rate * errors[c][f]; 153 | } 154 | 155 | public: 156 | CategorySet classify(Array in) 157 | { 158 | inputs = in; 159 | CategorySet categories; 160 | 161 | for(int c = 0; c < features.size(); c++) 162 | { 163 | float similarity = 1-wmse(features[c], inputs, weights[c]); 164 | float threshold = thresholds[c]; 165 | 166 | if(similarity >= threshold) 167 | { 168 | float error = 1 - similarity; 169 | int index = get_sorted_index(categories.errors, error); 170 | if(categories.errors.size() > 0) 171 | { 172 | if(index < categories.errors.size()) 173 | { 174 | categories.insert(index, c, error); 175 | if(categories.errors.size() > class_size) 176 | { 177 | categories.erase(categories.errors.size()-1); 178 | } 179 | } 180 | } 181 | else 182 | { 183 | categories.append(c, error); 184 | } 185 | } 186 | thresholds[c] -= decay_rate; 187 | if(thresholds[c] > 1) 188 | { 189 | thresholds[c] = 1; 190 | } 191 | else if(thresholds[c] < 0) 192 | { 193 | thresholds[c] = 0; 194 | } 195 | 196 | } 197 | active_categories = categories; 198 | return active_categories; 199 | } 200 | 201 | void update() 202 | { 203 | for(int i = 0; i < active_categories.size(); i++) 204 | { 205 | int c = active_categories[i].label; 206 | thresholds[c] += learning_rate; 207 | for(int f = 0; f < input_size; f++) 208 | { 209 | updateError(c, f); 210 | updateProgress(c, f); 211 | updateWeight(c, f); 212 | updateFeature(c, f); 213 | 214 | if(weights[c][f] < 0) 215 | { 216 | weights[c][f] = 0; 217 | } 218 | else if(weights[c][f] > 1) 219 | { 220 | weights[c][f] = 1; 221 | } 222 | } 223 | } 224 | } 225 | 226 | 227 | Array decode(CategorySet c) 228 | { 229 | Array rep; 230 | for(int i = 0; i < inputs.size(); i++) 231 | { 232 | float val = 0; 233 | 234 | if(c.labels.size() > 0) 235 | { 236 | for(int j = 0; j < c.labels.size(); j++) 237 | { 238 | int index = c.labels[j]; 239 | val += features[index][i] * weights[index][i]; 240 | } 241 | val /= c.labels.size(); 242 | } 243 | 244 | rep.push_back(val); 245 | } 246 | 247 | return rep; 248 | } 249 | }; 250 | 251 | #endif // CLASSIFIER_HPP_INCLUDED 252 | -------------------------------------------------------------------------------- /src/Error.hpp: -------------------------------------------------------------------------------- 1 | #ifndef ERROR_HPP_INCLUDED 2 | #define ERROR_HPP_INCLUDED 3 | 4 | #include 5 | #include 6 | 7 | #define Array std::vector 8 | #define Matrix std::vector 9 | 10 | float err(float x1, float x2) 11 | { 12 | float e = x1-x2; 13 | return e; 14 | } 15 | 16 | float werr(float x1, float x2, float w) 17 | { 18 | float e = err(x1, x2) * w; 19 | return e; 20 | } 21 | 22 | float mse(float x1, float x2) 23 | { 24 | float e = pow(err(x1, x2), 2); 25 | return e; 26 | } 27 | 28 | float wmse(float x1, float x2, float w) 29 | { 30 | float e = mse(x1, x2) * w; 31 | return e; 32 | } 33 | 34 | float err(Array x1, Array x2) 35 | { 36 | float e = 0; 37 | for(int i = 0; i < x1.size(); i++) 38 | { 39 | e = e+err(x1[i], x2[i]); 40 | } 41 | e = e / x1.size(); 42 | return e; 43 | } 44 | 45 | float werr(Array x1, Array x2, Array w) 46 | { 47 | float e = 0; 48 | for(int i = 0; i < x1.size(); i++) 49 | { 50 | e = e+werr(x1[i], x2[i], w[i]); 51 | } 52 | e = e / x1.size(); 53 | return e; 54 | } 55 | 56 | float mse(Array x1, Array x2) 57 | { 58 | float e = 0; 59 | for(int i = 0; i < x1.size(); i++) 60 | { 61 | e = e+mse(x1[i], x2[i]); 62 | } 63 | e = e / x1.size(); 64 | return e; 65 | } 66 | 67 | float wmse(Array x1, Array x2, Array w) 68 | { 69 | float e = 0; 70 | for(int i = 0; i < x1.size(); i++) 71 | { 72 | e = e+wmse(x1[i], x2[i], w[i]); 73 | } 74 | e = e / x1.size(); 75 | return e; 76 | } 77 | 78 | #endif // ERROR_HPP_INCLUDED 79 | --------------------------------------------------------------------------------