├── .gitignore ├── .gitmodules ├── CODE_OF_CONDUCT.md ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── common.mk ├── src ├── cusync.cu ├── examples │ ├── sync-overhead │ │ └── sync-overhead.cu │ └── twoMatMuls │ │ ├── Makefile │ │ └── matrixMul.cu ├── include │ ├── cusync │ │ ├── cusync.h │ │ ├── cusync_defines.h │ │ ├── cusync_device_defs.h │ │ ├── device-functions.h │ │ ├── policies.h │ │ ├── tile-orders.h │ │ └── wait-kernel.h │ └── cutlass │ │ └── cusync-cutlass │ │ └── include │ │ └── cutlass │ │ ├── conv │ │ ├── device │ │ │ └── cusyncimplicit_gemm_convolution.h │ │ ├── kernel │ │ │ ├── cusyncdefault_conv2d_fprop.h │ │ │ └── implicit_cusyncgemm_convolution.h │ │ └── threadblock │ │ │ └── implicit_cusyncgemm_pipelined.h │ │ └── gemm │ │ ├── device │ │ └── cusyncgemm.h │ │ ├── kernel │ │ ├── cusyncgemm.h │ │ └── default_cusyncgemm.h │ │ └── threadblock │ │ ├── cusync_threadblock_swizzle.h │ │ ├── cusyncmma_multistage.h │ │ ├── cusyncmma_pipelined.h │ │ └── default_cusyncmma.h └── ml-bench │ ├── README.md │ ├── common.mk │ ├── plots │ ├── Makefile │ ├── common.py │ ├── mlp-gpt3-a100.png │ ├── mlp-gpt3-v100.png │ ├── mlp-llama-a100.png │ ├── mlp-llama-v100.png │ └── plotGPT.py │ ├── transformer │ ├── Makefile │ ├── allreduce_times.py │ ├── attention.cu │ ├── common.h │ ├── eval_mlp.py │ ├── mlp-lib.cu │ ├── mlp.cu │ ├── results │ │ ├── allreduce_times-12288 │ │ ├── attention-results │ │ ├── attention-results-gpt-3-cuda-12.2 │ │ ├── attention-results-gpt3 │ │ ├── attention-results-llama │ │ ├── attention-stream-k-output │ │ ├── mlp-gpt3-a100.csv │ │ ├── mlp-gpt3-v100.csv │ │ ├── mlp-llama-a100.csv │ │ ├── mlp-llama-v100.csv │ │ ├── mlp-results-2 │ │ ├── mlp-results-gpt-3 │ │ ├── mlp-results-gpt3-cuda-12.2 │ │ ├── mlp-results-in-paper │ │ ├── mlp-results-llama │ │ └── mlp-stream-k-output │ ├── streamk.cu │ ├── tile_sizes_db.py │ └── torch-baselines │ │ ├── cublasBaseline.py │ │ ├── torchAttention.py │ │ └── torchmlp.py │ └── volta_conv2d │ ├── Makefile │ ├── eval_resnet.py │ ├── resnet.cu │ ├── resnet_results.csv │ ├── torchconv2d.py │ ├── vgg-results-cuda-12.2 │ └── vgg.cu └── tests ├── cusync-test.h └── simple-test.cu /.gitignore: -------------------------------------------------------------------------------- 1 | *.i 2 | *.ii 3 | *.gpu 4 | *.ptx 5 | *.cubin 6 | *.fatbin 7 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "cutlass/nvidia-cutlass"] 2 | path = src/include/cutlass/nvidia-cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | [submodule "googletest"] 5 | path = tests/googletest 6 | url = https://github.com/google/googletest.git 7 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:23.11-py3 2 | RUN apt-get update 3 | RUN pip3 install matplotlib 4 | RUN apt-get install cmake -y 5 | RUN mkdir /cusync 6 | COPY . /cusync/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | 23 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | include common.mk 2 | 3 | GOOGLE_TEST = tests/googletest 4 | GOOGLE_TEST_BUILD = $(GOOGLE_TEST)/build 5 | TEST_INCLUDE_DIRS = -Isrc/include/ -I$(GOOGLE_TEST)/googletest/include/ -L$(GOOGLE_TEST_BUILD)/lib/ 6 | TEST_LFLAGS = -lgtest -lpthread 7 | GOOGLE_TEST_MAIN = $(GOOGLE_TEST)/googletest/src/gtest_main.cc 8 | CUSYNC_SRC_FILES = src/cusync.cu 9 | 10 | tests: run-simple-test 11 | 12 | build-googletest: $(GOOGLE_TEST) 13 | mkdir -p $(GOOGLE_TEST_BUILD) && cd $(GOOGLE_TEST_BUILD) && cmake .. && make -j 14 | 15 | simple-test: build-googletest $(shell find src/include/cusync -type f) 16 | $(NVCC) tests/$@.cu $(CUSYNC_SRC_FILES) $(TEST_INCLUDE_DIRS) $(TEST_LFLAGS) $(GOOGLE_TEST_MAIN) $(ARCH_CODE_FLAGS) -O3 -Xcompiler=-fopenmp,-O3,-Wall -o $@ -g -G 17 | 18 | run-simple-test: simple-test 19 | ./simple-test -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | CuSync 2 | --------------------- 3 | 4 | CuSync is a framework to synchronize tile-based CUDA kernels in a fine-grained manner. 5 | With CuSync, a programmer can write policies to synchronize dependent tiles, i.e. thread blocks, of a chain of producer and consumer kernels. 6 | Synchronizing thread blocks instead of kernels allows concurrent execution of independent thread blocks, thereby, improving the utilization during the last thread block waves on the GPU. 7 | More details are available at https://arxiv.org/abs/2305.13450. 8 | 9 | ## Performance 10 | 11 | The graphs below shows percentage improvement over GPT3 and LLAMA MLPs using optimized NVIDIA CUTLASS GeMMs on NVIDIA Tesla A100 and NVIDIA Tesla V100 GPUs for 8 way model parallelism GPT3 175B (H=12288 FP16) and LLAMA 65.2B (H=8192 FP16). 12 | NVIDIA CUTLASS StreamK is another method to optimize the utilization during the last thread block wave. 13 | PyTorch in the below experiments only performs GeMM and not the pointwise computations like GeLU, while CUTLASS implementations fuse these computations with the first GeMM. 14 | 15 | #### NVIDIA Tesla A100 SXM4 80GB with CUDA 12.2 16 | ![](https://github.com/parasailteam/cusync/blob/main/src/ml-bench/plots/mlp-gpt3-a100.png?raw=true) 17 | ![](https://github.com/parasailteam/cusync/blob/main/src/ml-bench/plots/mlp-llama-a100.png?raw=true) 18 | 19 | #### NVIDIA Tesla V100 SXM2 32GB with CUDA 12.2 20 | ![](https://github.com/parasailteam/cusync/blob/main/src/ml-bench/plots/mlp-gpt3-v100.png?raw=true) 21 | ![](https://github.com/parasailteam/cusync/blob/main/src/ml-bench/plots/mlp-llama-v100.png?raw=true) 22 | 23 | ## Usage 24 | 25 | #### Clone 26 | Clone the repo and its submodules using 27 | 28 | ```git clone --recurse-submodules https://github.com/parasailteam/cusync.git``` 29 | 30 | If already cloned and want to clone submodules, use 31 | 32 | ```git submodule update --init --recursive``` 33 | 34 | #### Example 35 | An example of synchronizing two dependent GeMMs is provided in the `src/example/`. Moreover, there are small tests in `tests/` that can be used as examples. 36 | 37 | #### CuSync + CUTLASS 38 | 39 | The repo also provides CUTLASS GeMM structs augmented with CuSync structures in `src/include/cusync-cutlass/`. 40 | The MLP code in `src/ml-bench/transformer` provides a good way to use CUTLASS cusync. 41 | 42 | ## Tests 43 | 44 | Run tests using `make tests` 45 | 46 | ## Evaluation 47 | 48 | Instructions are in `src/ml-bench/README.md`. 49 | 50 | ## Contributing 51 | 52 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 53 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 54 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 55 | 56 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 57 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 58 | provided by the bot. You will only need to do this once across all repos using our CLA. 59 | 60 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 61 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 62 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 63 | 64 | ## Trademarks 65 | 66 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 67 | trademarks or logos is subject to and must follow 68 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 69 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 70 | Any use of third-party trademarks or logos are subject to those third-party's policies. 71 | 72 | 73 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /common.mk: -------------------------------------------------------------------------------- 1 | ARCH_CODE_FLAGS=-gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 2 | NVCC=/usr/local/cuda/bin/nvcc -------------------------------------------------------------------------------- /src/cusync.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "cusync/wait-kernel.h" 5 | 6 | /* 7 | * The wait kernel waits until the value of semaphore has reached the given value. 8 | * @semaphore: Address to the unsigned integer semaphore 9 | * @givenValue: Given value of the semaphore 10 | */ 11 | __global__ 12 | void waitKernel(volatile uint32_t* semaphore, uint32_t givenValue) { 13 | if (threadIdx.x == 0) { 14 | uint32_t currVal = globalLoad(semaphore); 15 | while(currVal < givenValue) { 16 | currVal = globalVolatileLoad(semaphore); 17 | } 18 | } 19 | } 20 | 21 | namespace cusync { 22 | void invokeWaitKernel(uint32_t* semaphore, uint32_t givenValue, cudaStream_t stream) { 23 | waitKernel<<<1,1,0,stream>>>((uint32_t*)semaphore, givenValue); 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/examples/sync-overhead/sync-overhead.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | 7 | #define CUDA_CHECK(cmd) do { \ 8 | cudaError_t e = cmd; \ 9 | if( e != cudaSuccess ) { \ 10 | printf("Failed: Cuda error %s:%d '%s'\n", \ 11 | __FILE__,__LINE__,cudaGetErrorString(e)); \ 12 | exit(EXIT_FAILURE); \ 13 | } \ 14 | } while(0); 15 | 16 | static double convertTimeValToDouble(struct timeval _time) { 17 | return ((double)_time.tv_sec)*1e6 + ((double)_time.tv_usec); 18 | } 19 | 20 | static struct timeval getTimeOfDay () { 21 | struct timeval _time; 22 | 23 | if (gettimeofday (&_time, NULL) == -1) { 24 | fprintf (stderr, "gettimeofday returned -1\n"); 25 | perror (""); 26 | abort (); 27 | } 28 | 29 | return _time; 30 | } 31 | 32 | static double timeInMicroSeconds() { 33 | return convertTimeValToDouble(getTimeOfDay()); 34 | } 35 | 36 | static double getCurrentTime() { 37 | return timeInMicroSeconds(); 38 | } 39 | 40 | __global__ void kernel1(float *in, int i, volatile int* sync, bool cansync) { 41 | int linearid = threadIdx.x + blockIdx.x * blockDim.x; 42 | in[linearid] = i; 43 | __syncthreads(); 44 | if (cansync && threadIdx.x == 0) 45 | sync[blockIdx.x] += 1; 46 | } 47 | 48 | __global__ void kernel2(float *out, float *in, volatile int* sync, bool cansync, int iter) { 49 | if (cansync && threadIdx.x == 0) { 50 | for (int i = threadIdx.x; i < 1; i += blockDim.x) { 51 | while (sync[i] < iter + 1); 52 | } 53 | // sync[blockIdx.x] = 0; 54 | } 55 | __syncthreads(); 56 | int linearid = threadIdx.x + blockIdx.x * blockDim.x; 57 | out[linearid] = in[linearid] + 1; 58 | } 59 | 60 | int main() { 61 | float* in, *out; 62 | size_t size = 1 << 20; 63 | CUDA_CHECK(cudaMalloc(&in, size)); 64 | CUDA_CHECK(cudaMalloc(&out, size)); 65 | int* sync; 66 | CUDA_CHECK(cudaMalloc(&sync, size)); 67 | CUDA_CHECK(cudaMemset(sync, size * sizeof(int), 0)); 68 | unsigned int threads = 128; 69 | dim3 grid = {80*2 * (1024/threads), 1, 1}; 70 | dim3 block = {threads, 1, 1}; 71 | cudaStream_t prodstream, constream; 72 | 73 | int highestPriority; 74 | int lowestPriority; 75 | 76 | CUDA_CHECK(cudaDeviceGetStreamPriorityRange(&lowestPriority, &highestPriority)); 77 | 78 | CUDA_CHECK(cudaStreamCreateWithPriority(&prodstream, 0, highestPriority)); 79 | CUDA_CHECK(cudaStreamCreateWithPriority(&constream, 0, lowestPriority)); 80 | CUDA_CHECK(cudaDeviceSynchronize()); 81 | double sync_exec = 0; 82 | for (int i = 0; i < 110; i++) { 83 | double s = getCurrentTime(); 84 | kernel1<<>>(in, 0, sync, true); 85 | kernel2<<>>(out, in, sync, true, i); 86 | CUDA_CHECK(cudaDeviceSynchronize()); 87 | double t = getCurrentTime(); 88 | if (i >= 10) 89 | sync_exec += t - s; 90 | } 91 | 92 | printf("exec with sync %lf\n", sync_exec); 93 | CUDA_CHECK(cudaDeviceSynchronize()); 94 | 95 | double exec = 0; 96 | for (int i = 0; i < 100; i++) { 97 | double s = getCurrentTime(); 98 | kernel1<<>>(in, 0, sync, false); 99 | kernel2<<>>(out, in, sync, false, i); 100 | CUDA_CHECK(cudaDeviceSynchronize()); 101 | double t = getCurrentTime(); 102 | exec += t - s; 103 | } 104 | printf("exec without sync %lf\n", exec); 105 | printf("Overhead %lf %%\n", (sync_exec - exec)/exec * 100.); 106 | } -------------------------------------------------------------------------------- /src/examples/twoMatMuls/Makefile: -------------------------------------------------------------------------------- 1 | include ../../../common.mk 2 | all: matrixMul 3 | 4 | matrixMul: matrixMul.cu 5 | $(NVCC) -I../../include/ matrixMul.cu ../../cusync.cu -o matrixMul $(ARCH_CODE_FLAGS) 6 | -------------------------------------------------------------------------------- /src/examples/twoMatMuls/matrixMul.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | /** 5 | * Matrix multiplication: C = A * B. 6 | * Host code. 7 | * 8 | * This sample implements matrix multiplication which makes use of shared memory 9 | * to ensure data reuse, the matrix multiplication is done using tiling 10 | * approach. It has been written for clarity of exposition to illustrate various 11 | * CUDA programming principles, not with the goal of providing the most 12 | * performant generic kernel for matrix multiplication. See also: V. Volkov and 13 | * J. Demmel, "Benchmarking GPUs to tune dense linear algebra," in Proc. 2008 14 | * ACM/IEEE Conf. on Supercomputing (SC '08), Piscataway, NJ: IEEE Press, 2008, 15 | * pp. Art. 31:1-11. 16 | */ 17 | 18 | // System includes 19 | #include 20 | #include 21 | 22 | // CUDA runtime 23 | #include 24 | 25 | // CuSync include 26 | #include 27 | 28 | /** 29 | * Matrix multiplication (CUDA Kernel) on the device: C = A * B 30 | * wA is A's width and wB is B's width 31 | */ 32 | 33 | using namespace cusync; 34 | 35 | //Define Producer and Consumer CuStage 36 | const int BLOCK_SIZE = 32; 37 | using Sync = TileSync; 38 | using ProdCuStage = CuStage; 39 | using ConsCuStage = CuStage; 40 | 41 | template 42 | __global__ void MatrixMulCUDA(CuStageTy custage, float *C, float *A, 43 | float *B, int wA, int wB) { 44 | __shared__ int tileSh[3]; 45 | // Get tile to compute by this thread block 46 | dim3 tile = custage.tile((dim3*)&tileSh[0]); 47 | 48 | // Block index 49 | int bx = tile.x; 50 | int by = tile.y; 51 | 52 | // Thread index 53 | int tx = threadIdx.x; 54 | int ty = threadIdx.y; 55 | 56 | // Index of the first sub-matrix of A processed by the block 57 | int aBegin = wA * BLOCK_SIZE * by; 58 | 59 | // Index of the last sub-matrix of A processed by the block 60 | int aEnd = aBegin + wA - 1; 61 | 62 | // Step size used to iterate through the sub-matrices of A 63 | int aStep = BLOCK_SIZE; 64 | 65 | // Index of the first sub-matrix of B processed by the block 66 | int bBegin = BLOCK_SIZE * bx; 67 | 68 | // Step size used to iterate through the sub-matrices of B 69 | int bStep = BLOCK_SIZE * wB; 70 | 71 | // Csub is used to store the element of the block sub-matrix 72 | // that is computed by the thread 73 | float Csub = 0; 74 | 75 | // Loop over all the sub-matrices of A and B 76 | // required to compute the block sub-matrix 77 | for (int a = aBegin, b = bBegin; a <= aEnd; a += aStep, b += bStep) { 78 | // Declaration of the shared memory array As used to 79 | // store the sub-matrix of A 80 | __shared__ float As[BLOCK_SIZE][BLOCK_SIZE]; 81 | 82 | // Declaration of the shared memory array Bs used to 83 | // store the sub-matrix of B 84 | __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE]; 85 | 86 | // Load the matrices from device memory 87 | // to shared memory; each thread loads 88 | // one element of each matrix 89 | // Wait for tile of A to be computed by producer kernel 90 | 91 | dim3 tile = {(uint32_t)(a - aBegin), (uint32_t)by * BLOCK_SIZE, 1}; 92 | custage.wait(tile); 93 | 94 | As[ty][tx] = A[a + wA * ty + tx]; 95 | Bs[ty][tx] = B[b + wB * ty + tx]; 96 | 97 | // Synchronize to make sure the matrices are loaded 98 | __syncthreads(); 99 | 100 | // Multiply the two matrices together; 101 | // each thread computes one element 102 | // of the block sub-matrix 103 | #pragma unroll 104 | for (int k = 0; k < BLOCK_SIZE; ++k) { 105 | Csub += As[ty][k] * Bs[k][tx]; 106 | } 107 | 108 | // Synchronize to make sure that the preceding 109 | // computation is done before loading two new 110 | // sub-matrices of A and B in the next iteration 111 | __syncthreads(); 112 | } 113 | 114 | // Write the block sub-matrix to device memory; 115 | // each thread writes one element 116 | int c = wB * BLOCK_SIZE * by + BLOCK_SIZE * bx; 117 | C[c + wB * ty + tx] = Csub; 118 | 119 | // Post the status of tile when computed 120 | custage.post({(uint32_t)bx * BLOCK_SIZE, (uint32_t)by * BLOCK_SIZE, 0}); 121 | } 122 | 123 | void ConstantInit(float *data, int size, float val) { 124 | for (int i = 0; i < size; ++i) { 125 | data[i] = val; 126 | } 127 | } 128 | 129 | /** 130 | * Run a simple test of matrix multiplication using CUDA 131 | */ 132 | int MatrixMultiply(int argc, char **argv, int block_size, const dim3 &dimsA, 133 | const dim3 &dimsB, const dim3 &dimsD) { 134 | // Allocate host memory for matrices A and B 135 | unsigned int size_A = dimsA.x * dimsA.y; 136 | unsigned int mem_size_A = sizeof(float) * size_A; 137 | float *h_A; 138 | CUDA_CHECK(cudaMallocHost(&h_A, mem_size_A)); 139 | unsigned int size_B = dimsB.x * dimsB.y; 140 | unsigned int mem_size_B = sizeof(float) * size_B; 141 | float *h_B; 142 | CUDA_CHECK(cudaMallocHost(&h_B, mem_size_B)); 143 | float *h_D; 144 | CUDA_CHECK(cudaMallocHost(&h_D, mem_size_A)); 145 | 146 | cudaStream_t prod_stream, cons_stream; 147 | 148 | // Initialize host memory 149 | const float valB = 0.01f; 150 | ConstantInit(h_A, size_A, 1.0f); 151 | ConstantInit(h_B, size_B, valB); 152 | ConstantInit(h_D, size_B, valB); 153 | 154 | // Allocate device memory 155 | float *d_A, *d_B, *d_C, *d_D, *d_E; 156 | 157 | // Allocate host matrix C and E 158 | dim3 dimsC(dimsB.x, dimsA.y, 1); 159 | unsigned int mem_size_C = dimsC.x * dimsC.y * sizeof(float); 160 | float *h_C; 161 | CUDA_CHECK(cudaMallocHost(&h_C, mem_size_C)); 162 | 163 | dim3 dimsE(dimsB.x, dimsA.y, 1); 164 | unsigned int mem_size_E = dimsC.x * dimsC.y * sizeof(float); 165 | float *h_E; 166 | CUDA_CHECK(cudaMallocHost(&h_E, mem_size_E)); 167 | 168 | if (h_C == NULL) { 169 | fprintf(stderr, "Failed to allocate host matrix C!\n"); 170 | exit(EXIT_FAILURE); 171 | } 172 | 173 | CUDA_CHECK(cudaMalloc(reinterpret_cast(&d_A), mem_size_A)); 174 | CUDA_CHECK(cudaMalloc(reinterpret_cast(&d_B), mem_size_B)); 175 | CUDA_CHECK(cudaMalloc(reinterpret_cast(&d_C), mem_size_C)); 176 | CUDA_CHECK(cudaMalloc(reinterpret_cast(&d_D), mem_size_B)); 177 | CUDA_CHECK(cudaMalloc(reinterpret_cast(&d_E), mem_size_E)); 178 | 179 | // Allocate CUDA events that we'll use for timing 180 | cudaEvent_t start, stop; 181 | CUDA_CHECK(cudaEventCreate(&start)); 182 | CUDA_CHECK(cudaEventCreate(&stop)); 183 | 184 | CUDA_CHECK(cudaStreamCreateWithFlags(&cons_stream, cudaStreamNonBlocking)); 185 | CUDA_CHECK(cudaStreamCreateWithFlags(&prod_stream, cudaStreamNonBlocking)); 186 | 187 | // copy host memory to device 188 | CUDA_CHECK( 189 | cudaMemcpy(d_A, h_A, mem_size_A, cudaMemcpyHostToDevice)); 190 | CUDA_CHECK( 191 | cudaMemcpy(d_B, h_B, mem_size_B, cudaMemcpyHostToDevice)); 192 | CUDA_CHECK( 193 | cudaMemcpy(d_D, h_D, mem_size_B, cudaMemcpyHostToDevice)); 194 | 195 | // Setup execution parameters 196 | dim3 threads(block_size, block_size, 1); 197 | dim3 grid(dimsB.x / threads.x, dimsA.y / threads.y, 1); 198 | 199 | // Create CuSync and CuStage 200 | Sync sync; 201 | dim3 tilesize = threads; 202 | ProdCuStage prod(grid, tilesize, NoSync(), sync); 203 | ConsCuStage cons(grid, tilesize, sync, NoSync()); 204 | CuSync::setProducerConsumerPair(prod, cons); 205 | 206 | // Create and start timer 207 | printf("Computing result using CUDA Kernel...\n"); 208 | 209 | assert (block_size == 32); 210 | // Invoke producer kernel (C = A * B) 211 | MatrixMulCUDA 212 | <<>>(prod, d_C, d_A, d_B, dimsA.x, dimsB.x); 213 | 214 | //Invoke wait kernel 215 | prod.invokeWaitKernel(cons_stream); 216 | 217 | //Invoke consumer kernel (E = C * D) 218 | MatrixMulCUDA 219 | <<>>(cons, d_E, d_C, d_D, dimsA.x, dimsB.x); 220 | 221 | CUDA_CHECK(cudaDeviceSynchronize()); 222 | 223 | //for next run increment the iteration counter 224 | prod.incrementIter(); 225 | cons.incrementIter(); 226 | 227 | printf("Execution done\n"); 228 | 229 | // Copy result from device to host 230 | CUDA_CHECK( 231 | cudaMemcpy(h_C, d_C, mem_size_C, cudaMemcpyDeviceToHost)); 232 | CUDA_CHECK( 233 | cudaMemcpy(h_E, d_E, mem_size_C, cudaMemcpyDeviceToHost)); 234 | 235 | printf("Checking computed result for correctness: \n"); 236 | bool correct = true; 237 | 238 | // test relative error by the formula 239 | // |_cpu - _gpu|/<|x|, |y|> < eps 240 | double eps = 1.e-5; // machine zero 241 | // Check C 242 | for (int i = 0; i < static_cast(dimsC.x * dimsC.y); i++) { 243 | double abs_err = fabs(h_C[i] - (dimsA.x * valB)); 244 | double dot_length = dimsA.x; 245 | double abs_val = fabs(h_C[i]); 246 | double rel_err = abs_err / abs_val / dot_length; 247 | 248 | if (rel_err > eps) { 249 | printf("Error! Matrix[%05d]=%.8f, ref=%.8f error term is > %E\n", i, 250 | h_C[i], dimsA.x * valB, eps); 251 | correct = false; 252 | break; 253 | } 254 | } 255 | 256 | printf("C results: %s\n", correct ? "PASS" : "FAIL"); 257 | 258 | for (int i = 0; i < static_cast(dimsC.x * dimsC.y); i++) { 259 | double abs_err = fabs(h_E[i] - (dimsA.x * valB * dimsA.x * valB)); 260 | double dot_length = dimsA.x; 261 | double abs_val = fabs(h_E[i]); 262 | double rel_err = abs_err / abs_val / dot_length; 263 | 264 | if (rel_err > eps) { 265 | printf("Error! Matrix[%05d]=%.8f, ref=%.8f error term is > %E\n", i, 266 | h_E[i], dimsA.x * valB * dimsA.x, eps); 267 | correct = false; 268 | break; 269 | } 270 | } 271 | 272 | printf("E results: %s\n", correct ? "PASS" : "FAIL"); 273 | 274 | // Clean up memory 275 | CUDA_CHECK(cudaFreeHost(h_A)); 276 | CUDA_CHECK(cudaFreeHost(h_B)); 277 | CUDA_CHECK(cudaFreeHost(h_C)); 278 | CUDA_CHECK(cudaFreeHost(h_D)); 279 | CUDA_CHECK(cudaFreeHost(h_E)); 280 | CUDA_CHECK(cudaFree(d_A)); 281 | CUDA_CHECK(cudaFree(d_B)); 282 | CUDA_CHECK(cudaFree(d_C)); 283 | CUDA_CHECK(cudaFree(d_D)); 284 | CUDA_CHECK(cudaFree(d_E)); 285 | CUDA_CHECK(cudaEventDestroy(start)); 286 | CUDA_CHECK(cudaEventDestroy(stop)); 287 | 288 | if (correct) { 289 | return EXIT_SUCCESS; 290 | } else { 291 | return EXIT_FAILURE; 292 | } 293 | } 294 | 295 | /** 296 | * Program main 297 | */ 298 | int main(int argc, char **argv) { 299 | printf("[Matrix Multiply Using CUDA] - Starting...\n"); 300 | 301 | // This will pick the best possible CUDA capable device, otherwise 302 | // override the device ID based on input provided at the command line 303 | 304 | int block_size = BLOCK_SIZE; 305 | 306 | dim3 dimsA(4 * 2 * block_size, 4 * 2 * block_size, 1); 307 | dim3 dimsB = dimsA; 308 | dim3 dimsD = dimsA; 309 | 310 | printf("MatrixA(%d,%d), MatrixB(%d,%d), MatrixD(%d,%d)\n", dimsA.x, dimsA.y, dimsB.x, 311 | dimsB.y, dimsD.x, dimsD.y); 312 | 313 | int matrix_result = MatrixMultiply(argc, argv, block_size, dimsA, dimsB, dimsD); 314 | 315 | exit(matrix_result); 316 | } 317 | -------------------------------------------------------------------------------- /src/include/cusync/cusync.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | // 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "cusync_defines.h" 13 | 14 | #pragma once 15 | 16 | #define CUDA_CHECK(cmd) do { \ 17 | cudaError_t e = cmd; \ 18 | if( e != cudaSuccess ) { \ 19 | printf("Failed: Cuda error %s:%d '%s'\n", \ 20 | __FILE__,__LINE__,cudaGetErrorString(e)); \ 21 | exit(EXIT_FAILURE); \ 22 | } \ 23 | } while(0); 24 | 25 | #define DIVUP(x, y) (((x) + (y) - 1)/(y)); 26 | 27 | #include "policies.h" 28 | #include "wait-kernel.h" 29 | 30 | namespace cusync { 31 | /* 32 | * A test class to access private members of CuStage 33 | */ 34 | class CuSyncTest; 35 | class CuSync; 36 | 37 | /* 38 | * List of CuSync errors. 39 | * CuSyncErrorNotProducer : Operation is performed on a CuStage which is not a producer 40 | * CuSyncErrorNotConsumer : Operation is performed on a CuStage which is not a producer 41 | * CuSyncErrorNotInitialized : CuStage is not initialized 42 | * CuSyncErrorInvalidLinearBlockIndex: TileOrder do not cover all thread blocks to a linear index 43 | * CuSyncErrorCUDAError : Internal CUDA Error, use cudaGetLastError() 44 | * CuSyncSuccess : Operation sucess 45 | */ 46 | enum CuSyncError { 47 | CuSyncErrorNotProducer, 48 | CuSyncErrorNotConsumer, 49 | CuSyncErrorNotInitialized, 50 | CuSyncErrorInvalidLinearBlockIndex, 51 | CuSyncErrorCUDAError, 52 | CuSyncSuccess 53 | }; 54 | 55 | /* 56 | * List of optimizations for a CuStage that avoids certain operations 57 | * performed by a CuStage for a specific scenario. 58 | * NoOptimizations : No optimization is performed 59 | * NoAtomicAdd : Use memory write instead of atomic add. Useful when each 60 | * tile is associated with a distinct semaphore 61 | * AvoidWaitKernel : Avoid calling wait kernel. Useful when thread blocks of all 62 | * CuStages can be allocated within a single wave 63 | * AvoidCustomOrder: Avoid assigning tiles in the specific order but use CUDA's 64 | * arbitrary order. Useful when thread blocks of N dependent 65 | * CuStages can be allocated within (N - 1) waves 66 | * ReorderTileLoads: Reorder tile loads to overlap computation of one input's tile 67 | * with loading of other inputs tile 68 | */ 69 | enum Optimizations { 70 | NoOptimization = 0, 71 | NoAtomicAdd = 1 << 0, 72 | AvoidWaitKernel = 1 << 1, 73 | AvoidCustomOrder = 1 << 2, 74 | ReorderTileLoads = 1 << 3 75 | }; 76 | 77 | /* 78 | * A CuStage is associated with a single kernel. A CuStage contains following 79 | * information about its kernel: 80 | * 1. grid and tile size of the kernel 81 | * 2. grid size of its producer kernel 82 | * 3. input and output synchronization policies for the kernel 83 | * 84 | * Moreover, CuStage contains pointers to the tile order and array of semaphore 85 | * for tile synchronization policies. 86 | */ 87 | template 92 | class CuStage { 93 | private: 94 | //grid size of this stage 95 | dim3 grid_; 96 | //grid size of the producer stage 97 | dim3 prodGrid_; 98 | //tile size of this stage 99 | dim3 tileSize_; 100 | 101 | //Number of runs of stage kernels invoked 102 | int iter; 103 | //Producer Sync policy of the stage 104 | InputSyncPolicy inputPolicy_; 105 | //Consumer Sync policy of the stage 106 | OutputSyncPolicy outputPolicy_; 107 | 108 | //GPU pointer to array of order of tiles 109 | dim3* tileOrder; 110 | //GPU pointer to counter of tile for index in tile order 111 | uint32_t* tileCounter; 112 | 113 | //GPU pointer to wait kernel semaphore 114 | int* kernelExecuted_; 115 | 116 | volatile uint32_t* tileStatusWrite_; 117 | volatile uint32_t* tileStatusRead_; 118 | 119 | //CuSyncTest and CuSync can access private members 120 | friend class CuSyncTest; 121 | friend class CuSync; 122 | 123 | //Call TileOrder parameter to generate tile order and store 124 | //it in tileOrder 125 | CuSyncError buildScheduleBuffer() { 126 | dim3* hTileOrder = new dim3[numTiles()]; 127 | bool errInvalidLinearBlockIndex = false; 128 | 129 | CUDA_CHECK(cudaMalloc(&tileCounter, sizeof(int))); 130 | CUDA_CHECK(cudaMemset(tileCounter, 0, sizeof(int))); 131 | CUDA_CHECK(cudaMalloc(&tileOrder, sizeof(*tileOrder) * numTiles())); 132 | 133 | dim3 invalidBlock = {numTiles(), 0, 0}; 134 | for (uint32_t id = 0; id < numTiles(); id++) { 135 | hTileOrder[id] = invalidBlock; 136 | } 137 | 138 | for (uint32_t z = 0; z < grid_.z; z++) { 139 | for (uint32_t y = 0; y < grid_.y; y++) { 140 | for (uint32_t x = 0; x < grid_.x; x++) { 141 | size_t id = TileOrder().blockIndex(grid_, {x, y, z}); 142 | if (hTileOrder[id].x == invalidBlock.x) { 143 | hTileOrder[id] = {x, y, z}; 144 | } else { 145 | errInvalidLinearBlockIndex = true; 146 | } 147 | }}} 148 | 149 | CUDA_CHECK(cudaMemcpy(tileOrder, hTileOrder, 150 | sizeof(*tileOrder) * numTiles(), 151 | cudaMemcpyHostToDevice)); 152 | delete[] hTileOrder; 153 | 154 | if (errInvalidLinearBlockIndex) return CuSyncErrorInvalidLinearBlockIndex; 155 | 156 | return CuSyncSuccess; 157 | } 158 | 159 | //Set the producer grid 160 | template 161 | void setProdGrid(ProdCuStage& prod) {prodGrid_ = prod.grid();} 162 | 163 | //Get tile status semaphore arrays 164 | CUSYNC_DEVICE_HOST 165 | volatile uint32_t* getTileStatusToPost() {return tileStatusWrite_;} 166 | CUSYNC_DEVICE_HOST 167 | volatile uint32_t* getTileStatusToWait() {return tileStatusRead_;} 168 | 169 | //Set tile status semaphore arrays 170 | CUSYNC_HOST 171 | void setTileStatusToPost(volatile uint32_t* ptr) {tileStatusWrite_ = ptr ;} 172 | CUSYNC_HOST 173 | void setTileStatusToWait(volatile uint32_t* ptr) {tileStatusRead_ = ptr ;} 174 | 175 | public: 176 | CuStage(dim3 grid, dim3 tileSize, InputSyncPolicy inputPolicy, OutputSyncPolicy outputPolicy) : 177 | grid_(grid), 178 | prodGrid_(0), //set by CuSync::set* methods 179 | tileSize_(tileSize), 180 | iter(1), //run counter starts from 1 181 | inputPolicy_(inputPolicy), 182 | outputPolicy_(outputPolicy) { 183 | 184 | buildScheduleBuffer(); 185 | 186 | if (isProducer()) { 187 | //Allocate tile status semaphore array for all tiles 188 | //CuSync::set* methods set this array to consumer stages 189 | CUDA_CHECK(cudaMalloc(&tileStatusWrite_, numTiles() * sizeof(int))); 190 | CUDA_CHECK(cudaMemset((uint32_t*)tileStatusWrite_, 0, numTiles() * sizeof(int))); 191 | 192 | //Allocate wait kernel semaphore 193 | if (!getAvoidWaitKernel()) { 194 | CUDA_CHECK(cudaMalloc(&kernelExecuted_, sizeof(int))); 195 | CUDA_CHECK(cudaMemset(kernelExecuted_, 0, sizeof(int))); 196 | } 197 | } 198 | } 199 | 200 | //Return grid size of this stage 201 | dim3 grid() {return grid_;} 202 | 203 | CuSyncError invokeWaitKernel(cudaStream_t stream) { 204 | if (!isProducer()) return CuSyncErrorNotProducer; 205 | if (!getAvoidWaitKernel()) 206 | cusync::invokeWaitKernel((uint32_t*)kernelExecuted_, iter, stream); 207 | if (cudaGetLastError() != cudaSuccess) return CuSyncErrorCUDAError; 208 | return CuSyncSuccess; 209 | } 210 | 211 | 212 | void incrementIter() {iter += 1;} 213 | 214 | CUSYNC_DEVICE_HOST 215 | CuStage(): iter(1) {} 216 | 217 | /* 218 | * Getters and setters for private variables. 219 | */ 220 | //Getters for optimizations 221 | CUSYNC_DEVICE_HOST 222 | bool getNoAtomicAdd () {return Opts & NoAtomicAdd; } 223 | CUSYNC_DEVICE_HOST 224 | bool getAvoidWaitKernel () {return Opts & AvoidWaitKernel; } 225 | CUSYNC_DEVICE_HOST 226 | bool getReorderTileLoads() {return Opts & ReorderTileLoads;} 227 | CUSYNC_DEVICE_HOST 228 | bool getAvoidCustomOrder() {return Opts & AvoidCustomOrder;} 229 | 230 | //A producer does have a policy for its output 231 | CUSYNC_DEVICE_HOST 232 | bool isProducer() {return !std::is_same::value;} 233 | 234 | //A consumer does have a policy for its input 235 | CUSYNC_DEVICE_HOST 236 | bool isConsumer() {return !std::is_same::value;} 237 | 238 | /* 239 | * Returns total number of thread blocks 240 | */ 241 | CUSYNC_DEVICE_HOST 242 | uint32_t numTiles() {return grid_.x *grid_.y*grid_.z;} 243 | 244 | /* 245 | * Returns the tile index of tile using the input policy 246 | */ 247 | CUSYNC_DEVICE 248 | uint32_t waitTileIndex(dim3 tile) { 249 | return inputPolicy_.tileIndex(tile, prodGrid_);; 250 | } 251 | 252 | /* 253 | * Return semaphore value for the tile index 254 | */ 255 | CUSYNC_DEVICE 256 | uint32_t waitSemValue(dim3 tile) { 257 | return globalVolatileLoad(&tileStatusRead_[waitTileIndex(tile)]); 258 | } 259 | 260 | /* 261 | * Return expected wait value for the tile 262 | */ 263 | CUSYNC_DEVICE 264 | uint32_t expectedWaitValue(dim3 tile) { 265 | return inputPolicy_.waitValue(tile, prodGrid_); 266 | } 267 | 268 | /* 269 | * Wait until the semaphore of the tile reaches the wait value 270 | */ 271 | CUSYNC_DEVICE 272 | CuSyncError wait(dim3& tile, uint32_t waitingThread = 0, bool callSync = true); 273 | 274 | /* 275 | * Post the status of completion of tile. 276 | */ 277 | CUSYNC_DEVICE 278 | CuSyncError post(const dim3& tile, uint32_t postThread = 0); 279 | 280 | /* 281 | * Returns the next tile process and set the waitkernel's semaphore if valid 282 | */ 283 | CUSYNC_DEVICE 284 | dim3 tile(dim3* shared_storage); 285 | }; 286 | 287 | struct CuSync { 288 | template 289 | static CuSyncError setProducerConsumerPair(Stage1& prod, Stage2& cons) { 290 | if (!prod.isProducer()) return CuSyncErrorNotProducer; 291 | if (!cons.isConsumer()) return CuSyncErrorNotConsumer; 292 | if (prod.getTileStatusToPost() == nullptr) 293 | return CuSyncErrorNotInitialized; 294 | 295 | cons.setProdGrid(prod); 296 | cons.setTileStatusToWait(prod.getTileStatusToPost()); 297 | return CuSyncSuccess; 298 | } 299 | }; 300 | } 301 | 302 | #if (defined(__NVCC__) || defined(__CUDACC__)) 303 | #include "cusync_device_defs.h" 304 | #include "tile-orders.h" 305 | #include "device-functions.h" 306 | #endif 307 | -------------------------------------------------------------------------------- /src/include/cusync/cusync_defines.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | // 4 | #if (defined(__CUDACC__) || defined(__NVCC__)) 5 | #define CUSYNC_DEVICE __device__ __forceinline__ 6 | #else 7 | #define CUSYNC_DEVICE 8 | #endif 9 | 10 | #if (defined(__CUDACC__) || defined(__NVCC__)) 11 | #define CUSYNC_HOST __host__ __forceinline__ 12 | #else 13 | #define CUSYNC_HOST 14 | #endif 15 | 16 | #if (defined(__CUDACC__) || defined(__NVCC__)) 17 | #define CUSYNC_DEVICE_HOST __device__ __host__ __forceinline__ 18 | #else 19 | #define CUSYNC_DEVICE_HOST 20 | #endif 21 | 22 | #if (defined(__CUDACC__) || defined(__NVCC__)) 23 | #define CUSYNC_GLOBAL __global__ 24 | #else 25 | #define CUSYNC_GLOBAL 26 | #endif 27 | -------------------------------------------------------------------------------- /src/include/cusync/cusync_device_defs.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "tile-orders.h" 5 | #include "policies.h" 6 | #include "device-functions.h" 7 | #include "wait-kernel.h" 8 | 9 | #pragma once 10 | 11 | #define CUDA_CHECK(cmd) do { \ 12 | cudaError_t e = cmd; \ 13 | if( e != cudaSuccess ) { \ 14 | printf("Failed: Cuda error %s:%d '%s'\n", \ 15 | __FILE__,__LINE__,cudaGetErrorString(e)); \ 16 | exit(EXIT_FAILURE); \ 17 | } \ 18 | } while(0); 19 | 20 | #define DIVUP(x, y) (((x) + (y) - 1)/(y)); 21 | 22 | namespace cusync { 23 | #define CUSTAGE_METHOD_DEF(RET_TYPE) \ 24 | template \ 28 | CUSYNC_DEVICE RET_TYPE CuStage:: 29 | 30 | /* 31 | * Wait until the semaphore of the tile reaches the wait value 32 | */ 33 | CUSTAGE_METHOD_DEF(CuSyncError) wait(dim3& tile, uint32_t waitingThread, bool callSync) { 34 | if (!isConsumer()) return CuSyncErrorNotConsumer; 35 | if (!inputPolicy_.isSync(tile, prodGrid_)) return CuSyncSuccess; 36 | 37 | if (threadIdx.x == waitingThread && threadIdx.y == 0 && threadIdx.z == 0) { 38 | uint32_t w = inputPolicy_.waitValue(tile, prodGrid_); 39 | uint32_t idx = inputPolicy_.tileIndex(tile, prodGrid_); 40 | auto v = globalLoad(&tileStatusRead_[idx]); 41 | while(v < iter * w) { 42 | v = globalVolatileLoad(&tileStatusRead_[idx]); 43 | } 44 | } 45 | 46 | if (callSync) 47 | __syncthreads(); 48 | 49 | return CuSyncSuccess; 50 | } 51 | 52 | /* 53 | * Post the status of completion of tile. 54 | */ 55 | CUSTAGE_METHOD_DEF(CuSyncError) post(const dim3& tile, uint32_t postThread) { 56 | if (!isProducer()) return CuSyncErrorNotProducer; 57 | __syncthreads(); 58 | if (threadIdx.x == postThread && threadIdx.y == 0 && threadIdx.z == 0) { 59 | __threadfence_system(); 60 | uint32_t idx = outputPolicy_.tileIndex(tile, grid_); 61 | if (!getNoAtomicAdd()) { 62 | atomicAdd((int*)&tileStatusWrite_[idx], 63 | outputPolicy_.postValue(tile, grid_)); 64 | } else { 65 | uint32_t val = outputPolicy_.postValue(tile, grid_) * iter; 66 | asm volatile ("st.global.release.gpu.u32 [%0], {%1};" :: "l"((int*)&tileStatusWrite_[idx]), "r"(val)); 67 | } 68 | } 69 | 70 | __syncwarp(); 71 | return CuSyncSuccess; 72 | } 73 | 74 | /* 75 | * Returns the next tile process and set the waitkernel's semaphore if valid 76 | */ 77 | CUSTAGE_METHOD_DEF(dim3) tile(dim3* shared_storage) { 78 | if (!getAvoidWaitKernel()) { 79 | if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && 80 | blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && isProducer()) { 81 | *kernelExecuted_ = iter; 82 | } 83 | } 84 | if (!getAvoidCustomOrder()) { 85 | if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { 86 | if (shared_storage != nullptr) { 87 | uint32_t linear_id = atomicAdd(tileCounter, 1); 88 | if (linear_id == numTiles() - 1) { 89 | *tileCounter = 0; 90 | } 91 | *shared_storage = tileOrder[linear_id]; 92 | } 93 | } 94 | 95 | if (shared_storage != nullptr) { 96 | __syncthreads(); 97 | return *shared_storage; 98 | } 99 | return blockIdx; 100 | } else { 101 | return blockIdx; 102 | } 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /src/include/cusync/device-functions.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "cusync_defines.h" 5 | 6 | #pragma once 7 | 8 | /* 9 | * Volatile load of an unsigned integer from global memory 10 | * @addr: global memory address of an unsigned integer 11 | * 12 | * Returns the loaded unsigned integer 13 | */ 14 | CUSYNC_DEVICE 15 | static uint32_t globalVolatileLoad(volatile uint32_t* addr) { 16 | uint32_t val; 17 | asm volatile ("ld.global.acquire.gpu.u32 {%0}, [%1];" : "=r"(val) : "l"(addr)); 18 | return val; 19 | } 20 | 21 | /* 22 | * Load of an unsigned integer from global memory and caching in L2 cache 23 | * @addr: global memory address of an unsigned integer 24 | * 25 | * Returns the loaded unsigned integer 26 | */ 27 | CUSYNC_DEVICE 28 | static uint32_t globalLoad(volatile uint32_t* addr) { 29 | uint32_t val; 30 | asm volatile ("ld.global.cg.u32 {%0}, [%1];" : "=r"(val) : "l"(addr)); 31 | return val; 32 | } 33 | -------------------------------------------------------------------------------- /src/include/cusync/policies.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | namespace cusync { 7 | /* 8 | * A synchronization policy (SyncPolicy) is a struct of four method: 9 | * uint32_t waitValue(const dim3& tile, const dim3& grid) returns completed value for the tile 10 | * uint32_t tileIndex(const dim3& tile, const dim3& grid) returns semaphore index for the tile 11 | * uint32_t isSync (const dim3& tile, const dim3& grid) returns if semaphore should be sync for the tile 12 | * uint32_t postValue(const dim3& tile, const dim3& grid) returns the value of semaphore when tile is processed 13 | */ 14 | 15 | /* 16 | * No Synchronization Policy. A CuStage will not call any methods of this policy. 17 | */ 18 | struct NoSync { 19 | CUSYNC_DEVICE_HOST 20 | NoSync() {} 21 | 22 | #if (defined(__CUDACC__) || defined(__NVCC__)) 23 | CUSYNC_DEVICE 24 | uint32_t waitValue(const dim3& tile, const dim3& grid) {return 0;} 25 | CUSYNC_DEVICE 26 | uint32_t tileIndex(const dim3& tile, const dim3& grid) {return 0;} 27 | CUSYNC_DEVICE 28 | bool isSync (const dim3& tile, const dim3& grid) {return false;} 29 | CUSYNC_DEVICE 30 | uint32_t postValue(const dim3& tile, const dim3& grid) {return 0;} 31 | #endif 32 | }; 33 | 34 | /* 35 | * RowSync policy assigns same semaphore for tiles sharing the same row (x index of tile) 36 | */ 37 | template 38 | struct RowSync { 39 | //Value to wait on for each row 40 | uint32_t waitValue_; 41 | //Value to post when the tile is computed 42 | uint32_t postValue_; 43 | 44 | /* 45 | * Default constructor for RowSync initializes wait and post value to 0 46 | */ 47 | CUSYNC_DEVICE_HOST 48 | RowSync() : waitValue_(0), postValue_(0) {} 49 | 50 | /* 51 | * Initializes post value to 1 and wait value to the given value 52 | */ 53 | RowSync(uint32_t waitValue) : waitValue_(waitValue), postValue_(1) {} 54 | 55 | /* 56 | * Initializes post value and wait value 57 | */ 58 | RowSync(uint32_t waitValue, uint32_t postValue) : 59 | waitValue_(waitValue), postValue_(postValue) {} 60 | 61 | #if (defined(__CUDACC__) || defined(__NVCC__)) 62 | /* 63 | * Returns the wait value 64 | */ 65 | CUSYNC_DEVICE 66 | uint32_t waitValue(const dim3& tile, const dim3& grid) { 67 | return waitValue_; 68 | } 69 | 70 | /* 71 | * Returns the tile index as the x index of tile 72 | */ 73 | CUSYNC_DEVICE 74 | uint32_t tileIndex(const dim3& tile, const dim3& grid) { 75 | return tile.x/TileM; 76 | } 77 | 78 | /* 79 | * Returns true only when the tile is the first tile of column, 80 | * i.e., y index is 0 81 | */ 82 | CUSYNC_DEVICE 83 | bool isSync(const dim3& tile, const dim3& grid) { 84 | return tile.z == 1; 85 | } 86 | 87 | /* 88 | * Returns the post value 89 | */ 90 | CUSYNC_DEVICE 91 | uint32_t postValue(const dim3& tile, const dim3& grid) { 92 | return postValue_; 93 | } 94 | #endif 95 | }; 96 | 97 | /* 98 | * TileSync assigns distinct semaphore to each tile 99 | */ 100 | template 101 | struct TileSync { 102 | uint32_t waitValue_; 103 | uint32_t postValue_; 104 | 105 | /* 106 | * Initializes both wait and post value to 1 107 | */ 108 | CUSYNC_DEVICE_HOST 109 | TileSync(): waitValue_(1), postValue_(1) {} 110 | 111 | /* 112 | * Initializes both wait and post values to given values 113 | */ 114 | TileSync(uint32_t waitValue, uint32_t postValue): 115 | waitValue_(waitValue), postValue_(postValue) {} 116 | 117 | #if (defined(__CUDACC__) || defined(__NVCC__)) 118 | /* 119 | * Return the wait value 120 | */ 121 | CUSYNC_DEVICE_HOST 122 | uint32_t waitValue(const dim3& tile, const dim3& grid) { 123 | return waitValue_; 124 | } 125 | 126 | /* 127 | * Return the post value 128 | */ 129 | CUSYNC_DEVICE_HOST 130 | uint32_t postValue(const dim3& tile, const dim3& grid) 131 | {return postValue_;} 132 | 133 | /* 134 | * Return the linear tile index for the grid 135 | */ 136 | CUSYNC_DEVICE 137 | constexpr uint32_t tileIndex(const dim3& tile, const dim3& grid) { 138 | return TileOrder().tileIndex({tile.x/TileM, tile.y/TileN, 0}, grid); 139 | } 140 | 141 | /* 142 | * Always synchronize on a tile 143 | */ 144 | CUSYNC_DEVICE 145 | bool isSync(const dim3& tile, const dim3& grid) { 146 | return tile.y%TileN == 0; 147 | } 148 | #endif 149 | }; 150 | 151 | /* 152 | * Synchronizes tiles of 2D Implicit GeMM Convolution for given values of 153 | * 2D Convolution kernel size (R x S). 154 | * 155 | * The implict GeMM algorithm converts a Conv2D of B input images of size [P, Q, C] 156 | * with a kernel matrix of size [R, S] into a GeMM of matrices 157 | * [B∗P∗Q, C∗R∗S] x [C∗R∗S, C]. Therefore, a tile {x,y} of the consumer Conv2D 158 | * synchronizes on the tile {x, y/(R*S)} of its producer Conv2D. 159 | */ 160 | template 161 | struct Conv2DTileSync { 162 | uint32_t waitValue_; 163 | uint32_t postValue_; 164 | 165 | /* 166 | * Initializes both wait and post value to 1 167 | */ 168 | CUSYNC_DEVICE_HOST 169 | Conv2DTileSync(): waitValue_(1), postValue_(1) {} 170 | 171 | /* 172 | * Initializes both wait and post value to given values 173 | */ 174 | Conv2DTileSync(uint32_t waitValue, uint32_t postValue): 175 | waitValue_(waitValue), postValue_(postValue) {} 176 | 177 | #if (defined(__CUDACC__) || defined(__NVCC__)) 178 | /* 179 | * Returns the wait value 180 | */ 181 | CUSYNC_DEVICE 182 | uint32_t waitValue(const dim3& tile, const dim3& grid) { 183 | return waitValue_; 184 | } 185 | 186 | /* 187 | * Returns the post value 188 | */ 189 | CUSYNC_DEVICE 190 | uint32_t postValue(const dim3& tile, const dim3& grid) 191 | {return postValue_;} 192 | 193 | /* 194 | * Returns the wait value 195 | */ 196 | CUSYNC_DEVICE 197 | uint32_t tileIndex(const dim3& tile, const dim3& grid) { 198 | return TileOrder().tileIndex({tile.x/TileM, (tile.y/TileN)/(R*S), 0}, grid); 199 | } 200 | 201 | /* 202 | * Synchronizes tiles only when it is a multiple of 203 | * the conv kernel size 204 | */ 205 | CUSYNC_DEVICE 206 | bool isSync(const dim3& tile, const dim3& grid) { 207 | return (tile.y/TileN) % (R * S) == 0; 208 | } 209 | #endif 210 | }; 211 | 212 | #if 0 213 | /* 214 | * Other experimental sync policies 215 | */ 216 | #define BatchedRows 2 217 | 218 | struct BatchedRowSync { 219 | uint32_t waitValue_; 220 | uint32_t postValue_; 221 | __device__ __host__ BatchedRowSync() : waitValue_(0), postValue_(0) {} 222 | __device__ __host__ BatchedRowSync(uint32_t waitValue) : waitValue_(waitValue), postValue_(1) {} 223 | __device__ __host__ BatchedRowSync(uint32_t waitValue, uint32_t postValue) : 224 | waitValue_(waitValue), postValue_(postValue) {} 225 | 226 | __device__ bool canBatch(const dim3& tile) { 227 | return true; 228 | } 229 | 230 | __device__ uint32_t waitValue(const dim3& tile, const dim3& grid) { 231 | return waitValue_ * BatchedRows; 232 | } 233 | 234 | __device__ uint32_t tileIndex(const dim3& tile, const dim3& grid) { 235 | return tile.x/BatchedRows; 236 | } 237 | 238 | __device__ bool isSync(const dim3& tile) { 239 | return tile.y == 0; 240 | } 241 | 242 | __device__ uint32_t postValue(const dim3& tile, const dim3& grid) { 243 | return postValue_; 244 | } 245 | }; 246 | 247 | struct BatchedRowSync2 { 248 | uint32_t waitValue_; 249 | uint32_t postValue_; 250 | __device__ __host__ BatchedRowSync2() : waitValue_(0), postValue_(0) {} 251 | __device__ __host__ BatchedRowSync2(uint32_t waitValue) : waitValue_(waitValue), postValue_(1) {} 252 | __device__ __host__ BatchedRowSync2(uint32_t waitValue, uint32_t postValue) : 253 | waitValue_(waitValue), postValue_(postValue) {} 254 | 255 | __device__ bool canBatch(const dim3& tile) { 256 | if (tile.x >= BatchedRows) 257 | return true; 258 | return false; 259 | } 260 | 261 | __device__ uint32_t waitValue(const dim3& tile, const dim3& grid) { 262 | if (canBatch(tile)) 263 | return waitValue_ * (grid.x - BatchedRows); 264 | return waitValue_; 265 | } 266 | 267 | __device__ uint32_t tileIndex(const dim3& tile, const dim3& grid) { 268 | if (canBatch(tile)) { 269 | return BatchedRows; 270 | } 271 | return tile.x; 272 | } 273 | 274 | __device__ bool isSync(const dim3& tile) { 275 | return tile.y == 0; 276 | } 277 | 278 | __device__ uint32_t postValue(const dim3& tile, const dim3& grid) { 279 | return postValue_; 280 | } 281 | }; 282 | 283 | struct TileFirstAndRowSync { 284 | uint32_t waitTileValue_; 285 | uint32_t postTileValue_; 286 | uint32_t waitRowValue_; 287 | uint32_t postRowValue_; 288 | 289 | __device__ __host__ TileFirstAndRowSync() {} 290 | __device__ __host__ TileFirstAndRowSync(uint32_t waitTileValue, uint32_t postTileValue, 291 | uint32_t waitRowValue) : 292 | waitTileValue_(waitTileValue), postTileValue_(postTileValue), waitRowValue_(waitRowValue), postRowValue_(1) {} 293 | // __device__ __host__ TileFirstAndRowSync(uint32_t waitValue, uint32_t postValue) : 294 | // waitValue_(waitValue), postValue_(postValue) {} 295 | 296 | __device__ int tileBatch(const dim3& tile) { 297 | if (isTileSync(tile)) 298 | return 8; 299 | return 1; 300 | } 301 | 302 | __device__ bool isTileSync(const dim3& tile) { 303 | if (tile.x < 1) { 304 | return true; 305 | } 306 | return false; 307 | } 308 | 309 | __device__ bool isRowSync(const dim3& tile) { 310 | return !isTileSync(tile); 311 | } 312 | 313 | __device__ uint32_t waitValue(const dim3& tile, const dim3& grid) { 314 | if (isTileSync(tile)) { 315 | return waitTileValue_ * tileBatch(tile); 316 | } 317 | 318 | return waitRowValue_; 319 | } 320 | 321 | __device__ uint32_t tileIndex(const dim3& tile, const dim3& grid) { 322 | if (isTileSync(tile)) { 323 | return (tile.x * 48 + tile.y)/tileBatch(tile); 324 | } 325 | return 1 * 48 + tile.x; 326 | } 327 | 328 | __device__ bool isSync(const dim3& tile, const dim3& grid) { 329 | if (isTileSync(tile)) 330 | return true; 331 | else 332 | return tile.y == 0; 333 | } 334 | 335 | __device__ uint32_t postValue(const dim3& tile, const dim3& grid) { 336 | if (isTileSync(tile)) { 337 | return postTileValue_; 338 | } 339 | 340 | return postRowValue_; 341 | } 342 | }; 343 | 344 | struct FirstTileSync { 345 | uint32_t waitValue_; 346 | uint32_t postValue_; 347 | 348 | __device__ __host__ FirstTileSync(): waitValue_(1), postValue_(1) {} 349 | __device__ __host__ FirstTileSync(uint32_t waitValue, uint32_t postValue): 350 | waitValue_(waitValue), postValue_(postValue) {} 351 | 352 | __device__ __host__ uint32_t waitValue(const dim3& tile, const dim3& grid) { 353 | return waitValue_; 354 | } 355 | 356 | __device__ __host__ uint32_t postValue(const dim3& tile, const dim3& grid) 357 | {return postValue_;} 358 | 359 | __device__ constexpr uint32_t tileIndex(const dim3& tile, const dim3& grid) { 360 | return (tile.x * grid.y + tile.y); 361 | } 362 | 363 | __device__ bool isSync(const dim3& tile, const dim3& grid) { 364 | return tile.y == 0; 365 | } 366 | }; 367 | #endif 368 | } 369 | -------------------------------------------------------------------------------- /src/include/cusync/tile-orders.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | namespace cusync { 7 | /* 8 | * A tile order generates processing order of tiles and maps tiles to thread blocks. 9 | * A tile order is declared as follows and must subclass GenericTileOrder: 10 | * struct TileOrder : public GenericTileOrder { 11 | * size_t operator()(dim3& grid, dim3& tile); 12 | * dim3 tileToBlock(const dim3& tile); 13 | * } 14 | * GenericTileOrder assume its subclass should not have a state. 15 | * If a stateful order is needed the subclass should override 16 | * the GenericTileOrder::tileIndex method. 17 | */ 18 | 19 | template 20 | struct GenericTileOrder { 21 | /* 22 | * Returns a linear index of the thread block in the grid. 23 | */ 24 | CUSYNC_DEVICE_HOST 25 | uint32_t blockIndex(const dim3& grid, const dim3& block) 26 | {return 0;} 27 | 28 | /* 29 | * Maps a tile to a thread block. 30 | */ 31 | CUSYNC_DEVICE 32 | dim3 tileToBlock(const dim3& tile) 33 | {return {0,0,0};} 34 | 35 | /* 36 | * Returns a linear tile index. 37 | */ 38 | CUSYNC_DEVICE 39 | uint32_t tileIndex(const dim3& tile, const dim3& grid) { 40 | dim3 block = Child().tileToBlock(tile); 41 | return Child().blockIndex(grid, block); 42 | } 43 | }; 44 | 45 | /* 46 | * TransposeXYOrder order that generates tile indices first for X-dimension, then 47 | * Y-dimension, and finally Z-dimension. 48 | * It maps a tile {x, y, z} to threadblock {y, x, z} 49 | */ 50 | struct TransposeXYOrder : public GenericTileOrder { 51 | CUSYNC_DEVICE_HOST 52 | uint32_t blockIndex(const dim3& grid, const dim3& block) { 53 | return block.x + block.y * grid.x + block.z * grid.x * grid.y; 54 | } 55 | 56 | CUSYNC_DEVICE 57 | dim3 tileToBlock(const dim3& tile) { 58 | return dim3{tile.y, tile.x, tile.z}; 59 | } 60 | }; 61 | 62 | /* 63 | * IdentityOrder orders tile indices first for X-dimension, then Y-dim, and finally Z-dim. 64 | * It maps a tile {x,y,z} to threadblock {x,y,z}. 65 | */ 66 | struct IdentityOrder : public GenericTileOrder { 67 | CUSYNC_DEVICE_HOST 68 | uint32_t blockIndex(const dim3& grid, const dim3& block) { 69 | return block.x + block.y * grid.x + block.z * grid.x * grid.y; 70 | } 71 | 72 | CUSYNC_DEVICE 73 | dim3 tileToBlock(const dim3& tile) { 74 | return dim3{tile.x, tile.y, tile.z}; 75 | } 76 | }; 77 | 78 | #if 0 79 | //Experimental Orders 80 | struct OrderZXY { 81 | __device__ __host__ __forceinline__ 82 | uint32_t operator()(const dim3& grid, const dim3& tile) { 83 | return tile.z + tile.x * grid.z + tile.y * grid.x * grid.z; 84 | } 85 | 86 | __device__ __host__ __forceinline__ 87 | dim3 tileToBlock(const dim3& tile) { 88 | return dim3{tile.y, tile.x, tile.z}; 89 | } 90 | 91 | __device__ __host__ __forceinline__ 92 | uint32_t tileIndex(const dim3& tile, const dim3& grid) { 93 | dim3 block = tileToBlock(tile); 94 | return this->operator()(grid, block); 95 | } 96 | }; 97 | 98 | struct OrderZXY2 { 99 | __device__ __host__ __forceinline__ 100 | uint32_t operator()(const dim3& grid, const dim3& tile) { 101 | return tile.x + tile.y * grid.x + tile.z * grid.x * grid.y; 102 | } 103 | 104 | __device__ __host__ __forceinline__ 105 | dim3 tileToBlock(const dim3& tile) { 106 | return dim3{tile.y, tile.x, tile.z}; 107 | } 108 | 109 | __device__ __host__ __forceinline__ 110 | uint32_t tileIndex(const dim3& tile, const dim3& grid) { 111 | dim3 block = tileToBlock(tile); 112 | return OrderZXY()(grid, block); 113 | } 114 | }; 115 | #endif 116 | } 117 | -------------------------------------------------------------------------------- /src/include/cusync/wait-kernel.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include 5 | #include "device-functions.h" 6 | 7 | #pragma once 8 | 9 | namespace cusync { 10 | void invokeWaitKernel(uint32_t* semaphore, uint32_t givenValue, cudaStream_t stream); 11 | } 12 | -------------------------------------------------------------------------------- /src/include/cutlass/cusync-cutlass/include/cutlass/conv/device/cusyncimplicit_gemm_convolution.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | * 30 | **************************************************************************************************/ 31 | /* \file 32 | \brief Template for device-level Implicit GEMM Convolution 33 | */ 34 | 35 | #pragma once 36 | 37 | #include 38 | 39 | #include "cutlass/cutlass.h" 40 | #include "cutlass/device_kernel.h" 41 | #include "cutlass/conv/convolution.h" 42 | 43 | ///////////////////////////////////////////////////////////////////////////////////////////////// 44 | 45 | namespace cutlass { 46 | namespace conv { 47 | namespace device { 48 | 49 | ///////////////////////////////////////////////////////////////////////////////////////////////// 50 | 51 | template 52 | class CuSyncImplicitGemmConvolution { 53 | public: 54 | 55 | using UnderlyingKernel = ImplicitGemmKernel_; 56 | 57 | using ElementA = typename UnderlyingKernel::ElementA; 58 | using LayoutA = typename UnderlyingKernel::LayoutA; 59 | using ElementB = typename UnderlyingKernel::ElementB; 60 | using LayoutB = typename UnderlyingKernel::LayoutB; 61 | using ElementC = typename UnderlyingKernel::ElementC; 62 | using LayoutC = typename UnderlyingKernel::LayoutC; 63 | using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator; 64 | using ElementCompute = typename UnderlyingKernel::ElementCompute; 65 | using OperatorClass = typename UnderlyingKernel::OperatorClass; 66 | using ArchTag = typename UnderlyingKernel::ArchTag; 67 | using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape; 68 | using WarpShape = typename UnderlyingKernel::WarpShape; 69 | using InstructionShape = typename UnderlyingKernel::InstructionShape; 70 | using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle; 71 | using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp; 72 | static int const kStages = UnderlyingKernel::kStages; 73 | static int const kConvDim = UnderlyingKernel::kConvDim; 74 | using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator; 75 | using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator; 76 | using MathOperator = typename UnderlyingKernel::MathOperator; 77 | 78 | static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator; 79 | static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm; 80 | static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport; 81 | static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode; 82 | 83 | static int const kWarpCount = 84 | (ThreadblockShape::kM / WarpShape::kM) * 85 | (ThreadblockShape::kN / WarpShape::kN) * 86 | (ThreadblockShape::kK / WarpShape::kK); 87 | 88 | /// Argument structure 89 | using Arguments = typename UnderlyingKernel::Arguments; 90 | 91 | /// Kernel parameters object 92 | typename UnderlyingKernel::Params params_; 93 | 94 | public: 95 | 96 | /// Constructs Implicit GEMM 97 | CuSyncImplicitGemmConvolution() { } 98 | 99 | /// Determines whether the Implicit GEMM can execute the given problem. 100 | static Status can_implement(Arguments const &args) { 101 | 102 | // dispatch to iterators 103 | Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size); 104 | if (Status::kSuccess != status) { 105 | return status; 106 | } 107 | 108 | status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size); 109 | if (Status::kSuccess != status) { 110 | return status; 111 | } 112 | 113 | // check group conv constraint 114 | if (args.problem_size.groups != 1) { 115 | if (kGroupMode == conv::GroupMode::kNone) { 116 | return Status::kErrorInvalidProblem; 117 | } 118 | 119 | // C and K should be multiple of groups 120 | if (args.problem_size.K % args.problem_size.groups || 121 | args.problem_size.C % args.problem_size.groups) { 122 | return Status::kErrorInvalidProblem; 123 | } 124 | 125 | // split-k is not supported 126 | if (args.problem_size.split_k_slices != 1) { 127 | return Status::kErrorInvalidProblem; 128 | } 129 | 130 | int k_per_group = args.problem_size.K / args.problem_size.groups; 131 | // k_per_group should be multiple of ThreadblockShape N, one CTA calculate one group 132 | if (kGroupMode == conv::GroupMode::kSingleGroup && k_per_group % ThreadblockShape::kN) { 133 | return Status::kErrorInvalidProblem; 134 | } 135 | // ThreadblockShape::kN should be divisible by k_per_group, one CTA calculate multiple groups 136 | if (kGroupMode == conv::GroupMode::kMultipleGroup && ThreadblockShape::kN % k_per_group) { 137 | return Status::kErrorInvalidProblem; 138 | } 139 | 140 | // current optimized iterator algo only supports SingleGroup mode 141 | if (kIteratorAlgorithm == IteratorAlgorithm::kOptimized && 142 | kGroupMode != conv::GroupMode::kSingleGroup) { 143 | return Status::kErrorInvalidProblem; 144 | } 145 | } 146 | 147 | static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess; 148 | if (kConvolutionalOperator == conv::Operator::kFprop) { 149 | if (args.problem_size.K % kAlignmentC) 150 | return Status::kErrorMisalignedOperand; 151 | } else if (kConvolutionalOperator == conv::Operator::kDgrad) { 152 | if (args.problem_size.C % kAlignmentC) 153 | return Status::kErrorMisalignedOperand; 154 | } else if (kConvolutionalOperator == conv::Operator::kWgrad) { 155 | if (args.problem_size.C % kAlignmentC) 156 | return Status::kErrorMisalignedOperand; 157 | } 158 | 159 | // check for unsupported problem sizes for strided dgrad implementation 160 | if (kConvolutionalOperator == conv::Operator::kDgrad && 161 | kStrideSupport == conv::StrideSupport::kStrided) { 162 | 163 | // split-k (serial or parallel) is not supported for strided dgrad 164 | if(args.problem_size.split_k_slices > 1) { 165 | return Status::kErrorNotSupported; 166 | } 167 | 168 | // dilation > {1x1} is not supported for strided dgrad 169 | if(args.problem_size.dilation_h > 1 || args.problem_size.dilation_w > 1) { 170 | return Status::kErrorNotSupported; 171 | } 172 | } 173 | 174 | // Determine grid shape 175 | ThreadblockSwizzle threadblock_swizzle; 176 | 177 | dim3 grid = threadblock_swizzle.get_grid_shape( 178 | threadblock_swizzle.get_tiled_shape( 179 | kConvolutionalOperator, 180 | args.problem_size, 181 | {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, 182 | args.problem_size.split_k_slices)); 183 | 184 | if (!(grid.y <= std::numeric_limits::max() && 185 | grid.z <= std::numeric_limits::max())) { 186 | 187 | return Status::kErrorInvalidProblem; 188 | } 189 | 190 | return Status::kSuccess; 191 | } 192 | 193 | /// Gets the workspace size 194 | static size_t get_workspace_size(Arguments const &args) { 195 | 196 | size_t workspace_bytes = 0; 197 | 198 | // Determine grid shape 199 | ThreadblockSwizzle threadblock_swizzle; 200 | 201 | cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( 202 | kConvolutionalOperator, 203 | args.problem_size, 204 | {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, 205 | args.problem_size.split_k_slices); 206 | 207 | if(args.split_k_mode == SplitKMode::kParallel) { 208 | 209 | // Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace. 210 | // The user needs to call a reduction operator to optain the final output tensor 211 | workspace_bytes = 212 | sizeof(ElementAccumulator) * 213 | size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size)) * 214 | size_t(grid_tiled_shape.k()); 215 | } 216 | 217 | else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size.split_k_slices > 1) { 218 | 219 | // Split-K serial: The user workspace is used to store semaphore and serialize writing the 220 | // final reduced output to user's output tensor 221 | workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); 222 | } 223 | 224 | return workspace_bytes; 225 | } 226 | 227 | /// Initializes GEMM state from arguments. 228 | Status initialize( 229 | Arguments const &args, 230 | void *workspace = nullptr, 231 | cudaStream_t stream = nullptr) { 232 | 233 | if (args.problem_size.split_k_slices > 1) { 234 | 235 | if (!workspace) { 236 | return Status::kErrorWorkspaceNull; 237 | } 238 | 239 | cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream); 240 | 241 | if (status != cudaSuccess) { 242 | return Status::kErrorInternal; 243 | } 244 | } 245 | 246 | // initialize the params structure from the arguments 247 | params_ = typename UnderlyingKernel::Params( 248 | args, 249 | static_cast(workspace) 250 | ); 251 | 252 | int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); 253 | 254 | if (smem_size >= (48 << 10)) { 255 | cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, 256 | cudaFuncAttributeMaxDynamicSharedMemorySize, 257 | smem_size); 258 | 259 | if (result != cudaSuccess) { 260 | return Status::kErrorInternal; 261 | } 262 | } 263 | 264 | return Status::kSuccess; 265 | } 266 | 267 | /// Initializes GEMM state from arguments. 268 | Status update(Arguments const &args, void *workspace = nullptr) { 269 | 270 | // update the params structure from the arguments 271 | params_.ptr_A = args.ref_A.data(); 272 | params_.ptr_B = args.ref_B.data(); 273 | params_.ptr_C = args.ref_C.data(); 274 | params_.ptr_D = args.ref_D.data(); 275 | params_.output_op = args.output_op; 276 | params_.semaphore = static_cast(workspace); 277 | params_.custage = args.custage; 278 | 279 | return Status::kSuccess; 280 | } 281 | 282 | /// Runs the kernel using initialized state. 283 | Status run(cudaStream_t stream = nullptr) { 284 | 285 | ThreadblockSwizzle threadblock_swizzle; 286 | 287 | dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); 288 | dim3 block(32 * kWarpCount, 1, 1); 289 | 290 | int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); 291 | 292 | cutlass::Kernel<<>>(params_); 293 | 294 | cudaError_t result = cudaGetLastError(); 295 | 296 | return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; 297 | } 298 | 299 | 300 | /// Runs the kernel using initialized state. 301 | Status operator()(cudaStream_t stream = nullptr) { 302 | return run(stream); 303 | } 304 | 305 | /// Runs the kernel using initialized state. 306 | Status operator()( 307 | Arguments const &args, 308 | void *workspace = nullptr, 309 | cudaStream_t stream = nullptr) { 310 | 311 | Status status = initialize(args, workspace, stream); 312 | 313 | if (status == Status::kSuccess) { 314 | status = run(stream); 315 | } 316 | 317 | return status; 318 | } 319 | }; 320 | 321 | ///////////////////////////////////////////////////////////////////////////////////////////////// 322 | 323 | } 324 | } 325 | } 326 | 327 | ///////////////////////////////////////////////////////////////////////////////////////////////// 328 | -------------------------------------------------------------------------------- /src/include/cutlass/cusync-cutlass/include/cutlass/gemm/threadblock/cusync_threadblock_swizzle.h: -------------------------------------------------------------------------------- 1 | ///////////////////////////////////////////////////////////////////////////////////////////////// 2 | 3 | #include 4 | 5 | namespace cutlass { 6 | namespace gemm { 7 | namespace threadblock { 8 | 9 | ///////////////////////////////////////////////////////////////////////////////////////////////// 10 | 11 | /// Adds two methods over GemmHorizontalThreadblockSwizzle 12 | struct CuSyncGemmHorizontalThreadblockSwizzle : public GemmHorizontalThreadblockSwizzle { 13 | /// Obtains the threadblock offset (in units of threadblock-scoped tiles) 14 | /// block_idx are already transposed in the kernel grid, 15 | /// so x is column dim and y is row dim 16 | CUTLASS_DEVICE 17 | GemmCoord get_tile_offset(int log_tile) const { 18 | int block_idx_x = RematerializeBlockIdxX(); 19 | int block_idx_y = RematerializeBlockIdxY(); 20 | int block_idx_z = RematerializeBlockIdxZ(); 21 | 22 | return GemmCoord{(block_idx_y >> log_tile), // 23 | (block_idx_x << log_tile) + ((block_idx_y) & ((1 << (log_tile)) - 1)), 24 | block_idx_z}; 25 | } 26 | 27 | CUTLASS_DEVICE 28 | GemmCoord get_tile_offset(int log_tile, int block_idx_x, int block_idx_y, int block_idx_z) const { 29 | return GemmCoord{(block_idx_y >> log_tile), // 30 | (block_idx_x << log_tile) + ((block_idx_y) & ((1 << (log_tile)) - 1)), 31 | block_idx_z}; 32 | } 33 | 34 | CUTLASS_HOST_DEVICE 35 | GemmCoord get_tiled_shape( 36 | GemmCoord problem_size, 37 | GemmCoord tile_size, 38 | int split_k_slices) const { 39 | return GemmHorizontalThreadblockSwizzle::get_tiled_shape(problem_size, tile_size, split_k_slices); 40 | } 41 | CUTLASS_HOST_DEVICE 42 | GemmCoord get_tiled_shape( 43 | cutlass::conv::Operator conv_operator, 44 | cutlass::conv::Conv2dProblemSize const &problem_size, 45 | GemmCoord tile_size, 46 | int split_k_slices) const { 47 | 48 | gemm::GemmCoord implicit_gemm_problem_size = 49 | cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); 50 | 51 | return GemmHorizontalThreadblockSwizzle::get_tiled_shape( 52 | implicit_gemm_problem_size, tile_size, split_k_slices); 53 | } 54 | }; 55 | 56 | template 57 | struct CuSyncGemmIdentityThreadblockSwizzle : public GemmIdentityThreadblockSwizzle { 58 | /// get_tile_offset based on custom block indices 59 | CUTLASS_DEVICE 60 | GemmCoord get_tile_offset(int log_tile, int block_idx_x, int block_idx_y, int block_idx_z) const { 61 | return GemmCoord{(block_idx_x >> log_tile), // 62 | (block_idx_y << log_tile) + ((block_idx_x) & ((1 << (log_tile)) - 1)), 63 | block_idx_z}; 64 | } 65 | }; 66 | 67 | ///////////////////////////////////////////////////////////////////////////////////////////////// 68 | 69 | } // namespace threadblock 70 | } // namespace gemm 71 | } // namespace cutlass -------------------------------------------------------------------------------- /src/ml-bench/README.md: -------------------------------------------------------------------------------- 1 | Evaluation 2 | ---------------- 3 | 4 | This directory contains instructions to run evaluation for results. 5 | 6 | ## Docker Image 7 | 8 | The docker file with pre-requisites is in the main directory. Create the docker container using 9 | 10 | ``` 11 | docker build -t cusync-cgo-24 . 12 | docker run -it --gpus all cusync-cgo-24 13 | cd /cusync 14 | ``` 15 | 16 | ## Pre-requisites (Native Execution) 17 | 18 | ### Linux Installation 19 | We recommend using Ubuntu 22.04 as the Linux OS. We have not tested our artifact with 20 | any other OS but we believe Ubuntu 20.04 and 23.04 should 21 | also work. 22 | 23 | ### Install Dependencies 24 | Execute following commands to install dependencies. 25 | 26 | ``` 27 | sudo apt update 28 | sudo apt install gcc linux-headers-$(uname -r) make g++ git python3 wget unzip python3-pip build-essential cmake 29 | ``` 30 | 31 | We use CUDA 12.2 in our experiments. 32 | 33 | Install PyTorch using pip. 34 | ``` 35 | sudo pip3 install torch torchvision torchaudio 36 | ``` 37 | 38 | ## Run Exeperiments 39 | 40 | ### MLP Results 41 | 42 | Following commands will run all experiments to gather the 43 | results 44 | 45 | ``` 46 | cd transformer 47 | python3 eval_llm.py mlp gpt3 48 | python3 eval_llm.py attention gpt3 49 | python3 eval_llm.py mlp llama 50 | python3 eval_llm.py attention llama 51 | python3 allreduce_times.py 52 | ``` 53 | 54 | ### Conv2D Results 55 | Following commands will run all experiments to gather results 56 | 57 | ``` 58 | cd volta_conv2d 59 | python3 eval_conv.py resnet 60 | python3 eval_conv.py vgg 61 | ``` 62 | 63 | ### Generate Plots 64 | ``` 65 | cd plots 66 | make -j 67 | ``` -------------------------------------------------------------------------------- /src/ml-bench/common.mk: -------------------------------------------------------------------------------- 1 | NVCC=/usr/local/cuda/bin/nvcc -std=c++17 2 | ROOT=../../../ 3 | CUSYNC=$(ROOT)/src/include 4 | CUSYNC_SRC=$(ROOT)/src/ 5 | NV_CUTLASS=$(ROOT)/src/include/cutlass/nvidia-cutlass 6 | CUSYNC_CUTLASS=$(ROOT)/src/include/cutlass/cusync-cutlass 7 | BUILD=build -------------------------------------------------------------------------------- /src/ml-bench/plots/Makefile: -------------------------------------------------------------------------------- 1 | TRANSFORMER_RESULTS=../transformer/results/ 2 | CONV_RESULTS=../volta_conv2d/results/ 3 | 4 | .SECONDEXPANSION: 5 | 6 | mlp-gpt3-v100: $(TRANSFORMER_RESULTS)/$$@.csv 7 | python3 plotGPT.py $(TRANSFORMER_RESULTS)/$@.csv $@.png 8 | 9 | mlp-gpt3-a100: $(TRANSFORMER_RESULTS)/$$@.csv 10 | python3 plotGPT.py $(TRANSFORMER_RESULTS)/$@.csv $@.png 11 | 12 | mlp-llama-v100: $(TRANSFORMER_RESULTS)/$$@.csv 13 | python3 plotGPT.py $(TRANSFORMER_RESULTS)/$@.csv $@.png 14 | 15 | mlp-llama-a100: $(TRANSFORMER_RESULTS)/$$@.csv 16 | python3 plotGPT.py $(TRANSFORMER_RESULTS)/$@.csv $@.png 17 | 18 | all: mlp-gpt3-v100 mlp-gpt3-a100.pdf mlp-llama-v100.pdf mlp-llama-a100.pdf 19 | 20 | clean: 21 | rm -rf *.pdf ; rm -rf *.png 22 | -------------------------------------------------------------------------------- /src/ml-bench/plots/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | colors = ["#66c2a5", "#fc8d62", "#8da0cb","#e78ac3"] 5 | -------------------------------------------------------------------------------- /src/ml-bench/plots/mlp-gpt3-a100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/cusync/345b60a137f89716047ce55590d5aed5f9adebb0/src/ml-bench/plots/mlp-gpt3-a100.png -------------------------------------------------------------------------------- /src/ml-bench/plots/mlp-gpt3-v100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/cusync/345b60a137f89716047ce55590d5aed5f9adebb0/src/ml-bench/plots/mlp-gpt3-v100.png -------------------------------------------------------------------------------- /src/ml-bench/plots/mlp-llama-a100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/cusync/345b60a137f89716047ce55590d5aed5f9adebb0/src/ml-bench/plots/mlp-llama-a100.png -------------------------------------------------------------------------------- /src/ml-bench/plots/mlp-llama-v100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/cusync/345b60a137f89716047ce55590d5aed5f9adebb0/src/ml-bench/plots/mlp-llama-v100.png -------------------------------------------------------------------------------- /src/ml-bench/plots/plotGPT.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import sys 5 | import csv 6 | from common import * 7 | import math 8 | import matplotlib.ticker as mtick 9 | 10 | csv_file = sys.argv[1] 11 | pdf_name = sys.argv[2] 12 | 13 | only_one_h = True 14 | attention_or_mlp = "attention" if ("attention" in csv_file) else "mlp" 15 | model = "gpt3" if "gpt3" in csv_file else "llama" 16 | gpu = "a100" if "a100" in csv_file else ("v100" if "v100" in csv_file else "") 17 | 18 | only_streamk = False 19 | if len(sys.argv) > 3 and sys.argv[3] == "only_streamk": 20 | only_streamk = True 21 | import math 22 | import csv 23 | mInd = 0 24 | seqInd = 1 25 | hInd = 2 26 | syncTypeInd = 3 27 | streamkInd = 4 28 | torchInd = 4 29 | baselineInd = 4 30 | stdevBaselineInd = 5 31 | # matmul1Ind = 6 32 | # matmul2Ind = 7 33 | # maxtbsInd = 8 34 | # matmul1TbsInd = 9 35 | # matmul2TbsInd = 10 36 | overlapInd = 8 37 | stdevOverlapInd = 9 38 | 39 | def load_csv(csv_file): 40 | data = [] 41 | with open(csv_file, 'r') as f: 42 | csv_reader = csv.reader(f,delimiter='&') 43 | for i, row in enumerate(csv_reader): 44 | row_new = [] 45 | for e in row: 46 | row_new.append(e.strip()) 47 | row = row_new 48 | data += [row] 49 | 50 | return data 51 | 52 | data = load_csv(csv_file) 53 | 54 | import matplotlib.pyplot as plt 55 | import numpy as np 56 | if attention_or_mlp == "attention": 57 | width = 0.3 58 | else: 59 | width = 0.4 60 | 61 | # fig = plt.subplots(figsize =(10, 7)) 62 | m = [] 63 | h = [] 64 | torchT = [] 65 | baseline = [] 66 | stdevBaseline = [] 67 | matmul1 = [] 68 | softmax = [] 69 | matmul2 = [] 70 | maxtbs = [] 71 | matmul1Tbs = [] 72 | matmul2Tbs = [] 73 | rowOverlap = [] 74 | stdevRowOverlap = [] 75 | tileOverlap = [] 76 | stdevTileOverlap = [] 77 | stridedTileOverlap = [] 78 | stdevStridedTileOverlap = [] 79 | maxSpeedup = [] 80 | analyticalOverlapTimes = [] 81 | streamK = [] 82 | stdevstreamk = [] 83 | 84 | rowIdx = 0 85 | 86 | #Time t is in microseconds 87 | def flops(model, m, t): 88 | t = t/1e6 89 | if model == "llama": 90 | H = 8192 91 | FFN = ((H+128-1)//128)*128 92 | return (2*(m * 2 * FFN * H + m*FFN*H)/t)/1e12 93 | elif model == "gpt3": 94 | H = 12288 95 | FFN = H/2 96 | return (2*(m * FFN * H + m*FFN*H)/t)/1e12 97 | 98 | def flops_for_all_rows(model, batches, times): 99 | flops_list = [] 100 | for m,t in zip(batches, times): 101 | flops_list += [flops(model, m, t)] 102 | return flops_list 103 | 104 | while rowIdx < len(data): 105 | # print(rowIdx) 106 | row = data[rowIdx] 107 | i = 0 108 | while rowIdx < len(data) and i < (4 if attention_or_mlp == 'mlp' else 6): 109 | row = data[rowIdx] 110 | if row[syncTypeInd] == 'streamk': 111 | streamK += [float(row[streamkInd])] 112 | elif row[syncTypeInd] == 'rowsync': 113 | rowOverlap += [float(row[overlapInd])] 114 | elif row[syncTypeInd] == 'baseline': 115 | m += [int(row[mInd])] 116 | baseline += [float(row[baselineInd])] 117 | elif row[syncTypeInd] == 'tilesync': 118 | tileOverlap += [float(row[overlapInd])] 119 | elif row[syncTypeInd] == 'stridedsync': 120 | stridedTileOverlap += [float(row[overlapInd])] 121 | elif row[syncTypeInd] == 'torch': 122 | torchT += [float(row[torchInd])] 123 | rowIdx += 1 124 | i += 1 125 | 126 | # baseline = flops_for_all_rows(model, m, baseline) 127 | # streamK = flops_for_all_rows(model, m, streamK) 128 | # rowOverlap = flops_for_all_rows(model, m, rowOverlap) 129 | # tileOverlap = flops_for_all_rows(model, m, tileOverlap) 130 | 131 | if __name__ == "__main__": 132 | # secFactor = 1e3 if (secs == "ms") else 1e6 133 | torchT = np.array(torchT) 134 | baseline = np.array(baseline) 135 | ind = np.arange(len(baseline)) 136 | matmul1 = np.array(matmul1) 137 | matmul2 = np.array(matmul2) 138 | softmax = np.array(softmax) 139 | stdevBaseline = np.array(stdevBaseline) 140 | rowOverlap = np.array(rowOverlap) 141 | stdevRowOverlap = np.array(stdevRowOverlap) 142 | tileOverlap = np.array(tileOverlap) 143 | streamK = np.array(streamK) 144 | stdevTileOverlap = np.array(stdevTileOverlap) 145 | analyticalOverlapTimes = np.array(analyticalOverlapTimes) 146 | 147 | cutlassSpeedup = (torchT - baseline)/torchT*100 148 | cusync = np.minimum(rowOverlap, tileOverlap) 149 | cusyncSpeedup = (torchT - cusync)/torchT*100 150 | cusyncOverCUTLASS = (baseline - cusync)/baseline*100 151 | if gpu == "a100": 152 | streamKSpeedup = (torchT - streamK)/torchT*100 153 | else: 154 | streamKSpeedup = np.array([0]) 155 | 156 | cusyncSpeedup = np.clip(cusyncSpeedup, -5, 45) 157 | cutlassSpeedup = np.clip(cutlassSpeedup, -5, 45) 158 | streamKSpeedup = np.clip(streamKSpeedup, -5, 45) 159 | cusyncOverCUTLASS = np.clip(cusyncOverCUTLASS, -5, 45) 160 | 161 | # analyticalSpeedup = baseline/analyticalOverlapTimes 162 | fig, ax2 = plt.subplots(1,1,sharex=True) 163 | p0 = ax2.plot(ind, cutlassSpeedup, 'o', color=colors[0]) 164 | p1 = ax2.plot(ind, cusyncOverCUTLASS, marker='+', color=colors[1]) 165 | p2 = ax2.plot(ind, cusyncSpeedup, 'x', color=colors[2]) 166 | if gpu == "a100": 167 | p3 = ax2.plot(ind, streamKSpeedup, 's',color=colors[3]) 168 | 169 | # if attention_or_mlp == "attention": 170 | # stridedTileSpeedup = (baseline - stridedTileOverlap)/baseline * 100 171 | # p3 = ax2.plot(ind, stridedTileSpeedup,'v',color=colors[2]) 172 | # print(stridedTileSpeedup) 173 | # for i, f in enumerate(np.maximum(np.maximum(rowSpeedup, tileSpeedup), stridedTileSpeedup)): 174 | # ax2.text(i, f+1, "%.0f"%round(f, 0), color = 'black', ha = 'center', rotation=0) 175 | # else: 176 | # for i, f in enumerate(cusyncSpeedup): 177 | # ax2.text(i*2+2, f+1, "%.0f"%round(f,0), color = 'black', ha = 'center', rotation=0) 178 | 179 | # p4 = ax2.plot(ind, streamKSpeedup, 'x',color=colors[3]) 180 | 181 | # p3 = ax2.plot(list(range(0, len(data)//2)), analyticalSpeedup) 182 | 183 | # for bar1, d in zip(p1, cusyncSpeedup): 184 | # ax2.text(bar1.get_x()+bar1.get_width()/2-0.05, bar1.get_height()+0.5, "%.0f"%(round(d,0)), 185 | # color = 'black', ha = 'center', va = 'center', rotation=0) 186 | 187 | # for bar1, speedup in zip(p3, fastkronspeedup): 188 | # ax2.text(bar1.get_x()+bar1.get_width()/2+0.04, bar1.get_height()+0.05, r"%.2f$\times$"%(1/speedup), color = 'black', ha = 'center', va = 'center', rotation=0, fontsize='large') 189 | # if only_one_h and attention_or_mlp == True: 190 | # plt.ylim(0.6, 1.3) 191 | # plt.yticks([0.6+0.1*i for i in range(0, 7)]) 192 | # else: 193 | # ax2.margins(0.02) 194 | max_speedup = max([np.amax(cusyncSpeedup), np.amax(streamKSpeedup), np.amax(cusyncOverCUTLASS)]) 195 | print(max_speedup) 196 | max_speedup = int(((max_speedup+10-1)//10)*10) 197 | print(max_speedup) 198 | plt.ylim(-5, max_speedup) 199 | plt.yticks(ticks=[-5+5*i for i in range(0, max_speedup//5+1)], 200 | labels=["%d%%"%(-5+5*i) for i in range(0, max_speedup//5 + 1)]) 201 | # ax2.yaxis.set_major_formatter(mtick.PercentFormatter(decimals=None)) 202 | # ax2.set_yticklabels(["%d%%"%(-5+5*i) for i in range(0, 9)]) 203 | # plt.yticks(["%d%(-5+5*i) for i in range(0, 7)]) 204 | # plt.xlim(-1,ind[-1]+1) 205 | # plt.title('Contribution by the teams') 206 | plt.axhline(0, color='black', ls='dotted') 207 | # plt.yticks(np.arange(0, 1.25, 0.25)) 208 | if attention_or_mlp == "mlp": 209 | xt = list(m) 210 | plt.xticks(ind, xt, rotation=90) 211 | 212 | plt.ylabel('Percentage Improvement of X/Y') 213 | # ax2.get_yaxis().set_label_coords(-0.17,0.4) 214 | plt.xlabel("Number of Tokens in %s MLP on A100"%(model.upper())) 215 | # ax2.get_xaxis().set_label_coords(0.45,-0.4) 216 | if gpu == "a100": 217 | labels = (p0[0], p3[0], p2[0], p1[0]) 218 | legends = ('CUTLASS/PyTorch', 'StreamK/PyTorch', 'CuSync/PyTorch', 'CuSync/CUTLASS') 219 | else: 220 | labels = (p0[0], p2[0], p1[0]) 221 | legends = ('CUTLASS/PyTorch', 'CuSync/PyTorch', 'CuSync/CUTLASS') 222 | 223 | plt.legend(labels, legends, 224 | loc='upper left', bbox_to_anchor=(-0.1, 1.20), 225 | ncol=2,columnspacing=1,handlelength=1.7) 226 | else: 227 | ax2.get_yaxis().set_visible(False) 228 | ax2.get_xaxis().set_label_coords(0.45,-0.4) 229 | xt = list((2**i for i in range(0, len(ind)))) 230 | if "attention" in csv_file: 231 | xt = ["512, 0", "1024, 0", "2048, 0", "1, 512", "2, 512", "4, 512", "1, 1024", "2, 1024", "4, 1024", "1, 2048", "2, 2048", "4, 2048"] 232 | plt.xticks(ind, xt, rotation=90) 233 | if attention_or_mlp == "attention" and model == "gpt3": 234 | plt.legend((p1[0], p2[0], p3[0], p4[0]), 235 | ('RowSync', 'TileSync+WRT', 'StridedTileSync+WRT', 'StreamK'), 236 | loc='upper left', bbox_to_anchor=(-0.01, 1.16), 237 | ncol=4,columnspacing=1,handlelength=1.7) 238 | plt.xlabel("B$\\times$S, S'") 239 | 240 | plt.rcParams["font.family"] = "libertine" 241 | #FIGURES_DIR = "./" 242 | fig = plt.gcf() 243 | fig.subplots_adjust(bottom=0.1) 244 | # if only_one_h: 245 | # else: 246 | # fig.set_size_inches(8.5, 2.5) 247 | # if attention_or_mlp == "mlp" and model == "gpt3": 248 | fig.set_size_inches(4, 3) 249 | # else: 250 | # fig.set_size_inches(3.2, 2.4) 251 | # ax.set_xticks([]) 252 | FIGURES_DIR = "./" 253 | fig.savefig(FIGURES_DIR+pdf_name,bbox_inches='tight',pad_inches=0) 254 | #plt.show() 255 | -------------------------------------------------------------------------------- /src/ml-bench/transformer/Makefile: -------------------------------------------------------------------------------- 1 | include ../common.mk 2 | 3 | ARCH_FLAGS=-gencode=arch=compute_70,code=[sm_70,compute_70] -gencode=arch=compute_80,code=[sm_80,compute_80] 4 | INCLUDES=-I$(NV_CUTLASS)/include -I$(NV_CUTLASS)/examples/common -I$(NV_CUTLASS)/tools/util/include -I$(CUSYNC_CUTLASS)/include/ -I$(CUSYNC) -I. 5 | 6 | DEFINES=-DCUTLASS_ENABLE_CUBLAS=1 -DCUTLASS_NAMESPACE=cutlass -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1 -DCUTLASS_TEST_LEVEL=0 -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1 -DCUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED=1 -DCUTLASS_DEBUG_TRACE_LEVEL=0 7 | CUSYNC_SRC_FILES=$(CUSYNC_SRC)/cusync.cu 8 | 9 | $(BUILD)/streamk: streamk.cu $(CUSYNC_SRC_FILES) 10 | $(NVCC) $(DEFINES) $(INCLUDES) $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -O3 11 | 12 | $(BUILD)/streamk-eval: $(BUILD)/streamk-eval.cu $(CUSYNC_SRC_FILES) 13 | $(NVCC) $(DEFINES) $(INCLUDES) $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -O3 -DEVAL_TILE_SIZES 14 | 15 | $(BUILD)/mlp-batchedrow: mlp.cu common.h $(CUSYNC_SRC_FILES) 16 | $(NVCC) $(DEFINES) $(INCLUDES) $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -O3 -DBATCHEDROW 17 | 18 | $(BUILD)/mlp-tilebatchsync: mlp.cu common.h $(CUSYNC_SRC_FILES) 19 | $(NVCC) $(DEFINES) $(INCLUDES) $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -O3 -DTILEBATCH 20 | 21 | $(BUILD)/mlp-gpt3-rowsync: mlp.cu common.h $(CUSYNC_SRC_FILES) 22 | $(NVCC) $(DEFINES) $(INCLUDES) $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -O3 -DROWSYNC -DMLP_GPT3 23 | 24 | $(BUILD)/mlp-gpt3-tilesync: mlp.cu common.h $(CUSYNC_SRC_FILES) 25 | $(NVCC) $(DEFINES) $(INCLUDES) $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -O3 -DTILESYNC -DMLP_GPT3 26 | 27 | $(BUILD)/mlp-llama-rowsync: mlp.cu common.h $(CUSYNC_SRC_FILES) 28 | $(NVCC) $(DEFINES) $(INCLUDES) $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -O3 -DROWSYNC -DMLP_LLAMA 29 | 30 | $(BUILD)/mlp-llama-tilesync: mlp.cu common.h $(CUSYNC_SRC_FILES) 31 | $(NVCC) $(DEFINES) $(INCLUDES) $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -O3 -DTILESYNC -DMLP_LLAMA 32 | 33 | $(BUILD)/libmlp.so: mlp-lib.cu common.h $(CUSYNC_SRC_FILES) 34 | $(NVCC) $(DEFINES) $(INCLUDES) --shared -Xcompiler -m64,-fPIC,-Wconversion,-fno-strict-aliasing $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -O3 -DROWSYNC 35 | 36 | $(BUILD)/mlp-tilesync: mlp.cu common.h $(CUSYNC_SRC_FILES) 37 | $(NVCC) $(DEFINES) $(INCLUDES) $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -Xptxas -v -lcublas -Xcompiler=-fopenmp -O3 -DTILESYNC 38 | 39 | $(BUILD)/mlp-eval-baseline: $(BUILD)/mlp-eval-baseline.cu common.h $(CUSYNC_SRC_FILES) 40 | $(NVCC) $(DEFINES) $(INCLUDES) $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -O3 -DROWSYNC 41 | 42 | $(BUILD)/mlp-eval-rowsync: $(BUILD)/mlp-eval-rowsync.cu common.h $(CUSYNC_SRC_FILES) 43 | $(NVCC) $(DEFINES) $(INCLUDES) $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -O3 44 | 45 | $(BUILD)/mlp-eval-tilesync: $(BUILD)/mlp-eval-tilesync.cu common.h $(CUSYNC_SRC_FILES) 46 | $(NVCC) $(DEFINES) $(INCLUDES) $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -O3 47 | 48 | $(BUILD)/mlp-eval: $(BUILD)/mlp-eval.cu common.h $(CUSYNC_SRC_FILES) 49 | $(NVCC) $(DEFINES) $(INCLUDES) $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -O3 50 | 51 | $(BUILD)/attention-rowsync: attention.cu $(CUSYNC_SRC_FILES) 52 | $(NVCC) $(DEFINES) $(INCLUDES) -O3 $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -DROWSYNC 53 | 54 | $(BUILD)/attention-tilesync: attention.cu $(CUSYNC_SRC_FILES) 55 | $(NVCC) $(DEFINES) $(INCLUDES) -O3 $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -DTILESYNC 56 | 57 | $(BUILD)/attention-gpt3-stridedsync: attention.cu $(CUSYNC_SRC_FILES) 58 | $(NVCC) $(DEFINES) $(INCLUDES) -O3 $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -DSTRIDEDSYNC -DGPT3 59 | 60 | $(BUILD)/attention-llama-stridedsync: attention.cu $(CUSYNC_SRC_FILES) 61 | $(NVCC) $(DEFINES) $(INCLUDES) -O3 $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -DSTRIDEDSYNC -DLLaMA 62 | 63 | $(BUILD)/attention-eval-baseline: $(BUILD)/attention-eval-baseline.cu $(CUSYNC_SRC_FILES) 64 | $(NVCC) $(DEFINES) $(INCLUDES) -O3 $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -DTILESYNC 65 | 66 | $(BUILD)/attention-eval-rowsync: $(BUILD)/attention-eval-rowsync.cu $(CUSYNC_SRC_FILES) 67 | $(NVCC) $(DEFINES) $(INCLUDES) -O3 $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -DROWSYNC 68 | 69 | $(BUILD)/attention-eval-tilesync: $(BUILD)/attention-eval-tilesync.cu $(CUSYNC_SRC_FILES) 70 | $(NVCC) $(DEFINES) $(INCLUDES) -O3 $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $@ -Xptxas -v -lcublas -Xcompiler=-fopenmp -DTILESYNC 71 | 72 | $(BUILD)/attention-gpt3-eval-stridedsync: $(BUILD)/attention-eval-stridedsync.cu $(CUSYNC_SRC_FILES) 73 | $(NVCC) $(DEFINES) $(INCLUDES) -O3 $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $(BUILD)/attention-eval-stridedsync -Xptxas -v -lcublas -Xcompiler=-fopenmp -DSTRIDEDSYNC -DGPT3 74 | 75 | $(BUILD)/attention-llama-eval-stridedsync: $(BUILD)/attention-eval-stridedsync.cu $(CUSYNC_SRC_FILES) 76 | $(NVCC) $(DEFINES) $(INCLUDES) -O3 $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $(BUILD)/attention-eval-stridedsync -Xptxas -v -lcublas -Xcompiler=-fopenmp -DSTRIDEDSYNC -DLLaMA 77 | 78 | $(BUILD)/attention-gpt3-eval: $(BUILD)/attention-eval.cu $(CUSYNC_SRC_FILES) 79 | $(NVCC) $(DEFINES) $(INCLUDES) -O3 $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $(BUILD)/attention-eval -Xptxas -v -lcublas -Xcompiler=-fopenmp -DGPT3 80 | 81 | $(BUILD)/attention-llama-eval: $(BUILD)/attention-eval.cu $(CUSYNC_SRC_FILES) 82 | $(NVCC) $(DEFINES) $(INCLUDES) -O3 $(ARCH_FLAGS) -DNDEBUG $< $(CUSYNC_SRC_FILES) -o $(BUILD)/attention-eval -Xptxas -v -lcublas -Xcompiler=-fopenmp -DLLaMA 83 | 84 | clean: 85 | rm -f $(BUILD)/mlp-* $(BUILD)/attention-* 86 | -------------------------------------------------------------------------------- /src/ml-bench/transformer/allreduce_times.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import torch 4 | import torch.distributed as dist 5 | import torch.multiprocessing as mp 6 | import time 7 | import sys 8 | 9 | H = int(sys.argv[1]) 10 | 11 | def run(rank, size): 12 | """ Distributed function to be implemented later. """ 13 | for b in [1,2,4,8,16,32,64,128,256,512,1024,2048]: 14 | inT = torch.ones(b*H,dtype=torch.half).cuda(rank) 15 | 16 | for i in range(10): 17 | dist.all_reduce(inT) 18 | torch.cuda.synchronize() 19 | dist.barrier() 20 | 21 | s = time.time() 22 | for i in range(100): 23 | dist.all_reduce(inT) 24 | torch.cuda.synchronize() 25 | e = time.time() 26 | 27 | if rank == 0: 28 | print(f"{b} & {H} & {((e - s)/100.)*1000}") 29 | 30 | def init_process(rank, size, fn, backend='nccl'): 31 | """ Initialize the distributed environment. """ 32 | os.environ['MASTER_ADDR'] = '127.0.0.1' 33 | os.environ['MASTER_PORT'] = '29500' 34 | dist.init_process_group(backend, rank=rank, world_size=size) 35 | fn(rank, size) 36 | 37 | if __name__ == "__main__": 38 | size = 8 39 | processes = [] 40 | mp.set_start_method("spawn") 41 | for rank in range(size): 42 | p = mp.Process(target=init_process, args=(rank, size, run)) 43 | p.start() 44 | processes.append(p) 45 | 46 | for p in processes: 47 | p.join() -------------------------------------------------------------------------------- /src/ml-bench/transformer/common.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | // 4 | #include 5 | #include 6 | #include 7 | 8 | #include "cutlass/cutlass.h" 9 | #include "cutlass/gemm/device/cusyncgemm.h" 10 | #include "cutlass/gemm/device/gemm.h" 11 | #include "cutlass/gemm/threadblock/cusync_threadblock_swizzle.h" 12 | 13 | #include "cutlass/util/host_tensor.h" 14 | #include "cutlass/util/reference/device/gemm.h" 15 | #include "cutlass/util/reference/host/tensor_compare.h" 16 | #include "cutlass/util/reference/host/tensor_copy.h" 17 | #include "cutlass/util/reference/host/tensor_fill.h" 18 | #include "cutlass/util/tensor_view_io.h" 19 | #include "helper.h" 20 | #include 21 | 22 | #include 23 | #include 24 | #include 25 | 26 | #define DIVUP(x, y) (((x) + (y) - 1)/(y)) 27 | 28 | static double convertTimeValToDouble(struct timeval _time) { 29 | return ((double)_time.tv_sec)*1e6 + ((double)_time.tv_usec); 30 | } 31 | 32 | static struct timeval getTimeOfDay () { 33 | struct timeval _time; 34 | 35 | if (gettimeofday (&_time, NULL) == -1) { 36 | fprintf (stderr, "gettimeofday returned -1\n"); 37 | perror (""); 38 | abort (); 39 | } 40 | 41 | return _time; 42 | } 43 | 44 | static double timeInMicroSeconds() { 45 | return convertTimeValToDouble(getTimeOfDay()); 46 | } 47 | 48 | static double getCurrentTime() { 49 | return timeInMicroSeconds(); 50 | } 51 | 52 | #define CUBLASCHECK(cmd) do { \ 53 | cublasStatus_t e = cmd; \ 54 | if (e != CUBLAS_STATUS_SUCCESS) { \ 55 | printf("Failed: CUBLAS error %s: %d '%d'\n", \ 56 | __FILE__, __LINE__, cmd); \ 57 | assert(false); \ 58 | } \ 59 | } while(0) 60 | 61 | template 62 | __global__ void ref_cudamatmul(uint32_t M, uint32_t N, uint32_t K, 63 | T* A, T* B, T* C) { 64 | int ROW = blockIdx.y*blockDim.y+threadIdx.y; 65 | int COL = blockIdx.x*blockDim.x+threadIdx.x; 66 | 67 | if (ROW < M && COL < N) { 68 | AT tmpSum = (AT)0.0f; 69 | // each thread computes one element of the block sub-matrix 70 | for (uint32_t i = 0; i < K; i++) { 71 | tmpSum += (AT)(A[ROW * K + i]) * (AT)(B[i * N + COL]); 72 | } 73 | 74 | C[ROW * N + COL] = (T)tmpSum; 75 | } 76 | } 77 | 78 | template 79 | void ref_matmul(uint32_t M, uint32_t N, uint32_t K, T* mat1, T* mat2, T* host_res) { 80 | T* dev_refC = NULL; 81 | CUDA_CHECK(cudaMalloc(&dev_refC, sizeof(T)*M*N)); 82 | dim3 block = {32, 32, 1}; 83 | dim3 grid = {N/block.x + 1, M/block.y + 1, 1}; 84 | ref_cudamatmul<<>>(M, N, K, mat1, mat2, dev_refC); 85 | CUDA_CHECK(cudaDeviceSynchronize()); 86 | CUDA_CHECK(cudaMemcpy(host_res, dev_refC, sizeof(T)*M*N, cudaMemcpyDeviceToHost)); 87 | } 88 | 89 | template 90 | void ref_cpumatmul(uint32_t M, uint32_t N, uint32_t K, T* mat1, T* mat2, T* res) 91 | { 92 | uint32_t i, j, k; 93 | for (i = 0; i < M; i++) { 94 | #pragma omp parallel for 95 | for (j = 0; j < N; j++) { 96 | AT accum = 0; 97 | for (k = 0; k < K; k++) 98 | accum += ((float)mat1[i*K + k]) * ((float)mat2[k*N + j]); 99 | res[i*N + j] = T(accum); 100 | } 101 | } 102 | } 103 | 104 | template 105 | bool equals(size_t size, T* mat1, T* mat2, float err) { 106 | bool eq = true; 107 | for (size_t i = 0; i < size; i++) { 108 | float e1 = (float)mat1[i]; 109 | float e2 = (float)mat2[i]; 110 | 111 | float v = err; 112 | bool ret = true; 113 | if (abs(e1) < v && abs(e2) < v) { 114 | 115 | ret = true; 116 | } else if (abs(e1) < v) { 117 | ret = false; 118 | } else if (abs(e2) < v) { 119 | ret = false; 120 | } else { 121 | float err = abs(abs(e1) - abs(e2))/max(abs(e1), abs(e2)); 122 | if (err <= v) { 123 | ret = true; 124 | } else { 125 | printf("243: %f , %f at %lu, %f\n", e1, e2, i, err); 126 | ret = false; 127 | } 128 | } 129 | 130 | if (ret == false) { 131 | // printf("%f != %f at %lu\n", e1, e2, i); 132 | eq = false; 133 | } 134 | } 135 | return eq; 136 | return true; 137 | } 138 | 139 | template 140 | __global__ void printKernel(size_t sz, T* data) { 141 | if (threadIdx.x == 0) { 142 | for (size_t i = 65536; i < sz; i++) { 143 | printf("%f at %lu \n", (float)data[i], i); 144 | } 145 | } 146 | } 147 | 148 | template 149 | void memset_value(T*f, T v, size_t nelems) 150 | { 151 | T* h_buff = (T*)malloc(sizeof(T)*nelems); 152 | assert(h_buff != nullptr); 153 | for (uint64_t i = 0; i < nelems; i++) { 154 | h_buff[i] = v; 155 | } 156 | 157 | CUDA_CHECK(cudaMemcpy(f, h_buff, sizeof(T)*nelems, cudaMemcpyHostToDevice)); 158 | free(h_buff); 159 | } 160 | 161 | template 162 | void memset_random2(T*f, T v1, T v2, size_t nelems) 163 | { 164 | // T* h_buff = (T*)malloc(sizeof(T)*nelems); 165 | assert(f != nullptr); 166 | for (uint64_t i = 0; i < nelems; i++) { 167 | if (rand()%2 == 0) 168 | f[i] = v1; 169 | else 170 | f[i] = v2; 171 | // printf("%f\n", (float)f[i]); 172 | } 173 | 174 | // CUDA_CHECK(cudaMemcpy(f, h_buff, sizeof(T)*nelems, cudaMemcpyHostToDevice)); 175 | // free(h_buff); 176 | } 177 | 178 | template 179 | void memset_random(T*f, int numVals, T* values, size_t nelems) 180 | { 181 | // T* h_buff = (T*)malloc(sizeof(T)*nelems); 182 | assert(f != nullptr); 183 | for (uint64_t i = 0; i < nelems; i++) { 184 | f[i] = values[rand()%numVals]; 185 | } 186 | 187 | // CUDA_CHECK(cudaMemcpy(f, h_buff, sizeof(T)*nelems, cudaMemcpyHostToDevice)); 188 | // free(h_buff); 189 | } 190 | 191 | __global__ void init_curand_states(curandState* states, size_t num_states) 192 | { 193 | int thread_id = blockIdx.x*blockDim.x + threadIdx.x; 194 | if (thread_id < num_states) 195 | curand_init(thread_id, threadIdx.x, 0, &states[thread_id]); 196 | } 197 | 198 | int run(int argc, char* arg[]); 199 | int main(int argc, char* argv[]) { 200 | 201 | // Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1. 202 | // 203 | // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples. 204 | if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) { 205 | std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl; 206 | 207 | // Returning zero when built on older Toolkits so tests pass. The actions of this SDK example are no-op. 208 | return 0; 209 | } 210 | else { 211 | return run(argc, argv); 212 | } 213 | } 214 | 215 | -------------------------------------------------------------------------------- /src/ml-bench/transformer/eval_mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import subprocess 5 | import re 6 | import sys 7 | import os 8 | import tile_sizes_db 9 | import time 10 | 11 | attention_or_mlp = sys.argv[1].lower() 12 | model = sys.argv[2].lower() 13 | arch = sys.argv[3].lower() 14 | 15 | assert attention_or_mlp in ["attention", "mlp"] 16 | assert arch.lower() in ["v100", "a100"] 17 | 18 | baselineTimes = {} 19 | cublasTimes = {} 20 | overlappedTimes = {} 21 | minimumTimes = {} 22 | speedup = {} 23 | maxspeedup = {} 24 | import json 25 | from statistics import stdev 26 | 27 | def getAllTimes(s, START, END): 28 | '''Parse output of binaries to obtain list of times 29 | ''' 30 | alltimes = {} 31 | assert START in s 32 | assert END in s 33 | s = s[s.find(START):s.find(END)] 34 | s = s[s.find("\n"):] 35 | alljsons = [] 36 | for l in re.findall(r".+", s): 37 | j = json.loads(l) 38 | alljsons += [j] 39 | 40 | def sortkey(elem): 41 | return elem["Total"] 42 | 43 | alljsons.sort(key=sortkey) 44 | p = 0.9 45 | alljsons = alljsons[:int(len(alljsons)*0.9)] 46 | for j in alljsons: 47 | for k in j: 48 | if k not in alltimes: 49 | alltimes[k] = [] 50 | alltimes[k] += [float(j[k])] 51 | 52 | return alltimes 53 | 54 | def avg(l): 55 | return sum(l)/len(l) 56 | 57 | def slurp(path): 58 | with open(path, "r") as f: 59 | return f.read() 60 | 61 | def buildDir(f): 62 | return 'build/'+f 63 | 64 | if not os.path.exists(buildDir("")): 65 | os.mkdir(buildDir("")) 66 | 67 | def resultsDir(f): 68 | '''Results directory''' 69 | return 'results/'+f 70 | 71 | '''Make results directory if not exists''' 72 | if not os.path.exists(resultsDir("")): 73 | os.mkdir(resultsDir("")) 74 | 75 | def getStreamKTimes(output): 76 | runtime = re.findall(r'\s*Avg runtime: ([\d\.]+)', output) 77 | return float(runtime[0]) 78 | 79 | def genAndMakeStreamK(batchInfo, gemmidx): 80 | inFile = "streamk.cu" 81 | outFile = buildDir("streamk-eval.cu") 82 | tilesCode = """using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<%d, %d, %d>; 83 | using ShapeMMAWarp = cutlass::gemm::GemmShape<%d, %d, %d>;""" 84 | tileSize = batchInfo[syncPolicy]["TileSizes"] if "TileSizes" in batchInfo["baseline"] else batchInfo["TileSizes"] 85 | if len(tileSize) > 1: 86 | tilesCode = tilesCode % tuple(batchInfo["TileSizes"][gemmidx]) 87 | else: 88 | tilesCode = tilesCode % tuple(batchInfo["TileSizes"]) 89 | 90 | NumStages = batchInfo[syncPolicy]["NumStages"] if "NumStages" in batchInfo["baseline"] else batchInfo["NumStages"] 91 | if isinstance(NumStages, list): 92 | NumStages = NumStages[gemmidx] 93 | 94 | numStagesCode = "const uint NumStages = %d;\n" % NumStages 95 | tilesCode += numStagesCode 96 | 97 | if model == "gpt3" and attention_or_mlp == "mlp" and gemmidx == 0: 98 | tilesCode += "#define MLP_GPT3_GEMM1" 99 | 100 | fileContents = slurp(inFile) 101 | tilesCodeStart = fileContents.find("//") + len("//") 102 | tilesCodeEnd = fileContents.find("//") 103 | fileContents = fileContents[0:tilesCodeStart] + "\n" + tilesCode + "\n" + fileContents[tilesCodeEnd:] 104 | with open(outFile, "w") as f: 105 | f.write(fileContents) 106 | (s,o) = subprocess.getstatusoutput(f"rm -r {buildDir('streamk-eval')} ; make {buildDir('streamk-eval')}") 107 | if s != 0: 108 | print(o) 109 | sys.exit(0) 110 | 111 | def deleteFiles(syncPolicies, attention_or_mlp): 112 | command = "rm -f " 113 | for policy in syncPolicies: 114 | if attention_or_mlp == 'attention' and policy == 'stridedsync': 115 | command += buildDir("%s-%s-eval-%s "%(attention_or_mlp, model, policy)) 116 | else: 117 | command += buildDir("%s-eval-%s "%(attention_or_mlp, policy)) 118 | 119 | (s,o) = subprocess.getstatusoutput(command) 120 | 121 | if s != 0: 122 | print(o) 123 | sys.exit(0) 124 | 125 | def makeFiles(syncPolicies, attention_or_mlp): 126 | command = "make " 127 | for policy in syncPolicies: 128 | if attention_or_mlp == 'attention' and policy == 'stridedsync': 129 | command += buildDir("%s-%s-eval-%s "%(attention_or_mlp, model, policy)) 130 | else: 131 | command += buildDir("%s-eval-%s "%(attention_or_mlp, policy)) 132 | 133 | flags = "-j" 134 | command += flags 135 | (s,o) = subprocess.getstatusoutput(command) 136 | 137 | if s != 0: 138 | print(o) 139 | sys.exit(0) 140 | 141 | def genFiles(batchInfo, syncPolicy, attention_or_mlp): 142 | inMLPFile = "mlp.cu" if attention_or_mlp == "mlp" else "attention.cu" 143 | outMLPFile = buildDir(attention_or_mlp + "-eval-" + syncPolicy + ".cu") 144 | tilesTemplate = """using ShapeThreadBlock%d = cutlass::gemm::GemmShape<%d, %d, %d>; 145 | using ShapeWarp%d = cutlass::gemm::GemmShape<%d, %d, %d>;""" 146 | tilesCode = "" 147 | 148 | tileSize = batchInfo[syncPolicy]["TileSizes"] if "TileSizes" in batchInfo[syncPolicy] else batchInfo["TileSizes"] 149 | if len(tileSize) > 1: 150 | for i,tile in enumerate(tileSize): 151 | tilesCode += tilesTemplate % tuple([i+1] + tile[:3] + [i+1] + tile[3:]) 152 | tilesCode += "\n" 153 | else: 154 | tilesTemplate = """using ShapeThreadBlock = cutlass::gemm::GemmShape<%d, %d, %d>; 155 | using ShapeWarp = cutlass::gemm::GemmShape<%d, %d, %d>;""" 156 | tilesCode = tilesTemplate % tuple(tileSize[0]) 157 | 158 | NumStages = batchInfo[syncPolicy]["NumStages"] if "NumStages" in batchInfo[syncPolicy] else batchInfo["NumStages"] 159 | numStagesCode = "" 160 | NumStagesTemplate = "const uint NumStages%d = %d;\n" 161 | if isinstance(NumStages, list) and len(NumStages) > 1: 162 | for i,num in enumerate(NumStages): 163 | numStagesCode += NumStagesTemplate % tuple([i+1, int(num)]) 164 | else: 165 | numStagesCode = NumStagesTemplate %(1, int(NumStages)) + \ 166 | NumStagesTemplate %(2, int(NumStages)) 167 | tilesCode+=numStagesCode 168 | batchInfo = batchInfo["tilesync"] if syncPolicy == "stridedsync" or syncPolicy == 'baseline' else batchInfo[syncPolicy] 169 | if "SoftmaxRowTile" in batchInfo: 170 | tilesCode += "\nconst uint SoftmaxRowTile = %d;"%batchInfo["SoftmaxRowTile"] 171 | mlpFileContents = slurp(inMLPFile) 172 | tilesCodeStart = mlpFileContents.find("//") + len("//") 173 | tilesCodeEnd = mlpFileContents.find("//") 174 | mlpFileContents = mlpFileContents[0:tilesCodeStart] + "\n" + tilesCode + "\n" + mlpFileContents[tilesCodeEnd:] 175 | optimizationsStart = mlpFileContents.find("//") + len("//") 176 | optimizationsEnd = mlpFileContents.find("//") 177 | optimizationsCode = "" 178 | if model == "GPT3".lower(): 179 | optimizationsCode += f"#define {attention_or_mlp.upper()}_GPT3\n" 180 | elif model == "LLAMA".lower(): 181 | optimizationsCode += f"#define {attention_or_mlp.upper()}_LLAMA\n" 182 | 183 | if syncPolicy != 'baseline': 184 | if "AvoidCustomOrder" in batchInfo and batchInfo["AvoidCustomOrder"] == True: 185 | optimizationsCode += "#define AVOID_CUSTOM_ORDER"+"\n" 186 | else: 187 | optimizationsCode += "#undef AVOID_CUSTOM_ORDER"+"\n" 188 | if "AvoidWaitKernel" in batchInfo and batchInfo["AvoidWaitKernel"] == True: 189 | optimizationsCode += "#define AVOID_WAIT_KERNEL"+"\n" 190 | else: 191 | optimizationsCode += "#undef AVOID_WAIT_KERNEL"+"\n" 192 | if "ReorderTileLoads" in batchInfo and batchInfo["ReorderTileLoads"] == True: 193 | optimizationsCode += "#define REORDER_TILE_LOADS"+"\n" 194 | else: 195 | optimizationsCode += "#undef REORDER_TILE_LOADS"+"\n" 196 | if "NoAtomicAdd" in batchInfo and batchInfo["NoAtomicAdd"] == True: 197 | optimizationsCode += "#define NO_ATOMIC_ADD"+"\n" 198 | else: 199 | optimizationsCode += "#undef NO_ATOMIC_ADD"+"\n" 200 | 201 | optimizationsCode += "#define " + syncPolicy.upper() + "\n" 202 | optimizationsCode += "#define " + "EVAL_TILE_SIZES" + "\n" 203 | mlpFileContents = mlpFileContents[0:optimizationsStart] + "\n" + optimizationsCode + "\n" + mlpFileContents[optimizationsEnd:] 204 | if os.path.exists(outMLPFile): 205 | with open(outMLPFile, "r") as f: 206 | oldContents = f.read() 207 | if mlpFileContents == oldContents: 208 | return 209 | with open(outMLPFile, "w") as f: 210 | f.write(mlpFileContents) 211 | 212 | tiles_field_str = f"{model}_{attention_or_mlp}_{arch}" 213 | tiles = getattr(tile_sizes_db, tiles_field_str) 214 | 215 | if model.lower() == "GPT3".lower(): 216 | H = 12288 217 | FFN = int(4*H/8) 218 | elif model.lower() == "llama".lower(): 219 | H = 8192 220 | FFN = int(((8192/3+127)//128)*128)#int(2/3 * 4 * H/8) 221 | else: 222 | print ("No Hidden dim for ", model) 223 | sys.exit(0) 224 | 225 | policies = ['rowsync', 'tilesync', 'stridedsync'] 226 | if 'stridedsync' in policies and attention_or_mlp == 'mlp': 227 | policies.pop(policies.index('stridedsync')) 228 | 229 | deleteFiles(policies+['baseline'], attention_or_mlp) 230 | 231 | if attention_or_mlp == "mlp": 232 | cases = (([1,2,4,8,16,32,64,128,256]) if arch=='v100' else []) +\ 233 | [512+256*i for i in range(0, 11)] 234 | else: 235 | #cases = [(0,256), (0,512), (0, 1024), (0, 2048), (1024,1), (1024,4), (2048,1), (2048,4)] 236 | cases = [(512,1),(512,2), (512,4), (1024,1), (1024,2), (1024,4), (2048,1), (2048,2), (2048,4)] 237 | 238 | results_csv = "" 239 | 240 | for case in cases: 241 | if attention_or_mlp == "attention": 242 | m = case[1] 243 | seq = case[0] 244 | else: 245 | m = case 246 | seq = 0 247 | 248 | caseTiles = None 249 | if attention_or_mlp == "attention": 250 | caseTiles = tiles[seq][m] 251 | else: 252 | if m > 2048: 253 | caseTiles = tiles[4096] 254 | else: 255 | caseTiles = tiles[m] 256 | 257 | if True: 258 | if attention_or_mlp == "attention": 259 | (s, o) = subprocess.getstatusoutput(f"python3 torch-baselines/torchAttention.py {m} {int(H/8)} {H} {H}") 260 | else: 261 | (s, o) = subprocess.getstatusoutput(f"python3 torch-baselines/torchmlp.py {m} {model}") 262 | 263 | if s == -1: 264 | print("error " + o) 265 | else: 266 | ctime = o 267 | cublasTimes[m] = ctime 268 | 269 | result_row = f'{m} & {seq} & {H} & {"torch"} & {"%.2f"%float(ctime)}' 270 | print(result_row) 271 | results_csv += result_row + "\n" 272 | 273 | if arch == "a100": 274 | genAndMakeStreamK(caseTiles, 0) 275 | streamk_command = buildDir("streamk-eval") + f" --m={m} --alpha=1 --beta=0 --iterations=20 " 276 | (s, o) = subprocess.getstatusoutput(streamk_command + f"--n={int(2*FFN if model=='llama' else FFN)} --k={H} " + f"--split={caseTiles['baseline']['split_ks'][0]}") 277 | if s != 0: 278 | print("StreamK Error") 279 | print(o) 280 | 281 | firstGeMMStreamK = getStreamKTimes(o) 282 | genAndMakeStreamK(caseTiles, 1) 283 | (s, o) = subprocess.getstatusoutput(streamk_command + f"--n={H} --k={int(FFN)} " + f"--split={caseTiles['baseline']['split_ks'][1]}") 284 | if s != 0: 285 | print("StreamK Error") 286 | print(o) 287 | 288 | secondGeMMStreamK = getStreamKTimes(o) 289 | total = firstGeMMStreamK + secondGeMMStreamK 290 | result_row = f'{m} & {seq} & {H} & {"streamk"} & {"%.2f"%(total*1000)} & {"%.2f"%(firstGeMMStreamK*1000)} & {"%.2f"%(secondGeMMStreamK*1000)}' 291 | print(result_row) 292 | results_csv += result_row + "\n" 293 | 294 | baselineDone = False 295 | bTimeTotal = 0 296 | 297 | for syncPolicy in (policies+['baseline']): 298 | genFiles(caseTiles, syncPolicy, attention_or_mlp) 299 | 300 | makeFiles(policies+['baseline'], attention_or_mlp) 301 | 302 | split_ks = caseTiles['baseline']['split_ks'] 303 | splitKArgs = " " + " ".join([f"--split-k{i+1} {split_ks[i]}" for i in range(len(split_ks))]) 304 | commandArgs = f" --batch {m} --check false --model {model.lower()}" 305 | if attention_or_mlp == "attention": 306 | commandArgs += f" --seqlen {(seq - m) if seq > m else seq}" 307 | baselineCommand = buildDir(f"{attention_or_mlp}-eval-baseline") + commandArgs + splitKArgs + " --policy baseline" 308 | (s, o) = subprocess.getstatusoutput(baselineCommand) 309 | # print(o) 310 | if "Invalid" in o: 311 | pass 312 | elif s != 0: 313 | print("error " + o) 314 | else: 315 | # print(o) 316 | baselinetimes = getAllTimes(o, 'START-BASELINE', 'END-BASELINE') 317 | bTimeTotal = baselinetimes["Total"] 318 | bTimeMatmul1 = baselinetimes["matmul1Time"] 319 | bTimeMatmul2 = baselinetimes["matmul2Time"] 320 | result_row = f'{m} & {seq} & {H} & baseline & {"%.2f"%avg(bTimeTotal)} & {"%.2f"%stdev(bTimeTotal)} & {"%.2f"%avg(bTimeMatmul1)} & {"%.2f"%avg(bTimeMatmul2)}' 321 | results_csv += result_row + "\n" 322 | print(result_row) 323 | baselineDone = True 324 | 325 | for syncPolicy in policies: 326 | split_ks = (caseTiles["tilesync"] if syncPolicy == "stridedsync" else caseTiles[syncPolicy])["split_ks"] 327 | splitKArgs = " " + " ".join([f"--split-k{i+1} {split_ks[i]}" for i in range(len(split_ks))]) 328 | command = "" 329 | # if attention_or_mlp == 'attention' and syncPolicy == 'stridedsync': 330 | # command += buildDir("%s-%s-eval-%s "%(attention_or_mlp, model, syncPolicy)) 331 | # else: 332 | command += buildDir("%s-eval-%s "%(attention_or_mlp, syncPolicy)) 333 | command += commandArgs + splitKArgs + " --policy cusync" 334 | (s, o) = subprocess.getstatusoutput(command) 335 | 336 | otime = -1 337 | if "Invalid" in o: 338 | pass 339 | elif s != 0: 340 | print("error " + o) 341 | else: 342 | overlaptimes = getAllTimes(o, 'START-OVERLAPPED', 'END-OVERLAPPED') 343 | otime = overlaptimes["Total"] 344 | 345 | result_row = f'{m} & {seq} & {H} & {syncPolicy} & {"%.2f"%avg(bTimeTotal)} & {"%.2f"%stdev(bTimeTotal)} & {"%.2f"%avg(bTimeMatmul1)} & {"%.2f"%avg(bTimeMatmul2)} & {"%.2f"%avg(otime)} & {"%.2f"%stdev(otime)} & {"%.2f"%(100 - avg(otime)/avg(bTimeTotal)*100)}' 346 | results_csv += result_row + "\n" 347 | print(result_row) 348 | time.sleep(5) 349 | 350 | with open(os.path.join(resultsDir(""), f"{attention_or_mlp}-{model}-{arch}.csv"), "w") as f: 351 | f.write(results_csv) 352 | -------------------------------------------------------------------------------- /src/ml-bench/transformer/results/allreduce_times-12288: -------------------------------------------------------------------------------- 1 | M & Time in microseconds 2 | 1 & 38 3 | 2 & 35 4 | 4 & 67 5 | 8 & 148 6 | 16 & 104 7 | 32 & 146 8 | 64 & 95 9 | 128 & 132 10 | 256 & 211 11 | 512 & 274 12 | 1024 & 502 13 | 2048 & 994 14 | 4096 & 1979 15 | 8192 & 3924 -------------------------------------------------------------------------------- /src/ml-bench/transformer/results/attention-results: -------------------------------------------------------------------------------- 1 | 256 & 6144 & Row-Sync & 184.31 & 248.81 & 1.55 & 141.15 & 67.44 & 2 & 144 & 96 & 252.93 & 1.24 & -1.65 2 | 256 & 6144 & Tile-Sync & 184.31 & 247.33 & 0.96 & 140.67 & 67.00 & 2 & 144 & 96 & 247.93 & 1.80 & -0.24 3 | 512 & 6144 & Row-Sync & 296.36 & 352.22 & 1.40 & 203.33 & 82.81 & 2 & 72 & 96 & 353.48 & 1.76 & -0.36 4 | 512 & 6144 & Tile-Sync & 296.36 & 352.81 & 1.69 & 203.70 & 83.04 & 2 & 72 & 96 & 359.67 & 2.79 & -1.94 5 | 1024 & 6144 & Row-Sync & 503.72 & 515.30 & 1.92 & 267.11 & 138.33 & 2 & 144 & 192 & 499.26 & 4.02 & 3.11 6 | 1024 & 6144 & Tile-Sync & 503.72 & 515.56 & 2.15 & 267.56 & 138.41 & 2 & 144 & 192 & 492.41 & 10.65 & 4.49 7 | 2048 & 6144 & Row-Sync & 959.01 & 929.93 & 11.62 & 511.04 & 218.59 & 2 & 288 & 384 & 939.41 & 23.25 & -1.02 8 | 2048 & 6144 & Tile-Sync & 959.01 & 942.15 & 9.99 & 516.78 & 222.19 & 2 & 288 & 384 & 978.59 & 17.18 & -3.87 9 | 1 & 8192 & Row-Sync & 102.15 & 176.22 & 1.60 & 108.67 & 51.48 & 3 & 48 & 32 & 179.74 & 0.90 & 4.25 10 | 1 & 8192 & Tile-Sync & 102.15 & 176.22 & 1.60 & 108.67 & 51.48 & 3 & 48 & 32 & 168.74 & 0.90 & 4.25 11 | 2 & 8192 & Row-Sync & 120.83 & 174.33 & 0.83 & 107.59 & 50.85 & 3 & 48 & 32 & 179.63 & 1.04 & 3.27 12 | 2 & 8192 & Tile-Sync & 120.83 & 174.33 & 0.83 & 107.59 & 50.85 & 3 & 48 & 32 & 168.63 & 1.04 & 3.27 13 | 4 & 8192 & Row-Sync & 121.60 & 176.07 & 1.62 & 108.63 & 51.26 & 3 & 48 & 32 & 179.41 & 1.05 & 3.79 14 | 4 & 8192 & Tile-Sync & 121.60 & 176.07 & 1.62 & 108.63 & 51.26 & 3 & 48 & 32 & 169.41 & 1.05 & 3.79 15 | 8 & 8192 & Row-Sync & 122.68 & 175.07 & 0.87 & 108.37 & 50.85 & 3 & 48 & 32 & 179.19 & 1.14 & 3.93 16 | 8 & 8192 & Tile-Sync & 122.68 & 175.07 & 0.87 & 108.37 & 50.85 & 3 & 48 & 32 & 168.19 & 1.14 & 3.93 17 | 16 & 8192 & Row-Sync & 124.97 & 177.15 & 1.03 & 109.04 & 51.56 & 3 & 48 & 32 & 179.30 & 1.23 & 3.30 18 | 16 & 8192 & Tile-Sync & 124.97 & 177.15 & 1.03 & 109.04 & 51.56 & 3 & 48 & 32 & 171.30 & 1.23 & 3.30 19 | 32 & 8192 & Row-Sync & 582.10 & 178.74 & 0.66 & 110.00 & 51.85 & 3 & 48 & 32 & 179.93 & 1.11 & 1.57 20 | 32 & 8192 & Tile-Sync & 582.10 & 178.74 & 0.66 & 110.00 & 51.85 & 3 & 48 & 32 & 175.93 & 1.11 & 1.57 21 | 64 & 8192 & Row-Sync & 124.98 & 187.22 & 1.63 & 113.85 & 52.33 & 3 & 48 & 32 & 188.81 & 1.71 & -0.85 22 | 64 & 8192 & Tile-Sync & 124.98 & 187.22 & 1.63 & 113.85 & 52.33 & 3 & 48 & 32 & 188.81 & 1.71 & -0.85 23 | 256 & 8192 & Row-Sync & 253.36 & 393.37 & 1.31 & 261.41 & 85.26 & 2 & 192 & 128 & 441.19 & 5.48 & -12.16 24 | 256 & 8192 & Tile-Sync & 253.36 & 395.22 & 2.04 & 262.00 & 85.67 & 2 & 192 & 128 & 430.15 & 7.55 & -8.84 25 | 512 & 8192 & Row-Sync & 464.36 & 509.00 & 3.11 & 326.67 & 101.37 & 2 & 96 & 128 & 508.07 & 2.72 & 0.18 26 | 512 & 8192 & Tile-Sync & 464.36 & 511.19 & 4.95 & 328.41 & 101.81 & 2 & 96 & 128 & 524.41 & 7.05 & -2.59 27 | 1024 & 8192 & Row-Sync & 895.86 & 875.81 & 3.77 & 544.07 & 192.41 & 2 & 192 & 256 & 872.93 & 23.50 & 0.33 28 | 1024 & 8192 & Tile-Sync & 895.86 & 874.30 & 3.66 & 543.56 & 191.67 & 2 & 192 & 256 & 855.44 & 20.49 & 2.16 29 | 2048 & 8192 & Row-Sync & 1629.03 & 1553.41 & 4.11 & 946.74 & 348.04 & 2 & 384 & 512 & 1316.04 & 38.76 & 15.28 30 | 2048 & 8192 & Tile-Sync & 1629.03 & 1552.44 & 4.06 & 945.63 & 347.81 & 2 & 384 & 512 & 1428.63 & 39.32 & 7.98 31 | 1 & 12288 & Row-Sync & 285.07 & 258.70 & 1.79 & 165.74 & 74.48 & 3 & 72 & 48 & 260.59 & 1.28 & 5.07 32 | 1 & 12288 & Tile-Sync & 285.07 & 258.70 & 1.79 & 165.74 & 74.48 & 3 & 72 & 48 & 245.59 & 1.28 & 5.07 33 | 1 & 12288 & Stream-K & -- & -- & -- & -- & - & - & - & - & 248 & 1.28 & 34 | 1 & 12288 & Strided-Tile-Sync & 285.07 & 258.70 & 1.79 & 165.74 & 74.48 & 3 & 72 & 48 & 230.59 & 1.28 & 5.07 35 | 2 & 12288 & Row-Sync & 261.00 & 258.26 & 1.26 & 165.70 & 73.93 & 3 & 72 & 48 & 260.74 & 1.26 & 4.07 36 | 2 & 12288 & Tile-Sync & 285.07 & 258.70 & 1.79 & 165.74 & 74.48 & 3 & 72 & 48 & 245.59 & 1.28 & 5.07 37 | 2 & 12288 & Stream-K & -- & -- & -- & -- & - & - & - & - & 248 & 1.28 & 38 | 2 & 12288 & Strided-Tile-Sync & 285.07 & 258.70 & 1.79 & 165.74 & 74.48 & 3 & 72 & 48 & 230.59 & 1.28 & 5.07 39 | 4 & 12288 & Row-Sync & 264.43 & 258.48 & 1.45 & 165.44 & 74.33 & 3 & 72 & 48 & 260.44 & 0.97 & 4.27 40 | 4 & 12288 & Tile-Sync & 285.07 & 258.70 & 1.79 & 165.74 & 74.48 & 3 & 72 & 48 & 245.59 & 1.28 & 5.07 41 | 4 & 12288 & Stream-K & -- & -- & -- & -- & - & - & - & - & 248 & 1.28 & 42 | 4 & 12288 & Strided-Tile-Sync & 285.07 & 258.70 & 1.79 & 165.74 & 74.48 & 3 & 72 & 48 & 230.59 & 1.28 & 5.07 43 | 8 & 12288 & Row-Sync & 265.85 & 259.41 & 1.58 & 166.30 & 74.37 & 3 & 72 & 48 & 260.44 & 1.91 & 4.61 44 | 8 & 12288 & Tile-Sync & 285.07 & 258.70 & 1.79 & 165.74 & 74.48 & 3 & 72 & 48 & 245.59 & 1.28 & 5.07 45 | 8 & 12288 & Stream-K & -- & -- & -- & -- & - & - & - & - & 248 & 1.28 & 46 | 8 & 12288 & Strided-Tile-Sync & 285.07 & 258.70 & 1.79 & 165.74 & 74.48 & 3 & 72 & 48 & 230.59 & 1.28 & 5.07 47 | 16 & 12288 & Row-Sync & 271.84 & 260.15 & 1.35 & 166.70 & 74.78 & 3 & 72 & 48 & 260.70 & 1.66 & 1.71 48 | 16 & 12288 & Tile-Sync & 285.07 & 258.70 & 1.79 & 165.74 & 74.48 & 3 & 72 & 48 & 245.59 & 1.28 & 5.07 49 | 16 & 12288 & Stream-K & -- & -- & -- & -- & - & - & - & - & 248 & 1.28 & 50 | 16 & 12288 & Strided-Tile-Sync & 285.07 & 258.70 & 1.79 & 165.74 & 74.48 & 3 & 72 & 48 & 230.59 & 1.28 & 5.07 51 | 32 & 12288 & Row-Sync & 245.88 & 262.59 & 1.65 & 168.04 & 75.19 & 3 & 72 & 48 & 260.74 & 1.68 & 0.71 52 | 32 & 12288 & Tile-Sync & 285.07 & 258.70 & 1.79 & 165.74 & 74.48 & 3 & 72 & 48 & 245.59 & 1.28 & 5.07 53 | 32 & 12288 & Stream-K & -- & -- & -- & -- & - & - & - & - & 248 & 1.28 & 54 | 32 & 12288 & Strided-Tile-Sync & 285.07 & 258.70 & 1.79 & 165.74 & 74.48 & 3 & 72 & 48 & 230.59 & 1.28 & 5.07 55 | 64 & 12288 & Row-Sync & 243.70 & 302.04 & 46.02 & 189.56 & 85.93 & 3 & 72 & 48 & 302.81 & 2.18 & 5.04 56 | 64 & 12288 & Tile-Sync & 243.70 & 302.04 & 46.02 & 189.56 & 85.93 & 3 & 72 & 48 & 286.81 & 2.18 & 5.04 57 | 64 & 12288 & Stream-K & -- & -- & -- & -- & - & - & - & - & 272 & 1.28 & 58 | 64 & 12288 & Strided-Tile-Sync & 285.07 & 258.70 & 1.79 & 165.74 & 74.48 & 3 & 72 & 48 & 270.59 & 1.28 & 5.07 59 | 128 & 12288 & Row-Sync & 243.70 & 302.04 & 46.02 & 189.56 & 85.93 & 3 & 72 & 48 & 302.81 & 2.18 & 5.04 60 | 128 & 12288 & Tile-Sync & 243.70 & 302.04 & 46.02 & 189.56 & 85.93 & 3 & 72 & 48 & 286.81 & 2.18 & 5.04 61 | 128 & 12288 & Stream-K & -- & -- & -- & -- & - & - & - & - & 272 & 1.28 & 62 | 128 & 12288 & Strided-Tile-Sync & 285.07 & 258.70 & 1.79 & 165.74 & 74.48 & 3 & 72 & 48 & 270.59 & 1.28 & 5.07 63 | 256 & 12288 & Row-Sync & 504.10 & 699.63 & 2.47 & 440.74 & 197.07 & 2 & 288 & 192 & 682.30 & 10.87 & 2.48 64 | 256 & 12288 & Tile-Sync & 504.10 & 699.19 & 2.11 & 441.11 & 196.78 & 2 & 288 & 192 & 686.74 & 6.49 & 1.78 65 | 256 & 12288 & Stream-K & -- & -- & -- & -- & - & - & - & - & 1142 & 1.28 & 66 | 256 & 12288 & Strided-Tile-Sync & 504.10 & 699.19 & 2.11 & 441.11 & 196.78 & 2 & 288 & 192 & 686.74 & 6.49 & 1.78 67 | 512 & 12288 & Row-Sync & 955.07 & 840.52 & 3.04 & 495.04 & 237.67 & 2 & 144 & 192 & 770.85 & 24.72 & 4.24 68 | 512 & 12288 & Tile-Sync & 955.07 & 839.59 & 2.83 & 494.26 & 237.59 & 2 & 144 & 192 & 860.48 & 8.27 & -2.49 69 | 512 & 12288 & Stream-K & -- & -- & -- & -- & - & - & - & - & 880 & 1.28 & 70 | 512 & 12288 & Strided-Tile-Sync & 955.07 & 839.59 & 2.83 & 494.26 & 237.59 & 2 & 144 & 192 & 860.48 & 8.27 & -2.49 71 | 1024 & 12288 & Row-Sync & 1839.59 & 1541.37 & 19.74 & 947.26 & 393.96 & 2 & 288 & 384 & 1370.48 & 16.36 & 8.04 72 | 1024 & 12288 & Tile-Sync & 1839.59 & 1497.96 & 63.91 & 923.48 & 379.07 & 2 & 144 & 384 & 1468.78 & 10.74 & 1.95 73 | 1024 & 12288 & Stream-K & -- & -- & -- & -- & - & - & - & - & 1420 & 1.28 & 74 | 1024 & 12288 & Strided-Tile-Sync & 1839.59 & 1497.96 & 63.91 & 923.48 & 379.07 & 2 & 144 & 384 & 1468.78 & 10.74 & 1.95 75 | 2048 & 12288 & Row-Sync & 3487.05 & 2781.11 & 111.35 & 1771.15 & 646.63 & 2 & 288 & 768 & 2595.22 & 20.65 & 6.18 76 | 2048 & 12288 & Tile-Sync & 3487.05 & 2616.26 & 7.05 & 1660.22 & 614.22 & 2 & 288 & 768 & 2799.00 & 15.52 & -6.98 77 | 2048 & 12288 & Stream-K & -- & -- & -- & -- & - & - & - & - & 2900 & 1.28 & 78 | 2048 & 12288 & Strided-Tile-Sync & 3487.05 & 2616.26 & 7.05 & 1660.22 & 614.22 & 2 & 288 & 768 & 2799.00 & 15.52 & -6.98 79 | 1 & 16384 & Row-Sync & 537.37 & 394.26 & 2.05 & 276.15 & 97.74 & 3 & 96 & 64 & 395.19 & 2.35 & 2.56 80 | 1 & 16384 & Tile-Sync & 537.37 & 394.26 & 2.05 & 276.15 & 97.74 & 3 & 96 & 64 & 384.19 & 2.35 & 2.56 81 | 2 & 16384 & Row-Sync & 348.35 & 394.74 & 2.46 & 276.19 & 98.22 & 3 & 96 & 64 & 395.56 & 1.87 & 2.83 82 | 2 & 16384 & Tile-Sync & 348.35 & 394.74 & 2.46 & 276.19 & 98.22 & 3 & 96 & 64 & 383.56 & 1.87 & 2.83 83 | 4 & 16384 & Row-Sync & 351.41 & 396.52 & 2.49 & 278.48 & 97.81 & 3 & 96 & 64 & 396.89 & 2.62 & 2.68 84 | 4 & 16384 & Tile-Sync & 351.41 & 396.52 & 2.49 & 278.48 & 97.81 & 3 & 96 & 64 & 385.89 & 2.62 & 2.68 85 | 8 & 16384 & Row-Sync & 351.99 & 399.04 & 1.29 & 280.37 & 97.96 & 3 & 96 & 64 & 399.22 & 2.42 & 2.71 86 | 8 & 16384 & Tile-Sync & 351.99 & 399.04 & 1.29 & 280.37 & 97.96 & 3 & 96 & 64 & 388.22 & 2.42 & 2.71 87 | 16 & 16384 & Row-Sync & 363.81 & 401.37 & 1.50 & 282.59 & 98.22 & 3 & 96 & 64 & 401.96 & 2.33 & 2.59 88 | 16 & 16384 & Tile-Sync & 363.81 & 401.37 & 1.50 & 282.59 & 98.22 & 3 & 96 & 64 & 390.96 & 2.33 & 2.59 89 | 32 & 16384 & Row-Sync & 356.26 & 406.67 & 1.92 & 285.96 & 99.22 & 3 & 96 & 64 & 406.33 & 1.96 & 2.54 90 | 32 & 16384 & Tile-Sync & 356.26 & 406.67 & 1.92 & 285.96 & 99.22 & 3 & 96 & 64 & 396.33 & 1.96 & 2.54 91 | 64 & 16384 & Row-Sync & 366.44 & 458.30 & 2.38 & 301.59 & 126.85 & 3 & 288 & 192 & 458.93 & 4.18 & -4.94 92 | 64 & 16384 & Tile-Sync & 366.44 & 458.30 & 2.38 & 301.59 & 126.85 & 3 & 288 & 192 & 450.93 & 4.18 & -4.94 93 | 256 & 16384 & Row-Sync & 824.85 & 1204.48 & 2.64 & 833.89 & 293.67 & 2 & 384 & 256 & 1157.63 & 5.70 & 3.89 94 | 256 & 16384 & Tile-Sync & 824.85 & 1205.33 & 2.88 & 833.85 & 293.89 & 2 & 384 & 256 & 1155.59 & 6.92 & 4.13 95 | 512 & 16384 & Row-Sync & 1564.33 & 1516.59 & 7.51 & 1041.63 & 334.81 & 2 & 192 & 256 & 1429.41 & 52.69 & 5.75 96 | 512 & 16384 & Tile-Sync & 1564.33 & 1515.56 & 5.02 & 1040.81 & 335.41 & 2 & 192 & 256 & 1517.59 & 8.30 & -0.13 97 | 1024 & 16384 & Row-Sync & 3245.55 & 2590.70 & 100.54 & 1742.89 & 601.70 & 2 & 384 & 512 & 2291.89 & 10.84 & 11.53 98 | 1024 & 16384 & Tile-Sync & 3245.55 & 2674.04 & 15.14 & 1800.89 & 618.41 & 2 & 384 & 512 & 2411.11 & 17.02 & 9.83 99 | 2048 & 16384 & Row-Sync & 6134.47 & 4683.22 & 162.54 & 3161.96 & 1063.67 & 2 & 384 & 1024 & 4277.78 & 22.25 & 8.66 100 | 2048 & 16384 & Tile-Sync & 6134.47 & 4645.93 & 163.57 & 3139.52 & 1053.19 & 2 & 384 & 1024 & 4717.48 & 18.33 & -1.54 -------------------------------------------------------------------------------- /src/ml-bench/transformer/results/attention-results-gpt-3-cuda-12.2: -------------------------------------------------------------------------------- 1 | 1 & 12288 & baseline & 244.28 & 1.13 & 159.33 & 65.56 2 | 1 & 12288 & rowsync & 244.28 & 1.13 & 159.33 & 65.56 & 225.33 & 1.50 & 7.76 3 | 1 & 12288 & tilesync & 244.28 & 1.13 & 159.33 & 65.56 & 222.72 & 2.22 & 8.82 4 | 1 & 12288 & stridedsync & 244.28 & 1.13 & 159.33 & 65.56 & 222.11 & 2.42 & 9.07 5 | 2 & 12288 & baseline & 255.83 & 1.50 & 159.33 & 65.61 6 | 2 & 12288 & rowsync & 255.83 & 1.50 & 159.33 & 65.61 & 226.39 & 0.98 & 11.51 7 | 2 & 12288 & tilesync & 255.83 & 1.50 & 159.33 & 65.61 & 237.83 & 2.31 & 7.04 8 | 2 & 12288 & stridedsync & 255.83 & 1.50 & 159.33 & 65.61 & 238.61 & 2.09 & 6.73 9 | 4 & 12288 & baseline & 245.33 & 1.24 & 159.89 & 65.83 10 | 4 & 12288 & rowsync & 245.33 & 1.24 & 159.89 & 65.83 & 227.33 & 1.14 & 7.34 11 | 4 & 12288 & tilesync & 245.33 & 1.24 & 159.89 & 65.83 & 224.28 & 1.78 & 8.58 12 | 4 & 12288 & stridedsync & 245.33 & 1.24 & 159.89 & 65.83 & 223.11 & 2.59 & 9.06 13 | 8 & 12288 & baseline & 245.89 & 1.60 & 159.89 & 66.06 14 | 8 & 12288 & rowsync & 245.89 & 1.60 & 159.89 & 66.06 & 229.67 & 1.41 & 6.60 15 | 8 & 12288 & tilesync & 245.89 & 1.60 & 159.89 & 66.06 & 223.78 & 2.16 & 8.99 16 | 8 & 12288 & stridedsync & 245.89 & 1.60 & 159.89 & 66.06 & 224.94 & 2.31 & 8.52 17 | 16 & 12288 & baseline & 248.22 & 1.35 & 161.11 & 67.00 18 | 16 & 12288 & rowsync & 248.22 & 1.35 & 161.11 & 67.00 & 234.06 & 1.76 & 5.71 19 | 16 & 12288 & tilesync & 248.22 & 1.35 & 161.11 & 67.00 & 226.94 & 2.04 & 8.57 20 | 16 & 12288 & stridedsync & 248.22 & 1.35 & 161.11 & 67.00 & 226.28 & 1.60 & 8.84 21 | 32 & 12288 & baseline & 252.83 & 1.50 & 164.28 & 68.06 22 | 32 & 12288 & rowsync & 252.83 & 1.50 & 164.28 & 68.06 & 240.94 & 1.55 & 4.70 23 | 32 & 12288 & tilesync & 252.83 & 1.50 & 164.28 & 68.06 & 233.39 & 2.17 & 7.69 24 | 32 & 12288 & stridedsync & 252.83 & 1.50 & 164.28 & 68.06 & 231.33 & 1.57 & 8.50 25 | 64 & 12288 & baseline & 297.56 & 1.65 & 190.00 & 82.17 26 | 64 & 12288 & rowsync & 297.56 & 1.65 & 190.00 & 82.17 & 265.00 & 6.25 & 10.94 27 | 64 & 12288 & tilesync & 297.56 & 1.65 & 190.00 & 82.17 & 282.56 & 5.02 & 5.04 28 | 64 & 12288 & stridedsync & 297.56 & 1.65 & 190.00 & 82.17 & 286.00 & 4.69 & 3.88 29 | 128 & 12288 & baseline & 359.50 & 1.65 & 222.22 & 98.06 30 | 128 & 12288 & rowsync & 359.50 & 1.65 & 222.22 & 98.06 & 352.78 & 3.04 & 1.87 31 | 128 & 12288 & tilesync & 359.50 & 1.65 & 222.22 & 98.06 & 353.78 & 2.86 & 1.59 32 | 128 & 12288 & stridedsync & 359.50 & 1.65 & 222.22 & 98.06 & 350.17 & 2.81 & 2.60 33 | 256 & 12288 & baseline & 536.11 & 3.07 & 307.72 & 138.61 34 | 256 & 12288 & rowsync & 536.11 & 3.07 & 307.72 & 138.61 & 443.50 & 5.52 & 17.27 35 | 256 & 12288 & tilesync & 536.11 & 3.07 & 307.72 & 138.61 & 509.72 & 3.64 & 4.92 36 | 256 & 12288 & stridedsync & 536.11 & 3.07 & 307.72 & 138.61 & 513.22 & 3.47 & 4.27 37 | 512 & 12288 & baseline & 825.94 & 2.78 & 499.94 & 223.72 38 | 512 & 12288 & rowsync & 825.94 & 2.78 & 499.94 & 223.72 & 726.33 & 9.17 & 12.06 39 | 512 & 12288 & tilesync & 825.94 & 2.78 & 499.94 & 223.72 & 787.00 & 4.06 & 4.72 40 | 512 & 12288 & stridedsync & 825.94 & 2.78 & 499.94 & 223.72 & 768.22 & 5.37 & 6.99 41 | 1024 & 12288 & baseline & 1478.72 & 15.71 & 915.56 & 377.17 42 | 1024 & 12288 & rowsync & 1478.72 & 15.71 & 915.56 & 377.17 & 1238.00 & 3.77 & 16.28 43 | 1024 & 12288 & tilesync & 1478.72 & 15.71 & 915.56 & 377.17 & 1400.39 & 24.42 & 5.30 44 | 1024 & 12288 & stridedsync & 1478.72 & 15.71 & 915.56 & 377.17 & 1257.33 & 14.71 & 14.97 45 | 2048 & 12288 & baseline & 2815.17 & 7.47 & 1813.83 & 644.89 46 | 2048 & 12288 & rowsync & 2815.17 & 7.47 & 1813.83 & 644.89 & 2543.39 & 79.25 & 9.65 47 | 2048 & 12288 & tilesync & 2815.17 & 7.47 & 1813.83 & 644.89 & 2558.67 & 61.36 & 9.11 48 | 2048 & 12288 & stridedsync & 2815.17 & 7.47 & 1813.83 & 644.89 & 2600.56 & 83.22 & 7.62 49 | 50 | 51 | 52 | 53 | 256 & 12288 & baseline & 581.67 & 7.72 & 354.78 & 44.00 54 | 256 & 12288 & rowsync & 581.67 & 7.72 & 354.78 & 44.00 & 570.44 & 11.00 & 1.93 55 | 256 & 12288 & tilesync & 581.67 & 7.72 & 354.78 & 44.00 & 586.17 & 13.74 & -0.77 56 | 512 & 12288 & baseline & 808.06 & 2.90 & 487.78 & 62.67 57 | 512 & 12288 & rowsync & 808.06 & 2.90 & 487.78 & 62.67 & 819.67 & 14.28 & -1.44 58 | 512 & 12288 & tilesync & 808.06 & 2.90 & 487.78 & 62.67 & 748.39 & 10.56 & 7.38 59 | 1024 & 12288 & baseline & 1360.17 & 12.42 & 870.33 & 75.17 60 | 1024 & 12288 & rowsync & 1360.17 & 12.42 & 870.33 & 75.17 & 1286.50 & 13.99 & 5.42 61 | 1024 & 12288 & tilesync & 1360.17 & 12.42 & 870.33 & 75.17 & 1195.83 & 5.57 & 12.08 62 | 2048 & 12288 & baseline & 2331.17 & 5.91 & 1482.78 & 149.78 63 | 2048 & 12288 & rowsync & 2331.17 & 5.91 & 1482.78 & 149.78 & 2338.39 & 31.34 & -0.31 64 | 2048 & 12288 & tilesync & 2331.17 & 5.91 & 1482.78 & 149.78 & 2315.72 & 27.28 & 0.66 65 | 1 & 12288 & baseline & 271.94 & 0.94 & 145.83 & 45.22 66 | 1 & 12288 & rowsync & 271.94 & 0.94 & 145.83 & 45.22 & 252.72 & 1.18 & 7.07 67 | 1 & 12288 & tilesync & 271.94 & 0.94 & 145.83 & 45.22 & 252.67 & 2.35 & 7.09 68 | 4 & 12288 & baseline & 271.78 & 1.59 & 146.22 & 44.67 69 | 4 & 12288 & rowsync & 271.78 & 1.59 & 146.22 & 44.67 & 253.61 & 1.29 & 6.68 70 | 4 & 12288 & tilesync & 271.78 & 1.59 & 146.22 & 44.67 & 252.44 & 2.09 & 7.11 71 | 1 & 12288 & baseline & 290.28 & 1.60 & 145.78 & 50.44 72 | 1 & 12288 & rowsync & 290.28 & 1.60 & 145.78 & 50.44 & 274.94 & 2.78 & 5.28 73 | 1 & 12288 & tilesync & 290.28 & 1.60 & 145.78 & 50.44 & 268.39 & 2.38 & 7.54 74 | 4 & 12288 & baseline & 290.28 & 1.71 & 145.67 & 50.72 75 | 4 & 12288 & rowsync & 290.28 & 1.71 & 145.67 & 50.72 & 274.50 & 3.24 & 5.44 76 | 4 & 12288 & tilesync & 290.28 & 1.71 & 145.67 & 50.72 & 268.22 & 2.60 & 7.60 77 | 78 | Attention branch 79 | 256 & 0& 12288 & baseline & 581.94 & 9.07 & 353.89 & 44.11 80 | 256 & 0& 12288 & rowsync & 581.94 & 9.07 & 353.89 & 44.11 & 591.67 & 11.59 & -1.67 81 | 256 & 0& 12288 & tilesync & 581.94 & 9.07 & 353.89 & 44.11 & 491.83 & 4.53 & 15.48 82 | 256 & 0& 12288 & stridedsync & 581.94 & 9.07 & 353.89 & 44.11 & 492.39 & 2.62 & 15.39 83 | 512 & 0& 12288 & baseline & 809.44 & 3.54 & 488.06 & 63.28 84 | 512 & 0& 12288 & rowsync & 809.44 & 3.54 & 488.06 & 63.28 & 809.06 & 11.19 & 0.05 85 | 512 & 0& 12288 & tilesync & 809.44 & 3.54 & 488.06 & 63.28 & 648.44 & 4.26 & 19.89 86 | 512 & 0& 12288 & stridedsync & 809.44 & 3.54 & 488.06 & 63.28 & 713.44 & 4.33 & 11.86 87 | 1024 & 0& 12288 & baseline & 1364.11 & 14.04 & 869.28 & 75.33 88 | 1024 & 0& 12288 & rowsync & 1364.11 & 14.04 & 869.28 & 75.33 & 1334.00 & 4.50 & 2.21 89 | 1024 & 0& 12288 & tilesync & 1364.11 & 14.04 & 869.28 & 75.33 & 1110.33 & 3.69 & 18.60 90 | 1024 & 0& 12288 & stridedsync & 1364.11 & 14.04 & 869.28 & 75.33 & 1103.39 & 5.99 & 19.11 91 | 2048 & 0& 12288 & baseline & 2334.00 & 10.20 & 1485.89 & 148.44 92 | 2048 & 0& 12288 & rowsync & 2334.00 & 10.20 & 1485.89 & 148.44 & 2367.67 & 35.49 & -1.44 93 | 2048 & 0& 12288 & tilesync & 2334.00 & 10.20 & 1485.89 & 148.44 & 2335.11 & 27.27 & -0.05 94 | 2048 & 0& 12288 & stridedsync & 2334.00 & 10.20 & 1485.89 & 148.44 & 2305.56 & 29.16 & 1.22 95 | 96 | Updated Attention results 97 | 256 & 12288 & baseline & 585.50 & 10.13 & 354.44 & 44.72 98 | 256 & 12288 & rowsync & 585.50 & 10.13 & 354.44 & 44.72 & 568.61 & 13.29 & 2.88 99 | 256 & 12288 & tilesync & 585.50 & 10.13 & 354.44 & 44.72 & 585.94 & 11.49 & -0.08 100 | 256 & 12288 & stridedsync & 585.50 & 10.13 & 354.44 & 44.72 & 571.06 & 14.13 & 2.47 101 | 512 & 12288 & baseline & 698.11 & 2.08 & 418.17 & 43.33 102 | 512 & 12288 & rowsync & 698.11 & 2.08 & 418.17 & 43.33 & 680.61 & 2.57 & 2.51 103 | 512 & 12288 & tilesync & 698.11 & 2.08 & 418.17 & 43.33 & 711.61 & 3.22 & -1.93 104 | 512 & 12288 & stridedsync & 698.11 & 2.08 & 418.17 & 43.33 & 625.94 & 2.01 & 10.34 105 | 1024 & 12288 & baseline & 1268.44 & 2.06 & 777.67 & 91.67 106 | 1024 & 12288 & rowsync & 1268.44 & 2.06 & 777.67 & 91.67 & 1091.33 & 2.35 & 13.96 107 | 1024 & 12288 & tilesync & 1268.44 & 2.06 & 777.67 & 91.67 & 1130.00 & 3.20 & 10.91 108 | 1024 & 12288 & stridedsync & 1268.44 & 2.06 & 777.67 & 91.67 & 1081.44 & 2.18 & 14.74 109 | 2048 & 12288 & baseline & 2372.72 & 1.84 & 1369.39 & 258.11 110 | 2048 & 12288 & rowsync & 2372.72 & 1.84 & 1369.39 & 258.11 & 2035.83 & 2.33 & 14.20 111 | 2048 & 12288 & tilesync & 2372.72 & 1.84 & 1369.39 & 258.11 & 2036.61 & 6.16 & 14.17 112 | 2048 & 12288 & stridedsync & 2372.72 & 1.84 & 1369.39 & 258.11 & 1994.67 & 3.25 & 15.93 113 | 1 & 12288 & baseline & 259.28 & 0.96 & 142.17 & 40.39 114 | 1 & 12288 & rowsync & 259.28 & 0.96 & 142.17 & 40.39 & 240.78 & 1.48 & 7.14 115 | 1 & 12288 & tilesync & 259.28 & 0.96 & 142.17 & 40.39 & 253.22 & 3.19 & 2.34 116 | 1 & 12288 & stridedsync & 259.28 & 0.96 & 142.17 & 40.39 & 253.11 & 1.23 & 2.38 117 | 4 & 12288 & baseline & 271.33 & 1.14 & 145.61 & 44.72 118 | 4 & 12288 & rowsync & 271.33 & 1.14 & 145.61 & 44.72 & 253.33 & 1.88 & 6.63 119 | 4 & 12288 & tilesync & 271.33 & 1.14 & 145.61 & 44.72 & 254.50 & 2.73 & 6.20 120 | 4 & 12288 & stridedsync & 271.33 & 1.14 & 145.61 & 44.72 & 255.06 & 1.43 & 6.00 121 | 1 & 12288 & baseline & 291.61 & 1.54 & 146.56 & 50.78 122 | 1 & 12288 & rowsync & 291.61 & 1.54 & 146.56 & 50.78 & 274.17 & 3.49 & 5.98 123 | 1 & 12288 & tilesync & 291.61 & 1.54 & 146.56 & 50.78 & 268.94 & 1.80 & 7.77 124 | 1 & 12288 & stridedsync & 291.61 & 1.54 & 146.56 & 50.78 & 270.50 & 2.26 & 7.24 125 | 4 & 12288 & baseline & 291.61 & 1.20 & 146.06 & 50.83 126 | 4 & 12288 & rowsync & 291.61 & 1.20 & 146.06 & 50.83 & 276.28 & 2.16 & 5.26 127 | 4 & 12288 & tilesync & 291.61 & 1.20 & 146.06 & 50.83 & 270.17 & 3.01 & 7.35 128 | 4 & 12288 & stridedsync & 291.61 & 1.20 & 146.06 & 50.83 & 273.72 & 2.82 & 6.13 -------------------------------------------------------------------------------- /src/ml-bench/transformer/results/attention-results-gpt3: -------------------------------------------------------------------------------- 1 | 1 & 12288 & pytorch & 204.29 2 | 1 & 12288 & streamk & 308.69 & 275.70 & 584.39 3 | 1 & 12288 & baseline & 297.94 & 2.01 & 190.33 & 87.89 4 | 1 & 12288 & rowsync & 297.94 & 2.01 & 190.33 & 87.89 & 278.00 & 1.91 & 6.69 5 | 1 & 12288 & tilesync & 297.94 & 2.01 & 190.33 & 87.89 & 293.83 & 2.15 & 1.38 6 | 1 & 12288 & stridedsync & 297.94 & 2.01 & 190.33 & 87.89 & 289.22 & 2.46 & 2.93 7 | 2 & 12288 & pytorch & 261.25 8 | 2 & 12288 & streamk & 308.85 & 275.34 & 584.19 9 | 2 & 12288 & baseline & 307.94 & 1.47 & 189.83 & 87.67 10 | 2 & 12288 & rowsync & 307.94 & 1.47 & 189.83 & 87.67 & 278.61 & 2.62 & 9.53 11 | 2 & 12288 & tilesync & 307.94 & 1.47 & 189.83 & 87.67 & 305.56 & 2.01 & 0.78 12 | 2 & 12288 & stridedsync & 307.94 & 1.47 & 189.83 & 87.67 & 298.78 & 2.44 & 2.98 13 | 4 & 12288 & pytorch & 250.52 14 | 4 & 12288 & streamk & 284.43 & 253.50 & 537.93 15 | 4 & 12288 & baseline & 300.39 & 1.82 & 191.11 & 88.94 16 | 4 & 12288 & rowsync & 300.39 & 1.82 & 191.11 & 88.94 & 280.11 & 2.11 & 6.75 17 | 4 & 12288 & tilesync & 300.39 & 1.82 & 191.11 & 88.94 & 295.89 & 2.78 & 1.50 18 | 4 & 12288 & stridedsync & 300.39 & 1.82 & 191.11 & 88.94 & 286.67 & 2.17 & 4.57 19 | 8 & 12288 & pytorch & 254.89 20 | 8 & 12288 & streamk & 284.24 & 253.42 & 537.65 21 | 8 & 12288 & baseline & 300.83 & 1.82 & 191.22 & 88.72 22 | 8 & 12288 & rowsync & 300.83 & 1.82 & 191.22 & 88.72 & 279.89 & 2.37 & 6.96 23 | 8 & 12288 & tilesync & 300.83 & 1.82 & 191.22 & 88.72 & 295.72 & 1.67 & 1.70 24 | 8 & 12288 & stridedsync & 300.83 & 1.82 & 191.22 & 88.72 & 287.61 & 2.15 & 4.40 25 | 16 & 12288 & pytorch & 260.64 26 | 16 & 12288 & streamk & 284.84 & 254.27 & 539.10 27 | 16 & 12288 & baseline & 304.17 & 1.58 & 194.56 & 89.06 28 | 16 & 12288 & rowsync & 304.17 & 1.58 & 194.56 & 89.06 & 279.67 & 2.33 & 8.05 29 | 16 & 12288 & tilesync & 304.17 & 1.58 & 194.56 & 89.06 & 297.50 & 2.09 & 2.19 30 | 16 & 12288 & stridedsync & 304.17 & 1.58 & 194.56 & 89.06 & 290.83 & 1.69 & 4.38 31 | 32 & 12288 & pytorch & 237.41 32 | 32 & 12288 & streamk & 311.81 & 279.17 & 590.98 33 | 32 & 12288 & baseline & 312.61 & 1.85 & 202.50 & 89.33 34 | 32 & 12288 & rowsync & 312.61 & 1.85 & 202.50 & 89.33 & 282.72 & 2.16 & 9.56 35 | 32 & 12288 & tilesync & 312.61 & 1.85 & 202.50 & 89.33 & 302.56 & 1.85 & 3.22 36 | 32 & 12288 & stridedsync & 312.61 & 1.85 & 202.50 & 89.33 & 292.72 & 2.72 & 6.36 37 | 64 & 12288 & pytorch & 241.86 38 | 64 & 12288 & streamk & 315.38 & 282.61 & 598.00 39 | 64 & 12288 & baseline & 313.83 & 2.77 & 198.28 & 89.89 40 | 64 & 12288 & rowsync & 313.83 & 2.77 & 198.28 & 89.89 & 317.50 & 2.09 & -1.17 41 | 64 & 12288 & tilesync & 313.83 & 2.77 & 198.28 & 89.89 & 320.56 & 6.50 & -2.14 42 | 64 & 12288 & stridedsync & 313.83 & 2.77 & 198.28 & 89.89 & 321.11 & 2.89 & -2.32 43 | 128 & 12288 & pytorch & 293.47 44 | 128 & 12288 & streamk & 334.09 & 406.81 & 740.90 45 | 128 & 12288 & baseline & 403.72 & 1.60 & 250.17 & 112.72 46 | 128 & 12288 & rowsync & 403.72 & 1.60 & 250.17 & 112.72 & 428.50 & 3.45 & -6.14 47 | 128 & 12288 & tilesync & 403.72 & 1.60 & 250.17 & 112.72 & 437.78 & 2.07 & -8.44 48 | 128 & 12288 & stridedsync & 403.72 & 1.60 & 250.17 & 112.72 & 439.00 & 1.81 & -8.74 49 | 256 & 12288 & pytorch & 456.70 50 | 256 & 12288 & streamk & 527.37 & 563.95 & 1091.32 51 | 256 & 12288 & baseline & 574.22 & 1.06 & 334.33 & 150.17 52 | 256 & 12288 & rowsync & 574.22 & 1.06 & 334.33 & 150.17 & 578.11 & 5.33 & -0.68 53 | 256 & 12288 & tilesync & 574.22 & 1.06 & 334.33 & 150.17 & 583.50 & 4.05 & -1.62 54 | 256 & 12288 & stridedsync & 574.22 & 1.06 & 334.33 & 150.17 & 588.44 & 3.45 & -2.48 55 | 512 & 12288 & pytorch & 859.76 56 | 512 & 12288 & streamk & 1048.88 & 945.76 & 1994.64 57 | 512 & 12288 & baseline & 900.11 & 17.58 & 546.50 & 249.72 58 | 512 & 12288 & rowsync & 900.11 & 17.58 & 546.50 & 249.72 & 850.00 & 15.00 & 5.57 59 | 512 & 12288 & tilesync & 900.11 & 17.58 & 546.50 & 249.72 & 905.83 & 20.53 & -0.64 60 | 512 & 12288 & stridedsync & 900.11 & 17.58 & 546.50 & 249.72 & 882.78 & 20.80 & 1.93 61 | 1024 & 12288 & pytorch & 1677.04 62 | 1024 & 12288 & streamk & 1738.14 & 1694.92 & 3433.06 63 | 1024 & 12288 & baseline & 1578.00 & 4.00 & 980.39 & 408.33 64 | 1024 & 12288 & rowsync & 1578.00 & 4.00 & 980.39 & 408.33 & 1535.17 & 12.32 & 2.71 65 | 1024 & 12288 & tilesync & 1578.00 & 4.00 & 980.39 & 408.33 & 1598.06 & 14.11 & -1.27 66 | 1024 & 12288 & stridedsync & 1578.00 & 4.00 & 980.39 & 408.33 & 1410.67 & 12.98 & 10.60 67 | 2048 & 12288 & pytorch & 3306.89 68 | 2048 & 12288 & streamk & 3540.68 & 3470.58 & 7011.26 69 | 2048 & 12288 & baseline & 2964.39 & 76.04 & 1937.11 & 669.06 70 | 2048 & 12288 & rowsync & 2964.39 & 76.04 & 1937.11 & 669.06 & 2744.44 & 31.55 & 7.42 71 | 2048 & 12288 & tilesync & 2964.39 & 76.04 & 1937.11 & 669.06 & 2859.67 & 21.12 & 3.53 72 | 2048 & 12288 & stridedsync & 2964.39 & 76.04 & 1937.11 & 669.06 & 2749.50 & 33.15 & 7.25 -------------------------------------------------------------------------------- /src/ml-bench/transformer/results/attention-results-llama: -------------------------------------------------------------------------------- 1 | 1 & 8192 & pytorch & 93.78 2 | 1 & 8192 & streamk & 206.62 & 186.83 & 393.44 3 | 1 & 8192 & baseline & 188.17 & 1.54 & 115.44 & 56.11 4 | 1 & 8192 & rowsync & 188.17 & 1.54 & 115.44 & 56.11 & 165.61 & 1.29 & 11.99 5 | 1 & 8192 & tilesync & 188.17 & 1.54 & 115.44 & 56.11 & 183.61 & 1.20 & 2.42 6 | 1 & 8192 & stridedsync & 188.17 & 1.54 & 115.44 & 56.11 & 179.72 & 0.96 & 4.49 7 | 2 & 8192 & pytorch & 121.61 8 | 2 & 8192 & streamk & 206.53 & 186.70 & 393.23 9 | 2 & 8192 & baseline & 188.11 & 1.68 & 115.72 & 56.17 10 | 2 & 8192 & rowsync & 188.11 & 1.68 & 115.72 & 56.17 & 165.83 & 0.86 & 11.84 11 | 2 & 8192 & tilesync & 188.11 & 1.68 & 115.72 & 56.17 & 184.78 & 1.26 & 1.77 12 | 2 & 8192 & stridedsync & 188.11 & 1.68 & 115.72 & 56.17 & 180.39 & 1.69 & 4.11 13 | 4 & 8192 & pytorch & 118.03 14 | 4 & 8192 & streamk & 206.37 & 186.33 & 392.70 15 | 4 & 8192 & baseline & 189.39 & 1.65 & 116.00 & 56.17 16 | 4 & 8192 & rowsync & 189.39 & 1.65 & 116.00 & 56.17 & 166.39 & 1.04 & 12.14 17 | 4 & 8192 & tilesync & 189.39 & 1.65 & 116.00 & 56.17 & 184.33 & 1.03 & 2.67 18 | 4 & 8192 & stridedsync & 189.39 & 1.65 & 116.00 & 56.17 & 181.06 & 1.16 & 4.40 19 | 8 & 8192 & pytorch & 122.72 20 | 8 & 8192 & streamk & 206.88 & 186.34 & 393.22 21 | 8 & 8192 & baseline & 189.61 & 1.24 & 115.83 & 56.94 22 | 8 & 8192 & rowsync & 189.61 & 1.24 & 115.83 & 56.94 & 167.06 & 1.63 & 11.90 23 | 8 & 8192 & tilesync & 189.61 & 1.24 & 115.83 & 56.94 & 185.06 & 1.06 & 2.40 24 | 8 & 8192 & stridedsync & 189.61 & 1.24 & 115.83 & 56.94 & 181.11 & 1.64 & 4.48 25 | 16 & 8192 & pytorch & 124.87 26 | 16 & 8192 & streamk & 206.83 & 186.52 & 393.35 27 | 16 & 8192 & baseline & 192.56 & 1.38 & 117.50 & 57.83 28 | 16 & 8192 & rowsync & 192.56 & 1.38 & 117.50 & 57.83 & 168.83 & 1.34 & 12.32 29 | 16 & 8192 & tilesync & 192.56 & 1.38 & 117.50 & 57.83 & 187.83 & 1.54 & 2.45 30 | 16 & 8192 & stridedsync & 192.56 & 1.38 & 117.50 & 57.83 & 182.56 & 1.04 & 5.19 31 | 32 & 8192 & pytorch & 113.46 32 | 32 & 8192 & streamk & 206.83 & 186.92 & 393.75 33 | 32 & 8192 & baseline & 197.33 & 1.28 & 120.94 & 58.94 34 | 32 & 8192 & rowsync & 197.33 & 1.28 & 120.94 & 58.94 & 172.00 & 1.08 & 12.84 35 | 32 & 8192 & tilesync & 197.33 & 1.28 & 120.94 & 58.94 & 190.61 & 1.33 & 3.41 36 | 32 & 8192 & stridedsync & 197.33 & 1.28 & 120.94 & 58.94 & 184.00 & 1.41 & 6.76 37 | 64 & 8192 & pytorch & 119.50 38 | 64 & 8192 & streamk & 186.55 & 185.44 & 371.99 39 | 64 & 8192 & baseline & 210.89 & 1.28 & 117.44 & 66.56 40 | 64 & 8192 & rowsync & 210.89 & 1.28 & 117.44 & 66.56 & 185.11 & 0.90 & 12.22 41 | 64 & 8192 & tilesync & 210.89 & 1.28 & 117.44 & 66.56 & 184.78 & 1.83 & 12.38 42 | 64 & 8192 & stridedsync & 210.89 & 1.28 & 117.44 & 66.56 & 185.44 & 1.72 & 12.07 43 | 128 & 8192 & pytorch & 163.47 44 | 128 & 8192 & streamk & 187.49 & 151.65 & 339.15 45 | 128 & 8192 & baseline & 267.39 & 1.82 & 164.22 & 72.22 46 | 128 & 8192 & rowsync & 267.39 & 1.82 & 164.22 & 72.22 & 236.22 & 2.29 & 11.66 47 | 128 & 8192 & tilesync & 267.39 & 1.82 & 164.22 & 72.22 & 245.22 & 2.65 & 8.29 48 | 128 & 8192 & stridedsync & 267.39 & 1.82 & 164.22 & 72.22 & 246.56 & 2.01 & 7.79 49 | 256 & 8192 & pytorch & 247.74 50 | 256 & 8192 & streamk & 228.54 & 157.43 & 385.97 51 | 256 & 8192 & baseline & 363.78 & 1.99 & 217.22 & 101.72 52 | 256 & 8192 & rowsync & 363.78 & 1.99 & 217.22 & 101.72 & 318.33 & 4.43 & 12.49 53 | 256 & 8192 & tilesync & 363.78 & 1.99 & 217.22 & 101.72 & 328.11 & 7.65 & 9.80 54 | 256 & 8192 & stridedsync & 363.78 & 1.99 & 217.22 & 101.72 & 323.39 & 3.76 & 11.10 55 | 512 & 8192 & pytorch & 419.77 56 | 512 & 8192 & streamk & 425.30 & 325.94 & 751.24 57 | 512 & 8192 & baseline & 564.67 & 3.65 & 303.78 & 186.39 58 | 512 & 8192 & rowsync & 564.67 & 3.65 & 303.78 & 186.39 & 498.67 & 3.77 & 11.69 59 | 512 & 8192 & tilesync & 564.67 & 3.65 & 303.78 & 186.39 & 504.39 & 2.77 & 10.67 60 | 512 & 8192 & stridedsync & 564.67 & 3.65 & 303.78 & 186.39 & 501.94 & 3.78 & 11.11 61 | 1024 & 8192 & pytorch & 822.29 62 | 1024 & 8192 & streamk & 735.10 & 622.09 & 1357.19 63 | 1024 & 8192 & baseline & 932.72 & 3.20 & 547.00 & 253.44 64 | 1024 & 8192 & rowsync & 932.72 & 3.20 & 547.00 & 253.44 & 884.50 & 19.96 & 5.17 65 | 1024 & 8192 & tilesync & 932.72 & 3.20 & 547.00 & 253.44 & 897.94 & 10.30 & 3.73 66 | 1024 & 8192 & stridedsync & 932.72 & 3.20 & 547.00 & 253.44 & 878.89 & 17.69 & 5.77 67 | 2048 & 8192 & pytorch & 1486.02 68 | 2048 & 8192 & streamk & 1232.20 & 1185.89 & 2418.09 69 | 2048 & 8192 & baseline & 1653.00 & 4.78 & 962.44 & 446.22 70 | 2048 & 8192 & rowsync & 1653.00 & 4.78 & 962.44 & 446.22 & 1384.89 & 14.07 & 16.22 71 | 2048 & 8192 & tilesync & 1653.00 & 4.78 & 962.44 & 446.22 & 1408.44 & 11.72 & 14.79 72 | 2048 & 8192 & stridedsync & 1653.00 & 4.78 & 962.44 & 446.22 & 1395.78 & 9.57 & 15.56 73 | 74 | 75 | 76 | 77 | 78 | Attention results with updated cusync 79 | 256 & 8192 & rowsync & 318.78 & 1.40 & 167.28 & 52.39 & 351.78 & 4.40 & 1.35 80 | 256 & 8192 & tilesync & 318.78 & 1.40 & 167.28 & 52.39 & 326.94 & 1.63 & 2.56 81 | 256 & 8192 & stridedsync & 318.78 & 1.40 & 167.28 & 52.39 & 328.06 & 1.30 & 7.91 82 | 512 & 8192 & baseline & 408.11 & 2.00 & 222.61 & 52.72 83 | 512 & 8192 & rowsync & 408.11 & 2.00 & 222.61 & 52.72 & 418.44 & 2.87 & 2.53 84 | 512 & 8192 & tilesync & 408.11 & 2.00 & 222.61 & 52.72 & 418.28 & 2.14 & 5.49 85 | 512 & 8192 & stridedsync & 408.11 & 2.00 & 222.61 & 52.72 & 410.06 & 1.92 & 0.48 86 | 1024 & 8192 & baseline & 716.22 & 2.41 & 413.33 & 66.17 87 | 1024 & 8192 & rowsync & 716.22 & 2.41 & 413.33 & 66.17 & 703.56 & 12.19 & 9.77 88 | 1024 & 8192 & tilesync & 716.22 & 2.41 & 413.33 & 66.17 & 701.72 & 7.65 & 8.02 89 | 1024 & 8192 & stridedsync & 716.22 & 2.41 & 413.33 & 66.17 & 697.94 & 8.08 & 10.55 90 | 2048 & 8192 & baseline & 1338.28 & 28.66 & 801.39 & 105.28 91 | 2048 & 8192 & rowsync & 1338.28 & 28.66 & 801.39 & 105.28 & 1178.22 & 12.59 & 11.96 92 | 2048 & 8192 & tilesync & 1338.28 & 28.66 & 801.39 & 105.28 & 1204.89 & 8.94 & 9.97 93 | 2048 & 8192 & stridedsync & 1338.28 & 28.66 & 801.39 & 105.28 & 1124.83 & 12.48 & 15.95 94 | 1 & 8192 & baseline & 166.67 & 1.03 & 81.56 & 27.44 95 | 1 & 8192 & rowsync & 166.67 & 1.03 & 81.56 & 27.44 & 147.83 & 2.15 & 11.30 96 | 1 & 8192 & tilesync & 166.67 & 1.03 & 81.56 & 27.44 & 149.22 & 2.16 & 10.47 97 | 1 & 8192 & stridedsync & 166.67 & 1.03 & 81.56 & 27.44 & 150.67 & 0.91 & 9.60 98 | 4 & 8192 & baseline & 176.06 & 1.39 & 85.00 & 29.17 99 | 4 & 8192 & rowsync & 176.06 & 1.39 & 85.00 & 29.17 & 156.89 & 1.32 & 10.89 100 | 4 & 8192 & tilesync & 176.06 & 1.39 & 85.00 & 29.17 & 159.61 & 1.91 & 9.34 101 | 4 & 8192 & stridedsync & 176.06 & 1.39 & 85.00 & 29.17 & 159.56 & 1.10 & 9.37 102 | 1 & 8192 & baseline & 184.83 & 0.99 & 83.44 & 30.94 103 | 1 & 8192 & rowsync & 184.83 & 0.99 & 83.44 & 30.94 & 166.94 & 1.76 & 9.68 104 | 1 & 8192 & tilesync & 184.83 & 0.99 & 83.44 & 30.94 & 168.94 & 1.76 & 8.60 105 | 1 & 8192 & stridedsync & 184.83 & 0.99 & 83.44 & 30.94 & 171.94 & 1.66 & 6.97 106 | 4 & 8192 & baseline & 184.94 & 1.00 & 84.33 & 30.56 107 | 4 & 8192 & rowsync & 184.94 & 1.00 & 84.33 & 30.56 & 167.50 & 1.69 & 9.43 108 | 4 & 8192 & tilesync & 184.94 & 1.00 & 84.33 & 30.56 & 171.44 & 1.82 & 7.30 109 | 4 & 8192 & stridedsync & 184.94 & 1.00 & 84.33 & 30.56 & 173.00 & 1.37 & 6.46 110 | -------------------------------------------------------------------------------- /src/ml-bench/transformer/results/attention-stream-k-output: -------------------------------------------------------------------------------- 1 | 1 2 | 20 timing iterations of 1 x 4608 x 12288 matrix-matrix multiply 3 | 4 | Basic split-K GEMM with tile-splitting factor 4 5 | Avg runtime: 0.156845 ms 6 | GFLOPs: 722.027 7 | 8 | StreamK emulating Split-K GEMM with tile-splitting factor 4 9 | Avg runtime: 0.171227 ms 10 | GFLOPs: 661.38 11 | Speedup vs Basic-SplitK: 0.916 12 | 13 | 20 timing iterations of 1 x 12288 x 1536 matrix-matrix multiply 14 | 15 | Basic data-parallel GEMM 16 | Avg runtime: 0.067952 ms 17 | GFLOPs: 555.521 18 | 19 | StreamK GEMM with default load-balancing 20 | Avg runtime: 0.071456 ms 21 | GFLOPs: 528.279 22 | Speedup vs Basic-DP: 0.951 23 | 24 | StreamK emulating basic data-parallel GEMM 25 | Avg runtime: 0.0599296 ms 26 | GFLOPs: 629.885 27 | Speedup vs Basic-DP: 1.134 28 | 29 | Basic split-K GEMM with tile-splitting factor 2 30 | Avg runtime: 0.0685472 ms 31 | GFLOPs: 550.697 32 | 33 | StreamK emulating Split-K GEMM with tile-splitting factor 2 34 | Avg runtime: 0.0737888 ms 35 | GFLOPs: 511.578 36 | Speedup vs Basic-SplitK: 0.929 37 | 2 38 | 20 timing iterations of 2 x 4608 x 12288 matrix-matrix multiply 39 | 40 | Basic split-K GEMM with tile-splitting factor 4 41 | Avg runtime: 0.157358 ms 42 | GFLOPs: 1439.34 43 | 44 | StreamK emulating Split-K GEMM with tile-splitting factor 4 45 | Avg runtime: 0.171435 ms 46 | GFLOPs: 1321.15 47 | Speedup vs Basic-SplitK: 0.918 48 | 49 | 20 timing iterations of 2 x 12288 x 1536 matrix-matrix multiply 50 | 51 | Basic data-parallel GEMM 52 | Avg runtime: 0.0679008 ms 53 | GFLOPs: 1111.88 54 | 55 | StreamK GEMM with default load-balancing 56 | Avg runtime: 0.0718288 ms 57 | GFLOPs: 1051.08 58 | Speedup vs Basic-DP: 0.945 59 | 60 | StreamK emulating basic data-parallel GEMM 61 | Avg runtime: 0.0600224 ms 62 | GFLOPs: 1257.82 63 | Speedup vs Basic-DP: 1.131 64 | 65 | Basic split-K GEMM with tile-splitting factor 2 66 | Avg runtime: 0.0686496 ms 67 | GFLOPs: 1099.75 68 | 69 | StreamK emulating Split-K GEMM with tile-splitting factor 2 70 | Avg runtime: 0.0745376 ms 71 | GFLOPs: 1012.88 72 | Speedup vs Basic-SplitK: 0.921 73 | 4 74 | 20 timing iterations of 4 x 4608 x 12288 matrix-matrix multiply 75 | 76 | Basic split-K GEMM with tile-splitting factor 4 77 | Avg runtime: 0.157414 ms 78 | GFLOPs: 2877.66 79 | 80 | StreamK emulating Split-K GEMM with tile-splitting factor 4 81 | Avg runtime: 0.171813 ms 82 | GFLOPs: 2636.5 83 | Speedup vs Basic-SplitK: 0.916 84 | 85 | 20 timing iterations of 4 x 12288 x 1536 matrix-matrix multiply 86 | 87 | Basic data-parallel GEMM 88 | Avg runtime: 0.0682352 ms 89 | GFLOPs: 2212.86 90 | 91 | StreamK GEMM with default load-balancing 92 | Avg runtime: 0.0718848 ms 93 | GFLOPs: 2100.51 94 | Speedup vs Basic-DP: 0.949 95 | 96 | StreamK emulating basic data-parallel GEMM 97 | Avg runtime: 0.0601536 ms 98 | GFLOPs: 2510.16 99 | Speedup vs Basic-DP: 1.134 100 | 101 | Basic split-K GEMM with tile-splitting factor 2 102 | Avg runtime: 0.069176 ms 103 | GFLOPs: 2182.76 104 | 105 | StreamK emulating Split-K GEMM with tile-splitting factor 2 106 | Avg runtime: 0.0743648 ms 107 | GFLOPs: 2030.46 108 | Speedup vs Basic-SplitK: 0.930 109 | 8 110 | 20 timing iterations of 8 x 4608 x 12288 matrix-matrix multiply 111 | 112 | Basic split-K GEMM with tile-splitting factor 4 113 | Avg runtime: 0.157418 ms 114 | GFLOPs: 5755.2 115 | 116 | StreamK emulating Split-K GEMM with tile-splitting factor 4 117 | Avg runtime: 0.171622 ms 118 | GFLOPs: 5278.85 119 | Speedup vs Basic-SplitK: 0.917 120 | 121 | 20 timing iterations of 8 x 12288 x 1536 matrix-matrix multiply 122 | 123 | Basic data-parallel GEMM 124 | Avg runtime: 0.0683248 ms 125 | GFLOPs: 4419.92 126 | 127 | StreamK GEMM with default load-balancing 128 | Avg runtime: 0.0721632 ms 129 | GFLOPs: 4184.82 130 | Speedup vs Basic-DP: 0.947 131 | 132 | StreamK emulating basic data-parallel GEMM 133 | Avg runtime: 0.0601952 ms 134 | GFLOPs: 5016.84 135 | Speedup vs Basic-DP: 1.135 136 | 137 | Basic split-K GEMM with tile-splitting factor 2 138 | Avg runtime: 0.0694816 ms 139 | GFLOPs: 4346.33 140 | 141 | StreamK emulating Split-K GEMM with tile-splitting factor 2 142 | Avg runtime: 0.0746464 ms 143 | GFLOPs: 4045.61 144 | Speedup vs Basic-SplitK: 0.931 145 | 16 146 | 20 timing iterations of 16 x 4608 x 12288 matrix-matrix multiply 147 | 148 | Basic split-K GEMM with tile-splitting factor 4 149 | Avg runtime: 0.157645 ms 150 | GFLOPs: 11493.8 151 | 152 | StreamK emulating Split-K GEMM with tile-splitting factor 4 153 | Avg runtime: 0.172155 ms 154 | GFLOPs: 10525 155 | Speedup vs Basic-SplitK: 0.916 156 | 157 | 20 timing iterations of 16 x 12288 x 1536 matrix-matrix multiply 158 | 159 | Basic data-parallel GEMM 160 | Avg runtime: 0.0683088 ms 161 | GFLOPs: 8841.9 162 | 163 | StreamK GEMM with default load-balancing 164 | Avg runtime: 0.0723744 ms 165 | GFLOPs: 8345.21 166 | Speedup vs Basic-DP: 0.944 167 | 168 | StreamK emulating basic data-parallel GEMM 169 | Avg runtime: 0.0606784 ms 170 | GFLOPs: 9953.79 171 | Speedup vs Basic-DP: 1.126 172 | 173 | Basic split-K GEMM with tile-splitting factor 2 174 | Avg runtime: 0.0700192 ms 175 | GFLOPs: 8625.92 176 | 177 | StreamK emulating Split-K GEMM with tile-splitting factor 2 178 | Avg runtime: 0.0749888 ms 179 | GFLOPs: 8054.27 180 | Speedup vs Basic-SplitK: 0.934 181 | 32 182 | 20 timing iterations of 32 x 4608 x 12288 matrix-matrix multiply 183 | 184 | Basic split-K GEMM with tile-splitting factor 4 185 | Avg runtime: 0.159341 ms 186 | GFLOPs: 22742.9 187 | 188 | StreamK emulating Split-K GEMM with tile-splitting factor 4 189 | Avg runtime: 0.173766 ms 190 | GFLOPs: 20854.9 191 | Speedup vs Basic-SplitK: 0.917 192 | 193 | 20 timing iterations of 32 x 12288 x 1536 matrix-matrix multiply 194 | 195 | Basic data-parallel GEMM 196 | Avg runtime: 0.0690016 ms 197 | GFLOPs: 17506.3 198 | 199 | StreamK GEMM with default load-balancing 200 | Avg runtime: 0.0723408 ms 201 | GFLOPs: 16698.2 202 | Speedup vs Basic-DP: 0.954 203 | 204 | StreamK emulating basic data-parallel GEMM 205 | Avg runtime: 0.0617504 ms 206 | GFLOPs: 19562 207 | Speedup vs Basic-DP: 1.117 208 | 209 | Basic split-K GEMM with tile-splitting factor 2 210 | Avg runtime: 0.071152 ms 211 | GFLOPs: 16977.2 212 | 213 | StreamK emulating Split-K GEMM with tile-splitting factor 2 214 | Avg runtime: 0.0772352 ms 215 | GFLOPs: 15640 216 | Speedup vs Basic-SplitK: 0.921 217 | 64 218 | 20 timing iterations of 64 x 4608 x 12288 matrix-matrix multiply 219 | 220 | Basic split-K GEMM with tile-splitting factor 4 221 | Avg runtime: 0.173627 ms 222 | GFLOPs: 41743.2 223 | 224 | StreamK emulating Split-K GEMM with tile-splitting factor 4 225 | Avg runtime: 0.179035 ms 226 | GFLOPs: 40482.3 227 | Speedup vs Basic-SplitK: 0.970 228 | 229 | 20 timing iterations of 64 x 12288 x 1536 matrix-matrix multiply 230 | 231 | Basic data-parallel GEMM 232 | Avg runtime: 0.0715152 ms 233 | GFLOPs: 33781.9 234 | 235 | StreamK GEMM with default load-balancing 236 | Avg runtime: 0.0737152 ms 237 | GFLOPs: 32773.7 238 | Speedup vs Basic-DP: 0.970 239 | 240 | StreamK emulating basic data-parallel GEMM 241 | Avg runtime: 0.0648944 ms 242 | GFLOPs: 37228.5 243 | Speedup vs Basic-DP: 1.102 244 | 245 | Basic split-K GEMM with tile-splitting factor 2 246 | Avg runtime: 0.0741392 ms 247 | GFLOPs: 32586.3 248 | 249 | StreamK emulating Split-K GEMM with tile-splitting factor 2 250 | Avg runtime: 0.079552 ms 251 | GFLOPs: 30369.1 252 | Speedup vs Basic-SplitK: 0.932 253 | 128 254 | 20 timing iterations of 128 x 4608 x 12288 matrix-matrix multiply 255 | 256 | Basic split-K GEMM with tile-splitting factor 4 257 | Avg runtime: 0.190134 ms 258 | GFLOPs: 76238.3 259 | 260 | StreamK emulating Split-K GEMM with tile-splitting factor 4 261 | Avg runtime: 0.243024 ms 262 | GFLOPs: 59646.4 263 | Speedup vs Basic-SplitK: 0.782 264 | 265 | 20 timing iterations of 128 x 12288 x 1536 matrix-matrix multiply 266 | 267 | Basic data-parallel GEMM 268 | Avg runtime: 0.094184 ms 269 | GFLOPs: 51302.1 270 | 271 | StreamK GEMM with default load-balancing 272 | Avg runtime: 0.102627 ms 273 | GFLOPs: 47081.5 274 | Speedup vs Basic-DP: 0.918 275 | 276 | StreamK emulating basic data-parallel GEMM 277 | Avg runtime: 0.0931648 ms 278 | GFLOPs: 51863.3 279 | Speedup vs Basic-DP: 1.011 280 | 281 | Basic split-K GEMM with tile-splitting factor 2 282 | Avg runtime: 0.103139 ms 283 | GFLOPs: 46847.7 284 | 285 | StreamK emulating Split-K GEMM with tile-splitting factor 2 286 | Avg runtime: 0.131126 ms 287 | GFLOPs: 36848.7 288 | Speedup vs Basic-SplitK: 0.787 289 | 290 | 256 291 | 20 timing iterations of 256 x 4608 x 12288 matrix-matrix multiply 292 | 293 | Basic split-K GEMM with tile-splitting factor 4 294 | Avg runtime: 0.293355 ms 295 | GFLOPs: 98825.7 296 | 297 | StreamK emulating Split-K GEMM with tile-splitting factor 4 298 | Avg runtime: 9.40603 ms 299 | GFLOPs: 3082.17 300 | Speedup vs Basic-SplitK: 0.031 301 | 302 | 20 timing iterations of 256 x 12288 x 1536 matrix-matrix multiply 303 | 304 | Basic data-parallel GEMM 305 | Avg runtime: 0.14753 ms 306 | GFLOPs: 65503.3 307 | 308 | StreamK GEMM with default load-balancing 309 | Avg runtime: 2.88446 ms 310 | GFLOPs: 3350.25 311 | Speedup vs Basic-DP: 0.051 312 | 313 | StreamK emulating basic data-parallel GEMM 314 | Avg runtime: 0.142814 ms 315 | GFLOPs: 67666 316 | Speedup vs Basic-DP: 1.033 317 | 318 | Basic split-K GEMM with tile-splitting factor 2 319 | Avg runtime: 0.147571 ms 320 | GFLOPs: 65484.8 321 | 322 | StreamK emulating Split-K GEMM with tile-splitting factor 2 323 | Avg runtime: 3.35184 ms 324 | GFLOPs: 2883.1 325 | Speedup vs Basic-SplitK: 0.044 326 | 327 | 512 328 | 20 timing iterations of 512 x 4608 x 12288 matrix-matrix multiply 329 | 330 | Basic split-K GEMM with tile-splitting factor 2 331 | Avg runtime: 0.486846 ms 332 | GFLOPs: 119097 333 | 334 | StreamK emulating Split-K GEMM with tile-splitting factor 2 335 | Avg runtime: 16.9325 ms 336 | GFLOPs: 3424.31 337 | Speedup vs Basic-SplitK: 0.029 338 | 339 | 20 timing iterations of 512 x 12288 x 1536 matrix-matrix multiply 340 | 341 | Basic data-parallel GEMM 342 | Avg runtime: 0.229642 ms 343 | GFLOPs: 84163.1 344 | 345 | StreamK GEMM with default load-balancing 346 | Avg runtime: 5.91028 ms 347 | GFLOPs: 3270.12 348 | Speedup vs Basic-DP: 0.039 349 | 350 | StreamK emulating basic data-parallel GEMM 351 | Avg runtime: 0.219325 ms 352 | GFLOPs: 88122.1 353 | Speedup vs Basic-DP: 1.047 354 | 355 | Basic split-K GEMM with tile-splitting factor 2 356 | Avg runtime: 0.24607 ms 357 | GFLOPs: 78544 358 | 359 | StreamK emulating Split-K GEMM with tile-splitting factor 2 360 | Avg runtime: 5.94693 ms 361 | GFLOPs: 3249.97 362 | Speedup vs Basic-SplitK: 0.041 363 | 364 | 1024 365 | 20 timing iterations of 1024 x 4608 x 12288 matrix-matrix multiply 366 | 367 | Basic data-parallel GEMM 368 | Avg runtime: 0.918547 ms 369 | GFLOPs: 126247 370 | 371 | StreamK GEMM with default load-balancing 372 | Avg runtime: 36.0473 ms 373 | GFLOPs: 3216.99 374 | Speedup vs Basic-DP: 0.025 375 | 376 | StreamK emulating basic data-parallel GEMM 377 | Avg runtime: 0.848008 ms 378 | GFLOPs: 136749 379 | Speedup vs Basic-DP: 1.083 380 | 381 | Basic split-K GEMM with tile-splitting factor 2 382 | Avg runtime: 0.872634 ms 383 | GFLOPs: 132890 384 | 385 | StreamK emulating Split-K GEMM with tile-splitting factor 2 386 | Avg runtime: 35.1125 ms 387 | GFLOPs: 3302.65 388 | Speedup vs Basic-SplitK: 0.025 389 | 390 | 20 timing iterations of 1024 x 12288 x 1536 matrix-matrix multiply 391 | 392 | Basic data-parallel GEMM 393 | Avg runtime: 0.371226 ms 394 | GFLOPs: 104127 395 | 396 | StreamK GEMM with default load-balancing 397 | Avg runtime: 11.6917 ms 398 | GFLOPs: 3306.15 399 | Speedup vs Basic-DP: 0.032 400 | 401 | StreamK emulating basic data-parallel GEMM 402 | Avg runtime: 0.375458 ms 403 | GFLOPs: 102954 404 | Speedup vs Basic-DP: 0.989 405 | 406 | Basic split-K GEMM with tile-splitting factor 2 407 | Avg runtime: 0.408646 ms 408 | GFLOPs: 94592.1 409 | 410 | StreamK emulating Split-K GEMM with tile-splitting factor 2 411 | Avg runtime: 11.7896 ms 412 | GFLOPs: 3278.71 413 | Speedup vs Basic-SplitK: 0.035 414 | 2048 415 | 20 timing iterations of 2048 x 4608 x 12288 matrix-matrix multiply 416 | 417 | Basic data-parallel GEMM 418 | Avg runtime: 1.68713 ms 419 | GFLOPs: 137469 420 | 421 | StreamK GEMM with default load-balancing 422 | Avg runtime: 74.6993 ms 423 | GFLOPs: 3104.82 424 | Speedup vs Basic-DP: 0.023 425 | 426 | StreamK emulating basic data-parallel GEMM 427 | Avg runtime: 1.71081 ms 428 | GFLOPs: 135566 429 | Speedup vs Basic-DP: 0.986 430 | 431 | Basic split-K GEMM with tile-splitting factor 2 432 | Avg runtime: 1.75267 ms 433 | GFLOPs: 132328 434 | 435 | StreamK emulating Split-K GEMM with tile-splitting factor 2 436 | Avg runtime: 74.1262 ms 437 | GFLOPs: 3128.83 438 | Speedup vs Basic-SplitK: 0.024 439 | 440 | 20 timing iterations of 2048 x 12288 x 1536 matrix-matrix multiply 441 | 442 | Basic data-parallel GEMM 443 | Avg runtime: 0.676182 ms 444 | GFLOPs: 114332 445 | 446 | StreamK GEMM with default load-balancing 447 | Avg runtime: 27.0026 ms 448 | GFLOPs: 2863.04 449 | Speedup vs Basic-DP: 0.025 450 | 451 | StreamK emulating basic data-parallel GEMM 452 | Avg runtime: 26.7472 ms 453 | GFLOPs: 2890.37 454 | Speedup vs Basic-DP: 0.025 455 | 456 | Basic split-K GEMM with tile-splitting factor 2 457 | Avg runtime: 0.83525 ms 458 | GFLOPs: 92558.5 459 | 460 | StreamK emulating Split-K GEMM with tile-splitting factor 2 461 | Avg runtime: 27.1901 ms 462 | GFLOPs: 2843.3 463 | Speedup vs Basic-SplitK: 0.031 464 | -------------------------------------------------------------------------------- /src/ml-bench/transformer/results/mlp-gpt3-a100.csv: -------------------------------------------------------------------------------- 1 | 512 & 0 & 12288 & torch & 724.75 2 | 512 & 0 & 12288 & streamk & 718.44 & 345.46 & 372.97 3 | 512 & 0 & 12288 & baseline & 743.08 & 0.79 & 371.54 & 371.54 4 | 512 & 0 & 12288 & rowsync & 743.08 & 0.79 & 371.54 & 371.54 & 752.36 & 0.59 & -1.25 5 | 512 & 0 & 12288 & tilesync & 743.08 & 0.79 & 371.54 & 371.54 & 780.63 & 0.70 & -5.05 6 | 768 & 0 & 12288 & torch & 1061.34 7 | 768 & 0 & 12288 & streamk & 1051.60 & 541.13 & 510.47 8 | 768 & 0 & 12288 & baseline & 1098.70 & 0.96 & 545.17 & 553.53 9 | 768 & 0 & 12288 & rowsync & 1098.70 & 0.96 & 545.17 & 553.53 & 1089.48 & 1.38 & 0.84 10 | 768 & 0 & 12288 & tilesync & 1098.70 & 0.96 & 545.17 & 553.53 & 1140.05 & 1.79 & -3.76 11 | 1024 & 0 & 12288 & torch & 1459.14 12 | 1024 & 0 & 12288 & streamk & 1453.23 & 785.38 & 667.86 13 | 1024 & 0 & 12288 & baseline & 1516.71 & 0.72 & 786.60 & 730.11 14 | 1024 & 0 & 12288 & rowsync & 1516.71 & 0.72 & 786.60 & 730.11 & 1409.54 & 1.33 & 7.07 15 | 1024 & 0 & 12288 & tilesync & 1516.71 & 0.72 & 786.60 & 730.11 & 1540.21 & 1.05 & -1.55 16 | 1280 & 0 & 12288 & torch & 1797.79 17 | 1280 & 0 & 12288 & streamk & 1752.33 & 923.19 & 829.14 18 | 1280 & 0 & 12288 & baseline & 1831.65 & 1.21 & 920.80 & 910.85 19 | 1280 & 0 & 12288 & rowsync & 1831.65 & 1.21 & 920.80 & 910.85 & 1505.39 & 0.92 & 17.81 20 | 1280 & 0 & 12288 & tilesync & 1831.65 & 1.21 & 920.80 & 910.85 & 1653.19 & 1.66 & 9.74 21 | 1536 & 0 & 12288 & torch & 2111.18 22 | 1536 & 0 & 12288 & streamk & 2114.73 & 1127.45 & 987.28 23 | 1536 & 0 & 12288 & baseline & 2218.84 & 0.88 & 1128.05 & 1090.79 24 | 1536 & 0 & 12288 & rowsync & 2218.84 & 0.88 & 1128.05 & 1090.79 & 1721.00 & 1.11 & 22.44 25 | 1536 & 0 & 12288 & tilesync & 2218.84 & 0.88 & 1128.05 & 1090.79 & 1802.92 & 2.02 & 18.74 26 | 1792 & 0 & 12288 & torch & 2515.16 27 | 1792 & 0 & 12288 & streamk & 2458.90 & 1316.93 & 1141.97 28 | 1792 & 0 & 12288 & baseline & 2587.19 & 0.52 & 1319.14 & 1268.05 29 | 1792 & 0 & 12288 & rowsync & 2587.19 & 0.52 & 1319.14 & 1268.05 & 1997.48 & 0.93 & 22.79 30 | 1792 & 0 & 12288 & tilesync & 2587.19 & 0.52 & 1319.14 & 1268.05 & 2080.54 & 0.83 & 19.58 31 | 2048 & 0 & 12288 & torch & 2828.45 32 | 2048 & 0 & 12288 & streamk & 2640.32 & 1340.61 & 1299.71 33 | 2048 & 0 & 12288 & baseline & 2893.82 & 0.79 & 1445.09 & 1448.73 34 | 2048 & 0 & 12288 & rowsync & 2893.82 & 0.79 & 1445.09 & 1448.73 & 2315.78 & 1.96 & 19.98 35 | 2048 & 0 & 12288 & tilesync & 2893.82 & 0.79 & 1445.09 & 1448.73 & 2644.31 & 11.26 & 8.62 36 | 2048 & 0 & 12288 & torch & 2829.39 37 | 2048 & 0 & 12288 & streamk & 2639.25 & 1339.94 & 1299.31 38 | 2048 & 0 & 12288 & baseline & 2894.22 & 0.71 & 1445.15 & 1449.07 39 | 2048 & 0 & 12288 & rowsync & 2894.22 & 0.71 & 1445.15 & 1449.07 & 2317.14 & 1.33 & 19.94 40 | 2048 & 0 & 12288 & tilesync & 2894.22 & 0.71 & 1445.15 & 1449.07 & 2649.77 & 14.84 & 8.45 41 | 2304 & 0 & 12288 & torch & 2858.94 42 | 2304 & 0 & 12288 & streamk & 2888.32 & 1443.69 & 1444.63 43 | 2304 & 0 & 12288 & baseline & 2901.33 & 1.16 & 1448.73 & 1452.60 44 | 2304 & 0 & 12288 & rowsync & 2901.33 & 1.16 & 1448.73 & 1452.60 & 2479.67 & 1.93 & 14.53 45 | 2304 & 0 & 12288 & tilesync & 2901.33 & 1.16 & 1448.73 & 1452.60 & 2689.37 & 1.40 & 7.31 46 | 2560 & 0 & 12288 & torch & 3553.63 47 | 2560 & 0 & 12288 & streamk & 3308.09 & 1640.29 & 1667.80 48 | 2560 & 0 & 12288 & baseline & 3027.68 & 158.23 & 1587.48 & 1440.20 49 | 2560 & 0 & 12288 & rowsync & 3027.68 & 158.23 & 1587.48 & 1440.20 & 2790.57 & 2.70 & 7.83 50 | 2560 & 0 & 12288 & tilesync & 3027.68 & 158.23 & 1587.48 & 1440.20 & 3084.34 & 1.77 & -1.87 51 | 2816 & 0 & 12288 & torch & 3558.92 52 | 2816 & 0 & 12288 & streamk & 3691.54 & 1884.34 & 1807.20 53 | 2816 & 0 & 12288 & baseline & 3619.21 & 0.80 & 1806.85 & 1812.37 54 | 2816 & 0 & 12288 & rowsync & 3619.21 & 0.80 & 1806.85 & 1812.37 & 3093.73 & 2.06 & 14.52 55 | 2816 & 0 & 12288 & tilesync & 3619.21 & 0.80 & 1806.85 & 1812.37 & 3382.61 & 1.40 & 6.54 56 | 3072 & 0 & 12288 & torch & 4138.26 57 | 3072 & 0 & 12288 & streamk & 3872.23 & 1934.76 & 1937.47 58 | 3072 & 0 & 12288 & baseline & 3428.81 & 0.80 & 1778.23 & 1650.57 59 | 3072 & 0 & 12288 & rowsync & 3428.81 & 0.80 & 1778.23 & 1650.57 & 3402.47 & 1.56 & 0.77 60 | 3072 & 0 & 12288 & tilesync & 3428.81 & 0.80 & 1778.23 & 1650.57 & 3706.48 & 2.58 & -8.10 -------------------------------------------------------------------------------- /src/ml-bench/transformer/results/mlp-gpt3-v100.csv: -------------------------------------------------------------------------------- 1 | 1 & 0 & 12288 & torch & 467.86 2 | 1 & 0 & 12288 & baseline & 390.03 & 1.21 & 202.30 & 187.74 3 | 1 & 0 & 12288 & rowsync & 390.03 & 1.21 & 202.30 & 187.74 & 387.01 & 1.34 & 0.77 4 | 1 & 0 & 12288 & tilesync & 390.03 & 1.21 & 202.30 & 187.74 & 374.10 & 1.35 & 4.09 5 | 2 & 0 & 12288 & torch & 385.15 6 | 2 & 0 & 12288 & baseline & 389.46 & 1.06 & 201.95 & 187.51 7 | 2 & 0 & 12288 & rowsync & 389.46 & 1.06 & 201.95 & 187.51 & 387.47 & 1.27 & 0.51 8 | 2 & 0 & 12288 & tilesync & 389.46 & 1.06 & 201.95 & 187.51 & 374.89 & 1.10 & 3.74 9 | 4 & 0 & 12288 & torch & 387.03 10 | 4 & 0 & 12288 & baseline & 392.14 & 1.99 & 203.78 & 188.37 11 | 4 & 0 & 12288 & rowsync & 392.14 & 1.99 & 203.78 & 188.37 & 387.86 & 1.63 & 1.09 12 | 4 & 0 & 12288 & tilesync & 392.14 & 1.99 & 203.78 & 188.37 & 375.65 & 1.51 & 4.21 13 | 8 & 0 & 12288 & torch & 388.80 14 | 8 & 0 & 12288 & baseline & 392.12 & 1.29 & 203.59 & 188.53 15 | 8 & 0 & 12288 & rowsync & 392.12 & 1.29 & 203.59 & 188.53 & 390.25 & 1.53 & 0.48 16 | 8 & 0 & 12288 & tilesync & 392.12 & 1.29 & 203.59 & 188.53 & 377.00 & 1.45 & 3.86 17 | 16 & 0 & 12288 & torch & 403.80 18 | 16 & 0 & 12288 & baseline & 395.20 & 1.59 & 205.24 & 189.96 19 | 16 & 0 & 12288 & rowsync & 395.20 & 1.59 & 205.24 & 189.96 & 392.41 & 1.74 & 0.71 20 | 16 & 0 & 12288 & tilesync & 395.20 & 1.59 & 205.24 & 189.96 & 379.27 & 1.86 & 4.03 21 | 32 & 0 & 12288 & torch & 417.94 22 | 32 & 0 & 12288 & baseline & 403.90 & 1.63 & 211.39 & 192.52 23 | 32 & 0 & 12288 & rowsync & 403.90 & 1.63 & 211.39 & 192.52 & 399.70 & 2.15 & 1.04 24 | 32 & 0 & 12288 & tilesync & 403.90 & 1.63 & 211.39 & 192.52 & 386.62 & 1.83 & 4.28 25 | 64 & 0 & 12288 & torch & 428.16 26 | 64 & 0 & 12288 & baseline & 425.98 & 2.16 & 227.10 & 198.88 27 | 64 & 0 & 12288 & rowsync & 425.98 & 2.16 & 227.10 & 198.88 & 422.12 & 1.66 & 0.91 28 | 64 & 0 & 12288 & tilesync & 425.98 & 2.16 & 227.10 & 198.88 & 416.93 & 3.02 & 2.12 29 | 128 & 0 & 12288 & torch & 562.41 30 | 128 & 0 & 12288 & baseline & 540.50 & 2.83 & 310.44 & 230.06 31 | 128 & 0 & 12288 & rowsync & 540.50 & 2.83 & 310.44 & 230.06 & 529.24 & 3.33 & 2.08 32 | 128 & 0 & 12288 & tilesync & 540.50 & 2.83 & 310.44 & 230.06 & 536.24 & 3.06 & 0.79 33 | 256 & 0 & 12288 & torch & 948.18 34 | 256 & 0 & 12288 & baseline & 853.79 & 1.88 & 443.11 & 410.68 35 | 256 & 0 & 12288 & rowsync & 853.79 & 1.88 & 443.11 & 410.68 & 852.60 & 1.92 & 0.14 36 | 256 & 0 & 12288 & tilesync & 853.79 & 1.88 & 443.11 & 410.68 & 806.63 & 5.68 & 5.52 37 | 512 & 0 & 12288 & torch & 1756.98 38 | 512 & 0 & 12288 & baseline & 1564.05 & 11.58 & 805.49 & 758.55 39 | 512 & 0 & 12288 & rowsync & 1564.05 & 11.58 & 805.49 & 758.55 & 1267.48 & 18.19 & 18.96 40 | 512 & 0 & 12288 & tilesync & 1564.05 & 11.58 & 805.49 & 758.55 & 1324.60 & 16.54 & 15.31 41 | 768 & 0 & 12288 & torch & 2692.21 42 | 768 & 0 & 12288 & baseline & 1593.92 & 5.97 & 781.14 & 812.78 43 | 768 & 0 & 12288 & rowsync & 1593.92 & 5.97 & 781.14 & 812.78 & 1550.05 & 61.92 & 2.75 44 | 768 & 0 & 12288 & tilesync & 1593.92 & 5.97 & 781.14 & 812.78 & 1655.58 & 14.20 & -3.87 45 | 1024 & 0 & 12288 & torch & 3364.33 46 | 1024 & 0 & 12288 & baseline & 2620.87 & 14.22 & 1446.18 & 1174.70 47 | 1024 & 0 & 12288 & rowsync & 2620.87 & 14.22 & 1446.18 & 1174.70 & 1985.65 & 25.77 & 24.24 48 | 1024 & 0 & 12288 & tilesync & 2620.87 & 14.22 & 1446.18 & 1174.70 & 2076.84 & 15.07 & 20.76 49 | 1280 & 0 & 12288 & torch & 3823.95 50 | 1280 & 0 & 12288 & baseline & 2696.53 & 13.20 & 1474.33 & 1222.20 51 | 1280 & 0 & 12288 & rowsync & 2696.53 & 13.20 & 1474.33 & 1222.20 & 2623.54 & 38.35 & 2.71 52 | 1280 & 0 & 12288 & tilesync & 2696.53 & 13.20 & 1474.33 & 1222.20 & 2647.50 & 37.11 & 1.82 53 | 1536 & 0 & 12288 & torch & 4856.49 54 | 1536 & 0 & 12288 & baseline & 3018.52 & 80.35 & 1462.49 & 1556.03 55 | 1536 & 0 & 12288 & rowsync & 3018.52 & 80.35 & 1462.49 & 1556.03 & 3045.94 & 85.45 & -0.91 56 | 1536 & 0 & 12288 & tilesync & 3018.52 & 80.35 & 1462.49 & 1556.03 & 3158.86 & 84.37 & -4.65 57 | 1792 & 0 & 12288 & torch & 5836.56 58 | 1792 & 0 & 12288 & baseline & 3970.22 & 76.14 & 2037.87 & 1932.34 59 | 1792 & 0 & 12288 & rowsync & 3970.22 & 76.14 & 2037.87 & 1932.34 & 3509.64 & 41.67 & 11.60 60 | 1792 & 0 & 12288 & tilesync & 3970.22 & 76.14 & 2037.87 & 1932.34 & 3572.28 & 98.11 & 10.02 61 | 2048 & 0 & 12288 & torch & 6256.99 62 | 2048 & 0 & 12288 & baseline & 3945.47 & 20.67 & 2015.69 & 1929.78 63 | 2048 & 0 & 12288 & rowsync & 3945.47 & 20.67 & 2015.69 & 1929.78 & 3968.68 & 24.50 & -0.59 64 | 2048 & 0 & 12288 & tilesync & 3945.47 & 20.67 & 2015.69 & 1929.78 & 4105.39 & 9.27 & -4.05 65 | 2304 & 0 & 12288 & torch & 7385.60 66 | 2304 & 0 & 12288 & baseline & 4321.45 & 20.84 & 2030.59 & 2290.86 67 | 2304 & 0 & 12288 & rowsync & 4321.45 & 20.84 & 2030.59 & 2290.86 & 4398.71 & 41.33 & -1.79 68 | 2304 & 0 & 12288 & tilesync & 4321.45 & 20.84 & 2030.59 & 2290.86 & 4584.84 & 22.18 & -6.09 69 | 2560 & 0 & 12288 & torch & 7461.47 70 | 2560 & 0 & 12288 & baseline & 4679.57 & 31.51 & 2209.00 & 2470.57 71 | 2560 & 0 & 12288 & rowsync & 4679.57 & 31.51 & 2209.00 & 2470.57 & 4476.47 & 48.93 & 4.34 72 | 2560 & 0 & 12288 & tilesync & 4679.57 & 31.51 & 2209.00 & 2470.57 & 4573.87 & 30.00 & 2.26 73 | 2816 & 0 & 12288 & torch & 8559.75 74 | 2816 & 0 & 12288 & baseline & 5709.59 & 146.59 & 2856.90 & 2852.69 75 | 2816 & 0 & 12288 & rowsync & 5709.59 & 146.59 & 2856.90 & 2852.69 & 5498.07 & 27.78 & 3.70 76 | 2816 & 0 & 12288 & tilesync & 5709.59 & 146.59 & 2856.90 & 2852.69 & 5531.01 & 193.07 & 3.13 77 | 3072 & 0 & 12288 & torch & 9335.36 78 | 3072 & 0 & 12288 & baseline & 6006.72 & 176.30 & 2825.05 & 3181.67 79 | 3072 & 0 & 12288 & rowsync & 6006.72 & 176.30 & 2825.05 & 3181.67 & 6097.52 & 148.36 & -1.51 80 | 3072 & 0 & 12288 & tilesync & 6006.72 & 176.30 & 2825.05 & 3181.67 & 5941.99 & 181.35 & 1.08 81 | -------------------------------------------------------------------------------- /src/ml-bench/transformer/results/mlp-llama-a100.csv: -------------------------------------------------------------------------------- 1 | 512 & 0 & 8192 & torch & 400.67 2 | 512 & 0 & 8192 & streamk & 352.16 & 217.67 & 134.50 3 | 512 & 0 & 8192 & baseline & 408.46 & 1.16 & 248.43 & 160.03 4 | 512 & 0 & 8192 & rowsync & 408.46 & 1.16 & 248.43 & 160.03 & 407.21 & 19.69 & 0.31 5 | 512 & 0 & 8192 & tilesync & 408.46 & 1.16 & 248.43 & 160.03 & 413.13 & 25.31 & -1.14 6 | 768 & 0 & 8192 & torch & 635.80 7 | 768 & 0 & 8192 & streamk & 570.20 & 386.95 & 183.25 8 | 768 & 0 & 8192 & baseline & 564.68 & 0.80 & 380.64 & 184.04 9 | 768 & 0 & 8192 & rowsync & 564.68 & 0.80 & 380.64 & 184.04 & 526.90 & 0.72 & 6.69 10 | 768 & 0 & 8192 & tilesync & 564.68 & 0.80 & 380.64 & 184.04 & 549.66 & 0.97 & 2.66 11 | 1024 & 0 & 8192 & torch & 695.20 12 | 1024 & 0 & 8192 & streamk & 682.67 & 454.33 & 228.34 13 | 1024 & 0 & 8192 & baseline & 721.35 & 1.28 & 468.02 & 253.33 14 | 1024 & 0 & 8192 & rowsync & 721.35 & 1.28 & 468.02 & 253.33 & 688.24 & 3.31 & 4.59 15 | 1024 & 0 & 8192 & tilesync & 721.35 & 1.28 & 468.02 & 253.33 & 721.81 & 2.30 & -0.06 16 | 1280 & 0 & 8192 & torch & 829.83 17 | 1280 & 0 & 8192 & streamk & 847.23 & 571.96 & 275.27 18 | 1280 & 0 & 8192 & baseline & 848.10 & 1.14 & 569.63 & 278.47 19 | 1280 & 0 & 8192 & rowsync & 848.10 & 1.14 & 569.63 & 278.47 & 802.53 & 1.52 & 5.37 20 | 1280 & 0 & 8192 & tilesync & 848.10 & 1.14 & 569.63 & 278.47 & 837.86 & 1.74 & 1.21 21 | 1536 & 0 & 8192 & torch & 960.80 22 | 1536 & 0 & 8192 & streamk & 979.70 & 648.29 & 331.41 23 | 1536 & 0 & 8192 & baseline & 999.42 & 0.93 & 644.32 & 355.10 24 | 1536 & 0 & 8192 & rowsync & 999.42 & 0.93 & 644.32 & 355.10 & 972.57 & 1.43 & 2.69 25 | 1536 & 0 & 8192 & tilesync & 999.42 & 0.93 & 644.32 & 355.10 & 998.46 & 1.08 & 0.10 26 | 1792 & 0 & 8192 & torch & 1176.22 27 | 1792 & 0 & 8192 & streamk & 1164.36 & 785.89 & 378.46 28 | 1792 & 0 & 8192 & baseline & 1163.09 & 0.95 & 723.80 & 439.30 29 | 1792 & 0 & 8192 & rowsync & 1163.09 & 0.95 & 723.80 & 439.30 & 1044.08 & 1.00 & 10.23 30 | 1792 & 0 & 8192 & tilesync & 1163.09 & 0.95 & 723.80 & 439.30 & 1072.36 & 0.90 & 7.80 31 | 2048 & 0 & 8192 & torch & 1358.58 32 | 2048 & 0 & 8192 & streamk & 1316.51 & 884.98 & 431.52 33 | 2048 & 0 & 8192 & baseline & 1360.78 & 2.13 & 917.73 & 443.05 34 | 2048 & 0 & 8192 & rowsync & 1360.78 & 2.13 & 917.73 & 443.05 & 1292.46 & 7.21 & 5.02 35 | 2048 & 0 & 8192 & tilesync & 1360.78 & 2.13 & 917.73 & 443.05 & 1200.13 & 5.26 & 11.81 36 | 2304 & 0 & 8192 & torch & 1663.80 37 | 2304 & 0 & 8192 & streamk & 1417.61 & 920.61 & 496.99 38 | 2304 & 0 & 8192 & baseline & 1534.92 & 1.19 & 992.20 & 542.72 39 | 2304 & 0 & 8192 & rowsync & 1534.92 & 1.19 & 992.20 & 542.72 & 1287.85 & 1.53 & 16.10 40 | 2304 & 0 & 8192 & tilesync & 1534.92 & 1.19 & 992.20 & 542.72 & 1352.87 & 1.66 & 11.86 41 | 2560 & 0 & 8192 & torch & 1604.40 42 | 2560 & 0 & 8192 & streamk & 1562.22 & 1019.20 & 543.02 43 | 2560 & 0 & 8192 & baseline & 1779.66 & 0.89 & 1233.29 & 546.36 44 | 2560 & 0 & 8192 & rowsync & 1779.66 & 0.89 & 1233.29 & 546.36 & 1697.91 & 2.60 & 4.59 45 | 2560 & 0 & 8192 & tilesync & 1779.66 & 0.89 & 1233.29 & 546.36 & 1580.20 & 3.57 & 11.21 46 | 2816 & 0 & 8192 & torch & 1706.54 47 | 2816 & 0 & 8192 & streamk & 1723.32 & 1120.49 & 602.83 48 | 2816 & 0 & 8192 & baseline & 1870.51 & 0.93 & 1236.99 & 633.51 49 | 2816 & 0 & 8192 & rowsync & 1870.51 & 0.93 & 1236.99 & 633.51 & 1533.10 & 1.90 & 18.04 50 | 2816 & 0 & 8192 & tilesync & 1870.51 & 0.93 & 1236.99 & 633.51 & 1635.95 & 1.76 & 12.54 51 | 3072 & 0 & 8192 & torch & 1841.50 52 | 3072 & 0 & 8192 & streamk & 1886.01 & 1232.88 & 653.13 53 | 3072 & 0 & 8192 & baseline & 1957.43 & 0.88 & 1239.44 & 717.99 54 | 3072 & 0 & 8192 & rowsync & 1957.43 & 0.88 & 1239.44 & 717.99 & 1623.44 & 1.37 & 17.06 55 | 3072 & 0 & 8192 & tilesync & 1957.43 & 0.88 & 1239.44 & 717.99 & 1707.18 & 2.23 & 12.78 56 | -------------------------------------------------------------------------------- /src/ml-bench/transformer/results/mlp-llama-v100.csv: -------------------------------------------------------------------------------- 1 | 1 & 0 & 8192 & torch & 320.41 2 | 1 & 0 & 8192 & baseline & 193.64 & 1.49 & 123.33 & 70.31 3 | 1 & 0 & 8192 & rowsync & 193.64 & 1.49 & 123.33 & 70.31 & 193.14 & 1.40 & 0.26 4 | 1 & 0 & 8192 & tilesync & 193.64 & 1.49 & 123.33 & 70.31 & 184.94 & 1.23 & 4.49 5 | 2 & 0 & 8192 & torch & 2020.41 6 | 2 & 0 & 8192 & baseline & 193.71 & 1.23 & 123.16 & 70.55 7 | 2 & 0 & 8192 & rowsync & 193.71 & 1.23 & 123.16 & 70.55 & 193.99 & 1.23 & -0.15 8 | 2 & 0 & 8192 & tilesync & 193.71 & 1.23 & 123.16 & 70.55 & 185.85 & 1.57 & 4.06 9 | 4 & 0 & 8192 & torch & 610.28 10 | 4 & 0 & 8192 & baseline & 194.28 & 0.85 & 123.39 & 70.89 11 | 4 & 0 & 8192 & rowsync & 194.28 & 0.85 & 123.39 & 70.89 & 193.93 & 1.54 & 0.18 12 | 4 & 0 & 8192 & tilesync & 194.28 & 0.85 & 123.39 & 70.89 & 185.46 & 1.67 & 4.54 13 | 8 & 0 & 8192 & torch & 604.44 14 | 8 & 0 & 8192 & baseline & 196.20 & 1.49 & 124.18 & 72.02 15 | 8 & 0 & 8192 & rowsync & 196.20 & 1.49 & 124.18 & 72.02 & 195.98 & 1.79 & 0.11 16 | 8 & 0 & 8192 & tilesync & 196.20 & 1.49 & 124.18 & 72.02 & 187.23 & 1.76 & 4.57 17 | 16 & 0 & 8192 & torch & 631.38 18 | 16 & 0 & 8192 & baseline & 198.54 & 1.16 & 124.64 & 73.89 19 | 16 & 0 & 8192 & rowsync & 198.54 & 1.16 & 124.64 & 73.89 & 196.20 & 1.93 & 1.18 20 | 16 & 0 & 8192 & tilesync & 198.54 & 1.16 & 124.64 & 73.89 & 188.13 & 1.71 & 5.24 21 | 32 & 0 & 8192 & torch & 225.29 22 | 32 & 0 & 8192 & baseline & 204.45 & 0.86 & 128.68 & 75.77 23 | 32 & 0 & 8192 & rowsync & 204.45 & 0.86 & 128.68 & 75.77 & 203.27 & 1.62 & 0.58 24 | 32 & 0 & 8192 & tilesync & 204.45 & 0.86 & 128.68 & 75.77 & 193.88 & 1.80 & 5.17 25 | 64 & 0 & 8192 & torch & 236.09 26 | 64 & 0 & 8192 & baseline & 224.64 & 1.96 & 139.32 & 85.33 27 | 64 & 0 & 8192 & rowsync & 224.64 & 1.96 & 139.32 & 85.33 & 225.28 & 1.36 & -0.28 28 | 64 & 0 & 8192 & tilesync & 224.64 & 1.96 & 139.32 & 85.33 & 221.58 & 1.36 & 1.36 29 | 128 & 0 & 8192 & torch & 298.05 30 | 128 & 0 & 8192 & baseline & 292.07 & 0.90 & 189.43 & 102.64 31 | 128 & 0 & 8192 & rowsync & 292.07 & 0.90 & 189.43 & 102.64 & 313.85 & 1.33 & -7.46 32 | 128 & 0 & 8192 & tilesync & 292.07 & 0.90 & 189.43 & 102.64 & 319.72 & 6.47 & -9.47 33 | 256 & 0 & 8192 & torch & 485.51 34 | 256 & 0 & 8192 & baseline & 416.09 & 2.56 & 257.02 & 159.06 35 | 256 & 0 & 8192 & rowsync & 416.09 & 2.56 & 257.02 & 159.06 & 435.37 & 1.95 & -4.63 36 | 256 & 0 & 8192 & tilesync & 416.09 & 2.56 & 257.02 & 159.06 & 414.89 & 2.46 & 0.29 37 | 512 & 0 & 8192 & torch & 926.81 38 | 512 & 0 & 8192 & baseline & 575.25 & 1.55 & 388.90 & 186.35 39 | 512 & 0 & 8192 & rowsync & 575.25 & 1.55 & 388.90 & 186.35 & 723.06 & 2.21 & -25.70 40 | 512 & 0 & 8192 & tilesync & 575.25 & 1.55 & 388.90 & 186.35 & 676.80 & 5.01 & -17.65 41 | 768 & 0 & 8192 & torch & 1313.29 42 | 768 & 0 & 8192 & baseline & 912.21 & 2.13 & 548.80 & 363.40 43 | 768 & 0 & 8192 & rowsync & 912.21 & 2.13 & 548.80 & 363.40 & 888.26 & 1.32 & 2.62 44 | 768 & 0 & 8192 & tilesync & 912.21 & 2.13 & 548.80 & 363.40 & 873.23 & 2.43 & 4.27 45 | 1024 & 0 & 8192 & torch & 1741.04 46 | 1024 & 0 & 8192 & baseline & 1341.33 & 2.31 & 975.58 & 365.74 47 | 1024 & 0 & 8192 & rowsync & 1341.33 & 2.31 & 975.58 & 365.74 & 1183.06 & 2.90 & 11.80 48 | 1024 & 0 & 8192 & tilesync & 1341.33 & 2.31 & 975.58 & 365.74 & 1140.86 & 6.38 & 14.95 49 | 1280 & 0 & 8192 & torch & 1998.07 50 | 1280 & 0 & 8192 & baseline & 1351.62 & 1.90 & 982.29 & 369.33 51 | 1280 & 0 & 8192 & rowsync & 1351.62 & 1.90 & 982.29 & 369.33 & 1214.24 & 7.75 & 10.16 52 | 1280 & 0 & 8192 & tilesync & 1351.62 & 1.90 & 982.29 & 369.33 & 1229.37 & 7.24 & 9.04 53 | 1536 & 0 & 8192 & torch & 2423.13 54 | 1536 & 0 & 8192 & baseline & 1529.17 & 1.69 & 984.18 & 545.00 55 | 1536 & 0 & 8192 & rowsync & 1529.17 & 1.69 & 984.18 & 545.00 & 1402.47 & 11.31 & 8.29 56 | 1536 & 0 & 8192 & tilesync & 1529.17 & 1.69 & 984.18 & 545.00 & 1412.78 & 6.15 & 7.61 57 | 1792 & 0 & 8192 & torch & 2992.24 58 | 1792 & 0 & 8192 & baseline & 1532.24 & 2.11 & 985.19 & 547.05 59 | 1792 & 0 & 8192 & rowsync & 1532.24 & 2.11 & 985.19 & 547.05 & 1537.70 & 1.79 & -0.36 60 | 1792 & 0 & 8192 & tilesync & 1532.24 & 2.11 & 985.19 & 547.05 & 1578.16 & 1.42 & -3.00 61 | 2048 & 0 & 8192 & torch & 3257.92 62 | 2048 & 0 & 8192 & baseline & 2189.31 & 2.74 & 1465.06 & 724.25 63 | 2048 & 0 & 8192 & rowsync & 2189.31 & 2.74 & 1465.06 & 724.25 & 1730.61 & 2.36 & 20.95 64 | 2048 & 0 & 8192 & tilesync & 2189.31 & 2.74 & 1465.06 & 724.25 & 1782.62 & 1.46 & 18.58 65 | 2304 & 0 & 8192 & torch & 3450.59 66 | 2304 & 0 & 8192 & baseline & 2196.19 & 2.41 & 1469.27 & 726.92 67 | 2304 & 0 & 8192 & rowsync & 2196.19 & 2.41 & 1469.27 & 726.92 & 2093.96 & 11.50 & 4.65 68 | 2304 & 0 & 8192 & tilesync & 2196.19 & 2.41 & 1469.27 & 726.92 & 2114.28 & 8.22 & 3.73 69 | 2560 & 0 & 8192 & torch & 3748.13 70 | 2560 & 0 & 8192 & baseline & 2198.58 & 1.80 & 1470.23 & 728.35 71 | 2560 & 0 & 8192 & rowsync & 2198.58 & 1.80 & 1470.23 & 728.35 & 2236.30 & 3.03 & -1.72 72 | 2560 & 0 & 8192 & tilesync & 2198.58 & 1.80 & 1470.23 & 728.35 & 2281.31 & 3.17 & -3.76 73 | 2816 & 0 & 8192 & torch & 4282.37 74 | 2816 & 0 & 8192 & baseline & 2814.30 & 76.39 & 1923.13 & 891.16 75 | 2816 & 0 & 8192 & rowsync & 2814.30 & 76.39 & 1923.13 & 891.16 & 2386.93 & 4.69 & 15.19 76 | 2816 & 0 & 8192 & tilesync & 2814.30 & 76.39 & 1923.13 & 891.16 & 2372.95 & 98.86 & 15.68 77 | 3072 & 0 & 8192 & torch & 4925.86 78 | 3072 & 0 & 8192 & baseline & 2632.65 & 48.27 & 1801.38 & 831.26 79 | 3072 & 0 & 8192 & rowsync & 2632.65 & 48.27 & 1801.38 & 831.26 & 2496.34 & 4.84 & 5.18 80 | 3072 & 0 & 8192 & tilesync & 2632.65 & 48.27 & 1801.38 & 831.26 & 2586.23 & 105.63 & 1.76 81 | -------------------------------------------------------------------------------- /src/ml-bench/transformer/results/mlp-results-2: -------------------------------------------------------------------------------- 1 | nohup: ignoring input 2 | 1 & 12288 & rowsync & 391.56 & 404.83 & 1.92 & 309.98 & 296.11 & 206.33 & 198.39 & 379.33 & 1.91 & 6.30 3 | 1 & 12288 & tilesync & 391.56 & 403.56 & 2.41 & 309.98 & 296.11 & 205.17 & 198.28 & 378.33 & 1.57 & 6.25 4 | 2 & 12288 & rowsync & 387.24 & 405.61 & 2.45 & 308.25 & 300.32 & 207.06 & 198.56 & 378.50 & 1.38 & 6.68 5 | 2 & 12288 & tilesync & 387.24 & 405.89 & 1.88 & 308.25 & 300.32 & 206.78 & 199.11 & 379.06 & 1.98 & 6.61 6 | 4 & 12288 & rowsync & 389.45 & 405.67 & 2.66 & 310.25 & 298.69 & 206.67 & 199.00 & 379.11 & 1.32 & 6.55 7 | 4 & 12288 & tilesync & 389.45 & 405.39 & 2.75 & 310.25 & 298.69 & 206.83 & 198.44 & 378.72 & 1.74 & 6.58 8 | 8 & 12288 & rowsync & 392.29 & 407.06 & 4.05 & 308.59 & 302.55 & 207.50 & 199.56 & 380.78 & 1.40 & 6.46 9 | 8 & 12288 & tilesync & 392.29 & 406.39 & 2.15 & 308.59 & 302.55 & 206.89 & 199.50 & 380.33 & 1.57 & 6.41 10 | 16 & 12288 & rowsync & 406.46 & 405.44 & 1.98 & 310.45 & 297.59 & 206.17 & 199.28 & 467.17 & 204.76 & -15.22 11 | 16 & 12288 & tilesync & 406.46 & 406.89 & 1.64 & 310.45 & 297.59 & 207.17 & 199.67 & 382.39 & 2.00 & 6.02 12 | 32 & 12288 & rowsync & 404.59 & 408.28 & 3.04 & 312.92 & 308.68 & 208.39 & 199.89 & 386.44 & 1.50 & 5.35 13 | 32 & 12288 & tilesync & 404.59 & 409.83 & 2.48 & 312.92 & 308.68 & 208.67 & 201.17 & 385.56 & 1.79 & 5.92 14 | 64 & 12288 & rowsync & 416.71 & 428.39 & 2.12 & 316.43 & 312.96 & 219.56 & 208.83 & 409.00 & 1.68 & 4.53 15 | 64 & 12288 & tilesync & 416.71 & 429.61 & 3.36 & 316.43 & 312.96 & 220.89 & 208.72 & 396.67 & 1.81 & 7.67 16 | 128 & 12288 & rowsync & 565.04 & 512.44 & 5.31 & 280.22 & 310.39 & 270.22 & 242.22 & 469.22 & 3.99 & 8.43 17 | 128 & 12288 & tilesync & 565.04 & 511.22 & 3.90 & 280.22 & 310.39 & 270.83 & 240.33 & 443.50 & 4.66 & 13.25 18 | 256 & 12288 & rowsync & 931.54 & 908.11 & 24.85 & 574.54 & 567.06 & 452.44 & 455.67 & 832.00 & 4.21 & 8.38 19 | 256 & 12288 & tilesync & 931.54 & 912.06 & 27.05 & 574.54 & 567.06 & 452.83 & 459.17 & 808.67 & 3.41 & 11.34 20 | 512 & 12288 & rowsync & 1741.12 & 1626.22 & 3.70 & 1091.45 & 1046.99 & 832.89 & 793.22 & 1370.22 & 5.08 & 15.74 21 | 512 & 12288 & tilesync & 1741.12 & 1624.50 & 3.11 & 1091.45 & 1046.99 & 832.28 & 792.17 & 1369.89 & 9.74 & 15.67 22 | 1024 & 12288 & rowsync & 3384.12 & 2809.94 & 6.27 & 1914.42 & 1887.04 & 1478.61 & 1331.22 & 2389.89 & 7.44 & 14.95 23 | 1024 & 12288 & tilesync & 3384.12 & 2810.50 & 7.79 & 1914.42 & 1887.04 & 1478.33 & 1332.17 & 2274.72 & 9.21 & 19.06 24 | 2048 & 12288 & rowsync & 6461.14 & 4839.11 & 132.79 & 3763.56 & 3692.29 & 2612.44 & 2226.61 & 4452.28 & 15.62 & 7.99 25 | 2048 & 12288 & tilesync & 6461.14 & 4826.00 & 130.19 & 3763.56 & 3692.29 & 2603.67 & 2222.33 & 4486.22 & 12.70 & 7.04 26 | -------------------------------------------------------------------------------- /src/ml-bench/transformer/results/mlp-results-gpt-3: -------------------------------------------------------------------------------- 1 | 1 & 12288 & pytorch & 393.85 2 | 1 & 12288 & streamk & 309.09 & 292.03 & 601.12 3 | 1 & 12288 & baseline & 406.94 & 2.96 & 206.89 & 200.00 4 | 1 & 12288 & rowsync & 406.94 & 2.96 & 206.89 & 200.00 & 385.28 & 1.18 & 5.32 5 | 1 & 12288 & tilesync & 406.94 & 2.96 & 206.89 & 200.00 & 388.78 & 2.37 & 4.46 6 | 2 & 12288 & pytorch & 383.04 7 | 2 & 12288 & streamk & 290.30 & 277.96 & 568.26 8 | 2 & 12288 & baseline & 405.44 & 2.38 & 206.17 & 199.28 9 | 2 & 12288 & rowsync & 405.44 & 2.38 & 206.17 & 199.28 & 384.83 & 1.58 & 5.08 10 | 2 & 12288 & tilesync & 405.44 & 2.38 & 206.17 & 199.28 & 388.94 & 1.73 & 4.07 11 | 4 & 12288 & pytorch & 387.97 12 | 4 & 12288 & streamk & 309.26 & 290.89 & 600.15 13 | 4 & 12288 & baseline & 406.28 & 2.59 & 207.06 & 199.17 14 | 4 & 12288 & rowsync & 406.28 & 2.59 & 207.06 & 199.17 & 385.72 & 1.07 & 5.06 15 | 4 & 12288 & tilesync & 406.28 & 2.59 & 207.06 & 199.17 & 389.67 & 2.06 & 4.09 16 | 8 & 12288 & pytorch & 386.23 17 | 8 & 12288 & streamk & 281.70 & 306.02 & 587.72 18 | 8 & 12288 & baseline & 407.67 & 2.35 & 208.00 & 199.67 19 | 8 & 12288 & rowsync & 407.67 & 2.35 & 208.00 & 199.67 & 386.72 & 1.49 & 5.14 20 | 8 & 12288 & tilesync & 407.67 & 2.35 & 208.00 & 199.67 & 389.28 & 1.23 & 4.51 21 | 16 & 12288 & pytorch & 406.04 22 | 16 & 12288 & streamk & 309.68 & 290.05 & 599.73 23 | 16 & 12288 & baseline & 407.67 & 2.09 & 208.22 & 199.44 24 | 16 & 12288 & rowsync & 407.67 & 2.09 & 208.22 & 199.44 & 387.89 & 2.08 & 4.85 25 | 16 & 12288 & tilesync & 407.67 & 2.09 & 208.22 & 199.44 & 391.00 & 1.57 & 4.09 26 | 32 & 12288 & pytorch & 407.17 27 | 32 & 12288 & streamk & 309.35 & 298.65 & 607.99 28 | 32 & 12288 & baseline & 410.33 & 2.30 & 209.28 & 201.06 29 | 32 & 12288 & rowsync & 410.33 & 2.30 & 209.28 & 201.06 & 392.22 & 1.93 & 4.41 30 | 32 & 12288 & tilesync & 410.33 & 2.30 & 209.28 & 201.06 & 395.56 & 2.97 & 3.60 31 | 64 & 12288 & pytorch & 413.66 32 | 64 & 12288 & streamk & 316.56 & 308.33 & 624.89 33 | 64 & 12288 & baseline & 428.83 & 2.98 & 219.72 & 209.11 34 | 64 & 12288 & rowsync & 428.83 & 2.98 & 219.72 & 209.11 & 408.39 & 2.00 & 4.77 35 | 64 & 12288 & tilesync & 428.83 & 2.98 & 219.72 & 209.11 & 416.06 & 2.26 & 2.98 36 | 128 & 12288 & pytorch & 566.16 37 | 128 & 12288 & streamk & 280.18 & 316.51 & 596.69 38 | 128 & 12288 & baseline & 513.94 & 4.41 & 272.33 & 241.56 39 | 128 & 12288 & rowsync & 513.94 & 4.41 & 272.33 & 241.56 & 528.78 & 3.47 & -2.89 40 | 128 & 12288 & tilesync & 513.94 & 4.41 & 272.33 & 241.56 & 513.83 & 2.73 & 0.02 41 | 256 & 12288 & pytorch & 933.15 42 | 256 & 12288 & streamk & 574.46 & 572.41 & 1146.87 43 | 256 & 12288 & baseline & 926.00 & 33.12 & 454.89 & 471.11 44 | 256 & 12288 & rowsync & 926.00 & 33.12 & 454.89 & 471.11 & 991.22 & 7.15 & -7.04 45 | 256 & 12288 & tilesync & 926.00 & 33.12 & 454.89 & 471.11 & 821.72 & 7.18 & 5.75 46 | 512 & 12288 & pytorch & 1741.08 47 | 512 & 12288 & streamk & 1108.45 & 1038.17 & 2146.62 48 | 512 & 12288 & baseline & 1632.61 & 3.42 & 836.28 & 796.33 49 | 512 & 12288 & rowsync & 1632.61 & 3.42 & 836.28 & 796.33 & 1486.00 & 126.66 & 8.98 50 | 512 & 12288 & tilesync & 1632.61 & 3.42 & 836.28 & 796.33 & 1414.89 & 7.39 & 13.34 51 | 1024 & 12288 & pytorch & 3092.69 52 | 1024 & 12288 & streamk & 1914.77 & 1733.80 & 3648.57 53 | 1024 & 12288 & baseline & 2819.00 & 7.16 & 1483.50 & 1335.39 54 | 1024 & 12288 & rowsync & 2819.00 & 7.16 & 1483.50 & 1335.39 & 2292.78 & 41.76 & 18.67 55 | 1024 & 12288 & tilesync & 2819.00 & 7.16 & 1483.50 & 1335.39 & 2348.56 & 7.48 & 16.69 56 | 2048 & 12288 & pytorch & 6185.96 57 | 2048 & 12288 & streamk & 3798.96 & 3464.26 & 7263.22 58 | 2048 & 12288 & baseline & 5138.56 & 11.93 & 2770.00 & 2368.44 59 | 2048 & 12288 & rowsync & 5138.56 & 11.93 & 2770.00 & 2368.44 & 4461.61 & 10.77 & 13.17 60 | 2048 & 12288 & tilesync & 5138.56 & 11.93 & 2770.00 & 2368.44 & 4673.78 & 7.05 & 9.04 -------------------------------------------------------------------------------- /src/ml-bench/transformer/results/mlp-results-in-paper: -------------------------------------------------------------------------------- 1 | 1 & 6144 & Row-Sync & 144.11 & 209.33 & 0.97 & 88.11 & 121.17 & 3 & 48 & 24 & 206.44 & 0.78 & 1.38 2 | 1 & 6144 & Tile-Sync & 144.11 & 211.11 & 1.75 & 89.28 & 121.78 & 3 & 48 & 24 & 195.89 & 3.29 & 7.21 3 | 2 & 6144 & Row-Sync & 128.16 & 210.89 & 1.71 & 89.06 & 121.72 & 3 & 48 & 24 & 207.67 & 0.91 & 1.53 4 | 2 & 6144 & Tile-Sync & 128.16 & 209.11 & 0.83 & 88.11 & 121.00 & 3 & 48 & 24 & 194.72 & 0.96 & 6.88 5 | 4 & 6144 & Row-Sync & 128.62 & 209.28 & 0.67 & 88.06 & 121.22 & 3 & 48 & 24 & 206.94 & 0.80 & 1.11 6 | 4 & 6144 & Tile-Sync & 128.62 & 210.00 & 0.84 & 88.56 & 121.33 & 3 & 48 & 24 & 195.11 & 0.83 & 7.09 7 | 8 & 6144 & Row-Sync & 131.65 & 210.78 & 1.31 & 89.00 & 121.78 & 3 & 48 & 24 & 208.67 & 0.97 & 1.00 8 | 8 & 6144 & Tile-Sync & 131.65 & 210.17 & 0.79 & 88.94 & 121.22 & 3 & 48 & 24 & 194.94 & 0.87 & 7.24 9 | 16 & 6144 & Row-Sync & 137.71 & 210.89 & 1.02 & 89.61 & 121.22 & 3 & 48 & 24 & 209.61 & 1.79 & 0.61 10 | 16 & 6144 & Tile-Sync & 137.71 & 210.83 & 1.95 & 89.56 & 121.22 & 3 & 48 & 24 & 197.06 & 1.51 & 6.53 11 | 32 & 6144 & Row-Sync & 140.54 & 211.28 & 0.89 & 89.44 & 121.67 & 3 & 48 & 24 & 213.11 & 0.96 & -0.87 12 | 32 & 6144 & Tile-Sync & 140.54 & 211.83 & 0.92 & 89.44 & 122.33 & 3 & 48 & 24 & 199.89 & 0.76 & 5.64 13 | 64 & 6144 & Row-Sync & 149.05 & 222.56 & 1.38 & 96.61 & 125.94 & 3 & 48 & 24 & 226.72 & 1.13 & -1.87 14 | 64 & 6144 & Tile-Sync & 149.05 & 222.56 & 0.86 & 96.78 & 125.72 & 3 & 48 & 24 & 215.89 & 0.76 & 3.00 15 | 128 & 6144 & Row-Sync & 211.70 & 308.61 & 1.81 & 189.58 & 119.03 & 2 & 288 & 144 & 336.83 & 3.28 & -9.14 16 | 128 & 6144 & Tile-Sync & 211.70 & 308.94 & 2.10 & 189.72 & 119.14 & 2 & 288 & 144 & 343.33 & 3.04 & -11.13 17 | 256 & 6144 & Row-Sync & 323.70 & 359.75 & 1.66 & 182.81 & 176.94 & 2 & 96 & 48 & 364.78 & 2.65 & -1.40 18 | 256 & 6144 & Tile-Sync & 323.31 & 358.94 & 1.33 & 182.03 & 176.92 & 2 & 96 & 48 & 370.67 & 4.96 & -3.27 19 | 512 & 6144 & Row-Sync & 549.65 & 460.08 & 2.22 & 205.11 & 254.89 & 2 & 144 & 96 & 459.81 & 2.03 & 0.06 20 | 512 & 6144 & Tile-Sync & 549.48 & 460.94 & 3.10 & 205.69 & 255.22 & 2 & 144 & 96 & 479.64 & 3.35 & -4.06 21 | 1024 & 6144 & Row-Sync & 998.82 & 799.81 & 1.98 & 380.39 & 419.33 & 2 & 288 & 192 & 741.97 & 3.96 & 7.23 22 | 1024 & 6144 & Tile-Sync & 998.20 & 800.33 & 2.75 & 380.81 & 419.47 & 2 & 288 & 192 & 816.75 & 3.61 & -2.05 23 | 2048 & 6144 & Row-Sync & 1862.21 & 1426.97 & 31.79 & 724.06 & 702.89 & 2 & 576 & 384 & 1327.50 & 25.20 & 6.97 24 | 2048 & 6144 & Tile-Sync & 1862.87 & 1421.86 & 30.91 & 724.17 & 697.67 & 2 & 576 & 384 & 1437.00 & 12.92 & -1.06 25 | 1 & 8192 & Row-Sync & 250.10 & 265.94 & 1.06 & 109.89 & 156.06 & 3 & 64 & 32 & 266.11 & 1.32 & -0.06 26 | 1 & 8192 & Tile-Sync & 250.10 & 266.33 & 0.97 & 110.06 & 156.28 & 3 & 64 & 32 & 251.78 & 1.17 & 5.47 27 | 2 & 8192 & Row-Sync & 205.22 & 265.94 & 0.94 & 109.67 & 156.28 & 3 & 64 & 32 & 267.28 & 1.18 & -0.50 28 | 2 & 8192 & Tile-Sync & 205.22 & 266.39 & 2.00 & 109.78 & 156.56 & 3 & 64 & 32 & 251.39 & 1.29 & 5.63 29 | 4 & 8192 & Row-Sync & 206.58 & 265.61 & 1.20 & 109.61 & 156.00 & 3 & 64 & 32 & 267.17 & 1.34 & -0.59 30 | 4 & 8192 & Tile-Sync & 206.58 & 265.94 & 1.39 & 109.67 & 156.28 & 3 & 64 & 32 & 251.67 & 1.19 & 5.37 31 | 8 & 8192 & Row-Sync & 256.92 & 266.33 & 1.41 & 110.33 & 156.00 & 3 & 64 & 32 & 267.83 & 1.92 & -0.56 32 | 8 & 8192 & Tile-Sync & 256.92 & 266.56 & 1.38 & 110.50 & 156.06 & 3 & 64 & 32 & 252.44 & 0.98 & 5.29 33 | 16 & 8192 & Row-Sync & 261.73 & 266.94 & 0.87 & 110.72 & 156.22 & 3 & 64 & 32 & 269.33 & 1.19 & -0.89 34 | 16 & 8192 & Tile-Sync & 261.73 & 267.33 & 1.24 & 111.22 & 156.06 & 3 & 64 & 32 & 254.67 & 1.50 & 4.74 35 | 32 & 8192 & Row-Sync & 200.25 & 269.72 & 1.60 & 113.00 & 156.67 & 3 & 64 & 32 & 275.17 & 1.50 & -2.02 36 | 32 & 8192 & Tile-Sync & 200.25 & 269.83 & 1.58 & 112.94 & 156.89 & 3 & 64 & 32 & 258.67 & 1.28 & 4.14 37 | 64 & 8192 & Row-Sync & 205.59 & 278.89 & 1.91 & 117.33 & 161.50 & 3 & 64 & 32 & 290.11 & 0.96 & -4.02 38 | 64 & 8192 & Tile-Sync & 205.59 & 279.39 & 2.03 & 117.50 & 161.78 & 3 & 64 & 32 & 275.50 & 1.58 & 1.39 39 | 128 & 8192 & Row-Sync & 294.13 & 428.69 & 1.58 & 230.36 & 198.33 & 2 & 384 & 192 & 439.97 & 2.29 & -2.63 40 | 128 & 8192 & Tile-Sync & 294.13 & 429.19 & 1.33 & 229.89 & 199.25 & 2 & 384 & 192 & 478.67 & 6.78 & -11.53 41 | 256 & 8192 & Row-Sync & 466.67 & 424.86 & 1.46 & 208.61 & 216.11 & 2 & 128 & 64 & 430.11 & 2.93 & -24.77 42 | 256 & 8192 & Tile-Sync & 464.80 & 424.83 & 1.23 & 208.33 & 216.44 & 2 & 128 & 64 & 447.94 & 4.06 & -28.98 43 | 512 & 8192 & Row-Sync & 881.41 & 643.28 & 2.76 & 325.06 & 318.19 & 2 & 128 & 128 & 644.53 & 2.84 & -0.19 44 | 512 & 8192 & Tile-Sync & 880.63 & 642.92 & 1.81 & 324.58 & 318.33 & 2 & 128 & 128 & 660.47 & 2.29 & -2.73 45 | 1024 & 8192 & Row-Sync & 1691.38 & 1243.75 & 3.56 & 623.58 & 620.11 & 2 & 256 & 256 & 1192.03 & 40.11 & 4.16 46 | 1024 & 8192 & Tile-Sync & 1690.71 & 1244.11 & 2.23 & 623.08 & 620.97 & 2 & 256 & 256 & 1208.61 & 5.11 & 2.85 47 | 2048 & 8192 & Row-Sync & 3149.66 & 2347.19 & 3.57 & 1222.17 & 1124.97 & 2 & 512 & 512 & 2037.61 & 41.21 & 13.19 48 | 2048 & 8192 & Tile-Sync & 3133.95 & 2346.42 & 3.95 & 1221.75 & 1124.67 & 2 & 512 & 512 & 2235.81 & 84.23 & 4.71 49 | 1 & 12288 & Row-Sync & 469.68 & 440.67 & 1.33 & 199.89 & 240.78 & 3 & 72 & 48 & 433.56 & 1.25 & 1.61 50 | 1 & 12288 & Tile-Sync & 469.68 & 441.78 & 1.26 & 200.11 & 241.61 & 3 & 72 & 48 & 411.33 & 1.50 & 6.89 51 | 1 & 12288 & Stream-K & -- & -- & -- & ---- & ---- &- &- & - & 445 & 0 & 52 | 2 & 12288 & Row-Sync & 385.39 & 440.44 & 1.46 & 199.61 & 240.78 & 3 & 72 & 48 & 433.94 & 1.11 & 1.48 53 | 2 & 12288 & Tile-Sync & 385.39 & 440.33 & 1.57 & 199.67 & 240.67 & 3 & 72 & 48 & 411.50 & 1.10 & 6.55 54 | 2 & 12288 & Stream-K & -- & -- & -- & ---- & ---- &- &- & - & 445 & 0 & 55 | 4 & 12288 & Row-Sync & 387.68 & 442.06 & 1.35 & 200.00 & 241.94 & 3 & 72 & 48 & 434.83 & 1.15 & 1.63 56 | 4 & 12288 & Tile-Sync & 387.68 & 439.39 & 1.29 & 199.56 & 239.83 & 3 & 72 & 48 & 410.89 & 1.18 & 6.49 57 | 4 & 12288 & Stream-K & -- & -- & -- & ---- & ---- &- &- & - & 445 & 0 & 58 | 8 & 12288 & Row-Sync & 389.79 & 440.78 & 1.90 & 200.22 & 240.44 & 3 & 72 & 48 & 435.17 & 1.29 & 1.27 59 | 8 & 12288 & Tile-Sync & 389.79 & 440.50 & 1.25 & 199.61 & 240.83 & 3 & 72 & 48 & 412.61 & 1.46 & 6.33 60 | 8 & 12288 & Stream-K & -- & -- & -- & ---- & ---- &- &- & - & 445 & 0 & 61 | 16 & 12288 & Row-Sync & 405.78 & 442.56 & 1.65 & 201.17 & 241.39 & 3 & 72 & 48 & 436.11 & 1.37 & 1.46 62 | 16 & 12288 & Tile-Sync & 405.78 & 441.78 & 1.73 & 201.06 & 240.56 & 3 & 72 & 48 & 414.56 & 1.92 & 6.16 63 | 16 & 12288 & Stream-K & -- & -- & -- & ---- & ---- &- &- & - & 445 & 0 & 64 | 32 & 12288 & Row-Sync & 413.16 & 444.28 & 1.53 & 202.67 & 241.61 & 3 & 72 & 48 & 443.56 & 1.58 & 0.16 65 | 32 & 12288 & Tile-Sync & 413.16 & 444.61 & 1.33 & 202.39 & 242.17 & 3 & 72 & 48 & 420.33 & 1.81 & 5.46 66 | 32 & 12288 & Stream-K & -- & -- & -- & ---- & ---- &- &- & - & 445 & 0 & 67 | 64 & 12288 & Row-Sync & 420.32 & 470.83 & 1.92 & 222.33 & 248.44 & 3 & 72 & 48 & 475.61 & 2.00 & -1.01 68 | 64 & 12288 & Tile-Sync & 420.32 & 472.89 & 2.22 & 223.78 & 249.06 & 3 & 72 & 48 & 450.78 & 1.52 & 4.68 69 | 64 & 12288 & Stream-K & -- & -- & -- & ---- & ---- &- &- & - & 445 & 0 & 70 | 128 & 12288 & Row-Sync & 561.22 & 780.31 & 2.88 & 402.64 & 334.64 & 2 & 576 & 288 & 780.64 & 2.64 & -1.67 71 | 128 & 12288 & Tile-Sync & 561.22 & 780.81 & 2.29 & 403.08 & 334.69 & 2 & 576 & 288 & 730.39 & 3.20 & -4.15 72 | 128 & 12288 & Stream-K & -- & -- & -- & ---- & ---- &- &- & - & 770 & 0 & 73 | 256 & 12288 & Row-Sync & 945.09 & 944.69 & 3.73 & 449.17 & 495.50 & 2 & 192 & 96 & 1000.11 & 17.12 & -5.87 74 | 256 & 12288 & Tile-Sync & 943.98 & 945.08 & 2.31 & 449.31 & 495.75 & 2 & 192 & 96 & 840.25 & 4.52 & 7.92 75 | 256 & 12288 & Stream-K & -- & -- & -- & ---- & ---- &- &- & - & 1100 & 0 & 76 | 512 & 12288 & Row-Sync & 1754.25 & 1649.61 & 15.23 & 834.58 & 814.97 & 2 & 192 & 192 & 1456.36 & 55.09 & 11.71 77 | 512 & 12288 & Tile-Sync & 1755.43 & 1645.97 & 3.18 & 833.78 & 812.14 & 2 & 192 & 192 & 1528.19 & 53.37 & 7.16 78 | 512 & 12288 & Stream-K & -- & -- & -- & ---- & ---- &- &- & - & 1640 & 0 & 79 | 1024 & 12288 & Row-Sync & 3394.56 & 2890.25 & 107.83 & 1520.78 & 1369.39 & 2 & 192 & 384 & 2208.94 & 24.44 & 23.57 80 | 1024 & 12288 & Tile-Sync & 3394.95 & 2832.03 & 112.64 & 1491.28 & 1340.72 & 2 & 192 & 384 & 2451.72 & 9.62 & 13.43 81 | 1024 & 12288 & Stream-K & -- & -- & -- & ---- & ---- &- &- & - & 2780 & 0 & 82 | 2048 & 12288 & Row-Sync & 6547.43 & 4772.97 & 218.63 & 2528.36 & 2244.58 & 2 & 384 & 768 & 4260.28 & 17.87 & 10.74 83 | 2048 & 12288 & Tile-Sync & 6779.08 & 4832.92 & 215.71 & 2578.44 & 2254.42 & 2 & 384 & 768 & 4944.72 & 20.54 & -2.31 84 | 2048 & 12288 & Stream-K & -- & -- & -- & ---- & ---- & 2 & 394 & 768 & 20600 & 0 85 | 1 & 16384 & Row-Sync & 1194.33 & 666.11 & 1.18 & 334.78 & 331.22 & 3 & 64 & 64 & 665.00 & 2.17 & 0.17 86 | 1 & 16384 & Tile-Sync & 1194.33 & 664.94 & 1.70 & 333.94 & 330.94 & 3 & 64 & 64 & 658.39 & 2.62 & 0.99 87 | 2 & 16384 & Row-Sync & 666.46 & 664.17 & 1.72 & 333.72 & 330.44 & 3 & 64 & 64 & 665.11 & 1.68 & -0.14 88 | 2 & 16384 & Tile-Sync & 666.46 & 665.61 & 2.17 & 334.67 & 330.89 & 3 & 64 & 64 & 658.72 & 2.44 & 1.03 89 | 4 & 16384 & Row-Sync & 667.21 & 664.61 & 1.50 & 334.06 & 330.44 & 3 & 64 & 64 & 665.50 & 2.75 & -0.13 90 | 4 & 16384 & Tile-Sync & 667.21 & 665.00 & 1.37 & 334.44 & 330.50 & 3 & 64 & 64 & 658.83 & 2.50 & 0.93 91 | 8 & 16384 & Row-Sync & 673.23 & 665.78 & 2.26 & 335.06 & 330.72 & 3 & 64 & 64 & 667.56 & 1.92 & -0.27 92 | 8 & 16384 & Tile-Sync & 673.23 & 667.22 & 3.06 & 336.11 & 331.00 & 3 & 64 & 64 & 659.06 & 2.10 & 1.22 93 | 16 & 16384 & Row-Sync & 688.78 & 668.94 & 1.63 & 336.50 & 332.39 & 3 & 64 & 64 & 671.56 & 3.11 & -0.39 94 | 16 & 16384 & Tile-Sync & 688.78 & 669.22 & 2.21 & 337.00 & 332.17 & 3 & 64 & 64 & 664.50 & 1.76 & 0.71 95 | 32 & 16384 & Row-Sync & 676.08 & 725.11 & 40.40 & 365.39 & 359.50 & 3 & 64 & 64 & 682.11 & 2.61 & 5.93 96 | 32 & 16384 & Tile-Sync & 676.08 & 673.56 & 2.62 & 339.22 & 334.28 & 3 & 64 & 64 & 672.17 & 2.28 & 0.21 97 | 64 & 16384 & Row-Sync & 701.51 & 686.83 & 1.98 & 348.06 & 338.78 & 3 & 64 & 64 & 705.94 & 1.43 & -2.78 98 | 64 & 16384 & Tile-Sync & 701.51 & 687.11 & 2.72 & 347.67 & 339.44 & 3 & 64 & 64 & 688.11 & 2.19 & -0.15 99 | 128 & 16384 & Row-Sync & 834.95 & 1212.69 & 2.65 & 607.03 & 605.58 & 2 & 768 & 384 & 1224.53 & 2.92 & -0.98 100 | 128 & 16384 & Tile-Sync & 834.95 & 1212.72 & 3.30 & 607.28 & 605.39 & 2 & 768 & 384 & 1253.69 & 2.75 & -3.38 101 | 256 & 16384 & Row-Sync & 1560.54 & 1265.64 & 3.46 & 648.36 & 617.22 & 2 & 256 & 128 & 1319.92 & 3.14 & -4.27 102 | 256 & 16384 & Tile-Sync & 1560.25 & 1265.83 & 3.34 & 649.22 & 616.58 & 2 & 256 & 128 & 1185.33 & 3.57 & 6.35 103 | 512 & 16384 & Row-Sync & 3104.13 & 2388.61 & 40.89 & 1188.42 & 1200.19 & 2 & 256 & 256 & 2073.61 & 23.88 & 13.19 104 | 512 & 16384 & Tile-Sync & 3109.71 & 2360.50 & 77.97 & 1174.08 & 1186.42 & 2 & 256 & 256 & 2387.14 & 23.09 & -1.13 105 | 1024 & 16384 & Row-Sync & 5974.40 & 4549.56 & 16.03 & 2267.03 & 2282.53 & 2 & 256 & 512 & 3788.47 & 13.14 & 16.73 106 | 1024 & 16384 & Tile-Sync & 5628.02 & 4701.92 & 125.90 & 2344.97 & 2356.81 & 2 & 256 & 512 & 4322.92 & 23.04 & 8.06 107 | 2048 & 16384 & Row-Sync & 10651.81 & 7838.53 & 16.67 & 4015.33 & 3823.11 & 2 & 512 & 1024 & 7074.75 & 28.41 & 9.74 108 | 2048 & 16384 & Tile-Sync & 10973.74 & 7902.97 & 88.93 & 4054.11 & 3848.75 & 2 & 512 & 1024 & 8191.17 & 38.27 & -3.65 -------------------------------------------------------------------------------- /src/ml-bench/transformer/torch-baselines/cublasBaseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import time 4 | 5 | M = int(sys.argv[1]) 6 | N = int(sys.argv[2]) 7 | K = int(sys.argv[3]) 8 | 9 | a = torch.ones((M, K), dtype=torch.half).cuda() 10 | b = torch.ones((K, N), dtype=torch.half).cuda() 11 | #c = torch.ones((M, N), dtype=torch.half).cuda() 12 | # d = torch.ones((N, L), dtype=torch.half).cuda() 13 | #e = torch.ones([M, L], dtype=torch.half).cuda() 14 | 15 | c = a@b 16 | print(c.dtype) 17 | for i in range(10): 18 | c = a@b 19 | # e = c@d 20 | torch.cuda.synchronize() 21 | 22 | epochs = 20 23 | start = time.time_ns() 24 | 25 | for i in range(epochs): 26 | c = a@b 27 | # e = c@d 28 | torch.cuda.synchronize() 29 | end = time.time_ns() 30 | 31 | print((end-start)/epochs/1e3) 32 | -------------------------------------------------------------------------------- /src/ml-bench/transformer/torch-baselines/torchAttention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import time 4 | 5 | M = int(sys.argv[1]) 6 | N = int(sys.argv[2]) 7 | K = int(sys.argv[3]) 8 | L = int(sys.argv[4]) 9 | 10 | X = torch.ones((M, K), dtype=torch.half).cuda() 11 | QKV = torch.ones((K, N*3), dtype=torch.half).cuda() 12 | W2 = torch.ones((N, L), dtype=torch.half).cuda() 13 | 14 | for i in range(10): 15 | XQKV = X@QKV 16 | XQ = XQKV[:,0:N] 17 | XK = XQKV[:,N:2*N] 18 | XV = XQKV[:,2*N:3*N] 19 | 20 | XQDotXK = XQ*XK 21 | softmax = torch.softmax(XQDotXK, dim = 0) 22 | softmaxDotXV = softmax*XV 23 | dropout = torch.dropout(softmaxDotXV, 1.0, False) 24 | out = dropout@W2 25 | torch.cuda.synchronize() 26 | 27 | epochs = 20 28 | start = time.time_ns() 29 | 30 | for i in range(epochs): 31 | XQKV = X@QKV 32 | XQ = XQKV[:,0:N] 33 | XK = XQKV[:,N:2*N] 34 | XV = XQKV[:,2*N:3*N] 35 | 36 | # XQDotXK = XQ*XK 37 | # softmax = torch.softmax(XQDotXK, dim = 0) 38 | # softmaxDotXV = softmax*XV 39 | # dropout = torch.dropout(softmaxDotXV, 1.0, False) 40 | out = XQ@W2 41 | torch.cuda.synchronize() 42 | end = time.time_ns() 43 | 44 | print((end-start)/epochs/1e3) -------------------------------------------------------------------------------- /src/ml-bench/transformer/torch-baselines/torchmlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import time 4 | 5 | M = int(sys.argv[1]) 6 | model = sys.argv[2] 7 | assert model in ['gpt3', 'llama'] 8 | 9 | if model == 'gpt3': 10 | H = 12288 11 | X = torch.ones((M, H), dtype=torch.half).cuda() 12 | W1 = torch.ones((H, H//2), dtype=torch.half).cuda() 13 | W2 = torch.ones((H//2, H), dtype=torch.half).cuda() 14 | else: 15 | H = 8192 16 | H2 = ((H//3 + 127)//128)*128 17 | X = torch.ones((M, H), dtype=torch.half).cuda() 18 | W1 = torch.ones((H, 2*H2), dtype=torch.half).cuda() 19 | W2 = torch.ones((H2, H), dtype=torch.half).cuda() 20 | XW1_ = torch.ones((M, H2), dtype=torch.half).cuda() 21 | 22 | epochs = 20 23 | for i in range(epochs): 24 | XW1 = X@W1 25 | torch.cuda.synchronize() 26 | 27 | start = time.time_ns() 28 | 29 | if model == 'gpt3': 30 | for i in range(epochs): 31 | XW1 = X@W1 32 | out = XW1@W2 33 | torch.cuda.synchronize() 34 | elif model == 'llama': 35 | for i in range(epochs): 36 | XW1 = X@W1 37 | out = XW1_@W2 38 | torch.cuda.synchronize() 39 | end = time.time_ns() 40 | 41 | print((end-start)/epochs/1e3) -------------------------------------------------------------------------------- /src/ml-bench/volta_conv2d/Makefile: -------------------------------------------------------------------------------- 1 | include ../common.mk 2 | 3 | ARCH_FLAGS=-gencode=arch=compute_70,code=[sm_70,compute_70] 4 | XCOMPILER=-Xcompiler=-fPIE -Xcompiler=-Wconversion -Xcompiler=-fno-strict-aliasing 5 | BUILD=build 6 | 7 | INCLUDES=-I$(NV_CUTLASS)/include -I$(NV_CUTLASS)/examples/common -I$(NV_CUTLASS)/tools/util/include -I$(NV_CUTLASS) -I$(CUSYNC_CUTLASS)/include -I$(CUSYNC) -I. 8 | 9 | DEFINES=-DCUTLASS_ENABLE_CUBLAS=1 -DCUTLASS_NAMESPACE=cutlass -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1 -DCUTLASS_TEST_LEVEL=0 -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1 -DCUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED=1 -DCUTLASS_DEBUG_TRACE_LEVEL=0 10 | 11 | $(BUILD)/conv-rowsync: resnet.cu 12 | $(NVCC) $(INCLUDES) $(DEFINES) -O3 -DNDEBUG $(XCOMPILER) $(ARCH_FLAGS) $< -DROWSYNC -o $@ 13 | 14 | $(BUILD)/conv-tilesync: resnet.cu 15 | $(NVCC) $(INCLUDES) $(DEFINES) -O3 -DNDEBUG $(XCOMPILER) $(ARCH_FLAGS) $< -DTILESYNC -o $@ 16 | 17 | $(BUILD)/conv-eval-streamk: $(BUILD)/conv-eval-baseline.cu 18 | $(NVCC) $(INCLUDES) $(DEFINES) -O3 -DNDEBUG $(XCOMPILER) $(ARCH_FLAGS) -DROWSYNC -DSTREAM_K $< -o $@ 19 | 20 | $(BUILD)/conv-eval-baseline: $(BUILD)/conv-eval-baseline.cu 21 | $(NVCC) $(INCLUDES) $(DEFINES) -O3 -DNDEBUG $(XCOMPILER) $(ARCH_FLAGS) -DROWSYNC $< -o $@ 22 | 23 | $(BUILD)/conv-eval-rowsync: $(BUILD)/conv-eval-rowsync.cu 24 | $(NVCC) $(INCLUDES) $(DEFINES) -O3 -DNDEBUG $(XCOMPILER) $(ARCH_FLAGS) $< -o $@ 25 | 26 | $(BUILD)/conv-eval-tilesync: $(BUILD)/conv-eval-tilesync.cu 27 | $(NVCC) $(INCLUDES) $(DEFINES) -O3 -DNDEBUG $(XCOMPILER) $(ARCH_FLAGS) $< -o $@ 28 | 29 | #VGG 30 | $(BUILD)/vgg-rowsync: vgg.cu 31 | $(NVCC) $(INCLUDES) $(DEFINES) -O3 -DNDEBUG $(XCOMPILER) $(ARCH_FLAGS) ./vgg.cu -DROWSYNC -o $@ 32 | 33 | $(BUILD)/vgg-tilesync: vgg.cu 34 | $(NVCC) $(INCLUDES) $(DEFINES) -O3 -DNDEBUG $(XCOMPILER) $(ARCH_FLAGS) ./vgg.cu -DTILESYNC -o $@ 35 | -------------------------------------------------------------------------------- /src/ml-bench/volta_conv2d/resnet_results.csv: -------------------------------------------------------------------------------- 1 | 1 & 64 & baseline & 49.94 & 0.94 2 | 1 & 64 & rowsync & 49.94 & 0.94 & 40.22 & 0.43 & 19.47 3 | 1 & 64 & tilesync & 49.94 & 0.94 & 41.28 & 0.46 & 17.35 4 | 4 & 64 & baseline & 78.50 & 1.15 5 | 4 & 64 & rowsync & 78.50 & 1.15 & 63.94 & 0.54 & 18.54 6 | 4 & 64 & tilesync & 78.50 & 1.15 & 67.17 & 1.58 & 14.44 7 | 8 & 64 & baseline & 114.89 & 2.70 8 | 8 & 64 & rowsync & 114.89 & 2.70 & 98.50 & 0.86 & 14.26 9 | 8 & 64 & tilesync & 114.89 & 2.70 & 104.06 & 1.21 & 9.43 10 | 12 & 64 & baseline & 144.72 & 0.57 11 | 12 & 64 & rowsync & 144.72 & 0.57 & 138.44 & 1.62 & 4.34 12 | 12 & 64 & tilesync & 144.72 & 0.57 & 145.17 & 2.20 & -0.31 13 | 16 & 64 & baseline & 183.06 & 5.10 14 | 16 & 64 & rowsync & 183.06 & 5.10 & 175.89 & 1.41 & 3.92 15 | 16 & 64 & tilesync & 183.06 & 5.10 & 185.39 & 1.14 & -1.27 16 | 20 & 64 & baseline & 201.44 & 1.20 17 | 20 & 64 & rowsync & 201.44 & 1.20 & 184.72 & 2.22 & 8.30 18 | 20 & 64 & tilesync & 201.44 & 1.20 & 191.33 & 4.56 & 5.02 19 | 24 & 64 & baseline & 206.28 & 0.57 20 | 24 & 64 & rowsync & 206.28 & 0.57 & 203.67 & 0.97 & 1.27 21 | 24 & 64 & tilesync & 206.28 & 0.57 & 208.00 & 0.69 & -0.83 22 | 28 & 64 & baseline & 286.39 & 1.33 23 | 28 & 64 & rowsync & 286.39 & 1.33 & 252.44 & 0.98 & 11.85 24 | 28 & 64 & tilesync & 286.39 & 1.33 & 259.39 & 1.33 & 9.43 25 | 32 & 64 & baseline & 290.28 & 1.32 26 | 32 & 64 & rowsync & 290.28 & 1.32 & 261.11 & 1.28 & 10.05 27 | 32 & 64 & tilesync & 290.28 & 1.32 & 264.94 & 1.11 & 8.73 28 | 1 & 128 & baseline & 55.44 & 0.51 29 | 1 & 128 & rowsync & 55.44 & 0.51 & 44.39 & 0.50 & 19.94 30 | 1 & 128 & tilesync & 55.44 & 0.51 & 46.33 & 0.49 & 16.43 31 | 4 & 128 & baseline & 91.11 & 1.57 32 | 4 & 128 & rowsync & 91.11 & 1.57 & 69.94 & 1.35 & 23.23 33 | 4 & 128 & tilesync & 91.11 & 1.57 & 71.94 & 1.80 & 21.04 34 | 8 & 128 & baseline & 108.00 & 0.34 35 | 8 & 128 & rowsync & 108.00 & 0.34 & 113.44 & 1.42 & -5.04 36 | 8 & 128 & tilesync & 108.00 & 0.34 & 117.22 & 1.06 & -8.54 37 | 12 & 128 & baseline & 149.00 & 0.69 38 | 12 & 128 & rowsync & 149.00 & 0.69 & 121.67 & 1.41 & 18.34 39 | 12 & 128 & tilesync & 149.00 & 0.69 & 126.33 & 1.28 & 15.21 40 | 16 & 128 & baseline & 273.50 & 1.20 41 | 16 & 128 & rowsync & 273.50 & 1.20 & 265.44 & 0.92 & 2.95 42 | 16 & 128 & tilesync & 273.50 & 1.20 & 272.83 & 1.38 & 0.24 43 | 20 & 128 & baseline & 191.89 & 1.18 44 | 20 & 128 & rowsync & 191.89 & 1.18 & 176.44 & 1.25 & 8.05 45 | 20 & 128 & tilesync & 191.89 & 1.18 & 186.72 & 1.49 & 2.69 46 | 24 & 128 & baseline & 193.56 & 0.62 47 | 24 & 128 & rowsync & 193.56 & 0.62 & 179.22 & 2.46 & 7.41 48 | 24 & 128 & tilesync & 193.56 & 0.62 & 190.06 & 1.00 & 1.81 49 | 28 & 128 & baseline & 323.06 & 8.51 50 | 28 & 128 & rowsync & 323.06 & 8.51 & 254.28 & 3.23 & 21.29 51 | 28 & 128 & tilesync & 323.06 & 8.51 & 268.94 & 1.92 & 16.75 52 | 32 & 128 & baseline & 337.17 & 3.38 53 | 32 & 128 & rowsync & 337.17 & 3.38 & 264.11 & 3.36 & 21.67 54 | 32 & 128 & tilesync & 337.17 & 3.38 & 275.28 & 2.61 & 18.36 55 | 1 & 256 & baseline & 63.56 & 0.62 56 | 1 & 256 & rowsync & 63.56 & 0.62 & 53.11 & 0.68 & 16.43 57 | 1 & 256 & tilesync & 63.56 & 0.62 & 55.33 & 0.77 & 12.94 58 | 4 & 256 & baseline & 97.00 & 0.84 59 | 4 & 256 & rowsync & 97.00 & 0.84 & 77.39 & 1.79 & 20.22 60 | 4 & 256 & tilesync & 97.00 & 0.84 & 74.11 & 0.68 & 23.60 61 | 8 & 256 & baseline & 132.00 & 0.91 62 | 8 & 256 & rowsync & 132.00 & 0.91 & 121.72 & 0.89 & 7.79 63 | 8 & 256 & tilesync & 132.00 & 0.91 & 122.00 & 0.91 & 7.58 64 | 12 & 256 & baseline & 147.44 & 1.04 65 | 12 & 256 & rowsync & 147.44 & 1.04 & 141.78 & 1.63 & 3.84 66 | 12 & 256 & tilesync & 147.44 & 1.04 & 142.94 & 1.11 & 3.05 67 | 16 & 256 & baseline & 198.61 & 0.70 68 | 16 & 256 & rowsync & 198.61 & 0.70 & 188.11 & 0.96 & 5.29 69 | 16 & 256 & tilesync & 198.61 & 0.70 & 192.06 & 1.43 & 3.30 70 | 20 & 256 & baseline & 172.06 & 0.87 71 | 20 & 256 & rowsync & 172.06 & 0.87 & 213.06 & 0.80 & -23.83 72 | 20 & 256 & tilesync & 172.06 & 0.87 & 212.83 & 0.92 & -23.70 73 | 24 & 256 & baseline & 175.61 & 1.09 74 | 24 & 256 & rowsync & 175.61 & 1.09 & 216.06 & 1.39 & -23.03 75 | 24 & 256 & tilesync & 175.61 & 1.09 & 217.39 & 0.98 & -23.79 76 | 28 & 256 & baseline & 244.11 & 1.32 77 | 28 & 256 & rowsync & 244.11 & 1.32 & 224.67 & 1.53 & 7.97 78 | 28 & 256 & tilesync & 244.11 & 1.32 & 227.28 & 1.18 & 6.90 79 | 32 & 256 & baseline & 251.17 & 1.20 80 | 32 & 256 & rowsync & 251.17 & 1.20 & 228.50 & 0.92 & 9.02 81 | 32 & 256 & tilesync & 251.17 & 1.20 & 229.83 & 1.10 & 8.49 82 | 1 & 512 & baseline & 92.39 & 0.70 83 | 1 & 512 & rowsync & 92.39 & 0.70 & 83.22 & 0.94 & 9.92 84 | 1 & 512 & tilesync & 92.39 & 0.70 & 82.94 & 1.47 & 10.22 85 | 4 & 512 & baseline & 111.17 & 0.86 86 | 4 & 512 & rowsync & 111.17 & 0.86 & 101.89 & 0.83 & 8.35 87 | 4 & 512 & tilesync & 111.17 & 0.86 & 99.06 & 0.87 & 10.89 88 | 8 & 512 & baseline & 147.22 & 1.06 89 | 8 & 512 & rowsync & 147.22 & 1.06 & 138.28 & 1.27 & 6.08 90 | 8 & 512 & tilesync & 147.22 & 1.06 & 143.00 & 1.53 & 2.87 91 | 12 & 512 & baseline & 170.83 & 1.20 92 | 12 & 512 & rowsync & 170.83 & 1.20 & 164.72 & 0.67 & 3.58 93 | 12 & 512 & tilesync & 170.83 & 1.20 & 165.56 & 1.25 & 3.09 94 | 16 & 512 & baseline & 184.33 & 0.97 95 | 16 & 512 & rowsync & 184.33 & 0.97 & 206.72 & 1.07 & -12.15 96 | 16 & 512 & tilesync & 184.33 & 0.97 & 208.06 & 1.16 & -12.87 97 | 20 & 512 & baseline & 235.72 & 1.18 98 | 20 & 512 & rowsync & 235.72 & 1.18 & 222.72 & 2.11 & 5.51 99 | 20 & 512 & tilesync & 235.72 & 1.18 & 231.00 & 1.85 & 2.00 100 | 24 & 512 & baseline & 293.22 & 9.38 101 | 24 & 512 & rowsync & 293.22 & 9.38 & 300.33 & 4.49 & -2.43 102 | 24 & 512 & tilesync & 293.22 & 9.38 & 295.89 & 3.01 & -0.91 103 | 28 & 512 & baseline & 335.67 & 8.02 104 | 28 & 512 & rowsync & 335.67 & 8.02 & 323.28 & 1.90 & 3.69 105 | 28 & 512 & tilesync & 335.67 & 8.02 & 323.00 & 1.33 & 3.77 106 | 32 & 512 & baseline & 350.17 & 1.20 107 | 32 & 512 & rowsync & 350.17 & 1.20 & 339.72 & 1.27 & 2.98 108 | 32 & 512 & tilesync & 350.17 & 1.20 & 342.11 & 1.13 & 2.30 109 | -------------------------------------------------------------------------------- /src/ml-bench/volta_conv2d/torchconv2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import sys 4 | import time 5 | 6 | N = int(sys.argv[1]) 7 | H = 224 #int(sys.argv[2]) 8 | W = 224 #int(sys.argv[3]) 9 | C = int(sys.argv[2]) 10 | K = 1#int(sys.argv[5]) 11 | R = 1#int(sys.argv[6]) 12 | S = 1#int(sys.argv[7]) 13 | 14 | imgs = torch.ones((N, 3, H, W), dtype=torch.float16) 15 | imgs = imgs.cuda() 16 | conv1 = torch.nn.Conv2d(3, 64, 7, stride=2,padding=3, dtype=torch.float16).cuda() 17 | conv2 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1).cuda() 18 | conv3 = torch.nn.Conv2d(64, 64, 3, stride=1,padding=1, dtype=torch.float16).cuda() 19 | conv4 = torch.nn.Conv2d(64, 128, 3, stride=2,padding=1, dtype=torch.float16).cuda() 20 | conv5 = torch.nn.Conv2d(128, 128, 3, stride=1,padding=1, dtype=torch.float16).cuda() 21 | conv6 = torch.nn.Conv2d(128, 256, 3, stride=2,padding=1, dtype=torch.float16).cuda() 22 | conv7 = torch.nn.Conv2d(256, 256, 3, stride=1,padding=1, dtype=torch.float16).cuda() 23 | conv8 = torch.nn.Conv2d(256, 512, 3, stride=2,padding=1, dtype=torch.float16).cuda() 24 | conv9 = torch.nn.Conv2d(512, 512, 3, stride=1,padding=1, dtype=torch.float16).cuda() 25 | 26 | # conv3 = torch.nn.Conv2d(64, 64, 3, stride=1,padding=1) 27 | 28 | conv1_o = conv1(imgs) 29 | conv2_o = conv2(conv1_o) 30 | 31 | conv5_in = conv4(conv2_o) 32 | # print("Input shape for 128, 3x3", conv5_in.shape) 33 | 34 | conv7_in = conv6(conv5_in) 35 | # print("Input shape for 256, 3x3", conv7_in.shape) 36 | 37 | conv9_in = conv8(conv7_in) 38 | # print("Input shape for 512, 3x3", conv9_in.shape) 39 | 40 | def conv64x64_3(input): 41 | conv3_o = conv3(input) 42 | conv3_o = conv3(conv3_o) 43 | # conv3_o = conv3(conv3_o) 44 | 45 | def conv128x128_3(input): 46 | conv5_o = conv5(input) 47 | conv5_o = conv5(conv5_o) 48 | # conv5_o = conv5(conv5_o) 49 | 50 | def conv256x256_3(input): 51 | conv7_o = conv7(input) 52 | conv7_o = conv7(conv7_o) 53 | # conv7_o = conv7(conv7_o) 54 | 55 | def conv512x512_3(input): 56 | conv9_o = conv9(input) 57 | conv9_o = conv9(conv9_o) 58 | # conv9_o = conv9(conv9_o) 59 | 60 | 61 | for i in range(10): 62 | conv64x64_3(conv2_o) 63 | 64 | torch.cuda.synchronize() 65 | 66 | def execute(f, input): 67 | epochs = 20 68 | start = time.time_ns() 69 | 70 | for i in range(epochs): 71 | f(input) 72 | 73 | torch.cuda.synchronize() 74 | end = time.time_ns() 75 | return (end-start)/epochs/1e3 76 | 77 | if C == 64: 78 | print(execute(conv64x64_3, conv2_o)) 79 | elif C == 128: 80 | print(execute(conv128x128_3, conv5_in)) 81 | elif C == 256: 82 | print(execute(conv256x256_3, conv7_in)) 83 | elif C == 512: 84 | print(execute(conv512x512_3, conv9_in)) 85 | else: 86 | print("Invalid C=", C) -------------------------------------------------------------------------------- /src/ml-bench/volta_conv2d/vgg-results-cuda-12.2: -------------------------------------------------------------------------------- 1 | 1 & 256 & baseline & 122.67 & 0.59 2 | 1 & 256 & rowsync & 122.67 & 0.59 & 97.11 & 1.18 & 20.83 3 | 1 & 256 & tilesync & 122.67 & 0.59 & 106.17 & 0.92 & 13.45 4 | 4 & 256 & baseline & 193.44 & 0.86 5 | 4 & 256 & rowsync & 193.44 & 0.86 & 152.17 & 2.15 & 21.34 6 | 4 & 256 & tilesync & 193.44 & 0.86 & 141.50 & 2.01 & 26.85 7 | 8 & 256 & baseline & 241.83 & 1.34 8 | 8 & 256 & rowsync & 241.83 & 1.34 & 207.94 & 1.86 & 14.01 9 | 8 & 256 & tilesync & 241.83 & 1.34 & 210.67 & 2.68 & 12.89 10 | 12 & 256 & baseline & 279.33 & 1.33 11 | 12 & 256 & rowsync & 279.33 & 1.33 & 252.44 & 1.25 & 9.63 12 | 12 & 256 & tilesync & 279.33 & 1.33 & 257.28 & 2.16 & 7.90 13 | 16 & 256 & baseline & 394.83 & 0.86 14 | 16 & 256 & rowsync & 394.83 & 0.86 & 360.00 & 1.14 & 8.82 15 | 16 & 256 & tilesync & 394.83 & 0.86 & 362.17 & 1.29 & 8.27 16 | 20 & 256 & baseline & 335.33 & 1.14 17 | 20 & 256 & rowsync & 335.33 & 1.14 & 458.56 & 4.51 & -36.75 18 | 20 & 256 & tilesync & 335.33 & 1.14 & 465.61 & 5.24 & -38.85 19 | 24 & 256 & baseline & 340.17 & 1.34 20 | 24 & 256 & rowsync & 340.17 & 1.34 & 469.39 & 5.78 & -37.99 21 | 24 & 256 & tilesync & 340.17 & 1.34 & 473.72 & 4.38 & -39.26 22 | 28 & 256 & baseline & 467.44 & 3.07 23 | 28 & 256 & rowsync & 467.44 & 3.07 & 363.00 & 1.78 & 22.34 24 | 28 & 256 & tilesync & 467.44 & 3.07 & 365.89 & 1.41 & 21.73 25 | 32 & 256 & baseline & 477.61 & 1.58 26 | 32 & 256 & rowsync & 477.61 & 1.58 & 423.28 & 1.96 & 11.38 27 | 32 & 256 & tilesync & 477.61 & 1.58 & 424.44 & 2.01 & 11.13 28 | 1 & 512 & baseline & 168.67 & 0.84 29 | 1 & 512 & rowsync & 168.67 & 0.84 & 152.22 & 1.48 & 9.75 30 | 1 & 512 & tilesync & 168.67 & 0.84 & 149.50 & 2.23 & 11.36 31 | 4 & 512 & baseline & 218.22 & 0.73 32 | 4 & 512 & rowsync & 218.22 & 0.73 & 195.33 & 1.19 & 10.49 33 | 4 & 512 & tilesync & 218.22 & 0.73 & 189.44 & 2.41 & 13.19 34 | 8 & 512 & baseline & 284.78 & 1.80 35 | 8 & 512 & rowsync & 284.78 & 1.80 & 260.06 & 1.47 & 8.68 36 | 8 & 512 & tilesync & 284.78 & 1.80 & 268.11 & 2.68 & 5.85 37 | 12 & 512 & baseline & 299.06 & 1.16 38 | 12 & 512 & rowsync & 299.06 & 1.16 & 286.28 & 1.23 & 4.27 39 | 12 & 512 & tilesync & 299.06 & 1.16 & 287.06 & 1.83 & 4.01 40 | 16 & 512 & baseline & 369.61 & 2.15 41 | 16 & 512 & rowsync & 369.61 & 2.15 & 416.11 & 1.75 & -12.58 42 | 16 & 512 & tilesync & 369.61 & 2.15 & 413.83 & 4.71 & -11.96 43 | 20 & 512 & baseline & 457.89 & 2.78 44 | 20 & 512 & rowsync & 457.89 & 2.78 & 415.06 & 2.94 & 9.35 45 | 20 & 512 & tilesync & 457.89 & 2.78 & 432.17 & 2.62 & 5.62 46 | 24 & 512 & baseline & 603.67 & 11.37 47 | 24 & 512 & rowsync & 603.67 & 11.37 & 551.94 & 3.69 & 8.57 48 | 24 & 512 & tilesync & 603.67 & 11.37 & 551.44 & 3.48 & 8.65 49 | 28 & 512 & baseline & 682.50 & 5.83 50 | 28 & 512 & rowsync & 682.50 & 5.83 & 627.56 & 3.75 & 8.05 51 | 28 & 512 & tilesync & 682.50 & 5.83 & 619.00 & 2.89 & 9.30 52 | 32 & 512 & baseline & 696.39 & 1.20 53 | 32 & 512 & rowsync & 696.39 & 1.20 & 662.67 & 1.37 & 4.84 54 | 32 & 512 & tilesync & 696.39 & 1.20 & 662.89 & 1.53 & 4.81 -------------------------------------------------------------------------------- /tests/cusync-test.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #pragma once 6 | 7 | namespace cusync { 8 | class CuSyncTest { 9 | private: 10 | uint numStages; 11 | char* semValidArray; 12 | 13 | public: 14 | CuSyncTest(int numStages_) : numStages(numStages_) { 15 | CUDA_CHECK(cudaMalloc(&semValidArray, numStages * sizeof(char))); 16 | CUDA_CHECK(cudaMemset(semValidArray, 1, numStages * sizeof(char))); 17 | } 18 | 19 | template 20 | __device__ 21 | void setSemValue(uint stageIdx, dim3 tile, CuStage& custage) { 22 | if (!custage.isConsumer()) return; 23 | if (threadIdx.x == 0) { 24 | char eq = (char)((custage.expectedWaitValue(tile)* custage.iter) == custage.waitSemValue(tile)); 25 | semValidArray[stageIdx] = (bool) eq && ((bool) semValidArray[stageIdx]); 26 | } 27 | } 28 | 29 | bool allSemsCorrect() { 30 | char* hostSemValids = new char[numStages]; 31 | 32 | CUDA_CHECK(cudaMemcpy(hostSemValids, semValidArray, numStages * sizeof(char), cudaMemcpyDeviceToHost)); 33 | 34 | bool eq = true; 35 | for (uint i = 0; i < numStages; i++) { 36 | eq = eq && (bool)hostSemValids[i]; 37 | } 38 | 39 | delete hostSemValids; 40 | return eq; 41 | } 42 | 43 | ~CuSyncTest() {} 44 | }; 45 | } -------------------------------------------------------------------------------- /tests/simple-test.cu: -------------------------------------------------------------------------------- 1 | #include "cusync-test.h" 2 | 3 | #include "gtest/gtest.h" 4 | 5 | typedef uint ElemType; 6 | 7 | using namespace cusync; 8 | 9 | /* 10 | * This kernel copies elements of in array to out array. 11 | * For each thread block, the kernel post the status of tile (thread block) 12 | * and waits until the status of tile has reached expected value 13 | */ 14 | template 15 | __global__ 16 | void kernel(CuSyncTest cutest, CuStage custage, int idx, ElemType* in, ElemType* out) { 17 | dim3 tile = blockIdx; 18 | __shared__ int tileSh[3]; 19 | tile = custage.tile((dim3*)&tileSh[0]); 20 | custage.wait(tile); 21 | 22 | uint linearid = threadIdx.x + blockIdx.x * blockDim.x; 23 | out[linearid] = in[linearid]; 24 | 25 | custage.post(tile); 26 | cutest.setSemValue(idx, tile, custage); 27 | } 28 | 29 | /* 30 | * The test runs two kernels to copy from the source array to two output arrays. 31 | * The kernels are synchronized using the given synchronization. Finally, 32 | * checks the output of both copies and the value of semaphores are equal to the expected value. 33 | */ 34 | template 35 | bool run(int iters) { 36 | ElemType* array1, *array2, *array3; 37 | size_t size = 1 << 20; 38 | 39 | //Allocate three arrays 40 | CUDA_CHECK(cudaMalloc(&array1, size * sizeof(ElemType))); 41 | CUDA_CHECK(cudaMalloc(&array2, size * sizeof(ElemType))); 42 | CUDA_CHECK(cudaMalloc(&array3, size * sizeof(ElemType))); 43 | 44 | ElemType* hostarray = new ElemType[size]; 45 | //Initialize input array 46 | for (uint i = 0; i < size; i++) { 47 | hostarray[i] = i; 48 | } 49 | 50 | CUDA_CHECK(cudaMemcpy(array1, hostarray, size * sizeof(ElemType), cudaMemcpyHostToDevice)); 51 | 52 | CUDA_CHECK(cudaMemset(array2, 0, size * sizeof(ElemType))); 53 | CUDA_CHECK(cudaMemset(array3, 0, size * sizeof(ElemType))); 54 | 55 | cudaStream_t prod_stream, cons_stream; 56 | CUDA_CHECK(cudaStreamCreateWithFlags(&cons_stream, cudaStreamNonBlocking)); 57 | CUDA_CHECK(cudaStreamCreateWithFlags(&prod_stream, cudaStreamNonBlocking)); 58 | 59 | dim3 threads(128, 1, 1); 60 | dim3 grid(size/threads.x, 1, 1); 61 | 62 | //Expected value of each semaphore is 1 63 | SyncPolicy sync; 64 | ProdCuStage prod(grid, threads, NoSync(), sync); 65 | ConsCuStage cons(grid, threads, sync, NoSync()); 66 | CuSync::setProducerConsumerPair(prod, cons); 67 | 68 | CuSyncTest cutest(1); 69 | 70 | //Invoke both kernels 71 | int i = 0; 72 | while (i < iters) { 73 | kernel<<>>(cutest, prod, -1, array1, array2); 74 | prod.invokeWaitKernel(cons_stream); 75 | kernel<<>>(cutest, cons, 0, array2, array3); 76 | 77 | CUDA_CHECK(cudaDeviceSynchronize()); 78 | prod.incrementIter(); 79 | cons.incrementIter(); 80 | i++; 81 | } 82 | 83 | //Check that copies to array2 and array3 are correct 84 | CUDA_CHECK(cudaMemcpy(hostarray, array2, size * sizeof(ElemType), cudaMemcpyDeviceToHost)); 85 | bool eq = true; 86 | for (uint i = 0; i < size; i++) { 87 | eq = eq && (hostarray[i] == i); 88 | } 89 | 90 | CUDA_CHECK(cudaMemcpy(hostarray, array3, size * sizeof(ElemType), cudaMemcpyDeviceToHost)); 91 | for (uint i = 0; i < size; i++) { 92 | eq = eq && (hostarray[i] == i); 93 | } 94 | 95 | //Check that value of each semaphore is equal to the expected value 96 | eq = eq && cutest.allSemsCorrect(); 97 | 98 | //Cleanup 99 | delete hostarray; 100 | CUDA_CHECK(cudaFree(array1)); 101 | CUDA_CHECK(cudaFree(array2)); 102 | CUDA_CHECK(cudaFree(array3)); 103 | 104 | return eq; 105 | } 106 | 107 | TEST(SimpleTest_TileSync, NoOpts) { 108 | using Sync = TileSync; 109 | using ProdCuStage = CuStage; 110 | using ConsCuStage = CuStage; 111 | bool result = run(1); 112 | EXPECT_TRUE(result); 113 | } 114 | 115 | TEST(SimpleTest_TileSync_MultiIters, NoOpts) { 116 | using Sync = TileSync; 117 | using ProdCuStage = CuStage; 118 | using ConsCuStage = CuStage; 119 | bool result = run(2); 120 | EXPECT_TRUE(result); 121 | } 122 | 123 | TEST(SimpleTest_TileSync, NoAtomicAdd) { 124 | using Sync = TileSync; 125 | using ProdCuStage = CuStage; 126 | using ConsCuStage = CuStage; 127 | 128 | bool result = run(1); 129 | EXPECT_TRUE(result); 130 | } 131 | 132 | TEST(SimpleTest_TileSync, AvoidCustomOrder) { 133 | using Sync = TileSync; 134 | using ProdCuStage = CuStage; 135 | using ConsCuStage = CuStage; 136 | 137 | bool result = run(1); 138 | EXPECT_TRUE(result); 139 | } --------------------------------------------------------------------------------