├── .gitignore ├── README.md └── learn_ptx ├── __init__.py ├── context.py ├── elemwise.py ├── fps.py ├── kernels ├── elemwise_sqrt.ptx ├── fps_block.ptx ├── fps_block_v2.ptx ├── matmul_big_blocks.ptx ├── matmul_big_blocks_v2.ptx ├── matmul_big_blocks_v3.ptx ├── matmul_big_blocks_v4.ptx ├── matmul_big_blocks_v5.ptx ├── matmul_inner_loop.ptx ├── matmul_simple_block_v1.ptx ├── matmul_simple_block_v2.ptx ├── matmul_simple_block_v3.ptx ├── matmul_simple_block_v4.ptx ├── matmul_wmma_v1.ptx ├── matmul_wmma_v10.ptx ├── matmul_wmma_v2.ptx ├── matmul_wmma_v3.ptx ├── matmul_wmma_v4.ptx ├── matmul_wmma_v5.ptx ├── matmul_wmma_v6.ptx ├── matmul_wmma_v7.ptx ├── matmul_wmma_v8.ptx ├── matmul_wmma_v9.ptx ├── reduction_all_max_naive.ptx ├── reduction_all_max_naive_opt.ptx ├── reduction_all_max_naive_opt_flexible.ptx ├── reduction_all_max_naive_opt_flexible_novec.ptx ├── reduction_all_max_naive_opt_flexible_sin.ptx ├── reduction_all_max_naive_opt_flexible_sin_cpasync.ptx ├── reduction_all_max_naive_opt_flexible_widevec.ptx ├── reduction_all_max_naive_opt_novec.ptx ├── reduction_bool_naive.ptx ├── reduction_bool_warp.ptx ├── reduction_bool_warp_vec.ptx ├── reduction_trans_bool_blocked.ptx ├── reduction_trans_bool_naive.ptx ├── sort_bitonic_block.ptx ├── sort_bitonic_block_v2.ptx ├── sort_bitonic_global.ptx ├── sort_bitonic_global_v2.ptx ├── sort_bitonic_warp.ptx ├── sort_bitonic_warp_v2.ptx ├── sort_bitonic_warp_v3.ptx └── sort_merge_global.ptx ├── matmul.py ├── reduction.py └── sort.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # learn-ptx 2 | 3 | This is a collection of hand-coded PTX kernels that I'm writing while learning low-level CUDA programming. 4 | -------------------------------------------------------------------------------- /learn_ptx/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/learn-ptx/be2e59877d6fce9cd43c149ebfd00b6dc52162ac/learn_ptx/__init__.py -------------------------------------------------------------------------------- /learn_ptx/context.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import contextmanager 3 | from functools import lru_cache 4 | from typing import Callable, Sequence, Union 5 | 6 | import numpy as np 7 | import pycuda.autoinit 8 | from pycuda.compiler import DynamicModule 9 | from pycuda.driver import ( 10 | Context, 11 | DeviceAllocation, 12 | Event, 13 | from_device, 14 | jit_input_type, 15 | to_device, 16 | ) 17 | 18 | KERNEL_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "kernels") 19 | 20 | 21 | @lru_cache(maxsize=8) 22 | def compile_function(code_filename: str, function_name: str) -> Callable: 23 | module = DynamicModule() 24 | with open(os.path.join(KERNEL_DIR, code_filename), "rb") as f: 25 | module.add_data(f.read(), jit_input_type.PTX, name="kernel.ptx") 26 | module.link() 27 | return module.get_function(function_name) 28 | 29 | 30 | def numpy_to_gpu(arr: np.ndarray) -> DeviceAllocation: 31 | return to_device(arr) 32 | 33 | 34 | def gpu_to_numpy( 35 | allocation: DeviceAllocation, shape: Sequence[int], dtype: Union[np.dtype, str] 36 | ) -> np.ndarray: 37 | return from_device(allocation, shape, dtype) 38 | 39 | 40 | def sync(): 41 | Context.synchronize() 42 | 43 | 44 | @contextmanager 45 | def measure_time() -> Callable[[], float]: 46 | start = Event() 47 | end = Event() 48 | 49 | start.record() 50 | start.synchronize() 51 | 52 | def delay_fn() -> float: 53 | return start.time_till(end) / 1000 54 | 55 | yield delay_fn 56 | 57 | end.record() 58 | end.synchronize() 59 | -------------------------------------------------------------------------------- /learn_ptx/elemwise.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | 3 | import numpy as np 4 | 5 | from .context import compile_function, gpu_to_numpy, numpy_to_gpu, sync 6 | 7 | 8 | def sqrt_example(): 9 | fn = compile_function("elemwise_sqrt.ptx", "sqrtElements") 10 | inputs = np.abs(np.random.normal(size=[10000]).astype(np.float32)) 11 | input_buf = numpy_to_gpu(inputs) 12 | output_buf = numpy_to_gpu(inputs) 13 | block_size = 1024 14 | fn( 15 | input_buf, 16 | output_buf, 17 | np.int32(len(inputs) - 10), 18 | grid=(ceil(len(inputs) / block_size), 1, 1), 19 | block=(block_size, 1, 1), 20 | ) 21 | sync() 22 | results = gpu_to_numpy(output_buf, inputs.shape, inputs.dtype) 23 | expected = np.sqrt(inputs) 24 | print( 25 | f"maximum absolute error of sqrt is {np.abs(results[:-10] - expected[:-10]).max()}" 26 | ) 27 | print( 28 | f"maximum absolute error of masked is {np.abs(results[-10:] - inputs[-10:]).max()}" 29 | ) 30 | 31 | 32 | if __name__ == "__main__": 33 | sqrt_example() 34 | -------------------------------------------------------------------------------- /learn_ptx/fps.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | 5 | from .context import compile_function, gpu_to_numpy, measure_time, numpy_to_gpu, sync 6 | 7 | 8 | def fps_block(): 9 | fn = compile_function("fps_block.ptx", "farthestPointSampleBlock") 10 | inputs = np.random.normal(size=[2**16, 3]).astype(np.float32) 11 | outputs = np.zeros([4096, 3]).astype(np.float32) 12 | tmp = np.zeros([len(inputs)]).astype(np.float32) 13 | input_buf = numpy_to_gpu(inputs) 14 | output_buf = numpy_to_gpu(outputs) 15 | tmp_buf = numpy_to_gpu(tmp) 16 | with measure_time() as timer: 17 | fn( 18 | input_buf, 19 | tmp_buf, 20 | output_buf, 21 | np.int64(len(inputs)), 22 | np.int64(len(outputs)), 23 | grid=(1, 1, 1), 24 | block=(1024, 1, 1), 25 | ) 26 | sync() 27 | results = gpu_to_numpy(output_buf, outputs.shape, outputs.dtype) 28 | print(f"took {timer():.05f} seconds on GPU") 29 | t1 = time.time() 30 | expected = fps_on_cpu(inputs, len(outputs)) 31 | t2 = time.time() 32 | print(f"took {(t2 - t1):.05f} seconds on CPU") 33 | print(f"maximum absolute error is {np.abs(results - expected).max()}") 34 | print("results", results[:4]) 35 | print("expected", expected[:4]) 36 | 37 | 38 | def fps_block_v2(): 39 | fn = compile_function("fps_block_v2.ptx", "farthestPointSampleBlockV2") 40 | inputs = np.random.normal(size=[2**16, 3]).astype(np.float32) 41 | outputs = np.zeros([4096, 3]).astype(np.float32) 42 | tmp = np.zeros([len(inputs)]).astype(np.float32) 43 | input_buf = numpy_to_gpu(inputs) 44 | output_buf = numpy_to_gpu(outputs) 45 | tmp_buf = numpy_to_gpu(tmp) 46 | with measure_time() as timer: 47 | fn( 48 | input_buf, 49 | tmp_buf, 50 | output_buf, 51 | np.int64(len(inputs)), 52 | np.int64(len(outputs)), 53 | grid=(1, 1, 1), 54 | block=(1024, 1, 1), 55 | ) 56 | sync() 57 | results = gpu_to_numpy(output_buf, outputs.shape, outputs.dtype) 58 | print(f"took {timer():.05f} seconds on GPU") 59 | t1 = time.time() 60 | expected = fps_on_cpu(inputs, len(outputs)) 61 | t2 = time.time() 62 | print(f"took {(t2 - t1):.05f} seconds on CPU") 63 | print(f"maximum absolute error is {np.abs(results - expected).max()}") 64 | print("results", results[:4]) 65 | print("expected", expected[:4]) 66 | 67 | 68 | def fps_on_cpu(points: np.ndarray, n: int) -> np.ndarray: 69 | results = np.zeros([n, 3], dtype=points.dtype) 70 | results[0] = points[0] 71 | dists = ((points - points[0]) ** 2).sum(-1) 72 | dists[0] = -1 73 | for i in range(1, n): 74 | idx = np.argmax(dists) 75 | point = points[idx] 76 | dists = np.minimum(dists, ((points - point) ** 2).sum(-1)) 77 | dists[idx] = -1 78 | results[i] = point 79 | return results 80 | 81 | 82 | if __name__ == "__main__": 83 | fps_block_v2() 84 | -------------------------------------------------------------------------------- /learn_ptx/kernels/elemwise_sqrt.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | .visible .entry sqrtElements( 6 | .param .u64 inputPtr, 7 | .param .u64 outputPtr, 8 | .param .u32 n 9 | ) { 10 | .reg .pred %p1; 11 | .reg .u64 %addr; 12 | .reg .u32 %tmp<2>; 13 | .reg .u64 %offset; 14 | .reg .f32 %val; 15 | 16 | // Compute the offset as ctaid.x*ntid.x + tid.x 17 | mov.u32 %tmp0, %ctaid.x; 18 | mov.u32 %tmp1, %ntid.x; 19 | mul.lo.u32 %tmp0, %tmp0, %tmp1; 20 | mov.u32 %tmp1, %tid.x; 21 | add.u32 %tmp1, %tmp0, %tmp1; 22 | cvt.u64.u32 %offset, %tmp1; 23 | mul.lo.u64 %offset, %offset, 4; 24 | 25 | // Mask out out-of-bounds accesses. 26 | ld.param.u32 %tmp0, [n]; 27 | setp.lt.u32 %p1, %tmp1, %tmp0; 28 | 29 | // Load the value from memory. 30 | ld.param.u64 %addr, [inputPtr]; 31 | add.u64 %addr, %addr, %offset; 32 | @%p1 ld.global.f32 %val, [%addr]; 33 | 34 | // Element-wise operation itself. 35 | @%p1 sqrt.approx.f32 %val, %val; 36 | 37 | // Store back the output. 38 | ld.param.u64 %addr, [outputPtr]; 39 | add.u64 %addr, %addr, %offset; 40 | @%p1 st.global.f32 [%addr], %val; 41 | } -------------------------------------------------------------------------------- /learn_ptx/kernels/fps_block_v2.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | // Perform farthest point sampling with a single block. 6 | // Takes buffer of points, temporary buffer of distances, 7 | // and output pointer. 8 | // 9 | // Similar to fps_block.ptx, except uses fewer memory 10 | // accesses and is overall less redundant. 11 | 12 | .visible .entry farthestPointSampleBlockV2 ( 13 | .param .u64 ptrIn, 14 | .param .u64 tmpBuffer, 15 | .param .u64 ptrOut, 16 | .param .u64 inCount, 17 | .param .u64 outCount 18 | ) { 19 | .reg .pred %p0; 20 | 21 | // Arguments 22 | .reg .u64 %ptrIn; 23 | .reg .u64 %tmpBuffer; 24 | .reg .u64 %ptrOut; 25 | .reg .u64 %inCount; 26 | .reg .u64 %outCount; 27 | .reg .u64 %tidX; 28 | .reg .u64 %blockSize; 29 | 30 | // Buffers for communicating indices across warps. 31 | .shared .align 4 .f32 largestDist[32]; 32 | .shared .align 8 .u64 largestIndex[32]; 33 | 34 | // Load arguments and thread properties. 35 | ld.param.u64 %ptrIn, [ptrIn]; 36 | ld.param.u64 %tmpBuffer, [tmpBuffer]; 37 | ld.param.u64 %ptrOut, [ptrOut]; 38 | ld.param.u64 %inCount, [inCount]; 39 | ld.param.u64 %outCount, [outCount]; 40 | 41 | cvt.u64.u32 %tidX, %tid.x; 42 | cvt.u64.u32 %blockSize, %ntid.x; 43 | 44 | // Initialize the distance buffer to infinity. 45 | { 46 | .reg .u64 %i; 47 | .reg .u64 %offset; 48 | .reg .u64 %addr; 49 | .reg .f32 %inf; 50 | 51 | div.approx.f32 %inf, 1.0, 0.0; 52 | 53 | mov.u64 %i, 0; 54 | init_loop: 55 | add.u64 %offset, %i, %tidX; 56 | setp.lt.u64 %p0, %offset, %inCount; 57 | mad.lo.u64 %addr, %offset, 4, %tmpBuffer; 58 | @%p0 st.global.f32 [%addr], %inf; 59 | 60 | add.u64 %i, %i, %blockSize; 61 | setp.lt.u64 %p0, %i, %inCount; 62 | @%p0 bra.uni init_loop; 63 | init_loop_end: 64 | } 65 | 66 | // Make the first distance negative so we never pick this point again. 67 | setp.eq.u64 %p0, %tidX, 0; 68 | @%p0 st.global.f32 [%tmpBuffer], -1.0; 69 | 70 | // Loop until we have selected enough points. 71 | { 72 | .reg .u64 %i; 73 | .reg .u64 %nextIndex; 74 | mov.u64 %nextIndex, 0; 75 | mov.u64 %i, 0; 76 | loop_start: 77 | // Both this and %nextIndex will be updated as we compute 78 | // new distances so we can find the next point. 79 | .reg .f32 %localMaxDist; 80 | mov.f32 %localMaxDist, -1.0; 81 | { 82 | // Read the point on all ranks. 83 | .reg .f32 %nextPointX; 84 | .reg .f32 %nextPointY; 85 | .reg .f32 %nextPointZ; 86 | 87 | { 88 | .reg .u64 %nextPointAddr; 89 | mad.lo.u64 %nextPointAddr, %nextIndex, 12, %ptrIn; 90 | ldu.global.f32 %nextPointX, [%nextPointAddr]; 91 | ldu.global.f32 %nextPointY, [%nextPointAddr+4]; 92 | ldu.global.f32 %nextPointZ, [%nextPointAddr+8]; 93 | } 94 | 95 | // Write output from the first thread. 96 | { 97 | .reg .u64 %nextOutput; 98 | mad.lo.u64 %nextOutput, %i, 12, %ptrOut; 99 | setp.eq.u64 %p0, %tidX, 0; 100 | @%p0 st.global.f32 [%nextOutput], %nextPointX; 101 | @%p0 st.global.f32 [%nextOutput+4], %nextPointY; 102 | @%p0 st.global.f32 [%nextOutput+8], %nextPointZ; 103 | } 104 | 105 | // Compute new distances and take the minimum. 106 | // Also update %nextIndex and %localMaxDist. 107 | { 108 | .reg .pred %p1; 109 | .reg .u64 %j; 110 | .reg .f32 %tmpPointX; 111 | .reg .f32 %tmpPointY; 112 | .reg .f32 %tmpPointZ; 113 | .reg .u64 %curIn; 114 | .reg .u64 %curOut; 115 | .reg .u64 %strideIn; 116 | .reg .u64 %strideOut; 117 | .reg .f32 %ftmp<3>; 118 | .reg .u64 %prevIndex; 119 | .reg .u64 %localOffset; 120 | 121 | // We may modify %nextIndex as we find the 122 | // new maximum distance. 123 | mov.u64 %prevIndex, %nextIndex; 124 | 125 | mad.lo.u64 %curIn, %tidX, 12, %ptrIn; 126 | mad.lo.u64 %curOut, %tidX, 4, %tmpBuffer; 127 | mul.lo.u64 %strideIn, %blockSize, 12; 128 | shl.b64 %strideOut, %blockSize, 2; 129 | 130 | mov.u64 %j, 0; 131 | update_distances_loop: 132 | add.u64 %localOffset, %j, %tidX; 133 | setp.lt.u64 %p0, %localOffset, %inCount; 134 | setp.eq.u64 %p1, %localOffset, %prevIndex; 135 | 136 | @%p0 ld.global.f32 %tmpPointX, [%curIn]; 137 | @%p0 ld.global.f32 %tmpPointY, [%curIn+4]; 138 | @%p0 ld.global.f32 %tmpPointZ, [%curIn+8]; 139 | 140 | // Squared euclidean distance. 141 | sub.f32 %ftmp0, %nextPointX, %tmpPointX; 142 | sub.f32 %ftmp1, %nextPointY, %tmpPointY; 143 | sub.f32 %ftmp2, %nextPointZ, %tmpPointZ; 144 | mul.f32 %ftmp0, %ftmp0, %ftmp0; 145 | fma.rn.f32 %ftmp0, %ftmp1, %ftmp1, %ftmp0; 146 | fma.rn.f32 %ftmp0, %ftmp2, %ftmp2, %ftmp0; 147 | @%p0 ld.global.f32 %ftmp1, [%curOut]; 148 | @%p1 mov.f32 %ftmp1, -1.0; // store -1 at the last used point 149 | min.f32 %ftmp2, %ftmp0, %ftmp1; 150 | @%p0 st.global.f32 [%curOut], %ftmp2; 151 | 152 | // Update the distance/index for the next point. 153 | setp.gt.f32 %p1, %ftmp2, %localMaxDist; 154 | @%p1 mov.f32 %localMaxDist, %ftmp2; 155 | @%p1 mov.u64 %nextIndex, %localOffset; 156 | 157 | add.u64 %curIn, %curIn, %strideIn; 158 | add.u64 %curOut, %curOut, %strideOut; 159 | add.u64 %j, %j, %blockSize; 160 | setp.lt.u64 %p0, %j, %inCount; 161 | @%p0 bra.uni update_distances_loop; 162 | update_distances_loop_end: 163 | } 164 | } 165 | 166 | // Find the maximum distance across the entire block, 167 | // to figure out the next point to choose. 168 | { 169 | // Reduce across this warp. 170 | { 171 | .reg .u32 %xorMask; 172 | .reg .f32 %otherMaxDist; 173 | .reg .u32 %otherIndex<2>; 174 | .reg .u64 %otherIndex; 175 | mov.u32 %xorMask, 1; 176 | reduction_loop: 177 | shfl.sync.bfly.b32 %otherMaxDist, %localMaxDist, %xorMask, 0x1f, 0xffffffff; 178 | mov.b64 {%otherIndex1, %otherIndex0}, %nextIndex; 179 | shfl.sync.bfly.b32 %otherIndex0, %otherIndex0, %xorMask, 0x1f, 0xffffffff; 180 | shfl.sync.bfly.b32 %otherIndex1, %otherIndex1, %xorMask, 0x1f, 0xffffffff; 181 | mov.b64 %otherIndex, {%otherIndex1, %otherIndex0}; 182 | 183 | // Keep other value if it's greater or if it has a lower 184 | // index and is equal. 185 | setp.eq.f32 %p0, %localMaxDist, %otherMaxDist; 186 | setp.lt.and.u64 %p0, %otherIndex, %nextIndex, %p0; 187 | setp.gt.or.f32 %p0, %otherMaxDist, %localMaxDist, %p0; 188 | 189 | @%p0 mov.u64 %nextIndex, %otherIndex; 190 | @%p0 mov.f32 %localMaxDist, %otherMaxDist; 191 | shl.b32 %xorMask, %xorMask, 1; 192 | setp.lt.u32 %p0, %xorMask, 32; 193 | @%p0 bra.uni reduction_loop; 194 | reduction_loop_end: 195 | } 196 | 197 | // Write each warp's maximum to shared memory. 198 | .reg .u32 %sharedAddr; 199 | .reg .u32 %warpIndex; 200 | .reg .u32 %threadInWarp; 201 | cvt.u32.u64 %warpIndex, %tidX; 202 | and.b32 %threadInWarp, %warpIndex, 0x1f; 203 | shr.b32 %warpIndex, %warpIndex, 5; 204 | 205 | // Write one output from the first thread of each warp. 206 | setp.eq.u32 %p0, %threadInWarp, 0; 207 | mov.u32 %sharedAddr, largestDist; 208 | mad.lo.u32 %sharedAddr, %warpIndex, 4, %sharedAddr; 209 | @%p0 st.shared.f32 [%sharedAddr], %localMaxDist; 210 | mov.u32 %sharedAddr, largestIndex; 211 | mad.lo.u32 %sharedAddr, %warpIndex, 8, %sharedAddr; 212 | @%p0 st.shared.u64 [%sharedAddr], %nextIndex; 213 | 214 | // Make sure all writes are finished. 215 | bar.sync 0; 216 | 217 | // Read the entire shared buffer on each warp. 218 | // Each warp will now have an independent copy of the 219 | // exact same values to reduce. 220 | mov.u32 %sharedAddr, largestDist; 221 | mad.lo.u32 %sharedAddr, %threadInWarp, 4, %sharedAddr; 222 | ld.shared.f32 %localMaxDist, [%sharedAddr]; 223 | mov.u32 %sharedAddr, largestIndex; 224 | mad.lo.u32 %sharedAddr, %threadInWarp, 8, %sharedAddr; 225 | ld.shared.u64 %nextIndex, [%sharedAddr]; 226 | 227 | // Perform another reduction across the warp. 228 | // Exact copy of the above reduction loop. 229 | { 230 | .reg .u32 %xorMask; 231 | .reg .f32 %otherMaxDist; 232 | .reg .u32 %otherIndex<2>; 233 | .reg .u64 %otherIndex; 234 | mov.u32 %xorMask, 1; 235 | reduction_loop_1: 236 | shfl.sync.bfly.b32 %otherMaxDist, %localMaxDist, %xorMask, 0x1f, 0xffffffff; 237 | mov.b64 {%otherIndex1, %otherIndex0}, %nextIndex; 238 | shfl.sync.bfly.b32 %otherIndex0, %otherIndex0, %xorMask, 0x1f, 0xffffffff; 239 | shfl.sync.bfly.b32 %otherIndex1, %otherIndex1, %xorMask, 0x1f, 0xffffffff; 240 | mov.b64 %otherIndex, {%otherIndex1, %otherIndex0}; 241 | 242 | // Keep other value if it's greater or if it has a lower 243 | // index and is equal. 244 | setp.eq.f32 %p0, %localMaxDist, %otherMaxDist; 245 | setp.lt.and.u64 %p0, %otherIndex, %nextIndex, %p0; 246 | setp.gt.or.f32 %p0, %otherMaxDist, %localMaxDist, %p0; 247 | 248 | @%p0 mov.u64 %nextIndex, %otherIndex; 249 | @%p0 mov.f32 %localMaxDist, %otherMaxDist; 250 | shl.b32 %xorMask, %xorMask, 1; 251 | setp.lt.u32 %p0, %xorMask, 32; 252 | @%p0 bra.uni reduction_loop_1; 253 | reduction_loop_1_end: 254 | } 255 | } 256 | 257 | // Make sure all writes are visible past this point. 258 | bar.sync 0; 259 | membar.cta; 260 | 261 | add.u64 %i, %i, 1; 262 | setp.lt.u64 %p0, %i, %outCount; 263 | @%p0 bra.uni loop_start; 264 | loop_end: 265 | } 266 | 267 | ret; 268 | } 269 | -------------------------------------------------------------------------------- /learn_ptx/kernels/matmul_inner_loop.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | .visible .entry simpleMatmul ( 6 | .param .u64 ptrA, 7 | .param .u64 ptrB, 8 | .param .u64 ptrOut, 9 | .param .u32 numBlocks 10 | ) { 11 | .reg .pred %p0; 12 | .reg .u64 %dtmp<2>; 13 | .reg .u32 %stmp<2>; 14 | 15 | // Attributes of the thread/CTA. 16 | .reg .u32 %blockSize; 17 | .reg .u32 %tidX; 18 | .reg .u32 %tidY; 19 | 20 | .reg .u64 %offsetX; 21 | .reg .u64 %offsetY; 22 | .reg .u64 %stride; 23 | .reg .u32 %i; 24 | .reg .u32 %numIters; 25 | .reg .u32 %numBlocks; 26 | .reg .u64 %ptrA; 27 | .reg .u64 %ptrB; 28 | .reg .u64 %ptrOut; 29 | .reg .f32 %acc; 30 | .reg .f32 %val<2>; 31 | 32 | ld.param.u32 %numBlocks, [numBlocks]; 33 | ld.param.u64 %ptrA, [ptrA]; 34 | ld.param.u64 %ptrB, [ptrB]; 35 | ld.param.u64 %ptrOut, [ptrOut]; 36 | 37 | mov.u32 %blockSize, %ntid.x; 38 | mov.u32 %tidX, %tid.x; 39 | mov.u32 %tidY, %tid.y; 40 | 41 | // For computing offsetX, offsetY, and stride, we use 42 | // %dtmp0 to store a 64-bit version of %blockSize. 43 | cvt.u64.u32 %dtmp0, %blockSize; // %dtmp0 = %blockSize 44 | 45 | // Compute offsets in the output matrix. 46 | // offsetX = ctaid.x * ntid.x = ctaid.x * blockSize 47 | cvt.u64.u32 %offsetX, %ctaid.x; 48 | mul.lo.u64 %offsetX, %offsetX, %dtmp0; 49 | // offsetY = ctaid.y * ntid.y = ctaid.y * blockSize 50 | cvt.u64.u32 %offsetY, %ctaid.y; 51 | mul.lo.u64 %offsetY, %offsetY, %dtmp0; 52 | 53 | // Stride is blockSize * numBlocks; 54 | cvt.u64.u32 %stride, %numBlocks; 55 | mul.lo.u64 %stride, %stride, %dtmp0; 56 | 57 | // We will accumulate into this register. 58 | mov.f32 %acc, 0.0; 59 | 60 | // We will always read from A in order starting at 61 | // stride*(offsetY+tid.y) and going forward one element 62 | // per inner-loop iteration. 63 | cvt.u64.u32 %dtmp0, %tidY; 64 | add.u64 %dtmp0, %dtmp0, %offsetY; 65 | mul.lo.u64 %dtmp0, %dtmp0, %stride; 66 | shl.b64 %dtmp0, %dtmp0, 2; 67 | add.u64 %ptrA, %ptrA, %dtmp0; 68 | 69 | // We will calculate our block offset in B in %ptrB as 70 | // (offsetX + tid.x, i*ntid.y+j) 71 | // = offsetX + tid.x + (i*blockSize+j)*stride 72 | cvt.u64.u32 %dtmp1, %tidX; 73 | add.u64 %dtmp0, %dtmp1, %offsetX; 74 | shl.b64 %dtmp0, %dtmp0, 2; 75 | add.u64 %ptrB, %ptrB, %dtmp0; 76 | 77 | // Set %dtmp1 to stride in B. 78 | // Stride in ptrB is stride*4 79 | shl.b64 %dtmp1, %stride, 2; 80 | 81 | mul.lo.u32 %numIters, %blockSize, %numBlocks; 82 | mov.u32 %i, 0; 83 | loop_start: 84 | // We are iterating through the entire row of A sequentially. 85 | ld.global.f32 %val0, [%ptrA]; 86 | add.u64 %ptrA, %ptrA, 4; 87 | 88 | // Iterate through the entire column of B sequentially. 89 | ld.global.f32 %val1, [%ptrB]; 90 | add.u64 %ptrB, %ptrB, %dtmp1; 91 | 92 | // This will be optimized to a fused operation. 93 | mul.f32 %val1, %val0, %val1; 94 | add.f32 %acc, %acc, %val1; 95 | 96 | // i += 1; loop while i < numBlocks 97 | add.u32 %i, %i, 1; 98 | setp.lt.u32 %p0, %i, %numIters; 99 | @%p0 bra loop_start; 100 | 101 | loop_end: 102 | // Write back to output memory. 103 | 104 | // Output address is offsetX+tid.x + stride*(offsetY+tid.y) 105 | cvt.u64.u32 %dtmp0, %tidY; 106 | add.u64 %dtmp0, %dtmp0, %offsetY; 107 | mul.lo.u64 %dtmp0, %dtmp0, %stride; 108 | cvt.u64.u32 %dtmp1, %tidX; 109 | add.u64 %dtmp1, %dtmp1, %offsetX; 110 | add.u64 %dtmp0, %dtmp0, %dtmp1; 111 | shl.b64 %dtmp0, %dtmp0, 2; 112 | add.u64 %dtmp0, %dtmp0, %ptrOut; 113 | 114 | st.global.f32 [%dtmp0], %acc; 115 | } 116 | -------------------------------------------------------------------------------- /learn_ptx/kernels/matmul_simple_block_v1.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | .visible .entry blockedMatmul ( 6 | .param .u64 ptrA, 7 | .param .u64 ptrB, 8 | .param .u64 ptrOut, 9 | .param .u32 numBlocks 10 | ) { 11 | .reg .pred %p0; 12 | .reg .u64 %tmp<2>; 13 | .reg .u32 %halfTmp<2>; 14 | .reg .u32 %localOffset; 15 | .reg .u64 %offsetX; 16 | .reg .u64 %offsetY; 17 | .reg .u64 %stride; 18 | .reg .u32 %i; 19 | .reg .u32 %j; 20 | .reg .u32 %numBlocks; 21 | .reg .u64 %ptrA; 22 | .reg .u64 %ptrB; 23 | .reg .u64 %ptrOut; 24 | .reg .f32 %acc; 25 | .reg .f32 %val<2>; 26 | .shared .align 4 .f32 loadedA[1024]; // should be ntid.x*ntid.y 27 | .shared .align 4 .f32 loadedB[1024]; // should be ntid.x*ntid.y 28 | 29 | ld.param.u32 %numBlocks, [numBlocks]; 30 | ld.param.u64 %ptrA, [ptrA]; 31 | ld.param.u64 %ptrB, [ptrB]; 32 | ld.param.u64 %ptrOut, [ptrOut]; 33 | 34 | // Local offset is (tid.y*ntid.x + tid.x) * sizeof(float32) 35 | mov.u32 %localOffset, %tid.y; 36 | mov.u32 %halfTmp0, %ntid.x; 37 | mul.lo.u32 %localOffset, %localOffset, %halfTmp0; 38 | mov.u32 %halfTmp0, %tid.x; 39 | add.u32 %localOffset, %localOffset, %halfTmp0; 40 | mul.lo.u32 %localOffset, %localOffset, 4; 41 | 42 | // Compute offsets in the output matrix. 43 | // offsetX = ctaid.x * ntid.x 44 | cvt.u64.u32 %offsetX, %ctaid.x; 45 | cvt.u64.u32 %tmp0, %ntid.x; 46 | mul.lo.u64 %offsetX, %offsetX, %tmp0; 47 | // offsetY = ctaid.y * ntid.y 48 | cvt.u64.u32 %offsetY, %ctaid.y; 49 | cvt.u64.u32 %tmp0, %ntid.y; 50 | mul.lo.u64 %offsetY, %offsetY, %tmp0; 51 | 52 | // Stride is ntid.x * numBlocks 53 | cvt.u64.u32 %stride, %ntid.x; 54 | cvt.u64.u32 %tmp0, %numBlocks; 55 | mul.lo.u64 %stride, %stride, %tmp0; 56 | 57 | // Zero out our local portion of the output. 58 | // mov.u32 %halfTmp0, output; 59 | // add.u32 %halfTmp0, %halfTmp0, %localOffset; 60 | // mov.f32 %val0, 0.0; 61 | // st.shared.f32 [%halfTmp0], %val0; 62 | mov.f32 %acc, 0.0; 63 | 64 | mov.u32 %i, 0; 65 | loop_start: 66 | // Don't write into memory until other threads are 67 | // caught up, to avoid races. 68 | bar.sync 0; 69 | 70 | // Our block offset in A is (i*ntid.x + tid.x, offsetY + tid.y) 71 | cvt.u64.u32 %tmp0, %i; 72 | cvt.u64.u32 %tmp1, %ntid.x; 73 | mul.lo.u64 %tmp0, %tmp0, %tmp1; 74 | cvt.u64.u32 %tmp1, %tid.x; 75 | add.u64 %tmp0, %tmp0, %tmp1; 76 | cvt.u64.u32 %tmp1, %tid.y; 77 | add.u64 %tmp1, %tmp1, %offsetY; 78 | // Compute pointer as &ptrA[y*stride+x] 79 | mul.lo.u64 %tmp1, %tmp1, %stride; 80 | add.u64 %tmp0, %tmp0, %tmp1; 81 | mul.lo.u64 %tmp0, %tmp0, 4; 82 | add.u64 %tmp0, %tmp0, %ptrA; 83 | // Output pointer 84 | mov.u32 %halfTmp0, loadedA; 85 | add.u32 %halfTmp0, %halfTmp0, %localOffset; 86 | // Copy to local memory 87 | ld.global.f32 %val0, [%tmp0]; 88 | st.shared.f32 [%halfTmp0], %val0; 89 | 90 | // Our block offset in B is (offsetX + tid.x, i*ntid.y + tid.y) 91 | cvt.u64.u32 %tmp0, %i; 92 | cvt.u64.u32 %tmp1, %ntid.y; 93 | mul.lo.u64 %tmp0, %tmp0, %tmp1; 94 | cvt.u64.u32 %tmp1, %tid.y; 95 | add.u64 %tmp0, %tmp0, %tmp1; 96 | cvt.u64.u32 %tmp1, %tid.x; 97 | add.u64 %tmp1, %tmp1, %offsetX; 98 | // Compute global offset as &ptrB[y*stride+x] 99 | mul.lo.u64 %tmp0, %tmp0, %stride; 100 | add.u64 %tmp0, %tmp0, %tmp1; 101 | mul.lo.u64 %tmp0, %tmp0, 4; 102 | add.u64 %tmp0, %tmp0, %ptrB; 103 | // Output pointer 104 | mov.u32 %halfTmp0, loadedB; 105 | add.u32 %halfTmp0, %halfTmp0, %localOffset; 106 | // Copy to local memory 107 | ld.global.f32 %val0, [%tmp0]; 108 | st.shared.f32 [%halfTmp0], %val0; 109 | 110 | bar.sync 0; 111 | 112 | mov.u32 %j, 0; 113 | inner_loop_start: 114 | // Offset in loadedA is j + tid.y*ntid.x 115 | mov.u32 %halfTmp0, %ntid.x; 116 | mov.u32 %halfTmp1, %tid.y; 117 | mul.lo.u32 %halfTmp1, %halfTmp1, %halfTmp0; 118 | add.u32 %halfTmp1, %halfTmp1, %j; 119 | mul.lo.u32 %halfTmp1, %halfTmp1, 4; 120 | mov.u32 %halfTmp0, loadedA; 121 | add.u32 %halfTmp0, %halfTmp0, %halfTmp1; 122 | ld.shared.f32 %val0, [%halfTmp0]; 123 | 124 | // Offset in loadedB is tid.x + j*ntid.x 125 | mov.u32 %halfTmp1, %ntid.x; 126 | mul.lo.u32 %halfTmp1, %halfTmp1, %j; 127 | mov.u32 %halfTmp0, %tid.x; 128 | add.u32 %halfTmp1, %halfTmp1, %halfTmp0; 129 | mul.lo.u32 %halfTmp1, %halfTmp1, 4; 130 | mov.u32 %halfTmp0, loadedB; 131 | add.u32 %halfTmp0, %halfTmp0, %halfTmp1; 132 | ld.shared.f32 %val1, [%halfTmp0]; 133 | 134 | // Can be optimized to fused operation. 135 | mul.f32 %val1, %val0, %val1; 136 | add.f32 %acc, %acc, %val1; 137 | 138 | // j += 1; loop while j < ntid.x 139 | mov.u32 %halfTmp0, %ntid.x; 140 | add.u32 %j, %j, 1; 141 | setp.lt.u32 %p0, %j, %halfTmp0; 142 | @%p0 bra inner_loop_start; 143 | 144 | inner_loop_end: 145 | // i += 1; loop while i < numBlocks 146 | add.u32 %i, %i, 1; 147 | setp.lt.u32 %p0, %i, %numBlocks; 148 | @%p0 bra loop_start; 149 | 150 | loop_end: 151 | // Write back to output memory. 152 | 153 | // Output address is offsetX+tid.x + stride*(offsetY+tid.y) 154 | cvt.u64.u32 %tmp0, %tid.y; 155 | add.u64 %tmp0, %tmp0, %offsetY; 156 | mul.lo.u64 %tmp0, %tmp0, %stride; 157 | cvt.u64.u32 %tmp1, %tid.x; 158 | add.u64 %tmp1, %tmp1, %offsetX; 159 | add.u64 %tmp0, %tmp0, %tmp1; 160 | mul.lo.u64 %tmp0, %tmp0, 4; 161 | add.u64 %tmp0, %tmp0, %ptrOut; 162 | 163 | st.global.f32 [%tmp0], %acc; 164 | } -------------------------------------------------------------------------------- /learn_ptx/kernels/matmul_simple_block_v2.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | .visible .entry blockedMatmulV2 ( 6 | .param .u64 ptrA, 7 | .param .u64 ptrB, 8 | .param .u64 ptrOut, 9 | .param .u32 numBlocks 10 | ) { 11 | .reg .pred %p0; 12 | .reg .u64 %dtmp<2>; 13 | .reg .u32 %stmp<3>; 14 | 15 | // Offset in loadedA / loadedB that we write to. 16 | .reg .u32 %loadOffset; 17 | 18 | // Attributes of the thread/CTA. 19 | .reg .u32 %blockSize; 20 | .reg .u32 %tidX; 21 | .reg .u32 %tidY; 22 | 23 | .reg .u64 %offsetX; 24 | .reg .u64 %offsetY; 25 | .reg .u64 %stride; 26 | .reg .u32 %i; 27 | .reg .u32 %j; 28 | .reg .u32 %numBlocks; 29 | .reg .u64 %ptrA; 30 | .reg .u64 %ptrB; 31 | .reg .u64 %ptrOut; 32 | .reg .f32 %acc; 33 | .reg .f32 %val<2>; 34 | .shared .align 4 .f32 loadedA[1024]; // should be at least blockSize^2 35 | .shared .align 4 .f32 loadedB[1024]; // should be at least blockSize^2 36 | 37 | ld.param.u32 %numBlocks, [numBlocks]; 38 | ld.param.u64 %ptrA, [ptrA]; 39 | ld.param.u64 %ptrB, [ptrB]; 40 | ld.param.u64 %ptrOut, [ptrOut]; 41 | 42 | mov.u32 %blockSize, %ntid.x; 43 | mov.u32 %tidX, %tid.x; 44 | mov.u32 %tidY, %tid.y; 45 | 46 | // Local offset is (tid.y*blockSize + tid.x) * sizeof(float32) 47 | mul.lo.u32 %loadOffset, %tidY, %blockSize; 48 | add.u32 %loadOffset, %loadOffset, %tidX; 49 | shl.b32 %loadOffset, %loadOffset, 2; 50 | 51 | // For computing offsetX, offsetY, and stride, we use 52 | // %dtmp0 to store a 64-bit version of %blockSize. 53 | cvt.u64.u32 %dtmp0, %blockSize; // %dtmp0 = %blockSize 54 | 55 | // Compute offsets in the output matrix. 56 | // offsetX = ctaid.x * ntid.x = ctaid.x * blockSize 57 | cvt.u64.u32 %offsetX, %ctaid.x; 58 | mul.lo.u64 %offsetX, %offsetX, %dtmp0; 59 | // offsetY = ctaid.y * ntid.y = ctaid.y * blockSize 60 | cvt.u64.u32 %offsetY, %ctaid.y; 61 | mul.lo.u64 %offsetY, %offsetY, %dtmp0; 62 | 63 | // Stride is blockSize * numBlocks; 64 | cvt.u64.u32 %stride, %numBlocks; 65 | mul.lo.u64 %stride, %stride, %dtmp0; 66 | 67 | // We will accumulate into this register. 68 | mov.f32 %acc, 0.0; 69 | 70 | // We will calculate block offset in A in %ptrA as 71 | // (i*ntid.x + tid.x, offsetY + tid.y) 72 | // = i*ntid.x + tid.x + stride*(offsetY+tid.y) 73 | cvt.u64.u32 %dtmp0, %tidY; 74 | add.u64 %dtmp0, %dtmp0, %offsetY; 75 | mul.lo.u64 %dtmp0, %dtmp0, %stride; 76 | cvt.u64.u32 %dtmp1, %tidX; 77 | add.u64 %dtmp0, %dtmp0, %dtmp1; 78 | shl.b64 %dtmp0, %dtmp0, 2; 79 | add.u64 %ptrA, %ptrA, %dtmp0; 80 | 81 | // We will calculate our block offset in B in %ptrB as 82 | // (offsetX + tid.x, i*ntid.y + tid.y) 83 | // = offsetX + tid.x + i*stride*blockSize + stride*tid.y 84 | cvt.u64.u32 %dtmp0, %tidY; 85 | mul.lo.u64 %dtmp0, %dtmp0, %stride; 86 | cvt.u64.u32 %dtmp1, %tidX; 87 | add.u64 %dtmp0, %dtmp0, %dtmp1; 88 | add.u64 %dtmp0, %dtmp0, %offsetX; 89 | shl.b64 %dtmp0, %dtmp0, 2; 90 | add.u64 %ptrB, %ptrB, %dtmp0; 91 | 92 | // Set %dtmp0 and %dtmp1 to strides in A and B, respectively. 93 | // Stride in ptrA is blockSize*4 94 | cvt.u64.u32 %dtmp0, %blockSize; 95 | shl.b64 %dtmp0, %dtmp0, 2; 96 | // Stride in ptrB is stride*blockSize*4 97 | mul.lo.u64 %dtmp1, %dtmp0, %stride; 98 | 99 | mov.u32 %i, 0; 100 | loop_start: 101 | // Don't write into memory until other threads are 102 | // caught up, to avoid races. 103 | bar.sync 0; 104 | 105 | // Read our entry from A into shared memory. 106 | mov.u32 %stmp0, loadedA; 107 | add.u32 %stmp0, %stmp0, %loadOffset; 108 | // Copy to local memory 109 | ld.global.f32 %val0, [%ptrA]; 110 | st.shared.f32 [%stmp0], %val0; 111 | add.u64 %ptrA, %ptrA, %dtmp0; 112 | 113 | // Read our entry from B into shared memory. 114 | mov.u32 %stmp0, loadedB; 115 | add.u32 %stmp0, %stmp0, %loadOffset; 116 | // Copy to local memory 117 | ld.global.f32 %val0, [%ptrB]; 118 | st.shared.f32 [%stmp0], %val0; 119 | add.u64 %ptrB, %ptrB, %dtmp1; 120 | 121 | bar.sync 0; 122 | 123 | // %stmp0 will be address in A. 124 | // It will be &loadedA[j + tid.y*ntid.x], starting at j=0 125 | mul.lo.u32 %stmp0, %tidY, %blockSize; 126 | shl.b32 %stmp0, %stmp0, 2; 127 | mov.u32 %stmp1, loadedA; 128 | add.u32 %stmp0, %stmp0, %stmp1; 129 | 130 | // %stmp1 will be address in B. 131 | // It will be &loadedB[tid.x + j*ntid.x] starting at j=0 132 | mov.u32 %stmp1, loadedB; 133 | shl.b32 %stmp2, %tidX, 2; 134 | add.u32 %stmp1, %stmp1, %stmp2; 135 | shl.b32 %stmp2, %blockSize, 2; 136 | 137 | mov.u32 %j, 0; 138 | inner_loop_start: 139 | // Offset in loadedA is j + tid.y*ntid.x 140 | ld.shared.f32 %val0, [%stmp0]; 141 | add.u32 %stmp0, %stmp0, 4; 142 | 143 | // Offset in loadedB is tid.x + j*ntid.x 144 | ld.shared.f32 %val1, [%stmp1]; 145 | add.u32 %stmp1, %stmp1, %stmp2; 146 | 147 | // Can be optimized to fused operation. 148 | mul.f32 %val1, %val0, %val1; 149 | add.f32 %acc, %acc, %val1; 150 | 151 | // j += 1; loop while j < ntid.x 152 | add.u32 %j, %j, 1; 153 | setp.lt.u32 %p0, %j, %blockSize; 154 | @%p0 bra inner_loop_start; 155 | 156 | inner_loop_end: 157 | // i += 1; loop while i < numBlocks 158 | add.u32 %i, %i, 1; 159 | setp.lt.u32 %p0, %i, %numBlocks; 160 | @%p0 bra loop_start; 161 | 162 | loop_end: 163 | // Write back to output memory. 164 | 165 | // Output address is offsetX+tid.x + stride*(offsetY+tid.y) 166 | cvt.u64.u32 %dtmp0, %tidY; 167 | add.u64 %dtmp0, %dtmp0, %offsetY; 168 | mul.lo.u64 %dtmp0, %dtmp0, %stride; 169 | cvt.u64.u32 %dtmp1, %tidX; 170 | add.u64 %dtmp1, %dtmp1, %offsetX; 171 | add.u64 %dtmp0, %dtmp0, %dtmp1; 172 | shl.b64 %dtmp0, %dtmp0, 2; 173 | add.u64 %dtmp0, %dtmp0, %ptrOut; 174 | 175 | st.global.f32 [%dtmp0], %acc; 176 | } -------------------------------------------------------------------------------- /learn_ptx/kernels/matmul_simple_block_v3.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | .visible .entry blockedMatmulV3 ( 6 | .param .u64 ptrA, 7 | .param .u64 ptrB, 8 | .param .u64 ptrOut, 9 | .param .u32 numBlocks 10 | ) { 11 | .reg .pred %p0; 12 | .reg .u64 %dtmp<2>; 13 | .reg .u32 %stmp<2>; 14 | 15 | // Offset in loadedA / loadedB that we write to. 16 | .reg .u32 %loadOffset; 17 | 18 | // Attributes of the thread/CTA. 19 | .reg .u32 %blockSize; 20 | .reg .u32 %tidX; 21 | .reg .u32 %tidY; 22 | 23 | .reg .u64 %offsetX; 24 | .reg .u64 %offsetY; 25 | .reg .u64 %stride; 26 | .reg .u32 %i; 27 | .reg .u32 %j; 28 | .reg .u32 %numBlocks; 29 | .reg .u64 %ptrA; 30 | .reg .u64 %ptrB; 31 | .reg .u64 %ptrOut; 32 | .reg .f32 %acc; 33 | .reg .f32 %val<2>; 34 | .reg .f32 %localA; 35 | .shared .align 4 .f32 loadedA[1024]; // should be at least blockSize^2 36 | .shared .align 4 .f32 loadedB[1024]; // should be at least blockSize^2 37 | 38 | ld.param.u32 %numBlocks, [numBlocks]; 39 | ld.param.u64 %ptrA, [ptrA]; 40 | ld.param.u64 %ptrB, [ptrB]; 41 | ld.param.u64 %ptrOut, [ptrOut]; 42 | 43 | mov.u32 %blockSize, %ntid.x; 44 | mov.u32 %tidX, %tid.x; 45 | mov.u32 %tidY, %tid.y; 46 | 47 | // Local offset is (tid.y*blockSize + tid.x) * sizeof(float32) 48 | mul.lo.u32 %loadOffset, %tidY, %blockSize; 49 | add.u32 %loadOffset, %loadOffset, %tidX; 50 | shl.b32 %loadOffset, %loadOffset, 2; 51 | 52 | // For computing offsetX, offsetY, and stride, we use 53 | // %dtmp0 to store a 64-bit version of %blockSize. 54 | cvt.u64.u32 %dtmp0, %blockSize; // %dtmp0 = %blockSize 55 | 56 | // Compute offsets in the output matrix. 57 | // offsetX = ctaid.x * ntid.x = ctaid.x * blockSize 58 | cvt.u64.u32 %offsetX, %ctaid.x; 59 | mul.lo.u64 %offsetX, %offsetX, %dtmp0; 60 | // offsetY = ctaid.y * ntid.y = ctaid.y * blockSize 61 | cvt.u64.u32 %offsetY, %ctaid.y; 62 | mul.lo.u64 %offsetY, %offsetY, %dtmp0; 63 | 64 | // Stride is blockSize * numBlocks; 65 | cvt.u64.u32 %stride, %numBlocks; 66 | mul.lo.u64 %stride, %stride, %dtmp0; 67 | 68 | // We will accumulate into this register. 69 | mov.f32 %acc, 0.0; 70 | 71 | // We will calculate block offset in A in %ptrA as 72 | // (i*ntid.x + tid.x, offsetY + tid.y) 73 | // = i*ntid.x + tid.x + stride*(offsetY+tid.y) 74 | cvt.u64.u32 %dtmp0, %tidY; 75 | add.u64 %dtmp0, %dtmp0, %offsetY; 76 | mul.lo.u64 %dtmp0, %dtmp0, %stride; 77 | cvt.u64.u32 %dtmp1, %tidX; 78 | add.u64 %dtmp0, %dtmp0, %dtmp1; 79 | shl.b64 %dtmp0, %dtmp0, 2; 80 | add.u64 %ptrA, %ptrA, %dtmp0; 81 | 82 | // We will calculate our block offset in B in %ptrB as 83 | // (offsetX + tid.x, i*ntid.y + tid.y) 84 | // = offsetX + tid.x + i*stride*blockSize + stride*tid.y 85 | cvt.u64.u32 %dtmp0, %tidY; 86 | mul.lo.u64 %dtmp0, %dtmp0, %stride; 87 | cvt.u64.u32 %dtmp1, %tidX; 88 | add.u64 %dtmp0, %dtmp0, %dtmp1; 89 | add.u64 %dtmp0, %dtmp0, %offsetX; 90 | shl.b64 %dtmp0, %dtmp0, 2; 91 | add.u64 %ptrB, %ptrB, %dtmp0; 92 | 93 | // Set %dtmp0 and %dtmp1 to strides in A and B, respectively. 94 | // Stride in ptrA is blockSize*4 95 | cvt.u64.u32 %dtmp0, %blockSize; 96 | shl.b64 %dtmp0, %dtmp0, 2; 97 | // Stride in ptrB is stride*blockSize*4 98 | mul.lo.u64 %dtmp1, %dtmp0, %stride; 99 | 100 | mov.u32 %i, 0; 101 | loop_start: 102 | // Don't write into memory until other threads are 103 | // caught up, to avoid races. 104 | bar.sync 0; 105 | 106 | // Read our entry from A into shared memory. 107 | mov.u32 %stmp0, loadedA; 108 | add.u32 %stmp0, %stmp0, %loadOffset; 109 | // Copy to local memory 110 | ld.global.f32 %val0, [%ptrA]; 111 | st.shared.f32 [%stmp0], %val0; 112 | add.u64 %ptrA, %ptrA, %dtmp0; 113 | 114 | // Read our entry from B into shared memory. 115 | mov.u32 %stmp0, loadedB; 116 | add.u32 %stmp0, %stmp0, %loadOffset; 117 | // Copy to local memory 118 | ld.global.f32 %val0, [%ptrB]; 119 | st.shared.f32 [%stmp0], %val0; 120 | add.u64 %ptrB, %ptrB, %dtmp1; 121 | 122 | bar.sync 0; 123 | 124 | // This doesn't seem to help, but it should in theory. 125 | add.u32 %i, %i, 1; 126 | setp.lt.u32 %p0, %i, %numBlocks; 127 | @%p0 prefetch.global.L1 [%ptrA]; 128 | @%p0 prefetch.global.L1 [%ptrB]; 129 | 130 | // We will load each entry into a different thread 131 | // in the warp, under the assumption that the block 132 | // size is exactly the warp size. 133 | // We load &loadedA[tid.x + tid.y*ntid.x] into our register. 134 | mul.lo.u32 %stmp0, %tidY, %blockSize; 135 | add.u32 %stmp0, %stmp0, %tidX; 136 | shl.b32 %stmp0, %stmp0, 2; 137 | mov.u32 %stmp1, loadedA; 138 | add.u32 %stmp0, %stmp0, %stmp1; 139 | ld.shared.f32 %localA, [%stmp0]; 140 | 141 | // %stmp0 will be address in B. 142 | // It will be &loadedB[tid.x + j*ntid.x] starting at j=0 143 | mov.u32 %stmp0, loadedB; 144 | shl.b32 %stmp1, %tidX, 2; 145 | add.u32 %stmp0, %stmp0, %stmp1; 146 | shl.b32 %stmp1, %blockSize, 2; 147 | 148 | mov.u32 %j, 0; 149 | inner_loop_start: 150 | shfl.sync.idx.b32 %val0, %localA, %j, 0x1f, 0xffffffff; 151 | 152 | // Offset in loadedB is tid.x + j*ntid.x 153 | ld.shared.f32 %val1, [%stmp0]; 154 | add.u32 %stmp0, %stmp0, %stmp1; 155 | 156 | // Can be optimized to fused operation. 157 | mul.f32 %val1, %val0, %val1; 158 | add.f32 %acc, %acc, %val1; 159 | 160 | // j += 1; loop while j < ntid.x 161 | add.u32 %j, %j, 1; 162 | setp.lt.u32 %p0, %j, %blockSize; 163 | @%p0 bra inner_loop_start; 164 | 165 | inner_loop_end: 166 | // loop while i < numBlocks 167 | setp.lt.u32 %p0, %i, %numBlocks; 168 | @%p0 bra loop_start; 169 | 170 | loop_end: 171 | // Write back to output memory. 172 | 173 | // Output address is offsetX+tid.x + stride*(offsetY+tid.y) 174 | cvt.u64.u32 %dtmp0, %tidY; 175 | add.u64 %dtmp0, %dtmp0, %offsetY; 176 | mul.lo.u64 %dtmp0, %dtmp0, %stride; 177 | cvt.u64.u32 %dtmp1, %tidX; 178 | add.u64 %dtmp1, %dtmp1, %offsetX; 179 | add.u64 %dtmp0, %dtmp0, %dtmp1; 180 | shl.b64 %dtmp0, %dtmp0, 2; 181 | add.u64 %dtmp0, %dtmp0, %ptrOut; 182 | 183 | st.global.f32 [%dtmp0], %acc; 184 | } -------------------------------------------------------------------------------- /learn_ptx/kernels/matmul_simple_block_v4.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | .visible .entry blockedMatmulV4 ( 6 | .param .u64 ptrA, 7 | .param .u64 ptrB, 8 | .param .u64 ptrOut, 9 | .param .u32 numBlocks 10 | ) { 11 | .reg .pred %p0; 12 | .reg .u64 %dtmp<2>; 13 | .reg .u32 %stmp<2>; 14 | 15 | // Offset in loadedA / loadedB that we write to. 16 | .reg .u32 %loadOffset; 17 | 18 | // Attributes of the thread/CTA. 19 | .reg .u32 %tidX; 20 | .reg .u32 %tidY; 21 | 22 | .reg .u64 %offsetX; 23 | .reg .u64 %offsetY; 24 | .reg .u64 %stride; 25 | .reg .u32 %i; 26 | .reg .u32 %j; 27 | .reg .u32 %numBlocks; 28 | .reg .u64 %ptrA; 29 | .reg .u64 %ptrB; 30 | .reg .u64 %ptrOut; 31 | .reg .f32 %acc; 32 | .reg .f32 %ftmp; 33 | .reg .f32 %localA; 34 | .reg .f32 %localB<32>; 35 | .shared .align 4 .f32 loadedA[1024]; 36 | .shared .align 4 .f32 loadedB[1024]; 37 | 38 | ld.param.u32 %numBlocks, [numBlocks]; 39 | ld.param.u64 %ptrA, [ptrA]; 40 | ld.param.u64 %ptrB, [ptrB]; 41 | ld.param.u64 %ptrOut, [ptrOut]; 42 | 43 | mov.u32 %tidX, %tid.x; 44 | mov.u32 %tidY, %tid.y; 45 | 46 | // Local offset is (tid.y*blockSize + tid.x) * sizeof(float32) 47 | mul.lo.u32 %loadOffset, %tidY, 32; 48 | add.u32 %loadOffset, %loadOffset, %tidX; 49 | shl.b32 %loadOffset, %loadOffset, 2; 50 | 51 | // Compute offsets in the output matrix. 52 | // offsetX = ctaid.x * ntid.x = ctaid.x * blockSize 53 | cvt.u64.u32 %offsetX, %ctaid.x; 54 | mul.lo.u64 %offsetX, %offsetX, 32; 55 | // offsetY = ctaid.y * ntid.y = ctaid.y * blockSize 56 | cvt.u64.u32 %offsetY, %ctaid.y; 57 | mul.lo.u64 %offsetY, %offsetY, 32; 58 | 59 | // Stride is blockSize * numBlocks; 60 | cvt.u64.u32 %stride, %numBlocks; 61 | mul.lo.u64 %stride, %stride, 32; 62 | 63 | // We will accumulate into this register. 64 | mov.f32 %acc, 0.0; 65 | 66 | // We will calculate block offset in A in %ptrA as 67 | // (i*ntid.x + tid.x, offsetY + tid.y) 68 | // = i*ntid.x + tid.x + stride*(offsetY+tid.y) 69 | cvt.u64.u32 %dtmp0, %tidY; 70 | add.u64 %dtmp0, %dtmp0, %offsetY; 71 | mul.lo.u64 %dtmp0, %dtmp0, %stride; 72 | cvt.u64.u32 %dtmp1, %tidX; 73 | add.u64 %dtmp0, %dtmp0, %dtmp1; 74 | shl.b64 %dtmp0, %dtmp0, 2; 75 | add.u64 %ptrA, %ptrA, %dtmp0; 76 | 77 | // We will calculate our block offset in B in %ptrB as 78 | // (offsetX + tid.x, i*ntid.y + tid.y) 79 | // = offsetX + tid.x + i*stride*blockSize + stride*tid.y 80 | cvt.u64.u32 %dtmp0, %tidY; 81 | mul.lo.u64 %dtmp0, %dtmp0, %stride; 82 | cvt.u64.u32 %dtmp1, %tidX; 83 | add.u64 %dtmp0, %dtmp0, %dtmp1; 84 | add.u64 %dtmp0, %dtmp0, %offsetX; 85 | shl.b64 %dtmp0, %dtmp0, 2; 86 | add.u64 %ptrB, %ptrB, %dtmp0; 87 | 88 | // Stride in ptrB is stride*blockSize*4 89 | mul.lo.u64 %dtmp0, %stride, 128; 90 | 91 | mov.u32 %i, 0; 92 | loop_start: 93 | // Don't write into memory until other threads are 94 | // caught up, to avoid races. 95 | bar.sync 0; 96 | 97 | // Read our entry from A into shared memory. 98 | mov.u32 %stmp0, loadedA; 99 | add.u32 %stmp0, %stmp0, %loadOffset; 100 | // Copy to local memory 101 | ld.global.f32 %ftmp, [%ptrA]; 102 | st.shared.f32 [%stmp0], %ftmp; 103 | add.u64 %ptrA, %ptrA, 128; 104 | 105 | // Read our entry from B into shared memory. 106 | mov.u32 %stmp0, loadedB; 107 | add.u32 %stmp0, %stmp0, %loadOffset; 108 | // Copy to local memory 109 | ld.global.f32 %ftmp, [%ptrB]; 110 | st.shared.f32 [%stmp0], %ftmp; 111 | add.u64 %ptrB, %ptrB, %dtmp0; 112 | 113 | bar.sync 0; 114 | 115 | // This doesn't seem to help, but it should in theory. 116 | add.u32 %i, %i, 1; 117 | setp.lt.u32 %p0, %i, %numBlocks; 118 | @%p0 prefetch.global.L1 [%ptrA]; 119 | @%p0 prefetch.global.L1 [%ptrB]; 120 | 121 | // We will load each entry into a different thread 122 | // in the warp, under the assumption that the block 123 | // size is exactly the warp size. 124 | // We load &loadedA[tid.x + tid.y*ntid.x] into our register. 125 | mul.lo.u32 %stmp0, %tidY, 32; 126 | add.u32 %stmp0, %stmp0, %tidX; 127 | shl.b32 %stmp0, %stmp0, 2; 128 | mov.u32 %stmp1, loadedA; 129 | add.u32 %stmp0, %stmp0, %stmp1; 130 | ld.shared.f32 %localA, [%stmp0]; 131 | 132 | // %stmp0 will be address in B. 133 | // It will be &loadedB[tid.x + j*blockSize] starting at j=0 134 | mov.u32 %stmp0, loadedB; 135 | shl.b32 %stmp1, %tidX, 2; 136 | add.u32 %stmp0, %stmp0, %stmp1; 137 | 138 | // Fetch into registers. 139 | // 140 | // for i in range(32): 141 | // print(f"ld.shared.f32 %localB{i}, [%stmp0+{i*128}];") 142 | // 143 | ld.shared.f32 %localB0, [%stmp0+0]; 144 | ld.shared.f32 %localB1, [%stmp0+128]; 145 | ld.shared.f32 %localB2, [%stmp0+256]; 146 | ld.shared.f32 %localB3, [%stmp0+384]; 147 | ld.shared.f32 %localB4, [%stmp0+512]; 148 | ld.shared.f32 %localB5, [%stmp0+640]; 149 | ld.shared.f32 %localB6, [%stmp0+768]; 150 | ld.shared.f32 %localB7, [%stmp0+896]; 151 | ld.shared.f32 %localB8, [%stmp0+1024]; 152 | ld.shared.f32 %localB9, [%stmp0+1152]; 153 | ld.shared.f32 %localB10, [%stmp0+1280]; 154 | ld.shared.f32 %localB11, [%stmp0+1408]; 155 | ld.shared.f32 %localB12, [%stmp0+1536]; 156 | ld.shared.f32 %localB13, [%stmp0+1664]; 157 | ld.shared.f32 %localB14, [%stmp0+1792]; 158 | ld.shared.f32 %localB15, [%stmp0+1920]; 159 | ld.shared.f32 %localB16, [%stmp0+2048]; 160 | ld.shared.f32 %localB17, [%stmp0+2176]; 161 | ld.shared.f32 %localB18, [%stmp0+2304]; 162 | ld.shared.f32 %localB19, [%stmp0+2432]; 163 | ld.shared.f32 %localB20, [%stmp0+2560]; 164 | ld.shared.f32 %localB21, [%stmp0+2688]; 165 | ld.shared.f32 %localB22, [%stmp0+2816]; 166 | ld.shared.f32 %localB23, [%stmp0+2944]; 167 | ld.shared.f32 %localB24, [%stmp0+3072]; 168 | ld.shared.f32 %localB25, [%stmp0+3200]; 169 | ld.shared.f32 %localB26, [%stmp0+3328]; 170 | ld.shared.f32 %localB27, [%stmp0+3456]; 171 | ld.shared.f32 %localB28, [%stmp0+3584]; 172 | ld.shared.f32 %localB29, [%stmp0+3712]; 173 | ld.shared.f32 %localB30, [%stmp0+3840]; 174 | ld.shared.f32 %localB31, [%stmp0+3968]; 175 | 176 | // Perform the local dot product in-register. 177 | // 178 | // for i in range(32): 179 | // print(f"shfl.sync.idx.b32 %ftmp, %localA, {i}, 0x1f, 0xffffffff;") 180 | // print(f"fma.rn.f32 %acc, %ftmp, %localB{i}, %acc;") 181 | // 182 | shfl.sync.idx.b32 %ftmp, %localA, 0, 0x1f, 0xffffffff; 183 | fma.rn.f32 %acc, %ftmp, %localB0, %acc; 184 | shfl.sync.idx.b32 %ftmp, %localA, 1, 0x1f, 0xffffffff; 185 | fma.rn.f32 %acc, %ftmp, %localB1, %acc; 186 | shfl.sync.idx.b32 %ftmp, %localA, 2, 0x1f, 0xffffffff; 187 | fma.rn.f32 %acc, %ftmp, %localB2, %acc; 188 | shfl.sync.idx.b32 %ftmp, %localA, 3, 0x1f, 0xffffffff; 189 | fma.rn.f32 %acc, %ftmp, %localB3, %acc; 190 | shfl.sync.idx.b32 %ftmp, %localA, 4, 0x1f, 0xffffffff; 191 | fma.rn.f32 %acc, %ftmp, %localB4, %acc; 192 | shfl.sync.idx.b32 %ftmp, %localA, 5, 0x1f, 0xffffffff; 193 | fma.rn.f32 %acc, %ftmp, %localB5, %acc; 194 | shfl.sync.idx.b32 %ftmp, %localA, 6, 0x1f, 0xffffffff; 195 | fma.rn.f32 %acc, %ftmp, %localB6, %acc; 196 | shfl.sync.idx.b32 %ftmp, %localA, 7, 0x1f, 0xffffffff; 197 | fma.rn.f32 %acc, %ftmp, %localB7, %acc; 198 | shfl.sync.idx.b32 %ftmp, %localA, 8, 0x1f, 0xffffffff; 199 | fma.rn.f32 %acc, %ftmp, %localB8, %acc; 200 | shfl.sync.idx.b32 %ftmp, %localA, 9, 0x1f, 0xffffffff; 201 | fma.rn.f32 %acc, %ftmp, %localB9, %acc; 202 | shfl.sync.idx.b32 %ftmp, %localA, 10, 0x1f, 0xffffffff; 203 | fma.rn.f32 %acc, %ftmp, %localB10, %acc; 204 | shfl.sync.idx.b32 %ftmp, %localA, 11, 0x1f, 0xffffffff; 205 | fma.rn.f32 %acc, %ftmp, %localB11, %acc; 206 | shfl.sync.idx.b32 %ftmp, %localA, 12, 0x1f, 0xffffffff; 207 | fma.rn.f32 %acc, %ftmp, %localB12, %acc; 208 | shfl.sync.idx.b32 %ftmp, %localA, 13, 0x1f, 0xffffffff; 209 | fma.rn.f32 %acc, %ftmp, %localB13, %acc; 210 | shfl.sync.idx.b32 %ftmp, %localA, 14, 0x1f, 0xffffffff; 211 | fma.rn.f32 %acc, %ftmp, %localB14, %acc; 212 | shfl.sync.idx.b32 %ftmp, %localA, 15, 0x1f, 0xffffffff; 213 | fma.rn.f32 %acc, %ftmp, %localB15, %acc; 214 | shfl.sync.idx.b32 %ftmp, %localA, 16, 0x1f, 0xffffffff; 215 | fma.rn.f32 %acc, %ftmp, %localB16, %acc; 216 | shfl.sync.idx.b32 %ftmp, %localA, 17, 0x1f, 0xffffffff; 217 | fma.rn.f32 %acc, %ftmp, %localB17, %acc; 218 | shfl.sync.idx.b32 %ftmp, %localA, 18, 0x1f, 0xffffffff; 219 | fma.rn.f32 %acc, %ftmp, %localB18, %acc; 220 | shfl.sync.idx.b32 %ftmp, %localA, 19, 0x1f, 0xffffffff; 221 | fma.rn.f32 %acc, %ftmp, %localB19, %acc; 222 | shfl.sync.idx.b32 %ftmp, %localA, 20, 0x1f, 0xffffffff; 223 | fma.rn.f32 %acc, %ftmp, %localB20, %acc; 224 | shfl.sync.idx.b32 %ftmp, %localA, 21, 0x1f, 0xffffffff; 225 | fma.rn.f32 %acc, %ftmp, %localB21, %acc; 226 | shfl.sync.idx.b32 %ftmp, %localA, 22, 0x1f, 0xffffffff; 227 | fma.rn.f32 %acc, %ftmp, %localB22, %acc; 228 | shfl.sync.idx.b32 %ftmp, %localA, 23, 0x1f, 0xffffffff; 229 | fma.rn.f32 %acc, %ftmp, %localB23, %acc; 230 | shfl.sync.idx.b32 %ftmp, %localA, 24, 0x1f, 0xffffffff; 231 | fma.rn.f32 %acc, %ftmp, %localB24, %acc; 232 | shfl.sync.idx.b32 %ftmp, %localA, 25, 0x1f, 0xffffffff; 233 | fma.rn.f32 %acc, %ftmp, %localB25, %acc; 234 | shfl.sync.idx.b32 %ftmp, %localA, 26, 0x1f, 0xffffffff; 235 | fma.rn.f32 %acc, %ftmp, %localB26, %acc; 236 | shfl.sync.idx.b32 %ftmp, %localA, 27, 0x1f, 0xffffffff; 237 | fma.rn.f32 %acc, %ftmp, %localB27, %acc; 238 | shfl.sync.idx.b32 %ftmp, %localA, 28, 0x1f, 0xffffffff; 239 | fma.rn.f32 %acc, %ftmp, %localB28, %acc; 240 | shfl.sync.idx.b32 %ftmp, %localA, 29, 0x1f, 0xffffffff; 241 | fma.rn.f32 %acc, %ftmp, %localB29, %acc; 242 | shfl.sync.idx.b32 %ftmp, %localA, 30, 0x1f, 0xffffffff; 243 | fma.rn.f32 %acc, %ftmp, %localB30, %acc; 244 | shfl.sync.idx.b32 %ftmp, %localA, 31, 0x1f, 0xffffffff; 245 | fma.rn.f32 %acc, %ftmp, %localB31, %acc; 246 | 247 | @%p0 bra loop_start; 248 | 249 | loop_end: 250 | // Write back to output memory. 251 | 252 | // Output address is offsetX+tid.x + stride*(offsetY+tid.y) 253 | cvt.u64.u32 %dtmp0, %tidY; 254 | add.u64 %dtmp0, %dtmp0, %offsetY; 255 | mul.lo.u64 %dtmp0, %dtmp0, %stride; 256 | cvt.u64.u32 %dtmp1, %tidX; 257 | add.u64 %dtmp1, %dtmp1, %offsetX; 258 | add.u64 %dtmp0, %dtmp0, %dtmp1; 259 | shl.b64 %dtmp0, %dtmp0, 2; 260 | add.u64 %dtmp0, %dtmp0, %ptrOut; 261 | 262 | st.global.f32 [%dtmp0], %acc; 263 | } -------------------------------------------------------------------------------- /learn_ptx/kernels/matmul_wmma_v1.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_80 // needed for wmma instruction 3 | .address_size 64 4 | 5 | // A matrix multiplication that uses warp-level communication. 6 | // Multiplies 32x32 blocks using four warps (128 threads). 7 | // Each warp produces a 16x16 chunk and writes the output at the end. 8 | // Does not use shared memory, instead relies on L1/L2 cache. 9 | 10 | .visible .entry wmmaMatmulV1 ( 11 | .param .u64 ptrA, 12 | .param .u64 ptrB, 13 | .param .u64 ptrOut, 14 | .param .u32 numBlocks 15 | ) { 16 | .reg .pred %p0; 17 | 18 | // Attributes of the thread/CTA. 19 | .reg .u32 %tidX; 20 | .reg .u32 %tidY; 21 | .reg .u64 %ctaX; 22 | .reg .u64 %ctaY; 23 | 24 | // Arguments 25 | .reg .u64 %ptrA; 26 | .reg .u64 %ptrB; 27 | .reg .u64 %ptrOut; 28 | .reg .u32 %numBlocks; 29 | 30 | ld.param.u64 %ptrA, [ptrA]; 31 | ld.param.u64 %ptrB, [ptrB]; 32 | ld.param.u64 %ptrOut, [ptrOut]; 33 | ld.param.u32 %numBlocks, [numBlocks]; 34 | 35 | mov.u32 %tidX, %tid.x; // index in warp (0-32) 36 | mov.u32 %tidY, %tid.y; // warp index in block (0-4) 37 | cvt.u64.u32 %ctaX, %ctaid.x; // column of output 38 | cvt.u64.u32 %ctaY, %ctaid.y; // row of output 39 | 40 | // Accumulation registers are stored as 8 floats per thread. 41 | .reg .f32 %out<8>; 42 | mov.f32 %out0, 0.0; 43 | mov.f32 %out1, 0.0; 44 | mov.f32 %out2, 0.0; 45 | mov.f32 %out3, 0.0; 46 | mov.f32 %out4, 0.0; 47 | mov.f32 %out5, 0.0; 48 | mov.f32 %out6, 0.0; 49 | mov.f32 %out7, 0.0; 50 | 51 | // The row-wise stride of the matrices, measured in tf32's. 52 | .reg .u32 %stride32; 53 | .reg .u64 %stride; 54 | shl.b32 %stride32, %numBlocks, 5; 55 | cvt.u64.u32 %stride, %stride32; 56 | 57 | // We will use pointerInA to point to the top-left corner of our warp's 58 | // block in A. Both warp 0 and 1 will have the same pointer. 59 | // We will advance this by 8 every time we load some values and do a matmul. 60 | .reg .u64 %pointerInA; 61 | { 62 | .reg .u64 %tmp; 63 | shl.b64 %tmp, %stride, 7; // 4 bytes per float * 32 rows 64 | mul.lo.u64 %tmp, %tmp, %ctaY; 65 | add.u64 %pointerInA, %ptrA, %tmp; 66 | 67 | // Add row offset for second half of the block. 68 | cvt.u64.u32 %tmp, %tidY; 69 | and.b64 %tmp, %tmp, 2; // 2 if second row of block, 0 if first 70 | mul.lo.u64 %tmp, %tmp, %stride; 71 | shl.b64 %tmp, %tmp, 5; // Offset of thread is (16 rows)*stride*(4 bytes) = stride << (4 + 2) 72 | add.u64 %pointerInA, %pointerInA, %tmp; 73 | } 74 | 75 | // pointerInB is like pointerInA, except that we advance it by row rather than 76 | // by column. 77 | .reg .u64 %pointerInB; 78 | .reg .u64 %strideB; 79 | { 80 | .reg .u64 %tmp; 81 | shl.b64 %tmp, %ctaX, 7; // 4 bytes per float * 32 columns 82 | add.u64 %pointerInB, %ptrB, %tmp; 83 | 84 | // Add column offset for relevant parts of the block. 85 | cvt.u64.u32 %tmp, %tidY; 86 | and.b64 %tmp, %tmp, 1; // 1 if second column of block, 0 if first 87 | shl.b64 %tmp, %tmp, 6; // 16 floats * 4 bytes 88 | add.u64 %pointerInB, %pointerInB, %tmp; 89 | 90 | shl.b64 %strideB, %stride, 5; // 4 bytes * stride * 8 rows 91 | } 92 | 93 | .reg .u32 %remainingIters; 94 | mov.u32 %remainingIters, %numBlocks; 95 | 96 | outer_loop: 97 | setp.gt.u32 %p0, %remainingIters, 0; 98 | @!%p0 bra outer_loop_end; 99 | sub.u32 %remainingIters, %remainingIters, 1; 100 | 101 | { 102 | .reg .u32 %i; 103 | mov.u32 %i, 0; 104 | inner_loop: 105 | .reg .b32 %a<4>; 106 | .reg .b32 %b<4>; 107 | wmma.load.a.sync.aligned.row.m16n16k8.global.tf32 {%a0, %a1, %a2, %a3}, [%pointerInA], %stride32; 108 | wmma.load.b.sync.aligned.row.m16n16k8.global.tf32 {%b0, %b1, %b2, %b3}, [%pointerInB], %stride32; 109 | wmma.mma.sync.aligned.row.row.m16n16k8.f32.tf32.tf32.f32 {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}, {%a0, %a1, %a2, %a3}, {%b0, %b1, %b2, %b3}, {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}; 110 | add.u64 %pointerInA, %pointerInA, 32; // 8 floats * 4 bytes 111 | add.u64 %pointerInB, %pointerInB, %strideB; 112 | add.u32 %i, %i, 1; 113 | setp.eq.u32 %p0, %i, 4; 114 | @!%p0 bra inner_loop; 115 | inner_loop_end: 116 | } 117 | 118 | bra outer_loop; 119 | outer_loop_end: 120 | 121 | { 122 | .reg .u64 %outColumn; 123 | .reg .u64 %outOffset; 124 | .reg .u64 %tmp; 125 | 126 | shl.b64 %outColumn, %ctaX, 7; // 32 floats * 4 bytes 127 | cvt.u64.u32 %tmp, %tidY; 128 | and.b64 %tmp, %tmp, 1; // 1 if second column of block, 0 if first 129 | shl.b64 %tmp, %tmp, 6; // 16 floats * 4 bytes 130 | add.u64 %outColumn, %outColumn, %tmp; 131 | 132 | shl.b64 %outOffset, %stride, 7; // turn into a row offset (4 bytes), times 32 rows 133 | mul.lo.u64 %outOffset, %outOffset, %ctaY; 134 | cvt.u64.u32 %tmp, %tidY; 135 | 136 | // Offset for bottom half. 137 | and.b64 %tmp, %tmp, 2; // 2 if second row of block, 0 if first 138 | mul.lo.u64 %tmp, %tmp, %stride; 139 | shl.b64 %tmp, %tmp, 5; // for second row: 16 * stride * 4 bytes (already was 2, not 1) 140 | add.u64 %outOffset, %outOffset, %tmp; 141 | 142 | add.u64 %outOffset, %outOffset, %outColumn; 143 | add.u64 %ptrOut, %ptrOut, %outOffset; 144 | 145 | // Copy to %ptrOut. 146 | wmma.store.d.sync.aligned.m16n16k16.global.row.f32 [%ptrOut], {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}, %stride32; 147 | } 148 | 149 | ret; 150 | } -------------------------------------------------------------------------------- /learn_ptx/kernels/matmul_wmma_v2.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_80 // needed for wmma instruction 3 | .address_size 64 4 | 5 | // This is like matmul_wmma_v1.ptx, but we use shared memory to reduce 6 | // loads from global memory. 7 | 8 | .visible .entry wmmaMatmulV2 ( 9 | .param .u64 ptrA, 10 | .param .u64 ptrB, 11 | .param .u64 ptrOut, 12 | .param .u32 numBlocks 13 | ) { 14 | .reg .pred %p0; 15 | 16 | // Attributes of the thread/CTA. 17 | .reg .u32 %tidX; 18 | .reg .u32 %tidY; 19 | .reg .u64 %ctaX; 20 | .reg .u64 %ctaY; 21 | 22 | // Arguments 23 | .reg .u64 %ptrA; 24 | .reg .u64 %ptrB; 25 | .reg .u64 %ptrOut; 26 | .reg .u32 %numBlocks; 27 | 28 | // Cache for block operands. 29 | .shared .align 4 .f32 sharedA[1024]; 30 | .shared .align 4 .f32 sharedB[1024]; 31 | 32 | ld.param.u64 %ptrA, [ptrA]; 33 | ld.param.u64 %ptrB, [ptrB]; 34 | ld.param.u64 %ptrOut, [ptrOut]; 35 | ld.param.u32 %numBlocks, [numBlocks]; 36 | 37 | mov.u32 %tidX, %tid.x; // index in warp (0-32) 38 | mov.u32 %tidY, %tid.y; // warp index in block (0-4) 39 | cvt.u64.u32 %ctaX, %ctaid.x; // column of output 40 | cvt.u64.u32 %ctaY, %ctaid.y; // row of output 41 | 42 | // Accumulation registers are stored as 8 floats per thread. 43 | .reg .f32 %out<8>; 44 | mov.f32 %out0, 0.0; 45 | mov.f32 %out1, 0.0; 46 | mov.f32 %out2, 0.0; 47 | mov.f32 %out3, 0.0; 48 | mov.f32 %out4, 0.0; 49 | mov.f32 %out5, 0.0; 50 | mov.f32 %out6, 0.0; 51 | mov.f32 %out7, 0.0; 52 | 53 | // The row-wise stride of the matrices, measured in tf32's. 54 | .reg .u32 %stride32; 55 | .reg .u64 %stride; 56 | shl.b32 %stride32, %numBlocks, 5; 57 | cvt.u64.u32 %stride, %stride32; 58 | 59 | // This is used to increment by 4 rows at a time while loading. 60 | .reg .u64 %loadStride; 61 | shl.b64 %loadStride, %stride, 4; // 4 bytes * 4 rows 62 | 63 | // We will use pointerInA to point to a thread-specific part of ptrA, 64 | // which we increment as we load blocks. 65 | // We set loadPointerInSharedA to a pointer where we copy things into 66 | // when loading shared memory. 67 | // The other argument, %pointerInSharedA, never changes. 68 | .reg .u64 %pointerInA; 69 | .reg .u32 %loadPointerInSharedA; 70 | .reg .u32 %pointerInSharedA; 71 | { 72 | .reg .u64 %tmp; 73 | .reg .u32 %stmp; 74 | 75 | shl.b64 %tmp, %stride, 7; // 4 bytes per float * 32 rows 76 | mul.lo.u64 %tmp, %tmp, %ctaY; 77 | add.u64 %pointerInA, %ptrA, %tmp; 78 | cvt.u64.u32 %tmp, %tidX; 79 | shl.b64 %tmp, %tmp, 2; // tidX * (4 bytes) 80 | add.u64 %pointerInA, %pointerInA, %tmp; 81 | cvt.u64.u32 %tmp, %tidY; 82 | mul.lo.u64 %tmp, %tmp, %stride; 83 | shl.b64 %tmp, %tmp, 2; // multiply tidY * stride by 4 bytes per row. 84 | add.u64 %pointerInA, %pointerInA, %tmp; 85 | 86 | mov.u32 %loadPointerInSharedA, sharedA; 87 | shl.b32 %stmp, %tidX, 2; 88 | add.u32 %loadPointerInSharedA, %loadPointerInSharedA, %stmp; 89 | shl.b32 %stmp, %tidY, 7; // 32*(4 bytes) 90 | add.u32 %loadPointerInSharedA, %loadPointerInSharedA, %stmp; 91 | 92 | // pointerInSharedA depends only on which output row we are doing. 93 | mov.u32 %pointerInSharedA, sharedA; 94 | and.b32 %stmp, %tidY, 2; 95 | shl.b32 %stmp, %stmp, 10; // Shift down by 32*16*(4 bytes) / factor of 2 96 | add.u32 %pointerInSharedA, %pointerInSharedA, %stmp; 97 | } 98 | 99 | // pointerInB is like pointerInA, except that we advance it by row rather than 100 | // by column. 101 | .reg .u64 %pointerInB; 102 | .reg .u32 %loadPointerInSharedB; 103 | .reg .u32 %pointerInSharedB; 104 | { 105 | .reg .u32 %stmp; 106 | .reg .u64 %tmp<2>; 107 | 108 | shl.b64 %tmp0, %ctaX, 7; // 4 bytes per float * 32 columns 109 | add.u64 %pointerInB, %ptrB, %tmp0; 110 | cvt.u64.u32 %tmp1, %tidX; 111 | shl.b64 %tmp0, %tmp1, 2; // 4 bytes per float 112 | add.u64 %pointerInB, %pointerInB, %tmp0; 113 | shl.b64 %tmp0, %stride, 2; // stride * 4 bytes per float 114 | cvt.u64.u32 %tmp1, %tidY; 115 | mul.lo.u64 %tmp0, %tmp0, %tmp1; 116 | add.u64 %pointerInB, %pointerInB, %tmp0; 117 | 118 | mov.u32 %loadPointerInSharedB, sharedB; 119 | shl.b32 %stmp, %tidX, 2; 120 | add.u32 %loadPointerInSharedB, %loadPointerInSharedB, %stmp; 121 | shl.b32 %stmp, %tidY, 7; // 32*(4 bytes) 122 | add.u32 %loadPointerInSharedB, %loadPointerInSharedB, %stmp; 123 | 124 | // pointerInSharedB depends only on which output column we are doing. 125 | mov.u32 %pointerInSharedB, sharedB; 126 | and.b32 %stmp, %tidY, 1; // 1 if second column of block, 0 if first 127 | shl.b32 %stmp, %stmp, 6; // 16 floats * 4 bytes 128 | add.u32 %pointerInSharedB, %pointerInSharedB, %stmp; 129 | } 130 | 131 | .reg .u32 %remainingIters; 132 | mov.u32 %remainingIters, %numBlocks; 133 | 134 | outer_loop: 135 | setp.gt.u32 %p0, %remainingIters, 0; 136 | @!%p0 bra outer_loop_end; 137 | sub.u32 %remainingIters, %remainingIters, 1; 138 | 139 | // Load matrix A into shared memory. 140 | { 141 | .reg .u32 %i; 142 | .reg .f32 %ftmp; 143 | .reg .u64 %tmp; 144 | 145 | ld.global.f32 %ftmp, [%pointerInA]; 146 | st.shared.f32 [%loadPointerInSharedA], %ftmp; 147 | 148 | add.u64 %tmp, %pointerInA, %loadStride; 149 | ld.global.f32 %ftmp, [%tmp]; 150 | st.shared.f32 [%loadPointerInSharedA+512], %ftmp; 151 | 152 | add.u64 %tmp, %tmp, %loadStride; 153 | ld.global.f32 %ftmp, [%tmp]; 154 | st.shared.f32 [%loadPointerInSharedA+1024], %ftmp; 155 | 156 | add.u64 %tmp, %tmp, %loadStride; 157 | ld.global.f32 %ftmp, [%tmp]; 158 | st.shared.f32 [%loadPointerInSharedA+1536], %ftmp; 159 | 160 | add.u64 %tmp, %tmp, %loadStride; 161 | ld.global.f32 %ftmp, [%tmp]; 162 | st.shared.f32 [%loadPointerInSharedA+2048], %ftmp; 163 | 164 | add.u64 %tmp, %tmp, %loadStride; 165 | ld.global.f32 %ftmp, [%tmp]; 166 | st.shared.f32 [%loadPointerInSharedA+2560], %ftmp; 167 | 168 | add.u64 %tmp, %tmp, %loadStride; 169 | ld.global.f32 %ftmp, [%tmp]; 170 | st.shared.f32 [%loadPointerInSharedA+3072], %ftmp; 171 | 172 | add.u64 %tmp, %tmp, %loadStride; 173 | ld.global.f32 %ftmp, [%tmp]; 174 | st.shared.f32 [%loadPointerInSharedA+3584], %ftmp; 175 | 176 | // Advance to the right 32 floats. 177 | add.u64 %pointerInA, %pointerInA, 128; 178 | } 179 | 180 | // Load matrix B into shared memory. 181 | { 182 | .reg .u32 %i; 183 | .reg .f32 %ftmp; 184 | .reg .u64 %tmp; 185 | 186 | ld.global.f32 %ftmp, [%pointerInB]; 187 | st.shared.f32 [%loadPointerInSharedB], %ftmp; 188 | 189 | add.u64 %tmp, %pointerInB, %loadStride; 190 | ld.global.f32 %ftmp, [%tmp]; 191 | st.shared.f32 [%loadPointerInSharedB+512], %ftmp; 192 | 193 | add.u64 %tmp, %tmp, %loadStride; 194 | ld.global.f32 %ftmp, [%tmp]; 195 | st.shared.f32 [%loadPointerInSharedB+1024], %ftmp; 196 | add.u64 %tmp, %tmp, %loadStride; 197 | ld.global.f32 %ftmp, [%tmp]; 198 | st.shared.f32 [%loadPointerInSharedB+1536], %ftmp; 199 | 200 | add.u64 %tmp, %tmp, %loadStride; 201 | ld.global.f32 %ftmp, [%tmp]; 202 | st.shared.f32 [%loadPointerInSharedB+2048], %ftmp; 203 | 204 | add.u64 %tmp, %tmp, %loadStride; 205 | ld.global.f32 %ftmp, [%tmp]; 206 | st.shared.f32 [%loadPointerInSharedB+2560], %ftmp; 207 | 208 | add.u64 %tmp, %tmp, %loadStride; 209 | ld.global.f32 %ftmp, [%tmp]; 210 | st.shared.f32 [%loadPointerInSharedB+3072], %ftmp; 211 | 212 | add.u64 %tmp, %tmp, %loadStride; 213 | ld.global.f32 %ftmp, [%tmp]; 214 | st.shared.f32 [%loadPointerInSharedB+3584], %ftmp; 215 | 216 | add.u64 %pointerInB, %tmp, %loadStride; 217 | } 218 | 219 | bar.sync 0; 220 | 221 | { 222 | .reg .b32 %a<4>; 223 | .reg .b32 %b<4>; 224 | wmma.load.a.sync.aligned.row.m16n16k8.shared.tf32 {%a0, %a1, %a2, %a3}, [%pointerInSharedA], 32; 225 | wmma.load.b.sync.aligned.row.m16n16k8.shared.tf32 {%b0, %b1, %b2, %b3}, [%pointerInSharedB], 32; 226 | wmma.mma.sync.aligned.row.row.m16n16k8.f32.tf32.tf32.f32 {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}, {%a0, %a1, %a2, %a3}, {%b0, %b1, %b2, %b3}, {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}; 227 | 228 | wmma.load.a.sync.aligned.row.m16n16k8.shared.tf32 {%a0, %a1, %a2, %a3}, [%pointerInSharedA+32], 32; 229 | wmma.load.b.sync.aligned.row.m16n16k8.shared.tf32 {%b0, %b1, %b2, %b3}, [%pointerInSharedB+1024], 32; 230 | wmma.mma.sync.aligned.row.row.m16n16k8.f32.tf32.tf32.f32 {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}, {%a0, %a1, %a2, %a3}, {%b0, %b1, %b2, %b3}, {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}; 231 | 232 | wmma.load.a.sync.aligned.row.m16n16k8.shared.tf32 {%a0, %a1, %a2, %a3}, [%pointerInSharedA+64], 32; 233 | wmma.load.b.sync.aligned.row.m16n16k8.shared.tf32 {%b0, %b1, %b2, %b3}, [%pointerInSharedB+2048], 32; 234 | wmma.mma.sync.aligned.row.row.m16n16k8.f32.tf32.tf32.f32 {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}, {%a0, %a1, %a2, %a3}, {%b0, %b1, %b2, %b3}, {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}; 235 | 236 | wmma.load.a.sync.aligned.row.m16n16k8.shared.tf32 {%a0, %a1, %a2, %a3}, [%pointerInSharedA+96], 32; 237 | wmma.load.b.sync.aligned.row.m16n16k8.shared.tf32 {%b0, %b1, %b2, %b3}, [%pointerInSharedB+3072], 32; 238 | wmma.mma.sync.aligned.row.row.m16n16k8.f32.tf32.tf32.f32 {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}, {%a0, %a1, %a2, %a3}, {%b0, %b1, %b2, %b3}, {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}; 239 | } 240 | 241 | bar.sync 0; 242 | 243 | bra outer_loop; 244 | outer_loop_end: 245 | 246 | { 247 | .reg .u64 %outColumn; 248 | .reg .u64 %outOffset; 249 | .reg .u64 %tmp; 250 | 251 | shl.b64 %outColumn, %ctaX, 7; // 32 floats * 4 bytes 252 | cvt.u64.u32 %tmp, %tidY; 253 | and.b64 %tmp, %tmp, 1; // 1 if second column of block, 0 if first 254 | shl.b64 %tmp, %tmp, 6; // 16 floats * 4 bytes 255 | add.u64 %outColumn, %outColumn, %tmp; 256 | 257 | shl.b64 %outOffset, %stride, 7; // turn into a row offset (4 bytes), times 32 rows 258 | mul.lo.u64 %outOffset, %outOffset, %ctaY; 259 | cvt.u64.u32 %tmp, %tidY; 260 | 261 | // Offset for bottom half. 262 | and.b64 %tmp, %tmp, 2; // 2 if second row of block, 0 if first 263 | mul.lo.u64 %tmp, %tmp, %stride; 264 | shl.b64 %tmp, %tmp, 5; // for second row: 16 * stride * 4 bytes (already was 2, not 1) 265 | add.u64 %outOffset, %outOffset, %tmp; 266 | 267 | add.u64 %outOffset, %outOffset, %outColumn; 268 | add.u64 %ptrOut, %ptrOut, %outOffset; 269 | 270 | // Copy to %ptrOut. 271 | wmma.store.d.sync.aligned.m16n16k16.global.row.f32 [%ptrOut], {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}, %stride32; 272 | } 273 | 274 | ret; 275 | } -------------------------------------------------------------------------------- /learn_ptx/kernels/matmul_wmma_v3.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_80 // needed for wmma instruction 3 | .address_size 64 4 | 5 | // This is like matmul_wmma_v2.ptx, but we use a different layout for matrix 6 | // A in shared memory to avoid shared-memory bank conflicts. 7 | 8 | .visible .entry wmmaMatmulV3 ( 9 | .param .u64 ptrA, 10 | .param .u64 ptrB, 11 | .param .u64 ptrOut, 12 | .param .u32 numBlocks 13 | ) { 14 | .reg .pred %p0; 15 | 16 | // Attributes of the thread/CTA. 17 | .reg .u32 %tidX; 18 | .reg .u32 %tidY; 19 | .reg .u64 %ctaX; 20 | .reg .u64 %ctaY; 21 | 22 | // Arguments 23 | .reg .u64 %ptrA; 24 | .reg .u64 %ptrB; 25 | .reg .u64 %ptrOut; 26 | .reg .u32 %numBlocks; 27 | 28 | // Cache for block operands. 29 | .shared .align 4 .f32 sharedA[1024]; 30 | .shared .align 4 .f32 sharedB[1024]; 31 | 32 | ld.param.u64 %ptrA, [ptrA]; 33 | ld.param.u64 %ptrB, [ptrB]; 34 | ld.param.u64 %ptrOut, [ptrOut]; 35 | ld.param.u32 %numBlocks, [numBlocks]; 36 | 37 | mov.u32 %tidX, %tid.x; // index in warp (0-32) 38 | mov.u32 %tidY, %tid.y; // warp index in block (0-4) 39 | cvt.u64.u32 %ctaX, %ctaid.x; // column of output 40 | cvt.u64.u32 %ctaY, %ctaid.y; // row of output 41 | 42 | // Accumulation registers are stored as 8 floats per thread. 43 | .reg .f32 %out<8>; 44 | mov.f32 %out0, 0.0; 45 | mov.f32 %out1, 0.0; 46 | mov.f32 %out2, 0.0; 47 | mov.f32 %out3, 0.0; 48 | mov.f32 %out4, 0.0; 49 | mov.f32 %out5, 0.0; 50 | mov.f32 %out6, 0.0; 51 | mov.f32 %out7, 0.0; 52 | 53 | // The row-wise stride of the matrices, measured in tf32's. 54 | .reg .u32 %stride32; 55 | .reg .u64 %stride; 56 | shl.b32 %stride32, %numBlocks, 5; 57 | cvt.u64.u32 %stride, %stride32; 58 | 59 | // This is used to increment by 4 rows at a time while loading. 60 | .reg .u64 %loadStride; 61 | .reg .u64 %loadAStride; 62 | shl.b64 %loadStride, %stride, 4; // 4 bytes * 4 rows 63 | shl.b64 %loadAStride, %loadStride, 2; // 16 rows at a time 64 | 65 | // We will use pointerInA to point to a thread-specific part of ptrA, 66 | // which we increment as we load blocks. 67 | // We set loadPointerInSharedA to a pointer where we copy things into 68 | // when loading shared memory. 69 | // The other argument, %pointerInSharedA, never changes. 70 | .reg .u64 %pointerInA; 71 | .reg .u32 %loadPointerInSharedA; 72 | .reg .u32 %pointerInSharedA; 73 | { 74 | .reg .u64 %tmp<2>; 75 | .reg .u32 %stmp; 76 | 77 | shl.b64 %tmp0, %stride, 7; // 4 bytes per float * 32 rows 78 | mul.lo.u64 %tmp0, %tmp0, %ctaY; // ctaY*(32 rows)*(4 bytes) 79 | add.u64 %pointerInA, %ptrA, %tmp0; 80 | cvt.u64.u32 %tmp0, %tidX; 81 | mov.u64 %tmp1, %tmp0; 82 | and.b64 %tmp0, %tmp0, 7; // tidX % 8 gives our X offset 83 | shl.b64 %tmp0, %tmp0, 2; // (tidX % 8) * (4 bytes) 84 | add.u64 %pointerInA, %pointerInA, %tmp0; 85 | cvt.u64.u32 %tmp0, %tidY; 86 | shl.b64 %tmp0, %tmp0, 2; // tidY*4 87 | shr.b64 %tmp1, %tmp1, 3; // tidX//8 88 | add.u64 %tmp0, %tmp0, %tmp1; // (tidY*4 + tidX//8) 89 | mul.lo.u64 %tmp0, %tmp0, %stride; 90 | shl.b64 %tmp0, %tmp0, 2; // multiply (4*tidY + tidX//8) * stride by 4 bytes per row 91 | add.u64 %pointerInA, %pointerInA, %tmp0; 92 | 93 | // We only care whether we are working on the top of bottom half of A. 94 | // In the bottom case, we skip the first four 16x8 matrices. 95 | and.b32 %stmp, %tidY, 2; 96 | shl.b32 %stmp, %stmp, 10; // (16*8 floats)*(4 matrices)*(4 bytes) / (2 from tidY and) 97 | mov.u32 %pointerInSharedA, sharedA; 98 | add.u32 %pointerInSharedA, %pointerInSharedA, %stmp; 99 | 100 | // Each group of four consecutive rows are loaded by a warp in four load 101 | // instructions, such that they can then be rearranged so that the destination 102 | // matrices are consecutive in shared memory. 103 | mov.u32 %loadPointerInSharedA, sharedA; 104 | shl.b32 %stmp, %tidY, 5; // tidY*(32 floats) 105 | add.u32 %stmp, %stmp, %tidX; 106 | shl.b32 %stmp, %stmp, 2; // *= 4 bytes 107 | mov.u32 %loadPointerInSharedA, sharedA; 108 | add.u32 %loadPointerInSharedA, %loadPointerInSharedA, %stmp; // &sharedA[(tidY*32 + tidX)] 109 | } 110 | 111 | // pointerInB is like pointerInA, except that we advance it by row rather than 112 | // by column. 113 | .reg .u64 %pointerInB; 114 | .reg .u32 %loadPointerInSharedB; 115 | .reg .u32 %pointerInSharedB; 116 | { 117 | .reg .u32 %stmp; 118 | .reg .u64 %tmp<2>; 119 | 120 | shl.b64 %tmp0, %ctaX, 7; // 4 bytes per float * 32 columns 121 | add.u64 %pointerInB, %ptrB, %tmp0; 122 | cvt.u64.u32 %tmp1, %tidX; 123 | shl.b64 %tmp0, %tmp1, 2; // 4 bytes per float 124 | add.u64 %pointerInB, %pointerInB, %tmp0; 125 | shl.b64 %tmp0, %stride, 2; // stride * 4 bytes per float 126 | cvt.u64.u32 %tmp1, %tidY; 127 | mul.lo.u64 %tmp0, %tmp0, %tmp1; 128 | add.u64 %pointerInB, %pointerInB, %tmp0; 129 | 130 | mov.u32 %loadPointerInSharedB, sharedB; 131 | shl.b32 %stmp, %tidX, 2; 132 | add.u32 %loadPointerInSharedB, %loadPointerInSharedB, %stmp; 133 | shl.b32 %stmp, %tidY, 7; // 32*(4 bytes) 134 | add.u32 %loadPointerInSharedB, %loadPointerInSharedB, %stmp; 135 | 136 | // pointerInSharedB depends only on which output column we are doing. 137 | mov.u32 %pointerInSharedB, sharedB; 138 | and.b32 %stmp, %tidY, 1; // 1 if second column of block, 0 if first 139 | shl.b32 %stmp, %stmp, 6; // 16 floats * 4 bytes 140 | add.u32 %pointerInSharedB, %pointerInSharedB, %stmp; 141 | } 142 | 143 | .reg .u32 %remainingIters; 144 | mov.u32 %remainingIters, %numBlocks; 145 | 146 | outer_loop: 147 | setp.gt.u32 %p0, %remainingIters, 0; 148 | @!%p0 bra outer_loop_end; 149 | sub.u32 %remainingIters, %remainingIters, 1; 150 | 151 | // Load matrix A into shared memory. 152 | { 153 | .reg .f32 %ftmp<4>; 154 | .reg .u64 %tmp; 155 | 156 | // We load four matrices at once. 157 | ld.global.f32 %ftmp0, [%pointerInA]; 158 | ld.global.f32 %ftmp1, [%pointerInA+32]; 159 | ld.global.f32 %ftmp2, [%pointerInA+64]; 160 | ld.global.f32 %ftmp3, [%pointerInA+96]; 161 | 162 | // Add size of one matrix each time: 16*8*(4 bytes) 163 | st.shared.f32 [%loadPointerInSharedA], %ftmp0; 164 | st.shared.f32 [%loadPointerInSharedA+512], %ftmp1; 165 | st.shared.f32 [%loadPointerInSharedA+1024], %ftmp2; 166 | st.shared.f32 [%loadPointerInSharedA+1536], %ftmp3; 167 | 168 | // Do the same thing for the bottom 16x32 chunk of A. 169 | add.u64 %tmp, %pointerInA, %loadAStride; 170 | ld.global.f32 %ftmp0, [%tmp]; 171 | ld.global.f32 %ftmp1, [%tmp+32]; 172 | ld.global.f32 %ftmp2, [%tmp+64]; 173 | ld.global.f32 %ftmp3, [%tmp+96]; 174 | 175 | // Add size of one matrix each time: 16*8*(4 bytes) 176 | st.shared.f32 [%loadPointerInSharedA+2048], %ftmp0; 177 | st.shared.f32 [%loadPointerInSharedA+2560], %ftmp1; 178 | st.shared.f32 [%loadPointerInSharedA+3072], %ftmp2; 179 | st.shared.f32 [%loadPointerInSharedA+3584], %ftmp3; 180 | 181 | // Advance to the right 32 floats (columns). 182 | add.u64 %pointerInA, %pointerInA, 128; 183 | } 184 | 185 | // Load matrix B into shared memory. 186 | { 187 | .reg .u32 %i; 188 | .reg .f32 %ftmp; 189 | .reg .u64 %tmp; 190 | 191 | ld.global.f32 %ftmp, [%pointerInB]; 192 | st.shared.f32 [%loadPointerInSharedB], %ftmp; 193 | 194 | add.u64 %tmp, %pointerInB, %loadStride; 195 | ld.global.f32 %ftmp, [%tmp]; 196 | st.shared.f32 [%loadPointerInSharedB+512], %ftmp; 197 | 198 | add.u64 %tmp, %tmp, %loadStride; 199 | ld.global.f32 %ftmp, [%tmp]; 200 | st.shared.f32 [%loadPointerInSharedB+1024], %ftmp; 201 | add.u64 %tmp, %tmp, %loadStride; 202 | ld.global.f32 %ftmp, [%tmp]; 203 | st.shared.f32 [%loadPointerInSharedB+1536], %ftmp; 204 | 205 | add.u64 %tmp, %tmp, %loadStride; 206 | ld.global.f32 %ftmp, [%tmp]; 207 | st.shared.f32 [%loadPointerInSharedB+2048], %ftmp; 208 | 209 | add.u64 %tmp, %tmp, %loadStride; 210 | ld.global.f32 %ftmp, [%tmp]; 211 | st.shared.f32 [%loadPointerInSharedB+2560], %ftmp; 212 | 213 | add.u64 %tmp, %tmp, %loadStride; 214 | ld.global.f32 %ftmp, [%tmp]; 215 | st.shared.f32 [%loadPointerInSharedB+3072], %ftmp; 216 | 217 | add.u64 %tmp, %tmp, %loadStride; 218 | ld.global.f32 %ftmp, [%tmp]; 219 | st.shared.f32 [%loadPointerInSharedB+3584], %ftmp; 220 | 221 | add.u64 %pointerInB, %tmp, %loadStride; 222 | } 223 | 224 | bar.sync 0; 225 | 226 | { 227 | .reg .b32 %a<4>; 228 | .reg .b32 %b<4>; 229 | wmma.load.a.sync.aligned.row.m16n16k8.shared.tf32 {%a0, %a1, %a2, %a3}, [%pointerInSharedA], 8; 230 | wmma.load.b.sync.aligned.row.m16n16k8.shared.tf32 {%b0, %b1, %b2, %b3}, [%pointerInSharedB], 32; 231 | wmma.mma.sync.aligned.row.row.m16n16k8.f32.tf32.tf32.f32 {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}, {%a0, %a1, %a2, %a3}, {%b0, %b1, %b2, %b3}, {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}; 232 | 233 | wmma.load.a.sync.aligned.row.m16n16k8.shared.tf32 {%a0, %a1, %a2, %a3}, [%pointerInSharedA+512], 8; 234 | wmma.load.b.sync.aligned.row.m16n16k8.shared.tf32 {%b0, %b1, %b2, %b3}, [%pointerInSharedB+1024], 32; 235 | wmma.mma.sync.aligned.row.row.m16n16k8.f32.tf32.tf32.f32 {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}, {%a0, %a1, %a2, %a3}, {%b0, %b1, %b2, %b3}, {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}; 236 | 237 | wmma.load.a.sync.aligned.row.m16n16k8.shared.tf32 {%a0, %a1, %a2, %a3}, [%pointerInSharedA+1024], 8; 238 | wmma.load.b.sync.aligned.row.m16n16k8.shared.tf32 {%b0, %b1, %b2, %b3}, [%pointerInSharedB+2048], 32; 239 | wmma.mma.sync.aligned.row.row.m16n16k8.f32.tf32.tf32.f32 {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}, {%a0, %a1, %a2, %a3}, {%b0, %b1, %b2, %b3}, {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}; 240 | 241 | wmma.load.a.sync.aligned.row.m16n16k8.shared.tf32 {%a0, %a1, %a2, %a3}, [%pointerInSharedA+1536], 8; 242 | wmma.load.b.sync.aligned.row.m16n16k8.shared.tf32 {%b0, %b1, %b2, %b3}, [%pointerInSharedB+3072], 32; 243 | wmma.mma.sync.aligned.row.row.m16n16k8.f32.tf32.tf32.f32 {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}, {%a0, %a1, %a2, %a3}, {%b0, %b1, %b2, %b3}, {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}; 244 | } 245 | 246 | bar.sync 0; 247 | 248 | bra outer_loop; 249 | outer_loop_end: 250 | 251 | { 252 | .reg .u64 %outColumn; 253 | .reg .u64 %outOffset; 254 | .reg .u64 %tmp; 255 | 256 | shl.b64 %outColumn, %ctaX, 7; // 32 floats * 4 bytes 257 | cvt.u64.u32 %tmp, %tidY; 258 | and.b64 %tmp, %tmp, 1; // 1 if second column of block, 0 if first 259 | shl.b64 %tmp, %tmp, 6; // 16 floats * 4 bytes 260 | add.u64 %outColumn, %outColumn, %tmp; 261 | 262 | shl.b64 %outOffset, %stride, 7; // turn into a row offset (4 bytes), times 32 rows 263 | mul.lo.u64 %outOffset, %outOffset, %ctaY; 264 | cvt.u64.u32 %tmp, %tidY; 265 | 266 | // Offset for bottom half. 267 | and.b64 %tmp, %tmp, 2; // 2 if second row of block, 0 if first 268 | mul.lo.u64 %tmp, %tmp, %stride; 269 | shl.b64 %tmp, %tmp, 5; // for second row: 16 * stride * 4 bytes (already was 2, not 1) 270 | add.u64 %outOffset, %outOffset, %tmp; 271 | 272 | add.u64 %outOffset, %outOffset, %outColumn; 273 | add.u64 %ptrOut, %ptrOut, %outOffset; 274 | 275 | // Copy to %ptrOut. 276 | wmma.store.d.sync.aligned.m16n16k16.global.row.f32 [%ptrOut], {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}, %stride32; 277 | } 278 | 279 | ret; 280 | } -------------------------------------------------------------------------------- /learn_ptx/kernels/matmul_wmma_v4.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_80 // needed for wmma instruction 3 | .address_size 64 4 | 5 | // This is like matmul_wmma_v3.ptx, but with a layout for matrix B in 6 | // shared memory that avoids bank conflicts. 7 | 8 | .visible .entry wmmaMatmulV4 ( 9 | .param .u64 ptrA, 10 | .param .u64 ptrB, 11 | .param .u64 ptrOut, 12 | .param .u32 numBlocks 13 | ) { 14 | .reg .pred %p0; 15 | 16 | // Attributes of the thread/CTA. 17 | .reg .u32 %tidX; 18 | .reg .u32 %tidY; 19 | .reg .u64 %ctaX; 20 | .reg .u64 %ctaY; 21 | 22 | // Arguments 23 | .reg .u64 %ptrA; 24 | .reg .u64 %ptrB; 25 | .reg .u64 %ptrOut; 26 | .reg .u32 %numBlocks; 27 | 28 | // Cache for block operands. 29 | .shared .align 4 .f32 sharedA[1024]; 30 | .shared .align 4 .f32 sharedB[1024]; 31 | 32 | ld.param.u64 %ptrA, [ptrA]; 33 | ld.param.u64 %ptrB, [ptrB]; 34 | ld.param.u64 %ptrOut, [ptrOut]; 35 | ld.param.u32 %numBlocks, [numBlocks]; 36 | 37 | mov.u32 %tidX, %tid.x; // index in warp (0-32) 38 | mov.u32 %tidY, %tid.y; // warp index in block (0-4) 39 | cvt.u64.u32 %ctaX, %ctaid.x; // column of output 40 | cvt.u64.u32 %ctaY, %ctaid.y; // row of output 41 | 42 | // Accumulation registers are stored as 8 floats per thread. 43 | .reg .f32 %out<8>; 44 | mov.f32 %out0, 0.0; 45 | mov.f32 %out1, 0.0; 46 | mov.f32 %out2, 0.0; 47 | mov.f32 %out3, 0.0; 48 | mov.f32 %out4, 0.0; 49 | mov.f32 %out5, 0.0; 50 | mov.f32 %out6, 0.0; 51 | mov.f32 %out7, 0.0; 52 | 53 | // The row-wise stride of the matrices, measured in tf32's. 54 | .reg .u32 %stride32; 55 | .reg .u64 %stride; 56 | shl.b32 %stride32, %numBlocks, 5; 57 | cvt.u64.u32 %stride, %stride32; 58 | 59 | // This is used to increment by 4 rows at a time while loading. 60 | .reg .u64 %loadAStride; 61 | .reg .u64 %loadBStride; 62 | { 63 | shl.b64 %loadAStride, %stride, 6; // 16 rows at a time 64 | shl.b64 %loadBStride, %stride, 5; // 8 rows at a time 65 | } 66 | 67 | // We will use pointerInA to point to a thread-specific part of ptrA, 68 | // which we increment as we load blocks. 69 | // We set loadPointerInSharedA to a pointer where we copy things into 70 | // when loading shared memory. 71 | // The other argument, %pointerInSharedA, never changes. 72 | .reg .u64 %pointerInA; 73 | .reg .u32 %loadPointerInSharedA; 74 | .reg .u32 %pointerInSharedA; 75 | { 76 | .reg .u64 %tmp<2>; 77 | .reg .u32 %stmp; 78 | 79 | shl.b64 %tmp0, %stride, 7; // 4 bytes per float * 32 rows 80 | mul.lo.u64 %tmp0, %tmp0, %ctaY; // ctaY*(32 rows)*(4 bytes) 81 | add.u64 %pointerInA, %ptrA, %tmp0; 82 | cvt.u64.u32 %tmp0, %tidX; 83 | mov.u64 %tmp1, %tmp0; 84 | and.b64 %tmp0, %tmp0, 7; // tidX % 8 gives our X offset 85 | shl.b64 %tmp0, %tmp0, 2; // (tidX % 8) * (4 bytes) 86 | add.u64 %pointerInA, %pointerInA, %tmp0; 87 | cvt.u64.u32 %tmp0, %tidY; 88 | shl.b64 %tmp0, %tmp0, 2; // tidY*4 89 | shr.b64 %tmp1, %tmp1, 3; // tidX//8 90 | add.u64 %tmp0, %tmp0, %tmp1; // (tidY*4 + tidX//8) 91 | mul.lo.u64 %tmp0, %tmp0, %stride; 92 | shl.b64 %tmp0, %tmp0, 2; // multiply (4*tidY + tidX//8) * stride by 4 bytes per row 93 | add.u64 %pointerInA, %pointerInA, %tmp0; 94 | 95 | // We only care whether we are working on the top of bottom half of A. 96 | // In the bottom case, we skip the first four 16x8 matrices. 97 | and.b32 %stmp, %tidY, 2; 98 | shl.b32 %stmp, %stmp, 10; // (16*8 floats)*(4 matrices)*(4 bytes) / (2 from tidY and) 99 | mov.u32 %pointerInSharedA, sharedA; 100 | add.u32 %pointerInSharedA, %pointerInSharedA, %stmp; 101 | 102 | // Each group of four consecutive rows are loaded by a warp in four load 103 | // instructions, such that they can then be rearranged so that the destination 104 | // matrices are consecutive in shared memory. 105 | mov.u32 %loadPointerInSharedA, sharedA; 106 | shl.b32 %stmp, %tidY, 5; // tidY*(32 floats) 107 | add.u32 %stmp, %stmp, %tidX; 108 | shl.b32 %stmp, %stmp, 2; // *= 4 bytes 109 | mov.u32 %loadPointerInSharedA, sharedA; 110 | add.u32 %loadPointerInSharedA, %loadPointerInSharedA, %stmp; // &sharedA[(tidY*32 + tidX)] 111 | } 112 | 113 | // Each warp loads two rows of B at a time, and dumps each 8x16 sub-block of 114 | // B into its own sub-matrix. 115 | // The left 32x16 column is contiguous in shared memory, and then the right 116 | // 32x16 column follows it. 117 | .reg .u64 %pointerInB; 118 | .reg .u32 %loadPointerInSharedB; 119 | .reg .u32 %pointerInSharedB; 120 | { 121 | .reg .u32 %stmp; 122 | .reg .u64 %tmp<2>; 123 | 124 | shl.b64 %tmp0, %ctaX, 7; // 4 bytes per float * 32 columns 125 | add.u64 %pointerInB, %ptrB, %tmp0; 126 | cvt.u64.u32 %tmp1, %tidX; 127 | and.b64 %tmp0, %tmp1, 15; // tidX % 16 128 | shl.b64 %tmp0, %tmp0, 2; // 4 bytes per float 129 | add.u64 %pointerInB, %pointerInB, %tmp0; // pointerInB += (tidX % 16) * (4 bytes) 130 | shl.b64 %tmp0, %stride, 2; // stride * 4 bytes per float 131 | shr.b64 %tmp1, %tmp1, 4; // tidX // 16 132 | mul.lo.u64 %tmp1, %tmp1, %tmp0; 133 | add.u64 %pointerInB, %pointerInB, %tmp1; // pointerInB += (tidX // 16) * (bytes/row) 134 | cvt.u64.u32 %tmp1, %tidY; 135 | shl.b64 %tmp1, %tmp1, 1; // tidY*2 136 | mul.lo.u64 %tmp0, %tmp0, %tmp1; 137 | add.u64 %pointerInB, %pointerInB, %tmp0; // pointerInB += tidY * 2 * (bytes/row) 138 | 139 | mov.u32 %loadPointerInSharedB, sharedB; 140 | shl.b32 %stmp, %tidX, 2; // tidX*4 bytes 141 | add.u32 %loadPointerInSharedB, %loadPointerInSharedB, %stmp; 142 | shl.b32 %stmp, %tidY, 7; // tidY * 32 * (4 bytes) 143 | add.u32 %loadPointerInSharedB, %loadPointerInSharedB, %stmp; 144 | 145 | // pointerInSharedB depends on which output column we are doing. 146 | mov.u32 %pointerInSharedB, sharedB; 147 | and.b32 %stmp, %tidY, 1; // 1 if second column of block, 0 if first 148 | shl.b32 %stmp, %stmp, 11; // 4 matrices * (16 * 8) * 4 bytes 149 | add.u32 %pointerInSharedB, %pointerInSharedB, %stmp; 150 | } 151 | 152 | .reg .u32 %remainingIters; 153 | mov.u32 %remainingIters, %numBlocks; 154 | 155 | outer_loop: 156 | setp.gt.u32 %p0, %remainingIters, 0; 157 | @!%p0 bra outer_loop_end; 158 | sub.u32 %remainingIters, %remainingIters, 1; 159 | 160 | // Load matrix A into shared memory. 161 | { 162 | .reg .f32 %ftmp<4>; 163 | .reg .u64 %tmp; 164 | 165 | // We load four matrices at once. 166 | ld.global.f32 %ftmp0, [%pointerInA]; 167 | ld.global.f32 %ftmp1, [%pointerInA+32]; 168 | ld.global.f32 %ftmp2, [%pointerInA+64]; 169 | ld.global.f32 %ftmp3, [%pointerInA+96]; 170 | 171 | // Add size of one matrix each time: 16*8*(4 bytes) 172 | st.shared.f32 [%loadPointerInSharedA], %ftmp0; 173 | st.shared.f32 [%loadPointerInSharedA+512], %ftmp1; 174 | st.shared.f32 [%loadPointerInSharedA+1024], %ftmp2; 175 | st.shared.f32 [%loadPointerInSharedA+1536], %ftmp3; 176 | 177 | // Do the same thing for the bottom 16x32 chunk of A. 178 | add.u64 %tmp, %pointerInA, %loadAStride; 179 | ld.global.f32 %ftmp0, [%tmp]; 180 | ld.global.f32 %ftmp1, [%tmp+32]; 181 | ld.global.f32 %ftmp2, [%tmp+64]; 182 | ld.global.f32 %ftmp3, [%tmp+96]; 183 | 184 | // Add size of one matrix each time: 16*8*(4 bytes) 185 | st.shared.f32 [%loadPointerInSharedA+2048], %ftmp0; 186 | st.shared.f32 [%loadPointerInSharedA+2560], %ftmp1; 187 | st.shared.f32 [%loadPointerInSharedA+3072], %ftmp2; 188 | st.shared.f32 [%loadPointerInSharedA+3584], %ftmp3; 189 | 190 | // Advance to the right 32 floats (columns). 191 | add.u64 %pointerInA, %pointerInA, 128; 192 | } 193 | 194 | // Load matrix B into shared memory. 195 | { 196 | .reg .f32 %ftmp<2>; 197 | 198 | ld.global.f32 %ftmp0, [%pointerInB]; 199 | ld.global.f32 %ftmp1, [%pointerInB+64]; // offset by 16 columns 200 | st.shared.f32 [%loadPointerInSharedB], %ftmp0; 201 | st.shared.f32 [%loadPointerInSharedB+2048], %ftmp1; // (4 matrices) * (16 * 8) * (4 bytes) 202 | 203 | // Repeat while going down rows in B. 204 | add.u64 %pointerInB, %pointerInB, %loadBStride; 205 | ld.global.f32 %ftmp0, [%pointerInB]; 206 | ld.global.f32 %ftmp1, [%pointerInB+64]; 207 | st.shared.f32 [%loadPointerInSharedB+512], %ftmp0; 208 | st.shared.f32 [%loadPointerInSharedB+2560], %ftmp1; 209 | 210 | add.u64 %pointerInB, %pointerInB, %loadBStride; 211 | ld.global.f32 %ftmp0, [%pointerInB]; 212 | ld.global.f32 %ftmp1, [%pointerInB+64]; 213 | st.shared.f32 [%loadPointerInSharedB+1024], %ftmp0; 214 | st.shared.f32 [%loadPointerInSharedB+3072], %ftmp1; 215 | 216 | add.u64 %pointerInB, %pointerInB, %loadBStride; 217 | ld.global.f32 %ftmp0, [%pointerInB]; 218 | ld.global.f32 %ftmp1, [%pointerInB+64]; 219 | st.shared.f32 [%loadPointerInSharedB+1536], %ftmp0; 220 | st.shared.f32 [%loadPointerInSharedB+3584], %ftmp1; 221 | 222 | add.u64 %pointerInB, %pointerInB, %loadBStride; 223 | } 224 | 225 | bar.sync 0; 226 | 227 | { 228 | .reg .b32 %a<4>; 229 | .reg .b32 %b<4>; 230 | wmma.load.a.sync.aligned.row.m16n16k8.shared.tf32 {%a0, %a1, %a2, %a3}, [%pointerInSharedA], 8; 231 | wmma.load.b.sync.aligned.row.m16n16k8.shared.tf32 {%b0, %b1, %b2, %b3}, [%pointerInSharedB], 16; 232 | wmma.mma.sync.aligned.row.row.m16n16k8.f32.tf32.tf32.f32 {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}, {%a0, %a1, %a2, %a3}, {%b0, %b1, %b2, %b3}, {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}; 233 | 234 | wmma.load.a.sync.aligned.row.m16n16k8.shared.tf32 {%a0, %a1, %a2, %a3}, [%pointerInSharedA+512], 8; 235 | wmma.load.b.sync.aligned.row.m16n16k8.shared.tf32 {%b0, %b1, %b2, %b3}, [%pointerInSharedB+512], 16; 236 | wmma.mma.sync.aligned.row.row.m16n16k8.f32.tf32.tf32.f32 {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}, {%a0, %a1, %a2, %a3}, {%b0, %b1, %b2, %b3}, {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}; 237 | 238 | wmma.load.a.sync.aligned.row.m16n16k8.shared.tf32 {%a0, %a1, %a2, %a3}, [%pointerInSharedA+1024], 8; 239 | wmma.load.b.sync.aligned.row.m16n16k8.shared.tf32 {%b0, %b1, %b2, %b3}, [%pointerInSharedB+1024], 16; 240 | wmma.mma.sync.aligned.row.row.m16n16k8.f32.tf32.tf32.f32 {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}, {%a0, %a1, %a2, %a3}, {%b0, %b1, %b2, %b3}, {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}; 241 | 242 | wmma.load.a.sync.aligned.row.m16n16k8.shared.tf32 {%a0, %a1, %a2, %a3}, [%pointerInSharedA+1536], 8; 243 | wmma.load.b.sync.aligned.row.m16n16k8.shared.tf32 {%b0, %b1, %b2, %b3}, [%pointerInSharedB+1536], 16; 244 | wmma.mma.sync.aligned.row.row.m16n16k8.f32.tf32.tf32.f32 {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}, {%a0, %a1, %a2, %a3}, {%b0, %b1, %b2, %b3}, {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}; 245 | } 246 | 247 | bar.sync 0; 248 | 249 | bra outer_loop; 250 | outer_loop_end: 251 | 252 | { 253 | .reg .u64 %outColumn; 254 | .reg .u64 %outOffset; 255 | .reg .u64 %tmp; 256 | 257 | shl.b64 %outColumn, %ctaX, 7; // 32 floats * 4 bytes 258 | cvt.u64.u32 %tmp, %tidY; 259 | and.b64 %tmp, %tmp, 1; // 1 if second column of block, 0 if first 260 | shl.b64 %tmp, %tmp, 6; // 16 floats * 4 bytes 261 | add.u64 %outColumn, %outColumn, %tmp; 262 | 263 | shl.b64 %outOffset, %stride, 7; // turn into a row offset (4 bytes), times 32 rows 264 | mul.lo.u64 %outOffset, %outOffset, %ctaY; 265 | cvt.u64.u32 %tmp, %tidY; 266 | 267 | // Offset for bottom half. 268 | and.b64 %tmp, %tmp, 2; // 2 if second row of block, 0 if first 269 | mul.lo.u64 %tmp, %tmp, %stride; 270 | shl.b64 %tmp, %tmp, 5; // for second row: 16 * stride * 4 bytes (already was 2, not 1) 271 | add.u64 %outOffset, %outOffset, %tmp; 272 | 273 | add.u64 %outOffset, %outOffset, %outColumn; 274 | add.u64 %ptrOut, %ptrOut, %outOffset; 275 | 276 | // Copy to %ptrOut. 277 | wmma.store.d.sync.aligned.m16n16k16.global.row.f32 [%ptrOut], {%out0, %out1, %out2, %out3, %out4, %out5, %out6, %out7}, %stride32; 278 | } 279 | 280 | ret; 281 | } -------------------------------------------------------------------------------- /learn_ptx/kernels/reduction_all_max_naive.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | .visible .entry reductionAllMaxNaive ( 6 | .param .u64 ptrIn, 7 | .param .u64 ptrOut, 8 | .param .u64 numBlocks 9 | ) { 10 | .reg .pred %p0; 11 | 12 | // Arguments 13 | .reg .u64 %ptrIn; 14 | .reg .u64 %ptrOut; 15 | .reg .u64 %numBlocks; 16 | 17 | .reg .u64 %i; 18 | .reg .f32 %curMax; 19 | .reg .u64 %tmp0; 20 | .reg .u32 %stmp<2>; 21 | .reg .f32 %ftmp; 22 | 23 | .shared .align 4 .f32 results[32]; 24 | 25 | // Load arguments. 26 | ld.param.u64 %ptrIn, [ptrIn]; 27 | ld.param.u64 %ptrOut, [ptrOut]; 28 | ld.param.u64 %numBlocks, [numBlocks]; 29 | 30 | cvt.u64.u32 %tmp0, %tid.x; 31 | shl.b64 %tmp0, %tmp0, 2; 32 | add.u64 %ptrIn, %ptrIn, %tmp0; 33 | 34 | // Base condition: use our output. 35 | ld.global.f32 %curMax, [%ptrIn]; 36 | 37 | // Skip loop if each block only reads one value. 38 | setp.lt.u64 %p0, %numBlocks, 2; 39 | @%p0 bra loop_end; 40 | 41 | mov.u64 %i, 1; 42 | loop_start: 43 | add.u64 %ptrIn, %ptrIn, 4096; // block size * 4 44 | ld.global.f32 %ftmp, [%ptrIn]; 45 | max.f32 %curMax, %curMax, %ftmp; 46 | add.u64 %i, %i, 1; 47 | setp.lt.u64 %p0, %i, %numBlocks; 48 | @%p0 bra loop_start; 49 | loop_end: 50 | // Synchronize on warp using a hypercube. 51 | // https://en.wikipedia.org/wiki/Hypercube_(communication_pattern) 52 | shfl.sync.bfly.b32 %ftmp, %curMax, 1, 0x1f, 0xffffffff; 53 | max.f32 %curMax, %curMax, %ftmp; 54 | shfl.sync.bfly.b32 %ftmp, %curMax, 2, 0x1f, 0xffffffff; 55 | max.f32 %curMax, %curMax, %ftmp; 56 | shfl.sync.bfly.b32 %ftmp, %curMax, 4, 0x1f, 0xffffffff; 57 | max.f32 %curMax, %curMax, %ftmp; 58 | shfl.sync.bfly.b32 %ftmp, %curMax, 8, 0x1f, 0xffffffff; 59 | max.f32 %curMax, %curMax, %ftmp; 60 | shfl.sync.bfly.b32 %ftmp, %curMax, 16, 0x1f, 0xffffffff; 61 | max.f32 %curMax, %curMax, %ftmp; 62 | 63 | // Our warp writes to results[tid.x//32]. 64 | mov.u32 %stmp0, results; 65 | mov.u32 %stmp1, %tid.x; 66 | shr.b32 %stmp1, %stmp1, 5; 67 | shl.b32 %stmp1, %stmp1, 2; 68 | add.u32 %stmp0, %stmp0, %stmp1; 69 | // Only write from rank 0 of warp. 70 | mov.u32 %stmp1, %tid.x; 71 | and.b32 %stmp1, %stmp1, 31; 72 | setp.eq.u32 %p0, %stmp1, 0; 73 | @%p0 st.shared.f32 [%stmp0], %curMax; 74 | bar.sync 0; 75 | 76 | // Exit on all but first warp. 77 | mov.u32 %stmp1, %tid.x; 78 | and.b32 %stmp1, %stmp1, 992; // 1024 ^ 31 79 | setp.eq.u32 %p0, %stmp1, 0; 80 | @!%p0 ret; 81 | 82 | // Reduce the shared memory from the first warp. 83 | mov.u32 %stmp0, results; 84 | mov.u32 %stmp1, %tid.x; 85 | and.b32 %stmp1, %stmp1, 31; 86 | shl.b32 %stmp1, %stmp1, 2; 87 | add.u32 %stmp0, %stmp0, %stmp1; 88 | 89 | ld.shared.f32 %curMax, [%stmp0]; 90 | 91 | shfl.sync.bfly.b32 %ftmp, %curMax, 1, 0x1f, 0xffffffff; 92 | max.f32 %curMax, %curMax, %ftmp; 93 | shfl.sync.bfly.b32 %ftmp, %curMax, 2, 0x1f, 0xffffffff; 94 | max.f32 %curMax, %curMax, %ftmp; 95 | shfl.sync.bfly.b32 %ftmp, %curMax, 4, 0x1f, 0xffffffff; 96 | max.f32 %curMax, %curMax, %ftmp; 97 | shfl.sync.bfly.b32 %ftmp, %curMax, 8, 0x1f, 0xffffffff; 98 | max.f32 %curMax, %curMax, %ftmp; 99 | shfl.sync.bfly.b32 %ftmp, %curMax, 16, 0x1f, 0xffffffff; 100 | max.f32 %curMax, %curMax, %ftmp; 101 | 102 | setp.eq.u32 %p0, %stmp1, 0; 103 | @%p0 st.global.f32 [%ptrOut], %curMax; 104 | 105 | ret; 106 | } 107 | -------------------------------------------------------------------------------- /learn_ptx/kernels/reduction_all_max_naive_opt.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | .visible .entry reductionAllMaxNaiveOpt ( 6 | .param .u64 ptrIn, 7 | .param .u64 ptrOut, 8 | .param .u64 numBlocks 9 | ) { 10 | .reg .pred %p0; 11 | 12 | // Arguments 13 | .reg .u64 %ptrIn; 14 | .reg .u64 %ptrOut; 15 | .reg .u64 %numBlocks; 16 | 17 | .reg .u64 %i; 18 | .reg .f32 %curMax; 19 | .reg .u64 %tmp0; 20 | .reg .u32 %stmp<2>; 21 | .reg .f32 %ftmp; 22 | .reg .v4 .f32 %ftmpVec<2>; 23 | 24 | .shared .align 4 .f32 results[32]; 25 | 26 | // Load arguments. 27 | ld.param.u64 %ptrIn, [ptrIn]; 28 | ld.param.u64 %ptrOut, [ptrOut]; 29 | ld.param.u64 %numBlocks, [numBlocks]; 30 | 31 | // We support multiple blocks in the case where multiple 32 | // outputs are being written. 33 | // Input is offset 1024*4*numBlocks*ctaid.x, output offset by 4*ctaid.x 34 | cvt.u64.u32 %tmp0, %ctaid.x; 35 | shl.b64 %tmp0, %tmp0, 2; 36 | add.u64 %ptrOut, %ptrOut, %tmp0; 37 | mul.lo.u64 %tmp0, %tmp0, %numBlocks; 38 | shl.b64 %tmp0, %tmp0, 10; 39 | add.u64 %ptrIn, %ptrIn, %tmp0; 40 | 41 | // Each rank is offset by 16 bytes. 42 | cvt.u64.u32 %tmp0, %tid.x; 43 | shl.b64 %tmp0, %tmp0, 4; 44 | add.u64 %ptrIn, %ptrIn, %tmp0; 45 | 46 | // Base condition: use our output. 47 | ld.global.f32 %curMax, [%ptrIn]; 48 | 49 | mov.u64 %i, 0; 50 | loop_start: 51 | ld.global.v4.f32 %ftmpVec0, [%ptrIn]; 52 | ld.global.v4.f32 %ftmpVec1, [%ptrIn+16384]; 53 | add.u64 %ptrIn, %ptrIn, 32768; 54 | max.f32 %curMax, %curMax, %ftmpVec0.w; 55 | max.f32 %curMax, %curMax, %ftmpVec0.x; 56 | max.f32 %curMax, %curMax, %ftmpVec0.y; 57 | max.f32 %curMax, %curMax, %ftmpVec0.z; 58 | max.f32 %curMax, %curMax, %ftmpVec1.w; 59 | max.f32 %curMax, %curMax, %ftmpVec1.x; 60 | max.f32 %curMax, %curMax, %ftmpVec1.y; 61 | max.f32 %curMax, %curMax, %ftmpVec1.z; 62 | add.u64 %i, %i, 8; 63 | setp.lt.u64 %p0, %i, %numBlocks; 64 | @%p0 bra loop_start; 65 | loop_end: 66 | 67 | // Synchronize on warp using a hypercube. 68 | // https://en.wikipedia.org/wiki/Hypercube_(communication_pattern) 69 | shfl.sync.bfly.b32 %ftmp, %curMax, 1, 0x1f, 0xffffffff; 70 | max.f32 %curMax, %curMax, %ftmp; 71 | shfl.sync.bfly.b32 %ftmp, %curMax, 2, 0x1f, 0xffffffff; 72 | max.f32 %curMax, %curMax, %ftmp; 73 | shfl.sync.bfly.b32 %ftmp, %curMax, 4, 0x1f, 0xffffffff; 74 | max.f32 %curMax, %curMax, %ftmp; 75 | shfl.sync.bfly.b32 %ftmp, %curMax, 8, 0x1f, 0xffffffff; 76 | max.f32 %curMax, %curMax, %ftmp; 77 | shfl.sync.bfly.b32 %ftmp, %curMax, 16, 0x1f, 0xffffffff; 78 | max.f32 %curMax, %curMax, %ftmp; 79 | 80 | // Our warp writes to results[tid.x//32]. 81 | mov.u32 %stmp0, results; 82 | mov.u32 %stmp1, %tid.x; 83 | shr.b32 %stmp1, %stmp1, 5; 84 | shl.b32 %stmp1, %stmp1, 2; 85 | add.u32 %stmp0, %stmp0, %stmp1; 86 | // Only write from rank 0 of warp. 87 | mov.u32 %stmp1, %tid.x; 88 | and.b32 %stmp1, %stmp1, 31; 89 | setp.eq.u32 %p0, %stmp1, 0; 90 | @%p0 st.shared.f32 [%stmp0], %curMax; 91 | bar.sync 0; 92 | 93 | // Exit on all but first warp. 94 | mov.u32 %stmp1, %tid.x; 95 | and.b32 %stmp1, %stmp1, 992; // 1024 ^ 31 96 | setp.eq.u32 %p0, %stmp1, 0; 97 | @!%p0 ret; 98 | 99 | // Reduce the shared memory from the first warp. 100 | mov.u32 %stmp0, results; 101 | mov.u32 %stmp1, %tid.x; 102 | and.b32 %stmp1, %stmp1, 31; 103 | shl.b32 %stmp1, %stmp1, 2; 104 | add.u32 %stmp0, %stmp0, %stmp1; 105 | 106 | ld.shared.f32 %curMax, [%stmp0]; 107 | 108 | shfl.sync.bfly.b32 %ftmp, %curMax, 1, 0x1f, 0xffffffff; 109 | max.f32 %curMax, %curMax, %ftmp; 110 | shfl.sync.bfly.b32 %ftmp, %curMax, 2, 0x1f, 0xffffffff; 111 | max.f32 %curMax, %curMax, %ftmp; 112 | shfl.sync.bfly.b32 %ftmp, %curMax, 4, 0x1f, 0xffffffff; 113 | max.f32 %curMax, %curMax, %ftmp; 114 | shfl.sync.bfly.b32 %ftmp, %curMax, 8, 0x1f, 0xffffffff; 115 | max.f32 %curMax, %curMax, %ftmp; 116 | shfl.sync.bfly.b32 %ftmp, %curMax, 16, 0x1f, 0xffffffff; 117 | max.f32 %curMax, %curMax, %ftmp; 118 | 119 | setp.eq.u32 %p0, %stmp1, 0; 120 | @%p0 st.global.f32 [%ptrOut], %curMax; 121 | 122 | ret; 123 | } 124 | -------------------------------------------------------------------------------- /learn_ptx/kernels/reduction_all_max_naive_opt_flexible.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | // Similar to reduction_all_max_naive_opt.ptx, except that 6 | // we make use of multi-dimensional blocks where some warps 7 | // in the block can be inactive. This allows us to experiment 8 | // with different layouts for a fixed number of threads. 9 | // 10 | // In particular, block size should be (32, N, 1024/(32*N)), 11 | // in which case N warps will be active and the rest will be 12 | // stuck in a barrier to prevent other threads from occupying 13 | // the SM. 14 | 15 | .visible .entry reductionAllMaxNaiveOptFlexible ( 16 | .param .u64 ptrIn, 17 | .param .u64 ptrOut, 18 | .param .u64 numBlocks 19 | ) { 20 | .reg .pred %p0; 21 | 22 | // Arguments 23 | .reg .u64 %ptrIn; 24 | .reg .u64 %ptrOut; 25 | .reg .u64 %numBlocks; 26 | 27 | .reg .u64 %i; 28 | .reg .u64 %tmp<2>; 29 | .reg .u32 %stmp<2>; 30 | .reg .u64 %blockSize; 31 | .reg .f32 %curMax; 32 | .reg .f32 %ftmp; 33 | .reg .v4 .f32 %ftmpVec<2>; 34 | 35 | .shared .align 4 .f32 results[32]; 36 | 37 | // Load arguments. 38 | ld.param.u64 %ptrIn, [ptrIn]; 39 | ld.param.u64 %ptrOut, [ptrOut]; 40 | ld.param.u64 %numBlocks, [numBlocks]; 41 | 42 | // We might not do any work from certain threads of this block, 43 | // for experimentation purposes. 44 | // In particular, we do work from tid.z == 0. 45 | mov.u32 %stmp0, %tid.z; 46 | setp.eq.u32 %p0, %stmp0, 0; 47 | @!%p0 bra end_of_block; 48 | 49 | // blockSize = ntid.x * ntid.y (ignore ntid.z) 50 | mov.u32 %stmp0, %ntid.x; 51 | mov.u32 %stmp1, %ntid.y; 52 | mul.wide.u32 %blockSize, %stmp0, %stmp1; 53 | 54 | // Input is offset ctaid.x*4*blockSize*numBlocks, output offset by 4*ctaid.x 55 | cvt.u64.u32 %tmp0, %ctaid.x; 56 | shl.b64 %tmp0, %tmp0, 2; 57 | add.u64 %ptrOut, %ptrOut, %tmp0; 58 | mul.lo.u64 %tmp0, %tmp0, %blockSize; 59 | mul.lo.u64 %tmp0, %tmp0, %numBlocks; 60 | add.u64 %ptrIn, %ptrIn, %tmp0; 61 | 62 | // Each rank is offset by 16 bytes. 63 | cvt.u64.u32 %tmp0, %tid.x; 64 | cvt.u64.u32 %tmp1, %tid.y; 65 | shl.b64 %tmp1, %tmp1, 5; 66 | add.u64 %tmp0, %tmp0, %tmp1; 67 | shl.b64 %tmp0, %tmp0, 4; 68 | add.u64 %ptrIn, %ptrIn, %tmp0; 69 | 70 | // Base condition: use our output. 71 | ld.global.f32 %curMax, [%ptrIn]; 72 | 73 | // Stride is blockSize*16 bytes. 74 | shl.b64 %tmp0, %blockSize, 4; 75 | 76 | mov.u64 %i, 0; 77 | loop_start: 78 | ld.global.v4.f32 %ftmpVec0, [%ptrIn]; 79 | add.u64 %ptrIn, %ptrIn, %tmp0; 80 | ld.global.v4.f32 %ftmpVec1, [%ptrIn]; 81 | add.u64 %ptrIn, %ptrIn, %tmp0; 82 | max.f32 %curMax, %curMax, %ftmpVec0.w; 83 | max.f32 %curMax, %curMax, %ftmpVec0.x; 84 | max.f32 %curMax, %curMax, %ftmpVec0.y; 85 | max.f32 %curMax, %curMax, %ftmpVec0.z; 86 | max.f32 %curMax, %curMax, %ftmpVec1.w; 87 | max.f32 %curMax, %curMax, %ftmpVec1.x; 88 | max.f32 %curMax, %curMax, %ftmpVec1.y; 89 | max.f32 %curMax, %curMax, %ftmpVec1.z; 90 | add.u64 %i, %i, 8; 91 | setp.lt.u64 %p0, %i, %numBlocks; 92 | @%p0 bra loop_start; 93 | loop_end: 94 | 95 | // Synchronize on warp using a hypercube. 96 | // https://en.wikipedia.org/wiki/Hypercube_(communication_pattern) 97 | shfl.sync.bfly.b32 %ftmp, %curMax, 1, 0x1f, 0xffffffff; 98 | max.f32 %curMax, %curMax, %ftmp; 99 | shfl.sync.bfly.b32 %ftmp, %curMax, 2, 0x1f, 0xffffffff; 100 | max.f32 %curMax, %curMax, %ftmp; 101 | shfl.sync.bfly.b32 %ftmp, %curMax, 4, 0x1f, 0xffffffff; 102 | max.f32 %curMax, %curMax, %ftmp; 103 | shfl.sync.bfly.b32 %ftmp, %curMax, 8, 0x1f, 0xffffffff; 104 | max.f32 %curMax, %curMax, %ftmp; 105 | shfl.sync.bfly.b32 %ftmp, %curMax, 16, 0x1f, 0xffffffff; 106 | max.f32 %curMax, %curMax, %ftmp; 107 | 108 | // Our warp writes to results[tid.y]. 109 | mov.u32 %stmp0, results; 110 | mov.u32 %stmp1, %tid.y; 111 | shl.b32 %stmp1, %stmp1, 2; 112 | add.u32 %stmp0, %stmp0, %stmp1; 113 | // Only write from rank 0 of warp. 114 | mov.u32 %stmp1, %tid.x; 115 | setp.eq.u32 %p0, %stmp1, 0; 116 | @%p0 st.shared.f32 [%stmp0], %curMax; 117 | 118 | // Wait for all threads to write to shmem 119 | cvt.u32.u64 %stmp0, %blockSize; 120 | bar.sync 0, %stmp0; 121 | 122 | // Exit on all but first warp, where we do final reduction. 123 | mov.u32 %stmp1, %tid.y; 124 | setp.eq.u32 %p0, %stmp1, 0; 125 | @!%p0 bra end_of_block; 126 | 127 | // Reduce the shared memory from the first warp. 128 | mov.u32 %stmp1, %tid.x; 129 | mov.u32 %stmp0, %ntid.y; 130 | setp.lt.u32 %p0, %stmp1, %stmp0; // only reduce when tid.x < ntid.y 131 | shl.b32 %stmp1, %stmp1, 2; 132 | mov.u32 %stmp0, results; 133 | add.u32 %stmp0, %stmp0, %stmp1; 134 | @%p0 ld.shared.f32 %curMax, [%stmp0]; 135 | 136 | shfl.sync.bfly.b32 %ftmp, %curMax, 1, 0x1f, 0xffffffff; 137 | max.f32 %curMax, %curMax, %ftmp; 138 | shfl.sync.bfly.b32 %ftmp, %curMax, 2, 0x1f, 0xffffffff; 139 | max.f32 %curMax, %curMax, %ftmp; 140 | shfl.sync.bfly.b32 %ftmp, %curMax, 4, 0x1f, 0xffffffff; 141 | max.f32 %curMax, %curMax, %ftmp; 142 | shfl.sync.bfly.b32 %ftmp, %curMax, 8, 0x1f, 0xffffffff; 143 | max.f32 %curMax, %curMax, %ftmp; 144 | shfl.sync.bfly.b32 %ftmp, %curMax, 16, 0x1f, 0xffffffff; 145 | max.f32 %curMax, %curMax, %ftmp; 146 | 147 | setp.eq.u32 %p0, %stmp1, 0; 148 | @%p0 st.global.f32 [%ptrOut], %curMax; 149 | 150 | end_of_block: 151 | // Synchronize across all warps to make sure the block keeps 152 | // the SM busy and unable to schedule anything other blocks. 153 | bar.sync 1; 154 | 155 | ret; 156 | } 157 | -------------------------------------------------------------------------------- /learn_ptx/kernels/reduction_all_max_naive_opt_flexible_novec.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | // Similar to reduction_all_max_naive_opt_flexible, but without vector 6 | // loads and a simpler striding structure. 7 | 8 | .visible .entry reductionAllMaxNaiveOptFlexibleNovec ( 9 | .param .u64 ptrIn, 10 | .param .u64 ptrOut, 11 | .param .u64 numBlocks 12 | ) { 13 | .reg .pred %p0; 14 | 15 | // Arguments 16 | .reg .u64 %ptrIn; 17 | .reg .u64 %ptrOut; 18 | .reg .u64 %numBlocks; 19 | 20 | .reg .u64 %i; 21 | .reg .u64 %tmp<2>; 22 | .reg .u32 %stmp<2>; 23 | .reg .u64 %blockSize; 24 | .reg .f32 %curMax; 25 | .reg .f32 %ftmp; 26 | .reg .f32 %loaded<8>; 27 | 28 | .shared .align 4 .f32 results[32]; 29 | 30 | // Load arguments. 31 | ld.param.u64 %ptrIn, [ptrIn]; 32 | ld.param.u64 %ptrOut, [ptrOut]; 33 | ld.param.u64 %numBlocks, [numBlocks]; 34 | 35 | // We might not do any work from certain threads of this block, 36 | // for experimentation purposes. 37 | // In particular, we do work from tid.z == 0. 38 | mov.u32 %stmp0, %tid.z; 39 | setp.eq.u32 %p0, %stmp0, 0; 40 | @!%p0 bra end_of_block; 41 | 42 | // blockSize = ntid.x * ntid.y (ignore ntid.z) 43 | mov.u32 %stmp0, %ntid.x; 44 | mov.u32 %stmp1, %ntid.y; 45 | mul.wide.u32 %blockSize, %stmp0, %stmp1; 46 | 47 | // Input is offset ctaid.x*4*blockSize*numBlocks, output offset by 4*ctaid.x 48 | cvt.u64.u32 %tmp0, %ctaid.x; 49 | shl.b64 %tmp0, %tmp0, 2; 50 | add.u64 %ptrOut, %ptrOut, %tmp0; 51 | mul.lo.u64 %tmp0, %tmp0, %blockSize; 52 | mul.lo.u64 %tmp0, %tmp0, %numBlocks; 53 | add.u64 %ptrIn, %ptrIn, %tmp0; 54 | 55 | // Each rank is offset by 4 bytes. 56 | cvt.u64.u32 %tmp0, %tid.x; 57 | cvt.u64.u32 %tmp1, %tid.y; 58 | shl.b64 %tmp1, %tmp1, 5; 59 | add.u64 %tmp0, %tmp0, %tmp1; 60 | shl.b64 %tmp0, %tmp0, 2; 61 | add.u64 %ptrIn, %ptrIn, %tmp0; 62 | 63 | // Base condition: use our output. 64 | ld.global.f32 %curMax, [%ptrIn]; 65 | 66 | // Stride is blockSize*4 bytes. 67 | shl.b64 %tmp0, %blockSize, 2; 68 | 69 | mov.u64 %i, 0; 70 | loop_start: 71 | ld.global.f32 %loaded0, [%ptrIn]; 72 | add.u64 %ptrIn, %ptrIn, %tmp0; 73 | ld.global.f32 %loaded1, [%ptrIn]; 74 | add.u64 %ptrIn, %ptrIn, %tmp0; 75 | ld.global.f32 %loaded2, [%ptrIn]; 76 | add.u64 %ptrIn, %ptrIn, %tmp0; 77 | ld.global.f32 %loaded3, [%ptrIn]; 78 | add.u64 %ptrIn, %ptrIn, %tmp0; 79 | ld.global.f32 %loaded4, [%ptrIn]; 80 | add.u64 %ptrIn, %ptrIn, %tmp0; 81 | ld.global.f32 %loaded5, [%ptrIn]; 82 | add.u64 %ptrIn, %ptrIn, %tmp0; 83 | ld.global.f32 %loaded6, [%ptrIn]; 84 | add.u64 %ptrIn, %ptrIn, %tmp0; 85 | ld.global.f32 %loaded7, [%ptrIn]; 86 | add.u64 %ptrIn, %ptrIn, %tmp0; 87 | max.f32 %curMax, %curMax, %loaded0; 88 | max.f32 %curMax, %curMax, %loaded1; 89 | max.f32 %curMax, %curMax, %loaded2; 90 | max.f32 %curMax, %curMax, %loaded3; 91 | max.f32 %curMax, %curMax, %loaded4; 92 | max.f32 %curMax, %curMax, %loaded5; 93 | max.f32 %curMax, %curMax, %loaded6; 94 | max.f32 %curMax, %curMax, %loaded7; 95 | add.u64 %i, %i, 8; 96 | setp.lt.u64 %p0, %i, %numBlocks; 97 | @%p0 bra loop_start; 98 | loop_end: 99 | 100 | // Synchronize on warp using a hypercube. 101 | // https://en.wikipedia.org/wiki/Hypercube_(communication_pattern) 102 | shfl.sync.bfly.b32 %ftmp, %curMax, 1, 0x1f, 0xffffffff; 103 | max.f32 %curMax, %curMax, %ftmp; 104 | shfl.sync.bfly.b32 %ftmp, %curMax, 2, 0x1f, 0xffffffff; 105 | max.f32 %curMax, %curMax, %ftmp; 106 | shfl.sync.bfly.b32 %ftmp, %curMax, 4, 0x1f, 0xffffffff; 107 | max.f32 %curMax, %curMax, %ftmp; 108 | shfl.sync.bfly.b32 %ftmp, %curMax, 8, 0x1f, 0xffffffff; 109 | max.f32 %curMax, %curMax, %ftmp; 110 | shfl.sync.bfly.b32 %ftmp, %curMax, 16, 0x1f, 0xffffffff; 111 | max.f32 %curMax, %curMax, %ftmp; 112 | 113 | // Our warp writes to results[tid.y]. 114 | mov.u32 %stmp0, results; 115 | mov.u32 %stmp1, %tid.y; 116 | shl.b32 %stmp1, %stmp1, 2; 117 | add.u32 %stmp0, %stmp0, %stmp1; 118 | // Only write from rank 0 of warp. 119 | mov.u32 %stmp1, %tid.x; 120 | setp.eq.u32 %p0, %stmp1, 0; 121 | @%p0 st.shared.f32 [%stmp0], %curMax; 122 | 123 | // Wait for all threads to write to shmem 124 | cvt.u32.u64 %stmp0, %blockSize; 125 | bar.sync 0, %stmp0; 126 | 127 | // Exit on all but first warp, where we do final reduction. 128 | mov.u32 %stmp1, %tid.y; 129 | setp.eq.u32 %p0, %stmp1, 0; 130 | @!%p0 bra end_of_block; 131 | 132 | // Reduce the shared memory from the first warp. 133 | mov.u32 %stmp1, %tid.x; 134 | mov.u32 %stmp0, %ntid.y; 135 | setp.lt.u32 %p0, %stmp1, %stmp0; // only reduce when tid.x < ntid.y 136 | shl.b32 %stmp1, %stmp1, 2; 137 | mov.u32 %stmp0, results; 138 | add.u32 %stmp0, %stmp0, %stmp1; 139 | @%p0 ld.shared.f32 %curMax, [%stmp0]; 140 | 141 | shfl.sync.bfly.b32 %ftmp, %curMax, 1, 0x1f, 0xffffffff; 142 | max.f32 %curMax, %curMax, %ftmp; 143 | shfl.sync.bfly.b32 %ftmp, %curMax, 2, 0x1f, 0xffffffff; 144 | max.f32 %curMax, %curMax, %ftmp; 145 | shfl.sync.bfly.b32 %ftmp, %curMax, 4, 0x1f, 0xffffffff; 146 | max.f32 %curMax, %curMax, %ftmp; 147 | shfl.sync.bfly.b32 %ftmp, %curMax, 8, 0x1f, 0xffffffff; 148 | max.f32 %curMax, %curMax, %ftmp; 149 | shfl.sync.bfly.b32 %ftmp, %curMax, 16, 0x1f, 0xffffffff; 150 | max.f32 %curMax, %curMax, %ftmp; 151 | 152 | setp.eq.u32 %p0, %stmp1, 0; 153 | @%p0 st.global.f32 [%ptrOut], %curMax; 154 | 155 | end_of_block: 156 | // Synchronize across all warps to make sure the block keeps 157 | // the SM busy and unable to schedule anything other blocks. 158 | bar.sync 1; 159 | 160 | ret; 161 | } 162 | -------------------------------------------------------------------------------- /learn_ptx/kernels/reduction_all_max_naive_opt_flexible_sin.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | // Like reduction_all_max_naive_opt_flexible, but 6 | // applies sin() to every input. 7 | 8 | .visible .entry reductionAllMaxNaiveOptFlexibleSin ( 9 | .param .u64 ptrIn, 10 | .param .u64 ptrOut, 11 | .param .u64 numBlocks 12 | ) { 13 | .reg .pred %p0; 14 | 15 | // Arguments 16 | .reg .u64 %ptrIn; 17 | .reg .u64 %ptrOut; 18 | .reg .u64 %numBlocks; 19 | 20 | .reg .u64 %i; 21 | .reg .u64 %tmp<2>; 22 | .reg .u32 %stmp<2>; 23 | .reg .u64 %blockSize; 24 | .reg .f32 %curMax; 25 | .reg .f32 %ftmp; 26 | .reg .v4 .f32 %ftmpVec<2>; 27 | 28 | .shared .align 4 .f32 results[32]; 29 | 30 | // Load arguments. 31 | ld.param.u64 %ptrIn, [ptrIn]; 32 | ld.param.u64 %ptrOut, [ptrOut]; 33 | ld.param.u64 %numBlocks, [numBlocks]; 34 | 35 | // We might not do any work from certain threads of this block, 36 | // for experimentation purposes. 37 | // In particular, we do work from tid.z == 0. 38 | mov.u32 %stmp0, %tid.z; 39 | setp.eq.u32 %p0, %stmp0, 0; 40 | @!%p0 bra end_of_block; 41 | 42 | // blockSize = ntid.x * ntid.y (ignore ntid.z) 43 | mov.u32 %stmp0, %ntid.x; 44 | mov.u32 %stmp1, %ntid.y; 45 | mul.wide.u32 %blockSize, %stmp0, %stmp1; 46 | 47 | // Input is offset ctaid.x*4*blockSize*numBlocks, output offset by 4*ctaid.x 48 | cvt.u64.u32 %tmp0, %ctaid.x; 49 | shl.b64 %tmp0, %tmp0, 2; 50 | add.u64 %ptrOut, %ptrOut, %tmp0; 51 | mul.lo.u64 %tmp0, %tmp0, %blockSize; 52 | mul.lo.u64 %tmp0, %tmp0, %numBlocks; 53 | add.u64 %ptrIn, %ptrIn, %tmp0; 54 | 55 | // Each rank is offset by 16 bytes. 56 | cvt.u64.u32 %tmp0, %tid.x; 57 | cvt.u64.u32 %tmp1, %tid.y; 58 | shl.b64 %tmp1, %tmp1, 5; 59 | add.u64 %tmp0, %tmp0, %tmp1; 60 | shl.b64 %tmp0, %tmp0, 4; 61 | add.u64 %ptrIn, %ptrIn, %tmp0; 62 | 63 | // Base condition: use our output. 64 | ld.global.f32 %curMax, [%ptrIn]; 65 | sin.approx.ftz.f32 %curMax, %curMax; 66 | 67 | // Stride is blockSize*16 bytes. 68 | shl.b64 %tmp0, %blockSize, 4; 69 | 70 | mov.u64 %i, 0; 71 | loop_start: 72 | ld.global.v4.f32 %ftmpVec0, [%ptrIn]; 73 | add.u64 %ptrIn, %ptrIn, %tmp0; 74 | ld.global.v4.f32 %ftmpVec1, [%ptrIn]; 75 | add.u64 %ptrIn, %ptrIn, %tmp0; 76 | sin.approx.ftz.f32 %ftmpVec0.w, %ftmpVec0.w; 77 | sin.approx.ftz.f32 %ftmpVec0.x, %ftmpVec0.x; 78 | sin.approx.ftz.f32 %ftmpVec0.y, %ftmpVec0.y; 79 | sin.approx.ftz.f32 %ftmpVec0.z, %ftmpVec0.z; 80 | sin.approx.ftz.f32 %ftmpVec1.w, %ftmpVec1.w; 81 | sin.approx.ftz.f32 %ftmpVec1.x, %ftmpVec1.x; 82 | sin.approx.ftz.f32 %ftmpVec1.y, %ftmpVec1.y; 83 | sin.approx.ftz.f32 %ftmpVec1.z, %ftmpVec1.z; 84 | max.f32 %curMax, %curMax, %ftmpVec0.w; 85 | max.f32 %curMax, %curMax, %ftmpVec0.x; 86 | max.f32 %curMax, %curMax, %ftmpVec0.y; 87 | max.f32 %curMax, %curMax, %ftmpVec0.z; 88 | max.f32 %curMax, %curMax, %ftmpVec1.w; 89 | max.f32 %curMax, %curMax, %ftmpVec1.x; 90 | max.f32 %curMax, %curMax, %ftmpVec1.y; 91 | max.f32 %curMax, %curMax, %ftmpVec1.z; 92 | add.u64 %i, %i, 8; 93 | setp.lt.u64 %p0, %i, %numBlocks; 94 | @%p0 bra loop_start; 95 | loop_end: 96 | 97 | // Synchronize on warp using a hypercube. 98 | // https://en.wikipedia.org/wiki/Hypercube_(communication_pattern) 99 | shfl.sync.bfly.b32 %ftmp, %curMax, 1, 0x1f, 0xffffffff; 100 | max.f32 %curMax, %curMax, %ftmp; 101 | shfl.sync.bfly.b32 %ftmp, %curMax, 2, 0x1f, 0xffffffff; 102 | max.f32 %curMax, %curMax, %ftmp; 103 | shfl.sync.bfly.b32 %ftmp, %curMax, 4, 0x1f, 0xffffffff; 104 | max.f32 %curMax, %curMax, %ftmp; 105 | shfl.sync.bfly.b32 %ftmp, %curMax, 8, 0x1f, 0xffffffff; 106 | max.f32 %curMax, %curMax, %ftmp; 107 | shfl.sync.bfly.b32 %ftmp, %curMax, 16, 0x1f, 0xffffffff; 108 | max.f32 %curMax, %curMax, %ftmp; 109 | 110 | // Our warp writes to results[tid.y]. 111 | mov.u32 %stmp0, results; 112 | mov.u32 %stmp1, %tid.y; 113 | shl.b32 %stmp1, %stmp1, 2; 114 | add.u32 %stmp0, %stmp0, %stmp1; 115 | // Only write from rank 0 of warp. 116 | mov.u32 %stmp1, %tid.x; 117 | setp.eq.u32 %p0, %stmp1, 0; 118 | @%p0 st.shared.f32 [%stmp0], %curMax; 119 | 120 | // Wait for all threads to write to shmem 121 | cvt.u32.u64 %stmp0, %blockSize; 122 | bar.sync 0, %stmp0; 123 | 124 | // Exit on all but first warp, where we do final reduction. 125 | mov.u32 %stmp1, %tid.y; 126 | setp.eq.u32 %p0, %stmp1, 0; 127 | @!%p0 bra end_of_block; 128 | 129 | // Reduce the shared memory from the first warp. 130 | mov.u32 %stmp1, %tid.x; 131 | mov.u32 %stmp0, %ntid.y; 132 | setp.lt.u32 %p0, %stmp1, %stmp0; // only reduce when tid.x < ntid.y 133 | shl.b32 %stmp1, %stmp1, 2; 134 | mov.u32 %stmp0, results; 135 | add.u32 %stmp0, %stmp0, %stmp1; 136 | @%p0 ld.shared.f32 %curMax, [%stmp0]; 137 | 138 | shfl.sync.bfly.b32 %ftmp, %curMax, 1, 0x1f, 0xffffffff; 139 | max.f32 %curMax, %curMax, %ftmp; 140 | shfl.sync.bfly.b32 %ftmp, %curMax, 2, 0x1f, 0xffffffff; 141 | max.f32 %curMax, %curMax, %ftmp; 142 | shfl.sync.bfly.b32 %ftmp, %curMax, 4, 0x1f, 0xffffffff; 143 | max.f32 %curMax, %curMax, %ftmp; 144 | shfl.sync.bfly.b32 %ftmp, %curMax, 8, 0x1f, 0xffffffff; 145 | max.f32 %curMax, %curMax, %ftmp; 146 | shfl.sync.bfly.b32 %ftmp, %curMax, 16, 0x1f, 0xffffffff; 147 | max.f32 %curMax, %curMax, %ftmp; 148 | 149 | setp.eq.u32 %p0, %stmp1, 0; 150 | @%p0 st.global.f32 [%ptrOut], %curMax; 151 | 152 | end_of_block: 153 | // Synchronize across all warps to make sure the block keeps 154 | // the SM busy and unable to schedule anything other blocks. 155 | bar.sync 1; 156 | 157 | ret; 158 | } 159 | -------------------------------------------------------------------------------- /learn_ptx/kernels/reduction_all_max_naive_opt_flexible_sin_cpasync.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_80 // enough for cp.async 3 | .address_size 64 4 | 5 | // Like reduction_all_max_naive_opt_flexible_sin.ptx, but requires A100s 6 | // and uses cp.async. 7 | 8 | .visible .entry reductionAllMaxNaiveOptFlexibleSin ( 9 | .param .u64 ptrIn, 10 | .param .u64 ptrOut, 11 | .param .u64 numBlocks 12 | ) { 13 | .reg .pred %p0; 14 | 15 | // Arguments 16 | .reg .u64 %ptrIn; 17 | .reg .u64 %ptrOut; 18 | .reg .u64 %numBlocks; 19 | 20 | .reg .u64 %i; 21 | .reg .u64 %tmp<2>; 22 | .reg .u32 %stmp<2>; 23 | .reg .u64 %blockSize; 24 | .reg .f32 %curMax; 25 | .reg .f32 %ftmp; 26 | .reg .v4 .f32 %ftmpVec<2>; 27 | .reg .u32 %curCopyBuffer; 28 | .reg .u32 %otherCopyBuffer; 29 | 30 | .shared .align 16 .f32 copyBuffer1[4096]; 31 | .shared .align 16 .f32 copyBuffer2[4096]; 32 | .shared .align 4 .f32 results[32]; 33 | 34 | // Load arguments. 35 | ld.param.u64 %ptrIn, [ptrIn]; 36 | ld.param.u64 %ptrOut, [ptrOut]; 37 | ld.param.u64 %numBlocks, [numBlocks]; 38 | 39 | // We might not do any work from certain threads of this block, 40 | // for experimentation purposes. 41 | // In particular, we do work from tid.z == 0. 42 | mov.u32 %stmp0, %tid.z; 43 | setp.eq.u32 %p0, %stmp0, 0; 44 | @!%p0 bra end_of_block; 45 | 46 | // blockSize = ntid.x * ntid.y (ignore ntid.z) 47 | mov.u32 %stmp0, %ntid.x; 48 | mov.u32 %stmp1, %ntid.y; 49 | mul.wide.u32 %blockSize, %stmp0, %stmp1; 50 | 51 | // Input is offset ctaid.x*4*blockSize*numBlocks, output offset by 4*ctaid.x 52 | cvt.u64.u32 %tmp0, %ctaid.x; 53 | shl.b64 %tmp0, %tmp0, 2; 54 | add.u64 %ptrOut, %ptrOut, %tmp0; 55 | mul.lo.u64 %tmp0, %tmp0, %blockSize; 56 | mul.lo.u64 %tmp0, %tmp0, %numBlocks; 57 | add.u64 %ptrIn, %ptrIn, %tmp0; 58 | 59 | // Each rank is offset by 16 bytes. 60 | cvt.u64.u32 %tmp0, %tid.x; 61 | cvt.u64.u32 %tmp1, %tid.y; 62 | shl.b64 %tmp1, %tmp1, 5; 63 | add.u64 %tmp0, %tmp0, %tmp1; 64 | shl.b64 %tmp0, %tmp0, 4; 65 | add.u64 %ptrIn, %ptrIn, %tmp0; 66 | 67 | // Base condition: use minimum value 68 | mov.f32 %curMax, -1.0; 69 | 70 | // Stride is blockSize*16 bytes. 71 | shl.b64 %tmp0, %blockSize, 4; 72 | 73 | // We copy into slots of 16 bytes in our copy buffers. 74 | mov.u32 %stmp0, %tid.y; 75 | mov.u32 %stmp1, %tid.x; 76 | shl.b32 %stmp0, %stmp0, 5; 77 | add.u32 %stmp0, %stmp0, %stmp1; 78 | shl.b32 %stmp0, %stmp0, 4; 79 | // Copy buffers will alternate back and forth, but each 80 | // thread will always be responsible for the same part 81 | // of each. 82 | mov.u32 %curCopyBuffer, copyBuffer1; 83 | mov.u32 %otherCopyBuffer, copyBuffer2; 84 | add.u32 %curCopyBuffer, %curCopyBuffer, %stmp0; 85 | add.u32 %otherCopyBuffer, %otherCopyBuffer, %stmp0; 86 | 87 | // Initiate first copy. 88 | cp.async.ca.shared.global [%curCopyBuffer], [%ptrIn], 16; 89 | 90 | mov.u64 %i, 0; 91 | loop_start: 92 | // Wait for all copies to be complete. 93 | cp.async.wait_all; 94 | cvt.u32.u64 %stmp0, %blockSize; 95 | bar.sync 0, %stmp0; 96 | 97 | // Swap copy buffers, so we will always be reading 98 | // from %otherCopyBuffer during the reduction. 99 | mov.u32 %stmp0, %curCopyBuffer; 100 | mov.u32 %curCopyBuffer, %otherCopyBuffer; 101 | mov.u32 %otherCopyBuffer, %stmp0; 102 | 103 | add.u64 %i, %i, 4; 104 | setp.lt.u64 %p0, %i, %numBlocks; 105 | @!%p0 bra skip_copy; 106 | 107 | // Copy the next region in the background. 108 | add.u64 %ptrIn, %ptrIn, %tmp0; 109 | cp.async.ca.shared.global [%curCopyBuffer], [%ptrIn], 16; 110 | 111 | skip_copy: 112 | 113 | // For now we ignore bank conflicts, but this is just about 114 | // the worst access pattern. 115 | ld.shared.v4.f32 %ftmpVec0, [%otherCopyBuffer]; 116 | sin.approx.ftz.f32 %ftmpVec0.w, %ftmpVec0.w; 117 | sin.approx.ftz.f32 %ftmpVec0.x, %ftmpVec0.x; 118 | sin.approx.ftz.f32 %ftmpVec0.y, %ftmpVec0.y; 119 | sin.approx.ftz.f32 %ftmpVec0.z, %ftmpVec0.z; 120 | max.f32 %curMax, %curMax, %ftmpVec0.w; 121 | max.f32 %curMax, %curMax, %ftmpVec0.x; 122 | max.f32 %curMax, %curMax, %ftmpVec0.y; 123 | max.f32 %curMax, %curMax, %ftmpVec0.z; 124 | 125 | @%p0 bra loop_start; 126 | loop_end: 127 | 128 | // Synchronize on warp using a hypercube. 129 | // https://en.wikipedia.org/wiki/Hypercube_(communication_pattern) 130 | shfl.sync.bfly.b32 %ftmp, %curMax, 1, 0x1f, 0xffffffff; 131 | max.f32 %curMax, %curMax, %ftmp; 132 | shfl.sync.bfly.b32 %ftmp, %curMax, 2, 0x1f, 0xffffffff; 133 | max.f32 %curMax, %curMax, %ftmp; 134 | shfl.sync.bfly.b32 %ftmp, %curMax, 4, 0x1f, 0xffffffff; 135 | max.f32 %curMax, %curMax, %ftmp; 136 | shfl.sync.bfly.b32 %ftmp, %curMax, 8, 0x1f, 0xffffffff; 137 | max.f32 %curMax, %curMax, %ftmp; 138 | shfl.sync.bfly.b32 %ftmp, %curMax, 16, 0x1f, 0xffffffff; 139 | max.f32 %curMax, %curMax, %ftmp; 140 | 141 | // Our warp writes to results[tid.y]. 142 | mov.u32 %stmp0, results; 143 | mov.u32 %stmp1, %tid.y; 144 | shl.b32 %stmp1, %stmp1, 2; 145 | add.u32 %stmp0, %stmp0, %stmp1; 146 | // Only write from rank 0 of warp. 147 | mov.u32 %stmp1, %tid.x; 148 | setp.eq.u32 %p0, %stmp1, 0; 149 | @%p0 st.shared.f32 [%stmp0], %curMax; 150 | 151 | // Wait for all threads to write to shmem 152 | cvt.u32.u64 %stmp0, %blockSize; 153 | bar.sync 0, %stmp0; 154 | 155 | // Exit on all but first warp, where we do final reduction. 156 | mov.u32 %stmp1, %tid.y; 157 | setp.eq.u32 %p0, %stmp1, 0; 158 | @!%p0 bra end_of_block; 159 | 160 | // Reduce the shared memory from the first warp. 161 | mov.u32 %stmp1, %tid.x; 162 | mov.u32 %stmp0, %ntid.y; 163 | setp.lt.u32 %p0, %stmp1, %stmp0; // only reduce when tid.x < ntid.y 164 | shl.b32 %stmp1, %stmp1, 2; 165 | mov.u32 %stmp0, results; 166 | add.u32 %stmp0, %stmp0, %stmp1; 167 | @%p0 ld.shared.f32 %curMax, [%stmp0]; 168 | 169 | shfl.sync.bfly.b32 %ftmp, %curMax, 1, 0x1f, 0xffffffff; 170 | max.f32 %curMax, %curMax, %ftmp; 171 | shfl.sync.bfly.b32 %ftmp, %curMax, 2, 0x1f, 0xffffffff; 172 | max.f32 %curMax, %curMax, %ftmp; 173 | shfl.sync.bfly.b32 %ftmp, %curMax, 4, 0x1f, 0xffffffff; 174 | max.f32 %curMax, %curMax, %ftmp; 175 | shfl.sync.bfly.b32 %ftmp, %curMax, 8, 0x1f, 0xffffffff; 176 | max.f32 %curMax, %curMax, %ftmp; 177 | shfl.sync.bfly.b32 %ftmp, %curMax, 16, 0x1f, 0xffffffff; 178 | max.f32 %curMax, %curMax, %ftmp; 179 | 180 | setp.eq.u32 %p0, %stmp1, 0; 181 | @%p0 st.global.f32 [%ptrOut], %curMax; 182 | 183 | end_of_block: 184 | // Synchronize across all warps to make sure the block keeps 185 | // the SM busy and unable to schedule anything other blocks. 186 | bar.sync 1; 187 | 188 | ret; 189 | } -------------------------------------------------------------------------------- /learn_ptx/kernels/reduction_all_max_naive_opt_flexible_widevec.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | // Similar to reduction_all_max_naive_opt_flexible, but with more 6 | // dense vector loads (sequential pairs of loads). 7 | 8 | .visible .entry reductionAllMaxNaiveOptFlexibleWidevec ( 9 | .param .u64 ptrIn, 10 | .param .u64 ptrOut, 11 | .param .u64 numBlocks 12 | ) { 13 | .reg .pred %p0; 14 | 15 | // Arguments 16 | .reg .u64 %ptrIn; 17 | .reg .u64 %ptrOut; 18 | .reg .u64 %numBlocks; 19 | 20 | .reg .u64 %i; 21 | .reg .u64 %tmp<2>; 22 | .reg .u32 %stmp<2>; 23 | .reg .u64 %blockSize; 24 | .reg .f32 %curMax; 25 | .reg .f32 %ftmp; 26 | .reg .v4 .f32 %ftmpVec<2>; 27 | 28 | .shared .align 4 .f32 results[32]; 29 | 30 | // Load arguments. 31 | ld.param.u64 %ptrIn, [ptrIn]; 32 | ld.param.u64 %ptrOut, [ptrOut]; 33 | ld.param.u64 %numBlocks, [numBlocks]; 34 | 35 | // We might not do any work from certain threads of this block, 36 | // for experimentation purposes. 37 | // In particular, we do work from tid.z == 0. 38 | mov.u32 %stmp0, %tid.z; 39 | setp.eq.u32 %p0, %stmp0, 0; 40 | @!%p0 bra end_of_block; 41 | 42 | // blockSize = ntid.x * ntid.y (ignore ntid.z) 43 | mov.u32 %stmp0, %ntid.x; 44 | mov.u32 %stmp1, %ntid.y; 45 | mul.wide.u32 %blockSize, %stmp0, %stmp1; 46 | 47 | // Input is offset ctaid.x*4*blockSize*numBlocks, output offset by 4*ctaid.x 48 | cvt.u64.u32 %tmp0, %ctaid.x; 49 | shl.b64 %tmp0, %tmp0, 2; 50 | add.u64 %ptrOut, %ptrOut, %tmp0; 51 | mul.lo.u64 %tmp0, %tmp0, %blockSize; 52 | mul.lo.u64 %tmp0, %tmp0, %numBlocks; 53 | add.u64 %ptrIn, %ptrIn, %tmp0; 54 | 55 | // Each rank is offset by 32 bytes. 56 | cvt.u64.u32 %tmp0, %tid.x; 57 | cvt.u64.u32 %tmp1, %tid.y; 58 | shl.b64 %tmp1, %tmp1, 5; 59 | add.u64 %tmp0, %tmp0, %tmp1; 60 | shl.b64 %tmp0, %tmp0, 5; 61 | add.u64 %ptrIn, %ptrIn, %tmp0; 62 | 63 | // Base condition: use our output. 64 | ld.global.f32 %curMax, [%ptrIn]; 65 | 66 | // Stride is blockSize*32 bytes. 67 | shl.b64 %tmp0, %blockSize, 5; 68 | 69 | mov.u64 %i, 0; 70 | loop_start: 71 | ld.global.v4.f32 %ftmpVec0, [%ptrIn]; 72 | ld.global.v4.f32 %ftmpVec1, [%ptrIn+16]; 73 | add.u64 %ptrIn, %ptrIn, %tmp0; 74 | max.f32 %curMax, %curMax, %ftmpVec0.w; 75 | max.f32 %curMax, %curMax, %ftmpVec0.x; 76 | max.f32 %curMax, %curMax, %ftmpVec0.y; 77 | max.f32 %curMax, %curMax, %ftmpVec0.z; 78 | max.f32 %curMax, %curMax, %ftmpVec1.w; 79 | max.f32 %curMax, %curMax, %ftmpVec1.x; 80 | max.f32 %curMax, %curMax, %ftmpVec1.y; 81 | max.f32 %curMax, %curMax, %ftmpVec1.z; 82 | add.u64 %i, %i, 8; 83 | setp.lt.u64 %p0, %i, %numBlocks; 84 | @%p0 bra loop_start; 85 | loop_end: 86 | 87 | // Synchronize on warp using a hypercube. 88 | // https://en.wikipedia.org/wiki/Hypercube_(communication_pattern) 89 | shfl.sync.bfly.b32 %ftmp, %curMax, 1, 0x1f, 0xffffffff; 90 | max.f32 %curMax, %curMax, %ftmp; 91 | shfl.sync.bfly.b32 %ftmp, %curMax, 2, 0x1f, 0xffffffff; 92 | max.f32 %curMax, %curMax, %ftmp; 93 | shfl.sync.bfly.b32 %ftmp, %curMax, 4, 0x1f, 0xffffffff; 94 | max.f32 %curMax, %curMax, %ftmp; 95 | shfl.sync.bfly.b32 %ftmp, %curMax, 8, 0x1f, 0xffffffff; 96 | max.f32 %curMax, %curMax, %ftmp; 97 | shfl.sync.bfly.b32 %ftmp, %curMax, 16, 0x1f, 0xffffffff; 98 | max.f32 %curMax, %curMax, %ftmp; 99 | 100 | // Our warp writes to results[tid.y]. 101 | mov.u32 %stmp0, results; 102 | mov.u32 %stmp1, %tid.y; 103 | shl.b32 %stmp1, %stmp1, 2; 104 | add.u32 %stmp0, %stmp0, %stmp1; 105 | // Only write from rank 0 of warp. 106 | mov.u32 %stmp1, %tid.x; 107 | setp.eq.u32 %p0, %stmp1, 0; 108 | @%p0 st.shared.f32 [%stmp0], %curMax; 109 | 110 | // Wait for all threads to write to shmem 111 | cvt.u32.u64 %stmp0, %blockSize; 112 | bar.sync 0, %stmp0; 113 | 114 | // Exit on all but first warp, where we do final reduction. 115 | mov.u32 %stmp1, %tid.y; 116 | setp.eq.u32 %p0, %stmp1, 0; 117 | @!%p0 bra end_of_block; 118 | 119 | // Reduce the shared memory from the first warp. 120 | mov.u32 %stmp1, %tid.x; 121 | mov.u32 %stmp0, %ntid.y; 122 | setp.lt.u32 %p0, %stmp1, %stmp0; // only reduce when tid.x < ntid.y 123 | shl.b32 %stmp1, %stmp1, 2; 124 | mov.u32 %stmp0, results; 125 | add.u32 %stmp0, %stmp0, %stmp1; 126 | @%p0 ld.shared.f32 %curMax, [%stmp0]; 127 | 128 | shfl.sync.bfly.b32 %ftmp, %curMax, 1, 0x1f, 0xffffffff; 129 | max.f32 %curMax, %curMax, %ftmp; 130 | shfl.sync.bfly.b32 %ftmp, %curMax, 2, 0x1f, 0xffffffff; 131 | max.f32 %curMax, %curMax, %ftmp; 132 | shfl.sync.bfly.b32 %ftmp, %curMax, 4, 0x1f, 0xffffffff; 133 | max.f32 %curMax, %curMax, %ftmp; 134 | shfl.sync.bfly.b32 %ftmp, %curMax, 8, 0x1f, 0xffffffff; 135 | max.f32 %curMax, %curMax, %ftmp; 136 | shfl.sync.bfly.b32 %ftmp, %curMax, 16, 0x1f, 0xffffffff; 137 | max.f32 %curMax, %curMax, %ftmp; 138 | 139 | setp.eq.u32 %p0, %stmp1, 0; 140 | @%p0 st.global.f32 [%ptrOut], %curMax; 141 | 142 | end_of_block: 143 | // Synchronize across all warps to make sure the block keeps 144 | // the SM busy and unable to schedule anything other blocks. 145 | bar.sync 1; 146 | 147 | ret; 148 | } 149 | -------------------------------------------------------------------------------- /learn_ptx/kernels/reduction_all_max_naive_opt_novec.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | .visible .entry reductionAllMaxNaiveOptNoVec ( 6 | .param .u64 ptrIn, 7 | .param .u64 ptrOut, 8 | .param .u64 numBlocks 9 | ) { 10 | .reg .pred %p0; 11 | 12 | // Arguments 13 | .reg .u64 %ptrIn; 14 | .reg .u64 %ptrOut; 15 | .reg .u64 %numBlocks; 16 | 17 | .reg .u64 %i; 18 | .reg .f32 %curMax; 19 | .reg .u64 %tmp0; 20 | .reg .u32 %stmp<2>; 21 | .reg .f32 %ftmp; 22 | .reg .f32 %ftmpLoaded<8>; 23 | 24 | .shared .align 4 .f32 results[32]; 25 | 26 | // Load arguments. 27 | ld.param.u64 %ptrIn, [ptrIn]; 28 | ld.param.u64 %ptrOut, [ptrOut]; 29 | ld.param.u64 %numBlocks, [numBlocks]; 30 | 31 | // We support multiple blocks in the case where multiple 32 | // outputs are being written. 33 | // Input is offset 1024*4*numBlocks*ctaid.x, output offset by 4*ctaid.x 34 | cvt.u64.u32 %tmp0, %ctaid.x; 35 | shl.b64 %tmp0, %tmp0, 2; 36 | add.u64 %ptrOut, %ptrOut, %tmp0; 37 | mul.lo.u64 %tmp0, %tmp0, %numBlocks; 38 | shl.b64 %tmp0, %tmp0, 10; 39 | add.u64 %ptrIn, %ptrIn, %tmp0; 40 | 41 | // Each rank is offset by 16 bytes. 42 | cvt.u64.u32 %tmp0, %tid.x; 43 | shl.b64 %tmp0, %tmp0, 4; 44 | add.u64 %ptrIn, %ptrIn, %tmp0; 45 | 46 | // Base condition: use our output. 47 | ld.global.f32 %curMax, [%ptrIn]; 48 | 49 | mov.u64 %i, 0; 50 | loop_start: 51 | ld.global.f32 %ftmpLoaded0, [%ptrIn]; 52 | ld.global.f32 %ftmpLoaded1, [%ptrIn+4]; 53 | ld.global.f32 %ftmpLoaded2, [%ptrIn+8]; 54 | ld.global.f32 %ftmpLoaded3, [%ptrIn+12]; 55 | ld.global.f32 %ftmpLoaded4, [%ptrIn+16384]; 56 | ld.global.f32 %ftmpLoaded5, [%ptrIn+16388]; 57 | ld.global.f32 %ftmpLoaded6, [%ptrIn+16392]; 58 | ld.global.f32 %ftmpLoaded7, [%ptrIn+16396]; 59 | add.u64 %ptrIn, %ptrIn, 32768; 60 | max.f32 %curMax, %curMax, %ftmpLoaded0; 61 | max.f32 %curMax, %curMax, %ftmpLoaded1; 62 | max.f32 %curMax, %curMax, %ftmpLoaded2; 63 | max.f32 %curMax, %curMax, %ftmpLoaded3; 64 | max.f32 %curMax, %curMax, %ftmpLoaded4; 65 | max.f32 %curMax, %curMax, %ftmpLoaded5; 66 | max.f32 %curMax, %curMax, %ftmpLoaded6; 67 | max.f32 %curMax, %curMax, %ftmpLoaded7; 68 | add.u64 %i, %i, 8; 69 | setp.lt.u64 %p0, %i, %numBlocks; 70 | @%p0 bra loop_start; 71 | loop_end: 72 | 73 | // Synchronize on warp using a hypercube. 74 | // https://en.wikipedia.org/wiki/Hypercube_(communication_pattern) 75 | shfl.sync.bfly.b32 %ftmp, %curMax, 1, 0x1f, 0xffffffff; 76 | max.f32 %curMax, %curMax, %ftmp; 77 | shfl.sync.bfly.b32 %ftmp, %curMax, 2, 0x1f, 0xffffffff; 78 | max.f32 %curMax, %curMax, %ftmp; 79 | shfl.sync.bfly.b32 %ftmp, %curMax, 4, 0x1f, 0xffffffff; 80 | max.f32 %curMax, %curMax, %ftmp; 81 | shfl.sync.bfly.b32 %ftmp, %curMax, 8, 0x1f, 0xffffffff; 82 | max.f32 %curMax, %curMax, %ftmp; 83 | shfl.sync.bfly.b32 %ftmp, %curMax, 16, 0x1f, 0xffffffff; 84 | max.f32 %curMax, %curMax, %ftmp; 85 | 86 | // Our warp writes to results[tid.x//32]. 87 | mov.u32 %stmp0, results; 88 | mov.u32 %stmp1, %tid.x; 89 | shr.b32 %stmp1, %stmp1, 5; 90 | shl.b32 %stmp1, %stmp1, 2; 91 | add.u32 %stmp0, %stmp0, %stmp1; 92 | // Only write from rank 0 of warp. 93 | mov.u32 %stmp1, %tid.x; 94 | and.b32 %stmp1, %stmp1, 31; 95 | setp.eq.u32 %p0, %stmp1, 0; 96 | @%p0 st.shared.f32 [%stmp0], %curMax; 97 | bar.sync 0; 98 | 99 | // Exit on all but first warp. 100 | mov.u32 %stmp1, %tid.x; 101 | and.b32 %stmp1, %stmp1, 992; // 1024 ^ 31 102 | setp.eq.u32 %p0, %stmp1, 0; 103 | @!%p0 ret; 104 | 105 | // Reduce the shared memory from the first warp. 106 | mov.u32 %stmp0, results; 107 | mov.u32 %stmp1, %tid.x; 108 | and.b32 %stmp1, %stmp1, 31; 109 | shl.b32 %stmp1, %stmp1, 2; 110 | add.u32 %stmp0, %stmp0, %stmp1; 111 | 112 | ld.shared.f32 %curMax, [%stmp0]; 113 | 114 | shfl.sync.bfly.b32 %ftmp, %curMax, 1, 0x1f, 0xffffffff; 115 | max.f32 %curMax, %curMax, %ftmp; 116 | shfl.sync.bfly.b32 %ftmp, %curMax, 2, 0x1f, 0xffffffff; 117 | max.f32 %curMax, %curMax, %ftmp; 118 | shfl.sync.bfly.b32 %ftmp, %curMax, 4, 0x1f, 0xffffffff; 119 | max.f32 %curMax, %curMax, %ftmp; 120 | shfl.sync.bfly.b32 %ftmp, %curMax, 8, 0x1f, 0xffffffff; 121 | max.f32 %curMax, %curMax, %ftmp; 122 | shfl.sync.bfly.b32 %ftmp, %curMax, 16, 0x1f, 0xffffffff; 123 | max.f32 %curMax, %curMax, %ftmp; 124 | 125 | setp.eq.u32 %p0, %stmp1, 0; 126 | @%p0 st.global.f32 [%ptrOut], %curMax; 127 | 128 | ret; 129 | } 130 | -------------------------------------------------------------------------------- /learn_ptx/kernels/reduction_bool_naive.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | .visible .entry reductionBoolNaive ( 6 | .param .u64 ptrIn, 7 | .param .u64 ptrOut, 8 | .param .f32 threshold, 9 | .param .u64 numColumns 10 | ) { 11 | .reg .pred %p0; 12 | 13 | // Arguments 14 | .reg .u64 %ptrIn; 15 | .reg .u64 %ptrOut; 16 | .reg .f32 %threshold; 17 | .reg .u64 %numColumns; 18 | 19 | .reg .u64 %i; 20 | .reg .u16 %accumulation; 21 | .reg .u64 %tmp0; 22 | .reg .f32 %ftmp; 23 | 24 | // Load arguments. 25 | ld.param.u64 %ptrIn, [ptrIn]; 26 | ld.param.u64 %ptrOut, [ptrOut]; 27 | ld.param.f32 %threshold, [threshold]; 28 | ld.param.u64 %numColumns, [numColumns]; 29 | 30 | cvt.u64.u32 %tmp0, %ctaid.x; 31 | add.u64 %ptrOut, %ptrOut, %tmp0; 32 | shl.b64 %tmp0, %tmp0, 2; 33 | mul.lo.u64 %tmp0, %tmp0, %numColumns; 34 | add.u64 %ptrIn, %ptrIn, %tmp0; 35 | 36 | mov.u64 %i, 0; 37 | mov.u16 %accumulation, 0; 38 | 39 | loop_start: 40 | ld.global.f32 %ftmp, [%ptrIn]; 41 | setp.lt.f32 %p0, %ftmp, %threshold; 42 | @%p0 mov.u16 %accumulation, 1; 43 | add.u64 %ptrIn, %ptrIn, 4; 44 | add.u64 %i, %i, 1; 45 | setp.lt.u64 %p0, %i, %numColumns; 46 | @%p0 bra loop_start; 47 | loop_end: 48 | 49 | st.global.u8 [%ptrOut], %accumulation; 50 | ret; 51 | } 52 | -------------------------------------------------------------------------------- /learn_ptx/kernels/reduction_bool_warp.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | .visible .entry reductionBoolWarp ( 6 | .param .u64 ptrIn, 7 | .param .u64 ptrOut, 8 | .param .f32 threshold, 9 | .param .u64 numColumns 10 | ) { 11 | .reg .pred %p0; 12 | 13 | // Arguments 14 | .reg .u64 %ptrIn; 15 | .reg .u64 %ptrOut; 16 | .reg .f32 %threshold; 17 | .reg .u64 %numColumns; 18 | 19 | .reg .u64 %i; 20 | .reg .u32 %accumulation; 21 | .reg .u32 %commResult; 22 | .reg .u64 %tmp0; 23 | .reg .f32 %ftmp; 24 | 25 | // Load arguments. 26 | ld.param.u64 %ptrIn, [ptrIn]; 27 | ld.param.u64 %ptrOut, [ptrOut]; 28 | ld.param.f32 %threshold, [threshold]; 29 | ld.param.u64 %numColumns, [numColumns]; 30 | 31 | cvt.u64.u32 %tmp0, %ctaid.x; 32 | add.u64 %ptrOut, %ptrOut, %tmp0; 33 | shl.b64 %tmp0, %tmp0, 2; 34 | mul.lo.u64 %tmp0, %tmp0, %numColumns; 35 | add.u64 %ptrIn, %ptrIn, %tmp0; 36 | cvt.u64.u32 %tmp0, %tid.x; 37 | shl.b64 %tmp0, %tmp0, 2; // offset by 4*tid.x 38 | add.u64 %ptrIn, %ptrIn, %tmp0; 39 | 40 | mov.u64 %i, 0; 41 | mov.u32 %accumulation, 0; 42 | 43 | loop_start: 44 | ld.global.f32 %ftmp, [%ptrIn]; 45 | setp.lt.f32 %p0, %ftmp, %threshold; 46 | @%p0 mov.u32 %accumulation, 1; 47 | add.u64 %ptrIn, %ptrIn, 128; // stride of 32 floats = 128 bytes 48 | add.u64 %i, %i, 32; 49 | setp.lt.u64 %p0, %i, %numColumns; 50 | @%p0 bra loop_start; 51 | loop_end: 52 | 53 | // Synchronize across all ranks using a hypercube. 54 | // https://en.wikipedia.org/wiki/Hypercube_(communication_pattern) 55 | shfl.sync.bfly.b32 %commResult, %accumulation, 1, 0x1f, 0xffffffff; 56 | or.b32 %accumulation, %accumulation, %commResult; 57 | shfl.sync.bfly.b32 %commResult, %accumulation, 2, 0x1f, 0xffffffff; 58 | or.b32 %accumulation, %accumulation, %commResult; 59 | shfl.sync.bfly.b32 %commResult, %accumulation, 4, 0x1f, 0xffffffff; 60 | or.b32 %accumulation, %accumulation, %commResult; 61 | shfl.sync.bfly.b32 %commResult, %accumulation, 8, 0x1f, 0xffffffff; 62 | or.b32 %accumulation, %accumulation, %commResult; 63 | shfl.sync.bfly.b32 %commResult, %accumulation, 16, 0x1f, 0xffffffff; 64 | or.b32 %accumulation, %accumulation, %commResult; 65 | 66 | // Only rank 0 will store the results. 67 | mov.u32 %commResult, %tid.x; 68 | setp.eq.u32 %p0, %commResult, 0; 69 | @%p0 st.global.u8 [%ptrOut], %accumulation; 70 | 71 | ret; 72 | } 73 | -------------------------------------------------------------------------------- /learn_ptx/kernels/reduction_bool_warp_vec.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | .visible .entry reductionBoolWarpVec ( 6 | .param .u64 ptrIn, 7 | .param .u64 ptrOut, 8 | .param .f32 threshold, 9 | .param .u64 numColumns 10 | ) { 11 | .reg .pred %p0; 12 | 13 | // Arguments 14 | .reg .u64 %ptrIn; 15 | .reg .u64 %ptrOut; 16 | .reg .f32 %threshold; 17 | .reg .u64 %numColumns; 18 | 19 | .reg .u64 %i; 20 | .reg .u32 %accumulation; 21 | .reg .u32 %commResult; 22 | .reg .u64 %tmp0; 23 | .reg .v4 .f32 %ftmp; 24 | 25 | // Load arguments. 26 | ld.param.u64 %ptrIn, [ptrIn]; 27 | ld.param.u64 %ptrOut, [ptrOut]; 28 | ld.param.f32 %threshold, [threshold]; 29 | ld.param.u64 %numColumns, [numColumns]; 30 | 31 | cvt.u64.u32 %tmp0, %ctaid.x; 32 | add.u64 %ptrOut, %ptrOut, %tmp0; 33 | shl.b64 %tmp0, %tmp0, 2; 34 | mul.lo.u64 %tmp0, %tmp0, %numColumns; 35 | add.u64 %ptrIn, %ptrIn, %tmp0; 36 | cvt.u64.u32 %tmp0, %tid.x; 37 | shl.b64 %tmp0, %tmp0, 4; // offset by 16*tid.x 38 | add.u64 %ptrIn, %ptrIn, %tmp0; 39 | 40 | mov.u64 %i, 0; 41 | mov.u32 %accumulation, 0; 42 | 43 | loop_start: 44 | ld.global.v4.f32 %ftmp, [%ptrIn]; 45 | setp.lt.f32 %p0, %ftmp.w, %threshold; 46 | @%p0 mov.u32 %accumulation, 1; 47 | setp.lt.f32 %p0, %ftmp.x, %threshold; 48 | @%p0 mov.u32 %accumulation, 1; 49 | setp.lt.f32 %p0, %ftmp.y, %threshold; 50 | @%p0 mov.u32 %accumulation, 1; 51 | setp.lt.f32 %p0, %ftmp.z, %threshold; 52 | @%p0 mov.u32 %accumulation, 1; 53 | add.u64 %ptrIn, %ptrIn, 512; // stride of 128 floats = 512 bytes 54 | add.u64 %i, %i, 128; 55 | setp.lt.u64 %p0, %i, %numColumns; 56 | @%p0 bra loop_start; 57 | loop_end: 58 | 59 | // Synchronize across all ranks using a hypercube. 60 | // https://en.wikipedia.org/wiki/Hypercube_(communication_pattern) 61 | shfl.sync.bfly.b32 %commResult, %accumulation, 1, 0x1f, 0xffffffff; 62 | or.b32 %accumulation, %accumulation, %commResult; 63 | shfl.sync.bfly.b32 %commResult, %accumulation, 2, 0x1f, 0xffffffff; 64 | or.b32 %accumulation, %accumulation, %commResult; 65 | shfl.sync.bfly.b32 %commResult, %accumulation, 4, 0x1f, 0xffffffff; 66 | or.b32 %accumulation, %accumulation, %commResult; 67 | shfl.sync.bfly.b32 %commResult, %accumulation, 8, 0x1f, 0xffffffff; 68 | or.b32 %accumulation, %accumulation, %commResult; 69 | shfl.sync.bfly.b32 %commResult, %accumulation, 16, 0x1f, 0xffffffff; 70 | or.b32 %accumulation, %accumulation, %commResult; 71 | 72 | // Only rank 0 will store the results. 73 | mov.u32 %commResult, %tid.x; 74 | setp.eq.u32 %p0, %commResult, 0; 75 | @%p0 st.global.u8 [%ptrOut], %accumulation; 76 | 77 | ret; 78 | } 79 | -------------------------------------------------------------------------------- /learn_ptx/kernels/reduction_trans_bool_blocked.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | .visible .entry reductionTransBoolBlocked ( 6 | .param .u64 ptrIn, 7 | .param .u64 ptrOut, 8 | .param .f32 threshold, 9 | .param .u64 numRows 10 | ) { 11 | .reg .pred %p0; 12 | 13 | // Arguments 14 | .reg .u64 %ptrIn; 15 | .reg .u64 %ptrOut; 16 | .reg .f32 %threshold; 17 | .reg .u64 %numRows; 18 | 19 | .reg .u64 %i; 20 | .reg .u16 %accumulation; 21 | .reg .u64 %tmp<2>; 22 | .reg .f32 %ftmp; 23 | 24 | // Load arguments. 25 | ld.param.u64 %ptrIn, [ptrIn]; 26 | ld.param.u64 %ptrOut, [ptrOut]; 27 | ld.param.f32 %threshold, [threshold]; 28 | ld.param.u64 %numRows, [numRows]; 29 | 30 | // Output index is tid.x + ctaid.x*ntid.x 31 | cvt.u64.u32 %tmp0, %ctaid.x; 32 | cvt.u64.u32 %tmp1, %ntid.x; 33 | mul.lo.u64 %tmp0, %tmp0, %tmp1; 34 | cvt.u64.u32 %tmp1, %tid.x; 35 | add.u64 %tmp0, %tmp0, %tmp1; 36 | add.u64 %ptrOut, %ptrOut, %tmp0; 37 | shl.b64 %tmp0, %tmp0, 2; 38 | add.u64 %ptrIn, %ptrIn, %tmp0; 39 | 40 | // Stride is 4*ntid.x*nctaid.x 41 | cvt.u64.u32 %tmp0, %nctaid.x; 42 | cvt.u64.u32 %tmp1, %ntid.x; 43 | mul.lo.u64 %tmp0, %tmp0, %tmp1; 44 | shl.b64 %tmp0, %tmp0, 2; 45 | 46 | mov.u64 %i, 0; 47 | mov.u16 %accumulation, 0; 48 | 49 | loop_start: 50 | ld.global.f32 %ftmp, [%ptrIn]; 51 | setp.lt.f32 %p0, %ftmp, %threshold; 52 | @%p0 mov.u16 %accumulation, 1; 53 | add.u64 %ptrIn, %ptrIn, %tmp0; 54 | add.u64 %i, %i, 1; 55 | setp.lt.u64 %p0, %i, %numRows; 56 | @%p0 bra loop_start; 57 | loop_end: 58 | 59 | st.global.u8 [%ptrOut], %accumulation; 60 | ret; 61 | } 62 | -------------------------------------------------------------------------------- /learn_ptx/kernels/reduction_trans_bool_naive.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | .visible .entry reductionTransBoolNaive ( 6 | .param .u64 ptrIn, 7 | .param .u64 ptrOut, 8 | .param .f32 threshold, 9 | .param .u64 numRows 10 | ) { 11 | .reg .pred %p0; 12 | 13 | // Arguments 14 | .reg .u64 %ptrIn; 15 | .reg .u64 %ptrOut; 16 | .reg .f32 %threshold; 17 | .reg .u64 %numRows; 18 | 19 | .reg .u64 %i; 20 | .reg .u16 %accumulation; 21 | .reg .u64 %tmp0; 22 | .reg .f32 %ftmp; 23 | 24 | // Load arguments. 25 | ld.param.u64 %ptrIn, [ptrIn]; 26 | ld.param.u64 %ptrOut, [ptrOut]; 27 | ld.param.f32 %threshold, [threshold]; 28 | ld.param.u64 %numRows, [numRows]; 29 | 30 | cvt.u64.u32 %tmp0, %ctaid.x; 31 | add.u64 %ptrOut, %ptrOut, %tmp0; 32 | shl.b64 %tmp0, %tmp0, 2; 33 | add.u64 %ptrIn, %ptrIn, %tmp0; 34 | 35 | // Stride is 4*nctaid.x 36 | cvt.u64.u32 %tmp0, %nctaid.x; 37 | shl.b64 %tmp0, %tmp0, 2; 38 | 39 | mov.u64 %i, 0; 40 | mov.u16 %accumulation, 0; 41 | 42 | loop_start: 43 | ld.global.f32 %ftmp, [%ptrIn]; 44 | setp.lt.f32 %p0, %ftmp, %threshold; 45 | @%p0 mov.u16 %accumulation, 1; 46 | add.u64 %ptrIn, %ptrIn, %tmp0; 47 | add.u64 %i, %i, 1; 48 | setp.lt.u64 %p0, %i, %numRows; 49 | @%p0 bra loop_start; 50 | loop_end: 51 | 52 | st.global.u8 [%ptrOut], %accumulation; 53 | ret; 54 | } 55 | -------------------------------------------------------------------------------- /learn_ptx/kernels/sort_bitonic_block.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | .visible .entry sortBitonicBlock ( 6 | .param .u64 ptr 7 | ) { 8 | .reg .pred %p0; 9 | .reg .pred %upperHalf; 10 | .reg .pred %reverse; 11 | 12 | // Arguments 13 | .reg .u64 %ptr; 14 | 15 | // Cached thread properties 16 | .reg .u32 %tidX; 17 | .reg .u32 %tidY; 18 | 19 | // Other variables. 20 | .reg .u64 %dtmp<2>; 21 | .reg .u32 %stmp<4>; 22 | .reg .u32 %i; 23 | .reg .u32 %j; 24 | .reg .f32 %val<3>; 25 | 26 | .shared .align 4 .f32 sortBuffer[1024]; 27 | 28 | // Load arguments and thread properties. 29 | ld.param.u64 %ptr, [ptr]; 30 | mov.u32 %tidX, %tid.x; 31 | mov.u32 %tidY, %tid.y; 32 | 33 | cvt.u64.u32 %dtmp0, %ctaid.x; 34 | shl.b64 %dtmp0, %dtmp0, 10; 35 | cvt.u64.u32 %dtmp1, %tidY; 36 | shl.b64 %dtmp1, %dtmp1, 5; 37 | add.u64 %dtmp0, %dtmp0, %dtmp1; // (ctaid.x*1024 + tid.y*32) 38 | cvt.u64.u32 %dtmp1, %tidX; 39 | add.u64 %dtmp0, %dtmp0, %dtmp1; 40 | shl.b64 %dtmp0, %dtmp0, 2; // 4*(ctaid.x*1024 + tid.y*32 + tid.x) 41 | add.u64 %ptr, %ptr, %dtmp0; 42 | ld.global.f32 %val0, [%ptr]; 43 | 44 | mov.u32 %i, 0; 45 | loop_start: 46 | // Flip the order of every other group to keep the data 47 | // in bitonic order. 48 | shl.b32 %stmp0, 2, %i; 49 | shl.b32 %stmp1, %tidY, 5; 50 | add.u32 %stmp1, %stmp1, %tidX; 51 | and.b32 %stmp0, %stmp1, %stmp0; 52 | setp.ne.u32 %reverse, %stmp0, 0; 53 | 54 | mov.u32 %j, %i; 55 | inner_loop_start: 56 | // Our stride is 2^j; 57 | shl.b32 %stmp0, 1, %j; 58 | 59 | // Check if we are first, and then flip it based on %reverse. 60 | and.b32 %stmp1, %tidX, %stmp0; 61 | setp.eq.xor.u32 %p0, %stmp1, 0, %reverse; 62 | 63 | shfl.sync.bfly.b32 %val1, %val0, %stmp0, 0x1f, 0xffffffff; 64 | // Keep lower or higher value depending on circumstances. 65 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 66 | selp.f32 %val0, %val1, %val0, %p0; 67 | 68 | setp.ne.u32 %p0, %j, 0; 69 | sub.u32 %j, %j, 1; 70 | @%p0 bra inner_loop_start; 71 | inner_loop_end: 72 | 73 | add.u32 %i, %i, 1; 74 | setp.lt.u32 %p0, %i, 5; 75 | @%p0 bra loop_start; 76 | loop_end: 77 | 78 | // Our index in shared will be stmp0 = 4*(tidX + tidY*32) 79 | // We start by storing the values so that each 32-float block 80 | // is sorted. 81 | shl.b32 %stmp0, %tidY, 5; 82 | add.u32 %stmp0, %stmp0, %tidX; 83 | shl.b32 %stmp0, %stmp0, 2; 84 | mov.u32 %stmp1, sortBuffer; 85 | add.u32 %stmp1, %stmp1, %stmp0; 86 | st.shared.f32 [%stmp1], %val0; 87 | 88 | // Skip work in second half of all warps while 89 | // doing cross-shared-memory merging. 90 | setp.gt.u32 %upperHalf, %tidY, 15; 91 | 92 | // Sort from stride 32 to stride 512, by striding 93 | // tid.y and doing the resulting indexing logic. 94 | mov.u32 %i, 0; 95 | block_loop_start: 96 | // Merge across warps, avoiding bank conflicts by reading 97 | // consecutive values of at least 32 floats. 98 | // This logic reads and writes two values from each warp, 99 | // masking activity from the upper half of the block. 100 | mov.u32 %j, %i; 101 | inner_block_loop_start: 102 | // We will store a "virtual" tid.y in stmp2 by effectively 103 | // moving bit %j to the 16 position (most significant digit). 104 | // 105 | // If tid.y % (1<; 23 | .reg .u32 %stmp<4>; 24 | .reg .u32 %i; 25 | .reg .u32 %j; 26 | .reg .f32 %val<3>; 27 | .reg .u32 %rank; 28 | .reg .u32 %rankAnd1; 29 | .reg .u32 %rankAnd2; 30 | .reg .u32 %rankAnd4; 31 | .reg .u32 %rankAnd8; 32 | .reg .u32 %rankAnd16; 33 | .reg .u32 %rankAnd32; 34 | 35 | .shared .align 4 .f32 sortBuffer[1024]; 36 | 37 | // Load arguments and thread properties. 38 | ld.param.u64 %ptr, [ptr]; 39 | mov.u32 %tidX, %tid.x; 40 | mov.u32 %tidY, %tid.y; 41 | 42 | shl.b32 %rank, %tidY, 5; 43 | add.u32 %rank, %rank, %tidX; 44 | and.b32 %rankAnd1, %rank, 1; 45 | and.b32 %rankAnd2, %rank, 2; 46 | and.b32 %rankAnd4, %rank, 4; 47 | and.b32 %rankAnd8, %rank, 8; 48 | and.b32 %rankAnd16, %rank, 16; 49 | and.b32 %rankAnd32, %rank, 32; 50 | 51 | cvt.u64.u32 %dtmp0, %ctaid.x; 52 | shl.b64 %dtmp0, %dtmp0, 10; 53 | cvt.u64.u32 %dtmp1, %tidY; 54 | shl.b64 %dtmp1, %dtmp1, 5; 55 | add.u64 %dtmp0, %dtmp0, %dtmp1; // (ctaid.x*1024 + tid.y*32) 56 | cvt.u64.u32 %dtmp1, %tidX; 57 | add.u64 %dtmp0, %dtmp0, %dtmp1; 58 | shl.b64 %dtmp0, %dtmp0, 2; // 4*(ctaid.x*1024 + tid.y*32 + tid.x) 59 | add.u64 %ptr, %ptr, %dtmp0; 60 | ld.global.f32 %val0, [%ptr]; 61 | 62 | // Sort within each warp. 63 | // i=0 64 | setp.ne.u32 %reverse, %rankAnd2, 0; 65 | // j=0 66 | setp.eq.xor.u32 %p0, %rankAnd1, 0, %reverse; 67 | shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff; 68 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 69 | selp.f32 %val0, %val1, %val0, %p0; 70 | 71 | // i=1 72 | setp.ne.u32 %reverse, %rankAnd4, 0; 73 | // j=1 74 | setp.eq.xor.u32 %p0, %rankAnd2, 0, %reverse; 75 | shfl.sync.bfly.b32 %val1, %val0, 2, 0x1f, 0xffffffff; 76 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 77 | selp.f32 %val0, %val1, %val0, %p0; 78 | // j=0 79 | setp.eq.xor.u32 %p0, %rankAnd1, 0, %reverse; 80 | shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff; 81 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 82 | selp.f32 %val0, %val1, %val0, %p0; 83 | 84 | // i=2 85 | setp.ne.u32 %reverse, %rankAnd8, 0; 86 | // j=2 87 | setp.eq.xor.u32 %p0, %rankAnd4, 0, %reverse; 88 | shfl.sync.bfly.b32 %val1, %val0, 4, 0x1f, 0xffffffff; 89 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 90 | selp.f32 %val0, %val1, %val0, %p0; 91 | // j=1 92 | setp.eq.xor.u32 %p0, %rankAnd2, 0, %reverse; 93 | shfl.sync.bfly.b32 %val1, %val0, 2, 0x1f, 0xffffffff; 94 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 95 | selp.f32 %val0, %val1, %val0, %p0; 96 | // j=0 97 | setp.eq.xor.u32 %p0, %rankAnd1, 0, %reverse; 98 | shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff; 99 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 100 | selp.f32 %val0, %val1, %val0, %p0; 101 | 102 | // i=3 103 | setp.ne.u32 %reverse, %rankAnd16, 0; 104 | // j=3 105 | setp.eq.xor.u32 %p0, %rankAnd8, 0, %reverse; 106 | shfl.sync.bfly.b32 %val1, %val0, 8, 0x1f, 0xffffffff; 107 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 108 | selp.f32 %val0, %val1, %val0, %p0; 109 | // j=2 110 | setp.eq.xor.u32 %p0, %rankAnd4, 0, %reverse; 111 | shfl.sync.bfly.b32 %val1, %val0, 4, 0x1f, 0xffffffff; 112 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 113 | selp.f32 %val0, %val1, %val0, %p0; 114 | // j=1 115 | setp.eq.xor.u32 %p0, %rankAnd2, 0, %reverse; 116 | shfl.sync.bfly.b32 %val1, %val0, 2, 0x1f, 0xffffffff; 117 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 118 | selp.f32 %val0, %val1, %val0, %p0; 119 | // j=0 120 | setp.eq.xor.u32 %p0, %rankAnd1, 0, %reverse; 121 | shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff; 122 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 123 | selp.f32 %val0, %val1, %val0, %p0; 124 | 125 | // i=4 126 | setp.ne.u32 %reverse, %rankAnd32, 0; 127 | // j=4 128 | setp.eq.xor.u32 %p0, %rankAnd16, 0, %reverse; 129 | shfl.sync.bfly.b32 %val1, %val0, 16, 0x1f, 0xffffffff; 130 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 131 | selp.f32 %val0, %val1, %val0, %p0; 132 | // j=3 133 | setp.eq.xor.u32 %p0, %rankAnd8, 0, %reverse; 134 | shfl.sync.bfly.b32 %val1, %val0, 8, 0x1f, 0xffffffff; 135 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 136 | selp.f32 %val0, %val1, %val0, %p0; 137 | // j=2 138 | setp.eq.xor.u32 %p0, %rankAnd4, 0, %reverse; 139 | shfl.sync.bfly.b32 %val1, %val0, 4, 0x1f, 0xffffffff; 140 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 141 | selp.f32 %val0, %val1, %val0, %p0; 142 | // j=1 143 | setp.eq.xor.u32 %p0, %rankAnd2, 0, %reverse; 144 | shfl.sync.bfly.b32 %val1, %val0, 2, 0x1f, 0xffffffff; 145 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 146 | selp.f32 %val0, %val1, %val0, %p0; 147 | // j=0 148 | setp.eq.xor.u32 %p0, %rankAnd1, 0, %reverse; 149 | shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff; 150 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 151 | selp.f32 %val0, %val1, %val0, %p0; 152 | 153 | // Our index in shared will be stmp0 = 4*(tidX + tidY*32) 154 | // We start by storing the values so that each 32-float block 155 | // is sorted (or reversed, to make 64-float bitonic chunks). 156 | shl.b32 %stmp0, %rank, 2; 157 | mov.u32 %stmp1, sortBuffer; 158 | add.u32 %stmp1, %stmp1, %stmp0; 159 | st.shared.f32 [%stmp1], %val0; 160 | 161 | // Sort from stride 32 to stride 512, by striding 162 | // tid.y and doing the resulting indexing logic. 163 | mov.u32 %i, 0; 164 | block_loop_start: 165 | 166 | // We merge pairs from shared memory, so we only use half 167 | // of the block for this inner loop. 168 | // We still want to make sure all writes from all ranks 169 | // are finished writing to shared memory. 170 | bar.sync 0; 171 | setp.gt.u32 %p0, %tidY, 15; 172 | @%p0 bra inner_block_loop_end; 173 | 174 | // Merge across warps, avoiding bank conflicts by reading 175 | // consecutive values of at least 32 floats. 176 | // This logic reads and writes two values from each warp. 177 | mov.u32 %j, %i; 178 | inner_block_loop_start: 179 | // We will store a "virtual" tid.y in stmp2 by effectively 180 | // moving bit %j to the 16 position (most significant digit). 181 | // 182 | // If tid.y % (1<; 48 | cvt.u64.u32 %tidX, %tid.x; 49 | cvt.u64.u32 %ctaidX, %ctaid.x; 50 | cvt.u64.u32 %ntidX, %ntid.x; 51 | mul.lo.u64 %globalIdx, %ctaidX, %ntidX; 52 | add.u64 %globalIdx, %globalIdx, %tidX; 53 | shr.b64 %halfChunkSize, %chunkSize, 1; 54 | div.u64 %chunkIdx, %globalIdx, %halfChunkSize; 55 | rem.u64 %indexInChunk, %globalIdx, %halfChunkSize; 56 | 57 | mul.lo.u64 %tmp0, %chunkIdx, %chunkSize; 58 | add.u64 %tmp1, %tmp0, %indexInChunk; 59 | shl.b64 %tmp2, %tmp1, 2; 60 | add.u64 %ptrLower, %ptr, %tmp2; 61 | 62 | @%crossover sub.u64 %indexInChunk, %halfChunkSize, %indexInChunk; 63 | @%crossover sub.u64 %indexInChunk, %indexInChunk, 1; 64 | add.u64 %tmp3, %tmp0, %halfChunkSize; 65 | add.u64 %tmp4, %tmp3, %indexInChunk; 66 | shl.b64 %tmp5, %tmp4, 2; 67 | add.u64 %ptrUpper, %ptr, %tmp5; 68 | } 69 | 70 | ld.global.f32 %valA, [%ptrLower]; 71 | ld.global.f32 %valB, [%ptrUpper]; 72 | setp.gt.f32 %p0, %valA, %valB; 73 | { 74 | .reg .f32 %tmp; 75 | @%p0 mov.f32 %tmp, %valA; 76 | @%p0 mov.f32 %valA, %valB; 77 | @%p0 mov.f32 %valB, %tmp; 78 | } 79 | st.global.f32 [%ptrLower], %valA; 80 | st.global.f32 [%ptrUpper], %valB; 81 | 82 | ret; 83 | } 84 | -------------------------------------------------------------------------------- /learn_ptx/kernels/sort_bitonic_global_v2.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | // Similar to sort_bitonic_global.ptx, but once called with 6 | // chunkSize <= ntid.x*2, will complete the merge in one 7 | // kernel call. 8 | 9 | .visible .entry sortBitonicGlobalV2 ( 10 | .param .u64 ptr, 11 | .param .u64 chunkSize, 12 | .param .u32 crossover 13 | ) { 14 | .reg .pred %p0; 15 | 16 | // Arguments 17 | .reg .u64 %ptr; 18 | .reg .u64 %chunkSize; 19 | .reg .pred %crossover; 20 | .reg .pred %completeAtOnce; 21 | 22 | // Computed from thread index. 23 | .reg .u64 %globalIdx; 24 | 25 | // Addresses for two values to swap. 26 | .reg .u64 %ptrLower; 27 | .reg .u64 %ptrUpper; 28 | 29 | // Used for storing values. 30 | .reg .f32 %valA; 31 | .reg .f32 %valB; 32 | 33 | // Load arguments and thread properties. 34 | ld.param.u64 %ptr, [ptr]; 35 | ld.param.u64 %chunkSize, [chunkSize]; 36 | { 37 | .reg .u32 %tmp; 38 | ld.param.u32 %tmp, [crossover]; 39 | setp.ne.u32 %crossover, %tmp, 0; 40 | } 41 | 42 | { 43 | .reg .u64 %ntidX; 44 | .reg .u64 %tidX; 45 | .reg .u64 %ctaidX; 46 | cvt.u64.u32 %ntidX, %ntid.x; 47 | cvt.u64.u32 %tidX, %tid.x; 48 | cvt.u64.u32 %ctaidX, %ctaid.x; 49 | mul.lo.u64 %globalIdx, %ctaidX, %ntidX; 50 | add.u64 %globalIdx, %globalIdx, %tidX; 51 | 52 | .reg .u64 %tmp; 53 | shr.b64 %tmp, %chunkSize, 1; 54 | setp.le.u64 %completeAtOnce, %tmp, %ntidX; 55 | } 56 | 57 | loop_start: 58 | 59 | // Offset in buffer is based on global index and chunkSize. 60 | // Each chunkSize/2 threads handles one chunk, and each thread 61 | // swaps two values. 62 | { 63 | .reg .u64 %halfChunkSize; 64 | .reg .u64 %chunkIdx; 65 | .reg .u64 %indexInChunk; 66 | .reg .u64 %tmp<6>; 67 | shr.b64 %halfChunkSize, %chunkSize, 1; 68 | div.u64 %chunkIdx, %globalIdx, %halfChunkSize; 69 | rem.u64 %indexInChunk, %globalIdx, %halfChunkSize; 70 | 71 | mul.lo.u64 %tmp0, %chunkIdx, %chunkSize; 72 | add.u64 %tmp1, %tmp0, %indexInChunk; 73 | shl.b64 %tmp2, %tmp1, 2; 74 | add.u64 %ptrLower, %ptr, %tmp2; 75 | 76 | @%crossover sub.u64 %indexInChunk, %halfChunkSize, %indexInChunk; 77 | @%crossover sub.u64 %indexInChunk, %indexInChunk, 1; 78 | add.u64 %tmp3, %tmp0, %halfChunkSize; 79 | add.u64 %tmp4, %tmp3, %indexInChunk; 80 | shl.b64 %tmp5, %tmp4, 2; 81 | add.u64 %ptrUpper, %ptr, %tmp5; 82 | } 83 | 84 | ld.global.f32 %valA, [%ptrLower]; 85 | ld.global.f32 %valB, [%ptrUpper]; 86 | setp.gt.f32 %p0, %valA, %valB; 87 | { 88 | .reg .f32 %tmp; 89 | @%p0 mov.f32 %tmp, %valA; 90 | @%p0 mov.f32 %valA, %valB; 91 | @%p0 mov.f32 %valB, %tmp; 92 | } 93 | st.global.f32 [%ptrLower], %valA; 94 | st.global.f32 [%ptrUpper], %valB; 95 | 96 | setp.gt.and.u64 %p0, %chunkSize, 2, %completeAtOnce; 97 | @!%p0 ret; 98 | bar.sync 0; // Wait for chunk to be sorted 99 | membar.cta; // fence not supported yet on older GPUs. 100 | setp.eq.u64 %crossover, %chunkSize, 1; // set crossover=false 101 | shr.b64 %chunkSize, %chunkSize, 1; 102 | bra loop_start; 103 | } 104 | -------------------------------------------------------------------------------- /learn_ptx/kernels/sort_bitonic_warp.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | .visible .entry sortBitonicWarp ( 6 | .param .u64 ptr 7 | ) { 8 | .reg .pred %p0; 9 | .reg .pred %reverse; 10 | 11 | // Arguments 12 | .reg .u64 %ptr; 13 | 14 | // Cached thread properties 15 | .reg .u32 %tidX; 16 | 17 | // Other variables. 18 | .reg .u64 %dtmp<2>; 19 | .reg .u32 %stmp<2>; 20 | .reg .u32 %i; 21 | .reg .u32 %j; 22 | .reg .f32 %val<2>; 23 | 24 | // Load arguments and thread properties. 25 | ld.param.u64 %ptr, [ptr]; 26 | mov.u32 %tidX, %tid.x; 27 | 28 | cvt.u64.u32 %dtmp0, %ctaid.x; 29 | shl.b64 %dtmp0, %dtmp0, 5; 30 | cvt.u64.u32 %dtmp1, %tidX; 31 | add.u64 %dtmp0, %dtmp0, %dtmp1; 32 | shl.b64 %dtmp0, %dtmp0, 2; // 4*(ctaid.x*32 + tid.x) 33 | add.u64 %ptr, %ptr, %dtmp0; 34 | ld.global.f32 %val0, [%ptr]; 35 | 36 | mov.u32 %i, 0; 37 | loop_start: 38 | // Flip the order of every other group to keep the data 39 | // in bitonic order. 40 | shl.b32 %stmp0, 2, %i; 41 | and.b32 %stmp0, %tidX, %stmp0; 42 | setp.ne.u32 %reverse, %stmp0, 0; 43 | 44 | mov.u32 %j, %i; 45 | inner_loop_start: 46 | // Our stride is 2^j; 47 | shl.b32 %stmp0, 1, %j; 48 | 49 | // Check if we are first, and then flip it based on %reverse. 50 | and.b32 %stmp1, %tidX, %stmp0; 51 | setp.eq.xor.u32 %p0, %stmp1, 0, %reverse; 52 | 53 | shfl.sync.bfly.b32 %val1, %val0, %stmp0, 0x1f, 0xffffffff; 54 | // Keep lower or higher value depending on circumstances. 55 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 56 | selp.f32 %val0, %val1, %val0, %p0; 57 | 58 | setp.ne.u32 %p0, %j, 0; 59 | sub.u32 %j, %j, 1; 60 | @%p0 bra inner_loop_start; 61 | inner_loop_end: 62 | 63 | add.u32 %i, %i, 1; 64 | setp.lt.u32 %p0, %i, 5; 65 | @%p0 bra loop_start; 66 | loop_end: 67 | 68 | st.global.f32 [%ptr], %val0; 69 | ret; 70 | } 71 | -------------------------------------------------------------------------------- /learn_ptx/kernels/sort_bitonic_warp_v2.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | // Unrolled version of sort_bitonic_warp.ptx 6 | 7 | .visible .entry sortBitonicWarpV2 ( 8 | .param .u64 ptr 9 | ) { 10 | .reg .pred %p0; 11 | .reg .pred %reverse; 12 | 13 | // Arguments 14 | .reg .u64 %ptr; 15 | 16 | // Cached thread properties 17 | .reg .u32 %tidX; 18 | 19 | // Other variables. 20 | .reg .u64 %dtmp<2>; 21 | .reg .u32 %stmp<2>; 22 | .reg .f32 %val<2>; 23 | 24 | // Load arguments and thread properties. 25 | ld.param.u64 %ptr, [ptr]; 26 | mov.u32 %tidX, %tid.x; 27 | 28 | cvt.u64.u32 %dtmp0, %ctaid.x; 29 | shl.b64 %dtmp0, %dtmp0, 5; 30 | cvt.u64.u32 %dtmp1, %tidX; 31 | add.u64 %dtmp0, %dtmp0, %dtmp1; 32 | shl.b64 %dtmp0, %dtmp0, 2; // 4*(ctaid.x*32 + tid.x) 33 | add.u64 %ptr, %ptr, %dtmp0; 34 | ld.global.f32 %val0, [%ptr]; 35 | 36 | /* 37 | 38 | for i in range(5): 39 | print("") 40 | print(f" // {i=}") 41 | print(f" and.b32 %stmp0, %tidX, {2 << i};") 42 | print(f" setp.ne.u32 %reverse, %stmp0, 0;") 43 | 44 | for j in range(i, -1, -1): 45 | print(f" // {j=}") 46 | print(f" and.b32 %stmp1, %tidX, {1 << j};") 47 | print(f" setp.eq.xor.u32 %p0, %stmp1, 0, %reverse;") 48 | print(f" shfl.sync.bfly.b32 %val1, %val0, {1 << j}, 0x1f, 0xffffffff;") 49 | print(f" setp.lt.xor.f32 %p0, %val0, %val1, %p0;") 50 | print(f" selp.f32 %val0, %val1, %val0, %p0;") 51 | 52 | */ 53 | 54 | loop_start: 55 | // i=0 56 | and.b32 %stmp0, %tidX, 2; 57 | setp.ne.u32 %reverse, %stmp0, 0; 58 | // j=0 59 | and.b32 %stmp1, %tidX, 1; 60 | setp.eq.xor.u32 %p0, %stmp1, 0, %reverse; 61 | shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff; 62 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 63 | selp.f32 %val0, %val1, %val0, %p0; 64 | 65 | // i=1 66 | and.b32 %stmp0, %tidX, 4; 67 | setp.ne.u32 %reverse, %stmp0, 0; 68 | // j=1 69 | and.b32 %stmp1, %tidX, 2; 70 | setp.eq.xor.u32 %p0, %stmp1, 0, %reverse; 71 | shfl.sync.bfly.b32 %val1, %val0, 2, 0x1f, 0xffffffff; 72 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 73 | selp.f32 %val0, %val1, %val0, %p0; 74 | // j=0 75 | and.b32 %stmp1, %tidX, 1; 76 | setp.eq.xor.u32 %p0, %stmp1, 0, %reverse; 77 | shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff; 78 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 79 | selp.f32 %val0, %val1, %val0, %p0; 80 | 81 | // i=2 82 | and.b32 %stmp0, %tidX, 8; 83 | setp.ne.u32 %reverse, %stmp0, 0; 84 | // j=2 85 | and.b32 %stmp1, %tidX, 4; 86 | setp.eq.xor.u32 %p0, %stmp1, 0, %reverse; 87 | shfl.sync.bfly.b32 %val1, %val0, 4, 0x1f, 0xffffffff; 88 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 89 | selp.f32 %val0, %val1, %val0, %p0; 90 | // j=1 91 | and.b32 %stmp1, %tidX, 2; 92 | setp.eq.xor.u32 %p0, %stmp1, 0, %reverse; 93 | shfl.sync.bfly.b32 %val1, %val0, 2, 0x1f, 0xffffffff; 94 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 95 | selp.f32 %val0, %val1, %val0, %p0; 96 | // j=0 97 | and.b32 %stmp1, %tidX, 1; 98 | setp.eq.xor.u32 %p0, %stmp1, 0, %reverse; 99 | shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff; 100 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 101 | selp.f32 %val0, %val1, %val0, %p0; 102 | 103 | // i=3 104 | and.b32 %stmp0, %tidX, 16; 105 | setp.ne.u32 %reverse, %stmp0, 0; 106 | // j=3 107 | and.b32 %stmp1, %tidX, 8; 108 | setp.eq.xor.u32 %p0, %stmp1, 0, %reverse; 109 | shfl.sync.bfly.b32 %val1, %val0, 8, 0x1f, 0xffffffff; 110 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 111 | selp.f32 %val0, %val1, %val0, %p0; 112 | // j=2 113 | and.b32 %stmp1, %tidX, 4; 114 | setp.eq.xor.u32 %p0, %stmp1, 0, %reverse; 115 | shfl.sync.bfly.b32 %val1, %val0, 4, 0x1f, 0xffffffff; 116 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 117 | selp.f32 %val0, %val1, %val0, %p0; 118 | // j=1 119 | and.b32 %stmp1, %tidX, 2; 120 | setp.eq.xor.u32 %p0, %stmp1, 0, %reverse; 121 | shfl.sync.bfly.b32 %val1, %val0, 2, 0x1f, 0xffffffff; 122 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 123 | selp.f32 %val0, %val1, %val0, %p0; 124 | // j=0 125 | and.b32 %stmp1, %tidX, 1; 126 | setp.eq.xor.u32 %p0, %stmp1, 0, %reverse; 127 | shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff; 128 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 129 | selp.f32 %val0, %val1, %val0, %p0; 130 | 131 | // i=4 132 | and.b32 %stmp0, %tidX, 32; 133 | setp.ne.u32 %reverse, %stmp0, 0; 134 | // j=4 135 | and.b32 %stmp1, %tidX, 16; 136 | setp.eq.xor.u32 %p0, %stmp1, 0, %reverse; 137 | shfl.sync.bfly.b32 %val1, %val0, 16, 0x1f, 0xffffffff; 138 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 139 | selp.f32 %val0, %val1, %val0, %p0; 140 | // j=3 141 | and.b32 %stmp1, %tidX, 8; 142 | setp.eq.xor.u32 %p0, %stmp1, 0, %reverse; 143 | shfl.sync.bfly.b32 %val1, %val0, 8, 0x1f, 0xffffffff; 144 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 145 | selp.f32 %val0, %val1, %val0, %p0; 146 | // j=2 147 | and.b32 %stmp1, %tidX, 4; 148 | setp.eq.xor.u32 %p0, %stmp1, 0, %reverse; 149 | shfl.sync.bfly.b32 %val1, %val0, 4, 0x1f, 0xffffffff; 150 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 151 | selp.f32 %val0, %val1, %val0, %p0; 152 | // j=1 153 | and.b32 %stmp1, %tidX, 2; 154 | setp.eq.xor.u32 %p0, %stmp1, 0, %reverse; 155 | shfl.sync.bfly.b32 %val1, %val0, 2, 0x1f, 0xffffffff; 156 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 157 | selp.f32 %val0, %val1, %val0, %p0; 158 | // j=0 159 | and.b32 %stmp1, %tidX, 1; 160 | setp.eq.xor.u32 %p0, %stmp1, 0, %reverse; 161 | shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff; 162 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 163 | selp.f32 %val0, %val1, %val0, %p0; 164 | loop_end: 165 | 166 | st.global.f32 [%ptr], %val0; 167 | ret; 168 | } 169 | -------------------------------------------------------------------------------- /learn_ptx/kernels/sort_bitonic_warp_v3.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | // Similar to sort_bitonic_warp_v2.py but with pre-computed 6 | // cached values of tidX & (1 << i), a larger block size, and 7 | // 8 | 9 | .visible .entry sortBitonicWarpV3 ( 10 | .param .u64 ptr, 11 | .param .u64 count, 12 | .param .u64 stride 13 | ) { 14 | .reg .pred %p0; 15 | .reg .pred %reverse; 16 | 17 | // Arguments 18 | .reg .u64 %ptr; 19 | .reg .u64 %stride; 20 | .reg .u64 %count; 21 | 22 | // Cached thread properties 23 | .reg .u32 %tidX; 24 | .reg .u32 %tidY; 25 | 26 | // Other variables. 27 | .reg .u64 %dtmp<2>; 28 | .reg .u32 %tidXAnd1; 29 | .reg .u32 %tidXAnd2; 30 | .reg .u32 %tidXAnd4; 31 | .reg .u32 %tidXAnd8; 32 | .reg .u32 %tidXAnd16; 33 | .reg .u32 %tidXAnd32; 34 | .reg .f32 %val<2>; 35 | 36 | // Load arguments and thread properties. 37 | ld.param.u64 %ptr, [ptr]; 38 | ld.param.u64 %stride, [stride]; 39 | ld.param.u64 %count, [count]; 40 | mov.u32 %tidX, %tid.x; 41 | mov.u32 %tidY, %tid.y; 42 | 43 | cvt.u64.u32 %dtmp0, %ctaid.x; 44 | shl.b64 %dtmp0, %dtmp0, 8; 45 | cvt.u64.u32 %dtmp1, %tidY; 46 | shl.b64 %dtmp1, %dtmp1, 5; 47 | add.u64 %dtmp0, %dtmp0, %dtmp1; 48 | cvt.u64.u32 %dtmp1, %tidX; 49 | add.u64 %dtmp0, %dtmp0, %dtmp1; 50 | shl.b64 %dtmp0, %dtmp0, 2; // 4*(ctaid.x*256 + tid.y*32 + tid.x) 51 | add.u64 %ptr, %ptr, %dtmp0; 52 | 53 | and.b32 %tidXAnd1, %tidX, 1; 54 | and.b32 %tidXAnd2, %tidX, 2; 55 | and.b32 %tidXAnd4, %tidX, 4; 56 | and.b32 %tidXAnd8, %tidX, 8; 57 | and.b32 %tidXAnd16, %tidX, 16; 58 | mov.u32 %tidXAnd32, 0; 59 | 60 | outer_loop_start: 61 | ld.global.f32 %val0, [%ptr]; 62 | 63 | /* 64 | 65 | for i in range(5): 66 | print("") 67 | print(f" // {i=}") 68 | print(f" setp.ne.u32 %reverse, %tidXAnd{2 << i}, 0;") 69 | 70 | for j in range(i, -1, -1): 71 | print(f" // {j=}") 72 | print(f" setp.eq.xor.u32 %p0, %tidXAnd{1 << j}, 0, %reverse;") 73 | print(f" shfl.sync.bfly.b32 %val1, %val0, {1 << j}, 0x1f, 0xffffffff;") 74 | print(f" setp.lt.xor.f32 %p0, %val0, %val1, %p0;") 75 | print(f" selp.f32 %val0, %val1, %val0, %p0;") 76 | 77 | */ 78 | 79 | loop_start: 80 | 81 | // i=0 82 | setp.ne.u32 %reverse, %tidXAnd2, 0; 83 | // j=0 84 | setp.eq.xor.u32 %p0, %tidXAnd1, 0, %reverse; 85 | shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff; 86 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 87 | selp.f32 %val0, %val1, %val0, %p0; 88 | 89 | // i=1 90 | setp.ne.u32 %reverse, %tidXAnd4, 0; 91 | // j=1 92 | setp.eq.xor.u32 %p0, %tidXAnd2, 0, %reverse; 93 | shfl.sync.bfly.b32 %val1, %val0, 2, 0x1f, 0xffffffff; 94 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 95 | selp.f32 %val0, %val1, %val0, %p0; 96 | // j=0 97 | setp.eq.xor.u32 %p0, %tidXAnd1, 0, %reverse; 98 | shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff; 99 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 100 | selp.f32 %val0, %val1, %val0, %p0; 101 | 102 | // i=2 103 | setp.ne.u32 %reverse, %tidXAnd8, 0; 104 | // j=2 105 | setp.eq.xor.u32 %p0, %tidXAnd4, 0, %reverse; 106 | shfl.sync.bfly.b32 %val1, %val0, 4, 0x1f, 0xffffffff; 107 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 108 | selp.f32 %val0, %val1, %val0, %p0; 109 | // j=1 110 | setp.eq.xor.u32 %p0, %tidXAnd2, 0, %reverse; 111 | shfl.sync.bfly.b32 %val1, %val0, 2, 0x1f, 0xffffffff; 112 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 113 | selp.f32 %val0, %val1, %val0, %p0; 114 | // j=0 115 | setp.eq.xor.u32 %p0, %tidXAnd1, 0, %reverse; 116 | shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff; 117 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 118 | selp.f32 %val0, %val1, %val0, %p0; 119 | 120 | // i=3 121 | setp.ne.u32 %reverse, %tidXAnd16, 0; 122 | // j=3 123 | setp.eq.xor.u32 %p0, %tidXAnd8, 0, %reverse; 124 | shfl.sync.bfly.b32 %val1, %val0, 8, 0x1f, 0xffffffff; 125 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 126 | selp.f32 %val0, %val1, %val0, %p0; 127 | // j=2 128 | setp.eq.xor.u32 %p0, %tidXAnd4, 0, %reverse; 129 | shfl.sync.bfly.b32 %val1, %val0, 4, 0x1f, 0xffffffff; 130 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 131 | selp.f32 %val0, %val1, %val0, %p0; 132 | // j=1 133 | setp.eq.xor.u32 %p0, %tidXAnd2, 0, %reverse; 134 | shfl.sync.bfly.b32 %val1, %val0, 2, 0x1f, 0xffffffff; 135 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 136 | selp.f32 %val0, %val1, %val0, %p0; 137 | // j=0 138 | setp.eq.xor.u32 %p0, %tidXAnd1, 0, %reverse; 139 | shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff; 140 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 141 | selp.f32 %val0, %val1, %val0, %p0; 142 | 143 | // i=4 144 | setp.ne.u32 %reverse, %tidXAnd32, 0; 145 | // j=4 146 | setp.eq.xor.u32 %p0, %tidXAnd16, 0, %reverse; 147 | shfl.sync.bfly.b32 %val1, %val0, 16, 0x1f, 0xffffffff; 148 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 149 | selp.f32 %val0, %val1, %val0, %p0; 150 | // j=3 151 | setp.eq.xor.u32 %p0, %tidXAnd8, 0, %reverse; 152 | shfl.sync.bfly.b32 %val1, %val0, 8, 0x1f, 0xffffffff; 153 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 154 | selp.f32 %val0, %val1, %val0, %p0; 155 | // j=2 156 | setp.eq.xor.u32 %p0, %tidXAnd4, 0, %reverse; 157 | shfl.sync.bfly.b32 %val1, %val0, 4, 0x1f, 0xffffffff; 158 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 159 | selp.f32 %val0, %val1, %val0, %p0; 160 | // j=1 161 | setp.eq.xor.u32 %p0, %tidXAnd2, 0, %reverse; 162 | shfl.sync.bfly.b32 %val1, %val0, 2, 0x1f, 0xffffffff; 163 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 164 | selp.f32 %val0, %val1, %val0, %p0; 165 | // j=0 166 | setp.eq.xor.u32 %p0, %tidXAnd1, 0, %reverse; 167 | shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff; 168 | setp.lt.xor.f32 %p0, %val0, %val1, %p0; 169 | selp.f32 %val0, %val1, %val0, %p0; 170 | loop_end: 171 | 172 | st.global.f32 [%ptr], %val0; 173 | 174 | sub.u64 %count, %count, 1; 175 | add.u64 %ptr, %ptr, %stride; 176 | setp.ne.u64 %p0, %count, 0; 177 | @%p0 bra outer_loop_start; 178 | outer_loop_end: 179 | 180 | ret; 181 | } 182 | -------------------------------------------------------------------------------- /learn_ptx/kernels/sort_merge_global.ptx: -------------------------------------------------------------------------------- 1 | .version 7.0 2 | .target sm_50 // enough for my Titan X 3 | .address_size 64 4 | 5 | .visible .entry sortMergeGlobal ( 6 | .param .u64 ptr, 7 | .param .u64 ptrOut, 8 | .param .u64 chunkSize 9 | ) { 10 | .reg .pred %p<2>; 11 | 12 | // Arguments 13 | .reg .u64 %ptr; 14 | .reg .u64 %ptrOut; 15 | .reg .u64 %chunkSize; 16 | 17 | // Other variables. 18 | .reg .u64 %tidX; 19 | .reg .u64 %ptrA; 20 | .reg .u64 %ptrB; 21 | .reg .u32 %curWarpA; 22 | .reg .u32 %curWarpB; 23 | .reg .u64 %remainingA; 24 | .reg .u64 %remainingB; 25 | .reg .u64 %dtmp<2>; 26 | .reg .u64 %i; 27 | 28 | // Stored values per warp rank. 29 | .reg .f32 %valA; 30 | .reg .f32 %valB; 31 | .reg .f32 %valOut; 32 | 33 | // Shared across warp. 34 | .reg .f32 %curA; 35 | .reg .f32 %curB; 36 | 37 | // Load arguments and thread properties. 38 | ld.param.u64 %ptr, [ptr]; 39 | ld.param.u64 %ptrOut, [ptrOut]; 40 | ld.param.u64 %chunkSize, [chunkSize]; 41 | 42 | cvt.u64.u32 %tidX, %tid.x; 43 | 44 | // Compute offset of chunk in buffer as &ptr[(ctaid.x*ntid.y + tid.y)*chunkSize*2 + tid.x] 45 | cvt.u64.u32 %dtmp0, %ctaid.x; 46 | cvt.u64.u32 %dtmp1, %ntid.y; 47 | mul.lo.u64 %dtmp0, %dtmp0, %dtmp1; 48 | cvt.u64.u32 %dtmp1, %tid.y; 49 | add.u64 %dtmp0, %dtmp0, %dtmp1; 50 | mul.lo.u64 %dtmp0, %dtmp0, %chunkSize; 51 | shl.b64 %dtmp0, %dtmp0, 1; 52 | add.u64 %dtmp0, %dtmp0, %tidX; 53 | shl.b64 %dtmp0, %dtmp0, 2; // float -> byte offset 54 | 55 | add.u64 %ptrA, %ptr, %dtmp0; 56 | shl.b64 %dtmp1, %chunkSize, 2; 57 | add.u64 %ptrB, %ptrA, %dtmp1; 58 | add.u64 %ptrOut, %ptrOut, %dtmp0; 59 | 60 | mov.u64 %remainingA, %chunkSize; 61 | mov.u64 %remainingB, %chunkSize; 62 | mov.u32 %curWarpA, 0; 63 | mov.u32 %curWarpB, 0; 64 | 65 | // Load the first chunk. 66 | ld.global.f32 %valA, [%ptrA]; 67 | ld.global.f32 %valB, [%ptrB]; 68 | shfl.sync.idx.b32 %curA, %valA, 0, 0x1f, 0xffffffff; 69 | shfl.sync.idx.b32 %curB, %valB, 0, 0x1f, 0xffffffff; 70 | 71 | mov.u64 %i, 0; 72 | merge_loop: 73 | // Set %p0 if we are storing into our current rank. 74 | and.b64 %dtmp0, %i, 31; 75 | setp.eq.u64 %p0, %dtmp0, %tidX; 76 | 77 | setp.gt.f32 %p1, %curA, %curB; 78 | @%p1 bra selected_B; 79 | selected_A: 80 | @%p0 mov.f32 %valOut, %curA; 81 | sub.u64 %remainingA, %remainingA, 1; 82 | add.u32 %curWarpA, %curWarpA, 1; 83 | setp.lt.u32 %p0, %curWarpA, 32; 84 | @%p0 bra reload_A_done; 85 | reload_A: 86 | setp.eq.u64 %p0, %remainingA, 0; 87 | @%p0 bra done_selecting; 88 | mov.u32 %curWarpA, 0; 89 | add.u64 %ptrA, %ptrA, 128; // 32*4 90 | ld.global.f32 %valA, [%ptrA]; 91 | prefetch.global.L1 [%ptrA+128]; 92 | reload_A_done: 93 | shfl.sync.idx.b32 %curA, %valA, %curWarpA, 0x1f, 0xffffffff; 94 | bra done_selecting; 95 | selected_B: 96 | @%p0 mov.f32 %valOut, %curB; 97 | sub.u64 %remainingB, %remainingB, 1; 98 | add.u32 %curWarpB, %curWarpB, 1; 99 | setp.lt.u32 %p0, %curWarpB, 32; 100 | @%p0 bra reload_B_done; 101 | reload_B: 102 | setp.eq.u64 %p0, %remainingB, 0; 103 | @%p0 bra done_selecting; 104 | mov.u32 %curWarpB, 0; 105 | add.u64 %ptrB, %ptrB, 128; // 32*4 106 | ld.global.f32 %valB, [%ptrB]; 107 | prefetch.global.L1 [%ptrB+128]; 108 | reload_B_done: 109 | shfl.sync.idx.b32 %curB, %valB, %curWarpB, 0x1f, 0xffffffff; 110 | done_selecting: 111 | 112 | // Store all values every time the warp fills up. 113 | setp.eq.u64 %p0, %dtmp0, 31; 114 | @!%p0 bra store_warp_done; 115 | store_warp: 116 | st.global.f32 [%ptrOut], %valOut; 117 | add.u64 %ptrOut, %ptrOut, 128; 118 | store_warp_done: 119 | 120 | add.u64 %i, %i, 1; 121 | 122 | // Break out of main loop and enter specific 123 | // copy mode if one of the halves is depleted. 124 | setp.eq.u64 %p0, %remainingA, 0; 125 | @%p0 bra copy_B_loop; 126 | setp.eq.u64 %p0, %remainingB, 0; 127 | @%p0 bra copy_A_loop; 128 | 129 | bra merge_loop; 130 | 131 | merge_loop_end: 132 | 133 | copy_A_loop: 134 | and.b64 %dtmp0, %i, 31; 135 | setp.eq.u64 %p0, %dtmp0, %tidX; 136 | @%p0 mov.f32 %valOut, %curA; 137 | 138 | setp.eq.u64 %p0, %dtmp0, 31; 139 | @!%p0 bra store_warp_A_done; 140 | store_warp_A: 141 | st.global.f32 [%ptrOut], %valOut; 142 | add.u64 %ptrOut, %ptrOut, 128; // 32*4 143 | store_warp_A_done: 144 | 145 | sub.u64 %remainingA, %remainingA, 1; 146 | add.u32 %curWarpA, %curWarpA, 1; 147 | setp.lt.u32 %p0, %curWarpA, 32; 148 | @%p0 bra reload_A_done_2; 149 | reload_A_2: 150 | setp.eq.u64 %p0, %remainingA, 0; 151 | @%p0 bra skip_shfl_A; 152 | mov.u32 %curWarpA, 0; 153 | add.u64 %ptrA, %ptrA, 128; // 32*4 154 | ld.global.f32 %valA, [%ptrA]; 155 | prefetch.global.L1 [%ptrA+128]; 156 | reload_A_done_2: 157 | shfl.sync.idx.b32 %curA, %valA, %curWarpA, 0x1f, 0xffffffff; 158 | skip_shfl_A: 159 | 160 | setp.eq.u64 %p0, %remainingA, 0; 161 | @%p0 bra copy_loops_end; 162 | add.u64 %i, %i, 1; 163 | bra copy_A_loop; 164 | 165 | copy_B_loop: 166 | and.b64 %dtmp0, %i, 31; 167 | setp.eq.u64 %p0, %dtmp0, %tidX; 168 | @%p0 mov.f32 %valOut, %curB; 169 | 170 | setp.eq.u64 %p0, %dtmp0, 31; 171 | @!%p0 bra store_warp_B_done; 172 | store_warp_B: 173 | st.global.f32 [%ptrOut], %valOut; 174 | add.u64 %ptrOut, %ptrOut, 128; // 32*4 175 | store_warp_B_done: 176 | 177 | sub.u64 %remainingB, %remainingB, 1; 178 | add.u32 %curWarpB, %curWarpB, 1; 179 | setp.lt.u32 %p0, %curWarpB, 32; 180 | @%p0 bra reload_B_done_2; 181 | reload_B_2: 182 | setp.eq.u64 %p0, %remainingB, 0; 183 | @%p0 bra skip_shfl_B; 184 | mov.u32 %curWarpB, 0; 185 | add.u64 %ptrB, %ptrB, 128; // 32*4 186 | ld.global.f32 %valB, [%ptrB]; 187 | prefetch.global.L1 [%ptrB+128]; 188 | reload_B_done_2: 189 | shfl.sync.idx.b32 %curB, %valB, %curWarpB, 0x1f, 0xffffffff; 190 | skip_shfl_B: 191 | 192 | setp.eq.u64 %p0, %remainingB, 0; 193 | @%p0 bra copy_loops_end; 194 | add.u64 %i, %i, 1; 195 | bra copy_B_loop; 196 | 197 | copy_loops_end: 198 | 199 | ret; 200 | } 201 | -------------------------------------------------------------------------------- /learn_ptx/matmul.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | import numpy as np 4 | 5 | from .context import compile_function, gpu_to_numpy, measure_time, numpy_to_gpu, sync 6 | 7 | 8 | def matmul_simple_block_v1(): 9 | fn = compile_function("matmul_simple_block_v1.ptx", "blockedMatmul") 10 | evaluate_matmul_fn(fn) 11 | 12 | 13 | def matmul_simple_block_v2(): 14 | fn = compile_function("matmul_simple_block_v2.ptx", "blockedMatmulV2") 15 | evaluate_matmul_fn(fn) 16 | 17 | 18 | def matmul_simple_block_v3(): 19 | fn = compile_function("matmul_simple_block_v3.ptx", "blockedMatmulV3") 20 | evaluate_matmul_fn(fn) 21 | 22 | 23 | def matmul_simple_block_v4(): 24 | fn = compile_function("matmul_simple_block_v4.ptx", "blockedMatmulV4") 25 | evaluate_matmul_fn(fn) 26 | 27 | 28 | def matmul_inner_loop(): 29 | fn = compile_function("matmul_inner_loop.ptx", "simpleMatmul") 30 | evaluate_matmul_fn(fn) 31 | 32 | 33 | def matmul_big_blocks(): 34 | fn = compile_function("matmul_big_blocks.ptx", "bigBlocksMatmul") 35 | 36 | def call_fn(A: np.ndarray, B: np.ndarray, A_buf: Any, B_buf: Any, out_buf: Any): 37 | fn( 38 | A_buf, 39 | B_buf, 40 | out_buf, 41 | np.int32(A.shape[0] // 32), 42 | grid=( 43 | A.shape[0] // 32, 44 | A.shape[1] // 32, 45 | 1, 46 | ), 47 | block=(32, 8, 1), 48 | ) 49 | 50 | generic_eval_matmul(call_fn) 51 | 52 | 53 | def matmul_big_blocks_v2(): 54 | fn = compile_function("matmul_big_blocks_v2.ptx", "bigBlocksMatmulV2") 55 | 56 | def call_fn(A: np.ndarray, B: np.ndarray, A_buf: Any, B_buf: Any, out_buf: Any): 57 | fn( 58 | A_buf, 59 | B_buf, 60 | out_buf, 61 | np.int32(A.shape[0] // 32), 62 | grid=( 63 | A.shape[0] // 32, 64 | A.shape[1] // 32, 65 | 1, 66 | ), 67 | block=(32, 8, 1), 68 | ) 69 | 70 | generic_eval_matmul(call_fn) 71 | 72 | 73 | def matmul_big_blocks_v3(): 74 | fn = compile_function("matmul_big_blocks_v3.ptx", "bigBlocksMatmulV3") 75 | 76 | def call_fn(A: np.ndarray, B: np.ndarray, A_buf: Any, B_buf: Any, out_buf: Any): 77 | fn( 78 | A_buf, 79 | B_buf, 80 | out_buf, 81 | np.int32(A.shape[0] // 64), 82 | grid=( 83 | A.shape[0] // 64, 84 | A.shape[1] // 64, 85 | 1, 86 | ), 87 | block=(32, 8, 1), 88 | ) 89 | 90 | generic_eval_matmul(call_fn) 91 | 92 | 93 | def matmul_big_blocks_v4(): 94 | fn = compile_function("matmul_big_blocks_v4.ptx", "bigBlocksMatmulV4") 95 | 96 | def call_fn(A: np.ndarray, B: np.ndarray, A_buf: Any, B_buf: Any, out_buf: Any): 97 | fn( 98 | A_buf, 99 | B_buf, 100 | out_buf, 101 | np.int32(A.shape[0] // 64), 102 | grid=( 103 | A.shape[0] // 64, 104 | A.shape[1] // 64, 105 | 1, 106 | ), 107 | block=(32, 8, 1), 108 | ) 109 | 110 | generic_eval_matmul(call_fn) 111 | 112 | 113 | def matmul_big_blocks_v5(): 114 | fn = compile_function("matmul_big_blocks_v5.ptx", "bigBlocksMatmulV5") 115 | 116 | def call_fn(A: np.ndarray, B: np.ndarray, A_buf: Any, B_buf: Any, out_buf: Any): 117 | fn( 118 | A_buf, 119 | B_buf, 120 | out_buf, 121 | np.int32(A.shape[0] // 64), 122 | grid=( 123 | A.shape[0] // 64, 124 | A.shape[1] // 64, 125 | 1, 126 | ), 127 | block=(32, 8, 1), 128 | ) 129 | 130 | generic_eval_matmul(call_fn) 131 | 132 | 133 | def matmul_wmma_v1(): 134 | fn = compile_function("matmul_wmma_v1.ptx", "wmmaMatmulV1") 135 | 136 | def call_fn(A: np.ndarray, B: np.ndarray, A_buf: Any, B_buf: Any, out_buf: Any): 137 | fn( 138 | A_buf, 139 | B_buf, 140 | out_buf, 141 | np.int32(A.shape[0] // 32), 142 | grid=( 143 | A.shape[0] // 32, 144 | A.shape[1] // 32, 145 | 1, 146 | ), 147 | block=(32, 4, 1), 148 | ) 149 | 150 | generic_eval_matmul(call_fn) 151 | 152 | 153 | def matmul_wmma_v2(): 154 | fn = compile_function("matmul_wmma_v2.ptx", "wmmaMatmulV2") 155 | 156 | def call_fn(A: np.ndarray, B: np.ndarray, A_buf: Any, B_buf: Any, out_buf: Any): 157 | fn( 158 | A_buf, 159 | B_buf, 160 | out_buf, 161 | np.int32(A.shape[0] // 32), 162 | grid=( 163 | A.shape[0] // 32, 164 | A.shape[1] // 32, 165 | 1, 166 | ), 167 | block=(32, 4, 1), 168 | ) 169 | 170 | generic_eval_matmul(call_fn) 171 | 172 | 173 | def matmul_wmma_v3(): 174 | fn = compile_function("matmul_wmma_v3.ptx", "wmmaMatmulV3") 175 | 176 | def call_fn(A: np.ndarray, B: np.ndarray, A_buf: Any, B_buf: Any, out_buf: Any): 177 | fn( 178 | A_buf, 179 | B_buf, 180 | out_buf, 181 | np.int32(A.shape[0] // 32), 182 | grid=( 183 | A.shape[0] // 32, 184 | A.shape[1] // 32, 185 | 1, 186 | ), 187 | block=(32, 4, 1), 188 | ) 189 | 190 | generic_eval_matmul(call_fn) 191 | 192 | 193 | def matmul_wmma_v4(): 194 | fn = compile_function("matmul_wmma_v4.ptx", "wmmaMatmulV4") 195 | 196 | def call_fn(A: np.ndarray, B: np.ndarray, A_buf: Any, B_buf: Any, out_buf: Any): 197 | fn( 198 | A_buf, 199 | B_buf, 200 | out_buf, 201 | np.int32(A.shape[0] // 32), 202 | grid=( 203 | A.shape[0] // 32, 204 | A.shape[1] // 32, 205 | 1, 206 | ), 207 | block=(32, 4, 1), 208 | ) 209 | 210 | generic_eval_matmul(call_fn) 211 | 212 | 213 | def matmul_wmma_v5(): 214 | fn = compile_function("matmul_wmma_v5.ptx", "wmmaMatmulV5") 215 | 216 | def call_fn(A: np.ndarray, B: np.ndarray, A_buf: Any, B_buf: Any, out_buf: Any): 217 | fn( 218 | A_buf, 219 | B_buf, 220 | out_buf, 221 | np.int32(A.shape[0] // 64), 222 | grid=( 223 | A.shape[0] // 64, 224 | A.shape[1] // 64, 225 | 1, 226 | ), 227 | block=(32, 16, 1), 228 | ) 229 | 230 | generic_eval_matmul(call_fn) 231 | 232 | 233 | def matmul_wmma_v6(): 234 | fn = compile_function("matmul_wmma_v6.ptx", "wmmaMatmulV6") 235 | 236 | def call_fn(A: np.ndarray, B: np.ndarray, A_buf: Any, B_buf: Any, out_buf: Any): 237 | fn( 238 | A_buf, 239 | B_buf, 240 | out_buf, 241 | np.int32(A.shape[0] // 64), 242 | grid=( 243 | A.shape[0] // 64, 244 | A.shape[1] // 64, 245 | 1, 246 | ), 247 | block=(32, 16, 1), 248 | ) 249 | 250 | generic_eval_matmul(call_fn) 251 | 252 | 253 | def matmul_wmma_v7(): 254 | fn = compile_function("matmul_wmma_v7.ptx", "wmmaMatmulV7") 255 | 256 | def call_fn(A: np.ndarray, B: np.ndarray, A_buf: Any, B_buf: Any, out_buf: Any): 257 | fn( 258 | A_buf, 259 | B_buf, 260 | out_buf, 261 | np.int32(A.shape[0] // 64), 262 | grid=( 263 | A.shape[0] // 64, 264 | A.shape[1] // 64, 265 | 1, 266 | ), 267 | block=(32, 16, 1), 268 | ) 269 | 270 | generic_eval_matmul(call_fn) 271 | 272 | 273 | def matmul_wmma_v8(): 274 | fn = compile_function("matmul_wmma_v8.ptx", "wmmaMatmulV8") 275 | 276 | def call_fn(A: np.ndarray, B: np.ndarray, A_buf: Any, B_buf: Any, out_buf: Any): 277 | fn( 278 | A_buf, 279 | B_buf, 280 | out_buf, 281 | np.int32(A.shape[0] // 64), 282 | grid=( 283 | A.shape[0] // 64, 284 | A.shape[1] // 64, 285 | 1, 286 | ), 287 | block=(32, 16, 1), 288 | ) 289 | 290 | generic_eval_matmul(call_fn) 291 | 292 | 293 | def evaluate_matmul_fn(fn: Callable): 294 | def call_fn(A: np.ndarray, B: np.ndarray, A_buf: Any, B_buf: Any, out_buf: Any): 295 | block_size = 32 296 | fn( 297 | A_buf, 298 | B_buf, 299 | out_buf, 300 | np.int32(A.shape[0] // block_size), 301 | grid=( 302 | A.shape[0] // block_size, 303 | A.shape[1] // block_size, 304 | 1, 305 | ), 306 | block=(block_size, block_size, 1), 307 | ) 308 | 309 | generic_eval_matmul(call_fn) 310 | 311 | 312 | def generic_eval_matmul(fn: Callable, block_mult: int = 1): 313 | size = 8192 314 | A = np.random.normal(size=[size, size]).astype(np.float32) 315 | B = np.random.normal(size=[size, size]).astype(np.float32) 316 | A_buf = numpy_to_gpu(A) 317 | B_buf = numpy_to_gpu(B) 318 | out_buf = numpy_to_gpu(A * 0) 319 | with measure_time() as timer: 320 | fn( 321 | A, 322 | B, 323 | A_buf, 324 | B_buf, 325 | out_buf, 326 | ) 327 | sync() 328 | results = gpu_to_numpy(out_buf, A.shape, A.dtype) 329 | expected = A @ B 330 | print(f"maximum absolute error of matmul is {np.abs(results - expected).max()}") 331 | print(f"time elapsed: {timer()}") 332 | 333 | 334 | if __name__ == "__main__": 335 | matmul_wmma_v5() 336 | -------------------------------------------------------------------------------- /learn_ptx/reduction.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | 3 | import numpy as np 4 | 5 | from .context import compile_function, gpu_to_numpy, measure_time, numpy_to_gpu, sync 6 | 7 | 8 | def reduction_bool_naive(): 9 | fn = compile_function("reduction_bool_naive.ptx", "reductionBoolNaive") 10 | inputs = np.random.normal(size=[16384, 16384]).astype(np.float32) 11 | threshold = np.median(inputs.min(axis=-1)) 12 | outputs = np.zeros([inputs.shape[0]], dtype=np.uint8) 13 | input_buf = numpy_to_gpu(inputs) 14 | output_buf = numpy_to_gpu(outputs) 15 | with measure_time() as timer: 16 | fn( 17 | input_buf, 18 | output_buf, 19 | np.float32(threshold), 20 | np.int64(inputs.shape[1]), 21 | grid=(inputs.shape[0], 1, 1), 22 | block=(1, 1, 1), 23 | ) 24 | sync() 25 | results = gpu_to_numpy(output_buf, outputs.shape, outputs.dtype) 26 | expected = (inputs < threshold).any(axis=-1) 27 | print(f"took {timer()} seconds") 28 | print(f"disagreement frac {np.mean((expected != results).astype(np.float32))}") 29 | print(f"true frac {results.astype(np.float32).mean()}") 30 | 31 | 32 | def reduction_bool_warp(): 33 | fn = compile_function("reduction_bool_warp.ptx", "reductionBoolWarp") 34 | inputs = np.random.normal(size=[16384, 16384]).astype(np.float32) 35 | threshold = np.median(inputs.min(axis=-1)) 36 | outputs = np.zeros([inputs.shape[0]], dtype=np.uint8) 37 | input_buf = numpy_to_gpu(inputs) 38 | output_buf = numpy_to_gpu(outputs) 39 | with measure_time() as timer: 40 | fn( 41 | input_buf, 42 | output_buf, 43 | np.float32(threshold), 44 | np.int64(inputs.shape[1]), 45 | grid=(inputs.shape[0], 1, 1), 46 | block=(32, 1, 1), 47 | ) 48 | sync() 49 | results = gpu_to_numpy(output_buf, outputs.shape, outputs.dtype) 50 | expected = (inputs < threshold).any(axis=-1) 51 | print(f"took {timer()} seconds") 52 | print(f"disagreement frac {np.mean((expected != results).astype(np.float32))}") 53 | print(f"true frac {results.astype(np.float32).mean()}") 54 | 55 | 56 | def reduction_bool_warp_vec(): 57 | fn = compile_function("reduction_bool_warp_vec.ptx", "reductionBoolWarpVec") 58 | inputs = np.random.normal(size=[16384, 16384]).astype(np.float32) 59 | threshold = np.median(inputs.min(axis=-1)) 60 | outputs = np.zeros([inputs.shape[0]], dtype=np.uint8) 61 | input_buf = numpy_to_gpu(inputs) 62 | output_buf = numpy_to_gpu(outputs) 63 | with measure_time() as timer: 64 | fn( 65 | input_buf, 66 | output_buf, 67 | np.float32(threshold), 68 | np.int64(inputs.shape[1]), 69 | grid=(inputs.shape[0], 1, 1), 70 | block=(32, 1, 1), 71 | ) 72 | sync() 73 | results = gpu_to_numpy(output_buf, outputs.shape, outputs.dtype) 74 | expected = (inputs < threshold).any(axis=-1) 75 | print(f"took {timer()} seconds") 76 | print(f"disagreement frac {np.mean((expected != results).astype(np.float32))}") 77 | print(f"true frac {results.astype(np.float32).mean()}") 78 | 79 | 80 | def reduction_trans_bool_naive(): 81 | fn = compile_function("reduction_trans_bool_naive.ptx", "reductionTransBoolNaive") 82 | inputs = np.random.uniform(size=[16384, 16384]).astype(np.float32) 83 | threshold = np.median(inputs.min(axis=-1)) 84 | outputs = np.zeros([inputs.shape[1]], dtype=np.uint8) 85 | input_buf = numpy_to_gpu(inputs) 86 | output_buf = numpy_to_gpu(outputs) 87 | with measure_time() as timer: 88 | fn( 89 | input_buf, 90 | output_buf, 91 | np.float32(threshold), 92 | np.int64(inputs.shape[0]), 93 | grid=(inputs.shape[1], 1, 1), 94 | block=(1, 1, 1), 95 | ) 96 | sync() 97 | results = gpu_to_numpy(output_buf, outputs.shape, outputs.dtype) 98 | expected = (inputs < threshold).any(axis=0) 99 | print(f"took {timer()} seconds") 100 | print(f"disagreement frac {np.mean((expected != results).astype(np.float32))}") 101 | print(f"true frac {results.astype(np.float32).mean()}") 102 | 103 | 104 | def reduction_trans_bool_blocked(): 105 | fn = compile_function( 106 | "reduction_trans_bool_blocked.ptx", "reductionTransBoolBlocked" 107 | ) 108 | inputs = np.random.uniform(size=[16384, 16384]).astype(np.float32) 109 | threshold = np.median(inputs.min(axis=-1)) 110 | outputs = np.zeros([inputs.shape[1]], dtype=np.uint8) 111 | input_buf = numpy_to_gpu(inputs) 112 | output_buf = numpy_to_gpu(outputs) 113 | with measure_time() as timer: 114 | fn( 115 | input_buf, 116 | output_buf, 117 | np.float32(threshold), 118 | np.int64(inputs.shape[0]), 119 | grid=(inputs.shape[1] // 256, 1, 1), 120 | block=(256, 1, 1), 121 | ) 122 | sync() 123 | results = gpu_to_numpy(output_buf, outputs.shape, outputs.dtype) 124 | expected = (inputs < threshold).any(axis=0) 125 | print(f"took {timer()} seconds") 126 | print(f"disagreement frac {np.mean((expected != results).astype(np.float32))}") 127 | print(f"true frac {results.astype(np.float32).mean()}") 128 | 129 | 130 | def reduction_all_max_naive(): 131 | fn = compile_function("reduction_all_max_naive.ptx", "reductionAllMaxNaive") 132 | inputs = np.random.uniform(size=[16384**2]).astype(np.float32) 133 | outputs = np.zeros([1], dtype=np.float32) 134 | input_buf = numpy_to_gpu(inputs) 135 | output_buf = numpy_to_gpu(outputs) 136 | with measure_time() as timer: 137 | fn( 138 | input_buf, 139 | output_buf, 140 | np.int64(len(inputs) // 1024), 141 | grid=(1, 1, 1), 142 | block=(1024, 1, 1), 143 | ) 144 | sync() 145 | results = gpu_to_numpy(output_buf, outputs.shape, outputs.dtype) 146 | expected = np.max(inputs) 147 | print(f"took {timer()} seconds") 148 | assert results[0] == expected, f"{results[0]=} {expected=}" 149 | 150 | 151 | def reduction_all_max_naive_opt(): 152 | fn = compile_function("reduction_all_max_naive_opt.ptx", "reductionAllMaxNaiveOpt") 153 | inputs = np.random.uniform(size=[16384**2]).astype(np.float32) 154 | outputs = np.zeros([1], dtype=np.float32) 155 | input_buf = numpy_to_gpu(inputs) 156 | output_buf = numpy_to_gpu(outputs) 157 | with measure_time() as timer: 158 | fn( 159 | input_buf, 160 | output_buf, 161 | np.int64(len(inputs) // 1024), 162 | grid=(1, 1, 1), 163 | block=(1024, 1, 1), 164 | ) 165 | sync() 166 | results = gpu_to_numpy(output_buf, outputs.shape, outputs.dtype) 167 | expected = np.max(inputs) 168 | print(f"took {timer()} seconds") 169 | assert results[0] == expected, f"{results[0]=} {expected=}" 170 | 171 | 172 | def reduction_all_max_naive_opt_novec(): 173 | fn = compile_function( 174 | "reduction_all_max_naive_opt_novec.ptx", "reductionAllMaxNaiveOptNoVec" 175 | ) 176 | inputs = np.random.uniform(size=[16384**2]).astype(np.float32) 177 | outputs = np.zeros([1], dtype=np.float32) 178 | input_buf = numpy_to_gpu(inputs) 179 | output_buf = numpy_to_gpu(outputs) 180 | with measure_time() as timer: 181 | fn( 182 | input_buf, 183 | output_buf, 184 | np.int64(len(inputs) // 1024), 185 | grid=(1, 1, 1), 186 | block=(1024, 1, 1), 187 | ) 188 | sync() 189 | results = gpu_to_numpy(output_buf, outputs.shape, outputs.dtype) 190 | expected = np.max(inputs) 191 | print(f"took {timer()} seconds") 192 | assert results[0] == expected, f"{results[0]=} {expected=}" 193 | 194 | 195 | def reduction_all_max_multistep(): 196 | opt_reduce = compile_function( 197 | "reduction_all_max_naive_opt.ptx", "reductionAllMaxNaiveOpt" 198 | ) 199 | small_reduce = compile_function( 200 | "reduction_all_max_naive.ptx", "reductionAllMaxNaive" 201 | ) 202 | inputs = np.random.uniform(size=[16384**2]).astype(np.float32) 203 | inputs[16384 * 8192 + 1337] = 1.5 # hide a needle in the haystack 204 | outputs = np.zeros([1024], dtype=np.float32) 205 | input_buf = numpy_to_gpu(inputs) 206 | output_buf = numpy_to_gpu(outputs) 207 | with measure_time() as timer: 208 | opt_reduce( 209 | input_buf, 210 | output_buf, 211 | np.int64((len(inputs) // 1024) // 1024), 212 | grid=(1024, 1, 1), 213 | block=(1024, 1, 1), 214 | ) 215 | small_reduce( 216 | output_buf, 217 | output_buf, 218 | np.int64(1), 219 | grid=(1, 1, 1), 220 | block=(1024, 1, 1), 221 | ) 222 | sync() 223 | results = gpu_to_numpy(output_buf, outputs.shape, outputs.dtype) 224 | expected = np.max(inputs) 225 | print(f"took {timer()} seconds") 226 | assert results[0] == expected, f"{results[0]=} {expected=}" 227 | 228 | 229 | def reduction_all_max_flexible_multistep(): 230 | # opt_reduce = compile_function( 231 | # "reduction_all_max_naive_opt_flexible.ptx", "reductionAllMaxNaiveOptFlexible" 232 | # ) 233 | # opt_reduce = compile_function( 234 | # "reduction_all_max_naive_opt_flexible_novec.ptx", 235 | # "reductionAllMaxNaiveOptFlexibleNovec", 236 | # ) 237 | # opt_reduce = compile_function( 238 | # "reduction_all_max_naive_opt_flexible_widevec.ptx", 239 | # "reductionAllMaxNaiveOptFlexibleWidevec", 240 | # ) 241 | opt_reduce = compile_function( 242 | "reduction_all_max_naive_opt_flexible_sin.ptx", 243 | "reductionAllMaxNaiveOptFlexibleSin", 244 | ) 245 | small_reduce = compile_function( 246 | "reduction_all_max_naive.ptx", "reductionAllMaxNaive" 247 | ) 248 | inputs = np.random.uniform(size=[16384**2]).astype(np.float32) 249 | inputs[16384 * 8192 + 1337] = 1.5 # hide a needle in the haystack 250 | expected = np.max(np.sin(inputs)) 251 | outputs = np.zeros([1024], dtype=np.float32) 252 | input_buf = numpy_to_gpu(inputs) 253 | warp_values = [1, 2, 4, 8, 16, 32] 254 | block_values = [1, 2, 4, 8, 16, 32, 64] 255 | output_grid = np.zeros([len(warp_values), len(block_values)]) 256 | for i, n_warps in enumerate(warp_values): 257 | for j, n_blocks in enumerate(block_values): 258 | output_buf = numpy_to_gpu(outputs) 259 | with measure_time() as timer: 260 | opt_reduce( 261 | input_buf, 262 | output_buf, 263 | np.int64((len(inputs) // n_blocks) // (n_warps * 32)), 264 | grid=(n_blocks, 1, 1), 265 | block=(32, n_warps, 1024 // (32 * n_warps)), 266 | ) 267 | # Always reduce 1024 values, even though some of them 268 | # may be zero. 269 | small_reduce( 270 | output_buf, 271 | output_buf, 272 | np.int64(1), 273 | grid=(1, 1, 1), 274 | block=(1024, 1, 1), 275 | ) 276 | sync() 277 | results = gpu_to_numpy(output_buf, outputs.shape, outputs.dtype) 278 | rate = ((int(np.prod(inputs.shape)) * 4) / timer()) / (2**30) 279 | print(f"{n_warps=} {n_blocks=} GiB/s={rate}") 280 | assert np.allclose(results[0], expected), f"{results[0]=} {expected=}" 281 | output_grid[i, j] = rate 282 | rows = [["", *[f"{i} warps" for i in warp_values]]] 283 | for label, row in zip(block_values, output_grid.T): 284 | rows.append([f"{label} SMs", *[f"{x:.02f} GiB/s" for x in row]]) 285 | print("") 286 | for row in rows: 287 | print("") 288 | for item in row: 289 | print(f"") 290 | print("") 291 | print("
{item}
") 292 | 293 | 294 | if __name__ == "__main__": 295 | reduction_all_max_flexible_multistep() 296 | -------------------------------------------------------------------------------- /learn_ptx/sort.py: -------------------------------------------------------------------------------- 1 | import time 2 | from math import ceil 3 | 4 | import numpy as np 5 | 6 | from .context import compile_function, gpu_to_numpy, measure_time, numpy_to_gpu, sync 7 | 8 | 9 | def sort_bitonic_warp(): 10 | fn = compile_function("sort_bitonic_warp.ptx", "sortBitonicWarp") 11 | inputs = np.random.normal(size=[16384 * 32, 32]).astype(np.float32) 12 | input_buf = numpy_to_gpu(inputs) 13 | with measure_time() as timer: 14 | fn( 15 | input_buf, 16 | grid=(inputs.shape[0], 1, 1), 17 | block=(32, 1, 1), 18 | ) 19 | sync() 20 | results = gpu_to_numpy(input_buf, inputs.shape, inputs.dtype) 21 | expected = np.sort(inputs, axis=-1) 22 | print(f"took {timer()} seconds") 23 | assert np.allclose(results, expected), f"\n{results=}\n{expected=}" 24 | 25 | 26 | def sort_bitonic_warp_v2(): 27 | fn = compile_function("sort_bitonic_warp_v2.ptx", "sortBitonicWarpV2") 28 | inputs = np.random.normal(size=[16384 * 32, 32]).astype(np.float32) 29 | input_buf = numpy_to_gpu(inputs) 30 | with measure_time() as timer: 31 | fn( 32 | input_buf, 33 | grid=(inputs.shape[0], 1, 1), 34 | block=(32, 1, 1), 35 | ) 36 | sync() 37 | results = gpu_to_numpy(input_buf, inputs.shape, inputs.dtype) 38 | expected = np.sort(inputs, axis=-1) 39 | print(f"took {timer()} seconds") 40 | assert np.allclose(results, expected), f"\n{results=}\n{expected=}" 41 | 42 | 43 | def sort_bitonic_warp_v3(): 44 | fn = compile_function("sort_bitonic_warp_v3.ptx", "sortBitonicWarpV3") 45 | inputs = np.random.normal(size=[16384 * 32, 32]).astype(np.float32) 46 | input_buf = numpy_to_gpu(inputs) 47 | loop_count = 8 # more work per thread 48 | loop_stride = (int(np.prod(inputs.shape)) * 4) // loop_count 49 | with measure_time() as timer: 50 | fn( 51 | input_buf, 52 | np.int64(loop_count), 53 | np.int64(loop_stride), 54 | grid=((inputs.shape[0] // 8) // loop_count, 1, 1), 55 | block=(32, 8, 1), 56 | ) 57 | sync() 58 | results = gpu_to_numpy(input_buf, inputs.shape, inputs.dtype) 59 | expected = np.sort(inputs, axis=-1) 60 | print(f"took {timer()} seconds") 61 | assert np.allclose(results, expected), f"\n{results=}\n{expected=}" 62 | 63 | 64 | def sort_bitonic_block(): 65 | fn = compile_function("sort_bitonic_block.ptx", "sortBitonicBlock") 66 | inputs = np.random.normal(size=[16384, 1024]).astype(np.float32) 67 | input_buf = numpy_to_gpu(inputs) 68 | with measure_time() as timer: 69 | fn( 70 | input_buf, 71 | grid=(inputs.shape[0], 1, 1), 72 | block=(32, 32, 1), 73 | ) 74 | sync() 75 | results = gpu_to_numpy(input_buf, inputs.shape, inputs.dtype) 76 | expected = np.sort(inputs, axis=-1) 77 | print(f"took {timer()} seconds") 78 | assert np.allclose(results, expected), f"\n{results=}\n{expected=}" 79 | 80 | 81 | def sort_bitonic_block_v2(): 82 | fn = compile_function("sort_bitonic_block_v2.ptx", "sortBitonicBlockV2") 83 | inputs = np.random.normal(size=[16384, 1024]).astype(np.float32) 84 | input_buf = numpy_to_gpu(inputs) 85 | with measure_time() as timer: 86 | fn( 87 | input_buf, 88 | grid=(inputs.shape[0], 1, 1), 89 | block=(32, 32, 1), 90 | ) 91 | sync() 92 | results = gpu_to_numpy(input_buf, inputs.shape, inputs.dtype) 93 | expected = np.sort(inputs, axis=-1) 94 | print(f"took {timer()} seconds") 95 | assert np.allclose(results, expected), f"\n{results=}\n{expected=}" 96 | 97 | 98 | def sort_merge_global(): 99 | warp_fn = compile_function("sort_bitonic_warp_v2.ptx", "sortBitonicWarpV2") 100 | global_fn = compile_function("sort_merge_global.ptx", "sortMergeGlobal") 101 | inputs = np.random.normal(size=[2**28]).astype(np.float32) 102 | tmp = np.zeros_like(inputs) 103 | input_buf = numpy_to_gpu(inputs) 104 | tmp_buf = numpy_to_gpu(tmp) 105 | num_el = int(np.prod(inputs.shape)) 106 | print("sorting on GPU...") 107 | with measure_time() as timer: 108 | # Sort per warp before merging. 109 | warp_fn( 110 | input_buf, 111 | grid=(num_el // 32, 1, 1), 112 | block=(32, 1, 1), 113 | ) 114 | n_sorted = 32 115 | while n_sorted < num_el: 116 | # Maximum of 8 warps per block, to maximize occupancy when possible. 117 | concurrency = min(num_el // (2 * n_sorted), 8) 118 | grid_size = num_el // (2 * n_sorted * concurrency) 119 | global_fn( 120 | input_buf, 121 | tmp_buf, 122 | np.int64(n_sorted), 123 | grid=(grid_size, 1, 1), 124 | block=(32, concurrency, 1), 125 | ) 126 | input_buf, tmp_buf = tmp_buf, input_buf 127 | n_sorted *= 2 128 | sync() 129 | results = gpu_to_numpy(input_buf, inputs.shape, inputs.dtype) 130 | print("sorting on CPU...") 131 | t1 = time.time() 132 | expected = np.sort(inputs, axis=-1) 133 | t2 = time.time() 134 | print(f"took {timer()} seconds on GPU and {t2 - t1} seconds on CPU") 135 | assert np.allclose(results, expected), f"\n{results=}\n{expected=}" 136 | 137 | 138 | def sort_bitonic_global(): 139 | warp_fn = compile_function("sort_bitonic_warp_v2.ptx", "sortBitonicWarpV2") 140 | global_fn = compile_function("sort_bitonic_global.ptx", "sortBitonicGlobal") 141 | inputs = np.random.normal(size=[2**28]).astype(np.float32) 142 | tmp = np.zeros_like(inputs) 143 | input_buf = numpy_to_gpu(inputs) 144 | num_el = int(np.prod(inputs.shape)) 145 | print("sorting on GPU...") 146 | with measure_time() as timer: 147 | # Sort per warp before merging. 148 | warp_fn( 149 | input_buf, 150 | grid=(num_el // 32, 1, 1), 151 | block=(32, 1, 1), 152 | ) 153 | block_size = 64 154 | while block_size <= num_el: 155 | sub_block_size = block_size 156 | while sub_block_size > 1: 157 | global_fn( 158 | input_buf, 159 | np.int64(sub_block_size), 160 | np.int32(sub_block_size == block_size), 161 | grid=((len(inputs) // 256) // 2, 1, 1), 162 | block=(256, 1, 1), 163 | ) 164 | sub_block_size //= 2 165 | block_size *= 2 166 | sync() 167 | results = gpu_to_numpy(input_buf, inputs.shape, inputs.dtype) 168 | print("sorting on CPU...") 169 | t1 = time.time() 170 | expected = np.sort(inputs, axis=-1) 171 | t2 = time.time() 172 | print(f"took {timer()} seconds on GPU and {t2 - t1} seconds on CPU") 173 | assert np.allclose(results, expected), f"\n{results=}\n{expected=}" 174 | 175 | 176 | def sort_bitonic_global_v2(): 177 | warp_fn = compile_function("sort_bitonic_warp_v2.ptx", "sortBitonicWarpV2") 178 | global_fn = compile_function("sort_bitonic_global_v2.ptx", "sortBitonicGlobalV2") 179 | inputs = np.random.normal(size=[2**28]).astype(np.float32) 180 | input_buf = numpy_to_gpu(inputs) 181 | num_el = int(np.prod(inputs.shape)) 182 | print("sorting on GPU...") 183 | with measure_time() as timer: 184 | # Sort per warp before merging. 185 | warp_fn( 186 | input_buf, 187 | grid=(num_el // 32, 1, 1), 188 | block=(32, 1, 1), 189 | ) 190 | block_size = 64 191 | while block_size <= num_el: 192 | sub_block_size = block_size 193 | while sub_block_size > 1: 194 | global_fn( 195 | input_buf, 196 | np.int64(sub_block_size), 197 | np.int32(sub_block_size == block_size), 198 | grid=((len(inputs) // 256) // 2, 1, 1), 199 | block=(256, 1, 1), 200 | ) 201 | if sub_block_size <= 512: 202 | # Block should have completed the sort. 203 | break 204 | sub_block_size //= 2 205 | block_size *= 2 206 | sync() 207 | results = gpu_to_numpy(input_buf, inputs.shape, inputs.dtype) 208 | print("sorting on CPU...") 209 | t1 = time.time() 210 | expected = np.sort(inputs, axis=-1) 211 | t2 = time.time() 212 | print(f"took {timer()} seconds on GPU and {t2 - t1} seconds on CPU") 213 | assert np.allclose(results, expected), f"\n{results=}\n{expected=}" 214 | 215 | 216 | if __name__ == "__main__": 217 | sort_bitonic_global() 218 | sort_bitonic_global_v2() 219 | --------------------------------------------------------------------------------