├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── ann.c ├── cnn.lua ├── data └── placeholder.txt ├── mnist.c ├── network.lua ├── testconv.lua └── testmnist.lua /.gitignore: -------------------------------------------------------------------------------- 1 | *-ubyte 2 | *.dll 3 | *.pgm 4 | 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2023 codingnow.com 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | LUA_INC=-I /usr/local/include 2 | LUA_LIB=-L /usr/local/bin -llua54 3 | CFLAGS=-Wall -O2 4 | SHARED=--shared 5 | SO=dll 6 | 7 | all : mnist.$(SO) ann.$(SO) 8 | 9 | mnist.$(SO) : mnist.c 10 | gcc -o $@ $(SHARED) $(CFLAGS) $^ $(LUA_INC) $(LUA_LIB) 11 | 12 | ann.$(SO) : ann.c 13 | gcc -o $@ $(SHARED) $(CFLAGS) $^ $(LUA_INC) $(LUA_LIB) 14 | 15 | clean : 16 | rm -f *.$(SO) 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Toy neural network 2 | 3 | It's a C/Lua implementation of a feedforward neural network in the book "[Neural Networks and Deep Learning](http://neuralnetworksanddeeplearning.com/)". 4 | 5 | 6 | 1. Download MNIST data from http://yann.lecun.com/exdb/mnist/ , and put them into data/ 7 | 2. Build lua modules mnist and ann with lua 5.4 8 | 3. run `lua network.lua` 9 | -------------------------------------------------------------------------------- /ann.c: -------------------------------------------------------------------------------- 1 | #define LUA_LIB 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | struct signal { 12 | int n; 13 | float data[1]; 14 | }; 15 | 16 | static inline struct signal * 17 | check_signal(lua_State *L, int index) { 18 | return (struct signal *)luaL_checkudata(L, index, "ANN_SIGNAL"); 19 | } 20 | 21 | static int 22 | lsignal_toarray(lua_State *L) { 23 | struct signal * s = check_signal(L, 1); 24 | lua_createtable(L, s->n, 0); 25 | int i; 26 | for (i=0;in;i++) { 27 | lua_pushnumber(L, s->data[i]); 28 | lua_rawseti(L, -2, i+1); 29 | } 30 | return 1; 31 | } 32 | 33 | static int 34 | lsignal_size(lua_State *L) { 35 | struct signal * s = check_signal(L, 1); 36 | lua_pushinteger(L, s->n); 37 | return 1; 38 | } 39 | 40 | static void 41 | init_signal_with_string(lua_State *L, struct signal *s, int index) { 42 | size_t sz; 43 | const uint8_t * image = (const uint8_t *)luaL_checklstring(L, index, &sz); 44 | if (sz != s->n) 45 | luaL_error(L, "Invalid image size %d != %d", (int)sz, s->n); 46 | int i; 47 | for (i=0;in;i++) { 48 | s->data[i] = image[i] / 255.0f; 49 | } 50 | } 51 | 52 | static void 53 | init_signal_with_table(lua_State *L, struct signal *s, int index) { 54 | int i; 55 | for (i=0;in;i++) { 56 | if (lua_geti(L, index, i+1) != LUA_TNUMBER) 57 | luaL_error(L, "Invalid signal init %d", i+1); 58 | s->data[i] = lua_tonumber(L, -1); 59 | lua_pop(L, 1); 60 | } 61 | if (lua_geti(L, index, i+1) != LUA_TNIL) 62 | luaL_error(L, "Invalid signal init table (too long)"); 63 | lua_pop(L, 1); 64 | } 65 | 66 | static void 67 | init_signal_n(lua_State *L, struct signal *s, int n) { 68 | if (n < 0 || n >= s->n) 69 | luaL_error(L, "Invalid n (%d)", n); 70 | memset(s->data, 0, sizeof(s->data[0]) * s->n); 71 | s->data[n] = 1.0f; 72 | } 73 | 74 | static int 75 | lsignal_init(lua_State *L) { 76 | struct signal * s = check_signal(L, 1); 77 | switch (lua_type(L, 2)) { 78 | case LUA_TSTRING: 79 | init_signal_with_string(L, s, 2); 80 | break; 81 | case LUA_TNUMBER: 82 | init_signal_n(L, s, luaL_checkinteger(L, 2)); 83 | break; 84 | case LUA_TTABLE: 85 | init_signal_with_table(L, s, 2); 86 | break; 87 | case LUA_TNIL: 88 | case LUA_TNONE: 89 | memset(s->data, 0, sizeof(s->data[0]) * s->n); 90 | break; 91 | default: 92 | return luaL_argerror(L, 2, "Invalid signal init arg"); 93 | } 94 | lua_settop(L, 1); 95 | return 1; 96 | } 97 | 98 | static int 99 | lsignal_max(lua_State *L) { 100 | struct signal * s = check_signal(L, 1); 101 | float m = s->data[0]; 102 | float sum = m; 103 | int idx = 0; 104 | int i; 105 | for (i=1;in;i++) { 106 | if (s->data[i] > m) { 107 | m = s->data[i]; 108 | idx = i; 109 | } 110 | sum += s->data[i]; 111 | } 112 | lua_pushinteger(L, idx); 113 | lua_pushnumber(L, m / sum); 114 | return 2; 115 | } 116 | 117 | static int 118 | lsignal_accumulate(lua_State *L) { 119 | struct signal * s = check_signal(L, 1); 120 | struct signal * delta = check_signal(L, 2); 121 | if (s->n != delta->n) 122 | return luaL_error(L, "signal size %d != %d", s->n, delta->n); 123 | int i; 124 | if (lua_type(L, 3) == LUA_TNUMBER) { 125 | float eta = lua_tonumber(L, 3); 126 | for (i=0;in;i++) { 127 | s->data[i] += delta->data[i] * eta; 128 | } 129 | } else { 130 | for (i=0;in;i++) { 131 | s->data[i] += delta->data[i]; 132 | } 133 | } 134 | lua_settop(L, 1); 135 | return 1; 136 | } 137 | 138 | static inline float 139 | sigmoid(float z) { 140 | return 1.0f / (1.0f + expf(-z)); 141 | } 142 | 143 | static int 144 | lsignal_sigmoid(lua_State *L) { 145 | struct signal * s = check_signal(L, 1); 146 | int i; 147 | for (i=0;in;i++) { 148 | s->data[i] = sigmoid(s->data[i]); 149 | } 150 | lua_settop(L, 1); 151 | return 1; 152 | } 153 | 154 | static int 155 | lsignal_relu(lua_State *L) { 156 | struct signal * s = check_signal(L, 1); 157 | int i; 158 | for (i=0;in;i++) { 159 | if (s->data[i] < 0) 160 | s->data[i] = 0; 161 | } 162 | lua_settop(L, 1); 163 | return 1; 164 | } 165 | 166 | static inline float 167 | sigmoid_prime(float s) { 168 | return s * (1-s); 169 | } 170 | 171 | static void 172 | addfloat(lua_State *L, luaL_Buffer *b, float f) { 173 | char tmp[16]; 174 | int n = snprintf(tmp+1, sizeof(tmp)-1, "%.5g", f); 175 | int k; 176 | for (k=n+1;kn;i++) { 197 | addfloat(L, &b, s->data[i]); 198 | } 199 | luaL_addchar(&b, ']'); 200 | luaL_pushresult(&b); 201 | return 1; 202 | } 203 | 204 | static int 205 | lsignal_image(lua_State *L) { 206 | struct signal * s = check_signal(L, 1); 207 | luaL_Buffer b; 208 | luaL_buffinit(L, &b); 209 | int i; 210 | for (i=0;in;i++) { 211 | float v = s->data[i]; 212 | int c; 213 | if (v <= 0) 214 | c = 0; 215 | else if (v >= 1.0f) 216 | c = 255; 217 | else 218 | c = v * 255; 219 | luaL_addchar(&b, c); 220 | } 221 | luaL_pushresult(&b); 222 | return 1; 223 | } 224 | 225 | static inline void 226 | gaussrand(float r[2], float deviation) { 227 | float V1, V2, S; 228 | do { 229 | float U1 = (double)rand() / RAND_MAX; 230 | float U2 = (double)rand() / RAND_MAX; 231 | 232 | V1 = 2 * U1 - 1; 233 | V2 = 2 * U2 - 1; 234 | S = V1 * V1 + V2 * V2; 235 | } while (S >= 1 || S == 0); 236 | 237 | float X = sqrtf(-2 * logf(S) / S) * deviation; 238 | r[0] = V1 * X; 239 | r[1] = V2 * X; 240 | } 241 | 242 | static void 243 | randn(float *f, int n, float deviation) { 244 | int i; 245 | for (i=0;idata, s->n, deviation); 260 | lua_settop(L, 1); 261 | return 1; 262 | } 263 | 264 | static int 265 | lsignal(lua_State *L) { 266 | int n = luaL_checkinteger(L, 1); 267 | size_t sz = sizeof(struct signal) + sizeof(float) * (n-1); 268 | struct signal * s = (struct signal *)lua_newuserdatauv(L, sz, 0); 269 | memset(s->data, 0, sizeof(s->data[0]) * n); 270 | s->n = n; 271 | if (luaL_newmetatable(L, "ANN_SIGNAL")) { 272 | lua_pushvalue(L, -1); 273 | lua_setfield(L, -2, "__index"); 274 | luaL_Reg l[] = { 275 | { "toarray", lsignal_toarray }, 276 | { "image", lsignal_image }, 277 | { "init", lsignal_init }, 278 | { "randn", lsignal_randn }, 279 | { "max", lsignal_max }, 280 | { "size", lsignal_size }, 281 | { "accumulate", lsignal_accumulate }, 282 | { "sigmoid", lsignal_sigmoid }, 283 | { "relu", lsignal_relu }, 284 | { "__tostring", lsignal_dump }, 285 | { NULL, NULL }, 286 | }; 287 | luaL_setfuncs(L, l, 0); 288 | } 289 | lua_setmetatable(L, -2); 290 | 291 | return 1; 292 | } 293 | 294 | struct weight { 295 | int w; 296 | int h; 297 | float data[1]; 298 | }; 299 | 300 | static inline struct weight * 301 | check_weight(lua_State *L, int index) { 302 | return (struct weight *)luaL_checkudata(L, index, "ANN_WEIGHT"); 303 | } 304 | 305 | static int 306 | lweight_zero(lua_State *L) { 307 | struct weight *w = check_weight(L, 1); 308 | int s = w->w * w->h; 309 | memset(w->data, 0, sizeof(w->data[0]) * s); 310 | lua_settop(L, 1); 311 | return 1; 312 | } 313 | 314 | static int 315 | lweight_size(lua_State *L) { 316 | struct weight *w = check_weight(L, 1); 317 | lua_pushinteger(L, w->w); 318 | lua_pushinteger(L, w->h); 319 | return 2; 320 | } 321 | 322 | static int 323 | lweight_randn(lua_State *L) { 324 | struct weight *w = check_weight(L, 1); 325 | float deviation = luaL_optnumber(L, 2, 1.0f); 326 | randn(w->data, w->w * w->h, deviation); 327 | lua_settop(L, 1); 328 | return 1; 329 | } 330 | 331 | static int 332 | lweight_dump(lua_State *L) { 333 | struct weight *w = check_weight(L, 1); 334 | luaL_Buffer b; 335 | luaL_buffinit(L, &b); 336 | luaL_addchar(&b, '['); 337 | int i,j; 338 | const float * f = w->data; 339 | for (i=0;ih;i++) { 340 | luaL_addlstring(&b, "[ ", 2); 341 | for (j=0;jw;j++) { 342 | addfloat(L, &b, *f); 343 | ++f; 344 | } 345 | luaL_addchar(&b, ']'); 346 | if (ih-1) { 347 | luaL_addlstring(&b, "\n ", 2); 348 | } 349 | } 350 | luaL_addchar(&b, ']'); 351 | luaL_pushresult(&b); 352 | return 1; 353 | } 354 | 355 | static int 356 | lweight_accumulate(lua_State *L) { 357 | struct weight * s = check_weight(L, 1); 358 | struct weight * delta = check_weight(L, 2); 359 | if (s->w != delta->w || s->h != delta->h) 360 | return luaL_error(L, "weight size (%d, %d) != (%d, %d)", s->w, s->h, delta->w, delta->h); 361 | int i; 362 | int sz = s->w * s->h; 363 | if (lua_type(L, 3) == LUA_TNUMBER) { 364 | float eta = lua_tonumber(L, 3); 365 | for (i=0;idata[i] += delta->data[i] * eta; 367 | } 368 | } else { 369 | for (i=0;idata[i] += delta->data[i]; 371 | } 372 | } 373 | lua_settop(L, 1); 374 | return 1; 375 | } 376 | 377 | static int 378 | lweight_import(lua_State *L) { 379 | struct weight * w = check_weight(L, 1); 380 | luaL_checktype(L, 2, LUA_TTABLE); 381 | int i,j; 382 | float *data = w->data; 383 | for (i=0;ih;i++) { 384 | if (lua_geti(L, 2, i+1) != LUA_TTABLE) { 385 | return luaL_error(L, "Invalid source [%d]", i+1); 386 | } 387 | for (j=0;jw;j++) { 388 | if (lua_geti(L, -1, j+1) != LUA_TNUMBER) { 389 | return luaL_error(L, "Invalid source [%d][%d]", i+1, j+1); 390 | } 391 | *data = lua_tonumber(L, -1); 392 | lua_pop(L, 1); 393 | ++data; 394 | } 395 | if (lua_geti(L, -1, w->w+1) != LUA_TNIL) 396 | return luaL_error(L, "Invalid source [%d] (too long)", i+1); 397 | lua_pop(L, 2); 398 | } 399 | return 0; 400 | } 401 | 402 | static int 403 | lweight(lua_State *L) { 404 | int width = luaL_checkinteger(L, 1); 405 | int height = luaL_checkinteger(L, 2); 406 | int s = width * height; 407 | size_t sz = sizeof(struct weight) + sizeof(float) * (s-1); 408 | struct weight * w = (struct weight *)lua_newuserdatauv(L, sz, 0); 409 | w->w = width; 410 | w->h = height; 411 | if (luaL_newmetatable(L, "ANN_WEIGHT")) { 412 | lua_pushvalue(L, -1); 413 | lua_setfield(L, -2, "__index"); 414 | luaL_Reg l[] = { 415 | { "import", lweight_import }, 416 | { "zero", lweight_zero }, 417 | { "randn", lweight_randn }, 418 | { "size", lweight_size }, 419 | { "accumulate", lweight_accumulate }, 420 | { "__tostring", lweight_dump }, 421 | { NULL, NULL }, 422 | }; 423 | luaL_setfuncs(L, l, 0); 424 | } 425 | lua_setmetatable(L, -2); 426 | return 1; 427 | } 428 | 429 | static int 430 | lprop(lua_State *L) { 431 | struct signal * input = check_signal(L, 1); 432 | struct signal * output = check_signal(L, 2); 433 | struct weight * w = check_weight(L, 3); 434 | if (input->n != w->w || output->n != w->h) { 435 | return luaL_error(L, "Invalid weight (%d , %d) != (%d , %d)", w->w, w->h, input->n, output->n); 436 | } 437 | int i,j; 438 | const float * c = w->data; 439 | for (i=0;in;i++) { 440 | float s = 0; 441 | for (j=0;jn;j++) { 442 | s += input->data[j] * (*c); 443 | ++c; 444 | } 445 | output->data[i] = s; 446 | } 447 | return 0; 448 | } 449 | 450 | // source(w) <----w(w,h)---- delta(h) 451 | 452 | static int 453 | lbackprop_weight(lua_State *L) { 454 | struct signal * source = check_signal(L, 1); 455 | struct signal * delta = check_signal(L, 2); 456 | struct weight * w = check_weight(L, 3); 457 | if (source->n != w->w || delta->n != w->h) { 458 | return luaL_error(L, "Invalid weight (%d , %d) != (%d, %d)", w->w, w->h, source->n, delta->n); 459 | } 460 | int i,j; 461 | float * nabla = w->data; 462 | for (i=0;in;i++) { 463 | float d = delta->data[i]; 464 | for (j=0;jn;j++) { 465 | *nabla = d * source->data[j]; 466 | ++nabla; 467 | } 468 | } 469 | return 0; 470 | } 471 | 472 | 473 | // output_delta(w) <----w(w,h)----- delta(h) 474 | 475 | static int 476 | lbackprop_bias(lua_State *L) { 477 | struct signal * output = check_signal(L, 1); 478 | struct signal * delta = check_signal(L, 2); 479 | struct weight * w = check_weight(L, 3); 480 | if (output->n != w->w || delta->n != w->h) { 481 | return luaL_error(L, "Invalid weight (%d , %d) != (%d, %d)", w->w, w->h, output->n, delta->n); 482 | } 483 | int i,j; 484 | for (i=0;in;i++) { 485 | const float * weight = &w->data[i]; 486 | float d = 0; 487 | for (j=0;jn;j++) { 488 | d += delta->data[j] * (*weight); 489 | weight += w->w; 490 | } 491 | output->data[i] = d; 492 | } 493 | return 0; 494 | } 495 | 496 | static int 497 | lbackprop_sigmoid(lua_State *L) { 498 | struct signal * s = check_signal(L, 1); 499 | struct signal * input = check_signal(L, 2); 500 | if (s->n != input->n) 501 | return luaL_error(L, "Invalid signal size"); 502 | int i; 503 | for (i=0;in;i++) { 504 | input->data[i] *= sigmoid_prime(s->data[i]); 505 | } 506 | return 0; 507 | } 508 | 509 | static int 510 | lbackprop_relu(lua_State *L) { 511 | struct signal * s = check_signal(L, 1); 512 | struct signal * input = check_signal(L, 2); 513 | if (s->n != input->n) 514 | return luaL_error(L, "Invalid signal size"); 515 | int i; 516 | for (i=0;in;i++) { 517 | if (s->data[i] <= 0) 518 | input->data[i] = 0; 519 | } 520 | return 0; 521 | } 522 | 523 | static void 524 | softmax(struct signal *a, struct signal *output) { 525 | int i; 526 | float m = a->data[0]; 527 | for (i=1;in;i++) { 528 | if (a->data[i] > m) 529 | m = a->data[i]; 530 | } 531 | float sum = 0; 532 | for (i=0;in;i++) { 533 | float exp_a = expf(a->data[i] - m); 534 | output->data[i] = exp_a; 535 | sum += exp_a; 536 | } 537 | float inv_sum = 1.0f / sum; 538 | for (i=0;in;i++) { 539 | output->data[i] *= inv_sum; 540 | } 541 | } 542 | 543 | 544 | static int 545 | lsignal_softmax(lua_State *L) { 546 | struct signal * a = check_signal(L, 1); 547 | struct signal * b = check_signal(L, 2); 548 | struct signal * output = check_signal(L, 3); 549 | if (a->n != b->n || a->n != output->n) 550 | return luaL_error(L, "Invalid signal size"); 551 | softmax(a, output); 552 | int i; 553 | for (i=0;in;i++) { 554 | output->data[i] -= b->data[i]; 555 | } 556 | return 0; 557 | } 558 | 559 | // filter for convolution with stride 1. 560 | struct filter { 561 | int size; // (size * size) filter 562 | int pooling; 563 | int n; 564 | int src_w; 565 | int src_h; 566 | float f[1]; // bias[n] + weight[size * size * n] 567 | }; 568 | 569 | static inline size_t 570 | filter_size(int size, int n) { 571 | int nfloat = (size * size + 1) * n; 572 | return sizeof(struct filter) + (nfloat - 1) * sizeof(float); 573 | } 574 | 575 | static inline float * 576 | filter_weight(struct filter *f, int n) { 577 | return f->f + f->n + f->size * f->size * n; 578 | } 579 | 580 | static inline float 581 | filter_bias(struct filter *f, int n) { 582 | return f->f[n]; 583 | } 584 | 585 | static struct filter * 586 | check_filter(lua_State *L, int index) { 587 | return luaL_checkudata(L, index, "ANN_FILTER"); 588 | } 589 | 590 | static int 591 | lfilter_randn(lua_State *L) { 592 | struct filter *f = check_filter(L, 1); 593 | float deviation = luaL_optnumber(L, 2, 1.0f); 594 | int n = f->n * (1 + f->size * f->size); 595 | randn(f->f, n, deviation); 596 | lua_settop(L, 1); 597 | return 1; 598 | } 599 | 600 | static int 601 | lfilter_zero(lua_State *L) { 602 | struct filter *f = check_filter(L, 1); 603 | int n = f->n * (1 + f->size * f->size); 604 | memset(f->f, 0, n * sizeof(float)); 605 | lua_settop(L, 1); 606 | return 1; 607 | } 608 | 609 | static int 610 | lfilter_dump(lua_State *L) { 611 | struct filter * f = check_filter(L, 1); 612 | int i,j,k; 613 | luaL_Buffer b; 614 | luaL_buffinit(L, &b); 615 | 616 | for (i=0; in; i++) { 617 | addfloat(L, &b, filter_bias(f, i)); 618 | luaL_addchar(&b, '\n'); 619 | float * w = filter_weight(f, i); 620 | for (j=0;jsize;j++) { 621 | luaL_addlstring(&b, " [", 3); 622 | for (k=0;ksize;k++) { 623 | addfloat(L, &b, *w); 624 | ++w; 625 | } 626 | luaL_addlstring(&b, "]\n", 2); 627 | } 628 | } 629 | luaL_pushresult(&b); 630 | return 1; 631 | } 632 | 633 | static void 634 | set_arg(lua_State *L, const char *key, int v) { 635 | lua_pushinteger(L, v); 636 | lua_setfield(L, -2, key); 637 | } 638 | 639 | static inline void 640 | filter_output_size(struct filter *f, int *w, int *h) { 641 | *w = f->src_w - f->size + 1; 642 | *h = f->src_h - f->size + 1; 643 | } 644 | 645 | static int 646 | lfilter_args(lua_State *L) { 647 | struct filter * f = check_filter(L, 1); 648 | lua_newtable(L); 649 | set_arg(L, "size", f->size); 650 | set_arg(L, "n", f->n); 651 | set_arg(L, "w", f->src_w); 652 | set_arg(L, "h", f->src_h); 653 | set_arg(L, "pooling", f->pooling); 654 | int dw, dh; 655 | filter_output_size(f, &dw, &dh); 656 | set_arg(L, "cw", dw); 657 | set_arg(L, "ch", dh); 658 | set_arg(L, "conv_size", dw * dh * f->n); 659 | dw /= f->pooling; 660 | dh /= f->pooling; 661 | set_arg(L, "pw", dw); 662 | set_arg(L, "ph", dh); 663 | set_arg(L, "output_size", dw * dh * f->n); 664 | 665 | return 1; 666 | } 667 | 668 | static inline float 669 | conv_dot(const float *src, int stride, const float *f, int fsize) { 670 | int i,j; 671 | float s = 0; 672 | const float *line = src; 673 | for (i=0;isrc_w * f->src_h; 708 | int dw,dh; 709 | filter_output_size(f, &dw, &dh); 710 | int output_size = dw * dh; 711 | if (input_size != input->n) 712 | return luaL_error(L, "Invalid input signal size %d * %d != %d", f->src_w, f->src_h, input->n); 713 | if (output_size * f->n != output->n) 714 | return luaL_error(L, "Invalid output signal size %d * %d * %d != %d", dw, dh, f->n, output->n); 715 | 716 | int i; 717 | float *oimg = output->data; 718 | for (i=0;in;i++) { 719 | conv2dpool(input->data, f->src_w, f->src_h, oimg, f->size, filter_weight(f, i), filter_bias(f, i)); 720 | oimg += output_size; 721 | } 722 | return 0; 723 | } 724 | 725 | static inline float 726 | pooling_max(const float *src, int x, int y, int pooling, int stride) { 727 | int i,j; 728 | src += y * pooling * stride + x * pooling; 729 | float m = *src; 730 | for (i=0;i m) 734 | m = v; 735 | } 736 | src += stride; 737 | } 738 | return m; 739 | } 740 | 741 | static int 742 | lfilter_maxpooling(lua_State *L) { 743 | struct filter *f = check_filter(L, 1); 744 | struct signal *input = check_signal(L, 2); 745 | struct signal *output = check_signal(L, 3); 746 | 747 | int dw,dh; 748 | filter_output_size(f, &dw, &dh); 749 | int input_size = dw * dh; 750 | if (input_size * f->n != input->n) 751 | return luaL_error(L, "Invalid input signal size %d * %d * %d != %d", dw, dh, f->n, input->n); 752 | int pw = dw / f->pooling; 753 | int ph = dh / f->pooling; 754 | int output_size = pw * ph; 755 | if (output_size * f->n != output->n) 756 | return luaL_error(L, "Invalid output signal size %d * %d * %d != %d", pw, ph, f->n, output->n); 757 | 758 | int i,j,k; 759 | float *ptr = output->data; 760 | const float * src = input->data; 761 | int pooling_size = f->pooling; 762 | for (i=0;in;i++) { 763 | for (j=0;j maxv) { 783 | maxv = v; 784 | m = &conv[j]; 785 | } 786 | conv[j] = 0; 787 | } 788 | conv += stride; 789 | } 790 | *m = delta; 791 | } 792 | 793 | static void 794 | pooling_max_backprop(const float *delta_img, float *conv_img, int w, int h, int pooling) { 795 | int i,j; 796 | int y = h - pooling + 1; 797 | int x = w - pooling + 1; 798 | int stride = w * pooling; 799 | for (i=0;in / f->n; 834 | const float * ptr = delta->data; 835 | int i,j; 836 | for (i=0;in;i++) { 837 | float s = 0; 838 | for (j=0;jf[i] = s; 843 | } 844 | 845 | return 0; 846 | } 847 | 848 | static int 849 | lbackprop_maxpooling(lua_State *L) { 850 | struct filter *f = check_filter(L, 1); 851 | struct signal *conv = check_signal(L, 2); 852 | struct signal *delta = check_signal(L, 3); 853 | 854 | int dw,dh; 855 | filter_output_size(f, &dw, &dh); 856 | int conv_size = dw * dh; 857 | 858 | int pw = dw / f->pooling; 859 | int ph = dh / f->pooling; 860 | int output_size = pw * ph; 861 | 862 | if (conv_size * f->n != conv->n) 863 | return luaL_error(L, "Invalid input convolution size %d * %d != %d", dw, dh, f->n, conv->n); 864 | 865 | if (output_size * f->n != delta->n) 866 | return luaL_error(L, "Invalid output signal size %d * %d * %d != %d", pw, ph, f->n, delta->n); 867 | 868 | int i; 869 | const float * delta_img = delta->data; 870 | float * conv_img = conv->data; 871 | for (i=0;in;i++) { 872 | pooling_max_backprop(delta_img, conv_img, dw, dh, f->pooling); 873 | delta_img += output_size; 874 | conv_img += conv_size; 875 | } 876 | 877 | return 0; 878 | } 879 | 880 | static int 881 | lbackprop_conv_weight(lua_State *L) { 882 | struct filter *f = check_filter(L, 1); 883 | struct signal *input = check_signal(L, 2); 884 | struct signal *delta = check_signal(L, 3); 885 | 886 | int input_size = f->src_w * f->src_h; 887 | int dw,dh; 888 | filter_output_size(f, &dw, &dh); 889 | int delta_size = dw * dh; 890 | 891 | if (input_size != input->n) 892 | return luaL_error(L, "Invalid input signal size %d * %d != %d", f->src_w, f->src_h, input->n); 893 | 894 | if (delta_size * f->n != delta->n) 895 | return luaL_error(L, "Invalid input delta size %d * %d != %d", dw, dh, f->n, delta->n); 896 | 897 | const float * input_img = input->data; 898 | float * delta_img = delta->data; 899 | int i,j,k; 900 | for (i=0;in;i++) { 901 | const float * line = input_img; 902 | float * w = filter_weight(f, i); 903 | for (j=0;jsize;j++) { 904 | for (k=0;ksize;k++) { 905 | *w = calc_filter_weight(line + k, f->src_w, delta_img, dw, dh); 906 | ++w; 907 | } 908 | line += f->src_w; 909 | } 910 | delta_img += delta_size; 911 | } 912 | 913 | return 0; 914 | } 915 | 916 | static int 917 | lfilter_clone(lua_State *L) { 918 | struct filter *f = check_filter(L, 1); 919 | size_t sz = lua_rawlen(L, 1); 920 | void * c = lua_newuserdatauv(L, sz, 0); 921 | memcpy(c, f, sz); 922 | lua_getmetatable(L, 1); 923 | lua_setmetatable(L, -2); 924 | 925 | return 1; 926 | } 927 | 928 | static int 929 | lfilter_accumulate(lua_State *L) { 930 | struct filter * f = check_filter(L, 1); 931 | struct filter * delta = check_filter(L, 2); 932 | if (f->size != delta->size || f->n != delta->n) 933 | return luaL_error(L, "filter size (%d , %d) != (%d , %d)", f->size, f->n, delta->size, delta->n); 934 | int i; 935 | int nfloat = (f->size * f->size + 1) * f->n; 936 | if (lua_type(L, 3) == LUA_TNUMBER) { 937 | float eta = lua_tonumber(L, 3); 938 | for (i=0;if[i] += delta->f[i] * eta; 940 | } 941 | } else { 942 | for (i=0;if[i] += delta->f[i]; 944 | } 945 | } 946 | lua_settop(L, 1); 947 | return 1; 948 | } 949 | 950 | static int 951 | lfilter_export(lua_State *L) { 952 | struct filter * f = check_filter(L, 1); 953 | lua_createtable(L, f->n, 0); 954 | int i,j; 955 | int size = f->size * f->size; 956 | for (i=0;in;i++) { 957 | lua_createtable(L, size, 1); 958 | const float * w = filter_weight(f, i); 959 | for (j=0;jsize * f->size; 976 | for (i=0;in;i++) { 977 | if (lua_rawgeti(L, 2, i+1) != LUA_TTABLE) 978 | return luaL_error(L, "[%d] is not a table (%s)", i+1, lua_typename(L, lua_type(L, -1))); 979 | if (lua_getfield(L, -1, "bias") != LUA_TNUMBER) { 980 | return luaL_error(L, "[%d] missing bias", i+1); 981 | } 982 | f->f[i] = lua_tonumber(L, -1); 983 | lua_pop(L, 1); 984 | float *w = filter_weight(f, i); 985 | for (j=0;jsize = size; 1008 | f->n = n; 1009 | f->src_w = src_w; 1010 | f->src_h = src_h; 1011 | f->pooling = pooling; 1012 | 1013 | if (luaL_newmetatable(L, "ANN_FILTER")) { 1014 | lua_pushvalue(L, -1); 1015 | lua_setfield(L, -2, "__index"); 1016 | luaL_Reg l[] = { 1017 | { "clone", lfilter_clone }, 1018 | { "accumulate", lfilter_accumulate }, 1019 | { "randn", lfilter_randn }, 1020 | { "zero", lfilter_zero }, 1021 | { "__tostring", lfilter_dump }, 1022 | { "args", lfilter_args }, 1023 | { "convolution", lfilter_convolution }, 1024 | { "maxpooling", lfilter_maxpooling }, 1025 | { "export", lfilter_export }, 1026 | { "import", lfilter_import }, 1027 | { "backprop_maxpooling", lbackprop_maxpooling }, 1028 | { "backprop_conv_bias", lbackprop_conv_bias}, 1029 | { "backprop_conv_weight", lbackprop_conv_weight}, 1030 | { NULL, NULL }, 1031 | }; 1032 | luaL_setfuncs(L, l, 0); 1033 | } 1034 | lua_setmetatable(L, -2); 1035 | 1036 | return 1; 1037 | } 1038 | 1039 | LUAMOD_API int 1040 | luaopen_ann(lua_State *L) { 1041 | luaL_checkversion(L); 1042 | luaL_Reg l[] = { 1043 | { "signal" , lsignal }, 1044 | { "weight", lweight }, 1045 | { "prop", lprop }, 1046 | { "backprop_weight", lbackprop_weight }, 1047 | { "backprop_bias", lbackprop_bias }, 1048 | { "softmax_error", lsignal_softmax }, 1049 | { "backprop_sigmoid", lbackprop_sigmoid }, 1050 | { "backprop_relu", lbackprop_relu }, 1051 | { "convpool_filter", lconvpool_filter }, 1052 | { NULL, NULL }, 1053 | }; 1054 | luaL_newlib(L, l); 1055 | return 1; 1056 | } 1057 | -------------------------------------------------------------------------------- /cnn.lua: -------------------------------------------------------------------------------- 1 | local mnist = require "mnist" 2 | local ann = require "ann" 3 | 4 | local labels = mnist.labels "data/train-labels.idx1-ubyte" 5 | local images = mnist.images "data/train-images.idx3-ubyte" 6 | 7 | local network = {} ; network.__index = network 8 | 9 | function network.new(args) 10 | local filter = ann.convpool_filter(args.filter_size, 11 | args.col, 12 | args.row, 13 | args.filter_n, 14 | 2):randn(0.01) 15 | local conv_args = filter:args() 16 | local n = { 17 | filter = filter, 18 | input = ann.signal(args.col * args.row), 19 | conv = ann.signal(conv_args.conv_size), 20 | pooling = ann.signal(conv_args.output_size), 21 | hidden = ann.signal(args.hidden), 22 | output = ann.signal(args.output), 23 | weight_ih = ann.weight(conv_args.output_size, args.hidden):randn(), 24 | weight_ho = ann.weight(args.hidden, args.output):randn(), 25 | bias_hidden = ann.signal(args.hidden):randn(), 26 | bias_output = ann.signal(args.output):randn(), 27 | } 28 | 29 | return setmetatable(n, network) 30 | end 31 | 32 | function network:feedforward(image) 33 | self.input:init(image) 34 | self.filter:convolution(self.input, self.conv) 35 | self.filter:maxpooling(self.conv, self.pooling) 36 | self.pooling:relu() 37 | ann.prop(self.pooling, self.hidden, self.weight_ih) 38 | self.hidden:accumulate(self.bias_hidden):sigmoid() 39 | ann.prop(self.hidden, self.output, self.weight_ho) 40 | return self.output:accumulate(self.bias_output) 41 | end 42 | 43 | local function shffule_training_data(t) 44 | local n = #t 45 | for i = 1, n - 1 do 46 | local r = math.random(i, n) 47 | t[i] , t[r] = t[r], t[i] 48 | end 49 | end 50 | 51 | function network:train(training_data, batch_size, eta) 52 | shffule_training_data(training_data) 53 | 54 | local eta_ = - eta / batch_size 55 | local filter_delta = self.filter:clone() 56 | local filter_delta_s = self.filter:clone() 57 | local dw_ih = ann.weight(self.weight_ih:size()) 58 | local dw_ih_s = ann.weight(self.weight_ih:size()) 59 | local dw_ho = ann.weight(self.weight_ho:size()) 60 | local dw_ho_s = ann.weight(self.weight_ho:size()) 61 | local db_output_s = ann.signal(self.output:size()) 62 | local db_hidden = ann.signal(self.hidden:size()) 63 | local db_hidden_s = ann.signal(self.hidden:size()) 64 | local db_pooling = ann.signal(self.pooling:size()) 65 | 66 | local db_output 67 | 68 | local function backprop(expect) 69 | -- calc error 70 | ann.softmax_error(self.output, expect, db_output) 71 | -- backprop from output to hidden 72 | ann.backprop_weight(self.hidden, db_output, dw_ho) 73 | ann.backprop_bias(db_hidden, db_output, self.weight_ho) 74 | ann.backprop_sigmoid(self.hidden, db_hidden) 75 | 76 | -- backprop from hidden to pooling 77 | ann.backprop_weight(self.pooling, db_hidden, dw_ih) 78 | ann.backprop_bias(db_pooling, db_hidden, self.weight_ih) 79 | 80 | -- backprop convpooling 81 | ann.backprop_relu(self.pooling, db_pooling) 82 | filter_delta:backprop_conv_bias(db_pooling) 83 | filter_delta:backprop_maxpooling(self.conv, db_pooling) 84 | filter_delta:backprop_conv_weight(self.input, self.conv) 85 | end 86 | 87 | local scale = 1 / db_pooling:size() 88 | 89 | for i = 1, #training_data, batch_size do 90 | self:feedforward(training_data[i].image) 91 | db_output = db_output_s 92 | backprop(training_data[i].expect) 93 | 94 | db_output = self.output 95 | db_hidden_s, db_hidden = db_hidden, db_hidden_s 96 | dw_ih_s, dw_ih = dw_ih, dw_ih_s 97 | dw_ho_s, dw_ho = dw_ho, dw_ho_s 98 | filter_delta, filter_delta_s = filter_delta_s, filter_delta 99 | 100 | for j = 1, batch_size-1 do 101 | local image = training_data[i+j] 102 | if image then 103 | self:feedforward(image.image) 104 | backprop(training_data[i+j].expect) 105 | dw_ih_s:accumulate(dw_ih) 106 | dw_ho_s:accumulate(dw_ho) 107 | db_output_s:accumulate(db_output) 108 | db_hidden_s:accumulate(db_hidden) 109 | filter_delta_s:accumulate(filter_delta) 110 | else 111 | eta_ = - eta / j 112 | break 113 | end 114 | end 115 | 116 | self.weight_ih:accumulate(dw_ih_s, eta_) 117 | self.weight_ho:accumulate(dw_ho_s, eta_) 118 | self.bias_hidden:accumulate(db_hidden_s, eta_) 119 | self.bias_output:accumulate(db_output_s, eta_) 120 | self.filter:accumulate(filter_delta_s, eta_ * scale) 121 | end 122 | end 123 | 124 | local function gen_training_data() 125 | local result = {} 126 | for i = 0, 9 do 127 | result[i] = ann.signal(10):init(i) 128 | end 129 | local training = {} 130 | for i = 1, #images do 131 | training[i] = { 132 | image = images[i], 133 | expect = result[labels[i] ], 134 | value = labels[i], 135 | } 136 | end 137 | return training 138 | end 139 | 140 | local n = network.new { 141 | row = images.row, 142 | col = images.col, 143 | filter_size = 5, 144 | filter_n = 30, 145 | hidden = 30, 146 | output = 10, 147 | } 148 | 149 | local data = gen_training_data() 150 | 151 | local labels = mnist.labels "data/t10k-labels.idx1-ubyte" 152 | local images = mnist.images "data/t10k-images.idx3-ubyte" 153 | 154 | local function test() 155 | local s = 0 156 | for idx = 1, #labels do 157 | local r, p = n:feedforward(images[idx]):max() 158 | local label = labels[idx] 159 | if r~=label then 160 | s = s + 1 161 | end 162 | end 163 | return (s / #labels * 100) .."%" 164 | end 165 | 166 | for i = 1, 30 do 167 | n:train(data,10,3.0) 168 | print("Epoch", i, test()) 169 | end 170 | -------------------------------------------------------------------------------- /data/placeholder.txt: -------------------------------------------------------------------------------- 1 | 4 files download from http://yann.lecun.com/exdb/mnist/ 2 | 3 | t10k-images.idx3-ubyte 4 | t10k-labels.idx1-ubyte 5 | train-images.idx3-ubyte 6 | train-labels.idx1-ubyte -------------------------------------------------------------------------------- /mnist.c: -------------------------------------------------------------------------------- 1 | #define LUA_LIB 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | static int 9 | label_get(lua_State *L) { 10 | uint8_t *data = (uint8_t *)luaL_checkudata(L, 1, "MNIST_LABELS"); 11 | int n = luaL_checkinteger(L, 2); 12 | int sz = lua_rawlen(L, 1); 13 | if (n <= 0 || n > sz) { 14 | return luaL_error(L, "Out of range %d [1, %d]", n, sz); 15 | } 16 | lua_pushinteger(L, data[n-1]); 17 | return 1; 18 | } 19 | 20 | static int 21 | label_len(lua_State *L) { 22 | luaL_checkudata(L, 1, "MNIST_LABELS"); 23 | int sz = lua_rawlen(L, 1); 24 | lua_pushinteger(L, sz); 25 | return 1; 26 | } 27 | 28 | static uint32_t 29 | read_uint32(FILE *f) { 30 | uint8_t bytes[4] = {0} ; 31 | fread(bytes, 1, 4, f); 32 | return bytes[0] << 24 | bytes[1] << 16 | bytes[2] << 8 | bytes[3]; 33 | } 34 | 35 | static int 36 | read_labels(lua_State *L) { 37 | const char * filename = luaL_checkstring(L, 1); 38 | FILE *f = fopen(filename, "rb"); 39 | if (f == NULL) 40 | return luaL_error(L, "Can't open %s", filename); 41 | uint32_t magic = read_uint32(f); 42 | if (magic != 2049) 43 | return luaL_error(L, "Invalid magic number %d (Should be 2049)", magic); 44 | uint32_t number = read_uint32(f); 45 | void *data = lua_newuserdatauv(L, number, 0); 46 | if (fread(data, 1, number, f) != number) 47 | return luaL_error(L, "Invalid labels number (%d)", number); 48 | fclose(f); 49 | if (luaL_newmetatable(L, "MNIST_LABELS")) { 50 | luaL_Reg l[] = { 51 | { "__index", label_get }, 52 | { "__len", label_len }, 53 | { NULL, NULL }, 54 | }; 55 | luaL_setfuncs(L, l, 0); 56 | } 57 | lua_setmetatable(L, -2); 58 | return 1; 59 | } 60 | 61 | struct image_meta { 62 | uint32_t n; 63 | uint32_t row; 64 | uint32_t col; 65 | }; 66 | 67 | static int 68 | image_len(lua_State *L) { 69 | luaL_checkudata(L, 1, "MNIST_IMAGES"); 70 | lua_getiuservalue(L, 1, 1); 71 | struct image_meta *meta = (struct image_meta *)lua_touserdata(L, -1); 72 | lua_pushinteger(L, meta->n); 73 | return 1; 74 | } 75 | 76 | static int 77 | image_attrib(lua_State *L, struct image_meta *meta, const char *what) { 78 | if (strcmp(what, "row") == 0) { 79 | lua_pushinteger(L, meta->row); 80 | return 1; 81 | } else if (strcmp(what, "col") == 0) { 82 | lua_pushinteger(L, meta->col); 83 | return 1; 84 | } 85 | return luaL_error(L, "Can't get .%s", what); 86 | } 87 | 88 | static int 89 | image_get(lua_State *L) { 90 | luaL_checkudata(L, 1, "MNIST_IMAGES"); 91 | lua_getiuservalue(L, 1, 1); 92 | struct image_meta *meta = (struct image_meta *)lua_touserdata(L, -1); 93 | if (lua_type(L, 2) == LUA_TSTRING) { 94 | return image_attrib(L, meta, lua_tostring(L, 2)); 95 | } 96 | int idx = luaL_checkinteger(L, 2); 97 | if (idx <= 0 || idx > meta->n) { 98 | return luaL_error(L, "Out of range %d [1, %d]", idx, meta->n); 99 | } 100 | size_t stride = meta->row * meta->col; 101 | const char * image = (const char *)lua_touserdata(L, 1); 102 | image = image + stride * (idx-1); 103 | lua_pushlstring(L, image, stride); 104 | return 1; 105 | } 106 | 107 | static int 108 | read_images(lua_State *L) { 109 | struct image_meta *meta = (struct image_meta *)lua_newuserdatauv(L, sizeof(*meta), 0); 110 | const char * filename = luaL_checkstring(L, 1); 111 | FILE *f = fopen(filename, "rb"); 112 | if (f == NULL) 113 | return luaL_error(L, "Can't open %s", filename); 114 | uint32_t magic = read_uint32(f); 115 | if (magic != 2051) 116 | return luaL_error(L, "Invalid magic number %d (Should be 2051)", magic); 117 | meta->n = read_uint32(f); 118 | meta->row = read_uint32(f); 119 | meta->col = read_uint32(f); 120 | size_t sz = meta->n * meta->row * meta->col; 121 | void *data = lua_newuserdatauv(L, sz, 1); 122 | lua_pushvalue(L, -2); 123 | lua_setiuservalue(L, -2, 1); 124 | if (fread(data, 1, sz, f) != sz) 125 | return luaL_error(L, "Invalid images size %dx%dx%d", meta->n, meta->row, meta->col); 126 | fclose(f); 127 | if (luaL_newmetatable(L, "MNIST_IMAGES")) { 128 | luaL_Reg l[] = { 129 | { "__index", image_get }, 130 | { "__len", image_len }, 131 | { NULL, NULL }, 132 | }; 133 | luaL_setfuncs(L, l, 0); 134 | } 135 | lua_setmetatable(L, -2); 136 | return 1; 137 | } 138 | 139 | static int 140 | gen_pgm(lua_State *L) { 141 | size_t sz = 0; 142 | const uint8_t * image = (const uint8_t *)luaL_checklstring(L, 1, &sz); 143 | int row = luaL_checkinteger(L, 2); 144 | int col = luaL_checkinteger(L, 3); 145 | size_t stride = row * col; 146 | if (stride != sz) 147 | return luaL_error(L, "Invalid %d x %d", row, col); 148 | luaL_Buffer b; 149 | luaL_buffinit(L, &b); 150 | lua_pushfstring(L, "P5\n%d %d\n255\n", row, col); 151 | luaL_addvalue(&b); 152 | char * buffer = luaL_prepbuffsize(&b, stride); 153 | memcpy(buffer, image, stride); 154 | luaL_addsize(&b, stride); 155 | luaL_pushresult(&b); 156 | return 1; 157 | } 158 | 159 | LUAMOD_API int 160 | luaopen_mnist(lua_State *L) { 161 | luaL_checkversion(L); 162 | luaL_Reg l[] = { 163 | { "labels", read_labels }, 164 | { "images", read_images }, 165 | { "pgm", gen_pgm }, 166 | { NULL, NULL }, 167 | }; 168 | luaL_newlib(L, l); 169 | return 1; 170 | } 171 | -------------------------------------------------------------------------------- /network.lua: -------------------------------------------------------------------------------- 1 | local mnist = require "mnist" 2 | local ann = require "ann" 3 | 4 | local labels = mnist.labels "data/train-labels.idx1-ubyte" 5 | local images = mnist.images "data/train-images.idx3-ubyte" 6 | 7 | local network = {} ; network.__index = network 8 | 9 | function network.new(args) 10 | local n = { 11 | input = ann.signal(args.input), 12 | hidden = ann.signal(args.hidden), 13 | output = ann.signal(args.output), 14 | weight_ih = ann.weight(args.input, args.hidden):randn(), 15 | weight_ho = ann.weight(args.hidden, args.output):randn(), 16 | bias_hidden = ann.signal(args.hidden):randn(), 17 | bias_output = ann.signal(args.output):randn(), 18 | } 19 | 20 | return setmetatable(n, network) 21 | end 22 | 23 | function network:feedforward(image) 24 | self.input:init(image) 25 | ann.prop(self.input, self.hidden, self.weight_ih) 26 | self.hidden:accumulate(self.bias_hidden):sigmoid() 27 | ann.prop(self.hidden, self.output, self.weight_ho) 28 | return self.output:accumulate(self.bias_output) 29 | end 30 | 31 | local function shffule_training_data(t) 32 | local n = #t 33 | for i = 1, n - 1 do 34 | local r = math.random(i, n) 35 | t[i] , t[r] = t[r], t[i] 36 | end 37 | end 38 | 39 | function network:train(training_data, batch_size, eta) 40 | shffule_training_data(training_data) 41 | 42 | local eta_ = - eta / batch_size 43 | local dw_ih = ann.weight(self.weight_ih:size()) 44 | local dw_ih_s = ann.weight(self.weight_ih:size()) 45 | local dw_ho = ann.weight(self.weight_ho:size()) 46 | local dw_ho_s = ann.weight(self.weight_ho:size()) 47 | local db_output_s = ann.signal(self.output:size()) 48 | local db_hidden = ann.signal(self.hidden:size()) 49 | local db_hidden_s = ann.signal(self.hidden:size()) 50 | 51 | local db_output 52 | 53 | local function backprop(expect) 54 | -- calc error 55 | ann.softmax_error(self.output, expect, db_output) 56 | -- backprop from output to hidden 57 | ann.backprop_weight(self.hidden, db_output, dw_ho) 58 | ann.backprop_bias(db_hidden, db_output, self.weight_ho) 59 | ann.backprop_sigmoid(self.hidden, db_hidden) 60 | -- backprop from hidden to input 61 | ann.backprop_weight(self.input, db_hidden, dw_ih) 62 | end 63 | 64 | for i = 1, #training_data, batch_size do 65 | self:feedforward(training_data[i].image) 66 | db_output = db_output_s 67 | backprop(training_data[i].expect) 68 | db_output = self.output 69 | db_hidden_s, db_hidden = db_hidden, db_hidden_s 70 | dw_ih_s, dw_ih = dw_ih, dw_ih_s 71 | dw_ho_s, dw_ho = dw_ho, dw_ho_s 72 | 73 | for j = 1, batch_size-1 do 74 | local image = training_data[i+j] 75 | if image then 76 | self:feedforward(image.image) 77 | backprop(training_data[i+j].expect) 78 | dw_ih_s:accumulate(dw_ih) 79 | dw_ho_s:accumulate(dw_ho) 80 | db_output_s:accumulate(db_output) 81 | db_hidden_s:accumulate(db_hidden) 82 | else 83 | eta_ = - eta / j 84 | break 85 | end 86 | end 87 | 88 | self.weight_ih:accumulate(dw_ih_s, eta_) 89 | self.weight_ho:accumulate(dw_ho_s, eta_) 90 | self.bias_hidden:accumulate(db_hidden_s, eta_) 91 | self.bias_output:accumulate(db_output_s, eta_) 92 | end 93 | end 94 | 95 | local function gen_training_data() 96 | local result = {} 97 | for i = 0, 9 do 98 | result[i] = ann.signal(10):init(i) 99 | end 100 | local training = {} 101 | for i = 1, #images do 102 | training[i] = { 103 | image = images[i], 104 | expect = result[labels[i] ], 105 | value = labels[i], 106 | } 107 | end 108 | return training 109 | end 110 | 111 | local n = network.new { 112 | input = images.row * images.col, 113 | hidden = 30, 114 | output = 10, 115 | } 116 | 117 | local data = gen_training_data() 118 | 119 | local labels = mnist.labels "data/t10k-labels.idx1-ubyte" 120 | local images = mnist.images "data/t10k-images.idx3-ubyte" 121 | 122 | local function test() 123 | local s = 0 124 | for idx = 1, #labels do 125 | local r, p = n:feedforward(images[idx]):max() 126 | local label = labels[idx] 127 | if r~=label then 128 | s = s + 1 129 | end 130 | end 131 | return (s / #labels * 100) .."%" 132 | end 133 | 134 | for i = 1, 30 do 135 | n:train(data,20,3.0) 136 | print("Epoch", i, test()) 137 | end 138 | -------------------------------------------------------------------------------- /testconv.lua: -------------------------------------------------------------------------------- 1 | local mnist = require "mnist" 2 | local ann = require "ann" 3 | local images = mnist.images "data/train-images.idx3-ubyte" 4 | 5 | local x, y = images.col, images.row 6 | 7 | local image = ann.signal(x * y):init(images[1]) 8 | 9 | local filter = ann.convpool_filter( 10 | 3, -- size 11 | x,y, 12 | 2, -- filter number 13 | 2) -- pooling 2x2 14 | 15 | filter:import { 16 | { bias = 0, 17 | 0, -1, 0, 18 | 0, 2, 0, 19 | 0, -1, 0 } , 20 | { bias = 0.1, 21 | 0, -0.5, 0, 22 | 0, 1, 0, 23 | 0, -0.5, 0 } 24 | } 25 | 26 | local args = filter:args() 27 | local conv = ann.signal(args.conv_size) 28 | local result = ann.signal(args.output_size) 29 | local expect = ann.signal(args.output_size) 30 | 31 | filter:convolution(image, conv) 32 | filter:maxpooling(conv, expect) 33 | 34 | --[[ 35 | local f = assert(io.open("conv.pgm", "wb")) 36 | f:write(mnist.pgm(expect:image(), args.pw, args.ph * 2)) 37 | f:close() 38 | ]] 39 | 40 | filter:randn() 41 | 42 | local delta = filter:clone() 43 | local scale = - 0.5 / args.output_size 44 | 45 | for i = 1, 20000 do 46 | filter:convolution(image, conv) 47 | filter:maxpooling(conv, result) 48 | result:accumulate(expect, -1) -- error 49 | 50 | delta:backprop_conv_bias(result) 51 | delta:backprop_maxpooling(conv, result) 52 | delta:backprop_conv_weight(image, conv) 53 | 54 | filter:accumulate(delta, scale) 55 | end 56 | 57 | print(filter) 58 | 59 | filter:convolution(image, conv) 60 | filter:maxpooling(conv, expect) 61 | --[[ 62 | local f = assert(io.open("conv2.pgm", "wb")) 63 | f:write(mnist.pgm(expect:image(), args.pw, args.ph * 2)) 64 | f:close() 65 | ]] -------------------------------------------------------------------------------- /testmnist.lua: -------------------------------------------------------------------------------- 1 | local mnist = require "mnist" 2 | 3 | local labels = mnist.labels "data/train-labels.idx1-ubyte" 4 | print(#labels) 5 | 6 | for i = 1, 10 do 7 | print(labels[i]) 8 | end 9 | 10 | local images = mnist.images "data/train-images.idx3-ubyte" 11 | local image = images[1] 12 | local size = images.row * images.col 13 | 14 | local f = assert(io.open("image1.pgm", "wb")) 15 | f:write(mnist.pgm(image, images.row, images.col)) 16 | f:close() 17 | 18 | --------------------------------------------------------------------------------