├── .gitignore ├── src ├── lib.rs ├── ast │ ├── lex.rs │ └── mod.rs ├── compiler.rs └── vm.rs ├── kernels ├── add_simple.cu ├── copy.cu ├── add.cu ├── times_two.cu ├── fncall.cu ├── gemm.cu ├── transpose.cu ├── add_simple.ptx ├── copy.ptx ├── times_two.ptx ├── add.ptx ├── fncall.ptx ├── transpose.ptx └── gemm.ptx ├── Makefile ├── Cargo.toml ├── LICENSE-MIT ├── examples └── times_two.rs ├── README.md ├── Cargo.lock ├── tests └── test.rs └── LICENSE-APACHE /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | .vscode 3 | flamegraph.svg 4 | perf.data 5 | out.txt 6 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | mod compiler; 2 | mod vm; 3 | mod ast; 4 | 5 | pub use vm::{Context, Argument, LaunchParams}; 6 | -------------------------------------------------------------------------------- /kernels/add_simple.cu: -------------------------------------------------------------------------------- 1 | __global__ void add_simple(float* a, float* b, float* c) { 2 | size_t i = threadIdx.x; 3 | c[i] = a[i] + b[i]; 4 | } 5 | -------------------------------------------------------------------------------- /kernels/copy.cu: -------------------------------------------------------------------------------- 1 | __global__ void copy(float* a, float* b, size_t n) { 2 | size_t i = blockIdx.x * blockDim.x + threadIdx.x; 3 | if (i < n) { 4 | b[i] = a[i]; 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /kernels/add.cu: -------------------------------------------------------------------------------- 1 | __global__ void add(float* a, float* b, float* c, size_t n) { 2 | size_t i = blockIdx.x * blockDim.x + threadIdx.x; 3 | if (i < n) { 4 | c[i] = a[i] + b[i]; 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /kernels/times_two.cu: -------------------------------------------------------------------------------- 1 | __global__ void times_two(float* a, float* b, size_t n) { 2 | size_t i = blockIdx.x * blockDim.x + threadIdx.x; 3 | if (i < n) { 4 | b[i] = 2 * a[i]; 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | NVCC=nvcc -ptx -arch sm_89 2 | 3 | SOURCES=$(wildcard kernels/*.cu) 4 | PTX=$(SOURCES:.cu=.ptx) 5 | 6 | all: $(PTX) 7 | 8 | %.ptx: %.cu 9 | $(NVCC) -o $@ $< 10 | 11 | clean: 12 | rm -f $(PTX) -------------------------------------------------------------------------------- /kernels/fncall.cu: -------------------------------------------------------------------------------- 1 | __device__ __noinline__ float add_op(float a, float b) { 2 | return a + b; 3 | } 4 | 5 | 6 | __global__ void add(float* a, float* b, float* c, size_t n) { 7 | size_t i = blockIdx.x * blockDim.x + threadIdx.x; 8 | if (i < n) { 9 | c[i] = add_op(a[i], b[i]); 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /kernels/gemm.cu: -------------------------------------------------------------------------------- 1 | __global__ void gemm(float* a, float* b, float* c, size_t m, size_t k, size_t n) { 2 | size_t i = blockIdx.x * blockDim.x + threadIdx.x; 3 | size_t j = blockIdx.y * blockDim.y + threadIdx.y; 4 | 5 | if (i < m && j < n) { 6 | float sum = 0.0f; 7 | for (size_t l = 0; l < k; ++l) { 8 | sum += a[i * k + l] * b[l * n + j]; 9 | } 10 | c[i * n + j] = sum; 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ptoxide" 3 | version = "0.1.0" 4 | edition = "2021" 5 | description = "A virtual machine to execute CUDA PTX without a GPU" 6 | license = "MIT OR Apache-2.0" 7 | repository = "https://github.com/gvilums/ptoxide" 8 | 9 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 10 | 11 | [dependencies] 12 | bytemuck = "1.14.0" 13 | logos = "0.13.0" 14 | thiserror = "1.0.50" 15 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any 2 | person obtaining a copy of this software and associated 3 | documentation files (the "Software"), to deal in the 4 | Software without restriction, including without 5 | limitation the rights to use, copy, modify, merge, 6 | publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software 8 | is furnished to do so, subject to the following 9 | conditions: 10 | 11 | The above copyright notice and this permission notice 12 | shall be included in all copies or substantial portions 13 | of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 16 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 17 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 18 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 19 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 22 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 23 | DEALINGS IN THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /kernels/transpose.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define BLOCK_SIZE 32 5 | 6 | __global__ void transpose(float *input, float *output, size_t N) { 7 | 8 | __shared__ float sharedMemory [BLOCK_SIZE] [BLOCK_SIZE]; 9 | 10 | // global index 11 | int i = threadIdx.x + blockIdx.x * blockDim.x; 12 | int j = threadIdx.y + blockIdx.y * blockDim.y; 13 | 14 | // transposed global memory index 15 | int ti = threadIdx.x + blockIdx.y * blockDim.x; 16 | int tj = threadIdx.y + blockIdx.x * blockDim.y; 17 | 18 | // local index 19 | int local_i = threadIdx.x; 20 | int local_j = threadIdx.y; 21 | 22 | if (i < N && j < N) { 23 | // reading from global memory in coalesed manner and performing tanspose in shared memory 24 | int index = j * N + i; 25 | sharedMemory[local_i][local_j] = input[index]; 26 | } else { 27 | sharedMemory[local_i][local_j] = 0.0; 28 | } 29 | 30 | __syncthreads(); 31 | 32 | if (ti < N && tj < N) { 33 | // writing into global memory in coalesed fashion via transposed data in shared memory 34 | int transposedIndex = tj * N + ti; 35 | output[transposedIndex] = sharedMemory[local_j][local_i]; 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /kernels/add_simple.ptx: -------------------------------------------------------------------------------- 1 | // 2 | // Generated by NVIDIA NVVM Compiler 3 | // 4 | // Compiler Build ID: CL-33281558 5 | // Cuda compilation tools, release 12.3, V12.3.52 6 | // Based on NVVM 7.0.1 7 | // 8 | 9 | .version 8.3 10 | .target sm_89 11 | .address_size 64 12 | 13 | // .globl _Z10add_simplePfS_S_ 14 | 15 | .visible .entry _Z10add_simplePfS_S_( 16 | .param .u64 _Z10add_simplePfS_S__param_0, 17 | .param .u64 _Z10add_simplePfS_S__param_1, 18 | .param .u64 _Z10add_simplePfS_S__param_2 19 | ) 20 | { 21 | .reg .f32 %f<4>; 22 | .reg .b32 %r<2>; 23 | .reg .b64 %rd<11>; 24 | 25 | 26 | ld.param.u64 %rd1, [_Z10add_simplePfS_S__param_0]; 27 | ld.param.u64 %rd2, [_Z10add_simplePfS_S__param_1]; 28 | ld.param.u64 %rd3, [_Z10add_simplePfS_S__param_2]; 29 | cvta.to.global.u64 %rd4, %rd3; 30 | cvta.to.global.u64 %rd5, %rd2; 31 | cvta.to.global.u64 %rd6, %rd1; 32 | mov.u32 %r1, %tid.x; 33 | mul.wide.u32 %rd7, %r1, 4; 34 | add.s64 %rd8, %rd6, %rd7; 35 | ld.global.f32 %f1, [%rd8]; 36 | add.s64 %rd9, %rd5, %rd7; 37 | ld.global.f32 %f2, [%rd9]; 38 | add.f32 %f3, %f1, %f2; 39 | add.s64 %rd10, %rd4, %rd7; 40 | st.global.f32 [%rd10], %f3; 41 | ret; 42 | 43 | } 44 | 45 | -------------------------------------------------------------------------------- /kernels/copy.ptx: -------------------------------------------------------------------------------- 1 | // 2 | // Generated by NVIDIA NVVM Compiler 3 | // 4 | // Compiler Build ID: CL-33281558 5 | // Cuda compilation tools, release 12.3, V12.3.52 6 | // Based on NVVM 7.0.1 7 | // 8 | 9 | .version 8.3 10 | .target sm_89 11 | .address_size 64 12 | 13 | // .globl _Z4copyPfS_m 14 | 15 | .visible .entry _Z4copyPfS_m( 16 | .param .u64 _Z4copyPfS_m_param_0, 17 | .param .u64 _Z4copyPfS_m_param_1, 18 | .param .u64 _Z4copyPfS_m_param_2 19 | ) 20 | { 21 | .reg .pred %p<2>; 22 | .reg .f32 %f<2>; 23 | .reg .b32 %r<5>; 24 | .reg .b64 %rd<10>; 25 | 26 | 27 | ld.param.u64 %rd2, [_Z4copyPfS_m_param_0]; 28 | ld.param.u64 %rd3, [_Z4copyPfS_m_param_1]; 29 | ld.param.u64 %rd4, [_Z4copyPfS_m_param_2]; 30 | mov.u32 %r1, %ctaid.x; 31 | mov.u32 %r2, %ntid.x; 32 | mov.u32 %r3, %tid.x; 33 | mad.lo.s32 %r4, %r1, %r2, %r3; 34 | cvt.u64.u32 %rd1, %r4; 35 | setp.ge.u64 %p1, %rd1, %rd4; 36 | @%p1 bra $L__BB0_2; 37 | 38 | cvta.to.global.u64 %rd5, %rd2; 39 | shl.b64 %rd6, %rd1, 2; 40 | add.s64 %rd7, %rd5, %rd6; 41 | ld.global.f32 %f1, [%rd7]; 42 | cvta.to.global.u64 %rd8, %rd3; 43 | add.s64 %rd9, %rd8, %rd6; 44 | st.global.f32 [%rd9], %f1; 45 | 46 | $L__BB0_2: 47 | ret; 48 | 49 | } 50 | 51 | -------------------------------------------------------------------------------- /kernels/times_two.ptx: -------------------------------------------------------------------------------- 1 | // 2 | // Generated by NVIDIA NVVM Compiler 3 | // 4 | // Compiler Build ID: CL-33281558 5 | // Cuda compilation tools, release 12.3, V12.3.52 6 | // Based on NVVM 7.0.1 7 | // 8 | 9 | .version 8.3 10 | .target sm_89 11 | .address_size 64 12 | 13 | // .globl _Z9times_twoPfS_m 14 | 15 | .visible .entry _Z9times_twoPfS_m( 16 | .param .u64 _Z9times_twoPfS_m_param_0, 17 | .param .u64 _Z9times_twoPfS_m_param_1, 18 | .param .u64 _Z9times_twoPfS_m_param_2 19 | ) 20 | { 21 | .reg .pred %p<2>; 22 | .reg .f32 %f<3>; 23 | .reg .b32 %r<5>; 24 | .reg .b64 %rd<10>; 25 | 26 | 27 | ld.param.u64 %rd2, [_Z9times_twoPfS_m_param_0]; 28 | ld.param.u64 %rd3, [_Z9times_twoPfS_m_param_1]; 29 | ld.param.u64 %rd4, [_Z9times_twoPfS_m_param_2]; 30 | mov.u32 %r1, %ctaid.x; 31 | mov.u32 %r2, %ntid.x; 32 | mov.u32 %r3, %tid.x; 33 | mad.lo.s32 %r4, %r1, %r2, %r3; 34 | cvt.u64.u32 %rd1, %r4; 35 | setp.ge.u64 %p1, %rd1, %rd4; 36 | @%p1 bra $L__BB0_2; 37 | 38 | cvta.to.global.u64 %rd5, %rd2; 39 | shl.b64 %rd6, %rd1, 2; 40 | add.s64 %rd7, %rd5, %rd6; 41 | ld.global.f32 %f1, [%rd7]; 42 | add.f32 %f2, %f1, %f1; 43 | cvta.to.global.u64 %rd8, %rd3; 44 | add.s64 %rd9, %rd8, %rd6; 45 | st.global.f32 [%rd9], %f2; 46 | 47 | $L__BB0_2: 48 | ret; 49 | 50 | } 51 | 52 | -------------------------------------------------------------------------------- /kernels/add.ptx: -------------------------------------------------------------------------------- 1 | // 2 | // Generated by NVIDIA NVVM Compiler 3 | // 4 | // Compiler Build ID: CL-33281558 5 | // Cuda compilation tools, release 12.3, V12.3.52 6 | // Based on NVVM 7.0.1 7 | // 8 | 9 | .version 8.3 10 | .target sm_89 11 | .address_size 64 12 | 13 | // .globl _Z3addPfS_S_m 14 | 15 | .visible .entry _Z3addPfS_S_m( 16 | .param .u64 _Z3addPfS_S_m_param_0, 17 | .param .u64 _Z3addPfS_S_m_param_1, 18 | .param .u64 _Z3addPfS_S_m_param_2, 19 | .param .u64 _Z3addPfS_S_m_param_3 20 | ) 21 | { 22 | .reg .pred %p<2>; 23 | .reg .f32 %f<4>; 24 | .reg .b32 %r<5>; 25 | .reg .b64 %rd<13>; 26 | 27 | 28 | ld.param.u64 %rd2, [_Z3addPfS_S_m_param_0]; 29 | ld.param.u64 %rd3, [_Z3addPfS_S_m_param_1]; 30 | ld.param.u64 %rd4, [_Z3addPfS_S_m_param_2]; 31 | ld.param.u64 %rd5, [_Z3addPfS_S_m_param_3]; 32 | mov.u32 %r1, %ctaid.x; 33 | mov.u32 %r2, %ntid.x; 34 | mov.u32 %r3, %tid.x; 35 | mad.lo.s32 %r4, %r1, %r2, %r3; 36 | cvt.u64.u32 %rd1, %r4; 37 | setp.ge.u64 %p1, %rd1, %rd5; 38 | @%p1 bra $L__BB0_2; 39 | 40 | cvta.to.global.u64 %rd6, %rd2; 41 | shl.b64 %rd7, %rd1, 2; 42 | add.s64 %rd8, %rd6, %rd7; 43 | cvta.to.global.u64 %rd9, %rd3; 44 | add.s64 %rd10, %rd9, %rd7; 45 | ld.global.f32 %f1, [%rd10]; 46 | ld.global.f32 %f2, [%rd8]; 47 | add.f32 %f3, %f2, %f1; 48 | cvta.to.global.u64 %rd11, %rd4; 49 | add.s64 %rd12, %rd11, %rd7; 50 | st.global.f32 [%rd12], %f3; 51 | 52 | $L__BB0_2: 53 | ret; 54 | 55 | } 56 | 57 | -------------------------------------------------------------------------------- /examples/times_two.rs: -------------------------------------------------------------------------------- 1 | use ptoxide::{Context, Argument, LaunchParams}; 2 | 3 | fn main() { 4 | let a: Vec = vec![1., 2., 3., 4., 5.]; 5 | let mut b: Vec = vec![0.; a.len()]; 6 | 7 | let n = a.len(); 8 | 9 | let mut ctx = Context::new_with_module(KERNEL).expect("compile kernel"); 10 | 11 | const BLOCK_SIZE: u32 = 256; 12 | let grid_size = (n as u32 + BLOCK_SIZE - 1) / BLOCK_SIZE; 13 | 14 | let da = ctx.alloc(n); 15 | let db = ctx.alloc(n); 16 | 17 | ctx.write(da, &a); 18 | ctx.run( 19 | LaunchParams::func_id(0) 20 | .grid1d(grid_size) 21 | .block1d(BLOCK_SIZE), 22 | &[ 23 | Argument::ptr(da), 24 | Argument::ptr(db), 25 | Argument::U64(n as u64), 26 | ], 27 | ).expect("execute kernel"); 28 | 29 | ctx.read(db, &mut b); 30 | // prints [2.0, 4.0, 6.0, 8.0, 10.0] 31 | println!("{:?}", b); 32 | } 33 | 34 | const KERNEL: &'static str = r#" 35 | .version 8.3 36 | .target sm_89 37 | .address_size 64 38 | 39 | .visible .entry times_two( 40 | .param .u64 a, 41 | .param .u64 b, 42 | .param .u64 n 43 | ) 44 | { 45 | .reg .pred %p<2>; 46 | .reg .f32 %f<3>; 47 | .reg .b32 %r<5>; 48 | .reg .b64 %rd<10>; 49 | 50 | 51 | ld.param.u64 %rd2, [a]; 52 | ld.param.u64 %rd3, [b]; 53 | ld.param.u64 %rd4, [n]; 54 | mov.u32 %r1, %ctaid.x; 55 | mov.u32 %r2, %ntid.x; 56 | mov.u32 %r3, %tid.x; 57 | mad.lo.s32 %r4, %r1, %r2, %r3; 58 | cvt.u64.u32 %rd1, %r4; 59 | setp.ge.u64 %p1, %rd1, %rd4; 60 | @%p1 bra $L__BB0_2; 61 | 62 | cvta.to.global.u64 %rd5, %rd2; 63 | shl.b64 %rd6, %rd1, 2; 64 | add.s64 %rd7, %rd5, %rd6; 65 | ld.global.f32 %f1, [%rd7]; 66 | add.f32 %f2, %f1, %f1; 67 | cvta.to.global.u64 %rd8, %rd3; 68 | add.s64 %rd9, %rd8, %rd6; 69 | st.global.f32 [%rd9], %f2; 70 | 71 | $L__BB0_2: 72 | ret; 73 | } 74 | "#; -------------------------------------------------------------------------------- /kernels/fncall.ptx: -------------------------------------------------------------------------------- 1 | // 2 | // Generated by NVIDIA NVVM Compiler 3 | // 4 | // Compiler Build ID: CL-33281558 5 | // Cuda compilation tools, release 12.3, V12.3.52 6 | // Based on NVVM 7.0.1 7 | // 8 | 9 | .version 8.3 10 | .target sm_89 11 | .address_size 64 12 | 13 | 14 | .func (.param .b32 func_retval0) _Z6add_opff( 15 | .param .b32 _Z6add_opff_param_0, 16 | .param .b32 _Z6add_opff_param_1 17 | ) 18 | { 19 | .reg .f32 %f<4>; 20 | 21 | 22 | ld.param.f32 %f1, [_Z6add_opff_param_0]; 23 | ld.param.f32 %f2, [_Z6add_opff_param_1]; 24 | add.f32 %f3, %f1, %f2; 25 | st.param.f32 [func_retval0+0], %f3; 26 | ret; 27 | 28 | } 29 | // .globl _Z3addPfS_S_m 30 | .visible .entry _Z3addPfS_S_m( 31 | .param .u64 _Z3addPfS_S_m_param_0, 32 | .param .u64 _Z3addPfS_S_m_param_1, 33 | .param .u64 _Z3addPfS_S_m_param_2, 34 | .param .u64 _Z3addPfS_S_m_param_3 35 | ) 36 | { 37 | .reg .pred %p<2>; 38 | .reg .f32 %f<4>; 39 | .reg .b32 %r<5>; 40 | .reg .b64 %rd<13>; 41 | 42 | 43 | ld.param.u64 %rd2, [_Z3addPfS_S_m_param_0]; 44 | ld.param.u64 %rd3, [_Z3addPfS_S_m_param_1]; 45 | ld.param.u64 %rd4, [_Z3addPfS_S_m_param_2]; 46 | ld.param.u64 %rd5, [_Z3addPfS_S_m_param_3]; 47 | mov.u32 %r1, %ctaid.x; 48 | mov.u32 %r2, %ntid.x; 49 | mov.u32 %r3, %tid.x; 50 | mad.lo.s32 %r4, %r1, %r2, %r3; 51 | cvt.u64.u32 %rd1, %r4; 52 | setp.ge.u64 %p1, %rd1, %rd5; 53 | @%p1 bra $L__BB1_2; 54 | 55 | cvta.to.global.u64 %rd6, %rd2; 56 | shl.b64 %rd7, %rd1, 2; 57 | add.s64 %rd8, %rd6, %rd7; 58 | cvta.to.global.u64 %rd9, %rd3; 59 | add.s64 %rd10, %rd9, %rd7; 60 | ld.global.f32 %f1, [%rd10]; 61 | ld.global.f32 %f2, [%rd8]; 62 | { // callseq 0, 0 63 | .reg .b32 temp_param_reg; 64 | .param .b32 param0; 65 | st.param.f32 [param0+0], %f2; 66 | .param .b32 param1; 67 | st.param.f32 [param1+0], %f1; 68 | .param .b32 retval0; 69 | call.uni (retval0), 70 | _Z6add_opff, 71 | ( 72 | param0, 73 | param1 74 | ); 75 | ld.param.f32 %f3, [retval0+0]; 76 | } // callseq 0 77 | cvta.to.global.u64 %rd11, %rd4; 78 | add.s64 %rd12, %rd11, %rd7; 79 | st.global.f32 [%rd12], %f3; 80 | 81 | $L__BB1_2: 82 | ret; 83 | 84 | } 85 | 86 | -------------------------------------------------------------------------------- /kernels/transpose.ptx: -------------------------------------------------------------------------------- 1 | // 2 | // Generated by NVIDIA NVVM Compiler 3 | // 4 | // Compiler Build ID: CL-33281558 5 | // Cuda compilation tools, release 12.3, V12.3.52 6 | // Based on NVVM 7.0.1 7 | // 8 | 9 | .version 8.3 10 | .target sm_89 11 | .address_size 64 12 | 13 | // .globl _Z9transposePfS_m 14 | // _ZZ9transposePfS_mE12sharedMemory has been demoted 15 | 16 | .visible .entry _Z9transposePfS_m( 17 | .param .u64 _Z9transposePfS_m_param_0, 18 | .param .u64 _Z9transposePfS_m_param_1, 19 | .param .u64 _Z9transposePfS_m_param_2 20 | ) 21 | { 22 | .reg .pred %p<7>; 23 | .reg .f32 %f<3>; 24 | .reg .b32 %r<26>; 25 | .reg .b64 %rd<14>; 26 | // demoted variable 27 | .shared .align 4 .b8 _ZZ9transposePfS_mE12sharedMemory[4096]; 28 | 29 | ld.param.u64 %rd1, [_Z9transposePfS_m_param_0]; 30 | ld.param.u64 %rd2, [_Z9transposePfS_m_param_1]; 31 | ld.param.u64 %rd3, [_Z9transposePfS_m_param_2]; 32 | mov.u32 %r8, %ntid.x; 33 | mov.u32 %r9, %ctaid.x; 34 | mov.u32 %r1, %tid.x; 35 | mad.lo.s32 %r2, %r9, %r8, %r1; 36 | mov.u32 %r10, %ntid.y; 37 | mov.u32 %r11, %ctaid.y; 38 | mov.u32 %r3, %tid.y; 39 | mad.lo.s32 %r4, %r11, %r10, %r3; 40 | mad.lo.s32 %r5, %r11, %r8, %r1; 41 | mad.lo.s32 %r6, %r9, %r10, %r3; 42 | cvt.s64.s32 %rd4, %r2; 43 | setp.lt.u64 %p1, %rd4, %rd3; 44 | cvt.s64.s32 %rd5, %r4; 45 | setp.lt.u64 %p2, %rd5, %rd3; 46 | and.pred %p3, %p1, %p2; 47 | shl.b32 %r12, %r1, 7; 48 | mov.u32 %r13, _ZZ9transposePfS_mE12sharedMemory; 49 | add.s32 %r14, %r13, %r12; 50 | shl.b32 %r15, %r3, 2; 51 | add.s32 %r7, %r14, %r15; 52 | @%p3 bra $L__BB0_2; 53 | bra.uni $L__BB0_1; 54 | 55 | $L__BB0_2: 56 | cvta.to.global.u64 %rd6, %rd1; 57 | cvt.u32.u64 %r17, %rd3; 58 | mad.lo.s32 %r18, %r4, %r17, %r2; 59 | mul.wide.s32 %rd7, %r18, 4; 60 | add.s64 %rd8, %rd6, %rd7; 61 | ld.global.f32 %f1, [%rd8]; 62 | st.shared.f32 [%r7], %f1; 63 | bra.uni $L__BB0_3; 64 | 65 | $L__BB0_1: 66 | mov.u32 %r16, 0; 67 | st.shared.u32 [%r7], %r16; 68 | 69 | $L__BB0_3: 70 | bar.sync 0; 71 | cvt.s64.s32 %rd9, %r5; 72 | setp.ge.u64 %p4, %rd9, %rd3; 73 | cvt.s64.s32 %rd10, %r6; 74 | setp.ge.u64 %p5, %rd10, %rd3; 75 | or.pred %p6, %p4, %p5; 76 | @%p6 bra $L__BB0_5; 77 | 78 | cvt.u32.u64 %r19, %rd3; 79 | mad.lo.s32 %r20, %r6, %r19, %r5; 80 | shl.b32 %r21, %r3, 7; 81 | add.s32 %r23, %r13, %r21; 82 | shl.b32 %r24, %r1, 2; 83 | add.s32 %r25, %r23, %r24; 84 | ld.shared.f32 %f2, [%r25]; 85 | cvta.to.global.u64 %rd11, %rd2; 86 | mul.wide.s32 %rd12, %r20, 4; 87 | add.s64 %rd13, %rd11, %rd12; 88 | st.global.f32 [%rd13], %f2; 89 | 90 | $L__BB0_5: 91 | ret; 92 | 93 | } 94 | 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ptoxide 2 | 3 | `ptoxide` is a crate that allows NVIDIA CUDA PTX code to be executed on any machine. 4 | It was created as a project to learn more about the CUDA excution model. 5 | 6 | Kernels are executed by compiling them to a custom bytecode format, 7 | which is then executed inside of a virtual machine. 8 | 9 | To see how the library works in practice, check out the [example below](#example), 10 | and take a look at the integration tests in the [tests](/tests) directory. 11 | 12 | Try running `cargo run --example times_two` to see it in action! 13 | 14 | ## Supported Features 15 | `ptoxide` supports most fundamental PTX features, such as: 16 | - Global, shared, and local (stack) memory 17 | - (Recursive) function calls 18 | - Thread synchronization using barriers 19 | - Various arithmetic operations on integers and floating point values 20 | - One-, two-, and three-dimensional thread grids and blocks 21 | 22 | These features are sufficient to execute the kernels found in the [kernels](/kernels) directory, 23 | such as simple vector operations, matrix multiplication, 24 | and matrix transposition using a shared buffer. 25 | 26 | However, many features and instructions are still missing, and you will probably encounter `todo!`s 27 | and parsing errors when attempting to execute more complex programs. 28 | Pull requests to implement missing features are always greatly appreciated! 29 | 30 | ## Internals 31 | The code of the library itself is not yet well-documented. However, here is a general overview of the main 32 | modules comprising `ptoxide`: 33 | - The [`ast`](/src/ast/mod.rs) module implements the logic for parsing PTX programs. 34 | - The [`vm`](/src/vm.rs) module defines a bytecode format and implements the virtual machine to execute it. 35 | - The [`compiler`](/src/compiler.rs) module implements a simple single-pass compiler to translate a PTX program given as an AST to bytecode. 36 | 37 | ## Example 38 | The following code snippet shows how to invoke a kernel to scale a vector of floats by a factor of 2. 39 | Check out the [full example](/examples/times_two.rs) in the [examples directory](/examples/), 40 | or run it by running `cargo run --example times_two`. 41 | 42 | ```rust 43 | use ptoxide::{Context, Argument, LaunchParams}; 44 | 45 | fn times_two(kernel: &str) { 46 | let a: Vec = vec![1., 2., 3., 4., 5.]; 47 | let mut b: Vec = vec![0.; a.len()]; 48 | 49 | let n = a.len(); 50 | 51 | let mut ctx = Context::new_with_module(kernel).expect("compile kernel"); 52 | 53 | const BLOCK_SIZE: u32 = 256; 54 | let grid_size = (n as u32 + BLOCK_SIZE - 1) / BLOCK_SIZE; 55 | 56 | let da = ctx.alloc(n); 57 | let db = ctx.alloc(n); 58 | 59 | ctx.write(da, &a); 60 | ctx.run( 61 | LaunchParams::func_id(0) 62 | .grid1d(grid_size) 63 | .block1d(BLOCK_SIZE), 64 | &[ 65 | Argument::ptr(da), 66 | Argument::ptr(db), 67 | Argument::U64(n as u64), 68 | ], 69 | ).expect("execute kernel"); 70 | 71 | ctx.read(db, &mut b); 72 | // prints [2.0, 4.0, 6.0, 8.0, 10.0] 73 | println!("{:?}", b); 74 | } 75 | ``` 76 | 77 | ## Reading PTX 78 | To learn more about the PTX ISA, check out NVIDIA's [documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html). 79 | 80 | ## License 81 | `ptoxide` is dual-licensed under the Apache License version 2.0 and the MIT license, at your choosing. 82 | -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "beef" 7 | version = "0.5.2" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "3a8241f3ebb85c056b509d4327ad0358fbbba6ffb340bf388f26350aeda225b1" 10 | 11 | [[package]] 12 | name = "bytemuck" 13 | version = "1.14.0" 14 | source = "registry+https://github.com/rust-lang/crates.io-index" 15 | checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" 16 | 17 | [[package]] 18 | name = "fnv" 19 | version = "1.0.7" 20 | source = "registry+https://github.com/rust-lang/crates.io-index" 21 | checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" 22 | 23 | [[package]] 24 | name = "logos" 25 | version = "0.13.0" 26 | source = "registry+https://github.com/rust-lang/crates.io-index" 27 | checksum = "c000ca4d908ff18ac99b93a062cb8958d331c3220719c52e77cb19cc6ac5d2c1" 28 | dependencies = [ 29 | "logos-derive", 30 | ] 31 | 32 | [[package]] 33 | name = "logos-codegen" 34 | version = "0.13.0" 35 | source = "registry+https://github.com/rust-lang/crates.io-index" 36 | checksum = "dc487311295e0002e452025d6b580b77bb17286de87b57138f3b5db711cded68" 37 | dependencies = [ 38 | "beef", 39 | "fnv", 40 | "proc-macro2", 41 | "quote", 42 | "regex-syntax", 43 | "syn", 44 | ] 45 | 46 | [[package]] 47 | name = "logos-derive" 48 | version = "0.13.0" 49 | source = "registry+https://github.com/rust-lang/crates.io-index" 50 | checksum = "dbfc0d229f1f42d790440136d941afd806bc9e949e2bcb8faa813b0f00d1267e" 51 | dependencies = [ 52 | "logos-codegen", 53 | ] 54 | 55 | [[package]] 56 | name = "proc-macro2" 57 | version = "1.0.69" 58 | source = "registry+https://github.com/rust-lang/crates.io-index" 59 | checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" 60 | dependencies = [ 61 | "unicode-ident", 62 | ] 63 | 64 | [[package]] 65 | name = "ptoxide" 66 | version = "0.1.0" 67 | dependencies = [ 68 | "bytemuck", 69 | "logos", 70 | "thiserror", 71 | ] 72 | 73 | [[package]] 74 | name = "quote" 75 | version = "1.0.33" 76 | source = "registry+https://github.com/rust-lang/crates.io-index" 77 | checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" 78 | dependencies = [ 79 | "proc-macro2", 80 | ] 81 | 82 | [[package]] 83 | name = "regex-syntax" 84 | version = "0.6.29" 85 | source = "registry+https://github.com/rust-lang/crates.io-index" 86 | checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" 87 | 88 | [[package]] 89 | name = "syn" 90 | version = "2.0.38" 91 | source = "registry+https://github.com/rust-lang/crates.io-index" 92 | checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" 93 | dependencies = [ 94 | "proc-macro2", 95 | "quote", 96 | "unicode-ident", 97 | ] 98 | 99 | [[package]] 100 | name = "thiserror" 101 | version = "1.0.50" 102 | source = "registry+https://github.com/rust-lang/crates.io-index" 103 | checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" 104 | dependencies = [ 105 | "thiserror-impl", 106 | ] 107 | 108 | [[package]] 109 | name = "thiserror-impl" 110 | version = "1.0.50" 111 | source = "registry+https://github.com/rust-lang/crates.io-index" 112 | checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" 113 | dependencies = [ 114 | "proc-macro2", 115 | "quote", 116 | "syn", 117 | ] 118 | 119 | [[package]] 120 | name = "unicode-ident" 121 | version = "1.0.12" 122 | source = "registry+https://github.com/rust-lang/crates.io-index" 123 | checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" 124 | -------------------------------------------------------------------------------- /kernels/gemm.ptx: -------------------------------------------------------------------------------- 1 | // 2 | // Generated by NVIDIA NVVM Compiler 3 | // 4 | // Compiler Build ID: CL-33281558 5 | // Cuda compilation tools, release 12.3, V12.3.52 6 | // Based on NVVM 7.0.1 7 | // 8 | 9 | .version 8.3 10 | .target sm_89 11 | .address_size 64 12 | 13 | // .globl _Z4gemmPfS_S_mmm 14 | 15 | .visible .entry _Z4gemmPfS_S_mmm( 16 | .param .u64 _Z4gemmPfS_S_mmm_param_0, 17 | .param .u64 _Z4gemmPfS_S_mmm_param_1, 18 | .param .u64 _Z4gemmPfS_S_mmm_param_2, 19 | .param .u64 _Z4gemmPfS_S_mmm_param_3, 20 | .param .u64 _Z4gemmPfS_S_mmm_param_4, 21 | .param .u64 _Z4gemmPfS_S_mmm_param_5 22 | ) 23 | { 24 | .reg .pred %p<9>; 25 | .reg .f32 %f<30>; 26 | .reg .b32 %r<9>; 27 | .reg .b64 %rd<61>; 28 | 29 | 30 | ld.param.u64 %rd31, [_Z4gemmPfS_S_mmm_param_0]; 31 | ld.param.u64 %rd32, [_Z4gemmPfS_S_mmm_param_1]; 32 | ld.param.u64 %rd28, [_Z4gemmPfS_S_mmm_param_2]; 33 | ld.param.u64 %rd33, [_Z4gemmPfS_S_mmm_param_3]; 34 | ld.param.u64 %rd29, [_Z4gemmPfS_S_mmm_param_4]; 35 | ld.param.u64 %rd30, [_Z4gemmPfS_S_mmm_param_5]; 36 | cvta.to.global.u64 %rd1, %rd32; 37 | cvta.to.global.u64 %rd2, %rd31; 38 | mov.u32 %r1, %ntid.x; 39 | mov.u32 %r2, %ctaid.x; 40 | mov.u32 %r3, %tid.x; 41 | mad.lo.s32 %r4, %r2, %r1, %r3; 42 | cvt.u64.u32 %rd3, %r4; 43 | mov.u32 %r5, %ntid.y; 44 | mov.u32 %r6, %ctaid.y; 45 | mov.u32 %r7, %tid.y; 46 | mad.lo.s32 %r8, %r6, %r5, %r7; 47 | cvt.u64.u32 %rd4, %r8; 48 | setp.ge.u64 %p1, %rd3, %rd33; 49 | setp.ge.u64 %p2, %rd4, %rd30; 50 | or.pred %p3, %p1, %p2; 51 | @%p3 bra $L__BB0_9; 52 | 53 | setp.eq.s64 %p4, %rd29, 0; 54 | mov.f32 %f29, 0f00000000; 55 | @%p4 bra $L__BB0_8; 56 | 57 | mul.lo.s64 %rd5, %rd3, %rd29; 58 | and.b64 %rd6, %rd29, 3; 59 | add.s64 %rd35, %rd29, -1; 60 | setp.lt.u64 %p5, %rd35, 3; 61 | mov.f32 %f29, 0f00000000; 62 | mov.u64 %rd57, 0; 63 | @%p5 bra $L__BB0_5; 64 | 65 | sub.s64 %rd7, %rd6, %rd29; 66 | shl.b64 %rd37, %rd4, 2; 67 | add.s64 %rd55, %rd1, %rd37; 68 | shl.b64 %rd38, %rd5, 2; 69 | add.s64 %rd39, %rd2, %rd38; 70 | add.s64 %rd54, %rd39, 8; 71 | shl.b64 %rd10, %rd30, 2; 72 | mov.f32 %f29, 0f00000000; 73 | mov.u64 %rd57, 0; 74 | 75 | $L__BB0_4: 76 | ld.global.f32 %f12, [%rd55]; 77 | ld.global.f32 %f13, [%rd54+-8]; 78 | fma.rn.f32 %f14, %f13, %f12, %f29; 79 | add.s64 %rd40, %rd55, %rd10; 80 | ld.global.f32 %f15, [%rd40]; 81 | ld.global.f32 %f16, [%rd54+-4]; 82 | fma.rn.f32 %f17, %f16, %f15, %f14; 83 | add.s64 %rd41, %rd40, %rd10; 84 | ld.global.f32 %f18, [%rd41]; 85 | ld.global.f32 %f19, [%rd54]; 86 | fma.rn.f32 %f20, %f19, %f18, %f17; 87 | add.s64 %rd42, %rd41, %rd10; 88 | add.s64 %rd55, %rd42, %rd10; 89 | ld.global.f32 %f21, [%rd42]; 90 | ld.global.f32 %f22, [%rd54+4]; 91 | fma.rn.f32 %f29, %f22, %f21, %f20; 92 | add.s64 %rd57, %rd57, 4; 93 | add.s64 %rd43, %rd7, %rd57; 94 | add.s64 %rd54, %rd54, 16; 95 | setp.ne.s64 %p6, %rd43, 0; 96 | @%p6 bra $L__BB0_4; 97 | 98 | $L__BB0_5: 99 | setp.eq.s64 %p7, %rd6, 0; 100 | @%p7 bra $L__BB0_8; 101 | 102 | mul.lo.s64 %rd44, %rd57, %rd30; 103 | add.s64 %rd45, %rd44, %rd4; 104 | shl.b64 %rd46, %rd45, 2; 105 | add.s64 %rd60, %rd1, %rd46; 106 | shl.b64 %rd19, %rd30, 2; 107 | add.s64 %rd47, %rd57, %rd5; 108 | shl.b64 %rd48, %rd47, 2; 109 | add.s64 %rd59, %rd2, %rd48; 110 | neg.s64 %rd58, %rd6; 111 | 112 | $L__BB0_7: 113 | .pragma "nounroll"; 114 | ld.global.f32 %f23, [%rd60]; 115 | ld.global.f32 %f24, [%rd59]; 116 | fma.rn.f32 %f29, %f24, %f23, %f29; 117 | add.s64 %rd60, %rd60, %rd19; 118 | add.s64 %rd59, %rd59, 4; 119 | add.s64 %rd58, %rd58, 1; 120 | setp.ne.s64 %p8, %rd58, 0; 121 | @%p8 bra $L__BB0_7; 122 | 123 | $L__BB0_8: 124 | mul.lo.s64 %rd49, %rd3, %rd30; 125 | add.s64 %rd50, %rd49, %rd4; 126 | cvta.to.global.u64 %rd51, %rd28; 127 | shl.b64 %rd52, %rd50, 2; 128 | add.s64 %rd53, %rd51, %rd52; 129 | st.global.f32 [%rd53], %f29; 130 | 131 | $L__BB0_9: 132 | ret; 133 | 134 | } 135 | 136 | -------------------------------------------------------------------------------- /tests/test.rs: -------------------------------------------------------------------------------- 1 | use ptoxide::{Argument, Context, LaunchParams}; 2 | 3 | const ADD: &'static str = include_str!("../kernels/add.ptx"); 4 | const ADD_SIMPLE: &'static str = include_str!("../kernels/add_simple.ptx"); 5 | const FNCALL: &'static str = include_str!("../kernels/fncall.ptx"); 6 | const GEMM: &'static str = include_str!("../kernels/gemm.ptx"); 7 | const TRANSPOSE: &'static str = include_str!("../kernels/transpose.ptx"); 8 | 9 | #[test] 10 | fn add_simple() { 11 | let mut ctx = Context::new_with_module(ADD_SIMPLE).unwrap(); 12 | 13 | const N: usize = 10; 14 | 15 | let a = ctx.alloc(N); 16 | let b = ctx.alloc(N); 17 | let c = ctx.alloc(N); 18 | 19 | let data_a = vec![1f32; N]; 20 | let data_b = vec![2f32; N]; 21 | 22 | ctx.write(a, &data_a); 23 | ctx.write(b, &data_b); 24 | 25 | ctx.run( 26 | LaunchParams::func_id(0).grid1d(1).block1d(N as u32), 27 | &[Argument::ptr(a), Argument::ptr(b), Argument::ptr(c)], 28 | ) 29 | .unwrap(); 30 | 31 | let mut res = vec![0f32; N]; 32 | ctx.read(c, &mut res); 33 | 34 | res.iter().for_each(|v| assert_eq!(*v, 3f32)); 35 | } 36 | 37 | #[test] 38 | fn add() { 39 | let mut ctx = Context::new_with_module(ADD).unwrap(); 40 | 41 | const N: usize = 1000; 42 | const BLOCK_SIZE: u32 = 256; 43 | const GRID_SIZE: u32 = (N as u32 + BLOCK_SIZE - 1) / BLOCK_SIZE; 44 | 45 | let a = ctx.alloc(N); 46 | let b = ctx.alloc(N); 47 | let c = ctx.alloc(N); 48 | 49 | let data_a = vec![1f32; N]; 50 | let data_b = vec![2f32; N]; 51 | ctx.write(a, &data_a); 52 | ctx.write(b, &data_b); 53 | 54 | ctx.run( 55 | LaunchParams::func_id(0) 56 | .grid1d(GRID_SIZE) 57 | .block1d(BLOCK_SIZE), 58 | &[ 59 | Argument::ptr(a), 60 | Argument::ptr(b), 61 | Argument::ptr(c), 62 | Argument::U64(N as u64), 63 | ], 64 | ) 65 | .unwrap(); 66 | 67 | let mut res = vec![0f32; N]; 68 | ctx.read(c, &mut res); 69 | 70 | res.iter().for_each(|v| assert_eq!(*v, 3f32)); 71 | } 72 | 73 | #[test] 74 | fn fncall() { 75 | let mut ctx = Context::new_with_module(FNCALL).unwrap(); 76 | 77 | const N: usize = 1000; 78 | const BLOCK_SIZE: u32 = 256; 79 | const GRID_SIZE: u32 = (N as u32 + BLOCK_SIZE - 1) / BLOCK_SIZE; 80 | 81 | let a = ctx.alloc(N); 82 | let b = ctx.alloc(N); 83 | let c = ctx.alloc(N); 84 | 85 | let data_a = vec![1f32; N]; 86 | let data_b = vec![2f32; N]; 87 | ctx.write(a, &data_a); 88 | ctx.write(b, &data_b); 89 | 90 | ctx.run( 91 | LaunchParams::func_id(1) // in this case id 0 is the helper fn 92 | .grid1d(GRID_SIZE) 93 | .block1d(BLOCK_SIZE), 94 | &[ 95 | Argument::ptr(a), 96 | Argument::ptr(b), 97 | Argument::ptr(c), 98 | Argument::U64(N as u64), 99 | ], 100 | ) 101 | .unwrap(); 102 | 103 | let mut res = vec![0f32; N]; 104 | ctx.read(c, &mut res); 105 | 106 | res.iter().for_each(|v| assert_eq!(*v, 3f32)); 107 | } 108 | 109 | fn run_transpose(ctx: &mut Context, n: usize) { 110 | const BLOCK_SIZE: u32 = 32; 111 | let grid_size: u32 = (n as u32 + BLOCK_SIZE - 1) / BLOCK_SIZE; 112 | 113 | let a = ctx.alloc(n * n); 114 | let b = ctx.alloc(n * n); 115 | 116 | let mut data_a = vec![0f32; n * n]; 117 | for x in 0..n { 118 | for y in 0..n { 119 | data_a[x * n + y] = (x * n + y) as f32; 120 | } 121 | } 122 | ctx.write(a, &data_a); 123 | 124 | ctx.run( 125 | LaunchParams::func_id(0) 126 | .grid2d(grid_size, grid_size) 127 | .block2d(BLOCK_SIZE, BLOCK_SIZE), 128 | &[Argument::ptr(a), Argument::ptr(b), Argument::U64(n as u64)], 129 | ) 130 | .unwrap(); 131 | 132 | let mut res = vec![0f32; n * n]; 133 | ctx.read(b, &mut res); 134 | 135 | for x in 0..n { 136 | for y in 0..n { 137 | assert_eq!(res[x * n + y], data_a[y * n + x]); 138 | } 139 | } 140 | } 141 | 142 | #[test] 143 | fn transpose() { 144 | let mut ctx = Context::new_with_module(TRANSPOSE).unwrap(); 145 | for i in 1..100 { 146 | run_transpose(&mut ctx, i); 147 | ctx.reset_mem(); 148 | } 149 | } 150 | 151 | fn run_gemm(ctx: &mut Context, m: usize, k: usize, n: usize) { 152 | // todo test non-even alignment 153 | let block_size = 1;//32; 154 | let grid_x = (m as u32 + block_size - 1) / block_size; 155 | let grid_y = (n as u32 + block_size - 1) / block_size; 156 | 157 | let a = ctx.alloc(m * k); 158 | let b = ctx.alloc(k * n); 159 | let c = ctx.alloc(m * n); 160 | 161 | let data_a = vec![1f32; m * k]; 162 | let data_b = vec![1f32; k * n]; 163 | 164 | ctx.write(a, &data_a); 165 | ctx.write(b, &data_b); 166 | 167 | ctx.run( 168 | LaunchParams::func_id(0) 169 | .grid2d(grid_x, grid_y) 170 | .block2d(block_size, block_size), 171 | &[ 172 | Argument::ptr(a), 173 | Argument::ptr(b), 174 | Argument::ptr(c), 175 | Argument::U64(m as u64), 176 | Argument::U64(k as u64), 177 | Argument::U64(n as u64), 178 | ], 179 | ) 180 | .unwrap(); 181 | 182 | let mut res = vec![0f32; m * n]; 183 | ctx.read(c, &mut res); 184 | 185 | for val in res { 186 | assert_eq!(val, k as f32); 187 | } 188 | } 189 | 190 | #[test] 191 | fn gemm() { 192 | let mut ctx = Context::new_with_module(GEMM).unwrap(); 193 | 194 | let sizes = [ 195 | (32, 32, 32), 196 | (3, 2, 1), 197 | (1, 20, 1), 198 | (2, 24, 1), 199 | (123, 54, 10), 200 | (20, 40, 33), 201 | ]; 202 | 203 | for (m, k, n) in sizes.into_iter() { 204 | run_gemm(&mut ctx, m, k, n); 205 | ctx.reset_mem(); 206 | } 207 | } 208 | 209 | 210 | #[test] 211 | fn times_two() { 212 | let a: Vec = vec![1., 2., 3., 4., 5.]; 213 | let mut b: Vec = vec![0.; a.len()]; 214 | 215 | let n = a.len(); 216 | 217 | let kernel = std::fs::read_to_string("kernels/times_two.ptx").expect("read kernel file"); 218 | let mut ctx = Context::new_with_module(&kernel).expect("compile kernel"); 219 | 220 | const BLOCK_SIZE: u32 = 256; 221 | let grid_size = (n as u32 + BLOCK_SIZE - 1) / BLOCK_SIZE; 222 | 223 | let da = ctx.alloc(n); 224 | let db = ctx.alloc(n); 225 | 226 | ctx.write(da, &a); 227 | ctx.run( 228 | LaunchParams::func_id(0) 229 | .grid1d(grid_size) 230 | .block1d(BLOCK_SIZE), 231 | &[ 232 | Argument::ptr(da), 233 | Argument::ptr(db), 234 | Argument::U64(n as u64), 235 | ], 236 | ).expect("execute kernel"); 237 | 238 | ctx.read(db, &mut b); 239 | for (x, y) in a.into_iter().zip(b) { 240 | assert_eq!(2. * x, y); 241 | } 242 | } 243 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | -------------------------------------------------------------------------------- /src/ast/lex.rs: -------------------------------------------------------------------------------- 1 | use logos::{Logos, Lexer}; 2 | 3 | fn lex_reg_multiplicity<'a>(lex: &mut Lexer<'a, Token<'a>>) -> Result { 4 | let mut s = lex.slice(); 5 | s = &s[1..s.len() - 1]; 6 | s.parse().map_err(|_| LexError::ParseRegMultiplicity) 7 | } 8 | 9 | fn lex_version_number<'a>(lex: &mut Lexer<'a, Token<'a>>) -> Result<(u32, u32), LexError> { 10 | let num_str = lex.slice().split_whitespace().nth(1).ok_or(LexError::ParseVersionNumber)?; 11 | let Some((major_str, minor_str)) = num_str.split_once('.') else { 12 | return Err(LexError::ParseVersionNumber); 13 | }; 14 | let major = major_str 15 | .parse() 16 | .map_err(|_| LexError::ParseVersionNumber)?; 17 | let minor = minor_str 18 | .parse() 19 | .map_err(|_| LexError::ParseVersionNumber)?; 20 | Ok(( major, minor )) 21 | } 22 | 23 | fn lex_float32_constant<'a>(lex: &mut Lexer<'a, Token<'a>>) -> Result { 24 | let Some(vals) = lex.slice().as_bytes().get(2..) else { 25 | return Err(LexError::ParseFloatConst); 26 | }; 27 | let mut val = 0u32; 28 | for c in vals { 29 | val <<= 4; 30 | val |= match c { 31 | b'0'..=b'9' => c - b'0', 32 | b'a'..=b'f' => c - b'a' + 10, 33 | b'A'..=b'F' => c - b'A' + 10, 34 | _ => return Err(LexError::ParseFloatConst), 35 | } as u32; 36 | } 37 | Ok(f32::from_bits(val)) 38 | } 39 | 40 | fn lex_float64_constant<'a>(lex: &mut Lexer<'a, Token<'a>>) -> Result { 41 | let Some(vals) = lex.slice().as_bytes().get(2..) else { 42 | return Err(LexError::ParseFloatConst); 43 | }; 44 | let mut val = 0u64; 45 | for c in vals { 46 | val <<= 4; 47 | val |= match c { 48 | b'0'..=b'9' => c - b'0', 49 | b'a'..=b'f' => c - b'a' + 10, 50 | b'A'..=b'F' => c - b'A' + 10, 51 | _ => return Err(LexError::ParseFloatConst), 52 | } as u64; 53 | } 54 | Ok(f64::from_bits(val)) 55 | } 56 | 57 | #[derive(Clone, Copy, Debug, PartialEq, Default)] 58 | pub enum LexError { 59 | ParseFloatConst, 60 | ParseRegMultiplicity, 61 | ParseVersionNumber, 62 | #[default] 63 | Unknown, 64 | } 65 | 66 | impl std::fmt::Display for LexError { 67 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 68 | write!(f, "{:?}", self) 69 | } 70 | } 71 | impl std::error::Error for LexError {} 72 | 73 | #[derive(Logos, Debug, PartialEq, Clone, Copy)] 74 | #[logos(skip r"[ \t\n\f]+")] // Ignore this regex pattern between tokens 75 | #[logos(error = LexError)] 76 | pub enum Token<'a> { 77 | #[token(".address_size")] 78 | AddressSize, 79 | #[token(".explicitcluster")] 80 | Explicitcluster, 81 | #[token(".maxnreg")] 82 | Maxnreg, 83 | #[token(".section")] 84 | Section, 85 | #[token(".alias")] 86 | Alias, 87 | #[token(".extern")] 88 | Extern, 89 | #[token(".maxntid")] 90 | Maxntid, 91 | #[token(".shared")] 92 | Shared, 93 | #[token(".align")] 94 | Align, 95 | #[token(".file")] 96 | File, 97 | #[token(".minnctapersm")] 98 | Minnctapersm, 99 | #[token(".sreg")] 100 | Sreg, 101 | #[token(".branchtargets")] 102 | Branchtargets, 103 | #[token(".func")] 104 | Func, 105 | #[token(".noreturn")] 106 | Noreturn, 107 | #[token(".target")] 108 | Target, 109 | #[token(".callprototype")] 110 | Callprototype, 111 | #[token(".global")] 112 | Global, 113 | #[token(".param")] 114 | Param, 115 | #[token(".tex")] 116 | Tex, 117 | #[token(".calltargets")] 118 | Calltargets, 119 | #[token(".loc")] 120 | Loc, 121 | #[token(".pragma")] 122 | Pragma, 123 | // we parse this as a single token to avoid ambiguity with float constants 124 | #[regex(r".version[ \t\f\n]+\d+\.\d+", lex_version_number)] 125 | Version((u32, u32)), 126 | #[token(".common")] 127 | Common, 128 | #[token(".local")] 129 | Local, 130 | #[token(".reg")] 131 | Reg, 132 | #[token(".visible")] 133 | Visible, 134 | #[token(".const")] 135 | Const, 136 | #[token(".maxclusterrank")] 137 | Maxclusterrank, 138 | #[token(".reqnctapercluster")] 139 | Reqnctapercluster, 140 | #[token(".weak")] 141 | Weak, 142 | #[token(".entry")] 143 | Entry, 144 | #[token(".maxnctapersm")] 145 | Maxnctapersm, 146 | #[token(".reqntid")] 147 | Reqntid, 148 | 149 | #[token(".b128")] 150 | Bit128, 151 | #[token(".b64")] 152 | Bit64, 153 | #[token(".b32")] 154 | Bit32, 155 | #[token(".b16")] 156 | Bit16, 157 | #[token(".b8")] 158 | Bit8, 159 | #[token(".u64")] 160 | Unsigned64, 161 | #[token(".u32")] 162 | Unsigned32, 163 | #[token(".u16")] 164 | Unsigned16, 165 | #[token(".u8")] 166 | Unsigned8, 167 | #[token(".s64")] 168 | Signed64, 169 | #[token(".s32")] 170 | Signed32, 171 | #[token(".s16")] 172 | Signed16, 173 | #[token(".s8")] 174 | Signed8, 175 | #[token(".f64")] 176 | Float64, 177 | #[token(".f32")] 178 | Float32, 179 | #[token(".f16x2")] 180 | Float16x2, 181 | #[token(".f16")] 182 | Float16, 183 | #[token(".pred")] 184 | Predicate, 185 | 186 | #[token(".v2")] 187 | V2, 188 | #[token(".v4")] 189 | V4, 190 | 191 | #[token("abs")] 192 | Abs, 193 | #[token("discard")] 194 | Discard, 195 | #[token("min")] 196 | Min, 197 | #[token("shf")] 198 | Shf, 199 | #[token("vadd")] 200 | Vadd, 201 | #[token("activemask")] 202 | Activemask, 203 | #[token("div")] 204 | Div, 205 | #[token("mma")] 206 | Mma, 207 | #[token("shfl")] 208 | Shfl, 209 | #[token("vadd2")] 210 | Vadd2, 211 | #[token("add")] 212 | Add, 213 | #[token("dp2a")] 214 | Dp2A, 215 | #[token("mov")] 216 | Mov, 217 | #[token("shl")] 218 | Shl, 219 | #[token("vadd4")] 220 | Vadd4, 221 | #[token("addc")] 222 | Addc, 223 | #[token("dp4a")] 224 | Dp4A, 225 | #[token("movmatrix")] 226 | Movmatrix, 227 | #[token("shr")] 228 | Shr, 229 | #[token("vavrg2")] 230 | Vavrg2, 231 | #[token("alloca")] 232 | Alloca, 233 | #[token("elect")] 234 | Elect, 235 | #[token("mul")] 236 | Mul, 237 | #[token("sin")] 238 | Sin, 239 | #[token("vavrg4")] 240 | Vavrg4, 241 | #[token("and")] 242 | And, 243 | #[token("ex2")] 244 | Ex2, 245 | #[token("mul24")] 246 | Mul24, 247 | #[token("slct")] 248 | Slct, 249 | #[token("vmad")] 250 | Vmad, 251 | #[token("applypriority")] 252 | Applypriority, 253 | #[token("exit")] 254 | Exit, 255 | #[token("multimem")] 256 | Multimem, 257 | #[token("sqrt")] 258 | Sqrt, 259 | #[token("vmax")] 260 | Vmax, 261 | #[token("atom")] 262 | Atom, 263 | #[token("fence")] 264 | Fence, 265 | #[token("nanosleep")] 266 | Nanosleep, 267 | #[token("st")] 268 | St, 269 | #[token("vmax2")] 270 | Vmax2, 271 | #[token("bar")] 272 | Bar, 273 | #[token("fma")] 274 | Fma, 275 | #[token("neg")] 276 | Neg, 277 | #[token("stackrestore")] 278 | Stackrestore, 279 | #[token("vmax4")] 280 | Vmax4, 281 | #[token("barrier")] 282 | Barrier, 283 | #[token("fns")] 284 | Fns, 285 | #[token("not")] 286 | Not, 287 | #[token("stacksave")] 288 | Stacksave, 289 | #[token("vmin")] 290 | Vmin, 291 | #[token("bfe")] 292 | Bfe, 293 | #[token("getctarank")] 294 | Getctarank, 295 | #[token("or")] 296 | Or, 297 | #[token("stmatrix")] 298 | Stmatrix, 299 | #[token("vmin2")] 300 | Vmin2, 301 | #[token("bfi")] 302 | Bfi, 303 | #[token("griddepcontrol")] 304 | Griddepcontrol, 305 | #[token("pmevent")] 306 | Pmevent, 307 | #[token("sub")] 308 | Sub, 309 | #[token("vmin4")] 310 | Vmin4, 311 | #[token("bfind")] 312 | Bfind, 313 | #[token("isspacep")] 314 | Isspacep, 315 | #[token("popc")] 316 | Popc, 317 | #[token("subc")] 318 | Subc, 319 | #[token("vote")] 320 | Vote, 321 | #[token("bmsk")] 322 | Bmsk, 323 | #[token("istypep")] 324 | Istypep, 325 | #[token("prefetch")] 326 | Prefetch, 327 | #[token("suld")] 328 | Suld, 329 | #[token("vset")] 330 | Vset, 331 | #[token("bra")] 332 | Bra, 333 | #[token("ld")] 334 | Ld, 335 | #[token("prefetchu")] 336 | Prefetchu, 337 | #[token("suq")] 338 | Suq, 339 | #[token("vset2")] 340 | Vset2, 341 | #[token("brev")] 342 | Brev, 343 | #[token("ldmatrix")] 344 | Ldmatrix, 345 | #[token("prmt")] 346 | Prmt, 347 | #[token("sured")] 348 | Sured, 349 | #[token("vset4")] 350 | Vset4, 351 | #[token("brkpt")] 352 | Brkpt, 353 | #[token("ldu")] 354 | Ldu, 355 | #[token("rcp")] 356 | Rcp, 357 | #[token("sust")] 358 | Sust, 359 | #[token("vshl")] 360 | Vshl, 361 | #[token("brx")] 362 | Brx, 363 | #[token("lg2")] 364 | Lg2, 365 | #[token("red")] 366 | Red, 367 | #[token("szext")] 368 | Szext, 369 | #[token("vshr")] 370 | Vshr, 371 | #[token("call")] 372 | Call, 373 | #[token("lop3")] 374 | Lop3, 375 | #[token("redux")] 376 | Redux, 377 | #[token("tanh")] 378 | Tanh, 379 | #[token("vsub")] 380 | Vsub, 381 | #[token("clz")] 382 | Clz, 383 | #[token("mad")] 384 | Mad, 385 | #[token("rem")] 386 | Rem, 387 | #[token("testp")] 388 | Testp, 389 | #[token("vsub2")] 390 | Vsub2, 391 | #[token("cnot")] 392 | Cnot, 393 | #[token("mad24")] 394 | Mad24, 395 | #[token("ret")] 396 | Ret, 397 | #[token("tex")] 398 | InsTex, 399 | #[token("vsub4")] 400 | Vsub4, 401 | #[token("copysign")] 402 | Copysign, 403 | #[token("madc")] 404 | Madc, 405 | #[token("rsqrt")] 406 | Rsqrt, 407 | #[token("tld4")] 408 | Tld4, 409 | #[token("wgmma")] 410 | Wgmma, 411 | #[token("cos")] 412 | Cos, 413 | #[token("mapa")] 414 | Mapa, 415 | #[token("sad")] 416 | Sad, 417 | #[token("trap")] 418 | Trap, 419 | #[token("wmma")] 420 | Wmma, 421 | #[token("cp")] 422 | Cp, 423 | #[token("match")] 424 | Match, 425 | #[token("selp")] 426 | Selp, 427 | #[token("txq")] 428 | Txq, 429 | #[token("xor")] 430 | Xor, 431 | #[token("createpolicy")] 432 | Createpolicy, 433 | #[token("max")] 434 | Max, 435 | #[token("set")] 436 | Set, 437 | #[token("vabsdiff")] 438 | Vabsdiff, 439 | #[token("cvt")] 440 | Cvt, 441 | #[token("mbarrier")] 442 | Mbarrier, 443 | #[token("setmaxnreg")] 444 | Setmaxnreg, 445 | #[token("vabsdiff2")] 446 | Vabsdiff2, 447 | #[token("cvta")] 448 | Cvta, 449 | #[token("membar")] 450 | Membar, 451 | #[token("setp")] 452 | Setp, 453 | #[token("vabsdiff4")] 454 | Vabsdiff4, 455 | 456 | #[token(".cta")] 457 | Cta, 458 | 459 | #[token(".sync")] 460 | Sync, 461 | 462 | #[token(".to")] 463 | To, 464 | 465 | #[token(".rn")] 466 | Rn, 467 | #[token(".rz")] 468 | Rz, 469 | #[token(".rm")] 470 | Rm, 471 | #[token(".rp")] 472 | Rp, 473 | 474 | #[token(".lo")] 475 | Low, 476 | #[token(".hi")] 477 | High, 478 | #[token(".wide")] 479 | Wide, 480 | 481 | #[token(".eq")] 482 | Eq, 483 | #[token(".ne")] 484 | Ne, 485 | #[token(".lt")] 486 | Lt, 487 | #[token(".le")] 488 | Le, 489 | #[token(".gt")] 490 | Gt, 491 | #[token(".ge")] 492 | Ge, 493 | 494 | #[token(".uni")] 495 | Uniform, 496 | 497 | #[token("%tid")] 498 | ThreadId, 499 | #[token("%tid.x")] 500 | ThreadIdX, 501 | #[token("%tid.y")] 502 | ThreadIdY, 503 | #[token("%tid.z")] 504 | ThreadIdZ, 505 | 506 | #[token("%ntid")] 507 | NumThreads, 508 | #[token("%ntid.x")] 509 | NumThreadsX, 510 | #[token("%ntid.y")] 511 | NumThreadsY, 512 | #[token("%ntid.z")] 513 | NumThreadsZ, 514 | 515 | #[token("%ctaid")] 516 | CtaId, 517 | #[token("%ctaid.x")] 518 | CtaIdX, 519 | #[token("%ctaid.y")] 520 | CtaIdY, 521 | #[token("%ctaid.z")] 522 | CtaIdZ, 523 | 524 | #[regex(r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+", |lex| lex.slice())] 525 | Identifier(&'a str), 526 | 527 | #[regex(r"-?[0-9]+", |lex| lex.slice().parse().ok(), priority=2)] 528 | IntegerConst(i64), 529 | // todo make sure this token does not conflict with others 530 | // #[regex(r"[-+]?[0-9]*\.([0-9]+([eE][-+]?[0-9]+)?)", |lex| lex.slice().parse().ok())] 531 | #[regex(r"0[dD][0-9a-fA-F]{16}", lex_float64_constant)] 532 | Float64Const(f64), 533 | #[regex(r"0[fF][0-9a-fA-F]{8}", lex_float32_constant)] 534 | Float32Const(f32), 535 | 536 | #[regex(r"<\s*\+?\d+\s*>", lex_reg_multiplicity)] 537 | RegMultiplicity(u32), 538 | 539 | #[token("{")] 540 | LeftBrace, 541 | #[token("}")] 542 | RightBrace, 543 | #[token("(")] 544 | LeftParen, 545 | #[token(")")] 546 | RightParen, 547 | #[token("[")] 548 | LeftBracket, 549 | #[token("]")] 550 | RightBracket, 551 | #[token("@")] 552 | At, 553 | #[token("!")] 554 | Bang, 555 | #[token("+")] 556 | Plus, 557 | #[token(";")] 558 | Semicolon, 559 | #[token(":")] 560 | Colon, 561 | #[token(",")] 562 | Comma, 563 | 564 | #[regex(r#""[^"]*""#, |lex| lex.slice())] 565 | StringLiteral(&'a str), 566 | #[regex(r"\d+\.\d+", lex_version_number)] 567 | VersionNumber((u32, u32)), 568 | 569 | #[regex(r"//.*", logos::skip)] 570 | Skip, 571 | } 572 | 573 | impl<'a> Token<'a> { 574 | pub fn is_directive(&self) -> bool { 575 | matches!( 576 | self, 577 | Token::Version(_) 578 | | Token::Target 579 | | Token::AddressSize 580 | | Token::Visible 581 | | Token::Entry 582 | | Token::Func 583 | | Token::Param 584 | | Token::Reg 585 | | Token::Global 586 | | Token::Local 587 | | Token::Shared 588 | | Token::Const 589 | | Token::Align 590 | | Token::Pragma 591 | ) 592 | } 593 | } 594 | 595 | impl<'a> std::fmt::Display for Token<'a> { 596 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 597 | write!(f, "{:?}", self) 598 | } 599 | } 600 | -------------------------------------------------------------------------------- /src/compiler.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use crate::ast; 4 | use crate::vm; 5 | 6 | #[derive(Debug)] 7 | pub struct CompiledModule { 8 | pub instructions: Vec, 9 | pub func_descriptors: Vec, 10 | // pub global_vars: Vec, 11 | pub symbol_map: HashMap, 12 | } 13 | 14 | #[derive(thiserror::Error, Debug)] 15 | pub enum CompilationError { 16 | #[error("undefined symbol: {0:?}")] 17 | UndefinedSymbol(String), 18 | #[error("undefined label: {0:?}")] 19 | UndefinedLabel(String), 20 | #[error("invalid state space")] 21 | InvalidStateSpace, 22 | #[error("invalid operand: {0:?}")] 23 | InvalidOperand(ast::Operand), 24 | #[error("missing operand")] 25 | MissingOperand, 26 | #[error("invalid immediate type")] 27 | InvalidImmediateType, 28 | #[error("invalid register type: {0:?}")] 29 | InvalidRegisterType(vm::RegOperand), 30 | } 31 | 32 | #[derive(Clone, Debug)] 33 | struct BasicBlock { 34 | label: Option, 35 | instructions: Vec, 36 | } 37 | 38 | #[derive(Clone, Copy, Debug)] 39 | enum Variable { 40 | Register(vm::GenericReg), 41 | Absolute(usize), 42 | Stack(isize), 43 | } 44 | struct VariableMap(HashMap); 45 | 46 | impl VariableMap { 47 | pub fn new() -> Self { 48 | Self(HashMap::new()) 49 | } 50 | 51 | pub fn insert(&mut self, ident: String, var: Variable) { 52 | self.0.insert(ident, var); 53 | } 54 | 55 | pub fn get(&self, ident: &str) -> Option<&Variable> { 56 | self.0.get(ident) 57 | } 58 | 59 | pub fn get_reg(&self, ident: &str) -> Result { 60 | match self.0.get(ident) { 61 | Some(Variable::Register(reg)) => Ok(*reg), 62 | Some(_) => Err(CompilationError::InvalidStateSpace), 63 | _ => Err(CompilationError::UndefinedSymbol(ident.to_string())), 64 | } 65 | } 66 | } 67 | 68 | impl CompiledModule { 69 | fn compile_directive_toplevel( 70 | &mut self, 71 | directive: ast::Directive, 72 | ) -> Result<(), CompilationError> { 73 | use ast::Directive; 74 | match directive { 75 | Directive::Version(_) => Ok(()), 76 | Directive::Target(_) => Ok(()), 77 | Directive::Pragma(_) => Ok(()), 78 | Directive::VarDecl(_) => todo!(), 79 | Directive::AddressSize(a) => match a { 80 | ast::AddressSize::Adr64 => Ok(()), 81 | _ => todo!(), 82 | }, 83 | Directive::Function(f) => self.compile_function(f), 84 | } 85 | } 86 | 87 | fn compile_function(&mut self, func: ast::Function) -> Result<(), CompilationError> { 88 | let iptr = vm::IPtr(self.instructions.len()); 89 | let mut state = FuncCodegenState::new(self); 90 | state.compile_ast(func)?; 91 | let (ident, mut frame_desc, instructions) = state.finalize()?; 92 | frame_desc.iptr = iptr; 93 | self.instructions.extend(instructions); 94 | let desc = vm::Symbol::Function(self.func_descriptors.len()); 95 | self.func_descriptors.push(frame_desc); 96 | self.symbol_map.insert(ident, desc); 97 | Ok(()) 98 | } 99 | } 100 | 101 | fn resolve_state_space(st: ast::StateSpace) -> Result { 102 | use ast::StateSpace::*; 103 | match st { 104 | Global | Constant => Ok(vm::StateSpace::Global), 105 | Shared => Ok(vm::StateSpace::Shared), 106 | Local | Parameter => Ok(vm::StateSpace::Stack), 107 | Register => Err(CompilationError::InvalidStateSpace), 108 | } 109 | } 110 | 111 | fn get_ops(ops: Vec) -> Result<[ast::Operand; N], CompilationError> { 112 | if ops.len() != N { 113 | return Err(CompilationError::MissingOperand); 114 | } 115 | const VAL: ast::Operand = ast::Operand::Variable(String::new()); 116 | let mut arr = [VAL; N]; 117 | for (i, op) in ops.into_iter().enumerate() { 118 | arr[i] = op; 119 | } 120 | Ok(arr) 121 | } 122 | 123 | struct FuncCodegenState<'a> { 124 | parent: &'a CompiledModule, 125 | ident: String, 126 | instructions: Vec, 127 | var_map: VariableMap, 128 | num_regs: usize, 129 | num_args: usize, 130 | stack_size: usize, 131 | shared_size: usize, 132 | label_map: HashMap, 133 | jump_map: Vec, 134 | } 135 | 136 | impl<'a> FuncCodegenState<'a> { 137 | pub fn new(parent: &'a CompiledModule) -> Self { 138 | Self { 139 | parent, 140 | ident: String::new(), 141 | instructions: Vec::new(), 142 | var_map: VariableMap::new(), 143 | num_regs: 0, 144 | num_args: 0, 145 | stack_size: 0, 146 | shared_size: 0, 147 | label_map: HashMap::new(), 148 | jump_map: Vec::new(), 149 | } 150 | } 151 | 152 | fn alloc_reg(&mut self) -> vm::GenericReg { 153 | let idx = self.num_args + self.num_regs; 154 | self.num_regs += 1; 155 | vm::GenericReg(idx) 156 | } 157 | 158 | fn declare_var(&mut self, decl: ast::VarDecl) -> Result<(), CompilationError> { 159 | let vmty: vm::Type = decl.ty.into(); 160 | use ast::StateSpace; 161 | if let ast::Type::Pred = decl.ty { 162 | // predicates can only exist in the reg state space 163 | if decl.state_space != StateSpace::Register { 164 | todo!() 165 | } 166 | } 167 | if let StateSpace::Register = decl.state_space { 168 | if !decl.array_bounds.is_empty() { 169 | todo!("array bounds not supported on register variables") 170 | } 171 | let reg = self.alloc_reg(); 172 | self.var_map.insert(decl.ident, Variable::Register(reg)); 173 | return Ok(()); 174 | } 175 | 176 | let count = decl.array_bounds.iter().product::(); 177 | let size = vmty.size() * count as usize; 178 | let align = vmty.alignment(); 179 | assert!(align.count_ones() == 1); 180 | // align to required alignment 181 | 182 | match decl.state_space { 183 | StateSpace::Shared => { 184 | self.shared_size = (self.shared_size + align - 1) & !(align - 1); 185 | let loc = Variable::Absolute(self.shared_size); 186 | self.shared_size += size; 187 | self.var_map.insert(decl.ident, loc); 188 | } 189 | StateSpace::Local | StateSpace::Parameter => { 190 | self.stack_size = (self.stack_size + align - 1) & !(align - 1); 191 | let loc = Variable::Stack(self.stack_size as isize); 192 | self.stack_size += size; 193 | self.var_map.insert(decl.ident, loc); 194 | } 195 | StateSpace::Global => todo!(), 196 | StateSpace::Constant => todo!(), 197 | StateSpace::Register => unreachable!(), 198 | } 199 | Ok(()) 200 | } 201 | 202 | fn handle_vars(&mut self, vars: Vec) -> Result<(), CompilationError> { 203 | for decl in vars { 204 | if let Some(mult) = decl.multiplicity { 205 | if !decl.array_bounds.is_empty() { 206 | todo!("array bounds not supported on parametrized variables") 207 | } 208 | for i in 0..mult { 209 | let mut decl = decl.clone(); 210 | decl.ident.push_str(&i.to_string()); 211 | decl.multiplicity = None; 212 | self.declare_var(decl)?; 213 | } 214 | } else { 215 | self.declare_var(decl)?; 216 | } 217 | } 218 | Ok(()) 219 | } 220 | 221 | fn handle_params( 222 | &mut self, 223 | retval: Option, 224 | mut params: Vec, 225 | ) -> Result<(), CompilationError> { 226 | if let Some(p) = retval { 227 | params.push(p); 228 | } 229 | self.num_args = params.len(); 230 | for (idx, param) in params.into_iter().enumerate() { 231 | if let ast::Type::Pred = param.ty { 232 | // this should raise an error as predicates can only exist in the reg state space 233 | todo!() 234 | } 235 | // for now, only handle parameters in the .param state space 236 | 237 | self.var_map 238 | .insert(param.ident, Variable::Register(vm::GenericReg(idx))); 239 | } 240 | Ok(()) 241 | } 242 | 243 | fn construct_immediate( 244 | &mut self, 245 | _ty: vm::Type, 246 | imm: ast::Immediate, 247 | ) -> Result { 248 | let vmconst = match imm { 249 | ast::Immediate::Float32(v) => vm::Constant::F32(v), 250 | ast::Immediate::Float64(v) => vm::Constant::F64(v), 251 | ast::Immediate::Int64(v) => vm::Constant::S64(v), 252 | ast::Immediate::UInt64(v) => vm::Constant::U64(v), 253 | }; 254 | let opref = self.alloc_reg(); 255 | self.instructions 256 | .push(vm::Instruction::Const(opref, vmconst)); 257 | Ok(opref) 258 | } 259 | 260 | fn get_src_reg( 261 | &mut self, 262 | ty: vm::Type, 263 | op: &ast::Operand, 264 | ) -> Result { 265 | use ast::Operand; 266 | match op { 267 | Operand::Variable(ident) => self.var_map.get_reg(ident).map(|r| r.into()), 268 | Operand::Immediate(imm) => self.construct_immediate(ty, *imm).map(|r| r.into()), 269 | Operand::SpecialReg(special) => Ok((*special).into()), 270 | op @ Operand::Address(_) => Err(CompilationError::InvalidOperand(op.clone())), 271 | } 272 | } 273 | 274 | fn get_dst_reg( 275 | &mut self, 276 | _ty: vm::Type, 277 | op: &ast::Operand, 278 | ) -> Result { 279 | use ast::Operand; 280 | match op { 281 | Operand::Variable(ident) => self.var_map.get_reg(ident), 282 | _ => Err(CompilationError::InvalidOperand(op.clone())), 283 | } 284 | } 285 | 286 | fn reg_dst_1src( 287 | &mut self, 288 | ty: vm::Type, 289 | ops: &[ast::Operand], 290 | ) -> Result<(vm::GenericReg, vm::RegOperand), CompilationError> { 291 | let [dst, src] = ops else { todo!() }; 292 | let dst_reg = self.get_dst_reg(ty, dst)?; 293 | let src_reg = self.get_src_reg(ty, src)?; 294 | Ok((dst_reg, src_reg)) 295 | } 296 | 297 | fn reg_dst_2src( 298 | &mut self, 299 | ty: vm::Type, 300 | ops: &[ast::Operand], 301 | ) -> Result<(vm::GenericReg, vm::RegOperand, vm::RegOperand), CompilationError> { 302 | let [dst, lhs_op, rhs_op] = ops else { todo!() }; 303 | let dst_reg = self.get_dst_reg(ty, dst)?; 304 | let lhs_reg = self.get_src_reg(ty, lhs_op)?; 305 | let rhs_reg = self.get_src_reg(ty, rhs_op)?; 306 | Ok((dst_reg, lhs_reg, rhs_reg)) 307 | } 308 | 309 | fn reg_dst_3src( 310 | &mut self, 311 | ty: vm::Type, 312 | ops: &[ast::Operand], 313 | ) -> Result< 314 | ( 315 | vm::GenericReg, 316 | vm::RegOperand, 317 | vm::RegOperand, 318 | vm::RegOperand, 319 | ), 320 | CompilationError, 321 | > { 322 | let [dst, src1_op, src2_op, src3_op] = ops else { 323 | todo!() 324 | }; 325 | let dst_reg = self.get_dst_reg(ty, dst)?; 326 | let src1_reg = self.get_src_reg(ty, src1_op)?; 327 | let src2_reg = self.get_src_reg(ty, src2_op)?; 328 | let src3_reg = self.get_src_reg(ty, src3_op)?; 329 | Ok((dst_reg, src1_reg, src2_reg, src3_reg)) 330 | } 331 | 332 | fn resolve_addr_operand( 333 | &mut self, 334 | operand: &ast::AddressOperand, 335 | ) -> Result { 336 | use ast::AddressOperand; 337 | Ok(match operand { 338 | AddressOperand::Address(ident) => { 339 | match self 340 | .var_map 341 | .get(ident) 342 | .cloned() 343 | .ok_or_else(|| CompilationError::UndefinedSymbol(ident.to_string()))? 344 | { 345 | Variable::Register(reg) => reg.into(), 346 | Variable::Absolute(addr) => self 347 | .construct_immediate(vm::Type::U64, ast::Immediate::UInt64(addr as u64))? 348 | .into(), 349 | Variable::Stack(addr) => { 350 | let dst = self.construct_immediate( 351 | vm::Type::S64, 352 | ast::Immediate::Int64(addr as i64), 353 | )?; 354 | self.instructions.push(vm::Instruction::Add( 355 | vm::Type::S64, 356 | dst, 357 | dst.into(), 358 | ast::SpecialReg::StackPtr.into(), 359 | )); 360 | dst.into() 361 | } 362 | } 363 | } 364 | AddressOperand::AddressOffset(ident, offset) => { 365 | match self 366 | .var_map 367 | .get(ident) 368 | .cloned() 369 | .ok_or_else(|| CompilationError::UndefinedSymbol(ident.to_string()))? 370 | { 371 | Variable::Register(reg) => { 372 | let dst = self 373 | .construct_immediate(vm::Type::S64, ast::Immediate::Int64(*offset))?; 374 | self.instructions.push(vm::Instruction::Add( 375 | vm::Type::S64, 376 | dst, 377 | dst.into(), 378 | reg.into(), 379 | )); 380 | dst.into() 381 | } 382 | Variable::Absolute(addr) => self 383 | .construct_immediate( 384 | vm::Type::U64, 385 | ast::Immediate::UInt64(addr as u64 + *offset as u64), 386 | )? 387 | .into(), 388 | Variable::Stack(addr) => { 389 | let dst = self.construct_immediate( 390 | vm::Type::S64, 391 | ast::Immediate::Int64(addr as i64 + *offset), 392 | )?; 393 | self.instructions.push(vm::Instruction::Add( 394 | vm::Type::S64, 395 | dst, 396 | dst.into(), 397 | ast::SpecialReg::StackPtr.into(), 398 | )); 399 | dst.into() 400 | } 401 | } 402 | } 403 | AddressOperand::AddressOffsetVar(_, _) => todo!(), 404 | AddressOperand::ArrayIndex(_, _) => todo!(), 405 | }) 406 | } 407 | 408 | fn handle_instruction(&mut self, instr: ast::Instruction) -> Result<(), CompilationError> { 409 | use ast::Operand; 410 | use ast::Operation; 411 | 412 | if let Some(guard) = instr.guard { 413 | let (ident, expected) = match guard { 414 | ast::Guard::Normal(s) => (s, false), 415 | ast::Guard::Negated(s) => (s, true), 416 | }; 417 | let guard_reg = self.var_map.get_reg(&ident)?; 418 | self.instructions 419 | .push(vm::Instruction::SkipIf(guard_reg.into(), expected)); 420 | } 421 | 422 | match instr.specifier { 423 | Operation::Load(st, ty) => { 424 | let [dst, src] = get_ops(instr.operands)?; 425 | let Operand::Variable(ident) = dst else { 426 | return Err(CompilationError::InvalidOperand(dst)); 427 | }; 428 | let Operand::Address(addr_op) = src else { 429 | return Err(CompilationError::InvalidOperand(src)); 430 | }; 431 | let dst_reg = self.var_map.get_reg(&ident)?; 432 | let src_op = self.resolve_addr_operand(&addr_op)?; 433 | self.instructions.push(vm::Instruction::Load( 434 | ty.into(), 435 | resolve_state_space(st)?, 436 | dst_reg, 437 | src_op, 438 | )) 439 | } 440 | Operation::Store(st, ty) => { 441 | let [dst, src] = get_ops(instr.operands)?; 442 | let Operand::Address(addr_op) = dst else { 443 | return Err(CompilationError::InvalidOperand(dst)); 444 | }; 445 | let Operand::Variable(ident) = src else { 446 | return Err(CompilationError::InvalidOperand(src)); 447 | }; 448 | let src_reg = self.var_map.get_reg(&ident)?; 449 | let dst_op = self.resolve_addr_operand(&addr_op)?; 450 | self.instructions.push(vm::Instruction::Store( 451 | ty.into(), 452 | resolve_state_space(st)?, 453 | src_reg.into(), 454 | dst_op, 455 | )) 456 | } 457 | Operation::Move(ty) => { 458 | let ty = ty.into(); 459 | let [dst, src] = get_ops(instr.operands)?; 460 | let dst_reg = self.get_dst_reg(ty, &dst)?; 461 | let src_reg = 462 | match src { 463 | Operand::Variable(ident) => { 464 | match self.var_map.get(&ident).cloned().ok_or_else(|| { 465 | CompilationError::UndefinedSymbol(ident.to_string()) 466 | })? { 467 | Variable::Register(reg) => reg.into(), 468 | // this is an LEA operation, not just a normal mov 469 | Variable::Stack(offset) => { 470 | let imm = self.construct_immediate( 471 | vm::Type::U64, 472 | ast::Immediate::UInt64(offset as u64), 473 | )?; 474 | self.instructions.push(vm::Instruction::Add( 475 | vm::Type::S64, 476 | dst_reg, 477 | imm.into(), 478 | ast::SpecialReg::StackPtr.into(), 479 | )); 480 | return Ok(()); 481 | } 482 | Variable::Absolute(addr) => { 483 | self.instructions.push(vm::Instruction::Const( 484 | dst_reg, 485 | vm::Constant::U64(addr as u64), 486 | )); 487 | return Ok(()); 488 | } 489 | } 490 | } 491 | Operand::Immediate(imm) => self.construct_immediate(ty, imm)?.into(), 492 | Operand::SpecialReg(special) => special.into(), 493 | op @ Operand::Address(_) => { 494 | return Err(CompilationError::InvalidOperand(op.clone())) 495 | } 496 | }; 497 | self.instructions 498 | .push(vm::Instruction::Move(ty, dst_reg, src_reg)); 499 | } 500 | Operation::Add(ty) => { 501 | let ty = ty.into(); 502 | let (dst_reg, lhs_reg, rhs_reg) = 503 | self.reg_dst_2src(ty, instr.operands.as_slice())?; 504 | self.instructions 505 | .push(vm::Instruction::Add(ty, dst_reg, lhs_reg, rhs_reg)); 506 | } 507 | Operation::Multiply(mode, ty) => { 508 | let ty = ty.into(); 509 | let (dst_reg, lhs_reg, rhs_reg) = 510 | self.reg_dst_2src(ty, instr.operands.as_slice())?; 511 | self.instructions 512 | .push(vm::Instruction::Mul(ty, mode, dst_reg, lhs_reg, rhs_reg)); 513 | } 514 | Operation::MultiplyAdd(mode, ty) => { 515 | let ty = ty.into(); 516 | let (dst, a, b, c) = self.reg_dst_3src(ty, &instr.operands)?; 517 | let tmp = self.alloc_reg(); 518 | self.instructions 519 | .push(vm::Instruction::Mul(ty, mode, tmp, a, b)); 520 | self.instructions 521 | .push(vm::Instruction::Add(ty, dst, tmp.into(), c)); 522 | } 523 | Operation::Sub(ty) => { 524 | let ty = ty.into(); 525 | let (dst, a, b) = self.reg_dst_2src(ty, &instr.operands)?; 526 | self.instructions.push(vm::Instruction::Sub(ty, dst, a, b)) 527 | } 528 | Operation::Or(ty) => { 529 | let ty = ty.into(); 530 | let (dst, a, b) = self.reg_dst_2src(ty, &instr.operands)?; 531 | self.instructions.push(vm::Instruction::Or(ty, dst, a, b)) 532 | } 533 | Operation::And(ty) => { 534 | let ty = ty.into(); 535 | let (dst, a, b) = self.reg_dst_2src(ty, &instr.operands)?; 536 | self.instructions.push(vm::Instruction::And(ty, dst, a, b)) 537 | } 538 | Operation::Not(ty) => { 539 | let ty = ty.into(); 540 | let (dst, src) = self.reg_dst_1src(ty, &instr.operands)?; 541 | self.instructions.push(vm::Instruction::Not(ty, dst, src)) 542 | } 543 | Operation::FusedMulAdd(_, ty) => { 544 | let ty = ty.into(); 545 | let (dst, a, b, c) = self.reg_dst_3src(ty, &instr.operands)?; 546 | let tmp = self.alloc_reg(); 547 | self.instructions 548 | .push(vm::Instruction::Mul(ty, ast::MulMode::Low, tmp, a, b)); 549 | self.instructions 550 | .push(vm::Instruction::Add(ty, dst, tmp.into(), c)); 551 | } 552 | Operation::Negate(ty) => { 553 | let ty = ty.into(); 554 | let (dst, src) = self.reg_dst_1src(ty, &instr.operands)?; 555 | self.instructions.push(vm::Instruction::Neg(ty, dst, src)); 556 | } 557 | Operation::Convert { from, to } => { 558 | let from = from.into(); 559 | let to = to.into(); 560 | let (dst, src) = self.reg_dst_1src(from, &instr.operands)?; 561 | self.instructions.push(vm::Instruction::Convert { 562 | dst_type: to, 563 | src_type: from, 564 | dst, 565 | src, 566 | }); 567 | } 568 | Operation::ConvertAddress(_ty, _st) => todo!(), 569 | Operation::ConvertAddressTo(ty, _st) => { 570 | let ty = ty.into(); 571 | // TODO handle different state spaces 572 | // for now, just move the address register into the destination register 573 | let (dst, src) = self.reg_dst_1src(ty, &instr.operands)?; 574 | self.instructions.push(vm::Instruction::Move(ty, dst, src)); 575 | } 576 | Operation::SetPredicate(pred, ty) => { 577 | let ty = ty.into(); 578 | let (dst, a, b) = self.reg_dst_2src(ty, instr.operands.as_slice())?; 579 | self.instructions 580 | .push(vm::Instruction::SetPredicate(ty, pred, dst, a, b)); 581 | } 582 | Operation::ShiftLeft(ty) => { 583 | let ty = ty.into(); 584 | let (dst_reg, lhs_reg, rhs_reg) = 585 | self.reg_dst_2src(ty, instr.operands.as_slice())?; 586 | self.instructions 587 | .push(vm::Instruction::ShiftLeft(ty, dst_reg, lhs_reg, rhs_reg)); 588 | } 589 | Operation::Call { 590 | uniform: _, 591 | ident, 592 | ret_param, 593 | mut params, 594 | } => { 595 | let Some(vm::Symbol::Function(descriptor)) = self.parent.symbol_map.get(&ident) 596 | else { 597 | return Err(CompilationError::UndefinedSymbol(ident.clone())); 598 | }; 599 | 600 | if let Some(ret_param) = ret_param { 601 | params.push(ret_param); 602 | } 603 | for param in ¶ms { 604 | match self 605 | .var_map 606 | .get(param) 607 | .cloned() 608 | .ok_or_else(|| CompilationError::UndefinedSymbol(ident.to_string()))? 609 | { 610 | Variable::Register(reg) => { 611 | self.instructions.push(vm::Instruction::PushArg(reg.into())); 612 | } 613 | Variable::Stack(offset) => { 614 | let imm = self.construct_immediate( 615 | vm::Type::S64, 616 | ast::Immediate::Int64(offset as i64), 617 | )?; 618 | self.instructions.push(vm::Instruction::Add( 619 | vm::Type::S64, 620 | imm, 621 | imm.into(), 622 | ast::SpecialReg::StackPtr.into(), 623 | )); 624 | self.instructions.push(vm::Instruction::PushArg(imm.into())); 625 | } 626 | Variable::Absolute(addr) => { 627 | let imm = self.construct_immediate( 628 | vm::Type::U64, 629 | ast::Immediate::UInt64(addr as u64), 630 | )?; 631 | self.instructions.push(vm::Instruction::PushArg(imm.into())); 632 | } 633 | } 634 | } 635 | self.instructions.push(vm::Instruction::Call(*descriptor)); 636 | for param in ¶ms { 637 | match self.var_map.get(param).cloned() { 638 | Some(Variable::Register(reg)) => { 639 | self.instructions.push(vm::Instruction::PopArg(Some(reg))); 640 | } 641 | Some(_) => { 642 | self.instructions.push(vm::Instruction::PopArg(None)); 643 | } 644 | None => return Err(CompilationError::UndefinedSymbol(param.clone())), 645 | } 646 | } 647 | } 648 | Operation::BarrierSync => match instr.operands.as_slice() { 649 | [idx] => { 650 | let src_reg = self.get_src_reg(vm::Type::U32, idx)?; 651 | self.instructions.push(vm::Instruction::BarrierSync { 652 | idx: src_reg, 653 | cnt: None, 654 | }) 655 | } 656 | [_idx, _cnt] => { 657 | todo!() 658 | } 659 | _ => todo!(), 660 | }, 661 | Operation::Branch => { 662 | let [Operand::Variable(ident)] = instr.operands.as_slice() else { 663 | todo!() 664 | }; 665 | let jump_idx = self.jump_map.len(); 666 | self.jump_map.push(ident.clone()); 667 | self.instructions.push(vm::Instruction::Jump { 668 | offset: jump_idx as isize, 669 | }); 670 | } 671 | Operation::Return => self.instructions.push(vm::Instruction::Return), 672 | }; 673 | Ok(()) 674 | } 675 | 676 | fn handle_basic_block(&mut self, block: BasicBlock) -> Result<(), CompilationError> { 677 | if let Some(label) = block.label { 678 | self.label_map.insert(label, self.instructions.len()); 679 | } 680 | for instr in block.instructions { 681 | self.handle_instruction(instr)?; 682 | } 683 | Ok(()) 684 | } 685 | 686 | pub fn compile_ast(&mut self, func: ast::Function) -> Result<(), CompilationError> { 687 | self.ident = func.ident; 688 | let ast::Statement::Grouping(mut body) = *func.body else { 689 | todo!() 690 | }; 691 | 692 | let mut block = BasicBlock { 693 | label: None, 694 | instructions: Vec::new(), 695 | }; 696 | let mut bblocks = Vec::new(); 697 | let mut var_decls = Vec::new(); 698 | 699 | body.reverse(); 700 | while let Some(statement) = body.pop() { 701 | use ast::{Directive, Statement}; 702 | 703 | match statement { 704 | Statement::Directive(Directive::VarDecl(v)) => var_decls.push(v), 705 | Statement::Instruction(i) => block.instructions.push(i), 706 | Statement::Label(ident) => { 707 | let mut block2 = BasicBlock { 708 | label: Some(ident), 709 | instructions: Vec::new(), 710 | }; 711 | std::mem::swap(&mut block, &mut block2); 712 | bblocks.push(block2); 713 | } 714 | // TODO: groupings usually interact with the scope of variables, 715 | // which is completely ignored here. Should be fixed in the future 716 | Statement::Grouping(mut inner) => { 717 | inner.reverse(); 718 | body.extend(inner); 719 | } 720 | // ignore miscelaneous directives 721 | Statement::Directive(_) => {} 722 | } 723 | } 724 | 725 | bblocks.push(block); 726 | 727 | self.handle_params(func.return_param, func.params)?; 728 | self.handle_vars(var_decls)?; 729 | for block in bblocks { 730 | self.handle_basic_block(block)?; 731 | } 732 | 733 | Ok(()) 734 | } 735 | 736 | pub fn finalize( 737 | mut self, 738 | ) -> Result<(String, vm::FuncFrameDesc, Vec), CompilationError> { 739 | // resolve jump targets 740 | for (idx, instr) in self.instructions.iter_mut().enumerate() { 741 | if let vm::Instruction::Jump { offset } = instr { 742 | let jump_map_idx = *offset as usize; 743 | let target_label = &self.jump_map[jump_map_idx]; 744 | if let Some(target_label) = self.label_map.get(target_label) { 745 | *offset = *target_label as isize - idx as isize; 746 | } else { 747 | return Err(CompilationError::UndefinedLabel(target_label.clone())); 748 | } 749 | } 750 | } 751 | 752 | // align stack size to 16 bytes 753 | self.stack_size = (self.stack_size + 15) & !15; 754 | let frame_desc = vm::FuncFrameDesc { 755 | iptr: vm::IPtr(self.instructions.len()), 756 | frame_size: self.stack_size, 757 | shared_size: self.shared_size, 758 | num_regs: self.num_regs, 759 | num_args: self.num_args, 760 | }; 761 | Ok((self.ident, frame_desc, self.instructions)) 762 | } 763 | } 764 | 765 | pub fn compile(module: ast::Module) -> Result { 766 | let mut cmod = CompiledModule { 767 | instructions: Vec::new(), 768 | func_descriptors: Vec::new(), 769 | // global_vars: Vec::new(), 770 | symbol_map: HashMap::new(), 771 | }; 772 | for directive in module.0 { 773 | cmod.compile_directive_toplevel(directive)?; 774 | } 775 | Ok(cmod) 776 | } 777 | 778 | #[cfg(test)] 779 | mod test { 780 | use super::*; 781 | 782 | #[test] 783 | fn compile_add_simple() { 784 | let contents = std::fs::read_to_string("kernels/add_simple.ptx").unwrap(); 785 | let module = crate::ast::parse_program(&contents).unwrap(); 786 | let _ = compile(module).unwrap(); 787 | } 788 | 789 | #[test] 790 | fn compile_add() { 791 | let contents = std::fs::read_to_string("kernels/add.ptx").unwrap(); 792 | let module = crate::ast::parse_program(&contents).unwrap(); 793 | let _ = compile(module).unwrap(); 794 | } 795 | 796 | #[test] 797 | fn compile_transpose() { 798 | let contents = std::fs::read_to_string("kernels/transpose.ptx").unwrap(); 799 | let module = crate::ast::parse_program(&contents).unwrap(); 800 | let _ = compile(module).unwrap(); 801 | } 802 | 803 | #[test] 804 | fn compile_gemm() { 805 | let contents = std::fs::read_to_string("kernels/gemm.ptx").unwrap(); 806 | let module = crate::ast::parse_program(&contents).unwrap(); 807 | let _ = compile(module).unwrap(); 808 | } 809 | 810 | #[test] 811 | fn compile_fncall() { 812 | let contents = std::fs::read_to_string("kernels/fncall.ptx").unwrap(); 813 | let module = crate::ast::parse_program(&contents).unwrap(); 814 | let _ = compile(module).unwrap(); 815 | } 816 | } 817 | -------------------------------------------------------------------------------- /src/ast/mod.rs: -------------------------------------------------------------------------------- 1 | mod lex; 2 | 3 | use std::ops::Range; 4 | 5 | use lex::LexError; 6 | use lex::Token; 7 | use thiserror::Error; 8 | 9 | #[derive(Debug, Clone, Copy)] 10 | pub struct SourceLocation { 11 | byte: usize, 12 | line: usize, 13 | col: usize, 14 | } 15 | 16 | impl std::fmt::Display for SourceLocation { 17 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 18 | write!(f, "@{}:{} (byte {})", self.line, self.col, self.byte) 19 | } 20 | } 21 | 22 | #[derive(Error, Debug)] 23 | pub enum ParseErr { 24 | #[error("Unexpected token \"{:?}\"", .0)] 25 | UnexpectedToken(String, SourceLocation), 26 | #[error("Unexpected end of file")] 27 | UnexpectedEof, 28 | #[error("Lex error \"{:?}\" at {:?}", .0, .1)] 29 | LexError(LexError, SourceLocation), 30 | #[error("Unknown token \"{:?}\" at {:?}", .0, .1)] 31 | UnknownToken(String, SourceLocation), 32 | } 33 | 34 | type ParseResult = Result; 35 | 36 | type Ident = String; 37 | 38 | #[derive(Debug, Clone, Copy, PartialEq)] 39 | pub struct Version { 40 | major: u32, 41 | minor: u32, 42 | } 43 | 44 | #[derive(Clone, Debug)] 45 | pub struct Pragma(String); 46 | 47 | #[derive(Debug)] 48 | pub enum AddressSize { 49 | Adr32, 50 | Adr64, 51 | Other, 52 | } 53 | 54 | #[derive(Debug)] 55 | pub struct Module(pub Vec); 56 | 57 | #[derive(Debug)] 58 | pub struct Function { 59 | pub ident: Ident, 60 | pub visible: bool, 61 | pub entry: bool, 62 | pub noreturn: bool, 63 | pub return_param: Option, 64 | pub params: Vec, 65 | pub body: Box, 66 | } 67 | 68 | #[derive(Debug)] 69 | pub struct FunctionParam { 70 | pub ident: Ident, 71 | pub ty: Type, 72 | pub alignment: Option, 73 | pub array_bounds: Vec, 74 | } 75 | 76 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 77 | pub enum StateSpace { 78 | Global, 79 | Local, 80 | Shared, 81 | Register, 82 | Constant, 83 | Parameter, 84 | } 85 | 86 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 87 | pub enum Type { 88 | B128, 89 | B64, 90 | B32, 91 | B16, 92 | B8, 93 | U64, 94 | U32, 95 | U16, 96 | U8, 97 | S64, 98 | S32, 99 | S16, 100 | S8, 101 | F64, 102 | F32, 103 | F16x2, 104 | F16, 105 | Pred, 106 | } 107 | 108 | #[derive(Debug, Clone, Copy)] 109 | pub enum Vector { 110 | V2, 111 | V4, 112 | } 113 | 114 | #[derive(Debug, Clone, Copy)] 115 | pub enum SpecialReg { 116 | StackPtr, 117 | ThreadId, 118 | ThreadIdX, 119 | ThreadIdY, 120 | ThreadIdZ, 121 | NumThread, 122 | NumThreadX, 123 | NumThreadY, 124 | NumThreadZ, 125 | CtaId, 126 | CtaIdX, 127 | CtaIdY, 128 | CtaIdZ, 129 | NumCta, 130 | NumCtaX, 131 | NumCtaY, 132 | NumCtaZ, 133 | } 134 | 135 | impl From for Operand { 136 | fn from(value: SpecialReg) -> Self { 137 | Operand::SpecialReg(value) 138 | } 139 | } 140 | 141 | #[derive(Debug, Clone)] 142 | pub struct VarDecl { 143 | pub state_space: StateSpace, 144 | pub ty: Type, 145 | pub vector: Option, 146 | pub ident: Ident, 147 | pub alignment: Option, 148 | pub array_bounds: Vec, 149 | pub multiplicity: Option, 150 | } 151 | 152 | #[derive(Debug, Clone)] 153 | pub enum AddressOperand { 154 | Address(Ident), 155 | AddressOffset(Ident, i64), 156 | AddressOffsetVar(Ident, Ident), 157 | ArrayIndex(Ident, usize), 158 | } 159 | 160 | impl AddressOperand { 161 | pub fn get_ident(&self) -> &Ident { 162 | match self { 163 | AddressOperand::Address(ident) => ident, 164 | AddressOperand::AddressOffset(ident, _) => ident, 165 | AddressOperand::AddressOffsetVar(ident, _) => ident, 166 | AddressOperand::ArrayIndex(ident, _) => ident, 167 | } 168 | } 169 | } 170 | 171 | #[derive(Debug, Clone)] 172 | pub enum Operand { 173 | SpecialReg(SpecialReg), 174 | Variable(Ident), 175 | Immediate(Immediate), 176 | Address(AddressOperand), 177 | } 178 | 179 | #[derive(Debug, Clone, Copy)] 180 | pub enum Immediate { 181 | Float32(f32), 182 | Float64(f64), 183 | Int64(i64), 184 | UInt64(u64), 185 | } 186 | 187 | #[derive(Debug, Clone)] 188 | pub enum Guard { 189 | Normal(Ident), 190 | Negated(Ident), 191 | } 192 | 193 | #[derive(Debug)] 194 | pub enum Directive { 195 | VarDecl(VarDecl), 196 | Version(Version), 197 | Target(String), 198 | AddressSize(AddressSize), 199 | Function(Function), 200 | Pragma(Pragma), 201 | } 202 | 203 | #[derive(Debug, Clone)] 204 | pub struct Instruction { 205 | pub guard: Option, 206 | pub specifier: Operation, 207 | pub operands: Vec, 208 | } 209 | 210 | #[derive(Debug)] 211 | pub enum Statement { 212 | Directive(Directive), 213 | Instruction(Instruction), 214 | Grouping(Vec), 215 | Label(Ident), 216 | } 217 | 218 | #[derive(Debug, Clone, Copy)] 219 | pub enum PredicateOp { 220 | LessThan, 221 | LessThanEqual, 222 | GreaterThan, 223 | GreaterThanEqual, 224 | Equal, 225 | NotEqual, 226 | } 227 | 228 | #[derive(Debug, Clone, Copy)] 229 | pub enum MulMode { 230 | Low, 231 | High, 232 | Wide, 233 | } 234 | 235 | #[derive(Debug, Clone, Copy)] 236 | pub enum RoundingMode { 237 | NearestEvent, 238 | Zero, 239 | NegInf, 240 | PosInf, 241 | } 242 | 243 | #[derive(Debug, Clone)] 244 | pub enum Operation { 245 | Load(StateSpace, Type), 246 | Store(StateSpace, Type), 247 | Move(Type), 248 | Add(Type), 249 | Sub(Type), 250 | Or(Type), 251 | And(Type), 252 | Not(Type), 253 | FusedMulAdd(RoundingMode, Type), 254 | Negate(Type), 255 | Multiply(MulMode, Type), 256 | MultiplyAdd(MulMode, Type), 257 | Convert { 258 | from: Type, 259 | to: Type, 260 | }, 261 | ConvertAddress(Type, StateSpace), 262 | ConvertAddressTo(Type, StateSpace), 263 | SetPredicate(PredicateOp, Type), 264 | ShiftLeft(Type), 265 | Call { 266 | uniform: bool, 267 | ident: Ident, 268 | ret_param: Option, 269 | params: Vec, 270 | }, 271 | BarrierSync, 272 | Branch, 273 | Return, 274 | } 275 | 276 | type TokenPos<'a> = Range; 277 | 278 | struct Parser<'a> { 279 | src: &'a str, 280 | inner: std::iter::Peekable>>, 281 | } 282 | 283 | impl<'a> Parser<'a> { 284 | pub fn new(src: &'a str) -> Self { 285 | use logos::Logos; 286 | Self { 287 | src, 288 | inner: Token::lexer(src).spanned().peekable(), 289 | } 290 | } 291 | 292 | fn locate(&self, span: Range) -> SourceLocation { 293 | let text = self.src.as_bytes(); 294 | 295 | let mut line = 1; 296 | let mut col = 0; 297 | 298 | let end = span.end.min(text.len()); 299 | 300 | for &c in &text[..end] { 301 | match c { 302 | b'\n' => { 303 | line += 1; 304 | col = 0; 305 | }, 306 | b'\t' => { 307 | col = (col / 4) * 4 + 4; 308 | } 309 | _ => col += 1, 310 | } 311 | } 312 | 313 | SourceLocation { 314 | byte: span.start, 315 | line, 316 | col, 317 | } 318 | } 319 | 320 | fn unexpected(&self, (token, pos): (Token, TokenPos)) -> ParseErr { 321 | ParseErr::UnexpectedToken(token.to_string(), self.locate(pos)) 322 | } 323 | 324 | fn get(&mut self) -> ParseResult, TokenPos)>> { 325 | match self.inner.peek().cloned() { 326 | Some((Ok(tok), pos)) => Ok(Some((tok, pos))), 327 | Some((Err(LexError::Unknown), pos)) => Err(ParseErr::UnknownToken( 328 | self.src[pos.clone()].to_string(), 329 | self.locate(pos), 330 | )), 331 | Some((Err(err), pos)) => Err(ParseErr::LexError(err, self.locate(pos))), 332 | None => Ok(None), 333 | } 334 | } 335 | 336 | fn must_get(&mut self) -> Result<(Token<'a>, TokenPos), ParseErr> { 337 | self.get()?.ok_or(ParseErr::UnexpectedEof) 338 | } 339 | 340 | fn skip(&mut self) { 341 | self.inner.next(); 342 | } 343 | 344 | fn consume(&mut self, token: Token) -> ParseResult<()> { 345 | let head = self.must_get()?; 346 | if head.0 == token { 347 | self.skip(); 348 | Ok(()) 349 | } else { 350 | Err(self.unexpected(head)) 351 | } 352 | } 353 | 354 | fn consume_match(&mut self, token: Token) -> ParseResult { 355 | let Some(head) = self.get()? else { 356 | return Ok(false); 357 | }; 358 | if head.0 == token { 359 | self.skip(); 360 | Ok(true) 361 | } else { 362 | Ok(false) 363 | } 364 | } 365 | 366 | fn pop(&mut self) -> ParseResult, TokenPos)>> { 367 | match self.inner.next() { 368 | Some((Ok(tok), pos)) => Ok(Some((tok, pos))), 369 | Some((Err(err), pos)) => Err(ParseErr::LexError(err, self.locate(pos))), 370 | None => Ok(None), 371 | } 372 | } 373 | 374 | fn must_pop(&mut self) -> Result<(Token<'a>, TokenPos), ParseErr> { 375 | self.pop()?.ok_or(ParseErr::UnexpectedEof) 376 | } 377 | 378 | fn parse_pragma(&mut self) -> ParseResult { 379 | self.consume(Token::Pragma)?; 380 | let t = self.must_pop()?; 381 | match t.0 { 382 | Token::StringLiteral(s) => { 383 | self.consume(Token::Semicolon)?; 384 | Ok(Pragma(s.to_string())) 385 | } 386 | _ => Err(self.unexpected(t)), 387 | } 388 | } 389 | 390 | fn parse_version(&mut self) -> ParseResult { 391 | let t = self.must_pop()?; 392 | match t.0 { 393 | Token::Version((major, minor)) => Ok(Version { major, minor }), 394 | _ => Err(self.unexpected(t)), 395 | } 396 | } 397 | 398 | fn parse_target(&mut self) -> ParseResult { 399 | self.consume(Token::Target)?; 400 | let t = self.must_pop()?; 401 | match t.0 { 402 | Token::Identifier(target) => Ok(target.to_string()), 403 | _ => Err(self.unexpected(t)), 404 | } 405 | } 406 | 407 | fn parse_address_size(&mut self) -> ParseResult { 408 | self.consume(Token::AddressSize)?; 409 | let t = self.must_pop()?; 410 | let Token::IntegerConst(size) = t.0 else { 411 | return Err(self.unexpected(t)); 412 | }; 413 | match size { 414 | 32 => Ok(AddressSize::Adr32), 415 | 64 => Ok(AddressSize::Adr64), 416 | _ => Ok(AddressSize::Other), 417 | } 418 | } 419 | 420 | fn parse_module(&mut self) -> ParseResult { 421 | let mut directives = Vec::new(); 422 | while self.get()?.is_some() { 423 | match self.parse_directive() { 424 | Ok(directive) => { 425 | directives.push(directive); 426 | } 427 | Err(e) => return Err(e), 428 | } 429 | } 430 | Ok(Module(directives)) 431 | } 432 | 433 | fn parse_array_bounds(&mut self) -> ParseResult> { 434 | let mut bounds = Vec::new(); 435 | loop { 436 | match self.get()? { 437 | Some((Token::LeftBracket, _)) => self.skip(), 438 | _ => break Ok(bounds), 439 | } 440 | let t = self.must_pop()?; 441 | let Token::IntegerConst(bound) = t.0 else { 442 | return Err(self.unexpected(t)); 443 | }; 444 | self.consume(Token::RightBracket)?; 445 | // todo clean up raw casts 446 | bounds.push(bound as u32); 447 | } 448 | } 449 | 450 | fn parse_state_space(&mut self) -> ParseResult { 451 | let t = self.must_pop()?; 452 | match t.0 { 453 | Token::Global => Ok(StateSpace::Global), 454 | Token::Local => Ok(StateSpace::Local), 455 | Token::Shared => Ok(StateSpace::Shared), 456 | Token::Reg => Ok(StateSpace::Register), 457 | Token::Param => Ok(StateSpace::Parameter), 458 | Token::Const => Ok(StateSpace::Constant), 459 | _ => Err(self.unexpected(t)), 460 | } 461 | } 462 | 463 | fn parse_alignment(&mut self) -> ParseResult { 464 | self.consume(Token::Align)?; 465 | let t = self.must_pop()?; 466 | let alignment = match t.0 { 467 | Token::IntegerConst(i) => i as u32, 468 | _ => return Err(self.unexpected(t)), 469 | }; 470 | Ok(alignment) 471 | } 472 | 473 | fn parse_type(&mut self) -> ParseResult { 474 | let t = self.must_pop()?; 475 | let ty = match t.0 { 476 | Token::Bit8 => Type::B8, 477 | Token::Bit16 => Type::B16, 478 | Token::Bit32 => Type::B32, 479 | Token::Bit64 => Type::B64, 480 | Token::Bit128 => Type::B128, 481 | Token::Unsigned8 => Type::U8, 482 | Token::Unsigned16 => Type::U16, 483 | Token::Unsigned32 => Type::U32, 484 | Token::Unsigned64 => Type::U64, 485 | Token::Signed8 => Type::S8, 486 | Token::Signed16 => Type::S16, 487 | Token::Signed32 => Type::S32, 488 | Token::Signed64 => Type::S64, 489 | Token::Float16 => Type::F16, 490 | Token::Float16x2 => Type::F16x2, 491 | Token::Float32 => Type::F32, 492 | Token::Float64 => Type::F64, 493 | Token::Predicate => Type::Pred, 494 | _ => return Err(self.unexpected(t)), 495 | }; 496 | Ok(ty) 497 | } 498 | 499 | fn parse_rounding_mode(&mut self) -> ParseResult { 500 | let t = self.must_pop()?; 501 | let mode = match t.0 { 502 | Token::Rn => RoundingMode::NearestEvent, 503 | Token::Rz => RoundingMode::Zero, 504 | Token::Rm => RoundingMode::NegInf, 505 | Token::Rp => RoundingMode::PosInf, 506 | _ => return Err(self.unexpected(t)), 507 | }; 508 | Ok(mode) 509 | } 510 | 511 | fn parse_mul_mode(&mut self) -> ParseResult { 512 | let t = self.must_pop()?; 513 | let mode = match t.0 { 514 | Token::Low => MulMode::Low, 515 | Token::High => MulMode::High, 516 | Token::Wide => MulMode::Wide, 517 | _ => return Err(self.unexpected(t)), 518 | }; 519 | Ok(mode) 520 | } 521 | 522 | fn parse_variable(&mut self) -> ParseResult { 523 | let state_space = self.parse_state_space()?; 524 | 525 | let t = self.get()?; 526 | let alignment = if let Some((Token::Align, _)) = t { 527 | Some(self.parse_alignment()?) 528 | } else { 529 | None 530 | }; 531 | 532 | let t = self.get()?; 533 | let vector = match t { 534 | Some((Token::V2, _)) => { 535 | self.skip(); 536 | Some(Vector::V2) 537 | } 538 | Some((Token::V4, _)) => { 539 | self.skip(); 540 | Some(Vector::V4) 541 | } 542 | _ => None, 543 | }; 544 | 545 | let ty = self.parse_type()?; 546 | 547 | let t = self.must_pop()?; 548 | let ident = match t.0 { 549 | Token::Identifier(s) => s.to_string(), 550 | _ => return Err(self.unexpected(t)), 551 | }; 552 | 553 | let t = self.must_get()?; 554 | let multiplicity = match t.0 { 555 | Token::RegMultiplicity(m) => { 556 | self.skip(); 557 | Some(m) 558 | } 559 | _ => None, 560 | }; 561 | 562 | let array_bounds = self.parse_array_bounds()?; 563 | 564 | self.consume(Token::Semicolon)?; 565 | 566 | Ok(VarDecl { 567 | state_space, 568 | ty, 569 | vector, 570 | alignment, 571 | array_bounds, 572 | ident: ident.to_string(), 573 | multiplicity, 574 | }) 575 | } 576 | 577 | fn parse_guard(&mut self) -> ParseResult { 578 | self.consume(Token::At)?; 579 | let t = self.must_pop()?; 580 | let guard = match t.0 { 581 | Token::Identifier(s) => Guard::Normal(s.to_string()), 582 | Token::Bang => { 583 | let t = self.must_pop()?; 584 | let ident = match t.0 { 585 | Token::Identifier(s) => s, 586 | _ => return Err(self.unexpected(t)), 587 | }; 588 | Guard::Negated(ident.to_string()) 589 | } 590 | _ => return Err(self.unexpected(t)), 591 | }; 592 | Ok(guard) 593 | } 594 | 595 | fn parse_predicate(&mut self) -> ParseResult { 596 | let t = self.must_pop()?; 597 | let pred = match t.0 { 598 | Token::Ge => PredicateOp::GreaterThanEqual, 599 | Token::Gt => PredicateOp::GreaterThan, 600 | Token::Le => PredicateOp::LessThanEqual, 601 | Token::Lt => PredicateOp::LessThan, 602 | Token::Eq => PredicateOp::Equal, 603 | Token::Ne => PredicateOp::NotEqual, 604 | _ => return Err(self.unexpected(t)), 605 | }; 606 | Ok(pred) 607 | } 608 | 609 | fn parse_operation(&mut self) -> ParseResult { 610 | let t = self.must_pop()?; 611 | match t.0 { 612 | Token::Ld => { 613 | let state_space = self.parse_state_space()?; 614 | let ty = self.parse_type()?; 615 | Ok(Operation::Load(state_space, ty)) 616 | } 617 | Token::St => { 618 | let state_space = self.parse_state_space()?; 619 | let ty = self.parse_type()?; 620 | Ok(Operation::Store(state_space, ty)) 621 | } 622 | Token::Mov => { 623 | let ty = self.parse_type()?; 624 | Ok(Operation::Move(ty)) 625 | } 626 | Token::Add => { 627 | let ty = self.parse_type()?; 628 | Ok(Operation::Add(ty)) 629 | } 630 | Token::Sub => { 631 | let ty = self.parse_type()?; 632 | Ok(Operation::Sub(ty)) 633 | } 634 | Token::Or => { 635 | let ty = self.parse_type()?; 636 | Ok(Operation::Or(ty)) 637 | } 638 | Token::And => { 639 | let ty = self.parse_type()?; 640 | Ok(Operation::And(ty)) 641 | } 642 | Token::Not => { 643 | let ty = self.parse_type()?; 644 | Ok(Operation::Not(ty)) 645 | } 646 | Token::Mul => { 647 | let mode = self.parse_mul_mode()?; 648 | let ty = self.parse_type()?; 649 | Ok(Operation::Multiply(mode, ty)) 650 | } 651 | Token::Mad => { 652 | let mode = self.parse_mul_mode()?; 653 | let ty = self.parse_type()?; 654 | Ok(Operation::MultiplyAdd(mode, ty)) 655 | } 656 | Token::Fma => { 657 | let mode = self.parse_rounding_mode()?; 658 | let ty = self.parse_type()?; 659 | Ok(Operation::FusedMulAdd(mode, ty)) 660 | } 661 | Token::Neg => { 662 | let ty = self.parse_type()?; 663 | Ok(Operation::Negate(ty)) 664 | } 665 | Token::Cvt => { 666 | let to = self.parse_type()?; 667 | let from = self.parse_type()?; 668 | Ok(Operation::Convert { to, from }) 669 | } 670 | Token::Call => { 671 | let uniform = self.consume_match(Token::Uniform)?; 672 | let ret_param = if let Token::LeftParen = self.must_get()?.0 { 673 | self.skip(); 674 | let t = self.must_pop()?; 675 | let ident = match t.0 { 676 | Token::Identifier(s) => s.to_string(), 677 | _ => return Err(self.unexpected(t)), 678 | }; 679 | self.consume(Token::RightParen)?; 680 | self.consume(Token::Comma)?; 681 | Some(ident) 682 | } else { 683 | None 684 | }; 685 | let t = self.must_pop()?; 686 | let ident = match t.0 { 687 | Token::Identifier(s) => s.to_string(), 688 | _ => return Err(self.unexpected(t)), 689 | }; 690 | self.consume(Token::Comma)?; 691 | let mut params = Vec::new(); 692 | if let Token::LeftParen = self.must_get()?.0 { 693 | self.skip(); 694 | loop { 695 | let t = self.must_pop()?; 696 | let ident = match t.0 { 697 | Token::Identifier(s) => s.to_string(), 698 | _ => return Err(self.unexpected(t)), 699 | }; 700 | params.push(ident); 701 | let t = self.must_pop()?; 702 | match t.0 { 703 | Token::RightParen => break, 704 | Token::Comma => {} 705 | _ => return Err(self.unexpected(t)), 706 | } 707 | } 708 | }; 709 | 710 | Ok(Operation::Call { 711 | uniform, 712 | ident: ident.to_string(), 713 | ret_param, 714 | params, 715 | }) 716 | } 717 | Token::Cvta => match self.must_get()?.0 { 718 | Token::To => { 719 | self.skip(); 720 | let state_space = self.parse_state_space()?; 721 | let ty = self.parse_type()?; 722 | Ok(Operation::ConvertAddressTo(ty, state_space)) 723 | } 724 | _ => { 725 | let state_space = self.parse_state_space()?; 726 | let ty = self.parse_type()?; 727 | Ok(Operation::ConvertAddress(ty, state_space)) 728 | } 729 | }, 730 | Token::Setp => { 731 | let pred = self.parse_predicate()?; 732 | let ty = self.parse_type()?; 733 | Ok(Operation::SetPredicate(pred, ty)) 734 | } 735 | Token::Shl => { 736 | let ty = self.parse_type()?; 737 | Ok(Operation::ShiftLeft(ty)) 738 | } 739 | Token::Bra => { 740 | self.consume_match(Token::Uniform)?; 741 | Ok(Operation::Branch) 742 | } 743 | Token::Ret => Ok(Operation::Return), 744 | Token::Bar => { 745 | // cta token is meaningless 746 | self.consume_match(Token::Cta)?; 747 | self.consume(Token::Sync)?; 748 | Ok(Operation::BarrierSync) 749 | } 750 | _ => Err(self.unexpected(t)), 751 | } 752 | } 753 | 754 | fn parse_operand(&mut self) -> ParseResult { 755 | let t = self.must_pop()?; 756 | let operand = match t.0 { 757 | Token::ThreadId => SpecialReg::ThreadId.into(), 758 | Token::ThreadIdX => SpecialReg::ThreadIdX.into(), 759 | Token::ThreadIdY => SpecialReg::ThreadIdY.into(), 760 | Token::ThreadIdZ => SpecialReg::ThreadIdZ.into(), 761 | Token::NumThreads => SpecialReg::NumThread.into(), 762 | Token::NumThreadsX => SpecialReg::NumThreadX.into(), 763 | Token::NumThreadsY => SpecialReg::NumThreadY.into(), 764 | Token::NumThreadsZ => SpecialReg::NumThreadZ.into(), 765 | Token::CtaId => SpecialReg::CtaId.into(), 766 | Token::CtaIdX => SpecialReg::CtaIdX.into(), 767 | Token::CtaIdY => SpecialReg::CtaIdY.into(), 768 | Token::CtaIdZ => SpecialReg::CtaIdZ.into(), 769 | Token::IntegerConst(i) => Operand::Immediate(Immediate::Int64(i)), 770 | Token::Float64Const(f) => Operand::Immediate(Immediate::Float64(f)), 771 | Token::Float32Const(f) => Operand::Immediate(Immediate::Float32(f)), 772 | Token::Identifier(s) => { 773 | let t = self.get()?; 774 | if let Some((Token::LeftBracket, _)) = t { 775 | todo!("array syntax in operands") 776 | } else { 777 | Operand::Variable(s.to_string()) 778 | } 779 | } 780 | Token::LeftBracket => { 781 | let t = self.must_pop()?; 782 | let Token::Identifier(s) = t.0 else { 783 | return Err(self.unexpected(t)); 784 | }; 785 | let ident = s.to_string(); 786 | 787 | let t = self.must_get()?; 788 | let res = if let Token::Plus = t.0 { 789 | self.skip(); 790 | let t = self.must_pop()?; 791 | match t.0 { 792 | Token::IntegerConst(i) => { 793 | Operand::Address(AddressOperand::AddressOffset(ident, i)) 794 | } 795 | Token::Identifier(s) => { 796 | Operand::Address(AddressOperand::AddressOffsetVar(ident, s.to_string())) 797 | } 798 | _ => return Err(self.unexpected(t)), 799 | } 800 | } else { 801 | Operand::Address(AddressOperand::Address(ident)) 802 | }; 803 | self.consume(Token::RightBracket)?; 804 | res 805 | } 806 | _ => return Err(self.unexpected(t)), 807 | }; 808 | Ok(operand) 809 | } 810 | 811 | fn parse_operands(&mut self) -> ParseResult> { 812 | let mut operands = Vec::new(); 813 | loop { 814 | let t = self.must_get()?; 815 | match t.0 { 816 | Token::Semicolon => { 817 | self.skip(); 818 | break Ok(operands); 819 | } 820 | Token::Comma => self.skip(), 821 | _ => {} 822 | } 823 | let op = self.parse_operand()?; 824 | operands.push(op); 825 | } 826 | } 827 | 828 | fn parse_grouping(&mut self) -> ParseResult> { 829 | self.consume(Token::LeftBrace)?; // Consume the left brace 830 | let mut statements = Vec::new(); 831 | loop { 832 | let t = self.must_get()?; 833 | if let Token::RightBrace = t.0 { 834 | self.skip(); 835 | break Ok(statements); 836 | } 837 | statements.push(self.parse_statement()?); 838 | } 839 | } 840 | 841 | fn parse_directive(&mut self) -> ParseResult { 842 | let t = self.must_get()?; 843 | let res = match t.0 { 844 | Token::Version(_) => { 845 | let version = self.parse_version()?; 846 | Directive::Version(version) 847 | } 848 | Token::Target => { 849 | let target = self.parse_target()?; 850 | Directive::Target(target) 851 | } 852 | Token::AddressSize => { 853 | let addr_size = self.parse_address_size()?; 854 | Directive::AddressSize(addr_size) 855 | } 856 | Token::Func | Token::Visible | Token::Entry => { 857 | let function = self.parse_function()?; 858 | Directive::Function(function) 859 | } 860 | Token::Pragma => { 861 | let pragma = self.parse_pragma()?; 862 | Directive::Pragma(pragma) 863 | } 864 | _ => { 865 | let var = self.parse_variable()?; 866 | Directive::VarDecl(var) 867 | } 868 | }; 869 | Ok(res) 870 | } 871 | 872 | fn parse_instruction(&mut self) -> ParseResult { 873 | let t = self.must_get()?; 874 | let guard = if let Token::At = t.0 { 875 | Some(self.parse_guard()?) 876 | } else { 877 | None 878 | }; 879 | 880 | let specifier = self.parse_operation()?; 881 | let operands = self.parse_operands()?; 882 | 883 | Ok(Instruction { 884 | guard, 885 | specifier, 886 | operands, 887 | }) 888 | } 889 | 890 | fn parse_statement(&mut self) -> ParseResult { 891 | let t = self.must_get()?; 892 | match t.0 { 893 | Token::LeftBrace => { 894 | let grouping = self.parse_grouping()?; 895 | Ok(Statement::Grouping(grouping)) 896 | } 897 | t if t.is_directive() => { 898 | let dir = self.parse_directive()?; 899 | Ok(Statement::Directive(dir)) 900 | } 901 | Token::Identifier(i) => { 902 | let i = i.to_string(); 903 | self.skip(); 904 | self.consume(Token::Colon)?; 905 | Ok(Statement::Label(i.to_string())) 906 | } 907 | _ => { 908 | let instr = self.parse_instruction()?; 909 | Ok(Statement::Instruction(instr)) 910 | } 911 | } 912 | } 913 | 914 | fn parse_function_param(&mut self) -> ParseResult { 915 | self.consume(Token::Param)?; // Consume the param keyword 916 | 917 | let alignment = None; // todo parse alignment in function param 918 | 919 | let ty = self.parse_type()?; 920 | let ident = loop { 921 | let t = self.must_pop()?; 922 | if let Token::Identifier(s) = t.0 { 923 | break s.to_string(); 924 | } 925 | }; 926 | 927 | let array_bounds = self.parse_array_bounds()?; 928 | 929 | Ok(FunctionParam { 930 | alignment, 931 | ident: ident.to_string(), 932 | ty, 933 | array_bounds, 934 | }) 935 | } 936 | 937 | fn parse_function_params(&mut self) -> ParseResult> { 938 | // if there is no left parenthesis, there are no parameters 939 | if !self.consume_match(Token::LeftParen)? { 940 | return Ok(Vec::new()); 941 | } 942 | // if we immediately see a right parenthesis, there are no parameters 943 | if self.consume_match(Token::RightParen)? { 944 | return Ok(Vec::new()); 945 | } 946 | 947 | let mut params = Vec::new(); 948 | loop { 949 | params.push(self.parse_function_param()?); 950 | let t = self.must_pop()?; 951 | match t.0 { 952 | Token::Comma => {} 953 | Token::RightParen => break Ok(params), 954 | _ => return Err(self.unexpected(t)), 955 | } 956 | } 957 | } 958 | 959 | fn parse_return_param(&mut self) -> ParseResult> { 960 | let t = self.must_get()?; 961 | if let Token::LeftParen = t.0 { 962 | self.skip(); 963 | } else { 964 | return Ok(None); 965 | } 966 | let param = self.parse_function_param()?; 967 | self.consume(Token::RightParen)?; 968 | Ok(Some(param)) 969 | } 970 | 971 | fn parse_function(&mut self) -> ParseResult { 972 | let visible = if let Token::Visible = self.must_get()?.0 { 973 | self.skip(); 974 | true 975 | } else { 976 | false 977 | }; 978 | let t = self.must_pop()?; 979 | let entry = match t.0 { 980 | Token::Entry => true, 981 | Token::Func => false, 982 | _ => return Err(self.unexpected(t)), 983 | }; 984 | 985 | let return_param = self.parse_return_param()?; 986 | 987 | let t = self.must_pop()?; 988 | let ident = match t.0 { 989 | Token::Identifier(s) => s.to_string(), 990 | _ => return Err(self.unexpected(t)), 991 | }; 992 | 993 | let noreturn = if let Token::Noreturn = self.must_get()?.0 { 994 | self.skip(); 995 | true 996 | } else { 997 | false 998 | }; 999 | 1000 | let params = self.parse_function_params()?; 1001 | let body = self.parse_statement()?; 1002 | 1003 | Ok(Function { 1004 | ident: ident.to_string(), 1005 | visible, 1006 | entry, 1007 | return_param, 1008 | noreturn, 1009 | params, 1010 | body: Box::new(body), 1011 | }) 1012 | } 1013 | } 1014 | 1015 | pub fn parse_program(src: &str) -> Result { 1016 | Parser::new(src).parse_module() 1017 | } 1018 | 1019 | #[cfg(test)] 1020 | mod test { 1021 | use super::*; 1022 | 1023 | #[test] 1024 | fn test_parse_add() { 1025 | let contents = std::fs::read_to_string("kernels/add.ptx").unwrap(); 1026 | let _ = parse_program(&contents).unwrap(); 1027 | } 1028 | 1029 | #[test] 1030 | fn test_parse_transpose() { 1031 | let contents = std::fs::read_to_string("kernels/transpose.ptx").unwrap(); 1032 | let _ = parse_program(&contents).unwrap(); 1033 | } 1034 | 1035 | #[test] 1036 | fn test_parse_add_simple() { 1037 | let contents = std::fs::read_to_string("kernels/add_simple.ptx").unwrap(); 1038 | let _ = parse_program(&contents).unwrap(); 1039 | } 1040 | 1041 | #[test] 1042 | fn test_parse_fncall() { 1043 | let contents = std::fs::read_to_string("kernels/fncall.ptx").unwrap(); 1044 | let _ = parse_program(&contents).unwrap(); 1045 | } 1046 | 1047 | #[test] 1048 | fn test_parse_gemm() { 1049 | let contents = std::fs::read_to_string("kernels/gemm.ptx").unwrap(); 1050 | let _ = parse_program(&contents).unwrap(); 1051 | } 1052 | } 1053 | -------------------------------------------------------------------------------- /src/vm.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use crate::ast::MulMode; 4 | use crate::ast::PredicateOp; 5 | use crate::ast::SpecialReg; 6 | 7 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 8 | pub enum Type { 9 | U128, 10 | U64, 11 | U32, 12 | U16, 13 | U8, 14 | S64, 15 | S32, 16 | S16, 17 | S8, 18 | F64, 19 | F32, 20 | F16x2, 21 | F16, 22 | Pred, 23 | } 24 | 25 | impl Type { 26 | pub fn size(&self) -> usize { 27 | use Type::*; 28 | match self { 29 | U128 => 16, 30 | U64 | S64 | F64 => 8, 31 | U32 | S32 | F32 | F16x2 => 4, 32 | U16 | S16 | F16 => 2, 33 | U8 | S8 => 1, 34 | Pred => 1, 35 | } 36 | } 37 | 38 | pub fn alignment(&self) -> usize { 39 | self.size() 40 | } 41 | } 42 | 43 | impl From for Type { 44 | fn from(value: crate::ast::Type) -> Self { 45 | use crate::ast; 46 | match value { 47 | ast::Type::B128 => Type::U128, 48 | ast::Type::B64 => Type::U64, 49 | ast::Type::B32 => Type::U32, 50 | ast::Type::B16 => Type::U16, 51 | ast::Type::B8 => Type::U8, 52 | ast::Type::U64 => Type::U64, 53 | ast::Type::U32 => Type::U32, 54 | ast::Type::U16 => Type::U16, 55 | ast::Type::U8 => Type::U8, 56 | ast::Type::S64 => Type::S64, 57 | ast::Type::S32 => Type::S32, 58 | ast::Type::S16 => Type::S16, 59 | ast::Type::S8 => Type::S8, 60 | ast::Type::F64 => Type::F64, 61 | ast::Type::F32 => Type::F32, 62 | ast::Type::F16x2 => Type::F16x2, 63 | ast::Type::F16 => Type::F16, 64 | ast::Type::Pred => Type::Pred, 65 | } 66 | } 67 | } 68 | 69 | #[derive(Debug, Clone, Copy)] 70 | pub enum Symbol { 71 | Function(usize), 72 | #[allow(dead_code)] 73 | Variable(usize), 74 | } 75 | 76 | #[derive(Clone, Copy, Debug)] 77 | pub enum StateSpace { 78 | // includes ptx global and const 79 | Global, 80 | // includes ptx local and param 81 | Stack, 82 | // includes ptx shared 83 | Shared, 84 | } 85 | 86 | #[derive(Clone, Copy, Debug)] 87 | pub enum Constant { 88 | U64(u64), 89 | S64(i64), 90 | F64(f64), 91 | F32(f32), 92 | } 93 | 94 | #[derive(Clone, Copy, Debug)] 95 | pub struct GenericReg(pub usize); 96 | 97 | #[derive(Clone, Copy, Debug)] 98 | pub enum RegOperand { 99 | Generic(GenericReg), 100 | Special(SpecialReg), 101 | } 102 | 103 | impl From for RegOperand { 104 | fn from(value: GenericReg) -> Self { 105 | RegOperand::Generic(value) 106 | } 107 | } 108 | 109 | impl From for RegOperand { 110 | fn from(value: SpecialReg) -> Self { 111 | RegOperand::Special(value) 112 | } 113 | } 114 | 115 | #[derive(Clone, Copy, Debug)] 116 | pub enum Instruction { 117 | Load(Type, StateSpace, GenericReg, RegOperand), 118 | Store(Type, StateSpace, RegOperand, RegOperand), 119 | Move(Type, GenericReg, RegOperand), 120 | Const(GenericReg, Constant), 121 | Convert { 122 | dst_type: Type, 123 | src_type: Type, 124 | dst: GenericReg, 125 | src: RegOperand, 126 | }, 127 | 128 | Add(Type, GenericReg, RegOperand, RegOperand), 129 | Sub(Type, GenericReg, RegOperand, RegOperand), 130 | Or(Type, GenericReg, RegOperand, RegOperand), 131 | And(Type, GenericReg, RegOperand, RegOperand), 132 | Not(Type, GenericReg, RegOperand), 133 | ShiftLeft(Type, GenericReg, RegOperand, RegOperand), 134 | Mul(Type, MulMode, GenericReg, RegOperand, RegOperand), 135 | Neg(Type, GenericReg, RegOperand), 136 | 137 | PushArg(RegOperand), 138 | PopArg(Option), 139 | Call(usize), 140 | 141 | BarrierSync { 142 | idx: RegOperand, 143 | cnt: Option, 144 | }, 145 | BarrierArrive { 146 | idx: RegOperand, 147 | cnt: Option, 148 | }, 149 | 150 | Jump { 151 | offset: isize, 152 | }, 153 | SetPredicate(Type, PredicateOp, GenericReg, RegOperand, RegOperand), 154 | SkipIf(RegOperand, bool), 155 | 156 | Return, 157 | } 158 | 159 | #[derive(Clone, Copy, Debug)] 160 | pub struct IPtr(pub usize); 161 | 162 | #[derive(Clone, Debug)] 163 | struct FrameMeta { 164 | return_addr: IPtr, 165 | num_args: usize, 166 | reg_base: usize, 167 | stack_base: usize, 168 | } 169 | 170 | #[derive(Debug, Clone)] 171 | struct ThreadState { 172 | iptr: IPtr, 173 | regs: Vec, 174 | stack_data: Vec, 175 | stack_frames: Vec, 176 | nctaid: (u32, u32, u32), 177 | ctaid: (u32, u32, u32), 178 | ntid: (u32, u32, u32), 179 | tid: (u32, u32, u32), 180 | } 181 | 182 | #[derive(Clone, Copy, Debug)] 183 | pub struct FuncFrameDesc { 184 | pub iptr: IPtr, 185 | pub frame_size: usize, 186 | pub num_args: usize, 187 | pub num_regs: usize, 188 | pub shared_size: usize, 189 | } 190 | 191 | #[derive(Debug, Default)] 192 | pub struct Context { 193 | global_mem: Vec, 194 | instructions: Vec, 195 | descriptors: Vec, 196 | symbol_map: HashMap, 197 | } 198 | 199 | macro_rules! byte_reg_funcs { 200 | ($($get:ident, $set:ident, $helper_fn:ident, $helper_type:ty, $n:expr);* $(;)?) => { 201 | $( 202 | fn $get(&self, reg: RegOperand) -> [u8; $n] { 203 | self.$helper_fn(reg).to_ne_bytes() 204 | } 205 | 206 | fn $set(&mut self, reg: GenericReg, value: [u8; $n]) { 207 | self.set(reg, <$helper_type>::from_ne_bytes(value) as u128); 208 | } 209 | )* 210 | }; 211 | } 212 | 213 | macro_rules! int_getters { 214 | ($($get:ident, $t2:ty);* $(;)?) => { 215 | $( 216 | fn $get(&self, reg: RegOperand) -> $t2 { 217 | self.get(reg) as $t2 218 | } 219 | )* 220 | }; 221 | } 222 | 223 | macro_rules! int_setters { 224 | ($($set:ident, $t2:ty);* $(;)?) => { 225 | $( 226 | fn $set(&mut self, reg: GenericReg, value: $t2) { 227 | self.set(reg, value as u128); 228 | } 229 | )* 230 | }; 231 | } 232 | 233 | impl ThreadState { 234 | fn new( 235 | nctaid: (u32, u32, u32), 236 | ctaid: (u32, u32, u32), 237 | ntid: (u32, u32, u32), 238 | tid: (u32, u32, u32), 239 | ) -> ThreadState { 240 | ThreadState { 241 | iptr: IPtr(0), 242 | regs: Vec::new(), 243 | stack_data: Vec::new(), 244 | stack_frames: Vec::new(), 245 | nctaid, 246 | ctaid, 247 | ntid, 248 | tid, 249 | } 250 | } 251 | 252 | fn get(&self, reg: RegOperand) -> u128 { 253 | match reg { 254 | RegOperand::Generic(reg) => { 255 | let meta = self.stack_frames.last().unwrap(); 256 | let idx = meta.reg_base - meta.num_args + reg.0; 257 | self.regs[idx] 258 | } 259 | RegOperand::Special(reg) => self.get_special(reg), 260 | } 261 | } 262 | 263 | fn set(&mut self, reg: GenericReg, val: u128) { 264 | let meta = self.stack_frames.last().unwrap(); 265 | let idx = meta.reg_base - meta.num_args + reg.0; 266 | self.regs[idx] = val; 267 | } 268 | 269 | fn get_pred(&self, reg: RegOperand) -> bool { 270 | self.get(reg) != 0 271 | } 272 | 273 | fn get_f32(&self, reg: RegOperand) -> f32 { 274 | f32::from_bits(self.get_u32(reg)) 275 | } 276 | 277 | fn set_f32(&mut self, reg: GenericReg, val: f32) { 278 | self.set_u32(reg, val.to_bits()) 279 | } 280 | 281 | fn get_f64(&self, reg: RegOperand) -> f64 { 282 | f64::from_bits(self.get_u64(reg)) 283 | } 284 | 285 | fn set_f64(&mut self, reg: GenericReg, val: f64) { 286 | self.set_u64(reg, val.to_bits()) 287 | } 288 | 289 | byte_reg_funcs!( 290 | // get_b8, set_b8, get_u8, u8, 1; 291 | // get_b16, set_b16, get_u16, u16, 2; 292 | // get_b32, set_b32, get_u32, u32, 4; 293 | // get_b64, set_b64, get_u64, u64, 8; 294 | get_b128, set_b128, get_u128, u128, 16; 295 | ); 296 | 297 | int_getters!( 298 | // get_u8, u8; 299 | get_u16, u16; 300 | get_u32, u32; 301 | get_u64, u64; 302 | get_u128, u128; 303 | 304 | // get_i8, i8; 305 | // get_i16, i16; 306 | get_i32, i32; 307 | get_i64, i64; 308 | // get_i128, i128; 309 | ); 310 | int_setters!( 311 | set_pred, bool; 312 | // set_u8, u8; 313 | set_u16, u16; 314 | set_u32, u32; 315 | set_u64, u64; 316 | // set_u128, u128; 317 | 318 | // set_i8, i8; 319 | // set_i16, i16; 320 | set_i32, i32; 321 | set_i64, i64; 322 | // set_i128, i128; 323 | ); 324 | 325 | fn iptr_fetch_incr(&mut self) -> IPtr { 326 | let ret = self.iptr; 327 | self.iptr.0 += 1; 328 | ret 329 | } 330 | 331 | fn frame_teardown(&mut self) { 332 | let meta = self.stack_frames.pop().unwrap(); 333 | self.stack_data.truncate(meta.stack_base); 334 | self.regs.resize(meta.reg_base + meta.num_args, 0); 335 | self.iptr = meta.return_addr; 336 | } 337 | 338 | fn frame_setup(&mut self, desc: FuncFrameDesc) { 339 | self.stack_frames.push(FrameMeta { 340 | return_addr: self.iptr, 341 | num_args: desc.num_args, 342 | reg_base: self.regs.len(), 343 | stack_base: self.stack_data.len(), 344 | }); 345 | self.stack_data 346 | .resize(self.stack_data.len() + desc.frame_size, 0); 347 | // args were already set up by caller 348 | self.regs.resize(self.regs.len() + desc.num_regs, 0); 349 | self.iptr = desc.iptr; 350 | } 351 | 352 | fn read_reg_unsigned(&self, reg: RegOperand) -> VmResult { 353 | Ok(self.get(reg) as usize) 354 | } 355 | 356 | fn get_stack_ptr(&self) -> usize { 357 | self.stack_frames.last().unwrap().stack_base 358 | } 359 | 360 | fn get_special(&self, s: SpecialReg) -> u128 { 361 | match s { 362 | SpecialReg::StackPtr => self.get_stack_ptr() as u128, 363 | SpecialReg::ThreadId => todo!(), 364 | SpecialReg::ThreadIdX => self.tid.0 as u128, 365 | SpecialReg::ThreadIdY => self.tid.1 as u128, 366 | SpecialReg::ThreadIdZ => self.tid.2 as u128, 367 | SpecialReg::NumThread => todo!(), 368 | SpecialReg::NumThreadX => self.ntid.0 as u128, 369 | SpecialReg::NumThreadY => self.ntid.1 as u128, 370 | SpecialReg::NumThreadZ => self.ntid.2 as u128, 371 | SpecialReg::CtaId => todo!(), 372 | SpecialReg::CtaIdX => self.ctaid.0 as u128, 373 | SpecialReg::CtaIdY => self.ctaid.1 as u128, 374 | SpecialReg::CtaIdZ => self.ctaid.2 as u128, 375 | SpecialReg::NumCta => todo!(), 376 | SpecialReg::NumCtaX => self.nctaid.0 as u128, 377 | SpecialReg::NumCtaY => self.nctaid.1 as u128, 378 | SpecialReg::NumCtaZ => self.nctaid.2 as u128, 379 | } 380 | } 381 | } 382 | 383 | #[derive(thiserror::Error, Debug)] 384 | pub enum VmError { 385 | #[error("invalid address operand register {0:?}")] 386 | InvalidAddressOperandRegister(RegOperand), 387 | #[error("Invalid register operand {:?} for instruction {:?}", .1, .0)] 388 | InvalidOperand(Instruction, RegOperand), 389 | #[error("Parameter data did not match descriptor")] 390 | ParamDataSizeMismatch, 391 | #[error("Invalid function id {0}")] 392 | InvalidFunctionId(usize), 393 | #[error("Invalid function name {0}")] 394 | InvalidFunctionName(String), 395 | #[error("Slice size mismatch")] 396 | SliceSizeMismatch(#[from] std::array::TryFromSliceError), 397 | #[error("Parse error: {0:?}")] 398 | ParseError(#[from] crate::ast::ParseErr), 399 | #[error("Compile error: {0:?}")] 400 | CompileError(#[from] crate::compiler::CompilationError), 401 | } 402 | 403 | type VmResult = Result; 404 | 405 | #[derive(Clone, Copy, Debug)] 406 | pub struct DevicePointer(u64, std::marker::PhantomData); 407 | 408 | #[derive(Clone, Copy, Debug)] 409 | pub struct RawDevicePointer(u64); 410 | 411 | impl From for DevicePointer { 412 | fn from(value: RawDevicePointer) -> Self { 413 | Self(value.0, std::marker::PhantomData) 414 | } 415 | } 416 | 417 | impl From> for RawDevicePointer { 418 | fn from(value: DevicePointer) -> Self { 419 | Self(value.0) 420 | } 421 | } 422 | 423 | pub enum Argument<'a> { 424 | Ptr(RawDevicePointer), 425 | U64(u64), 426 | U32(u32), 427 | Bytes(&'a [u8]), 428 | } 429 | 430 | impl<'a> Argument<'a> { 431 | pub fn ptr(ptr: DevicePointer) -> Self { 432 | Self::Ptr(ptr.into()) 433 | } 434 | } 435 | 436 | #[derive(Clone, Copy, Debug)] 437 | enum FuncIdent<'a> { 438 | Name(&'a str), 439 | Id(usize), 440 | } 441 | 442 | #[derive(Clone, Copy, Debug)] 443 | pub struct LaunchParams<'a> { 444 | func_id: FuncIdent<'a>, 445 | grid_dim: (u32, u32, u32), 446 | block_dim: (u32, u32, u32), 447 | } 448 | 449 | enum ThreadResult { 450 | Continue, 451 | Sync(usize, Option), 452 | Arrive(usize, Option), 453 | Exit, 454 | } 455 | 456 | impl<'a> LaunchParams<'a> { 457 | pub fn func(name: &'a str) -> LaunchParams<'a> { 458 | LaunchParams { 459 | func_id: FuncIdent::Name(name), 460 | grid_dim: (1, 1, 1), 461 | block_dim: (1, 1, 1), 462 | } 463 | } 464 | 465 | pub fn func_id(id: usize) -> LaunchParams<'a> { 466 | LaunchParams { 467 | func_id: FuncIdent::Id(id), 468 | grid_dim: (1, 1, 1), 469 | block_dim: (1, 1, 1), 470 | } 471 | } 472 | 473 | pub fn grid1d(mut self, x: u32) -> LaunchParams<'a> { 474 | self.grid_dim = (x, 1, 1); 475 | self 476 | } 477 | 478 | pub fn grid2d(mut self, x: u32, y: u32) -> LaunchParams<'a> { 479 | self.grid_dim = (x, y, 1); 480 | self 481 | } 482 | 483 | pub fn block1d(mut self, x: u32) -> LaunchParams<'a> { 484 | self.block_dim = (x, 1, 1); 485 | self 486 | } 487 | 488 | pub fn block2d(mut self, x: u32, y: u32) -> LaunchParams<'a> { 489 | self.block_dim = (x, y, 1); 490 | self 491 | } 492 | } 493 | 494 | #[derive(Clone, Debug)] 495 | struct Barrier { 496 | target: usize, 497 | arrived: usize, 498 | blocked: Vec, 499 | } 500 | 501 | struct Barriers { 502 | barriers: Vec>, 503 | } 504 | 505 | impl Barriers { 506 | pub fn new() -> Self { 507 | Barriers { 508 | barriers: Vec::new(), 509 | } 510 | } 511 | 512 | pub fn arrive(&mut self, _idx: usize, _target: usize) -> VmResult> { 513 | todo!() 514 | } 515 | 516 | pub fn block( 517 | &mut self, 518 | idx: usize, 519 | target: usize, 520 | thread: ThreadState, 521 | ) -> VmResult> { 522 | self.assert_size(idx); 523 | if let Some(ref mut barr) = self.barriers[idx] { 524 | barr.blocked.push(thread); 525 | barr.arrived += 1; 526 | if barr.arrived == barr.target { 527 | let barr = self.barriers[idx].take().unwrap(); 528 | return Ok(barr.blocked); 529 | } 530 | } else { 531 | self.barriers[idx] = Some(Barrier { 532 | target, 533 | arrived: 1, 534 | blocked: vec![thread], 535 | }); 536 | } 537 | Ok(Vec::new()) 538 | } 539 | 540 | fn assert_size(&mut self, idx: usize) { 541 | if idx >= self.barriers.len() { 542 | self.barriers.resize(idx + 1, None); 543 | } 544 | } 545 | } 546 | 547 | macro_rules! binary_op { 548 | ($threadop:expr, $tyop:expr, $dstop:expr, $aop:expr, $bop:expr; 549 | $($target_ty:pat, $op:ident, $getter:ident, $setter:ident);*$(;)?) => { 550 | match ($tyop) { 551 | $( 552 | $target_ty => { 553 | let val = $threadop.$getter($aop).$op($threadop.$getter($bop)); 554 | $threadop.$setter($dstop, val); 555 | } 556 | )* 557 | _ => todo!() 558 | } 559 | }; 560 | } 561 | 562 | macro_rules! binary_op2 { 563 | ($op:path, $threadop:expr, $tyop:expr, $dstop:expr, $aop:expr, $bop:expr; 564 | $($target_ty:pat, $getter:ident, $setter:ident);*$(;)?) => { 565 | match ($tyop) { 566 | $( 567 | $target_ty => { 568 | let val = $op($threadop.$getter($aop), $threadop.$getter($bop)); 569 | $threadop.$setter($dstop, val); 570 | } 571 | )* 572 | _ => todo!() 573 | } 574 | }; 575 | } 576 | 577 | macro_rules! unary_op { 578 | ($threadop:expr, $tyop:expr, $dstop:expr, $srcop:expr; 579 | $($target_ty:pat, $op:ident, $getter:ident, $setter:ident);*$(;)?) => { 580 | match ($tyop) { 581 | $( 582 | $target_ty => { 583 | let val = $threadop.$getter($srcop).$op(); 584 | $threadop.$setter($dstop, val); 585 | } 586 | )* 587 | _ => todo!() 588 | } 589 | }; 590 | } 591 | 592 | macro_rules! comparison_op { 593 | ($threadop:expr, $tyop:expr, $op:expr, $dstop:expr, $aop:expr, $bop:expr; 594 | $($target_ty:pat, $getter:ident);*$(;)?) => { 595 | match ($tyop) { 596 | $( 597 | $target_ty => { 598 | let a = $threadop.$getter($aop); 599 | let b = $threadop.$getter($bop); 600 | let value = match $op { 601 | PredicateOp::LessThan => a < b, 602 | PredicateOp::LessThanEqual => a <= b, 603 | PredicateOp::Equal => a == b, 604 | PredicateOp::NotEqual => a != b, 605 | PredicateOp::GreaterThan => a > b, 606 | PredicateOp::GreaterThanEqual => a >= b, 607 | }; 608 | $threadop.set_pred($dstop, value); 609 | } 610 | )* 611 | _ => todo!() 612 | } 613 | }; 614 | } 615 | 616 | impl Context { 617 | fn fetch_instr(&self, iptr: IPtr) -> Instruction { 618 | self.instructions[iptr.0] 619 | } 620 | 621 | #[cfg(test)] 622 | fn new_raw(program: Vec, descriptors: Vec) -> Context { 623 | Context { 624 | global_mem: Vec::new(), 625 | instructions: program, 626 | descriptors, 627 | symbol_map: HashMap::new(), 628 | } 629 | } 630 | 631 | pub fn new() -> Context { 632 | Context { 633 | global_mem: Vec::new(), 634 | instructions: Vec::new(), 635 | descriptors: Vec::new(), 636 | symbol_map: HashMap::new(), 637 | } 638 | } 639 | 640 | pub fn new_with_module(module: &str) -> VmResult { 641 | let module = crate::ast::parse_program(module)?; 642 | let compiled = crate::compiler::compile(module)?; 643 | Ok(Self { 644 | global_mem: Vec::new(), 645 | instructions: compiled.instructions, 646 | descriptors: compiled.func_descriptors, 647 | symbol_map: compiled.symbol_map, 648 | }) 649 | } 650 | 651 | pub fn load(&mut self, module: &str) -> VmResult<()> { 652 | let module = crate::ast::parse_program(module)?; 653 | let _compiled = crate::compiler::compile(module).unwrap(); 654 | todo!() 655 | } 656 | 657 | pub fn alloc_raw(&mut self, size: usize, align: usize) -> RawDevicePointer { 658 | // Calculate the next aligned position 659 | let aligned_ptr = (self.global_mem.len() + align - 1) & !(align - 1); 660 | // Resize the vector to ensure the space is allocated 661 | self.global_mem.resize(aligned_ptr + size, 0); 662 | // Return the device pointer to the aligned address 663 | RawDevicePointer(aligned_ptr as u64) 664 | } 665 | 666 | pub fn write_raw(&mut self, ptr: RawDevicePointer, offset: usize, data: &[u8]) { 667 | let begin = ptr.0 as usize + offset; 668 | let end = begin + data.len(); 669 | self.global_mem[begin..end].copy_from_slice(data); 670 | } 671 | 672 | pub fn read_raw(&mut self, ptr: RawDevicePointer, offset: usize, data: &mut [u8]) { 673 | let begin = ptr.0 as usize + offset; 674 | let end = begin + data.len(); 675 | data.copy_from_slice(&self.global_mem[begin..end]); 676 | } 677 | 678 | pub fn alloc(&mut self, count: usize) -> DevicePointer { 679 | self.alloc_raw(count * std::mem::size_of::(), std::mem::align_of::()) 680 | .into() 681 | } 682 | 683 | pub fn read(&mut self, src: DevicePointer, dst: &mut [T]) 684 | where 685 | T: bytemuck::NoUninit + bytemuck::AnyBitPattern, 686 | { 687 | self.read_raw(src.into(), 0, bytemuck::cast_slice_mut(dst)); 688 | } 689 | 690 | pub fn write(&mut self, dst: DevicePointer, src: &[T]) 691 | where 692 | T: bytemuck::NoUninit + bytemuck::AnyBitPattern, 693 | { 694 | self.write_raw(dst.into(), 0, bytemuck::cast_slice(src)); 695 | } 696 | 697 | pub fn reset_mem(&mut self) { 698 | self.global_mem.clear(); 699 | } 700 | 701 | fn run_cta( 702 | &mut self, 703 | nctaid: (u32, u32, u32), 704 | ctaid: (u32, u32, u32), 705 | ntid: (u32, u32, u32), 706 | desc: FuncFrameDesc, 707 | init_stack: &[u8], 708 | init_regs: &[u128], 709 | ) -> VmResult<()> { 710 | let mut shared_mem = vec![0u8; desc.shared_size]; 711 | 712 | let mut runnable = Vec::new(); 713 | for x in 0..ntid.0 { 714 | for y in 0..ntid.1 { 715 | for z in 0..ntid.2 { 716 | let mut state = ThreadState::new(nctaid, ctaid, ntid, (x, y, z)); 717 | state.stack_data.extend_from_slice(init_stack); 718 | state.regs.extend_from_slice(init_regs); 719 | state.frame_setup(desc); 720 | runnable.push(state); 721 | } 722 | } 723 | } 724 | let cta_size = (ntid.0 * ntid.1 * ntid.2) as usize; 725 | 726 | let mut barriers = Barriers::new(); 727 | 728 | while let Some(mut state) = runnable.pop() { 729 | loop { 730 | match self.step_thread(&mut state, &mut shared_mem)? { 731 | ThreadResult::Continue => continue, 732 | ThreadResult::Arrive(idx, cnt) => { 733 | let cnt = cnt.unwrap_or(cta_size); 734 | runnable.extend(barriers.arrive(idx, cnt)?); 735 | continue; 736 | } 737 | ThreadResult::Sync(idx, cnt) => { 738 | let cnt = cnt.unwrap_or(cta_size); 739 | runnable.extend(barriers.block(idx, cnt, state)?); 740 | break; 741 | } 742 | ThreadResult::Exit => break, 743 | } 744 | } 745 | } 746 | Ok(()) 747 | } 748 | 749 | fn step_thread( 750 | &mut self, 751 | thread: &mut ThreadState, 752 | shared_mem: &mut [u8], 753 | ) -> VmResult { 754 | let inst = self.fetch_instr(thread.iptr_fetch_incr()); 755 | match inst { 756 | Instruction::Load(ty, space, dst, addr) => { 757 | let addr = thread.read_reg_unsigned(addr)?; 758 | let mut buf = [0u8; 16]; 759 | let data = match space { 760 | StateSpace::Global => self.global_mem.as_slice(), 761 | StateSpace::Stack => thread.stack_data.as_slice(), 762 | StateSpace::Shared => shared_mem, 763 | }; 764 | let size = ty.size(); 765 | buf[..size].copy_from_slice(&data[addr..addr + size]); 766 | thread.set_b128(dst, buf); 767 | } 768 | Instruction::Store(ty, space, src, addr) => { 769 | let addr = thread.read_reg_unsigned(addr)?; 770 | let buf = thread.get_b128(src); 771 | let data = match space { 772 | StateSpace::Global => self.global_mem.as_mut_slice(), 773 | StateSpace::Stack => thread.stack_data.as_mut_slice(), 774 | StateSpace::Shared => shared_mem, 775 | }; 776 | let size = ty.size(); 777 | data[addr..addr + size].copy_from_slice(&buf[..size]); 778 | } 779 | Instruction::Add(ty, dst, a, b) => { 780 | use std::ops::Add; 781 | binary_op! { 782 | thread, ty, dst, a, b; 783 | Type::U64, add, get_u64, set_u64; 784 | Type::S64, add, get_i64, set_i64; 785 | Type::U32, add, get_u32, set_u32; 786 | Type::S32, add, get_i32, set_i32; 787 | Type::F64, add, get_f64, set_f64; 788 | Type::F32, add, get_f32, set_f32; 789 | }; 790 | } 791 | Instruction::Sub(ty, dst, a, b) => { 792 | binary_op2! { 793 | std::ops::Sub::sub, thread, ty, dst, a, b; 794 | Type::U64, get_u64, set_u64; 795 | Type::S64, get_i64, set_i64; 796 | Type::U32, get_u32, set_u32; 797 | Type::S32, get_i32, set_i32; 798 | Type::F64, get_f64, set_f64; 799 | Type::F32, get_f32, set_f32; 800 | }; 801 | } 802 | Instruction::Mul(ty, mode, dst, a, b) => { 803 | use std::ops::Mul; 804 | match mode { 805 | MulMode::Low => { 806 | binary_op! { 807 | thread, ty, dst, a, b; 808 | Type::U64, mul, get_u64, set_u64; 809 | Type::S64, mul, get_i64, set_i64; 810 | Type::U32, mul, get_u32, set_u32; 811 | Type::S32, mul, get_i32, set_i32; 812 | Type::F64, mul, get_f64, set_f64; 813 | Type::F32, mul, get_f32, set_f32; 814 | } 815 | } 816 | MulMode::High => todo!(), 817 | MulMode::Wide => match ty { 818 | Type::U32 => { 819 | thread 820 | .set_u64(dst, thread.get_u32(a) as u64 * thread.get_u32(b) as u64); 821 | } 822 | Type::S32 => { 823 | thread 824 | .set_i64(dst, thread.get_i32(a) as i64 * thread.get_i32(b) as i64); 825 | } 826 | _ => todo!(), 827 | }, 828 | } 829 | } 830 | Instruction::Or(ty, dst, a, b) => { 831 | use std::ops::BitOr; 832 | binary_op! { 833 | thread, ty, dst, a, b; 834 | Type::Pred, bitor, get_pred, set_pred; 835 | Type::U64 | Type::S64, bitor, get_u64, set_u64; 836 | Type::U32 | Type::S32, bitor, get_u32, set_u32; 837 | }; 838 | } 839 | Instruction::And(ty, dst, a, b) => { 840 | use std::ops::BitAnd; 841 | binary_op! { 842 | thread, ty, dst, a, b; 843 | Type::Pred, bitand, get_pred, set_pred; 844 | Type::U64 | Type::S64, bitand, get_u64, set_u64; 845 | Type::U32 | Type::S32, bitand, get_u32, set_u32; 846 | }; 847 | } 848 | Instruction::Neg(ty, dst, src) => { 849 | use std::ops::Neg; 850 | unary_op! { 851 | thread, ty, dst, src; 852 | Type::S64, neg, get_i64, set_i64; 853 | Type::S32, neg, get_i32, set_i32; 854 | Type::F32, neg, get_f32, set_f32; 855 | }; 856 | } 857 | Instruction::Not(ty, dst, src) => { 858 | use std::ops::Not; 859 | unary_op! { 860 | thread, ty, dst, src; 861 | Type::Pred, not, get_pred, set_pred; 862 | Type::U64, not, get_u64, set_u64; 863 | Type::U32, not, get_u32, set_u32; 864 | } 865 | } 866 | Instruction::ShiftLeft(ty, dst, a, b) => { 867 | use std::ops::Shl; 868 | binary_op! { 869 | thread, ty, dst, a, b; 870 | Type::U64, shl, get_u64, set_u64; 871 | Type::U32, shl, get_u32, set_u32; 872 | Type::U16, shl, get_u16, set_u16; 873 | }; 874 | } 875 | Instruction::Convert { 876 | dst_type, 877 | src_type, 878 | dst, 879 | src, 880 | } => match (dst_type, src_type) { 881 | // todo in reality most of these are no-ops 882 | (Type::U64, Type::U32) => { 883 | thread.set_u64(dst, thread.get_u32(src) as u64); 884 | } 885 | (Type::S64, Type::S32) => { 886 | thread.set_i64(dst, thread.get_i32(src) as i64); 887 | } 888 | (Type::U32, Type::U64) => { 889 | thread.set_u32(dst, thread.get_u64(src) as u32); 890 | } 891 | (Type::S32, Type::S64) => { 892 | thread.set_i32(dst, thread.get_i64(src) as i32); 893 | } 894 | _ => todo!(), 895 | }, 896 | Instruction::Move(_, dst, src) => { 897 | thread.set(dst, thread.get(src)); 898 | } 899 | Instruction::Const(dst, value) => match value { 900 | Constant::U64(value) => thread.set_u64(dst, value), 901 | Constant::S64(value) => thread.set_i64(dst, value), 902 | Constant::F32(value) => thread.set_f32(dst, value), 903 | Constant::F64(value) => thread.set_f64(dst, value), 904 | }, 905 | Instruction::SetPredicate(ty, op, dst, a, b) => { 906 | comparison_op! { 907 | thread, ty, op, dst, a, b; 908 | Type::U64, get_u64; 909 | Type::S64, get_i64; 910 | Type::U32, get_u32; 911 | Type::S32, get_i32; 912 | Type::F32, get_f32; 913 | }; 914 | } 915 | Instruction::BarrierSync { idx, cnt } => { 916 | let idx = thread.get_u32(idx) as usize; 917 | let cnt = cnt.map(|r| thread.get_u64(r) as usize); 918 | return Ok(ThreadResult::Sync(idx, cnt)); 919 | } 920 | Instruction::BarrierArrive { idx, cnt } => { 921 | let idx = thread.get_u32(idx) as usize; 922 | let cnt = cnt.map(|r| thread.get_u64(r) as usize); 923 | return Ok(ThreadResult::Arrive(idx, cnt)); 924 | } 925 | Instruction::Jump { offset } => { 926 | thread.iptr.0 = (thread.iptr.0 as isize + offset - 1) as usize; 927 | } 928 | Instruction::SkipIf(cond, expected) => { 929 | if thread.get_pred(cond) == expected { 930 | thread.iptr.0 += 1; 931 | } 932 | } 933 | Instruction::Return => thread.frame_teardown(), 934 | Instruction::PushArg(reg) => { 935 | let val = thread.get(reg); 936 | thread.regs.push(val); 937 | } 938 | Instruction::PopArg(reg) => { 939 | let val = thread.regs.pop().unwrap(); 940 | if let Some(reg) = reg { 941 | thread.set(reg, val); 942 | } 943 | } 944 | Instruction::Call(desc_idx) => { 945 | let desc = self.descriptors[desc_idx]; 946 | thread.frame_setup(desc); 947 | } 948 | } 949 | if thread.stack_frames.is_empty() { 950 | Ok(ThreadResult::Exit) 951 | } else { 952 | Ok(ThreadResult::Continue) 953 | } 954 | } 955 | 956 | pub fn run(&mut self, params: LaunchParams, args: &[Argument]) -> VmResult<()> { 957 | let desc = match params.func_id { 958 | FuncIdent::Name(s) => { 959 | let Some(Symbol::Function(i)) = self.symbol_map.get(s) else { 960 | return Err(VmError::InvalidFunctionName(s.to_string())); 961 | }; 962 | self.descriptors[*i] 963 | } 964 | FuncIdent::Id(i) => *self 965 | .descriptors 966 | .get(i) 967 | .ok_or(VmError::InvalidFunctionId(i))?, 968 | }; 969 | if args.len() != desc.num_args { 970 | return Err(VmError::ParamDataSizeMismatch); 971 | } 972 | 973 | let mut init_stack = Vec::new(); 974 | let mut init_regs = Vec::new(); 975 | for arg in args { 976 | init_regs.push(init_stack.len() as u128); 977 | match arg { 978 | Argument::Ptr(ptr) => { 979 | let ptr_bytes = ptr.0.to_ne_bytes(); 980 | init_stack.extend_from_slice(&ptr_bytes); 981 | } 982 | Argument::U64(v) => { 983 | let v_bytes = v.to_ne_bytes(); 984 | init_stack.extend_from_slice(&v_bytes); 985 | } 986 | Argument::U32(v) => { 987 | let v_bytes = v.to_ne_bytes(); 988 | init_stack.extend_from_slice(&v_bytes); 989 | } 990 | Argument::Bytes(v) => { 991 | init_stack.extend_from_slice(v); 992 | } 993 | } 994 | } 995 | for x in 0..params.grid_dim.0 { 996 | for y in 0..params.grid_dim.1 { 997 | for z in 0..params.grid_dim.2 { 998 | self.run_cta( 999 | params.grid_dim, 1000 | (x, y, z), 1001 | params.block_dim, 1002 | desc, 1003 | &init_stack, 1004 | &init_regs, 1005 | )?; 1006 | } 1007 | } 1008 | } 1009 | Ok(()) 1010 | } 1011 | } 1012 | 1013 | #[cfg(test)] 1014 | mod test { 1015 | use super::*; 1016 | 1017 | #[test] 1018 | fn simple() { 1019 | let prog = vec![ 1020 | Instruction::Load( 1021 | Type::U64, 1022 | StateSpace::Stack, 1023 | GenericReg(0), 1024 | GenericReg(0).into(), 1025 | ), 1026 | Instruction::Load( 1027 | Type::U64, 1028 | StateSpace::Stack, 1029 | GenericReg(1), 1030 | GenericReg(1).into(), 1031 | ), 1032 | Instruction::Load( 1033 | Type::U64, 1034 | StateSpace::Stack, 1035 | GenericReg(2), 1036 | GenericReg(2).into(), 1037 | ), 1038 | Instruction::Load( 1039 | Type::U64, 1040 | StateSpace::Global, 1041 | GenericReg(0), 1042 | GenericReg(0).into(), 1043 | ), 1044 | Instruction::Load( 1045 | Type::U64, 1046 | StateSpace::Global, 1047 | GenericReg(1), 1048 | GenericReg(1).into(), 1049 | ), 1050 | // add values 1051 | Instruction::Add( 1052 | Type::U64, 1053 | GenericReg(0), 1054 | GenericReg(0).into(), 1055 | GenericReg(1).into(), 1056 | ), 1057 | // store result 1058 | Instruction::Store( 1059 | Type::U64, 1060 | StateSpace::Global, 1061 | GenericReg(0).into(), 1062 | GenericReg(2).into(), 1063 | ), 1064 | Instruction::Return, 1065 | ]; 1066 | let desc = vec![FuncFrameDesc { 1067 | iptr: IPtr(0), 1068 | frame_size: 0, 1069 | shared_size: 0, 1070 | num_args: 3, 1071 | num_regs: 0, 1072 | }]; 1073 | const ALIGN: usize = std::mem::align_of::(); 1074 | let mut ctx = Context::new_raw(prog, desc); 1075 | let a = ctx.alloc_raw(8, ALIGN); 1076 | let b = ctx.alloc_raw(8, ALIGN); 1077 | let c = ctx.alloc_raw(8, ALIGN); 1078 | ctx.write_raw(a, 0, &1u64.to_ne_bytes()); 1079 | ctx.write_raw(b, 0, &2u64.to_ne_bytes()); 1080 | ctx.run( 1081 | LaunchParams::func_id(0).grid1d(1).block1d(1), 1082 | &[Argument::Ptr(a), Argument::Ptr(b), Argument::Ptr(c)], 1083 | ) 1084 | .unwrap(); 1085 | let mut res = [0u8; 8]; 1086 | ctx.read_raw(c, 0, &mut res); 1087 | assert_eq!(u64::from_ne_bytes(res), 3); 1088 | } 1089 | 1090 | #[test] 1091 | fn multiple_threads() { 1092 | let prog = vec![ 1093 | Instruction::Load( 1094 | Type::U64, 1095 | StateSpace::Stack, 1096 | GenericReg(0), 1097 | GenericReg(0).into(), 1098 | ), 1099 | Instruction::Load( 1100 | Type::U64, 1101 | StateSpace::Stack, 1102 | GenericReg(1), 1103 | GenericReg(1).into(), 1104 | ), 1105 | Instruction::Load( 1106 | Type::U64, 1107 | StateSpace::Stack, 1108 | GenericReg(2), 1109 | GenericReg(2).into(), 1110 | ), 1111 | Instruction::Const(GenericReg(3), Constant::U64(8)), 1112 | Instruction::Mul( 1113 | Type::U64, 1114 | MulMode::Low, 1115 | GenericReg(3), 1116 | GenericReg(3).into(), 1117 | SpecialReg::ThreadIdX.into(), 1118 | ), 1119 | Instruction::Add( 1120 | Type::U64, 1121 | GenericReg(0), 1122 | GenericReg(0).into(), 1123 | GenericReg(3).into(), 1124 | ), 1125 | Instruction::Add( 1126 | Type::U64, 1127 | GenericReg(1), 1128 | GenericReg(1).into(), 1129 | GenericReg(3).into(), 1130 | ), 1131 | Instruction::Add( 1132 | Type::U64, 1133 | GenericReg(2), 1134 | GenericReg(2).into(), 1135 | GenericReg(3).into(), 1136 | ), 1137 | Instruction::Load( 1138 | Type::U64, 1139 | StateSpace::Global, 1140 | GenericReg(0), 1141 | GenericReg(0).into(), 1142 | ), 1143 | Instruction::Load( 1144 | Type::U64, 1145 | StateSpace::Global, 1146 | GenericReg(1), 1147 | GenericReg(1).into(), 1148 | ), 1149 | // add values 1150 | Instruction::Add( 1151 | Type::U64, 1152 | GenericReg(0), 1153 | GenericReg(0).into(), 1154 | GenericReg(1).into(), 1155 | ), 1156 | // store result 1157 | Instruction::Store( 1158 | Type::U64, 1159 | StateSpace::Global, 1160 | GenericReg(0).into(), 1161 | GenericReg(2).into(), 1162 | ), 1163 | Instruction::Return, 1164 | ]; 1165 | let desc = vec![FuncFrameDesc { 1166 | iptr: IPtr(0), 1167 | frame_size: 0, 1168 | shared_size: 0, 1169 | num_args: 3, 1170 | num_regs: 1, 1171 | }]; 1172 | 1173 | const ALIGN: usize = std::mem::align_of::(); 1174 | const N: usize = 10; 1175 | 1176 | let mut ctx = Context::new_raw(prog, desc); 1177 | let a = ctx.alloc_raw(8 * N, ALIGN); 1178 | let b = ctx.alloc_raw(8 * N, ALIGN); 1179 | let c = ctx.alloc_raw(8 * N, ALIGN); 1180 | 1181 | let data_a = vec![1u64; N]; 1182 | let data_b = vec![2u64; N]; 1183 | ctx.write_raw(a, 0, bytemuck::cast_slice(&data_a)); 1184 | ctx.write_raw(b, 0, bytemuck::cast_slice(&data_b)); 1185 | 1186 | ctx.run( 1187 | LaunchParams::func_id(0).grid1d(1).block1d(N as u32), 1188 | &[Argument::Ptr(a), Argument::Ptr(b), Argument::Ptr(c)], 1189 | ) 1190 | .unwrap(); 1191 | 1192 | let mut res = vec![0u64; N]; 1193 | ctx.read_raw(c, 0, bytemuck::cast_slice_mut(&mut res)); 1194 | 1195 | res.iter().for_each(|v| assert_eq!(*v, 3)); 1196 | } 1197 | } 1198 | --------------------------------------------------------------------------------