├── .ipynb_checkpoints
├── CA_HOUSE_LINEAR_REGRESSION-checkpoint.ipynb
├── MNIST_LINEAR_CLASSIFIER-checkpoint.ipynb
└── MNIST_MLN_CLASSIFIER-checkpoint.ipynb
├── CA_HOUSE_LINEAR_REGRESSION.ipynb
├── DIGIT_MLN_DROPOUT.ipynb
├── MNIST_AUTOENCODER.ipynb
├── MNIST_BILSTM.ipynb
├── MNIST_CNN.ipynb
├── MNIST_LINEAR_CLASSIFIER.ipynb
├── MNIST_LSTM.ipynb
├── MNIST_MLN_BN.ipynb
├── MNIST_MLN_CLASSIFIER.ipynb
├── MNIST_MLN_CLASSIFIER_DROPOUT.ipynb
├── README.md
└── datasets
└── housing.csv
/.ipynb_checkpoints/CA_HOUSE_LINEAR_REGRESSION-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import pandas as pd\n"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 2,
17 | "metadata": {
18 | "collapsed": true
19 | },
20 | "outputs": [],
21 | "source": [
22 | "housing_data = pd.read_csv('datasets/housing.csv')"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 3,
28 | "metadata": {},
29 | "outputs": [
30 | {
31 | "data": {
32 | "text/html": [
33 | "
\n",
34 | "\n",
47 | "
\n",
48 | " \n",
49 | " \n",
50 | " | \n",
51 | " longitude | \n",
52 | " latitude | \n",
53 | " housing_median_age | \n",
54 | " total_rooms | \n",
55 | " total_bedrooms | \n",
56 | " population | \n",
57 | " households | \n",
58 | " median_income | \n",
59 | " median_house_value | \n",
60 | " ocean_proximity | \n",
61 | "
\n",
62 | " \n",
63 | " \n",
64 | " \n",
65 | " 0 | \n",
66 | " -122.23 | \n",
67 | " 37.88 | \n",
68 | " 41.0 | \n",
69 | " 880.0 | \n",
70 | " 129.0 | \n",
71 | " 322.0 | \n",
72 | " 126.0 | \n",
73 | " 8.3252 | \n",
74 | " 452600.0 | \n",
75 | " NEAR BAY | \n",
76 | "
\n",
77 | " \n",
78 | " 1 | \n",
79 | " -122.22 | \n",
80 | " 37.86 | \n",
81 | " 21.0 | \n",
82 | " 7099.0 | \n",
83 | " 1106.0 | \n",
84 | " 2401.0 | \n",
85 | " 1138.0 | \n",
86 | " 8.3014 | \n",
87 | " 358500.0 | \n",
88 | " NEAR BAY | \n",
89 | "
\n",
90 | " \n",
91 | " 2 | \n",
92 | " -122.24 | \n",
93 | " 37.85 | \n",
94 | " 52.0 | \n",
95 | " 1467.0 | \n",
96 | " 190.0 | \n",
97 | " 496.0 | \n",
98 | " 177.0 | \n",
99 | " 7.2574 | \n",
100 | " 352100.0 | \n",
101 | " NEAR BAY | \n",
102 | "
\n",
103 | " \n",
104 | " 3 | \n",
105 | " -122.25 | \n",
106 | " 37.85 | \n",
107 | " 52.0 | \n",
108 | " 1274.0 | \n",
109 | " 235.0 | \n",
110 | " 558.0 | \n",
111 | " 219.0 | \n",
112 | " 5.6431 | \n",
113 | " 341300.0 | \n",
114 | " NEAR BAY | \n",
115 | "
\n",
116 | " \n",
117 | " 4 | \n",
118 | " -122.25 | \n",
119 | " 37.85 | \n",
120 | " 52.0 | \n",
121 | " 1627.0 | \n",
122 | " 280.0 | \n",
123 | " 565.0 | \n",
124 | " 259.0 | \n",
125 | " 3.8462 | \n",
126 | " 342200.0 | \n",
127 | " NEAR BAY | \n",
128 | "
\n",
129 | " \n",
130 | "
\n",
131 | "
"
132 | ],
133 | "text/plain": [
134 | " longitude latitude housing_median_age total_rooms total_bedrooms \\\n",
135 | "0 -122.23 37.88 41.0 880.0 129.0 \n",
136 | "1 -122.22 37.86 21.0 7099.0 1106.0 \n",
137 | "2 -122.24 37.85 52.0 1467.0 190.0 \n",
138 | "3 -122.25 37.85 52.0 1274.0 235.0 \n",
139 | "4 -122.25 37.85 52.0 1627.0 280.0 \n",
140 | "\n",
141 | " population households median_income median_house_value ocean_proximity \n",
142 | "0 322.0 126.0 8.3252 452600.0 NEAR BAY \n",
143 | "1 2401.0 1138.0 8.3014 358500.0 NEAR BAY \n",
144 | "2 496.0 177.0 7.2574 352100.0 NEAR BAY \n",
145 | "3 558.0 219.0 5.6431 341300.0 NEAR BAY \n",
146 | "4 565.0 259.0 3.8462 342200.0 NEAR BAY "
147 | ]
148 | },
149 | "execution_count": 3,
150 | "metadata": {},
151 | "output_type": "execute_result"
152 | }
153 | ],
154 | "source": [
155 | "housing_data.head()\n",
156 | "# 社区数据"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": 4,
162 | "metadata": {},
163 | "outputs": [
164 | {
165 | "name": "stdout",
166 | "output_type": "stream",
167 | "text": [
168 | "\n",
169 | "RangeIndex: 20640 entries, 0 to 20639\n",
170 | "Data columns (total 10 columns):\n",
171 | "longitude 20640 non-null float64\n",
172 | "latitude 20640 non-null float64\n",
173 | "housing_median_age 20640 non-null float64\n",
174 | "total_rooms 20640 non-null float64\n",
175 | "total_bedrooms 20433 non-null float64\n",
176 | "population 20640 non-null float64\n",
177 | "households 20640 non-null float64\n",
178 | "median_income 20640 non-null float64\n",
179 | "median_house_value 20640 non-null float64\n",
180 | "ocean_proximity 20640 non-null object\n",
181 | "dtypes: float64(9), object(1)\n",
182 | "memory usage: 1.6+ MB\n"
183 | ]
184 | }
185 | ],
186 | "source": [
187 | "housing_data.info()"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": 5,
193 | "metadata": {},
194 | "outputs": [
195 | {
196 | "name": "stdout",
197 | "output_type": "stream",
198 | "text": [
199 | "\n",
200 | "RangeIndex: 20640 entries, 0 to 20639\n",
201 | "Data columns (total 10 columns):\n",
202 | "longitude 20640 non-null float64\n",
203 | "latitude 20640 non-null float64\n",
204 | "housing_median_age 20640 non-null float64\n",
205 | "total_rooms 20640 non-null float64\n",
206 | "total_bedrooms 20640 non-null float64\n",
207 | "population 20640 non-null float64\n",
208 | "households 20640 non-null float64\n",
209 | "median_income 20640 non-null float64\n",
210 | "median_house_value 20640 non-null float64\n",
211 | "ocean_proximity 20640 non-null object\n",
212 | "dtypes: float64(9), object(1)\n",
213 | "memory usage: 1.6+ MB\n"
214 | ]
215 | }
216 | ],
217 | "source": [
218 | "median = housing_data['total_bedrooms'].median()\n",
219 | "housing_data = housing_data.fillna(median)\n",
220 | "\n",
221 | "housing_data.info()"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": 6,
227 | "metadata": {
228 | "collapsed": true
229 | },
230 | "outputs": [],
231 | "source": [
232 | "housing_data = housing_data.drop('ocean_proximity', axis=1)\n",
233 | "\n"
234 | ]
235 | },
236 | {
237 | "cell_type": "code",
238 | "execution_count": 7,
239 | "metadata": {},
240 | "outputs": [
241 | {
242 | "name": "stdout",
243 | "output_type": "stream",
244 | "text": [
245 | "\n",
246 | "RangeIndex: 20640 entries, 0 to 20639\n",
247 | "Data columns (total 9 columns):\n",
248 | "longitude 20640 non-null float64\n",
249 | "latitude 20640 non-null float64\n",
250 | "housing_median_age 20640 non-null float64\n",
251 | "total_rooms 20640 non-null float64\n",
252 | "total_bedrooms 20640 non-null float64\n",
253 | "population 20640 non-null float64\n",
254 | "households 20640 non-null float64\n",
255 | "median_income 20640 non-null float64\n",
256 | "median_house_value 20640 non-null float64\n",
257 | "dtypes: float64(9)\n",
258 | "memory usage: 1.4 MB\n"
259 | ]
260 | }
261 | ],
262 | "source": [
263 | "housing_data.info()"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": 8,
269 | "metadata": {
270 | "collapsed": true
271 | },
272 | "outputs": [],
273 | "source": [
274 | "\n",
275 | "from sklearn.model_selection import train_test_split\n",
276 | "\n",
277 | "train_set, test_set = train_test_split(housing_data, test_size=0.2, random_state=33)"
278 | ]
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": 9,
283 | "metadata": {
284 | "collapsed": true
285 | },
286 | "outputs": [],
287 | "source": [
288 | "y_train = train_set['median_house_value']\n",
289 | "X_train = train_set.drop('median_house_value', axis=1)\n",
290 | "\n",
291 | "y_test = test_set['median_house_value']\n",
292 | "X_test = test_set.drop('median_house_value', axis=1)\n",
293 | "\n",
294 | "from sklearn.linear_model import LinearRegression\n",
295 | "\n",
296 | "lr = LinearRegression()\n",
297 | "\n",
298 | "lr.fit(X_train, y_train)\n",
299 | "\n",
300 | "from sklearn.metrics.regression import mean_squared_error\n",
301 | "\n",
302 | "y_predict = lr.predict(X_test)"
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "execution_count": 10,
308 | "metadata": {},
309 | "outputs": [
310 | {
311 | "data": {
312 | "text/plain": [
313 | "69990.825384776152"
314 | ]
315 | },
316 | "execution_count": 10,
317 | "metadata": {},
318 | "output_type": "execute_result"
319 | }
320 | ],
321 | "source": [
322 | "import numpy as np\n",
323 | "print np.sqrt(mean_squared_error(y_test, y_predict))"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": 11,
329 | "metadata": {
330 | "collapsed": true
331 | },
332 | "outputs": [],
333 | "source": [
334 | "from sklearn.preprocessing import StandardScaler\n",
335 | "\n",
336 | "ss = StandardScaler()\n",
337 | "\n",
338 | "X_train = ss.fit_transform(X_train)\n",
339 | "X_test = ss.fit_transform(X_test)"
340 | ]
341 | },
342 | {
343 | "cell_type": "code",
344 | "execution_count": 13,
345 | "metadata": {},
346 | "outputs": [
347 | {
348 | "name": "stdout",
349 | "output_type": "stream",
350 | "text": [
351 | "69620.6890544\n",
352 | "69901.4152052\n"
353 | ]
354 | }
355 | ],
356 | "source": [
357 | "lr = LinearRegression()\n",
358 | "\n",
359 | "lr.fit(X_train, y_train)\n",
360 | "\n",
361 | "from sklearn.metrics.regression import mean_squared_error\n",
362 | "\n",
363 | "\n",
364 | "y_predict = lr.predict(X_train)\n",
365 | "\n",
366 | "import numpy as np\n",
367 | "\n",
368 | "\n",
369 | "print np.sqrt(mean_squared_error(y_train, y_predict))\n",
370 | "\n",
371 | "y_predict = lr.predict(X_test)\n",
372 | "\n",
373 | "\n",
374 | "print np.sqrt(mean_squared_error(y_test, y_predict))"
375 | ]
376 | },
377 | {
378 | "cell_type": "code",
379 | "execution_count": 14,
380 | "metadata": {},
381 | "outputs": [
382 | {
383 | "name": "stdout",
384 | "output_type": "stream",
385 | "text": [
386 | "0.0\n",
387 | "79017.4707616\n"
388 | ]
389 | }
390 | ],
391 | "source": [
392 | "from sklearn.tree import DecisionTreeRegressor\n",
393 | "\n",
394 | "dtr = DecisionTreeRegressor()\n",
395 | "\n",
396 | "dtr.fit(X_train, y_train)\n",
397 | "\n",
398 | "y_predict = dtr.predict(X_train)\n",
399 | "\n",
400 | "print np.sqrt(mean_squared_error(y_train, y_predict))\n",
401 | "\n",
402 | "from sklearn.metrics.regression import mean_squared_error\n",
403 | "\n",
404 | "y_predict = dtr.predict(X_test)\n",
405 | "\n",
406 | "import numpy as np\n",
407 | "print np.sqrt(mean_squared_error(y_test, y_predict))"
408 | ]
409 | },
410 | {
411 | "cell_type": "code",
412 | "execution_count": null,
413 | "metadata": {
414 | "collapsed": true
415 | },
416 | "outputs": [],
417 | "source": []
418 | }
419 | ],
420 | "metadata": {
421 | "kernelspec": {
422 | "display_name": "Python 2",
423 | "language": "python",
424 | "name": "python2"
425 | },
426 | "language_info": {
427 | "codemirror_mode": {
428 | "name": "ipython",
429 | "version": 2
430 | },
431 | "file_extension": ".py",
432 | "mimetype": "text/x-python",
433 | "name": "python",
434 | "nbconvert_exporter": "python",
435 | "pygments_lexer": "ipython2",
436 | "version": "2.7.13"
437 | }
438 | },
439 | "nbformat": 4,
440 | "nbformat_minor": 2
441 | }
442 |
--------------------------------------------------------------------------------
/CA_HOUSE_LINEAR_REGRESSION.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import pandas as pd\n"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 2,
17 | "metadata": {
18 | "collapsed": true
19 | },
20 | "outputs": [],
21 | "source": [
22 | "housing_data = pd.read_csv('datasets/housing.csv')"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 3,
28 | "metadata": {},
29 | "outputs": [
30 | {
31 | "data": {
32 | "text/html": [
33 | "\n",
34 | "\n",
47 | "
\n",
48 | " \n",
49 | " \n",
50 | " | \n",
51 | " longitude | \n",
52 | " latitude | \n",
53 | " housing_median_age | \n",
54 | " total_rooms | \n",
55 | " total_bedrooms | \n",
56 | " population | \n",
57 | " households | \n",
58 | " median_income | \n",
59 | " median_house_value | \n",
60 | " ocean_proximity | \n",
61 | "
\n",
62 | " \n",
63 | " \n",
64 | " \n",
65 | " 0 | \n",
66 | " -122.23 | \n",
67 | " 37.88 | \n",
68 | " 41.0 | \n",
69 | " 880.0 | \n",
70 | " 129.0 | \n",
71 | " 322.0 | \n",
72 | " 126.0 | \n",
73 | " 8.3252 | \n",
74 | " 452600.0 | \n",
75 | " NEAR BAY | \n",
76 | "
\n",
77 | " \n",
78 | " 1 | \n",
79 | " -122.22 | \n",
80 | " 37.86 | \n",
81 | " 21.0 | \n",
82 | " 7099.0 | \n",
83 | " 1106.0 | \n",
84 | " 2401.0 | \n",
85 | " 1138.0 | \n",
86 | " 8.3014 | \n",
87 | " 358500.0 | \n",
88 | " NEAR BAY | \n",
89 | "
\n",
90 | " \n",
91 | " 2 | \n",
92 | " -122.24 | \n",
93 | " 37.85 | \n",
94 | " 52.0 | \n",
95 | " 1467.0 | \n",
96 | " 190.0 | \n",
97 | " 496.0 | \n",
98 | " 177.0 | \n",
99 | " 7.2574 | \n",
100 | " 352100.0 | \n",
101 | " NEAR BAY | \n",
102 | "
\n",
103 | " \n",
104 | " 3 | \n",
105 | " -122.25 | \n",
106 | " 37.85 | \n",
107 | " 52.0 | \n",
108 | " 1274.0 | \n",
109 | " 235.0 | \n",
110 | " 558.0 | \n",
111 | " 219.0 | \n",
112 | " 5.6431 | \n",
113 | " 341300.0 | \n",
114 | " NEAR BAY | \n",
115 | "
\n",
116 | " \n",
117 | " 4 | \n",
118 | " -122.25 | \n",
119 | " 37.85 | \n",
120 | " 52.0 | \n",
121 | " 1627.0 | \n",
122 | " 280.0 | \n",
123 | " 565.0 | \n",
124 | " 259.0 | \n",
125 | " 3.8462 | \n",
126 | " 342200.0 | \n",
127 | " NEAR BAY | \n",
128 | "
\n",
129 | " \n",
130 | "
\n",
131 | "
"
132 | ],
133 | "text/plain": [
134 | " longitude latitude housing_median_age total_rooms total_bedrooms \\\n",
135 | "0 -122.23 37.88 41.0 880.0 129.0 \n",
136 | "1 -122.22 37.86 21.0 7099.0 1106.0 \n",
137 | "2 -122.24 37.85 52.0 1467.0 190.0 \n",
138 | "3 -122.25 37.85 52.0 1274.0 235.0 \n",
139 | "4 -122.25 37.85 52.0 1627.0 280.0 \n",
140 | "\n",
141 | " population households median_income median_house_value ocean_proximity \n",
142 | "0 322.0 126.0 8.3252 452600.0 NEAR BAY \n",
143 | "1 2401.0 1138.0 8.3014 358500.0 NEAR BAY \n",
144 | "2 496.0 177.0 7.2574 352100.0 NEAR BAY \n",
145 | "3 558.0 219.0 5.6431 341300.0 NEAR BAY \n",
146 | "4 565.0 259.0 3.8462 342200.0 NEAR BAY "
147 | ]
148 | },
149 | "execution_count": 3,
150 | "metadata": {},
151 | "output_type": "execute_result"
152 | }
153 | ],
154 | "source": [
155 | "housing_data.head()\n",
156 | "# 社区数据"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": 4,
162 | "metadata": {},
163 | "outputs": [
164 | {
165 | "name": "stdout",
166 | "output_type": "stream",
167 | "text": [
168 | "\n",
169 | "RangeIndex: 20640 entries, 0 to 20639\n",
170 | "Data columns (total 10 columns):\n",
171 | "longitude 20640 non-null float64\n",
172 | "latitude 20640 non-null float64\n",
173 | "housing_median_age 20640 non-null float64\n",
174 | "total_rooms 20640 non-null float64\n",
175 | "total_bedrooms 20433 non-null float64\n",
176 | "population 20640 non-null float64\n",
177 | "households 20640 non-null float64\n",
178 | "median_income 20640 non-null float64\n",
179 | "median_house_value 20640 non-null float64\n",
180 | "ocean_proximity 20640 non-null object\n",
181 | "dtypes: float64(9), object(1)\n",
182 | "memory usage: 1.6+ MB\n"
183 | ]
184 | }
185 | ],
186 | "source": [
187 | "housing_data.info()"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": 5,
193 | "metadata": {},
194 | "outputs": [
195 | {
196 | "name": "stdout",
197 | "output_type": "stream",
198 | "text": [
199 | "\n",
200 | "RangeIndex: 20640 entries, 0 to 20639\n",
201 | "Data columns (total 10 columns):\n",
202 | "longitude 20640 non-null float64\n",
203 | "latitude 20640 non-null float64\n",
204 | "housing_median_age 20640 non-null float64\n",
205 | "total_rooms 20640 non-null float64\n",
206 | "total_bedrooms 20640 non-null float64\n",
207 | "population 20640 non-null float64\n",
208 | "households 20640 non-null float64\n",
209 | "median_income 20640 non-null float64\n",
210 | "median_house_value 20640 non-null float64\n",
211 | "ocean_proximity 20640 non-null object\n",
212 | "dtypes: float64(9), object(1)\n",
213 | "memory usage: 1.6+ MB\n"
214 | ]
215 | }
216 | ],
217 | "source": [
218 | "median = housing_data['total_bedrooms'].median()\n",
219 | "housing_data = housing_data.fillna(median)\n",
220 | "\n",
221 | "housing_data.info()"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": 6,
227 | "metadata": {
228 | "collapsed": true
229 | },
230 | "outputs": [],
231 | "source": [
232 | "housing_data = housing_data.drop('ocean_proximity', axis=1)\n",
233 | "\n"
234 | ]
235 | },
236 | {
237 | "cell_type": "code",
238 | "execution_count": 7,
239 | "metadata": {},
240 | "outputs": [
241 | {
242 | "name": "stdout",
243 | "output_type": "stream",
244 | "text": [
245 | "\n",
246 | "RangeIndex: 20640 entries, 0 to 20639\n",
247 | "Data columns (total 9 columns):\n",
248 | "longitude 20640 non-null float64\n",
249 | "latitude 20640 non-null float64\n",
250 | "housing_median_age 20640 non-null float64\n",
251 | "total_rooms 20640 non-null float64\n",
252 | "total_bedrooms 20640 non-null float64\n",
253 | "population 20640 non-null float64\n",
254 | "households 20640 non-null float64\n",
255 | "median_income 20640 non-null float64\n",
256 | "median_house_value 20640 non-null float64\n",
257 | "dtypes: float64(9)\n",
258 | "memory usage: 1.4 MB\n"
259 | ]
260 | }
261 | ],
262 | "source": [
263 | "housing_data.info()"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": 8,
269 | "metadata": {
270 | "collapsed": true
271 | },
272 | "outputs": [],
273 | "source": [
274 | "\n",
275 | "from sklearn.model_selection import train_test_split\n",
276 | "\n",
277 | "train_set, test_set = train_test_split(housing_data, test_size=0.2, random_state=33)"
278 | ]
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": 9,
283 | "metadata": {
284 | "collapsed": true
285 | },
286 | "outputs": [],
287 | "source": [
288 | "y_train = train_set['median_house_value']\n",
289 | "X_train = train_set.drop('median_house_value', axis=1)\n",
290 | "\n",
291 | "y_test = test_set['median_house_value']\n",
292 | "X_test = test_set.drop('median_house_value', axis=1)\n",
293 | "\n",
294 | "from sklearn.linear_model import LinearRegression\n",
295 | "\n",
296 | "lr = LinearRegression()\n",
297 | "\n",
298 | "lr.fit(X_train, y_train)\n",
299 | "\n",
300 | "from sklearn.metrics.regression import mean_squared_error\n",
301 | "\n",
302 | "y_predict = lr.predict(X_test)"
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "execution_count": 10,
308 | "metadata": {},
309 | "outputs": [
310 | {
311 | "data": {
312 | "text/plain": [
313 | "69990.825384776152"
314 | ]
315 | },
316 | "execution_count": 10,
317 | "metadata": {},
318 | "output_type": "execute_result"
319 | }
320 | ],
321 | "source": [
322 | "import numpy as np\n",
323 | "print np.sqrt(mean_squared_error(y_test, y_predict))"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": 11,
329 | "metadata": {
330 | "collapsed": true
331 | },
332 | "outputs": [],
333 | "source": [
334 | "from sklearn.preprocessing import StandardScaler\n",
335 | "\n",
336 | "ss = StandardScaler()\n",
337 | "\n",
338 | "X_train = ss.fit_transform(X_train)\n",
339 | "X_test = ss.fit_transform(X_test)"
340 | ]
341 | },
342 | {
343 | "cell_type": "code",
344 | "execution_count": 13,
345 | "metadata": {},
346 | "outputs": [
347 | {
348 | "name": "stdout",
349 | "output_type": "stream",
350 | "text": [
351 | "69620.6890544\n",
352 | "69901.4152052\n"
353 | ]
354 | }
355 | ],
356 | "source": [
357 | "lr = LinearRegression()\n",
358 | "\n",
359 | "lr.fit(X_train, y_train)\n",
360 | "\n",
361 | "from sklearn.metrics.regression import mean_squared_error\n",
362 | "\n",
363 | "\n",
364 | "y_predict = lr.predict(X_train)\n",
365 | "\n",
366 | "import numpy as np\n",
367 | "\n",
368 | "\n",
369 | "print np.sqrt(mean_squared_error(y_train, y_predict))\n",
370 | "\n",
371 | "y_predict = lr.predict(X_test)\n",
372 | "\n",
373 | "\n",
374 | "print np.sqrt(mean_squared_error(y_test, y_predict))"
375 | ]
376 | },
377 | {
378 | "cell_type": "code",
379 | "execution_count": 14,
380 | "metadata": {},
381 | "outputs": [
382 | {
383 | "name": "stdout",
384 | "output_type": "stream",
385 | "text": [
386 | "0.0\n",
387 | "79017.4707616\n"
388 | ]
389 | }
390 | ],
391 | "source": [
392 | "from sklearn.tree import DecisionTreeRegressor\n",
393 | "\n",
394 | "dtr = DecisionTreeRegressor()\n",
395 | "\n",
396 | "dtr.fit(X_train, y_train)\n",
397 | "\n",
398 | "y_predict = dtr.predict(X_train)\n",
399 | "\n",
400 | "print np.sqrt(mean_squared_error(y_train, y_predict))\n",
401 | "\n",
402 | "from sklearn.metrics.regression import mean_squared_error\n",
403 | "\n",
404 | "y_predict = dtr.predict(X_test)\n",
405 | "\n",
406 | "import numpy as np\n",
407 | "print np.sqrt(mean_squared_error(y_test, y_predict))"
408 | ]
409 | },
410 | {
411 | "cell_type": "code",
412 | "execution_count": null,
413 | "metadata": {
414 | "collapsed": true
415 | },
416 | "outputs": [],
417 | "source": []
418 | }
419 | ],
420 | "metadata": {
421 | "kernelspec": {
422 | "display_name": "Python 2",
423 | "language": "python",
424 | "name": "python2"
425 | },
426 | "language_info": {
427 | "codemirror_mode": {
428 | "name": "ipython",
429 | "version": 2
430 | },
431 | "file_extension": ".py",
432 | "mimetype": "text/x-python",
433 | "name": "python",
434 | "nbconvert_exporter": "python",
435 | "pygments_lexer": "ipython2",
436 | "version": "2.7.13"
437 | }
438 | },
439 | "nbformat": 4,
440 | "nbformat_minor": 2
441 | }
442 |
--------------------------------------------------------------------------------
/MNIST_AUTOENCODER.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import tensorflow as tf\n",
12 | "from tensorflow.examples.tutorials.mnist import input_data\n",
13 | "import numpy as np\n",
14 | "import matplotlib\n",
15 | "matplotlib.use('nbagg')\n",
16 | "\n",
17 | "import matplotlib.pyplot as plt"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 2,
23 | "metadata": {
24 | "collapsed": false
25 | },
26 | "outputs": [
27 | {
28 | "name": "stdout",
29 | "output_type": "stream",
30 | "text": [
31 | "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
32 | "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
33 | "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
34 | "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n"
35 | ]
36 | }
37 | ],
38 | "source": [
39 | "mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 3,
45 | "metadata": {
46 | "collapsed": false
47 | },
48 | "outputs": [
49 | {
50 | "name": "stdout",
51 | "output_type": "stream",
52 | "text": [
53 | "(55000, 784) (55000, 10)\n"
54 | ]
55 | }
56 | ],
57 | "source": [
58 | "print mnist.train.images.shape, mnist.train.labels.shape"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 4,
64 | "metadata": {
65 | "collapsed": false
66 | },
67 | "outputs": [],
68 | "source": [
69 | "INPUT_UNITS = 784\n",
70 | "HIDDEN_UNITS = [256]\n",
71 | "ENCODED_UNITS = 128\n",
72 | "\n",
73 | "LEARNING_RATE = 1e-3\n",
74 | "TRAINING_ITER = 2000\n",
75 | "BATCH_SIZE = 100\n",
76 | "\n",
77 | "def init_weights(shape):\n",
78 | " return tf.Variable(tf.truncated_normal(shape, stddev=0.1))\n",
79 | "\n",
80 | "def init_biases(shape):\n",
81 | " return tf.Variable(tf.zeros(shape) + 0.1)\n",
82 | "\n",
83 | "def add_layer(x, weights, biases, activation_function = None):\n",
84 | " scores = tf.matmul(x, weights) + biases\n",
85 | " \n",
86 | " if activation_function:\n",
87 | " return activation_function(scores)\n",
88 | " else:\n",
89 | " return scores\n",
90 | "\n",
91 | "weights = {\n",
92 | " 'encoder_l1': init_weights([INPUT_UNITS, HIDDEN_UNITS[0]]), \n",
93 | " 'encoder_l2': init_weights([HIDDEN_UNITS[0], ENCODED_UNITS]),\n",
94 | " 'decoder_l1': init_weights([ENCODED_UNITS, HIDDEN_UNITS[0]]),\n",
95 | " 'decoder_l2': init_weights([HIDDEN_UNITS[0], INPUT_UNITS])\n",
96 | "}\n",
97 | "\n",
98 | "biases = {\n",
99 | " 'encoder_l1': init_biases([HIDDEN_UNITS[0]]), \n",
100 | " 'encoder_l2': init_biases([ENCODED_UNITS]),\n",
101 | " 'decoder_l1': init_biases(HIDDEN_UNITS[0]),\n",
102 | " 'decoder_l2': init_biases([INPUT_UNITS])\n",
103 | "}\n",
104 | "\n",
105 | "def encoder(x):\n",
106 | " encoder_h1 = add_layer(x, weights['encoder_l1'], biases['encoder_l1'], tf.nn.relu)\n",
107 | " return add_layer(encoder_h1, weights['encoder_l2'], biases['encoder_l2'], tf.nn.relu) \n",
108 | "\n",
109 | "def decoder(x):\n",
110 | " decoder_h1 = add_layer(x, weights['decoder_l1'], biases['decoder_l1'], tf.nn.relu)\n",
111 | " return add_layer(decoder_h1, weights['decoder_l2'], biases['decoder_l2'], tf.nn.relu)\n",
112 | "\n",
113 | "x = tf.placeholder(tf.float32, [None, 784])\n",
114 | "\n",
115 | "encoder_op = encoder(x)\n",
116 | "x_ = decoder(encoder_op)\n",
117 | "\n",
118 | "loss = tf.reduce_mean(tf.square(x - x_))\n",
119 | "\n",
120 | "train_step = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)\n"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": 5,
126 | "metadata": {
127 | "collapsed": false
128 | },
129 | "outputs": [
130 | {
131 | "name": "stdout",
132 | "output_type": "stream",
133 | "text": [
134 | "0.200395\n",
135 | "0.0587475\n",
136 | "0.039646\n",
137 | "0.0379133\n",
138 | "0.0267893\n",
139 | "0.026916\n",
140 | "0.0215286\n",
141 | "0.0191278\n",
142 | "0.0185347\n",
143 | "0.0168091\n",
144 | "0.0186432\n",
145 | "0.018459\n",
146 | "0.0159615\n",
147 | "0.0166012\n",
148 | "0.0161047\n",
149 | "0.0142662\n",
150 | "0.0149797\n",
151 | "0.0151406\n",
152 | "0.0143366\n",
153 | "0.0143979\n",
154 | "0.0138252\n",
155 | "0.0138254\n",
156 | "0.0128702\n",
157 | "0.0133265\n",
158 | "0.0130669\n",
159 | "0.0127066\n",
160 | "0.0140361\n",
161 | "0.012641\n",
162 | "0.013777\n",
163 | "0.0137268\n",
164 | "0.0132932\n",
165 | "0.0122787\n",
166 | "0.011369\n",
167 | "0.0119215\n",
168 | "0.0114884\n",
169 | "0.0122204\n",
170 | "0.0121677\n",
171 | "0.0114692\n",
172 | "0.0120117\n",
173 | "0.0117096\n"
174 | ]
175 | },
176 | {
177 | "data": {
178 | "application/javascript": [
179 | "/* Put everything inside the global mpl namespace */\n",
180 | "window.mpl = {};\n",
181 | "\n",
182 | "\n",
183 | "mpl.get_websocket_type = function() {\n",
184 | " if (typeof(WebSocket) !== 'undefined') {\n",
185 | " return WebSocket;\n",
186 | " } else if (typeof(MozWebSocket) !== 'undefined') {\n",
187 | " return MozWebSocket;\n",
188 | " } else {\n",
189 | " alert('Your browser does not have WebSocket support.' +\n",
190 | " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
191 | " 'Firefox 4 and 5 are also supported but you ' +\n",
192 | " 'have to enable WebSockets in about:config.');\n",
193 | " };\n",
194 | "}\n",
195 | "\n",
196 | "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
197 | " this.id = figure_id;\n",
198 | "\n",
199 | " this.ws = websocket;\n",
200 | "\n",
201 | " this.supports_binary = (this.ws.binaryType != undefined);\n",
202 | "\n",
203 | " if (!this.supports_binary) {\n",
204 | " var warnings = document.getElementById(\"mpl-warnings\");\n",
205 | " if (warnings) {\n",
206 | " warnings.style.display = 'block';\n",
207 | " warnings.textContent = (\n",
208 | " \"This browser does not support binary websocket messages. \" +\n",
209 | " \"Performance may be slow.\");\n",
210 | " }\n",
211 | " }\n",
212 | "\n",
213 | " this.imageObj = new Image();\n",
214 | "\n",
215 | " this.context = undefined;\n",
216 | " this.message = undefined;\n",
217 | " this.canvas = undefined;\n",
218 | " this.rubberband_canvas = undefined;\n",
219 | " this.rubberband_context = undefined;\n",
220 | " this.format_dropdown = undefined;\n",
221 | "\n",
222 | " this.image_mode = 'full';\n",
223 | "\n",
224 | " this.root = $('');\n",
225 | " this._root_extra_style(this.root)\n",
226 | " this.root.attr('style', 'display: inline-block');\n",
227 | "\n",
228 | " $(parent_element).append(this.root);\n",
229 | "\n",
230 | " this._init_header(this);\n",
231 | " this._init_canvas(this);\n",
232 | " this._init_toolbar(this);\n",
233 | "\n",
234 | " var fig = this;\n",
235 | "\n",
236 | " this.waiting = false;\n",
237 | "\n",
238 | " this.ws.onopen = function () {\n",
239 | " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
240 | " fig.send_message(\"send_image_mode\", {});\n",
241 | " if (mpl.ratio != 1) {\n",
242 | " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
243 | " }\n",
244 | " fig.send_message(\"refresh\", {});\n",
245 | " }\n",
246 | "\n",
247 | " this.imageObj.onload = function() {\n",
248 | " if (fig.image_mode == 'full') {\n",
249 | " // Full images could contain transparency (where diff images\n",
250 | " // almost always do), so we need to clear the canvas so that\n",
251 | " // there is no ghosting.\n",
252 | " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
253 | " }\n",
254 | " fig.context.drawImage(fig.imageObj, 0, 0);\n",
255 | " };\n",
256 | "\n",
257 | " this.imageObj.onunload = function() {\n",
258 | " this.ws.close();\n",
259 | " }\n",
260 | "\n",
261 | " this.ws.onmessage = this._make_on_message_function(this);\n",
262 | "\n",
263 | " this.ondownload = ondownload;\n",
264 | "}\n",
265 | "\n",
266 | "mpl.figure.prototype._init_header = function() {\n",
267 | " var titlebar = $(\n",
268 | " '');\n",
270 | " var titletext = $(\n",
271 | " '');\n",
273 | " titlebar.append(titletext)\n",
274 | " this.root.append(titlebar);\n",
275 | " this.header = titletext[0];\n",
276 | "}\n",
277 | "\n",
278 | "\n",
279 | "\n",
280 | "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
281 | "\n",
282 | "}\n",
283 | "\n",
284 | "\n",
285 | "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
286 | "\n",
287 | "}\n",
288 | "\n",
289 | "mpl.figure.prototype._init_canvas = function() {\n",
290 | " var fig = this;\n",
291 | "\n",
292 | " var canvas_div = $('');\n",
293 | "\n",
294 | " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
295 | "\n",
296 | " function canvas_keyboard_event(event) {\n",
297 | " return fig.key_event(event, event['data']);\n",
298 | " }\n",
299 | "\n",
300 | " canvas_div.keydown('key_press', canvas_keyboard_event);\n",
301 | " canvas_div.keyup('key_release', canvas_keyboard_event);\n",
302 | " this.canvas_div = canvas_div\n",
303 | " this._canvas_extra_style(canvas_div)\n",
304 | " this.root.append(canvas_div);\n",
305 | "\n",
306 | " var canvas = $('');\n",
307 | " canvas.addClass('mpl-canvas');\n",
308 | " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
309 | "\n",
310 | " this.canvas = canvas[0];\n",
311 | " this.context = canvas[0].getContext(\"2d\");\n",
312 | "\n",
313 | " var backingStore = this.context.backingStorePixelRatio ||\n",
314 | "\tthis.context.webkitBackingStorePixelRatio ||\n",
315 | "\tthis.context.mozBackingStorePixelRatio ||\n",
316 | "\tthis.context.msBackingStorePixelRatio ||\n",
317 | "\tthis.context.oBackingStorePixelRatio ||\n",
318 | "\tthis.context.backingStorePixelRatio || 1;\n",
319 | "\n",
320 | " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
321 | "\n",
322 | " var rubberband = $('');\n",
323 | " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
324 | "\n",
325 | " var pass_mouse_events = true;\n",
326 | "\n",
327 | " canvas_div.resizable({\n",
328 | " start: function(event, ui) {\n",
329 | " pass_mouse_events = false;\n",
330 | " },\n",
331 | " resize: function(event, ui) {\n",
332 | " fig.request_resize(ui.size.width, ui.size.height);\n",
333 | " },\n",
334 | " stop: function(event, ui) {\n",
335 | " pass_mouse_events = true;\n",
336 | " fig.request_resize(ui.size.width, ui.size.height);\n",
337 | " },\n",
338 | " });\n",
339 | "\n",
340 | " function mouse_event_fn(event) {\n",
341 | " if (pass_mouse_events)\n",
342 | " return fig.mouse_event(event, event['data']);\n",
343 | " }\n",
344 | "\n",
345 | " rubberband.mousedown('button_press', mouse_event_fn);\n",
346 | " rubberband.mouseup('button_release', mouse_event_fn);\n",
347 | " // Throttle sequential mouse events to 1 every 20ms.\n",
348 | " rubberband.mousemove('motion_notify', mouse_event_fn);\n",
349 | "\n",
350 | " rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
351 | " rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
352 | "\n",
353 | " canvas_div.on(\"wheel\", function (event) {\n",
354 | " event = event.originalEvent;\n",
355 | " event['data'] = 'scroll'\n",
356 | " if (event.deltaY < 0) {\n",
357 | " event.step = 1;\n",
358 | " } else {\n",
359 | " event.step = -1;\n",
360 | " }\n",
361 | " mouse_event_fn(event);\n",
362 | " });\n",
363 | "\n",
364 | " canvas_div.append(canvas);\n",
365 | " canvas_div.append(rubberband);\n",
366 | "\n",
367 | " this.rubberband = rubberband;\n",
368 | " this.rubberband_canvas = rubberband[0];\n",
369 | " this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
370 | " this.rubberband_context.strokeStyle = \"#000000\";\n",
371 | "\n",
372 | " this._resize_canvas = function(width, height) {\n",
373 | " // Keep the size of the canvas, canvas container, and rubber band\n",
374 | " // canvas in synch.\n",
375 | " canvas_div.css('width', width)\n",
376 | " canvas_div.css('height', height)\n",
377 | "\n",
378 | " canvas.attr('width', width * mpl.ratio);\n",
379 | " canvas.attr('height', height * mpl.ratio);\n",
380 | " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
381 | "\n",
382 | " rubberband.attr('width', width);\n",
383 | " rubberband.attr('height', height);\n",
384 | " }\n",
385 | "\n",
386 | " // Set the figure to an initial 600x600px, this will subsequently be updated\n",
387 | " // upon first draw.\n",
388 | " this._resize_canvas(600, 600);\n",
389 | "\n",
390 | " // Disable right mouse context menu.\n",
391 | " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
392 | " return false;\n",
393 | " });\n",
394 | "\n",
395 | " function set_focus () {\n",
396 | " canvas.focus();\n",
397 | " canvas_div.focus();\n",
398 | " }\n",
399 | "\n",
400 | " window.setTimeout(set_focus, 100);\n",
401 | "}\n",
402 | "\n",
403 | "mpl.figure.prototype._init_toolbar = function() {\n",
404 | " var fig = this;\n",
405 | "\n",
406 | " var nav_element = $('')\n",
407 | " nav_element.attr('style', 'width: 100%');\n",
408 | " this.root.append(nav_element);\n",
409 | "\n",
410 | " // Define a callback function for later on.\n",
411 | " function toolbar_event(event) {\n",
412 | " return fig.toolbar_button_onclick(event['data']);\n",
413 | " }\n",
414 | " function toolbar_mouse_event(event) {\n",
415 | " return fig.toolbar_button_onmouseover(event['data']);\n",
416 | " }\n",
417 | "\n",
418 | " for(var toolbar_ind in mpl.toolbar_items) {\n",
419 | " var name = mpl.toolbar_items[toolbar_ind][0];\n",
420 | " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
421 | " var image = mpl.toolbar_items[toolbar_ind][2];\n",
422 | " var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
423 | "\n",
424 | " if (!name) {\n",
425 | " // put a spacer in here.\n",
426 | " continue;\n",
427 | " }\n",
428 | " var button = $('');\n",
429 | " button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
430 | " 'ui-button-icon-only');\n",
431 | " button.attr('role', 'button');\n",
432 | " button.attr('aria-disabled', 'false');\n",
433 | " button.click(method_name, toolbar_event);\n",
434 | " button.mouseover(tooltip, toolbar_mouse_event);\n",
435 | "\n",
436 | " var icon_img = $('');\n",
437 | " icon_img.addClass('ui-button-icon-primary ui-icon');\n",
438 | " icon_img.addClass(image);\n",
439 | " icon_img.addClass('ui-corner-all');\n",
440 | "\n",
441 | " var tooltip_span = $('');\n",
442 | " tooltip_span.addClass('ui-button-text');\n",
443 | " tooltip_span.html(tooltip);\n",
444 | "\n",
445 | " button.append(icon_img);\n",
446 | " button.append(tooltip_span);\n",
447 | "\n",
448 | " nav_element.append(button);\n",
449 | " }\n",
450 | "\n",
451 | " var fmt_picker_span = $('');\n",
452 | "\n",
453 | " var fmt_picker = $('');\n",
454 | " fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
455 | " fmt_picker_span.append(fmt_picker);\n",
456 | " nav_element.append(fmt_picker_span);\n",
457 | " this.format_dropdown = fmt_picker[0];\n",
458 | "\n",
459 | " for (var ind in mpl.extensions) {\n",
460 | " var fmt = mpl.extensions[ind];\n",
461 | " var option = $(\n",
462 | " '', {selected: fmt === mpl.default_extension}).html(fmt);\n",
463 | " fmt_picker.append(option)\n",
464 | " }\n",
465 | "\n",
466 | " // Add hover states to the ui-buttons\n",
467 | " $( \".ui-button\" ).hover(\n",
468 | " function() { $(this).addClass(\"ui-state-hover\");},\n",
469 | " function() { $(this).removeClass(\"ui-state-hover\");}\n",
470 | " );\n",
471 | "\n",
472 | " var status_bar = $('');\n",
473 | " nav_element.append(status_bar);\n",
474 | " this.message = status_bar[0];\n",
475 | "}\n",
476 | "\n",
477 | "mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
478 | " // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
479 | " // which will in turn request a refresh of the image.\n",
480 | " this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
481 | "}\n",
482 | "\n",
483 | "mpl.figure.prototype.send_message = function(type, properties) {\n",
484 | " properties['type'] = type;\n",
485 | " properties['figure_id'] = this.id;\n",
486 | " this.ws.send(JSON.stringify(properties));\n",
487 | "}\n",
488 | "\n",
489 | "mpl.figure.prototype.send_draw_message = function() {\n",
490 | " if (!this.waiting) {\n",
491 | " this.waiting = true;\n",
492 | " this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
493 | " }\n",
494 | "}\n",
495 | "\n",
496 | "\n",
497 | "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
498 | " var format_dropdown = fig.format_dropdown;\n",
499 | " var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
500 | " fig.ondownload(fig, format);\n",
501 | "}\n",
502 | "\n",
503 | "\n",
504 | "mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
505 | " var size = msg['size'];\n",
506 | " if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
507 | " fig._resize_canvas(size[0], size[1]);\n",
508 | " fig.send_message(\"refresh\", {});\n",
509 | " };\n",
510 | "}\n",
511 | "\n",
512 | "mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
513 | " var x0 = msg['x0'] / mpl.ratio;\n",
514 | " var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
515 | " var x1 = msg['x1'] / mpl.ratio;\n",
516 | " var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
517 | " x0 = Math.floor(x0) + 0.5;\n",
518 | " y0 = Math.floor(y0) + 0.5;\n",
519 | " x1 = Math.floor(x1) + 0.5;\n",
520 | " y1 = Math.floor(y1) + 0.5;\n",
521 | " var min_x = Math.min(x0, x1);\n",
522 | " var min_y = Math.min(y0, y1);\n",
523 | " var width = Math.abs(x1 - x0);\n",
524 | " var height = Math.abs(y1 - y0);\n",
525 | "\n",
526 | " fig.rubberband_context.clearRect(\n",
527 | " 0, 0, fig.canvas.width, fig.canvas.height);\n",
528 | "\n",
529 | " fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
530 | "}\n",
531 | "\n",
532 | "mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
533 | " // Updates the figure title.\n",
534 | " fig.header.textContent = msg['label'];\n",
535 | "}\n",
536 | "\n",
537 | "mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
538 | " var cursor = msg['cursor'];\n",
539 | " switch(cursor)\n",
540 | " {\n",
541 | " case 0:\n",
542 | " cursor = 'pointer';\n",
543 | " break;\n",
544 | " case 1:\n",
545 | " cursor = 'default';\n",
546 | " break;\n",
547 | " case 2:\n",
548 | " cursor = 'crosshair';\n",
549 | " break;\n",
550 | " case 3:\n",
551 | " cursor = 'move';\n",
552 | " break;\n",
553 | " }\n",
554 | " fig.rubberband_canvas.style.cursor = cursor;\n",
555 | "}\n",
556 | "\n",
557 | "mpl.figure.prototype.handle_message = function(fig, msg) {\n",
558 | " fig.message.textContent = msg['message'];\n",
559 | "}\n",
560 | "\n",
561 | "mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
562 | " // Request the server to send over a new figure.\n",
563 | " fig.send_draw_message();\n",
564 | "}\n",
565 | "\n",
566 | "mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
567 | " fig.image_mode = msg['mode'];\n",
568 | "}\n",
569 | "\n",
570 | "mpl.figure.prototype.updated_canvas_event = function() {\n",
571 | " // Called whenever the canvas gets updated.\n",
572 | " this.send_message(\"ack\", {});\n",
573 | "}\n",
574 | "\n",
575 | "// A function to construct a web socket function for onmessage handling.\n",
576 | "// Called in the figure constructor.\n",
577 | "mpl.figure.prototype._make_on_message_function = function(fig) {\n",
578 | " return function socket_on_message(evt) {\n",
579 | " if (evt.data instanceof Blob) {\n",
580 | " /* FIXME: We get \"Resource interpreted as Image but\n",
581 | " * transferred with MIME type text/plain:\" errors on\n",
582 | " * Chrome. But how to set the MIME type? It doesn't seem\n",
583 | " * to be part of the websocket stream */\n",
584 | " evt.data.type = \"image/png\";\n",
585 | "\n",
586 | " /* Free the memory for the previous frames */\n",
587 | " if (fig.imageObj.src) {\n",
588 | " (window.URL || window.webkitURL).revokeObjectURL(\n",
589 | " fig.imageObj.src);\n",
590 | " }\n",
591 | "\n",
592 | " fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
593 | " evt.data);\n",
594 | " fig.updated_canvas_event();\n",
595 | " fig.waiting = false;\n",
596 | " return;\n",
597 | " }\n",
598 | " else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
599 | " fig.imageObj.src = evt.data;\n",
600 | " fig.updated_canvas_event();\n",
601 | " fig.waiting = false;\n",
602 | " return;\n",
603 | " }\n",
604 | "\n",
605 | " var msg = JSON.parse(evt.data);\n",
606 | " var msg_type = msg['type'];\n",
607 | "\n",
608 | " // Call the \"handle_{type}\" callback, which takes\n",
609 | " // the figure and JSON message as its only arguments.\n",
610 | " try {\n",
611 | " var callback = fig[\"handle_\" + msg_type];\n",
612 | " } catch (e) {\n",
613 | " console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
614 | " return;\n",
615 | " }\n",
616 | "\n",
617 | " if (callback) {\n",
618 | " try {\n",
619 | " // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
620 | " callback(fig, msg);\n",
621 | " } catch (e) {\n",
622 | " console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
623 | " }\n",
624 | " }\n",
625 | " };\n",
626 | "}\n",
627 | "\n",
628 | "// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
629 | "mpl.findpos = function(e) {\n",
630 | " //this section is from http://www.quirksmode.org/js/events_properties.html\n",
631 | " var targ;\n",
632 | " if (!e)\n",
633 | " e = window.event;\n",
634 | " if (e.target)\n",
635 | " targ = e.target;\n",
636 | " else if (e.srcElement)\n",
637 | " targ = e.srcElement;\n",
638 | " if (targ.nodeType == 3) // defeat Safari bug\n",
639 | " targ = targ.parentNode;\n",
640 | "\n",
641 | " // jQuery normalizes the pageX and pageY\n",
642 | " // pageX,Y are the mouse positions relative to the document\n",
643 | " // offset() returns the position of the element relative to the document\n",
644 | " var x = e.pageX - $(targ).offset().left;\n",
645 | " var y = e.pageY - $(targ).offset().top;\n",
646 | "\n",
647 | " return {\"x\": x, \"y\": y};\n",
648 | "};\n",
649 | "\n",
650 | "/*\n",
651 | " * return a copy of an object with only non-object keys\n",
652 | " * we need this to avoid circular references\n",
653 | " * http://stackoverflow.com/a/24161582/3208463\n",
654 | " */\n",
655 | "function simpleKeys (original) {\n",
656 | " return Object.keys(original).reduce(function (obj, key) {\n",
657 | " if (typeof original[key] !== 'object')\n",
658 | " obj[key] = original[key]\n",
659 | " return obj;\n",
660 | " }, {});\n",
661 | "}\n",
662 | "\n",
663 | "mpl.figure.prototype.mouse_event = function(event, name) {\n",
664 | " var canvas_pos = mpl.findpos(event)\n",
665 | "\n",
666 | " if (name === 'button_press')\n",
667 | " {\n",
668 | " this.canvas.focus();\n",
669 | " this.canvas_div.focus();\n",
670 | " }\n",
671 | "\n",
672 | " var x = canvas_pos.x * mpl.ratio;\n",
673 | " var y = canvas_pos.y * mpl.ratio;\n",
674 | "\n",
675 | " this.send_message(name, {x: x, y: y, button: event.button,\n",
676 | " step: event.step,\n",
677 | " guiEvent: simpleKeys(event)});\n",
678 | "\n",
679 | " /* This prevents the web browser from automatically changing to\n",
680 | " * the text insertion cursor when the button is pressed. We want\n",
681 | " * to control all of the cursor setting manually through the\n",
682 | " * 'cursor' event from matplotlib */\n",
683 | " event.preventDefault();\n",
684 | " return false;\n",
685 | "}\n",
686 | "\n",
687 | "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
688 | " // Handle any extra behaviour associated with a key event\n",
689 | "}\n",
690 | "\n",
691 | "mpl.figure.prototype.key_event = function(event, name) {\n",
692 | "\n",
693 | " // Prevent repeat events\n",
694 | " if (name == 'key_press')\n",
695 | " {\n",
696 | " if (event.which === this._key)\n",
697 | " return;\n",
698 | " else\n",
699 | " this._key = event.which;\n",
700 | " }\n",
701 | " if (name == 'key_release')\n",
702 | " this._key = null;\n",
703 | "\n",
704 | " var value = '';\n",
705 | " if (event.ctrlKey && event.which != 17)\n",
706 | " value += \"ctrl+\";\n",
707 | " if (event.altKey && event.which != 18)\n",
708 | " value += \"alt+\";\n",
709 | " if (event.shiftKey && event.which != 16)\n",
710 | " value += \"shift+\";\n",
711 | "\n",
712 | " value += 'k';\n",
713 | " value += event.which.toString();\n",
714 | "\n",
715 | " this._key_event_extra(event, name);\n",
716 | "\n",
717 | " this.send_message(name, {key: value,\n",
718 | " guiEvent: simpleKeys(event)});\n",
719 | " return false;\n",
720 | "}\n",
721 | "\n",
722 | "mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
723 | " if (name == 'download') {\n",
724 | " this.handle_save(this, null);\n",
725 | " } else {\n",
726 | " this.send_message(\"toolbar_button\", {name: name});\n",
727 | " }\n",
728 | "};\n",
729 | "\n",
730 | "mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
731 | " this.message.textContent = tooltip;\n",
732 | "};\n",
733 | "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
734 | "\n",
735 | "mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
736 | "\n",
737 | "mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
738 | " // Create a \"websocket\"-like object which calls the given IPython comm\n",
739 | " // object with the appropriate methods. Currently this is a non binary\n",
740 | " // socket, so there is still some room for performance tuning.\n",
741 | " var ws = {};\n",
742 | "\n",
743 | " ws.close = function() {\n",
744 | " comm.close()\n",
745 | " };\n",
746 | " ws.send = function(m) {\n",
747 | " //console.log('sending', m);\n",
748 | " comm.send(m);\n",
749 | " };\n",
750 | " // Register the callback with on_msg.\n",
751 | " comm.on_msg(function(msg) {\n",
752 | " //console.log('receiving', msg['content']['data'], msg);\n",
753 | " // Pass the mpl event to the overriden (by mpl) onmessage function.\n",
754 | " ws.onmessage(msg['content']['data'])\n",
755 | " });\n",
756 | " return ws;\n",
757 | "}\n",
758 | "\n",
759 | "mpl.mpl_figure_comm = function(comm, msg) {\n",
760 | " // This is the function which gets called when the mpl process\n",
761 | " // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
762 | "\n",
763 | " var id = msg.content.data.id;\n",
764 | " // Get hold of the div created by the display call when the Comm\n",
765 | " // socket was opened in Python.\n",
766 | " var element = $(\"#\" + id);\n",
767 | " var ws_proxy = comm_websocket_adapter(comm)\n",
768 | "\n",
769 | " function ondownload(figure, format) {\n",
770 | " window.open(figure.imageObj.src);\n",
771 | " }\n",
772 | "\n",
773 | " var fig = new mpl.figure(id, ws_proxy,\n",
774 | " ondownload,\n",
775 | " element.get(0));\n",
776 | "\n",
777 | " // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
778 | " // web socket which is closed, not our websocket->open comm proxy.\n",
779 | " ws_proxy.onopen();\n",
780 | "\n",
781 | " fig.parent_element = element.get(0);\n",
782 | " fig.cell_info = mpl.find_output_cell(\"\");\n",
783 | " if (!fig.cell_info) {\n",
784 | " console.error(\"Failed to find cell for figure\", id, fig);\n",
785 | " return;\n",
786 | " }\n",
787 | "\n",
788 | " var output_index = fig.cell_info[2]\n",
789 | " var cell = fig.cell_info[0];\n",
790 | "\n",
791 | "};\n",
792 | "\n",
793 | "mpl.figure.prototype.handle_close = function(fig, msg) {\n",
794 | " var width = fig.canvas.width/mpl.ratio\n",
795 | " fig.root.unbind('remove')\n",
796 | "\n",
797 | " // Update the output cell to use the data from the current canvas.\n",
798 | " fig.push_to_output();\n",
799 | " var dataURL = fig.canvas.toDataURL();\n",
800 | " // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
801 | " // the notebook keyboard shortcuts fail.\n",
802 | " IPython.keyboard_manager.enable()\n",
803 | " $(fig.parent_element).html('
');\n",
804 | " fig.close_ws(fig, msg);\n",
805 | "}\n",
806 | "\n",
807 | "mpl.figure.prototype.close_ws = function(fig, msg){\n",
808 | " fig.send_message('closing', msg);\n",
809 | " // fig.ws.close()\n",
810 | "}\n",
811 | "\n",
812 | "mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
813 | " // Turn the data on the canvas into data in the output cell.\n",
814 | " var width = this.canvas.width/mpl.ratio\n",
815 | " var dataURL = this.canvas.toDataURL();\n",
816 | " this.cell_info[1]['text/html'] = '
';\n",
817 | "}\n",
818 | "\n",
819 | "mpl.figure.prototype.updated_canvas_event = function() {\n",
820 | " // Tell IPython that the notebook contents must change.\n",
821 | " IPython.notebook.set_dirty(true);\n",
822 | " this.send_message(\"ack\", {});\n",
823 | " var fig = this;\n",
824 | " // Wait a second, then push the new image to the DOM so\n",
825 | " // that it is saved nicely (might be nice to debounce this).\n",
826 | " setTimeout(function () { fig.push_to_output() }, 1000);\n",
827 | "}\n",
828 | "\n",
829 | "mpl.figure.prototype._init_toolbar = function() {\n",
830 | " var fig = this;\n",
831 | "\n",
832 | " var nav_element = $('')\n",
833 | " nav_element.attr('style', 'width: 100%');\n",
834 | " this.root.append(nav_element);\n",
835 | "\n",
836 | " // Define a callback function for later on.\n",
837 | " function toolbar_event(event) {\n",
838 | " return fig.toolbar_button_onclick(event['data']);\n",
839 | " }\n",
840 | " function toolbar_mouse_event(event) {\n",
841 | " return fig.toolbar_button_onmouseover(event['data']);\n",
842 | " }\n",
843 | "\n",
844 | " for(var toolbar_ind in mpl.toolbar_items){\n",
845 | " var name = mpl.toolbar_items[toolbar_ind][0];\n",
846 | " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
847 | " var image = mpl.toolbar_items[toolbar_ind][2];\n",
848 | " var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
849 | "\n",
850 | " if (!name) { continue; };\n",
851 | "\n",
852 | " var button = $('');\n",
853 | " button.click(method_name, toolbar_event);\n",
854 | " button.mouseover(tooltip, toolbar_mouse_event);\n",
855 | " nav_element.append(button);\n",
856 | " }\n",
857 | "\n",
858 | " // Add the status bar.\n",
859 | " var status_bar = $('');\n",
860 | " nav_element.append(status_bar);\n",
861 | " this.message = status_bar[0];\n",
862 | "\n",
863 | " // Add the close button to the window.\n",
864 | " var buttongrp = $('');\n",
865 | " var button = $('');\n",
866 | " button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
867 | " button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
868 | " buttongrp.append(button);\n",
869 | " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
870 | " titlebar.prepend(buttongrp);\n",
871 | "}\n",
872 | "\n",
873 | "mpl.figure.prototype._root_extra_style = function(el){\n",
874 | " var fig = this\n",
875 | " el.on(\"remove\", function(){\n",
876 | "\tfig.close_ws(fig, {});\n",
877 | " });\n",
878 | "}\n",
879 | "\n",
880 | "mpl.figure.prototype._canvas_extra_style = function(el){\n",
881 | " // this is important to make the div 'focusable\n",
882 | " el.attr('tabindex', 0)\n",
883 | " // reach out to IPython and tell the keyboard manager to turn it's self\n",
884 | " // off when our div gets focus\n",
885 | "\n",
886 | " // location in version 3\n",
887 | " if (IPython.notebook.keyboard_manager) {\n",
888 | " IPython.notebook.keyboard_manager.register_events(el);\n",
889 | " }\n",
890 | " else {\n",
891 | " // location in version 2\n",
892 | " IPython.keyboard_manager.register_events(el);\n",
893 | " }\n",
894 | "\n",
895 | "}\n",
896 | "\n",
897 | "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
898 | " var manager = IPython.notebook.keyboard_manager;\n",
899 | " if (!manager)\n",
900 | " manager = IPython.keyboard_manager;\n",
901 | "\n",
902 | " // Check for shift+enter\n",
903 | " if (event.shiftKey && event.which == 13) {\n",
904 | " this.canvas_div.blur();\n",
905 | " // select the cell after this one\n",
906 | " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n",
907 | " IPython.notebook.select(index + 1);\n",
908 | " }\n",
909 | "}\n",
910 | "\n",
911 | "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
912 | " fig.ondownload(fig, null);\n",
913 | "}\n",
914 | "\n",
915 | "\n",
916 | "mpl.find_output_cell = function(html_output) {\n",
917 | " // Return the cell and output element which can be found *uniquely* in the notebook.\n",
918 | " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
919 | " // IPython event is triggered only after the cells have been serialised, which for\n",
920 | " // our purposes (turning an active figure into a static one), is too late.\n",
921 | " var cells = IPython.notebook.get_cells();\n",
922 | " var ncells = cells.length;\n",
923 | " for (var i=0; i= 3 moved mimebundle to data attribute of output\n",
930 | " data = data.data;\n",
931 | " }\n",
932 | " if (data['text/html'] == html_output) {\n",
933 | " return [cell, data, j];\n",
934 | " }\n",
935 | " }\n",
936 | " }\n",
937 | " }\n",
938 | "}\n",
939 | "\n",
940 | "// Register the function which deals with the matplotlib target/channel.\n",
941 | "// The kernel may be null if the page has been refreshed.\n",
942 | "if (IPython.notebook.kernel != null) {\n",
943 | " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
944 | "}\n"
945 | ],
946 | "text/plain": [
947 | ""
948 | ]
949 | },
950 | "metadata": {},
951 | "output_type": "display_data"
952 | },
953 | {
954 | "data": {
955 | "text/html": [
956 | "
"
957 | ],
958 | "text/plain": [
959 | ""
960 | ]
961 | },
962 | "metadata": {},
963 | "output_type": "display_data"
964 | }
965 | ],
966 | "source": [
967 | "init = tf.global_variables_initializer()\n",
968 | "\n",
969 | "with tf.Session() as sess:\n",
970 | " sess.run(init)\n",
971 | " for i in range(TRAINING_ITER):\n",
972 | " batch_xs, _ = mnist.train.next_batch(BATCH_SIZE)\n",
973 | " sess.run(train_step, feed_dict = {x: batch_xs})\n",
974 | " if i % 50 == 0:\n",
975 | " print sess.run(loss, feed_dict = {x: batch_xs})\n",
976 | " encode_decode = sess.run(x_, feed_dict={x: mnist.test.images[:10]})\n",
977 | " # Compare original images with their reconstructions\n",
978 | " f, a = plt.subplots(2, 10, figsize=(10, 2))\n",
979 | " for i in range(10):\n",
980 | " a[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)))\n",
981 | " a[1][i].imshow(np.reshape(encode_decode[i], (28, 28)))\n",
982 | " plt.show()"
983 | ]
984 | },
985 | {
986 | "cell_type": "code",
987 | "execution_count": null,
988 | "metadata": {
989 | "collapsed": false
990 | },
991 | "outputs": [],
992 | "source": []
993 | },
994 | {
995 | "cell_type": "code",
996 | "execution_count": null,
997 | "metadata": {
998 | "collapsed": true
999 | },
1000 | "outputs": [],
1001 | "source": []
1002 | },
1003 | {
1004 | "cell_type": "code",
1005 | "execution_count": null,
1006 | "metadata": {
1007 | "collapsed": true
1008 | },
1009 | "outputs": [],
1010 | "source": []
1011 | }
1012 | ],
1013 | "metadata": {
1014 | "kernelspec": {
1015 | "display_name": "Python 2",
1016 | "language": "python",
1017 | "name": "python2"
1018 | },
1019 | "language_info": {
1020 | "codemirror_mode": {
1021 | "name": "ipython",
1022 | "version": 2
1023 | },
1024 | "file_extension": ".py",
1025 | "mimetype": "text/x-python",
1026 | "name": "python",
1027 | "nbconvert_exporter": "python",
1028 | "pygments_lexer": "ipython2",
1029 | "version": "2.7.11"
1030 | }
1031 | },
1032 | "nbformat": 4,
1033 | "nbformat_minor": 0
1034 | }
1035 |
--------------------------------------------------------------------------------
/MNIST_BILSTM.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "from tensorflow.examples.tutorials.mnist import input_data\n",
12 | "import tensorflow as tf"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 2,
18 | "metadata": {
19 | "collapsed": false
20 | },
21 | "outputs": [
22 | {
23 | "name": "stdout",
24 | "output_type": "stream",
25 | "text": [
26 | "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
27 | "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
28 | "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
29 | "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n"
30 | ]
31 | }
32 | ],
33 | "source": [
34 | "mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 3,
40 | "metadata": {
41 | "collapsed": false
42 | },
43 | "outputs": [],
44 | "source": [
45 | "HIDDEN_UNITS = 256\n",
46 | "N_CLASSES = 10\n",
47 | "LEARNING_RATE = 1e-2\n",
48 | "TRAINING_ITER = 1000\n",
49 | "BATCH_SIZE = 100\n",
50 | "\n",
51 | "def BiRNN(x, weights, biases):\n",
52 | " x = tf.transpose(x, [1,0,2])\n",
53 | " x = tf.reshape(x, [-1, 28])\n",
54 | " x = tf.split(x, 28)\n",
55 | " \n",
56 | " lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(num_units=HIDDEN_UNITS)\n",
57 | " lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(num_units=HIDDEN_UNITS)\n",
58 | " \n",
59 | " outputs, _, _ = tf.contrib.rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x, dtype=tf.float32)\n",
60 | " \n",
61 | " return tf.matmul(outputs[-1], weights) + biases\n",
62 | "\n",
63 | "\n",
64 | "x = tf.placeholder(tf.float32, [None, 784])\n",
65 | "y = tf.placeholder(tf.float32, [None, 10])\n",
66 | "\n",
67 | "x_2d = tf.reshape(x, [-1, 28, 28])\n",
68 | "\n",
69 | "weights = {'output': tf.Variable(tf.random_normal([2 * HIDDEN_UNITS, N_CLASSES]))}\n",
70 | "biases = {'output': tf.Variable(tf.zeros([N_CLASSES]) + 0.1)}\n",
71 | "\n",
72 | "predictions = BiRNN(x_2d, weights['output'], biases['output'])\n",
73 | "\n",
74 | "loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=predictions))\n",
75 | "train_step = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)\n",
76 | "\n",
77 | "accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(y, 1), tf.arg_max(predictions, 1)), tf.float32))\n",
78 | "\n",
79 | "\n",
80 | "init = tf.global_variables_initializer()\n"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": 5,
86 | "metadata": {
87 | "collapsed": false
88 | },
89 | "outputs": [
90 | {
91 | "name": "stdout",
92 | "output_type": "stream",
93 | "text": [
94 | "0.15\n",
95 | "0.73\n",
96 | "0.84\n",
97 | "0.92\n",
98 | "0.95\n",
99 | "0.91\n",
100 | "0.95\n",
101 | "0.98\n",
102 | "0.99\n",
103 | "0.97\n",
104 | "0.97\n",
105 | "0.99\n",
106 | "0.99\n",
107 | "0.98\n",
108 | "0.99\n",
109 | "0.99\n"
110 | ]
111 | },
112 | {
113 | "ename": "KeyboardInterrupt",
114 | "evalue": "",
115 | "output_type": "error",
116 | "traceback": [
117 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
118 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
119 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mTRAINING_ITER\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mbatch_xs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_ys\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmnist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnext_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mBATCH_SIZE\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_step\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbatch_xs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbatch_ys\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m50\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mprint\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maccuracy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbatch_xs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbatch_ys\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
120 | "\u001b[0;32m/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 765\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 766\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 767\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 768\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 769\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
121 | "\u001b[0;32m/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 963\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 964\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m--> 965\u001b[0;31m feed_dict_string, options, run_metadata)\n\u001b[0m\u001b[1;32m 966\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 967\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
122 | "\u001b[0;32m/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1013\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1014\u001b[0m return self._do_call(_run_fn, self._session, feed_dict, fetch_list,\n\u001b[0;32m-> 1015\u001b[0;31m target_list, options, run_metadata)\n\u001b[0m\u001b[1;32m 1016\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1017\u001b[0m return self._do_call(_prun_fn, self._session, handle, feed_dict,\n",
123 | "\u001b[0;32m/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m 1020\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1021\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1022\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1023\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1024\u001b[0m \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
124 | "\u001b[0;32m/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m 1002\u001b[0m return tf_session.TF_Run(session, options,\n\u001b[1;32m 1003\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1004\u001b[0;31m status, run_metadata)\n\u001b[0m\u001b[1;32m 1005\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1006\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msession\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
125 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
126 | ]
127 | }
128 | ],
129 | "source": [
130 | "with tf.Session() as sess:\n",
131 | " sess.run(init)\n",
132 | " for i in range(TRAINING_ITER):\n",
133 | " batch_xs, batch_ys = mnist.train.next_batch(BATCH_SIZE)\n",
134 | " sess.run(train_step, feed_dict = {x: batch_xs, y: batch_ys})\n",
135 | " if i % 50 == 0:\n",
136 | " print sess.run(accuracy, feed_dict = {x: batch_xs, y: batch_ys})\n",
137 | " print sess.run(accuracy, feed_dict = {x: mnist.validation.images, y: mnist.validation.labels})"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": null,
143 | "metadata": {
144 | "collapsed": true
145 | },
146 | "outputs": [],
147 | "source": []
148 | }
149 | ],
150 | "metadata": {
151 | "kernelspec": {
152 | "display_name": "Python 2",
153 | "language": "python",
154 | "name": "python2"
155 | },
156 | "language_info": {
157 | "codemirror_mode": {
158 | "name": "ipython",
159 | "version": 2
160 | },
161 | "file_extension": ".py",
162 | "mimetype": "text/x-python",
163 | "name": "python",
164 | "nbconvert_exporter": "python",
165 | "pygments_lexer": "ipython2",
166 | "version": "2.7.11"
167 | }
168 | },
169 | "nbformat": 4,
170 | "nbformat_minor": 0
171 | }
172 |
--------------------------------------------------------------------------------
/MNIST_LSTM.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import tensorflow as tf\n",
12 | "from tensorflow.examples.tutorials.mnist import input_data"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 2,
18 | "metadata": {
19 | "collapsed": false
20 | },
21 | "outputs": [
22 | {
23 | "name": "stdout",
24 | "output_type": "stream",
25 | "text": [
26 | "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
27 | "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
28 | "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
29 | "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n"
30 | ]
31 | }
32 | ],
33 | "source": [
34 | "mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 3,
40 | "metadata": {
41 | "collapsed": false
42 | },
43 | "outputs": [],
44 | "source": [
45 | "LSTM_UNITS = 256\n",
46 | "BATCH_SIZE = 100\n",
47 | "LEARNING_RATE = 1e-3\n",
48 | "N_CLASSES = 10\n",
49 | "TRAINING_ITER = 1000\n",
50 | "\n",
51 | "def LSTM(x):\n",
52 | " \n",
53 | " lstm_cell = tf.contrib.rnn.BasicLSTMCell(num_units=LSTM_UNITS)\n",
54 | " \n",
55 | " init_state = lstm_cell.zero_state(batch_size=BATCH_SIZE, dtype=tf.float32)\n",
56 | " \n",
57 | " outputs, final_state = tf.nn.dynamic_rnn(cell=lstm_cell,inputs=x,initial_state=init_state,time_major=False) \n",
58 | " \n",
59 | " return tf.unstack(tf.transpose(outputs, [1, 0, 2]))\n",
60 | "\n",
61 | "def init_weights(shape):\n",
62 | " return tf.Variable(tf.truncated_normal(shape, stddev=0.1))\n",
63 | "\n",
64 | "def init_biases(shape):\n",
65 | " return tf.Variable(tf.zeros(shape) + 0.1)\n",
66 | "\n",
67 | "def add_layer(x, weights, biases, activation_function=None):\n",
68 | " results = tf.matmul(x, weights) + biases\n",
69 | " if activation_function:\n",
70 | " return activation_function(results)\n",
71 | " else:\n",
72 | " return results\n",
73 | "\n",
74 | "weights = {'output': init_weights([LSTM_UNITS, N_CLASSES])}\n",
75 | "biases = {'output': init_biases([N_CLASSES])}\n",
76 | "\n",
77 | "x = tf.placeholder(tf.float32, [None, 784])\n",
78 | "y = tf.placeholder(tf.float32, [None, N_CLASSES])\n",
79 | "\n",
80 | "x_2d = tf.reshape(x, [-1, 28, 28])\n",
81 | "\n",
82 | "lstm_outputs = LSTM(x_2d)\n",
83 | "\n",
84 | "predictions = add_layer(lstm_outputs[-1], weights['output'], biases['output'], None)\n",
85 | "\n",
86 | "loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=predictions))\n",
87 | "\n",
88 | "train_step = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)\n",
89 | "\n",
90 | "accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(y, 1), tf.arg_max(predictions, 1)), tf.float32))\n",
91 | "\n"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": 4,
97 | "metadata": {
98 | "collapsed": false
99 | },
100 | "outputs": [
101 | {
102 | "name": "stdout",
103 | "output_type": "stream",
104 | "text": [
105 | "TRAIN_ACC@0: 0.27\n",
106 | "VALID_ACC@0:"
107 | ]
108 | },
109 | {
110 | "ename": "InvalidArgumentError",
111 | "evalue": "ConcatOp : Dimensions of inputs should match: shape[0] = [5000,28] vs. shape[1] = [100,256]\n\t [[Node: rnn/while/basic_lstm_cell/basic_lstm_cell/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device=\"/job:localhost/replica:0/task:0/cpu:0\"](rnn/while/TensorArrayReadV3, rnn/while/Identity_3, rnn/while/basic_lstm_cell/basic_lstm_cell/concat/axis)]]\n\nCaused by op u'rnn/while/basic_lstm_cell/basic_lstm_cell/concat', defined at:\n File \"/Users/miaofan/anaconda/lib/python2.7/runpy.py\", line 162, in _run_module_as_main\n \"__main__\", fname, loader, pkg_name)\n File \"/Users/miaofan/anaconda/lib/python2.7/runpy.py\", line 72, in _run_code\n exec code in run_globals\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/ipykernel/__main__.py\", line 3, in \n app.launch_new_instance()\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/traitlets/config/application.py\", line 658, in launch_instance\n app.start()\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/ipykernel/kernelapp.py\", line 405, in start\n ioloop.IOLoop.instance().start()\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/zmq/eventloop/ioloop.py\", line 162, in start\n super(ZMQIOLoop, self).start()\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tornado/ioloop.py\", line 883, in start\n handler_func(fd_obj, events)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tornado/stack_context.py\", line 275, in null_wrapper\n return fn(*args, **kwargs)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py\", line 440, in _handle_events\n self._handle_recv()\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py\", line 472, in _handle_recv\n self._run_callback(callback, msg)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py\", line 414, in _run_callback\n callback(*args, **kwargs)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tornado/stack_context.py\", line 275, in null_wrapper\n return fn(*args, **kwargs)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/ipykernel/kernelbase.py\", line 260, in dispatcher\n return self.dispatch_shell(stream, msg)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/ipykernel/kernelbase.py\", line 212, in dispatch_shell\n handler(stream, idents, msg)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/ipykernel/kernelbase.py\", line 370, in execute_request\n user_expressions, allow_stdin)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/ipykernel/ipkernel.py\", line 175, in do_execute\n shell.run_cell(code, store_history=store_history, silent=silent)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/IPython/core/interactiveshell.py\", line 2717, in run_cell\n interactivity=interactivity, compiler=compiler, result=result)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/IPython/core/interactiveshell.py\", line 2821, in run_ast_nodes\n if self.run_code(code, result):\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/IPython/core/interactiveshell.py\", line 2881, in run_code\n exec(code_obj, self.user_global_ns, self.user_ns)\n File \"\", line 38, in \n lstm_outputs = LSTM(x_2d)\n File \"\", line 13, in LSTM\n outputs, final_state = tf.nn.dynamic_rnn(cell=lstm_cell,inputs=x,initial_state=init_state,time_major=False)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/rnn.py\", line 546, in dynamic_rnn\n dtype=dtype)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/rnn.py\", line 713, in _dynamic_rnn_loop\n swap_memory=swap_memory)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py\", line 2605, in while_loop\n result = context.BuildLoop(cond, body, loop_vars, shape_invariants)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py\", line 2438, in BuildLoop\n pred, body, original_loop_vars, loop_vars, shape_invariants)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py\", line 2388, in _BuildLoop\n body_result = body(*packed_vars_for_body)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/rnn.py\", line 698, in _time_step\n (output, new_state) = call_cell()\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/rnn.py\", line 684, in \n call_cell = lambda: cell(input_t, state)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py\", line 179, in __call__\n concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py\", line 751, in _linear\n res = math_ops.matmul(array_ops.concat(args, 1), weights)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py\", line 1034, in concat\n name=name)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py\", line 519, in _concat_v2\n name=name)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py\", line 763, in apply_op\n op_def=op_def)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/ops.py\", line 2327, in create_op\n original_op=self._default_original_op, op_def=op_def)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/ops.py\", line 1226, in __init__\n self._traceback = _extract_stack()\n\nInvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [5000,28] vs. shape[1] = [100,256]\n\t [[Node: rnn/while/basic_lstm_cell/basic_lstm_cell/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device=\"/job:localhost/replica:0/task:0/cpu:0\"](rnn/while/TensorArrayReadV3, rnn/while/Identity_3, rnn/while/basic_lstm_cell/basic_lstm_cell/concat/axis)]]\n",
112 | "output_type": "error",
113 | "traceback": [
114 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
115 | "\u001b[0;31mInvalidArgumentError\u001b[0m Traceback (most recent call last)",
116 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m50\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mprint\u001b[0m \u001b[0;34m'TRAIN_ACC@%d:'\u001b[0m \u001b[0;34m%\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maccuracy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbatch_xs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mbatch_ys\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mprint\u001b[0m \u001b[0;34m'VALID_ACC@%d:'\u001b[0m \u001b[0;34m%\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maccuracy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mmnist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidation\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimages\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mmnist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidation\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
117 | "\u001b[0;32m/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 765\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 766\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 767\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 768\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 769\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
118 | "\u001b[0;32m/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 963\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 964\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m--> 965\u001b[0;31m feed_dict_string, options, run_metadata)\n\u001b[0m\u001b[1;32m 966\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 967\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
119 | "\u001b[0;32m/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1013\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1014\u001b[0m return self._do_call(_run_fn, self._session, feed_dict, fetch_list,\n\u001b[0;32m-> 1015\u001b[0;31m target_list, options, run_metadata)\n\u001b[0m\u001b[1;32m 1016\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1017\u001b[0m return self._do_call(_prun_fn, self._session, handle, feed_dict,\n",
120 | "\u001b[0;32m/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m 1033\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1034\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1035\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode_def\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1036\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1037\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_extend_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
121 | "\u001b[0;31mInvalidArgumentError\u001b[0m: ConcatOp : Dimensions of inputs should match: shape[0] = [5000,28] vs. shape[1] = [100,256]\n\t [[Node: rnn/while/basic_lstm_cell/basic_lstm_cell/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device=\"/job:localhost/replica:0/task:0/cpu:0\"](rnn/while/TensorArrayReadV3, rnn/while/Identity_3, rnn/while/basic_lstm_cell/basic_lstm_cell/concat/axis)]]\n\nCaused by op u'rnn/while/basic_lstm_cell/basic_lstm_cell/concat', defined at:\n File \"/Users/miaofan/anaconda/lib/python2.7/runpy.py\", line 162, in _run_module_as_main\n \"__main__\", fname, loader, pkg_name)\n File \"/Users/miaofan/anaconda/lib/python2.7/runpy.py\", line 72, in _run_code\n exec code in run_globals\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/ipykernel/__main__.py\", line 3, in \n app.launch_new_instance()\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/traitlets/config/application.py\", line 658, in launch_instance\n app.start()\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/ipykernel/kernelapp.py\", line 405, in start\n ioloop.IOLoop.instance().start()\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/zmq/eventloop/ioloop.py\", line 162, in start\n super(ZMQIOLoop, self).start()\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tornado/ioloop.py\", line 883, in start\n handler_func(fd_obj, events)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tornado/stack_context.py\", line 275, in null_wrapper\n return fn(*args, **kwargs)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py\", line 440, in _handle_events\n self._handle_recv()\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py\", line 472, in _handle_recv\n self._run_callback(callback, msg)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py\", line 414, in _run_callback\n callback(*args, **kwargs)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tornado/stack_context.py\", line 275, in null_wrapper\n return fn(*args, **kwargs)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/ipykernel/kernelbase.py\", line 260, in dispatcher\n return self.dispatch_shell(stream, msg)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/ipykernel/kernelbase.py\", line 212, in dispatch_shell\n handler(stream, idents, msg)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/ipykernel/kernelbase.py\", line 370, in execute_request\n user_expressions, allow_stdin)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/ipykernel/ipkernel.py\", line 175, in do_execute\n shell.run_cell(code, store_history=store_history, silent=silent)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/IPython/core/interactiveshell.py\", line 2717, in run_cell\n interactivity=interactivity, compiler=compiler, result=result)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/IPython/core/interactiveshell.py\", line 2821, in run_ast_nodes\n if self.run_code(code, result):\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/IPython/core/interactiveshell.py\", line 2881, in run_code\n exec(code_obj, self.user_global_ns, self.user_ns)\n File \"\", line 38, in \n lstm_outputs = LSTM(x_2d)\n File \"\", line 13, in LSTM\n outputs, final_state = tf.nn.dynamic_rnn(cell=lstm_cell,inputs=x,initial_state=init_state,time_major=False)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/rnn.py\", line 546, in dynamic_rnn\n dtype=dtype)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/rnn.py\", line 713, in _dynamic_rnn_loop\n swap_memory=swap_memory)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py\", line 2605, in while_loop\n result = context.BuildLoop(cond, body, loop_vars, shape_invariants)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py\", line 2438, in BuildLoop\n pred, body, original_loop_vars, loop_vars, shape_invariants)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py\", line 2388, in _BuildLoop\n body_result = body(*packed_vars_for_body)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/rnn.py\", line 698, in _time_step\n (output, new_state) = call_cell()\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/rnn.py\", line 684, in \n call_cell = lambda: cell(input_t, state)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py\", line 179, in __call__\n concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py\", line 751, in _linear\n res = math_ops.matmul(array_ops.concat(args, 1), weights)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py\", line 1034, in concat\n name=name)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py\", line 519, in _concat_v2\n name=name)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py\", line 763, in apply_op\n op_def=op_def)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/ops.py\", line 2327, in create_op\n original_op=self._default_original_op, op_def=op_def)\n File \"/Users/miaofan/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/ops.py\", line 1226, in __init__\n self._traceback = _extract_stack()\n\nInvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [5000,28] vs. shape[1] = [100,256]\n\t [[Node: rnn/while/basic_lstm_cell/basic_lstm_cell/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device=\"/job:localhost/replica:0/task:0/cpu:0\"](rnn/while/TensorArrayReadV3, rnn/while/Identity_3, rnn/while/basic_lstm_cell/basic_lstm_cell/concat/axis)]]\n"
122 | ]
123 | }
124 | ],
125 | "source": [
126 | "init = tf.global_variables_initializer()\n",
127 | "\n",
128 | "with tf.Session() as sess:\n",
129 | " sess.run(init)\n",
130 | " \n",
131 | " for i in range(TRAINING_ITER):\n",
132 | " batch_xs, batch_ys = mnist.train.next_batch(BATCH_SIZE)\n",
133 | " sess.run(train_step, feed_dict= {x: batch_xs, y: batch_ys})\n",
134 | " if i % 50 == 0:\n",
135 | " print 'TRAIN_ACC@%d:' %i, sess.run(accuracy, feed_dict = {x: batch_xs, y:batch_ys})\n",
136 | " \n",
137 | " \n",
138 | " #print 'VALID_ACC@%d:' %i, sess.run(accuracy, feed_dict = {x: mnist.validation.images, y: mnist.validation.labels})"
139 | ]
140 | }
141 | ],
142 | "metadata": {
143 | "kernelspec": {
144 | "display_name": "Python 2",
145 | "language": "python",
146 | "name": "python2"
147 | },
148 | "language_info": {
149 | "codemirror_mode": {
150 | "name": "ipython",
151 | "version": 2
152 | },
153 | "file_extension": ".py",
154 | "mimetype": "text/x-python",
155 | "name": "python",
156 | "nbconvert_exporter": "python",
157 | "pygments_lexer": "ipython2",
158 | "version": "2.7.11"
159 | }
160 | },
161 | "nbformat": 4,
162 | "nbformat_minor": 0
163 | }
164 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep_Learning_with_TensorFlow
2 | Deep Learning with TensorFlow implemented in Python mainly practiced with MNIST dataset.
3 |
--------------------------------------------------------------------------------