├── .github └── workflows │ └── build.yml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md └── operators ├── Cargo.toml ├── build.rs └── src ├── .clang-format ├── add ├── args.rs ├── common_cpu │ └── mod.rs ├── cuda │ ├── add.cuh │ └── mod.rs ├── infini │ └── mod.rs ├── mod.rs └── opencl │ └── mod.rs ├── add_rows ├── args.rs ├── common_cpu │ └── mod.rs ├── cuda │ ├── add_rows.cuh │ └── mod.rs ├── infini │ └── mod.rs ├── mod.rs └── opencl │ └── mod.rs ├── all_reduce ├── args.rs ├── common_cpu.rs ├── infini.rs ├── mod.rs └── nccl.rs ├── attention ├── args.rs ├── common_cpu.rs ├── cuda.rs ├── infini.rs ├── mod.rs ├── opencl.rs └── operator.rs ├── attention_kv_cached ├── args.rs ├── common_cpu.rs ├── cuda.rs ├── infini.rs ├── mod.rs ├── opencl.rs └── operator.rs ├── broadcast ├── args.rs ├── common_cpu │ └── mod.rs ├── mod.rs └── nccl │ └── mod.rs ├── common ├── blob.rs ├── calculator.rs ├── diversity.rs ├── error.rs ├── maybe_dyn.rs ├── mod.rs ├── pool.rs ├── tensor.rs ├── unsigned.rs └── workspace.rs ├── conv ├── args.rs ├── common_cpu.rs ├── cuda.rs ├── im2col.rs ├── infini.rs ├── mod.rs └── opencl.rs ├── fuesd_softmax ├── args.rs ├── common_cpu │ └── mod.rs ├── cuda │ ├── fused_softmax.cuh │ └── mod.rs ├── infini │ └── mod.rs ├── mod.rs └── opencl │ ├── fused_softmax.cl │ └── mod.rs ├── gelu ├── args.rs ├── common_cpu │ └── mod.rs ├── cuda │ ├── gelu.cuh │ └── mod.rs ├── infini │ └── mod.rs ├── mod.rs └── opencl │ └── mod.rs ├── handle ├── common_cpu │ ├── inproc_node.rs │ └── mod.rs ├── cuda │ ├── alloc.rs │ ├── cxx │ │ ├── export.h │ │ ├── iluvatar.lua │ │ ├── nv.lua │ │ └── test_compile_8.0 │ │ │ └── test_compile.cu │ ├── library.rs │ ├── mod.rs │ ├── module.rs │ └── nccl.rs ├── infini │ ├── ccl.rs │ └── mod.rs ├── mod.rs └── opencl │ └── mod.rs ├── layer_norm ├── args.rs ├── common_cpu │ └── mod.rs ├── cuda │ ├── layer_norm.cuh │ └── mod.rs ├── infini │ └── mod.rs ├── mod.rs └── opencl │ └── mod.rs ├── lib.rs ├── mat_mul ├── args.rs ├── common_cpu │ └── mod.rs ├── cuda │ └── mod.rs ├── infini │ └── mod.rs ├── mod.rs └── opencl │ ├── mat_mul.cl │ └── mod.rs ├── random_sample ├── args.rs ├── common_cpu │ └── mod.rs ├── cuda │ ├── ffi.rs │ ├── mod.rs │ └── sample.cuh ├── infini │ └── mod.rs ├── kv_pair.rs ├── mod.rs └── opencl │ ├── mod.rs │ └── random_sample.cl ├── rearrange ├── args.rs ├── common_cpu │ └── mod.rs ├── cuda │ ├── mod.rs │ └── rearrange.cuh ├── infini │ └── mod.rs ├── mod.rs └── opencl │ ├── mod.rs │ └── rearrange.cl ├── rms_norm ├── args.rs ├── common_cpu │ └── mod.rs ├── cuda │ ├── mod.rs │ └── rms_norm.cuh ├── infini │ └── mod.rs ├── mod.rs └── opencl │ ├── mod.rs │ └── rms_norm.cl ├── rope ├── args.rs ├── common_cpu │ └── mod.rs ├── cuda │ ├── mod.rs │ └── rope.cuh ├── infini │ └── mod.rs ├── mod.rs └── opencl │ ├── mod.rs │ └── rope.cl └── swiglu ├── args.rs ├── common_cpu └── mod.rs ├── cuda ├── mod.rs └── swiglu.cuh ├── infini └── mod.rs ├── mod.rs └── opencl ├── mod.rs └── swiglu.cl /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | # This workflow uses actions that are not certified by GitHub. 2 | # They are provided by a third-party and are governed by 3 | # separate terms of service, privacy policy, and support 4 | # documentation. 5 | # rust-clippy is a tool that runs a bunch of lints to catch common 6 | # mistakes in your Rust code and help improve your Rust code. 7 | # More details at https://github.com/rust-lang/rust-clippy 8 | # and https://rust-lang.github.io/rust-clippy/ 9 | 10 | name: CI 11 | 12 | on: 13 | pull_request: 14 | push: 15 | paths-ignore: 16 | - '**.md' 17 | - 'LICENSE' 18 | 19 | jobs: 20 | rust-clippy-analyze: 21 | name: Run rust-clippy analyzing 22 | runs-on: ubuntu-latest 23 | permissions: 24 | security-events: write 25 | steps: 26 | - name: Checkout code 27 | uses: actions/checkout@v4 28 | 29 | - name: Check format 30 | run: cargo fmt --check 31 | 32 | - name: Update to latest deps 33 | run: cargo update 34 | 35 | - name: cuda-toolkit 36 | uses: Jimver/cuda-toolkit@v0.2.18 37 | with: 38 | method: 'network' 39 | 40 | - name: Install xmake 41 | uses: xmake-io/github-action-setup-xmake@v1 42 | with: 43 | xmake-version: latest 44 | 45 | - name: Run test 46 | run: cargo test --release 47 | 48 | - name: Install required cargo 49 | run: cargo install clippy-sarif sarif-fmt 50 | 51 | - name: Run rust-clippy 52 | run: 53 | cargo clippy 54 | --all-features 55 | --message-format=json | clippy-sarif | tee rust-clippy-results.sarif | sarif-fmt 56 | continue-on-error: true 57 | 58 | - name: Upload analysis results to GitHub 59 | uses: github/codeql-action/upload-sarif@v3 60 | with: 61 | sarif_file: rust-clippy-results.sarif 62 | wait-for-processing: true 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = ["operators"] 3 | resolver = "2" 4 | 5 | [workspace.dependencies] 6 | clrt = { git = "https://github.com/InfiniTensor/clrt", rev = "984ac7a" } 7 | search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "984ac7a" } 8 | 9 | infini-rt = { git = "https://github.com/InfiniTensor/infini-toolkit", rev = "e8362c3" } 10 | infini-op = { git = "https://github.com/InfiniTensor/infini-toolkit", rev = "e8362c3" } 11 | infini-ccl = { git = "https://github.com/InfiniTensor/infini-toolkit", rev = "e8362c3" } 12 | search-infini-tools = { git = "https://github.com/InfiniTensor/infini-toolkit", rev = "e8362c3" } 13 | 14 | cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "f3ffbcc" } 15 | cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "f3ffbcc" } 16 | nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "f3ffbcc" } 17 | search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "f3ffbcc" } 18 | search-corex-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "f3ffbcc" } 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright © 2024 YdrMaster 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 多硬件支持的算子库 2 | 3 | [![CI](https://github.com/YdrMaster/operators-rs/actions/workflows/build.yml/badge.svg?branch=main)](https://github.com/YdrMaster/operators-rs/actions) 4 | ![GitHub repo size](https://img.shields.io/github/repo-size/YdrMaster/operators-rs) 5 | ![GitHub code size in bytes](https://img.shields.io/github/languages/code-size/YdrMaster/operators-rs) 6 | [![GitHub Issues](https://img.shields.io/github/issues/YdrMaster/operators-rs)](https://github.com/YdrMaster/operators-rs/issues) 7 | [![GitHub Pull Requests](https://img.shields.io/github/issues-pr/YdrMaster/operators-rs)](https://github.com/YdrMaster/operators-rs/pulls) 8 | ![GitHub contributors](https://img.shields.io/github/contributors/YdrMaster/operators-rs) 9 | ![GitHub commit activity](https://img.shields.io/github/commit-activity/m/YdrMaster/operators-rs) 10 | -------------------------------------------------------------------------------- /operators/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "operators" 3 | version = "0.0.0" 4 | edition = "2021" 5 | authors = ["YdrMaster "] 6 | 7 | [features] 8 | default = ["common-cpu", "opencl", "infini", "nvidia-gpu", "iluvatar-gpu"] 9 | common-cpu = ["gemm"] 10 | opencl = ["clrt"] 11 | infini = ["infini-rt", "infini-op", "infini-ccl"] 12 | nvidia-gpu = ["cuda", "cublas", "nccl", "fslock", "libloading"] 13 | iluvatar-gpu = ["cuda", "cublas", "fslock", "libloading"] 14 | 15 | [dependencies] 16 | digit-layout = "0.2" 17 | ndarray-layout = "0.1" 18 | rayon = "1.10" 19 | lru = "0.12" 20 | num-traits = "0.2" 21 | itertools = "0.14" 22 | half = "2.4" 23 | log = "0.4" 24 | 25 | gemm = { version = "0.18", optional = true } 26 | 27 | clrt = { workspace = true, optional = true } 28 | 29 | infini-rt = { workspace = true, optional = true } 30 | infini-op = { workspace = true, optional = true } 31 | infini-ccl = { workspace = true, optional = true } 32 | 33 | cuda = { workspace = true, optional = true } 34 | cublas = { workspace = true, optional = true } 35 | nccl = { workspace = true, optional = true } 36 | fslock = { version = "0.2", optional = true } 37 | libloading = { version = "0.8", optional = true } 38 | 39 | [build-dependencies] 40 | build-script-cfg = "0.0" 41 | search-cl-tools.workspace = true 42 | search-infini-tools.workspace = true 43 | search-cuda-tools.workspace = true 44 | search-corex-tools.workspace = true 45 | 46 | [dev-dependencies] 47 | gemm = "0.18" 48 | rand = "0.9" 49 | -------------------------------------------------------------------------------- /operators/build.rs: -------------------------------------------------------------------------------- 1 | fn main() { 2 | use build_script_cfg::Cfg; 3 | use search_cl_tools::find_opencl; 4 | use search_corex_tools::find_corex; 5 | use search_cuda_tools::{find_cuda_root, find_nccl_root}; 6 | use search_infini_tools::{find_infini_ccl, find_infini_op, find_infini_rt}; 7 | 8 | let cpu = Cfg::new("use_cpu"); 9 | let cl = Cfg::new("use_cl"); 10 | let infini = Cfg::new("use_infini"); 11 | let cuda = Cfg::new("use_cuda"); 12 | let nvidia = Cfg::new("use_nvidia"); 13 | let nccl = Cfg::new("use_nccl"); 14 | let iluvatar = Cfg::new("use_iluvatar"); 15 | 16 | if cfg!(feature = "common-cpu") { 17 | cpu.define() 18 | } 19 | if cfg!(feature = "opencl") && find_opencl().is_some() { 20 | cl.define() 21 | } 22 | if cfg!(feature = "infini") 23 | && find_infini_rt().is_some() 24 | && find_infini_op().is_some() 25 | && find_infini_ccl().is_some() 26 | { 27 | infini.define() 28 | } 29 | let use_nvidia = cfg!(feature = "nvidia-gpu") && find_cuda_root().is_some(); 30 | let use_iluvatar = cfg!(feature = "iluvatar-gpu") && find_corex().is_some(); 31 | if use_nvidia { 32 | nvidia.define(); 33 | if find_nccl_root().is_some() { 34 | nccl.define() 35 | } 36 | } 37 | if use_iluvatar { 38 | iluvatar.define() 39 | } 40 | if use_nvidia || use_iluvatar { 41 | cuda.define() 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /operators/src/.clang-format: -------------------------------------------------------------------------------- 1 | # Generated from CLion C/C++ Code Style settings 2 | BasedOnStyle: LLVM 3 | AccessModifierOffset: -4 4 | AlignAfterOpenBracket: Align 5 | # AlignConsecutiveAssignments: None 6 | AlignOperands: Align 7 | AllowAllArgumentsOnNextLine: false 8 | AllowAllConstructorInitializersOnNextLine: false 9 | AllowAllParametersOfDeclarationOnNextLine: false 10 | AllowShortBlocksOnASingleLine: Always 11 | AllowShortCaseLabelsOnASingleLine: false 12 | AllowShortFunctionsOnASingleLine: All 13 | AllowShortIfStatementsOnASingleLine: Always 14 | AllowShortLambdasOnASingleLine: All 15 | AllowShortLoopsOnASingleLine: true 16 | AlwaysBreakAfterReturnType: None 17 | AlwaysBreakTemplateDeclarations: No 18 | BreakBeforeBraces: Custom 19 | BraceWrapping: 20 | AfterCaseLabel: false 21 | AfterClass: false 22 | AfterControlStatement: Never 23 | AfterEnum: false 24 | AfterFunction: false 25 | AfterNamespace: false 26 | AfterUnion: false 27 | BeforeCatch: false 28 | BeforeElse: false 29 | IndentBraces: false 30 | SplitEmptyFunction: false 31 | SplitEmptyRecord: true 32 | BreakBeforeBinaryOperators: None 33 | BreakBeforeTernaryOperators: true 34 | BreakConstructorInitializers: BeforeColon 35 | BreakInheritanceList: BeforeColon 36 | ColumnLimit: 0 37 | CompactNamespaces: true 38 | ContinuationIndentWidth: 4 39 | IndentCaseLabels: true 40 | IndentPPDirectives: None 41 | IndentWidth: 4 42 | KeepEmptyLinesAtTheStartOfBlocks: true 43 | MaxEmptyLinesToKeep: 2 44 | NamespaceIndentation: All 45 | ObjCSpaceAfterProperty: false 46 | ObjCSpaceBeforeProtocolList: true 47 | PointerAlignment: Right 48 | ReflowComments: false 49 | SpaceAfterCStyleCast: true 50 | SpaceAfterLogicalNot: false 51 | SpaceAfterTemplateKeyword: false 52 | SpaceBeforeAssignmentOperators: true 53 | SpaceBeforeCpp11BracedList: false 54 | SpaceBeforeCtorInitializerColon: true 55 | SpaceBeforeInheritanceColon: true 56 | SpaceBeforeParens: ControlStatements 57 | SpaceBeforeRangeBasedForLoopColon: true 58 | SpaceInEmptyParentheses: false 59 | SpacesBeforeTrailingComments: 0 60 | SpacesInAngles: false 61 | SpacesInCStyleCastParentheses: false 62 | SpacesInContainerLiterals: false 63 | SpacesInParentheses: false 64 | SpacesInSquareBrackets: false 65 | TabWidth: 4 66 | UseTab: Never 67 | -------------------------------------------------------------------------------- /operators/src/add/common_cpu/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{args::Scheme, Add, Args}; 2 | use crate::{common_cpu::Cpu, ByteOf, LaunchError, QueueAlloc, SchemeError}; 3 | use digit_layout::types as ty; 4 | use half::f16; 5 | use rayon::iter::{IntoParallelIterator, ParallelIterator}; 6 | 7 | pub struct Operator; 8 | 9 | impl Add for Operator {} 10 | 11 | impl crate::Operator for Operator { 12 | type Hardware = Cpu; 13 | type TopoNode = Cpu; 14 | type Args = Args; 15 | 16 | #[inline] 17 | fn new(_node: &Self::TopoNode) -> Self { 18 | Self 19 | } 20 | #[inline] 21 | fn scheme( 22 | &mut self, 23 | _args: &Self::Args, 24 | _max_workspace_size: usize, 25 | ) -> Result { 26 | Ok(0) 27 | } 28 | 29 | fn launch( 30 | &self, 31 | args: &Self::Args, 32 | _workspace: &mut [ByteOf], 33 | _queue_alloc: &QA, 34 | ) -> Result<(), LaunchError> 35 | where 36 | QA: QueueAlloc, 37 | { 38 | let scheme = Scheme::new(args)?; 39 | let c = args.c_base as isize; 40 | let a = args.a_base as isize; 41 | let b = args.b_base as isize; 42 | let idx_strides = scheme.idx_strides(); 43 | let c_strides = scheme.c_strides(); 44 | let a_strides = scheme.a_strides(); 45 | let b_strides = scheme.b_strides(); 46 | (0..scheme.count() as isize) 47 | .into_par_iter() 48 | .for_each(|mut rem| { 49 | let mut c = c; 50 | let mut a = a; 51 | let mut b = b; 52 | for (i, &s) in idx_strides.iter().enumerate() { 53 | let k = rem / s; 54 | c += k * c_strides[i]; 55 | a += k * a_strides[i]; 56 | b += k * b_strides[i]; 57 | rem %= s; 58 | } 59 | match scheme.dt() { 60 | ty::F16 => add::(c, a, b), 61 | ty::F32 => add::(c, a, b), 62 | ty::F64 => add::(c, a, b), 63 | _ => todo!(), 64 | } 65 | }); 66 | Ok(()) 67 | } 68 | } 69 | 70 | fn add>(c: isize, a: isize, b: isize) { 71 | let c = c as *mut T; 72 | let a = a as *const T; 73 | let b = b as *const T; 74 | unsafe { *c = a.read() + b.read() } 75 | } 76 | -------------------------------------------------------------------------------- /operators/src/add/cuda/add.cuh: -------------------------------------------------------------------------------- 1 | template 2 | static __device__ void _add( 3 | Tdata *__restrict__ c, 4 | Tdata const *__restrict__ a, 5 | Tdata const *__restrict__ b) { 6 | auto const idx = blockIdx.x * blockDim.x + threadIdx.x; 7 | c[idx] = a[idx] + b[idx]; 8 | } 9 | -------------------------------------------------------------------------------- /operators/src/add/infini/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{Add, Args}; 2 | use crate::{infini::Device, ByteOf, LaunchError, QueueAlloc, SchemeError}; 3 | 4 | pub struct Operator; 5 | 6 | impl Add for Operator {} 7 | 8 | impl crate::Operator for Operator { 9 | type Hardware = Device; 10 | type TopoNode = Device; 11 | type Args = Args; 12 | 13 | fn new(_node: &Self::TopoNode) -> Self { 14 | todo!() 15 | } 16 | 17 | fn scheme( 18 | &mut self, 19 | _args: &Self::Args, 20 | _max_workspace_size: usize, 21 | ) -> Result { 22 | todo!() 23 | } 24 | 25 | fn launch( 26 | &self, 27 | _args: &Self::Args, 28 | _workspace: &mut [ByteOf], 29 | _queue_alloc: &QA, 30 | ) -> Result<(), LaunchError> 31 | where 32 | QA: QueueAlloc, 33 | { 34 | todo!() 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /operators/src/add/mod.rs: -------------------------------------------------------------------------------- 1 | //! c = a + b 2 | 3 | #[cfg(any(use_cpu, test))] 4 | pub mod common_cpu; 5 | #[cfg(use_cuda)] 6 | pub mod cuda; 7 | #[cfg(use_infini)] 8 | pub mod infini; 9 | #[cfg(use_cl)] 10 | pub mod opencl; 11 | 12 | mod args; 13 | pub use args::Args; 14 | 15 | crate::op_trait!(Add); 16 | -------------------------------------------------------------------------------- /operators/src/add/opencl/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{Add, Args}; 2 | use crate::{opencl::ClDevice, ByteOf, LaunchError, QueueAlloc, SchemeError}; 3 | 4 | pub struct Operator; 5 | 6 | impl Add for Operator {} 7 | 8 | impl crate::Operator for Operator { 9 | type Hardware = ClDevice; 10 | type TopoNode = ClDevice; 11 | type Args = Args; 12 | 13 | fn new(_node: &Self::TopoNode) -> Self { 14 | todo!() 15 | } 16 | 17 | fn scheme( 18 | &mut self, 19 | _args: &Self::Args, 20 | _max_workspace_size: usize, 21 | ) -> Result { 22 | todo!() 23 | } 24 | 25 | fn launch( 26 | &self, 27 | _args: &Self::Args, 28 | _workspace: &mut [ByteOf], 29 | _queue_alloc: &QA, 30 | ) -> Result<(), LaunchError> 31 | where 32 | QA: QueueAlloc, 33 | { 34 | todo!() 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /operators/src/add_rows/args.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | type_not_support, 3 | utils::{dim_distinct, rank_error, type_distinct}, 4 | ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout, 5 | }; 6 | use digit_layout::{DigitLayout, LayoutContent::Unsigned}; 7 | use std::ptr::{null, null_mut}; 8 | 9 | #[derive(Clone)] 10 | pub struct Args { 11 | pub dst_layout: TensorLayout, 12 | pub dst_base: MutPtr, 13 | pub src_layout: TensorLayout, 14 | pub src_base: ConstPtr, 15 | pub idx_layout: TensorLayout, 16 | pub idx_base: ConstPtr, 17 | } 18 | 19 | impl Args { 20 | pub fn new_null( 21 | dst_layout: TensorLayout, 22 | src_layout: TensorLayout, 23 | idx_layout: TensorLayout, 24 | ) -> Self { 25 | Self { 26 | dst_layout, 27 | dst_base: null_mut(), 28 | src_layout, 29 | src_base: null(), 30 | idx_layout, 31 | idx_base: null(), 32 | } 33 | } 34 | } 35 | 36 | #[derive(Clone, Debug)] 37 | pub(super) struct Meta { 38 | pub dt: DigitLayout, 39 | pub dt_idx: DigitLayout, 40 | pub batch: MaybeDyn, 41 | pub m: MaybeDyn, 42 | pub n: MaybeDyn, 43 | pub k: MaybeDyn, 44 | } 45 | 46 | impl Args { 47 | pub(super) fn meta(&self) -> Result { 48 | let Self { 49 | dst_layout: dst, 50 | src_layout: src, 51 | idx_layout: idx, 52 | .. 53 | } = self; 54 | 55 | let dt = type_distinct(&[dst.dt(), src.dt()])?; 56 | let dt_idx = idx.dt(); 57 | if !matches!(dt_idx.decode(), Unsigned { .. }) { 58 | return Err(type_not_support(format!( 59 | "data type {dt_idx} is not supported, must be unsigned integers" 60 | ))); 61 | } 62 | 63 | let &[batch, m, n] = dst.shape() else { 64 | return Err(rank_error("dst", 3, dst.ndim())); 65 | }; 66 | let &[k, n_] = src.shape() else { 67 | return Err(rank_error("src", 2, src.ndim())); 68 | }; 69 | let &[batch_, m_] = idx.shape() else { 70 | return Err(rank_error("idx", 2, idx.ndim())); 71 | }; 72 | 73 | Ok(Meta { 74 | dt, 75 | dt_idx, 76 | batch: dim_distinct(&[batch, batch_])?, 77 | m: dim_distinct(&[m, m_])?, 78 | n: dim_distinct(&[n, n_])?, 79 | k, 80 | }) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /operators/src/add_rows/common_cpu/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{args::Meta, AddRows, Args}; 2 | use crate::{common_cpu::Cpu, get_static, ByteOf, LaunchError, QueueAlloc, SchemeError, Unsigned}; 3 | use digit_layout::types as ty; 4 | use half::f16; 5 | use rayon::iter::{IntoParallelIterator, ParallelIterator}; 6 | use std::ops::AddAssign; 7 | 8 | pub struct Operator; 9 | 10 | impl AddRows for Operator {} 11 | 12 | impl crate::Operator for Operator { 13 | type Hardware = Cpu; 14 | type TopoNode = Cpu; 15 | type Args = Args; 16 | 17 | fn new(_node: &Self::TopoNode) -> Self { 18 | Self 19 | } 20 | 21 | fn scheme( 22 | &mut self, 23 | _args: &Self::Args, 24 | _max_workspace_size: usize, 25 | ) -> Result { 26 | Ok(0) 27 | } 28 | 29 | fn launch( 30 | &self, 31 | args: &Self::Args, 32 | _workspace: &mut [ByteOf], 33 | _queue_alloc: &QA, 34 | ) -> Result<(), LaunchError> 35 | where 36 | QA: QueueAlloc, 37 | { 38 | let Meta { 39 | dt, 40 | dt_idx, 41 | batch: b, 42 | m, 43 | n, 44 | k, 45 | } = args.meta()?; 46 | let Args { 47 | dst_layout, 48 | dst_base, 49 | src_layout, 50 | src_base, 51 | idx_layout, 52 | idx_base, 53 | } = args; 54 | 55 | let &[bsd, msd, nsd] = dst_layout.strides() else { 56 | unreachable!() 57 | }; 58 | let &[kss, nss] = src_layout.strides() else { 59 | unreachable!() 60 | }; 61 | let &[bsi, msi] = idx_layout.strides() else { 62 | unreachable!() 63 | }; 64 | 65 | get_static! { 66 | b m n k 67 | bsd msd nsd 68 | bsi msi nss kss 69 | } 70 | 71 | let dst = *dst_base as usize; 72 | let src = *src_base as usize; 73 | let idx = *idx_base as usize; 74 | 75 | macro_rules! calculate { 76 | ($t:ty, $i:ty) => { 77 | (0..b * m).into_par_iter().for_each(|bm| { 78 | Scheme::<$t, $i> { 79 | dst: dst as _, 80 | src: src as _, 81 | idx: idx as _, 82 | m, 83 | n, 84 | k, 85 | bsd, 86 | msd, 87 | nsd, 88 | kss, 89 | nss, 90 | bsi, 91 | msi, 92 | } 93 | .calculate(bm) 94 | }) 95 | }; 96 | } 97 | 98 | match (dt, dt_idx) { 99 | (ty::F16, ty::U32) => calculate!(f16, u32), 100 | (ty::F32, ty::U32) => calculate!(f32, u32), 101 | (ty::F64, ty::U32) => calculate!(f64, u32), 102 | (ty::F16, ty::U64) => calculate!(f16, u64), 103 | (ty::F32, ty::U64) => calculate!(f32, u64), 104 | (ty::F64, ty::U64) => calculate!(f64, u64), 105 | (_, _) => todo!(), 106 | } 107 | Ok(()) 108 | } 109 | } 110 | 111 | struct Scheme { 112 | dst: *mut T, 113 | src: *const T, 114 | idx: *const I, 115 | m: usize, 116 | n: usize, 117 | k: usize, 118 | bsd: isize, 119 | msd: isize, 120 | nsd: isize, 121 | kss: isize, 122 | nss: isize, 123 | bsi: isize, 124 | msi: isize, 125 | } 126 | 127 | impl Scheme 128 | where 129 | T: AddAssign + Copy, 130 | I: Unsigned + Copy, 131 | { 132 | fn calculate(&self, bm: usize) { 133 | let b = (bm / self.m) as isize; 134 | let m = (bm % self.m) as isize; 135 | let dst = unsafe { self.dst.byte_offset(b * self.bsd + m * self.msd) }; 136 | let idx = unsafe { *self.idx.byte_offset(b * self.bsi + m * self.msi) }.val(); 137 | assert!(idx < self.k); 138 | 139 | let src = unsafe { self.src.byte_offset(idx as isize * self.kss) }; 140 | for i in 0..self.n as isize { 141 | unsafe { 142 | let dst = dst.byte_offset(i * self.nsd); 143 | let src = src.byte_offset(i * self.nss); 144 | *dst += *src; 145 | } 146 | } 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /operators/src/add_rows/cuda/add_rows.cuh: -------------------------------------------------------------------------------- 1 | template 2 | static __device__ void add_rows( 3 | Tdata *__restrict__ dst, 4 | Tdata const *__restrict__ src, 5 | Tidx const *__restrict__ i, 6 | int const stride_d_b, 7 | int const stride_d_m, 8 | int const stride_s, 9 | int const stride_i) { 10 | auto idx_n = blockIdx.x * blockDim.x + threadIdx.x; 11 | auto idst = blockIdx.z * stride_d_b + blockIdx.y * stride_d_m + idx_n; 12 | auto isrc = i[blockIdx.z * stride_i + blockIdx.y] * stride_s + idx_n; 13 | dst[idst] += src[isrc]; 14 | } 15 | -------------------------------------------------------------------------------- /operators/src/add_rows/infini/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{AddRows, Args}; 2 | use crate::{infini::Device, ByteOf, LaunchError, QueueAlloc, SchemeError}; 3 | 4 | pub struct Operator; 5 | 6 | impl AddRows for Operator {} 7 | 8 | impl crate::Operator for Operator { 9 | type Hardware = Device; 10 | type TopoNode = Device; 11 | type Args = Args; 12 | 13 | fn new(_node: &Self::TopoNode) -> Self { 14 | todo!() 15 | } 16 | 17 | fn scheme( 18 | &mut self, 19 | _args: &Self::Args, 20 | _max_workspace_size: usize, 21 | ) -> Result { 22 | todo!() 23 | } 24 | 25 | fn launch( 26 | &self, 27 | _args: &Self::Args, 28 | _workspace: &mut [ByteOf], 29 | _queue_alloc: &QA, 30 | ) -> Result<(), LaunchError> 31 | where 32 | QA: QueueAlloc, 33 | { 34 | todo!() 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /operators/src/add_rows/mod.rs: -------------------------------------------------------------------------------- 1 | //! dst += src[i] 2 | 3 | #[cfg(any(use_cpu, test))] 4 | pub mod common_cpu; 5 | #[cfg(use_cuda)] 6 | pub mod cuda; 7 | #[cfg(use_infini)] 8 | pub mod infini; 9 | #[cfg(use_cl)] 10 | pub mod opencl; 11 | 12 | mod args; 13 | pub use args::Args; 14 | 15 | crate::op_trait!(AddRows); 16 | -------------------------------------------------------------------------------- /operators/src/add_rows/opencl/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{AddRows, Args}; 2 | use crate::{opencl::ClDevice, ByteOf, LaunchError, QueueAlloc, SchemeError}; 3 | 4 | pub struct Operator; 5 | 6 | impl AddRows for Operator {} 7 | 8 | impl crate::Operator for Operator { 9 | type Hardware = ClDevice; 10 | type TopoNode = ClDevice; 11 | type Args = Args; 12 | 13 | fn new(_node: &Self::TopoNode) -> Self { 14 | todo!() 15 | } 16 | 17 | fn scheme( 18 | &mut self, 19 | _args: &Self::Args, 20 | _max_workspace_size: usize, 21 | ) -> Result { 22 | todo!() 23 | } 24 | 25 | fn launch( 26 | &self, 27 | _args: &Self::Args, 28 | _workspace: &mut [ByteOf], 29 | _queue_alloc: &QA, 30 | ) -> Result<(), LaunchError> 31 | where 32 | QA: QueueAlloc, 33 | { 34 | todo!() 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /operators/src/all_reduce/args.rs: -------------------------------------------------------------------------------- 1 | use super::ReduceOp; 2 | use crate::{ 3 | dyn_not_support, rearrange, shape_mismatch, strides_not_support, utils::type_distinct, 4 | Hardware, MaybeDyn, SchemeError, 5 | }; 6 | use digit_layout::DigitLayout; 7 | use ndarray_layout::ArrayLayout; 8 | 9 | pub struct Args { 10 | pub pair: rearrange::Args, 11 | pub op: ReduceOp, 12 | } 13 | 14 | impl AsRef> for Args { 15 | #[inline] 16 | fn as_ref(&self) -> &rearrange::Args { 17 | &self.pair 18 | } 19 | } 20 | 21 | pub(super) struct Meta { 22 | pub dt: DigitLayout, 23 | pub size: usize, 24 | } 25 | 26 | impl Args { 27 | pub(super) fn meta(&self) -> Result { 28 | let Self { 29 | pair: 30 | rearrange::Args { 31 | dst_layout, 32 | src_layout, 33 | .. 34 | }, 35 | .. 36 | } = self; 37 | 38 | let dt = type_distinct(&[dst_layout.dt(), src_layout.dt()])?; 39 | 40 | let Some(shape) = MaybeDyn::get_all(dst_layout.shape()) else { 41 | return Err(dyn_not_support("")); 42 | }; 43 | let Some(strides) = MaybeDyn::get_all(dst_layout.strides()) else { 44 | return Err(dyn_not_support("")); 45 | }; 46 | let dst = ArrayLayout::<2>::new(shape, strides, 0); 47 | let &[dst] = dst 48 | .merge_be(0, dst.ndim()) 49 | .ok_or(strides_not_support(""))? 50 | .shape() 51 | else { 52 | unreachable!() 53 | }; 54 | 55 | let Some(shape) = MaybeDyn::get_all(src_layout.shape()) else { 56 | return Err(dyn_not_support("")); 57 | }; 58 | let Some(strides) = MaybeDyn::get_all(src_layout.strides()) else { 59 | return Err(dyn_not_support("")); 60 | }; 61 | let src = ArrayLayout::<2>::new(shape, strides, 0); 62 | let &[src] = src 63 | .merge_be(0, src.ndim()) 64 | .ok_or(strides_not_support(""))? 65 | .shape() 66 | else { 67 | unreachable!() 68 | }; 69 | 70 | if dst != src { 71 | return Err(shape_mismatch("")); 72 | } 73 | 74 | Ok(Meta { dt, size: dst }) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /operators/src/all_reduce/common_cpu.rs: -------------------------------------------------------------------------------- 1 | use super::{args::Meta, AllReduce, Args, ReduceOp}; 2 | use crate::{ 3 | broadcast::{self, common_cpu::Operator as Broadcast}, 4 | common_cpu::{Cpu, InprocNode}, 5 | rearrange, ByteOf, LaunchError, QueueAlloc, SchemeError, TopoNode, 6 | }; 7 | use digit_layout::DigitLayout; 8 | use half::{bf16, f16}; 9 | use std::{ 10 | iter::zip, 11 | ops::AddAssign, 12 | slice::{from_raw_parts, from_raw_parts_mut}, 13 | }; 14 | 15 | pub struct Operator { 16 | node: InprocNode, 17 | broadcast: Broadcast, 18 | } 19 | 20 | impl AllReduce> for Operator {} 21 | 22 | impl crate::Operator for Operator { 23 | type Hardware = Cpu; 24 | type TopoNode = InprocNode; 25 | type Args = Args; 26 | 27 | fn new(node: &Self::TopoNode) -> Self { 28 | assert!(node.group_size().is_power_of_two()); 29 | Self { 30 | node: node.clone(), 31 | broadcast: Broadcast::new(node), 32 | } 33 | } 34 | 35 | fn scheme( 36 | &mut self, 37 | _args: &Self::Args, 38 | _max_workspace_size: usize, 39 | ) -> Result { 40 | Ok(0) 41 | } 42 | 43 | fn launch( 44 | &self, 45 | args: &Self::Args, 46 | workspace: &mut [ByteOf], 47 | queue_alloc: &QA, 48 | ) -> Result<(), LaunchError> 49 | where 50 | QA: QueueAlloc, 51 | { 52 | let rank = self.node.rank(); 53 | let group_size = self.node.group_size(); 54 | if group_size == 1 { 55 | return Ok(()); 56 | } 57 | 58 | let Meta { dt, size } = args.meta()?; 59 | let &Args { 60 | pair: rearrange::Args { 61 | dst_base, src_base, .. 62 | }, 63 | op, 64 | } = args; 65 | 66 | let mut ptr = src_base; 67 | let mut i = rank; 68 | let mut root = rank; 69 | let mut stride = 1; 70 | let guard = self.node.wait(); 71 | while stride < group_size { 72 | if i % 2 == 0 { 73 | root += stride; 74 | self.node.send(root, ptr as _); 75 | 76 | i /= 2; 77 | stride *= 2; 78 | while stride < group_size { 79 | if i % 2 == 0 { 80 | root += stride; 81 | } 82 | i /= 2; 83 | stride *= 2; 84 | } 85 | 86 | break; 87 | } else { 88 | reduce(dt, op, size, dst_base, self.node.recv() as _); 89 | ptr = dst_base; 90 | 91 | i /= 2; 92 | stride *= 2; 93 | } 94 | } 95 | drop(guard); 96 | self.broadcast.launch( 97 | &broadcast::Args { 98 | pair: args.pair.clone(), 99 | root, 100 | }, 101 | workspace, 102 | queue_alloc, 103 | ) 104 | } 105 | } 106 | 107 | fn reduce(dt: DigitLayout, op: ReduceOp, len: usize, buf: *mut u8, src: *const u8) { 108 | match op { 109 | ReduceOp::Sum => { 110 | macro_rules! sum { 111 | ($( $dt:ident => $ty:ty )+ ) => { 112 | match dt { 113 | $( digit_layout::types::$dt => sum::<$ty>(len, buf, src), )+ 114 | _ => todo!(), 115 | } 116 | }; 117 | } 118 | sum! { 119 | U8 => u8 120 | I8 => i8 121 | U16 => u16 122 | I16 => i16 123 | F16 => f16 124 | BF16 => bf16 125 | U32 => u32 126 | I32 => i32 127 | F32 => f32 128 | U64 => u64 129 | I64 => i64 130 | F64 => f64 131 | U128 => u128 132 | I128 => i128 133 | } 134 | } 135 | ReduceOp::Prod | ReduceOp::Min | ReduceOp::Max | ReduceOp::Mean => todo!(), 136 | } 137 | } 138 | 139 | fn sum(len: usize, buf: *mut u8, src: *const u8) { 140 | let dst = unsafe { from_raw_parts_mut(buf.cast::(), len) }; 141 | let src = unsafe { from_raw_parts(src.cast::(), len) }; 142 | for (dst, src) in zip(dst, src) { 143 | *dst += src.clone(); 144 | } 145 | } 146 | 147 | #[test] 148 | fn test_comm() { 149 | use crate::{common_cpu::ThisThread, Operator as _, TensorLayout}; 150 | use digit_layout::types::U32; 151 | 152 | InprocNode::new(4) 153 | .into_iter() 154 | .map(|node| { 155 | std::thread::spawn(move || { 156 | let mut buf = [node.rank() as u32; 8]; 157 | let op = Operator::new(&node); 158 | op.launch( 159 | &Args { 160 | pair: rearrange::Args { 161 | dst_layout: TensorLayout::new_contiguous(U32, &[8]), 162 | dst_base: buf.as_mut_ptr().cast(), 163 | src_layout: TensorLayout::new_contiguous(U32, &[8]), 164 | src_base: buf.as_ptr().cast(), 165 | }, 166 | op: ReduceOp::Sum, 167 | }, 168 | &mut [], 169 | &ThisThread, 170 | ) 171 | .unwrap(); 172 | buf 173 | }) 174 | }) 175 | .collect::>() 176 | .into_iter() 177 | .for_each(|h| assert_eq!(h.join().unwrap(), [0 + 1 + 2 + 3; 8])); 178 | } 179 | -------------------------------------------------------------------------------- /operators/src/all_reduce/infini.rs: -------------------------------------------------------------------------------- 1 | use super::{args::Meta, AllReduce, Args, ReduceOp}; 2 | use crate::{ 3 | infini::{Device, InfiniNode}, 4 | rearrange::{self, infini::Operator as Rearrange}, 5 | ByteOf, LaunchError, QueueAlloc, SchemeError, 6 | }; 7 | use digit_layout::types as ty; 8 | use infini_ccl::bindings::InfiniDataType_t; 9 | use std::{ 10 | slice::{from_raw_parts, from_raw_parts_mut}, 11 | sync::Arc, 12 | }; 13 | 14 | pub enum Operator { 15 | Rearrange(Rearrange), 16 | Comm(Arc), 17 | } 18 | 19 | impl AllReduce for Operator {} 20 | 21 | impl crate::Operator for Operator { 22 | type Hardware = Device; 23 | type TopoNode = InfiniNode; 24 | type Args = Args; 25 | 26 | fn new(node: &Self::TopoNode) -> Self { 27 | match node.comm.as_ref() { 28 | Some(comm) => Self::Comm(comm.clone()), 29 | None => Self::Rearrange(Rearrange::new(&node.device)), 30 | } 31 | } 32 | 33 | fn scheme( 34 | &mut self, 35 | _args: &Self::Args, 36 | _max_workspace_size: usize, 37 | ) -> Result { 38 | Ok(0) 39 | } 40 | 41 | fn launch( 42 | &self, 43 | args: &Self::Args, 44 | workspace: &mut [ByteOf], 45 | queue_alloc: &QA, 46 | ) -> Result<(), LaunchError> 47 | where 48 | QA: QueueAlloc, 49 | { 50 | match self { 51 | Self::Rearrange(rearrange) => rearrange.launch(&args.pair, workspace, queue_alloc), 52 | Self::Comm(comm) => { 53 | let Meta { dt, size } = args.meta()?; 54 | let &Args { 55 | pair: 56 | rearrange::Args { 57 | dst_base, src_base, .. 58 | }, 59 | op, 60 | .. 61 | } = args; 62 | 63 | assert_eq!(op, ReduceOp::Sum); 64 | let len = dt.nbytes() * size; 65 | 66 | comm.allreduce_sum( 67 | unsafe { from_raw_parts_mut(dst_base, len) }, 68 | unsafe { from_raw_parts(src_base, len) }, 69 | match dt { 70 | ty::F16 => InfiniDataType_t::INFINI_F16, 71 | ty::F32 => InfiniDataType_t::INFINI_F32, 72 | _ => todo!(), 73 | }, 74 | queue_alloc.queue(), 75 | ); 76 | Ok(()) 77 | } 78 | } 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /operators/src/all_reduce/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(any(use_cpu, test))] 2 | pub mod common_cpu; 3 | #[cfg(use_infini)] 4 | pub mod infini; 5 | #[cfg(use_nccl)] 6 | pub mod nccl; 7 | 8 | mod args; 9 | pub use args::Args; 10 | 11 | crate::comm_trait!(AllReduce); 12 | crate::non_comm!(NonAllReduce impl AllReduce); 13 | 14 | #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] 15 | #[repr(u8)] 16 | pub enum ReduceOp { 17 | Sum, 18 | Prod, 19 | Min, 20 | Max, 21 | Mean, 22 | } 23 | -------------------------------------------------------------------------------- /operators/src/all_reduce/nccl.rs: -------------------------------------------------------------------------------- 1 | use super::{args::Meta, AllReduce, Args, ReduceOp}; 2 | use crate::{ 3 | cuda::{Gpu, NcclNode}, 4 | rearrange, ByteOf, LaunchError, QueueAlloc, SchemeError, 5 | }; 6 | use std::{ 7 | slice::{from_raw_parts, from_raw_parts_mut}, 8 | sync::Arc, 9 | }; 10 | 11 | pub struct Operator { 12 | nccl: Arc, 13 | } 14 | 15 | impl AllReduce for Operator {} 16 | 17 | impl crate::Operator for Operator { 18 | type Hardware = Gpu; 19 | type TopoNode = NcclNode; 20 | type Args = Args; 21 | 22 | fn new(node: &Self::TopoNode) -> Self { 23 | Self { 24 | nccl: node.nccl.clone(), 25 | } 26 | } 27 | 28 | fn scheme( 29 | &mut self, 30 | _args: &Self::Args, 31 | _max_workspace_size: usize, 32 | ) -> Result { 33 | Ok(0) 34 | } 35 | 36 | fn launch( 37 | &self, 38 | args: &Self::Args, 39 | _workspace: &mut [ByteOf], 40 | queue_alloc: &QA, 41 | ) -> Result<(), LaunchError> 42 | where 43 | QA: QueueAlloc, 44 | { 45 | let Meta { dt, size } = args.meta()?; 46 | let &Args { 47 | pair: rearrange::Args { 48 | dst_base, src_base, .. 49 | }, 50 | op, 51 | .. 52 | } = args; 53 | 54 | let len = dt.nbytes() * size; 55 | self.nccl.all_reduce( 56 | unsafe { from_raw_parts_mut(dst_base, len) }, 57 | Some(unsafe { from_raw_parts(src_base, len) }), 58 | dt, 59 | convert_enum(op), 60 | queue_alloc.queue(), 61 | ); 62 | Ok(()) 63 | } 64 | } 65 | 66 | #[inline(always)] 67 | fn convert_enum(op: ReduceOp) -> nccl::ReduceType { 68 | use nccl::ReduceType::*; 69 | match op { 70 | ReduceOp::Sum => ncclSum, 71 | ReduceOp::Prod => ncclProd, 72 | ReduceOp::Min => ncclMin, 73 | ReduceOp::Max => ncclMax, 74 | ReduceOp::Mean => ncclAvg, 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /operators/src/attention/args.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | dyn_, 3 | fuesd_softmax::AttnMask, 4 | utils::{dim_distinct, rank_error, type_distinct}, 5 | ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout, 6 | }; 7 | use digit_layout::DigitLayout; 8 | use std::ptr::{null, null_mut}; 9 | 10 | pub struct Args { 11 | pub q_layout: TensorLayout, 12 | pub q_base: MutPtr, 13 | 14 | pub k_layout: TensorLayout, 15 | pub k_base: ConstPtr, 16 | 17 | pub v_layout: TensorLayout, 18 | pub v_base: ConstPtr, 19 | 20 | pub o_layout: TensorLayout, 21 | pub o_base: MutPtr, 22 | 23 | pub mask: AttnMask, 24 | } 25 | 26 | pub(super) struct Meta { 27 | pub dt: DigitLayout, 28 | pub nh: MaybeDyn, 29 | pub nkvh: MaybeDyn, 30 | pub seq: MaybeDyn, 31 | pub att: MaybeDyn, 32 | pub dh: MaybeDyn, 33 | } 34 | 35 | impl Args { 36 | pub(crate) fn new_null( 37 | mask: AttnMask, 38 | dt: DigitLayout, 39 | nh: MaybeDyn, 40 | nkvh: MaybeDyn, 41 | seq: MaybeDyn, 42 | att: MaybeDyn, 43 | dh: MaybeDyn, 44 | ) -> Self { 45 | let qo_layout = TensorLayout::new_dyn(dt, &[nh, seq, dh], &[dyn_(); 3]); 46 | let kv_layout = TensorLayout::new_dyn(dt, &[nkvh, att, dh], &[dyn_(); 3]); 47 | Self { 48 | q_layout: qo_layout.clone(), 49 | q_base: null_mut(), 50 | k_layout: kv_layout.clone(), 51 | k_base: null(), 52 | v_layout: kv_layout, 53 | v_base: null(), 54 | o_layout: qo_layout, 55 | o_base: null_mut(), 56 | mask, 57 | } 58 | } 59 | 60 | pub(super) fn meta(&self) -> Result { 61 | let Self { 62 | q_layout, 63 | k_layout, 64 | v_layout, 65 | o_layout, 66 | .. 67 | } = self; 68 | 69 | let &[nh_q, seq_q, dh_q] = q_layout.shape() else { 70 | return Err(rank_error("q", 3, q_layout.ndim())); 71 | }; 72 | let &[nkvh_k, att_k, dh_k] = k_layout.shape() else { 73 | return Err(rank_error("k", 3, k_layout.ndim())); 74 | }; 75 | let &[nkvh_v, att_v, dh_v] = v_layout.shape() else { 76 | return Err(rank_error("v", 3, v_layout.ndim())); 77 | }; 78 | let &[nh_o, seq_o, dh_o] = o_layout.shape() else { 79 | return Err(rank_error("o", 3, o_layout.ndim())); 80 | }; 81 | 82 | Ok(Meta { 83 | dt: type_distinct(&[q_layout.dt(), k_layout.dt(), v_layout.dt(), o_layout.dt()])?, 84 | nh: dim_distinct(&[nh_q, nh_o])?, 85 | nkvh: dim_distinct(&[nkvh_k, nkvh_v])?, 86 | seq: dim_distinct(&[seq_q, seq_o])?, 87 | att: dim_distinct(&[att_k, att_v])?, 88 | dh: dim_distinct(&[dh_q, dh_k, dh_v, dh_o])?, 89 | }) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /operators/src/attention/common_cpu.rs: -------------------------------------------------------------------------------- 1 | impl_op!(common_cpu, Cpu); 2 | -------------------------------------------------------------------------------- /operators/src/attention/cuda.rs: -------------------------------------------------------------------------------- 1 | impl_op!(cuda, Gpu); 2 | 3 | #[cfg(test)] 4 | mod test { 5 | use super::{super::Args, Operator}; 6 | use crate::{cuda::Gpu, ByteOf, Hardware, Operator as _, TensorLayout}; 7 | use digit_layout::{types as ty, DigitLayout}; 8 | 9 | fn dyn_args(dt: DigitLayout, nh: usize, seq: usize, att: usize) -> Args { 10 | use crate::dyn_; 11 | Args::new_null( 12 | crate::fuesd_softmax::AttnMask::Causal, 13 | dt, 14 | nh.into(), 15 | dyn_(), 16 | seq.into(), 17 | att.into(), 18 | dyn_(), 19 | ) 20 | } 21 | 22 | fn args( 23 | dt: DigitLayout, 24 | nh: usize, 25 | nkvh: usize, 26 | seq: usize, 27 | att: usize, 28 | dh: usize, 29 | q_base: *mut ByteOf, 30 | k_base: *const ByteOf, 31 | v_base: *const ByteOf, 32 | o_base: *mut ByteOf, 33 | ) -> Args { 34 | Args { 35 | q_layout: TensorLayout::new_contiguous(dt, &[nh, seq, dh]), 36 | k_layout: TensorLayout::new_contiguous(dt, &[nkvh, att, dh]), 37 | v_layout: TensorLayout::new_contiguous(dt, &[nkvh, att, dh]), 38 | o_layout: TensorLayout::new_contiguous(dt, &[nh, seq, dh]), 39 | q_base, 40 | k_base, 41 | v_base, 42 | o_base, 43 | mask: crate::fuesd_softmax::AttnMask::Causal, 44 | } 45 | } 46 | 47 | #[test] 48 | fn test_compile() { 49 | let Some(gpu) = Gpu::init() else { 50 | return; 51 | }; 52 | println!("{}", gpu.0.device().info()); 53 | 54 | let mut op = Operator::new(&gpu); 55 | let workspace = op.scheme(&dyn_args(ty::F16, 32, 7, 127), usize::MAX); 56 | println!("workspace: {workspace:?}"); 57 | } 58 | 59 | #[test] 60 | fn test_compute() { 61 | use super::super::common_cpu::Operator as RefOp; 62 | use crate::{ 63 | common_cpu::{Cpu, ThisThread}, 64 | cuda::cast_load, 65 | test_utils::{Diff, ErrorCollector}, 66 | }; 67 | use cuda::memcpy_d2h; 68 | use half::f16; 69 | use rand::Rng; 70 | 71 | let Some(gpu) = Gpu::init() else { 72 | return; 73 | }; 74 | 75 | let nh = 32; 76 | let nkvh = 4; 77 | let seq = 7; 78 | let att = 127; 79 | let dh = 64; 80 | 81 | let cpu_op = RefOp::new(&Cpu); 82 | let gpu_op = Operator::new(&gpu); 83 | 84 | let mut q = vec![0.0f64; nh * seq * dh]; 85 | let mut k = vec![0.0f64; nkvh * att * dh]; 86 | let mut v = vec![0.0f64; nkvh * att * dh]; 87 | let o = vec![0.0f64; nh * seq * dh]; 88 | rand::rng().fill(&mut q[..]); 89 | rand::rng().fill(&mut k[..]); 90 | rand::rng().fill(&mut v[..]); 91 | let k = k; 92 | let v = v; 93 | 94 | let o_ans = gpu.apply(|ctx| { 95 | let stream = ctx.stream(); 96 | #[cfg(use_nvidia)] 97 | let rt = &stream; 98 | #[cfg(use_iluvatar)] 99 | let rt = ctx; 100 | let mut q = cast_load(&q, f16::from_f64, &stream); 101 | let k = cast_load(&k, f16::from_f64, &stream); 102 | let v = cast_load(&v, f16::from_f64, &stream); 103 | let mut o = rt.malloc::(o.len()); 104 | gpu_op 105 | .launch( 106 | &args( 107 | ty::F16, 108 | nh, 109 | nkvh, 110 | seq, 111 | att, 112 | dh, 113 | q.as_mut_ptr(), 114 | k.as_ptr(), 115 | v.as_ptr(), 116 | o.as_mut_ptr(), 117 | ), 118 | &mut [], 119 | &stream, 120 | ) 121 | .unwrap(); 122 | 123 | let mut host = vec![f16::ZERO; nh * seq * dh]; 124 | memcpy_d2h(&mut host, &o); 125 | host 126 | }); 127 | 128 | let mut o_ref = o; 129 | cpu_op 130 | .launch( 131 | &args( 132 | ty::F64, 133 | nh, 134 | nkvh, 135 | seq, 136 | att, 137 | dh, 138 | q.as_mut_ptr().cast(), 139 | k.as_ptr().cast(), 140 | v.as_ptr().cast(), 141 | o_ref.as_mut_ptr().cast(), 142 | ), 143 | &mut [], 144 | &ThisThread, 145 | ) 146 | .unwrap(); 147 | 148 | let diff = o_ref 149 | .into_iter() 150 | .zip(o_ans) 151 | .map(|(a, b)| Diff::new(a, b.to_f64())) 152 | .collect::>(); 153 | 154 | let mut ec = ErrorCollector::new(f16::EPSILON.to_f64(), 1e-3); 155 | diff.into_iter().for_each(|diff| ec.push(diff)); 156 | println!("{ec}"); 157 | 158 | let (out, count) = ec.summary(); 159 | assert!(out * 1000 <= count); 160 | } 161 | } 162 | -------------------------------------------------------------------------------- /operators/src/attention/infini.rs: -------------------------------------------------------------------------------- 1 | impl_op!(infini, Device); 2 | -------------------------------------------------------------------------------- /operators/src/attention/mod.rs: -------------------------------------------------------------------------------- 1 | mod args; 2 | mod operator; 3 | 4 | pub use args::Args; 5 | 6 | crate::op_trait!(Attention); 7 | 8 | macro_rules! impl_op { 9 | ($dev:ident, $proc:ident) => { 10 | pub type Operator = super::operator::Operator< 11 | crate::$dev::$proc, 12 | crate::mat_mul::$dev::Operator, 13 | crate::fuesd_softmax::$dev::Operator, 14 | crate::rearrange::$dev::Operator, 15 | >; 16 | }; 17 | } 18 | 19 | #[cfg(any(use_cpu, test))] 20 | pub mod common_cpu; 21 | #[cfg(use_cuda)] 22 | pub mod cuda; 23 | #[cfg(use_infini)] 24 | pub mod infini; 25 | #[cfg(use_cl)] 26 | pub mod opencl; 27 | -------------------------------------------------------------------------------- /operators/src/attention/opencl.rs: -------------------------------------------------------------------------------- 1 | impl_op!(opencl, ClDevice); 2 | -------------------------------------------------------------------------------- /operators/src/attention_kv_cached/args.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | fuesd_softmax::AttnMask, 3 | utils::{dim_distinct, rank_error, type_distinct}, 4 | ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout, 5 | }; 6 | use digit_layout::DigitLayout; 7 | 8 | pub struct Args { 9 | pub q_layout: TensorLayout, 10 | pub q_base: MutPtr, 11 | 12 | pub k_layout: TensorLayout, 13 | pub k_base: ConstPtr, 14 | 15 | pub v_layout: TensorLayout, 16 | pub v_base: ConstPtr, 17 | 18 | pub o_layout: TensorLayout, 19 | pub o_base: MutPtr, 20 | 21 | pub k_cache_layout: TensorLayout, 22 | pub k_cache_base: MutPtr, 23 | 24 | pub v_cache_layout: TensorLayout, 25 | pub v_cache_base: MutPtr, 26 | 27 | pub mask: AttnMask, 28 | pub pos: MaybeDyn, 29 | } 30 | 31 | pub(super) struct Meta { 32 | pub dt: DigitLayout, 33 | pub nh: MaybeDyn, 34 | pub nkvh: MaybeDyn, 35 | pub dh: MaybeDyn, 36 | 37 | pub seq: MaybeDyn, 38 | } 39 | 40 | impl Args { 41 | #[allow(clippy::too_many_arguments)] 42 | pub fn new_null( 43 | q_layout: TensorLayout, 44 | k_layout: TensorLayout, 45 | v_layout: TensorLayout, 46 | o_layout: TensorLayout, 47 | k_cache_layout: TensorLayout, 48 | v_cache_layout: TensorLayout, 49 | mask: AttnMask, 50 | pos: MaybeDyn, 51 | ) -> Self { 52 | use std::ptr::{null, null_mut}; 53 | Self { 54 | q_layout, 55 | q_base: null_mut(), 56 | k_layout, 57 | k_base: null(), 58 | v_layout, 59 | v_base: null(), 60 | o_layout, 61 | o_base: null_mut(), 62 | k_cache_layout, 63 | k_cache_base: null_mut(), 64 | v_cache_layout, 65 | v_cache_base: null_mut(), 66 | mask, 67 | pos, 68 | } 69 | } 70 | 71 | pub(super) fn meta(&self) -> Result { 72 | let Self { 73 | q_layout, 74 | k_layout, 75 | v_layout, 76 | o_layout, 77 | k_cache_layout, 78 | v_cache_layout, 79 | .. 80 | } = self; 81 | 82 | let &[nh_q, seq_q, dh_q] = q_layout.shape() else { 83 | return Err(rank_error("q", 3, q_layout.ndim())); 84 | }; 85 | let &[nkvh_k, seq_k, dh_k] = k_layout.shape() else { 86 | return Err(rank_error("k", 3, k_layout.ndim())); 87 | }; 88 | let &[nkvh_v, seq_v, dh_v] = v_layout.shape() else { 89 | return Err(rank_error("v", 3, v_layout.ndim())); 90 | }; 91 | let &[nh_o, seq_o, dh_o] = o_layout.shape() else { 92 | return Err(rank_error("o", 3, o_layout.ndim())); 93 | }; 94 | let &[nkvh_kc, _buf, dh_kc] = k_cache_layout.shape() else { 95 | return Err(rank_error("k_cache", 3, k_cache_layout.ndim())); 96 | }; 97 | let &[nkvh_vc, _buf, dh_vc] = v_cache_layout.shape() else { 98 | return Err(rank_error("v_cache", 3, v_cache_layout.ndim())); 99 | }; 100 | 101 | Ok(Meta { 102 | dt: type_distinct(&[ 103 | q_layout.dt(), 104 | k_layout.dt(), 105 | v_layout.dt(), 106 | o_layout.dt(), 107 | k_cache_layout.dt(), 108 | v_cache_layout.dt(), 109 | ])?, 110 | nh: dim_distinct(&[nh_q, nh_o])?, 111 | nkvh: dim_distinct(&[nkvh_k, nkvh_v, nkvh_kc, nkvh_vc])?, 112 | dh: dim_distinct(&[dh_q, dh_k, dh_v, dh_o, dh_kc, dh_vc])?, 113 | seq: dim_distinct(&[seq_q, seq_k, seq_v, seq_o])?, 114 | }) 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /operators/src/attention_kv_cached/common_cpu.rs: -------------------------------------------------------------------------------- 1 | impl_op!(common_cpu, Cpu); 2 | -------------------------------------------------------------------------------- /operators/src/attention_kv_cached/infini.rs: -------------------------------------------------------------------------------- 1 | impl_op!(infini, Device); 2 | -------------------------------------------------------------------------------- /operators/src/attention_kv_cached/mod.rs: -------------------------------------------------------------------------------- 1 | mod args; 2 | mod operator; 3 | 4 | pub use args::Args; 5 | 6 | crate::op_trait!(AttnKVCached); 7 | 8 | macro_rules! impl_op { 9 | ($dev:ident, $proc:ident) => { 10 | pub type Operator = super::operator::Operator< 11 | crate::$dev::$proc, 12 | crate::rearrange::$dev::Operator, 13 | crate::attention::$dev::Operator, 14 | >; 15 | }; 16 | } 17 | 18 | #[cfg(any(use_cpu, test))] 19 | pub mod common_cpu; 20 | #[cfg(use_cuda)] 21 | pub mod cuda; 22 | #[cfg(use_infini)] 23 | pub mod infini; 24 | #[cfg(use_cl)] 25 | pub mod opencl; 26 | -------------------------------------------------------------------------------- /operators/src/attention_kv_cached/opencl.rs: -------------------------------------------------------------------------------- 1 | impl_op!(opencl, ClDevice); 2 | -------------------------------------------------------------------------------- /operators/src/broadcast/args.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | dyn_not_support, rearrange, shape_mismatch, strides_not_support, utils::type_distinct, 3 | Hardware, MaybeDyn, SchemeError, 4 | }; 5 | use ndarray_layout::ArrayLayout; 6 | 7 | pub struct Args { 8 | pub pair: rearrange::Args, 9 | pub root: usize, 10 | } 11 | 12 | impl AsRef> for Args { 13 | #[inline] 14 | fn as_ref(&self) -> &rearrange::Args { 15 | &self.pair 16 | } 17 | } 18 | 19 | pub(super) struct Meta { 20 | pub size: usize, 21 | } 22 | 23 | impl Args { 24 | pub(super) fn meta(&self) -> Result { 25 | let Self { 26 | pair: 27 | rearrange::Args { 28 | dst_layout, 29 | src_layout, 30 | .. 31 | }, 32 | .. 33 | } = self; 34 | 35 | let dt = type_distinct(&[dst_layout.dt(), src_layout.dt()])?; 36 | 37 | let Some(shape) = MaybeDyn::get_all(dst_layout.shape()) else { 38 | return Err(dyn_not_support("")); 39 | }; 40 | let Some(strides) = MaybeDyn::get_all(dst_layout.strides()) else { 41 | return Err(dyn_not_support("")); 42 | }; 43 | let dst = ArrayLayout::<2>::new(shape, strides, 0); 44 | let &[dst] = dst 45 | .merge_be(0, dst.ndim()) 46 | .ok_or(strides_not_support(""))? 47 | .shape() 48 | else { 49 | unreachable!() 50 | }; 51 | 52 | let Some(shape) = MaybeDyn::get_all(src_layout.shape()) else { 53 | return Err(dyn_not_support("")); 54 | }; 55 | let Some(strides) = MaybeDyn::get_all(src_layout.strides()) else { 56 | return Err(dyn_not_support("")); 57 | }; 58 | let src = ArrayLayout::<2>::new(shape, strides, 0); 59 | let &[src] = src 60 | .merge_be(0, src.ndim()) 61 | .ok_or(strides_not_support(""))? 62 | .shape() 63 | else { 64 | unreachable!() 65 | }; 66 | 67 | if dst != src { 68 | return Err(shape_mismatch("")); 69 | } 70 | 71 | Ok(Meta { 72 | size: dst * dt.nbytes(), 73 | }) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /operators/src/broadcast/common_cpu/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{args::Meta, Args, Broadcast}; 2 | use crate::{ 3 | common_cpu::{Cpu, InprocNode}, 4 | rearrange, ByteOf, LaunchError, QueueAlloc, SchemeError, TopoNode, 5 | }; 6 | use std::ptr::{addr_eq, copy, copy_nonoverlapping}; 7 | 8 | pub struct Operator(InprocNode); 9 | 10 | impl Broadcast> for Operator {} 11 | 12 | impl crate::Operator for Operator { 13 | type Hardware = Cpu; 14 | type TopoNode = InprocNode; 15 | type Args = Args; 16 | 17 | fn new(node: &Self::TopoNode) -> Self { 18 | assert!(node.group_size().is_power_of_two()); 19 | Self(node.clone()) 20 | } 21 | 22 | fn scheme( 23 | &mut self, 24 | _args: &Self::Args, 25 | _max_workspace_size: usize, 26 | ) -> Result { 27 | Ok(0) 28 | } 29 | 30 | fn launch( 31 | &self, 32 | args: &Self::Args, 33 | _workspace: &mut [ByteOf], 34 | _queue_alloc: &QA, 35 | ) -> Result<(), LaunchError> 36 | where 37 | QA: QueueAlloc, 38 | { 39 | let rank = self.0.rank(); 40 | let group_size = self.0.group_size(); 41 | 42 | let Meta { size } = args.meta()?; 43 | let &Args { 44 | pair: rearrange::Args { 45 | dst_base, src_base, .. 46 | }, 47 | root, 48 | .. 49 | } = args; 50 | 51 | let _guard = self.0.wait(); 52 | if rank == root { 53 | for i in 0..group_size { 54 | if i != rank { 55 | self.0.send(i, src_base as _) 56 | } 57 | } 58 | if !addr_eq(dst_base, src_base) { 59 | unsafe { copy(src_base, dst_base, size) } 60 | } 61 | for _ in 0..group_size - 1 { 62 | assert_eq!(self.0.recv(), usize::MAX) 63 | } 64 | } else { 65 | unsafe { copy_nonoverlapping(self.0.recv() as _, dst_base, size) } 66 | self.0.send(root, usize::MAX) 67 | } 68 | Ok(()) 69 | } 70 | } 71 | 72 | #[test] 73 | fn test_comm() { 74 | use crate::{common_cpu::ThisThread, Operator as _, TensorLayout}; 75 | use digit_layout::types::U32; 76 | 77 | InprocNode::new(4) 78 | .into_iter() 79 | .map(|node| { 80 | std::thread::spawn(move || { 81 | let mut buf = [node.rank() as u32; 8]; 82 | let op = Operator::new(&node); 83 | op.launch( 84 | &Args { 85 | pair: rearrange::Args { 86 | dst_layout: TensorLayout::new_contiguous(U32, &[8]), 87 | dst_base: buf.as_mut_ptr().cast(), 88 | src_layout: TensorLayout::new_contiguous(U32, &[8]), 89 | src_base: buf.as_ptr().cast(), 90 | }, 91 | root: 1, 92 | }, 93 | &mut [], 94 | &ThisThread, 95 | ) 96 | .unwrap(); 97 | buf 98 | }) 99 | }) 100 | .collect::>() 101 | .into_iter() 102 | .for_each(|h| assert_eq!(h.join().unwrap(), [1; 8])) 103 | } 104 | -------------------------------------------------------------------------------- /operators/src/broadcast/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(any(use_cpu, test))] 2 | pub mod common_cpu; 3 | #[cfg(use_nccl)] 4 | pub mod nccl; 5 | 6 | mod args; 7 | pub use args::Args; 8 | 9 | crate::comm_trait!(Broadcast); 10 | crate::non_comm!(NonBroadcast impl Broadcast); 11 | -------------------------------------------------------------------------------- /operators/src/broadcast/nccl/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{args::Meta, Args, Broadcast}; 2 | use crate::{ 3 | cuda::{Gpu, NcclNode}, 4 | rearrange, ByteOf, LaunchError, QueueAlloc, SchemeError, 5 | }; 6 | use std::{ 7 | slice::{from_raw_parts, from_raw_parts_mut}, 8 | sync::Arc, 9 | }; 10 | 11 | pub struct Operator { 12 | nccl: Arc, 13 | } 14 | 15 | impl Broadcast for Operator {} 16 | 17 | impl crate::Operator for Operator { 18 | type Hardware = Gpu; 19 | type TopoNode = NcclNode; 20 | type Args = Args; 21 | 22 | fn new(node: &Self::TopoNode) -> Self { 23 | Self { 24 | nccl: node.nccl.clone(), 25 | } 26 | } 27 | 28 | fn scheme( 29 | &mut self, 30 | _args: &Self::Args, 31 | _max_workspace_size: usize, 32 | ) -> Result { 33 | Ok(0) 34 | } 35 | 36 | fn launch( 37 | &self, 38 | args: &Self::Args, 39 | _workspace: &mut [ByteOf], 40 | queue_alloc: &QA, 41 | ) -> Result<(), LaunchError> 42 | where 43 | QA: QueueAlloc, 44 | { 45 | let Meta { size } = args.meta()?; 46 | let &Args { 47 | pair: rearrange::Args { 48 | dst_base, src_base, .. 49 | }, 50 | root, 51 | .. 52 | } = args; 53 | self.nccl.broadcast( 54 | unsafe { from_raw_parts_mut(dst_base, size) }, 55 | Some(unsafe { from_raw_parts(src_base, size) }), 56 | root as _, 57 | queue_alloc.queue(), 58 | ); 59 | Ok(()) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /operators/src/common/blob.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | alloc::{alloc, dealloc, Layout}, 3 | ops::{Deref, DerefMut}, 4 | ptr::NonNull, 5 | slice::{from_raw_parts, from_raw_parts_mut}, 6 | }; 7 | 8 | pub struct Blob { 9 | ptr: NonNull, 10 | len: usize, 11 | } 12 | 13 | impl Blob { 14 | #[inline] 15 | pub fn new(size: usize) -> Self { 16 | Self { 17 | ptr: NonNull::new(unsafe { alloc(layout(size)) }).unwrap(), 18 | len: size, 19 | } 20 | } 21 | } 22 | 23 | impl Drop for Blob { 24 | #[inline] 25 | fn drop(&mut self) { 26 | let &mut Blob { ptr, len } = self; 27 | unsafe { dealloc(ptr.as_ptr(), layout(len)) } 28 | } 29 | } 30 | 31 | #[inline(always)] 32 | const fn layout(size: usize) -> Layout { 33 | unsafe { Layout::from_size_align_unchecked(size, align_of::()) } 34 | } 35 | 36 | impl Deref for Blob { 37 | type Target = [u8]; 38 | #[inline] 39 | fn deref(&self) -> &[u8] { 40 | unsafe { from_raw_parts(self.ptr.as_ptr(), self.len) } 41 | } 42 | } 43 | 44 | impl DerefMut for Blob { 45 | #[inline] 46 | fn deref_mut(&mut self) -> &mut [u8] { 47 | unsafe { from_raw_parts_mut(self.ptr.as_ptr(), self.len) } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /operators/src/common/calculator.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | cmp::Ordering, 3 | collections::{BTreeSet, HashMap}, 4 | ops::Range, 5 | }; 6 | 7 | #[derive(Debug)] 8 | pub struct OffsetCalculator { 9 | alignment: usize, 10 | free_list: BTreeSet, 11 | heads: HashMap, 12 | tails: HashMap, 13 | } 14 | 15 | impl OffsetCalculator { 16 | pub fn new(alignment: usize) -> Self { 17 | Self { 18 | alignment, 19 | free_list: BTreeSet::new(), 20 | heads: HashMap::new(), 21 | tails: HashMap::new(), 22 | } 23 | } 24 | 25 | pub fn put(&mut self, range: &Range) { 26 | let len = range.len().div_ceil(self.alignment) * self.alignment; 27 | if len == 0 { 28 | return; 29 | } 30 | 31 | let mut head = range.start; 32 | let mut tail = head + len; 33 | if let Some(len_) = self.tails.remove(&head) { 34 | head -= len_; 35 | assert!(self.free_list.remove(&Area { 36 | off: head, 37 | len: len_, 38 | })); 39 | assert_eq!(self.heads.remove(&head), Some(len_)); 40 | } 41 | if let Some(len_) = self.heads.remove(&tail) { 42 | assert!(self.free_list.remove(&Area { 43 | off: tail, 44 | len: len_, 45 | })); 46 | tail += len_; 47 | assert_eq!(self.tails.remove(&tail), Some(len_)); 48 | } 49 | 50 | self.insert_area(Area { 51 | off: head, 52 | len: tail - head, 53 | }) 54 | } 55 | 56 | pub fn take(&mut self, expect: usize) -> Option> { 57 | let len = expect.div_ceil(self.alignment) * self.alignment; 58 | if len == 0 { 59 | return Some(usize::MAX..usize::MAX); 60 | } 61 | 62 | let &free = self.free_list.range(Area { off: 0, len }..).next()?; 63 | 64 | let head = free.off; 65 | let tail = free.off + free.len; 66 | 67 | self.free_list.remove(&free); 68 | self.heads.remove(&head); 69 | self.tails.remove(&tail); 70 | 71 | if free.len > len { 72 | self.insert_area(Area { 73 | off: free.off + len, 74 | len: free.len - len, 75 | }) 76 | } 77 | 78 | Some(head..head + expect) 79 | } 80 | 81 | fn insert_area(&mut self, area: Area) { 82 | self.free_list.insert(area); 83 | self.heads.insert(area.off, area.len); 84 | self.tails.insert(area.off + area.len, area.len); 85 | } 86 | } 87 | 88 | #[derive(Clone, Copy, PartialEq, Eq, Debug)] 89 | struct Area { 90 | off: usize, 91 | len: usize, 92 | } 93 | 94 | impl PartialOrd for Area { 95 | fn partial_cmp(&self, other: &Self) -> Option { 96 | Some(self.cmp(other)) 97 | } 98 | } 99 | 100 | impl Ord for Area { 101 | fn cmp(&self, other: &Self) -> Ordering { 102 | match self.len.cmp(&other.len) { 103 | Ordering::Equal => self.off.cmp(&other.off), 104 | ord => ord, 105 | } 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /operators/src/common/diversity.rs: -------------------------------------------------------------------------------- 1 | use lru::LruCache; 2 | use std::{hash::Hash, num::NonZeroUsize, sync::Mutex}; 3 | 4 | #[derive(Clone, Debug)] 5 | pub struct SchemeCacheSize { 6 | pub low: usize, 7 | pub medium: usize, 8 | pub high: usize, 9 | } 10 | 11 | impl Default for SchemeCacheSize { 12 | fn default() -> Self { 13 | Self { 14 | low: 4, 15 | medium: 16, 16 | high: 64, 17 | } 18 | } 19 | } 20 | 21 | #[derive(Clone, Copy, Debug)] 22 | pub enum SchemeDiversity { 23 | Low, 24 | Medium, 25 | High, 26 | } 27 | 28 | impl SchemeCacheSize { 29 | pub fn new_cache(&self, level: SchemeDiversity) -> Mutex> { 30 | let size = match level { 31 | SchemeDiversity::Low => self.low, 32 | SchemeDiversity::Medium => self.medium, 33 | SchemeDiversity::High => self.high, 34 | }; 35 | Mutex::new(LruCache::new(NonZeroUsize::new(size).unwrap())) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /operators/src/common/error.rs: -------------------------------------------------------------------------------- 1 | #[derive(Clone, Copy, PartialEq, Eq, Debug)] 2 | pub enum SchemeErrorKind { 3 | TypeNotSupport, 4 | TypeMismatch, 5 | RankNotSupport, 6 | RankMismatch, 7 | ShapeNotSupport, 8 | ShapeMismatch, 9 | StridesNotSupport, 10 | ArgsNotSupport, 11 | DynamicNotSupport, 12 | } 13 | 14 | #[derive(Clone, Debug)] 15 | pub struct SchemeError { 16 | pub kind: SchemeErrorKind, 17 | pub info: String, 18 | } 19 | 20 | #[derive(Clone, Copy, PartialEq, Eq, Debug)] 21 | pub enum LaunchErrorKind { 22 | Scheme(SchemeErrorKind), 23 | ExecutionFailed, 24 | } 25 | 26 | #[derive(Clone, Debug)] 27 | pub struct LaunchError { 28 | pub kind: LaunchErrorKind, 29 | pub info: String, 30 | } 31 | 32 | impl From for LaunchError { 33 | fn from(SchemeError { kind, info }: SchemeError) -> Self { 34 | Self { 35 | kind: LaunchErrorKind::Scheme(kind), 36 | info, 37 | } 38 | } 39 | } 40 | 41 | pub(super) mod functions { 42 | use super::{LaunchError, LaunchErrorKind::*, SchemeError, SchemeErrorKind::*}; 43 | 44 | macro_rules! builder { 45 | ($ty:ident: $name:ident $kind:expr) => { 46 | #[inline] 47 | pub fn $name(info: impl Into) -> $ty { 48 | $ty { 49 | kind: $kind, 50 | info: info.into(), 51 | } 52 | } 53 | }; 54 | } 55 | 56 | builder!(SchemeError: type_not_support TypeNotSupport ); 57 | builder!(SchemeError: type_mismatch TypeMismatch ); 58 | builder!(SchemeError: rank_mismatch RankMismatch ); 59 | builder!(SchemeError: rank_not_support RankNotSupport ); 60 | builder!(SchemeError: shape_not_support ShapeNotSupport ); 61 | builder!(SchemeError: shape_mismatch ShapeMismatch ); 62 | builder!(SchemeError: strides_not_support StridesNotSupport); 63 | builder!(SchemeError: args_not_support ArgsNotSupport ); 64 | builder!(SchemeError: dyn_not_support DynamicNotSupport); 65 | 66 | builder!(LaunchError: execution_failed ExecutionFailed ); 67 | } 68 | -------------------------------------------------------------------------------- /operators/src/common/maybe_dyn.rs: -------------------------------------------------------------------------------- 1 | pub trait DynVal { 2 | fn default_dyn() -> Self; 3 | fn is_dynamic(&self) -> bool; 4 | } 5 | 6 | impl DynVal for isize { 7 | #[inline] 8 | fn default_dyn() -> Self { 9 | Self::MAX 10 | } 11 | #[inline] 12 | fn is_dynamic(&self) -> bool { 13 | *self == Self::MAX 14 | } 15 | } 16 | 17 | impl DynVal for usize { 18 | #[inline] 19 | fn default_dyn() -> Self { 20 | Self::MAX 21 | } 22 | #[inline] 23 | fn is_dynamic(&self) -> bool { 24 | *self == Self::MAX 25 | } 26 | } 27 | 28 | impl DynVal for f32 { 29 | #[inline] 30 | fn default_dyn() -> Self { 31 | Self::INFINITY 32 | } 33 | #[inline] 34 | fn is_dynamic(&self) -> bool { 35 | self.is_infinite() && self.is_sign_positive() 36 | } 37 | } 38 | 39 | #[derive(Clone, Copy, PartialEq, Eq, Debug)] 40 | #[repr(transparent)] 41 | pub struct MaybeDyn(pub T); 42 | 43 | impl From for MaybeDyn { 44 | #[inline] 45 | fn from(value: T) -> Self { 46 | Self(value) 47 | } 48 | } 49 | 50 | impl MaybeDyn { 51 | #[inline] 52 | pub fn dynamic() -> Self { 53 | Self(T::default_dyn()) 54 | } 55 | #[inline] 56 | pub fn is_dynamic(&self) -> bool { 57 | self.0.is_dynamic() 58 | } 59 | #[inline] 60 | pub fn get_static(&self) -> Option<&T> { 61 | if !self.is_dynamic() { 62 | Some(&self.0) 63 | } else { 64 | None 65 | } 66 | } 67 | } 68 | 69 | #[inline(always)] 70 | pub fn dyn_() -> MaybeDyn { 71 | MaybeDyn::dynamic() 72 | } 73 | 74 | #[derive(Clone, Copy, PartialEq, Eq, Debug)] 75 | pub enum MergeError { 76 | EmptyIter, 77 | NotMatch, 78 | } 79 | 80 | impl MaybeDyn { 81 | pub fn merge<'a>(iter: impl IntoIterator) -> Result<&'a Self, MergeError> { 82 | let mut iter = iter.into_iter(); 83 | let mut acc = iter.next().ok_or(MergeError::EmptyIter)?; 84 | for it in iter { 85 | if it.is_dynamic() { 86 | // Nothing to do 87 | } else if acc.is_dynamic() { 88 | acc = it; 89 | } else if acc.0 != it.0 { 90 | return Err(MergeError::NotMatch); 91 | } 92 | } 93 | Ok(acc) 94 | } 95 | 96 | pub fn get_all(slice: &[Self]) -> Option<&[T]> { 97 | if slice.iter().any(|arg| arg.is_dynamic()) { 98 | None 99 | } else { 100 | Some(unsafe { std::slice::from_raw_parts(slice.as_ptr().cast(), slice.len()) }) 101 | } 102 | } 103 | } 104 | 105 | #[inline] 106 | pub(crate) fn static_from(arg: &MaybeDyn) -> Result<&T, SchemeError> { 107 | arg.get_static().ok_or_else(|| dyn_not_support("")) 108 | } 109 | 110 | macro_rules! get_static { 111 | ($($name:ident)*) => { 112 | $( let $name = *$crate::static_from(&$name)?; )* 113 | }; 114 | } 115 | 116 | pub(crate) use get_static; 117 | 118 | use super::{dyn_not_support, SchemeError}; 119 | -------------------------------------------------------------------------------- /operators/src/common/mod.rs: -------------------------------------------------------------------------------- 1 | mod blob; 2 | mod calculator; 3 | mod diversity; 4 | mod error; 5 | mod maybe_dyn; 6 | mod pool; 7 | mod tensor; 8 | mod unsigned; 9 | mod workspace; 10 | 11 | pub use blob::Blob; 12 | pub use calculator::OffsetCalculator; 13 | pub use error::{functions::*, LaunchError, LaunchErrorKind, SchemeError, SchemeErrorKind}; 14 | pub use maybe_dyn::{dyn_, DynVal, MaybeDyn}; 15 | pub use pool::Pool; 16 | pub use tensor::TensorLayout; 17 | pub use unsigned::Unsigned; 18 | pub use workspace::Workspace; 19 | 20 | pub(crate) use diversity::{SchemeCacheSize, SchemeDiversity}; 21 | pub(crate) use maybe_dyn::{get_static, static_from}; 22 | pub(crate) use workspace::WorkspaceCollector; 23 | 24 | pub mod utils { 25 | use super::{rank_not_support, shape_mismatch, type_mismatch, MaybeDyn, SchemeError}; 26 | use digit_layout::DigitLayout; 27 | 28 | #[cfg(any(use_cuda, use_cl))] 29 | #[inline] 30 | pub(crate) const fn gcd(mut a: usize, mut b: usize) -> usize { 31 | while b != 0 { 32 | let rem = a % b; 33 | a = b; 34 | b = rem; 35 | } 36 | a 37 | } 38 | 39 | #[inline] 40 | pub(crate) fn type_distinct(pairs: &[DigitLayout]) -> Result { 41 | let [dt, tail @ ..] = pairs else { 42 | unreachable!("pairs empty"); 43 | }; 44 | if tail.iter().all(|it| it == dt) { 45 | Ok(*dt) 46 | } else { 47 | Err(type_mismatch(format!("{pairs:?} are not distinct"))) 48 | } 49 | } 50 | 51 | #[inline] 52 | pub(crate) fn rank_error(arg: &str, expected: usize, actual: usize) -> SchemeError { 53 | rank_not_support(format!("{arg}.ndim = {actual}, {expected} expected")) 54 | } 55 | 56 | #[inline] 57 | pub(crate) fn dim_distinct(args: &[MaybeDyn]) -> Result, SchemeError> { 58 | MaybeDyn::merge(args) 59 | .copied() 60 | .map_err(|_| shape_mismatch(format!("{args:?} are not distinct"))) 61 | } 62 | } 63 | 64 | #[cfg(test)] 65 | #[allow(dead_code)] 66 | pub(crate) mod test_utils { 67 | use std::fmt; 68 | 69 | pub struct Diff { 70 | pub abs: f64, 71 | pub rel: f64, 72 | } 73 | 74 | impl Diff { 75 | pub fn new(a: f64, b: f64) -> Self { 76 | let abs = (a - b).abs(); 77 | let rel = abs / (a.abs() + b.abs() + f64::EPSILON); 78 | Self { abs, rel } 79 | } 80 | } 81 | 82 | pub struct ErrorCollector { 83 | threshold: Diff, 84 | max_diff: Diff, 85 | outliers: Vec, 86 | count: usize, 87 | } 88 | 89 | impl ErrorCollector { 90 | pub fn new(abs: f64, rel: f64) -> Self { 91 | Self { 92 | threshold: Diff { abs, rel }, 93 | max_diff: Diff { abs: 0., rel: 0. }, 94 | outliers: vec![], 95 | count: 0, 96 | } 97 | } 98 | 99 | pub fn push(&mut self, diff: Diff) { 100 | self.max_diff.abs = f64::max(self.max_diff.abs, diff.abs); 101 | self.max_diff.rel = f64::max(self.max_diff.rel, diff.rel); 102 | 103 | if diff.abs > self.threshold.abs && diff.rel > self.threshold.rel { 104 | self.outliers.push(self.count); 105 | } 106 | 107 | self.count += 1; 108 | } 109 | 110 | pub fn summary(self) -> (usize, usize) { 111 | (self.outliers.len(), self.count) 112 | } 113 | 114 | pub fn outliers(&self) -> &[usize] { 115 | &self.outliers 116 | } 117 | } 118 | 119 | impl fmt::Display for ErrorCollector { 120 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 121 | write!( 122 | f, 123 | "abs: {:.3e}, rel: {:.3e}, outliers: {}/{}", 124 | self.max_diff.abs, 125 | self.max_diff.rel, 126 | self.outliers.len(), 127 | self.count, 128 | ) 129 | } 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /operators/src/common/pool.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | alloc::{alloc, dealloc, Layout}, 3 | ptr::null_mut, 4 | sync::atomic::{ 5 | AtomicPtr, 6 | Ordering::{Acquire, Release}, 7 | }, 8 | }; 9 | 10 | #[repr(transparent)] 11 | pub struct Pool(AtomicPtr>); 12 | 13 | struct Item { 14 | value: T, 15 | next: *mut Item, 16 | } 17 | 18 | impl Default for Pool { 19 | #[inline] 20 | fn default() -> Self { 21 | Self::new() 22 | } 23 | } 24 | 25 | impl Pool { 26 | #[inline] 27 | pub fn new() -> Self { 28 | Self(AtomicPtr::new(null_mut())) 29 | } 30 | 31 | #[inline] 32 | fn update(&self, current: *mut Item, new: *mut Item) -> Option<*mut Item> { 33 | self.0 34 | .compare_exchange_weak(current, new, Release, Acquire) 35 | .err() 36 | } 37 | 38 | pub fn push(&self, value: T) { 39 | let item = unsafe { alloc(Layout::new::>()) } as *mut Item; 40 | unsafe { 41 | item.write(Item { 42 | value, 43 | next: self.0.load(Acquire), 44 | }) 45 | }; 46 | while let Some(current) = self.update(unsafe { (*item).next }, item) { 47 | unsafe { (*item).next = current }; 48 | } 49 | } 50 | 51 | pub fn pop(&self) -> Option { 52 | let mut item = self.0.load(Acquire); 53 | while !item.is_null() { 54 | if let Some(current) = self.update(item, unsafe { (*item).next }) { 55 | item = current; 56 | } else { 57 | break; 58 | } 59 | } 60 | 61 | if item.is_null() { 62 | None 63 | } else { 64 | let Item { value, .. } = unsafe { item.read() }; 65 | unsafe { dealloc(item as _, Layout::new::>()) }; 66 | Some(value) 67 | } 68 | } 69 | } 70 | 71 | impl Drop for Pool { 72 | fn drop(&mut self) { 73 | while self.pop().is_some() {} 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /operators/src/common/tensor.rs: -------------------------------------------------------------------------------- 1 | use crate::MaybeDyn; 2 | use digit_layout::DigitLayout; 3 | use ndarray_layout::ArrayLayout; 4 | use std::{ 5 | alloc::{alloc, dealloc, Layout}, 6 | ptr::{copy_nonoverlapping, NonNull}, 7 | slice::from_raw_parts, 8 | }; 9 | 10 | /// | field | type | 11 | /// |:--------:|:-------------:| 12 | /// | dt | DigitLayout | 13 | /// | ndim | u64 | 14 | /// | shape | [usize; ndim] | 15 | /// | strides | [isize; ndim] | 16 | #[repr(transparent)] 17 | pub struct TensorLayout(NonNull); 18 | 19 | impl TensorLayout { 20 | pub fn new_dyn( 21 | dt: DigitLayout, 22 | shape: &[MaybeDyn], 23 | strides: &[MaybeDyn], 24 | ) -> Self { 25 | let shape: &[usize] = unsafe { std::mem::transmute(shape) }; 26 | let strides: &[isize] = unsafe { std::mem::transmute(strides) }; 27 | Self::new(dt, shape, strides) 28 | } 29 | 30 | pub fn new(dt: DigitLayout, shape: &[usize], strides: &[isize]) -> Self { 31 | assert_eq!(shape.len(), strides.len()); 32 | 33 | unsafe { 34 | let ptr = alloc(Self::layout(shape.len())); 35 | 36 | let cursor: *mut DigitLayout = ptr.cast(); 37 | cursor.write(dt); 38 | let cursor: *mut u64 = cursor.add(1).cast(); 39 | cursor.write(shape.len() as _); 40 | let cursor: *mut usize = cursor.add(1).cast(); 41 | copy_nonoverlapping(shape.as_ptr(), cursor, shape.len()); 42 | let cursor: *mut isize = cursor.add(shape.len()).cast(); 43 | copy_nonoverlapping(strides.as_ptr(), cursor, strides.len()); 44 | 45 | Self(NonNull::new_unchecked(ptr as _)) 46 | } 47 | } 48 | 49 | pub fn new_contiguous(dt: DigitLayout, shape: &[usize]) -> Self { 50 | let mut strides = shape 51 | .iter() 52 | .rev() 53 | .scan(dt.nbytes() as isize, |mul, &d| { 54 | let stride = *mul; 55 | *mul *= d as isize; 56 | Some(stride) 57 | }) 58 | .collect::>(); 59 | strides.reverse(); 60 | Self::new(dt, shape, &strides) 61 | } 62 | 63 | #[inline] 64 | pub fn from_arr(dt: DigitLayout, arr: &ArrayLayout) -> Self { 65 | Self::new(dt, arr.shape(), arr.strides()) 66 | } 67 | 68 | #[inline] 69 | pub fn dt(&self) -> DigitLayout { 70 | let ptr = self.0.cast(); 71 | unsafe { *ptr.as_ref() } 72 | } 73 | 74 | #[inline] 75 | pub fn ndim(&self) -> usize { 76 | let ptr = self.0.cast::().as_ptr(); 77 | unsafe { *ptr.add(1) as _ } 78 | } 79 | 80 | #[inline] 81 | pub fn shape(&self) -> &[MaybeDyn] { 82 | let ptr = self.0.cast::>().as_ptr(); 83 | let len = self.ndim(); 84 | unsafe { from_raw_parts(ptr.add(2), len) } 85 | } 86 | 87 | #[inline] 88 | pub fn strides(&self) -> &[MaybeDyn] { 89 | let ptr = self.0.cast::>().as_ptr(); 90 | let len = self.ndim(); 91 | unsafe { from_raw_parts(ptr.add(2 + len), len) } 92 | } 93 | 94 | #[inline(always)] 95 | fn layout(ndim: usize) -> Layout { 96 | Layout::array::(2 + ndim * 2).unwrap() 97 | } 98 | } 99 | 100 | impl Clone for TensorLayout { 101 | #[inline] 102 | fn clone(&self) -> Self { 103 | let layout = Self::layout(self.ndim()); 104 | let src = self.0.cast::().as_ptr(); 105 | unsafe { 106 | let dst = alloc(layout); 107 | copy_nonoverlapping(src, dst, layout.size()); 108 | Self(NonNull::new_unchecked(dst as _)) 109 | } 110 | } 111 | } 112 | 113 | impl Drop for TensorLayout { 114 | #[inline] 115 | fn drop(&mut self) { 116 | let ptr = self.0.cast().as_ptr(); 117 | let layout = Self::layout(self.ndim()); 118 | unsafe { dealloc(ptr, layout) } 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /operators/src/common/unsigned.rs: -------------------------------------------------------------------------------- 1 | pub trait Unsigned { 2 | fn from(u: usize) -> Self; 3 | fn val(&self) -> usize; 4 | } 5 | 6 | macro_rules! impl_idx { 7 | ($( $ty:ty )+) => { 8 | $( 9 | impl Unsigned for $ty { 10 | #[inline] 11 | fn from(u: usize) -> Self { 12 | u as _ 13 | } 14 | 15 | #[inline] 16 | fn val(&self) -> usize { 17 | *self as _ 18 | } 19 | } 20 | )+ 21 | }; 22 | } 23 | 24 | impl_idx! { u8 u16 u32 u64 u128 usize } 25 | -------------------------------------------------------------------------------- /operators/src/common/workspace.rs: -------------------------------------------------------------------------------- 1 | use crate::{ByteOf, QueueAlloc}; 2 | use std::{ 3 | mem::ManuallyDrop, 4 | ops::{Deref, DerefMut}, 5 | }; 6 | 7 | pub enum Workspace<'a, QA: QueueAlloc> { 8 | Ext(&'a mut [ByteOf]), 9 | Int(ManuallyDrop, &'a QA), 10 | } 11 | 12 | impl<'a, QA: QueueAlloc> Workspace<'a, QA> { 13 | #[inline] 14 | pub fn new(queue_alloc: &'a QA, ext: &'a mut [ByteOf], size: usize) -> Self { 15 | if ext.len() >= size { 16 | Self::Ext(ext) 17 | } else { 18 | let dev_mem = queue_alloc.alloc(size); 19 | Self::Int(ManuallyDrop::new(dev_mem), queue_alloc) 20 | } 21 | } 22 | } 23 | 24 | impl Deref for Workspace<'_, QA> { 25 | type Target = [ByteOf]; 26 | #[inline] 27 | fn deref(&self) -> &Self::Target { 28 | match self { 29 | Self::Ext(ext) => ext, 30 | Self::Int(dev_mem, _) => dev_mem, 31 | } 32 | } 33 | } 34 | 35 | impl DerefMut for Workspace<'_, QA> { 36 | #[inline] 37 | fn deref_mut(&mut self) -> &mut Self::Target { 38 | match self { 39 | Self::Ext(ext) => ext, 40 | Self::Int(dev_mem, _) => dev_mem, 41 | } 42 | } 43 | } 44 | 45 | impl Drop for Workspace<'_, QA> { 46 | fn drop(&mut self) { 47 | match self { 48 | Self::Ext(_) => {} 49 | Self::Int(dev_mem, qa) => qa.free(unsafe { ManuallyDrop::take(dev_mem) }), 50 | } 51 | } 52 | } 53 | 54 | pub(crate) struct WorkspaceCollector { 55 | base: Vec, 56 | sub: usize, 57 | } 58 | 59 | impl WorkspaceCollector { 60 | #[inline] 61 | pub fn new() -> Self { 62 | Self { 63 | base: Vec::with_capacity(2), 64 | sub: 0, 65 | } 66 | } 67 | 68 | #[inline] 69 | pub fn push_base(&mut self, base: usize) { 70 | self.base.push(base) 71 | } 72 | 73 | #[inline] 74 | pub fn push_sub(&mut self, sub: usize) { 75 | self.sub = self.sub.max(sub) 76 | } 77 | 78 | pub fn cauculate(mut self, max_workspace_size: usize) -> usize { 79 | self.base.push(self.sub); 80 | let mut ans = 0; 81 | for s in self.base { 82 | if ans + s <= max_workspace_size { 83 | ans += s; 84 | } else { 85 | return ans; 86 | } 87 | } 88 | ans 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /operators/src/conv/args.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | utils::{dim_distinct, rank_error, type_distinct}, 3 | ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout, 4 | }; 5 | use digit_layout::DigitLayout; 6 | 7 | pub struct Args { 8 | pub y_layout: TensorLayout, 9 | pub y_base: MutPtr, 10 | pub x_layout: TensorLayout, 11 | pub x_base: ConstPtr, 12 | pub w_layout: TensorLayout, 13 | pub w_base: ConstPtr, 14 | pub b_layout: TensorLayout, 15 | pub b_base: ConstPtr, 16 | pub strides: [usize; 2], 17 | pub dilations: [usize; 2], 18 | pub pads: [usize; 4], 19 | } 20 | 21 | pub(crate) struct Meta { 22 | pub dt: DigitLayout, 23 | pub n: MaybeDyn, 24 | pub m: MaybeDyn, 25 | pub c: MaybeDyn, 26 | pub h: MaybeDyn, 27 | pub w: MaybeDyn, 28 | pub hy: MaybeDyn, 29 | pub wy: MaybeDyn, 30 | pub hk: MaybeDyn, 31 | pub wk: MaybeDyn, 32 | } 33 | 34 | impl Args { 35 | pub(super) fn meta(&self) -> Result { 36 | let Self { 37 | y_layout, 38 | x_layout, 39 | w_layout, 40 | b_layout, 41 | .. 42 | } = self; 43 | 44 | let &[ny, my, hy, wy] = y_layout.shape() else { 45 | return Err(rank_error("y", 4, y_layout.ndim())); 46 | }; 47 | let &[n, c, h, w] = x_layout.shape() else { 48 | return Err(rank_error("x", 4, x_layout.ndim())); 49 | }; 50 | let &[m, ck, hk, wk] = w_layout.shape() else { 51 | return Err(rank_error("w", 4, w_layout.ndim())); 52 | }; 53 | let &[mb] = b_layout.shape() else { 54 | return Err(rank_error("b", 1, b_layout.ndim())); 55 | }; 56 | 57 | Ok(Meta { 58 | dt: type_distinct(&[y_layout.dt(), x_layout.dt(), w_layout.dt(), b_layout.dt()])?, 59 | n: dim_distinct(&[n, ny])?, 60 | m: dim_distinct(&[m, my, mb])?, 61 | c: dim_distinct(&[c, ck])?, 62 | h, 63 | w, 64 | hy, 65 | wy, 66 | hk, 67 | wk, 68 | }) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /operators/src/conv/common_cpu.rs: -------------------------------------------------------------------------------- 1 | im2col!(common_cpu, Cpu); 2 | -------------------------------------------------------------------------------- /operators/src/conv/cuda.rs: -------------------------------------------------------------------------------- 1 | im2col!(cuda, Gpu); 2 | -------------------------------------------------------------------------------- /operators/src/conv/infini.rs: -------------------------------------------------------------------------------- 1 | im2col!(infini, Device); 2 | -------------------------------------------------------------------------------- /operators/src/conv/mod.rs: -------------------------------------------------------------------------------- 1 | mod args; 2 | mod im2col; 3 | 4 | pub use args::Args; 5 | 6 | crate::op_trait!(Conv); 7 | 8 | macro_rules! im2col { 9 | ($dev:ident, $proc:ident) => { 10 | pub type ConvIm2Col = super::im2col::Operator< 11 | crate::$dev::$proc, 12 | crate::rearrange::$dev::Operator, 13 | crate::mat_mul::$dev::Operator, 14 | >; 15 | }; 16 | } 17 | 18 | #[cfg(any(use_cpu, test))] 19 | pub mod common_cpu; 20 | #[cfg(use_cuda)] 21 | pub mod cuda; 22 | #[cfg(use_infini)] 23 | pub mod infini; 24 | #[cfg(use_cl)] 25 | pub mod opencl; 26 | -------------------------------------------------------------------------------- /operators/src/conv/opencl.rs: -------------------------------------------------------------------------------- 1 | im2col!(opencl, ClDevice); 2 | -------------------------------------------------------------------------------- /operators/src/fuesd_softmax/args.rs: -------------------------------------------------------------------------------- 1 | use crate::{rank_not_support, Hardware, MutPtr, SchemeError, TensorLayout}; 2 | use digit_layout::DigitLayout; 3 | use std::ptr::null_mut; 4 | 5 | pub struct Args { 6 | pub att_mask: AttnMask, 7 | pub att_layout: TensorLayout, 8 | pub att_base: MutPtr, 9 | } 10 | 11 | #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] 12 | #[repr(u8)] 13 | pub enum AttnMask { 14 | None, 15 | Causal, 16 | } 17 | 18 | pub(super) struct Meta { 19 | pub dt: DigitLayout, 20 | } 21 | 22 | impl Args { 23 | pub fn new_null(att_mask: AttnMask, att_layout: TensorLayout) -> Self { 24 | Self { 25 | att_mask, 26 | att_layout, 27 | att_base: null_mut(), 28 | } 29 | } 30 | 31 | pub(super) fn meta(&self) -> Result { 32 | let dt = self.att_layout.dt(); 33 | if self.att_layout.ndim() != 3 { 34 | return Err(rank_not_support("")); 35 | } 36 | Ok(Meta { dt }) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /operators/src/fuesd_softmax/common_cpu/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{ 2 | args::{AttnMask, Meta}, 3 | Args, FusedSoftmax, 4 | }; 5 | use crate::{common_cpu::Cpu, get_static, ByteOf, LaunchError, QueueAlloc, SchemeError}; 6 | use half::f16; 7 | use rayon::iter::{IntoParallelIterator, ParallelIterator}; 8 | 9 | pub struct Operator; 10 | 11 | impl FusedSoftmax for Operator {} 12 | 13 | impl crate::Operator for Operator { 14 | type Hardware = Cpu; 15 | type TopoNode = Cpu; 16 | type Args = Args; 17 | 18 | #[inline] 19 | fn new(_node: &Self::TopoNode) -> Self { 20 | Self 21 | } 22 | 23 | fn scheme( 24 | &mut self, 25 | args: &Self::Args, 26 | _max_workspace_size: usize, 27 | ) -> Result { 28 | let _meta = args.meta()?; 29 | Ok(0) 30 | } 31 | 32 | fn launch( 33 | &self, 34 | args: &Self::Args, 35 | _workspace: &mut [ByteOf], 36 | _queue_alloc: &QA, 37 | ) -> Result<(), LaunchError> 38 | where 39 | QA: QueueAlloc, 40 | { 41 | let Meta { dt } = args.meta()?; 42 | let Args { 43 | att_mask, 44 | att_layout, 45 | att_base, 46 | } = args; 47 | let &[nh, seq_len, att_len] = att_layout.shape() else { 48 | unreachable!() 49 | }; 50 | let &[sh, ss, sa] = att_layout.strides() else { 51 | unreachable!() 52 | }; 53 | 54 | get_static! { 55 | nh seq_len att_len 56 | sh ss sa 57 | } 58 | 59 | macro_rules! calculate { 60 | ($ty:ty) => { 61 | Scheme::<$ty> { 62 | nh, 63 | seq_len, 64 | att_len, 65 | sh, 66 | ss, 67 | sa, 68 | att_base: att_base.cast(), 69 | } 70 | .calculate(*att_mask) 71 | }; 72 | } 73 | 74 | use digit_layout::types as ty; 75 | match dt { 76 | ty::F16 => calculate!(f16), 77 | ty::F32 => calculate!(f32), 78 | ty::F64 => calculate!(f64), 79 | _ => todo!(), 80 | } 81 | Ok(()) 82 | } 83 | } 84 | 85 | struct Scheme { 86 | nh: usize, 87 | seq_len: usize, 88 | att_len: usize, 89 | sh: isize, 90 | ss: isize, 91 | sa: isize, 92 | att_base: *mut T, 93 | } 94 | 95 | unsafe impl Send for Scheme {} 96 | unsafe impl Sync for Scheme {} 97 | 98 | impl Scheme { 99 | fn loop_(&self, mask: AttnMask, f: impl Sync + Fn(isize, *mut T)) { 100 | let nh = self.nh as isize; 101 | let seq_len = self.seq_len as isize; 102 | let att_len = self.att_len as isize; 103 | 104 | (0..nh * seq_len).into_par_iter().for_each(|i| { 105 | let j = i / seq_len; 106 | let k = i % seq_len; 107 | let att = unsafe { self.att_base.byte_offset(j * self.sh + k * self.ss) }; 108 | let causal = match mask { 109 | AttnMask::None => att_len, 110 | AttnMask::Causal => att_len - seq_len + k + 1, 111 | }; 112 | f(causal, att) 113 | }); 114 | } 115 | } 116 | 117 | impl Scheme { 118 | fn calculate(&self, mask: AttnMask) { 119 | let att_len = self.att_len as isize; 120 | self.loop_(mask, |causal, att| { 121 | let att = |k| unsafe { &mut *att.byte_offset(k * self.sa) }; 122 | 123 | let max = (0..causal) 124 | .map(att) 125 | .max_by(|a, b| a.total_cmp(b)) 126 | .unwrap() 127 | .to_f32(); 128 | 129 | let div = (0..causal) 130 | .map(att) 131 | .map(|x| { 132 | let exp = (x.to_f32() - max).exp(); 133 | *x = f16::from_f32(exp); 134 | exp 135 | }) 136 | .sum::() 137 | .recip(); 138 | 139 | (0..causal) 140 | .map(att) 141 | .for_each(|x| *x = f16::from_f32(x.to_f32() * div)); 142 | (causal..att_len).map(att).for_each(|x| *x = f16::ZERO); 143 | }); 144 | } 145 | } 146 | 147 | impl Scheme { 148 | fn calculate(&self, mask: AttnMask) { 149 | let att_len = self.att_len as isize; 150 | self.loop_(mask, |causal, att| { 151 | let att = |k| unsafe { &mut *att.byte_offset(k * self.sa) }; 152 | 153 | let max = *(0..causal).map(att).max_by(|a, b| a.total_cmp(b)).unwrap(); 154 | 155 | let div = (0..causal) 156 | .map(att) 157 | .map(|x| { 158 | let exp = (*x - max).exp(); 159 | *x = exp; 160 | exp 161 | }) 162 | .sum::() 163 | .recip(); 164 | 165 | (0..causal).map(att).for_each(|x| *x *= div); 166 | (causal..att_len).map(att).for_each(|x| *x = 0.); 167 | }); 168 | } 169 | } 170 | 171 | impl Scheme { 172 | fn calculate(&self, mask: AttnMask) { 173 | let att_len = self.att_len as isize; 174 | self.loop_(mask, |causal, att| { 175 | let att = |k| unsafe { &mut *att.byte_offset(k * self.sa) }; 176 | 177 | let max = *(0..causal).map(att).max_by(|a, b| a.total_cmp(b)).unwrap(); 178 | 179 | let div = (0..causal) 180 | .map(att) 181 | .map(|x| { 182 | let exp = (*x - max).exp(); 183 | *x = exp; 184 | exp 185 | }) 186 | .sum::() 187 | .recip(); 188 | 189 | (0..causal).map(att).for_each(|x| *x *= div); 190 | (causal..att_len).map(att).for_each(|x| *x = 0.); 191 | }); 192 | } 193 | } 194 | -------------------------------------------------------------------------------- /operators/src/fuesd_softmax/cuda/fused_softmax.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | struct AttentionNonMask { 4 | __forceinline__ __device__ bool 5 | operator()(int tok_id, int seq_len, 6 | int pos_id, int att_len) { 7 | return true; 8 | } 9 | }; 10 | 11 | struct AttentionCausalMask { 12 | __forceinline__ __device__ bool 13 | operator()(int tok_id, int seq_len, 14 | int pos_id, int att_len) { 15 | // tok_id ↓ |<---att_len--->| 16 | // 0 | * * ... * | 17 | // 1 | * * ... * * | 18 | // 2 | * * ... * * * | 19 | // seq_len: 3 |---------------| 20 | return att_len + tok_id >= pos_id + seq_len; 21 | } 22 | }; 23 | 24 | template 25 | static __device__ void block_padding( 26 | Tdata *__restrict__ att, 27 | Tmask mask, 28 | unsigned int const tok_id, 29 | unsigned int const seq_len) { 30 | 31 | auto att_idx = threadIdx.x, att_len = blockDim.x; 32 | auto thread_data = mask(tok_id, seq_len, att_idx, att_len) 33 | ? float(att[att_idx]) 34 | : -__FLT_MAX__; 35 | 36 | using BlockOp = cub::BlockReduce; 37 | __shared__ typename BlockOp::TempStorage temp_storage; 38 | auto block_op = BlockOp(temp_storage); 39 | 40 | __shared__ float max; 41 | { 42 | auto acc = block_op.Reduce(thread_data, cub::Max(), att_len); 43 | if (threadIdx.x == 0) { max = acc; } 44 | } 45 | __syncthreads(); 46 | 47 | __shared__ float mean; 48 | { 49 | auto acc = block_op.Sum(thread_data = expf(thread_data - max), att_len); 50 | if (threadIdx.x == 0) { mean = fdividef(1, acc); } 51 | } 52 | __syncthreads(); 53 | 54 | att[att_idx] = Tdata(thread_data * mean); 55 | } 56 | 57 | template 58 | static __device__ void block_folding( 59 | Tdata *__restrict__ att, 60 | Tmask mask, 61 | unsigned int const tok_id, 62 | unsigned int const seq_len, 63 | unsigned int const att_len) { 64 | // num items per thread 65 | auto local = (att_len + blockDim.x - 1) / blockDim.x; 66 | // shared memory for thread data 67 | // local ↓ |<----blockDim.x---->| 68 | // | T0 | T1 | ... | TN | 69 | // | T0 | T1 | ... | TN | 70 | // 每个线程纵向使用以避免 bank conflict 71 | extern __shared__ float data_[]; 72 | 73 | auto thread_data = data_ + threadIdx.x; 74 | auto thread_offset = threadIdx.x * local; 75 | att += thread_offset; 76 | 77 | float thread_max = -__FLT_MAX__; 78 | for (unsigned int i = 0; i < local; ++i) { 79 | auto att_idx = thread_offset + i; 80 | auto val = att_idx < att_len && mask(tok_id, seq_len, att_idx, att_len) 81 | ? float(att[i]) 82 | : -__FLT_MAX__; 83 | thread_data[i * blockDim.x] = val; 84 | thread_max = cub::Max()(thread_max, val); 85 | } 86 | 87 | using BlockOp = cub::BlockReduce; 88 | __shared__ typename BlockOp::TempStorage temp_storage; 89 | auto block_op = BlockOp(temp_storage); 90 | 91 | __shared__ float max; 92 | { 93 | auto acc = block_op.Reduce(thread_max, cub::Max()); 94 | if (threadIdx.x == 0) { max = acc; } 95 | } 96 | __syncthreads(); 97 | 98 | __shared__ float mean; 99 | { 100 | float thread_sum = 0; 101 | for (unsigned int i = 0; i < local; ++i) { 102 | auto &val = thread_data[i * blockDim.x]; 103 | thread_sum += (val = expf(val - max)); 104 | } 105 | auto acc = block_op.Sum(thread_sum); 106 | if (threadIdx.x == 0) { mean = fdividef(1, acc); } 107 | } 108 | __syncthreads(); 109 | 110 | for (unsigned int i = 0; i < local; ++i) { 111 | if (auto att_idx = thread_offset + i; att_idx < att_len) { 112 | att[i] = Tdata(thread_data[i * blockDim.x] * mean); 113 | } 114 | } 115 | } 116 | 117 | // assert BLOCK_SIZE >= blockDim.x 118 | template 119 | static __forceinline__ __device__ void padding( 120 | Tdata *__restrict__ att, 121 | Tmask mask, 122 | int const stride_z, 123 | int const stride_y, 124 | int const stride_x) { 125 | auto offset = blockIdx.x * stride_x + blockIdx.y * stride_y + blockIdx.z * stride_z, 126 | tok_id = blockIdx.x, 127 | seq_len = gridDim.x; 128 | block_padding(att + offset, mask, tok_id, seq_len); 129 | } 130 | 131 | template 132 | static __forceinline__ __device__ void folding( 133 | Tdata *__restrict__ att, 134 | Tmask mask, 135 | unsigned int const att_len, 136 | int const stride_z, 137 | int const stride_y, 138 | int const stride_x) { 139 | auto offset = blockIdx.x * stride_x + blockIdx.y * stride_y + blockIdx.z * stride_z, 140 | tok_id = blockIdx.x, 141 | seq_len = gridDim.x; 142 | block_folding(att + offset, mask, tok_id, seq_len, att_len); 143 | } 144 | -------------------------------------------------------------------------------- /operators/src/fuesd_softmax/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(any(use_cpu, test))] 2 | pub mod common_cpu; 3 | #[cfg(use_cuda)] 4 | pub mod cuda; 5 | #[cfg(use_infini)] 6 | pub mod infini; 7 | #[cfg(use_cl)] 8 | pub mod opencl; 9 | 10 | mod args; 11 | pub use args::{Args, AttnMask}; 12 | 13 | crate::op_trait!(FusedSoftmax); 14 | -------------------------------------------------------------------------------- /operators/src/fuesd_softmax/opencl/fused_softmax.cl: -------------------------------------------------------------------------------- 1 | #define CL_TARGET_OPENCL_VERSION 200 2 | #pragma OPENCL EXTENSION cl_khr_fp16 : enable 3 | 4 | #ifndef Tval 5 | #define Tval float 6 | #endif 7 | 8 | #ifndef ITEMS_THREAD 9 | #define ITEMS_THREAD 8 10 | #endif 11 | 12 | #ifndef MASK 13 | #define MASK causal_mask 14 | #endif 15 | 16 | typedef unsigned int Tidx; 17 | 18 | bool causal_mask(Tidx tok_id, Tidx seq_len, 19 | Tidx pos_id, Tidx att_len) { 20 | // tok_id ↓ |<---att_len--->| 21 | // 0 | * * ... * | 22 | // 1 | * * ... * * | 23 | // 2 | * * ... * * * | 24 | // seq_len: 3 |---------------| 25 | return att_len + tok_id >= pos_id + seq_len; 26 | } 27 | 28 | kernel void softmax_register( 29 | global Tval *att_, 30 | Tidx const seq_len, 31 | Tidx const att_len, 32 | int const head_stride, 33 | int const tok_stride) { 34 | 35 | Tidx const 36 | head_idx = get_group_id(1), 37 | tok_id = get_group_id(0), 38 | l_idx = get_local_id(0), 39 | l_len = get_local_size(0); 40 | 41 | global Tval *att = att_ + head_idx * head_stride + tok_id * tok_stride; 42 | 43 | float 44 | data[ITEMS_THREAD], 45 | max_ = -FLT_MAX, 46 | sum_ = 0; 47 | 48 | for (Tidx i = 0, idx = l_idx; idx < att_len; ++i, idx += l_len) { 49 | data[i] = causal_mask(tok_id, seq_len, idx, att_len) ? att[idx] : -FLT_MAX; 50 | max_ = fmax(max_, data[i]); 51 | } 52 | 53 | max_ = work_group_reduce_max(max_); 54 | 55 | for (Tidx i = 0, idx = l_idx; idx < att_len; ++i, idx += l_len) { 56 | data[i] = exp(data[i] - max_); 57 | sum_ += data[i]; 58 | } 59 | 60 | barrier(CLK_LOCAL_MEM_FENCE); 61 | float const k = 1 / work_group_reduce_add(sum_); 62 | 63 | for (Tidx i = 0, idx = l_idx; idx < att_len; ++i, idx += l_len) 64 | att[idx] = data[i] * k; 65 | } 66 | 67 | kernel void softmax_global( 68 | global Tval *att_, 69 | Tidx const seq_len, 70 | Tidx const att_len, 71 | int const head_stride, 72 | int const tok_stride) { 73 | 74 | Tidx const 75 | head_idx = get_group_id(1), 76 | tok_id = get_group_id(0), 77 | l_idx = get_local_id(0), 78 | l_len = get_local_size(0); 79 | 80 | global Tval *att = att_ + head_idx * head_stride + tok_id * tok_stride; 81 | 82 | float 83 | max_ = -FLT_MAX, 84 | sum_ = 0; 85 | 86 | for (Tidx i = 0, idx = l_idx; idx < att_len; ++i, idx += l_len) { 87 | float const data = causal_mask(tok_id, seq_len, idx, att_len) ? att[idx] : -FLT_MAX; 88 | max_ = fmax(max_, data); 89 | } 90 | 91 | max_ = work_group_reduce_max(max_); 92 | 93 | for (Tidx i = 0, idx = l_idx; idx < att_len; ++i, idx += l_len) { 94 | float const data = exp(att[idx] - max_); 95 | att[idx] = data; 96 | sum_ += data; 97 | } 98 | 99 | barrier(CLK_LOCAL_MEM_FENCE); 100 | float const k = 1 / work_group_reduce_add(sum_); 101 | 102 | for (Tidx i = 0, idx = l_idx; idx < att_len; ++i, idx += l_len) 103 | att[idx] *= k; 104 | } 105 | -------------------------------------------------------------------------------- /operators/src/gelu/args.rs: -------------------------------------------------------------------------------- 1 | use crate::{utils::rank_error, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout}; 2 | use digit_layout::DigitLayout; 3 | 4 | pub struct Args { 5 | pub layout: TensorLayout, 6 | pub base: MutPtr, 7 | } 8 | 9 | pub(super) struct Meta { 10 | pub dt: DigitLayout, 11 | pub n: MaybeDyn, 12 | pub d: MaybeDyn, 13 | } 14 | 15 | impl Args { 16 | pub fn new_layout(layout: TensorLayout) -> Self { 17 | use std::ptr::null_mut; 18 | Self { 19 | layout, 20 | base: null_mut(), 21 | } 22 | } 23 | 24 | pub(super) fn meta(&self) -> Result { 25 | let Self { layout, .. } = self; 26 | 27 | let &[n, d] = layout.shape() else { 28 | return Err(rank_error("layout", 2, layout.ndim())); 29 | }; 30 | 31 | Ok(Meta { 32 | dt: layout.dt(), 33 | n, 34 | d, 35 | }) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /operators/src/gelu/common_cpu/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{args::Meta, Args, Gelu}; 2 | use crate::{common_cpu::Cpu, get_static, ByteOf, LaunchError, QueueAlloc, SchemeError}; 3 | use half::f16; 4 | 5 | pub struct Operator; 6 | 7 | impl Gelu for Operator {} 8 | 9 | impl crate::Operator for Operator { 10 | type Hardware = Cpu; 11 | type TopoNode = Cpu; 12 | type Args = Args; 13 | 14 | fn new(_node: &Self::TopoNode) -> Self { 15 | Self 16 | } 17 | 18 | fn scheme( 19 | &mut self, 20 | args: &Self::Args, 21 | _max_workspace_size: usize, 22 | ) -> Result { 23 | let _meta = args.meta()?; 24 | Ok(0) 25 | } 26 | 27 | fn launch( 28 | &self, 29 | args: &Self::Args, 30 | _workspace: &mut [ByteOf], 31 | _queue_alloc: &QA, 32 | ) -> Result<(), LaunchError> 33 | where 34 | QA: QueueAlloc, 35 | { 36 | let Meta { dt, n, d } = args.meta()?; 37 | let Args { layout, base } = args; 38 | let &[sn, sd] = layout.strides() else { 39 | unreachable!() 40 | }; 41 | 42 | get_static! { 43 | n d 44 | sn sd 45 | } 46 | 47 | macro_rules! calculate { 48 | ($ty:ty) => { 49 | Scheme::<$ty> { 50 | n, 51 | d, 52 | sn, 53 | sd, 54 | base: base.cast(), 55 | } 56 | .calculate() 57 | }; 58 | } 59 | 60 | use digit_layout::types as ty; 61 | match dt { 62 | ty::F16 => calculate!(f16), 63 | ty::F32 => calculate!(f32), 64 | ty::F64 => calculate!(f64), 65 | _ => todo!(), 66 | } 67 | Ok(()) 68 | } 69 | } 70 | 71 | struct Scheme { 72 | n: usize, 73 | d: usize, 74 | sn: isize, 75 | sd: isize, 76 | base: *mut T, 77 | } 78 | 79 | unsafe impl Send for Scheme {} 80 | unsafe impl Sync for Scheme {} 81 | 82 | impl Scheme { 83 | fn loop_(&self, f: impl Sync + Fn(T) -> T) { 84 | for i in 0..self.n as isize { 85 | (0..self.d as isize).for_each(|j| { 86 | let data = unsafe { &mut *self.base.byte_offset(i * self.sn + j * self.sd) }; 87 | *data = f(*data); 88 | }) 89 | } 90 | } 91 | } 92 | 93 | impl Scheme { 94 | #[inline] 95 | fn calculate(&self) { 96 | self.loop_(|base| f16::from_f32(gelu_f32(base.to_f32()))) 97 | } 98 | } 99 | 100 | impl Scheme { 101 | #[inline] 102 | fn calculate(&self) { 103 | self.loop_(gelu_f32) 104 | } 105 | } 106 | 107 | impl Scheme { 108 | #[inline] 109 | fn calculate(&self) { 110 | self.loop_(gelu_f64) 111 | } 112 | } 113 | 114 | #[inline(always)] 115 | fn gelu_f32(x: f32) -> f32 { 116 | use std::f32::consts::FRAC_2_PI; 117 | 0.5 * x * (1. + (FRAC_2_PI.sqrt() * (x + 0.044715 * x.powi(3))).tanh()) 118 | } 119 | 120 | #[inline(always)] 121 | fn gelu_f64(x: f64) -> f64 { 122 | use std::f64::consts::FRAC_2_PI; 123 | 0.5 * x * (1. + (FRAC_2_PI.sqrt() * (x + 0.044715 * x.powi(3))).tanh()) 124 | } 125 | -------------------------------------------------------------------------------- /operators/src/gelu/cuda/gelu.cuh: -------------------------------------------------------------------------------- 1 | #ifndef M_SQRT1_2 2 | #define M_SQRT1_2 .707106781186547524401f 3 | #endif 4 | template 5 | static __device__ void gelu( 6 | Tdata *__restrict__ data) { 7 | auto i = blockIdx.x * blockDim.x + threadIdx.x; 8 | auto x = float(data[i]); 9 | data[i] = Tdata(0.5f * x * (1.0f + erf(x * M_SQRT1_2))); 10 | } -------------------------------------------------------------------------------- /operators/src/gelu/infini/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{Args, Gelu}; 2 | use crate::{infini::Device, ByteOf, LaunchError, QueueAlloc, SchemeError}; 3 | use infini_op::Handle; 4 | use std::sync::Arc; 5 | 6 | #[repr(transparent)] 7 | pub struct Operator(Arc); 8 | 9 | impl Gelu for Operator {} 10 | 11 | impl crate::Operator for Operator { 12 | type Hardware = Device; 13 | type TopoNode = Device; 14 | type Args = Args; 15 | 16 | fn new(_node: &Self::TopoNode) -> Self { 17 | todo!() 18 | } 19 | 20 | fn scheme( 21 | &mut self, 22 | _args: &Self::Args, 23 | _max_workspace_size: usize, 24 | ) -> Result { 25 | todo!() 26 | } 27 | 28 | fn launch( 29 | &self, 30 | _args: &Self::Args, 31 | _workspace: &mut [ByteOf], 32 | _queue_alloc: &QA, 33 | ) -> Result<(), LaunchError> 34 | where 35 | QA: QueueAlloc, 36 | { 37 | todo!() 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /operators/src/gelu/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(any(use_cpu, test))] 2 | pub mod common_cpu; 3 | #[cfg(use_cuda)] 4 | pub mod cuda; 5 | #[cfg(use_infini)] 6 | pub mod infini; 7 | #[cfg(use_cl)] 8 | pub mod opencl; 9 | 10 | mod args; 11 | pub use args::Args; 12 | 13 | crate::op_trait!(Gelu); 14 | -------------------------------------------------------------------------------- /operators/src/gelu/opencl/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{Args, Gelu}; 2 | use crate::{opencl::ClDevice, ByteOf, LaunchError, QueueAlloc, SchemeError}; 3 | 4 | pub struct Operator; 5 | 6 | impl Gelu for Operator {} 7 | 8 | impl crate::Operator for Operator { 9 | type Hardware = ClDevice; 10 | type TopoNode = ClDevice; 11 | type Args = Args; 12 | 13 | fn new(_node: &Self::TopoNode) -> Self { 14 | todo!() 15 | } 16 | 17 | fn scheme( 18 | &mut self, 19 | _args: &Self::Args, 20 | _max_workspace_size: usize, 21 | ) -> Result { 22 | todo!() 23 | } 24 | 25 | fn launch( 26 | &self, 27 | _args: &Self::Args, 28 | _workspace: &mut [ByteOf], 29 | _queue_alloc: &QA, 30 | ) -> Result<(), LaunchError> 31 | where 32 | QA: QueueAlloc, 33 | { 34 | todo!() 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /operators/src/handle/common_cpu/inproc_node.rs: -------------------------------------------------------------------------------- 1 | use super::Cpu; 2 | use crate::TopoNode; 3 | use std::sync::{ 4 | atomic::{AtomicUsize, Ordering::Relaxed}, 5 | mpsc::{channel, Receiver, Sender}, 6 | Arc, Condvar, Mutex, 7 | }; 8 | 9 | pub struct InprocNode { 10 | rank: usize, 11 | senders: Box<[Sender]>, 12 | receiver: Arc>>, 13 | notifier: Arc, 14 | counter: Arc, 15 | } 16 | 17 | impl Clone for InprocNode { 18 | fn clone(&self) -> Self { 19 | Self { 20 | rank: self.rank, 21 | senders: self.senders.clone(), 22 | receiver: self.receiver.clone(), 23 | notifier: self.notifier.clone(), 24 | counter: self.counter.clone(), 25 | } 26 | } 27 | } 28 | 29 | impl InprocNode { 30 | pub fn new(n: usize) -> Vec> { 31 | let mut senders = Vec::with_capacity(n); 32 | let mut receivers = Vec::with_capacity(n); 33 | for _ in 0..n { 34 | let (sender, receiver) = channel(); 35 | senders.push(sender); 36 | receivers.push(Arc::new(Mutex::new(receiver))); 37 | } 38 | let senders: Box<[Sender]> = senders.into(); 39 | let notifier = Arc::new(Notifier::new()); 40 | 41 | receivers 42 | .into_iter() 43 | .enumerate() 44 | .map(|(rank, receiver)| InprocNode { 45 | rank, 46 | senders: senders.clone(), 47 | receiver, 48 | notifier: notifier.clone(), 49 | counter: Arc::new(AtomicUsize::new(0)), 50 | }) 51 | .collect() 52 | } 53 | 54 | #[inline] 55 | pub(crate) fn send(&self, i: usize, msg: T) { 56 | self.senders[i].send(msg).unwrap(); 57 | } 58 | 59 | #[inline] 60 | pub(crate) fn recv(&self) -> T { 61 | self.receiver.lock().unwrap().recv().unwrap() 62 | } 63 | 64 | #[must_use] 65 | #[inline] 66 | pub(crate) fn wait(&self) -> Guard { 67 | let count = self.counter.fetch_add(1, Relaxed) * self.group_size(); 68 | self.notifier.wait(count) 69 | } 70 | } 71 | 72 | impl TopoNode for InprocNode { 73 | #[inline] 74 | fn processor(&self) -> &Cpu { 75 | &Cpu 76 | } 77 | #[inline] 78 | fn rank(&self) -> usize { 79 | self.rank 80 | } 81 | #[inline] 82 | fn group_size(&self) -> usize { 83 | self.senders.len() 84 | } 85 | } 86 | 87 | #[repr(transparent)] 88 | pub struct Guard<'a>(&'a Notifier); 89 | 90 | impl Drop for Guard<'_> { 91 | #[inline] 92 | fn drop(&mut self) { 93 | self.0.notify(); 94 | } 95 | } 96 | 97 | struct Notifier { 98 | lock: Mutex, 99 | cond: Condvar, 100 | } 101 | 102 | impl Notifier { 103 | fn new() -> Self { 104 | Self { 105 | lock: Mutex::new(0), 106 | cond: Condvar::new(), 107 | } 108 | } 109 | 110 | fn wait(&self, count: usize) -> Guard { 111 | let _guard = self 112 | .cond 113 | .wait_while(self.lock.lock().unwrap(), |current| *current < count) 114 | .unwrap(); 115 | Guard(self) 116 | } 117 | 118 | fn notify(&self) { 119 | *self.lock.lock().unwrap() += 1; 120 | self.cond.notify_all(); 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /operators/src/handle/common_cpu/mod.rs: -------------------------------------------------------------------------------- 1 | mod inproc_node; 2 | 3 | use crate::{Alloc, Blob, Hardware, QueueAlloc, QueueOf}; 4 | 5 | pub use inproc_node::InprocNode; 6 | 7 | #[derive(Clone, Copy, Debug)] 8 | pub struct Cpu; 9 | 10 | #[derive(Clone, Copy, Debug)] 11 | pub struct ThisThread; 12 | 13 | impl Hardware for Cpu { 14 | type Byte = u8; 15 | type Queue<'ctx> = ThisThread; 16 | } 17 | 18 | impl Alloc for T { 19 | #[inline] 20 | fn alloc(&self, size: usize) -> Blob { 21 | Blob::new(size) 22 | } 23 | 24 | #[inline] 25 | fn free(&self, _mem: Blob) {} 26 | } 27 | 28 | impl QueueAlloc for ThisThread { 29 | type Hardware = Cpu; 30 | type DevMem = Blob; 31 | #[inline] 32 | fn queue(&self) -> &QueueOf { 33 | self 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /operators/src/handle/cuda/alloc.rs: -------------------------------------------------------------------------------- 1 | use super::Gpu; 2 | use crate::{Alloc, OffsetCalculator, QueueAlloc, QueueOf}; 3 | use cuda::{CurrentCtx, DevByte, DevMem, Stream}; 4 | use std::{ 5 | cell::RefCell, 6 | ops::{Deref, DerefMut, Range}, 7 | rc::Rc, 8 | }; 9 | 10 | pub struct StreamMemPool<'ctx> { 11 | stream: Stream<'ctx>, 12 | mem_pool: Rc>>, 13 | } 14 | 15 | pub struct MemPoolBlob<'ctx> { 16 | mem_pool: Rc>>, 17 | range: Range, 18 | } 19 | 20 | struct MemPool<'ctx> { 21 | pool: Vec>, 22 | recorder: OffsetCalculator, 23 | } 24 | 25 | impl Deref for MemPoolBlob<'_> { 26 | type Target = [DevByte]; 27 | #[inline] 28 | fn deref(&self) -> &Self::Target { 29 | unsafe { std::slice::from_raw_parts(self.range.start as _, self.range.len()) } 30 | } 31 | } 32 | 33 | impl DerefMut for MemPoolBlob<'_> { 34 | #[inline] 35 | fn deref_mut(&mut self) -> &mut Self::Target { 36 | unsafe { std::slice::from_raw_parts_mut(self.range.start as _, self.range.len()) } 37 | } 38 | } 39 | 40 | impl<'ctx> StreamMemPool<'ctx> { 41 | pub fn new(stream: Stream<'ctx>) -> Self { 42 | let alignment = stream.ctx().dev().alignment(); 43 | Self { 44 | stream, 45 | mem_pool: Rc::new(RefCell::new(MemPool { 46 | pool: Vec::new(), 47 | recorder: OffsetCalculator::new(if alignment == 0 { 256 } else { alignment }), 48 | })), 49 | } 50 | } 51 | 52 | pub fn put(&self, size: usize) { 53 | let blob = self.stream.ctx().malloc::(size); 54 | let area = blob.as_ptr_range(); 55 | let mut mem_pool = self.mem_pool.borrow_mut(); 56 | mem_pool.pool.push(blob); 57 | mem_pool.recorder.put(&(area.start as _..area.end as _)); 58 | } 59 | } 60 | 61 | impl<'ctx> Alloc> for StreamMemPool<'ctx> { 62 | #[inline] 63 | fn alloc(&self, size: usize) -> MemPoolBlob<'ctx> { 64 | let range = self 65 | .mem_pool 66 | .borrow_mut() 67 | .recorder 68 | .take(size) 69 | .expect("out of memory"); 70 | MemPoolBlob { 71 | mem_pool: self.mem_pool.clone(), 72 | range, 73 | } 74 | } 75 | 76 | #[inline] 77 | fn free(&self, mem: MemPoolBlob<'ctx>) { 78 | assert!(Rc::ptr_eq(&self.mem_pool, &mem.mem_pool)); 79 | self.mem_pool.borrow_mut().recorder.put(&mem.range) 80 | } 81 | } 82 | 83 | impl<'ctx> QueueAlloc for StreamMemPool<'ctx> { 84 | type Hardware = Gpu; 85 | type DevMem = MemPoolBlob<'ctx>; 86 | #[inline] 87 | fn queue(&self) -> &QueueOf { 88 | &self.stream 89 | } 90 | } 91 | 92 | impl<'ctx> Alloc> for &'ctx CurrentCtx { 93 | #[inline] 94 | fn alloc(&self, size: usize) -> DevMem<'ctx> { 95 | self.malloc::(size) 96 | } 97 | 98 | #[inline] 99 | fn free(&self, _mem: DevMem<'ctx>) {} 100 | } 101 | 102 | #[cfg(use_nvidia)] 103 | impl<'ctx> Alloc> for Stream<'ctx> { 104 | #[inline] 105 | fn alloc(&self, size: usize) -> DevMem<'ctx> { 106 | self.malloc::(size) 107 | } 108 | 109 | #[inline] 110 | fn free(&self, mem: DevMem<'ctx>) { 111 | mem.drop_on(self) 112 | } 113 | } 114 | 115 | #[cfg(use_nvidia)] 116 | impl<'ctx> QueueAlloc for Stream<'ctx> { 117 | type Hardware = Gpu; 118 | type DevMem = DevMem<'ctx>; 119 | #[inline] 120 | fn queue(&self) -> &QueueOf { 121 | self 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /operators/src/handle/cuda/cxx/export.h: -------------------------------------------------------------------------------- 1 | #ifndef __EXPORT_H__ 2 | #define __EXPORT_H__ 3 | 4 | #if defined(_WIN32) 5 | #define __export __declspec(dllexport) 6 | #elif defined(__GNUC__) && ((__GNUC__ >= 4) || (__GNUC__ == 3 && __GNUC_MINOR__ >= 3)) 7 | #define __export __attribute__((visibility("default"))) 8 | #else 9 | #define __export 10 | #endif 11 | 12 | #ifdef __cplusplus 13 | #define __C extern "C" 14 | #else 15 | #define __C 16 | #endif 17 | 18 | #endif// __EXPORT_H__ 19 | -------------------------------------------------------------------------------- /operators/src/handle/cuda/cxx/iluvatar.lua: -------------------------------------------------------------------------------- 1 | toolchain("iluvatar.toolchain") 2 | set_toolset("cc" , "clang" ) 3 | set_toolset("cxx" , "clang++") 4 | set_toolset("cu" , "clang++") 5 | set_toolset("culd", "clang++") 6 | set_toolset("cu-ccbin", "$(env CXX)", "$(env CC)") 7 | toolchain_end() 8 | rule("iluvatar.env") 9 | add_deps("cuda.env", {order = true}) 10 | after_load(function (target) 11 | local old = target:get("syslinks") 12 | local new = {} 13 | 14 | for _, link in ipairs(old) do 15 | if link ~= "cudadevrt" then 16 | table.insert(new, link) 17 | end 18 | end 19 | 20 | if #old > #new then 21 | target:set("syslinks", new) 22 | local log = "cudadevrt removed, syslinks = { " 23 | for _, link in ipairs(new) do 24 | log = log .. link .. ", " 25 | end 26 | log = log:sub(0, -3) .. " }" 27 | print(log) 28 | end 29 | end) 30 | rule_end() 31 | 32 | 33 | target("lib") 34 | set_kind("shared") 35 | set_optimize("aggressive") 36 | set_languages("cxx17") 37 | add_files("src.cu") 38 | -- 如果配置了 Iluvatar,则按照 Iluvatar 的方式编译 39 | set_toolchains("iluvatar.toolchain") 40 | add_rules("iluvatar.env") 41 | set_values("cuda.rdc", false) 42 | add_links("cudart") -- 首选动态链接 cudart 以免链接 cudart_static 43 | target_end() 44 | -------------------------------------------------------------------------------- /operators/src/handle/cuda/cxx/nv.lua: -------------------------------------------------------------------------------- 1 | target("lib") 2 | set_kind("shared") 3 | set_toolchains("cuda") 4 | set_optimize("aggressive") 5 | 6 | if is_plat("windows") then 7 | -- See 8 | add_cuflags("-Xcompiler=/utf-8", "--expt-relaxed-constexpr", "--allow-unsupported-compiler") 9 | add_defines("_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH") 10 | end 11 | 12 | set_languages("cxx17") 13 | add_files("src.cu") 14 | target_end() 15 | -------------------------------------------------------------------------------- /operators/src/handle/cuda/cxx/test_compile_8.0/test_compile.cu: -------------------------------------------------------------------------------- 1 | #include "../export.h" 2 | 3 | __C __export const char *hello_world() { 4 | return "Hello, world!"; 5 | } 6 | -------------------------------------------------------------------------------- /operators/src/handle/cuda/module.rs: -------------------------------------------------------------------------------- 1 | use super::{Handle, Key}; 2 | use cuda::{ 3 | bindings::nvrtcResult, ContextResource, ContextSpore, CurrentCtx, Dim3, KernelFn, ModuleSpore, 4 | Ptx, Stream, 5 | }; 6 | use log::warn; 7 | use std::{ 8 | collections::{hash_map::Entry::Occupied, HashMap}, 9 | ffi::{c_void, CStr}, 10 | ptr::addr_eq, 11 | sync::{Arc, OnceLock, RwLock}, 12 | }; 13 | 14 | pub(crate) struct ModuleBox { 15 | handle: Arc, 16 | key: Key, 17 | module: Option, 18 | } 19 | 20 | impl ModuleBox { 21 | pub(super) fn share(handle: Arc, key: Key, code: impl FnOnce() -> String) -> Arc { 22 | let ptx = cache_ptx(&key, code).unwrap(); 23 | let module = handle.context.apply(|ctx| ctx.load(&ptx).sporulate()); 24 | Arc::new(Self { 25 | handle, 26 | key, 27 | module: Some(module), 28 | }) 29 | } 30 | 31 | pub fn load<'ctx>(&'ctx self, name: impl AsRef, ctx: &'ctx CurrentCtx) -> KernelFn<'ctx> { 32 | self.module 33 | .as_ref() 34 | .unwrap() 35 | .sprout_ref(ctx) 36 | .get_kernel(name) 37 | } 38 | 39 | pub fn launch( 40 | &self, 41 | name: impl AsRef, 42 | grid_dims: impl Into, 43 | block_dims: impl Into, 44 | params: *const *const c_void, 45 | shared_mem: usize, 46 | stream: &Stream, 47 | ) { 48 | self.load(name, stream.ctx()).launch( 49 | grid_dims, 50 | block_dims, 51 | params, 52 | shared_mem, 53 | Some(stream), 54 | ) 55 | } 56 | } 57 | 58 | impl Drop for ModuleBox { 59 | #[inline] 60 | fn drop(&mut self) { 61 | if let Occupied(entry) = self.handle.modules.write().unwrap().entry(self.key.clone()) { 62 | if addr_eq(entry.get().as_ptr(), self as *const _) { 63 | entry.remove(); 64 | } 65 | } 66 | if let Some(module) = self.module.take() { 67 | self.handle.context.apply(|ctx| drop(module.sprout(ctx))); 68 | } 69 | } 70 | } 71 | 72 | fn cache_ptx(key: &Key, code: impl FnOnce() -> String) -> Result, (nvrtcResult, String)> { 73 | static CACHE: OnceLock>>> = OnceLock::new(); 74 | let cache = CACHE.get_or_init(Default::default); 75 | 76 | if let Some(ptx) = cache.read().unwrap().get(key) { 77 | return Ok(ptx.clone()); 78 | } 79 | let (ptx, log) = Ptx::compile(code(), key.1); 80 | match ptx { 81 | Ok(ptx) => { 82 | if !log.is_empty() { 83 | warn!("{log}"); 84 | } 85 | 86 | let ptx = Arc::new(ptx); 87 | let _ = cache.write().unwrap().insert(key.clone(), ptx.clone()); 88 | Ok(ptx) 89 | } 90 | Err(e) => Err((e, log)), 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /operators/src/handle/cuda/nccl.rs: -------------------------------------------------------------------------------- 1 | use super::{Config, Gpu}; 2 | use crate::TopoNode; 3 | use nccl::Communicator; 4 | use std::sync::Arc; 5 | 6 | pub struct NcclNode { 7 | gpu: Gpu, 8 | pub(crate) nccl: Arc, 9 | } 10 | 11 | impl NcclNode { 12 | pub fn new(comm: Communicator, config: Config) -> Self { 13 | Self { 14 | gpu: Gpu::new(comm.device().retain_primary(), config), 15 | nccl: Arc::new(comm), 16 | } 17 | } 18 | } 19 | 20 | impl TopoNode for NcclNode { 21 | #[inline] 22 | fn processor(&self) -> &Gpu { 23 | &self.gpu 24 | } 25 | #[inline] 26 | fn rank(&self) -> usize { 27 | self.nccl.rank() 28 | } 29 | #[inline] 30 | fn group_size(&self) -> usize { 31 | self.nccl.count() 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /operators/src/handle/infini/ccl.rs: -------------------------------------------------------------------------------- 1 | use super::Device; 2 | use crate::TopoNode; 3 | use infini_ccl::{bindings::DeviceType, Comm}; 4 | use std::{os::raw::c_uint, sync::Arc}; 5 | 6 | pub struct InfiniNode { 7 | rank: usize, 8 | group_size: usize, 9 | pub(crate) device: Device, 10 | pub(crate) comm: Option>, 11 | } 12 | 13 | impl InfiniNode { 14 | pub fn cpu(n: usize) -> Vec { 15 | let indices = (0..n as _).collect::>(); 16 | Self::new(&indices, DeviceType::DEVICE_CPU) 17 | } 18 | 19 | pub fn nv_gpu(indices: &[c_uint]) -> Vec { 20 | Self::new(indices, DeviceType::DEVICE_NVIDIA) 21 | } 22 | 23 | pub fn cambricon_mlu(indices: &[c_uint]) -> Vec { 24 | Self::new(indices, DeviceType::DEVICE_CAMBRICON) 25 | } 26 | 27 | pub fn ascend_npu(indices: &[c_uint]) -> Vec { 28 | Self::new(indices, DeviceType::DEVICE_ASCEND) 29 | } 30 | 31 | fn new(indices: &[c_uint], ty: DeviceType) -> Vec { 32 | let confused: infini_rt::DeviceType = unsafe { std::mem::transmute(ty) }; 33 | if let &[id] = indices { 34 | vec![Self { 35 | rank: 0, 36 | group_size: 1, 37 | device: Device::new(confused, id as _), 38 | comm: None, 39 | }] 40 | } else { 41 | Comm::init_all(ty, indices) 42 | .into_iter() 43 | .zip(indices) 44 | .enumerate() 45 | .map(|(idx, (comm, &id))| Self { 46 | rank: idx, 47 | group_size: indices.len(), 48 | device: Device::new(confused, id as _), 49 | comm: Some(Arc::new(comm)), 50 | }) 51 | .collect() 52 | } 53 | } 54 | } 55 | 56 | impl TopoNode for InfiniNode { 57 | #[inline] 58 | fn processor(&self) -> &Device { 59 | &self.device 60 | } 61 | #[inline] 62 | fn rank(&self) -> usize { 63 | self.rank 64 | } 65 | #[inline] 66 | fn group_size(&self) -> usize { 67 | self.group_size 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /operators/src/handle/infini/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::{Alloc, Hardware, QueueAlloc, QueueOf}; 2 | use infini_rt::{DevBlob, DevByte, DeviceType, Stream}; 3 | use std::{ops::Deref, sync::Arc}; 4 | 5 | mod ccl; 6 | pub use ccl::InfiniNode; 7 | 8 | #[derive(Clone)] 9 | pub struct Device { 10 | device: infini_rt::Device, 11 | handle: Arc, 12 | } 13 | 14 | impl Device { 15 | #[inline] 16 | pub fn cpu() -> Self { 17 | Self::new(infini_rt::DEVICE_CPU, 0) 18 | } 19 | 20 | #[inline] 21 | pub fn nv_gpu(id: usize) -> Self { 22 | Self::new(infini_rt::DEVICE_NVIDIA, id) 23 | } 24 | 25 | #[inline] 26 | pub fn cambricon_mlu(id: usize) -> Self { 27 | Self::new(infini_rt::DEVICE_CAMBRICON, id) 28 | } 29 | 30 | #[inline] 31 | pub fn ascend_npu(id: usize) -> Self { 32 | Self::new(infini_rt::DEVICE_ASCEND, id) 33 | } 34 | 35 | fn new(ty: infini_rt::DeviceType, id: usize) -> Self { 36 | use infini_op::bindings::Device as Ty; 37 | Self { 38 | device: infini_rt::Device { ty, id: id as _ }, 39 | handle: Arc::new(infini_op::Handle::new( 40 | match ty { 41 | infini_rt::DEVICE_CPU => Ty::DevCpu, 42 | infini_rt::DEVICE_NVIDIA => Ty::DevNvGpu, 43 | infini_rt::DEVICE_CAMBRICON => Ty::DevCambriconMlu, 44 | infini_rt::DEVICE_ASCEND => Ty::DevAscendNpu, 45 | _ => unreachable!("unknown device type"), 46 | }, 47 | id as _, 48 | )), 49 | } 50 | } 51 | 52 | #[inline] 53 | pub(crate) fn device_type(&self) -> DeviceType { 54 | self.device.ty 55 | } 56 | 57 | #[inline] 58 | pub(crate) fn handle(&self) -> &Arc { 59 | &self.handle 60 | } 61 | } 62 | 63 | impl Deref for Device { 64 | type Target = infini_rt::Device; 65 | #[inline] 66 | fn deref(&self) -> &Self::Target { 67 | &self.device 68 | } 69 | } 70 | 71 | impl Hardware for Device { 72 | type Byte = DevByte; 73 | type Queue<'ctx> = Stream; 74 | } 75 | 76 | impl Alloc for Device { 77 | #[inline] 78 | fn alloc(&self, size: usize) -> DevBlob { 79 | self.device.malloc::(size) 80 | } 81 | 82 | #[inline] 83 | fn free(&self, _mem: DevBlob) {} 84 | } 85 | 86 | impl Alloc for Stream { 87 | #[inline] 88 | fn alloc(&self, size: usize) -> DevBlob { 89 | self.malloc::(size) 90 | } 91 | 92 | #[inline] 93 | fn free(&self, mem: DevBlob) { 94 | self.free(mem) 95 | } 96 | } 97 | 98 | impl QueueAlloc for Stream { 99 | type Hardware = Device; 100 | type DevMem = DevBlob; 101 | #[inline] 102 | fn queue(&self) -> &QueueOf { 103 | self 104 | } 105 | } 106 | 107 | /// 并行转换类型并异步拷贝到显存。 108 | #[cfg(test)] 109 | pub(crate) fn cast_load<'ctx, T, U, F>(val: &[T], f: F, stream: &Stream) -> DevBlob 110 | where 111 | T: Sync + Copy, 112 | U: Send + Copy, 113 | F: Sync + Fn(T) -> U, 114 | { 115 | let mut host = stream.get_device().malloc_host::(val.len()); 116 | let host = unsafe { std::slice::from_raw_parts_mut(host.as_mut_ptr().cast(), val.len()) }; 117 | host.into_iter().zip(val).for_each(|(y, x)| *y = f(*x)); 118 | let ans = stream.from_host(host); 119 | stream.synchronize(); 120 | ans 121 | } 122 | -------------------------------------------------------------------------------- /operators/src/handle/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(any(use_cpu, test))] 2 | pub mod common_cpu; 3 | #[cfg(use_cuda)] 4 | pub mod cuda; 5 | #[cfg(use_infini)] 6 | pub mod infini; 7 | #[cfg(use_cl)] 8 | pub mod opencl; 9 | -------------------------------------------------------------------------------- /operators/src/handle/opencl/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::{Alloc, Hardware, Pool, QueueAlloc, QueueOf, SchemeCacheSize, SchemeDiversity}; 2 | use clrt::{BuildError, CommandQueue, Context, Kernel, Program, SvmBlob, SvmByte}; 3 | use lru::LruCache; 4 | use std::{ 5 | collections::HashMap, 6 | ffi::{CStr, CString}, 7 | fmt, 8 | hash::Hash, 9 | sync::Mutex, 10 | }; 11 | 12 | pub struct ClDevice { 13 | ctx: Context, 14 | cache_size: SchemeCacheSize, 15 | } 16 | 17 | impl Hardware for ClDevice { 18 | type Byte = SvmByte; 19 | type Queue<'ctx> = CommandQueue; 20 | } 21 | 22 | impl ClDevice { 23 | #[inline] 24 | pub fn new(context: Context, cache_size: SchemeCacheSize) -> Self { 25 | Self { 26 | ctx: context, 27 | cache_size, 28 | } 29 | } 30 | 31 | #[inline] 32 | pub(crate) fn context(&self) -> &Context { 33 | &self.ctx 34 | } 35 | 36 | #[inline] 37 | pub fn new_cache(&self, level: SchemeDiversity) -> Mutex> { 38 | self.cache_size.new_cache(level) 39 | } 40 | } 41 | 42 | impl Alloc for Context { 43 | #[inline] 44 | fn alloc(&self, size: usize) -> SvmBlob { 45 | self.malloc::(size) 46 | } 47 | 48 | #[inline] 49 | fn free(&self, _mem: SvmBlob) {} 50 | } 51 | 52 | impl Alloc for CommandQueue { 53 | #[inline] 54 | fn alloc(&self, size: usize) -> SvmBlob { 55 | self.ctx().malloc::(size) 56 | } 57 | 58 | #[inline] 59 | fn free(&self, mem: SvmBlob) { 60 | self.free(mem, None) 61 | } 62 | } 63 | 64 | impl QueueAlloc for CommandQueue { 65 | type Hardware = ClDevice; 66 | type DevMem = SvmBlob; 67 | #[inline] 68 | fn queue(&self) -> &QueueOf { 69 | self 70 | } 71 | } 72 | 73 | pub(crate) struct KernelCache { 74 | program: Program, 75 | kernels: HashMap>, 76 | } 77 | 78 | pub(crate) const CL2_0: &CStr = c"-cl-std=CL2.0"; 79 | 80 | pub struct CodeGen { 81 | code: &'static str, 82 | defines: Vec<(&'static str, String)>, 83 | } 84 | 85 | impl CodeGen { 86 | pub fn new(code: &'static str) -> Self { 87 | Self { 88 | code, 89 | defines: Default::default(), 90 | } 91 | } 92 | 93 | pub fn define(&mut self, name: &'static str, value: impl ToString) -> &mut Self { 94 | self.defines.push((name, value.to_string())); 95 | self 96 | } 97 | } 98 | 99 | impl fmt::Display for CodeGen { 100 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 101 | for (name, value) in &self.defines { 102 | writeln!(f, "#define {} {}", name, value)? 103 | } 104 | write!(f, "{}", self.code) 105 | } 106 | } 107 | 108 | impl KernelCache { 109 | pub fn new(ctx: &Context, src: &str, opts: &CStr) -> Self { 110 | let program = match ctx.build_from_source(src, opts) { 111 | Ok(program) => program, 112 | Err(BuildError::BuildFailed(log)) => { 113 | println!("{log}"); 114 | panic!("Failed to build cl kernels") 115 | } 116 | Err(BuildError::Others(err)) => { 117 | panic!("Failed to build cl kernels with error {err}") 118 | } 119 | }; 120 | let kernels = program 121 | .kernels() 122 | .into_iter() 123 | .map(|k| { 124 | let name = k.name(); 125 | let pool = Pool::new(); 126 | pool.push(k); 127 | (name, pool) 128 | }) 129 | .collect(); 130 | Self { program, kernels } 131 | } 132 | 133 | pub fn take(&self, name: &str) -> Option { 134 | self.kernels 135 | .get(name)? 136 | .pop() 137 | .or_else(|| self.program.get_kernel(CString::new(name).unwrap())) 138 | } 139 | 140 | pub fn put(&self, name: &str, kernel: Kernel) { 141 | self.kernels.get(name).unwrap().push(kernel) 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /operators/src/layer_norm/args.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | utils::{dim_distinct, rank_error, type_distinct}, 3 | ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout, 4 | }; 5 | use digit_layout::DigitLayout; 6 | 7 | pub struct Args { 8 | pub y_layout: TensorLayout, 9 | pub y_base: MutPtr, 10 | pub x_layout: TensorLayout, 11 | pub x_base: ConstPtr, 12 | pub scale_layout: TensorLayout, 13 | pub scale_base: ConstPtr, 14 | pub bias_layout: TensorLayout, 15 | pub bias_base: ConstPtr, 16 | pub epsilon: f32, 17 | } 18 | 19 | pub(super) struct Meta { 20 | pub dt_a: DigitLayout, 21 | pub dt_w: DigitLayout, 22 | pub n: MaybeDyn, 23 | pub d: MaybeDyn, 24 | } 25 | 26 | impl Args { 27 | pub(super) fn meta(&self) -> Result { 28 | let Self { 29 | y_layout: y, 30 | x_layout: x, 31 | scale_layout: scale, 32 | bias_layout: bias, 33 | .. 34 | } = self; 35 | 36 | let &[ny, dy] = y.shape() else { 37 | return Err(rank_error("y", 2, y.ndim())); 38 | }; 39 | let &[nx, dx] = x.shape() else { 40 | return Err(rank_error("x", 2, x.ndim())); 41 | }; 42 | let &[ds] = scale.shape() else { 43 | return Err(rank_error("scale", 1, scale.ndim())); 44 | }; 45 | let &[db] = bias.shape() else { 46 | return Err(rank_error("bias", 1, bias.ndim())); 47 | }; 48 | 49 | Ok(Meta { 50 | dt_a: type_distinct(&[y.dt(), x.dt()])?, 51 | dt_w: type_distinct(&[scale.dt(), bias.dt()])?, 52 | n: dim_distinct(&[ny, nx])?, 53 | d: dim_distinct(&[dy, dx, ds, db])?, 54 | }) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /operators/src/layer_norm/common_cpu/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{args::Meta, Args, LayerNorm}; 2 | use crate::{common_cpu::Cpu, get_static, ByteOf, LaunchError, QueueAlloc, SchemeError}; 3 | use half::f16; 4 | use num_traits::{real::Real, NumCast, ToPrimitive}; 5 | use std::ops::AddAssign; 6 | 7 | pub struct Operator; 8 | 9 | impl LayerNorm for Operator {} 10 | 11 | impl crate::Operator for Operator { 12 | type Hardware = Cpu; 13 | type TopoNode = Cpu; 14 | type Args = Args; 15 | 16 | fn new(_node: &Self::TopoNode) -> Self { 17 | Self 18 | } 19 | 20 | fn scheme( 21 | &mut self, 22 | args: &Self::Args, 23 | _max_workspace_size: usize, 24 | ) -> Result { 25 | let _meta = args.meta()?; 26 | Ok(0) 27 | } 28 | 29 | fn launch( 30 | &self, 31 | args: &Self::Args, 32 | _workspace: &mut [ByteOf], 33 | _queue_alloc: &QA, 34 | ) -> Result<(), LaunchError> 35 | where 36 | QA: QueueAlloc, 37 | { 38 | let Meta { dt_w, dt_a, n, d } = args.meta()?; 39 | let Args { 40 | y_layout, 41 | y_base, 42 | x_layout, 43 | x_base, 44 | scale_layout, 45 | scale_base, 46 | bias_layout, 47 | bias_base, 48 | epsilon, 49 | } = args; 50 | let &[nsy, dsy] = y_layout.strides() else { 51 | unreachable!() 52 | }; 53 | let &[nsx, dsx] = x_layout.strides() else { 54 | unreachable!() 55 | }; 56 | let &[dss] = scale_layout.strides() else { 57 | unreachable!() 58 | }; 59 | let &[dsb] = bias_layout.strides() else { 60 | unreachable!() 61 | }; 62 | 63 | get_static! { 64 | n d 65 | nsy dsy 66 | nsx dsx 67 | dss 68 | dsb 69 | } 70 | 71 | macro_rules! calculate { 72 | ($eps:expr; $w:ty, $a:ty) => { 73 | Scheme { 74 | n, 75 | d, 76 | nsy, 77 | dsy, 78 | nsx, 79 | dsx, 80 | dss, 81 | dsb, 82 | epsilon: $eps, 83 | y: y_base.cast::<$a>(), 84 | x: x_base.cast::<$a>(), 85 | s: scale_base.cast::<$w>(), 86 | b: bias_base.cast::<$w>(), 87 | } 88 | .calculate() 89 | }; 90 | } 91 | 92 | use digit_layout::types as ty; 93 | match (dt_w, dt_a) { 94 | (ty::F16, ty::F16) => calculate!(*epsilon ; f16, f16), 95 | (ty::F32, ty::F16) => calculate!(*epsilon ; f32, f16), 96 | (ty::F32, ty::F32) => calculate!(*epsilon ; f32, f32), 97 | (ty::F64, ty::F64) => calculate!(*epsilon as f64; f64, f64), 98 | (_, _) => todo!(), 99 | } 100 | 101 | Ok(()) 102 | } 103 | } 104 | 105 | struct Scheme { 106 | n: usize, 107 | d: usize, 108 | nsy: isize, 109 | dsy: isize, 110 | nsx: isize, 111 | dsx: isize, 112 | dss: isize, 113 | dsb: isize, 114 | epsilon: X, 115 | y: *mut A, 116 | x: *const A, 117 | s: *const W, 118 | b: *const W, 119 | } 120 | 121 | impl Scheme 122 | where 123 | X: Real + AddAssign, 124 | W: Real, 125 | A: Real, 126 | { 127 | fn calculate(self) { 128 | for i in 0..self.n as isize { 129 | let mut sum = X::zero(); 130 | let mut sum2 = X::zero(); 131 | for j in 0..self.d as isize { 132 | let x: X = get(self.x, i * self.nsx + j * self.dsx); 133 | sum += x; 134 | sum2 += x * x; 135 | } 136 | let n = X::from(self.d).unwrap(); 137 | let e = sum / n; 138 | let e2 = sum2 / n; 139 | let std = (e2 - e * e).sqrt(); 140 | let k = (std + self.epsilon).recip(); 141 | 142 | for j in 0..self.d as isize { 143 | let y = unsafe { &mut *self.y.byte_offset(i * self.nsy + j * self.dsy) }; 144 | let x: X = get(self.x, i * self.nsx + j * self.dsx); 145 | let s: X = get(self.s, j * self.dss); 146 | let b: X = get(self.b, j * self.dsb); 147 | 148 | *y = A::from((x - e).mul_add(s * k, b)).unwrap(); 149 | } 150 | } 151 | } 152 | } 153 | 154 | #[inline] 155 | fn get(ptr: *const T, offset: isize) -> X { 156 | X::from(unsafe { ptr.byte_offset(offset).read() }).unwrap() 157 | } 158 | -------------------------------------------------------------------------------- /operators/src/layer_norm/cuda/layer_norm.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | struct SumPair { 6 | float average; 7 | float variance; 8 | 9 | __device__ SumPair operator+(const SumPair &other) const { 10 | return SumPair{this->average + other.average, this->variance + other.variance}; 11 | } 12 | }; 13 | template 14 | static __device__ void padding( 15 | Ta *__restrict__ y_, 16 | int const stride_y, 17 | Ta const *__restrict__ x_, 18 | int const stride_x, 19 | Tw const *__restrict__ s_, 20 | Tw const *__restrict__ b_, 21 | float const epsilon) { 22 | auto y = y_ + blockIdx.x * stride_y + threadIdx.x; 23 | float const 24 | x = x_[blockIdx.x * stride_x + threadIdx.x], 25 | s = s_[threadIdx.x], 26 | b = b_[threadIdx.x]; 27 | 28 | using BlockOp = cub::BlockReduce; 29 | __shared__ typename BlockOp::TempStorage temp_storge; 30 | SumPair tmp = {x, x * x}; 31 | SumPair sum_pair = BlockOp(temp_storge).Reduce(tmp, cub::Sum()); 32 | __shared__ float average, variance; 33 | if (threadIdx.x == 0) { 34 | average = sum_pair.average / float(BLOCK_SIZE); 35 | variance = __frcp_rn(sqrtf(sum_pair.variance / float(BLOCK_SIZE) - powf(average, 2.0)) + epsilon); 36 | } 37 | __syncthreads(); 38 | 39 | *y = Ta((x - average) * variance * s + b); 40 | } 41 | 42 | template 43 | static __device__ void folding( 44 | Ta *__restrict__ y_, 45 | int const stride_y, 46 | Ta const *__restrict__ x_, 47 | int const stride_x, 48 | Tw const *__restrict__ s_, 49 | Tw const *__restrict__ b_, 50 | float const epsilon, 51 | unsigned int const items_size) { 52 | y_ += blockIdx.x * stride_y; 53 | x_ += blockIdx.x * stride_x; 54 | 55 | float data[NUM_ITEMS_THREAD], scale[NUM_ITEMS_THREAD], bias[NUM_ITEMS_THREAD]; 56 | { 57 | using BlockOp = cub::BlockLoad; 58 | __shared__ typename BlockOp::TempStorage temp_storage; 59 | BlockOp(temp_storage).Load(x_, data, items_size, 0.f); 60 | BlockOp(temp_storage).Load(s_, scale, items_size, 0.f); 61 | BlockOp(temp_storage).Load(b_, bias, items_size, 0.f); 62 | } 63 | 64 | float sum_average = 0, sum_variance = 0; 65 | #pragma unroll 66 | for (unsigned int i = 0; i < NUM_ITEMS_THREAD; ++i) { 67 | sum_average += data[i]; 68 | sum_variance += data[i] * data[i]; 69 | } 70 | 71 | SumPair tmp_sum = {sum_average, sum_variance}; 72 | using BlockOp = cub::BlockReduce; 73 | __shared__ typename BlockOp::TempStorage temp_storge; 74 | SumPair sum_pair = BlockOp(temp_storge).Reduce(tmp_sum, cub::Sum()); 75 | 76 | __shared__ float average, variance; 77 | if (threadIdx.x == 0) { 78 | average = sum_pair.average / float(items_size); 79 | variance = __frcp_rn(sqrtf(sum_pair.variance / float(items_size) - powf(average, 2.0)) + epsilon); 80 | } 81 | __syncthreads(); 82 | 83 | #pragma unroll 84 | for (unsigned int i = 0; i < NUM_ITEMS_THREAD; ++i) { 85 | data[i] = (data[i] - average) * variance * scale[i] + bias[i]; 86 | } 87 | 88 | { 89 | using BlockOp = cub::BlockStore; 90 | __shared__ typename BlockOp::TempStorage temp_storage; 91 | BlockOp(temp_storage).Store(y_, data, items_size); 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /operators/src/layer_norm/infini/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{Args, LayerNorm}; 2 | use crate::{infini::Device, ByteOf, LaunchError, QueueAlloc, SchemeError}; 3 | 4 | pub struct Operator; 5 | 6 | impl LayerNorm for Operator {} 7 | 8 | impl crate::Operator for Operator { 9 | type Hardware = Device; 10 | type TopoNode = Device; 11 | type Args = Args; 12 | 13 | fn new(_node: &Self::TopoNode) -> Self { 14 | todo!() 15 | } 16 | 17 | fn scheme( 18 | &mut self, 19 | _args: &Self::Args, 20 | _max_workspace_size: usize, 21 | ) -> Result { 22 | todo!() 23 | } 24 | 25 | fn launch( 26 | &self, 27 | _args: &Self::Args, 28 | _workspace: &mut [ByteOf], 29 | _queue_alloc: &QA, 30 | ) -> Result<(), LaunchError> 31 | where 32 | QA: QueueAlloc, 33 | { 34 | todo!() 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /operators/src/layer_norm/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(any(use_cpu, test))] 2 | pub mod common_cpu; 3 | #[cfg(use_cuda)] 4 | pub mod cuda; 5 | #[cfg(use_infini)] 6 | pub mod infini; 7 | #[cfg(use_cl)] 8 | pub mod opencl; 9 | 10 | mod args; 11 | pub use args::Args; 12 | 13 | crate::op_trait!(LayerNorm); 14 | -------------------------------------------------------------------------------- /operators/src/layer_norm/opencl/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{Args, LayerNorm}; 2 | use crate::{opencl::ClDevice, ByteOf, LaunchError, QueueAlloc, SchemeError}; 3 | 4 | pub struct Operator; 5 | 6 | impl LayerNorm for Operator {} 7 | 8 | impl crate::Operator for Operator { 9 | type Hardware = ClDevice; 10 | type TopoNode = ClDevice; 11 | type Args = Args; 12 | 13 | fn new(_node: &Self::TopoNode) -> Self { 14 | todo!() 15 | } 16 | 17 | fn scheme( 18 | &mut self, 19 | _args: &Self::Args, 20 | _max_workspace_size: usize, 21 | ) -> Result { 22 | todo!() 23 | } 24 | 25 | fn launch( 26 | &self, 27 | _args: &Self::Args, 28 | _workspace: &mut [ByteOf], 29 | _queue_alloc: &QA, 30 | ) -> Result<(), LaunchError> 31 | where 32 | QA: QueueAlloc, 33 | { 34 | todo!() 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /operators/src/lib.rs: -------------------------------------------------------------------------------- 1 | // #![deny(warnings)] 2 | 3 | mod common; 4 | mod handle; 5 | 6 | pub mod add; 7 | pub mod add_rows; 8 | pub mod all_reduce; 9 | pub mod attention; 10 | pub mod attention_kv_cached; 11 | pub mod broadcast; 12 | pub mod conv; 13 | pub mod fuesd_softmax; 14 | pub mod gelu; 15 | pub mod layer_norm; 16 | pub mod mat_mul; 17 | pub mod random_sample; 18 | pub mod rearrange; 19 | pub mod rms_norm; 20 | pub mod rope; 21 | pub mod swiglu; 22 | 23 | pub use common::*; 24 | 25 | #[cfg(any(use_cpu, test))] 26 | pub use handle::common_cpu; 27 | 28 | #[cfg(use_cl)] 29 | pub use handle::opencl; 30 | #[cfg(use_cl)] 31 | pub extern crate clrt; 32 | 33 | #[cfg(use_infini)] 34 | pub use handle::infini; 35 | #[cfg(use_infini)] 36 | pub extern crate infini_rt; 37 | 38 | #[cfg(use_cuda)] 39 | pub mod cuda { 40 | pub use crate::handle::cuda::*; 41 | pub use ::cuda::*; 42 | } 43 | #[cfg(use_cuda)] 44 | pub extern crate cublas; 45 | #[cfg(use_nccl)] 46 | pub extern crate nccl; 47 | 48 | use rearrange::Rearrange; 49 | use std::{marker::PhantomData, ops::DerefMut, ptr::addr_eq}; 50 | 51 | /// 算力硬件抽象。 52 | /// 53 | /// 约定硬件如何存储和运行。 54 | /// 这个特质应该由管理硬件的基本单元的映射类型实现,通常是**硬件上下文**。 55 | pub trait Hardware { 56 | /// 硬件的存储单元类型。 57 | type Byte; 58 | /// 硬件的任务队列类型。 59 | type Queue<'ctx>; 60 | } 61 | 62 | pub trait TopoNode { 63 | fn processor(&self) -> &H; 64 | fn rank(&self) -> usize; 65 | fn group_size(&self) -> usize; 66 | } 67 | 68 | impl TopoNode for H { 69 | #[inline] 70 | fn processor(&self) -> &H { 71 | self 72 | } 73 | #[inline] 74 | fn rank(&self) -> usize { 75 | 0 76 | } 77 | #[inline] 78 | fn group_size(&self) -> usize { 79 | 1 80 | } 81 | } 82 | 83 | pub type ByteOf = ::Byte; 84 | pub type QueueOf<'ctx, H> = ::Queue<'ctx>; 85 | pub type ArgsOf = ::Args; 86 | pub(crate) type MutPtr = *mut ::Byte; 87 | pub(crate) type ConstPtr = *const ::Byte; 88 | 89 | pub trait Alloc { 90 | fn alloc(&self, size: usize) -> M; 91 | fn free(&self, mem: M); 92 | } 93 | 94 | /// 绑定到队列的分配器。 95 | pub trait QueueAlloc: Alloc { 96 | /// 队列分配器对应的硬件。 97 | type Hardware: Hardware; 98 | /// 分配器分配和回收的对象,表示对某块存储区域的所有权。 99 | type DevMem: DerefMut]>; 100 | /// 分配器对应的队列。 101 | fn queue(&self) -> &QueueOf; 102 | } 103 | 104 | /// 算子。 105 | pub trait Operator { 106 | /// 执行算子的硬件。 107 | type Hardware: Hardware; 108 | /// 算子对应的通信拓扑节点。 109 | type TopoNode: TopoNode; 110 | /// 算子的参数类型。 111 | type Args; 112 | 113 | /// 在指定拓扑节点上创建算子实例。 114 | fn new(node: &Self::TopoNode) -> Self; 115 | 116 | /// 规划执行方案。 117 | /// 118 | /// 通过向算子实例提供尽可能详细的参数来尽量确定算子执行方案。 119 | /// 通过允许参数中标量值、张量形状、张量步长和张量基址的动态性([ArgVal] 或 [null](std::ptr::null))来尽可能复用算子实例。 120 | /// 121 | /// 另外,需要传入一个最大工作空间容量。工作空间是与硬件存储单元相同类型的存储区域,供算子执行过程中使用。 122 | /// 规划执行方案时,将尽可能尝试计算一个满足最大工作空间容量的工作空间需求,作为返回值。 123 | /// 124 | /// 算子的返回值将保证不大于最大工作空间容量。如果算子还需要更多空间,可能产生运行时分配。 125 | /// 126 | /// 由于参数提供可能不全,有时无法计算出具体的工作空间需求,算子将返回 0 作为工作空间需求,并在执行时再计算实际的需求。 127 | fn scheme( 128 | &mut self, 129 | args: &Self::Args, 130 | max_workspace_size: usize, 131 | ) -> Result; 132 | 133 | /// 发射算子到任务队列。 134 | /// 135 | /// 如果算子实际需要的工作空间大于通过参数提供的工作空间,将通过流分配器分配和释放工作空间。 136 | fn launch( 137 | &self, 138 | args: &Self::Args, 139 | workspace: &mut [ByteOf], 140 | queue_alloc: &QA, 141 | ) -> Result<(), LaunchError> 142 | where 143 | QA: QueueAlloc; 144 | } 145 | 146 | macro_rules! op_trait { 147 | ($name:ident $($body:item)*) => { 148 | pub trait $name: 149 | $crate::Operator< 150 | Hardware = H, 151 | TopoNode = H, 152 | Args = Args, 153 | >{$($body)*} 154 | }; 155 | } 156 | 157 | macro_rules! comm_trait { 158 | ($name:ident $($body:item)*) => { 159 | pub trait $name>: 160 | $crate::Operator< 161 | Hardware = H, 162 | TopoNode = N, 163 | Args = Args, 164 | >{$($body)*} 165 | }; 166 | } 167 | 168 | macro_rules! non_comm { 169 | ($name:ident impl $trait:ident) => { 170 | pub type $name = crate::NonComm>; 171 | impl $trait for $name 172 | where 173 | H: crate::Hardware, 174 | R: crate::rearrange::Rearrange, 175 | { 176 | } 177 | }; 178 | } 179 | 180 | pub(crate) use {comm_trait, non_comm, op_trait}; 181 | 182 | #[repr(transparent)] 183 | pub struct NonComm(R, PhantomData<(H, A)>); 184 | 185 | impl Operator for NonComm 186 | where 187 | H: Hardware, 188 | R: Rearrange, 189 | A: AsRef>, 190 | { 191 | type Hardware = H; 192 | type TopoNode = H; 193 | type Args = A; 194 | 195 | #[inline] 196 | fn new(node: &Self::TopoNode) -> Self { 197 | Self(R::new(node), PhantomData) 198 | } 199 | 200 | #[inline] 201 | fn scheme( 202 | &mut self, 203 | args: &Self::Args, 204 | max_workspace_size: usize, 205 | ) -> Result { 206 | self.0.scheme(args.as_ref(), max_workspace_size) 207 | } 208 | 209 | #[inline] 210 | fn launch( 211 | &self, 212 | args: &Self::Args, 213 | workspace: &mut [ByteOf], 214 | queue_alloc: &QA, 215 | ) -> Result<(), crate::LaunchError> 216 | where 217 | QA: QueueAlloc, 218 | { 219 | let args = args.as_ref(); 220 | if !addr_eq(args.dst_base, args.src_base) { 221 | self.0.launch(args, workspace, queue_alloc)? 222 | } 223 | Ok(()) 224 | } 225 | } 226 | -------------------------------------------------------------------------------- /operators/src/mat_mul/args.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | dyn_not_support, rank_not_support, shape_mismatch, shape_not_support, strides_not_support, 3 | utils::type_distinct, ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout, 4 | }; 5 | use digit_layout::DigitLayout; 6 | use std::{ 7 | mem::swap, 8 | ptr::{null, null_mut}, 9 | }; 10 | 11 | pub struct Args { 12 | pub c_layout: TensorLayout, 13 | pub c_base: MutPtr, 14 | pub beta: f32, 15 | pub a_layout: TensorLayout, 16 | pub a_base: ConstPtr, 17 | pub b_layout: TensorLayout, 18 | pub b_base: ConstPtr, 19 | pub alpha: f32, 20 | } 21 | 22 | #[derive(Clone, PartialEq, Eq, Debug)] 23 | pub(super) struct SchemeLayout { 24 | pub dt: DigitLayout, 25 | pub ab_swap: bool, 26 | pub a_trans: bool, 27 | pub b_trans: bool, 28 | 29 | pub batch: usize, 30 | pub m: usize, 31 | pub n: usize, 32 | pub k: usize, 33 | 34 | pub c_stride: isize, 35 | pub c_ld: isize, 36 | 37 | pub a_stride: isize, 38 | pub a_ld: isize, 39 | 40 | pub b_stride: isize, 41 | pub b_ld: isize, 42 | } 43 | 44 | impl Args { 45 | pub fn new_null( 46 | c_layout: TensorLayout, 47 | beta: f32, 48 | a_layout: TensorLayout, 49 | b_layout: TensorLayout, 50 | alpha: f32, 51 | ) -> Self { 52 | Self { 53 | c_layout, 54 | c_base: null_mut(), 55 | beta, 56 | a_layout, 57 | a_base: null(), 58 | b_layout, 59 | b_base: null(), 60 | alpha, 61 | } 62 | } 63 | 64 | pub(super) fn layout(&self) -> Result { 65 | let Self { 66 | c_layout, 67 | a_layout, 68 | b_layout, 69 | .. 70 | } = self; 71 | 72 | // 确认矩阵结构匹配 73 | let mut c = Matrix::try_from(&self.c_layout)?; 74 | let mut a = Matrix::try_from(&self.a_layout)?; 75 | let mut b = Matrix::try_from(&self.b_layout)?; 76 | if c.r != a.r || c.c != b.c || a.c != b.r { 77 | return Err(shape_mismatch("Inconsistent matrix shapes")); 78 | } 79 | // 确认批处理结构匹配 80 | let batch = c.batch; 81 | if !a.match_batch(batch) || !b.match_batch(batch) { 82 | return Err(shape_mismatch("Inconsistent batch sizes")); 83 | } 84 | // 确认 c 列优先 85 | let ab_swap = if c.rs == 1 && c.cs != 1 { 86 | // Nothing to do 87 | false 88 | } else if c.cs == 1 { 89 | // cT = bT.aT 90 | c.transpose(); 91 | a.transpose(); 92 | b.transpose(); 93 | swap(&mut a, &mut b); 94 | true 95 | } else { 96 | return Err(strides_not_support("Matrix is not contiguous")); 97 | }; 98 | 99 | let (a_ld, a_trans) = a.ld_trans()?; 100 | let (b_ld, b_trans) = b.ld_trans()?; 101 | Ok(SchemeLayout { 102 | dt: type_distinct(&[c_layout.dt(), a_layout.dt(), b_layout.dt()])?, 103 | ab_swap, 104 | a_trans, 105 | b_trans, 106 | 107 | batch, 108 | m: c.r, 109 | n: c.c, 110 | k: a.c, 111 | 112 | c_stride: c.stride, 113 | c_ld: c.cs, 114 | 115 | a_stride: a.stride, 116 | a_ld, 117 | 118 | b_stride: b.stride, 119 | b_ld, 120 | }) 121 | } 122 | } 123 | 124 | #[derive(Clone, Debug)] 125 | struct Matrix { 126 | batch: usize, 127 | stride: isize, 128 | r: usize, 129 | c: usize, 130 | rs: isize, 131 | cs: isize, 132 | } 133 | 134 | impl TryFrom<&TensorLayout> for Matrix { 135 | type Error = SchemeError; 136 | 137 | fn try_from(tensor: &TensorLayout) -> Result { 138 | let Some(shape) = MaybeDyn::get_all(tensor.shape()) else { 139 | return Err(dyn_not_support("")); 140 | }; 141 | let Some(strides) = MaybeDyn::get_all(tensor.strides()) else { 142 | return Err(dyn_not_support("")); 143 | }; 144 | 145 | let [batch @ .., r, c] = shape else { 146 | return Err(rank_not_support("Matrix must have rank 2 or more")); 147 | }; 148 | let [stride @ .., rs, cs] = strides else { 149 | unreachable!(); 150 | }; 151 | let unit = tensor.dt().nbytes() as isize; 152 | let (batch, stride) = match batch { 153 | [] | [1] => { 154 | assert!(matches!(stride, [] | [_])); 155 | (1, 0) 156 | } 157 | &[batch] => { 158 | let &[stride] = stride else { unreachable!() }; 159 | (batch, stride / unit) 160 | } 161 | _ => return Err(shape_not_support("Higher-rank tensors not supported")), 162 | }; 163 | Ok(Self { 164 | batch, 165 | stride, 166 | r: *r, 167 | c: *c, 168 | rs: rs / unit, 169 | cs: cs / unit, 170 | }) 171 | } 172 | } 173 | 174 | impl Matrix { 175 | #[inline(always)] 176 | fn match_batch(&self, batch: usize) -> bool { 177 | self.batch == 1 || self.batch == batch 178 | } 179 | #[inline(always)] 180 | fn ld_trans(&mut self) -> Result<(isize, bool), SchemeError> { 181 | match (self.rs, self.cs) { 182 | (1, cs) => Ok((cs, false)), 183 | (rs, 1) => Ok((rs, true)), 184 | (_, _) => Err(strides_not_support("Matrix is not contiguous")), 185 | } 186 | } 187 | #[inline(always)] 188 | fn transpose(&mut self) { 189 | swap(&mut self.r, &mut self.c); 190 | swap(&mut self.rs, &mut self.cs); 191 | } 192 | } 193 | -------------------------------------------------------------------------------- /operators/src/mat_mul/common_cpu/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{args::SchemeLayout, Args, MatMul}; 2 | use crate::{common_cpu::Cpu, type_not_support, ByteOf, LaunchError, QueueAlloc, SchemeError}; 3 | 4 | pub struct Operator; 5 | 6 | impl MatMul for Operator {} 7 | 8 | impl crate::Operator for Operator { 9 | type Hardware = Cpu; 10 | type TopoNode = Cpu; 11 | type Args = Args; 12 | 13 | #[inline] 14 | fn new(_node: &Self::TopoNode) -> Self { 15 | Self 16 | } 17 | 18 | fn scheme( 19 | &mut self, 20 | _args: &Self::Args, 21 | _max_workspace_size: usize, 22 | ) -> Result { 23 | Ok(0) 24 | } 25 | 26 | fn launch( 27 | &self, 28 | args: &Self::Args, 29 | _workspace: &mut [ByteOf], 30 | _queue_alloc: &QA, 31 | ) -> Result<(), LaunchError> 32 | where 33 | QA: QueueAlloc, 34 | { 35 | let SchemeLayout { 36 | dt, 37 | ab_swap, 38 | a_trans, 39 | b_trans, 40 | batch, 41 | m, 42 | n, 43 | k, 44 | c_stride, 45 | c_ld, 46 | a_stride, 47 | a_ld, 48 | b_stride, 49 | b_ld, 50 | } = args.layout()?; 51 | let &Args { 52 | c_base, 53 | beta, 54 | a_base, 55 | b_base, 56 | alpha, 57 | .. 58 | } = args; 59 | 60 | let c = c_base as usize; 61 | let [a, b] = if ab_swap { 62 | [b_base, a_base] 63 | } else { 64 | [a_base, b_base] 65 | } 66 | .map(|ptr| ptr as usize); 67 | let (lhs_cs, lhs_rs) = if a_trans { (1, a_ld) } else { (a_ld, 1) }; 68 | let (rhs_cs, rhs_rs) = if b_trans { (1, b_ld) } else { (b_ld, 1) }; 69 | 70 | macro_rules! gemm { 71 | ($ty:ty; $alpha:expr, $beta:expr) => { 72 | (0..batch as isize).for_each(|i| unsafe { 73 | gemm::gemm( 74 | m, 75 | n, 76 | k, 77 | (c as *mut $ty).offset(i * c_stride), 78 | c_ld, 79 | 1, 80 | beta != 0., 81 | (a as *const $ty).offset(i * a_stride), 82 | lhs_cs, 83 | lhs_rs, 84 | (b as *const $ty).offset(i * b_stride), 85 | rhs_cs, 86 | rhs_rs, 87 | $beta, 88 | $alpha, 89 | false, 90 | false, 91 | false, 92 | gemm::Parallelism::Rayon(0), 93 | ) 94 | }) 95 | }; 96 | } 97 | 98 | use digit_layout::types as ty; 99 | use gemm::f16; 100 | match dt { 101 | ty::F16 => gemm!(f16; f16::from_f32(alpha), f16::from_f32(beta)), 102 | ty::F32 => gemm!(f32; alpha, beta), 103 | ty::F64 => gemm!(f64; alpha as _, beta as _), 104 | _ => Err(type_not_support(format!("Unsupported {dt}")))?, 105 | } 106 | Ok(()) 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /operators/src/mat_mul/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(any(use_cpu, test))] 2 | pub mod common_cpu; 3 | #[cfg(use_cuda)] 4 | pub mod cuda; 5 | #[cfg(use_infini)] 6 | pub mod infini; 7 | #[cfg(use_cl)] 8 | pub mod opencl; 9 | 10 | mod args; 11 | pub use args::Args; 12 | 13 | crate::op_trait!(MatMul); 14 | -------------------------------------------------------------------------------- /operators/src/mat_mul/opencl/mat_mul.cl: -------------------------------------------------------------------------------- 1 | #define CL_TARGET_OPENCL_VERSION 200 2 | #pragma OPENCL EXTENSION cl_khr_fp16 : enable 3 | 4 | #ifndef Tval 5 | #define Tval float 6 | #endif 7 | 8 | #ifdef USE_HALF 9 | #define MUL(valueA, valueB) (float) (valueA * valueB) 10 | #define SCAL(beta, p, alpha, value) (half)(beta * (float) (*p) + alpha * value) 11 | #else 12 | #define MUL(valueA, valueB) valueA *valueB 13 | #define SCAL(beta, p, alpha, value) beta *(*p) + alpha *value 14 | #endif 15 | 16 | __kernel void general_gemm(__global Tval *A, __global Tval *B, __global Tval *C, 17 | int as, int ars, int acs, int bs, int brs, int bcs, 18 | int cs, int crs, int ccs, int batch, 19 | int M, int N, int K, float alpha, float beta) { 20 | int g_idx = get_global_id(0); 21 | int g_idy = get_global_id(1); 22 | int row_id = g_idy / N; 23 | int col_id = g_idy % N; 24 | 25 | Tval valueA = 0.0f; 26 | Tval valueB = 0.0f; 27 | float value = 0.0f; 28 | 29 | for (int i = 0; i < K; i++) { 30 | valueA = *(A + g_idx * as + row_id * ars + i * acs); 31 | valueB = *(B + g_idx * bs + i * brs + col_id * bcs); 32 | value += MUL(valueA, valueB); 33 | } 34 | 35 | __global Tval *p = C + g_idx * cs + row_id * crs + col_id * ccs; 36 | *p = SCAL(beta, p, alpha, value); 37 | } 38 | -------------------------------------------------------------------------------- /operators/src/random_sample/args.rs: -------------------------------------------------------------------------------- 1 | use super::KVPair; 2 | use crate::{ 3 | type_not_support, utils::rank_error, ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, 4 | TensorLayout, 5 | }; 6 | use digit_layout::{types as ty, DigitLayout}; 7 | use std::ptr::{null, null_mut}; 8 | 9 | pub struct Args { 10 | pub kv_pair: TensorLayout, 11 | pub kv_pair_base: MutPtr, 12 | pub logits: TensorLayout, 13 | pub logits_base: ConstPtr, 14 | pub indices: TensorLayout, 15 | pub indices_base: ConstPtr, 16 | pub config: SampleArgs, 17 | pub seed: f32, 18 | } 19 | 20 | #[derive(Clone, Copy, Debug)] 21 | pub struct SampleArgs { 22 | pub(super) temperature: f32, 23 | pub(super) top_p: f32, 24 | pub(super) top_k: usize, 25 | } 26 | 27 | #[derive(Clone, Copy, PartialEq, Eq, Debug)] 28 | pub enum SampleArgsError { 29 | NegativeTemperature, 30 | NonPositiveTop, 31 | } 32 | 33 | impl Args { 34 | pub fn layout(dt: DigitLayout, n: usize) -> Self { 35 | Args { 36 | kv_pair: TensorLayout::new(KVPair::<()>::LAYOUT, &[], &[]), 37 | kv_pair_base: null_mut(), 38 | logits: TensorLayout::new(dt, &[n], &[dt.nbytes() as _]), 39 | logits_base: null(), 40 | indices: TensorLayout::new(ty::U32, &[n], &[ty::U32.nbytes() as _]), 41 | indices_base: null(), 42 | config: SampleArgs { 43 | temperature: 0., 44 | top_p: 0., 45 | top_k: usize::MAX, 46 | }, 47 | seed: 0., 48 | } 49 | } 50 | } 51 | 52 | impl Default for SampleArgs { 53 | #[inline] 54 | fn default() -> Self { 55 | Self::ARG_MAX 56 | } 57 | } 58 | 59 | impl SampleArgs { 60 | pub const ARG_MAX: Self = Self { 61 | temperature: 0., 62 | top_p: 1., 63 | top_k: usize::MAX, 64 | }; 65 | 66 | pub fn new(temperature: f32, top_p: f32, top_k: usize) -> Result { 67 | if temperature < 0. { 68 | return Err(SampleArgsError::NegativeTemperature); 69 | } 70 | if top_k == 0 || top_p <= 0. { 71 | return Err(SampleArgsError::NonPositiveTop); 72 | } 73 | Ok(Self { 74 | temperature, 75 | top_p: f32::min(top_p, 1.), 76 | top_k, 77 | }) 78 | } 79 | 80 | #[inline] 81 | pub fn is_argmax(&self) -> bool { 82 | self.temperature == 0. || self.top_k == 1 83 | } 84 | } 85 | 86 | #[derive(PartialEq, Eq, Debug)] 87 | pub(super) struct Meta { 88 | pub dt: DigitLayout, 89 | pub n: MaybeDyn, 90 | } 91 | 92 | impl Args { 93 | pub(super) fn meta(&self) -> Result { 94 | let Self { 95 | kv_pair, 96 | logits, 97 | indices, 98 | .. 99 | } = self; 100 | 101 | if kv_pair.dt() != KVPair::<()>::LAYOUT { 102 | return Err(type_not_support("output must be KVpair")); 103 | } 104 | 105 | let dt_p = logits.dt(); 106 | if dt_p.nbytes() > size_of::() { 107 | return Err(type_not_support("element too large")); 108 | } 109 | if indices.dt() != ty::U32 { 110 | return Err(type_not_support("indices must be u32")); 111 | } 112 | let &[n] = self.logits.shape() else { 113 | return Err(rank_error("logits", 1, self.logits.ndim())); 114 | }; 115 | let &[_] = self.indices.shape() else { 116 | return Err(rank_error("indices", 1, self.indices.ndim())); 117 | }; 118 | 119 | Ok(Meta { dt: dt_p, n }) 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /operators/src/random_sample/common_cpu/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{args::Meta, Args, Indices, KVPair, RandomSample, SampleArgs}; 2 | use crate::{ 3 | common_cpu::Cpu, get_static, strides_not_support, type_not_support, ByteOf, LaunchError, 4 | QueueAlloc, SchemeError, 5 | }; 6 | use half::f16; 7 | use num_traits::Float; 8 | use std::{cmp::Ordering::Equal, slice::from_raw_parts}; 9 | 10 | pub struct Operator; 11 | 12 | impl RandomSample for Operator { 13 | fn build_indices(_n: usize, queue_alloc: &QA) -> Indices 14 | where 15 | QA: QueueAlloc, 16 | { 17 | Indices { 18 | n: 0, 19 | mem: queue_alloc.alloc(0), 20 | } 21 | } 22 | } 23 | 24 | impl crate::Operator for Operator { 25 | type Hardware = Cpu; 26 | type TopoNode = Cpu; 27 | type Args = Args; 28 | 29 | fn new(_node: &Self::TopoNode) -> Self { 30 | Self 31 | } 32 | 33 | fn scheme( 34 | &mut self, 35 | _args: &Self::Args, 36 | _max_workspace_size: usize, 37 | ) -> Result { 38 | Ok(0) 39 | } 40 | 41 | fn launch( 42 | &self, 43 | args: &Self::Args, 44 | _workspace: &mut [ByteOf], 45 | _queue_alloc: &QA, 46 | ) -> Result<(), LaunchError> 47 | where 48 | QA: QueueAlloc, 49 | { 50 | let Meta { dt, n } = args.meta()?; 51 | let &[s] = args.logits.strides() else { 52 | unreachable!() 53 | }; 54 | if s.get_static().copied() != Some(dt.nbytes() as isize) { 55 | return Err(strides_not_support("").into()); 56 | } 57 | 58 | get_static!(n); 59 | let Args { 60 | kv_pair_base, 61 | logits_base, 62 | config, 63 | seed, 64 | .. 65 | } = args; 66 | 67 | use digit_layout::types as ty; 68 | let kv = if config.is_argmax() { 69 | macro_rules! argmax { 70 | ($ty:ty) => { 71 | argmax::<$ty>(*logits_base, n).into_raw() 72 | }; 73 | } 74 | match dt { 75 | ty::F16 => argmax!(f16), 76 | ty::F32 => argmax!(f32), 77 | e => return Err(type_not_support(format!("{e} not support")).into()), 78 | } 79 | } else { 80 | let &SampleArgs { 81 | temperature, 82 | top_p, 83 | top_k, 84 | } = config; 85 | macro_rules! random { 86 | ($ty:ty) => { 87 | random::<$ty>(*logits_base, n, temperature, top_p, top_k, *seed).into_raw() 88 | }; 89 | } 90 | match dt { 91 | ty::F16 => random!(f16), 92 | ty::F32 => random!(f32), 93 | e => return Err(type_not_support(format!("{e} not support")).into()), 94 | } 95 | }; 96 | unsafe { kv_pair_base.cast::>().write(kv) }; 97 | 98 | Ok(()) 99 | } 100 | } 101 | 102 | fn argmax(ptr: *const u8, len: usize) -> KVPair { 103 | let (key, val) = unsafe { from_raw_parts(ptr.cast::(), len) } 104 | .iter() 105 | .enumerate() 106 | .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Equal)) 107 | .unwrap(); 108 | KVPair::new(key as _, *val) 109 | } 110 | 111 | fn random( 112 | ptr: *const u8, 113 | len: usize, 114 | t: f32, 115 | top_p: f32, 116 | top_k: usize, 117 | seed: f32, 118 | ) -> KVPair { 119 | // sort 120 | let ptr = ptr as usize; 121 | let mut logits = (0..len) 122 | .map(|idx| { 123 | KVPair::new( 124 | idx as _, 125 | unsafe { &*(ptr as *const T).add(idx) }.to_f32().unwrap(), 126 | ) 127 | }) 128 | .collect::>(); 129 | logits.sort_unstable(); 130 | let max = logits[0].val(); 131 | logits[0].set_val(1.); 132 | // softmax & sum 133 | for i in 1..logits.len() { 134 | let softmax = logits[i - 1].val() + ((logits[i].val() - max) / t).exp(); 135 | logits[i].set_val(softmax); 136 | } 137 | // topk & topp & random 138 | let pk = logits[top_k.min(logits.len()) - 1].val(); 139 | let pp = logits[logits.len() - 1].val() * top_p; 140 | let plimit = seed * f32::min(pk, pp); 141 | // sample 142 | let ans = *logits.iter().find(|p| p.val() >= plimit).unwrap(); 143 | KVPair::new(ans.idx() as _, T::from(ans.val()).unwrap()) 144 | } 145 | -------------------------------------------------------------------------------- /operators/src/random_sample/cuda/ffi.rs: -------------------------------------------------------------------------------- 1 | use crate::{random_sample::SampleArgs, LaunchError}; 2 | use cuda::{bindings::CUstream, AsRaw, DevByte, Stream}; 3 | use libloading::Library; 4 | 5 | type WorkspaceFunc = unsafe extern "C" fn( 6 | *mut usize, // argmax 7 | *mut usize, // random_sample 8 | usize, // n 9 | ) -> i32; 10 | 11 | type ArgMaxFunc = unsafe extern "C" fn( 12 | *mut DevByte, // - kv_pair 13 | *const DevByte, // logits 14 | usize, // n 15 | *mut DevByte, // - workspace_ptr 16 | usize, // workspace_len 17 | CUstream, // stream 18 | ) -> i32; 19 | 20 | type SampleFunc = unsafe extern "C" fn( 21 | *mut DevByte, // - kv_pair 22 | *const DevByte, // logits 23 | *const DevByte, // indices 24 | usize, // n 25 | f32, // - seed 26 | f32, // temperature 27 | f32, // topp 28 | usize, // topk 29 | *mut DevByte, // - workspace_ptr 30 | usize, // workspace_len 31 | CUstream, // stream 32 | ) -> i32; 33 | 34 | macro_rules! extern_c { 35 | ($ty:ty; $lib:expr, $name:expr; $($args:expr),* $(,)?) => {{ 36 | let result = unsafe { $lib.get::<$ty>($name.as_bytes()).unwrap()( $( $args ),* ) }; 37 | if result == ::cuda::bindings::CUresult::CUDA_SUCCESS as _ { 38 | Ok(()) 39 | } else { 40 | Err($crate::execution_failed(format!( 41 | "{} failed with cuda error code {result}", 42 | $name 43 | ))) 44 | } 45 | }}; 46 | } 47 | 48 | pub(super) fn workspace_size( 49 | lib: &Library, 50 | name: &str, 51 | n: usize, 52 | ) -> Result<(usize, usize), LaunchError> { 53 | let mut argmax_size = 0; 54 | let mut sample_size = 0; 55 | extern_c!(WorkspaceFunc; lib, name; &mut argmax_size, &mut sample_size, n)?; 56 | Ok((argmax_size, sample_size)) 57 | } 58 | 59 | pub(super) fn argmax( 60 | lib: &Library, 61 | name: &str, 62 | kv_pair: *mut DevByte, 63 | logits: *const DevByte, 64 | n: usize, 65 | workspace: &mut [DevByte], 66 | stream: &Stream, 67 | ) -> Result<(), LaunchError> { 68 | extern_c! { ArgMaxFunc; 69 | lib, name; 70 | 71 | kv_pair, 72 | logits, 73 | n, 74 | 75 | workspace.as_mut_ptr(), 76 | workspace.len(), 77 | stream.as_raw(), 78 | } 79 | } 80 | 81 | #[allow(clippy::too_many_arguments)] 82 | pub(super) fn sample( 83 | lib: &Library, 84 | name: &str, 85 | kv_pair: *mut DevByte, 86 | logits: *const DevByte, 87 | indices: *const DevByte, 88 | n: usize, 89 | config: SampleArgs, 90 | seed: f32, 91 | workspace: &mut [DevByte], 92 | stream: &Stream, 93 | ) -> Result<(), LaunchError> { 94 | extern_c! { SampleFunc; 95 | lib, name; 96 | 97 | kv_pair, 98 | logits, 99 | indices, 100 | n, 101 | 102 | seed, 103 | config.temperature, 104 | config.top_p, 105 | config.top_k, 106 | 107 | workspace.as_mut_ptr(), 108 | workspace.len(), 109 | stream.as_raw(), 110 | } 111 | } 112 | 113 | pub(super) fn format_code( 114 | dt: &str, 115 | workspace_name: &str, 116 | argmax_name: &str, 117 | sample_name: &str, 118 | ) -> String { 119 | use crate::cuda::{EXPORT, EXPORT_H}; 120 | const CODE: &str = include_str!("sample.cuh"); 121 | 122 | format!( 123 | r#" 124 | {EXPORT_H} 125 | {CODE} 126 | 127 | {EXPORT}cudaError {workspace_name}( 128 | size_t *argmax, 129 | size_t *random_sample, 130 | size_t n 131 | ) {{ 132 | return calculate_workspace_size<{dt}>(argmax, random_sample, n); 133 | }} 134 | 135 | {EXPORT}cudaError {argmax_name}( 136 | cub::KeyValuePair *kv_pair, 137 | {dt} const *logits, 138 | size_t n, 139 | 140 | void *workspace_ptr, 141 | size_t workspace_len, 142 | cudaStream_t stream 143 | ) {{ 144 | return arg_max( 145 | kv_pair, 146 | logits, 147 | n, 148 | 149 | workspace_ptr, 150 | workspace_len, 151 | stream); 152 | }} 153 | 154 | {EXPORT}cudaError {sample_name}( 155 | cub::KeyValuePair *kv_pair, 156 | {dt} const *logits, 157 | unsigned int const *indices, 158 | size_t n, 159 | 160 | float random, 161 | float temperature, 162 | float topp, 163 | size_t topk, 164 | 165 | void *workspace_ptr, 166 | size_t workspace_len, 167 | cudaStream_t stream 168 | ) {{ 169 | return random_sample( 170 | kv_pair, 171 | logits, 172 | indices, 173 | n, 174 | 175 | random, 176 | temperature, 177 | topp, 178 | topk, 179 | 180 | workspace_ptr, 181 | workspace_len, 182 | stream); 183 | }} 184 | "# 185 | ) 186 | } 187 | -------------------------------------------------------------------------------- /operators/src/random_sample/kv_pair.rs: -------------------------------------------------------------------------------- 1 | use digit_layout::layout; 2 | use std::{ 3 | cmp::Ordering::{self, Equal}, 4 | marker::PhantomData, 5 | mem::{align_of, size_of}, 6 | }; 7 | 8 | #[derive(Clone, Copy, Debug)] 9 | #[repr(C)] 10 | pub struct KVPair { 11 | idx: u32, 12 | val: u32, 13 | _phantom: PhantomData, 14 | } 15 | 16 | impl KVPair { 17 | layout!(LAYOUT u(32); 2); 18 | 19 | pub fn new(idx: u32, val: T) -> Self { 20 | const { assert!(size_of::() <= size_of::()) } 21 | const { assert!(align_of::() <= align_of::()) } 22 | 23 | let mut val_bytes = 0; 24 | let ptr = std::ptr::from_mut(&mut val_bytes).cast::(); 25 | unsafe { ptr.write(val) }; 26 | 27 | Self { 28 | idx, 29 | val: val_bytes, 30 | _phantom: PhantomData, 31 | } 32 | } 33 | 34 | #[inline] 35 | pub fn into_raw(self) -> KVPair<()> { 36 | KVPair { 37 | idx: self.idx, 38 | val: self.val, 39 | _phantom: PhantomData, 40 | } 41 | } 42 | 43 | #[inline] 44 | pub const fn idx(&self) -> usize { 45 | self.idx as _ 46 | } 47 | 48 | #[inline] 49 | pub const fn val(&self) -> T { 50 | let bytes = self.val.to_ne_bytes(); 51 | unsafe { bytes.as_ptr().cast::().read() } 52 | } 53 | } 54 | 55 | impl KVPair { 56 | #[inline] 57 | pub fn set_val(&mut self, val: f32) { 58 | self.val = val.to_bits(); 59 | } 60 | } 61 | 62 | impl PartialEq for KVPair { 63 | fn eq(&self, other: &Self) -> bool { 64 | self.cmp(other) == Equal 65 | } 66 | } 67 | impl Eq for KVPair {} 68 | impl PartialOrd for KVPair { 69 | fn partial_cmp(&self, other: &Self) -> Option { 70 | Some(self.cmp(other)) 71 | } 72 | } 73 | impl Ord for KVPair { 74 | fn cmp(&self, other: &Self) -> Ordering { 75 | match self.val().partial_cmp(&other.val()) { 76 | Some(Equal) => match self.idx.cmp(&other.idx) { 77 | Equal => Equal, 78 | ord => ord, 79 | }, 80 | Some(ord) => ord.reverse(), 81 | None => Equal, 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /operators/src/random_sample/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(any(use_cpu, test))] 2 | pub mod common_cpu; 3 | #[cfg(use_cuda)] 4 | pub mod cuda; 5 | #[cfg(use_infini)] 6 | pub mod infini; 7 | #[cfg(use_cl)] 8 | pub mod opencl; 9 | 10 | mod args; 11 | mod kv_pair; 12 | 13 | pub use args::{Args, SampleArgs, SampleArgsError}; 14 | pub use kv_pair::KVPair; 15 | 16 | crate::op_trait! { RandomSample 17 | fn build_indices(n: usize, queue_alloc: &QA) -> Indices 18 | where QA: crate::QueueAlloc; 19 | } 20 | 21 | pub struct Indices { 22 | pub n: usize, 23 | pub mem: Mem, 24 | } 25 | -------------------------------------------------------------------------------- /operators/src/random_sample/opencl/random_sample.cl: -------------------------------------------------------------------------------- 1 | #define CL_TARGET_OPENCL_VERSION 200 2 | #pragma OPENCL EXTENSION cl_khr_fp16 : enable 3 | 4 | #ifndef Tval 5 | #define Tval float 6 | #endif 7 | 8 | // assert: GROUP_SIZE is power of 2 9 | #ifndef GROUP_SIZE 10 | #define GROUP_SIZE 512 11 | #endif 12 | 13 | typedef unsigned int Tidx; 14 | 15 | typedef struct { 16 | Tidx idx; 17 | Tval val; 18 | } KVPair; 19 | 20 | KVPair group_argmax(local KVPair *data, KVPair reg) { 21 | Tidx const idx = get_local_id(0), 22 | len = get_local_size(0); 23 | 24 | data[idx] = reg; 25 | barrier(CLK_LOCAL_MEM_FENCE); 26 | 27 | for (Tidx stride = len >> 1; stride; stride >>= 1) { 28 | if (idx < stride) { 29 | local KVPair 30 | *a = data + idx, 31 | *b = data + idx + stride; 32 | if (b->val > a->val) *a = *b; 33 | } 34 | barrier(CLK_LOCAL_MEM_FENCE); 35 | } 36 | 37 | return data[0]; 38 | } 39 | 40 | kernel void argmax_build_pairs( 41 | global Tval const *input, 42 | global KVPair *output, 43 | Tidx const n, 44 | float init) { 45 | 46 | Tidx const 47 | g_idx = get_global_id(0), 48 | g_len = get_global_size(0), 49 | l_idx = get_local_id(0); 50 | 51 | // register: 每个线程可能处理多个数据,汇总到寄存器中 52 | // NOTICE 为保证线程利用率,每个线程应该处理至少 2 个数据 53 | KVPair reg = {-1, (Tval) init}; 54 | for (Tidx i = g_idx; i < n; i += g_len) { 55 | Tval const val = input[i]; 56 | if (val > reg.val) reg = (KVPair) {i, val}; 57 | } 58 | 59 | // local memory: 每个工作组在工作组内存中实现规约 60 | local KVPair kv_pairs[GROUP_SIZE]; 61 | reg = group_argmax(kv_pairs, reg); 62 | 63 | // 最终结果写回 global 64 | if (l_idx == 0) output[g_idx / GROUP_SIZE] = reg; 65 | } 66 | 67 | kernel void argmax_reduce( 68 | global KVPair const *pairs, 69 | global KVPair *output, 70 | Tidx const n, 71 | float init) { 72 | 73 | Tidx const 74 | g_idx = get_global_id(0), 75 | g_len = get_global_size(0), 76 | l_idx = get_local_id(0); 77 | 78 | // register: 每个线程可能处理多个数据,汇总到寄存器中 79 | // NOTICE 为保证线程利用率,每个线程应该处理至少 2 个数据 80 | KVPair reg = {-1, (Tval) init}; 81 | for (Tidx i = g_idx; i < n; i += g_len) { 82 | KVPair const pair = pairs[i]; 83 | if (pair.val > reg.val) reg = pair; 84 | } 85 | 86 | // local memory: 每个工作组在工作组内存中实现规约 87 | local KVPair kv_pairs[GROUP_SIZE]; 88 | reg = group_argmax(kv_pairs, reg); 89 | 90 | // 最终结果写回 global 91 | if (l_idx == 0) *output = reg; 92 | } 93 | -------------------------------------------------------------------------------- /operators/src/rearrange/common_cpu/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{args::Scheme, Args, Rearrange}; 2 | use crate::{common_cpu::Cpu, ByteOf, LaunchError, QueueAlloc, SchemeError}; 3 | use rayon::iter::{IntoParallelIterator, ParallelIterator}; 4 | 5 | pub struct Operator; 6 | 7 | impl Rearrange for Operator {} 8 | 9 | impl crate::Operator for Operator { 10 | type Hardware = Cpu; 11 | type TopoNode = Cpu; 12 | type Args = Args; 13 | 14 | fn new(_node: &Self::TopoNode) -> Self { 15 | Self 16 | } 17 | 18 | fn scheme( 19 | &mut self, 20 | _args: &Self::Args, 21 | _max_workspace_size: usize, 22 | ) -> Result { 23 | Ok(0) 24 | } 25 | 26 | fn launch( 27 | &self, 28 | args: &Self::Args, 29 | _workspace: &mut [ByteOf], 30 | _queue_alloc: &QA, 31 | ) -> Result<(), LaunchError> 32 | where 33 | QA: QueueAlloc, 34 | { 35 | let scheme = Scheme::new(args)?; 36 | let unit = scheme.unit(); 37 | if scheme.count() == 1 { 38 | unsafe { std::ptr::copy_nonoverlapping(args.src_base, args.dst_base, unit) }; 39 | } else { 40 | let dst = args.dst_base as isize; 41 | let src = args.src_base as isize; 42 | let idx_strides = scheme.idx_strides(); 43 | let dst_strides = scheme.dst_strides(); 44 | let src_strides = scheme.src_strides(); 45 | (0..scheme.count() as isize) 46 | .into_par_iter() 47 | .for_each(|mut rem| { 48 | let mut dst = dst; 49 | let mut src = src; 50 | for (i, &s) in idx_strides.iter().enumerate() { 51 | let k = rem / s; 52 | dst += k * dst_strides[i]; 53 | src += k * src_strides[i]; 54 | rem %= s; 55 | } 56 | unsafe { std::ptr::copy_nonoverlapping::(src as _, dst as _, unit) }; 57 | }); 58 | } 59 | Ok(()) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /operators/src/rearrange/cuda/rearrange.cuh: -------------------------------------------------------------------------------- 1 | template 2 | static __device__ void rearrange( 3 | void *__restrict__ dst, 4 | int const rsa, 5 | int const csa, 6 | void const *__restrict__ src, 7 | int const rsb, 8 | int const csb, 9 | unsigned int const ncols) { 10 | 11 | auto row = blockIdx.y, 12 | col = blockIdx.x * blockDim.y + threadIdx.y; 13 | if (col >= ncols) return; 14 | 15 | auto thread = threadIdx.x, 16 | warp_size = blockDim.x; 17 | auto i = (row * rsa + col * csa) * warp_size + thread; 18 | auto j = (row * rsb + col * csb) * warp_size + thread; 19 | // printf("%d %d %d %d: row = %d, col = %d, nrows = %d, ncols = %d, rsa = %d, rsb = %d, csa = %d, csb = %d, warp_size = %d, thread = %d, i = %d, j = %d\n", 20 | // blockIdx.y, blockIdx.x, threadIdx.y, threadIdx.x, row, col, gridDim.y, ncols, rsa, rsb, csa, csb, warp_size, thread, i, j); 21 | 22 | reinterpret_cast(dst)[i] = reinterpret_cast(src)[j]; 23 | } 24 | -------------------------------------------------------------------------------- /operators/src/rearrange/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(any(use_cpu, test))] 2 | pub mod common_cpu; 3 | #[cfg(use_cuda)] 4 | pub mod cuda; 5 | #[cfg(use_infini)] 6 | pub mod infini; 7 | #[cfg(use_cl)] 8 | pub mod opencl; 9 | 10 | mod args; 11 | pub use args::Args; 12 | 13 | crate::op_trait!(Rearrange); 14 | -------------------------------------------------------------------------------- /operators/src/rearrange/opencl/rearrange.cl: -------------------------------------------------------------------------------- 1 | #define CL_TARGET_OPENCL_VERSION 200 2 | #pragma OPENCL EXTENSION cl_khr_fp16 : enable 3 | 4 | __kernel void rearrange( 5 | __global unsigned int *dst, 6 | unsigned int rsa, 7 | unsigned int csa, 8 | __global unsigned int *src, 9 | unsigned int rsb, 10 | unsigned int csb, 11 | unsigned int ncols, 12 | unsigned int unit) { 13 | 14 | int g_id = get_global_id(0); 15 | int group_id = g_id / unit; 16 | int l_id = g_id % unit; 17 | 18 | int rows = group_id / ncols; 19 | int cols = group_id % ncols; 20 | 21 | int i = (rows * rsa + cols * csa) * unit + l_id; 22 | int j = (rows * rsb + cols * csb) * unit + l_id; 23 | dst[i] = src[j]; 24 | } 25 | -------------------------------------------------------------------------------- /operators/src/rms_norm/args.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | utils::{dim_distinct, rank_error, type_distinct}, 3 | ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout, 4 | }; 5 | use digit_layout::DigitLayout; 6 | 7 | pub struct Args { 8 | pub y_layout: TensorLayout, 9 | pub y_base: MutPtr, 10 | pub x_layout: TensorLayout, 11 | pub x_base: ConstPtr, 12 | pub w_layout: TensorLayout, 13 | pub w_base: ConstPtr, 14 | pub epsilon: f32, 15 | } 16 | 17 | pub(super) struct Meta { 18 | pub dt_a: DigitLayout, 19 | pub dt_w: DigitLayout, 20 | pub n: MaybeDyn, 21 | pub d: MaybeDyn, 22 | } 23 | 24 | impl Args { 25 | pub(super) fn meta(&self) -> Result { 26 | let Self { 27 | y_layout, 28 | x_layout, 29 | w_layout, 30 | .. 31 | } = self; 32 | 33 | let &[ny, dy] = y_layout.shape() else { 34 | return Err(rank_error("y", 2, y_layout.ndim())); 35 | }; 36 | let &[nx, dx] = x_layout.shape() else { 37 | return Err(rank_error("x", 2, x_layout.ndim())); 38 | }; 39 | let &[dw] = w_layout.shape() else { 40 | return Err(rank_error("w", 1, w_layout.ndim())); 41 | }; 42 | 43 | Ok(Meta { 44 | dt_a: type_distinct(&[y_layout.dt(), x_layout.dt()])?, 45 | dt_w: w_layout.dt(), 46 | n: dim_distinct(&[ny, nx])?, 47 | d: dim_distinct(&[dy, dx, dw])?, 48 | }) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /operators/src/rms_norm/common_cpu/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{args::Meta, Args, RmsNorm}; 2 | use crate::{common_cpu::Cpu, get_static, ByteOf, LaunchError, QueueAlloc, SchemeError}; 3 | use half::f16; 4 | use rayon::iter::{IntoParallelIterator, ParallelIterator}; 5 | 6 | pub struct Operator; 7 | 8 | impl RmsNorm for Operator {} 9 | 10 | impl crate::Operator for Operator { 11 | type Hardware = Cpu; 12 | type TopoNode = Cpu; 13 | type Args = Args; 14 | 15 | fn new(_node: &Self::TopoNode) -> Self { 16 | Self 17 | } 18 | 19 | fn scheme( 20 | &mut self, 21 | args: &Self::Args, 22 | _max_workspace_size: usize, 23 | ) -> Result { 24 | let _meta = args.meta()?; 25 | Ok(0) 26 | } 27 | 28 | fn launch( 29 | &self, 30 | args: &Self::Args, 31 | _workspace: &mut [ByteOf], 32 | _queue_alloc: &QA, 33 | ) -> Result<(), LaunchError> 34 | where 35 | QA: QueueAlloc, 36 | { 37 | let Meta { dt_w, dt_a, n, d } = args.meta()?; 38 | let Args { 39 | y_layout, 40 | y_base, 41 | x_layout, 42 | x_base, 43 | w_layout, 44 | w_base, 45 | epsilon, 46 | } = args; 47 | let &[nsy, dsy] = y_layout.strides() else { 48 | unreachable!() 49 | }; 50 | let &[nsx, dsx] = x_layout.strides() else { 51 | unreachable!() 52 | }; 53 | let &[dsw] = w_layout.strides() else { 54 | unreachable!() 55 | }; 56 | 57 | get_static! { 58 | n d 59 | nsy dsy 60 | nsx dsx 61 | dsw 62 | } 63 | 64 | macro_rules! calculate { 65 | ($w:ty, $a:ty) => { 66 | Scheme::<$w, $a> { 67 | n, 68 | d, 69 | nsy, 70 | dsy, 71 | nsx, 72 | dsx, 73 | dsw, 74 | epsilon: *epsilon, 75 | y: y_base.cast(), 76 | x: x_base.cast(), 77 | w: w_base.cast(), 78 | } 79 | .calculate() 80 | }; 81 | } 82 | 83 | use digit_layout::types as ty; 84 | match (dt_w, dt_a) { 85 | (ty::F16, ty::F16) => calculate!(f16, f16), 86 | (ty::F32, ty::F16) => calculate!(f32, f16), 87 | (ty::F32, ty::F32) => calculate!(f32, f32), 88 | (ty::F64, ty::F64) => calculate!(f64, f64), 89 | (_, _) => todo!(), 90 | } 91 | 92 | Ok(()) 93 | } 94 | } 95 | 96 | struct Scheme { 97 | n: usize, 98 | d: usize, 99 | nsy: isize, 100 | dsy: isize, 101 | nsx: isize, 102 | dsx: isize, 103 | dsw: isize, 104 | epsilon: f32, 105 | y: *mut A, 106 | x: *const A, 107 | w: *const W, 108 | } 109 | 110 | unsafe impl Send for Scheme {} 111 | unsafe impl Sync for Scheme {} 112 | 113 | impl Scheme { 114 | #[inline] 115 | unsafe fn y_ptr(&self, i: isize, j: isize) -> *mut A { 116 | self.y.byte_offset(i * self.nsy + j * self.dsy) 117 | } 118 | #[inline] 119 | unsafe fn x_ptr(&self, i: isize, j: isize) -> *const A { 120 | self.x.byte_offset(i * self.nsx + j * self.dsx) 121 | } 122 | #[inline] 123 | unsafe fn w_ptr(&self, j: isize) -> *const W { 124 | self.w.byte_offset(j * self.dsw) 125 | } 126 | } 127 | 128 | macro_rules! impl_k { 129 | ($ty:ty) => { 130 | fn k(&self, i: isize) -> $ty { 131 | let sum = (0..self.d as isize) 132 | .map(|j| unsafe { self.x(i, j) }.powi(2)) 133 | .sum::<$ty>(); 134 | (sum / (self.d as $ty) + self.epsilon as $ty).sqrt().recip() 135 | } 136 | }; 137 | } 138 | 139 | impl Scheme { 140 | impl_k!(f32); 141 | 142 | #[inline] 143 | unsafe fn y(&self, i: isize, j: isize, val: f32) { 144 | self.y_ptr(i, j).write(f16::from_f32(val)) 145 | } 146 | #[inline] 147 | unsafe fn x(&self, i: isize, j: isize) -> f32 { 148 | self.x_ptr(i, j).read().to_f32() 149 | } 150 | } 151 | impl Scheme { 152 | impl_k!(f32); 153 | 154 | #[inline] 155 | unsafe fn y(&self, i: isize, j: isize, val: f32) { 156 | self.y_ptr(i, j).write(val) 157 | } 158 | #[inline] 159 | unsafe fn x(&self, i: isize, j: isize) -> f32 { 160 | self.x_ptr(i, j).read() 161 | } 162 | } 163 | impl Scheme { 164 | impl_k!(f64); 165 | 166 | #[inline] 167 | unsafe fn y(&self, i: isize, j: isize, val: f64) { 168 | self.y_ptr(i, j).write(val) 169 | } 170 | #[inline] 171 | unsafe fn x(&self, i: isize, j: isize) -> f64 { 172 | self.x_ptr(i, j).read() 173 | } 174 | } 175 | 176 | impl Scheme { 177 | #[inline] 178 | unsafe fn w(&self, j: isize) -> f32 { 179 | self.w_ptr(j).read().to_f32() 180 | } 181 | } 182 | impl Scheme { 183 | #[inline] 184 | unsafe fn w(&self, j: isize) -> f32 { 185 | self.w_ptr(j).read() 186 | } 187 | } 188 | impl Scheme { 189 | #[inline] 190 | unsafe fn w(&self, j: isize) -> f64 { 191 | self.w_ptr(j).read() 192 | } 193 | } 194 | 195 | macro_rules! impl_scheme { 196 | ($w:ty, $a:ty) => { 197 | impl Scheme<$w, $a> { 198 | fn calculate(self) { 199 | for i in 0..self.n as isize { 200 | let k = self.k(i); 201 | (0..self.d as isize) 202 | .into_par_iter() 203 | .for_each(|j| unsafe { self.y(i, j, k * self.w(j) * self.x(i, j)) }); 204 | } 205 | } 206 | } 207 | }; 208 | } 209 | 210 | impl_scheme!(f16, f16); 211 | impl_scheme!(f32, f16); 212 | impl_scheme!(f32, f32); 213 | impl_scheme!(f64, f64); 214 | -------------------------------------------------------------------------------- /operators/src/rms_norm/cuda/rms_norm.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | // assert BLOCK_SIZE >= blockDim.x 6 | template 7 | static __device__ void padding( 8 | Ta *__restrict__ y_, 9 | int const stride_y, 10 | Ta const *__restrict__ x_, 11 | int const stride_x, 12 | Tw const *__restrict__ w_, 13 | float const epsilon) { 14 | auto y = y_ + blockIdx.x * stride_y + threadIdx.x; 15 | float const 16 | x = x_[blockIdx.x * stride_x + threadIdx.x], 17 | w = w_[threadIdx.x]; 18 | 19 | using BlockOp = cub::BlockReduce; 20 | __shared__ typename BlockOp::TempStorage temp_storage; 21 | auto acc = BlockOp(temp_storage).Reduce(x * x, cub::Sum()); 22 | 23 | __shared__ float rms; 24 | if (threadIdx.x == 0) { 25 | rms = rsqrtf(acc / float(blockDim.x) + epsilon); 26 | } 27 | __syncthreads(); 28 | 29 | *y = Ta(rms * x * w); 30 | } 31 | 32 | template 33 | static __device__ void folding( 34 | Ta *__restrict__ y, 35 | int const stride_y, 36 | Ta const *__restrict__ x, 37 | int const stride_x, 38 | Tw const *__restrict__ w, 39 | float const epsilon, 40 | unsigned int const items_size) { 41 | y += blockIdx.x * stride_y; 42 | x += blockIdx.x * stride_x; 43 | 44 | float data[NUM_ITEMS_THREAD], weight[NUM_ITEMS_THREAD]; 45 | { 46 | using BlockOp = cub::BlockLoad; 47 | __shared__ typename BlockOp::TempStorage temp_storage; 48 | BlockOp(temp_storage).Load(x, data, items_size, 0.f); 49 | BlockOp(temp_storage).Load(w, weight, items_size, 0.f); 50 | } 51 | 52 | float squared = 0; 53 | #pragma unroll 54 | for (unsigned int i = 0; i < NUM_ITEMS_THREAD; ++i) { 55 | squared += data[i] * data[i]; 56 | } 57 | 58 | float acc; 59 | { 60 | using BlockOp = cub::BlockReduce; 61 | __shared__ typename BlockOp::TempStorage temp_storage; 62 | acc = BlockOp(temp_storage).Reduce(squared, cub::Sum()); 63 | } 64 | 65 | __shared__ float rms; 66 | if (threadIdx.x == 0) { 67 | rms = rsqrtf(acc / float(items_size) + epsilon); 68 | } 69 | __syncthreads(); 70 | 71 | #pragma unroll 72 | for (unsigned int i = 0; i < NUM_ITEMS_THREAD; ++i) { 73 | data[i] = rms * data[i] * weight[i]; 74 | } 75 | 76 | { 77 | using BlockOp = cub::BlockStore; 78 | __shared__ typename BlockOp::TempStorage temp_storage; 79 | BlockOp(temp_storage).Store(y, data, items_size); 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /operators/src/rms_norm/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(any(use_cpu, test))] 2 | pub mod common_cpu; 3 | #[cfg(use_cuda)] 4 | pub mod cuda; 5 | #[cfg(use_infini)] 6 | pub mod infini; 7 | #[cfg(use_cl)] 8 | pub mod opencl; 9 | 10 | mod args; 11 | pub use args::Args; 12 | 13 | crate::op_trait!(RmsNorm); 14 | -------------------------------------------------------------------------------- /operators/src/rms_norm/opencl/rms_norm.cl: -------------------------------------------------------------------------------- 1 | #define CL_TARGET_OPENCL_VERSION 200 2 | #pragma OPENCL EXTENSION cl_khr_fp16 : enable 3 | 4 | #ifndef Ta 5 | #define Ta float 6 | #endif 7 | 8 | #ifndef Tw 9 | #define Tw float 10 | #endif 11 | 12 | #ifndef ITEMS_THREAD 13 | #define ITEMS_THREAD 1 14 | #endif 15 | 16 | typedef unsigned int Tidx; 17 | 18 | kernel void rms_norm( 19 | global Ta *y_, 20 | int const y_stride, 21 | global Ta const *x_, 22 | int const x_stride, 23 | global Tw const *w, 24 | float const epsilon, 25 | Tidx const d) { 26 | 27 | Tidx g_idx = get_group_id(0), 28 | l_idx = get_local_id(0), 29 | l_len = get_local_size(0); 30 | global Ta 31 | *y = y_ + g_idx * y_stride; 32 | global Ta const 33 | *x = x_ + g_idx * x_stride; 34 | 35 | float 36 | val_x[ITEMS_THREAD], 37 | val_w[ITEMS_THREAD], 38 | squared = 0; 39 | for (Tidx i = 0, idx = l_idx; idx < d; ++i, idx += l_len) { 40 | val_x[i] = x[idx]; 41 | val_w[i] = w[idx]; 42 | squared += val_x[i] * val_x[i]; 43 | } 44 | 45 | float rms = native_rsqrt(work_group_reduce_add(squared) / d + epsilon); 46 | 47 | for (Tidx i = 0, idx = l_idx; idx < d; ++i, idx += l_len) 48 | y[idx] = rms * val_x[i] * val_w[i]; 49 | } 50 | -------------------------------------------------------------------------------- /operators/src/rope/args.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | type_not_support, 3 | utils::{dim_distinct, rank_error}, 4 | ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout, 5 | }; 6 | use digit_layout::DigitLayout; 7 | 8 | pub struct Args { 9 | pub t_layout: TensorLayout, 10 | pub t_base: MutPtr, 11 | pub p_layout: TensorLayout, 12 | pub p_base: ConstPtr, 13 | pub sin_layout: TensorLayout, 14 | pub sin_base: ConstPtr, 15 | pub cos_layout: TensorLayout, 16 | pub cos_base: ConstPtr, 17 | pub theta: f32, 18 | } 19 | 20 | pub(super) struct Meta { 21 | pub dt_t: DigitLayout, 22 | pub dt_p: DigitLayout, 23 | pub nt: MaybeDyn, 24 | #[allow(dead_code)] 25 | pub dh: MaybeDyn, 26 | } 27 | 28 | impl Args { 29 | pub(super) fn meta(&self) -> Result { 30 | let Self { 31 | t_layout, 32 | p_layout, 33 | sin_layout, 34 | cos_layout, 35 | .. 36 | } = self; 37 | 38 | let &[nt, _, dh] = t_layout.shape() else { 39 | return Err(rank_error("t", 3, t_layout.ndim())); 40 | }; 41 | let &[np] = p_layout.shape() else { 42 | return Err(rank_error("p", 1, p_layout.ndim())); 43 | }; 44 | let &[_, dh_sin] = sin_layout.shape() else { 45 | return Err(rank_error("sin", 2, sin_layout.ndim())); 46 | }; 47 | let &[_, dh_cos] = cos_layout.shape() else { 48 | return Err(rank_error("cos", 2, cos_layout.ndim())); 49 | }; 50 | 51 | let dt_t = t_layout.dt(); 52 | let dt_p = p_layout.dt(); 53 | use digit_layout::LayoutContent::{Real, Unsigned}; 54 | // tokens must be floating-point numbers 55 | if !matches!(dt_t.decode(), Real { exponent: 1.., .. },) { 56 | return Err(type_not_support(format!( 57 | "data type {dt_t} is not supported, must be floating-point numbers", 58 | ))); 59 | } 60 | // positions must be unsigned integers 61 | if !matches!(dt_p.decode(), Unsigned { .. }) { 62 | return Err(type_not_support(format!( 63 | "data type {dt_p} is not supported, must be unsigned integers" 64 | ))); 65 | } 66 | Ok(Meta { 67 | dt_t, 68 | dt_p, 69 | nt: dim_distinct(&[nt, np])?, 70 | dh: dim_distinct(&[dh, dh_sin, dh_cos])?, 71 | }) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /operators/src/rope/cuda/rope.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | template 4 | static __device__ void padding( 5 | half2 *__restrict__ t, 6 | int const stride_token, 7 | int const stride_head, 8 | Tp const *__restrict__ pos, 9 | float const theta) { 10 | 11 | auto const 12 | // nt = gridDim.y, 13 | // nh_h = gridDim.x, 14 | nh_l = blockDim.y, 15 | dh = blockDim.x, 16 | 17 | it = blockIdx.y, // token index 18 | ih_h = blockIdx.x, // head index (high) 19 | ih_l = threadIdx.y, // head index (low) 20 | ih = ih_h * nh_l + ih_l,// head index 21 | i = threadIdx.x; // element index 22 | 23 | t += it * stride_token + ih * stride_head + i; 24 | float a = t->x, b = t->y, sin, cos; 25 | sincosf(float(pos[it]) / powf(theta, float(i) / float(dh)), &sin, &cos); 26 | *t = half2(a * cos - b * sin, a * sin + b * cos); 27 | } 28 | -------------------------------------------------------------------------------- /operators/src/rope/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(any(use_cpu, test))] 2 | pub mod common_cpu; 3 | #[cfg(use_cuda)] 4 | pub mod cuda; 5 | #[cfg(use_infini)] 6 | pub mod infini; 7 | #[cfg(use_cl)] 8 | pub mod opencl; 9 | 10 | mod args; 11 | pub use args::Args; 12 | 13 | crate::op_trait! { Rope 14 | /// 生成 sincos 表([2, n, dh])。 15 | fn build_sincos(dt: digit_layout::DigitLayout, nctx: usize, dh: usize, queue_alloc: &QA) -> SinCosTable 16 | where QA: crate::QueueAlloc; 17 | /// 为多个请求生成位置向量([nt])。 18 | fn build_pos(dt: digit_layout::DigitLayout, nt: usize, iter: I, queue_alloc: &QA) -> QA::DevMem 19 | where I: IntoIterator, 20 | QA: crate::QueueAlloc; 21 | } 22 | 23 | pub struct Seq { 24 | pub pos: usize, 25 | pub len: usize, 26 | } 27 | 28 | pub struct SinCosTable { 29 | pub nctx: usize, 30 | pub mem: Mem, 31 | } 32 | 33 | trait PosTy { 34 | fn from_usize(p: usize) -> Self; 35 | } 36 | 37 | impl PosTy for u32 { 38 | fn from_usize(p: usize) -> Self { 39 | p as _ 40 | } 41 | } 42 | 43 | impl PosTy for u64 { 44 | fn from_usize(p: usize) -> Self { 45 | p as _ 46 | } 47 | } 48 | 49 | fn fill_pos(ptr: *mut T, len: usize, iter: I) 50 | where 51 | T: PosTy, 52 | I: IntoIterator, 53 | { 54 | iter.into_iter() 55 | .flat_map(|seq| seq.pos..seq.pos + seq.len) 56 | .zip(unsafe { std::slice::from_raw_parts_mut(ptr, len) }) 57 | .for_each(|(pos, out)| *out = T::from_usize(pos)) 58 | } 59 | -------------------------------------------------------------------------------- /operators/src/rope/opencl/rope.cl: -------------------------------------------------------------------------------- 1 | #define CL_TARGET_OPENCL_VERSION 200 2 | #pragma OPENCL EXTENSION cl_khr_fp16 : enable 3 | 4 | #ifndef Tval 5 | #define Tval float2 6 | #endif 7 | 8 | #ifndef Tpos 9 | #define Tpos unsigned int 10 | #endif 11 | 12 | #ifdef USE_HALF 13 | #define LOAD_DATA(ptr) vload_half2(0, (__global half *) ptr) 14 | #define STORE_DATA(ptr, val) vstore_half2(val, 0, (__global half *) ptr) 15 | #else 16 | #define LOAD_DATA(ptr) (*ptr) 17 | #define STORE_DATA(ptr, val) (*ptr = val) 18 | #endif 19 | 20 | typedef unsigned int Tidx; 21 | 22 | __kernel void rope( 23 | __global Tval *t, 24 | int const stride_token, 25 | int const stride_head, 26 | __global Tpos const *pos, 27 | float const theta) { 28 | 29 | Tidx nh_l = get_local_size(0), 30 | dh = get_local_size(1), 31 | it = get_group_id(0), 32 | ih_h = get_group_id(1), 33 | ih_l = get_local_id(0), 34 | ih = ih_h * nh_l + ih_l, 35 | i = get_local_id(1); 36 | 37 | __global Tval *t2 = t + it * stride_token + ih * stride_head + i; 38 | 39 | float2 data = LOAD_DATA(t2); 40 | float angle = (float) (pos[it]) / pow(theta, (float) i / (float) dh); 41 | float sin_val = native_sin(angle); 42 | float cos_val = native_cos(angle); 43 | 44 | float2 result; 45 | result.x = data.x * cos_val - data.y * sin_val; 46 | result.y = data.x * sin_val + data.y * cos_val; 47 | STORE_DATA(t2, result); 48 | } 49 | -------------------------------------------------------------------------------- /operators/src/swiglu/args.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | utils::{dim_distinct, rank_error, type_distinct}, 3 | ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout, 4 | }; 5 | use digit_layout::DigitLayout; 6 | 7 | pub struct Args { 8 | pub gate_layout: TensorLayout, 9 | pub gate_base: MutPtr, 10 | pub up_layout: TensorLayout, 11 | pub up_base: ConstPtr, 12 | } 13 | 14 | pub(super) struct Meta { 15 | pub dt: DigitLayout, 16 | pub n: MaybeDyn, 17 | pub d: MaybeDyn, 18 | } 19 | 20 | impl Args { 21 | pub fn new_layout(gate_layout: TensorLayout, up_layout: TensorLayout) -> Self { 22 | use std::ptr::{null, null_mut}; 23 | Self { 24 | gate_layout, 25 | gate_base: null_mut(), 26 | up_layout, 27 | up_base: null(), 28 | } 29 | } 30 | 31 | pub(super) fn meta(&self) -> Result { 32 | let Self { 33 | gate_layout, 34 | up_layout, 35 | .. 36 | } = self; 37 | 38 | let &[gn, gd] = gate_layout.shape() else { 39 | return Err(rank_error("gate", 2, gate_layout.ndim())); 40 | }; 41 | let &[un, ud] = up_layout.shape() else { 42 | return Err(rank_error("up", 2, up_layout.ndim())); 43 | }; 44 | 45 | Ok(Meta { 46 | dt: type_distinct(&[gate_layout.dt(), up_layout.dt()])?, 47 | n: dim_distinct(&[gn, un])?, 48 | d: dim_distinct(&[gd, ud])?, 49 | }) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /operators/src/swiglu/common_cpu/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{args::Meta, Args, Swiglu}; 2 | use crate::{common_cpu::Cpu, get_static, ByteOf, LaunchError, QueueAlloc, SchemeError}; 3 | use half::f16; 4 | 5 | pub struct Operator; 6 | 7 | impl Swiglu for Operator {} 8 | 9 | impl crate::Operator for Operator { 10 | type Hardware = Cpu; 11 | type TopoNode = Cpu; 12 | type Args = Args; 13 | 14 | fn new(_node: &Self::TopoNode) -> Self { 15 | Self 16 | } 17 | 18 | fn scheme( 19 | &mut self, 20 | args: &Self::Args, 21 | _max_workspace_size: usize, 22 | ) -> Result { 23 | let _meta = args.meta()?; 24 | Ok(0) 25 | } 26 | 27 | fn launch( 28 | &self, 29 | args: &Self::Args, 30 | _workspace: &mut [ByteOf], 31 | _queue_alloc: &QA, 32 | ) -> Result<(), LaunchError> 33 | where 34 | QA: QueueAlloc, 35 | { 36 | let Meta { dt, n, d } = args.meta()?; 37 | let Args { 38 | gate_layout, 39 | gate_base, 40 | up_layout, 41 | up_base, 42 | } = args; 43 | let &[sgn, sgd] = gate_layout.strides() else { 44 | unreachable!() 45 | }; 46 | let &[sun, sud] = up_layout.strides() else { 47 | unreachable!() 48 | }; 49 | 50 | get_static! { 51 | n d 52 | sgn sgd 53 | sun sud 54 | } 55 | 56 | macro_rules! calculate { 57 | ($ty:ty) => { 58 | Scheme::<$ty> { 59 | n, 60 | d, 61 | sgn, 62 | sgd, 63 | sun, 64 | sud, 65 | gate_base: gate_base.cast(), 66 | up_base: up_base.cast(), 67 | } 68 | .calculate() 69 | }; 70 | } 71 | 72 | use digit_layout::types as ty; 73 | match dt { 74 | ty::F16 => calculate!(f16), 75 | ty::F32 => calculate!(f32), 76 | ty::F64 => calculate!(f64), 77 | _ => todo!(), 78 | } 79 | Ok(()) 80 | } 81 | } 82 | 83 | struct Scheme { 84 | n: usize, 85 | d: usize, 86 | sgn: isize, 87 | sgd: isize, 88 | sun: isize, 89 | sud: isize, 90 | gate_base: *mut T, 91 | up_base: *const T, 92 | } 93 | 94 | unsafe impl Send for Scheme {} 95 | unsafe impl Sync for Scheme {} 96 | 97 | impl Scheme { 98 | fn loop_(&self, f: impl Sync + Fn(T, T) -> T) { 99 | for i in 0..self.n as isize { 100 | (0..self.d as isize).for_each(|j| { 101 | let gate = unsafe { &mut *self.gate_base.byte_offset(i * self.sgn + j * self.sgd) }; 102 | let up = unsafe { *self.up_base.byte_offset(i * self.sun + j * self.sud) }; 103 | *gate = f(*gate, up); 104 | }) 105 | } 106 | } 107 | } 108 | 109 | impl Scheme { 110 | #[inline] 111 | fn calculate(&self) { 112 | self.loop_(|gate, up| { 113 | let a = gate.to_f32(); 114 | let b = up.to_f32(); 115 | f16::from_f32(a * sigmoid_f32(a) * b) 116 | }) 117 | } 118 | } 119 | 120 | impl Scheme { 121 | #[inline] 122 | fn calculate(&self) { 123 | self.loop_(|gate, up| gate * sigmoid_f32(gate) * up) 124 | } 125 | } 126 | 127 | impl Scheme { 128 | #[inline] 129 | fn calculate(&self) { 130 | self.loop_(|gate, up| gate * sigmoid_f64(gate) * up) 131 | } 132 | } 133 | 134 | #[inline(always)] 135 | fn sigmoid_f32(x: f32) -> f32 { 136 | 1. / (1. + (-x).exp()) 137 | } 138 | 139 | #[inline(always)] 140 | fn sigmoid_f64(x: f64) -> f64 { 141 | 1. / (1. + (-x).exp()) 142 | } 143 | -------------------------------------------------------------------------------- /operators/src/swiglu/cuda/swiglu.cuh: -------------------------------------------------------------------------------- 1 | static __forceinline__ __device__ float sigmoid(float x) { 2 | return fdividef(1, 1 + expf(-x)); 3 | } 4 | 5 | template 6 | static __device__ void swiglu( 7 | Tdata *__restrict__ gate_, 8 | int const stride_gate, 9 | Tdata const *__restrict__ up_, 10 | int const stride_up) { 11 | auto k = blockIdx.x * blockDim.x + threadIdx.x, 12 | i = blockIdx.y * stride_gate + k, 13 | j = blockIdx.y * stride_up + k; 14 | auto x = float(gate_[i]), 15 | y = float(up_[j]); 16 | gate_[i] = Tdata(x * sigmoid(x) * y); 17 | } 18 | -------------------------------------------------------------------------------- /operators/src/swiglu/infini/mod.rs: -------------------------------------------------------------------------------- 1 | use super::{args::Meta, Args, Swiglu}; 2 | use crate::{get_static, infini::Device, ByteOf, LaunchError, QueueAlloc, SchemeError}; 3 | use infini_op::{infiniop, AsRaw, Descriptor, Handle}; 4 | use std::sync::Arc; 5 | 6 | pub struct Operator(Arc); 7 | 8 | impl Swiglu for Operator {} 9 | 10 | impl crate::Operator for Operator { 11 | type Hardware = Device; 12 | type TopoNode = Device; 13 | type Args = Args; 14 | 15 | #[inline] 16 | fn new(node: &Self::TopoNode) -> Self { 17 | Self(node.handle().clone()) 18 | } 19 | 20 | #[inline] 21 | fn scheme( 22 | &mut self, 23 | _args: &Self::Args, 24 | _max_workspace_size: usize, 25 | ) -> Result { 26 | Ok(0) 27 | } 28 | 29 | fn launch( 30 | &self, 31 | args: &Self::Args, 32 | _workspace: &mut [ByteOf], 33 | queue_alloc: &QA, 34 | ) -> Result<(), LaunchError> 35 | where 36 | QA: QueueAlloc, 37 | { 38 | let Meta { dt, n, d } = args.meta()?; 39 | let Args { 40 | gate_layout, 41 | gate_base, 42 | up_layout, 43 | up_base, 44 | } = args; 45 | let &[gns, gds] = gate_layout.strides() else { 46 | unreachable!() 47 | }; 48 | let &[uns, uds] = up_layout.strides() else { 49 | unreachable!() 50 | }; 51 | 52 | get_static! { 53 | n d 54 | gns gds 55 | uns uds 56 | } 57 | 58 | let gate = infini_op::Tensor::new(dt, [n, d], [gns, gds]); 59 | let up = infini_op::Tensor::new(dt, [n, d], [uns, uds]); 60 | 61 | let descriptor = Descriptor::new( 62 | |ptr| { 63 | infiniop!(infiniopCreateSwiGLUDescriptor( 64 | self.0.as_raw(), 65 | ptr, 66 | gate.as_raw(), 67 | up.as_raw(), 68 | gate.as_raw(), 69 | )) 70 | }, 71 | infini_op::bindings::infiniopDestroySwiGLUDescriptor, 72 | ); 73 | infiniop!(infiniopSwiGLU( 74 | descriptor.as_raw(), 75 | gate_base.cast(), 76 | up_base.cast(), 77 | gate_base.cast(), 78 | queue_alloc.queue().as_void_ptr(), 79 | )); 80 | Ok(()) 81 | } 82 | } 83 | 84 | #[cfg(test)] 85 | mod test { 86 | use super::{Args, Device, Operator}; 87 | use crate::{dyn_, Hardware, Operator as _, TensorLayout}; 88 | use digit_layout::{ 89 | types::{F16, F64}, 90 | DigitLayout, 91 | }; 92 | 93 | fn dyn_args(dt: DigitLayout) -> Args { 94 | use std::ptr::{null, null_mut}; 95 | let layout = TensorLayout::new_dyn(dt, &[dyn_(); 2], &[dyn_(); 2]); 96 | Args { 97 | gate_layout: layout.clone(), 98 | gate_base: null_mut(), 99 | up_layout: layout, 100 | up_base: null(), 101 | } 102 | } 103 | 104 | fn args( 105 | dt: DigitLayout, 106 | n: usize, 107 | d: usize, 108 | gate_base: *mut H::Byte, 109 | up_base: *const H::Byte, 110 | ) -> Args { 111 | let layout = TensorLayout::new_contiguous(dt, &[n, d]); 112 | Args { 113 | gate_layout: layout.clone(), 114 | gate_base, 115 | up_layout: layout, 116 | up_base, 117 | } 118 | } 119 | 120 | #[test] 121 | fn test_compute() { 122 | use super::super::common_cpu::Operator as RefOp; 123 | use crate::{ 124 | common_cpu::{Cpu, ThisThread}, 125 | infini::cast_load, 126 | test_utils::{Diff, ErrorCollector}, 127 | }; 128 | use half::f16; 129 | use rand::Rng; 130 | 131 | let n = 5632; 132 | let d = 2048; 133 | 134 | infini_rt::init(infini_rt::DEVICE_CPU); 135 | let dev = Device::cpu(); 136 | 137 | let mut cpu_op = RefOp::new(&Cpu); 138 | let mut dev_op = Operator::new(&dev); 139 | cpu_op.scheme(&dyn_args(F64), 0).unwrap(); 140 | dev_op.scheme(&dyn_args(F16), 0).unwrap(); 141 | 142 | let mut gate = vec![0.0f64; n * d]; 143 | let mut up = vec![0.0f64; n * d]; 144 | rand::rng().fill(&mut gate[..]); 145 | rand::rng().fill(&mut up[..]); 146 | let up = up; 147 | 148 | let gate_ans = { 149 | let stream = dev.stream(); 150 | let mut gate = cast_load(&gate, f16::from_f64, &stream); 151 | let up = cast_load(&up, f16::from_f64, &stream); 152 | dev_op 153 | .launch( 154 | &args(F16, n, d, gate.as_mut_ptr().cast(), up.as_ptr().cast()), 155 | &mut [], 156 | &stream, 157 | ) 158 | .unwrap(); 159 | let mut host = vec![f16::ZERO; n * d]; 160 | dev.memcpy_d2h(&mut host, &gate); 161 | host 162 | }; 163 | 164 | let mut gate_ref = gate; 165 | cpu_op 166 | .launch( 167 | &args(F64, n, d, gate_ref.as_mut_ptr().cast(), up.as_ptr().cast()), 168 | &mut [], 169 | &ThisThread, 170 | ) 171 | .unwrap(); 172 | 173 | let diff = gate_ref 174 | .into_iter() 175 | .zip(gate_ans) 176 | .map(|(a, b)| Diff::new(a, b.to_f64())) 177 | .collect::>(); 178 | 179 | let mut ec = ErrorCollector::new(f16::EPSILON.to_f64(), 0.); 180 | diff.into_iter().for_each(|diff| ec.push(diff)); 181 | println!("{ec}"); 182 | 183 | let (out, count) = ec.summary(); 184 | assert!(out * 1000 <= count); 185 | } 186 | } 187 | -------------------------------------------------------------------------------- /operators/src/swiglu/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(any(use_cpu, test))] 2 | pub mod common_cpu; 3 | #[cfg(use_cuda)] 4 | pub mod cuda; 5 | #[cfg(use_infini)] 6 | pub mod infini; 7 | #[cfg(use_cl)] 8 | pub mod opencl; 9 | 10 | mod args; 11 | pub use args::Args; 12 | 13 | crate::op_trait!(Swiglu); 14 | -------------------------------------------------------------------------------- /operators/src/swiglu/opencl/swiglu.cl: -------------------------------------------------------------------------------- 1 | #define CL_TARGET_OPENCL_VERSION 300 2 | #pragma OPENCL EXTENSION cl_khr_fp16 : enable 3 | 4 | #ifndef Tval 5 | #define Tval float 6 | #endif 7 | 8 | typedef unsigned int Tidx; 9 | 10 | __kernel void swiglu( 11 | __global Tval *gate, 12 | int const stride_gate, 13 | __global Tval *up, 14 | int const strid_up) { 15 | 16 | Tidx g_idx = get_global_id(0); 17 | Tidx g_idy = get_global_id(1); 18 | 19 | Tidx i = g_idx * stride_gate + g_idy; 20 | Tidx j = g_idx * strid_up + g_idy; 21 | 22 | Tval x = gate[i]; 23 | Tval y = up[j]; 24 | 25 | Tval sig = 1.0f / (1.0f + exp(-x)); 26 | gate[i] = x * sig * y; 27 | } 28 | --------------------------------------------------------------------------------