├── .gitignore
├── LICENSE
├── README.md
├── checkpoint
└── .gitignore
├── os_elm.py
├── train_iris.py
└── train_mnist.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.out
2 | .vscode/
3 | __pycache__/
4 | weights/*
5 | datasets/*
6 | results/*
7 | models/*
8 | *.pkl
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Otenim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TF-OS-ELM
2 |
3 | ## Overview
4 |
5 |
6 |

7 |
8 |
9 | In this repository, we provide a tensorflow implementation of Online Sequential
10 | Extreme Learning Machine (OS-ELM) introduced by Liang et al. in this [paper](http://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=4012031).
11 | You can execute our OS-ELM module either on CPUs or GPUs.
12 |
13 | OS-ELM is able to learn faster and training will always
14 | converge to the global optimal solution, while ordinary backpropagation-based
15 | neural networks have to deal with the local minima problem.
16 |
17 | ## Dependencies
18 |
19 | We tested our codes by using the following libraries.
20 |
21 | * Python==3.6.0
22 | * Numpy==1.14.1
23 | * Tensorflow==1.6.0
24 | * Keras==2.1.5
25 | * scikit-learn==0.17.1
26 |
27 | We used Keras only for downloading the MNIST dataset.
28 |
29 | You don't have to use exactly the same version of the each library,
30 | but we can not guarantee the codes work well in the case.
31 |
32 | All the above libraries can be installed in the following command.
33 |
34 | `$ pip install -U numpy Keras scikit-learn tensorflow`
35 |
36 | If you want to run our OS-ELM module on GPUs, please install `tensorflow-gpu`
37 | in addition to the above command.
38 |
39 | ## Usage
40 |
41 | Here, we show how to train a OS-ELM module and predict on it.
42 | For the sake of simplicity, we assume to train the model on MNIST, a
43 | hand-written digits dataset.
44 |
45 | ```python
46 | from keras.datasets import mnist
47 | from keras.utils import to_categorical
48 | from os_elm import OS_ELM
49 | import numpy as np
50 | import tensorflow as tf
51 | import tqdm
52 |
53 | def softmax(a):
54 | c = np.max(a, axis=-1).reshape(-1, 1)
55 | exp_a = np.exp(a - c)
56 | sum_exp_a = np.sum(exp_a, axis=-1).reshape(-1, 1)
57 | return exp_a / sum_exp_a
58 |
59 | def main():
60 |
61 | # ===========================================
62 | # Instantiate os-elm
63 | # ===========================================
64 | n_input_nodes = 784
65 | n_hidden_nodes = 1024
66 | n_output_nodes = 10
67 |
68 | os_elm = OS_ELM(
69 | # the number of input nodes.
70 | n_input_nodes=n_input_nodes,
71 | # the number of hidden nodes.
72 | n_hidden_nodes=n_hidden_nodes,
73 | # the number of output nodes.
74 | n_output_nodes=n_output_nodes,
75 | # loss function.
76 | # the default value is 'mean_squared_error'.
77 | # for the other functions, we support
78 | # 'mean_absolute_error', 'categorical_crossentropy', and 'binary_crossentropy'.
79 | loss='mean_squared_error',
80 | # activation function applied to the hidden nodes.
81 | # the default value is 'sigmoid'.
82 | # for the other functions, we support 'linear' and 'tanh'.
83 | # NOTE: OS-ELM can apply an activation function only to the hidden nodes.
84 | activation='sigmoid',
85 | )
86 |
87 | # ===========================================
88 | # Prepare dataset
89 | # ===========================================
90 | n_classes = n_output_nodes
91 |
92 | # load MNIST
93 | (x_train, t_train), (x_test, t_test) = mnist.load_data()
94 | # normalize images' values within [0, 1]
95 | x_train = x_train.reshape(-1, n_input_nodes) / 255.
96 | x_test = x_test.reshape(-1, n_input_nodes) / 255.
97 | x_train = x_train.astype(np.float32)
98 | x_test = x_test.astype(np.float32)
99 |
100 | # convert label data into one-hot-vector format data.
101 | t_train = to_categorical(t_train, num_classes=n_classes)
102 | t_test = to_categorical(t_test, num_classes=n_classes)
103 | t_train = t_train.astype(np.float32)
104 | t_test = t_test.astype(np.float32)
105 |
106 | # divide the training dataset into two datasets:
107 | # (1) for the initial training phase
108 | # (2) for the sequential training phase
109 | # NOTE: the number of training samples for the initial training phase
110 | # must be much greater than the number of the model's hidden nodes.
111 | # here, we assign int(1.5 * n_hidden_nodes) training samples
112 | # for the initial training phase.
113 | border = int(1.5 * n_hidden_nodes)
114 | x_train_init = x_train[:border]
115 | x_train_seq = x_train[border:]
116 | t_train_init = t_train[:border]
117 | t_train_seq = t_train[border:]
118 |
119 |
120 | # ===========================================
121 | # Training
122 | # ===========================================
123 | # the initial training phase
124 | pbar = tqdm.tqdm(total=len(x_train), desc='initial training phase')
125 | os_elm.init_train(x_train_init, t_train_init)
126 | pbar.update(n=len(x_train_init))
127 |
128 | # the sequential training phase
129 | pbar.set_description('sequential training phase')
130 | batch_size = 64
131 | for i in range(0, len(x_train_seq), batch_size):
132 | x_batch = x_train_seq[i:i+batch_size]
133 | t_batch = t_train_seq[i:i+batch_size]
134 | os_elm.seq_train(x_batch, t_batch)
135 | pbar.update(n=len(x_batch))
136 | pbar.close()
137 |
138 | # ===========================================
139 | # Prediction
140 | # ===========================================
141 | # sample 10 validation samples from x_test
142 | n = 10
143 | x = x_test[:n]
144 | t = t_test[:n]
145 |
146 | # 'predict' method returns raw values of output nodes.
147 | y = os_elm.predict(x)
148 | # apply softmax function to the output values.
149 | y = softmax(y)
150 |
151 | # check the answers.
152 | for i in range(n):
153 | max_ind = np.argmax(y[i])
154 | print('========== sample index %d ==========' % i)
155 | print('estimated answer: class %d' % max_ind)
156 | print('estimated probability: %.3f' % y[i,max_ind])
157 | print('true answer: class %d' % np.argmax(t[i]))
158 |
159 | # ===========================================
160 | # Evaluation
161 | # ===========================================
162 | # we currently support 'loss' and 'accuracy' for 'metrics'.
163 | # NOTE: 'accuracy' is valid only if the model assumes
164 | # to deal with a classification problem, while 'loss' is always valid.
165 | # loss = os_elm.evaluate(x_test, t_test, metrics=['loss']
166 | [loss, accuracy] = os_elm.evaluate(x_test, t_test, metrics=['loss', 'accuracy'])
167 | print('val_loss: %f, val_accuracy: %f' % (loss, accuracy))
168 |
169 | # ===========================================
170 | # Save model
171 | # ===========================================
172 | print('saving model parameters...')
173 | os_elm.save('./checkpoint/model.ckpt')
174 |
175 | # initialize weights of os_elm
176 | os_elm.initialize_variables()
177 |
178 | # ===========================================
179 | # Load model
180 | # ===========================================
181 | # If you want to load weights to a model,
182 | # the architecture of the model must be exactly the same
183 | # as the one when the weights were saved.
184 | print('restoring model parameters...')
185 | os_elm.restore('./checkpoint/model.ckpt')
186 |
187 | # ===========================================
188 | # ReEvaluation
189 | # ===========================================
190 | # loss = os_elm.evaluate(x_test, t_test, metrics=['loss']
191 | [loss, accuracy] = os_elm.evaluate(x_test, t_test, metrics=['loss', 'accuracy'])
192 | print('val_loss: %f, val_accuracy: %f' % (loss, accuracy))
193 |
194 | if __name__ == '__main__':
195 | main()
196 | ```
197 |
198 | ## Notes
199 |
200 | The following figure shows OS-ELM training formula.
201 |
202 |
203 |

204 |
205 |
206 |
207 | * **important**: Since matrix inversion in OS-ELM update formula has a lot of conditional operations, even if it is executed on GPUs, the training is not necessarily accelerated.
208 | * In OS-ELM, you can apply an activation function only to the hidden nodes.
209 | * OS-ELM always finds the global optimal solution for the weight matrices at every training.
210 | * If you feed all the training samples to OS-ELM in the initial training phase,
211 | the computational procedures will be exactly the same as ELM. So, we can consider ELM is a special case of OS-ELM.
212 | * OS-ELM does not need to train iteratively on the same data samples,
213 | while backpropagation-based models usually need to do that.
214 | * OS-ELM does not update 'alpha', the weight matrix connecting the input nodes
215 | and the hidden nodes. It makes OS-ELM train faster.
216 | * OS-ELM does not need to compute gradients. The weight matrices are trained by
217 | computing some matrix products and a matrix inversion.
218 | * The computational complexity for the matrix inversion is about O(batch\_size^3),
219 | so take care for the cost when you increase batch\_size.
220 |
221 | ## Demo
222 |
223 | You can execute the above sample code with the following command.
224 |
225 | `$ python train_mnist.py`
226 |
227 | ## Todos
228 |
229 | * support more activation functions
230 | * support more loss functions
231 | * provide benchmark results
232 |
--------------------------------------------------------------------------------
/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | checkpoint
2 | model.ckpt.data-00000-of-00001
3 | model.ckpt.index
4 | model.ckpt.meta
5 |
--------------------------------------------------------------------------------
/os_elm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tqdm
3 | import tensorflow as tf
4 | import os
5 |
6 | class OS_ELM(object):
7 |
8 | def __init__(
9 | self, n_input_nodes, n_hidden_nodes, n_output_nodes,
10 | activation='sigmoid', loss='mean_squared_error', name=None):
11 |
12 | if name == None:
13 | self.name = 'model'
14 | else:
15 | self.name = name
16 |
17 | self.__sess = tf.Session()
18 | self.__n_input_nodes = n_input_nodes
19 | self.__n_hidden_nodes = n_hidden_nodes
20 | self.__n_output_nodes = n_output_nodes
21 |
22 | if activation == 'sigmoid':
23 | self.__activation = tf.nn.sigmoid
24 | elif activation == 'linear' or activation == None:
25 | self.__activation = tf.identity
26 | elif activation == 'tanh':
27 | self.__activation = tf.tanh
28 | else:
29 | raise ValueError(
30 | 'an unknown activation function \'%s\' was given.' % (activation)
31 | )
32 |
33 | if loss == 'mean_squared_error':
34 | self.__lossfun = tf.losses.mean_squared_error
35 | elif loss == 'mean_absolute_error':
36 | self.__lossfun = tf.keras.losses.mean_absolute_error
37 | elif loss == 'categorical_crossentropy':
38 | self.__lossfun = tf.keras.losses.categorical_crossentropy
39 | elif loss == 'binary_crossentropy':
40 | self.__lossfun = tf.keras.losses.binary_crossentropy
41 | else:
42 | raise ValueError(
43 | 'an unknown loss function \'%s\' was given. ' % loss
44 | )
45 |
46 | self.__is_finished_init_train = tf.get_variable(
47 | 'is_finished_init_train',
48 | shape=[],
49 | dtype=bool,
50 | initializer=tf.constant_initializer(False),
51 | )
52 | self.__x = tf.placeholder(tf.float32, shape=(None, self.__n_input_nodes), name='x')
53 | self.__t = tf.placeholder(tf.float32, shape=(None, self.__n_output_nodes), name='t')
54 | self.__alpha = tf.get_variable(
55 | 'alpha',
56 | shape=[self.__n_input_nodes, self.__n_hidden_nodes],
57 | initializer=tf.random_uniform_initializer(-1,1),
58 | trainable=False,
59 | )
60 | self.__bias = tf.get_variable(
61 | 'bias',
62 | shape=[self.__n_hidden_nodes],
63 | initializer=tf.random_uniform_initializer(-1,1),
64 | trainable=False,
65 | )
66 | self.__beta = tf.get_variable(
67 | 'beta',
68 | shape=[self.__n_hidden_nodes, self.__n_output_nodes],
69 | initializer=tf.zeros_initializer(),
70 | trainable=False,
71 | )
72 | self.__p = tf.get_variable(
73 | 'p',
74 | shape=[self.__n_hidden_nodes, self.__n_hidden_nodes],
75 | initializer=tf.zeros_initializer(),
76 | trainable=False,
77 | )
78 |
79 | # Finish initial training phase
80 | self.__finish_init_train = tf.assign(self.__is_finished_init_train, True)
81 |
82 | # Predict
83 | self.__predict = tf.matmul(self.__activation(tf.matmul(self.__x, self.__alpha) + self.__bias), self.__beta)
84 |
85 | # Loss
86 | self.__loss = self.__lossfun(self.__t, self.__predict)
87 |
88 | # Accuracy
89 | self.__accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(self.__predict, 1), tf.argmax(self.__t, 1)), tf.float32))
90 |
91 | # Initial training phase
92 | self.__init_train = self.__build_init_train_graph()
93 |
94 | # Sequential training phase
95 | self.__seq_train = self.__build_seq_train_graph()
96 |
97 | # Saver
98 | self.__saver = tf.train.Saver()
99 |
100 | # Initialize variables
101 | self.__sess.run(tf.global_variables_initializer())
102 |
103 | def predict(self, x):
104 | return self.__sess.run(self.__predict, feed_dict={self.__x: x})
105 |
106 | def evaluate(self, x, t, metrics=['loss']):
107 | met = []
108 | for m in metrics:
109 | if m == 'loss':
110 | met.append(self.__loss)
111 | elif m == 'accuracy':
112 | met.append(self.__accuracy)
113 | else:
114 | return ValueError(
115 | 'an unknown metric \'%s\' was given.' % m
116 | )
117 | ret = self.__sess.run(met, feed_dict={self.__x: x, self.__t: t})
118 | return list(map(lambda x: float(x), ret))
119 |
120 | def init_train(self, x, t):
121 | if self.__sess.run(self.__is_finished_init_train):
122 | raise Exception(
123 | 'the initial training phase has already finished. '
124 | 'please call \'seq_train\' method for further training.'
125 | )
126 | if len(x) < self.__n_hidden_nodes:
127 | raise ValueError(
128 | 'in the initial training phase, the number of training samples '
129 | 'must be greater than the number of hidden nodes. '
130 | 'But this time len(x) = %d, while n_hidden_nodes = %d' % (len(x), self.__n_hidden_nodes)
131 | )
132 | self.__sess.run(self.__init_train, feed_dict={self.__x: x, self.__t: t})
133 | self.__sess.run(self.__finish_init_train)
134 |
135 | def seq_train(self, x, t):
136 | if self.__sess.run(self.__is_finished_init_train) == False:
137 | raise Exception(
138 | 'you have not gone through the initial training phase yet. '
139 | 'please first initialize the model\'s weights by \'init_train\' '
140 | 'method before calling \'seq_train\' method.'
141 | )
142 | self.__sess.run(self.__seq_train, feed_dict={self.__x: x, self.__t: t})
143 |
144 | def __build_init_train_graph(self):
145 | H = self.__activation(tf.matmul(self.__x, self.__alpha) + self.__bias)
146 | HT = tf.transpose(H)
147 | HTH = tf.matmul(HT, H)
148 | p = tf.assign(self.__p, tf.matrix_inverse(HTH))
149 | pHT = tf.matmul(p, HT)
150 | pHTt = tf.matmul(pHT, self.__t)
151 | init_train = tf.assign(self.__beta, pHTt)
152 | return init_train
153 |
154 | def __build_seq_train_graph(self):
155 | H = self.__activation(tf.matmul(self.__x, self.__alpha) + self.__bias)
156 | HT = tf.transpose(H)
157 | batch_size = tf.shape(self.__x)[0]
158 | I = tf.eye(batch_size)
159 | Hp = tf.matmul(H, self.__p)
160 | HpHT = tf.matmul(Hp, HT)
161 | temp = tf.matrix_inverse(I + HpHT)
162 | pHT = tf.matmul(self.__p, HT)
163 | p = tf.assign(self.__p, self.__p - tf.matmul(tf.matmul(pHT, temp), Hp))
164 | pHT = tf.matmul(p, HT)
165 | Hbeta = tf.matmul(H, self.__beta)
166 | seq_train = self.__beta.assign(self.__beta + tf.matmul(pHT, self.__t - Hbeta))
167 | return seq_train
168 |
169 | def save(self, filepath):
170 | tf.reset_default_graph()
171 | self.__saver.save(self.__sess, filepath)
172 |
173 | def restore(self, filepath):
174 | self.__saver.restore(self.__sess, filepath)
175 |
176 | def initialize_variables(self):
177 | for var in [self.__alpha, self.__bias, self.__beta, self.__p, self.__is_finished_init_train]:
178 | self.__sess.run(var.initializer)
179 |
180 | def __del__(self):
181 | self.__sess.close()
182 |
183 | @property
184 | def input_shape(self):
185 | return (self.__n_input_nodes,)
186 |
187 | @property
188 | def output_shape(self):
189 | return (self.__n_output_nodes,)
190 |
191 | @property
192 | def n_input_nodes(self):
193 | return self.__n_input_nodes
194 |
195 | @property
196 | def n_hidden_nodes(self):
197 | return self.__n_hidden_nodes
198 |
199 | @property
200 | def n_output_nodes(self):
201 | return self.__n_output_nodes
202 |
--------------------------------------------------------------------------------
/train_iris.py:
--------------------------------------------------------------------------------
1 | from sklearn import datasets
2 | from keras.utils import to_categorical
3 | from os_elm import OS_ELM
4 | import numpy as np
5 | import tensorflow as tf
6 | import tqdm
7 |
8 | def softmax(a):
9 | c = np.max(a, axis=-1).reshape(-1, 1)
10 | exp_a = np.exp(a - c)
11 | sum_exp_a = np.sum(exp_a, axis=-1).reshape(-1, 1)
12 | return exp_a / sum_exp_a
13 |
14 | def main():
15 |
16 | # ===========================================
17 | # Instantiate os-elm
18 | # ===========================================
19 | n_input_nodes = 4
20 | n_hidden_nodes = 16
21 | n_output_nodes = 3
22 |
23 | os_elm = OS_ELM(
24 | # the number of input nodes.
25 | n_input_nodes=n_input_nodes,
26 | # the number of hidden nodes.
27 | n_hidden_nodes=n_hidden_nodes,
28 | # the number of output nodes.
29 | n_output_nodes=n_output_nodes,
30 | # loss function.
31 | # the default value is 'mean_squared_error'.
32 | # for the other functions, we support
33 | # 'mean_absolute_error', 'categorical_crossentropy', and 'binary_crossentropy'.
34 | loss='mean_squared_error',
35 | # activation function applied to the hidden nodes.
36 | # the default value is 'sigmoid'.
37 | # for the other functions, we support 'linear' and 'tanh'.
38 | # NOTE: OS-ELM can apply an activation function only to the hidden nodes.
39 | activation='sigmoid',
40 | )
41 |
42 | # ===========================================
43 | # Prepare dataset
44 | # ===========================================
45 | n_classes = n_output_nodes
46 |
47 | # load Iris
48 | iris = datasets.load_iris()
49 | x_iris, t_iris = iris.data, iris.target
50 |
51 | # normalize each column value
52 | mean = np.mean(x_iris, axis=0)
53 | std = np.std(x_iris, axis=0)
54 | x_iris = (x_iris - mean) / std
55 |
56 | # convert label data into one-hot-vector format data.
57 | t_iris = to_categorical(t_iris, num_classes=n_classes)
58 |
59 | # shuffle dataset
60 | perm = np.random.permutation(len(x_iris))
61 | x_iris = x_iris[perm]
62 | t_iris = t_iris[perm]
63 |
64 | # divide dataset for training and testing
65 | border = int(len(x_iris) * 0.8)
66 | x_train, x_test = x_iris[:border], x_iris[border:]
67 | t_train, t_test = t_iris[:border], t_iris[border:]
68 |
69 | # divide the training dataset into two datasets:
70 | # (1) for the initial training phase
71 | # (2) for the sequential training phase
72 | # NOTE: the number of training samples for the initial training phase
73 | # must be much greater than the number of the model's hidden nodes.
74 | # here, we assign int(1.2 * n_hidden_nodes) training samples
75 | # for the initial training phase.
76 | border = int(1.2 * n_hidden_nodes)
77 | x_train_init = x_train[:border]
78 | x_train_seq = x_train[border:]
79 | t_train_init = t_train[:border]
80 | t_train_seq = t_train[border:]
81 |
82 |
83 | # ===========================================
84 | # Training
85 | # ===========================================
86 | # the initial training phase
87 | pbar = tqdm.tqdm(total=len(x_train), desc='initial training phase')
88 | os_elm.init_train(x_train_init, t_train_init)
89 | pbar.update(n=len(x_train_init))
90 |
91 | # the sequential training phase
92 | pbar.set_description('sequential training phase')
93 | batch_size = 8
94 | for i in range(0, len(x_train_seq), batch_size):
95 | x_batch = x_train_seq[i:i+batch_size]
96 | t_batch = t_train_seq[i:i+batch_size]
97 | os_elm.seq_train(x_batch, t_batch)
98 | pbar.update(n=len(x_batch))
99 | pbar.close()
100 |
101 | # ===========================================
102 | # Prediction
103 | # ===========================================
104 | # sample 10 validation samples from x_test
105 | n = 10
106 | x = x_test[:n]
107 | t = t_test[:n]
108 |
109 | # 'predict' method returns raw values of output nodes.
110 | y = os_elm.predict(x)
111 | # apply softmax function to the output values.
112 | y = softmax(y)
113 |
114 | # check the answers.
115 | for i in range(n):
116 | max_ind = np.argmax(y[i])
117 | print('========== sample index %d ==========' % i)
118 | print('estimated answer: class %d' % max_ind)
119 | print('estimated probability: %.3f' % y[i,max_ind])
120 | print('true answer: class %d' % np.argmax(t[i]))
121 |
122 | # ===========================================
123 | # Evaluation
124 | # ===========================================
125 | # we currently support 'loss' and 'accuracy' for 'metrics'.
126 | # NOTE: 'accuracy' is valid only if the model assumes
127 | # to deal with a classification problem, while 'loss' is always valid.
128 | # loss = os_elm.evaluate(x_test, t_test, metrics=['loss']
129 | [loss, accuracy] = os_elm.evaluate(x_test, t_test, metrics=['loss', 'accuracy'])
130 | print('val_loss: %f, val_accuracy: %f' % (loss, accuracy))
131 |
132 | # ===========================================
133 | # Save model
134 | # ===========================================
135 | print('saving model parameters...')
136 | os_elm.save('./checkpoint/model.ckpt')
137 |
138 | # initialize weights of os_elm
139 | os_elm.initialize_variables()
140 |
141 | # ===========================================
142 | # Load model
143 | # ===========================================
144 | # If you want to load weights to a model,
145 | # the architecture of the model must be exactly the same
146 | # as the one when the weights were saved.
147 | print('restoring model parameters...')
148 | os_elm.restore('./checkpoint/model.ckpt')
149 |
150 | # ===========================================
151 | # ReEvaluation
152 | # ===========================================
153 | # loss = os_elm.evaluate(x_test, t_test, metrics=['loss']
154 | [loss, accuracy] = os_elm.evaluate(x_test, t_test, metrics=['loss', 'accuracy'])
155 | print('val_loss: %f, val_accuracy: %f' % (loss, accuracy))
156 |
157 | if __name__ == '__main__':
158 | main()
159 |
--------------------------------------------------------------------------------
/train_mnist.py:
--------------------------------------------------------------------------------
1 | from keras.datasets import mnist
2 | from keras.utils import to_categorical
3 | from os_elm import OS_ELM
4 | import numpy as np
5 | import tensorflow as tf
6 | import tqdm
7 |
8 | def softmax(a):
9 | c = np.max(a, axis=-1).reshape(-1, 1)
10 | exp_a = np.exp(a - c)
11 | sum_exp_a = np.sum(exp_a, axis=-1).reshape(-1, 1)
12 | return exp_a / sum_exp_a
13 |
14 | def main():
15 |
16 | # ===========================================
17 | # Instantiate os-elm
18 | # ===========================================
19 | n_input_nodes = 784
20 | n_hidden_nodes = 1024
21 | n_output_nodes = 10
22 |
23 | os_elm = OS_ELM(
24 | # the number of input nodes.
25 | n_input_nodes=n_input_nodes,
26 | # the number of hidden nodes.
27 | n_hidden_nodes=n_hidden_nodes,
28 | # the number of output nodes.
29 | n_output_nodes=n_output_nodes,
30 | # loss function.
31 | # the default value is 'mean_squared_error'.
32 | # for the other functions, we support
33 | # 'mean_absolute_error', 'categorical_crossentropy', and 'binary_crossentropy'.
34 | loss='mean_squared_error',
35 | # activation function applied to the hidden nodes.
36 | # the default value is 'sigmoid'.
37 | # for the other functions, we support 'linear' and 'tanh'.
38 | # NOTE: OS-ELM can apply an activation function only to the hidden nodes.
39 | activation='sigmoid',
40 | )
41 |
42 | # ===========================================
43 | # Prepare dataset
44 | # ===========================================
45 | n_classes = n_output_nodes
46 |
47 | # load MNIST
48 | (x_train, t_train), (x_test, t_test) = mnist.load_data()
49 | # normalize images' values within [0, 1]
50 | x_train = x_train.reshape(-1, n_input_nodes) / 255.
51 | x_test = x_test.reshape(-1, n_input_nodes) / 255.
52 | x_train = x_train.astype(np.float32)
53 | x_test = x_test.astype(np.float32)
54 |
55 | # convert label data into one-hot-vector format data.
56 | t_train = to_categorical(t_train, num_classes=n_classes)
57 | t_test = to_categorical(t_test, num_classes=n_classes)
58 | t_train = t_train.astype(np.float32)
59 | t_test = t_test.astype(np.float32)
60 |
61 | # divide the training dataset into two datasets:
62 | # (1) for the initial training phase
63 | # (2) for the sequential training phase
64 | # NOTE: the number of training samples for the initial training phase
65 | # must be much greater than the number of the model's hidden nodes.
66 | # here, we assign int(1.5 * n_hidden_nodes) training samples
67 | # for the initial training phase.
68 | border = int(1.5 * n_hidden_nodes)
69 | x_train_init = x_train[:border]
70 | x_train_seq = x_train[border:]
71 | t_train_init = t_train[:border]
72 | t_train_seq = t_train[border:]
73 |
74 |
75 | # ===========================================
76 | # Training
77 | # ===========================================
78 | # the initial training phase
79 | pbar = tqdm.tqdm(total=len(x_train), desc='initial training phase')
80 | os_elm.init_train(x_train_init, t_train_init)
81 | pbar.update(n=len(x_train_init))
82 |
83 | # the sequential training phase
84 | pbar.set_description('sequential training phase')
85 | batch_size = 64
86 | for i in range(0, len(x_train_seq), batch_size):
87 | x_batch = x_train_seq[i:i+batch_size]
88 | t_batch = t_train_seq[i:i+batch_size]
89 | os_elm.seq_train(x_batch, t_batch)
90 | pbar.update(n=len(x_batch))
91 | pbar.close()
92 |
93 | # ===========================================
94 | # Prediction
95 | # ===========================================
96 | # sample 10 validation samples from x_test
97 | n = 10
98 | x = x_test[:n]
99 | t = t_test[:n]
100 |
101 | # 'predict' method returns raw values of output nodes.
102 | y = os_elm.predict(x)
103 | # apply softmax function to the output values.
104 | y = softmax(y)
105 |
106 | # check the answers.
107 | for i in range(n):
108 | max_ind = np.argmax(y[i])
109 | print('========== sample index %d ==========' % i)
110 | print('estimated answer: class %d' % max_ind)
111 | print('estimated probability: %.3f' % y[i,max_ind])
112 | print('true answer: class %d' % np.argmax(t[i]))
113 |
114 | # ===========================================
115 | # Evaluation
116 | # ===========================================
117 | # we currently support 'loss' and 'accuracy' for 'metrics'.
118 | # NOTE: 'accuracy' is valid only if the model assumes
119 | # to deal with a classification problem, while 'loss' is always valid.
120 | # loss = os_elm.evaluate(x_test, t_test, metrics=['loss']
121 | [loss, accuracy] = os_elm.evaluate(x_test, t_test, metrics=['loss', 'accuracy'])
122 | print('val_loss: %f, val_accuracy: %f' % (loss, accuracy))
123 |
124 | # ===========================================
125 | # Save model
126 | # ===========================================
127 | print('saving model parameters...')
128 | os_elm.save('./checkpoint/model.ckpt')
129 |
130 | # initialize weights of os_elm
131 | os_elm.initialize_variables()
132 |
133 | # ===========================================
134 | # Load model
135 | # ===========================================
136 | # If you want to load weights to a model,
137 | # the architecture of the model must be exactly the same
138 | # as the one when the weights were saved.
139 | print('restoring model parameters...')
140 | os_elm.restore('./checkpoint/model.ckpt')
141 |
142 | # ===========================================
143 | # ReEvaluation
144 | # ===========================================
145 | # loss = os_elm.evaluate(x_test, t_test, metrics=['loss']
146 | [loss, accuracy] = os_elm.evaluate(x_test, t_test, metrics=['loss', 'accuracy'])
147 | print('val_loss: %f, val_accuracy: %f' % (loss, accuracy))
148 |
149 | if __name__ == '__main__':
150 | main()
151 |
--------------------------------------------------------------------------------