├── .gitignore
├── README.md
├── eval
├── README.md
├── matrix.h
├── model.h
├── run.c
└── utils.h
├── models
└── model.def
└── train
├── README.md
├── model.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # virtualenv
3 | venv/
4 | ENV/
5 |
6 | # ide
7 | .cproject
8 | .project
9 | .settings
10 | .pydevproject
11 |
12 | # random executables
13 | run
14 |
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # tinier-nn
2 |
3 | A tinier framework for deploying binarized neural networks on embedded systems.
4 |
5 | I wrote a better version in C++ as a gist [here](https://gist.github.com/codekansas/3cb447e3d95ccac4c5a56ea7ffb079ce).
6 |
7 | ## About
8 |
9 | The core of this framework is the use of the Binarized Neural Network (BNN) described in [Binarized Neural Networks:
10 | Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1](https://arxiv.org/abs/1602.02830).
11 | This framework seemed ideal for use with embedded systems such as an Arduino (or Raspberry Pi), but to my knowledge
12 | this wasn't already available.
13 |
14 | The system consists of two parts:
15 | - `train`: TensorFlow code for building a BNN and saving it in a specific framework.
16 | - `eval`: The inference part, which runs on the system, is written in straight C. It reads the model into SRAM and
17 | performs matrix multiplications using a bitwise XOR, which (probably) leads to a big improvement in time and power
18 | consumption (although I haven't benchmarked anything).
19 |
20 | The two sample scripts, `train/model.py` and `eval/run.c` demonstrate how to train a model to discriminate an XOR function. The model uses a lot more weights than would theoretically be necessary for this task, but together they demonstrate how to adapt the code to other use cases.
21 |
22 | ## Demo
23 |
24 | To run the demo, run:
25 |
26 | make eval/run
27 | cat models/model.def | eval/run
28 |
29 | The outputs show the predictions for an XOR function.
30 |
31 | To train the model, run:
32 |
33 | python train/model.py --save_path models/model.def
34 |
35 | This is how the `models/model.def` file was generated.
36 |
37 | ## Math Notes
38 |
39 | Encoding weights / activations with values of -1 and 1 as binary values: `-1 -> 0, 1 -> 1`. Then matrix multiplication
40 | done using the XOR operation. Here's an example:
41 |
42 | Using binary weights and activations of -1 and 1:
43 |
44 | - Vector-matrix operation is `[1, -1] * [1, -1; -1, 1] = [1 * 1 + -1 * -1, 1 * -1 + -1 * 1] = [2, -2]`
45 | - After applying the binary activation function `x > 0 ? 1 : -1` gives `[1, -1]`
46 |
47 | Using binary weights and activations of -1 and 1:
48 |
49 | - Encoding the inputs as binary weights: `[1, 0] * [1, 0; 0, 1]`
50 | - Applying XOR + sum: `[1 ^ 1 + 0 ^ 0; 1 ^ 0 + 0 ^ 1] = [0, 2]`
51 | - Activation function then becomes `x < (2 / 2) ? 1 : 0` which gives `[1, 0]`
52 |
53 | Because the operations are done this way, I made it so that matrix dimensions must be multiples of the integer sizes.
54 | Padding can be used to make data line up correctly (although if someone wants to change this, LMK).
55 |
56 | ## To Do
57 |
58 | - On most Arduinos, Flash memory is bigger than SRAM by about a factor of 32. So it's not too bad to encode models
59 | as characters instead of bits (and it makes them easier to debug). Although this is something that could be improved.
60 | - More examples would be awesome.
61 | - Matrix multiplication could be better, maybe.
62 | - Architectures besides feed-forward networks would be good.
63 |
64 | This was a project that I worked on for CalHacks 3.0 (although I never submitted it).
65 |
--------------------------------------------------------------------------------
/eval/README.md:
--------------------------------------------------------------------------------
1 | # Evaluate trained BNN
2 |
3 | In principle, the training and inference steps for the BNN should be separate from one another. The files in this directory provide code for performing low-level inference on embedded devices, given a model that has been trained elsewhere. This handles memory management appropriately.
4 |
5 | `tinier-nn/models/model.def` (the sample definition) contains weights trained for an XOR function. This file demonstrates how to initialize a model and perform a classification task. These weights were trained and saved using the train/model.py script.
6 |
7 | To run normally, the model needs to be piped to stdin (weight loading can be configured differently, depending on the application). From this directory, run the following:
8 |
9 | make run # To build the file itself.
10 | cat ../models/model.def | ./run
--------------------------------------------------------------------------------
/eval/matrix.h:
--------------------------------------------------------------------------------
1 | /*
2 | * matrix.h
3 | *
4 | * Performs vector-matrix operations associated with the Binarized Neural
5 | * Network architecture.
6 | */
7 |
8 | #ifndef MATRIX_H_
9 | #define MATRIX_H_
10 |
11 | #include "utils.h"
12 | #include
13 |
14 |
15 | typedef struct _matrix {
16 | /* Defines a matrix struct. */
17 | dim_t w, h;
18 | data_t *data;
19 | } matrix;
20 |
21 |
22 | typedef struct _vector {
23 | /* Defines a vector struct. */
24 | dim_t h;
25 | data_t *data;
26 | } vector;
27 |
28 |
29 | void instantiate_matrix(matrix *m, int w, int h) {
30 | /* Instantiates a matrix with width w and height h.
31 | *
32 | * This should only be called once; it provides the buffer,
33 | * and data is fed into it.
34 | */
35 | if (w % INT_SIZE != 0 || h % INT_SIZE != 0) {
36 | log_str("Invalid matrix size requested.\n");
37 | exit_failure();
38 | }
39 |
40 | m->w = w;
41 | m->h = h;
42 | m->data = allocate_memory(w * h * INT_SIZE);
43 | }
44 |
45 |
46 | void instantiate_vector(vector *v, int h) {
47 | /* Instantiates a vector with height h.
48 | *
49 | * This should only be called once; it provides the buffer,
50 | * and data is fed into it.
51 | */
52 | if (h % INT_SIZE != 0) {
53 | log_str("Invalid vector size requested.\n");
54 | exit_failure();
55 | }
56 |
57 | v->h = h;
58 | v->data = allocate_memory(h);
59 | }
60 |
61 |
62 | data_t bitsum(data_t x) {
63 | /* Calculates the bitsum of a data_t, e.g. the number of 1's.
64 | * This is useful for matrix multiplication.
65 | */
66 | data_t c = 0;
67 | for (int i = 0; i < INT_SIZE; i++) {
68 | c += (x & 1);
69 | x >>= 1;
70 | }
71 | return c;
72 | }
73 |
74 |
75 | void matmul(vector *from, vector *to, matrix *by) {
76 | /* Multiples binary vector "from" into "by" and puts the
77 | * result in "to". It is assumed that from.h == by.w
78 | * and to.h == by.h.
79 | *
80 | * The data in the matrix should be row-major, e.g. go through
81 | * all of 1 .. w each step of 1 .. h.
82 | */
83 | dim_t from_h = from->h / INT_SIZE, to_h = to->h / INT_SIZE;
84 | for (dim_t i = 0; i < to_h; i++) {
85 | data_t d = 0;
86 | for (dim_t j = 0; j < INT_SIZE; j++) {
87 | d <<= 1;
88 | data_t c = 0;
89 | for (dim_t k = 0; k < from_h; k++) {
90 | c += bitsum(by->data[((i * INT_SIZE) + j) * from_h + k] ^
91 | from->data[k]);
92 | }
93 |
94 | // Threshold function.
95 | d |= (c >= from->h / 2) ? 0 : 1;
96 | }
97 | to->data[i] = d;
98 | }
99 | }
100 |
101 |
102 | void print_mat(matrix *m) {
103 | for (dim_t i = 0; i < m->h; i++) {
104 | for (dim_t j = 0; j < m->w; j++) {
105 | log_str((m->data[(i * m->w + j) / INT_SIZE] &
106 | (1 << (j % INT_SIZE))) ? "1" : "0");
107 | }
108 | log_str("\n");
109 | }
110 | }
111 |
112 |
113 | void print_vec(vector *v) {
114 | for (dim_t i = 0; i < v->h; i++) {
115 | log_str((v->data[i / INT_SIZE] &
116 | (1 << (i % INT_SIZE))) ? "1" : "0");
117 | }
118 | log_str("\n");
119 | }
120 |
121 |
122 | #endif /* MATRIX_H_ */
123 |
--------------------------------------------------------------------------------
/eval/model.h:
--------------------------------------------------------------------------------
1 | /*
2 | * model.h
3 | *
4 | * Components associated with neural network architecture.
5 | */
6 |
7 | #ifndef MODEL_H_
8 | #define MODEL_H_
9 |
10 | #include "matrix.h"
11 |
12 |
13 | typedef struct _dense {
14 | matrix weights;
15 | vector outputs;
16 | struct _dense *next;
17 | } dense;
18 |
19 |
20 | void build_layer(dense *buffer, dim_t num_input, dim_t num_output) {
21 | /* Builds a single feedforward dense layer.
22 | *
23 | * Args:
24 | * buffer: pointer to location to hold data
25 | * num_input: number of input dimensions
26 | * num_output: number of output dimensions
27 | *
28 | * Returns:
29 | * A dense binary layer which performs a linear transform from
30 | * num_input to num_output dimensions.
31 | */
32 |
33 | instantiate_matrix(&buffer->weights, num_input, num_output);
34 | instantiate_vector(&buffer->outputs, num_output);
35 | buffer->next = NULL;
36 | }
37 |
38 |
39 | vector* get_result(vector *input, dense *head) {
40 | /* Passes through every layer in the network and returns a pointer to the
41 | * result.
42 | *
43 | * Args:
44 | * input: pointer to the input vector (to be processed)
45 | * dense: pointer to the first dense layer in a network.
46 | *
47 | * Returns:
48 | * A pointer to the vector containing the output data.
49 | */
50 |
51 | while (head != NULL) {
52 | matmul(input, &head->outputs, &head->weights);
53 | input = &head->outputs;
54 | head = head->next;
55 | }
56 | return input;
57 | }
58 |
59 |
60 | int load_model(dense *buffer, unsigned max_layers) {
61 | /* Loads a model into SRAM from a text file.
62 | *
63 | * Returns:
64 | * number of layers in the model (-1 if there was an error);
65 | */
66 |
67 | if (next_char() != 'b' || next_char() != 'n' || next_char() != 'n') {
68 | log_str("Invalid magic string.");
69 | exit_failure();
70 | }
71 |
72 | // Holds the dimensions of the layer.
73 | dim_t w, h;
74 |
75 | for (int layer = 0; layer < max_layers; layer++) {
76 |
77 | // Reads in width and height;
78 | get_dims(&w, &h);
79 | if (w == 0 || h == 0) {
80 | return layer;
81 | }
82 |
83 | // Instantiate layer itself.
84 | build_layer(&buffer[layer], w, h);
85 |
86 | // Connect previous layer to current layer.
87 | if (layer > 0) {
88 | buffer[layer-1].next = &buffer[layer];
89 | }
90 |
91 | // Reads in data for the current model.
92 | dim_t max_v = (buffer[layer].weights.w *
93 | buffer[layer].weights.h) / INT_SIZE;
94 | for (dim_t i = 0; i < max_v; i++) {
95 | data_t d = 0;
96 | dim_t j = 0;
97 | while (j < INT_SIZE) {
98 | char c = next_char();
99 | if (c == '\n') {
100 | continue;
101 | }
102 |
103 | d <<= 1;
104 | if (c == '1') {
105 | d |= 1;
106 | }
107 | j++;
108 | }
109 | buffer[layer].weights.data[i] = d;
110 | }
111 | }
112 |
113 | return -1;
114 | }
115 |
116 |
117 | #endif /* MODEL_H_ */
118 |
--------------------------------------------------------------------------------
/eval/run.c:
--------------------------------------------------------------------------------
1 | /*
2 | * run.c
3 | *
4 | * Sample script for evaluating a trained model.
5 | */
6 |
7 | #include "model.h"
8 |
9 | #define MAX_LAYERS 10
10 |
11 | vector input_vec[4];
12 | dense layers[MAX_LAYERS];
13 |
14 | int main() {
15 |
16 | // Builds the network.
17 | int num_layers = load_model(layers, MAX_LAYERS); // Load from stdin.
18 | printf("%d layers in model.\n", num_layers);
19 |
20 | // Run on XOR input vectors.
21 | for (int i = 0; i < 4; i++) {
22 |
23 | // Actually allocates space for the vector itself.
24 | instantiate_vector(&input_vec[i], 32);
25 |
26 | // Adds the XOR data (0 -> 00, 1 -> 01, 2 -> 10, 3 -> 11).
27 | input_vec[i].data[0] = i << 30;
28 |
29 | // Runs the network and prints the output.
30 | vector *result = get_result(&input_vec[i], layers);
31 | print_vec(result);
32 | }
33 |
34 | return 0;
35 | }
36 |
--------------------------------------------------------------------------------
/eval/utils.h:
--------------------------------------------------------------------------------
1 | /*
2 | * utils.h
3 | *
4 | * When porting to a different embedded system, these values should be updated
5 | * for the particular compiler.
6 | */
7 |
8 | #ifndef UTILS_H_
9 | #define UTILS_H_
10 |
11 | #include
12 | #include
13 |
14 | /*
15 | * Defines types associated with matrices / vectors.
16 | */
17 | typedef unsigned int dim_t;
18 | typedef unsigned int data_t;
19 | typedef int bool_t;
20 |
21 | /*
22 | * Defines the number of bits in a single int.
23 | */
24 | #define INT_SIZE (sizeof(int) * 8)
25 |
26 | /*
27 | * Failure exit. Uses a system call, so it should probably be generic in order
28 | * to work on other platforms.
29 | */
30 | void exit_failure() {
31 | exit(EXIT_FAILURE);
32 | }
33 |
34 | /*
35 | * Generic operation for logging string information. It could just be ignored.
36 | */
37 | void log_str(const char *x) {
38 | fprintf(stderr, "%s", x);
39 | }
40 |
41 | data_t* allocate_memory(dim_t n) {
42 | if (n % INT_SIZE != 0) {
43 | log_str("Invalid shape requested for memory allocation.\n");
44 | exit_failure();
45 | }
46 | return (data_t*) malloc(n / INT_SIZE);
47 | }
48 |
49 | /*
50 | * next_char and next_int define generic operations for reading and writing
51 | * from some data source, for loading models. These operations could be
52 | * changed to read from a serial port.
53 | */
54 | char next_char() {
55 | return getchar();
56 | }
57 |
58 | void get_dims(dim_t *w, dim_t *h) {
59 | scanf("%d,%d", w, h);
60 | }
61 |
62 | #endif /* UTILS_H_ */
63 |
--------------------------------------------------------------------------------
/models/model.def:
--------------------------------------------------------------------------------
1 | bnn
2 | 32,64
3 | 01011011100100010110101001110101
4 | 10000111111111100010011001111100
5 | 11011011000001111011000111001110
6 | 11000110001111111100011101001000
7 | 01110011000000101010001110010111
8 | 10101111010011111011011010101010
9 | 11100100000100100001010110111000
10 | 00010101000001111010000011100110
11 | 00010110101110011110111001100011
12 | 01110100000011100000111010110110
13 | 10110100001101101110111110101111
14 | 00010011010111001001011101011001
15 | 00100100000110101001001100010011
16 | 00001000001000111100101100011100
17 | 00111100010001000010101001011001
18 | 10001010011100010110110110101011
19 | 00110101001000101011100011000101
20 | 01000100111111011011100010000001
21 | 01001111100100011111110010110111
22 | 11010100011101110001010000110110
23 | 01100001011010001111000001100000
24 | 11100011000000111111011001001111
25 | 00101111010010000000010110100001
26 | 10000001010100101011011011000000
27 | 00111101100111100010011011101100
28 | 10110010101001111011100001000000
29 | 01110010101000000010110011000110
30 | 10011000011011110111011101000111
31 | 01101000111101111001001010000000
32 | 11001011110110100100100000111100
33 | 01000010001000001000010011100110
34 | 11100101110010110000101000110110
35 | 11101010110000011110101011011000
36 | 10010010001111111001001101101111
37 | 11000000001001010001011000110000
38 | 01101101100110001111010001011111
39 | 00111000100001001110001110111000
40 | 00011000001010001100111111100100
41 | 11101101110100001101101000000101
42 | 01100000001100111011100110011001
43 | 00110101110101101001001010000110
44 | 01111001000010111101111011101011
45 | 01010100101101011001010111000100
46 | 10100010000100011101101010110111
47 | 00110001101110111011011101000101
48 | 00101101101100011101110101100111
49 | 11001011111011001000011011100011
50 | 00000010011110000101011001000001
51 | 00000110101101001110011010000110
52 | 01111000111101101011111100010011
53 | 11001010000001011101001011111011
54 | 01010100001000111100011010101001
55 | 00000001110110011000011001100011
56 | 10001101101011011110000000110111
57 | 00011011000100101011001001111001
58 | 01100100100100110010000101111011
59 | 01101110111011010001101010000110
60 | 01101011010111110001100010100001
61 | 10001101001010100001110001110001
62 | 11101011111110001111001110001111
63 | 00011101001101100101110000000101
64 | 01001100100010010000000011010111
65 | 01010011100110011011110101100100
66 | 10101000000111000001110110110011
67 | 64,64
68 | 0110000001101101100001011011010110110001111011111000110111000011
69 | 1001010100001101100111011100110110010111000001110110101101000111
70 | 1110011110100011101000111110000101010111000111010011010011100011
71 | 0011110010100110010010101001001100010000000011001010110100011100
72 | 1110110011100011110001100111101101001000111011100111111110000111
73 | 1000101110100001001110000001111000111111100101110110100011001110
74 | 1111101011100000001101010100001110010000100000110010010101011010
75 | 0111110001001111011001010001110100110011000010110010011000000001
76 | 0001001100001110010111011001011111100111011001001001011001001111
77 | 1100010001100011100110101011101110001111010011110100110011101110
78 | 1000010110101001111100110000111110000110000001101010110010100010
79 | 0000111000010110010110011011001010101011100011110000110001011011
80 | 1011100111100101110001001111100011010100011010100101001000011100
81 | 1001110000010010111100110100111110000101011110111011001110010001
82 | 0001101110001001110101100100010011101000001010111010101110000101
83 | 0010100101000010000000111011010011101001101001111110101111000000
84 | 1000110000101111110110010101111110000001101011000011100101110101
85 | 0100000001010010101101000010100010010011010100010000101000001101
86 | 0101111110101001011010111101100100011010110011010100000100011110
87 | 1010001000110110111001001010001001100101100011001110001110000111
88 | 0111000011100100011001111111011010111001011101100101110100101000
89 | 1000110110000010101111110001011111001001010100001010010010111010
90 | 0101001101011011011000111111100000000101000010010111111001111110
91 | 0110000111010000010101000000111000001010011100111001011100111000
92 | 1100001000011010000001001011111110101101001011000011000010100100
93 | 0100111100101100110011111000110010011110101100100010000100111001
94 | 1011101100011010100100000111101111001101011111011101100101110001
95 | 0000111101111100100111100010101011110001011100000011011011111010
96 | 1001011000011101100011101101101011001100111110011110100111101111
97 | 0011001011001011011001011111011110110110010011100100100001000100
98 | 0011111111111010001001010010110011111001010110100100100101100101
99 | 1111101011010101100011011100100001111000100111101011011011100011
100 | 0001000101011111011010111111101110001101001101111111100010001100
101 | 1101001100011110101110101001001111001010100101000100110111110000
102 | 1000010000011111100100010001100110111100001000011110000011010000
103 | 0110000000011001101010000110010110101010000011001011101011011000
104 | 1101001001101111000111101111110010110010000011011111111101001011
105 | 0100100011101000011010010011101000110101010100100000101110000001
106 | 1000101011110001011100011000000011110101010011100101110100110100
107 | 0111000111001100011001011010111000101001101111100000001100010111
108 | 1000010001100111101110001101111110001100010010011011110010011010
109 | 0000101010110000100100101110001111001111011011101111010001001100
110 | 1100001001101101111000100111011110111111001101101001100101100010
111 | 1001100101111010100101111011110000100001011000000100100111110111
112 | 1100010101111001100000111110011010010010100011000100010101111001
113 | 0111000101100111000101001011000101101000110110011100111000110001
114 | 0010001100000111100011110111100000111001010100110010100001000100
115 | 0010011011001001111010111011011000111000010111011010010010011100
116 | 1101110101001001001011101001000111111001101101110000010011100000
117 | 0000111101011110110101010101001000100110101011101110110010001100
118 | 0011101011000101001010001001010101110110101001000110001001001011
119 | 0100100011000110001111100010011111000010001100010111001111100001
120 | 0011110011101101101010001101011100001110100110001101010001000110
121 | 0110100010011100001001011010010111110110000011100101010010010011
122 | 1000001100010011111000100001111100011101100110000111110110001010
123 | 1101100000010011010100110010101110011110011000010101000111101011
124 | 1101001100000000110010100111111101110010001011000001010000101101
125 | 0101000111110100011010110011100001111000100100001101011100010110
126 | 0101101100101100011011000010001101100001101110100101001000000000
127 | 1100010110011101010000100010011000001000111111110101010000001101
128 | 1010001000111110110111011101111010011010100111010000101001101110
129 | 0001110110110111010110001001010111000110101001011000101011010111
130 | 0111111000101010000110000001101001110111001011100000011011111011
131 | 1101110100001001111110100010010101000111110010001111010101001011
132 | 64,32
133 | 0001001110010000110110010001001000100100000000001111000011110001
134 | 0001101010010000100011001100001000000001001001010100111001101100
135 | 1010100100010011110001110111111100100110011011100011101001101010
136 | 0101001101011100010000011100001111011111000000000101110001101000
137 | 0001010110100001011110011110111011000111011110111010011010110000
138 | 0100101110001111010100010011110010100011100111101010110111100000
139 | 1000100000000101110110110010011110010101100000001001100011101011
140 | 0011000001000110100110111000111101000100001110010010000111101100
141 | 0011101110011100001010001011010101000101000100100101001001010001
142 | 1000001110001001100111111001001010001110011011001000010110110001
143 | 0011001000110111011001110001010010000110001100111011111101100010
144 | 1100111111010001011010001100100011000101001101100010010011110010
145 | 1001111001110001111100100101101000010100101111011011101101110010
146 | 1001101110010000111100000101011010100101011101111111011010100110
147 | 1010011101000011100110110100001111100111101101001001000101111001
148 | 0010100010011001111110001111011110100001000110100011111101100100
149 | 1100100111001100110110101011111100100111010100101011010001110000
150 | 1000010110010001110110111110000100001101000110011011111000100011
151 | 1000011001010100011111000111011010110001000100001000010001111000
152 | 1011100110011111101010010111111110110101001000100010110101000101
153 | 1001101100010100011110100000011001000110001000100101001111011010
154 | 0001000110001011011011111000001000001000001100011110110111111110
155 | 1010100000000101010000000100001000000111000101011101000110101110
156 | 0011001111010001111010011011111110100101011111000110010011101110
157 | 1111001010000110000000010110010001001101011111001111100001111000
158 | 1001000010000111010110010011111100011110001010100110010101100110
159 | 0010011010000011010110010110101000001111010001000010111110110001
160 | 1010001011010001001010010110010110000101011000001100010100111010
161 | 1011001110001000100111101111011110110101010110101110000011101011
162 | 0001101100010101110101010100001010001001000110001000001000010100
163 | 0100010010001100110001000111000011111111001111101100010000001001
164 | 0010100010000101101110110110011111011101001100111000101101100111
165 | 0,0
--------------------------------------------------------------------------------
/train/README.md:
--------------------------------------------------------------------------------
1 | # Train the BNN
2 |
3 | Training should be done off the embedded device. This directory provides a framework for defining and training BNNs in TensorFlow, then saving the weights so that they can be ported to an embedded device.
4 |
5 | Model creators for Binary NN.
6 |
7 | Reference paper: [https://arxiv.org/pdf/1602.02830v3.pdf](https://arxiv.org/pdf/1602.02830v3.pdf)
8 |
9 | The script itself trains a binary network to perform a simple XOR classification task (although this could be extended to other tasks). The weights are saved to an output file which can be read by the C evaluation code (in particular, eval/run.c demonstrates how to build input vectors and run them through the network).
10 |
11 | From start to finish, models can be trained and evaluated using:
12 |
13 | make eval/run
14 | python train/model.py --save_path models/model.def
15 | cat models/model.def | eval/run
16 |
--------------------------------------------------------------------------------
/train/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import argparse
8 | import os
9 | import warnings
10 |
11 | import numpy as np
12 | import tensorflow as tf
13 |
14 |
15 | def add_binary_layer(prev_layer, num_outputs, inference, uid):
16 | """Adds a binary layer to the network.
17 |
18 | Args:
19 | prev_layer: 2D Tensor with shape (batch_size, num_hidden).
20 | num_outputs: int, number of dimensions in next hidden layer.
21 | inference: bool scalar Tensor, tells the network if it is
22 | the inference step.
23 | uid: int, the unique identifier.
24 |
25 | Returns:
26 | tuple (activation, updates)
27 | activation: output of this layer.
28 | updates: variables used to compute gradients.
29 | """
30 |
31 | batch_size, num_inputs = [dim.value for dim in prev_layer.get_shape()]
32 |
33 | def _pos_neg_like(tensor):
34 | return (tf.constant(1, dtype=tf.float32, shape=tensor.get_shape()),
35 | tf.constant(-1, dtype=tf.float32, shape=tensor.get_shape()))
36 |
37 | with tf.variable_scope('binary_layer_%d' % uid) as vs:
38 | w = tf.get_variable('w', shape=(num_inputs, num_outputs),
39 | initializer=tf.random_normal_initializer(seed=uid),
40 | trainable=True)
41 |
42 | # Binarize the weight matrix.
43 | bin_w = tf.select(w > 0, *_pos_neg_like(w))
44 |
45 | # Output has shape (batch_size, num_outputs)
46 | wx = tf.matmul(prev_layer, bin_w)
47 |
48 | with tf.variable_scope('activation'):
49 | p = tf.tanh(wx)
50 |
51 | def _binomial_activation():
52 | sigma = tf.random_uniform(
53 | p.get_shape(), seed=uid) < (p + 1) / 2
54 | return tf.select(sigma, *_pos_neg_like(p))
55 |
56 | def _binary_activation():
57 | return tf.select(wx > 0, *_pos_neg_like(wx))
58 |
59 | activation = tf.cond(inference,
60 | _binomial_activation, # Stochastic inference.
61 | _binary_activation) # Deterministic inference.
62 |
63 | return activation, (p, prev_layer, w, bin_w)
64 |
65 |
66 | def build_model(input_var, layers=[]):
67 | """Builds a BNN model.
68 |
69 | Args:
70 | input_var: 2D Tensor with shape (batch_size, num_input_features).
71 | layers: list of ints, the dimensionality of each layer.
72 |
73 | Returns:
74 | tuple (output_layer, updates, inference)
75 | output_layer: 2D Tensor, the output of the model.
76 | updates: list of variables used to compute gradients.
77 | inference: scalar bool Tensor, used to tell the model when to
78 | use inference mode.
79 | """
80 |
81 | inference = tf.placeholder(dtype=tf.bool, name='inference')
82 | updates = []
83 | hidden_layer = input_var
84 | for i, num_hidden in enumerate(layers):
85 | if num_hidden % 16 != 0:
86 | warnings.warn('Hidden layers should be multiples of '
87 | '16, not %d' % num_hidden)
88 |
89 | hidden_layer, update = add_binary_layer(hidden_layer, num_hidden,
90 | inference, i)
91 | updates.append(update)
92 | output_layer = hidden_layer
93 |
94 | return output_layer, updates, inference
95 |
96 |
97 | def get_loss(output, target):
98 | return tf.reduce_mean(tf.square(output - target))
99 |
100 |
101 | def get_accuracy(output, target, element_wise=True):
102 | if not element_wise:
103 | output, target = tf.reduce_sum(output, [1]), tf.reduce_sum(target, [1])
104 | eq = tf.equal(tf.greater(output, 0), tf.greater(target, 0))
105 | return tf.reduce_mean(tf.cast(eq, tf.float32))
106 |
107 |
108 | def binary_backprop(loss, output, updates):
109 | """Manually backpropagates gradient error.
110 |
111 | Args:
112 | loss: scalar Tensor, the model loss.
113 | output: scalar
114 | updates: list of updates to use for backprop.
115 |
116 | Returns:
117 | backprop_updates: list of (grad, variable) tuples, the gradients.
118 | """
119 |
120 | backprop_updates = []
121 | loss_grad, = tf.gradients(loss, output)
122 |
123 | for p, prev_layer, w, bin_w in updates[::-1]:
124 | w_grad, loss_grad = tf.gradients(p, [bin_w, prev_layer], loss_grad)
125 | backprop_updates.append((w_grad, w))
126 |
127 | return backprop_updates
128 |
129 |
130 | def save_model(path, binary_weights):
131 | with open(os.path.join(path, 'model.def'), 'w') as f:
132 | f.write('bnn')
133 | for i, weight in enumerate(binary_weights):
134 | weight = weight.T
135 | f.write('\n%d,%d\n' % (weight.shape[1], weight.shape[0]))
136 | bstr = '\n'.join(''.join('0' if e < 0 else '1' for e in r)
137 | for r in weight)
138 | f.write(bstr)
139 | f.write('\n0,0')
140 |
141 |
142 | def main():
143 | parser = argparse.ArgumentParser(
144 | description='Train binary neural network weights.')
145 | parser.add_argument(
146 | '--save_path',
147 | type=str,
148 | help='Where to save the weights and configuration.',
149 | required=True)
150 | parser.add_argument(
151 | '--num_train',
152 | type=int,
153 | help='Number of iterations to train model.',
154 | default=20000)
155 | parser.add_argument(
156 | '--eval_every',
157 | type=int,
158 | help='How often to evalute the model.',
159 | default=1000)
160 | args = parser.parse_args()
161 |
162 | input_var = tf.placeholder(dtype=tf.float32, shape=(4, 32),
163 | name='input_placeholder')
164 | output_var = tf.placeholder(dtype=tf.float32, shape=(4, 32),
165 | name='output_placeholder')
166 | layer_sizes = [64, 64, 32]
167 |
168 | # Configures data for a simple XOR task.
169 | input_data = np.zeros(shape=(4, 32))
170 | input_data[(0, 0, 1, 2), (0, 1, 1, 0)] = 1
171 | output_data = np.zeros(shape=(4, 32))
172 | output_data[(0, 3), :] = 1
173 | input_data = input_data * 2 - 1
174 | output_data = output_data * 2 - 1
175 |
176 | feed_dict = {
177 | input_var: input_data,
178 | output_var: output_data,
179 | }
180 |
181 | model_scope = 'model'
182 | with tf.variable_scope(model_scope):
183 | output_layer, updates, inference = build_model(
184 | input_var, layers=layer_sizes)
185 | model_vars = tf.get_collection(tf.GraphKeys.VARIABLES,
186 | scope=model_scope)
187 | loss = get_loss(output_layer, output_var)
188 | accuracy = get_accuracy(output_layer, output_var)
189 |
190 | global_step = tf.Variable(0, name='global_step', trainable=False)
191 | learning_rate = 0.01
192 |
193 | gradients = binary_backprop(loss, output_layer, updates)
194 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
195 | min_op = optimizer.apply_gradients(gradients, global_step=global_step)
196 |
197 | # Get model weights.
198 | _, _, _, weights = zip(*updates)
199 |
200 | best_accuracy = 0
201 | with tf.Session() as sess:
202 | sess.run(tf.initialize_all_variables())
203 | for i in xrange(1, args.num_train + 1):
204 | feed_dict[inference] = True
205 | sess.run(min_op, feed_dict=feed_dict)
206 | if i % args.eval_every == 0:
207 | feed_dict[inference] = False
208 | accuracy_val, loss_val = sess.run([accuracy, loss],
209 | feed_dict=feed_dict)
210 | print('Epoch = %d: Loss = %.4f, Accuracy = %.4f' %
211 | (i, loss_val, accuracy_val))
212 | if accuracy_val > best_accuracy:
213 | for entry in sess.run(output_layer, feed_dict=feed_dict):
214 | out_str = ''.join('0' if i < 0 else '1' for i in entry)
215 | print(out_str[::-1])
216 | save_model(args.save_path, sess.run(weights))
217 | best_accuracy = accuracy_val
218 |
219 |
220 | if __name__ == '__main__':
221 | main()
222 |
--------------------------------------------------------------------------------
/train/requirements.txt:
--------------------------------------------------------------------------------
1 | funcsigs==1.0.2
2 | mock==2.0.0
3 | numpy==1.11.2
4 | pbr==1.10.0
5 | protobuf==3.0.0
6 | six==1.10.0
7 | tensorflow==0.11.0
8 |
--------------------------------------------------------------------------------