├── LICENSE ├── README.md ├── requirements.txt └── src ├── .idea ├── .gitignore ├── MarsCodeWorkspaceAppSettings.xml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── src.iml ├── config └── default.yml ├── deep_learning ├── __pycache__ │ ├── models.cpython-312.pyc │ └── models.cpython-39.pyc ├── dl_estimator.py └── models.py ├── main.py ├── traditional ├── __pycache__ │ ├── estimators.cpython-312.pyc │ └── estimators.cpython-39.pyc ├── estimators.py └── ls_estimation.py └── utils ├── __pycache__ ├── config.cpython-312.pyc ├── config.cpython-39.pyc ├── data_generator.cpython-312.pyc ├── data_generator.cpython-39.pyc ├── preprocessing.cpython-312.pyc ├── preprocessing.cpython-39.pyc ├── trainer.cpython-312.pyc └── trainer.cpython-39.pyc ├── config.py ├── data_generator.py ├── preprocessing.py └── trainer.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 修明 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 信道估计方法比较研究 2 | 3 | 本项目实现了一个完整的信道估计方法比较研究框架,包括传统方法和深度学习方法的实现、评估和性能对比。 4 | 5 | ## 项目概述 6 | 7 | 本项目旨在比较不同信道估计方法的性能,包括: 8 | 9 | ### 传统方法 10 | - LS(最小二乘)估计 11 | - LMMSE(线性最小均方误差)估计 12 | - ML(最大似然)估计 13 | 14 | ### 深度学习方法 15 | - CNN(卷积神经网络) 16 | - RNN(循环神经网络) 17 | - LSTM(长短期记忆网络) 18 | - GRU(门控循环单元) 19 | - Hybrid(混合CNN-LSTM模型) 20 | 21 | ## 项目结构 22 | 23 | ``` 24 | project/ 25 | ├── src/ # 源代码目录 26 | │ ├── main.py # 主程序入口 27 | │ ├── config/ # 配置文件目录 28 | │ │ └── default.yml # 默认配置文件 29 | │ ├── traditional/ # 传统方法实现 30 | │ │ └── estimators.py # 传统估计器实现 31 | │ ├── deep_learning/ # 深度学习方法实现 32 | │ │ └── models.py # 深度学习模型实现 33 | │ └── utils/ # 工具函数 34 | │ ├── config.py # 配置加载器 35 | │ ├── data_generator.py # 数据生成器 36 | │ ├── preprocessing.py # 数据预处理 37 | │ └── trainer.py # 模型训练器 38 | ├── results/ # 实验结果目录 39 | │ ├── logs/ # 日志文件 40 | │ ├── models/ # 保存的模型 41 | │ └── plots/ # 结果图表 42 | └── README.md # 项目说明文档 43 | ``` 44 | 45 | ## 功能特性 46 | 47 | 1. **数据处理** 48 | - 支持多种信道模型(Rayleigh、Rician等) 49 | - 灵活的数据生成和预处理 50 | - 数据增强支持 51 | - 自动数据集划分(训练/验证/测试) 52 | 53 | 2. **模型实现** 54 | - 传统方法的矩阵运算优化 55 | - 多种深度学习架构 56 | - GPU加速支持 57 | - 模型保存和加载 58 | 59 | 3. **训练与评估** 60 | - 自动化训练流程 61 | - 早停机制 62 | - 学习率自适应调整 63 | - 多指标性能评估 64 | 65 | 4. **可视化与分析** 66 | - 训练过程可视化 67 | - 性能对比图表 68 | - 详细的日志记录 69 | - 结果导出和保存 70 | 71 | ## 配置说明 72 | 73 | ### 实验配置 74 | ```yaml 75 | experiment: 76 | name: "channel_estimation" # 实验名称 77 | seed: 42 # 随机种子 78 | save_dir: "results/models" # 模型保存目录 79 | plot_dir: "results/plots" # 图表保存目录 80 | use_cuda: true # 是否使用GPU 81 | 82 | channel: 83 | n_tx: 4 # 发射天线数 84 | n_rx: 4 # 接收天线数 85 | n_pilot: 16 # 导频长度 86 | snr_db: 10 # 信噪比(dB) 87 | n_samples: 10000 # 样本数量 88 | type: "rayleigh" # 信道类型 89 | rician_k: 1.0 # Rician K因子(仅Rician信道) 90 | 91 | data: 92 | preprocessing: 93 | normalization: "standard" # 标准化方法 94 | remove_outliers: true # 是否移除异常值 95 | outlier_threshold: 3.0 # 异常值阈值 96 | 97 | augmentation: 98 | enabled: true # 是否启用数据增强 99 | methods: ["noise", "phase"] # 增强方法 100 | noise_std: 0.01 # 噪声标准差 101 | phase_shift_range: [-0.1, 0.1] # 相位偏移范围 102 | 103 | split: 104 | train_ratio: 0.7 # 训练集比例 105 | val_ratio: 0.15 # 验证集比例 106 | 107 | loader: 108 | batch_size: 128 # 批次大小 109 | num_workers: 4 # 数据加载线程数 110 | pin_memory: true # 是否固定内存 111 | 112 | models: 113 | common: 114 | learning_rate: 0.001 # 学习率 115 | 116 | cnn: 117 | channels: [64, 128, 256] # CNN通道数 118 | kernel_size: 3 # 卷积核大小 119 | 120 | rnn: 121 | hidden_size: 256 # 隐藏层大小 122 | num_layers: 2 # 层数 123 | 124 | lstm: 125 | hidden_size: 256 126 | num_layers: 2 127 | 128 | gru: 129 | hidden_size: 256 130 | num_layers: 2 131 | 132 | hybrid: 133 | rnn_hidden_size: 256 # RNN部分隐藏层大小 134 | cnn_channels: [64, 128] # CNN部分通道数 135 | 136 | training: 137 | epochs: 100 # 训练轮数 138 | early_stopping: 139 | patience: 10 # 早停耐心值 140 | ``` 141 | 142 | ### 传统方法参数 143 | ```yaml 144 | traditional: 145 | ls: 146 | regularization: true # 是否使用正则化 147 | lambda: 0.01 # 正则化参数 148 | 149 | lmmse: 150 | adaptive_snr: true # 是否自适应SNR 151 | correlation_method: "empirical" # 相关矩阵计算方法 152 | 153 | ml: 154 | max_iter: 100 # 最大迭代次数 155 | tol: 1e-6 # 收敛阈值 156 | learning_rate: 0.01 # 学习率 157 | ``` 158 | 159 | ## 使用说明 160 | 161 | 1. **环境配置** 162 | ```bash 163 | # 创建虚拟环境 164 | python -m venv venv 165 | source venv/bin/activate # Linux/Mac 166 | # 或 167 | .\venv\Scripts\activate # Windows 168 | 169 | # 安装依赖 170 | pip install -r requirements.txt 171 | ``` 172 | 173 | 2. **运行实验** 174 | ```bash 175 | # 使用默认配置 176 | python src/main.py 177 | 178 | # 指定配置文件 179 | python src/main.py --config path/to/config.yml 180 | 181 | # 指定设备 182 | python src/main.py --device cuda 183 | 184 | # 设置随机种子 185 | python src/main.py --seed 42 186 | ``` 187 | 188 | 3. **查看结果** 189 | - 实验日志:`results/logs/` 190 | - 训练好的模型:`results/models/` 191 | - 性能对比图表:`results/plots/` 192 | - 实验结果JSON:`results/results_*.json` 193 | 194 | ## 性能指标 195 | 196 | 项目使用以下指标评估估计器性能: 197 | 198 | 1. **MSE (均方误差)** 199 | - 衡量估计值与真实值的平均平方差 200 | - 越小越好 201 | 202 | 2. **NMSE (归一化均方误差)** 203 | - 考虑信道功率归一化后的MSE 204 | - 消除信道功率差异的影响 205 | 206 | 3. **BER (误比特率)** 207 | - 在给定信道估计下的通信系统误比特率 208 | - 实际通信性能的直接指标 209 | 210 | ## 实验结果 211 | 212 | 实验会生成以下可视化结果: 213 | 214 | 1. **整体性能对比** 215 | - 所有方法的MSE柱状图对比 216 | - 直观展示各方法的估计精度 217 | 218 | 2. **训练过程分析** 219 | - 每个深度学习模型的训练损失曲线 220 | - 学习率变化曲线 221 | - 帮助理解模型训练动态 222 | 223 | 3. **多维度性能对比** 224 | - 传统方法和深度学习方法的雷达图 225 | - 从多个指标综合评估性能 226 | 227 | 4. **误差分布分析** 228 | - 各方法MSE的箱线图 229 | - 展示估计误差的统计特性 230 | 231 | ## 注意事项 232 | 233 | 1. **硬件要求** 234 | - 建议使用GPU进行训练 235 | - 至少8GB内存 236 | - 存储空间:约1GB(取决于实验规模) 237 | 238 | 2. **数据处理** 239 | - 注意数据预处理的标准化方法选择 240 | - 合理设置异常值阈值 241 | - 根据实际需求调整数据增强参数 242 | 243 | 3. **训练优化** 244 | - 适当调整批次大小和学习率 245 | - 注意早停参数的设置 246 | - 监控训练过程避免过拟合 247 | 248 | 4. **结果分析** 249 | - 综合考虑多个性能指标 250 | - 注意不同信道条件下的表现 251 | - 考虑计算复杂度和实时性要求 252 | 253 | # 作者信息 254 | 255 | - 作者:修明 256 | - 邮箱:lzmpt@qq.com 257 | 258 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.21.0 2 | torch>=1.9.0 3 | matplotlib>=3.4.0 4 | pyyaml>=5.4.0 5 | tqdm>=4.61.0 6 | tensorboard>=2.6.0 7 | scikit-learn>=0.24.0 8 | pandas>=1.3.0 9 | seaborn>=0.11.0 10 | rich>=10.0.0 11 | pytest>=6.2.0 # 用于测试 12 | black>=21.5b2 # 用于代码格式化 13 | flake8>=3.9.0 # 用于代码检查 14 | mypy>=0.910 # 用于类型检查 -------------------------------------------------------------------------------- /src/.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # 默认忽略的文件 2 | /shelf/ 3 | /workspace.xml 4 | # 基于编辑器的 HTTP 客户端请求 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /src/.idea/MarsCodeWorkspaceAppSettings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /src/.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 18 | -------------------------------------------------------------------------------- /src/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /src/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /src/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /src/.idea/src.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /src/config/default.yml: -------------------------------------------------------------------------------- 1 | # 信道估计实验配置文件 2 | 3 | # 实验基本设置 4 | experiment: 5 | # 实验名称,用于保存结果和日志 6 | name: 'channel_estimation' 7 | # 实验描述 8 | description: '信道估计方法比较研究' 9 | # 随机种子,用于复现实验结果 10 | seed: 42 11 | # 是否使用GPU 12 | use_cuda: true 13 | # 结果保存路径 14 | save_dir: 'results' 15 | # 结果图保存路径 16 | plot_dir: 'plots' 17 | # 是否启用详细日志 18 | verbose: true 19 | 20 | # 信道参数设置 21 | channel: 22 | # 发射天线数量 23 | n_tx: 4 24 | # 接收天线数量 25 | n_rx: 4 26 | # 导频符号长度(等于发射天线数) 27 | n_pilot: 16 28 | # 信噪比(dB) 29 | snr_db: 10 30 | # 采样数量 31 | n_samples: 10000 32 | # 信道类型:'rayleigh' 或 'rician' 33 | type: 'rayleigh' 34 | # Rician K因子(仅在type='rician'时有效) 35 | rician_k: 1.0 36 | 37 | # 数据处理设置 38 | data: 39 | # 数据预处理 40 | preprocessing: 41 | # 归一化方法:'z-score' 或 'min-max' 42 | normalization: 'z-score' 43 | # 是否移除异常值 44 | remove_outliers: true 45 | # 异常值阈值(标准差的倍数) 46 | outlier_threshold: 3.0 47 | 48 | # 数据增强 49 | augmentation: 50 | # 是否启用数据增强 51 | enabled: true 52 | # 增强方法列表 53 | methods: ['noise', 'phase_shift', 'magnitude_scale'] 54 | # 高斯噪声标准差 55 | noise_std: 0.1 56 | # 相位偏移范围(弧度) 57 | phase_shift_range: [-0.1, 0.1] 58 | # 幅度缩放范围 59 | magnitude_scale_range: [0.9, 1.1] 60 | 61 | # 数据集划分 62 | split: 63 | # 训练集比例 64 | train_ratio: 0.7 65 | # 验证集比例 66 | val_ratio: 0.15 67 | # 测试集比例(自动计算为1-train_ratio-val_ratio) 68 | test_ratio: 0.15 69 | 70 | # 数据加载 71 | loader: 72 | # 批次大小 73 | batch_size: 128 74 | # 是否打乱训练数据 75 | shuffle: true 76 | # 数据加载线程数 77 | num_workers: 4 78 | # 是否将数据固定在内存中 79 | pin_memory: true 80 | 81 | # 传统估计器设置 82 | traditional: 83 | # LS估计器参数 84 | ls: 85 | # 是否使用正则化 86 | regularization: true 87 | # 正则化参数 88 | lambda: 0.01 89 | 90 | # LMMSE估计器参数 91 | lmmse: 92 | # 是否使用自适应SNR估计 93 | adaptive_snr: true 94 | # 信道相关矩阵估计方法:'sample' 或 'theoretical' 95 | correlation_method: 'sample' 96 | 97 | # ML估计器参数 98 | ml: 99 | # 最大迭代次数 100 | max_iter: 100 101 | # 收敛阈值 102 | tol: 1e-6 103 | # 学习率 104 | learning_rate: 0.01 105 | 106 | # 深度学习模型设置 107 | models: 108 | # 通用设置 109 | common: 110 | # 学习率 111 | learning_rate: 0.001 112 | # 权重衰减 113 | weight_decay: 0.0001 114 | # Dropout率 115 | dropout: 0.3 116 | # 批归一化动量 117 | bn_momentum: 0.1 118 | 119 | # RNN模型参数 120 | rnn: 121 | # 隐藏层大小 122 | hidden_size: 256 123 | # 层数 124 | num_layers: 2 125 | # RNN类型:'vanilla', 'lstm', 'gru' 126 | rnn_type: 'vanilla' 127 | 128 | # LSTM模型参数 129 | lstm: 130 | # 隐藏层大小 131 | hidden_size: 256 132 | # 层数 133 | num_layers: 2 134 | # 是否使用双向LSTM 135 | bidirectional: true 136 | 137 | # GRU模型参数 138 | gru: 139 | # 隐藏层大小 140 | hidden_size: 256 141 | # 层数 142 | num_layers: 2 143 | # 是否使用双向GRU 144 | bidirectional: true 145 | 146 | # 混合模型参数 147 | hybrid: 148 | # RNN部分隐藏层大小 149 | rnn_hidden_size: 256 150 | # 全连接层大小 151 | fc_sizes: [512, 256] 152 | 153 | # 训练设置 154 | training: 155 | # 训练轮数 156 | epochs: 100 157 | # 早停设置 158 | early_stopping: 159 | # 是否启用早停 160 | enabled: true 161 | # 早停耐心值 162 | patience: 10 163 | # 最小改善阈值 164 | min_delta: 1e-4 165 | 166 | # 学习率调度器 167 | scheduler: 168 | # 调度器类型:'reduce_on_plateau', 'cosine', 'step' 169 | type: 'reduce_on_plateau' 170 | # 降低学习率的因子 171 | factor: 0.1 172 | # 降低学习率前的耐心值 173 | patience: 5 174 | # 最小学习率 175 | min_lr: 1e-6 176 | 177 | # 模型保存 178 | checkpointing: 179 | # 是否启用检查点保存 180 | enabled: true 181 | # 保存频率(每N个epoch) 182 | save_freq: 10 183 | # 是否只保存最佳模型 184 | save_best_only: false 185 | # 最大保存检查点数量 186 | max_to_keep: 5 187 | 188 | # 评估设置 189 | evaluation: 190 | # 评估指标列表 191 | metrics: ['mse', 'nmse', 'ber'] 192 | # 是否计算置信区间 193 | confidence_interval: true 194 | # 置信水平 195 | confidence_level: 0.95 196 | # 是否保存预测结果 197 | save_predictions: true 198 | 199 | # TensorBoard设置 200 | tensorboard: 201 | # 是否启用TensorBoard 202 | enabled: true 203 | # 更新频率(每N个batch) 204 | update_freq: 10 205 | # 是否记录梯度直方图 206 | histogram_freq: 1 207 | # 是否记录模型图 208 | write_graph: true 209 | # 是否记录配置文件 210 | write_config: true -------------------------------------------------------------------------------- /src/deep_learning/__pycache__/models.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/deep_learning/__pycache__/models.cpython-312.pyc -------------------------------------------------------------------------------- /src/deep_learning/__pycache__/models.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/deep_learning/__pycache__/models.cpython-39.pyc -------------------------------------------------------------------------------- /src/deep_learning/dl_estimator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import numpy as np 5 | 6 | class DLChannelEstimator(nn.Module): 7 | """基于深度学习的信道估计器""" 8 | 9 | def __init__(self, input_size, hidden_size=128): 10 | super(DLChannelEstimator, self).__init__() 11 | 12 | self.network = nn.Sequential( 13 | nn.Linear(input_size, hidden_size), 14 | nn.ReLU(), 15 | nn.BatchNorm1d(hidden_size), 16 | nn.Dropout(0.3), 17 | 18 | nn.Linear(hidden_size, hidden_size), 19 | nn.ReLU(), 20 | nn.BatchNorm1d(hidden_size), 21 | nn.Dropout(0.3), 22 | 23 | nn.Linear(hidden_size, input_size) 24 | ) 25 | 26 | self.criterion = nn.MSELoss() 27 | self.optimizer = None 28 | 29 | def forward(self, x): 30 | """前向传播""" 31 | return self.network(x) 32 | 33 | def train_model(self, train_loader, epochs=100, learning_rate=0.001): 34 | """ 35 | 训练模型 36 | 37 | 参数: 38 | train_loader: 训练数据加载器 39 | epochs: 训练轮数 40 | learning_rate: 学习率 41 | """ 42 | self.optimizer = optim.Adam(self.parameters(), lr=learning_rate) 43 | 44 | for epoch in range(epochs): 45 | total_loss = 0 46 | for batch_X, batch_Y in train_loader: 47 | self.optimizer.zero_grad() 48 | 49 | outputs = self(batch_X) 50 | loss = self.criterion(outputs, batch_Y) 51 | 52 | loss.backward() 53 | self.optimizer.step() 54 | 55 | total_loss += loss.item() 56 | 57 | if (epoch + 1) % 10 == 0: 58 | print(f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}') 59 | 60 | def estimate(self, X): 61 | """ 62 | 使用训练好的模型进行信道估计 63 | 64 | 参数: 65 | X: 输入信号 66 | 67 | 返回: 68 | 估计的信道响应 69 | """ 70 | self.eval() 71 | with torch.no_grad(): 72 | return self(X) 73 | 74 | def get_mse(self, y_true, y_pred): 75 | """计算均方误差""" 76 | return self.criterion(y_pred, y_true).item() -------------------------------------------------------------------------------- /src/deep_learning/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | class CNNEstimator(nn.Module): 6 | """基于CNN的信道估计器""" 7 | def __init__(self, input_size, n_rx=4, n_tx=8): 8 | super(CNNEstimator, self).__init__() 9 | 10 | self.n_features = input_size 11 | self.n_rx = n_rx 12 | self.n_tx = n_tx 13 | 14 | self.cnn = nn.Sequential( 15 | # 将输入重塑为 (batch_size, 1, -1),作为1D CNN的输入 16 | nn.Conv1d(1, 64, kernel_size=3, padding=1), 17 | nn.ReLU(), 18 | nn.BatchNorm1d(64), 19 | 20 | nn.Conv1d(64, 128, kernel_size=3, padding=1), 21 | nn.ReLU(), 22 | nn.BatchNorm1d(128), 23 | 24 | nn.Conv1d(128, 64, kernel_size=3, padding=1), 25 | nn.ReLU(), 26 | nn.BatchNorm1d(64), 27 | 28 | nn.Conv1d(64, 1, kernel_size=3, padding=1) 29 | ) 30 | 31 | # 全连接层 32 | self.fc = nn.Sequential( 33 | nn.Linear(self.n_features, 512), 34 | nn.ReLU(), 35 | nn.Dropout(0.3), 36 | nn.Linear(512, n_rx * n_tx) 37 | ) 38 | 39 | def forward(self, x): 40 | batch_size = x.size(0) 41 | # 重塑为1D CNN输入 42 | x = x.view(batch_size, 1, -1) 43 | x = self.cnn(x) 44 | x = x.view(batch_size, -1) 45 | x = self.fc(x) 46 | # 重塑输出为 [batch_size, n_rx, n_tx] 47 | return x.view(batch_size, self.n_rx, self.n_tx) 48 | 49 | class RNNEstimator(nn.Module): 50 | """基于RNN的信道估计器""" 51 | def __init__(self, input_size, hidden_size=256, num_layers=2, n_rx=4, n_tx=8): 52 | super(RNNEstimator, self).__init__() 53 | 54 | self.n_features = input_size 55 | self.seq_length = 8 # 将输入序列分成8个时间步 56 | self.feature_size = self.n_features // self.seq_length 57 | self.n_rx = n_rx 58 | self.n_tx = n_tx 59 | 60 | self.rnn = nn.RNN(self.feature_size, hidden_size, num_layers, batch_first=True) 61 | self.fc = nn.Sequential( 62 | nn.Linear(hidden_size, 512), 63 | nn.ReLU(), 64 | nn.Dropout(0.3), 65 | nn.Linear(512, n_rx * n_tx) 66 | ) 67 | 68 | def forward(self, x): 69 | batch_size = x.size(0) 70 | # 重塑为序列数据 71 | x = x.view(batch_size, self.seq_length, -1) 72 | out, _ = self.rnn(x) 73 | x = self.fc(out[:, -1, :]) 74 | # 重塑输出为 [batch_size, n_rx, n_tx] 75 | return x.view(batch_size, self.n_rx, self.n_tx) 76 | 77 | class LSTMEstimator(nn.Module): 78 | """基于LSTM的信道估计器""" 79 | def __init__(self, input_size, hidden_size=256, num_layers=2, n_rx=4, n_tx=8): 80 | super(LSTMEstimator, self).__init__() 81 | 82 | self.n_features = input_size 83 | self.seq_length = 8 84 | self.feature_size = self.n_features // self.seq_length 85 | self.n_rx = n_rx 86 | self.n_tx = n_tx 87 | 88 | self.lstm = nn.LSTM(self.feature_size, hidden_size, num_layers, batch_first=True) 89 | self.fc = nn.Sequential( 90 | nn.Linear(hidden_size, 512), 91 | nn.ReLU(), 92 | nn.Dropout(0.3), 93 | nn.Linear(512, n_rx * n_tx) 94 | ) 95 | 96 | def forward(self, x): 97 | batch_size = x.size(0) 98 | x = x.view(batch_size, self.seq_length, -1) 99 | out, (_, _) = self.lstm(x) 100 | x = self.fc(out[:, -1, :]) 101 | # 重塑输出为 [batch_size, n_rx, n_tx] 102 | return x.view(batch_size, self.n_rx, self.n_tx) 103 | 104 | class GRUEstimator(nn.Module): 105 | """基于GRU的信道估计器""" 106 | def __init__(self, input_size, hidden_size=256, num_layers=2, n_rx=4, n_tx=8): 107 | super(GRUEstimator, self).__init__() 108 | 109 | self.n_features = input_size 110 | self.seq_length = 8 111 | self.feature_size = self.n_features // self.seq_length 112 | self.n_rx = n_rx 113 | self.n_tx = n_tx 114 | 115 | self.gru = nn.GRU(self.feature_size, hidden_size, num_layers, batch_first=True) 116 | self.fc = nn.Sequential( 117 | nn.Linear(hidden_size, 512), 118 | nn.ReLU(), 119 | nn.Dropout(0.3), 120 | nn.Linear(512, n_rx * n_tx) 121 | ) 122 | 123 | def forward(self, x): 124 | batch_size = x.size(0) 125 | x = x.view(batch_size, self.seq_length, -1) 126 | out, _ = self.gru(x) 127 | x = self.fc(out[:, -1, :]) 128 | # 重塑输出为 [batch_size, n_rx, n_tx] 129 | return x.view(batch_size, self.n_rx, self.n_tx) 130 | 131 | class HybridEstimator(nn.Module): 132 | """混合深度学习估计器(CNN+LSTM)""" 133 | def __init__(self, input_size, hidden_size=256, n_rx=4, n_tx=8): 134 | super(HybridEstimator, self).__init__() 135 | 136 | self.n_features = input_size 137 | self.seq_length = 8 138 | self.feature_size = self.n_features // self.seq_length 139 | self.n_rx = n_rx 140 | self.n_tx = n_tx 141 | 142 | # 1D CNN特征提取 143 | self.cnn = nn.Sequential( 144 | nn.Conv1d(1, 64, kernel_size=3, padding=1), 145 | nn.ReLU(), 146 | nn.BatchNorm1d(64), 147 | nn.Conv1d(64, 32, kernel_size=3, padding=1), 148 | nn.ReLU(), 149 | nn.BatchNorm1d(32) 150 | ) 151 | 152 | # LSTM处理时序特征 153 | self.lstm = nn.LSTM(32 * self.feature_size, hidden_size, batch_first=True) 154 | 155 | # 全连接层 156 | self.fc = nn.Sequential( 157 | nn.Linear(hidden_size, 512), 158 | nn.ReLU(), 159 | nn.Dropout(0.3), 160 | nn.Linear(512, n_rx * n_tx) 161 | ) 162 | 163 | def forward(self, x): 164 | batch_size = x.size(0) 165 | 166 | # CNN特征提取 167 | x = x.view(batch_size, 1, -1) 168 | cnn_out = self.cnn(x) 169 | 170 | # 重塑以适应LSTM 171 | lstm_in = cnn_out.view(batch_size, self.seq_length, -1) 172 | 173 | # LSTM处理 174 | lstm_out, _ = self.lstm(lstm_in) 175 | 176 | # 全连接层 177 | x = self.fc(lstm_out[:, -1, :]) 178 | # 重塑输出为 [batch_size, n_rx, n_tx] 179 | return x.view(batch_size, self.n_rx, self.n_tx) -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | 信道估计方法比较研究 3 | 4 | 本模块是项目的主入口,实现了完整的信道估计实验流程,包括: 5 | 1. 数据生成与预处理 6 | 2. 传统方法评估(LS、LMMSE、ML) 7 | 3. 深度学习方法评估(CNN、RNN、LSTM、GRU、Hybrid) 8 | 4. 结果可视化与保存 9 | 10 | 主要功能: 11 | - 配置管理:支持通过YAML文件配置实验参数 12 | - 数据处理:生成、预处理、增强和划分数据集 13 | - 模型训练:支持多种深度学习模型的训练和评估 14 | - 结果分析:生成多种可视化图表进行性能对比 15 | - 日志记录:详细记录实验过程和结果 16 | 17 | 作者: lzm lzmpt@qq.com 18 | 日期: 2025-03-07 19 | """ 20 | 21 | import numpy as np 22 | import torch 23 | from torch.utils.data import TensorDataset, DataLoader 24 | import matplotlib.pyplot as plt 25 | from datetime import datetime 26 | import os 27 | from pathlib import Path 28 | import time 29 | import argparse 30 | import random 31 | from typing import Dict 32 | import yaml 33 | import json 34 | from typing import Any 35 | import warnings 36 | import sys 37 | import logging 38 | from contextlib import redirect_stdout 39 | 40 | # 过滤掉特定的警告 41 | warnings.filterwarnings('ignore', category=UserWarning, module='torch.optim.lr_scheduler') 42 | 43 | from traditional.estimators import LSEstimator, LMMSEEstimator, MLEstimator, calculate_performance_metrics 44 | from deep_learning.models import CNNEstimator, RNNEstimator, LSTMEstimator, GRUEstimator, HybridEstimator 45 | from utils.data_generator import ChannelDataGenerator 46 | from utils.preprocessing import ChannelPreprocessor, augment_data 47 | from utils.trainer import ChannelEstimatorTrainer 48 | from utils.config import Config 49 | 50 | class TeeLogger: 51 | """ 52 | 同时将输出写入到控制台和文件的日志记录器 53 | 54 | 该类实现了一个双向输出流,可以: 55 | 1. 将所有输出同时发送到终端和日志文件 56 | 2. 实时刷新输出,确保日志及时记录 57 | 3. 支持标准输出流的所有基本操作 58 | 59 | 参数: 60 | filename (str): 日志文件的路径 61 | """ 62 | def __init__(self, filename): 63 | self.terminal = sys.stdout 64 | self.log_file = open(filename, 'w', encoding='utf-8') 65 | 66 | def write(self, message): 67 | """写入消息到终端和文件""" 68 | self.terminal.write(message) 69 | self.log_file.write(message) 70 | self.log_file.flush() 71 | 72 | def flush(self): 73 | """刷新输出缓冲区""" 74 | self.terminal.flush() 75 | self.log_file.flush() 76 | 77 | def setup_logging(save_dir: Path, timestamp: str): 78 | """ 79 | 设置日志记录系统 80 | 81 | 创建日志目录并初始化日志记录器,支持: 82 | 1. 创建时间戳命名的日志文件 83 | 2. 同时输出到控制台和文件 84 | 3. 自动创建所需目录结构 85 | 86 | 参数: 87 | save_dir (Path): 保存目录的路径 88 | timestamp (str): 时间戳字符串 89 | 90 | 返回: 91 | TeeLogger: 配置好的日志记录器 92 | """ 93 | log_dir = save_dir / 'logs' 94 | log_dir.mkdir(parents=True, exist_ok=True) 95 | log_file = log_dir / f'experiment_{timestamp}.log' 96 | return TeeLogger(str(log_file)) 97 | 98 | def set_seed(seed): 99 | """ 100 | 设置随机种子以确保实验可重复性 101 | 102 | 统一设置所有相关库的随机种子,包括: 103 | 1. Python random模块 104 | 2. NumPy 105 | 3. PyTorch CPU 106 | 4. PyTorch GPU(如果可用) 107 | 5. CUDA后端(如果可用) 108 | 109 | 参数: 110 | seed (int): 随机种子值 111 | """ 112 | random.seed(seed) 113 | np.random.seed(seed) 114 | torch.manual_seed(seed) 115 | if torch.cuda.is_available(): 116 | torch.cuda.manual_seed(seed) 117 | torch.cuda.manual_seed_all(seed) 118 | torch.backends.cudnn.deterministic = True 119 | torch.backends.cudnn.benchmark = False 120 | 121 | def print_section_header(title): 122 | """ 123 | 打印带有时间戳的分节标题 124 | 125 | 创建醒目的分节标题,包括: 126 | 1. 分隔线 127 | 2. 当前时间戳 128 | 3. 节标题 129 | 130 | 参数: 131 | title (str): 节标题文本 132 | """ 133 | timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 134 | print(f"\n{'='*80}\n{timestamp} - {title}\n{'='*80}") 135 | 136 | def print_progress(message, indent=0): 137 | """ 138 | 打印带有时间戳的进度信息 139 | 140 | 格式化输出进度信息,包括: 141 | 1. 时间戳 142 | 2. 缩进级别 143 | 3. 具体消息 144 | 145 | 参数: 146 | message (str): 要打印的消息 147 | indent (int): 缩进级别(每级2个空格) 148 | """ 149 | timestamp = datetime.now().strftime("%H:%M:%S") 150 | indent_str = " " * indent 151 | print(f"[{timestamp}] {indent_str}{message}") 152 | 153 | def evaluate_traditional_methods(H: np.ndarray, X: np.ndarray, Y: np.ndarray, snr: float, cfg: dict) -> Dict[str, Dict[str, float]]: 154 | """ 155 | 评估传统信道估计方法的性能 156 | 157 | 实现了三种传统估计方法的评估: 158 | 1. LS(最小二乘)估计 159 | 2. LMMSE(线性最小均方误差)估计 160 | 3. ML(最大似然)估计 161 | 162 | 对每种方法: 163 | - 初始化估计器 164 | - 执行信道估计 165 | - 计算性能指标 166 | - 记录执行时间 167 | - 输出详细结果 168 | 169 | 参数: 170 | H (np.ndarray): 真实信道矩阵,shape=(n_samples, n_rx, n_tx) 171 | X (np.ndarray): 导频符号,shape=(n_samples, n_tx, n_pilot) 172 | Y (np.ndarray): 接收信号,shape=(n_samples, n_rx, n_pilot) 173 | snr (float): 信噪比(线性,非dB) 174 | cfg (dict): 配置字典,包含各估计器的参数 175 | 176 | 返回: 177 | Dict[str, Dict[str, float]]: 包含各估计器性能指标的嵌套字典 178 | """ 179 | results = {} 180 | 181 | # LS估计 182 | print_progress("开始 LS 估计...", 1) 183 | start_time = time.time() 184 | ls_estimator = LSEstimator( 185 | regularization=cfg['traditional']['ls']['regularization'], 186 | lambda_=cfg['traditional']['ls']['lambda'] 187 | ) 188 | H_est_ls = ls_estimator.estimate(Y, X) 189 | results['ls'] = calculate_performance_metrics(H, H_est_ls) 190 | print_progress(f"LS 估计完成 (耗时: {time.time()-start_time:.2f}s)", 1) 191 | print_progress("LS 性能指标:", 2) 192 | for metric, value in results['ls'].items(): 193 | print_progress(f"{metric}: {value:.6f}", 3) 194 | 195 | # LMMSE估计 196 | print_progress("\n开始 LMMSE 估计...", 1) 197 | start_time = time.time() 198 | lmmse_estimator = LMMSEEstimator( 199 | snr=snr, 200 | adaptive_snr=cfg['traditional']['lmmse']['adaptive_snr'], 201 | correlation_method=cfg['traditional']['lmmse']['correlation_method'] 202 | ) 203 | H_est_lmmse = lmmse_estimator.estimate(Y, X) 204 | results['lmmse'] = calculate_performance_metrics(H, H_est_lmmse) 205 | print_progress(f"LMMSE 估计完成 (耗时: {time.time()-start_time:.2f}s)", 1) 206 | print_progress("LMMSE 性能指标:", 2) 207 | for metric, value in results['lmmse'].items(): 208 | print_progress(f"{metric}: {value:.6f}", 3) 209 | 210 | # ML估计 211 | print_progress("\n开始 ML 估计...", 1) 212 | start_time = time.time() 213 | ml_estimator = MLEstimator( 214 | max_iter=cfg['traditional']['ml']['max_iter'], 215 | tol=cfg['traditional']['ml']['tol'], 216 | learning_rate=cfg['traditional']['ml']['learning_rate'] 217 | ) 218 | H_est_ml = ml_estimator.estimate(Y, X) 219 | results['ml'] = calculate_performance_metrics(H, H_est_ml) 220 | print_progress(f"ML 估计完成 (耗时: {time.time()-start_time:.2f}s)", 1) 221 | print_progress("ML 性能指标:", 2) 222 | for metric, value in results['ml'].items(): 223 | print_progress(f"{metric}: {value:.6f}", 3) 224 | 225 | return results 226 | 227 | def evaluate_dl_methods(train_loader, val_loader, test_loader, input_size, cfg): 228 | """ 229 | 评估深度学习方法的性能 230 | 231 | 实现了五种深度学习模型的训练和评估: 232 | 1. CNN(卷积神经网络) 233 | 2. RNN(循环神经网络) 234 | 3. LSTM(长短期记忆网络) 235 | 4. GRU(门控循环单元) 236 | 5. Hybrid(混合CNN-LSTM模型) 237 | 238 | 对每个模型: 239 | - 初始化模型结构 240 | - 配置训练器 241 | - 执行训练过程 242 | - 加载最佳模型 243 | - 在测试集上评估 244 | - 记录性能指标 245 | 246 | 参数: 247 | train_loader (DataLoader): 训练数据加载器 248 | val_loader (DataLoader): 验证数据加载器 249 | test_loader (DataLoader): 测试数据加载器 250 | input_size (int): 输入特征维度 251 | cfg (dict): 配置字典,包含模型和训练参数 252 | 253 | 返回: 254 | dict: 包含各模型训练结果和性能指标的字典 255 | """ 256 | device = torch.device('cuda' if cfg['experiment']['use_cuda'] and torch.cuda.is_available() else 'cpu') 257 | results = {} 258 | 259 | # 获取通用模型设置 260 | common_cfg = cfg['models']['common'] 261 | 262 | # 定义要评估的模型 263 | models = { 264 | 'CNN': CNNEstimator(input_size=input_size), 265 | 'RNN': RNNEstimator( 266 | input_size=input_size, 267 | hidden_size=cfg['models']['rnn']['hidden_size'], 268 | num_layers=cfg['models']['rnn']['num_layers'] 269 | ), 270 | 'LSTM': LSTMEstimator( 271 | input_size=input_size, 272 | hidden_size=cfg['models']['lstm']['hidden_size'], 273 | num_layers=cfg['models']['lstm']['num_layers'] 274 | ), 275 | 'GRU': GRUEstimator( 276 | input_size=input_size, 277 | hidden_size=cfg['models']['gru']['hidden_size'], 278 | num_layers=cfg['models']['gru']['num_layers'] 279 | ), 280 | 'Hybrid': HybridEstimator( 281 | input_size=input_size, 282 | hidden_size=cfg['models']['hybrid']['rnn_hidden_size'] 283 | ) 284 | } 285 | 286 | print_section_header("评估深度学习方法") 287 | 288 | # 评估每个模型 289 | for name, model in models.items(): 290 | print_progress(f"\n开始训练 {name} 模型...", 1) 291 | print_progress(f"模型结构:", 2) 292 | print_progress(str(model), 3) 293 | 294 | # 创建训练器 295 | trainer = ChannelEstimatorTrainer( 296 | model=model, 297 | train_loader=train_loader, 298 | val_loader=val_loader, 299 | save_dir=cfg['experiment']['save_dir'], 300 | project_name=f"{cfg['experiment']['name']}_{name.lower()}", 301 | device=device, 302 | learning_rate=common_cfg['learning_rate'] 303 | ) 304 | 305 | # 训练模型 306 | start_time = time.time() 307 | trainer.train( 308 | epochs=cfg['training']['epochs'], 309 | early_stopping_patience=cfg['training']['early_stopping']['patience'] 310 | ) 311 | training_time = time.time() - start_time 312 | print_progress(f"{name} 模型训练完成 (耗时: {training_time:.2f}s)", 1) 313 | 314 | # 加载最佳模型进行测试 315 | trainer.load_model('best.pt') 316 | model = trainer.model 317 | 318 | # 测试评估 319 | print_progress(f"开始评估 {name} 模型...", 1) 320 | model.eval() 321 | all_preds = [] 322 | all_true = [] 323 | 324 | test_start_time = time.time() 325 | with torch.no_grad(): 326 | for batch_X, batch_Y in test_loader: 327 | batch_X, batch_Y = batch_X.to(device), batch_Y.to(device) 328 | outputs = model(batch_X) 329 | all_preds.append(outputs.cpu().numpy()) 330 | all_true.append(batch_Y.cpu().numpy()) 331 | 332 | all_preds = np.concatenate(all_preds, axis=0) 333 | all_true = np.concatenate(all_true, axis=0) 334 | 335 | metrics = calculate_performance_metrics(all_true, all_preds) 336 | results[name] = { 337 | 'metrics': metrics, 338 | 'model': model, 339 | 'trainer': trainer 340 | } 341 | 342 | print_progress(f"{name} 模型评估完成 (耗时: {time.time()-test_start_time:.2f}s)", 1) 343 | print_progress(f"{name} 模型性能指标:", 2) 344 | for metric_name, value in metrics.items(): 345 | print_progress(f"{metric_name}: {value:.6f}", 3) 346 | 347 | return results 348 | 349 | def plot_results(traditional_results, dl_results, cfg): 350 | """ 351 | 绘制实验结果比较图表 352 | 353 | 生成四种类型的可视化图表: 354 | 1. 整体性能对比图:所有方法的MSE柱状图 355 | 2. 训练历史曲线:每个深度学习模型的训练过程 356 | 3. 性能指标雷达图:传统方法和深度学习方法的多指标对比 357 | 4. 误差分布箱线图:所有方法的MSE分布对比 358 | 359 | 参数: 360 | traditional_results (dict): 传统方法的评估结果 361 | dl_results (dict): 深度学习方法的评估结果 362 | cfg (dict): 配置字典,包含绘图相关参数 363 | """ 364 | print_section_header("绘制结果比较图") 365 | 366 | # 创建保存目录 367 | save_dir = Path(cfg['experiment']['plot_dir']) 368 | save_dir.mkdir(parents=True, exist_ok=True) 369 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 370 | 371 | print_progress("生成性能对比图...", 1) 372 | 373 | # 1. 整体MSE对比图 374 | plt.figure(figsize=(12, 6)) 375 | 376 | # 绘制传统方法结果 377 | x_traditional = np.arange(len(traditional_results)) 378 | mse_traditional = [results['mse'] for results in traditional_results.values()] 379 | plt.bar(x_traditional, mse_traditional, alpha=0.8, label='Traditional Methods') 380 | plt.xticks(x_traditional, traditional_results.keys()) 381 | 382 | # 绘制深度学习方法结果 383 | x_dl = np.arange(len(dl_results)) + len(traditional_results) 384 | mse_dl = [results['metrics']['mse'] for results in dl_results.values()] 385 | plt.bar(x_dl, mse_dl, alpha=0.8, label='Deep Learning Methods') 386 | plt.xticks(np.concatenate([x_traditional, x_dl]), 387 | list(traditional_results.keys()) + list(dl_results.keys()), 388 | rotation=45) 389 | 390 | plt.ylabel('MSE') 391 | plt.title('所有方法性能对比') 392 | plt.legend() 393 | plt.tight_layout() 394 | 395 | save_path = save_dir / f'overall_comparison_{timestamp}.png' 396 | plt.savefig(save_path) 397 | plt.close() 398 | 399 | print_progress(f"整体对比图已保存至: {save_path}", 1) 400 | 401 | # 2. 每个深度学习模型的训练历史 402 | print_progress("生成训练历史图...", 1) 403 | for name, result in dl_results.items(): 404 | trainer = result['trainer'] 405 | history = trainer.get_history() 406 | 407 | plt.figure(figsize=(15, 5)) 408 | 409 | # 训练损失和验证损失 410 | plt.subplot(1, 2, 1) 411 | plt.plot(history['train_loss'], label='训练损失') 412 | plt.plot(history['val_loss'], label='验证损失') 413 | plt.xlabel('Epoch') 414 | plt.ylabel('Loss') 415 | plt.title(f'{name} 模型训练历史') 416 | plt.legend() 417 | plt.grid(True) 418 | 419 | # 学习率变化 420 | plt.subplot(1, 2, 2) 421 | plt.plot(history['learning_rate'], label='学习率') 422 | plt.xlabel('Epoch') 423 | plt.ylabel('Learning Rate') 424 | plt.title('学习率变化') 425 | plt.legend() 426 | plt.grid(True) 427 | 428 | plt.tight_layout() 429 | save_path = save_dir / f'{name.lower()}_training_history_{timestamp}.png' 430 | plt.savefig(save_path) 431 | plt.close() 432 | 433 | print_progress(f"{name} 模型训练历史图已保存至: {save_path}", 2) 434 | 435 | # 3. 性能指标雷达图 436 | print_progress("生成性能指标雷达图...", 1) 437 | metrics = ['mse', 'nmse', 'ber'] # 根据实际指标调整 438 | 439 | # 传统方法雷达图 440 | plt.figure(figsize=(10, 10)) 441 | angles = np.linspace(0, 2*np.pi, len(metrics), endpoint=False) 442 | angles = np.concatenate((angles, [angles[0]])) # 闭合图形 443 | 444 | ax = plt.subplot(111, polar=True) 445 | for method, results in traditional_results.items(): 446 | values = [results[metric] for metric in metrics] 447 | values = np.concatenate((values, [values[0]])) # 闭合图形 448 | ax.plot(angles, values, 'o-', linewidth=2, label=method) 449 | ax.fill(angles, values, alpha=0.25) 450 | 451 | ax.set_xticks(angles[:-1]) 452 | ax.set_xticklabels(metrics) 453 | plt.title('传统方法性能指标对比') 454 | plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1)) 455 | 456 | save_path = save_dir / f'traditional_metrics_radar_{timestamp}.png' 457 | plt.savefig(save_path) 458 | plt.close() 459 | 460 | print_progress(f"传统方法性能指标雷达图已保存至: {save_path}", 2) 461 | 462 | # 深度学习方法雷达图 463 | plt.figure(figsize=(10, 10)) 464 | ax = plt.subplot(111, polar=True) 465 | for name, result in dl_results.items(): 466 | values = [result['metrics'][metric] for metric in metrics] 467 | values = np.concatenate((values, [values[0]])) # 闭合图形 468 | ax.plot(angles, values, 'o-', linewidth=2, label=name) 469 | ax.fill(angles, values, alpha=0.25) 470 | 471 | ax.set_xticks(angles[:-1]) 472 | ax.set_xticklabels(metrics) 473 | plt.title('深度学习方法性能指标对比') 474 | plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1)) 475 | 476 | save_path = save_dir / f'dl_metrics_radar_{timestamp}.png' 477 | plt.savefig(save_path) 478 | plt.close() 479 | 480 | print_progress(f"深度学习方法性能指标雷达图已保存至: {save_path}", 2) 481 | 482 | # 4. 箱线图比较 483 | print_progress("生成性能指标箱线图...", 1) 484 | plt.figure(figsize=(15, 6)) 485 | 486 | # 合并所有方法的结果 487 | all_methods = {} 488 | all_methods.update({k: {'mse': [v['mse']]} for k, v in traditional_results.items()}) 489 | all_methods.update({k: {'mse': [v['metrics']['mse']]} for k, v in dl_results.items()}) 490 | 491 | # 创建箱线图数据 492 | labels = [] 493 | mse_data = [] 494 | for method, data in all_methods.items(): 495 | labels.append(method) 496 | mse_data.append(data['mse']) 497 | 498 | plt.boxplot(mse_data, labels=labels) 499 | plt.xticks(rotation=45) 500 | plt.ylabel('MSE') 501 | plt.title('所有方法MSE分布对比') 502 | plt.grid(True) 503 | 504 | plt.tight_layout() 505 | save_path = save_dir / f'mse_boxplot_{timestamp}.png' 506 | plt.savefig(save_path) 507 | plt.close() 508 | 509 | print_progress(f"性能指标箱线图已保存至: {save_path}", 2) 510 | 511 | print_progress("所有结果图表已生成完成", 1) 512 | 513 | def parse_args(): 514 | """ 515 | 解析命令行参数 516 | 517 | 支持的参数: 518 | 1. --config: 配置文件路径 519 | 2. --device: 运行设备(cpu/cuda) 520 | 3. --seed: 随机种子 521 | 522 | 返回: 523 | argparse.Namespace: 解析后的参数对象 524 | """ 525 | parser = argparse.ArgumentParser(description='信道估计方法比较研究') 526 | parser.add_argument('--config', type=str, default='src/config/default.yml', 527 | help='配置文件路径') 528 | parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], 529 | help='运行设备') 530 | parser.add_argument('--seed', type=int, help='随机种子') 531 | return parser.parse_args() 532 | 533 | def load_config(config_path: str) -> dict: 534 | """ 535 | 加载配置文件 536 | 537 | 功能: 538 | 1. 读取YAML配置文件 539 | 2. 验证必要配置项 540 | 3. 检查配置完整性 541 | 542 | 参数: 543 | config_path (str): 配置文件路径 544 | 545 | 返回: 546 | dict: 配置字典 547 | 548 | 异常: 549 | ValueError: 当缺少必要的配置项时抛出 550 | """ 551 | with open(config_path, 'r', encoding='utf-8') as f: 552 | config = yaml.safe_load(f) 553 | 554 | # 验证必要的配置项 555 | required_configs = [ 556 | 'experiment.name', 557 | 'channel.n_tx', 558 | 'channel.n_rx', 559 | 'channel.n_pilot', 560 | 'channel.snr_db', 561 | 'channel.n_samples', 562 | 'channel.type', 563 | 'traditional.ls.regularization', 564 | 'traditional.ls.lambda', 565 | 'traditional.lmmse.adaptive_snr', 566 | 'traditional.lmmse.correlation_method', 567 | 'traditional.ml.max_iter', 568 | 'traditional.ml.tol' 569 | ] 570 | 571 | for config_path in required_configs: 572 | value = get_config_value(config, config_path) 573 | if value is None: 574 | raise ValueError(f"配置文件缺少必要项: {config_path}") 575 | 576 | return config 577 | 578 | def get_config_value(config: dict, path: str) -> Any: 579 | """ 580 | 从嵌套字典中获取值 581 | 582 | 使用点号分隔的路径从嵌套字典中获取值 583 | 例如:'models.cnn.channels' -> config['models']['cnn']['channels'] 584 | 585 | 参数: 586 | config (dict): 配置字典 587 | path (str): 以点分隔的配置路径 588 | 589 | 返回: 590 | Any: 配置值,如果路径不存在则返回None 591 | """ 592 | keys = path.split('.') 593 | value = config 594 | for key in keys: 595 | if not isinstance(value, dict) or key not in value: 596 | return None 597 | value = value[key] 598 | return value 599 | 600 | def main(): 601 | """主函数""" 602 | # 解析命令行参数 603 | parser = argparse.ArgumentParser(description='信道估计实验') 604 | parser.add_argument('--config', type=str, default='src/config/default.yml', 605 | help='配置文件路径') 606 | args = parser.parse_args() 607 | 608 | # 加载配置 609 | cfg = load_config(args.config) 610 | 611 | # 创建保存目录 612 | save_dir = Path(cfg['experiment']['save_dir']) 613 | plot_dir = Path(cfg['experiment']['plot_dir']) 614 | save_dir.mkdir(parents=True, exist_ok=True) 615 | plot_dir.mkdir(parents=True, exist_ok=True) 616 | 617 | # 设置时间戳 618 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 619 | 620 | # 设置日志记录 621 | logger = setup_logging(save_dir, timestamp) 622 | sys.stdout = logger 623 | 624 | # 设置随机种子 625 | seed = cfg['experiment']['seed'] 626 | random.seed(seed) 627 | np.random.seed(seed) 628 | torch.manual_seed(seed) 629 | if torch.cuda.is_available() and cfg['experiment']['use_cuda']: 630 | torch.cuda.manual_seed(seed) 631 | 632 | try: 633 | print_section_header("开始信道估计实验") 634 | 635 | # 打印实验参数 636 | print_progress("实验参数配置:", 1) 637 | print_progress(f"发射天线数: {cfg['channel']['n_tx']}", 2) 638 | print_progress(f"接收天线数: {cfg['channel']['n_rx']}", 2) 639 | print_progress(f"导频长度: {cfg['channel']['n_pilot']}", 2) 640 | print_progress(f"信噪比: {cfg['channel']['snr_db']} dB", 2) 641 | print_progress(f"样本数量: {cfg['channel']['n_samples']}", 2) 642 | print_progress(f"批次大小: {cfg['data']['loader']['batch_size']}", 2) 643 | print_progress(f"运行设备: {'cuda' if torch.cuda.is_available() and cfg['experiment']['use_cuda'] else 'cpu'}", 2) 644 | 645 | # 生成数据 646 | print_progress("\n初始化数据生成器...", 1) 647 | start_time = time.time() 648 | data_generator = ChannelDataGenerator( 649 | n_tx=cfg['channel']['n_tx'], 650 | n_rx=cfg['channel']['n_rx'], 651 | n_pilot=cfg['channel']['n_pilot'], 652 | channel_type=cfg['channel']['type'], 653 | rician_k=cfg['channel'].get('rician_k', 1.0) 654 | ) 655 | 656 | print_progress("开始生成数据...", 1) 657 | H, X, Y = data_generator.generate( 658 | n_samples=cfg['channel']['n_samples'], 659 | snr_db=cfg['channel']['snr_db'] 660 | ) 661 | 662 | print_progress(f"信道矩阵 H 形状: {H.shape}", 2) 663 | print_progress(f"导频符号 X 形状: {X.shape}", 2) 664 | print_progress(f"接收信号 Y 形状: {Y.shape}", 2) 665 | print_progress(f"数据生成完成 (耗时: {time.time()-start_time:.2f}s)", 1) 666 | 667 | # 数据预处理 668 | print_progress("\n开始数据预处理...", 1) 669 | start_time = time.time() 670 | preprocessor = ChannelPreprocessor( 671 | normalization=cfg['data']['preprocessing']['normalization'], 672 | remove_outliers=cfg['data']['preprocessing']['remove_outliers'], 673 | outlier_threshold=cfg['data']['preprocessing']['outlier_threshold'] 674 | ) 675 | 676 | H = preprocessor.fit_transform(H) 677 | Y = preprocessor.transform(Y) 678 | 679 | print_progress(f"数据预处理完成 (耗时: {time.time()-start_time:.2f}s)", 1) 680 | 681 | # 数据增强 682 | if cfg['data']['augmentation']['enabled']: 683 | print_progress("\n开始数据增强...", 1) 684 | start_time = time.time() 685 | 686 | H = augment_data( 687 | H, 688 | methods=cfg['data']['augmentation']['methods'], 689 | noise_std=cfg['data']['augmentation']['noise_std'], 690 | phase_shift_range=cfg['data']['augmentation']['phase_shift_range'], 691 | magnitude_scale_range=cfg['data']['augmentation']['magnitude_scale_range'] 692 | ) 693 | 694 | print_progress("增强后数据形状:", 2) 695 | print_progress(f"H: {H.shape}", 3) 696 | print_progress(f"X: {X.shape}", 3) 697 | print_progress(f"Y: {Y.shape}", 3) 698 | print_progress(f"数据增强完成 (耗时: {time.time()-start_time:.2f}s)", 1) 699 | 700 | # 评估传统方法 701 | print_section_header("评估传统估计方法") 702 | traditional_results = evaluate_traditional_methods( 703 | H=H, 704 | X=X, 705 | Y=Y, 706 | snr=10**(cfg['channel']['snr_db']/10), 707 | cfg=cfg 708 | ) 709 | 710 | # 数据集划分 711 | print_section_header("准备深度学习数据集") 712 | n_samples = H.shape[0] 713 | train_ratio = cfg['data']['split']['train_ratio'] 714 | val_ratio = cfg['data']['split']['val_ratio'] 715 | 716 | # 计算样本数量 717 | n_train = int(n_samples * train_ratio) 718 | n_val = int(n_samples * val_ratio) 719 | n_test = n_samples - n_train - n_val 720 | 721 | print_progress("数据集划分:", 1) 722 | print_progress(f"训练集: {n_train} 样本", 2) 723 | print_progress(f"验证集: {n_val} 样本", 2) 724 | print_progress(f"测试集: {n_test} 样本", 2) 725 | 726 | # 转换为PyTorch张量 727 | # 将复数数据分离为实部和虚部 728 | X_real = torch.from_numpy(X.real).float() 729 | X_imag = torch.from_numpy(X.imag).float() 730 | H_real = torch.from_numpy(H.real).float() 731 | H_imag = torch.from_numpy(H.imag).float() 732 | 733 | # 在最后一维拼接实部和虚部 734 | X_tensor = torch.cat([X_real, X_imag], dim=-1) 735 | H_tensor = torch.cat([H_real, H_imag], dim=-1) 736 | 737 | # 创建数据集 738 | dataset = TensorDataset(X_tensor, H_tensor) 739 | 740 | # 划分数据集 741 | train_dataset, val_dataset, test_dataset = torch.utils.data.random_split( 742 | dataset, [n_train, n_val, n_test] 743 | ) 744 | 745 | # 创建数据加载器 746 | train_loader = DataLoader( 747 | train_dataset, 748 | batch_size=cfg['data']['loader']['batch_size'], 749 | shuffle=True, 750 | num_workers=cfg['data']['loader']['num_workers'], 751 | pin_memory=cfg['data']['loader']['pin_memory'] 752 | ) 753 | 754 | val_loader = DataLoader( 755 | val_dataset, 756 | batch_size=cfg['data']['loader']['batch_size'], 757 | shuffle=False, 758 | num_workers=cfg['data']['loader']['num_workers'], 759 | pin_memory=cfg['data']['loader']['pin_memory'] 760 | ) 761 | 762 | test_loader = DataLoader( 763 | test_dataset, 764 | batch_size=cfg['data']['loader']['batch_size'], 765 | shuffle=False, 766 | num_workers=cfg['data']['loader']['num_workers'], 767 | pin_memory=cfg['data']['loader']['pin_memory'] 768 | ) 769 | 770 | # 计算输入大小(考虑实部和虚部) 771 | input_size = X.shape[1] * X.shape[2] * 2 # n_tx * n_pilot * 2 (实部和虚部) 772 | 773 | # 训练和评估深度学习模型 774 | dl_results = evaluate_dl_methods( 775 | train_loader=train_loader, 776 | val_loader=val_loader, 777 | test_loader=test_loader, 778 | input_size=input_size, 779 | cfg=cfg 780 | ) 781 | 782 | # 保存结果 783 | results = { 784 | 'config': cfg, 785 | 'traditional_results': traditional_results, 786 | 'dl_results': { 787 | name: { 788 | 'metrics': result['metrics'] 789 | } for name, result in dl_results.items() 790 | } 791 | } 792 | 793 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 794 | result_file = save_dir / f"results_{timestamp}.json" 795 | with open(result_file, 'w', encoding='utf-8') as f: 796 | json.dump(results, f, indent=4) 797 | 798 | # 绘制结果 799 | plot_results(traditional_results, dl_results, cfg) 800 | 801 | print("\n实验完成!结果已保存至:", result_file) 802 | 803 | except Exception as e: 804 | print(f"\n实验出错!错误信息:{str(e)}") 805 | raise e 806 | 807 | finally: 808 | # 恢复标准输出 809 | sys.stdout = sys.__stdout__ 810 | # 关闭日志文件 811 | logger.log_file.close() 812 | 813 | if __name__ == "__main__": 814 | main() -------------------------------------------------------------------------------- /src/traditional/__pycache__/estimators.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/traditional/__pycache__/estimators.cpython-312.pyc -------------------------------------------------------------------------------- /src/traditional/__pycache__/estimators.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/traditional/__pycache__/estimators.cpython-39.pyc -------------------------------------------------------------------------------- /src/traditional/estimators.py: -------------------------------------------------------------------------------- 1 | """ 2 | 传统信道估计方法模块 3 | 包含LS、LMMSE和ML估计器 4 | """ 5 | 6 | import numpy as np 7 | from scipy.special import erf 8 | from typing import Dict, Union, Optional 9 | 10 | class LSEstimator: 11 | """最小二乘(LS)估计器""" 12 | 13 | def __init__(self, regularization: bool = False, lambda_: float = 0.01): 14 | """ 15 | 初始化LS估计器 16 | 17 | 参数: 18 | regularization: 是否使用正则化 19 | lambda_: 正则化参数 20 | """ 21 | self.regularization = regularization 22 | self.lambda_ = lambda_ 23 | 24 | def estimate(self, Y: np.ndarray, X: np.ndarray) -> np.ndarray: 25 | """ 26 | 执行LS估计 27 | 28 | 参数: 29 | Y: 接收信号,shape=(n_samples, n_rx, n_pilot) 30 | X: 导频符号,shape=(n_samples, n_tx, n_pilot) 31 | 32 | 返回: 33 | H_est: 估计的信道矩阵,shape=(n_samples, n_rx, n_tx) 34 | """ 35 | # 转置X以便进行批量矩阵运算 36 | X_H = np.conjugate(np.transpose(X, (0, 2, 1))) # (n_samples, n_pilot, n_tx) 37 | 38 | # 计算XX^H 39 | XX_H = np.matmul(X, X_H) # (n_samples, n_tx, n_tx) 40 | 41 | if self.regularization: 42 | # 添加正则化项 43 | n_tx = X.shape[1] 44 | reg_term = self.lambda_ * np.eye(n_tx)[np.newaxis, :, :] 45 | XX_H = XX_H + reg_term 46 | 47 | # 使用SVD进行稳定的矩阵求逆 48 | H_est_list = [] 49 | for i in range(Y.shape[0]): 50 | try: 51 | # 尝试直接求逆 52 | XX_H_inv = np.linalg.inv(XX_H[i]) 53 | except np.linalg.LinAlgError: 54 | # 如果矩阵接近奇异,使用伪逆 55 | XX_H_inv = np.linalg.pinv(XX_H[i]) 56 | 57 | # 计算信道估计 58 | H_est_i = np.matmul(Y[i], np.matmul(X_H[i], XX_H_inv)) 59 | H_est_list.append(H_est_i) 60 | 61 | H_est = np.stack(H_est_list, axis=0) 62 | return H_est 63 | 64 | class LMMSEEstimator: 65 | """线性最小均方误差(LMMSE)估计器""" 66 | 67 | def __init__(self, snr: float, adaptive_snr: bool = False, 68 | correlation_method: str = 'sample'): 69 | """ 70 | 初始化LMMSE估计器 71 | 72 | 参数: 73 | snr: 信噪比(线性,非dB) 74 | adaptive_snr: 是否使用自适应SNR估计 75 | correlation_method: 信道相关矩阵估计方法,'sample' 或 'theoretical' 76 | """ 77 | self.snr = snr 78 | self.adaptive_snr = adaptive_snr 79 | self.correlation_method = correlation_method 80 | 81 | if correlation_method not in ['sample', 'theoretical']: 82 | raise ValueError("correlation_method必须是'sample'或'theoretical'") 83 | 84 | def _estimate_correlation(self, H_ls: np.ndarray) -> np.ndarray: 85 | """ 86 | 估计信道相关矩阵 87 | 88 | 参数: 89 | H_ls: LS估计的信道矩阵 90 | 91 | 返回: 92 | R_H: 信道相关矩阵 93 | """ 94 | if self.correlation_method == 'sample': 95 | # 使用样本相关矩阵 96 | H_ls_flat = H_ls.reshape(H_ls.shape[0], -1) 97 | R_H = np.matmul(H_ls_flat.T.conj(), H_ls_flat) / H_ls.shape[0] 98 | # 添加小的正则化项以提高数值稳定性 99 | epsilon = 1e-10 100 | R_H = R_H + epsilon * np.eye(R_H.shape[0]) 101 | else: 102 | # 使用理论相关矩阵(假设指数衰减) 103 | n_rx, n_tx = H_ls.shape[1:3] 104 | R_H = np.zeros((n_rx * n_tx, n_rx * n_tx), dtype=complex) 105 | rho = 0.7 # 相关系数 106 | for i in range(n_rx * n_tx): 107 | for j in range(n_rx * n_tx): 108 | R_H[i, j] = rho ** abs(i - j) 109 | 110 | return R_H 111 | 112 | def estimate(self, Y: np.ndarray, X: np.ndarray) -> np.ndarray: 113 | """ 114 | 执行LMMSE估计 115 | 116 | 参数: 117 | Y: 接收信号,shape=(n_samples, n_rx, n_pilot) 118 | X: 导频符号,shape=(n_samples, n_tx, n_pilot) 119 | 120 | 返回: 121 | H_est: 估计的信道矩阵,shape=(n_samples, n_rx, n_tx) 122 | """ 123 | # 首先进行LS估计 124 | ls_estimator = LSEstimator(regularization=True, lambda_=0.01) # 使用正则化的LS估计 125 | H_ls = ls_estimator.estimate(Y, X) 126 | 127 | # 估计信道相关矩阵 128 | R_H = self._estimate_correlation(H_ls) 129 | 130 | if self.adaptive_snr: 131 | # 使用样本方差估计噪声功率 132 | error = Y - np.matmul(H_ls, X) 133 | noise_power = np.mean(np.abs(error) ** 2) 134 | signal_power = np.mean(np.abs(Y) ** 2) 135 | snr = max(signal_power / noise_power, 1.0) # 确保SNR不小于1 136 | else: 137 | snr = self.snr 138 | 139 | # LMMSE估计 140 | n_samples = Y.shape[0] 141 | H_est = np.zeros_like(H_ls) 142 | 143 | for i in range(n_samples): 144 | h_ls = H_ls[i].flatten() 145 | # 使用更稳定的求解方法 146 | try: 147 | # 使用Cholesky分解求解线性方程组 148 | A = R_H + np.eye(len(h_ls)) / snr 149 | L = np.linalg.cholesky(A) 150 | h_est = h_ls.copy() 151 | # 解线性方程组 A @ x = R_H @ h_ls 152 | y = np.linalg.solve(L, R_H @ h_ls) 153 | h_est = np.linalg.solve(L.T.conj(), y) 154 | except np.linalg.LinAlgError: 155 | # 如果Cholesky分解失败,使用伪逆 156 | h_est = np.matmul(R_H, np.linalg.pinv(R_H + np.eye(len(h_ls)) / snr)) @ h_ls 157 | 158 | H_est[i] = h_est.reshape(H_ls.shape[1:]) 159 | 160 | return H_est 161 | 162 | class MLEstimator: 163 | """最大似然(ML)估计器""" 164 | 165 | def __init__(self, max_iter: int = 100, tol: float = 1e-6, learning_rate: float = 0.01): 166 | """ 167 | 初始化ML估计器 168 | 169 | 参数: 170 | max_iter: 最大迭代次数 171 | tol: 收敛阈值 172 | learning_rate: 初始学习率 173 | """ 174 | self.max_iter = max_iter 175 | self.tol = float(tol) 176 | self.learning_rate = learning_rate 177 | self.min_lr = 1e-6 # 最小学习率 178 | self.lr_decay = 0.95 # 学习率衰减因子 179 | 180 | def _calculate_loss(self, H_est: np.ndarray, X: np.ndarray, Y: np.ndarray) -> float: 181 | """计算损失函数(负对数似然)""" 182 | error = Y - np.matmul(H_est, X) 183 | return np.mean(np.abs(error) ** 2) 184 | 185 | def estimate(self, Y: np.ndarray, X: np.ndarray) -> np.ndarray: 186 | """ 187 | 执行ML估计 188 | 189 | 参数: 190 | Y: 接收信号,shape=(n_samples, n_rx, n_pilot) 191 | X: 导频符号,shape=(n_samples, n_tx, n_pilot) 192 | 193 | 返回: 194 | H_est: 估计的信道矩阵,shape=(n_samples, n_rx, n_tx) 195 | """ 196 | # 使用正则化的LS估计作为初始值 197 | ls_estimator = LSEstimator(regularization=True, lambda_=0.01) 198 | H_est = ls_estimator.estimate(Y, X) 199 | 200 | # 迭代优化 201 | X_H = np.conjugate(np.transpose(X, (0, 2, 1))) # (n_samples, n_pilot, n_tx) 202 | lr = self.learning_rate 203 | best_H = H_est.copy() 204 | best_loss = self._calculate_loss(H_est, X, Y) 205 | 206 | for iter_idx in range(self.max_iter): 207 | H_prev = H_est.copy() 208 | prev_loss = self._calculate_loss(H_prev, X, Y) 209 | 210 | # 计算误差 211 | error = Y - np.matmul(H_est, X) # (n_samples, n_rx, n_pilot) 212 | 213 | # 计算梯度 214 | gradient = -np.matmul(error, X_H) # 负梯度方向 215 | 216 | # 梯度缩放 217 | grad_norm = np.maximum(np.sqrt(np.mean(np.abs(gradient) ** 2)), 1e-8) 218 | gradient = gradient / grad_norm 219 | 220 | # 更新估计 221 | H_est = H_est - lr * gradient 222 | 223 | # 计算当前损失 224 | current_loss = self._calculate_loss(H_est, X, Y) 225 | 226 | # 如果损失增加,回退并降低学习率 227 | if current_loss > prev_loss: 228 | H_est = H_prev 229 | lr = max(lr * self.lr_decay, self.min_lr) 230 | continue 231 | 232 | # 更新最佳结果 233 | if current_loss < best_loss: 234 | best_H = H_est.copy() 235 | best_loss = current_loss 236 | 237 | # 检查收敛 238 | diff = float(np.mean(np.abs(H_est - H_prev) ** 2)) 239 | if diff < self.tol: 240 | break 241 | 242 | # 定期降低学习率 243 | if iter_idx > 0 and iter_idx % 10 == 0: 244 | lr = max(lr * self.lr_decay, self.min_lr) 245 | 246 | return best_H 247 | 248 | def calculate_performance_metrics(H_true: np.ndarray, H_est: np.ndarray) -> Dict[str, float]: 249 | """ 250 | 计算性能指标 251 | 252 | 参数: 253 | H_true: 真实信道矩阵 254 | H_est: 估计的信道矩阵 255 | 256 | 返回: 257 | 包含各种性能指标的字典 258 | """ 259 | # 计算MSE 260 | mse = np.mean(np.abs(H_true - H_est) ** 2) 261 | 262 | # 计算NMSE 263 | nmse = mse / np.mean(np.abs(H_true) ** 2) 264 | 265 | # 计算BER(假设QPSK调制) 266 | def qpsk_ber(h_true, h_est): 267 | # 添加小的常数以避免除零 268 | epsilon = 1e-10 269 | snr_eff = np.abs(h_true) ** 2 / (np.abs(h_true - h_est) ** 2 + epsilon) 270 | return 0.5 * np.mean(1 - erf(np.sqrt(snr_eff/2))) 271 | 272 | ber = qpsk_ber(H_true, H_est) 273 | 274 | return { 275 | 'mse': float(mse), 276 | 'nmse': float(nmse), 277 | 'ber': float(ber) 278 | } -------------------------------------------------------------------------------- /src/traditional/ls_estimation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import linalg 3 | 4 | class LSChannelEstimator: 5 | """最小二乘(LS)信道估计器""" 6 | 7 | def __init__(self): 8 | self.H_est = None # 估计的信道矩阵 9 | 10 | def estimate(self, Y, X): 11 | """ 12 | 使用最小二乘方法估计信道 13 | 14 | 参数: 15 | Y: 接收信号矩阵 16 | X: 发送信号矩阵(导频) 17 | 18 | 返回: 19 | H_est: 估计的信道矩阵 20 | """ 21 | # 使用最小二乘方法估计信道 22 | # H = Y * X^H * (X * X^H)^(-1) 23 | X_H = np.conjugate(X).T 24 | self.H_est = Y @ X_H @ linalg.inv(X @ X_H) 25 | return self.H_est 26 | 27 | def get_mse(self, H_true): 28 | """ 29 | 计算均方误差 30 | 31 | 参数: 32 | H_true: 真实信道矩阵 33 | 34 | 返回: 35 | mse: 均方误差 36 | """ 37 | if self.H_est is None: 38 | raise ValueError("请先进行信道估计") 39 | 40 | mse = np.mean(np.abs(H_true - self.H_est) ** 2) 41 | return mse -------------------------------------------------------------------------------- /src/utils/__pycache__/config.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/utils/__pycache__/config.cpython-312.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/config.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/utils/__pycache__/config.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/data_generator.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/utils/__pycache__/data_generator.cpython-312.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/data_generator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/utils/__pycache__/data_generator.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/preprocessing.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/utils/__pycache__/preprocessing.cpython-312.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/preprocessing.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/utils/__pycache__/preprocessing.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/trainer.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/utils/__pycache__/trainer.cpython-312.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/trainer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/utils/__pycache__/trainer.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | 配置加载器模块 3 | 用于加载和管理YAML配置文件 4 | """ 5 | 6 | import os 7 | import yaml 8 | from pathlib import Path 9 | from typing import Any, Dict, Optional 10 | 11 | class Config: 12 | """配置加载器类""" 13 | 14 | def __init__(self, config_path: str = "config/default.yml"): 15 | """ 16 | 初始化配置加载器 17 | 18 | 参数: 19 | config_path: 配置文件路径,默认为'config/default.yml' 20 | """ 21 | self.config_path = Path(config_path) 22 | if not self.config_path.exists(): 23 | raise FileNotFoundError(f"配置文件未找到: {config_path}") 24 | 25 | with open(self.config_path, 'r', encoding='utf-8') as f: 26 | self.config = yaml.safe_load(f) 27 | 28 | def get(self, key: str, default: Any = None) -> Any: 29 | """ 30 | 获取配置值 31 | 32 | 参数: 33 | key: 配置键,使用点号分隔层级,如'models.cnn.channels' 34 | default: 默认值,当键不存在时返回 35 | 36 | 返回: 37 | 配置值 38 | """ 39 | keys = key.split('.') 40 | value = self.config 41 | 42 | for k in keys: 43 | if isinstance(value, dict): 44 | value = value.get(k) 45 | if value is None: 46 | return default 47 | else: 48 | return default 49 | 50 | return value 51 | 52 | def set(self, key: str, value: Any) -> None: 53 | """ 54 | 设置配置值 55 | 56 | 参数: 57 | key: 配置键,使用点号分隔层级 58 | value: 要设置的值 59 | """ 60 | keys = key.split('.') 61 | config = self.config 62 | 63 | for k in keys[:-1]: 64 | if k not in config: 65 | config[k] = {} 66 | config = config[k] 67 | 68 | config[keys[-1]] = value 69 | 70 | def save(self, save_path: Optional[str] = None) -> None: 71 | """ 72 | 保存配置到文件 73 | 74 | 参数: 75 | save_path: 保存路径,默认为原配置文件路径 76 | """ 77 | save_path = Path(save_path) if save_path else self.config_path 78 | save_path.parent.mkdir(parents=True, exist_ok=True) 79 | 80 | with open(save_path, 'w', encoding='utf-8') as f: 81 | yaml.safe_dump(self.config, f, allow_unicode=True, sort_keys=False) 82 | 83 | def update_from_args(self, args: Dict[str, Any]) -> None: 84 | """ 85 | 从命令行参数更新配置 86 | 87 | 参数: 88 | args: 命令行参数字典 89 | """ 90 | for key, value in args.items(): 91 | if value is not None: # 只更新非None的值 92 | self.set(key, value) 93 | 94 | def __getitem__(self, key: str) -> Any: 95 | """使配置对象可以使用字典方式访问""" 96 | return self.get(key) 97 | 98 | def __setitem__(self, key: str, value: Any) -> None: 99 | """使配置对象可以使用字典方式设置值""" 100 | self.set(key, value) 101 | 102 | def __str__(self) -> str: 103 | """返回配置的字符串表示""" 104 | return yaml.safe_dump(self.config, allow_unicode=True, sort_keys=False) 105 | 106 | @property 107 | def experiment(self) -> Dict: 108 | """获取实验配置""" 109 | return self.config.get('experiment', {}) 110 | 111 | @property 112 | def channel(self) -> Dict: 113 | """获取信道配置""" 114 | return self.config.get('channel', {}) 115 | 116 | @property 117 | def data(self) -> Dict: 118 | """获取数据处理配置""" 119 | return self.config.get('data', {}) 120 | 121 | @property 122 | def models(self) -> Dict: 123 | """获取模型配置""" 124 | return self.config.get('models', {}) 125 | 126 | @property 127 | def training(self) -> Dict: 128 | """获取训练配置""" 129 | return self.config.get('training', {}) 130 | 131 | @property 132 | def evaluation(self) -> Dict: 133 | """获取评估配置""" 134 | return self.config.get('evaluation', {}) 135 | 136 | @property 137 | def tensorboard(self) -> Dict: 138 | """获取TensorBoard配置""" 139 | return self.config.get('tensorboard', {}) -------------------------------------------------------------------------------- /src/utils/data_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | 信道数据生成器模块 3 | 用于生成MIMO信道数据、导频符号和接收信号 4 | """ 5 | 6 | import numpy as np 7 | from typing import Tuple 8 | 9 | class ChannelDataGenerator: 10 | """信道数据生成器类""" 11 | 12 | def __init__(self, n_tx: int, n_rx: int, n_pilot: int, 13 | channel_type: str = 'rayleigh', rician_k: float = 1.0): 14 | """ 15 | 初始化信道数据生成器 16 | 17 | 参数: 18 | n_tx: 发射天线数量 19 | n_rx: 接收天线数量 20 | n_pilot: 导频符号长度 21 | channel_type: 信道类型,'rayleigh' 或 'rician' 22 | rician_k: Rician K因子,仅在channel_type='rician'时有效 23 | """ 24 | self.n_tx = n_tx 25 | self.n_rx = n_rx 26 | self.n_pilot = n_pilot 27 | self.channel_type = channel_type.lower() 28 | self.rician_k = rician_k 29 | 30 | if self.channel_type not in ['rayleigh', 'rician']: 31 | raise ValueError("信道类型必须是 'rayleigh' 或 'rician'") 32 | 33 | if self.n_pilot < self.n_tx: 34 | raise ValueError("导频长度必须大于等于发射天线数量") 35 | 36 | def generate_channel(self, n_samples: int) -> np.ndarray: 37 | """ 38 | 生成MIMO信道矩阵 39 | 40 | 参数: 41 | n_samples: 样本数量 42 | 43 | 返回: 44 | shape=(n_samples, n_rx, n_tx)的信道矩阵 45 | """ 46 | # 生成随机复高斯分量 47 | h_random = (np.random.normal(0, 1/np.sqrt(2), (n_samples, self.n_rx, self.n_tx)) + 48 | 1j * np.random.normal(0, 1/np.sqrt(2), (n_samples, self.n_rx, self.n_tx))) 49 | 50 | if self.channel_type == 'rayleigh': 51 | return h_random 52 | else: # Rician信道 53 | # 生成确定性分量(LOS分量) 54 | h_los = np.ones((n_samples, self.n_rx, self.n_tx)) / np.sqrt(self.n_tx * self.n_rx) 55 | 56 | # 组合LOS分量和随机分量 57 | k = self.rician_k 58 | h_rician = np.sqrt(k/(k+1)) * h_los + np.sqrt(1/(k+1)) * h_random 59 | 60 | return h_rician 61 | 62 | def generate_pilot_symbols(self, n_samples: int) -> np.ndarray: 63 | """ 64 | 生成导频符号 65 | 66 | 参数: 67 | n_samples: 样本数量 68 | 69 | 返回: 70 | shape=(n_samples, n_tx, n_pilot)的导频符号矩阵 71 | """ 72 | # 使用QPSK调制生成导频 73 | pilot_real = np.random.choice([-1, 1], size=(n_samples, self.n_tx, self.n_pilot)) 74 | pilot_imag = np.random.choice([-1, 1], size=(n_samples, self.n_tx, self.n_pilot)) 75 | pilot = (pilot_real + 1j * pilot_imag) / np.sqrt(2) 76 | 77 | # 归一化功率 78 | pilot = pilot / np.sqrt(self.n_tx) 79 | 80 | return pilot 81 | 82 | def generate_received_signal(self, H: np.ndarray, X: np.ndarray, snr_db: float) -> np.ndarray: 83 | """ 84 | 生成接收信号 85 | 86 | 参数: 87 | H: shape=(n_samples, n_rx, n_tx)的信道矩阵 88 | X: shape=(n_samples, n_tx, n_pilot)的发送信号 89 | snr_db: 信噪比(dB) 90 | 91 | 返回: 92 | shape=(n_samples, n_rx, n_pilot)的接收信号 93 | """ 94 | # 计算信噪比 95 | snr = 10 ** (snr_db / 10) 96 | 97 | # 计算信号功率 98 | signal_power = np.mean(np.abs(H @ X) ** 2) 99 | noise_power = signal_power / snr 100 | 101 | # 生成噪声 102 | noise = (np.random.normal(0, np.sqrt(noise_power/2), X.shape[0:1] + (self.n_rx, X.shape[-1])) + 103 | 1j * np.random.normal(0, np.sqrt(noise_power/2), X.shape[0:1] + (self.n_rx, X.shape[-1]))) 104 | 105 | # 生成接收信号 106 | Y = np.matmul(H, X) + noise 107 | 108 | return Y 109 | 110 | def generate(self, n_samples: int, snr_db: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 111 | """ 112 | 生成完整的训练数据集 113 | 114 | 参数: 115 | n_samples: 样本数量 116 | snr_db: 信噪比(dB) 117 | 118 | 返回: 119 | (H, X, Y)元组,分别是信道矩阵、导频符号和接收信号 120 | """ 121 | # 生成信道矩阵 122 | H = self.generate_channel(n_samples) 123 | 124 | # 生成导频符号 125 | X = self.generate_pilot_symbols(n_samples) 126 | 127 | # 生成接收信号 128 | Y = self.generate_received_signal(H, X, snr_db) 129 | 130 | return H, X, Y -------------------------------------------------------------------------------- /src/utils/preprocessing.py: -------------------------------------------------------------------------------- 1 | """ 2 | 信道数据预处理模块 3 | 用于数据归一化和异常值处理 4 | """ 5 | 6 | import numpy as np 7 | from typing import Optional, Union, List, Tuple 8 | from sklearn.preprocessing import StandardScaler, MinMaxScaler 9 | 10 | class ChannelPreprocessor: 11 | """信道数据预处理器""" 12 | 13 | def __init__(self, normalization: str = 'z-score', 14 | remove_outliers: bool = False, 15 | outlier_threshold: float = 3.0): 16 | """ 17 | 初始化预处理器 18 | 19 | 参数: 20 | normalization: 归一化方法,'z-score' 或 'min-max' 21 | remove_outliers: 是否移除异常值 22 | outlier_threshold: 异常值阈值(标准差的倍数) 23 | """ 24 | self.normalization = normalization.lower() 25 | self.remove_outliers = remove_outliers 26 | self.outlier_threshold = outlier_threshold 27 | 28 | if self.normalization not in ['z-score', 'min-max']: 29 | raise ValueError("归一化方法必须是 'z-score' 或 'min-max'") 30 | 31 | # 存储归一化参数 32 | self.mean_real = None 33 | self.std_real = None 34 | self.mean_imag = None 35 | self.std_imag = None 36 | self.min_real = None 37 | self.max_real = None 38 | self.min_imag = None 39 | self.max_imag = None 40 | 41 | def _remove_outliers(self, data: np.ndarray) -> np.ndarray: 42 | """ 43 | 移除异常值 44 | 45 | 参数: 46 | data: 输入数据 47 | 48 | 返回: 49 | 处理后的数据 50 | """ 51 | if not self.remove_outliers: 52 | return data 53 | 54 | # 分别处理实部和虚部 55 | real_part = np.real(data) 56 | imag_part = np.imag(data) 57 | 58 | # 计算均值和标准差 59 | mean_real = np.mean(real_part) 60 | std_real = np.std(real_part) 61 | mean_imag = np.mean(imag_part) 62 | std_imag = np.std(imag_part) 63 | 64 | # 创建掩码 65 | real_mask = np.abs(real_part - mean_real) <= self.outlier_threshold * std_real 66 | imag_mask = np.abs(imag_part - mean_imag) <= self.outlier_threshold * std_imag 67 | mask = real_mask & imag_mask 68 | 69 | # 将异常值替换为均值 70 | real_part[~real_mask] = mean_real 71 | imag_part[~imag_mask] = mean_imag 72 | 73 | return real_part + 1j * imag_part 74 | 75 | def fit(self, data: np.ndarray) -> None: 76 | """ 77 | 计算归一化参数 78 | 79 | 参数: 80 | data: 输入数据 81 | """ 82 | # 首先移除异常值 83 | data = self._remove_outliers(data) 84 | 85 | # 分离实部和虚部 86 | real_part = np.real(data) 87 | imag_part = np.imag(data) 88 | 89 | if self.normalization == 'z-score': 90 | self.mean_real = np.mean(real_part) 91 | self.std_real = np.std(real_part) 92 | self.mean_imag = np.mean(imag_part) 93 | self.std_imag = np.std(imag_part) 94 | else: # min-max归一化 95 | self.min_real = np.min(real_part) 96 | self.max_real = np.max(real_part) 97 | self.min_imag = np.min(imag_part) 98 | self.max_imag = np.max(imag_part) 99 | 100 | def transform(self, data: np.ndarray) -> np.ndarray: 101 | """ 102 | 应用归一化 103 | 104 | 参数: 105 | data: 输入数据 106 | 107 | 返回: 108 | 归一化后的数据 109 | """ 110 | # 首先移除异常值 111 | data = self._remove_outliers(data) 112 | 113 | # 分离实部和虚部 114 | real_part = np.real(data) 115 | imag_part = np.imag(data) 116 | 117 | if self.normalization == 'z-score': 118 | if self.mean_real is None or self.std_real is None: 119 | raise ValueError("请先调用fit方法计算归一化参数") 120 | 121 | real_norm = (real_part - self.mean_real) / self.std_real 122 | imag_norm = (imag_part - self.mean_imag) / self.std_imag 123 | else: # min-max归一化 124 | if self.min_real is None or self.max_real is None: 125 | raise ValueError("请先调用fit方法计算归一化参数") 126 | 127 | real_norm = (real_part - self.min_real) / (self.max_real - self.min_real) 128 | imag_norm = (imag_part - self.min_imag) / (self.max_imag - self.min_imag) 129 | 130 | return real_norm + 1j * imag_norm 131 | 132 | def fit_transform(self, data: np.ndarray) -> np.ndarray: 133 | """ 134 | 计算归一化参数并应用归一化 135 | 136 | 参数: 137 | data: 输入数据 138 | 139 | 返回: 140 | 归一化后的数据 141 | """ 142 | self.fit(data) 143 | return self.transform(data) 144 | 145 | def inverse_transform(self, data: np.ndarray) -> np.ndarray: 146 | """ 147 | 反归一化 148 | 149 | 参数: 150 | data: 归一化后的数据 151 | 152 | 返回: 153 | 原始尺度的数据 154 | """ 155 | real_part = np.real(data) 156 | imag_part = np.imag(data) 157 | 158 | if self.normalization == 'z-score': 159 | if self.mean_real is None or self.std_real is None: 160 | raise ValueError("请先调用fit方法计算归一化参数") 161 | 162 | real_orig = real_part * self.std_real + self.mean_real 163 | imag_orig = imag_part * self.std_imag + self.mean_imag 164 | else: # min-max归一化 165 | if self.min_real is None or self.max_real is None: 166 | raise ValueError("请先调用fit方法计算归一化参数") 167 | 168 | real_orig = real_part * (self.max_real - self.min_real) + self.min_real 169 | imag_orig = imag_part * (self.max_imag - self.min_imag) + self.min_imag 170 | 171 | return real_orig + 1j * imag_orig 172 | 173 | def augment_data(data: np.ndarray, 174 | methods: Optional[List[str]] = None, 175 | noise_std: float = 0.01, 176 | phase_shift_range: Tuple[float, float] = (-0.1, 0.1), 177 | magnitude_scale_range: Tuple[float, float] = (0.9, 1.1)) -> np.ndarray: 178 | """ 179 | 数据增强 180 | 181 | 参数: 182 | data: 输入数据 183 | methods: 增强方法列表,可选['noise', 'phase_shift', 'magnitude_scale'] 184 | noise_std: 高斯噪声标准差 185 | phase_shift_range: 相位偏移范围(弧度) 186 | magnitude_scale_range: 幅度缩放范围 187 | 188 | 返回: 189 | 增强后的数据 190 | """ 191 | if methods is None: 192 | methods = ['noise', 'phase_shift', 'magnitude_scale'] 193 | 194 | augmented_data = data.copy() 195 | 196 | for method in methods: 197 | if method == 'noise': 198 | # 添加复高斯噪声 199 | noise = (np.random.normal(0, noise_std, data.shape) + 200 | 1j * np.random.normal(0, noise_std, data.shape)) 201 | augmented_data += noise 202 | 203 | elif method == 'phase_shift': 204 | # 随机相位偏移 205 | phase_shift = np.random.uniform(phase_shift_range[0], 206 | phase_shift_range[1], 207 | data.shape) 208 | augmented_data *= np.exp(1j * phase_shift) 209 | 210 | elif method == 'magnitude_scale': 211 | # 随机幅度缩放 212 | scale = np.random.uniform(magnitude_scale_range[0], 213 | magnitude_scale_range[1], 214 | data.shape) 215 | augmented_data *= scale 216 | 217 | return augmented_data -------------------------------------------------------------------------------- /src/utils/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | 信道估计训练器模块 3 | 4 | 此模块实现了用于训练深度学习模型的训练器类。训练器提供了完整的训练流程管理, 5 | 包括训练循环、验证评估、模型保存、早停机制等功能。 6 | 7 | 主要特性: 8 | - 支持GPU训练 9 | - 实时进度显示 10 | - TensorBoard可视化 11 | - 自动保存最佳模型 12 | - 训练状态恢复 13 | - 早停机制 14 | - 学习率自适应调整 15 | 16 | 作者: lzm 17 | 日期: 2025-03-07 18 | """ 19 | 20 | import os 21 | import time 22 | from pathlib import Path 23 | import torch 24 | from torch.utils.tensorboard import SummaryWriter 25 | from tqdm import tqdm 26 | import sys 27 | 28 | class ChannelEstimatorTrainer: 29 | """ 30 | 信道估计模型训练器 31 | 32 | 该类实现了完整的模型训练流程,包括: 33 | 1. 训练和验证循环 34 | 2. 损失计算和优化 35 | 3. 进度监控和可视化 36 | 4. 模型保存和加载 37 | 5. 训练状态管理 38 | 39 | 参数: 40 | model (nn.Module): 要训练的模型 41 | train_loader (DataLoader): 训练数据加载器 42 | val_loader (DataLoader): 验证数据加载器 43 | save_dir (str): 模型和日志保存目录 44 | project_name (str): 项目名称,用于创建子目录 45 | device (str): 训练设备('cuda'或'cpu') 46 | learning_rate (float): 初始学习率 47 | """ 48 | 49 | def __init__( 50 | self, 51 | model, 52 | train_loader, 53 | val_loader, 54 | save_dir='runs/train', 55 | project_name='channel_estimation', 56 | device='cuda', 57 | learning_rate=0.001 58 | ): 59 | # 初始化模型和数据加载器 60 | self.model = model.to(device) 61 | self.train_loader = train_loader 62 | self.val_loader = val_loader 63 | self.device = device 64 | self.learning_rate = learning_rate 65 | 66 | # 创建保存目录 67 | self.save_dir = Path(save_dir) / project_name 68 | self.weights_dir = self.save_dir / 'weights' 69 | self.weights_dir.mkdir(parents=True, exist_ok=True) 70 | 71 | # 设置TensorBoard 72 | self.writer = SummaryWriter(str(self.save_dir / 'tensorboard')) 73 | 74 | # 初始化优化器和损失函数 75 | self.criterion = torch.nn.MSELoss() 76 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate) 77 | 78 | # 设置学习率调度器,当验证损失不再下降时降低学习率 79 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 80 | self.optimizer, 'min', patience=5, verbose=True 81 | ) 82 | 83 | # 初始化训练状态 84 | self.best_val_loss = float('inf') # 最佳验证损失 85 | self.current_epoch = 0 # 当前训练轮数 86 | self.history = { # 训练历史记录 87 | 'train_loss': [], # 训练损失历史 88 | 'val_loss': [], # 验证损失历史 89 | 'learning_rate': [] # 学习率历史 90 | } 91 | 92 | def train_epoch(self): 93 | """ 94 | 训练一个epoch 95 | 96 | 该方法实现了单个训练epoch的完整流程: 97 | 1. 遍历训练数据批次 98 | 2. 前向传播计算损失 99 | 3. 反向传播更新参数 100 | 4. 记录训练状态 101 | 5. 更新进度显示 102 | 103 | 返回: 104 | float: 该epoch的平均训练损失 105 | """ 106 | self.model.train() # 设置为训练模式 107 | total_loss = 0 108 | 109 | # 创建进度条 110 | pbar = tqdm(self.train_loader, 111 | desc=f'Epoch {self.current_epoch + 1:<3d}', 112 | leave=True, # 保持显示 113 | position=1, # 位置在总进度条下方 114 | bar_format='{desc:<12} {percentage:3.0f}%|{bar:50}{r_bar}', 115 | ncols=120) 116 | 117 | # 遍历训练数据批次 118 | for batch_idx, (X, y) in enumerate(pbar): 119 | # 将数据移到指定设备 120 | X, y = X.to(self.device), y.to(self.device) 121 | 122 | # 前向传播 123 | self.optimizer.zero_grad() # 清除梯度 124 | outputs = self.model(X) # 模型预测 125 | loss = self.criterion(outputs, y) # 计算损失 126 | 127 | # 反向传播 128 | loss.backward() # 计算梯度 129 | self.optimizer.step() # 更新参数 130 | 131 | # 更新损失统计 132 | total_loss += loss.item() 133 | avg_loss = total_loss / (batch_idx + 1) 134 | 135 | # 更新进度条显示 136 | pbar.set_postfix({'loss': f'{avg_loss:.6f}'}, refresh=True) 137 | 138 | # 记录到TensorBoard 139 | step = self.current_epoch * len(self.train_loader) + batch_idx 140 | self.writer.add_scalar('train/batch_loss', loss.item(), step) 141 | 142 | pbar.close() 143 | return total_loss / len(self.train_loader) 144 | 145 | def validate(self): 146 | """ 147 | 验证模型性能 148 | 149 | 该方法在验证集上评估模型性能: 150 | 1. 不计算梯度 151 | 2. 遍历验证数据批次 152 | 3. 计算验证损失 153 | 4. 更新进度显示 154 | 155 | 返回: 156 | float: 验证集上的平均损失 157 | """ 158 | self.model.eval() # 设置为评估模式 159 | total_loss = 0 160 | val_loader_len = len(self.val_loader) 161 | 162 | with torch.no_grad(): # 不计算梯度 163 | # 创建进度条 164 | pbar = tqdm(self.val_loader, 165 | desc='Validating', 166 | leave=True, # 保持显示 167 | position=1, # 位置在总进度条下方 168 | bar_format='{desc:<12} {percentage:3.0f}%|{bar:50}{r_bar}', 169 | ncols=120) 170 | 171 | # 遍历验证数据批次 172 | for batch_idx, (X, y) in enumerate(pbar): 173 | # 将数据移到指定设备 174 | X, y = X.to(self.device), y.to(self.device) 175 | outputs = self.model(X) # 模型预测 176 | loss = self.criterion(outputs, y) # 计算损失 177 | total_loss += loss.item() 178 | 179 | # 更新进度条显示 180 | avg_loss = total_loss / (batch_idx + 1) 181 | pbar.set_postfix({'loss': f'{avg_loss:.6f}'}, refresh=True) 182 | 183 | pbar.close() 184 | 185 | return total_loss / val_loader_len 186 | 187 | def train(self, epochs=100, early_stopping_patience=10): 188 | """ 189 | 完整的训练流程 190 | 191 | 该方法实现了完整的模型训练流程: 192 | 1. 多轮训练循环 193 | 2. 定期验证评估 194 | 3. 模型保存 195 | 4. 早停机制 196 | 5. 进度监控 197 | 198 | 参数: 199 | epochs (int): 训练轮数 200 | early_stopping_patience (int): 早停耐心值,验证损失多少轮未改善时停止训练 201 | """ 202 | early_stopping_counter = 0 203 | start_time = time.time() 204 | 205 | # 打印训练配置 206 | print("\n训练配置:") 207 | print(f" Epochs: {epochs}") 208 | print(f" Batch Size: {self.train_loader.batch_size}") 209 | print(f" Learning Rate: {self.learning_rate}") 210 | print(f" Device: {self.device}") 211 | print(f" Early Stopping Patience: {early_stopping_patience}") 212 | print() 213 | 214 | # 创建总进度条 215 | pbar = tqdm(range(epochs), 216 | desc='Progress', 217 | leave=True, # 保持显示 218 | position=0, # 位置在最上方 219 | bar_format='{desc:<12} {percentage:3.0f}%|{bar:50}{r_bar}', 220 | ncols=120) 221 | 222 | try: 223 | # 训练循环 224 | for epoch in pbar: 225 | self.current_epoch = epoch 226 | 227 | # 训练一个epoch 228 | train_loss = self.train_epoch() 229 | 230 | # 验证评估 231 | val_loss = self.validate() 232 | 233 | # 更新学习率 234 | current_lr = self.optimizer.param_groups[0]['lr'] 235 | self.scheduler.step(val_loss) # 根据验证损失调整学习率 236 | 237 | # 记录训练历史 238 | self.history['train_loss'].append(train_loss) 239 | self.history['val_loss'].append(val_loss) 240 | self.history['learning_rate'].append(current_lr) 241 | 242 | # 记录到TensorBoard 243 | self.writer.add_scalar('train/epoch_loss', train_loss, epoch) 244 | self.writer.add_scalar('val/epoch_loss', val_loss, epoch) 245 | self.writer.add_scalar('train/lr', current_lr, epoch) 246 | 247 | # 更新总进度条显示 248 | pbar.set_postfix({ 249 | 'train': f'{train_loss:.6f}', 250 | 'val': f'{val_loss:.6f}', 251 | 'lr': f'{current_lr:.2e}' 252 | }, refresh=True) 253 | 254 | # 保存最佳模型 255 | if val_loss < self.best_val_loss: 256 | self.best_val_loss = val_loss 257 | self.save_model('best.pt') 258 | early_stopping_counter = 0 # 重置早停计数器 259 | else: 260 | early_stopping_counter += 1 261 | 262 | # 定期保存模型 263 | if (epoch + 1) % 10 == 0: 264 | self.save_model(f'epoch_{epoch + 1}.pt') 265 | 266 | # 早停检查 267 | if early_stopping_counter >= early_stopping_patience: 268 | print(f'\nEarly stopping triggered at epoch {epoch + 1}') 269 | break 270 | 271 | # 在每个epoch结束时打印空行 272 | print() 273 | 274 | except KeyboardInterrupt: 275 | print('\nTraining interrupted by user') 276 | 277 | finally: 278 | # 保存最后的模型和清理资源 279 | self.save_model('last.pt') 280 | self.writer.close() 281 | 282 | # 打印训练结果统计 283 | elapsed_time = time.time() - start_time 284 | print("\n训练完成:") 285 | print(f" 训练轮数: {self.current_epoch + 1}/{epochs}") 286 | print(f" 最佳验证损失: {self.best_val_loss:.6f}") 287 | print(f" 训练时间: {elapsed_time/3600:.2f}h") 288 | print(f" 保存路径: {self.weights_dir}") 289 | print() 290 | 291 | def save_model(self, filename): 292 | """ 293 | 保存模型状态 294 | 295 | 保存完整的训练状态,包括: 296 | - 模型参数 297 | - 优化器状态 298 | - 学习率调度器状态 299 | - 训练历史 300 | - 最佳验证损失 301 | 302 | 参数: 303 | filename (str): 保存的文件名 304 | """ 305 | save_path = self.weights_dir / filename 306 | torch.save({ 307 | 'epoch': self.current_epoch, 308 | 'model_state_dict': self.model.state_dict(), 309 | 'optimizer_state_dict': self.optimizer.state_dict(), 310 | 'scheduler_state_dict': self.scheduler.state_dict(), 311 | 'best_val_loss': self.best_val_loss, 312 | 'history': self.history 313 | }, str(save_path)) 314 | 315 | def load_model(self, filename): 316 | """ 317 | 加载模型状态 318 | 319 | 加载完整的训练状态,包括: 320 | - 模型参数 321 | - 优化器状态 322 | - 学习率调度器状态 323 | - 训练历史 324 | - 最佳验证损失 325 | 326 | 参数: 327 | filename (str): 要加载的文件名 328 | 329 | 返回: 330 | bool: 是否成功加载模型 331 | """ 332 | load_path = self.weights_dir / filename 333 | if load_path.exists(): 334 | checkpoint = torch.load(str(load_path)) 335 | self.model.load_state_dict(checkpoint['model_state_dict']) 336 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 337 | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 338 | self.current_epoch = checkpoint['epoch'] 339 | self.best_val_loss = checkpoint['best_val_loss'] 340 | if 'history' in checkpoint: 341 | self.history = checkpoint['history'] 342 | return True 343 | return False 344 | 345 | def get_history(self): 346 | """ 347 | 获取训练历史 348 | 349 | 返回: 350 | dict: 包含训练损失、验证损失和学习率的历史记录 351 | """ 352 | return self.history --------------------------------------------------------------------------------