├── README.md ├── benchmark.py ├── openai_gemm_pytorch.py └── test.py /README.md: -------------------------------------------------------------------------------- 1 | openai-gemm.pytorch 2 | ======== 3 | 4 | PyTorch bindings for openai-gemm. 5 | 6 | 7 | 8 | 9 | ## Installation 10 | 11 | Clone original openai-gemm and add it to PYTHONPATH, 12 | install pycuda: 13 | 14 | ``` 15 | pip install pycuda 16 | ``` 17 | 18 | and follow instructions to install PyTorch on 19 | 20 | No `neon` installation needed. 21 | 22 | ## Usage 23 | 24 | The library defines `matmul` function similar to the one that 25 | works with neon: , 26 | which instead of neon matrices takes `torch.cuda.FloatTensor` or `torch.cuda.HalfTensor` 27 | as A, B and C. 28 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import pycuda.autoinit 4 | import numpy as np 5 | import pycuda.driver as drv 6 | import torch 7 | 8 | from openai_gemm_pytorch import matmul 9 | 10 | print drv.Context.get_current().get_device().name() 11 | 12 | config = ( 13 | # m, n, k, AT, BT (row order) 14 | ( 16, 1760, 1760, False, False), 15 | ( 32, 1760, 1760, False, False), 16 | ( 64, 1760, 1760, False, False), 17 | ( 128, 1760, 1760, False, False), 18 | ( 7000, 1760, 1760, False, False), 19 | ( 16, 2048, 2048, False, False), 20 | ( 32, 2048, 2048, False, False), 21 | ( 64, 2048, 2048, False, False), 22 | ( 128, 2048, 2048, False, False), 23 | ( 7000, 2048, 2048, False, False), 24 | ( 16, 2560, 2560, False, False), 25 | ( 32, 2560, 2560, False, False), 26 | ( 64, 2560, 2560, False, False), 27 | ( 128, 2560, 2560, False, False), 28 | ( 7000, 2560, 2560, False, False), 29 | ( 16, 4096, 4096, False, False), 30 | ( 32, 4096, 4096, False, False), 31 | ( 64, 4096, 4096, False, False), 32 | ( 128, 4096, 4096, False, False), 33 | ( 7000, 4096, 4096, False, False), 34 | ( 16, 1760, 1760, False, True), 35 | ( 32, 1760, 1760, False, True), 36 | ( 64, 1760, 1760, False, True), 37 | ( 128, 1760, 1760, False, True), 38 | ( 7000, 1760, 1760, False, True), 39 | ( 16, 2048, 2048, False, True), 40 | ( 32, 2048, 2048, False, True), 41 | ( 64, 2048, 2048, False, True), 42 | ( 128, 2048, 2048, False, True), 43 | ( 7000, 2048, 2048, False, True), 44 | ( 16, 2560, 2560, False, True), 45 | ( 32, 2560, 2560, False, True), 46 | ( 64, 2560, 2560, False, True), 47 | ( 128, 2560, 2560, False, True), 48 | ( 7000, 2560, 2560, False, True), 49 | ( 16, 4096, 4096, False, True), 50 | ( 32, 4096, 4096, False, True), 51 | ( 64, 4096, 4096, False, True), 52 | ( 128, 4096, 4096, False, True), 53 | ( 7000, 4096, 4096, False, True), 54 | ( 7133, 1760, 1760, True , False), 55 | ( 7133, 2048, 2048, True , False), 56 | ( 7133, 2560, 2560, True , False), 57 | ( 7133, 4096, 4096, True , False), 58 | ( 9124, 5124, 1760, False, False), 59 | ( 9124, 5124, 2048, False, False), 60 | ( 9124, 5124, 2560, False, False), 61 | ( 9124, 5124, 4096, False, False), 62 | ( 9124, 5124, 1760, False, True), 63 | ( 9124, 5124, 2048, False, True), 64 | ( 9124, 5124, 2560, False, True), 65 | ( 9124, 5124, 4096, False, True), 66 | ( 8457, 35, 1760, False, False), 67 | ( 8457, 35, 2048, False, False), 68 | ( 8457, 35, 2560, False, False), 69 | ( 8457, 35, 4096, False, False), 70 | ( 8457, 35, 1760, False, True), 71 | ( 8457, 35, 2048, False, True), 72 | ( 8457, 35, 2560, False, True), 73 | ( 8457, 35, 4096, False, True), 74 | ( 16, 7680, 2560, False, False), 75 | ( 32, 7680, 2560, False, False), 76 | ( 64, 7680, 2560, False, False), 77 | ( 128, 7680, 2560, False, False), 78 | ( 16, 7680, 2560, False, True), 79 | ( 32, 7680, 2560, False, True), 80 | ( 64, 7680, 2560, False, True), 81 | ( 128, 7680, 2560, False, True), 82 | ( 16, 3072, 1024, False, False), 83 | ( 32, 3072, 1024, False, False), 84 | ( 64, 3072, 1024, False, False), 85 | ( 128, 3072, 1024, False, False), 86 | ( 16, 3072, 1024, False, True), 87 | ( 32, 3072, 1024, False, True), 88 | ( 64, 3072, 1024, False, True), 89 | ( 128, 3072, 1024, False, True), 90 | ( 7435, 3072, 1024, True , False), 91 | ( 5481, 7680, 2560, True , False), 92 | 93 | # (60000, 32, 32, True , False), 94 | # (60000, 256, 256, True , False), 95 | 96 | # ( 4096, 4096, 32, True , False), 97 | # ( 3456, 3456, 32, True , False), 98 | # ( 896, 896, 32, True , False), 99 | ) 100 | 101 | print "| M| N| K| Op|OpenAI_32|cuBLAS_32|ratio_32|OpenAI_16|cuBLAS_16|ratio_16|" 102 | print "|------|------|------|---|---------|---------|--------|---------|---------|--------|" 103 | 104 | for m, n, k, at, bt in config: 105 | 106 | dimA = (k,m) if at else (m,k) 107 | dimB = (n,k) if bt else (k,n) 108 | dimC = (m,n) 109 | 110 | opA = 'T' if at else 'N' 111 | opB = 'T' if bt else 'N' 112 | op = opA + opB 113 | 114 | dtype_data = list() 115 | 116 | for dtype in ('torch.cuda.FloatTensor', 'torch.cuda.HalfTensor'): #np.float32, np.float16, 117 | 118 | A = torch.randn(dimA).type(dtype) 119 | B = torch.randn(dimB).type(dtype) 120 | C = torch.randn(dimC).type(dtype) 121 | 122 | if at: A = A.t() 123 | if bt: B = B.t() 124 | 125 | data = matmul(A, B, C, bench=True) 126 | 127 | cublas = data.pop() 128 | openai = sorted(data)[0] 129 | 130 | text = "%9.0f|%9.0f|%8.1f" % (openai[1], cublas[1], openai[1] / cublas[1]) 131 | 132 | dtype_data.append(text) 133 | 134 | 135 | print "|%6d|%6d|%6d|%3s|%s|" % (m, n, k, op, "|".join(dtype_data)) 136 | 137 | -------------------------------------------------------------------------------- /openai_gemm_pytorch.py: -------------------------------------------------------------------------------- 1 | import pycuda.autoinit 2 | import torch 3 | from openai_gemm import _get_gemm_kernel, _get_bench_data 4 | 5 | def is_transposed(t): 6 | return not (t.stride(1) == 1 and t.stride(0) != 0) 7 | 8 | def matmul(A, B, C, alpha=1.0, beta=0.0, stream=None, bench=False): 9 | """ 10 | C = alpha * A . B + beta * C 11 | C = alpha * A.T . B + beta * C 12 | C = alpha * A . B.T + beta * C 13 | C = alpha * A.T . B.T + beta * C 14 | 15 | bench: return benchmark data for all available tiles + cublas 16 | """ 17 | 18 | if isinstance(C, torch.cuda.FloatTensor): 19 | prefix = "s" 20 | elif isinstance(C, torch.cuda.HalfTensor): 21 | prefix = "h" 22 | else: 23 | raise TypeError("Only floating point dot currently supported.") 24 | 25 | # (m,n) = (m,k) . (k,n) 26 | m = A.size(0) 27 | n = B.size(1) 28 | k = A.size(1) 29 | assert m == C.size(0) 30 | assert n == C.size(1) 31 | assert k == B.size(0) 32 | 33 | # Extract the operations and contiguous dimension sizes (cda, cdb, cdc). 34 | # Note that these can be the same as from the shape unless the non-contiguous dimension is sliced. 35 | # One dimension must be contiguous (DRAM efficiency demands this). 36 | # Note that the strides here do not include the datatype size as they would in numpy. 37 | # A transpose op (.T) on a GPUTensor reverses the shape and strides then flags the tensor as transposed (is_trans=True) - 38 | # The underlying data is unchanged. 39 | if is_transposed(A): 40 | opA = 'T' 41 | cda = A.stride(1) 42 | assert A.stride(0) == 1 43 | else: 44 | opA = 'N' 45 | cda = A.stride(0) 46 | assert A.stride(1) == 1 47 | 48 | if is_transposed(B): 49 | opB = 'T' 50 | cdb = B.stride(1) 51 | assert B.stride(0) == 1 52 | else: 53 | opB = 'N' 54 | cdb = B.stride(0) 55 | assert B.stride(1) == 1 56 | 57 | cdc = C.stride(0) 58 | assert C.stride(1) == 1 59 | 60 | op = opA + opB 61 | 62 | # get and autotune the kernel selection 63 | kernel, params, dynamic_shared = _get_gemm_kernel(prefix, op, cda, cdb, cdc, m, n, k) 64 | 65 | # bind dynamic params 66 | params[2:8] = (stream, C.data_ptr(), A.data_ptr(), B.data_ptr(), alpha, beta) 67 | 68 | # call the kernel 69 | kernel.prepared_async_call(*params, shared_size=dynamic_shared) 70 | 71 | # unbind dynamic params 72 | params[2:8] = (None,) * 6 73 | 74 | # return benchmark data if requested 75 | if bench: 76 | return _get_bench_data()[(prefix, op, cda, cdb, cdc, m, n, k)] 77 | 78 | return C 79 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from openai_gemm_pytorch import matmul 4 | 5 | 6 | class TestMatMul(unittest.TestCase): 7 | def testNN(self): 8 | a = torch.randn(5,4).cuda() 9 | b = torch.randn(4,7).cuda() 10 | 11 | c = torch.Tensor(5,7).cuda() 12 | matmul(a, b, c) 13 | 14 | self.assertLess((c - a.mm(b)).abs().max(), 1e-6) 15 | 16 | def testNT(self): 17 | a = torch.randn(5,4).cuda() 18 | b = torch.randn(7,4).cuda().t() 19 | 20 | c = torch.Tensor(5,7).cuda() 21 | matmul(a, b, c) 22 | 23 | self.assertLess((c - a.mm(b)).abs().max(), 1e-6) 24 | 25 | if __name__ == '__main__': 26 | unittest.main() 27 | --------------------------------------------------------------------------------