├── MatrixMultiplication.cpp └── README.md /MatrixMultiplication.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | /* 6 | * Generate Algorithm 7 | * matA a M*K matrix 8 | * matB a K*N matrix 9 | * matC a M*N matrix 10 | * matC = matA * matB 11 | */ 12 | static void mm_generate(float* matA,float* matB,float* matC,const int M,const int N,const int K) 13 | { 14 | for (int i = 0; i < M;i++) 15 | { 16 | for (int j = 0; j < N;j++) 17 | { 18 | float sum = 0.0f; 19 | for (int k = 0; k < K;k++) 20 | { 21 | sum += matA[i*K + k] * matB[k*N + j]; 22 | } 23 | matC[i*N + j] = sum; 24 | } 25 | } 26 | } 27 | 28 | static void showMatrix(float* C, int M, int N){ 29 | printf("\n=============================\n"); 30 | for(int i = 0; i < M; i++){ 31 | for(int j = 0; j < N; j++){ 32 | printf("%f ",C[i*N + j]); 33 | } 34 | printf("\n"); 35 | } 36 | printf("\n"); 37 | } 38 | 39 | /* 40 | * Strassen Algorithm 41 | * matA a M*K matrix 42 | * matB a K*N matrix 43 | * matC a M*N matrix 44 | * matC = matA * matB 45 | * M1 = (A11+A22)*(B11+B22) 46 | * M2 = (A21+A22)*B11 47 | * M3 = A11*(B12-B22) 48 | * M4 = A22*(B21-B11) 49 | * M5 = (A11+A12)*B22 50 | * M6 = (A21-A11)*(B11+B12) 51 | * M7 = (A12-A22)*(B21+B22) 52 | * C11 = M1+M4-M5+M7 53 | * C12 = M3+M5 54 | * C21 = M2+M4 55 | * C22 = M1-M2+M3+M6 56 | */ 57 | static void mm_strassen(float* matA, float* matB, float* matC, const int M, const int N, const int K) 58 | { 59 | if ((M <= 2) || M%2 != 0 || N%2 != 0 || K%2!=0) 60 | { 61 | return mm_generate(matA, matB, matC, M, N, K); 62 | } 63 | 64 | int offset = 0; 65 | //M1 = (A11+A22)*(B11+B22) 66 | float* M1 = (float*) malloc((M/2) * (N/2) * sizeof(float)); 67 | { 68 | //M1_0 = (A11+A22) 69 | float * M1_0 = (float*) malloc((M/2) * (K/2) * sizeof(float)); 70 | offset = M*K / 2 + K / 2; 71 | for (int i = 0; i < M / 2; i++) 72 | { 73 | for (int j = 0; j < K/2; j++) 74 | { 75 | const int baseIdx = i*K + j; 76 | M1_0[i*K/2+j] = matA[baseIdx] + matA[baseIdx + offset]; 77 | } 78 | } 79 | //M1_1 = (B11+B22) 80 | float* M1_1 = (float*) malloc((K/2) * (N/2) * sizeof(float)); 81 | offset = K*N / 2 + N / 2; 82 | for (int i = 0; i < K / 2; i++) 83 | { 84 | for (int j = 0; j < N / 2; j++) 85 | { 86 | const int baseIdx = i*N + j; 87 | M1_1[i*N/2+j] = matB[baseIdx] + matB[baseIdx + offset]; 88 | } 89 | } 90 | mm_strassen(&M1_0[0], &M1_1[0], &M1[0], M / 2, N / 2, K / 2); 91 | 92 | free(M1_0); M1_0=NULL; 93 | free(M1_1); M1_1=NULL; 94 | } 95 | 96 | //M2 = (A21+A22)*B11 97 | float* M2 = (float*) malloc((M/2) * (N/2) * sizeof(float)); 98 | { 99 | //M2_0 = (A21+A22) 100 | float* M2_0 = (float*) malloc((M/2) * (K/2) * sizeof(float)); 101 | offset = K / 2; 102 | for (int i = M / 2; i < M; i++) 103 | { 104 | for (int j = 0; j < K / 2; j++) 105 | { 106 | const int baseIdx = i*K + j; 107 | M2_0[(i-M/2)*K/2+j] = matA[baseIdx] + matA[baseIdx + offset]; 108 | } 109 | } 110 | //M2_1 = B11 111 | float* M2_1 = (float*) malloc((K/2) * (N/2) * sizeof(float)); 112 | for(int i = 0; i < K / 2; i++) { 113 | for(int j = 0; j < N / 2; j++){ 114 | M2_1[i * N/2 + j] = matB[i * N + j]; 115 | } 116 | } 117 | mm_strassen(&M2_0[0], &M2_1[0], &M2[0], M / 2, N / 2, K / 2); 118 | 119 | free(M2_0); M2_0=NULL; 120 | free(M2_1); M2_1=NULL; 121 | } 122 | 123 | //M3 = A11*(B12-B22) 124 | float* M3 = (float*) malloc((M/2) * (N/2) * sizeof(float)); 125 | { 126 | //M3_0 = A11 127 | float* M3_0 = (float*) malloc((M/2) * (K/2) * sizeof(float)); 128 | for(int i = 0; i < M / 2; i++){ 129 | for(int j = 0; j < K / 2; j++){ 130 | M3_0[i * K/2 + j] = matA[i * K + j]; 131 | } 132 | } 133 | //M3_1 = (B12-B22) 134 | float* M3_1 = (float*) malloc((K/2) * (N/2) * sizeof(float)); 135 | offset = K*N / 2; 136 | for (int i = 0; i < K/2; i++) 137 | { 138 | for (int j = N/2; j < N; j++) 139 | { 140 | const int baseIdx = i*N + j; 141 | M3_1[i*N/2+j-N/2] = matB[baseIdx] - matB[baseIdx + offset]; 142 | } 143 | } 144 | mm_strassen(&M3_0[0], &M3_1[0], &M3[0], M / 2, N / 2, K / 2); 145 | 146 | free(M3_0); M3_0=NULL; 147 | free(M3_1); M3_1=NULL; 148 | } 149 | 150 | //M4 = A22*(B21-B11) 151 | float* M4 = (float*) malloc((M/2) * (N/2) * sizeof(float)); 152 | { 153 | //M4_0 = A22 154 | float* M4_0 = (float*) malloc((M/2) * (K/2) * sizeof(float)); 155 | for(int i = M / 2; i < M; i++){ 156 | for(int j = K / 2; j < K; j++){ 157 | M4_0[(i-M/2) * K/2 + j - K/2] = matA[i * K + j]; 158 | } 159 | } 160 | //M4_1 = (B21-B11) 161 | float* M4_1 = (float*) malloc((K/2) * (N/2) * sizeof(float)); 162 | offset = N*K/2; 163 | for (int i = 0; i < K / 2; i++) 164 | { 165 | for (int j = 0; j < N/2; j++) 166 | { 167 | const int baseIdx = i*N + j; 168 | M4_1[i*N/2 + j] = matB[baseIdx + offset] - matB[baseIdx]; 169 | } 170 | } 171 | mm_strassen(&M4_0[0], &M4_1[0], &M4[0], M / 2, N / 2, K / 2); 172 | 173 | free(M4_0); M4_0=NULL; 174 | free(M4_1); M4_1=NULL; 175 | } 176 | 177 | //M5 = (A11+A12)*B22 178 | float* M5 = (float*) malloc((M/2) * (N/2) * sizeof(float)); 179 | { 180 | //M5_0 = (A11+A12) 181 | float* M5_0 = (float*) malloc((M/2) * (K/2) * sizeof(float)); 182 | offset = K / 2; 183 | for (int i = 0; i < M/2; i++) 184 | { 185 | for (int j = 0; j < K / 2; j++) 186 | { 187 | const int baseIdx = i*K + j; 188 | M5_0[i*K / 2 + j] = matA[baseIdx] + matA[baseIdx + offset]; 189 | } 190 | } 191 | //M5_1 = B22 192 | float* M5_1 = (float*) malloc((K/2) * (N/2) * sizeof(float)); 193 | offset = N*K/2 + N/2; 194 | for(int i = 0; i < K / 2; i++){ 195 | for(int j = 0; j < N / 2; j++){ 196 | M5_1[i * N/2 + j] = matB[i * N + j + offset]; 197 | } 198 | } 199 | mm_strassen(&M5_0[0], &M5_1[0], &M5[0], M / 2, N / 2, K / 2); 200 | 201 | free(M5_0); M5_0=NULL; 202 | free(M5_1); M5_1=NULL; 203 | } 204 | 205 | //M6 = (A21-A11)*(B11+B12) 206 | float* M6 = (float*) malloc((M/2) * (N/2) * sizeof(float)); 207 | { 208 | //M6_0 = (A21-A11) 209 | float* M6_0 = (float*) malloc((M/2) * (K/2) * sizeof(float)); 210 | offset = K * M / 2; 211 | for (int i = 0; i < M / 2; i++) 212 | { 213 | for (int j = 0; j < K/2; j++) 214 | { 215 | const int baseIdx = i*K + j; 216 | M6_0[i*K/2+j] = matA[baseIdx + offset] - matA[baseIdx]; 217 | } 218 | } 219 | //M6_1 = (B11+B12) 220 | float* M6_1 = (float*) malloc((K/2) * (N/2) * sizeof(float)); 221 | offset = N / 2; 222 | for (int i = 0; i < K / 2; i++) 223 | { 224 | for (int j = 0; j < N/2; j++) 225 | { 226 | const int baseIdx = i*N + j; 227 | M6_1[i*N/2+j] = matB[baseIdx] + matB[baseIdx + offset]; 228 | } 229 | } 230 | mm_strassen(&M6_0[0], &M6_1[0], &M6[0], M / 2, N / 2, K / 2); 231 | 232 | free(M6_0); M6_0=NULL; 233 | free(M6_1); M6_1=NULL; 234 | } 235 | 236 | //M7 = (A12-A22)*(B21+B22) 237 | float* M7 = (float*) malloc((M/2) * (N/2) * sizeof(float)); 238 | { 239 | //M7_0 = (A12-A22) 240 | float* M7_0 = (float*) malloc((M/2) * (K/2) * sizeof(float)); 241 | offset = M*K / 2; 242 | for (int i = 0; i < M / 2; i++) 243 | { 244 | for (int j = K/2; j < K; j++) 245 | { 246 | const int baseIdx = i*K + j; 247 | M7_0[i*K / 2 + j - K / 2] = matA[baseIdx] - matA[baseIdx + offset]; 248 | } 249 | } 250 | //M7_1 = (B21+B22) 251 | float* M7_1 = (float*) malloc((K/2) * (N/2) * sizeof(float)); 252 | offset = N / 2; 253 | for (int i = K/2; i < K; i++) 254 | { 255 | for (int j = 0; j < N / 2; j++) 256 | { 257 | const int baseIdx = i*N + j; 258 | M7_1[(i-K/2)*N / 2 + j] = matB[baseIdx] + matB[baseIdx + offset]; 259 | } 260 | } 261 | mm_strassen(&M7_0[0], &M7_1[0], &M7[0], M / 2, N / 2, K / 2); 262 | 263 | free(M7_0); M7_0=NULL; 264 | free(M7_1); M7_1=NULL; 265 | } 266 | 267 | for (int i = 0; i < M / 2;i++) 268 | { 269 | for (int j = 0; j < N / 2;j++) 270 | { 271 | const int idx = i*N / 2 + j; 272 | //C11 = M1+M4-M5+M7 273 | matC[i*N + j] = M1[idx] + M4[idx] - M5[idx] + M7[idx]; 274 | //C12 = M3+M5 275 | matC[i*N + j + N/2] = M3[idx] + M5[idx]; 276 | //C21 = M2+M4 277 | matC[(i+M/2)*N + j] = M2[idx] + M4[idx]; 278 | //C22 = M1-M2+M3+M6 279 | matC[(i+M/2)*N + j + N/2] = M1[idx] - M2[idx] + M3[idx] + M6[idx]; 280 | } 281 | } 282 | free(M1); M1=NULL; 283 | free(M2); M2=NULL; 284 | free(M3); M3=NULL; 285 | free(M4); M4=NULL; 286 | free(M5); M5=NULL; 287 | free(M6); M6=NULL; 288 | free(M7); M7=NULL; 289 | } 290 | 291 | /* 292 | * Used in Coppersmith-Winograd Algorithm 293 | * strideA is the col num of matA, initial value is K 294 | * strideB is the col num of matB, initial value is N 295 | * strideC is the col num of matC, initial value is N 296 | */ 297 | static void mm_generate(float* matA, float* matB, float* matC, const int M, const int N, const int K, 298 | const int strideA, const int strideB, const int strideC){ 299 | for(int i = 0; i < M; i++){ 300 | for(int j = 0; j < N; j++){ 301 | float sum = 0.0f; 302 | for(int k = 0; k < K; k++){ 303 | sum += matA[i*strideA + k] * matB[k*strideB + j]; 304 | } 305 | matC[i*strideC + j] = sum; 306 | } 307 | } 308 | } 309 | 310 | /* 311 | * matA M*K 312 | * matB K*N 313 | * matC M*N 314 | * matC = matA * matB 315 | * S1 = A21 + A22 T1 = B12 - B11 316 | * S2 = S1 - A11 T2 = B22 - T1 317 | * S3 = A11 - A21 T3 = B22 - B12 318 | * S4 = A12 - S2 T4 = T2 - B21 319 | * M1 = A11 * B11 U1 = M1 + M2 320 | * M2 = A12 * B21 U2 = M1 + M6 321 | * M3 = S4 * B22 U3 = U2 + M7 322 | * M4 = A22 * T4 U4 = U2 + M5 323 | * M5 = S1 * T1 U5 = U4 + M3 324 | * M6 = S2 * T2 U6 = U3 - U4 325 | * M7 = S3 * T3 U7 = U3 + M5 326 | * C11 = U1 327 | * C12 = U5 328 | * C21 = U6 329 | * C22 = U7 330 | */ 331 | static void mm_CoppersmithWinograd(float* matA, float* matB, float* matC, const int M, const int N, const int K, 332 | const int strideA, const int strideB, const int strideC){ 333 | if((M <= 2) || (M%2 != 0 || N%2 != 0 || K%2 != 0)){ 334 | return mm_generate(matA, matB, matC, M, N, K, strideA, strideB, strideC); 335 | } 336 | 337 | float* S1 = (float*) malloc((M/2) * (K/2) * sizeof(float)); 338 | float* S2 = (float*) malloc((M/2) * (K/2) * sizeof(float)); 339 | float* S3 = (float*) malloc((M/2) * (K/2) * sizeof(float)); 340 | float* S4 = (float*) malloc((M/2) * (K/2) * sizeof(float)); 341 | { 342 | for(int i = 0; i < M/2; i++){ 343 | for(int j = 0; j < K/2; j++){ 344 | int idxA, offset, idxS = i * (K/2) + j; 345 | 346 | //S1 = A21 + A22 347 | idxA = (i + (M/2)) * strideA + j; 348 | offset = K/2; 349 | S1[idxS] = matA[idxA] + matA[idxA + offset]; 350 | 351 | //S2 = S1 - A11 352 | idxA = i * strideA + j; 353 | S2[idxS] = S1[idxS] - matA[idxA]; 354 | 355 | //S3 = A11 - A21 356 | offset = (M/2) * strideA; 357 | S3[idxS] = matA[idxA] - matA[idxA + offset]; 358 | 359 | //S4 = A12 - S2 360 | idxA = i * strideA + (K/2) + j; 361 | S4[idxS] = matA[idxA] - S2[idxS]; 362 | } 363 | } 364 | } 365 | 366 | float* T1 = (float*) malloc((K/2) * (N/2) * sizeof(float)); 367 | float* T2 = (float*) malloc((K/2) * (N/2) * sizeof(float)); 368 | float* T3 = (float*) malloc((K/2) * (N/2) * sizeof(float)); 369 | float* T4 = (float*) malloc((K/2) * (N/2) * sizeof(float)); 370 | { 371 | for(int i = 0; i < K/2; i++){ 372 | for(int j = 0; j < N/2; j++){ 373 | int idxB, offset, idxT = i * (N/2) + j; 374 | 375 | //T1 = B12 - B11 376 | idxB = i * strideB + j; 377 | offset = (N/2); 378 | T1[idxT] = matB[idxB + offset] - matB[idxB]; 379 | 380 | //T2 = B22 - T1 381 | idxB = (i + (K/2)) * strideB + (N/2) + j; 382 | T2[idxT] = matB[idxB] - T1[idxT]; 383 | 384 | //T3 = B22 - B12 385 | idxB = i * strideB + (N/2) + j; 386 | offset = ((K/2)) * strideB; 387 | T3[idxT] = matB[idxB + offset] - matB[idxB]; 388 | 389 | //T4 = T2 - B21 390 | idxB = (i + (K/2)) * strideB + j; 391 | T4[idxT] = T2[idxT] - matB[idxB]; 392 | } 393 | } 394 | } 395 | 396 | //M1 = A11 * B11 397 | float* M1 = (float*) malloc((M/2) * (N/2) * sizeof(float)); 398 | mm_CoppersmithWinograd(matA, matB, &M1[0], M/2, N/2, K/2, strideA, strideB, N/2); 399 | 400 | //M2 = A12 * B21 401 | float* M2 = (float*) malloc((M/2) * (N/2) * sizeof(float)); 402 | mm_CoppersmithWinograd(&matA[K/2], &matB[(K/2)*strideB], &M2[0], M/2, N/2, K/2, strideA, strideB, N/2); 403 | 404 | //M3 = S4 * B22 405 | float* M3 = (float*) malloc((M/2) * (N/2) * sizeof(float)); 406 | mm_CoppersmithWinograd(&S4[0], &matB[(K/2) * strideB + (N/2)], &M3[0], M/2, N/2, K/2, K/2, strideB, N/2); 407 | 408 | //M4 = A22 * T4 409 | float* M4 = (float*) malloc((M/2) * (N/2) * sizeof(float)); 410 | mm_CoppersmithWinograd(&matA[(M/2) * strideA + (K/2)], &T4[0], &M4[0], M/2, N/2, K/2, strideA, N/2, N/2); 411 | 412 | //M5 = S1 * T1 413 | float* M5 = (float*) malloc((M/2) * (N/2) * sizeof(float)); 414 | mm_CoppersmithWinograd(&S1[0], &T1[0], &M5[0], M/2, N/2, K/2, K/2, N/2, N/2); 415 | 416 | //M6 = S2 * T2 417 | float* M6 = (float*) malloc((M/2) * (N/2) * sizeof(float)); 418 | mm_CoppersmithWinograd(&S2[0], &T2[0], &M6[0], M/2, N/2, K/2, K/2, N/2, N/2); 419 | 420 | //M7 = S3 * T3 421 | float* M7 = (float*) malloc((M/2) * (N/2) * sizeof(float)); 422 | mm_CoppersmithWinograd(&S3[0], &T3[0], &M7[0], M/2, N/2, K/2, K/2, N/2, N/2); 423 | 424 | //C11 = U1 = M1 + M2 425 | //C12 = U5 = U4 + M3 = U2 + M5 + M3 = M1 + M6 + M5 + M3 426 | //C21 = U6 = U3 - M4 = U2 + M7 - M4 = M1 + M6 + M7 - M4 427 | //C22 = U7 = U3 + M5 = U2 + M7 + M5 = M1 + M6 + M7 + M5 428 | for(int i = 0; i < M/2; i++){ 429 | for(int j = 0; j < N/2; j++){ 430 | int idx = i * (N/2) + j; 431 | matC[i*strideC + j] = M1[idx] + M2[idx]; 432 | matC[i*strideC + j + (N/2)] = M1[idx] + M6[idx] + M5[idx] + M3[idx]; 433 | matC[(i+(M/2))*strideC + j] = M1[idx] + M6[idx] + M7[idx] - M4[idx]; 434 | matC[(i+(M/2))*strideC + j + (N/2)] = M1[idx] + M6[idx] + M7[idx] + M5[idx]; 435 | } 436 | } 437 | free(S1); S1=NULL; 438 | free(S2); S2=NULL; 439 | free(S3); S3=NULL; 440 | free(S4); S4=NULL; 441 | free(T1); T1=NULL; 442 | free(T2); T2=NULL; 443 | free(T3); T3=NULL; 444 | free(T4); T4=NULL; 445 | free(M1); M1=NULL; 446 | free(M2); M2=NULL; 447 | free(M3); M3=NULL; 448 | free(M4); M4=NULL; 449 | free(M5); M5=NULL; 450 | free(M6); M6=NULL; 451 | free(M7); M7=NULL; 452 | } 453 | 454 | static void mm_CoppersmithWinograd(float* matA, float* matB, float* matC, const int M, const int N, const int K){ 455 | mm_CoppersmithWinograd(matA, matB, matC, M, N, K, K, N, N); 456 | } 457 | 458 | /* 459 | * It is the test function 460 | */ 461 | static void mm_test(int M, int N, int K, int rangeTop){ 462 | unsigned seed = time(0); 463 | srand(seed); 464 | clock_t start,end; 465 | for(int i = 0; i < 10; i++){ 466 | float * mA = (float*) malloc(M*K*sizeof(float)); 467 | float * mB = (float*) malloc(K*N*sizeof(float)); 468 | float * mC = (float*) malloc(M*N*sizeof(float)); 469 | float * mD = (float*) malloc(M*N*sizeof(float)); 470 | float * mE = (float*) malloc(M*N*sizeof(float)); 471 | for(int j = 0; j < M*K; j++){ 472 | mA[j] = rand() % rangeTop; 473 | } 474 | for(int j = 0; j < K*N; j++){ 475 | mB[j] = rand() % rangeTop; 476 | } 477 | start = clock(); 478 | mm_strassen(mA, mB, mC, M, N, K); 479 | end = clock(); 480 | double endtime = (double) (end-start)/CLOCKS_PER_SEC; 481 | printf("Strassen%d time: %fms\n", i, endtime*1000); 482 | 483 | start = clock(); 484 | mm_generate(mA, mB, mD, M, N, K); 485 | end = clock(); 486 | endtime = (double) (end-start)/CLOCKS_PER_SEC; 487 | printf("Generate%d time: %fms\n", i, endtime*1000); 488 | 489 | start = clock(); 490 | mm_CoppersmithWinograd(mA, mB, mE, M, N, K); 491 | end = clock(); 492 | endtime = (double) (end-start)/CLOCKS_PER_SEC; 493 | printf("Winograd%d time: %fms\n", i, endtime*1000); 494 | 495 | for(int j = 0; j < M*N; j++){ 496 | if(mC[j] != mD[j] || mC[j] != mD[j]){ 497 | printf("========A========\n"); 498 | showMatrix(mA, M, K); 499 | printf("========B========\n"); 500 | showMatrix(mB, K, N); 501 | printf("========Strassen========\n"); 502 | showMatrix(mC, M, N); 503 | printf("========Generate========\n"); 504 | showMatrix(mD, M, N); 505 | printf("========Winograd========\n"); 506 | showMatrix(mE, M, N); 507 | return ; 508 | } 509 | } 510 | printf("\n"); 511 | } 512 | } 513 | 514 | int main(){; 515 | int M, N, K, rangeTop; 516 | M = 1000; 517 | N = 1000; 518 | K = 2000; 519 | rangeTop = 10; 520 | mm_test(M, N, K, rangeTop); 521 | return 0; 522 | } 523 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Matrix-Multiplication 2 | Three Matrix-Multiplication-Algorithms: Generate Algorithm, Strassen Algorithm and Coppersmith-Winograd Algorithm. 3 | --------------------------------------------------------------------------------