├── .gitignore ├── examples ├── ferris.png └── basic.rs ├── Cargo.toml ├── LICENSE ├── README.md ├── shaders └── matching.wgsl └── src └── lib.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | -------------------------------------------------------------------------------- /examples/ferris.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/urholaukkarinen/template-matching/HEAD/examples/ferris.png -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "template-matching" 3 | version = "0.2.0" 4 | edition = "2021" 5 | authors = ["Urho Laukkarinen "] 6 | 7 | description = "GPU-accelerated template matching" 8 | license = "MIT" 9 | readme = "README.md" 10 | repository = "https://github.com/urholaukkarinen/template-matching" 11 | homepage = "https://github.com/urholaukkarinen/template-matching" 12 | keywords = ["gpu", "image"] 13 | categories = ["computer-vision"] 14 | 15 | [dependencies] 16 | wgpu = "0.16" 17 | pollster = "0.3" 18 | bytemuck = { version = "1.13", features = ["derive"] } 19 | image = { version = "0.24", optional = true } 20 | futures-intrusive = "0.5" 21 | 22 | [dev-dependencies] 23 | image = "0.24" 24 | imageproc = "0.23" 25 | 26 | [features] 27 | default = ["image"] 28 | image = ["dep:image"] 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Urho Laukkarinen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # template-matching 2 | 3 | [![Latest version](https://img.shields.io/crates/v/template-matching.svg)](https://crates.io/crates/template-matching) 4 | [![Documentation](https://docs.rs/template-matching/badge.svg)](https://docs.rs/template-matching) 5 | ![MIT](https://img.shields.io/badge/license-MIT-blue.svg) 6 | 7 | GPU-accelerated template matching library for Rust. The crate is designed as a faster alternative to [imageproc::template_matching](https://docs.rs/imageproc/latest/imageproc/template_matching/index.html). 8 | 9 | ## Installation 10 | 11 | ```bash 12 | [dependencies] 13 | template-matching = { version = "0.2.0", features = ["image"] } 14 | ``` 15 | 16 | ## Usage 17 | 18 | ```rust 19 | use template_matching::{find_extremes, match_template, MatchTemplateMethod, TemplateMatcher}; 20 | 21 | fn main() { 22 | // Load images and convert them to f32 grayscale 23 | let input_image = image::load_from_memory(include_bytes!("input.png")).unwrap().to_luma32f(); 24 | let template_image = image::load_from_memory(include_bytes!("template.png")).unwrap().to_luma32f(); 25 | 26 | let result = match_template(&input_image, &template_image, MatchTemplateMethod::SumOfSquaredDifferences); 27 | 28 | // Or alternatively you can create the matcher first 29 | let mut matcher = TemplateMatcher::new(); 30 | matcher.match_template(&input_image, &template_image, MatchTemplateMethod::SumOfSquaredDifferences); 31 | let result = matcher.wait_for_result().unwrap(); 32 | 33 | // Calculate min & max values 34 | let extremes = find_extremes(&result); 35 | } 36 | ``` 37 | -------------------------------------------------------------------------------- /examples/basic.rs: -------------------------------------------------------------------------------- 1 | use std::time::Instant; 2 | 3 | use image::{DynamicImage, GenericImageView}; 4 | use template_matching::{find_extremes, MatchTemplateMethod, TemplateMatcher}; 5 | 6 | fn main() { 7 | let input_image = image::load_from_memory(include_bytes!("ferris.png")).unwrap(); 8 | let input_luma8 = input_image.to_luma8(); 9 | let input_luma32f = input_image.to_luma32f(); 10 | 11 | let mut matcher = TemplateMatcher::new(); 12 | 13 | for i in 0..5 { 14 | let n = 10 + i * 5; 15 | let template_image = DynamicImage::ImageRgba8(input_image.view(n, n, n, n).to_image()); 16 | let template_luma8 = template_image.to_luma8(); 17 | let template_luma32f = template_image.to_luma32f(); 18 | 19 | // Start matching with GPU acceleration 20 | let time = Instant::now(); 21 | matcher.match_template( 22 | &input_luma32f, 23 | &template_luma32f, 24 | MatchTemplateMethod::SumOfSquaredDifferences, 25 | ); 26 | let matcher_start_elapsed = time.elapsed(); 27 | 28 | // Start matching with imageproc 29 | let time = Instant::now(); 30 | let result = imageproc::template_matching::match_template( 31 | &input_luma8, 32 | &template_luma8, 33 | imageproc::template_matching::MatchTemplateMethod::SumOfSquaredErrors, 34 | ); 35 | println!( 36 | "imageproc::template_matching::match_template took {} ms", 37 | time.elapsed().as_millis() 38 | ); 39 | let extremes = imageproc::template_matching::find_extremes(&result); 40 | println!("{:?}", extremes); 41 | 42 | // Get result from GPU accelerated matching 43 | let time = Instant::now(); 44 | let result = matcher.wait_for_result().unwrap(); 45 | println!( 46 | "template_matching::match_template took {:.2} ms", 47 | (time.elapsed() + matcher_start_elapsed).as_micros() as f32 / 1000.0 48 | ); 49 | 50 | let extremes = find_extremes(&result); 51 | println!("{:?}", extremes); 52 | println!(); 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /shaders/matching.wgsl: -------------------------------------------------------------------------------- 1 | struct Uniforms { 2 | input_width: u32, 3 | input_height: u32, 4 | template_width: u32, 5 | template_height: u32, 6 | }; 7 | 8 | @group(0) 9 | @binding(0) 10 | var input_buf: array; 11 | 12 | @group(0) 13 | @binding(1) 14 | var template_buf: array; 15 | 16 | @group(0) 17 | @binding(2) 18 | var result_buf: array; 19 | 20 | @group(0) 21 | @binding(3) 22 | var uniforms: Uniforms; 23 | 24 | @compute 25 | @workgroup_size(16, 16, 1) 26 | fn main_sad(@builtin(global_invocation_id) global_id: vec3) { 27 | var x = global_id.x; 28 | var y = global_id.y; 29 | 30 | var input_width = uniforms.input_width; 31 | var input_height = uniforms.input_height; 32 | 33 | var template_width = uniforms.template_width; 34 | var template_height = uniforms.template_height; 35 | 36 | var match_width = min(template_width, input_width - x); 37 | var match_height = min(template_height, input_height - y); 38 | 39 | var total_sum = 0.0; 40 | for (var i = 0u; i < match_width; i++) { 41 | for (var j = 0u; j < match_height; j++) { 42 | var input_idx = (y + j) * input_width + (i + x); 43 | var template_idx = j * template_width + i; 44 | 45 | var input_val = input_buf[input_idx]; 46 | var template_val = template_buf[template_idx]; 47 | 48 | var diff = abs(input_val - template_val); 49 | 50 | total_sum += diff; 51 | } 52 | } 53 | 54 | var result_idx = y * (input_width - template_width + 1u) + x; 55 | result_buf[result_idx] = total_sum; 56 | } 57 | 58 | @compute 59 | @workgroup_size(16, 16, 1) 60 | fn main_ssd(@builtin(global_invocation_id) global_id: vec3) { 61 | var x = global_id.x; 62 | var y = global_id.y; 63 | 64 | var input_width = uniforms.input_width; 65 | var input_height = uniforms.input_height; 66 | 67 | var template_width = uniforms.template_width; 68 | var template_height = uniforms.template_height; 69 | 70 | var match_width = min(template_width, input_width - x); 71 | var match_height = min(template_height, input_height - y); 72 | 73 | var total_sum = 0.0; 74 | for (var i = 0u; i < match_width; i++) { 75 | for (var j = 0u; j < match_height; j++) { 76 | var input_idx = (y + j) * input_width + (i + x); 77 | var template_idx = j * template_width + i; 78 | 79 | var input_val = input_buf[input_idx]; 80 | var template_val = template_buf[template_idx]; 81 | 82 | var sqdiff = pow(input_val - template_val, 2.0); 83 | 84 | total_sum += sqdiff; 85 | } 86 | } 87 | 88 | var result_idx = y * (input_width - template_width + 1u) + x; 89 | result_buf[result_idx] = total_sum; 90 | } 91 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! GPU-accelerated template matching. 2 | //! 3 | //! Faster alternative to [imageproc::template_matching](https://docs.rs/imageproc/latest/imageproc/template_matching/index.html). 4 | 5 | #![deny(clippy::all)] 6 | #![allow(dead_code)] 7 | #![allow(unused_variables)] 8 | 9 | use std::{borrow::Cow, mem::size_of}; 10 | use wgpu::util::DeviceExt; 11 | 12 | #[derive(Copy, Clone, Debug, PartialEq)] 13 | pub enum MatchTemplateMethod { 14 | SumOfAbsoluteDifferences, 15 | SumOfSquaredDifferences, 16 | } 17 | 18 | /// Slides a template over the input and scores the match at each point using the requested method. 19 | /// 20 | /// This is a shorthand for: 21 | /// ```ignore 22 | /// let mut matcher = TemplateMatcher::new(); 23 | /// matcher.match_template(input, template, method); 24 | /// matcher.wait_for_result().unwrap() 25 | /// ``` 26 | /// You can use [find_extremes] to find minimum and maximum values, and their locations in the result image. 27 | pub fn match_template<'a>( 28 | input: impl Into>, 29 | template: impl Into>, 30 | method: MatchTemplateMethod, 31 | ) -> Image<'static> { 32 | let mut matcher = TemplateMatcher::new(); 33 | matcher.match_template(input, template, method); 34 | matcher.wait_for_result().unwrap() 35 | } 36 | 37 | /// Finds the smallest and largest values and their locations in an image. 38 | pub fn find_extremes(input: &Image<'_>) -> Extremes { 39 | let mut min_value = f32::MAX; 40 | let mut min_value_location = (0, 0); 41 | let mut max_value = f32::MIN; 42 | let mut max_value_location = (0, 0); 43 | 44 | for y in 0..input.height { 45 | for x in 0..input.width { 46 | let idx = (y * input.width) + x; 47 | let value = input.data[idx as usize]; 48 | 49 | if value < min_value { 50 | min_value = value; 51 | min_value_location = (x, y); 52 | } 53 | 54 | if value > max_value { 55 | max_value = value; 56 | max_value_location = (x, y); 57 | } 58 | } 59 | } 60 | 61 | Extremes { 62 | min_value, 63 | max_value, 64 | min_value_location, 65 | max_value_location, 66 | } 67 | } 68 | 69 | pub struct Image<'a> { 70 | pub data: Cow<'a, [f32]>, 71 | pub width: u32, 72 | pub height: u32, 73 | } 74 | 75 | impl<'a> Image<'a> { 76 | pub fn new(data: impl Into>, width: u32, height: u32) -> Self { 77 | Self { 78 | data: data.into(), 79 | width, 80 | height, 81 | } 82 | } 83 | } 84 | 85 | #[cfg(feature = "image")] 86 | impl<'a> From<&'a image::ImageBuffer, Vec>> for Image<'a> { 87 | fn from(img: &'a image::ImageBuffer, Vec>) -> Self { 88 | Self { 89 | data: Cow::Borrowed(img), 90 | width: img.width(), 91 | height: img.height(), 92 | } 93 | } 94 | } 95 | 96 | #[derive(Copy, Clone, Debug)] 97 | pub struct Extremes { 98 | pub min_value: f32, 99 | pub max_value: f32, 100 | pub min_value_location: (u32, u32), 101 | pub max_value_location: (u32, u32), 102 | } 103 | 104 | #[repr(C)] 105 | #[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)] 106 | struct ShaderUniforms { 107 | input_width: u32, 108 | input_height: u32, 109 | template_width: u32, 110 | template_height: u32, 111 | } 112 | 113 | pub struct TemplateMatcher { 114 | instance: wgpu::Instance, 115 | adapter: wgpu::Adapter, 116 | device: wgpu::Device, 117 | queue: wgpu::Queue, 118 | shader: wgpu::ShaderModule, 119 | bind_group_layout: wgpu::BindGroupLayout, 120 | pipeline_layout: wgpu::PipelineLayout, 121 | 122 | last_pipeline: Option, 123 | last_method: Option, 124 | 125 | last_input_size: (u32, u32), 126 | last_template_size: (u32, u32), 127 | last_result_size: (u32, u32), 128 | 129 | uniform_buffer: wgpu::Buffer, 130 | input_buffer: Option, 131 | template_buffer: Option, 132 | result_buffer: Option, 133 | staging_buffer: Option, 134 | bind_group: Option, 135 | 136 | matching_ongoing: bool, 137 | } 138 | 139 | impl Default for TemplateMatcher { 140 | fn default() -> Self { 141 | Self::new() 142 | } 143 | } 144 | 145 | impl TemplateMatcher { 146 | pub fn new() -> Self { 147 | let instance = wgpu::Instance::new(wgpu::InstanceDescriptor { 148 | backends: wgpu::Backends::all(), 149 | dx12_shader_compiler: Default::default(), 150 | }); 151 | 152 | let adapter = pollster::block_on(async { 153 | instance 154 | .request_adapter(&wgpu::RequestAdapterOptions { 155 | power_preference: wgpu::PowerPreference::HighPerformance, 156 | compatible_surface: None, 157 | force_fallback_adapter: false, 158 | }) 159 | .await 160 | .expect("Adapter request failed") 161 | }); 162 | 163 | let (device, queue) = pollster::block_on(async { 164 | adapter 165 | .request_device( 166 | &wgpu::DeviceDescriptor { 167 | label: None, 168 | features: wgpu::Features::empty(), 169 | limits: wgpu::Limits::default(), 170 | }, 171 | None, 172 | ) 173 | .await 174 | .expect("Device request failed") 175 | }); 176 | 177 | let shader = device.create_shader_module(wgpu::include_wgsl!("../shaders/matching.wgsl")); 178 | 179 | let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { 180 | label: None, 181 | entries: &[ 182 | wgpu::BindGroupLayoutEntry { 183 | binding: 0, 184 | visibility: wgpu::ShaderStages::COMPUTE, 185 | ty: wgpu::BindingType::Buffer { 186 | ty: wgpu::BufferBindingType::Storage { read_only: true }, 187 | has_dynamic_offset: false, 188 | min_binding_size: None, 189 | }, 190 | count: None, 191 | }, 192 | wgpu::BindGroupLayoutEntry { 193 | binding: 1, 194 | visibility: wgpu::ShaderStages::COMPUTE, 195 | ty: wgpu::BindingType::Buffer { 196 | ty: wgpu::BufferBindingType::Storage { read_only: true }, 197 | has_dynamic_offset: false, 198 | min_binding_size: None, 199 | }, 200 | count: None, 201 | }, 202 | wgpu::BindGroupLayoutEntry { 203 | binding: 2, 204 | visibility: wgpu::ShaderStages::COMPUTE, 205 | ty: wgpu::BindingType::Buffer { 206 | ty: wgpu::BufferBindingType::Storage { read_only: false }, 207 | has_dynamic_offset: false, 208 | min_binding_size: None, 209 | }, 210 | count: None, 211 | }, 212 | wgpu::BindGroupLayoutEntry { 213 | binding: 3, 214 | visibility: wgpu::ShaderStages::COMPUTE, 215 | ty: wgpu::BindingType::Buffer { 216 | ty: wgpu::BufferBindingType::Uniform, 217 | has_dynamic_offset: false, 218 | min_binding_size: None, 219 | }, 220 | count: None, 221 | }, 222 | ], 223 | }); 224 | 225 | let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { 226 | label: None, 227 | bind_group_layouts: &[&bind_group_layout], 228 | push_constant_ranges: &[], 229 | }); 230 | 231 | let uniform_buffer = device.create_buffer(&wgpu::BufferDescriptor { 232 | label: Some("uniform_buffer"), 233 | usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST, 234 | size: size_of::() as _, 235 | mapped_at_creation: false, 236 | }); 237 | 238 | Self { 239 | instance, 240 | adapter, 241 | device, 242 | queue, 243 | shader, 244 | pipeline_layout, 245 | bind_group_layout, 246 | last_pipeline: None, 247 | last_method: None, 248 | last_input_size: (0, 0), 249 | last_template_size: (0, 0), 250 | last_result_size: (0, 0), 251 | uniform_buffer, 252 | input_buffer: None, 253 | template_buffer: None, 254 | result_buffer: None, 255 | staging_buffer: None, 256 | bind_group: None, 257 | matching_ongoing: false, 258 | } 259 | } 260 | 261 | /// Waits for the latest [match_template] execution and returns the result. 262 | /// Returns [None] if no matching was started. 263 | pub fn wait_for_result(&mut self) -> Option> { 264 | if !self.matching_ongoing { 265 | return None; 266 | } 267 | self.matching_ongoing = false; 268 | 269 | let (result_width, result_height) = self.last_result_size; 270 | 271 | let buffer_slice = self.staging_buffer.as_ref().unwrap().slice(..); 272 | let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel(); 273 | buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap()); 274 | 275 | self.device.poll(wgpu::Maintain::Wait); 276 | 277 | pollster::block_on(async { 278 | let result; 279 | 280 | if let Some(Ok(())) = receiver.receive().await { 281 | let data = buffer_slice.get_mapped_range(); 282 | result = bytemuck::cast_slice(&data).to_vec(); 283 | drop(data); 284 | self.staging_buffer.as_ref().unwrap().unmap(); 285 | } else { 286 | result = vec![0.0; (result_width * result_height) as usize] 287 | }; 288 | 289 | Some(Image::new(result, result_width as _, result_height as _)) 290 | }) 291 | } 292 | 293 | /// Slides a template over the input and scores the match at each point using the requested method. 294 | /// To get the result of the matching, call [wait_for_result]. 295 | pub fn match_template<'a>( 296 | &mut self, 297 | input: impl Into>, 298 | template: impl Into>, 299 | method: MatchTemplateMethod, 300 | ) { 301 | if self.matching_ongoing { 302 | // Discard previous result if not collected. 303 | self.wait_for_result(); 304 | } 305 | 306 | let input = input.into(); 307 | let template = template.into(); 308 | 309 | if self.last_pipeline.is_none() || self.last_method != Some(method) { 310 | self.last_method = Some(method); 311 | 312 | let entry_point = match method { 313 | MatchTemplateMethod::SumOfAbsoluteDifferences => "main_sad", 314 | MatchTemplateMethod::SumOfSquaredDifferences => "main_ssd", 315 | }; 316 | 317 | self.last_pipeline = Some(self.device.create_compute_pipeline( 318 | &wgpu::ComputePipelineDescriptor { 319 | label: None, 320 | layout: Some(&self.pipeline_layout), 321 | module: &self.shader, 322 | entry_point, 323 | }, 324 | )); 325 | } 326 | 327 | let mut buffers_changed = false; 328 | 329 | let input_size = (input.width, input.height); 330 | if self.input_buffer.is_none() || self.last_input_size != input_size { 331 | buffers_changed = true; 332 | 333 | self.last_input_size = input_size; 334 | 335 | self.input_buffer = Some(self.device.create_buffer_init( 336 | &wgpu::util::BufferInitDescriptor { 337 | label: Some("input_buffer"), 338 | contents: bytemuck::cast_slice(&input.data), 339 | usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST, 340 | }, 341 | )); 342 | } else { 343 | self.queue.write_buffer( 344 | self.input_buffer.as_ref().unwrap(), 345 | 0, 346 | bytemuck::cast_slice(&input.data), 347 | ); 348 | } 349 | 350 | let template_size = (template.width, template.height); 351 | if self.template_buffer.is_none() || self.last_template_size != template_size { 352 | self.queue.write_buffer( 353 | &self.uniform_buffer, 354 | 0, 355 | bytemuck::cast_slice(&[ShaderUniforms { 356 | input_width: input.width, 357 | input_height: input.height, 358 | template_width: template.width, 359 | template_height: template.height, 360 | }]), 361 | ); 362 | buffers_changed = true; 363 | 364 | self.last_template_size = template_size; 365 | 366 | self.template_buffer = Some(self.device.create_buffer_init( 367 | &wgpu::util::BufferInitDescriptor { 368 | label: Some("template_buffer"), 369 | contents: bytemuck::cast_slice(&template.data), 370 | usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST, 371 | }, 372 | )); 373 | } else { 374 | self.queue.write_buffer( 375 | self.template_buffer.as_ref().unwrap(), 376 | 0, 377 | bytemuck::cast_slice(&template.data), 378 | ); 379 | } 380 | 381 | let result_width = input.width - template.width + 1; 382 | let result_height = input.height - template.height + 1; 383 | let result_buf_size = (result_width * result_height) as u64 * size_of::() as u64; 384 | 385 | if buffers_changed { 386 | self.last_result_size = (result_width, result_height); 387 | 388 | self.result_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor { 389 | label: Some("result_buffer"), 390 | usage: wgpu::BufferUsages::STORAGE 391 | | wgpu::BufferUsages::COPY_SRC 392 | | wgpu::BufferUsages::COPY_DST, 393 | size: result_buf_size, 394 | mapped_at_creation: false, 395 | })); 396 | 397 | self.staging_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor { 398 | label: Some("staging_buffer"), 399 | usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ, 400 | size: result_buf_size, 401 | mapped_at_creation: false, 402 | })); 403 | 404 | self.bind_group = Some(self.device.create_bind_group(&wgpu::BindGroupDescriptor { 405 | label: None, 406 | layout: &self.bind_group_layout, 407 | entries: &[ 408 | wgpu::BindGroupEntry { 409 | binding: 0, 410 | resource: self.input_buffer.as_ref().unwrap().as_entire_binding(), 411 | }, 412 | wgpu::BindGroupEntry { 413 | binding: 1, 414 | resource: self.template_buffer.as_ref().unwrap().as_entire_binding(), 415 | }, 416 | wgpu::BindGroupEntry { 417 | binding: 2, 418 | resource: self.result_buffer.as_ref().unwrap().as_entire_binding(), 419 | }, 420 | wgpu::BindGroupEntry { 421 | binding: 3, 422 | resource: self.uniform_buffer.as_entire_binding(), 423 | }, 424 | ], 425 | })); 426 | } 427 | 428 | let mut encoder = self 429 | .device 430 | .create_command_encoder(&wgpu::CommandEncoderDescriptor { 431 | label: Some("encoder"), 432 | }); 433 | 434 | { 435 | let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { 436 | label: Some("compute_pass"), 437 | }); 438 | compute_pass.set_pipeline(self.last_pipeline.as_ref().unwrap()); 439 | compute_pass.set_bind_group(0, self.bind_group.as_ref().unwrap(), &[]); 440 | compute_pass.dispatch_workgroups( 441 | (result_width as f32 / 16.0).ceil() as u32, 442 | (result_height as f32 / 16.0).ceil() as u32, 443 | 1, 444 | ); 445 | } 446 | 447 | encoder.copy_buffer_to_buffer( 448 | self.result_buffer.as_ref().unwrap(), 449 | 0, 450 | self.staging_buffer.as_ref().unwrap(), 451 | 0, 452 | result_buf_size, 453 | ); 454 | 455 | self.queue.submit(std::iter::once(encoder.finish())); 456 | self.matching_ongoing = true; 457 | } 458 | } 459 | --------------------------------------------------------------------------------