├── .cargo └── config.toml ├── .github └── workflows │ └── ci.yaml ├── .gitignore ├── .idea ├── .gitignore ├── modules.xml ├── vcs.xml └── wg-math.iml ├── Cargo.toml ├── LICENSE-APACHE.txt ├── LICENSE-MIT.txt ├── README.md ├── assets └── gguf │ └── dummy.gguf └── crates ├── wgcore-derive ├── Cargo.toml └── src │ └── lib.rs ├── wgcore ├── CHANGELOG.md ├── Cargo.toml ├── README.md ├── examples │ ├── buffer_readback.rs │ ├── compose.rs │ ├── compose_dependency.wgsl │ ├── compose_kernel.wgsl │ ├── encase.rs │ ├── encase.wgsl │ ├── hot_reloading.rs │ ├── hot_reloading.wgsl │ ├── overwrite.rs │ ├── overwritten_dependency.wgsl │ ├── timestamp_queries.rs │ └── timestamp_queries.wgsl └── src │ ├── composer.rs │ ├── gpu.rs │ ├── hot_reloading.rs │ ├── kernel.rs │ ├── lib.rs │ ├── shader.rs │ ├── shapes.rs │ ├── tensor.rs │ ├── timestamps.rs │ └── utils.rs ├── wgebra ├── CHANGELOG.md ├── Cargo.toml ├── README.md └── src │ ├── geometry │ ├── cholesky.rs │ ├── cholesky.wgsl │ ├── eig2.rs │ ├── eig2.wgsl │ ├── eig3.rs │ ├── eig3.wgsl │ ├── eig4.rs │ ├── eig4.wgsl │ ├── inv.rs │ ├── inv.wgsl │ ├── lu.rs │ ├── lu.wgsl │ ├── mod.rs │ ├── qr2.rs │ ├── qr2.wgsl │ ├── qr3.rs │ ├── qr3.wgsl │ ├── qr4.rs │ ├── qr4.wgsl │ ├── quat.rs │ ├── quat.wgsl │ ├── rot2.rs │ ├── rot2.wgsl │ ├── sim2.rs │ ├── sim2.wgsl │ ├── sim3.rs │ ├── sim3.wgsl │ ├── svd2.rs │ ├── svd2.wgsl │ ├── svd3.rs │ └── svd3.wgsl │ ├── lib.rs │ ├── linalg │ ├── gemm.rs │ ├── gemm.wgsl │ ├── gemv.rs │ ├── gemv.wgsl │ ├── mod.rs │ ├── op_assign.rs │ ├── op_assign.wgsl │ ├── reduce.rs │ ├── reduce.wgsl │ ├── shape.rs │ └── shape.wgsl │ └── utils │ ├── min_max.rs │ ├── min_max.wgsl │ ├── mod.rs │ ├── trig.rs │ └── trig.wgsl ├── wgparry ├── CHANGELOG.md ├── README.md ├── crates │ ├── wgparry2d │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src │ └── wgparry3d │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src └── src │ ├── ball.rs │ ├── ball.wgsl │ ├── capsule.rs │ ├── capsule.wgsl │ ├── cone.rs │ ├── cone.wgsl │ ├── contact.rs │ ├── contact.wgsl │ ├── cuboid.rs │ ├── cuboid.wgsl │ ├── cylinder.rs │ ├── cylinder.wgsl │ ├── lib.rs │ ├── projection.rs │ ├── projection.wgsl │ ├── ray.rs │ ├── ray.wgsl │ ├── segment.rs │ ├── segment.wgsl │ ├── shape.rs │ ├── shape.wgsl │ ├── shape_fake_cone.wgsl │ ├── shape_fake_cylinder.wgsl │ ├── triangle.rs │ └── triangle.wgsl └── wgrapier ├── CHANGELOG.md ├── README.md ├── crates ├── wgrapier2d │ ├── Cargo.toml │ ├── README.md │ └── src └── wgrapier3d │ ├── Cargo.toml │ ├── README.md │ └── src ├── examples ├── gravity.rs └── gravity.wgsl └── src ├── dynamics ├── body.rs ├── body.wgsl ├── integrate.rs ├── integrate.wgsl └── mod.rs └── lib.rs /.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [target.wasm32-unknown-unknown] 2 | runner = "wasm-server-runner" -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | RUSTFLAGS: --deny warnings 12 | RUSTDOCFLAGS: --deny warnings 13 | 14 | jobs: 15 | # Run clippy lints. 16 | clippy: 17 | name: Clippy 18 | runs-on: ubuntu-latest 19 | timeout-minutes: 30 20 | steps: 21 | - name: Checkout repository 22 | uses: actions/checkout@v4 23 | 24 | - name: Install Rust toolchain 25 | uses: dtolnay/rust-toolchain@stable 26 | with: 27 | components: clippy 28 | 29 | - name: Install dependencies 30 | run: sudo apt-get update; sudo apt-get install --no-install-recommends libasound2-dev libudev-dev libwayland-dev libxkbcommon-dev 31 | 32 | - name: Populate target directory from cache 33 | uses: Leafwing-Studios/cargo-cache@v2 34 | with: 35 | sweep-cache: true 36 | 37 | - name: Run clippy lints 38 | run: cargo clippy --locked --workspace --all-targets --all-features -- --deny warnings 39 | 40 | # Check formatting. 41 | format: 42 | name: Format 43 | runs-on: ubuntu-latest 44 | timeout-minutes: 30 45 | steps: 46 | - name: Checkout repository 47 | uses: actions/checkout@v4 48 | 49 | - name: Install Rust toolchain 50 | uses: dtolnay/rust-toolchain@stable 51 | with: 52 | components: rustfmt 53 | 54 | - name: Run cargo fmt 55 | run: cargo fmt --all -- --check 56 | 57 | # Check documentation. 58 | doc: 59 | name: Docs 60 | runs-on: ubuntu-latest 61 | timeout-minutes: 30 62 | steps: 63 | - name: Checkout repository 64 | uses: actions/checkout@v4 65 | 66 | - name: Install Rust toolchain 67 | uses: dtolnay/rust-toolchain@stable 68 | 69 | - name: Install dependencies 70 | run: sudo apt-get update; sudo apt-get install --no-install-recommends libasound2-dev libudev-dev libwayland-dev libxkbcommon-dev 71 | 72 | - name: Populate target directory from cache 73 | uses: Leafwing-Studios/cargo-cache@v2 74 | with: 75 | sweep-cache: true 76 | 77 | - name: Check documentation 78 | run: cargo doc --locked --workspace --all-features --document-private-items --no-deps -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | debug/ 4 | target/ 5 | 6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 8 | Cargo.lock 9 | 10 | # These are backup files generated by rustfmt 11 | **/*.rs.bk 12 | 13 | # MSVC Windows builds of rustc generate these, which store debugging information 14 | *.pdb 15 | 16 | .lock 17 | dist 18 | assets 19 | 20 | *_bg.wasm 21 | website 22 | 23 | # JetBrain IDEs 24 | # 25 | # 26 | # 27 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 28 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 29 | # and can be added to the global gitignore or merged into this file. For a more nuclear 30 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 31 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 32 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 33 | 34 | # User-specific stuff 35 | .idea/**/workspace.xml 36 | .idea/**/tasks.xml 37 | .idea/**/usage.statistics.xml 38 | .idea/**/dictionaries 39 | .idea/**/shelf 40 | 41 | # AWS User-specific 42 | .idea/**/aws.xml 43 | 44 | # Generated files 45 | .idea/**/contentModel.xml 46 | 47 | # Sensitive or high-churn files 48 | .idea/**/dataSources/ 49 | .idea/**/dataSources.ids 50 | .idea/**/dataSources.local.xml 51 | .idea/**/sqlDataSources.xml 52 | .idea/**/dynamic.xml 53 | .idea/**/uiDesigner.xml 54 | .idea/**/dbnavigator.xml 55 | 56 | # Gradle 57 | .idea/**/gradle.xml 58 | .idea/**/libraries 59 | 60 | # Gradle and Maven with auto-import 61 | # When using Gradle or Maven with auto-import, you should exclude module files, 62 | # since they will be recreated, and may cause churn. Uncomment if using 63 | # auto-import. 64 | # .idea/artifacts 65 | # .idea/compiler.xml 66 | # .idea/jarRepositories.xml 67 | # .idea/modules.xml 68 | # .idea/*.iml 69 | # .idea/modules 70 | # *.iml 71 | # *.ipr 72 | 73 | # CMake 74 | cmake-build-*/ 75 | 76 | # Mongo Explorer plugin 77 | .idea/**/mongoSettings.xml 78 | 79 | # File-based project format 80 | *.iws 81 | 82 | # IntelliJ 83 | out/ 84 | 85 | # mpeltonen/sbt-idea plugin 86 | .idea_modules/ 87 | 88 | # JIRA plugin 89 | atlassian-ide-plugin.xml 90 | 91 | # Cursive Clojure plugin 92 | .idea/replstate.xml 93 | 94 | # SonarLint plugin 95 | .idea/sonarlint/ 96 | 97 | # Crashlytics plugin (for Android Studio and IntelliJ) 98 | com_crashlytics_export_strings.xml 99 | crashlytics.properties 100 | crashlytics-build.properties 101 | fabric.properties 102 | 103 | # Editor-based Rest Client 104 | .idea/httpRequests 105 | 106 | # Android studio 3.1+ serialized cache file 107 | .idea/caches/build_file_checksums.ser 108 | 109 | .DS_store -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/wg-math.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "crates/wgcore", "crates/wgebra", 4 | "crates/wgparry/crates/wgparry2d", "crates/wgparry/crates/wgparry3d", 5 | "crates/wgrapier/crates/wgrapier2d", "crates/wgrapier/crates/wgrapier3d" 6 | ] 7 | resolver = "2" 8 | 9 | [workspace.dependencies] 10 | nalgebra = { version = "0.33.1", features = ["convert-bytemuck"] } 11 | parry2d = { version = "0.18", features = ["bytemuck", "encase"] } 12 | parry3d = { version = "0.18", features = ["bytemuck", "encase"] } 13 | wgpu = { version = "24", features = ["naga-ir"] } 14 | bytemuck = { version = "1", features = ["derive", "extern_crate_std"] } 15 | anyhow = "1" 16 | async-channel = "2" 17 | naga_oil = "0.17" 18 | thiserror = "1" 19 | 20 | encase = { version = "0.10.0", features = ["nalgebra"] } 21 | 22 | [workspace.lints] 23 | rust.unexpected_cfgs = { level = "warn", check-cfg = [ 24 | 'cfg(feature, values("dim2", "dim3"))' 25 | ] } 26 | 27 | [profile.release] 28 | opt-level = 'z' 29 | 30 | [patch.crates-io] 31 | parry3d = { git = "https://github.com/dimforge/parry", branch = "encase" } 32 | parry2d = { git = "https://github.com/dimforge/parry", branch = "encase" } 33 | encase = { git = "https://github.com/sebcrozet/encase", branch = "nalgebra-points" } 34 | -------------------------------------------------------------------------------- /LICENSE-MIT.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wgmath − GPU scientific computing on every platform 2 | 3 |

4 | crates.io 5 |

6 |

7 | 8 | 9 | 10 |

11 | 12 | ----- 13 | 14 | **wgmath** is a set of [Rust](https://www.rust-lang.org/) libraries exposing 15 | re-usable [WebGPU](https://www.w3.org/TR/WGSL/) shaders for scientific computing including: 16 | 17 | - The [**wgcore** crate](https://github.com/dimforge/wgmath/tree/main/crates/wgcore), a centerpiece of the **wgmath** 18 | ecosystem, exposes a set of proc-macros to facilitate sharing and composing shaders across Rust libraries. 19 | - Linear algebra with the [**wgebra** crate](https://github.com/dimforge/wgmath/tree/main/crates/wgebra). 20 | - AI (Large Language Models) with the [**wgml** crate](https://github.com/dimforge/wgml/tree/main). 21 | - Collision-detection with the 22 | [**wgparry2d** and **wgparry3d**](https://github.com/dimforge/wgmath/tree/main/crates/wgparry) crates (still very 23 | WIP). 24 | - Rigid-body physics with the 25 | [**wgrapier2d** and **wgrapier3d**](https://github.com/dimforge/wgmath/tree/main/crates/wgrapier3d) crates (still very 26 | WIP). 27 | 28 | By targeting WebGPU, these libraries run on most GPUs, including on mobile and on the web. It aims to promote open and 29 | cross-platform GPU computing for scientific applications, a field currently strongly dominated by proprietary 30 | solutions (like CUDA). 31 | 32 | ⚠️ All these libraries are still under heavy development and might be lacking some important features. Contributions 33 | are welcome! 34 | 35 | ---- 36 | 37 | **See the readme of each individual crate (on the `crates` directory) for additional details.** 38 | 39 | ---- 40 | -------------------------------------------------------------------------------- /assets/gguf/dummy.gguf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dimforge/wgmath/95538845080ef8680cf9906f1949623e30be3495/assets/gguf/dummy.gguf -------------------------------------------------------------------------------- /crates/wgcore-derive/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "wgcore-derive" 3 | authors = ["Sébastien Crozet "] 4 | description = "Proc-macro for composable WGSL shaders." 5 | homepage = "https://wgmath.rs" 6 | repository = "https://github.com/dimforge/wgmath" 7 | version = "0.2.0" 8 | edition = "2021" 9 | license = "MIT OR Apache-2.0" 10 | 11 | [lib] 12 | name = "wgcore_derive" 13 | path = "src/lib.rs" 14 | proc-macro = true 15 | 16 | [dependencies] 17 | syn = "2.0.77" 18 | quote = "1.0.37" 19 | proc-macro2 = "1.0.86" 20 | darling = "0.20.10" -------------------------------------------------------------------------------- /crates/wgcore/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## Unreleased 2 | 3 | ### Added 4 | 5 | - Add `Shader::shader_module` to generate and return the shader’s `ShaderModule`. 6 | 7 | ### Changed 8 | 9 | - Rename `Shader::set_absolute_path` to `Shader::set_wgsl_path`. 10 | - Rename `Shader::absolute_path` to `Shader::wgsl_path`. 11 | - Workgroup memory automatic zeroing is now **disabled** by default due to its significant 12 | performance impact. 13 | 14 | ## v0.2.2 15 | 16 | ### Fixed 17 | 18 | - Fix crash in `HotReloadState` when targetting wasm. 19 | 20 | ## v0.2.1 21 | 22 | ### Fixed 23 | 24 | - Fix build when targeting wasm. 25 | 26 | ## v0.2.0 27 | 28 | ### Added 29 | 30 | - Add support for hot-reloading, see [#1](https://github.com/dimforge/wgmath/pull/1). This includes breaking changes to 31 | the `Shader` trait. 32 | - Add support for shader overwriting, see [#1](https://github.com/dimforge/wgmath/pull/1). 33 | -------------------------------------------------------------------------------- /crates/wgcore/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "wgcore" 3 | authors = ["Sébastien Crozet "] 4 | description = "Utilities and abstractions for composable WGSL shaders." 5 | homepage = "https://wgmath.rs" 6 | repository = "https://github.com/dimforge/wgmath" 7 | readme = "README.md" 8 | version = "0.2.2" 9 | edition = "2021" 10 | license = "MIT OR Apache-2.0" 11 | 12 | [features] 13 | derive = ["wgcore-derive"] 14 | 15 | [dependencies] 16 | nalgebra = { workspace = true } 17 | wgpu = { workspace = true, features = ["wgsl"] } 18 | bytemuck = { workspace = true } 19 | anyhow = { workspace = true } 20 | async-channel = { workspace = true } 21 | naga_oil = { workspace = true } 22 | encase = { workspace = true } 23 | 24 | wgcore-derive = { version = "0.2", path = "../wgcore-derive", optional = true } 25 | 26 | dashmap = "5" 27 | notify = { version = "7" } # , optional = true } 28 | 29 | # For test_shader_compilation 30 | paste = "1" 31 | 32 | [dev-dependencies] 33 | nalgebra = { version = "0.33", features = ["rand"] } 34 | futures-test = "0.3" 35 | serial_test = "3" 36 | approx = "0.5" 37 | async-std = { version = "1", features = ["attributes"] } 38 | -------------------------------------------------------------------------------- /crates/wgcore/README.md: -------------------------------------------------------------------------------- 1 | # wgcore − utilities and abstractions for composable WGSL shaders 2 | 3 | **wgcore** provides simple abstractions over shaders and gpu resources based on `wgpu`. It aims to: 4 | 5 | - Expose thin wrappers that are as unsurprising as possible. We do not rely on complex compiler 6 | magic like bitcode generation in frameworks like `cust` and `rust-gpu`. 7 | - Provide a proc-macro (through the `wgcore-derive` crate) to simplifies shader reuse across 8 | crates with very low boilerplate. 9 | - No ownership of the gpu device and queue. While `wgcore` does expose a utility struct 10 | [`gpu::GpuInstance`] to initialize the compute unit, it is completely optional. All the features 11 | of `wgcore` remain usable if the gpu device and queue are already own by, e.g., a game engine. 12 | 13 | ## Shader composition 14 | 15 | #### Basic usage 16 | 17 | Currently, **wgcore** relies on [naga-oil](https://github.com/bevyengine/naga_oil) for shader 18 | composition. Though we are keeping an eye on the ongoing [WESL](https://github.com/wgsl-tooling-wg) 19 | effort for an alternative to `naga-oil`. 20 | 21 | The main value added over `naga-oil` is the `wgcore::Shader` trait and proc-macro. This lets you 22 | declare composable shaders very concisely. For example, if the WGSL sources are at the path 23 | `./shader_sources.wgsl` relative to the `.rs` source file, all that’s needed for it to be composable 24 | is to `derive` she `Shader` trait: 25 | 26 | ```rust ignore 27 | #[derive(Shader)] 28 | #[shader(src = "shader_source.wgsl")] 29 | struct MyShader1; 30 | ``` 31 | 32 | Then it becomes immediately importable (assuming the `.wgsl` source itself contains a 33 | `#define_import_path` statement) from another shader with the `shader(derive)` attribute: 34 | 35 | ```rust ignore 36 | #[derive(Shader)] 37 | #[shader( 38 | derive(MyShader1), // This shader depends on the `MyShader1` shader. 39 | src = "kernel.wgsl", // Shader source code, will be embedded in the exe with `include_str!`. 40 | )] 41 | struct MyShader2; 42 | ``` 43 | 44 | Finally, if we want to use these shaders from another one which contains a kernel entry-point, 45 | it is possible to declare `ComputePipeline` fields on the struct deriving `Shader`: 46 | 47 | ```rust ignore 48 | #[derive(Shader)] 49 | #[shader( 50 | derive(MyShader1, MyShader2), 51 | src = "kernel.wgsl", 52 | )] 53 | struct MyKernel { 54 | // Note that the field name has to match the kernel entry-point’s name. 55 | main: ComputePipeline, 56 | } 57 | ``` 58 | 59 | This will automatically generate the necessary boiler-place for creating the compute pipeline 60 | from a device: `MyKernel::from_device(device)`. 61 | 62 | #### Some customization 63 | 64 | The `Shader` proc-macro allows some customizations of the imported shaders: 65 | 66 | - `src_fn = "function_name"`: allows the input sources to be modified by an arbitrary string 67 | transformation function before being compiled as a naga module. This enables any custom 68 | preprocessor to run before naga-oil. 69 | - `shader_defs = "function_name"`: allows the declaration of shader definitions that can then be 70 | used in the shader in, e.g., `#ifdef MY_SHADER_DEF` statements (as well as `#if` statements and 71 | anything supported by the `naga-oil`’s shader definitions feature). 72 | - `composable = false`: specifies that the shader does not exports any reusable symbols to other 73 | shaders. in particular, this **must** be specified if the shader sources doesn’t contain any 74 | `#define_import_path` statement. 75 | 76 | ```rust ignore 77 | #[derive(Shader)] 78 | #[shader( 79 | derive(MyShader1, MyShader2), 80 | src = "kernel.wgsl", 81 | src_fn = "substitute_aliases", 82 | shader_defs = "dim_shader_defs" 83 | composable = false 84 | )] 85 | struct MyKernel { 86 | main: ComputePipeline, 87 | } -------------------------------------------------------------------------------- /crates/wgcore/examples/buffer_readback.rs: -------------------------------------------------------------------------------- 1 | use nalgebra::DVector; 2 | use wgcore::gpu::GpuInstance; 3 | use wgcore::tensor::GpuVector; 4 | use wgpu::BufferUsages; 5 | 6 | #[async_std::main] 7 | async fn main() -> anyhow::Result<()> { 8 | // Initialize the gpu device and its queue. 9 | // 10 | // Note that `GpuInstance` is just a simple helper struct for initializing the gpu resources. 11 | // You are free to initialize them independently if more control is needed, or reuse the ones 12 | // that were already created/owned by e.g., a game engine. 13 | let gpu = GpuInstance::new().await?; 14 | 15 | // Create the buffers. 16 | const LEN: u32 = 10; 17 | let buffer_data = DVector::from_fn(LEN as usize, |i, _| i as u32); 18 | let buffer = GpuVector::init( 19 | gpu.device(), 20 | &buffer_data, 21 | BufferUsages::STORAGE | BufferUsages::COPY_SRC, 22 | ); 23 | let staging = GpuVector::uninit( 24 | gpu.device(), 25 | LEN, 26 | BufferUsages::COPY_DST | BufferUsages::MAP_READ, 27 | ); 28 | 29 | // Queue the operation. 30 | // Encode & submit the operation to the gpu. 31 | let mut encoder = gpu.device().create_command_encoder(&Default::default()); 32 | // Copy the result to the staging buffer. 33 | staging.copy_from(&mut encoder, &buffer); 34 | gpu.queue().submit(Some(encoder.finish())); 35 | 36 | let read = DVector::from(staging.read(gpu.device()).await?); 37 | assert_eq!(buffer_data, read); 38 | println!("Buffer copy & read succeeded!"); 39 | println!("Original: {:?}", buffer_data); 40 | println!("Readback: {:?}", read); 41 | 42 | Ok(()) 43 | } 44 | -------------------------------------------------------------------------------- /crates/wgcore/examples/compose.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "derive"))] 2 | std::compile_error!( 3 | r#" 4 | ############################################################### 5 | ## The `derive` feature must be enabled to run this example. ## 6 | ############################################################### 7 | "# 8 | ); 9 | 10 | use nalgebra::DVector; 11 | use std::fmt::Debug; 12 | use wgcore::gpu::GpuInstance; 13 | use wgcore::kernel::{CommandEncoderExt, KernelDispatch}; 14 | use wgcore::tensor::GpuVector; 15 | use wgcore::Shader; 16 | use wgpu::{BufferUsages, ComputePipeline}; 17 | 18 | // Declare our shader module that contains our composable functions. 19 | // Note that we don’t build any compute pipeline from this wgsl file. 20 | #[derive(Shader)] 21 | #[shader( 22 | src = "compose_dependency.wgsl" // Shader source code, will be embedded in the exe with `include_str!` 23 | )] 24 | struct Composable; 25 | 26 | #[derive(Shader)] 27 | #[shader( 28 | derive(Composable), // This shader depends on the `Composable` shader. 29 | src = "compose_kernel.wgsl", // Shader source code, will be embedded in the exe with `include_str!`. 30 | composable = false // This shader doesn’t export any symbols reusable from other wgsl shaders. 31 | )] 32 | struct WgKernel { 33 | // This ComputePipeline field indicates that the Shader macro needs to generate the boilerplate 34 | // for loading the compute pipeline in `WgKernel::from_device`. 35 | main: ComputePipeline, 36 | } 37 | 38 | #[derive(Copy, Clone, PartialEq, Default, bytemuck::Pod, bytemuck::Zeroable)] 39 | #[repr(C)] 40 | pub struct MyStruct { 41 | value: f32, 42 | } 43 | 44 | // Optional: makes the debug output more concise. 45 | impl Debug for MyStruct { 46 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 47 | write!(f, "{}", self.value) 48 | } 49 | } 50 | 51 | #[async_std::main] 52 | async fn main() -> anyhow::Result<()> { 53 | // Initialize the gpu device and its queue. 54 | // 55 | // Note that `GpuInstance` is just a simple helper struct for initializing the gpu resources. 56 | // You are free to initialize them independently if more control is needed, or reuse the ones 57 | // that were already created/owned by e.g., a game engine. 58 | let gpu = GpuInstance::new().await?; 59 | 60 | // Load and compile our kernel. The `from_device` function was generated by the `Shader` derive. 61 | // Note that its dependency to `Composable` is automatically resolved by the `Shader` derive 62 | // too. 63 | let kernel = WgKernel::from_device(gpu.device())?; 64 | println!("######################################"); 65 | println!("###### Composed shader sources: ######"); 66 | println!("######################################"); 67 | println!("{}", WgKernel::flat_wgsl()?); 68 | 69 | // Now, let’s actually run our kernel. 70 | let result = run_kernel(&gpu, &kernel).await; 71 | println!("Result: {:?}", result); 72 | 73 | Ok(()) 74 | } 75 | 76 | async fn run_kernel(gpu: &GpuInstance, kernel: &WgKernel) -> Vec { 77 | // Create the buffers. 78 | const LEN: u32 = 10; 79 | let a_data = DVector::from_fn(LEN as usize, |i, _| MyStruct { value: i as f32 }); 80 | let b_data = DVector::from_fn(LEN as usize, |i, _| MyStruct { 81 | value: i as f32 * 10.0, 82 | }); 83 | let a_buf = GpuVector::init( 84 | gpu.device(), 85 | &a_data, 86 | BufferUsages::STORAGE | BufferUsages::COPY_SRC, 87 | ); 88 | let b_buf = GpuVector::init(gpu.device(), &b_data, BufferUsages::STORAGE); 89 | let staging = GpuVector::uninit( 90 | gpu.device(), 91 | LEN, 92 | BufferUsages::COPY_DST | BufferUsages::MAP_READ, 93 | ); 94 | 95 | // Encode & submit the operation to the gpu. 96 | let mut encoder = gpu.device().create_command_encoder(&Default::default()); 97 | let mut pass = encoder.compute_pass("test", None); 98 | KernelDispatch::new(gpu.device(), &mut pass, &kernel.main) 99 | .bind0([a_buf.buffer(), b_buf.buffer()]) 100 | .dispatch(LEN.div_ceil(64)); 101 | drop(pass); 102 | 103 | // Copy the result to the staging buffer. 104 | staging.copy_from(&mut encoder, &a_buf); 105 | gpu.queue().submit(Some(encoder.finish())); 106 | 107 | // Read the result back from the gpu. 108 | staging 109 | .read(gpu.device()) 110 | .await 111 | .expect("Failed to read result from the GPU.") 112 | } 113 | -------------------------------------------------------------------------------- /crates/wgcore/examples/compose_dependency.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path composable::module 2 | 3 | struct MyStruct { 4 | value: f32, 5 | } 6 | 7 | fn shared_function(a: MyStruct, b: MyStruct) -> MyStruct { 8 | return MyStruct(a.value + b.value); 9 | } -------------------------------------------------------------------------------- /crates/wgcore/examples/compose_kernel.wgsl: -------------------------------------------------------------------------------- 1 | #import composable::module as Dependency 2 | 3 | @group(0) @binding(0) 4 | var a: array; 5 | @group(0) @binding(1) 6 | var b: array; 7 | 8 | @compute @workgroup_size(64, 1, 1) 9 | fn main(@builtin(global_invocation_id) invocation_id: vec3) { 10 | let i = invocation_id.x; 11 | if i < arrayLength(&a) { 12 | a[i] = Dependency::shared_function(a[i], b[i]); 13 | } 14 | } -------------------------------------------------------------------------------- /crates/wgcore/examples/encase.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "derive"))] 2 | std::compile_error!( 3 | r#" 4 | ############################################################### 5 | ## The `derive` feature must be enabled to run this example. ## 6 | ############################################################### 7 | "# 8 | ); 9 | 10 | use nalgebra::Vector4; 11 | use wgcore::gpu::GpuInstance; 12 | use wgcore::kernel::{CommandEncoderExt, KernelDispatch}; 13 | use wgcore::tensor::GpuVector; 14 | use wgcore::Shader; 15 | use wgpu::{BufferUsages, ComputePipeline}; 16 | 17 | #[derive(Copy, Clone, PartialEq, Debug, Default, bytemuck::Pod, bytemuck::Zeroable)] 18 | #[repr(C)] 19 | pub struct BytemuckStruct { 20 | value: f32, 21 | } 22 | 23 | #[derive(Copy, Clone, PartialEq, Debug, Default, encase::ShaderType)] 24 | #[repr(C)] 25 | pub struct EncaseStruct { 26 | value: f32, 27 | // This implies some internal padding, so we can’t rely on bytemuck. 28 | // Encase will handle that properly. 29 | value2: Vector4, 30 | } 31 | 32 | #[derive(Shader)] 33 | #[shader(src = "encase.wgsl", composable = false)] 34 | struct ShaderEncase { 35 | main: ComputePipeline, 36 | } 37 | 38 | #[async_std::main] 39 | async fn main() -> anyhow::Result<()> { 40 | // Initialize the gpu device and its queue. 41 | // 42 | // Note that `GpuInstance` is just a simple helper struct for initializing the gpu resources. 43 | // You are free to initialize them independently if more control is needed, or reuse the ones 44 | // that were already created/owned by e.g., a game engine. 45 | let gpu = GpuInstance::new().await?; 46 | 47 | // Load and compile our kernel. The `from_device` function was generated by the `Shader` derive. 48 | // Note that its dependency to `Composable` is automatically resolved by the `Shader` derive 49 | // too. 50 | let kernel = ShaderEncase::from_device(gpu.device())?; 51 | 52 | // Create the buffers. 53 | const LEN: u32 = 1000; 54 | let a_data = (0..LEN) 55 | .map(|x| EncaseStruct { 56 | value: x as f32, 57 | value2: Vector4::repeat(x as f32 * 10.0), 58 | }) 59 | .collect::>(); 60 | let b_data = (0..LEN) 61 | .map(|x| BytemuckStruct { value: x as f32 }) 62 | .collect::>(); 63 | // Call `encase` instead of `init` because `EncaseStruct` isn’t `Pod`. 64 | // The `encase` function has a bit of overhead so bytemuck should be preferred whenever possible. 65 | let a_buf = GpuVector::encase(gpu.device(), &a_data, BufferUsages::STORAGE); 66 | let b_buf = GpuVector::init(gpu.device(), &b_data, BufferUsages::STORAGE); 67 | 68 | // Encode & submit the operation to the gpu. 69 | let mut encoder = gpu.device().create_command_encoder(&Default::default()); 70 | let mut pass = encoder.compute_pass("test", None); 71 | KernelDispatch::new(gpu.device(), &mut pass, &kernel.main) 72 | .bind0([a_buf.buffer(), b_buf.buffer()]) 73 | .dispatch(LEN.div_ceil(64)); 74 | drop(pass); 75 | gpu.queue().submit(Some(encoder.finish())); 76 | 77 | Ok(()) 78 | } 79 | -------------------------------------------------------------------------------- /crates/wgcore/examples/encase.wgsl: -------------------------------------------------------------------------------- 1 | @group(0) @binding(0) 2 | var a: array; 3 | @group(0) @binding(1) 4 | var b: array; 5 | 6 | struct BytemuckStruct { 7 | value: f32, 8 | } 9 | 10 | struct EncaseStruct { 11 | value: f32, 12 | value2: vec4 13 | } 14 | 15 | @compute @workgroup_size(64, 1, 1) 16 | fn main(@builtin(global_invocation_id) invocation_id: vec3) { 17 | let i = invocation_id.x; 18 | if i < arrayLength(&a) { 19 | a[i].value += b[i].value; 20 | a[i].value2 += vec4(b[i].value); 21 | } 22 | } -------------------------------------------------------------------------------- /crates/wgcore/examples/hot_reloading.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "derive"))] 2 | std::compile_error!( 3 | r#" 4 | ############################################################### 5 | ## The `derive` feature must be enabled to run this example. ## 6 | ############################################################### 7 | "# 8 | ); 9 | 10 | use wgcore::gpu::GpuInstance; 11 | use wgcore::hot_reloading::HotReloadState; 12 | use wgcore::kernel::{CommandEncoderExt, KernelDispatch}; 13 | use wgcore::tensor::GpuScalar; 14 | use wgcore::Shader; 15 | use wgpu::{BufferUsages, ComputePipeline}; 16 | 17 | #[derive(Shader)] 18 | #[shader(src = "hot_reloading.wgsl", composable = false)] 19 | struct ShaderHotReloading { 20 | main: ComputePipeline, 21 | } 22 | 23 | #[async_std::main] 24 | async fn main() -> anyhow::Result<()> { 25 | // Initialize the gpu device and its queue. 26 | // 27 | // Note that `GpuInstance` is just a simple helper struct for initializing the gpu resources. 28 | // You are free to initialize them independently if more control is needed, or reuse the ones 29 | // that were already created/owned by e.g., a game engine. 30 | let gpu = GpuInstance::new().await?; 31 | 32 | // Load and compile our kernel. The `from_device` function was generated by the `Shader` derive. 33 | // Note that its dependency to `Composable` is automatically resolved by the `Shader` derive 34 | // too. 35 | let mut kernel = ShaderHotReloading::from_device(gpu.device())?; 36 | 37 | // Create the buffers. 38 | let buffer = GpuScalar::init( 39 | gpu.device(), 40 | 0u32, 41 | BufferUsages::STORAGE | BufferUsages::COPY_SRC, 42 | ); 43 | let staging = GpuScalar::init( 44 | gpu.device(), 45 | 0u32, 46 | BufferUsages::COPY_DST | BufferUsages::MAP_READ, 47 | ); 48 | 49 | // Init hot-reloading. 50 | let mut hot_reload = HotReloadState::new()?; 51 | ShaderHotReloading::watch_sources(&mut hot_reload)?; 52 | 53 | // Queue the operation. 54 | println!("#############################"); 55 | println!("Edit the file `hot_reloading.wgsl`.\nThe updated result will be printed below whenever a change is detected."); 56 | println!("#############################"); 57 | 58 | for loop_id in 0.. { 59 | // Detect & apply changes. 60 | hot_reload.update_changes(); 61 | match kernel.reload_if_changed(gpu.device(), &hot_reload) { 62 | Ok(changed) => { 63 | if changed || loop_id == 0 { 64 | // We detected a change (or this is the first loop). 65 | // Encode & submit the operation to the gpu. 66 | let mut encoder = gpu.device().create_command_encoder(&Default::default()); 67 | // Run our kernel. 68 | let mut pass = encoder.compute_pass("test", None); 69 | KernelDispatch::new(gpu.device(), &mut pass, &kernel.main) 70 | .bind0([buffer.buffer()]) 71 | .dispatch(1); 72 | drop(pass); 73 | 74 | // Copy the result to the staging buffer. 75 | staging.copy_from(&mut encoder, &buffer); 76 | gpu.queue().submit(Some(encoder.finish())); 77 | 78 | let result_read = staging.read(gpu.device()).await.unwrap(); 79 | println!("Current result value: {}", result_read[0]); 80 | } 81 | } 82 | Err(e) => { 83 | // Hot-reloading failed, likely due to a syntax error in the shader. 84 | println!("Hot reloading error: {:?}", e); 85 | } 86 | } 87 | } 88 | 89 | Ok(()) 90 | } 91 | -------------------------------------------------------------------------------- /crates/wgcore/examples/hot_reloading.wgsl: -------------------------------------------------------------------------------- 1 | @group(0) @binding(0) 2 | var a: u32; 3 | 4 | @compute @workgroup_size(1, 1, 1) 5 | fn main(@builtin(global_invocation_id) invocation_id: vec3) { 6 | a = 1u; // Change this value and save the file while running the `hot_reloading` example. 7 | } -------------------------------------------------------------------------------- /crates/wgcore/examples/overwrite.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "derive"))] 2 | std::compile_error!( 3 | r#" 4 | ############################################################### 5 | ## The `derive` feature must be enabled to run this example. ## 6 | ############################################################### 7 | "# 8 | ); 9 | 10 | use nalgebra::DVector; 11 | use std::fmt::Debug; 12 | use wgcore::gpu::GpuInstance; 13 | use wgcore::kernel::{CommandEncoderExt, KernelDispatch}; 14 | use wgcore::tensor::GpuVector; 15 | use wgcore::Shader; 16 | use wgpu::{BufferUsages, ComputePipeline}; 17 | 18 | // Declare our shader module that contains our composable functions. 19 | // Note that we don’t build any compute pipeline from this wgsl file. 20 | #[derive(Shader)] 21 | #[shader( 22 | src = "compose_dependency.wgsl" // Shader source code, will be embedded in the exe with `include_str!` 23 | )] 24 | struct Composable; 25 | 26 | #[derive(Shader)] 27 | #[shader( 28 | derive(Composable), // This shader depends on the `Composable` shader. 29 | src = "compose_kernel.wgsl", // Shader source code, will be embedded in the exe with `include_str!`. 30 | composable = false // This shader doesn’t export any symbols reusable from other wgsl shaders. 31 | )] 32 | struct WgKernel { 33 | // This ComputePipeline field indicates that the Shader macro needs to generate the boilerplate 34 | // for loading the compute pipeline in `WgKernel::from_device`. 35 | main: ComputePipeline, 36 | } 37 | 38 | #[derive(Copy, Clone, PartialEq, Default, bytemuck::Pod, bytemuck::Zeroable)] 39 | #[repr(C)] 40 | pub struct MyStruct { 41 | value: f32, 42 | } 43 | 44 | // Optional: makes the debug output more concise. 45 | impl Debug for MyStruct { 46 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 47 | write!(f, "{}", self.value) 48 | } 49 | } 50 | 51 | #[async_std::main] 52 | async fn main() -> anyhow::Result<()> { 53 | // Initialize the gpu device and its queue. 54 | // 55 | // Note that `GpuInstance` is just a simple helper struct for initializing the gpu resources. 56 | // You are free to initialize them independently if more control is needed, or reuse the ones 57 | // that were already created/owned by e.g., a game engine. 58 | let gpu = GpuInstance::new().await?; 59 | 60 | // Load and compile our kernel. The `from_device` function was generated by the `Shader` derive. 61 | // Note that its dependency to `Composable` is automatically resolved by the `Shader` derive 62 | // too. 63 | let kernel_before_overwrite = WgKernel::from_device(gpu.device())?; 64 | // Run the original shader. 65 | let result_before_overwrite = run_kernel(&gpu, &kernel_before_overwrite).await; 66 | 67 | // Overwrite the sources of the dependency module. 68 | // Since we are running this with `cargo run --example`, the path is relative to the 69 | // `target/debug` folder. 70 | Composable::set_wgsl_path("../../crates/wgcore/examples/overwritten_dependency.wgsl"); 71 | // Recompile our kernel. 72 | let kernel_after_overwrite = WgKernel::from_device(gpu.device())?; 73 | // Run the modified kernel. 74 | let result_after_overwrite = run_kernel(&gpu, &kernel_after_overwrite).await; 75 | 76 | println!("Result before overwrite: {:?}", result_before_overwrite); 77 | println!("Result after overwrite: {:?}", result_after_overwrite); 78 | 79 | Ok(()) 80 | } 81 | 82 | async fn run_kernel(gpu: &GpuInstance, kernel: &WgKernel) -> Vec { 83 | // Create the buffers. 84 | const LEN: u32 = 10; 85 | let a_data = DVector::from_fn(LEN as usize, |i, _| MyStruct { value: i as f32 }); 86 | let b_data = DVector::from_fn(LEN as usize, |i, _| MyStruct { 87 | value: i as f32 * 10.0, 88 | }); 89 | let a_buf = GpuVector::init( 90 | gpu.device(), 91 | &a_data, 92 | BufferUsages::STORAGE | BufferUsages::COPY_SRC, 93 | ); 94 | let b_buf = GpuVector::init(gpu.device(), &b_data, BufferUsages::STORAGE); 95 | let staging = GpuVector::uninit( 96 | gpu.device(), 97 | LEN, 98 | BufferUsages::COPY_DST | BufferUsages::MAP_READ, 99 | ); 100 | 101 | // Encode & submit the operation to the gpu. 102 | let mut encoder = gpu.device().create_command_encoder(&Default::default()); 103 | let mut pass = encoder.compute_pass("test", None); 104 | KernelDispatch::new(gpu.device(), &mut pass, &kernel.main) 105 | .bind0([a_buf.buffer(), b_buf.buffer()]) 106 | .dispatch(LEN.div_ceil(64)); 107 | drop(pass); 108 | // Copy the result to the staging buffer. 109 | staging.copy_from(&mut encoder, &a_buf); 110 | gpu.queue().submit(Some(encoder.finish())); 111 | 112 | // Read the result back from the gpu. 113 | staging 114 | .read(gpu.device()) 115 | .await 116 | .expect("Failed to read result from the GPU.") 117 | } 118 | -------------------------------------------------------------------------------- /crates/wgcore/examples/overwritten_dependency.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path composable::module 2 | 3 | struct MyStruct { 4 | value: f32, 5 | } 6 | 7 | fn shared_function(a: MyStruct, b: MyStruct) -> MyStruct { 8 | // Same as compose_dependency.wgsl but with a subtraction instead of an addition. 9 | return MyStruct(a.value - b.value); 10 | } -------------------------------------------------------------------------------- /crates/wgcore/examples/timestamp_queries.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "derive"))] 2 | std::compile_error!( 3 | r#" 4 | ############################################################### 5 | ## The `derive` feature must be enabled to run this example. ## 6 | ############################################################### 7 | "# 8 | ); 9 | 10 | use wgcore::gpu::GpuInstance; 11 | use wgcore::hot_reloading::HotReloadState; 12 | use wgcore::kernel::{CommandEncoderExt, KernelDispatch}; 13 | use wgcore::tensor::GpuVector; 14 | use wgcore::timestamps::GpuTimestamps; 15 | use wgcore::Shader; 16 | use wgpu::{BufferUsages, ComputePipeline}; 17 | 18 | #[derive(Shader)] 19 | #[shader(src = "timestamp_queries.wgsl", composable = false)] 20 | struct ShaderTimestampQueries { 21 | main: ComputePipeline, 22 | } 23 | 24 | #[async_std::main] 25 | async fn main() -> anyhow::Result<()> { 26 | // Initialize the gpu device and its queue. 27 | // 28 | // Note that `GpuInstance` is just a simple helper struct for initializing the gpu resources. 29 | // You are free to initialize them independently if more control is needed, or reuse the ones 30 | // that were already created/owned by e.g., a game engine. 31 | let gpu = GpuInstance::new().await?; 32 | 33 | // Load and compile our kernel. The `from_device` function was generated by the `Shader` derive. 34 | // Note that its dependency to `Composable` is automatically resolved by the `Shader` derive 35 | // too. 36 | let mut kernel = ShaderTimestampQueries::from_device(gpu.device())?; 37 | 38 | // Create the buffers. 39 | const LEN: u32 = 2_000_000; 40 | let buffer = GpuVector::init( 41 | gpu.device(), 42 | vec![0u32; LEN as usize], 43 | BufferUsages::STORAGE | BufferUsages::COPY_SRC, 44 | ); 45 | 46 | // Init hot-reloading. 47 | // We are setting up hot-reloading so that we can change somme elements in the shader 48 | // (like the iteration count) and see how that affects performances live. 49 | let mut hot_reload = HotReloadState::new()?; 50 | ShaderTimestampQueries::watch_sources(&mut hot_reload)?; 51 | 52 | // Init timestamp queries. 53 | // To measure the time of one kernel, we need two timestamps (one for when it starts and one for 54 | // when it stopped). 55 | let mut timestamps = GpuTimestamps::new(gpu.device(), 2); 56 | 57 | // Queue the operation. 58 | println!("#############################"); 59 | println!("Edit the file `timestamp_queries.wgsl` (for example by multiplying or dividing NUM_ITERS by 10).\nThe updated runtime will be printed below whenever a change is detected."); 60 | println!("#############################"); 61 | 62 | for _loop_id in 0.. { 63 | // Detect & apply changes. 64 | hot_reload.update_changes(); 65 | match kernel.reload_if_changed(gpu.device(), &hot_reload) { 66 | Ok(changed) => { 67 | if changed { 68 | // Clear the timestamps to reuse in the next loop. 69 | timestamps.clear(); 70 | // We detected a change (or this is the first loop). 71 | 72 | // Encode & submit the operation to the gpu. 73 | let mut encoder = gpu.device().create_command_encoder(&Default::default()); 74 | // Declare a compute pass with timestamps enabled. 75 | let mut pass = 76 | encoder.compute_pass("timestamp_queries_test", Some(&mut timestamps)); 77 | // Dispatch our kernel. 78 | KernelDispatch::new(gpu.device(), &mut pass, &kernel.main) 79 | .bind0([buffer.buffer()]) 80 | .dispatch(LEN.div_ceil(64)); 81 | drop(pass); 82 | // Resolve the timestamp queries. 83 | timestamps.resolve(&mut encoder); 84 | gpu.queue().submit(Some(encoder.finish())); 85 | 86 | // Read and print the kernel’s runtime. 87 | let timestamps_read = timestamps.wait_for_results_ms(gpu.device(), gpu.queue()); 88 | println!( 89 | "Current run time: {}ms", 90 | timestamps_read[1] - timestamps_read[0] 91 | ); 92 | } 93 | } 94 | Err(e) => { 95 | // Hot-reloading failed, likely due to a syntax error in the shader. 96 | println!("Hot reloading error: {:?}", e); 97 | } 98 | } 99 | } 100 | 101 | Ok(()) 102 | } 103 | -------------------------------------------------------------------------------- /crates/wgcore/examples/timestamp_queries.wgsl: -------------------------------------------------------------------------------- 1 | @group(0) @binding(0) 2 | var a: array; 3 | 4 | @compute @workgroup_size(64, 1, 1) 5 | fn main(@builtin(global_invocation_id) invocation_id: vec3) { 6 | let i = invocation_id.x; 7 | if i < arrayLength(&a) { 8 | const NUM_ITERS: u32 = 10000u; 9 | for (var k = 0u; k < NUM_ITERS; k++) { 10 | a[i] = collatz_iterations(a[i] * 7919); 11 | } 12 | } 13 | } 14 | 15 | // This is taken from the wgpu "hello_compute" example: 16 | // https://github.com/gfx-rs/wgpu/blob/6f5014f0a3441bcbc3eb4223aee454b95904b087/examples/src/hello_compute/shader.wgsl 17 | // (Apache 2 / MIT license) 18 | // 19 | // The Collatz Conjecture states that for any integer n: 20 | // If n is even, n = n/2 21 | // If n is odd, n = 3n+1 22 | // And repeat this process for each new n, you will always eventually reach 1. 23 | // Though the conjecture has not been proven, no counterexample has ever been found. 24 | // This function returns how many times this recurrence needs to be applied to reach 1. 25 | fn collatz_iterations(n_base: u32) -> u32{ 26 | var n: u32 = n_base; 27 | var i: u32 = 0u; 28 | loop { 29 | if (n <= 1u) { 30 | break; 31 | } 32 | if (n % 2u == 0u) { 33 | n = n / 2u; 34 | } 35 | else { 36 | // Overflow? (i.e. 3*n + 1 > 0xffffffffu?) 37 | if (n >= 1431655765u) { // 0x55555555u 38 | return 4294967295u; // 0xffffffffu 39 | } 40 | 41 | n = 3u * n + 1u; 42 | } 43 | i = i + 1u; 44 | } 45 | return i; 46 | } -------------------------------------------------------------------------------- /crates/wgcore/src/composer.rs: -------------------------------------------------------------------------------- 1 | //! Extensions over naga-oil’s Composer. 2 | 3 | use naga_oil::compose::preprocess::Preprocessor; 4 | use naga_oil::compose::{ 5 | ComposableModuleDefinition, ComposableModuleDescriptor, Composer, ComposerError, ErrSource, 6 | }; 7 | 8 | /// An extension trait for the naga-oil `Composer` to work around some of its limitations. 9 | pub trait ComposerExt { 10 | /// Adds a composable module to `self` only if it hasn’t been added yet. 11 | /// 12 | /// Currently, `naga-oil` behaves strangely (some symbols stop resolving) if the same module is 13 | /// added twice. This function checks if the module has already been added. If it was already 14 | /// added, then `self` is left unchanged and `Ok(None)` is returned. 15 | fn add_composable_module_once( 16 | &mut self, 17 | desc: ComposableModuleDescriptor<'_>, 18 | ) -> Result, ComposerError>; 19 | } 20 | 21 | impl ComposerExt for Composer { 22 | fn add_composable_module_once( 23 | &mut self, 24 | desc: ComposableModuleDescriptor<'_>, 25 | ) -> Result, ComposerError> { 26 | let prep = Preprocessor::default(); 27 | // TODO: not sure if allow_defines should be `true` or `false`. 28 | let meta = prep 29 | .get_preprocessor_metadata(desc.source, false) 30 | .map_err(|inner| ComposerError { 31 | inner, 32 | source: ErrSource::Constructing { 33 | path: desc.file_path.to_string(), 34 | source: desc.source.to_string(), 35 | offset: 0, 36 | }, 37 | })?; 38 | 39 | if let Some(name) = &meta.name { 40 | if self.contains_module(name) { 41 | // Module already exists, don’t insert it. 42 | return Ok(None); 43 | } 44 | } 45 | 46 | self.add_composable_module(desc).map(Some) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /crates/wgcore/src/gpu.rs: -------------------------------------------------------------------------------- 1 | //! Utilities struct to initialize a gpu device. 2 | 3 | use std::sync::Arc; 4 | use wgpu::{Adapter, Device, Instance, Queue}; 5 | 6 | /// Helper struct to initialize a device and its queue. 7 | pub struct GpuInstance { 8 | _instance: Instance, // TODO: do we have to keep this around? 9 | _adapter: Adapter, // TODO: do we have to keep this around? 10 | device: Arc, 11 | queue: Queue, 12 | } 13 | 14 | impl GpuInstance { 15 | /// Initializes a wgpu instance and create its queue. 16 | pub async fn new() -> anyhow::Result { 17 | let instance = wgpu::Instance::default(); 18 | let adapter = instance 19 | .request_adapter(&wgpu::RequestAdapterOptions { 20 | power_preference: wgpu::PowerPreference::HighPerformance, 21 | ..Default::default() 22 | }) 23 | .await 24 | .ok_or_else(|| anyhow::anyhow!("Failed to initialize gpu adapter."))?; 25 | let (device, queue) = adapter 26 | .request_device( 27 | &wgpu::DeviceDescriptor { 28 | label: None, 29 | required_features: wgpu::Features::TIMESTAMP_QUERY, 30 | required_limits: wgpu::Limits { 31 | max_buffer_size: 1_000_000_000, 32 | max_storage_buffer_binding_size: 1_000_000_000, 33 | ..Default::default() 34 | }, 35 | memory_hints: Default::default(), 36 | }, 37 | None, 38 | ) 39 | .await 40 | .map_err(|e| anyhow::anyhow!("{:?}", e))?; 41 | 42 | Ok(Self { 43 | _instance: instance, 44 | _adapter: adapter, 45 | device: Arc::new(device), 46 | queue, 47 | }) 48 | } 49 | 50 | /// The `wgpu` device. 51 | pub fn device(&self) -> &Device { 52 | &self.device 53 | } 54 | 55 | /// The shared `wgpu` device. 56 | pub fn device_arc(&self) -> Arc { 57 | self.device.clone() 58 | } 59 | 60 | /// The `wgpu` queue. 61 | pub fn queue(&self) -> &Queue { 62 | &self.queue 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /crates/wgcore/src/hot_reloading.rs: -------------------------------------------------------------------------------- 1 | //! Utility to detect changed files for shader hot-reloading. 2 | 3 | use async_channel::Receiver; 4 | use notify::{Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher}; 5 | use std::collections::HashMap; 6 | use std::path::{Path, PathBuf}; 7 | 8 | #[cfg(doc)] 9 | use crate::Shader; 10 | 11 | /// State for tracking file changes. 12 | pub struct HotReloadState { 13 | #[cfg(not(target_family = "wasm"))] 14 | watcher: RecommendedWatcher, 15 | rcv: Receiver>, 16 | file_changed: HashMap, 17 | } 18 | 19 | impl HotReloadState { 20 | /// Initializes the file-tracking context. 21 | /// 22 | /// To register a shader for change-tracking call [`Shader::watch_sources`] once with the state 23 | /// returned by this function. 24 | /// To register a file for change-tracking, call [`HotReloadState::watch_file`]. 25 | pub fn new() -> notify::Result { 26 | let (snd, rcv) = async_channel::unbounded(); 27 | Ok(Self { 28 | #[cfg(not(target_family = "wasm"))] 29 | watcher: notify::recommended_watcher(move |msg| { 30 | // TODO: does hot-reloading make sense on wasm anyway? 31 | let _ = snd.send_blocking(msg); 32 | })?, 33 | rcv, 34 | file_changed: Default::default(), 35 | }) 36 | } 37 | 38 | /// Saves in `self` the set of watched files that changed since the last time this function 39 | /// was called. 40 | /// 41 | /// Once this call completes, the [`Self::file_changed`] method can be used to check if a 42 | /// particular file (assuming it was added to the watch list with [`Self::watch_file`]) has 43 | /// changed since the last time [`Self::update_changes`] was called. 44 | pub fn update_changes(&mut self) { 45 | for changed in self.file_changed.values_mut() { 46 | *changed = false; 47 | } 48 | 49 | while let Ok(event) = self.rcv.try_recv() { 50 | if let Ok(event) = event { 51 | if event.need_rescan() || matches!(event.kind, EventKind::Modify(_)) { 52 | for path in event.paths { 53 | self.file_changed.insert(path, true); 54 | } 55 | } 56 | } 57 | } 58 | } 59 | 60 | /// Registers a files for change-tracking. 61 | pub fn watch_file(&mut self, path: &Path) -> notify::Result<()> { 62 | #[cfg(not(target_family = "wasm"))] 63 | if !self.file_changed.contains_key(path) { 64 | self.watcher.watch(path, RecursiveMode::NonRecursive)?; 65 | // NOTE: this won’t insert if the watch failed. 66 | self.file_changed.insert(path.to_path_buf(), false); 67 | } 68 | 69 | Ok(()) 70 | } 71 | 72 | /// Checks if the specified file change was detected at the time of calling [`Self::update_changes`]. 73 | pub fn file_changed(&self, path: &Path) -> bool { 74 | self.file_changed.get(path).copied().unwrap_or_default() 75 | } 76 | 77 | /// Gets the list of files watched for hot-reloading, as well as there last known change status. 78 | pub fn watched_files(&self) -> impl Iterator { 79 | self.file_changed 80 | .iter() 81 | .map(|(path, changed)| (path, *changed)) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /crates/wgcore/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![doc = include_str!("../README.md")] 2 | // #![warn(missing_docs)] 3 | #![allow(clippy::result_large_err)] 4 | 5 | pub mod composer; 6 | pub mod gpu; 7 | pub mod hot_reloading; 8 | pub mod kernel; 9 | pub mod shader; 10 | pub mod shapes; 11 | pub mod tensor; 12 | pub mod timestamps; 13 | pub mod utils; 14 | 15 | pub use bytemuck::Pod; 16 | 17 | pub use shader::{Shader, ShaderRegistry}; 18 | #[cfg(feature = "derive")] 19 | pub use wgcore_derive::*; 20 | 21 | /// Third-party modules re-exports. 22 | pub mod re_exports { 23 | pub use bytemuck; 24 | pub use encase; 25 | pub use naga_oil::{ 26 | self, 27 | compose::{ComposableModuleDescriptor, Composer, ComposerError, NagaModuleDescriptor}, 28 | }; 29 | pub use notify; 30 | pub use paste; 31 | pub use wgpu::{self, Device}; 32 | } 33 | 34 | /// A macro that declares a test that will check compilation of the shader identified by the given 35 | /// struct implementing `Shader`. 36 | #[macro_export] 37 | macro_rules! test_shader_compilation { 38 | ($ty: ident) => { 39 | wgcore::test_shader_compilation!($ty, wgcore); 40 | }; 41 | 42 | ($ty: ident, $wgcore: ident) => { 43 | wgcore::test_shader_compilation!($ty, wgcore, Default::default()); 44 | }; 45 | 46 | ($ty: ident, $wgcore: ident, $shader_defs: expr) => { 47 | $wgcore::re_exports::paste::paste! { 48 | #[cfg(test)] 49 | mod [] { 50 | use super::$ty; 51 | use naga_oil::compose::NagaModuleDescriptor; 52 | use $wgcore::Shader; 53 | use $wgcore::gpu::GpuInstance; 54 | use $wgcore::utils; 55 | 56 | #[futures_test::test] 57 | #[serial_test::serial] 58 | async fn shader_compiles() { 59 | // Add a dumb entry point for testing. 60 | let src = format!( 61 | "{} 62 | @compute @workgroup_size(1, 1, 1) 63 | fn macro_generated_test(@builtin(global_invocation_id) invocation_id: vec3) {{}} 64 | ", 65 | $ty::src() 66 | ); 67 | let gpu = GpuInstance::new().await.unwrap(); 68 | let module = $ty::composer() 69 | .unwrap() 70 | .make_naga_module(NagaModuleDescriptor { 71 | source: &src, 72 | file_path: $ty::FILE_PATH, 73 | shader_defs: $shader_defs, 74 | ..Default::default() 75 | }) 76 | .unwrap(); 77 | let _ = utils::load_module(gpu.device(), "macro_generated_test", module); 78 | } 79 | } 80 | } 81 | }; 82 | } 83 | -------------------------------------------------------------------------------- /crates/wgcore/src/shapes.rs: -------------------------------------------------------------------------------- 1 | //! Tensor shape definition. 2 | 3 | use crate::tensor::MatrixOrdering; 4 | use dashmap::DashMap; 5 | use std::sync::{Arc, Mutex}; 6 | use wgpu::util::{BufferInitDescriptor, DeviceExt}; 7 | use wgpu::{Buffer, BufferUsages, Device, Queue}; 8 | 9 | #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, bytemuck::Pod, bytemuck::Zeroable)] 10 | #[repr(C)] 11 | /// The shape of a matrix view over a GPU tensor. 12 | pub struct ViewShape { 13 | /// The tensor view’s number of rows, columns, and matrices. 14 | pub size: [u32; 3], 15 | /// The view’s column stride (number of elements between two columns). 16 | pub stride: u32, 17 | /// The view’s matrix stride (number of elements between two matrices in the tensor). 18 | pub stride_mat: u32, 19 | /// Index of the first element of the view on the underlying buffer. 20 | pub offset: u32, 21 | } 22 | 23 | impl ViewShape { 24 | /// Converts the shape `self` for a buffer `&[f32]` to a buffer `&[vec4f]`. 25 | pub fn f32_to_vec4(self) -> Self { 26 | let size = if Ordering::is_column_major() { 27 | [self.size[0] / 4, self.size[1], self.size[2]] 28 | } else { 29 | [self.size[0], self.size[1] / 4, self.size[2]] 30 | }; 31 | 32 | Self { 33 | size, 34 | stride: self.stride / 4, 35 | stride_mat: self.stride_mat / 4, 36 | offset: self.offset / 4, 37 | } 38 | } 39 | } 40 | 41 | /// A map between a `ViewShape` and an uniform storage `Buffer` containing its value on the gpu. 42 | /// 43 | /// Ideally, we should use push-constants for view shapes. Unfortunately, push-constants is an 44 | /// optional extension, so we have to emulate them with uniforms for maximum portability. 45 | #[derive(Default)] 46 | pub struct ViewShapeBuffers { 47 | // TODO: once we switch to wgpu 14, we can store a `Buffer` directly instead of 48 | // `Arc` (they will be clonable), and we can also store the `Device` 49 | // here to simplify `self.get` and the kernel dispatch apis. 50 | buffers: DashMap>, 51 | tmp_buffers: DashMap>, 52 | recycled: Mutex>>, 53 | } 54 | 55 | impl ViewShapeBuffers { 56 | /// Creates an empty map. 57 | pub fn new() -> Self { 58 | Self { 59 | buffers: DashMap::new(), 60 | tmp_buffers: DashMap::new(), 61 | recycled: Mutex::new(vec![]), 62 | } 63 | } 64 | 65 | pub fn clear_tmp(&self) { 66 | let mut recycled = self.recycled.lock().unwrap(); 67 | self.tmp_buffers.retain(|_, buffer| { 68 | recycled.push(buffer.clone()); 69 | false 70 | }) 71 | } 72 | 73 | pub fn put_tmp(&self, device: &Device, queue: &Queue, shape: ViewShape) { 74 | if self.contains(shape) { 75 | return; 76 | } 77 | 78 | let mut recycled = self.recycled.lock().unwrap(); 79 | let buffer = if let Some(buffer) = recycled.pop() { 80 | queue.write_buffer(&buffer, 0, bytemuck::cast_slice(&[shape])); 81 | buffer 82 | } else { 83 | drop(recycled); 84 | Self::make_buffer( 85 | device, 86 | shape, 87 | BufferUsages::UNIFORM | BufferUsages::COPY_DST, 88 | ) 89 | }; 90 | 91 | self.tmp_buffers.insert(shape, buffer); 92 | } 93 | 94 | fn make_buffer(device: &Device, shape: ViewShape, usage: BufferUsages) -> Arc { 95 | Arc::new(device.create_buffer_init(&BufferInitDescriptor { 96 | label: None, 97 | contents: bytemuck::cast_slice(&[shape]), 98 | usage, 99 | })) 100 | } 101 | 102 | pub fn contains(&self, shape: ViewShape) -> bool { 103 | self.buffers.contains_key(&shape) || self.tmp_buffers.contains_key(&shape) 104 | } 105 | 106 | /// Gets of insert the gpu uniform storage `Buffer` containing the value of `shape`. 107 | pub fn get(&self, device: &Device, shape: ViewShape) -> Arc { 108 | if let Some(buffer) = self.tmp_buffers.get(&shape) { 109 | return buffer.value().clone(); 110 | } 111 | 112 | self.buffers 113 | .entry(shape) 114 | .or_insert_with(|| Self::make_buffer(device, shape, BufferUsages::UNIFORM)) 115 | .clone() 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /crates/wgcore/src/utils.rs: -------------------------------------------------------------------------------- 1 | //! Utilities for creating a ComputePipeline from source or from a naga module. 2 | 3 | use wgpu::naga::Module; 4 | use wgpu::{ 5 | ComputePipeline, ComputePipelineDescriptor, Device, PipelineCompilationOptions, 6 | ShaderRuntimeChecks, 7 | }; 8 | 9 | /// Creates a compute pipeline from the shader sources `content` and the name of its `entry_point`. 10 | pub fn load_shader(device: &Device, entry_point: &str, content: &str) -> ComputePipeline { 11 | let shader = unsafe { 12 | device.create_shader_module_trusted( 13 | wgpu::ShaderModuleDescriptor { 14 | label: None, 15 | source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(content)), 16 | }, 17 | ShaderRuntimeChecks::unchecked(), 18 | ) 19 | }; 20 | device.create_compute_pipeline(&ComputePipelineDescriptor { 21 | label: Some(entry_point), 22 | layout: None, 23 | module: &shader, 24 | entry_point: Some(entry_point), 25 | compilation_options: Default::default(), 26 | cache: None, 27 | }) 28 | } 29 | 30 | /// Creates a compute pipeline from the shader `module` and the name of its `entry_point`. 31 | pub fn load_module(device: &Device, entry_point: &str, module: Module) -> ComputePipeline { 32 | let shader = unsafe { 33 | device.create_shader_module_trusted( 34 | wgpu::ShaderModuleDescriptor { 35 | label: None, 36 | source: wgpu::ShaderSource::Naga(std::borrow::Cow::Owned(module)), 37 | }, 38 | ShaderRuntimeChecks::unchecked(), 39 | ) 40 | }; 41 | device.create_compute_pipeline(&ComputePipelineDescriptor { 42 | label: Some(entry_point), 43 | layout: None, 44 | module: &shader, 45 | entry_point: Some(entry_point), 46 | compilation_options: PipelineCompilationOptions { 47 | zero_initialize_workgroup_memory: false, 48 | ..Default::default() 49 | }, 50 | cache: None, 51 | }) 52 | } 53 | 54 | /// Convents a naga module to its WGSL string representation. 55 | pub fn naga_module_to_wgsl(module: &Module) -> String { 56 | use wgpu::naga; 57 | 58 | let mut validator = 59 | naga::valid::Validator::new(naga::valid::ValidationFlags::all(), Default::default()); 60 | let info = validator.validate(module).unwrap(); 61 | 62 | naga::back::wgsl::write_string(module, &info, naga::back::wgsl::WriterFlags::EXPLICIT_TYPES) 63 | .unwrap() 64 | } 65 | -------------------------------------------------------------------------------- /crates/wgebra/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## Unreleased 2 | 3 | ### Modified 4 | 5 | - Replaced all lazy kernel invocation queueing by dispatches directly. 6 | 7 | ### Added 8 | 9 | - Added implementation of matrix decompositions (LU, QR, Cholesky, Eigendecomposition) 10 | for `mat2x2`, `mat3x3`, and `mat4x4` on the GPU. 11 | 12 | ## v0.2.0 13 | 14 | ### Modified 15 | 16 | - Update to `wgcore` v0.2.0. 17 | -------------------------------------------------------------------------------- /crates/wgebra/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "wgebra" 3 | authors = ["Sébastien Crozet "] 4 | description = "Composable WGSL shaders for linear algebra." 5 | homepage = "https://wgmath.rs" 6 | repository = "https://github.com/dimforge/wgmath" 7 | version = "0.2.0" 8 | edition = "2021" 9 | license = "MIT OR Apache-2.0" 10 | 11 | [dependencies] 12 | wgpu = { workspace = true } 13 | bytemuck = { workspace = true } 14 | naga_oil = { workspace = true } 15 | nalgebra = { workspace = true } 16 | encase = { workspace = true, features = ["nalgebra"] } 17 | 18 | wgcore = { version = "0.2", path = "../wgcore", features = ["derive"] } 19 | 20 | [dev-dependencies] 21 | nalgebra = { version = "0.33", features = ["rand"] } 22 | futures-test = "0.3" 23 | serial_test = "3" 24 | approx = "0.5" -------------------------------------------------------------------------------- /crates/wgebra/README.md: -------------------------------------------------------------------------------- 1 | # wgebra − composable WGSL shaders for linear algebra 2 | 3 |

4 | crates.io 5 |

6 | 7 | ---- 8 | 9 | The goal of **wgebra** is to especially be "[**nalgebra**](https://nalgebra.rs) on the gpu". It aims (but it isn’t there 10 | yet) to expose linear algebra operations (including BLAS-like and LAPACK-like operations) as well as geometric types 11 | (quaternions, similarities, etc.) as composable WGSL shaders and kernels. 12 | 13 | ## Reusable shader functions 14 | 15 | **wgebra** exposes various reusable WGSL shaders to be composed with your owns. This exposes various functionalities 16 | that are not covered by the mathematical functions included in the WebGPU standard: 17 | 18 | - Low-dimensional matrix decompositions: 19 | - Inverse, Cholesky, LU, QR, Symmetric Eigendecomposition, for 2x2, 3x3, and 4x4 matrices. 20 | - Singular Values Decomposition for 2x2 and 3x3 matrices. 21 | - Geometric transformations: 22 | - Quaternions (for 3D rotations). 23 | - Compact 2D rotation representation. 24 | - 2D and 3D similarities (rotations + translation + uniform scale). 25 | 26 | ## Kernels 27 | 28 | **wgebra** exposes kernels for running common linear-algebra operations on vectors, matrices, and 3-tensors. In 29 | particular: 30 | 31 | - The product of two matrices: `Gemm` (including both `m1 * m2` and `transpose(m1) * m2`). Supports 3-tensors. 32 | - The product of a matrix and a vector: `Gemv` (including both `m * v` and `transpose(m) * v`). Supports 3-tensors. 33 | - Componentwise binary operations between two vectors (addition, subtraction, product, division, assignation). 34 | - Reduction on a single vector (sum, product, min, max, squared norm). 35 | 36 | ## Using the library 37 | 38 | To access the features of **wgebra** on your own Rust project, add the dependency to your `Cargo.toml`: 39 | 40 | ```toml 41 | [dependencies] 42 | wgebra = "0.2.0" # NOTE: set the version number to the latest. 43 | ``` 44 | 45 | Then shaders can be composed with your code, and kernels can be dispatched. For additional information, refer to 46 | the [user-guide](https://wgmath.rs/docs/). 47 | 48 | ## Running tests 49 | 50 | Tests can be run the same way as usual: 51 | 52 | ```sh 53 | cargo test 54 | ``` 55 | 56 | ## Benchmarks 57 | 58 | There is currently no benchmark in the `wgebra` repository itself. However, some benchmarks of the matrix multiplication 59 | kernels can be run from [wgml-bench](https://github.com/dimforge/wgml/tree/main/crates/wgml-bench). -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/cholesky.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path IMPORT_PATH 2 | 3 | // NOTE: this depends on a preprocessor substituting the following macros: 4 | // - DIM: the matrix dimension (e.g. `2` for 2x2 matrices). 5 | // - MAT: the matrix type (e.g. `mat2x2` for a 2x2 matrix). 6 | // - IMPORT_PATH: the `define_import_path` path. 7 | 8 | /// Computes the Cholesky decomposition of the given matrix. 9 | /// 10 | /// The decomposition’s result is stored in the lower-triangular part of the output matrix. 11 | /// 12 | /// For additional information on the Cholesky decomposition, see the [nalgebra](https://nalgebra.rs/docs/user_guide/decompositions_and_lapack/#cholesky-decomposition) 13 | /// documentation. 14 | fn cholesky(x: MAT) -> MAT { 15 | var m = x; 16 | 17 | // PERF: consider unrolling the loops? 18 | for (var j = 0u; j < DIM; j++) { 19 | for (var k = 0u; k < j; k++) { 20 | let factor = -m[k][j]; 21 | 22 | for (var l = j; l < DIM; l++) { 23 | m[j][l] += factor * m[k][l]; 24 | } 25 | } 26 | 27 | let denom = sqrt(m[j][j]); 28 | m[j][j] = denom; 29 | 30 | for (var l = j + 1u; l < DIM; l++) { 31 | m[j][l] /= denom; 32 | } 33 | } 34 | 35 | return m; 36 | } -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/eig2.rs: -------------------------------------------------------------------------------- 1 | use nalgebra::{Matrix2, Vector2}; 2 | use wgcore::Shader; 3 | #[cfg(test)] 4 | use { 5 | crate::utils::WgTrig, 6 | naga_oil::compose::NagaModuleDescriptor, 7 | wgpu::{ComputePipeline, Device}, 8 | }; 9 | 10 | #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] 11 | #[repr(C)] 12 | /// GPU representation of a symmetric 2x2 matrix eigendecomposition. 13 | /// 14 | /// See the [nalgebra](https://nalgebra.rs/docs/user_guide/decompositions_and_lapack/#eigendecomposition-of-a-hermitian-matrix) 15 | /// documentation for details on the eigendecomposition 16 | pub struct GpuSymmetricEigen2 { 17 | /// Eigenvectors of the matrix. 18 | pub eigenvectors: Matrix2, 19 | /// Eigenvalues of the matrix. 20 | pub eigenvalues: Vector2, 21 | } 22 | 23 | #[derive(Shader)] 24 | #[shader(src = "eig2.wgsl")] 25 | /// Shader for computing the eigendecomposition of symmetric 2x2 matrices. 26 | pub struct WgSymmetricEigen2; 27 | 28 | impl WgSymmetricEigen2 { 29 | #[doc(hidden)] 30 | #[cfg(test)] 31 | pub fn tests(device: &Device) -> ComputePipeline { 32 | let test_kernel = r#" 33 | @group(0) @binding(0) 34 | var in: array>; 35 | @group(0) @binding(1) 36 | var out: array; 37 | 38 | @compute @workgroup_size(1, 1, 1) 39 | fn test(@builtin(global_invocation_id) invocation_id: vec3) { 40 | let i = invocation_id.x; 41 | out[i] = symmetric_eigen(in[i]); 42 | } 43 | "#; 44 | 45 | let src = format!("{}\n{}", Self::src(), test_kernel); 46 | let module = WgTrig::composer() 47 | .unwrap() 48 | .make_naga_module(NagaModuleDescriptor { 49 | source: &src, 50 | file_path: Self::FILE_PATH, 51 | ..Default::default() 52 | }) 53 | .unwrap(); 54 | wgcore::utils::load_module(device, "test", module) 55 | } 56 | } 57 | 58 | #[cfg(test)] 59 | mod test { 60 | use super::GpuSymmetricEigen2; 61 | use approx::assert_relative_eq; 62 | use nalgebra::{DVector, Matrix2}; 63 | use wgcore::gpu::GpuInstance; 64 | use wgcore::kernel::{CommandEncoderExt, KernelDispatch}; 65 | use wgcore::tensor::GpuVector; 66 | use wgpu::BufferUsages; 67 | 68 | #[futures_test::test] 69 | #[serial_test::serial] 70 | async fn gpu_eig2() { 71 | let gpu = GpuInstance::new().await.unwrap(); 72 | let svd = super::WgSymmetricEigen2::tests(gpu.device()); 73 | let mut encoder = gpu.device().create_command_encoder(&Default::default()); 74 | 75 | const LEN: usize = 345; 76 | let mut matrices: DVector> = DVector::new_random(LEN); 77 | // matrices[0] = Matrix2::zeros(); // The zero matrix can cause issues on some platforms (like macos) with unspecified atan2 on (0, 0). 78 | // matrices[1] = Matrix2::identity(); // The identity matrix can cause issues on some platforms. 79 | for mat in matrices.iter_mut() { 80 | *mat = mat.transpose() * *mat; // Make it symmetric. 81 | } 82 | 83 | let inputs = GpuVector::init(gpu.device(), &matrices, BufferUsages::STORAGE); 84 | let result: GpuVector = GpuVector::uninit( 85 | gpu.device(), 86 | matrices.len() as u32, 87 | BufferUsages::STORAGE | BufferUsages::COPY_SRC, 88 | ); 89 | let staging: GpuVector = GpuVector::uninit( 90 | gpu.device(), 91 | matrices.len() as u32, 92 | BufferUsages::MAP_READ | BufferUsages::COPY_DST, 93 | ); 94 | 95 | // Dispatch the test. 96 | let mut pass = encoder.compute_pass("test", None); 97 | KernelDispatch::new(gpu.device(), &mut pass, &svd) 98 | .bind0([inputs.buffer(), result.buffer()]) 99 | .dispatch(matrices.len() as u32); 100 | drop(pass); // Ensure the pass is ended before the encoder is borrowed again. 101 | 102 | staging.copy_from(&mut encoder, &result); 103 | gpu.queue().submit(Some(encoder.finish())); 104 | 105 | // Check the result is correct. 106 | let gpu_result = staging.read(gpu.device()).await.unwrap(); 107 | 108 | for (m, eigen) in matrices.iter().zip(gpu_result.iter()) { 109 | let reconstructed = eigen.eigenvectors 110 | * Matrix2::from_diagonal(&eigen.eigenvalues) 111 | * eigen.eigenvectors.transpose(); 112 | assert_relative_eq!(*m, reconstructed, epsilon = 1.0e-4); 113 | } 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/eig2.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path wgebra::eig2 2 | 3 | // The eigendecomposition of a symmetric 2x2 matrix. 4 | /// 5 | /// See the [nalgebra](https://nalgebra.rs/docs/user_guide/decompositions_and_lapack/#eigendecomposition-of-a-hermitian-matrix) 6 | /// documentation for details on the eigendecomposition. 7 | struct SymmetricEigen { 8 | /// Eigenvectors of the matrix. 9 | eigenvectors: mat2x2, 10 | /// Eigenvalues of the matrix. 11 | eigenvalues: vec2, 12 | }; 13 | 14 | // Computes the eigendecomposition of a symmetric 2x2 matrix. 15 | fn symmetric_eigen(m: mat2x2) -> SymmetricEigen { 16 | let a = m[0].x; 17 | let c = m[0].y; 18 | let b = m[1].y; 19 | 20 | if c == 0.0 { 21 | return SymmetricEigen( 22 | mat2x2(vec2(1.0, 0.0), vec2(0.0, 1.0)), 23 | vec2(a, b) 24 | ); 25 | } 26 | 27 | let ab = a - b; 28 | let sigma = sqrt(4.0 * c * c + ab * ab); 29 | let eigenvalues = vec2( 30 | (a + b + sigma) / 2.0, 31 | (a + b - sigma) / 2.0 32 | ); 33 | let eigv1 = vec2((a - b + sigma) / (2.0 * c), 1.0); 34 | let eigv2 = vec2((a - b - sigma) / (2.0 * c), 1.0); 35 | 36 | let eigenvectors = mat2x2(eigv1 / length(eigv1), eigv2 / length(eigv2)); 37 | 38 | return SymmetricEigen(eigenvectors, eigenvalues); 39 | } 40 | 41 | fn eigenvalues(m: mat2x2) -> vec2 { 42 | let a = m[0].x; 43 | let c = m[0].y; 44 | let b = m[1].y; 45 | 46 | if c == 0.0 { 47 | return vec2(a, b); 48 | } 49 | 50 | let ab = a - b; 51 | let sigma = sqrt(4.0 * c * c + ab * ab); 52 | return vec2( 53 | (a + b + sigma) / 2.0, 54 | (a + b - sigma) / 2.0 55 | ); 56 | } 57 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/eig3.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::WgMinMax; 2 | use crate::{WgRot2, WgSymmetricEigen2}; 3 | use nalgebra::{Matrix3, Vector3}; 4 | use wgcore::{test_shader_compilation, Shader}; 5 | #[cfg(test)] 6 | use { 7 | naga_oil::compose::NagaModuleDescriptor, 8 | wgpu::{ComputePipeline, Device}, 9 | }; 10 | 11 | #[derive(Copy, Clone, Debug, encase::ShaderType)] 12 | #[repr(C)] 13 | /// GPU representation of a symmetric 3x3 matrix eigendecomposition. 14 | /// 15 | /// See the [nalgebra](https://nalgebra.rs/docs/user_guide/decompositions_and_lapack/#eigendecomposition-of-a-hermitian-matrix) 16 | /// documentation for details on the eigendecomposition 17 | pub struct GpuSymmetricEigen3 { 18 | /// Eigenvectors of the matrix. 19 | pub eigenvectors: Matrix3, 20 | /// Eigenvalues of the matrix. 21 | pub eigenvalues: Vector3, 22 | } 23 | 24 | #[derive(Shader)] 25 | #[shader(derive(WgMinMax, WgSymmetricEigen2, WgRot2), src = "eig3.wgsl")] 26 | /// Shader for computing the eigendecomposition of symmetric 3x3 matrices. 27 | pub struct WgSymmetricEigen3; 28 | 29 | test_shader_compilation!(WgSymmetricEigen3); 30 | 31 | impl WgSymmetricEigen3 { 32 | #[doc(hidden)] 33 | #[cfg(test)] 34 | pub fn tests(device: &Device) -> ComputePipeline { 35 | let test_kernel = r#" 36 | @group(0) @binding(0) 37 | var in: array>; 38 | @group(0) @binding(1) 39 | var out: array; 40 | 41 | @compute @workgroup_size(1, 1, 1) 42 | fn test(@builtin(global_invocation_id) invocation_id: vec3) { 43 | let i = invocation_id.x; 44 | out[i] = symmetric_eigen(in[i]); 45 | } 46 | "#; 47 | 48 | let src = format!("{}\n{}", Self::src(), test_kernel); 49 | let module = WgSymmetricEigen3::composer() 50 | .unwrap() 51 | .make_naga_module(NagaModuleDescriptor { 52 | source: &src, 53 | file_path: Self::FILE_PATH, 54 | ..Default::default() 55 | }) 56 | .unwrap(); 57 | wgcore::utils::load_module(device, "test", module) 58 | } 59 | } 60 | 61 | #[cfg(test)] 62 | mod test { 63 | use super::GpuSymmetricEigen3; 64 | use approx::{assert_relative_eq, relative_eq}; 65 | use nalgebra::{DVector, Matrix3}; 66 | use wgcore::gpu::GpuInstance; 67 | use wgcore::kernel::{CommandEncoderExt, KernelDispatch}; 68 | use wgcore::tensor::GpuVector; 69 | use wgpu::BufferUsages; 70 | 71 | #[futures_test::test] 72 | #[serial_test::serial] 73 | async fn gpu_eig3() { 74 | let gpu = GpuInstance::new().await.unwrap(); 75 | let svd = super::WgSymmetricEigen3::tests(gpu.device()); 76 | let mut encoder = gpu.device().create_command_encoder(&Default::default()); 77 | 78 | const LEN: usize = 345; 79 | let mut matrices: DVector> = DVector::new_random(LEN); 80 | for mat in matrices.iter_mut() { 81 | *mat = mat.transpose() * *mat; // Make it symmetric. 82 | } 83 | 84 | let inputs = GpuVector::encase(gpu.device(), &matrices, BufferUsages::STORAGE); 85 | let result: GpuVector = GpuVector::uninit_encased( 86 | gpu.device(), 87 | matrices.len() as u32, 88 | BufferUsages::STORAGE | BufferUsages::COPY_SRC, 89 | ); 90 | let staging: GpuVector = GpuVector::uninit_encased( 91 | gpu.device(), 92 | matrices.len() as u32, 93 | BufferUsages::MAP_READ | BufferUsages::COPY_DST, 94 | ); 95 | 96 | // Dispatch the test. 97 | let mut pass = encoder.compute_pass("test", None); 98 | KernelDispatch::new(gpu.device(), &mut pass, &svd) 99 | .bind0([inputs.buffer(), result.buffer()]) 100 | .dispatch(matrices.len() as u32); 101 | drop(pass); // Ensure the pass is ended before the encoder is borrowed again. 102 | 103 | staging.copy_from_encased(&mut encoder, &result); 104 | gpu.queue().submit(Some(encoder.finish())); 105 | 106 | // Check the result is correct. 107 | let gpu_result = staging.read_encased(gpu.device()).await.unwrap(); 108 | let mut allowed_fails = 0; 109 | 110 | for (m, eigen) in matrices.iter().zip(gpu_result.iter()) { 111 | println!("eig: (gpu) {:?}", eigen); 112 | println!("eig (na): {:?}", m.symmetric_eigen()); 113 | 114 | let reconstructed = eigen.eigenvectors 115 | * Matrix3::from_diagonal(&eigen.eigenvalues) 116 | * eigen.eigenvectors.transpose(); 117 | println!("reconstructed: {:?}", m.symmetric_eigen().recompose()); 118 | 119 | // NOTE: we allow about 2% of the decompositions to fail, to account for occasionally 120 | // bad random matrices that will fail the test due to an unsuitable epsilon. 121 | // Ideally this percentage should be kept as low as possible, but likely not 122 | // removable entirely. 123 | if allowed_fails == matrices.len() * 2 / 100 { 124 | assert_relative_eq!(*m, reconstructed, epsilon = 1.0e-4); 125 | } else if !relative_eq!(*m, reconstructed, epsilon = 1.0e-4) { 126 | allowed_fails += 1; 127 | } 128 | } 129 | 130 | println!("Num fails: {}/{}", allowed_fails, matrices.len()); 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/eig4.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::WgMinMax; 2 | use crate::{WgRot2, WgSymmetricEigen2}; 3 | use nalgebra::{Matrix4, Vector4}; 4 | use wgcore::{test_shader_compilation, Shader}; 5 | #[cfg(test)] 6 | use { 7 | naga_oil::compose::NagaModuleDescriptor, 8 | wgpu::{ComputePipeline, Device}, 9 | }; 10 | 11 | #[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)] 12 | #[repr(C)] 13 | /// GPU representation of a symmetric 4x4 matrix eigendecomposition. 14 | /// 15 | /// See the [nalgebra](https://nalgebra.rs/docs/user_guide/decompositions_and_lapack/#eigendecomposition-of-a-hermitian-matrix) 16 | /// documentation for details on the eigendecomposition 17 | pub struct GpuSymmetricEigen4 { 18 | /// Eigenvectors of the matrix. 19 | pub eigenvectors: Matrix4, 20 | /// Eigenvalues of the matrix. 21 | pub eigenvalues: Vector4, 22 | } 23 | 24 | #[derive(Shader)] 25 | #[shader(derive(WgMinMax, WgSymmetricEigen2, WgRot2), src = "eig4.wgsl")] 26 | /// Shader for computing the eigendecomposition of symmetric 4x4 matrices. 27 | pub struct WgSymmetricEigen4; 28 | 29 | test_shader_compilation!(WgSymmetricEigen4); 30 | 31 | impl WgSymmetricEigen4 { 32 | #[doc(hidden)] 33 | #[cfg(test)] 34 | pub fn tests(device: &Device) -> ComputePipeline { 35 | let test_kernel = r#" 36 | @group(0) @binding(0) 37 | var in: array>; 38 | @group(0) @binding(1) 39 | var out: array; 40 | 41 | @compute @workgroup_size(1, 1, 1) 42 | fn test(@builtin(global_invocation_id) invocation_id: vec3) { 43 | let i = invocation_id.x; 44 | out[i] = symmetric_eigen(in[i]); 45 | } 46 | "#; 47 | 48 | let src = format!("{}\n{}", Self::src(), test_kernel); 49 | let module = WgSymmetricEigen4::composer() 50 | .unwrap() 51 | .make_naga_module(NagaModuleDescriptor { 52 | source: &src, 53 | file_path: Self::FILE_PATH, 54 | ..Default::default() 55 | }) 56 | .unwrap(); 57 | wgcore::utils::load_module(device, "test", module) 58 | } 59 | } 60 | 61 | #[cfg(test)] 62 | mod test { 63 | use super::GpuSymmetricEigen4; 64 | use approx::{assert_relative_eq, relative_eq}; 65 | use nalgebra::{DVector, Matrix4}; 66 | use wgcore::gpu::GpuInstance; 67 | use wgcore::kernel::{CommandEncoderExt, KernelDispatch}; 68 | use wgcore::tensor::GpuVector; 69 | use wgpu::BufferUsages; 70 | 71 | #[futures_test::test] 72 | #[serial_test::serial] 73 | async fn gpu_eig4() { 74 | let gpu = GpuInstance::new().await.unwrap(); 75 | let svd = super::WgSymmetricEigen4::tests(gpu.device()); 76 | let mut encoder = gpu.device().create_command_encoder(&Default::default()); 77 | 78 | const LEN: usize = 345; 79 | let mut matrices: DVector> = DVector::new_random(LEN); 80 | for mat in matrices.iter_mut() { 81 | *mat = mat.transpose() * *mat; // Make it symmetric. 82 | } 83 | 84 | let inputs = GpuVector::init(gpu.device(), &matrices, BufferUsages::STORAGE); 85 | let result: GpuVector = GpuVector::uninit( 86 | gpu.device(), 87 | matrices.len() as u32, 88 | BufferUsages::STORAGE | BufferUsages::COPY_SRC, 89 | ); 90 | let staging: GpuVector = GpuVector::uninit( 91 | gpu.device(), 92 | matrices.len() as u32, 93 | BufferUsages::MAP_READ | BufferUsages::COPY_DST, 94 | ); 95 | 96 | // Dispatch the test. 97 | let mut pass = encoder.compute_pass("test", None); 98 | KernelDispatch::new(gpu.device(), &mut pass, &svd) 99 | .bind0([inputs.buffer(), result.buffer()]) 100 | .dispatch(matrices.len() as u32); 101 | drop(pass); // Ensure the pass is ended before the encoder is borrowed again. 102 | 103 | staging.copy_from(&mut encoder, &result); 104 | gpu.queue().submit(Some(encoder.finish())); 105 | 106 | // Check the result is correct. 107 | let gpu_result = staging.read(gpu.device()).await.unwrap(); 108 | let mut allowed_fails = 0; 109 | 110 | for (m, eigen) in matrices.iter().zip(gpu_result.iter()) { 111 | println!("eig: (gpu) {:?}", eigen); 112 | println!("eig (na): {:?}", m.symmetric_eigen()); 113 | 114 | let reconstructed = eigen.eigenvectors 115 | * Matrix4::from_diagonal(&eigen.eigenvalues) 116 | * eigen.eigenvectors.transpose(); 117 | println!("reconstructed: {:?}", m.symmetric_eigen().recompose()); 118 | 119 | // NOTE: we allow about 2% of the decompositions to fail, to account for occasionally 120 | // bad random matrices that will fail the test due to an unsuitable epsilon. 121 | // Ideally this percentage should be kept as low as possible, but likely not 122 | // removable entirely. 123 | if allowed_fails == matrices.len() * 2 / 100 { 124 | assert_relative_eq!(*m, reconstructed, epsilon = 1.0e-4); 125 | } else if !relative_eq!(*m, reconstructed, epsilon = 1.0e-4) { 126 | allowed_fails += 1; 127 | } 128 | } 129 | 130 | println!("Num fails: {}/{}", allowed_fails, matrices.len()); 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/inv.rs: -------------------------------------------------------------------------------- 1 | use wgcore::Shader; 2 | 3 | #[derive(Shader)] 4 | #[shader(src = "inv.wgsl")] 5 | /// Shader exposing small matrix inverses. 6 | pub struct WgInv; 7 | 8 | wgcore::test_shader_compilation!(WgInv); 9 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/inv.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path wgebra::inv 2 | 3 | // These inverse functions were copied from https://github.com/gfx-rs/wgpu/tree/trunk/naga/src/back/wgsl/polyfill/inverse (MIT/Apache 2 license) 4 | 5 | /// The inverse of a 2x2 matrix. 6 | /// 7 | /// Returns an invalid result if the matrix is not invertible. 8 | fn inv2(m: mat2x2) -> mat2x2 { 9 | var adj: mat2x2; 10 | adj[0][0] = m[1][1]; 11 | adj[0][1] = -m[0][1]; 12 | adj[1][0] = -m[1][0]; 13 | adj[1][1] = m[0][0]; 14 | 15 | let det: f32 = m[0][0] * m[1][1] - m[1][0] * m[0][1]; 16 | return adj * (1 / det); 17 | } 18 | 19 | 20 | /// The inverse of a 2x2 matrix. 21 | /// 22 | /// Returns an invalid result if the matrix is not invertible. 23 | fn inv3(m: mat3x3) -> mat3x3 { 24 | var adj: mat3x3; 25 | 26 | adj[0][0] = (m[1][1] * m[2][2] - m[2][1] * m[1][2]); 27 | adj[1][0] = - (m[1][0] * m[2][2] - m[2][0] * m[1][2]); 28 | adj[2][0] = (m[1][0] * m[2][1] - m[2][0] * m[1][1]); 29 | adj[0][1] = - (m[0][1] * m[2][2] - m[2][1] * m[0][2]); 30 | adj[1][1] = (m[0][0] * m[2][2] - m[2][0] * m[0][2]); 31 | adj[2][1] = - (m[0][0] * m[2][1] - m[2][0] * m[0][1]); 32 | adj[0][2] = (m[0][1] * m[1][2] - m[1][1] * m[0][2]); 33 | adj[1][2] = - (m[0][0] * m[1][2] - m[1][0] * m[0][2]); 34 | adj[2][2] = (m[0][0] * m[1][1] - m[1][0] * m[0][1]); 35 | 36 | let det: f32 = (m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1]) 37 | - m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0]) 38 | + m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0])); 39 | 40 | return adj * (1 / det); 41 | } 42 | 43 | 44 | /// The inverse of a 2x2 matrix. 45 | /// 46 | /// Returns an invalid result if the matrix is not invertible. 47 | fn inv4(m: mat4x4) -> mat4x4 { 48 | let sub_factor00: f32 = m[2][2] * m[3][3] - m[3][2] * m[2][3]; 49 | let sub_factor01: f32 = m[2][1] * m[3][3] - m[3][1] * m[2][3]; 50 | let sub_factor02: f32 = m[2][1] * m[3][2] - m[3][1] * m[2][2]; 51 | let sub_factor03: f32 = m[2][0] * m[3][3] - m[3][0] * m[2][3]; 52 | let sub_factor04: f32 = m[2][0] * m[3][2] - m[3][0] * m[2][2]; 53 | let sub_factor05: f32 = m[2][0] * m[3][1] - m[3][0] * m[2][1]; 54 | let sub_factor06: f32 = m[1][2] * m[3][3] - m[3][2] * m[1][3]; 55 | let sub_factor07: f32 = m[1][1] * m[3][3] - m[3][1] * m[1][3]; 56 | let sub_factor08: f32 = m[1][1] * m[3][2] - m[3][1] * m[1][2]; 57 | let sub_factor09: f32 = m[1][0] * m[3][3] - m[3][0] * m[1][3]; 58 | let sub_factor10: f32 = m[1][0] * m[3][2] - m[3][0] * m[1][2]; 59 | let sub_factor11: f32 = m[1][1] * m[3][3] - m[3][1] * m[1][3]; 60 | let sub_factor12: f32 = m[1][0] * m[3][1] - m[3][0] * m[1][1]; 61 | let sub_factor13: f32 = m[1][2] * m[2][3] - m[2][2] * m[1][3]; 62 | let sub_factor14: f32 = m[1][1] * m[2][3] - m[2][1] * m[1][3]; 63 | let sub_factor15: f32 = m[1][1] * m[2][2] - m[2][1] * m[1][2]; 64 | let sub_factor16: f32 = m[1][0] * m[2][3] - m[2][0] * m[1][3]; 65 | let sub_factor17: f32 = m[1][0] * m[2][2] - m[2][0] * m[1][2]; 66 | let sub_factor18: f32 = m[1][0] * m[2][1] - m[2][0] * m[1][1]; 67 | 68 | var adj: mat4x4; 69 | adj[0][0] = (m[1][1] * sub_factor00 - m[1][2] * sub_factor01 + m[1][3] * sub_factor02); 70 | adj[1][0] = - (m[1][0] * sub_factor00 - m[1][2] * sub_factor03 + m[1][3] * sub_factor04); 71 | adj[2][0] = (m[1][0] * sub_factor01 - m[1][1] * sub_factor03 + m[1][3] * sub_factor05); 72 | adj[3][0] = - (m[1][0] * sub_factor02 - m[1][1] * sub_factor04 + m[1][2] * sub_factor05); 73 | adj[0][1] = - (m[0][1] * sub_factor00 - m[0][2] * sub_factor01 + m[0][3] * sub_factor02); 74 | adj[1][1] = (m[0][0] * sub_factor00 - m[0][2] * sub_factor03 + m[0][3] * sub_factor04); 75 | adj[2][1] = - (m[0][0] * sub_factor01 - m[0][1] * sub_factor03 + m[0][3] * sub_factor05); 76 | adj[3][1] = (m[0][0] * sub_factor02 - m[0][1] * sub_factor04 + m[0][2] * sub_factor05); 77 | adj[0][2] = (m[0][1] * sub_factor06 - m[0][2] * sub_factor07 + m[0][3] * sub_factor08); 78 | adj[1][2] = - (m[0][0] * sub_factor06 - m[0][2] * sub_factor09 + m[0][3] * sub_factor10); 79 | adj[2][2] = (m[0][0] * sub_factor11 - m[0][1] * sub_factor09 + m[0][3] * sub_factor12); 80 | adj[3][2] = - (m[0][0] * sub_factor08 - m[0][1] * sub_factor10 + m[0][2] * sub_factor12); 81 | adj[0][3] = - (m[0][1] * sub_factor13 - m[0][2] * sub_factor14 + m[0][3] * sub_factor15); 82 | adj[1][3] = (m[0][0] * sub_factor13 - m[0][2] * sub_factor16 + m[0][3] * sub_factor17); 83 | adj[2][3] = - (m[0][0] * sub_factor14 - m[0][1] * sub_factor16 + m[0][3] * sub_factor18); 84 | adj[3][3] = (m[0][0] * sub_factor15 - m[0][1] * sub_factor17 + m[0][2] * sub_factor18); 85 | 86 | let det = (m[0][0] * adj[0][0] + m[0][1] * adj[1][0] + m[0][2] * adj[2][0] + m[0][3] * adj[3][0]); 87 | 88 | return adj * (1 / det); 89 | } -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/lu.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path IMPORT_PATH 2 | 3 | // NOTE: this shader depends on a preprocessor substituting the following macros: 4 | // - NROWS: the matrix’s number of rows. 5 | // - NCOLS: the matrix’s number of columns. 6 | // - PERM: the vector type for the permutation sequence. 7 | // Must be a u32 vector of dimension min(NROWS, NCOLS). 8 | // - MAT: the matrix type (e.g. `mat2x2` for a 2x2 matrix). 9 | // - IMPORT_PATH: the `define_import_path` path. 10 | 11 | /// Structure describing a permutation sequence applied by the LU decomposition. 12 | struct Permutations { 13 | /// First permutation indices (row `ia[i]` is permuted with row`ib[i]`]. 14 | ia: PERM, 15 | /// Second permutation indices (row `ia[i]` is permuted with row`ib[i]`]. 16 | ib: PERM, 17 | /// The number of permutations in `self`. Only the first `len` elements of 18 | /// [`Self::ia`] and [`Self::ib`] need to be taken into account. 19 | len: u32 20 | } 21 | 22 | /// GPU representation of a matrix LU decomposition (with partial pivoting). 23 | /// 24 | /// See the [nalgebra](https://nalgebra.rs/docs/user_guide/decompositions_and_lapack#lu-with-partial-or-full-pivoting) documentation 25 | /// for details on the LU decomposition. 26 | struct LU { 27 | /// The LU decomposition where both lower and upper-triangular matrices are stored 28 | /// in the same matrix. In particular the diagonal full of `1` of the lower-triangular 29 | /// matrix isn’t stored explicitly. 30 | lu: MAT, 31 | /// The row permutations applied during the decomposition. 32 | p: Permutations 33 | } 34 | 35 | /// Computse the LU decomposition of the matrix. 36 | fn lu(x: MAT) -> LU { 37 | let min_nrows_ncols = min(NROWS, NCOLS); 38 | var p = Permutations(); 39 | var lu = x; 40 | 41 | for (var i = 0u; i < min_nrows_ncols; i++) { 42 | // Find the pivot index (maximum absolute value on the 43 | // column i, on rows [i, NROWS]. 44 | var piv = i; 45 | var piv_val = abs(lu[i][i]); 46 | for (var r = i + 1u; r < NROWS; r++) { 47 | let abs_val = abs(lu[i][r]); 48 | if abs_val > piv_val { 49 | piv = r; 50 | piv_val = abs_val; 51 | } 52 | } 53 | 54 | if piv_val == 0.0 { 55 | // No non-zero entries on this column. 56 | continue; 57 | } 58 | 59 | // NOTE: read the diagonal element, not `piv_val` since 60 | // the latter involve an absolute value. 61 | let diag = lu[i][piv]; 62 | 63 | if piv != i { 64 | p.ia[p.len] = i; 65 | p.ib[p.len] = piv; 66 | p.len++; 67 | 68 | for (var k = 0u; k < i; k++) { 69 | let mki = lu[k][i]; 70 | lu[k][i] = lu[k][piv]; 71 | lu[k][piv] = mki; 72 | } 73 | 74 | gauss_step_swap(&lu, diag, i, piv); 75 | } else { 76 | gauss_step(&lu, diag, i); 77 | } 78 | } 79 | 80 | return LU(lu, p); 81 | } 82 | 83 | /// Executes one step of gaussian elimination on the i-th row and column of `m`. The diagonal 84 | /// element `m[(i, i)]` is provided as argument. 85 | fn gauss_step(m: ptr, diag: f32, i: u32) 86 | { 87 | let inv_diag = 1.0 / diag; 88 | 89 | for (var r = i + 1u; r < NROWS; r++) { 90 | (*m)[i][r] *= inv_diag; 91 | } 92 | 93 | for (var c = i + 1u; c < NCOLS; c++) { 94 | let pivot = (*m)[c][i]; 95 | 96 | for (var r = i + 1u; r < NROWS; r++) { 97 | (*m)[c][r] -= pivot * (*m)[i][r]; 98 | } 99 | } 100 | } 101 | 102 | /// Swaps the rows `i` with the row `piv` and executes one step of gaussian elimination on the i-th 103 | /// row and column of `m`. The diagonal element `m[(i, i)]` is provided as argument. 104 | fn gauss_step_swap( 105 | m: ptr, 106 | diag: f32, 107 | i: u32, 108 | piv: u32, 109 | ) 110 | { 111 | let inv_diag = 1.0 / diag; 112 | 113 | let mii = (*m)[i][i]; 114 | (*m)[i][i] = (*m)[i][piv]; 115 | (*m)[i][piv] = mii; 116 | 117 | for (var r = i + 1u; r < NROWS; r++) { 118 | (*m)[i][r] *= inv_diag; 119 | } 120 | 121 | for (var c = i + 1u; c < NCOLS; c++) { 122 | let mci = (*m)[c][i]; 123 | (*m)[c][i] = (*m)[c][piv]; 124 | (*m)[c][piv] = mci; 125 | 126 | let pivot = (*m)[c][i]; 127 | 128 | for (var r = i + 1u; r < NROWS; r++) { 129 | (*m)[c][r] -= pivot * (*m)[i][r]; 130 | } 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/mod.rs: -------------------------------------------------------------------------------- 1 | //! Geometric transformations. 2 | 3 | pub use cholesky::*; 4 | pub use eig2::*; 5 | pub use eig3::*; 6 | pub use eig4::*; 7 | pub use inv::*; 8 | pub use lu::*; 9 | pub use qr2::*; 10 | pub use qr3::*; 11 | pub use qr4::*; 12 | pub use quat::*; 13 | pub use rot2::*; 14 | pub use sim2::*; 15 | pub use sim3::*; 16 | pub use svd2::*; 17 | pub use svd3::*; 18 | 19 | mod cholesky; 20 | mod eig2; 21 | mod eig3; 22 | mod eig4; 23 | mod inv; 24 | mod lu; 25 | mod qr2; 26 | mod qr3; 27 | mod qr4; 28 | mod quat; 29 | mod rot2; 30 | mod sim2; 31 | mod sim3; 32 | mod svd2; 33 | mod svd3; 34 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/qr2.rs: -------------------------------------------------------------------------------- 1 | use nalgebra::Matrix2; 2 | use wgcore::{test_shader_compilation, Shader}; 3 | #[cfg(test)] 4 | use { 5 | naga_oil::compose::NagaModuleDescriptor, 6 | wgpu::{ComputePipeline, Device}, 7 | }; 8 | 9 | #[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)] 10 | #[repr(C)] 11 | /// GPU representation of a 2x2 matrix QR decomposition. 12 | /// 13 | /// See the [nalgebra](https://nalgebra.rs/docs/user_guide/decompositions_and_lapack#qr) 14 | /// documentation for details on the QR decomposition. 15 | pub struct GpuQR2 { 16 | /// The QR decomposition’s 2x2 unitary matrix. 17 | pub q: Matrix2, 18 | /// The QR decomposition’s 2x2 upper-triangular matrix. 19 | pub r: Matrix2, 20 | } 21 | 22 | #[derive(Shader)] 23 | #[shader(src = "qr2.wgsl")] 24 | /// Shader for computing the Singular Value Decomposition of 2x2 matrices. 25 | pub struct WgQR2; 26 | 27 | test_shader_compilation!(WgQR2); 28 | 29 | impl WgQR2 { 30 | #[doc(hidden)] 31 | #[cfg(test)] 32 | pub fn tests(device: &Device) -> ComputePipeline { 33 | let test_kernel = r#" 34 | @group(0) @binding(0) 35 | var in: array>; 36 | @group(0) @binding(1) 37 | var out: array; 38 | 39 | @compute @workgroup_size(1, 1, 1) 40 | fn test(@builtin(global_invocation_id) invocation_id: vec3) { 41 | let i = invocation_id.x; 42 | out[i] = qr(in[i]); 43 | } 44 | "#; 45 | 46 | let src = format!("{}\n{}", Self::src(), test_kernel); 47 | let module = WgQR2::composer() 48 | .unwrap() 49 | .make_naga_module(NagaModuleDescriptor { 50 | source: &src, 51 | file_path: Self::FILE_PATH, 52 | ..Default::default() 53 | }) 54 | .unwrap(); 55 | wgcore::utils::load_module(device, "test", module) 56 | } 57 | } 58 | 59 | #[cfg(test)] 60 | mod test { 61 | use super::GpuQR2; 62 | use approx::{assert_relative_eq, relative_eq}; 63 | use nalgebra::{DVector, Matrix2}; 64 | use wgcore::gpu::GpuInstance; 65 | use wgcore::kernel::{CommandEncoderExt, KernelDispatch}; 66 | use wgcore::tensor::GpuVector; 67 | use wgpu::BufferUsages; 68 | 69 | #[futures_test::test] 70 | #[serial_test::serial] 71 | async fn gpu_qr2() { 72 | let gpu = GpuInstance::new().await.unwrap(); 73 | let svd = super::WgQR2::tests(gpu.device()); 74 | let mut encoder = gpu.device().create_command_encoder(&Default::default()); 75 | 76 | const LEN: usize = 345; 77 | let matrices: DVector> = DVector::new_random(LEN); 78 | let inputs = GpuVector::init(gpu.device(), &matrices, BufferUsages::STORAGE); 79 | let result: GpuVector = GpuVector::uninit( 80 | gpu.device(), 81 | matrices.len() as u32, 82 | BufferUsages::STORAGE | BufferUsages::COPY_SRC, 83 | ); 84 | let staging: GpuVector = GpuVector::uninit( 85 | gpu.device(), 86 | matrices.len() as u32, 87 | BufferUsages::MAP_READ | BufferUsages::COPY_DST, 88 | ); 89 | 90 | // Dispatch the test. 91 | let mut pass = encoder.compute_pass("test", None); 92 | KernelDispatch::new(gpu.device(), &mut pass, &svd) 93 | .bind0([inputs.buffer(), result.buffer()]) 94 | .dispatch(matrices.len() as u32); 95 | drop(pass); // Ensure the pass is ended before the encoder is borrowed again. 96 | 97 | staging.copy_from(&mut encoder, &result); 98 | gpu.queue().submit(Some(encoder.finish())); 99 | 100 | // Check the result is correct. 101 | let gpu_result = staging.read(gpu.device()).await.unwrap(); 102 | let mut allowed_fails = 0; 103 | 104 | for (m, qr) in matrices.iter().zip(gpu_result.iter()) { 105 | let qr_na = m.qr(); 106 | 107 | // NOTE: we allow about 1% of the decompositions to fail, to account for occasionally 108 | // bad random matrices that will fail the test due to an unsuitable epsilon. 109 | // Ideally this percentage should be kept as low as possible, but likely not 110 | // removable entirely. 111 | if allowed_fails == matrices.len() * 2 / 100 { 112 | assert_relative_eq!(qr_na.q(), qr.q, epsilon = 1.0e-4); 113 | assert_relative_eq!(qr_na.r(), qr.r, epsilon = 1.0e-4); 114 | } else if !relative_eq!(qr_na.q(), qr.q, epsilon = 1.0e-4) 115 | || !relative_eq!(qr_na.r(), qr.r, epsilon = 1.0e-4) 116 | { 117 | allowed_fails += 1; 118 | } 119 | } 120 | 121 | println!("Num fails: {}/{}", allowed_fails, matrices.len()); 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/qr2.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path wgebra::qr2 2 | 3 | /// The QR decomposition of a 2x2 matrix. 4 | /// 5 | /// See the [nalgebra](https://nalgebra.rs/docs/user_guide/decompositions_and_lapack#qr) 6 | /// documentation for details on the QR decomposition. 7 | struct QR { 8 | /// The QR decomposition’s 2x2 unitary matrix. 9 | q: mat2x2, 10 | /// The QR decomposition’s 2x2 upper-triangular matrix. 11 | r: mat2x2 12 | } 13 | 14 | // Computes the QR decomposition of a 2x2 matrix. 15 | fn qr(x: mat2x2) -> QR { 16 | const DIM = 2; 17 | var m = x; 18 | var diag = vec2(); 19 | 20 | // Apply householder reflections. 21 | for (var i = 0; i < 2; i++) { 22 | // Ported from househodler::reflection_axis_mut 23 | // The axis (or `column`) is `m[i.., i]`. 24 | var axis_sq_norm = 0.0; 25 | for (var r = i; r < DIM; r++) { 26 | axis_sq_norm += m[i][r] * m[i][r]; 27 | } 28 | 29 | let axis_norm = sqrt(axis_sq_norm); 30 | let modulus = abs(m[i][i]); 31 | let sgn = sign(m[i][i]); 32 | var signed_norm = sgn * axis_norm; 33 | let factor = (axis_sq_norm + modulus * axis_norm) * 2.0; 34 | m[i][i] += signed_norm; 35 | 36 | if factor != 0.0 { 37 | let factor_sqrt = sqrt(factor); 38 | var norm = 0.0; 39 | for (var r = i; r < DIM; r++) { 40 | m[i][r] /= factor_sqrt; 41 | norm += m[i][r] * m[i][r]; 42 | } 43 | 44 | norm = sqrt(norm); 45 | 46 | // Renormalization (see nalgebra’s doc of `householder::reflection_axis_mut`). 47 | for (var r = i; r < DIM; r++) { 48 | m[i][r] /= norm; 49 | } 50 | 51 | diag[i] = -signed_norm; 52 | } else { 53 | diag[i] = signed_norm; 54 | } 55 | 56 | // Apply the reflection. 57 | if factor != 0.0 { 58 | // refl.reflect_with_sign(&mut res_rows, signs[i].clone().signum()); 59 | let sgn = sign(diag[i]); 60 | for (var c = i; c < DIM; c++) { 61 | let m_two = -2.0 * sgn; 62 | var factor = 0.0; 63 | for (var r = i; r < DIM; r++) { 64 | factor += m[i][r] * m[c][r]; 65 | } 66 | for (var r = i; r < DIM; r++) { 67 | m[c][r] = m_two * factor * m[i][r] + m[c][r] * sgn; 68 | } 69 | } 70 | } 71 | } 72 | 73 | // Initialize q from m (see QR::q() in nalgebra). 74 | var q = mat2x2( 75 | vec2(1.0, 0.0), 76 | vec2(0.0, 1.0), 77 | ); 78 | for (var i = DIM - 1; i >= 0; i--) { 79 | // axis := m[i.., i] 80 | // res_rows := q[i.., i..] 81 | let sgn = sign(diag[i]); 82 | 83 | // refl.reflect_with_sign(&mut res_rows, signs[i].clone().signum()); 84 | for (var c = i; c < DIM; c++) { 85 | let m_two = -2.0 * sgn; 86 | var factor = 0.0; 87 | for (var r = i; r < DIM; r++) { 88 | factor += m[i][r] * q[c][r]; 89 | } 90 | for (var r = i; r < DIM; r++) { 91 | q[c][r] = m_two * factor * m[i][r] + q[c][r] * sgn; 92 | } 93 | } 94 | 95 | if i == 0 { 96 | break; 97 | } 98 | } 99 | 100 | // Fill the lower triangle of `m` and set its diagonal to get `r`. 101 | let r = mat2x2( 102 | vec2(abs(diag.x), 0.0), 103 | vec2(m[1][0], abs(diag.y)), 104 | ); 105 | 106 | return QR(q, r); 107 | } 108 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/qr3.rs: -------------------------------------------------------------------------------- 1 | use nalgebra::Matrix3; 2 | use wgcore::{test_shader_compilation, Shader}; 3 | #[cfg(test)] 4 | use { 5 | naga_oil::compose::NagaModuleDescriptor, 6 | wgpu::{ComputePipeline, Device}, 7 | }; 8 | 9 | #[derive(Copy, Clone, Debug, encase::ShaderType)] 10 | #[repr(C)] 11 | /// GPU representation of a 3x3 matrix QR decomposition. 12 | /// 13 | /// See the [nalgebra](https://nalgebra.rs/docs/user_guide/decompositions_and_lapack#qr) 14 | /// documentation for details on the QR decomposition. 15 | pub struct GpuQR3 { 16 | /// The QR decomposition’s 3x3 unitary matrix. 17 | pub q: Matrix3, 18 | /// The QR decomposition’s 3x3 upper-triangular matrix. 19 | pub r: Matrix3, 20 | } 21 | 22 | #[derive(Shader)] 23 | #[shader(src = "qr3.wgsl")] 24 | /// Shader for computing the Singular Value Decomposition of 3x3 matrices. 25 | pub struct WgQR3; 26 | 27 | test_shader_compilation!(WgQR3); 28 | 29 | impl WgQR3 { 30 | #[doc(hidden)] 31 | #[cfg(test)] 32 | pub fn tests(device: &Device) -> ComputePipeline { 33 | let test_kernel = r#" 34 | @group(0) @binding(0) 35 | var in: array>; 36 | @group(0) @binding(1) 37 | var out: array; 38 | 39 | @compute @workgroup_size(1, 1, 1) 40 | fn test(@builtin(global_invocation_id) invocation_id: vec3) { 41 | let i = invocation_id.x; 42 | out[i] = qr(in[i]); 43 | } 44 | "#; 45 | 46 | let src = format!("{}\n{}", Self::src(), test_kernel); 47 | let module = WgQR3::composer() 48 | .unwrap() 49 | .make_naga_module(NagaModuleDescriptor { 50 | source: &src, 51 | file_path: Self::FILE_PATH, 52 | ..Default::default() 53 | }) 54 | .unwrap(); 55 | wgcore::utils::load_module(device, "test", module) 56 | } 57 | } 58 | 59 | #[cfg(test)] 60 | mod test { 61 | use super::GpuQR3; 62 | use approx::{assert_relative_eq, relative_eq}; 63 | use nalgebra::{DVector, Matrix3}; 64 | use wgcore::gpu::GpuInstance; 65 | use wgcore::kernel::{CommandEncoderExt, KernelDispatch}; 66 | use wgcore::tensor::GpuVector; 67 | use wgpu::BufferUsages; 68 | 69 | #[futures_test::test] 70 | #[serial_test::serial] 71 | async fn gpu_qr3() { 72 | let gpu = GpuInstance::new().await.unwrap(); 73 | let svd = super::WgQR3::tests(gpu.device()); 74 | let mut encoder = gpu.device().create_command_encoder(&Default::default()); 75 | 76 | const LEN: usize = 345; 77 | let matrices: DVector> = DVector::new_random(LEN); 78 | let inputs = GpuVector::encase(gpu.device(), &matrices, BufferUsages::STORAGE); 79 | let result: GpuVector = GpuVector::uninit_encased( 80 | gpu.device(), 81 | matrices.len() as u32, 82 | BufferUsages::STORAGE | BufferUsages::COPY_SRC, 83 | ); 84 | let staging: GpuVector = GpuVector::uninit_encased( 85 | gpu.device(), 86 | matrices.len() as u32, 87 | BufferUsages::MAP_READ | BufferUsages::COPY_DST, 88 | ); 89 | 90 | // Dispatch the test. 91 | let mut pass = encoder.compute_pass("test", None); 92 | KernelDispatch::new(gpu.device(), &mut pass, &svd) 93 | .bind0([inputs.buffer(), result.buffer()]) 94 | .dispatch(matrices.len() as u32); 95 | drop(pass); // Ensure the pass is ended before the encoder is borrowed again. 96 | 97 | staging.copy_from_encased(&mut encoder, &result); 98 | gpu.queue().submit(Some(encoder.finish())); 99 | 100 | // Check the result is correct. 101 | let gpu_result = staging.read_encased(gpu.device()).await.unwrap(); 102 | let mut allowed_fails = 0; 103 | 104 | for (m, qr) in matrices.iter().zip(gpu_result.iter()) { 105 | let qr_na = m.qr(); 106 | 107 | // NOTE: we allow about 1% of the decompositions to fail, to account for occasionally 108 | // bad random matrices that will fail the test due to an unsuitable epsilon. 109 | // Ideally this percentage should be kept as low as possible, but likely not 110 | // removable entirely. 111 | if allowed_fails == matrices.len() * 2 / 100 { 112 | assert_relative_eq!(qr_na.q(), qr.q, epsilon = 1.0e-4); 113 | assert_relative_eq!(qr_na.r(), qr.r, epsilon = 1.0e-4); 114 | } else if !relative_eq!(qr_na.q(), qr.q, epsilon = 1.0e-4) 115 | || !relative_eq!(qr_na.r(), qr.r, epsilon = 1.0e-4) 116 | { 117 | allowed_fails += 1; 118 | } 119 | } 120 | 121 | println!("Num fails: {}/{}", allowed_fails, matrices.len()); 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/qr3.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path wgebra::qr3 2 | 3 | /// The QR decomposition of a 3x3 matrix. 4 | /// 5 | /// See the [nalgebra](https://nalgebra.rs/docs/user_guide/decompositions_and_lapack#qr) 6 | /// documentation for details on the QR decomposition. 7 | struct QR { 8 | /// The QR decomposition’s 3x3 unitary matrix. 9 | q: mat3x3, 10 | /// The QR decomposition’s 3x3 upper-triangular matrix. 11 | r: mat3x3 12 | } 13 | 14 | /// Computes the QR decomposition of a 3x3 matrix. 15 | fn qr(x: mat3x3) -> QR { 16 | const DIM = 3; 17 | var m = x; 18 | var diag = vec3(); 19 | 20 | // Apply householder reflections. 21 | for (var i = 0; i < 3; i++) { 22 | // Ported from househodler::reflection_axis_mut 23 | // The axis (or `column`) is `m[i.., i]`. 24 | var axis_sq_norm = 0.0; 25 | for (var r = i; r < DIM; r++) { 26 | axis_sq_norm += m[i][r] * m[i][r]; 27 | } 28 | 29 | let axis_norm = sqrt(axis_sq_norm); 30 | let modulus = abs(m[i][i]); 31 | let sgn = sign(m[i][i]); 32 | var signed_norm = sgn * axis_norm; 33 | let factor = (axis_sq_norm + modulus * axis_norm) * 2.0; 34 | m[i][i] += signed_norm; 35 | 36 | if factor != 0.0 { 37 | let factor_sqrt = sqrt(factor); 38 | var norm = 0.0; 39 | for (var r = i; r < DIM; r++) { 40 | m[i][r] /= factor_sqrt; 41 | norm += m[i][r] * m[i][r]; 42 | } 43 | 44 | norm = sqrt(norm); 45 | 46 | // Renormalization (see nalgebra’s doc of `householder::reflection_axis_mut`). 47 | for (var r = i; r < DIM; r++) { 48 | m[i][r] /= norm; 49 | } 50 | 51 | diag[i] = -signed_norm; 52 | } else { 53 | diag[i] = signed_norm; 54 | } 55 | 56 | // Apply the reflection. 57 | if factor != 0.0 { 58 | // refl.reflect_with_sign(&mut res_rows, signs[i].clone().signum()); 59 | let sgn = sign(diag[i]); 60 | for (var c = i; c < DIM; c++) { 61 | let m_two = -2.0 * sgn; 62 | var factor = 0.0; 63 | for (var r = i; r < DIM; r++) { 64 | factor += m[i][r] * m[c][r]; 65 | } 66 | for (var r = i; r < DIM; r++) { 67 | m[c][r] = m_two * factor * m[i][r] + m[c][r] * sgn; 68 | } 69 | } 70 | } 71 | } 72 | 73 | // Initialize q from m (see QR::q() in nalgebra). 74 | var q = mat3x3( 75 | vec3(1.0, 0.0, 0.0), 76 | vec3(0.0, 1.0, 0.0), 77 | vec3(0.0, 0.0, 1.0), 78 | ); 79 | for (var i = DIM - 1; i >= 0; i--) { 80 | // axis := m[i.., i] 81 | // res_rows := q[i.., i..] 82 | let sgn = sign(diag[i]); 83 | 84 | // refl.reflect_with_sign(&mut res_rows, signs[i].clone().signum()); 85 | for (var c = i; c < DIM; c++) { 86 | let m_two = -2.0 * sgn; 87 | var factor = 0.0; 88 | for (var r = i; r < DIM; r++) { 89 | factor += m[i][r] * q[c][r]; 90 | } 91 | for (var r = i; r < DIM; r++) { 92 | q[c][r] = m_two * factor * m[i][r] + q[c][r] * sgn; 93 | } 94 | } 95 | 96 | if i == 0 { 97 | break; 98 | } 99 | } 100 | 101 | // Fill the lower triangle of `m` and set its diagonal to get `r`. 102 | let r = mat3x3( 103 | vec3(abs(diag.x), 0.0, 0.0), 104 | vec3(m[1][0], abs(diag.y), 0.0), 105 | vec3(m[2][0], m[2][1], abs(diag.z)), 106 | ); 107 | 108 | return QR(q, r); 109 | } 110 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/qr4.rs: -------------------------------------------------------------------------------- 1 | use nalgebra::Matrix4; 2 | use wgcore::{test_shader_compilation, Shader}; 3 | #[cfg(test)] 4 | use { 5 | naga_oil::compose::NagaModuleDescriptor, 6 | wgpu::{ComputePipeline, Device}, 7 | }; 8 | 9 | #[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)] 10 | #[repr(C)] 11 | /// GPU representation of a 4x4 matrix QR decomposition. 12 | /// 13 | /// See the [nalgebra](https://nalgebra.rs/docs/user_guide/decompositions_and_lapack#qr) documentation 14 | /// for details on the QR decomposition. 15 | pub struct GpuQR4 { 16 | /// The QR decomposition’s 4x4 unitary matrix. 17 | pub q: Matrix4, 18 | /// The QR decomposition’s 4x4 upper-triangular matrix. 19 | pub r: Matrix4, 20 | } 21 | 22 | #[derive(Shader)] 23 | #[shader(src = "qr4.wgsl")] 24 | /// Shader for computing the Singular Value Decomposition of 4x4 matrices. 25 | pub struct WgQR4; 26 | 27 | test_shader_compilation!(WgQR4); 28 | 29 | impl WgQR4 { 30 | #[doc(hidden)] 31 | #[cfg(test)] 32 | pub fn tests(device: &Device) -> ComputePipeline { 33 | let test_kernel = r#" 34 | @group(0) @binding(0) 35 | var in: array>; 36 | @group(0) @binding(1) 37 | var out: array; 38 | 39 | @compute @workgroup_size(1, 1, 1) 40 | fn test(@builtin(global_invocation_id) invocation_id: vec3) { 41 | let i = invocation_id.x; 42 | out[i] = qr(in[i]); 43 | } 44 | "#; 45 | 46 | let src = format!("{}\n{}", Self::src(), test_kernel); 47 | let module = WgQR4::composer() 48 | .unwrap() 49 | .make_naga_module(NagaModuleDescriptor { 50 | source: &src, 51 | file_path: Self::FILE_PATH, 52 | ..Default::default() 53 | }) 54 | .unwrap(); 55 | wgcore::utils::load_module(device, "test", module) 56 | } 57 | } 58 | 59 | #[cfg(test)] 60 | mod test { 61 | use super::GpuQR4; 62 | use approx::{assert_relative_eq, relative_eq}; 63 | use nalgebra::{DVector, Matrix4}; 64 | use wgcore::gpu::GpuInstance; 65 | use wgcore::kernel::{CommandEncoderExt, KernelDispatch}; 66 | use wgcore::tensor::GpuVector; 67 | use wgpu::BufferUsages; 68 | 69 | #[futures_test::test] 70 | #[serial_test::serial] 71 | async fn gpu_qr4() { 72 | let gpu = GpuInstance::new().await.unwrap(); 73 | let svd = super::WgQR4::tests(gpu.device()); 74 | let mut encoder = gpu.device().create_command_encoder(&Default::default()); 75 | 76 | const LEN: usize = 345; 77 | let matrices: DVector> = DVector::new_random(LEN); 78 | let inputs = GpuVector::init(gpu.device(), &matrices, BufferUsages::STORAGE); 79 | let result: GpuVector = GpuVector::uninit( 80 | gpu.device(), 81 | matrices.len() as u32, 82 | BufferUsages::STORAGE | BufferUsages::COPY_SRC, 83 | ); 84 | let staging: GpuVector = GpuVector::uninit( 85 | gpu.device(), 86 | matrices.len() as u32, 87 | BufferUsages::MAP_READ | BufferUsages::COPY_DST, 88 | ); 89 | 90 | // Dispatch the test. 91 | let mut pass = encoder.compute_pass("test", None); 92 | KernelDispatch::new(gpu.device(), &mut pass, &svd) 93 | .bind0([inputs.buffer(), result.buffer()]) 94 | .dispatch(matrices.len() as u32); 95 | drop(pass); // Ensure the pass is ended before the encoder is borrowed again. 96 | 97 | staging.copy_from(&mut encoder, &result); 98 | gpu.queue().submit(Some(encoder.finish())); 99 | 100 | // Check the result is correct. 101 | let gpu_result = staging.read(gpu.device()).await.unwrap(); 102 | let mut allowed_fails = 0; 103 | 104 | for (m, qr) in matrices.iter().zip(gpu_result.iter()) { 105 | let qr_na = m.qr(); 106 | 107 | // NOTE: we allow about 1% of the decompositions to fail, to account for occasionally 108 | // bad random matrices that will fail the test due to an unsuitable epsilon. 109 | // Ideally this percentage should be kept as low as possible, but likely not 110 | // removable entirely. 111 | if allowed_fails == matrices.len() * 2 / 100 { 112 | assert_relative_eq!(qr_na.q(), qr.q, epsilon = 1.0e-4); 113 | assert_relative_eq!(qr_na.r(), qr.r, epsilon = 1.0e-4); 114 | } else if !relative_eq!(qr_na.q(), qr.q, epsilon = 1.0e-4) 115 | || !relative_eq!(qr_na.r(), qr.r, epsilon = 1.0e-4) 116 | { 117 | allowed_fails += 1; 118 | } 119 | } 120 | 121 | println!("Num fails: {}/{}", allowed_fails, matrices.len()); 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/qr4.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path wgebra::qr4 2 | 3 | /// The QR decomposition of a 4x4 matrix. 4 | /// 5 | /// See the [nalgebra](https://nalgebra.rs/docs/user_guide/decompositions_and_lapack#qr) documentation 6 | /// for details on the QR decomposition. 7 | struct QR { 8 | /// The QR decomposition’s 4x4 unitary matrix. 9 | q: mat4x4, 10 | /// The QR decomposition’s 4x4 upper-triangular matrix. 11 | r: mat4x4 12 | } 13 | 14 | /// Computes the QR decomposition of a 4x4 matrix. 15 | fn qr(x: mat4x4) -> QR { 16 | const DIM = 4; 17 | var m = x; 18 | var diag = vec4(); 19 | 20 | // Apply householder reflections. 21 | for (var i = 0; i < 4; i++) { 22 | // Ported from househodler::reflection_axis_mut 23 | // The axis (or `column`) is `m[i.., i]`. 24 | var axis_sq_norm = 0.0; 25 | for (var r = i; r < DIM; r++) { 26 | axis_sq_norm += m[i][r] * m[i][r]; 27 | } 28 | 29 | let axis_norm = sqrt(axis_sq_norm); 30 | let modulus = abs(m[i][i]); 31 | let sgn = sign(m[i][i]); 32 | var signed_norm = sgn * axis_norm; 33 | let factor = (axis_sq_norm + modulus * axis_norm) * 2.0; 34 | m[i][i] += signed_norm; 35 | 36 | if factor != 0.0 { 37 | let factor_sqrt = sqrt(factor); 38 | var norm = 0.0; 39 | for (var r = i; r < DIM; r++) { 40 | m[i][r] /= factor_sqrt; 41 | norm += m[i][r] * m[i][r]; 42 | } 43 | 44 | norm = sqrt(norm); 45 | 46 | // Renormalization (see nalgebra’s doc of `householder::reflection_axis_mut`). 47 | for (var r = i; r < DIM; r++) { 48 | m[i][r] /= norm; 49 | } 50 | 51 | diag[i] = -signed_norm; 52 | } else { 53 | diag[i] = signed_norm; 54 | } 55 | 56 | // Apply the reflection. 57 | if factor != 0.0 { 58 | // refl.reflect_with_sign(&mut res_rows, signs[i].clone().signum()); 59 | let sgn = sign(diag[i]); 60 | for (var c = i; c < DIM; c++) { 61 | let m_two = -2.0 * sgn; 62 | var factor = 0.0; 63 | for (var r = i; r < DIM; r++) { 64 | factor += m[i][r] * m[c][r]; 65 | } 66 | for (var r = i; r < DIM; r++) { 67 | m[c][r] = m_two * factor * m[i][r] + m[c][r] * sgn; 68 | } 69 | } 70 | } 71 | } 72 | 73 | // Initialize q from m (see QR::q() in nalgebra). 74 | var q = mat4x4( 75 | vec4(1.0, 0.0, 0.0, 0.0), 76 | vec4(0.0, 1.0, 0.0, 0.0), 77 | vec4(0.0, 0.0, 1.0, 0.0), 78 | vec4(0.0, 0.0, 0.0, 1.0), 79 | ); 80 | for (var i = DIM - 1; i >= 0; i--) { 81 | // axis := m[i.., i] 82 | // res_rows := q[i.., i..] 83 | let sgn = sign(diag[i]); 84 | 85 | // refl.reflect_with_sign(&mut res_rows, signs[i].clone().signum()); 86 | for (var c = i; c < DIM; c++) { 87 | let m_two = -2.0 * sgn; 88 | var factor = 0.0; 89 | for (var r = i; r < DIM; r++) { 90 | factor += m[i][r] * q[c][r]; 91 | } 92 | for (var r = i; r < DIM; r++) { 93 | q[c][r] = m_two * factor * m[i][r] + q[c][r] * sgn; 94 | } 95 | } 96 | 97 | if i == 0 { 98 | break; 99 | } 100 | } 101 | 102 | // Fill the lower triangle of `m` and set its diagonal to get `r`. 103 | let r = mat4x4( 104 | vec4(abs(diag.x), 0.0, 0.0, 0.0), 105 | vec4(m[1][0], abs(diag.y), 0.0, 0.0), 106 | vec4(m[2][0], m[2][1], abs(diag.z), 0.0), 107 | vec4(m[3][0], m[3][1], m[3][2], abs(diag.w)) 108 | ); 109 | 110 | return QR(q, r); 111 | } 112 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/quat.rs: -------------------------------------------------------------------------------- 1 | use wgcore::Shader; 2 | 3 | // NOTE: interesting perf. comparison between quaternions and matrices: 4 | // https://tech.metail.com/performance-quaternions-gpu/ 5 | 6 | #[derive(Shader)] 7 | #[shader(src = "quat.wgsl")] 8 | /// Shader exposing a quaternion type and operations for representing 3D rotations. 9 | pub struct WgQuat; 10 | 11 | wgcore::test_shader_compilation!(WgQuat); 12 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/quat.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path wgebra::quat 2 | 3 | /// A unit quaternion representing a rotation. 4 | struct Quat { 5 | /// The quaternion’s coordinates (i, j, k, w). 6 | coords: vec4, 7 | } 8 | 9 | /// The quaternion representing an identity rotation. 10 | fn identity() -> Quat { 11 | return Quat(vec4(0.0, 0.0, 0.0, 1.0)); 12 | } 13 | 14 | /// Convert an axis-angle (represented as the axis multiplied by the angle) to 15 | /// a quaternion. 16 | fn fromScaledAxis(axisangle: vec3) -> Quat { 17 | let angle = length(axisangle); 18 | let is_zero = f32(angle == 0.0); 19 | 20 | if angle == 0.0 { 21 | return identity(); 22 | } else { 23 | let hs = sin(angle / 2.0); 24 | let hc = cos(angle / 2.0); 25 | let axis = axisangle / angle; 26 | return Quat(vec4(axis * hs, hc)); 27 | } 28 | } 29 | 30 | // Converts this quaternion to a rotation matrix. 31 | fn toMatrix(quat: Quat) -> mat3x3 { 32 | let i = quat.coords.x; 33 | let j = quat.coords.y; 34 | let k = quat.coords.z; 35 | let w = quat.coords.w; 36 | 37 | let ww = w * w; 38 | let ii = i * i; 39 | let jj = j * j; 40 | let kk = k * k; 41 | let ij = i * j * 2.0; 42 | let wk = w * k * 2.0; 43 | let wj = w * j * 2.0; 44 | let ik = i * k * 2.0; 45 | let jk = j * k * 2.0; 46 | let wi = w * i * 2.0; 47 | 48 | return mat3x3( 49 | vec3(ww + ii - jj - kk, wk + ij, ik - wj), 50 | vec3(ij - wk, ww - ii + jj - kk, wi + jk), 51 | vec3(wj + ik, jk - wi, ww - ii - jj + kk), 52 | ); 53 | } 54 | 55 | /// Normalizes this quaternion again using a first-order Taylor approximation. 56 | /// This is useful when repeated computations might cause a drift in the norm 57 | /// because of float inaccuracies. 58 | fn renormalizeFast(q: Quat) -> Quat { 59 | let sq_norm = dot(q.coords, q.coords); 60 | return Quat(q.coords * (0.5 * (3.0 - sq_norm))); 61 | } 62 | 63 | /// The inverse (conjugate) of a unit quaternion. 64 | fn inv(q: Quat) -> Quat { 65 | return Quat(vec4(-q.coords.xyz, q.coords.w)); 66 | } 67 | 68 | /// Multiplies two quaternions (combines their rotations). 69 | fn mul(lhs: Quat, rhs: Quat) -> Quat { 70 | let scalar = lhs.coords.w * rhs.coords.w - dot(lhs.coords.xyz, rhs.coords.xyz); 71 | let v = cross(lhs.coords.xyz, rhs.coords.xyz) + lhs.coords.w * rhs.coords.xyz + rhs.coords.w * lhs.coords.xyz; 72 | return Quat(vec4(v, scalar)); 73 | } 74 | 75 | /// Multiplies a quaternion by a vector (rotates the vector). 76 | fn mulVec(q: Quat, v: vec3) -> vec3 { 77 | let t = cross(q.coords.xyz, v) * 2.0; 78 | let c = cross(q.coords.xyz, t); 79 | return t * q.coords.w + c + v; 80 | } 81 | 82 | /// Multiplies a quaternion’s inverse by a vector (inverse-rotates the vector). 83 | fn invMulVec(q: Quat, v: vec3) -> vec3 { 84 | let t = cross(q.coords.xyz, v) * 2.0; 85 | let c = cross(q.coords.xyz, t); 86 | return t * -q.coords.w + c + v; 87 | } 88 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/rot2.rs: -------------------------------------------------------------------------------- 1 | use wgcore::Shader; 2 | 3 | #[derive(Shader)] 4 | #[shader(src = "rot2.wgsl")] 5 | /// Shader exposing a 2D rotation type and operations. 6 | pub struct WgRot2; 7 | 8 | wgcore::test_shader_compilation!(WgRot2); 9 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/rot2.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path wgebra::rot2 2 | 3 | 4 | /// Compact representation of a 2D rotation. 5 | struct Rot2 { 6 | cos_sin: vec2 7 | } 8 | 9 | /// Returns `true` if `rot` isn’t zero. 10 | /// 11 | /// Failible functions that return a Rot2 will generally return zero 12 | /// as the value to indicate failure. 13 | fn is_valid(rot: Rot2) -> bool { 14 | return rot.cos_sin.x != 0.0 || rot.cos_sin.y != 0.0; 15 | } 16 | 17 | /// Initializes a 2D rotation from an angle (radians). 18 | fn fromAngle(angle: f32) -> Rot2 { 19 | return Rot2(vec2(cos(angle), sin(angle))); 20 | } 21 | 22 | /// Computes the rotation `R` required such that the `y` component of `R * v` is zero. 23 | /// 24 | /// Returns `Rot2()` (i.e. Rot2 filled with zeros) if no rotation is needed (i.e. if `v.y == 0`). Otherwise, this returns 25 | /// the rotation `R` such that `R * v = [ |v|, 0.0 ]^t` where `|v|` is the norm of `v`. 26 | fn cancel_y(v: vec2) -> Rot2 { 27 | if v.y != 0.0 { 28 | let r = sign(v.x) / length(v); 29 | let cos_sin = vec2(v.x, -v.y) * r; 30 | return Rot2(cos_sin); 31 | } else { 32 | return Rot2(); 33 | } 34 | } 35 | 36 | /// The quaternion representing an identity rotation. 37 | fn identity() -> Rot2 { 38 | return Rot2(vec2(1.0, 0.0)); 39 | } 40 | 41 | fn toMatrix(r: Rot2) -> mat2x2 { 42 | return mat2x2( 43 | vec2(r.cos_sin.x, r.cos_sin.y), 44 | vec2(-r.cos_sin.y, r.cos_sin.x) 45 | ); 46 | } 47 | 48 | /// The inverse of a 2d rotation. 49 | fn inv(r: Rot2) -> Rot2 { 50 | return Rot2(vec2(r.cos_sin.x, -r.cos_sin.y)); 51 | } 52 | 53 | /// Multiplication of two 2D rotations. 54 | fn mul(lhs: Rot2, rhs: Rot2) -> Rot2 { 55 | let new_cos = lhs.cos_sin.x * rhs.cos_sin.x - lhs.cos_sin.y * rhs.cos_sin.y; 56 | let new_sin = lhs.cos_sin.y * rhs.cos_sin.x + lhs.cos_sin.x * rhs.cos_sin.y; 57 | return Rot2(vec2(new_cos, new_sin)); 58 | } 59 | 60 | /// Multiplies a 2D rotation by a vector (rotates the vector). 61 | fn mulVec(r: Rot2, v: vec2) -> vec2 { 62 | return vec2(r.cos_sin.x * v.x - r.cos_sin.y * v.y, r.cos_sin.y * v.x + r.cos_sin.x * v.y); 63 | } 64 | 65 | /// Multiplies the inverse of a 2D rotation by a vector (applies inverse rotation to the vector). 66 | fn invMulVec(r: Rot2, v: vec2) -> vec2 { 67 | return vec2(r.cos_sin.x * v.x + r.cos_sin.y * v.y, -r.cos_sin.y * v.x + r.cos_sin.x * v.y); 68 | } 69 | 70 | // Apply the rotation to rows i and i + 1 to the given 3x3 matrix. 71 | fn rotate_rows3(rot: Rot2, m: ptr>, i: u32) { 72 | for (var r = 0; r < 3; r++) { 73 | let v = vec2((*m)[i][r], (*m)[i + 1][r]); 74 | let rv = invMulVec(rot, v); 75 | (*m)[i][r] = rv.x; 76 | (*m)[i + 1][r] = rv.y; 77 | } 78 | } 79 | 80 | // Apply the rotation to rows i and i + 1 to the given 4x4 matrix. 81 | fn rotate_rows4(rot: Rot2, m: ptr>, i: u32) { 82 | for (var r = 0; r < 4; r++) { 83 | let v = vec2((*m)[i][r], (*m)[i + 1][r]); 84 | let rv = invMulVec(rot, v); 85 | (*m)[i][r] = rv.x; 86 | (*m)[i + 1][r] = rv.y; 87 | } 88 | } -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/sim2.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path wgebra::sim2 2 | #import wgebra::rot2 as Rot 3 | 4 | 5 | /// An 2D similarity representing a uniform scale, followed by a rotation, followed by a translation. 6 | struct Sim2 { 7 | /// The similarity’s rotational part. 8 | rotation: Rot::Rot2, 9 | /// The similarity’s translational part. 10 | translation: vec2, 11 | /// The similarity’s scaling part. 12 | scale: f32, 13 | } 14 | 15 | /// The identity similarity. 16 | fn identity() -> Sim2 { 17 | return Sim2(Rot::identity(), vec2(0.0f), 1.0f); 18 | } 19 | 20 | /// Multiplies two similarities. 21 | fn mul(lhs: Sim2, rhs: Sim2) -> Sim2 { 22 | let rotation = Rot::mul(lhs.rotation, rhs.rotation); 23 | let translation = lhs.translation + Rot::mulVec(lhs.rotation, rhs.translation) * lhs.scale; 24 | return Sim2(rotation, translation, lhs.scale * rhs.scale); 25 | } 26 | 27 | /// Inverts a similarity. 28 | fn inv(sim: Sim2) -> Sim2 { 29 | let scale = 1.0f / sim.scale; 30 | let rotation = Rot::inv(sim.rotation); 31 | let translation = Rot::mulVec(rotation, -sim.translation) * scale; 32 | return Sim2(rotation, translation, scale); 33 | } 34 | 35 | /// Multiplies a similarity and a point (scales, rotates then translates the point). 36 | fn mulPt(sim: Sim2, pt: vec2) -> vec2 { 37 | return Rot::mulVec(sim.rotation, pt * sim.scale) + sim.translation; 38 | } 39 | 40 | /// Multiplies the inverse of a similarity and a point (inv-translates, inv-rotates, then inv-scales the point). 41 | fn invMulPt(sim: Sim2, pt: vec2) -> vec2 { 42 | return Rot::invMulVec(sim.rotation, (pt - sim.translation)) / sim.scale; 43 | } 44 | 45 | /// Multiplies a similarity and a vector (scales and rotates the vector; the translation is ignored). 46 | fn mulVec(sim: Sim2, vec: vec2) -> vec2 { 47 | return Rot::mulVec(sim.rotation, vec) * sim.scale; 48 | } 49 | 50 | /// Multiplies the inverse of a similarity and a vector (inv-rotates then inv-scales the point; the translation is ignored). 51 | fn invMulVec(sim: Sim2, vec: vec2) -> vec2 { 52 | return Rot::invMulVec(sim.rotation, vec) / sim.scale; 53 | } 54 | 55 | /// Multiplies the inverse of a similarity and a unit vector. 56 | /// 57 | /// This is similar to `invMulVec` but the scaling part of the similarity is ignored to preserve the vector’s unit size. 58 | fn invMulUnitVec(sim: Sim2, vec: vec2) -> vec2 { 59 | return Rot::invMulVec(sim.rotation, vec); 60 | } 61 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/sim3.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path wgebra::sim3 2 | #import wgebra::quat as Rot 3 | 4 | 5 | /// An 3D similarity representing a uniform scale, followed by a rotation, followed by a translation. 6 | struct Sim3 { 7 | /// The similarity’s rotational part. 8 | rotation: Rot::Quat, 9 | /// The similarity’s translational (xyz) and scaling (w) part. 10 | translation_scale: vec4 11 | } 12 | 13 | /// The identity similarity. 14 | fn identity() -> Sim3 { 15 | return Sim3(Rot::identity(), vec4(0.0f, 0.0f, 0.0f, 1.0f)); 16 | } 17 | 18 | /// Multiplies two similarities. 19 | fn mul(lhs: Sim3, rhs: Sim3) -> Sim3 { 20 | let rotation = Rot::mul(lhs.rotation, rhs.rotation); 21 | let translation = lhs.translation_scale.xyz + Rot::mulVec(lhs.rotation, rhs.translation_scale.xyz) * lhs.translation_scale.w; 22 | return Sim3(rotation, vec4(translation, lhs.translation_scale.w * rhs.translation_scale.w)); 23 | } 24 | 25 | /// Inverts a similarity. 26 | fn inv(sim: Sim3) -> Sim3 { 27 | let scale = 1.0f / sim.translation_scale.w; 28 | let rotation = Rot::inv(sim.rotation); 29 | let translation = Rot::mulVec(rotation, -sim.translation_scale.xyz) * scale; 30 | return Sim3(rotation, vec4(translation, scale)); 31 | } 32 | 33 | /// Multiplies a similarity and a point (scales, rotates then translates the point). 34 | fn mulPt(sim: Sim3, pt: vec3) -> vec3 { 35 | return Rot::mulVec(sim.rotation, pt * sim.translation_scale.w) + sim.translation_scale.xyz; 36 | } 37 | 38 | /// Multiplies the inverse of a similarity and a point (inv-translates, inv-rotates, then inv-scales the point). 39 | fn invMulPt(sim: Sim3, pt: vec3) -> vec3 { 40 | return Rot::invMulVec(sim.rotation, (pt - sim.translation_scale.xyz)) / sim.translation_scale.w; 41 | } 42 | 43 | /// Multiplies a similarity and a vector (scales and rotates the vector; the translation is ignored). 44 | fn mulVec(sim: Sim3, vec: vec3) -> vec3 { 45 | return Rot::mulVec(sim.rotation, vec) * sim.translation_scale.w; 46 | } 47 | 48 | /// Multiplies the inverse of a similarity and a vector (inv-rotates then inv-scales the point; the translation is ignored). 49 | fn invMulVec(sim: Sim3, vec: vec3) -> vec3 { 50 | return Rot::invMulVec(sim.rotation, vec) / sim.translation_scale.w; 51 | } 52 | 53 | /// Multiplies the inverse of a similarity and a unit vector. 54 | /// 55 | /// This is similar to `invMulVec` but the scaling part of the similarity is ignored to preserve the vector’s unit size. 56 | fn invMulUnitVec(sim: Sim3, vec: vec3) -> vec3 { 57 | return Rot::invMulVec(sim.rotation, vec); 58 | } 59 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/svd2.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::WgTrig; 2 | use nalgebra::{Matrix2, Vector2}; 3 | use wgcore::Shader; 4 | #[cfg(test)] 5 | use { 6 | naga_oil::compose::NagaModuleDescriptor, 7 | wgpu::{ComputePipeline, Device}, 8 | }; 9 | 10 | #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] 11 | #[repr(C)] 12 | /// GPU representation of a 2x2 matrix SVD. 13 | pub struct GpuSvd2 { 14 | /// First orthogonal matrix of the SVD. 15 | pub u: Matrix2, 16 | /// Singular values. 17 | pub s: Vector2, 18 | /// Second orthogonal matrix of the SVD. 19 | pub vt: Matrix2, 20 | } 21 | 22 | #[derive(Shader)] 23 | #[shader(derive(WgTrig), src = "svd2.wgsl")] 24 | /// Shader for computing the Singular Value Decomposition of 2x2 matrices. 25 | pub struct WgSvd2; 26 | 27 | impl WgSvd2 { 28 | #[cfg(test)] 29 | #[doc(hidden)] 30 | pub fn tests(device: &Device) -> ComputePipeline { 31 | let test_kernel = r#" 32 | @group(0) @binding(0) 33 | var in: array>; 34 | @group(0) @binding(1) 35 | var out: array; 36 | 37 | @compute @workgroup_size(1, 1, 1) 38 | fn test(@builtin(global_invocation_id) invocation_id: vec3) { 39 | let i = invocation_id.x; 40 | out[i] = svd(in[i]); 41 | } 42 | "#; 43 | 44 | let src = format!("{}\n{}", Self::src(), test_kernel); 45 | let module = WgTrig::composer() 46 | .unwrap() 47 | .make_naga_module(NagaModuleDescriptor { 48 | source: &src, 49 | file_path: Self::FILE_PATH, 50 | ..Default::default() 51 | }) 52 | .unwrap(); 53 | wgcore::utils::load_module(device, "test", module) 54 | } 55 | } 56 | 57 | #[cfg(test)] 58 | mod test { 59 | use super::GpuSvd2; 60 | use approx::assert_relative_eq; 61 | use nalgebra::{DVector, Matrix2}; 62 | use wgcore::gpu::GpuInstance; 63 | use wgcore::kernel::{CommandEncoderExt, KernelDispatch}; 64 | use wgcore::tensor::GpuVector; 65 | use wgpu::BufferUsages; 66 | 67 | #[futures_test::test] 68 | #[serial_test::serial] 69 | async fn gpu_svd2() { 70 | let gpu = GpuInstance::new().await.unwrap(); 71 | let svd = super::WgSvd2::tests(gpu.device()); 72 | let mut encoder = gpu.device().create_command_encoder(&Default::default()); 73 | 74 | const LEN: usize = 345; 75 | let mut matrices: DVector> = DVector::new_random(LEN); 76 | matrices[0] = Matrix2::zeros(); // The zero matrix can cause issues on some platforms (like macos) with unspecified atan2 on (0, 0). 77 | matrices[1] = Matrix2::identity(); // The identity matrix can cause issues on some platforms. 78 | let inputs = GpuVector::init(gpu.device(), &matrices, BufferUsages::STORAGE); 79 | let result: GpuVector = GpuVector::uninit( 80 | gpu.device(), 81 | matrices.len() as u32, 82 | BufferUsages::STORAGE | BufferUsages::COPY_SRC, 83 | ); 84 | let staging: GpuVector = GpuVector::uninit( 85 | gpu.device(), 86 | matrices.len() as u32, 87 | BufferUsages::MAP_READ | BufferUsages::COPY_DST, 88 | ); 89 | 90 | // Dispatch the test. 91 | let mut pass = encoder.compute_pass("test", None); 92 | KernelDispatch::new(gpu.device(), &mut pass, &svd) 93 | .bind0([inputs.buffer(), result.buffer()]) 94 | .dispatch(matrices.len() as u32); 95 | drop(pass); // Ensure the pass is ended before the encoder is borrowed again. 96 | 97 | staging.copy_from(&mut encoder, &result); 98 | gpu.queue().submit(Some(encoder.finish())); 99 | 100 | // Check the result is correct. 101 | let gpu_result = staging.read(gpu.device()).await.unwrap(); 102 | 103 | for (m, svd) in matrices.iter().zip(gpu_result.iter()) { 104 | let reconstructed = svd.u * Matrix2::from_diagonal(&svd.s) * svd.vt; 105 | assert_relative_eq!(*m, reconstructed, epsilon = 1.0e-4); 106 | } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/svd2.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path wgebra::svd2 2 | #import wgebra::trig as Trig 3 | 4 | // The SVD of a 2x2 matrix. 5 | struct Svd { 6 | U: mat2x2, 7 | S: vec2, 8 | Vt: mat2x2, 9 | }; 10 | 11 | // Computes the SVD of a 2x2 matrix. 12 | fn svd(m: mat2x2) -> Svd { 13 | let e = (m[0].x + m[1].y) * 0.5; 14 | let f = (m[0].x - m[1].y) * 0.5; 15 | let g = (m[0].y + m[1].x) * 0.5; 16 | let h = (m[0].y - m[1].x) * 0.5; 17 | let q = sqrt(e * e + h * h); 18 | let r = sqrt(f * f + g * g); 19 | 20 | // Note that the singular values are always sorted because sx >= sy 21 | // because q >= 0 and r >= 0. 22 | let sx = q + r; 23 | let sy = q - r; 24 | let sy_sign = select(1.0, -1.0, sy < 0.0); 25 | let singular_values = vec2(sx, sy * sy_sign); 26 | 27 | let a1 = Trig::stable_atan2(g, f); 28 | let a2 = Trig::stable_atan2(h, e); 29 | let theta = (a2 - a1) * 0.5; 30 | let phi = (a2 + a1) * 0.5; 31 | let st = sin(theta); 32 | let ct = cos(theta); 33 | let sp = sin(phi); 34 | let cp = cos(phi); 35 | 36 | let u = mat2x2(vec2(cp, sp), vec2(-sp, cp)); 37 | let v_t = mat2x2(vec2(ct, st * sy_sign), vec2(-st, ct * sy_sign)); 38 | 39 | return Svd(u, singular_values, v_t); 40 | } 41 | 42 | // Rebuilds the matrix this svd is the decomposition of. 43 | fn recompose(svd: Svd) -> mat2x2 { 44 | let U_S = mat2x2(svd.U[0] * svd.S.x, svd.U[1] * svd.S.y); 45 | return U_S * svd.Vt; 46 | } 47 | -------------------------------------------------------------------------------- /crates/wgebra/src/geometry/svd3.rs: -------------------------------------------------------------------------------- 1 | use crate::WgQuat; 2 | use nalgebra::{Matrix4x3, Vector4}; 3 | use wgcore::Shader; 4 | #[cfg(test)] 5 | use { 6 | naga_oil::compose::NagaModuleDescriptor, 7 | wgpu::{ComputePipeline, Device}, 8 | }; 9 | 10 | #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] 11 | #[repr(C)] 12 | /// A 3D SVD as represented on the gpu, with padding (every fourth rows 13 | /// can be ignored). 14 | // TODO: switch to encase? 15 | pub struct GpuSvd3 { 16 | /// First orthogonal matrix of the SVD. 17 | u: Matrix4x3, 18 | /// Singular values. 19 | s: Vector4, 20 | /// Second orthogonal matrix of the SVD. 21 | vt: Matrix4x3, 22 | } 23 | 24 | #[derive(Shader)] 25 | #[shader(derive(WgQuat), src = "svd3.wgsl")] 26 | /// Shader for computing the Singular Value Decomposition of 3x3 matrices. 27 | pub struct WgSvd3; 28 | 29 | impl WgSvd3 { 30 | #[cfg(test)] 31 | #[doc(hidden)] 32 | pub fn tests(device: &Device) -> ComputePipeline { 33 | let test_kernel = r#" 34 | @group(0) @binding(0) 35 | var in: array>; 36 | @group(0) @binding(1) 37 | var out: array; 38 | 39 | @compute @workgroup_size(1, 1, 1) 40 | fn test(@builtin(global_invocation_id) invocation_id: vec3) { 41 | let i = invocation_id.x; 42 | out[i] = svd(in[i]); 43 | } 44 | "#; 45 | 46 | let src = format!("{}\n{}", Self::src(), test_kernel); 47 | let module = WgQuat::composer() 48 | .unwrap() 49 | .make_naga_module(NagaModuleDescriptor { 50 | source: &src, 51 | file_path: Self::FILE_PATH, 52 | ..Default::default() 53 | }) 54 | .unwrap(); 55 | wgcore::utils::load_module(device, "test", module) 56 | } 57 | } 58 | 59 | #[cfg(test)] 60 | mod test { 61 | use super::GpuSvd3; 62 | use approx::assert_relative_eq; 63 | use nalgebra::{DVector, Matrix3, Matrix4x3}; 64 | use wgcore::gpu::GpuInstance; 65 | use wgcore::kernel::{CommandEncoderExt, KernelDispatch}; 66 | use wgcore::tensor::GpuVector; 67 | use wgpu::BufferUsages; 68 | 69 | #[futures_test::test] 70 | #[serial_test::serial] 71 | async fn gpu_svd3() { 72 | let gpu = GpuInstance::new().await.unwrap(); 73 | let svd = super::WgSvd3::tests(gpu.device()); 74 | let mut encoder = gpu.device().create_command_encoder(&Default::default()); 75 | 76 | const LEN: usize = 345; 77 | let matrices: DVector> = DVector::new_random(LEN); 78 | let inputs = GpuVector::init(gpu.device(), &matrices, BufferUsages::STORAGE); 79 | let result: GpuVector = GpuVector::uninit( 80 | gpu.device(), 81 | matrices.len() as u32, 82 | BufferUsages::STORAGE | BufferUsages::COPY_SRC, 83 | ); 84 | let staging: GpuVector = GpuVector::uninit( 85 | gpu.device(), 86 | matrices.len() as u32, 87 | BufferUsages::MAP_READ | BufferUsages::COPY_DST, 88 | ); 89 | 90 | // Dispatch the test. 91 | let mut pass = encoder.compute_pass("test", None); 92 | KernelDispatch::new(gpu.device(), &mut pass, &svd) 93 | .bind0([inputs.buffer(), result.buffer()]) 94 | .dispatch(matrices.len() as u32); 95 | drop(pass); // Ensure the pass is ended before the encoder is borrowed again. 96 | 97 | // Submit. 98 | staging.copy_from(&mut encoder, &result); 99 | gpu.queue().submit(Some(encoder.finish())); 100 | 101 | // Check the result is correct. 102 | let gpu_result = staging.read(gpu.device()).await.unwrap(); 103 | 104 | for (m, svd) in matrices.iter().zip(gpu_result.iter()) { 105 | let m = m.fixed_rows::<3>(0).into_owned(); 106 | let reconstructed = svd.u.fixed_rows::<3>(0).into_owned() 107 | * Matrix3::from_diagonal(&svd.s.fixed_rows::<3>(0)) 108 | * svd.vt.fixed_rows::<3>(0).into_owned(); 109 | assert_relative_eq!(m, reconstructed, epsilon = 1.0e-4); 110 | } 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /crates/wgebra/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![doc = include_str!("../README.md")] 2 | #![allow(clippy::too_many_arguments)] 3 | #![allow(clippy::result_large_err)] 4 | #![warn(missing_docs)] 5 | 6 | pub use geometry::*; 7 | pub use linalg::*; 8 | 9 | pub mod geometry; 10 | pub mod linalg; 11 | pub mod utils; 12 | -------------------------------------------------------------------------------- /crates/wgebra/src/linalg/gemv.wgsl: -------------------------------------------------------------------------------- 1 | #import wgblas::shape as Shape 2 | 3 | @group(0) @binding(0) 4 | var shape_out: Shape::Shape; 5 | @group(0) @binding(1) 6 | var shape_m: Shape::Shape; 7 | @group(0) @binding(2) 8 | var shape_v: Shape::Shape; 9 | @group(0) @binding(3) 10 | var out: array>; 11 | @group(0) @binding(4) 12 | var m: array>; 13 | @group(0) @binding(5) 14 | var v: array>; 15 | 16 | // NOTE: gemv_tr_fast is quite a bit (15%) faster with a workgroup size of 8. 17 | const WORKGROUP_SIZE: u32 = 32; 18 | 19 | var sketch: array, WORKGROUP_SIZE>; 20 | 21 | fn reduce_sum(index: u32, stride: u32) { 22 | if index < stride { 23 | sketch[index] += sketch[index + stride]; 24 | } 25 | workgroupBarrier(); 26 | } 27 | 28 | @compute @workgroup_size(WORKGROUP_SIZE, 1, 1) 29 | fn gemv_fast( 30 | @builtin(workgroup_id) workgroup_id: vec3, 31 | @builtin(local_invocation_id) local_id: vec3 32 | ) { 33 | let shape_m = Shape::with_vec4_elts(shape_m); 34 | let shape_v = Shape::with_vec4_elts(shape_v); 35 | let shape_out = Shape::with_vec4_elts(shape_out); 36 | 37 | var sum = vec4(0.0); 38 | 39 | for (var j = 0u; j < shape_m.ncols; j += 4u * WORKGROUP_SIZE) { 40 | var ia = Shape::it(shape_m, workgroup_id.x, j + local_id.x * 4u, workgroup_id.z); 41 | let ib = ia + shape_m.stride; 42 | let ic = ib + shape_m.stride; 43 | let id = ic + shape_m.stride; 44 | let submat = mat4x4(m[ia], m[ib], m[ic], m[id]); 45 | 46 | let iv = Shape::it(shape_v, j / 4u + local_id.x, workgroup_id.y, workgroup_id.z); 47 | sum += submat * v[iv]; 48 | } 49 | 50 | sketch[local_id.x] = sum; 51 | 52 | workgroupBarrier(); 53 | 54 | // reduce_sum(local_id.x, 32u); 55 | reduce_sum(local_id.x, 16u); 56 | reduce_sum(local_id.x, 8u); 57 | reduce_sum(local_id.x, 4u); 58 | reduce_sum(local_id.x, 2u); 59 | reduce_sum(local_id.x, 1u); 60 | 61 | if local_id.x == 0u { 62 | let i_out = Shape::it(shape_out, workgroup_id.x, workgroup_id.y, workgroup_id.z); 63 | out[i_out] = sketch[0]; 64 | } 65 | } 66 | 67 | @compute @workgroup_size(WORKGROUP_SIZE, 1, 1) 68 | fn gemv(@builtin(global_invocation_id) invocation_id: vec3) { 69 | let shape_m = Shape::with_vec4_elts(shape_m); 70 | let shape_v = Shape::with_vec4_elts(shape_v); 71 | let shape_out = Shape::with_vec4_elts(shape_out); 72 | 73 | if invocation_id.x < shape_m.nrows { 74 | var sum = vec4(0.0); 75 | 76 | for (var j = 0u; j < shape_m.ncols; j += 4u) { 77 | var ia = Shape::it(shape_m, invocation_id.x, j, invocation_id.z); 78 | let ib = ia + shape_m.stride; 79 | let ic = ib + shape_m.stride; 80 | let id = ic + shape_m.stride; 81 | let submat = mat4x4(m[ia], m[ib], m[ic], m[id]); 82 | 83 | let iv = Shape::it(shape_v, j / 4u, invocation_id.y, invocation_id.z); 84 | sum += submat * v[iv]; 85 | } 86 | 87 | let i_out = Shape::it(shape_out, invocation_id.x, invocation_id.y, invocation_id.z); 88 | out[i_out] = sum; 89 | } 90 | } 91 | 92 | @compute @workgroup_size(WORKGROUP_SIZE, 1, 1) 93 | fn gemv_tr(@builtin(global_invocation_id) invocation_id: vec3) { 94 | let shape_m = Shape::with_vec4_elts(shape_m); 95 | let shape_v = Shape::with_vec4_elts(shape_v); 96 | let shape_out = Shape::with_vec4_elts(shape_out); 97 | 98 | if invocation_id.x < (shape_m.ncols + 3u) / 4 { 99 | var sum = vec4(0.0); 100 | 101 | for (var j = 0u; j < shape_m.nrows; j++) { 102 | var ia = Shape::it(shape_m, j, invocation_id.x * 4u, invocation_id.z); 103 | let ib = ia + shape_m.stride; 104 | let ic = ib + shape_m.stride; 105 | let id = ic + shape_m.stride; 106 | let submat = mat4x4(m[ia], m[ib], m[ic], m[id]); 107 | 108 | let iv = Shape::it(shape_v, j, invocation_id.y, invocation_id.z); 109 | sum += transpose(submat) * v[iv]; 110 | } 111 | 112 | let i_out = Shape::it(shape_out, invocation_id.x, invocation_id.y, invocation_id.z); 113 | out[i_out] = sum; 114 | } 115 | } 116 | 117 | @compute @workgroup_size(WORKGROUP_SIZE, 1, 1) 118 | fn gemv_tr_fast( 119 | @builtin(workgroup_id) workgroup_id: vec3, 120 | @builtin(local_invocation_id) local_id: vec3 121 | ) { 122 | let shape_m = Shape::with_vec4_elts(shape_m); 123 | let shape_v = Shape::with_vec4_elts(shape_v); 124 | let shape_out = Shape::with_vec4_elts(shape_out); 125 | 126 | var sum = vec4(0.0); 127 | 128 | for (var j = 0u; j < shape_m.nrows; j += WORKGROUP_SIZE) { 129 | var ia = Shape::it(shape_m, j + local_id.x, workgroup_id.x * 4u, workgroup_id.z); 130 | let ib = ia + shape_m.stride; 131 | let ic = ib + shape_m.stride; 132 | let id = ic + shape_m.stride; 133 | let submat = mat4x4(m[ia], m[ib], m[ic], m[id]); 134 | 135 | let iv = Shape::it(shape_v, j + local_id.x, workgroup_id.y, workgroup_id.z); 136 | sum += transpose(submat) * v[iv]; 137 | } 138 | 139 | sketch[local_id.x] = sum; 140 | 141 | workgroupBarrier(); 142 | 143 | // reduce_sum(local_id.x, 64u); 144 | // reduce_sum(local_id.x, 32u); 145 | reduce_sum(local_id.x, 16u); 146 | reduce_sum(local_id.x, 8u); 147 | reduce_sum(local_id.x, 4u); 148 | reduce_sum(local_id.x, 2u); 149 | reduce_sum(local_id.x, 1u); 150 | 151 | if local_id.x == 0u { 152 | let i_out = Shape::it(shape_out, workgroup_id.x, workgroup_id.y, workgroup_id.z); 153 | out[i_out] = sketch[0]; 154 | } 155 | } -------------------------------------------------------------------------------- /crates/wgebra/src/linalg/mod.rs: -------------------------------------------------------------------------------- 1 | //! Fundamental linear-algebra matrix/vector operations. 2 | 3 | mod gemm; 4 | mod gemv; 5 | mod op_assign; 6 | mod reduce; 7 | mod shape; 8 | 9 | pub use gemm::{Gemm, GemmVariant}; 10 | pub use gemv::{Gemv, GemvVariant}; 11 | pub use op_assign::{OpAssign, OpAssignVariant}; 12 | pub use reduce::{Reduce, ReduceOp}; 13 | pub use shape::{row_major_shader_defs, Shape}; 14 | -------------------------------------------------------------------------------- /crates/wgebra/src/linalg/op_assign.wgsl: -------------------------------------------------------------------------------- 1 | #import wgblas::shape as Shape 2 | 3 | @group(0) @binding(0) 4 | var shape_a: Shape::Shape; 5 | @group(0) @binding(1) 6 | var shape_b: Shape::Shape; 7 | @group(0) @binding(2) 8 | var a: array; 9 | @group(0) @binding(3) 10 | var b: array; 11 | 12 | const WORKGROUP_SIZE: u32 = 64; 13 | 14 | fn add_f32(a: f32, b: f32) -> f32 { 15 | return a + b; 16 | } 17 | 18 | fn sub_f32(a: f32, b: f32) -> f32 { 19 | return a - b; 20 | } 21 | 22 | fn mul_f32(a: f32, b: f32) -> f32 { 23 | return a * b; 24 | } 25 | 26 | fn div_f32(a: f32, b: f32) -> f32 { 27 | return a / b; 28 | } 29 | 30 | fn placeholder(a: f32, b: f32) -> f32 { 31 | return a + b; 32 | } 33 | 34 | // TODO: will the read of a be optimized-out by the shader compiler 35 | // or do we need to write a dedicated kernel for this? 36 | fn copy_f32(a: f32, b: f32) -> f32 { 37 | return b; 38 | } 39 | 40 | @compute @workgroup_size(WORKGROUP_SIZE, 1, 1) 41 | fn main(@builtin(global_invocation_id) invocation_id: vec3) { 42 | if invocation_id.x < shape_a.nrows { 43 | let ia = Shape::iv(shape_a, invocation_id.x); 44 | let ib = Shape::iv(shape_b, invocation_id.x); 45 | a[ia] = placeholder(a[ia], b[ib]); 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /crates/wgebra/src/linalg/reduce.wgsl: -------------------------------------------------------------------------------- 1 | #import wgblas::shape as Shape; 2 | 3 | @group(0) @binding(0) 4 | var shape: Shape::Shape; 5 | @group(0) @binding(1) 6 | var input: array; 7 | @group(0) @binding(2) 8 | var output: f32; 9 | 10 | const WORKGROUP_SIZE: u32 = 128; 11 | 12 | fn reduce_sum_f32(acc: f32, x: f32) -> f32 { 13 | return acc + x; 14 | } 15 | 16 | fn reduce_prod_f32(acc: f32, x: f32) -> f32 { 17 | return acc * x; 18 | } 19 | 20 | fn reduce_min_f32(acc: f32, x: f32) -> f32 { 21 | return min(acc, x); 22 | } 23 | 24 | fn reduce_max_f32(acc: f32, x: f32) -> f32 { 25 | return max(acc, x); 26 | } 27 | 28 | fn reduce_sqnorm_f32(acc: f32, x: f32) -> f32 { 29 | return acc + x * x; 30 | } 31 | 32 | fn init_zero() -> f32 { 33 | return 0.0; 34 | } 35 | 36 | fn init_one() -> f32 { 37 | return 1.0; 38 | } 39 | 40 | fn init_max_f32() -> f32 { 41 | return 3.40282347E+38; 42 | } 43 | 44 | fn init_min_f32() -> f32 { 45 | return -3.40282347E+38; 46 | } 47 | 48 | fn init_placeholder() -> f32 { 49 | return 0.0; 50 | } 51 | 52 | fn reduce_placeholder(acc: f32, x: f32) -> f32 { 53 | return acc + x; 54 | } 55 | fn workspace_placeholder(acc: f32, x: f32) -> f32 { 56 | return acc + x; 57 | } 58 | 59 | var workspace: array; 60 | 61 | fn reduce(thread_id: u32, stride: u32) { 62 | if thread_id < stride { 63 | workspace[thread_id] = reduce_placeholder(workspace[thread_id], workspace[thread_id + stride]); 64 | } 65 | workgroupBarrier(); 66 | } 67 | 68 | fn run_reduction(thread_id: u32) -> f32 { 69 | workspace[thread_id] = init_placeholder(); 70 | 71 | for (var i = thread_id; i < shape.nrows; i += WORKGROUP_SIZE) { 72 | let val_i = input[Shape::iv(shape, i)]; 73 | workspace[thread_id] = workspace_placeholder(workspace[thread_id], val_i); 74 | } 75 | 76 | workgroupBarrier(); 77 | 78 | reduce(thread_id, 64u); 79 | reduce(thread_id, 32u); 80 | reduce(thread_id, 16u); 81 | reduce(thread_id, 8u); 82 | reduce(thread_id, 4u); 83 | reduce(thread_id, 2u); 84 | reduce(thread_id, 1u); 85 | 86 | return workspace[0]; 87 | } 88 | 89 | @compute @workgroup_size(WORKGROUP_SIZE, 1, 1) 90 | fn main(@builtin(global_invocation_id) invocation_id: vec3) { 91 | let result = run_reduction(invocation_id.x); 92 | 93 | if (invocation_id.x == 0) { 94 | output = result; 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /crates/wgebra/src/linalg/shape.rs: -------------------------------------------------------------------------------- 1 | use naga_oil::compose::ShaderDefValue; 2 | use std::collections::HashMap; 3 | use wgcore::Shader; 4 | 5 | #[derive(Shader)] 6 | #[shader(src = "shape.wgsl")] 7 | /// A shader for handling matrix/vector indexing based on their shape of type 8 | /// [`wgcore::shapes::ViewShape`]. 9 | pub struct Shape; 10 | 11 | /// Shader definitions setting the `ROW_MAJOR` boolean macro for shaders supporting conditional 12 | /// compilation for switching row-major and column-major matrix handling. 13 | pub fn row_major_shader_defs() -> HashMap { 14 | [("ROW_MAJOR".to_string(), ShaderDefValue::Bool(true))].into() 15 | } 16 | -------------------------------------------------------------------------------- /crates/wgebra/src/linalg/shape.wgsl: -------------------------------------------------------------------------------- 1 | // Module comment 2 | // And a second line 3 | 4 | #define_import_path wgblas::shape 5 | 6 | // The shape of a matrix. 7 | // 8 | // If the `ROW_MAJOR` constant is defined then this represents the shape of a row-major matrix. 9 | // Otherwise, it represents the shape of a column-major matrix. 10 | struct Shape { 11 | // The number of rows in each matrix of the tensor. 12 | nrows: u32, 13 | // The number of columns in each matrix of the tensor. 14 | ncols: u32, 15 | // The number of matrices in the tensor. 16 | nmats: u32, 17 | // The number of elements separating two elements along the non-major dimension 18 | // of each matrix of the tensor. 19 | // 20 | // If the matrix is row-major (`ROW_MAJOR` is defined) then this is the number of elements in memory 21 | // between two consecutive elements from the same column. 22 | // 23 | // If the matrix is column-major (`ROW_MAJOR` is undefined) then this is the number of elements in memory 24 | // between two consecutive elements from the same row. 25 | // 26 | // Note that the stride along the other dimension is always assumed to be 1. 27 | stride: u32, 28 | // The number of elements separating two elements along the "matrix" direction (i.e. the third 29 | // tensor direction). This is independent from the matrix ordering (row-major vs. column-major). 30 | stride_mat: u32, 31 | // Index of the first element of the tensor. 32 | offset: u32, 33 | } 34 | 35 | // Index of the `i-th` element of a vector. 36 | fn iv(view: Shape, i: u32) -> u32 { 37 | return view.offset + i; 38 | } 39 | 40 | fn div_ceil4(a: u32) -> u32 { 41 | return (a + 3u) / 4u; 42 | } 43 | 44 | // Index of the element at row `i`, column `j` of the matrix `t` in this tensor. 45 | fn it(view: Shape, i: u32, j: u32, t: u32) -> u32 { 46 | return t * view.stride_mat + im(view, i, j); 47 | } 48 | 49 | #ifdef ROW_MAJOR 50 | // Index of the element at row `i` and column `j` of a row-major matrix. 51 | fn im(view: Shape, i: u32, j: u32) -> u32 { 52 | return view.offset + i * view.stride + j; 53 | } 54 | 55 | fn with_vec4_elts(shape: Shape) -> Shape { 56 | return Shape(shape.nrows, div_ceil4(shape.ncols), shape.nmats, div_ceil4(shape.stride), div_ceil4(shape.stride_mat), shape.offset / 4u); 57 | } 58 | #else 59 | // Index of the element at row `i` and column `j` of a column-major matrix. 60 | fn im(view: Shape, i: u32, j: u32) -> u32 { 61 | return view.offset + i + j * view.stride; 62 | } 63 | 64 | fn with_vec4_elts(shape: Shape) -> Shape { 65 | return Shape(div_ceil4(shape.nrows), shape.ncols, shape.nmats, div_ceil4(shape.stride), div_ceil4(shape.stride_mat), shape.offset / 4u); 66 | } 67 | #endif 68 | -------------------------------------------------------------------------------- /crates/wgebra/src/utils/min_max.rs: -------------------------------------------------------------------------------- 1 | use wgcore::Shader; 2 | 3 | /// Helper shader functions for calculating the min/max elements of a vector or matrix. 4 | #[derive(Shader)] 5 | #[shader(src = "min_max.wgsl")] 6 | pub struct WgMinMax; 7 | -------------------------------------------------------------------------------- /crates/wgebra/src/utils/min_max.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path wgebra::min_max 2 | 3 | /// Computes the maximum value accross all elements of the given 2D vector. 4 | fn max2(v: vec2) -> f32 { 5 | return max(v.x, v.y); 6 | } 7 | 8 | /// Computes the maximum **absolute** value accross all elements of the given 2x2 matrix. 9 | fn amax2x2(m: mat2x2) -> f32 { 10 | let vm = max(abs(m[0]), abs(m[1])); 11 | return max(vm.x, vm.y); 12 | } 13 | 14 | /// Computes the maximum value accross all elements of the given 2x2 matrix. 15 | fn max2x2(m: mat2x2) -> f32 { 16 | let vm = max(m[0], m[1]); 17 | return max(vm.x, vm.y); 18 | } 19 | 20 | /// Computes the maximum value accross all elements of the given 3D vector. 21 | fn max3(v: vec3) -> f32 { 22 | return max(v.x, max(v.y, v.z)); 23 | } 24 | 25 | /// Computes the maximum **absolute** value accross all elements of the given 3x3 matrix. 26 | fn amax3x3(m: mat3x3) -> f32 { 27 | let vm = max(abs(m[0]), max(abs(m[1]), abs(m[2]))); 28 | return max(vm.x, max(vm.y, vm.z)); 29 | } 30 | 31 | /// Computes the maximum value accross all elements of the given 3x3 matrix. 32 | fn max3x3(m: mat3x3) -> f32 { 33 | let vm = max(m[0], max(m[1], m[2])); 34 | return max(vm.x, max(vm.y, vm.z)); 35 | } 36 | 37 | /// Computes the maximum value accross all elements of the given 4D vector. 38 | fn max4(v: vec4) -> f32 { 39 | return max(v.x, max(v.y, max(v.z, v.w))); 40 | } 41 | 42 | /// Computes the maximum **absolute** value accross all elements of the given 4x4 matrix. 43 | fn amax4x4(m: mat4x4) -> f32 { 44 | let vm = max(abs(m[0]), max(abs(m[1]), max(abs(m[2]), abs(m[3])))); 45 | return max(vm.x, max(vm.y, max(vm.z, vm.w))); 46 | } 47 | 48 | /// Computes the maximum value accross all elements of the given 4x4 matrix. 49 | fn max4x4(m: mat4x4) -> f32 { 50 | let vm = max(m[0], max(m[1], max(m[2], m[3]))); 51 | return max(vm.x, max(vm.y, max(vm.z, vm.w))); 52 | } -------------------------------------------------------------------------------- /crates/wgebra/src/utils/mod.rs: -------------------------------------------------------------------------------- 1 | //! Utilities to address some platform-dependent differences 2 | //! (e.g. for some trigonometric functions). 3 | 4 | pub use self::min_max::WgMinMax; 5 | pub use self::trig::WgTrig; 6 | 7 | mod min_max; 8 | mod trig; 9 | -------------------------------------------------------------------------------- /crates/wgebra/src/utils/trig.rs: -------------------------------------------------------------------------------- 1 | use wgcore::Shader; 2 | 3 | /// Alternative implementations of some geometric functions on the gpu. 4 | /// 5 | /// Some platforms (Metal in particular) has implementations of some trigonometric functions 6 | /// that are not numerically stable. This is the case for example for `atan2` and `atanh` that 7 | /// may occasionally lead to NaNs. This shader exposes alternative implementations for numerically 8 | /// stable versions of these functions to ensure good behavior across all platforms. 9 | #[derive(Shader)] 10 | #[shader(src = "trig.wgsl")] 11 | pub struct WgTrig; 12 | -------------------------------------------------------------------------------- /crates/wgebra/src/utils/trig.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path wgebra::trig 2 | 3 | /// The value of pi. 4 | const PI: f32 = 3.14159265358979323846264338327950288; 5 | 6 | /// A numerically stable implementation of tanh. 7 | /// 8 | /// Metal’s implementation of tanh returns NaN for large values. 9 | /// This function is more numerically stable and should be used as a 10 | /// drop-in replacement. 11 | // Inspired from https://github.com/apache/tvm/pull/16438 (Apache 2.0 license). 12 | fn stable_tanh(x: f32) -> f32 { 13 | let exp_neg2x = exp(-2.0 * x); 14 | let exp_pos2x = exp(2.0 * x); 15 | let tanh_pos = (1.0 - exp_neg2x) / (1.0 + exp_neg2x); 16 | let tanh_neg = (exp_pos2x - 1.0) / (exp_pos2x + 1.0); 17 | return select(tanh_neg, tanh_pos, x >= 0.0); 18 | } 19 | 20 | /// In some platforms, atan2 has unusable edge cases, e.g., returning NaN when y = 0 and x = 0. 21 | /// 22 | /// This is for example the case in Metal/MSL: https://github.com/gfx-rs/wgpu/issues/4319 23 | /// So we need to implement it ourselves to ensure svd always returns reasonable results on some 24 | /// edge cases like the identity. 25 | fn stable_atan2(y: f32, x: f32) -> f32 { 26 | let ang = atan(y / x); 27 | if x > 0.0 { 28 | return ang; 29 | } 30 | if x < 0.0 && y > 0.0 { 31 | return ang + PI; 32 | } 33 | if x < 0.0 && y < 0.0 { 34 | return ang - PI; 35 | } 36 | 37 | // Force the other ubounded cases to 0. 38 | return 0.0; 39 | } -------------------------------------------------------------------------------- /crates/wgparry/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## v0.2.0 2 | 3 | ### Modified 4 | 5 | - Update to `wgcore` v0.2.0. 6 | 7 | -------------------------------------------------------------------------------- /crates/wgparry/README.md: -------------------------------------------------------------------------------- 1 | # wgparry: cross-platform GPU collision-detection 2 | 3 | **/!\ This library is still under heavy development and is still missing many features.** 4 | 5 | The goal of **wgparry** is to especially be "**parry** on the gpu". It aims (but it isn’t there yet) to expose 6 | geometric operations (collision-detection, ray-casting, point-projection, etc.) as composable WGSl shaders and kernels. 7 | -------------------------------------------------------------------------------- /crates/wgparry/crates/wgparry2d/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "wgparry2d" 3 | authors = ["Sébastien Crozet "] 4 | description = "Cross-platform 2D GPU collision detection and geometry." 5 | homepage = "https://wgmath.rs" 6 | repository = "https://github.com/dimforge/wgmath" 7 | version = "0.2.0" 8 | edition = "2021" 9 | license = "MIT OR Apache-2.0" 10 | 11 | [lib] 12 | name = "wgparry2d" 13 | path = "../../src/lib.rs" 14 | required-features = ["dim2"] 15 | 16 | [lints] 17 | rust.unexpected_cfgs = { level = "warn", check-cfg = [ 18 | 'cfg(feature, values("dim3"))', 19 | ] } 20 | 21 | [features] 22 | default = ["dim2"] 23 | dim2 = [] 24 | 25 | [dependencies] 26 | nalgebra = { workspace = true } 27 | wgpu = { workspace = true } 28 | naga_oil = { workspace = true } 29 | bytemuck = { workspace = true } 30 | encase = { workspace = true } 31 | parry2d = { workspace = true } 32 | 33 | wgcore = { version = "0.2", path = "../../../wgcore" } 34 | wgebra = { version = "0.2", path = "../../../wgebra" } 35 | 36 | [dev-dependencies] 37 | nalgebra = { version = "0.33", features = ["rand"] } 38 | futures-test = "0.3" 39 | serial_test = "3" 40 | approx = "0.5" 41 | -------------------------------------------------------------------------------- /crates/wgparry/crates/wgparry2d/README.md: -------------------------------------------------------------------------------- 1 | ../../README.md -------------------------------------------------------------------------------- /crates/wgparry/crates/wgparry2d/src: -------------------------------------------------------------------------------- 1 | ../../src -------------------------------------------------------------------------------- /crates/wgparry/crates/wgparry3d/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "wgparry3d" 3 | authors = ["Sébastien Crozet "] 4 | description = "Cross-platform 3D GPU collision-detection and geometry." 5 | homepage = "https://wgmath.rs" 6 | repository = "https://github.com/dimforge/wgmath" 7 | version = "0.2.0" 8 | edition = "2021" 9 | license = "MIT OR Apache-2.0" 10 | 11 | [lib] 12 | name = "wgparry3d" 13 | path = "../../src/lib.rs" 14 | required-features = ["dim3"] 15 | 16 | [lints] 17 | rust.unexpected_cfgs = { level = "warn", check-cfg = [ 18 | 'cfg(feature, values("dim2"))', 19 | ] } 20 | 21 | [features] 22 | default = ["dim3"] 23 | dim3 = [] 24 | 25 | [dependencies] 26 | nalgebra = { workspace = true } 27 | wgpu = { workspace = true } 28 | naga_oil = { workspace = true } 29 | bytemuck = { workspace = true } 30 | encase = { workspace = true } 31 | parry3d = { workspace = true } 32 | 33 | wgcore = { version = "0.2", path = "../../../wgcore" } 34 | wgebra = { version = "0.2", path = "../../../wgebra" } 35 | 36 | [dev-dependencies] 37 | nalgebra = { version = "0.33", features = ["rand"] } 38 | futures-test = "0.3" 39 | serial_test = "3" 40 | approx = "0.5" 41 | -------------------------------------------------------------------------------- /crates/wgparry/crates/wgparry3d/README.md: -------------------------------------------------------------------------------- 1 | ../../README.md -------------------------------------------------------------------------------- /crates/wgparry/crates/wgparry3d/src: -------------------------------------------------------------------------------- 1 | ../../src -------------------------------------------------------------------------------- /crates/wgparry/src/ball.rs: -------------------------------------------------------------------------------- 1 | //! The ball shape. 2 | 3 | use crate::projection::WgProjection; 4 | use crate::ray::WgRay; 5 | use crate::{dim_shader_defs, substitute_aliases}; 6 | use wgcore::Shader; 7 | use wgebra::{WgSim2, WgSim3}; 8 | 9 | #[derive(Shader)] 10 | #[shader( 11 | derive(WgSim3, WgSim2, WgRay, WgProjection), 12 | src = "ball.wgsl", 13 | src_fn = "substitute_aliases", 14 | shader_defs = "dim_shader_defs" 15 | )] 16 | /// Shader defining the ball shape as well as its ray-casting and point-projection functions. 17 | pub struct WgBall; 18 | 19 | #[cfg(test)] 20 | mod test { 21 | use super::WgBall; 22 | #[cfg(feature = "dim2")] 23 | use parry2d::shape::Ball; 24 | #[cfg(feature = "dim3")] 25 | use parry3d::shape::Ball; 26 | use wgcore::tensor::GpuVector; 27 | 28 | #[futures_test::test] 29 | #[serial_test::serial] 30 | async fn gpu_ball() { 31 | crate::projection::test_utils::test_point_projection::( 32 | "Ball", 33 | Ball::new(0.5), 34 | |device, shapes, usages| GpuVector::init(device, shapes, usages).into_inner(), 35 | ) 36 | .await; 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /crates/wgparry/src/ball.wgsl: -------------------------------------------------------------------------------- 1 | #if DIM == 2 2 | #import wgebra::sim2 as Pose 3 | #else 4 | #import wgebra::sim3 as Pose 5 | #endif 6 | #import wgparry::ray as Ray 7 | #import wgparry::projection as Proj 8 | 9 | #define_import_path wgparry::ball 10 | 11 | 12 | /// A ball, defined by its radius. 13 | struct Ball { 14 | /// The ball’s radius. 15 | radius: f32, 16 | } 17 | 18 | /* 19 | /// Casts a ray on a ball. 20 | /// 21 | /// Returns a negative value if there is no hit. 22 | /// If there is a hit, the result is a scalar `t >= 0` such that the hit point is equal to `ray.origin + t * ray.dir`. 23 | fn castLocalRay(ball: Ball, ray: Ray::Ray, maxTimeOfImpact: f32) -> f32 { 24 | // Ray origin relative to the ball’s center. It’s the origin itself since it’s in the ball’s local frame. 25 | let dcenter = ray.origin; 26 | let a = dot(ray.dir, ray.dir); 27 | let b = dot(dcenter, ray.dir); 28 | let c = dot(dcenter, dcenter) - ball.radius * ball.radius; 29 | let delta = b * b - a * c; 30 | let t = -b - sqrt(delta); 31 | 32 | if (c > 0.0 && (b > 0.0 || a == 0.0)) || delta < 0.0 || t > maxTimeOfImpact * a { 33 | // No hit. 34 | return -1.0; 35 | } else if a == 0.0 { 36 | // Dir is zero but the ray started inside the ball. 37 | return 0.0; 38 | } else { 39 | // Hit. If t <= 0, the origin is inside the ball. 40 | return max(t / a, 0.0); 41 | } 42 | } 43 | 44 | /// Casts a ray on a transformed ball. 45 | /// 46 | /// Returns a negative value if there is no hit. 47 | /// If there is a hit, the result is a scalar `t >= 0` such that the hit point is equal to `ray.origin + t * ray.dir`. 48 | fn castRay(ball: Ball, pose: Transform, ray: Ray::Ray, maxTimeOfImpact: f32) -> f32 { 49 | let localRay = Ray::Ray(Pose::invMulPt(pose, ray.origin), Pose::invMulVec(pose, ray.dir)); 50 | return castLocalRay(ball, localRay, maxTimeOfImpact); 51 | } 52 | */ 53 | 54 | /// Projects a point on a ball. 55 | /// 56 | /// If the point is inside the ball, the point itself is returned. 57 | fn projectLocalPoint(ball: Ball, pt: Vector) -> Vector { 58 | let dist = length(pt); 59 | 60 | if dist >= ball.radius { 61 | // The point is outside the ball. 62 | return pt * (ball.radius / dist); 63 | } else { 64 | // The point is inside the ball. 65 | return pt; 66 | } 67 | } 68 | 69 | /// Projects a point on a transformed ball. 70 | /// 71 | /// If the point is inside the ball, the point itself is returned. 72 | fn projectPoint(ball: Ball, pose: Transform, pt: Vector) -> Vector { 73 | let localPt = Pose::invMulPt(pose, pt); 74 | return Pose::mulPt(pose, projectLocalPoint(ball, localPt)); 75 | } 76 | 77 | 78 | /// Projects a point on the boundary of a ball. 79 | fn projectLocalPointOnBoundary(ball: Ball, pt: Vector) -> Proj::ProjectionResult { 80 | let dist = length(pt); 81 | #if DIM == 2 82 | let fallback = vec2(0.0, ball.radius); 83 | #else 84 | let fallback = vec3(0.0, ball.radius, 0.0); 85 | #endif 86 | 87 | let projected_point = 88 | select(fallback, pt * (ball.radius / dist), dist != 0.0); 89 | let is_inside = dist <= ball.radius; 90 | 91 | return Proj::ProjectionResult(projected_point, is_inside); 92 | } 93 | 94 | /// Project a point of a transformed ball’s boundary. 95 | /// 96 | /// If the point is inside of the box, it will be projected on its boundary but 97 | /// `ProjectionResult::is_inside` will be set to `true`. 98 | fn projectPointOnBoundary(ball: Ball, pose: Transform, pt: Vector) -> Proj::ProjectionResult { 99 | let local_pt = Pose::invMulPt(pose, pt); 100 | var result = projectLocalPointOnBoundary(ball, local_pt); 101 | result.point = Pose::mulPt(pose, result.point); 102 | return result; 103 | } 104 | -------------------------------------------------------------------------------- /crates/wgparry/src/capsule.rs: -------------------------------------------------------------------------------- 1 | //! The capsule shape. 2 | 3 | use crate::projection::WgProjection; 4 | use crate::ray::WgRay; 5 | use crate::segment::WgSegment; 6 | use crate::{dim_shader_defs, substitute_aliases}; 7 | use wgcore::Shader; 8 | use wgebra::{WgSim2, WgSim3}; 9 | 10 | #[derive(Shader)] 11 | #[shader( 12 | derive(WgSim3, WgSim2, WgRay, WgProjection, WgSegment), 13 | src = "capsule.wgsl", 14 | src_fn = "substitute_aliases", 15 | shader_defs = "dim_shader_defs" 16 | )] 17 | /// Shader defining the capsule shape as well as its ray-casting and point-projection functions. 18 | pub struct WgCapsule; 19 | 20 | #[cfg(test)] 21 | mod test { 22 | use super::WgCapsule; 23 | use parry::shape::Capsule; 24 | use wgcore::tensor::GpuVector; 25 | 26 | #[futures_test::test] 27 | #[serial_test::serial] 28 | async fn gpu_capsule() { 29 | crate::projection::test_utils::test_point_projection::( 30 | "Capsule", 31 | Capsule::new_y(1.0, 0.5), 32 | |device, shapes, usages| GpuVector::encase(device, shapes, usages).into_inner(), 33 | ) 34 | .await; 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /crates/wgparry/src/capsule.wgsl: -------------------------------------------------------------------------------- 1 | #if DIM == 2 2 | #import wgebra::sim2 as Pose 3 | #else 4 | #import wgebra::sim3 as Pose 5 | #endif 6 | #import wgparry::ray as Ray 7 | #import wgparry::projection as Proj 8 | #import wgparry::segment as Seg 9 | 10 | #define_import_path wgparry::capsule 11 | 12 | /// A capsule, defined by its radius. 13 | struct Capsule { 14 | /// The capsule’s principal axis. 15 | segment: Seg::Segment, 16 | /// The capsule’s radius. 17 | radius: f32, 18 | } 19 | 20 | fn orthonormal_basis3(v: vec3) -> array, 2> { 21 | // NOTE: not using `sign` because we don’t want the 0.0 case to return 0.0. 22 | let sign = select(-1.0, 1.0, v.z >= 0.0); 23 | let a = -1.0 / (sign + v.z); 24 | let b = v.x * v.y * a; 25 | 26 | return array( 27 | vec3( 28 | 1.0 + sign * v.x * v.x * a, 29 | sign * b, 30 | -sign * v.x, 31 | ), 32 | vec3(b, sign + v.y * v.y * a, -v.y), 33 | ); 34 | } 35 | 36 | fn any_orthogonal_vector(v: Vector) -> Vector { 37 | #if DIM == 2 38 | return vec2(v.y, -v.x); 39 | #else 40 | return orthonormal_basis3(v)[0]; 41 | #endif 42 | } 43 | 44 | /// Projects a point on a capsule. 45 | /// 46 | /// If the point is inside the capsule, the point itself is returned. 47 | fn projectLocalPoint(capsule: Capsule, pt: Vector) -> Vector { 48 | let proj_on_axis = Seg::projectLocalPoint(capsule.segment, pt); 49 | let dproj = pt - proj_on_axis; 50 | let dist_to_axis = length(dproj); 51 | 52 | // PERF: call `select` instead? 53 | if dist_to_axis > capsule.radius { 54 | return proj_on_axis + dproj * (capsule.radius / dist_to_axis); 55 | } else { 56 | return pt; 57 | } 58 | } 59 | 60 | /// Projects a point on a transformed capsule. 61 | /// 62 | /// If the point is inside the capsule, the point itself is returned. 63 | fn projectPoint(capsule: Capsule, pose: Transform, pt: Vector) -> Vector { 64 | let localPt = Pose::invMulPt(pose, pt); 65 | return Pose::mulPt(pose, projectLocalPoint(capsule, localPt)); 66 | } 67 | 68 | 69 | /// Projects a point on the boundary of a capsule. 70 | fn projectLocalPointOnBoundary(capsule: Capsule, pt: Vector) -> Proj::ProjectionResult { 71 | let proj_on_axis = Seg::projectLocalPoint(capsule.segment, pt); 72 | let dproj = pt - proj_on_axis; 73 | let dist_to_axis = length(dproj); 74 | 75 | if dist_to_axis > 0.0 { 76 | let is_inside = dist_to_axis <= capsule.radius; 77 | return Proj::ProjectionResult(proj_on_axis + dproj * (capsule.radius / dist_to_axis), is_inside); 78 | } else { 79 | // Very rare occurence: the point lies on the capsule’s axis exactly. 80 | // Pick an arbitrary projection direction along an axis orthogonal to the principal axis. 81 | let axis_seg = capsule.segment.b - capsule.segment.a; 82 | let axis_len = length(axis_seg); 83 | let proj_dir = any_orthogonal_vector(axis_seg / axis_len); 84 | return Proj::ProjectionResult(proj_on_axis + proj_dir * capsule.radius, true); 85 | } 86 | } 87 | 88 | /// Project a point of a transformed capsule’s boundary. 89 | /// 90 | /// If the point is inside of the box, it will be projected on its boundary but 91 | /// `ProjectionResult::is_inside` will be set to `true`. 92 | fn projectPointOnBoundary(capsule: Capsule, pose: Transform, pt: Vector) -> Proj::ProjectionResult { 93 | let local_pt = Pose::invMulPt(pose, pt); 94 | var result = projectLocalPointOnBoundary(capsule, local_pt); 95 | result.point = Pose::mulPt(pose, result.point); 96 | return result; 97 | } 98 | -------------------------------------------------------------------------------- /crates/wgparry/src/cone.rs: -------------------------------------------------------------------------------- 1 | //! The cone shape. 2 | 3 | use crate::projection::WgProjection; 4 | use crate::ray::WgRay; 5 | use crate::segment::WgSegment; 6 | use crate::{dim_shader_defs, substitute_aliases}; 7 | use wgcore::Shader; 8 | use wgebra::{WgSim2, WgSim3}; 9 | 10 | #[derive(Shader)] 11 | #[shader( 12 | derive(WgSim3, WgSim2, WgRay, WgProjection, WgSegment), 13 | src = "cone.wgsl", 14 | src_fn = "substitute_aliases", 15 | shader_defs = "dim_shader_defs" 16 | )] 17 | /// Shader defining the cone shape as well as its ray-casting and point-projection functions. 18 | pub struct WgCone; 19 | 20 | #[cfg(test)] 21 | mod test { 22 | use super::WgCone; 23 | use parry::shape::Cone; 24 | use wgcore::tensor::GpuVector; 25 | 26 | #[futures_test::test] 27 | #[serial_test::serial] 28 | async fn gpu_cone() { 29 | crate::projection::test_utils::test_point_projection::( 30 | "Cone", 31 | Cone::new(1.0, 0.5), 32 | |device, shapes, usages| GpuVector::init(device, shapes, usages).into_inner(), 33 | ) 34 | .await; 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /crates/wgparry/src/cone.wgsl: -------------------------------------------------------------------------------- 1 | #if DIM == 2 2 | #import wgebra::sim2 as Pose 3 | #else 4 | #import wgebra::sim3 as Pose 5 | #endif 6 | #import wgparry::ray as Ray 7 | #import wgparry::projection as Proj 8 | #import wgparry::segment as Seg 9 | 10 | #define_import_path wgparry::cone 11 | 12 | /// A cone, defined by its radius. 13 | struct Cone { 14 | /// The cone’s principal axis. 15 | half_height: f32, 16 | /// The cone’s radius. 17 | radius: f32, 18 | } 19 | 20 | /// Projects a point on a cone. 21 | /// 22 | /// If the point is inside the cone, the point itself is returned. 23 | fn projectLocalPoint(cone: Cone, pt: Vector) -> Vector { 24 | // Project on the basis. 25 | let planar_dist_from_basis_center = length(pt.xz); 26 | let dir_from_basis_center = select( 27 | vec2(1.0, 0.0), 28 | pt.xz / planar_dist_from_basis_center, 29 | planar_dist_from_basis_center > 0.0 30 | ); 31 | 32 | let projection_on_basis = vec3(pt.x, -cone.half_height, pt.z); 33 | 34 | if pt.y < -cone.half_height && planar_dist_from_basis_center <= cone.radius { 35 | // The projection is on the basis. 36 | return projection_on_basis; 37 | } 38 | 39 | // Project on the basis circle. 40 | let proj2d = dir_from_basis_center * cone.radius; 41 | let projection_on_basis_circle = vec3(proj2d[0], -cone.half_height, proj2d[1]); 42 | 43 | // Project on the conic side. 44 | // TODO: we could solve this in 2D using the plane passing through the cone axis and the conic_side_segment to save some computation. 45 | let apex_point = vec3(0.0, cone.half_height, 0.0); 46 | let conic_side_segment = Seg::Segment(apex_point, projection_on_basis_circle); 47 | let conic_side_segment_dir = conic_side_segment.b - conic_side_segment.a; 48 | let proj = Seg::projectLocalPoint(conic_side_segment, pt); 49 | 50 | let apex_to_basis_center = vec3(0.0, -2.0 * cone.half_height, 0.0); 51 | 52 | // Now determine if the point is inside of the cone. 53 | if pt.y >= -cone.half_height 54 | && pt.y <= cone.half_height 55 | && dot( 56 | cross(conic_side_segment_dir, pt - apex_point), 57 | cross(conic_side_segment_dir, apex_to_basis_center) 58 | ) >= 0.0 59 | { 60 | // We are inside of the cone. 61 | return pt; 62 | } else { 63 | // We are outside of the cone, return the computed segment projection. 64 | return proj; 65 | } 66 | } 67 | 68 | /// Projects a point on a transformed cone. 69 | /// 70 | /// If the point is inside the cone, the point itself is returned. 71 | fn projectPoint(cone: Cone, pose: Transform, pt: Vector) -> Vector { 72 | let localPt = Pose::invMulPt(pose, pt); 73 | return Pose::mulPt(pose, projectLocalPoint(cone, localPt)); 74 | } 75 | 76 | 77 | /// Projects a point on the boundary of a cone. 78 | fn projectLocalPointOnBoundary(cone: Cone, pt: Vector) -> Proj::ProjectionResult { 79 | // Project on the basis. 80 | let planar_dist_from_basis_center = length(pt.xz); 81 | let dir_from_basis_center = select( 82 | vec2(1.0, 0.0), 83 | pt.xz / planar_dist_from_basis_center, 84 | planar_dist_from_basis_center > 0.0 85 | ); 86 | 87 | let projection_on_basis = vec3(pt.x, -cone.half_height, pt.z); 88 | 89 | if pt.y < -cone.half_height && planar_dist_from_basis_center <= cone.radius { 90 | // The projection is on the basis. 91 | return Proj::ProjectionResult(projection_on_basis, false); 92 | } 93 | 94 | // Project on the basis circle. 95 | let proj2d = dir_from_basis_center * cone.radius; 96 | let projection_on_basis_circle = vec3(proj2d[0], -cone.half_height, proj2d[1]); 97 | 98 | // Project on the conic side. 99 | // TODO: we could solve this in 2D using the plane passing through the cone axis and the conic_side_segment to save some computation. 100 | let apex_point = vec3(0.0, cone.half_height, 0.0); 101 | let conic_side_segment = Seg::Segment(apex_point, projection_on_basis_circle); 102 | let conic_side_segment_dir = conic_side_segment.b - conic_side_segment.a; 103 | let proj = Seg::projectLocalPoint(conic_side_segment, pt); 104 | 105 | let apex_to_basis_center = vec3(0.0, -2.0 * cone.half_height, 0.0); 106 | 107 | // Now determine if the point is inside of the cone. 108 | if pt.y >= -cone.half_height 109 | && pt.y <= cone.half_height 110 | && dot( 111 | cross(conic_side_segment_dir, pt - apex_point), 112 | cross(conic_side_segment_dir, apex_to_basis_center) 113 | ) >= 0.0 114 | { 115 | // We are inside of the cone, so the correct projection is 116 | // either on the basis of the cone, or on the conic side. 117 | let pt_to_proj = proj - pt; 118 | let pt_to_basis_proj = projection_on_basis - pt; 119 | if dot(pt_to_proj, pt_to_proj) > dot(pt_to_basis_proj, pt_to_basis_proj) { 120 | return Proj::ProjectionResult(projection_on_basis, true); 121 | } else { 122 | return Proj::ProjectionResult(proj, true); 123 | } 124 | } else { 125 | // We are outside of the cone, return the computed segment projection as-is. 126 | return Proj::ProjectionResult(proj, false); 127 | } 128 | } 129 | 130 | /// Project a point of a transformed cone’s boundary. 131 | /// 132 | /// If the point is inside of the box, it will be projected on its boundary but 133 | /// `ProjectionResult::is_inside` will be set to `true`. 134 | fn projectPointOnBoundary(cone: Cone, pose: Transform, pt: Vector) -> Proj::ProjectionResult { 135 | let local_pt = Pose::invMulPt(pose, pt); 136 | var result = projectLocalPointOnBoundary(cone, local_pt); 137 | result.point = Pose::mulPt(pose, result.point); 138 | return result; 139 | } 140 | -------------------------------------------------------------------------------- /crates/wgparry/src/contact.rs: -------------------------------------------------------------------------------- 1 | use crate::ball::WgBall; 2 | use wgcore::Shader; 3 | use wgebra::WgSim3; 4 | 5 | #[derive(Shader)] 6 | #[shader(derive(WgSim3, WgBall, WgProjection), src = "contact.wgsl")] 7 | pub struct WgContact; 8 | 9 | wgcore::test_shader_compilation!(WgContact); 10 | -------------------------------------------------------------------------------- /crates/wgparry/src/contact.wgsl: -------------------------------------------------------------------------------- 1 | #import wgebra::sim3 as Pose 2 | #import wgparry::ball as Ball 3 | 4 | #define_import_path wgparry::contact 5 | 6 | /// A pair of contact points between two shapes. 7 | struct Contact { 8 | /// The contact point on the first shape. 9 | point1: vec3, 10 | /// The contact pointon the second shape. 11 | point2: vec3, 12 | /// The first shape’s normal at its contact point. 13 | normal1: vec3, 14 | /// The second shape’s normal at its contact point. 15 | normal2: vec3, 16 | /// The distance between the two contact points. 17 | dist: f32, 18 | } 19 | 20 | 21 | /// Computes the contact between two balls. 22 | fn ballBall(pose12: Pose::Sim3, ball1: Ball::Ball, ball2: Ball::Ball) -> Contact { 23 | let r1 = ball1.radius; 24 | let r2 = ball2.radius; 25 | let center2_1 = pose12.translation_scale.xyz; 26 | let distance = length(center2_1); 27 | let sum_radius = r1 + r2; 28 | 29 | var normal1 = vec3(1.0, 0.0, 0.0); 30 | 31 | if distance != 0.0 { 32 | normal1 = center2_1 / distance; 33 | } 34 | 35 | let normal2 = -Pose::invMulUnitVec(pose12, normal1); 36 | let point1 = normal1 * r1; 37 | let point2 = normal2 * r2; 38 | 39 | return Contact( 40 | point1, 41 | point2, 42 | normal1, 43 | normal2, 44 | distance - sum_radius, 45 | ); 46 | } -------------------------------------------------------------------------------- /crates/wgparry/src/cuboid.rs: -------------------------------------------------------------------------------- 1 | //! The cuboid shape. 2 | 3 | use crate::projection::WgProjection; 4 | use crate::ray::WgRay; 5 | use crate::{dim_shader_defs, substitute_aliases}; 6 | use wgcore::Shader; 7 | use wgebra::{WgSim2, WgSim3}; 8 | 9 | #[derive(Shader)] 10 | #[shader( 11 | derive(WgSim3, WgSim2, WgRay, WgProjection), 12 | src = "cuboid.wgsl", 13 | src_fn = "substitute_aliases", 14 | shader_defs = "dim_shader_defs" 15 | )] 16 | /// Shader defining the Cuboid shape as well as its ray-casting and point-projection functions. 17 | pub struct WgCuboid; 18 | 19 | #[cfg(test)] 20 | mod test { 21 | use super::WgCuboid; 22 | use na::vector; 23 | #[cfg(feature = "dim2")] 24 | use parry2d::shape::Cuboid; 25 | #[cfg(feature = "dim3")] 26 | use parry3d::shape::Cuboid; 27 | use wgcore::tensor::GpuVector; 28 | 29 | #[futures_test::test] 30 | #[serial_test::serial] 31 | async fn gpu_cuboid() { 32 | crate::projection::test_utils::test_point_projection::( 33 | "Cuboid", 34 | #[cfg(feature = "dim2")] 35 | Cuboid::new(vector![1.0, 2.0]), 36 | #[cfg(feature = "dim3")] 37 | Cuboid::new(vector![1.0, 2.0, 3.0]), 38 | |device, shapes, usages| GpuVector::encase(device, shapes, usages).into_inner(), 39 | ) 40 | .await; 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /crates/wgparry/src/cuboid.wgsl: -------------------------------------------------------------------------------- 1 | #if DIM == 2 2 | #import wgebra::sim2 as Pose 3 | #else 4 | #import wgebra::sim3 as Pose 5 | #endif 6 | #import wgparry::ray as Ray 7 | #import wgparry::projection as Proj 8 | 9 | #define_import_path wgparry::cuboid 10 | 11 | /// The result of a point projection. 12 | struct ProjectionResult { 13 | /// The point’s projection on the shape. 14 | /// This can be equal to the original point if the point was inside 15 | /// of the shape and the projection function doesn’t always project 16 | /// on the boundary. 17 | point: Vector, 18 | /// Is the point inside of the shape? 19 | is_inside: bool, 20 | } 21 | 22 | 23 | /// A box, defined by its half-extents (half-length alon geach dimension). 24 | struct Cuboid { 25 | halfExtents: Vector 26 | } 27 | 28 | /// Projects a point on a box. 29 | /// 30 | /// If the point is inside the box, the point itself is returned. 31 | fn projectLocalPoint(box: Cuboid, pt: Vector) -> Vector { 32 | let mins = -box.halfExtents; 33 | let maxs = box.halfExtents; 34 | 35 | let mins_pt = mins - pt; // -hext - pt 36 | let pt_maxs = pt - maxs; // pt - hext 37 | let shift = max(mins_pt, Vector(0.0)) - max(pt_maxs, Vector(0.0)); 38 | 39 | return pt + shift; 40 | } 41 | 42 | /// Projects a point on a transformed box. 43 | /// 44 | /// If the point is inside the box, the point itself is returned. 45 | fn projectPoint(box: Cuboid, pose: Transform, pt: Vector) -> Vector { 46 | let local_pt = Pose::invMulPt(pose, pt); 47 | return Pose::mulPt(pose, projectLocalPoint(box, local_pt)); 48 | } 49 | 50 | /// Projects a point on the boundary of a box. 51 | fn projectLocalPointOnBoundary(box: Cuboid, pt: Vector) -> Proj::ProjectionResult { 52 | let out_proj = projectLocalPoint(box, pt); 53 | 54 | // Projection if the point is inside the box. 55 | let pt_sgn_with_zero = sign(pt); 56 | // This the sign of pt, or -1 for components that were zero. 57 | // This bias is arbitrary (we could have picked +1), but we picked it so 58 | // it matches the bias that’s in parry. 59 | let pt_sgn = pt_sgn_with_zero + (abs(pt_sgn_with_zero) - Vector(1.0)); 60 | let diff = box.halfExtents - pt_sgn * pt; 61 | 62 | #if DIM == 2 63 | let pick_x = diff.x <= diff.y; 64 | let shift_x = Vector(diff.x * pt_sgn.x, 0.0); 65 | let shift_y = Vector(0.0, diff.y * pt_sgn.y); 66 | let pen_shift = select(shift_y, shift_x, pick_x); 67 | #else 68 | let pick_x = diff.x <= diff.y && diff.x <= diff.z; 69 | let pick_y = diff.y <= diff.x && diff.y <= diff.z; 70 | let shift_x = Vector(diff.x * pt_sgn.x, 0.0, 0.0); 71 | let shift_y = Vector(0.0, diff.y * pt_sgn.y, 0.0); 72 | let shift_z = Vector(0.0, 0.0, diff.z * pt_sgn.z); 73 | let pen_shift = select(select(shift_z, shift_y, pick_y), shift_x, pick_x); 74 | #endif 75 | let in_proj = pt + pen_shift; 76 | 77 | // Select between in and out proj. 78 | let is_inside = all(pt == out_proj); 79 | return Proj::ProjectionResult(select(out_proj, in_proj, is_inside), is_inside); 80 | } 81 | 82 | /// Project a point of a transformed box’s boundary. 83 | /// 84 | /// If the point is inside of the box, it will be projected on its boundary but 85 | /// `ProjectionResult::is_inside` will be set to `true`. 86 | fn projectPointOnBoundary(box: Cuboid, pose: Transform, pt: Vector) -> Proj::ProjectionResult { 87 | let local_pt = Pose::invMulPt(pose, pt); 88 | var result = projectLocalPointOnBoundary(box, local_pt); 89 | result.point = Pose::mulPt(pose, result.point); 90 | return result; 91 | } 92 | 93 | 94 | // FIXME: ray.wgsl needs to support 2d/3d for these implementations to be commented-out. 95 | ///* 96 | // * Ray casting. 97 | // */ 98 | ///// Casts a ray on a box. 99 | ///// 100 | ///// Returns a negative value if there is no hit. 101 | ///// If there is a hit, the result is a scalar `t >= 0` such that the hit point is equal to `ray.origin + t * ray.dir`. 102 | //fn castLocalRay(box: Cuboid, ray: Ray::Ray, maxTimeOfImpact: f32) -> f32 { 103 | // let mins = -box.halfExtents; 104 | // let maxs = box.halfExtents; 105 | // let inter1 = (mins - ray.origin) / ray.dir; 106 | // let inter2 = (maxs - ray.origin) / ray.dir; 107 | // 108 | // let vtmin = min(inter1, inter2); 109 | // let vtmax = max(inter1, inter2); 110 | // 111 | //#if DIM == 2 112 | // let tmin = max(max(vtmin.x, vtmin.y), 0.0); 113 | // let tmax = min(min(vtmax.x, vtmax.y), maxTimeOfImpact); 114 | //#else 115 | // let tmin = max(max(max(vtmin.x, vtmin.y), vtmin.z), 0.0); 116 | // let tmax = min(min(min(vtmax.x, vtmax.y), vtmax.z), maxTimeOfImpact); 117 | //#endif 118 | // 119 | // if tmin > tmax || tmax < 0.0 { 120 | // return -1.0; 121 | // } else { 122 | // return tmin; 123 | // } 124 | //} 125 | // 126 | ///// Casts a ray on a transformed box. 127 | ///// 128 | ///// Returns a negative value if there is no hit. 129 | ///// If there is a hit, the result is a scalar `t >= 0` such that the hit point is equal to `ray.origin + t * ray.dir`. 130 | //fn castRay(box: Cuboid, pose: Transform, ray: Ray::Ray, maxTimeOfImpact: f32) -> f32 { 131 | // let localRay = Ray::Ray(Pose::invMulPt(pose, ray.origin), Pose::invMulVec(pose, ray.dir)); 132 | // return castLocalRay(box, localRay, maxTimeOfImpact); 133 | //} 134 | -------------------------------------------------------------------------------- /crates/wgparry/src/cylinder.rs: -------------------------------------------------------------------------------- 1 | //! The cylinder shape. 2 | 3 | use crate::projection::WgProjection; 4 | use crate::ray::WgRay; 5 | use crate::{dim_shader_defs, substitute_aliases}; 6 | use wgcore::Shader; 7 | use wgebra::{WgSim2, WgSim3}; 8 | 9 | #[derive(Shader)] 10 | #[shader( 11 | derive(WgSim3, WgSim2, WgRay, WgProjection), 12 | src = "cylinder.wgsl", 13 | src_fn = "substitute_aliases", 14 | shader_defs = "dim_shader_defs" 15 | )] 16 | /// Shader defining the cylinder shape as well as its ray-casting and point-projection functions. 17 | pub struct WgCylinder; 18 | 19 | #[cfg(test)] 20 | mod test { 21 | use super::WgCylinder; 22 | use parry::shape::Cylinder; 23 | use wgcore::tensor::GpuVector; 24 | 25 | #[futures_test::test] 26 | #[serial_test::serial] 27 | async fn gpu_cylinder() { 28 | crate::projection::test_utils::test_point_projection::( 29 | "Cylinder", 30 | Cylinder::new(1.0, 0.5), 31 | |device, shapes, usages| GpuVector::init(device, shapes, usages).into_inner(), 32 | ) 33 | .await; 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /crates/wgparry/src/cylinder.wgsl: -------------------------------------------------------------------------------- 1 | #if DIM == 2 2 | #import wgebra::sim2 as Pose 3 | #else 4 | #import wgebra::sim3 as Pose 5 | #endif 6 | #import wgparry::ray as Ray 7 | #import wgparry::projection as Proj 8 | 9 | #define_import_path wgparry::cylinder 10 | 11 | /// A cylinder, defined by its radius. 12 | struct Cylinder { 13 | /// The cylinder’s principal axis. 14 | half_height: f32, 15 | /// The cylinder’s radius. 16 | radius: f32, 17 | } 18 | 19 | /// Projects a point on a cylinder. 20 | /// 21 | /// If the point is inside the cylinder, the point itself is returned. 22 | fn projectLocalPoint(cylinder: Cylinder, pt: Vector) -> Vector { 23 | 24 | // Project on the basis. 25 | let planar_dist_from_basis_center = length(pt.xz); 26 | let dir_from_basis_center = select( 27 | vec2(1.0, 0.0), 28 | pt.xz / planar_dist_from_basis_center, 29 | planar_dist_from_basis_center > 0.0 30 | ); 31 | 32 | let proj2d = dir_from_basis_center * cylinder.radius; 33 | 34 | // PERF: reduce branching 35 | if pt.y >= -cylinder.half_height 36 | && pt.y <= cylinder.half_height 37 | && planar_dist_from_basis_center <= cylinder.radius 38 | { 39 | return pt; 40 | } else { 41 | // The point is outside of the cylinder. 42 | if pt.y > cylinder.half_height { 43 | if planar_dist_from_basis_center <= cylinder.radius { 44 | return vec3(pt.x, cylinder.half_height, pt.z); 45 | } else { 46 | return vec3(proj2d[0], cylinder.half_height, proj2d[1]); 47 | } 48 | } else if pt.y < -cylinder.half_height { 49 | // Project on the bottom plane or the bottom circle. 50 | if planar_dist_from_basis_center <= cylinder.radius { 51 | return vec3(pt.x, -cylinder.half_height, pt.z); 52 | } else { 53 | return vec3(proj2d[0], -cylinder.half_height, proj2d[1]); 54 | } 55 | } else { 56 | // Project on the side. 57 | return vec3(proj2d[0], pt.y, proj2d[1]); 58 | } 59 | } 60 | } 61 | 62 | /// Projects a point on a transformed cylinder. 63 | /// 64 | /// If the point is inside the cylinder, the point itself is returned. 65 | fn projectPoint(cylinder: Cylinder, pose: Transform, pt: Vector) -> Vector { 66 | let localPt = Pose::invMulPt(pose, pt); 67 | return Pose::mulPt(pose, projectLocalPoint(cylinder, localPt)); 68 | } 69 | 70 | 71 | /// Projects a point on the boundary of a cylinder. 72 | fn projectLocalPointOnBoundary(cylinder: Cylinder, pt: Vector) -> Proj::ProjectionResult { 73 | // Project on the basis. 74 | let planar_dist_from_basis_center = length(pt.xz); 75 | let dir_from_basis_center = select( 76 | vec2(1.0, 0.0), 77 | pt.xz / planar_dist_from_basis_center, 78 | planar_dist_from_basis_center > 0.0 79 | ); 80 | 81 | let proj2d = dir_from_basis_center * cylinder.radius; 82 | 83 | // PERF: reduce branching 84 | if pt.y >= -cylinder.half_height 85 | && pt.y <= cylinder.half_height 86 | && planar_dist_from_basis_center <= cylinder.radius 87 | { 88 | // The point is inside of the cylinder. 89 | let dist_to_top = cylinder.half_height - pt.y; 90 | let dist_to_bottom = pt.y - (-cylinder.half_height); 91 | let dist_to_side = cylinder.radius - planar_dist_from_basis_center; 92 | 93 | if dist_to_top < dist_to_bottom && dist_to_top < dist_to_side { 94 | let projection_on_top = vec3(pt.x, cylinder.half_height, pt.z); 95 | return Proj::ProjectionResult(projection_on_top, true); 96 | } else if dist_to_bottom < dist_to_top && dist_to_bottom < dist_to_side { 97 | let projection_on_bottom = 98 | vec3(pt.x, -cylinder.half_height, pt.z); 99 | return Proj::ProjectionResult(projection_on_bottom, true); 100 | } else { 101 | let projection_on_side = vec3(proj2d[0], pt.y, proj2d[1]); 102 | return Proj::ProjectionResult(projection_on_side, true); 103 | } 104 | } else { 105 | // The point is outside of the cylinder. 106 | if pt.y > cylinder.half_height { 107 | if planar_dist_from_basis_center <= cylinder.radius { 108 | let projection_on_top = vec3(pt.x, cylinder.half_height, pt.z); 109 | return Proj::ProjectionResult(projection_on_top, false); 110 | } else { 111 | let projection_on_top_circle = 112 | vec3(proj2d[0], cylinder.half_height, proj2d[1]); 113 | return Proj::ProjectionResult(projection_on_top_circle, false); 114 | } 115 | } else if pt.y < -cylinder.half_height { 116 | // Project on the bottom plane or the bottom circle. 117 | if planar_dist_from_basis_center <= cylinder.radius { 118 | let projection_on_bottom = 119 | vec3(pt.x, -cylinder.half_height, pt.z); 120 | return Proj::ProjectionResult(projection_on_bottom, false); 121 | } else { 122 | let projection_on_bottom_circle = 123 | vec3(proj2d[0], -cylinder.half_height, proj2d[1]); 124 | return Proj::ProjectionResult(projection_on_bottom_circle, false); 125 | } 126 | } else { 127 | // Project on the side. 128 | let projection_on_side = vec3(proj2d[0], pt.y, proj2d[1]); 129 | return Proj::ProjectionResult(projection_on_side, false); 130 | } 131 | } 132 | } 133 | 134 | /// Project a point of a transformed cylinder’s boundary. 135 | /// 136 | /// If the point is inside of the box, it will be projected on its boundary but 137 | /// `ProjectionResult::is_inside` will be set to `true`. 138 | fn projectPointOnBoundary(cylinder: Cylinder, pose: Transform, pt: Vector) -> Proj::ProjectionResult { 139 | let local_pt = Pose::invMulPt(pose, pt); 140 | var result = projectLocalPointOnBoundary(cylinder, local_pt); 141 | result.point = Pose::mulPt(pose, result.point); 142 | return result; 143 | } 144 | -------------------------------------------------------------------------------- /crates/wgparry/src/projection.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path wgparry::projection 2 | 3 | /// The result of a point projection. 4 | struct ProjectionResult { 5 | /// The point’s projection on the shape. 6 | /// This can be equal to the original point if the point was inside 7 | /// of the shape and the projection function doesn’t always project 8 | /// on the boundary. 9 | point: Vector, 10 | /// Is the point inside of the shape? 11 | is_inside: bool, 12 | } 13 | -------------------------------------------------------------------------------- /crates/wgparry/src/ray.rs: -------------------------------------------------------------------------------- 1 | //! The ray structure. 2 | 3 | use wgcore::Shader; 4 | 5 | #[derive(Shader)] 6 | #[shader(src = "ray.wgsl")] 7 | /// Shader defining the wgsl ray structure for ray-casting. 8 | pub struct WgRay; 9 | 10 | wgcore::test_shader_compilation!(WgRay); 11 | -------------------------------------------------------------------------------- /crates/wgparry/src/ray.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path wgparry::ray 2 | 3 | struct Ray { 4 | origin: vec3, 5 | dir: vec3, 6 | } 7 | 8 | /// The point on the ray at the given parameter `t`. 9 | fn ptAt(ray: Ray, t: f32) -> vec3 { 10 | return ray.origin + ray.dir * t; 11 | } -------------------------------------------------------------------------------- /crates/wgparry/src/segment.rs: -------------------------------------------------------------------------------- 1 | //! The segment shape. 2 | 3 | use crate::projection::WgProjection; 4 | use crate::ray::WgRay; 5 | use crate::{dim_shader_defs, substitute_aliases}; 6 | use wgcore::Shader; 7 | use wgebra::{WgSim2, WgSim3}; 8 | 9 | #[derive(Shader)] 10 | #[shader( 11 | derive(WgSim3, WgSim2, WgRay, WgProjection), 12 | src = "segment.wgsl", 13 | src_fn = "substitute_aliases", 14 | shader_defs = "dim_shader_defs" 15 | )] 16 | /// Shader defining the segment shape as well as its ray-casting and point-projection functions. 17 | pub struct WgSegment; 18 | 19 | // TODO: 20 | // #[cfg(test)] 21 | // mod test { 22 | // use super::WgSegment; 23 | // use parry::shape::Segment; 24 | // use wgcore::tensor::GpuVector; 25 | // 26 | // #[futures_test::test] 27 | // #[serial_test::serial] 28 | // async fn gpu_segment() { 29 | // crate::projection::test_utils::test_point_projection::( 30 | // "Segment", 31 | // Segment::new(1.0, 0.5), 32 | // |device, shapes, usages| GpuVector::encase(device, shapes, usages).into_inner(), 33 | // ) 34 | // .await; 35 | // } 36 | // } 37 | -------------------------------------------------------------------------------- /crates/wgparry/src/segment.wgsl: -------------------------------------------------------------------------------- 1 | #if DIM == 2 2 | #import wgebra::sim2 as Pose 3 | #else 4 | #import wgebra::sim3 as Pose 5 | #endif 6 | #import wgparry::ray as Ray 7 | #import wgparry::projection as Proj 8 | 9 | #define_import_path wgparry::segment 10 | 11 | struct Segment { 12 | a: Vector, 13 | b: Vector, 14 | } 15 | 16 | // TODO: implement the other projection functions 17 | fn projectLocalPoint(seg: Segment, pt: Vector) -> Vector { 18 | let ab = seg.b - seg.a; 19 | let ap = pt - seg.a; 20 | let ab_ap = dot(ab, ap); 21 | let sqnab = dot(ab, ab); 22 | 23 | // PERF: would it be faster to do a bunch of `select` instead of `if`? 24 | if ab_ap <= 0.0 { 25 | // Voronoï region of vertex 'a'. 26 | return seg.a; 27 | } else if ab_ap >= sqnab { 28 | // Voronoï region of vertex 'b'. 29 | return seg.b; 30 | } else { 31 | // Voronoï region of the segment interior. 32 | let u = ab_ap / sqnab; 33 | return seg.a + ab * u; 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /crates/wgparry/src/shape_fake_cone.wgsl: -------------------------------------------------------------------------------- 1 | // This module only exists as a workaround for some naga-oil weirdness. 2 | 3 | #define_import_path wgparry::cone 4 | -------------------------------------------------------------------------------- /crates/wgparry/src/shape_fake_cylinder.wgsl: -------------------------------------------------------------------------------- 1 | // This module only exists as a workaround for some naga-oil weirdness. 2 | 3 | #define_import_path wgparry::cylinder 4 | -------------------------------------------------------------------------------- /crates/wgparry/src/triangle.rs: -------------------------------------------------------------------------------- 1 | //! The triangle shape. 2 | 3 | use crate::substitute_aliases; 4 | use wgcore::Shader; 5 | 6 | #[derive(Shader)] 7 | #[shader(src = "triangle.wgsl", src_fn = "substitute_aliases")] 8 | /// Shader defining the triangle shape as well as its ray-casting and point-projection functions. 9 | pub struct WgTriangle; 10 | -------------------------------------------------------------------------------- /crates/wgparry/src/triangle.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path wgparry::triangle 2 | 3 | struct Triangle { 4 | a: Vector, 5 | b: Vector, 6 | c: Vector, 7 | } 8 | -------------------------------------------------------------------------------- /crates/wgrapier/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## v0.2.0 2 | 3 | ### Modified 4 | 5 | - Update to `wgcore` v0.2.0. 6 | -------------------------------------------------------------------------------- /crates/wgrapier/README.md: -------------------------------------------------------------------------------- 1 | # wgrapier: cross-platform GPU physics simulation 2 | 3 | **/!\ This library is still under heavy development and is still missing many features.** 4 | 5 | The goal of **wgrapier** is to especially be "**rapier** on the gpu". It aims (but it isn’t there yet) to be a 6 | GPU rigid-body physics engine. 7 | -------------------------------------------------------------------------------- /crates/wgrapier/crates/wgrapier2d/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "wgrapier2d" 3 | authors = ["Sébastien Crozet "] 4 | description = "Cross-platform 2D rigid-body physics." 5 | homepage = "https://wgmath.rs" 6 | repository = "https://github.com/dimforge/wgmath" 7 | version = "0.2.0" 8 | edition = "2021" 9 | license = "MIT OR Apache-2.0" 10 | 11 | [lib] 12 | name = "wgrapier2d" 13 | path = "../../src/lib.rs" 14 | required-features = ["dim2"] 15 | 16 | [lints] 17 | rust.unexpected_cfgs = { level = "warn", check-cfg = [ 18 | 'cfg(feature, values("dim3"))', 19 | ] } 20 | 21 | [features] 22 | default = ["dim2"] 23 | dim2 = [] 24 | 25 | [dependencies] 26 | nalgebra = { workspace = true } 27 | wgpu = { workspace = true } 28 | naga_oil = { workspace = true } 29 | bytemuck = { workspace = true } 30 | encase = { workspace = true } 31 | 32 | wgcore = { version = "0.2", path = "../../../wgcore" } 33 | wgebra = { version = "0.2", path = "../../../wgebra" } 34 | wgparry2d = { version = "0.2", path = "../../../wgparry/crates/wgparry2d" } 35 | rapier2d = "0.23" # TODO: should be behind a feature? 36 | num-traits = "0.2" 37 | 38 | [dev-dependencies] 39 | nalgebra = { version = "0.33", features = ["rand"] } 40 | futures-test = "0.3" 41 | serial_test = "3" 42 | approx = "0.5" 43 | async-std = { version = "1", features = ["attributes"] } 44 | #bevy = { version = "0.14", features = ["shader_format_glsl", "shader_format_spirv"], } 45 | #bevy_panorbit_camera = "0.19.1" -------------------------------------------------------------------------------- /crates/wgrapier/crates/wgrapier2d/README.md: -------------------------------------------------------------------------------- 1 | ../../README.md -------------------------------------------------------------------------------- /crates/wgrapier/crates/wgrapier2d/src: -------------------------------------------------------------------------------- 1 | ../../src -------------------------------------------------------------------------------- /crates/wgrapier/crates/wgrapier3d/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "wgrapier3d" 3 | authors = ["Sébastien Crozet "] 4 | description = "Cross-platform 3D rigid-body physics." 5 | homepage = "https://wgmath.rs" 6 | repository = "https://github.com/dimforge/wgmath" 7 | version = "0.2.0" 8 | edition = "2021" 9 | license = "MIT OR Apache-2.0" 10 | 11 | [lib] 12 | name = "wgrapier3d" 13 | path = "../../src/lib.rs" 14 | required-features = ["dim3"] 15 | 16 | [lints] 17 | rust.unexpected_cfgs = { level = "warn", check-cfg = [ 18 | 'cfg(feature, values("dim2"))', 19 | ] } 20 | 21 | [features] 22 | default = ["dim3"] 23 | dim3 = [] 24 | 25 | [dependencies] 26 | nalgebra = { workspace = true } 27 | wgpu = { workspace = true } 28 | naga_oil = { workspace = true } 29 | bytemuck = { workspace = true } 30 | encase = { workspace = true } 31 | 32 | wgcore = { version = "0.2", path = "../../../wgcore" } 33 | wgebra = { version = "0.2", path = "../../../wgebra" } 34 | wgparry3d = { version = "0.2", path = "../../../wgparry/crates/wgparry3d" } 35 | rapier3d = "0.23"# TODO: should be behind a feature? 36 | num-traits = "0.2" 37 | 38 | [dev-dependencies] 39 | nalgebra = { version = "0.33", features = ["rand"] } 40 | futures-test = "0.3" 41 | serial_test = "3" 42 | approx = "0.5" 43 | async-std = { version = "1", features = ["attributes"] } 44 | #bevy = { version = "0.14", features = ["shader_format_glsl", "shader_format_spirv"] } 45 | #bevy_panorbit_camera = "0.19.1" -------------------------------------------------------------------------------- /crates/wgrapier/crates/wgrapier3d/README.md: -------------------------------------------------------------------------------- 1 | ../../README.md -------------------------------------------------------------------------------- /crates/wgrapier/crates/wgrapier3d/src: -------------------------------------------------------------------------------- 1 | ../../src -------------------------------------------------------------------------------- /crates/wgrapier/examples/gravity.wgsl: -------------------------------------------------------------------------------- 1 | #import wgrapier::body as Body; 2 | #import wgebra::sim3 as Pose; 3 | 4 | @group(0) @binding(0) 5 | var mprops: array; 6 | @group(0) @binding(1) 7 | var local_mprops: array; 8 | @group(0) @binding(2) 9 | var poses: array; 10 | @group(0) @binding(3) 11 | var vels: array; 12 | 13 | const WORKGROUP_SIZE: u32 = 64; 14 | 15 | @compute @workgroup_size(WORKGROUP_SIZE, 1, 1) 16 | fn main(@builtin(global_invocation_id) invocation_id: vec3, @builtin(num_workgroups) num_workgroups: vec3) { 17 | let gravity = Body::Force(vec3(0.0, -9.81, 0.0), vec3(0.0)); 18 | let num_threads = num_workgroups.x * WORKGROUP_SIZE * num_workgroups.y * num_workgroups.z; 19 | for (var i = invocation_id.x; i < arrayLength(&poses); i += num_threads) { 20 | let new_vels = Body::integrateForces(mprops[i], vels[i], gravity, 0.016); 21 | let new_pose = Body::integratePose(poses[i], new_vels, local_mprops[i].com, 0.016); 22 | let new_mprops = Body::updateMprops(new_pose, local_mprops[i]); 23 | 24 | mprops[i] = new_mprops; 25 | vels[i] = new_vels; 26 | poses[i] = new_pose; 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /crates/wgrapier/src/dynamics/body.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path wgrapier::body 2 | 3 | #if DIM == 2 4 | #import wgebra::sim2 as Pose 5 | #import wgebra::rot2 as Rot 6 | #else 7 | #import wgebra::sim3 as Pose 8 | #import wgebra::quat as Rot 9 | #endif 10 | 11 | 12 | /// The mass-properties of a rigid-body. 13 | /// Note that the mass-properties may be expressed either in the rigid-body’s local-space or in world-space, 14 | /// depending on its provenance. Usually, the world-space and local-space mass-properties will be stored in 15 | /// two separate buffers. 16 | struct MassProperties { 17 | // TODO: a representation with Quaternion & vec3 (for frame & principal inertia) would be much more compact and make 18 | // this struct have the size of a mat4x4 19 | #if DIM == 2 20 | /// The rigid-body’s inverse inertia tensor. 21 | inv_inertia: f32, 22 | #else 23 | inv_inertia: mat3x3, 24 | #endif 25 | /// The rigid-body’s inverse mass along each coordinate axis. 26 | /// 27 | /// Allowing different values along each axis allows the user to specify 0 along each axis. 28 | /// By setting zero, the linear motion along the corresponding world-space axis will be locked. 29 | inv_mass: Vector, 30 | /// The rigid-body’s center of mass. 31 | com: Vector, 32 | } 33 | 34 | /// An impulse (linear and angular/torque) 35 | struct Impulse { 36 | /// A linear impulse. 37 | linear: Vector, 38 | /// An angular impulse (torque impulse). 39 | angular: AngVector, 40 | } 41 | 42 | /// A force and torque. 43 | struct Force { 44 | /// A linear force. 45 | linear: Vector, 46 | /// An angular force (torque). 47 | angular: AngVector, 48 | } 49 | 50 | /// A linear and angular velocity. 51 | struct Velocity { 52 | /// The linear (translational) part of the velocity. 53 | linear: Vector, 54 | /// The angular (rotational) part of the velocity. 55 | angular: AngVector, 56 | } 57 | 58 | /// A rigid-body pose and its velocity. 59 | struct RigidBodyState { 60 | /// The rigid-body’s pose (translation, rotation, uniform scale). 61 | pose: Transform, 62 | /// The rigid-body’s velocity (translational and rotational). 63 | velocity: Velocity, 64 | } 65 | 66 | /// Computes new velocities after applying the given impulse. 67 | fn applyImpulse(mprops: MassProperties, velocity: Velocity, imp: Impulse) -> Velocity { 68 | let acc_lin = mprops.inv_mass * imp.linear; 69 | let acc_ang = mprops.inv_inertia * imp.angular; 70 | return Velocity(velocity.linear + acc_lin, velocity.angular + acc_ang); 71 | } 72 | 73 | 74 | /// Computes new velocities after integrating forces by a timestep equal to `dt`. 75 | fn integrateForces(mprops: MassProperties, velocity: Velocity, force: Force, dt: f32) -> Velocity { 76 | let acc_lin = mprops.inv_mass * force.linear; 77 | let acc_ang = mprops.inv_inertia * force.angular; 78 | return Velocity(velocity.linear + acc_lin * dt, velocity.angular + acc_ang * dt); 79 | } 80 | 81 | #if DIM == 2 82 | /// Computes a new pose after integrating velocitie by a timestep equal to `dt`. 83 | fn integrateVelocity(pose: Transform, vels: Velocity, local_com: Vector, dt: f32) -> Transform { 84 | let init_com = Pose::mulPt(pose, local_com); 85 | let init_tra = pose.translation; 86 | let init_scale = pose.scale; 87 | 88 | let delta_ang = Rot::fromAngle(vels.angular * dt); 89 | let delta_lin = vels.linear * dt; 90 | 91 | let new_translation = 92 | init_com + Rot::mulVec(delta_ang, (init_tra - init_com)) * init_scale + delta_lin; 93 | let new_rotation = Rot::mul(delta_ang, pose.rotation); 94 | 95 | return Transform(new_rotation, new_translation, init_scale); 96 | } 97 | 98 | /// Computes the new world-space mass-properties based on the local-space mass-properties and its transform. 99 | fn updateMprops(pose: Transform, local_mprops: MassProperties) -> MassProperties { 100 | let world_com = Pose::mulPt(pose, local_mprops.com); 101 | return MassProperties(local_mprops.inv_inertia, local_mprops.inv_mass, world_com); 102 | } 103 | 104 | /// Computes the linear velocity at a given point. 105 | fn velocity_at_point(center_of_mass: Vector, vels: Velocity, point: Vector) -> Vector { 106 | let lever_arm = point - center_of_mass; 107 | return vels.linear + vels.angular * vec2(-lever_arm.y, lever_arm.x); 108 | } 109 | #else 110 | /// Computes a new pose after integrating velocitie by a timestep equal to `dt`. 111 | fn integrateVelocity(pose: Transform, vels: Velocity, local_com: Vector, dt: f32) -> Transform { 112 | let init_com = Pose::mulPt(pose, local_com); 113 | let init_tra = pose.translation_scale.xyz; 114 | let init_scale = pose.translation_scale.w; 115 | 116 | let delta_ang = Rot::fromScaledAxis(vels.angular * dt); 117 | let delta_lin = vels.linear * dt; 118 | 119 | let new_translation = 120 | init_com + Rot::mulVec(delta_ang, (init_tra - init_com)) * init_scale + delta_lin; 121 | let new_rotation = Rot::renormalizeFast(Rot::mul(delta_ang, pose.rotation)); 122 | 123 | return Transform(new_rotation, vec4(new_translation, init_scale)); 124 | } 125 | 126 | /// Computes the new world-space mass-properties based on the local-space mass-properties and its transform. 127 | fn updateMprops(pose: Transform, local_mprops: MassProperties) -> MassProperties { 128 | let world_com = Pose::mulPt(pose, local_mprops.com); 129 | let rot_mat = Rot::toMatrix(pose.rotation); 130 | let world_inv_inertia = rot_mat * local_mprops.inv_inertia * transpose(rot_mat); 131 | 132 | return MassProperties(world_inv_inertia, local_mprops.inv_mass, world_com); 133 | } 134 | 135 | /// Computes the linear velocity at a given point. 136 | fn velocity_at_point(com: Vector, vels: Velocity, point: Vector) -> Vector { 137 | return vels.linear + cross(vels.angular, point - com); 138 | } 139 | #endif -------------------------------------------------------------------------------- /crates/wgrapier/src/dynamics/integrate.rs: -------------------------------------------------------------------------------- 1 | //! Force and velocity integration. 2 | 3 | use crate::dynamics::body::{GpuBodySet, WgBody}; 4 | use wgcore::kernel::KernelDispatch; 5 | use wgcore::Shader; 6 | use wgparry::{dim_shader_defs, substitute_aliases}; 7 | use wgpu::{ComputePass, ComputePipeline, Device}; 8 | 9 | #[derive(Shader)] 10 | #[shader( 11 | derive(WgBody), 12 | src = "integrate.wgsl", 13 | src_fn = "substitute_aliases", 14 | shader_defs = "dim_shader_defs" 15 | )] 16 | /// Shaders exposing composable functions for force and velocity integration. 17 | pub struct WgIntegrate { 18 | /// Compute shader for integrating forces and velocities of every rigid-body. 19 | pub integrate: ComputePipeline, 20 | } 21 | 22 | impl WgIntegrate { 23 | const WORKGROUP_SIZE: u32 = 64; 24 | 25 | /// Dispatch an invocation of [`WgIntegrate::integrate`] for integrating forces and velocities 26 | /// of every rigid-body in the given [`GpuBodySet`]: 27 | pub fn dispatch(&self, device: &Device, pass: &mut ComputePass, bodies: &GpuBodySet) { 28 | KernelDispatch::new(device, pass, &self.integrate) 29 | .bind0([ 30 | bodies.mprops.buffer(), 31 | bodies.local_mprops.buffer(), 32 | bodies.poses.buffer(), 33 | bodies.vels.buffer(), 34 | ]) 35 | .dispatch(bodies.len().div_ceil(Self::WORKGROUP_SIZE)); 36 | } 37 | } 38 | 39 | wgcore::test_shader_compilation!(WgIntegrate, wgcore, wgparry::dim_shader_defs()); 40 | -------------------------------------------------------------------------------- /crates/wgrapier/src/dynamics/integrate.wgsl: -------------------------------------------------------------------------------- 1 | #define_import_path wgrapier::integrate 2 | 3 | #import wgrapier::body as Body; 4 | 5 | #if DIM == 2 6 | #import wgebra::sim2 as Pose; 7 | #else 8 | #import wgebra::sim3 as Pose; 9 | #endif 10 | 11 | @group(0) @binding(0) 12 | var mprops: array; 13 | @group(0) @binding(1) 14 | var local_mprops: array; 15 | #if DIM == 2 16 | @group(0) @binding(2) 17 | var poses: array; 18 | #else 19 | @group(0) @binding(2) 20 | var poses: array; 21 | #endif 22 | @group(0) @binding(3) 23 | var vels: array; 24 | 25 | const WORKGROUP_SIZE: u32 = 64; 26 | 27 | @compute @workgroup_size(WORKGROUP_SIZE, 1, 1) 28 | fn integrate(@builtin(global_invocation_id) invocation_id: vec3, @builtin(num_workgroups) num_workgroups: vec3) { 29 | let num_threads = num_workgroups.x * WORKGROUP_SIZE; 30 | for (var i = invocation_id.x; i < arrayLength(&poses); i += num_threads) { 31 | // TODO: get dt from somewhere 32 | let new_pose = Body::integrateVelocity(poses[i], vels[i], local_mprops[i].com, 0.0016); 33 | let new_mprops = Body::updateMprops(new_pose, local_mprops[i]); 34 | 35 | mprops[i] = new_mprops; 36 | poses[i] = new_pose; 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /crates/wgrapier/src/dynamics/mod.rs: -------------------------------------------------------------------------------- 1 | //! Rigid-body dynamics (forces, velocities, etc.) 2 | 3 | pub use body::{BodyDesc, GpuBodySet, GpuForce, GpuMassProperties, GpuVelocity, WgBody}; 4 | pub use integrate::WgIntegrate; 5 | 6 | pub mod body; 7 | pub mod integrate; 8 | -------------------------------------------------------------------------------- /crates/wgrapier/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![doc = include_str!("../README.md")] 2 | // #![warn(missing_docs)] 3 | 4 | #[cfg(feature = "dim2")] 5 | pub extern crate rapier2d as rapier; 6 | #[cfg(feature = "dim3")] 7 | pub extern crate rapier3d as rapier; 8 | #[cfg(feature = "dim2")] 9 | pub extern crate wgparry2d as wgparry; 10 | #[cfg(feature = "dim3")] 11 | pub extern crate wgparry3d as wgparry; 12 | 13 | pub mod dynamics; 14 | --------------------------------------------------------------------------------