├── .env.example
├── .gitignore
├── CONTRIBUTION_NOTES.md
├── README.md
├── TODO.md
├── ignition-web-rl
├── .gitignore
├── README.md
├── eslint.config.js
├── index.html
├── package-lock.json
├── package.json
├── public
│ ├── logo.png
│ └── vite.svg
├── src
│ ├── App.tsx
│ ├── Experience.tsx
│ ├── assets
│ │ └── react.svg
│ ├── index.css
│ ├── logo-3d.tsx
│ ├── main.tsx
│ ├── simple-agent.tsx
│ ├── store
│ │ └── targetStore.ts
│ ├── target.tsx
│ ├── themes.ts
│ └── vite-env.d.ts
├── tsconfig.app.json
├── tsconfig.json
├── tsconfig.node.json
└── vite.config.ts
├── package.json
├── packages
├── backend-onnx
│ ├── index.ts
│ ├── package.json
│ ├── src
│ │ └── index.ts
│ └── tsconfig.json
├── backend-tfjs
│ ├── backend
│ │ └── setBackend.ts
│ ├── index.ts
│ ├── model
│ │ ├── model.json
│ │ └── weights.bin
│ ├── model_step-10
│ │ ├── model.json
│ │ └── weights.bin
│ ├── package.json
│ ├── src
│ │ ├── agents
│ │ │ ├── dqn.ts
│ │ │ ├── ppo.ts
│ │ │ └── qtable.ts
│ │ ├── index.ts
│ │ ├── io
│ │ │ ├── index.ts
│ │ │ ├── loadModel.ts
│ │ │ └── saveModelToHub.ts
│ │ ├── memory
│ │ │ └── ReplayBuffer.ts
│ │ ├── model
│ │ │ └── BuildMLP.ts
│ │ ├── tools
│ │ │ └── trainer.ts
│ │ └── types.ts
│ ├── test
│ │ ├── checkpoint.test.ts
│ │ ├── dqn.test.ts
│ │ ├── hubIntegration.test.ts
│ │ ├── setBackend.test.ts
│ │ └── trainer.test.ts
│ ├── tmp-model
│ │ ├── model.json
│ │ ├── model
│ │ │ ├── model.json
│ │ │ └── weights.bin
│ │ └── weights.bin
│ ├── tmp-readme
│ │ └── README.md
│ ├── tsconfig.json
│ └── vitest.config.ts
├── core
│ ├── index.ts
│ ├── package.json
│ ├── src
│ │ ├── ignition-env.ts
│ │ ├── index.ts
│ │ └── types.ts
│ └── tsconfig.json
└── demo-target-chasing
│ ├── index.html
│ ├── index.ts
│ ├── package.json
│ ├── readme.md
│ ├── src
│ ├── AgentConfigPanel.tsx
│ ├── Experience.tsx
│ ├── NetworkDesigner.tsx
│ ├── TrainingControls.tsx
│ ├── env.d.ts
│ ├── index.html
│ ├── index.ts
│ ├── main.ts
│ ├── store
│ │ └── trainingStore.ts
│ ├── visualization.html
│ └── visualization.ts
│ ├── styles.css
│ ├── tsconfig.app.json
│ ├── tsconfig.json
│ └── vite.config.ts
├── pnpm-lock.yaml
├── pnpm-workspace.yaml
├── r3f
└── target-chasing
│ ├── .gitignore
│ ├── README.md
│ ├── eslint.config.js
│ ├── index.html
│ ├── package.json
│ ├── public
│ ├── logo.png
│ └── vite.svg
│ ├── src
│ ├── App.tsx
│ ├── Experience.tsx
│ ├── TrainingControls.tsx
│ ├── assets
│ │ └── react.svg
│ ├── components
│ │ ├── AgentConfigPanel.tsx
│ │ ├── NetworkDesigner.tsx
│ │ └── VisualizationCharts.tsx
│ ├── index.css
│ ├── main.tsx
│ ├── store
│ │ └── trainingStore.ts
│ ├── styles.css
│ ├── visualization.ts
│ └── vite-env.d.ts
│ ├── tsconfig.app.json
│ ├── tsconfig.json
│ ├── tsconfig.node.json
│ └── vite.config.ts
├── roadmap.md
├── tsconfig.base.json
└── vitest.config.ts
/.env.example:
--------------------------------------------------------------------------------
1 | # Hugging Face Hub
2 | HF_TOKEN=your_huggingface_token_here
3 |
4 | # TensorFlow.js Backend (optional)
5 | TFJS_BACKEND=webgl # or 'cpu', 'wasm', 'webgpu'
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # node
2 | node_modules/
3 | .pnpm/
4 | .pnpm-debug.log
5 | npm-debug.log*
6 |
7 | # dist
8 | dist/
9 | build/
10 | out/
11 | tmp-model/
12 |
13 | # logs
14 | logs
15 | *.log
16 |
17 | # TypeScript
18 | *.tsbuildinfo
19 |
20 | # testing
21 | coverage/
22 | .vitest/
23 |
24 | # IDEs
25 | .vscode/
26 | .idea/
27 | .DS_Store
28 | *.swp
29 |
30 | # OS
31 | Thumbs.db
32 | ehthumbs.db
33 |
34 | # env files
35 | .env
36 | .env.*
37 | !.env.example
38 |
--------------------------------------------------------------------------------
/CONTRIBUTION_NOTES.md:
--------------------------------------------------------------------------------
1 | # Contribution Notes: IgnitionAI Enhancements (April 2025)
2 |
3 | This document details the enhancements and modifications made to the IgnitionAI project, specifically focusing on the `r3f/target-chasing` demo application. The goal was to generalize the project, add visualization capabilities, implement a user-friendly configuration interface, and provide a foundation for a visual network designer, all while keeping long-term contribution in mind.
4 |
5 | ## 1. Project Setup & Analysis
6 |
7 | - **Repository Cloned:** The project was cloned from `https://github.com/IgnitionAI/ignition.git`.
8 | - **Dependency Installation:** Dependencies were installed using `pnpm install` in the root directory.
9 | - **Structure Analysis:** The monorepo structure (`pnpm-workspace.yaml`) was identified, with key packages being `@ignitionai/core`, `@ignitionai/backend-tfjs`, and the demo application `r3f/target-chasing`.
10 | - **Existing Demo:** The `r3f/target-chasing` demo uses React, React Three Fiber (R3F), Rapier physics, and Zustand for state management to visualize a simple agent learning to reach a target in a 3D environment.
11 |
12 | ## 2. Core Enhancements in `r3f/target-chasing`
13 |
14 | Several new components and modifications were introduced to the demo application (`/home/ubuntu/ignition/r3f/target-chasing/src`).
15 |
16 | ### 2.1. Visualization Charts (`components/VisualizationCharts.tsx`)
17 |
18 | - **Purpose:** To provide real-time feedback on the agent's training progress.
19 | - **Technology:** Uses the `recharts` library for creating interactive line charts.
20 | - **Features:**
21 | - **Reward Chart:** Displays the reward received by the agent at each step.
22 | - **Loss Chart:** Displays the training loss of the DQN agent (currently simulated, requires integration with `backend-tfjs` training loop).
23 | - **Epsilon Chart:** Shows the decay of the epsilon value (exploration rate) over time (currently simulated based on episode count, ideally fetched from the agent).
24 | - **Implementation:**
25 | - A new React component `VisualizationCharts` was created.
26 | - It uses the `useTrainingStore` (Zustand) to get the current `reward` and `episodeSteps`.
27 | - It maintains internal state (`rewardHistory`, `lossHistory`, `epsilonHistory`) to store data points for the charts.
28 | - `useEffect` hooks are used to update the history arrays when relevant state changes.
29 | - Data history is capped (`maxDataPoints`) to prevent performance degradation.
30 | - `ResponsiveContainer` from `recharts` ensures charts adapt to the panel size.
31 | - **Future Work:**
32 | - Integrate actual loss and epsilon values from the `DQNAgent` in `backend-tfjs`. This might require modifications to the `IgnitionEnv` or `DQNAgent` to expose these values, potentially via callbacks or updated state in the store.
33 |
34 | ### 2.2. Agent Configuration Panel (`components/AgentConfigPanel.tsx`)
35 |
36 | - **Purpose:** Allows users to modify the DQN agent's hyperparameters and network architecture without editing the code directly.
37 | - **Technology:** Standard React functional component with state management (`useState`).
38 | - **Features:**
39 | - Configure network architecture (Input Size, Action Size, Hidden Layers - add/remove/modify neuron counts).
40 | - Configure training parameters (Epsilon, Epsilon Decay, Min Epsilon, Gamma, Learning Rate, Batch Size, Memory Size).
41 | - "Apply Configuration" button triggers a callback (`onApplyConfig`) to pass the updated configuration to the parent component (`App.tsx`).
42 | - **Implementation:**
43 | - Uses controlled input components (` `) for each parameter.
44 | - State hooks manage the current value of each configuration parameter.
45 | - Functions `addLayer`, `removeLayer`, `updateLayer` manage the dynamic hidden layer configuration.
46 | - **Integration:** The `App.tsx` component manages the `agentConfig` state and passes it down to `Experience.tsx`. The `AgentConfigPanel` updates this state via the `handleApplyConfig` callback.
47 |
48 | ### 2.3. Network Designer (`components/NetworkDesigner.tsx`)
49 |
50 | - **Purpose:** Provides a visual, drag-and-drop interface for designing the neural network architecture. This is a foundational implementation.
51 | - **Technology:** Uses the `reactflow` library (v11).
52 | - *Note:* Initially `react-flow-renderer` (v10) was installed but replaced with `reactflow` as the former is deprecated.
53 | - **Features:**
54 | - Displays nodes (Input, Dense Layers, Output) and edges representing connections.
55 | - Basic interaction: Drag nodes, potentially add/remove nodes/edges (though add/remove logic is basic).
56 | - Extracts a simplified hidden layer structure (array of neuron counts) based on the nodes present.
57 | - Updates the parent component (`App.tsx`) via the `onNetworkChange` callback when the structure changes.
58 | - **Implementation:**
59 | - Uses `ReactFlowProvider` and `ReactFlow` components.
60 | - Manages `elements` (nodes and edges) state.
61 | - `onConnect`, `onElementsRemove`, `onLoad` callbacks handle basic interactions.
62 | - `extractNetworkStructure` function attempts to parse the hidden layer configuration from the node labels (this is a simplification).
63 | - **Limitations & Future Work:**
64 | - **Visual Only:** Currently, the designer primarily serves as a visual aid. The actual network structure used by the agent is still primarily driven by the `AgentConfigPanel` (specifically the hidden layer neuron counts).
65 | - **Limited Functionality:** Adding new nodes with specific types (e.g., different activation functions) or editing neuron counts directly on the nodes is not implemented.
66 | - **Structure Extraction:** The logic to translate the visual graph into a precise layer configuration (`extractNetworkStructure`) is basic and needs significant improvement to handle complex graphs, different layer types, and connection validation.
67 | - **Integration:** Needs deeper integration so that the visual design *directly* and accurately defines the agent's network architecture passed to `backend-tfjs`.
68 |
69 | ### 2.4. UI Styling (`styles.css`)
70 |
71 | - **Purpose:** To style the new UI panels and ensure a consistent look.
72 | - **Changes:**
73 | - Added a `.ui-panels` container to hold the control/visualization panels on the right side.
74 | - Added specific styles for `.visualization-charts`, `.agent-config-panel`, and `.network-designer-panel`.
75 | - Included basic styling for `reactflow` elements to match the dark theme.
76 |
77 | ### 2.5. Application Entry Point (`App.tsx`)
78 |
79 | - **Purpose:** Integrates all UI components and manages the overall application state related to configuration.
80 | - **Changes:**
81 | - Imports and renders `VisualizationCharts`, `AgentConfigPanel`, and `NetworkDesigner` within the `.ui-panels` container.
82 | - Manages the `agentConfig` state.
83 | - Implements `handleApplyConfig` and `handleNetworkChange` callbacks to receive updates from the child components and update the `agentConfig` state.
84 | - Passes the `agentConfig` state down to the `Experience` component.
85 | - Modified `startTraining` and `resetEnvironment` calls to pass the current `agentConfig` to the `Experience` component, ensuring the agent is created/reset with the latest settings.
86 |
87 | ### 2.6. Core Experience (`Experience.tsx`)
88 |
89 | - **Purpose:** Manages the 3D scene, physics, agent logic, and environment interaction.
90 | - **Changes:**
91 | - Modified the `Agent` and `Experience` components to accept `agentConfig` as a prop.
92 | - Created an `initializeEnvironment` function within the `Agent` component. This function:
93 | - Takes the `agentConfig` as input.
94 | - Disposes of the previous TFJS model (`agentRefInternal.current?.dispose()`) to prevent memory leaks when the configuration changes.
95 | - Creates a new `DQNAgent` instance using the provided configuration.
96 | - Creates a new `IgnitionEnv` instance with the new agent.
97 | - Added a `useEffect` hook in the `Agent` component that calls `initializeEnvironment` whenever the `agentConfig` prop changes. This ensures the agent and environment are recreated with the new settings.
98 | - Modified the `startTraining` and `resetEnvironment` functions (exposed via `useImperativeHandle`) to accept the `agentConfig` and potentially re-initialize the environment if the config has changed since the last initialization.
99 | - Added `CuboidCollider` with `sensor` property to the `Cible` component for more reliable collision detection.
100 | - Minor refactoring for clarity and state management using Zustand.
101 |
102 | ## 3. Code Generalization & Modularity
103 |
104 | - **Dynamic Configuration:** The most significant generalization was making the agent's configuration dynamic. Instead of being hardcoded in `Experience.tsx`, the configuration is now managed in `App.tsx` and driven by the UI panels (`AgentConfigPanel`, `NetworkDesigner`). This allows users to experiment with different settings without modifying the source code.
105 | - **Component Structure:** Breaking down the UI into separate components (`TrainingControls`, `VisualizationCharts`, `AgentConfigPanel`, `NetworkDesigner`) improves modularity and maintainability.
106 | - **State Management:** Continued use of Zustand (`useTrainingStore`) provides a centralized way to manage training-related state accessible by multiple components.
107 |
108 | ## 4. Roadmap Update (`roadmap.md`)
109 |
110 | - The roadmap was updated to reflect the implemented features:
111 | - Added visualization charts to Phase 2.
112 | - Added dynamic configuration, config panel, and the basic network designer to Phase 3.
113 | - Marked relevant items as completed (✅).
114 | - Added notes about the current limitations of the network designer.
115 |
116 | ## 5. Considerations for Long-Term Contribution
117 |
118 | - **React Flow Integration:** The current Network Designer is basic. A key next step is to implement robust logic to translate the visual graph from React Flow into a valid network configuration (potentially a sequence of layer definitions) that can be directly used by `@ignitionai/backend-tfjs`. This involves defining custom node types, handling connections properly, and potentially adding UI elements for configuring layer parameters directly on the nodes.
119 | - **Backend Integration (Loss/Epsilon):** The visualization charts need access to real-time loss and epsilon values from the backend. This requires exposing these metrics from the `DQNAgent` training loop, possibly through callbacks passed during environment creation or by updating the Zustand store from within the backend package (which might be less ideal due to coupling).
120 | - **Performance:** While basic optimizations weren't the focus of this contribution, potential bottlenecks could arise from frequent state updates for charts or complex React Flow graphs. Performance profiling might be needed later.
121 | - **Error Handling:** More robust error handling could be added, especially around agent creation with potentially invalid configurations from the UI.
122 | - **Code Comments:** Added basic comments to new components, but further commenting, especially in complex logic areas (like environment interaction or future network graph parsing), would be beneficial.
123 |
124 | This contribution provides a significant step towards a more user-friendly and configurable interface for the IgnitionAI framework, laying the groundwork for further enhancements in visualization and no-code network design.
125 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🚀 IgnitionAI - Reinforcement Learning Made Simple
2 |
3 | [](https://www.npmjs.com/package/@ignitionai/backend-tfjs)
4 | [](https://www.npmjs.com/package/@ignitionai/backend-tfjs)
5 | [](https://github.com/ignitionai)
6 |
7 | ---
8 |
9 | IgnitionAI is designed to make Deep Reinforcement Learning **easy**, **modular**, and **production-ready**, especially within browser environments using technologies like WebGPU via TensorFlow.js.
10 |
11 | # 📑 Table of Contents
12 |
13 | - [Overview](#overview)
14 | - [Packages](#packages)
15 | - [Demo: Target Chasing (R3F)](#demo-target-chasing-r3f)
16 | - [Features](#features)
17 | - [Running the Demo](#running-the-demo)
18 | - [Core Library Usage](#core-library-usage)
19 | - [Installation](#installation)
20 | - [Getting Started](#getting-started)
21 | - [1. Import Modules](#1-import-modules)
22 | - [2. Create a DQN Agent](#2-create-a-dqn-agent)
23 | - [3. Create an Environment](#3-create-an-environment)
24 | - [4. Step Through Training](#4-step-through-training)
25 | - [Tips](#tips)
26 | - [Example: Reward Shaping](#example-reward-shaping)
27 | - [Roadmap](#roadmap)
28 | - [Contributing](#contributing)
29 |
30 | ---
31 |
32 | # Overview
33 |
34 | This project provides a set of tools and libraries to facilitate the development and visualization of reinforcement learning agents directly in the browser. It leverages TensorFlow.js for backend computations (including WebGPU support) and React Three Fiber for 3D visualization.
35 |
36 | ---
37 |
38 | # Packages
39 |
40 | This is a monorepo managed with pnpm. Key packages include:
41 |
42 | - **`packages/core`**: Contains the core `IgnitionEnv` class and shared utilities.
43 | - **`packages/backend-tfjs`**: Implements RL agents (like DQN) using TensorFlow.js.
44 | - **`packages/backend-onnx`**: (Planned) Backend for running inference using ONNX Runtime.
45 | - **`r3f/target-chasing`**: A demo application showcasing a DQN agent learning in a 3D environment using React Three Fiber.
46 |
47 | ---
48 |
49 | # Demo: Target Chasing (R3F)
50 |
51 | Located in `r3f/target-chasing`, this demo provides a visual example of a DQN agent learning to navigate a 3D space to reach a target.
52 |
53 | ## Features
54 |
55 | - **3D Visualization:** Uses React Three Fiber (R3F) and Rapier physics to render the agent, target, and environment.
56 | - **Real-time Training:** Watch the agent learn in real-time in your browser.
57 | - **Interactive UI Panels:**
58 | - **Training Controls:** Start, stop, and reset the training process. View basic stats like episode count, success rate, time, and current reward.
59 | - **Visualization Charts:** Real-time charts (using Recharts) displaying Reward, Loss (simulated), and Epsilon Decay (simulated) over training steps.
60 | - **Agent Configuration:** Modify hyperparameters (learning rate, gamma, epsilon settings, etc.) and basic network architecture (input/output size, hidden layers) without code changes. Click "Apply Configuration" to re-initialize the agent with new settings.
61 | - **Network Designer (Basic):** A visual drag-and-drop interface (using React Flow) to represent the network structure. Currently, this is primarily a visual aid; the actual network structure is defined via the Agent Configuration panel.
62 |
63 | ## Running the Demo
64 |
65 | 1. Navigate to the demo directory: `cd r3f/target-chasing`
66 | 2. Install dependencies (if not already done from the root): `pnpm install`
67 | 3. Run the development server: `pnpm dev`
68 | 4. Open the provided URL (usually `http://localhost:5173/`) in your browser.
69 |
70 | ---
71 |
72 | # Core Library Usage
73 |
74 | ## Installation
75 |
76 | ```bash
77 | pnpm install @ignitionai/backend-tfjs @ignitionai/core
78 | # or
79 | npm install @ignitionai/backend-tfjs @ignitionai/core
80 | # or
81 | yarn add @ignitionai/backend-tfjs @ignitionai/core
82 | ```
83 |
84 | ---
85 |
86 | ## Getting Started
87 |
88 | Here's a basic example of using the core library components.
89 |
90 | ### 1. Import Modules
91 |
92 | ```tsx
93 | import { DQNAgent } from '@ignitionai/backend-tfjs'
94 | import { IgnitionEnv } from '@ignitionai/core'
95 | ```
96 |
97 | ---
98 |
99 | ### 2. Create a DQN Agent
100 |
101 | Configure your agent. Note that these parameters can now be dynamically set via the UI in the demo.
102 |
103 | ```tsx
104 | const agentConfig = {
105 | inputSize: 9, // Size of the observation space
106 | actionSize: 4, // Number of possible actions
107 | hiddenLayers: [64, 64], // Example hidden layers
108 | lr: 0.001, // Learning rate
109 | gamma: 0.99, // Discount factor
110 | epsilon: 0.9, // Initial exploration rate
111 | epsilonDecay: 0.97, // Epsilon decay per step
112 | minEpsilon: 0.05, // Minimum exploration
113 | batchSize: 128, // Batch size for training
114 | memorySize: 100000 // Experience replay memory size
115 | };
116 |
117 | const dqnAgent = new DQNAgent(agentConfig);
118 | ```
119 |
120 | ---
121 |
122 | ### 3. Create an Environment
123 |
124 | Define the environment interactions.
125 |
126 | ```tsx
127 | const trainingEnv = new IgnitionEnv({
128 | agent: dqnAgent,
129 |
130 | getObservation: () => {
131 | // Return an array of normalized values representing the current state.
132 | // Example: [agentPosX, agentPosY, targetPosX, targetPosY, ...]
133 | return [];
134 | },
135 |
136 | applyAction: (action: number | number[]) => {
137 | // Apply the chosen action to update your environment state.
138 | console.log("Applying action:", action);
139 | },
140 |
141 | computeReward: () => {
142 | // Return a numerical reward based on the new state after the action.
143 | return 0;
144 | },
145 |
146 | isDone: () => {
147 | // Return true if the episode should end (e.g., agent reaches goal, time limit exceeded).
148 | return false;
149 | },
150 |
151 | onReset: () => {
152 | // Reset the environment to a starting state for the next episode.
153 | }
154 | });
155 | ```
156 |
157 | ---
158 |
159 | ### 4. Step Through Training
160 |
161 | Integrate the `step()` function into your application's loop (e.g., a `requestAnimationFrame` loop or `useFrame` in R3F).
162 |
163 | ```tsx
164 | // Example within a React component using R3F
165 | import { useFrame } from '@react-three/fiber';
166 |
167 | // ... inside your component
168 | useFrame(() => {
169 | if (isTraining) { // Assuming 'isTraining' is a state variable
170 | trainingEnv.step();
171 | }
172 | });
173 | ```
174 |
175 | Each call to `step()` performs one cycle:
176 | - Get observation -> Agent chooses action -> Apply action -> Compute reward -> Store experience -> Potentially train model -> Check if done -> Reset if done.
177 |
178 | ---
179 |
180 | # Tips
181 |
182 | - **Normalize Observations:** Ensure your observation values are scaled, typically between 0 and 1 or -1 and 1, for better network performance.
183 | - **Reward Shaping:** This is critical. Provide intermediate rewards to guide the agent. Don't rely solely on a large reward at the very end. See the example below.
184 | - **Visual Feedback:** Use the provided visualization charts and 3D view in the demo to understand agent behavior and debug issues.
185 | - **Hyperparameter Tuning:** Experiment with learning rate, epsilon decay, network architecture, etc., using the configuration panel in the demo.
186 |
187 | ---
188 |
189 | # Example: Reward Shaping
190 |
191 | **Bad reward shaping (Sparse Reward):**
192 |
193 | ```tsx
194 | // Only rewards reaching the exact goal
195 | computeReward: () => {
196 | return agentReachedTarget ? 100 : 0;
197 | }
198 | ```
199 |
200 | **Good reward shaping (Dense Reward):**
201 |
202 | ```tsx
203 | // Encourage progress toward the goal
204 | computeReward: () => {
205 | const distNow = distance(currentAgentPos, targetPos);
206 | const distBefore = previousDistance; // Store distance from the previous step
207 |
208 | // Reward for getting closer
209 | let reward = (distBefore - distNow) * 10;
210 |
211 | if (agentReachedTarget) {
212 | reward += 100; // Bonus for reaching the goal
213 | }
214 |
215 | // Optional: Small penalty for existing (encourages faster completion)
216 | // reward -= 0.1;
217 |
218 | previousDistance = distNow; // Update distance for the next step
219 | return reward;
220 | }
221 | ```
222 |
223 | ✅ Good reward shaping encourages better learning and faster convergence!
224 |
225 | ---
226 |
227 | # Roadmap
228 |
229 | See the [roadmap.md](./roadmap.md) file for planned features and development phases.
230 |
231 | ---
232 |
233 | # Contributing
234 |
235 | Contributions are welcome! Please refer to the [CONTRIBUTION_NOTES.md](./CONTRIBUTION_NOTES.md) for details on recent changes and potential areas for future development.
236 |
237 | ---
238 |
239 | Built with ❤️ by Salim (@IgnitionAI)
240 |
241 |
--------------------------------------------------------------------------------
/TODO.md:
--------------------------------------------------------------------------------
1 | # IgnitionAI Enhancement Project Todo List
2 |
3 | - [X] **Step 1: Setup & Initial Exploration**
4 | - [X] Install project dependencies (`pnpm install`).
5 | - [X] Explore `packages` directory structure and contents (`core`, `backend-tfjs`, `r3f`).
6 | - [X] Explore `ignition-web-rl` and `r3f` demo directories.
7 | - [X] Understand the interaction between packages.
8 | - [X] **Step 2: Update Roadmap**
9 | - [X] Edit `/home/ubuntu/ignition/roadmap.md`.
10 | - [X] Integrate drag-and-drop interface goal.
11 | - [X] Expand visualization goals.
12 | - [X] Add performance and documentation enhancement notes.
13 | - [X] Restructure/rephrase for clarity as a contribution.
14 | - [X] **Step 3: Implement Visualization Functions**
15 | - [X] Identify existing visualization capabilities in `r3f` package/demo.
16 | - [X] Choose and integrate a charting library (e.g., Recharts) into the demo app.
17 | - [X] Implement reward plot visualization.
18 | - [X] Implement loss plot visualization.
19 | - [X] Implement epsilon decay visualization.
20 | - [X] **Step 4: Enhance and Generalize Code**
21 | - [X] **Drag-and-Drop Interface:**
22 | - [X] Choose UI framework/library (likely React within the existing demo structure).
23 | - [X] Choose drag-and-drop library (e.g., React Flow).
24 | - [X] Design node types (layers, inputs, outputs, etc.).
25 | - [X] Implement the drag-and-drop canvas.
26 | - [X] Develop logic to translate the visual graph into an agent configuration (e.g., MLP layers for DQN).
27 | - [X] Integrate the interface with `@ignitionai/backend-tfjs` agent creation.
28 | - [X] **Code Generalization:**
29 | - [X] Refactor agent configuration/creation in `@ignitionai/backend-tfjs` to support dynamic creation from the new interface.
30 | - [X] Improve modularity where needed.
31 | - [X] **Step 5: Test Implementations**
32 | - [X] Test new visualization components.
33 | - [X] Test the drag-and-drop interface for network design.
34 | - [X] Run existing examples/tests to ensure no regressions.
35 | - [X] **Step 6: Document Changes and Improvements**
36 | - [X] Add comments to new code sections.
37 | - [X] Update `/home/ubuntu/ignition/README.md` with information about new features (visualizations, drag-and-drop interface).
38 | - [X] Write explanations of the work done, focusing on design choices and concepts for long-term contribution (save as a separate file, e.g., `CONTRIBUTION_NOTES.md`).
39 | - [X] **Step 7: Prepare Final Deliverables**
40 | - [X] Verify all todo items are complete or skipped intentionally.
41 | - [X] Create a patch file or zip archive of the modified `ignition` directory.
42 | - [ ] Prepare a summary message for the user including the deliverables.
43 |
--------------------------------------------------------------------------------
/ignition-web-rl/.gitignore:
--------------------------------------------------------------------------------
1 | # Logs
2 | logs
3 | *.log
4 | npm-debug.log*
5 | yarn-debug.log*
6 | yarn-error.log*
7 | pnpm-debug.log*
8 | lerna-debug.log*
9 |
10 | node_modules
11 | dist
12 | dist-ssr
13 | *.local
14 |
15 | # Editor directories and files
16 | .vscode/*
17 | !.vscode/extensions.json
18 | .idea
19 | .DS_Store
20 | *.suo
21 | *.ntvs*
22 | *.njsproj
23 | *.sln
24 | *.sw?
25 |
--------------------------------------------------------------------------------
/ignition-web-rl/README.md:
--------------------------------------------------------------------------------
1 | # IgnitionAI Web Reinforcement Learning Demo
2 |
3 | This project demonstrates the capabilities of IgnitionAI's reinforcement learning framework in a web environment. It features a 3D obstacle course where an agent learns to navigate and avoid moving obstacles to reach a target.
4 |
5 | ## Features
6 |
7 | - **3D Environment**: Built with React Three Fiber and Rapier physics
8 | - **Reinforcement Learning**: Powered by IgnitionAI's DQN implementation
9 | - **Interactive Training**: Real-time visualization of the agent's learning process
10 | - **Dynamic Obstacles**: Various movement patterns (horizontal, vertical, circular)
11 | - **Futuristic Design**: Sleek, modern UI with metallic textures and dynamic lighting
12 |
13 | ## Technologies
14 |
15 | - **Frontend**: React, TypeScript, Vite
16 | - **3D Rendering**: Three.js, React Three Fiber
17 | - **Physics**: Rapier (WebAssembly-based physics engine)
18 | - **Machine Learning**: TensorFlow.js, IgnitionAI
19 | - **State Management**: React Context API
20 |
21 | ## Getting Started
22 |
23 | ### Prerequisites
24 |
25 | - Node.js (v16+)
26 | - npm or yarn
27 |
28 | ### Installation
29 |
30 | 1. Clone the repository:
31 | ```bash
32 | git clone https://github.com/IgnitionAI/ignition.git
33 | cd ignition/ignition-web-rl
34 | ```
35 |
36 | 2. Install dependencies:
37 | ```bash
38 | npm install
39 | # or
40 | yarn
41 | ```
42 |
43 | 3. Start the development server:
44 | ```bash
45 | npm run dev
46 | # or
47 | yarn dev
48 | ```
49 |
50 | 4. Open your browser and navigate to `http://localhost:5173`
51 |
52 | ## Usage
53 |
54 | 1. **Start Training**: Click the "Start Training" button to begin the reinforcement learning process
55 | 2. **Reset Environment**: Use the reset button to start a new episode
56 | 3. **Adjust Parameters**: Modify learning parameters in the `Experience.tsx` file
57 | 4. **Observe Progress**: Watch the agent improve over time as it learns to navigate the environment
58 |
59 | ## Architecture
60 |
61 | The project is structured as follows:
62 |
63 | - `src/Experience.tsx`: Main environment setup and RL integration
64 | - `src/simple-agent.tsx`: Agent visualization and physics
65 | - `src/target.tsx`: Target object implementation
66 | - `src/themes.ts`: Visual styling configuration
67 |
68 | ## Known Issues
69 |
70 | - Rapier physics engine may occasionally throw Rust errors due to concurrent access to physics objects
71 | - These errors are handled with try/catch blocks and proper synchronization
72 |
73 | ## Contributing
74 |
75 | Contributions are welcome! Please feel free to submit a Pull Request.
76 |
77 | ## License
78 |
79 | This project is licensed under the MIT License - see the LICENSE file for details.
80 |
81 | ## Acknowledgments
82 |
83 | - IgnitionAI team for the reinforcement learning framework
84 | - React Three Fiber community for the 3D rendering capabilities
85 | - Rapier team for the physics engine
86 |
--------------------------------------------------------------------------------
/ignition-web-rl/eslint.config.js:
--------------------------------------------------------------------------------
1 | import js from '@eslint/js'
2 | import globals from 'globals'
3 | import reactHooks from 'eslint-plugin-react-hooks'
4 | import reactRefresh from 'eslint-plugin-react-refresh'
5 | import tseslint from 'typescript-eslint'
6 |
7 | export default tseslint.config(
8 | { ignores: ['dist'] },
9 | {
10 | extends: [js.configs.recommended, ...tseslint.configs.recommended],
11 | files: ['**/*.{ts,tsx}'],
12 | languageOptions: {
13 | ecmaVersion: 2020,
14 | globals: globals.browser,
15 | },
16 | plugins: {
17 | 'react-hooks': reactHooks,
18 | 'react-refresh': reactRefresh,
19 | },
20 | rules: {
21 | ...reactHooks.configs.recommended.rules,
22 | 'react-refresh/only-export-components': [
23 | 'warn',
24 | { allowConstantExport: true },
25 | ],
26 | },
27 | },
28 | )
29 |
--------------------------------------------------------------------------------
/ignition-web-rl/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | IgnitionAI Demo
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/ignition-web-rl/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "ignition-web-rl",
3 | "private": true,
4 | "version": "0.0.0",
5 | "type": "module",
6 | "scripts": {
7 | "dev": "vite",
8 | "build": "tsc -b && vite build",
9 | "lint": "eslint .",
10 | "preview": "vite preview"
11 | },
12 | "dependencies": {
13 | "@ignitionai/backend-tfjs": "^0.1.0",
14 | "@ignitionai/core": "^0.1.0",
15 | "@react-three/drei": "^10.0.6",
16 | "@react-three/fiber": "^9.1.2",
17 | "@react-three/rapier": "^2.1.0",
18 | "@tensorflow/tfjs": "^4.22.0",
19 | "@tensorflow/tfjs-backend-webgpu": "^4.22.0",
20 | "@types/three": "^0.162.0",
21 | "react": "^19.0.0",
22 | "react-dom": "^19.0.0",
23 | "three": "^0.162.0",
24 | "zustand": "^5.0.3"
25 | },
26 | "devDependencies": {
27 | "@eslint/js": "^9.22.0",
28 | "@types/react": "^19.0.10",
29 | "@types/react-dom": "^19.0.4",
30 | "@vitejs/plugin-react": "^4.3.4",
31 | "eslint": "^9.22.0",
32 | "eslint-plugin-react-hooks": "^5.2.0",
33 | "eslint-plugin-react-refresh": "^0.4.19",
34 | "globals": "^16.0.0",
35 | "typescript": "~5.7.2",
36 | "typescript-eslint": "^8.26.1",
37 | "vite": "^6.3.1"
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/ignition-web-rl/public/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IgnitionAI/ignition/9177384c3c7750bc25a596c10150ad5008bfaa95/ignition-web-rl/public/logo.png
--------------------------------------------------------------------------------
/ignition-web-rl/public/vite.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ignition-web-rl/src/App.tsx:
--------------------------------------------------------------------------------
1 | import { Canvas } from '@react-three/fiber'
2 | import { OrbitControls } from '@react-three/drei'
3 | import { Physics } from '@react-three/rapier'
4 | import Experience from './Experience'
5 | import { useState, useRef } from 'react'
6 | import { Themes, ThemeName } from './themes'
7 |
8 | function App() {
9 | const [currentTheme, setCurrentTheme] = useState('Futuristic')
10 | const [isTraining, setIsTraining] = useState(false)
11 | const [episodeCount, setEpisodeCount] = useState(0)
12 | const [totalReward, setTotalReward] = useState(0)
13 | const [episodeTime, setEpisodeTime] = useState(0)
14 |
15 | // Référence aux contrôles d'entraînement
16 | const trainingControlsRef = useRef<{
17 | startTraining: () => void;
18 | stopTraining: () => void;
19 | resetEnvironment: () => void;
20 | } | null>(null)
21 |
22 | const changeTheme = (themeName: ThemeName) => {
23 | setCurrentTheme(themeName)
24 | }
25 |
26 | const handleStartTraining = () => {
27 | if (trainingControlsRef.current) {
28 | trainingControlsRef.current.startTraining()
29 | setIsTraining(true)
30 | }
31 | }
32 |
33 | const handleStopTraining = () => {
34 | if (trainingControlsRef.current) {
35 | trainingControlsRef.current.stopTraining()
36 | setIsTraining(false)
37 | }
38 | }
39 |
40 | const handleResetEnvironment = () => {
41 | if (trainingControlsRef.current) {
42 | trainingControlsRef.current.resetEnvironment()
43 | setTotalReward(0)
44 | }
45 | }
46 |
47 | const handleEnvironmentReady = (controls: {
48 | startTraining: () => void;
49 | stopTraining: () => void;
50 | resetEnvironment: () => void;
51 | }) => {
52 | trainingControlsRef.current = controls
53 | }
54 |
55 | return (
56 | <>
57 |
65 | {Object.keys(Themes).map((themeName) => (
66 | changeTheme(themeName as ThemeName)}
69 | style={{
70 | background: currentTheme === themeName ? Themes[themeName as ThemeName].colors.primary : '#333',
71 | color: 'white',
72 | border: 'none',
73 | padding: '8px 12px',
74 | borderRadius: '4px',
75 | cursor: 'pointer',
76 | fontWeight: currentTheme === themeName ? 'bold' : 'normal'
77 | }}
78 | >
79 | {themeName}
80 |
81 | ))}
82 |
83 |
84 |
99 |
100 | IgnitionAI - Entraînement
101 |
102 |
103 |
104 | Épisodes:
105 | {episodeCount}
106 |
107 |
108 |
109 | Récompense:
110 | {totalReward.toFixed(2)}
111 |
112 |
113 |
114 | Temps écoulé:
115 | {episodeTime} / 60 sec
116 |
117 |
118 |
119 | {!isTraining ? (
120 |
132 | Démarrer
133 |
134 | ) : (
135 |
147 | Arrêter
148 |
149 | )}
150 |
151 |
163 | Réinitialiser
164 |
165 |
166 |
167 |
168 |
169 |
175 |
176 |
177 |
185 |
186 |
187 | >
188 | )
189 | }
190 |
191 | export default App
192 |
--------------------------------------------------------------------------------
/ignition-web-rl/src/assets/react.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ignition-web-rl/src/index.css:
--------------------------------------------------------------------------------
1 | #root {
2 | width: 100vw;
3 | height: 100vh;
4 | }
5 |
6 | body {
7 | margin: 0;
8 | }
--------------------------------------------------------------------------------
/ignition-web-rl/src/logo-3d.tsx:
--------------------------------------------------------------------------------
1 | import { Center, Text, shaderMaterial } from '@react-three/drei'
2 | import * as THREE from 'three'
3 | import { useFrame, extend } from '@react-three/fiber'
4 | import { useRef } from 'react'
5 | import { DefaultTheme, ThemeProps } from './themes'
6 |
7 | // Définition d'un matériau shader personnalisé pour le gradient
8 | //vibe-coding for shader
9 | const GradientMaterial = shaderMaterial(
10 | {
11 | colorA: new THREE.Color('#a5f3fc'),
12 | colorB: new THREE.Color('#0c8cbf')
13 | },
14 | // Vertex shader
15 | `
16 | varying vec2 vUv;
17 | void main() {
18 | vUv = uv;
19 | gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
20 | }
21 | `,
22 | // Fragment shader
23 | `
24 | uniform vec3 colorA;
25 | uniform vec3 colorB;
26 | varying vec2 vUv;
27 |
28 | void main() {
29 | vec3 color = mix(colorA, colorB, vUv.y);
30 | gl_FragColor = vec4(color, 1.0);
31 | }
32 | `
33 | )
34 |
35 | // Étendre React Three Fiber avec notre matériau personnalisé
36 | extend({ GradientMaterial })
37 |
38 | // Déclarer le type pour TypeScript
39 | declare global {
40 | namespace JSX {
41 | interface IntrinsicElements {
42 | gradientMaterial: any
43 | }
44 | }
45 | }
46 |
47 | interface Logo3DProps {
48 | theme?: ThemeProps;
49 | }
50 |
51 | function Logo3D({ theme = DefaultTheme }: Logo3DProps) {
52 | const groupRef = useRef(null)
53 | const materialRef = useRef(null)
54 |
55 | // Animation de flottement agressive
56 | useFrame((state) => {
57 | if (groupRef.current) {
58 | const time = state.clock.elapsedTime
59 |
60 | // Mouvement vertical plus rapide et plus ample
61 | groupRef.current.position.y = 15 + Math.sin(time * 1.2) * 0.8
62 |
63 | // Ajout d'un mouvement latéral pour plus de dynamisme
64 | groupRef.current.position.x = Math.sin(time * 0.7) * 0.5
65 |
66 | // Mouvement avant/arrière pour effet 3D
67 | groupRef.current.position.z = -15 + Math.sin(time * 0.9) * 0.7
68 |
69 | // Rotations plus prononcées sur plusieurs axes
70 | groupRef.current.rotation.y = Math.sin(time * 0.8) * 0.15
71 | groupRef.current.rotation.x = Math.sin(time * 0.6) * 0.08
72 | groupRef.current.rotation.z = Math.sin(time * 0.5) * 0.05
73 |
74 | // Mise à jour des couleurs du gradient selon le thème
75 | if (materialRef.current) {
76 | materialRef.current.uniforms.colorA.value = new THREE.Color(theme.logo.startColor);
77 | materialRef.current.uniforms.colorB.value = new THREE.Color(theme.logo.endColor);
78 | }
79 | }
80 | })
81 |
82 | return (
83 |
84 |
85 |
92 | IgnitionAI
93 | {/* @ts-ignore */}
94 |
95 |
96 |
97 |
98 | )
99 | }
100 |
101 | export default Logo3D
--------------------------------------------------------------------------------
/ignition-web-rl/src/main.tsx:
--------------------------------------------------------------------------------
1 | import { StrictMode } from 'react'
2 | import { createRoot } from 'react-dom/client'
3 | import './index.css'
4 | import App from './App.tsx'
5 |
6 | createRoot(document.getElementById('root')!).render(
7 |
8 |
9 | ,
10 | )
11 |
--------------------------------------------------------------------------------
/ignition-web-rl/src/simple-agent.tsx:
--------------------------------------------------------------------------------
1 | import { useRef, forwardRef, useImperativeHandle } from 'react'
2 | import { RigidBody, RapierRigidBody } from '@react-three/rapier'
3 | import * as THREE from 'three'
4 | import { DefaultTheme, ThemeProps } from './themes'
5 |
6 | interface SimpleAgentProps {
7 | position?: [number, number, number];
8 | theme?: ThemeProps;
9 | onObstacleCollision?: () => void;
10 | }
11 |
12 | const SimpleAgent = forwardRef(({ position = [0, 1, 0], theme = DefaultTheme, onObstacleCollision }: SimpleAgentProps, ref) => {
13 | const bodyRef = useRef(null)
14 | const rigidBodyRef = useRef(null)
15 |
16 | // Exposer la référence
17 | useImperativeHandle(ref, () => rigidBodyRef.current)
18 |
19 | // Animation simple pour donner vie à l'agent
20 | // useFrame((state) => {
21 | // if (bodyRef.current) {
22 | // // Légère oscillation pour simuler une respiration
23 | // bodyRef.current.position.y = Math.sin(state.clock.getElapsedTime() * 2) * 0.1 + position[1]
24 |
25 | // // Légère rotation pour plus de dynamisme
26 | // bodyRef.current.rotation.y = state.clock.getElapsedTime() * 0.5
27 | // }
28 | // })
29 |
30 | return (
31 | {
38 | // Vérifier si la collision est avec un obstacle ou un mur
39 | if (e.other.rigidBodyObject?.name === 'obstacle' ||
40 | e.other.rigidBodyObject?.name === 'wall') {
41 | // Signaler la collision
42 | onObstacleCollision && onObstacleCollision();
43 | }
44 | }}
45 | linearDamping={0.5}
46 | angularDamping={0.5}
47 | mass={1}
48 | >
49 |
50 | {/* Corps simple cubique comme dans Unity ML-Agents */}
51 |
52 |
53 |
60 |
61 |
62 | {/* Yeux */}
63 |
64 | {/* Œil gauche */}
65 |
66 |
67 |
72 |
73 |
74 | {/* Pupille gauche */}
75 |
76 |
77 |
82 |
83 |
84 | {/* Œil droit */}
85 |
86 |
87 |
92 |
93 |
94 | {/* Pupille droite */}
95 |
96 |
97 |
102 |
103 |
104 |
105 | {/* Petit indicateur de direction (sous les yeux) */}
106 |
107 |
108 |
115 |
116 |
117 |
118 | )
119 | })
120 |
121 | export default SimpleAgent
--------------------------------------------------------------------------------
/ignition-web-rl/src/store/targetStore.ts:
--------------------------------------------------------------------------------
1 | import { create } from 'zustand'
2 |
3 | // Dimensions de l'arène
4 | const ARENA_SIZE = 40
5 | const SAFE_ZONE = 5 // Zone de sécurité pour éviter que la cible soit trop près des bords
6 |
7 | // Fonction pour générer une position aléatoire dans l'arène
8 | function generateRandomPosition(): [number, number, number] {
9 | // Générer des coordonnées aléatoires dans l'arène
10 | const x = (Math.random() * (ARENA_SIZE - 2 * SAFE_ZONE)) - (ARENA_SIZE / 2 - SAFE_ZONE)
11 | const z = (Math.random() * (ARENA_SIZE - 2 * SAFE_ZONE)) - (ARENA_SIZE / 2 - SAFE_ZONE)
12 |
13 | // La hauteur est fixe pour que la cible soit au-dessus du sol
14 | return [x, 1.5, z]
15 | }
16 |
17 | interface TargetState {
18 | position: [number, number, number];
19 | collected: boolean;
20 | resetTarget: () => void;
21 | collectTarget: () => void;
22 | }
23 |
24 | export const useTargetStore = create((set) => ({
25 | position: generateRandomPosition(),
26 | collected: false,
27 | resetTarget: () => set({ position: generateRandomPosition(), collected: false }),
28 | collectTarget: () => {
29 | set({ collected: true })
30 | // Après un court délai, régénérer une nouvelle cible à une position différente
31 | setTimeout(() => {
32 | set({ position: generateRandomPosition(), collected: false })
33 | }, 1000)
34 | }
35 | }))
36 |
37 | // Exporter les constantes pour les utiliser ailleurs
38 | export { ARENA_SIZE, SAFE_ZONE }
39 |
--------------------------------------------------------------------------------
/ignition-web-rl/src/target.tsx:
--------------------------------------------------------------------------------
1 | import { useRef } from 'react'
2 | import { useFrame } from '@react-three/fiber'
3 | import { RigidBody } from '@react-three/rapier'
4 | import * as THREE from 'three'
5 | import { DefaultTheme, ThemeProps } from './themes'
6 | import { useTargetStore } from './store/targetStore'
7 |
8 | interface TargetProps {
9 | theme?: ThemeProps;
10 | }
11 |
12 | function Target({ theme = DefaultTheme }: TargetProps) {
13 | const targetRef = useRef(null)
14 | const { position, collectTarget } = useTargetStore()
15 |
16 | // Animation pour rendre la cible plus visible et attractive
17 | useFrame((state) => {
18 | if (targetRef.current) {
19 | // Rotation continue
20 | targetRef.current.rotation.y = state.clock.getElapsedTime() * 2
21 |
22 | // Légère oscillation verticale
23 | targetRef.current.position.y = Math.sin(state.clock.getElapsedTime() * 3) * 0.2 + position[1]
24 | }
25 | })
26 |
27 | return (
28 | {
34 | // Vérifier si c'est l'agent qui a touché la cible
35 | if (e.rigidBodyObject?.name === 'agent') {
36 | collectTarget()
37 | }
38 | }}
39 | >
40 |
41 | {/* Sphère principale */}
42 |
43 |
44 |
51 |
52 |
53 | {/* Anneaux décoratifs */}
54 |
55 |
56 |
63 |
64 |
65 |
66 |
67 |
74 |
75 |
76 | {/* Particules lumineuses autour de la cible */}
77 | {[...Array(8)].map((_, i) => (
78 |
87 |
88 |
95 |
96 | ))}
97 |
98 |
99 | )
100 | }
101 |
102 | export default Target
--------------------------------------------------------------------------------
/ignition-web-rl/src/themes.ts:
--------------------------------------------------------------------------------
1 | // Définition des types pour notre système de thèmes
2 | export interface MaterialProps {
3 | color: string;
4 | metalness: number;
5 | roughness: number;
6 | emissive?: string;
7 | emissiveIntensity?: number;
8 | }
9 |
10 | export interface ThemeColors {
11 | primary: string;
12 | secondary: string;
13 | accent: string;
14 | background: string;
15 | floor: string;
16 | gridCell: string;
17 | gridSection: string;
18 | }
19 |
20 | export interface ThemeProps {
21 | name: string;
22 | colors: ThemeColors;
23 | materials: {
24 | wall: MaterialProps;
25 | obstacle: MaterialProps;
26 | agent: MaterialProps;
27 | floor: MaterialProps;
28 | target: MaterialProps;
29 | };
30 | logo: {
31 | startColor: string;
32 | endColor: string;
33 | materialProps: MaterialProps;
34 | };
35 | lighting: {
36 | ambient: {
37 | intensity: number;
38 | };
39 | directional: {
40 | intensity: number;
41 | position: [number, number, number];
42 | };
43 | spot: {
44 | intensity: number;
45 | color: string;
46 | angle: number;
47 | penumbra: number;
48 | distance: number;
49 | };
50 | };
51 | grid: {
52 | cellSize: number;
53 | cellThickness: number;
54 | sectionSize: number;
55 | sectionThickness: number;
56 | fadeStrength: number;
57 | };
58 | }
59 |
60 | // Thème futuriste (celui que vous utilisez actuellement)
61 | export const FuturisticTheme: ThemeProps = {
62 | name: "Futuristic",
63 | colors: {
64 | primary: "#0c8cbf",
65 | secondary: "#a5f3fc",
66 | accent: "#3f4e8d",
67 | background: "#171720",
68 | floor: "#171730",
69 | gridCell: "#3f4e8d",
70 | gridSection: "#0c8cbf"
71 | },
72 | materials: {
73 | wall: {
74 | color: "#0c8cbf",
75 | metalness: 0.6,
76 | roughness: 0.2
77 | },
78 | obstacle: {
79 | color: "#a5f3fc",
80 | metalness: 0.8,
81 | roughness: 0.1,
82 | emissive: "#a5f3fc",
83 | emissiveIntensity: 0.2
84 | },
85 | agent: {
86 | color: "#0c8cbf",
87 | metalness: 0.7,
88 | roughness: 0.2,
89 | emissive: "#0c8cbf",
90 | emissiveIntensity: 0.3
91 | },
92 | floor: {
93 | color: "#171730",
94 | metalness: 0.8,
95 | roughness: 0.2
96 | },
97 | target: {
98 | color: "#ff9d00",
99 | metalness: 0.9,
100 | roughness: 0.1,
101 | emissive: "#ff9d00",
102 | emissiveIntensity: 0.8
103 | }
104 | },
105 | logo: {
106 | startColor: "#a5f3fc",
107 | endColor: "#0c8cbf",
108 | materialProps: {
109 | color: "#ffffff",
110 | metalness: 0.8,
111 | roughness: 0.1,
112 | emissive: "#0c8cbf",
113 | emissiveIntensity: 0.5
114 | }
115 | },
116 | lighting: {
117 | ambient: {
118 | intensity: 0.3
119 | },
120 | directional: {
121 | intensity: 1.5,
122 | position: [10, 10, 5]
123 | },
124 | spot: {
125 | intensity: 1,
126 | color: "#0c8cbf",
127 | angle: 0.6,
128 | penumbra: 0.5,
129 | distance: 50
130 | }
131 | },
132 | grid: {
133 | cellSize: 2,
134 | cellThickness: 0.6,
135 | sectionSize: 10,
136 | sectionThickness: 1.5,
137 | fadeStrength: 1
138 | }
139 | };
140 |
141 | // Thème antique
142 | export const AncientTheme: ThemeProps = {
143 | name: "Ancient",
144 | colors: {
145 | primary: "#8d6e63",
146 | secondary: "#d7ccc8",
147 | accent: "#a1887f",
148 | background: "#3e2723",
149 | floor: "#4e342e",
150 | gridCell: "#8d6e63",
151 | gridSection: "#6d4c41"
152 | },
153 | materials: {
154 | wall: {
155 | color: "#8d6e63",
156 | metalness: 0.1,
157 | roughness: 0.8
158 | },
159 | obstacle: {
160 | color: "#a1887f",
161 | metalness: 0.2,
162 | roughness: 0.7,
163 | emissive: "#d7ccc8",
164 | emissiveIntensity: 0.1
165 | },
166 | agent: {
167 | color: "#d7ccc8",
168 | metalness: 0.3,
169 | roughness: 0.6,
170 | emissive: "#d7ccc8",
171 | emissiveIntensity: 0.2
172 | },
173 | floor: {
174 | color: "#4e342e",
175 | metalness: 0.2,
176 | roughness: 0.8
177 | },
178 | target: {
179 | color: "#ffd54f",
180 | metalness: 0.4,
181 | roughness: 0.5,
182 | emissive: "#ffd54f",
183 | emissiveIntensity: 0.6
184 | }
185 | },
186 | logo: {
187 | startColor: "#d7ccc8",
188 | endColor: "#8d6e63",
189 | materialProps: {
190 | color: "#ffffff",
191 | metalness: 0.3,
192 | roughness: 0.7,
193 | emissive: "#d7ccc8",
194 | emissiveIntensity: 0.3
195 | }
196 | },
197 | lighting: {
198 | ambient: {
199 | intensity: 0.5
200 | },
201 | directional: {
202 | intensity: 1.2,
203 | position: [10, 10, 5]
204 | },
205 | spot: {
206 | intensity: 0.8,
207 | color: "#ffcc80",
208 | angle: 0.7,
209 | penumbra: 0.6,
210 | distance: 40
211 | }
212 | },
213 | grid: {
214 | cellSize: 2,
215 | cellThickness: 0.4,
216 | sectionSize: 10,
217 | sectionThickness: 1.2,
218 | fadeStrength: 0.8
219 | }
220 | };
221 |
222 | // Thème naturel
223 | export const NatureTheme: ThemeProps = {
224 | name: "Nature",
225 | colors: {
226 | primary: "#388e3c",
227 | secondary: "#81c784",
228 | accent: "#1b5e20",
229 | background: "#1b3024",
230 | floor: "#2e7d32",
231 | gridCell: "#4caf50",
232 | gridSection: "#2e7d32"
233 | },
234 | materials: {
235 | wall: {
236 | color: "#388e3c",
237 | metalness: 0.1,
238 | roughness: 0.9
239 | },
240 | obstacle: {
241 | color: "#81c784",
242 | metalness: 0.1,
243 | roughness: 0.8,
244 | emissive: "#81c784",
245 | emissiveIntensity: 0.1
246 | },
247 | agent: {
248 | color: "#1b5e20",
249 | metalness: 0.2,
250 | roughness: 0.7,
251 | emissive: "#81c784",
252 | emissiveIntensity: 0.2
253 | },
254 | floor: {
255 | color: "#2e7d32",
256 | metalness: 0.1,
257 | roughness: 0.9
258 | },
259 | target: {
260 | color: "#8bc34a",
261 | metalness: 0.3,
262 | roughness: 0.6,
263 | emissive: "#8bc34a",
264 | emissiveIntensity: 0.5
265 | }
266 | },
267 | logo: {
268 | startColor: "#81c784",
269 | endColor: "#1b5e20",
270 | materialProps: {
271 | color: "#ffffff",
272 | metalness: 0.2,
273 | roughness: 0.8,
274 | emissive: "#81c784",
275 | emissiveIntensity: 0.3
276 | }
277 | },
278 | lighting: {
279 | ambient: {
280 | intensity: 0.6
281 | },
282 | directional: {
283 | intensity: 1.3,
284 | position: [10, 10, 5]
285 | },
286 | spot: {
287 | intensity: 0.7,
288 | color: "#aed581",
289 | angle: 0.8,
290 | penumbra: 0.7,
291 | distance: 45
292 | }
293 | },
294 | grid: {
295 | cellSize: 2,
296 | cellThickness: 0.5,
297 | sectionSize: 10,
298 | sectionThickness: 1.3,
299 | fadeStrength: 0.9
300 | }
301 | };
302 |
303 | // Thème par défaut (futuriste)
304 | export const DefaultTheme = FuturisticTheme;
305 |
306 | // Tous les thèmes disponibles
307 | export const Themes = {
308 | Futuristic: FuturisticTheme,
309 | Ancient: AncientTheme,
310 | Nature: NatureTheme
311 | };
312 |
313 | // Type pour les noms de thèmes
314 | export type ThemeName = keyof typeof Themes;
315 |
--------------------------------------------------------------------------------
/ignition-web-rl/src/vite-env.d.ts:
--------------------------------------------------------------------------------
1 | ///
2 |
--------------------------------------------------------------------------------
/ignition-web-rl/tsconfig.app.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
4 | "target": "ES2020",
5 | "useDefineForClassFields": true,
6 | "lib": ["ES2020", "DOM", "DOM.Iterable"],
7 | "module": "ESNext",
8 | "skipLibCheck": true,
9 |
10 | /* Bundler mode */
11 | "moduleResolution": "bundler",
12 | "allowImportingTsExtensions": true,
13 | "isolatedModules": true,
14 | "moduleDetection": "force",
15 | "noEmit": true,
16 | "jsx": "react-jsx",
17 |
18 | /* Linting */
19 | "strict": true,
20 | "noUnusedLocals": true,
21 | "noUnusedParameters": true,
22 | "noFallthroughCasesInSwitch": true,
23 | "noUncheckedSideEffectImports": true
24 | },
25 | "include": ["src"]
26 | }
27 |
--------------------------------------------------------------------------------
/ignition-web-rl/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "files": [],
3 | "references": [
4 | { "path": "./tsconfig.app.json" },
5 | { "path": "./tsconfig.node.json" }
6 | ]
7 | }
8 |
--------------------------------------------------------------------------------
/ignition-web-rl/tsconfig.node.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
4 | "target": "ES2022",
5 | "lib": ["ES2023"],
6 | "module": "ESNext",
7 | "skipLibCheck": true,
8 |
9 | /* Bundler mode */
10 | "moduleResolution": "bundler",
11 | "allowImportingTsExtensions": true,
12 | "isolatedModules": true,
13 | "moduleDetection": "force",
14 | "noEmit": true,
15 |
16 | /* Linting */
17 | "strict": true,
18 | "noUnusedLocals": true,
19 | "noUnusedParameters": true,
20 | "noFallthroughCasesInSwitch": true,
21 | "noUncheckedSideEffectImports": true
22 | },
23 | "include": ["vite.config.ts"]
24 | }
25 |
--------------------------------------------------------------------------------
/ignition-web-rl/vite.config.ts:
--------------------------------------------------------------------------------
1 | import { defineConfig } from 'vite'
2 | import react from '@vitejs/plugin-react'
3 | import path from 'path'
4 |
5 | // https://vite.dev/config/
6 | export default defineConfig({
7 | plugins: [react()],
8 | resolve: {
9 | alias: {
10 | '@ignitionai/backend-tfjs': path.resolve(__dirname, 'node_modules/@ignitionai/backend-tfjs/dist'),
11 | '@ignitionai/core': path.resolve(__dirname, 'node_modules/@ignitionai/core/dist')
12 | }
13 | }
14 | })
15 |
--------------------------------------------------------------------------------
/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "ignition-monorepo",
3 | "author": "salim4n (@IgnitionAI)",
4 | "license": "MIT",
5 | "private": false,
6 | "workspaces": [
7 | "packages/*"
8 | ],
9 | "scripts": {
10 | "build": "pnpm -r build",
11 | "test": "vitest run",
12 | "r3f:dev": "pnpm --filter @ignitionai/r3f dev",
13 | "r3f:build": "pnpm --filter @ignitionai/r3f build",
14 | "r3f:preview": "pnpm --filter @ignitionai/r3f preview",
15 | "publish-packages": "pnpm build && pnpm -r publish",
16 | "version-packages": "pnpm -r version",
17 | "clean": "pnpm -r exec -- rm -rf dist node_modules",
18 | "lint": "pnpm -r lint"
19 | },
20 | "packageManager": "pnpm@10.8.0+sha512.0e82714d1b5b43c74610193cb20734897c1d00de89d0e18420aebc5977fa13d780a9cb05734624e81ebd81cc876cd464794850641c48b9544326b5622ca29971",
21 | "devDependencies": {
22 | "@vitest/ui": "^3.1.1",
23 | "dotenv": "^16.5.0",
24 | "typescript": "^5.8.3",
25 | "vitest": "^3.1.1"
26 | },
27 | "dependencies": {
28 | "reactflow": "^11.11.4"
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/packages/backend-onnx/index.ts:
--------------------------------------------------------------------------------
1 | export const backendOnnxHello = () => {
2 | console.log("Hello from backend-onnx");
3 | };
--------------------------------------------------------------------------------
/packages/backend-onnx/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "@ignitionai/backend-onnx",
3 | "version": "0.1.0",
4 | "main": "dist/index.js",
5 | "types": "dist/index.d.ts",
6 | "scripts": {
7 | "build": "tsc"
8 | },
9 | "dependencies": {}
10 | }
11 |
--------------------------------------------------------------------------------
/packages/backend-onnx/src/index.ts:
--------------------------------------------------------------------------------
1 | // Backend ONNX implementation
2 | export const version = '0.1.0';
3 |
4 | export default {
5 | name: 'backend-onnx',
6 | version
7 | };
--------------------------------------------------------------------------------
/packages/backend-onnx/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "../../tsconfig.base.json",
3 | "compilerOptions": {
4 | "outDir": "dist"
5 | },
6 | "include": ["src/**/*"]
7 | }
8 |
--------------------------------------------------------------------------------
/packages/backend-tfjs/backend/setBackend.ts:
--------------------------------------------------------------------------------
1 | import * as tf from '@tensorflow/tfjs';
2 | import '@tensorflow/tfjs-backend-webgpu'; // Experimental
3 |
4 | /**
5 | * Initialize TensorFlow.js with the best available backend
6 | * Preference order: WebGPU > WebGL > CPU > WASM
7 | * @returns The name of the backend used
8 | */
9 | export async function initTfjsBackend(): Promise {
10 | // List of backends in order of preference
11 | const backends = ['webgpu', 'webgl', 'cpu', 'wasm'];
12 |
13 | // Check which backends are available
14 | const availableBackends = backends.filter(b => tf.findBackend(b) !== undefined);
15 | console.log('Available backends:', availableBackends);
16 |
17 | for (const backend of availableBackends) {
18 | try {
19 | await tf.setBackend(backend);
20 | const currentBackend = tf.getBackend();
21 | console.log(`TensorFlow.js using backend: ${currentBackend}`);
22 | return currentBackend;
23 | } catch (error) {
24 | console.warn(`Unable to initialize ${backend} backend:`, error);
25 | continue;
26 | }
27 | }
28 |
29 | throw new Error('No TensorFlow.js backend available');
30 | }
31 |
32 | /**
33 | * Check if WebGPU is available
34 | */
35 | export function isWebGPUAvailable(): boolean {
36 | return 'gpu' in navigator;
37 | }
38 |
39 | /**
40 | * Get information about the current backend
41 | */
42 | export function getTfjsBackendInfo() {
43 | const backend = tf.getBackend();
44 | const memory = tf.memory();
45 |
46 | return {
47 | backend,
48 | numTensors: memory.numTensors,
49 | numDataBuffers: memory.numDataBuffers,
50 | numBytes: memory.numBytes,
51 | unreliable: memory.unreliable,
52 | };
53 | }
54 |
55 | /**
56 | * Configure TensorFlow.js backend safely
57 | * @param name Name of the backend to use ('webgpu', 'webgl', 'cpu', or 'wasm')
58 | * @returns Promise
59 | */
60 | export async function setBackendSafe(
61 | name: 'cpu' | 'webgl' | 'wasm' | 'webgpu' = "webgl"
62 | ): Promise {
63 | // Use findBackend to check if the backend is available
64 | if (tf.findBackend(name) === undefined) {
65 | console.warn(`[TFJS] Backend "${name}" is not supported in this environment.`);
66 | return;
67 | }
68 |
69 | if (name === 'webgpu') {
70 | console.warn('⚠️ [TFJS] "webgpu" is experimental and may conflict with WebGL/R3F.');
71 | }
72 |
73 | try {
74 | await tf.setBackend(name);
75 | await tf.ready();
76 | console.log(`[TFJS] Using backend: ${tf.getBackend()}`);
77 | } catch (error) {
78 | console.warn(`[TFJS] Failed to set backend "${name}": ${error.message}`);
79 | // Fallback to default if available
80 | if (tf.findBackend('cpu') !== undefined && name !== 'cpu') {
81 | console.warn('[TFJS] Falling back to CPU backend.');
82 | await tf.setBackend('cpu');
83 | }
84 | }
85 | }
86 |
--------------------------------------------------------------------------------
/packages/backend-tfjs/index.ts:
--------------------------------------------------------------------------------
1 | // Re-export all from src/index.ts
2 | export * from './src/index';
3 |
--------------------------------------------------------------------------------
/packages/backend-tfjs/model/model.json:
--------------------------------------------------------------------------------
1 | {"modelTopology":{"class_name":"Sequential","config":{"name":"sequential_6","layers":[{"class_name":"Dense","config":{"units":8,"activation":"relu","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"normal","seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"kernel_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"bias_constraint":null,"name":"dense_Dense11","trainable":true,"batch_input_shape":[null,2],"dtype":"float32"}},{"class_name":"Dense","config":{"units":3,"activation":"linear","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"normal","seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"kernel_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"bias_constraint":null,"name":"dense_Dense12","trainable":true}}]},"keras_version":"tfjs-layers 4.22.0","backend":"tensor_flow.js"},"weightsManifest":[{"paths":["weights.bin"],"weights":[{"name":"dense_Dense11/kernel","shape":[2,8],"dtype":"float32"},{"name":"dense_Dense11/bias","shape":[8],"dtype":"float32"},{"name":"dense_Dense12/kernel","shape":[8,3],"dtype":"float32"},{"name":"dense_Dense12/bias","shape":[3],"dtype":"float32"}]}],"format":"layers-model","generatedBy":"TensorFlow.js tfjs-layers v4.22.0","convertedBy":null}
--------------------------------------------------------------------------------
/packages/backend-tfjs/model/weights.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IgnitionAI/ignition/9177384c3c7750bc25a596c10150ad5008bfaa95/packages/backend-tfjs/model/weights.bin
--------------------------------------------------------------------------------
/packages/backend-tfjs/model_step-10/model.json:
--------------------------------------------------------------------------------
1 | {"modelTopology":{"class_name":"Sequential","config":{"name":"sequential_1","layers":[{"class_name":"Dense","config":{"units":8,"activation":"relu","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"normal","seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"kernel_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"bias_constraint":null,"name":"dense_Dense1","trainable":true,"batch_input_shape":[null,2],"dtype":"float32"}},{"class_name":"Dense","config":{"units":2,"activation":"linear","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"normal","seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"kernel_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"bias_constraint":null,"name":"dense_Dense2","trainable":true}}]},"keras_version":"tfjs-layers 4.22.0","backend":"tensor_flow.js"},"weightsManifest":[{"paths":["weights.bin"],"weights":[{"name":"dense_Dense1/kernel","shape":[2,8],"dtype":"float32"},{"name":"dense_Dense1/bias","shape":[8],"dtype":"float32"},{"name":"dense_Dense2/kernel","shape":[8,2],"dtype":"float32"},{"name":"dense_Dense2/bias","shape":[2],"dtype":"float32"}]}],"format":"layers-model","generatedBy":"TensorFlow.js tfjs-layers v4.22.0","convertedBy":null}
--------------------------------------------------------------------------------
/packages/backend-tfjs/model_step-10/weights.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IgnitionAI/ignition/9177384c3c7750bc25a596c10150ad5008bfaa95/packages/backend-tfjs/model_step-10/weights.bin
--------------------------------------------------------------------------------
/packages/backend-tfjs/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "@ignitionai/backend-tfjs",
3 | "version": "0.1.0",
4 | "description": "TensorFlow.js backend for IgnitionAI - browser-based reinforcement learning framework",
5 | "main": "dist/index.js",
6 | "types": "dist/index.d.ts",
7 | "module": "dist/index.esm.js",
8 | "files": [
9 | "dist"
10 | ],
11 | "scripts": {
12 | "build": "tsc",
13 | "test": "vitest --config ../../vitest.config.ts",
14 | "prepublishOnly": "npm run build"
15 | },
16 | "keywords": [
17 | "reinforcement-learning",
18 | "tensorflow",
19 | "tfjs",
20 | "dqn",
21 | "ppo",
22 | "q-learning",
23 | "browser",
24 | "webgl",
25 | "webgpu"
26 | ],
27 | "author": "salim4n (@IgnitionAI)",
28 | "license": "MIT",
29 | "repository": {
30 | "type": "git",
31 | "url": "https://github.com/IgnitionAI/ignition-monorepo-starter"
32 | },
33 | "homepage": "https://github.com/IgnitionAI/ignition-monorepo-starter#readme",
34 | "dependencies": {
35 | "@huggingface/hub": "^1.1.2",
36 | "@tensorflow/tfjs": "^4.22.0",
37 | "@tensorflow/tfjs-backend-cpu": "^4.22.0",
38 | "@tensorflow/tfjs-backend-wasm": "^4.22.0",
39 | "@tensorflow/tfjs-backend-webgl": "^4.22.0",
40 | "@tensorflow/tfjs-backend-webgpu": "^4.22.0",
41 | "@tensorflow/tfjs-vis": "^1.5.1"
42 | },
43 | "devDependencies": {
44 | "@tensorflow/tfjs-node": "^4.22.0",
45 | "dotenv": "^16.5.0",
46 | "form-data": "^4.0.2",
47 | "vitest": "^3.1.1"
48 | },
49 | "publishConfig": {
50 | "access": "public"
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/packages/backend-tfjs/src/agents/dqn.ts:
--------------------------------------------------------------------------------
1 | import * as tf from '@tensorflow/tfjs';
2 |
3 | import { loadModelFromHub } from '../io/loadModel';
4 | import { saveModelToHub } from '../io/saveModelToHub';
5 | import {
6 | Experience,
7 | ReplayBuffer,
8 | } from '../memory/ReplayBuffer';
9 | import { buildQNetwork } from '../model/BuildMLP';
10 | import { DQNConfig } from '../types';
11 |
12 | export class DQNAgent {
13 | private model: tf.Sequential;
14 | private targetModel: tf.Sequential;
15 | private memory: ReplayBuffer;
16 | private epsilon: number;
17 | private epsilonDecay: number;
18 | private minEpsilon: number;
19 | private gamma: number;
20 | private batchSize: number;
21 | private targetUpdateFrequency: number;
22 | private trainStepCounter = 0;
23 | private actionSize: number;
24 | private bestReward: number = -Infinity;
25 |
26 | constructor(private config: DQNConfig) {
27 | const {
28 | inputSize,
29 | actionSize,
30 | hiddenLayers = [24, 24],
31 | gamma = 0.99,
32 | epsilon = 1.0,
33 | epsilonDecay = 0.995,
34 | minEpsilon = 0.01,
35 | lr = 0.001,
36 | batchSize = 32,
37 | memorySize = 10000,
38 | targetUpdateFrequency = 1000,
39 | } = config;
40 |
41 | this.actionSize = actionSize;
42 | this.gamma = gamma;
43 | this.epsilon = epsilon;
44 | this.epsilonDecay = epsilonDecay;
45 | this.minEpsilon = minEpsilon;
46 | this.batchSize = batchSize;
47 | this.targetUpdateFrequency = targetUpdateFrequency;
48 |
49 | this.model = buildQNetwork(inputSize, actionSize, hiddenLayers, lr);
50 | this.targetModel = buildQNetwork(inputSize, actionSize, hiddenLayers, lr);
51 | this.updateTargetModel();
52 |
53 | this.memory = new ReplayBuffer(memorySize);
54 | }
55 |
56 | async getAction(state: number[]): Promise {
57 | if (Math.random() < this.epsilon) {
58 | return Math.floor(Math.random() * this.actionSize);
59 | }
60 |
61 | const stateTensor = tf.tensor2d([state]);
62 | const qValues = this.model.predict(stateTensor) as tf.Tensor;
63 | const action = (await qValues.argMax(1).data())[0];
64 |
65 | tf.dispose([stateTensor, qValues]);
66 | return action;
67 | }
68 |
69 | remember(exp: Experience): void {
70 | this.memory.add(exp);
71 | }
72 |
73 | async updateTargetModel(): Promise {
74 | this.targetModel.setWeights(this.model.getWeights());
75 | }
76 |
77 | async train(): Promise {
78 | if (this.memory.size() < this.batchSize) return;
79 |
80 | const batch = this.memory.sample(this.batchSize);
81 | const states = batch.map(e => e.state);
82 | const nextStates = batch.map(e => e.nextState);
83 |
84 | const stateTensor = tf.tensor2d(states);
85 | const nextStateTensor = tf.tensor2d(nextStates);
86 |
87 | const qValues = this.model.predict(stateTensor) as tf.Tensor2D;
88 | const nextQValues = this.targetModel.predict(nextStateTensor) as tf.Tensor2D;
89 |
90 | const qArray = qValues.arraySync() as number[][];
91 | const nextQArray = nextQValues.arraySync() as number[][];
92 |
93 | const updatedQ = qArray.map((q, i) => {
94 | const { action, reward, done } = batch[i];
95 | q[action] = done ? reward : reward + this.gamma * Math.max(...nextQArray[i]);
96 | return q;
97 | });
98 |
99 | const targetTensor = tf.tensor2d(updatedQ);
100 | await this.model.fit(stateTensor, targetTensor, { epochs: 1, verbose: 0 });
101 |
102 | tf.dispose([stateTensor, nextStateTensor, qValues, nextQValues, targetTensor]);
103 |
104 | if (this.epsilon > this.minEpsilon) {
105 | this.epsilon *= this.epsilonDecay;
106 | }
107 |
108 | this.trainStepCounter++;
109 | if (this.trainStepCounter % this.targetUpdateFrequency === 0) {
110 | await this.updateTargetModel();
111 | }
112 | }
113 |
114 | reset(): void {
115 | this.epsilon = this.config.epsilon ?? 1.0;
116 | this.memory = new ReplayBuffer(this.config.memorySize);
117 | this.trainStepCounter = 0;
118 | }
119 |
120 | async saveToHub(repoId: string, token: string, modelName = 'model', checkpointName = 'last'): Promise {
121 | console.log(`[DQN] Saving model to HF Hub: ${repoId}`);
122 | // later generate model based on template mode name
123 | await saveModelToHub(this.model, repoId, token, `${modelName}_${checkpointName}`);
124 |
125 | }
126 |
127 | async loadFromHub(repoId: string, modelPath = 'model.json'): Promise {
128 | console.log(`[DQN] Loading model from HF Hub: ${repoId}`);
129 | const loadedModel = await loadModelFromHub(repoId, modelPath);
130 | this.model = loadedModel as tf.Sequential;
131 | await this.updateTargetModel();
132 | }
133 |
134 | /**
135 | * Save the model under a checkpoint name to Hugging Face Hub.
136 | * e.g., checkpointName = "last", "best", "step-1000"
137 | */
138 | async saveCheckpoint(repoId: string, token: string, checkpointName: string): Promise {
139 | const folder = `model_${checkpointName}`;
140 | console.log(`[DQN] Saving checkpoint "${checkpointName}" to HF Hub...`);
141 | await saveModelToHub(this.model, repoId, token, folder);
142 | console.log(`[DQN] ✅ Checkpoint "${checkpointName}" saved`);
143 | }
144 |
145 | async maybeSaveBestCheckpoint(repoId: string, token: string, reward: number, step?: number): Promise {
146 | console.log(`[DQN] Current best: ${this.bestReward.toFixed(4)}, new reward: ${reward.toFixed(4)}`);
147 | if (reward > this.bestReward) {
148 | console.log(`[DQN] 🏆 New best reward: ${reward.toFixed(3)} > ${this.bestReward.toFixed(3)}`);
149 | this.bestReward = reward;
150 | const checkpointName = step !== undefined ? `step-${step}` : 'best';
151 | await this.saveCheckpoint(repoId, token, checkpointName);
152 | }
153 | }
154 |
155 |
156 | /**
157 | * Load a checkpointed model from Hugging Face Hub.
158 | */
159 | async loadCheckpoint(repoId: string, checkpointName: string): Promise {
160 | const modelPath = `model_${checkpointName}/model.json`;
161 | console.log(`[DQN] Loading checkpoint "${checkpointName}" from HF Hub...`);
162 | const model = await loadModelFromHub(repoId, modelPath);
163 | this.model = model as tf.Sequential;
164 | await this.updateTargetModel();
165 | console.log(`[DQN] ✅ Checkpoint "${checkpointName}" loaded`);
166 | }
167 |
168 |
169 | dispose(): void {
170 | console.log(`[DQN] Disposing model...`);
171 | this.model?.dispose();
172 | console.log(`[DQN] Model disposed`);
173 | this.targetModel?.dispose();
174 | console.log(`[DQN] Target model disposed`);
175 | this.memory = new ReplayBuffer(0);
176 | console.log(`[DQN] Memory disposed`);
177 | console.log(`[DQN] ✅ DQNAgent disposed`);
178 | }
179 | }
180 |
--------------------------------------------------------------------------------
/packages/backend-tfjs/src/agents/ppo.ts:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IgnitionAI/ignition/9177384c3c7750bc25a596c10150ad5008bfaa95/packages/backend-tfjs/src/agents/ppo.ts
--------------------------------------------------------------------------------
/packages/backend-tfjs/src/agents/qtable.ts:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IgnitionAI/ignition/9177384c3c7750bc25a596c10150ad5008bfaa95/packages/backend-tfjs/src/agents/qtable.ts
--------------------------------------------------------------------------------
/packages/backend-tfjs/src/index.ts:
--------------------------------------------------------------------------------
1 | export * from './agents/dqn';
2 | //export * from './agents/ppo';
3 | //export * from './agents/qtable';
4 | export * from './memory/ReplayBuffer';
5 | export * from './model/BuildMLP';
6 | export * from './types';
7 | export { DQNAgent } from './agents/dqn';
--------------------------------------------------------------------------------
/packages/backend-tfjs/src/io/index.ts:
--------------------------------------------------------------------------------
1 | export * from './saveModelToHub';
2 | export * from './loadModel';
--------------------------------------------------------------------------------
/packages/backend-tfjs/src/io/loadModel.ts:
--------------------------------------------------------------------------------
1 | import * as tf from '@tensorflow/tfjs';
2 |
3 | /**
4 | * Attendre un délai spécifié en millisecondes
5 | * @param ms Le nombre de millisecondes à attendre
6 | */
7 | async function sleep(ms: number): Promise {
8 | return new Promise(resolve => setTimeout(resolve, ms));
9 | }
10 |
11 | /**
12 | * Load a TensorFlow.js model hosted on Hugging Face Hub.
13 | * GraphModel is faster and more efficient for inference, but requires more memory.
14 | * If you want to fine-tune the model, use LayersModel.
15 | *
16 | * @param repoId string - full repo name (e.g. "salim4n/my-dqn-model")
17 | * @param filename string - defaults to "model.json"
18 | * @param graphModel boolean - defaults to false
19 | * @param maxRetries number - max number of retry attempts
20 | * @param initialDelay number - initial delay in ms before first retry
21 | * @returns tf.LayersModel - the loaded model
22 | */
23 | export async function loadModelFromHub(
24 | repoId: string,
25 | filename: string = 'model.json',
26 | graphModel: boolean = false,
27 | maxRetries: number = 3,
28 | initialDelay: number = 2000
29 | ): Promise {
30 | const url = `https://huggingface.co/${repoId}/resolve/main/${filename}`;
31 | console.log(`[HFHub] Loading model from: ${url}`);
32 |
33 | // Tentatives avec backoff exponentiel
34 | let lastError;
35 | for (let attempt = 0; attempt < maxRetries; attempt++) {
36 | try {
37 | const model = graphModel ? await tf.loadGraphModel(url) : await tf.loadLayersModel(url);
38 | console.log(`[HFHub] ✅ Model loaded from Hugging Face Hub (${repoId})`);
39 | return model;
40 | } catch (error) {
41 | lastError = error;
42 | const delay = initialDelay * Math.pow(2, attempt);
43 | console.warn(`[HFHub] ⚠️ Failed to load model on attempt ${attempt + 1}/${maxRetries}. Retrying in ${delay}ms...`);
44 | await sleep(delay);
45 | }
46 | }
47 |
48 | console.error(`[HFHub] ❌ Failed to load model from ${url} after ${maxRetries} attempts`);
49 | throw lastError;
50 | }
51 |
--------------------------------------------------------------------------------
/packages/backend-tfjs/src/io/saveModelToHub.ts:
--------------------------------------------------------------------------------
1 | import * as fs from 'fs';
2 | import * as path from 'path';
3 |
4 | import {
5 | createRepo,
6 | uploadFiles,
7 | } from '@huggingface/hub';
8 | import * as tf from '@tensorflow/tfjs-node';
9 |
10 | // import { loadModelFromHub } from './loadModel';
11 |
12 | // Classe File polyfill pour Node.js
13 | /**
14 | class NodeFile {
15 | name: string;
16 | content: Buffer;
17 |
18 | constructor(content: Buffer, name: string) {
19 | this.content = content;
20 | this.name = name;
21 | }
22 | }
23 | */
24 |
25 | /**
26 | * Save a TensorFlow.js model locally and push it to Hugging Face Hub.
27 | *
28 | * @param model Trained tf.LayersModel
29 | * @param repo Full Hugging Face repo ID (e.g. "salim4n/dqn-agent")
30 | * @param token Hugging Face access token
31 | * @param subdir Subfolder inside repo (e.g. "step-5" or "best")
32 | */
33 | export async function saveModelToHub(
34 | model: tf.LayersModel,
35 | repo: string,
36 | token: string,
37 | subdir: string = 'model'
38 | ): Promise {
39 | const tmpDir = path.resolve(`./tmp-model/${subdir}`);
40 | fs.mkdirSync(tmpDir, { recursive: true });
41 |
42 | // Sauvegarde directe des fichiers sans utiliser model.save()
43 | const modelJSON = model.toJSON();
44 | const weights = model.getWeights();
45 |
46 | // Sauvegarder model.json
47 | fs.writeFileSync(
48 | path.join(tmpDir, 'model.json'),
49 | JSON.stringify(modelJSON, null, 2)
50 | );
51 |
52 | // Sauvegarder weights.bin
53 | const weightData = new Float32Array(weights.reduce((acc, w) => acc + w.size, 0));
54 | let offset = 0;
55 | weights.forEach(w => {
56 | const data = w.dataSync();
57 | weightData.set(data, offset);
58 | offset += data.length;
59 | });
60 | fs.writeFileSync(path.join(tmpDir, 'weights.bin'), Buffer.from(weightData.buffer));
61 |
62 | // Création des objets pour Hugging Face Upload
63 | const files = [
64 | {
65 | path: `${subdir}/model.json`,
66 | content: new Blob([fs.readFileSync(path.join(tmpDir, 'model.json'))])
67 | },
68 | {
69 | path: `${subdir}/weights.bin`,
70 | content: new Blob([fs.readFileSync(path.join(tmpDir, 'weights.bin'))])
71 | }
72 | ];
73 |
74 | // Ajouter README
75 | const readmeContent = `# TensorFlow.js Model
76 |
77 | ## Model Information
78 | - Framework: TensorFlow.js
79 | - Type: Deep Q-Network (DQN)
80 | - Created by: IgnitionAI
81 |
82 | ## Model Format
83 | This model is saved in TensorFlow.js format and can be loaded in two ways:
84 |
85 | 1. **LayersModel** (Default)
86 | - Better for fine-tuning and training
87 | - More flexible for model modifications
88 | - Higher memory usage
89 | - Slower inference
90 |
91 | 2. **GraphModel**
92 | - Optimized for inference only
93 | - Faster execution
94 | - Lower memory usage
95 | - Not suitable for training
96 |
97 | ## Usage
98 | \`\`\`javascript
99 | import { loadModelFromHub } from '@ignitionai/backend-tfjs';
100 |
101 | // Option 1: Load as LayersModel (for training/fine-tuning)
102 | const layersModel = await loadModelFromHub(
103 | '${repo}',
104 | '${subdir}/model.json',
105 | false // graphModel = false for LayersModel
106 | );
107 |
108 | // Option 2: Load as GraphModel (for inference only)
109 | const graphModel = await loadModelFromHub(
110 | '${repo}',
111 | '${subdir}/model.json',
112 | true // graphModel = true for GraphModel
113 | );
114 |
115 | // Run inference
116 | const input = tf.tensor2d([[0.1, 0.2]]);
117 | const output = model.predict(input);
118 | \`\`\`
119 |
120 | ## Features
121 | - Automatic retry with exponential backoff
122 | - Configurable retry attempts and delays
123 | - Error handling and logging
124 | - Support for both LayersModel and GraphModel
125 |
126 | ## Files
127 | - \`model.json\`: Model architecture and configuration
128 | - \`weights.bin\`: Model weights
129 | - \`README.md\`: This documentation
130 |
131 | ## Repository
132 | This model was uploaded via the IgnitionAI TensorFlow.js integration.
133 | `;
134 |
135 | // Ajouter README si c'est le dossier racine du modèle
136 | if (subdir === 'model') {
137 | files.push({
138 | path: 'README.md',
139 | content: new Blob([readmeContent])
140 | });
141 | }
142 |
143 | // Création du repo si nécessaire
144 | try {
145 | await createRepo({ repo, accessToken: token });
146 | console.log(`[HFHub] Repo "${repo}" ready.`);
147 | } catch (err) {
148 | console.warn(`[HFHub] Repo already exists or failed to create:`, (err as any)?.message);
149 | }
150 |
151 | // Upload vers Hugging Face
152 | await uploadFiles({
153 | repo,
154 | accessToken: token,
155 | files
156 | });
157 |
158 | console.log(`[HFHub] ✅ Uploaded to https://huggingface.co/${repo}/tree/main/${subdir}`);
159 | }
160 |
--------------------------------------------------------------------------------
/packages/backend-tfjs/src/memory/ReplayBuffer.ts:
--------------------------------------------------------------------------------
1 | export interface Experience {
2 | state: number[];
3 | action: number;
4 | reward: number;
5 | nextState: number[];
6 | done: boolean;
7 | }
8 |
9 | export class ReplayBuffer {
10 | private buffer: Experience[] = [];
11 | private capacity: number;
12 |
13 | constructor(capacity: number = 10000) {
14 | this.capacity = capacity;
15 | }
16 |
17 | add(exp: Experience): void {
18 | if (this.buffer.length >= this.capacity) {
19 | // delete older element
20 | this.buffer.shift();
21 | }
22 | this.buffer.push(exp);
23 | }
24 |
25 | sample(batchSize: number): Experience[] {
26 | const sampled: Experience[] = [];
27 | const bufferLength = this.buffer.length;
28 | for (let i = 0; i < Math.min(batchSize, bufferLength); i++) {
29 | const idx = Math.floor(Math.random() * bufferLength);
30 | sampled.push(this.buffer[idx]);
31 | }
32 | return sampled;
33 | }
34 |
35 | size(): number {
36 | return this.buffer.length;
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/packages/backend-tfjs/src/model/BuildMLP.ts:
--------------------------------------------------------------------------------
1 | import * as tf from '@tensorflow/tfjs';
2 |
3 | /**
4 | * Build a simple Q-network with a Sequential model.
5 | *
6 | * @param inputSize The size of the input (state dimension)
7 | * @param outputSize The number of actions (output dimension)
8 | * @param hiddenLayers Optional array defining the number of units in hidden layers.
9 | * Default is [24, 24].
10 | * @param lr Learning rate. Default is 0.001.
11 | * @returns A compiled tf.Sequential model.
12 | */
13 | export function buildQNetwork(
14 | inputSize: number,
15 | outputSize: number,
16 | hiddenLayers: number[] = [24, 24],
17 | lr: number = 0.001
18 | ): tf.Sequential {
19 | const model = tf.sequential();
20 |
21 | // Input layer
22 | model.add(tf.layers.dense({
23 | inputShape: [inputSize],
24 | units: hiddenLayers[0],
25 | activation: 'relu',
26 | }));
27 |
28 | // Additional hidden layers if any
29 | for (let i = 1; i < hiddenLayers.length; i++) {
30 | model.add(tf.layers.dense({
31 | units: hiddenLayers[i],
32 | activation: 'relu',
33 | }));
34 | }
35 |
36 | // Output layer (linear activation for Q-values)
37 | model.add(tf.layers.dense({
38 | units: outputSize,
39 | activation: 'linear',
40 | }));
41 |
42 | model.compile({
43 | optimizer: tf.train.adam(lr),
44 | loss: 'meanSquaredError',
45 | });
46 |
47 | return model;
48 | }
49 |
--------------------------------------------------------------------------------
/packages/backend-tfjs/src/tools/trainer.ts:
--------------------------------------------------------------------------------
1 | import { DQNAgent } from "../agents/dqn";
2 | import { DQNConfig } from "../types";
3 | import * as tf from '@tensorflow/tfjs';
4 |
5 | interface TrainAgentOptions {
6 | config: DQNConfig;
7 | maxSteps?: number;
8 | checkpointEvery?: number;
9 | repoId?: string;
10 | token?: string;
11 | onStep?: (step: number, reward: number, action: number, state: number[]) => void;
12 | getEnvStep: () => {
13 | state: number[];
14 | correctAction: number;
15 | nextState: number[];
16 | };
17 | }
18 |
19 | export async function trainAgent(options: TrainAgentOptions): Promise {
20 | const {
21 | config,
22 | maxSteps = 100,
23 | checkpointEvery = 10,
24 | repoId,
25 | token,
26 | onStep,
27 | getEnvStep
28 | } = options;
29 |
30 | const agent = new DQNAgent(config);
31 | let bestReward = -Infinity;
32 |
33 | for (let step = 1; step <= maxSteps; step++) {
34 | const { state, correctAction, nextState } = getEnvStep();
35 |
36 | const action = await agent.getAction(state);
37 | const reward = action === correctAction ? 1 : -1;
38 | const done = false;
39 |
40 | agent.remember({ state, action, reward, nextState, done });
41 | await agent.train();
42 |
43 | onStep?.(step, reward, action, state);
44 |
45 | if (checkpointEvery && step % checkpointEvery === 0 && repoId && token) {
46 | await agent.saveCheckpoint(repoId, token, `step-${step}`);
47 | }
48 |
49 | if (reward > bestReward && repoId && token) {
50 | bestReward = reward;
51 | await agent.saveCheckpoint(repoId, token, 'best');
52 | }
53 | }
54 |
55 | agent.dispose();
56 | tf.disposeVariables();
57 | tf.dispose();
58 | }
59 |
--------------------------------------------------------------------------------
/packages/backend-tfjs/src/types.ts:
--------------------------------------------------------------------------------
1 | export interface DQNConfig {
2 | inputSize: number; // Dimension of the state vector
3 | actionSize: number; // Number of possible discrete actions
4 | hiddenLayers?: number[]; // Number of neurons per hidden layer, default: [24, 24]
5 | gamma?: number; // Discount factor (default: 0.99)
6 | epsilon?: number; // Exploration rate (default: 1.0)
7 | epsilonDecay?: number; // Decay rate for epsilon per training step (default: 0.995)
8 | minEpsilon?: number; // Minimum exploration rate (default: 0.01)
9 | lr?: number; // Learning rate for the optimizer (default: 0.001)
10 | batchSize?: number; // Batch size for training (default: 32)
11 | memorySize?: number; // Maximum size of the replay buffer (default: 10000)
12 | targetUpdateFrequency?: number; // How often to update the target network (in training steps)
13 | }
14 |
--------------------------------------------------------------------------------
/packages/backend-tfjs/test/checkpoint.test.ts:
--------------------------------------------------------------------------------
1 | import { describe, it, expect, beforeEach, afterEach } from 'vitest';
2 | import * as tf from '@tensorflow/tfjs-node';
3 | import { DQNAgent } from '../src/agents/dqn';
4 | import { DQNConfig } from '../src/types';
5 | import * as dotenv from 'dotenv';
6 |
7 | dotenv.config();
8 | const HF_TOKEN = process.env.HF_TOKEN;
9 |
10 | if (!HF_TOKEN) {
11 | console.warn('⚠️ HF_TOKEN missing, skipping checkpoint test');
12 | it.skip('Checkpoint test skipped due to missing HF_TOKEN', () => {});
13 | } else {
14 | describe('DQNAgent Checkpointing with Hugging Face Hub', () => {
15 | const config: DQNConfig = {
16 | inputSize: 2,
17 | actionSize: 2,
18 | hiddenLayers: [8],
19 | epsilon: 0.0, // exploitation mode
20 | memorySize: 20,
21 | batchSize: 4,
22 | };
23 |
24 | const timestamp = Date.now();
25 | const REPO_ID = `salim4n/dqn-checkpoint-test-${timestamp}`;
26 | const CHECKPOINT_NAME = 'step-5';
27 | const TEST_STATE = [0.5, 0.9];
28 |
29 | let agent: DQNAgent;
30 |
31 | beforeEach(() => {
32 | agent = new DQNAgent(config);
33 | });
34 |
35 |
36 |
37 | it('should save and reload checkpoint via HF Hub', async () => {
38 | // Simuler entraînement minimal
39 | for (let i = 0; i < 5; i++) {
40 | const s = [Math.random(), Math.random()];
41 | const a = Math.floor(Math.random() * 2);
42 | const r = Math.random();
43 | const sNext = [Math.random(), Math.random()];
44 | agent.remember({ state: s, action: a, reward: r, nextState: sNext, done: false });
45 | }
46 | await agent.train();
47 |
48 | // Tester une action avant la sauvegarde
49 | const actionBefore = await agent.getAction(TEST_STATE);
50 |
51 | // Sauvegarder le checkpoint sur HF
52 | await agent.saveCheckpoint(REPO_ID, HF_TOKEN, CHECKPOINT_NAME);
53 |
54 | // Détruire puis recharger le modèle depuis le checkpoint
55 | agent.dispose();
56 | const newAgent = new DQNAgent(config);
57 | await newAgent.loadCheckpoint(REPO_ID, CHECKPOINT_NAME);
58 |
59 | // Tester une action après chargement
60 | const actionAfter = await newAgent.getAction(TEST_STATE);
61 |
62 | console.log(`[TEST] Action before: ${actionBefore}, after reload: ${actionAfter}`);
63 | expect([0, 1]).toContain(actionAfter);
64 |
65 | newAgent.dispose();
66 | }, 60000);
67 | });
68 | }
69 |
--------------------------------------------------------------------------------
/packages/backend-tfjs/test/dqn.test.ts:
--------------------------------------------------------------------------------
1 | import { describe, it, expect, afterEach } from 'vitest';
2 | import { DQNAgent } from '../src/agents/dqn';
3 | import { DQNConfig } from '../src/types';
4 | import * as tf from '@tensorflow/tfjs';
5 |
6 | // Fonction utilitaire pour surveiller la mémoire TensorFlow.js
7 | function logMemoryStatus(label: string) {
8 | const mem = tf.memory();
9 | console.log(`[TFJS Memory - ${label}]`, {
10 | numTensors: mem.numTensors,
11 | numBytes: mem.numBytes,
12 | numDataBuffers: mem.numDataBuffers,
13 | unreliable: mem.unreliable
14 | });
15 | }
16 |
17 | describe('DQNAgent', () => {
18 | const config: DQNConfig = {
19 | inputSize: 1,
20 | actionSize: 2,
21 | hiddenLayers: [16],
22 | epsilon: 0.5,
23 | epsilonDecay: 0.99,
24 | minEpsilon: 0.01,
25 | lr: 0.001,
26 | gamma: 0.9,
27 | batchSize: 4,
28 | memorySize: 100,
29 | targetUpdateFrequency: 10,
30 | };
31 |
32 | afterEach(() => {
33 | // Nettoyage global des tenseurs après chaque test
34 | tf.disposeVariables();
35 | tf.dispose(); // Dispose all remaining tensors
36 | console.log('[TFJS] After cleanup:', tf.memory());
37 | });
38 |
39 | it('should train on a fake environment', async () => {
40 | console.log('========== DQN AGENT TEST ==========');
41 | console.log('[TFJS] Current backend:', tf.getBackend());
42 | logMemoryStatus('START');
43 |
44 | console.log('Config:', JSON.stringify(config, null, 2));
45 | const agent = new DQNAgent(config);
46 | console.log('Agent initialized');
47 | logMemoryStatus('AFTER AGENT INIT');
48 |
49 | let correctActions = 0;
50 | let totalReward = 0;
51 |
52 | // Créons un compteur pour suivre les tenseurs créés
53 | let stepCounter = 0;
54 |
55 | // La boucle principale d'entraînement
56 | for (let step = 0; step < 50; step++) {
57 | stepCounter++;
58 | if (step % 10 === 0) {
59 | // Log la mémoire tous les 10 pas pour réduire les sorties
60 | logMemoryStatus(`STEP ${step}`);
61 | }
62 |
63 | const state = [Math.random()];
64 | const correct = state[0] > 0.5 ? 1 : 0;
65 |
66 | console.log(`\nStep ${step+1}/50: State=${state[0].toFixed(3)}, Correct action=${correct}`);
67 |
68 | // Obtenir l'action
69 | const action = await agent.getAction(state);
70 | console.log(`Action chosen: ${action}`);
71 |
72 | const reward = action === correct ? 1 : -1;
73 | if (action === correct) correctActions++;
74 | totalReward += reward;
75 |
76 | console.log(`Reward: ${reward}, Running score: ${correctActions}/${step+1} (${(correctActions/(step+1)*100).toFixed(1)}%)`);
77 |
78 | const nextState = [Math.random()];
79 | const done = false;
80 |
81 | agent.remember({ state, action, reward, nextState, done });
82 | console.log(`Memory size: ${(agent as any).memory.size()}`);
83 |
84 | // Train est déjà géré avec tf.dispose dans l'agent
85 | await agent.train();
86 | console.log(`Epsilon: ${(agent as any).epsilon.toFixed(3)}`);
87 |
88 | // Forcer le nettoyage de la mémoire à chaque étape
89 | if (step % 10 === 9) {
90 | tf.dispose();
91 | }
92 | }
93 |
94 | console.log('\n========== TRAINING SUMMARY ==========');
95 | console.log(`Final score: ${correctActions}/50 (${(correctActions/50*100).toFixed(1)}%)`);
96 | console.log(`Total reward: ${totalReward}`);
97 | console.log(`Final epsilon: ${(agent as any).epsilon.toFixed(3)}`);
98 | console.log('[TFJS] Current backend:', tf.getBackend());
99 | logMemoryStatus('END OF TRAINING');
100 |
101 | // Test avec un état connu
102 | const testState = [0.9]; // devrait préférer action = 1
103 | console.log(`\nTest prediction for state [${testState[0]}]`);
104 | const chosenAction = await agent.getAction(testState);
105 | console.log(`Agent chose: ${chosenAction} (Expected: 1)`);
106 |
107 | expect([0, 1]).toContain(chosenAction);
108 |
109 | // Nettoyage des tenseurs créés pendant le test final et des ressources de l'agent
110 | if (typeof agent.dispose === 'function') {
111 | agent.dispose();
112 | }
113 | tf.dispose();
114 |
115 | // Vérification finale de la mémoire
116 | logMemoryStatus('END OF TEST');
117 | });
118 | });
119 |
--------------------------------------------------------------------------------
/packages/backend-tfjs/test/hubIntegration.test.ts:
--------------------------------------------------------------------------------
1 | import { describe, it, expect, beforeEach, afterEach } from 'vitest';
2 | import * as tf from '@tensorflow/tfjs-node';
3 | import { DQNAgent } from '../src/agents/dqn';
4 | import { DQNConfig } from '../src/types';
5 | import { saveModelToHub } from '../src/io/saveModelToHub';
6 | import { loadModelFromHub } from '../src/io/loadModel';
7 | import * as dotenv from 'dotenv';
8 |
9 | dotenv.config();
10 |
11 | const HF_TOKEN = process.env.HF_TOKEN;
12 | if (!HF_TOKEN) {
13 | console.warn('⚠️ Missing HF_TOKEN in .env — skipping Hugging Face test');
14 | it.skip('Hugging Face test skipped due to missing token', () => {});
15 | } else {
16 | describe('Hugging Face Hub DQN integration', () => {
17 | const config: DQNConfig = {
18 | inputSize: 2,
19 | actionSize: 2,
20 | hiddenLayers: [8],
21 | epsilon: 0.0,
22 | memorySize: 10,
23 | batchSize: 4
24 | };
25 |
26 | const timestamp = Date.now();
27 | const REPO_ID = `salim4n/tfjs-dqn-test-${timestamp}`;
28 | const TEST_STATE = [0.2, 0.8];
29 |
30 | let agent: DQNAgent;
31 |
32 | beforeEach(() => {
33 | agent = new DQNAgent(config);
34 | });
35 |
36 | afterEach(() => {
37 | agent && agent.dispose();
38 | tf.disposeVariables();
39 | tf.dispose();
40 | });
41 |
42 | it('should save and load model from Hugging Face Hub', async () => {
43 | // Entraînement rapide sur quelques steps aléatoires
44 | for (let i = 0; i < 5; i++) {
45 | const state = [Math.random(), Math.random()];
46 | const action = Math.floor(Math.random() * 2);
47 | const reward = Math.random();
48 | const nextState = [Math.random(), Math.random()];
49 | agent.remember({ state, action, reward, nextState, done: false });
50 | }
51 | await agent.train();
52 |
53 | // Action avant sauvegarde
54 | const actionBefore = await agent.getAction(TEST_STATE);
55 | console.log(`[TEST] Action before save: ${actionBefore}`);
56 |
57 | // Save to HF
58 | await saveModelToHub(agent['model'], REPO_ID, HF_TOKEN!);
59 |
60 | // Attendre que le modèle soit disponible sur Hugging Face Hub
61 | console.log('[TEST] Waiting for model to be available on Hugging Face Hub...');
62 | await new Promise(resolve => setTimeout(resolve, 5000));
63 |
64 | // Charger le modèle avec plus de tentatives et un délai plus long
65 | console.log('[TEST] Attempting to load model...');
66 | const newModel = await loadModelFromHub(REPO_ID, 'model/model.json', false, 5, 3000);
67 | const pred = newModel.predict(tf.tensor2d([TEST_STATE])) as tf.Tensor;
68 | const output = await pred.array();
69 |
70 | console.log(`[TEST] Prediction after reload:`, output);
71 | expect(output[0]).toHaveLength(2);
72 | expect(typeof output[0][0]).toBe('number');
73 |
74 | newModel && newModel.dispose();
75 | pred && pred.dispose();
76 | }, 60000);
77 | });
78 | }
79 |
--------------------------------------------------------------------------------
/packages/backend-tfjs/test/setBackend.test.ts:
--------------------------------------------------------------------------------
1 | import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
2 | import * as tf from '@tensorflow/tfjs';
3 | // Import tfjs-node for Node.js support
4 | import '@tensorflow/tfjs-node';
5 | import { setBackendSafe } from '../backend/setBackend';
6 | import { initTfjsBackend, isWebGPUAvailable, getTfjsBackendInfo } from '../backend/setBackend';
7 |
8 | describe('TFJS Backend Selector', () => {
9 | beforeEach(() => {
10 | vi.spyOn(console, 'log').mockImplementation(() => {});
11 | vi.spyOn(console, 'warn').mockImplementation(() => {});
12 | });
13 |
14 | afterEach(() => {
15 | vi.restoreAllMocks();
16 | });
17 |
18 | describe('Backend Initialization in Node Environment', () => {
19 | it('should use tensorflow backend in Node.js environment', async () => {
20 | // In Node.js with tfjs-node, the backend should be 'tensorflow'
21 | const currentBackend = tf.getBackend();
22 | expect(currentBackend).toBe('tensorflow');
23 | });
24 |
25 | it('should attempt to use webgl in Node.js', async () => {
26 | const warnSpy = vi.spyOn(console, 'warn');
27 | await setBackendSafe('webgl');
28 |
29 | // In Node.js, we expect a warning about failed backend initialization
30 | expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining('Failed to set backend'));
31 | });
32 |
33 | it('should attempt to use wasm in Node.js', async () => {
34 | const warnSpy = vi.spyOn(console, 'warn');
35 | await setBackendSafe('wasm');
36 |
37 | // In Node.js, we expect a warning about failed backend initialization
38 | expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining('Failed to set backend'));
39 | });
40 |
41 | it('should warn on webgpu backend', async () => {
42 | const warnSpy = vi.spyOn(console, 'warn');
43 | await setBackendSafe('webgpu');
44 |
45 | // We expect warnings related to webgpu
46 | const warningCalled = warnSpy.mock.calls.some(
47 | call => call[0] && typeof call[0] === 'string' &&
48 | (call[0].includes('webgpu') || call[0].includes('Failed to set backend'))
49 | );
50 | expect(warningCalled).toBe(true);
51 | });
52 |
53 | it('should select available backend with initTfjsBackend', async () => {
54 | const backend = await initTfjsBackend();
55 | // In Node.js environment with tfjs-node, we expect to get 'tensorflow'
56 | expect(backend).toBe('cpu');
57 | });
58 | });
59 |
60 | describe('Other Backend Tests', () => {
61 | it('should have isWebGPUAvailable function', () => {
62 | expect(typeof isWebGPUAvailable).toBe('function');
63 | });
64 |
65 | it('should have getTfjsBackendInfo function', () => {
66 | const info = getTfjsBackendInfo();
67 | expect(info.backend).toBe('cpu');
68 | expect(typeof info.numTensors).toBe('number');
69 | });
70 | });
71 | });
--------------------------------------------------------------------------------
/packages/backend-tfjs/test/trainer.test.ts:
--------------------------------------------------------------------------------
1 | import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
2 | import * as tf from '@tensorflow/tfjs-node';
3 | import { trainAgent } from '../src/tools/trainer';
4 | import { DQNConfig } from '../src/types';
5 | import * as dotenv from 'dotenv';
6 |
7 | dotenv.config();
8 |
9 | const HF_TOKEN = process.env.HF_TOKEN;
10 |
11 | if (!HF_TOKEN) {
12 | console.warn('⚠️ HF_TOKEN is missing — skipping trainAgent test');
13 | it.skip('Skipped due to missing HF_TOKEN', () => {});
14 | } else {
15 | describe('trainAgent integration test with Hugging Face Hub', () => {
16 | const config: DQNConfig = {
17 | inputSize: 2,
18 | actionSize: 2,
19 | hiddenLayers: [8],
20 | epsilon: 0.5,
21 | epsilonDecay: 0.9,
22 | minEpsilon: 0.1,
23 | gamma: 0.9,
24 | batchSize: 4,
25 | memorySize: 100,
26 | lr: 0.001,
27 | targetUpdateFrequency: 5,
28 | };
29 |
30 | const timestamp = Date.now();
31 | const repoId = `salim4n/test-train-agent-${timestamp}`;
32 |
33 | const onStep = vi.fn();
34 |
35 | beforeEach(() => {
36 | onStep.mockClear();
37 | });
38 |
39 |
40 | it('should train agent and checkpoint to Hugging Face', async () => {
41 | await trainAgent({
42 | config,
43 | maxSteps: 10,
44 | checkpointEvery: 5,
45 | repoId,
46 | token: HF_TOKEN!,
47 | getEnvStep: () => {
48 | const s = [Math.random(), Math.random()];
49 | return {
50 | state: s,
51 | correctAction: s[0] > 0.5 ? 1 : 0,
52 | nextState: [Math.random(), Math.random()]
53 | };
54 | },
55 | onStep
56 | });
57 |
58 | expect(onStep).toHaveBeenCalledTimes(10);
59 | }, 60000);
60 | });
61 | }
62 |
--------------------------------------------------------------------------------
/packages/backend-tfjs/tmp-model/model.json:
--------------------------------------------------------------------------------
1 | {"modelTopology":{"class_name":"Sequential","config":{"name":"sequential_1","layers":[{"class_name":"Dense","config":{"units":8,"activation":"relu","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"normal","seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"kernel_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"bias_constraint":null,"name":"dense_Dense1","trainable":true,"batch_input_shape":[null,2],"dtype":"float32"}},{"class_name":"Dense","config":{"units":2,"activation":"linear","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"normal","seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"kernel_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"bias_constraint":null,"name":"dense_Dense2","trainable":true}}]},"keras_version":"tfjs-layers 4.22.0","backend":"tensor_flow.js"},"weightsManifest":[{"paths":["weights.bin"],"weights":[{"name":"dense_Dense1/kernel","shape":[2,8],"dtype":"float32"},{"name":"dense_Dense1/bias","shape":[8],"dtype":"float32"},{"name":"dense_Dense2/kernel","shape":[8,2],"dtype":"float32"},{"name":"dense_Dense2/bias","shape":[2],"dtype":"float32"}]}],"format":"layers-model","generatedBy":"TensorFlow.js tfjs-layers v4.22.0","convertedBy":null}
--------------------------------------------------------------------------------
/packages/backend-tfjs/tmp-model/model/model.json:
--------------------------------------------------------------------------------
1 | "{\"class_name\":\"Sequential\",\"config\":{\"name\":\"sequential_1\",\"layers\":[{\"class_name\":\"Dense\",\"config\":{\"units\":8,\"activation\":\"relu\",\"use_bias\":true,\"kernel_initializer\":{\"class_name\":\"VarianceScaling\",\"config\":{\"scale\":1,\"mode\":\"fan_avg\",\"distribution\":\"normal\",\"seed\":null}},\"bias_initializer\":{\"class_name\":\"Zeros\",\"config\":{}},\"kernel_regularizer\":null,\"bias_regularizer\":null,\"activity_regularizer\":null,\"kernel_constraint\":null,\"bias_constraint\":null,\"name\":\"dense_Dense1\",\"trainable\":true,\"batch_input_shape\":[null,2],\"dtype\":\"float32\"}},{\"class_name\":\"Dense\",\"config\":{\"units\":2,\"activation\":\"linear\",\"use_bias\":true,\"kernel_initializer\":{\"class_name\":\"VarianceScaling\",\"config\":{\"scale\":1,\"mode\":\"fan_avg\",\"distribution\":\"normal\",\"seed\":null}},\"bias_initializer\":{\"class_name\":\"Zeros\",\"config\":{}},\"kernel_regularizer\":null,\"bias_regularizer\":null,\"activity_regularizer\":null,\"kernel_constraint\":null,\"bias_constraint\":null,\"name\":\"dense_Dense2\",\"trainable\":true}}]},\"keras_version\":\"tfjs-layers 4.22.0\",\"backend\":\"tensor_flow.js\"}"
--------------------------------------------------------------------------------
/packages/backend-tfjs/tmp-model/model/weights.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IgnitionAI/ignition/9177384c3c7750bc25a596c10150ad5008bfaa95/packages/backend-tfjs/tmp-model/model/weights.bin
--------------------------------------------------------------------------------
/packages/backend-tfjs/tmp-model/weights.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IgnitionAI/ignition/9177384c3c7750bc25a596c10150ad5008bfaa95/packages/backend-tfjs/tmp-model/weights.bin
--------------------------------------------------------------------------------
/packages/backend-tfjs/tmp-readme/README.md:
--------------------------------------------------------------------------------
1 | # TensorFlow.js DQN Model
2 |
3 | ## Model Information
4 | - Framework: TensorFlow.js
5 | - Type: Deep Q-Network (DQN)
6 | - Created by: Ignition AI
7 |
8 | ## Usage in TensorFlow.js
9 | ```javascript
10 | // Load the model
11 | const model = await tf.loadLayersModel('https://huggingface.co/salim4n/tfjs-dqn-2025-04-13-1744559537521/resolve/main/model.json');
12 |
13 | // Use the model
14 | const input = tf.tensor2d([[/* your input values */]]);
15 | const output = model.predict(input);
16 | ```
17 |
18 | ## Files
19 | - model.json: Model architecture
20 | - *.bin: Model weights
21 |
22 | ## Repository
23 | This model was uploaded via the IgnitionAI TensorFlow.js integration.
24 |
--------------------------------------------------------------------------------
/packages/backend-tfjs/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "../../tsconfig.base.json",
3 | "compilerOptions": {
4 | "outDir": "dist"
5 | },
6 | "include": ["src/**/*"]
7 | }
8 |
--------------------------------------------------------------------------------
/packages/backend-tfjs/vitest.config.ts:
--------------------------------------------------------------------------------
1 | // packages/backend-tfjs/vitest.config.ts
2 | import { defineConfig } from 'vitest/config';
3 |
4 | export default defineConfig({
5 | test: {
6 | globals: true,
7 | environment: 'node',
8 | include: ['tests/**/*.test.ts']
9 | }
10 | });
11 |
--------------------------------------------------------------------------------
/packages/core/index.ts:
--------------------------------------------------------------------------------
1 | export const core = () => {
2 | console.log("Hello from core");
3 | };
--------------------------------------------------------------------------------
/packages/core/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "@ignitionai/core",
3 | "version": "0.1.0",
4 | "description": "Core components for IgnitionAI - browser-based reinforcement learning framework",
5 | "main": "dist/index.js",
6 | "types": "dist/index.d.ts",
7 | "module": "dist/index.esm.js",
8 | "files": [
9 | "dist"
10 | ],
11 | "scripts": {
12 | "build": "tsc",
13 | "prepublishOnly": "npm run build"
14 | },
15 | "keywords": [
16 | "reinforcement-learning",
17 | "agent",
18 | "environment",
19 | "browser",
20 | "tensorflow"
21 | ],
22 | "author": "salim4n (@IgnitionAI)",
23 | "license": "MIT",
24 | "repository": {
25 | "type": "git",
26 | "url": "https://github.com/IgnitionAI/ignition"
27 | },
28 | "homepage": "https://github.com/IgnitionAI/ignition#readme",
29 | "dependencies": {
30 | "zod": "^3.24.2"
31 | },
32 | "publishConfig": {
33 | "access": "public"
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/packages/core/src/ignition-env.ts:
--------------------------------------------------------------------------------
1 | import {
2 | AgentInterface,
3 | Experience,
4 | } from './types';
5 |
6 | export interface IgnitionEnvConfig {
7 | agent: AgentInterface;
8 |
9 | getObservation: () => number[];
10 | applyAction: (action: number | number[]) => void;
11 | computeReward: () => number;
12 | isDone: () => boolean;
13 | onReset?: () => void;
14 |
15 | stepIntervalMs?: number;
16 | hfRepoId?: string;
17 | hfToken?: string;
18 | }
19 |
20 | export class IgnitionEnv {
21 | private config: IgnitionEnvConfig;
22 | private agent: AgentInterface;
23 | private currentState: number[];
24 | private intervalId?: ReturnType;
25 | public stepCount: number = 0;
26 |
27 | constructor(config: IgnitionEnvConfig) {
28 | this.config = config;
29 | this.agent = config.agent;
30 | this.currentState = config.getObservation();
31 | }
32 |
33 | public async step(): Promise {
34 | this.stepCount++;
35 |
36 | const action = await this.agent.getAction(this.currentState);
37 | this.config.applyAction(action);
38 |
39 | const nextState = this.config.getObservation();
40 | const reward = this.config.computeReward();
41 | const done = this.config.isDone();
42 |
43 | const experience: Experience = {
44 | state: this.currentState,
45 | action,
46 | reward,
47 | nextState,
48 | done,
49 | };
50 |
51 | this.agent.remember(experience);
52 | await this.agent.train();
53 |
54 | if ('maybeSaveBestCheckpoint' in this.agent && this.config.hfRepoId && this.config.hfToken) {
55 | await (this.agent as any).maybeSaveBestCheckpoint(
56 | this.config.hfRepoId,
57 | this.config.hfToken,
58 | reward,
59 | this.stepCount
60 | );
61 | }
62 |
63 | if (done) {
64 | this.config.onReset?.();
65 | this.currentState = this.config.getObservation();
66 | } else {
67 | this.currentState = nextState;
68 | }
69 | }
70 |
71 | public start(auto: boolean = true): void {
72 | if (!auto) return;
73 | const interval = this.config.stepIntervalMs ?? 100;
74 | this.intervalId = setInterval(() => this.step(), interval);
75 | }
76 |
77 | public stop(): void {
78 | if (this.intervalId) {
79 | clearInterval(this.intervalId);
80 | this.intervalId = undefined;
81 | }
82 | }
83 |
84 | public reset(): void {
85 | this.config.onReset?.();
86 | this.currentState = this.config.getObservation();
87 | this.stepCount = 0;
88 | }
89 | }
90 |
--------------------------------------------------------------------------------
/packages/core/src/index.ts:
--------------------------------------------------------------------------------
1 | // Core package implementation
2 | export const version = '0.1.0';
3 |
4 | export default {
5 | name: 'core',
6 | version
7 | };
8 |
9 | export { IgnitionEnv } from './ignition-env';
--------------------------------------------------------------------------------
/packages/core/src/types.ts:
--------------------------------------------------------------------------------
1 | export interface Experience {
2 | state: number[];
3 | action: number;
4 | reward: number;
5 | nextState: number[];
6 | done: boolean;
7 | }
8 |
9 | export interface AgentInterface {
10 | getAction(observation: number[]): Promise;
11 | remember(experience: Experience): void;
12 | train(): Promise;
13 | }
14 |
--------------------------------------------------------------------------------
/packages/core/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "../../tsconfig.base.json",
3 | "compilerOptions": {
4 | "outDir": "dist"
5 | },
6 | "include": ["src/**/*"]
7 | }
8 |
--------------------------------------------------------------------------------
/packages/demo-target-chasing/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Target Chasing Visualization
7 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/packages/demo-target-chasing/index.ts:
--------------------------------------------------------------------------------
1 | export const demo = () => {
2 | console.log("Hello from demo");
3 | };
--------------------------------------------------------------------------------
/packages/demo-target-chasing/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "target-chasing",
3 | "private": true,
4 | "version": "0.0.0",
5 | "type": "module",
6 | "scripts": {
7 | "dev": "vite",
8 | "build": "tsc -b && vite build",
9 | "lint": "eslint .",
10 | "preview": "vite preview"
11 | },
12 | "dependencies": {
13 | "@ignitionai/backend-tfjs": "workspace:*",
14 | "@ignitionai/core": "workspace:*",
15 | "@react-three/drei": "^10.0.6",
16 | "@react-three/fiber": "^9.1.2",
17 | "@react-three/rapier": "^2.1.0",
18 | "@types/three": "^0.162.0",
19 | "react": "^19.0.0",
20 | "react-dom": "^19.0.0",
21 | "three": "^0.162.0",
22 | "zustand": "^5.0.3"
23 | },
24 | "devDependencies": {
25 | "@eslint/js": "^9.22.0",
26 | "@react-three/eslint-plugin": "^0.1.2",
27 | "@types/react": "^19.0.10",
28 | "@types/react-dom": "^19.0.4",
29 | "@vitejs/plugin-react": "^4.3.4",
30 | "eslint": "^9.22.0",
31 | "eslint-plugin-react-hooks": "^5.2.0",
32 | "eslint-plugin-react-refresh": "^0.4.19",
33 | "globals": "^16.0.0",
34 | "typescript": "~5.7.2",
35 | "typescript-eslint": "^8.26.1",
36 | "vite": "^6.3.1"
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/packages/demo-target-chasing/readme.md:
--------------------------------------------------------------------------------
1 | # 🎯 Target Chasing Demo with IgnitionAI
2 |
3 | This example demonstrates how to use IgnitionAI with Three.js to create a simple reinforcement learning environment where an agent learns to reach a target.
4 |
5 | ## 🎥 Demo Video
6 |
7 | [](https://www.youtube.com/watch?v=97CMG7H_5Mo)
8 |
9 | ## 🔧 Requirements
10 |
11 | - Three.js scene (WebGLRenderer, Camera, Meshes, etc.)
12 | - @ignitionai/backend-tfjs (DQN Agent)
13 | - @ignitionai/core (IgnitionEnv)
14 | - Optionally: Hugging Face token for checkpointing
15 |
16 | ## 🚀 Getting Started
17 |
18 | 1. Clone the repository
19 | 2. Install dependencies: `npm install`
20 | 3. Create a `.env` file with your Hugging Face token: `VITE_HF_TOKEN=your_token_here`
21 | 4. Run the development server: `npm run dev`
22 | 5. Open your browser to `http://localhost:3000`
23 |
24 | ## 🧠 Reinforcement Learning Environment
25 |
26 | ### Setup
27 |
28 | ```typescript
29 | const env: IgnitionEnv = new IgnitionEnv({
30 | agent: dqnAgent, // The learning agent
31 | getObservation: () => [position, targetPosition],
32 | applyAction: (a: number) => {
33 | const dx = a - 1; // Convert action to direction
34 | position += dx * 0.2;
35 | agentMesh.position.x = position; // Apply movement
36 | },
37 | computeReward: () => {
38 | const d = Math.abs(position - targetPosition);
39 | let reward = 1.0 / (1.0 + d);
40 | if (d > previousDistance) reward -= 0.5; // Penalty for moving away
41 | if (d < 0.5) reward += 1.0; // Bonus for getting close
42 | previousDistance = d;
43 | return reward;
44 | },
45 | isDone: () => {
46 | const d = Math.abs(position - targetPosition);
47 | return d < 0.1 || stepCount > 1000; // Success or timeout
48 | },
49 | onReset: () => {
50 | position = 0;
51 | targetPosition = (Math.random() - 0.5) * 4;
52 | agentMesh.position.x = position;
53 | targetMesh.position.x = targetPosition;
54 | stepCount = 0;
55 | bestDistance = Infinity;
56 | previousDistance = Math.abs(position - targetPosition);
57 | },
58 | stepIntervalMs: 100, // Loop speed
59 | hfRepoId: 'your-username/dqn-checkpoint',
60 | hfToken: import.meta.env.VITE_HF_TOKEN || '',
61 | });
62 | ```
63 |
64 | ## ✅ Features
65 |
66 | - ✅ Custom visual scene via Three.js
67 | - ✅ Simple declarative API for environment definition
68 | - ✅ No ML knowledge required
69 | - ✅ Optional Hugging Face integration (auto checkpoints)
70 | - ✅ Ready for R3F / WebGL / WebXR
71 | - ✅ Real-time visualization of agent learning
72 | - ✅ Automatic checkpoint saving for best models
73 | - ✅ Training progress monitoring
74 |
75 | ## 🎮 How It Works
76 |
77 | 1. The agent (green sphere) learns to reach the target (red sphere)
78 | 2. The agent can move left, stay still, or move right
79 | 3. The agent receives rewards based on distance to target
80 | 4. Training automatically stops when the target is reached or after 1000 steps
81 | 5. The best model is automatically saved as a checkpoint
82 |
83 | ## 🔍 Customization
84 |
85 | You can customize the environment by modifying:
86 |
87 | - Agent parameters (learning rate, exploration rate, etc.)
88 | - Reward function to change learning behavior
89 | - Target position range
90 | - Movement speed
91 | - Maximum steps per episode
92 |
93 | ## 🚀 Run It
94 |
95 | 1. Create a canvas or WebGL context
96 | 2. Import DQNAgent and IgnitionEnv
97 | 3. Define your observation, reward, and action logic
98 | 4. Call `env.start()` and watch it learn
99 |
100 | ## 📚 Learn More
101 |
102 | - [IgnitionAI Documentation](https://github.com/ignitionai/ignition)
103 | - [Three.js Documentation](https://threejs.org/docs/)
104 | - [Reinforcement Learning Basics](https://spinningup.openai.com/en/latest/spinningup/rl_intro.html)
105 |
106 | Need help or want to showcase your own scenes? Join the [IgnitionAI Discussions](https://github.com/ignitionai/ignition/discussions) on GitHub!
--------------------------------------------------------------------------------
/packages/demo-target-chasing/src/AgentConfigPanel.tsx:
--------------------------------------------------------------------------------
1 | import React, { useState } from 'react';
2 |
3 | interface LayerConfig {
4 | id: string;
5 | neurons: number;
6 | }
7 |
8 | interface AgentConfigPanelProps {
9 | onApplyConfig: (config: {
10 | inputSize: number;
11 | actionSize: number;
12 | hiddenLayers: number[];
13 | epsilon: number;
14 | epsilonDecay: number;
15 | minEpsilon: number;
16 | gamma: number;
17 | lr: number;
18 | batchSize: number;
19 | memorySize: number;
20 | }) => void;
21 | }
22 |
23 | export function AgentConfigPanel({ onApplyConfig }: AgentConfigPanelProps) {
24 | // Default configuration
25 | const [inputSize, setInputSize] = useState(9);
26 | const [actionSize, setActionSize] = useState(4);
27 | const [layers, setLayers] = useState([
28 | { id: 'layer1', neurons: 64 },
29 | { id: 'layer2', neurons: 64 }
30 | ]);
31 | const [epsilon, setEpsilon] = useState(0.9);
32 | const [epsilonDecay, setEpsilonDecay] = useState(0.97);
33 | const [minEpsilon, setMinEpsilon] = useState(0.05);
34 | const [gamma, setGamma] = useState(0.99);
35 | const [learningRate, setLearningRate] = useState(0.001);
36 | const [batchSize, setBatchSize] = useState(128);
37 | const [memorySize, setMemorySize] = useState(100000);
38 |
39 | // Function to add a new layer
40 | const addLayer = () => {
41 | const newId = `layer${layers.length + 1}`;
42 | setLayers([...layers, { id: newId, neurons: 32 }]);
43 | };
44 |
45 | // Function to remove a layer
46 | const removeLayer = (id: string) => {
47 | if (layers.length > 1) {
48 | setLayers(layers.filter(layer => layer.id !== id));
49 | }
50 | };
51 |
52 | // Function to update a layer's neuron count
53 | const updateLayer = (id: string, neurons: number) => {
54 | setLayers(layers.map(layer =>
55 | layer.id === id ? { ...layer, neurons } : layer
56 | ));
57 | };
58 |
59 | // Apply configuration
60 | const applyConfig = () => {
61 | onApplyConfig({
62 | inputSize,
63 | actionSize,
64 | hiddenLayers: layers.map(layer => layer.neurons),
65 | epsilon,
66 | epsilonDecay,
67 | minEpsilon,
68 | gamma,
69 | lr: learningRate,
70 | batchSize,
71 | memorySize
72 | });
73 | };
74 |
75 | return (
76 |
77 |
Agent Configuration
78 |
79 |
80 |
Network Architecture
81 |
82 |
83 | Input Size:
84 | setInputSize(parseInt(e.target.value))}
88 | min="1"
89 | />
90 |
91 |
92 |
93 | Action Size:
94 | setActionSize(parseInt(e.target.value))}
98 | min="1"
99 | />
100 |
101 |
102 |
103 |
Hidden Layers
104 | {layers.map((layer, index) => (
105 |
106 | Layer {index + 1}:
107 | updateLayer(layer.id, parseInt(e.target.value))}
111 | min="1"
112 | />
113 | removeLayer(layer.id)}>Remove
114 |
115 | ))}
116 |
Add Layer
117 |
118 |
119 |
120 |
203 |
204 |
Apply Configuration
205 |
206 | );
207 | }
208 |
--------------------------------------------------------------------------------
/packages/demo-target-chasing/src/NetworkDesigner.tsx:
--------------------------------------------------------------------------------
1 | import 'reactflow/dist/style.css';
2 |
3 | import {
4 | useCallback,
5 | useEffect,
6 | useState,
7 | } from 'react';
8 |
9 | import ReactFlow, {
10 | addEdge,
11 | Background,
12 | Connection,
13 | Controls,
14 | Edge,
15 | MiniMap,
16 | Node,
17 | ReactFlowInstance,
18 | ReactFlowProvider,
19 | } from 'reactflow';
20 |
21 | // Combine node + edge type
22 | type Element = any;
23 |
24 | const initialElements: Element[] = [
25 | { id: 'input', type: 'input', data: { label: 'Input (Size: 9)' }, position: { x: 100, y: 100 } },
26 | { id: 'hidden1', type: 'default', data: { label: 'Dense (Neurons: 64)' }, position: { x: 300, y: 50 } },
27 | { id: 'hidden2', type: 'default', data: { label: 'Dense (Neurons: 64)' }, position: { x: 300, y: 150 } },
28 | { id: 'output', type: 'output', data: { label: 'Output (Actions: 4)' }, position: { x: 500, y: 100 } },
29 | { id: 'e-input-h1', source: 'input', target: 'hidden1', animated: true },
30 | { id: 'e-input-h2', source: 'input', target: 'hidden2', animated: true },
31 | { id: 'e-h1-output', source: 'hidden1', target: 'output', animated: true },
32 | { id: 'e-h2-output', source: 'hidden2', target: 'output', animated: true },
33 | ];
34 |
35 | interface NetworkDesignerProps {
36 | onNetworkChange: (layers: number[]) => void;
37 | }
38 |
39 | export function NetworkDesigner({ onNetworkChange }: NetworkDesignerProps) {
40 | const [elements, setElements] = useState(initialElements);
41 | const [reactFlowInstance, setReactFlowInstance] = useState(null);
42 |
43 | const onNodesDelete = useCallback(
44 | (nodesToRemove: Node[]) =>
45 | setElements((els: Element[]) =>
46 | els.filter((el) => !(el as Node).id || !nodesToRemove.some((n) => n.id === (el as Node).id))
47 | ),
48 | []
49 | );
50 |
51 | const onEdgesDelete = useCallback(
52 | (edgesToRemove: Edge[]) =>
53 | setElements((els: Element[]) =>
54 | els.filter((el) => !(el as Edge).id || !edgesToRemove.some((e) => e.id === (el as Edge).id))
55 | ),
56 | []
57 | );
58 |
59 | const onConnect = useCallback(
60 | (params: Connection | Edge) =>
61 | setElements((els: Element[]) => addEdge({ ...params, animated: true }, els)),
62 | []
63 | );
64 |
65 | const onInit = useCallback((instance: ReactFlowInstance) => {
66 | setReactFlowInstance(instance);
67 | instance.fitView();
68 | extractNetworkStructure(initialElements);
69 | }, []);
70 |
71 | const extractNetworkStructure = (currentElements: Element[]) => {
72 | const hiddenLayers: number[] = [];
73 | currentElements.forEach((el) => {
74 | if ('type' in el && el.type === 'default' && el.data?.label?.includes('Dense')) {
75 | const match = el.data.label.match(/Neurons: (\d+)/);
76 | if (match && match[1]) {
77 | hiddenLayers.push(parseInt(match[1], 10));
78 | }
79 | }
80 | });
81 |
82 | console.log('Extracted hidden layers:', hiddenLayers);
83 | onNetworkChange(hiddenLayers);
84 | };
85 |
86 | useEffect(() => {
87 | if (reactFlowInstance) {
88 | extractNetworkStructure(elements);
89 | }
90 | }, [elements, reactFlowInstance]);
91 |
92 | return (
93 |
94 |
Network Designer (Drag & Drop - Basic)
95 |
96 |
97 | 'position' in e)}
99 | edges={elements.filter((e): e is Edge => 'source' in e && 'target' in e)}
100 | onConnect={onConnect}
101 | onNodesDelete={onNodesDelete}
102 | onEdgesDelete={onEdgesDelete}
103 | onInit={onInit}
104 | fitView
105 | snapToGrid
106 | snapGrid={[15, 15]}
107 | >
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 | Note: This is a basic visual representation. Add/remove/connect nodes to define layers.
116 | Neuron counts need manual adjustment via the config panel for now.
117 |
118 |
119 | );
120 | }
121 |
--------------------------------------------------------------------------------
/packages/demo-target-chasing/src/TrainingControls.tsx:
--------------------------------------------------------------------------------
1 | import { useTrainingStore } from './store/trainingStore';
2 |
3 | interface TrainingControlsProps {
4 | startTraining: () => void;
5 | stopTraining: () => void;
6 | resetEnvironment: () => void;
7 | }
8 |
9 | export function TrainingControls({ startTraining, stopTraining, resetEnvironment }: TrainingControlsProps) {
10 | const {
11 | isTraining,
12 | episodeCount,
13 | reward,
14 | episodeTime,
15 | successCount,
16 | difficulty,
17 | lastAction
18 | } = useTrainingStore();
19 |
20 | return (
21 |
22 |
Contrôle d'entraînement
23 |
Épisodes: {episodeCount}
24 |
Succès: {successCount} / {episodeCount}
25 |
Difficulté: {difficulty + 1}/3
26 |
Temps: {episodeTime.toFixed(1)}s
27 |
Dernière action: {lastAction !== -1 ? ['Gauche', 'Droite', 'Avant', 'Arrière'][lastAction] : 'Aucune'}
28 |
Récompense: {reward.toFixed(2)}
29 |
30 | {!isTraining ? (
31 | Démarrer l'entraînement
32 | ) : (
33 | Arrêter l'entraînement
34 | )}
35 | Réinitialiser
36 |
37 |
38 | );
39 | }
40 |
--------------------------------------------------------------------------------
/packages/demo-target-chasing/src/env.d.ts:
--------------------------------------------------------------------------------
1 | ///
2 |
3 | interface ImportMetaEnv {
4 | readonly VITE_HF_TOKEN: string;
5 | }
6 |
7 | interface ImportMeta {
8 | readonly env: ImportMetaEnv;
9 | }
--------------------------------------------------------------------------------
/packages/demo-target-chasing/src/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Target Chasing Visualization
7 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/packages/demo-target-chasing/src/index.ts:
--------------------------------------------------------------------------------
1 | // Demo target chasing implementation
2 | export const version = '0.1.0';
3 |
4 | export default {
5 | name: 'demo-target-chasing',
6 | version
7 | };
--------------------------------------------------------------------------------
/packages/demo-target-chasing/src/main.ts:
--------------------------------------------------------------------------------
1 | import * as dotenv from 'dotenv';
2 |
3 | import { DQNAgent } from '@ignitionai/backend-tfjs';
4 | import { IgnitionEnv } from '@ignitionai/core';
5 |
6 | // Charger les variables d'environnement depuis le fichier .env
7 | dotenv.config();
8 |
9 | console.log('Starting training...');
10 | console.log('HF_TOKEN:', process.env.HF_TOKEN ? '✅ Found' : '❌ Not found');
11 |
12 | const agent = new DQNAgent({
13 | inputSize: 2,
14 | actionSize: 3,
15 | hiddenLayers: [32, 32],
16 | gamma: 0.99,
17 | epsilon: 1.0,
18 | epsilonDecay: 0.995,
19 | minEpsilon: 0.01,
20 | lr: 0.001,
21 | batchSize: 32,
22 | memorySize: 1000,
23 | targetUpdateFrequency: 10,
24 | });
25 |
26 | let position = 0;
27 | let target = (Math.random() - 0.5) * 4;
28 | let bestDistance = Infinity;
29 |
30 | // Définir la fonction isDone en dehors de la configuration
31 | const isDone = (): boolean => {
32 | const d = Math.abs(position - target);
33 | return d < 0.1 || env.stepCount > 1000;
34 | };
35 |
36 | const env: IgnitionEnv = new IgnitionEnv({
37 | agent,
38 | getObservation: () => [position, target],
39 | applyAction: (action: number | number[]) => {
40 | // Handle array actions by taking the first number
41 | const a = Array.isArray(action) ? action[0] : action;
42 | const dx = a - 1;
43 | position += dx * 0.2;
44 |
45 | if (env.stepCount % 10 === 0) {
46 | console.log(`Step ${env.stepCount}: pos=${position.toFixed(2)}, target=${target.toFixed(2)}`);
47 | }
48 | },
49 | computeReward: () => {
50 | const d = Math.abs(position - target);
51 | const reward = 1.0 / (1.0 + d);
52 | if (env.stepCount % 10 === 0) {
53 | console.log(`[REWARD] ${reward.toFixed(4)}`);
54 | }
55 | return reward;
56 | },
57 | isDone,
58 | onReset: () => {
59 | position = 0;
60 | target = (Math.random() - 0.5) * 4;
61 | console.log(`[RESET] New target: ${target.toFixed(2)}`);
62 | },
63 | stepIntervalMs: 100,
64 | hfRepoId: 'salim4n/dqn-checkpoint-demo',
65 | hfToken: process.env.HF_TOKEN!,
66 | });
67 |
68 | // Étendre la méthode step pour gérer les checkpoints
69 | const originalStep = env.step.bind(env);
70 | env.step = async () => {
71 | await originalStep();
72 |
73 | const d = Math.abs(position - target);
74 |
75 | // Sauvegarder un checkpoint si c'est la meilleure performance jusqu'à présent
76 | if (d < bestDistance) {
77 | bestDistance = d;
78 | console.log(`[CHECKPOINT] Nouvelle meilleure distance: ${d.toFixed(4)}`);
79 | console.log(`[CHECKPOINT] Sauvegarde du meilleur modèle...`);
80 | await agent.saveCheckpoint(
81 | 'salim4n/test-checkpoint',
82 | process.env.HF_TOKEN!,
83 | 'best'
84 | );
85 | console.log(`[CHECKPOINT] ✅ Meilleur modèle sauvegardé`);
86 | }
87 |
88 | // Sauvegarder un checkpoint tous les 100 steps
89 | if (env.stepCount % 100 === 0) {
90 | console.log(`[CHECKPOINT] Sauvegarde régulière à l'étape ${env.stepCount}`);
91 | await agent.saveCheckpoint(
92 | 'salim4n/test-checkpoint',
93 | process.env.HF_TOKEN!,
94 | `step-${env.stepCount}`
95 | );
96 | }
97 |
98 | // Si c'est la fin, sauvegarder un dernier checkpoint
99 | if (isDone()) {
100 | console.log(`Training finished at step ${env.stepCount}!`);
101 | console.log(`Final distance: ${d.toFixed(2)}`);
102 | await agent.saveCheckpoint(
103 | 'salim4n/dqn-checkpoint-demo',
104 | process.env.HF_TOKEN!,
105 | 'final'
106 | );
107 | env.stop();
108 | process.exit(0);
109 | }
110 | };
111 |
112 | env.start();
113 |
--------------------------------------------------------------------------------
/packages/demo-target-chasing/src/store/trainingStore.ts:
--------------------------------------------------------------------------------
1 | import { create } from 'zustand';
2 |
3 | interface TrainingState {
4 | // États d'entraînement
5 | isTraining: boolean;
6 | isTrainingInProgress: boolean;
7 | episodeCount: number;
8 | reward: number;
9 | bestReward: number;
10 | episodeSteps: number;
11 | reachedTarget: boolean;
12 | episodeTime: number;
13 | episodeStartTime: number;
14 | successCount: number;
15 | difficulty: number;
16 | lastAction: number;
17 |
18 | // État de la cible
19 | targetPosition: [number, number, number];
20 |
21 | // Actions
22 | setIsTraining: (value: boolean) => void;
23 | setIsTrainingInProgress: (value: boolean) => void;
24 | setEpisodeCount: (value: number | ((prev: number) => number)) => void;
25 | setReward: (value: number) => void;
26 | setBestReward: (value: number) => void;
27 | setEpisodeSteps: (value: number | ((prev: number) => number)) => void;
28 | setReachedTarget: (value: boolean) => void;
29 | setEpisodeTime: (value: number) => void;
30 | setEpisodeStartTime: (value: number) => void;
31 | setSuccessCount: (value: number | ((prev: number) => number)) => void;
32 | setDifficulty: (value: number | ((prev: number) => number)) => void;
33 | setLastAction: (value: number) => void;
34 | setTargetPosition: (value: [number, number, number]) => void;
35 |
36 | // Méthodes utilitaires
37 | resetEpisode: () => void;
38 | incrementEpisodeCount: () => void;
39 | }
40 |
41 | export const useTrainingStore = create((set) => ({
42 | // États initiaux
43 | isTraining: false,
44 | isTrainingInProgress: false,
45 | episodeCount: 0,
46 | reward: 0,
47 | bestReward: -Infinity,
48 | episodeSteps: 0,
49 | reachedTarget: false,
50 | episodeTime: 0,
51 | episodeStartTime: Date.now(),
52 | successCount: 0,
53 | difficulty: 0,
54 | lastAction: -1,
55 |
56 | // État initial de la cible
57 | targetPosition: [0, 10, 0] as [number, number, number],
58 |
59 | // Actions
60 | setIsTraining: (value) => set({ isTraining: value }),
61 | setIsTrainingInProgress: (value) => set({ isTrainingInProgress: value }),
62 | setEpisodeCount: (value) => set((state) => ({
63 | episodeCount: typeof value === 'function' ? value(state.episodeCount) : value
64 | })),
65 | setReward: (value) => set({ reward: value }),
66 | setBestReward: (value) => set({ bestReward: value }),
67 | setEpisodeSteps: (value) => set((state) => ({
68 | episodeSteps: typeof value === 'function' ? value(state.episodeSteps) : value
69 | })),
70 | setReachedTarget: (value) => set({ reachedTarget: value }),
71 | setEpisodeTime: (value) => set({ episodeTime: value }),
72 | setEpisodeStartTime: (value) => set({ episodeStartTime: value }),
73 | setSuccessCount: (value) => set((state) => ({
74 | successCount: typeof value === 'function' ? value(state.successCount) : value
75 | })),
76 | setDifficulty: (value) => set((state) => ({
77 | difficulty: typeof value === 'function' ? value(state.difficulty) : value
78 | })),
79 | setLastAction: (value) => set({ lastAction: value }),
80 | setTargetPosition: (value) => set({ targetPosition: value }),
81 |
82 | // Méthodes utilitaires
83 | resetEpisode: () => set((state) => ({
84 | episodeSteps: 0,
85 | reachedTarget: false,
86 | episodeCount: state.episodeCount + 1,
87 | episodeTime: 0,
88 | episodeStartTime: Date.now()
89 | })),
90 | incrementEpisodeCount: () => set((state) => ({ episodeCount: state.episodeCount + 1 }))
91 | }));
92 |
--------------------------------------------------------------------------------
/packages/demo-target-chasing/src/visualization.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | DQN Agent Visualization
5 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/packages/demo-target-chasing/src/visualization.ts:
--------------------------------------------------------------------------------
1 | import * as THREE from 'three';
2 | import { OrbitControls } from 'three/examples/jsm/controls/OrbitControls';
3 |
4 | import { DQNAgent } from '@ignitionai/backend-tfjs';
5 | import { IgnitionEnv } from '@ignitionai/core';
6 |
7 | // Configuration de la scène Three.js
8 | const scene = new THREE.Scene();
9 | const camera = new THREE.PerspectiveCamera(75, window.innerWidth / window.innerHeight, 0.1, 1000);
10 | const renderer = new THREE.WebGLRenderer();
11 | renderer.setSize(window.innerWidth, window.innerHeight);
12 | document.body.appendChild(renderer.domElement);
13 |
14 | // Ajouter des contrôles de caméra
15 | const controls = new OrbitControls(camera, renderer.domElement);
16 | controls.enableDamping = true;
17 | controls.dampingFactor = 0.05;
18 |
19 | // Créer une grille pour mieux visualiser l'espace
20 | const gridHelper = new THREE.GridHelper(50, 50);
21 | scene.add(gridHelper);
22 |
23 | // Créer l'agent (sphère bleue)
24 | const agentGeometry = new THREE.SphereGeometry(0.2, 32, 32);
25 | const agentMaterial = new THREE.MeshBasicMaterial({ color: 0x0000ff });
26 | const agent = new THREE.Mesh(agentGeometry, agentMaterial);
27 | scene.add(agent);
28 |
29 | // Créer la cible (sphère rouge)
30 | const targetGeometry = new THREE.SphereGeometry(0.2, 32, 32);
31 | const targetMaterial = new THREE.MeshBasicMaterial({ color: 0xff0000 });
32 | const target = new THREE.Mesh(targetGeometry, targetMaterial);
33 | scene.add(target);
34 |
35 | // Positionner la caméra
36 | camera.position.set(5, 5, 5);
37 | camera.lookAt(0, 0, 0);
38 |
39 | // Variables pour l'environnement
40 | let position = 0;
41 | let targetPosition = (Math.random() - 0.5) * 4;
42 | let bestDistance = Infinity;
43 | let stepCount = 0;
44 | let previousDistance = Infinity;
45 |
46 | // Créer l'agent DQN
47 | const dqnAgent = new DQNAgent({
48 | inputSize: 2,
49 | actionSize: 3,
50 | hiddenLayers: [32, 32],
51 | gamma: 0.99,
52 | epsilon: 1.0,
53 | epsilonDecay: 0.995,
54 | minEpsilon: 0.01,
55 | lr: 0.001,
56 | batchSize: 32,
57 | memorySize: 1000,
58 | targetUpdateFrequency: 10,
59 | });
60 |
61 | // Vérifier si le token est disponible
62 | const hfToken = import.meta.env?.VITE_HF_TOKEN;
63 | if (!hfToken) {
64 | console.warn('⚠️ VITE_HF_TOKEN non trouvé. Les checkpoints ne seront pas sauvegardés.');
65 | }
66 |
67 | // Créer l'environnement
68 | const env: IgnitionEnv = new IgnitionEnv({
69 | agent: dqnAgent,
70 | getObservation: () => [position, targetPosition],
71 | applyAction: (action: number | number[]) => {
72 | const a = Array.isArray(action) ? action[0] : action;
73 | const dx = a - 1;
74 | position += dx * 0.2;
75 | agent.position.x = position;
76 |
77 | // Log de l'action
78 | console.log(`[ACTION] ${a} (dx: ${dx.toFixed(2)})`);
79 | },
80 |
81 | computeReward: () => {
82 | const d = Math.abs(position - targetPosition);
83 |
84 | // Vérifier si l'agent s'éloigne
85 | const isMovingAway = d > previousDistance;
86 | previousDistance = d;
87 |
88 | // Récompense de base
89 | let reward = 1.0 / (1.0 + d);
90 |
91 | // Pénalité si s'éloigne
92 | if (isMovingAway) {
93 | reward -= 0.5;
94 | }
95 |
96 | // Bonus si proche
97 | if (d < 0.5) {
98 | reward += 1.0;
99 | }
100 |
101 | return reward;
102 | },
103 | isDone: (): boolean => {
104 | const d = Math.abs(position - targetPosition);
105 | const done = d < 0.1 || stepCount > 1000;
106 |
107 | if (done) {
108 | console.log(`[DONE] Distance finale: ${d.toFixed(2)}`);
109 | }
110 |
111 | return done;
112 | },
113 | onReset: () => {
114 | position = 0;
115 | targetPosition = (Math.random() - 0.5) * 4;
116 | agent.position.x = position;
117 | target.position.x = targetPosition;
118 | stepCount = 0;
119 | bestDistance = Infinity;
120 | previousDistance = Math.abs(position - targetPosition);
121 |
122 | // Log du reset
123 | console.log(`[RESET] Nouvelle cible: ${targetPosition.toFixed(2)}`);
124 | },
125 | stepIntervalMs: 100,
126 | hfRepoId: 'salim4n/dqn-checkpoint-threejs',
127 | hfToken: hfToken || '',
128 | });
129 |
130 | // Étendre la méthode step pour gérer les checkpoints
131 | const originalStep = env.step.bind(env);
132 | env.step = async (action?: number) => {
133 | // Attendre que l'étape précédente soit terminée
134 | const result = await originalStep();
135 | stepCount++;
136 |
137 | const d = Math.abs(position - targetPosition);
138 |
139 | // Log de l'étape
140 | if (stepCount % 10 === 0) {
141 | console.log(`[STEP ${stepCount}] Position: ${position.toFixed(2)}, Cible: ${targetPosition.toFixed(2)}, Distance: ${d.toFixed(2)}`);
142 | }
143 |
144 | // Sauvegarder un checkpoint si c'est la meilleure performance jusqu'à présent
145 | if (d < bestDistance) {
146 | bestDistance = d;
147 | console.log(`[CHECKPOINT] Nouvelle meilleure distance: ${d.toFixed(4)}`);
148 | console.log(`[CHECKPOINT] Sauvegarde du meilleur modèle...`);
149 | // Désactiver la sauvegarde dans le navigateur
150 | // await dqnAgent.saveCheckpoint(
151 | // 'salim4n/dqn-checkpoint-threejs',
152 | // hfToken || '',
153 | // 'best'
154 | // );
155 | console.log(`[CHECKPOINT] ✅ Meilleur modèle sauvegardé (simulé)`);
156 | }
157 |
158 | // Sauvegarder un checkpoint tous les 100 steps
159 | if (stepCount % 100 === 0) {
160 | console.log(`[CHECKPOINT] Sauvegarde régulière à l'étape ${stepCount}`);
161 | // Désactiver la sauvegarde dans le navigateur
162 | // await dqnAgent.saveCheckpoint(
163 | // 'salim4n/dqn-checkpoint-threejs',
164 | // hfToken || '',
165 | // `step-${stepCount}`
166 | // );
167 | }
168 |
169 | // Si c'est la fin, sauvegarder un dernier checkpoint
170 | const isDone = (): boolean => {
171 | const d = Math.abs(position - targetPosition);
172 | return d < 0.1 || stepCount > 1000;
173 | };
174 |
175 | if (isDone()) {
176 | console.log(`[FINISH] Entraînement terminé à l'étape ${stepCount}!`);
177 | console.log(`[FINISH] Distance finale: ${d.toFixed(2)}`);
178 | // Désactiver la sauvegarde dans le navigateur
179 | // await dqnAgent.saveCheckpoint(
180 | // 'salim4n/dqn-checkpoint-threejs',
181 | // hfToken || '',
182 | // 'final'
183 | // );
184 | env.stop();
185 | }
186 |
187 | return result;
188 | };
189 |
190 | // Fonction d'animation
191 | function animate() {
192 | requestAnimationFrame(animate);
193 | controls.update();
194 | renderer.render(scene, camera);
195 | }
196 |
197 | // Gérer le redimensionnement de la fenêtre
198 | window.addEventListener('resize', () => {
199 | camera.aspect = window.innerWidth / window.innerHeight;
200 | camera.updateProjectionMatrix();
201 | renderer.setSize(window.innerWidth, window.innerHeight);
202 | });
203 |
204 | // Démarrer l'animation et l'environnement
205 | console.log('[START] Démarrage de la visualisation...');
206 | animate();
207 | env.start();
--------------------------------------------------------------------------------
/packages/demo-target-chasing/styles.css:
--------------------------------------------------------------------------------
1 | .ui-panels {
2 | display: flex;
3 | flex-direction: column;
4 | position: fixed;
5 | top: 20px;
6 | right: 20px;
7 | width: 350px;
8 | max-height: 90vh;
9 | overflow-y: auto;
10 | z-index: 1000;
11 | gap: 15px;
12 | }
13 |
14 | .training-controls {
15 | background: rgba(0, 0, 0, 0.7);
16 | padding: 15px;
17 | border-radius: 5px;
18 | color: white;
19 | font-family: 'Arial', sans-serif;
20 | box-shadow: 0 0 10px rgba(0, 0, 0, 0.5);
21 | }
22 |
23 | .training-controls h3 {
24 | margin-top: 0;
25 | color: #0c8cbf; /* Bleu IgnitionAI */
26 | border-bottom: 1px solid #3f4e8d;
27 | padding-bottom: 8px;
28 | }
29 |
30 | .training-controls button {
31 | background-color: #0c8cbf;
32 | border: none;
33 | color: white;
34 | padding: 8px 12px;
35 | border-radius: 4px;
36 | cursor: pointer;
37 | transition: background-color 0.3s;
38 | }
39 |
40 | .training-controls button:hover {
41 | background-color: #3f4e8d;
42 | }
43 |
44 | .training-controls div {
45 | margin-bottom: 5px;
46 | }
47 |
48 | /* Visualization Charts Styles */
49 | .visualization-charts {
50 | background: rgba(0, 0, 0, 0.7);
51 | padding: 15px;
52 | border-radius: 5px;
53 | color: white;
54 | font-family: 'Arial', sans-serif;
55 | box-shadow: 0 0 10px rgba(0, 0, 0, 0.5);
56 | }
57 |
58 | .visualization-charts h3 {
59 | margin-top: 0;
60 | color: #0c8cbf;
61 | border-bottom: 1px solid #3f4e8d;
62 | padding-bottom: 8px;
63 | }
64 |
65 | .visualization-charts h4 {
66 | margin-top: 10px;
67 | margin-bottom: 5px;
68 | color: #ffffff;
69 | }
70 |
71 | .chart-container {
72 | margin-bottom: 20px;
73 | background: rgba(30, 30, 30, 0.5);
74 | padding: 10px;
75 | border-radius: 4px;
76 | }
77 |
78 | /* Agent Config Panel Styles */
79 | .agent-config-panel {
80 | background: rgba(0, 0, 0, 0.7);
81 | padding: 15px;
82 | border-radius: 5px;
83 | color: white;
84 | font-family: 'Arial', sans-serif;
85 | box-shadow: 0 0 10px rgba(0, 0, 0, 0.5);
86 | }
87 |
88 | .agent-config-panel h3 {
89 | margin-top: 0;
90 | color: #0c8cbf;
91 | border-bottom: 1px solid #3f4e8d;
92 | padding-bottom: 8px;
93 | }
94 |
95 | .agent-config-panel h4 {
96 | margin-top: 15px;
97 | margin-bottom: 10px;
98 | color: #ffffff;
99 | }
100 |
101 | .agent-config-panel h5 {
102 | margin-top: 10px;
103 | margin-bottom: 5px;
104 | color: #cccccc;
105 | }
106 |
107 | .config-section {
108 | margin-bottom: 15px;
109 | background: rgba(30, 30, 30, 0.5);
110 | padding: 10px;
111 | border-radius: 4px;
112 | }
113 |
114 | .config-row {
115 | display: flex;
116 | justify-content: space-between;
117 | align-items: center;
118 | margin-bottom: 8px;
119 | }
120 |
121 | .config-row label {
122 | flex: 1;
123 | margin-right: 10px;
124 | }
125 |
126 | .config-row input {
127 | width: 80px;
128 | padding: 5px;
129 | border-radius: 3px;
130 | border: 1px solid #555;
131 | background: #333;
132 | color: white;
133 | }
134 |
135 | .layers-container {
136 | margin-top: 10px;
137 | }
138 |
139 | .layer-row {
140 | display: flex;
141 | align-items: center;
142 | margin-bottom: 8px;
143 | }
144 |
145 | .layer-row label {
146 | width: 80px;
147 | margin-right: 10px;
148 | }
149 |
150 | .layer-row input {
151 | width: 60px;
152 | padding: 5px;
153 | border-radius: 3px;
154 | border: 1px solid #555;
155 | background: #333;
156 | color: white;
157 | margin-right: 10px;
158 | }
159 |
160 | .layer-row button {
161 | background-color: #d32f2f;
162 | border: none;
163 | color: white;
164 | padding: 5px 8px;
165 | border-radius: 3px;
166 | cursor: pointer;
167 | font-size: 0.8em;
168 | }
169 |
170 | .layers-container button {
171 | background-color: #388e3c;
172 | border: none;
173 | color: white;
174 | padding: 5px 10px;
175 | border-radius: 3px;
176 | cursor: pointer;
177 | margin-top: 5px;
178 | }
179 |
180 | .apply-button {
181 | background-color: #0c8cbf;
182 | border: none;
183 | color: white;
184 | padding: 10px 15px;
185 | border-radius: 4px;
186 | cursor: pointer;
187 | transition: background-color 0.3s;
188 | width: 100%;
189 | margin-top: 15px;
190 | font-weight: bold;
191 | }
192 |
193 | .apply-button:hover {
194 | background-color: #3f4e8d;
195 | }
196 |
197 | /* Network Designer Styles */
198 | .network-designer-panel {
199 | background: rgba(0, 0, 0, 0.7);
200 | padding: 15px;
201 | border-radius: 5px;
202 | color: white;
203 | font-family: 'Arial', sans-serif;
204 | box-shadow: 0 0 10px rgba(0, 0, 0, 0.5);
205 | }
206 |
207 | .network-designer-panel h3 {
208 | margin-top: 0;
209 | color: #0c8cbf;
210 | border-bottom: 1px solid #3f4e8d;
211 | padding-bottom: 8px;
212 | }
213 |
214 | /* React Flow specific styles */
215 | .react-flow__node {
216 | background: #2a2a2a;
217 | color: #eee;
218 | border: 1px solid #555;
219 | border-radius: 4px;
220 | padding: 8px 12px;
221 | font-size: 12px;
222 | }
223 |
224 | .react-flow__node.react-flow__node-input {
225 | background: #388e3c;
226 | border-color: #66bb6a;
227 | }
228 |
229 | .react-flow__node.react-flow__node-output {
230 | background: #d32f2f;
231 | border-color: #ef5350;
232 | }
233 |
234 | .react-flow__edge-path {
235 | stroke: #0c8cbf;
236 | stroke-width: 2;
237 | }
238 |
239 | .react-flow__controls button {
240 | background-color: rgba(40, 40, 40, 0.8);
241 | color: white;
242 | border: 1px solid #555;
243 | }
244 |
245 | .react-flow__controls button:hover {
246 | background-color: rgba(60, 60, 60, 0.9);
247 | }
248 |
249 | .react-flow__minimap {
250 | background-color: rgba(30, 30, 30, 0.8);
251 | border: 1px solid #555;
252 | }
253 |
254 | .react-flow__background {
255 | background-color: #1e1e1e;
256 | }
257 |
258 |
--------------------------------------------------------------------------------
/packages/demo-target-chasing/tsconfig.app.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
4 | "target": "ES2020",
5 | "useDefineForClassFields": true,
6 | "lib": ["ES2020", "DOM", "DOM.Iterable"],
7 | "module": "ESNext",
8 | "skipLibCheck": true,
9 |
10 | /* Bundler mode */
11 | "moduleResolution": "bundler",
12 | "allowImportingTsExtensions": true,
13 | "isolatedModules": true,
14 | "moduleDetection": "force",
15 | "noEmit": true,
16 | "jsx": "react-jsx",
17 |
18 | /* Linting */
19 | "strict": true,
20 | "noUnusedLocals": true,
21 | "noUnusedParameters": true,
22 | "noFallthroughCasesInSwitch": true,
23 | "noUncheckedSideEffectImports": true
24 | },
25 | "include": ["src"]
26 | }
27 |
--------------------------------------------------------------------------------
/packages/demo-target-chasing/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "../../tsconfig.base.json",
3 | "compilerOptions": {
4 | "outDir": "dist"
5 | },
6 | "include": ["src/**/*"]
7 | }
8 |
--------------------------------------------------------------------------------
/packages/demo-target-chasing/vite.config.ts:
--------------------------------------------------------------------------------
1 | import * as dotenv from 'dotenv';
2 | import path from 'path';
3 | import { defineConfig } from 'vite';
4 |
5 | import react from '@vitejs/plugin-react';
6 |
7 | // Charger les variables d'environnement
8 | dotenv.config();
9 |
10 | // Déclarer les types d'environnement
11 | declare global {
12 | interface ImportMetaEnv {
13 | VITE_HF_TOKEN: string;
14 | }
15 | }
16 |
17 | export default defineConfig({
18 | root: 'src',
19 | publicDir: '../public',
20 | build: {
21 | outDir: '../dist',
22 | },
23 | server: {
24 | port: 3000,
25 | },
26 | define: {
27 | 'process.env': process.env
28 | },
29 | plugins: [react()],
30 | resolve: {
31 | alias: {
32 | '@ignitionai/backend-tfjs': path.resolve(__dirname, '../backend-tfjs/src'),
33 | '@ignitionai/core': path.resolve(__dirname, '../core/src')
34 | }
35 | }
36 | });
--------------------------------------------------------------------------------
/pnpm-workspace.yaml:
--------------------------------------------------------------------------------
1 | packages:
2 | - 'packages/*'
3 |
--------------------------------------------------------------------------------
/r3f/target-chasing/.gitignore:
--------------------------------------------------------------------------------
1 | # Logs
2 | logs
3 | *.log
4 | npm-debug.log*
5 | yarn-debug.log*
6 | yarn-error.log*
7 | pnpm-debug.log*
8 | lerna-debug.log*
9 |
10 | node_modules
11 | dist
12 | dist-ssr
13 | *.local
14 |
15 | # Editor directories and files
16 | .vscode/*
17 | !.vscode/extensions.json
18 | .idea
19 | .DS_Store
20 | *.suo
21 | *.ntvs*
22 | *.njsproj
23 | *.sln
24 | *.sw?
25 |
--------------------------------------------------------------------------------
/r3f/target-chasing/README.md:
--------------------------------------------------------------------------------
1 | # React + TypeScript + Vite
2 |
3 | This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
4 |
5 | Currently, two official plugins are available:
6 |
7 | - [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react) uses [Babel](https://babeljs.io/) for Fast Refresh
8 | - [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
9 |
10 | ## Expanding the ESLint configuration
11 |
12 | If you are developing a production application, we recommend updating the configuration to enable type-aware lint rules:
13 |
14 | ```js
15 | export default tseslint.config({
16 | extends: [
17 | // Remove ...tseslint.configs.recommended and replace with this
18 | ...tseslint.configs.recommendedTypeChecked,
19 | // Alternatively, use this for stricter rules
20 | ...tseslint.configs.strictTypeChecked,
21 | // Optionally, add this for stylistic rules
22 | ...tseslint.configs.stylisticTypeChecked,
23 | ],
24 | languageOptions: {
25 | // other options...
26 | parserOptions: {
27 | project: ['./tsconfig.node.json', './tsconfig.app.json'],
28 | tsconfigRootDir: import.meta.dirname,
29 | },
30 | },
31 | })
32 | ```
33 |
34 | You can also install [eslint-plugin-react-x](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-x) and [eslint-plugin-react-dom](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-dom) for React-specific lint rules:
35 |
36 | ```js
37 | // eslint.config.js
38 | import reactX from 'eslint-plugin-react-x'
39 | import reactDom from 'eslint-plugin-react-dom'
40 |
41 | export default tseslint.config({
42 | plugins: {
43 | // Add the react-x and react-dom plugins
44 | 'react-x': reactX,
45 | 'react-dom': reactDom,
46 | },
47 | rules: {
48 | // other rules...
49 | // Enable its recommended typescript rules
50 | ...reactX.configs['recommended-typescript'].rules,
51 | ...reactDom.configs.recommended.rules,
52 | },
53 | })
54 | ```
55 |
--------------------------------------------------------------------------------
/r3f/target-chasing/eslint.config.js:
--------------------------------------------------------------------------------
1 | import js from '@eslint/js'
2 | import globals from 'globals'
3 | import reactHooks from 'eslint-plugin-react-hooks'
4 | import reactRefresh from 'eslint-plugin-react-refresh'
5 | import tseslint from 'typescript-eslint'
6 |
7 | export default tseslint.config(
8 | { ignores: ['dist'] },
9 | {
10 | extends: [js.configs.recommended, ...tseslint.configs.recommended],
11 | files: ['**/*.{ts,tsx}'],
12 | languageOptions: {
13 | ecmaVersion: 2020,
14 | globals: globals.browser,
15 | },
16 | plugins: {
17 | 'react-hooks': reactHooks,
18 | 'react-refresh': reactRefresh,
19 | },
20 | rules: {
21 | ...reactHooks.configs.recommended.rules,
22 | 'react-refresh/only-export-components': [
23 | 'warn',
24 | { allowConstantExport: true },
25 | ],
26 | },
27 | },
28 | )
29 |
--------------------------------------------------------------------------------
/r3f/target-chasing/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | Target Chasing - IgnitionAI
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/r3f/target-chasing/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "target-chasing",
3 | "private": true,
4 | "version": "0.0.0",
5 | "type": "module",
6 | "scripts": {
7 | "dev": "vite",
8 | "build": "tsc -b && vite build",
9 | "lint": "eslint .",
10 | "preview": "vite preview"
11 | },
12 | "dependencies": {
13 | "@ignitionai/backend-tfjs": "workspace:*",
14 | "@ignitionai/core": "workspace:*",
15 | "@react-three/drei": "^10.0.6",
16 | "@react-three/fiber": "^9.1.2",
17 | "@react-three/rapier": "^2.1.0",
18 | "@types/three": "^0.162.0",
19 | "react": "^19.0.0",
20 | "react-dom": "^19.0.0",
21 | "reactflow": "^11.11.4",
22 | "recharts": "^2.15.3",
23 | "three": "^0.162.0",
24 | "zustand": "^5.0.3"
25 | },
26 | "devDependencies": {
27 | "@eslint/js": "^9.22.0",
28 | "@react-three/eslint-plugin": "^0.1.2",
29 | "@types/react": "^19.0.10",
30 | "@types/react-dom": "^19.0.4",
31 | "@vitejs/plugin-react": "^4.3.4",
32 | "eslint": "^9.22.0",
33 | "eslint-plugin-react-hooks": "^5.2.0",
34 | "eslint-plugin-react-refresh": "^0.4.19",
35 | "globals": "^16.0.0",
36 | "typescript": "~5.7.2",
37 | "typescript-eslint": "^8.26.1",
38 | "vite": "^6.3.1"
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/r3f/target-chasing/public/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IgnitionAI/ignition/9177384c3c7750bc25a596c10150ad5008bfaa95/r3f/target-chasing/public/logo.png
--------------------------------------------------------------------------------
/r3f/target-chasing/public/vite.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/r3f/target-chasing/src/App.tsx:
--------------------------------------------------------------------------------
1 | import { Canvas } from "@react-three/fiber";
2 | import { Physics } from "@react-three/rapier";
3 | import { KeyboardControls, OrbitControls } from "@react-three/drei";
4 | import Experience from "./Experience";
5 | import { useMemo, useRef, useState, useCallback } from "react"; // Added useCallback
6 | import { TrainingControls } from "./TrainingControls";
7 | import { VisualizationCharts } from "./components/VisualizationCharts";
8 | import { AgentConfigPanel } from "./components/AgentConfigPanel";
9 | import { NetworkDesigner } from "./components/NetworkDesigner"; // Import the network designer
10 |
11 | import "./styles.css";
12 |
13 | export const Controls = {
14 | forward: "forward",
15 | back: "back",
16 | left: "left",
17 | right: "right",
18 | jump: "jump",
19 | };
20 |
21 | // Define the structure for agent configuration
22 | interface AgentConfig {
23 | inputSize: number;
24 | actionSize: number;
25 | hiddenLayers: number[];
26 | epsilon: number;
27 | epsilonDecay: number;
28 | minEpsilon: number;
29 | gamma: number;
30 | lr: number;
31 | batchSize: number;
32 | memorySize: number;
33 | }
34 |
35 | function App() {
36 | const map = useMemo(
37 | () => [
38 | { name: Controls.forward, keys: ["ArrowUp", "KeyW"] },
39 | { name: Controls.back, keys: ["ArrowDown", "KeyS"] },
40 | { name: Controls.left, keys: ["ArrowLeft", "KeyA"] },
41 | { name: Controls.right, keys: ["ArrowRight", "KeyD"] },
42 | { name: Controls.jump, keys: ["Space"] },
43 | ],
44 | []
45 | );
46 |
47 | // Reference to the Experience component
48 | const experienceRef = useRef(null);
49 |
50 | // State to hold the current agent configuration
51 | const [agentConfig, setAgentConfig] = useState({
52 | // Default config matching Experience.tsx initial setup
53 | inputSize: 9,
54 | actionSize: 4,
55 | hiddenLayers: [64, 64],
56 | epsilon: 0.9,
57 | epsilonDecay: 0.97,
58 | minEpsilon: 0.05,
59 | gamma: 0.99,
60 | lr: 0.001,
61 | batchSize: 128,
62 | memorySize: 100000,
63 | });
64 |
65 | // Callback for the config panel to update the configuration
66 | const handleApplyConfig = useCallback((newConfig: AgentConfig) => {
67 | console.log("Applying new agent configuration from panel:", newConfig);
68 | setAgentConfig(newConfig);
69 | // Optionally, reset the environment when config changes
70 | // resetEnvironment();
71 | }, []);
72 |
73 | // Callback for the network designer to update hidden layers
74 | const handleNetworkChange = useCallback((newHiddenLayers: number[]) => {
75 | console.log("Applying new hidden layers from designer:", newHiddenLayers);
76 | setAgentConfig(prevConfig => ({
77 | ...prevConfig,
78 | hiddenLayers: newHiddenLayers,
79 | }));
80 | // Optionally, reset the environment when network structure changes
81 | // resetEnvironment();
82 | }, []);
83 |
84 | // Control functions passed to TrainingControls
85 | const startTraining = () => {
86 | experienceRef.current?.startTraining(agentConfig); // Pass config on start
87 | };
88 |
89 | const stopTraining = () => {
90 | experienceRef.current?.stopTraining();
91 | };
92 |
93 | const resetEnvironment = () => {
94 | experienceRef.current?.resetEnvironment(agentConfig); // Pass config on reset
95 | };
96 |
97 | return (
98 | <>
99 | {/* UI Panels Container */}
100 |
101 |
106 |
107 |
108 |
{/* Add the network designer */}
109 |
110 |
111 |
112 |
113 |
114 |
115 | {/* Disable physics debug view for clarity */}
116 | {/* Pass the agentConfig to the Experience component */}
117 |
118 |
119 |
120 |
121 | >
122 | );
123 | }
124 |
125 | export default App;
126 |
--------------------------------------------------------------------------------
/r3f/target-chasing/src/TrainingControls.tsx:
--------------------------------------------------------------------------------
1 | import { useTrainingStore } from './store/trainingStore';
2 |
3 | interface TrainingControlsProps {
4 | startTraining: () => void;
5 | stopTraining: () => void;
6 | resetEnvironment: () => void;
7 | }
8 |
9 | export function TrainingControls({ startTraining, stopTraining, resetEnvironment }: TrainingControlsProps) {
10 | const {
11 | isTraining,
12 | episodeCount,
13 | reward,
14 | episodeTime,
15 | successCount,
16 | difficulty,
17 | lastAction
18 | } = useTrainingStore();
19 |
20 | return (
21 |
22 |
Contrôle d'entraînement
23 |
Épisodes: {episodeCount}
24 |
Succès: {successCount} / {episodeCount}
25 |
Difficulté: {difficulty + 1}/3
26 |
Temps: {episodeTime.toFixed(1)}s
27 |
Dernière action: {lastAction !== -1 ? ['Gauche', 'Droite', 'Avant', 'Arrière'][lastAction] : 'Aucune'}
28 |
Récompense: {reward.toFixed(2)}
29 |
30 | {!isTraining ? (
31 | Démarrer l'entraînement
32 | ) : (
33 | Arrêter l'entraînement
34 | )}
35 | Réinitialiser
36 |
37 |
38 | );
39 | }
40 |
--------------------------------------------------------------------------------
/r3f/target-chasing/src/assets/react.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/r3f/target-chasing/src/components/AgentConfigPanel.tsx:
--------------------------------------------------------------------------------
1 | import React, { useState } from 'react';
2 | import { useTrainingStore } from '../store/trainingStore';
3 |
4 | interface LayerConfig {
5 | id: string;
6 | neurons: number;
7 | }
8 |
9 | interface AgentConfigPanelProps {
10 | onApplyConfig: (config: {
11 | inputSize: number;
12 | actionSize: number;
13 | hiddenLayers: number[];
14 | epsilon: number;
15 | epsilonDecay: number;
16 | minEpsilon: number;
17 | gamma: number;
18 | lr: number;
19 | batchSize: number;
20 | memorySize: number;
21 | }) => void;
22 | }
23 |
24 | export function AgentConfigPanel({ onApplyConfig }: AgentConfigPanelProps) {
25 | // Default configuration
26 | const [inputSize, setInputSize] = useState(9);
27 | const [actionSize, setActionSize] = useState(4);
28 | const [layers, setLayers] = useState([
29 | { id: 'layer1', neurons: 64 },
30 | { id: 'layer2', neurons: 64 }
31 | ]);
32 | const [epsilon, setEpsilon] = useState(0.9);
33 | const [epsilonDecay, setEpsilonDecay] = useState(0.97);
34 | const [minEpsilon, setMinEpsilon] = useState(0.05);
35 | const [gamma, setGamma] = useState(0.99);
36 | const [learningRate, setLearningRate] = useState(0.001);
37 | const [batchSize, setBatchSize] = useState(128);
38 | const [memorySize, setMemorySize] = useState(100000);
39 |
40 | // Function to add a new layer
41 | const addLayer = () => {
42 | const newId = `layer${layers.length + 1}`;
43 | setLayers([...layers, { id: newId, neurons: 32 }]);
44 | };
45 |
46 | // Function to remove a layer
47 | const removeLayer = (id: string) => {
48 | if (layers.length > 1) {
49 | setLayers(layers.filter(layer => layer.id !== id));
50 | }
51 | };
52 |
53 | // Function to update a layer's neuron count
54 | const updateLayer = (id: string, neurons: number) => {
55 | setLayers(layers.map(layer =>
56 | layer.id === id ? { ...layer, neurons } : layer
57 | ));
58 | };
59 |
60 | // Apply configuration
61 | const applyConfig = () => {
62 | onApplyConfig({
63 | inputSize,
64 | actionSize,
65 | hiddenLayers: layers.map(layer => layer.neurons),
66 | epsilon,
67 | epsilonDecay,
68 | minEpsilon,
69 | gamma,
70 | lr: learningRate,
71 | batchSize,
72 | memorySize
73 | });
74 | };
75 |
76 | return (
77 |
78 |
Agent Configuration
79 |
80 |
81 |
Network Architecture
82 |
83 |
84 | Input Size:
85 | setInputSize(parseInt(e.target.value))}
89 | min="1"
90 | />
91 |
92 |
93 |
94 | Action Size:
95 | setActionSize(parseInt(e.target.value))}
99 | min="1"
100 | />
101 |
102 |
103 |
104 |
Hidden Layers
105 | {layers.map((layer, index) => (
106 |
107 | Layer {index + 1}:
108 | updateLayer(layer.id, parseInt(e.target.value))}
112 | min="1"
113 | />
114 | removeLayer(layer.id)}>Remove
115 |
116 | ))}
117 |
Add Layer
118 |
119 |
120 |
121 |
204 |
205 |
Apply Configuration
206 |
207 | );
208 | }
209 |
--------------------------------------------------------------------------------
/r3f/target-chasing/src/components/NetworkDesigner.tsx:
--------------------------------------------------------------------------------
1 | import 'reactflow/dist/style.css';
2 |
3 | import React, {
4 | useCallback,
5 | useEffect,
6 | useState,
7 | } from 'react';
8 |
9 | import ReactFlow, {
10 | addEdge,
11 | applyEdgeChanges,
12 | applyNodeChanges,
13 | Background,
14 | Connection,
15 | Controls,
16 | Edge,
17 | EdgeChange,
18 | MiniMap,
19 | Node,
20 | NodeChange,
21 | OnConnect,
22 | OnEdgesChange,
23 | OnNodesChange,
24 | ReactFlowProvider,
25 | } from 'reactflow';
26 |
27 | interface NetworkDesignerProps {
28 | onNetworkChange: (layers: number[]) => void;
29 | }
30 |
31 | // New ReactFlow versions separate nodes and edges
32 | type FlowElement = Node | Edge;
33 |
34 | const initialNodes: Node[] = [
35 | { id: 'input', type: 'input', data: { label: 'Input (Size: 9)' }, position: { x: 100, y: 100 } },
36 | { id: 'hidden1', type: 'default', data: { label: 'Dense (Neurons: 64)' }, position: { x: 300, y: 50 } },
37 | { id: 'hidden2', type: 'default', data: { label: 'Dense (Neurons: 64)' }, position: { x: 300, y: 150 } },
38 | { id: 'output', type: 'output', data: { label: 'Output (Actions: 4)' }, position: { x: 500, y: 100 } },
39 | ];
40 |
41 | const initialEdges: Edge[] = [
42 | { id: 'e-input-h1', source: 'input', target: 'hidden1', animated: true },
43 | { id: 'e-input-h2', source: 'input', target: 'hidden2', animated: true },
44 | { id: 'e-h1-output', source: 'hidden1', target: 'output', animated: true },
45 | { id: 'e-h2-output', source: 'hidden2', target: 'output', animated: true },
46 | ];
47 |
48 | export function NetworkDesigner({ onNetworkChange }: NetworkDesignerProps) {
49 | const [nodes, setNodes] = useState(initialNodes);
50 | const [edges, setEdges] = useState(initialEdges);
51 |
52 | const onNodesChange: OnNodesChange = useCallback(
53 | (changes: NodeChange[]) => setNodes((nds) => applyNodeChanges(changes, nds)),
54 | []
55 | );
56 |
57 | const onEdgesChange: OnEdgesChange = useCallback(
58 | (changes: EdgeChange[]) => setEdges((eds) => applyEdgeChanges(changes, eds)),
59 | []
60 | );
61 |
62 | const onConnect: OnConnect = useCallback(
63 | (params: Connection) => setEdges((eds) => addEdge({ ...params, animated: true }, eds)),
64 | []
65 | );
66 |
67 | const extractNetworkStructure = (currentNodes: Node[]) => {
68 | const hiddenLayers: number[] = [];
69 | currentNodes.forEach((node) => {
70 | if (node.type === 'default' && node.data?.label?.includes('Dense')) {
71 | const match = node.data.label.match(/Neurons: (\d+)/);
72 | if (match && match[1]) {
73 | hiddenLayers.push(parseInt(match[1], 10));
74 | }
75 | }
76 | });
77 | onNetworkChange(hiddenLayers);
78 | };
79 |
80 | useEffect(() => {
81 | extractNetworkStructure(nodes);
82 | }, [nodes]);
83 |
84 | return (
85 |
86 |
Network Designer (Drag & Drop - Basic)
87 |
88 |
89 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 | Note: This is a basic visual representation. Add/remove/connect nodes to define layers.
105 | Neuron counts need manual adjustment via the config panel for now.
106 |
107 |
108 | );
109 | }
110 |
--------------------------------------------------------------------------------
/r3f/target-chasing/src/components/VisualizationCharts.tsx:
--------------------------------------------------------------------------------
1 | import React, {
2 | useEffect,
3 | useState,
4 | } from 'react';
5 |
6 | import {
7 | CartesianGrid,
8 | Legend,
9 | Line,
10 | LineChart,
11 | ResponsiveContainer,
12 | Tooltip,
13 | XAxis,
14 | YAxis,
15 | } from 'recharts';
16 |
17 | import { useTrainingStore } from '../store/trainingStore';
18 |
19 | // Define the data structure for our charts
20 | interface DataPoint {
21 | step: number;
22 | reward?: number;
23 | loss?: number;
24 | epsilon?: number;
25 | }
26 |
27 | interface VisualizationChartsProps {
28 | maxDataPoints?: number;
29 | }
30 |
31 | export function VisualizationCharts({ maxDataPoints = 100 }: VisualizationChartsProps) {
32 | // State to store historical data
33 | const [rewardHistory, setRewardHistory] = useState([]);
34 | const [lossHistory, setLossHistory] = useState([]);
35 | const [epsilonHistory, setEpsilonHistory] = useState([]);
36 |
37 | // Get current values from the training store
38 | const {
39 | reward,
40 | episodeSteps,
41 | isTraining,
42 | episodeCount
43 | } = useTrainingStore();
44 |
45 | // Get epsilon from the agent (we'll need to add this to the store)
46 | // For now, we'll simulate epsilon decay
47 | const [simulatedEpsilon, setSimulatedEpsilon] = useState(1.0);
48 | const [simulatedLoss, setSimulatedLoss] = useState(0.5);
49 |
50 | // Update epsilon simulation
51 | useEffect(() => {
52 | if (isTraining) {
53 | // Simulate epsilon decay (starting at 1.0, decaying to 0.01)
54 | const newEpsilon = Math.max(0.01, 1.0 * Math.pow(0.995, episodeCount));
55 | setSimulatedEpsilon(newEpsilon);
56 |
57 | // Simulate loss (starting high, gradually decreasing with fluctuations)
58 | const baseLoss = 0.5 * Math.pow(0.99, episodeCount);
59 | const randomFactor = Math.random() * 0.1 - 0.05; // Random fluctuation between -0.05 and 0.05
60 | setSimulatedLoss(Math.max(0.01, baseLoss + randomFactor));
61 | }
62 | }, [episodeCount, isTraining]);
63 |
64 | // Update charts when reward changes
65 | useEffect(() => {
66 | if (isTraining) {
67 | // Add new data point to reward history
68 | const newRewardPoint: DataPoint = {
69 | step: episodeSteps,
70 | reward: reward
71 | };
72 |
73 | setRewardHistory(prev => {
74 | const newHistory = [...prev, newRewardPoint];
75 | // Keep only the most recent points to avoid performance issues
76 | return newHistory.slice(-maxDataPoints);
77 | });
78 |
79 | // Add new data point to loss history
80 | const newLossPoint: DataPoint = {
81 | step: episodeSteps,
82 | loss: simulatedLoss
83 | };
84 |
85 | setLossHistory(prev => {
86 | const newHistory = [...prev, newLossPoint];
87 | return newHistory.slice(-maxDataPoints);
88 | });
89 |
90 | // Add new data point to epsilon history
91 | const newEpsilonPoint: DataPoint = {
92 | step: episodeSteps,
93 | epsilon: simulatedEpsilon
94 | };
95 |
96 | setEpsilonHistory(prev => {
97 | const newHistory = [...prev, newEpsilonPoint];
98 | return newHistory.slice(-maxDataPoints);
99 | });
100 | }
101 | }, [reward, episodeSteps, isTraining, simulatedEpsilon, simulatedLoss, maxDataPoints]);
102 |
103 | // Reset charts when training is stopped
104 | useEffect(() => {
105 | if (!isTraining) {
106 | // Keep the data but don't update it
107 | }
108 | }, [isTraining]);
109 |
110 | return (
111 |
112 |
Training Visualization
113 |
114 | {/* Reward Chart */}
115 |
116 |
Reward Over Time
117 |
118 |
122 |
123 |
124 |
125 |
126 |
127 |
134 |
135 |
136 |
137 |
138 | {/* Loss Chart */}
139 |
140 |
Loss Over Time
141 |
142 |
146 |
147 |
148 |
149 |
150 |
151 |
157 |
158 |
159 |
160 |
161 | {/* Epsilon Chart */}
162 |
163 |
Epsilon Decay
164 |
165 |
169 |
170 |
171 |
175 |
176 |
177 |
183 |
184 |
185 |
186 |
187 | );
188 | }
189 |
--------------------------------------------------------------------------------
/r3f/target-chasing/src/index.css:
--------------------------------------------------------------------------------
1 | #root {
2 | width: 100vw;
3 | height: 100vh;
4 | }
5 |
6 | body {
7 | margin: 0;
8 | }
9 |
--------------------------------------------------------------------------------
/r3f/target-chasing/src/main.tsx:
--------------------------------------------------------------------------------
1 | import { StrictMode } from 'react'
2 | import { createRoot } from 'react-dom/client'
3 | import './index.css'
4 | import App from './App.tsx'
5 |
6 | createRoot(document.getElementById('root')!).render(
7 |
8 |
9 | ,
10 | )
11 |
--------------------------------------------------------------------------------
/r3f/target-chasing/src/store/trainingStore.ts:
--------------------------------------------------------------------------------
1 | import { create } from 'zustand';
2 |
3 | interface TrainingState {
4 | // États d'entraînement
5 | isTraining: boolean;
6 | isTrainingInProgress: boolean;
7 | episodeCount: number;
8 | reward: number;
9 | bestReward: number;
10 | episodeSteps: number;
11 | reachedTarget: boolean;
12 | episodeTime: number;
13 | episodeStartTime: number;
14 | successCount: number;
15 | difficulty: number;
16 | lastAction: number;
17 |
18 | // État de la cible
19 | targetPosition: [number, number, number];
20 |
21 | // Actions
22 | setIsTraining: (value: boolean) => void;
23 | setIsTrainingInProgress: (value: boolean) => void;
24 | setEpisodeCount: (value: number | ((prev: number) => number)) => void;
25 | setReward: (value: number) => void;
26 | setBestReward: (value: number) => void;
27 | setEpisodeSteps: (value: number | ((prev: number) => number)) => void;
28 | setReachedTarget: (value: boolean) => void;
29 | setEpisodeTime: (value: number) => void;
30 | setEpisodeStartTime: (value: number) => void;
31 | setSuccessCount: (value: number | ((prev: number) => number)) => void;
32 | setDifficulty: (value: number | ((prev: number) => number)) => void;
33 | setLastAction: (value: number) => void;
34 | setTargetPosition: (value: [number, number, number]) => void;
35 |
36 | // Méthodes utilitaires
37 | resetEpisode: () => void;
38 | incrementEpisodeCount: () => void;
39 | }
40 |
41 | export const useTrainingStore = create((set) => ({
42 | // États initiaux
43 | isTraining: false,
44 | isTrainingInProgress: false,
45 | episodeCount: 0,
46 | reward: 0,
47 | bestReward: -Infinity,
48 | episodeSteps: 0,
49 | reachedTarget: false,
50 | episodeTime: 0,
51 | episodeStartTime: Date.now(),
52 | successCount: 0,
53 | difficulty: 0,
54 | lastAction: -1,
55 |
56 | // État initial de la cible
57 | targetPosition: [0, 10, 0] as [number, number, number],
58 |
59 | // Actions
60 | setIsTraining: (value) => set({ isTraining: value }),
61 | setIsTrainingInProgress: (value) => set({ isTrainingInProgress: value }),
62 | setEpisodeCount: (value) => set((state) => ({
63 | episodeCount: typeof value === 'function' ? value(state.episodeCount) : value
64 | })),
65 | setReward: (value) => set({ reward: value }),
66 | setBestReward: (value) => set({ bestReward: value }),
67 | setEpisodeSteps: (value) => set((state) => ({
68 | episodeSteps: typeof value === 'function' ? value(state.episodeSteps) : value
69 | })),
70 | setReachedTarget: (value) => set({ reachedTarget: value }),
71 | setEpisodeTime: (value) => set({ episodeTime: value }),
72 | setEpisodeStartTime: (value) => set({ episodeStartTime: value }),
73 | setSuccessCount: (value) => set((state) => ({
74 | successCount: typeof value === 'function' ? value(state.successCount) : value
75 | })),
76 | setDifficulty: (value) => set((state) => ({
77 | difficulty: typeof value === 'function' ? value(state.difficulty) : value
78 | })),
79 | setLastAction: (value) => set({ lastAction: value }),
80 | setTargetPosition: (value) => set({ targetPosition: value }),
81 |
82 | // Méthodes utilitaires
83 | resetEpisode: () => set((state) => ({
84 | episodeSteps: 0,
85 | reachedTarget: false,
86 | episodeCount: state.episodeCount + 1,
87 | episodeTime: 0,
88 | episodeStartTime: Date.now()
89 | })),
90 | incrementEpisodeCount: () => set((state) => ({ episodeCount: state.episodeCount + 1 }))
91 | }));
92 |
--------------------------------------------------------------------------------
/r3f/target-chasing/src/styles.css:
--------------------------------------------------------------------------------
1 | .ui-panels {
2 | display: flex;
3 | flex-direction: column;
4 | position: fixed;
5 | top: 20px;
6 | right: 20px;
7 | width: 350px;
8 | max-height: 90vh;
9 | overflow-y: auto;
10 | z-index: 1000;
11 | gap: 15px;
12 | }
13 |
14 | .training-controls {
15 | background: rgba(0, 0, 0, 0.7);
16 | padding: 15px;
17 | border-radius: 5px;
18 | color: white;
19 | font-family: 'Arial', sans-serif;
20 | box-shadow: 0 0 10px rgba(0, 0, 0, 0.5);
21 | }
22 |
23 | .training-controls h3 {
24 | margin-top: 0;
25 | color: #0c8cbf; /* Bleu IgnitionAI */
26 | border-bottom: 1px solid #3f4e8d;
27 | padding-bottom: 8px;
28 | }
29 |
30 | .training-controls button {
31 | background-color: #0c8cbf;
32 | border: none;
33 | color: white;
34 | padding: 8px 12px;
35 | border-radius: 4px;
36 | cursor: pointer;
37 | transition: background-color 0.3s;
38 | }
39 |
40 | .training-controls button:hover {
41 | background-color: #3f4e8d;
42 | }
43 |
44 | .training-controls div {
45 | margin-bottom: 5px;
46 | }
47 |
48 | /* Visualization Charts Styles */
49 | .visualization-charts {
50 | background: rgba(0, 0, 0, 0.7);
51 | padding: 15px;
52 | border-radius: 5px;
53 | color: white;
54 | font-family: 'Arial', sans-serif;
55 | box-shadow: 0 0 10px rgba(0, 0, 0, 0.5);
56 | }
57 |
58 | .visualization-charts h3 {
59 | margin-top: 0;
60 | color: #0c8cbf;
61 | border-bottom: 1px solid #3f4e8d;
62 | padding-bottom: 8px;
63 | }
64 |
65 | .visualization-charts h4 {
66 | margin-top: 10px;
67 | margin-bottom: 5px;
68 | color: #ffffff;
69 | }
70 |
71 | .chart-container {
72 | margin-bottom: 20px;
73 | background: rgba(30, 30, 30, 0.5);
74 | padding: 10px;
75 | border-radius: 4px;
76 | }
77 |
78 | /* Agent Config Panel Styles */
79 | .agent-config-panel {
80 | background: rgba(0, 0, 0, 0.7);
81 | padding: 15px;
82 | border-radius: 5px;
83 | color: white;
84 | font-family: 'Arial', sans-serif;
85 | box-shadow: 0 0 10px rgba(0, 0, 0, 0.5);
86 | }
87 |
88 | .agent-config-panel h3 {
89 | margin-top: 0;
90 | color: #0c8cbf;
91 | border-bottom: 1px solid #3f4e8d;
92 | padding-bottom: 8px;
93 | }
94 |
95 | .agent-config-panel h4 {
96 | margin-top: 15px;
97 | margin-bottom: 10px;
98 | color: #ffffff;
99 | }
100 |
101 | .agent-config-panel h5 {
102 | margin-top: 10px;
103 | margin-bottom: 5px;
104 | color: #cccccc;
105 | }
106 |
107 | .config-section {
108 | margin-bottom: 15px;
109 | background: rgba(30, 30, 30, 0.5);
110 | padding: 10px;
111 | border-radius: 4px;
112 | }
113 |
114 | .config-row {
115 | display: flex;
116 | justify-content: space-between;
117 | align-items: center;
118 | margin-bottom: 8px;
119 | }
120 |
121 | .config-row label {
122 | flex: 1;
123 | margin-right: 10px;
124 | }
125 |
126 | .config-row input {
127 | width: 80px;
128 | padding: 5px;
129 | border-radius: 3px;
130 | border: 1px solid #555;
131 | background: #333;
132 | color: white;
133 | }
134 |
135 | .layers-container {
136 | margin-top: 10px;
137 | }
138 |
139 | .layer-row {
140 | display: flex;
141 | align-items: center;
142 | margin-bottom: 8px;
143 | }
144 |
145 | .layer-row label {
146 | width: 80px;
147 | margin-right: 10px;
148 | }
149 |
150 | .layer-row input {
151 | width: 60px;
152 | padding: 5px;
153 | border-radius: 3px;
154 | border: 1px solid #555;
155 | background: #333;
156 | color: white;
157 | margin-right: 10px;
158 | }
159 |
160 | .layer-row button {
161 | background-color: #d32f2f;
162 | border: none;
163 | color: white;
164 | padding: 5px 8px;
165 | border-radius: 3px;
166 | cursor: pointer;
167 | font-size: 0.8em;
168 | }
169 |
170 | .layers-container button {
171 | background-color: #388e3c;
172 | border: none;
173 | color: white;
174 | padding: 5px 10px;
175 | border-radius: 3px;
176 | cursor: pointer;
177 | margin-top: 5px;
178 | }
179 |
180 | .apply-button {
181 | background-color: #0c8cbf;
182 | border: none;
183 | color: white;
184 | padding: 10px 15px;
185 | border-radius: 4px;
186 | cursor: pointer;
187 | transition: background-color 0.3s;
188 | width: 100%;
189 | margin-top: 15px;
190 | font-weight: bold;
191 | }
192 |
193 | .apply-button:hover {
194 | background-color: #3f4e8d;
195 | }
196 |
197 | /* Network Designer Styles */
198 | .network-designer-panel {
199 | background: rgba(0, 0, 0, 0.7);
200 | padding: 15px;
201 | border-radius: 5px;
202 | color: white;
203 | font-family: 'Arial', sans-serif;
204 | box-shadow: 0 0 10px rgba(0, 0, 0, 0.5);
205 | }
206 |
207 | .network-designer-panel h3 {
208 | margin-top: 0;
209 | color: #0c8cbf;
210 | border-bottom: 1px solid #3f4e8d;
211 | padding-bottom: 8px;
212 | }
213 |
214 | /* React Flow specific styles */
215 | .react-flow__node {
216 | background: #2a2a2a;
217 | color: #eee;
218 | border: 1px solid #555;
219 | border-radius: 4px;
220 | padding: 8px 12px;
221 | font-size: 12px;
222 | }
223 |
224 | .react-flow__node.react-flow__node-input {
225 | background: #388e3c;
226 | border-color: #66bb6a;
227 | }
228 |
229 | .react-flow__node.react-flow__node-output {
230 | background: #d32f2f;
231 | border-color: #ef5350;
232 | }
233 |
234 | .react-flow__edge-path {
235 | stroke: #0c8cbf;
236 | stroke-width: 2;
237 | }
238 |
239 | .react-flow__controls button {
240 | background-color: rgba(40, 40, 40, 0.8);
241 | color: white;
242 | border: 1px solid #555;
243 | }
244 |
245 | .react-flow__controls button:hover {
246 | background-color: rgba(60, 60, 60, 0.9);
247 | }
248 |
249 | .react-flow__minimap {
250 | background-color: rgba(30, 30, 30, 0.8);
251 | border: 1px solid #555;
252 | }
253 |
254 | .react-flow__background {
255 | background-color: #1e1e1e;
256 | }
257 |
258 |
--------------------------------------------------------------------------------
/r3f/target-chasing/src/visualization.ts:
--------------------------------------------------------------------------------
1 | import * as THREE from 'three';
2 | import { OrbitControls } from 'three/examples/jsm/controls/OrbitControls';
3 |
4 | import { DQNAgent } from '@ignitionai/backend-tfjs';
5 | import { IgnitionEnv } from '@ignitionai/core';
6 |
7 | // Configuration de la scène Three.js
8 | const scene = new THREE.Scene();
9 | const camera = new THREE.PerspectiveCamera(75, window.innerWidth / window.innerHeight, 0.1, 1000);
10 | const renderer = new THREE.WebGLRenderer();
11 | renderer.setSize(window.innerWidth, window.innerHeight);
12 | document.body.appendChild(renderer.domElement);
13 |
14 | // Ajouter des contrôles de caméra
15 | const controls = new OrbitControls(camera, renderer.domElement);
16 | controls.enableDamping = true;
17 | controls.dampingFactor = 0.05;
18 |
19 | // Créer une grille pour mieux visualiser l'espace
20 | const gridHelper = new THREE.GridHelper(50, 50);
21 | scene.add(gridHelper);
22 |
23 | // Créer l'agent (sphère bleue)
24 | const agentGeometry = new THREE.SphereGeometry(0.2, 32, 32);
25 | const agentMaterial = new THREE.MeshBasicMaterial({ color: 0x0000ff });
26 | const agent = new THREE.Mesh(agentGeometry, agentMaterial);
27 | scene.add(agent);
28 |
29 | // Créer la cible (sphère rouge)
30 | const targetGeometry = new THREE.SphereGeometry(0.2, 32, 32);
31 | const targetMaterial = new THREE.MeshBasicMaterial({ color: 0xff0000 });
32 | const target = new THREE.Mesh(targetGeometry, targetMaterial);
33 | scene.add(target);
34 |
35 | // Positionner la caméra
36 | camera.position.set(5, 5, 5);
37 | camera.lookAt(0, 0, 0);
38 |
39 | // Variables pour l'environnement
40 | let position = 0;
41 | let targetPosition = (Math.random() - 0.5) * 4;
42 | let bestDistance = Infinity;
43 | let stepCount = 0;
44 | let previousDistance = Infinity;
45 |
46 | // Créer l'agent DQN
47 | const dqnAgent = new DQNAgent({
48 | inputSize: 2,
49 | actionSize: 3,
50 | hiddenLayers: [32, 32],
51 | gamma: 0.99,
52 | epsilon: 1.0,
53 | epsilonDecay: 0.995,
54 | minEpsilon: 0.01,
55 | lr: 0.001,
56 | batchSize: 32,
57 | memorySize: 1000,
58 | targetUpdateFrequency: 10,
59 | });
60 |
61 | // Vérifier si le token est disponible
62 | const hfToken = import.meta.env?.VITE_HF_TOKEN;
63 | if (!hfToken) {
64 | console.warn('⚠️ VITE_HF_TOKEN non trouvé. Les checkpoints ne seront pas sauvegardés.');
65 | }
66 |
67 | // Créer l'environnement
68 | const env: IgnitionEnv = new IgnitionEnv({
69 | agent: dqnAgent,
70 | getObservation: () => [position, targetPosition],
71 | applyAction: (action: number | number[]) => {
72 | const a = Array.isArray(action) ? action[0] : action;
73 | const dx = a - 1;
74 | position += dx * 0.2;
75 | agent.position.x = position;
76 |
77 | // Log de l'action
78 | console.log(`[ACTION] ${a} (dx: ${dx.toFixed(2)})`);
79 | },
80 | computeReward: () => {
81 | const d = Math.abs(position - targetPosition);
82 |
83 | // Vérifier si l'agent s'éloigne
84 | const isMovingAway = d > previousDistance;
85 | previousDistance = d;
86 |
87 | // Récompense de base
88 | let reward = 1.0 / (1.0 + d);
89 |
90 | // Pénalité si s'éloigne
91 | if (isMovingAway) {
92 | reward -= 0.5;
93 | }
94 |
95 | // Bonus si proche
96 | if (d < 0.5) {
97 | reward += 1.0;
98 | }
99 |
100 | return reward;
101 | },
102 | isDone: (): boolean => {
103 | const d = Math.abs(position - targetPosition);
104 | const done = d < 0.1 || stepCount > 1000;
105 |
106 | if (done) {
107 | console.log(`[DONE] Distance finale: ${d.toFixed(2)}`);
108 | }
109 |
110 | return done;
111 | },
112 | onReset: () => {
113 | position = 0;
114 | targetPosition = (Math.random() - 0.5) * 4;
115 | agent.position.x = position;
116 | target.position.x = targetPosition;
117 | stepCount = 0;
118 | bestDistance = Infinity;
119 | previousDistance = Math.abs(position - targetPosition);
120 |
121 | // Log du reset
122 | console.log(`[RESET] Nouvelle cible: ${targetPosition.toFixed(2)}`);
123 | },
124 | stepIntervalMs: 100,
125 | hfRepoId: 'salim4n/dqn-checkpoint-threejs',
126 | hfToken: hfToken || '',
127 | });
128 |
129 | // Étendre la méthode step pour gérer les checkpoints
130 | const originalStep = env.step.bind(env);
131 | env.step = async (action?: number) => {
132 | // Attendre que l'étape précédente soit terminée
133 | const result = await originalStep();
134 | stepCount++;
135 |
136 | const d = Math.abs(position - targetPosition);
137 |
138 | // Log de l'étape
139 | if (stepCount % 10 === 0) {
140 | console.log(`[STEP ${stepCount}] Position: ${position.toFixed(2)}, Cible: ${targetPosition.toFixed(2)}, Distance: ${d.toFixed(2)}`);
141 | }
142 |
143 | // Sauvegarder un checkpoint si c'est la meilleure performance jusqu'à présent
144 | if (d < bestDistance) {
145 | bestDistance = d;
146 | console.log(`[CHECKPOINT] Nouvelle meilleure distance: ${d.toFixed(4)}`);
147 | console.log(`[CHECKPOINT] Sauvegarde du meilleur modèle...`);
148 | // Désactiver la sauvegarde dans le navigateur
149 | // await dqnAgent.saveCheckpoint(
150 | // 'salim4n/dqn-checkpoint-threejs',
151 | // hfToken || '',
152 | // 'best'
153 | // );
154 | console.log(`[CHECKPOINT] ✅ Meilleur modèle sauvegardé (simulé)`);
155 | }
156 |
157 | // Sauvegarder un checkpoint tous les 100 steps
158 | if (stepCount % 100 === 0) {
159 | console.log(`[CHECKPOINT] Sauvegarde régulière à l'étape ${stepCount}`);
160 | // Désactiver la sauvegarde dans le navigateur
161 | // await dqnAgent.saveCheckpoint(
162 | // 'salim4n/dqn-checkpoint-threejs',
163 | // hfToken || '',
164 | // `step-${stepCount}`
165 | // );
166 | }
167 |
168 | // Si c'est la fin, sauvegarder un dernier checkpoint
169 | const isDone = (): boolean => {
170 | const d = Math.abs(position - targetPosition);
171 | return d < 0.1 || stepCount > 1000;
172 | };
173 |
174 | if (isDone()) {
175 | console.log(`[FINISH] Entraînement terminé à l'étape ${stepCount}!`);
176 | console.log(`[FINISH] Distance finale: ${d.toFixed(2)}`);
177 | // Désactiver la sauvegarde dans le navigateur
178 | // await dqnAgent.saveCheckpoint(
179 | // 'salim4n/dqn-checkpoint-threejs',
180 | // hfToken || '',
181 | // 'final'
182 | // );
183 | env.stop();
184 | }
185 |
186 | return result;
187 | };
188 |
189 | // Fonction d'animation
190 | function animate() {
191 | requestAnimationFrame(animate);
192 | controls.update();
193 | renderer.render(scene, camera);
194 | }
195 |
196 | // Gérer le redimensionnement de la fenêtre
197 | window.addEventListener('resize', () => {
198 | camera.aspect = window.innerWidth / window.innerHeight;
199 | camera.updateProjectionMatrix();
200 | renderer.setSize(window.innerWidth, window.innerHeight);
201 | });
202 |
203 | // Démarrer l'animation et l'environnement
204 | console.log('[START] Démarrage de la visualisation...');
205 | animate();
206 | env.start();
--------------------------------------------------------------------------------
/r3f/target-chasing/src/vite-env.d.ts:
--------------------------------------------------------------------------------
1 | ///
2 |
--------------------------------------------------------------------------------
/r3f/target-chasing/tsconfig.app.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
4 | "target": "ES2020",
5 | "useDefineForClassFields": true,
6 | "lib": ["ES2020", "DOM", "DOM.Iterable"],
7 | "module": "ESNext",
8 | "skipLibCheck": true,
9 |
10 | /* Bundler mode */
11 | "moduleResolution": "bundler",
12 | "allowImportingTsExtensions": true,
13 | "isolatedModules": true,
14 | "moduleDetection": "force",
15 | "noEmit": true,
16 | "jsx": "react-jsx",
17 |
18 | /* Linting */
19 | "strict": true,
20 | "noUnusedLocals": true,
21 | "noUnusedParameters": true,
22 | "noFallthroughCasesInSwitch": true,
23 | "noUncheckedSideEffectImports": true
24 | },
25 | "include": ["src"]
26 | }
27 |
--------------------------------------------------------------------------------
/r3f/target-chasing/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "jsx": "react-jsx",
4 | "baseUrl": ".",
5 | "noUnusedLocals": false,
6 | "noUnusedParameters": false,
7 | "paths": {
8 | "@ignitionai/core": ["../../packages/core/src"],
9 | "@ignitionai/backend-tfjs": ["../../packages/backend-tfjs/src"]
10 | },
11 | "skipLibCheck": true,
12 | "typeRoots": ["./node_modules/@types", "../../node_modules/@types"]
13 | },
14 | "include": ["src"],
15 | "extends": "./tsconfig.app.json"
16 | }
17 |
--------------------------------------------------------------------------------
/r3f/target-chasing/tsconfig.node.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
4 | "target": "ES2022",
5 | "lib": ["ES2023"],
6 | "module": "ESNext",
7 | "skipLibCheck": true,
8 |
9 | /* Bundler mode */
10 | "moduleResolution": "bundler",
11 | "allowImportingTsExtensions": true,
12 | "isolatedModules": true,
13 | "moduleDetection": "force",
14 | "noEmit": true,
15 |
16 | /* Linting */
17 | "strict": true,
18 | "noUnusedLocals": true,
19 | "noUnusedParameters": true,
20 | "noFallthroughCasesInSwitch": true,
21 | "noUncheckedSideEffectImports": true
22 | },
23 | "include": ["vite.config.ts"]
24 | }
25 |
--------------------------------------------------------------------------------
/r3f/target-chasing/vite.config.ts:
--------------------------------------------------------------------------------
1 | import { defineConfig } from 'vite'
2 | import react from '@vitejs/plugin-react'
3 | import path from 'path'
4 |
5 | // https://vite.dev/config/
6 | export default defineConfig({
7 | plugins: [react()],
8 | resolve: {
9 | alias: {
10 | '@ignitionai/backend-tfjs': path.resolve(__dirname, '../../packages/backend-tfjs/src'),
11 | '@ignitionai/core': path.resolve(__dirname, '../../packages/core/src')
12 | }
13 | }
14 | })
15 |
--------------------------------------------------------------------------------
/roadmap.md:
--------------------------------------------------------------------------------
1 | # 🧭 IgnitionAI - Project Roadmap
2 |
3 | This document outlines the phased development of **IgnitionAI** — a modular, browser-friendly framework for intelligent agent simulation and reinforcement learning.
4 |
5 | ---
6 |
7 | ## ✅ Phase 1 — Core Logic (MVP)
8 |
9 | > ⚙️ Goal: Run agent-environment logic headlessly (no UI)
10 |
11 | ✅ Roadmap for "RL algo first"
12 | Phase A — @ignitionai/backend-tfjs only
13 | Implementing classic algorithms with TensorFlow.js
14 |
15 | 1. 🔁 Q-learning (tabular) – minimalist JS version
16 | without neural networks
17 | ✅ Implemented Q-Table agent with state/action lookup
18 | ✅ Added tests for basic functionality
19 |
20 | 2. 🧠 DQN – Deep Q-Network
21 | ✅ Implemented MLP simple input → hidden → output
22 | ✅ Added replay buffer with experience sampling
23 | ✅ Implemented target network with periodic updates
24 | ✅ Added epsilon-greedy exploration/exploitation
25 | ✅ Loss function based on TD error
26 | ✅ Unit tests with training validation
27 |
28 | 3. 🧘♂️ PPO – Policy Gradient
29 | ✅ Created initial PPO agent skeleton
30 | - [ ] Implement Actor-Critic model
31 | - [ ] Implement episode-based training
32 | - [ ] Add policy and value loss functions
33 |
34 | ---
35 |
36 | ## ✅ Phase 1.5 — Backend Infrastructure
37 |
38 | > 🧰 Goal: Create robust, multi-environment backend support
39 |
40 | ✅ Created modular monorepo structure
41 | ✅ Implemented robust backend selection system
42 | ✅ Added support for all major TensorFlow.js backends:
43 | - WebGPU (experimental)
44 | - WebGL
45 | - CPU
46 | - WASM
47 | ✅ Added helper utilities for backend detection and info
48 | ✅ Added comprehensive model management system:
49 | - IndexedDB local storage
50 | - Hugging Face Hub integration with authentication
51 | - Automatic model serialization/deserialization
52 | - Checkpoint system with:
53 | - Regular checkpoints (step-based)
54 | - Best model checkpoints
55 | - Automatic retry with exponential backoff
56 | - Model versioning and metadata
57 | ✅ Added robust error handling and logging
58 | ✅ Comprehensive unit tests and integration tests
59 |
60 | ---
61 |
62 | ## 🚀 Phase 2 — R3F Visualisation & Basic UI
63 |
64 | > 🎮 Goal: Visualize the agent/environment and provide basic interaction
65 |
66 | ✅ `@ignitionai/r3f`: add `AgentMesh`, `TargetMesh`, `useAgent`
67 | ✅ `@ignitionai/demo-target-chasing`: setup Vite + R3F scene
68 | ✅ Add training monitoring and auto-stop functionality
69 | ✅ Display step count and reward in the UI
70 | ✅ Implement real-time model updates
71 | ✅ Added basic training controls (Start/Stop/Reset)
72 | ✅ Added real-time visualization charts (Reward, Loss, Epsilon) using Recharts
73 | - [ ] Add more advanced visualization (e.g., network graph, Q-values)
74 | - [ ] Optimize performance for longer training sessions
75 | - [ ] Add ability to save/load models from the UI
76 |
77 | ---
78 |
79 | ## ✅ Phase 3 — TFJS Backend & Dynamic Configuration
80 |
81 | > 🧠 Goal: Train and run a model directly in the browser with user configuration
82 |
83 | ✅ `@ignitionai/backend-tfjs`: built simple MLP model with configurable layers
84 | ✅ Implemented `train()` and `predict()` APIs via DQN agent
85 | ✅ Added model serialization with `save()` and `load()`
86 | ✅ Added support for Hugging Face Hub integration
87 | ✅ Created streamlined `Agent` class interface
88 | ✅ Added comprehensive training utilities:
89 | - Progress tracking
90 | - Performance metrics
91 | - Model checkpointing
92 | - Training visualization (3D scene + charts)
93 | ✅ Implemented browser-based training with Three.js visualization
94 | ✅ Added automatic checkpoint saving for best models
95 | ✅ Refactored demo to accept dynamic agent configuration from UI
96 | ✅ Added UI Panel for Agent Hyperparameter Configuration (Learning Rate, Epsilon, Gamma, etc.)
97 | ✅ Added basic Drag-and-Drop Network Designer (React Flow) for visual representation (Note: Currently visual only, config panel drives actual layer structure)
98 | - [ ] Fully integrate Network Designer to drive agent creation
99 | - [ ] Add support for loading pre-trained weights via UI
100 |
101 | ---
102 |
103 | ## 🚀 Phase 4 — ONNX Runtime Backend (Inference-only)
104 |
105 | > ⚡ Goal: Run optimized pre-trained models in production
106 |
107 | ✅ Created initial package structure for ONNX backend
108 | - [ ] Implement ONNX Runtime Web integration
109 | - [ ] Add `.onnx` model loading and inference
110 | - [ ] Create `InferenceBackend` wrapper
111 | - [ ] Add model conversion utilities (TFJS → ONNX)
112 |
113 | ---
114 |
115 | ## 🚀 Phase 5 — Advanced Environments
116 |
117 | > 🌍 Goal: Create more complex environments for agent training
118 |
119 | - [ ] Implement grid-based environments (maze, pathfinding)
120 | - [ ] Add physics-based environments (pendulum, cartpole)
121 | - [ ] Create multi-agent environments
122 | - [ ] Add environment customization tools
123 | - [ ] Implement environment visualization tools
124 |
125 | ---
126 |
127 | ## 🚀 Phase 6 — Advanced Algorithms
128 |
129 | > 🧠 Goal: Implement more sophisticated RL algorithms
130 |
131 | - [ ] Implement DDPG (Deep Deterministic Policy Gradient)
132 | - [ ] Add SAC (Soft Actor-Critic)
133 | - [ ] Implement A2C (Advantage Actor-Critic)
134 | - [ ] Add support for custom algorithm implementations
135 | - [ ] Create algorithm comparison tools
136 |
137 | ---
138 |
139 | ## 🚀 Phase 7 — Deployment & Production
140 |
141 | > 🚢 Goal: Make the framework production-ready
142 |
143 | - [ ] Add comprehensive documentation
144 | - [ ] Create example applications
145 | - [ ] Implement CI/CD pipeline
146 | - [ ] Add performance optimization tools
147 | - [ ] Create deployment guides
148 | - [ ] Add monitoring and analytics
149 |
150 | ---
151 |
152 | Built with ❤️ by Salim (@IgnitionAI)
153 |
154 |
--------------------------------------------------------------------------------
/tsconfig.base.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "ESNext",
4 | "module": "ESNext",
5 | "moduleResolution": "Node",
6 | "declaration": true,
7 | "outDir": "dist",
8 | "strict": true,
9 | "esModuleInterop": true,
10 | "skipLibCheck": true,
11 | "jsx": "react-jsx"
12 | },
13 | "include": ["packages"]
14 | }
15 |
--------------------------------------------------------------------------------
/vitest.config.ts:
--------------------------------------------------------------------------------
1 | import { defineConfig } from 'vitest/config';
2 |
3 | export default defineConfig({
4 | test: {
5 | globals: true,
6 | environment: 'node',
7 | include: ['**/*.{test,spec}.{js,mjs,cjs,ts,mts,cts,jsx,tsx}', '**/test/**/*.{test,spec}.{js,mjs,cjs,ts,mts,cts,jsx,tsx}']
8 | },
9 | });
--------------------------------------------------------------------------------