├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── README.md ├── build.rs ├── examples ├── enumerate.rs ├── matrix_mul │ ├── main.rs │ ├── matrixMul_kernel.cu │ └── matrixMul_kernel.fatbin └── matrix_mul_jit │ ├── main.rs │ ├── matrixMul_kernel.cu │ └── matrixMul_kernel.ptx └── src ├── context.rs ├── device.rs ├── dim3.rs ├── error.rs ├── func.rs ├── future.rs ├── init.rs ├── kernel_params.rs ├── lib.rs ├── mem.rs ├── module.rs ├── stream.rs ├── sys.rs └── version.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | [[package]] 4 | name = "cfg-if" 5 | version = "1.0.0" 6 | source = "registry+https://github.com/rust-lang/crates.io-index" 7 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 8 | 9 | [[package]] 10 | name = "cuda-oxide" 11 | version = "0.4.0" 12 | dependencies = [ 13 | "num_enum", 14 | "rand", 15 | ] 16 | 17 | [[package]] 18 | name = "derivative" 19 | version = "2.2.0" 20 | source = "registry+https://github.com/rust-lang/crates.io-index" 21 | checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" 22 | dependencies = [ 23 | "proc-macro2", 24 | "quote", 25 | "syn", 26 | ] 27 | 28 | [[package]] 29 | name = "getrandom" 30 | version = "0.2.3" 31 | source = "registry+https://github.com/rust-lang/crates.io-index" 32 | checksum = "7fcd999463524c52659517fe2cea98493cfe485d10565e7b0fb07dbba7ad2753" 33 | dependencies = [ 34 | "cfg-if", 35 | "libc", 36 | "wasi", 37 | ] 38 | 39 | [[package]] 40 | name = "libc" 41 | version = "0.2.95" 42 | source = "registry+https://github.com/rust-lang/crates.io-index" 43 | checksum = "789da6d93f1b866ffe175afc5322a4d76c038605a1c3319bb57b06967ca98a36" 44 | 45 | [[package]] 46 | name = "num_enum" 47 | version = "0.5.1" 48 | source = "registry+https://github.com/rust-lang/crates.io-index" 49 | checksum = "226b45a5c2ac4dd696ed30fa6b94b057ad909c7b7fc2e0d0808192bced894066" 50 | dependencies = [ 51 | "derivative", 52 | "num_enum_derive", 53 | ] 54 | 55 | [[package]] 56 | name = "num_enum_derive" 57 | version = "0.5.1" 58 | source = "registry+https://github.com/rust-lang/crates.io-index" 59 | checksum = "1c0fd9eba1d5db0994a239e09c1be402d35622277e35468ba891aa5e3188ce7e" 60 | dependencies = [ 61 | "proc-macro-crate", 62 | "proc-macro2", 63 | "quote", 64 | "syn", 65 | ] 66 | 67 | [[package]] 68 | name = "ppv-lite86" 69 | version = "0.2.10" 70 | source = "registry+https://github.com/rust-lang/crates.io-index" 71 | checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857" 72 | 73 | [[package]] 74 | name = "proc-macro-crate" 75 | version = "0.1.5" 76 | source = "registry+https://github.com/rust-lang/crates.io-index" 77 | checksum = "1d6ea3c4595b96363c13943497db34af4460fb474a95c43f4446ad341b8c9785" 78 | dependencies = [ 79 | "toml", 80 | ] 81 | 82 | [[package]] 83 | name = "proc-macro2" 84 | version = "1.0.27" 85 | source = "registry+https://github.com/rust-lang/crates.io-index" 86 | checksum = "f0d8caf72986c1a598726adc988bb5984792ef84f5ee5aa50209145ee8077038" 87 | dependencies = [ 88 | "unicode-xid", 89 | ] 90 | 91 | [[package]] 92 | name = "quote" 93 | version = "1.0.9" 94 | source = "registry+https://github.com/rust-lang/crates.io-index" 95 | checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7" 96 | dependencies = [ 97 | "proc-macro2", 98 | ] 99 | 100 | [[package]] 101 | name = "rand" 102 | version = "0.8.3" 103 | source = "registry+https://github.com/rust-lang/crates.io-index" 104 | checksum = "0ef9e7e66b4468674bfcb0c81af8b7fa0bb154fa9f28eb840da5c447baeb8d7e" 105 | dependencies = [ 106 | "libc", 107 | "rand_chacha", 108 | "rand_core", 109 | "rand_hc", 110 | ] 111 | 112 | [[package]] 113 | name = "rand_chacha" 114 | version = "0.3.0" 115 | source = "registry+https://github.com/rust-lang/crates.io-index" 116 | checksum = "e12735cf05c9e10bf21534da50a147b924d555dc7a547c42e6bb2d5b6017ae0d" 117 | dependencies = [ 118 | "ppv-lite86", 119 | "rand_core", 120 | ] 121 | 122 | [[package]] 123 | name = "rand_core" 124 | version = "0.6.2" 125 | source = "registry+https://github.com/rust-lang/crates.io-index" 126 | checksum = "34cf66eb183df1c5876e2dcf6b13d57340741e8dc255b48e40a26de954d06ae7" 127 | dependencies = [ 128 | "getrandom", 129 | ] 130 | 131 | [[package]] 132 | name = "rand_hc" 133 | version = "0.3.0" 134 | source = "registry+https://github.com/rust-lang/crates.io-index" 135 | checksum = "3190ef7066a446f2e7f42e239d161e905420ccab01eb967c9eb27d21b2322a73" 136 | dependencies = [ 137 | "rand_core", 138 | ] 139 | 140 | [[package]] 141 | name = "serde" 142 | version = "1.0.126" 143 | source = "registry+https://github.com/rust-lang/crates.io-index" 144 | checksum = "ec7505abeacaec74ae4778d9d9328fe5a5d04253220a85c4ee022239fc996d03" 145 | 146 | [[package]] 147 | name = "syn" 148 | version = "1.0.72" 149 | source = "registry+https://github.com/rust-lang/crates.io-index" 150 | checksum = "a1e8cdbefb79a9a5a65e0db8b47b723ee907b7c7f8496c76a1770b5c310bab82" 151 | dependencies = [ 152 | "proc-macro2", 153 | "quote", 154 | "unicode-xid", 155 | ] 156 | 157 | [[package]] 158 | name = "toml" 159 | version = "0.5.8" 160 | source = "registry+https://github.com/rust-lang/crates.io-index" 161 | checksum = "a31142970826733df8241ef35dc040ef98c679ab14d7c3e54d827099b3acecaa" 162 | dependencies = [ 163 | "serde", 164 | ] 165 | 166 | [[package]] 167 | name = "unicode-xid" 168 | version = "0.2.2" 169 | source = "registry+https://github.com/rust-lang/crates.io-index" 170 | checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" 171 | 172 | [[package]] 173 | name = "wasi" 174 | version = "0.10.2+wasi-snapshot-preview1" 175 | source = "registry+https://github.com/rust-lang/crates.io-index" 176 | checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" 177 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "cuda-oxide" 3 | version = "0.4.0" 4 | authors = ["Maxwell Bruce "] 5 | edition = "2018" 6 | license = "GPL-3.0-or-later" 7 | repository = "https://github.com/Protryon/cuda-oxide" 8 | description = "cuda-oxide provides a high-level, rusty wrapper over CUDA. It provides the best safety one can get when working with hardware." 9 | keywords = [ "cuda", "gpu", "parallel" ] 10 | 11 | [dependencies] 12 | num_enum = "0.5" 13 | 14 | [dev-dependencies] 15 | rand = "0.8" -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # cuda-oxide 3 | 4 | `cuda-oxide` is a *safe* wrapper for [CUDA](https://en.wikipedia.org/wiki/CUDA). With `cuda-oxide` you can execute and coordinate CUDA kernels. 5 | 6 | ## Safety Philosophy 7 | 8 | `cuda-oxide` *does not* offer any safety on the GPU-side of writing CUDA code. It doesn't compile Rust to PTX. `cuda-oxide` offers general CPU-level safety working with the CUDA library and best-availability safety for working with GPU buffers and objects. 9 | 10 | Examples of things currently considered safe: 11 | * Reading from an uninitialized GPU buffer into host memory 12 | * Some invalid `libcuda` operations that will cause `libcuda` to stop accepting any API calls 13 | * Setting various attributes that can have side effects for an entire device 14 | * Writing to read-only device memory 15 | 16 | ## Supported Features 17 | * Device Management 18 | * Context Management 19 | * Module Management 20 | * JIT compilation of Modules 21 | * Stream Management 22 | * Kernel Execution 23 | * Device Memory read/write 24 | 25 | ## Unsupported Features 26 | * Memory Pools 27 | * Unified Addressing 28 | * Events & Stream Events 29 | * Stream State Polling 30 | * Stream Graph Capturing 31 | * Stream Batch Memory Operations 32 | * External Memory 33 | * Multi-device helper (possible already, but not made easy) 34 | * Graphs 35 | * Textures & Surfaces 36 | * OpenGL/VDPAU/EGL Interoperability 37 | 38 | ## Examples 39 | 40 | See the `examples` directory for usage examples. -------------------------------------------------------------------------------- /build.rs: -------------------------------------------------------------------------------- 1 | fn main() { 2 | println!("cargo:rustc-link-lib=dylib=cuda"); 3 | let cuda_path = std::env::var("CUDA_LIB_PATH") 4 | .unwrap_or_else(|_| "/usr/local/cuda-11.3/lib64/".to_string()); 5 | println!("cargo:rustc-link-search=native={}", cuda_path); 6 | } 7 | -------------------------------------------------------------------------------- /examples/enumerate.rs: -------------------------------------------------------------------------------- 1 | use cuda_oxide::*; 2 | 3 | fn main() { 4 | Cuda::init().unwrap(); 5 | let v = Cuda::version().unwrap(); 6 | println!("{:?}", v); 7 | for device in Cuda::list_devices().unwrap() { 8 | println!("name: {}", device.name().unwrap()); 9 | println!("uuid: {}", device.uuid().unwrap()); 10 | println!("memory size: {}", device.memory_size().unwrap()); 11 | println!( 12 | "clock rate: {}", 13 | device 14 | .get_attribute(DeviceAttribute::MemoryClockRate) 15 | .unwrap() 16 | ); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /examples/matrix_mul/main.rs: -------------------------------------------------------------------------------- 1 | use cuda_oxide::*; 2 | use rand::{thread_rng, Rng}; 3 | 4 | const BLOCK_SIZE: u32 = 32; 5 | const A_WIDTH: usize = BLOCK_SIZE as usize * 40; 6 | const A_HEIGHT: usize = BLOCK_SIZE as usize * 60; 7 | const B_WIDTH: usize = BLOCK_SIZE as usize * 40; 8 | const B_HEIGHT: usize = BLOCK_SIZE as usize * 40; 9 | const C_WIDTH: usize = B_WIDTH; 10 | const C_HEIGHT: usize = A_HEIGHT; 11 | 12 | fn matrix_bytes(input: &[f64]) -> &[u8] { 13 | unsafe { std::slice::from_raw_parts(input.as_ptr() as *const u8, input.len() * 8) } 14 | } 15 | 16 | fn bytes_matrix(input: &[u8]) -> &[f64] { 17 | unsafe { std::slice::from_raw_parts(input.as_ptr() as *const f64, input.len() / 8) } 18 | } 19 | 20 | fn main() { 21 | Cuda::init().unwrap(); 22 | let v = Cuda::version().unwrap(); 23 | println!("Using CUDA {}.{}", v.major, v.minor); 24 | let device = Cuda::list_devices().unwrap(); 25 | let device = device.first().unwrap(); 26 | println!("using device: {}", device.name().unwrap()); 27 | 28 | // normally this would be built by a build script, but examples in cargo don't seem to support this 29 | // nvcc matrixMul_kernel.cu -o matrixMul_kernel.fatbin -fatbin 30 | let kernel = include_bytes!("./matrixMul_kernel.fatbin"); 31 | let mut context = Context::new(device).unwrap(); 32 | let handle = context.enter().unwrap(); 33 | let module = Module::load(&handle, &kernel[..]).unwrap(); 34 | let function = module.get_function("matrixMul_bs32_64bit").unwrap(); 35 | 36 | let mut mat_a = vec![0.0; A_WIDTH * A_HEIGHT]; 37 | let mut mat_b = vec![0.0; B_WIDTH * B_HEIGHT]; 38 | 39 | for i in 0..mat_a.len() { 40 | mat_a[i] = thread_rng().gen_range(0.0..1.0); 41 | } 42 | for i in 0..mat_b.len() { 43 | mat_b[i] = thread_rng().gen_range(0.0..1.0); 44 | } 45 | 46 | let device_mat_a = DeviceBox::new(&handle, matrix_bytes(&mat_a[..])).unwrap(); 47 | let device_mat_b = DeviceBox::new(&handle, matrix_bytes(&mat_b[..])).unwrap(); 48 | 49 | let output = DeviceBox::alloc(&handle, C_WIDTH as u64 * C_HEIGHT as u64 * 8).unwrap(); 50 | 51 | handle.context().synchronize().unwrap(); 52 | 53 | let rea = device_mat_a.load().unwrap(); 54 | assert_eq!(&rea[..], matrix_bytes(&mat_a[..])); 55 | 56 | let mut stream = Stream::new(&handle).unwrap(); 57 | unsafe { 58 | stream.launch( 59 | &function, 60 | (C_WIDTH as u32 / BLOCK_SIZE, C_HEIGHT as u32 / BLOCK_SIZE), 61 | (BLOCK_SIZE, BLOCK_SIZE), 62 | 2 * BLOCK_SIZE * BLOCK_SIZE * 8, 63 | ( 64 | &output, 65 | &device_mat_a, 66 | &device_mat_b, 67 | A_WIDTH as usize, 68 | B_WIDTH as usize, 69 | ), 70 | ) 71 | } 72 | .unwrap(); 73 | 74 | stream.callback(|| println!("done")).unwrap(); 75 | 76 | stream.sync().unwrap(); 77 | 78 | let output = output.load().unwrap(); 79 | let output = bytes_matrix(&output[..]); 80 | println!("{:?}", output); 81 | } 82 | -------------------------------------------------------------------------------- /examples/matrix_mul/matrixMul_kernel.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 1993-2017 NVIDIA Corporation. All rights reserved. 3 | * 4 | * Please refer to the NVIDIA end user license agreement (EULA) associated 5 | * with this source code for terms and conditions that govern your use of 6 | * this software. Any use, reproduction, disclosure, or distribution of 7 | * this software and related documentation outside the terms of the EULA 8 | * is strictly prohibited. 9 | * 10 | */ 11 | 12 | /* Matrix multiplication: C = A * B. 13 | * Device code. 14 | */ 15 | 16 | #ifndef _MATRIXMUL_KERNEL_H_ 17 | #define _MATRIXMUL_KERNEL_H_ 18 | 19 | #include 20 | 21 | #define AS(i, j) As[i][j] 22 | #define BS(i, j) Bs[i][j] 23 | 24 | //////////////////////////////////////////////////////////////////////////////// 25 | //! Matrix multiplication on the device: C = A * B 26 | //! wA is A's width and wB is B's width 27 | //////////////////////////////////////////////////////////////////////////////// 28 | template 29 | __device__ void matrixMul(double *C, double *A, double *B, size_type wA, size_type wB) { 30 | // Block index 31 | size_type bx = blockIdx.x; 32 | size_type by = blockIdx.y; 33 | 34 | // Thread index 35 | size_type tx = threadIdx.x; 36 | size_type ty = threadIdx.y; 37 | 38 | // Index of the first sub-matrix of A processed by the block 39 | size_type aBegin = wA * block_size * by; 40 | 41 | // Index of the last sub-matrix of A processed by the block 42 | size_type aEnd = aBegin + wA - 1; 43 | 44 | // Step size used to iterate through the sub-matrices of A 45 | size_type aStep = block_size; 46 | 47 | // Index of the first sub-matrix of B processed by the block 48 | size_type bBegin = block_size * bx; 49 | 50 | // Step size used to iterate through the sub-matrices of B 51 | size_type bStep = block_size * wB; 52 | 53 | // Csub is used to store the element of the block sub-matrix 54 | // that is computed by the thread 55 | double Csub = 0; 56 | 57 | // Loop over all the sub-matrices of A and B 58 | // required to compute the block sub-matrix 59 | for (size_type a = aBegin, b = bBegin; a <= aEnd; a += aStep, b += bStep) { 60 | // Declaration of the shared memory array As used to 61 | // store the sub-matrix of A 62 | __shared__ double As[block_size][block_size]; 63 | 64 | // Declaration of the shared memory array Bs used to 65 | // store the sub-matrix of B 66 | __shared__ double Bs[block_size][block_size]; 67 | 68 | // Load the matrices from device memory 69 | // to shared memory; each thread loads 70 | // one element of each matrix 71 | AS(ty, tx) = A[a + wA * ty + tx]; 72 | BS(ty, tx) = B[b + wB * ty + tx]; 73 | 74 | // Synchronize to make sure the matrices are loaded 75 | __syncthreads(); 76 | 77 | // Multiply the two matrices together; 78 | // each thread computes one element 79 | // of the block sub-matrix 80 | #pragma unroll 81 | 82 | for (size_type k = 0; k < block_size; ++k) Csub += AS(ty, k) * BS(k, tx); 83 | 84 | // Synchronize to make sure that the preceding 85 | // computation is done before loading two new 86 | // sub-matrices of A and B in the next iteration 87 | __syncthreads(); 88 | } 89 | 90 | // Write the block sub-matrix to device memory; 91 | // each thread writes one element 92 | size_type c = wB * block_size * by + block_size * bx; 93 | C[c + wB * ty + tx] = Csub; 94 | } 95 | 96 | // C wrappers around our template kernel 97 | extern "C" __global__ void matrixMul_bs16_64bit(double *C, double *A, double *B, 98 | size_t wA, size_t wB) { 99 | matrixMul<16, size_t>(C, A, B, wA, wB); 100 | } 101 | extern "C" __global__ void matrixMul_bs32_64bit(double *C, double *A, double *B, 102 | size_t wA, size_t wB) { 103 | matrixMul<32, size_t>(C, A, B, wA, wB); 104 | } 105 | 106 | #endif // #ifndef _MATRIXMUL_KERNEL_H_ 107 | -------------------------------------------------------------------------------- /examples/matrix_mul/matrixMul_kernel.fatbin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Protryon/cuda-oxide/021a53a99fa783817ddfb501a2873e2be954e5bb/examples/matrix_mul/matrixMul_kernel.fatbin -------------------------------------------------------------------------------- /examples/matrix_mul_jit/main.rs: -------------------------------------------------------------------------------- 1 | use cuda_oxide::*; 2 | use rand::{thread_rng, Rng}; 3 | 4 | const BLOCK_SIZE: u32 = 32; 5 | const A_WIDTH: usize = BLOCK_SIZE as usize * 40; 6 | const A_HEIGHT: usize = BLOCK_SIZE as usize * 60; 7 | const B_WIDTH: usize = BLOCK_SIZE as usize * 40; 8 | const B_HEIGHT: usize = BLOCK_SIZE as usize * 40; 9 | const C_WIDTH: usize = B_WIDTH; 10 | const C_HEIGHT: usize = A_HEIGHT; 11 | 12 | fn matrix_bytes(input: &[f64]) -> &[u8] { 13 | unsafe { std::slice::from_raw_parts(input.as_ptr() as *const u8, input.len() * 8) } 14 | } 15 | 16 | fn bytes_matrix(input: &[u8]) -> &[f64] { 17 | unsafe { std::slice::from_raw_parts(input.as_ptr() as *const f64, input.len() / 8) } 18 | } 19 | 20 | fn main() { 21 | Cuda::init().unwrap(); 22 | let v = Cuda::version().unwrap(); 23 | println!("Using CUDA {}", v); 24 | let device = Cuda::list_devices().unwrap(); 25 | let device = device.first().unwrap(); 26 | println!("using device: {}", device.name().unwrap()); 27 | let device_compute = device.compute_capability().unwrap(); 28 | println!("cuda device compute capability = {}", device_compute); 29 | 30 | let mut context = Context::new(device).unwrap(); 31 | 32 | let handle = context.enter().unwrap(); 33 | 34 | // normally this would be built by a build script, but examples in cargo don't seem to support this 35 | // nvcc matrixMul_kernel.cu -ptx 36 | let kernel = include_bytes!("./matrixMul_kernel.ptx"); 37 | let linked_kernel = Linker::new(&handle, device_compute, LinkerOptions::default()) 38 | .unwrap() 39 | .add("matrixMul_kernel.ptx", LinkerInputType::Ptx, &kernel[..]) 40 | .unwrap(); 41 | let module = linked_kernel.build_module().unwrap(); 42 | 43 | let function = module.get_function("matrixMul_bs32_64bit").unwrap(); 44 | 45 | let mut mat_a = vec![0.0; A_WIDTH * A_HEIGHT]; 46 | let mut mat_b = vec![0.0; B_WIDTH * B_HEIGHT]; 47 | 48 | for i in 0..mat_a.len() { 49 | mat_a[i] = thread_rng().gen_range(0.0..1.0); 50 | } 51 | for i in 0..mat_b.len() { 52 | mat_b[i] = thread_rng().gen_range(0.0..1.0); 53 | } 54 | 55 | let device_mat_a = DeviceBox::new(&handle, matrix_bytes(&mat_a[..])).unwrap(); 56 | let device_mat_b = DeviceBox::new(&handle, matrix_bytes(&mat_b[..])).unwrap(); 57 | 58 | let output = DeviceBox::alloc(&handle, C_WIDTH as u64 * C_HEIGHT as u64 * 8).unwrap(); 59 | 60 | handle.context().synchronize().unwrap(); 61 | 62 | let rea = device_mat_a.load().unwrap(); 63 | assert_eq!(&rea[..], matrix_bytes(&mat_a[..])); 64 | 65 | let mut stream = Stream::new(&handle).unwrap(); 66 | unsafe { 67 | stream.launch( 68 | &function, 69 | (C_WIDTH as u32 / BLOCK_SIZE, C_HEIGHT as u32 / BLOCK_SIZE), 70 | (BLOCK_SIZE, BLOCK_SIZE), 71 | 2 * BLOCK_SIZE * BLOCK_SIZE * 8, 72 | ( 73 | &output, 74 | &device_mat_a, 75 | &device_mat_b, 76 | A_WIDTH as usize, 77 | B_WIDTH as usize, 78 | ), 79 | ) 80 | } 81 | .unwrap(); 82 | 83 | stream.callback(|| println!("done")).unwrap(); 84 | 85 | stream.sync().unwrap(); 86 | 87 | let output = output.load().unwrap(); 88 | let output = bytes_matrix(&output[..]); 89 | println!("{:?}", output); 90 | } 91 | -------------------------------------------------------------------------------- /examples/matrix_mul_jit/matrixMul_kernel.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 1993-2017 NVIDIA Corporation. All rights reserved. 3 | * 4 | * Please refer to the NVIDIA end user license agreement (EULA) associated 5 | * with this source code for terms and conditions that govern your use of 6 | * this software. Any use, reproduction, disclosure, or distribution of 7 | * this software and related documentation outside the terms of the EULA 8 | * is strictly prohibited. 9 | * 10 | */ 11 | 12 | /* Matrix multiplication: C = A * B. 13 | * Device code. 14 | */ 15 | 16 | #ifndef _MATRIXMUL_KERNEL_H_ 17 | #define _MATRIXMUL_KERNEL_H_ 18 | 19 | #include 20 | 21 | #define AS(i, j) As[i][j] 22 | #define BS(i, j) Bs[i][j] 23 | 24 | //////////////////////////////////////////////////////////////////////////////// 25 | //! Matrix multiplication on the device: C = A * B 26 | //! wA is A's width and wB is B's width 27 | //////////////////////////////////////////////////////////////////////////////// 28 | template 29 | __device__ void matrixMul(double *C, double *A, double *B, size_type wA, size_type wB) { 30 | // Block index 31 | size_type bx = blockIdx.x; 32 | size_type by = blockIdx.y; 33 | 34 | // Thread index 35 | size_type tx = threadIdx.x; 36 | size_type ty = threadIdx.y; 37 | 38 | // Index of the first sub-matrix of A processed by the block 39 | size_type aBegin = wA * block_size * by; 40 | 41 | // Index of the last sub-matrix of A processed by the block 42 | size_type aEnd = aBegin + wA - 1; 43 | 44 | // Step size used to iterate through the sub-matrices of A 45 | size_type aStep = block_size; 46 | 47 | // Index of the first sub-matrix of B processed by the block 48 | size_type bBegin = block_size * bx; 49 | 50 | // Step size used to iterate through the sub-matrices of B 51 | size_type bStep = block_size * wB; 52 | 53 | // Csub is used to store the element of the block sub-matrix 54 | // that is computed by the thread 55 | double Csub = 0; 56 | 57 | // Loop over all the sub-matrices of A and B 58 | // required to compute the block sub-matrix 59 | for (size_type a = aBegin, b = bBegin; a <= aEnd; a += aStep, b += bStep) { 60 | // Declaration of the shared memory array As used to 61 | // store the sub-matrix of A 62 | __shared__ double As[block_size][block_size]; 63 | 64 | // Declaration of the shared memory array Bs used to 65 | // store the sub-matrix of B 66 | __shared__ double Bs[block_size][block_size]; 67 | 68 | // Load the matrices from device memory 69 | // to shared memory; each thread loads 70 | // one element of each matrix 71 | AS(ty, tx) = A[a + wA * ty + tx]; 72 | BS(ty, tx) = B[b + wB * ty + tx]; 73 | 74 | // Synchronize to make sure the matrices are loaded 75 | __syncthreads(); 76 | 77 | // Multiply the two matrices together; 78 | // each thread computes one element 79 | // of the block sub-matrix 80 | #pragma unroll 81 | 82 | for (size_type k = 0; k < block_size; ++k) Csub += AS(ty, k) * BS(k, tx); 83 | 84 | // Synchronize to make sure that the preceding 85 | // computation is done before loading two new 86 | // sub-matrices of A and B in the next iteration 87 | __syncthreads(); 88 | } 89 | 90 | // Write the block sub-matrix to device memory; 91 | // each thread writes one element 92 | size_type c = wB * block_size * by + block_size * bx; 93 | C[c + wB * ty + tx] = Csub; 94 | } 95 | 96 | // C wrappers around our template kernel 97 | extern "C" __global__ void matrixMul_bs16_64bit(double *C, double *A, double *B, 98 | size_t wA, size_t wB) { 99 | matrixMul<16, size_t>(C, A, B, wA, wB); 100 | } 101 | extern "C" __global__ void matrixMul_bs32_64bit(double *C, double *A, double *B, 102 | size_t wA, size_t wB) { 103 | matrixMul<32, size_t>(C, A, B, wA, wB); 104 | } 105 | 106 | #endif // #ifndef _MATRIXMUL_KERNEL_H_ 107 | -------------------------------------------------------------------------------- /examples/matrix_mul_jit/matrixMul_kernel.ptx: -------------------------------------------------------------------------------- 1 | // 2 | // Generated by NVIDIA NVVM Compiler 3 | // 4 | // Compiler Build ID: CL-29920130 5 | // Cuda compilation tools, release 11.3, V11.3.109 6 | // Based on NVVM 7.0.1 7 | // 8 | 9 | .version 7.3 10 | .target sm_86 11 | .address_size 64 12 | 13 | // .globl matrixMul_bs16_64bit 14 | // _ZZ9matrixMulILi16EmEvPdS0_S0_T0_S1_E2As has been demoted 15 | // _ZZ9matrixMulILi16EmEvPdS0_S0_T0_S1_E2Bs has been demoted 16 | // _ZZ9matrixMulILi32EmEvPdS0_S0_T0_S1_E2As has been demoted 17 | // _ZZ9matrixMulILi32EmEvPdS0_S0_T0_S1_E2Bs has been demoted 18 | 19 | .visible .entry matrixMul_bs16_64bit( 20 | .param .u64 matrixMul_bs16_64bit_param_0, 21 | .param .u64 matrixMul_bs16_64bit_param_1, 22 | .param .u64 matrixMul_bs16_64bit_param_2, 23 | .param .u64 matrixMul_bs16_64bit_param_3, 24 | .param .u64 matrixMul_bs16_64bit_param_4 25 | ) 26 | { 27 | .reg .pred %p<3>; 28 | .reg .b32 %r<16>; 29 | .reg .f64 %fd<57>; 30 | .reg .b64 %rd<41>; 31 | // demoted variable 32 | .shared .align 8 .b8 _ZZ9matrixMulILi16EmEvPdS0_S0_T0_S1_E2As[2048]; 33 | // demoted variable 34 | .shared .align 8 .b8 _ZZ9matrixMulILi16EmEvPdS0_S0_T0_S1_E2Bs[2048]; 35 | 36 | ld.param.u64 %rd18, [matrixMul_bs16_64bit_param_0]; 37 | ld.param.u64 %rd19, [matrixMul_bs16_64bit_param_1]; 38 | ld.param.u64 %rd20, [matrixMul_bs16_64bit_param_2]; 39 | ld.param.u64 %rd21, [matrixMul_bs16_64bit_param_3]; 40 | ld.param.u64 %rd22, [matrixMul_bs16_64bit_param_4]; 41 | mov.u32 %r5, %ctaid.x; 42 | mov.u32 %r6, %ctaid.y; 43 | cvt.u64.u32 %rd1, %r6; 44 | mov.u32 %r7, %tid.x; 45 | cvt.u64.u32 %rd2, %r7; 46 | mov.u32 %r8, %tid.y; 47 | cvt.u64.u32 %rd3, %r8; 48 | mul.lo.s64 %rd23, %rd21, %rd1; 49 | shl.b64 %rd40, %rd23, 4; 50 | add.s64 %rd24, %rd21, -1; 51 | add.s64 %rd5, %rd24, %rd40; 52 | setp.lt.u64 %p1, %rd5, %rd24; 53 | mul.wide.u32 %rd6, %r5, 16; 54 | shl.b64 %rd7, %rd22, 4; 55 | mul.lo.s64 %rd8, %rd3, %rd22; 56 | mov.f64 %fd56, 0d0000000000000000; 57 | @%p1 bra $L__BB0_3; 58 | 59 | cvt.u32.u64 %r9, %rd3; 60 | mul.lo.s64 %rd25, %rd3, %rd21; 61 | add.s64 %rd9, %rd25, %rd2; 62 | cvt.u32.u64 %r10, %rd2; 63 | shl.b32 %r11, %r9, 7; 64 | mov.u32 %r12, _ZZ9matrixMulILi16EmEvPdS0_S0_T0_S1_E2As; 65 | add.s32 %r3, %r12, %r11; 66 | shl.b32 %r13, %r10, 3; 67 | add.s32 %r1, %r3, %r13; 68 | add.s64 %rd10, %rd8, %rd2; 69 | mov.u32 %r14, _ZZ9matrixMulILi16EmEvPdS0_S0_T0_S1_E2Bs; 70 | add.s32 %r15, %r14, %r11; 71 | add.s32 %r2, %r15, %r13; 72 | add.s32 %r4, %r14, %r13; 73 | cvta.to.global.u64 %rd12, %rd19; 74 | cvta.to.global.u64 %rd13, %rd20; 75 | mov.f64 %fd56, 0d0000000000000000; 76 | mov.u64 %rd39, %rd6; 77 | 78 | $L__BB0_2: 79 | add.s64 %rd26, %rd9, %rd40; 80 | shl.b64 %rd27, %rd26, 3; 81 | add.s64 %rd28, %rd12, %rd27; 82 | ld.global.f64 %fd6, [%rd28]; 83 | st.shared.f64 [%r1], %fd6; 84 | add.s64 %rd29, %rd10, %rd39; 85 | shl.b64 %rd30, %rd29, 3; 86 | add.s64 %rd31, %rd13, %rd30; 87 | ld.global.f64 %fd7, [%rd31]; 88 | st.shared.f64 [%r2], %fd7; 89 | bar.sync 0; 90 | ld.shared.f64 %fd8, [%r4]; 91 | ld.shared.f64 %fd9, [%r3]; 92 | fma.rn.f64 %fd10, %fd9, %fd8, %fd56; 93 | ld.shared.f64 %fd11, [%r4+128]; 94 | ld.shared.f64 %fd12, [%r3+8]; 95 | fma.rn.f64 %fd13, %fd12, %fd11, %fd10; 96 | ld.shared.f64 %fd14, [%r4+256]; 97 | ld.shared.f64 %fd15, [%r3+16]; 98 | fma.rn.f64 %fd16, %fd15, %fd14, %fd13; 99 | ld.shared.f64 %fd17, [%r4+384]; 100 | ld.shared.f64 %fd18, [%r3+24]; 101 | fma.rn.f64 %fd19, %fd18, %fd17, %fd16; 102 | ld.shared.f64 %fd20, [%r4+512]; 103 | ld.shared.f64 %fd21, [%r3+32]; 104 | fma.rn.f64 %fd22, %fd21, %fd20, %fd19; 105 | ld.shared.f64 %fd23, [%r4+640]; 106 | ld.shared.f64 %fd24, [%r3+40]; 107 | fma.rn.f64 %fd25, %fd24, %fd23, %fd22; 108 | ld.shared.f64 %fd26, [%r4+768]; 109 | ld.shared.f64 %fd27, [%r3+48]; 110 | fma.rn.f64 %fd28, %fd27, %fd26, %fd25; 111 | ld.shared.f64 %fd29, [%r4+896]; 112 | ld.shared.f64 %fd30, [%r3+56]; 113 | fma.rn.f64 %fd31, %fd30, %fd29, %fd28; 114 | ld.shared.f64 %fd32, [%r4+1024]; 115 | ld.shared.f64 %fd33, [%r3+64]; 116 | fma.rn.f64 %fd34, %fd33, %fd32, %fd31; 117 | ld.shared.f64 %fd35, [%r4+1152]; 118 | ld.shared.f64 %fd36, [%r3+72]; 119 | fma.rn.f64 %fd37, %fd36, %fd35, %fd34; 120 | ld.shared.f64 %fd38, [%r4+1280]; 121 | ld.shared.f64 %fd39, [%r3+80]; 122 | fma.rn.f64 %fd40, %fd39, %fd38, %fd37; 123 | ld.shared.f64 %fd41, [%r4+1408]; 124 | ld.shared.f64 %fd42, [%r3+88]; 125 | fma.rn.f64 %fd43, %fd42, %fd41, %fd40; 126 | ld.shared.f64 %fd44, [%r4+1536]; 127 | ld.shared.f64 %fd45, [%r3+96]; 128 | fma.rn.f64 %fd46, %fd45, %fd44, %fd43; 129 | ld.shared.f64 %fd47, [%r4+1664]; 130 | ld.shared.f64 %fd48, [%r3+104]; 131 | fma.rn.f64 %fd49, %fd48, %fd47, %fd46; 132 | ld.shared.f64 %fd50, [%r4+1792]; 133 | ld.shared.f64 %fd51, [%r3+112]; 134 | fma.rn.f64 %fd52, %fd51, %fd50, %fd49; 135 | ld.shared.f64 %fd53, [%r4+1920]; 136 | ld.shared.f64 %fd54, [%r3+120]; 137 | fma.rn.f64 %fd56, %fd54, %fd53, %fd52; 138 | bar.sync 0; 139 | add.s64 %rd39, %rd39, %rd7; 140 | add.s64 %rd40, %rd40, 16; 141 | setp.le.u64 %p2, %rd40, %rd5; 142 | @%p2 bra $L__BB0_2; 143 | 144 | $L__BB0_3: 145 | add.s64 %rd32, %rd6, %rd2; 146 | add.s64 %rd33, %rd32, %rd8; 147 | mul.lo.s64 %rd34, %rd7, %rd1; 148 | add.s64 %rd35, %rd33, %rd34; 149 | cvta.to.global.u64 %rd36, %rd18; 150 | shl.b64 %rd37, %rd35, 3; 151 | add.s64 %rd38, %rd36, %rd37; 152 | st.global.f64 [%rd38], %fd56; 153 | ret; 154 | 155 | } 156 | // .globl matrixMul_bs32_64bit 157 | .visible .entry matrixMul_bs32_64bit( 158 | .param .u64 matrixMul_bs32_64bit_param_0, 159 | .param .u64 matrixMul_bs32_64bit_param_1, 160 | .param .u64 matrixMul_bs32_64bit_param_2, 161 | .param .u64 matrixMul_bs32_64bit_param_3, 162 | .param .u64 matrixMul_bs32_64bit_param_4 163 | ) 164 | { 165 | .reg .pred %p<3>; 166 | .reg .b32 %r<16>; 167 | .reg .f64 %fd<105>; 168 | .reg .b64 %rd<41>; 169 | // demoted variable 170 | .shared .align 8 .b8 _ZZ9matrixMulILi32EmEvPdS0_S0_T0_S1_E2As[8192]; 171 | // demoted variable 172 | .shared .align 8 .b8 _ZZ9matrixMulILi32EmEvPdS0_S0_T0_S1_E2Bs[8192]; 173 | 174 | ld.param.u64 %rd18, [matrixMul_bs32_64bit_param_0]; 175 | ld.param.u64 %rd19, [matrixMul_bs32_64bit_param_1]; 176 | ld.param.u64 %rd20, [matrixMul_bs32_64bit_param_2]; 177 | ld.param.u64 %rd21, [matrixMul_bs32_64bit_param_3]; 178 | ld.param.u64 %rd22, [matrixMul_bs32_64bit_param_4]; 179 | mov.u32 %r5, %ctaid.x; 180 | mov.u32 %r6, %ctaid.y; 181 | cvt.u64.u32 %rd1, %r6; 182 | mov.u32 %r7, %tid.x; 183 | cvt.u64.u32 %rd2, %r7; 184 | mov.u32 %r8, %tid.y; 185 | cvt.u64.u32 %rd3, %r8; 186 | mul.lo.s64 %rd23, %rd21, %rd1; 187 | shl.b64 %rd40, %rd23, 5; 188 | add.s64 %rd24, %rd21, -1; 189 | add.s64 %rd5, %rd24, %rd40; 190 | setp.lt.u64 %p1, %rd5, %rd24; 191 | mul.wide.u32 %rd6, %r5, 32; 192 | shl.b64 %rd7, %rd22, 5; 193 | mul.lo.s64 %rd8, %rd3, %rd22; 194 | mov.f64 %fd104, 0d0000000000000000; 195 | @%p1 bra $L__BB1_3; 196 | 197 | cvt.u32.u64 %r9, %rd3; 198 | mul.lo.s64 %rd25, %rd3, %rd21; 199 | add.s64 %rd9, %rd25, %rd2; 200 | cvt.u32.u64 %r10, %rd2; 201 | shl.b32 %r11, %r9, 8; 202 | mov.u32 %r12, _ZZ9matrixMulILi32EmEvPdS0_S0_T0_S1_E2As; 203 | add.s32 %r3, %r12, %r11; 204 | shl.b32 %r13, %r10, 3; 205 | add.s32 %r1, %r3, %r13; 206 | add.s64 %rd10, %rd8, %rd2; 207 | mov.u32 %r14, _ZZ9matrixMulILi32EmEvPdS0_S0_T0_S1_E2Bs; 208 | add.s32 %r15, %r14, %r11; 209 | add.s32 %r2, %r15, %r13; 210 | add.s32 %r4, %r14, %r13; 211 | cvta.to.global.u64 %rd12, %rd19; 212 | cvta.to.global.u64 %rd13, %rd20; 213 | mov.f64 %fd104, 0d0000000000000000; 214 | mov.u64 %rd39, %rd6; 215 | 216 | $L__BB1_2: 217 | add.s64 %rd26, %rd9, %rd40; 218 | shl.b64 %rd27, %rd26, 3; 219 | add.s64 %rd28, %rd12, %rd27; 220 | ld.global.f64 %fd6, [%rd28]; 221 | st.shared.f64 [%r1], %fd6; 222 | add.s64 %rd29, %rd10, %rd39; 223 | shl.b64 %rd30, %rd29, 3; 224 | add.s64 %rd31, %rd13, %rd30; 225 | ld.global.f64 %fd7, [%rd31]; 226 | st.shared.f64 [%r2], %fd7; 227 | bar.sync 0; 228 | ld.shared.f64 %fd8, [%r4]; 229 | ld.shared.f64 %fd9, [%r3]; 230 | fma.rn.f64 %fd10, %fd9, %fd8, %fd104; 231 | ld.shared.f64 %fd11, [%r4+256]; 232 | ld.shared.f64 %fd12, [%r3+8]; 233 | fma.rn.f64 %fd13, %fd12, %fd11, %fd10; 234 | ld.shared.f64 %fd14, [%r4+512]; 235 | ld.shared.f64 %fd15, [%r3+16]; 236 | fma.rn.f64 %fd16, %fd15, %fd14, %fd13; 237 | ld.shared.f64 %fd17, [%r4+768]; 238 | ld.shared.f64 %fd18, [%r3+24]; 239 | fma.rn.f64 %fd19, %fd18, %fd17, %fd16; 240 | ld.shared.f64 %fd20, [%r4+1024]; 241 | ld.shared.f64 %fd21, [%r3+32]; 242 | fma.rn.f64 %fd22, %fd21, %fd20, %fd19; 243 | ld.shared.f64 %fd23, [%r4+1280]; 244 | ld.shared.f64 %fd24, [%r3+40]; 245 | fma.rn.f64 %fd25, %fd24, %fd23, %fd22; 246 | ld.shared.f64 %fd26, [%r4+1536]; 247 | ld.shared.f64 %fd27, [%r3+48]; 248 | fma.rn.f64 %fd28, %fd27, %fd26, %fd25; 249 | ld.shared.f64 %fd29, [%r4+1792]; 250 | ld.shared.f64 %fd30, [%r3+56]; 251 | fma.rn.f64 %fd31, %fd30, %fd29, %fd28; 252 | ld.shared.f64 %fd32, [%r4+2048]; 253 | ld.shared.f64 %fd33, [%r3+64]; 254 | fma.rn.f64 %fd34, %fd33, %fd32, %fd31; 255 | ld.shared.f64 %fd35, [%r4+2304]; 256 | ld.shared.f64 %fd36, [%r3+72]; 257 | fma.rn.f64 %fd37, %fd36, %fd35, %fd34; 258 | ld.shared.f64 %fd38, [%r4+2560]; 259 | ld.shared.f64 %fd39, [%r3+80]; 260 | fma.rn.f64 %fd40, %fd39, %fd38, %fd37; 261 | ld.shared.f64 %fd41, [%r4+2816]; 262 | ld.shared.f64 %fd42, [%r3+88]; 263 | fma.rn.f64 %fd43, %fd42, %fd41, %fd40; 264 | ld.shared.f64 %fd44, [%r4+3072]; 265 | ld.shared.f64 %fd45, [%r3+96]; 266 | fma.rn.f64 %fd46, %fd45, %fd44, %fd43; 267 | ld.shared.f64 %fd47, [%r4+3328]; 268 | ld.shared.f64 %fd48, [%r3+104]; 269 | fma.rn.f64 %fd49, %fd48, %fd47, %fd46; 270 | ld.shared.f64 %fd50, [%r4+3584]; 271 | ld.shared.f64 %fd51, [%r3+112]; 272 | fma.rn.f64 %fd52, %fd51, %fd50, %fd49; 273 | ld.shared.f64 %fd53, [%r4+3840]; 274 | ld.shared.f64 %fd54, [%r3+120]; 275 | fma.rn.f64 %fd55, %fd54, %fd53, %fd52; 276 | ld.shared.f64 %fd56, [%r4+4096]; 277 | ld.shared.f64 %fd57, [%r3+128]; 278 | fma.rn.f64 %fd58, %fd57, %fd56, %fd55; 279 | ld.shared.f64 %fd59, [%r4+4352]; 280 | ld.shared.f64 %fd60, [%r3+136]; 281 | fma.rn.f64 %fd61, %fd60, %fd59, %fd58; 282 | ld.shared.f64 %fd62, [%r4+4608]; 283 | ld.shared.f64 %fd63, [%r3+144]; 284 | fma.rn.f64 %fd64, %fd63, %fd62, %fd61; 285 | ld.shared.f64 %fd65, [%r4+4864]; 286 | ld.shared.f64 %fd66, [%r3+152]; 287 | fma.rn.f64 %fd67, %fd66, %fd65, %fd64; 288 | ld.shared.f64 %fd68, [%r4+5120]; 289 | ld.shared.f64 %fd69, [%r3+160]; 290 | fma.rn.f64 %fd70, %fd69, %fd68, %fd67; 291 | ld.shared.f64 %fd71, [%r4+5376]; 292 | ld.shared.f64 %fd72, [%r3+168]; 293 | fma.rn.f64 %fd73, %fd72, %fd71, %fd70; 294 | ld.shared.f64 %fd74, [%r4+5632]; 295 | ld.shared.f64 %fd75, [%r3+176]; 296 | fma.rn.f64 %fd76, %fd75, %fd74, %fd73; 297 | ld.shared.f64 %fd77, [%r4+5888]; 298 | ld.shared.f64 %fd78, [%r3+184]; 299 | fma.rn.f64 %fd79, %fd78, %fd77, %fd76; 300 | ld.shared.f64 %fd80, [%r4+6144]; 301 | ld.shared.f64 %fd81, [%r3+192]; 302 | fma.rn.f64 %fd82, %fd81, %fd80, %fd79; 303 | ld.shared.f64 %fd83, [%r4+6400]; 304 | ld.shared.f64 %fd84, [%r3+200]; 305 | fma.rn.f64 %fd85, %fd84, %fd83, %fd82; 306 | ld.shared.f64 %fd86, [%r4+6656]; 307 | ld.shared.f64 %fd87, [%r3+208]; 308 | fma.rn.f64 %fd88, %fd87, %fd86, %fd85; 309 | ld.shared.f64 %fd89, [%r4+6912]; 310 | ld.shared.f64 %fd90, [%r3+216]; 311 | fma.rn.f64 %fd91, %fd90, %fd89, %fd88; 312 | ld.shared.f64 %fd92, [%r4+7168]; 313 | ld.shared.f64 %fd93, [%r3+224]; 314 | fma.rn.f64 %fd94, %fd93, %fd92, %fd91; 315 | ld.shared.f64 %fd95, [%r4+7424]; 316 | ld.shared.f64 %fd96, [%r3+232]; 317 | fma.rn.f64 %fd97, %fd96, %fd95, %fd94; 318 | ld.shared.f64 %fd98, [%r4+7680]; 319 | ld.shared.f64 %fd99, [%r3+240]; 320 | fma.rn.f64 %fd100, %fd99, %fd98, %fd97; 321 | ld.shared.f64 %fd101, [%r4+7936]; 322 | ld.shared.f64 %fd102, [%r3+248]; 323 | fma.rn.f64 %fd104, %fd102, %fd101, %fd100; 324 | bar.sync 0; 325 | add.s64 %rd39, %rd39, %rd7; 326 | add.s64 %rd40, %rd40, 32; 327 | setp.le.u64 %p2, %rd40, %rd5; 328 | @%p2 bra $L__BB1_2; 329 | 330 | $L__BB1_3: 331 | add.s64 %rd32, %rd6, %rd2; 332 | add.s64 %rd33, %rd32, %rd8; 333 | mul.lo.s64 %rd34, %rd7, %rd1; 334 | add.s64 %rd35, %rd33, %rd34; 335 | cvta.to.global.u64 %rd36, %rd18; 336 | shl.b64 %rd37, %rd35, 3; 337 | add.s64 %rd38, %rd36, %rd37; 338 | st.global.f64 [%rd38], %fd104; 339 | ret; 340 | 341 | } 342 | 343 | -------------------------------------------------------------------------------- /src/context.rs: -------------------------------------------------------------------------------- 1 | use crate::*; 2 | use num_enum::TryFromPrimitive; 3 | use std::{ptr::null_mut, rc::Rc}; 4 | 5 | /// A CUDA application context. 6 | /// To start interacting with a device, you want to [`Context::enter`] 7 | #[derive(Debug)] 8 | pub struct Context { 9 | pub(crate) inner: *mut sys::CUctx_st, 10 | } 11 | 12 | impl Context { 13 | /// Creates a new [`Context`] for a given [`Device`] 14 | pub fn new(device: &Device) -> CudaResult { 15 | let mut inner = null_mut(); 16 | cuda_error(unsafe { 17 | sys::cuCtxCreate_v2( 18 | &mut inner as *mut _, 19 | sys::CUctx_flags_enum_CU_CTX_SCHED_BLOCKING_SYNC, 20 | device.handle, 21 | ) 22 | })?; 23 | Ok(Context { inner }) 24 | } 25 | 26 | /// Gets the API version of the [`Context`]. 27 | /// This is not the compute capability of the device and probably not what you are looking for. See [`Device::compute_capability`] 28 | pub fn version(&self) -> CudaResult { 29 | let mut out = 0u32; 30 | cuda_error(unsafe { sys::cuCtxGetApiVersion(self.inner, &mut out as *mut u32) })?; 31 | Ok(out.into()) 32 | } 33 | 34 | /// Synchronize a [`Context`], running all active handles to completion 35 | pub fn synchronize(&self) -> CudaResult<()> { 36 | cuda_error(unsafe { sys::cuCtxSynchronize() }) 37 | } 38 | 39 | /// Set a CUDA context limit 40 | pub fn set_limit(&mut self, limit: LimitType, value: u64) -> CudaResult<()> { 41 | cuda_error(unsafe { sys::cuCtxSetLimit(limit as u32, value as sys::size_t) }) 42 | } 43 | 44 | /// Get a CUDA context limit 45 | pub fn get_limit(&self, limit: LimitType) -> CudaResult { 46 | let mut out: sys::size_t = 0; 47 | cuda_error(unsafe { sys::cuCtxGetLimit(&mut out as *mut sys::size_t, limit as u32) })?; 48 | Ok(out as u64) 49 | } 50 | 51 | /// Enter a [`Context`], consuming a mutable reference to the context, and allowing thread-local operations to happen. 52 | pub fn enter<'a>(&'a mut self) -> CudaResult>> { 53 | cuda_error(unsafe { sys::cuCtxSetCurrent(self.inner) })?; 54 | Ok(Rc::new(Handle { 55 | context: self, 56 | // async_stream_pool: RefCell::new(vec![]), 57 | })) 58 | } 59 | } 60 | 61 | impl Drop for Context { 62 | fn drop(&mut self) { 63 | if let Err(e) = cuda_error(unsafe { sys::cuCtxDestroy_v2(self.inner) }) { 64 | eprintln!("CUDA: failed to destroy cuda context: {:?}", e); 65 | } 66 | } 67 | } 68 | 69 | /// A CUDA [`Context`] handle for executing thread-local operations. 70 | pub struct Handle<'a> { 71 | pub(crate) context: &'a mut Context, 72 | // async_stream_pool: RefCell>>, 73 | } 74 | 75 | impl<'a> Handle<'a> { 76 | /// Get an immutable reference to the source context. 77 | pub fn context(&self) -> &Context { 78 | &self.context 79 | } 80 | 81 | // pub(crate) fn get_async_stream(self: &Rc>) -> CudaResult> { 82 | // let mut pool = self.async_stream_pool.borrow_mut(); 83 | // if pool.is_empty() { 84 | // Stream::new(self) 85 | // } else { 86 | // Ok(pool.pop().unwrap()) 87 | // } 88 | // } 89 | 90 | // pub(crate) fn reset_async_stream(self: &Rc>, stream: Stream<'a>) { 91 | // let mut pool = self.async_stream_pool.borrow_mut(); 92 | // pool.push(stream); 93 | // } 94 | } 95 | 96 | impl<'a> Drop for Handle<'a> { 97 | fn drop(&mut self) { 98 | if let Err(e) = cuda_error(unsafe { sys::cuCtxSetCurrent(null_mut()) }) { 99 | eprintln!("CUDA: error dropping context handle: {:?}", e); 100 | } 101 | } 102 | } 103 | 104 | /// Context limit types 105 | #[derive(Clone, Copy, Debug, TryFromPrimitive)] 106 | #[repr(u32)] 107 | pub enum LimitType { 108 | /// GPU thread stack size 109 | StackSize = 0x00, 110 | /// GPU printf FIFO size 111 | PrintfFifoSize = 0x01, 112 | /// GPU malloc heap size 113 | MallocHeapSize = 0x02, 114 | /// GPU device runtime launch synchronize depth 115 | DevRuntimeSyncDepth = 0x03, 116 | /// GPU device runtime pending launch count 117 | DevRuntimePendingLaunchCount = 0x04, 118 | /// A value between 0 and 128 that indicates the maximum fetch granularity of L2 (in Bytes). This is a hint 119 | MaxL2FetchGranularity = 0x05, 120 | /// A size in bytes for L2 persisting lines cache size 121 | PersistingL2CacheSize = 0x06, 122 | } 123 | -------------------------------------------------------------------------------- /src/device.rs: -------------------------------------------------------------------------------- 1 | use crate::*; 2 | use num_enum::TryFromPrimitive; 3 | 4 | /// A reference to a CUDA-enabled device 5 | pub struct Device { 6 | pub(crate) handle: i32, 7 | } 8 | 9 | /// Type of native array format 10 | #[derive(Clone, Copy, Debug, TryFromPrimitive)] 11 | #[repr(u32)] 12 | pub enum CudaArrayFormat { 13 | UnsignedInt8 = 0x01, 14 | UnsignedInt16 = 0x02, 15 | UnsignedInt32 = 0x03, 16 | SignedInt8 = 0x08, 17 | SignedInt16 = 0x09, 18 | SignedInt32 = 0x0a, 19 | Half = 0x10, 20 | Float = 0x20, 21 | Nv12 = 0xb0, 22 | } 23 | 24 | impl Device { 25 | /// Fetches a human-readable name from the device 26 | pub fn name(&self) -> CudaResult { 27 | let mut buf = [0u8; 256]; 28 | cuda_error(unsafe { sys::cuDeviceGetName(buf.as_mut_ptr() as *mut i8, 256, self.handle) })?; 29 | Ok( 30 | String::from_utf8_lossy(&buf[..buf.iter().position(|x| *x == 0).unwrap_or(0)]) 31 | .into_owned(), 32 | ) 33 | } 34 | 35 | /// Gets a UUID from the device 36 | pub fn uuid(&self) -> CudaResult { 37 | let mut out = 0u128; 38 | cuda_error(unsafe { sys::cuDeviceGetUuid(&mut out as *mut u128 as *mut _, self.handle) })?; 39 | Ok(out) 40 | } 41 | 42 | /// Gets the total available memory size of the device, in bytes 43 | pub fn memory_size(&self) -> CudaResult { 44 | let mut memory_size = 0usize; 45 | cuda_error(unsafe { 46 | sys::cuDeviceTotalMem_v2(&mut memory_size as *mut usize as *mut _, self.handle) 47 | })?; 48 | Ok(memory_size) 49 | } 50 | 51 | /// Gets a current attribute value for the device 52 | pub fn get_attribute(&self, attribute: DeviceAttribute) -> CudaResult { 53 | let mut out = 0i32; 54 | cuda_error(unsafe { 55 | sys::cuDeviceGetAttribute(&mut out as *mut i32, attribute as u32, self.handle) 56 | })?; 57 | Ok(out) 58 | } 59 | 60 | /// Gets the compute capability of the device 61 | pub fn compute_capability(&self) -> CudaResult { 62 | Ok(CudaVersion { 63 | major: self.get_attribute(DeviceAttribute::ComputeCapabilityMajor)? as u32, 64 | minor: self.get_attribute(DeviceAttribute::ComputeCapabilityMinor)? as u32, 65 | }) 66 | } 67 | 68 | /// Calculates the linear max width of 1D textures for a given native array format 69 | pub fn get_texture_1d_linear_max_width( 70 | &self, 71 | format: CudaArrayFormat, 72 | channels: u32, 73 | ) -> CudaResult { 74 | let mut out = 0usize; 75 | cuda_error(unsafe { 76 | sys::cuDeviceGetTexture1DLinearMaxWidth( 77 | &mut out as *mut usize as *mut _, 78 | format as u32, 79 | channels, 80 | self.handle, 81 | ) 82 | })?; 83 | Ok(out) 84 | } 85 | } 86 | 87 | impl Cuda { 88 | /// List all CUDA-enabled devices on the host 89 | pub fn list_devices() -> CudaResult> { 90 | let mut count = 0i32; 91 | cuda_error(unsafe { sys::cuDeviceGetCount(&mut count as *mut i32) })?; 92 | let mut out = Vec::with_capacity(count as usize); 93 | for i in 0..count { 94 | let mut device = Device { handle: 0 }; 95 | cuda_error(unsafe { sys::cuDeviceGet(&mut device.handle as *mut i32, i) })?; 96 | out.push(device); 97 | } 98 | Ok(out) 99 | } 100 | } 101 | 102 | /// A [`Device`]-specific attribute type 103 | #[derive(Clone, Copy, Debug, TryFromPrimitive)] 104 | #[repr(u32)] 105 | pub enum DeviceAttribute { 106 | MaxThreadsPerBlock = 1, 107 | MaxBlockDimX = 2, 108 | MaxBlockDimY = 3, 109 | MaxBlockDimZ = 4, 110 | MaxGridDimX = 5, 111 | MaxGridDimY = 6, 112 | MaxGridDimZ = 7, 113 | SharedMemoryPerBlock = 8, 114 | TotalConstantMemory = 9, 115 | WarpSize = 10, 116 | MaxPitch = 11, 117 | RegistersPerBlock = 12, 118 | ClockRate = 13, 119 | TextureAlignment = 14, 120 | GpuOverlap = 15, 121 | MultiprocessorCount = 16, 122 | KernelExecTimeout = 17, 123 | Integrated = 18, 124 | CanMapHostMemory = 19, 125 | ComputeMode = 20, 126 | MaximumTexture1dWidth = 21, 127 | MaximumTexture2dWidth = 22, 128 | MaximumTexture2dHeight = 23, 129 | MaximumTexture3dWidth = 24, 130 | MaximumTexture3dHeight = 25, 131 | MaximumTexture3dDepth = 26, 132 | MaximumTexture2dArrayWidth = 27, 133 | MaximumTexture2dArrayHeight = 28, 134 | MaximumTexture2dArrayNumslices = 29, 135 | SurfaceAlignment = 30, 136 | ConcurrentKernels = 31, 137 | EccEnabled = 32, 138 | PciBusId = 33, 139 | PciDeviceId = 34, 140 | TccDriver = 35, 141 | MemoryClockRate = 36, 142 | GlobalMemoryBusWidth = 37, 143 | L2CacheSize = 38, 144 | MaxThreadsPerMultiprocessor = 39, 145 | AsyncEngineCount = 40, 146 | UnifiedAddressing = 41, 147 | MaximumTexture1dLayeredWidth = 42, 148 | MaximumTexture1dLayeredLayers = 43, 149 | CanTex2dGather = 44, 150 | MaximumTexture2dGatherWidth45, 151 | MaximumTexture2dGatherHeight = 46, 152 | MaximumTexture3dWidthAlternate = 47, 153 | MaximumTexture3dHeightAlternate = 48, 154 | MaximumTexture3dDepthAlternate = 49, 155 | PciDomainId = 50, 156 | TexturePitchAlignment = 51, 157 | MaximumTexturecubemapWidth = 52, 158 | MaximumTexturecubemapLayeredWidth = 53, 159 | MaximumTexturecubemapLayeredLayers = 54, 160 | MaximumSurface1dWidth = 55, 161 | MaximumSurface2dWidth = 56, 162 | MaximumSurface2dHeight = 57, 163 | MaximumSurface3dWidth = 58, 164 | MaximumSurface3dHeight = 59, 165 | MaximumSurface3dDepth = 60, 166 | MaximumSurface1dLayeredWidth = 61, 167 | MaximumSurface1dLayeredLayers = 62, 168 | MaximumSurface2dLayeredWidth = 63, 169 | MaximumSurface2dLayeredHeight = 64, 170 | MaximumSurface2dLayeredLayers = 65, 171 | MaximumSurfacecubemapWidth = 66, 172 | MaximumSurfacecubemapLayeredWidth = 67, 173 | MaximumSurfacecubemapLayeredLayers = 68, 174 | MaximumTexture1dLinearWidth = 69, 175 | MaximumTexture2dLinearWidth = 70, 176 | MaximumTexture2dLinearHeight = 71, 177 | MaximumTexture2dLinearPitch = 72, 178 | MaximumTexture2dMipmappedWidth = 73, 179 | MaximumTexture2dMipmappedHeight = 74, 180 | ComputeCapabilityMajor = 75, 181 | ComputeCapabilityMinor = 76, 182 | MaximumTexture1dMipmappedWidth = 77, 183 | StreamPrioritiesSupported = 78, 184 | GlobalL1CacheSupported = 79, 185 | LocalL1CacheSupported = 80, 186 | MaxSharedMemoryPerMultiprocessor = 81, 187 | MaxRegistersPerMultiprocessor = 82, 188 | ManagedMemory = 83, 189 | MultiGpuBoard = 84, 190 | MultiGpuBoardGroupId = 85, 191 | HostNativeAtomicSupported = 86, 192 | SingleToDoublePrecisionPerfRatio = 87, 193 | PageableMemoryAccess = 88, 194 | ConcurrentManagedAccess = 89, 195 | ComputePreemptionSupported = 90, 196 | CanUseHostPointerForRegisteredMem = 91, 197 | CanUseStreamMemOps = 92, 198 | CanUse64BitStreamMemOps = 93, 199 | CanUseStreamWaitValueNor = 94, 200 | CooperativeLaunch = 95, 201 | CooperativeMultiDeviceLaunch = 96, 202 | MaxSharedMemoryPerBlockOptin = 97, 203 | CanFlushRemoteWrites = 98, 204 | HostRegisterSupported = 99, 205 | PageableMemoryAccessUsesHostPageTables = 100, 206 | DirectManagedMemAccessFromHost = 101, 207 | VirtualMemoryManagementSupported = 102, 208 | HandleTypePosixFileDescriptorSupported = 103, 209 | HandleTypeWin32HandleSupported = 104, 210 | HandleTypeWin32KmtHandleSupported = 105, 211 | MaxBlocksPerMultiprocessor = 106, 212 | GenericCompressionSupported = 107, 213 | MaxPersistingL2CacheSize = 108, 214 | MaxAccessPolicyWindowSize = 109, 215 | GpuDirectRdmaWithCudaVmmSupported = 110, 216 | ReservedSharedMemoryPerBlock = 111, 217 | SparseCudaArraySupported = 112, 218 | ReadOnlyHostRegisterSupported = 113, 219 | TimelineSemaphoreInteropSupported = 114, 220 | MemoryPoolsSupported = 115, 221 | GpuDirectRdmaSupported = 116, 222 | GpuDirectRdmaFlushWritesOptions = 117, 223 | GpuDirectRdmaWritesOrdering = 118, 224 | MempoolSupportedHandleTypes = 119, 225 | } 226 | -------------------------------------------------------------------------------- /src/dim3.rs: -------------------------------------------------------------------------------- 1 | use std::ops::{Deref, DerefMut}; 2 | 3 | /// A dimensional value equivalent to a 3-tuple of u32 4 | pub struct Dim3(pub (u32, u32, u32)); 5 | 6 | impl Deref for Dim3 { 7 | type Target = (u32, u32, u32); 8 | 9 | fn deref(&self) -> &Self::Target { 10 | &self.0 11 | } 12 | } 13 | 14 | impl DerefMut for Dim3 { 15 | fn deref_mut(&mut self) -> &mut Self::Target { 16 | &mut self.0 17 | } 18 | } 19 | 20 | impl Into<(u32, u32, u32)> for Dim3 { 21 | fn into(self) -> (u32, u32, u32) { 22 | self.0 23 | } 24 | } 25 | 26 | impl From<(u32, u32, u32)> for Dim3 { 27 | fn from(inner: (u32, u32, u32)) -> Self { 28 | Self(inner) 29 | } 30 | } 31 | 32 | impl From<(u32, u32)> for Dim3 { 33 | fn from((x, y): (u32, u32)) -> Self { 34 | Self((x, y, 1)) 35 | } 36 | } 37 | 38 | impl From<(u32,)> for Dim3 { 39 | fn from((x,): (u32,)) -> Self { 40 | Self((x, 1, 1)) 41 | } 42 | } 43 | 44 | impl From for Dim3 { 45 | fn from(x: u32) -> Self { 46 | Self((x, 1, 1)) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | use num_enum::TryFromPrimitive; 2 | use std::{ 3 | error::Error, 4 | fmt::{self, Debug}, 5 | }; 6 | 7 | /// A device-sourced or `libcuda`-sourced error code 8 | #[derive(Debug, Copy, Clone, TryFromPrimitive)] 9 | #[repr(u32)] 10 | pub enum ErrorCode { 11 | #[doc = "The API call returned with no errors. In the case of query calls, this"] 12 | #[doc = "also means that the operation being queried is complete (see"] 13 | #[doc = "::cuEventQuery() and ::cuStreamQuery())."] 14 | Success = 0, 15 | #[doc = "This indicates that one or more of the parameters passed to the API call"] 16 | #[doc = "is not within an acceptable range of values."] 17 | InvalidValue = 1, 18 | #[doc = "The API call failed because it was unable to allocate enough memory to"] 19 | #[doc = "perform the requested operation."] 20 | OutOfMemory = 2, 21 | #[doc = "This indicates that the CUDA driver has not been initialized with"] 22 | #[doc = "::cuInit() or that initialization has failed."] 23 | NotInitialized = 3, 24 | #[doc = "This indicates that the CUDA driver is in the process of shutting down."] 25 | Deinitialized = 4, 26 | #[doc = "This indicates profiler is not initialized for this run. This can"] 27 | #[doc = "happen when the application is running with external profiling tools"] 28 | #[doc = "like visual profiler."] 29 | ProfilerDisabled = 5, 30 | #[doc = "\\deprecated"] 31 | #[doc = "This error return is deprecated as of CUDA 5.0. It is no longer an error"] 32 | #[doc = "to attempt to enable/disable the profiling via ::cuProfilerStart or"] 33 | #[doc = "::cuProfilerStop without initialization."] 34 | ProfilerNotInitialized = 6, 35 | #[doc = "\\deprecated"] 36 | #[doc = "This error return is deprecated as of CUDA 5.0. It is no longer an error"] 37 | #[doc = "to call cuProfilerStart() when profiling is already enabled."] 38 | ProfilerAlreadyStarted = 7, 39 | #[doc = "\\deprecated"] 40 | #[doc = "This error return is deprecated as of CUDA 5.0. It is no longer an error"] 41 | #[doc = "to call cuProfilerStop() when profiling is already disabled."] 42 | ProfilerAlreadyStopped = 8, 43 | #[doc = "This indicates that the CUDA driver that the application has loaded is a"] 44 | #[doc = "stub library. Applications that run with the stub rather than a real"] 45 | #[doc = "driver loaded will result in CUDA API returning this error."] 46 | StubLibrary = 34, 47 | #[doc = "This indicates that no CUDA-capable devices were detected by the installed"] 48 | #[doc = "CUDA driver."] 49 | NoDevice = 100, 50 | #[doc = "This indicates that the device ordinal supplied by the user does not"] 51 | #[doc = "correspond to a valid CUDA device."] 52 | InvalidDevice = 101, 53 | #[doc = "This error indicates that the Grid license is not applied."] 54 | DeviceNotLicensed = 102, 55 | #[doc = "This indicates that the device kernel image is invalid. This can also"] 56 | #[doc = "indicate an invalid CUDA module."] 57 | InvalidImage = 200, 58 | #[doc = "This most frequently indicates that there is no context bound to the"] 59 | #[doc = "current thread. This can also be returned if the context passed to an"] 60 | #[doc = "API call is not a valid handle (such as a context that has had"] 61 | #[doc = "::cuCtxDestroy() invoked on it). This can also be returned if a user"] 62 | #[doc = "mixes different API versions (i.e. 3010 context with 3020 API calls)."] 63 | #[doc = "See ::cuCtxGetApiVersion() for more details."] 64 | InvalidContext = 201, 65 | #[doc = "This indicated that the context being supplied as a parameter to the"] 66 | #[doc = "API call was already the active context."] 67 | #[doc = "\\deprecated"] 68 | #[doc = "This error return is deprecated as of CUDA 3.2. It is no longer an"] 69 | #[doc = "error to attempt to push the active context via ::cuCtxPushCurrent()."] 70 | ContextAlreadyCurrent = 202, 71 | #[doc = "This indicates that a map or register operation has failed."] 72 | MapFailed = 205, 73 | #[doc = "This indicates that an unmap or unregister operation has failed."] 74 | UnmapFailed = 206, 75 | #[doc = "This indicates that the specified array is currently mapped and thus"] 76 | #[doc = "cannot be destroyed."] 77 | ArrayIsMapped = 207, 78 | #[doc = "This indicates that the resource is already mapped."] 79 | AlreadyMapped = 208, 80 | #[doc = "This indicates that there is no kernel image available that is suitable"] 81 | #[doc = "for the device. This can occur when a user specifies code generation"] 82 | #[doc = "options for a particular CUDA source file that do not include the"] 83 | #[doc = "corresponding device configuration."] 84 | NoBinaryForGpu = 209, 85 | #[doc = "This indicates that a resource has already been acquired."] 86 | AlreadyAcquired = 210, 87 | #[doc = "This indicates that a resource is not mapped."] 88 | NotMapped = 211, 89 | #[doc = "This indicates that a mapped resource is not available for access as an"] 90 | #[doc = "array."] 91 | NotMappedAsArray = 212, 92 | #[doc = "This indicates that a mapped resource is not available for access as a"] 93 | #[doc = "pointer."] 94 | NotMappedAsPointer = 213, 95 | #[doc = "This indicates that an uncorrectable ECC error was detected during"] 96 | #[doc = "execution."] 97 | EccUncorrectable = 214, 98 | #[doc = "This indicates that the ::CUlimit passed to the API call is not"] 99 | #[doc = "supported by the active device."] 100 | UnsupportedLimit = 215, 101 | #[doc = "This indicates that the ::CUcontext passed to the API call can"] 102 | #[doc = "only be bound to a single CPU thread at a time but is already"] 103 | #[doc = "bound to a CPU thread."] 104 | ContextAlreadyInUse = 216, 105 | #[doc = "This indicates that peer access is not supported across the given"] 106 | #[doc = "devices."] 107 | PeerAccessUnsupported = 217, 108 | #[doc = "This indicates that a PTX JIT compilation failed."] 109 | InvalidPtx = 218, 110 | #[doc = "This indicates an error with OpenGL or DirectX context."] 111 | InvalidGraphicsContext = 219, 112 | #[doc = "This indicates that an uncorrectable NVLink error was detected during the"] 113 | #[doc = "execution."] 114 | NvlinkUncorrectable = 220, 115 | #[doc = "This indicates that the PTX JIT compiler library was not found."] 116 | JitCompilerNotFound = 221, 117 | #[doc = "This indicates that the provided PTX was compiled with an unsupported toolchain."] 118 | UnsupportedPtxVersion = 222, 119 | #[doc = "This indicates that the PTX JIT compilation was disabled."] 120 | JitCompilationDisabled = 223, 121 | #[doc = "This indicates that the device kernel source is invalid."] 122 | InvalidSource = 300, 123 | #[doc = "This indicates that the file specified was not found."] 124 | FileNotFound = 301, 125 | #[doc = "This indicates that a link to a shared object failed to resolve."] 126 | SharedObjectSymbolNotFound = 302, 127 | #[doc = "This indicates that initialization of a shared object failed."] 128 | SharedObjectInitFailed = 303, 129 | #[doc = "This indicates that an OS call failed."] 130 | OperatingSystem = 304, 131 | #[doc = "This indicates that a resource handle passed to the API call was not"] 132 | #[doc = "valid. Resource handles are opaque types like ::CUstream and ::CUevent."] 133 | InvalidHandle = 400, 134 | #[doc = "This indicates that a resource required by the API call is not in a"] 135 | #[doc = "valid state to perform the requested operation."] 136 | IllegalState = 401, 137 | #[doc = "This indicates that a named symbol was not found. Examples of symbols"] 138 | #[doc = "are global/constant variable names, driver function names, texture names,"] 139 | #[doc = "and surface names."] 140 | NotFound = 500, 141 | #[doc = "This indicates that asynchronous operations issued previously have not"] 142 | #[doc = "completed yet. This result is not actually an error, but must be indicated"] 143 | #[doc = "differently than ::CUDA_SUCCESS (which indicates completion). Calls that"] 144 | #[doc = "may return this value include ::cuEventQuery() and ::cuStreamQuery()."] 145 | NotReady = 600, 146 | #[doc = "While executing a kernel, the device encountered a"] 147 | #[doc = "load or store instruction on an invalid memory address."] 148 | #[doc = "This leaves the process in an inconsistent state and any further CUDA work"] 149 | #[doc = "will return the same error. To continue using CUDA, the process must be terminated"] 150 | #[doc = "and relaunched."] 151 | IllegalAddress = 700, 152 | #[doc = "This indicates that a launch did not occur because it did not have"] 153 | #[doc = "appropriate resources. This error usually indicates that the user has"] 154 | #[doc = "attempted to pass too many arguments to the device kernel, or the"] 155 | #[doc = "kernel launch specifies too many threads for the kernel's register"] 156 | #[doc = "count. Passing arguments of the wrong size (i.e. a 64-bit pointer"] 157 | #[doc = "when a 32-bit int is expected) is equivalent to passing too many"] 158 | #[doc = "arguments and can also result in this error."] 159 | LaunchOutOfResources = 701, 160 | #[doc = "This indicates that the device kernel took too long to execute. This can"] 161 | #[doc = "only occur if timeouts are enabled - see the device attribute"] 162 | #[doc = "::CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT for more information."] 163 | #[doc = "This leaves the process in an inconsistent state and any further CUDA work"] 164 | #[doc = "will return the same error. To continue using CUDA, the process must be terminated"] 165 | #[doc = "and relaunched."] 166 | LaunchTimeout = 702, 167 | #[doc = "This error indicates a kernel launch that uses an incompatible texturing"] 168 | #[doc = "mode."] 169 | LaunchIncompatibleTexturing = 703, 170 | #[doc = "This error indicates that a call to ::cuCtxEnablePeerAccess() is"] 171 | #[doc = "trying to re-enable peer access to a context which has already"] 172 | #[doc = "had peer access to it enabled."] 173 | PeerAccessAlreadyEnabled = 704, 174 | #[doc = "This error indicates that ::cuCtxDisablePeerAccess() is"] 175 | #[doc = "trying to disable peer access which has not been enabled yet"] 176 | #[doc = "via ::cuCtxEnablePeerAccess()."] 177 | PeerAccessNotEnabled = 705, 178 | #[doc = "This error indicates that the primary context for the specified device"] 179 | #[doc = "has already been initialized."] 180 | PrimaryContextActive = 708, 181 | #[doc = "This error indicates that the context current to the calling thread"] 182 | #[doc = "has been destroyed using ::cuCtxDestroy, or is a primary context which"] 183 | #[doc = "has not yet been initialized."] 184 | ContextIsDestroyed = 709, 185 | #[doc = "A device-side assert triggered during kernel execution. The context"] 186 | #[doc = "cannot be used anymore, and must be destroyed. All existing device"] 187 | #[doc = "memory allocations from this context are invalid and must be"] 188 | #[doc = "reconstructed if the program is to continue using CUDA."] 189 | Assert = 710, 190 | #[doc = "This error indicates that the hardware resources required to enable"] 191 | #[doc = "peer access have been exhausted for one or more of the devices"] 192 | #[doc = "passed to ::cuCtxEnablePeerAccess()."] 193 | TooManyPeers = 711, 194 | #[doc = "This error indicates that the memory range passed to ::cuMemHostRegister()"] 195 | #[doc = "has already been registered."] 196 | HostMemoryAlreadyRegistered = 712, 197 | #[doc = "This error indicates that the pointer passed to ::cuMemHostUnregister()"] 198 | #[doc = "does not correspond to any currently registered memory region."] 199 | HostMemoryNotRegistered = 713, 200 | #[doc = "While executing a kernel, the device encountered a stack error."] 201 | #[doc = "This can be due to stack corruption or exceeding the stack size limit."] 202 | #[doc = "This leaves the process in an inconsistent state and any further CUDA work"] 203 | #[doc = "will return the same error. To continue using CUDA, the process must be terminated"] 204 | #[doc = "and relaunched."] 205 | HardwareStackError = 714, 206 | #[doc = "While executing a kernel, the device encountered an illegal instruction."] 207 | #[doc = "This leaves the process in an inconsistent state and any further CUDA work"] 208 | #[doc = "will return the same error. To continue using CUDA, the process must be terminated"] 209 | #[doc = "and relaunched."] 210 | IllegalInstruction = 715, 211 | #[doc = "While executing a kernel, the device encountered a load or store instruction"] 212 | #[doc = "on a memory address which is not aligned."] 213 | #[doc = "This leaves the process in an inconsistent state and any further CUDA work"] 214 | #[doc = "will return the same error. To continue using CUDA, the process must be terminated"] 215 | #[doc = "and relaunched."] 216 | MisalignedAddress = 716, 217 | #[doc = "While executing a kernel, the device encountered an instruction"] 218 | #[doc = "which can only operate on memory locations in certain address spaces"] 219 | #[doc = "(global, shared, or local), but was supplied a memory address not"] 220 | #[doc = "belonging to an allowed address space."] 221 | #[doc = "This leaves the process in an inconsistent state and any further CUDA work"] 222 | #[doc = "will return the same error. To continue using CUDA, the process must be terminated"] 223 | #[doc = "and relaunched."] 224 | InvalidAddressSpace = 717, 225 | #[doc = "While executing a kernel, the device program counter wrapped its address space."] 226 | #[doc = "This leaves the process in an inconsistent state and any further CUDA work"] 227 | #[doc = "will return the same error. To continue using CUDA, the process must be terminated"] 228 | #[doc = "and relaunched."] 229 | InvalidPc = 718, 230 | #[doc = "An exception occurred on the device while executing a kernel. Common"] 231 | #[doc = "causes include dereferencing an invalid device pointer and accessing"] 232 | #[doc = "out of bounds shared memory. Less common cases can be system specific - more"] 233 | #[doc = "information about these cases can be found in the system specific user guide."] 234 | #[doc = "This leaves the process in an inconsistent state and any further CUDA work"] 235 | #[doc = "will return the same error. To continue using CUDA, the process must be terminated"] 236 | #[doc = "and relaunched."] 237 | LaunchFailed = 719, 238 | #[doc = "This error indicates that the number of blocks launched per grid for a kernel that was"] 239 | #[doc = "launched via either ::cuLaunchCooperativeKernel or ::cuLaunchCooperativeKernelMultiDevice"] 240 | #[doc = "exceeds the maximum number of blocks as allowed by ::cuOccupancyMaxActiveBlocksPerMultiprocessor"] 241 | #[doc = "or ::cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags times the number of multiprocessors"] 242 | #[doc = "as specified by the device attribute ::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT."] 243 | CooperativeLaunchTooLarge = 720, 244 | #[doc = "This error indicates that the attempted operation is not permitted."] 245 | NotPermitted = 800, 246 | #[doc = "This error indicates that the attempted operation is not supported"] 247 | #[doc = "on the current system or device."] 248 | NotSupported = 801, 249 | #[doc = "This error indicates that the system is not yet ready to start any CUDA"] 250 | #[doc = "work. To continue using CUDA, verify the system configuration is in a"] 251 | #[doc = "valid state and all required driver daemons are actively running."] 252 | #[doc = "More information about this error can be found in the system specific"] 253 | #[doc = "user guide."] 254 | SystemNotReady = 802, 255 | #[doc = "This error indicates that there is a mismatch between the versions of"] 256 | #[doc = "the display driver and the CUDA driver. Refer to the compatibility documentation"] 257 | #[doc = "for supported versions."] 258 | SystemDriverMismatch = 803, 259 | #[doc = "This error indicates that the system was upgraded to run with forward compatibility"] 260 | #[doc = "but the visible hardware detected by CUDA does not support this configuration."] 261 | #[doc = "Refer to the compatibility documentation for the supported hardware matrix or ensure"] 262 | #[doc = "that only supported hardware is visible during initialization via the CUDA_VISIBLE_DEVICES"] 263 | #[doc = "environment variable."] 264 | CompatNotSupportedOnDevice = 804, 265 | #[doc = "This error indicates that the operation is not permitted when"] 266 | #[doc = "the stream is capturing."] 267 | StreamCaptureUnsupported = 900, 268 | #[doc = "This error indicates that the current capture sequence on the stream"] 269 | #[doc = "has been invalidated due to a previous error."] 270 | StreamCaptureInvalidated = 901, 271 | #[doc = "This error indicates that the operation would have resulted in a merge"] 272 | #[doc = "of two independent capture sequences."] 273 | StreamCaptureMerge = 902, 274 | #[doc = "This error indicates that the capture was not initiated in this stream."] 275 | StreamCaptureUnmatched = 903, 276 | #[doc = "This error indicates that the capture sequence contains a fork that was"] 277 | #[doc = "not joined to the primary stream."] 278 | StreamCaptureUnjoined = 904, 279 | #[doc = "This error indicates that a dependency would have been created which"] 280 | #[doc = "crosses the capture sequence boundary. Only implicit in-stream ordering"] 281 | #[doc = "dependencies are allowed to cross the boundary."] 282 | StreamCaptureIsolation = 905, 283 | #[doc = "This error indicates a disallowed implicit dependency on a current capture"] 284 | #[doc = "sequence from cudaStreamLegacy."] 285 | StreamCaptureImplicit = 906, 286 | #[doc = "This error indicates that the operation is not permitted on an event which"] 287 | #[doc = "was last recorded in a capturing stream."] 288 | CapturedEvent = 907, 289 | #[doc = "A stream capture sequence not initiated with the ::CU_STREAM_CAPTURE_MODE_RELAXED"] 290 | #[doc = "argument to ::cuStreamBeginCapture was passed to ::cuStreamEndCapture in a"] 291 | #[doc = "different thread."] 292 | StreamCaptureWrongThread = 908, 293 | #[doc = "This error indicates that the timeout specified for the wait operation has lapsed."] 294 | Timeout = 909, 295 | #[doc = "This error indicates that the graph update was not performed because it included"] 296 | #[doc = "changes which violated constraints specific to instantiated graph update."] 297 | GraphExecUpdateFailure = 910, 298 | #[doc = "This indicates that an unknown internal error has occurred."] 299 | Unknown = 999, 300 | } 301 | 302 | impl fmt::Display for ErrorCode { 303 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 304 | ::fmt(self, f) 305 | } 306 | } 307 | 308 | impl Error for ErrorCode {} 309 | 310 | pub type CudaResult = Result; 311 | 312 | pub(crate) fn cuda_error(input: u32) -> CudaResult<()> { 313 | if input == 0 { 314 | Ok(()) 315 | } else { 316 | Err(ErrorCode::try_from_primitive(input).unwrap_or(ErrorCode::Unknown)) 317 | } 318 | } 319 | -------------------------------------------------------------------------------- /src/func.rs: -------------------------------------------------------------------------------- 1 | use crate::*; 2 | use num_enum::TryFromPrimitive; 3 | 4 | /// A [`Function`]-specific attribute type 5 | #[derive(Debug, Copy, Clone, TryFromPrimitive)] 6 | #[repr(u32)] 7 | pub enum FunctionAttribute { 8 | /// The maximum number of threads per block, beyond which a launch of the function would fail. This number depends on both the function and the device on which the function is currently loaded. 9 | MaxThreadsPerBlock = 0, 10 | /// The size in bytes of statically-allocated shared memory required by this function. This does not include dynamically-allocated shared memory requested by the user at runtime. 11 | SharedSizeBytes = 1, 12 | /// The size in bytes of user-allocated constant memory required by this function. 13 | ConstSizeBytes = 2, 14 | /// The size in bytes of local memory used by each thread of this function. 15 | LocalSizeBytes = 3, 16 | /// The number of registers used by each thread of this function. 17 | NumRegs = 4, 18 | /// The PTX virtual architecture version for which the function was compiled. This value is the major PTX version * 10 + the minor PTX version, so a PTX version 1.3 function would return the value 13. Note that this may return the undefined value of 0 for cubins compiled prior to CUDA 3.0. 19 | PtxVersion = 5, 20 | /// The binary architecture version for which the function was compiled. This value is the major binary version * 10 + the minor binary version, so a binary version 1.3 function would return the value 13. Note that this will return a value of 10 for legacy cubins that do not have a properly-encoded binary architecture version. 21 | BinaryVersion = 6, 22 | /// The attribute to indicate whether the function has been compiled with user specified option "-Xptxas --dlcm=ca" set . 23 | CacheModeCa = 7, 24 | /// The maximum size in bytes of dynamically-allocated shared memory that can be used by this function. If the user-specified dynamic shared memory size is larger than this value, the launch will fail. See cuFuncSetAttribute 25 | MaxDynamicSharedSizeBytes = 8, 26 | /// On devices where the L1 cache and shared memory use the same hardware resources, this sets the shared memory carveout preference, in percent of the total shared memory. Refer to CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR. This is only a hint, and the driver can choose a different ratio if required to execute the function. See cuFuncSetAttribute 27 | PreferredSharedMemoryCarveout = 9, 28 | } 29 | 30 | /// A [`Function`] cache config 31 | #[derive(Debug, Copy, Clone, TryFromPrimitive)] 32 | #[repr(u32)] 33 | pub enum FuncCache { 34 | /// no preference for shared memory or L1 (default) 35 | PreferNone = 0x00, 36 | /// prefer larger shared memory and smaller L1 cache 37 | PreferShared = 0x01, 38 | /// prefer larger L1 cache and smaller shared memory 39 | PreferL1 = 0x02, 40 | /// prefer equal sized L1 cache and shared memory 41 | PreferEqual = 0x03, 42 | } 43 | 44 | /// A [`Function`] shared memory config 45 | #[derive(Debug, Copy, Clone, TryFromPrimitive)] 46 | #[repr(u32)] 47 | pub enum FuncSharedConfig { 48 | /// set default shared memory bank size 49 | DefaultBankSize = 0x00, 50 | /// set shared memory bank width to four bytes 51 | FourByteBankSize = 0x01, 52 | /// set shared memory bank width to eight bytes 53 | EightByteBankSize = 0x02, 54 | } 55 | 56 | /// Represents an individual callable Kernel loaded from a [`Module`] 57 | pub struct Function<'a, 'b> { 58 | pub(crate) module: &'b Module<'a>, 59 | pub(crate) inner: *mut sys::CUfunc_st, 60 | } 61 | 62 | impl<'a, 'b> Function<'a, 'b> { 63 | /// Returns a module handle. 64 | pub fn module(&self) -> &'b Module<'a> { 65 | self.module 66 | } 67 | 68 | /// Returns information about a function. 69 | pub fn get_attribute(&self, attribute: FunctionAttribute) -> CudaResult { 70 | let mut out = 0i32; 71 | cuda_error(unsafe { 72 | sys::cuFuncGetAttribute(&mut out as *mut i32, attribute as u32, self.inner) 73 | })?; 74 | Ok(out) 75 | } 76 | 77 | /// Sets information about a function. 78 | pub fn set_attribute(&mut self, attribute: FunctionAttribute, value: i32) -> CudaResult<()> { 79 | cuda_error(unsafe { sys::cuFuncSetAttribute(self.inner, attribute as u32, value) }) 80 | } 81 | 82 | /// Sets the preferred cache configuration for a device function. 83 | pub fn set_cache_config(&mut self, func_cache: FuncCache) -> CudaResult<()> { 84 | cuda_error(unsafe { sys::cuFuncSetCacheConfig(self.inner, func_cache as u32) }) 85 | } 86 | 87 | /// Sets the shared memory configuration for a device function. 88 | pub fn set_shared_mem_config(&mut self, config: FuncSharedConfig) -> CudaResult<()> { 89 | cuda_error(unsafe { sys::cuFuncSetSharedMemConfig(self.inner, config as u32) }) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /src/future.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | future::Future, 3 | pin::Pin, 4 | rc::Rc, 5 | sync::{Arc, Mutex}, 6 | task::{Context, Poll}, 7 | }; 8 | 9 | use crate::*; 10 | 11 | enum InteriorState { 12 | Waiting, 13 | Some(T), 14 | Taken, 15 | } 16 | 17 | pub struct CudaFuture<'a, T> { 18 | interior: Arc>>>, 19 | active_stream: Option>, 20 | handle: Rc>, 21 | } 22 | 23 | impl<'a> Unpin for CudaFuture<'a, ()> {} 24 | 25 | impl<'a> Future for CudaFuture<'a, ()> { 26 | type Output = CudaResult<()>; 27 | 28 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 29 | let mut interior = self.interior.lock().unwrap(); 30 | match &mut *interior { 31 | InteriorState::Taken => panic!("over polled CudaFuture (do you need to fuse?)"), 32 | x @ InteriorState::Some(_) => { 33 | let mut state = InteriorState::Taken; 34 | std::mem::swap(x, &mut state); 35 | if let InteriorState::Some(x) = state { 36 | Poll::Ready(x) 37 | } else { 38 | unimplemented!() 39 | } 40 | } 41 | InteriorState::Waiting => { 42 | let waker = cx.waker().clone(); 43 | drop(interior); 44 | let new_self = self.interior.clone(); 45 | match self.active_stream.as_mut().unwrap().callback(move || { 46 | let mut inner = new_self.lock().unwrap(); 47 | if matches!(&*inner, InteriorState::Waiting) { 48 | *inner = InteriorState::Some(Ok(())); 49 | } 50 | waker.wake() 51 | }) { 52 | Ok(()) => Poll::Pending, 53 | Err(e) => Poll::Ready(Err(e)), 54 | } 55 | } 56 | } 57 | } 58 | } 59 | 60 | impl<'a> CudaFuture<'a, ()> { 61 | pub(crate) fn new(handle: Rc>, stream: Stream<'a>) -> Self { 62 | CudaFuture { 63 | interior: Arc::new(Mutex::new(InteriorState::Waiting)), 64 | active_stream: Some(stream), 65 | handle, 66 | } 67 | } 68 | } 69 | 70 | impl<'a, T> Drop for CudaFuture<'a, T> { 71 | fn drop(&mut self) { 72 | self.handle 73 | .reset_async_stream(self.active_stream.take().unwrap()); 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/init.rs: -------------------------------------------------------------------------------- 1 | use std::sync::atomic::{AtomicBool, Ordering}; 2 | 3 | use crate::{error::*, sys, Cuda}; 4 | 5 | static CHECK_INIT: AtomicBool = AtomicBool::new(false); 6 | impl Cuda { 7 | /// Initialize the CUDA library. Can be called repeatedly at no cost. 8 | pub fn init() -> CudaResult<()> { 9 | if CHECK_INIT.load(Ordering::SeqCst) { 10 | return Ok(()); 11 | } 12 | cuda_error(unsafe { sys::cuInit(0) })?; 13 | CHECK_INIT.store(true, Ordering::SeqCst); 14 | Ok(()) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/kernel_params.rs: -------------------------------------------------------------------------------- 1 | use crate::{DeviceBox, DevicePtr}; 2 | 3 | /// Some data able to represent one or more kernel parameters 4 | pub trait KernelParameters { 5 | fn params(&self, out: &mut Vec>); 6 | } 7 | 8 | impl KernelParameters for u8 { 9 | fn params(&self, out: &mut Vec>) { 10 | out.push(vec![*self]); 11 | } 12 | } 13 | 14 | impl KernelParameters for u16 { 15 | fn params(&self, out: &mut Vec>) { 16 | out.push(self.to_le_bytes().to_vec()); 17 | } 18 | } 19 | 20 | impl KernelParameters for u32 { 21 | fn params(&self, out: &mut Vec>) { 22 | out.push(self.to_le_bytes().to_vec()); 23 | } 24 | } 25 | 26 | impl KernelParameters for u64 { 27 | fn params(&self, out: &mut Vec>) { 28 | out.push(self.to_le_bytes().to_vec()); 29 | } 30 | } 31 | 32 | impl KernelParameters for usize { 33 | fn params(&self, out: &mut Vec>) { 34 | out.push(self.to_le_bytes().to_vec()); 35 | } 36 | } 37 | 38 | impl KernelParameters for i8 { 39 | fn params(&self, out: &mut Vec>) { 40 | out.push(self.to_le_bytes().to_vec()); 41 | } 42 | } 43 | 44 | impl KernelParameters for i16 { 45 | fn params(&self, out: &mut Vec>) { 46 | out.push(self.to_le_bytes().to_vec()); 47 | } 48 | } 49 | 50 | impl KernelParameters for i32 { 51 | fn params(&self, out: &mut Vec>) { 52 | out.push(self.to_le_bytes().to_vec()); 53 | } 54 | } 55 | 56 | impl KernelParameters for i64 { 57 | fn params(&self, out: &mut Vec>) { 58 | out.push(self.to_le_bytes().to_vec()); 59 | } 60 | } 61 | 62 | impl KernelParameters for f32 { 63 | fn params(&self, out: &mut Vec>) { 64 | out.push(self.to_le_bytes().to_vec()); 65 | } 66 | } 67 | 68 | impl KernelParameters for f64 { 69 | fn params(&self, out: &mut Vec>) { 70 | out.push(self.to_le_bytes().to_vec()); 71 | } 72 | } 73 | 74 | /// WARNING: this is unsafe! 75 | impl<'a> KernelParameters for DevicePtr<'a> { 76 | fn params(&self, out: &mut Vec>) { 77 | out.push(self.inner.to_le_bytes().to_vec()); 78 | } 79 | } 80 | 81 | /// WARNING: this is unsafe! 82 | impl<'a, 'b> KernelParameters for &'b DeviceBox<'a> { 83 | fn params(&self, out: &mut Vec>) { 84 | out.push(self.inner.inner.to_le_bytes().to_vec()); 85 | } 86 | } 87 | 88 | impl KernelParameters for &[u8] { 89 | fn params(&self, out: &mut Vec>) { 90 | out.push(self.to_vec()); 91 | } 92 | } 93 | 94 | impl KernelParameters for Vec { 95 | fn params(&self, out: &mut Vec>) { 96 | out.push(self.clone()); 97 | } 98 | } 99 | 100 | impl KernelParameters for [T; N] { 101 | fn params(&self, out: &mut Vec>) { 102 | for x in self { 103 | x.params(out); 104 | } 105 | } 106 | } 107 | 108 | impl KernelParameters for Box { 109 | fn params(&self, out: &mut Vec>) { 110 | (&**self).params(out); 111 | } 112 | } 113 | 114 | impl KernelParameters for () { 115 | fn params(&self, _out: &mut Vec>) {} 116 | } 117 | 118 | macro_rules! tuple_impls { 119 | ($($len:expr => ($($n:tt $name:ident)+))+) => { 120 | $( 121 | impl<$($name: KernelParameters),+> KernelParameters for ($($name,)+) { 122 | fn params(&self, out: &mut Vec>) { 123 | $( 124 | $name::params(&self.$n, out); 125 | )+ 126 | } 127 | } 128 | )+ 129 | } 130 | } 131 | 132 | tuple_impls! { 133 | 1 => (0 T0) 134 | 2 => (0 T0 1 T1) 135 | 3 => (0 T0 1 T1 2 T2) 136 | 4 => (0 T0 1 T1 2 T2 3 T3) 137 | 5 => (0 T0 1 T1 2 T2 3 T3 4 T4) 138 | 6 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5) 139 | 7 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6) 140 | 8 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7) 141 | 9 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8) 142 | 10 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9) 143 | 11 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10) 144 | 12 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11) 145 | 13 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12) 146 | 14 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13) 147 | 15 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14) 148 | 16 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15) 149 | } 150 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::missing_safety_doc)] 2 | 3 | #[allow( 4 | non_upper_case_globals, 5 | non_snake_case, 6 | improper_ctypes, 7 | non_camel_case_types 8 | )] 9 | #[doc(hidden)] 10 | pub mod sys; 11 | 12 | pub mod context; 13 | pub mod device; 14 | pub mod dim3; 15 | pub mod error; 16 | pub mod func; 17 | // pub mod future; 18 | pub mod init; 19 | pub mod kernel_params; 20 | pub mod mem; 21 | pub mod module; 22 | pub mod stream; 23 | pub mod version; 24 | 25 | pub struct Cuda; 26 | 27 | pub use context::*; 28 | pub use device::*; 29 | pub use dim3::*; 30 | pub(crate) use error::cuda_error; 31 | pub use error::{CudaResult, ErrorCode}; 32 | pub use func::*; 33 | // pub use future::*; 34 | pub use kernel_params::*; 35 | pub use mem::*; 36 | pub use module::*; 37 | pub use stream::*; 38 | pub use version::*; 39 | -------------------------------------------------------------------------------- /src/mem.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | ops::{Deref, DerefMut}, 3 | pin::Pin, 4 | rc::Rc, 5 | }; 6 | 7 | use crate::*; 8 | 9 | /// A slice into the device memory. 10 | #[derive(Clone)] 11 | pub struct DevicePtr<'a> { 12 | pub(crate) handle: Rc>, 13 | pub(crate) inner: u64, 14 | pub(crate) len: u64, 15 | } 16 | 17 | impl<'a> DevicePtr<'a> { 18 | pub fn as_raw(&self) -> u64 { 19 | self.inner 20 | } 21 | 22 | pub unsafe fn from_raw_parts(handle: Rc>, ptr: u64, len: u64) -> Self { 23 | Self { 24 | handle, 25 | inner: ptr, 26 | len, 27 | } 28 | } 29 | 30 | /// Synchronously copies data from `self` to `target`. Panics if length is not equal. 31 | pub fn copy_to<'b>(&self, target: &DevicePtr<'b>) -> CudaResult<()> { 32 | if self.len > target.len { 33 | panic!("overflow in DevicePtr::copy_to"); 34 | } else if self.len < target.len { 35 | panic!("underflow in DevicePtr::copy_to"); 36 | } 37 | 38 | if std::ptr::eq(self.handle.context, target.handle.context) { 39 | cuda_error(unsafe { sys::cuMemcpy(target.inner, self.inner, self.len as sys::size_t) }) 40 | } else { 41 | cuda_error(unsafe { 42 | sys::cuMemcpyPeer( 43 | target.inner, 44 | target.handle.context.inner, 45 | self.inner, 46 | self.handle.context.inner, 47 | self.len as sys::size_t, 48 | ) 49 | }) 50 | } 51 | } 52 | 53 | /// Asynchronously copies data from `self` to `target`. Panics if length is not equal. 54 | pub fn copy_to_stream<'b, 'c: 'b + 'a>( 55 | &self, 56 | target: &DevicePtr<'b>, 57 | stream: &mut Stream<'c>, 58 | ) -> CudaResult<()> 59 | where 60 | 'a: 'b, 61 | { 62 | if self.len > target.len { 63 | panic!("overflow in DevicePtr::copy_to"); 64 | } else if self.len < target.len { 65 | panic!("underflow in DevicePtr::copy_to"); 66 | } 67 | 68 | if std::ptr::eq(self.handle.context, target.handle.context) { 69 | cuda_error(unsafe { 70 | sys::cuMemcpyAsync( 71 | target.inner, 72 | self.inner, 73 | self.len as sys::size_t, 74 | stream.inner, 75 | ) 76 | }) 77 | } else { 78 | cuda_error(unsafe { 79 | sys::cuMemcpyPeerAsync( 80 | target.inner, 81 | target.handle.context.inner, 82 | self.inner, 83 | self.handle.context.inner, 84 | self.len as sys::size_t, 85 | stream.inner, 86 | ) 87 | }) 88 | } 89 | } 90 | 91 | // pub fn copy_to_async<'b>(&self, target: &DevicePtr<'b>) -> CudaResult> 92 | // where 93 | // 'a: 'b, 94 | // { 95 | // let mut stream = self.handle.get_async_stream()?; 96 | // unsafe { self.copy_to_stream(target, &mut stream) }?; 97 | // Ok(CudaFuture::new(self.handle.clone(), stream)) 98 | // } 99 | 100 | /// Synchronously copies data from `source` to `self`. Panics if length is not equal. 101 | pub fn copy_from<'b>(&self, source: &DevicePtr<'b>) -> CudaResult<()> { 102 | source.copy_to(self) 103 | } 104 | 105 | /// Asynchronously copies data from `source` to `self`. Panics if length is not equal. 106 | pub fn copy_from_stream<'b: 'a, 'c: 'a + 'b>( 107 | &self, 108 | source: &DevicePtr<'b>, 109 | stream: &mut Stream<'c>, 110 | ) -> CudaResult<()> { 111 | source.copy_to_stream(self, stream) 112 | } 113 | 114 | /// Gets a subslice of this slice from `[from:to]` 115 | pub fn subslice(&self, from: u64, to: u64) -> Self { 116 | if from > self.len || from > to || to > self.len { 117 | panic!("overflow in DevicePtr::subslice"); 118 | } 119 | Self { 120 | handle: self.handle.clone(), 121 | inner: self.inner + from, 122 | len: to - from, 123 | } 124 | } 125 | 126 | /// Gets the length of this slice 127 | pub fn len(&self) -> u64 { 128 | self.len 129 | } 130 | 131 | /// Check if the slice's length is 0 132 | pub fn is_empty(&self) -> bool { 133 | self.len == 0 134 | } 135 | 136 | /// Synchronously loads the data from this slice into a local buffer 137 | pub fn load(&self) -> CudaResult> { 138 | let mut buf = Vec::with_capacity(self.len as usize); 139 | cuda_error(unsafe { 140 | sys::cuMemcpyDtoH_v2( 141 | buf.as_mut_ptr() as *mut _, 142 | self.inner, 143 | self.len as sys::size_t, 144 | ) 145 | })?; 146 | unsafe { buf.set_len(self.len as usize) }; 147 | Ok(buf) 148 | } 149 | 150 | /// Asynchronously loads the data from this slice into a local buffer. 151 | /// The contents of the buffer are undefined until `stream.sync` is called. 152 | /// The output must not be dropped until the stream is synced. 153 | pub unsafe fn load_stream(&self, stream: &mut Stream<'a>) -> CudaResult> { 154 | let mut buf = Vec::with_capacity(self.len as usize); 155 | cuda_error(sys::cuMemcpyDtoHAsync_v2( 156 | buf.as_mut_ptr() as *mut _, 157 | self.inner, 158 | self.len as sys::size_t, 159 | stream.inner, 160 | ))?; 161 | buf.set_len(self.len as usize); 162 | Ok(buf) 163 | } 164 | 165 | /// Synchronously stores host data from `data` to `self`. 166 | pub fn store(&self, data: &[u8]) -> CudaResult<()> { 167 | if data.len() > self.len as usize { 168 | panic!("overflow in DevicePtr::store"); 169 | } else if data.len() < self.len as usize { 170 | panic!("underflow in DevicePtr::store"); 171 | } 172 | cuda_error(unsafe { 173 | sys::cuMemcpyHtoD_v2( 174 | self.inner, 175 | data.as_ptr() as *const _, 176 | self.len as sys::size_t, 177 | ) 178 | })?; 179 | Ok(()) 180 | } 181 | 182 | /// Asynchronously stores host data from `data` to `self`. 183 | /// The `data` must not be dropped or mutated until `stream.sync` is called. 184 | pub fn store_stream<'b>(&self, data: &'b [u8], stream: &'b mut Stream<'a>) -> CudaResult<()> { 185 | if data.len() > self.len as usize { 186 | panic!("overflow in DevicePtr::store"); 187 | } else if data.len() < self.len as usize { 188 | panic!("underflow in DevicePtr::store"); 189 | } 190 | cuda_error(unsafe { 191 | sys::cuMemcpyHtoDAsync_v2( 192 | self.inner, 193 | data.as_ptr() as *const _, 194 | self.len as sys::size_t, 195 | stream.inner, 196 | ) 197 | })?; 198 | Ok(()) 199 | } 200 | 201 | /// Asynchronously stores host data from `data` to `self`. 202 | /// `data` will be dropped once the [`Stream`] is synced or dropped. 203 | pub fn store_stream_buf(&self, data: Vec, stream: &mut Stream<'a>) -> CudaResult<()> { 204 | if data.len() > self.len as usize { 205 | panic!("overflow in DevicePtr::store"); 206 | } else if data.len() < self.len as usize { 207 | panic!("underflow in DevicePtr::store"); 208 | } 209 | let data: Pin> = data.into_boxed_slice().into(); 210 | stream.pending_stores.push(data); 211 | cuda_error(unsafe { 212 | sys::cuMemcpyHtoDAsync_v2( 213 | self.inner, 214 | stream.pending_stores.last().unwrap().as_ptr() as *const _, 215 | self.len as sys::size_t, 216 | stream.inner, 217 | ) 218 | })?; 219 | Ok(()) 220 | } 221 | 222 | /// Synchronously set the contents of `self` to `data` repeated to fill length 223 | pub fn memset_d8(&self, data: u8) -> CudaResult<()> { 224 | cuda_error(unsafe { sys::cuMemsetD8_v2(self.inner, data, self.len as sys::size_t) }) 225 | } 226 | 227 | /// Asynchronously set the contents of `self` to `data` repeated to fill length 228 | pub fn memset_d8_stream(&self, data: u8, stream: &mut Stream<'a>) -> CudaResult<()> { 229 | cuda_error(unsafe { 230 | sys::cuMemsetD8Async(self.inner, data, self.len as sys::size_t, stream.inner) 231 | }) 232 | } 233 | 234 | /// Synchronously set the contents of `self` to `data` repeated to fill length. 235 | /// Panics if [`Self::len`] is not a multiple of 2. 236 | pub fn memset_d16(&self, data: u16) -> CudaResult<()> { 237 | if self.len % 2 != 0 { 238 | panic!("alignment failure in DevicePtr::memset_d16"); 239 | } 240 | cuda_error(unsafe { sys::cuMemsetD16_v2(self.inner, data, self.len as sys::size_t / 2) }) 241 | } 242 | 243 | /// Asynchronously set the contents of `self` to `data` repeated to fill length. 244 | /// Panics if [`Self::len`] is not a multiple of 2. 245 | pub fn memset_d16_stream(&self, data: u16, stream: &mut Stream<'a>) -> CudaResult<()> { 246 | if self.len % 2 != 0 { 247 | panic!("alignment failure in DevicePtr::memset_d16_stream"); 248 | } 249 | cuda_error(unsafe { 250 | sys::cuMemsetD16Async(self.inner, data, self.len as sys::size_t / 2, stream.inner) 251 | }) 252 | } 253 | 254 | /// Synchronously set the contents of `self` to `data` repeated to fill length. 255 | /// Panics if [`Self::len`] is not a multiple of 4. 256 | pub fn memset_d32(&self, data: u32) -> CudaResult<()> { 257 | if self.len % 4 != 0 { 258 | panic!("alignment failure in DevicePtr::memset_d32"); 259 | } 260 | cuda_error(unsafe { sys::cuMemsetD32_v2(self.inner, data, self.len as sys::size_t / 4) }) 261 | } 262 | 263 | /// Asynchronously set the contents of `self` to `data` repeated to fill length. 264 | /// Panics if [`Self::len`] is not a multiple of 4. 265 | pub fn memset_d32_stream(&self, data: u32, stream: &mut Stream<'a>) -> CudaResult<()> { 266 | if self.len % 4 != 0 { 267 | panic!("alignment failure in DevicePtr::memset_d32_stream"); 268 | } 269 | cuda_error(unsafe { 270 | sys::cuMemsetD32Async(self.inner, data, self.len as sys::size_t / 4, stream.inner) 271 | }) 272 | } 273 | 274 | /// Gets a reference to the owning handle 275 | pub fn handle(&self) -> &Rc> { 276 | &self.handle 277 | } 278 | } 279 | 280 | /// An owned device-allocated buffer 281 | pub struct DeviceBox<'a> { 282 | pub(crate) inner: DevicePtr<'a>, 283 | } 284 | 285 | impl<'a> DeviceBox<'a> { 286 | /// Allocate an uninitialized buffer of size `size` on the device 287 | pub fn alloc(handle: &Rc>, size: u64) -> CudaResult { 288 | let mut out = 0u64; 289 | cuda_error(unsafe { sys::cuMemAlloc_v2(&mut out as *mut u64, size as sys::size_t) })?; 290 | Ok(DeviceBox { 291 | inner: DevicePtr { 292 | handle: handle.clone(), 293 | inner: out, 294 | len: size, 295 | }, 296 | }) 297 | } 298 | 299 | /// Allocate a new initialized buffer on the device matching the size and content of `input`. 300 | pub fn new(handle: &Rc>, input: &[u8]) -> CudaResult { 301 | let buf = Self::alloc(handle, input.len() as u64)?; 302 | buf.store(input)?; 303 | Ok(buf) 304 | } 305 | 306 | /// Allocates a new uninitialized buffer on the device, then asynchronously fills it with `input`. 307 | /// `input` must not be dropped or mutated until `stream.sync` is called. 308 | /// Does not allocate the memory asynchronously. 309 | pub fn new_stream<'b>( 310 | handle: &Rc>, 311 | input: &'b [u8], 312 | stream: &'b mut Stream<'a>, 313 | ) -> CudaResult { 314 | let buf = Self::alloc(handle, input.len() as u64)?; 315 | buf.store_stream(input, stream)?; 316 | Ok(buf) 317 | } 318 | 319 | /// Allocates a new uninitialized buffer on the device, then synchronously fills it with `input`. 320 | /// `input` will be dropped when the stream is synced or dropped. 321 | /// Does not allocate the memory asynchronously. 322 | pub fn new_stream_buf( 323 | handle: &Rc>, 324 | input: Vec, 325 | stream: &mut Stream<'a>, 326 | ) -> CudaResult { 327 | let buf = Self::alloc(handle, input.len() as u64)?; 328 | buf.store_stream_buf(input, stream)?; 329 | Ok(buf) 330 | } 331 | 332 | /// Allocates a new initialized buffer on the device matching the size and content of `input`. 333 | /// Note that memory is directly copied, so [`T`] must be [`Sized`] should not contain any pointers, references, unsized types, or other non-FFI safe types. 334 | pub fn new_ffi(handle: &Rc>, input: &[T]) -> CudaResult { 335 | let raw = unsafe { 336 | std::slice::from_raw_parts( 337 | input.as_ptr() as *const u8, 338 | input.len() * std::mem::size_of::(), 339 | ) 340 | }; 341 | let buf = Self::alloc(handle, raw.len() as u64)?; 342 | buf.store(raw)?; 343 | Ok(buf) 344 | } 345 | 346 | /// Allocates a new uninitialized buffer on the device, then synchronously fills it with `input`. 347 | /// Note that memory is directly copied, so [`T`] must be [`Sized`] *should* not contain any pointers, references, unsized types, or other non-FFI safe types. 348 | /// `input` must not be dropped or mutated until `stream.sync` is called. 349 | /// Does not allocate the memory asynchronously. 350 | pub fn new_ffi_stream<'b, T>( 351 | handle: &Rc>, 352 | input: &'b [T], 353 | stream: &'b mut Stream<'a>, 354 | ) -> CudaResult { 355 | let raw = unsafe { 356 | std::slice::from_raw_parts( 357 | input.as_ptr() as *const u8, 358 | input.len() * std::mem::size_of::(), 359 | ) 360 | }; 361 | let buf = Self::alloc(handle, raw.len() as u64)?; 362 | buf.store_stream(raw, stream)?; 363 | Ok(buf) 364 | } 365 | 366 | /// Allocates a new uninitialized buffer on the device, then synchronously fills it with `input`. 367 | /// Note that memory is directly copied, so [`T`] must be [`Sized`] *should* not contain any pointers, references, unsized types, or other non-FFI safe types. 368 | /// `input` will be dropped when the stream is synced or dropped. 369 | /// Does not allocate the memory asynchronously. 370 | pub fn new_ffi_stream_buf<'b, T>( 371 | handle: &Rc>, 372 | mut input: Vec, 373 | stream: &'b mut Stream<'a>, 374 | ) -> CudaResult { 375 | let raw = unsafe { 376 | Vec::from_raw_parts( 377 | input.as_mut_ptr() as *mut u8, 378 | input.len() * std::mem::size_of::(), 379 | input.capacity() * std::mem::size_of::(), 380 | ) 381 | }; 382 | std::mem::forget(input); 383 | let buf = Self::alloc(handle, raw.len() as u64)?; 384 | buf.store_stream_buf(raw, stream)?; 385 | Ok(buf) 386 | } 387 | 388 | /// Leaks the DeviceBox, similar to [`Box::leak`]. 389 | pub fn leak(self) { 390 | std::mem::forget(self); 391 | } 392 | 393 | /// Constructs a [`DeviceBox`] from a device pointer. 394 | pub unsafe fn from_raw(raw: DevicePtr<'a>) -> Self { 395 | Self { inner: raw } 396 | } 397 | } 398 | 399 | impl<'a> Drop for DeviceBox<'a> { 400 | fn drop(&mut self) { 401 | if let Err(e) = cuda_error(unsafe { sys::cuMemFree_v2(self.inner.inner) }) { 402 | eprintln!("CUDA: failed freeing device buffer: {:?}", e); 403 | } 404 | } 405 | } 406 | 407 | impl<'a> AsRef> for DeviceBox<'a> { 408 | fn as_ref(&self) -> &DevicePtr<'a> { 409 | &self.inner 410 | } 411 | } 412 | 413 | impl<'a> Deref for DeviceBox<'a> { 414 | type Target = DevicePtr<'a>; 415 | 416 | fn deref(&self) -> &Self::Target { 417 | &self.inner 418 | } 419 | } 420 | 421 | impl<'a> DerefMut for DeviceBox<'a> { 422 | fn deref_mut(&mut self) -> &mut Self::Target { 423 | &mut self.inner 424 | } 425 | } 426 | -------------------------------------------------------------------------------- /src/module.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | borrow::Cow, 3 | ffi::{c_void, CString}, 4 | ptr::null_mut, 5 | rc::Rc, 6 | }; 7 | 8 | use crate::*; 9 | 10 | // Debug must not be derived, see comment on info_buf 11 | /// A CUDA JIT linker context, used to compile device-specific kernels from PTX assembly or link together several precompiled binaries 12 | pub struct Linker<'a> { 13 | inner: *mut sys::CUlinkState_st, 14 | info_buf: Vec, // both info_buf and errors_buf contain uninitialized memory! they should always be NUL terminated strings 15 | errors_buf: Vec, 16 | handle: Rc>, 17 | } 18 | 19 | /// The type of input to the linker 20 | #[derive(Clone, Copy, Debug, PartialEq)] 21 | pub enum LinkerInputType { 22 | Cubin, 23 | Ptx, 24 | Fatbin, 25 | } 26 | 27 | /// Linker options for CUDA, can generally just be defaulted. 28 | #[derive(Clone, Copy, Debug)] 29 | pub struct LinkerOptions { 30 | /// Add debug symbols to emitted binary 31 | pub debug_info: bool, 32 | /// Collect INFO logs from CUDA build/link, up to 16 MB, then emit to STDOUT 33 | pub log_info: bool, 34 | /// Collect ERROR logs from CUDA build/link, up to 16 MB, then emit to STDOUT 35 | pub log_errors: bool, 36 | /// Increase log verbosity 37 | pub verbose_logs: bool, 38 | } 39 | 40 | impl Default for LinkerOptions { 41 | fn default() -> Self { 42 | LinkerOptions { 43 | debug_info: false, 44 | log_info: true, 45 | log_errors: true, 46 | verbose_logs: false, 47 | } 48 | } 49 | } 50 | 51 | impl<'a> Linker<'a> { 52 | /// Creates a new [`Linker`] for the given context handle, compute capability, and linker options. 53 | pub fn new( 54 | handle: &Rc>, 55 | compute_capability: CudaVersion, 56 | options: LinkerOptions, 57 | ) -> CudaResult { 58 | let mut linker = Linker { 59 | inner: null_mut(), 60 | info_buf: if options.log_info { 61 | let mut buf = Vec::with_capacity(16 * 1024 * 1024); 62 | buf.push(0); 63 | unsafe { buf.set_len(buf.capacity()) }; 64 | buf 65 | } else { 66 | vec![] 67 | }, 68 | errors_buf: if options.log_errors { 69 | let mut buf = Vec::with_capacity(16 * 1024 * 1024); 70 | buf.push(0); 71 | unsafe { buf.set_len(buf.capacity()) }; 72 | buf 73 | } else { 74 | vec![] 75 | }, 76 | handle: handle.clone(), 77 | }; 78 | let log_verbose = if options.verbose_logs { 1u32 } else { 0u32 }; 79 | let debug_info = if options.debug_info { 1u32 } else { 0u32 }; 80 | 81 | let mut options = [ 82 | sys::CUjit_option_enum_CU_JIT_INFO_LOG_BUFFER, 83 | sys::CUjit_option_enum_CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, 84 | sys::CUjit_option_enum_CU_JIT_ERROR_LOG_BUFFER, 85 | sys::CUjit_option_enum_CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, 86 | sys::CUjit_option_enum_CU_JIT_TARGET, 87 | sys::CUjit_option_enum_CU_JIT_LOG_VERBOSE, 88 | sys::CUjit_option_enum_CU_JIT_GENERATE_DEBUG_INFO, 89 | ]; 90 | let target = match (compute_capability.major, compute_capability.minor) { 91 | (2, 0) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_20, 92 | (2, 1) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_21, 93 | (3, 0) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_30, 94 | (3, 2) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_32, 95 | (3, 5) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_35, 96 | (3, 7) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_37, 97 | (5, 0) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_50, 98 | (5, 2) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_52, 99 | (5, 3) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_53, 100 | (6, 0) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_60, 101 | (6, 1) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_61, 102 | (6, 2) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_62, 103 | (7, 0) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_70, 104 | (7, 2) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_72, 105 | (7, 5) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_75, 106 | (8, 0) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_80, 107 | (8, 6) => sys::CUjit_target_enum_CU_TARGET_COMPUTE_86, 108 | (_, _) => return Err(ErrorCode::UnsupportedPtxVersion), 109 | }; 110 | 111 | let mut values = [ 112 | linker.info_buf.as_mut_ptr() as *mut c_void, 113 | linker.info_buf.len() as u32 as u64 as *mut c_void, 114 | linker.errors_buf.as_mut_ptr() as *mut c_void, 115 | linker.errors_buf.len() as u32 as u64 as *mut c_void, 116 | target as u64 as *mut c_void, 117 | log_verbose as u64 as *mut c_void, 118 | debug_info as u64 as *mut c_void, 119 | ]; 120 | cuda_error(unsafe { 121 | sys::cuLinkCreate_v2( 122 | options.len() as u32, 123 | options.as_mut_ptr(), 124 | values.as_mut_ptr(), 125 | &mut linker.inner as *mut _, 126 | ) 127 | })?; 128 | Ok(linker) 129 | } 130 | 131 | fn emit_logs(&self) { 132 | let info_string = self.info_buf.iter().position(|x| *x == 0); 133 | if let Some(info_string) = info_string { 134 | let info_string = String::from_utf8_lossy(&self.info_buf[..info_string]); 135 | if !info_string.is_empty() { 136 | info_string.split('\n').for_each(|line| { 137 | println!("[CUDA INFO] {}", line); 138 | }); 139 | } 140 | } 141 | let error_string = self.errors_buf.iter().position(|x| *x == 0); 142 | if let Some(error_string) = error_string { 143 | let error_string = String::from_utf8_lossy(&self.errors_buf[..error_string]); 144 | if !error_string.is_empty() { 145 | error_string.split('\n').for_each(|line| { 146 | println!("[CUDA ERROR] {}", line); 147 | }); 148 | } 149 | } 150 | } 151 | 152 | /// Add an input file to the linker context. `name` is only used for logs 153 | pub fn add(self, name: &str, format: LinkerInputType, in_data: &[u8]) -> CudaResult { 154 | let mut data = Cow::Borrowed(in_data); 155 | if format == LinkerInputType::Ptx { 156 | let mut new_data = Vec::with_capacity(in_data.len() + 1); 157 | new_data.extend_from_slice(in_data); 158 | new_data.push(0); 159 | data = Cow::Owned(new_data) 160 | } 161 | 162 | let format = match format { 163 | LinkerInputType::Cubin => sys::CUjitInputType_enum_CU_JIT_INPUT_CUBIN, 164 | LinkerInputType::Ptx => sys::CUjitInputType_enum_CU_JIT_INPUT_PTX, 165 | LinkerInputType::Fatbin => sys::CUjitInputType_enum_CU_JIT_INPUT_FATBINARY, 166 | }; 167 | let name = CString::new(name).unwrap(); 168 | 169 | let out = cuda_error(unsafe { 170 | sys::cuLinkAddData_v2( 171 | self.inner, 172 | format, 173 | data.as_ptr() as *mut u8 as *mut c_void, 174 | data.len() as sys::size_t, 175 | name.as_ptr(), 176 | 0, 177 | null_mut(), 178 | null_mut(), 179 | ) 180 | }); 181 | 182 | if let Err(e) = out { 183 | self.emit_logs(); 184 | return Err(e); 185 | } 186 | Ok(self) 187 | } 188 | 189 | /// Emit the cubin assembly binary. You probably want [`Linker::build_module`] 190 | pub fn build(&self) -> CudaResult<&[u8]> { 191 | let mut cubin_out: *mut c_void = null_mut(); 192 | let mut size_out: sys::size_t = 0; 193 | let out = cuda_error(unsafe { 194 | sys::cuLinkComplete( 195 | self.inner, 196 | &mut cubin_out as *mut *mut c_void, 197 | &mut size_out as *mut sys::size_t, 198 | ) 199 | }); 200 | self.emit_logs(); 201 | if let Err(e) = out { 202 | return Err(e); 203 | } 204 | Ok(unsafe { std::slice::from_raw_parts(cubin_out as *const u8, size_out as usize) }) 205 | } 206 | 207 | /// Build a CUDA module from this [`Linker`]. 208 | pub fn build_module(&self) -> CudaResult> { 209 | let built = self.build()?; 210 | Module::load(&self.handle, built) 211 | } 212 | } 213 | 214 | impl<'a> Drop for Linker<'a> { 215 | fn drop(&mut self) { 216 | if let Err(e) = cuda_error(unsafe { sys::cuLinkDestroy(self.inner) }) { 217 | eprintln!("CUDA: failed to destroy cuda linker state: {:?}", e); 218 | } 219 | } 220 | } 221 | 222 | /// A loaded CUDA module 223 | pub struct Module<'a> { 224 | handle: Rc>, 225 | inner: *mut sys::CUmod_st, 226 | } 227 | 228 | impl<'a> Module<'a> { 229 | /// Takes a raw CUDA kernel image and loads the corresponding module module into the current context. 230 | /// The pointer can be a cubin or PTX or fatbin file as a NULL-terminated text string 231 | pub fn load(handle: &Rc>, module: &[u8]) -> CudaResult { 232 | let mut inner = null_mut(); 233 | cuda_error(unsafe { 234 | sys::cuModuleLoadData(&mut inner as *mut _, module.as_ptr() as *const _) 235 | })?; 236 | Ok(Module { 237 | inner, 238 | handle: handle.clone(), 239 | }) 240 | } 241 | 242 | /// Same as [`Module::load`] but uses `fatCubin` format. 243 | pub fn load_fatcubin(handle: &Rc>, module: &[u8]) -> CudaResult { 244 | let mut inner = null_mut(); 245 | cuda_error(unsafe { 246 | sys::cuModuleLoadFatBinary(&mut inner as *mut _, module.as_ptr() as *const _) 247 | })?; 248 | Ok(Module { 249 | inner, 250 | handle: handle.clone(), 251 | }) 252 | } 253 | 254 | /// Retrieve a reference to a define CUDA kernel within the module. 255 | pub fn get_function<'b>(&'b self, name: &str) -> CudaResult> { 256 | let mut inner = null_mut(); 257 | let name = CString::new(name).unwrap(); 258 | cuda_error(unsafe { 259 | sys::cuModuleGetFunction(&mut inner as *mut _, self.inner, name.as_ptr()) 260 | })?; 261 | Ok(Function { 262 | module: self, 263 | inner, 264 | }) 265 | } 266 | 267 | /// Get a pointer to a global variable defined by a CUDA module. 268 | pub fn get_global<'b: 'a>(&'b self, name: &str) -> CudaResult> { 269 | let mut out = DevicePtr { 270 | handle: self.handle.clone(), 271 | inner: 0, 272 | len: 0, 273 | }; 274 | let name = CString::new(name).unwrap(); 275 | cuda_error(unsafe { 276 | sys::cuModuleGetGlobal_v2( 277 | &mut out.inner, 278 | &mut out.len as *mut u64 as *mut _, 279 | self.inner, 280 | name.as_ptr(), 281 | ) 282 | })?; 283 | Ok(out) 284 | } 285 | 286 | // pub fn get_surface(&self, name: &str) { 287 | 288 | // } 289 | 290 | // pub fn get_texture(&self, name: &str) { 291 | 292 | // } 293 | } 294 | 295 | impl<'a> Drop for Module<'a> { 296 | fn drop(&mut self) { 297 | if let Err(e) = cuda_error(unsafe { sys::cuModuleUnload(self.inner) }) { 298 | eprintln!("CUDA: failed to destroy cuda module: {:?}", e); 299 | } 300 | } 301 | } 302 | -------------------------------------------------------------------------------- /src/stream.rs: -------------------------------------------------------------------------------- 1 | use num_enum::TryFromPrimitive; 2 | use std::{ffi::c_void, marker::PhantomData, pin::Pin, ptr::null_mut, rc::Rc}; 3 | 4 | use crate::*; 5 | 6 | /// A stream of asynchronous operations operating in a [`Context`] 7 | pub struct Stream<'a> { 8 | pub(crate) inner: *mut sys::CUstream_st, 9 | pub(crate) pending_stores: Vec>>, 10 | _p: PhantomData<&'a ()>, 11 | } 12 | 13 | /// Wait comparison type for waiting on some condition in [`Stream::wait_32`]/etc 14 | #[derive(Debug, Copy, Clone, TryFromPrimitive)] 15 | #[repr(u32)] 16 | pub enum WaitValueMode { 17 | /// Wait until (int32_t)(*addr - value) >= 0 (or int64_t for 64 bit values). Note this is a cyclic comparison which ignores wraparound. (Default behavior.) 18 | Geq = 0x0, 19 | /// Wait until *addr == value. 20 | Eq = 0x1, 21 | /// Wait until (*addr & value) != 0. 22 | And = 0x2, 23 | /// Wait until ~(*addr | value) != 0. Support for this operation can be queried with cuDeviceGetAttribute() and CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR. 24 | Nor = 0x3, 25 | } 26 | 27 | unsafe extern "C" fn host_callback(arg: *mut std::ffi::c_void) { 28 | let closure: Box> = Box::from_raw(arg as *mut _); 29 | closure(); 30 | } 31 | 32 | impl<'a> Stream<'a> { 33 | /// Creates a new stream for a handle 34 | pub fn new(_handle: &Rc>) -> CudaResult { 35 | let mut out = null_mut(); 36 | cuda_error(unsafe { 37 | sys::cuStreamCreate( 38 | &mut out as *mut _, 39 | sys::CUstream_flags_enum_CU_STREAM_NON_BLOCKING, 40 | ) 41 | })?; 42 | Ok(Self { 43 | inner: out, 44 | pending_stores: vec![], 45 | _p: PhantomData, 46 | }) 47 | } 48 | 49 | /// Drives all pending tasks on the stream to completion 50 | pub fn sync(&mut self) -> CudaResult<()> { 51 | cuda_error(unsafe { sys::cuStreamSynchronize(self.inner) })?; 52 | self.pending_stores.clear(); 53 | Ok(()) 54 | } 55 | 56 | /// Returns `Ok(true)` if the stream has finished processing all queued tasks. 57 | pub fn is_synced(&self) -> CudaResult { 58 | match cuda_error(unsafe { sys::cuStreamQuery(self.inner) }) { 59 | Ok(()) => Ok(true), 60 | Err(ErrorCode::NotReady) => Ok(false), 61 | Err(e) => Err(e), 62 | } 63 | } 64 | 65 | /// Wait for a 4-byte value in a specific location to compare to `value` by `mode`. 66 | pub fn wait_32<'b>( 67 | &'b mut self, 68 | addr: &'b DevicePtr<'a>, 69 | value: u32, 70 | mode: WaitValueMode, 71 | flush: bool, 72 | ) -> CudaResult<()> { 73 | if addr.len < 4 { 74 | panic!("overflow in Stream::wait_32"); 75 | } 76 | let flush = if flush { 1u32 << 30 } else { 0 }; 77 | cuda_error(unsafe { 78 | sys::cuStreamWaitValue32(self.inner, addr.inner, value, mode as u32 | flush) 79 | }) 80 | } 81 | 82 | /// Wait for a 8-byte value in a specific location to compare to `value` by `mode`. 83 | pub fn wait_64<'b>( 84 | &mut self, 85 | addr: &'b DevicePtr<'a>, 86 | value: u64, 87 | mode: WaitValueMode, 88 | flush: bool, 89 | ) -> CudaResult<()> { 90 | if addr.len < 8 { 91 | panic!("overflow in Stream::wait_64"); 92 | } 93 | let flush = if flush { 1u32 << 30 } else { 0 }; 94 | cuda_error(unsafe { 95 | sys::cuStreamWaitValue64(self.inner, addr.inner, value, mode as u32 | flush) 96 | }) 97 | } 98 | 99 | /// Writes a 4-byte value to device memory asynchronously 100 | pub fn write_32<'b>( 101 | &'b mut self, 102 | addr: &'b DevicePtr<'a>, 103 | value: u32, 104 | no_memory_barrier: bool, 105 | ) -> CudaResult<()> { 106 | if addr.len < 4 { 107 | panic!("overflow in Stream::write_32"); 108 | } 109 | let no_memory_barrier = if no_memory_barrier { 1u32 } else { 0 }; 110 | cuda_error(unsafe { 111 | sys::cuStreamWriteValue32(self.inner, addr.inner, value, no_memory_barrier) 112 | }) 113 | } 114 | 115 | /// Writes a 8-byte value to device memory asynchronously 116 | pub fn write_64<'b>( 117 | &'b mut self, 118 | addr: &'b DevicePtr<'a>, 119 | value: u64, 120 | no_memory_barrier: bool, 121 | ) -> CudaResult<()> { 122 | if addr.len < 8 { 123 | panic!("overflow in Stream::write_64"); 124 | } 125 | let no_memory_barrier = if no_memory_barrier { 1u32 } else { 0 }; 126 | cuda_error(unsafe { 127 | sys::cuStreamWriteValue64(self.inner, addr.inner, value, no_memory_barrier) 128 | }) 129 | } 130 | 131 | /// Calls a callback closure function `callback` once all prior tasks in the Stream have been driven to completion. 132 | /// Note that it is a memory leak to drop the stream before this callback is called. 133 | /// The callback is not guaranteed to be called if the stream errors out. 134 | /// Also note that it is erroneous in `libcuda` to make any calls to `libcuda` from this callback. 135 | /// The callback is called from a CUDA internal thread, however this is an implementation detail of `libcuda` and not guaranteed. 136 | pub fn callback(&mut self, callback: F) -> CudaResult<()> { 137 | let callback: Box> = Box::new(Box::new(callback)); 138 | cuda_error(unsafe { 139 | sys::cuLaunchHostFunc( 140 | self.inner, 141 | Some(host_callback), 142 | Box::leak(callback) as *mut _ as *mut _, 143 | ) 144 | }) 145 | } 146 | 147 | /// Launch a CUDA kernel on this [`Stream`] with the given `grid_dim` grid dimensions, `block_dim` block dimensions, `shared_mem_size` allocated shared memory pool, and `parameters` kernel parameters. 148 | /// It is undefined behavior to pass in `parameters` that do not conform to the passes CUDA kernel. If the argument count is wrong, CUDA will generally throw an error. 149 | /// If your `parameters` is accurate to the kernel definition, then this function is otherwise safe. 150 | pub unsafe fn launch<'b, D1: Into, D2: Into, K: KernelParameters>( 151 | &mut self, 152 | f: &Function<'a, 'b>, 153 | grid_dim: D1, 154 | block_dim: D2, 155 | shared_mem_size: u32, 156 | parameters: K, 157 | ) -> CudaResult<()> { 158 | let grid_dim = grid_dim.into().0; 159 | let block_dim = block_dim.into().0; 160 | let mut kernel_params = vec![]; 161 | parameters.params(&mut kernel_params); 162 | let mut new_kernel_params = Vec::with_capacity(kernel_params.len()); 163 | for param in &kernel_params { 164 | new_kernel_params.push(param.as_ptr() as *mut c_void); 165 | } 166 | cuda_error(sys::cuLaunchKernel( 167 | f.inner, 168 | grid_dim.0, 169 | grid_dim.1, 170 | grid_dim.2, 171 | block_dim.0, 172 | block_dim.1, 173 | block_dim.2, 174 | shared_mem_size, 175 | self.inner, 176 | new_kernel_params.as_mut_ptr(), 177 | null_mut(), 178 | )) 179 | } 180 | } 181 | 182 | impl<'a> Drop for Stream<'a> { 183 | fn drop(&mut self) { 184 | if let Err(e) = cuda_error(unsafe { sys::cuStreamDestroy_v2(self.inner) }) { 185 | eprintln!("CUDA: failed to drop stream: {:?}", e); 186 | } 187 | } 188 | } 189 | -------------------------------------------------------------------------------- /src/version.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | 3 | use crate::{ 4 | error::{cuda_error, CudaResult}, 5 | sys, Cuda, 6 | }; 7 | 8 | /// A CUDA device or API version 9 | #[derive(Clone, Copy, Debug)] 10 | pub struct CudaVersion { 11 | pub major: u32, 12 | pub minor: u32, 13 | } 14 | 15 | impl From for CudaVersion { 16 | fn from(version: u32) -> Self { 17 | CudaVersion { 18 | major: version as u32 / 1000, 19 | minor: (version as u32 % 1000) / 10, 20 | } 21 | } 22 | } 23 | 24 | impl Into<(u32, u32)> for CudaVersion { 25 | fn into(self) -> (u32, u32) { 26 | (self.major, self.minor) 27 | } 28 | } 29 | 30 | impl From<(u32, u32)> for CudaVersion { 31 | fn from(other: (u32, u32)) -> Self { 32 | CudaVersion { 33 | major: other.0, 34 | minor: other.1, 35 | } 36 | } 37 | } 38 | 39 | impl fmt::Display for CudaVersion { 40 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 41 | write!(f, "{}.{}", self.major, self.minor) 42 | } 43 | } 44 | 45 | impl Cuda { 46 | /// Gets the local driver version (not to be confused with device compute capability) 47 | pub fn version() -> CudaResult { 48 | let mut version = 0i32; 49 | cuda_error(unsafe { sys::cuDriverGetVersion(&mut version as *mut i32) })?; 50 | Ok((version as u32).into()) 51 | } 52 | } 53 | --------------------------------------------------------------------------------