├── documents ├── multigpu.md ├── contributing.md ├── autotuning.md └── python-api.md ├── requirements.txt ├── tests ├── benchmarks │ ├── requirements.txt │ ├── single-a100-flops.png │ ├── single-x86-flops.png │ ├── pykronecker-ex.py │ ├── process_fullresults.py │ ├── CMakeLists.txt │ ├── torch_kron.py │ └── profile_kernels.py ├── src │ ├── general-tests-all.cpp │ ├── select-backend.h │ ├── distinct-shapes.cpp │ ├── old │ │ ├── odd-shapes.cpp │ │ ├── non-square-tuner-tests.cpp │ │ ├── non-square-TT-tests.cpp │ │ ├── tuner-tests.cpp │ │ ├── TT-tests.cpp │ │ ├── fusion-tests.cpp │ │ └── no-fusion-tests.cpp │ ├── multi-cuda-distinct-shapes.cpp │ ├── multi-cuda-tuner-tests.cpp │ ├── multi-cuda-tests-kernel_decl.inc │ ├── multi-cuda-no-fusion-non-square-tests.cpp │ ├── multi-cuda-no-fusion-tests.cpp │ └── general-tests-TT.cpp ├── x86 │ ├── memcheck.sh │ └── CMakeLists.txt ├── update-configs.py ├── CMakeLists.txt ├── cuda │ ├── single-cuda-kernel-decls.in │ ├── CMakeLists.txt │ └── old │ │ └── CMakeLists.txt ├── run-tests.py └── python │ ├── test_wheels.py │ ├── test_numpy.py │ └── test_torch.py ├── src ├── kernels │ ├── hip_kernel_info.h │ ├── gpu_kmmkernel.cu │ ├── cuda │ │ ├── utils.cuh │ │ ├── mma.cuh │ │ └── register-loads.cuh │ ├── cpu │ │ ├── mma.h │ │ ├── memory-store.h │ │ └── tensor.h │ ├── cpu_kmmkernel.h │ ├── cuda_kmmkernel.h │ ├── get_batched_data.h │ ├── best-kernels │ │ ├── kmm-v100-kernels │ │ ├── kmm-a100-kernels │ │ ├── x86-avx-kernels │ │ └── kmm-x86-avx-kernels │ ├── kmmkernel.cpp │ ├── hw_details.h │ ├── gpu_kmmkernel.h │ └── kernel_opt.h ├── kmm │ ├── matrix.cpp │ ├── coord.h │ └── stackarray.h ├── handle │ ├── op.h │ ├── distrib_handle.cpp │ └── handle.inline.h ├── best_performing_kernels.in ├── env │ ├── env.h │ └── env.cpp ├── utils │ ├── logger.h │ ├── thread_pool.h │ └── utils.h ├── kernel_db │ ├── kernel_db.inline.h │ ├── hip_kernel_db.h │ └── hip_kernel_db.hip ├── config.h ├── optimized_kernel_map.py ├── print_best_kernel_for_shapes.py └── autotuner │ └── autotuner.h ├── Dockerfile ├── example ├── Makefile ├── x86-example.cpp └── cuda-example.cu ├── packaging ├── manylinux_docker_build.sh ├── any_build.sh └── wheels.py ├── pyfastkron ├── __init__.py └── fastkronhandle.py ├── .gitmodules ├── libFastKron.pc.in ├── include └── fastkronMg.h ├── pyproject.toml ├── LICENSE.txt ├── setup.py └── README.md /documents/multigpu.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy >= 1.0 2 | torch >=1.10 3 | torchvision >=0.1 4 | -------------------------------------------------------------------------------- /tests/benchmarks/requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch 2 | torchvision 3 | torchaudio 4 | gpytorch -------------------------------------------------------------------------------- /src/kernels/hip_kernel_info.h: -------------------------------------------------------------------------------- 1 | #include "gpu_kmmkernel.h" 2 | 3 | typedef GPUKMMKernel HIPKernel; -------------------------------------------------------------------------------- /tests/src/general-tests-all.cpp: -------------------------------------------------------------------------------- 1 | #include "general-tests-NN.cpp" 2 | #include "general-tests-TT.cpp" -------------------------------------------------------------------------------- /documents/contributing.md: -------------------------------------------------------------------------------- 1 | Create wheels package 2 | 3 | Running tests 4 | 5 | Use flake8 6 | 7 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | 2 | FROM nvcr.io/nvidia/pytorch:23.11-py3 3 | RUN mkdir /fastkron/ 4 | COPY . /fastkron/ 5 | -------------------------------------------------------------------------------- /tests/src/select-backend.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | fastKronBackend getTestBackend() { 4 | #ifdef TEST_ 5 | } -------------------------------------------------------------------------------- /tests/benchmarks/single-a100-flops.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhijangda/fastkron/HEAD/tests/benchmarks/single-a100-flops.png -------------------------------------------------------------------------------- /tests/benchmarks/single-x86-flops.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhijangda/fastkron/HEAD/tests/benchmarks/single-x86-flops.png -------------------------------------------------------------------------------- /tests/x86/memcheck.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/sh 2 | 3 | OMP_NUM_THREADS=1 valgrind --leak-check=full --show-leak-kinds=all --xml=yes --xml-file=l.xml $1 4 | -------------------------------------------------------------------------------- /example/Makefile: -------------------------------------------------------------------------------- 1 | cuda-example: cuda-example.cu 2 | nvcc $< -L ../build/ -lFastKron -I ../include -o $@ -g 3 | 4 | x86-example: x86-example.cpp 5 | g++ $< -L ../build/ -lFastKron -I ../include -o $@ -g 6 | -------------------------------------------------------------------------------- /src/kmm/matrix.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "kmm/matrix.h" 4 | 5 | std::size_t std::hash::operator()(const Factor& m) const { 6 | return hash()(m.p()) ^ hash()(m.q()); 7 | } -------------------------------------------------------------------------------- /packaging/manylinux_docker_build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | export PATH="/opt/python/$1-$1/bin:$PATH" 4 | cd $2 5 | /opt/python/$1-$1/bin/pip install setuptools-scm 6 | rm -rf build/ 7 | python setup.py bdist_wheel -------------------------------------------------------------------------------- /pyfastkron/__init__.py: -------------------------------------------------------------------------------- 1 | # Read version number as written by setuptools_scm 2 | try: 3 | from pyfastkron.version import version as __version__ # @manual 4 | except Exception: # pragma: no cover 5 | __version__ = "Unknown" -------------------------------------------------------------------------------- /packaging/any_build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | export PATH="/opt/python/cp311-cp311/bin:$PATH" 4 | /opt/python/cp311-cp311/bin/pip install setuptools-scm 5 | cd $1 6 | rm -rf build/ 7 | BUILD_ANY_WHEEL=1 python3.11 setup.py bdist_wheel -------------------------------------------------------------------------------- /src/handle/op.h: -------------------------------------------------------------------------------- 1 | #include "fastkron.h" 2 | #include "config.h" 3 | 4 | #pragma once 5 | 6 | std::string fastKronOpToStr(const fastKronOp& op); 7 | std::ostream& operator<<(std::ostream& os, const fastKronOp& op); 8 | fastKronOp swapFastKronOp(fastKronOp op); -------------------------------------------------------------------------------- /src/best_performing_kernels.in: -------------------------------------------------------------------------------- 1 | 256,16,16,16,1,4096,1,16,2 2 | 128,16,16,16,1,2048,1,16,1 3 | 128,8,8,8,1,1024,1,8,2 4 | 512,8,8,8,1,4096,1,8,3 5 | 128,8,8,8,1,1024,1,8,1 6 | 256,32,32,32,1,4096,1,16,1 7 | 256,32,32,32,1,8192,1,32,2 8 | 256,32,32,32,1,8192,1,32,2 -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tests/googletest"] 2 | path = tests/googletest 3 | url = https://github.com/google/googletest.git 4 | [submodule "tests/benchmarks/AnyOption"] 5 | path = tests/benchmarks/AnyOption 6 | url = https://github.com/hackorama/AnyOption.git 7 | [submodule "pybind11"] 8 | path = pybind11 9 | url = ../../pybind/pybind11 10 | branch = stable 11 | -------------------------------------------------------------------------------- /tests/benchmarks/pykronecker-ex.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax.numpy as jnp 3 | from pykronecker import KroneckerProduct 4 | 5 | As = [] 6 | for i in range(0,10): 7 | As.append(jnp.asarray(np.random.normal(size=(4,4)))) 8 | x = jnp.asarray(np.random.normal(size=(320, 4**10))) 9 | 10 | KP = KroneckerProduct(As) 11 | for i in range(10): 12 | y = x @ KP 13 | -------------------------------------------------------------------------------- /src/env/env.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | enum DistComm { 4 | DistCommNone = 0, 5 | P2P, 6 | NCCL, 7 | }; 8 | 9 | 10 | std::ostream& operator<<(std::ostream &out, DistComm comm); 11 | 12 | enum LogLevel { 13 | Nothing = 0, 14 | Info = 1, 15 | Debug = 2 16 | }; 17 | 18 | namespace env { 19 | DistComm getDistComm(); 20 | LogLevel getLogLevel(); 21 | bool getUseTune(); 22 | } -------------------------------------------------------------------------------- /tests/update-configs.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | 4 | f = open(sys.argv[1], 'r+') 5 | s = f.read() 6 | f.close() 7 | newS = "" 8 | for d in re.findall(r'.+', s): 9 | (th, q, p, tileQ, tileM, k, rp, rc, fused, elem, rowModTileIsZero, kEqVar, dist) = d.split(',') 10 | newS += ",".join((th, q, p, tileQ, k, tileM, fused, dist, rp, rc, elem, rowModTileIsZero, kEqVar)) + "\n" 11 | 12 | print(newS) -------------------------------------------------------------------------------- /libFastKron.pc.in: -------------------------------------------------------------------------------- 1 | prefix="@CMAKE_INSTALL_PREFIX@" 2 | exec_prefix="${prefix}" 3 | libdir="${prefix}/lib" 4 | includedir="${prefix}/include" 5 | 6 | Name: @PROJECT_NAME@ 7 | Description: @CMAKE_PROJECT_DESCRIPTION@ 8 | URL: @CMAKE_PROJECT_HOMEPAGE_URL@ 9 | Version: @PROJECT_VERSION@ 10 | Requires: @pc_req_public@ 11 | Requires.private: @pc_req_private@ 12 | Cflags: -I"${includedir}" 13 | Libs: -L"${libdir}" -lFastKron -------------------------------------------------------------------------------- /tests/benchmarks/process_fullresults.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import re 3 | 4 | def slurp(filepath): 5 | with open(filepath, 'r') as f: 6 | return f.read() 7 | 8 | def process(filepath): 9 | for line in re.findall('.+', slurp(filepath)): 10 | split = line.split('&') 11 | if len(split) == 5 and float(split[4].strip()) <= 1 and float(split[4].strip()) >= 0 and float(split[1].strip()) != 1.00: 12 | print(split) 13 | 14 | 15 | 16 | if __name__ == "__main__": 17 | process(sys.argv[1]) -------------------------------------------------------------------------------- /tests/src/distinct-shapes.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "testBase.h" 3 | 4 | TEST(EXPAND(TEST_BACKEND,DistinctShapes), Case1) { 5 | uint KP_MAT_N[] = {16,32,8,32}; 6 | uint KP_MAT_K[] = {8,8,16,8}; 7 | uint N = 1; 8 | uint K = 1; 9 | for (uint i = 0; i < (uint)4; i++) { 10 | N *= KP_MAT_N[i]; 11 | K *= KP_MAT_K[i]; 12 | } 13 | bool b = run(128, N, K, 4, KP_MAT_N, KP_MAT_K, fastKronOp_N, fastKronOp_N, 1, 0, false, 1, 1, 1, 1, true, false, true, getTestBackend(), false); 14 | EXPECT_TRUE(b); 15 | } -------------------------------------------------------------------------------- /src/kernels/gpu_kmmkernel.cu: -------------------------------------------------------------------------------- 1 | #include "kernels/gpu_kmmkernel.h" 2 | 3 | std::string GPUKMMKernel::str() const { 4 | std::stringstream info; 5 | info << backend() << "_" << arch() << "_" 6 | << getNumThreads() << "_" << KMMKernel::str(); 7 | return info.str(); 8 | } 9 | 10 | bool GPUKMMKernel::canCompute(KMMProblem problem, const HardwareDetails* hw, bool p2p, 11 | KernelBatchType::Ty probBatchType, 12 | bool exactFuse) { 13 | return KMMKernel::canCompute(problem, hw, p2p, probBatchType, exactFuse); 14 | } -------------------------------------------------------------------------------- /tests/src/old/odd-shapes.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "testBase.h" 3 | 4 | #define NON_SQUARE(M, Facs, P, Q, Type) \ 5 | TEST(EXPAND(TEST_BACKEND,OddShape), Type##_##M##x##Facs##x##P##x##Q##_) { \ 6 | uint KP_MAT_N[Facs];\ 7 | uint KP_MAT_K[Facs];\ 8 | uint N = 1;\ 9 | uint K = 1;\ 10 | for (uint i = 0; i < (uint)Facs; i++) {\ 11 | N *= Q;\ 12 | K *= P;\ 13 | KP_MAT_K[i] = P;\ 14 | KP_MAT_N[i] = Q;\ 15 | }\ 16 | bool b = run(M, N, K, Facs, KP_MAT_N, KP_MAT_K, fastKronOp_N, fastKronOp_N, 1, 0, false, 1, 1, 1, 1, true, false, true, getTestBackend(), false);\ 17 | EXPECT_TRUE(b);\ 18 | } 19 | 20 | NON_SQUARE(12, 2, 31, 16, float) 21 | NON_SQUARE(8, 2, 16, 31, float) 22 | NON_SQUARE(6, 4, 31, 31, float) -------------------------------------------------------------------------------- /src/kernels/cuda/utils.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #pragma once 5 | 6 | #define MIN(x,y) (((x) < (y)) ? (x) : (y)) 7 | #define MAX(x,y) (((x) > (y)) ? (x) : (y)) 8 | #define DIVUP(x, y) (((x) + (y) - 1)/((y))) 9 | 10 | __host__ __device__ constexpr uint power(const uint x, const uint y) { 11 | uint result = 1; 12 | for (uint i = 0; i < y; i++) { 13 | result = result * x; 14 | } 15 | return result; 16 | } 17 | 18 | __device__ __forceinline__ bool isfirstIdx(dim3 idx) {return idx.x == 0 && idx.y == 0 & idx.z == 0;} 19 | 20 | template 21 | __device__ __forceinline__ 22 | size_t nonAlignedElems(const ElemT* ptr, uint vecElems) { 23 | return (reinterpret_cast(ptr)/sizeof(ElemT)) % vecElems; 24 | } -------------------------------------------------------------------------------- /tests/src/multi-cuda-distinct-shapes.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "testBase.h" 3 | 4 | #define MULTI_GPU_DISTINCT_SHAPES_TEST(M, GM, GK, LocalKrons) \ 5 | TEST(MultiGPUDistinctShapesTest, GM##_##GK##_##LocalKrons##_) {\ 6 | uint KP_MAT_N[] = {16,8,8,8};\ 7 | uint KP_MAT_K[] = {8,32,16,32};\ 8 | uint N = 1;\ 9 | uint K = 1;\ 10 | for (uint i = 0; i < (uint)4; i++) {\ 11 | N *= KP_MAT_N[i];\ 12 | K *= KP_MAT_K[i];\ 13 | }\ 14 | bool b = run(FastKronMMType::MKM, M, N, K, 4, KP_MAT_N, KP_MAT_K, fastKronOp_N, fastKronOp_N, 1,1,1,1, 1.0f, 0.0f, 1, 0, false, GM, GK, GM*GK, LocalKrons, true, false, true, fastKronBackend_CUDA, false, false);\ 15 | EXPECT_TRUE(b);\ 16 | } 17 | 18 | MULTI_GPU_DISTINCT_SHAPES_TEST(256, 1, 2, 1); 19 | MULTI_GPU_DISTINCT_SHAPES_TEST(256, 2, 2, 1); -------------------------------------------------------------------------------- /tests/src/old/non-square-tuner-tests.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "testBase.h" 3 | 4 | #define NON_SQUARE(M, Facs, P, Q, Type) \ 5 | TEST(EXPAND(TEST_BACKEND,NonSquare), Type##_##M##x##Facs##x##P##x##Q##_) { \ 6 | uint KP_MAT_N[Facs];\ 7 | uint KP_MAT_K[Facs];\ 8 | uint N = 1;\ 9 | uint K = 1;\ 10 | for (uint i = 0; i < (uint)Facs; i++) {\ 11 | N *= Q;\ 12 | K *= P;\ 13 | KP_MAT_K[i] = P;\ 14 | KP_MAT_N[i] = Q;\ 15 | }\ 16 | bool b = run(M, N, K, Facs, KP_MAT_N, KP_MAT_K, fastKronOp_N, fastKronOp_N, 1, 0, false, 1, 1, 1, 1, true, false, true, getTestBackend(), false);\ 17 | EXPECT_TRUE(b);\ 18 | } 19 | 20 | NON_SQUARE(11, 4, 8, 16, float) 21 | NON_SQUARE(12, 5, 8, 16, float) 22 | 23 | NON_SQUARE(12, 3, 128, 32, float) 24 | 25 | NON_SQUARE(11, 3, 32, 16, float) 26 | NON_SQUARE(12, 4, 32, 16, float) -------------------------------------------------------------------------------- /tests/src/multi-cuda-tuner-tests.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "testBase.h" 3 | 4 | #define MULTI_GPU_TUNER_TEST(M, Facs, FacSize, GM, GK, KronBatch, Type) \ 5 | TEST(MultiGpuTuner, Type##_##M##x##Facs##x##FacSize##x##GM##x##GK##x##KronBatch##_) { \ 6 | uint KP_MAT_N[Facs];\ 7 | uint KP_MAT_K[Facs];\ 8 | uint N = 1;\ 9 | uint K = 1;\ 10 | for (uint i = 0; i < (uint)Facs; i++) {\ 11 | N *= FacSize;\ 12 | K *= FacSize;\ 13 | KP_MAT_K[i] = KP_MAT_N[i] = FacSize;\ 14 | }\ 15 | bool b = run(FastKronMMType::MKM, M, N, K, Facs, KP_MAT_N, KP_MAT_K, fastKronOp_N, fastKronOp_N,1,1,1,1, 1.0f, 0.0f, 0, 0, false, GM, GK, GM*GK, KronBatch, true, true, true, fastKronBackend_CUDA, false, false);\ 16 | EXPECT_TRUE(b);\ 17 | } 18 | 19 | MULTI_GPU_TUNER_TEST(128, 5, 16, 2, 1, 5, float); 20 | MULTI_GPU_TUNER_TEST(128, 4, 16, 2, 2, 3, float); -------------------------------------------------------------------------------- /tests/src/old/non-square-TT-tests.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "testBase.h" 3 | 4 | #define NON_SQUARE_TT(M, Facs, P, Q, Type) \ 5 | TEST(EXPAND(TEST_BACKEND,NonSquareTT), Type##_##M##x##Facs##x##P##x##Q##_) { \ 6 | uint KP_MAT_N[Facs];\ 7 | uint KP_MAT_K[Facs];\ 8 | uint N = 1;\ 9 | uint K = 1;\ 10 | for (uint i = 0; i < (uint)Facs; i++) {\ 11 | N *= Q;\ 12 | K *= P;\ 13 | KP_MAT_K[i] = P;\ 14 | KP_MAT_N[i] = Q;\ 15 | }\ 16 | bool b = run(M, N, K, Facs, KP_MAT_N, KP_MAT_K, fastKronOp_T, fastKronOp_T, 1, 0, false, 1, 1, 1, 1, true, false, true, getTestBackend(), false);\ 17 | EXPECT_TRUE(b);\ 18 | } 19 | 20 | NON_SQUARE_TT(11, 4, 8, 16, float) 21 | NON_SQUARE_TT(12, 5, 8, 16, float) 22 | 23 | NON_SQUARE_TT(12, 3, 128, 32, float) 24 | 25 | NON_SQUARE_TT(11, 3, 32, 16, float) 26 | NON_SQUARE_TT(12, 4, 32, 16, float) -------------------------------------------------------------------------------- /tests/src/multi-cuda-tests-kernel_decl.inc: -------------------------------------------------------------------------------- 1 | #define MAX_K 1048576 2 | #define MIN_K 1048576 3 | #define MIN_KP_K 32 4 | #define MAX_KP_K 32 5 | #define KERNEL_DECL(T, VecT, ElemType) \ 6 | KMMKernel{(void*)kronGemmKernel,128, 64, 64, 64, 2, 4096, 2, 16, 1, ElemType, 1, 0, 0},\ 7 | KMMKernel{(void*)kronGemmKernel,128, 64, 64, 64, 2, 4096, 2, 16, 1, ElemType, 1, 0, 1},\ 8 | KMMKernel{(void*)kronGemmKernel,128, 128, 128, 128, 1, 8192, 2, 32, 1, ElemType, 1, 0, 1},\ 9 | KMMKernel{(void*)kronGemmKernel,128, 128, 128, 128, 1, 8192, 2, 32, 1, ElemType, 1, 0, 0} -------------------------------------------------------------------------------- /tests/src/old/tuner-tests.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "testBase.h" 3 | 4 | #define NO_FUSION_TUNER_TEST(M, Facs, FacSize, Type) \ 5 | TEST(EXPAND(TEST_BACKEND,Tuner), Type##_##M##x##Facs##x##FacSize##_) { \ 6 | uint KP_MAT_N[Facs];\ 7 | uint KP_MAT_K[Facs];\ 8 | uint N = 1;\ 9 | uint K = 1;\ 10 | for (uint i = 0; i < (uint)Facs; i++) {\ 11 | N *= FacSize;\ 12 | K *= FacSize;\ 13 | KP_MAT_K[i] = KP_MAT_N[i] = FacSize;\ 14 | }\ 15 | bool b = run(M, N, K, Facs, KP_MAT_N, KP_MAT_K, fastKronOp_N, fastKronOp_N, \ 16 | 1, 0, false, 1, 1, 1, 1, true, true, true, fastKronBackend_CUDA, false);\ 17 | EXPECT_TRUE(b);\ 18 | } 19 | 20 | NO_FUSION_TUNER_TEST(512, 4, 16, float) 21 | 22 | NO_FUSION_TUNER_TEST(512, 3, 64, float) 23 | 24 | // SINGLE_GPU_NO_FUSION_TUNER_TEST(1, 3, 32, float) 25 | 26 | // SINGLE_GPU_NO_FUSION_TUNER_TEST(4, 2, 64, float) -------------------------------------------------------------------------------- /tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package (Python3) 2 | set(GEN_TUNER_KERNELS_PY ${SRC}/gen_tuner_kernels.py) 3 | 4 | ADD_SUBDIRECTORY(googletest EXCLUDE_FROM_ALL) 5 | enable_testing() 6 | include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) 7 | set(TEST_SRC ${CMAKE_CURRENT_SOURCE_DIR}/src) 8 | 9 | set(TESTS_INCLUDES ${CMAKE_CURRENT_SOURCE_DIR} ${SRC}) 10 | set(TESTS_LIBS GTest::gtest_main FastKron) 11 | if (ENABLE_MULTI_GPU AND ENABLE_CUDA) 12 | set(TESTS_LIBS ${TESTS_LIBS} nccl) 13 | endif() 14 | 15 | add_subdirectory(benchmarks) 16 | 17 | if (ENABLE_CUDA) 18 | add_subdirectory(cuda) 19 | add_custom_target(run-cuda-tests 20 | COMMAND single-gpu-cuda-all 21 | DEPENDS single-gpu-cuda-all) 22 | endif() 23 | 24 | if (ENABLE_X86) 25 | add_subdirectory(x86) 26 | add_custom_target(run-x86-tests 27 | COMMAND x86-cpu-all 28 | DEPENDS x86-cpu-all) 29 | endif() -------------------------------------------------------------------------------- /src/utils/logger.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "env/env.h" 4 | 5 | #pragma once 6 | 7 | class Logger { 8 | private: 9 | LogLevel level; 10 | 11 | public: 12 | Logger(LogLevel level) : level(level) { 13 | if (valid()) 14 | std::cout << "[FastKron] "; 15 | } 16 | 17 | template 18 | Logger& operator<< (const T &x) { 19 | if (valid()) { 20 | std::cout << x; 21 | } 22 | 23 | return *this; 24 | } 25 | 26 | Logger& operator<< (std::ostream& (*f)(std::ostream &)) { 27 | if (valid()) 28 | f(std::cout); 29 | return *this; 30 | } 31 | 32 | Logger& operator<< (std::ostream& (*f)(std::ios &)) { 33 | if (valid()) 34 | f(std::cout); 35 | return *this; 36 | } 37 | 38 | Logger& operator<< (std::ostream& (*f)(std::ios_base &)) { 39 | if (valid()) 40 | f(std::cout); 41 | return *this; 42 | } 43 | 44 | bool valid() {return level <= env::getLogLevel();} 45 | }; -------------------------------------------------------------------------------- /include/fastkronMg.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | 6 | #pragma once 7 | 8 | extern "C" { 9 | //backends is a bitwise OR 10 | fastKronError fastKronMgInitCUDA(fastKronHandle handle, void *streams, int gpus, int gpusInM = -1, int gpusInK = -1, int gpuLocalKrons = -1); 11 | 12 | //TODO: modify such that the results are always written to the supplied result pointer 13 | fastKronError fastKronMgSGEMM(fastKronHandle handle, const uint32_t NumKronMats, void* x[], void* kronMats[], void* result[], 14 | uint32_t M, uint32_t N, uint32_t K, uint32_t KronMatCols[], uint32_t KronMatRows[], 15 | void* temp1[], void* temp2[], void* stream); 16 | fastKronError fastKronMgAllocX(fastKronHandle handle, void* dX[], void* hX, uint32_t M, uint32_t K); 17 | fastKronError fastKronMgGatherY(fastKronHandle handle, void* dY[], void* hY, uint32_t M, uint32_t K, uint32_t NumKronMats, uint32_t KronMatCols[], uint32_t KronMatRows[]); 18 | } -------------------------------------------------------------------------------- /example/x86-example.cpp: -------------------------------------------------------------------------------- 1 | //x86-example.cpp 2 | #include 3 | 4 | #include 5 | 6 | int main() { 7 | //Define Problem Sizes 8 | uint32_t N = 5; 9 | uint32_t M = 16; 10 | uint32_t Ps[5] = {8,8,8,8,8}, Qs[5] = {8,8,8,8,8}; 11 | 12 | //Allocate inputs and output 13 | float* x, *fs[N], *z; 14 | x = new float[M * (int)powf(Ps[0], N)]; 15 | for (int i = 0; i < N; i++) fs[i] = new float[Ps[0]*Qs[0]]; 16 | z = new float[M * (int)powf(Qs[0], N)]; 17 | 18 | //Initialize FastKron with all backends (CUDA and x86 by default) 19 | fastKronHandle handle; 20 | fastKronInitAllBackends(&handle); 21 | 22 | //Get Temporary size and allocate temporary 23 | size_t tempSize, resultSize; 24 | gekmmSizes(handle, M, N, Ps, Qs, &resultSize, &tempSize); 25 | 26 | float* temp; 27 | temp = new float[tempSize]; 28 | 29 | //Do KronMatmul 30 | sgekmm(handle, fastKronBackend_X86, N, Qs, Ps, M, 31 | (float**)fs, fastKronOp_N, x, fastKronOp_N, z, 1, 0, nullptr, 32 | temp, nullptr); 33 | 34 | //Destroy FastKron 35 | fastKronDestroy(handle); 36 | } -------------------------------------------------------------------------------- /src/kmm/coord.h: -------------------------------------------------------------------------------- 1 | #include "config.h" 2 | 3 | // template 4 | // class Coord { 5 | // uint32_t val[Dim]; 6 | 7 | // protected: 8 | // Coord(uint32_t... args) { 9 | // for (int i = 0; i < Dim; i++) { 10 | // this->val[i] = val[i]; 11 | // } 12 | // } 13 | 14 | // uint32_t value(uint32_t i) const {return val[i];} 15 | // }; 16 | 17 | class Coord2D { 18 | uint32_t val[2]; 19 | 20 | public: 21 | CUDA_DEVICE_HOST 22 | Coord2D(uint32_t i, uint32_t j) { 23 | val[0] = i; 24 | val[1] = j; 25 | } 26 | 27 | CUDA_DEVICE_HOST 28 | uint32_t i() const {return val[0];} 29 | CUDA_DEVICE_HOST 30 | uint32_t j() const {return val[1];} 31 | }; 32 | 33 | class Coord3D { 34 | uint32_t val[3]; 35 | 36 | public: 37 | CUDA_DEVICE_HOST 38 | Coord3D(uint32_t i, uint32_t j, uint32_t k) { 39 | val[0] = i; 40 | val[1] = j; 41 | val[2] = k; 42 | } 43 | 44 | CUDA_DEVICE_HOST 45 | uint32_t i() const {return val[0];} 46 | CUDA_DEVICE_HOST 47 | uint32_t j() const {return val[1];} 48 | CUDA_DEVICE_HOST 49 | uint32_t k() const {return val[2];} 50 | }; -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=64", "setuptools-scm>=8"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name="pyfastkron" 7 | authors=[{name="Abhinav Jangda", email="abhijangda@gmail.com"}] 8 | maintainers = [{name="Abhinav Jangda", email="abhijangda@gmail.com"}] 9 | description="A library for efficient matrix and kronecker product matrix multiplication on parallel hardware" 10 | dynamic = ["version"] 11 | requires-python= ">= 3.9" 12 | license = {file="LICENSE.txt"} 13 | readme = "README.md" 14 | keywords = ["kronecker product", "cuda", "gpu", "kronecker matrix multiplication"] 15 | dependencies = [ 16 | "numpy", 17 | "torch", 18 | "torchvision" 19 | ] 20 | 21 | [project.urls] 22 | Homepage = "https://github.com/abhijangda/fastkron" 23 | Documentation = "https://github.com/abhijangda/fastkron" 24 | Repository = "https://github.com/abhijangda/fastkron" 25 | 26 | [tool.pytest.ini_options] 27 | minversion = "6.0" 28 | testpaths = [ 29 | "tests/python", 30 | ] 31 | 32 | [tool.setuptools_scm] 33 | local_scheme = "node-and-date" 34 | write_to = "./pyfastkron/version.py" -------------------------------------------------------------------------------- /tests/src/old/TT-tests.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "testBase.h" 3 | 4 | #define TT_TEST(M, Facs, FacSize, Type) \ 5 | TEST(EXPAND(TEST_BACKEND, TT), Type##_##M##x##Facs##x##FacSize##_) { \ 6 | uint KP_MAT_N[Facs];\ 7 | uint KP_MAT_K[Facs];\ 8 | uint N = 1;\ 9 | uint K = 1;\ 10 | for (uint i = 0; i < (uint)Facs; i++) {\ 11 | N *= FacSize;\ 12 | K *= FacSize;\ 13 | KP_MAT_K[i] = KP_MAT_N[i] = FacSize;\ 14 | }\ 15 | bool b = run(M, N, K, Facs, KP_MAT_N, KP_MAT_K, fastKronOp_T, fastKronOp_T, 1, 0, false, 1, 1, 1, 1, true, true, true, getTestBackend(), false);\ 16 | EXPECT_TRUE(b);\ 17 | } 18 | 19 | //FacSize 8 20 | TT_TEST(11, 4, 8, float); 21 | TT_TEST(11, 5, 8, float); 22 | TT_TEST(11, 8, 8, float); 23 | 24 | //FacSize 16 25 | TT_TEST(11, 2, 16, float); 26 | TT_TEST(11, 5, 16, float); 27 | // TT_TEST(11, 6, 16, float); 28 | 29 | //FacSize 32 30 | TT_TEST(11, 2, 32, float); 31 | TT_TEST(11, 3, 32, float); 32 | 33 | //FacSize 64 34 | TT_TEST(11, 2, 64, float); 35 | TT_TEST(11, 3, 64, float); 36 | 37 | //FacSize 128 38 | TT_TEST(11, 2, 128, float); 39 | TT_TEST(11, 3, 128, float); -------------------------------------------------------------------------------- /tests/src/multi-cuda-no-fusion-non-square-tests.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "testBase.h" 3 | 4 | #define MULTI_GPU_NON_SQUARE_TEST(M, Facs, P, Q, GM, GK, KronBatch, Type) \ 5 | TEST(MultiGpuNonSquare, Type##_##M##x##Facs##x##P##x##Q##_##GM##x##GK##x##KronBatch##_) { \ 6 | uint KP_MAT_N[Facs];\ 7 | uint KP_MAT_K[Facs];\ 8 | uint N = 1;\ 9 | uint K = 1;\ 10 | for (uint i = 0; i < (uint)Facs; i++) {\ 11 | N *= Q;\ 12 | K *= P;\ 13 | KP_MAT_K[i] = P;\ 14 | KP_MAT_N[i] = Q;\ 15 | }\ 16 | bool b = run(FastKronMMType::MKM, M, N, K, Facs, KP_MAT_N, KP_MAT_K, fastKronOp_N, fastKronOp_N, 1,1,1,1,1.0f, 0.0f, 1, 0, false, GM, GK, GM*GK, KronBatch, true, false, true, fastKronBackend_CUDA, false, false);\ 17 | EXPECT_TRUE(b);\ 18 | } 19 | 20 | MULTI_GPU_NON_SQUARE_TEST(18, 5, 8, 32, 2, 1, 5, float); 21 | MULTI_GPU_NON_SQUARE_TEST(14, 5, 8, 32, 1, 2, 3, float); 22 | MULTI_GPU_NON_SQUARE_TEST(18, 5, 8, 32, 2, 2, 3, float); 23 | 24 | MULTI_GPU_NON_SQUARE_TEST(18, 4, 64, 16, 2, 1, 4, float); 25 | MULTI_GPU_NON_SQUARE_TEST(14, 4, 64, 16, 1, 2, 3, float); 26 | MULTI_GPU_NON_SQUARE_TEST(18, 4, 64, 16, 2, 2, 3, float); -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Abhinav Jangda 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/kernel_db/kernel_db.inline.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | template 4 | void KernelDatabase::loadKernels(SubClassKernel* kernels, uint32_t numKernels) { 5 | //Load kernels into compiledKernels map 6 | for (uint i = 0; i < numKernels; i++) { 7 | SubClassKernel& info = kernels[i]; 8 | DbKey key {info.getMaxFactor(), info.getOpX(), info.getOpF(), info.getBatchType()}; 9 | auto iter = compiledKernels.find(key); 10 | if (iter == compiledKernels.end()) { 11 | compiledKernels.emplace(std::make_pair(key, std::vector())); 12 | } 13 | compiledKernels.at(key).push_back(&info); 14 | } 15 | 16 | if (false && Logger(LogLevel::Debug).valid()) { 17 | //Print loaded kernels 18 | uint numKernelsLoaded = 0; 19 | Logger(LogLevel::Debug) << "Loading compiled kernels" << std::endl; 20 | for (auto iter : compiledKernels) { 21 | for (auto kernel : iter.second) { 22 | Logger(LogLevel::Debug) << kernel->str() << std::endl; 23 | } 24 | numKernelsLoaded += iter.second.size(); 25 | } 26 | Logger(LogLevel::Debug) << "Number of kernels loaded: " << numKernelsLoaded << std::endl; 27 | } 28 | } -------------------------------------------------------------------------------- /tests/src/multi-cuda-no-fusion-tests.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "testBase.h" 3 | 4 | #define MULTI_GPU_NO_FUSION_TEST(M, Facs, FacSize, GM, GK, KronBatch, Type) \ 5 | TEST(MultiGpuNoFusion, Type##_##M##x##Facs##x##FacSize##x##GM##x##GK##x##KronBatch##_) { \ 6 | uint KP_MAT_N[Facs];\ 7 | uint KP_MAT_K[Facs];\ 8 | uint N = 1;\ 9 | uint K = 1;\ 10 | for (uint i = 0; i < (uint)Facs; i++) {\ 11 | N *= FacSize;\ 12 | K *= FacSize;\ 13 | KP_MAT_K[i] = KP_MAT_N[i] = FacSize;\ 14 | }\ 15 | int devices = 0;\ 16 | CUDACHECK(cudaGetDeviceCount(&devices));\ 17 | if (devices < GM * GK) {EXPECT_TRUE(true); return;}\ 18 | bool b = run(FastKronMMType::MKM, M, N, K, Facs, KP_MAT_N, KP_MAT_K, fastKronOp_N, fastKronOp_N, 1,1,1,1, 1.0f, 0.0f, 0, 0, false, GM, GK, GM*GK, KronBatch, true, false, true, fastKronBackend_CUDA, false, false);\ 19 | EXPECT_TRUE(b);\ 20 | } 21 | 22 | MULTI_GPU_NO_FUSION_TEST(20, 4, 64, 2, 1, 4, float); 23 | MULTI_GPU_NO_FUSION_TEST(18, 4, 64, 1, 2, 3, float); 24 | MULTI_GPU_NO_FUSION_TEST(18, 4, 64, 2, 2, 2, float); 25 | 26 | MULTI_GPU_NO_FUSION_TEST(12, 3, 128, 1, 4, 2, float); 27 | MULTI_GPU_NO_FUSION_TEST(8, 4, 128, 2, 4, 3, float); -------------------------------------------------------------------------------- /tests/x86/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # X86 tests 2 | set(X86_TEST_DEFINE -D TEST_BACKEND_X86) 3 | 4 | add_custom_target(gen-x86-kernels 5 | COMMAND python3 ${GEN_TUNER_KERNELS_PY} -backend x86 -archs sisd avx avx512 -same-factors 2 128,128 -same-factors 2 64,64 -same-factors 3 32,32 -same-factors 5 16,16 -same-factors 7 8,8 -same-factors 10 4,4 -same-factors 20 2,2 -opX N T -opF N T -types float double -match-configs-file ${SRC}/kernels/best-kernels/x86-avx-kernels) 6 | 7 | add_executable(x86-cpu-NN ${TEST_SRC}/general-tests-NN.cpp) 8 | target_include_directories(x86-cpu-NN PRIVATE ${TESTS_INCLUDES}) 9 | target_link_libraries(x86-cpu-NN ${TESTS_LIBS}) 10 | target_compile_definitions(x86-cpu-NN PRIVATE ${X86_TEST_DEFINE}) 11 | 12 | add_executable(x86-cpu-TT ${TEST_SRC}/general-tests-TT.cpp) 13 | target_include_directories(x86-cpu-TT PRIVATE ${TESTS_INCLUDES}) 14 | target_link_libraries(x86-cpu-TT ${TESTS_LIBS}) 15 | target_compile_definitions(x86-cpu-TT PRIVATE ${X86_TEST_DEFINE}) 16 | 17 | add_executable(x86-cpu-all ${TEST_SRC}/general-tests-all.cpp) 18 | target_include_directories(x86-cpu-all PRIVATE ${TESTS_INCLUDES}) 19 | target_link_libraries(x86-cpu-all ${TESTS_LIBS}) 20 | target_compile_definitions(x86-cpu-all PRIVATE ${X86_TEST_DEFINE}) -------------------------------------------------------------------------------- /example/cuda-example.cu: -------------------------------------------------------------------------------- 1 | //cuda-example.cu 2 | #include 3 | #include 4 | 5 | int main() { 6 | //Define Problem Sizes 7 | uint32_t N = 5; 8 | uint32_t M = 1024; 9 | uint32_t Ps[5] = {8,8,8,8,8}, Qs[5] = {8,8,8,8,8}; 10 | 11 | //Allocate inputs and output 12 | float* x, *fs[N], *z; 13 | cudaMalloc(&x, M * (int)powf(Ps[0], N) * sizeof(float)); 14 | for (int i = 0; i < N; i++) cudaMalloc(&fs[i], Ps[0]*Qs[0] * sizeof(float)); 15 | cudaMalloc(&z, M * (int)powf(Qs[0], N) * sizeof(float)); 16 | 17 | //Initialize FastKron with CUDA 18 | fastKronHandle handle; 19 | fastKronInit(&handle, fastKronBackend_CUDA); 20 | 21 | //Initialize FastKron's CUDA with stream 22 | cudaStream_t stream; 23 | cudaStreamCreate(&stream); 24 | fastKronInitCUDA(handle, (void*)&stream); 25 | 26 | //Get Temporary size and allocate temporary 27 | size_t tempSize, resultSize; 28 | gekmmSizes(handle, M, N, Ps, Qs, &resultSize, &tempSize); 29 | 30 | float* temp; 31 | cudaMalloc(&temp, tempSize * sizeof(float)); 32 | 33 | //Do KronMatmul using the tuned kernel 34 | 35 | sgemkm(handle, fastKronBackend_CUDA, M, N, Ps, Qs, 36 | x, fastKronOp_N, fs, fastKronOp_N, z, 1, 0, nullptr, 37 | temp, nullptr); 38 | 39 | //Destroy FastKron 40 | fastKronDestroy(handle); 41 | } -------------------------------------------------------------------------------- /src/handle/distrib_handle.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "handle/handle.h" 5 | 6 | fastKronError distributedKronMatmul(FastKronHandle&, const uint, void*[], void* [], void* [], 7 | uint , uint , uint , uint [], uint [], void** , void** , 8 | void* ) { 9 | std::cout << "Not implemented" << std::endl; 10 | assert(false); 11 | return fastKronSuccess; 12 | } 13 | 14 | fastKronError FastKronHandle::allocDistributedX(void* [], void* , uint , uint ) { 15 | std::cout << "Not implemented" << std::endl; 16 | return fastKronSuccess; 17 | } 18 | 19 | fastKronError FastKronHandle::gatherDistributedY(void* [], void* , uint , uint , uint , uint [], uint []) { 20 | //TODO: Make FastKronError type 21 | std::cout << "Not implemented" << std::endl; 22 | return fastKronSuccess; 23 | } 24 | 25 | fastKronError FastKronHandle::distributedsgekmm(const uint NumKronMats, float* x[], float* kronMats[], float* result[], 26 | uint M, uint N, uint K, uint KronMatCols[], uint KronMatRows[], float** temp1, float** temp2, 27 | void* streams) { 28 | return distributedKronMatmul(*this, NumKronMats, (void**)x, (void**)kronMats, (void**)result, M, N, K, 29 | KronMatCols, KronMatRows, (void**)temp1, (void**)temp2, streams); 30 | } 31 | -------------------------------------------------------------------------------- /src/handle/handle.inline.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | inline bool FastKronHandle::hasBackend(fastKronBackend backend) { 4 | return (backends & backend); 5 | } 6 | 7 | inline KernelDatabase* FastKronHandle::getKernelDb(fastKronBackend backend) { 8 | switch (backend) { 9 | case fastKronBackend_X86: 10 | #ifdef ENABLE_X86 11 | return &x86Kernels; 12 | #endif 13 | case fastKronBackend_CUDA: 14 | #ifdef ENABLE_CUDA 15 | return &cudaKernels; 16 | #endif 17 | case fastKronBackend_HIP: 18 | #ifdef ENABLE_HIP 19 | return &hipKernels; 20 | #endif 21 | default: 22 | return nullptr; 23 | } 24 | } 25 | 26 | inline std::vector FastKronHandle::getAllKernelDbs() { 27 | std::vector out; 28 | if (hasBackend(fastKronBackend_X86)) { 29 | out.push_back(getKernelDb(fastKronBackend_X86)); 30 | } else if (hasBackend(fastKronBackend_CUDA)) { 31 | out.push_back(getKernelDb(fastKronBackend_CUDA)); 32 | } else if (hasBackend(fastKronBackend_HIP)) { 33 | out.push_back(getKernelDb(fastKronBackend_HIP)); 34 | } 35 | return out; 36 | } 37 | 38 | inline void FastKronHandle::setOptions(uint32_t options) {this->options = options;} 39 | 40 | inline bool FastKronHandle::canTune() { 41 | return (options & fastKronOptionsTune) == 42 | fastKronOptionsTune; 43 | } 44 | 45 | inline bool FastKronHandle::getUseFusion() { 46 | return (options & fastKronOptionsUseFusion) == 47 | fastKronOptionsUseFusion; 48 | } -------------------------------------------------------------------------------- /src/config.h: -------------------------------------------------------------------------------- 1 | #if defined(__NVCC__) || defined(__CUDACC__) || defined(__HIPCC__) 2 | #if !defined(__forceinline__) 3 | #define __forceinline__ inline 4 | #endif 5 | 6 | #define CUDA_HOST __host__ 7 | #define CUDA_DEVICE __device__ __forceinline__ 8 | #define CUDA_DEVICE_HOST CUDA_HOST CUDA_DEVICE 9 | #else 10 | #define CUDA_HOST 11 | #define CUDA_DEVICE inline 12 | #define CUDA_DEVICE_HOST CUDA_HOST CUDA_DEVICE 13 | #endif 14 | 15 | #define PRAGMA(X) _Pragma(#X) 16 | 17 | #if defined(__clang__) 18 | #define CXX_PRAGMA_PUSH_OPTIONS _Pragma("") 19 | #define CXX_PRAGMA_O3 20 | #define CXX_PRAGMA_ARCH_SISD _Pragma("clang attribute push (__attribute__((target(\"arch=x86-64-v2\"))), apply_to=function)") 21 | #define CXX_PRAGMA_ARCH_AVX _Pragma("clang attribute push (__attribute__((target(\"arch=x86-64-v3\"))), apply_to=function)") 22 | #define CXX_PRAGMA_ARCH_AVX512 _Pragma("clang attribute push (__attribute__((target(\"arch=x86-64-v4\"))), apply_to=function)") 23 | #define CXX_PRAGMA_POP_OPTIONS _Pragma("clang attribute pop") 24 | 25 | #elif defined(__GNUC__) || defined(__GNUG__) 26 | #define CXX_PRAGMA_PUSH_OPTIONS _Pragma("GCC push_options") 27 | #define CXX_PRAGMA_O3 _Pragma("GCC optimization(\"O3\")") 28 | #define CXX_PRAGMA_ARCH_SISD _Pragma("GCC target(\"arch=x86-64-v2\")") 29 | #define CXX_PRAGMA_ARCH_AVX _Pragma("GCC target(\"arch=x86-64-v3\")") 30 | #define CXX_PRAGMA_ARCH_AVX512 _Pragma("GCC target(\"arch=x86-64-v4\")") 31 | #define CXX_PRAGMA_POP_OPTIONS _Pragma("GCC pop_options") 32 | #endif -------------------------------------------------------------------------------- /src/kernels/cpu/mma.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | template 6 | static CUDA_DEVICE_HOST 7 | void loadYInterim(uint32_t tileP, const YElem& y, 8 | const FCache& Fch, YInterim& Ych, YRegisters& YReg) { 9 | if (tileP == 0) { 10 | YReg.zero(); 11 | } else { 12 | //TODO: For OpY=fastKronOp_T YReg.apply should have last loop in m 13 | YReg.apply([&](X86VecT& e, const uint32_t ym, const uint32_t yk, const uint32_t yq) { 14 | e.load(&Ych.at(y.m() + ym * YReg.mvec(), y.q() + yq, y.k()/Fch.p() + yk * YReg.kvec())); 15 | }); 16 | } 17 | } 18 | 19 | template 22 | static CUDA_DEVICE_HOST 23 | void mma(uint32_t /*tileP*/, const YElem& y, 24 | const XCache& Xch, const FCache& Fch, 25 | YInterim& /*Ych*/, YRegisters& YReg) { 26 | const fastKronOp Layout = YRegisters::layout(); 27 | 28 | for (uint32_t p = 0; p < Fch.p(); p++) { 29 | XRegisters XReg; 30 | FRegisters FReg; 31 | XReg.apply([&](X86VecT& e, const uint32_t em, const uint32_t ek, const uint32_t ep) { 32 | e.load(&Xch.at(y.m() + em*YReg.mvec(), y.k()/Fch.p() + ek*YReg.kvec(), p + ep)); 33 | }); 34 | 35 | FReg.apply([&](X86VecT& e, const uint32_t ep, const uint32_t eq) { 36 | e.broadcast(&Fch.at(p + ep, y.q() + eq)); 37 | }); 38 | 39 | YReg.apply([&](X86VecT& e, const uint32_t ym, const uint32_t yk, const uint32_t yq) { 40 | e.fmadd(XReg.at(ym, yk, 0), FReg.at(0, yq)); 41 | }); 42 | } 43 | } -------------------------------------------------------------------------------- /tests/benchmarks/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(AnyOption EXCLUDE_FROM_ALL) 2 | 3 | if (ENABLE_CUDA) 4 | add_executable(benchmark_cuda benchmark.cpp) 5 | set_source_files_properties(benchmark.cpp PROPERTIES LANGUAGE CUDA) 6 | add_dependencies(benchmark_cuda anyoption) 7 | target_compile_definitions(benchmark_cuda PUBLIC TEST_BACKEND_CUDA) 8 | target_include_directories(benchmark_cuda PRIVATE AnyOption ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${CMAKE_CURRENT_SOURCE_DIR}/../ ${SRC}/) 9 | if (ENABLE_MULTI_GPU) 10 | target_link_libraries(benchmark_cuda PRIVATE FastKron anyoption nccl) 11 | else() 12 | target_link_libraries(benchmark_cuda PRIVATE FastKron anyoption) 13 | endif() 14 | if(OpenMP_CXX_FOUND) 15 | target_link_libraries(benchmark_cuda PRIVATE OpenMP::OpenMP_CXX) 16 | endif() 17 | endif() 18 | 19 | if (ENABLE_X86) 20 | add_executable(benchmark_x86 benchmark.cpp) 21 | add_dependencies(benchmark_x86 anyoption) 22 | target_compile_definitions(benchmark_x86 PUBLIC TEST_BACKEND_X86) 23 | target_include_directories(benchmark_x86 PRIVATE AnyOption ${CMAKE_CURRENT_SOURCE_DIR}/../ ${SRC}/) 24 | target_link_libraries(benchmark_x86 PRIVATE FastKron anyoption) 25 | if(OpenMP_CXX_FOUND) 26 | target_link_libraries(benchmark_x86 PRIVATE OpenMP::OpenMP_CXX) 27 | endif() 28 | endif() 29 | 30 | if (ENABLE_HIP) 31 | add_executable(benchmark_hip benchmark.cpp) 32 | set_source_files_properties(benchmark.cpp PROPERTIES LANGUAGE HIP) 33 | add_dependencies(benchmark_hip anyoption) 34 | target_compile_definitions(benchmark_hip PUBLIC TEST_BACKEND_HIP) 35 | target_include_directories(benchmark_hip PRIVATE AnyOption ${HIP_INCLUDE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/../ ${SRC}/) 36 | target_link_libraries(benchmark_hip PRIVATE FastKron anyoption) 37 | if(OpenMP_CXX_FOUND) 38 | target_link_libraries(benchmark_hip PRIVATE OpenMP::OpenMP_CXX) 39 | endif() 40 | endif() -------------------------------------------------------------------------------- /tests/src/old/fusion-tests.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "testBase.h" 3 | 4 | #define FUSION_TEST(M, Facs, FacSize, Type) \ 5 | TEST(EXPAND(TEST_BACKEND,Fusion), Type##_##M##x##Facs##x##FacSize##_) { \ 6 | uint KP_MAT_N[Facs];\ 7 | uint KP_MAT_K[Facs];\ 8 | uint N = 1;\ 9 | uint K = 1;\ 10 | for (uint i = 0; i < (uint)Facs; i++) {\ 11 | N *= FacSize;\ 12 | K *= FacSize;\ 13 | KP_MAT_K[i] = KP_MAT_N[i] = FacSize;\ 14 | }\ 15 | bool b = run(M, N, K, Facs, KP_MAT_N, KP_MAT_K, fastKronOp_N, fastKronOp_N, 1, 0, false, 1, 1, 1, 1, true, true, true, getTestBackend(), false);\ 16 | EXPECT_TRUE(b);\ 17 | } 18 | 19 | //FacSize 2 20 | // FUSION_TEST(11, 7, 2, float); 21 | // FUSION_TEST(11, 8, 2, float); 22 | FUSION_TEST(11, 10, 2, float); 23 | FUSION_TEST(11, 15, 2, float); 24 | FUSION_TEST(11, 20, 2, float); 25 | 26 | //FacSize 4 27 | // FUSION_TEST(11, 4, 4, float); 28 | // FUSION_TEST(11, 6, 4, float); 29 | FUSION_TEST(11, 8, 4, float); 30 | FUSION_TEST(11, 9, 4, float); 31 | FUSION_TEST(11, 10, 4, float); 32 | 33 | //FacSize 8 34 | // FUSION_TEST(11, 4, 8, float); 35 | // FUSION_TEST(11, 5, 8, float); 36 | FUSION_TEST(11, 6, 8, float); 37 | FUSION_TEST(11, 7, 8, float); 38 | FUSION_TEST(11, 8, 8, float); 39 | 40 | //FacSize 16 41 | FUSION_TEST(11, 2, 16, float); 42 | FUSION_TEST(11, 3, 16, float); 43 | FUSION_TEST(11, 4, 16, float); 44 | // FUSION_TEST(11, 5, 16, float); 45 | // FUSION_TEST(11, 6, 16, float); 46 | 47 | //FacSize 32 48 | FUSION_TEST(11, 2, 32, float); 49 | FUSION_TEST(11, 3, 32, float); 50 | // FUSION_TEST(11, 4, 32, float); 51 | // FUSION_TEST(11, 5, 32, float); 52 | 53 | FUSION_TEST(12, 3, 32, float); 54 | 55 | //FacSize 64 56 | FUSION_TEST(11, 2, 64, float); 57 | // FUSION_TEST(11, 3, 64, float); 58 | // FUSION_TEST(11, 4, 64, float); 59 | 60 | //FacSize 128 61 | FUSION_TEST(11, 2, 128, float); 62 | // FUSION_TEST(11, 3, 128, float); 63 | 64 | // //FacSize 256 65 | // SINGLE_GPU_FUSION_TEST(11, 2, 128, float); 66 | // SINGLE_GPU_FUSION_TEST(11, 3, 128, float); -------------------------------------------------------------------------------- /tests/src/old/no-fusion-tests.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "testBase.h" 3 | 4 | #define NO_FUSION_TEST(M, Facs, FacSize, Type) \ 5 | TEST(EXPAND(TEST_BACKEND,NoFusion), Type##_##M##x##Facs##x##FacSize##_) { \ 6 | uint KP_MAT_N[Facs];\ 7 | uint KP_MAT_K[Facs];\ 8 | uint N = 1;\ 9 | uint K = 1;\ 10 | for (uint i = 0; i < (uint)Facs; i++) {\ 11 | N *= FacSize;\ 12 | K *= FacSize;\ 13 | KP_MAT_K[i] = KP_MAT_N[i] = FacSize;\ 14 | }\ 15 | bool b = run(M, N, K, Facs, KP_MAT_N, KP_MAT_K, fastKronOp_N, fastKronOp_N, 1, 0, false, 1, 1, 1, 1, true, false, true, getTestBackend( ), false);\ 16 | EXPECT_TRUE(b);\ 17 | } 18 | 19 | //FacSize 2 20 | // NO_FUSION_TEST(1, 7, 2, float); 21 | // NO_FUSION_TEST(1, 8, 2, float); 22 | NO_FUSION_TEST(11, 10, 2, float); 23 | NO_FUSION_TEST(11, 15, 2, float); 24 | NO_FUSION_TEST(11, 20, 2, float); 25 | 26 | // //FacSize 4 27 | // NO_FUSION_TEST(1, 4, 4, float); 28 | // NO_FUSION_TEST(1, 6, 4, float); 29 | NO_FUSION_TEST(11, 8, 4, float); 30 | NO_FUSION_TEST(11, 9, 4, float); 31 | NO_FUSION_TEST(11, 10, 4, float); 32 | 33 | // //FacSize 8 34 | // NO_FUSION_TEST(11, 4, 8, float); 35 | // NO_FUSION_TEST(11, 5, 8, float); 36 | NO_FUSION_TEST(11, 6, 8, float); 37 | NO_FUSION_TEST(11, 7, 8, float); 38 | NO_FUSION_TEST(11, 8, 8, float); 39 | 40 | // //FacSize 16 41 | // NO_FUSION_TEST(11, 2, 16, float); 42 | // NO_FUSION_TEST(11, 3, 16, float); 43 | NO_FUSION_TEST(11, 4, 16, float); 44 | NO_FUSION_TEST(11, 5, 16, float); 45 | // NO_FUSION_TEST(11, 6, 16, float); 46 | 47 | // //FacSize 32 48 | // NO_FUSION_TEST(11, 2, 32, float); 49 | NO_FUSION_TEST(11, 3, 32, float); 50 | NO_FUSION_TEST(11, 4, 32, float); 51 | // NO_FUSION_TEST(11, 5, 32, float); 52 | 53 | // // NO_FUSION_TEST(12, 3, 32, float); 54 | 55 | // //FacSize 64 56 | NO_FUSION_TEST(11, 2, 64, float); 57 | NO_FUSION_TEST(11, 3, 64, float); 58 | // NO_FUSION_TEST(11, 4, 64, float); 59 | 60 | // // //FacSize 128 61 | NO_FUSION_TEST(11, 2, 128, float); 62 | NO_FUSION_TEST(11, 3, 128, float); 63 | 64 | // // //FacSize 256 65 | // // NO_FUSION_TEST(1, 2, 128, float); 66 | // // NO_FUSION_TEST(1, 3, 128, float); -------------------------------------------------------------------------------- /tests/benchmarks/torch_kron.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def initmat(m, n): 4 | return torch.randint(0, 10, (m,n)) #randn(m,n) 5 | 6 | def baseline(input, kronmats): 7 | outputKron = kronmats[0] 8 | for m in kronmats[1:]: 9 | outputKron = torch.kron(outputKron, m) 10 | return torch.matmul(input, outputKron) 11 | 12 | def matmulkron(input, kronmats): 13 | output = input 14 | shape = input.shape 15 | 16 | for i,k in enumerate(reversed(kronmats)): 17 | newinput = output.reshape(shape[0] * (shape[1]//k.shape[0]), k.shape[0]) 18 | output = torch.matmul(newinput, k) 19 | output = output.view(shape[0], (shape[1]//k.shape[0]), k.shape[0]) 20 | output = output.transpose(1,2) 21 | # if i == 1: 22 | #print(output.shape) 23 | return output.reshape(shape) 24 | 25 | def contraction(input, kronmats): 26 | output = input 27 | shape = input.shape 28 | n = len(kronmats) 29 | output = output.reshape([shape[0],] + [kronmats[0].shape[0] for i in range(len(kronmats))]) 30 | 31 | for i,k in enumerate(reversed(kronmats)): 32 | print(i, output.shape) 33 | output = torch.tensordot(output, k, dims=([n],[0])) 34 | for j in range(n, 1,-1): 35 | output = output.transpose(j, j-1) 36 | return output.reshape(shape) 37 | 38 | def fackler2019toms(input, kronmats): 39 | #https://dl-acm-org.silk.library.umass.edu/doi/pdf/10.1145/3291041 40 | output = input 41 | shape = input.shape 42 | n = len(kronmats) 43 | output = output.reshape([kronmats[0].shape[0] for i in range(len(kronmats))] + [shape[0],]) 44 | kronrows = kronmats[0].shape[0] 45 | for i,k in enumerate(reversed(kronmats)): 46 | print(i, output.shape) 47 | output = output.mT @ k.T 48 | output = output.reshape([kronmats[0].shape[0] for j in range(len(kronmats) - i - 1)] + [shape[0],] + [kronmats[0].shape[0] for j in range(i+1)]) 49 | 50 | return output.reshape(shape) 51 | 52 | #[5,5], [4,4], [3,3] 53 | #[N, 3, 4, 5] x [5,5] = [N,3,4,5] 54 | #[N,3,5,4] x [4,4] = [N,3,5,4] 55 | #[N,4,5,3]x[3,3] = [N,4,5,3] 56 | 57 | if __name__ == "__main__": 58 | npoints = 3 59 | twoPower = 4 60 | input = initmat(npoints, twoPower**npoints) 61 | kronmats = [] 62 | for s in range(npoints): 63 | kronmats += [initmat(twoPower,twoPower)] 64 | # print(kronmats[0]) 65 | # print(kronmats[1]) 66 | b = baseline(input, kronmats) 67 | 68 | o = matmulkron(input, kronmats) 69 | # o = fackler2019toms(input, kronmats) 70 | # o = contraction(input, kronmats) 71 | 72 | print ((b == o)) 73 | print ((b == o).all()) -------------------------------------------------------------------------------- /src/kernels/cuda/mma.cuh: -------------------------------------------------------------------------------- 1 | template 2 | CUDA_DEVICE 3 | void slicedMMA(XReg& Xr, FReg& Fr, YReg& Yr) { 4 | //Matrix Multiply Accumulate 5 | #pragma unroll 6 | for (uint j = 0; j < Yr.q(); j++) 7 | #pragma unroll 8 | for (uint m = 0; m < Yr.m(); m++) 9 | #pragma unroll 10 | for (uint i = 0; i < Yr.k(); i++) 11 | #pragma unroll 12 | for (uint p = 0; p < Xr.p(); p++) { 13 | Yr.add(m, i, j, Xr.at(m, i, p) * Fr.at(p, j)); 14 | } 15 | } 16 | 17 | template 19 | CUDA_DEVICE 20 | void mainMMA(uint32_t m, XShared& Xsh, FShared& Fsh, YReg& Yr, XReg& Xr, FReg& Fr, const YElem& yElem) { 21 | //Load shared memory Xsh to registers Xr 22 | if (Xsh.layout() == fastKronOp_N) { 23 | #pragma unroll 24 | for (uint rm = 0; rm < Yr.m(); rm++) { 25 | // if (rm < m) { 26 | #pragma unroll 27 | for (uint rk = 0; rk < Xr.k(); rk++) { 28 | uint shXk = yElem.k() + rk; 29 | uint shift = (yElem.k() / Yr.k()); 30 | 31 | #pragma unroll 32 | for (uint p = 0; p < Xr.p(); p++) { 33 | //TODO: bring shift calculation in Xsh.at 34 | //TODO: use the actual type not float 35 | auto temp = Xsh.at(yElem.m() + rm, shXk * Xr.p() + (p + shift)%Xr.p()); 36 | Xr.set(rm, rk, p, temp); 37 | // } 38 | }}} 39 | } else { 40 | #pragma unroll 41 | for (uint rk = 0; rk < Xr.k(); rk++) { 42 | uint shXk = yElem.k() + rk; 43 | uint shift = 0;//(yElem.k() / Yr.k()); 44 | 45 | #pragma unroll 46 | for (uint p = 0; p < Xr.p(); p++) { 47 | #pragma unroll 48 | for (uint rm = 0; rm < Yr.m(); rm++) { 49 | //TODO: bring shift calculation in Xsh.at 50 | auto temp = Xsh.at((yElem.m() + rm + shift)/*%Xsh.m()*/, shXk * Xr.p() + p); 51 | Xr.set(rm, rk, p, temp); 52 | // } 53 | }}} 54 | } 55 | 56 | if (Fsh.layout() == fastKronOp_N) { 57 | #pragma unroll 58 | for (uint rq = 0; rq < Yr.q(); rq++) { 59 | uint shFcol = yElem.q() + rq; 60 | #pragma unroll 61 | for (uint p = 0; p < Xr.p(); p++) { 62 | Fr.set(p, rq, Fsh.at(p, shFcol)); 63 | }} 64 | } else if (Fsh.layout() == fastKronOp_T) { 65 | uint32_t qe = yElem.q(); 66 | #pragma unroll 67 | for (uint rq = 0; rq < Yr.q(); rq++) { 68 | uint32_t shFcol = qe + rq; 69 | #pragma unroll 70 | for (uint p = 0; p < Xr.p(); p++) { 71 | if (true) {//Padding 72 | Fr.set(p, rq, (&Fsh.at(0,0))[shFcol + p*(Fsh.q() + 1)]); 73 | } 74 | }} 75 | } 76 | 77 | slicedMMA(Xr, Fr, Yr); 78 | } -------------------------------------------------------------------------------- /src/kernel_db/hip_kernel_db.h: -------------------------------------------------------------------------------- 1 | #include "kernel_db.h" 2 | 3 | #include "kmm/kmmalgo.h" 4 | #include "kernels/kmmkernel.h" 5 | #include "kernels/params.h" 6 | #include "kernel_db/kernel_db.h" 7 | #include "env/env.h" 8 | #include "utils/thread_pool.h" 9 | 10 | class HIPKernelDatabase : public KernelDatabase { 11 | public: 12 | std::vector streams; 13 | // uint numGPUs_; 14 | // uint gpusInM_; 15 | // uint gpusInK_; 16 | // uint perGPUKronBatch_; 17 | // bool isDistributed_; 18 | // DistComm distComm_; 19 | // std::vector ncclComms; 20 | // pthread_barrier_t* barriers_; 21 | // thread_pool* threads_; 22 | 23 | public: 24 | HIPKernelDatabase(); 25 | ~HIPKernelDatabase() {} 26 | 27 | fastKronError init(void* ptrToStream){ 28 | streams.clear(); 29 | streams.push_back(ptrToStream); 30 | return fastKronSuccess; 31 | } 32 | 33 | void free() { 34 | streams.clear(); 35 | // if (isDistributed_) { 36 | // for (uint g = 0; g < gpusInM_; g++) { 37 | // int s = pthread_barrier_destroy(&barriers_[g]); 38 | // PTHREAD_BARRIER_CHECK(s); 39 | // } 40 | 41 | // delete threads_; 42 | // delete barriers_; 43 | 44 | // if (distComm_ == DistComm::NCCL) { 45 | // for (int i=0; i 2 | 3 | #include 4 | 5 | #include "env/env.h" 6 | #include "utils/logger.h" 7 | 8 | namespace env { 9 | #define ENV_FASTKRON(x) "FASTKRON_" x; 10 | 11 | static char COMM[] = ENV_FASTKRON("COMM"); 12 | static char LOGLEVEL[] = ENV_FASTKRON("LOG"); 13 | static char USETUNE[] = ENV_FASTKRON("TUNE"); 14 | 15 | char* strupr(char* str) { 16 | char *s = str; 17 | while (*s) { 18 | *s = toupper((unsigned char) *s); 19 | s++; 20 | } 21 | return s; 22 | } 23 | 24 | /** 25 | * intEnvToBool() - Convert an integer environment value to boolean. 26 | * @env: The environment variable. 27 | * @defaultBool: Default boolean value to return when env var is not defined. 28 | * 29 | * Return - True if value of env is 1, False if value of env is 0, and default 30 | * if env is not defined 31 | */ 32 | bool intEnvToBool(char* env, bool defaultBool) { 33 | char* val = getenv(env); 34 | if (val == nullptr) return defaultBool; 35 | if (strcmp(val, "0") == 0) return false; 36 | if (strcmp(val, "1") == 0) return true; 37 | Logger(LogLevel::Info) << "Invalid " << env << "=" << val << std::endl; 38 | return defaultBool; 39 | } 40 | 41 | /** 42 | * getDistComm() - Get DistComm value from environment value of COMM 43 | */ 44 | DistComm getDistComm() { 45 | char* val = getenv(COMM); 46 | if (val == nullptr) return DistComm::DistCommNone; 47 | strupr(val); 48 | if (strcmp(val, "P2P") == 0) return DistComm::P2P; 49 | if (strcmp(val, "NCCL") == 0) return DistComm::NCCL; 50 | Logger(LogLevel::Info) << "Invalid " << COMM << "=" << val << std::endl; 51 | return DistComm::DistCommNone; 52 | } 53 | 54 | /** 55 | * getLogLevel() - Get LogLevel value from environment value of LOGLEVEL 56 | */ 57 | LogLevel getLogLevel() { 58 | char *val = getenv(LOGLEVEL); 59 | if (val == nullptr) return LogLevel::Nothing; 60 | strupr(val); 61 | if (strcmp(val, "INFO") == 0) return LogLevel::Info; 62 | if (strcmp(val, "DEBUG") == 0) return LogLevel::Debug; 63 | Logger(LogLevel::Info) << "Invalid " << LOGLEVEL << "=" << val << std::endl; 64 | return LogLevel::Nothing; 65 | } 66 | 67 | /** 68 | * getUseTune() - Get UseTune value from environemtn value of USETUNE 69 | */ 70 | bool getUseTune() { 71 | char *val = getenv(USETUNE); 72 | if (val == nullptr) return false; 73 | if (strcmp(val, "0") == 0) return false; 74 | if (strcmp(val, "1") == 0) return true; 75 | Logger(LogLevel::Info) << "Invalid " << USETUNE << " = " << val << std::endl; 76 | return false; 77 | } 78 | } 79 | 80 | std::ostream& operator<<(std::ostream &out, DistComm comm) { 81 | switch (comm) { 82 | case DistComm::DistCommNone: 83 | out << "CommNone"; 84 | break; 85 | case DistComm::P2P: 86 | out << "P2P"; 87 | break; 88 | case DistComm::NCCL: 89 | out << "NCCL"; 90 | break; 91 | } 92 | return out; 93 | } -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import sys 4 | 5 | from setuptools import setup, Extension 6 | from setuptools.command.build_ext import build_ext as build_ext_orig 7 | 8 | 9 | class CMakeExtension(Extension): 10 | 11 | def __init__(self, name): 12 | # don't invoke the original build_ext for this special extension 13 | super().__init__(name, sources=[]) 14 | 15 | 16 | class build_ext(build_ext_orig): 17 | 18 | def run(self): 19 | for ext in self.extensions: 20 | self.build_cmake(ext) 21 | super().run() 22 | 23 | def build_cmake(self, ext): 24 | cwd = pathlib.Path().absolute() 25 | 26 | # these dirs will be created in build_py, so if you don't have 27 | # any python sources to bundle, the dirs will be missing 28 | build_temp = pathlib.Path(self.build_temp) 29 | build_temp.mkdir(parents=True, exist_ok=True) 30 | extdir = pathlib.Path(self.get_ext_fullpath(ext.name)) 31 | extdir.parent.mkdir(parents=True, exist_ok=True) 32 | # example of cmake args 33 | config = 'Debug' if self.debug else 'Release' 34 | cmake_args = [ 35 | '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + str(extdir.parent.absolute()), 36 | '-DCMAKE_BUILD_TYPE=' + config, 37 | '-DPYMODULE=ON', 38 | ] 39 | if 'X86' in ext.name: 40 | cmake_args += ['-DENABLE_CUDA=OFF', '-DENABLE_X86=ON'] 41 | elif 'CUDA' in ext.name: 42 | cmake_args += ['-DENABLE_CUDA=ON', '-DENABLE_X86=OFF'] 43 | 44 | cmake_args += [f'-DPYTHON_EXECUTABLE={sys.executable}'] 45 | 46 | # example of build args 47 | build_args = [ 48 | '--config ' + config, 49 | '-j' 50 | ] 51 | os.chdir(str(build_temp)) 52 | self.spawn(['cmake', str(cwd)] + cmake_args) 53 | if not self.dry_run: 54 | self.spawn(['cmake', '--build', '.'] + build_args) 55 | # Troubleshooting: if fail on line above then delete all possible 56 | # temporary CMake files including "CMakeCache.txt" in top level dir. 57 | os.chdir(str(cwd)) 58 | 59 | def find_version(*file_paths): 60 | try: 61 | with io.open(os.path.join(os.path.dirname(__file__), *file_paths), encoding="utf8") as fp: 62 | version_file = fp.read() 63 | version_match = re.search(r"^__version__ = version = ['\"]([^'\"]*)['\"]", version_file, re.M) 64 | return version_match.group(1) 65 | except Exception: 66 | return None 67 | 68 | 69 | if "BUILD_ANY_WHEEL" in os.environ: 70 | setup( 71 | packages=['pyfastkron'], 72 | version=find_version("pyfastkron", "version.py"), 73 | ) 74 | else: 75 | setup( 76 | packages=['pyfastkron'], 77 | ext_modules=[CMakeExtension('pyfastkron.FastKronX86'), 78 | CMakeExtension('pyfastkron.FastKronCUDA')], 79 | version=find_version("pyfastkron", "version.py"), 80 | cmdclass={ 81 | 'build_ext': build_ext, 82 | } 83 | ) 84 | -------------------------------------------------------------------------------- /src/kernels/cpu_kmmkernel.h: -------------------------------------------------------------------------------- 1 | #include "kernels/kmmkernel.h" 2 | 3 | #pragma once 4 | 5 | /** 6 | * CPUKMMKernel - A subclass for KMMKernels running on CPU 7 | */ 8 | struct CPUKMMKernel : public KMMKernel { 9 | public: 10 | CPUKMMKernel() {} 11 | CPUKMMKernel(void* kernelInvoker, FastKronType elemType, 12 | Factor f, Factor tileF, Matrix tileX, uint fusedFacs, bool P2PStore, 13 | uint regM, uint regK, uint regQ, uint optLevel, 14 | fastKronOp opX, fastKronOp opF, FastKronMMType mmType, 15 | KernelBatchType::Ty kernelBatchType) : 16 | KMMKernel(kernelInvoker, elemType, f, tileF, tileX, 17 | fusedFacs, P2PStore, regM, regK, regQ, 18 | optLevel, opX, opF, mmType, kernelBatchType) {} 19 | }; 20 | 21 | /** 22 | * X86KMMKernel - A subclass for KMMKernels running on an X86 CPU 23 | * This class contains a member to determine the SIMD architecture of the kernel. 24 | */ 25 | struct X86KMMKernel : public CPUKMMKernel { 26 | /** 27 | * @simd: The SIMD architecture of the kernel either AVX256, AVX512, or SISD. 28 | */ 29 | private: 30 | X86SIMD simd; 31 | 32 | public: 33 | X86KMMKernel() {} 34 | X86KMMKernel(X86SIMD simd, void* kernelInvoker, FastKronType elemType, 35 | Factor f, Factor tileF, Matrix tileX, uint fusedFacs, bool P2PStore, 36 | uint regM, uint regK, uint regQ, uint optLevel, 37 | fastKronOp opX, fastKronOp opF, FastKronMMType mmType, 38 | KernelBatchType::Ty kernelBatchType) : 39 | CPUKMMKernel(kernelInvoker, elemType, f, tileF, tileX, fusedFacs, 40 | P2PStore, regM, regK, regQ, optLevel, opX, opF, 41 | mmType, kernelBatchType), 42 | simd(simd) {} 43 | 44 | X86SIMD getSIMD() {return simd;} 45 | 46 | /** 47 | * canCompute - Overrides the method of KMMKernel and checks if simd of this kernel 48 | * can run on the given hardware. 49 | */ 50 | virtual bool canCompute(KMMProblem problem, const HardwareDetails* hw, 51 | bool p2p, KernelBatchType::Ty probBatchType, bool exactFuse = true) { 52 | if (CPUKMMKernel::canCompute(problem, hw, p2p, probBatchType, exactFuse)) { 53 | //A CPU with higher SIMD width (say AVX512) always support a lower 54 | //SIMD width (say AVX256) 55 | return getSIMD() <= ((X86ArchDetails*)hw)->simd; 56 | } 57 | return false; 58 | } 59 | 60 | /** 61 | * backend - Overrides the method of KMMKernel and always return X86. 62 | */ 63 | virtual std::string backend() const {return "X86";} 64 | 65 | /** 66 | * arch - Overrides the method of KMMKernel and return SIMD architecture. 67 | */ 68 | virtual std::string arch() const {return x86simdToStr(simd);} 69 | 70 | /** 71 | * str - Overrides the method of KMMKernel. 72 | */ 73 | virtual std::string str() const { 74 | std::stringstream info; 75 | info << backend() << "_" << arch() << "_" << KMMKernel::str(); 76 | return info.str(); 77 | } 78 | }; -------------------------------------------------------------------------------- /src/kernels/cuda_kmmkernel.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "kernels/gpu_kmmkernel.h" 5 | 6 | #pragma once 7 | 8 | struct CUDAKMMKernel : public GPUKMMKernel { 9 | SMArch sm; 10 | CUDAKMMKernel() {} 11 | CUDAKMMKernel(SMArch sm, void* kernelInvoker, FastKronType elemType, 12 | Factor f, Factor tileF, Matrix tileX, uint fusedFacs, bool P2PStore, 13 | uint regM, uint regK, uint regQ, uint optLevel, 14 | fastKronOp opX, fastKronOp opF, FastKronMMType mmType, KernelBatchType::Ty kernelBatchType, 15 | void*(*getKernel)(), uint NumThreads, 16 | uint alignX, uint alignF) : 17 | GPUKMMKernel(kernelInvoker, elemType, f, tileF, tileX, 18 | fusedFacs, P2PStore, regM, regK, regQ, 19 | optLevel, opX, opF, mmType, kernelBatchType, getKernel, 20 | NumThreads, alignX, alignF), 21 | sm(sm) {} 22 | 23 | /*** Functions to get/set information for CUDA Kernel ***/ 24 | /** 25 | * getPTXVersion() - Return PTX Version of the kernel as XXYY. 26 | */ 27 | uint32_t getPTXVersion() const { 28 | cudaFuncAttributes attr; 29 | CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel)); 30 | return attr.ptxVersion; 31 | } 32 | 33 | /** 34 | * getLocalSize() - Return local memory size in bytes. 35 | */ 36 | uint32_t getLocalSize() const { 37 | cudaFuncAttributes attr; 38 | CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel)); 39 | return attr.localSizeBytes; 40 | } 41 | 42 | /** 43 | * getNumRegs() - Return number of registers per thread. 44 | */ 45 | uint32_t getNumRegs() const { 46 | cudaFuncAttributes attr; 47 | CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel)); 48 | return attr.numRegs; 49 | } 50 | 51 | /** 52 | * setSharedMemAttr() - Set MaxDynamicSharedMemorySize attribute of the kernel 53 | * if shared memory is more than 48KB. 54 | */ 55 | cudaError_t setSharedMemAttr() { 56 | cudaError_t err = cudaSuccess; 57 | if (getMaxSharedMemSize() >= (48 << 10)) { 58 | err = cudaFuncSetAttribute(kernel, 59 | cudaFuncAttributeMaxDynamicSharedMemorySize, 60 | getMaxSharedMemSize()); 61 | } 62 | 63 | return err; 64 | } 65 | /*********************************************************/ 66 | 67 | /** 68 | * canCompute() - Overrides method of GPUKMMKernel 69 | */ 70 | virtual bool canCompute(KMMProblem problem, const HardwareDetails* hw, 71 | bool p2p, KernelBatchType::Ty probBatchType, bool exactFuse = true) { 72 | if (GPUKMMKernel::canCompute(problem, hw, p2p, probBatchType, exactFuse)) { 73 | return ((CUDAArchDetails*)hw)->smArch == sm; 74 | } 75 | return false; 76 | } 77 | 78 | /** 79 | * backend() - Returns CUDA as backend. 80 | */ 81 | virtual std::string backend() const { 82 | return "cuda"; 83 | } 84 | 85 | /** 86 | * arch() - Returns SM string. 87 | */ 88 | virtual std::string arch() const { 89 | return smArchToStr(sm); 90 | } 91 | }; -------------------------------------------------------------------------------- /documents/autotuning.md: -------------------------------------------------------------------------------- 1 | # Kernel Tuning for a GeKMM Problem 2 | 3 | The space of all valid x86/CUDA/HIP/ARM kernels to run a computation can contain 1000s of kernels. 4 | Hence, it is not practical to building all kernels for each possible kernel size and ship these kernels as part of FastKron. 5 | Instead, FastKron contains a set of few efficient pre-selected kernels and at runtime select the fastest series of kernels to compute the problem. 6 | 7 | FastKron contains three modes to select the best kernel series: 8 | * *Online Selection* uses an algorithm that selects the fastest series of kernel. The algorithm balances between cache size (shared memory) and parallelism. However, the algorithm is not always correct in finding the fastest series. 9 | * *Fast Tuning* runs all pre-selected kernels for the problem and selects the fastest kernel series. This tuning is done only once for the problem and the selected kernel series is called for subsequent execution for same problem sizes, therefore, amortizing the cost of tuning. 10 | * *Full Tuning* generates all valid kernels for a given problem size, builds FastKron library, runs all kernels to find the fastest kernel series, and use these kernels for subsequent execution of same problem size. Similar to Fast Tuning, Full Tuning is done only once for the problem size, thus, amortizing the cost of tuning. 11 | 12 | Lets see how to use all three modes. 13 | 14 | #### Online Selection 15 | 16 | By default FastKron uses Online Selection algorithm to select the fastest kernel series. 17 | 18 | #### Fast Tuning 19 | 20 | Setting `fastKronOptionsTune` as an option using `fastKronSetOptions` enables Fast Tuning. 21 | This option must be set before calling any of the `*gekmm` functions. 22 | 23 | #### Full Tuning 24 | 25 | Suppose the GeKMM problem is: 26 | 27 | $Z = \alpha ~ op(X) \times \left (op(F^1) \otimes op(F^2) \otimes \dots op(F^N) \right) + \beta Y$ 28 | 29 | where, 30 | * $op$ is no-transpose or transpose operation on a matrix. 31 | * each $op(F^i)$ is a row-major matrix of size $P^i \times Q^i$. 32 | * $F^i \otimes F^j$ is Kronecker Product of two matrices 33 | * $op(X)$ is a row-major matrix of size $M \times \left(P^1 \cdot P^2 \cdot P^3 \dots P^N \right)$ 34 | * $Y$ and $Z$ are row-major matrices of size $M \times \left(Q^1 \cdot Q^2 \cdot Q^3 \dots Q^N \right)$ 35 | * $\alpha$ and $\beta$ are scalars 36 | 37 | The first step is to generate all valid kernels for the problem sizes using `src/gen_tuner_kernels.py`. 38 | 39 | ``` 40 | python ../src/gen_tuner_kernels.py -backend -archs -distinct-factors N P1,Q1 P2,Q2 P3,Q3 ... -types -opX -opF -opt-levels 3 41 | ``` 42 | 43 | For example, generate CUDA kernels for Ampere architecture (SM80+) with opX = N, opF = T, N=4 and all Ps = Qs = 8. 44 | 45 | ``` 46 | python ../src/gen_tuner_kernels.py -backend cuda -archs ampere -distinct-factors 3 8,8 8,8 8,8 -types float -opX N -opF T -opt-levels 3 47 | ``` 48 | 49 | The next step is to run CMake with `FULL_TUNE=ON` and enable the backend but switch off other backends, and do make. 50 | 51 | ``` 52 | mkdir build/ 53 | cd build/ 54 | cmake .. -DFULL_TUNE=ON -DENABLE_=ON 55 | make -j 56 | ``` 57 | 58 | The final step is to use `fastKronSetOptions()` to set `fastKronOptionsTune`. -------------------------------------------------------------------------------- /packaging/wheels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | 5 | #pyproject.toml 6 | #[global] 7 | #index-url = "https://download.pytorch.org/whl/cu121" 8 | 9 | docker_create_container = "docker run -d -v $(pwd):/fastkron --name fastkron_build -it sameli/manylinux_2_28_x86_64_cuda_12.3:latest" 10 | docker_kill_container = "docker kill fastkron_build" 11 | docker_rm_container = "docker rm fastkron_build" 12 | docker_exec = f"docker exec fastkron_build" 13 | docker_remove_gcc_12 = f"{docker_exec} yum remove gcc-toolset-12* -y" 14 | docker_install_gcc_11 = f"{docker_exec} yum install gcc-toolset-11* -y" 15 | docker_install_git = f"{docker_exec} yum install git -y" 16 | docker_git_add_safe_dir = f"{docker_exec} git config --global --add safe.directory /fastkron" 17 | gcc_11_path = "PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH" 18 | 19 | host_fk_dir = os.getcwd() 20 | bdist_dir = "dist" 21 | docker_fk_dir = "/fastkron/" 22 | docker_packaging = os.path.join(docker_fk_dir, "packaging") 23 | docker_bdist_dir = os.path.join(docker_fk_dir, bdist_dir) 24 | 25 | def run_command(command): 26 | print("Running ", command, " in directory ", os.getcwd()) 27 | (s, o) = subprocess.getstatusoutput(command) 28 | if s != 0: 29 | print (f"Error running {command}\n", o) 30 | assert False 31 | return s, o 32 | 33 | def build_wheel(python_version): 34 | (s, o) = run_command(f"{docker_exec} sh {docker_packaging}/manylinux_docker_build.sh cp{python_version} {docker_fk_dir}") 35 | 36 | def build_any_wheel(): 37 | (s, o) = run_command(f"{docker_exec} sh {docker_packaging}/any_build.sh {docker_fk_dir}") 38 | 39 | def test_wheel(python_version): 40 | python_dir = f"/opt/python/cp{python_version}-cp{python_version}/bin/" 41 | pip = os.path.join(python_dir, "pip") 42 | python = os.path.join(python_dir, "python") 43 | for f in os.listdir(os.path.join(host_fk_dir, bdist_dir)): 44 | if f"cp{python_version}-manylinux_2_28_x86_64.whl" in f: 45 | (s, o) = run_command(f"{docker_exec} {pip} install {docker_bdist_dir}/{f}") 46 | 47 | (s, o) = run_command(f"{docker_exec} {python} {docker_fk_dir}/tests/python/test_wheels.py") 48 | 49 | def audit_wheel(python_version): 50 | for f in os.listdir(os.path.join(host_fk_dir, bdist_dir)): 51 | if f"cp{python_version}-linux_x86_64.whl" in f: 52 | (s, o) = run_command(f"{docker_exec} auditwheel repair {docker_bdist_dir}/{f} -w {docker_bdist_dir}/") 53 | 54 | if __name__ == "__main__": 55 | import argparse 56 | parser = argparse.ArgumentParser(description = "Build Python Wheels") 57 | parser.add_argument('-python-version', required=True, type=str, nargs="+") 58 | 59 | args = parser.parse_args() 60 | print("Create container") 61 | 62 | run_command(docker_create_container) 63 | run_command(docker_install_git) 64 | run_command(docker_git_add_safe_dir) 65 | 66 | print(f"Building for Python versions: {args.python_version}") 67 | for py in args.python_version: 68 | print(f"Building for Python {py}") 69 | build_wheel(py) 70 | 71 | build_any_wheel() 72 | 73 | print(f"Auditing wheels") 74 | for py in args.python_version: 75 | audit_wheel(py) 76 | 77 | print(f"Test wheels") 78 | for py in args.python_version: 79 | test_wheel(py) 80 | 81 | run_command(docker_kill_container) 82 | run_command(docker_rm_container) -------------------------------------------------------------------------------- /tests/run-tests.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import os 3 | import shutil 4 | import subprocess 5 | import sys 6 | 7 | def execute(command): 8 | print(f"Executing {command}") 9 | (s, o) = subprocess.getstatusoutput(command) 10 | if s != 0: 11 | print(f'Error in executing "{command}"') 12 | print(o) 13 | assert False 14 | return o 15 | 16 | backend = sys.argv[1].lower() 17 | single_or_multi = sys.argv[2].lower() 18 | 19 | all_backends = ['cuda', 'x86']# 'hip', 'arm'] 20 | 21 | assert backend in (all_backends + ['all']) 22 | assert single_or_multi in ['single', 'multi', 'all'] 23 | 24 | if backend == 'all': 25 | backends = all_backends 26 | else: 27 | backends = [backend] 28 | 29 | if single_or_multi == 'all': 30 | single_or_multi = ['single', 'multi'] 31 | else: 32 | single_or_multi = [single_or_multi] 33 | 34 | test_cases = {k : {'single':{}, 'multi': {}} for k in all_backends} 35 | 36 | test_cases['cuda']['single'] = {'gen-single-gpu-kernels' : ['single-gpu-cuda-NN', 'single-gpu-cuda-TT']} 37 | test_cases['cuda']['multi'] = { 38 | 'gen-multi-cuda-tests-kernel' : ['FASTKRON_COMM=NCCL multi-cuda-no-fusion-tests', 39 | 'FASTKRON_COMM=P2P multi-cuda-no-fusion-tests'], 40 | 'gen-multi-cuda-tuner-kernels' : ['multi-cuda-tuner-tests'], 41 | 'gen-multi-cuda-no-fusion-non-square-tests-kernel' : ['FASTKRON_COMM=P2P multi-cuda-no-fusion-non-square-tests', 42 | 'FASTKRON_COMM=NCCL multi-cuda-no-fusion-non-square-tests'], 43 | 'gen-multi-cuda-distinct-shapes' : ['FASTKRON_COMM=P2P multi-cuda-distinct-shapes', 44 | 'FASTKRON_COMM=NCCL multi-cuda-distinct-shapes'] 45 | } 46 | 47 | test_cases['x86']['single'] = {'gen-x86-kernels' : ['x86-cpu-NN', 'x86-cpu-TT']} 48 | 49 | if os.path.exists("build/"): 50 | shutil.rmtree("build/") 51 | 52 | if not os.path.exists("build/"): 53 | os.mkdir("build/") 54 | 55 | os.chdir("build/") 56 | cmake = "-DCMAKE_BUILD_TYPE=Release " 57 | for b in backends: 58 | cmake += f"-DENABLE_{b.upper()}=ON " 59 | 60 | if 'single' in single_or_multi: 61 | execute(f'cmake .. {cmake}') 62 | 63 | for mode in single_or_multi: 64 | if mode == 'single': 65 | for backend in backends: 66 | for case in test_cases[backend][mode]: 67 | execute(f'make {case}') 68 | 69 | execute(f'make -j') 70 | 71 | for backend in backends: 72 | for case in test_cases[backend][mode]: 73 | for run in test_cases[backend][mode][case]: 74 | output = execute(f'make {run if " " not in run else run.split(" ")[1]} -j') 75 | output = execute((f"TUNE=0 tests/{backend}/"+run) if ' ' not in run else run.replace(' ', f' tests/{backend}/')) 76 | if 'FAILED' in output: 77 | print(output) 78 | 79 | if 'multi' in single_or_multi: 80 | execute(f'cmake .. {cmake} -DFULL_TUNE=ON -DENABLE_MULTI_GPU=ON') 81 | for case in test_cases['cuda'][mode]: 82 | execute(f'make {case}') 83 | execute(f'make -j') 84 | for run in test_cases['cuda'][mode][case]: 85 | output = execute(f'make {run if " " not in run else run.split(" ")[1]} -j') 86 | output = execute((f"tests/{backend}/"+run) if ' ' not in run else run.replace(' ', f' tests/{backend}/')) 87 | if 'FAILED' in output: 88 | print(output) -------------------------------------------------------------------------------- /src/kmm/stackarray.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "config.h" 4 | 5 | #pragma once 6 | 7 | /** 8 | * StackArray - This class defines an array that is stored on the stack for a type and size. 9 | */ 10 | template 11 | class StackArray { 12 | public: 13 | /** 14 | * @array: The storage buffer of type @T with size @MaxSize. 15 | */ 16 | T array[MaxSize]; 17 | /** 18 | * @n: Length of elements filled in the array. 19 | */ 20 | uint32_t n; 21 | 22 | StackArray() { 23 | for (uint32_t i = 0; i < MaxSize; i++) { 24 | array[i] = T(); 25 | } 26 | } 27 | 28 | public: 29 | StackArray(const T* ptrs, uint32_t n) : n(n) { 30 | if (ptrs) { 31 | for (uint32_t i = 0; i < n; i++) { 32 | array[i] = ptrs[i]; 33 | } 34 | } 35 | 36 | for (uint32_t i = n; i < MaxSize; i++) { 37 | array[i] = T(); 38 | } 39 | } 40 | 41 | StackArray(std::initializer_list initList) : n(initList.size()) { 42 | int len = 0; 43 | for (auto elem : initList) { 44 | array[len++] = elem; 45 | } 46 | 47 | for (uint32_t i = n; i < MaxSize; i++) { 48 | array[i] = T(); 49 | } 50 | } 51 | 52 | CUDA_DEVICE_HOST 53 | T& operator[](int index) { 54 | assert (index < n && index >= 0); 55 | return array[index]; 56 | } 57 | 58 | CUDA_DEVICE_HOST 59 | T& operator[](uint32_t index) { 60 | assert (index < n); 61 | return array[index]; 62 | } 63 | 64 | StackArray sub(uint32_t start, uint32_t length) const { 65 | assert(length <= n); 66 | T ptrs[length]; 67 | for (uint32_t i = 0; i < length; i++) { 68 | ptrs[i] = array[i + start]; 69 | } 70 | 71 | return StackArray(ptrs, length); 72 | } 73 | 74 | void push_front(const T& elem) { 75 | assert(n < MaxSize); 76 | for (int i = n; i >= 1; i--) { 77 | array[i] = array[i-1]; 78 | } 79 | 80 | array[0] = elem; 81 | n++; 82 | } 83 | 84 | void push_back(const T& elem) { 85 | assert(n < MaxSize); 86 | array[n] = elem; 87 | n++; 88 | } 89 | 90 | CUDA_DEVICE_HOST 91 | uint32_t len() const {return n;} 92 | 93 | template 94 | StackArray slice(uint32_t start) const { 95 | assert(SliceSize <= n); 96 | T ptrs[SliceSize]; 97 | for (uint32_t i = 0; i < SliceSize; i++) { 98 | ptrs[i] = array[i + start]; 99 | } 100 | 101 | return StackArray(ptrs, SliceSize); 102 | } 103 | 104 | template 105 | StackArray sliceOrEmpty(uint32_t start) const { 106 | T ptrs[SliceSize]; 107 | for (uint32_t i = 0; i < SliceSize && i < n; i++) { 108 | ptrs[i] = array[i + start]; 109 | } 110 | 111 | for (uint32_t i = n; i < SliceSize; i++) { 112 | ptrs[i] = T(); 113 | } 114 | 115 | return StackArray(ptrs, SliceSize); 116 | } 117 | 118 | StackArray slice(uint32_t start, uint32_t size) const { 119 | assert(size <= n); 120 | T ptrs[size]; 121 | for (uint32_t i = 0; i < size; i++) { 122 | ptrs[i] = array[i + start]; 123 | } 124 | 125 | return StackArray(ptrs, size); 126 | } 127 | 128 | StackArray(const StackArray& x) : StackArray (&x.array[0], x.len()) {} 129 | }; -------------------------------------------------------------------------------- /src/optimized_kernel_map.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import re 4 | import sys 5 | 6 | def run_command(command): 7 | (s, o) = subprocess.getstatusoutput(command) 8 | if s != 0: 9 | print (f"Running {command}\n", o) 10 | return o 11 | 12 | def run(m, n, p, q, opX, opF, backend, fuse, tune): 13 | o = run_command(f'TUNE={tune} ../build/tests/benchmarks/benchmark_cuda -m {m} -n {n} -p {p} -q {q} -r {10} -w {10} -t float --tune --backend {backend} {"--fuse" if fuse else ""}') 14 | o = o[o.find('Minimum Time'):] 15 | 16 | kernelSeries = re.findall(r'\s*\[(\d+), (\d+)\] = (\d+) (.+) runs', o) 17 | allKernelsExec = [] 18 | for kernelExec in kernelSeries: 19 | start,end,k,kernel = kernelExec 20 | allKernelsExec += [(int(end) - int(start) + 1, int(k), kernel)] 21 | 22 | allKernelsExec = list(set(allKernelsExec)) 23 | gflops = re.findall(r'GFLOPS: (\d+\.\d+)', o) 24 | #print(f"{m}x{p**n}*({p}x{q}^{n})",allKernelsExec, gflops) 25 | 26 | return f"{m}x{p**n}*({p}x{q}^{n})", allKernelsExec, gflops #["{"+f'Matrix({m},{p**n}), \"{kernels[0]}\"'+"},"] 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('-backend' , type=str) 31 | parser.add_argument('-opX' , required=True , type=str) 32 | parser.add_argument('-opF' , required=True , type=str) 33 | 34 | args = parser.parse_args() 35 | 36 | assert args.opX in ["N", "T"] 37 | assert args.opF in ["N", "T"] 38 | run_command(f'python ./gen_tuner_kernels.py -backend {args.backend} -same-factors 3 128,128 -same-factors 3 64,64 -same-factors 4 32,32 -same-factors 5 16,16 -same-factors 6 8,8 -same-factors 10 4,4 -same-factors 20 2,2 -opX N -opF N -match-configs-file kernels/best-kernels/a100-kernels') 39 | run_command(f'cd ../build/ && make benchmark_{args.backend} -j') 40 | 41 | shapeToKernel = {} 42 | 43 | for p in [2,4,8,16,32,64,128,256]: 44 | for q in [2,4,8,16,32,64,128,256]: 45 | for n in range(1,20): 46 | for m in [1,4,16,64,256,1024]: 47 | if m*(p**n) > 2*1024*1024*1024 or m*(q**n) > 2*1024*1024*1024: 48 | continue 49 | if p <= 32 and q <= 32: 50 | continue 51 | tuned = run(m, n, p, q, args.opX, args.opF, args.backend, True, 1) 52 | selected = run(m, n, p, q, args.opX, args.opF, args.backend, True, 0) 53 | print("Tuned", tuned) 54 | print("Selected", selected) 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | # for kernelExec in allKernelsExec: 67 | # key = f"Factor({p},{q}),{kernelExec[0]}" 68 | # if key not in shapeToKernel: 69 | # shapeToKernel[key] = [] 70 | # if len(shapeToKernel[key]) > 0 and (shapeToKernel[key][-1][2] == kernelExec[2] and shapeToKernel[key][-1][1] <= kernelExec[1] and shapeToKernel[key][-1][0] <= m): 71 | # continue 72 | # shapeToKernel[key] += [(m, kernelExec[1], kernelExec[2])] 73 | 74 | # maplines = "" 75 | # indent = 1 76 | 77 | # for k,vs in shapeToKernel.items(): 78 | # maplines += " " * indent + "{\n" 79 | # indent += 1 80 | # maplines += " " * indent + "{"+k+"}," + " {\n" 81 | # indent += 1 82 | # for v in vs: 83 | # maplines += " " * indent + "{" + f'Matrix({v[0]}, {v[1]}), "{v[2]}"' + "}" + ",\n" 84 | # indent -= 1 85 | # maplines += " " * indent + "}\n" 86 | # indent -= 1 87 | # maplines += " " * indent + "},\n" 88 | 89 | 90 | # print(maplines) -------------------------------------------------------------------------------- /tests/python/test_wheels.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import numpy as np 3 | 4 | import pyfastkron.fastkronnumpy as fk 5 | 6 | def product(values): 7 | return reduce((lambda a, b: a * b), values) 8 | 9 | def transpose(m): 10 | axis = tuple(range(len(m.shape[:-2]))) + \ 11 | (len(m.shape) - 1, len(m.shape) - 2) 12 | return m.transpose(axis) 13 | 14 | def reference(mmtype, x, fs): 15 | batchKron = fs[0].shape[:-2] 16 | if len(batchKron) == 0: 17 | outputKron = fs[0] 18 | for m in fs[1:]: 19 | outputKron = np.kron(outputKron, m) 20 | else: 21 | batchDims = product(batchKron) 22 | for i,f in enumerate(fs): 23 | fs[i] = fs[i].reshape((batchDims,) + f.shape[-2:]) 24 | 25 | output = fs[0] 26 | for f in fs[1:]: 27 | prev = output 28 | output = np.ndarray(shape=(batchDims, prev.shape[-2] * f.shape[-2], prev.shape[-1] * f.shape[-1]), 29 | dtype=f.dtype) 30 | for b in range(batchDims): 31 | output[b:] = np.kron(prev[b,:], f[b,:]) 32 | outputKron = output.reshape(batchKron + (output.shape[-2], output.shape[-1])) 33 | 34 | if mmtype == "mkm": 35 | return np.matmul(x, outputKron) 36 | elif mmtype == "kmm": 37 | return np.matmul(outputKron, x) 38 | 39 | def run(mmtype, m, n, p, q, dtype, device, trX, trF, 40 | high=5, batchDimX=[], batchDimFPre=[], batchDimZ=[]): 41 | #Using integer values instead of real numbers because 42 | #floating point is not associative 43 | if mmtype == "mkm": 44 | xshape = [m, p**n] if not trX else [p**n, m] 45 | elif mmtype == "kmm": 46 | xshape = [p**n, m] if not trX else [m, p**n] 47 | 48 | if m == 1: 49 | if trX: 50 | xshape = [xshape[0],] 51 | else: 52 | xshape = [xshape[1],] 53 | 54 | xshape = list(batchDimX) + xshape 55 | 56 | if mmtype == "mkm": 57 | fshape = [p, q] if not trF else [q, p] 58 | elif mmtype == "kmm": 59 | fshape = [q, p] if not trF else [p, q] 60 | 61 | if q == 1: 62 | if trF: 63 | fshape = [fshape[1],] 64 | else: 65 | fshape = [fshape[0],] 66 | 67 | fshape = list(batchDimFPre) + fshape 68 | 69 | zshape = list(batchDimZ) 70 | if mmtype == "mkm": 71 | zshape += [m,q**n] 72 | elif mmtype == "kmm": 73 | zshape += [q**n,m] 74 | 75 | x = np.random.randint(0, high=high,size=xshape).astype(dtype) 76 | fs = [np.random.randint(0, high=high,size=fshape).astype(dtype)\ 77 | for i in range(n)] 78 | z = np.random.randint(0, high=high,size=zshape).astype(dtype) 79 | if trX: 80 | x = transpose(x) 81 | if trF: 82 | fs = [transpose(f) for f in fs] 83 | 84 | alpha = 3.0 85 | beta = 0.0 86 | 87 | if mmtype == "mkm": 88 | y = fk.gemkm(x, fs, alpha, beta, z) 89 | elif mmtype == "kmm": 90 | y = fk.gekmm(fs, x, alpha, beta, z) 91 | 92 | ref = alpha * reference(mmtype, x, fs) + beta * z 93 | val = np.isclose(y, ref, rtol=1e-04).all().item() 94 | print(52) 95 | assert val 96 | 97 | def device_tests(device): 98 | for mmtype in ["mkm", "kmm"]: 99 | run(mmtype, 8, 3, 8, 8, np.float32, device, False, False) 100 | run(mmtype, 16, 3, 32, 32, np.float32, device, False, False) 101 | run(mmtype, 16, 2, 128, 127, np.float32, device, True, False) 102 | 103 | run(mmtype, 10, 5, 6, 6, np.float32, device, True, False) 104 | 105 | # #double 106 | run(mmtype, 11, 10, 3, 3, np.double, device, False, True) 107 | run(mmtype, 200, 2, 32, 32, np.double, device, True, True) 108 | 109 | #float16 110 | run(mmtype, 102, 4, 8, 8, np.float16, device, False, False, high=2) 111 | 112 | def test_cpu(): 113 | device_tests("cpu") 114 | 115 | if __name__ == "__main__": 116 | test_cpu() -------------------------------------------------------------------------------- /src/print_best_kernel_for_shapes.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import re 4 | import sys 5 | 6 | def run_command(command): 7 | (s, o) = subprocess.getstatusoutput(command) 8 | if s != 0: 9 | print (f"Running {command}\n", o) 10 | return o 11 | 12 | def tune(mmtype, ms, n, p, q, opX, opF, fuse, backend, elemtype): 13 | # run_command(f'python gen_tuner_kernels.py -backend {backend} -same-factors {n} {p},{q} -opX {opX} -opF {opF}') 14 | # run_command(f'cd ../build/ && make benchmark_{backend} -j') 15 | # for m in ms: 16 | o = run_command("FASTKRON_LOG=DEBUG " + ("OMP_NUM_THREADS=64 taskset -c 0-64" if backend == "x86" else "") + f'../build/tests/benchmarks/benchmark_{backend} -m {m} -n {n} -p {p} -q {q} -r {10} -w {10} -t {elemtype} --tune --backend {backend} {"--fuse" if fuse else ""} -a 1 -b 0 --gemmtype {mmtype}') 17 | o = o[o.find('Minimum Time'):] 18 | kernels = re.findall(r'\d+\s(.+)\sruns\sfor', o) 19 | kernels = set(kernels) 20 | gflops = re.findall(r'GFLOPS: (\d+\.\d+)', o) 21 | print(f"{m}x{p**n}*({p}x{q}^{n})",kernels, gflops) 22 | 23 | def parse_same_factors(case): 24 | n = int(case[0]) 25 | assert len(case[1:]) == 1 26 | split = case[1].split(',') 27 | p = int(split[0]) #[int(split[0]) for i in range(n)] 28 | q = int(split[1]) #[int(split[1]) for i in range(n)] 29 | 30 | # k = compute_k(ps, qs) 31 | return (n, p, q) 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('-distinct-factors' , required=False, type=str, action='append', nargs="+") 36 | parser.add_argument('-same-factors' , required=False, type=str, action='append', nargs="+") 37 | parser.add_argument('-m' , required=True, type=int, action='append', nargs="+") 38 | parser.add_argument('-mmtype' , required=True, type=str) 39 | parser.add_argument('-opX' , required=True , type=str) 40 | parser.add_argument('-opF' , required=True , type=str) 41 | parser.add_argument('-num-kernels' , required=False, type=int, default=10000) 42 | parser.add_argument('-backend' , required=True, type=str) 43 | parser.add_argument('-arch' , required=True, type=str) 44 | args = parser.parse_args() 45 | parsed_cases = [] 46 | 47 | if args.same_factors is not None: 48 | for case in args.same_factors: 49 | try: 50 | parsed_cases += [parse_same_factors(case)] 51 | except Exception as e: 52 | print(f"Invalid case: {case}") 53 | print(e) 54 | sys.exit(0) 55 | if args.backend is None or args.backend.lower() not in ['cuda', 'x86', 'hip', 'arm']: 56 | print(f"Invalid backend: {args.backend}") 57 | sys.exit(0) 58 | 59 | print("Print and tune kernels for ", parsed_cases) 60 | assert args.opX in ["N", "T"] 61 | assert args.opF in ["N", "T"] 62 | assert args.mmtype in ["kmm", "mkm"] 63 | elemType = "double" 64 | 65 | # run_command(f'python ./gen_tuner_kernels.py -backend {args.backend} -archs {args.arch} -same-factors 3 128,128 -same-factors 3 64,64 -same-factors 4 32,32 -same-factors 5 16,16 -same-factors 6 8,8 -same-factors 10 4,4 -same-factors 20 2,2 -opX N -opF N -mm-type {args.mmtype} -types {elemType} -opt-levels 3') 66 | # run_command(f'cd ../build/ && make benchmark_{args.backend} -j') 67 | 68 | for p in [2,4,8,16,32, 64, 128]: 69 | for q in [2,4,8,16,32, 64, 128]: 70 | if p != q: 71 | continue 72 | for n in range(1,13): 73 | for m in [2,4,16,64,256]: 74 | if m*(p**n) > 1024*1024*1024 or m*(q**n) > 1024*1024*1024 or p**n < 64 or q**n < 64: 75 | continue 76 | # if p <= 32 and q <= 32: 77 | # continue 78 | if p <= 32: 79 | tune(args.mmtype, m, n, p, q, args.opX, args.opF, False, args.backend, elemType) 80 | tune(args.mmtype, m, n, p, q, args.opX, args.opF, True, args.backend, elemType) -------------------------------------------------------------------------------- /src/utils/thread_pool.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #pragma once 9 | 10 | template 11 | class thread_pool { 12 | public: 13 | struct task { 14 | void(*f)(task_args); 15 | task_args args; 16 | 17 | task(void(*f_)(task_args), task_args args_) : 18 | f(f_), args(args_) {} 19 | 20 | task() {} 21 | 22 | volatile task& operator= (const task& x) volatile { 23 | f = x.f; 24 | args = x.args; 25 | return *this; 26 | } 27 | 28 | volatile task& operator= (const volatile task& x) volatile { 29 | f = x.f; 30 | args = x.args; 31 | return *this; 32 | } 33 | }; 34 | 35 | private: 36 | std::vector threads; 37 | volatile task* tasks; 38 | volatile bool* done; 39 | 40 | uint32_t num_threads; 41 | std::vector wait_mutexes; 42 | std::vector tasks_mutexes; 43 | std::vector waiting_vars; 44 | 45 | struct thread_args { 46 | uint32_t thread_id; 47 | thread_pool* pool; 48 | }; 49 | 50 | static void thread_func(thread_args args) { 51 | volatile thread_pool* volatile_pool = ((volatile thread_pool*)args.pool); 52 | while(volatile_pool->is_running()) { 53 | args.pool->wait_for_task(args.thread_id); 54 | if (volatile_pool->is_running()) 55 | args.pool->run_task(args.thread_id); 56 | volatile_pool->thread_done(args.thread_id); 57 | } 58 | } 59 | 60 | bool running; 61 | 62 | public: 63 | thread_pool(): num_threads(0) {} 64 | 65 | thread_pool(uint32_t num_threads_) : num_threads(num_threads_) { 66 | init(num_threads_); 67 | } 68 | 69 | void init(uint32_t num_threads_) { 70 | num_threads = num_threads_; 71 | running = true; 72 | tasks = new task[num_threads]; 73 | done = new bool[num_threads]; 74 | wait_mutexes = std::vector(num_threads); 75 | tasks_mutexes = std::vector(num_threads); 76 | waiting_vars = std::vector(num_threads); 77 | for (uint32_t t = 0; t < num_threads; t++) { 78 | threads.push_back(std::thread(thread_pool::thread_func, thread_args{t, this})); 79 | done[t] = false; 80 | } 81 | } 82 | 83 | bool is_running() const volatile {return running;} 84 | 85 | void end() { 86 | running = false; 87 | notify_all(); 88 | for (uint32_t t = 0; t < num_threads; t++) { 89 | threads[t].join(); 90 | } 91 | } 92 | 93 | void run_task(uint32_t id) { 94 | std::unique_lock tlk(tasks_mutexes[id]); 95 | volatile task* t = &tasks[id]; 96 | t->f(t->args); 97 | tlk.unlock(); 98 | } 99 | 100 | void wait_for_task(uint32_t id) { 101 | std::unique_lock lk(wait_mutexes[id]); 102 | waiting_vars[id].wait(lk); 103 | lk.unlock(); 104 | } 105 | 106 | void notify_all() { 107 | for (uint32_t i = 0; i < num_threads; i++) { 108 | done[i] = false; 109 | std::unique_lock lk(wait_mutexes[i]); 110 | waiting_vars[i].notify_all(); 111 | lk.unlock(); 112 | } 113 | } 114 | 115 | void execute_tasks(task* tasks_) { 116 | for (uint32_t i = 0; i < num_threads; i++) { 117 | std::unique_lock tlk(tasks_mutexes[i]); 118 | tasks[i] = tasks_[i]; 119 | tlk.unlock(); 120 | } 121 | notify_all(); 122 | } 123 | 124 | void thread_done(uint32_t thread_id) volatile { 125 | done[thread_id] = true; 126 | } 127 | 128 | void join_tasks() { 129 | for (uint32_t t = 0; t < num_threads; t++) { 130 | volatile bool* d = (volatile bool*) &done[t]; 131 | while (*d != true); 132 | } 133 | } 134 | 135 | ~thread_pool() { 136 | end(); 137 | join_tasks(); 138 | delete tasks; 139 | } 140 | }; 141 | -------------------------------------------------------------------------------- /src/kernels/get_batched_data.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /** 4 | * GetBatchedData - A template struct to obtain common parameters for 5 | * single or strided batched problems. The struct is specialized 6 | * for both batch types. 7 | */ 8 | template 10 | struct GetBatchedData { 11 | /** 12 | * getBatchCount - Get number of batches in the problem. 13 | */ 14 | CUDA_DEVICE_HOST 15 | uint getBatchCount(const KernelParams& params) const; 16 | 17 | /** 18 | * getXBatch - Get input X for a batch. 19 | */ 20 | CUDA_DEVICE_HOST 21 | Matrix getXBatch(const KernelParams& params, int batch) const; 22 | 23 | /** 24 | * getYBatch - Get output Y for a batch. 25 | */ 26 | CUDA_DEVICE_HOST 27 | Matrix getYBatch(const KernelParams& params, int batch) const; 28 | 29 | /** 30 | * getFBatch - Get factor at an index for a batch. 31 | */ 32 | CUDA_DEVICE_HOST 33 | Factor getFBatch(const KernelParams& params, int fidx, int batch) const; 34 | 35 | /** 36 | * getZBatch - Get input Z for a batch. 37 | */ 38 | CUDA_DEVICE_HOST 39 | Matrix getZBatch(const EpilogueParams& params, const Matrix& Y, int batch) const; 40 | 41 | CUDA_DEVICE_HOST 42 | Matrix getIntermediateBatch(const FusedParams& params, int fac, uint32_t batch) const; 43 | }; 44 | 45 | /** 46 | * GetBatchedData specliazed for single batched problem. 47 | */ 48 | template 49 | struct GetBatchedData { 50 | uint getBatchCount(const KernelParams& /*params*/) const {return 1;} 51 | 52 | CUDA_DEVICE_HOST 53 | Matrix getXBatch(const KernelParams& params, int /*batch*/) const { 54 | return params.problem.x(); 55 | } 56 | 57 | CUDA_DEVICE_HOST 58 | Matrix getYBatch(const KernelParams& params, int /*batch*/) const { 59 | return params.problem.y(); 60 | } 61 | 62 | CUDA_DEVICE_HOST 63 | Factor getFBatch(const KernelParams& params, int fidx, int /*batch*/) const { 64 | return params.problem.f(fidx); 65 | } 66 | 67 | CUDA_DEVICE_HOST 68 | Matrix getZBatch(const EpilogueParams& params, const Matrix& Y, int /*batch*/) const { 69 | return Matrix(Y.m(), Y.n(), (void*)params.template z()); 70 | } 71 | 72 | CUDA_DEVICE_HOST 73 | Matrix getIntermediateBatch(const FusedParams& params, int fac, uint32_t batch) const { 74 | return params.intermediates[fac]; 75 | } 76 | }; 77 | 78 | /** 79 | * GetBatchedData specliazed for strided batched problem. 80 | */ 81 | template 82 | struct GetBatchedData { 83 | uint getBatchCount(const KernelParams& params) const {return params.problem.batchCount();} 84 | 85 | CUDA_DEVICE_HOST 86 | Matrix getXBatch(const KernelParams& params, int batch) const { 87 | return params.problem.x().template batch(batch); 88 | } 89 | 90 | CUDA_DEVICE_HOST 91 | Matrix getYBatch(const KernelParams& params, int batch) const { 92 | return params.problem.y().template batch(batch); 93 | } 94 | 95 | CUDA_DEVICE_HOST 96 | Factor getFBatch(const KernelParams& params, int fidx, int batch) const { 97 | return params.problem.f(fidx).template batch(batch); 98 | } 99 | 100 | CUDA_DEVICE_HOST 101 | Matrix getZBatch(const EpilogueParams& params, const Matrix& /*Y*/, int batch) const { 102 | return params.getZ().template batch(batch); 103 | } 104 | 105 | CUDA_DEVICE_HOST 106 | Matrix getIntermediateBatch(const FusedParams& params, int fac, uint32_t batch) const { 107 | return params.intermediates[fac].template batch(batch); 108 | } 109 | }; -------------------------------------------------------------------------------- /src/kernels/cpu/memory-store.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | template 7 | static CUDA_DEVICE_HOST 8 | void store(const KernelParams& /*params*/, const FusedParams& fusedParams, const EpilogueParams& /*epilogueParams*/, 9 | X86VecT beta, 10 | uint32_t fac, uint32_t /*batch*/, 11 | uint32_t tileM, uint32_t tileK, uint32_t tileP, uint32_t tileQ, 12 | const YElem& y, 13 | const Factor& F, Matrix& Y, Matrix& FusedIntermediate, Matrix& Z, 14 | FCache& Fch, TileX& XTile, YInterim& Ych, YRegisters& YReg) { 15 | bool storeY = false; 16 | if (fac > 0 || (Fch.p() <= F.p() && tileP < F.p() - Fch.p())) { 17 | YReg.apply([&](X86VecT& e, const uint32_t rm, const uint32_t rk, const uint32_t rq) { 18 | e.store(&Ych.at(y.m() + rm * YReg.mvec(), y.q() + rq, y.k()/Fch.p() + rk * YReg.kvec())); 19 | }); 20 | } else { 21 | storeY = true; 22 | } 23 | 24 | bool storeFusedIntermediate = (fac > 0 && fusedParams.intermediates[fac].data() != nullptr); 25 | 26 | if (storeFusedIntermediate || storeY) { 27 | Matrix output = (storeFusedIntermediate) ? FusedIntermediate : Y; 28 | 29 | YReg.apply([&](X86VecT& e, const uint32_t rm, const uint32_t rk, const uint32_t rq) { 30 | constexpr bool kQMultipleOfTileQ = KernelOptimizations::IsQMultipleOfTileQ(OptLevel); 31 | constexpr bool kKMultipleOfTileK = KernelOptimizations::IsKMultipleOfTileK(OptLevel); 32 | constexpr bool kMMultipleOfTileM = KernelOptimizations::IsMMultipleOfTileM(OptLevel); 33 | 34 | uint32_t slice = y.k()/Fch.p() + rk * YReg.kvec(); 35 | 36 | if (!kKMultipleOfTileK && slice >= XTile.cols/F.p()) return; 37 | if (!kQMultipleOfTileQ && tileQ + y.q() + rq >= F.q()) return; 38 | 39 | const uint32_t XTileSlices = XTile.tileCols()/F.p(); 40 | const uint32_t XSlices = Y.n()/F.q(); 41 | uint32_t yN; 42 | 43 | if (fusedParams.NumFused > 1) { 44 | uint32_t xshCol = (rq + y.q()) * XTileSlices + rk*YReg.kvec() + y.k()/Fch.p(); 45 | //Scale shared mem slice idx to global mem idx 46 | uint32_t glSlice = (xshCol/XTileSlices)*XSlices; 47 | //Scale shared fused slice to global mem 48 | uint32_t sliceElem = ((xshCol%XTileSlices)/fusedParams.XShFusedSlices[fac])*fusedParams.XglFusedSlices[fac]; 49 | //Elem idx in Fused Slice 50 | uint32_t elem = (tileK/XTile.tileCols()) * fusedParams.XShFusedSlices[fac] + 51 | xshCol%fusedParams.XShFusedSlices[fac]; 52 | yN = glSlice + sliceElem + elem; 53 | } else { 54 | yN = (y.q() + rq) * XSlices + 55 | (tileK/XTile.tileCols()) * XTileSlices + 56 | slice; 57 | if (Fch.q() < F.q()) { 58 | yN += tileQ * XSlices; 59 | } 60 | } 61 | 62 | if (kMMultipleOfTileM || y.m() + rm*YReg.mvec() < XTile.m()) { 63 | uint32_t numElems; 64 | if (YReg.layout() == fastKronOp_N) { 65 | uint32_t slices = (kKMultipleOfTileK && 66 | XTile.tileCols() % YReg.kvec() == 0) ? 67 | YReg.kvec() : (XTile.cols/F.p() - slice); 68 | slices = MIN(YReg.kvec(), slices); 69 | numElems = slices; 70 | } else { 71 | numElems = kMMultipleOfTileM ? YReg.mvec() : XTile.m() - (y.m() + rm*YReg.mvec()); 72 | numElems = MIN(YReg.mvec(), numElems); 73 | } 74 | if (storeY && (EpilogueKindVal & EpilogueKind::Beta) == EpilogueKind::Beta) { 75 | X86VecT z; 76 | z.load(Z.data(tileM + y.m() + rm*YReg.mvec(), yN, YReg.layout()), numElems); 77 | e.fmadd(beta, z); 78 | } 79 | uint32_t yM = tileM + y.m() + rm*YReg.mvec(); 80 | e.store(output.data(yM, yN, YReg.layout()), numElems); 81 | }}); 82 | } 83 | } -------------------------------------------------------------------------------- /tests/benchmarks/profile_kernels.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import re 3 | import os 4 | import sys 5 | import copy 6 | 7 | ENV_VAR = "PYTHONPATH=/home/parasail/.local/lib/python3.8/site-packages/:$PYTHONPATH LD_LIBRARY_PATH=%s:"%(os.getcwd()) 8 | NVPROF_BIN="/usr/local/cuda/bin/nvprof" 9 | RequiredMetrics = ["shared_load_transactions_per_request", "gld_transactions_per_request", "global_load_requests"] 10 | NVPROF_FLAGS="--metrics "+",".join(RequiredMetrics) 11 | 12 | NVPROF_COMMAND=" ".join([NVPROF_BIN, NVPROF_FLAGS]) 13 | 14 | npoints = 320 15 | cases = {4:(7, 10), 8:(3,6), 16:(2,5), 32: (2,4)} 16 | 17 | FASTKRON_BIN = "./kron" 18 | FASTKRON_FLAGS = "-b 320 -f %d -s %d -t float -r 1 -w 0" 19 | GPYTORCH_BIN = "python3 kronecker-model.py" 20 | GPYTORCH_FLAGS = "320 %d %d 1" 21 | MetricValues = { 22 | "fastkron": {m : 0 for m in RequiredMetrics}, 23 | "cublas": {m : 0 for m in RequiredMetrics}, 24 | "transpose": {m : 0 for m in RequiredMetrics}, 25 | } 26 | metric_values = [] 27 | 28 | def parseMetricLine(line): 29 | # return re.findall(r'(\d+)\s+([\_\w\d\-]+)\s+([\w+\s\d\(\)\-]+)\s+(.+?)\s+(.+?)\s+(.+)',line) 30 | line = line.split(" ") 31 | new_line = [] 32 | for m in line: 33 | if m != "": 34 | new_line += [m.strip()] 35 | return new_line 36 | 37 | def parseMetric(metrics_value, nvprofOutput): 38 | o = nvprofOutput[nvprofOutput.find("Metric result:"):] 39 | lines = list(re.findall(r'.+', o)) 40 | l = 0 41 | while l < len(lines): 42 | line = lines[l] 43 | if "Kernel:" in line: 44 | kernel_name = "" 45 | if "kronGemmKernel" in line: 46 | kernel_name = "fastkron" 47 | elif "gemm" in line: 48 | kernel_name = "cublas" 49 | elif "at" in line and "native" in line and "elementwise_kernel" in line: 50 | kernel_name = "transpose" 51 | else: 52 | print(line) 53 | sys.exit(0) 54 | print(kernel_name) 55 | l += 1 56 | line = lines[l] 57 | while l < len(lines) and "Kernel:" not in lines[l]: 58 | line = lines[l] 59 | parsed = parseMetricLine(line) 60 | if parsed[1] in RequiredMetrics: 61 | if parsed[1]=="global_load_requests": 62 | metrics_value[kernel_name][parsed[1]] += float(parsed[5])*int(parsed[0]) 63 | else: 64 | metrics_value[kernel_name][parsed[1]] = max(float(parsed[5]), metrics_value[kernel_name][parsed[1]]) 65 | l += 1 66 | if l < len(lines) and "Kernel:" not in lines[l]: 67 | l += 1 68 | 69 | for g in cases: 70 | for d in range(cases[g][0], cases[g][1]+1): 71 | metric_value = copy.deepcopy(MetricValues) 72 | metric_values += [(g,d,metric_value)] 73 | if True: 74 | fastkron_command = FASTKRON_BIN + " " + FASTKRON_FLAGS%(d, g) 75 | command = "sudo " + ENV_VAR + " " + NVPROF_COMMAND + " " + fastkron_command 76 | print(command) 77 | (s, o) = subprocess.getstatusoutput(command) 78 | if s != 0: 79 | print(o) 80 | else: 81 | parseMetric(metric_value, o) 82 | 83 | if True: 84 | gpytorch_command = GPYTORCH_BIN + " " + GPYTORCH_FLAGS%(d, g) 85 | command = "sudo " + ENV_VAR + " " + NVPROF_COMMAND + " " + gpytorch_command 86 | print(command) 87 | (s, o) = subprocess.getstatusoutput(command) 88 | if s != 0: 89 | print(o) 90 | else: 91 | parseMetric(metric_value, o) 92 | 93 | firstRow = [] 94 | print(metric_values) 95 | kernels = list(MetricValues.keys()) 96 | metrics = RequiredMetrics 97 | for kernel in kernels: 98 | for metric in RequiredMetrics: 99 | firstRow += [kernel+"-"+metric] 100 | 101 | print("&".join(["g", "d"] + firstRow)) 102 | 103 | for metric_value in metric_values: 104 | row = [] 105 | for kernel in kernels: 106 | for metric in RequiredMetrics: 107 | row += [str(metric_value[2][kernel][metric])] 108 | print("&".join([str(metric_value[0]), str(metric_value[1])]+row)) -------------------------------------------------------------------------------- /src/kernels/best-kernels/kmm-v100-kernels: -------------------------------------------------------------------------------- 1 | kmm_cuda_volta_256_f_128x128_32x128_1_32x256_2x2x8_**_*_*_* 2 | kmm_cuda_volta_256_f_128x128_32x128_1_16x512_1x4x8_**_*_*_* 3 | kmm_cuda_volta_128_f_128x128_32x128_1_2x4096_2x1x32_**_*_*_* 4 | kmm_cuda_volta_128_f_128x128_32x64_1_2x256_1x1x2_**_*_*_* 5 | 6 | kmm_cuda_volta_128_f_64x64_32x64_1_32x128_4x2x4_**_*_*_* 7 | kmm_cuda_volta_128_f_64x64_32x64_1_16x128_2x2x4_**_*_*_* 8 | kmm_cuda_volta_64_f_64x64_32x64_1_2x1024_2x1x16_**_*_*_* 9 | kmm_cuda_volta_128_f_64x64_32x64_1_2x128_1x1x2_**_*_*_* 10 | 11 | kmm_cuda_volta_256_f_32x32_32x32_1_32x128_4x1x4_**_strided_*_* 12 | kmm_cuda_volta_128_f_32x32_32x32_1_16x128_1x4x4_**_strided_*_* 13 | kmm_cuda_volta_128_f_32x32_32x32_2_2x4096_2x1x32_**_strided_*_* 14 | kmm_cuda_volta_128_f_32x32_32x32_1_2x4096_1x2x32_**_strided_*_* 15 | kmm_cuda_volta_64_f_32x32_32x32_1_2x128_2x2x1_**_strided_*_* 16 | 17 | kmm_cuda_volta_256_f_16x16_16x16_2_16x256_1x2x8_**_strided_*_2|3 18 | kmm_cuda_volta_256_f_16x16_16x16_1_16x256_1x2x8_**_strided_*_2|3 19 | kmm_cuda_volta_256_f_16x16_16x16_2_2x4096_2x1x16_**_strided_*_2|3 20 | kmm_cuda_volta_256_f_16x16_16x16_1_2x4096_2x1x16_**_strided_*_2|3 21 | kmm_cuda_volta_64_f_16x16_16x16_2_2x256_1x2x4_**_strided_*_2|3 22 | kmm_cuda_volta_64_f_16x16_16x16_1_2x256_1x2x4_**_strided_*_2|3 23 | 24 | kmm_cuda_volta_512_f_8x8_8x8_3_2x4096_2x1x8_**_strided_*_2|3 25 | kmm_cuda_volta_512_f_8x8_8x8_1_2x4096_2x1x8_**_strided_*_2|3 26 | kmm_cuda_volta_512_f_8x8_8x8_3_16x512_1x2x8_**_strided_*_2|3 27 | kmm_cuda_volta_512_f_8x8_8x8_1_16x512_1x2x8_**_strided_*_2|3 28 | kmm_cuda_volta_128_f_8x8_8x8_1_2x128_1x1x2_**_strided_*_2|3 29 | kmm_cuda_volta_128_f_8x8_8x8_2_2x128_2x1x1_**_strided_*_2|3 30 | 31 | kmm_cuda_volta_512_f_4x4_4x4_4_16x256_1x2x4_**_strided_*_2|3 32 | kmm_cuda_volta_256_f_4x4_4x4_1_16x128_1x2x4_**_strided_*_2|3 33 | kmm_cuda_volta_256_f_4x4_4x4_3_2x256_1x1x2_**_strided_*_2|3 34 | kmm_cuda_volta_256_f_4x4_4x4_1_2x256_1x1x2_**_strided_*_2|3 35 | kmm_cuda_volta_64_f_4x4_4x4_1_2x64_1x2x1_**_strided_*_2|3 36 | kmm_cuda_volta_64_f_4x4_4x4_3_2x64_1x2x1_**_strided_*_2|3 37 | 38 | kmm_cuda_volta_128_f_2x2_2x2_6_16x64_1x4x2_**_strided_*_2|3 39 | kmm_cuda_volta_128_f_2x2_2x2_1_16x64_1x4x2_**_strided_*_2|3 40 | kmm_cuda_volta_64_f_2x2_2x2_6_2x64_1x1x2_**_strided_*_2|3 41 | kmm_cuda_volta_64_f_2x2_2x2_1_2x64_1x1x2_**_strided_*_2|3 42 | 43 | kmm_cuda_volta_128_d_128x128_16x64_1_16x512_2x2x8_**_*_*_* 44 | kmm_cuda_volta_64_d_128x128_16x128_1_2x1024_1x4x8_**_*_*_* 45 | kmm_cuda_volta_128_d_128x128_32x64_1_2x512_1x4x1_**_*_*_* 46 | 47 | kmm_cuda_volta_256_d_64x64_32x64_1_16x128_1x2x4_**_*_*_* 48 | kmm_cuda_volta_128_d_64x64_32x64_1_2x2048_1x1x32_**_*_*_* 49 | kmm_cuda_volta_64_d_64x64_32x16_1_2x256_1x1x2_**_*_*_* 50 | 51 | kmm_cuda_volta_128_d_32x32_32x32_2_2x2048_1x1x32_**_strided_*_* 52 | kmm_cuda_volta_128_d_32x32_32x32_1_2x2048_1x1x32_**_strided_*_* 53 | kmm_cuda_volta_512_d_32x32_32x32_1_16x128_1x2x2_**_strided_*_* 54 | 55 | kmm_cuda_volta_512_d_16x16_16x16_2_16x256_1x2x4_**_strided_*_2|3 56 | kmm_cuda_volta_256_d_16x16_16x16_1_16x32_1x1x2_**_strided_*_2|3 57 | kmm_cuda_volta_128_d_16x16_16x16_2_2x1024_1x1x16_**_strided_*_2|3 58 | kmm_cuda_volta_128_d_16x16_16x16_1_2x1024_1x1x16_**_strided_*_2|3 59 | kmm_cuda_volta_128_d_16x16_16x16_2_2x256_1x1x4_**_strided_*_2|3 60 | kmm_cuda_volta_128_d_16x16_16x16_1_2x256_1x1x4_**_strided_*_2|3 61 | 62 | kmm_cuda_volta_256_d_8x8_8x8_2_16x128_1x2x4_**_strided_*_2|3 63 | kmm_cuda_volta_256_d_8x8_8x8_1_16x128_1x2x4_**_strided_*_2|3 64 | kmm_cuda_volta_256_d_8x8_8x8_3_2x1024_1x1x8_**_strided_*_2|3 65 | kmm_cuda_volta_64_d_8x8_8x8_2_2x512_1x2x8_**_strided_*_2|3 66 | kmm_cuda_volta_64_d_8x8_8x8_1_2x512_1x2x8_**_strided_*_2|3 67 | kmm_cuda_volta_64_d_8x8_8x8_2_2x64_1x1x2_**_strided_*_2|3 68 | kmm_cuda_volta_64_d_8x8_8x4_1_2x64_1x1x1_**_strided_*_2|3 69 | 70 | kmm_cuda_volta_1024_d_4x4_4x4_4_16x256_1x2x2_**_strided_*_2|3 71 | kmm_cuda_volta_1024_d_4x4_4x4_1_16x256_1x2x2_**_strided_*_2|3 72 | kmm_cuda_volta_256_d_4x4_4x4_4_2x512_1x1x4_**_strided_*_2|3 73 | kmm_cuda_volta_256_d_4x4_4x4_1_2x512_1x1x4_**_strided_*_2|3 74 | kmm_cuda_volta_64_d_4x4_4x4_3_2x64_1x1x2_**_strided_*_2|3 75 | kmm_cuda_volta_64_d_4x4_4x4_1_2x64_1x1x2_**_strided_*_2|3 76 | 77 | kmm_cuda_volta_128_d_2x2_2x2_6_16x64_1x4x2_**_strided_*_2|3 78 | kmm_cuda_volta_128_d_2x2_2x2_1_16x64_1x4x2_**_strided_*_2|3 79 | 80 | kmm_cuda_volta_64_d_2x2_2x2_6_2x64_1x2x1_**_strided_*_2|3 81 | kmm_cuda_volta_64_d_2x2_2x2_1_2x64_1x2x1_**_strided_*_2|3 -------------------------------------------------------------------------------- /src/utils/utils.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #pragma once 9 | 10 | #define CUDA_LAST_ERROR do { \ 11 | cudaError_t e = cudaGetLastError(); \ 12 | if (e != cudaSuccess) { \ 13 | printf("Failed: Cuda error %s:%d '%s'\n", \ 14 | __FILE__,__LINE__,cudaGetErrorString(e)); \ 15 | exit(EXIT_FAILURE); \ 16 | } \ 17 | } while (0) \ 18 | 19 | #define CUDA_CHECK(cmd) do { \ 20 | cudaError_t e = cmd; \ 21 | if(e != cudaSuccess and \ 22 | e != cudaErrorPeerAccessAlreadyEnabled) { \ 23 | printf("Failed: Cuda error %s:%d '%s'\n", \ 24 | __FILE__,__LINE__,cudaGetErrorString(e)); \ 25 | exit(EXIT_FAILURE); \ 26 | } \ 27 | } while(0) \ 28 | 29 | #define HIP_CHECK(cmd) do { \ 30 | hipError_t e = cmd; \ 31 | if(e != hipSuccess and \ 32 | e != hipErrorPeerAccessAlreadyEnabled) { \ 33 | printf("Failed: HIP error %s:%d '%s'\n", \ 34 | __FILE__,__LINE__,hipGetErrorString(e)); \ 35 | exit(EXIT_FAILURE); \ 36 | } \ 37 | } while(0) \ 38 | 39 | #define NCCLCHECK(cmd) do { \ 40 | ncclResult_t r = cmd; \ 41 | if (r!= ncclSuccess) { \ 42 | printf("Failed, NCCL error %s:%d '%s'\n", \ 43 | __FILE__,__LINE__,ncclGetErrorString(r)); \ 44 | exit(EXIT_FAILURE); \ 45 | } \ 46 | } while(0) 47 | 48 | #define PTHREAD_BARRIER_CHECK(x) do { \ 49 | if (x != 0 && \ 50 | x != PTHREAD_BARRIER_SERIAL_THREAD) { \ 51 | printf("Failed: pthread barrier error %s:%d\n", \ 52 | __FILE__,__LINE__); \ 53 | exit(EXIT_FAILURE); \ 54 | } \ 55 | } while (0) \ 56 | 57 | #define MIN(x,y) (((x) < (y)) ? (x) : (y)) 58 | #define MAX(x,y) (((x) > (y)) ? (x) : (y)) 59 | #define DIVUP(x,y) (((x) + (y) - 1)/((y))) 60 | #define ROUNDUP(x,y) (DIVUP(x,y)*(y)) 61 | #define ROUNDDOWN(x,y) (x/y)*y 62 | 63 | #define CUDA_WARP_SIZE 32U 64 | #define NULL_CHECK(x) if ((x) == nullptr) return fastKronInvalidArgument; 65 | 66 | // static constexpr int log2(uint n) {return 31 - __builtin_clz(n);} 67 | // static constexpr int log2(int n) {return 31 - __builtin_clz(n);} 68 | 69 | static inline double convertTimeValToDouble(struct timeval _time) { 70 | return ((double)_time.tv_sec)*1e6 + ((double)_time.tv_usec); 71 | } 72 | 73 | static inline struct timeval getTimeOfDay () { 74 | struct timeval _time; 75 | 76 | if (gettimeofday (&_time, NULL) == -1) { 77 | fprintf (stderr, "gettimeofday returned -1\n"); 78 | perror (""); 79 | abort (); 80 | } 81 | 82 | return _time; 83 | } 84 | 85 | static inline double getCurrTime() { 86 | return convertTimeValToDouble(getTimeOfDay()); 87 | } 88 | 89 | // static inline int ilog2(uint x) { 90 | // return sizeof(uint32_t) * CHAR_BIT - __builtin_clz(x) - 1; 91 | // } 92 | 93 | static inline bool isPowerOf2(uint x) { 94 | return (x & (x - 1)) == 0; 95 | } 96 | 97 | static inline int ffs(uint x) { 98 | for (int i = 0; i < 32; i++) { 99 | if (((x >> i) & 1) == 1) return i; 100 | } 101 | return -1; 102 | } 103 | 104 | template 105 | static inline void memset(T* ptr, size_t nelem, T val) { 106 | for (uint32_t i = 0; i < nelem; i++) 107 | ptr[i] = val; 108 | } 109 | 110 | static inline void parallelCopy(char* trash1, char* trash2, uint32_t sz) { 111 | #pragma omp parallel for 112 | for (uint32_t i = 0; i < sz; i++) { 113 | trash1[i] = trash2[i]; 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /tests/src/general-tests-TT.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "testBase.h" 3 | 4 | #define GENERAL_TEST_TT(MMType, M, MinFacs, MaxFacs, P, Q, Type, Tune, IsForward, BatchZ, BatchX, BatchF, BatchY) \ 5 | TEST(EXPAND(TEST_BACKEND,Fusion), MMType##_##Type##_##M##x##MinFacs##_##MaxFacs##_##P##x##Q##_##Tune##_##IsForward##_##BatchZ##x##BatchX##x##BatchF##x##BatchY##_##TT) { \ 6 | bool result = true;\ 7 | for (uint Facs = MinFacs; Facs <= MaxFacs; Facs++) {\ 8 | uint KP_MAT_N[Facs];\ 9 | uint KP_MAT_K[Facs];\ 10 | uint N = 1;\ 11 | uint K = 1;\ 12 | for (uint i = 0; i < (uint)Facs; i++) {\ 13 | N *= Q;\ 14 | K *= P;\ 15 | KP_MAT_K[i] = P;\ 16 | KP_MAT_N[i] = Q;\ 17 | }\ 18 | Type alpha = IsForward ? 1.0f : 2.0f;\ 19 | Type beta = IsForward ? 0.0f : 1.0f;\ 20 | result = result and run(MMType, M, N, K, Facs, KP_MAT_N, KP_MAT_K, fastKronOp_T, fastKronOp_T, BatchZ,BatchX,BatchF,BatchY,alpha, beta, 1, 0, false, 1, 1, 1, 1, true, true, Tune, getTestBackend(), false, false);\ 21 | if (!result) abort();\ 22 | }\ 23 | EXPECT_TRUE(result);\ 24 | } 25 | 26 | #define CONTIGUOUS_TEST_TT(MinN, MaxN, P, Q, Type, Tune, IsForward) \ 27 | GENERAL_TEST_TT(MKM, 16, MinN, MaxN, P, Q, Type, Tune, IsForward, 1, 1, 1, 1); \ 28 | GENERAL_TEST_TT(MKM, 13, MinN, MaxN, P, Q, Type, Tune, IsForward, 1, 1, 1, 1); \ 29 | GENERAL_TEST_TT(KMM, 1, MinN, MaxN, P, Q, Type, Tune, IsForward, 1, 1, 1, 1); \ 30 | GENERAL_TEST_TT(KMM, 17, MinN, MaxN, P, Q, Type, Tune, IsForward, 1, 1, 1, 1); \ 31 | 32 | #define STRIDED_BATCHED_TEST_TT(MinN, MaxN, P, Q, Type, Tune, IsForward, BatchZ, BatchX, BatchF, BatchY) \ 33 | GENERAL_TEST_TT(MKM, 16, MinN, MaxN, P, Q, Type, Tune, IsForward, BatchZ, BatchX, BatchF, BatchY); \ 34 | GENERAL_TEST_TT(KMM, 3, MinN, MaxN, P, Q, Type, Tune, IsForward, BatchZ, BatchX, BatchF, BatchY); 35 | 36 | CONTIGUOUS_TEST_TT(1, 10, 2, 1, float, false, false); 37 | CONTIGUOUS_TEST_TT(1, 9, 1, 6, float, false, false); 38 | CONTIGUOUS_TEST_TT(1, 10, 2, 2, float, false, false); 39 | CONTIGUOUS_TEST_TT(1, 6, 3, 3, float, false, false); 40 | CONTIGUOUS_TEST_TT(1, 6, 4, 4, float, false, false); 41 | CONTIGUOUS_TEST_TT(1, 5, 5, 5, float, false, false); 42 | CONTIGUOUS_TEST_TT(1, 5, 6, 6, float, false, false); 43 | CONTIGUOUS_TEST_TT(1, 5, 8, 8, float, false, false); 44 | CONTIGUOUS_TEST_TT(1, 5, 12, 12, float, false, false); 45 | CONTIGUOUS_TEST_TT(1, 5, 16, 16, float, false, false); 46 | CONTIGUOUS_TEST_TT(1, 5, 24, 24, float, false, false); 47 | CONTIGUOUS_TEST_TT(1, 4, 31, 31, float, false, false); 48 | CONTIGUOUS_TEST_TT(1, 4, 32, 32, float, false, false); 49 | CONTIGUOUS_TEST_TT(1, 3, 50, 50, float, false, false); 50 | CONTIGUOUS_TEST_TT(1, 3, 55, 55, float, false, false); 51 | CONTIGUOUS_TEST_TT(1, 3, 62, 62, float, false, false); 52 | CONTIGUOUS_TEST_TT(1, 3, 64, 64, float, false, false); 53 | CONTIGUOUS_TEST_TT(1, 3, 127, 127, float, false, false); 54 | CONTIGUOUS_TEST_TT(1, 3, 128, 128, float, false, false); 55 | CONTIGUOUS_TEST_TT(1, 3, 129, 129, float, false, false); 56 | CONTIGUOUS_TEST_TT(1, 2, 255, 255, float, false, false); 57 | CONTIGUOUS_TEST_TT(1, 2, 297, 297, float, false, false); 58 | CONTIGUOUS_TEST_TT(1, 2, 384, 384, float, false, false); 59 | CONTIGUOUS_TEST_TT(1, 1, 505, 505, float, false, false); 60 | CONTIGUOUS_TEST_TT(1, 1, 512, 512, float, false, false); 61 | CONTIGUOUS_TEST_TT(1, 1, 739, 739, float, false, false); 62 | CONTIGUOUS_TEST_TT(1, 1, 1024, 1024, float, false, false); 63 | 64 | CONTIGUOUS_TEST_TT(1, 5, 8, 2, float, false, false); 65 | CONTIGUOUS_TEST_TT(1, 3, 31, 63, float, false, false); 66 | CONTIGUOUS_TEST_TT(1, 3, 63, 31, float, false, false); 67 | CONTIGUOUS_TEST_TT(1, 2, 297, 127, float, false, false); 68 | CONTIGUOUS_TEST_TT(1, 2, 127, 297, float, false, false); 69 | CONTIGUOUS_TEST_TT(1, 2, 936, 505, float, false, false); 70 | 71 | CONTIGUOUS_TEST_TT(1, 3, 128, 128, float, true, false); 72 | CONTIGUOUS_TEST_TT(3, 4, 32, 32, float, true, false); 73 | CONTIGUOUS_TEST_TT(3, 5, 18, 8, float, true, false); 74 | 75 | CONTIGUOUS_TEST_TT(1, 4, 8, 32, float, true, true); 76 | 77 | STRIDED_BATCHED_TEST_TT(1, 3, 128, 128, float, true, false, 2, 2, 2, 2); 78 | STRIDED_BATCHED_TEST_TT(3, 4, 32, 32, float, true, false, 2, 1, 2, 2); 79 | STRIDED_BATCHED_TEST_TT(1, 3, 64, 64, float, false, false, 2, 2, 2, 2); 80 | STRIDED_BATCHED_TEST_TT(1, 5, 12, 12, float, true, false, 2, 2, 1, 1); 81 | STRIDED_BATCHED_TEST_TT(3, 4, 5, 5, float, false, false, 2, 1, 2, 2); 82 | 83 | STRIDED_BATCHED_TEST_TT(1, 3, 128, 128, float, true, true, 2, 2, 2, 2); 84 | -------------------------------------------------------------------------------- /src/kernels/best-kernels/kmm-a100-kernels: -------------------------------------------------------------------------------- 1 | kmm_cuda_ampere_128_f_128x128_32x64_1_32x256_4x2x4_**_*_*_* 2 | kmm_cuda_ampere_128_f_128x128_32x64_1_16x512_4x1x8_**_*_*_* 3 | kmm_cuda_ampere_128_f_128x128_32x64_1_2x8192_2x1x32_**_*_*_* 4 | kmm_cuda_ampere_128_f_128x128_32x64_1_2x256_1x2x1_**_*_*_* 5 | kmm_cuda_ampere_64_f_128x128_32x64_1_8x256_4x2x2_T*_*_*_* 6 | 7 | kmm_cuda_ampere_128_f_64x64_32x64_1_32x128_4x2x4_**_*_*_* 8 | kmm_cuda_ampere_128_f_64x64_32x64_1_16x128_2x2x4_**_*_*_* 9 | kmm_cuda_ampere_64_f_64x64_32x64_1_2x1024_1x4x8_**_*_*_* 10 | kmm_cuda_ampere_64_f_64x64_32x32_1_2x128_1x2x1_**_*_*_* 11 | kmm_cuda_ampere_64_f_64x64_32x64_1_8x128_4x2x2_T*_*_*_* 12 | 13 | kmm_cuda_ampere_256_f_32x32_32x32_1_32x128_4x1x4_N*_strided_*_* 14 | kmm_cuda_ampere_128_f_32x32_32x32_1_32x128_4x2x4_T*_strided_*_* 15 | kmm_cuda_ampere_128_f_32x32_32x32_1_16x128_1x4x4_**_strided_*_* 16 | kmm_cuda_ampere_64_f_32x32_32x32_1_2x2048_2x1x32_**_strided_*_* 17 | kmm_cuda_ampere_128_f_32x32_32x32_1_2x128_2x1x1_**_strided_*_* 18 | kmm_cuda_ampere_256_f_32x32_32x32_1_8x512_1x4x4_**_strided_*_* 19 | 20 | kmm_cuda_ampere_128_f_16x16_16x16_2_16x256_2x1x16_N*_strided_*_2|3 21 | kmm_cuda_ampere_128_f_16x16_16x16_2_2x4096_2x2x16_**_strided_*_2|3 22 | kmm_cuda_ampere_128_f_16x16_16x16_1_16x256_2x1x16_N*_strided_*_2|3 23 | kmm_cuda_ampere_128_f_16x16_16x16_1_2x128_1x1x2_**_strided_*_2|3 24 | kmm_cuda_ampere_128_f_16x16_16x16_2_32x256_2x4x8_T*_strided_*_2|3 25 | kmm_cuda_ampere_64_f_16x16_16x16_1_8x32_1x2x2_T*_strided_*_2|3 26 | 27 | kmm_cuda_ampere_256_f_8x8_8x8_2_64x64_1x2x8_T*_strided_*_2|3 28 | kmm_cuda_ampere_256_f_8x8_8x8_1_64x64_1x2x8_T*_strided_*_2|3 29 | kmm_cuda_ampere_512_f_8x8_8x8_3_16x512_1x2x8_**_strided_*_2|3 30 | kmm_cuda_ampere_512_f_8x8_8x8_2_16x256_1x1x8_**_strided_*_2|3 31 | kmm_cuda_ampere_128_f_8x8_8x8_1_16x64_2x1x4_**_strided_*_2|3 32 | kmm_cuda_ampere_128_f_8x8_8x8_2_8x64_1x2x2_T*_strided_*_2|3 33 | kmm_cuda_ampere_128_f_8x8_8x8_1_8x64_1x2x2_T*_strided_*_2|3 34 | kmm_cuda_ampere_512_f_8x8_8x8_3_2x4096_2x1x8_**_strided_*_2|3 35 | kmm_cuda_ampere_128_f_8x8_8x8_2_2x1024_2x1x8_**_strided_*_2|3 36 | kmm_cuda_ampere_128_f_8x8_8x8_1_2x1024_2x1x8_**_strided_*_2|3 37 | kmm_cuda_ampere_128_f_8x8_8x8_1_2x128_1x2x1_**_strided_*_2|3 38 | 39 | kmm_cuda_ampere_256_f_4x4_4x4_4_16x256_1x4x4_**_strided_*_2|3 40 | kmm_cuda_ampere_256_f_4x4_4x4_1_16x64_1x1x4_**_strided_*_2|3 41 | kmm_cuda_ampere_256_f_4x4_4x4_4_2x256_1x1x2_**_strided_*_2|3 42 | kmm_cuda_ampere_128_f_4x4_4x2_1_2x128_1x1x1_**_strided_*_2|3 43 | 44 | kmm_cuda_ampere_256_f_2x2_2x2_6_16x64_2x1x2_**_strided_*_2|3 45 | kmm_cuda_ampere_256_f_2x2_2x2_1_16x32_1x1x2_**_strided_*_2|3 46 | kmm_cuda_ampere_128_f_2x2_2x2_6_2x64_1x1x1_**_strided_*_2|3 47 | kmm_cuda_ampere_64_f_2x2_2x2_1_2x64_2x1x1_**_strided_*_2|3 48 | 49 | kmm_cuda_ampere_128_d_128x128_32x64_1_16x512_1x2x16_**_*_*_* 50 | kmm_cuda_ampere_128_d_128x128_32x64_1_2x4096_1x2x16_**_*_*_* 51 | kmm_cuda_ampere_128_d_128x128_32x64_1_2x256_1x2x1_**_*_*_* 52 | 53 | kmm_cuda_ampere_256_d_64x64_32x64_1_16x256_1x2x8_**_*_*_* 54 | kmm_cuda_ampere_128_d_64x64_32x64_1_2x2048_2x1x16_**_*_*_* 55 | kmm_cuda_ampere_128_d_64x64_32x32_1_2x128_1x1x1_**_*_*_* 56 | 57 | kmm_cuda_ampere_256_d_32x32_32x32_1_16x128_1x2x4_**_strided_*_* 58 | kmm_cuda_ampere_128_d_32x32_32x32_2_2x2048_1x1x32_**_strided_*_* 59 | kmm_cuda_ampere_64_d_32x32_32x32_1_2x64_1x2x1_**_strided_*_* 60 | 61 | kmm_cuda_ampere_256_d_16x16_16x16_2_16x256_1x4x4_**_strided_*_2|3 62 | kmm_cuda_ampere_256_d_16x16_16x16_1_16x256_1x4x4_**_strided_*_2|3 63 | kmm_cuda_ampere_128_d_16x16_16x16_2_2x2048_2x1x16_**_strided_*_2|3 64 | kmm_cuda_ampere_64_d_16x16_16x16_1_2x64_1x2x1_**_strided_*_2|3 65 | 66 | kmm_cuda_ampere_512_d_8x8_8x8_2_16x128_1x1x4_**_strided_*_2|3 67 | kmm_cuda_ampere_512_d_8x8_8x8_1_16x128_1x1x4_**_strided_*_2|3 68 | kmm_cuda_ampere_256_d_8x8_8x8_3_2x2048_2x1x8_**_strided_*_2|3 69 | kmm_cuda_ampere_128_d_8x8_8x8_2_2x1024_2x1x8_**_strided_*_2|3 70 | kmm_cuda_ampere_64_d_8x8_8x8_1_2x256_1x1x8_**_strided_*_2|3 71 | kmm_cuda_ampere_64_d_8x8_8x8_1_2x64_1x2x1_**_strided_*_2|3 72 | kmm_cuda_ampere_64_d_8x8_8x8_2_2x64_1x2x1_**_strided_*_2|3 73 | 74 | kmm_cuda_ampere_512_d_4x4_4x4_4_16x256_1x4x2_**_strided_*_2|3 75 | kmm_cuda_ampere_256_d_4x4_4x4_4_2x512_1x1x4_**_strided_*_2|3 76 | kmm_cuda_ampere_64_d_4x4_4x4_3_2x64_1x1x2_**_strided_*_2|3 77 | kmm_cuda_ampere_64_d_4x4_4x4_1_2x64_1x1x2_**_strided_*_2|3 78 | 79 | kmm_cuda_ampere_128_d_2x2_2x2_6_16x64_1x4x2_**_strided_*_2|3 80 | kmm_cuda_ampere_128_d_2x2_2x2_1_16x64_1x4x2_**_strided_*_2|3 81 | kmm_cuda_ampere_64_d_2x2_2x2_5_2x64_1x1x2_**_strided_*_2|3 82 | kmm_cuda_ampere_64_d_2x2_2x2_1_2x64_1x1x2_**_strided_*_2|3 -------------------------------------------------------------------------------- /tests/cuda/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # CUDA tests 2 | set(CUDA_TEST_DEFINE -D TEST_BACKEND_CUDA) 3 | set(CUDA_TESTS_INCLUDES ${TESTS_INCLUDES}) 4 | 5 | if (ENABLE_CUDA) 6 | set_source_files_properties(${TEST_SRC}/general-tests-NN.cpp PROPERTIES LANGUAGE CUDA) 7 | set_source_files_properties(${TEST_SRC}/general-tests-TT.cpp PROPERTIES LANGUAGE CUDA) 8 | set_source_files_properties(${TEST_SRC}/general-tests-all.cpp PROPERTIES LANGUAGE CUDA) 9 | set_source_files_properties(${TEST_SRC}/multi-cuda-no-fusion-tests.cpp PROPERTIES LANGUAGE CUDA) 10 | set_source_files_properties(${TEST_SRC}/multi-cuda-tuner-tests.cpp PROPERTIES LANGUAGE CUDA) 11 | set_source_files_properties(${TEST_SRC}/multi-cuda-no-fusion-non-square-tests.cpp PROPERTIES LANGUAGE CUDA) 12 | set_source_files_properties(${TEST_SRC}/multi-cuda-distinct-shapes.cpp PROPERTIES LANGUAGE CUDA) 13 | endif() 14 | 15 | add_custom_target(gen-single-gpu-kernels 16 | COMMAND python3 ${GEN_TUNER_KERNELS_PY} -mm-type mkm kmm -backend cuda -archs ampere volta -same-factors 2 128,128 -same-factors 2 64,64 -same-factors 3 32,32 -same-factors 5 16,16 -same-factors 7 8,8 -same-factors 10 4,4 -same-factors 20 2,2 -opX N T -opF N T -types float double -match-configs-file ${SRC}/kernels/best-kernels/a100-kernels ${SRC}/kernels/best-kernels/kmm-a100-kernels) 17 | 18 | add_executable(single-gpu-cuda-NN ${TEST_SRC}/general-tests-NN.cpp) 19 | target_include_directories(single-gpu-cuda-NN PRIVATE ${TESTS_INCLUDES}) 20 | target_link_libraries(single-gpu-cuda-NN ${TESTS_LIBS}) 21 | target_compile_definitions(single-gpu-cuda-NN PRIVATE ${CUDA_TEST_DEFINE}) 22 | 23 | add_executable(single-gpu-cuda-TT ${TEST_SRC}/general-tests-TT.cpp) 24 | target_include_directories(single-gpu-cuda-TT PRIVATE ${TESTS_INCLUDES}) 25 | target_link_libraries(single-gpu-cuda-TT ${TESTS_LIBS}) 26 | target_compile_definitions(single-gpu-cuda-TT PRIVATE ${CUDA_TEST_DEFINE}) 27 | 28 | add_executable(single-gpu-cuda-all ${TEST_SRC}/general-tests-all.cpp) 29 | target_include_directories(single-gpu-cuda-all PRIVATE ${TESTS_INCLUDES}) 30 | target_link_libraries(single-gpu-cuda-all ${TESTS_LIBS}) 31 | target_compile_definitions(single-gpu-cuda-all PRIVATE ${CUDA_TEST_DEFINE}) 32 | 33 | if (ENABLE_MULTI_GPU AND ENABLE_CUDA) 34 | add_custom_target(gen-multi-cuda-tests-kernel 35 | COMMAND python3 ${GEN_TUNER_KERNELS_PY} -mm-type mkm -backend cuda -archs ampere volta -same-factors 4 64,64 -same-factors 4 128,128 -types float -dist-kernels -opX N -opF N -match-configs mkm_cuda_ampere_128_f_64x64_32x64_1_1x4096_1x4x8 cuda_ampere_128_f_128x128_32x128_1_1x8192_1x4x16 mkm_cuda_volta_256_f_128x128_32x128_1_2x4096_1x2x16 cuda_volta_256_f_64x64_32x64_1_2x2048_2x1x8 36 | ) 37 | 38 | add_executable(multi-cuda-no-fusion-tests ${TEST_SRC}/multi-cuda-no-fusion-tests.cpp) 39 | 40 | target_include_directories(multi-cuda-no-fusion-tests PRIVATE ${TESTS_INCLUDES}) 41 | target_link_libraries(multi-cuda-no-fusion-tests ${TESTS_LIBS}) 42 | target_compile_definitions(multi-cuda-no-fusion-tests PRIVATE ${CUDA_TEST_DEFINE}) 43 | 44 | add_executable(multi-cuda-tuner-tests ${TEST_SRC}/multi-cuda-tuner-tests.cpp) 45 | add_custom_target(gen-multi-cuda-tuner-kernels 46 | COMMAND python3 ${GEN_TUNER_KERNELS_PY} -mm-type mkm -backend cuda -archs ampere -types float -same-factors 5 16,16 -dist-kernels -opX N -opF N -num-kernels 50 47 | ) 48 | target_include_directories(multi-cuda-tuner-tests PRIVATE ${TESTS_INCLUDES}) 49 | target_link_libraries(multi-cuda-tuner-tests ${TESTS_LIBS}) 50 | target_compile_definitions(multi-cuda-tuner-tests PRIVATE ${CUDA_TEST_DEFINE}) 51 | 52 | add_executable(multi-cuda-no-fusion-non-square-tests ${TEST_SRC}/multi-cuda-no-fusion-non-square-tests.cpp) 53 | add_custom_target(gen-multi-cuda-no-fusion-non-square-tests-kernel 54 | COMMAND python3 ${GEN_TUNER_KERNELS_PY} -mm-type mkm -backend cuda -archs ampere -types float -same-factors 5 8,32 -same-factors 4 64,16 -dist-kernels -opX N -opF N -num-kernels 50 55 | ) 56 | target_include_directories(multi-cuda-no-fusion-non-square-tests PRIVATE ${TESTS_INCLUDES}) 57 | target_link_libraries(multi-cuda-no-fusion-non-square-tests ${TESTS_LIBS}) 58 | target_compile_definitions(multi-cuda-no-fusion-non-square-tests PRIVATE ${CUDA_TEST_DEFINE}) 59 | 60 | add_executable(multi-cuda-distinct-shapes ${TEST_SRC}/multi-cuda-distinct-shapes.cpp) 61 | add_custom_target(gen-multi-cuda-distinct-shapes 62 | COMMAND python3 ${GEN_TUNER_KERNELS_PY} -mm-type mkm -backend cuda -archs ampere -types float -distinct-factors 3 8,16 32,8 16,8 -dist-kernels -opX N -opF N -num-kernels 50 63 | ) 64 | target_include_directories(multi-cuda-distinct-shapes PRIVATE ${TESTS_INCLUDES}) 65 | target_link_libraries(multi-cuda-distinct-shapes ${TESTS_LIBS}) 66 | target_compile_definitions(multi-cuda-distinct-shapes PRIVATE ${CUDA_TEST_DEFINE}) 67 | endif() -------------------------------------------------------------------------------- /src/kernels/kmmkernel.cpp: -------------------------------------------------------------------------------- 1 | #include "kernels/kmmkernel.h" 2 | 3 | size_t KMMKernel::getMaxTotalTileSize() const { 4 | Matrix Xsh = Matrix(tileX.m(), (tileX.n()/f.p())*tileF.p()); 5 | //TODO: make this tileF.size() + Xsh.size() 6 | return (tileF.numel() + Xsh.numel())*sizeOfFastKronType(elemType); 7 | } 8 | 9 | Matrix KMMKernel::getMaxTileY() const { 10 | return Matrix(tileX.m(), (tileX.n()/f.p()) * tileF.q()); 11 | } 12 | 13 | Factor KMMKernel::getTileF(KMMProblem problem) const { 14 | Factor f_ = problem.f(0); 15 | return Factor(MIN(tileF.p(), f_.p()), MIN(tileF.q(), f_.q())); 16 | } 17 | 18 | Matrix KMMKernel::getTileX(KMMProblem problem) const { 19 | Factor f_ = problem.f(0); 20 | 21 | uint32_t kernelTileSlices = tileX.n()/f.p(); 22 | uint32_t problemTileSlices = problem.x().n()/f_.p(); 23 | 24 | uint32_t slices = 0; 25 | 26 | if (problemTileSlices >= kernelTileSlices) { 27 | slices = kernelTileSlices; 28 | } else { 29 | slices = MAX(1, MIN(tileX.n()/f_.p(), kernelTileSlices)); 30 | slices = MIN(problemTileSlices, slices); 31 | } 32 | 33 | return Matrix(tileX.m(), slices * f_.p()); 34 | } 35 | 36 | size_t KMMKernel::getTotalTileSize(KMMProblem problem) const { 37 | Matrix tileX_ = getTileX(problem); 38 | Factor f_ = problem.f(0); 39 | 40 | //Pad Xsh to TileP 41 | //Pad Fsh to TileP x TileQ 42 | Matrix Xsh = Matrix(tileX_.m(), 43 | (tileX_.n()/f_.p()) * tileF.p()); 44 | return (tileF.numel() + Xsh.numel())*sizeOfFastKronType(elemType); 45 | } 46 | 47 | size_t KMMKernel::getNumThreads(KMMProblem problem) const { 48 | Matrix tileX_ = getTileX(problem); 49 | Factor tileF_ = getTileF(problem); 50 | 51 | return DIVUP(problem.k(), tileX_.n()) * 52 | DIVUP(problem.f(0).q(), tileF_.q()) * 53 | DIVUP(problem.m(), tileX_.m()); 54 | } 55 | 56 | bool KMMKernel::isOptValid(KMMProblem problem, KernelOptimizations::Optimization opt) const { 57 | using Opts = KernelOptimizations::Optimization; 58 | switch (opt) { 59 | case Opts::None: 60 | return true; 61 | case Opts::XshSlicesSame: 62 | return getTileX(problem).n()/problem.f(0).p() == tileX.n()/f.p(); 63 | case Opts::QMultipleOfTileQ: 64 | return problem.f(0).q() % tileF.q() == 0; 65 | case Opts::PMultipleOfTileP: 66 | return problem.f(0).p() % tileF.p() == 0; 67 | case Opts::KMultipleOfTileK: 68 | return problem.k() % getTileX(problem).n() == 0; 69 | case Opts::MMultipleOfTileM: 70 | return problem.m() % getTileX(problem).m() == 0; 71 | case Opts::QLeTileQ: 72 | return problem.f(0).q() <= f.q(); 73 | case Opts::TileKSame: 74 | return getTileX(problem).n() == tileX.n(); 75 | case Opts::FactorShapeSame: 76 | return f.p() == problem.f(0).p() && f.q() == problem.f(0).q(); 77 | 78 | default: 79 | return false; 80 | } 81 | 82 | return false; 83 | } 84 | 85 | bool KMMKernel::canCompute(KMMProblem problem, const HardwareDetails*, 86 | bool p2p, KernelBatchType::Ty probBatchType, 87 | bool exactFuse) { 88 | using Opts = KernelOptimizations::Optimization; 89 | 90 | bool ret = problem.mmtype() == mmType && problem.type() == elemType && 91 | problem.opFs() == opF && problem.opX() == opX && 92 | P2PStore == p2p && ((exactFuse && problem.n() <= fusedFacs) || !exactFuse); 93 | 94 | if (!ret) return false; 95 | 96 | ret = kernelBatchType == probBatchType || 97 | //A strided batched kernel can compute single batch problem 98 | (kernelBatchType == KernelBatchType::StridedBatched && 99 | probBatchType == KernelBatchType::Normal); 100 | 101 | if (!ret) return false; 102 | 103 | bool followsAllOpts = true; 104 | uint lg = 0; 105 | for (Opts opt = Opts(lg); opt < Opts::NumOptimizations; opt = Opts(1 << lg), ++lg) { 106 | if ((KernelOptimizations::getOptimizations(optLevel) & opt) == opt) { 107 | followsAllOpts = followsAllOpts && isOptValid(problem, opt); 108 | }} 109 | 110 | return followsAllOpts; 111 | } 112 | 113 | static std::ostream& operator<<(std::ostream& os, const KernelBatchType::Ty& b) { 114 | switch (b) { 115 | case KernelBatchType::Normal: 116 | os << "cont"; 117 | break; 118 | case KernelBatchType::StridedBatched: 119 | os << "strided"; 120 | break; 121 | case KernelBatchType::Batch: 122 | os << "batched"; 123 | break; 124 | } 125 | return os; 126 | } 127 | 128 | std::string KMMKernel::str() const { 129 | std::stringstream info; 130 | info << strOfFastKronMMType(mmType) 131 | << "_" << strOfFastKronType(elemType) 132 | << "_" << f << "_" << tileF <<"_" << fusedFacs 133 | << "_" << tileX << "_" << regM << "x" << regK << "x" << regQ 134 | << "_" << opX << opF << "_" << kernelBatchType << "_" << P2PStore << "_" << optLevel; 135 | return info.str(); 136 | } -------------------------------------------------------------------------------- /tests/cuda/old/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # CUDA tests 2 | set(CUDA_TEST_DEFINE -D TEST_BACKEND_CUDA) 3 | set(CUDA_TESTS_INCLUDES ${TESTS_INCLUDES}) 4 | 5 | set_source_files_properties(${TEST_SRC}/no-fusion-tests.cpp PROPERTIES LANGUAGE CUDA) 6 | set_source_files_properties(${TEST_SRC}/fusion-tests.cpp PROPERTIES LANGUAGE CUDA) 7 | set_source_files_properties(${TEST_SRC}/TT-tests.cpp PROPERTIES LANGUAGE CUDA) 8 | set_source_files_properties(${TEST_SRC}/tuner-tests.cpp PROPERTIES LANGUAGE CUDA) 9 | set_source_files_properties(${TEST_SRC}/non-square-tuner-tests.cpp PROPERTIES LANGUAGE CUDA) 10 | set_source_files_properties(${TEST_SRC}/non-square-TT-tests.cpp PROPERTIES LANGUAGE CUDA) 11 | set_source_files_properties(${TEST_SRC}/distinct-shapes.cpp PROPERTIES LANGUAGE CUDA) 12 | set_source_files_properties(${TEST_SRC}/odd-shapes.cpp PROPERTIES LANGUAGE CUDA) 13 | set_source_files_properties(${TEST_SRC}/no-fusion-tests.cpp PROPERTIES LANGUAGE CUDA) 14 | 15 | add_executable(single-cuda-no-fusion-tests ${TEST_SRC}/no-fusion-tests.cpp) 16 | add_executable(single-cuda-fusion-tests ${TEST_SRC}/fusion-tests.cpp) 17 | add_custom_target(gen-single-cuda-kernels 18 | COMMAND python ${GEN_TUNER_KERNELS_PY} -backend cuda -same-factors 20 2,2 -same-factors 10 4,4 -same-factors 8 8,8 -same-factors 6 16,16 -same-factors 5 32,32 -same-factors 4 64,64 -same-factors 3 128,128 -opX N -opF N -match-configs-file ${CMAKE_CURRENT_SOURCE_DIR}/single-cuda-kernel-decls.in 19 | ) 20 | target_include_directories(single-cuda-no-fusion-tests PRIVATE ${TESTS_INCLUDES}) 21 | target_link_libraries(single-cuda-no-fusion-tests ${TESTS_LIBS}) 22 | target_compile_definitions(single-cuda-no-fusion-tests PRIVATE ${CUDA_TEST_DEFINE}) 23 | 24 | target_include_directories(single-cuda-fusion-tests PRIVATE ${TESTS_INCLUDES}) 25 | target_link_libraries(single-cuda-fusion-tests ${TESTS_LIBS}) 26 | target_compile_definitions(single-cuda-fusion-tests PRIVATE ${CUDA_TEST_DEFINE}) 27 | 28 | add_executable(single-cuda-TT-tests ${TEST_SRC}/TT-tests.cpp) 29 | add_custom_target(gen-single-cuda-TT-kernels 30 | COMMAND python ${GEN_TUNER_KERNELS_PY} -backend cuda -same-factors 8 8,8 -same-factors 6 16,16 -same-factors 5 32,32 -same-factors 4 64,64 -same-factors 3 128,128 -opX T -opF T -match-configs-file ${CMAKE_CURRENT_SOURCE_DIR}/single-cuda-kernel-decls.in 31 | ) 32 | target_include_directories(single-cuda-TT-tests PRIVATE ${TESTS_INCLUDES}) 33 | target_link_libraries(single-cuda-TT-tests ${TESTS_LIBS}) 34 | target_compile_definitions(single-cuda-TT-tests PRIVATE ${CUDA_TEST_DEFINE}) 35 | 36 | add_executable(single-cuda-tuner-tests ${TEST_SRC}/tuner-tests.cpp) 37 | add_custom_target(gen-tuner-kernels 38 | COMMAND python ${GEN_TUNER_KERNELS_PY} -backend cuda -same-factors 4 16,16 -same-factors 3 64,64 -opX N -opF N 39 | ) 40 | target_include_directories(single-cuda-tuner-tests PRIVATE ${TESTS_INCLUDES}) 41 | target_link_libraries(single-cuda-tuner-tests ${TESTS_LIBS}) 42 | target_compile_definitions(single-cuda-tuner-tests PRIVATE ${CUDA_TEST_DEFINE}) 43 | 44 | add_executable(single-cuda-non-square-tests ${TEST_SRC}/non-square-tuner-tests.cpp) 45 | add_custom_target(gen-non-square-kernels 46 | COMMAND python ${GEN_TUNER_KERNELS_PY} -backend cuda -same-factors 4 8,16 -same-factors 3 32,16 -same-factors 3 128,32 -opX N -opF N 47 | ) 48 | target_include_directories(single-cuda-non-square-tests PRIVATE ${TESTS_INCLUDES}) 49 | target_link_libraries(single-cuda-non-square-tests ${TESTS_LIBS}) 50 | target_compile_definitions(single-cuda-non-square-tests PRIVATE ${CUDA_TEST_DEFINE}) 51 | 52 | add_executable(single-cuda-non-square-TT-tests ${TEST_SRC}/non-square-TT-tests.cpp) 53 | add_custom_target(gen-single-cuda-non-square-TT-kernels 54 | COMMAND python ${GEN_TUNER_KERNELS_PY} -backend cuda -same-factors 4 8,16 -same-factors 3 32,16 -same-factors 3 128,32 -opX T -opF T 55 | ) 56 | target_include_directories(single-cuda-non-square-TT-tests PRIVATE ${TESTS_INCLUDES}) 57 | target_link_libraries(single-cuda-non-square-TT-tests ${TESTS_LIBS}) 58 | target_compile_definitions(single-cuda-non-square-TT-tests PRIVATE ${CUDA_TEST_DEFINE}) 59 | 60 | add_executable(single-cuda-distinct-shapes ${TEST_SRC}/distinct-shapes.cpp) 61 | add_custom_target(gen-single-cuda-distinct-shapes 62 | COMMAND python ${GEN_TUNER_KERNELS_PY} -backend cuda -distinct-factors 3 8,16 16,8 8,32 -opX N -opF N 63 | ) 64 | target_include_directories(single-cuda-distinct-shapes PRIVATE ${TESTS_INCLUDES}) 65 | target_link_libraries(single-cuda-distinct-shapes ${TESTS_LIBS}) 66 | target_compile_definitions(single-cuda-distinct-shapes PRIVATE ${CUDA_TEST_DEFINE}) 67 | 68 | add_executable(single-cuda-odd-shapes ${TEST_SRC}/odd-shapes.cpp) 69 | add_custom_target(gen-single-cuda-odd-shapes 70 | COMMAND python ${GEN_TUNER_KERNELS_PY} -backend cuda -same-factors 2 31,16 -same-factors 2 16,31 -same-factors 4 31,31 -opX N -opF N 71 | ) 72 | target_include_directories(single-cuda-odd-shapes PRIVATE ${TESTS_INCLUDES}) 73 | target_link_libraries(single-cuda-odd-shapes ${TESTS_LIBS}) 74 | target_compile_definitions(single-cuda-odd-shapes PRIVATE ${CUDA_TEST_DEFINE}) -------------------------------------------------------------------------------- /src/kernels/hw_details.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #pragma once 5 | 6 | class HardwareDetails { 7 | public: 8 | virtual ~HardwareDetails() {} 9 | }; 10 | 11 | enum SMArch { 12 | SMArchNone, 13 | ampere, 14 | volta, 15 | maxwell, 16 | }; 17 | 18 | static inline std::string smArchToStr(SMArch arch) { 19 | switch (arch) { 20 | case SMArch::volta: 21 | return "volta"; 22 | case SMArch::ampere: 23 | return "ampere"; 24 | case SMArch::maxwell: 25 | return "maxwell"; 26 | default: 27 | return ""; 28 | } 29 | 30 | return ""; 31 | } 32 | 33 | static inline SMArch computeCapabilityToSMArch(uint major, uint minor) { 34 | uint32_t c = major * 10 + minor; 35 | if (c >= 80 && c < 90) { 36 | return SMArch::ampere; 37 | } else if (c >= 70 && c < 80) { 38 | return SMArch::volta; 39 | } else if (c >= 60 && c < 70) { 40 | return SMArch::maxwell; 41 | } 42 | return SMArch::SMArchNone; 43 | } 44 | 45 | class CUDAArchDetails : public HardwareDetails { 46 | public: 47 | uint32_t numSMs; 48 | uint32_t maxBlocksPerSM; 49 | uint32_t maxThreadsPerBlock; 50 | uint32_t maxThreadsPerSM; 51 | uint32_t regsPerSM; 52 | uint32_t maxRegsPerThread; 53 | uint32_t sharedMemPerSM; 54 | uint32_t sharedMemPerBlock; 55 | std::string name; 56 | uint32_t computeMajor; 57 | uint32_t computeMinor; 58 | uint32_t warpSize; 59 | SMArch smArch; 60 | 61 | 62 | // CUDAArchDetail(uint32_t numSMs, uint32_t maxBlocksPerSM, uint32_t maxThreadsPerBlock, 63 | // uint32_t maxThreadsPerSM, uint32_t regsPerSM, uint32_t sharedMemPerSM) : 64 | // numSMs(numSMs), maxBlocksPerSM(maxBlocksPerSM), 65 | // maxThreadsPerBlock(maxThreadsPerBlock), 66 | // maxThreadsPerSM(maxThreadsPerSM), 67 | // regsPerSM(regsPerSM), sharedMemPerSM(sharedMemPerSM) {} 68 | CUDAArchDetails(int dev); 69 | 70 | friend std::ostream& operator<<(std::ostream &out, const CUDAArchDetails& detail) { 71 | std::string indent = " "; 72 | out << detail.name << std::endl << 73 | indent << "Compute Capability : " << (detail.computeMajor*10 + detail.computeMinor) << std::endl << 74 | indent << "SMs : " << detail.numSMs << std::endl << 75 | indent << "Max Blocks per SM : " << detail.maxBlocksPerSM << std::endl << 76 | indent << "Max Threads per SM : " << detail.maxThreadsPerSM << std::endl << 77 | indent << "Registers Per SM : " << detail.regsPerSM << std::endl << 78 | indent << "Shared Memory per SM : " << detail.sharedMemPerSM << std::endl<< 79 | indent << "Shared Memory Per Block : " << detail.sharedMemPerBlock << std::endl << 80 | indent << "Warp Size : " << detail.warpSize << std::endl 81 | ; 82 | return out; 83 | } 84 | 85 | virtual ~CUDAArchDetails() {} 86 | }; 87 | 88 | enum X86SIMD { 89 | SISD, 90 | AVX, 91 | AVX512 92 | }; 93 | 94 | static std::string x86simdToStr(X86SIMD simd) { 95 | switch(simd) { 96 | case SISD: 97 | return "SISD"; 98 | case AVX: 99 | return "AVX"; 100 | case AVX512: 101 | return "AVX512"; 102 | } 103 | return ""; 104 | } 105 | 106 | class CPUArchDetails : public HardwareDetails { 107 | public: 108 | std::string vendor; 109 | std::string model; 110 | //Cache sizes in KB 111 | uint32_t l1Size; 112 | uint32_t l2Size; 113 | uint32_t l3Size; 114 | uint32_t sockets; 115 | uint32_t cores; 116 | 117 | CPUArchDetails(std::string vendor, std::string model, uint32_t l1Size, uint32_t l2Size, uint32_t l3Size, uint32_t sockets, uint32_t cores) : 118 | vendor(vendor), model(model), l1Size(l1Size), l2Size(l2Size), l3Size(l3Size), sockets(sockets), cores(cores) 119 | {} 120 | uint32_t totalL3Size() {return l3Size * sockets;} 121 | }; 122 | 123 | class X86ArchDetails : public CPUArchDetails { 124 | public: 125 | X86SIMD simd; 126 | 127 | X86ArchDetails(std::string vendor, std::string model, uint32_t l1Size, uint32_t l2Size, uint32_t l3Size, uint32_t sockets, uint32_t cores, X86SIMD simd) : 128 | CPUArchDetails(vendor, model, l1Size, l2Size, l3Size, sockets, cores), simd(simd) 129 | {} 130 | 131 | friend std::ostream& operator<<(std::ostream& out, const X86ArchDetails& detail) { 132 | std::string indent = " "; 133 | out << indent << "Vendor : " << detail.vendor << std::endl 134 | << indent << "Model : " << detail.model << std::endl 135 | << indent << "L1 Cache Size : " << detail.l1Size << " KB" << std::endl 136 | << indent << "L2 Cache Size : " << detail.l2Size << " KB" << std::endl 137 | << indent << "L3 Cache Size : " << detail.l3Size << " KB" << std::endl 138 | << indent << "Cores : " << detail.cores << std::endl 139 | << indent << "Sockets : " << detail.sockets << std::endl 140 | << indent << "SIMD Type : " << x86simdToStr(detail.simd) << std::endl; 141 | return out; 142 | } 143 | }; -------------------------------------------------------------------------------- /tests/python/test_numpy.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import numpy as np 3 | 4 | import pyfastkron.fastkronnumpy as fk 5 | 6 | def product(values): 7 | return reduce((lambda a, b: a * b), values) 8 | 9 | def transpose(m): 10 | axis = tuple(range(len(m.shape[:-2]))) + \ 11 | (len(m.shape) - 1, len(m.shape) - 2) 12 | return m.transpose(axis) 13 | 14 | def reference(mmtype, x, fs): 15 | batchKron = fs[0].shape[:-2] 16 | if len(batchKron) == 0: 17 | outputKron = fs[0] 18 | for m in fs[1:]: 19 | outputKron = np.kron(outputKron, m) 20 | else: 21 | batchDims = product(batchKron) 22 | for i,f in enumerate(fs): 23 | fs[i] = fs[i].reshape((batchDims,) + f.shape[-2:]) 24 | 25 | output = fs[0] 26 | for f in fs[1:]: 27 | prev = output 28 | output = np.ndarray(shape=(batchDims, prev.shape[-2] * f.shape[-2], prev.shape[-1] * f.shape[-1]), 29 | dtype=f.dtype) 30 | for b in range(batchDims): 31 | output[b:] = np.kron(prev[b,:], f[b,:]) 32 | outputKron = output.reshape(batchKron + (output.shape[-2], output.shape[-1])) 33 | 34 | if mmtype == "mkm": 35 | return np.matmul(x, outputKron) 36 | elif mmtype == "kmm": 37 | return np.matmul(outputKron, x) 38 | 39 | def run(mmtype, m, n, p, q, dtype, device, trX, trF, 40 | high=5, batchDimX=[], batchDimFPre=[], batchDimZ=[]): 41 | #Using integer values instead of real numbers because 42 | #floating point is not associative 43 | if mmtype == "mkm": 44 | xshape = [m, p**n] if not trX else [p**n, m] 45 | elif mmtype == "kmm": 46 | xshape = [p**n, m] if not trX else [m, p**n] 47 | 48 | if m == 1: 49 | if trX: 50 | xshape = [xshape[0],] 51 | else: 52 | xshape = [xshape[1],] 53 | 54 | xshape = list(batchDimX) + xshape 55 | 56 | if mmtype == "mkm": 57 | fshape = [p, q] if not trF else [q, p] 58 | elif mmtype == "kmm": 59 | fshape = [q, p] if not trF else [p, q] 60 | 61 | if q == 1: 62 | if trF: 63 | fshape = [fshape[1],] 64 | else: 65 | fshape = [fshape[0],] 66 | 67 | fshape = list(batchDimFPre) + fshape 68 | 69 | zshape = list(batchDimZ) 70 | if mmtype == "mkm": 71 | zshape += [m,q**n] 72 | elif mmtype == "kmm": 73 | zshape += [q**n,m] 74 | 75 | x = np.random.randint(0, high=high,size=xshape).astype(dtype) 76 | fs = [np.random.randint(0, high=high,size=fshape).astype(dtype)\ 77 | for i in range(n)] 78 | z = np.random.randint(0, high=high,size=zshape).astype(dtype) 79 | if trX: 80 | x = transpose(x) 81 | if trF: 82 | fs = [transpose(f) for f in fs] 83 | 84 | alpha = 3.0 85 | beta = 0.0 86 | 87 | if mmtype == "mkm": 88 | y = fk.gemkm(x, fs, alpha, beta, z) 89 | elif mmtype == "kmm": 90 | y = fk.gekmm(fs, x, alpha, beta, z) 91 | 92 | ref = alpha * reference(mmtype, x, fs) + beta * z 93 | val = np.isclose(y, ref, rtol=1e-04).all().item() 94 | print(52) 95 | assert val 96 | 97 | def device_tests(device): 98 | for mmtype in ["mkm", "kmm"]: 99 | run(mmtype, 8, 3, 8, 8, np.float32, device, False, False) 100 | run(mmtype, 16, 2, 128, 128, np.float32, device, False, False) 101 | 102 | run(mmtype, 10, 5, 6, 6, np.float32, device, True, False) 103 | 104 | run(mmtype, 32, 5, 8, 8, np.float32, device, False, False, batchDimX=[2,], batchDimFPre=[], batchDimZ=[2,]) 105 | run(mmtype, 16, 5, 8, 8, np.float32, device, False, False, batchDimX=[2,3], batchDimFPre=[2,3]) 106 | run(mmtype, 8, 5, 8, 8, np.float32, device, False, False, batchDimX=[2,1,], batchDimFPre=[3,]) 107 | run(mmtype, 2, 5, 8, 8, np.float32, device, False, False, batchDimX=[2,1,], batchDimFPre=[2,4,]) 108 | run(mmtype, 32, 4, 8, 8, np.float32, device, False, False, batchDimX=[3,3,1,], batchDimFPre=[3,1,4,]) 109 | run(mmtype, 24, 4, 8, 8, np.float32, device, False, False, batchDimX=[2,], batchDimFPre=[3,2,]) 110 | 111 | run(mmtype, 128, 5, 8, 8, np.float32, device, False, False, batchDimX=[2,], batchDimFPre=[3,2,], batchDimZ=[3,1]) 112 | 113 | run(mmtype, 16, 5, 8, 8, np.float32, device, True, True, batchDimX=[2,], batchDimFPre=[]) 114 | run(mmtype, 32, 5, 8, 8, np.float32, device, True, True, batchDimX=[2,1,], batchDimFPre=[3,]) 115 | run(mmtype, 13, 5, 8, 8, np.float32, device, True, True, batchDimX=[2,1,], batchDimFPre=[2,4,]) 116 | run(mmtype, 29, 5, 8, 8, np.float32, device, True, True, batchDimX=[2,], batchDimFPre=[3,2,]) 117 | 118 | # #double 119 | run(mmtype, 11, 10, 3, 3, np.double, device, False, True) 120 | run(mmtype, 200, 2, 32, 32, np.double, device, True, True) 121 | 122 | run(mmtype, 128, 5, 8, 8, np.double, device, True, True, batchDimX=[2,1,], batchDimFPre=[2,4,]) 123 | 124 | #float16 125 | run(mmtype, 102, 4, 8, 8, np.float16, device, False, False, high=2) 126 | run(mmtype, 102, 4, 8, 8, np.float16, device, False, False, high=2, batchDimX=[2,], batchDimFPre=[]) 127 | run(mmtype, 102, 4, 8, 8, np.float16, device, False, False, high=2, batchDimX=[2,1,], batchDimFPre=[3,]) 128 | run(mmtype, 10, 3, 16, 8, np.float16, device, True, False, high=2) 129 | 130 | def test_cpu(): 131 | device_tests("cpu") 132 | 133 | if __name__ == "__main__": 134 | test_cpu() -------------------------------------------------------------------------------- /src/kernels/cuda/register-loads.cuh: -------------------------------------------------------------------------------- 1 | #include "config.h" 2 | 3 | #pragma once 4 | 5 | /***float loads***/ 6 | CUDA_DEVICE 7 | void ldGlobalVec(const float4* ptr, float regs[4]) { 8 | #if defined(__NVCC__) || defined(__CUDACC__) 9 | asm volatile ("ld.ca.global.v4.f32 {%0, %1, %2, %3}, [%4];" : 10 | "=f"(regs[0]), "=f"(regs[1]), "=f"(regs[2]), "=f"(regs[3]) : "l"(ptr)); 11 | #elif defined(__HIPCC__) 12 | float4 f4 = *(float4*)ptr; 13 | regs[0] = f4.x; regs[1] = f4.y; regs[2] = f4.z; regs[3] = f4.w; 14 | #endif 15 | } 16 | 17 | CUDA_DEVICE 18 | void ldGlobalVec(const float2* ptr, float regs[2]) { 19 | #if defined(__NVCC__) || defined(__CUDACC__) 20 | asm volatile ("ld.ca.global.v2.f32 {%0, %1}, [%2];" : 21 | "=f"(regs[0]), "=f"(regs[1]) : "l"(ptr)); 22 | #elif defined(__HIPCC__) 23 | float2 f2 = *(float2*)ptr; 24 | regs[0] = f2.x; regs[1] = f2.y; 25 | #endif 26 | } 27 | 28 | CUDA_DEVICE 29 | void ldGlobalVec(const float* ptr, float regs[1]) { 30 | #if defined(__NVCC__) || defined(__CUDACC__) 31 | asm volatile ("ld.ca.global.f32 {%0}, [%1];" : 32 | "=f"(regs[0]) : "l"(ptr)); 33 | #elif defined(__HIPCC__) 34 | regs[0] = *ptr; 35 | #endif 36 | } 37 | 38 | CUDA_DEVICE 39 | void ldGlobalVec(const float* ptr, float* regs, uint len) { 40 | switch(len) { 41 | case 1: 42 | #if defined(__NVCC__) || defined(__CUDACC__) 43 | asm volatile ("ld.ca.global.f32 {%0}, [%1];" : 44 | "=f"(regs[0]) : "l"(ptr)); 45 | #elif defined(__HIPCC__) 46 | regs[0] = *ptr; 47 | #endif 48 | break; 49 | case 2: 50 | #if defined(__NVCC__) || defined(__CUDACC__) 51 | asm volatile ("ld.ca.global.v2.f32 {%0, %1}, [%2];" : 52 | "=f"(regs[0]), "=f"(regs[1]) : "l"(ptr)); 53 | #elif defined(__HIPCC__) 54 | { 55 | float2 f2 = *(float2*)ptr; 56 | regs[0] = f2.x; regs[1] = f2.y; 57 | } 58 | #endif 59 | break; 60 | case 4: 61 | #if defined(__NVCC__) || defined(__CUDACC__) 62 | asm volatile ("ld.ca.global.v4.f32 {%0, %1, %2, %3}, [%4];" : 63 | "=f"(regs[0]), "=f"(regs[1]), "=f"(regs[2]), "=f"(regs[3]) : "l"(ptr)); 64 | #elif defined(__HIPCC__) 65 | { 66 | float4 f4 = *(float4*)ptr; 67 | regs[0] = f4.x; regs[1] = f4.y; regs[2] = f4.z; regs[3] = f4.w; 68 | } 69 | #endif 70 | break; 71 | } 72 | } 73 | 74 | CUDA_DEVICE 75 | void ldGlobalVec(const int* ptr, int* regs, uint len) { 76 | switch(len) { 77 | case 1: 78 | #if defined(__NVCC__) || defined(__CUDACC__) 79 | asm volatile ("ld.ca.global.s32 {%0}, [%1];" : 80 | "=r"(regs[0]) : "l"(ptr)); 81 | #elif defined(__HIPCC__) 82 | regs[0] = *ptr; 83 | #endif 84 | break; 85 | case 2: 86 | #if defined(__NVCC__) || defined(__CUDACC__) 87 | asm volatile ("ld.ca.global.v2.s32 {%0, %1}, [%2];" : 88 | "=r"(regs[0]), "=r"(regs[1]) : "l"(ptr)); 89 | #elif defined(__HIPCC__) 90 | { 91 | int2 f2 = *(int2*)ptr; 92 | regs[0] = f2.x; regs[1] = f2.y; 93 | } 94 | #endif 95 | break; 96 | case 4: 97 | #if defined(__NVCC__) || defined(__CUDACC__) 98 | asm volatile ("ld.ca.global.v4.s32 {%0, %1, %2, %3}, [%4];" : 99 | "=r"(regs[0]), "=r"(regs[1]), "=r"(regs[2]), "=r"(regs[3]) : "l"(ptr)); 100 | #elif defined(__HIPCC__) 101 | { 102 | int4 f4 = *(int4*)ptr; 103 | regs[0] = f4.x; regs[1] = f4.y; regs[2] = f4.z; regs[3] = f4.w; 104 | } 105 | #endif 106 | break; 107 | } 108 | } 109 | 110 | CUDA_DEVICE 111 | void ldGlobalVec(const double* ptr, double* regs, uint len) { 112 | switch(len) { 113 | case 1: 114 | #if defined(__NVCC__) || defined(__CUDACC__) 115 | asm volatile ("ld.ca.global.f64 {%0}, [%1];" : 116 | "=d"(regs[0]) : "l"(ptr)); 117 | #elif defined(__HIPCC__) 118 | regs[0] = *ptr; 119 | #endif 120 | break; 121 | case 2: 122 | #if defined(__NVCC__) || defined(__CUDACC__) 123 | asm volatile ("ld.ca.global.v2.f64 {%0, %1}, [%2];" : 124 | "=d"(regs[0]), "=d"(regs[1]) : "l"(ptr)); 125 | #elif defined(__HIPCC__) 126 | { 127 | int2 f2 = *(int2*)ptr; 128 | regs[0] = f2.x; regs[1] = f2.y; 129 | } 130 | #endif 131 | break; 132 | case 4: 133 | // #if defined(__NVCC__) || defined(__CUDACC__) 134 | // asm volatile ("ld.ca.global.v4.f64 {%0, %1, %2, %3}, [%4];" : 135 | // "=d"(regs[0]), "=d"(regs[1]), "=d"(regs[2]), "=d"(regs[3]) : "l"(ptr)); 136 | // #elif defined(__HIPCC__) 137 | // { 138 | // int4 f4 = *(int4*)ptr; 139 | // regs[0] = f4.x; regs[1] = f4.y; regs[2] = f4.z; regs[3] = f4.w; 140 | // } 141 | // #endif 142 | break; 143 | } 144 | } 145 | 146 | //int loads 147 | CUDA_DEVICE 148 | void ldGlobalVec(const int* ptr, int4& vec) { 149 | vec = *(int4*)ptr; 150 | } 151 | 152 | //double loads 153 | CUDA_DEVICE 154 | void ldGlobalVec(const double* ptr, double4& vec) { 155 | vec = *(double4*)ptr; 156 | } 157 | 158 | CUDA_DEVICE 159 | void sharedStore(float* ptr, float val) { 160 | #if defined(__NVCC__) || defined(__CUDACC__) 161 | asm volatile ("st.shared.f32 [%0], {%1};\n" :: "l"(ptr), "f"(val)); 162 | #elif defined(__HIPCC__) 163 | *ptr = val; 164 | #endif 165 | } -------------------------------------------------------------------------------- /documents/python-api.md: -------------------------------------------------------------------------------- 1 | # PyFastKron API 2 | 3 | PyFastKron provides modules for NumPy and PyTorch. The NumPy module, `FastKronNumpy` supports x86 backend. The PyTorch module, `FastKronTorch` supports both x86 and CUDA backends. 4 | 5 | ## FastKronNumpy 6 | 7 | Import modules as: 8 | 9 | ``` 10 | import numpy as np 11 | import pyfastkron.fastkronnumpy 12 | ``` 13 | 14 | Functions: 15 | ``` 16 | def gekmm(fs : List[np.ndarray], x : np.ndarray, 17 | alpha : float = 1.0, beta : float = 0.0, 18 | y : Optional[np.ndarray] = None) -> np.ndarray 19 | 20 | def gemkm(x : np.ndarray, fs : List[np.ndarray], 21 | alpha : float = 1.0, beta : float = 0.0, 22 | y : Optional[np.ndarray] = None) 23 | ``` 24 | Perform Generalized Kronecker Matrix-Matrix Multiplication (GeKMM): 25 | 26 | $Z = \alpha ~ X \times \left( F^1 \otimes F^2 \otimes \dots F^N \right) + \beta Y$ 27 | 28 | or Generalized Matrix Kronecker-Matrix Multiplication (GeMKM): 29 | 30 | $Z = \alpha ~ \left( F^1 \otimes F^2 \otimes \dots F^N \right) \times X + \beta Y$ 31 | 32 | Both functions support dimension numpy dimension broadcasting semantics. 33 | 34 | * **Parameters** 35 | * `x` is an np array 36 | * `fs` is a list of np arrays 37 | * `alpha` and `beta` are constants 38 | * `y` is an np array 39 | 40 | * **Returns** 41 | Returns output as an np array. 42 | 43 | ### Example 44 | 45 | ``` 46 | import pyfastkron.fastkronnumpy as fk 47 | 48 | x = np.ones((10, 8**5), dtype=np.float32) 49 | fs = [np.ones((8,8), dtype=np.float32) for i in range(5)] 50 | y = fk.gekmm(x, fs) 51 | ``` 52 | 53 | ## FastKronTorch 54 | 55 | Import modules as: 56 | 57 | ``` 58 | import numpy as np 59 | import pyfastkron.fastkronnumpy 60 | ``` 61 | 62 | Functions: 63 | 64 | ``` 65 | def gekmm(fs : List[np.ndarray], x : np.ndarray, 66 | alpha : float = 1.0, beta : float = 0.0, 67 | y : Optional[np.ndarray] = None) -> np.ndarray 68 | 69 | def gemkm(x : np.ndarray, fs : List[np.ndarray], 70 | alpha : float = 1.0, beta : float = 0.0, 71 | y : Optional[np.ndarray] = None) 72 | ``` 73 | Perform Generalized Kronecker Matrix-Matrix Multiplication (GeKMM): 74 | 75 | $Z = \alpha ~ X \times \left( F^1 \otimes F^2 \otimes \dots F^N \right) + \beta Y$ 76 | 77 | or Generalized Matrix Kronecker-Matrix Multiplication (GeMKM): 78 | 79 | $Z = \alpha ~ \left( F^1 \otimes F^2 \otimes \dots F^N \right) \times X + \beta Y$ 80 | 81 | Both functions support dimension numpy dimension broadcasting semantics (see below for more information). 82 | 83 | Both functions are implemented as torch.autograd.Function and hence support gradient calculation. 84 | 85 | * **Parameters** 86 | * `x` is an np array 87 | * `fs` is a list of np arrays 88 | * `alpha` and `beta` are constants 89 | * `y` is an np array 90 | 91 | * **Returns** 92 | Returns output as an np array. 93 | 94 | ### Example 95 | 96 | ``` 97 | import pyfastkron.fastkrontorch as fk 98 | 99 | x = torch.ones((10, 8**5), dtype=np.float32) 100 | fs = [torch.ones((8,8), dtype=np.float32) for i in range(5)] 101 | y = fk.gekmm(x, fs) 102 | ``` 103 | 104 | ## GeMKM Broadcast Semantics 105 | TODO Relook at these 106 | 107 | The behavior depends on the dimensionality of the tensors as follows: 108 | 109 | * If $X$ and all $F$s are 1-dimensional, the dot product (scalar) is returned. 110 | * If $X$ and atleast one $F$ is 2-dimensional, the matrix kronecker-matrix product is returned. 111 | * If $X$ is 1-dimensional and atleast one $F$ is 2-dimensional, then a 1 is prepended to dimensions of $X$, a 1 is appended to dimensions of $F$, and after MKM the added dimensions are removed. 112 | * If $X$ is 2-dimensional and all $F$s are 1-dimensional, the matrix kronecker-vector product is returned. 113 | * If $X$ and atleast one $F$ are at least 1-dimensional and at least $X$ or one of $F$ is N-dimensional (where N > 2), then a batched MKM is returned. If $X$ is 1-dimensional, a 1 is prepended to its dimension for the purpose of the batched MKM and removed after. If any of $F$s argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched MKM and removed after. The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable). However, all $F$s must have same batch dimensions. 114 | 115 | ## GeKMM Broadcast Semantics 116 | 117 | The behavior depends on the dimensionality of the tensors as follows: 118 | 119 | * If all $F$s and $X$ are 1-dimensional, the dot product (scalar) is returned. 120 | * If atleast one $F$ and $X$ is 2-dimensional, the kronecker-matrix matrix product is returned. 121 | * If atleast one $F$ is 2-dimensional and $X$ is 1-dimensional, then a 1 is preppended to dimensions of $F$, a 1 is appended to dimensions of $X$, and after KMM the added dimensions are removed. 122 | * If all $F$s are 1-dimensional and $X$ is 2-dimensional, then 1 is preppended to dimensions of $F$ and the kronecker-matrix matrix product is returned and added dimensions are removed. 123 | * If atleast one $F$ and $X$ are at least 1-dimensional or one of $F$ is N-dimensional (where N > 2), then a batched KMM is returned. If $X$ is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched KMM and removed after. If any of $F$s argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the batched KMM and removed after. The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable). However, all $F$s must have same batch dimensions. -------------------------------------------------------------------------------- /src/kernels/gpu_kmmkernel.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "kmmkernel.h" 5 | 6 | #pragma once 7 | 8 | /** 9 | * GPUKMMKernel - A subclass for kernels running on GPUs. 10 | * This class must be subclassed for CUDA or HIP kernels. 11 | */ 12 | struct GPUKMMKernel : public KMMKernel { 13 | protected: 14 | /** 15 | * @kernel: Pointer to the kernel. 16 | */ 17 | void* kernel; 18 | 19 | /** 20 | * @numThreads: Number of threads per threadblock. 21 | */ 22 | uint numThreads; 23 | 24 | /** 25 | * @alignX: Alignment of pointer of X. 26 | */ 27 | uint alignX; 28 | 29 | /** 30 | * @alignF: Alignment of pointer of F. 31 | */ 32 | uint alignF; 33 | 34 | public: 35 | GPUKMMKernel() {} 36 | GPUKMMKernel(void* kernelInvoker, FastKronType elemType, 37 | Factor f, Factor tileF, Matrix tileX, uint fusedFacs, bool P2PStore, 38 | uint regM, uint regK, uint regQ, uint optLevel, 39 | fastKronOp opX, fastKronOp opF, FastKronMMType mmType, KernelBatchType::Ty kernelBatchType, 40 | void*(*getKernel)(), uint NumThreads, 41 | uint alignX, uint alignF) : 42 | KMMKernel(kernelInvoker, elemType, f, tileF, tileX, 43 | fusedFacs, P2PStore, regM, regK, regQ, 44 | optLevel, opX, opF, mmType, kernelBatchType), 45 | numThreads(NumThreads), kernel(getKernel()), 46 | alignX(alignX), alignF(alignF) {} 47 | 48 | /** 49 | * Getter for members 50 | */ 51 | void* getKernel() const {return kernel;} 52 | uint getNumThreads() const {return numThreads;} 53 | uint getAlignmentX() const {return alignX;} 54 | uint getAlignmentF() const {return alignF;} 55 | 56 | /** 57 | * grid() - Returns grid size of the kernel for a problem. 58 | */ 59 | template 60 | dim3 grid(const KMMProblem& problem, int batchCount) const; 61 | 62 | template 63 | dim3 grid(const KMMProblemT& problem) const; 64 | 65 | template 66 | dim3 grid(const KMMProblemStridedBatchedT& problem) const; 67 | 68 | /** 69 | * block() - Returns block size of the kernel for a problem. 70 | */ 71 | dim3 block() const { 72 | return dim3{getNumThreads(), 1, 1}; 73 | } 74 | 75 | /** 76 | * getNumBlocks() - Returns number of blocks of the kernel for a problem. 77 | */ 78 | uint32_t getNumBlocks(KMMProblem problem) const { 79 | dim3 g = grid(problem); 80 | return g.x*g.y*g.z; 81 | } 82 | uint32_t getNumBlocks(KMMProblemStridedBatched problem) const { 83 | dim3 g = grid(problem); 84 | return g.x*g.y*g.z; 85 | } 86 | 87 | size_t getSharedMemPaddingSize() const { 88 | return ((mmType == FastKronMMType::KMM) ? tileF.p() * sizeOfFastKronType(elemType) : 0) + 89 | ((mmType == FastKronMMType::MKM && opF == fastKronOp_T) ? 90 | tileF.p() * sizeOfFastKronType(elemType) : 0); 91 | } 92 | 93 | /** 94 | * getMaxSharedMemSize() - Returns the maximum shared memory size of the kernel. 95 | * Effectively this is the maximum total tile size 96 | */ 97 | size_t getMaxSharedMemSize() const { 98 | return getMaxTotalTileSize() + getSharedMemPaddingSize(); 99 | } 100 | 101 | /** 102 | * getSharedMemSize() - Returns the shared memory size for the kernel. 103 | */ 104 | size_t getSharedMemSize(KMMProblem problem) const { 105 | //TODO: Shouldn't this be MIN? because getTotalTileSize < getMaxTotalTileSize 106 | return MAX(getTotalTileSize(problem) + getSharedMemPaddingSize(), getMaxSharedMemSize()); 107 | } 108 | size_t getSharedMemSize(KMMProblemStridedBatched problem) const { 109 | return getSharedMemSize(problem.batchProblem(0)); 110 | } 111 | 112 | /** 113 | * canCompute() - Overriding the method of KMMKernel. 114 | */ 115 | virtual bool canCompute(KMMProblem problem, const HardwareDetails* hw, bool p2p, 116 | KernelBatchType::Ty probBatchType, bool exactFuse = true); 117 | 118 | /** 119 | * str() - Overriding the method of KMMKernel. Adds NumThreads extra to the kernel string. 120 | */ 121 | virtual std::string str() const; 122 | }; 123 | 124 | template 125 | dim3 GPUKMMKernel::grid(const KMMProblem& problem, int batchCount) const { 126 | Matrix tileX = getTileX(problem); 127 | Factor tileF = getTileF(problem); 128 | bool isNOrT = true; //true for N and false for T 129 | if (problem.mmtype() == FastKronMMType::MKM) { 130 | isNOrT = (problem.opX() == fastKronOp_N); 131 | } else { 132 | isNOrT = (problem.opX() == fastKronOp_T); 133 | } 134 | 135 | if (isNOrT) { 136 | return dim3(DIVUP(problem.k(), tileX.n()) * DIVUP(problem.f(0).q(), tileF.q()), 137 | DIVUP(problem.m(), tileX.m()), 138 | batchCount); 139 | } else { 140 | return dim3(DIVUP(problem.m(), tileX.m()), 141 | DIVUP(problem.k(), tileX.n()) * DIVUP(problem.f(0).q(), tileF.q()), 142 | batchCount); 143 | } 144 | } 145 | 146 | template 147 | dim3 GPUKMMKernel::grid(const KMMProblemT& problem) const { 148 | return grid(problem, 1); 149 | } 150 | 151 | template 152 | dim3 GPUKMMKernel::grid(const KMMProblemStridedBatchedT& problem) const { 153 | return grid(problem.batchProblem(0), problem.batchCount()); 154 | } -------------------------------------------------------------------------------- /pyfastkron/fastkronhandle.py: -------------------------------------------------------------------------------- 1 | class FastKronHandle: 2 | def hasBackend(backends, enumBackend): 3 | return (backends & int(enumBackend)) == int(enumBackend) 4 | 5 | def __init__(self, backend, libFastKron): 6 | self.libFastKron = libFastKron 7 | self.backends = libFastKron.backends() 8 | self.handle = libFastKron.init() 9 | self.backend = None 10 | 11 | if backend.lower() == 'x86': 12 | if FastKronHandle.hasBackend(self.backends, 13 | libFastKron.Backend.X86): 14 | libFastKron.initX86(self.handle) 15 | self.backend = libFastKron.Backend.X86 16 | elif backend.lower() == 'cuda': 17 | if FastKronHandle.hasBackend(self.backends, 18 | libFastKron.Backend.CUDA): 19 | import torch 20 | streams = [torch.cuda.current_stream().cuda_stream] 21 | libFastKron.initCUDA(self.handle, streams) 22 | self.backend = libFastKron.Backend.CUDA 23 | else: 24 | assert "Invalid backend", backend 25 | 26 | def __del__(self): 27 | if self.handle is not None: 28 | self.libFastKron.destroy(self.handle) 29 | self.handle = self.backends = self.x86 = self.cuda = None 30 | 31 | def version(self): 32 | return self.libFastKron.version() 33 | 34 | def backend(self, device_type): 35 | if device_type == "cpu": 36 | return self.libFastKron.Backend.X86 37 | if device_type == "cuda": 38 | return self.libFastKron.Backend.CUDA 39 | raise RuntimeError(f"Invalid device {device_type}") 40 | 41 | def gekmmSizes(self, xshape, ps, qs): 42 | return self.libFastKron.gekmmSizes(self.handle, xshape[0], 43 | len(ps), ps, qs) 44 | 45 | def gekmmSizesForward(self, xshape, ps, qs): 46 | return self.libFastKron.gekmmSizesForward(self.handle, xshape[0], 47 | len(ps), ps, qs) 48 | 49 | def xgemkm(self, fn, m, n, ps, qs, x, fs, z, alpha, beta, y, 50 | temp1, temp2, trX=False, trF=False): 51 | fn(self.handle, self.backend, m, n, ps, qs, 52 | x, self.fastKronOp(trX), 53 | fs, self.fastKronOp(trF), 54 | z, 55 | alpha, beta, y, 56 | temp1, temp2) 57 | 58 | def fastKronOp(self, tr): 59 | return self.libFastKron.Op.N if not tr else self.libFastKron.Op.T 60 | 61 | # TODO: Change argument order according to cublas API see comment in pywapper.cpp 62 | def xgemkmStridedBatched(self, fn, m, n, ps, qs, x, strideX, fs, strideFs, 63 | batchCount, z, strideZ, alpha, beta, y, strideY, 64 | temp1, temp2, trX=False, trF=False): 65 | fn(self.handle, self.backend, m, n, ps, qs, 66 | x, self.fastKronOp(trX), strideX, 67 | fs, self.fastKronOp(trF), strideFs, 68 | z, strideZ, 69 | alpha, beta, batchCount, y, strideY, 70 | temp1, temp2) 71 | 72 | def xgekmm(self, fn, m, n, ps, qs, x, fs, z, alpha, beta, y, 73 | temp1, temp2, trX=False, trF=False): 74 | fn(self.handle, self.backend, n, qs, ps, m, 75 | fs, self.fastKronOp(trF), 76 | x, self.fastKronOp(trX), 77 | z, 78 | alpha, beta, y, 79 | temp1, temp2) 80 | 81 | def xgekmmStridedBatched(self, fn, m, n, ps, qs, x, strideX, fs, strideFs, 82 | batchCount, z, strideZ, alpha, beta, y, strideY, 83 | temp1, temp2, trX=False, trF=False): 84 | fn(self.handle, self.backend, n, qs, ps, m, 85 | fs, self.fastKronOp(trF), strideFs, 86 | x, self.fastKronOp(trX), strideX, 87 | z, strideZ, 88 | alpha, beta, batchCount, y, strideY, 89 | temp1, temp2) 90 | 91 | def xmkmForward(self, fn, m, n, ps, qs, x, fs, z, 92 | intermediates, trX=False, trF=False): 93 | fn(self.handle, self.backend, m, n, ps, qs, 94 | x, self.fastKronOp(trX), fs, self.fastKronOp(trF), 95 | z, intermediates) 96 | 97 | def xmkmForwardStridedBatched(self, fn, m, n, ps, qs, x, strideX, 98 | fs, strideFs, batchCount, z, strideZ, 99 | intermediates, strideIntermediates, 100 | trX=False, trF=False): 101 | fn(self.handle, self.backend, m, n, ps, qs, 102 | x, self.fastKronOp(trX), strideX, 103 | fs, self.fastKronOp(trF), strideFs, 104 | z, strideZ, batchCount, intermediates, strideIntermediates) 105 | 106 | def xkmmForward(self, fn, m, n, ps, qs, x, fs, z, 107 | intermediates, trX=False, trF=False): 108 | fn(self.handle, self.backend, n, qs, ps, m, 109 | fs, self.fastKronOp(trF), x, self.fastKronOp(trX), z, intermediates) 110 | 111 | def xkmmForwardStridedBatched(self, fn, m, n, ps, qs, x, strideX, 112 | fs, strideFs, batchCount, z, strideZ, 113 | intermediates, strideIntermediates, 114 | trX=False, trF=False): 115 | fn(self.handle, self.backend, n, qs, ps, m, 116 | fs, self.fastKronOp(trF), strideFs, x, self.fastKronOp(trX), 117 | strideX, z, strideZ, batchCount, intermediates, 118 | strideIntermediates) 119 | -------------------------------------------------------------------------------- /src/kernels/cpu/tensor.h: -------------------------------------------------------------------------------- 1 | #include "kernels/cuda/fixed-shape-tensor.cuh" 2 | 3 | #pragma once 4 | 5 | //TODO: Think about this 6 | template 8 | class SliceCPU { 9 | public: 10 | const Matrix parent; 11 | //TODO: Create Coord2D 12 | uint32_t startrow; 13 | uint32_t startcol; 14 | uint32_t tileRows_; 15 | uint32_t tileCols_; 16 | uint32_t rows; 17 | uint32_t cols; 18 | uint32_t P; 19 | T* ptr; 20 | 21 | public: 22 | CUDA_DEVICE_HOST 23 | SliceCPU(uint32_t startrow, uint32_t startcol, uint32_t paramTileK, uint32_t P, Matrix parent) : 24 | parent(parent), startrow(startrow), startcol(startcol), 25 | tileRows_(OptTileX::M()), 26 | P(P), 27 | ptr(parent.data(startrow, startcol, OptTileX::Op())) { 28 | tileCols_ = isTileKSame ? OptTileX::N() : paramTileK; 29 | rows = (tileRows_ == 1) ? 1 : MIN(tileRows_, parent.m() - startrow); 30 | cols = isKMultipleOfTileK ? tileCols() : MIN(tileCols(), parent.n() - startcol); 31 | } 32 | 33 | CUDA_DEVICE_HOST 34 | const T* data(uint32_t row, uint32_t slice, uint32_t elem) const { 35 | //TODO: get common parts out 36 | if (OptTileX::Op() == fastKronOp_N) { 37 | uint32_t idx = row * parent.n(); 38 | idx += slice*P + elem; 39 | return &ptr[idx]; 40 | } else if (OptTileX::Op() == fastKronOp_T) { 41 | uint32_t idx = slice*P + elem; 42 | idx = idx * parent.m() + row; 43 | return &ptr[idx]; 44 | } 45 | 46 | return nullptr; 47 | } 48 | 49 | CUDA_DEVICE_HOST 50 | const T* data(uint32_t idx) const { 51 | return &ptr[idx]; 52 | } 53 | 54 | CUDA_DEVICE_HOST 55 | uint32_t m() const {return rows;} 56 | CUDA_DEVICE_HOST 57 | uint32_t n() const {return cols;} 58 | CUDA_DEVICE_HOST 59 | uint32_t numel() const {return rows * cols;} 60 | CUDA_DEVICE_HOST 61 | uint32_t tileRows() const {return tileRows_;} 62 | CUDA_DEVICE_HOST 63 | uint32_t tileCols() const {return isTileKSame ? OptTileX::N() : tileCols_;} 64 | }; 65 | 66 | template 68 | class TransposedDirectShared3D : public AbstractFixedShapeTensor2D { 69 | using Base = AbstractFixedShapeTensor2D; 70 | T* data; 71 | 72 | public: 73 | CUDA_DEVICE_HOST 74 | TransposedDirectShared3D(T* data) : data(data) {} 75 | 76 | CUDA_DEVICE_HOST 77 | constexpr fastKronOp layout() const {return Layout;} 78 | 79 | CUDA_DEVICE_HOST 80 | //TODO: Make this Coord1D 81 | void store(uint32_t eIdx, uint32_t num, const T* elems) { 82 | #pragma unroll 83 | for (uint ve = 0; ve < num; ve++) { 84 | uint idx = eIdx + ve; 85 | Base::set(data, idx, elems[ve]); 86 | } 87 | } 88 | 89 | CUDA_DEVICE_HOST 90 | //TODO: Make this Coord2D 91 | void store(uint32_t row, uint32_t col, uint32_t num, const T* elems) { 92 | #pragma unroll 93 | for (uint ve = 0; ve < num; ve++) { 94 | uint32_t idx = row * Base::shape(1) + col + ve; 95 | Base::set(data, idx, elems[ve]); 96 | } 97 | } 98 | 99 | CUDA_DEVICE_HOST 100 | T& at(uint32_t row, uint32_t slice, uint32_t elem) { 101 | return Base::at(data, row, elem * slices() + slice); 102 | } 103 | 104 | CUDA_DEVICE_HOST 105 | const T& at(uint32_t row, uint32_t slice, uint32_t elem) const { 106 | return Base::at(data, row, elem * slices() + slice); 107 | } 108 | 109 | void zero(uint32_t startRow, uint32_t startSlice, uint32_t startElem, uint32_t endRow, uint32_t endSlice, uint32_t endElem) { 110 | if (layout() == fastKronOp_N) { 111 | for (uint32_t r = startRow; r < endRow; r++) { 112 | for (uint32_t e = startElem; e < endElem; e++) { 113 | for (uint32_t c = startSlice; c < endSlice; c++) { 114 | at(r, c, e) = 0.0f; 115 | } 116 | } 117 | } 118 | } else { 119 | for (uint32_t e = startElem; e < endElem; e++) { 120 | for (uint32_t c = startSlice; c < endSlice; c++) { 121 | for (uint32_t r = startRow; r < endRow; r++) { 122 | at(r, c, e) = 0.0f; 123 | } 124 | } 125 | } 126 | } 127 | } 128 | 129 | CUDA_DEVICE_HOST 130 | uint32_t numel() const {return m() * n();} 131 | 132 | CUDA_DEVICE_HOST 133 | uint32_t slices() const {return OptTileX::N()/OptF::P();} 134 | 135 | CUDA_DEVICE_HOST 136 | uint32_t m() const {return OptTileX::M();} 137 | CUDA_DEVICE_HOST 138 | uint32_t n() const {return OptTileX::N();} 139 | CUDA_DEVICE_HOST 140 | uint32_t p() const {return OptTileF::P();} 141 | }; 142 | 143 | template 144 | class YInterim : public AbstractFixedShapeTensor3D { 145 | using Base = AbstractFixedShapeTensor3D; 146 | T* data; 147 | 148 | public: 149 | CUDA_DEVICE_HOST 150 | YInterim(T* data) : data(data) {} 151 | 152 | CUDA_DEVICE_HOST 153 | uint32_t m() const {return OptTileX::M();} 154 | CUDA_DEVICE_HOST 155 | uint32_t slices() const {return OptTileX::N()/OptF::P();} 156 | CUDA_DEVICE_HOST 157 | uint32_t q() const {return OptTileF::Q();} 158 | 159 | CUDA_DEVICE_HOST 160 | T& at(const uint32_t m, const uint32_t q, const uint32_t slice) { 161 | if (OpY == fastKronOp_N) 162 | return Base::at(data, m, q, slice); 163 | else if (OpY == fastKronOp_T) { 164 | return Base::at(data, q * this->slices() * this->m() + slice * this->m() + m); 165 | } 166 | } 167 | }; -------------------------------------------------------------------------------- /src/kernel_db/hip_kernel_db.hip: -------------------------------------------------------------------------------- 1 | #include "kernel_db/hip_kernel_db.h" 2 | 3 | #include 4 | 5 | #include "kernels/hip_kernel_info.h" 6 | 7 | #ifdef ENABLE_HIP 8 | #include "kernels/hip/kron-kernels/kernel_decl.inc" 9 | #endif 10 | 11 | HIPKernel AllHIPKernels[] = { 12 | #ifdef ENABLE_HIP 13 | ALL_HIP_KERNELS 14 | #endif 15 | }; 16 | 17 | HIPKernelDatabase::HIPKernelDatabase() : KernelDatabase() { 18 | loadKernels(AllHIPKernels, sizeof(AllHIPKernels)/sizeof(HIPKernel)); 19 | } 20 | 21 | //Launch hip kernels 22 | template 23 | fastKronError invoke(HIPKernel& kernelInfo, const uint kronIndex, 24 | KMMProblem problem, 25 | DistributedParams distParams, 26 | EpilogueParams epilogueParams, 27 | KernelMode execMode, 28 | hipStream_t stream) { 29 | hipError_t status; 30 | 31 | //Create the grid and thread block 32 | KernelParams params (problem, kronIndex, execMode); 33 | FusedParams fusedParams (problem, kernelInfo.tileX.n()); 34 | //Call kernel 35 | typedef void (*KronMatmulKernelTy)(KernelParams, FusedParams, 36 | DistributedParams, EpilogueParams, dim3, dim3, uint32_t, hipStream_t); 37 | KronMatmulKernelTy(kernelInfo.invokerFunc)(params, fusedParams, distParams, 38 | epilogueParams, kernelInfo.grid(problem), 39 | kernelInfo.block(), kernelInfo.sharedMemSize(), stream); 40 | status = hipGetLastError(); 41 | HIP_CHECK(status); 42 | 43 | return fastKronSuccess; 44 | } 45 | 46 | fastKronError HIPKernelDatabase::invokeKernel(KMMKernel* kernel, const uint kronIndex, 47 | KMMProblem problem, EpilogueParams epilogueParams, 48 | KernelMode execMode) { 49 | DistributedParams distParams; 50 | hipStream_t stream = *(hipStream_t*)streams[0]; 51 | HIPKernel& hipKernel = dynamic_cast(*kernel); 52 | 53 | switch(problem.n()) { 54 | case 1: 55 | return invoke<1>(hipKernel, kronIndex, problem, 56 | distParams, epilogueParams, execMode, stream); 57 | case 2: 58 | return invoke<2>(hipKernel, kronIndex, problem, 59 | distParams, epilogueParams, execMode, stream); 60 | case 3: 61 | return invoke<3>(hipKernel, kronIndex, problem, 62 | distParams, epilogueParams, execMode, stream); 63 | case 4: 64 | return invoke<4>(hipKernel, kronIndex, problem, 65 | distParams, epilogueParams, execMode, stream); 66 | case 5: 67 | return invoke<5>(hipKernel, kronIndex, problem, 68 | distParams, epilogueParams, execMode, stream); 69 | case 6: 70 | return invoke<6>(hipKernel, kronIndex, problem, 71 | distParams, epilogueParams, execMode, stream); 72 | default: 73 | std::cout << "Invalid number of fused kernels" << std::endl; 74 | return fastKronKernelNotFound; 75 | } 76 | } 77 | 78 | fastKronError HIPKernelDatabase::timeKernel(KMMKernel* kernel, const uint factorIdx, 79 | KMMProblem problem, DistributedParams distParams, 80 | EpilogueParams epilogueParams, 81 | KernelMode execMode, 82 | bool distP2PStore, 83 | int warmups, int runs, 84 | float& runtime) { 85 | hipStream_t stream = *(hipStream_t*)streams[0]; 86 | HIP_CHECK(hipStreamSynchronize(stream)); 87 | hipEvent_t startEvent, endEvent; 88 | HIP_CHECK(hipEventCreate(&startEvent)); 89 | HIP_CHECK(hipEventCreate(&endEvent)); 90 | fastKronError status; 91 | for (int r = 0; r < warmups + runs; r++) { 92 | if (r == warmups) HIP_CHECK(hipEventRecord(startEvent, stream)); 93 | if (distP2PStore) { 94 | status = invokeP2PStoreKernel(kernel, factorIdx, problem, 95 | distParams, epilogueParams, execMode); 96 | } else { 97 | status = invokeKernel(kernel, factorIdx, problem, 98 | epilogueParams, execMode); 99 | } 100 | } 101 | HIP_CHECK(hipEventRecord(endEvent, stream)); 102 | HIP_CHECK(hipEventSynchronize(endEvent)); 103 | if (status != fastKronSuccess) { 104 | HIP_CHECK(hipEventDestroy(startEvent)); 105 | HIP_CHECK(hipEventDestroy(endEvent)); 106 | return status; 107 | } 108 | HIP_CHECK(hipEventElapsedTime(&runtime, startEvent, endEvent)); 109 | runtime = runtime/runs; 110 | HIP_CHECK(hipEventDestroy(startEvent)); 111 | HIP_CHECK(hipEventDestroy(endEvent)); 112 | return fastKronSuccess; 113 | } 114 | 115 | fastKronError HIPKernelDatabase::initTune() { 116 | HIP_CHECK(hipSetDevice(0)); 117 | return fastKronSuccess; 118 | } 119 | 120 | fastKronError HIPKernelDatabase::procMalloc(uint32_t proc, size_t size, void*& ptr){ 121 | HIP_CHECK(hipSetDevice(proc)); 122 | HIP_CHECK(hipMalloc(&ptr, size)); 123 | HIP_CHECK(hipMemset(ptr, 1, size)); 124 | 125 | return fastKronSuccess; 126 | } 127 | 128 | fastKronError HIPKernelDatabase::procMemset(uint32_t proc, Matrix& m, float val) { 129 | HIP_CHECK(hipSetDevice(proc)); 130 | float* host = new float[m.numel()]; 131 | memset(host, m.numel(), val); 132 | HIP_CHECK(hipMemcpy(m.data(), host, m.numel()*sizeof(float), hipMemcpyHostToDevice)); 133 | delete[] host; 134 | 135 | return fastKronSuccess; 136 | } 137 | 138 | fastKronError HIPKernelDatabase::procFree(uint32_t proc, void* ptr) { 139 | HIP_CHECK(hipSetDevice(proc)); 140 | HIP_CHECK(hipFree(ptr)); 141 | 142 | return fastKronSuccess; 143 | } -------------------------------------------------------------------------------- /src/kernels/best-kernels/x86-avx-kernels: -------------------------------------------------------------------------------- 1 | mkm_x86_sisd_f_128x128_32x128_1_1x16384_1x1x1_**_strided_*_0|1|2 2 | mkm_x86_sisd_d_128x128_16x128_1_1x16384_1x1x1_**_strided_*_0|1|2 3 | 4 | mkm_x86_avx_f_128x128_32x128_1_1x16384_1x16x4_N* 5 | mkm_x86_avx_f_128x128_64x64_1_1x8192_1x16x4_N* 6 | mkm_x86_avx_f_128x128_32x8_1_1x4096_1x16x4_N* 7 | mkm_x86_avx_f_128x128_32x128_1_8x2048_1x16x4_T* 8 | 9 | mkm_x86_avx_f_64x64_64x64_1_1x4096_1x16x4_N* 10 | mkm_x86_avx_f_64x64_32x32_1_1x2048_1x16x4_N* 11 | mkm_x86_avx_f_64x64_32x2_1_1x1024_1x16x2_N* 12 | mkm_x86_avx_f_64x64_64x64_1_8x1024_1x16x4_T* 13 | 14 | mkm_x86_avx_f_32x32_32x32_2_1x8192_1x16x4_N* 15 | mkm_x86_avx_f_32x32_32x32_1_1x8192_1x16x4_N* 16 | mkm_x86_avx_f_32x32_32x32_1_1x1024_1x16x4_N* 17 | mkm_x86_avx_f_32x32_32x32_1_8x1024_1x16x4_T* 18 | 19 | mkm_x86_avx_f_16x16_16x16_2_1x2048_1x16x4_N* 20 | mkm_x86_avx_f_16x16_16x16_1_1x2048_1x16x4_N* 21 | mkm_x86_avx_f_16x16_16x16_1_1x256_1x16x4_N* 22 | mkm_x86_avx_f_16x16_16x16_1_8x256_1x16x4_T* 23 | 24 | mkm_x86_avx_f_8x8_8x8_3_1x4096_1x16x4_N* 25 | mkm_x86_avx_f_8x8_8x8_1_1x4096_1x16x4_N* 26 | mkm_x86_avx_f_8x8_8x8_1_8x512_1x16x4_T* 27 | 28 | mkm_x86_avx_f_8x8_8x8_2_1x512_1x16x4_N* 29 | mkm_x86_avx_f_8x8_8x8_1_1x512_1x16x4_N* 30 | mkm_x86_avx_f_8x8_8x8_1_1x64_1x8x4_N* 31 | 32 | mkm_x86_avx_f_4x4_4x4_3_1x512_1x16x4_N* 33 | mkm_x86_avx_f_4x4_4x4_2_1x256_1x16x4_N* 34 | mkm_x86_avx_f_4x4_4x4_1_1x64_1x16x2_N* 35 | mkm_x86_avx_f_4x4_4x4_1_8x64_1x16x2_T* 36 | 37 | mkm_x86_avx_f_2x2_2x2_5_1x256_1x32x2_N* 38 | mkm_x86_avx_f_2x2_2x2_4_1x128_1x32x2_N* 39 | mkm_x86_avx_f_2x2_2x2_3_1x64_1x16x1_N* 40 | 41 | mkm_x86_avx_f_2x2_2x2_1_8x64_1x16x1_T* 42 | 43 | mkm_x86_avx_f_2x2_2x2_1_1x256_1x32x1_N* 44 | mkm_x86_avx_f_2x2_2x2_1_1x64_1x16x1_N* 45 | mkm_x86_avx_f_2x2_2x2_1_1x16_1x8x2_N* 46 | 47 | mkm_x86_avx_d_128x128_16x128_1_1x16384_1x8x4_N* 48 | mkm_x86_avx_d_128x128_32x64_1_1x8192_1x8x4_N* 49 | mkm_x86_avx_d_128x128_32x8_1_1x4096_1x8x4_N* 50 | mkm_x86_avx_d_128x128_16x128_1_8x2048_1x8x4_T* 51 | 52 | mkm_x86_avx_d_64x64_32x64_1_1x4096_1x8x4_N* 53 | mkm_x86_avx_d_64x64_32x32_1_1x2048_1x8x4_N* 54 | mkm_x86_avx_d_64x64_32x2_1_1x1024_1x8x2_N* 55 | X86_AVX_mkm_d_64x64_32x64_1_8x512_1x8x4_T* 56 | 57 | mkm_x86_avx_d_32x32_32x32_2_1x4096_1x8x4_N* 58 | mkm_x86_avx_d_32x32_16x32_1_1x4096_1x8x4_N* 59 | mkm_x86_avx_d_32x32_16x32_1_8x512_1x8x4_T* 60 | mkm_x86_avx_d_32x32_32x32_1_1x1024_1x8x4_N* 61 | 62 | mkm_x86_avx_d_16x16_16x16_2_1x2048_1x8x4_N* 63 | mkm_x86_avx_d_16x16_16x16_1_1x2048_1x8x4_N* 64 | mkm_x86_avx_d_16x16_16x16_1_1x256_1x8x4_N* 65 | mkm_x86_avx_d_16x16_16x16_1_8x256_1x8x4_T* 66 | 67 | mkm_x86_avx_d_8x8_8x8_3_1x8192_1x8x4_N* 68 | mkm_x86_avx_d_8x8_8x8_2_1x4096_1x8x4_N* 69 | mkm_x86_avx_d_8x8_8x8_1_1x4096_1x8x4_N* 70 | mkm_x86_avx_d_8x8_8x8_1_8x256_1x8x4_T* 71 | 72 | mkm_x86_avx_d_8x8_8x8_2_1x512_1x8x4_N* 73 | mkm_x86_avx_d_8x8_8x8_1_1x512_1x8x4_N* 74 | mkm_x86_avx_d_8x8_8x8_1_1x64_1x8x4_N* 75 | 76 | mkm_x86_avx_d_4x4_4x4_3_1x256_1x8x4_N* 77 | mkm_x86_avx_d_4x4_4x4_1_1x64_1x8x4_N* 78 | mkm_x86_avx_d_4x4_4x4_1_8x64_1x8x4_T* 79 | 80 | mkm_x86_avx_d_2x2_2x2_5_1x256_1x16x2_N* 81 | mkm_x86_avx_d_2x2_2x2_4_1x128_1x16x2_N* 82 | mkm_x86_avx_d_2x2_2x2_3_1x64_1x8x1_N* 83 | mkm_x86_avx_d_2x2_2x2_2_1x32_1x8x1_N* 84 | 85 | mkm_x86_avx_d_2x2_2x2_1_1x256_1x16x1_N* 86 | mkm_x86_avx_d_2x2_2x2_1_1x64_1x8x1_N* 87 | mkm_x86_avx_d_2x2_2x2_1_1x16_1x4x2_N* 88 | mkm_x86_avx_d_2x2_2x2_1_8x64_1x8x1_T* 89 | 90 | mkm_x86_avx512_f_128x128_32x128_1_1x16384_1x64x4_N* 91 | mkm_x86_avx512_f_128x128_64x64_1_1x8192_1x64x4_N* 92 | mkm_x86_avx512_f_128x128_32x8_1_1x4096_1x32x4_N* 93 | mkm_x86_avx512_f_128x128_16x128_1_16x8192_1x64x4_T* 94 | 95 | mkm_x86_avx512_f_64x64_64x64_1_1x4096_1x64x4_N* 96 | mkm_x86_avx512_f_64x64_32x32_1_1x2048_1x32x4_N* 97 | mkm_x86_avx512_f_64x64_32x2_1_1x1024_1x16x2_N* 98 | mkm_x86_avx512_f_64x64_64x64_1_16x2048_1x32x8_T* 99 | 100 | mkm_x86_avx512_f_32x32_32x32_1_1x8192_1x64x4_N* 101 | mkm_x86_avx512_f_32x32_32x32_1_1x1024_1x32x4_N* 102 | mkm_x86_avx512_f_32x32_32x32_1_16x1024_1x32x8_T* 103 | 104 | mkm_x86_avx512_f_16x16_16x16_1_1x2048_1x64x4_N* 105 | mkm_x86_avx512_f_16x16_16x16_1_1x256_1x16x4_N* 106 | mkm_x86_avx512_f_16x16_16x16_1_16x256_1x16x4_T* 107 | 108 | mkm_x86_avx512_f_8x8_8x8_2_1x4096_1x64x4_N* 109 | mkm_x86_avx512_f_8x8_8x8_1_1x4096_1x64x4_N* 110 | mkm_x86_avx512_f_8x8_8x8_1_1x512_1x64x4_N* 111 | 112 | mkm_x86_avx512_f_8x8_8x8_1_16x512_1x64x4_T* 113 | 114 | mkm_x86_avx512_f_4x4_4x4_1_1x256_1x64x4_N* 115 | mkm_x86_avx512_f_4x4_4x4_1_16x128_1x32x4_T* 116 | 117 | mkm_x86_avx512_f_2x2_2x2_1_1x256_1x64x1_N* 118 | mkm_x86_avx512_f_2x2_2x2_1_1x64_1x32x1_N* 119 | 120 | mkm_x86_avx512_f_2x2_2x2_1_16x64_1x32x2_T* 121 | 122 | mkm_x86_avx512_d_128x128_16x128_1_1x16384_1x32x4_N* 123 | mkm_x86_avx512_d_128x128_32x64_1_1x8192_1x32x4_N* 124 | mkm_x86_avx512_d_128x128_32x8_1_1x4096_1x32x4_N* 125 | 126 | mkm_x86_avx512_d_128x128_16x128_1_8x4096_1x32x4_T* 127 | 128 | mkm_x86_avx512_d_64x64_32x64_1_1x4096_1x32x4_N* 129 | mkm_x86_avx512_d_64x64_32x32_1_1x2048_1x32x4_N* 130 | mkm_x86_avx512_d_64x64_32x2_1_1x1024_1x16x2_N* 131 | mkm_x86_avx512_d_64x64_32x64_1_8x2048_1x32x4_T* 132 | 133 | mkm_x86_avx512_d_32x32_16x32_1_1x4096_1x32x4_N* 134 | mkm_x86_avx512_d_32x32_32x32_1_1x1024_1x32x4_N* 135 | mkm_x86_avx512_d_32x32_16x32_1_8x1024_1x32x4_T* 136 | 137 | mkm_x86_avx512_d_16x16_16x16_1_1x2048_1x32x4_N* 138 | mkm_x86_avx512_d_16x16_16x16_1_1x256_1x16x4_N* 139 | mkm_x86_avx512_d_16x16_16x16_1_8x512_1x32x4_N* 140 | 141 | mkm_x86_avx512_d_8x8_8x8_2_1x4096_1x32x4_N* 142 | mkm_x86_avx512_d_8x8_8x8_1_1x4096_1x32x4_N* 143 | mkm_x86_avx512_d_8x8_8x4_1_1x512_1x32x4_N* 144 | mkm_x86_avx512_d_8x8_8x8_1_1x64_1x8x4_N* 145 | 146 | mkm_x86_avx512_d_8x8_8x8_1_8x512_1x32x4_T* 147 | 148 | mkm_x86_avx512_d_4x4_4x4_1_1x64_1x16x4_N* 149 | mkm_x86_avx512_d_4x4_4x2_1_1x64_1x16x2_N* 150 | 151 | mkm_x86_avx512_d_4x4_4x4_1_8x64_1x16x4_T* 152 | 153 | mkm_x86_avx512_d_2x2_2x2_1_1x256_1x32x1_N* 154 | mkm_x86_avx512_d_2x2_2x2_1_1x64_1x32x1_N* 155 | mkm_x86_avx512_d_2x2_2x2_1_1x16_1x8x2_N* 156 | 157 | mkm_x86_avx512_d_2x2_2x2_1_8x64_1x32x1_T* -------------------------------------------------------------------------------- /src/autotuner/autotuner.h: -------------------------------------------------------------------------------- 1 | #include "kmm/kmmalgo.h" 2 | 3 | #include "kernel_db/kernel_db.h" 4 | #include "kernels/kmmkernel.h" 5 | 6 | #include 7 | #include 8 | 9 | #pragma once 10 | 11 | /** 12 | * TunedKernelsMap - maps a KMMProblem to tuned kernel with its execution time. 13 | */ 14 | template 15 | class TunedKernelsMap { 16 | /** 17 | * @ProblemToKernels: a map of KMMProblem to a pair of kernel and its run time in milliseconds. 18 | */ 19 | using ProblemToKernels = std::unordered_map>; 20 | 21 | /** 22 | * @kernels: the map of KMMProblem to single gpu/cpu kernels. 23 | * @p2pKernels: the map of KMMProblem to kernels storing output using P2P. 24 | */ 25 | ProblemToKernels kernels; 26 | ProblemToKernels p2pKernels; 27 | 28 | /** 29 | * getKernel - get kernel of a problem from a map. 30 | */ 31 | typename ProblemToKernels::const_iterator getKernel(const ProblemToKernels& map, const KMMProblemT& problem) { 32 | return map.find(problem); 33 | } 34 | 35 | public: 36 | TunedKernelsMap() {} 37 | 38 | /** 39 | * add() - add or update a kernel-time pair for a problem. 40 | * @problem: The KMMProblem to add kernel-time pair for. 41 | * @p2p: True if the problem requires P2P stores for storing output otherwise false. 42 | * @kernelAndTime: The pair of kernel and its runtime. 43 | */ 44 | void add(const KMMProblemT& problem, bool p2p, std::pair kernelAndtime) { 45 | if (p2p) { 46 | p2pKernels.emplace(std::make_pair(problem, kernelAndtime)); 47 | } else { 48 | kernels.emplace(std::make_pair(problem, kernelAndtime)); 49 | } 50 | } 51 | 52 | /** 53 | * hasKernel() - Determine if there is a kernel for a problem 54 | * @problem: The KMMProblem to find kernel-time pair for 55 | * @p2p: True if the problem requires P2P stores for storing output otherwise false. 56 | */ 57 | bool hasKernel(const KMMProblemT& problem, bool p2p) { 58 | return (p2p) ? getKernel(p2pKernels, problem) != p2pKernels.end(): 59 | getKernel(kernels, problem) != kernels.end(); 60 | } 61 | 62 | /** 63 | * getKernel() - Get the kernel for a problem 64 | * @problem: The KMMProblem to find kernel for 65 | * @p2p: True if the problem requires P2P stores for storing output otherwise false. 66 | */ 67 | KMMKernel* getKernel(const KMMProblemT& problem, bool p2p) { 68 | return (p2p) ? getKernel(p2pKernels, problem)->second.first : 69 | getKernel(kernels, problem)->second.first; 70 | } 71 | 72 | /** 73 | * getKernelTime() - Get the kernel-time for a problem 74 | * @problem: The KMMProblem to find kernel for 75 | * @p2p: True if the problem requires P2P stores for storing output otherwise false. 76 | */ 77 | float getKernelTime(const KMMProblemT& problem, bool p2p) { 78 | return (p2p) ? getKernel(p2pKernels, problem)->second.second : 79 | getKernel(kernels, problem)->second.second; 80 | } 81 | }; 82 | 83 | /** 84 | * Forward declaration of FastKronHandle for Autotuner 85 | */ 86 | class FastKronHandle; 87 | 88 | /** 89 | * Autotuner - goes through all valid kernels and finds a kernel series with least execution time 90 | * for a KMMProblem 91 | */ 92 | class Autotuner { 93 | /** 94 | * @fastKron: The parent FastKron handle 95 | */ 96 | FastKronHandle& fastKron; 97 | 98 | /** 99 | * @tunedKernelsMap: A map of tuned kernels and KMMProblems 100 | */ 101 | TunedKernelsMap tunedKernelsMap; 102 | /** 103 | * @tunedKernelsMapStridedBatched: A map of tuned kernels and KMMProblemStridedBatched 104 | */ 105 | TunedKernelsMap tunedKernelsMapStridedBatched; 106 | 107 | /** 108 | * @tunedProblemCache: A cache of already tuned full KMMProblems. 109 | * Maps each KMMProblem to tuned kernel series for each backend. 110 | */ 111 | std::unordered_map> tunedProblemCache; 112 | /** 113 | * @tunedProblemCacheStridedBatched: A cache of already tuned full KMMProblemStridedBatcheds. 114 | * Maps each KMMProblemStridedBatched to tuned kernel series for each backend. 115 | */ 116 | std::unordered_map> tunedProblemCacheStridedBatched; 117 | 118 | /** 119 | * tune() - Tune kernels for all subproblems in the KMMProblem. 120 | * @problem: The base KMMProblem. 121 | * @kernelDb: KernelDatabase containing kernels. 122 | * @isDistributed: If the KMMProblem is computed using distributed GPUs 123 | * @distParams: Distributed paramaters if needed. 124 | */ 125 | template 126 | fastKronError tune(KMMProblem problem, TunedKernelsMap& tunedKernelsMap, 127 | KernelDatabase* kernelDb, bool isDistributed, 128 | DistributedParams distParams); 129 | 130 | public: 131 | TunedKernelsSeries distribTunedKernelSeries; 132 | 133 | Autotuner(FastKronHandle& fastKron); 134 | 135 | /** 136 | * tune() - Find the best performing kernel series for a KMMProblem on a backend 137 | * @problem: KMMProblem 138 | * @backend: fastKronBackend containing kernels 139 | * @retKernelSeries: [OUT] the tuned kernel series 140 | */ 141 | fastKronError tune(KMMProblem problem, 142 | const fastKronBackend backend, 143 | TunedKernelsSeries& retKernelSeries); 144 | 145 | /** 146 | * tune() - Find the best performing kernel series for a KMMProblemStridedBatched on a backend 147 | * @problem: KMMProblemStridedBatched 148 | * @backend: fastKronBackend containing kernels 149 | * @retKernelSeries: [OUT] the tuned kernel series 150 | */ 151 | fastKronError tune(KMMProblemStridedBatched problem, 152 | const fastKronBackend backend, 153 | TunedKernelsSeries& retKernelSeries); 154 | }; -------------------------------------------------------------------------------- /src/kernels/best-kernels/kmm-x86-avx-kernels: -------------------------------------------------------------------------------- 1 | kmm_x86_sisd_f_128x128_32x128_1_8x256_1x1x1 2 | kmm_x86_sisd_d_128x128_32x128_1_8x256_1x1x1 3 | 4 | kmm_x86_avx_f_128x128_32x128_1_32x512_16x4x1_N*_*_*_* 5 | kmm_x86_avx_f_128x128_32x128_1_16x1024_16x4x1 6 | kmm_x86_avx_f_128x128_32x128_1_8x4096_8x1x8 7 | kmm_x86_avx_f_128x128_32x32_1_8x256_8x1x4 8 | 9 | kmm_x86_avx_f_64x64_32x64_1_16x256_16x1x4 10 | kmm_x86_avx_f_64x64_32x64_1_8x512_8x1x4 11 | kmm_x86_avx_f_64x64_32x16_1_16x256_16x2x1 12 | kmm_x86_avx_f_64x64_32x32_1_8x64_8x1x4 13 | 14 | kmm_x86_avx_f_32x32_32x32_1_16x256_16x1x2 15 | kmm_x86_avx_f_32x32_32x32_1_64x256_8x1x4 16 | kmm_x86_avx_f_32x32_32x32_1_8x512_8x1x2 17 | kmm_x86_avx_f_32x32_32x32_1_8x32_8x1x8 18 | 19 | 20 | kmm_x86_avx_f_16x16_16x16_2_32x256_16x2x2 21 | kmm_x86_avx_f_16x16_16x16_1_32x256_16x2x2 22 | kmm_x86_avx_f_16x16_16x16_2_16x256_16x1x2 23 | kmm_x86_avx_f_16x16_16x16_1_16x256_16x1x2 24 | kmm_x86_avx_f_16x16_16x16_2_8x256_8x1x4 25 | kmm_x86_avx_f_16x16_16x16_1_8x256_8x1x4 26 | 27 | kmm_x86_avx_f_8x8_8x8_3_8x512_8x1x4 28 | kmm_x86_avx_f_8x8_8x8_1_8x512_8x1x4 29 | kmm_x86_avx_f_8x8_8x8_2_8x64_8x2x1 30 | kmm_x86_avx_f_8x8_8x8_1_8x64_8x2x1 31 | 32 | 33 | kmm_x86_avx_f_4x4_4x4_3_32x64_8x4x1 34 | kmm_x86_avx_f_4x4_4x4_1_32x64_8x4x1 35 | 36 | kmm_x86_avx_f_4x4_4x4_3_16x128_16x1x4 37 | kmm_x86_avx_f_4x4_4x4_1_16x128_16x1x4 38 | 39 | kmm_x86_avx_f_4x4_4x4_4_8x256_8x1x1 40 | kmm_x86_avx_f_4x4_4x4_1_8x256_8x1x1 41 | 42 | kmm_x86_avx_f_4x4_4x4_5_8x1024_8x1x4 43 | 44 | kmm_x86_avx_f_4x4_4x4_1_16x16_8x2x4 45 | kmm_x86_avx_f_4x4_4x4_1_8x16_8x1x4 46 | 47 | 48 | kmm_x86_avx_f_2x2_2x2_5_32x32_8x2x2 49 | kmm_x86_avx_f_2x2_2x2_1_32x32_8x2x2 50 | 51 | kmm_x86_avx_f_2x2_2x2_6_16x64_8x2x2 52 | kmm_x86_avx_f_2x2_2x2_1_16x64_8x2x2 53 | 54 | kmm_x86_avx_f_2x2_2x2_6_8x64_8x4x2 55 | kmm_x86_avx_f_2x2_2x2_1_8x64_8x4x2 56 | 57 | 58 | kmm_x86_avx_d_128x128_16x128_1_32x256_8x2x2_N*_*_*_* 59 | kmm_x86_avx_d_128x128_16x128_1_16x256_8x2x2 60 | kmm_x86_avx_d_128x128_16x32_1_8x256_4x2x4 61 | 62 | kmm_x86_avx_d_64x64_16x32_1_32x128_8x2x1 63 | kmm_x86_avx_d_64x64_16x32_1_16x256_4x4x1 64 | kmm_x86_avx_d_64x64_16x4_1_8x128_4x2x4 65 | kmm_x86_avx_d_64x64_16x8_1_8x64_4x1x2 66 | 67 | kmm_x86_avx_d_32x32_16x32_1_32x64_8x2x2 68 | kmm_x86_avx_d_32x32_16x32_1_16x64_8x2x2 69 | kmm_x86_avx_d_32x32_16x32_1_8x64_8x2x1 70 | 71 | kmm_x86_avx_d_16x16_16x16_2_32x256_8x2x2 72 | kmm_x86_avx_d_16x16_16x16_1_32x256_8x2x2 73 | kmm_x86_avx_d_16x16_16x16_2_8x256_8x2x1 74 | kmm_x86_avx_d_16x16_16x16_1_8x256_8x2x1 75 | kmm_x86_avx_d_16x16_16x16_2_16x256_8x2x1 76 | kmm_x86_avx_d_16x16_16x16_1_16x256_8x2x1 77 | kmm_x86_avx_d_16x16_16x16_1_8x128_8x2x1 78 | 79 | kmm_x86_avx_d_8x8_8x8_3_32x512_4x2x2 80 | kmm_x86_avx_d_8x8_8x8_1_32x512_4x2x2 81 | kmm_x86_avx_d_8x8_8x8_2_32x64_4x4x1 82 | kmm_x86_avx_d_8x8_8x8_1_32x64_4x4x1 83 | kmm_x86_avx_d_8x8_8x8_3_16x512_4x4x1 84 | kmm_x86_avx_d_8x8_8x8_1_16x512_4x4x1 85 | kmm_x86_avx_d_8x8_8x8_3_8x512_8x1x2 86 | kmm_x86_avx_d_8x8_8x8_1_8x512_8x1x2 87 | kmm_x86_avx_d_8x8_8x8_2_8x64_4x1x2 88 | kmm_x86_avx_d_8x8_8x8_1_8x64_4x1x2 89 | 90 | kmm_x86_avx_d_4x4_4x4_3_16x128_8x2x2 91 | kmm_x86_avx_d_4x4_4x4_1_16x128_8x2x2 92 | kmm_x86_avx_d_4x4_4x4_4_8x512_8x2x2 93 | kmm_x86_avx_d_4x4_4x4_1_8x512_8x2x2 94 | kmm_x86_avx_d_4x4_4x4_3_8x64_4x2x2 95 | kmm_x86_avx_d_4x4_4x4_1_8x64_4x2x2 96 | kmm_x86_avx_d_4x4_4x4_2_8x16_4x2x1 97 | kmm_x86_avx_d_4x4_4x4_1_8x16_4x2x1 98 | 99 | kmm_x86_avx_d_2x2_2x2_6_16x64_8x2x2 100 | kmm_x86_avx_d_2x2_2x2_1_16x64_8x2x2 101 | 102 | kmm_x86_avx_d_2x2_2x2_6_8x64_4x1x2 103 | kmm_x86_avx_d_2x2_2x2_1_8x64_4x1x2 104 | 105 | kmm_x86_avx512_f_128x128_32x128_1_32x256_32x2x4 106 | kmm_x86_avx512_f_128x128_32x128_1_16x512_16x4x2 107 | kmm_x86_avx512_f_128x128_32x32_1_16x512_16x4x2 108 | 109 | kmm_x86_avx512_f_64x64_64x64_1_64x64_64x1x4 110 | kmm_x86_avx512_f_64x64_32x64_1_32x128_32x2x4 111 | kmm_x86_avx512_f_64x64_32x32_1_16x256_16x4x2 112 | 113 | kmm_x86_avx512_f_32x32_32x32_1_64x128_64x1x4 114 | kmm_x86_avx512_f_32x32_32x32_1_32x64_32x2x4 115 | kmm_x86_avx512_f_32x32_32x32_1_16x64_16x2x4 116 | 117 | kmm_x86_avx512_f_16x16_16x16_2_32x256_32x1x4 118 | kmm_x86_avx512_f_16x16_16x16_1_32x256_32x1x4 119 | kmm_x86_avx512_f_16x16_16x16_2_16x256_16x1x4 120 | kmm_x86_avx512_f_16x16_16x16_1_16x256_16x1x4 121 | kmm_x86_avx512_f_16x16_16x16_1_16x64_16x4x4 122 | 123 | kmm_x86_avx512_f_8x8_8x8_3_16x512_16x1x4 124 | kmm_x86_avx512_f_8x8_8x8_1_16x512_16x1x4 125 | kmm_x86_avx512_f_8x8_8x8_2_16x64_16x4x4 126 | kmm_x86_avx512_f_8x8_8x8_1_16x64_16x4x4 127 | 128 | kmm_x86_avx512_f_4x4_4x4_3_16x256_16x4x4 129 | kmm_x86_avx512_f_4x4_4x4_1_16x256_16x4x4 130 | kmm_x86_avx512_f_4x4_4x4_3_32x256_32x2x4 131 | kmm_x86_avx512_f_4x4_4x4_1_32x256_32x2x4 132 | 133 | kmm_x86_avx512_f_2x2_2x2_4_32x64_32x4x2 134 | kmm_x86_avx512_f_2x2_2x2_1_32x64_32x4x2 135 | kmm_x86_avx512_f_2x2_2x2_4_16x64_16x4x2 136 | kmm_x86_avx512_f_2x2_2x2_1_16x64_16x4x2 137 | 138 | kmm_x86_avx512_d_128x128_16x128_1_32x512_16x2x4 139 | kmm_x86_avx512_d_128x128_16x128_1_16x512_16x2x4 140 | kmm_x86_avx512_d_128x128_16x32_1_16x128_16x1x4 141 | 142 | kmm_x86_avx512_d_64x64_16x64_1_32x128_16x2x4 143 | kmm_x86_avx512_d_64x64_16x32_1_16x256_16x2x4 144 | kmm_x86_avx512_d_64x64_16x8_1_8x128_8x2x8 145 | 146 | kmm_x86_avx512_d_32x32_16x32_1_32x64_16x2x4 147 | kmm_x86_avx512_d_32x32_16x32_1_16x64_16x2x4 148 | kmm_x86_avx512_d_32x32_16x32_1_8x64_8x2x8 149 | 150 | kmm_x86_avx512_d_16x16_16x16_2_32x256_16x2x4 151 | kmm_x86_avx512_d_16x16_16x16_1_32x256_16x2x4 152 | kmm_x86_avx512_d_16x16_16x16_2_16x256_16x2x4 153 | kmm_x86_avx512_d_16x16_16x16_1_16x256_16x2x4 154 | kmm_x86_avx512_d_16x16_16x16_2_8x256_8x2x8 155 | kmm_x86_avx512_d_16x16_16x16_1_8x256_8x2x8 156 | kmm_x86_avx512_d_16x16_16x16_1_8x128_8x2x8 157 | 158 | kmm_x86_avx512_d_8x8_8x8_3_32x512_32x2x2 159 | kmm_x86_avx512_d_8x8_8x8_1_32x512_32x2x2 160 | kmm_x86_avx512_d_8x8_8x8_3_16x512_16x2x2 161 | kmm_x86_avx512_d_8x8_8x8_1_16x512_16x2x2 162 | kmm_x86_avx512_d_8x8_8x8_2_16x64_16x1x4 163 | kmm_x86_avx512_d_8x8_8x8_1_16x64_16x1x4 164 | 165 | kmm_x86_avx512_d_4x4_4x4_3_16x128_16x2x2 166 | kmm_x86_avx512_d_4x4_4x4_1_16x128_16x2x2 167 | kmm_x86_avx512_d_4x4_4x4_3_8x128_8x2x2 168 | kmm_x86_avx512_d_4x4_4x4_1_8x128_8x2x2 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastKron 2 | 3 | FastKron is a fast library for computing *Generalized Matrix Kronecker-Matrix Multiplication (GeMKM)* and *Generalized Kronecker-Matrix Matrix Multiplication (GeKMM)* on NVIDIA GPUs and X86 CPUs. 4 | FastKron contains specialized algorithms and implementations of GeMKM and GeKMM rather than using existing linear algebra operations. 5 | FastKron avoids extra transposes and adds more optimizations including fusion of multiple kernels. 6 | Therefore, FastKron performs orders of magnitude better than baseline GPyTorch, NVIDIA cuTensor, and HPTT. 7 | Fastkron provides a C++ library and a Python library for Numpy and PyTorch autograd functions. 8 | FastKron provides fast implementations for float and double data type, while Numpy/PyTorch functions uses Shuffle algorithm for other types. 9 | 10 | For more details look [Fast Kronecker Matrix-Matrix Multiplication on GPUs](https://dl.acm.org/doi/abs/10.1145/3627535.3638489). 11 | 12 | # Performance 13 | We compare FastKron's GeMKM and GeKMM with the existing shuffle algorithm in GPyTorch based on PyTorch 2.5.1. 14 | Below table shows the range of speedup on different hardware and data types. 15 | 16 | ### GeMKM 17 | 18 | | Hardware | Float | Double | 19 | |----------|----------|--------| 20 | | AMD 64-Core CPU with AVX| 9.3-45x| 5.8-21x| 21 | | AMD 64-Core CPU with AVX512| 9.7-38x| 6.3-21x| 22 | | NVIDIA A100 80 GB| 1.5-9.5x| 1.1-9.5x| 23 | | NVIDIA V100 16 GB| 2.5-10x| 1.9-11x| 24 | 25 | ### GeKMM 26 | 27 | | Hardware | Float | Double | 28 | |----------|----------|--------| 29 | | AMD 64-Core CPU with AVX| 2.7-13.7x| 1.5-7x| 30 | | AMD 64-Core CPU with AVX512| 2.2-14x| 2-7x| 31 | | NVIDIA A100 80 GB|1.3-4.6x |0.9-4.5x | 32 | | NVIDIA V100 16 GB| 1.4-6.4x|2-7.8x | 33 | 34 | For more information see [documents/performance.md](https://github.com/abhijangda/FastKron/blob/main/documents/performance.md) 35 | 36 | # Hardware and OS Support 37 | | | Linux | WSL2 | Windows | Mac | 38 | |----------|----------|----------|-------|-----| 39 | | x86 | ✅ | ✅ | 🐍 | 🐍 | 40 | | ARM | 🐍 | 🐍 | 🐍 | 🐍 | 41 | | AVX256 | ✅ | ✅ | 🐍 | 🐍 | 42 | | AVX512 | ✅ |✅ | 🐍 | 🐍| 43 | | SM50+ CUDA cores |✅ | ✅ | 🐍 | 🐍 | 44 | | SM80+ Tensor cores | ❌ | ❌ | 🐍 | 🐍 | 45 | | AMD RoCM | 🐍 | 🐍 | 🐍 | 🐍 | 46 | 47 | ✅ FastKron supports optimized implementations for AVX256 and AVX512 CPUs and NVIDIA GPUs.\ 48 | ❌ Tensor cores for double are not supported.\ 49 | 🐍 Supported in Python module. x86 CPUs older than GLIBC x86-64-v2, ARM CPUs, AMD GPUs, Windows, and Mac OS are not supported in C++ API but PyFastKron *fallbacks* to the shuffle algorithm in Numpy or PyTorch. 50 | 51 | The future roadmap is as follows in terms of priority: Windows, SM80+ Double Tensor cores, AMD GPUs, ARM CPUs. 52 | 53 | # Example 54 | The directory `example/` pinclude examples of using FastKron's CUDA and x86 backend using both C++ and Python. 55 | Before using an example, follow below instructions to build FastKron. 56 | 57 | # Installation 58 | 59 | PyFastKron can be installed using pip. 60 | 61 | ```pip install pyfastkron``` 62 | 63 | PyFastKron's CUDA backend is built with CUDA 12.3 but is compatible with CUDA 11.8 and above. 64 | 65 | # Build 66 | Build the C++ library, libFastKron.so, to use with C++ programs or the Python library, PyFastKron, to use with PyTorch or Numpy programs. 67 | 68 | ### Required Pre-requisites 69 | On Ubuntu : 70 | ``` 71 | sudo apt update && sudo apt install gcc linux-headers-$(uname -r) make g++ git python3-dev wget unzip python3-pip build-essential devscripts debhelper fakeroot intel-mkl cmake 72 | ``` 73 | 74 | ### CUDA Pre-requisite 75 | Install CUDA 11+ from https://developer.nvidia.com/cuda/ . 76 | 77 | ### Clone repository 78 | Clone repository with submodules using 79 | ``` 80 | git clone --recurse-submodules https://github.com/abhijangda/fastkron.git 81 | ``` 82 | 83 | If already cloned and want to only clone submodules, use 84 | ``` 85 | git submodule update --init --recursive 86 | ``` 87 | 88 | ### libFastKron 89 | Build FastKron as C++ library using below commands: 90 | 91 | ```mkdir build/ 92 | cd build/ 93 | cmake .. 94 | make -j 95 | ``` 96 | 97 | To install run 98 | ```make install``` 99 | 100 | By default both x86 and CUDA backends are built. use CMAKE option `-DENABLE_CUDA=OFF` to disable CUDA backend or `-DENABLE_X86=OFF` to disable x86 backend. 101 | 102 | Run X86 CPU tests using 103 | ``` 104 | make run-x86-tests 105 | ``` 106 | 107 | Run CUDA tests using 108 | ``` 109 | make run-cuda-tests 110 | ``` 111 | 112 | ### PyFastKron 113 | Install PyFastKron using pip 114 | 115 | ``` 116 | pip install . 117 | ``` 118 | 119 | Run tests using 120 | ``` 121 | pytest 122 | ``` 123 | 124 | # Documentation 125 | 126 | C++ API: [documents/cpp-api.md](https://github.com/abhijangda/FastKron/blob/main/documents/cpp-api.md)\ 127 | Python API: [documents/python-api.md](https://github.com/abhijangda/FastKron/blob/main/documents/python-api.md)\ 128 | Kernel Tuning: [documents/autotuning.md](https://github.com/abhijangda/FastKron/blob/main/documents/autotuning.md)\ 129 | Performance: [documents/performance.md](https://github.com/abhijangda/FastKron/blob/main/documents/performance.md)\ 130 | Multi-GPU: [documents/multigpu.md](https://github.com/abhijangda/FastKron/blob/main/documents/multigpu.md)\ 131 | Contributing: [documents/contributing.md](https://github.com/abhijangda/FastKron/blob/main/documents/contributing.md) 132 | 133 | # Citation 134 | 135 | ``` 136 | @inproceedings{10.1145/3627535.3638489, 137 | author = {Jangda, Abhinav and Yadav, Mohit}, 138 | title = {Fast Kronecker Matrix-Matrix Multiplication on GPUs}, 139 | year = {2024}, 140 | isbn = {9798400704352}, 141 | publisher = {Association for Computing Machinery}, 142 | address = {New York, NY, USA}, 143 | url = {https://doi.org/10.1145/3627535.3638489}, 144 | doi = {10.1145/3627535.3638489}, 145 | booktitle = {Proceedings of the 29th ACM SIGPLAN Annual Symposium on Principles and Practice of Parallel Programming}, 146 | pages = {390–403}, 147 | numpages = {14}, 148 | keywords = {graphics processing units, CUDA, kronecker product, linear algebra}, 149 | location = {Edinburgh, United Kingdom}, 150 | series = {PPoPP '24} 151 | } 152 | ``` -------------------------------------------------------------------------------- /tests/python/test_torch.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import torch 3 | import torch.autograd 4 | 5 | import pyfastkron.fastkrontorch as fk 6 | 7 | def product(values): 8 | return reduce((lambda a, b: a * b), values) 9 | 10 | def transpose(m): 11 | axis = tuple(range(len(m.shape[:-2]))) + \ 12 | (len(m.shape) - 1, len(m.shape) - 2) 13 | return m.mT #torch.transpose(m, -2, -1) 14 | 15 | def reference(mmtype, x, fs, device): 16 | batchKron = fs[0].shape[:-2] 17 | if len(batchKron) == 0: 18 | outputKron = fs[0] 19 | for m in fs[1:]: 20 | outputKron = torch.kron(outputKron, m) 21 | else: 22 | batchDims = product(batchKron) 23 | fs = [f.reshape((batchDims,) + f.shape[-2:]) for f in fs] 24 | 25 | output = fs[0] 26 | for f in fs[1:]: 27 | prev = output 28 | s = (batchDims, prev.shape[-2] * f.shape[-2], prev.shape[-1] * f.shape[-1]) 29 | output = torch.zeros((batchDims, prev.shape[-2] * f.shape[-2], prev.shape[-1] * f.shape[-1]), 30 | dtype=f.dtype).to(device) 31 | for b in range(batchDims): 32 | output[b,:,:] = torch.kron(prev.contiguous()[b,:,:], f.contiguous()[b,:,:]) 33 | outputKron = output.reshape(batchKron + (output.shape[-2], output.shape[-1])) 34 | 35 | if mmtype == "mkm": 36 | return torch.matmul(x, outputKron) 37 | elif mmtype == "kmm": 38 | return torch.matmul(outputKron, x) 39 | 40 | def run(mmtype, m, n, ps, qs, dtype, device, trX, trF, 41 | high=5, batchDimX=[], batchDimFPre=[], batchDimZ=[], 42 | gradcheck=False): 43 | 44 | if type(ps) is int: 45 | ps = [ps] 46 | 47 | if len(ps) == 1: 48 | ps = [ps[0]]*n 49 | 50 | if type(qs) is int: 51 | qs = [qs] 52 | 53 | if len(qs) == 1: 54 | qs = [qs[0]]*n 55 | 56 | #Using integer values instead of real numbers because 57 | #floating point is not associative 58 | if mmtype == "mkm": 59 | xshape = [m, product(ps)] if not trX else [product(ps), m] 60 | elif mmtype == "kmm": 61 | xshape = [product(ps), m] if not trX else [m, product(ps)] 62 | 63 | xshape = list(batchDimX) + xshape 64 | 65 | if mmtype == "mkm": 66 | fshape = [[ps[i], qs[i]] if not trF else [qs[i], ps[i]] for i in range(n)] 67 | elif mmtype == "kmm": 68 | fshape = [[qs[i], ps[i]] if not trF else [ps[i], qs[i]] for i in range(n)] 69 | 70 | fshape = [list(batchDimFPre) + fshape[i] for i in range(n)] 71 | 72 | zshape = list(batchDimZ) 73 | 74 | if mmtype == "mkm": 75 | zshape += [m,product(qs)] 76 | elif mmtype == "kmm": 77 | zshape += [product(qs),m] 78 | 79 | x = torch.randint(0, high=high,size=xshape, dtype=dtype).to(device) 80 | fs = [torch.randint(0, high=high,size=fshape[i], dtype=dtype).to(device)\ 81 | for i in range(n)] 82 | z = torch.randint(0,high=high, size=zshape, dtype=dtype).to(device) 83 | 84 | if trX: 85 | x = transpose(x) 86 | if trF: 87 | fs = [transpose(f) for f in fs] 88 | 89 | fs = tuple(fs) 90 | 91 | if not gradcheck: 92 | alpha = 3.0 93 | beta = 1.0 94 | if mmtype == "mkm": 95 | y = fk.gemkm(x, fs, alpha, beta, z) 96 | elif mmtype == "kmm": 97 | y = fk.gekmm(fs, x, alpha, beta, z) 98 | 99 | if x.device.type == "cuda": 100 | torch.cuda.synchronize() 101 | ref = alpha * reference(mmtype, x, fs, device) 102 | if z != None: 103 | ref += beta * z 104 | 105 | val = torch.isclose(y, ref).all().item() 106 | print(101) 107 | assert val 108 | else: 109 | x.requires_grad = True 110 | for f in fs: 111 | f.requires_grad = True 112 | if mmtype == "kmm": 113 | torch.autograd.gradcheck(fk.KMM.apply, (x,*fs), eps=1e-5, atol=1e-4) 114 | elif mmtype == "mkm": 115 | torch.autograd.gradcheck(fk.MKM.apply, (x,*fs), eps=1e-5, atol=1e-4) 116 | print(116) 117 | 118 | def device_tests(device): 119 | with torch.no_grad(): 120 | for mmtype in ["mkm", "kmm"]: 121 | run(mmtype, 16, 5, 8, 8, torch.float32, device, False, False) 122 | run(mmtype, 10, 5, 6, 6, torch.float32, device, True, False) 123 | 124 | run(mmtype, 16, 5, 8, 8, torch.float32, device, False, False, batchDimX=[2,], batchDimFPre=[], batchDimZ=[2,]) 125 | run(mmtype, 32, 5, 8, 8, torch.float32, device, False, False, batchDimX=[2,3], batchDimFPre=[2,3]) 126 | run(mmtype, 8, 5, 8, 8, torch.float32, device, False, False, batchDimX=[2,1,], batchDimFPre=[3,]) 127 | run(mmtype, 2, 5, 8, 8, torch.float32, device, False, False, batchDimX=[2,1,], batchDimFPre=[2,4,]) 128 | run(mmtype, 32, 4, 8, 8, torch.float32, device, False, False, batchDimX=[3,3,1,], batchDimFPre=[3,1,4,]) 129 | run(mmtype, 24, 4, 8, 8, torch.float32, device, False, False, batchDimX=[2,], batchDimFPre=[3,2,]) 130 | 131 | run(mmtype, 16, 4, 8, 8, torch.float32, device, False, False, batchDimX=[2,], batchDimFPre=[3,2,], batchDimZ=[3,1]) 132 | 133 | run(mmtype, 16, 4, 16, 8, torch.float32, device, True, True, batchDimX=[2,], batchDimFPre=[]) 134 | run(mmtype, 32, 5, 8, 8, torch.float32, device, True, True, batchDimX=[2,1,], batchDimFPre=[3,]) 135 | run(mmtype, 13, 5, 8, 8, torch.float32, device, True, True, batchDimX=[2,1,], batchDimFPre=[2,4,]) 136 | run(mmtype, 19, 3, 8, 32, torch.float32, device, True, True, batchDimX=[2,], batchDimFPre=[3,2,]) 137 | 138 | #double 139 | run(mmtype, 11, 10, 3, 3, torch.double, device, False, True) 140 | run(mmtype, 200, 2, 32, 32, torch.double, device, True, True) 141 | 142 | run(mmtype, 128, 5, 8, 8, torch.double, device, True, True, batchDimX=[2,1,], batchDimFPre=[2,4,]) 143 | 144 | #float16 145 | run(mmtype, 102, 4, 8, 8, torch.float16, device, False, False, high=2) 146 | run(mmtype, 102, 4, 8, 8, torch.float16, device, False, False, high=2, batchDimX=[2,], batchDimFPre=[]) 147 | run(mmtype, 102, 4, 8, 8, torch.float16, device, False, False, high=2, batchDimX=[2,1,], batchDimFPre=[3,]) 148 | run(mmtype, 10, 3, 16, 8, torch.float16, device, True, False, high=2) 149 | 150 | for mmtype in ["mkm", "kmm"]: 151 | run(mmtype, 5, 4, 6, 6, torch.double, device, False, True, batchDimX=[1,], batchDimFPre=[2,], gradcheck=True) 152 | run(mmtype, 5, 4, 4, 6, torch.double, device, True, True, batchDimX=[1,], batchDimFPre=[2,], gradcheck=True) 153 | 154 | def test_cuda(): 155 | if torch.cuda.is_available(): 156 | device_tests("cuda") 157 | 158 | def test_cpu(): 159 | device_tests("cpu") 160 | 161 | if __name__ == "__main__": 162 | test_cuda() 163 | test_cpu() -------------------------------------------------------------------------------- /src/kernels/kernel_opt.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /** 4 | * A KMMKernel can be compiled with one or more of the following. 5 | * Each optimization optimizes the kernel for a case of problem shapes. 6 | * There are 4 optimization levels 0 to 3. Higher opt level has more 7 | * optimizations and each level add extra optimizations over the prev level. 8 | */ 9 | struct KernelOptimizations { 10 | /** 11 | * An enum of each Optimization 12 | */ 13 | enum Optimization { 14 | //No optimization, i.e. a general kernel. 15 | None = 0, 16 | //No. of slices of tile of X cols is same as the slices of kernel's tileK. 17 | XshSlicesSame = 1 << 0, 18 | //The problem Q is a multiple of kernel's TileF.q() 19 | QMultipleOfTileQ = 1 << 1, 20 | //The problem P is a multiple of kernel's TileF.p() 21 | PMultipleOfTileP = 1 << 2, 22 | //The problem's X.cols is a multiple of kernel's TileX.n() 23 | KMultipleOfTileK = 1 << 3, 24 | //The problem's X.rows is a multiple of kernel's TileX.m() 25 | MMultipleOfTileM = 1 << 4, 26 | //The problem Q is less than kernel's TileF.q() 27 | QLeTileQ = 1 << 5, 28 | //Kernel is invoked with same TileK as the template TileK 29 | TileKSame = 1 << 6, 30 | //Problem's factor has same shape as kernel's MaxF 31 | FactorShapeSame = 1 << 7, 32 | //Number of Optimizations 33 | NumOptimizations = 1 << 8 34 | }; 35 | 36 | /** 37 | * OptLevel0() - Return a bitwise OR of optimizations at level 0. 38 | * At level 0, there are no optimization and a kernel 39 | * will run for any X and F shape. 40 | */ 41 | CUDA_DEVICE_HOST 42 | static constexpr uint OptLevel0() { 43 | return Optimization::None; 44 | } 45 | 46 | /** 47 | * OptLevel1() - Return a bitwise OR of optimizations at level 1. 48 | * At level 1, slices of tile of X cols are same as kernel's tile slices. 49 | */ 50 | CUDA_DEVICE_HOST 51 | static constexpr uint OptLevel1() { 52 | return OptLevel0() | 53 | Optimization::XshSlicesSame 54 | ; 55 | } 56 | 57 | /** 58 | * OptLevel2() - Return a bitwise OR of optimizations at level 2. 59 | * At level 2, problem's K, Q, and P must be multiple of kernel's 60 | * TileK, TileQ and TileP respectively. 61 | */ 62 | CUDA_DEVICE_HOST 63 | static constexpr uint OptLevel2() { 64 | return OptLevel1() | 65 | Optimization::KMultipleOfTileK | 66 | Optimization::QMultipleOfTileQ | 67 | Optimization::PMultipleOfTileP 68 | ; 69 | } 70 | 71 | /** 72 | * OptLevel3() - Return a bitwise OR of optimizations at level 3. 73 | * At level 3, problem's factor has same shape as kernel's, 74 | * kernel is invoked with same TileK as kernel's template, and 75 | * problem's M is multiple of TileM 76 | */ 77 | CUDA_DEVICE_HOST 78 | static constexpr uint OptLevel3() { 79 | return OptLevel2() | 80 | Optimization::FactorShapeSame | 81 | Optimization::TileKSame | 82 | Optimization::MMultipleOfTileM 83 | ; 84 | } 85 | 86 | /** 87 | * MaxOptLevel() - Return maximum optimization level, i.e. 3. 88 | */ 89 | CUDA_DEVICE_HOST 90 | static constexpr uint MaxOptLevel() { 91 | return 3; 92 | } 93 | 94 | /** 95 | * getOptimizations() - Return bitwise OR of optimizations at given level. 96 | * @optLevel: Optimization level. 97 | */ 98 | CUDA_DEVICE_HOST 99 | static constexpr uint getOptimizations(uint optLevel) { 100 | switch(optLevel) { 101 | case 0: return OptLevel0(); 102 | case 1: return OptLevel1(); 103 | case 2: return OptLevel2(); 104 | case 3: return OptLevel3(); 105 | default: 106 | return 0; 107 | } 108 | } 109 | 110 | /** 111 | * isEnabled() - Return true if an optimization is enabled in an optimization level. 112 | * @optLevel: Optimization level. 113 | * @specl: Optimization. 114 | */ 115 | CUDA_DEVICE_HOST 116 | static constexpr bool isEnabled(uint optLevel, Optimization specl) { 117 | return (getOptimizations(optLevel) & specl) == specl; 118 | } 119 | 120 | CUDA_DEVICE_HOST 121 | static constexpr bool IsXshSlicesSame(uint optLevel) { 122 | return isEnabled(optLevel, Optimization::XshSlicesSame); 123 | } 124 | 125 | CUDA_DEVICE_HOST 126 | static constexpr bool IsQMultipleOfTileQ(uint optLevel) { 127 | return isEnabled(optLevel, Optimization::QMultipleOfTileQ); 128 | } 129 | 130 | CUDA_DEVICE_HOST 131 | static constexpr bool IsPMultipleOfTileP(uint optLevel) { 132 | return isEnabled(optLevel, Optimization::PMultipleOfTileP); 133 | } 134 | 135 | CUDA_DEVICE_HOST 136 | static constexpr bool IsKMultipleOfTileK(uint optLevel) { 137 | return isEnabled(optLevel, Optimization::KMultipleOfTileK); 138 | } 139 | 140 | CUDA_DEVICE_HOST 141 | static constexpr bool IsMMultipleOfTileM(uint optLevel) { 142 | return isEnabled(optLevel, Optimization::MMultipleOfTileM); 143 | } 144 | 145 | CUDA_DEVICE_HOST 146 | static constexpr bool IsQLeTileQ (uint optLevel) { 147 | return isEnabled(optLevel, Optimization::QLeTileQ); 148 | } 149 | 150 | CUDA_DEVICE_HOST 151 | static constexpr bool IsTileKSame (uint optLevel) { 152 | return isEnabled(optLevel, Optimization::TileKSame); 153 | } 154 | 155 | CUDA_DEVICE_HOST 156 | static constexpr bool IsFactorShapeSame (uint optLevel) { 157 | return isEnabled(optLevel, Optimization::FactorShapeSame); 158 | } 159 | }; 160 | 161 | template 162 | CUDA_DEVICE_HOST uint32_t getXshSlices(const KernelParams& params) { 163 | constexpr bool kFactorShapeSame = KernelOptimizations::IsFactorShapeSame(OptLevel); 164 | if (kFactorShapeSame) { 165 | return kTileK/kP; 166 | } else { 167 | return params.XshSlices; 168 | } 169 | } 170 | 171 | 172 | template 173 | CUDA_DEVICE_HOST uint32_t getXSlices(const Matrix& Y, const KernelParams& params) { 174 | //# of slices for a row. Same as X.n()/P but use Y.n()/Q to reduce 175 | //number of loads as store also requires reading Y.n() 176 | constexpr bool kFactorShapeSame = KernelOptimizations::IsFactorShapeSame(OptLevel); 177 | if (kFactorShapeSame) { 178 | return Y.n()/kQ; 179 | } else { 180 | return params.XSlices; 181 | } 182 | } 183 | 184 | template 185 | CUDA_DEVICE_HOST uint32_t getQThreads(uint XshSlices) { 186 | if (kXshSlicesSame) return XshSlices/RegK; 187 | return DIVUP(XshSlices, RegK); 188 | } 189 | 190 | template 191 | CUDA_DEVICE_HOST uint32_t getQByTileQ(uint Q) { 192 | if (kQLeTileQ) { 193 | return 1; 194 | } 195 | return DIVUP(Q, TileQ); 196 | } 197 | 198 | template 199 | CUDA_DEVICE_HOST uint32_t getXTileK(KernelParams& params) { 200 | constexpr bool kTileKSame = KernelOptimizations::IsTileKSame(OptLevel); 201 | if (kTileKSame) return kTileK; 202 | return params.tileX.n(); 203 | } 204 | 205 | template 206 | CUDA_DEVICE constexpr fastKronOp swapFastKronOp() { 207 | if (Op == fastKronOp_N) return fastKronOp_T; 208 | if (Op == fastKronOp_T) return fastKronOp_N; 209 | } --------------------------------------------------------------------------------