├── .gitignore ├── README.md ├── eval ├── README.md ├── matrix.h ├── model.h ├── run.c └── utils.h ├── models └── model.def └── train ├── README.md ├── model.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # virtualenv 3 | venv/ 4 | ENV/ 5 | 6 | # ide 7 | .cproject 8 | .project 9 | .settings 10 | .pydevproject 11 | 12 | # random executables 13 | run 14 | 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tinier-nn 2 | 3 | A tinier framework for deploying binarized neural networks on embedded systems. 4 | 5 | I wrote a better version in C++ as a gist [here](https://gist.github.com/codekansas/3cb447e3d95ccac4c5a56ea7ffb079ce). 6 | 7 | ## About 8 | 9 | The core of this framework is the use of the Binarized Neural Network (BNN) described in [Binarized Neural Networks: 10 | Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1](https://arxiv.org/abs/1602.02830). 11 | This framework seemed ideal for use with embedded systems such as an Arduino (or Raspberry Pi), but to my knowledge 12 | this wasn't already available. 13 | 14 | The system consists of two parts: 15 | - `train`: TensorFlow code for building a BNN and saving it in a specific framework. 16 | - `eval`: The inference part, which runs on the system, is written in straight C. It reads the model into SRAM and 17 | performs matrix multiplications using a bitwise XOR, which (probably) leads to a big improvement in time and power 18 | consumption (although I haven't benchmarked anything). 19 | 20 | The two sample scripts, `train/model.py` and `eval/run.c` demonstrate how to train a model to discriminate an XOR function. The model uses a lot more weights than would theoretically be necessary for this task, but together they demonstrate how to adapt the code to other use cases. 21 | 22 | ## Demo 23 | 24 | To run the demo, run: 25 | 26 | make eval/run 27 | cat models/model.def | eval/run 28 | 29 | The outputs show the predictions for an XOR function. 30 | 31 | To train the model, run: 32 | 33 | python train/model.py --save_path models/model.def 34 | 35 | This is how the `models/model.def` file was generated. 36 | 37 | ## Math Notes 38 | 39 | Encoding weights / activations with values of -1 and 1 as binary values: `-1 -> 0, 1 -> 1`. Then matrix multiplication 40 | done using the XOR operation. Here's an example: 41 | 42 | Using binary weights and activations of -1 and 1: 43 | 44 | - Vector-matrix operation is `[1, -1] * [1, -1; -1, 1] = [1 * 1 + -1 * -1, 1 * -1 + -1 * 1] = [2, -2]` 45 | - After applying the binary activation function `x > 0 ? 1 : -1` gives `[1, -1]` 46 | 47 | Using binary weights and activations of -1 and 1: 48 | 49 | - Encoding the inputs as binary weights: `[1, 0] * [1, 0; 0, 1]` 50 | - Applying XOR + sum: `[1 ^ 1 + 0 ^ 0; 1 ^ 0 + 0 ^ 1] = [0, 2]` 51 | - Activation function then becomes `x < (2 / 2) ? 1 : 0` which gives `[1, 0]` 52 | 53 | Because the operations are done this way, I made it so that matrix dimensions must be multiples of the integer sizes. 54 | Padding can be used to make data line up correctly (although if someone wants to change this, LMK). 55 | 56 | ## To Do 57 | 58 | - On most Arduinos, Flash memory is bigger than SRAM by about a factor of 32. So it's not too bad to encode models 59 | as characters instead of bits (and it makes them easier to debug). Although this is something that could be improved. 60 | - More examples would be awesome. 61 | - Matrix multiplication could be better, maybe. 62 | - Architectures besides feed-forward networks would be good. 63 | 64 | This was a project that I worked on for CalHacks 3.0 (although I never submitted it). 65 | -------------------------------------------------------------------------------- /eval/README.md: -------------------------------------------------------------------------------- 1 | # Evaluate trained BNN 2 | 3 | In principle, the training and inference steps for the BNN should be separate from one another. The files in this directory provide code for performing low-level inference on embedded devices, given a model that has been trained elsewhere. This handles memory management appropriately. 4 | 5 | `tinier-nn/models/model.def` (the sample definition) contains weights trained for an XOR function. This file demonstrates how to initialize a model and perform a classification task. These weights were trained and saved using the train/model.py script. 6 | 7 | To run normally, the model needs to be piped to stdin (weight loading can be configured differently, depending on the application). From this directory, run the following: 8 | 9 | make run # To build the file itself. 10 | cat ../models/model.def | ./run -------------------------------------------------------------------------------- /eval/matrix.h: -------------------------------------------------------------------------------- 1 | /* 2 | * matrix.h 3 | * 4 | * Performs vector-matrix operations associated with the Binarized Neural 5 | * Network architecture. 6 | */ 7 | 8 | #ifndef MATRIX_H_ 9 | #define MATRIX_H_ 10 | 11 | #include "utils.h" 12 | #include 13 | 14 | 15 | typedef struct _matrix { 16 | /* Defines a matrix struct. */ 17 | dim_t w, h; 18 | data_t *data; 19 | } matrix; 20 | 21 | 22 | typedef struct _vector { 23 | /* Defines a vector struct. */ 24 | dim_t h; 25 | data_t *data; 26 | } vector; 27 | 28 | 29 | void instantiate_matrix(matrix *m, int w, int h) { 30 | /* Instantiates a matrix with width w and height h. 31 | * 32 | * This should only be called once; it provides the buffer, 33 | * and data is fed into it. 34 | */ 35 | if (w % INT_SIZE != 0 || h % INT_SIZE != 0) { 36 | log_str("Invalid matrix size requested.\n"); 37 | exit_failure(); 38 | } 39 | 40 | m->w = w; 41 | m->h = h; 42 | m->data = allocate_memory(w * h * INT_SIZE); 43 | } 44 | 45 | 46 | void instantiate_vector(vector *v, int h) { 47 | /* Instantiates a vector with height h. 48 | * 49 | * This should only be called once; it provides the buffer, 50 | * and data is fed into it. 51 | */ 52 | if (h % INT_SIZE != 0) { 53 | log_str("Invalid vector size requested.\n"); 54 | exit_failure(); 55 | } 56 | 57 | v->h = h; 58 | v->data = allocate_memory(h); 59 | } 60 | 61 | 62 | data_t bitsum(data_t x) { 63 | /* Calculates the bitsum of a data_t, e.g. the number of 1's. 64 | * This is useful for matrix multiplication. 65 | */ 66 | data_t c = 0; 67 | for (int i = 0; i < INT_SIZE; i++) { 68 | c += (x & 1); 69 | x >>= 1; 70 | } 71 | return c; 72 | } 73 | 74 | 75 | void matmul(vector *from, vector *to, matrix *by) { 76 | /* Multiples binary vector "from" into "by" and puts the 77 | * result in "to". It is assumed that from.h == by.w 78 | * and to.h == by.h. 79 | * 80 | * The data in the matrix should be row-major, e.g. go through 81 | * all of 1 .. w each step of 1 .. h. 82 | */ 83 | dim_t from_h = from->h / INT_SIZE, to_h = to->h / INT_SIZE; 84 | for (dim_t i = 0; i < to_h; i++) { 85 | data_t d = 0; 86 | for (dim_t j = 0; j < INT_SIZE; j++) { 87 | d <<= 1; 88 | data_t c = 0; 89 | for (dim_t k = 0; k < from_h; k++) { 90 | c += bitsum(by->data[((i * INT_SIZE) + j) * from_h + k] ^ 91 | from->data[k]); 92 | } 93 | 94 | // Threshold function. 95 | d |= (c >= from->h / 2) ? 0 : 1; 96 | } 97 | to->data[i] = d; 98 | } 99 | } 100 | 101 | 102 | void print_mat(matrix *m) { 103 | for (dim_t i = 0; i < m->h; i++) { 104 | for (dim_t j = 0; j < m->w; j++) { 105 | log_str((m->data[(i * m->w + j) / INT_SIZE] & 106 | (1 << (j % INT_SIZE))) ? "1" : "0"); 107 | } 108 | log_str("\n"); 109 | } 110 | } 111 | 112 | 113 | void print_vec(vector *v) { 114 | for (dim_t i = 0; i < v->h; i++) { 115 | log_str((v->data[i / INT_SIZE] & 116 | (1 << (i % INT_SIZE))) ? "1" : "0"); 117 | } 118 | log_str("\n"); 119 | } 120 | 121 | 122 | #endif /* MATRIX_H_ */ 123 | -------------------------------------------------------------------------------- /eval/model.h: -------------------------------------------------------------------------------- 1 | /* 2 | * model.h 3 | * 4 | * Components associated with neural network architecture. 5 | */ 6 | 7 | #ifndef MODEL_H_ 8 | #define MODEL_H_ 9 | 10 | #include "matrix.h" 11 | 12 | 13 | typedef struct _dense { 14 | matrix weights; 15 | vector outputs; 16 | struct _dense *next; 17 | } dense; 18 | 19 | 20 | void build_layer(dense *buffer, dim_t num_input, dim_t num_output) { 21 | /* Builds a single feedforward dense layer. 22 | * 23 | * Args: 24 | * buffer: pointer to location to hold data 25 | * num_input: number of input dimensions 26 | * num_output: number of output dimensions 27 | * 28 | * Returns: 29 | * A dense binary layer which performs a linear transform from 30 | * num_input to num_output dimensions. 31 | */ 32 | 33 | instantiate_matrix(&buffer->weights, num_input, num_output); 34 | instantiate_vector(&buffer->outputs, num_output); 35 | buffer->next = NULL; 36 | } 37 | 38 | 39 | vector* get_result(vector *input, dense *head) { 40 | /* Passes through every layer in the network and returns a pointer to the 41 | * result. 42 | * 43 | * Args: 44 | * input: pointer to the input vector (to be processed) 45 | * dense: pointer to the first dense layer in a network. 46 | * 47 | * Returns: 48 | * A pointer to the vector containing the output data. 49 | */ 50 | 51 | while (head != NULL) { 52 | matmul(input, &head->outputs, &head->weights); 53 | input = &head->outputs; 54 | head = head->next; 55 | } 56 | return input; 57 | } 58 | 59 | 60 | int load_model(dense *buffer, unsigned max_layers) { 61 | /* Loads a model into SRAM from a text file. 62 | * 63 | * Returns: 64 | * number of layers in the model (-1 if there was an error); 65 | */ 66 | 67 | if (next_char() != 'b' || next_char() != 'n' || next_char() != 'n') { 68 | log_str("Invalid magic string."); 69 | exit_failure(); 70 | } 71 | 72 | // Holds the dimensions of the layer. 73 | dim_t w, h; 74 | 75 | for (int layer = 0; layer < max_layers; layer++) { 76 | 77 | // Reads in width and height; 78 | get_dims(&w, &h); 79 | if (w == 0 || h == 0) { 80 | return layer; 81 | } 82 | 83 | // Instantiate layer itself. 84 | build_layer(&buffer[layer], w, h); 85 | 86 | // Connect previous layer to current layer. 87 | if (layer > 0) { 88 | buffer[layer-1].next = &buffer[layer]; 89 | } 90 | 91 | // Reads in data for the current model. 92 | dim_t max_v = (buffer[layer].weights.w * 93 | buffer[layer].weights.h) / INT_SIZE; 94 | for (dim_t i = 0; i < max_v; i++) { 95 | data_t d = 0; 96 | dim_t j = 0; 97 | while (j < INT_SIZE) { 98 | char c = next_char(); 99 | if (c == '\n') { 100 | continue; 101 | } 102 | 103 | d <<= 1; 104 | if (c == '1') { 105 | d |= 1; 106 | } 107 | j++; 108 | } 109 | buffer[layer].weights.data[i] = d; 110 | } 111 | } 112 | 113 | return -1; 114 | } 115 | 116 | 117 | #endif /* MODEL_H_ */ 118 | -------------------------------------------------------------------------------- /eval/run.c: -------------------------------------------------------------------------------- 1 | /* 2 | * run.c 3 | * 4 | * Sample script for evaluating a trained model. 5 | */ 6 | 7 | #include "model.h" 8 | 9 | #define MAX_LAYERS 10 10 | 11 | vector input_vec[4]; 12 | dense layers[MAX_LAYERS]; 13 | 14 | int main() { 15 | 16 | // Builds the network. 17 | int num_layers = load_model(layers, MAX_LAYERS); // Load from stdin. 18 | printf("%d layers in model.\n", num_layers); 19 | 20 | // Run on XOR input vectors. 21 | for (int i = 0; i < 4; i++) { 22 | 23 | // Actually allocates space for the vector itself. 24 | instantiate_vector(&input_vec[i], 32); 25 | 26 | // Adds the XOR data (0 -> 00, 1 -> 01, 2 -> 10, 3 -> 11). 27 | input_vec[i].data[0] = i << 30; 28 | 29 | // Runs the network and prints the output. 30 | vector *result = get_result(&input_vec[i], layers); 31 | print_vec(result); 32 | } 33 | 34 | return 0; 35 | } 36 | -------------------------------------------------------------------------------- /eval/utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * utils.h 3 | * 4 | * When porting to a different embedded system, these values should be updated 5 | * for the particular compiler. 6 | */ 7 | 8 | #ifndef UTILS_H_ 9 | #define UTILS_H_ 10 | 11 | #include 12 | #include 13 | 14 | /* 15 | * Defines types associated with matrices / vectors. 16 | */ 17 | typedef unsigned int dim_t; 18 | typedef unsigned int data_t; 19 | typedef int bool_t; 20 | 21 | /* 22 | * Defines the number of bits in a single int. 23 | */ 24 | #define INT_SIZE (sizeof(int) * 8) 25 | 26 | /* 27 | * Failure exit. Uses a system call, so it should probably be generic in order 28 | * to work on other platforms. 29 | */ 30 | void exit_failure() { 31 | exit(EXIT_FAILURE); 32 | } 33 | 34 | /* 35 | * Generic operation for logging string information. It could just be ignored. 36 | */ 37 | void log_str(const char *x) { 38 | fprintf(stderr, "%s", x); 39 | } 40 | 41 | data_t* allocate_memory(dim_t n) { 42 | if (n % INT_SIZE != 0) { 43 | log_str("Invalid shape requested for memory allocation.\n"); 44 | exit_failure(); 45 | } 46 | return (data_t*) malloc(n / INT_SIZE); 47 | } 48 | 49 | /* 50 | * next_char and next_int define generic operations for reading and writing 51 | * from some data source, for loading models. These operations could be 52 | * changed to read from a serial port. 53 | */ 54 | char next_char() { 55 | return getchar(); 56 | } 57 | 58 | void get_dims(dim_t *w, dim_t *h) { 59 | scanf("%d,%d", w, h); 60 | } 61 | 62 | #endif /* UTILS_H_ */ 63 | -------------------------------------------------------------------------------- /models/model.def: -------------------------------------------------------------------------------- 1 | bnn 2 | 32,64 3 | 01011011100100010110101001110101 4 | 10000111111111100010011001111100 5 | 11011011000001111011000111001110 6 | 11000110001111111100011101001000 7 | 01110011000000101010001110010111 8 | 10101111010011111011011010101010 9 | 11100100000100100001010110111000 10 | 00010101000001111010000011100110 11 | 00010110101110011110111001100011 12 | 01110100000011100000111010110110 13 | 10110100001101101110111110101111 14 | 00010011010111001001011101011001 15 | 00100100000110101001001100010011 16 | 00001000001000111100101100011100 17 | 00111100010001000010101001011001 18 | 10001010011100010110110110101011 19 | 00110101001000101011100011000101 20 | 01000100111111011011100010000001 21 | 01001111100100011111110010110111 22 | 11010100011101110001010000110110 23 | 01100001011010001111000001100000 24 | 11100011000000111111011001001111 25 | 00101111010010000000010110100001 26 | 10000001010100101011011011000000 27 | 00111101100111100010011011101100 28 | 10110010101001111011100001000000 29 | 01110010101000000010110011000110 30 | 10011000011011110111011101000111 31 | 01101000111101111001001010000000 32 | 11001011110110100100100000111100 33 | 01000010001000001000010011100110 34 | 11100101110010110000101000110110 35 | 11101010110000011110101011011000 36 | 10010010001111111001001101101111 37 | 11000000001001010001011000110000 38 | 01101101100110001111010001011111 39 | 00111000100001001110001110111000 40 | 00011000001010001100111111100100 41 | 11101101110100001101101000000101 42 | 01100000001100111011100110011001 43 | 00110101110101101001001010000110 44 | 01111001000010111101111011101011 45 | 01010100101101011001010111000100 46 | 10100010000100011101101010110111 47 | 00110001101110111011011101000101 48 | 00101101101100011101110101100111 49 | 11001011111011001000011011100011 50 | 00000010011110000101011001000001 51 | 00000110101101001110011010000110 52 | 01111000111101101011111100010011 53 | 11001010000001011101001011111011 54 | 01010100001000111100011010101001 55 | 00000001110110011000011001100011 56 | 10001101101011011110000000110111 57 | 00011011000100101011001001111001 58 | 01100100100100110010000101111011 59 | 01101110111011010001101010000110 60 | 01101011010111110001100010100001 61 | 10001101001010100001110001110001 62 | 11101011111110001111001110001111 63 | 00011101001101100101110000000101 64 | 01001100100010010000000011010111 65 | 01010011100110011011110101100100 66 | 10101000000111000001110110110011 67 | 64,64 68 | 0110000001101101100001011011010110110001111011111000110111000011 69 | 1001010100001101100111011100110110010111000001110110101101000111 70 | 1110011110100011101000111110000101010111000111010011010011100011 71 | 0011110010100110010010101001001100010000000011001010110100011100 72 | 1110110011100011110001100111101101001000111011100111111110000111 73 | 1000101110100001001110000001111000111111100101110110100011001110 74 | 1111101011100000001101010100001110010000100000110010010101011010 75 | 0111110001001111011001010001110100110011000010110010011000000001 76 | 0001001100001110010111011001011111100111011001001001011001001111 77 | 1100010001100011100110101011101110001111010011110100110011101110 78 | 1000010110101001111100110000111110000110000001101010110010100010 79 | 0000111000010110010110011011001010101011100011110000110001011011 80 | 1011100111100101110001001111100011010100011010100101001000011100 81 | 1001110000010010111100110100111110000101011110111011001110010001 82 | 0001101110001001110101100100010011101000001010111010101110000101 83 | 0010100101000010000000111011010011101001101001111110101111000000 84 | 1000110000101111110110010101111110000001101011000011100101110101 85 | 0100000001010010101101000010100010010011010100010000101000001101 86 | 0101111110101001011010111101100100011010110011010100000100011110 87 | 1010001000110110111001001010001001100101100011001110001110000111 88 | 0111000011100100011001111111011010111001011101100101110100101000 89 | 1000110110000010101111110001011111001001010100001010010010111010 90 | 0101001101011011011000111111100000000101000010010111111001111110 91 | 0110000111010000010101000000111000001010011100111001011100111000 92 | 1100001000011010000001001011111110101101001011000011000010100100 93 | 0100111100101100110011111000110010011110101100100010000100111001 94 | 1011101100011010100100000111101111001101011111011101100101110001 95 | 0000111101111100100111100010101011110001011100000011011011111010 96 | 1001011000011101100011101101101011001100111110011110100111101111 97 | 0011001011001011011001011111011110110110010011100100100001000100 98 | 0011111111111010001001010010110011111001010110100100100101100101 99 | 1111101011010101100011011100100001111000100111101011011011100011 100 | 0001000101011111011010111111101110001101001101111111100010001100 101 | 1101001100011110101110101001001111001010100101000100110111110000 102 | 1000010000011111100100010001100110111100001000011110000011010000 103 | 0110000000011001101010000110010110101010000011001011101011011000 104 | 1101001001101111000111101111110010110010000011011111111101001011 105 | 0100100011101000011010010011101000110101010100100000101110000001 106 | 1000101011110001011100011000000011110101010011100101110100110100 107 | 0111000111001100011001011010111000101001101111100000001100010111 108 | 1000010001100111101110001101111110001100010010011011110010011010 109 | 0000101010110000100100101110001111001111011011101111010001001100 110 | 1100001001101101111000100111011110111111001101101001100101100010 111 | 1001100101111010100101111011110000100001011000000100100111110111 112 | 1100010101111001100000111110011010010010100011000100010101111001 113 | 0111000101100111000101001011000101101000110110011100111000110001 114 | 0010001100000111100011110111100000111001010100110010100001000100 115 | 0010011011001001111010111011011000111000010111011010010010011100 116 | 1101110101001001001011101001000111111001101101110000010011100000 117 | 0000111101011110110101010101001000100110101011101110110010001100 118 | 0011101011000101001010001001010101110110101001000110001001001011 119 | 0100100011000110001111100010011111000010001100010111001111100001 120 | 0011110011101101101010001101011100001110100110001101010001000110 121 | 0110100010011100001001011010010111110110000011100101010010010011 122 | 1000001100010011111000100001111100011101100110000111110110001010 123 | 1101100000010011010100110010101110011110011000010101000111101011 124 | 1101001100000000110010100111111101110010001011000001010000101101 125 | 0101000111110100011010110011100001111000100100001101011100010110 126 | 0101101100101100011011000010001101100001101110100101001000000000 127 | 1100010110011101010000100010011000001000111111110101010000001101 128 | 1010001000111110110111011101111010011010100111010000101001101110 129 | 0001110110110111010110001001010111000110101001011000101011010111 130 | 0111111000101010000110000001101001110111001011100000011011111011 131 | 1101110100001001111110100010010101000111110010001111010101001011 132 | 64,32 133 | 0001001110010000110110010001001000100100000000001111000011110001 134 | 0001101010010000100011001100001000000001001001010100111001101100 135 | 1010100100010011110001110111111100100110011011100011101001101010 136 | 0101001101011100010000011100001111011111000000000101110001101000 137 | 0001010110100001011110011110111011000111011110111010011010110000 138 | 0100101110001111010100010011110010100011100111101010110111100000 139 | 1000100000000101110110110010011110010101100000001001100011101011 140 | 0011000001000110100110111000111101000100001110010010000111101100 141 | 0011101110011100001010001011010101000101000100100101001001010001 142 | 1000001110001001100111111001001010001110011011001000010110110001 143 | 0011001000110111011001110001010010000110001100111011111101100010 144 | 1100111111010001011010001100100011000101001101100010010011110010 145 | 1001111001110001111100100101101000010100101111011011101101110010 146 | 1001101110010000111100000101011010100101011101111111011010100110 147 | 1010011101000011100110110100001111100111101101001001000101111001 148 | 0010100010011001111110001111011110100001000110100011111101100100 149 | 1100100111001100110110101011111100100111010100101011010001110000 150 | 1000010110010001110110111110000100001101000110011011111000100011 151 | 1000011001010100011111000111011010110001000100001000010001111000 152 | 1011100110011111101010010111111110110101001000100010110101000101 153 | 1001101100010100011110100000011001000110001000100101001111011010 154 | 0001000110001011011011111000001000001000001100011110110111111110 155 | 1010100000000101010000000100001000000111000101011101000110101110 156 | 0011001111010001111010011011111110100101011111000110010011101110 157 | 1111001010000110000000010110010001001101011111001111100001111000 158 | 1001000010000111010110010011111100011110001010100110010101100110 159 | 0010011010000011010110010110101000001111010001000010111110110001 160 | 1010001011010001001010010110010110000101011000001100010100111010 161 | 1011001110001000100111101111011110110101010110101110000011101011 162 | 0001101100010101110101010100001010001001000110001000001000010100 163 | 0100010010001100110001000111000011111111001111101100010000001001 164 | 0010100010000101101110110110011111011101001100111000101101100111 165 | 0,0 -------------------------------------------------------------------------------- /train/README.md: -------------------------------------------------------------------------------- 1 | # Train the BNN 2 | 3 | Training should be done off the embedded device. This directory provides a framework for defining and training BNNs in TensorFlow, then saving the weights so that they can be ported to an embedded device. 4 | 5 | Model creators for Binary NN. 6 | 7 | Reference paper: [https://arxiv.org/pdf/1602.02830v3.pdf](https://arxiv.org/pdf/1602.02830v3.pdf) 8 | 9 | The script itself trains a binary network to perform a simple XOR classification task (although this could be extended to other tasks). The weights are saved to an output file which can be read by the C evaluation code (in particular, eval/run.c demonstrates how to build input vectors and run them through the network). 10 | 11 | From start to finish, models can be trained and evaluated using: 12 | 13 | make eval/run 14 | python train/model.py --save_path models/model.def 15 | cat models/model.def | eval/run 16 | -------------------------------------------------------------------------------- /train/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import argparse 8 | import os 9 | import warnings 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | 15 | def add_binary_layer(prev_layer, num_outputs, inference, uid): 16 | """Adds a binary layer to the network. 17 | 18 | Args: 19 | prev_layer: 2D Tensor with shape (batch_size, num_hidden). 20 | num_outputs: int, number of dimensions in next hidden layer. 21 | inference: bool scalar Tensor, tells the network if it is 22 | the inference step. 23 | uid: int, the unique identifier. 24 | 25 | Returns: 26 | tuple (activation, updates) 27 | activation: output of this layer. 28 | updates: variables used to compute gradients. 29 | """ 30 | 31 | batch_size, num_inputs = [dim.value for dim in prev_layer.get_shape()] 32 | 33 | def _pos_neg_like(tensor): 34 | return (tf.constant(1, dtype=tf.float32, shape=tensor.get_shape()), 35 | tf.constant(-1, dtype=tf.float32, shape=tensor.get_shape())) 36 | 37 | with tf.variable_scope('binary_layer_%d' % uid) as vs: 38 | w = tf.get_variable('w', shape=(num_inputs, num_outputs), 39 | initializer=tf.random_normal_initializer(seed=uid), 40 | trainable=True) 41 | 42 | # Binarize the weight matrix. 43 | bin_w = tf.select(w > 0, *_pos_neg_like(w)) 44 | 45 | # Output has shape (batch_size, num_outputs) 46 | wx = tf.matmul(prev_layer, bin_w) 47 | 48 | with tf.variable_scope('activation'): 49 | p = tf.tanh(wx) 50 | 51 | def _binomial_activation(): 52 | sigma = tf.random_uniform( 53 | p.get_shape(), seed=uid) < (p + 1) / 2 54 | return tf.select(sigma, *_pos_neg_like(p)) 55 | 56 | def _binary_activation(): 57 | return tf.select(wx > 0, *_pos_neg_like(wx)) 58 | 59 | activation = tf.cond(inference, 60 | _binomial_activation, # Stochastic inference. 61 | _binary_activation) # Deterministic inference. 62 | 63 | return activation, (p, prev_layer, w, bin_w) 64 | 65 | 66 | def build_model(input_var, layers=[]): 67 | """Builds a BNN model. 68 | 69 | Args: 70 | input_var: 2D Tensor with shape (batch_size, num_input_features). 71 | layers: list of ints, the dimensionality of each layer. 72 | 73 | Returns: 74 | tuple (output_layer, updates, inference) 75 | output_layer: 2D Tensor, the output of the model. 76 | updates: list of variables used to compute gradients. 77 | inference: scalar bool Tensor, used to tell the model when to 78 | use inference mode. 79 | """ 80 | 81 | inference = tf.placeholder(dtype=tf.bool, name='inference') 82 | updates = [] 83 | hidden_layer = input_var 84 | for i, num_hidden in enumerate(layers): 85 | if num_hidden % 16 != 0: 86 | warnings.warn('Hidden layers should be multiples of ' 87 | '16, not %d' % num_hidden) 88 | 89 | hidden_layer, update = add_binary_layer(hidden_layer, num_hidden, 90 | inference, i) 91 | updates.append(update) 92 | output_layer = hidden_layer 93 | 94 | return output_layer, updates, inference 95 | 96 | 97 | def get_loss(output, target): 98 | return tf.reduce_mean(tf.square(output - target)) 99 | 100 | 101 | def get_accuracy(output, target, element_wise=True): 102 | if not element_wise: 103 | output, target = tf.reduce_sum(output, [1]), tf.reduce_sum(target, [1]) 104 | eq = tf.equal(tf.greater(output, 0), tf.greater(target, 0)) 105 | return tf.reduce_mean(tf.cast(eq, tf.float32)) 106 | 107 | 108 | def binary_backprop(loss, output, updates): 109 | """Manually backpropagates gradient error. 110 | 111 | Args: 112 | loss: scalar Tensor, the model loss. 113 | output: scalar 114 | updates: list of updates to use for backprop. 115 | 116 | Returns: 117 | backprop_updates: list of (grad, variable) tuples, the gradients. 118 | """ 119 | 120 | backprop_updates = [] 121 | loss_grad, = tf.gradients(loss, output) 122 | 123 | for p, prev_layer, w, bin_w in updates[::-1]: 124 | w_grad, loss_grad = tf.gradients(p, [bin_w, prev_layer], loss_grad) 125 | backprop_updates.append((w_grad, w)) 126 | 127 | return backprop_updates 128 | 129 | 130 | def save_model(path, binary_weights): 131 | with open(os.path.join(path, 'model.def'), 'w') as f: 132 | f.write('bnn') 133 | for i, weight in enumerate(binary_weights): 134 | weight = weight.T 135 | f.write('\n%d,%d\n' % (weight.shape[1], weight.shape[0])) 136 | bstr = '\n'.join(''.join('0' if e < 0 else '1' for e in r) 137 | for r in weight) 138 | f.write(bstr) 139 | f.write('\n0,0') 140 | 141 | 142 | def main(): 143 | parser = argparse.ArgumentParser( 144 | description='Train binary neural network weights.') 145 | parser.add_argument( 146 | '--save_path', 147 | type=str, 148 | help='Where to save the weights and configuration.', 149 | required=True) 150 | parser.add_argument( 151 | '--num_train', 152 | type=int, 153 | help='Number of iterations to train model.', 154 | default=20000) 155 | parser.add_argument( 156 | '--eval_every', 157 | type=int, 158 | help='How often to evalute the model.', 159 | default=1000) 160 | args = parser.parse_args() 161 | 162 | input_var = tf.placeholder(dtype=tf.float32, shape=(4, 32), 163 | name='input_placeholder') 164 | output_var = tf.placeholder(dtype=tf.float32, shape=(4, 32), 165 | name='output_placeholder') 166 | layer_sizes = [64, 64, 32] 167 | 168 | # Configures data for a simple XOR task. 169 | input_data = np.zeros(shape=(4, 32)) 170 | input_data[(0, 0, 1, 2), (0, 1, 1, 0)] = 1 171 | output_data = np.zeros(shape=(4, 32)) 172 | output_data[(0, 3), :] = 1 173 | input_data = input_data * 2 - 1 174 | output_data = output_data * 2 - 1 175 | 176 | feed_dict = { 177 | input_var: input_data, 178 | output_var: output_data, 179 | } 180 | 181 | model_scope = 'model' 182 | with tf.variable_scope(model_scope): 183 | output_layer, updates, inference = build_model( 184 | input_var, layers=layer_sizes) 185 | model_vars = tf.get_collection(tf.GraphKeys.VARIABLES, 186 | scope=model_scope) 187 | loss = get_loss(output_layer, output_var) 188 | accuracy = get_accuracy(output_layer, output_var) 189 | 190 | global_step = tf.Variable(0, name='global_step', trainable=False) 191 | learning_rate = 0.01 192 | 193 | gradients = binary_backprop(loss, output_layer, updates) 194 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) 195 | min_op = optimizer.apply_gradients(gradients, global_step=global_step) 196 | 197 | # Get model weights. 198 | _, _, _, weights = zip(*updates) 199 | 200 | best_accuracy = 0 201 | with tf.Session() as sess: 202 | sess.run(tf.initialize_all_variables()) 203 | for i in xrange(1, args.num_train + 1): 204 | feed_dict[inference] = True 205 | sess.run(min_op, feed_dict=feed_dict) 206 | if i % args.eval_every == 0: 207 | feed_dict[inference] = False 208 | accuracy_val, loss_val = sess.run([accuracy, loss], 209 | feed_dict=feed_dict) 210 | print('Epoch = %d: Loss = %.4f, Accuracy = %.4f' % 211 | (i, loss_val, accuracy_val)) 212 | if accuracy_val > best_accuracy: 213 | for entry in sess.run(output_layer, feed_dict=feed_dict): 214 | out_str = ''.join('0' if i < 0 else '1' for i in entry) 215 | print(out_str[::-1]) 216 | save_model(args.save_path, sess.run(weights)) 217 | best_accuracy = accuracy_val 218 | 219 | 220 | if __name__ == '__main__': 221 | main() 222 | -------------------------------------------------------------------------------- /train/requirements.txt: -------------------------------------------------------------------------------- 1 | funcsigs==1.0.2 2 | mock==2.0.0 3 | numpy==1.11.2 4 | pbr==1.10.0 5 | protobuf==3.0.0 6 | six==1.10.0 7 | tensorflow==0.11.0 8 | --------------------------------------------------------------------------------