├── .github └── workflows │ └── actions.yml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── build.bat ├── js └── wasm_nn.js ├── pic └── wasm_nn_6arm.png ├── run.bat ├── src ├── data.rs ├── lib.rs └── nn.rs └── wasm_nn.html /.github/workflows/actions.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | runs-on: windows-latest 8 | steps: 9 | - uses: actions/checkout@v1 10 | - name: "WASM build" 11 | run: | 12 | rustup target add wasm32-unknown-unknown 13 | cargo build --target=wasm32-unknown-unknown 14 | - name: "WASM test" 15 | run: | 16 | rustup target add wasm32-unknown-unknown 17 | cargo test -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | *.wasm 3 | Cargo.lock -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "wasm_nn" 3 | version = "0.1.0" 4 | authors = ["Donough Liu "] 5 | edition = "2018" 6 | 7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 8 | [lib] 9 | crate-type = ["cdylib"] 10 | 11 | [dependencies] 12 | ndarray = "0.13.1" 13 | ndarray-rand = "0.11.0" 14 | rand = "0.7.3" 15 | rand_core = "0.5.1" 16 | once_cell = "1.4.0" 17 | #console_error_panic_hook = "0.1.6" but it seems to use the wasm_bindgen which is not good as this is a minimal dependency project -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE 2 | Version 2, December 2004 3 | 4 | Copyright (C) 2004 Sam Hocevar 5 | 6 | Everyone is permitted to copy and distribute verbatim or modified 7 | copies of this license document, and changing it is allowed as long 8 | as the name is changed. 9 | 10 | DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE 11 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 12 | 13 | 0. You just DO WHAT THE FUCK YOU WANT TO. 14 | 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Rust + WebAssembly + Neural Network 2 | 3 | ![demo pic](./pic/wasm_nn_6arm.png) 4 | 5 | Try to run Neural Network on web browser. 6 | 7 | **Attention: you need a http server to run locally. Because cross-origin requests are not supported for the file protocol scheme.** 8 | 9 | Demo page: https://ldm0.com/wasm_nn.html -------------------------------------------------------------------------------- /build.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | cargo build --release --target=wasm32-unknown-unknown 3 | copy target\wasm32-unknown-unknown\release\wasm_nn.wasm .\wasm\wasm_nn.wasm -------------------------------------------------------------------------------- /js/wasm_nn.js: -------------------------------------------------------------------------------- 1 | // Ratio: [0., 1.) 2 | function hue(ratio) { 3 | let rgb = ""; 4 | let hue = 6 * ratio; 5 | let integer_part = Math.floor(hue); 6 | let fractal_part = Math.round((hue - integer_part) * 255); 7 | switch (integer_part) { 8 | case 0: rgb = "#" + "FF" + zero_padding(fractal_part) + "00"; break; 9 | case 1: rgb = "#" + zero_padding(255 - fractal_part) + "FF" + "00"; break; 10 | case 2: rgb = "#" + "00" + "FF" + zero_padding(fractal_part); break; 11 | case 3: rgb = "#" + "00" + zero_padding(255 - fractal_part) + "FF"; break; 12 | case 4: rgb = "#" + zero_padding(fractal_part) + "00" + "FF"; break; 13 | case 5: rgb = "#" + "FF" + "00" + zero_padding(255 - fractal_part); break; 14 | } 15 | return rgb; 16 | } 17 | 18 | // Number to zero-padding hex string 19 | function zero_padding(num) { 20 | let result = Math.round(num).toString(16); 21 | if (result.length < 2) 22 | result = "0" + result; 23 | return result; 24 | } 25 | 26 | async function main(){ 27 | const canvas_width = 256; 28 | const canvas_height = 256; 29 | const canvas_buffer_size = canvas_width * canvas_height * 4; 30 | const data_span_radius = 1.; 31 | 32 | const canvas = document.getElementById("main_canvas"); 33 | const canvas_context = canvas.getContext("2d"); 34 | 35 | const control_button = document.getElementById("control_button"); 36 | const loss_reveal = document.getElementById("loss_reveal"); 37 | const input_data_spin_span = document.getElementById("input_data_spin_span"); 38 | const input_data_num = document.getElementById("input_data_num"); 39 | const input_num_classes = document.getElementById("input_num_classes"); 40 | const input_data_gen_rand_max = document.getElementById("input_data_gen_rand_max"); 41 | const input_network_gen_rand_max = document.getElementById("input_network_gen_rand_max"); 42 | const input_fc_size = document.getElementById("input_fc_size"); 43 | const input_descent_rate = document.getElementById("input_descent_rate"); 44 | const input_regular_rate = document.getElementById("input_regular_rate"); 45 | 46 | function get_settings() { 47 | let settings = [ 48 | parseFloat(input_data_spin_span.value), 49 | parseInt(input_data_num.value), 50 | parseInt(input_num_classes.value), 51 | parseFloat(input_data_gen_rand_max.value), 52 | parseFloat(input_network_gen_rand_max.value), 53 | parseInt(input_fc_size.value), 54 | parseFloat(input_descent_rate.value), 55 | parseFloat(input_regular_rate.value), 56 | ]; 57 | return settings; 58 | } 59 | 60 | function envs() { 61 | // For debug 62 | function log_u64(x) { 63 | console.log(x); 64 | } 65 | 66 | function draw_point(x, y, label) { 67 | canvas_context.beginPath(); 68 | canvas_context.arc(x, y, 2, 0, 2 * Math.PI); 69 | canvas_context.fillStyle = hue(label) + "7f"; 70 | canvas_context.fill(); 71 | } 72 | let env = { 73 | log_u64, 74 | draw_point, 75 | }; 76 | return env; 77 | } 78 | 79 | const kernel_stream = await fetch("../wasm/wasm_nn.wasm"); 80 | const kernel = await WebAssembly.instantiateStreaming(kernel_stream, { env: envs()}); 81 | 82 | const {alloc: kernel_alloc, free: kernel_free} = kernel.instance.exports; 83 | const { 84 | init: kernel_init, 85 | train: kernel_train, 86 | draw_prediction: kernel_draw_prediction, 87 | draw_points: kernel_draw_points 88 | } = kernel.instance.exports; 89 | const {memory} = kernel.instance.exports; 90 | 91 | // Alloc graphic buffer 92 | // Should not freed because you don't know when the drawing completes 93 | // Maybe not completed forever... 94 | //kernel_free(canvas_buffer_ptr, buffer_size); 95 | const canvas_buffer_ptr = kernel_alloc(canvas_buffer_size); 96 | 97 | function draw_frame() { 98 | // multiply 1.1 for spadding 99 | kernel_draw_prediction(canvas_buffer_ptr, canvas_width, canvas_height, data_span_radius * 2); 100 | const canvas_buffer_array = new Uint8ClampedArray(memory.buffer, canvas_buffer_ptr, canvas_buffer_size); 101 | const canvas_image_data = new ImageData(canvas_buffer_array, canvas_width, canvas_height) 102 | canvas_context.putImageData(canvas_image_data, 0, 0); 103 | 104 | kernel_draw_points(canvas_width, canvas_height, data_span_radius * 2); 105 | } 106 | 107 | function nninit(settings) { 108 | // Gen data for training. Check source code of kernel for parameter meaning 109 | kernel_init( 110 | data_span_radius, 111 | settings[0], 112 | settings[1], 113 | settings[2], 114 | settings[3], 115 | settings[4], 116 | settings[5], 117 | settings[6], 118 | settings[7], 119 | ); 120 | // draw a fram to avoid blank canvas 121 | draw_frame(); 122 | } 123 | 124 | nninit(get_settings()); 125 | 126 | { 127 | let run = false; 128 | 129 | { 130 | let counter = 0; 131 | function nnloop() { 132 | if (run) { 133 | let loss = kernel_train(); 134 | if (counter >= 10) { 135 | counter = 0; 136 | loss_reveal.innerText = "loss: " + loss; 137 | window.requestAnimationFrame(draw_frame); 138 | } 139 | setTimeout(nnloop, 0); 140 | ++counter; 141 | } 142 | } 143 | } 144 | 145 | function nnstart() { 146 | run = true; 147 | nnloop(); 148 | } 149 | 150 | function nnstop() { 151 | run = false; 152 | } 153 | } 154 | 155 | { 156 | let run = false; 157 | let current_settings = get_settings(); 158 | 159 | control_button.onclick = () => { 160 | if (run) { 161 | run = false; 162 | control_button.innerText = "run"; 163 | nnstop(); 164 | } else { 165 | run = true; 166 | control_button.innerText = "stop"; 167 | let new_settings = get_settings(); 168 | if (JSON.stringify(current_settings) !== JSON.stringify(new_settings)) { 169 | current_settings = new_settings; 170 | nninit(current_settings); 171 | } 172 | nnstart(); 173 | } 174 | } 175 | } 176 | } 177 | 178 | main(); -------------------------------------------------------------------------------- /pic/wasm_nn_6arm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ldm0/wasm_nn/da6149a8e1d4c22a1c1496058401ea44aeadfc4d/pic/wasm_nn_6arm.png -------------------------------------------------------------------------------- /run.bat: -------------------------------------------------------------------------------- 1 | :: just use a http server because wasm module cannot be get with file request 2 | emrun . -------------------------------------------------------------------------------- /src/data.rs: -------------------------------------------------------------------------------- 1 | use ndarray::prelude::*; 2 | use ndarray::{stack, Array, Array1, Array2, Axis}; // for matrices 3 | 4 | use ndarray_rand::rand_distr::StandardNormal; // for randomness 5 | use ndarray_rand::RandomExt; // for randomness 6 | use rand::rngs::SmallRng; 7 | use rand::SeedableRng; // for from_seed // for randomness 8 | 9 | use std::f32::consts::PI; // for math functions 10 | 11 | /// point data with labels 12 | #[derive(Default)] 13 | pub struct Data { 14 | pub points: Array2, // points position 15 | pub labels: Array1, // points labels 16 | } 17 | 18 | impl Data { 19 | // num_sample: num of data for each label class 20 | // radius: radius of the circle of data points position 21 | // span: each data arm ratate span 22 | pub fn init( 23 | &mut self, 24 | num_classes: u32, 25 | num_samples: u32, 26 | radius: f32, 27 | span: f32, 28 | rand_max: f32, 29 | ) { 30 | // For array creating convenience 31 | let num_classes = num_classes as usize; 32 | let num_samples = num_samples as usize; 33 | 34 | let num_data = num_classes * num_samples; 35 | self.points = Array::zeros((num_data, 2)); 36 | self.labels = Array::zeros(num_data); 37 | for i in 0..num_classes { 38 | let rho = Array::linspace(0f32, radius, num_samples); 39 | let begin = i as f32 * (2f32 * PI / num_classes as f32); 40 | 41 | let seed = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; 42 | let mut rng = SmallRng::from_seed(seed); 43 | let theta = Array::linspace(begin, begin - span, num_samples) 44 | // will be changed later to use span to generate randomness to avoid points flickering 45 | + Array::::random_using(num_samples, StandardNormal, &mut rng) * rand_max; 46 | 47 | let xs = (theta.mapv(f32::sin) * &rho) 48 | .into_shape((num_samples, 1)) 49 | .unwrap(); 50 | let ys = (theta.mapv(f32::cos) * &rho) 51 | .into_shape((num_samples, 1)) 52 | .unwrap(); 53 | let mut class_points = self 54 | .points 55 | .slice_mut(s![i * num_samples..(i + 1) * num_samples, ..]); 56 | class_points.assign(&stack![Axis(1), xs, ys]); 57 | let mut class_labels = self 58 | .labels 59 | .slice_mut(s![i * num_samples..(i + 1) * num_samples]); 60 | class_labels.fill(i as u32); 61 | // or: 62 | //class_labels.assign(&(Array::ones(num_samples) * i)); 63 | } 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | mod data; 2 | mod nn; 3 | 4 | use std::mem; 5 | use std::slice; 6 | //use std::os::raw::{/*c_double, c_int, */c_void}; // for js functions imports 7 | use once_cell::sync::Lazy; 8 | use std::sync::Mutex; // for lazy_static // for global variables 9 | 10 | use ndarray::prelude::*; 11 | use ndarray::{array, Array, Array1, Array3, Axis, Zip}; 12 | 13 | use data::Data; 14 | use nn::Network; 15 | 16 | #[derive(Default)] 17 | struct MetaData { 18 | fc_size: u32, 19 | num_classes: u32, 20 | descent_rate: f32, 21 | regular_rate: f32, 22 | } 23 | 24 | #[derive(Default)] 25 | struct CriticalSection(MetaData, Data, Network); 26 | 27 | // Imported js functions 28 | extern "C" { 29 | // for debug 30 | fn log_u64(num: u32); 31 | // for data pointer draw 32 | // x,y: the offset from upper left corner 33 | // label: a fractal which represents the position current label is in total 34 | // position range 35 | fn draw_point(x: u32, y: u32, label_ratio: f32); 36 | } 37 | 38 | static DATA: Lazy> = Lazy::new(|| Mutex::default()); 39 | 40 | #[no_mangle] 41 | // This function returns the offset of the allocated buffer in wasm memory 42 | pub fn alloc(size: u32) -> *mut u8 { 43 | let mut buffer: Vec = Vec::with_capacity(size as usize); 44 | let buffer_ptr = buffer.as_mut_ptr(); 45 | mem::forget(buffer); 46 | buffer_ptr 47 | } 48 | 49 | #[no_mangle] 50 | pub fn free(buffer_ptr: *mut u8, size: u32) { 51 | let _ = unsafe { Vec::from_raw_parts(buffer_ptr, 0, size as usize) }; 52 | } 53 | 54 | #[no_mangle] 55 | pub fn init( 56 | data_radius: f32, 57 | data_spin_span: f32, 58 | data_num: u32, 59 | num_classes: u32, 60 | data_gen_rand_max: f32, 61 | network_gen_rand_max: f32, 62 | fc_size: u32, 63 | descent_rate: f32, 64 | regular_rate: f32, 65 | ) { 66 | // Thanks rust compiler :-/ 67 | let ref mut tmp = *DATA.lock().unwrap(); 68 | let CriticalSection(metadata, data, network) = tmp; 69 | 70 | metadata.fc_size = fc_size; 71 | metadata.num_classes = num_classes; 72 | metadata.descent_rate = descent_rate; 73 | metadata.regular_rate = regular_rate; 74 | 75 | // Num of each data class is the same 76 | data.init( 77 | num_classes, 78 | data_num / num_classes, 79 | data_radius, 80 | data_spin_span, 81 | data_gen_rand_max, 82 | ); 83 | 84 | // Input of this network is two dimension points 85 | // output label is sparsed num_classes integer 86 | const PLANE_DIMENSION: u32 = 2; 87 | network.init(PLANE_DIMENSION, fc_size, num_classes, network_gen_rand_max); 88 | } 89 | 90 | #[no_mangle] 91 | pub fn train() -> f32 { 92 | let ref mut tmp = *DATA.lock().unwrap(); 93 | // Jesus, thats magic 94 | let CriticalSection(ref metadata, ref data, ref mut network) = *tmp; 95 | 96 | let regular_rate = metadata.regular_rate; 97 | let descent_rate = metadata.descent_rate; 98 | 99 | let (fc_layer, softmax) = network.forward_propagation(&data.points); 100 | let (dw1, db1, dw2, db2) = network.back_propagation( 101 | &data.points, 102 | &fc_layer, 103 | &softmax, 104 | &data.labels, 105 | regular_rate, 106 | ); 107 | let loss = network.loss(&softmax, &data.labels, regular_rate); 108 | network.descent(&dw1, &db1, &dw2, &db2, descent_rate); 109 | 110 | let (data_loss, regular_loss) = loss; 111 | data_loss + regular_loss 112 | } 113 | 114 | // Plot classified backgroud to canvas 115 | // span_least The least span of area should be drawn to canvas(because usually the canvas is not square) 116 | #[no_mangle] 117 | pub fn draw_prediction(canvas: *mut u8, width: u32, height: u32, span_least: f32) { 118 | // assert!(span_least > 0f32); 119 | let width = width as usize; 120 | let height = height as usize; 121 | 122 | // `data` will be used to draw data points 123 | let ref tmp = *DATA.lock().unwrap(); 124 | let CriticalSection(metadata, _, network) = tmp; 125 | 126 | let num_classes = metadata.num_classes as usize; 127 | 128 | let r: Array1 = Array::linspace(0f32, 200f32, num_classes); 129 | let g: Array1 = Array::linspace(0f32, 240f32, num_classes); 130 | let b: Array1 = Array::linspace(0f32, 255f32, num_classes); 131 | 132 | let span_per_pixel = span_least / width.min(height) as f32; 133 | let span_height = height as f32 * span_per_pixel; 134 | let span_width = width as f32 * span_per_pixel; 135 | 136 | let width_max = span_width / 2f32; 137 | let width_min = -span_width / 2f32; 138 | let height_max = span_height / 2f32; 139 | let height_min = -span_height / 2f32; 140 | 141 | let x_axis: Array1 = Array::linspace(width_min, width_max, width); 142 | let y_axis: Array1 = Array::linspace(height_min, height_max, height); 143 | 144 | // coordination 145 | let mut grid: Array3 = Array::zeros((height, width, 2)); 146 | for y in 0..height { 147 | for x in 0..width { 148 | let coord = array![x_axis[[x]], y_axis[[y]]]; 149 | let mut slice = grid.slice_mut(s![y, x, ..]); 150 | slice.assign(&coord); 151 | } 152 | } 153 | 154 | let xys = grid.into_shape((height * width, 2)).unwrap(); 155 | let (_, softmax) = network.forward_propagation(&xys); 156 | let mut labels: Array1 = Array::zeros(height * width); 157 | for (y, row) in softmax.axis_iter(Axis(0)).enumerate() { 158 | let mut maxx = 0 as usize; 159 | let mut max = row[[0]]; 160 | for (x, col) in row.iter().enumerate() { 161 | if *col > max { 162 | maxx = x; 163 | max = *col; 164 | } 165 | } 166 | labels[[y]] = maxx; 167 | } 168 | let grid_label = labels.into_shape((height, width)).unwrap(); 169 | 170 | let canvas_size = width * height * 4; 171 | let canvas: &mut [u8] = unsafe { slice::from_raw_parts_mut(canvas, canvas_size) }; 172 | for y in 0..height { 173 | for x in 0..width { 174 | // assume rgba 175 | canvas[4 * (y * width + x) + 0] = r[[grid_label[[y, x]]]] as u8; 176 | canvas[4 * (y * width + x) + 1] = g[[grid_label[[y, x]]]] as u8; 177 | canvas[4 * (y * width + x) + 2] = b[[grid_label[[y, x]]]] as u8; 178 | canvas[4 * (y * width + x) + 3] = 0xFF as u8; 179 | } 180 | } 181 | } 182 | 183 | // check parameters for function below which draws predictions 184 | #[no_mangle] 185 | pub fn draw_points(width: u32, height: u32, span_least: f32) { 186 | let ref tmp = *DATA.lock().unwrap(); 187 | let CriticalSection(metadata, data, _) = tmp; 188 | let num_classes = metadata.num_classes as f32; 189 | 190 | let pixel_per_span = width.min(height) as f32 / span_least; 191 | let labels = &data.labels; 192 | let points = &data.points; 193 | let points_x = points.index_axis(Axis(1), 0); 194 | let points_y = points.index_axis(Axis(1), 1); 195 | Zip::from(labels) 196 | .and(points_x) 197 | .and(points_y) 198 | .apply(|&label, &x, &y| { 199 | // Assume data position is limited in: 200 | // [-data_radius - data_rand_max, data_radius + data_rand_max] 201 | let x = (x * pixel_per_span) as i64 + width as i64 / 2; 202 | let y = (y * pixel_per_span) as i64 + height as i64 / 2; 203 | 204 | // if points can show in canvas 205 | if !(x >= width as i64 || x < 0 || y >= height as i64 || y < 0) { 206 | // floor 207 | let x = x as u32; 208 | let y = y as u32; 209 | let label_ratio = label as f32 / num_classes; 210 | unsafe { 211 | draw_point(x, y, label_ratio); 212 | } 213 | } 214 | }); 215 | } 216 | 217 | #[cfg(test)] 218 | mod kernel_test { 219 | use super::*; 220 | 221 | static POINT_DRAW_TIMES: Lazy> = Lazy::new(|| Mutex::new(0)); 222 | 223 | // Override the extern functions 224 | #[no_mangle] 225 | extern "C" fn draw_point(_: u32, _: u32, _: f32) { 226 | *POINT_DRAW_TIMES.lock().unwrap() += 1; 227 | } 228 | 229 | use std::f32::consts::PI; // for math functions 230 | 231 | const DATA_GEN_RADIUS: f32 = 1f32; 232 | const SPIN_SPAN: f32 = PI; 233 | const NUM_CLASSES: u32 = 3; 234 | const DATA_NUM: u32 = 300; 235 | const FC_SIZE: u32 = 100; 236 | const REGULAR_RATE: f32 = 0.001f32; 237 | const DESCENT_RATE: f32 = 1f32; 238 | const DATA_GEN_RAND_MAX: f32 = 0.25f32; 239 | const NETWORK_GEN_RAND_MAX: f32 = 0.1f32; 240 | 241 | #[test] 242 | fn test_all() { 243 | init( 244 | DATA_GEN_RADIUS, 245 | SPIN_SPAN, 246 | DATA_NUM, 247 | NUM_CLASSES, 248 | DATA_GEN_RAND_MAX, 249 | NETWORK_GEN_RAND_MAX, 250 | FC_SIZE, 251 | DESCENT_RATE, 252 | REGULAR_RATE, 253 | ); 254 | let loss_before: f32 = train(); 255 | for _ in 0..50 { 256 | let loss = train(); 257 | assert!(loss < loss_before * 1.1f32); 258 | } 259 | } 260 | 261 | #[test] 262 | fn test_buffer_allocation() { 263 | let buffer = alloc(114514); 264 | free(buffer, 114514); 265 | } 266 | 267 | #[test] 268 | fn test_draw_prediction() { 269 | init( 270 | DATA_GEN_RADIUS, 271 | SPIN_SPAN, 272 | DATA_NUM, 273 | NUM_CLASSES, 274 | DATA_GEN_RAND_MAX, 275 | NETWORK_GEN_RAND_MAX, 276 | FC_SIZE, 277 | DESCENT_RATE, 278 | REGULAR_RATE, 279 | ); 280 | let width = 100; 281 | let height = 100; 282 | let buffer = alloc(width * height * 4); 283 | draw_prediction(buffer, width, height, 2f32); 284 | free(buffer, width * height * 4); 285 | } 286 | 287 | #[test] 288 | fn test_draw_points() { 289 | // Because cargo test is default multi-thread, put them together to avoid data_racing 290 | 291 | // span_least * 1.1 for padding 292 | 293 | init( 294 | DATA_GEN_RADIUS, 295 | SPIN_SPAN, 296 | DATA_NUM, 297 | NUM_CLASSES, 298 | DATA_GEN_RAND_MAX, 299 | NETWORK_GEN_RAND_MAX, 300 | FC_SIZE, 301 | DESCENT_RATE, 302 | REGULAR_RATE, 303 | ); 304 | 305 | // test small resolution drawing 306 | *POINT_DRAW_TIMES.lock().unwrap() = 0; 307 | draw_points(1, 1, DATA_GEN_RADIUS * 2f32 * 1.1f32); 308 | assert_eq!(DATA_NUM, *POINT_DRAW_TIMES.lock().unwrap()); 309 | 310 | // test tall screen drawing 311 | *POINT_DRAW_TIMES.lock().unwrap() = 0; 312 | draw_points(1, 100, DATA_GEN_RADIUS * 2f32 * 1.1f32); 313 | assert_eq!(DATA_NUM, *POINT_DRAW_TIMES.lock().unwrap()); 314 | 315 | // test flat screen drawing 316 | *POINT_DRAW_TIMES.lock().unwrap() = 0; 317 | draw_points(1, 100, DATA_GEN_RADIUS * 2f32 * 1.1f32); 318 | assert_eq!(DATA_NUM, *POINT_DRAW_TIMES.lock().unwrap()); 319 | 320 | // test square screen drawing 321 | *POINT_DRAW_TIMES.lock().unwrap() = 0; 322 | draw_points(100, 100, DATA_GEN_RADIUS * 2f32 * 1.1f32); 323 | assert_eq!(DATA_NUM, *POINT_DRAW_TIMES.lock().unwrap()); 324 | 325 | // test huge screen drawing 326 | *POINT_DRAW_TIMES.lock().unwrap() = 0; 327 | draw_points(10000000, 1000000, DATA_GEN_RADIUS * 2f32 * 1.1f32); 328 | assert_eq!(DATA_NUM, *POINT_DRAW_TIMES.lock().unwrap()); 329 | } 330 | } 331 | -------------------------------------------------------------------------------- /src/nn.rs: -------------------------------------------------------------------------------- 1 | use ndarray::{Array, Array1, Array2, Axis, Zip}; // for matrices 2 | 3 | use ndarray_rand::rand_distr::StandardNormal; // for randomness 4 | use ndarray_rand::RandomExt; // for randomness 5 | use rand::rngs::SmallRng; 6 | use rand::SeedableRng; // for from_seed // for randomness 7 | 8 | /** 9 | * single layer neural network 10 | */ 11 | #[derive(Default)] 12 | pub struct Network { 13 | pub w1: Array2, 14 | pub b1: Array2, 15 | pub w2: Array2, 16 | pub b2: Array2, 17 | } 18 | 19 | impl Network { 20 | pub fn init(&mut self, input_size: u32, fc_size: u32, output_size: u32, rand_max: f32) { 21 | let input_size = input_size as usize; 22 | let fc_size = fc_size as usize; 23 | let output_size = output_size as usize; 24 | // according to rand::rngs/mod.rs line 121 25 | let seed = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; 26 | let mut rng = SmallRng::from_seed(seed); 27 | *self = Network { 28 | w1: Array::random_using((input_size, fc_size), StandardNormal, &mut rng) * rand_max, 29 | w2: Array::random_using((fc_size, output_size), StandardNormal, &mut rng) * rand_max, 30 | b1: Array::random_using((1, fc_size), StandardNormal, &mut rng) * rand_max, 31 | b2: Array::random_using((1, output_size), StandardNormal, &mut rng) * rand_max, 32 | /* random version but commented because strange behaviour of random in wasm leads to panic 33 | w1: Array::ones((input_size, fc_size)) * rand_max, 34 | w2: Array::ones((fc_size, output_size)) * rand_max, 35 | b1: Array::ones((1, fc_size)) * rand_max, 36 | b2: Array::ones((1, output_size)) * rand_max, 37 | */ 38 | } 39 | } 40 | 41 | pub fn descent( 42 | &mut self, 43 | dw1: &Array2, 44 | db1: &Array2, 45 | dw2: &Array2, 46 | db2: &Array2, 47 | descent_rate: f32, 48 | ) { 49 | let rate = descent_rate; 50 | self.w1 -= &(rate * dw1); 51 | self.b1 -= &(rate * db1); 52 | self.w2 -= &(rate * dw2); 53 | self.b2 -= &(rate * db2); 54 | } 55 | 56 | pub fn forward_propagation(&self, points: &Array2) -> (Array2, Array2) { 57 | let act1 = &points.dot(&self.w1) + &self.b1; 58 | let fc_layer = act1.mapv(|x| x.max(0f32)); // relu process 59 | let act2 = &fc_layer.dot(&self.w2) + &self.b2; 60 | let scores = act2; 61 | let exp_scores = scores.mapv(f32::exp); 62 | let softmax = &exp_scores / &exp_scores.sum_axis(Axis(1)).insert_axis(Axis(1)); 63 | // println!("{:#?}", softmax); 64 | (fc_layer, softmax) 65 | } 66 | 67 | pub fn loss( 68 | &self, 69 | softmax: &Array2, 70 | labels: &Array1, 71 | regular_rate: f32, 72 | ) -> (f32, f32) { 73 | let num_data = softmax.nrows(); 74 | let mut probs_correct: Array1 = Array::zeros(num_data); 75 | Zip::from(&mut probs_correct) 76 | .and(softmax.genrows()) 77 | .and(labels) 78 | .apply(|prob_correct, prob, &label| { 79 | *prob_correct = prob[label as usize]; 80 | }); 81 | let infos = probs_correct.mapv(|x| -f32::ln(x)); 82 | //println!("{:#?}", &probs_correct); 83 | let data_loss = infos.mean().unwrap(); 84 | let regular_loss = 85 | 0.5f32 * regular_rate * ((&self.w1 * &self.w1).sum() + (&self.w2 * &self.w2).sum()); 86 | //println!("data loss: {} regular loss: {}", data_loss, regular_loss); 87 | (data_loss, regular_loss) 88 | } 89 | 90 | pub fn back_propagation( 91 | &self, 92 | points: &Array2, 93 | fc_layer: &Array2, 94 | softmax: &Array2, 95 | labels: &Array1, 96 | regular_rate: f32, 97 | ) -> (Array2, Array2, Array2, Array2) { 98 | let num_data = softmax.nrows(); 99 | let mut dscores = softmax.clone(); 100 | for (i, mut dscore) in dscores.axis_iter_mut(Axis(0)).enumerate() { 101 | dscore[[labels[[i]] as usize]] -= 1f32; 102 | } 103 | dscores /= num_data as f32; 104 | let dact2 = dscores; 105 | let dfc_layer = dact2.dot(&self.w2.t()); 106 | let mut dact1 = dfc_layer.clone(); 107 | Zip::from(&mut dact1).and(fc_layer).apply(|act1, &fc| { 108 | if fc == 0f32 { 109 | *act1 = 0f32; 110 | } 111 | }); 112 | 113 | let dw2 = fc_layer.t().dot(&dact2) + regular_rate * &self.w2; 114 | let db2 = dact2.sum_axis(Axis(0)).insert_axis(Axis(0)); 115 | let dw1 = points.t().dot(&dact1) + regular_rate * &self.w1; 116 | let db1 = dact1.sum_axis(Axis(0)).insert_axis(Axis(0)); 117 | (dw1, db1, dw2, db2) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /wasm_nn.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | wasm with rust for single layer neural netowrk 6 | 7 | 8 | 9 |
10 | Single layer of fully connected neural netowrk
11 | Powered by WebAssembly, Rust and ♥
12 | WebAssembly Module may need some time to load. Please be patient. :-)
13 | Source Code: https://github.com/ldm0/wasm_nn 14 |
15 |
16 | main_canvas 17 |
18 |
19 | 20 |
21 |
22 | loss: NaN 23 |
24 | 29 |
30 | arm rotate angle 31 |
32 |
33 | overall number of data points 34 |
35 |
36 | num of label 37 |
38 |
39 | maximum data random offset 40 |
41 |
42 | maximum network weight random offset 43 |
44 |
45 | size of full connected layer 46 |
47 |
48 | network training descent rate 49 |
50 |
51 | network regular rate 52 |
53 | 54 | --------------------------------------------------------------------------------