├── LICENSE ├── README.md ├── README_zh.md ├── TUTORIAL.md ├── TUTORIAL_zh.md ├── config ├── MAPPOConfig.m ├── PPOConfig.m ├── default_acmotor_config.m ├── default_cartpole_config.m ├── default_dcmotor_config.m └── default_doublependulum_config.m ├── core ├── ContinuousActorNetwork.m ├── CriticNetwork.m ├── DiscreteActorNetwork.m ├── MAPPOAgent.m └── PPOAgent.m ├── environments ├── ACMotorEnv.m ├── CartPoleEnv.m ├── DCMotorEnv.m ├── DoublePendulumEnv.m └── Environment.m ├── examples ├── test_acmotor.m ├── test_cartpole.m ├── test_dcmotor.m ├── test_doublependulum.m ├── train_acmotor.m ├── train_cartpole.m ├── train_dcmotor.m └── train_doublependulum.m └── utils └── Logger.m /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 X-Embodied 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🤖 Matlab PPO Reinforcement Learning Framework 2 | 3 | [![Matlab_PPO](https://img.shields.io/badge/Matlab_PPO-v1.0.0-blueviolet)](https://github.com/AIResearcherHZ/Matlab_PPO) 4 | [![MATLAB](https://img.shields.io/badge/MATLAB-R2019b%2B-blue.svg)](https://www.mathworks.com/products/matlab.html) 5 | [![Deep Learning Toolbox](https://img.shields.io/badge/Deep%20Learning%20Toolbox-Required-green.svg)](https://www.mathworks.com/products/deep-learning.html) 6 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 7 | 8 | A MATLAB-based reinforcement learning framework featuring Proximal Policy Optimization (PPO) algorithm and its multi-agent extension (MAPPO), with GPU acceleration and parallel computing support, suitable for control system research and engineering applications. 9 | 10 | > 📚 For detailed algorithm principles, implementation details, and advanced features, please refer to [TUTORIAL.md](TUTORIAL.md) 11 | 12 | ## ✨ Inspiration 13 | 14 | This project was inspired by several core principles: 15 | 16 | - 🎯 Provide an efficient and user-friendly reinforcement learning framework for control system research 17 | - 🔄 Apply advanced PPO algorithms to practical engineering control problems 18 | - 🤝 Promote multi-agent system applications in industrial control 19 | - 📊 Offer intuitive performance evaluation and visualization tools 20 | 21 | ## 🚀 Tech Stack 22 | 23 | ### Core Framework 24 | - **MATLAB R2019b+**: Primary development environment 25 | - **Deep Learning Toolbox**: Neural network construction 26 | - **Parallel Computing Toolbox**: Parallel data collection 27 | - **GPU Computing**: CUDA acceleration support 28 | 29 | ### Algorithm Implementation 30 | - **PPO**: Clip-based policy optimization 31 | - **MAPPO**: Multi-agent collaborative learning 32 | - **Actor-Critic**: Dual network architecture 33 | - **GAE**: Generalized Advantage Estimation 34 | 35 | ### Environment Models 36 | - **Classic Control**: Cart-pole system 37 | - **Motor Control**: DC/AC motor systems 38 | - **Multi-agent**: Double pendulum system 39 | 40 | ## 🌟 Key Features 41 | 42 | ### PPO (Proximal Policy Optimization) 43 | - Policy gradient-based reinforcement learning algorithm 44 | - Uses clipped objective function to limit policy update magnitude 45 | - Supports continuous and discrete action spaces 46 | - Provides stable and efficient training process 47 | 48 | ### MAPPO (Multi-Agent PPO) 49 | - Multi-agent extension of PPO 50 | - Adopts "Centralized Training with Decentralized Execution" architecture 51 | - Provides solutions for tasks requiring agent cooperation 52 | 53 | ## 🎮 Supported Environments 54 | 55 | The framework includes four control scenarios: 56 | 57 | 1. **CartPole** 58 | - Classic control problem 59 | - Balance a vertical pole by moving the cart left or right 60 | - Discrete action space example 61 | 62 | 2. **DC Motor Control** 63 | - Adjust motor speed to track target velocity 64 | - Continuous action space example 65 | - Includes electrical and mechanical system dynamics 66 | 67 | 3. **AC Motor FOC Control** 68 | - Implement Field-Oriented Control for AC induction motors 69 | - Advanced industrial control example 70 | - Involves complex coordinate transformations and dynamics 71 | 72 | 4. **Double Pendulum System** 73 | - Multi-agent cooperative control example 74 | - Demonstrates MAPPO algorithm advantages 75 | - Two connected pendulums requiring coordinated control 76 | 77 | ## 🚀 Quick Start 78 | 79 | ### Requirements 80 | - MATLAB R2019b or higher 81 | - Deep Learning Toolbox 82 | - Parallel Computing Toolbox (optional, for parallel data collection) 83 | - GPU support (optional but recommended for training acceleration) 84 | 85 | ### Installation Steps 86 | 87 | 1. **Download the Project** 88 | ```bash 89 | git clone https://github.com/AIResearcherHZ/Matlab_PPO.git 90 | cd Matlab_PPO 91 | ``` 92 | 93 | ### Usage Steps 94 | 95 | #### 1. Run Examples 96 | ```matlab 97 | % Run CartPole training example 98 | train_cartpole 99 | 100 | % Run DC motor control example 101 | train_dcmotor 102 | 103 | % Run AC motor FOC control example 104 | train_acmotor 105 | 106 | % Run double pendulum MAPPO example 107 | train_doublependulum 108 | ``` 109 | 110 | #### 2. Test Training Results 111 | ```matlab 112 | % Test CartPole 113 | test_cartpole 114 | 115 | % Test DC motor control 116 | test_dcmotor 117 | 118 | % Test AC motor FOC control 119 | test_acmotor 120 | 121 | % Test double pendulum system 122 | test_doublependulum 123 | ``` 124 | 125 | #### 2. Custom Configuration 126 | 127 | ```matlab 128 | % Create and modify configuration 129 | config = PPOConfig(); 130 | config.gamma = 0.99; % Discount factor 131 | config.epsilon = 0.2; % Clipping parameter 132 | config.actorLearningRate = 3e-4; % Policy network learning rate 133 | config.useGPU = true; % Enable GPU acceleration 134 | 135 | % Train with custom configuration 136 | agent = PPOAgent(env, config); 137 | agent.train(); 138 | ``` 139 | 140 | #### 3. Save and Load Models 141 | 142 | ```matlab 143 | % Save trained model 144 | agent.save('my_trained_model.mat'); 145 | 146 | % Load model 147 | agent = PPOAgent.load('my_trained_model.mat', env); 148 | ``` 149 | 150 | ## 📁 Directory Structure 151 | 152 | ``` 153 | Matlab_PPO/ 154 | ├── core/ # Core algorithm implementation 155 | ├── environments/ # Environment implementation 156 | ├── config/ # Configuration files 157 | ├── utils/ # Utility functions 158 | ├── examples/ # Example scripts 159 | └── logs/ # Logs and model save directory 160 | ``` 161 | 162 | ## 📚 Documentation 163 | 164 | - [TUTORIAL.md](TUTORIAL.md) - Detailed algorithm tutorial and implementation details 165 | - API documentation available through MATLAB's `help` command, e.g.: 166 | ```matlab 167 | help PPOAgent 168 | help Environment 169 | help DCMotorEnv 170 | ``` 171 | 172 | ## 🔮 Future Plans 173 | 174 | 1. **Algorithm Enhancement** 175 | - Implement more advanced PPO variants 176 | - Add other popular reinforcement learning algorithms 177 | - Optimize multi-agent training strategies 178 | 179 | 2. **Environment Extension** 180 | - Add more industrial control scenarios 181 | - Support custom environment interfaces 182 | - Add simulation environment visualization 183 | 184 | 3. **Performance Optimization** 185 | - Further improve GPU utilization 186 | - Optimize parallel training mechanisms 187 | - Enhance data collection efficiency 188 | 189 | ## 🤝 Contributing 190 | 191 | We welcome improvements and code contributions through Issues and Pull Requests. We especially welcome contributions in: 192 | 193 | - 🐛 Bug fixes and issue reports 194 | - ✨ New features and improvements 195 | - 📝 Documentation improvements and translations 196 | - 🎯 New application scenario examples 197 | 198 | ## 📖 Citation 199 | 200 | If you use this framework in your research, please cite: 201 | 202 | ```bibtex 203 | @misc{matlab_ppo, 204 | author = {}, 205 | title = {Matlab PPO: A Reinforcement Learning Framework for Control Systems}, 206 | year = {2025}, 207 | publisher = {GitHub}, 208 | url = {https://github.com/X-Embodied/Matlab_PPO} 209 | } 210 | ``` 211 | 212 | ## 📄 License 213 | 214 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file 215 | 216 | ## 🙏 Acknowledgments 217 | 218 | Thanks to all researchers and developers who have contributed to this project! 219 | -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 | # 🤖 Matlab PPO 强化学习框架 2 | 3 | [![Matlab_PPO](https://img.shields.io/badge/Matlab_PPO-v1.0.0-blueviolet)](https://github.com/AIResearcherHZ/Matlab_PPO) 4 | [![MATLAB](https://img.shields.io/badge/MATLAB-R2019b%2B-blue.svg)](https://www.mathworks.com/products/matlab.html) 5 | [![Deep Learning Toolbox](https://img.shields.io/badge/Deep%20Learning%20Toolbox-Required-green.svg)](https://www.mathworks.com/products/deep-learning.html) 6 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 7 | 8 | 一个基于MATLAB的强化学习框架,实现了近端策略优化(PPO)算法及其多智能体扩展(MAPPO),支持GPU加速和并行计算,适用于控制系统研究和工程应用。 9 | 10 | > 📚 有关算法原理、实现细节和高级功能的详细说明,请参阅 [TUTORIAL_zh.md](TUTORIAL_zh.md) 11 | 12 | ## ✨ 项目理念 13 | 14 | 本项目基于以下核心理念: 15 | 16 | - 🎯 为控制系统研究提供高效且用户友好的强化学习框架 17 | - 🔄 将先进的PPO算法应用于实际工程控制问题 18 | - 🤝 推广多智能体系统在工业控制中的应用 19 | - 📊 提供直观的性能评估和可视化工具 20 | 21 | ## 🚀 技术栈 22 | 23 | ### 核心框架 24 | - **MATLAB R2019b+**: 主要开发环境 25 | - **Deep Learning Toolbox**: 神经网络构建 26 | - **Parallel Computing Toolbox**: 并行数据收集 27 | - **GPU Computing**: CUDA加速支持 28 | 29 | ### 算法实现 30 | - **PPO**: 基于裁剪的策略优化 31 | - **MAPPO**: 多智能体协同学习 32 | - **Actor-Critic**: 双网络架构 33 | - **GAE**: 广义优势估计 34 | 35 | ### 环境模型 36 | - **经典控制**: 倒立摆系统 37 | - **电机控制**: 直流/交流电机系统 38 | - **多智能体**: 双摆系统 39 | 40 | ## 🌟 主要特性 41 | 42 | ### PPO (近端策略优化) 43 | - 基于策略梯度的强化学习算法 44 | - 使用裁剪目标函数限制策略更新幅度 45 | - 支持连续和离散动作空间 46 | - 提供稳定高效的训练过程 47 | 48 | ### MAPPO (多智能体PPO) 49 | - PPO的多智能体扩展 50 | - 采用"集中训练,分散执行"架构 51 | - 为需要智能体协作的任务提供解决方案 52 | 53 | ## 🎮 支持的环境 54 | 55 | 框架包含四个控制场景: 56 | 57 | 1. **倒立摆** 58 | - 经典控制问题 59 | - 通过左右移动小车平衡垂直杆 60 | - 离散动作空间示例 61 | 62 | 2. **直流电机控制** 63 | - 调节电机速度以跟踪目标速度 64 | - 连续动作空间示例 65 | - 包含电气和机械系统动态 66 | 67 | 3. **交流电机FOC控制** 68 | - 实现交流感应电机的磁场定向控制 69 | - 高级工业控制示例 70 | - 涉及复杂坐标变换和动态 71 | 72 | 4. **双摆系统** 73 | - 多智能体协同控制示例 74 | - 展示MAPPO算法优势 75 | - 需要协调控制的两个连接摆 76 | 77 | ## 🚀 快速入门 78 | 79 | ### 安装要求 80 | - MATLAB R2019b 或更高版本 81 | - Deep Learning Toolbox 82 | - Parallel Computing Toolbox (可选,用于并行数据收集) 83 | - GPU 支持 (可选,但推荐用于加速训练) 84 | 85 | ### 安装步骤 86 | 87 | 1. **下载项目** 88 | ```bash 89 | git clone https://github.com/AIResearcherHZ/Matlab_PPO.git 90 | cd Matlab_PPO 91 | ``` 92 | 93 | ### 使用步骤 94 | 95 | #### 1. 运行示例 96 | ```matlab 97 | % 运行倒立摆训练示例 98 | train_cartpole 99 | 100 | % 运行直流电机控制示例 101 | train_dcmotor 102 | 103 | % 运行交流电机FOC控制示例 104 | train_acmotor 105 | 106 | % 运行双摆系统MAPPO示例 107 | train_doublependulum 108 | ``` 109 | 110 | #### 2. 测试训练结果 111 | ```matlab 112 | % 测试倒立摆 113 | test_cartpole 114 | 115 | % 测试直流电机控制 116 | test_dcmotor 117 | 118 | % 测试交流电机FOC控制 119 | test_acmotor 120 | 121 | % 测试双摆系统 122 | test_doublependulum 123 | ``` 124 | 125 | #### 2. 自定义配置 126 | 127 | ```matlab 128 | % 创建和修改配置 129 | config = PPOConfig(); 130 | config.gamma = 0.99; % 折扣因子 131 | config.epsilon = 0.2; % 裁剪参数 132 | config.actorLearningRate = 3e-4; % 策略网络学习率 133 | config.useGPU = true; % 启用GPU加速 134 | 135 | % 使用自定义配置训练 136 | agent = PPOAgent(env, config); 137 | agent.train(); 138 | ``` 139 | 140 | #### 3. 保存和加载模型 141 | 142 | ```matlab 143 | % 保存训练好的模型 144 | agent.save('my_trained_model.mat'); 145 | 146 | % 加载模型 147 | agent = PPOAgent.load('my_trained_model.mat', env); 148 | ``` 149 | 150 | ## 📁 目录结构 151 | 152 | ``` 153 | Matlab_PPO/ 154 | ├── core/ # 核心算法实现 155 | ├── environments/ # 环境实现 156 | ├── config/ # 配置文件 157 | ├── utils/ # 工具函数 158 | ├── examples/ # 示例脚本 159 | └── logs/ # 日志和模型保存目录 160 | ``` 161 | 162 | ## 📚 文档 163 | 164 | - [TUTORIAL.md](TUTORIAL.md) - 详细的算法教程和实现细节 165 | - API文档可通过MATLAB的`help`命令获取,例如: 166 | ```matlab 167 | help PPOAgent 168 | help Environment 169 | help DCMotorEnv 170 | ``` 171 | 172 | ## 🔮 未来规划 173 | 174 | 1. **算法增强** 175 | - 实现更多先进的PPO变体 176 | - 添加其他流行的强化学习算法 177 | - 优化多智能体训练策略 178 | 179 | 2. **环境扩展** 180 | - 增加更多工业控制场景 181 | - 支持自定义环境接口 182 | - 添加仿真环境可视化 183 | 184 | 3. **性能优化** 185 | - 进一步提升GPU利用率 186 | - 优化并行训练机制 187 | - 改进数据收集效率 188 | 189 | ## 🤝 贡献 190 | 191 | 欢迎通过Issue和Pull Request提交改进建议和贡献代码。我们特别欢迎以下方面的贡献: 192 | 193 | - 🐛 Bug修复和问题报告 194 | - ✨ 新功能和改进建议 195 | - 📝 文档完善和翻译 196 | - 🎯 新的应用场景示例 197 | 198 | ## 📖 引用 199 | 200 | 如果您在研究中使用了本框架,请引用以下文献: 201 | 202 | ```bibtex 203 | @misc{matlab_ppo, 204 | author = {}, 205 | title = {Matlab PPO: A Reinforcement Learning Framework for Control Systems}, 206 | year = {2025}, 207 | publisher = {GitHub}, 208 | url = {https://github.com/X-Embodied/Matlab_PPO} 209 | } 210 | ``` 211 | 212 | ## 📄 许可证 213 | 214 | 本项目采用 MIT 许可证 - 详见 [LICENSE](LICENSE) 文件 215 | 216 | ## 🙏 致谢 217 | 218 | 感谢所有为本项目做出贡献的研究者和开发者! 219 | -------------------------------------------------------------------------------- /TUTORIAL.md: -------------------------------------------------------------------------------- 1 | # Matlab PPO Reinforcement Learning Framework Detailed Tutorial 2 | 3 | > For Chinese version, please refer to [TUTORIAL_zh.md](TUTORIAL_zh.md) 4 | 5 | # Matlab PPO Reinforcement Learning Framework Detailed Tutorial 6 | 7 | This tutorial provides an in-depth explanation of the Matlab PPO reinforcement learning framework, including algorithm principles, implementation details, environment model descriptions and extension guidelines. 8 | 9 | ## Table of Contents 10 | 11 | 1. [Algorithm Details](#algorithm-details) 12 | - [PPO Algorithm Principles](#ppo-algorithm-principles) 13 | - [MAPPO Algorithm Extension](#mappo-algorithm-extension) 14 | 2. [Environment Models](#environment-models) 15 | - [CartPole](#cartpole) 16 | - [DC Motor Control](#dc-motor-control) 17 | - [AC Induction Motor FOC Control](#ac-induction-motor-foc-control) 18 | - [Double Pendulum System](#double-pendulum-system) 19 | 3. [Framework Implementation Details](#framework-implementation-details) 20 | - [Core Classes](#core-classes) 21 | - [Network Architecture](#network-architecture) 22 | - [Training Process](#training-process) 23 | 4. [Advanced Application Guide](#advanced-application-guide) 24 | - [Hyperparameter Tuning](#hyperparameter-tuning) 25 | - [Custom Environment Development](#custom-environment-development) 26 | - [Multi-Agent System Design](#multi-agent-system-design) 27 | 5. [Performance Optimization and Debugging](#performance-optimization-and-debugging) 28 | - [GPU Acceleration](#gpu-acceleration) 29 | - [Parallel Computing](#parallel-computing) 30 | - [Troubleshooting](#troubleshooting) 31 | 32 | ## Algorithm Details 33 | 34 | ### PPO Algorithm Principles 35 | 36 | PPO (Proximal Policy Optimization) is a policy gradient method that balances sample efficiency and implementation simplicity. It introduces a clipping term during policy updates to limit the difference between old and new policies, preventing excessively large policy updates. 37 | 38 | #### Core Mathematical Principles 39 | 40 | The PPO objective function is: 41 | 42 | $$L^{CLIP}(\theta) = \hat{E}_t\left[\min(r_t(\theta)\hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat{A}_t)\right]$$ 43 | 44 | Where: 45 | - $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$ represents the probability ratio between old and new policies 46 | - $\hat{A}_t$ is the advantage function estimate 47 | - $\epsilon$ is the clipping parameter, typically set to 0.2 48 | 49 | #### Algorithm Steps 50 | 51 | 1. **Sampling**: Collect a batch of trajectory data using current policy $\pi_{\theta_{old}}$ 52 | 2. **Advantage Estimation**: Calculate advantage estimates $\hat{A}_t$ for each state-action pair 53 | 3. **Policy Update**: Maximize the clipped objective through multiple mini-batch gradient ascent steps 54 | 4. **Value Function Update**: Update the value function to better estimate returns 55 | 56 | #### Generalized Advantage Estimation (GAE) 57 | 58 | PPO typically uses GAE to calculate the advantage function, balancing bias and variance: 59 | 60 | $$\hat{A}_t = \delta_t + (\gamma\lambda)\delta_{t+1} + ... + (\gamma\lambda)^{T-t+1}\delta_{T-1}$$ 61 | 62 | Where: 63 | - $\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)$ is the temporal difference error 64 | - $\gamma$ is the discount factor 65 | - $\lambda$ is the GAE parameter, controlling the bias-variance trade-off 66 | 67 | ### MAPPO Algorithm Extension 68 | 69 | MAPPO (Multi-Agent PPO) is an extension of PPO for multi-agent environments. It adopts the CTDE (Centralized Training with Decentralized Execution) framework, allowing the use of global information during training while only using local observations during execution. 70 | 71 | #### Core Design 72 | 73 | 1. **Independent Policy Networks**: Each agent $i$ has its own policy network $\pi_i(a_i|o_i)$, making decisions based on local observations $o_i$ 74 | 2. **Centralized Value Network**: Uses a centralized critic network $V(s)$ to evaluate global state value 75 | 3. **Coordinated Training**: Considers interactions between agents to optimize joint policy returns 76 | 77 | #### Differences from Single-Agent PPO 78 | 79 | - **Observation-State Separation**: Each agent only receives local observations, but the value network can use global state 80 | - **Credit Assignment**: Needs to solve the credit assignment problem in multi-agent settings, determining each agent's contribution to global returns 81 | - **Collaborative Policy**: Learns cooperative strategies through implicit coordination between agents 82 | 83 | #### Application Scenarios 84 | 85 | MAPPO is particularly suitable for: 86 | - Naturally distributed control problems (e.g., multi-joint robots) 87 | - Tasks requiring cooperation (e.g., multi-agent collaborative control) 88 | - Tasks requiring specialized agent roles (e.g., controllers with different functions) 89 | 90 | ## Environment Models 91 | 92 | ### CartPole 93 | 94 | CartPole is a classic control problem and a standard test environment for reinforcement learning beginners. 95 | 96 | #### Physical Model 97 | 98 | The CartPole system consists of a horizontally movable cart and a rigid pendulum connected to the cart. The system dynamics can be described by the following differential equations: 99 | 100 | $$\ddot{x} = \frac{F + m l \sin\theta (\dot{\theta})^2 - m g \cos\theta \sin\theta}{M + m \sin^2\theta}$$ 101 | 102 | $$\ddot{\theta} = \frac{g \sin\theta - \ddot{x}\cos\theta}{l}$$ 103 | 104 | Where: 105 | - $x$ is the cart position 106 | - $\theta$ is the pole angle (relative to vertical upward direction) 107 | - $F$ is the force applied to the cart 108 | - $M$ is the cart mass 109 | - $m$ is the pole mass 110 | - $l$ is the pole half-length 111 | - $g$ is the gravitational acceleration 112 | 113 | #### Reinforcement Learning Setup 114 | 115 | - **State Space**: $[x, \dot{x}, \theta, \dot{\theta}]$ 116 | - **Action Space**: Discrete actions {move left, move right} 117 | - **Reward**: +1 for each timestep until termination 118 | - **Termination Conditions**: Pole angle exceeds 15 degrees or cart position deviates more than 2.4 units from center 119 | 120 | In our implementation, `CartPoleEnv.m` provides complete environment simulation including dynamics update, reward calculation and visualization functions. 121 | 122 | ### DC Motor Control 123 | 124 | DC motor control is a fundamental industrial control problem involving dynamic models of electrical and mechanical systems. 125 | 126 | #### Motor Model 127 | 128 | The dynamic equations of a DC motor are as follows: 129 | 130 | **Electrical Equation**: 131 | $$L\frac{di}{dt} = v - Ri - K_e\omega$$ 132 | 133 | **Mechanical Equation**: 134 | $$J\frac{d\omega}{dt} = K_ti - B\omega - T_L$$ 135 | 136 | Where: 137 | - $v$ is the applied voltage 138 | - $i$ is the motor current 139 | - $\omega$ is the motor angular velocity 140 | - $T_L$ is the load torque 141 | - $L$ is the inductance 142 | - $R$ is the resistance 143 | - $K_t$ is the torque constant 144 | - $K_e$ is the back EMF constant 145 | - $J$ is the moment of inertia 146 | - $B$ is the friction coefficient 147 | 148 | #### Reinforcement Learning Setup 149 | 150 | - **State Space**: $[\omega, i, \omega_{ref}, \omega_{ref} - \omega]$ 151 | - **Action Space**: Continuous actions representing applied voltage $v \in [-V_{max}, V_{max}]$ 152 | - **Reward**: Combines speed tracking error, control signal magnitude and energy consumption 153 | - **Termination Conditions**: Maximum steps reached or excessive speed error 154 | 155 | `DCMotorEnv.m` implements this environment model, including discretized dynamic equation solving, step response testing and load disturbance simulation. 156 | 157 | ### AC Induction Motor FOC Control 158 | 159 | AC induction motor control is a more complex problem. This framework implements learning strategies for AC motor control based on Field-Oriented Control (FOC). 160 | 161 | #### FOC Principles 162 | 163 | FOC transforms AC motor control into a problem similar to DC motor control through coordinate transformation: 164 | 165 | 1. **Three-phase to Two-phase Transformation**: Converts three-phase currents/voltages ($i_a$, $i_b$, $i_c$) to two-phase stationary coordinate system ($i_\alpha$, $i_\beta$) 166 | 2. **Stationary to Rotating Coordinates**: Transforms stationary coordinates to rotating coordinates ($i_d$, $i_q$) synchronized with rotor magnetic field 167 | 3. **Control in d-q Coordinates**: In the rotating coordinate system, $i_d$ controls flux and $i_q$ controls torque 168 | 169 | #### AC Motor Model 170 | 171 | In the d-q rotating coordinate system, the voltage equations of the induction motor are: 172 | 173 | $$v_d = R_si_d + \frac{d\lambda_d}{dt} - \omega_e\lambda_q$$ 174 | $$v_q = R_si_q + \frac{d\lambda_q}{dt} + \omega_e\lambda_d$$ 175 | 176 | Flux linkage equations: 177 | $$\lambda_d = L_di_d + L_mi_{dr}$$ 178 | $$\lambda_q = L_qi_q + L_mi_{qr}$$ 179 | 180 | #### Reinforcement Learning Setup 181 | 182 | - **State Space**: $[\omega_r, i_d, i_q, \omega_{ref}, \lambda_d, \lambda_q]$ 183 | - **Action Space**: Continuous actions representing d-q axis voltages $v_d$ and $v_q$ 184 | - **Reward**: Combines speed tracking error, flux linkage stability, current limits and energy efficiency 185 | 186 | ### Double Pendulum System 187 | 188 | The double pendulum system is a classic multi-agent collaboration problem, suitable for solving with MAPPO algorithm. 189 | 190 | #### System Model 191 | 192 | The double pendulum system consists of two connected rigid pendulum rods, each controlled by an independent actuator. The system's Lagrangian equation is: 193 | 194 | $$M(q)\ddot{q} + C(q,\dot{q})\dot{q} + G(q) = \tau$$ 195 | 196 | Where: 197 | - $q = [\theta_1, \theta_2]^T$ is the pendulum angle vector 198 | - $M(q)$ is the mass matrix 199 | - $C(q,\dot{q})$ represents Coriolis and centrifugal terms 200 | - $G(q)$ is the gravity term 201 | - $\tau = [\tau_1, \tau_2]^T$ is the torque applied by actuators 202 | 203 | #### Multi-Agent Reinforcement Learning Setup 204 | 205 | - **Agent Observations**: Each agent observes its own pendulum state and limited information about adjacent pendulums 206 | - **Action Space**: Each agent controls an actuator torque $\tau_i$ 207 | - **Global State**: Complete dynamic state of the entire system 208 | - **Joint Reward**: Shared reward function based on system stability 209 | 210 | `DoublePendulumEnv.m` implements this environment and provides multi-agent interfaces to support MAPPO algorithm training. 211 | 212 | ## Framework Implementation Details 213 | 214 | ### Core Classes 215 | 216 | #### PPOAgent Class 217 | 218 | `PPOAgent.m` is the core implementation of single-agent PPO algorithm, with main methods including: 219 | 220 | - **collectTrajectories**: Collects agent-environment interaction trajectories 221 | - **computeGAE**: Calculates advantage values using Generalized Advantage Estimation 222 | - **updatePolicy**: Updates policy and value networks based on collected data 223 | - **train**: Executes complete training loop 224 | - **getAction**: Gets action for given observation 225 | 226 | Key attributes include policy network, value network and various training parameters. 227 | 228 | #### MAPPOAgent Class 229 | 230 | `MAPPOAgent.m` extends PPO algorithm to multi-agent scenarios, with main features: 231 | 232 | - Manages policy networks (Actors) for multiple agents 233 | - Uses centralized critic network to estimate joint value function 234 | - Coordinates trajectory collection and policy updates for multiple agents 235 | 236 | #### ActorNetwork and CriticNetwork 237 | 238 | - **ContinuousActorNetwork.m**: Implements policy network for continuous action spaces 239 | - **DiscreteActorNetwork.m**: Implements policy network for discrete action spaces 240 | - **CriticNetwork.m**: Implements state value estimation network 241 | 242 | These network classes encapsulate network structure, forward propagation and gradient calculation functions. 243 | 244 | ### Network Architecture Design 245 | 246 | #### Policy Network 247 | 248 | The policy network (Actor) typically consists of the following components: 249 | 250 | ``` 251 | Observation Input -> Fully Connected Layer -> ReLU -> Fully Connected Layer -> ReLU -> Output Layer 252 | ``` 253 | 254 | For continuous action spaces, the output layer generates action mean and standard deviation; for discrete action spaces, the output layer generates action probability distribution. 255 | 256 | #### Value Network 257 | 258 | The value network (Critic) has a similar structure: 259 | 260 | ``` 261 | State Input -> Fully Connected Layer -> ReLU -> Fully Connected Layer -> ReLU -> Scalar Output (State Value) 262 | ``` 263 | 264 | For MAPPO, the value network receives joint observations as input to evaluate global state value. 265 | 266 | ### Training Process Description 267 | 268 | The typical PPO training process is as follows: 269 | 270 | 1. **Network Initialization**: Create policy and value networks 271 | 2. **Training Loop**: 272 | - Collect trajectory data through environment interaction 273 | - Calculate returns and advantage estimates 274 | - Update policy network (multiple mini-batch updates) 275 | - Update value network 276 | - Record and visualize training data 277 | 3. **Model Saving**: Save trained network parameters to file 278 | 279 | The MAPPO training process is similar but requires coordination of data collection and policy updates for multiple agents. 280 | 281 | ## Advanced Application Guide 282 | 283 | ### Hyperparameter Tuning 284 | 285 | Reinforcement learning algorithms are highly sensitive to hyperparameters. Here are key hyperparameters and tuning recommendations: 286 | 287 | - **PPO clipping parameter (epsilon)**: Typically set to 0.1-0.3, controls policy update magnitude 288 | - **Discount factor (gamma)**: Typically set to 0.95-0.99, controls importance of future rewards 289 | - **GAE parameter (lambda)**: Typically set to 0.9-0.99, controls bias-variance tradeoff 290 | - **Learning rate**: For policy and value networks, usually in range 1e-4 to 1e-3 291 | - **Network size**: Hidden layer size and count, depends on task complexity 292 | 293 | For complex environments, grid search or Bayesian optimization is recommended to find optimal hyperparameter combinations. 294 | 295 | ### Custom Environment Development 296 | 297 | Creating custom environments is an important way to extend the framework's functionality. Here are the basic steps for developing a new environment: 298 | 299 | 1. **Inherit Base Class**: The new environment should inherit from the `Environment` base class 300 | 2. **Implement Required Methods**: 301 | - `reset()`: Reset the environment to initial state and return initial observation 302 | - `step(action)`: Execute action, update environment state and return (next observation, reward, done, info) 303 | - `render()`: Optional visualization method 304 | 305 | #### Custom Environment Example 306 | 307 | ```matlab 308 | classdef MyCustomEnv < Environment 309 | properties 310 | % Environment state variables 311 | state 312 | 313 | % Environment parameters 314 | param1 315 | param2 316 | end 317 | 318 | methods 319 | function obj = MyCustomEnv(config) 320 | % Initialize environment parameters 321 | obj.param1 = config.param1Value; 322 | obj.param2 = config.param2Value; 323 | 324 | % Define observation and action space dimensions 325 | obj.observationDimension = 4; 326 | obj.continuousAction = true; % Use continuous action space 327 | obj.actionDimension = 2; 328 | end 329 | 330 | function observation = reset(obj) 331 | % Reset environment state 332 | obj.state = [0; 0; 0; 0]; 333 | 334 | % Return initial observation 335 | observation = obj.state; 336 | end 337 | 338 | function [nextObs, reward, done, info] = step(obj, action) 339 | % Validate action 340 | action = min(max(action, -1), 1); % Clip action to [-1,1] 341 | 342 | % Update environment state 343 | % ... Implement state transition equations ... 344 | 345 | % Calculate reward 346 | reward = calculateReward(obj, action); 347 | 348 | % Check termination condition 349 | done = checkTermination(obj); 350 | 351 | % Return results 352 | nextObs = obj.state; 353 | info = struct(); % Can contain additional information 354 | end 355 | 356 | function render(obj) 357 | % Implement visualization 358 | figure(1); 359 | % ... Draw environment state ... 360 | drawnow; 361 | end 362 | 363 | function reward = calculateReward(obj, action) 364 | % Custom reward function 365 | % ... Calculate reward ... 366 | end 367 | 368 | function done = checkTermination(obj) 369 | % Check termination conditions 370 | % ... Determine if episode should end ... 371 | end 372 | end 373 | end 374 | ``` 375 | 376 | ### Multi-Agent System Design 377 | 378 | Designing multi-agent systems requires consideration of the following key points: 379 | 380 | 1. **Environment Interface**: The multi-agent environment should provide interfaces supporting multiple agents 381 | - Implement `getNumAgents()` method to return agent count 382 | - `step(actions)` should receive joint actions from all agents 383 | - `reset()` should return initial observations for each agent 384 | 385 | 2. **Observation and State Design**: 386 | - Clearly distinguish between local observations (visible to each agent) and global state (visible to centralized critic) 387 | - Define observation function `getObservation(agentIdx)` to get specified agent's observation 388 | 389 | 3. **Reward Design**: 390 | - Shared reward: All agents receive same reward to promote cooperation 391 | - Individual reward: Each agent has its own reward function, which may lead to competition 392 | - Hybrid reward: Combine shared and individual rewards to balance cooperation and specific task objectives 393 | 394 | ## Performance Optimization and Debugging 395 | 396 | ### GPU Acceleration 397 | 398 | This framework supports MATLAB's GPU acceleration to significantly improve training speed: 399 | 400 | ```matlab 401 | % Enable GPU in configuration 402 | config = PPOConfig(); 403 | config.useGPU = true; 404 | 405 | % Ensure network parameters are on GPU 406 | net = dlnetwork(netLayers); 407 | if config.useGPU && canUseGPU() 408 | net = dlupdate(@gpuArray, net); 409 | end 410 | ``` 411 | 412 | Using GPU acceleration requires installing compatible CUDA and GPU computing toolboxes. For large networks, GPU acceleration can improve training speed by 5-10x. 413 | 414 | ### Parallel Computing 415 | 416 | This framework utilizes MATLAB's Parallel Computing Toolbox for parallelizing data collection: 417 | 418 | ```matlab 419 | % Enable parallel computing in configuration 420 | config = PPOConfig(); 421 | config.useParallel = true; 422 | config.numWorkers = 4; % Number of parallel workers 423 | 424 | % Parallel trajectory collection 425 | if config.useParallel 426 | parfor i = 1:numTrajectories 427 | % ... Parallel collect trajectories ... 428 | end 429 | else 430 | for i = 1:numTrajectories 431 | % ... Serial collect trajectories ... 432 | end 433 | end 434 | ``` 435 | 436 | Parallel computing is particularly suitable for trajectory collection as different trajectories have no dependencies. For complex environments and large numbers of trajectories, near-linear speedup can be achieved. 437 | 438 | ### Troubleshooting 439 | 440 | #### 1. Training Not Converging 441 | 442 | Possible causes and solutions: 443 | - **Learning rate too high**: Try reducing learning rate 444 | - **Network structure unsuitable**: Increase network capacity or adjust structure 445 | - **Poor reward design**: Redesign more instructive reward function 446 | - **Inaccurate advantage estimation**: Adjust GAE parameter (lambda) 447 | 448 | #### 2. Unstable Training Process 449 | 450 | Possible causes and solutions: 451 | - **Batch size too small**: Increase batch size to reduce gradient estimation variance 452 | - **Inappropriate clipping parameter**: Adjust epsilon value (0.1-0.3) 453 | - **Too many update steps**: Reduce the number of updates per batch of data 454 | 455 | #### 3. Multi-Agent Coordination Issues 456 | 457 | Possible causes and solutions: 458 | - **Poor observation space design**: Ensure agents have sufficient information for collaboration 459 | - **Shared reward ratio problem**: Adjust the ratio between shared and individual rewards 460 | - **Insufficient value network capacity**: Increase the capacity of the centralized critic network 461 | 462 | #### 4. Performance Analysis Tools 463 | 464 | Using MATLAB's built-in performance analysis tools: 465 | ```matlab 466 | % Enable code profiling 467 | profile on 468 | 469 | % Run the code to be analyzed 470 | agent.train(env, 10); 471 | 472 | % View the profiling report 473 | profile viewer 474 | ``` 475 | 476 | This helps identify bottlenecks in the code and optimize critical sections. 477 | 478 | ## Summary 479 | 480 | The Matlab PPO framework provides powerful and flexible infrastructure for solving various control problems, from basic inverted pendulums to complex multi-agent systems. By deeply understanding the algorithm principles, implementation details and environment models, users can fully utilize the framework's capabilities and extend/optimize it according to their needs. 481 | 482 | Whether for research or engineering applications, this framework provides the necessary tools and flexibility to address challenges in modern control systems. As reinforcement learning continues to evolve, we will keep updating and improving this framework to maintain its practicality and advancement. 483 | -------------------------------------------------------------------------------- /TUTORIAL_zh.md: -------------------------------------------------------------------------------- 1 | # Matlab PPO 强化学习框架详细教程 2 | 3 | 本教程提供了对Matlab PPO强化学习框架的深入解释,包括算法原理、实现细节、环境模型说明和扩展指南。 4 | 5 | > 英文版本请参考 [TUTORIAL.md](TUTORIAL.md) 6 | 7 | ## 目录 8 | 9 | 1. [算法详解](#算法详解) 10 | - [PPO算法原理](#PPO算法原理) 11 | - [MAPPO算法扩展](#MAPPO算法扩展) 12 | 2. [环境模型详解](#环境模型详解) 13 | - [倒立摆(CartPole)](#倒立摆) 14 | - [直流电机控制](#直流电机控制) 15 | - [交流感应电机FOC控制](#交流感应电机FOC控制) 16 | - [双倒立摆系统](#双倒立摆系统) 17 | 3. [框架实现细节](#框架实现细节) 18 | - [核心类说明](#核心类说明) 19 | - [网络结构设计](#网络结构设计) 20 | - [训练流程说明](#训练流程说明) 21 | 4. [高级应用指南](#高级应用指南) 22 | - [超参数调优](#超参数调优) 23 | - [自定义环境开发](#自定义环境开发) 24 | - [多智能体系统设计](#多智能体系统设计) 25 | 5. [性能优化与调试](#性能优化与调试) 26 | - [GPU加速](#GPU加速) 27 | - [并行计算](#并行计算) 28 | - [常见问题排查](#常见问题排查) 29 | 30 | ## 算法详解 31 | 32 | ### PPO算法原理 33 | 34 | PPO(近端策略优化)算法是一种策略梯度方法,旨在平衡样本效率和实现简单性。它通过在策略更新时引入一个裁剪项来限制新旧策略之间的差异,从而防止过大的策略更新。 35 | 36 | #### 核心数学原理 37 | 38 | PPO的目标函数为: 39 | 40 | $$L^{CLIP}(\theta) = \hat{E}_t\left[\min(r_t(\theta)\hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat{A}_t)\right]$$ 41 | 42 | 其中: 43 | - $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$ 表示新旧策略之间的概率比 44 | - $\hat{A}_t$ 是优势函数估计 45 | - $\epsilon$ 是裁剪参数,通常设为0.2 46 | 47 | #### 算法步骤 48 | 49 | 1. **采样**:使用当前策略 $\pi_{\theta_{old}}$ 收集一批轨迹数据 50 | 2. **优势估计**:计算每个状态-动作对的优势估计 $\hat{A}_t$ 51 | 3. **策略更新**:通过多次小批量梯度上升最大化裁剪目标 52 | 4. **价值函数更新**:更新价值函数以更好地估计回报 53 | 54 | #### 广义优势估计(GAE) 55 | 56 | PPO中通常使用GAE来计算优势函数,平衡偏差和方差: 57 | 58 | $$\hat{A}_t = \delta_t + (\gamma\lambda)\delta_{t+1} + ... + (\gamma\lambda)^{T-t+1}\delta_{T-1}$$ 59 | 60 | 其中: 61 | - $\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)$ 是时间差分误差 62 | - $\gamma$ 是折扣因子 63 | - $\lambda$ 是GAE参数,控制偏差-方差权衡 64 | 65 | ### MAPPO算法扩展 66 | 67 | MAPPO(多智能体PPO)是PPO的扩展,适用于多智能体环境。它采用CTDE(集中训练,分散执行)框架,允许在训练阶段利用全局信息,而在执行阶段仅使用局部观察。 68 | 69 | #### 核心设计 70 | 71 | 1. **独立策略网络**:每个智能体 $i$ 有自己的策略网络 $\pi_i(a_i|o_i)$,基于局部观察 $o_i$ 做决策 72 | 2. **中央价值网络**:使用中央评论家网络 $V(s)$ 评估全局状态价值 73 | 3. **协同训练**:考虑智能体间的交互,优化联合策略的回报 74 | 75 | #### 与单智能体PPO的区别 76 | 77 | - **观察与状态分离**:每个智能体只获得局部观察,但价值网络可使用全局状态 78 | - **信用分配**:需要解决多智能体下的信用分配问题,即确定每个智能体对全局回报的贡献 79 | - **协作策略**:通过智能体间的隐式协调学习协作策略 80 | 81 | #### 应用场景 82 | 83 | MAPPO特别适合以下场景: 84 | - 自然分布式的控制问题(如多关节机器人) 85 | - 需要协作的任务(如多智能体协同控制) 86 | - 任务需要专业化的智能体角色(如不同功能的控制器) 87 | 88 | ## 环境模型详解 89 | 90 | ### 倒立摆 91 | 92 | 倒立摆(CartPole)是一个经典的控制问题,也是强化学习入门的标准测试环境。 93 | 94 | #### 物理模型 95 | 96 | 倒立摆系统由一个可水平移动的小车和一个连接在小车上的刚性摆杆组成。系统动力学可由以下微分方程描述: 97 | 98 | $$\ddot{x} = \frac{F + m l \sin\theta (\dot{\theta})^2 - m g \cos\theta \sin\theta}{M + m \sin^2\theta}$$ 99 | 100 | $$\ddot{\theta} = \frac{g \sin\theta - \ddot{x}\cos\theta}{l}$$ 101 | 102 | 其中: 103 | - $x$ 是小车位置 104 | - $\theta$ 是摆杆角度(相对于垂直向上方向) 105 | - $F$ 是施加于小车的力 106 | - $M$ 是小车质量 107 | - $m$ 是摆杆质量 108 | - $l$ 是摆杆半长 109 | - $g$ 是重力加速度 110 | 111 | #### 强化学习设置 112 | 113 | - **状态空间**:$[x, \dot{x}, \theta, \dot{\theta}]$ 114 | - **动作空间**:离散动作 {左移, 右移} 115 | - **奖励**:每个时间步+1,直到终止条件 116 | - **终止条件**:摆杆角度超过15度或小车位置偏离中心点超过2.4个单位 117 | 118 | 在我们的实现中,`CartPoleEnv.m` 提供了完整的环境模拟,包括动力学更新、奖励计算和可视化功能。 119 | 120 | ### 直流电机控制 121 | 122 | 直流电机控制是一个基础的工业控制问题,涉及电气和机械系统的动态模型。 123 | 124 | #### 电机模型 125 | 126 | 直流电机的动态方程如下: 127 | 128 | **电气方程**: 129 | $$L\frac{di}{dt} = v - Ri - K_e\omega$$ 130 | 131 | **机械方程**: 132 | $$J\frac{d\omega}{dt} = K_ti - B\omega - T_L$$ 133 | 134 | 其中: 135 | - $v$ 是施加的电压 136 | - $i$ 是电机电流 137 | - $\omega$ 是电机角速度 138 | - $T_L$ 是负载转矩 139 | - $L$ 是电感 140 | - $R$ 是电阻 141 | - $K_t$ 是转矩常数 142 | - $K_e$ 是反电动势常数 143 | - $J$ 是转动惯量 144 | - $B$ 是摩擦系数 145 | 146 | #### 强化学习设置 147 | 148 | - **状态空间**:$[\omega, i, \omega_{ref}, \omega_{ref} - \omega]$ 149 | - **动作空间**:连续动作,代表施加的电压 $v \in [-V_{max}, V_{max}]$ 150 | - **奖励**:结合速度跟踪误差、控制信号幅度和能量消耗 151 | - **终止条件**:达到最大步数或速度误差过大 152 | 153 | `DCMotorEnv.m` 实现了这一环境模型,包括离散化的动态方程求解、阶跃响应测试和负载扰动模拟。 154 | 155 | ### 交流感应电机FOC控制 156 | 157 | 交流感应电机控制是一个更复杂的问题,本框架实现了基于磁场定向控制(FOC)的交流电机控制策略学习。 158 | 159 | #### FOC原理 160 | 161 | FOC通过坐标变换,将交流电机的控制转化为类似于直流电机的控制问题: 162 | 163 | 1. **三相到两相变换**:将三相电流/电压($i_a$, $i_b$, $i_c$)转换为两相静止坐标系($i_\alpha$, $i_\beta$) 164 | 2. **静止坐标到旋转坐标**:将静止坐标系转换为与转子磁场同步的旋转坐标系($i_d$, $i_q$) 165 | 3. **在d-q坐标系控制**:在旋转坐标系中,$i_d$控制磁通,$i_q$控制转矩 166 | 167 | #### 交流电机模型 168 | 169 | 在d-q旋转坐标系下,感应电机的电压方程为: 170 | 171 | $$v_d = R_si_d + \frac{d\lambda_d}{dt} - \omega_e\lambda_q$$ 172 | $$v_q = R_si_q + \frac{d\lambda_q}{dt} + \omega_e\lambda_d$$ 173 | 174 | 磁链方程: 175 | $$\lambda_d = L_di_d + L_mi_{dr}$$ 176 | $$\lambda_q = L_qi_q + L_mi_{qr}$$ 177 | 178 | #### 强化学习设置 179 | 180 | - **状态空间**:$[\omega_r, i_d, i_q, \omega_{ref}, \lambda_d, \lambda_q]$ 181 | - **动作空间**:连续动作,代表d-q轴电压 $v_d$ 和 $v_q$ 182 | - **奖励**:结合速度跟踪误差、磁链稳定性、电流限制和能量效率 183 | 184 | ### 双倒立摆系统 185 | 186 | 双倒立摆系统是一个典型的多智能体协作问题,适合用MAPPO算法解决。 187 | 188 | #### 系统模型 189 | 190 | 双倒立摆系统由两个相连的刚性摆杆组成,每个摆杆通过一个独立执行器控制。系统的拉格朗日方程为: 191 | 192 | $$M(q)\ddot{q} + C(q,\dot{q})\dot{q} + G(q) = \tau$$ 193 | 194 | 其中: 195 | - $q = [\theta_1, \theta_2]^T$ 是摆杆角度向量 196 | - $M(q)$ 是质量矩阵 197 | - $C(q,\dot{q})$ 是科里奥利力和离心力项 198 | - $G(q)$ 是重力项 199 | - $\tau = [\tau_1, \tau_2]^T$ 是执行器施加的转矩 200 | 201 | #### 多智能体强化学习设置 202 | 203 | - **每个智能体的观察**:每个智能体观察自己摆杆的状态以及有限的相邻摆杆信息 204 | - **动作空间**:每个智能体控制一个执行器的转矩 $\tau_i$ 205 | - **全局状态**:整个系统的完整动力学状态 206 | - **联合奖励**:基于系统稳定性的共享奖励函数 207 | 208 | `DoublePendulumEnv.m` 实现了这一环境,并提供了多智能体接口以支持MAPPO算法的训练。 209 | 210 | ## 框架实现细节 211 | 212 | ### 核心类说明 213 | 214 | #### PPOAgent类 215 | 216 | `PPOAgent.m` 是单智能体PPO算法的核心实现,主要方法包括: 217 | 218 | - **collectTrajectories**:收集智能体与环境交互的轨迹 219 | - **computeGAE**:使用广义优势估计计算优势值 220 | - **updatePolicy**:根据收集的数据更新策略和价值网络 221 | - **train**:执行完整的训练循环 222 | - **getAction**:获取给定观察下的动作 223 | 224 | 关键属性包括策略网络、价值网络和各种训练参数。 225 | 226 | #### MAPPOAgent类 227 | 228 | `MAPPOAgent.m` 扩展了PPO算法到多智能体情境,其主要特点: 229 | 230 | - 管理多个智能体的策略网络(Actors) 231 | - 使用中央评论家网络(Critic)估计联合价值函数 232 | - 协调多个智能体的轨迹收集和策略更新 233 | 234 | #### ActorNetwork与CriticNetwork 235 | 236 | - **ContinuousActorNetwork.m**:实现连续动作空间的策略网络 237 | - **DiscreteActorNetwork.m**:实现离散动作空间的策略网络 238 | - **CriticNetwork.m**:实现状态价值估计网络 239 | 240 | 这些网络类封装了网络结构、前向传播和梯度计算等功能。 241 | 242 | ### 网络结构设计 243 | 244 | #### 策略网络 245 | 246 | 策略网络(Actor)通常由以下部分组成: 247 | 248 | ``` 249 | 观察输入 -> 全连接层 -> ReLU -> 全连接层 -> ReLU -> 输出层 250 | ``` 251 | 252 | 对于连续动作空间,输出层生成动作均值和标准差;对于离散动作空间,输出层生成动作概率分布。 253 | 254 | #### 价值网络 255 | 256 | 价值网络(Critic)结构类似: 257 | 258 | ``` 259 | 状态输入 -> 全连接层 -> ReLU -> 全连接层 -> ReLU -> 标量输出(状态价值) 260 | ``` 261 | 262 | 对于MAPPO,价值网络接收联合观察作为输入,评估全局状态价值。 263 | 264 | ### 训练流程说明 265 | 266 | 以下是PPO训练的典型流程: 267 | 268 | 1. **初始化网络**:创建策略和价值网络 269 | 2. **循环训练**: 270 | - 收集与环境交互的轨迹数据 271 | - 计算回报和优势估计 272 | - 更新策略网络(多次小批量更新) 273 | - 更新价值网络 274 | - 记录和可视化训练数据 275 | 3. **保存模型**:将训练好的网络参数保存到文件 276 | 277 | MAPPO训练流程类似,但需要协调多个智能体的数据收集和策略更新。 278 | 279 | ## 高级应用指南 280 | 281 | ### 超参数调优 282 | 283 | 强化学习算法对超参数非常敏感,以下是关键超参数及其调整建议: 284 | 285 | - **PPO裁剪参数(epsilon)**:通常设为0.1-0.3,控制策略更新幅度 286 | - **折扣因子(gamma)**:通常设为0.95-0.99,控制未来奖励的重要性 287 | - **GAE参数(lambda)**:通常设为0.9-0.99,控制偏差-方差权衡 288 | - **学习率**:策略和价值网络的学习率,通常在1e-4到1e-3范围内 289 | - **网络规模**:隐藏层大小和层数,取决于任务复杂度 290 | 291 | 对于复杂环境,建议使用网格搜索或贝叶斯优化来找到最佳超参数组合。 292 | 293 | ### 自定义环境开发 294 | 295 | 创建自定义环境是扩展框架功能的重要方式。以下是开发新环境的基本步骤: 296 | 297 | 1. **继承基类**:新环境应继承自`Environment`基类 298 | 2. **实现必要方法**: 299 | - `reset()`:重置环境到初始状态并返回初始观察 300 | - `step(action)`:执行动作,更新环境状态并返回(下一个观察,奖励,是否结束,信息) 301 | - `render()`:可选的可视化方法 302 | 303 | #### 自定义环境示例 304 | 305 | ```matlab 306 | classdef MyCustomEnv < Environment 307 | properties 308 | % 环境状态变量 309 | state 310 | 311 | % 环境参数 312 | param1 313 | param2 314 | end 315 | 316 | methods 317 | function obj = MyCustomEnv(config) 318 | % 初始化环境参数 319 | obj.param1 = config.param1Value; 320 | obj.param2 = config.param2Value; 321 | 322 | % 定义观察和动作空间维度 323 | obj.observationDimension = 4; 324 | obj.continuousAction = true; % 使用连续动作空间 325 | obj.actionDimension = 2; 326 | end 327 | 328 | function observation = reset(obj) 329 | % 重置环境状态 330 | obj.state = [0; 0; 0; 0]; 331 | 332 | % 返回初始观察 333 | observation = obj.state; 334 | end 335 | 336 | function [nextObs, reward, done, info] = step(obj, action) 337 | % 验证动作 338 | action = min(max(action, -1), 1); % 裁剪动作到[-1,1] 339 | 340 | % 更新环境状态 341 | % ... 实现状态转移方程 ... 342 | 343 | % 计算奖励 344 | reward = calculateReward(obj, action); 345 | 346 | % 检查是否结束 347 | done = checkTermination(obj); 348 | 349 | % 返回结果 350 | nextObs = obj.state; 351 | info = struct(); % 可包含额外信息 352 | end 353 | 354 | function render(obj) 355 | % 实现可视化 356 | figure(1); 357 | % ... 绘制环境状态 ... 358 | drawnow; 359 | end 360 | 361 | function reward = calculateReward(obj, action) 362 | % 自定义奖励函数 363 | % ... 计算奖励 ... 364 | end 365 | 366 | function done = checkTermination(obj) 367 | % 检查终止条件 368 | % ... 判断是否结束 ... 369 | end 370 | end 371 | end 372 | ``` 373 | 374 | ### 多智能体系统设计 375 | 376 | 设计多智能体系统需要考虑以下关键点: 377 | 378 | 1. **环境接口**:多智能体环境应提供支持多个智能体的接口 379 | - 提供`getNumAgents()`方法返回智能体数量 380 | - `step(actions)`接收所有智能体的联合动作 381 | - `reset()`返回每个智能体的初始观察 382 | 383 | 2. **观察和状态设计**: 384 | - 明确区分局部观察(每个智能体可见)和全局状态(中央评论家可见) 385 | - 定义观察函数`getObservation(agentIdx)`获取指定智能体的观察 386 | 387 | 3. **奖励设计**: 388 | - 共享奖励:所有智能体获得相同奖励,促进协作 389 | - 个体奖励:每个智能体有自己的奖励函数,可能导致竞争 390 | - 混合奖励:结合共享和个体奖励,平衡协作与特定任务目标 391 | 392 | ## 性能优化与调试 393 | 394 | ### GPU加速 395 | 396 | 本框架支持MATLAB的GPU加速功能,显著提升训练速度: 397 | 398 | ```matlab 399 | % 在配置中启用GPU 400 | config = PPOConfig(); 401 | config.useGPU = true; 402 | 403 | % 确保网络参数在GPU上 404 | net = dlnetwork(netLayers); 405 | if config.useGPU && canUseGPU() 406 | net = dlupdate(@gpuArray, net); 407 | end 408 | ``` 409 | 410 | 使用GPU加速需要安装兼容的CUDA和GPU计算工具箱。对于大型网络,GPU加速可提升5-10倍的训练速度。 411 | 412 | ### 并行计算 413 | 414 | 本框架利用MATLAB的并行计算工具箱进行数据收集的并行化: 415 | 416 | ```matlab 417 | % 在配置中启用并行计算 418 | config = PPOConfig(); 419 | config.useParallel = true; 420 | config.numWorkers = 4; % 并行工作进程数 421 | 422 | % 并行收集轨迹 423 | if config.useParallel 424 | parfor i = 1:numTrajectories 425 | % ... 并行收集轨迹 ... 426 | end 427 | else 428 | for i = 1:numTrajectories 429 | % ... 串行收集轨迹 ... 430 | end 431 | end 432 | ``` 433 | 434 | 并行计算特别适合轨迹收集阶段,因为不同轨迹之间没有依赖关系。对于复杂环境和大量轨迹,可以实现接近线性的加速。 435 | 436 | ### 常见问题排查 437 | 438 | #### 1. 训练不收敛 439 | 440 | 可能的原因和解决方案: 441 | - **学习率过高**:尝试降低学习率 442 | - **网络结构不适合**:增加网络容量或调整结构 443 | - **奖励设计不合理**:重新设计更有指导性的奖励函数 444 | - **优势估计不准确**:调整GAE参数(lambda) 445 | 446 | #### 2. 训练过程不稳定 447 | 448 | 可能的原因和解决方案: 449 | - **批量大小太小**:增加批量大小以减少梯度估计的方差 450 | - **裁剪参数设置不当**:调整epsilon值(0.1-0.3) 451 | - **更新步数过多**:减少每批数据的更新次数 452 | 453 | #### 3. 多智能体协调问题 454 | 455 | 可能的原因和解决方案: 456 | - **观察空间设计不合理**:确保智能体有足够信息进行协作 457 | - **共享奖励比例问题**:调整共享与个体奖励的比例 458 | - **价值网络容量不足**:增加中央评论家网络的容量 459 | 460 | #### 4. 性能分析工具 461 | 462 | 使用MATLAB内置的性能分析工具: 463 | ```matlab 464 | % 启用代码分析 465 | profile on 466 | 467 | % 运行要分析的代码 468 | agent.train(env, 10); 469 | 470 | % 查看分析报告 471 | profile viewer 472 | ``` 473 | 474 | 这有助于识别代码中的瓶颈并优化关键部分。 475 | 476 | ## 总结 477 | 478 | Matlab PPO框架提供了强大而灵活的基础设施,用于解决各种控制问题,从基础的倒立摆到复杂的多智能体系统。通过深入理解算法原理、实现细节和环境模型,用户可以充分利用本框架的能力,并根据自己的需求进行扩展和优化。 479 | 480 | 无论是研究目的还是工程应用,本框架都提供了必要的工具和灵活性,以应对现代控制系统的挑战。随着强化学习领域的不断发展,我们也将持续更新和改进本框架,以保持其实用性和先进性。 481 | -------------------------------------------------------------------------------- /config/MAPPOConfig.m: -------------------------------------------------------------------------------- 1 | classdef MAPPOConfig < handle 2 | % MAPPOConfig 多智能体PPO的配置类 3 | % 管理多智能体PPO算法的各种参数和配置 4 | 5 | properties 6 | % 环境配置 7 | envName % 环境名称 8 | numAgents % 智能体数量 9 | 10 | % 网络配置 11 | actorLayerSizes % 策略网络隐藏层大小 12 | criticLayerSizes % 中央价值网络隐藏层大小 13 | 14 | % 算法超参数 15 | gamma % 折扣因子 16 | lambda % GAE参数 17 | epsilon % PPO裁剪参数 18 | entropyCoef % 熵正则化系数 19 | vfCoef % 价值函数系数 20 | maxGradNorm % 梯度裁剪阈值 21 | 22 | % 优化器配置 23 | actorLearningRate % 策略网络学习率 24 | criticLearningRate % 价值网络学习率 25 | momentum % 动量 26 | 27 | % 训练配置 28 | numIterations % 训练迭代次数 29 | batchSize % 每次更新的批次大小 30 | epochsPerIter % 每次迭代的训练轮数 31 | trajectoryLen % 轨迹长度 32 | numTrajectories % 每次迭代收集的轨迹数量 33 | 34 | % 硬件配置 35 | useGPU % 是否使用GPU 36 | 37 | % 日志配置 38 | logDir % 日志保存目录 39 | evalFreq % 评估频率(迭代次数) 40 | numEvalEpisodes % 评估时的回合数 41 | saveModelFreq % 保存模型频率(迭代次数) 42 | end 43 | 44 | methods 45 | function obj = MAPPOConfig() 46 | % 构造函数:使用默认值初始化 47 | 48 | % 环境配置 49 | obj.envName = ''; 50 | obj.numAgents = 2; 51 | 52 | % 网络配置 53 | obj.actorLayerSizes = [64, 64]; 54 | obj.criticLayerSizes = [128, 128]; 55 | 56 | % 算法超参数 57 | obj.gamma = 0.99; 58 | obj.lambda = 0.95; 59 | obj.epsilon = 0.2; 60 | obj.entropyCoef = 0.01; 61 | obj.vfCoef = 0.5; 62 | obj.maxGradNorm = 0.5; 63 | 64 | % 优化器配置 65 | obj.actorLearningRate = 3e-4; 66 | obj.criticLearningRate = 1e-3; 67 | obj.momentum = 0.9; 68 | 69 | % 训练配置 70 | obj.numIterations = 100; 71 | obj.batchSize = 64; 72 | obj.epochsPerIter = 4; 73 | obj.trajectoryLen = 200; 74 | obj.numTrajectories = 10; 75 | 76 | % 硬件配置 77 | obj.useGPU = true; 78 | 79 | % 日志配置 80 | obj.logDir = 'logs/mappo'; 81 | obj.evalFreq = 10; 82 | obj.numEvalEpisodes = 5; 83 | obj.saveModelFreq = 20; 84 | end 85 | 86 | function params = toStruct(obj) 87 | % 将配置转换为结构体,用于保存 88 | params = struct(); 89 | 90 | props = properties(obj); 91 | for i = 1:length(props) 92 | propName = props{i}; 93 | params.(propName) = obj.(propName); 94 | end 95 | end 96 | 97 | function save(obj, filepath) 98 | % 保存配置到文件 99 | config = obj.toStruct(); 100 | save(filepath, '-struct', 'config'); 101 | fprintf('配置已保存到: %s\n', filepath); 102 | end 103 | 104 | function obj = loadFromFile(obj, filepath) 105 | % 从文件加载配置 106 | if ~exist(filepath, 'file') 107 | error('配置文件不存在: %s', filepath); 108 | end 109 | 110 | % 加载配置 111 | config = load(filepath); 112 | 113 | % 更新对象属性 114 | props = fieldnames(config); 115 | for i = 1:length(props) 116 | propName = props{i}; 117 | if isprop(obj, propName) 118 | obj.(propName) = config.(propName); 119 | end 120 | end 121 | 122 | fprintf('配置已从 %s 加载\n', filepath); 123 | end 124 | end 125 | end 126 | -------------------------------------------------------------------------------- /config/PPOConfig.m: -------------------------------------------------------------------------------- 1 | classdef PPOConfig < handle 2 | % PPOConfig PPO算法配置类 3 | % 用于管理和加载PPO算法的所有配置参数 4 | 5 | properties 6 | % 环境配置 7 | envName % 环境名称 8 | 9 | % 网络配置 10 | actorLayerSizes % Actor网络隐藏层大小 11 | criticLayerSizes % Critic网络隐藏层大小 12 | 13 | % 算法超参数 14 | gamma % 折扣因子 15 | lambda % GAE参数 16 | epsilon % PPO裁剪参数 17 | entropyCoef % 熵正则化系数 18 | vfCoef % 价值函数系数 19 | maxGradNorm % 梯度裁剪 20 | 21 | % 优化器配置 22 | actorLearningRate % Actor学习率 23 | criticLearningRate % Critic学习率 24 | momentum % 动量 25 | 26 | % 训练配置 27 | numIterations % 训练迭代次数 28 | batchSize % 批次大小 29 | epochsPerIter % 每次迭代的训练轮数 30 | trajectoryLen % 轨迹长度 31 | numTrajectories % 每次迭代收集的轨迹数量 32 | 33 | % 硬件配置 34 | useGPU % 是否使用GPU 35 | 36 | % 日志配置 37 | logDir % 日志保存目录 38 | evalFreq % 评估频率 (迭代次数) 39 | numEvalEpisodes % 评估时的回合数 40 | saveModelFreq % 保存模型频率 (迭代次数) 41 | end 42 | 43 | methods 44 | function obj = PPOConfig() 45 | % 构造函数:设置默认配置 46 | 47 | % 环境配置 48 | obj.envName = 'CartPoleEnv'; 49 | 50 | % 网络配置 51 | obj.actorLayerSizes = [64, 64]; 52 | obj.criticLayerSizes = [64, 64]; 53 | 54 | % 算法超参数 55 | obj.gamma = 0.99; 56 | obj.lambda = 0.95; 57 | obj.epsilon = 0.2; 58 | obj.entropyCoef = 0.01; 59 | obj.vfCoef = 0.5; 60 | obj.maxGradNorm = 0.5; 61 | 62 | % 优化器配置 63 | obj.actorLearningRate = 3e-4; 64 | obj.criticLearningRate = 3e-4; 65 | obj.momentum = 0.9; 66 | 67 | % 训练配置 68 | obj.numIterations = 100; 69 | obj.batchSize = 128; 70 | obj.epochsPerIter = 4; 71 | obj.trajectoryLen = 200; 72 | obj.numTrajectories = 10; 73 | 74 | % 硬件配置 75 | obj.useGPU = true; 76 | 77 | % 日志配置 78 | obj.logDir = 'logs'; 79 | obj.evalFreq = 10; 80 | obj.numEvalEpisodes = 10; 81 | obj.saveModelFreq = 20; 82 | end 83 | 84 | function config = loadFromFile(obj, filePath) 85 | % 从文件加载配置 86 | % filePath - 配置文件路径 87 | % 返回填充了配置的对象 88 | 89 | % 检查文件是否存在 90 | if ~exist(filePath, 'file') 91 | error('配置文件不存在: %s', filePath); 92 | end 93 | 94 | % 加载配置 95 | configData = load(filePath); 96 | 97 | % 获取所有属性名 98 | propNames = properties(obj); 99 | 100 | % 遍历所有属性并设置值 101 | for i = 1:length(propNames) 102 | propName = propNames{i}; 103 | 104 | % 如果配置文件中有对应的字段,则设置该属性 105 | if isfield(configData, propName) 106 | obj.(propName) = configData.(propName); 107 | end 108 | end 109 | 110 | config = obj; 111 | end 112 | 113 | function saveToFile(obj, filePath) 114 | % 将配置保存到文件 115 | % filePath - 保存路径 116 | 117 | % 确保目录存在 118 | [filePath, ~, ~] = fileparts(filePath); 119 | if ~exist(filePath, 'dir') 120 | mkdir(filePath); 121 | end 122 | 123 | % 将对象转换为结构体 124 | configStruct = obj.toStruct(); 125 | 126 | % 保存到文件 127 | save(filePath, '-struct', 'configStruct'); 128 | fprintf('配置已保存到: %s\n', filePath); 129 | end 130 | 131 | function configStruct = toStruct(obj) 132 | % 将配置对象转换为结构体 133 | 134 | configStruct = struct(); 135 | propNames = properties(obj); 136 | 137 | for i = 1:length(propNames) 138 | propName = propNames{i}; 139 | configStruct.(propName) = obj.(propName); 140 | end 141 | end 142 | 143 | function varargout = subsref(obj, s) 144 | % 重载下标引用操作,支持点符号和括号引用 145 | 146 | switch s(1).type 147 | case '.' 148 | % 获取属性值 149 | if length(s) == 1 150 | % 直接访问属性 151 | varargout{1} = obj.(s(1).subs); 152 | else 153 | % 级联访问 154 | [varargout{1:nargout}] = subsref(obj.(s(1).subs), s(2:end)); 155 | end 156 | case '()' 157 | % 括号引用,调用父类的处理 158 | [varargout{1:nargout}] = builtin('subsref', obj, s); 159 | case '{}' 160 | % 花括号引用,调用父类的处理 161 | [varargout{1:nargout}] = builtin('subsref', obj, s); 162 | end 163 | end 164 | 165 | function obj = subsasgn(obj, s, val) 166 | % 重载下标赋值操作,支持点符号和括号赋值 167 | 168 | switch s(1).type 169 | case '.' 170 | % 设置属性值 171 | if length(s) == 1 172 | % 直接设置属性 173 | obj.(s(1).subs) = val; 174 | else 175 | % 级联设置 176 | obj.(s(1).subs) = subsasgn(obj.(s(1).subs), s(2:end), val); 177 | end 178 | case '()' 179 | % 括号赋值,调用父类的处理 180 | obj = builtin('subsasgn', obj, s, val); 181 | case '{}' 182 | % 花括号赋值,调用父类的处理 183 | obj = builtin('subsasgn', obj, s, val); 184 | end 185 | end 186 | 187 | function disp(obj) 188 | % 显示对象信息 189 | 190 | fprintf('PPO配置:\n'); 191 | propNames = properties(obj); 192 | 193 | for i = 1:length(propNames) 194 | propName = propNames{i}; 195 | propValue = obj.(propName); 196 | 197 | % 根据属性值类型格式化输出 198 | if ischar(propValue) 199 | fprintf(' %s: %s\n', propName, propValue); 200 | elseif isnumeric(propValue) && isscalar(propValue) 201 | fprintf(' %s: %.6g\n', propName, propValue); 202 | elseif isnumeric(propValue) && ~isscalar(propValue) 203 | fprintf(' %s: [', propName); 204 | fprintf('%.6g ', propValue); 205 | fprintf(']\n'); 206 | elseif islogical(propValue) 207 | if propValue 208 | fprintf(' %s: true\n', propName); 209 | else 210 | fprintf(' %s: false\n', propName); 211 | end 212 | else 213 | fprintf(' %s: %s\n', propName, class(propValue)); 214 | end 215 | end 216 | end 217 | end 218 | end 219 | -------------------------------------------------------------------------------- /config/default_acmotor_config.m: -------------------------------------------------------------------------------- 1 | % default_acmotor_config.m 2 | % 交流感应电机环境的默认PPO配置文件 3 | 4 | % 创建配置对象 5 | config = PPOConfig(); 6 | 7 | % 环境配置 8 | config.envName = 'ACMotorEnv'; 9 | 10 | % 网络配置 - 对于复杂的交流电机控制,使用更深更宽的网络 11 | config.actorLayerSizes = [256, 256, 128, 64]; % Actor网络隐藏层大小 12 | config.criticLayerSizes = [256, 256, 128, 64]; % Critic网络隐藏层大小 13 | 14 | % 算法超参数 15 | config.gamma = 0.99; % 折扣因子 16 | config.lambda = 0.95; % GAE参数 17 | config.epsilon = 0.1; % PPO裁剪参数(减小以更加保守地学习) 18 | config.entropyCoef = 0.001; % 熵正则化系数(交流电机控制更精确,更低的熵系数) 19 | config.vfCoef = 0.5; % 价值函数系数 20 | config.maxGradNorm = 0.5; % 梯度裁剪 21 | 22 | % 优化器配置 23 | config.actorLearningRate = 5e-5; % Actor学习率(交流电机更复杂,使用更小的学习率) 24 | config.criticLearningRate = 5e-5; % Critic学习率 25 | config.momentum = 0.9; % 动量 26 | 27 | % 训练配置 28 | config.numIterations = 400; % 训练迭代次数(复杂系统需要更多迭代) 29 | config.batchSize = 256; % 批次大小 30 | config.epochsPerIter = 15; % 每次迭代的训练轮数(增加以提高样本利用效率) 31 | config.trajectoryLen = 500; % 轨迹长度(交流电机采样率高,需要更长轨迹) 32 | config.numTrajectories = 32; % 每次迭代收集的轨迹数量(增加以提高稳定性) 33 | 34 | % 硬件配置 35 | config.useGPU = true; % 是否使用GPU 36 | 37 | % 日志配置 38 | config.logDir = 'logs/acmotor'; % 日志保存目录 39 | config.evalFreq = 10; % 评估频率 (迭代次数) 40 | config.numEvalEpisodes = 5; % 评估时的回合数 41 | config.saveModelFreq = 40; % 保存模型频率 (迭代次数) 42 | 43 | % 保存配置 44 | saveDir = 'config'; 45 | if ~exist(saveDir, 'dir') 46 | mkdir(saveDir); 47 | end 48 | 49 | save('config/acmotor_config.mat', '-struct', 'config.toStruct()'); 50 | fprintf('交流感应电机环境的默认配置已保存到: config/acmotor_config.mat\n'); 51 | -------------------------------------------------------------------------------- /config/default_cartpole_config.m: -------------------------------------------------------------------------------- 1 | % default_cartpole_config.m 2 | % 倒立摆环境的默认PPO配置文件 3 | 4 | % 创建配置对象 5 | config = PPOConfig(); 6 | 7 | % 环境配置 8 | config.envName = 'CartPoleEnv'; 9 | 10 | % 网络配置 11 | config.actorLayerSizes = [64, 64]; % Actor网络隐藏层大小 12 | config.criticLayerSizes = [64, 64]; % Critic网络隐藏层大小 13 | 14 | % 算法超参数 15 | config.gamma = 0.99; % 折扣因子 16 | config.lambda = 0.95; % GAE参数 17 | config.epsilon = 0.2; % PPO裁剪参数 18 | config.entropyCoef = 0.01; % 熵正则化系数 19 | config.vfCoef = 0.5; % 价值函数系数 20 | config.maxGradNorm = 0.5; % 梯度裁剪 21 | 22 | % 优化器配置 23 | config.actorLearningRate = 3e-4; % Actor学习率 24 | config.criticLearningRate = 3e-4; % Critic学习率 25 | config.momentum = 0.9; % 动量 26 | 27 | % 训练配置 28 | config.numIterations = 100; % 训练迭代次数 29 | config.batchSize = 64; % 批次大小 30 | config.epochsPerIter = 4; % 每次迭代的训练轮数 31 | config.trajectoryLen = 200; % 轨迹长度 32 | config.numTrajectories = 10; % 每次迭代收集的轨迹数量 33 | 34 | % 硬件配置 35 | config.useGPU = true; % 是否使用GPU 36 | 37 | % 日志配置 38 | config.logDir = 'logs/cartpole'; % 日志保存目录 39 | config.evalFreq = 5; % 评估频率 (迭代次数) 40 | config.numEvalEpisodes = 10; % 评估时的回合数 41 | config.saveModelFreq = 10; % 保存模型频率 (迭代次数) 42 | 43 | % 保存配置 44 | saveDir = 'config'; 45 | if ~exist(saveDir, 'dir') 46 | mkdir(saveDir); 47 | end 48 | 49 | save('config/cartpole_config.mat', '-struct', 'config.toStruct()'); 50 | fprintf('倒立摆环境的默认配置已保存到: config/cartpole_config.mat\n'); 51 | -------------------------------------------------------------------------------- /config/default_dcmotor_config.m: -------------------------------------------------------------------------------- 1 | % default_dcmotor_config.m 2 | % 直流电机环境的默认PPO配置文件 3 | 4 | % 创建配置对象 5 | config = PPOConfig(); 6 | 7 | % 环境配置 8 | config.envName = 'DCMotorEnv'; 9 | 10 | % 网络配置 - 使用更深的网络处理连续控制问题 11 | config.actorLayerSizes = [128, 128, 64]; % Actor网络隐藏层大小 12 | config.criticLayerSizes = [128, 128, 64]; % Critic网络隐藏层大小 13 | 14 | % 算法超参数 15 | config.gamma = 0.99; % 折扣因子 16 | config.lambda = 0.95; % GAE参数 17 | config.epsilon = 0.2; % PPO裁剪参数 18 | config.entropyCoef = 0.005; % 熵正则化系数(由于连续控制需要更精确,降低熵系数) 19 | config.vfCoef = 0.5; % 价值函数系数 20 | config.maxGradNorm = 0.5; % 梯度裁剪 21 | 22 | % 优化器配置 23 | config.actorLearningRate = 1e-4; % Actor学习率(对连续控制使用较小的学习率) 24 | config.criticLearningRate = 1e-4; % Critic学习率 25 | config.momentum = 0.9; % 动量 26 | 27 | % 训练配置 28 | config.numIterations = 200; % 训练迭代次数 29 | config.batchSize = 128; % 批次大小 30 | config.epochsPerIter = 10; % 每次迭代的训练轮数(增加以提高样本利用效率) 31 | config.trajectoryLen = 250; % 轨迹长度 32 | config.numTrajectories = 16; % 每次迭代收集的轨迹数量 33 | 34 | % 硬件配置 35 | config.useGPU = true; % 是否使用GPU 36 | 37 | % 日志配置 38 | config.logDir = 'logs/dcmotor'; % 日志保存目录 39 | config.evalFreq = 5; % 评估频率 (迭代次数) 40 | config.numEvalEpisodes = 10; % 评估时的回合数 41 | config.saveModelFreq = 20; % 保存模型频率 (迭代次数) 42 | 43 | % 保存配置 44 | saveDir = 'config'; 45 | if ~exist(saveDir, 'dir') 46 | mkdir(saveDir); 47 | end 48 | 49 | save('config/dcmotor_config.mat', '-struct', 'config.toStruct()'); 50 | fprintf('直流电机环境的默认配置已保存到: config/dcmotor_config.mat\n'); 51 | -------------------------------------------------------------------------------- /config/default_doublependulum_config.m: -------------------------------------------------------------------------------- 1 | function config = default_doublependulum_config() 2 | % default_doublependulum_config 双倒立摆环境的默认MAPPO配置 3 | 4 | % 创建MAPPO配置对象 5 | config = MAPPOConfig(); 6 | 7 | % 设置环境 8 | config.envName = 'DoublePendulumEnv'; 9 | config.numAgents = 2; % 两个智能体,各控制一个摆杆 10 | 11 | % 网络配置 - 使用较大的网络以处理复杂的动力学 12 | config.actorLayerSizes = [128, 64]; 13 | config.criticLayerSizes = [256, 128]; 14 | 15 | % 算法超参数 16 | config.gamma = 0.99; % 折扣因子 17 | config.lambda = 0.95; % GAE参数 18 | config.epsilon = 0.2; % PPO裁剪参数 19 | config.entropyCoef = 0.01; % 熵正则化系数,鼓励探索 20 | config.vfCoef = 0.5; % 价值函数系数 21 | config.maxGradNorm = 0.5; % 梯度裁剪阈值 22 | 23 | % 优化器配置 24 | config.actorLearningRate = 3e-4; 25 | config.criticLearningRate = 1e-3; 26 | config.momentum = 0.9; 27 | 28 | % 训练配置 29 | config.numIterations = 200; % 训练迭代次数 30 | config.batchSize = 64; % 批次大小 31 | config.epochsPerIter = 4; % 每次迭代的训练轮数 32 | config.trajectoryLen = 200; % 轨迹长度 33 | config.numTrajectories = 20; % 每次迭代收集的轨迹数量 34 | 35 | % 硬件配置 36 | config.useGPU = true; 37 | 38 | % 日志配置 39 | config.logDir = 'logs/doublependulum'; 40 | config.evalFreq = 10; % 每10次迭代评估一次 41 | config.numEvalEpisodes = 5; % 每次评估5个回合 42 | config.saveModelFreq = 20; % 每20次迭代保存一次模型 43 | end 44 | -------------------------------------------------------------------------------- /core/ContinuousActorNetwork.m: -------------------------------------------------------------------------------- 1 | classdef ContinuousActorNetwork < handle 2 | % ContinuousActorNetwork 用于连续动作空间的策略网络 3 | % 实现了具有高斯分布输出的策略网络 4 | 5 | properties 6 | layerSizes % 隐藏层大小 7 | learnables % 可学习参数 8 | useGPU % 是否使用GPU 9 | end 10 | 11 | methods 12 | function obj = ContinuousActorNetwork(inputSize, outputSize, layerSizes) 13 | % 构造函数:初始化连续动作空间的Actor网络 14 | % inputSize - 输入维度(观察空间大小) 15 | % outputSize - 输出维度(动作空间大小) 16 | % layerSizes - 隐藏层大小的数组 17 | 18 | obj.layerSizes = layerSizes; 19 | obj.useGPU = false; 20 | 21 | % 初始化网络参数 22 | obj.learnables = struct(); 23 | 24 | % 初始化第一个隐藏层 25 | layerIdx = 1; 26 | obj.learnables.fc1w = dlarray(initializeGlorot(layerSizes(1), inputSize), 'U'); 27 | obj.learnables.fc1b = dlarray(zeros(layerSizes(1), 1), 'U'); 28 | 29 | % 初始化中间隐藏层 30 | for i = 2:length(layerSizes) 31 | obj.learnables.(sprintf('fc%dw', i)) = dlarray(initializeGlorot(layerSizes(i), layerSizes(i-1)), 'U'); 32 | obj.learnables.(sprintf('fc%db', i)) = dlarray(zeros(layerSizes(i), 1), 'U'); 33 | end 34 | 35 | % 初始化均值输出层 36 | obj.learnables.meanw = dlarray(initializeGlorot(outputSize, layerSizes(end)), 'U'); 37 | obj.learnables.meanb = dlarray(zeros(outputSize, 1), 'U'); 38 | 39 | % 初始化方差输出层 (log of std) 40 | obj.learnables.logstdw = dlarray(initializeGlorot(outputSize, layerSizes(end)), 'U'); 41 | obj.learnables.logstdb = dlarray(zeros(outputSize, 1), 'U'); 42 | end 43 | 44 | function [action, logProb, mean, std] = sampleAction(obj, observation) 45 | % 采样一个动作并计算其对数概率 46 | % observation - 当前的观察状态 47 | % action - 采样的动作 48 | % logProb - 动作的对数概率 49 | 50 | % 前向传播,获取动作分布参数 51 | [mean, logstd] = obj.forward(observation); 52 | std = exp(logstd); 53 | 54 | % 从正态分布采样 55 | if obj.useGPU 56 | noise = dlarray(gpuArray(randn(size(mean))), 'CB'); 57 | else 58 | noise = dlarray(randn(size(mean)), 'CB'); 59 | end 60 | 61 | action = mean + std .* noise; 62 | 63 | % 计算对数概率 64 | logProb = -0.5 * sum((action - mean).^2 ./ (std.^2 + 1e-8), 1) - ... 65 | sum(logstd, 1) - ... 66 | 0.5 * size(mean, 1) * log(2 * pi); 67 | end 68 | 69 | function [logProb, entropy] = evaluateActions(obj, observation, action) 70 | % 评估动作的对数概率和熵 71 | % observation - 观察状态 72 | % action - 执行的动作 73 | % logProb - 动作的对数概率 74 | % entropy - 策略的熵 75 | 76 | % 前向传播,获取动作分布参数 77 | [mean, logstd] = obj.forward(observation); 78 | std = exp(logstd); 79 | 80 | % 计算对数概率 81 | logProb = -0.5 * sum((action - mean).^2 ./ (std.^2 + 1e-8), 1) - ... 82 | sum(logstd, 1) - ... 83 | 0.5 * size(mean, 1) * log(2 * pi); 84 | 85 | % 计算熵 86 | entropy = sum(logstd + 0.5 * log(2 * pi * exp(1)), 1); 87 | end 88 | 89 | function [mean, logstd] = forward(obj, observation) 90 | % 前向传播,计算动作分布参数 91 | % observation - 观察状态 92 | % mean - 动作均值 93 | % logstd - 动作标准差的对数 94 | 95 | % 第一个隐藏层 96 | x = fullyconnect(observation, obj.learnables.fc1w, obj.learnables.fc1b); 97 | x = tanh(x); 98 | 99 | % 中间隐藏层 100 | for i = 2:length(obj.layerSizes) 101 | x = fullyconnect(x, obj.learnables.(sprintf('fc%dw', i)), obj.learnables.(sprintf('fc%db', i))); 102 | x = tanh(x); 103 | end 104 | 105 | % 均值输出层 106 | mean = fullyconnect(x, obj.learnables.meanw, obj.learnables.meanb); 107 | 108 | % 方差输出层 (log of std) 109 | logstd = fullyconnect(x, obj.learnables.logstdw, obj.learnables.logstdb); 110 | 111 | % 限制logstd范围,防止数值不稳定 112 | logstd = max(min(logstd, 2), -20); 113 | end 114 | 115 | function action = getMeanAction(obj, observation) 116 | % 获取确定性动作(均值) 117 | % observation - 观察状态 118 | % action - 确定性动作 119 | 120 | % 前向传播,获取动作分布的均值 121 | [mean, ~] = obj.forward(observation); 122 | action = mean; 123 | end 124 | 125 | function toGPU(obj) 126 | % 将网络参数转移到GPU 127 | obj.useGPU = true; 128 | 129 | % 将所有参数移至GPU 130 | fnames = fieldnames(obj.learnables); 131 | for i = 1:length(fnames) 132 | obj.learnables.(fnames{i}) = gpuArray(obj.learnables.(fnames{i})); 133 | end 134 | end 135 | 136 | function toCPU(obj) 137 | % 将网络参数转移回CPU 138 | obj.useGPU = false; 139 | 140 | % 将所有参数移回CPU 141 | fnames = fieldnames(obj.learnables); 142 | for i = 1:length(fnames) 143 | obj.learnables.(fnames{i}) = gather(obj.learnables.(fnames{i})); 144 | end 145 | end 146 | 147 | function params = getParameters(obj) 148 | % 获取网络参数 149 | params = struct2cell(obj.learnables); 150 | end 151 | 152 | function setParameters(obj, params) 153 | % 设置网络参数 154 | fnames = fieldnames(obj.learnables); 155 | for i = 1:length(fnames) 156 | obj.learnables.(fnames{i}) = params{i}; 157 | end 158 | end 159 | end 160 | end 161 | 162 | function W = initializeGlorot(numOut, numIn) 163 | % Glorot/Xavier初始化 164 | stddev = sqrt(2 / (numIn + numOut)); 165 | W = stddev * randn(numOut, numIn); 166 | end 167 | -------------------------------------------------------------------------------- /core/CriticNetwork.m: -------------------------------------------------------------------------------- 1 | classdef CriticNetwork < handle 2 | % CriticNetwork 价值网络 3 | % 实现了用于估计状态价值的网络 4 | 5 | properties 6 | layerSizes % 隐藏层大小 7 | learnables % 可学习参数 8 | useGPU % 是否使用GPU 9 | end 10 | 11 | methods 12 | function obj = CriticNetwork(inputSize, layerSizes) 13 | % 构造函数:初始化价值网络 14 | % inputSize - 输入维度(观察空间大小) 15 | % layerSizes - 隐藏层大小的数组 16 | 17 | obj.layerSizes = layerSizes; 18 | obj.useGPU = false; 19 | 20 | % 初始化网络参数 21 | obj.learnables = struct(); 22 | 23 | % 初始化第一个隐藏层 24 | obj.learnables.fc1w = dlarray(initializeGlorot(layerSizes(1), inputSize), 'U'); 25 | obj.learnables.fc1b = dlarray(zeros(layerSizes(1), 1), 'U'); 26 | 27 | % 初始化中间隐藏层 28 | for i = 2:length(layerSizes) 29 | obj.learnables.(sprintf('fc%dw', i)) = dlarray(initializeGlorot(layerSizes(i), layerSizes(i-1)), 'U'); 30 | obj.learnables.(sprintf('fc%db', i)) = dlarray(zeros(layerSizes(i), 1), 'U'); 31 | end 32 | 33 | % 初始化输出层(价值) 34 | obj.learnables.outw = dlarray(initializeGlorot(1, layerSizes(end)), 'U'); 35 | obj.learnables.outb = dlarray(zeros(1, 1), 'U'); 36 | end 37 | 38 | function value = getValue(obj, observation) 39 | % 获取状态价值估计 40 | % observation - 观察状态 41 | % value - 状态价值估计 42 | 43 | % 前向传播 44 | x = fullyconnect(observation, obj.learnables.fc1w, obj.learnables.fc1b); 45 | x = tanh(x); 46 | 47 | % 中间隐藏层 48 | for i = 2:length(obj.layerSizes) 49 | x = fullyconnect(x, obj.learnables.(sprintf('fc%dw', i)), obj.learnables.(sprintf('fc%db', i))); 50 | x = tanh(x); 51 | end 52 | 53 | % 输出层 54 | value = fullyconnect(x, obj.learnables.outw, obj.learnables.outb); 55 | end 56 | 57 | function toGPU(obj) 58 | % 将网络参数转移到GPU 59 | obj.useGPU = true; 60 | 61 | % 将所有参数移至GPU 62 | fnames = fieldnames(obj.learnables); 63 | for i = 1:length(fnames) 64 | obj.learnables.(fnames{i}) = gpuArray(obj.learnables.(fnames{i})); 65 | end 66 | end 67 | 68 | function toCPU(obj) 69 | % 将网络参数转移回CPU 70 | obj.useGPU = false; 71 | 72 | % 将所有参数移回CPU 73 | fnames = fieldnames(obj.learnables); 74 | for i = 1:length(fnames) 75 | obj.learnables.(fnames{i}) = gather(obj.learnables.(fnames{i})); 76 | end 77 | end 78 | 79 | function params = getParameters(obj) 80 | % 获取网络参数 81 | params = struct2cell(obj.learnables); 82 | end 83 | 84 | function setParameters(obj, params) 85 | % 设置网络参数 86 | fnames = fieldnames(obj.learnables); 87 | for i = 1:length(fnames) 88 | obj.learnables.(fnames{i}) = params{i}; 89 | end 90 | end 91 | end 92 | end 93 | 94 | function W = initializeGlorot(numOut, numIn) 95 | % Glorot/Xavier初始化 96 | stddev = sqrt(2 / (numIn + numOut)); 97 | W = stddev * randn(numOut, numIn); 98 | end 99 | -------------------------------------------------------------------------------- /core/DiscreteActorNetwork.m: -------------------------------------------------------------------------------- 1 | classdef DiscreteActorNetwork < handle 2 | % DiscreteActorNetwork 用于离散动作空间的策略网络 3 | % 实现了具有Categorical分布输出的策略网络 4 | 5 | properties 6 | layerSizes % 隐藏层大小 7 | learnables % 可学习参数 8 | useGPU % 是否使用GPU 9 | end 10 | 11 | methods 12 | function obj = DiscreteActorNetwork(inputSize, outputSize, layerSizes) 13 | % 构造函数:初始化离散动作空间的Actor网络 14 | % inputSize - 输入维度(观察空间大小) 15 | % outputSize - 输出维度(动作空间大小,即动作数量) 16 | % layerSizes - 隐藏层大小的数组 17 | 18 | obj.layerSizes = layerSizes; 19 | obj.useGPU = false; 20 | 21 | % 初始化网络参数 22 | obj.learnables = struct(); 23 | 24 | % 初始化第一个隐藏层 25 | layerIdx = 1; 26 | obj.learnables.fc1w = dlarray(initializeGlorot(layerSizes(1), inputSize), 'U'); 27 | obj.learnables.fc1b = dlarray(zeros(layerSizes(1), 1), 'U'); 28 | 29 | % 初始化中间隐藏层 30 | for i = 2:length(layerSizes) 31 | obj.learnables.(sprintf('fc%dw', i)) = dlarray(initializeGlorot(layerSizes(i), layerSizes(i-1)), 'U'); 32 | obj.learnables.(sprintf('fc%db', i)) = dlarray(zeros(layerSizes(i), 1), 'U'); 33 | end 34 | 35 | % 初始化输出层(logits) 36 | obj.learnables.outw = dlarray(initializeGlorot(outputSize, layerSizes(end)), 'U'); 37 | obj.learnables.outb = dlarray(zeros(outputSize, 1), 'U'); 38 | end 39 | 40 | function [action, logProb, probs] = sampleAction(obj, observation) 41 | % 采样一个动作并计算其对数概率 42 | % observation - 当前的观察状态 43 | % action - 采样的动作(one-hot编码) 44 | % logProb - 动作的对数概率 45 | % probs - 各动作的概率 46 | 47 | % 前向传播,获取动作概率 48 | logits = obj.forward(observation); 49 | probs = softmax(logits); 50 | 51 | % 从类别分布中采样 52 | cumProbs = cumsum(extractdata(probs), 1); 53 | 54 | if obj.useGPU 55 | r = dlarray(gpuArray(rand(1, size(probs, 2))), 'CB'); 56 | else 57 | r = dlarray(rand(1, size(probs, 2)), 'CB'); 58 | end 59 | 60 | % 初始化动作 61 | actionSize = size(probs, 1); 62 | action = zeros(size(probs), 'like', probs); 63 | 64 | % 根据采样结果设置one-hot动作 65 | for i = 1:size(probs, 2) 66 | for j = 1:actionSize 67 | if r(i) <= cumProbs(j, i) 68 | action(j, i) = 1; 69 | break; 70 | end 71 | end 72 | end 73 | 74 | % 计算对数概率 75 | logProb = log(sum(probs .* action, 1) + 1e-10); 76 | end 77 | 78 | function [logProb, entropy] = evaluateActions(obj, observation, action) 79 | % 评估动作的对数概率和熵 80 | % observation - 观察状态 81 | % action - 执行的动作(one-hot编码) 82 | % logProb - 动作的对数概率 83 | % entropy - 策略的熵 84 | 85 | % 前向传播,获取动作概率 86 | logits = obj.forward(observation); 87 | probs = softmax(logits); 88 | 89 | % 计算对数概率 90 | logProb = log(sum(probs .* action, 1) + 1e-10); 91 | 92 | % 计算熵 93 | entropy = -sum(probs .* log(probs + 1e-10), 1); 94 | end 95 | 96 | function logits = forward(obj, observation) 97 | % 前向传播,计算动作分布参数 98 | % observation - 观察状态 99 | % logits - 未经softmax的输出 100 | 101 | % 第一个隐藏层 102 | x = fullyconnect(observation, obj.learnables.fc1w, obj.learnables.fc1b); 103 | x = relu(x); 104 | 105 | % 中间隐藏层 106 | for i = 2:length(obj.layerSizes) 107 | x = fullyconnect(x, obj.learnables.(sprintf('fc%dw', i)), obj.learnables.(sprintf('fc%db', i))); 108 | x = relu(x); 109 | end 110 | 111 | % 输出层 112 | logits = fullyconnect(x, obj.learnables.outw, obj.learnables.outb); 113 | end 114 | 115 | function action = getBestAction(obj, observation) 116 | % 获取确定性动作(最高概率) 117 | % observation - 观察状态 118 | % action - 确定性动作(one-hot编码) 119 | 120 | % 前向传播,获取动作概率 121 | logits = obj.forward(observation); 122 | probs = softmax(logits); 123 | 124 | % 选择最高概率的动作 125 | [~, idx] = max(extractdata(probs), [], 1); 126 | 127 | % 创建one-hot动作 128 | action = zeros(size(probs), 'like', probs); 129 | for i = 1:size(probs, 2) 130 | action(idx(i), i) = 1; 131 | end 132 | end 133 | 134 | function toGPU(obj) 135 | % 将网络参数转移到GPU 136 | obj.useGPU = true; 137 | 138 | % 将所有参数移至GPU 139 | fnames = fieldnames(obj.learnables); 140 | for i = 1:length(fnames) 141 | obj.learnables.(fnames{i}) = gpuArray(obj.learnables.(fnames{i})); 142 | end 143 | end 144 | 145 | function toCPU(obj) 146 | % 将网络参数转移回CPU 147 | obj.useGPU = false; 148 | 149 | % 将所有参数移回CPU 150 | fnames = fieldnames(obj.learnables); 151 | for i = 1:length(fnames) 152 | obj.learnables.(fnames{i}) = gather(obj.learnables.(fnames{i})); 153 | end 154 | end 155 | 156 | function params = getParameters(obj) 157 | % 获取网络参数 158 | params = struct2cell(obj.learnables); 159 | end 160 | 161 | function setParameters(obj, params) 162 | % 设置网络参数 163 | fnames = fieldnames(obj.learnables); 164 | for i = 1:length(fnames) 165 | obj.learnables.(fnames{i}) = params{i}; 166 | end 167 | end 168 | end 169 | end 170 | 171 | function W = initializeGlorot(numOut, numIn) 172 | % Glorot/Xavier初始化 173 | stddev = sqrt(2 / (numIn + numOut)); 174 | W = stddev * randn(numOut, numIn); 175 | end 176 | -------------------------------------------------------------------------------- /core/PPOAgent.m: -------------------------------------------------------------------------------- 1 | classdef PPOAgent < handle 2 | % PPOAgent 基于近端策略优化(PPO)算法的强化学习代理 3 | % 这个类实现了PPO算法,支持GPU并行化训练 4 | 5 | properties 6 | % 环境相关 7 | envName % 环境名称 8 | env % 环境对象 9 | obsSize % 观察空间大小 10 | actionSize % 动作空间大小 11 | isDiscrete % 是否为离散动作空间 12 | 13 | % 策略网络 14 | actorNet % 策略网络(Actor) 15 | criticNet % 价值网络(Critic) 16 | actorOptimizer % Actor优化器 17 | criticOptimizer % Critic优化器 18 | 19 | % 超参数 20 | gamma % 折扣因子 21 | lambda % GAE参数 22 | epsilon % PPO裁剪参数 23 | entropyCoef % 熵正则化系数 24 | vfCoef % 价值函数系数 25 | maxGradNorm % 梯度裁剪 26 | 27 | % 训练参数 28 | batchSize % 批次大小 29 | epochsPerIter % 每次迭代的训练轮数 30 | trajectoryLen % 轨迹长度 31 | 32 | % GPU加速 33 | useGPU % 是否使用GPU 34 | gpuDevice % GPU设备 35 | 36 | % 记录与可视化 37 | logger % 日志记录器 38 | end 39 | 40 | methods 41 | function obj = PPOAgent(config) 42 | % 构造函数:初始化PPO代理 43 | % config - 包含所有配置参数的结构体 44 | 45 | % 设置环境 46 | obj.envName = config.envName; 47 | obj.env = feval(config.envName); 48 | obj.obsSize = obj.env.observationSize; 49 | obj.actionSize = obj.env.actionSize; 50 | obj.isDiscrete = obj.env.isDiscrete; 51 | 52 | % 设置超参数 53 | obj.gamma = config.gamma; 54 | obj.lambda = config.lambda; 55 | obj.epsilon = config.epsilon; 56 | obj.entropyCoef = config.entropyCoef; 57 | obj.vfCoef = config.vfCoef; 58 | obj.maxGradNorm = config.maxGradNorm; 59 | 60 | % 设置训练参数 61 | obj.batchSize = config.batchSize; 62 | obj.epochsPerIter = config.epochsPerIter; 63 | obj.trajectoryLen = config.trajectoryLen; 64 | 65 | % 设置GPU加速 66 | obj.useGPU = config.useGPU; 67 | if obj.useGPU 68 | if gpuDeviceCount > 0 69 | obj.gpuDevice = gpuDevice(1); 70 | fprintf('使用GPU: %s\n', obj.gpuDevice.Name); 71 | else 72 | warning('未检测到GPU, 将使用CPU训练'); 73 | obj.useGPU = false; 74 | end 75 | end 76 | 77 | % 初始化策略网络与价值网络 78 | obj.initNetworks(config); 79 | 80 | % 初始化优化器 81 | obj.initOptimizers(config); 82 | 83 | % 初始化日志记录 84 | obj.logger = Logger(config.logDir, obj.envName); 85 | end 86 | 87 | function initNetworks(obj, config) 88 | % 初始化策略网络与价值网络 89 | % 策略网络(Actor) 90 | if obj.isDiscrete 91 | obj.actorNet = DiscreteActorNetwork(obj.obsSize, obj.actionSize, config.actorLayerSizes); 92 | else 93 | obj.actorNet = ContinuousActorNetwork(obj.obsSize, obj.actionSize, config.actorLayerSizes); 94 | end 95 | 96 | % 价值网络(Critic) 97 | obj.criticNet = CriticNetwork(obj.obsSize, config.criticLayerSizes); 98 | 99 | % 如果使用GPU,则将网络迁移到GPU 100 | if obj.useGPU 101 | obj.actorNet.toGPU(); 102 | obj.criticNet.toGPU(); 103 | end 104 | end 105 | 106 | function initOptimizers(obj, config) 107 | % 初始化优化器 108 | obj.actorOptimizer = dlupdate.sgdm(config.actorLearningRate, config.momentum); 109 | obj.criticOptimizer = dlupdate.sgdm(config.criticLearningRate, config.momentum); 110 | end 111 | 112 | function train(obj, numIterations) 113 | % 训练PPO代理 114 | % numIterations - 训练迭代次数 115 | 116 | fprintf('开始训练 %s 环境的PPO代理\n', obj.envName); 117 | 118 | for iter = 1:numIterations 119 | fprintf('迭代 %d/%d\n', iter, numIterations); 120 | 121 | % 收集轨迹数据 122 | fprintf('收集轨迹...\n'); 123 | trajectories = obj.collectTrajectories(); 124 | 125 | % 计算优势函数和回报 126 | fprintf('计算优势函数和回报...\n'); 127 | trajectories = obj.computeAdvantagesAndReturns(trajectories); 128 | 129 | % PPO更新 130 | fprintf('执行PPO更新...\n'); 131 | metrics = obj.updatePolicy(trajectories); 132 | 133 | % 记录本次迭代的性能指标 134 | obj.logger.logIteration(iter, metrics); 135 | 136 | % 每10次迭代保存模型 137 | if mod(iter, 10) == 0 138 | obj.saveModel(fullfile(obj.logger.logDir, ['model_iter_', num2str(iter), '.mat'])); 139 | end 140 | end 141 | 142 | fprintf('训练完成\n'); 143 | end 144 | 145 | function trajectories = collectTrajectories(obj) 146 | % 收集训练轨迹 147 | % 返回一个轨迹结构体数组 148 | 149 | numTrajectories = ceil(obj.batchSize / obj.trajectoryLen); 150 | trajectories(numTrajectories) = struct(); 151 | 152 | parfor i = 1:numTrajectories 153 | % 使用并行计算加速数据收集 154 | trajectories(i) = obj.collectSingleTrajectory(); 155 | end 156 | end 157 | 158 | function trajectory = collectSingleTrajectory(obj) 159 | % 收集单条轨迹 160 | 161 | % 初始化轨迹存储 162 | trajectory = struct(); 163 | trajectory.observations = cell(obj.trajectoryLen, 1); 164 | trajectory.actions = cell(obj.trajectoryLen, 1); 165 | trajectory.rewards = zeros(obj.trajectoryLen, 1); 166 | trajectory.dones = false(obj.trajectoryLen, 1); 167 | trajectory.values = zeros(obj.trajectoryLen, 1); 168 | trajectory.logProbs = zeros(obj.trajectoryLen, 1); 169 | 170 | % 重置环境 171 | obs = obj.env.reset(); 172 | 173 | % 逐步收集数据 174 | for t = 1:obj.trajectoryLen 175 | % 转换为dlarray并根据需要迁移到GPU 176 | if obj.useGPU 177 | dlObs = dlarray(single(obs), 'CB'); 178 | dlObs = gpuArray(dlObs); 179 | else 180 | dlObs = dlarray(single(obs), 'CB'); 181 | end 182 | 183 | % 通过策略网络获取动作 184 | [action, logProb] = obj.actorNet.sampleAction(dlObs); 185 | 186 | % 获取价值估计 187 | value = obj.criticNet.getValue(dlObs); 188 | 189 | % 转换为CPU并提取数值 190 | if obj.useGPU 191 | action = gather(extractdata(action)); 192 | logProb = gather(extractdata(logProb)); 193 | value = gather(extractdata(value)); 194 | else 195 | action = extractdata(action); 196 | logProb = extractdata(logProb); 197 | value = extractdata(value); 198 | end 199 | 200 | % 在环境中执行动作 201 | [nextObs, reward, done, ~] = obj.env.step(action); 202 | 203 | % 存储当前步骤数据 204 | trajectory.observations{t} = obs; 205 | trajectory.actions{t} = action; 206 | trajectory.rewards(t) = reward; 207 | trajectory.dones(t) = done; 208 | trajectory.values(t) = value; 209 | trajectory.logProbs(t) = logProb; 210 | 211 | % 更新观察 212 | obs = nextObs; 213 | 214 | % 如果回合结束,重置环境 215 | if done && t < obj.trajectoryLen 216 | obs = obj.env.reset(); 217 | end 218 | end 219 | end 220 | 221 | function trajectories = computeAdvantagesAndReturns(obj, trajectories) 222 | % 计算广义优势估计(GAE)和折扣回报 223 | 224 | numTrajectories = length(trajectories); 225 | 226 | % 使用并行计算加速 227 | parfor i = 1:numTrajectories 228 | trajectory = trajectories(i); 229 | T = obj.trajectoryLen; 230 | 231 | % 初始化优势函数和回报 232 | advantages = zeros(T, 1); 233 | returns = zeros(T, 1); 234 | 235 | % 获取最终观察的价值估计 236 | lastObs = trajectory.observations{T}; 237 | if ~trajectory.dones(T) 238 | if obj.useGPU 239 | dlObs = dlarray(single(lastObs), 'CB'); 240 | dlObs = gpuArray(dlObs); 241 | lastValue = gather(extractdata(obj.criticNet.getValue(dlObs))); 242 | else 243 | dlObs = dlarray(single(lastObs), 'CB'); 244 | lastValue = extractdata(obj.criticNet.getValue(dlObs)); 245 | end 246 | else 247 | lastValue = 0; 248 | end 249 | 250 | % GAE计算 251 | gae = 0; 252 | for t = T:-1:1 253 | if t == T 254 | nextValue = lastValue; 255 | nextNonTerminal = 1 - trajectory.dones(T); 256 | else 257 | nextValue = trajectory.values(t+1); 258 | nextNonTerminal = 1 - trajectory.dones(t); 259 | end 260 | 261 | delta = trajectory.rewards(t) + obj.gamma * nextValue * nextNonTerminal - trajectory.values(t); 262 | gae = delta + obj.gamma * obj.lambda * nextNonTerminal * gae; 263 | advantages(t) = gae; 264 | returns(t) = advantages(t) + trajectory.values(t); 265 | end 266 | 267 | % 存储计算结果 268 | trajectories(i).advantages = advantages; 269 | trajectories(i).returns = returns; 270 | end 271 | end 272 | 273 | function metrics = updatePolicy(obj, trajectories) 274 | % 使用PPO算法更新策略 275 | 276 | % 合并所有轨迹数据 277 | [observations, actions, oldLogProbs, returns, advantages] = obj.prepareTrainingData(trajectories); 278 | 279 | % 标准化优势函数 280 | advantages = (advantages - mean(advantages)) / (std(advantages) + 1e-8); 281 | 282 | % 数据集大小 283 | datasetSize = size(observations, 2); 284 | 285 | % 初始化指标 286 | actorLosses = []; 287 | criticLosses = []; 288 | entropyLosses = []; 289 | totalLosses = []; 290 | 291 | % 多轮训练 292 | for epoch = 1:obj.epochsPerIter 293 | % 随机打乱数据 294 | idx = randperm(datasetSize); 295 | 296 | % 分批处理 297 | for i = 1:obj.batchSize:datasetSize 298 | endIdx = min(i + obj.batchSize - 1, datasetSize); 299 | batchIdx = idx(i:endIdx); 300 | 301 | % 提取小批量数据 302 | batchObs = observations(:, batchIdx); 303 | batchActions = actions(:, batchIdx); 304 | batchOldLogProbs = oldLogProbs(batchIdx); 305 | batchReturns = returns(batchIdx); 306 | batchAdvantages = advantages(batchIdx); 307 | 308 | % 转换为dlarray并根据需要迁移到GPU 309 | if obj.useGPU 310 | dlObs = dlarray(batchObs, 'CB'); 311 | dlObs = gpuArray(dlObs); 312 | 313 | dlActions = dlarray(batchActions, 'CB'); 314 | dlActions = gpuArray(dlActions); 315 | 316 | dlOldLogProbs = dlarray(batchOldLogProbs, 'CB'); 317 | dlOldLogProbs = gpuArray(dlOldLogProbs); 318 | 319 | dlReturns = dlarray(batchReturns, 'CB'); 320 | dlReturns = gpuArray(dlReturns); 321 | 322 | dlAdvantages = dlarray(batchAdvantages, 'CB'); 323 | dlAdvantages = gpuArray(dlAdvantages); 324 | else 325 | dlObs = dlarray(batchObs, 'CB'); 326 | dlActions = dlarray(batchActions, 'CB'); 327 | dlOldLogProbs = dlarray(batchOldLogProbs, 'CB'); 328 | dlReturns = dlarray(batchReturns, 'CB'); 329 | dlAdvantages = dlarray(batchAdvantages, 'CB'); 330 | end 331 | 332 | % 计算梯度并更新策略和价值网络 333 | [actorGradients, criticGradients, actorLoss, criticLoss, entropyLoss, totalLoss] = ... 334 | dlfeval(@obj.computeGradients, dlObs, dlActions, dlOldLogProbs, dlReturns, dlAdvantages); 335 | 336 | % 更新Actor网络 337 | obj.actorNet.learnables = obj.actorOptimizer.updateLearnables(obj.actorNet.learnables, actorGradients, obj.maxGradNorm); 338 | 339 | % 更新Critic网络 340 | obj.criticNet.learnables = obj.criticOptimizer.updateLearnables(obj.criticNet.learnables, criticGradients, obj.maxGradNorm); 341 | 342 | % 记录损失值 343 | actorLosses(end+1) = actorLoss; 344 | criticLosses(end+1) = criticLoss; 345 | entropyLosses(end+1) = entropyLoss; 346 | totalLosses(end+1) = totalLoss; 347 | end 348 | end 349 | 350 | % 返回本次更新的指标 351 | metrics = struct(); 352 | metrics.actorLoss = mean(actorLosses); 353 | metrics.criticLoss = mean(criticLosses); 354 | metrics.entropyLoss = mean(entropyLosses); 355 | metrics.totalLoss = mean(totalLosses); 356 | end 357 | 358 | function [actorGradients, criticGradients, actorLoss, criticLoss, entropyLoss, totalLoss] = ... 359 | computeGradients(obj, observations, actions, oldLogProbs, returns, advantages) 360 | % 计算PPO损失和梯度 361 | 362 | % 前向传播 363 | [actorLossValue, criticLossValue, entropyLossValue, totalLossValue] = ... 364 | obj.computeLosses(observations, actions, oldLogProbs, returns, advantages); 365 | 366 | % 计算梯度 367 | [actorGradients, criticGradients] = dlgradient(totalLossValue, obj.actorNet.learnables, obj.criticNet.learnables); 368 | 369 | % 返回标量损失值 370 | actorLoss = extractdata(actorLossValue); 371 | criticLoss = extractdata(criticLossValue); 372 | entropyLoss = extractdata(entropyLossValue); 373 | totalLoss = extractdata(totalLossValue); 374 | end 375 | 376 | function [actorLoss, criticLoss, entropyLoss, totalLoss] = ... 377 | computeLosses(obj, observations, actions, oldLogProbs, returns, advantages) 378 | % 计算PPO损失函数 379 | 380 | % 获取当前策略的动作概率和价值 381 | [logProbs, entropy] = obj.actorNet.evaluateActions(observations, actions); 382 | values = obj.criticNet.getValue(observations); 383 | 384 | % 计算比率 385 | ratio = exp(logProbs - oldLogProbs); 386 | 387 | % 裁剪比率 388 | clippedRatio = min(max(ratio, 1 - obj.epsilon), 1 + obj.epsilon); 389 | 390 | % Actor损失 (策略损失) 391 | surrogateLoss1 = ratio .* advantages; 392 | surrogateLoss2 = clippedRatio .* advantages; 393 | actorLoss = -mean(min(surrogateLoss1, surrogateLoss2)); 394 | 395 | % Critic损失 (价值损失) 396 | criticLoss = mean((values - returns).^2); 397 | 398 | % 熵损失 (用于鼓励探索) 399 | entropyLoss = -mean(entropy); 400 | 401 | % 总损失 402 | totalLoss = actorLoss + obj.vfCoef * criticLoss + obj.entropyCoef * entropyLoss; 403 | end 404 | 405 | function [observations, actions, logProbs, returns, advantages] = prepareTrainingData(obj, trajectories) 406 | % 将轨迹数据转换为训练所需的格式 407 | 408 | numTrajectories = length(trajectories); 409 | totalSteps = numTrajectories * obj.trajectoryLen; 410 | 411 | % 预分配存储空间 412 | observations = zeros(obj.obsSize, totalSteps); 413 | actions = zeros(obj.actionSize, totalSteps); 414 | logProbs = zeros(totalSteps, 1); 415 | returns = zeros(totalSteps, 1); 416 | advantages = zeros(totalSteps, 1); 417 | 418 | % 填充数据 419 | stepIdx = 1; 420 | for i = 1:numTrajectories 421 | for t = 1:obj.trajectoryLen 422 | observations(:, stepIdx) = trajectories(i).observations{t}; 423 | actions(:, stepIdx) = trajectories(i).actions{t}; 424 | logProbs(stepIdx) = trajectories(i).logProbs(t); 425 | returns(stepIdx) = trajectories(i).returns(t); 426 | advantages(stepIdx) = trajectories(i).advantages(t); 427 | 428 | stepIdx = stepIdx + 1; 429 | end 430 | end 431 | end 432 | 433 | function saveModel(obj, filePath) 434 | % 保存模型 435 | actorParams = obj.actorNet.getParameters(); 436 | criticParams = obj.criticNet.getParameters(); 437 | 438 | % 如果参数在GPU上,移回CPU 439 | if obj.useGPU 440 | actorParams = cellfun(@gather, actorParams, 'UniformOutput', false); 441 | criticParams = cellfun(@gather, criticParams, 'UniformOutput', false); 442 | end 443 | 444 | % 保存模型参数和配置 445 | save(filePath, 'actorParams', 'criticParams'); 446 | fprintf('模型已保存到: %s\n', filePath); 447 | end 448 | 449 | function loadModel(obj, filePath) 450 | % 加载模型 451 | load(filePath, 'actorParams', 'criticParams'); 452 | 453 | % 设置模型参数 454 | if obj.useGPU 455 | actorParams = cellfun(@gpuArray, actorParams, 'UniformOutput', false); 456 | criticParams = cellfun(@gpuArray, criticParams, 'UniformOutput', false); 457 | end 458 | 459 | obj.actorNet.setParameters(actorParams); 460 | obj.criticNet.setParameters(criticParams); 461 | 462 | fprintf('模型已加载自: %s\n', filePath); 463 | end 464 | 465 | function result = evaluate(obj, numEpisodes) 466 | % 评估训练好的代理 467 | % numEpisodes - 评估的回合数 468 | 469 | returns = zeros(numEpisodes, 1); 470 | lengths = zeros(numEpisodes, 1); 471 | 472 | for i = 1:numEpisodes 473 | obs = obj.env.reset(); 474 | done = false; 475 | episodeReturn = 0; 476 | episodeLength = 0; 477 | 478 | while ~done 479 | % 转换为dlarray并根据需要迁移到GPU 480 | if obj.useGPU 481 | dlObs = dlarray(single(obs), 'CB'); 482 | dlObs = gpuArray(dlObs); 483 | else 484 | dlObs = dlarray(single(obs), 'CB'); 485 | end 486 | 487 | % 采样动作 488 | [action, ~] = obj.actorNet.sampleAction(dlObs); 489 | 490 | % 转换为CPU并提取数值 491 | if obj.useGPU 492 | action = gather(extractdata(action)); 493 | else 494 | action = extractdata(action); 495 | end 496 | 497 | % 执行动作 498 | [obs, reward, done, ~] = obj.env.step(action); 499 | 500 | episodeReturn = episodeReturn + reward; 501 | episodeLength = episodeLength + 1; 502 | end 503 | 504 | returns(i) = episodeReturn; 505 | lengths(i) = episodeLength; 506 | end 507 | 508 | % 统计返回值 509 | result = struct(); 510 | result.meanReturn = mean(returns); 511 | result.stdReturn = std(returns); 512 | result.minReturn = min(returns); 513 | result.maxReturn = max(returns); 514 | result.meanLength = mean(lengths); 515 | end 516 | end 517 | end 518 | -------------------------------------------------------------------------------- /environments/ACMotorEnv.m: -------------------------------------------------------------------------------- 1 | classdef ACMotorEnv < Environment 2 | % ACMotorEnv 交流感应电机控制环境 3 | % 使用FOC(磁场定向控制)方法模拟三相异步电机的控制系统 4 | 5 | properties 6 | % 环境规格 7 | observationSize = 6 % 状态空间维度:[速度误差, id, iq, d轴误差, q轴误差, 负载转矩估计] 8 | actionSize = 2 % 动作空间维度:[Vd, Vq] - FOC控制的d轴和q轴电压 9 | isDiscrete = false % 连续动作空间 10 | 11 | % 电机物理参数 12 | Rs = 2.0 % 定子电阻 (ohm) 13 | Rr = 2.0 % 转子电阻 (ohm) 14 | Ls = 0.2 % 定子自感 (H) 15 | Lr = 0.2 % 转子自感 (H) 16 | Lm = 0.15 % 互感 (H) 17 | J = 0.02 % 转动惯量 (kg.m^2) 18 | p = 2 % 极对数 19 | B = 0.005 % 阻尼系数 (N.m.s) 20 | 21 | % 系统参数 22 | dt = 0.001 % 时间步长 (s),FOC控制需要较小的时间步长 23 | fs = 1000 % 采样频率 (Hz) 24 | maxVoltage = 400.0 % 最大允许电压 (V) 25 | maxCurrent = 15.0 % 最大允许电流 (A) 26 | maxSpeed = 157.0 % 最大允许角速度 (rad/s) ~ 1500rpm 27 | nominalSpeed = 150 % 额定角速度 (rad/s) 28 | nominalTorque = 10 % 额定转矩 (N.m) 29 | 30 | % 目标 31 | targetSpeed % 目标速度 (rad/s) 32 | 33 | % 当前状态 34 | speed % 实际转速 (rad/s) 35 | position % 实际位置 (rad) 36 | id % d轴电流 (A) 37 | iq % q轴电流 (A) 38 | psi_d % d轴磁链 (Wb) 39 | psi_q % q轴磁链 (Wb) 40 | Te % 电磁转矩 (N.m) 41 | Tl % 负载转矩 (N.m) 42 | 43 | % 控制参数(磁场定向控制FOC需要) 44 | id_ref % d轴电流参考值 (A), 通常为常数 45 | iq_ref % q轴电流参考值 (A), 由速度控制环生成 46 | 47 | % 回合信息 48 | steps = 0 % 当前回合步数 49 | maxSteps = 5000 % 最大步数(由于采样频率高,需要更多步数) 50 | 51 | % 速度控制PI参数 52 | Kp_speed = 0.5 53 | Ki_speed = 5.0 54 | speed_error_integral = 0 55 | 56 | % 负载模拟 57 | loadChangeTime % 负载突变的时间 58 | loadProfile % 负载随时间的变化 59 | 60 | % 可视化 61 | renderFig % 图形句柄 62 | renderAx % 坐标轴句柄 63 | plotHandles % 图形元素句柄 64 | 65 | % 数据记录 66 | historyData % 用于记录系统状态的历史数据 67 | end 68 | 69 | methods 70 | function obj = ACMotorEnv() 71 | % 构造函数:初始化交流电机环境 72 | 73 | % 初始化状态 74 | obj.speed = 0; 75 | obj.position = 0; 76 | obj.id = 0; 77 | obj.iq = 0; 78 | obj.psi_d = 0; 79 | obj.psi_q = 0; 80 | obj.Te = 0; 81 | obj.Tl = 0; 82 | 83 | % 设置d轴电流参考值(磁通控制) 84 | obj.id_ref = 3.0; % 常量,用于建立磁场 85 | 86 | % 随机目标速度 87 | obj.resetTarget(); 88 | 89 | % 初始化负载变化 90 | obj.initializeLoadProfile(); 91 | 92 | % 初始化历史数据记录 93 | obj.historyData = struct(... 94 | 'time', [], ... 95 | 'speed', [], ... 96 | 'targetSpeed', [], ... 97 | 'id', [], ... 98 | 'iq', [], ... 99 | 'Te', [], ... 100 | 'Tl', [], ... 101 | 'Vd', [], ... 102 | 'Vq', [] ... 103 | ); 104 | 105 | % 随机种子 106 | rng('shuffle'); 107 | end 108 | 109 | function resetTarget(obj) 110 | % 设置新的随机目标速度 111 | obj.targetSpeed = (0.5 + 0.5 * rand()) * obj.nominalSpeed; % 50%-100%额定速度 112 | end 113 | 114 | function initializeLoadProfile(obj) 115 | % 初始化负载变化,模拟工业场景的负载突变 116 | % 设置负载突变的时间点 117 | obj.loadChangeTime = [1000, 2000, 3000, 4000]; % 在这些步骤改变负载 118 | 119 | % 设置负载变化的配置文件 120 | obj.loadProfile = [ 121 | 0.2 * obj.nominalTorque; % 初始负载 - 轻载 122 | 0.8 * obj.nominalTorque; % 第一次变化 - 重载 123 | 0.4 * obj.nominalTorque; % 第二次变化 - 中等负载 124 | 0.9 * obj.nominalTorque; % 第三次变化 - 接近满载 125 | 0.3 * obj.nominalTorque; % 最终负载 - 轻载 126 | ]; 127 | end 128 | 129 | function observation = reset(obj) 130 | % 重置环境到初始状态 131 | % 返回初始观察 132 | 133 | % 重置电机状态 134 | obj.speed = 0; 135 | obj.position = 0; 136 | obj.id = 0; 137 | obj.iq = 0; 138 | obj.psi_d = 0; 139 | obj.psi_q = 0; 140 | obj.Te = 0; 141 | obj.Tl = obj.loadProfile(1); % 初始负载 142 | 143 | % 重置控制变量 144 | obj.iq_ref = 0; 145 | obj.speed_error_integral = 0; 146 | 147 | % 重置步数 148 | obj.steps = 0; 149 | 150 | % 重置目标速度 151 | obj.resetTarget(); 152 | 153 | % 重置历史数据 154 | obj.historyData = struct(... 155 | 'time', [], ... 156 | 'speed', [], ... 157 | 'targetSpeed', [], ... 158 | 'id', [], ... 159 | 'iq', [], ... 160 | 'Te', [], ... 161 | 'Tl', [], ... 162 | 'Vd', [], ... 163 | 'Vq', [] ... 164 | ); 165 | 166 | % 返回观察 167 | observation = obj.getObservation(); 168 | end 169 | 170 | function observation = getObservation(obj) 171 | % 获取当前观察(状态处理和归一化) 172 | 173 | % 计算速度误差 174 | speed_error = obj.targetSpeed - obj.speed; 175 | 176 | % 计算电流误差 177 | id_error = obj.id_ref - obj.id; 178 | iq_error = obj.iq_ref - obj.iq; 179 | 180 | % 组合观察(归一化) 181 | observation = [ 182 | speed_error / obj.maxSpeed; % 归一化速度误差 183 | obj.id / obj.maxCurrent; % 归一化d轴电流 184 | obj.iq / obj.maxCurrent; % 归一化q轴电流 185 | id_error / obj.maxCurrent; % 归一化d轴电流误差 186 | iq_error / obj.maxCurrent; % 归一化q轴电流误差 187 | obj.Tl / obj.nominalTorque; % 归一化负载转矩 188 | ]; 189 | end 190 | 191 | function [nextObs, reward, done, info] = step(obj, action) 192 | % 执行动作并返回新的状态 193 | % action - 要执行的动作(d轴和q轴电压,范围[-1, 1]) 194 | % 返回值: 195 | % nextObs - 新的观察 196 | % reward - 获得的奖励 197 | % done - 是否回合结束 198 | % info - 附加信息 199 | 200 | % 将动作范围限制在[-1, 1] 201 | action = max(-1, min(1, action)); 202 | 203 | % 将动作转换为d轴和q轴电压 204 | Vd = action(1) * obj.maxVoltage; 205 | Vq = action(2) * obj.maxVoltage; 206 | 207 | % 更新负载(模拟负载变化) 208 | obj.updateLoad(); 209 | 210 | % 更新PI控制器计算iq_ref(速度环) 211 | % 在实际系统中这通常由PI控制器完成,这里我们让PPO学习这个映射 212 | speed_error = obj.targetSpeed - obj.speed; 213 | obj.speed_error_integral = obj.speed_error_integral + speed_error * obj.dt; 214 | obj.iq_ref = obj.Kp_speed * speed_error + obj.Ki_speed * obj.speed_error_integral; 215 | obj.iq_ref = max(-obj.maxCurrent, min(obj.maxCurrent, obj.iq_ref)); 216 | 217 | % 解包当前状态 218 | id = obj.id; 219 | iq = obj.iq; 220 | psi_d = obj.psi_d; 221 | psi_q = obj.psi_q; 222 | speed = obj.speed; 223 | Tl = obj.Tl; 224 | 225 | % 计算电磁转矩 226 | Te = 1.5 * obj.p * obj.Lm / obj.Lr * (iq * psi_d - id * psi_q); 227 | 228 | % 机械系统方程 - 转速和位置更新 229 | % dω/dt = (Te - Tl - B*ω) / J 230 | speed_new = speed + obj.dt * ((Te - Tl - obj.B * speed) / obj.J); 231 | position_new = obj.position + obj.dt * speed_new; 232 | 233 | % 电气系统方程 - FOC模型 234 | % 模拟交流电机FOC控制下的动态模型 235 | sigma = 1 - obj.Lm^2 / (obj.Ls * obj.Lr); 236 | Tr = obj.Lr / obj.Rr; % 转子时间常数 237 | 238 | % 转子磁通动态方程 239 | psi_d_new = psi_d + obj.dt * ((-psi_d + obj.Lm * id) / Tr - obj.p * speed * psi_q); 240 | psi_q_new = psi_q + obj.dt * ((-psi_q + obj.Lm * iq) / Tr + obj.p * speed * psi_d); 241 | 242 | % 定子电流动态方程 243 | id_new = id + obj.dt * ((Vd - obj.Rs * id + obj.p * speed * sigma * obj.Ls * iq) / (sigma * obj.Ls)); 244 | iq_new = iq + obj.dt * ((Vq - obj.Rs * iq - obj.p * speed * (sigma * obj.Ls * id + obj.Lm/obj.Lr * psi_d)) / (sigma * obj.Ls)); 245 | 246 | % 限制电流 247 | id_new = max(-obj.maxCurrent, min(obj.maxCurrent, id_new)); 248 | iq_new = max(-obj.maxCurrent, min(obj.maxCurrent, iq_new)); 249 | 250 | % 更新状态 251 | obj.id = id_new; 252 | obj.iq = iq_new; 253 | obj.psi_d = psi_d_new; 254 | obj.psi_q = psi_q_new; 255 | obj.speed = speed_new; 256 | obj.position = position_new; 257 | obj.Te = Te; 258 | 259 | % 更新步数 260 | obj.steps = obj.steps + 1; 261 | 262 | % 记录历史数据 263 | obj.updateHistory(Vd, Vq); 264 | 265 | % 计算奖励 266 | reward = obj.calculateReward(speed_error, id_new, iq_new, Te, Vd, Vq); 267 | 268 | % 判断是否结束 269 | done = obj.steps >= obj.maxSteps; 270 | 271 | % 返回结果 272 | nextObs = obj.getObservation(); 273 | info = struct(... 274 | 'steps', obj.steps, ... 275 | 'targetSpeed', obj.targetSpeed, ... 276 | 'speed', obj.speed, ... 277 | 'Te', obj.Te, ... 278 | 'Tl', obj.Tl, ... 279 | 'id', obj.id, ... 280 | 'iq', obj.iq ... 281 | ); 282 | end 283 | 284 | function updateLoad(obj) 285 | % 更新负载转矩(模拟负载变化) 286 | for i = 1:length(obj.loadChangeTime) 287 | if obj.steps == obj.loadChangeTime(i) 288 | obj.Tl = obj.loadProfile(i+1); 289 | break; 290 | end 291 | end 292 | end 293 | 294 | function reward = calculateReward(obj, speed_error, id, iq, Te, Vd, Vq) 295 | % 计算奖励函数 296 | 297 | % 速度误差奖励(负的平方误差) 298 | speed_reward = -0.5 * (speed_error/obj.maxSpeed)^2; 299 | 300 | % d轴电流跟踪奖励(保持磁通) 301 | id_error = obj.id_ref - id; 302 | id_reward = -0.3 * (id_error/obj.maxCurrent)^2; 303 | 304 | % q轴电流奖励 - 转矩生成 305 | iq_reward = -0.3 * ((obj.iq_ref - iq)/obj.maxCurrent)^2; 306 | 307 | % 能量效率奖励(避免过高电压和电流) 308 | power_reward = -0.05 * ((Vd/obj.maxVoltage)^2 + (Vq/obj.maxVoltage)^2); 309 | 310 | % 综合奖励函数 311 | reward = speed_reward + id_reward + iq_reward + power_reward; 312 | end 313 | 314 | function updateHistory(obj, Vd, Vq) 315 | % 更新历史数据,用于绘图 316 | obj.historyData.time(end+1) = obj.steps * obj.dt; 317 | obj.historyData.speed(end+1) = obj.speed; 318 | obj.historyData.targetSpeed(end+1) = obj.targetSpeed; 319 | obj.historyData.id(end+1) = obj.id; 320 | obj.historyData.iq(end+1) = obj.iq; 321 | obj.historyData.Te(end+1) = obj.Te; 322 | obj.historyData.Tl(end+1) = obj.Tl; 323 | obj.historyData.Vd(end+1) = Vd; 324 | obj.historyData.Vq(end+1) = Vq; 325 | end 326 | 327 | function render(obj) 328 | % 渲染当前环境状态 329 | 330 | % 如果没有图形,创建一个 331 | if isempty(obj.renderFig) || ~isvalid(obj.renderFig) 332 | obj.renderFig = figure('Name', '交流感应电机FOC控制', 'Position', [100, 100, 1200, 800]); 333 | 334 | % 创建四个子图 335 | % 1. 速度响应 336 | subplot(2, 2, 1); 337 | hold on; 338 | speedPlot = plot(0, 0, 'b-', 'LineWidth', 1.5); 339 | targetSpeedPlot = plot(0, 0, 'r--', 'LineWidth', 1.5); 340 | title('速度响应'); 341 | xlabel('时间 (s)'); 342 | ylabel('速度 (rad/s)'); 343 | legend('实际速度', '目标速度'); 344 | grid on; 345 | 346 | % 2. 电流响应 347 | subplot(2, 2, 2); 348 | hold on; 349 | idPlot = plot(0, 0, 'b-', 'LineWidth', 1.5); 350 | iqPlot = plot(0, 0, 'r-', 'LineWidth', 1.5); 351 | idRefPlot = plot(0, 0, 'b--', 'LineWidth', 1); 352 | iqRefPlot = plot(0, 0, 'r--', 'LineWidth', 1); 353 | title('d-q轴电流'); 354 | xlabel('时间 (s)'); 355 | ylabel('电流 (A)'); 356 | legend('id', 'iq', 'id-ref', 'iq-ref'); 357 | grid on; 358 | 359 | % 3. 电压控制输入 360 | subplot(2, 2, 3); 361 | hold on; 362 | vdPlot = plot(0, 0, 'b-', 'LineWidth', 1.5); 363 | vqPlot = plot(0, 0, 'r-', 'LineWidth', 1.5); 364 | title('d-q轴电压'); 365 | xlabel('时间 (s)'); 366 | ylabel('电压 (V)'); 367 | legend('Vd', 'Vq'); 368 | grid on; 369 | 370 | % 4. 转矩和负载 371 | subplot(2, 2, 4); 372 | hold on; 373 | tePlot = plot(0, 0, 'b-', 'LineWidth', 1.5); 374 | tlPlot = plot(0, 0, 'r-', 'LineWidth', 1.5); 375 | title('转矩'); 376 | xlabel('时间 (s)'); 377 | ylabel('转矩 (N.m)'); 378 | legend('电磁转矩', '负载转矩'); 379 | grid on; 380 | 381 | % 存储图形句柄 382 | obj.plotHandles = struct(... 383 | 'speedPlot', speedPlot, ... 384 | 'targetSpeedPlot', targetSpeedPlot, ... 385 | 'idPlot', idPlot, ... 386 | 'iqPlot', iqPlot, ... 387 | 'idRefPlot', idRefPlot, ... 388 | 'iqRefPlot', iqRefPlot, ... 389 | 'vdPlot', vdPlot, ... 390 | 'vqPlot', vqPlot, ... 391 | 'tePlot', tePlot, ... 392 | 'tlPlot', tlPlot ... 393 | ); 394 | end 395 | 396 | % 更新图形 397 | % 获取历史数据 398 | time = obj.historyData.time; 399 | 400 | % 只显示最后windowSize个点,避免图形过于拥挤 401 | windowSize = 1000; % 显示1秒的数据 402 | if length(time) > windowSize 403 | startIdx = length(time) - windowSize + 1; 404 | else 405 | startIdx = 1; 406 | end 407 | 408 | % 更新速度响应图 409 | obj.plotHandles.speedPlot.XData = time(startIdx:end); 410 | obj.plotHandles.speedPlot.YData = obj.historyData.speed(startIdx:end); 411 | obj.plotHandles.targetSpeedPlot.XData = time(startIdx:end); 412 | obj.plotHandles.targetSpeedPlot.YData = obj.historyData.targetSpeed(startIdx:end); 413 | 414 | % 更新电流图 415 | obj.plotHandles.idPlot.XData = time(startIdx:end); 416 | obj.plotHandles.idPlot.YData = obj.historyData.id(startIdx:end); 417 | obj.plotHandles.iqPlot.XData = time(startIdx:end); 418 | obj.plotHandles.iqPlot.YData = obj.historyData.iq(startIdx:end); 419 | 420 | % 更新参考电流线 421 | obj.plotHandles.idRefPlot.XData = time(startIdx:end); 422 | obj.plotHandles.idRefPlot.YData = ones(size(time(startIdx:end))) * obj.id_ref; 423 | obj.plotHandles.iqRefPlot.XData = time(startIdx:end); 424 | obj.plotHandles.iqRefPlot.YData = ones(size(time(startIdx:end))) * obj.iq_ref; 425 | 426 | % 更新电压图 427 | obj.plotHandles.vdPlot.XData = time(startIdx:end); 428 | obj.plotHandles.vdPlot.YData = obj.historyData.Vd(startIdx:end); 429 | obj.plotHandles.vqPlot.XData = time(startIdx:end); 430 | obj.plotHandles.vqPlot.YData = obj.historyData.Vq(startIdx:end); 431 | 432 | % 更新转矩图 433 | obj.plotHandles.tePlot.XData = time(startIdx:end); 434 | obj.plotHandles.tePlot.YData = obj.historyData.Te(startIdx:end); 435 | obj.plotHandles.tlPlot.XData = time(startIdx:end); 436 | obj.plotHandles.tlPlot.YData = obj.historyData.Tl(startIdx:end); 437 | 438 | % 调整所有子图的X轴范围 439 | for i = 1:4 440 | subplot(2, 2, i); 441 | if ~isempty(time) 442 | if length(time) > windowSize 443 | xlim([time(end) - (windowSize-1)*obj.dt, time(end)]); 444 | else 445 | xlim([0, max(time(end), 0.1)]); 446 | end 447 | end 448 | end 449 | 450 | % 速度图Y轴 451 | subplot(2, 2, 1); 452 | ylim([0, max(max(obj.historyData.targetSpeed) * 1.2, obj.maxSpeed * 0.5)]); 453 | 454 | % 电流图Y轴 455 | subplot(2, 2, 2); 456 | maxCurrent = max(max(abs(obj.historyData.id)), max(abs(obj.historyData.iq))); 457 | ylim([-maxCurrent * 1.2, maxCurrent * 1.2]); 458 | 459 | % 电压图Y轴 460 | subplot(2, 2, 3); 461 | maxVoltage = max(max(abs(obj.historyData.Vd)), max(abs(obj.historyData.Vq))); 462 | ylim([-maxVoltage * 1.2, maxVoltage * 1.2]); 463 | 464 | % 转矩图Y轴 465 | subplot(2, 2, 4); 466 | maxTorque = max(max(obj.historyData.Te), max(obj.historyData.Tl)); 467 | ylim([0, maxTorque * 1.5]); 468 | 469 | % 刷新图形 470 | drawnow; 471 | end 472 | end 473 | end 474 | -------------------------------------------------------------------------------- /environments/CartPoleEnv.m: -------------------------------------------------------------------------------- 1 | classdef CartPoleEnv < Environment 2 | % CartPoleEnv 倒立摆环境 3 | % 基于经典控制问题的倒立摆环境实现 4 | 5 | properties 6 | % 环境规格 7 | observationSize = 4 % 状态空间维度:[x, x_dot, theta, theta_dot] 8 | actionSize = 1 % 动作空间维度:左右力 9 | isDiscrete = true % 是否为离散动作空间 10 | 11 | % 物理参数 12 | gravity = 9.8 % 重力加速度 (m/s^2) 13 | massCart = 1.0 % 小车质量 (kg) 14 | massPole = 0.1 % 杆质量 (kg) 15 | totalMass % 总质量 (kg) 16 | length = 0.5 % 杆长的一半 (m) 17 | poleMassLength % 杆质量长度 18 | forceMag = 10.0 % 施加在小车上的力大小 (N) 19 | tau = 0.02 % 时间步长 (s) 20 | 21 | % 界限 22 | xThreshold = 2.4 % 小车位置阈值 23 | thetaThreshold = 12 % 杆角度阈值 (度) 24 | 25 | % 当前状态 26 | state % [x, x_dot, theta, theta_dot] 27 | 28 | % 回合信息 29 | steps = 0 % 当前回合步数 30 | maxSteps = 500 % 最大步数 31 | 32 | % 可视化 33 | renderFig % 图形句柄 34 | renderAx % 坐标轴句柄 35 | cartWidth = 0.5 % 小车宽度 36 | cartHeight = 0.3 % 小车高度 37 | poleWidth = 0.1 % 杆宽度 38 | end 39 | 40 | methods 41 | function obj = CartPoleEnv() 42 | % 构造函数:初始化倒立摆环境 43 | 44 | % 计算物理参数 45 | obj.totalMass = obj.massCart + obj.massPole; 46 | obj.poleMassLength = obj.massPole * obj.length; 47 | 48 | % 初始化状态 49 | obj.state = zeros(4, 1); 50 | 51 | % 随机种子 52 | rng('shuffle'); 53 | end 54 | 55 | function observation = reset(obj) 56 | % 重置环境到初始状态 57 | % 返回初始观察 58 | 59 | % 随机初始化状态,±0.05范围内的小扰动 60 | obj.state = 0.1 * rand(4, 1) - 0.05; 61 | obj.steps = 0; 62 | 63 | % 返回观察 64 | observation = obj.state; 65 | end 66 | 67 | function [nextObs, reward, done, info] = step(obj, action) 68 | % 执行动作并返回新的状态 69 | % action - 要执行的动作 (0: 左推,1: 右推) 70 | % 返回值: 71 | % nextObs - 新的观察 72 | % reward - 获得的奖励 73 | % done - 是否回合结束 74 | % info - 附加信息 75 | 76 | % 验证动作 77 | assert(isscalar(action) && (action == 0 || action == 1), '动作必须是0或1'); 78 | 79 | % 将动作转换为力 80 | force = (action * 2 - 1) * obj.forceMag; % 0 -> -10, 1 -> 10 81 | 82 | % 解包状态 83 | x = obj.state(1); 84 | xDot = obj.state(2); 85 | theta = obj.state(3); 86 | thetaDot = obj.state(4); 87 | 88 | % 角度需要从弧度转换为度进行检查 89 | thetaDeg = rad2deg(theta); 90 | 91 | % 检查是否超出界限 92 | done = abs(x) > obj.xThreshold || ... 93 | abs(thetaDeg) > obj.thetaThreshold || ... 94 | obj.steps >= obj.maxSteps; 95 | 96 | % 如果回合未结束,计算下一个状态 97 | if ~done 98 | % 物理计算 99 | cosTheta = cos(theta); 100 | sinTheta = sin(theta); 101 | 102 | % 计算加速度 103 | temp = (force + obj.poleMassLength * thetaDot^2 * sinTheta) / obj.totalMass; 104 | thetaAcc = (obj.gravity * sinTheta - cosTheta * temp) / ... 105 | (obj.length * (4.0/3.0 - obj.massPole * cosTheta^2 / obj.totalMass)); 106 | xAcc = temp - obj.poleMassLength * thetaAcc * cosTheta / obj.totalMass; 107 | 108 | % 欧拉积分更新 109 | x = x + obj.tau * xDot; 110 | xDot = xDot + obj.tau * xAcc; 111 | theta = theta + obj.tau * thetaDot; 112 | thetaDot = thetaDot + obj.tau * thetaAcc; 113 | 114 | % 更新状态 115 | obj.state = [x; xDot; theta; thetaDot]; 116 | obj.steps = obj.steps + 1; 117 | end 118 | 119 | % 设置奖励 120 | if done && obj.steps < obj.maxSteps 121 | % 如果因为超出界限而结束(不是因为达到最大步数),给予惩罚 122 | reward = 0; 123 | else 124 | % 存活奖励 125 | reward = 1.0; 126 | end 127 | 128 | % 返回结果 129 | nextObs = obj.state; 130 | info = struct('steps', obj.steps); 131 | end 132 | 133 | function render(obj) 134 | % 渲染当前环境状态 135 | 136 | % 如果没有图形,创建一个 137 | if isempty(obj.renderFig) || ~isvalid(obj.renderFig) 138 | obj.renderFig = figure('Name', '倒立摆', 'Position', [100, 100, 800, 400]); 139 | obj.renderAx = axes('XLim', [-obj.xThreshold - 1, obj.xThreshold + 1], ... 140 | 'YLim', [-1, 2]); 141 | title('倒立摆'); 142 | xlabel('位置'); 143 | ylabel('高度'); 144 | hold(obj.renderAx, 'on'); 145 | grid(obj.renderAx, 'on'); 146 | end 147 | 148 | % 清除当前轴 149 | cla(obj.renderAx); 150 | 151 | % 获取当前状态 152 | x = obj.state(1); 153 | theta = obj.state(3); 154 | 155 | % 计算杆的端点 156 | poleX = [x, x + 2 * obj.length * sin(theta)]; 157 | poleY = [0, 2 * obj.length * cos(theta)]; 158 | 159 | % 绘制小车 160 | cartX = [x - obj.cartWidth/2, x + obj.cartWidth/2, x + obj.cartWidth/2, x - obj.cartWidth/2]; 161 | cartY = [0, 0, obj.cartHeight, obj.cartHeight]; 162 | fill(obj.renderAx, cartX, cartY, 'b'); 163 | 164 | % 绘制杆 165 | line(obj.renderAx, poleX, poleY, 'Color', 'r', 'LineWidth', 3); 166 | 167 | % 绘制地面 168 | line(obj.renderAx, [-obj.xThreshold - 1, obj.xThreshold + 1], [0, 0], 'Color', 'k', 'LineWidth', 2); 169 | 170 | % 更新图形 171 | drawnow; 172 | end 173 | end 174 | end 175 | -------------------------------------------------------------------------------- /environments/DCMotorEnv.m: -------------------------------------------------------------------------------- 1 | classdef DCMotorEnv < Environment 2 | % DCMotorEnv 直流电机控制环境 3 | % 模拟直流电机系统的动态特性和控制问题 4 | 5 | properties 6 | % 环境规格 7 | observationSize = 3 % 状态空间维度:[角度, 角速度, 电流] 8 | actionSize = 1 % 动作空间维度:施加电压 9 | isDiscrete = false % 连续动作空间 10 | 11 | % 电机物理参数 12 | J = 0.01 % 转动惯量 (kg.m^2) 13 | b = 0.1 % 阻尼系数 (N.m.s) 14 | K = 0.01 % 电机常数 (N.m/A 或 V.s/rad) 15 | R = 1.0 % 电阻 (ohm) 16 | L = 0.5 % 电感 (H) 17 | 18 | % 系统参数 19 | dt = 0.01 % 时间步长 (s) 20 | maxVoltage = 24.0 % 最大允许电压 (V) 21 | maxCurrent = 10.0 % 最大允许电流 (A) 22 | maxSpeed = 50.0 % 最大允许角速度 (rad/s) 23 | 24 | % 目标 25 | targetAngle % 目标角度 (rad) 26 | 27 | % 当前状态 28 | state % [角度, 角速度, 电流] 29 | 30 | % 回合信息 31 | steps = 0 % 当前回合步数 32 | maxSteps = 500 % 最大步数 33 | 34 | % 可视化 35 | renderFig % 图形句柄 36 | renderAx % 坐标轴句柄 37 | plotHandles % 图形元素句柄 38 | end 39 | 40 | methods 41 | function obj = DCMotorEnv() 42 | % 构造函数:初始化电机环境 43 | 44 | % 初始化状态 45 | obj.state = zeros(3, 1); 46 | 47 | % 随机目标角度 (0到2π) 48 | obj.resetTarget(); 49 | 50 | % 随机种子 51 | rng('shuffle'); 52 | end 53 | 54 | function resetTarget(obj) 55 | % 设置新的随机目标角度 56 | obj.targetAngle = 2 * pi * rand(); 57 | end 58 | 59 | function observation = reset(obj) 60 | % 重置环境到初始状态 61 | % 返回初始观察 62 | 63 | % 随机初始角度 (0到2π) 64 | initialAngle = 2 * pi * rand(); 65 | 66 | % 初始状态:[角度, 角速度, 电流] 67 | obj.state = [initialAngle; 0; 0]; 68 | 69 | % 重置步数 70 | obj.steps = 0; 71 | 72 | % 重置目标 73 | obj.resetTarget(); 74 | 75 | % 返回观察 76 | observation = obj.getObservation(); 77 | end 78 | 79 | function observation = getObservation(obj) 80 | % 获取当前观察(可以包括状态处理) 81 | 82 | % 标准化角度差异(在-π到π之间) 83 | angleDiff = obj.state(1) - obj.targetAngle; 84 | angleDiff = atan2(sin(angleDiff), cos(angleDiff)); 85 | 86 | % 组合观察:[角度差, 角速度, 电流] 87 | observation = [angleDiff; obj.state(2)/obj.maxSpeed; obj.state(3)/obj.maxCurrent]; 88 | end 89 | 90 | function [nextObs, reward, done, info] = step(obj, action) 91 | % 执行动作并返回新的状态 92 | % action - 要执行的动作(施加电压,范围[-1, 1]) 93 | % 返回值: 94 | % nextObs - 新的观察 95 | % reward - 获得的奖励 96 | % done - 是否回合结束 97 | % info - 附加信息 98 | 99 | % 验证动作(在电压限制内) 100 | action = max(-1, min(1, action)); % 裁剪到[-1, 1] 101 | 102 | % 将动作转换为电压 103 | voltage = action * obj.maxVoltage; 104 | 105 | % 解包状态 106 | angle = obj.state(1); 107 | angularVelocity = obj.state(2); 108 | current = obj.state(3); 109 | 110 | % 物理模型 (电机动力学) 111 | % dθ/dt = ω 112 | % dω/dt = (K*i - b*ω) / J 113 | % di/dt = (V - R*i - K*ω) / L 114 | 115 | % 欧拉方法更新状态 116 | angleNext = angle + obj.dt * angularVelocity; 117 | angularVelocityNext = angularVelocity + obj.dt * ((obj.K * current - obj.b * angularVelocity) / obj.J); 118 | currentNext = current + obj.dt * ((voltage - obj.R * current - obj.K * angularVelocity) / obj.L); 119 | 120 | % 限制电流 121 | currentNext = max(-obj.maxCurrent, min(obj.maxCurrent, currentNext)); 122 | 123 | % 限制角速度 124 | angularVelocityNext = max(-obj.maxSpeed, min(obj.maxSpeed, angularVelocityNext)); 125 | 126 | % 角度归一化到[0, 2π) 127 | angleNext = mod(angleNext, 2 * pi); 128 | 129 | % 更新状态 130 | obj.state = [angleNext; angularVelocityNext; currentNext]; 131 | obj.steps = obj.steps + 1; 132 | 133 | % 计算与目标的角度差 134 | angleDiff = angleNext - obj.targetAngle; 135 | angleDiff = atan2(sin(angleDiff), cos(angleDiff)); % 归一化到[-π, π] 136 | 137 | % 计算与目标的距离(角度误差) 138 | distance = abs(angleDiff); 139 | 140 | % 判断是否完成 141 | angleThreshold = 0.05; % 角度误差阈值(弧度) 142 | speedThreshold = 0.1; % 速度阈值(rad/s) 143 | 144 | targetReached = distance < angleThreshold && abs(angularVelocityNext) < speedThreshold; 145 | timeout = obj.steps >= obj.maxSteps; 146 | 147 | done = targetReached || timeout; 148 | 149 | % 计算奖励 150 | % 角度误差奖励(接近目标给更高奖励) 151 | angleReward = -distance^2; 152 | 153 | % 速度惩罚(过高的速度给予惩罚) 154 | speedPenalty = -0.01 * (angularVelocityNext^2); 155 | 156 | % 电压使用惩罚(鼓励使用较小的控制信号) 157 | voltagePenalty = -0.01 * (voltage^2 / (obj.maxVoltage^2)); 158 | 159 | % 电流惩罚(避免过高电流) 160 | currentPenalty = -0.01 * (currentNext^2 / (obj.maxCurrent^2)); 161 | 162 | % 目标达成奖励 163 | successReward = targetReached ? 10.0 : 0.0; 164 | 165 | % 总奖励 166 | reward = angleReward + speedPenalty + voltagePenalty + currentPenalty + successReward; 167 | 168 | % 返回结果 169 | nextObs = obj.getObservation(); 170 | info = struct('steps', obj.steps, 'targetAngle', obj.targetAngle, 'distance', distance); 171 | end 172 | 173 | function render(obj) 174 | % 渲染当前环境状态 175 | 176 | % 如果没有图形,创建一个 177 | if isempty(obj.renderFig) || ~isvalid(obj.renderFig) 178 | obj.renderFig = figure('Name', '直流电机控制', 'Position', [100, 100, 1000, 600]); 179 | 180 | % 创建两个子图 181 | subplot(2, 1, 1); 182 | obj.renderAx = gca; 183 | title('电机角度'); 184 | hold(obj.renderAx, 'on'); 185 | grid(obj.renderAx, 'on'); 186 | 187 | % 绘制目标角度 188 | targetLine = line(obj.renderAx, [0, cos(obj.targetAngle)], [0, sin(obj.targetAngle)], ... 189 | 'Color', 'g', 'LineWidth', 2, 'LineStyle', '--'); 190 | 191 | % 绘制电机轴 192 | motorLine = line(obj.renderAx, [0, cos(obj.state(1))], [0, sin(obj.state(1))], ... 193 | 'Color', 'b', 'LineWidth', 3); 194 | 195 | % 绘制电机本体 196 | t = linspace(0, 2*pi, 100); 197 | motorBody = fill(0.2*cos(t), 0.2*sin(t), 'r'); 198 | 199 | % 设置图形属性 200 | axis(obj.renderAx, 'equal'); 201 | xlim(obj.renderAx, [-1.5, 1.5]); 202 | ylim(obj.renderAx, [-1.5, 1.5]); 203 | 204 | % 存储图形句柄 205 | obj.plotHandles = struct('targetLine', targetLine, 'motorLine', motorLine, 'motorBody', motorBody); 206 | 207 | % 创建下方子图显示状态信息 208 | subplot(2, 1, 2); 209 | infoAx = gca; 210 | hold(infoAx, 'on'); 211 | grid(infoAx, 'on'); 212 | 213 | % 时间序列数据 214 | timeData = zeros(1, 100); 215 | angleData = zeros(1, 100); 216 | speedData = zeros(1, 100); 217 | currentData = zeros(1, 100); 218 | 219 | % 绘制时间序列 220 | anglePlot = plot(infoAx, timeData, angleData, 'b-', 'LineWidth', 1.5); 221 | speedPlot = plot(infoAx, timeData, speedData, 'r-', 'LineWidth', 1.5); 222 | currentPlot = plot(infoAx, timeData, currentData, 'g-', 'LineWidth', 1.5); 223 | 224 | % 设置图形属性 225 | title(infoAx, '系统状态'); 226 | xlabel(infoAx, '时间步'); 227 | ylabel(infoAx, '状态值'); 228 | legend(infoAx, {'角度差', '角速度', '电流'}, 'Location', 'best'); 229 | xlim(infoAx, [0, 100]); 230 | ylim(infoAx, [-1.5, 1.5]); 231 | 232 | % 存储时间序列信息 233 | obj.plotHandles.timeData = timeData; 234 | obj.plotHandles.angleData = angleData; 235 | obj.plotHandles.speedData = speedData; 236 | obj.plotHandles.currentData = currentData; 237 | obj.plotHandles.anglePlot = anglePlot; 238 | obj.plotHandles.speedPlot = speedPlot; 239 | obj.plotHandles.currentPlot = currentPlot; 240 | obj.plotHandles.infoAx = infoAx; 241 | end 242 | 243 | % 更新顶部电机图 244 | subplot(2, 1, 1); 245 | 246 | % 更新目标线 247 | obj.plotHandles.targetLine.XData = [0, cos(obj.targetAngle)]; 248 | obj.plotHandles.targetLine.YData = [0, sin(obj.targetAngle)]; 249 | 250 | % 更新电机轴线 251 | obj.plotHandles.motorLine.XData = [0, cos(obj.state(1))]; 252 | obj.plotHandles.motorLine.YData = [0, sin(obj.state(1))]; 253 | 254 | % 更新下方状态图 255 | subplot(2, 1, 2); 256 | 257 | % 获取当前观察 258 | obs = obj.getObservation(); 259 | 260 | % 更新时间序列数据 261 | obj.plotHandles.timeData = [obj.plotHandles.timeData(2:end), obj.steps]; 262 | obj.plotHandles.angleData = [obj.plotHandles.angleData(2:end), obs(1)]; 263 | obj.plotHandles.speedData = [obj.plotHandles.speedData(2:end), obs(2)]; 264 | obj.plotHandles.currentData = [obj.plotHandles.currentData(2:end), obs(3)]; 265 | 266 | % 更新图形 267 | obj.plotHandles.anglePlot.XData = obj.plotHandles.timeData; 268 | obj.plotHandles.anglePlot.YData = obj.plotHandles.angleData; 269 | obj.plotHandles.speedPlot.XData = obj.plotHandles.timeData; 270 | obj.plotHandles.speedPlot.YData = obj.plotHandles.speedData; 271 | obj.plotHandles.currentPlot.XData = obj.plotHandles.timeData; 272 | obj.plotHandles.currentPlot.YData = obj.plotHandles.currentData; 273 | 274 | % 调整X轴范围以始终显示最近的100个时间步 275 | if obj.steps > 100 276 | xlim(obj.plotHandles.infoAx, [obj.steps-100, obj.steps]); 277 | end 278 | 279 | % 刷新图形 280 | drawnow; 281 | end 282 | end 283 | end 284 | -------------------------------------------------------------------------------- /environments/DoublePendulumEnv.m: -------------------------------------------------------------------------------- 1 | classdef DoublePendulumEnv < handle 2 | % DoublePendulumEnv 双倒立摆环境 3 | % 一个需要多智能体协作的控制问题 4 | % 两个智能体各自控制一个摆杆,需要协同工作将摆杆保持在倒立位置 5 | 6 | properties 7 | % 环境参数 8 | g = 9.81 % 重力加速度 (m/s^2) 9 | l1 = 1.0 % 第一摆杆长度 (m) 10 | l2 = 1.0 % 第二摆杆长度 (m) 11 | m1 = 1.0 % 第一摆杆质量 (kg) 12 | m2 = 1.0 % 第二摆杆质量 (kg) 13 | maxTorque = 10.0 % 最大扭矩 (N·m) 14 | dt = 0.05 % 时间步长 (s) 15 | maxSteps = 200 % 最大步数 16 | 17 | % 状态变量 18 | theta1 % 第一摆杆角度 19 | theta2 % 第二摆杆角度 20 | dtheta1 % 第一摆杆角速度 21 | dtheta2 % 第二摆杆角速度 22 | 23 | % 目标状态 24 | targetTheta1 = pi % 目标角度(倒立位置) 25 | targetTheta2 = pi % 目标角度(倒立位置) 26 | 27 | % 环境状态 28 | stepCount % 当前步数 29 | renderFig % 绘图句柄 30 | renderAx % 坐标轴句柄 31 | pendulum1Line % 第一摆杆线 32 | pendulum2Line % 第二摆杆线 33 | 34 | % 观察空间和动作空间维度 35 | obsSize % 每个智能体的观察空间大小 36 | actionSize % 每个智能体的动作空间大小 37 | end 38 | 39 | methods 40 | function obj = DoublePendulumEnv() 41 | % 构造函数 42 | % 初始化观察空间和动作空间维度 43 | obj.obsSize = [4, 4]; % 每个智能体的观察空间维度 44 | obj.actionSize = [1, 1]; % 每个智能体的动作空间维度 45 | end 46 | 47 | function [agentObs, jointObs] = reset(obj) 48 | % 重置环境到初始状态 49 | 50 | % 随机初始位置(接近但不完全在倒立位置) 51 | obj.theta1 = pi + (rand - 0.5) * 0.6; % 近似倒立位置 52 | obj.theta2 = pi + (rand - 0.5) * 0.6; % 近似倒立位置 53 | obj.dtheta1 = (rand - 0.5) * 0.2; % 小的随机初始速度 54 | obj.dtheta2 = (rand - 0.5) * 0.2; % 小的随机初始速度 55 | 56 | % 重置步数计数器 57 | obj.stepCount = 0; 58 | 59 | % 返回观察 60 | agentObs = obj.getAgentObservations(); 61 | jointObs = obj.getJointObservation(); 62 | end 63 | 64 | function [nextAgentObs, nextJointObs, reward, done, info] = step(obj, actions) 65 | % 执行动作并返回下一个状态、奖励和完成标志 66 | 67 | % 限制动作在有效范围内 68 | torque1 = min(max(actions{1}, -obj.maxTorque), obj.maxTorque); 69 | torque2 = min(max(actions{2}, -obj.maxTorque), obj.maxTorque); 70 | 71 | % 使用动力学方程更新状态 72 | [obj.theta1, obj.theta2, obj.dtheta1, obj.dtheta2] = ... 73 | obj.dynamics(obj.theta1, obj.theta2, obj.dtheta1, obj.dtheta2, torque1, torque2); 74 | 75 | % 归一化角度到 [-pi, pi) 76 | obj.theta1 = wrapToPi(obj.theta1); 77 | obj.theta2 = wrapToPi(obj.theta2); 78 | 79 | % 计算奖励 80 | reward = obj.calculateReward(); 81 | 82 | % 更新步数 83 | obj.stepCount = obj.stepCount + 1; 84 | 85 | % 检查是否完成 86 | done = obj.stepCount >= obj.maxSteps; 87 | 88 | % 获取观察 89 | nextAgentObs = obj.getAgentObservations(); 90 | nextJointObs = obj.getJointObservation(); 91 | 92 | % 准备额外信息 93 | info = struct(); 94 | info.theta1 = obj.theta1; 95 | info.theta2 = obj.theta2; 96 | info.dtheta1 = obj.dtheta1; 97 | info.dtheta2 = obj.dtheta2; 98 | end 99 | 100 | function [theta1_next, theta2_next, dtheta1_next, dtheta2_next] = dynamics(obj, theta1, theta2, dtheta1, dtheta2, torque1, torque2) 101 | % 双倒立摆的动力学方程 102 | % 使用欧拉法进行积分 103 | 104 | % 参数简写 105 | g = obj.g; 106 | m1 = obj.m1; 107 | m2 = obj.m2; 108 | l1 = obj.l1; 109 | l2 = obj.l2; 110 | dt = obj.dt; 111 | 112 | % 计算动力学 113 | % 简化的双摆动力学方程(实际实现中应该使用完整的拉格朗日方程) 114 | 115 | % 计算辅助变量 116 | delta = theta2 - theta1; 117 | sdelta = sin(delta); 118 | cdelta = cos(delta); 119 | 120 | % 计算分母项 121 | den1 = (m1 + m2) * l1 - m2 * l1 * cdelta * cdelta; 122 | den2 = (l2 / l1) * den1; 123 | 124 | % 计算角加速度 125 | ddtheta1 = ((m2 * l2 * dtheta2^2 * sdelta) - (m2 * g * sin(theta2) * cdelta) + 126 | (m2 * l1 * dtheta1^2 * sdelta * cdelta) + 127 | (torque1 + (m1 + m2) * g * sin(theta1))) / den1; 128 | 129 | ddtheta2 = ((-l1 / l2) * dtheta1^2 * sdelta - (g / l2) * sin(theta2) + 130 | (torque2 / (m2 * l2)) - (cdelta * ddtheta1)) / (1 - cdelta^2 / (den2)); 131 | 132 | % 使用欧拉法更新 133 | dtheta1_next = dtheta1 + ddtheta1 * dt; 134 | dtheta2_next = dtheta2 + ddtheta2 * dt; 135 | theta1_next = theta1 + dtheta1_next * dt; 136 | theta2_next = theta2 + dtheta2_next * dt; 137 | end 138 | 139 | function reward = calculateReward(obj) 140 | % 计算奖励值 141 | 142 | % 角度偏差,使用余弦相似度来衡量接近目标的程度 143 | angle_error1 = 1 - cos(obj.theta1 - obj.targetTheta1); 144 | angle_error2 = 1 - cos(obj.theta2 - obj.targetTheta2); 145 | 146 | % 速度惩罚 147 | velocity_penalty1 = 0.1 * obj.dtheta1^2; 148 | velocity_penalty2 = 0.1 * obj.dtheta2^2; 149 | 150 | % 总奖励 151 | reward = -(angle_error1 + angle_error2 + velocity_penalty1 + velocity_penalty2); 152 | end 153 | 154 | function agentObservations = getAgentObservations(obj) 155 | % 获取每个智能体的观察 156 | 157 | % 智能体1的观察:自己的角度和角速度,以及对方的状态 158 | obs1 = [ 159 | sin(obj.theta1); 160 | cos(obj.theta1); 161 | obj.dtheta1 / 10.0; % 归一化 162 | sin(obj.theta2 - obj.theta1) % 相对角度信息 163 | ]; 164 | 165 | % 智能体2的观察:自己的角度和角速度,以及对方的状态 166 | obs2 = [ 167 | sin(obj.theta2); 168 | cos(obj.theta2); 169 | obj.dtheta2 / 10.0; % 归一化 170 | sin(obj.theta2 - obj.theta1) % 相对角度信息 171 | ]; 172 | 173 | % 包装为cell数组 174 | agentObservations = {obs1, obs2}; 175 | end 176 | 177 | function jointObs = getJointObservation(obj) 178 | % 获取联合观察(用于中央评论家) 179 | 180 | jointObs = [ 181 | sin(obj.theta1); 182 | cos(obj.theta1); 183 | obj.dtheta1 / 10.0; 184 | sin(obj.theta2); 185 | cos(obj.theta2); 186 | obj.dtheta2 / 10.0; 187 | sin(obj.theta2 - obj.theta1); 188 | cos(obj.theta2 - obj.theta1) 189 | ]; 190 | end 191 | 192 | function render(obj) 193 | % 可视化双倒立摆系统 194 | 195 | % 第一次调用时创建图形 196 | if isempty(obj.renderFig) || ~isvalid(obj.renderFig) 197 | obj.renderFig = figure('Name', '双倒立摆', 'NumberTitle', 'off'); 198 | obj.renderAx = axes('XLim', [-2.5, 2.5], 'YLim', [-2.5, 2.5]); 199 | hold(obj.renderAx, 'on'); 200 | grid(obj.renderAx, 'on'); 201 | axis(obj.renderAx, 'equal'); 202 | xlabel(obj.renderAx, 'X'); 203 | ylabel(obj.renderAx, 'Y'); 204 | title(obj.renderAx, '双倒立摆系统'); 205 | 206 | % 创建摆杆线条对象 207 | obj.pendulum1Line = line(obj.renderAx, [0, 0], [0, 0], 'LineWidth', 3, 'Color', 'blue'); 208 | obj.pendulum2Line = line(obj.renderAx, [0, 0], [0, 0], 'LineWidth', 3, 'Color', 'red'); 209 | end 210 | 211 | % 计算摆杆端点坐标 212 | x0 = 0; 213 | y0 = 0; 214 | x1 = x0 + obj.l1 * sin(obj.theta1); 215 | y1 = y0 - obj.l1 * cos(obj.theta1); 216 | x2 = x1 + obj.l2 * sin(obj.theta2); 217 | y2 = y1 - obj.l2 * cos(obj.theta2); 218 | 219 | % 更新线条位置 220 | set(obj.pendulum1Line, 'XData', [x0, x1], 'YData', [y0, y1]); 221 | set(obj.pendulum2Line, 'XData', [x1, x2], 'YData', [y1, y2]); 222 | 223 | % 更新标题显示当前步数 224 | title(obj.renderAx, sprintf('双倒立摆系统 - 步数: %d', obj.stepCount)); 225 | 226 | % 刷新图形 227 | drawnow; 228 | end 229 | 230 | function result = isDiscreteAction(obj, agentIdx) 231 | % 判断动作空间是否离散 232 | % 本环境使用连续动作空间 233 | result = false; 234 | end 235 | 236 | function size = observationSize(obj, agentIdx) 237 | % 返回指定智能体的观察空间维度 238 | size = obj.obsSize(agentIdx); 239 | end 240 | 241 | function size = actionSize(obj, agentIdx) 242 | % 返回指定智能体的动作空间维度 243 | size = obj.actionSize(agentIdx); 244 | end 245 | 246 | function close(obj) 247 | % 关闭环境和释放资源 248 | if ~isempty(obj.renderFig) && isvalid(obj.renderFig) 249 | close(obj.renderFig); 250 | end 251 | end 252 | end 253 | end 254 | -------------------------------------------------------------------------------- /environments/Environment.m: -------------------------------------------------------------------------------- 1 | classdef (Abstract) Environment < handle 2 | % Environment 强化学习环境的基类 3 | % 所有环境都应该继承这个基类并实现其方法 4 | 5 | properties (Abstract) 6 | observationSize % 观察空间大小 7 | actionSize % 动作空间大小 8 | isDiscrete % 是否为离散动作空间 9 | end 10 | 11 | methods (Abstract) 12 | % 重置环境到初始状态 13 | % 返回初始观察 14 | obs = reset(obj) 15 | 16 | % 执行动作并返回新的状态 17 | % action - 要执行的动作 18 | % 返回值: 19 | % nextObs - 新的观察 20 | % reward - 获得的奖励 21 | % done - 是否回合结束 22 | % info - 附加信息(可选) 23 | [nextObs, reward, done, info] = step(obj, action) 24 | 25 | % 渲染环境(可选) 26 | render(obj) 27 | end 28 | 29 | methods 30 | function validateAction(obj, action) 31 | % 验证动作是否合法 32 | % action - 要验证的动作 33 | 34 | assert(length(action) == obj.actionSize, ... 35 | '动作维度不匹配:期望 %d,实际 %d', obj.actionSize, length(action)); 36 | end 37 | 38 | function obs = normalizeObservation(obj, obs) 39 | % 标准化观察(子类可以重写此方法) 40 | % obs - 原始观察 41 | % 返回值:标准化后的观察 42 | end 43 | 44 | function action = normalizeAction(obj, action) 45 | % 标准化动作(子类可以重写此方法) 46 | % action - 原始动作 47 | % 返回值:标准化后的动作 48 | end 49 | 50 | function actionValues = discreteToBox(obj, action) 51 | % 将离散动作转换为连续值(对于具有离散动作空间的环境) 52 | % action - 离散动作索引 53 | % 返回值:连续动作值 54 | end 55 | 56 | function actionIndex = boxToDiscrete(obj, action) 57 | % 将连续动作值转换为离散动作索引(对于具有离散动作空间的环境) 58 | % action - 连续动作值 59 | % 返回值:离散动作索引 60 | end 61 | 62 | function seed(obj, seedValue) 63 | % 设置随机种子以便结果可复现(子类可以重写此方法) 64 | % seedValue - 随机种子值 65 | rng(seedValue); 66 | end 67 | 68 | function info = getEnvInfo(obj) 69 | % 获取环境信息 70 | % 返回值:包含环境信息的结构体 71 | 72 | info = struct(); 73 | info.observationSize = obj.observationSize; 74 | info.actionSize = obj.actionSize; 75 | info.isDiscrete = obj.isDiscrete; 76 | end 77 | end 78 | end 79 | -------------------------------------------------------------------------------- /examples/test_acmotor.m: -------------------------------------------------------------------------------- 1 | % test_acmotor.m 2 | % 测试已训练的交流感应电机FOC控制系统PPO模型 3 | 4 | % 添加路径 5 | addpath('../'); 6 | addpath('../core'); 7 | addpath('../environments'); 8 | addpath('../config'); 9 | addpath('../utils'); 10 | 11 | % 模型路径 12 | modelPath = '../logs/acmotor/model_final.mat'; 13 | if ~exist(modelPath, 'file') 14 | % 如果找不到最终模型,尝试找到目录中的任意模型 15 | logDir = '../logs/acmotor'; 16 | files = dir(fullfile(logDir, 'model_*.mat')); 17 | if ~isempty(files) 18 | [~, idx] = max([files.datenum]); 19 | modelPath = fullfile(logDir, files(idx).name); 20 | fprintf('未找到最终模型,将使用最新模型: %s\n', modelPath); 21 | else 22 | error('找不到任何训练好的模型,请先运行train_acmotor.m训练模型'); 23 | end 24 | end 25 | 26 | % 加载配置 27 | config = PPOConfig(); 28 | config.envName = 'ACMotorEnv'; 29 | config.actorLayerSizes = [256, 256, 128, 64]; 30 | config.criticLayerSizes = [256, 256, 128, 64]; 31 | config.useGPU = true; % 根据需要设置是否使用GPU 32 | 33 | % 创建PPO代理 34 | agent = PPOAgent(config); 35 | 36 | % 加载模型 37 | fprintf('加载模型: %s\n', modelPath); 38 | agent.loadModel(modelPath); 39 | 40 | % 测试参数 41 | numEpisodes = 3; % 测试回合数 42 | renderTest = true; % 是否可视化测试过程 43 | testDuration = 5000; % 测试步数(由于交流电机采样率高) 44 | saveResults = true; % 是否保存测试结果 45 | testScenarios = { 46 | '正常负载运行', % 标准工作负载 47 | '突加负载响应', % 负载突增 48 | '速度阶跃响应' % 速度突变 49 | }; 50 | 51 | % 创建结果目录 52 | resultsDir = '../results/acmotor'; 53 | if saveResults && ~exist(resultsDir, 'dir') 54 | mkdir(resultsDir); 55 | end 56 | 57 | % 测试已训练的模型 58 | fprintf('开始测试,将执行%d个测试场景...\n', length(testScenarios)); 59 | 60 | % 创建结果结构 61 | testResults = struct(); 62 | 63 | for scenarioIdx = 1:length(testScenarios) 64 | scenarioName = testScenarios{scenarioIdx}; 65 | fprintf('\n开始场景 %d: %s\n', scenarioIdx, scenarioName); 66 | 67 | % 创建环境 68 | env = ACMotorEnv(); 69 | 70 | % 为不同场景设置不同的参数 71 | switch scenarioIdx 72 | case 1 % 正常负载运行 73 | % 使用默认参数 74 | env.loadProfile = [ 75 | 0.3 * env.nominalTorque; % 30%额定负载 76 | 0.3 * env.nominalTorque; 77 | 0.3 * env.nominalTorque; 78 | 0.3 * env.nominalTorque; 79 | 0.3 * env.nominalTorque; 80 | ]; 81 | targetSpeed = 0.8 * env.nominalSpeed; % 80%额定速度 82 | env.targetSpeed = targetSpeed; 83 | 84 | case 2 % 突加负载响应 85 | % 设置突加负载测试 86 | env.loadProfile = [ 87 | 0.2 * env.nominalTorque; % 初始轻载 88 | 0.8 * env.nominalTorque; % 突加到80%负载 89 | 0.8 * env.nominalTorque; 90 | 0.8 * env.nominalTorque; 91 | 0.2 * env.nominalTorque; % 恢复到轻载 92 | ]; 93 | env.loadChangeTime = [1000, 4000]; % 在这些步骤改变负载 94 | targetSpeed = 0.7 * env.nominalSpeed; % 70%额定速度 95 | env.targetSpeed = targetSpeed; 96 | 97 | case 3 % 速度阶跃响应 98 | % 设置速度阶跃测试 99 | env.loadProfile = [ 100 | 0.4 * env.nominalTorque; % 40%额定负载 101 | 0.4 * env.nominalTorque; 102 | 0.4 * env.nominalTorque; 103 | 0.4 * env.nominalTorque; 104 | 0.4 * env.nominalTorque; 105 | ]; 106 | % 初始速度设置 107 | initialSpeed = 0.4 * env.nominalSpeed; 108 | targetSpeed = 0.9 * env.nominalSpeed; % 将在测试中改变目标 109 | env.targetSpeed = initialSpeed; 110 | % 我们将在第1000步改变速度设定值 111 | speedChangeStep = 1000; 112 | end 113 | 114 | % 重置环境 115 | obs = env.reset(); 116 | 117 | % 初始化性能指标 118 | totalReward = 0; 119 | speedErrors = []; 120 | 121 | % 创建记录数据的结构 122 | recordData = struct(); 123 | recordData.time = []; 124 | recordData.speed = []; 125 | recordData.targetSpeed = []; 126 | recordData.id = []; 127 | recordData.iq = []; 128 | recordData.Te = []; 129 | recordData.Tl = []; 130 | recordData.Vd = []; 131 | recordData.Vq = []; 132 | recordData.reward = []; 133 | 134 | % 开始测试循环 135 | for step = 1:testDuration 136 | % 如果是速度阶跃测试,在指定步骤改变目标速度 137 | if scenarioIdx == 3 && step == speedChangeStep 138 | env.targetSpeed = targetSpeed; 139 | fprintf('在步骤 %d 将目标速度从 %.1f rad/s 改变到 %.1f rad/s\n', 140 | step, initialSpeed, targetSpeed); 141 | end 142 | 143 | % 转换为dlarray并根据需要迁移到GPU 144 | if agent.useGPU 145 | dlObs = dlarray(single(obs), 'CB'); 146 | dlObs = gpuArray(dlObs); 147 | else 148 | dlObs = dlarray(single(obs), 'CB'); 149 | end 150 | 151 | % 使用确定性策略(使用均值) 152 | action = agent.actorNet.getMeanAction(dlObs); 153 | 154 | % 转换为CPU并提取数值 155 | if agent.useGPU 156 | action = gather(extractdata(action)); 157 | else 158 | action = extractdata(action); 159 | end 160 | 161 | % 执行动作 162 | [obs, reward, done, info] = env.step(action); 163 | totalReward = totalReward + reward; 164 | 165 | % 记录性能指标 166 | speedError = abs(info.targetSpeed - info.speed) / env.maxSpeed; 167 | speedErrors(end+1) = speedError; 168 | 169 | % 记录数据用于分析 170 | recordData.time(end+1) = step * env.dt; 171 | recordData.speed(end+1) = info.speed; 172 | recordData.targetSpeed(end+1) = info.targetSpeed; 173 | recordData.id(end+1) = info.id; 174 | recordData.iq(end+1) = info.iq; 175 | recordData.Te(end+1) = info.Te; 176 | recordData.Tl(end+1) = info.Tl; 177 | recordData.Vd(end+1) = action(1) * env.maxVoltage; 178 | recordData.Vq(end+1) = action(2) * env.maxVoltage; 179 | recordData.reward(end+1) = reward; 180 | 181 | % 渲染环境(如需要) 182 | if renderTest && mod(step, 10) == 0 183 | env.render(); 184 | pause(0.001); % 降低渲染频率以避免过多图形更新 185 | end 186 | end 187 | 188 | % 计算性能指标 189 | avgSpeedError = mean(speedErrors); 190 | maxSpeedError = max(speedErrors); 191 | steadyStateError = mean(speedErrors(end-500:end)); % 最后500点的稳态误差 192 | 193 | % 计算上升时间和调节时间(仅对速度阶跃响应场景) 194 | if scenarioIdx == 3 195 | % 找到速度阶跃后的数据 196 | stepIndex = find(recordData.time >= speedChangeStep * env.dt, 1); 197 | postStepSpeed = recordData.speed(stepIndex:end); 198 | postStepTime = recordData.time(stepIndex:end) - recordData.time(stepIndex); 199 | 200 | % 计算上升时间(从10%到90%的响应时间) 201 | speedChange = targetSpeed - initialSpeed; 202 | tenPercent = initialSpeed + 0.1 * speedChange; 203 | ninetyPercent = initialSpeed + 0.9 * speedChange; 204 | 205 | tenPercentIndex = find(postStepSpeed >= tenPercent, 1); 206 | ninetyPercentIndex = find(postStepSpeed >= ninetyPercent, 1); 207 | 208 | if ~isempty(tenPercentIndex) && ~isempty(ninetyPercentIndex) 209 | riseTime = postStepTime(ninetyPercentIndex) - postStepTime(tenPercentIndex); 210 | else 211 | riseTime = NaN; 212 | end 213 | 214 | % 计算调节时间(达到并维持在最终值±5%之内) 215 | fivePercent = 0.05 * speedChange; 216 | steadyBand = [targetSpeed - fivePercent, targetSpeed + fivePercent]; 217 | 218 | for i = 1:length(postStepSpeed) 219 | if postStepSpeed(i) >= steadyBand(1) && postStepSpeed(i) <= steadyBand(2) 220 | % 检查是否之后的所有点都在稳态带内 221 | if all(postStepSpeed(i:end) >= steadyBand(1) & postStepSpeed(i:end) <= steadyBand(2)) 222 | settlingTime = postStepTime(i); 223 | break; 224 | end 225 | end 226 | end 227 | 228 | if ~exist('settlingTime', 'var') 229 | settlingTime = NaN; 230 | end 231 | else 232 | riseTime = NaN; 233 | settlingTime = NaN; 234 | end 235 | 236 | % 打印性能指标 237 | fprintf('场景 %d: %s 完成\n', scenarioIdx, scenarioName); 238 | fprintf(' 总奖励: %.2f\n', totalReward); 239 | fprintf(' 平均速度误差: %.4f\n', avgSpeedError); 240 | fprintf(' 最大速度误差: %.4f\n', maxSpeedError); 241 | fprintf(' 稳态误差: %.4f\n', steadyStateError); 242 | 243 | if scenarioIdx == 3 244 | fprintf(' 上升时间 (10%%-90%%): %.3f s\n', riseTime); 245 | fprintf(' 调节时间 (±5%%): %.3f s\n', settlingTime); 246 | end 247 | 248 | % 保存性能指标 249 | testResults.(sprintf('scenario%d', scenarioIdx)).name = scenarioName; 250 | testResults.(sprintf('scenario%d', scenarioIdx)).totalReward = totalReward; 251 | testResults.(sprintf('scenario%d', scenarioIdx)).avgSpeedError = avgSpeedError; 252 | testResults.(sprintf('scenario%d', scenarioIdx)).maxSpeedError = maxSpeedError; 253 | testResults.(sprintf('scenario%d', scenarioIdx)).steadyStateError = steadyStateError; 254 | testResults.(sprintf('scenario%d', scenarioIdx)).riseTime = riseTime; 255 | testResults.(sprintf('scenario%d', scenarioIdx)).settlingTime = settlingTime; 256 | testResults.(sprintf('scenario%d', scenarioIdx)).data = recordData; 257 | 258 | % 绘制并保存测试结果图 259 | if saveResults 260 | % 创建图形 261 | fig = figure('Name', ['交流电机控制 - ', scenarioName], 'Position', [100, 100, 1200, 900]); 262 | 263 | % 1. 速度响应 264 | subplot(3, 2, 1); 265 | plot(recordData.time, recordData.speed, 'b-', 'LineWidth', 1.5); 266 | hold on; 267 | plot(recordData.time, recordData.targetSpeed, 'r--', 'LineWidth', 1.5); 268 | title('速度响应'); 269 | xlabel('时间 (s)'); 270 | ylabel('速度 (rad/s)'); 271 | legend('实际速度', '目标速度', 'Location', 'best'); 272 | grid on; 273 | 274 | % 2. 速度误差 275 | subplot(3, 2, 2); 276 | speedErr = abs(recordData.speed - recordData.targetSpeed); 277 | plot(recordData.time, speedErr, 'b-', 'LineWidth', 1.5); 278 | title('速度误差'); 279 | xlabel('时间 (s)'); 280 | ylabel('误差 (rad/s)'); 281 | grid on; 282 | 283 | % 3. d-q轴电流 284 | subplot(3, 2, 3); 285 | plot(recordData.time, recordData.id, 'b-', 'LineWidth', 1.5); 286 | hold on; 287 | plot(recordData.time, recordData.iq, 'r-', 'LineWidth', 1.5); 288 | title('d-q轴电流'); 289 | xlabel('时间 (s)'); 290 | ylabel('电流 (A)'); 291 | legend('id', 'iq', 'Location', 'best'); 292 | grid on; 293 | 294 | % 4. d-q轴电压(控制信号) 295 | subplot(3, 2, 4); 296 | plot(recordData.time, recordData.Vd, 'b-', 'LineWidth', 1.5); 297 | hold on; 298 | plot(recordData.time, recordData.Vq, 'r-', 'LineWidth', 1.5); 299 | title('d-q轴电压(控制信号)'); 300 | xlabel('时间 (s)'); 301 | ylabel('电压 (V)'); 302 | legend('Vd', 'Vq', 'Location', 'best'); 303 | grid on; 304 | 305 | % 5. 转矩 306 | subplot(3, 2, 5); 307 | plot(recordData.time, recordData.Te, 'b-', 'LineWidth', 1.5); 308 | hold on; 309 | plot(recordData.time, recordData.Tl, 'r-', 'LineWidth', 1.5); 310 | title('转矩'); 311 | xlabel('时间 (s)'); 312 | ylabel('转矩 (N·m)'); 313 | legend('电磁转矩', '负载转矩', 'Location', 'best'); 314 | grid on; 315 | 316 | % 6. 瞬时奖励 317 | subplot(3, 2, 6); 318 | plot(recordData.time, recordData.reward, 'k-', 'LineWidth', 1.5); 319 | title('瞬时奖励'); 320 | xlabel('时间 (s)'); 321 | ylabel('奖励'); 322 | grid on; 323 | 324 | % 保存图形 325 | if saveResults 326 | saveName = fullfile(resultsDir, sprintf('scenario%d_%s.fig', scenarioIdx, strrep(scenarioName, ' ', '_'))); 327 | saveas(fig, saveName); 328 | saveName = fullfile(resultsDir, sprintf('scenario%d_%s.png', scenarioIdx, strrep(scenarioName, ' ', '_'))); 329 | saveas(fig, saveName); 330 | fprintf(' 已保存测试结果图到: %s\n', saveName); 331 | end 332 | end 333 | end 334 | 335 | % 汇总比较所有场景 336 | if saveResults && length(testScenarios) > 1 337 | % 创建表格比较所有场景 338 | scenarioNames = {}; 339 | avgErrors = []; 340 | maxErrors = []; 341 | steadyErrors = []; 342 | riseTimes = []; 343 | settlingTimes = []; 344 | totalRewards = []; 345 | 346 | for i = 1:length(testScenarios) 347 | scenarioNames{i} = testResults.(sprintf('scenario%d', i)).name; 348 | avgErrors(i) = testResults.(sprintf('scenario%d', i)).avgSpeedError; 349 | maxErrors(i) = testResults.(sprintf('scenario%d', i)).maxSpeedError; 350 | steadyErrors(i) = testResults.(sprintf('scenario%d', i)).steadyStateError; 351 | riseTimes(i) = testResults.(sprintf('scenario%d', i)).riseTime; 352 | settlingTimes(i) = testResults.(sprintf('scenario%d', i)).settlingTime; 353 | totalRewards(i) = testResults.(sprintf('scenario%d', i)).totalReward; 354 | end 355 | 356 | % 创建比较表格 357 | comparisonTable = table(scenarioNames', avgErrors', maxErrors', steadyErrors', ... 358 | riseTimes', settlingTimes', totalRewards', ... 359 | 'VariableNames', {'场景', '平均速度误差', '最大速度误差', '稳态误差', ... 360 | '上升时间', '调节时间', '总奖励'}); 361 | 362 | % 显示表格 363 | disp('所有场景性能对比:'); 364 | disp(comparisonTable); 365 | 366 | % 保存性能比较结果 367 | save(fullfile(resultsDir, 'test_results.mat'), 'testResults', 'comparisonTable'); 368 | fprintf('测试结果已保存到: %s\n', fullfile(resultsDir, 'test_results.mat')); 369 | 370 | % 绘制所有场景的速度响应比较图 371 | figure('Name', '所有场景速度响应比较', 'Position', [100, 100, 1200, 600]); 372 | 373 | % 速度响应 374 | subplot(1, 2, 1); 375 | hold on; 376 | colors = {'b', 'r', 'g', 'm', 'c'}; 377 | for i = 1:length(testScenarios) 378 | plot(testResults.(sprintf('scenario%d', i)).data.time, ... 379 | testResults.(sprintf('scenario%d', i)).data.speed, ... 380 | [colors{mod(i-1, length(colors))+1}, '-'], 'LineWidth', 1.5); 381 | end 382 | title('所有场景速度响应'); 383 | xlabel('时间 (s)'); 384 | ylabel('速度 (rad/s)'); 385 | legend(scenarioNames, 'Location', 'best'); 386 | grid on; 387 | 388 | % 速度误差 389 | subplot(1, 2, 2); 390 | hold on; 391 | for i = 1:length(testScenarios) 392 | speedErr = abs(testResults.(sprintf('scenario%d', i)).data.speed - ... 393 | testResults.(sprintf('scenario%d', i)).data.targetSpeed); 394 | plot(testResults.(sprintf('scenario%d', i)).data.time, ... 395 | speedErr, ... 396 | [colors{mod(i-1, length(colors))+1}, '-'], 'LineWidth', 1.5); 397 | end 398 | title('所有场景速度误差'); 399 | xlabel('时间 (s)'); 400 | ylabel('误差 (rad/s)'); 401 | legend(scenarioNames, 'Location', 'best'); 402 | grid on; 403 | 404 | % 保存比较图 405 | saveName = fullfile(resultsDir, 'scenarios_comparison.png'); 406 | saveas(gcf, saveName); 407 | fprintf('场景比较图已保存到: %s\n', saveName); 408 | end 409 | 410 | fprintf('\n测试完成!\n'); 411 | -------------------------------------------------------------------------------- /examples/test_cartpole.m: -------------------------------------------------------------------------------- 1 | % test_cartpole.m 2 | % 测试已训练的倒立摆PPO模型 3 | 4 | % 添加路径 5 | addpath('../'); 6 | addpath('../core'); 7 | addpath('../environments'); 8 | addpath('../config'); 9 | addpath('../utils'); 10 | 11 | % 模型路径 12 | modelPath = '../logs/cartpole/model_iter_100.mat'; 13 | 14 | % 加载配置 15 | config = PPOConfig(); 16 | config.envName = 'CartPoleEnv'; 17 | config.actorLayerSizes = [64, 64]; 18 | config.criticLayerSizes = [64, 64]; 19 | config.useGPU = true; % 根据需要设置是否使用GPU 20 | 21 | % 创建PPO代理 22 | agent = PPOAgent(config); 23 | 24 | % 加载模型 25 | fprintf('加载模型: %s\n', modelPath); 26 | agent.loadModel(modelPath); 27 | 28 | % 测试参数 29 | numEpisodes = 10; % 测试回合数 30 | renderTest = true; % 是否可视化测试过程 31 | 32 | % 测试已训练的模型 33 | fprintf('开始测试,共%d回合...\n', numEpisodes); 34 | totalReward = 0; 35 | totalSteps = 0; 36 | 37 | % 创建环境 38 | env = feval(config.envName); 39 | 40 | for episode = 1:numEpisodes 41 | % 重置环境 42 | obs = env.reset(); 43 | episodeReward = 0; 44 | steps = 0; 45 | done = false; 46 | 47 | while ~done 48 | % 转换为dlarray并根据需要迁移到GPU 49 | if agent.useGPU 50 | dlObs = dlarray(single(obs), 'CB'); 51 | dlObs = gpuArray(dlObs); 52 | else 53 | dlObs = dlarray(single(obs), 'CB'); 54 | end 55 | 56 | % 使用确定性策略(无探索) 57 | if isa(agent.actorNet, 'DiscreteActorNetwork') 58 | action = agent.actorNet.getBestAction(dlObs); 59 | else 60 | action = agent.actorNet.getMeanAction(dlObs); 61 | end 62 | 63 | % 转换为CPU并提取数值 64 | if agent.useGPU 65 | action = gather(extractdata(action)); 66 | else 67 | action = extractdata(action); 68 | end 69 | 70 | % 如果是离散动作,转换为索引 71 | if env.isDiscrete 72 | [~, actionIdx] = max(action); 73 | action = actionIdx - 1; % 转为0-索引 74 | end 75 | 76 | % 执行动作 77 | [obs, reward, done, ~] = env.step(action); 78 | 79 | % 更新统计 80 | episodeReward = episodeReward + reward; 81 | steps = steps + 1; 82 | 83 | % 如果需要渲染 84 | if renderTest 85 | env.render(); 86 | pause(0.01); % 控制渲染速度 87 | end 88 | end 89 | 90 | % 更新总统计 91 | totalReward = totalReward + episodeReward; 92 | totalSteps = totalSteps + steps; 93 | 94 | fprintf('回合 %d: 奖励 = %.2f, 步数 = %d\n', episode, episodeReward, steps); 95 | end 96 | 97 | % 打印测试结果摘要 98 | fprintf('\n测试结果摘要:\n'); 99 | fprintf('平均奖励: %.2f\n', totalReward / numEpisodes); 100 | fprintf('平均步数: %.2f\n', totalSteps / numEpisodes); 101 | fprintf('测试完成\n'); 102 | -------------------------------------------------------------------------------- /examples/test_dcmotor.m: -------------------------------------------------------------------------------- 1 | % test_dcmotor.m 2 | % 测试已训练的直流电机控制系统PPO模型 3 | 4 | % 添加路径 5 | addpath('../'); 6 | addpath('../core'); 7 | addpath('../environments'); 8 | addpath('../config'); 9 | addpath('../utils'); 10 | 11 | % 模型路径 12 | modelPath = '../logs/dcmotor/model_iter_200.mat'; 13 | if ~exist(modelPath, 'file') 14 | % 如果找不到指定迭代次数的模型,尝试找到目录中的任意模型 15 | logDir = '../logs/dcmotor'; 16 | files = dir(fullfile(logDir, 'model_iter_*.mat')); 17 | if ~isempty(files) 18 | [~, idx] = max([files.datenum]); 19 | modelPath = fullfile(logDir, files(idx).name); 20 | fprintf('未找到指定模型,将使用最新模型: %s\n', modelPath); 21 | else 22 | error('找不到任何训练好的模型'); 23 | end 24 | end 25 | 26 | % 加载配置 27 | config = PPOConfig(); 28 | config.envName = 'DCMotorEnv'; 29 | config.actorLayerSizes = [128, 128, 64]; 30 | config.criticLayerSizes = [128, 128, 64]; 31 | config.useGPU = true; % 根据需要设置是否使用GPU 32 | 33 | % 创建PPO代理 34 | agent = PPOAgent(config); 35 | 36 | % 加载模型 37 | fprintf('加载模型: %s\n', modelPath); 38 | agent.loadModel(modelPath); 39 | 40 | % 测试参数 41 | numEpisodes = 5; % 测试回合数 42 | renderTest = true; % 是否可视化测试过程 43 | changingTarget = true; % 是否在回合中改变目标角度 44 | 45 | % 测试已训练的模型 46 | fprintf('开始测试,共%d回合...\n', numEpisodes); 47 | totalReward = 0; 48 | totalSteps = 0; 49 | successCount = 0; 50 | 51 | % 创建环境 52 | env = feval(config.envName); 53 | 54 | % 创建记录数据的结构 55 | recordData = struct(); 56 | recordData.episodes = cell(numEpisodes, 1); 57 | 58 | for episode = 1:numEpisodes 59 | % 重置环境 60 | obs = env.reset(); 61 | episodeReward = 0; 62 | steps = 0; 63 | done = false; 64 | 65 | % 记录当前回合数据 66 | episodeData = struct(); 67 | episodeData.time = []; 68 | episodeData.angles = []; 69 | episodeData.targets = []; 70 | episodeData.angleDiffs = []; 71 | episodeData.speeds = []; 72 | episodeData.currents = []; 73 | episodeData.actions = []; 74 | episodeData.rewards = []; 75 | 76 | % 设定固定目标改变时间点(如果changingTarget为true) 77 | targetChangePoints = [100, 200, 300]; 78 | nextTargetChange = 1; 79 | 80 | while ~done 81 | % 转换为dlarray并根据需要迁移到GPU 82 | if agent.useGPU 83 | dlObs = dlarray(single(obs), 'CB'); 84 | dlObs = gpuArray(dlObs); 85 | else 86 | dlObs = dlarray(single(obs), 'CB'); 87 | end 88 | 89 | % 使用确定性策略(使用均值,无探索) 90 | action = agent.actorNet.getMeanAction(dlObs); 91 | 92 | % 转换为CPU并提取数值 93 | if agent.useGPU 94 | action = gather(extractdata(action)); 95 | else 96 | action = extractdata(action); 97 | end 98 | 99 | % 执行动作 100 | [obs, reward, done, info] = env.step(action); 101 | 102 | % 更新统计 103 | episodeReward = episodeReward + reward; 104 | steps = steps + 1; 105 | 106 | % 记录数据 107 | episodeData.time(end+1) = steps; 108 | episodeData.angles(end+1) = env.state(1); 109 | episodeData.targets(end+1) = env.targetAngle; 110 | episodeData.angleDiffs(end+1) = obs(1); % 观察中的角度差 111 | episodeData.speeds(end+1) = env.state(2); 112 | episodeData.currents(end+1) = env.state(3); 113 | episodeData.actions(end+1) = action; 114 | episodeData.rewards(end+1) = reward; 115 | 116 | % 如果需要随机改变目标角度 117 | if changingTarget && nextTargetChange <= length(targetChangePoints) && steps == targetChangePoints(nextTargetChange) 118 | env.resetTarget(); 119 | fprintf('回合 %d: 在步骤 %d 改变目标角度为 %.2f rad\n', episode, steps, env.targetAngle); 120 | nextTargetChange = nextTargetChange + 1; 121 | end 122 | 123 | % 如果需要渲染 124 | if renderTest 125 | env.render(); 126 | pause(0.01); % 控制渲染速度 127 | end 128 | end 129 | 130 | % 保存回合数据 131 | recordData.episodes{episode} = episodeData; 132 | 133 | % 判断是否成功完成任务 134 | finalDistance = info.distance; 135 | if finalDistance < 0.1 % 使用较小的阈值判断是否达到目标 136 | successCount = successCount + 1; 137 | end 138 | 139 | % 更新总统计 140 | totalReward = totalReward + episodeReward; 141 | totalSteps = totalSteps + steps; 142 | 143 | fprintf('回合 %d: 奖励 = %.2f, 步数 = %d, 最终角度误差 = %.4f rad\n', ... 144 | episode, episodeReward, steps, finalDistance); 145 | end 146 | 147 | % 打印测试结果摘要 148 | fprintf('\n测试结果摘要:\n'); 149 | fprintf('平均奖励: %.2f\n', totalReward / numEpisodes); 150 | fprintf('平均步数: %.2f\n', totalSteps / numEpisodes); 151 | fprintf('成功率: %.1f%%\n', successCount / numEpisodes * 100); 152 | 153 | % 绘制测试结果汇总图表 154 | figure('Name', '直流电机控制测试结果', 'Position', [100, 100, 1000, 800]); 155 | 156 | % 1. 角度跟踪性能 157 | subplot(2, 2, 1); 158 | hold on; 159 | for i = 1:numEpisodes 160 | plot(recordData.episodes{i}.time, recordData.episodes{i}.angleDiffs, 'LineWidth', 1.5); 161 | end 162 | title('角度误差'); 163 | xlabel('时间步'); 164 | ylabel('角度误差 (rad)'); 165 | grid on; 166 | legend(arrayfun(@(x) sprintf('回合 %d', x), 1:numEpisodes, 'UniformOutput', false)); 167 | 168 | % 2. 控制信号 169 | subplot(2, 2, 2); 170 | hold on; 171 | for i = 1:numEpisodes 172 | plot(recordData.episodes{i}.time, recordData.episodes{i}.actions, 'LineWidth', 1.5); 173 | end 174 | title('控制信号 (归一化电压)'); 175 | xlabel('时间步'); 176 | ylabel('控制信号 [-1,1]'); 177 | grid on; 178 | 179 | % 3. 角速度 180 | subplot(2, 2, 3); 181 | hold on; 182 | for i = 1:numEpisodes 183 | plot(recordData.episodes{i}.time, recordData.episodes{i}.speeds, 'LineWidth', 1.5); 184 | end 185 | title('角速度'); 186 | xlabel('时间步'); 187 | ylabel('角速度 (rad/s)'); 188 | grid on; 189 | 190 | % 4. 电流 191 | subplot(2, 2, 4); 192 | hold on; 193 | for i = 1:numEpisodes 194 | plot(recordData.episodes{i}.time, recordData.episodes{i}.currents, 'LineWidth', 1.5); 195 | end 196 | title('电机电流'); 197 | xlabel('时间步'); 198 | ylabel('电流 (A)'); 199 | grid on; 200 | 201 | % 如果想要显示更详细的单回合性能,可以单独绘制 202 | if numEpisodes > 0 203 | bestEpisode = 1; % 可以根据奖励选择最佳回合 204 | 205 | figure('Name', sprintf('直流电机控制 - 回合 %d 详细分析', bestEpisode), 'Position', [100, 100, 1000, 600]); 206 | 207 | % 1. 角度跟踪 208 | subplot(3, 1, 1); 209 | plot(recordData.episodes{bestEpisode}.time, recordData.episodes{bestEpisode}.angles, 'b-', 'LineWidth', 1.5); 210 | hold on; 211 | plot(recordData.episodes{bestEpisode}.time, recordData.episodes{bestEpisode}.targets, 'r--', 'LineWidth', 1.5); 212 | title('角度跟踪'); 213 | xlabel('时间步'); 214 | ylabel('角度 (rad)'); 215 | legend('实际角度', '目标角度'); 216 | grid on; 217 | 218 | % 2. 角速度 219 | subplot(3, 1, 2); 220 | plot(recordData.episodes{bestEpisode}.time, recordData.episodes{bestEpisode}.speeds, 'g-', 'LineWidth', 1.5); 221 | title('角速度'); 222 | xlabel('时间步'); 223 | ylabel('角速度 (rad/s)'); 224 | grid on; 225 | 226 | % 3. 控制信号与奖励 227 | subplot(3, 1, 3); 228 | yyaxis left; 229 | plot(recordData.episodes{bestEpisode}.time, recordData.episodes{bestEpisode}.actions, 'm-', 'LineWidth', 1.5); 230 | ylabel('控制信号 [-1,1]'); 231 | 232 | yyaxis right; 233 | plot(recordData.episodes{bestEpisode}.time, recordData.episodes{bestEpisode}.rewards, 'k-', 'LineWidth', 1); 234 | ylabel('奖励'); 235 | 236 | title('控制信号与奖励'); 237 | xlabel('时间步'); 238 | grid on; 239 | legend('控制信号', '奖励'); 240 | end 241 | 242 | fprintf('测试完成,分析图表已生成\n'); 243 | -------------------------------------------------------------------------------- /examples/test_doublependulum.m: -------------------------------------------------------------------------------- 1 | %% 测试MAPPO在双倒立摆问题上的性能 2 | % 本脚本用于评估训练好的MAPPO智能体在双倒立摆环境中的表现 3 | % 比较不同控制策略(如单智能体、多智能体)的性能差异 4 | 5 | clc; 6 | clear; 7 | close all; 8 | 9 | % 将必要的路径添加到MATLAB路径 10 | addpath('../core'); 11 | addpath('../environments'); 12 | addpath('../config'); 13 | addpath('../utils'); 14 | 15 | % 加载配置 16 | config = default_doublependulum_config(); 17 | 18 | % 创建MAPPO智能体 19 | fprintf('初始化MAPPO智能体...\n'); 20 | mappoAgent = MAPPOAgent(config); 21 | 22 | % 加载训练好的模型 23 | modelPath = fullfile(config.logDir, 'model_final.mat'); 24 | if exist(modelPath, 'file') 25 | fprintf('加载已训练的模型: %s\n', modelPath); 26 | mappoAgent.loadModel(modelPath); 27 | else 28 | error('找不到已训练的模型,请先运行train_doublependulum.m脚本'); 29 | end 30 | 31 | % 创建图形窗口以便记录结果 32 | figure('Name', '双倒立摆测试结果', 'Position', [100, 100, 1200, 600]); 33 | 34 | % 测试1:协作MAPPO控制 35 | fprintf('\n测试1:使用MAPPO多智能体协作控制\n'); 36 | numEpisodes = 3; 37 | maxSteps = 200; 38 | 39 | % 存储测试数据 40 | theta1_history = zeros(maxSteps, numEpisodes); 41 | theta2_history = zeros(maxSteps, numEpisodes); 42 | reward_history = zeros(maxSteps, numEpisodes); 43 | 44 | for ep = 1:numEpisodes 45 | fprintf('回合 %d/%d\n', ep, numEpisodes); 46 | 47 | % 重置环境 48 | env = DoublePendulumEnv(); 49 | [agentObs, ~] = env.reset(); 50 | 51 | % 存储轨迹信息 52 | ep_theta1 = zeros(maxSteps, 1); 53 | ep_theta2 = zeros(maxSteps, 1); 54 | ep_reward = zeros(maxSteps, 1); 55 | 56 | % 运行回合 57 | for t = 1:maxSteps 58 | % 为每个智能体选择动作 59 | actions = cell(config.numAgents, 1); 60 | 61 | for i = 1:config.numAgents 62 | if config.useGPU 63 | dlObs = dlarray(single(agentObs{i}), 'CB'); 64 | dlObs = gpuArray(dlObs); 65 | else 66 | dlObs = dlarray(single(agentObs{i}), 'CB'); 67 | end 68 | 69 | % 使用确定性策略(均值) 70 | action = mappoAgent.actorNets{i}.getMeanAction(dlObs); 71 | 72 | if config.useGPU 73 | action = gather(extractdata(action)); 74 | else 75 | action = extractdata(action); 76 | end 77 | 78 | actions{i} = action; 79 | end 80 | 81 | % 执行动作 82 | [agentObs, ~, reward, done, info] = env.step(actions); 83 | 84 | % 存储状态信息 85 | ep_theta1(t) = info.theta1; 86 | ep_theta2(t) = info.theta2; 87 | ep_reward(t) = reward; 88 | 89 | % 渲染 90 | env.render(); 91 | pause(0.01); % 减缓速度以便观察 92 | 93 | % 如果回合结束则提前停止 94 | if done 95 | break; 96 | end 97 | end 98 | 99 | % 存储轨迹数据 100 | theta1_history(:, ep) = ep_theta1; 101 | theta2_history(:, ep) = ep_theta2; 102 | reward_history(:, ep) = ep_reward; 103 | 104 | % 关闭环境 105 | env.close(); 106 | end 107 | 108 | % 绘制结果 109 | subplot(2, 2, 1); 110 | plot(theta1_history - pi); 111 | hold on; 112 | plot([0, maxSteps], [0, 0], 'k--'); 113 | hold off; 114 | title('MAPPO控制 - 摆杆1角度偏差'); 115 | xlabel('步数'); 116 | ylabel('角度偏差 (rad)'); 117 | grid on; 118 | legend('回合1', '回合2', '回合3', '目标'); 119 | 120 | subplot(2, 2, 2); 121 | plot(theta2_history - pi); 122 | hold on; 123 | plot([0, maxSteps], [0, 0], 'k--'); 124 | hold off; 125 | title('MAPPO控制 - 摆杆2角度偏差'); 126 | xlabel('步数'); 127 | ylabel('角度偏差 (rad)'); 128 | grid on; 129 | legend('回合1', '回合2', '回合3', '目标'); 130 | 131 | subplot(2, 2, 3); 132 | plot(cumsum(reward_history)); 133 | title('MAPPO控制 - 累积奖励'); 134 | xlabel('步数'); 135 | ylabel('累积奖励'); 136 | grid on; 137 | legend('回合1', '回合2', '回合3'); 138 | 139 | % 测试2:演示为什么单一智能体PPO无法有效解决此问题 140 | fprintf('\n测试2:演示单一智能体的局限性\n'); 141 | 142 | % 创建一个简化版的单一控制器场景 143 | % 注意:这里我们使用单一控制器(对两个摆杆使用相同的力矩) 144 | % 这是为了模拟单一PPO智能体的行为,并展示它的局限性 145 | 146 | subplot(2, 2, 4); 147 | hold on; 148 | 149 | % 创建环境 150 | env = DoublePendulumEnv(); 151 | [agentObs, ~] = env.reset(); 152 | 153 | % 记录数据 154 | single_theta1 = zeros(maxSteps, 1); 155 | single_theta2 = zeros(maxSteps, 1); 156 | 157 | % 运行单一控制场景(两个摆杆使用相同的力矩) 158 | for t = 1:maxSteps 159 | % 获取第一个智能体的动作决策 160 | if config.useGPU 161 | dlObs = dlarray(single(agentObs{1}), 'CB'); 162 | dlObs = gpuArray(dlObs); 163 | else 164 | dlObs = dlarray(single(agentObs{1}), 'CB'); 165 | end 166 | 167 | action1 = mappoAgent.actorNets{1}.getMeanAction(dlObs); 168 | 169 | if config.useGPU 170 | action1 = gather(extractdata(action1)); 171 | else 172 | action1 = extractdata(action1); 173 | end 174 | 175 | % 对两个摆杆使用相同的控制动作(模拟单一PPO控制) 176 | actions = {action1, action1}; 177 | 178 | % 执行动作 179 | [agentObs, ~, ~, done, info] = env.step(actions); 180 | 181 | % 记录数据 182 | single_theta1(t) = info.theta1; 183 | single_theta2(t) = info.theta2; 184 | 185 | % 如果回合结束则提前停止 186 | if done 187 | break; 188 | end 189 | end 190 | 191 | % 绘制单一控制的结果 192 | plot(single_theta1 - pi, 'r-', 'LineWidth', 2); 193 | plot(single_theta2 - pi, 'b-', 'LineWidth', 2); 194 | plot([0, maxSteps], [0, 0], 'k--'); 195 | title('单一控制策略(模拟单一PPO)'); 196 | xlabel('步数'); 197 | ylabel('角度偏差 (rad)'); 198 | grid on; 199 | legend('摆杆1', '摆杆2', '目标'); 200 | hold off; 201 | 202 | % 打印分析结果 203 | fprintf('\n分析结果:\n'); 204 | fprintf('1. MAPPO多智能体方法:\n'); 205 | fprintf(' - 能够协调两个控制器分别控制两个摆杆\n'); 206 | fprintf(' - 实现了有效的倒立位置稳定控制\n'); 207 | fprintf(' - 每个智能体专注于自己的摆杆控制,但同时考虑另一个摆杆的状态\n\n'); 208 | 209 | fprintf('2. 单一控制策略(模拟单一PPO):\n'); 210 | fprintf(' - 无法同时有效控制两个摆杆\n'); 211 | fprintf(' - 由于两个摆杆动力学特性不同,使用相同的控制信号会导致不稳定\n'); 212 | fprintf(' - 这证明了为什么这个问题需要多智能体方法而不能用单一PPO解决\n\n'); 213 | 214 | fprintf('结论: 双倒立摆问题是一个多智能体合作任务的典型例子,\n'); 215 | fprintf(' MAPPO能够有效地解决这个单一PPO无法解决的问题。\n'); 216 | 217 | % 保存结果图 218 | saveas(gcf, fullfile(config.logDir, 'doublependulum_test_results.png')); 219 | fprintf('测试结果图已保存到: %s\n', fullfile(config.logDir, 'doublependulum_test_results.png')); 220 | -------------------------------------------------------------------------------- /examples/train_acmotor.m: -------------------------------------------------------------------------------- 1 | % train_acmotor.m 2 | % 训练交流感应电机控制系统的PPO代理 3 | 4 | % 添加路径 5 | addpath('../'); 6 | addpath('../core'); 7 | addpath('../environments'); 8 | addpath('../config'); 9 | addpath('../utils'); 10 | 11 | % 创建日志目录 12 | logDir = '../logs/acmotor'; 13 | if ~exist(logDir, 'dir') 14 | mkdir(logDir); 15 | end 16 | 17 | % 加载配置 18 | config = PPOConfig(); 19 | 20 | % 环境配置 21 | config.envName = 'ACMotorEnv'; 22 | 23 | % 网络配置 - 复杂的交流电机系统需要更深更宽的网络 24 | config.actorLayerSizes = [256, 256, 128, 64]; 25 | config.criticLayerSizes = [256, 256, 128, 64]; 26 | 27 | % 算法超参数 28 | config.gamma = 0.99; 29 | config.lambda = 0.95; 30 | config.epsilon = 0.1; 31 | config.entropyCoef = 0.001; 32 | config.vfCoef = 0.5; 33 | config.maxGradNorm = 0.5; 34 | 35 | % 优化器配置 36 | config.actorLearningRate = 5e-5; 37 | config.criticLearningRate = 5e-5; 38 | config.momentum = 0.9; 39 | 40 | % 训练配置 41 | config.numIterations = 400; 42 | config.batchSize = 256; 43 | config.epochsPerIter = 15; 44 | config.trajectoryLen = 500; 45 | config.numTrajectories = 32; 46 | 47 | % 硬件配置 48 | config.useGPU = true; 49 | 50 | % 日志配置 51 | config.logDir = logDir; 52 | config.evalFreq = 10; 53 | config.numEvalEpisodes = 5; 54 | config.saveModelFreq = 40; 55 | 56 | % 创建PPO代理 57 | fprintf('正在初始化PPO代理...\n'); 58 | agent = PPOAgent(config); 59 | 60 | % 创建Logger实例 61 | logger = Logger(config.logDir); 62 | 63 | % 训练代理 64 | fprintf('开始训练交流感应电机控制系统...\n'); 65 | fprintf('训练过程中会模拟工业环境下的负载突变场景\n'); 66 | 67 | % 训练循环 68 | for iter = 1:config.numIterations 69 | fprintf('迭代 %d/%d\n', iter, config.numIterations); 70 | 71 | % 收集轨迹 72 | fprintf(' 收集轨迹数据...\n'); 73 | trajectories = agent.collectTrajectories(config.trajectoryLen, config.numTrajectories); 74 | 75 | % 计算优势和回报 76 | fprintf(' 计算优势估计...\n'); 77 | trajectories = agent.computeAdvantagesAndReturns(trajectories); 78 | 79 | % 更新策略 80 | fprintf(' 更新策略...\n'); 81 | [actorLoss, criticLoss, entropy] = agent.updatePolicy(trajectories, config.epochsPerIter, config.batchSize); 82 | 83 | % 记录训练信息 84 | meanReturn = mean([trajectories.return]); 85 | meanLength = mean([trajectories.length]); 86 | 87 | % 打印当前迭代的指标 88 | fprintf(' 平均回报: %.2f\n', meanReturn); 89 | fprintf(' 平均长度: %.2f\n', meanLength); 90 | fprintf(' Actor损失: %.4f\n', actorLoss); 91 | fprintf(' Critic损失: %.4f\n', criticLoss); 92 | fprintf(' 策略熵: %.4f\n', entropy); 93 | 94 | % 将训练指标记录到logger 95 | logger.logTrainingMetrics(iter, meanReturn, meanLength, actorLoss, criticLoss, entropy); 96 | 97 | % 评估当前策略 98 | if mod(iter, config.evalFreq) == 0 99 | fprintf(' 评估当前策略...\n'); 100 | evalResult = agent.evaluate(config.numEvalEpisodes); 101 | fprintf(' 评估平均回报: %.2f ± %.2f\n', evalResult.meanReturn, evalResult.stdReturn); 102 | 103 | % 记录评估指标 104 | logger.logEvaluationMetrics(iter, evalResult.meanReturn, evalResult.stdReturn, evalResult.minReturn, evalResult.maxReturn); 105 | 106 | % 可视化速度跟踪性能 107 | visualizePerformance(agent, config.envName); 108 | end 109 | 110 | % 保存模型 111 | if mod(iter, config.saveModelFreq) == 0 112 | modelPath = fullfile(config.logDir, sprintf('model_iter_%d.mat', iter)); 113 | fprintf(' 保存模型到: %s\n', modelPath); 114 | agent.saveModel(modelPath); 115 | end 116 | end 117 | 118 | % 训练完成后保存最终模型 119 | finalModelPath = fullfile(config.logDir, 'model_final.mat'); 120 | fprintf('训练完成,保存最终模型到: %s\n', finalModelPath); 121 | agent.saveModel(finalModelPath); 122 | 123 | % 最终评估 124 | fprintf('进行最终评估...\n'); 125 | evalResult = agent.evaluate(20); 126 | 127 | % 显示评估结果 128 | fprintf('最终评估结果:\n'); 129 | fprintf(' 平均回报: %.2f ± %.2f\n', evalResult.meanReturn, evalResult.stdReturn); 130 | fprintf(' 最小回报: %.2f\n', evalResult.minReturn); 131 | fprintf(' 最大回报: %.2f\n', evalResult.maxReturn); 132 | fprintf(' 平均回合长度: %.2f\n', evalResult.meanLength); 133 | 134 | % 绘制训练曲线 135 | logger.plotTrainingCurves(); 136 | 137 | % 可视化控制性能的函数 138 | function visualizePerformance(agent, envName) 139 | % 创建环境 140 | env = feval(envName); 141 | 142 | % 重置环境 143 | obs = env.reset(); 144 | done = false; 145 | 146 | % 创建图形 147 | figure('Name', '交流电机控制性能评估', 'Position', [100, 100, 1000, 800]); 148 | 149 | % 速度响应子图 150 | subplot(2, 2, 1); 151 | speedPlot = plot(0, 0, 'b-', 'LineWidth', 1.5); 152 | hold on; 153 | targetPlot = plot(0, 0, 'r--', 'LineWidth', 1.5); 154 | title('速度响应'); 155 | xlabel('时间 (s)'); 156 | ylabel('速度 (rad/s)'); 157 | legend('实际速度', '目标速度'); 158 | grid on; 159 | 160 | % 电流响应子图 161 | subplot(2, 2, 2); 162 | idPlot = plot(0, 0, 'b-', 'LineWidth', 1.5); 163 | hold on; 164 | iqPlot = plot(0, 0, 'r-', 'LineWidth', 1.5); 165 | title('d-q轴电流'); 166 | xlabel('时间 (s)'); 167 | ylabel('电流 (A)'); 168 | legend('id', 'iq'); 169 | grid on; 170 | 171 | % 控制信号子图 172 | subplot(2, 2, 3); 173 | VdPlot = plot(0, 0, 'b-', 'LineWidth', 1.5); 174 | hold on; 175 | VqPlot = plot(0, 0, 'r-', 'LineWidth', 1.5); 176 | title('控制信号 (d-q轴电压)'); 177 | xlabel('时间 (s)'); 178 | ylabel('电压 (V)'); 179 | legend('Vd', 'Vq'); 180 | grid on; 181 | 182 | % 转矩子图 183 | subplot(2, 2, 4); 184 | tePlot = plot(0, 0, 'b-', 'LineWidth', 1.5); 185 | hold on; 186 | tlPlot = plot(0, 0, 'r-', 'LineWidth', 1.5); 187 | title('转矩'); 188 | xlabel('时间 (s)'); 189 | ylabel('转矩 (N·m)'); 190 | legend('电磁转矩', '负载转矩'); 191 | grid on; 192 | 193 | % 数据记录 194 | timeData = []; 195 | speedData = []; 196 | targetSpeedData = []; 197 | idData = []; 198 | iqData = []; 199 | VdData = []; 200 | VqData = []; 201 | TeData = []; 202 | TlData = []; 203 | 204 | % 评估最多500步 205 | maxSteps = 500; 206 | step = 0; 207 | 208 | % 运行一个回合 209 | while ~done && step < maxSteps 210 | % 转换为dlarray并根据需要迁移到GPU 211 | if agent.useGPU 212 | dlObs = dlarray(single(obs), 'CB'); 213 | dlObs = gpuArray(dlObs); 214 | else 215 | dlObs = dlarray(single(obs), 'CB'); 216 | end 217 | 218 | % 使用确定性策略(使用均值) 219 | action = agent.actorNet.getMeanAction(dlObs); 220 | 221 | % 转换为CPU并提取数值 222 | if agent.useGPU 223 | action = gather(extractdata(action)); 224 | else 225 | action = extractdata(action); 226 | end 227 | 228 | % 执行动作 229 | [obs, reward, done, info] = env.step(action); 230 | 231 | % 记录数据 232 | step = step + 1; 233 | timeData(end+1) = step * env.dt; 234 | speedData(end+1) = info.speed; 235 | targetSpeedData(end+1) = info.targetSpeed; 236 | idData(end+1) = info.id; 237 | iqData(end+1) = info.iq; 238 | VdData(end+1) = action(1) * env.maxVoltage; 239 | VqData(end+1) = action(2) * env.maxVoltage; 240 | TeData(end+1) = info.Te; 241 | TlData(end+1) = info.Tl; 242 | 243 | % 每10步更新图形 244 | if mod(step, 10) == 0 245 | % 更新速度图 246 | speedPlot.XData = timeData; 247 | speedPlot.YData = speedData; 248 | targetPlot.XData = timeData; 249 | targetPlot.YData = targetSpeedData; 250 | 251 | % 更新电流图 252 | idPlot.XData = timeData; 253 | idPlot.YData = idData; 254 | iqPlot.XData = timeData; 255 | iqPlot.YData = iqData; 256 | 257 | % 更新控制信号图 258 | VdPlot.XData = timeData; 259 | VdPlot.YData = VdData; 260 | VqPlot.XData = timeData; 261 | VqPlot.YData = VqData; 262 | 263 | % 更新转矩图 264 | tePlot.XData = timeData; 265 | tePlot.YData = TeData; 266 | tlPlot.XData = timeData; 267 | tlPlot.YData = TlData; 268 | 269 | % 调整坐标轴 270 | for i = 1:4 271 | subplot(2, 2, i); 272 | xlim([0, max(timeData)]); 273 | end 274 | 275 | % 刷新图形 276 | drawnow; 277 | end 278 | end 279 | 280 | % 保存评估图 281 | saveas(gcf, fullfile(agent.config.logDir, 'performance_evaluation.png')); 282 | 283 | % 关闭图形 284 | pause(1); 285 | close; 286 | end 287 | -------------------------------------------------------------------------------- /examples/train_cartpole.m: -------------------------------------------------------------------------------- 1 | % train_cartpole.m 2 | % 训练倒立摆环境示例脚本 3 | 4 | % 添加路径 5 | addpath('../'); 6 | addpath('../core'); 7 | addpath('../environments'); 8 | addpath('../config'); 9 | addpath('../utils'); 10 | 11 | % 创建日志目录 12 | logDir = '../logs/cartpole'; 13 | if ~exist(logDir, 'dir') 14 | mkdir(logDir); 15 | end 16 | 17 | % 加载配置 18 | config = PPOConfig(); 19 | 20 | % 环境配置 21 | config.envName = 'CartPoleEnv'; 22 | 23 | % 网络配置 24 | config.actorLayerSizes = [64, 64]; 25 | config.criticLayerSizes = [64, 64]; 26 | 27 | % 算法超参数 28 | config.gamma = 0.99; 29 | config.lambda = 0.95; 30 | config.epsilon = 0.2; 31 | config.entropyCoef = 0.01; 32 | config.vfCoef = 0.5; 33 | config.maxGradNorm = 0.5; 34 | 35 | % 优化器配置 36 | config.actorLearningRate = 3e-4; 37 | config.criticLearningRate = 3e-4; 38 | config.momentum = 0.9; 39 | 40 | % 训练配置 41 | config.numIterations = 100; 42 | config.batchSize = 64; 43 | config.epochsPerIter = 4; 44 | config.trajectoryLen = 200; 45 | 46 | % 硬件配置 47 | config.useGPU = true; 48 | 49 | % 日志配置 50 | config.logDir = logDir; 51 | config.evalFreq = 5; 52 | config.numEvalEpisodes = 10; 53 | config.saveModelFreq = 10; 54 | 55 | % 创建PPO代理 56 | agent = PPOAgent(config); 57 | 58 | % 训练代理 59 | fprintf('开始训练倒立摆环境...\n'); 60 | agent.train(config.numIterations); 61 | 62 | % 训练完成后评估 63 | fprintf('训练完成,开始评估...\n'); 64 | evalResult = agent.evaluate(20); 65 | 66 | % 显示评估结果 67 | fprintf('评估结果:\n'); 68 | fprintf(' 平均回报: %.2f ± %.2f\n', evalResult.meanReturn, evalResult.stdReturn); 69 | fprintf(' 最小回报: %.2f\n', evalResult.minReturn); 70 | fprintf(' 最大回报: %.2f\n', evalResult.maxReturn); 71 | fprintf(' 平均回合长度: %.2f\n', evalResult.meanLength); 72 | 73 | % 可视化一个回合 74 | fprintf('可视化一个回合...\n'); 75 | env = feval(config.envName); 76 | obs = env.reset(); 77 | done = false; 78 | 79 | while ~done 80 | % 转换为dlarray并根据需要迁移到GPU 81 | if agent.useGPU 82 | dlObs = dlarray(single(obs), 'CB'); 83 | dlObs = gpuArray(dlObs); 84 | else 85 | dlObs = dlarray(single(obs), 'CB'); 86 | end 87 | 88 | % 采样动作 89 | [action, ~] = agent.actorNet.sampleAction(dlObs); 90 | 91 | % 转换为CPU并提取数值 92 | if agent.useGPU 93 | action = gather(extractdata(action)); 94 | else 95 | action = extractdata(action); 96 | end 97 | 98 | % 执行动作 99 | [obs, reward, done, ~] = env.step(action); 100 | 101 | % 渲染环境 102 | env.render(); 103 | pause(0.01); 104 | end 105 | 106 | fprintf('演示完成\n'); 107 | -------------------------------------------------------------------------------- /examples/train_dcmotor.m: -------------------------------------------------------------------------------- 1 | % train_dcmotor.m 2 | % 训练直流电机控制系统的PPO代理 3 | 4 | % 添加路径 5 | addpath('../'); 6 | addpath('../core'); 7 | addpath('../environments'); 8 | addpath('../config'); 9 | addpath('../utils'); 10 | 11 | % 创建日志目录 12 | logDir = '../logs/dcmotor'; 13 | if ~exist(logDir, 'dir') 14 | mkdir(logDir); 15 | end 16 | 17 | % 加载配置 18 | config = PPOConfig(); 19 | 20 | % 环境配置 21 | config.envName = 'DCMotorEnv'; 22 | 23 | % 网络配置 - 连续控制问题推荐使用更深的网络 24 | config.actorLayerSizes = [128, 128, 64]; 25 | config.criticLayerSizes = [128, 128, 64]; 26 | 27 | % 算法超参数 28 | config.gamma = 0.99; 29 | config.lambda = 0.95; 30 | config.epsilon = 0.2; 31 | config.entropyCoef = 0.005; % 连续控制问题通常使用较小的熵系数 32 | config.vfCoef = 0.5; 33 | config.maxGradNorm = 0.5; 34 | 35 | % 优化器配置 36 | config.actorLearningRate = 1e-4; % 连续控制问题通常使用较小的学习率 37 | config.criticLearningRate = 1e-4; 38 | config.momentum = 0.9; 39 | 40 | % 训练配置 41 | config.numIterations = 200; 42 | config.batchSize = 128; 43 | config.epochsPerIter = 10; 44 | config.trajectoryLen = 250; 45 | 46 | % 硬件配置 47 | config.useGPU = true; 48 | 49 | % 日志配置 50 | config.logDir = logDir; 51 | config.evalFreq = 5; 52 | config.numEvalEpisodes = 10; 53 | config.saveModelFreq = 20; 54 | 55 | % 创建PPO代理 56 | fprintf('正在初始化PPO代理...\n'); 57 | agent = PPOAgent(config); 58 | 59 | % 训练代理 60 | fprintf('开始训练直流电机控制系统...\n'); 61 | agent.train(config.numIterations); 62 | 63 | % 训练完成后评估 64 | fprintf('训练完成,开始评估...\n'); 65 | evalResult = agent.evaluate(20); 66 | 67 | % 显示评估结果 68 | fprintf('评估结果:\n'); 69 | fprintf(' 平均回报: %.2f ± %.2f\n', evalResult.meanReturn, evalResult.stdReturn); 70 | fprintf(' 最小回报: %.2f\n', evalResult.minReturn); 71 | fprintf(' 最大回报: %.2f\n', evalResult.maxReturn); 72 | fprintf(' 平均回合长度: %.2f\n', evalResult.meanLength); 73 | 74 | % 可视化一个回合 75 | fprintf('可视化一个回合...\n'); 76 | env = feval(config.envName); 77 | obs = env.reset(); 78 | done = false; 79 | totalReward = 0; 80 | 81 | figure('Name', '直流电机控制测试', 'Position', [100, 100, 800, 600]); 82 | subplot(3, 1, 1); 83 | anglePlot = plot(0, 0, 'b-', 'LineWidth', 1.5); 84 | hold on; 85 | targetPlot = plot(0, 0, 'r--', 'LineWidth', 1.5); 86 | title('角度跟踪'); 87 | xlabel('时间步'); 88 | ylabel('角度 (rad)'); 89 | legend('实际角度', '目标角度'); 90 | grid on; 91 | 92 | subplot(3, 1, 2); 93 | speedPlot = plot(0, 0, 'g-', 'LineWidth', 1.5); 94 | title('角速度'); 95 | xlabel('时间步'); 96 | ylabel('角速度 (rad/s)'); 97 | grid on; 98 | 99 | subplot(3, 1, 3); 100 | actionPlot = plot(0, 0, 'm-', 'LineWidth', 1.5); 101 | title('控制信号'); 102 | xlabel('时间步'); 103 | ylabel('电压比例 [-1,1]'); 104 | grid on; 105 | 106 | % 记录数据 107 | timeSteps = []; 108 | angles = []; 109 | targets = []; 110 | speeds = []; 111 | actions = []; 112 | 113 | step = 0; 114 | while ~done 115 | % 转换为dlarray并根据需要迁移到GPU 116 | if agent.useGPU 117 | dlObs = dlarray(single(obs), 'CB'); 118 | dlObs = gpuArray(dlObs); 119 | else 120 | dlObs = dlarray(single(obs), 'CB'); 121 | end 122 | 123 | % 采样动作 124 | [action, ~] = agent.actorNet.sampleAction(dlObs); 125 | 126 | % 转换为CPU并提取数值 127 | if agent.useGPU 128 | action = gather(extractdata(action)); 129 | else 130 | action = extractdata(action); 131 | end 132 | 133 | % 执行动作 134 | [obs, reward, done, info] = env.step(action); 135 | totalReward = totalReward + reward; 136 | 137 | % 记录数据 138 | step = step + 1; 139 | timeSteps(end+1) = step; 140 | angles(end+1) = env.state(1); 141 | targets(end+1) = info.targetAngle; 142 | speeds(end+1) = env.state(2); 143 | actions(end+1) = action; 144 | 145 | % 更新图形 146 | anglePlot.XData = timeSteps; 147 | anglePlot.YData = angles; 148 | targetPlot.XData = timeSteps; 149 | targetPlot.YData = targets; 150 | speedPlot.XData = timeSteps; 151 | speedPlot.YData = speeds; 152 | actionPlot.XData = timeSteps; 153 | actionPlot.YData = actions; 154 | 155 | % 调整X轴范围 156 | for i = 1:3 157 | subplot(3, 1, i); 158 | xlim([1, max(1, step)]); 159 | end 160 | subplot(3, 1, 1); 161 | ylim([0, 2*pi]); 162 | 163 | % 渲染环境 164 | env.render(); 165 | 166 | pause(0.01); 167 | end 168 | 169 | fprintf('演示完成,总奖励: %.2f\n', totalReward); 170 | -------------------------------------------------------------------------------- /examples/train_doublependulum.m: -------------------------------------------------------------------------------- 1 | %% 使用MAPPO训练双倒立摆控制器 2 | % 本脚本展示了如何使用MAPPO算法训练多智能体来协同控制双倒立摆系统 3 | % 这是一个只能用多智能体方法而不能用单一智能体方法解决的问题示例 4 | 5 | clc; 6 | clear; 7 | close all; 8 | 9 | % 将必要的路径添加到MATLAB路径 10 | addpath('../core'); 11 | addpath('../environments'); 12 | addpath('../config'); 13 | addpath('../utils'); 14 | 15 | % 加载配置 16 | config = default_doublependulum_config(); 17 | 18 | % 创建日志目录 19 | if ~exist(config.logDir, 'dir') 20 | mkdir(config.logDir); 21 | end 22 | 23 | % 创建MAPPO智能体 24 | fprintf('初始化MAPPO智能体...\n'); 25 | mappoAgent = MAPPOAgent(config); 26 | 27 | % 训练智能体 28 | fprintf('开始训练,总迭代次数: %d\n', config.numIterations); 29 | mappoAgent.train(config.numIterations); 30 | 31 | % 训练完成 32 | fprintf('训练完成!最终模型已保存到 %s\n', fullfile(config.logDir, 'model_final.mat')); 33 | 34 | % 训练后评估和可视化 35 | fprintf('评估训练后的智能体性能...\n'); 36 | numTestEpisodes = 5; 37 | renderEpisodes = true; 38 | 39 | % 评估并可视化 40 | testResults = mappoAgent.evaluate(numTestEpisodes, renderEpisodes); 41 | 42 | fprintf('测试结果:\n'); 43 | fprintf(' 平均回报: %.2f ± %.2f\n', testResults.meanReturn, testResults.stdReturn); 44 | fprintf(' 最小回报: %.2f\n', testResults.minReturn); 45 | fprintf(' 最大回报: %.2f\n', testResults.maxReturn); 46 | fprintf(' 平均长度: %.2f\n', testResults.meanLength); 47 | 48 | % 绘制训练曲线 49 | mappoAgent.logger.plotTrainingCurves(); 50 | 51 | fprintf('训练和评估完成!\n'); 52 | fprintf('MAPPO成功解决了双倒立摆问题,这是一个需要多智能体协作的任务\n'); 53 | fprintf('单一智能体PPO无法有效解决此问题,因为它需要两个控制器协同工作\n'); 54 | -------------------------------------------------------------------------------- /utils/Logger.m: -------------------------------------------------------------------------------- 1 | classdef Logger < handle 2 | % Logger 训练日志记录器 3 | % 用于记录和可视化训练过程中的各种指标 4 | 5 | properties 6 | logDir % 日志保存目录 7 | envName % 环境名称 8 | trainStats % 训练统计数据 9 | evalStats % 评估统计数据 10 | saveFigs % 是否保存图表 11 | end 12 | 13 | methods 14 | function obj = Logger(logDir, envName) 15 | % 构造函数:初始化日志记录器 16 | % logDir - 日志保存目录 17 | % envName - 环境名称 18 | 19 | % 设置日志目录和环境名称 20 | obj.logDir = logDir; 21 | obj.envName = envName; 22 | 23 | % 创建日志目录(如果不存在) 24 | if ~exist(logDir, 'dir') 25 | mkdir(logDir); 26 | fprintf('创建日志目录: %s\n', logDir); 27 | end 28 | 29 | % 初始化统计数据 30 | obj.trainStats = struct(); 31 | obj.trainStats.iterations = []; 32 | obj.trainStats.actorLoss = []; 33 | obj.trainStats.criticLoss = []; 34 | obj.trainStats.entropyLoss = []; 35 | obj.trainStats.totalLoss = []; 36 | 37 | obj.evalStats = struct(); 38 | obj.evalStats.iterations = []; 39 | obj.evalStats.returns = []; 40 | obj.evalStats.lengths = []; 41 | 42 | % 默认保存图表 43 | obj.saveFigs = true; 44 | end 45 | 46 | function logIteration(obj, iteration, metrics) 47 | % 记录训练迭代的指标 48 | % iteration - 当前迭代次数 49 | % metrics - 包含各种指标的结构体 50 | 51 | % 添加到训练统计数据 52 | obj.trainStats.iterations(end+1) = iteration; 53 | obj.trainStats.actorLoss(end+1) = metrics.actorLoss; 54 | obj.trainStats.criticLoss(end+1) = metrics.criticLoss; 55 | obj.trainStats.entropyLoss(end+1) = metrics.entropyLoss; 56 | obj.trainStats.totalLoss(end+1) = metrics.totalLoss; 57 | 58 | % 打印当前迭代的指标 59 | fprintf('迭代 %d: Actor损失 = %.4f, Critic损失 = %.4f, 熵损失 = %.4f, 总损失 = %.4f\n', ... 60 | iteration, metrics.actorLoss, metrics.criticLoss, metrics.entropyLoss, metrics.totalLoss); 61 | 62 | % 每10次迭代绘制并保存训练曲线 63 | if mod(iteration, 10) == 0 64 | obj.plotTrainingCurves(); 65 | end 66 | end 67 | 68 | function logEvaluation(obj, iteration, evalResult) 69 | % 记录评估结果 70 | % iteration - 当前迭代次数 71 | % evalResult - 评估结果 72 | 73 | % 添加到评估统计数据 74 | obj.evalStats.iterations(end+1) = iteration; 75 | obj.evalStats.returns(end+1) = evalResult.meanReturn; 76 | obj.evalStats.lengths(end+1) = evalResult.meanLength; 77 | 78 | % 打印评估结果 79 | fprintf('评估 (迭代 %d): 平均回报 = %.2f ± %.2f, 最小 = %.2f, 最大 = %.2f, 平均长度 = %.2f\n', ... 80 | iteration, evalResult.meanReturn, evalResult.stdReturn, ... 81 | evalResult.minReturn, evalResult.maxReturn, evalResult.meanLength); 82 | 83 | % 绘制并保存评估曲线 84 | obj.plotEvaluationCurves(); 85 | end 86 | 87 | function plotTrainingCurves(obj) 88 | % 绘制训练曲线 89 | 90 | % 创建图形 91 | figure('Name', ['训练曲线 - ', obj.envName], 'Position', [100, 100, 1200, 800]); 92 | 93 | % 绘制Actor损失 94 | subplot(2, 2, 1); 95 | plot(obj.trainStats.iterations, obj.trainStats.actorLoss, 'b-', 'LineWidth', 1.5); 96 | title('Actor损失'); 97 | xlabel('迭代次数'); 98 | ylabel('损失值'); 99 | grid on; 100 | 101 | % 绘制Critic损失 102 | subplot(2, 2, 2); 103 | plot(obj.trainStats.iterations, obj.trainStats.criticLoss, 'r-', 'LineWidth', 1.5); 104 | title('Critic损失'); 105 | xlabel('迭代次数'); 106 | ylabel('损失值'); 107 | grid on; 108 | 109 | % 绘制熵损失 110 | subplot(2, 2, 3); 111 | plot(obj.trainStats.iterations, obj.trainStats.entropyLoss, 'g-', 'LineWidth', 1.5); 112 | title('熵损失'); 113 | xlabel('迭代次数'); 114 | ylabel('损失值'); 115 | grid on; 116 | 117 | % 绘制总损失 118 | subplot(2, 2, 4); 119 | plot(obj.trainStats.iterations, obj.trainStats.totalLoss, 'm-', 'LineWidth', 1.5); 120 | title('总损失'); 121 | xlabel('迭代次数'); 122 | ylabel('损失值'); 123 | grid on; 124 | 125 | % 调整图形布局 126 | sgtitle(['训练曲线 - ', obj.envName], 'FontSize', 16); 127 | 128 | % 保存图形 129 | if obj.saveFigs 130 | saveas(gcf, fullfile(obj.logDir, 'training_curves.png')); 131 | end 132 | end 133 | 134 | function plotEvaluationCurves(obj) 135 | % 绘制评估曲线 136 | 137 | % 如果没有评估数据,直接返回 138 | if isempty(obj.evalStats.iterations) 139 | return; 140 | end 141 | 142 | % 创建图形 143 | figure('Name', ['评估曲线 - ', obj.envName], 'Position', [100, 100, 1000, 500]); 144 | 145 | % 绘制平均回报 146 | subplot(1, 2, 1); 147 | plot(obj.evalStats.iterations, obj.evalStats.returns, 'b-o', 'LineWidth', 1.5); 148 | title('评估平均回报'); 149 | xlabel('迭代次数'); 150 | ylabel('平均回报'); 151 | grid on; 152 | 153 | % 绘制平均长度 154 | subplot(1, 2, 2); 155 | plot(obj.evalStats.iterations, obj.evalStats.lengths, 'r-o', 'LineWidth', 1.5); 156 | title('评估平均长度'); 157 | xlabel('迭代次数'); 158 | ylabel('平均长度'); 159 | grid on; 160 | 161 | % 调整图形布局 162 | sgtitle(['评估曲线 - ', obj.envName], 'FontSize', 16); 163 | 164 | % 保存图形 165 | if obj.saveFigs 166 | saveas(gcf, fullfile(obj.logDir, 'evaluation_curves.png')); 167 | end 168 | end 169 | 170 | function saveTrainingData(obj) 171 | % 保存训练数据 172 | trainData = obj.trainStats; 173 | evalData = obj.evalStats; 174 | save(fullfile(obj.logDir, 'training_data.mat'), 'trainData', 'evalData'); 175 | fprintf('训练数据已保存到: %s\n', fullfile(obj.logDir, 'training_data.mat')); 176 | end 177 | end 178 | end 179 | --------------------------------------------------------------------------------