├── .gitignore ├── gui.png ├── .cargo └── config.toml ├── assets ├── road.png ├── Magero.ttf ├── agent.png ├── car-icon.png ├── end-point.png ├── enemy-red.png ├── flag-top.png ├── bound-truck.png ├── enemy-blue-1.png ├── enemy-blue-2.png ├── enemy-blue-3.png ├── enemy-red-2.png ├── enemy-truck.png ├── flag-bottom.png ├── enemy-yellow-1.png ├── enemy-yellow-2.png └── enemy-yellow-3.png ├── src ├── lib.cairo ├── lib.rs ├── resources.rs ├── configs.rs ├── rays.cairo ├── vehicle.cairo ├── population.rs ├── main.rs ├── model.cairo ├── model_test.cairo ├── enemy.rs ├── car.rs ├── dojo.rs ├── nn.rs ├── math.cairo ├── gui.rs ├── enemy.cairo └── racer.cairo ├── Scarb.toml ├── Cargo.toml ├── LICENSE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | out 2 | target 3 | assets -------------------------------------------------------------------------------- /gui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cartridge-gg/drive-ai/HEAD/gui.png -------------------------------------------------------------------------------- /.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [target.wasm32-unknown-unknown] 2 | runner = "wasm-server-runner" 3 | -------------------------------------------------------------------------------- /assets/road.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cartridge-gg/drive-ai/HEAD/assets/road.png -------------------------------------------------------------------------------- /assets/Magero.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cartridge-gg/drive-ai/HEAD/assets/Magero.ttf -------------------------------------------------------------------------------- /assets/agent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cartridge-gg/drive-ai/HEAD/assets/agent.png -------------------------------------------------------------------------------- /assets/car-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cartridge-gg/drive-ai/HEAD/assets/car-icon.png -------------------------------------------------------------------------------- /assets/end-point.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cartridge-gg/drive-ai/HEAD/assets/end-point.png -------------------------------------------------------------------------------- /assets/enemy-red.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cartridge-gg/drive-ai/HEAD/assets/enemy-red.png -------------------------------------------------------------------------------- /assets/flag-top.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cartridge-gg/drive-ai/HEAD/assets/flag-top.png -------------------------------------------------------------------------------- /assets/bound-truck.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cartridge-gg/drive-ai/HEAD/assets/bound-truck.png -------------------------------------------------------------------------------- /assets/enemy-blue-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cartridge-gg/drive-ai/HEAD/assets/enemy-blue-1.png -------------------------------------------------------------------------------- /assets/enemy-blue-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cartridge-gg/drive-ai/HEAD/assets/enemy-blue-2.png -------------------------------------------------------------------------------- /assets/enemy-blue-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cartridge-gg/drive-ai/HEAD/assets/enemy-blue-3.png -------------------------------------------------------------------------------- /assets/enemy-red-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cartridge-gg/drive-ai/HEAD/assets/enemy-red-2.png -------------------------------------------------------------------------------- /assets/enemy-truck.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cartridge-gg/drive-ai/HEAD/assets/enemy-truck.png -------------------------------------------------------------------------------- /assets/flag-bottom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cartridge-gg/drive-ai/HEAD/assets/flag-bottom.png -------------------------------------------------------------------------------- /assets/enemy-yellow-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cartridge-gg/drive-ai/HEAD/assets/enemy-yellow-1.png -------------------------------------------------------------------------------- /assets/enemy-yellow-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cartridge-gg/drive-ai/HEAD/assets/enemy-yellow-2.png -------------------------------------------------------------------------------- /assets/enemy-yellow-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cartridge-gg/drive-ai/HEAD/assets/enemy-yellow-3.png -------------------------------------------------------------------------------- /src/lib.cairo: -------------------------------------------------------------------------------- 1 | mod enemy; 2 | 3 | mod math; 4 | mod model; 5 | mod racer; 6 | mod rays; 7 | mod vehicle; 8 | use vehicle::{Vehicle, VehicleTrait}; 9 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod car; 2 | pub mod configs; 3 | pub mod dojo; 4 | pub mod enemy; 5 | pub mod gui; 6 | pub mod nn; 7 | pub mod population; 8 | pub mod resources; 9 | 10 | pub use configs::*; 11 | pub use resources::*; 12 | -------------------------------------------------------------------------------- /Scarb.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "drive_ai" 3 | version = "0.1.0" 4 | cairo-version = "2.0.1" 5 | 6 | [cairo] 7 | sierra-replace-ids = true 8 | 9 | [dependencies] 10 | dojo = { git = "https://github.com/dojoengine/dojo" } 11 | cubit = { git = "https://github.com/tarrencev/cubit" } 12 | orion = { git = "https://github.com/danilowhk/orion" } 13 | 14 | [[target.dojo]] 15 | 16 | [tool.dojo.env] 17 | rpc_url = "http://localhost:5050/" 18 | 19 | # Default account for katana with seed = 0 20 | account_address = "0x03ee9e18edc71a6df30ac3aca2e0b02a198fbce19b7480a63a0d71cbd76652e0" 21 | private_key = "0x0300001800000000300000180000000000030000000000003006001800006600" 22 | -------------------------------------------------------------------------------- /src/resources.rs: -------------------------------------------------------------------------------- 1 | use bevy::prelude::*; 2 | 3 | #[derive(Resource, Default)] 4 | pub struct SimStats { 5 | pub num_cars_alive: usize, 6 | pub fitness: Vec, 7 | pub generation_count: u32, 8 | pub max_current_score: f32, 9 | } 10 | 11 | #[derive(Resource)] 12 | pub struct Settings { 13 | pub is_show_rays: bool, 14 | pub is_hide_rays_at_start: bool, 15 | pub start_next_generation: bool, 16 | pub restart_sim: bool, 17 | pub is_camera_follow: bool, 18 | } 19 | 20 | #[derive(Resource, Default)] 21 | pub struct BrainToDisplay(pub Vec>); 22 | 23 | #[derive(Resource)] 24 | pub struct MaxDistanceTravelled(pub f32); 25 | 26 | impl Default for Settings { 27 | fn default() -> Self { 28 | Self { 29 | is_show_rays: true, 30 | is_hide_rays_at_start: true, 31 | start_next_generation: false, 32 | restart_sim: false, 33 | is_camera_follow: true, 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "steering" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | bevy = { version = "0.10.1" } 10 | bevy-inspector-egui = "0.18.3" 11 | bevy_pancam = "0.8.0" 12 | bevy_prototype_debug_lines = "0.10.1" 13 | bevy_rapier2d = "0.21.0" 14 | bevy-tokio-tasks = "0.10" 15 | dojo-client = { git = "https://github.com/dojoengine/dojo", rev = "187a12e74ad1020d76a86a59315b55f9fb08891e" } 16 | eyre = "0.6" 17 | num = "0.4" 18 | rand = "0.8.5" 19 | starknet = "0.4.0" 20 | tokio = { version = "1", features = ["sync"] } 21 | url = "2.2.2" 22 | serde = { version = "1.0.130", features = ["derive"] } 23 | serde_json = "1.0.68" 24 | 25 | [workspace] 26 | resolver = "2" # Important! wgpu/Bevy needs this! 27 | 28 | # Enable a small amount of optimization in debug mode 29 | [profile.dev] 30 | opt-level = 1 31 | 32 | # Enable high optimizations for dependencies (incl. Bevy), but not for our code: 33 | [profile.dev.package."*"] 34 | opt-level = 3 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Bones-ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Drive AI 2 | 3 | Drive AI is a ambitious experiement pushing the boundaries of verfiable compute. The road-fighter inspired simulation environment is implemented with [Dojo](https://github.com/dojoengine/dojo), a provable game engine which enables zero knowledge proofs to be generated attesting to the integrity of a computation. 4 | 5 | In the simulation, a car is controlled by a neural network and is tasked with navigating traffic in it's environment. The car recives inputs from it's sensors, passes them to its neural network and outputs a control for the direction of the car. 6 | 7 | In this demo, a neural network is trained in the simulation environment offchain. Once a model is defined, it can be exported and benchmarked in the provable simulation. All physics and neural network inference occurs in realtime and zero knowledge proofs of the computation are produced asynchonously. The 8 | 9 | Built with [Dojo](https://github.com/dojoengine/dojo), [Rust](https://www.rust-lang.org/) and [Bevy](https://bevyengine.org/) game engine 10 | 11 | ![gui](/gui.png) 12 | 13 | ## Usage 14 | - Clone the repo 15 | ``` 16 | git clone git@github.com:cartridge-gg/drive-ai.git 17 | cd drive-ai 18 | ``` 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | - Run the simulation 29 | ``` 30 | cargo run 31 | ``` 32 | ## Configurations 33 | - The project config file is located at `src/configs.rs` 34 | 35 | ## Assets 36 | - [https://www.spriters-resource.com/nes/roadfighter/sheet/57232/](https://www.spriters-resource.com/nes/roadfighter/sheet/57232/) 37 | - Font - [https://code807.itch.io/magero](https://code807.itch.io/magero) 38 | 39 | ## Acknowledgements 40 | 41 | This game is based on the great work of the original rust implementation found here: https://github.com/bones-ai/rust-drive-ai 42 | -------------------------------------------------------------------------------- /src/configs.rs: -------------------------------------------------------------------------------- 1 | use bevy::prelude::Color; 2 | 3 | /// Main 4 | pub const NUM_ROAD_TILES: u32 = 1; 5 | pub const ROAD_SPRITE_W: f32 = 160.0; 6 | pub const ROAD_SPRITE_H: f32 = 288.0; 7 | pub const NUM_ENEMY_CARS: u32 = 140; 8 | pub const SPRITE_SCALE_FACTOR: f32 = 6.0; 9 | pub const BACKGROUND_COLOR: Color = Color::BLACK; 10 | pub const WINDOW_WIDTH: f32 = ROAD_SPRITE_W * SPRITE_SCALE_FACTOR; 11 | pub const WINDOW_HEIGHT: f32 = 1000.0; 12 | 13 | pub const ROAD_X_MIN: f32 = 238.0; // TODO: compute with SPRITE_SCALE_FACTOR 14 | pub const ROAD_X_MAX: f32 = 718.0; 15 | // TODO: subtract starting line (window / 2) 16 | pub const ROAD_W: f32 = ROAD_X_MAX - ROAD_X_MIN; 17 | // TODO: subtract goal position () 18 | pub const ROAD_H: f32 = WINDOW_HEIGHT * NUM_ROAD_TILES as f32; 19 | pub const DOJO_TO_BEVY_RATIO_X: f32 = ROAD_W / DOJO_GRID_WIDTH; 20 | pub const DOJO_TO_BEVY_RATIO_Y: f32 = ROAD_H / DOJO_GRID_HEIGHT; 21 | 22 | /// Car 23 | pub const NUM_AI_CARS: u32 = 1; 24 | pub const TURN_SPEED: f32 = 25.0; 25 | pub const CAR_THRUST: f32 = 5.0 * 100.0; 26 | pub const MAX_SPEED: f32 = 10.0 * 300.0; 27 | pub const FRICTION: f32 = 30.0 * 100.0; 28 | pub const MIN_SPEED_TO_STEER: f32 = 50.0; 29 | pub const NUM_RAY_CASTS: u32 = 8; 30 | pub const RAYCAST_SPREAD_ANGLE_DEG: f32 = 140.0; 31 | pub const RAYCAST_START_ANGLE_DEG: f32 = 20.0; 32 | pub const RAYCAST_MAX_TOI: f32 = 250.0; 33 | // pub const RAYCAST_THICKNESS: f32 = 0.3; 34 | 35 | /// NN 36 | pub const NUM_HIDDEN_NODES: usize = 15; 37 | pub const NUM_OUPUT_NODES: usize = 3; 38 | pub const NN_VIZ_NODE_RADIUS: f32 = 10.0; 39 | pub const NN_W_ACTIVATION_THRESHOLD: f64 = 0.3; 40 | pub const NN_S_ACTIVATION_THRESHOLD: f64 = 0.8; 41 | 42 | /// Others 43 | pub const FONT_RES_PATH: &str = "Magero.ttf"; 44 | 45 | /// Dojo 46 | pub const JSON_RPC_ENDPOINT: &str = "http://0.0.0.0:5050"; 47 | pub const ACCOUNT_ADDRESS: &str = 48 | "0x03ee9e18edc71a6df30ac3aca2e0b02a198fbce19b7480a63a0d71cbd76652e0"; // katana account 0 49 | pub const ACCOUNT_SECRET_KEY: &str = 50 | "0x0300001800000000300000180000000000030000000000003006001800006600"; 51 | pub const WORLD_ADDRESS: &str = "0x26065106fa319c3981618e7567480a50132f23932226a51c219ffb8e47daa84"; 52 | pub const DOJO_SYNC_INTERVAL: f32 = 0.1; 53 | pub const DOJO_GRID_WIDTH: f32 = 400.0; 54 | pub const DOJO_GRID_HEIGHT: f32 = 1000.0; 55 | pub const DOJO_ENEMIES_NB: u32 = 10; 56 | pub const MODEL_NAME: &str = "model"; 57 | -------------------------------------------------------------------------------- /src/rays.cairo: -------------------------------------------------------------------------------- 1 | use cubit::math::trig; 2 | use cubit::types::vec2::{Vec2, Vec2Trait}; 3 | use cubit::types::fixed::{Fixed, FixedTrait}; 4 | use array::{ArrayTrait, SpanTrait}; 5 | 6 | use drive_ai::math::{distance, intersects}; 7 | 8 | const DEG_90_IN_RADS: u128 = 28976077338029890953; 9 | const DEG_70_IN_RADS: u128 = 22536387234850959209; 10 | const DEG_50_IN_RADS: u128 = 16098473553126325695; 11 | const DEG_30_IN_RADS: u128 = 9658715196994321226; 12 | const DEG_10_IN_RADS: u128 = 3218956840862316756; 13 | 14 | const NUM_RAYS: usize = 5; 15 | const RAY_LENGTH: u128 = 2767011611056432742400; // 150 16 | 17 | #[derive(Serde, Drop)] 18 | struct Rays { 19 | segments: Span, 20 | } 21 | 22 | trait RaysTrait { 23 | fn new(position: Vec2, theta: Fixed) -> Rays; 24 | } 25 | 26 | impl RaysImpl of RaysTrait { 27 | fn new(position: Vec2, theta: Fixed) -> Rays { 28 | let ray_length = FixedTrait::new(RAY_LENGTH, false); 29 | 30 | let mut rays_theta = ArrayTrait::new(); 31 | // rays_theta.append(theta + FixedTrait::new(DEG_70_IN_RADS, true)); 32 | rays_theta.append(theta + FixedTrait::new(DEG_50_IN_RADS, true)); 33 | rays_theta.append(theta + FixedTrait::new(DEG_30_IN_RADS, true)); 34 | // rays_theta.append(theta + FixedTrait::new(DEG_10_IN_RADS, true)); 35 | rays_theta.append(theta); 36 | // rays_theta.append(theta + FixedTrait::new(DEG_10_IN_RADS, false)); 37 | rays_theta.append(theta + FixedTrait::new(DEG_30_IN_RADS, false)); 38 | rays_theta.append(theta + FixedTrait::new(DEG_50_IN_RADS, false)); 39 | // rays_theta.append(theta + FixedTrait::new(DEG_70_IN_RADS, false)); 40 | 41 | // TODO: Rays are semetric, we calculate half and invert 42 | let mut segments = ArrayTrait::new(); 43 | loop { 44 | match rays_theta.pop_front() { 45 | Option::Some(theta) => { 46 | // Endpoints of Ray 47 | // TODO: Rays are semetric, we calculate half and invert 48 | let cos_theta = trig::cos_fast(theta); 49 | let sin_theta = trig::sin_fast(theta); 50 | let delta1 = Vec2Trait::new(ray_length * sin_theta, ray_length * cos_theta); 51 | 52 | // TODO: We currently project out the center point? 53 | let q = position + delta1; 54 | 55 | segments.append(Ray { theta, cos_theta, sin_theta, p: position, q, }); 56 | }, 57 | Option::None(_) => { 58 | break (); 59 | } 60 | }; 61 | }; 62 | 63 | Rays { segments: segments.span() } 64 | } 65 | } 66 | 67 | #[derive(Serde, Drop)] 68 | struct Ray { 69 | theta: Fixed, 70 | cos_theta: Fixed, 71 | sin_theta: Fixed, 72 | p: Vec2, 73 | q: Vec2, 74 | } 75 | 76 | trait RayTrait { 77 | fn intersects(self: @Ray, p: Vec2, q: Vec2) -> bool; 78 | fn dist(self: @Ray, p: Vec2, q: Vec2) -> Fixed; 79 | } 80 | 81 | impl RayImpl of RayTrait { 82 | fn intersects(self: @Ray, p: Vec2, q: Vec2) -> bool { 83 | intersects(*self.p, *self.q, p, q) 84 | } 85 | fn dist(self: @Ray, p: Vec2, q: Vec2) -> Fixed { 86 | distance(*self.p, p, q, *self.cos_theta, *self.sin_theta) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /src/vehicle.cairo: -------------------------------------------------------------------------------- 1 | use array::ArrayTrait; 2 | use cubit::types::vec2::{Vec2, Vec2Trait}; 3 | use cubit::types::fixed::{Fixed, FixedTrait, FixedPrint, ONE_u128}; 4 | use drive_ai::racer::{CAR_HEIGHT, CAR_WIDTH}; 5 | use cubit::math::trig; 6 | use drive_ai::math; 7 | 8 | #[derive(Component, Serde, Drop, Copy)] 9 | struct Vehicle { 10 | // Current vehicle position 11 | position: Vec2, 12 | // Vehicle steer in radians -1/2π <= s <= 1/2π 13 | steer: Fixed, 14 | // Vehicle velocity 0 <= v <= 100 15 | speed: Fixed 16 | } 17 | 18 | impl VehicleSerdeLen of dojo::SerdeLen { 19 | #[inline(always)] 20 | fn len() -> usize { 21 | 8 22 | } 23 | } 24 | 25 | #[derive(Serde, Drop)] 26 | enum Direction { 27 | Straight: (), 28 | Left: (), 29 | Right: (), 30 | } 31 | 32 | #[derive(Serde, Drop)] 33 | struct Controls { 34 | steer: Direction, 35 | } 36 | 37 | // 10 degrees / pi/18 radians 38 | const TURN_STEP: u128 = 3219563738742341801; 39 | const HALF_PI: u128 = 28976077338029890953; 40 | 41 | trait VehicleTrait { 42 | fn control(ref self: Vehicle, controls: Controls) -> bool; 43 | fn drive(ref self: Vehicle); 44 | fn vertices(self: @Vehicle) -> Span; 45 | } 46 | 47 | impl VehicleImpl of VehicleTrait { 48 | fn control(ref self: Vehicle, controls: Controls) -> bool { 49 | let delta = match controls.steer { 50 | Direction::Straight(()) => FixedTrait::new(0, false), 51 | Direction::Left(()) => FixedTrait::new(TURN_STEP, true), 52 | Direction::Right(()) => FixedTrait::new(TURN_STEP, false), 53 | }; 54 | 55 | self.steer = self.steer + delta; 56 | 57 | (self.steer >= FixedTrait::new(HALF_PI, true) 58 | && self.steer <= FixedTrait::new(HALF_PI, false)) 59 | } 60 | 61 | fn drive(ref self: Vehicle) { 62 | // Velocity vector 63 | let x_comp = self.speed * trig::sin_fast(self.steer); 64 | let y_comp = self.speed * trig::cos_fast(self.steer); 65 | let v_0 = Vec2Trait::new(x_comp, y_comp); 66 | 67 | self.position = self.position + v_0; 68 | } 69 | 70 | fn vertices(self: @Vehicle) -> Span { 71 | math::vertices( 72 | *self.position, 73 | FixedTrait::new(CAR_WIDTH, false), 74 | FixedTrait::new(CAR_HEIGHT, false), 75 | *self.steer 76 | ) 77 | } 78 | } 79 | 80 | #[cfg(test)] 81 | mod tests { 82 | use debug::PrintTrait; 83 | use cubit::types::vec2::{Vec2, Vec2Trait}; 84 | use cubit::types::fixed::{Fixed, FixedTrait, FixedPrint}; 85 | use cubit::test::helpers::assert_precise; 86 | use array::SpanTrait; 87 | 88 | use super::{Vehicle, VehicleTrait, Controls, Direction, TURN_STEP}; 89 | 90 | const TEN: felt252 = 184467440737095516160; 91 | 92 | #[test] 93 | #[available_gas(2000000)] 94 | fn test_control() { 95 | let mut vehicle = Vehicle { 96 | position: Vec2Trait::new(FixedTrait::from_felt(TEN), FixedTrait::from_felt(TEN)), 97 | steer: FixedTrait::new(0_u128, false), 98 | speed: FixedTrait::from_felt(TEN) 99 | }; 100 | 101 | vehicle.control(Controls { steer: Direction::Left(()) }); 102 | assert(vehicle.steer == FixedTrait::new(TURN_STEP, true), 'invalid steer'); 103 | vehicle.control(Controls { steer: Direction::Left(()) }); 104 | assert(vehicle.steer == FixedTrait::new(2 * TURN_STEP, true), 'invalid steer'); 105 | vehicle.control(Controls { steer: Direction::Right(()) }); 106 | assert(vehicle.steer == FixedTrait::new(TURN_STEP, true), 'invalid steer'); 107 | vehicle.control(Controls { steer: Direction::Right(()) }); 108 | assert(vehicle.steer == FixedTrait::new(0, false), 'invalid steer'); 109 | } 110 | 111 | #[test] 112 | #[available_gas(20000000)] 113 | fn test_drive() { 114 | let mut vehicle = Vehicle { 115 | position: Vec2Trait::new(FixedTrait::from_felt(TEN), FixedTrait::from_felt(TEN)), 116 | steer: FixedTrait::new(0_u128, false), 117 | speed: FixedTrait::from_felt(TEN) 118 | }; 119 | 120 | vehicle.drive(); 121 | 122 | assert_precise(vehicle.position.x, TEN, 'invalid position x', Option::None(())); 123 | assert_precise( 124 | vehicle.position.y, 368934881474199059390, 'invalid position y', Option::None(()) 125 | ); 126 | 127 | vehicle.control(Controls { steer: Direction::Left(()) }); 128 | vehicle.drive(); 129 | 130 | // x: ~8.263527, y: ~29.84807913671 131 | assert_precise( 132 | vehicle.position.x, 152435159473296002840, 'invalid position x', Option::None(()) 133 | ); 134 | assert_precise( 135 | vehicle.position.y, 550599003738036609070, 'invalid position y', Option::None(()) 136 | ); 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /src/population.rs: -------------------------------------------------------------------------------- 1 | // use bevy::log; 2 | use bevy::prelude::*; 3 | // use rand::distributions::WeightedIndex; 4 | // use rand::prelude::Distribution; 5 | 6 | use crate::car::{Car, Fitness, Model}; 7 | // use crate::enemy::{spawn_bound_trucks, BoundControlTruck, Enemy}; 8 | // use crate::nn::Net; 9 | use crate::*; 10 | 11 | pub struct PopulationPlugin; 12 | 13 | impl Plugin for PopulationPlugin { 14 | fn build(&self, app: &mut App) { 15 | app.insert_resource(MaxDistanceTravelled(0.0)) 16 | // .add_startup_system(setup) 17 | .add_system(population_stats_system); 18 | // .add_systems((population_stats_system, generation_reset_system)); 19 | } 20 | } 21 | 22 | fn population_stats_system( 23 | // mut sim_stats: ResMut, 24 | mut max_distance_travelled: ResMut, 25 | // mut brain_on_display: ResMut, 26 | mut query: Query<(&Transform, &Model, &mut Fitness), With>, 27 | ) { 28 | let mut max_fitness = 0.0; 29 | // sim_stats.num_cars_alive = query.iter().len(); 30 | 31 | for (transform, _brain, mut fitness) in query.iter_mut() { 32 | fitness.0 = calc_fitness(transform); 33 | if fitness.0 > max_fitness { 34 | max_fitness = fitness.0; 35 | // brain_on_display.0 = brain.nn_outputs.clone(); 36 | // sim_stats.max_current_score = fitness.0; 37 | max_distance_travelled.0 = transform.translation.y; 38 | } 39 | } 40 | } 41 | 42 | // fn generation_reset_system( 43 | // mut commands: Commands, 44 | // asset_server: Res, 45 | // // mut settings: ResMut, 46 | // mut sim_stats: ResMut, 47 | // cars_query: Query<(Entity, &Model, &Fitness)>, 48 | // cars_count_query: Query>, 49 | // enemy_query: Query>, 50 | // bounds_truck_query: Query>, 51 | // ) { 52 | // let num_cars = cars_count_query.iter().count(); 53 | // if num_cars > 0 { 54 | // return; 55 | // } 56 | 57 | // bounds_truck_query.for_each(|t| commands.entity(t).despawn()); 58 | // enemy_query.for_each(|e| commands.entity(e).despawn()); 59 | 60 | // let mut fitnesses = Vec::new(); 61 | // let mut old_brains = Vec::new(); 62 | // for (e, brain, fitness) in cars_query.iter() { 63 | // fitnesses.push(fitness.0); 64 | // old_brains.push(brain.nn.clone()); 65 | 66 | // commands.entity(e).despawn(); 67 | // } 68 | 69 | // // let (max_fitness, gene_pool) = create_gene_pool(fitnesses); 70 | // // let mut rng = rand::thread_rng(); 71 | // // let mut new_brains = Vec::new(); 72 | 73 | // // for _ in 0..NUM_AI_CARS { 74 | // // let brain_idx = gene_pool.sample(&mut rng); 75 | // // let mut rand_brain = old_brains[brain_idx].clone(); 76 | // // rand_brain.mutate(); 77 | // // new_brains.push(rand_brain); 78 | // // } 79 | 80 | // // update stats 81 | // sim_stats.generation_count += 1; 82 | // // sim_stats.fitness.push(max_fitness); 83 | 84 | // // respawn everything 85 | // // spawn_enemies(&mut commands, &asset_server); 86 | // // spawn_bound_trucks(&mut commands, &asset_server); 87 | // // spawn_cars( 88 | // // &mut commands, 89 | // // &asset_server, 90 | // // &mut settings, 91 | // // Some(new_brains), 92 | // // ); 93 | // } 94 | 95 | // fn spawn_cars( 96 | // commands: &mut Commands, 97 | // asset_server: &AssetServer, 98 | // settings: &mut Settings, 99 | // models: Option>, 100 | // ) { 101 | // let models = models.unwrap_or(Vec::new()); 102 | // let is_new_nn = models.is_empty() || settings.restart_sim; 103 | // settings.restart_sim = false; 104 | 105 | // for i in 0..NUM_AI_CARS { 106 | // let model_id = FieldElement::from_dec_str(&i.to_string()).unwrap(); 107 | 108 | // match is_new_nn { 109 | // true => commands.spawn(CarBundle::new(asset_server, model_id)), 110 | // false => commands.spawn(CarBundle::with_model( 111 | // asset_server, 112 | // &models.get(i as usize).unwrap(), 113 | // )), 114 | // }; 115 | // } 116 | // } 117 | 118 | // fn create_gene_pool(values: Vec) -> (f32, WeightedIndex) { 119 | // let mut max_fitness = 0.0; 120 | // let mut weights = Vec::new(); 121 | 122 | // for v in values.iter() { 123 | // if *v > max_fitness { 124 | // max_fitness = *v; 125 | // } 126 | // weights.push(*v); 127 | // } 128 | 129 | // ( 130 | // max_fitness, 131 | // WeightedIndex::new(&weights).expect("Failed to generate gene pool"), 132 | // ) 133 | // } 134 | 135 | fn calc_fitness(transform: &Transform) -> f32 { 136 | let y = transform.translation.y; 137 | if y <= 600.0 { 138 | return 0.1; 139 | } 140 | 141 | return transform.translation.y / 340.0; 142 | } 143 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | use bevy::{math::vec3, prelude::*}; 2 | use bevy_inspector_egui::{bevy_egui::EguiPlugin, DefaultInspectorConfigPlugin}; 3 | use bevy_pancam::{PanCam, PanCamPlugin}; 4 | use bevy_rapier2d::prelude::*; 5 | use steering::{ 6 | car::{Car, CarPlugin}, 7 | configs::*, 8 | dojo::DojoPlugin, 9 | enemy::EnemyPlugin, 10 | population::PopulationPlugin, 11 | MaxDistanceTravelled, 12 | }; 13 | 14 | fn main() { 15 | App::new() 16 | .insert_resource(FixedTime::new_from_secs(0.25)) 17 | .add_plugins( 18 | DefaultPlugins 19 | .set(ImagePlugin::default_nearest()) 20 | .set(WindowPlugin { 21 | primary_window: Some(Window { 22 | resizable: false, 23 | focused: true, 24 | resolution: (WINDOW_WIDTH, WINDOW_HEIGHT).into(), 25 | ..default() 26 | }), 27 | ..default() 28 | }), 29 | ) 30 | .add_plugin(PanCamPlugin::default()) 31 | // .add_plugin(WorldInspectorPlugin::new().run_if(input_toggle_active(false, KeyCode::Tab))) // remove eguiplugin 32 | .add_plugin(DefaultInspectorConfigPlugin) // Requires egui plugin 33 | .add_plugin(EguiPlugin) 34 | .add_plugin(RapierPhysicsPlugin::::pixels_per_meter(100.0)) 35 | // .add_plugin(LogDiagnosticsPlugin::default()) 36 | // .add_plugin(FrameTimeDiagnosticsPlugin::default()) 37 | .add_plugin(CarPlugin) 38 | .add_plugin(EnemyPlugin) 39 | .add_plugin(PopulationPlugin) 40 | // .add_plugin(GuiPlugin) 41 | .add_plugin(DojoPlugin) 42 | .add_plugin(RapierDebugRenderPlugin::default()) 43 | .insert_resource(ClearColor(Color::rgb_u8(36, 36, 36))) 44 | .insert_resource(ClearColor(Color::WHITE)) 45 | // .insert_resource(Msaa::Off) 46 | .add_startup_system(setup) 47 | .add_system(bevy::window::close_on_esc) 48 | .add_system(camera_follow_system) 49 | // .add_system(settings_system) 50 | .run(); 51 | } 52 | 53 | fn setup( 54 | mut commands: Commands, 55 | asset_server: Res, 56 | mut rapier_config: ResMut, 57 | ) { 58 | rapier_config.gravity = Vec2::ZERO; 59 | 60 | commands 61 | .spawn(Camera2dBundle { 62 | transform: Transform::from_xyz(WINDOW_WIDTH / 2.0, WINDOW_HEIGHT / 2.0, 0.0), 63 | ..default() 64 | }) 65 | .insert(PanCam::default()); 66 | 67 | spawn_roads(&mut commands, &asset_server); 68 | // spawn_bound_trucks(&mut commands, &asset_server); 69 | } 70 | 71 | fn camera_follow_system( 72 | // settings: Res, 73 | max_distance_travelled: Res, 74 | mut cam_query: Query<(&Camera, &mut Transform), Without>, 75 | ) { 76 | let (_, mut cam_transform) = cam_query.get_single_mut().unwrap(); 77 | // if settings.is_camera_follow { 78 | cam_transform.translation = cam_transform.translation.lerp( 79 | vec3(cam_transform.translation.x, max_distance_travelled.0, 0.0), 80 | 0.05, 81 | ); 82 | // } 83 | } 84 | 85 | fn spawn_roads(commands: &mut Commands, asset_server: &AssetServer) { 86 | // Road 87 | // let rx = WINDOW_WIDTH / 2.0 - 30.0; 88 | let rx = ROAD_SPRITE_W / 2.0 * SPRITE_SCALE_FACTOR; 89 | let mut ry = ROAD_SPRITE_H / 2.0 * SPRITE_SCALE_FACTOR; 90 | for _ in 0..NUM_ROAD_TILES { 91 | commands.spawn(SpriteBundle { 92 | transform: Transform::from_xyz(rx, ry, -10.0) 93 | .with_scale(Vec3::splat(SPRITE_SCALE_FACTOR)), 94 | texture: asset_server.load("road.png"), 95 | ..default() 96 | }); 97 | ry += ROAD_SPRITE_H * SPRITE_SCALE_FACTOR; 98 | } 99 | let road_end_y = ry - ROAD_SPRITE_H * SPRITE_SCALE_FACTOR + 800.0; 100 | 101 | // end checker board 102 | commands.spawn(SpriteBundle { 103 | transform: Transform::from_xyz(rx, road_end_y - 50.0, -5.0) 104 | .with_scale(Vec3::splat(SPRITE_SCALE_FACTOR)), 105 | texture: asset_server.load("end-point.png"), 106 | ..default() 107 | }); 108 | 109 | // Road colliders 110 | // left 111 | let ry = 5.0 * ROAD_SPRITE_H * SPRITE_SCALE_FACTOR; 112 | let rx_min = ROAD_SPRITE_W / 2.0 * SPRITE_SCALE_FACTOR + 238.0; 113 | commands.spawn(( 114 | SpriteBundle { 115 | transform: Transform::from_xyz(rx_min, ry, 0.0).with_scale(vec3(0.5, 0.5, 1.0)), 116 | ..default() 117 | }, 118 | RigidBody::Fixed, 119 | Collider::cuboid( 120 | 5.0, 121 | ROAD_SPRITE_H * SPRITE_SCALE_FACTOR * NUM_ROAD_TILES as f32 * 5.0, 122 | ), 123 | )); 124 | 125 | // right 126 | let rx_max = ROAD_SPRITE_W * SPRITE_SCALE_FACTOR + 248.0; 127 | commands.spawn(( 128 | SpriteBundle { 129 | transform: Transform::from_xyz(rx_max, ry, 0.0).with_scale(vec3(0.5, 0.5, 1.0)), 130 | ..default() 131 | }, 132 | RigidBody::Fixed, 133 | Collider::cuboid( 134 | 5.0, 135 | ROAD_SPRITE_H * SPRITE_SCALE_FACTOR * NUM_ROAD_TILES as f32 * 5.0, 136 | ), 137 | )); 138 | 139 | // top 140 | commands.spawn(( 141 | SpriteBundle { 142 | transform: Transform::from_xyz(600.0, road_end_y, 0.0).with_scale(vec3(0.5, 0.5, 1.0)), 143 | ..default() 144 | }, 145 | RigidBody::Fixed, 146 | Collider::cuboid(500.0 * SPRITE_SCALE_FACTOR, 10.0), 147 | )); 148 | } 149 | 150 | // fn settings_system( 151 | // mut commands: Commands, 152 | // mut settings: ResMut, 153 | // mut sim_stats: ResMut, 154 | // car_query: Query>, 155 | // ) { 156 | // if settings.start_next_generation { 157 | // settings.start_next_generation = false; 158 | // car_query.iter().for_each(|c| { 159 | // commands.entity(c).remove::(); 160 | // }); 161 | // } 162 | // if settings.restart_sim { 163 | // // force restart 164 | // car_query.iter().for_each(|c| { 165 | // commands.entity(c).remove::(); 166 | // }); 167 | // *sim_stats = SimStats::default(); 168 | // sim_stats.generation_count = 0; 169 | // } 170 | // } 171 | -------------------------------------------------------------------------------- /src/model.cairo: -------------------------------------------------------------------------------- 1 | #[system] 2 | mod model { 3 | use array::ArrayTrait; 4 | use dojo::world::Context; 5 | use drive_ai::racer::Sensors; 6 | use drive_ai::vehicle::{Controls, Direction}; 7 | use orion::operators::nn::core::NNTrait; 8 | use orion::operators::nn::implementations::impl_nn_i8::NN_i8; 9 | use orion::operators::tensor::core::{TensorTrait, ExtraParams, Tensor}; 10 | use orion::operators::tensor::implementations::impl_tensor_fp::Tensor_fp; 11 | use orion::operators::tensor::implementations::impl_tensor_i8::Tensor_i8; 12 | use orion::numbers::signed_integer::i8::i8; 13 | use orion::numbers::fixed_point::core::{FixedType, FixedTrait}; 14 | use orion::performance::core::PerfomanceTrait; 15 | use orion::performance::implementations::impl_performance_fp::Performance_fp_i8; 16 | use orion::operators::nn::functional::linear::linear_ft::linear_ft; 17 | use core::traits::Into; 18 | use orion::numbers::fixed_point::implementations::impl_8x23::{FP8x23Impl, ONE, FP8x23Mul}; 19 | fn execute(ctx: Context, sensors: Sensors) -> Controls { 20 | let prediction: usize = forward(sensors.rays); 21 | let steer: Direction = if prediction == 0 { 22 | Direction::Straight(()) 23 | } else if prediction == 1 { 24 | Direction::Left(()) 25 | } else if prediction == 2 { 26 | Direction::Right(()) 27 | } else { 28 | let mut panic_msg = ArrayTrait::new(); 29 | panic_msg.append('prediction must be < 3'); 30 | panic(panic_msg) 31 | }; 32 | Controls { steer: steer, } 33 | } 34 | fn forward(input: Tensor) -> usize { 35 | let w = fc1::fc1_weights(); 36 | let b = fc1::fc1_bias(); 37 | let mut shape = ArrayTrait::::new(); 38 | shape.append(1); 39 | let mut data = ArrayTrait::::new(); 40 | data.append(FixedTrait::new_unscaled(1, false)); 41 | let extra = Option::::None(()); 42 | let y_scale = TensorTrait::new(shape.span(), data.span(), extra); 43 | let mut shape = ArrayTrait::::new(); 44 | shape.append(1); 45 | let mut data = ArrayTrait::::new(); 46 | data.append(FixedTrait::new_unscaled(0, false)); 47 | let extra = Option::::None(()); 48 | let y_zero_point = TensorTrait::new(shape.span(), data.span(), extra); 49 | let x = linear_ft(input, w, b); 50 | *x.argmax(0, Option::None(()), Option::None(())).data[0] 51 | } 52 | mod fc0 { 53 | use array::ArrayTrait; 54 | use orion::operators::tensor::core::{TensorTrait, Tensor, ExtraParams}; 55 | use orion::operators::tensor::implementations::impl_tensor_i8::Tensor_i8; 56 | use orion::numbers::fixed_point::core::FixedImpl; 57 | use orion::numbers::signed_integer::i8::i8; 58 | use orion::numbers::fixed_point::core::FixedType; 59 | use orion::operators::tensor::implementations::impl_tensor_fp::Tensor_fp; 60 | fn fc0_weights() -> Tensor { 61 | let mut shape = ArrayTrait::::new(); 62 | shape.append(3); 63 | shape.append(3); 64 | let mut data = ArrayTrait::::new(); 65 | data.append(FixedType { mag: 7630779, sign: true }); 66 | data.append(FixedType { mag: 5105879, sign: false }); 67 | data.append(FixedType { mag: 1421683, sign: false }); 68 | data.append(FixedType { mag: 7810292, sign: true }); 69 | data.append(FixedType { mag: 4044435, sign: true }); 70 | data.append(FixedType { mag: 6316368, sign: false }); 71 | data.append(FixedType { mag: 1884714, sign: false }); 72 | data.append(FixedType { mag: 4829624, sign: false }); 73 | data.append(FixedType { mag: 5562430, sign: false }); 74 | let extra = Option::::None(()); 75 | TensorTrait::new(shape.span(), data.span(), extra) 76 | } 77 | fn fc0_bias() -> Tensor { 78 | let mut shape = ArrayTrait::::new(); 79 | shape.append(3); 80 | let mut data = ArrayTrait::::new(); 81 | data.append(FixedType { mag: 0, sign: false }); 82 | data.append(FixedType { mag: 0, sign: false }); 83 | data.append(FixedType { mag: 0, sign: false }); 84 | let extra = Option::::None(()); 85 | TensorTrait::new(shape.span(), data.span(), extra) 86 | } 87 | } 88 | mod fc1 { 89 | use array::ArrayTrait; 90 | use orion::operators::tensor::core::{TensorTrait, Tensor, ExtraParams}; 91 | use orion::operators::tensor::implementations::impl_tensor_i8::Tensor_i8; 92 | use orion::numbers::fixed_point::core::FixedImpl; 93 | use orion::numbers::signed_integer::i8::i8; 94 | use orion::numbers::fixed_point::core::FixedType; 95 | use orion::operators::tensor::implementations::impl_tensor_fp::Tensor_fp; 96 | fn fc1_weights() -> Tensor { 97 | let mut shape = ArrayTrait::::new(); 98 | shape.append(1); 99 | shape.append(4); 100 | let mut data = ArrayTrait::::new(); 101 | data.append(FixedType { mag: 6383630, sign: false }); 102 | data.append(FixedType { mag: 8179303, sign: false }); 103 | data.append(FixedType { mag: 3419236, sign: false }); 104 | data.append(FixedType { mag: 5087765, sign: true }); 105 | let extra = Option::::None(()); 106 | TensorTrait::new(shape.span(), data.span(), extra) 107 | } 108 | fn fc1_bias() -> Tensor { 109 | let mut shape = ArrayTrait::::new(); 110 | shape.append(3); 111 | let mut data = ArrayTrait::::new(); 112 | data.append(FixedType { mag: 0, sign: false }); 113 | data.append(FixedType { mag: 0, sign: false }); 114 | data.append(FixedType { mag: 0, sign: false }); 115 | let extra = Option::::None(()); 116 | TensorTrait::new(shape.span(), data.span(), extra) 117 | } 118 | } 119 | } 120 | 121 | 122 | 123 | #[cfg(test)] 124 | mod tests { 125 | use core::result::ResultTrait; 126 | use core::serde::Serde; 127 | use clone::Clone; 128 | use array::{ArrayTrait, SpanTrait}; 129 | use traits::Into; 130 | use orion::operators::tensor::implementations::impl_tensor_fp::Tensor_fp; 131 | 132 | use drive_ai::racer::Sensors; 133 | 134 | use dojo::world::{IWorldDispatcher, IWorldDispatcherTrait, world}; 135 | use dojo::test_utils::spawn_test_world; 136 | 137 | use orion::operators::tensor::core::{Tensor, TensorTrait, ExtraParams}; 138 | use orion::numbers::fixed_point::core::{FixedType, FixedTrait}; 139 | use orion::numbers::fixed_point::implementations::impl_16x16::FP16x16Impl; 140 | 141 | 142 | #[test] 143 | #[available_gas(30000000)] 144 | fn test_model() { 145 | let caller = starknet::contract_address_const::<0x0>(); 146 | 147 | // Get required component. 148 | let mut components = array::ArrayTrait::new(); 149 | // Get required system. 150 | let mut systems = array::ArrayTrait::new(); 151 | systems.append(super::model::TEST_CLASS_HASH); 152 | // Get test world. 153 | let world = spawn_test_world(components, systems); 154 | let sensors = create_sensors(); 155 | let control = world.execute('model'.into(), sensors.span()); 156 | 157 | // Expect prediction == 0: 158 | assert(*control[0] == 0, 'invalid prediction') 159 | } 160 | 161 | // Utils 162 | fn create_sensors() -> Array { 163 | let mut shape = ArrayTrait::::new(); 164 | shape.append(5); 165 | let mut data = ArrayTrait::::new(); 166 | data.append(FixedTrait::new_unscaled(1, false)); 167 | data.append(FixedTrait::new_unscaled(2, false)); 168 | data.append(FixedTrait::new_unscaled(3, false)); 169 | data.append(FixedTrait::new_unscaled(4, false)); 170 | data.append(FixedTrait::new_unscaled(5, false)); 171 | let extra = Option::::None(()); 172 | let rays = TensorTrait::new(shape.span(), data.span(), extra); 173 | 174 | let sensors: Sensors = Sensors { rays: rays }; 175 | let mut serialized = ArrayTrait::new(); 176 | sensors.serialize(ref serialized); 177 | serialized 178 | } 179 | } -------------------------------------------------------------------------------- /src/model_test.cairo: -------------------------------------------------------------------------------- 1 | #[system] 2 | mod model { 3 | use array::ArrayTrait; 4 | use dojo::world::Context; 5 | use drive_ai::racer::Sensors; 6 | use drive_ai::vehicle::{Controls, Direction}; 7 | use orion::operators::nn::core::NNTrait; 8 | use orion::operators::nn::implementations::impl_nn_i8::NN_i8; 9 | use orion::operators::tensor::core::{TensorTrait, ExtraParams, Tensor}; 10 | use orion::operators::tensor::implementations::impl_tensor_fp::Tensor_fp; 11 | use orion::operators::tensor::implementations::impl_tensor_i8::Tensor_i8; 12 | use orion::numbers::signed_integer::i8::i8; 13 | use orion::numbers::fixed_point::core::{FixedType, FixedTrait}; 14 | use orion::performance::core::PerfomanceTrait; 15 | use orion::performance::implementations::impl_performance_fp::Performance_fp_i8; 16 | use orion::operators::nn::functional::linear::linear_ft::linear_ft; 17 | use core::traits::Into; 18 | use orion::numbers::fixed_point::implementations::impl_8x23::{FP8x23Impl, ONE, FP8x23Mul}; 19 | 20 | fn execute(ctx: Context, sensors: Sensors) -> Controls { 21 | let prediction: usize = forward(sensors.rays); 22 | 23 | let steer: Direction = if prediction == 0 { 24 | Direction::Straight(()) 25 | } else if prediction == 1 { 26 | Direction::Left(()) 27 | } else if prediction == 2 { 28 | Direction::Right(()) 29 | } else { 30 | let mut panic_msg = ArrayTrait::new(); 31 | panic_msg.append('prediction must be < 3'); 32 | panic(panic_msg) 33 | }; 34 | 35 | Controls { steer: steer, } 36 | } 37 | 38 | fn forward(input: Tensor) -> usize { 39 | let w = fc1::fc1_weights(); 40 | let b = fc1::fc1_bias(); 41 | 42 | // YSCALE 43 | let mut shape = ArrayTrait::::new(); 44 | shape.append(1); 45 | let mut data = ArrayTrait::::new(); 46 | data.append(FixedTrait::new_unscaled(1, false)); 47 | let extra = Option::::None(()); 48 | let y_scale = TensorTrait::new(shape.span(), data.span(), extra); 49 | 50 | // ZEROPOINT 51 | let mut shape = ArrayTrait::::new(); 52 | shape.append(1); 53 | let mut data = ArrayTrait::::new(); 54 | data.append(FixedTrait::new_unscaled(0, false)); 55 | let extra = Option::::None(()); 56 | let y_zero_point = TensorTrait::new(shape.span(), data.span(), extra); 57 | 58 | let x = linear_ft(input, w, b); 59 | *x.argmax(0, Option::None(()), Option::None(())).data[0] 60 | } 61 | mod fc0 { 62 | use array::ArrayTrait; 63 | use orion::operators::tensor::core::{TensorTrait, Tensor, ExtraParams}; 64 | use orion::operators::tensor::implementations::impl_tensor_i8::Tensor_i8; 65 | use orion::numbers::fixed_point::core::FixedImpl; 66 | use orion::numbers::signed_integer::i8::i8; 67 | use orion::numbers::fixed_point::core::FixedType; 68 | use orion::operators::tensor::implementations::impl_tensor_fp::Tensor_fp; 69 | 70 | 71 | fn fc0_weights() -> Tensor { 72 | let mut shape = ArrayTrait::::new(); 73 | shape.append(3); 74 | shape.append(3); 75 | let mut data = ArrayTrait::::new(); 76 | data.append(FixedType { mag: 72220, sign: false }); 77 | data.append(FixedType { mag: 5831923, sign: false }); 78 | data.append(FixedType { mag: 1359772, sign: true }); 79 | data.append(FixedType { mag: 7155590, sign: false }); 80 | data.append(FixedType { mag: 1799402, sign: false }); 81 | data.append(FixedType { mag: 6663625, sign: false }); 82 | data.append(FixedType { mag: 1611532, sign: true }); 83 | data.append(FixedType { mag: 8176142, sign: true }); 84 | data.append(FixedType { mag: 87320, sign: true }); 85 | let extra = Option::::None(()); 86 | TensorTrait::new(shape.span(), data.span(), extra) 87 | } 88 | fn fc0_bias() -> Tensor { 89 | let mut shape = ArrayTrait::::new(); 90 | shape.append(3); 91 | let mut data = ArrayTrait::::new(); 92 | data.append(FixedType { mag: 0, sign: false }); 93 | data.append(FixedType { mag: 0, sign: false }); 94 | data.append(FixedType { mag: 0, sign: false }); 95 | let extra = Option::::None(()); 96 | TensorTrait::new(shape.span(), data.span(), extra) 97 | } 98 | } 99 | mod fc1 { 100 | use array::ArrayTrait; 101 | use orion::operators::tensor::core::{TensorTrait, Tensor, ExtraParams}; 102 | use orion::operators::tensor::implementations::impl_tensor_i8::Tensor_i8; 103 | use orion::numbers::fixed_point::core::FixedImpl; 104 | use orion::numbers::fixed_point::core::FixedType; 105 | use orion::operators::tensor::implementations::impl_tensor_fp::Tensor_fp; 106 | 107 | 108 | fn fc1_weights() -> Tensor { 109 | let mut shape = ArrayTrait::::new(); 110 | shape.append(1); 111 | shape.append(4); 112 | let mut data = ArrayTrait::::new(); 113 | data.append(FixedType { mag: 1178602, sign: true }); 114 | data.append(FixedType { mag: 4518026, sign: false }); 115 | data.append(FixedType { mag: 5475490, sign: true }); 116 | data.append(FixedType { mag: 5467001, sign: true }); 117 | let extra = Option::::None(()); 118 | TensorTrait::new(shape.span(), data.span(), extra) 119 | } 120 | fn fc1_bias() -> Tensor { 121 | let mut shape = ArrayTrait::::new(); 122 | shape.append(3); 123 | let mut data = ArrayTrait::::new(); 124 | data.append(FixedType { mag: 0, sign: false }); 125 | data.append(FixedType { mag: 0, sign: false }); 126 | data.append(FixedType { mag: 0, sign: false }); 127 | let extra = Option::::None(()); 128 | TensorTrait::new(shape.span(), data.span(), extra) 129 | } 130 | } 131 | } 132 | 133 | 134 | #[cfg(test)] 135 | mod tests { 136 | use core::result::ResultTrait; 137 | use core::serde::Serde; 138 | use clone::Clone; 139 | use array::{ArrayTrait, SpanTrait}; 140 | use traits::Into; 141 | use orion::operators::tensor::implementations::impl_tensor_fp::Tensor_fp; 142 | 143 | use drive_ai::racer::Sensors; 144 | 145 | use dojo::world::{IWorldDispatcher, IWorldDispatcherTrait, world}; 146 | use dojo::test_utils::spawn_test_world; 147 | 148 | use orion::operators::tensor::core::{Tensor, TensorTrait, ExtraParams}; 149 | use orion::numbers::fixed_point::core::{FixedType, FixedTrait}; 150 | use orion::numbers::fixed_point::implementations::impl_16x16::FP16x16Impl; 151 | 152 | 153 | #[test] 154 | #[available_gas(30000000)] 155 | fn test_model() { 156 | let caller = starknet::contract_address_const::<0x0>(); 157 | 158 | // Get required component. 159 | let mut components = array::ArrayTrait::new(); 160 | // Get required system. 161 | let mut systems = array::ArrayTrait::new(); 162 | systems.append(super::model::TEST_CLASS_HASH); 163 | // Get test world. 164 | let world = spawn_test_world(components, systems); 165 | let sensors = create_sensors(); 166 | let control = world.execute('model'.into(), sensors.span()); 167 | 168 | // Expect prediction == 0: 169 | assert(*control[0] == 0, 'invalid prediction') 170 | } 171 | 172 | // Utils 173 | fn create_sensors() -> Array { 174 | let mut shape = ArrayTrait::::new(); 175 | shape.append(5); 176 | let mut data = ArrayTrait::::new(); 177 | data.append(FixedTrait::new_unscaled(1, false)); 178 | data.append(FixedTrait::new_unscaled(2, false)); 179 | data.append(FixedTrait::new_unscaled(3, false)); 180 | data.append(FixedTrait::new_unscaled(4, false)); 181 | data.append(FixedTrait::new_unscaled(5, false)); 182 | let extra = Option::::None(()); 183 | let rays = TensorTrait::new(shape.span(), data.span(), extra); 184 | 185 | let sensors: Sensors = Sensors { rays: rays }; 186 | let mut serialized = ArrayTrait::new(); 187 | sensors.serialize(ref serialized); 188 | serialized 189 | } 190 | } -------------------------------------------------------------------------------- /src/enemy.rs: -------------------------------------------------------------------------------- 1 | use crate::{configs::*, dojo::dojo_to_bevy_coordinate}; 2 | use bevy::{log, math::vec3, prelude::*}; 3 | use bevy_rapier2d::prelude::*; 4 | use rand::{thread_rng, Rng}; 5 | use starknet::core::types::FieldElement; 6 | 7 | pub struct EnemyPlugin; 8 | 9 | #[derive(Component, Reflect, Default)] 10 | pub struct Enemy { 11 | pub is_hit: bool, 12 | } 13 | 14 | #[derive(Component)] 15 | pub struct EnemyId(pub FieldElement); 16 | 17 | #[derive(Clone, Component, Reflect)] 18 | pub enum EnemyType { 19 | Simple, 20 | Horizontal(f32), 21 | Truck, 22 | } 23 | 24 | // #[derive(Component)] 25 | // pub struct BoundControlTruck; 26 | 27 | impl Plugin for EnemyPlugin { 28 | fn build(&self, app: &mut App) { 29 | app.add_event::() 30 | .add_event::() 31 | .add_systems((spawn_enemies, update_enemy)); 32 | // app.add_startup_system(setup) 33 | // .add_system(update_enemies) 34 | // .add_system(bound_control_system); 35 | } 36 | } 37 | 38 | pub struct SpawnEnemies; 39 | 40 | fn spawn_enemies( 41 | mut events: EventReader, 42 | mut commands: Commands, 43 | asset_server: Res, 44 | ) { 45 | for _ in events.iter() { 46 | for id in 0..DOJO_ENEMIES_NB { 47 | let enemy_type = EnemyType::random(); 48 | let enemy_scale = match enemy_type { 49 | EnemyType::Truck => 3.0, 50 | _ => 2.5, 51 | }; 52 | let collider = match enemy_type { 53 | EnemyType::Truck => Collider::cuboid(6.0, 15.0), 54 | _ => Collider::cuboid(4.0, 8.0), 55 | }; 56 | 57 | commands.spawn(( 58 | SpriteBundle { 59 | // TODO: workaround: spawn outside of screen because we know all enermies are spawned but don't know their positions yet 60 | transform: Transform::from_xyz(0.0, 0.0, 0.0).with_scale(vec3( 61 | enemy_scale, 62 | enemy_scale, 63 | 1.0, 64 | )), 65 | texture: asset_server.load(enemy_type.get_sprite()), 66 | ..default() 67 | }, 68 | // RigidBody::Dynamic, 69 | Velocity::zero(), 70 | ColliderMassProperties::Mass(1.0), 71 | Friction::new(100.0), 72 | ActiveEvents::COLLISION_EVENTS, 73 | collider, 74 | Damping { 75 | angular_damping: 2.0, 76 | linear_damping: 2.0, 77 | }, 78 | Enemy { is_hit: false }, 79 | EnemyId(id.into()), 80 | enemy_type, 81 | )); 82 | } 83 | } 84 | } 85 | 86 | pub struct UpdateEnemy { 87 | pub position: Vec, 88 | pub enemy_id: FieldElement, 89 | } 90 | 91 | fn update_enemy( 92 | mut events: EventReader, 93 | mut query: Query<(&mut Transform, &EnemyId), With>, 94 | ) { 95 | for e in events.iter() { 96 | let (new_x, new_y) = dojo_to_bevy_coordinate( 97 | e.position[0].to_string().parse().unwrap(), 98 | e.position[1].to_string().parse().unwrap(), 99 | ); 100 | 101 | log::info!("Enermy Position ({}), x: {new_x}, y: {new_y}", e.enemy_id); 102 | 103 | for (mut transform, enemy_id_comp) in query.iter_mut() { 104 | if enemy_id_comp.0 == e.enemy_id { 105 | transform.translation.x = new_x; 106 | transform.translation.y = new_y; 107 | } 108 | } 109 | } 110 | } 111 | 112 | // fn setup(mut commands: Commands, asset_server: Res) { 113 | // spawn_enemies(&mut commands, &asset_server); 114 | // } 115 | 116 | // fn update_enemies( 117 | // mut enemy_query: Query< 118 | // (&mut Transform, &mut Velocity, &mut Enemy, &mut EnemyType), 119 | // With, 120 | // >, 121 | // ) { 122 | // for (mut transform, mut velocity, mut enemy, mut enemy_type) in enemy_query.iter_mut() { 123 | // if enemy.is_hit { 124 | // continue; 125 | // } 126 | 127 | // velocity.linvel = vec2(0.0, 50.0); 128 | // enemy.is_hit = velocity.angvel != 0.0; 129 | 130 | // // horizontal motion 131 | // match enemy_type.as_mut() { 132 | // EnemyType::Horizontal(direction) => { 133 | // velocity.linvel += *direction * vec2(30.0, 0.0); 134 | 135 | // // direction update 136 | // // 738 -> 1180 is the road x dir 137 | // if transform.translation.x >= 1170.0 { 138 | // transform.translation.x = 1169.0; 139 | // *direction *= -1.0; 140 | // } else if transform.translation.x <= 742.0 { 141 | // transform.translation.x = 743.0; 142 | // *direction *= -1.0; 143 | // } 144 | // } 145 | // _ => {} 146 | // } 147 | // } 148 | // } 149 | 150 | // fn bound_control_system(mut query: Query<&mut Transform, With>) { 151 | // for mut transform in query.iter_mut() { 152 | // transform.translation.y += 1.0; 153 | // } 154 | // } 155 | 156 | // pub fn spawn_enemies(commands: &mut Commands, asset_server: &AssetServer) { 157 | // let mut enemy_y = 800.0; 158 | // for _ in 0..NUM_ENEMY_CARS { 159 | // let enemy_type = EnemyType::random(); 160 | // let enemy_scale = match enemy_type { 161 | // EnemyType::Truck => 3.0, 162 | // _ => 2.5, 163 | // }; 164 | // let collider = match enemy_type { 165 | // EnemyType::Truck => Collider::cuboid(6.0, 15.0), 166 | // _ => Collider::cuboid(4.0, 8.0), 167 | // }; 168 | // let mut rng = rand::thread_rng(); 169 | // let x = rng.gen_range(743.0..1169.0); 170 | // let y = enemy_y; 171 | // enemy_y += 200.0; 172 | // commands.spawn(( 173 | // SpriteBundle { 174 | // transform: Transform::from_xyz(x, y, 0.0).with_scale(vec3( 175 | // enemy_scale, 176 | // enemy_scale, 177 | // 1.0, 178 | // )), 179 | // texture: asset_server.load(enemy_type.get_sprite()), 180 | // ..default() 181 | // }, 182 | // RigidBody::Dynamic, 183 | // Velocity::zero(), 184 | // ColliderMassProperties::Mass(1.0), 185 | // Friction::new(100.0), 186 | // ActiveEvents::COLLISION_EVENTS, 187 | // collider, 188 | // Damping { 189 | // angular_damping: 2.0, 190 | // linear_damping: 2.0, 191 | // }, 192 | // Enemy { is_hit: false }, 193 | // enemy_type, 194 | // )); 195 | // } 196 | // } 197 | 198 | // pub fn spawn_bound_trucks(commands: &mut Commands, asset_server: &AssetServer) { 199 | // // Bound control trucks 200 | // let enemy_y = 100.0; 201 | // let mut enemy_x = 743.0; // upto 1169.0 202 | // for _ in 0..12 { 203 | // let enemy_type = EnemyType::Truck; 204 | // let enemy_scale = 3.0; 205 | // let collider = match enemy_type { 206 | // EnemyType::Truck => Collider::cuboid(6.0, 15.0), 207 | // _ => Collider::cuboid(4.0, 8.0), 208 | // }; 209 | // let x = enemy_x; 210 | // let y = enemy_y; 211 | 212 | // enemy_x += 40.0; 213 | // commands.spawn(( 214 | // SpriteBundle { 215 | // transform: Transform::from_xyz(x, y, 0.0).with_scale(vec3( 216 | // enemy_scale, 217 | // enemy_scale, 218 | // 1.0, 219 | // )), 220 | // texture: asset_server.load("bound-truck.png"), 221 | // ..default() 222 | // }, 223 | // RigidBody::Fixed, 224 | // ActiveEvents::COLLISION_EVENTS, 225 | // collider, 226 | // Damping { 227 | // angular_damping: 2.0, 228 | // linear_damping: 2.0, 229 | // }, 230 | // enemy_type, 231 | // BoundControlTruck, 232 | // )); 233 | // } 234 | // } 235 | 236 | impl EnemyType { 237 | pub fn random() -> Self { 238 | let all_vals = [Self::Horizontal(3.0), Self::Simple, Self::Truck]; 239 | let mut rng = thread_rng(); 240 | let index = rng.gen_range(0..all_vals.len()); 241 | 242 | all_vals[index].clone() 243 | } 244 | 245 | pub fn get_sprite(&self) -> &str { 246 | let mut rng = thread_rng(); 247 | match self { 248 | EnemyType::Simple => { 249 | let choices = ["enemy-blue-1.png", "enemy-yellow-1.png"]; 250 | return choices[rng.gen_range(0..choices.len())]; 251 | } 252 | EnemyType::Horizontal(_) => { 253 | let choices = [ 254 | "enemy-blue-2.png", 255 | "enemy-yellow-2.png", 256 | "enemy-yellow-3.png", 257 | ]; 258 | return choices[rng.gen_range(0..choices.len())]; 259 | } 260 | EnemyType::Truck => "enemy-truck.png", 261 | } 262 | } 263 | } 264 | -------------------------------------------------------------------------------- /src/car.rs: -------------------------------------------------------------------------------- 1 | use crate::dojo::fixed_to_f32; 2 | use crate::*; 3 | use crate::{dojo::dojo_to_bevy_coordinate, nn::Net}; 4 | use bevy::{log, math::vec3, prelude::*}; 5 | use bevy_prototype_debug_lines::DebugLinesPlugin; 6 | use bevy_rapier2d::prelude::*; 7 | use starknet::core::{types::FieldElement, utils::cairo_short_string_to_felt}; 8 | 9 | pub struct CarPlugin; 10 | 11 | #[derive(Component)] 12 | pub struct Car; 13 | 14 | #[derive(Component)] 15 | pub struct Model { 16 | pub nn: Net, 17 | pub nn_outputs: Vec>, 18 | // ray_inputs: Vec, 19 | pub id: FieldElement, 20 | } 21 | 22 | // #[derive(Component, Reflect)] 23 | // struct TurnSpeed(f32); 24 | 25 | // #[derive(Component, Reflect)] 26 | // struct Steer(f32); 27 | 28 | // #[derive(Component, Reflect)] 29 | // struct Speed(f32); 30 | 31 | #[derive(Component)] 32 | pub struct Fitness(pub f32); 33 | 34 | // #[derive(Resource, Default)] 35 | // struct RayCastSensors(Vec<(f32, f32)>); 36 | 37 | // wasd controls 38 | // struct CarControls(bool, bool, bool, bool); 39 | 40 | #[derive(Bundle)] 41 | pub struct CarBundle { 42 | sprite_bundle: SpriteBundle, 43 | car: Car, 44 | fitness: Fitness, 45 | model: Model, 46 | // speed: Speed, 47 | velocity: Velocity, 48 | mass: ColliderMassProperties, 49 | rigid_body: RigidBody, 50 | collider: Collider, 51 | events: ActiveEvents, 52 | damping: Damping, 53 | // sleep: Sleeping, 54 | ccd: Ccd, 55 | collision_groups: CollisionGroups, 56 | } 57 | 58 | impl Plugin for CarPlugin { 59 | fn build(&self, app: &mut App) { 60 | app.add_plugin(DebugLinesPlugin::default()) 61 | .add_event::() 62 | .add_event::() 63 | // .register_type::() 64 | // .register_type::() 65 | // .insert_resource(RayCastSensors::default()) 66 | // .add_startup_system(setup) 67 | // .add_systems((car_render_system, spawn_cars)); 68 | .add_systems((spawn_car, update_car, collision_events_system)); 69 | // .add_system(sensors_system) 70 | // .add_system(car_nn_controlled_system.in_schedule(CoreSchedule::FixedUpdate)); 71 | } 72 | } 73 | 74 | pub struct SpawnCar; 75 | 76 | fn spawn_car( 77 | mut events: EventReader, 78 | asset_server: Res, 79 | mut commands: Commands, 80 | ) { 81 | for _ in events.iter() { 82 | let model_id = cairo_short_string_to_felt(configs::MODEL_NAME).unwrap(); 83 | commands.spawn(CarBundle::new(&asset_server, model_id)); 84 | } 85 | } 86 | 87 | pub struct UpdateCar { 88 | pub vehicle: Vec, 89 | } 90 | 91 | fn update_car( 92 | mut events: EventReader, 93 | mut query: Query<(&mut Transform, &Model), With>, 94 | ) { 95 | for e in events.iter() { 96 | if let Ok((mut transform, model)) = query.get_single_mut() { 97 | let (new_x, new_y) = 98 | dojo_to_bevy_coordinate(fixed_to_f32(e.vehicle[0]), fixed_to_f32(e.vehicle[2])); 99 | 100 | log::info!("Vehicle Position ({}), x: {new_x}, y: {new_y}", model.id); 101 | 102 | transform.translation.x = new_x; 103 | transform.translation.y = new_y; 104 | } 105 | } 106 | } 107 | 108 | // fn position_based_movement_system(controls: CarControls, transform: &mut Transform) { 109 | // let a_key = controls.1; 110 | // let d_key = controls.3; 111 | 112 | // let time_step = 1.0 / 60.0; 113 | // let mut rotation_factor = 0.0; 114 | 115 | // if a_key { 116 | // rotation_factor += 0.5; 117 | // } else if d_key { 118 | // rotation_factor -= 0.5; 119 | // } 120 | 121 | // transform.rotate_z(rotation_factor * 5.0 * time_step); 122 | // } 123 | 124 | // fn setup(mut ray_cast_sensors: ResMut) { 125 | // // Pre compute the raycast directions 126 | // let angle_per_ray = RAYCAST_SPREAD_ANGLE_DEG / (NUM_RAY_CASTS as f32) + 1.0; 127 | // let mut current_angle = RAYCAST_START_ANGLE_DEG; 128 | // for _ in 0..NUM_RAY_CASTS { 129 | // let angle = current_angle * (PI / 180.0); 130 | // let x = angle.cos(); 131 | // let y = angle.sin(); 132 | // ray_cast_sensors.0.push((x, y)); 133 | 134 | // current_angle += angle_per_ray; 135 | // } 136 | // } 137 | 138 | fn collision_events_system( 139 | mut commands: Commands, 140 | mut collision_events: EventReader, 141 | ) { 142 | for collision_event in collision_events.iter() { 143 | match collision_event { 144 | CollisionEvent::Started(entity1, entity2, _) => { 145 | commands.entity(*entity2).remove::(); 146 | commands.entity(*entity1).remove::(); 147 | } 148 | _ => {} 149 | } 150 | } 151 | } 152 | 153 | // fn car_render_system(mut car_query: Query<&mut Transform, With>) { 154 | // for mut transform in car_query.iter_mut() { 155 | // let movement_direction = transform.rotation * Vec3::Y; 156 | // let movement_distance = 3.5; 157 | // let translation_delta = movement_direction * movement_distance; 158 | // transform.translation += translation_delta; 159 | // } 160 | // } 161 | 162 | // fn car_nn_controlled_system( 163 | // mut car_query: Query<(&mut Speed, &mut Model, &mut Transform), With>, 164 | // ) { 165 | // for (mut speed, mut model, mut transform) in car_query.iter_mut() { 166 | // if model.ray_inputs.is_empty() { 167 | // speed.0 = 0.0; 168 | // return; 169 | // } 170 | 171 | // model.nn_outputs = model.nn.predict(&model.ray_inputs); 172 | // let nn_out = model.nn_outputs[NUM_OUPUT_NODES - 1].clone(); 173 | // // nn_out = model.nn.predict(&model.ray_inputs).pop().unwrap(); 174 | 175 | // // let w_key = nn_out[0] >= NN_W_ACTIVATION_THRESHOLD; 176 | // let w_key = false; 177 | // // let s_key = nn_out[2] >= NN_S_ACTIVATION_THRESHOLD; 178 | // let s_key = false; 179 | // let mut a_key = false; 180 | // let mut d_key = false; 181 | 182 | // if nn_out[0] >= 0.5 { 183 | // a_key = true; 184 | // } else { 185 | // d_key = true; 186 | // } 187 | 188 | // position_based_movement_system(CarControls(w_key, a_key, s_key, d_key), &mut transform); 189 | // } 190 | // } 191 | 192 | // fn draw_ray_cast( 193 | // lines: &mut DebugLines, 194 | // settings: &Settings, 195 | // start: Vec3, 196 | // end: Vec3, 197 | // color: Color, 198 | // ) { 199 | // if !settings.is_show_rays { 200 | // return; 201 | // } 202 | 203 | // if start.y <= 700.0 && settings.is_hide_rays_at_start { 204 | // return; 205 | // } 206 | 207 | // lines.line_colored(start, end, 0.0, color); 208 | // } 209 | 210 | // fn sensors_system( 211 | // mut lines: ResMut, 212 | // settings: Res, 213 | // ray_cast_sensors: Res, 214 | // rapier_context: Res, 215 | // mut query: Query<(&Transform, &mut Model), With>, 216 | // ) { 217 | // for (transform, mut model) in query.iter_mut() { 218 | // let raycast_filter = CollisionGroups { 219 | // memberships: Group::GROUP_1, 220 | // filters: Group::GROUP_2, 221 | // }; 222 | // let filter = QueryFilter::default().groups(raycast_filter); 223 | // let ray_pos = transform.translation; 224 | // let mut nn_inputs = Vec::new(); 225 | 226 | // // Ray casts 227 | // let rot = transform.rotation.z; 228 | // for (mut x, mut y) in ray_cast_sensors.0.iter() { 229 | // (x, y) = rotate_point(x, y, rot); 230 | // let dest_vec = vec2(x, y); 231 | // let end_point = calculate_endpoint(ray_pos, dest_vec, RAYCAST_MAX_TOI); 232 | // draw_ray_cast(&mut lines, &settings, ray_pos, end_point, Color::RED); 233 | 234 | // let ray_pos_2d = vec2(ray_pos.x, ray_pos.y); 235 | // if let Some((_, toi)) = 236 | // rapier_context.cast_ray(ray_pos_2d, dest_vec, RAYCAST_MAX_TOI, false, filter) 237 | // { 238 | // // The first collider hit has the entity `entity` and it hit after 239 | // // the ray travelled a distance equal to `ray_dir * toi`. 240 | // let hit_point = ray_pos_2d + dest_vec * toi; 241 | // let hit_point = vec3(hit_point.x, hit_point.y, 0.0); 242 | 243 | // // Invalidate when hit length more than max toi 244 | // let dist_to_hit = ray_pos.distance(hit_point); 245 | // nn_inputs.push(dist_to_hit as f64 / RAYCAST_MAX_TOI as f64); 246 | // if dist_to_hit > RAYCAST_MAX_TOI { 247 | // continue; 248 | // } 249 | 250 | // draw_ray_cast(&mut lines, &settings, ray_pos, hit_point, Color::GREEN); 251 | // } else { 252 | // nn_inputs.push(1.0); 253 | // } 254 | // } 255 | 256 | // model.ray_inputs = nn_inputs; 257 | // } 258 | // } 259 | 260 | // fn calculate_endpoint(pos: Vec3, direction: Vec2, length: f32) -> Vec3 { 261 | // let dir = direction.normalize(); 262 | // vec3(pos[0] + dir[0] * length, pos[1] + dir[1] * length, 0.0) 263 | // } 264 | 265 | // fn rotate_point(x: f32, y: f32, angle_rad: f32) -> (f32, f32) { 266 | // // Calculate the distance from the origin 267 | // let r = (x * x + y * y).sqrt(); 268 | 269 | // // Calculate the current angle 270 | // let alpha = y.atan2(x); 271 | 272 | // // Add the rotation angle 273 | // let beta = alpha + angle_rad; 274 | 275 | // // Calculate the new coordinates 276 | // let x_prime = r * beta.cos(); 277 | // let y_prime = r * beta.sin(); 278 | 279 | // (x_prime, y_prime) 280 | // } 281 | 282 | impl CarBundle { 283 | pub fn new(asset_server: &AssetServer, model_id: FieldElement) -> Self { 284 | // let mut rng = rand::thread_rng(); 285 | // let rand_x = rng.gen_range(800.0..1100.0); 286 | 287 | Self { 288 | sprite_bundle: SpriteBundle { 289 | // TODO: make cordinate dynamic 290 | transform: Transform::from_xyz(298.0, 0.0, 0.0).with_scale(vec3(2.5, 2.5, 1.0)), 291 | texture: asset_server.load("agent.png"), 292 | ..default() 293 | }, 294 | car: Car, 295 | fitness: Fitness(0.0), 296 | model: Model { 297 | nn: Net::new(vec![ 298 | NUM_RAY_CASTS as usize, 299 | NUM_HIDDEN_NODES, 300 | NUM_OUPUT_NODES, 301 | ]), 302 | // ray_inputs: Vec::new(), 303 | nn_outputs: Vec::new(), 304 | id: model_id, 305 | }, 306 | // speed: Speed(0.0), 307 | velocity: Velocity::zero(), 308 | mass: ColliderMassProperties::Mass(3000.0), 309 | rigid_body: RigidBody::Dynamic, 310 | collider: Collider::cuboid(5.0, 8.0), 311 | events: ActiveEvents::COLLISION_EVENTS, 312 | damping: Damping { 313 | angular_damping: 100.0, 314 | linear_damping: 100.0, 315 | }, 316 | // sleep: Sleeping::disabled(), 317 | ccd: Ccd::enabled(), 318 | collision_groups: CollisionGroups { 319 | memberships: Group::GROUP_1, 320 | filters: Group::GROUP_2, 321 | }, 322 | } 323 | } 324 | 325 | pub fn with_model(asset_server: &AssetServer, model: &Net) -> Self { 326 | // TODO: generate dojo id 327 | let mode_id = FieldElement::from_dec_str("0").unwrap(); 328 | 329 | let mut car = CarBundle::new(asset_server, mode_id); 330 | car.model.nn = model.clone(); 331 | car 332 | } 333 | } 334 | -------------------------------------------------------------------------------- /src/dojo.rs: -------------------------------------------------------------------------------- 1 | use crate::car::Car; 2 | use crate::car::Model; 3 | use crate::car::SpawnCar; 4 | use crate::car::UpdateCar; 5 | use crate::configs; 6 | use crate::enemy::SpawnEnemies; 7 | use crate::enemy::UpdateEnemy; 8 | use crate::ROAD_X_MIN; 9 | use bevy::ecs::system::SystemState; 10 | use bevy::log; 11 | use bevy::prelude::*; 12 | use bevy_rapier2d::prelude::*; 13 | use bevy_tokio_tasks::TaskContext; 14 | use bevy_tokio_tasks::{TokioTasksPlugin, TokioTasksRuntime}; 15 | use dojo_client::contract::world::WorldContract; 16 | use num::bigint::BigUint; 17 | use num::{FromPrimitive, ToPrimitive}; 18 | use rand::Rng; 19 | use starknet::accounts::SingleOwnerAccount; 20 | use starknet::core::types::{BlockId, BlockTag, FieldElement}; 21 | use starknet::core::utils::cairo_short_string_to_felt; 22 | use starknet::providers::jsonrpc::HttpTransport; 23 | use starknet::providers::JsonRpcClient; 24 | use starknet::signers::{LocalWallet, SigningKey}; 25 | use std::ops::Div; 26 | use std::str::FromStr; 27 | use std::sync::Arc; 28 | use tokio::sync::mpsc; 29 | use url::Url; 30 | 31 | pub fn rand_felt_fixed_point() -> FieldElement { 32 | let mut rng = rand::thread_rng(); 33 | ((rng.gen::() % 200) << 64).into() 34 | } 35 | 36 | #[derive(Resource)] 37 | pub struct DojoEnv { 38 | /// The block ID to use for all contract calls. 39 | block_id: BlockId, 40 | /// The address of the world contract. 41 | world_address: FieldElement, 42 | /// The account to use for performing execution on the World contract. 43 | account: Arc, LocalWallet>>, 44 | } 45 | 46 | impl DojoEnv { 47 | fn new( 48 | world_address: FieldElement, 49 | account: SingleOwnerAccount, LocalWallet>, 50 | ) -> Self { 51 | Self { 52 | world_address, 53 | account: Arc::new(account), 54 | block_id: BlockId::Tag(BlockTag::Latest), 55 | } 56 | } 57 | } 58 | 59 | pub struct DojoPlugin; 60 | 61 | impl Plugin for DojoPlugin { 62 | fn build(&self, app: &mut App) { 63 | let url = Url::parse(configs::JSON_RPC_ENDPOINT).unwrap(); 64 | let account_address = FieldElement::from_str(configs::ACCOUNT_ADDRESS).unwrap(); 65 | let account = SingleOwnerAccount::new( 66 | JsonRpcClient::new(HttpTransport::new(url)), 67 | LocalWallet::from_signing_key(SigningKey::from_secret_scalar( 68 | FieldElement::from_str(configs::ACCOUNT_SECRET_KEY).unwrap(), 69 | )), 70 | account_address, 71 | cairo_short_string_to_felt("KATANA").unwrap(), 72 | ); 73 | 74 | let world_address = FieldElement::from_str(configs::WORLD_ADDRESS).unwrap(); 75 | 76 | app.add_plugin(TokioTasksPlugin::default()) 77 | .insert_resource(DojoEnv::new(world_address, account)) 78 | .add_startup_systems(( 79 | setup, 80 | spawn_racers_thread, 81 | drive_thread, 82 | update_vehicle_thread, 83 | update_enemies_thread, 84 | )) 85 | .add_system(sync_dojo_state); 86 | } 87 | } 88 | 89 | fn setup(mut commands: Commands) { 90 | commands.spawn(DojoSyncTime::from_seconds(configs::DOJO_SYNC_INTERVAL)); 91 | } 92 | 93 | #[derive(Component)] 94 | struct DojoSyncTime { 95 | timer: Timer, 96 | } 97 | 98 | impl DojoSyncTime { 99 | fn from_seconds(duration: f32) -> Self { 100 | Self { 101 | timer: Timer::from_seconds(duration, TimerMode::Repeating), 102 | } 103 | } 104 | } 105 | 106 | fn sync_dojo_state( 107 | mut dojo_sync_time: Query<&mut DojoSyncTime>, 108 | time: Res