├── rocm_attr.lock ├── include ├── rocprofiler_v2.h ├── rocsolver.h ├── rocblas.h ├── rocrand.h ├── rocsparse.h ├── rocfft.h ├── hip.h ├── miopen.h ├── rocprofiler.h ├── activity.h └── rocwmma.h ├── src ├── rocsolver │ └── mod.rs ├── hip │ ├── examples │ │ ├── saxpy │ │ │ ├── rocm_attr.lock │ │ │ ├── Cargo.toml │ │ │ └── src │ │ │ │ └── main.rs │ │ ├── rust_kernel │ │ │ ├── rocm_attr.lock │ │ │ ├── Cargo.toml │ │ │ └── src │ │ │ │ └── main.rs │ │ ├── rust_kernel_async │ │ │ ├── rocm_attr.lock │ │ │ ├── Cargo.toml │ │ │ └── src │ │ │ │ └── main.rs │ │ ├── vector_add │ │ │ ├── .gitignore │ │ │ ├── vector_add.hsaco │ │ │ ├── kernel.hip │ │ │ ├── Cargo.toml │ │ │ ├── build.sh │ │ │ ├── README.md │ │ │ └── main.rs │ │ └── sort │ │ │ ├── Cargo.toml │ │ │ └── src │ │ │ └── main.rs │ ├── memory_ext │ │ ├── mod.rs │ │ └── sorting.rs │ ├── mod.rs │ ├── kernel.rs │ ├── error.rs │ ├── ffi.rs │ ├── event.rs │ ├── module.rs │ ├── device.rs │ ├── utils.rs │ └── stream.rs ├── miopen │ ├── examples │ │ ├── multi_tensor │ │ │ ├── rocm_attr.lock │ │ │ ├── Cargo.toml │ │ │ └── src │ │ │ │ ├── kernels.rs │ │ │ │ ├── data.rs │ │ │ │ ├── iris.csv │ │ │ │ ├── layer.rs │ │ │ │ └── main.rs │ │ └── basic │ │ │ ├── Cargo.toml │ │ │ └── src │ │ │ └── main.rs │ ├── mod.rs │ ├── error.rs │ ├── handle.rs │ ├── ctc_loss.rs │ ├── lrn.rs │ ├── reduce.rs │ └── mha.rs ├── rocarray │ └── error.rs ├── rocprofiler │ ├── mod.rs │ └── error.rs ├── rocblas │ ├── examples │ │ └── basic │ │ │ ├── Cargo.toml │ │ │ └── src │ │ │ └── main.rs │ ├── macros.rs │ ├── mod.rs │ ├── handle.rs │ └── error.rs ├── rocrand │ ├── examples │ │ └── normal │ │ │ ├── Cargo.toml │ │ │ └── src │ │ │ └── main.rs │ ├── error.rs │ ├── mod.rs │ └── utils.rs ├── rocmsmi │ └── mod.rs ├── rocsparse │ ├── mod.rs │ ├── vector.rs │ ├── matrix.rs │ ├── handle.rs │ ├── error.rs │ ├── descriptor.rs │ └── conversion.rs ├── lib.rs └── rocfft │ ├── mod.rs │ ├── cache.rs │ ├── ffi.rs │ └── execution.rs ├── .gitignore ├── .idea ├── misc.xml ├── vcs.xml ├── .gitignore ├── modules.xml └── rocm-rs.iml ├── .github ├── workflows │ └── rust.yml └── FUNDING.yml ├── Cargo.toml ├── LICENSE └── README.md /rocm_attr.lock: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /include/rocprofiler_v2.h: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/rocsolver/mod.rs: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/hip/examples/saxpy/rocm_attr.lock: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/hip/examples/rust_kernel/rocm_attr.lock: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/miopen/examples/multi_tensor/rocm_attr.lock: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/hip/examples/rust_kernel_async/rocm_attr.lock: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/hip/examples/vector_add/.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | Cargo.lock -------------------------------------------------------------------------------- /src/rocarray/error.rs: -------------------------------------------------------------------------------- 1 | pub enum Error { 2 | InvalidOperation(String) 3 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | **/target 3 | **/kernel_sources 4 | **/bindings.rs 5 | *.code-workspace -------------------------------------------------------------------------------- /src/hip/examples/vector_add/vector_add.hsaco: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RustNSparks/rocm-rs/HEAD/src/hip/examples/vector_add/vector_add.hsaco -------------------------------------------------------------------------------- /src/rocprofiler/mod.rs: -------------------------------------------------------------------------------- 1 | #[allow(warnings)] 2 | pub mod bindings; 3 | pub mod error; 4 | pub mod types; 5 | pub mod context; 6 | pub mod profiler; -------------------------------------------------------------------------------- /include/rocsolver.h: -------------------------------------------------------------------------------- 1 | #ifndef ROCSOLVER_WRAPPER_H 2 | #define ROCSOLVER_WRAPPER_H 3 | 4 | #include 5 | 6 | #endif // ROCSOLVER_WRAPPER_H -------------------------------------------------------------------------------- /src/rocblas/examples/basic/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "basic" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | rocm-rs = {path = "../../../.."} 8 | -------------------------------------------------------------------------------- /src/rocrand/examples/normal/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "example" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | rocm-rs = { path = "../../../.." } -------------------------------------------------------------------------------- /include/rocblas.h: -------------------------------------------------------------------------------- 1 | #ifndef ROCBLAS_WRAPPER_H 2 | #define ROCBLAS_WRAPPER_H 3 | 4 | // Include only the main header 5 | #include 6 | 7 | #endif // ROCBLAS_WRAPPER_H -------------------------------------------------------------------------------- /include/rocrand.h: -------------------------------------------------------------------------------- 1 | #ifndef ROCRAND_WRAPPER_H 2 | #define ROCRAND_WRAPPER_H 3 | 4 | // Include only the main header 5 | #include 6 | 7 | #endif // ROCRAND_WRAPPER_H -------------------------------------------------------------------------------- /src/miopen/examples/basic/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "basic" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | rocm-rs = { path = "../../../.." , features = ["miopen"]} -------------------------------------------------------------------------------- /src/hip/examples/saxpy/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "saxpy" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | rocm-rs = { path = "../../../../", features = ["macros"] } 8 | -------------------------------------------------------------------------------- /src/hip/examples/sort/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "sort" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | rocm-rs = { path = "../../../../", features = ["macros"] } 8 | -------------------------------------------------------------------------------- /include/rocsparse.h: -------------------------------------------------------------------------------- 1 | #ifndef ROCSPARSE_WRAPPER_H 2 | #define ROCSPARSE_WRAPPER_H 3 | 4 | // Main rocsparse header is sufficient 5 | #include 6 | 7 | #endif // ROCSPARSE_WRAPPER_H -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /src/hip/examples/rust_kernel/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rust_kernel" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | rocm_kernel_macros = "0.3.0" 8 | rocm-rs = { path = "../../../.." } -------------------------------------------------------------------------------- /src/hip/examples/rust_kernel_async/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rust_kernel_async" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | rocm-rs = { path = "../../../..", features = ["macros"]} -------------------------------------------------------------------------------- /include/rocfft.h: -------------------------------------------------------------------------------- 1 | #ifndef ROCFFT_WRAPPER_H 2 | #define ROCFFT_WRAPPER_H 3 | 4 | // The main rocfft header already includes all necessary dependencies 5 | #include 6 | 7 | #endif // ROCFFT_WRAPPER_H -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /src/miopen/examples/multi_tensor/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "multi_tensor" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | csv = "1.4.0" 8 | rocm-rs = { path = "../../../.." , features = ["miopen"]} -------------------------------------------------------------------------------- /include/hip.h: -------------------------------------------------------------------------------- 1 | #ifndef HIP_WRAPPER_H 2 | #define HIP_WRAPPER_H 3 | 4 | #ifndef __HIP_PLATFORM_AMD__ 5 | #define __HIP_PLATFORM_AMD__ 1 6 | #endif 7 | 8 | #include 9 | 10 | #endif // HIP_WRAPPER_H -------------------------------------------------------------------------------- /include/miopen.h: -------------------------------------------------------------------------------- 1 | #ifndef MIOPEN_WRAPPER_H 2 | #define MIOPEN_WRAPPER_H 3 | 4 | #ifndef __HIP_PLATFORM_AMD__ 5 | #define __HIP_PLATFORM_AMD__ 1 6 | #endif 7 | 8 | #include 9 | 10 | #endif // MIOPEN_WRAPPER_H -------------------------------------------------------------------------------- /src/hip/examples/vector_add/kernel.hip: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | 4 | extern "C" __global__ void vector_add(const float* a, const float* b, float* c, unsigned int n) { 5 | int i = blockDim.x * blockIdx.x + threadIdx.x; 6 | if (i < n) { 7 | c[i] = a[i] + b[i]; 8 | } 9 | } -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /src/hip/examples/vector_add/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "vector_add_example" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | rocm-rs = { path = "../../../.." } # Adjust this path to point to your rocm_rs crate 8 | 9 | [[bin]] 10 | name = "vector_add_example" 11 | path = "main.rs" -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: [ "master" ] 6 | pull_request: 7 | branches: [ "master" ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: format code 20 | run: cargo fmt 21 | -------------------------------------------------------------------------------- /include/rocprofiler.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef ROCPROFILER_WRAPPER_H 3 | #define ROCPROFILER_WRAPPER_H 4 | 5 | #ifndef __HIP_PLATFORM_AMD__ 6 | #define __HIP_PLATFORM_AMD__ 1 7 | #endif 8 | 9 | // Include the main ROCProfiler header 10 | // This should bring in all the necessary HSA dependencies 11 | #include 12 | #include "activity.h" 13 | 14 | #endif // ROCPROFILER_WRAPPER_H -------------------------------------------------------------------------------- /src/rocrand/examples/normal/src/main.rs: -------------------------------------------------------------------------------- 1 | use rocm_rs::rocrand::utils::generate_normal_f32; 2 | 3 | fn main() -> Result<(), Box> { 4 | let size = 32; 5 | let mut host = vec![0f32; size]; 6 | 7 | let device = generate_normal_f32(size, 0.5, 0.5, None)?; 8 | device.copy_to_host(&mut host)?; 9 | 10 | println!("{:?}", host); 11 | 12 | Ok(()) 13 | } 14 | -------------------------------------------------------------------------------- /src/rocmsmi/mod.rs: -------------------------------------------------------------------------------- 1 | pub use rocm_smi_lib as rocmsmi; 2 | pub use rocmsmi::*; 3 | 4 | #[cfg(test)] 5 | mod test { 6 | use crate::rocmsmi::{RocmSmi, *}; 7 | 8 | #[test] 9 | fn rocm_smi_test() -> Result<(), rocmsmi::RocmErr> { 10 | let mut rocm_smi = RocmSmi::init()?; 11 | 12 | let _ = rocm_smi.get_device_identifiers(0).unwrap(); 13 | Ok(()) 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/rocsparse/mod.rs: -------------------------------------------------------------------------------- 1 | //! Bindings for rocsparse 2 | //! Auto-generated - do not modify 3 | #[allow(warnings)] 4 | pub mod bindings; 5 | pub mod conversion; 6 | pub mod descriptor; 7 | pub mod error; 8 | pub mod handle; 9 | pub mod matrix; 10 | mod pruning; 11 | pub mod vector; 12 | 13 | // Re-export all bindings 14 | pub use bindings::*; 15 | 16 | // Import dependencies 17 | pub use crate::hip::*; 18 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | extern crate core; 2 | pub mod error; 3 | pub mod hip; 4 | #[cfg(feature = "miopen")] 5 | pub mod miopen; 6 | pub mod rocblas; 7 | pub mod rocfft; 8 | pub mod rocrand; 9 | #[cfg(feature = "rocsolver")] 10 | pub mod rocsolver; 11 | 12 | #[cfg(feature = "rocm_smi")] 13 | pub mod rocmsmi; 14 | // mod rocprofiler; 15 | pub mod rocarray; 16 | pub mod rocsparse; 17 | 18 | #[cfg(feature = "macros")] 19 | pub use rocm_kernel_macros; 20 | -------------------------------------------------------------------------------- /src/hip/examples/vector_add/build.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | # Build script for the vector addition example 4 | 5 | set -e # Exit on error 6 | 7 | # Directory containing this script 8 | SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 9 | 10 | echo Script directory: $SCRIPT_DIR 11 | 12 | # Compile the kernel to a binary file 13 | hipcc -fno-gpu-rdc -fPIC --genco -O3 -o "$SCRIPT_DIR/vector_add.hsaco" "$SCRIPT_DIR/kernel.hip" 14 | 15 | echo "Kernel compiled successfully to vector_add.hsaco" -------------------------------------------------------------------------------- /.idea/rocm-rs.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /include/activity.h: -------------------------------------------------------------------------------- 1 | #ifndef _SRC_CORE_ACTIVITY_H 2 | #define _SRC_CORE_ACTIVITY_H 3 | 4 | #include "rocprofiler.h" 5 | 6 | #include 7 | 8 | // HSA EVT ID enumeration 9 | enum hsa_evt_id_t { 10 | HSA_EVT_ID_ALLOCATE = ROCPROFILER_HSA_CB_ID_ALLOCATE, 11 | HSA_EVT_ID_DEVICE = ROCPROFILER_HSA_CB_ID_DEVICE, 12 | HSA_EVT_ID_MEMCOPY = ROCPROFILER_HSA_CB_ID_MEMCOPY, 13 | HSA_EVT_ID_SUBMIT = ROCPROFILER_HSA_CB_ID_SUBMIT, 14 | HSA_EVT_ID_KSYMBOL = ROCPROFILER_HSA_CB_ID_KSYMBOL, 15 | HSA_EVT_ID_CODEOBJ = ROCPROFILER_HSA_CB_ID_CODEOBJ, 16 | HSA_EVT_ID_NUMBER 17 | }; 18 | 19 | // HSA EVT callback data type 20 | typedef rocprofiler_hsa_callback_data_t hsa_evt_data_t; 21 | 22 | #endif // _SRC_CORE_ACTIVITY_H 23 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rocm-rs" 3 | version = "0.5.0" 4 | edition = "2024" 5 | description = "Rust bindings for AMD ROCm libraries" 6 | license = "MIT" 7 | repository = "https://github.com/RustNSparks/rocm-rs" 8 | documentation = "https://docs.rs/rocm-rs" 9 | readme = "README.md" 10 | keywords = ["gpu", "rocm", "amd", "hpc", "bindings"] 11 | categories = ["api-bindings", "external-ffi-bindings"] 12 | exclude = ["**/kernel_sources"] 13 | [lib] 14 | doctest = false 15 | 16 | 17 | [dependencies] 18 | rocm_smi_lib = { version = "0.3.1", optional = true } 19 | rocm_kernel_macros = {version = "0.4.2", optional = true} 20 | paste = "1.0.15" 21 | 22 | [build-dependencies] 23 | bindgen = "0.71.1" 24 | 25 | [features] 26 | default = ["macros", "miopen"] 27 | rocm_smi = ["dep:rocm_smi_lib"] 28 | miopen = [] 29 | rocprofiler = [] 30 | rocsolver = [] 31 | macros=["dep:rocm_kernel_macros"] 32 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: radudiaconu0 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 12 | polar: # Replace with a single Polar username 13 | buy_me_a_coffee: # Replace with a single Buy Me a Coffee username 14 | thanks_dev: # Replace with a single thanks.dev username 15 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 16 | -------------------------------------------------------------------------------- /src/rocblas/examples/basic/src/main.rs: -------------------------------------------------------------------------------- 1 | use rocm_rs::rocblas::scal; 2 | use rocm_rs::{hip::*, rocblas}; 3 | use std::error::Error; 4 | 5 | // this example shows Matrix-Vector Multiply with rocblas 6 | fn main() -> std::result::Result<(), Box> { 7 | // Initialize rocBLAS handle 8 | let handle = rocblas::Handle::new()?; 9 | 10 | // Matrix dimensions 11 | let m = 2; // rows 12 | let n = 3; // columns 13 | 14 | // Host data (column-major order) 15 | let mut h_a: [f32; 6] = [ 16 | 1.0, 4.0, // Column 0 17 | 2.0, 5.0, // Column 1 18 | 3.0, 6.0, // Column 2 19 | ]; 20 | 21 | // Device memory pointers 22 | let mut d_a = DeviceMemory::::new(m * n)?; 23 | 24 | d_a.copy_from_host(&h_a)?; 25 | 26 | 27 | let alpha: f32 = 2.0; 28 | // Perform y = alpha * A 29 | scal(&handle, (n*m) as i32, &alpha, &d_a, 1)?; 30 | 31 | 32 | // Copy result back to host 33 | d_a.copy_to_host(&mut h_a)?; 34 | 35 | println!("Result: {:?}", h_a); 36 | 37 | Ok(()) 38 | } 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Diaconu Radu-Mihai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/miopen/examples/multi_tensor/src/kernels.rs: -------------------------------------------------------------------------------- 1 | use rocm_rs::rocm_kernel_macros::{amdgpu_global, amdgpu_kernel_finalize, amdgpu_kernel_init}; 2 | 3 | amdgpu_kernel_init!(); 4 | 5 | #[amdgpu_global] 6 | fn linear_transform( 7 | input: *const f32, 8 | weights: *const f32, 9 | bias: *const f32, 10 | output: *mut f32, 11 | input_size: usize, 12 | output_size: usize, 13 | ) { 14 | let idx = workgroup_id_x() as usize; 15 | 16 | if idx < output_size { 17 | unsafe { 18 | let mut sum = *bias.add(idx); 19 | let offset = idx * input_size; 20 | for i in 0..input_size { 21 | sum += *weights.add(offset + i) * *input.add(i); 22 | } 23 | 24 | *output.add(idx) = sum; 25 | } 26 | } 27 | } 28 | 29 | #[amdgpu_global] 30 | fn gradient(predicted: *const f32, target: *const f32, grad_output: *mut f32, size: usize) { 31 | let idx = workgroup_id_x() as usize; 32 | 33 | if idx < size { 34 | unsafe { 35 | *grad_output.add(idx) = *predicted.add(idx) - *target.add(idx); 36 | } 37 | } 38 | } 39 | 40 | pub const KERNEL: &[u8] = include_bytes!(amdgpu_kernel_finalize!()); 41 | -------------------------------------------------------------------------------- /include/rocwmma.h: -------------------------------------------------------------------------------- 1 | // include/rocwmma.h 2 | // Wrapper include file for ROCWmma bindings with workarounds for AMD intrinsics 3 | 4 | #ifndef ROCWMMA_BINDINGS_WRAPPER 5 | #define ROCWMMA_BINDINGS_WRAPPER 6 | 7 | // Define missing AMD GPU intrinsics for bindgen 8 | #define __builtin_amdgcn_ds_bpermute(a, b) (0) 9 | #define __builtin_amdgcn_ds_permute(a, b) (0) 10 | #define __builtin_amdgcn_mov_dpp(a, b, c, d, e) (0) 11 | #define __builtin_amdgcn_uicmp(a, b, c) (0) 12 | #define __builtin_amdgcn_mbcnt_lo(a, b) (0) 13 | #define __builtin_amdgcn_mbcnt_hi(a, b) (0) 14 | 15 | // Define missing architecture-specific macros 16 | #define __AMDGCN_WAVEFRONT_SIZE 64 17 | 18 | // ADDED: Define __hip_internal as an empty namespace BEFORE including headers 19 | // This might satisfy the parser when it encounters __hip_internal::something 20 | namespace __hip_internal {} 21 | 22 | // Headers added via --include in build.rs (string, type_traits) are processed first by clang 23 | 24 | // Now include the actual rocwmma headers 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | 31 | #endif // ROCWMMA_BINDINGS_WRAPPER -------------------------------------------------------------------------------- /src/rocfft/mod.rs: -------------------------------------------------------------------------------- 1 | // src/rocfft/mod.rs 2 | 3 | //! Bindings for rocfft 4 | //! Auto-generated - do not modify 5 | #[allow(warnings)] 6 | pub mod bindings; 7 | pub mod cache; 8 | pub mod description; 9 | pub mod error; 10 | pub mod execution; 11 | pub mod ffi; 12 | pub mod field; 13 | pub mod plan; 14 | 15 | // Add the new utility modules 16 | pub mod examples; 17 | pub mod utils; 18 | 19 | // Re-export all bindings 20 | pub use bindings::*; 21 | 22 | /// Initialize rocFFT library 23 | pub fn setup() -> error::Result<()> { 24 | unsafe { error::check_error(bindings::rocfft_setup()) } 25 | } 26 | 27 | /// Cleanup rocFFT library 28 | pub fn cleanup() -> error::Result<()> { 29 | unsafe { error::check_error(bindings::rocfft_cleanup()) } 30 | } 31 | 32 | /// Get the rocFFT version string 33 | pub fn get_version() -> error::Result { 34 | let mut buffer = vec![0u8; 100]; 35 | unsafe { 36 | error::check_error(bindings::rocfft_get_version_string( 37 | buffer.as_mut_ptr() as *mut i8, 38 | buffer.len(), 39 | ))?; 40 | 41 | // Find the null terminator 42 | let len = buffer.iter().position(|&c| c == 0).unwrap_or(buffer.len()); 43 | buffer.truncate(len); 44 | 45 | Ok(String::from_utf8_lossy(&buffer).to_string()) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/hip/examples/sort/src/main.rs: -------------------------------------------------------------------------------- 1 | use rocm_rs::hip::{self, DeviceMemory, memory_ext::MemoryExt}; 2 | 3 | fn main() -> Result<(), hip::Error> { 4 | let arr: Vec = vec![ 5 | 87, 23, 56, 12, 91, 45, 78, 34, 67, 5, 99, 31, 64, 29, 76, 18, 50, 82, 37, 93, 15, 41, 60, 6 | 27, 72, 11, 48, 80, 33, 66, 22, 55, 77, 10, 44, 88, 3, 39, 70, 25, 58, 9, 43, 75, 20, 53, 7 | 85, 30, 63, 17, 51, 84, 28, 61, 14, 47, 79, 2, 35, 68, 19, 52, 81, 26, 59, 92, 13, 46, 71, 8 | 24, 57, 90, 32, 65, 8, 40, 73, 16, 49, 83, 36, 69, 1, 38, 74, 21, 54, 86, 4, 42, 7, 62, 95, 9 | 31, 64, 98, 12, 45, 78, 0, 10 | ]; 11 | 12 | let mut host_sorted = arr.clone(); 13 | host_sorted.sort(); 14 | 15 | let mut device_arr = DeviceMemory::new(arr.len())?; 16 | 17 | device_arr.copy_from_host(&arr)?; 18 | 19 | device_arr.sort()?; 20 | 21 | let mut gpu_sroted_ascending = vec![0; arr.len()]; 22 | device_arr.copy_to_host(&mut gpu_sroted_ascending)?; 23 | 24 | assert_eq!(host_sorted, gpu_sroted_ascending); 25 | println!("Sorted ascending: {:?}", gpu_sroted_ascending); 26 | 27 | host_sorted.reverse(); 28 | 29 | device_arr.copy_from_host(&arr)?; 30 | 31 | device_arr.sort_desc()?; 32 | 33 | let mut gpu_sroted_descending = vec![0; arr.len()]; 34 | device_arr.copy_to_host(&mut gpu_sroted_descending)?; 35 | 36 | assert_eq!(host_sorted, gpu_sroted_descending); 37 | println!("Sorted descending: {:?}", gpu_sroted_descending); 38 | 39 | Ok(()) 40 | } 41 | -------------------------------------------------------------------------------- /src/hip/memory_ext/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod sorting; 2 | 3 | use crate::hip::memory_ext::sorting::GPUSortAllowed; 4 | use crate::hip::{DeviceMemory, Result, Stream}; 5 | 6 | pub trait MemoryExt { 7 | fn sort(&mut self) -> Result<()>; 8 | fn sort_desc(&mut self) -> Result<()>; 9 | fn sort_async(&mut self, stream: &Stream) -> Result<()>; 10 | fn sort_desc_async(&mut self, stream: &Stream) -> Result<()>; 11 | fn check_sorted(&self) -> Result; 12 | fn check_sorted_async(&self, stream: &Stream) -> Result; 13 | } 14 | 15 | impl MemoryExt for DeviceMemory 16 | where 17 | T: GPUSortAllowed, 18 | { 19 | fn sort(&mut self) -> Result<()> { 20 | let stream = Stream::new()?; 21 | self.sort_async(&stream)?; 22 | stream.synchronize()?; 23 | Ok(()) 24 | } 25 | 26 | fn sort_desc(&mut self) -> Result<()> { 27 | let stream = Stream::new()?; 28 | self.sort_desc_async(&stream)?; 29 | stream.synchronize()?; 30 | Ok(()) 31 | } 32 | 33 | fn sort_async(&mut self, stream: &Stream) -> Result<()> { 34 | sorting::sort(self, stream, true) 35 | } 36 | 37 | fn sort_desc_async(&mut self, stream: &Stream) -> Result<()> { 38 | sorting::sort(self, stream, false) 39 | } 40 | 41 | fn check_sorted(&self) -> Result { 42 | sorting::check_sorted(self, None) 43 | } 44 | 45 | fn check_sorted_async(&self, stream: &Stream) -> Result { 46 | sorting::check_sorted(self, Some(stream)) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/rocblas/macros.rs: -------------------------------------------------------------------------------- 1 | /// impl helper macro for rocblas functions 2 | #[macro_export] 3 | macro_rules! impl_rocblas_func { 4 | ($trait_name:ident, $fn_type:ident, {$( $t:ty => $func:path ),* $(,)?}) => { 5 | $( 6 | impl $trait_name for $t { 7 | fn func() -> $fn_type { 8 | $func 9 | } 10 | } 11 | )* 12 | }; 13 | } 14 | #[macro_export] 15 | macro_rules! impl_rocblas_func_inner { 16 | ($func:expr, $($arg:expr),+ $(,)?) => {{ 17 | let status = unsafe { $func($($arg),+) }; 18 | if status != ffi::rocblas_status__rocblas_status_success { 19 | return Err(Error::new(status)); 20 | } 21 | Ok(()) 22 | }}; 23 | } 24 | #[macro_export] 25 | macro_rules! impl_rocblas_traits { 26 | ( 27 | $trait_name:ident, 28 | $fn_type:ident, 29 | $ffi_map:tt, 30 | $method_name:ident, 31 | ($($arg:ident : $arg_ty:ty),+ $(,)?), 32 | ($($fn_arg:ty),+ $(,)?), 33 | ($($call_arg:expr),+ $(,)?) 34 | ) => { 35 | type $fn_type = unsafe extern "C" fn($($fn_arg),+) -> u32; 36 | 37 | pub trait $trait_name { 38 | fn func() -> $fn_type; 39 | 40 | unsafe fn $method_name( 41 | $($arg: $arg_ty),+ 42 | ) -> Result<()> { 43 | impl_rocblas_func_inner!( 44 | Self::func(), 45 | $($call_arg),+ 46 | ) 47 | } 48 | } 49 | 50 | impl_rocblas_func!($trait_name, $fn_type, $ffi_map); 51 | }; 52 | } -------------------------------------------------------------------------------- /src/rocsparse/vector.rs: -------------------------------------------------------------------------------- 1 | //! Sparse vector types 2 | 3 | use crate::rocsparse::descriptor::IndexBase; 4 | use crate::rocsparse::error::{Result, status_to_result}; 5 | use crate::rocsparse::{ 6 | rocsparse_create_spvec_descr, rocsparse_datatype, rocsparse_destroy_spvec_descr, 7 | rocsparse_indextype, rocsparse_spvec_descr, 8 | }; 9 | use std::ffi::c_void; 10 | use std::marker::PhantomData; 11 | use std::mem::MaybeUninit; 12 | 13 | /// Sparse vectors 14 | pub struct SparseVector { 15 | pub(crate) inner: rocsparse_spvec_descr, 16 | _phantom: PhantomData, 17 | } 18 | 19 | impl SparseVector { 20 | /// Create a new sparse vector 21 | pub unsafe fn new( 22 | size: i64, 23 | nnz: i64, 24 | indices: *mut c_void, 25 | values: *mut c_void, 26 | idx_type: rocsparse_indextype, 27 | idx_base: IndexBase, 28 | data_type: rocsparse_datatype, 29 | ) -> Result { 30 | let mut descr = MaybeUninit::uninit(); 31 | let status = unsafe { 32 | rocsparse_create_spvec_descr( 33 | descr.as_mut_ptr(), 34 | size, 35 | nnz, 36 | indices, 37 | values, 38 | idx_type, 39 | idx_base.into(), 40 | data_type, 41 | ) 42 | }; 43 | status_to_result(status)?; 44 | let descr = unsafe { descr.assume_init() }; 45 | Ok(Self { 46 | inner: descr, 47 | _phantom: PhantomData, 48 | }) 49 | } 50 | } 51 | 52 | impl Drop for SparseVector { 53 | fn drop(&mut self) { 54 | unsafe { 55 | // Ignore error on drop 56 | let _ = rocsparse_destroy_spvec_descr(self.inner); 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/miopen/examples/multi_tensor/src/data.rs: -------------------------------------------------------------------------------- 1 | use std::collections::BTreeSet; 2 | 3 | 4 | pub fn prepare_data() -> (Vec>, Vec>, Vec) { 5 | let data = read_iris(); 6 | 7 | let x_data = data 8 | .iter() 9 | .map(|(features, _)| features.clone()) 10 | .collect::>>(); 11 | 12 | let y_data = data 13 | .iter() 14 | .map(|(_, label)| label.clone()) 15 | .collect::>(); 16 | 17 | let (y_target, class_labels) = one_hot_encode(&y_data); 18 | 19 | (x_data, y_target, class_labels) 20 | } 21 | 22 | fn read_iris() -> Vec<(Vec, String)> { 23 | let mut rdr = csv::Reader::from_path("src/iris.csv").expect("Cannot open iris.csv"); 24 | let mut data = Vec::new(); 25 | 26 | for result in rdr.records() { 27 | let record = result.expect("Error reading record"); 28 | let features: Vec = record 29 | .iter() 30 | .take(4) 31 | .map(|s| s.parse::().expect("Error parsing feature")) 32 | .collect(); 33 | let label = record.get(4).unwrap().to_string(); 34 | data.push((features, label)); 35 | } 36 | 37 | data 38 | } 39 | 40 | fn one_hot_encode(labels: &[String]) -> (Vec>, Vec) { 41 | let unique_labels: Vec = labels 42 | .iter() 43 | .cloned() 44 | .collect::>() 45 | .into_iter() 46 | .collect(); 47 | 48 | let mut one_hot = Vec::with_capacity(labels.len()); 49 | 50 | for label in labels { 51 | let mut encoding = vec![0.0; unique_labels.len()]; 52 | if let Some(pos) = unique_labels.iter().position(|l| l == label) { 53 | encoding[pos] = 1.0; 54 | } 55 | one_hot.push(encoding); 56 | } 57 | 58 | (one_hot, unique_labels) 59 | } 60 | -------------------------------------------------------------------------------- /src/hip/examples/vector_add/README.md: -------------------------------------------------------------------------------- 1 | # Vector Addition Example 2 | 3 | This example demonstrates how to use the rocm-rs library to perform a simple vector addition operation on an AMD GPU using HIP. 4 | 5 | ## Components 6 | 7 | - `kernel.hip`: Contains the HIP kernel for vector addition 8 | - `build.sh`: Script to compile the kernel to a binary file 9 | - `main.rs`: Rust application that loads the kernel and executes it 10 | 11 | ## Building and Running 12 | 13 | ### Step 1: Compile the HIP kernel 14 | 15 | First, make sure you have ROCm installed and configured properly on your system. 16 | 17 | Then run the build script to compile the kernel: 18 | 19 | ```bash 20 | chmod +x build.sh 21 | ./build.sh 22 | ``` 23 | 24 | This will create a file named `vector_add.hsaco` which contains the compiled kernel. 25 | 26 | ### Step 2: Build and run the Rust application 27 | 28 | ```bash 29 | cargo build --release 30 | cp vector_add.hsaco target/release/ 31 | cargo run --release 32 | ``` 33 | 34 | You can specify the vector size as a command-line argument: 35 | 36 | ```bash 37 | cargo run --release -- 10000000 38 | ``` 39 | 40 | ## What this example demonstrates 41 | 42 | 1. **Loading Modules**: How to load a precompiled HIP kernel 43 | 2. **Memory Management**: Allocating and copying memory between host and device 44 | 3. **Kernel Execution**: Setting up and launching a kernel with proper parameters 45 | 4. **Performance Measurement**: Using the Timer API to measure performance of different operations 46 | 5. **Error Handling**: Using the unified error handling system 47 | 48 | ## Troubleshooting 49 | 50 | - If the kernel file isn't found, make sure you've run the build script and copied the .hsaco file to the same directory as the executable 51 | - Check that your ROCm installation is working correctly 52 | - Verify that you have a compatible AMD GPU 53 | 54 | ## Expected Output 55 | 56 | The example prints timing information for: 57 | - Host-to-device memory transfer 58 | - Kernel execution 59 | - Device-to-host memory transfer 60 | 61 | It also verifies the results by comparing them with a CPU computation and prints sample values from the beginning and end of the vector. -------------------------------------------------------------------------------- /src/hip/examples/saxpy/src/main.rs: -------------------------------------------------------------------------------- 1 | use rocm_rs::{ 2 | hip::{kernel::AsKernelArg, *}, 3 | kernel_args, 4 | rocm_kernel_macros::*, 5 | }; 6 | 7 | const LEN: usize = 1024; 8 | 9 | // initializing rust gpu kernel 10 | amdgpu_kernel_init!(); 11 | 12 | // saxpy 13 | // x = ax+y 14 | #[amdgpu_global] 15 | fn saxpy(a: u32, x_arr: *mut u32, y_arr: *const u32) { 16 | // retriving data from buffere by workitem 17 | let x = read_by_workgroup_id_x(x_arr); 18 | let y = read_by_workgroup_id_x(y_arr); 19 | 20 | // writing data back 21 | write_by_workitem_id_x(x_arr, a * x + y); 22 | } 23 | 24 | // compiling gpu kernel and embedding kernel code inside host executable 25 | const KERNEL: &[u8] = include_bytes!(amdgpu_kernel_finalize!()); 26 | 27 | fn main() -> Result<()> { 28 | // setting up device 29 | let device = Device::new(0)?; 30 | device.set_current()?; 31 | 32 | // loading gpu kerenel (runs in runtime!) 33 | 34 | let module = Module::load_data(KERNEL)?; 35 | 36 | // acquiring function handle from gpu kernel 37 | let function = module.get_function("saxpy")?; 38 | 39 | // preparing host side buffers 40 | let mut x_host: Vec = vec![0; LEN]; 41 | let mut y_host: Vec = vec![0; LEN]; 42 | 43 | // x => 0,1,2...LEN 44 | // x => 0,2,4...LEN 45 | for i in 0..LEN { 46 | x_host[i] = i as u32; 47 | y_host[i] = (i * 2) as u32; 48 | } 49 | 50 | // preparing gpu side buffers 51 | let mut x = DeviceMemory::::new(LEN)?; 52 | let mut y = DeviceMemory::::new(LEN)?; 53 | 54 | x.copy_from_host(&x_host)?; 55 | y.copy_from_host(&y_host)?; 56 | let a = 10; 57 | 58 | // providing arguments for kernel 59 | let kernel_args = kernel_args!(a, x, y); 60 | 61 | // setting up launch args 62 | let grid_dim = Dim3 { x: 2, y: 1, z: 1 }; 63 | let block_dim = Dim3 { 64 | x: (LEN / 2) as u32, 65 | y: 1, 66 | z: 1, 67 | }; 68 | 69 | function.launch(grid_dim, block_dim, 0, None, kernel_args)?; 70 | 71 | // retriving computed data 72 | x.copy_to_host(&mut x_host)?; 73 | 74 | println!("Output: {:?}", &x_host[..256]); 75 | 76 | Ok(()) 77 | } 78 | -------------------------------------------------------------------------------- /src/hip/examples/rust_kernel/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::path::PathBuf; 2 | 3 | use rocm_kernel_macros::{amdgpu_kernel_attr, amdgpu_kernel_finalize, amdgpu_kernel_init}; 4 | use rocm_rs::{ 5 | hip::{kernel::AsKernelArg, *}, 6 | kernel_args, 7 | }; 8 | 9 | const LEN: usize = 1024; 10 | 11 | // initializing rust gpu kernel 12 | amdgpu_kernel_init!(); 13 | 14 | // marking code that will be coppied to gpu kernel 15 | #[amdgpu_kernel_attr] 16 | fn kernel(input: *const u32, output: *mut u32) { 17 | // retriving data from buffere by workitem 18 | let num = read_by_workitem_id_x(input); 19 | 20 | // writing data back 21 | write_by_workitem_id_x(output, num * 3); 22 | } 23 | 24 | // compiling gpu kernel 25 | const AMDGPU_KERNEL_BINARY_PATH: &str = amdgpu_kernel_finalize!(); 26 | 27 | fn main() -> Result<()> { 28 | // setting up device 29 | let device = Device::new(0)?; 30 | device.set_current()?; 31 | 32 | // loading gpu kerenel (runs in runtime!) 33 | let kernel_path = PathBuf::from(AMDGPU_KERNEL_BINARY_PATH); 34 | assert!(kernel_path.exists()); 35 | 36 | let module = Module::load(kernel_path)?; 37 | 38 | // acquiring function handle from gpu kernel 39 | let function = module.get_function("kernel")?; 40 | 41 | // preparing host side buffers 42 | let mut in_host: Vec = vec![0; LEN]; 43 | let mut out_host: Vec = vec![0; LEN]; 44 | 45 | for i in 0..LEN { 46 | in_host[i] = i as u32; 47 | } 48 | 49 | // preparing gpu side buffers 50 | let mut input = DeviceMemory::::new(LEN)?; 51 | let output = DeviceMemory::::new(LEN)?; 52 | 53 | input.copy_from_host(&in_host)?; 54 | 55 | // providing arguments for kernel 56 | let kernel_args = kernel_args!(input, output); 57 | 58 | // setting up launch args 59 | let grid_dim = Dim3 { x: 2, y: 1, z: 1 }; 60 | let block_dim = Dim3 { 61 | x: (LEN / 2) as u32, 62 | y: 1, 63 | z: 1, 64 | }; 65 | 66 | function.launch(grid_dim, block_dim, 0, None, &mut kernel_args.clone())?; 67 | 68 | // retriving computed data 69 | output.copy_to_host(&mut out_host)?; 70 | 71 | println!("Output: {:?}", &out_host[..256]); 72 | 73 | Ok(()) 74 | } 75 | -------------------------------------------------------------------------------- /src/miopen/mod.rs: -------------------------------------------------------------------------------- 1 | // src/miopen/mod.rs 2 | 3 | // Private modules 4 | pub mod activation; 5 | pub mod batchnorm; 6 | pub mod convolution; 7 | pub mod dropout; 8 | pub mod error; 9 | pub mod fusion; 10 | pub mod handle; 11 | pub mod lrn; 12 | pub mod mha; 13 | pub mod pooling; 14 | pub mod reduce; 15 | pub mod rnn; 16 | pub mod softmax; 17 | pub mod tensor; 18 | 19 | // We need to make this public for the rest of the crate 20 | // but don't necessarily want to expose it to users 21 | #[allow(warnings)] 22 | pub(crate) mod bindings; 23 | 24 | // Public re-export of FFI for internal use 25 | pub mod ctc_loss; 26 | pub mod ffi; 27 | 28 | // Re-export the main components for the public API 29 | pub use activation::{ActivationDescriptor, ActivationMode}; 30 | pub use batchnorm::BatchNormMode; 31 | pub use convolution::{ 32 | ConvBwdDataAlgorithm, ConvBwdWeightsAlgorithm, ConvFwdAlgorithm, ConvolutionDescriptor, 33 | ConvolutionMode, ConvolutionPerf, convolution_backward_data, convolution_backward_weights, 34 | convolution_forward, find_convolution_forward_algorithm, 35 | }; 36 | pub use dropout::{DropoutDescriptor, RNGType}; 37 | pub use error::{Error, Result}; 38 | pub use fusion::{FusionDirection, FusionOpDescriptor, FusionPlanDescriptor, OperatorArgs}; 39 | pub use handle::Handle; 40 | pub use lrn::{LRNDescriptor, LRNMode}; 41 | pub use pooling::{PoolingDescriptor, PoolingMode, PoolingWorkspaceIndexMode}; 42 | pub use reduce::{ 43 | IndicesType, NanPropagation, ReduceTensorDescriptor, ReduceTensorIndices, ReduceTensorOp, 44 | }; 45 | pub use rnn::{RNNAlgo, RNNBiasMode, RNNDescriptor, RNNDirectionMode, RNNInputMode, RNNMode}; 46 | pub use softmax::{ 47 | SoftmaxAlgorithm, SoftmaxDescriptor, SoftmaxMode, softmax_backward, softmax_backward_v2, 48 | softmax_forward, softmax_forward_v2, 49 | }; 50 | pub use tensor::{DataType, SeqTensorDescriptor, TensorDescriptor, TensorLayout}; 51 | 52 | // New components 53 | pub use mha::{MhaDescriptor, MhaMask, TensorArgumentId, mha_mask, tensor_argument_id}; 54 | 55 | /// Get MIOpen version information 56 | pub fn get_version() -> Result<(usize, usize, usize)> { 57 | let mut major = 0; 58 | let mut minor = 0; 59 | let mut patch = 0; 60 | 61 | let status = unsafe { ffi::miopenGetVersion(&mut major, &mut minor, &mut patch) }; 62 | 63 | Error::from_miopen_status_with_value(status, (major, minor, patch)) 64 | } 65 | -------------------------------------------------------------------------------- /src/hip/mod.rs: -------------------------------------------------------------------------------- 1 | // src/hip/mod.rs 2 | 3 | // Private modules 4 | pub mod device; 5 | pub mod error; 6 | pub mod event; 7 | pub mod kernel; 8 | pub mod memory; 9 | pub mod module; 10 | pub mod stream; 11 | pub mod utils; 12 | 13 | // We need to make this public for the rest of the crate 14 | // but don't necessarily want to expose it to users 15 | #[allow(warnings)] 16 | pub mod bindings; 17 | 18 | // Public re-export of FFI for internal use 19 | pub mod ffi; 20 | #[cfg(feature = "macros")] 21 | pub mod memory_ext; 22 | 23 | // Re-export the main components for the public API 24 | pub use device::{Device, DeviceProperties, get_device_count, get_device_properties}; 25 | pub use error::{Error, Result}; 26 | pub use event::{Event, Timer, event_flags}; 27 | pub use kernel::{Function, stream_to_rocrand}; 28 | pub use memory::{DeviceMemory, MemoryInfo, PinnedMemory, memory_info}; 29 | pub use module::{Module, compile_and_load, load_module, load_module_data}; 30 | pub use stream::{Stream, stream_flags}; 31 | pub use utils::{ 32 | Dim3, Version, calculate_grid_1d, calculate_grid_2d, calculate_grid_3d, is_hip_available, print_devices_info, 33 | }; 34 | 35 | /// Get the number of devices 36 | pub fn device_count() -> Result { 37 | device::get_device_count() 38 | } 39 | 40 | /// Initialize the HIP runtime 41 | pub fn init() -> Result<()> { 42 | let error = unsafe { ffi::hipInit(0) }; 43 | Error::from_hip_error(error) 44 | } 45 | 46 | /// Get the HIP driver version 47 | pub fn driver_version() -> Result { 48 | let mut version = 0; 49 | let error = unsafe { ffi::hipDriverGetVersion(&mut version) }; 50 | error::Error::from_hip_error_with_value(error, version) 51 | } 52 | 53 | /// Get the HIP runtime version 54 | pub fn runtime_version() -> Result { 55 | let mut version = 0; 56 | let error = unsafe { ffi::hipRuntimeGetVersion(&mut version) }; 57 | error::Error::from_hip_error_with_value(error, version) 58 | } 59 | 60 | /// Get the last error that occurred 61 | pub fn get_last_error() -> Error { 62 | Error::new(unsafe { ffi::hipGetLastError() }) 63 | } 64 | 65 | /// Synchronize the current device 66 | pub fn device_synchronize() -> Result<()> { 67 | let error = unsafe { ffi::hipDeviceSynchronize() }; 68 | Error::from_hip_error(error) 69 | } 70 | 71 | /// Reset the current device 72 | pub fn device_reset() -> Result<()> { 73 | let error = unsafe { ffi::hipDeviceReset() }; 74 | Error::from_hip_error(error) 75 | } 76 | -------------------------------------------------------------------------------- /src/hip/examples/rust_kernel_async/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::path::PathBuf; 2 | 3 | use rocm_rs::{hip::{kernel::AsKernelArg, *}, rocm_kernel_macros::*}; 4 | 5 | const LEN: usize = 1024; 6 | 7 | // initializing rust gpu kernel 8 | amdgpu_kernel_init!(); 9 | 10 | // marking code that will be coppied to gpu kernel 11 | #[amdgpu_global] 12 | fn kernel(input: *const u32, output: *mut u32) { 13 | // retriving data from buffere by workitem 14 | let num = read_by_workitem_id_x(input); 15 | 16 | // writing data back 17 | write_by_workitem_id_x(output, num * 3); 18 | } 19 | 20 | // compiling gpu kernel 21 | const AMDGPU_KERNEL_BINARY_PATH: &str = amdgpu_kernel_finalize!(); 22 | 23 | fn main() -> Result<()> { 24 | // setting up device 25 | let device = Device::current()?; 26 | 27 | // Create a stream for async operations 28 | let stream = device.get_stream()?; 29 | 30 | // adding callback that will be triggered at the end of stream 31 | stream.add_callback(|| println!("callback"))?; 32 | 33 | // loading gpu kerenel (runs in runtime!) 34 | let kernel_path = PathBuf::from(AMDGPU_KERNEL_BINARY_PATH); 35 | assert!(kernel_path.exists()); 36 | 37 | let module = Module::load(kernel_path)?; 38 | 39 | // acquiring function handle from gpu kernel 40 | let function = module.get_function("kernel")?; 41 | 42 | // preparing host side buffers 43 | let mut in_host: Vec = vec![0; LEN]; 44 | let out_host: Vec = vec![0; LEN]; 45 | 46 | for i in 0..LEN { 47 | in_host[i] = i as u32; 48 | } 49 | 50 | // preparing gpu side buffers 51 | let input = DeviceMemory::::new(LEN)?; 52 | let output = DeviceMemory::::new(LEN)?; 53 | 54 | // Copy data from host to device 55 | input.copy_from_host_async(in_host, &stream)?; 56 | 57 | // providing arguments for kernel 58 | let kernel_args = [input.as_kernel_arg(), output.as_kernel_arg()]; 59 | 60 | // setting up launch args 61 | let grid_dim = Dim3 { x: 2, y: 1, z: 1 }; 62 | let block_dim = Dim3 { 63 | x: (LEN / 2) as u32, 64 | y: 1, 65 | z: 1, 66 | }; 67 | 68 | function.launch(grid_dim, block_dim, 0, Some(&stream), &mut kernel_args.clone())?; 69 | 70 | // retriving computed data 71 | let pending = output.copy_to_host_async(out_host, &stream)?; 72 | 73 | // synchronizing memory (awaiting for copy to finish) 74 | let out_host = stream.synchronize_memory(pending)?; 75 | println!("Output: {:?}", &out_host[..256]); 76 | 77 | Ok(()) 78 | } 79 | -------------------------------------------------------------------------------- /src/rocblas/mod.rs: -------------------------------------------------------------------------------- 1 | // src/rocblas/mod.rs 2 | 3 | // Private modules 4 | pub mod error; 5 | pub mod handle; 6 | pub mod level1; 7 | pub mod level2; 8 | pub mod level3; 9 | pub mod types; 10 | pub mod utils; 11 | pub(crate) mod macros; 12 | // We need to make this public for the rest of the crate 13 | // but don't necessarily want to expose it to users 14 | #[allow(warnings)] 15 | pub(crate) mod bindings; 16 | 17 | // Public re-export of FFI for internal use 18 | mod async_ops; 19 | pub mod ffi; 20 | 21 | // Re-export the main components for the public API 22 | pub use error::{Error, Result}; 23 | pub use handle::Handle; 24 | pub use level1::{ 25 | amax, 26 | amax_batched, 27 | amax_strided_batched, 28 | amin, 29 | amin_batched, 30 | amin_strided_batched, 31 | asum, 32 | // batched variants 33 | asum_batched, 34 | // strided batched variants 35 | asum_strided_batched, 36 | axpy, 37 | axpy_batched, 38 | axpy_strided_batched, 39 | copy, 40 | copy_batched, 41 | copy_strided_batched, 42 | dot, 43 | dot_batched, 44 | dot_strided_batched, 45 | dotc, 46 | dotc_batched, 47 | dotc_strided_batched, 48 | dotu, 49 | dotu_batched, 50 | dotu_strided_batched, 51 | nrm2, 52 | nrm2_batched, 53 | nrm2_strided_batched, 54 | rot, 55 | rot_batched, 56 | rot_strided_batched, 57 | rotg, 58 | rotg_batched, 59 | rotg_strided_batched, 60 | rotm, 61 | rotm_batched, 62 | rotm_strided_batched, 63 | rotmg, 64 | rotmg_batched, 65 | rotmg_strided_batched, 66 | scal, 67 | scal_batched, 68 | scal_strided_batched, 69 | swap, 70 | swap_batched, 71 | swap_strided_batched, 72 | }; 73 | pub use level2::{ 74 | gbmv, 75 | // batched variants 76 | gbmv_batched, 77 | // strided batched variants 78 | gbmv_strided_batched, 79 | gemv, 80 | gemv_batched, 81 | gemv_strided_batched, 82 | hbmv, 83 | hbmv_batched, 84 | hbmv_strided_batched, 85 | }; 86 | pub use level3::{gemm, gemm_batched, gemm_strided_batched}; 87 | pub use types::{ 88 | rocblas_bfloat16, rocblas_datatype, rocblas_diagonal, rocblas_double_complex, rocblas_fill, 89 | rocblas_float_complex, rocblas_half, rocblas_operation, rocblas_side, 90 | }; 91 | pub use utils::{ 92 | AtomicsMode, GemmAlgo, GemmFlags, LayerMode, MathMode, PerformanceMetric, PointerMode, 93 | get_atomics_mode, get_math_mode, get_performance_metric, get_pointer_mode, set_atomics_mode, 94 | set_math_mode, set_performance_metric, set_pointer_mode, 95 | }; 96 | 97 | /// Create a RocBLAS handle 98 | pub fn create_handle() -> Result { 99 | Handle::new() 100 | } 101 | 102 | /// Initialize RocBLAS 103 | /// 104 | /// Note: In most cases, explicit initialization is not required 105 | /// as handle creation will initialize the library 106 | pub fn init() -> Result<()> { 107 | // Creating and immediately dropping a handle 108 | // will initialize rocBLAS and free resources 109 | let _ = create_handle()?; 110 | Ok(()) 111 | } 112 | -------------------------------------------------------------------------------- /src/miopen/error.rs: -------------------------------------------------------------------------------- 1 | // src/miopen/error.rs 2 | 3 | use crate::miopen::ffi; 4 | use std::error::Error as StdError; 5 | use std::fmt; 6 | 7 | /// Error type for MIOpen operations 8 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 9 | pub struct Error { 10 | code: ffi::miopenStatus_t, 11 | } 12 | 13 | /// Result type for MIOpen operations 14 | pub type Result = std::result::Result; 15 | 16 | impl Error { 17 | /// Create a new error from a MIOpen status code 18 | pub fn new(code: ffi::miopenStatus_t) -> Self { 19 | Self { code } 20 | } 21 | 22 | /// Convert a MIOpen status code to a Result 23 | pub fn from_miopen_status(status: ffi::miopenStatus_t) -> Result 24 | where 25 | T: Default, 26 | { 27 | if status == ffi::miopenStatus_t_miopenStatusSuccess { 28 | Ok(T::default()) 29 | } else { 30 | Err(Error::new(status)) 31 | } 32 | } 33 | 34 | /// Convert a MIOpen status code to a Result with a specific value 35 | pub fn from_miopen_status_with_value(status: ffi::miopenStatus_t, value: T) -> Result { 36 | if status == ffi::miopenStatus_t_miopenStatusSuccess { 37 | Ok(value) 38 | } else { 39 | Err(Error::new(status)) 40 | } 41 | } 42 | 43 | /// Returns true if the status code represents success 44 | pub fn is_success(&self) -> bool { 45 | self.code == ffi::miopenStatus_t_miopenStatusSuccess 46 | } 47 | 48 | /// Get the raw status code 49 | pub fn code(&self) -> ffi::miopenStatus_t { 50 | self.code 51 | } 52 | 53 | /// Returns the error description as a string 54 | pub fn description(&self) -> &'static str { 55 | unsafe { 56 | let desc_ptr = ffi::miopenGetErrorString(self.code); 57 | if desc_ptr.is_null() { 58 | "Unknown error" 59 | } else { 60 | // This is safe because miopenGetErrorString returns a static string 61 | std::ffi::CStr::from_ptr(desc_ptr) 62 | .to_str() 63 | .unwrap_or("Invalid error string") 64 | } 65 | } 66 | } 67 | } 68 | 69 | impl fmt::Display for Error { 70 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 71 | write!(f, "MIOpen error {}: {}", self.code, self.description()) 72 | } 73 | } 74 | 75 | impl StdError for Error {} 76 | 77 | // Define error conversion functions for common MIOpen error codes 78 | impl Error { 79 | pub fn is_not_initialized(&self) -> bool { 80 | self.code == ffi::miopenStatus_t_miopenStatusNotInitialized 81 | } 82 | 83 | pub fn is_invalid_value(&self) -> bool { 84 | self.code == ffi::miopenStatus_t_miopenStatusInvalidValue 85 | } 86 | 87 | pub fn is_bad_param(&self) -> bool { 88 | self.code == ffi::miopenStatus_t_miopenStatusBadParm 89 | } 90 | 91 | pub fn is_alloc_failed(&self) -> bool { 92 | self.code == ffi::miopenStatus_t_miopenStatusAllocFailed 93 | } 94 | 95 | pub fn is_internal_error(&self) -> bool { 96 | self.code == ffi::miopenStatus_t_miopenStatusInternalError 97 | } 98 | 99 | pub fn is_not_implemented(&self) -> bool { 100 | self.code == ffi::miopenStatus_t_miopenStatusNotImplemented 101 | } 102 | 103 | pub fn is_unsupported_op(&self) -> bool { 104 | self.code == ffi::miopenStatus_t_miopenStatusUnsupportedOp 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /src/hip/kernel.rs: -------------------------------------------------------------------------------- 1 | // src/hip/kernel.rs 2 | // 3 | // Kernel launching functions for HIP 4 | 5 | use crate::hip::Stream; 6 | use crate::hip::error::{Error, Result}; 7 | use crate::hip::ffi; 8 | use crate::hip::memory::KernelArg; 9 | use crate::hip::utils::Dim3; 10 | use std::ffi::{CString, c_void}; 11 | use std::ptr; 12 | 13 | /// A wrapper around a HIP function (kernel) 14 | pub struct Function { 15 | function: ffi::hipFunction_t, 16 | } 17 | 18 | impl Function { 19 | /// Create a new function from a module and function name 20 | pub unsafe fn new(module: ffi::hipModule_t, name: &str) -> Result { 21 | let func_name = CString::new(name).unwrap(); 22 | let mut function = ptr::null_mut(); 23 | 24 | let error = unsafe { ffi::hipModuleGetFunction(&mut function, module, func_name.as_ptr()) }; 25 | 26 | if error != ffi::hipError_t_hipSuccess { 27 | return Err(Error::new(error)); 28 | } 29 | 30 | Ok(Self { function }) 31 | } 32 | 33 | /// Launch the kernel with the given parameters 34 | pub fn launch( 35 | &self, 36 | grid_dim: Dim3, 37 | block_dim: Dim3, 38 | shared_mem_bytes: u32, 39 | stream: Option<&Stream>, 40 | kernel_params: &mut [*mut c_void], 41 | ) -> Result<()> { 42 | let stream_ptr = match stream { 43 | Some(s) => s.as_raw(), 44 | None => ptr::null_mut(), 45 | }; 46 | 47 | let error = unsafe { 48 | ffi::hipModuleLaunchKernel( 49 | self.function, 50 | grid_dim.x, 51 | grid_dim.y, 52 | grid_dim.z, 53 | block_dim.x, 54 | block_dim.y, 55 | block_dim.z, 56 | shared_mem_bytes, 57 | stream_ptr, 58 | kernel_params.as_mut_ptr(), 59 | ptr::null_mut(), // extra 60 | ) 61 | }; 62 | 63 | if error != ffi::hipError_t_hipSuccess { 64 | return Err(Error::new(error)); 65 | } 66 | 67 | Ok(()) 68 | } 69 | 70 | /// Get the raw function handle 71 | pub fn as_raw(&self) -> ffi::hipFunction_t { 72 | self.function 73 | } 74 | 75 | // Creates Function from raw function ponter 76 | pub unsafe fn from_raw(function: ffi::hipFunction_t) -> Self { 77 | Self { function } 78 | } 79 | } 80 | 81 | /// A trait for types that can be passed as kernel arguments 82 | pub trait AsKernelArg { 83 | /// Get a pointer to the argument value 84 | fn as_kernel_arg(&self) -> KernelArg; 85 | } 86 | 87 | // Implement KernelArg for common types 88 | macro_rules! impl_kernel_arg { 89 | ($($t:ty),*) => { 90 | $( 91 | impl AsKernelArg for $t { 92 | fn as_kernel_arg(&self) -> KernelArg { 93 | self as *const $t as *mut c_void 94 | } 95 | } 96 | )* 97 | }; 98 | } 99 | 100 | impl_kernel_arg!( 101 | usize, isize, i8, i16, i32, i64, u8, u16, u32, u64, f32, f64, bool 102 | ); 103 | 104 | #[macro_export] 105 | macro_rules! kernel_args { 106 | ($($i:expr),*) => { 107 | &mut [$($i.as_kernel_arg()),*] 108 | }; 109 | } 110 | 111 | /// Helper function to convert a Stream reference to the rocrand stream type 112 | pub fn stream_to_rocrand(stream: &Stream) -> crate::rocrand::bindings::hipStream_t { 113 | // Safe cast because both represent the same underlying HIP stream 114 | stream.as_raw() as crate::rocrand::bindings::hipStream_t 115 | } 116 | -------------------------------------------------------------------------------- /src/rocrand/error.rs: -------------------------------------------------------------------------------- 1 | // src/rocrand/error.rs 2 | 3 | use crate::rocrand::bindings; 4 | use std::fmt; 5 | 6 | /// rocRAND error types 7 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 8 | pub enum Error { 9 | /// Header file and linked library version do not match 10 | VersionMismatch, 11 | /// Generator was not created using rocrand_create_generator 12 | NotCreated, 13 | /// Memory allocation failed during execution 14 | AllocationFailed, 15 | /// Generator type is wrong 16 | TypeError, 17 | /// Argument out of range 18 | OutOfRange, 19 | /// Requested size is not a multiple of quasirandom generator's dimension, 20 | /// or requested size is not even, or pointer is misaligned 21 | LengthNotMultiple, 22 | /// GPU does not have double precision 23 | DoublePrecisionRequired, 24 | /// Kernel launch failure 25 | LaunchFailure, 26 | /// Internal library error 27 | InternalError, 28 | /// Unknown error 29 | Unknown(u32), 30 | } 31 | 32 | /// Specialized Result type for rocrand operations 33 | pub type Result = std::result::Result; 34 | 35 | impl Error { 36 | /// Convert a rocrand status code to a Result 37 | pub(crate) fn from_status(status: u32) -> Result<()> { 38 | match status { 39 | bindings::rocrand_status_ROCRAND_STATUS_SUCCESS => Ok(()), 40 | bindings::rocrand_status_ROCRAND_STATUS_VERSION_MISMATCH => Err(Error::VersionMismatch), 41 | bindings::rocrand_status_ROCRAND_STATUS_NOT_CREATED => Err(Error::NotCreated), 42 | bindings::rocrand_status_ROCRAND_STATUS_ALLOCATION_FAILED => { 43 | Err(Error::AllocationFailed) 44 | } 45 | bindings::rocrand_status_ROCRAND_STATUS_TYPE_ERROR => Err(Error::TypeError), 46 | bindings::rocrand_status_ROCRAND_STATUS_OUT_OF_RANGE => Err(Error::OutOfRange), 47 | bindings::rocrand_status_ROCRAND_STATUS_LENGTH_NOT_MULTIPLE => { 48 | Err(Error::LengthNotMultiple) 49 | } 50 | bindings::rocrand_status_ROCRAND_STATUS_DOUBLE_PRECISION_REQUIRED => { 51 | Err(Error::DoublePrecisionRequired) 52 | } 53 | bindings::rocrand_status_ROCRAND_STATUS_LAUNCH_FAILURE => Err(Error::LaunchFailure), 54 | bindings::rocrand_status_ROCRAND_STATUS_INTERNAL_ERROR => Err(Error::InternalError), 55 | other => Err(Error::Unknown(other)), 56 | } 57 | } 58 | } 59 | 60 | impl fmt::Display for Error { 61 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 62 | match self { 63 | Error::VersionMismatch => { 64 | write!(f, "Header file and linked library version do not match") 65 | } 66 | Error::NotCreated => write!( 67 | f, 68 | "Generator was not created using rocrand_create_generator" 69 | ), 70 | Error::AllocationFailed => write!(f, "Memory allocation failed during execution"), 71 | Error::TypeError => write!(f, "Generator type is wrong"), 72 | Error::OutOfRange => write!(f, "Argument out of range"), 73 | Error::LengthNotMultiple => write!( 74 | f, 75 | "Length not multiple of dimension or other alignment issue" 76 | ), 77 | Error::DoublePrecisionRequired => write!(f, "GPU does not have double precision"), 78 | Error::LaunchFailure => write!(f, "Kernel launch failure"), 79 | Error::InternalError => write!(f, "Internal library error"), 80 | Error::Unknown(code) => write!(f, "Unknown error (code: {})", code), 81 | } 82 | } 83 | } 84 | 85 | impl std::error::Error for Error {} 86 | -------------------------------------------------------------------------------- /src/rocsparse/matrix.rs: -------------------------------------------------------------------------------- 1 | //! Sparse matrix types and formats 2 | 3 | use crate::rocsparse::descriptor::IndexBase; 4 | use crate::rocsparse::error::{Result, status_to_result}; 5 | use crate::rocsparse::{ 6 | rocsparse_create_hyb_mat, rocsparse_create_mat_info, rocsparse_destroy_hyb_mat, 7 | rocsparse_destroy_mat_info, rocsparse_destroy_spmat_descr, rocsparse_hyb_mat, 8 | rocsparse_hyb_partition_, rocsparse_hyb_partition__rocsparse_hyb_partition_auto, 9 | rocsparse_hyb_partition__rocsparse_hyb_partition_max, 10 | rocsparse_hyb_partition__rocsparse_hyb_partition_user, rocsparse_mat_info, 11 | rocsparse_spmat_descr, 12 | }; 13 | use std::marker::PhantomData; 14 | use std::mem::MaybeUninit; 15 | 16 | /// HYB matrix partitioning type 17 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 18 | pub enum HybPartition { 19 | /// Automatically decide on ELL nnz per row 20 | Auto, 21 | /// User given ELL nnz per row 22 | User, 23 | /// Max ELL nnz per row, no COO part 24 | Max, 25 | } 26 | 27 | impl From for rocsparse_hyb_partition_ { 28 | fn from(partition: HybPartition) -> Self { 29 | match partition { 30 | HybPartition::Auto => rocsparse_hyb_partition__rocsparse_hyb_partition_auto, 31 | HybPartition::User => rocsparse_hyb_partition__rocsparse_hyb_partition_user, 32 | HybPartition::Max => rocsparse_hyb_partition__rocsparse_hyb_partition_max, 33 | } 34 | } 35 | } 36 | 37 | /// Hybrid matrix format (ELL + COO) 38 | pub struct HybMatrix { 39 | pub(crate) inner: rocsparse_hyb_mat, 40 | } 41 | 42 | impl HybMatrix { 43 | /// Create a new HYB matrix 44 | pub fn new() -> Result { 45 | let mut hyb = MaybeUninit::uninit(); 46 | let status = unsafe { rocsparse_create_hyb_mat(hyb.as_mut_ptr()) }; 47 | status_to_result(status)?; 48 | let hyb = unsafe { hyb.assume_init() }; 49 | Ok(Self { inner: hyb }) 50 | } 51 | } 52 | 53 | impl Drop for HybMatrix { 54 | fn drop(&mut self) { 55 | unsafe { 56 | // Ignore error on drop 57 | let _ = rocsparse_destroy_hyb_mat(self.inner); 58 | } 59 | } 60 | } 61 | 62 | /// Matrix info structure 63 | pub struct MatrixInfo { 64 | pub(crate) inner: rocsparse_mat_info, 65 | } 66 | 67 | impl MatrixInfo { 68 | /// Create a new matrix info 69 | pub fn new() -> Result { 70 | let mut info = MaybeUninit::uninit(); 71 | let status = unsafe { rocsparse_create_mat_info(info.as_mut_ptr()) }; 72 | status_to_result(status)?; 73 | let info = unsafe { info.assume_init() }; 74 | Ok(Self { inner: info }) 75 | } 76 | } 77 | 78 | impl Drop for MatrixInfo { 79 | fn drop(&mut self) { 80 | unsafe { 81 | // Ignore error on drop 82 | let _ = rocsparse_destroy_mat_info(self.inner); 83 | } 84 | } 85 | } 86 | 87 | /// Sparse matrix representation 88 | pub struct SparseMatrix { 89 | pub(crate) inner: rocsparse_spmat_descr, 90 | _phantom: PhantomData, 91 | } 92 | 93 | impl Drop for SparseMatrix { 94 | fn drop(&mut self) { 95 | unsafe { 96 | // Ignore error on drop 97 | let _ = rocsparse_destroy_spmat_descr(self.inner); 98 | } 99 | } 100 | } 101 | 102 | /// CSR (Compressed Sparse Row) matrix format helper 103 | pub struct CsrMatrix { 104 | /// Number of rows 105 | pub rows: i32, 106 | /// Number of columns 107 | pub cols: i32, 108 | /// Row pointers 109 | pub row_ptr: Vec, 110 | /// Column indices 111 | pub col_ind: Vec, 112 | /// Values 113 | pub values: Vec, 114 | /// Index base (zero or one) 115 | pub index_base: IndexBase, 116 | } 117 | -------------------------------------------------------------------------------- /src/rocprofiler/error.rs: -------------------------------------------------------------------------------- 1 | // src/rocprofiler/error.rs 2 | 3 | use crate::hip; 4 | use std::fmt; 5 | use std::error::Error as StdError; 6 | use std::ffi::CStr; 7 | 8 | use super::bindings; 9 | 10 | /// Error type for ROCProfiler operations 11 | #[derive(Debug, Clone, Copy)] 12 | pub struct Error { 13 | status: u32, // Using hsa_status_t 14 | } 15 | 16 | /// Result type for ROCProfiler operations 17 | pub type Result = std::result::Result; 18 | 19 | impl Error { 20 | /// Create a new error from an HSA status code 21 | pub fn new(status: u32) -> Self { 22 | Self { status } 23 | } 24 | 25 | /// Returns true if the error code represents success 26 | pub fn is_success(&self) -> bool { 27 | self.status == bindings::hsa_status_t_HSA_STATUS_SUCCESS 28 | } 29 | 30 | /// Get the raw error code 31 | pub fn code(&self) -> u32 { 32 | self.status 33 | } 34 | 35 | /// Convert an HSA status code to a Result 36 | pub fn from_hsa_status(status: u32) -> Result 37 | where 38 | T: Default, 39 | { 40 | if status == bindings::hsa_status_t_HSA_STATUS_SUCCESS { 41 | Ok(T::default()) 42 | } else { 43 | Err(Error::new(status)) 44 | } 45 | } 46 | 47 | /// Convert an HSA status code to a Result with a specific value 48 | pub fn from_hsa_status_with_value(status: u32, value: T) -> Result { 49 | if status == bindings::hsa_status_t_HSA_STATUS_SUCCESS { 50 | Ok(value) 51 | } else { 52 | Err(Error::new(status)) 53 | } 54 | } 55 | 56 | /// Returns the error description as a string 57 | pub fn description(&self) -> &'static str { 58 | match self.status { 59 | bindings::hsa_status_t_HSA_STATUS_SUCCESS => "Success", 60 | bindings::hsa_status_t_HSA_STATUS_ERROR => "Generic error", 61 | bindings::hsa_status_t_HSA_STATUS_ERROR_INVALID_ARGUMENT => "Invalid argument", 62 | bindings::hsa_status_t_HSA_STATUS_ERROR_OUT_OF_RESOURCES => "Out of resources", 63 | bindings::hsa_status_t_HSA_STATUS_ERROR_NOT_INITIALIZED => "Not initialized", 64 | bindings::hsa_status_t_HSA_STATUS_ERROR_INVALID_AGENT => "Invalid agent", 65 | bindings::hsa_status_t_HSA_STATUS_ERROR_INVALID_REGION => "Invalid region", 66 | _ => unsafe { 67 | // Try to get the actual error string from ROCProfiler 68 | let mut error_str_ptr = std::ptr::null(); 69 | if bindings::rocprofiler_error_string(&mut error_str_ptr) == bindings::hsa_status_t_HSA_STATUS_SUCCESS && !error_str_ptr.is_null() { 70 | let c_str = CStr::from_ptr(error_str_ptr); 71 | match c_str.to_str() { 72 | Ok(s) => { 73 | // This is not ideal as we're returning a slice that might not live long enough, 74 | // but ROCProfiler documentation suggests the string is static 75 | s 76 | } 77 | Err(_) => "Unknown error (invalid UTF-8 in error string)", 78 | } 79 | } else { 80 | "Unknown error" 81 | } 82 | } 83 | } 84 | } 85 | } 86 | 87 | impl fmt::Display for Error { 88 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 89 | write!(f, "ROCProfiler error {}: {}", self.status, self.description()) 90 | } 91 | } 92 | 93 | impl StdError for Error {} 94 | 95 | // Automatic conversion from HIP errors 96 | impl From for Error { 97 | fn from(error: hip::Error) -> Self { 98 | // Map HIP errors to a generic HSA error 99 | Error::new(bindings::hsa_status_t_HSA_STATUS_ERROR) 100 | } 101 | } -------------------------------------------------------------------------------- /src/rocsparse/handle.rs: -------------------------------------------------------------------------------- 1 | //! ROCsparse library context handle 2 | 3 | use crate::rocsparse::error::{Result, status_to_result}; 4 | use crate::rocsparse::{ 5 | ihipStream_t, rocsparse_create_handle, rocsparse_destroy_handle, rocsparse_get_pointer_mode, 6 | rocsparse_get_stream, rocsparse_get_version, rocsparse_handle, rocsparse_pointer_mode_, 7 | rocsparse_pointer_mode__rocsparse_pointer_mode_device, 8 | rocsparse_pointer_mode__rocsparse_pointer_mode_host, rocsparse_set_pointer_mode, 9 | rocsparse_set_stream, 10 | }; 11 | use std::mem::MaybeUninit; 12 | 13 | /// ROCsparse library context 14 | pub struct Handle { 15 | pub(crate) inner: rocsparse_handle, 16 | } 17 | 18 | impl Handle { 19 | /// Create a new ROCsparse handle 20 | pub fn new() -> Result { 21 | let mut handle = MaybeUninit::uninit(); 22 | let status = unsafe { rocsparse_create_handle(handle.as_mut_ptr()) }; 23 | status_to_result(status)?; 24 | let handle = unsafe { handle.assume_init() }; 25 | Ok(Self { inner: handle }) 26 | } 27 | 28 | /// Set the stream for the handle 29 | pub unsafe fn set_stream(&self, stream: *mut ihipStream_t) -> Result<()> { 30 | let status = unsafe { rocsparse_set_stream(self.inner, stream) }; 31 | status_to_result(status) 32 | } 33 | 34 | /// Get the current stream 35 | pub fn get_stream(&self) -> Result<*mut ihipStream_t> { 36 | let mut stream = MaybeUninit::uninit(); 37 | let status = unsafe { rocsparse_get_stream(self.inner, stream.as_mut_ptr()) }; 38 | status_to_result(status)?; 39 | Ok(unsafe { stream.assume_init() }) 40 | } 41 | 42 | /// Set pointer mode 43 | pub fn set_pointer_mode(&self, mode: PointerMode) -> Result<()> { 44 | let status = unsafe { rocsparse_set_pointer_mode(self.inner, mode.into()) }; 45 | status_to_result(status) 46 | } 47 | 48 | /// Get pointer mode 49 | pub fn get_pointer_mode(&self) -> Result { 50 | let mut mode = MaybeUninit::uninit(); 51 | let status = unsafe { rocsparse_get_pointer_mode(self.inner, mode.as_mut_ptr()) }; 52 | status_to_result(status)?; 53 | Ok(unsafe { PointerMode::from_raw(mode.assume_init()) }) 54 | } 55 | 56 | /// Get ROCsparse version 57 | pub fn get_version(&self) -> Result<(u32, u32, u32)> { 58 | let mut version = MaybeUninit::uninit(); 59 | let status = unsafe { rocsparse_get_version(self.inner, version.as_mut_ptr()) }; 60 | status_to_result(status)?; 61 | let version = unsafe { version.assume_init() }; 62 | let patch = version % 100; 63 | let minor = (version / 100) % 1000; 64 | let major = version / 100000; 65 | Ok((major as u32, minor as u32, patch as u32)) 66 | } 67 | } 68 | 69 | impl Drop for Handle { 70 | fn drop(&mut self) { 71 | unsafe { 72 | // Ignore error on drop 73 | let _ = rocsparse_destroy_handle(self.inner); 74 | } 75 | } 76 | } 77 | 78 | /// Pointer mode for ROCsparse functions 79 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 80 | pub enum PointerMode { 81 | /// Scalar pointers are in host memory 82 | Host, 83 | /// Scalar pointers are in device memory 84 | Device, 85 | } 86 | 87 | impl PointerMode { 88 | pub(crate) fn from_raw(raw: rocsparse_pointer_mode_) -> Self { 89 | match raw { 90 | rocsparse_pointer_mode__rocsparse_pointer_mode_device => PointerMode::Device, 91 | _ => PointerMode::Host, 92 | } 93 | } 94 | } 95 | 96 | impl From for rocsparse_pointer_mode_ { 97 | fn from(mode: PointerMode) -> Self { 98 | match mode { 99 | PointerMode::Host => rocsparse_pointer_mode__rocsparse_pointer_mode_host, 100 | PointerMode::Device => rocsparse_pointer_mode__rocsparse_pointer_mode_device, 101 | } 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /src/hip/error.rs: -------------------------------------------------------------------------------- 1 | // src/hip/error.rs 2 | 3 | use crate::hip::ffi; 4 | use std::error::Error as StdError; 5 | use std::fmt; 6 | 7 | /// Error type for HIP operations 8 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 9 | pub struct Error { 10 | code: ffi::hipError_t, 11 | } 12 | 13 | /// Result type for HIP operations 14 | pub type Result = std::result::Result; 15 | 16 | impl Error { 17 | /// Create a new error from a HIP error code 18 | pub fn new(code: ffi::hipError_t) -> Self { 19 | Self { code } 20 | } 21 | 22 | /// Convert a HIP error code to a Result 23 | pub fn from_hip_error(error: ffi::hipError_t) -> Result 24 | where 25 | T: Default, 26 | { 27 | if error == ffi::hipError_t_hipSuccess { 28 | Ok(T::default()) 29 | } else { 30 | Err(Error::new(error)) 31 | } 32 | } 33 | 34 | /// Convert a HIP error code to a Result with a specific value 35 | pub fn from_hip_error_with_value(error: ffi::hipError_t, value: T) -> Result { 36 | if error == ffi::hipError_t_hipSuccess { 37 | Ok(value) 38 | } else { 39 | Err(Error::new(error)) 40 | } 41 | } 42 | 43 | /// Returns true if the error code represents success 44 | pub fn is_success(&self) -> bool { 45 | self.code == ffi::hipError_t_hipSuccess 46 | } 47 | 48 | /// Get the raw error code 49 | pub fn code(&self) -> ffi::hipError_t { 50 | self.code 51 | } 52 | 53 | /// Returns the error name as a string 54 | pub fn name(&self) -> &'static str { 55 | unsafe { 56 | let name_ptr = ffi::hipGetErrorName(self.code); 57 | if name_ptr.is_null() { 58 | "Unknown error" 59 | } else { 60 | // This is safe because hipGetErrorName returns a static string 61 | std::ffi::CStr::from_ptr(name_ptr) 62 | .to_str() 63 | .unwrap_or("Invalid error string") 64 | } 65 | } 66 | } 67 | 68 | /// Returns the error description as a string 69 | pub fn description(&self) -> &'static str { 70 | unsafe { 71 | let desc_ptr = ffi::hipGetErrorString(self.code); 72 | if desc_ptr.is_null() { 73 | "Unknown error" 74 | } else { 75 | // This is safe because hipGetErrorString returns a static string 76 | std::ffi::CStr::from_ptr(desc_ptr) 77 | .to_str() 78 | .unwrap_or("Invalid error string") 79 | } 80 | } 81 | } 82 | } 83 | 84 | impl fmt::Display for Error { 85 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 86 | write!( 87 | f, 88 | "HIP error {}: {} - {}", 89 | self.code, 90 | self.name(), 91 | self.description() 92 | ) 93 | } 94 | } 95 | 96 | impl StdError for Error {} 97 | 98 | // Define error conversion functions for common HIP error codes 99 | impl Error { 100 | pub fn is_invalid_value(&self) -> bool { 101 | self.code == ffi::hipError_t_hipErrorInvalidValue 102 | } 103 | 104 | pub fn is_out_of_memory(&self) -> bool { 105 | self.code == ffi::hipError_t_hipErrorOutOfMemory 106 | || self.code == ffi::hipError_t_hipErrorMemoryAllocation 107 | } 108 | 109 | pub fn is_not_initialized(&self) -> bool { 110 | self.code == ffi::hipError_t_hipErrorNotInitialized 111 | } 112 | 113 | pub fn is_invalid_device(&self) -> bool { 114 | self.code == ffi::hipError_t_hipErrorInvalidDevice 115 | } 116 | 117 | pub fn is_invalid_context(&self) -> bool { 118 | self.code == ffi::hipError_t_hipErrorInvalidContext 119 | } 120 | 121 | pub fn is_not_ready(&self) -> bool { 122 | self.code == ffi::hipError_t_hipErrorNotReady 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /src/rocsparse/error.rs: -------------------------------------------------------------------------------- 1 | //! Error types for ROCsparse operations 2 | 3 | use crate::rocsparse::rocsparse_status; 4 | 5 | use super::bindings; 6 | 7 | /// Error type for ROCsparse operations 8 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 9 | pub enum Error { 10 | InvalidHandle, 11 | NotImplemented, 12 | InvalidPointer, 13 | InvalidSize, 14 | MemoryError, 15 | InternalError, 16 | InvalidValue, 17 | ArchMismatch, 18 | ZeroPivot, 19 | NotInitialized, 20 | TypeMismatch, 21 | RequiresSortedStorage, 22 | ThrownException, 23 | Continue, // This is not an error but part of the status enum 24 | Unknown(i32), 25 | } 26 | 27 | impl std::fmt::Display for Error { 28 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 29 | match self { 30 | Error::InvalidHandle => write!(f, "ROCsparse: handle not initialized, invalid or null"), 31 | Error::NotImplemented => write!(f, "ROCsparse: function is not implemented"), 32 | Error::InvalidPointer => write!(f, "ROCsparse: invalid pointer parameter"), 33 | Error::InvalidSize => write!(f, "ROCsparse: invalid size parameter"), 34 | Error::MemoryError => write!(f, "ROCsparse: failed memory allocation, copy, dealloc"), 35 | Error::InternalError => write!(f, "ROCsparse: other internal library failure"), 36 | Error::InvalidValue => write!(f, "ROCsparse: invalid value parameter"), 37 | Error::ArchMismatch => write!(f, "ROCsparse: device arch is not supported"), 38 | Error::ZeroPivot => write!(f, "ROCsparse: encountered zero pivot"), 39 | Error::NotInitialized => write!(f, "ROCsparse: descriptor has not been initialized"), 40 | Error::TypeMismatch => write!(f, "ROCsparse: index types do not match"), 41 | Error::RequiresSortedStorage => write!(f, "ROCsparse: sorted storage required"), 42 | Error::ThrownException => write!(f, "ROCsparse: exception being thrown"), 43 | Error::Continue => write!(f, "ROCsparse: nothing preventing function to proceed"), 44 | Error::Unknown(code) => write!(f, "ROCsparse: unknown error code: {}", code), 45 | } 46 | } 47 | } 48 | 49 | impl std::error::Error for Error {} 50 | 51 | /// Alias for Result with ROCsparse error 52 | pub type Result = std::result::Result; 53 | 54 | /// Convert low-level status to Result 55 | pub(crate) fn status_to_result(status: rocsparse_status) -> Result<()> { 56 | match status { 57 | bindings::rocsparse_status__rocsparse_status_success => Ok(()), 58 | bindings::rocsparse_status__rocsparse_status_invalid_handle => Err(Error::InvalidHandle), 59 | bindings::rocsparse_status__rocsparse_status_not_implemented => Err(Error::NotImplemented), 60 | bindings::rocsparse_status__rocsparse_status_invalid_pointer => Err(Error::InvalidPointer), 61 | bindings::rocsparse_status__rocsparse_status_invalid_size => Err(Error::InvalidSize), 62 | bindings::rocsparse_status__rocsparse_status_memory_error => Err(Error::MemoryError), 63 | bindings::rocsparse_status__rocsparse_status_internal_error => Err(Error::InternalError), 64 | bindings::rocsparse_status__rocsparse_status_invalid_value => Err(Error::InvalidValue), 65 | bindings::rocsparse_status__rocsparse_status_arch_mismatch => Err(Error::ArchMismatch), 66 | bindings::rocsparse_status__rocsparse_status_zero_pivot => Err(Error::ZeroPivot), 67 | bindings::rocsparse_status__rocsparse_status_not_initialized => Err(Error::NotInitialized), 68 | bindings::rocsparse_status__rocsparse_status_type_mismatch => Err(Error::TypeMismatch), 69 | bindings::rocsparse_status__rocsparse_status_requires_sorted_storage => { 70 | Err(Error::RequiresSortedStorage) 71 | } 72 | bindings::rocsparse_status__rocsparse_status_thrown_exception => { 73 | Err(Error::ThrownException) 74 | } 75 | bindings::rocsparse_status__rocsparse_status_continue => Err(Error::Continue), 76 | _ => Err(Error::Unknown(status as i32)), 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/miopen/handle.rs: -------------------------------------------------------------------------------- 1 | // src/miopen/handle.rs 2 | 3 | use crate::hip::{Stream, bindings::hipStream_t}; 4 | use crate::miopen::error::{Error, Result}; 5 | use crate::miopen::ffi; 6 | use std::ptr; 7 | 8 | /// Safe wrapper for MIOpen handle 9 | pub struct Handle { 10 | handle: ffi::miopenHandle_t, 11 | } 12 | 13 | impl Handle { 14 | /// Create a new MIOpen handle 15 | pub fn new() -> Result { 16 | let mut handle = ptr::null_mut(); 17 | let status = unsafe { ffi::miopenCreate(&mut handle) }; 18 | 19 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 20 | return Err(Error::new(status)); 21 | } 22 | 23 | Ok(Self { handle }) 24 | } 25 | 26 | /// Create a new MIOpen handle with a stream 27 | pub fn with_stream(stream: &Stream) -> Result { 28 | let mut handle = ptr::null_mut(); 29 | let status = unsafe { 30 | ffi::miopenCreateWithStream( 31 | &mut handle, 32 | stream.as_raw() as crate::miopen::bindings::miopenAcceleratorQueue_t, 33 | ) 34 | }; 35 | 36 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 37 | return Err(Error::new(status)); 38 | } 39 | 40 | Ok(Self { handle }) 41 | } 42 | 43 | /// Set the stream for this handle 44 | pub fn set_stream(&self, stream: &Stream) -> Result<()> { 45 | let status = unsafe { 46 | ffi::miopenSetStream( 47 | self.handle, 48 | stream.as_raw() as crate::miopen::bindings::miopenAcceleratorQueue_t, 49 | ) 50 | }; 51 | 52 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 53 | return Err(Error::new(status)); 54 | } 55 | 56 | Ok(()) 57 | } 58 | 59 | /// Get the current stream for this handle 60 | pub fn get_stream(&self) -> Result { 61 | let mut stream_id = ptr::null_mut(); 62 | let status = unsafe { ffi::miopenGetStream(self.handle, &mut stream_id) }; 63 | 64 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 65 | return Err(Error::new(status)); 66 | } 67 | 68 | // Create a stream from the raw pointer 69 | Ok(Stream::from_raw(stream_id as hipStream_t)) 70 | } 71 | 72 | /// Enable or disable profiling 73 | pub fn enable_profiling(&self, enable: bool) -> Result<()> { 74 | let status = unsafe { ffi::miopenEnableProfiling(self.handle, enable) }; 75 | 76 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 77 | return Err(Error::new(status)); 78 | } 79 | 80 | Ok(()) 81 | } 82 | 83 | /// Get the timing of the last kernel executed 84 | pub fn get_kernel_time(&self) -> Result { 85 | let mut time = 0.0; 86 | let status = unsafe { ffi::miopenGetKernelTime(self.handle, &mut time) }; 87 | 88 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 89 | return Err(Error::new(status)); 90 | } 91 | 92 | Ok(time) 93 | } 94 | 95 | /// Set a custom allocator for MIOpen 96 | pub unsafe fn set_allocator( 97 | &self, 98 | allocator: ffi::miopenAllocatorFunction, 99 | deallocator: ffi::miopenDeallocatorFunction, 100 | context: *mut ::std::os::raw::c_void, 101 | ) -> Result<()> { 102 | let status = 103 | unsafe { ffi::miopenSetAllocator(self.handle, allocator, deallocator, context) }; 104 | 105 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 106 | return Err(Error::new(status)); 107 | } 108 | 109 | Ok(()) 110 | } 111 | 112 | /// Get the raw handle 113 | pub fn as_raw(&self) -> ffi::miopenHandle_t { 114 | self.handle 115 | } 116 | } 117 | 118 | impl Drop for Handle { 119 | fn drop(&mut self) { 120 | if !self.handle.is_null() { 121 | unsafe { 122 | let _ = ffi::miopenDestroy(self.handle); 123 | }; 124 | self.handle = ptr::null_mut(); 125 | } 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /src/rocfft/cache.rs: -------------------------------------------------------------------------------- 1 | /*! 2 | # Kernel cache management for rocFFT 3 | 4 | This module provides functions to serialize and deserialize the rocFFT 5 | compiled kernel cache, allowing kernel caches to be saved and loaded 6 | between application runs. 7 | */ 8 | 9 | use crate::rocfft::bindings; 10 | use crate::rocfft::error::{Error, Result, check_error}; 11 | use std::ptr; 12 | use std::slice; 13 | 14 | /// A buffer containing serialized kernel cache data 15 | pub struct CacheBuffer { 16 | ptr: *mut std::ffi::c_void, 17 | len: usize, 18 | } 19 | 20 | impl CacheBuffer { 21 | /// Get a slice of the buffer contents 22 | pub fn as_slice(&self) -> &[u8] { 23 | if self.ptr.is_null() || self.len == 0 { 24 | &[] 25 | } else { 26 | unsafe { slice::from_raw_parts(self.ptr as *const u8, self.len) } 27 | } 28 | } 29 | 30 | /// Get the length of the buffer in bytes 31 | pub fn len(&self) -> usize { 32 | self.len 33 | } 34 | 35 | /// Check if the buffer is empty 36 | pub fn is_empty(&self) -> bool { 37 | self.len == 0 || self.ptr.is_null() 38 | } 39 | } 40 | 41 | impl Drop for CacheBuffer { 42 | fn drop(&mut self) { 43 | if !self.ptr.is_null() { 44 | unsafe { 45 | bindings::rocfft_cache_buffer_free(self.ptr); 46 | } 47 | self.ptr = ptr::null_mut(); 48 | self.len = 0; 49 | } 50 | } 51 | } 52 | 53 | /// Serialize the current compiled kernel cache into a buffer 54 | /// 55 | /// This function captures the current state of the rocFFT compiled kernel cache 56 | /// and serializes it into a buffer that can be saved and later deserialized. 57 | /// This can significantly improve startup performance for applications that 58 | /// use the same FFT configurations repeatedly. 59 | /// 60 | /// # Returns 61 | /// 62 | /// A result containing the serialized cache buffer 63 | /// 64 | /// # Example 65 | /// 66 | /// ```no_run 67 | /// use crate::rocfft; 68 | /// use std::fs::File; 69 | /// use std::io::Write; 70 | /// 71 | /// fn main() -> Result<(), Box> { 72 | /// rocfft::setup()?; 73 | /// 74 | /// // After running some transforms, serialize the cache 75 | /// let buffer = rocfft::cache::serialize()?; 76 | /// 77 | /// // Save to a file 78 | /// let mut file = File::create("rocfft_cache.bin")?; 79 | /// file.write_all(buffer.as_slice())?; 80 | /// 81 | /// rocfft::cleanup()?; 82 | /// Ok(()) 83 | /// } 84 | /// ``` 85 | pub fn serialize() -> Result { 86 | let mut ptr: *mut std::ffi::c_void = ptr::null_mut(); 87 | let mut len: usize = 0; 88 | 89 | unsafe { 90 | check_error(bindings::rocfft_cache_serialize(&mut ptr, &mut len))?; 91 | } 92 | 93 | Ok(CacheBuffer { ptr, len }) 94 | } 95 | 96 | /// Deserialize a buffer into the compiled kernel cache 97 | /// 98 | /// This function loads a previously serialized kernel cache into the rocFFT 99 | /// runtime, which can avoid recompilation of kernels and improve startup 100 | /// performance. 101 | /// 102 | /// # Arguments 103 | /// 104 | /// * `data` - Slice containing the serialized cache data 105 | /// 106 | /// # Returns 107 | /// 108 | /// A result indicating success or an error 109 | /// 110 | /// # Example 111 | /// 112 | /// ```no_run 113 | /// use crate::rocfft; 114 | /// use std::fs::File; 115 | /// use std::io::Read; 116 | /// 117 | /// fn main() -> Result<(), Box> { 118 | /// // Load from a file before initializing rocFFT 119 | /// let mut file = File::open("rocfft_cache.bin")?; 120 | /// let mut data = Vec::new(); 121 | /// file.read_to_end(&mut data)?; 122 | /// 123 | /// rocfft::setup()?; 124 | /// 125 | /// // Deserialize into the cache 126 | /// rocfft::cache::deserialize(&data)?; 127 | /// 128 | /// // Now use rocFFT with precompiled kernels 129 | /// 130 | /// rocfft::cleanup()?; 131 | /// Ok(()) 132 | /// } 133 | /// ``` 134 | pub fn deserialize(data: &[u8]) -> Result<()> { 135 | if data.is_empty() { 136 | return Ok(()); 137 | } 138 | 139 | unsafe { 140 | check_error(bindings::rocfft_cache_deserialize( 141 | data.as_ptr() as *const std::ffi::c_void, 142 | data.len(), 143 | )) 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /src/rocrand/mod.rs: -------------------------------------------------------------------------------- 1 | // src/rocrand/mod.rs 2 | // 3 | // Module definition for rocrand 4 | 5 | // Re-export the raw bindings for advanced usage 6 | #[allow(warnings)] 7 | pub mod bindings; 8 | 9 | // Import submodules 10 | pub mod distribution; 11 | pub mod error; 12 | pub mod generator; 13 | pub mod utils; 14 | 15 | // Re-export public items 16 | pub use distribution::{Discrete, LogNormal, Normal, Poisson, Uniform}; 17 | pub use error::{Error, Result}; 18 | pub use generator::{Generator, PseudoRng, QuasiRng}; 19 | 20 | /// Convenient re-exports of random number generator types 21 | pub mod rng_type { 22 | use super::bindings; 23 | 24 | pub const PSEUDO_DEFAULT: u32 = bindings::rocrand_rng_type_ROCRAND_RNG_PSEUDO_DEFAULT; 25 | pub const XORWOW: u32 = bindings::rocrand_rng_type_ROCRAND_RNG_PSEUDO_XORWOW; 26 | pub const MRG32K3A: u32 = bindings::rocrand_rng_type_ROCRAND_RNG_PSEUDO_MRG32K3A; 27 | pub const MTGP32: u32 = bindings::rocrand_rng_type_ROCRAND_RNG_PSEUDO_MTGP32; 28 | pub const PHILOX4_32_10: u32 = bindings::rocrand_rng_type_ROCRAND_RNG_PSEUDO_PHILOX4_32_10; 29 | pub const MRG31K3P: u32 = bindings::rocrand_rng_type_ROCRAND_RNG_PSEUDO_MRG31K3P; 30 | pub const LFSR113: u32 = bindings::rocrand_rng_type_ROCRAND_RNG_PSEUDO_LFSR113; 31 | pub const MT19937: u32 = bindings::rocrand_rng_type_ROCRAND_RNG_PSEUDO_MT19937; 32 | pub const THREEFRY2_32_20: u32 = bindings::rocrand_rng_type_ROCRAND_RNG_PSEUDO_THREEFRY2_32_20; 33 | pub const THREEFRY2_64_20: u32 = bindings::rocrand_rng_type_ROCRAND_RNG_PSEUDO_THREEFRY2_64_20; 34 | pub const THREEFRY4_32_20: u32 = bindings::rocrand_rng_type_ROCRAND_RNG_PSEUDO_THREEFRY4_32_20; 35 | pub const THREEFRY4_64_20: u32 = bindings::rocrand_rng_type_ROCRAND_RNG_PSEUDO_THREEFRY4_64_20; 36 | 37 | pub const QUASI_DEFAULT: u32 = bindings::rocrand_rng_type_ROCRAND_RNG_QUASI_DEFAULT; 38 | pub const SOBOL32: u32 = bindings::rocrand_rng_type_ROCRAND_RNG_QUASI_SOBOL32; 39 | pub const SCRAMBLED_SOBOL32: u32 = 40 | bindings::rocrand_rng_type_ROCRAND_RNG_QUASI_SCRAMBLED_SOBOL32; 41 | pub const SOBOL64: u32 = bindings::rocrand_rng_type_ROCRAND_RNG_QUASI_SOBOL64; 42 | pub const SCRAMBLED_SOBOL64: u32 = 43 | bindings::rocrand_rng_type_ROCRAND_RNG_QUASI_SCRAMBLED_SOBOL64; 44 | } 45 | 46 | /// Convenient re-exports of ordering constants 47 | pub mod ordering { 48 | use super::bindings; 49 | 50 | pub const PSEUDO_BEST: u32 = bindings::rocrand_ordering_ROCRAND_ORDERING_PSEUDO_BEST; 51 | pub const PSEUDO_DEFAULT: u32 = bindings::rocrand_ordering_ROCRAND_ORDERING_PSEUDO_DEFAULT; 52 | pub const PSEUDO_SEEDED: u32 = bindings::rocrand_ordering_ROCRAND_ORDERING_PSEUDO_SEEDED; 53 | pub const PSEUDO_LEGACY: u32 = bindings::rocrand_ordering_ROCRAND_ORDERING_PSEUDO_LEGACY; 54 | pub const PSEUDO_DYNAMIC: u32 = bindings::rocrand_ordering_ROCRAND_ORDERING_PSEUDO_DYNAMIC; 55 | pub const QUASI_DEFAULT: u32 = bindings::rocrand_ordering_ROCRAND_ORDERING_QUASI_DEFAULT; 56 | } 57 | 58 | /// Re-export direction vector constants 59 | pub mod direction_vector_set { 60 | use super::bindings; 61 | 62 | pub const VECTORS_32_JOEKUO6: u32 = 63 | bindings::rocrand_direction_vector_set_ROCRAND_DIRECTION_VECTORS_32_JOEKUO6; 64 | pub const SCRAMBLED_VECTORS_32_JOEKUO6: u32 = 65 | bindings::rocrand_direction_vector_set_ROCRAND_SCRAMBLED_DIRECTION_VECTORS_32_JOEKUO6; 66 | pub const VECTORS_64_JOEKUO6: u32 = 67 | bindings::rocrand_direction_vector_set_ROCRAND_DIRECTION_VECTORS_64_JOEKUO6; 68 | pub const SCRAMBLED_VECTORS_64_JOEKUO6: u32 = 69 | bindings::rocrand_direction_vector_set_ROCRAND_SCRAMBLED_DIRECTION_VECTORS_64_JOEKUO6; 70 | } 71 | 72 | /// Creates the default pseudo-random number generator 73 | pub fn default_generator() -> Result { 74 | PseudoRng::new(rng_type::PSEUDO_DEFAULT) 75 | } 76 | 77 | /// Creates a XORWOW pseudo-random number generator 78 | pub fn xorwow_generator() -> Result { 79 | PseudoRng::new(rng_type::XORWOW) 80 | } 81 | 82 | /// Creates a Sobol32 quasi-random number generator with the specified dimensions 83 | pub fn sobol32_generator(dimensions: u32) -> Result { 84 | let mut rng = QuasiRng::new(rng_type::SOBOL32)?; 85 | rng.set_dimensions(dimensions)?; 86 | Ok(rng) 87 | } 88 | 89 | /// Gets the rocRAND library version 90 | pub fn get_version() -> Result { 91 | // Use fully qualified syntax to call the trait function 92 | ::get_version() 93 | } 94 | -------------------------------------------------------------------------------- /src/hip/ffi.rs: -------------------------------------------------------------------------------- 1 | // src/hip/ffi.rs 2 | // 3 | // FFI bindings for the HIP API 4 | // This file re-exports the necessary symbols from the auto-generated bindings 5 | 6 | // We assume there's a bindings module that was auto-generated 7 | // using bindgen or similar tool 8 | use crate::hip::bindings; 9 | 10 | // Re-export the necessary types, constants, and functions 11 | 12 | // Error type and constants 13 | pub use bindings::hipError_t; 14 | pub use bindings::hipError_t_hipErrorInvalidContext; 15 | pub use bindings::hipError_t_hipErrorInvalidDevice; 16 | pub use bindings::hipError_t_hipErrorInvalidValue; 17 | pub use bindings::hipError_t_hipErrorMemoryAllocation; 18 | pub use bindings::hipError_t_hipErrorNotInitialized; 19 | pub use bindings::hipError_t_hipErrorNotReady; 20 | pub use bindings::hipError_t_hipErrorOutOfMemory; 21 | pub use bindings::hipError_t_hipSuccess; 22 | 23 | // Device handle and operations 24 | pub use bindings::hipDevice_t; 25 | pub use bindings::hipDeviceProp_tR0600; 26 | pub use bindings::hipDeviceReset; 27 | pub use bindings::hipDeviceSynchronize; 28 | pub use bindings::hipDriverGetVersion; 29 | pub use bindings::hipGetDevice; 30 | pub use bindings::hipGetDeviceCount; 31 | pub use bindings::hipGetDevicePropertiesR0600; 32 | pub use bindings::hipGetErrorName; 33 | pub use bindings::hipGetErrorString; 34 | pub use bindings::hipGetLastError; 35 | pub use bindings::hipInit; 36 | pub use bindings::hipRuntimeGetVersion; 37 | pub use bindings::hipSetDevice; 38 | 39 | // Memory management 40 | pub use bindings::hipFree; 41 | pub use bindings::hipHostFree; 42 | pub use bindings::hipHostGetDevicePointer; 43 | pub use bindings::hipHostMalloc; 44 | pub use bindings::hipMalloc; 45 | pub use bindings::hipMemGetInfo; 46 | pub use bindings::hipMemcpy; 47 | pub use bindings::hipMemcpyAsync; 48 | pub use bindings::hipMemset; 49 | 50 | // Memory copy kinds 51 | pub use bindings::hipMemcpyKind_hipMemcpyDefault; 52 | pub use bindings::hipMemcpyKind_hipMemcpyDeviceToDevice; 53 | pub use bindings::hipMemcpyKind_hipMemcpyDeviceToHost; 54 | pub use bindings::hipMemcpyKind_hipMemcpyHostToDevice; 55 | pub use bindings::hipMemcpyKind_hipMemcpyHostToHost; 56 | 57 | // Host malloc flags 58 | pub use bindings::hipHostMallocCoherent; 59 | pub use bindings::hipHostMallocDefault; 60 | pub use bindings::hipHostMallocMapped; 61 | pub use bindings::hipHostMallocNonCoherent; 62 | pub use bindings::hipHostMallocNumaUser; 63 | pub use bindings::hipHostMallocPortable; 64 | pub use bindings::hipHostMallocWriteCombined; 65 | 66 | // Stream operations 67 | pub use bindings::hipDeviceGetStreamPriorityRange; 68 | pub use bindings::hipStream_t; 69 | pub use bindings::hipStreamAddCallback; 70 | pub use bindings::hipStreamCreate; 71 | pub use bindings::hipStreamCreateWithFlags; 72 | pub use bindings::hipStreamCreateWithPriority; 73 | pub use bindings::hipStreamDestroy; 74 | pub use bindings::hipStreamGetDevice; 75 | pub use bindings::hipStreamGetFlags; 76 | pub use bindings::hipStreamGetPriority; 77 | pub use bindings::hipStreamQuery; 78 | pub use bindings::hipStreamSynchronize; 79 | pub use bindings::hipStreamWaitEvent; 80 | 81 | // Event operations 82 | pub use bindings::hipEvent_t; 83 | pub use bindings::hipEventCreate; 84 | pub use bindings::hipEventCreateWithFlags; 85 | pub use bindings::hipEventDestroy; 86 | pub use bindings::hipEventElapsedTime; 87 | pub use bindings::hipEventQuery; 88 | pub use bindings::hipEventRecord; 89 | pub use bindings::hipEventSynchronize; 90 | 91 | // Kernel launching 92 | pub use bindings::dim3; 93 | pub use bindings::hipFunction_t; 94 | pub use bindings::hipLaunchKernel; 95 | pub use bindings::hipModuleGetFunction; 96 | pub use bindings::hipModuleLaunchKernel; 97 | 98 | // Texture and surface references 99 | pub use bindings::hipCreateSurfaceObject; 100 | pub use bindings::hipCreateTextureObject; 101 | pub use bindings::hipDestroySurfaceObject; 102 | pub use bindings::hipDestroyTextureObject; 103 | pub use bindings::hipSurfaceObject_t; 104 | pub use bindings::hipTextureObject_t; 105 | 106 | pub use bindings::hipJitOption; 107 | pub use bindings::hipModule_t; 108 | pub use bindings::hipModuleGetGlobal; 109 | pub use bindings::hipModuleLoad; 110 | pub use bindings::hipModuleLoadData; 111 | pub use bindings::hipModuleLoadDataEx; 112 | pub use bindings::hipModuleUnload; 113 | 114 | // Other useful constants and types as needed for your implementation 115 | // Add more imports as required by your wrapper implementation 116 | -------------------------------------------------------------------------------- /src/miopen/examples/multi_tensor/src/iris.csv: -------------------------------------------------------------------------------- 1 | sepal_length,sepal_width,petal_length,petal_width,species 2 | 5.1,3.5,1.4,0.2,setosa 3 | 4.9,3.0,1.4,0.2,setosa 4 | 4.7,3.2,1.3,0.2,setosa 5 | 4.6,3.1,1.5,0.2,setosa 6 | 5.0,3.6,1.4,0.2,setosa 7 | 5.4,3.9,1.7,0.4,setosa 8 | 4.6,3.4,1.4,0.3,setosa 9 | 5.0,3.4,1.5,0.2,setosa 10 | 4.4,2.9,1.4,0.2,setosa 11 | 4.9,3.1,1.5,0.1,setosa 12 | 5.4,3.7,1.5,0.2,setosa 13 | 4.8,3.4,1.6,0.2,setosa 14 | 4.8,3.0,1.4,0.1,setosa 15 | 4.3,3.0,1.1,0.1,setosa 16 | 5.8,4.0,1.2,0.2,setosa 17 | 5.7,4.4,1.5,0.4,setosa 18 | 5.4,3.9,1.3,0.4,setosa 19 | 5.1,3.5,1.4,0.3,setosa 20 | 5.7,3.8,1.7,0.3,setosa 21 | 5.1,3.8,1.5,0.3,setosa 22 | 5.4,3.4,1.7,0.2,setosa 23 | 5.1,3.7,1.5,0.4,setosa 24 | 4.6,3.6,1.0,0.2,setosa 25 | 5.1,3.3,1.7,0.5,setosa 26 | 4.8,3.4,1.9,0.2,setosa 27 | 5.0,3.0,1.6,0.2,setosa 28 | 5.0,3.4,1.6,0.4,setosa 29 | 5.2,3.5,1.5,0.2,setosa 30 | 5.2,3.4,1.4,0.2,setosa 31 | 4.7,3.2,1.6,0.2,setosa 32 | 4.8,3.1,1.6,0.2,setosa 33 | 5.4,3.4,1.5,0.4,setosa 34 | 5.2,4.1,1.5,0.1,setosa 35 | 5.5,4.2,1.4,0.2,setosa 36 | 4.9,3.1,1.5,0.1,setosa 37 | 5.0,3.2,1.2,0.2,setosa 38 | 5.5,3.5,1.3,0.2,setosa 39 | 4.9,3.1,1.5,0.1,setosa 40 | 4.4,3.0,1.3,0.2,setosa 41 | 5.1,3.4,1.5,0.2,setosa 42 | 5.0,3.5,1.3,0.3,setosa 43 | 4.5,2.3,1.3,0.3,setosa 44 | 4.4,3.2,1.3,0.2,setosa 45 | 5.0,3.5,1.6,0.6,setosa 46 | 5.1,3.8,1.9,0.4,setosa 47 | 4.8,3.0,1.4,0.3,setosa 48 | 5.1,3.8,1.6,0.2,setosa 49 | 4.6,3.2,1.4,0.2,setosa 50 | 5.3,3.7,1.5,0.2,setosa 51 | 5.0,3.3,1.4,0.2,setosa 52 | 7.0,3.2,4.7,1.4,versicolor 53 | 6.4,3.2,4.5,1.5,versicolor 54 | 6.9,3.1,4.9,1.5,versicolor 55 | 5.5,2.3,4.0,1.3,versicolor 56 | 6.5,2.8,4.6,1.5,versicolor 57 | 5.7,2.8,4.5,1.3,versicolor 58 | 6.3,3.3,4.7,1.6,versicolor 59 | 4.9,2.4,3.3,1.0,versicolor 60 | 6.6,2.9,4.6,1.3,versicolor 61 | 5.2,2.7,3.9,1.4,versicolor 62 | 5.0,2.0,3.5,1.0,versicolor 63 | 5.9,3.0,4.2,1.5,versicolor 64 | 6.0,2.2,4.0,1.0,versicolor 65 | 6.1,2.9,4.7,1.4,versicolor 66 | 5.6,2.9,3.6,1.3,versicolor 67 | 6.7,3.1,4.4,1.4,versicolor 68 | 5.6,3.0,4.5,1.5,versicolor 69 | 5.8,2.7,4.1,1.0,versicolor 70 | 6.2,2.2,4.5,1.5,versicolor 71 | 5.6,2.5,3.9,1.1,versicolor 72 | 5.9,3.2,4.8,1.8,versicolor 73 | 6.1,2.8,4.0,1.3,versicolor 74 | 6.3,2.5,4.9,1.5,versicolor 75 | 6.1,2.8,4.7,1.2,versicolor 76 | 6.4,2.9,4.3,1.3,versicolor 77 | 6.6,3.0,4.4,1.4,versicolor 78 | 6.8,2.8,4.8,1.4,versicolor 79 | 6.7,3.0,5.0,1.7,versicolor 80 | 6.0,2.9,4.5,1.5,versicolor 81 | 5.7,2.6,3.5,1.0,versicolor 82 | 5.5,2.4,3.8,1.1,versicolor 83 | 5.5,2.4,3.7,1.0,versicolor 84 | 5.8,2.7,3.9,1.2,versicolor 85 | 6.0,2.7,5.1,1.6,versicolor 86 | 5.4,3.0,4.5,1.5,versicolor 87 | 6.0,3.4,4.5,1.6,versicolor 88 | 6.7,3.1,4.7,1.5,versicolor 89 | 6.3,2.3,4.4,1.3,versicolor 90 | 5.6,3.0,4.1,1.3,versicolor 91 | 5.5,2.5,4.0,1.3,versicolor 92 | 5.5,2.6,4.4,1.2,versicolor 93 | 6.1,3.0,4.6,1.4,versicolor 94 | 5.8,2.6,4.0,1.2,versicolor 95 | 5.0,2.3,3.3,1.0,versicolor 96 | 5.6,2.7,4.2,1.3,versicolor 97 | 5.7,3.0,4.2,1.2,versicolor 98 | 5.7,2.9,4.2,1.3,versicolor 99 | 6.2,2.9,4.3,1.3,versicolor 100 | 5.1,2.5,3.0,1.1,versicolor 101 | 5.7,2.8,4.1,1.3,versicolor 102 | 6.3,3.3,6.0,2.5,virginica 103 | 5.8,2.7,5.1,1.9,virginica 104 | 7.1,3.0,5.9,2.1,virginica 105 | 6.3,2.9,5.6,1.8,virginica 106 | 6.5,3.0,5.8,2.2,virginica 107 | 7.6,3.0,6.6,2.1,virginica 108 | 4.9,2.5,4.5,1.7,virginica 109 | 7.3,2.9,6.3,1.8,virginica 110 | 6.7,2.5,5.8,1.8,virginica 111 | 7.2,3.6,6.1,2.5,virginica 112 | 6.5,3.2,5.1,2.0,virginica 113 | 6.4,2.7,5.3,1.9,virginica 114 | 6.8,3.0,5.5,2.1,virginica 115 | 5.7,2.5,5.0,2.0,virginica 116 | 5.8,2.8,5.1,2.4,virginica 117 | 6.4,3.2,5.3,2.3,virginica 118 | 6.5,3.0,5.5,1.8,virginica 119 | 7.7,3.8,6.7,2.2,virginica 120 | 7.7,2.6,6.9,2.3,virginica 121 | 6.0,2.2,5.0,1.5,virginica 122 | 6.9,3.2,5.7,2.3,virginica 123 | 5.6,2.8,4.9,2.0,virginica 124 | 7.7,2.8,6.7,2.0,virginica 125 | 6.3,2.7,4.9,1.8,virginica 126 | 6.7,3.3,5.7,2.1,virginica 127 | 7.2,3.2,6.0,1.8,virginica 128 | 6.2,2.8,4.8,1.8,virginica 129 | 6.1,3.0,4.9,1.8,virginica 130 | 6.4,2.8,5.6,2.1,virginica 131 | 7.2,3.0,5.8,1.6,virginica 132 | 7.4,2.8,6.1,1.9,virginica 133 | 7.9,3.8,6.4,2.0,virginica 134 | 6.4,2.8,5.6,2.2,virginica 135 | 6.3,2.8,5.1,1.5,virginica 136 | 6.1,2.6,5.6,1.4,virginica 137 | 7.7,3.0,6.1,2.3,virginica 138 | 6.3,3.4,5.6,2.4,virginica 139 | 6.4,3.1,5.5,1.8,virginica 140 | 6.0,3.0,4.8,1.8,virginica 141 | 6.9,3.1,5.4,2.1,virginica 142 | 6.7,3.1,5.6,2.4,virginica 143 | 6.9,3.1,5.1,2.3,virginica 144 | 5.8,2.7,5.1,1.9,virginica 145 | 6.8,3.2,5.9,2.3,virginica 146 | 6.7,3.3,5.7,2.5,virginica 147 | 6.7,3.0,5.2,2.3,virginica 148 | 6.3,2.5,5.0,1.9,virginica 149 | 6.5,3.0,5.2,2.0,virginica 150 | 6.2,3.4,5.4,2.3,virginica 151 | 5.9,3.0,5.1,1.8,virginica -------------------------------------------------------------------------------- /src/rocrand/utils.rs: -------------------------------------------------------------------------------- 1 | // src/rocrand/utils.rs 2 | // 3 | // Utility functions for easier use of the rocrand library 4 | 5 | use crate::error::Result; 6 | use crate::hip::DeviceMemory; 7 | use crate::rocrand::{ 8 | Generator, LogNormal, Normal, Poisson, PseudoRng, QuasiRng, Uniform, rng_type, 9 | }; // Using our unified error type 10 | 11 | macro_rules! generate_uniform_rand_func { 12 | ($fn_name: ident, $data_type:ty, $generato_fn:ident, $rng_type:ident) => { 13 | paste::paste! { 14 | #[doc = "Generate random " $data_type " values on device"] 15 | pub fn $fn_name( 16 | count: usize, 17 | seed: Option, 18 | ) -> Result> { 19 | // Create a generator 20 | let mut generator = PseudoRng::new(rng_type::$rng_type)?; 21 | // Set seed if provided 22 | if let Some(seed_value) = seed { 23 | generator.set_seed(seed_value)?; 24 | } 25 | // Initialize the generator 26 | generator.initialize()?; 27 | // Allocate device memory 28 | let mut device_output = DeviceMemory::<$data_type>::new(count)?; 29 | 30 | // Generate the random numbers 31 | generator.$generato_fn(&mut device_output)?; 32 | 33 | Ok(device_output) 34 | } 35 | } 36 | }; 37 | } 38 | 39 | generate_uniform_rand_func!(generate_uniform_f32, f32, generate_uniform, XORWOW); 40 | generate_uniform_rand_func!(generate_uniform_f64, f64, generate_uniform_double, XORWOW); 41 | generate_uniform_rand_func!(generate_u32, u32, generate_u32, XORWOW); 42 | 43 | macro_rules! generate_normal_rand_func { 44 | ($fn_name: ident, $data_type:ty, $rng_type:ident, $dist:expr) => { 45 | paste::paste! { 46 | #[doc = "Generate normally distributed random " $data_type " values with specified mean and standard deviation"] 47 | pub fn $fn_name( 48 | count: usize, 49 | mean: f32, 50 | stddev: f32, 51 | seed: Option, 52 | ) -> Result> { 53 | // Create a generator 54 | let mut generator = PseudoRng::new(rng_type::$rng_type)?; 55 | 56 | // Set seed if provided 57 | if let Some(seed_value) = seed { 58 | generator.set_seed(seed_value)?; 59 | } 60 | 61 | // Initialize the generator 62 | generator.initialize()?; 63 | 64 | // Create a normal distribution 65 | let dist = $dist(mean, stddev); 66 | 67 | // Allocate device memory 68 | let mut device_output = DeviceMemory::::new(count)?; 69 | 70 | // Generate the random numbers 71 | dist.generate(&mut generator, &mut device_output)?; 72 | 73 | Ok(device_output) 74 | } 75 | } 76 | }; 77 | } 78 | 79 | generate_normal_rand_func!(generate_normal_f32, f32, PHILOX4_32_10, Normal::new); 80 | generate_normal_rand_func!(generate_log_normal_f32, f32, PHILOX4_32_10, LogNormal::new); 81 | 82 | /// Generate Poisson-distributed random u32 values with specified lambda 83 | pub fn generate_poisson(count: usize, lambda: f64, seed: Option) -> Result> { 84 | // Create a generator 85 | let mut generator = PseudoRng::new(rng_type::MTGP32)?; 86 | 87 | // Set seed if provided 88 | if let Some(seed_value) = seed { 89 | generator.set_seed(seed_value)?; 90 | } 91 | 92 | // Initialize the generator 93 | generator.initialize()?; 94 | 95 | // Create a poisson distribution 96 | let poisson_dist = Poisson::new(lambda); 97 | 98 | // Allocate device memory 99 | let mut device_output = DeviceMemory::::new(count)?; 100 | 101 | // Generate the random numbers 102 | poisson_dist.generate(&mut generator, &mut device_output)?; 103 | 104 | Ok(device_output) 105 | } 106 | 107 | /// Generate quasirandom sequence of f32 values with specified dimensions 108 | pub fn generate_quasi_f32(count: usize, dimensions: u32) -> Result> { 109 | // Create a quasi-random generator 110 | let mut generator = QuasiRng::new(rng_type::SOBOL32)?; 111 | 112 | // Set dimensions 113 | generator.set_dimensions(dimensions)?; 114 | 115 | // Initialize the generator 116 | generator.initialize()?; 117 | 118 | // Allocate device memory 119 | let mut device_output = DeviceMemory::::new(count)?; 120 | 121 | // Generate the random numbers 122 | Uniform::generate_quasi(&mut generator, &mut device_output)?; 123 | 124 | Ok(device_output) 125 | } -------------------------------------------------------------------------------- /src/rocfft/ffi.rs: -------------------------------------------------------------------------------- 1 | // FFI module for rocFFT 2 | // This file re-exports the necessary symbols from the auto-generated bindings 3 | 4 | // Import the raw bindings from the auto-generated module 5 | use crate::rocfft::bindings; 6 | 7 | // Re-export the necessary types, constants, and functions 8 | 9 | // Types 10 | pub use bindings::{ 11 | rocfft_array_type as rocfft_array_type_t_alias, 12 | rocfft_brick, 13 | 14 | rocfft_comm_type as rocfft_comm_type_t_alias, 15 | rocfft_execution_info, 16 | rocfft_field, 17 | // Handle types 18 | rocfft_plan, 19 | rocfft_plan_description, 20 | rocfft_precision as rocfft_precision_t_alias, 21 | rocfft_result_placement as rocfft_result_placement_t_alias, 22 | // Complex types 23 | rocfft_status as rocfft_status_t_alias, 24 | // Status type 25 | rocfft_status_e as rocfft_status_t, 26 | 27 | rocfft_transform_type as rocfft_transform_type_t_alias, 28 | }; 29 | 30 | // Status constants 31 | pub use bindings::{ 32 | rocfft_status_e_rocfft_status_failure as STATUS_FAILURE, 33 | rocfft_status_e_rocfft_status_invalid_arg_value as STATUS_INVALID_ARG_VALUE, 34 | rocfft_status_e_rocfft_status_invalid_array_type as STATUS_INVALID_ARRAY_TYPE, 35 | rocfft_status_e_rocfft_status_invalid_dimensions as STATUS_INVALID_DIMENSIONS, 36 | rocfft_status_e_rocfft_status_invalid_distance as STATUS_INVALID_DISTANCE, 37 | rocfft_status_e_rocfft_status_invalid_offset as STATUS_INVALID_OFFSET, 38 | rocfft_status_e_rocfft_status_invalid_strides as STATUS_INVALID_STRIDES, 39 | rocfft_status_e_rocfft_status_invalid_work_buffer as STATUS_INVALID_WORK_BUFFER, 40 | rocfft_status_e_rocfft_status_success as STATUS_SUCCESS, 41 | }; 42 | 43 | // Transform type constants 44 | pub use bindings::{ 45 | rocfft_transform_type_e_rocfft_transform_type_complex_forward as TRANSFORM_TYPE_COMPLEX_FORWARD, 46 | rocfft_transform_type_e_rocfft_transform_type_complex_inverse as TRANSFORM_TYPE_COMPLEX_INVERSE, 47 | rocfft_transform_type_e_rocfft_transform_type_real_forward as TRANSFORM_TYPE_REAL_FORWARD, 48 | rocfft_transform_type_e_rocfft_transform_type_real_inverse as TRANSFORM_TYPE_REAL_INVERSE, 49 | }; 50 | 51 | // Precision constants 52 | pub use bindings::{ 53 | rocfft_precision_e_rocfft_precision_double as PRECISION_DOUBLE, 54 | rocfft_precision_e_rocfft_precision_half as PRECISION_HALF, 55 | rocfft_precision_e_rocfft_precision_single as PRECISION_SINGLE, 56 | }; 57 | 58 | // Placement constants 59 | pub use bindings::{ 60 | rocfft_result_placement_e_rocfft_placement_inplace as PLACEMENT_INPLACE, 61 | rocfft_result_placement_e_rocfft_placement_notinplace as PLACEMENT_NOTINPLACE, 62 | }; 63 | 64 | // Array type constants 65 | pub use bindings::{ 66 | rocfft_array_type_e_rocfft_array_type_complex_interleaved as ARRAY_TYPE_COMPLEX_INTERLEAVED, 67 | rocfft_array_type_e_rocfft_array_type_complex_planar as ARRAY_TYPE_COMPLEX_PLANAR, 68 | rocfft_array_type_e_rocfft_array_type_hermitian_interleaved as ARRAY_TYPE_HERMITIAN_INTERLEAVED, 69 | rocfft_array_type_e_rocfft_array_type_hermitian_planar as ARRAY_TYPE_HERMITIAN_PLANAR, 70 | rocfft_array_type_e_rocfft_array_type_real as ARRAY_TYPE_REAL, 71 | rocfft_array_type_e_rocfft_array_type_unset as ARRAY_TYPE_UNSET, 72 | }; 73 | 74 | // Communicator type constants 75 | pub use bindings::{ 76 | rocfft_comm_type_e_rocfft_comm_mpi as COMM_TYPE_MPI, 77 | rocfft_comm_type_e_rocfft_comm_none as COMM_TYPE_NONE, 78 | }; 79 | 80 | // Function re-exports 81 | 82 | // Library setup/cleanup 83 | pub use bindings::{rocfft_cleanup, rocfft_get_version_string, rocfft_setup}; 84 | 85 | // Plan creation/destruction 86 | pub use bindings::{ 87 | rocfft_plan_create, rocfft_plan_destroy, rocfft_plan_get_print, 88 | rocfft_plan_get_work_buffer_size, 89 | }; 90 | 91 | // Plan description 92 | pub use bindings::{ 93 | rocfft_plan_description_create, rocfft_plan_description_destroy, 94 | rocfft_plan_description_set_comm, rocfft_plan_description_set_data_layout, 95 | rocfft_plan_description_set_scale_factor, 96 | }; 97 | 98 | // Execution 99 | pub use bindings::{ 100 | rocfft_execute, rocfft_execution_info_create, rocfft_execution_info_destroy, 101 | rocfft_execution_info_set_load_callback, rocfft_execution_info_set_store_callback, 102 | rocfft_execution_info_set_stream, rocfft_execution_info_set_work_buffer, 103 | }; 104 | 105 | // Field/Brick (distributed computation) 106 | pub use bindings::{ 107 | rocfft_brick_create, rocfft_brick_destroy, rocfft_field_add_brick, rocfft_field_create, 108 | rocfft_field_destroy, rocfft_plan_description_add_infield, 109 | rocfft_plan_description_add_outfield, 110 | }; 111 | 112 | // Cache management 113 | pub use bindings::{rocfft_cache_buffer_free, rocfft_cache_deserialize, rocfft_cache_serialize}; 114 | -------------------------------------------------------------------------------- /src/hip/event.rs: -------------------------------------------------------------------------------- 1 | // src/hip/event.rs 2 | 3 | use crate::hip::Stream; 4 | use crate::hip::error::{Error, Result}; 5 | use crate::hip::ffi; 6 | use std::ptr; 7 | 8 | /// Safe wrapper for HIP events 9 | pub struct Event { 10 | event: ffi::hipEvent_t, 11 | } 12 | 13 | impl Event { 14 | /// Create a new event with default flags 15 | pub fn new() -> Result { 16 | let mut event = ptr::null_mut(); 17 | let error = unsafe { ffi::hipEventCreate(&mut event) }; 18 | 19 | if error != ffi::hipError_t_hipSuccess { 20 | return Err(Error::new(error)); 21 | } 22 | 23 | Ok(Self { event }) 24 | } 25 | 26 | /// Create a new event with specific flags 27 | pub fn with_flags(flags: u32) -> Result { 28 | let mut event = ptr::null_mut(); 29 | let error = unsafe { ffi::hipEventCreateWithFlags(&mut event, flags) }; 30 | 31 | if error != ffi::hipError_t_hipSuccess { 32 | return Err(Error::new(error)); 33 | } 34 | 35 | Ok(Self { event }) 36 | } 37 | 38 | /// Record an event in a stream 39 | pub fn record(&self, stream: &Stream) -> Result<()> { 40 | let error = unsafe { ffi::hipEventRecord(self.event, stream.as_raw()) }; 41 | 42 | if error != ffi::hipError_t_hipSuccess { 43 | return Err(Error::new(error)); 44 | } 45 | 46 | Ok(()) 47 | } 48 | 49 | /// Synchronize on the event (wait for it to complete) 50 | pub fn synchronize(&self) -> Result<()> { 51 | let error = unsafe { ffi::hipEventSynchronize(self.event) }; 52 | 53 | if error != ffi::hipError_t_hipSuccess { 54 | return Err(Error::new(error)); 55 | } 56 | 57 | Ok(()) 58 | } 59 | 60 | /// Query if the event has completed 61 | pub fn query(&self) -> Result<()> { 62 | let error = unsafe { ffi::hipEventQuery(self.event) }; 63 | 64 | if error == ffi::hipError_t_hipSuccess { 65 | Ok(()) 66 | } else if error == ffi::hipError_t_hipErrorNotReady { 67 | // Not ready isn't a true error in this context 68 | Err(Error::new(error)) 69 | } else { 70 | Err(Error::new(error)) 71 | } 72 | } 73 | 74 | /// Calculate elapsed time between this event and another in milliseconds 75 | pub fn elapsed_time(&self, end: &Event) -> Result { 76 | let mut time = 0.0; 77 | let error = unsafe { ffi::hipEventElapsedTime(&mut time, self.event, end.event) }; 78 | 79 | if error != ffi::hipError_t_hipSuccess { 80 | return Err(Error::new(error)); 81 | } 82 | 83 | Ok(time) 84 | } 85 | 86 | /// Get the raw event handle 87 | pub fn as_raw(&self) -> ffi::hipEvent_t { 88 | self.event 89 | } 90 | } 91 | 92 | impl Drop for Event { 93 | fn drop(&mut self) { 94 | if !self.event.is_null() { 95 | unsafe { 96 | let _ = ffi::hipEventDestroy(self.event); 97 | // We cannot handle errors in drop, so just ignore the result 98 | }; 99 | self.event = ptr::null_mut(); 100 | } 101 | } 102 | } 103 | 104 | /// Constants for event creation flags 105 | pub mod event_flags { 106 | /// Default event creation flag 107 | pub const DEFAULT: u32 = 0; 108 | 109 | /// Event uses blocking synchronization 110 | pub const BLOCKING_SYNC: u32 = 1; 111 | 112 | /// Event will not record timing data 113 | pub const DISABLE_TIMING: u32 = 2; 114 | 115 | /// Event is suitable for interprocess use 116 | pub const INTERPROCESS: u32 = 4; 117 | } 118 | 119 | /// Helper struct to measure elapsed time 120 | pub struct Timer { 121 | start: Event, 122 | stop: Event, 123 | } 124 | 125 | impl Timer { 126 | /// Create a new timer 127 | pub fn new() -> Result { 128 | // Create with DISABLE_TIMING = false to enable timing 129 | let start = Event::new()?; 130 | let stop = Event::new()?; 131 | 132 | Ok(Self { start, stop }) 133 | } 134 | 135 | /// Start the timer by recording the start event 136 | pub fn start(&self, stream: &Stream) -> Result<()> { 137 | self.start.record(stream) 138 | } 139 | 140 | /// Stop the timer by recording the stop event 141 | pub fn stop(&self, stream: &Stream) -> Result<()> { 142 | self.stop.record(stream) 143 | } 144 | 145 | /// Get the elapsed time in milliseconds 146 | /// Note: This will synchronize the stop event if it has not completed yet 147 | pub fn elapsed_time(&self) -> Result { 148 | // Make sure the stop event has completed 149 | self.stop.synchronize()?; 150 | 151 | // Calculate the elapsed time 152 | self.start.elapsed_time(&self.stop) 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /src/miopen/examples/basic/src/main.rs: -------------------------------------------------------------------------------- 1 | use rocm_rs::{ 2 | error::Result, 3 | hip::DeviceMemory, 4 | miopen::{self, ActivationDescriptor, ActivationMode, DataType, TensorDescriptor}, 5 | }; 6 | 7 | fn main() -> Result<()> { 8 | // ----------------------- 9 | // 1. MIOpen handle 10 | // ----------------------- 11 | let miopen = miopen::Handle::new()?; 12 | 13 | // ----------------------- 14 | // 2. Training data 15 | // ----------------------- 16 | let input_len = 8usize; 17 | 18 | let x_host = vec![-10., -2., -1., 0., 1., 2., 3., 10.]; 19 | let y_target = vec![0.0, 0., 0., 0.5, 1., 1., 1., 1.]; 20 | let mut y_pred = vec![0f32; input_len]; 21 | let mut dl_dy = vec![0f32; input_len]; 22 | 23 | // ----------------------- 24 | // 3. Allocate GPU buffers 25 | // ----------------------- 26 | let mut d_linear = DeviceMemory::::new(input_len)?; // stores wx+b 27 | let mut d_y = DeviceMemory::::new(input_len)?; // ReLU output 28 | let mut d_dy = DeviceMemory::::new(input_len)?; // gradient input 29 | let mut d_dx = DeviceMemory::::new(input_len)?; // gradient output 30 | 31 | // ----------------------- 32 | // 4. Tensor descriptor 33 | // ----------------------- 34 | let tensor = TensorDescriptor::new_4d(DataType::MiopenFloat, 1, 1, 1, input_len as i32)?; 35 | 36 | // ----------------------- 37 | // 5. ReLU activation 38 | // ----------------------- 39 | let activation = 40 | ActivationDescriptor::with_mode(ActivationMode::MiopenActivationLOGISTIC, 0.0, 0.0, 0.0)?; 41 | 42 | let alpha = 1f32; 43 | let beta = 0f32; 44 | 45 | // ----------------------- 46 | // 6. Parameters of our 1-neuron model 47 | // ----------------------- 48 | let mut w: f32 = 0.1; 49 | let mut b: f32 = 0.0; 50 | let lr: f32 = 0.01; 51 | 52 | // ----------------------- 53 | // 7. Training loop 54 | // ----------------------- 55 | for epoch in 0..200 { 56 | // ---- forward linear: wx + b ---- 57 | for i in 0..input_len { 58 | y_pred[i] = w * x_host[i] + b; 59 | } 60 | 61 | d_linear.copy_from_host(&y_pred)?; 62 | 63 | // ---- MIOpen forward: ReLU(wx+b) ---- 64 | 65 | activation.forward( 66 | &miopen, &alpha, &tensor, &d_linear, &beta, &tensor, &mut d_y, 67 | )?; 68 | 69 | // bring prediction back 70 | d_y.copy_to_host(&mut y_pred)?; 71 | 72 | // ---- compute dL/dy = 2*(y_pred - y_target) ---- 73 | let mut loss = 0.0; 74 | for i in 0..input_len { 75 | let err = y_pred[i] - y_target[i]; 76 | loss += err * err; 77 | dl_dy[i] = 2.0 * err; 78 | } 79 | 80 | d_dy.copy_from_host(&dl_dy)?; 81 | 82 | // ---- MIOpen backward: dL/dx = ReLU'(x)*dL/dy ---- 83 | activation.backward( 84 | &miopen, &alpha, &tensor, &d_y, // y from forward 85 | &tensor, &d_dy, // dL/dy 86 | &tensor, &d_linear, // x before activation 87 | &beta, &tensor, &mut d_dx, // output: dL/dx 88 | )?; 89 | 90 | // get dL/dx back 91 | let mut dl_dx = vec![0f32; input_len]; 92 | d_dx.copy_to_host(&mut dl_dx)?; 93 | 94 | // ---- compute gradients for w and b ---- 95 | let grad_w: f32 = dl_dx.iter().zip(x_host.iter()).map(|(dx, x)| dx * x).sum(); 96 | 97 | let grad_b: f32 = dl_dx.iter().sum(); 98 | 99 | // ---- gradient descent ---- 100 | w -= lr * grad_w; 101 | b -= lr * grad_b; 102 | 103 | if epoch % 10 == 0 { 104 | println!( 105 | "epoch {:3} loss={:.4} w={:.3} b={:.3}", 106 | epoch, loss, w, b 107 | ); 108 | } 109 | } 110 | 111 | println!("\n=== INFERENCE PHASE ==="); 112 | 113 | let test_inputs = vec![-2.0, -1.0, 0., 1.0, 5.0]; 114 | let mut test_linear = vec![0f32; test_inputs.len()]; 115 | let mut test_output = vec![0f32; test_inputs.len()]; 116 | 117 | // ----------------------- 118 | // 8. Linear forward 119 | // ----------------------- 120 | for i in 0..test_inputs.len() { 121 | test_linear[i] = w * test_inputs[i] + b; 122 | } 123 | 124 | // ----------------------- 125 | // 9. Copy to GPU 126 | // ----------------------- 127 | d_linear.copy_from_host(&test_linear)?; 128 | 129 | // ----------------------- 130 | // 10. Inference 131 | // ----------------------- 132 | activation.forward( 133 | &miopen, &alpha, &tensor, &d_linear, &beta, &tensor, &mut d_y, 134 | )?; 135 | 136 | d_y.copy_to_host(&mut test_output)?; 137 | 138 | // ----------------------- 139 | // 11. Results 140 | // ----------------------- 141 | for (x, y) in test_inputs.iter().zip(test_output.iter()) { 142 | println!("input = {:>5} output = {}", x, y); 143 | } 144 | 145 | Ok(()) 146 | } 147 | -------------------------------------------------------------------------------- /src/rocsparse/descriptor.rs: -------------------------------------------------------------------------------- 1 | //! Matrix descriptor types and enums 2 | 3 | use crate::rocsparse::error::*; 4 | use crate::rocsparse::{ 5 | rocsparse_create_mat_descr, rocsparse_destroy_mat_descr, rocsparse_direction_, 6 | rocsparse_direction__rocsparse_direction_column, rocsparse_direction__rocsparse_direction_row, 7 | rocsparse_get_mat_index_base, rocsparse_get_mat_type, rocsparse_index_base_, 8 | rocsparse_index_base__rocsparse_index_base_one, 9 | rocsparse_index_base__rocsparse_index_base_zero, rocsparse_mat_descr, rocsparse_matrix_type_, 10 | rocsparse_matrix_type__rocsparse_matrix_type_general, 11 | rocsparse_matrix_type__rocsparse_matrix_type_hermitian, 12 | rocsparse_matrix_type__rocsparse_matrix_type_symmetric, 13 | rocsparse_matrix_type__rocsparse_matrix_type_triangular, rocsparse_set_mat_index_base, 14 | rocsparse_set_mat_type, 15 | }; 16 | use std::mem::MaybeUninit; 17 | 18 | /// Matrix storage format 19 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 20 | pub enum MatrixType { 21 | /// General matrix 22 | General, 23 | /// Symmetric matrix 24 | Symmetric, 25 | /// Hermitian matrix 26 | Hermitian, 27 | /// Triangular matrix 28 | Triangular, 29 | } 30 | 31 | impl From for rocsparse_matrix_type_ { 32 | fn from(ty: MatrixType) -> Self { 33 | match ty { 34 | MatrixType::General => rocsparse_matrix_type__rocsparse_matrix_type_general, 35 | MatrixType::Symmetric => rocsparse_matrix_type__rocsparse_matrix_type_symmetric, 36 | MatrixType::Hermitian => rocsparse_matrix_type__rocsparse_matrix_type_hermitian, 37 | MatrixType::Triangular => rocsparse_matrix_type__rocsparse_matrix_type_triangular, 38 | } 39 | } 40 | } 41 | 42 | /// Index base for sparse matrices 43 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 44 | pub enum IndexBase { 45 | /// Zero-based indexing 46 | Zero, 47 | /// One-based indexing 48 | One, 49 | } 50 | 51 | impl From for rocsparse_index_base_ { 52 | fn from(base: IndexBase) -> Self { 53 | match base { 54 | IndexBase::Zero => rocsparse_index_base__rocsparse_index_base_zero, 55 | IndexBase::One => rocsparse_index_base__rocsparse_index_base_one, 56 | } 57 | } 58 | } 59 | 60 | /// Direction for block storage formats 61 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 62 | pub enum Direction { 63 | /// Parse matrix by rows 64 | Row, 65 | /// Parse matrix by columns 66 | Column, 67 | } 68 | 69 | impl From for rocsparse_direction_ { 70 | fn from(dir: Direction) -> Self { 71 | match dir { 72 | Direction::Row => rocsparse_direction__rocsparse_direction_row, 73 | Direction::Column => rocsparse_direction__rocsparse_direction_column, 74 | } 75 | } 76 | } 77 | 78 | /// Matrix descriptor for sparse matrices 79 | pub struct MatrixDescriptor { 80 | pub(crate) inner: rocsparse_mat_descr, 81 | } 82 | 83 | impl MatrixDescriptor { 84 | /// Create a new matrix descriptor 85 | pub fn new() -> Result { 86 | let mut descr = MaybeUninit::uninit(); 87 | let status = unsafe { rocsparse_create_mat_descr(descr.as_mut_ptr()) }; 88 | status_to_result(status)?; 89 | let descr = unsafe { descr.assume_init() }; 90 | Ok(Self { inner: descr }) 91 | } 92 | 93 | /// Set the index base 94 | pub fn set_index_base(&self, base: IndexBase) -> Result<()> { 95 | let status = unsafe { rocsparse_set_mat_index_base(self.inner, base.into()) }; 96 | status_to_result(status) 97 | } 98 | 99 | /// Get the index base 100 | pub fn get_index_base(&self) -> IndexBase { 101 | let base = unsafe { rocsparse_get_mat_index_base(self.inner) }; 102 | if base == rocsparse_index_base__rocsparse_index_base_one { 103 | IndexBase::One 104 | } else { 105 | IndexBase::Zero 106 | } 107 | } 108 | 109 | /// Set the matrix type 110 | pub fn set_matrix_type(&self, ty: MatrixType) -> Result<()> { 111 | let status = unsafe { rocsparse_set_mat_type(self.inner, ty.into()) }; 112 | status_to_result(status) 113 | } 114 | 115 | /// Get the matrix type 116 | pub fn get_matrix_type(&self) -> MatrixType { 117 | let ty = unsafe { rocsparse_get_mat_type(self.inner) }; 118 | match ty { 119 | rocsparse_matrix_type__rocsparse_matrix_type_symmetric => MatrixType::Symmetric, 120 | rocsparse_matrix_type__rocsparse_matrix_type_hermitian => MatrixType::Hermitian, 121 | rocsparse_matrix_type__rocsparse_matrix_type_triangular => MatrixType::Triangular, 122 | _ => MatrixType::General, 123 | } 124 | } 125 | } 126 | 127 | impl Drop for MatrixDescriptor { 128 | fn drop(&mut self) { 129 | unsafe { 130 | // Ignore error on drop 131 | let _ = rocsparse_destroy_mat_descr(self.inner); 132 | } 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /src/miopen/examples/multi_tensor/src/layer.rs: -------------------------------------------------------------------------------- 1 | use std::rc::Rc; 2 | 3 | use rocm_rs::{error::Result, hip::{DeviceMemory, Dim3, Function, Module, kernel::AsKernelArg}, kernel_args, miopen::{self, ActivationDescriptor, ActivationMode, DataType, TensorDescriptor}}; 4 | 5 | 6 | const ALPHA: f32 = 1.0; 7 | const BETA: f32 = 0.0; 8 | 9 | pub struct Layer { 10 | tensor_desc: TensorDescriptor, 11 | activation_desc: ActivationDescriptor, 12 | device_act: DeviceMemory, 13 | device_grad_pre: DeviceMemory, 14 | pub(crate) device_grad_act: DeviceMemory, 15 | grad_pre: Vec, 16 | input_grad: Vec, 17 | input_size: usize, 18 | output_size: usize, 19 | device_weights: DeviceMemory, 20 | device_bias: DeviceMemory, 21 | device_output: DeviceMemory, 22 | _module: Rc, 23 | function: Function, 24 | } 25 | 26 | impl Layer { 27 | pub fn new( 28 | output_size: usize, 29 | input_size: usize, 30 | activation_mode: ActivationMode, 31 | module: Rc, 32 | ) -> Result { 33 | let function = module.get_function("linear_transform")?; 34 | Ok(Self { 35 | tensor_desc: TensorDescriptor::new_4d( 36 | DataType::MiopenFloat, 37 | 1, 38 | output_size as i32, 39 | 1, 40 | 1, 41 | )?, 42 | activation_desc: ActivationDescriptor::with_mode(activation_mode, 0.0, 0.0, 0.0)?, 43 | device_act: DeviceMemory::new(output_size)?, 44 | device_grad_pre: DeviceMemory::new(output_size)?, 45 | device_grad_act: DeviceMemory::new(output_size)?, 46 | grad_pre: vec![0.0; output_size], 47 | input_grad: vec![0.0; input_size], 48 | input_size, 49 | output_size, 50 | device_weights: DeviceMemory::new(output_size * input_size)?, 51 | device_bias: DeviceMemory::new(output_size)?, 52 | device_output: DeviceMemory::new(output_size)?, 53 | _module: module, 54 | function, 55 | }) 56 | } 57 | 58 | pub fn input_grad(&self) -> &[f32] { 59 | &self.input_grad 60 | } 61 | 62 | pub fn forward( 63 | &mut self, 64 | handle: &miopen::Handle, 65 | input: &DeviceMemory, 66 | weights: &[f32], 67 | bias: &[f32], 68 | ) -> Result<&DeviceMemory> { 69 | self.device_weights.copy_from_host(weights)?; 70 | self.device_bias.copy_from_host(bias)?; 71 | 72 | let args = kernel_args!( 73 | input, 74 | self.device_weights, 75 | self.device_bias, 76 | self.device_output, 77 | self.input_size, 78 | self.output_size 79 | ); 80 | 81 | self.function.launch( 82 | Dim3::new_1d(self.output_size as u32), 83 | Dim3::new_1d(1), 84 | 0, 85 | None, 86 | args, 87 | )?; 88 | 89 | self.activation_desc.forward( 90 | handle, 91 | &ALPHA, 92 | &self.tensor_desc, 93 | &self.device_output, 94 | &BETA, 95 | &self.tensor_desc, 96 | &mut self.device_act, 97 | )?; 98 | 99 | Ok(&self.device_act) 100 | } 101 | 102 | pub fn backward( 103 | &mut self, 104 | handle: &miopen::Handle, 105 | prev_activations: &DeviceMemory, 106 | weights: &mut [f32], 107 | bias: &mut [f32], 108 | learning_rate: f32, 109 | ) -> Result<()> { 110 | self.activation_desc.backward( 111 | handle, 112 | &ALPHA, 113 | &self.tensor_desc, 114 | &self.device_act, 115 | &self.tensor_desc, 116 | &self.device_grad_act, 117 | &self.tensor_desc, 118 | &self.device_output, 119 | &BETA, 120 | &self.tensor_desc, 121 | &mut self.device_grad_pre, 122 | )?; 123 | self.device_grad_pre.copy_to_host(&mut self.grad_pre)?; 124 | 125 | let input_size = self.input_size; 126 | 127 | let prev_activations = { 128 | let mut vec = vec![0.0; input_size]; 129 | prev_activations.copy_to_host(&mut vec)?; 130 | vec 131 | }; 132 | 133 | for (i, grad_in) in self.input_grad.iter_mut().enumerate() { 134 | let mut sum = 0.0; 135 | for (o, &grad_out) in self.grad_pre.iter().enumerate() { 136 | sum += weights[o * input_size + i] * grad_out; 137 | } 138 | *grad_in = sum; 139 | } 140 | 141 | for (o, &grad) in self.grad_pre.iter().enumerate() { 142 | let start = o * self.input_size; 143 | for input_idx in 0..self.input_size { 144 | weights[start + input_idx] -= learning_rate * grad * prev_activations[input_idx]; 145 | } 146 | bias[o] -= learning_rate * grad; 147 | } 148 | 149 | Ok(()) 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /src/miopen/ctc_loss.rs: -------------------------------------------------------------------------------- 1 | // src/miopen/ctc_loss.rs 2 | 3 | use crate::miopen::error::{Error, Result}; 4 | use crate::miopen::ffi; 5 | use crate::miopen::handle::Handle; 6 | use crate::miopen::tensor::TensorDescriptor; 7 | use std::os::raw::c_void; 8 | use std::ptr; 9 | 10 | /// CTC Loss algorithm 11 | pub type CTCLossAlgo = ffi::miopenCTCLossAlgo_t; 12 | 13 | /// Safe wrapper for MIOpen CTC Loss descriptor 14 | pub struct CTCLossDescriptor { 15 | desc: ffi::miopenCTCLossDescriptor_t, 16 | } 17 | 18 | impl CTCLossDescriptor { 19 | /// Create a new CTC Loss descriptor 20 | pub fn new() -> Result { 21 | let mut desc = ptr::null_mut(); 22 | let status = unsafe { ffi::miopenCreateCTCLossDescriptor(&mut desc) }; 23 | 24 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 25 | return Err(Error::new(status)); 26 | } 27 | 28 | Ok(Self { desc }) 29 | } 30 | 31 | /// Set the CTC Loss descriptor 32 | pub fn set( 33 | &mut self, 34 | data_type: ffi::miopenDataType_t, 35 | blank_label_id: i32, 36 | apply_softmax_layer: bool, 37 | ) -> Result<()> { 38 | let status = unsafe { 39 | ffi::miopenSetCTCLossDescriptor( 40 | self.desc, 41 | data_type, 42 | blank_label_id, 43 | apply_softmax_layer, 44 | ) 45 | }; 46 | 47 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 48 | return Err(Error::new(status)); 49 | } 50 | 51 | Ok(()) 52 | } 53 | 54 | /// Get the CTC Loss descriptor details 55 | pub fn get(&self) -> Result<(ffi::miopenDataType_t, i32, bool)> { 56 | let mut data_type = 0; 57 | let mut blank_label_id = 0; 58 | let mut apply_softmax_layer = false; 59 | 60 | let status = unsafe { 61 | ffi::miopenGetCTCLossDescriptor( 62 | self.desc, 63 | &mut data_type, 64 | &mut blank_label_id, 65 | &mut apply_softmax_layer, 66 | ) 67 | }; 68 | 69 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 70 | return Err(Error::new(status)); 71 | } 72 | 73 | Ok((data_type, blank_label_id, apply_softmax_layer)) 74 | } 75 | 76 | /// Get the raw descriptor 77 | pub fn as_raw(&self) -> ffi::miopenCTCLossDescriptor_t { 78 | self.desc 79 | } 80 | } 81 | 82 | impl Drop for CTCLossDescriptor { 83 | fn drop(&mut self) { 84 | if !self.desc.is_null() { 85 | unsafe { 86 | let _ = ffi::miopenDestroyCTCLossDescriptor(self.desc); 87 | // We cannot handle errors in drop, so just ignore the result 88 | }; 89 | self.desc = ptr::null_mut(); 90 | } 91 | } 92 | } 93 | 94 | /// Get the workspace size required for CTC Loss operations 95 | pub fn get_ctc_loss_workspace_size( 96 | handle: &Handle, 97 | probs_desc: &TensorDescriptor, 98 | gradients_desc: &TensorDescriptor, 99 | labels: &[i32], 100 | label_lengths: &[i32], 101 | input_lengths: &[i32], 102 | algo: CTCLossAlgo, 103 | ctc_loss_desc: &CTCLossDescriptor, 104 | ) -> Result { 105 | let mut workspace_size = 0; 106 | 107 | let status = unsafe { 108 | ffi::miopenGetCTCLossWorkspaceSize( 109 | handle.as_raw(), 110 | probs_desc.as_raw(), 111 | gradients_desc.as_raw(), 112 | labels.as_ptr(), 113 | label_lengths.as_ptr(), 114 | input_lengths.as_ptr(), 115 | algo, 116 | ctc_loss_desc.as_raw(), 117 | &mut workspace_size, 118 | ) 119 | }; 120 | 121 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 122 | return Err(Error::new(status)); 123 | } 124 | 125 | Ok(workspace_size) 126 | } 127 | 128 | /// Execute CTC Loss forward and gradient computation 129 | pub unsafe fn ctc_loss( 130 | handle: &Handle, 131 | probs_desc: &TensorDescriptor, 132 | probs: *const c_void, 133 | labels: &[i32], 134 | label_lengths: &[i32], 135 | input_lengths: &[i32], 136 | losses: *mut c_void, 137 | gradients_desc: &TensorDescriptor, 138 | gradients: *mut c_void, 139 | algo: CTCLossAlgo, 140 | ctc_loss_desc: &CTCLossDescriptor, 141 | workspace: *mut c_void, 142 | workspace_size: usize, 143 | ) -> Result<()> { 144 | let status = unsafe { 145 | ffi::miopenCTCLoss( 146 | handle.as_raw(), 147 | probs_desc.as_raw(), 148 | probs, 149 | labels.as_ptr(), 150 | label_lengths.as_ptr(), 151 | input_lengths.as_ptr(), 152 | losses, 153 | gradients_desc.as_raw(), 154 | gradients, 155 | algo, 156 | ctc_loss_desc.as_raw(), 157 | workspace, 158 | workspace_size, 159 | ) 160 | }; 161 | 162 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 163 | return Err(Error::new(status)); 164 | } 165 | 166 | Ok(()) 167 | } 168 | -------------------------------------------------------------------------------- /src/rocsparse/conversion.rs: -------------------------------------------------------------------------------- 1 | //! Matrix format conversion utilities 2 | 3 | use crate::rocsparse::descriptor::{IndexBase, MatrixDescriptor}; 4 | use crate::rocsparse::error::status_to_result; 5 | use crate::rocsparse::error::*; 6 | use crate::rocsparse::handle::Handle; 7 | use crate::rocsparse::{ 8 | rocsparse_action__rocsparse_action_numeric, rocsparse_action__rocsparse_action_symbolic, 9 | rocsparse_create_identity_permutation, rocsparse_csr2csc_buffer_size, rocsparse_csrsort, 10 | rocsparse_csrsort_buffer_size, rocsparse_scsr2csc, 11 | }; 12 | use std::ffi::c_void; 13 | 14 | /// Convert CSR to CSC (Compressed Sparse Column) format 15 | pub fn csr_to_csc( 16 | handle: &Handle, 17 | m: i32, 18 | n: i32, 19 | nnz: i32, 20 | csr_val: &[T], 21 | csr_row_ptr: &[i32], 22 | csr_col_ind: &[i32], 23 | csc_val: &mut [T], 24 | csc_row_ind: &mut [i32], 25 | csc_col_ptr: &mut [i32], 26 | copy_values: bool, 27 | idx_base: IndexBase, 28 | ) -> crate::rocsparse::error::Result<()> { 29 | // Get required buffer size 30 | let mut buffer_size = 0; 31 | let status = unsafe { 32 | rocsparse_csr2csc_buffer_size( 33 | handle.inner, 34 | m, 35 | n, 36 | nnz, 37 | csr_row_ptr.as_ptr(), 38 | csr_col_ind.as_ptr(), 39 | if copy_values { 40 | rocsparse_action__rocsparse_action_numeric 41 | } else { 42 | rocsparse_action__rocsparse_action_symbolic 43 | }, 44 | &mut buffer_size, 45 | ) 46 | }; 47 | status_to_result(status)?; 48 | 49 | // Allocate temporary buffer 50 | let mut temp_buffer = vec![0u8; buffer_size]; 51 | 52 | // Perform conversion based on type 53 | let status = convert_csr_to_csc( 54 | handle, 55 | m, 56 | n, 57 | nnz, 58 | csr_val, 59 | csr_row_ptr, 60 | csr_col_ind, 61 | csc_val, 62 | csc_row_ind, 63 | csc_col_ptr, 64 | copy_values, 65 | idx_base, 66 | temp_buffer.as_mut_ptr() as *mut c_void, 67 | ); 68 | 69 | status 70 | } 71 | 72 | // Implementation for specific types 73 | fn convert_csr_to_csc( 74 | handle: &Handle, 75 | m: i32, 76 | n: i32, 77 | nnz: i32, 78 | csr_val: &[T], 79 | csr_row_ptr: &[i32], 80 | csr_col_ind: &[i32], 81 | csc_val: &mut [T], 82 | csc_row_ind: &mut [i32], 83 | csc_col_ptr: &mut [i32], 84 | copy_values: bool, 85 | idx_base: IndexBase, 86 | temp_buffer: *mut c_void, 87 | ) -> Result<()> { 88 | // This would need to be implemented for each supported type (f32, f64, complex, etc.) 89 | // For simplicity, I'm showing the f32 case only 90 | 91 | if std::any::TypeId::of::() == std::any::TypeId::of::() { 92 | let status = unsafe { 93 | rocsparse_scsr2csc( 94 | handle.inner, 95 | m, 96 | n, 97 | nnz, 98 | csr_val.as_ptr() as *const f32, 99 | csr_row_ptr.as_ptr(), 100 | csr_col_ind.as_ptr(), 101 | csc_val.as_mut_ptr() as *mut f32, 102 | csc_row_ind.as_mut_ptr(), 103 | csc_col_ptr.as_mut_ptr(), 104 | if copy_values { 105 | rocsparse_action__rocsparse_action_numeric 106 | } else { 107 | rocsparse_action__rocsparse_action_symbolic 108 | }, 109 | idx_base.into(), 110 | temp_buffer, 111 | ) 112 | }; 113 | status_to_result(status) 114 | } else { 115 | Err(Error::NotImplemented) 116 | } 117 | } 118 | 119 | /// Create an identity permutation vector 120 | pub fn create_identity_permutation(handle: &Handle, n: i32, p: &mut [i32]) -> Result<()> { 121 | let status = unsafe { rocsparse_create_identity_permutation(handle.inner, n, p.as_mut_ptr()) }; 122 | status_to_result(status) 123 | } 124 | 125 | /// Sort a sparse CSR matrix 126 | pub fn csr_sort( 127 | handle: &Handle, 128 | m: i32, 129 | n: i32, 130 | nnz: i32, 131 | descr: &MatrixDescriptor, 132 | csr_row_ptr: &[i32], 133 | csr_col_ind: &mut [i32], 134 | perm: Option<&mut [i32]>, 135 | ) -> Result<()> { 136 | // Get required buffer size 137 | let mut buffer_size = 0; 138 | let status = unsafe { 139 | rocsparse_csrsort_buffer_size( 140 | handle.inner, 141 | m, 142 | n, 143 | nnz, 144 | csr_row_ptr.as_ptr(), 145 | csr_col_ind.as_ptr(), 146 | &mut buffer_size, 147 | ) 148 | }; 149 | status_to_result(status)?; 150 | 151 | // Allocate temporary buffer 152 | let mut temp_buffer = vec![0u8; buffer_size]; 153 | 154 | // Perform sort 155 | let status = unsafe { 156 | rocsparse_csrsort( 157 | handle.inner, 158 | m, 159 | n, 160 | nnz, 161 | descr.inner, 162 | csr_row_ptr.as_ptr(), 163 | csr_col_ind.as_mut_ptr(), 164 | perm.map_or(std::ptr::null_mut(), |p| p.as_mut_ptr()), 165 | temp_buffer.as_mut_ptr() as *mut c_void, 166 | ) 167 | }; 168 | 169 | status_to_result(status) 170 | } 171 | -------------------------------------------------------------------------------- /src/hip/module.rs: -------------------------------------------------------------------------------- 1 | // src/hip/module.rs 2 | // 3 | // Module loading and management for HIP 4 | 5 | use crate::hip::error::{Error, Result}; 6 | use crate::hip::ffi; 7 | use crate::hip::kernel::Function; 8 | use std::ffi::{CString, c_void}; 9 | use std::fs; 10 | use std::path::Path; 11 | use std::ptr; 12 | 13 | /// A wrapper around a HIP module 14 | pub struct Module { 15 | module: ffi::hipModule_t, 16 | } 17 | 18 | impl Module { 19 | /// Load a module from a file 20 | pub fn load>(path: P) -> Result { 21 | let path_str = path.as_ref().to_string_lossy(); 22 | let path_cstr = CString::new(path_str.as_bytes()).unwrap(); 23 | 24 | let mut module = ptr::null_mut(); 25 | let error = unsafe { ffi::hipModuleLoad(&mut module, path_cstr.as_ptr()) }; 26 | 27 | if error != ffi::hipError_t_hipSuccess { 28 | return Err(Error::new(error)); 29 | } 30 | 31 | Ok(Self { module }) 32 | } 33 | 34 | /// Load a module from a code object containing PTX code 35 | pub fn load_data(data: impl AsRef<[u8]>) -> Result { 36 | let mut module = ptr::null_mut(); 37 | let error = 38 | unsafe { ffi::hipModuleLoadData(&mut module, data.as_ref().as_ptr() as *const c_void) }; 39 | 40 | if error != ffi::hipError_t_hipSuccess { 41 | return Err(Error::new(error)); 42 | } 43 | 44 | Ok(Self { module }) 45 | } 46 | 47 | /// Load a module from a code object containing PTX code with options 48 | pub unsafe fn load_with_options( 49 | data: impl AsRef<[u8]>, 50 | num_options: u32, 51 | options: *mut ffi::hipJitOption, 52 | option_values: *mut *mut c_void, 53 | ) -> Result { 54 | let mut module = ptr::null_mut(); 55 | let error = unsafe { 56 | ffi::hipModuleLoadDataEx( 57 | &mut module, 58 | data.as_ref().as_ptr() as *const c_void, 59 | num_options, 60 | options, 61 | option_values, 62 | ) 63 | }; 64 | 65 | if error != ffi::hipError_t_hipSuccess { 66 | return Err(Error::new(error)); 67 | } 68 | 69 | Ok(Self { module }) 70 | } 71 | 72 | /// Get a function from the module 73 | pub fn get_function(&self, name: &str) -> Result { 74 | unsafe { Function::new(self.module, name) } 75 | } 76 | 77 | /// Get a global variable from the module 78 | pub fn get_global(&self, name: &str) -> Result<*mut T> { 79 | let name_cstr = CString::new(name).unwrap(); 80 | 81 | let mut dev_ptr = ptr::null_mut(); 82 | let mut size = 0usize; 83 | 84 | let error = unsafe { 85 | ffi::hipModuleGetGlobal(&mut dev_ptr, &mut size, self.module, name_cstr.as_ptr()) 86 | }; 87 | 88 | if error != ffi::hipError_t_hipSuccess { 89 | return Err(Error::new(error)); 90 | } 91 | 92 | if size < std::mem::size_of::() { 93 | return Err(Error::new(ffi::hipError_t_hipErrorInvalidValue)); 94 | } 95 | 96 | Ok(dev_ptr as *mut T) 97 | } 98 | 99 | /// Get the raw module handle 100 | pub fn as_raw(&self) -> ffi::hipModule_t { 101 | self.module 102 | } 103 | } 104 | 105 | impl Drop for Module { 106 | fn drop(&mut self) { 107 | if !self.module.is_null() { 108 | unsafe { 109 | let _ = ffi::hipModuleUnload(self.module); 110 | // We cannot handle errors in drop, so just ignore the result 111 | } 112 | self.module = ptr::null_mut(); 113 | } 114 | } 115 | } 116 | 117 | /// Helper function to load a module from a file 118 | pub fn load_module>(path: P) -> Result { 119 | Module::load(path) 120 | } 121 | 122 | /// Helper function to load a module from data 123 | pub fn load_module_data(data: &str) -> Result { 124 | Module::load_data(data) 125 | } 126 | 127 | /// Helper function to compile and load HIP code 128 | pub fn compile_and_load(source: &str, options: &[String]) -> Result { 129 | // This is a placeholder for a function that would: 130 | // 1. Save the source to a temporary file 131 | // 2. Run hipcc to compile it 132 | // 3. Load the resulting binary 133 | // 134 | // A real implementation would depend on your build system 135 | // and how you want to handle compilation. 136 | // 137 | // For now, let's just show how it might work: 138 | use std::env::temp_dir; 139 | use std::process::Command; 140 | 141 | let temp_src_path = temp_dir().join("temp_kernel.cpp"); 142 | let temp_bin_path = temp_dir().join("temp_kernel.hsaco"); 143 | 144 | fs::write(&temp_src_path, source) 145 | .map_err(|_| Error::new(ffi::hipError_t_hipErrorInvalidValue))?; 146 | 147 | let mut cmd = Command::new("hipcc"); 148 | cmd.arg("--genco"); 149 | 150 | for opt in options { 151 | cmd.arg(opt); 152 | } 153 | 154 | cmd.arg("-o").arg(&temp_bin_path).arg(&temp_src_path); 155 | 156 | let status = cmd 157 | .status() 158 | .map_err(|_| Error::new(ffi::hipError_t_hipErrorInvalidValue))?; 159 | 160 | if !status.success() { 161 | return Err(Error::new(ffi::hipError_t_hipErrorInvalidValue)); 162 | } 163 | 164 | Module::load(temp_bin_path) 165 | } 166 | -------------------------------------------------------------------------------- /src/hip/device.rs: -------------------------------------------------------------------------------- 1 | // src/hip/device.rs 2 | 3 | use crate::hip::error::{Error, Result}; 4 | use crate::hip::{Stream, ffi}; 5 | use std::ffi::CStr; 6 | 7 | /// Get the number of available devices 8 | pub fn get_device_count() -> Result { 9 | let mut count = 0; 10 | let error = unsafe { ffi::hipGetDeviceCount(&mut count) }; 11 | Error::from_hip_error_with_value(error, count) 12 | } 13 | 14 | /// Device properties 15 | #[derive(Debug, Clone)] 16 | pub struct DeviceProperties { 17 | pub name: String, 18 | pub total_global_mem: usize, 19 | pub shared_mem_per_block: usize, 20 | pub regs_per_block: i32, 21 | pub warp_size: i32, 22 | pub max_threads_per_block: i32, 23 | pub max_threads_dim: [i32; 3], 24 | pub max_grid_size: [i32; 3], 25 | pub clock_rate: i32, 26 | pub memory_clock_rate: i32, 27 | pub memory_bus_width: i32, 28 | pub total_const_mem: usize, 29 | pub major: i32, 30 | pub minor: i32, 31 | pub multi_processor_count: i32, 32 | pub l2_cache_size: i32, 33 | pub max_threads_per_multiprocessor: i32, 34 | pub compute_mode: i32, 35 | pub integrated: i32, 36 | pub can_map_host_memory: i32, 37 | } 38 | 39 | /// Get device properties for a given device 40 | pub fn get_device_properties(device_id: i32) -> Result { 41 | let mut props = unsafe { std::mem::zeroed::() }; 42 | let error = unsafe { ffi::hipGetDevicePropertiesR0600(&mut props, device_id) }; 43 | 44 | if error != ffi::hipError_t_hipSuccess { 45 | return Err(Error::new(error)); 46 | } 47 | 48 | let name = unsafe { 49 | let name_ptr = props.name.as_ptr() as *const i8; 50 | CStr::from_ptr(name_ptr).to_string_lossy().into_owned() 51 | }; 52 | 53 | Ok(DeviceProperties { 54 | name, 55 | total_global_mem: props.totalGlobalMem, 56 | shared_mem_per_block: props.sharedMemPerBlock, 57 | regs_per_block: props.regsPerBlock, 58 | warp_size: props.warpSize, 59 | max_threads_per_block: props.maxThreadsPerBlock, 60 | max_threads_dim: props.maxThreadsDim, 61 | max_grid_size: props.maxGridSize, 62 | clock_rate: props.clockRate, 63 | memory_clock_rate: props.memoryClockRate, 64 | memory_bus_width: props.memoryBusWidth, 65 | total_const_mem: props.totalConstMem, 66 | major: props.major, 67 | minor: props.minor, 68 | multi_processor_count: props.multiProcessorCount, 69 | l2_cache_size: props.l2CacheSize, 70 | max_threads_per_multiprocessor: props.maxThreadsPerMultiProcessor, 71 | compute_mode: props.computeMode, 72 | integrated: props.integrated, 73 | can_map_host_memory: props.canMapHostMemory, 74 | }) 75 | } 76 | 77 | /// A wrapper for HIP device operations 78 | #[derive(Debug, Clone)] 79 | pub struct Device { 80 | id: i32, 81 | } 82 | 83 | impl Device { 84 | /// Creates a new device with the given ID 85 | pub fn new(id: i32) -> Result { 86 | let count = get_device_count()?; 87 | if id < 0 || id >= count { 88 | return Err(Error::new(ffi::hipError_t_hipErrorInvalidDevice)); 89 | } 90 | Ok(Self { id }) 91 | } 92 | 93 | /// Get the current device 94 | pub fn current() -> Result { 95 | let mut device_id = 0; 96 | let error = unsafe { ffi::hipGetDevice(&mut device_id) }; 97 | if error != ffi::hipError_t_hipSuccess { 98 | return Err(Error::new(error)); 99 | } 100 | Ok(Self { id: device_id }) 101 | } 102 | 103 | /// Get the device ID 104 | pub fn id(&self) -> i32 { 105 | self.id 106 | } 107 | 108 | /// Set this device as the current device 109 | pub fn set_current(&self) -> Result<()> { 110 | let error = unsafe { ffi::hipSetDevice(self.id) }; 111 | Error::from_hip_error(error) 112 | } 113 | 114 | /// Synchronize this device 115 | pub fn synchronize(&self) -> Result<()> { 116 | // Save current device 117 | let current_device = Self::current()?; 118 | 119 | // Set this device as current 120 | self.set_current()?; 121 | 122 | // Synchronize 123 | let error = unsafe { ffi::hipDeviceSynchronize() }; 124 | 125 | // Restore previous device 126 | current_device.set_current()?; 127 | 128 | Error::from_hip_error(error) 129 | } 130 | 131 | /// Reset this device 132 | pub unsafe fn reset(&self) -> Result<()> { 133 | // Save current device 134 | let current_device = Self::current()?; 135 | 136 | // Set this device as current 137 | self.set_current()?; 138 | 139 | // Reset 140 | let error = unsafe { ffi::hipDeviceReset() }; 141 | 142 | // Restore previous device 143 | current_device.set_current()?; 144 | 145 | Error::from_hip_error(error) 146 | } 147 | 148 | /// Get the properties of this device 149 | pub fn properties(&self) -> Result { 150 | get_device_properties(self.id) 151 | } 152 | 153 | pub fn get_stream(&self) -> Result { 154 | Stream::new() 155 | } 156 | pub fn get_stream_with_flags(&self, flags: u32) -> Result { 157 | Stream::with_flags(flags) 158 | } 159 | pub fn get_stream_with_priority(&self, flags: u32, priority: i32) -> Result { 160 | Stream::with_priority(flags, priority) 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /src/miopen/lrn.rs: -------------------------------------------------------------------------------- 1 | // src/miopen/lrn.rs 2 | 3 | use crate::miopen::error::{Error, Result}; 4 | use crate::miopen::ffi; 5 | use crate::miopen::handle::Handle; 6 | use crate::miopen::tensor::TensorDescriptor; 7 | use std::os::raw::c_void; 8 | use std::ptr; 9 | 10 | /// LRN mode type 11 | pub type LRNMode = ffi::miopenLRNMode_t; 12 | 13 | /// Safe wrapper for MIOpen LRN descriptor 14 | pub struct LRNDescriptor { 15 | desc: ffi::miopenLRNDescriptor_t, 16 | } 17 | 18 | impl LRNDescriptor { 19 | /// Create a new LRN descriptor 20 | pub fn new() -> Result { 21 | let mut desc = ptr::null_mut(); 22 | let status = unsafe { ffi::miopenCreateLRNDescriptor(&mut desc) }; 23 | 24 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 25 | return Err(Error::new(status)); 26 | } 27 | 28 | Ok(Self { desc }) 29 | } 30 | 31 | /// Set the LRN descriptor details 32 | pub fn set( 33 | &mut self, 34 | mode: LRNMode, 35 | lrn_n: u32, 36 | lrn_alpha: f64, 37 | lrn_beta: f64, 38 | lrn_k: f64, 39 | ) -> Result<()> { 40 | let status = unsafe { 41 | ffi::miopenSetLRNDescriptor(self.desc, mode, lrn_n, lrn_alpha, lrn_beta, lrn_k) 42 | }; 43 | 44 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 45 | return Err(Error::new(status)); 46 | } 47 | 48 | Ok(()) 49 | } 50 | 51 | /// Get the LRN descriptor details 52 | pub fn get(&self) -> Result<(LRNMode, u32, f64, f64, f64)> { 53 | let mut mode = 0; 54 | let mut lrn_n = 0; 55 | let mut lrn_alpha = 0.0; 56 | let mut lrn_beta = 0.0; 57 | let mut lrn_k = 0.0; 58 | 59 | let status = unsafe { 60 | ffi::miopenGetLRNDescriptor( 61 | self.desc, 62 | &mut mode, 63 | &mut lrn_n, 64 | &mut lrn_alpha, 65 | &mut lrn_beta, 66 | &mut lrn_k, 67 | ) 68 | }; 69 | 70 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 71 | return Err(Error::new(status)); 72 | } 73 | 74 | Ok((mode, lrn_n, lrn_alpha, lrn_beta, lrn_k)) 75 | } 76 | 77 | /// Get the workspace size required for LRN operations 78 | pub fn get_workspace_size(y_desc: &TensorDescriptor) -> Result { 79 | let mut workspace_size = 0; 80 | 81 | let status = 82 | unsafe { ffi::miopenLRNGetWorkSpaceSize(y_desc.as_raw(), &mut workspace_size) }; 83 | 84 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 85 | return Err(Error::new(status)); 86 | } 87 | 88 | Ok(workspace_size) 89 | } 90 | 91 | /// Execute a forward LRN operation 92 | pub unsafe fn forward( 93 | &self, 94 | handle: &Handle, 95 | alpha: &[u8], 96 | x_desc: &TensorDescriptor, 97 | x: *const c_void, 98 | beta: &[u8], 99 | y_desc: &TensorDescriptor, 100 | y: *mut c_void, 101 | do_backward: bool, 102 | workspace: *mut c_void, 103 | ) -> Result<()> { 104 | let status = unsafe { 105 | ffi::miopenLRNForward( 106 | handle.as_raw(), 107 | self.desc, 108 | alpha.as_ptr() as *const c_void, 109 | x_desc.as_raw(), 110 | x, 111 | beta.as_ptr() as *const c_void, 112 | y_desc.as_raw(), 113 | y, 114 | do_backward, 115 | workspace, 116 | ) 117 | }; 118 | 119 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 120 | return Err(Error::new(status)); 121 | } 122 | 123 | Ok(()) 124 | } 125 | 126 | /// Execute a backward LRN operation 127 | pub unsafe fn backward( 128 | &self, 129 | handle: &Handle, 130 | alpha: &[u8], 131 | y_desc: &TensorDescriptor, 132 | y: *const c_void, 133 | dy_desc: &TensorDescriptor, 134 | dy: *const c_void, 135 | x_desc: &TensorDescriptor, 136 | x: *const c_void, 137 | beta: &[u8], 138 | dx_desc: &TensorDescriptor, 139 | dx: *mut c_void, 140 | workspace: *const c_void, 141 | ) -> Result<()> { 142 | let status = unsafe { 143 | ffi::miopenLRNBackward( 144 | handle.as_raw(), 145 | self.desc, 146 | alpha.as_ptr() as *const c_void, 147 | y_desc.as_raw(), 148 | y, 149 | dy_desc.as_raw(), 150 | dy, 151 | x_desc.as_raw(), 152 | x, 153 | beta.as_ptr() as *const c_void, 154 | dx_desc.as_raw(), 155 | dx, 156 | workspace, 157 | ) 158 | }; 159 | 160 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 161 | return Err(Error::new(status)); 162 | } 163 | 164 | Ok(()) 165 | } 166 | 167 | /// Get the raw descriptor 168 | pub fn as_raw(&self) -> ffi::miopenLRNDescriptor_t { 169 | self.desc 170 | } 171 | } 172 | 173 | impl Drop for LRNDescriptor { 174 | fn drop(&mut self) { 175 | if !self.desc.is_null() { 176 | unsafe { 177 | let _ = ffi::miopenDestroyLRNDescriptor(self.desc); 178 | // We cannot handle errors in drop, so just ignore the result 179 | }; 180 | self.desc = ptr::null_mut(); 181 | } 182 | } 183 | } 184 | -------------------------------------------------------------------------------- /src/rocfft/execution.rs: -------------------------------------------------------------------------------- 1 | use crate::rocfft::error::{Error, Result, check_error}; 2 | use crate::rocfft::ffi; 3 | use std::marker::PhantomData; 4 | use std::ptr; 5 | 6 | /// Additional execution parameters for a transform 7 | /// 8 | /// This structure can control: 9 | /// - Work buffers 10 | /// - Execution streams (HIP/ROCm streams) 11 | /// - Load/store callbacks 12 | pub struct ExecutionInfo { 13 | handle: ffi::rocfft_execution_info, 14 | _marker: PhantomData<*mut ()>, // Mark as !Send and !Sync 15 | } 16 | 17 | impl ExecutionInfo { 18 | /// Create a new execution info object 19 | /// 20 | /// # Returns 21 | /// 22 | /// A result containing the newly created execution info or an error 23 | pub fn new() -> Result { 24 | let mut handle: ffi::rocfft_execution_info = ptr::null_mut(); 25 | 26 | unsafe { 27 | check_error(ffi::rocfft_execution_info_create(&mut handle))?; 28 | } 29 | 30 | Ok(ExecutionInfo { 31 | handle, 32 | _marker: PhantomData, 33 | }) 34 | } 35 | 36 | /// Set a work buffer for the transform 37 | /// 38 | /// # Arguments 39 | /// 40 | /// * `buffer` - Pointer to work buffer (GPU memory) 41 | /// * `size_in_bytes` - Size of work buffer in bytes 42 | /// 43 | /// # Returns 44 | /// 45 | /// A result indicating success or an error 46 | /// 47 | /// # Note 48 | /// 49 | /// If you need to know how large the work buffer should be, call 50 | /// `Plan::get_work_buffer_size()`. 51 | pub unsafe fn set_work_buffer( 52 | &mut self, 53 | buffer: *mut std::ffi::c_void, 54 | size_in_bytes: usize, 55 | ) -> Result<()> { 56 | if self.handle.is_null() { 57 | return Err(Error::ObjectDestroyed); 58 | } 59 | 60 | if buffer.is_null() && size_in_bytes > 0 { 61 | return Err(Error::NullPointer); 62 | } 63 | 64 | unsafe { 65 | check_error(ffi::rocfft_execution_info_set_work_buffer( 66 | self.handle, 67 | buffer, 68 | size_in_bytes, 69 | )) 70 | } 71 | } 72 | 73 | /// Set a ROCm/HIP stream for the transform execution 74 | /// 75 | /// # Arguments 76 | /// 77 | /// * `stream` - HIP stream to use (hipStream_t) 78 | /// 79 | /// # Returns 80 | /// 81 | /// A result indicating success or an error 82 | pub unsafe fn set_stream(&mut self, stream: *mut std::ffi::c_void) -> Result<()> { 83 | if self.handle.is_null() { 84 | return Err(Error::ObjectDestroyed); 85 | } 86 | 87 | unsafe { check_error(ffi::rocfft_execution_info_set_stream(self.handle, stream)) } 88 | } 89 | 90 | /// Set a load callback for the transform (experimental) 91 | /// 92 | /// # Arguments 93 | /// 94 | /// * `callbacks` - Array of callback function pointers 95 | /// * `user_data` - Array of user data pointers passed to callbacks 96 | /// * `shared_mem_bytes` - Amount of shared memory for the callback 97 | /// 98 | /// # Returns 99 | /// 100 | /// A result indicating success or an error 101 | /// 102 | /// # Note 103 | /// 104 | /// This is an experimental feature in rocFFT. 105 | pub fn set_load_callback( 106 | &mut self, 107 | callbacks: &mut [*mut std::ffi::c_void], 108 | user_data: &mut [*mut std::ffi::c_void], 109 | shared_mem_bytes: usize, 110 | ) -> Result<()> { 111 | if self.handle.is_null() { 112 | return Err(Error::ObjectDestroyed); 113 | } 114 | 115 | unsafe { 116 | check_error(ffi::rocfft_execution_info_set_load_callback( 117 | self.handle, 118 | callbacks.as_mut_ptr(), 119 | user_data.as_mut_ptr(), 120 | shared_mem_bytes, 121 | )) 122 | } 123 | } 124 | 125 | /// Set a store callback for the transform (experimental) 126 | /// 127 | /// # Arguments 128 | /// 129 | /// * `callbacks` - Array of callback function pointers 130 | /// * `user_data` - Array of user data pointers passed to callbacks 131 | /// * `shared_mem_bytes` - Amount of shared memory for the callback 132 | /// 133 | /// # Returns 134 | /// 135 | /// A result indicating success or an error 136 | /// 137 | /// # Note 138 | /// 139 | /// This is an experimental feature in rocFFT. 140 | pub fn set_store_callback( 141 | &mut self, 142 | callbacks: &mut [*mut std::ffi::c_void], 143 | user_data: &mut [*mut std::ffi::c_void], 144 | shared_mem_bytes: usize, 145 | ) -> Result<()> { 146 | if self.handle.is_null() { 147 | return Err(Error::ObjectDestroyed); 148 | } 149 | 150 | unsafe { 151 | check_error(ffi::rocfft_execution_info_set_store_callback( 152 | self.handle, 153 | callbacks.as_mut_ptr(), 154 | user_data.as_mut_ptr(), 155 | shared_mem_bytes, 156 | )) 157 | } 158 | } 159 | 160 | /// Get the internal handle (for use in other rocFFT functions) 161 | pub(crate) fn as_ptr(&self) -> ffi::rocfft_execution_info { 162 | self.handle 163 | } 164 | } 165 | 166 | impl Drop for ExecutionInfo { 167 | fn drop(&mut self) { 168 | if !self.handle.is_null() { 169 | unsafe { 170 | ffi::rocfft_execution_info_destroy(self.handle); 171 | } 172 | self.handle = ptr::null_mut(); 173 | } 174 | } 175 | } 176 | 177 | // Prevent sending an execution info between threads as it's not guaranteed to be thread-safe 178 | -------------------------------------------------------------------------------- /src/miopen/examples/multi_tensor/src/main.rs: -------------------------------------------------------------------------------- 1 | mod data; 2 | mod kernels; 3 | pub mod layer; 4 | 5 | use crate::{kernels::KERNEL, layer::Layer}; 6 | use rocm_rs::{ 7 | error::Result, 8 | hip::{DeviceMemory, Dim3, Module, kernel::AsKernelArg}, 9 | kernel_args, 10 | miopen::{self, ActivationMode}, 11 | }; 12 | use std::rc::Rc; 13 | 14 | use crate::data::prepare_data; 15 | 16 | const HIDDEN_SIZE: usize = 8; 17 | const LEARNING_RATE: f32 = 0.01; 18 | const EPOCHS: usize = 200; 19 | 20 | fn main() -> Result<()> { 21 | let (x_data, y_target, class_labels) = prepare_data(); 22 | 23 | let handle = miopen::Handle::new()?; 24 | 25 | let input_size = x_data[0].len(); 26 | let hidden_size = HIDDEN_SIZE; 27 | let output_size = class_labels.len(); 28 | let learning_rate = LEARNING_RATE; 29 | let epochs = EPOCHS; 30 | 31 | let module = Rc::new(Module::load_data(KERNEL)?); 32 | 33 | let mut hidden_layer = Layer::new( 34 | hidden_size, 35 | input_size, 36 | ActivationMode::MiopenActivationLOGISTIC, 37 | module.clone(), 38 | )?; 39 | 40 | let mut output_layer = Layer::new( 41 | output_size, 42 | hidden_size, 43 | ActivationMode::MiopenActivationSOFTRELU, 44 | module.clone(), 45 | )?; 46 | 47 | let mut weights_input_hidden = init_weights(hidden_size, input_size); 48 | let mut bias_hidden = vec![0.0; hidden_size]; 49 | let mut weights_hidden_output = init_weights(output_size, hidden_size); 50 | let mut bias_output = vec![0.0; output_size]; 51 | 52 | let gradient_func = module.get_function("gradient")?; 53 | let mut target_device = DeviceMemory::new(output_size)?; 54 | let mut x_sample_dev = DeviceMemory::new(input_size)?; 55 | 56 | for epoch in 0..epochs { 57 | for (x_sample, target) in x_data.iter().zip(y_target.iter()) { 58 | target_device.copy_from_host(target)?; 59 | x_sample_dev.copy_from_host(x_sample)?; 60 | 61 | let hidden_activation = hidden_layer.forward( 62 | &handle, 63 | &x_sample_dev, 64 | &weights_input_hidden, 65 | &bias_hidden, 66 | )?; 67 | 68 | let prediction = output_layer.forward( 69 | &handle, 70 | &hidden_activation, 71 | &weights_hidden_output, 72 | &bias_output, 73 | )?; 74 | 75 | gradient_func.launch( 76 | Dim3::new_1d(output_size as u32), 77 | Dim3::new_1d(1), 78 | 0, 79 | None, 80 | kernel_args!( 81 | prediction, 82 | &target_device, 83 | &output_layer.device_grad_act, 84 | output_size 85 | ), 86 | )?; 87 | 88 | output_layer.backward( 89 | &handle, 90 | &hidden_activation, 91 | &mut weights_hidden_output, 92 | &mut bias_output, 93 | learning_rate, 94 | )?; 95 | 96 | hidden_layer 97 | .device_grad_act 98 | .copy_from_host(output_layer.input_grad())?; 99 | 100 | hidden_layer.backward( 101 | &handle, 102 | &x_sample_dev, 103 | &mut weights_input_hidden, 104 | &mut bias_hidden, 105 | learning_rate, 106 | )?; 107 | } 108 | 109 | if epoch % 10 == 0 { 110 | println!("Epoch {epoch}"); 111 | } 112 | } 113 | 114 | println!("Inference after training:"); 115 | 116 | let inference_samples = vec![ 117 | (vec![5.1, 3.5, 1.4, 0.2], "setosa"), 118 | (vec![7.0, 3.2, 4.7, 1.4], "versicolor"), 119 | (vec![6.0, 2.2, 5.0, 1.5], "virginica"), 120 | ]; 121 | 122 | for (features, expected_label) in inference_samples { 123 | let mut features_dev = DeviceMemory::new(features.len())?; 124 | features_dev.copy_from_host(&features)?; 125 | 126 | let hidden_activation = 127 | hidden_layer.forward(&handle, &features_dev, &weights_input_hidden, &bias_hidden)?; 128 | 129 | let prediction = output_layer.forward( 130 | &handle, 131 | &hidden_activation, 132 | &weights_hidden_output, 133 | &bias_output, 134 | )?; 135 | 136 | let prediction = { 137 | let mut vec = vec![0.0; output_size]; 138 | prediction.copy_to_host(&mut vec)?; 139 | vec 140 | }; 141 | 142 | let predicted_idx = prediction 143 | .iter() 144 | .enumerate() 145 | .max_by(|a, b| a.1.total_cmp(b.1)) 146 | .map(|(idx, _)| idx) 147 | .unwrap(); 148 | 149 | println!( 150 | "Expected: {expected_label}, Predicted: {}, Probabilities: {:?}", 151 | class_labels[predicted_idx], prediction 152 | ); 153 | } 154 | 155 | Ok(()) 156 | } 157 | 158 | /// Scaling factor used to generate deterministic but non-uniform initial 159 | /// weights. The value 0.37 is approximately 1/e and is chosen to spread 160 | /// the input to `sin` across different phases while keeping the magnitude 161 | /// below 1.0. This constant can be adjusted to change the initialization 162 | /// without altering the overall scheme. 163 | const WEIGHT_INIT_SEED_SCALE: f32 = 0.37; 164 | 165 | fn init_weights(rows: usize, cols: usize) -> Vec { 166 | let mut weights = Vec::with_capacity(rows * cols); 167 | for row in 0..rows { 168 | for col in 0..cols { 169 | let seed = (row * cols + col) as f32 * WEIGHT_INIT_SEED_SCALE; 170 | weights.push(seed.sin() * 0.1); 171 | } 172 | } 173 | weights 174 | } 175 | -------------------------------------------------------------------------------- /src/rocblas/handle.rs: -------------------------------------------------------------------------------- 1 | // src/rocblas/handle.rs 2 | 3 | use crate::hip::Stream; 4 | use crate::rocblas::error::{Error, Result}; 5 | use crate::rocblas::ffi; 6 | use std::ptr; 7 | 8 | /// Safe wrapper for RocBLAS handle 9 | pub struct Handle { 10 | handle: ffi::rocblas_handle, 11 | } 12 | 13 | impl Handle { 14 | /// Create a new RocBLAS handle 15 | pub fn new() -> Result { 16 | let mut handle = ptr::null_mut(); 17 | let error = unsafe { ffi::rocblas_create_handle(&mut handle) }; 18 | 19 | if error != ffi::rocblas_status__rocblas_status_success { 20 | return Err(Error::new(error)); 21 | } 22 | 23 | Ok(Self { handle }) 24 | } 25 | 26 | /// Set the stream for this handle 27 | pub fn set_stream(&self, stream: &Stream) -> Result<()> { 28 | // Use a type cast to convert between the two hipStream_t types 29 | let hip_stream_ptr = stream.as_raw(); 30 | // Cast to the expected type for rocblas 31 | let rocblas_stream_ptr = hip_stream_ptr as ffi::hipStream_t; 32 | 33 | let error = unsafe { ffi::rocblas_set_stream(self.handle, rocblas_stream_ptr) }; 34 | 35 | if error != ffi::rocblas_status__rocblas_status_success { 36 | return Err(Error::new(error)); 37 | } 38 | 39 | Ok(()) 40 | } 41 | /// Get the stream associated with this handle 42 | pub fn get_stream(&self) -> Result { 43 | let mut stream_ptr = ptr::null_mut(); 44 | let error = unsafe { ffi::rocblas_get_stream(self.handle, &mut stream_ptr) }; 45 | 46 | if error != ffi::rocblas_status__rocblas_status_success { 47 | return Err(Error::new(error)); 48 | } 49 | 50 | // Cast back to hip::ffi::hipStream_t 51 | let hip_stream_ptr = stream_ptr as crate::hip::ffi::hipStream_t; 52 | 53 | // Create a Stream from the raw pointer 54 | // This doesn't take ownership of the stream, just wraps the pointer 55 | Ok(Stream::from_raw(hip_stream_ptr)) 56 | } 57 | 58 | /// Set the pointer mode for this handle 59 | pub fn set_pointer_mode(&self, mode: ffi::rocblas_pointer_mode) -> Result<()> { 60 | let error = unsafe { ffi::rocblas_set_pointer_mode(self.handle, mode) }; 61 | 62 | if error != ffi::rocblas_status__rocblas_status_success { 63 | return Err(Error::new(error)); 64 | } 65 | 66 | Ok(()) 67 | } 68 | 69 | /// Get the pointer mode for this handle 70 | pub fn get_pointer_mode(&self) -> Result { 71 | let mut mode = ffi::rocblas_pointer_mode__rocblas_pointer_mode_host; 72 | let error = unsafe { ffi::rocblas_get_pointer_mode(self.handle, &mut mode) }; 73 | 74 | if error != ffi::rocblas_status__rocblas_status_success { 75 | return Err(Error::new(error)); 76 | } 77 | 78 | Ok(mode) 79 | } 80 | 81 | /// Set the atomics mode for this handle 82 | pub fn set_atomics_mode(&self, mode: ffi::rocblas_atomics_mode) -> Result<()> { 83 | let error = unsafe { ffi::rocblas_set_atomics_mode(self.handle, mode) }; 84 | 85 | if error != ffi::rocblas_status__rocblas_status_success { 86 | return Err(Error::new(error)); 87 | } 88 | 89 | Ok(()) 90 | } 91 | 92 | /// Get the atomics mode for this handle 93 | pub fn get_atomics_mode(&self) -> Result { 94 | let mut mode = ffi::rocblas_atomics_mode__rocblas_atomics_allowed; 95 | let error = unsafe { ffi::rocblas_get_atomics_mode(self.handle, &mut mode) }; 96 | 97 | if error != ffi::rocblas_status__rocblas_status_success { 98 | return Err(Error::new(error)); 99 | } 100 | 101 | Ok(mode) 102 | } 103 | 104 | /// Set the performance metric for this handle 105 | pub fn set_performance_metric(&self, metric: ffi::rocblas_performance_metric) -> Result<()> { 106 | let error = unsafe { ffi::rocblas_set_performance_metric(self.handle, metric) }; 107 | 108 | if error != ffi::rocblas_status__rocblas_status_success { 109 | return Err(Error::new(error)); 110 | } 111 | 112 | Ok(()) 113 | } 114 | 115 | /// Get the performance metric for this handle 116 | pub fn get_performance_metric(&self) -> Result { 117 | let mut metric = ffi::rocblas_performance_metric__rocblas_default_performance_metric; 118 | let error = unsafe { ffi::rocblas_get_performance_metric(self.handle, &mut metric) }; 119 | 120 | if error != ffi::rocblas_status__rocblas_status_success { 121 | return Err(Error::new(error)); 122 | } 123 | 124 | Ok(metric) 125 | } 126 | 127 | /// Set the math mode for this handle 128 | pub fn set_math_mode(&self, mode: ffi::rocblas_math_mode) -> Result<()> { 129 | let error = unsafe { ffi::rocblas_set_math_mode(self.handle, mode) }; 130 | 131 | if error != ffi::rocblas_status__rocblas_status_success { 132 | return Err(Error::new(error)); 133 | } 134 | 135 | Ok(()) 136 | } 137 | 138 | /// Get the math mode for this handle 139 | pub fn get_math_mode(&self) -> Result { 140 | let mut mode = ffi::rocblas_math_mode__rocblas_default_math; 141 | let error = unsafe { ffi::rocblas_get_math_mode(self.handle, &mut mode) }; 142 | 143 | if error != ffi::rocblas_status__rocblas_status_success { 144 | return Err(Error::new(error)); 145 | } 146 | 147 | Ok(mode) 148 | } 149 | 150 | /// Get the raw handle 151 | pub fn as_raw(&self) -> ffi::rocblas_handle { 152 | self.handle 153 | } 154 | } 155 | 156 | impl Drop for Handle { 157 | fn drop(&mut self) { 158 | if !self.handle.is_null() { 159 | unsafe { 160 | let _ = ffi::rocblas_destroy_handle(self.handle); 161 | // We cannot handle errors in drop, so just ignore the result 162 | } 163 | self.handle = ptr::null_mut(); 164 | } 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /src/hip/examples/vector_add/main.rs: -------------------------------------------------------------------------------- 1 | use rocm_rs::error::Result; 2 | use rocm_rs::hip::{calculate_grid_1d, Device, DeviceMemory, Dim3, Module, Stream, Timer}; 3 | use std::env; 4 | use std::path::PathBuf; 5 | use std::time::Instant; 6 | 7 | fn main() -> Result<()> { 8 | // Initialize device 9 | println!("Initializing device..."); 10 | let device = Device::new(0)?; 11 | device.set_current()?; 12 | 13 | // Print device info 14 | let props = device.properties()?; 15 | println!("Using device: {}", props.name); 16 | println!("Compute capability: {}.{}", props.major, props.minor); 17 | println!("Multiprocessor count: {}", props.multi_processor_count); 18 | 19 | // Load the precompiled kernel module 20 | let kernel_path = PathBuf::from("vector_add.hsaco"); 21 | 22 | if !kernel_path.exists() { 23 | println!( 24 | "Error: Could not find kernel file: {}", 25 | kernel_path.display() 26 | ); 27 | println!( 28 | "Error: Could not find kernel file: {}", 29 | kernel_path.display() 30 | ); 31 | println!("Make sure to run the build.sh script first to compile the kernel."); 32 | return Ok(()); 33 | } 34 | 35 | println!("Loading kernel module from: {}", kernel_path.display()); 36 | let module = Module::load(kernel_path)?; 37 | 38 | // Get the function handle 39 | let function = module.get_function("vector_add")?; 40 | 41 | // Create a stream for async operations 42 | let stream = Stream::new()?; 43 | 44 | // Get the test size from command line or use default 45 | let args: Vec = env::args().collect(); 46 | let n = if args.len() > 1 { 47 | args[1].parse::().unwrap_or(1_000_000) 48 | } else { 49 | 1_000_000 50 | }; 51 | 52 | println!("Vector size: {}", n); 53 | 54 | // Prepare host data 55 | println!("Preparing host data..."); 56 | let a: Vec = (0..n).map(|i| i as f32).collect(); 57 | let b: Vec = (0..n).map(|i| (2.0 * i as f32)).collect(); 58 | let c = vec![0.0f32; n]; 59 | 60 | // Allocate device memory 61 | println!("Allocating device memory..."); 62 | let d_a = DeviceMemory::::new(n)?; 63 | let d_b = DeviceMemory::::new(n)?; 64 | let d_c = DeviceMemory::::new(n)?; 65 | 66 | // Create a timer 67 | println!("Creating timer..."); 68 | let timer = Timer::new()?; 69 | 70 | // Start timing host-to-device transfer 71 | timer.start(&stream)?; 72 | 73 | // Copy data from host to device 74 | println!("Copying host data to device..."); 75 | d_a.copy_from_host_async(a.clone(), &stream)?; 76 | d_b.copy_from_host_async(b.clone(), &stream)?; 77 | 78 | // Stop timing host-to-device transfer 79 | timer.stop(&stream)?; 80 | let h2d_time = timer.elapsed_time()?; 81 | 82 | // Set up kernel launch parameters 83 | let block_size = 256; 84 | let grid_dim = calculate_grid_1d(n as u32, block_size); 85 | let block_dim = Dim3::new_1d(block_size); 86 | 87 | println!( 88 | "Launching kernel with grid={}, block={}", 89 | grid_dim.x, block_dim.x 90 | ); 91 | 92 | // Prepare kernel arguments 93 | let n_u32 = n as u32; 94 | let kernel_args = [ 95 | d_a.as_kernel_arg(), 96 | d_b.as_kernel_arg(), 97 | d_c.as_kernel_arg(), 98 | &n_u32 as *const _ as *mut std::ffi::c_void, 99 | ]; 100 | 101 | // Start timing kernel execution 102 | timer.start(&stream)?; 103 | 104 | // Launch the kernel 105 | function.launch( 106 | grid_dim, 107 | block_dim, 108 | 0, // shared memory bytes 109 | Some(&stream), 110 | &mut kernel_args.clone(), 111 | )?; 112 | 113 | // Stop timing kernel execution 114 | timer.stop(&stream)?; 115 | let kernel_time = timer.elapsed_time()?; 116 | 117 | // Start timing device-to-host transfer 118 | timer.start(&stream)?; 119 | 120 | // Copy results back to host 121 | let pending = d_c.copy_to_host_async(c, &stream)?; 122 | 123 | // Synchronize the stream to ensure all operations are complete 124 | let c = stream.synchronize_memory(pending)?; 125 | 126 | // Stop timing device-to-host transfer 127 | timer.stop(&stream)?; 128 | let d2h_time = timer.elapsed_time()?; 129 | 130 | // Print timing information 131 | println!("Host to Device Transfer: {:.3} ms", h2d_time); 132 | println!("Kernel Execution: {:.3} ms", kernel_time); 133 | println!("Device to Host Transfer: {:.3} ms", d2h_time); 134 | println!( 135 | "Total GPU Time: {:.3} ms", 136 | h2d_time + kernel_time + d2h_time 137 | ); 138 | 139 | // Verify results 140 | println!("Verifying results..."); 141 | let cpu_start = Instant::now(); 142 | 143 | let mut all_correct = true; 144 | for i in 0..n { 145 | let expected = a[i] + b[i]; 146 | let actual = c[i]; 147 | if (expected - actual).abs() > 1e-5 { 148 | println!( 149 | "Error at index {}: expected {}, got {}", 150 | i, expected, actual 151 | ); 152 | all_correct = false; 153 | if i > 10 { 154 | println!("Stopping verification after 10 errors..."); 155 | break; 156 | } 157 | } 158 | } 159 | 160 | let cpu_elapsed = cpu_start.elapsed(); 161 | println!( 162 | "CPU verification time: {:.3} ms", 163 | cpu_elapsed.as_secs_f32() * 1000.0 164 | ); 165 | 166 | if all_correct { 167 | println!("All results are correct!"); 168 | } else { 169 | println!("Some errors were found in the results."); 170 | } 171 | 172 | // Print a few results 173 | if n > 5 { 174 | println!("First 5 results:"); 175 | for i in 0..5 { 176 | println!("c[{}] = {} + {} = {}", i, a[i], b[i], c[i]); 177 | } 178 | 179 | println!("Last 5 results:"); 180 | for i in n - 5..n { 181 | println!("c[{}] = {} + {} = {}", i, a[i], b[i], c[i]); 182 | } 183 | } 184 | 185 | println!("Example completed successfully!"); 186 | Ok(()) 187 | } 188 | -------------------------------------------------------------------------------- /src/miopen/reduce.rs: -------------------------------------------------------------------------------- 1 | // src/miopen/reduce.rs 2 | 3 | use crate::miopen::error::{Error, Result}; 4 | use crate::miopen::ffi; 5 | use crate::miopen::handle::Handle; 6 | use crate::miopen::tensor::TensorDescriptor; 7 | use std::os::raw::c_void; 8 | use std::ptr; 9 | 10 | /// Reduction tensor operation 11 | pub type ReduceTensorOp = ffi::miopenReduceTensorOp_t; 12 | 13 | /// NaN propagation mode 14 | pub type NanPropagation = ffi::miopenNanPropagation_t; 15 | 16 | /// Reduction tensor indices 17 | pub type ReduceTensorIndices = ffi::miopenReduceTensorIndices_t; 18 | 19 | /// Indices type 20 | pub type IndicesType = ffi::miopenIndicesType_t; 21 | 22 | /// Safe wrapper for MIOpen reduce tensor descriptor 23 | pub struct ReduceTensorDescriptor { 24 | desc: ffi::miopenReduceTensorDescriptor_t, 25 | } 26 | 27 | impl ReduceTensorDescriptor { 28 | /// Create a new reduce tensor descriptor 29 | pub fn new() -> Result { 30 | let mut desc = ptr::null_mut(); 31 | let status = unsafe { ffi::miopenCreateReduceTensorDescriptor(&mut desc) }; 32 | 33 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 34 | return Err(Error::new(status)); 35 | } 36 | 37 | Ok(Self { desc }) 38 | } 39 | 40 | /// Set the reduce tensor descriptor 41 | pub fn set( 42 | &mut self, 43 | reduce_op: ReduceTensorOp, 44 | comp_type: ffi::miopenDataType_t, 45 | nan_opt: NanPropagation, 46 | indices: ReduceTensorIndices, 47 | indices_type: IndicesType, 48 | ) -> Result<()> { 49 | let status = unsafe { 50 | ffi::miopenSetReduceTensorDescriptor( 51 | self.desc, 52 | reduce_op, 53 | comp_type, 54 | nan_opt, 55 | indices, 56 | indices_type, 57 | ) 58 | }; 59 | 60 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 61 | return Err(Error::new(status)); 62 | } 63 | 64 | Ok(()) 65 | } 66 | 67 | /// Get the reduce tensor descriptor details 68 | pub fn get( 69 | &self, 70 | ) -> Result<( 71 | ReduceTensorOp, 72 | ffi::miopenDataType_t, 73 | NanPropagation, 74 | ReduceTensorIndices, 75 | IndicesType, 76 | )> { 77 | let mut reduce_op = 0; 78 | let mut comp_type = 0; 79 | let mut nan_opt = 0; 80 | let mut indices = 0; 81 | let mut indices_type = 0; 82 | 83 | let status = unsafe { 84 | ffi::miopenGetReduceTensorDescriptor( 85 | self.desc, 86 | &mut reduce_op, 87 | &mut comp_type, 88 | &mut nan_opt, 89 | &mut indices, 90 | &mut indices_type, 91 | ) 92 | }; 93 | 94 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 95 | return Err(Error::new(status)); 96 | } 97 | 98 | Ok((reduce_op, comp_type, nan_opt, indices, indices_type)) 99 | } 100 | 101 | /// Get the raw descriptor 102 | pub fn as_raw(&self) -> ffi::miopenReduceTensorDescriptor_t { 103 | self.desc 104 | } 105 | } 106 | 107 | impl Drop for ReduceTensorDescriptor { 108 | fn drop(&mut self) { 109 | if !self.desc.is_null() { 110 | unsafe { 111 | let _ = ffi::miopenDestroyReduceTensorDescriptor(self.desc); 112 | // We cannot handle errors in drop, so just ignore the result 113 | }; 114 | self.desc = ptr::null_mut(); 115 | } 116 | } 117 | } 118 | 119 | /// Get the size required for reduction indices 120 | pub fn get_reduction_indices_size( 121 | handle: &Handle, 122 | reduce_desc: &ReduceTensorDescriptor, 123 | a_desc: &TensorDescriptor, 124 | c_desc: &TensorDescriptor, 125 | ) -> Result { 126 | let mut size_in_bytes = 0; 127 | 128 | let status = unsafe { 129 | ffi::miopenGetReductionIndicesSize( 130 | handle.as_raw(), 131 | reduce_desc.as_raw(), 132 | a_desc.as_raw(), 133 | c_desc.as_raw(), 134 | &mut size_in_bytes, 135 | ) 136 | }; 137 | 138 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 139 | return Err(Error::new(status)); 140 | } 141 | 142 | Ok(size_in_bytes) 143 | } 144 | 145 | /// Get the workspace size required for reduction 146 | pub fn get_reduction_workspace_size( 147 | handle: &Handle, 148 | reduce_desc: &ReduceTensorDescriptor, 149 | a_desc: &TensorDescriptor, 150 | c_desc: &TensorDescriptor, 151 | ) -> Result { 152 | let mut size_in_bytes = 0; 153 | 154 | let status = unsafe { 155 | ffi::miopenGetReductionWorkspaceSize( 156 | handle.as_raw(), 157 | reduce_desc.as_raw(), 158 | a_desc.as_raw(), 159 | c_desc.as_raw(), 160 | &mut size_in_bytes, 161 | ) 162 | }; 163 | 164 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 165 | return Err(Error::new(status)); 166 | } 167 | 168 | Ok(size_in_bytes) 169 | } 170 | 171 | /// Execute a reduction operation 172 | pub unsafe fn reduce_tensor( 173 | handle: &Handle, 174 | reduce_desc: &ReduceTensorDescriptor, 175 | indices: *mut c_void, 176 | indices_size: usize, 177 | workspace: *mut c_void, 178 | workspace_size: usize, 179 | alpha: &[u8], 180 | a_desc: &TensorDescriptor, 181 | a: *const c_void, 182 | beta: &[u8], 183 | c_desc: &TensorDescriptor, 184 | c: *mut c_void, 185 | ) -> Result<()> { 186 | let status = unsafe { 187 | ffi::miopenReduceTensor( 188 | handle.as_raw(), 189 | reduce_desc.as_raw(), 190 | indices, 191 | indices_size, 192 | workspace, 193 | workspace_size, 194 | alpha.as_ptr() as *const c_void, 195 | a_desc.as_raw(), 196 | a, 197 | beta.as_ptr() as *const c_void, 198 | c_desc.as_raw(), 199 | c, 200 | ) 201 | }; 202 | 203 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 204 | return Err(Error::new(status)); 205 | } 206 | 207 | Ok(()) 208 | } 209 | -------------------------------------------------------------------------------- /src/hip/memory_ext/sorting.rs: -------------------------------------------------------------------------------- 1 | use crate::hip::kernel::AsKernelArg; 2 | use rocm_kernel_macros::{ 3 | amdgpu_device, amdgpu_global, amdgpu_kernel_finalize, amdgpu_kernel_init, 4 | }; 5 | 6 | amdgpu_kernel_init!(path: __build_in_kernels_sorting); 7 | 8 | #[amdgpu_device(__build_in_kernels_sorting)] 9 | use core::{cmp::PartialOrd, ptr::swap}; 10 | 11 | use crate::{ 12 | hip::{DeviceMemory, Dim3, Module, Stream, error::Result}, 13 | kernel_args, 14 | }; 15 | 16 | #[amdgpu_device(__build_in_kernels_sorting)] 17 | fn sort_odd_inner(arr: *mut T, ascending: bool) { 18 | let id_x = workgroup_id_x() as usize; 19 | 20 | let fst_index = id_x * 2 + 1; 21 | let sec_index = fst_index + 1; 22 | 23 | let fst = unsafe { *arr.add(fst_index) }; 24 | let sec = unsafe { *arr.add(sec_index) }; 25 | 26 | if (ascending && fst > sec) || (!ascending && fst < sec) { 27 | unsafe { 28 | swap(arr.add(fst_index), arr.add(sec_index)); 29 | } 30 | } 31 | } 32 | 33 | #[amdgpu_device(__build_in_kernels_sorting)] 34 | fn sort_even_inner(arr: *mut T, ascending: bool) { 35 | let id_x = workgroup_id_x() as usize; 36 | 37 | let fst_index = id_x * 2; 38 | let sec_index = fst_index + 1; 39 | 40 | let fst = unsafe { *arr.add(fst_index) }; 41 | let sec = unsafe { *arr.add(sec_index) }; 42 | 43 | if (ascending && fst > sec) || (!ascending && fst < sec) { 44 | unsafe { 45 | swap(arr.add(fst_index), arr.add(sec_index)); 46 | } 47 | } 48 | } 49 | 50 | #[amdgpu_device(__build_in_kernels_sorting)] 51 | fn check_sorted_inner(arr: *mut T, target: *mut bool, size: usize) { 52 | let id_x = workgroup_id_x() as usize; 53 | 54 | if (id_x >= size) { 55 | return; 56 | } 57 | 58 | let fst = unsafe { *arr.add(id_x) }; 59 | let sec = unsafe { *arr.add(id_x + 1) }; 60 | 61 | if (fst <= sec) { 62 | unsafe { *target.add(id_x) = true } 63 | } else { 64 | unsafe { *target.add(id_x) = false } 65 | } 66 | } 67 | 68 | macro_rules! sort_fns { 69 | ($t:ty) => { 70 | paste::paste! { 71 | #[amdgpu_global(__build_in_kernels_sorting)] 72 | fn [](arr: *mut $t, ascending: bool) { 73 | sort_odd_inner::<$t>(arr, ascending) 74 | } 75 | 76 | #[amdgpu_global(__build_in_kernels_sorting)] 77 | fn [](arr: *mut $t, ascending: bool) { 78 | sort_even_inner::<$t>(arr, ascending) 79 | } 80 | 81 | #[amdgpu_global(__build_in_kernels_sorting)] 82 | fn [](arr: *mut $t, target: *mut bool, size: usize) { 83 | check_sorted_inner::<$t>(arr, target, size) 84 | } 85 | } 86 | }; 87 | } 88 | 89 | pub trait GPUSortAllowed {} 90 | 91 | macro_rules! impl_gpu_sort_allowed { 92 | ($($t:ty),+) => { 93 | $( 94 | impl GPUSortAllowed for $t {} 95 | sort_fns!($t); 96 | )* 97 | }; 98 | } 99 | 100 | impl_gpu_sort_allowed!(i8, i16, i32, i64, u8, u16, u32, u64, f32, f64); 101 | 102 | pub(crate) const SORTING_KERNEL: &[u8] = 103 | include_bytes!(amdgpu_kernel_finalize!(__build_in_kernels_sorting)); 104 | 105 | pub(crate) fn sort(mem: &mut DeviceMemory, stream: &Stream, ascending: bool) -> Result<()> { 106 | let module = Module::load_data(SORTING_KERNEL)?; 107 | 108 | let sort_odd = 109 | module.get_function(&(String::from("sort_odd_") + std::any::type_name::()))?; 110 | let sort_even = 111 | module.get_function(&(String::from("sort_even_") + std::any::type_name::()))?; 112 | 113 | let count = mem.count() as u32; 114 | 115 | let args = kernel_args!(mem, ascending); 116 | 117 | let grid_dim_even = Dim3::new_1d(count / 2); 118 | let grid_dim_odd = Dim3::new_1d((count - 1) / 2); 119 | 120 | for _ in 0..count / 2 { 121 | sort_even.launch(grid_dim_even, Dim3::new_1d(1), 0, Some(stream), args)?; 122 | sort_odd.launch(grid_dim_odd, Dim3::new_1d(1), 0, Some(stream), args)?; 123 | } 124 | 125 | Ok(()) 126 | } 127 | 128 | /// Tis function synchronizes stream 129 | /// 130 | /// This function will return an error if memory size is zero. 131 | pub(crate) fn check_sorted(mem: &DeviceMemory, stream: Option<&Stream>) -> Result { 132 | let module = Module::load_data(SORTING_KERNEL)?; 133 | 134 | let check_sorted = 135 | module.get_function(&(String::from("check_sorted_") + std::any::type_name::()))?; 136 | 137 | let count = mem.count(); 138 | 139 | let target = DeviceMemory::::new(count - 1)?; 140 | 141 | let args = kernel_args!(mem, target, count); 142 | 143 | check_sorted.launch( 144 | Dim3::new_1d(count as u32 - 1), 145 | Dim3::new_1d(1), 146 | 0, 147 | stream, 148 | args, 149 | )?; 150 | let mut host = vec![false; count - 1]; 151 | if let Some(stream) = stream { 152 | let pending = target.copy_to_host_async(host, stream)?; 153 | host = stream.synchronize_memory(pending)?; 154 | } else { 155 | target.copy_to_host(&mut host)?; 156 | } 157 | Ok(host.iter().all(|x| *x)) 158 | } 159 | 160 | #[cfg(test)] 161 | mod test { 162 | use crate::{ 163 | error::Result, 164 | hip::{ 165 | Device, DeviceMemory, 166 | memory_ext::sorting::check_sorted, 167 | }, 168 | }; 169 | 170 | #[test] 171 | fn is_sorted() -> Result<()> { 172 | let device = Device::current()?; 173 | 174 | let stream = device.get_stream()?; 175 | 176 | let arr: Vec = vec![1, 2, 3, 4, 5, 6, 7, 8]; 177 | 178 | let mem = DeviceMemory::new(arr.len())?; 179 | mem.copy_from_host_async(arr, &stream)?; 180 | 181 | assert!(check_sorted(&mem, Some(&stream))?); 182 | 183 | Ok(()) 184 | } 185 | 186 | #[test] 187 | fn is_not_sorted() -> Result<()> { 188 | let device = Device::current()?; 189 | 190 | let stream = device.get_stream()?; 191 | 192 | let arr: Vec = vec![1, 3, 2, 4, 5, 6, 8, 7]; 193 | 194 | let mem = DeviceMemory::new(arr.len())?; 195 | mem.copy_from_host_async(arr, &stream)?; 196 | 197 | assert!(!check_sorted(&mem, Some(&stream))?); 198 | 199 | Ok(()) 200 | } 201 | } 202 | -------------------------------------------------------------------------------- /src/miopen/mha.rs: -------------------------------------------------------------------------------- 1 | // src/miopen/mha.rs 2 | 3 | use crate::miopen::error::{Error, Result}; 4 | use crate::miopen::ffi; 5 | use std::ptr; 6 | 7 | /// MHA mask mode 8 | pub type MhaMask = ffi::miopenMhaMask_t; 9 | 10 | /// Constants for MHA mask modes 11 | pub mod mha_mask { 12 | use crate::miopen::ffi; 13 | 14 | /// No mask for MHA 15 | pub const NONE: super::MhaMask = ffi::miopenMhaMask_t_miopenMhaMaskNone; 16 | 17 | /// Causal mask for MHA 18 | pub const CAUSAL: super::MhaMask = ffi::miopenMhaMask_t_miopenMhaMaskCausal; 19 | } 20 | 21 | /// Safe wrapper for MIOpen MHA descriptor 22 | pub struct MhaDescriptor { 23 | desc: ffi::miopenMhaDescriptor_t, 24 | } 25 | 26 | impl MhaDescriptor { 27 | /// Create a new MHA descriptor 28 | pub fn new() -> Result { 29 | let mut desc = ptr::null_mut(); 30 | let status = unsafe { ffi::miopenCreateMhaDescriptor(&mut desc) }; 31 | 32 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 33 | return Err(Error::new(status)); 34 | } 35 | 36 | Ok(Self { desc }) 37 | } 38 | 39 | /// Set the MHA descriptor parameters 40 | pub fn set(&mut self, scale: f32) -> Result<()> { 41 | let status = unsafe { ffi::miopenSetMhaDescriptor(self.desc, scale) }; 42 | 43 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 44 | return Err(Error::new(status)); 45 | } 46 | 47 | Ok(()) 48 | } 49 | 50 | /// Get the MHA descriptor parameters 51 | pub fn get(&self) -> Result { 52 | let mut scale = 0.0f32; 53 | let status = unsafe { ffi::miopenGetMhaDescriptor(self.desc, &mut scale) }; 54 | 55 | if status != ffi::miopenStatus_t_miopenStatusSuccess { 56 | return Err(Error::new(status)); 57 | } 58 | 59 | Ok(scale) 60 | } 61 | 62 | /// Get the raw descriptor 63 | pub fn as_raw(&self) -> ffi::miopenMhaDescriptor_t { 64 | self.desc 65 | } 66 | } 67 | 68 | impl Drop for MhaDescriptor { 69 | fn drop(&mut self) { 70 | if !self.desc.is_null() { 71 | // No explicit destroy function in the API, assuming it's managed by the MIOpen context 72 | } 73 | } 74 | } 75 | 76 | /// Identifiers for tensor arguments of MHA problems 77 | pub type TensorArgumentId = ffi::miopenTensorArgumentId_t; 78 | 79 | /// Constants for tensor argument IDs 80 | pub mod tensor_argument_id { 81 | use crate::miopen::ffi; 82 | 83 | // MHA tensor arguments 84 | pub const MHA_K: super::TensorArgumentId = ffi::miopenTensorArgumentId_t_miopenTensorMhaK; 85 | pub const MHA_Q: super::TensorArgumentId = ffi::miopenTensorArgumentId_t_miopenTensorMhaQ; 86 | pub const MHA_V: super::TensorArgumentId = ffi::miopenTensorArgumentId_t_miopenTensorMhaV; 87 | pub const MHA_O: super::TensorArgumentId = ffi::miopenTensorArgumentId_t_miopenTensorMhaO; 88 | pub const MHA_MASK: super::TensorArgumentId = ffi::miopenTensorArgumentId_t_miopenTensorMhaMask; 89 | pub const MHA_BIAS: super::TensorArgumentId = ffi::miopenTensorArgumentId_t_miopenTensorMhaBias; 90 | 91 | // Scale/descale tensors 92 | pub const MHA_DESCALE_K: super::TensorArgumentId = 93 | ffi::miopenTensorArgumentId_t_miopenTensorMhaDescaleK; 94 | pub const MHA_DESCALE_Q: super::TensorArgumentId = 95 | ffi::miopenTensorArgumentId_t_miopenTensorMhaDescaleQ; 96 | pub const MHA_DESCALE_V: super::TensorArgumentId = 97 | ffi::miopenTensorArgumentId_t_miopenTensorMhaDescaleV; 98 | pub const MHA_DESCALE_S: super::TensorArgumentId = 99 | ffi::miopenTensorArgumentId_t_miopenTensorMhaDescaleS; 100 | pub const MHA_SCALE_S: super::TensorArgumentId = 101 | ffi::miopenTensorArgumentId_t_miopenTensorMhaScaleS; 102 | pub const MHA_SCALE_O: super::TensorArgumentId = 103 | ffi::miopenTensorArgumentId_t_miopenTensorMhaScaleO; 104 | 105 | // Dropout related tensors 106 | pub const MHA_DROPOUT_PROBABILITY: super::TensorArgumentId = 107 | ffi::miopenTensorArgumentId_t_miopenTensorMhaDropoutProbability; 108 | pub const MHA_DROPOUT_SEED: super::TensorArgumentId = 109 | ffi::miopenTensorArgumentId_t_miopenTensorMhaDropoutSeed; 110 | pub const MHA_DROPOUT_OFFSET: super::TensorArgumentId = 111 | ffi::miopenTensorArgumentId_t_miopenTensorMhaDropoutOffset; 112 | 113 | // Other MHA tensors 114 | pub const MHA_AMAX_O: super::TensorArgumentId = 115 | ffi::miopenTensorArgumentId_t_miopenTensorMhaAmaxO; 116 | pub const MHA_AMAX_S: super::TensorArgumentId = 117 | ffi::miopenTensorArgumentId_t_miopenTensorMhaAmaxS; 118 | pub const MHA_M: super::TensorArgumentId = ffi::miopenTensorArgumentId_t_miopenTensorMhaM; 119 | pub const MHA_Z_INV: super::TensorArgumentId = 120 | ffi::miopenTensorArgumentId_t_miopenTensorMhaZInv; 121 | 122 | // Backward tensors 123 | pub const MHA_DO: super::TensorArgumentId = ffi::miopenTensorArgumentId_t_miopenTensorMhaDO; 124 | pub const MHA_DESCALE_O: super::TensorArgumentId = 125 | ffi::miopenTensorArgumentId_t_miopenTensorMhaDescaleO; 126 | pub const MHA_DESCALE_DO: super::TensorArgumentId = 127 | ffi::miopenTensorArgumentId_t_miopenTensorMhaDescaleDO; 128 | pub const MHA_DESCALE_DS: super::TensorArgumentId = 129 | ffi::miopenTensorArgumentId_t_miopenTensorMhaDescaleDS; 130 | pub const MHA_SCALE_DS: super::TensorArgumentId = 131 | ffi::miopenTensorArgumentId_t_miopenTensorMhaScaleDS; 132 | pub const MHA_SCALE_DQ: super::TensorArgumentId = 133 | ffi::miopenTensorArgumentId_t_miopenTensorMhaScaleDQ; 134 | pub const MHA_SCALE_DK: super::TensorArgumentId = 135 | ffi::miopenTensorArgumentId_t_miopenTensorMhaScaleDK; 136 | pub const MHA_SCALE_DV: super::TensorArgumentId = 137 | ffi::miopenTensorArgumentId_t_miopenTensorMhaScaleDV; 138 | pub const MHA_DQ: super::TensorArgumentId = ffi::miopenTensorArgumentId_t_miopenTensorMhaDQ; 139 | pub const MHA_DK: super::TensorArgumentId = ffi::miopenTensorArgumentId_t_miopenTensorMhaDK; 140 | pub const MHA_DV: super::TensorArgumentId = ffi::miopenTensorArgumentId_t_miopenTensorMhaDV; 141 | pub const MHA_AMAX_DQ: super::TensorArgumentId = 142 | ffi::miopenTensorArgumentId_t_miopenTensorMhaAmaxDQ; 143 | pub const MHA_AMAX_DK: super::TensorArgumentId = 144 | ffi::miopenTensorArgumentId_t_miopenTensorMhaAmaxDK; 145 | pub const MHA_AMAX_DV: super::TensorArgumentId = 146 | ffi::miopenTensorArgumentId_t_miopenTensorMhaAmaxDV; 147 | pub const MHA_AMAX_DS: super::TensorArgumentId = 148 | ffi::miopenTensorArgumentId_t_miopenTensorMhaAmaxDS; 149 | } 150 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # rocm-rs: Safe Rust wrappers for AMD ROCm Libraries 2 | 3 | This project provides Rust bindings for AMD's ROCm (Radeon Open Compute) libraries, allowing Rust developers to leverage AMD GPUs for high-performance computing. 4 | 5 | ## Current Status 6 | 7 | **Note: This project is in early development.** 8 | 9 | Currently implemented: 10 | - ✅ rocFFT - Fast Fourier Transform library (raw bindings + safe wrappers) 11 | - ✅ HIP - Heterogeneous-Compute Interface for Portability (raw bindings + safe wrappers) 12 | - ✅ rocBLAS - Basic Linear Algebra Subprograms (raw bindings + safe wrappers) 13 | - ✅ MIOpen - Deep learning primitives (raw bindings + safe wrappers) 14 | - ✅ rocRAND - Random number generation (raw bindings + safe wrappers) 15 | - ✅ rocSOLVER - Linear system solvers (raw bindings only) 16 | - ✅ rocSPARSE - Sparse linear algebra (raw bindings only) 17 | - ✅ ROCArray - GPU array struct with api similar to Vec (to be deprecated in favor of DeviceMemoryExt) 18 | - ✅ rocmsmi - system managment interface (refer to [rocm_smi_lib](https://github.com/PTFOPlayer/rocm_smi_lib_rs)) 19 | - ✅ rocm_kernel_macros - macros for writing gpu kernels in rust(refer to [rocm_kernel_macros](https://github.com/RustNSparks/rocm_kernel_macros)) 20 | 21 | The project currently focuses on providing raw FFI bindings for most libraries, with safe Rust wrappers available for rocFFT. Additional safe wrappers for other libraries are planned for future development. 22 | 23 | ## Prerequisites 24 | 25 | - AMD ROCm installed (version 6.3 or later recommended.It may work on older versions, but I did not test that) 26 | - Ubuntu 24.04 / Fedora 42 27 | - Rust toolchain (1.65.0 or later recommended) 28 | - A compatible AMD GPU 29 | 30 | ## Installation 31 | 32 | Add this to your `Cargo.toml`: 33 | 34 | ```toml 35 | [dependencies] 36 | rocm-rs = "4.2" 37 | ``` 38 | 39 | ## Usage 40 | 41 | First, ensure that the ROCm libraries are in your library path or set the `ROCM_PATH` environment variable. 42 | 43 | ### Writing your own kernels with rust 44 | 45 | ```rust 46 | use std::path::PathBuf; 47 | 48 | use rocm_kernel_macros::{amdgpu_kernel_attr, amdgpu_kernel_finalize, amdgpu_kernel_init}; 49 | use rocm_rs::hip::*; 50 | 51 | const LEN: usize = 1024; 52 | 53 | // initializing rust gpu kernel 54 | amdgpu_kernel_init!(); 55 | 56 | // marking code that will be coppied to gpu kernel 57 | #[amdgpu_kernel_attr] 58 | fn kernel(input: *const u32, output: *mut u32) { 59 | // retriving data from buffere by workitem 60 | let num = read_by_workitem_id_x(input); 61 | 62 | // writing data back 63 | write_by_workitem_id_x(output, num * 3); 64 | } 65 | 66 | // compiling gpu kernel 67 | const AMDGPU_KERNEL_BINARY_PATH: &str = amdgpu_kernel_finalize!(); 68 | 69 | fn main() -> Result<()> { 70 | // setting up device 71 | let device = Device::new(0)?; 72 | device.set_current()?; 73 | 74 | // loading gpu kerenel (runs in runtime!) 75 | let kernel_path = PathBuf::from(AMDGPU_KERNEL_BINARY_PATH); 76 | assert!(kernel_path.exists()); 77 | 78 | let module = Module::load(kernel_path)?; 79 | 80 | // acquiring function handle from gpu kernel 81 | let function = module.get_function("kernel")?; 82 | 83 | // preparing host side buffers 84 | let mut in_host: Vec = vec![0; LEN]; 85 | let mut out_host: Vec = vec![0; LEN]; 86 | 87 | for i in 0..LEN { 88 | in_host[i] = i as u32; 89 | } 90 | 91 | // preparing gpu side buffers 92 | let mut input = DeviceMemory::::new(LEN)?; 93 | let output = DeviceMemory::::new(LEN)?; 94 | 95 | input.copy_from_host(&in_host)?; 96 | 97 | // providing arguments for kernel 98 | let kernel_args = [input.as_kernel_arg(), output.as_kernel_arg()]; 99 | 100 | // setting up launch args 101 | let grid_dim = Dim3 { x: 2, y: 1, z: 1 }; 102 | let block_dim = Dim3 { 103 | x: (LEN / 2) as u32, 104 | y: 1, 105 | z: 1, 106 | }; 107 | 108 | function.launch(grid_dim, block_dim, 0, None, &mut kernel_args.clone())?; 109 | 110 | // retriving computed data 111 | output.copy_to_host(&mut out_host)?; 112 | 113 | println!("Output: {:?}", &out_host[..256]); 114 | 115 | Ok(()) 116 | } 117 | 118 | 119 | ``` 120 | 121 | ### Using rocFFT with safe wrappers: 122 | 123 | ```rust 124 | use rocm_rs::rocfft::{self, plan, execution, field}; 125 | 126 | fn main() { 127 | // Initialize the rocFFT library 128 | // Use the safe wrappers for rocFFT 129 | let plan = plan::Plan::new(/* parameters */); 130 | let field = field::Field::new(/* parameters */); 131 | let execution = execution::Execution::new(/* parameters */); 132 | 133 | // Perform FFT operations 134 | // ... 135 | } 136 | ``` 137 | 138 | ### Using other libraries with raw bindings: 139 | 140 | ```rust 141 | use rocm_rs::hip::*; 142 | 143 | fn main() { 144 | unsafe { 145 | // Example of using HIP raw bindings 146 | let mut device_count = 0; 147 | hipGetDeviceCount(&mut device_count); 148 | println!("Found {} HIP devices", device_count); 149 | 150 | // Use other raw bindings as needed 151 | // ... 152 | } 153 | } 154 | ``` 155 | 156 | ## Building from Source 157 | 158 | **Important**: When building from source, you need to run `cargo build` first to generate the bindings files before you can use the library or run tests. 159 | 160 | ```bash 161 | # Clone the repository 162 | git clone https://github.com/RustNSparks/rocm-rs 163 | cd rocm-rs 164 | 165 | # Set the ROCm path if not in the default location 166 | export ROCM_PATH=/opt/rocm 167 | 168 | # Build the project (generates bindings) 169 | cargo build 170 | ``` 171 | 172 | ## Feature flags 173 | 174 | - rocm_smi - enables bindings and wrappers for rocm_smi_lib 175 | 176 | ## Examples 177 | - hip 178 | - vector_add - example containing kernel written in cpp launched with rocm-rs 179 | - rust_kernel - example containing kernel written in in rust using macros 180 | - rust_kernel_async - example containing kernel written in in rust, using stream to manage memory asynchronously 181 | - saxpy - X = aX+Y 182 | - rand 183 | - normal - generating random numbers with normal distribution 184 | 185 | ## Contributing 186 | 187 | Contributions are welcome! Please feel free to submit a Pull Request. 188 | 189 | When contributing: 190 | 1. Run `cargo build` first to generate the bindings 191 | 2. Add tests for new functionality 192 | 3. Update documentation as needed 193 | 194 | ## License 195 | 196 | This project is licensed under the MIT License - see the LICENSE file for details. 197 | 198 | ## Acknowledgments 199 | 200 | - AMD for developing and maintaining ROCm 201 | - The Rust community for bindgen and other tools used in this project 202 | -------------------------------------------------------------------------------- /src/hip/utils.rs: -------------------------------------------------------------------------------- 1 | // src/hip/utils.rs 2 | 3 | use crate::hip::error::Result; 4 | use crate::hip::{self, Device, ffi}; 5 | 6 | /// Get a description of all devices in the system 7 | pub fn print_devices_info() -> Result { 8 | let count = hip::get_device_count()?; 9 | let mut output = String::new(); 10 | 11 | output.push_str(&format!("Found {} HIP device(s)\n", count)); 12 | 13 | for i in 0..count { 14 | let device = Device::new(i)?; 15 | let props = device.properties()?; 16 | 17 | output.push_str(&format!("\nDevice {}: {}\n", i, props.name)); 18 | output.push_str(&format!( 19 | " Compute capability: {}.{}\n", 20 | props.major, props.minor 21 | )); 22 | output.push_str(&format!( 23 | " Total memory: {} MB\n", 24 | props.total_global_mem / (1024 * 1024) 25 | )); 26 | output.push_str(&format!(" Clock rate: {} MHz\n", props.clock_rate / 1000)); 27 | output.push_str(&format!( 28 | " Multi-processor count: {}\n", 29 | props.multi_processor_count 30 | )); 31 | output.push_str(&format!( 32 | " Max threads per block: {}\n", 33 | props.max_threads_per_block 34 | )); 35 | output.push_str(&format!( 36 | " Max threads per multiprocessor: {}\n", 37 | props.max_threads_per_multiprocessor 38 | )); 39 | output.push_str(&format!(" Warp size: {}\n", props.warp_size)); 40 | output.push_str(&format!( 41 | " Max dimensions of a grid: [{}, {}, {}]\n", 42 | props.max_grid_size[0], props.max_grid_size[1], props.max_grid_size[2] 43 | )); 44 | output.push_str(&format!( 45 | " Max dimensions of a block: [{}, {}, {}]\n", 46 | props.max_threads_dim[0], props.max_threads_dim[1], props.max_threads_dim[2] 47 | )); 48 | output.push_str(&format!( 49 | " Shared memory per block: {} KB\n", 50 | props.shared_mem_per_block / 1024 51 | )); 52 | output.push_str(&format!( 53 | " Registers per block: {}\n", 54 | props.regs_per_block 55 | )); 56 | output.push_str(&format!( 57 | " L2 cache size: {} KB\n", 58 | props.l2_cache_size / 1024 59 | )); 60 | output.push_str(&format!( 61 | " Memory clock rate: {} MHz\n", 62 | props.memory_clock_rate / 1000 63 | )); 64 | output.push_str(&format!( 65 | " Memory bus width: {} bits\n", 66 | props.memory_bus_width 67 | )); 68 | output.push_str(&format!(" Integrated: {}\n", props.integrated)); 69 | output.push_str(&format!( 70 | " Can map host memory: {}\n", 71 | props.can_map_host_memory 72 | )); 73 | } 74 | 75 | Ok(output) 76 | } 77 | 78 | /// Wrapper for HIP version information 79 | pub struct Version { 80 | pub major: i32, 81 | pub minor: i32, 82 | pub patch: i32, 83 | } 84 | 85 | impl Version { 86 | /// Get the HIP driver version 87 | pub fn driver() -> Result { 88 | let version = hip::driver_version()?; 89 | 90 | // HIP versions are encoded as (10000*major + 100*minor + patch) 91 | let major = version / 10000; 92 | let minor = (version % 10000) / 100; 93 | let patch = version % 100; 94 | 95 | Ok(Self { 96 | major, 97 | minor, 98 | patch, 99 | }) 100 | } 101 | 102 | /// Get the HIP runtime version 103 | pub fn runtime() -> Result { 104 | let version = hip::runtime_version()?; 105 | 106 | // HIP versions are encoded as (10000*major + 100*minor + patch) 107 | let major = version / 10000; 108 | let minor = (version % 10000) / 100; 109 | let patch = version % 100; 110 | 111 | Ok(Self { 112 | major, 113 | minor, 114 | patch, 115 | }) 116 | } 117 | } 118 | 119 | impl std::fmt::Display for Version { 120 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 121 | write!(f, "{}.{}.{}", self.major, self.minor, self.patch) 122 | } 123 | } 124 | 125 | /// Convenient struct for 3D dimensions 126 | #[derive(Debug, Clone, Copy)] 127 | pub struct Dim3 { 128 | pub x: u32, 129 | pub y: u32, 130 | pub z: u32, 131 | } 132 | 133 | impl Dim3 { 134 | /// Create a new 1D dimension 135 | pub fn new_1d(x: u32) -> Self { 136 | Self { x, y: 1, z: 1 } 137 | } 138 | 139 | /// Create a new 2D dimension 140 | pub fn new_2d(x: u32, y: u32) -> Self { 141 | Self { x, y, z: 1 } 142 | } 143 | 144 | /// Create a new 3D dimension 145 | pub fn new_3d(x: u32, y: u32, z: u32) -> Self { 146 | Self { x, y, z } 147 | } 148 | 149 | /// Convert to the native HIP dim3 structure 150 | pub fn to_native(&self) -> ffi::dim3 { 151 | ffi::dim3 { 152 | x: self.x, 153 | y: self.y, 154 | z: self.z, 155 | } 156 | } 157 | } 158 | 159 | impl From for Dim3 { 160 | fn from(x: u32) -> Self { 161 | Self::new_1d(x) 162 | } 163 | } 164 | 165 | impl From<(u32, u32)> for Dim3 { 166 | fn from((x, y): (u32, u32)) -> Self { 167 | Self::new_2d(x, y) 168 | } 169 | } 170 | 171 | impl From<(u32, u32, u32)> for Dim3 { 172 | fn from((x, y, z): (u32, u32, u32)) -> Self { 173 | Self::new_3d(x, y, z) 174 | } 175 | } 176 | 177 | /// Calculate optimal grid dimensions for a 1D problem 178 | pub fn calculate_grid_1d(total_elements: u32, block_size: u32) -> Dim3 { 179 | let grid_size = (total_elements + block_size - 1) / block_size; 180 | Dim3::new_1d(grid_size) 181 | } 182 | 183 | /// Calculate optimal grid dimensions for a 2D problem 184 | pub fn calculate_grid_2d(width: u32, height: u32, block_x: u32, block_y: u32) -> Dim3 { 185 | let grid_x = (width + block_x - 1) / block_x; 186 | let grid_y = (height + block_y - 1) / block_y; 187 | Dim3::new_2d(grid_x, grid_y) 188 | } 189 | 190 | /// Calculate optimal grid dimensions for a 3D problem 191 | pub fn calculate_grid_3d( 192 | width: u32, 193 | height: u32, 194 | depth: u32, 195 | block_x: u32, 196 | block_y: u32, 197 | block_z: u32, 198 | ) -> Dim3 { 199 | let grid_x = (width + block_x - 1) / block_x; 200 | let grid_y = (height + block_y - 1) / block_y; 201 | let grid_z = (depth + block_z - 1) / block_z; 202 | Dim3::new_3d(grid_x, grid_y, grid_z) 203 | } 204 | 205 | /// Helper function to determine if HIP is available 206 | pub fn is_hip_available() -> bool { 207 | match hip::device_count() { 208 | Ok(count) => count > 0, 209 | Err(_) => false, 210 | } 211 | } 212 | -------------------------------------------------------------------------------- /src/hip/stream.rs: -------------------------------------------------------------------------------- 1 | // src/hip/stream.rs 2 | 3 | use crate::hip; 4 | use crate::hip::error::{Error, Result}; 5 | use crate::hip::event::Event; 6 | use crate::hip::ffi; 7 | use std::{panic, ptr}; 8 | 9 | use super::memory::SynchronizeCopies; 10 | 11 | /// Safe wrapper for HIP streams 12 | #[derive(Clone, Debug)] 13 | pub struct Stream { 14 | pub(crate) stream: hip::ffi::hipStream_t, 15 | } 16 | 17 | impl Stream { 18 | /// Create a new stream 19 | pub(crate) fn new() -> Result { 20 | let mut stream = ptr::null_mut(); 21 | let error = unsafe { ffi::hipStreamCreate(&mut stream) }; 22 | 23 | if error != ffi::hipError_t_hipSuccess { 24 | return Err(Error::new(error)); 25 | } 26 | 27 | Ok(Self { stream }) 28 | } 29 | 30 | /// Create a new stream with specific flags 31 | pub(crate) fn with_flags(flags: u32) -> Result { 32 | let mut stream = ptr::null_mut(); 33 | let error = unsafe { ffi::hipStreamCreateWithFlags(&mut stream, flags) }; 34 | 35 | if error != ffi::hipError_t_hipSuccess { 36 | return Err(Error::new(error)); 37 | } 38 | 39 | Ok(Self { stream }) 40 | } 41 | 42 | /// Create a new stream with priority 43 | pub(crate) fn with_priority(flags: u32, priority: i32) -> Result { 44 | let mut stream = ptr::null_mut(); 45 | let error = unsafe { ffi::hipStreamCreateWithPriority(&mut stream, flags, priority) }; 46 | 47 | if error != ffi::hipError_t_hipSuccess { 48 | return Err(Error::new(error)); 49 | } 50 | 51 | Ok(Self { stream }) 52 | } 53 | 54 | /// Wait for a stream to complete 55 | pub fn synchronize(&self) -> Result<()> { 56 | let error = unsafe { ffi::hipStreamSynchronize(self.stream) }; 57 | 58 | if error != ffi::hipError_t_hipSuccess { 59 | return Err(Error::new(error)); 60 | } 61 | 62 | Ok(()) 63 | } 64 | 65 | pub fn synchronize_memory(&self, copies: T) -> Result { 66 | Self::synchronize(&self)?; 67 | Ok(unsafe { copies.finalize() }) 68 | } 69 | 70 | /// Query if all operations in the stream have completed 71 | pub fn query(&self) -> Result<()> { 72 | let error = unsafe { ffi::hipStreamQuery(self.stream) }; 73 | 74 | if error == ffi::hipError_t_hipSuccess { 75 | Ok(()) 76 | } else if error == ffi::hipError_t_hipErrorNotReady { 77 | // Not ready isn't a true error in this context 78 | Err(Error::new(error)) 79 | } else { 80 | Err(Error::new(error)) 81 | } 82 | } 83 | 84 | /// Wait on an event 85 | pub fn wait_event(&self, event: &Event, flags: u32) -> Result<()> { 86 | let error = unsafe { ffi::hipStreamWaitEvent(self.stream, event.as_raw(), flags) }; 87 | 88 | if error != ffi::hipError_t_hipSuccess { 89 | return Err(Error::new(error)); 90 | } 91 | 92 | Ok(()) 93 | } 94 | 95 | /// Add a callback to be executed when the stream completes 96 | pub fn add_callback(&self, callback: F) -> Result<()> 97 | where 98 | F: FnOnce() + Send + 'static, 99 | { 100 | type Callback = dyn FnOnce() + Send + 'static; 101 | 102 | let boxed: Box>> = Box::new(Some(Box::new(callback))); 103 | 104 | let ptr = Box::into_raw(boxed) as *mut std::ffi::c_void; 105 | 106 | // The C callback function that will be called by HIP 107 | unsafe extern "C" fn helper_callback( 108 | _stream: ffi::hipStream_t, 109 | _status: ffi::hipError_t, 110 | user_data: *mut std::ffi::c_void, 111 | ) { 112 | let callback_box = unsafe { Box::from_raw(user_data as *mut Option>) }; 113 | 114 | if let Some(callback) = *callback_box { 115 | let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| callback())); 116 | } 117 | } 118 | 119 | let error = 120 | unsafe { ffi::hipStreamAddCallback(self.stream, Some(helper_callback), ptr, 0) }; 121 | 122 | if error != ffi::hipError_t_hipSuccess { 123 | unsafe { drop(Box::from_raw(ptr)) } 124 | return Err(Error::new(error)); 125 | } 126 | 127 | Ok(()) 128 | } 129 | 130 | /// Get the raw stream handle 131 | pub fn as_raw(&self) -> ffi::hipStream_t { 132 | self.stream 133 | } 134 | 135 | /// Get the stream priority range 136 | pub fn priority_range() -> Result<(i32, i32)> { 137 | let mut least_priority = 0; 138 | let mut greatest_priority = 0; 139 | 140 | let error = unsafe { 141 | ffi::hipDeviceGetStreamPriorityRange(&mut least_priority, &mut greatest_priority) 142 | }; 143 | 144 | if error != ffi::hipError_t_hipSuccess { 145 | return Err(Error::new(error)); 146 | } 147 | 148 | Ok((least_priority, greatest_priority)) 149 | } 150 | 151 | /// Get the priority of this stream 152 | pub fn get_priority(&self) -> Result { 153 | let mut priority = 0; 154 | 155 | let error = unsafe { ffi::hipStreamGetPriority(self.stream, &mut priority) }; 156 | 157 | if error != ffi::hipError_t_hipSuccess { 158 | return Err(Error::new(error)); 159 | } 160 | 161 | Ok(priority) 162 | } 163 | 164 | /// Get the flags of this stream 165 | pub fn get_flags(&self) -> Result { 166 | let mut flags = 0; 167 | 168 | let error = unsafe { ffi::hipStreamGetFlags(self.stream, &mut flags) }; 169 | 170 | if error != ffi::hipError_t_hipSuccess { 171 | return Err(Error::new(error)); 172 | } 173 | 174 | Ok(flags) 175 | } 176 | 177 | /// Get the device associated with this stream 178 | pub fn get_device(&self) -> Result { 179 | let mut device = 0; 180 | 181 | let error = unsafe { ffi::hipStreamGetDevice(self.stream, &mut device) }; 182 | 183 | if error != ffi::hipError_t_hipSuccess { 184 | return Err(Error::new(error)); 185 | } 186 | 187 | Ok(device) 188 | } 189 | pub fn from_raw(stream: ffi::hipStream_t) -> Self { 190 | Self { stream } 191 | } 192 | } 193 | 194 | impl Drop for Stream { 195 | fn drop(&mut self) { 196 | if !self.stream.is_null() { 197 | unsafe { 198 | let _ = ffi::hipStreamDestroy(self.stream); 199 | // We cannot handle errors in drop, so just ignore the result 200 | }; 201 | self.stream = ptr::null_mut(); 202 | } 203 | } 204 | } 205 | 206 | /// Constants for stream creation flags 207 | pub mod stream_flags { 208 | /// Default stream creation flag (synchronizing) 209 | pub const DEFAULT: u32 = 0; 210 | 211 | /// Non-blocking stream that doesn't synchronize with the NULL stream 212 | pub const NON_BLOCKING: u32 = 1; 213 | } 214 | -------------------------------------------------------------------------------- /src/rocblas/error.rs: -------------------------------------------------------------------------------- 1 | // src/rocblas/error.rs 2 | 3 | use crate::rocblas::ffi; 4 | use std::error::Error as StdError; 5 | use std::fmt; 6 | 7 | /// Error type for RocBLAS operations 8 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 9 | pub struct Error { 10 | code: ffi::rocblas_status, 11 | } 12 | 13 | /// Result type for RocBLAS operations 14 | pub type Result = std::result::Result; 15 | 16 | impl Error { 17 | /// Create a new error from a RocBLAS error code 18 | pub fn new(code: ffi::rocblas_status) -> Self { 19 | Self { code } 20 | } 21 | 22 | /// Convert a RocBLAS error code to a Result 23 | pub fn from_rocblas_error(error: ffi::rocblas_status) -> Result 24 | where 25 | T: Default, 26 | { 27 | if error == ffi::rocblas_status__rocblas_status_success { 28 | Ok(T::default()) 29 | } else { 30 | Err(Error::new(error)) 31 | } 32 | } 33 | 34 | /// Convert a RocBLAS error code to a Result with a specific value 35 | pub fn from_rocblas_error_with_value(error: ffi::rocblas_status, value: T) -> Result { 36 | if error == ffi::rocblas_status__rocblas_status_success { 37 | Ok(value) 38 | } else { 39 | Err(Error::new(error)) 40 | } 41 | } 42 | 43 | /// Returns true if the error code represents success 44 | pub fn is_success(&self) -> bool { 45 | self.code == ffi::rocblas_status__rocblas_status_success 46 | } 47 | 48 | /// Get the raw error code 49 | pub fn code(&self) -> ffi::rocblas_status { 50 | self.code 51 | } 52 | 53 | /// Get the name of the error code 54 | pub fn name(&self) -> &'static str { 55 | match self.code { 56 | ffi::rocblas_status__rocblas_status_success => "rocblas_status_success", 57 | ffi::rocblas_status__rocblas_status_invalid_handle => "rocblas_status_invalid_handle", 58 | ffi::rocblas_status__rocblas_status_not_implemented => "rocblas_status_not_implemented", 59 | ffi::rocblas_status__rocblas_status_invalid_pointer => "rocblas_status_invalid_pointer", 60 | ffi::rocblas_status__rocblas_status_invalid_size => "rocblas_status_invalid_size", 61 | ffi::rocblas_status__rocblas_status_memory_error => "rocblas_status_memory_error", 62 | ffi::rocblas_status__rocblas_status_internal_error => "rocblas_status_internal_error", 63 | ffi::rocblas_status__rocblas_status_perf_degraded => "rocblas_status_perf_degraded", 64 | ffi::rocblas_status__rocblas_status_size_query_mismatch => { 65 | "rocblas_status_size_query_mismatch" 66 | } 67 | ffi::rocblas_status__rocblas_status_size_increased => "rocblas_status_size_increased", 68 | ffi::rocblas_status__rocblas_status_size_unchanged => "rocblas_status_size_unchanged", 69 | ffi::rocblas_status__rocblas_status_invalid_value => "rocblas_status_invalid_value", 70 | ffi::rocblas_status__rocblas_status_continue => "rocblas_status_continue", 71 | ffi::rocblas_status__rocblas_status_check_numerics_fail => { 72 | "rocblas_status_check_numerics_fail" 73 | } 74 | ffi::rocblas_status__rocblas_status_excluded_from_build => { 75 | "rocblas_status_excluded_from_build" 76 | } 77 | ffi::rocblas_status__rocblas_status_arch_mismatch => "rocblas_status_arch_mismatch", 78 | _ => "Unknown rocblas_status code", 79 | } 80 | } 81 | 82 | /// Get the description of the error code 83 | pub fn description(&self) -> &'static str { 84 | match self.code { 85 | ffi::rocblas_status__rocblas_status_success => "Success", 86 | ffi::rocblas_status__rocblas_status_invalid_handle => { 87 | "Handle not initialized, invalid, or null" 88 | } 89 | ffi::rocblas_status__rocblas_status_not_implemented => "Function is not implemented", 90 | ffi::rocblas_status__rocblas_status_invalid_pointer => "Invalid pointer argument", 91 | ffi::rocblas_status__rocblas_status_invalid_size => "Invalid size argument", 92 | ffi::rocblas_status__rocblas_status_memory_error => { 93 | "Failed internal memory allocation, copy, or dealloc" 94 | } 95 | ffi::rocblas_status__rocblas_status_internal_error => "Other internal library failure", 96 | ffi::rocblas_status__rocblas_status_perf_degraded => { 97 | "Performance degraded due to low device memory" 98 | } 99 | ffi::rocblas_status__rocblas_status_size_query_mismatch => { 100 | "Unmatched start/stop size query" 101 | } 102 | ffi::rocblas_status__rocblas_status_size_increased => { 103 | "Queried device memory size increased" 104 | } 105 | ffi::rocblas_status__rocblas_status_size_unchanged => { 106 | "Queried device memory size unchanged" 107 | } 108 | ffi::rocblas_status__rocblas_status_invalid_value => "Passed argument not valid", 109 | ffi::rocblas_status__rocblas_status_continue => { 110 | "Nothing preventing function to proceed" 111 | } 112 | ffi::rocblas_status__rocblas_status_check_numerics_fail => "Check numerics failure", 113 | ffi::rocblas_status__rocblas_status_excluded_from_build => { 114 | "Feature excluded from build" 115 | } 116 | ffi::rocblas_status__rocblas_status_arch_mismatch => "Architecture mismatch", 117 | _ => "Unknown error", 118 | } 119 | } 120 | } 121 | 122 | impl fmt::Display for Error { 123 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 124 | write!( 125 | f, 126 | "RocBLAS error {}: {} - {}", 127 | self.code, 128 | self.name(), 129 | self.description() 130 | ) 131 | } 132 | } 133 | 134 | impl StdError for Error {} 135 | 136 | // Define error conversion functions for common RocBLAS error codes 137 | impl Error { 138 | pub fn is_invalid_handle(&self) -> bool { 139 | self.code == ffi::rocblas_status__rocblas_status_invalid_handle 140 | } 141 | 142 | pub fn is_not_implemented(&self) -> bool { 143 | self.code == ffi::rocblas_status__rocblas_status_not_implemented 144 | } 145 | 146 | pub fn is_invalid_pointer(&self) -> bool { 147 | self.code == ffi::rocblas_status__rocblas_status_invalid_pointer 148 | } 149 | 150 | pub fn is_invalid_size(&self) -> bool { 151 | self.code == ffi::rocblas_status__rocblas_status_invalid_size 152 | } 153 | 154 | pub fn is_memory_error(&self) -> bool { 155 | self.code == ffi::rocblas_status__rocblas_status_memory_error 156 | } 157 | 158 | pub fn is_internal_error(&self) -> bool { 159 | self.code == ffi::rocblas_status__rocblas_status_internal_error 160 | } 161 | 162 | pub fn is_invalid_value(&self) -> bool { 163 | self.code == ffi::rocblas_status__rocblas_status_invalid_value 164 | } 165 | } 166 | --------------------------------------------------------------------------------