├── .gitignore ├── Makefile ├── README.md ├── data ├── t10k-images.idx3-ubyte ├── t10k-labels.idx1-ubyte ├── train-images.idx3-ubyte └── train-labels.idx1-ubyte ├── layer.cu ├── layer.h ├── main.cu └── mnist.h /.gitignore: -------------------------------------------------------------------------------- 1 | CNN 2 | a.out 3 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | nvcc -lcuda -lcublas *.cu -o CNN -arch=compute_20 -Wno-deprecated-gpu-targets 3 | 4 | run: 5 | ./CNN 6 | clean: 7 | rm CNN 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CNN on CUDA 2 | Implementation of Convolutional Neural Network using CUDA. On testing with MNIST dataset for 50 epochs, accuracy of 97.22% was obtained with a GPU training time of about 650 seconds. 3 | 4 | ### Architecture 5 | All tests performed on an Nvidia GeForce 840M GPU, running CUDA 8.0.61. 6 | 7 | ### Compiling and Execution 8 | To compile just navigate to root and type `make` 9 | Executable can be run using `./CNN` 10 | -------------------------------------------------------------------------------- /data/t10k-images.idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paramhanji/CUDA-CNN/95c9f335451131c4e3b21b065168199f01371ed2/data/t10k-images.idx3-ubyte -------------------------------------------------------------------------------- /data/t10k-labels.idx1-ubyte: -------------------------------------------------------------------------------- 1 | '                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             -------------------------------------------------------------------------------- /data/train-images.idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paramhanji/CUDA-CNN/95c9f335451131c4e3b21b065168199f01371ed2/data/train-images.idx3-ubyte -------------------------------------------------------------------------------- /data/train-labels.idx1-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paramhanji/CUDA-CNN/95c9f335451131c4e3b21b065168199f01371ed2/data/train-labels.idx1-ubyte -------------------------------------------------------------------------------- /layer.cu: -------------------------------------------------------------------------------- 1 | #include "layer.h" 2 | 3 | // Constructor 4 | Layer::Layer(int M, int N, int O) 5 | { 6 | this->M = M; 7 | this->N = N; 8 | this->O = O; 9 | 10 | float h_bias[N]; 11 | float h_weight[N][M]; 12 | 13 | output = NULL; 14 | preact = NULL; 15 | bias = NULL; 16 | weight = NULL; 17 | 18 | for (int i = 0; i < N; ++i) { 19 | h_bias[i] = 0.5f - float(rand()) / float(RAND_MAX); 20 | /*h_bias[i] = 0.0f;*/ 21 | 22 | for (int j = 0; j < M; ++j) { 23 | h_weight[i][j] = 0.5f - float(rand()) / float(RAND_MAX); 24 | /*h_weight[i][j] = 0.05f;*/ 25 | } 26 | } 27 | 28 | cudaMalloc(&output, sizeof(float) * O); 29 | cudaMalloc(&preact, sizeof(float) * O); 30 | 31 | cudaMalloc(&bias, sizeof(float) * N); 32 | 33 | cudaMalloc(&weight, sizeof(float) * M * N); 34 | 35 | cudaMalloc(&d_output, sizeof(float) * O); 36 | cudaMalloc(&d_preact, sizeof(float) * O); 37 | cudaMalloc(&d_weight, sizeof(float) * M * N); 38 | 39 | cudaMemcpy(bias, h_bias, sizeof(float) * N, cudaMemcpyHostToDevice); 40 | 41 | cudaMemcpy(weight, h_weight, sizeof(float) * M * N, cudaMemcpyHostToDevice); 42 | } 43 | 44 | // Destructor 45 | Layer::~Layer() 46 | { 47 | cudaFree(output); 48 | cudaFree(preact); 49 | 50 | cudaFree(bias); 51 | 52 | cudaFree(weight); 53 | 54 | cudaFree(d_output); 55 | cudaFree(d_preact); 56 | cudaFree(d_weight); 57 | } 58 | 59 | // Send data one row from dataset to the GPU 60 | void Layer::setOutput(float *data) 61 | { 62 | cudaMemcpy(output, data, sizeof(float) * O, cudaMemcpyHostToDevice); 63 | } 64 | 65 | // Reset GPU memory between iterations 66 | void Layer::clear() 67 | { 68 | cudaMemset(output, 0x00, sizeof(float) * O); 69 | cudaMemset(preact, 0x00, sizeof(float) * O); 70 | } 71 | 72 | void Layer::bp_clear() 73 | { 74 | cudaMemset(d_output, 0x00, sizeof(float) * O); 75 | cudaMemset(d_preact, 0x00, sizeof(float) * O); 76 | cudaMemset(d_weight, 0x00, sizeof(float) * M * N); 77 | } 78 | 79 | 80 | __device__ float step_function(float v) 81 | { 82 | return 1 / (1 + exp(-v)); 83 | } 84 | 85 | __global__ void apply_step_function(float *input, float *output, const int N) 86 | { 87 | const int pos = blockIdx.x * blockDim.x + threadIdx.x; 88 | const int size = blockDim.x * gridDim.x; 89 | 90 | for (int idx = N * pos / size; idx < N * (pos+1) / size; ++idx) { 91 | output[idx] = step_function(input[idx]); 92 | } 93 | } 94 | 95 | __global__ void makeError(float *err, float *output, unsigned int Y, const int N) 96 | { 97 | const int pos = blockIdx.x * blockDim.x + threadIdx.x; 98 | const int size = blockDim.x * gridDim.x; 99 | 100 | for (int idx = N * pos / size; idx < N * (pos+1) / size; ++idx) { 101 | err[idx] = ((Y == idx ? 1.0f : 0.0f) - output[idx]); 102 | } 103 | } 104 | 105 | __global__ void apply_grad(float *output, float *grad, const int N) 106 | { 107 | const int pos = blockIdx.x * blockDim.x + threadIdx.x; 108 | const int size = blockDim.x * gridDim.x; 109 | 110 | for (int idx = N * pos / size; idx < N * (pos+1) / size; ++idx) { 111 | output[idx] += dt * grad[idx]; 112 | } 113 | } 114 | 115 | __global__ void fp_preact_c1(float input[28][28], float preact[6][24][24], float weight[6][5][5]) 116 | { 117 | const int pos = blockIdx.x * blockDim.x + threadIdx.x; 118 | const int size = blockDim.x * gridDim.x; 119 | 120 | const int N = 5*5*6*24*24; 121 | 122 | for (int n = N * pos / size; n < N * (pos+1) / size; ++n) { 123 | int idx = n; 124 | const int i1 = ((idx /= 1 ) % 5); 125 | const int i2 = ((idx /= 5 ) % 5); 126 | const int i3 = ((idx /= 5 ) % 6); 127 | const int i4 = ((idx /= 6 ) % 24); 128 | const int i5 = ((idx /= 24 ) % 24); 129 | 130 | atomicAdd(&preact[i3][i4][i5], weight[i3][i1][i2] * input[i4 + i1][i5 + i2]); 131 | } 132 | } 133 | 134 | __global__ void fp_bias_c1(float preact[6][24][24], float bias[6]) 135 | { 136 | const int pos = blockIdx.x * blockDim.x + threadIdx.x; 137 | const int size = blockDim.x * gridDim.x; 138 | 139 | const int N = 6*24*24; 140 | 141 | for (int n = N * pos / size; n < N * (pos+1) / size; ++n) { 142 | int idx = n; 143 | const int i1 = ((idx /= 1 ) % 6); 144 | const int i2 = ((idx /= 6 ) % 24); 145 | const int i3 = ((idx /= 24 ) % 24); 146 | 147 | preact[i1][i2][i3] += bias[i1]; 148 | } 149 | } 150 | 151 | __global__ void fp_preact_s1(float input[6][24][24], float preact[6][6][6], float weight[1][4][4]) 152 | { 153 | const int pos = blockIdx.x * blockDim.x + threadIdx.x; 154 | const int size = blockDim.x * gridDim.x; 155 | 156 | const int N = 4*4*6*6*6; 157 | 158 | for (int n = N * pos / size; n < N * (pos+1) / size; ++n) { 159 | int idx = n; 160 | const int i1 = ((idx /= 1 ) % 4); 161 | const int i2 = ((idx /= 4 ) % 4); 162 | const int i3 = ((idx /= 4 ) % 6); 163 | const int i4 = ((idx /= 6 ) % 6); 164 | const int i5 = ((idx /= 6 ) % 6); 165 | 166 | atomicAdd(&preact[i3][i4][i5], weight[0][i1][i2] * input[i3][i4 * 4 + i1][i5 * 4 + i2]); 167 | } 168 | } 169 | 170 | __global__ void fp_bias_s1(float preact[6][6][6], float bias[1]) 171 | { 172 | const int pos = blockIdx.x * blockDim.x + threadIdx.x; 173 | const int size = blockDim.x * gridDim.x; 174 | 175 | const int N = 6*6*6; 176 | 177 | for (int n = N * pos / size; n < N * (pos+1) / size; ++n) { 178 | int idx = n; 179 | const int i1 = ((idx /= 1 ) % 6); 180 | const int i2 = ((idx /= 6 ) % 6); 181 | const int i3 = ((idx /= 6 ) % 6); 182 | 183 | preact[i1][i2][i3] += bias[0]; 184 | } 185 | } 186 | 187 | __global__ void fp_preact_f(float input[6][6][6], float preact[10], float weight[10][6][6][6]) 188 | { 189 | const int pos = blockIdx.x * blockDim.x + threadIdx.x; 190 | const int size = blockDim.x * gridDim.x; 191 | 192 | const int N = 10*6*6*6; 193 | 194 | for (int n = N * pos / size; n < N * (pos+1) / size; ++n) { 195 | int idx = n; 196 | const int i1 = ((idx /= 1 ) % 10); 197 | const int i2 = ((idx /= 10 ) % 6); 198 | const int i3 = ((idx /= 6 ) % 6); 199 | const int i4 = ((idx /= 6 ) % 6); 200 | 201 | atomicAdd(&preact[i1], weight[i1][i2][i3][i4] * input[i2][i3][i4]); 202 | } 203 | } 204 | 205 | __global__ void fp_bias_f(float preact[10], float bias[10]) 206 | { 207 | const int pos = blockIdx.x * blockDim.x + threadIdx.x; 208 | const int size = blockDim.x * gridDim.x; 209 | 210 | const int N = 10; 211 | 212 | for (int idx = N * pos / size; idx < N * (pos+1) / size; ++idx) { 213 | preact[idx] += bias[idx]; 214 | } 215 | } 216 | 217 | __global__ void bp_weight_f(float d_weight[10][6][6][6], float d_preact[10], float p_output[6][6][6]) 218 | { 219 | const int pos = blockIdx.x * blockDim.x + threadIdx.x; 220 | const int size = blockDim.x * gridDim.x; 221 | 222 | const int N = 10*6*6*6; 223 | 224 | for (int n = N * pos / size; n < N * (pos+1) / size; ++n) { 225 | int idx = n; 226 | const int i1 = ((idx /= 1 ) % 10); 227 | const int i2 = ((idx /= 10 ) % 6); 228 | const int i3 = ((idx /= 6 ) % 6); 229 | const int i4 = ((idx /= 6 ) % 6); 230 | 231 | d_weight[i1][i2][i3][i4] = d_preact[i1] * p_output[i2][i3][i4]; 232 | } 233 | } 234 | 235 | __global__ void bp_bias_f(float bias[10], float d_preact[10]) 236 | { 237 | const int pos = blockIdx.x * blockDim.x + threadIdx.x; 238 | const int size = blockDim.x * gridDim.x; 239 | 240 | const int N = 10; 241 | 242 | for (int idx = N * pos / size; idx < N * (pos+1) / size; ++idx) { 243 | bias[idx] += dt * d_preact[idx]; 244 | } 245 | } 246 | 247 | __global__ void bp_output_s1(float d_output[6][6][6], float n_weight[10][6][6][6], float nd_preact[10]) 248 | { 249 | const int pos = blockIdx.x * blockDim.x + threadIdx.x; 250 | const int size = blockDim.x * gridDim.x; 251 | 252 | const int N = 10*6*6*6; 253 | 254 | for (int n = N * pos / size; n < N * (pos+1) / size; ++n) { 255 | int idx = n; 256 | const int i1 = ((idx /= 1 ) % 10); 257 | const int i2 = ((idx /= 10 ) % 6); 258 | const int i3 = ((idx /= 6 ) % 6); 259 | const int i4 = ((idx /= 6 ) % 6); 260 | 261 | atomicAdd(&d_output[i2][i3][i4], n_weight[i1][i2][i3][i4] * nd_preact[i1]); 262 | } 263 | } 264 | 265 | __global__ void bp_preact_s1(float d_preact[6][6][6], float d_output[6][6][6], float preact[6][6][6]) 266 | { 267 | const int pos = blockIdx.x * blockDim.x + threadIdx.x; 268 | const int size = blockDim.x * gridDim.x; 269 | 270 | const int N = 6*6*6; 271 | 272 | for (int n = N * pos / size; n < N * (pos+1) / size; ++n) { 273 | int idx = n; 274 | const int i1 = ((idx /= 1 ) % 6); 275 | const int i2 = ((idx /= 6 ) % 6); 276 | const int i3 = ((idx /= 6 ) % 6); 277 | 278 | const float o = step_function(preact[i1][i2][i3]); 279 | 280 | d_preact[i1][i2][i3] = d_output[i1][i2][i3] * o * (1 - o); 281 | } 282 | } 283 | 284 | __global__ void bp_weight_s1(float d_weight[1][4][4], float d_preact[6][6][6], float p_output[6][24][24]) 285 | { 286 | const int pos = blockIdx.x * blockDim.x + threadIdx.x; 287 | const int size = blockDim.x * gridDim.x; 288 | 289 | const int N = 1*4*4*6*6*6; 290 | const float d = pow(6.0f, 3.0f); 291 | 292 | for (int n = N * pos / size; n < N * (pos+1) / size; ++n) { 293 | int idx = n; 294 | const int i1 = ((idx /= 1 ) % 1); 295 | const int i2 = ((idx /= 1 ) % 4); 296 | const int i3 = ((idx /= 4 ) % 4); 297 | const int i4 = ((idx /= 4 ) % 6); 298 | const int i5 = ((idx /= 6 ) % 6); 299 | const int i6 = ((idx /= 6 ) % 6); 300 | 301 | atomicAdd(&d_weight[i1][i2][i3], d_preact[i4][i5][i6] * p_output[i4][i5 * 4 + i2][i6 * 4 + i3]); 302 | } 303 | } 304 | 305 | __global__ void bp_bias_s1(float bias[1], float d_preact[6][6][6]) 306 | { 307 | const int pos = blockIdx.x * blockDim.x + threadIdx.x; 308 | const int size = blockDim.x * gridDim.x; 309 | 310 | const int N = 6*6*6; 311 | const float d = pow(6.0f, 3.0f); 312 | 313 | for (int n = N * pos / size; n < N * (pos+1) / size; ++n) { 314 | int idx = n; 315 | const int i1 = ((idx /= 1 ) % 6); 316 | const int i2 = ((idx /= 6 ) % 6); 317 | const int i3 = ((idx /= 6 ) % 6); 318 | 319 | atomicAdd(&bias[0], dt * d_preact[i1][i2][i3] / d); 320 | } 321 | } 322 | 323 | __global__ void bp_output_c1(float d_output[6][24][24], float n_weight[1][4][4], float nd_preact[6][6][6]) 324 | { 325 | const int pos = blockIdx.x * blockDim.x + threadIdx.x; 326 | const int size = blockDim.x * gridDim.x; 327 | 328 | const int N = 1*4*4*6*6*6; 329 | 330 | for (int n = N * pos / size; n < N * (pos+1) / size; ++n) { 331 | int idx = n; 332 | const int i1 = ((idx /= 1 ) % 1); 333 | const int i2 = ((idx /= 1 ) % 4); 334 | const int i3 = ((idx /= 4 ) % 4); 335 | const int i4 = ((idx /= 4 ) % 6); 336 | const int i5 = ((idx /= 6 ) % 6); 337 | const int i6 = ((idx /= 6 ) % 6); 338 | 339 | atomicAdd(&d_output[i4][i5 * 4 + i2][i6 * 4 + i3], n_weight[i1][i2][i3] * nd_preact[i4][i5][i6]); 340 | } 341 | } 342 | 343 | __global__ void bp_preact_c1(float d_preact[6][24][24], float d_output[6][24][24], float preact[6][24][24]) 344 | { 345 | const int pos = blockIdx.x * blockDim.x + threadIdx.x; 346 | const int size = blockDim.x * gridDim.x; 347 | 348 | const int N = 6*24*24; 349 | 350 | for (int n = N * pos / size; n < N * (pos+1) / size; ++n) { 351 | int idx = n; 352 | const int i1 = ((idx /= 1 ) % 6); 353 | const int i2 = ((idx /= 6 ) % 24); 354 | const int i3 = ((idx /= 24 ) % 24); 355 | 356 | const float o = step_function(preact[i1][i2][i3]); 357 | 358 | d_preact[i1][i2][i3] = d_output[i1][i2][i3] * o * (1 - o); 359 | } 360 | } 361 | 362 | __global__ void bp_weight_c1(float d_weight[6][5][5], float d_preact[6][24][24], float p_output[28][28]) 363 | { 364 | const int pos = blockIdx.x * blockDim.x + threadIdx.x; 365 | const int size = blockDim.x * gridDim.x; 366 | 367 | const int N = 6*5*5*24*24; 368 | const float d = pow(24.0f, 2.0f); 369 | 370 | for (int n = N * pos / size; n < N * (pos+1) / size; ++n) { 371 | int idx = n; 372 | const int i1 = ((idx /= 1 ) % 6); 373 | const int i2 = ((idx /= 6 ) % 5); 374 | const int i3 = ((idx /= 5 ) % 5); 375 | const int i4 = ((idx /= 5 ) % 24); 376 | const int i5 = ((idx /= 24 ) % 24); 377 | 378 | atomicAdd(&d_weight[i1][i2][i3], d_preact[i1][i4][i5] * p_output[i4 + i2][i5 + i3] / d); 379 | } 380 | } 381 | 382 | __global__ void bp_bias_c1(float bias[6], float d_preact[6][24][24]) 383 | { 384 | const int pos = blockIdx.x * blockDim.x + threadIdx.x; 385 | const int size = blockDim.x * gridDim.x; 386 | 387 | const int N = 6*24*24; 388 | const float d = pow(24.0f, 2.0f); 389 | 390 | for (int n = N * pos / size; n < N * (pos+1) / size; ++n) { 391 | int idx = n; 392 | const int i1 = ((idx /= 1 ) % 6); 393 | const int i2 = ((idx /= 6 ) % 24); 394 | const int i3 = ((idx /= 24 ) % 24); 395 | 396 | atomicAdd(&bias[i1], dt * d_preact[i1][i2][i3] / d); 397 | } 398 | } 399 | -------------------------------------------------------------------------------- /layer.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #ifndef LAYER_H 8 | #define LAYER_H 9 | #endif 10 | 11 | const static float dt = 1.0E-01f; 12 | const static float threshold = 1.0E-02f; 13 | 14 | class Layer { 15 | public: 16 | int M, N, O; 17 | 18 | float *output; 19 | float *preact; 20 | 21 | float *bias; 22 | float *weight; 23 | 24 | float *d_output; 25 | float *d_preact; 26 | float *d_weight; 27 | 28 | Layer(int M, int N, int O); 29 | 30 | ~Layer(); 31 | 32 | void setOutput(float *data); 33 | void clear(); 34 | void bp_clear(); 35 | }; 36 | 37 | 38 | // Utility CUDA kernel functions 39 | __device__ float step_function(float v); 40 | __global__ void apply_step_function(float *input, float *output, const int N); 41 | __global__ void makeError(float *err, float *output, unsigned int Y, const int N); 42 | __global__ void apply_grad(float *output, float *grad, const int N); 43 | 44 | // Forward propagation kernels 45 | __global__ void fp_preact_c1(float input[28][28], float preact[6][24][24], float weight[6][5][5]); 46 | __global__ void fp_bias_c1(float preact[6][24][24], float bias[6]); 47 | __global__ void fp_preact_s1(float input[6][24][24], float preact[6][6][6], float weight[1][4][4]); 48 | __global__ void fp_bias_s1(float preact[6][6][6], float bias[1]); 49 | __global__ void fp_preact_f(float input[6][6][6], float preact[10], float weight[10][6][6][6]); 50 | __global__ void fp_bias_f(float preact[10], float bias[10]); 51 | 52 | // Back propagation kernels 53 | __global__ void bp_weight_f(float d_weight[10][6][6][6], float d_preact[10], float p_output[6][6][6]); 54 | __global__ void bp_bias_f(float bias[10], float d_preact[10]); 55 | __global__ void bp_output_s1(float d_output[6][6][6], float n_weight[10][6][6][6], float nd_preact[10]); 56 | __global__ void bp_preact_s1(float d_preact[6][6][6], float d_output[6][6][6], float preact[6][6][6]); 57 | __global__ void bp_weight_s1(float d_weight[1][4][4], float d_preact[6][6][6], float p_output[6][24][24]); 58 | __global__ void bp_bias_s1(float bias[1], float d_preact[6][6][6]); 59 | __global__ void bp_output_c1(float d_output[6][24][24], float n_weight[1][4][4], float nd_preact[6][6][6]); 60 | __global__ void bp_preact_c1(float d_preact[6][24][24], float d_output[6][24][24], float preact[6][24][24]); 61 | __global__ void bp_weight_c1(float d_weight[6][5][5], float d_preact[6][24][24], float p_output[28][28]); 62 | __global__ void bp_bias_c1(float bias[6], float d_preact[6][24][24]); 63 | -------------------------------------------------------------------------------- /main.cu: -------------------------------------------------------------------------------- 1 | #define USE_MNIST_LOADER 2 | #define MNIST_DOUBLE 3 | #include "mnist.h" 4 | #include "layer.h" 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | static mnist_data *train_set, *test_set; 11 | static unsigned int train_cnt, test_cnt; 12 | 13 | // Define layers of CNN 14 | static Layer l_input = Layer(0, 0, 28*28); 15 | static Layer l_c1 = Layer(5*5, 6, 24*24*6); 16 | static Layer l_s1 = Layer(4*4, 1, 6*6*6); 17 | static Layer l_f = Layer(6*6*6, 10, 10); 18 | 19 | static void learn(); 20 | static unsigned int classify(double data[28][28]); 21 | static void test(); 22 | static double forward_pass(double data[28][28]); 23 | static double back_pass(); 24 | 25 | static inline void loaddata() 26 | { 27 | mnist_load("data/train-images.idx3-ubyte", "data/train-labels.idx1-ubyte", 28 | &train_set, &train_cnt); 29 | mnist_load("data/t10k-images.idx3-ubyte", "data/t10k-labels.idx1-ubyte", 30 | &test_set, &test_cnt); 31 | } 32 | 33 | int main(int argc, const char **argv) 34 | { 35 | srand(time(NULL)); 36 | 37 | CUresult err = cuInit(0); 38 | if (err != CUDA_SUCCESS) { 39 | fprintf(stderr, "CUDA initialisation failed with error code - %d\n", err); 40 | return 1; 41 | } 42 | 43 | loaddata(); 44 | learn(); 45 | test(); 46 | 47 | return 0; 48 | } 49 | 50 | // Forward propagation of a single row in dataset 51 | static double forward_pass(double data[28][28]) 52 | { 53 | float input[28][28]; 54 | 55 | for (int i = 0; i < 28; ++i) { 56 | for (int j = 0; j < 28; ++j) { 57 | input[i][j] = data[i][j]; 58 | } 59 | } 60 | 61 | l_input.clear(); 62 | l_c1.clear(); 63 | l_s1.clear(); 64 | l_f.clear(); 65 | 66 | clock_t start, end; 67 | start = clock(); 68 | 69 | l_input.setOutput((float *)input); 70 | 71 | fp_preact_c1<<<64, 64>>>((float (*)[28])l_input.output, (float (*)[24][24])l_c1.preact, (float (*)[5][5])l_c1.weight); 72 | fp_bias_c1<<<64, 64>>>((float (*)[24][24])l_c1.preact, l_c1.bias); 73 | apply_step_function<<<64, 64>>>(l_c1.preact, l_c1.output, l_c1.O); 74 | 75 | fp_preact_s1<<<64, 64>>>((float (*)[24][24])l_c1.output, (float (*)[6][6])l_s1.preact, (float (*)[4][4])l_s1.weight); 76 | fp_bias_s1<<<64, 64>>>((float (*)[6][6])l_s1.preact, l_s1.bias); 77 | apply_step_function<<<64, 64>>>(l_s1.preact, l_s1.output, l_s1.O); 78 | 79 | fp_preact_f<<<64, 64>>>((float (*)[6][6])l_s1.output, l_f.preact, (float (*)[6][6][6])l_f.weight); 80 | fp_bias_f<<<64, 64>>>(l_f.preact, l_f.bias); 81 | apply_step_function<<<64, 64>>>(l_f.preact, l_f.output, l_f.O); 82 | 83 | end = clock(); 84 | return ((double) (end - start)) / CLOCKS_PER_SEC; 85 | } 86 | 87 | // Back propagation to update weights 88 | static double back_pass() 89 | { 90 | clock_t start, end; 91 | 92 | start = clock(); 93 | 94 | bp_weight_f<<<64, 64>>>((float (*)[6][6][6])l_f.d_weight, l_f.d_preact, (float (*)[6][6])l_s1.output); 95 | bp_bias_f<<<64, 64>>>(l_f.bias, l_f.d_preact); 96 | 97 | bp_output_s1<<<64, 64>>>((float (*)[6][6])l_s1.d_output, (float (*)[6][6][6])l_f.weight, l_f.d_preact); 98 | bp_preact_s1<<<64, 64>>>((float (*)[6][6])l_s1.d_preact, (float (*)[6][6])l_s1.d_output, (float (*)[6][6])l_s1.preact); 99 | bp_weight_s1<<<64, 64>>>((float (*)[4][4])l_s1.d_weight, (float (*)[6][6])l_s1.d_preact, (float (*)[24][24])l_c1.output); 100 | bp_bias_s1<<<64, 64>>>(l_s1.bias, (float (*)[6][6])l_s1.d_preact); 101 | 102 | bp_output_c1<<<64, 64>>>((float (*)[24][24])l_c1.d_output, (float (*)[4][4])l_s1.weight, (float (*)[6][6])l_s1.d_preact); 103 | bp_preact_c1<<<64, 64>>>((float (*)[24][24])l_c1.d_preact, (float (*)[24][24])l_c1.d_output, (float (*)[24][24])l_c1.preact); 104 | bp_weight_c1<<<64, 64>>>((float (*)[5][5])l_c1.d_weight, (float (*)[24][24])l_c1.d_preact, (float (*)[28])l_input.output); 105 | bp_bias_c1<<<64, 64>>>(l_c1.bias, (float (*)[24][24])l_c1.d_preact); 106 | 107 | 108 | apply_grad<<<64, 64>>>(l_f.weight, l_f.d_weight, l_f.M * l_f.N); 109 | apply_grad<<<64, 64>>>(l_s1.weight, l_s1.d_weight, l_s1.M * l_s1.N); 110 | apply_grad<<<64, 64>>>(l_c1.weight, l_c1.d_weight, l_c1.M * l_c1.N); 111 | 112 | end = clock(); 113 | return ((double) (end - start)) / CLOCKS_PER_SEC; 114 | } 115 | 116 | // Unfold the input layer 117 | static void unfold_input(double input[28][28], double unfolded[24*24][5*5]) 118 | { 119 | int a = 0; 120 | (void)unfold_input; 121 | 122 | for (int i = 0; i < 2; ++i) 123 | for (int j = 0; j < 2; ++j) { 124 | int b = 0; 125 | for (int x = i; x < i + 2; ++x) 126 | for (int y = j; y < j+2; ++y) 127 | unfolded[a][b++] = input[x][y]; 128 | a++; 129 | } 130 | } 131 | 132 | static void learn() 133 | { 134 | static cublasHandle_t blas; 135 | cublasCreate(&blas); 136 | 137 | float err; 138 | int iter = 50; 139 | 140 | double time_taken = 0.0; 141 | 142 | fprintf(stdout ,"Learning\n"); 143 | 144 | while (iter < 0 || iter-- > 0) { 145 | err = 0.0f; 146 | 147 | for (int i = 0; i < train_cnt; ++i) { 148 | float tmp_err; 149 | 150 | time_taken += forward_pass(train_set[i].data); 151 | 152 | l_f.bp_clear(); 153 | l_s1.bp_clear(); 154 | l_c1.bp_clear(); 155 | 156 | // Euclid distance of train_set[i] 157 | makeError<<<10, 1>>>(l_f.d_preact, l_f.output, train_set[i].label, 10); 158 | cublasSnrm2(blas, 10, l_f.d_preact, 1, &tmp_err); 159 | err += tmp_err; 160 | 161 | time_taken += back_pass(); 162 | } 163 | 164 | err /= train_cnt; 165 | fprintf(stdout, "error: %e, time_on_gpu: %lf\n", err, time_taken); 166 | 167 | if (err < threshold) { 168 | fprintf(stdout, "Training complete, error less than threshold\n\n"); 169 | break; 170 | } 171 | 172 | } 173 | 174 | fprintf(stdout, "\n Time - %lf\n", time_taken); 175 | } 176 | 177 | 178 | // Returns label of given data (0-9) 179 | static unsigned int classify(double data[28][28]) 180 | { 181 | float res[10]; 182 | 183 | forward_pass(data); 184 | 185 | unsigned int max = 0; 186 | 187 | cudaMemcpy(res, l_f.output, sizeof(float) * 10, cudaMemcpyDeviceToHost); 188 | 189 | for (int i = 1; i < 10; ++i) { 190 | if (res[max] < res[i]) { 191 | max = i; 192 | } 193 | } 194 | 195 | return max; 196 | } 197 | 198 | // Perform forward propagation of test data 199 | static void test() 200 | { 201 | int error = 0; 202 | 203 | for (int i = 0; i < test_cnt; ++i) { 204 | if (classify(test_set[i].data) != test_set[i].label) { 205 | ++error; 206 | } 207 | } 208 | 209 | fprintf(stdout, "Error Rate: %.2lf%%\n", 210 | double(error) / double(test_cnt) * 100.0); 211 | } 212 | -------------------------------------------------------------------------------- /mnist.h: -------------------------------------------------------------------------------- 1 | #ifndef __MNIST_H__ 2 | #define __MNIST_H__ 3 | 4 | /* 5 | * MNIST loader by Nuri Park - https://github.com/projectgalateia/mnist 6 | */ 7 | 8 | #ifdef USE_MNIST_LOADER /* Fundamental macro to make the code active */ 9 | 10 | #ifdef __cplusplus 11 | extern "C" { 12 | #endif 13 | 14 | /* 15 | * Make mnist_load function static. 16 | * Define when the header is included multiple time. 17 | */ 18 | #ifdef MNIST_STATIC 19 | #define _STATIC static 20 | #else 21 | #define _STATIC 22 | #endif 23 | 24 | /* 25 | * Make mnist loader to load image data as double type. 26 | * It divides unsigned char values by 255.0, so the results ranges from 0.0 to 1.0 27 | */ 28 | #ifdef MNIST_DOUBLE 29 | #define MNIST_DATA_TYPE double 30 | #else 31 | #define MNIST_DATA_TYPE unsigned char 32 | #endif 33 | 34 | typedef struct mnist_data { 35 | MNIST_DATA_TYPE data[28][28]; /* 28x28 data for the image */ 36 | unsigned int label; /* label : 0 to 9 */ 37 | } mnist_data; 38 | 39 | /* 40 | * If it's header inclusion, make only function prototype visible. 41 | */ 42 | #ifdef MNIST_HDR_ONLY 43 | 44 | _STATIC int mnist_load( 45 | const char *image_filename, 46 | const char *label_filename, 47 | mnist_data **data, 48 | unsigned int *count); 49 | 50 | #else 51 | 52 | #include 53 | #include 54 | #include 55 | 56 | /* 57 | * Load a unsigned int from raw data. 58 | * MSB first. 59 | */ 60 | static unsigned int mnist_bin_to_int(char *v) 61 | { 62 | int i; 63 | unsigned int ret = 0; 64 | 65 | for (i = 0; i < 4; ++i) { 66 | ret <<= 8; 67 | ret |= (unsigned char)v[i]; 68 | } 69 | 70 | return ret; 71 | } 72 | 73 | /* 74 | * MNIST dataset loader. 75 | * 76 | * Returns 0 if successed. 77 | * Check comments for the return codes. 78 | */ 79 | _STATIC int mnist_load( 80 | const char *image_filename, 81 | const char *label_filename, 82 | mnist_data **data, 83 | unsigned int *count) 84 | { 85 | int return_code = 0; 86 | int i; 87 | char tmp[4]; 88 | 89 | unsigned int image_cnt, label_cnt; 90 | unsigned int image_dim[2]; 91 | 92 | FILE *ifp = fopen(image_filename, "rb"); 93 | FILE *lfp = fopen(label_filename, "rb"); 94 | 95 | if (!ifp || !lfp) { 96 | return_code = -1; /* No such files */ 97 | goto cleanup; 98 | } 99 | 100 | fread(tmp, 1, 4, ifp); 101 | if (mnist_bin_to_int(tmp) != 2051) { 102 | return_code = -2; /* Not a valid image file */ 103 | goto cleanup; 104 | } 105 | 106 | fread(tmp, 1, 4, lfp); 107 | if (mnist_bin_to_int(tmp) != 2049) { 108 | return_code = -3; /* Not a valid label file */ 109 | goto cleanup; 110 | } 111 | 112 | fread(tmp, 1, 4, ifp); 113 | image_cnt = mnist_bin_to_int(tmp); 114 | 115 | fread(tmp, 1, 4, lfp); 116 | label_cnt = mnist_bin_to_int(tmp); 117 | 118 | if (image_cnt != label_cnt) { 119 | return_code = -4; /* Element counts of 2 files mismatch */ 120 | goto cleanup; 121 | } 122 | 123 | for (i = 0; i < 2; ++i) { 124 | fread(tmp, 1, 4, ifp); 125 | image_dim[i] = mnist_bin_to_int(tmp); 126 | } 127 | 128 | if (image_dim[0] != 28 || image_dim[1] != 28) { 129 | return_code = -2; /* Not a valid image file */ 130 | goto cleanup; 131 | } 132 | 133 | *count = image_cnt; 134 | *data = (mnist_data *)malloc(sizeof(mnist_data) * image_cnt); 135 | 136 | for (i = 0; i < image_cnt; ++i) { 137 | int j; 138 | unsigned char read_data[28 * 28]; 139 | mnist_data *d = &(*data)[i]; 140 | 141 | fread(read_data, 1, 28*28, ifp); 142 | 143 | #ifdef MNIST_DOUBLE 144 | for (j = 0; j < 28*28; ++j) { 145 | d->data[j/28][j%28] = read_data[j] / 255.0; 146 | } 147 | #else 148 | memcpy(d->data, read_data, 28*28); 149 | #endif 150 | 151 | fread(tmp, 1, 1, lfp); 152 | d->label = tmp[0]; 153 | } 154 | 155 | cleanup: 156 | if (ifp) fclose(ifp); 157 | if (lfp) fclose(lfp); 158 | 159 | return return_code; 160 | } 161 | 162 | #endif /* MNIST_HDR_ONLY */ 163 | 164 | #ifdef __cplusplus 165 | } 166 | #endif 167 | 168 | #endif /* USE_MNIST_LOADER */ 169 | #endif /* __MNIST_H__ */ 170 | 171 | --------------------------------------------------------------------------------