├── LICENSE
├── README.md
├── requirements.txt
└── src
├── .idea
├── .gitignore
├── MarsCodeWorkspaceAppSettings.xml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── src.iml
├── config
└── default.yml
├── deep_learning
├── __pycache__
│ ├── models.cpython-312.pyc
│ └── models.cpython-39.pyc
├── dl_estimator.py
└── models.py
├── main.py
├── traditional
├── __pycache__
│ ├── estimators.cpython-312.pyc
│ └── estimators.cpython-39.pyc
├── estimators.py
└── ls_estimation.py
└── utils
├── __pycache__
├── config.cpython-312.pyc
├── config.cpython-39.pyc
├── data_generator.cpython-312.pyc
├── data_generator.cpython-39.pyc
├── preprocessing.cpython-312.pyc
├── preprocessing.cpython-39.pyc
├── trainer.cpython-312.pyc
└── trainer.cpython-39.pyc
├── config.py
├── data_generator.py
├── preprocessing.py
└── trainer.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 修明
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 信道估计方法比较研究
2 |
3 | 本项目实现了一个完整的信道估计方法比较研究框架,包括传统方法和深度学习方法的实现、评估和性能对比。
4 |
5 | ## 项目概述
6 |
7 | 本项目旨在比较不同信道估计方法的性能,包括:
8 |
9 | ### 传统方法
10 | - LS(最小二乘)估计
11 | - LMMSE(线性最小均方误差)估计
12 | - ML(最大似然)估计
13 |
14 | ### 深度学习方法
15 | - CNN(卷积神经网络)
16 | - RNN(循环神经网络)
17 | - LSTM(长短期记忆网络)
18 | - GRU(门控循环单元)
19 | - Hybrid(混合CNN-LSTM模型)
20 |
21 | ## 项目结构
22 |
23 | ```
24 | project/
25 | ├── src/ # 源代码目录
26 | │ ├── main.py # 主程序入口
27 | │ ├── config/ # 配置文件目录
28 | │ │ └── default.yml # 默认配置文件
29 | │ ├── traditional/ # 传统方法实现
30 | │ │ └── estimators.py # 传统估计器实现
31 | │ ├── deep_learning/ # 深度学习方法实现
32 | │ │ └── models.py # 深度学习模型实现
33 | │ └── utils/ # 工具函数
34 | │ ├── config.py # 配置加载器
35 | │ ├── data_generator.py # 数据生成器
36 | │ ├── preprocessing.py # 数据预处理
37 | │ └── trainer.py # 模型训练器
38 | ├── results/ # 实验结果目录
39 | │ ├── logs/ # 日志文件
40 | │ ├── models/ # 保存的模型
41 | │ └── plots/ # 结果图表
42 | └── README.md # 项目说明文档
43 | ```
44 |
45 | ## 功能特性
46 |
47 | 1. **数据处理**
48 | - 支持多种信道模型(Rayleigh、Rician等)
49 | - 灵活的数据生成和预处理
50 | - 数据增强支持
51 | - 自动数据集划分(训练/验证/测试)
52 |
53 | 2. **模型实现**
54 | - 传统方法的矩阵运算优化
55 | - 多种深度学习架构
56 | - GPU加速支持
57 | - 模型保存和加载
58 |
59 | 3. **训练与评估**
60 | - 自动化训练流程
61 | - 早停机制
62 | - 学习率自适应调整
63 | - 多指标性能评估
64 |
65 | 4. **可视化与分析**
66 | - 训练过程可视化
67 | - 性能对比图表
68 | - 详细的日志记录
69 | - 结果导出和保存
70 |
71 | ## 配置说明
72 |
73 | ### 实验配置
74 | ```yaml
75 | experiment:
76 | name: "channel_estimation" # 实验名称
77 | seed: 42 # 随机种子
78 | save_dir: "results/models" # 模型保存目录
79 | plot_dir: "results/plots" # 图表保存目录
80 | use_cuda: true # 是否使用GPU
81 |
82 | channel:
83 | n_tx: 4 # 发射天线数
84 | n_rx: 4 # 接收天线数
85 | n_pilot: 16 # 导频长度
86 | snr_db: 10 # 信噪比(dB)
87 | n_samples: 10000 # 样本数量
88 | type: "rayleigh" # 信道类型
89 | rician_k: 1.0 # Rician K因子(仅Rician信道)
90 |
91 | data:
92 | preprocessing:
93 | normalization: "standard" # 标准化方法
94 | remove_outliers: true # 是否移除异常值
95 | outlier_threshold: 3.0 # 异常值阈值
96 |
97 | augmentation:
98 | enabled: true # 是否启用数据增强
99 | methods: ["noise", "phase"] # 增强方法
100 | noise_std: 0.01 # 噪声标准差
101 | phase_shift_range: [-0.1, 0.1] # 相位偏移范围
102 |
103 | split:
104 | train_ratio: 0.7 # 训练集比例
105 | val_ratio: 0.15 # 验证集比例
106 |
107 | loader:
108 | batch_size: 128 # 批次大小
109 | num_workers: 4 # 数据加载线程数
110 | pin_memory: true # 是否固定内存
111 |
112 | models:
113 | common:
114 | learning_rate: 0.001 # 学习率
115 |
116 | cnn:
117 | channels: [64, 128, 256] # CNN通道数
118 | kernel_size: 3 # 卷积核大小
119 |
120 | rnn:
121 | hidden_size: 256 # 隐藏层大小
122 | num_layers: 2 # 层数
123 |
124 | lstm:
125 | hidden_size: 256
126 | num_layers: 2
127 |
128 | gru:
129 | hidden_size: 256
130 | num_layers: 2
131 |
132 | hybrid:
133 | rnn_hidden_size: 256 # RNN部分隐藏层大小
134 | cnn_channels: [64, 128] # CNN部分通道数
135 |
136 | training:
137 | epochs: 100 # 训练轮数
138 | early_stopping:
139 | patience: 10 # 早停耐心值
140 | ```
141 |
142 | ### 传统方法参数
143 | ```yaml
144 | traditional:
145 | ls:
146 | regularization: true # 是否使用正则化
147 | lambda: 0.01 # 正则化参数
148 |
149 | lmmse:
150 | adaptive_snr: true # 是否自适应SNR
151 | correlation_method: "empirical" # 相关矩阵计算方法
152 |
153 | ml:
154 | max_iter: 100 # 最大迭代次数
155 | tol: 1e-6 # 收敛阈值
156 | learning_rate: 0.01 # 学习率
157 | ```
158 |
159 | ## 使用说明
160 |
161 | 1. **环境配置**
162 | ```bash
163 | # 创建虚拟环境
164 | python -m venv venv
165 | source venv/bin/activate # Linux/Mac
166 | # 或
167 | .\venv\Scripts\activate # Windows
168 |
169 | # 安装依赖
170 | pip install -r requirements.txt
171 | ```
172 |
173 | 2. **运行实验**
174 | ```bash
175 | # 使用默认配置
176 | python src/main.py
177 |
178 | # 指定配置文件
179 | python src/main.py --config path/to/config.yml
180 |
181 | # 指定设备
182 | python src/main.py --device cuda
183 |
184 | # 设置随机种子
185 | python src/main.py --seed 42
186 | ```
187 |
188 | 3. **查看结果**
189 | - 实验日志:`results/logs/`
190 | - 训练好的模型:`results/models/`
191 | - 性能对比图表:`results/plots/`
192 | - 实验结果JSON:`results/results_*.json`
193 |
194 | ## 性能指标
195 |
196 | 项目使用以下指标评估估计器性能:
197 |
198 | 1. **MSE (均方误差)**
199 | - 衡量估计值与真实值的平均平方差
200 | - 越小越好
201 |
202 | 2. **NMSE (归一化均方误差)**
203 | - 考虑信道功率归一化后的MSE
204 | - 消除信道功率差异的影响
205 |
206 | 3. **BER (误比特率)**
207 | - 在给定信道估计下的通信系统误比特率
208 | - 实际通信性能的直接指标
209 |
210 | ## 实验结果
211 |
212 | 实验会生成以下可视化结果:
213 |
214 | 1. **整体性能对比**
215 | - 所有方法的MSE柱状图对比
216 | - 直观展示各方法的估计精度
217 |
218 | 2. **训练过程分析**
219 | - 每个深度学习模型的训练损失曲线
220 | - 学习率变化曲线
221 | - 帮助理解模型训练动态
222 |
223 | 3. **多维度性能对比**
224 | - 传统方法和深度学习方法的雷达图
225 | - 从多个指标综合评估性能
226 |
227 | 4. **误差分布分析**
228 | - 各方法MSE的箱线图
229 | - 展示估计误差的统计特性
230 |
231 | ## 注意事项
232 |
233 | 1. **硬件要求**
234 | - 建议使用GPU进行训练
235 | - 至少8GB内存
236 | - 存储空间:约1GB(取决于实验规模)
237 |
238 | 2. **数据处理**
239 | - 注意数据预处理的标准化方法选择
240 | - 合理设置异常值阈值
241 | - 根据实际需求调整数据增强参数
242 |
243 | 3. **训练优化**
244 | - 适当调整批次大小和学习率
245 | - 注意早停参数的设置
246 | - 监控训练过程避免过拟合
247 |
248 | 4. **结果分析**
249 | - 综合考虑多个性能指标
250 | - 注意不同信道条件下的表现
251 | - 考虑计算复杂度和实时性要求
252 |
253 | # 作者信息
254 |
255 | - 作者:修明
256 | - 邮箱:lzmpt@qq.com
257 |
258 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.21.0
2 | torch>=1.9.0
3 | matplotlib>=3.4.0
4 | pyyaml>=5.4.0
5 | tqdm>=4.61.0
6 | tensorboard>=2.6.0
7 | scikit-learn>=0.24.0
8 | pandas>=1.3.0
9 | seaborn>=0.11.0
10 | rich>=10.0.0
11 | pytest>=6.2.0 # 用于测试
12 | black>=21.5b2 # 用于代码格式化
13 | flake8>=3.9.0 # 用于代码检查
14 | mypy>=0.910 # 用于类型检查
--------------------------------------------------------------------------------
/src/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # 默认忽略的文件
2 | /shelf/
3 | /workspace.xml
4 | # 基于编辑器的 HTTP 客户端请求
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/src/.idea/MarsCodeWorkspaceAppSettings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/src/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/src/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/src/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/src/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/src/.idea/src.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/src/config/default.yml:
--------------------------------------------------------------------------------
1 | # 信道估计实验配置文件
2 |
3 | # 实验基本设置
4 | experiment:
5 | # 实验名称,用于保存结果和日志
6 | name: 'channel_estimation'
7 | # 实验描述
8 | description: '信道估计方法比较研究'
9 | # 随机种子,用于复现实验结果
10 | seed: 42
11 | # 是否使用GPU
12 | use_cuda: true
13 | # 结果保存路径
14 | save_dir: 'results'
15 | # 结果图保存路径
16 | plot_dir: 'plots'
17 | # 是否启用详细日志
18 | verbose: true
19 |
20 | # 信道参数设置
21 | channel:
22 | # 发射天线数量
23 | n_tx: 4
24 | # 接收天线数量
25 | n_rx: 4
26 | # 导频符号长度(等于发射天线数)
27 | n_pilot: 16
28 | # 信噪比(dB)
29 | snr_db: 10
30 | # 采样数量
31 | n_samples: 10000
32 | # 信道类型:'rayleigh' 或 'rician'
33 | type: 'rayleigh'
34 | # Rician K因子(仅在type='rician'时有效)
35 | rician_k: 1.0
36 |
37 | # 数据处理设置
38 | data:
39 | # 数据预处理
40 | preprocessing:
41 | # 归一化方法:'z-score' 或 'min-max'
42 | normalization: 'z-score'
43 | # 是否移除异常值
44 | remove_outliers: true
45 | # 异常值阈值(标准差的倍数)
46 | outlier_threshold: 3.0
47 |
48 | # 数据增强
49 | augmentation:
50 | # 是否启用数据增强
51 | enabled: true
52 | # 增强方法列表
53 | methods: ['noise', 'phase_shift', 'magnitude_scale']
54 | # 高斯噪声标准差
55 | noise_std: 0.1
56 | # 相位偏移范围(弧度)
57 | phase_shift_range: [-0.1, 0.1]
58 | # 幅度缩放范围
59 | magnitude_scale_range: [0.9, 1.1]
60 |
61 | # 数据集划分
62 | split:
63 | # 训练集比例
64 | train_ratio: 0.7
65 | # 验证集比例
66 | val_ratio: 0.15
67 | # 测试集比例(自动计算为1-train_ratio-val_ratio)
68 | test_ratio: 0.15
69 |
70 | # 数据加载
71 | loader:
72 | # 批次大小
73 | batch_size: 128
74 | # 是否打乱训练数据
75 | shuffle: true
76 | # 数据加载线程数
77 | num_workers: 4
78 | # 是否将数据固定在内存中
79 | pin_memory: true
80 |
81 | # 传统估计器设置
82 | traditional:
83 | # LS估计器参数
84 | ls:
85 | # 是否使用正则化
86 | regularization: true
87 | # 正则化参数
88 | lambda: 0.01
89 |
90 | # LMMSE估计器参数
91 | lmmse:
92 | # 是否使用自适应SNR估计
93 | adaptive_snr: true
94 | # 信道相关矩阵估计方法:'sample' 或 'theoretical'
95 | correlation_method: 'sample'
96 |
97 | # ML估计器参数
98 | ml:
99 | # 最大迭代次数
100 | max_iter: 100
101 | # 收敛阈值
102 | tol: 1e-6
103 | # 学习率
104 | learning_rate: 0.01
105 |
106 | # 深度学习模型设置
107 | models:
108 | # 通用设置
109 | common:
110 | # 学习率
111 | learning_rate: 0.001
112 | # 权重衰减
113 | weight_decay: 0.0001
114 | # Dropout率
115 | dropout: 0.3
116 | # 批归一化动量
117 | bn_momentum: 0.1
118 |
119 | # RNN模型参数
120 | rnn:
121 | # 隐藏层大小
122 | hidden_size: 256
123 | # 层数
124 | num_layers: 2
125 | # RNN类型:'vanilla', 'lstm', 'gru'
126 | rnn_type: 'vanilla'
127 |
128 | # LSTM模型参数
129 | lstm:
130 | # 隐藏层大小
131 | hidden_size: 256
132 | # 层数
133 | num_layers: 2
134 | # 是否使用双向LSTM
135 | bidirectional: true
136 |
137 | # GRU模型参数
138 | gru:
139 | # 隐藏层大小
140 | hidden_size: 256
141 | # 层数
142 | num_layers: 2
143 | # 是否使用双向GRU
144 | bidirectional: true
145 |
146 | # 混合模型参数
147 | hybrid:
148 | # RNN部分隐藏层大小
149 | rnn_hidden_size: 256
150 | # 全连接层大小
151 | fc_sizes: [512, 256]
152 |
153 | # 训练设置
154 | training:
155 | # 训练轮数
156 | epochs: 100
157 | # 早停设置
158 | early_stopping:
159 | # 是否启用早停
160 | enabled: true
161 | # 早停耐心值
162 | patience: 10
163 | # 最小改善阈值
164 | min_delta: 1e-4
165 |
166 | # 学习率调度器
167 | scheduler:
168 | # 调度器类型:'reduce_on_plateau', 'cosine', 'step'
169 | type: 'reduce_on_plateau'
170 | # 降低学习率的因子
171 | factor: 0.1
172 | # 降低学习率前的耐心值
173 | patience: 5
174 | # 最小学习率
175 | min_lr: 1e-6
176 |
177 | # 模型保存
178 | checkpointing:
179 | # 是否启用检查点保存
180 | enabled: true
181 | # 保存频率(每N个epoch)
182 | save_freq: 10
183 | # 是否只保存最佳模型
184 | save_best_only: false
185 | # 最大保存检查点数量
186 | max_to_keep: 5
187 |
188 | # 评估设置
189 | evaluation:
190 | # 评估指标列表
191 | metrics: ['mse', 'nmse', 'ber']
192 | # 是否计算置信区间
193 | confidence_interval: true
194 | # 置信水平
195 | confidence_level: 0.95
196 | # 是否保存预测结果
197 | save_predictions: true
198 |
199 | # TensorBoard设置
200 | tensorboard:
201 | # 是否启用TensorBoard
202 | enabled: true
203 | # 更新频率(每N个batch)
204 | update_freq: 10
205 | # 是否记录梯度直方图
206 | histogram_freq: 1
207 | # 是否记录模型图
208 | write_graph: true
209 | # 是否记录配置文件
210 | write_config: true
--------------------------------------------------------------------------------
/src/deep_learning/__pycache__/models.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/deep_learning/__pycache__/models.cpython-312.pyc
--------------------------------------------------------------------------------
/src/deep_learning/__pycache__/models.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/deep_learning/__pycache__/models.cpython-39.pyc
--------------------------------------------------------------------------------
/src/deep_learning/dl_estimator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import numpy as np
5 |
6 | class DLChannelEstimator(nn.Module):
7 | """基于深度学习的信道估计器"""
8 |
9 | def __init__(self, input_size, hidden_size=128):
10 | super(DLChannelEstimator, self).__init__()
11 |
12 | self.network = nn.Sequential(
13 | nn.Linear(input_size, hidden_size),
14 | nn.ReLU(),
15 | nn.BatchNorm1d(hidden_size),
16 | nn.Dropout(0.3),
17 |
18 | nn.Linear(hidden_size, hidden_size),
19 | nn.ReLU(),
20 | nn.BatchNorm1d(hidden_size),
21 | nn.Dropout(0.3),
22 |
23 | nn.Linear(hidden_size, input_size)
24 | )
25 |
26 | self.criterion = nn.MSELoss()
27 | self.optimizer = None
28 |
29 | def forward(self, x):
30 | """前向传播"""
31 | return self.network(x)
32 |
33 | def train_model(self, train_loader, epochs=100, learning_rate=0.001):
34 | """
35 | 训练模型
36 |
37 | 参数:
38 | train_loader: 训练数据加载器
39 | epochs: 训练轮数
40 | learning_rate: 学习率
41 | """
42 | self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
43 |
44 | for epoch in range(epochs):
45 | total_loss = 0
46 | for batch_X, batch_Y in train_loader:
47 | self.optimizer.zero_grad()
48 |
49 | outputs = self(batch_X)
50 | loss = self.criterion(outputs, batch_Y)
51 |
52 | loss.backward()
53 | self.optimizer.step()
54 |
55 | total_loss += loss.item()
56 |
57 | if (epoch + 1) % 10 == 0:
58 | print(f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}')
59 |
60 | def estimate(self, X):
61 | """
62 | 使用训练好的模型进行信道估计
63 |
64 | 参数:
65 | X: 输入信号
66 |
67 | 返回:
68 | 估计的信道响应
69 | """
70 | self.eval()
71 | with torch.no_grad():
72 | return self(X)
73 |
74 | def get_mse(self, y_true, y_pred):
75 | """计算均方误差"""
76 | return self.criterion(y_pred, y_true).item()
--------------------------------------------------------------------------------
/src/deep_learning/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | class CNNEstimator(nn.Module):
6 | """基于CNN的信道估计器"""
7 | def __init__(self, input_size, n_rx=4, n_tx=8):
8 | super(CNNEstimator, self).__init__()
9 |
10 | self.n_features = input_size
11 | self.n_rx = n_rx
12 | self.n_tx = n_tx
13 |
14 | self.cnn = nn.Sequential(
15 | # 将输入重塑为 (batch_size, 1, -1),作为1D CNN的输入
16 | nn.Conv1d(1, 64, kernel_size=3, padding=1),
17 | nn.ReLU(),
18 | nn.BatchNorm1d(64),
19 |
20 | nn.Conv1d(64, 128, kernel_size=3, padding=1),
21 | nn.ReLU(),
22 | nn.BatchNorm1d(128),
23 |
24 | nn.Conv1d(128, 64, kernel_size=3, padding=1),
25 | nn.ReLU(),
26 | nn.BatchNorm1d(64),
27 |
28 | nn.Conv1d(64, 1, kernel_size=3, padding=1)
29 | )
30 |
31 | # 全连接层
32 | self.fc = nn.Sequential(
33 | nn.Linear(self.n_features, 512),
34 | nn.ReLU(),
35 | nn.Dropout(0.3),
36 | nn.Linear(512, n_rx * n_tx)
37 | )
38 |
39 | def forward(self, x):
40 | batch_size = x.size(0)
41 | # 重塑为1D CNN输入
42 | x = x.view(batch_size, 1, -1)
43 | x = self.cnn(x)
44 | x = x.view(batch_size, -1)
45 | x = self.fc(x)
46 | # 重塑输出为 [batch_size, n_rx, n_tx]
47 | return x.view(batch_size, self.n_rx, self.n_tx)
48 |
49 | class RNNEstimator(nn.Module):
50 | """基于RNN的信道估计器"""
51 | def __init__(self, input_size, hidden_size=256, num_layers=2, n_rx=4, n_tx=8):
52 | super(RNNEstimator, self).__init__()
53 |
54 | self.n_features = input_size
55 | self.seq_length = 8 # 将输入序列分成8个时间步
56 | self.feature_size = self.n_features // self.seq_length
57 | self.n_rx = n_rx
58 | self.n_tx = n_tx
59 |
60 | self.rnn = nn.RNN(self.feature_size, hidden_size, num_layers, batch_first=True)
61 | self.fc = nn.Sequential(
62 | nn.Linear(hidden_size, 512),
63 | nn.ReLU(),
64 | nn.Dropout(0.3),
65 | nn.Linear(512, n_rx * n_tx)
66 | )
67 |
68 | def forward(self, x):
69 | batch_size = x.size(0)
70 | # 重塑为序列数据
71 | x = x.view(batch_size, self.seq_length, -1)
72 | out, _ = self.rnn(x)
73 | x = self.fc(out[:, -1, :])
74 | # 重塑输出为 [batch_size, n_rx, n_tx]
75 | return x.view(batch_size, self.n_rx, self.n_tx)
76 |
77 | class LSTMEstimator(nn.Module):
78 | """基于LSTM的信道估计器"""
79 | def __init__(self, input_size, hidden_size=256, num_layers=2, n_rx=4, n_tx=8):
80 | super(LSTMEstimator, self).__init__()
81 |
82 | self.n_features = input_size
83 | self.seq_length = 8
84 | self.feature_size = self.n_features // self.seq_length
85 | self.n_rx = n_rx
86 | self.n_tx = n_tx
87 |
88 | self.lstm = nn.LSTM(self.feature_size, hidden_size, num_layers, batch_first=True)
89 | self.fc = nn.Sequential(
90 | nn.Linear(hidden_size, 512),
91 | nn.ReLU(),
92 | nn.Dropout(0.3),
93 | nn.Linear(512, n_rx * n_tx)
94 | )
95 |
96 | def forward(self, x):
97 | batch_size = x.size(0)
98 | x = x.view(batch_size, self.seq_length, -1)
99 | out, (_, _) = self.lstm(x)
100 | x = self.fc(out[:, -1, :])
101 | # 重塑输出为 [batch_size, n_rx, n_tx]
102 | return x.view(batch_size, self.n_rx, self.n_tx)
103 |
104 | class GRUEstimator(nn.Module):
105 | """基于GRU的信道估计器"""
106 | def __init__(self, input_size, hidden_size=256, num_layers=2, n_rx=4, n_tx=8):
107 | super(GRUEstimator, self).__init__()
108 |
109 | self.n_features = input_size
110 | self.seq_length = 8
111 | self.feature_size = self.n_features // self.seq_length
112 | self.n_rx = n_rx
113 | self.n_tx = n_tx
114 |
115 | self.gru = nn.GRU(self.feature_size, hidden_size, num_layers, batch_first=True)
116 | self.fc = nn.Sequential(
117 | nn.Linear(hidden_size, 512),
118 | nn.ReLU(),
119 | nn.Dropout(0.3),
120 | nn.Linear(512, n_rx * n_tx)
121 | )
122 |
123 | def forward(self, x):
124 | batch_size = x.size(0)
125 | x = x.view(batch_size, self.seq_length, -1)
126 | out, _ = self.gru(x)
127 | x = self.fc(out[:, -1, :])
128 | # 重塑输出为 [batch_size, n_rx, n_tx]
129 | return x.view(batch_size, self.n_rx, self.n_tx)
130 |
131 | class HybridEstimator(nn.Module):
132 | """混合深度学习估计器(CNN+LSTM)"""
133 | def __init__(self, input_size, hidden_size=256, n_rx=4, n_tx=8):
134 | super(HybridEstimator, self).__init__()
135 |
136 | self.n_features = input_size
137 | self.seq_length = 8
138 | self.feature_size = self.n_features // self.seq_length
139 | self.n_rx = n_rx
140 | self.n_tx = n_tx
141 |
142 | # 1D CNN特征提取
143 | self.cnn = nn.Sequential(
144 | nn.Conv1d(1, 64, kernel_size=3, padding=1),
145 | nn.ReLU(),
146 | nn.BatchNorm1d(64),
147 | nn.Conv1d(64, 32, kernel_size=3, padding=1),
148 | nn.ReLU(),
149 | nn.BatchNorm1d(32)
150 | )
151 |
152 | # LSTM处理时序特征
153 | self.lstm = nn.LSTM(32 * self.feature_size, hidden_size, batch_first=True)
154 |
155 | # 全连接层
156 | self.fc = nn.Sequential(
157 | nn.Linear(hidden_size, 512),
158 | nn.ReLU(),
159 | nn.Dropout(0.3),
160 | nn.Linear(512, n_rx * n_tx)
161 | )
162 |
163 | def forward(self, x):
164 | batch_size = x.size(0)
165 |
166 | # CNN特征提取
167 | x = x.view(batch_size, 1, -1)
168 | cnn_out = self.cnn(x)
169 |
170 | # 重塑以适应LSTM
171 | lstm_in = cnn_out.view(batch_size, self.seq_length, -1)
172 |
173 | # LSTM处理
174 | lstm_out, _ = self.lstm(lstm_in)
175 |
176 | # 全连接层
177 | x = self.fc(lstm_out[:, -1, :])
178 | # 重塑输出为 [batch_size, n_rx, n_tx]
179 | return x.view(batch_size, self.n_rx, self.n_tx)
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | """
2 | 信道估计方法比较研究
3 |
4 | 本模块是项目的主入口,实现了完整的信道估计实验流程,包括:
5 | 1. 数据生成与预处理
6 | 2. 传统方法评估(LS、LMMSE、ML)
7 | 3. 深度学习方法评估(CNN、RNN、LSTM、GRU、Hybrid)
8 | 4. 结果可视化与保存
9 |
10 | 主要功能:
11 | - 配置管理:支持通过YAML文件配置实验参数
12 | - 数据处理:生成、预处理、增强和划分数据集
13 | - 模型训练:支持多种深度学习模型的训练和评估
14 | - 结果分析:生成多种可视化图表进行性能对比
15 | - 日志记录:详细记录实验过程和结果
16 |
17 | 作者: lzm lzmpt@qq.com
18 | 日期: 2025-03-07
19 | """
20 |
21 | import numpy as np
22 | import torch
23 | from torch.utils.data import TensorDataset, DataLoader
24 | import matplotlib.pyplot as plt
25 | from datetime import datetime
26 | import os
27 | from pathlib import Path
28 | import time
29 | import argparse
30 | import random
31 | from typing import Dict
32 | import yaml
33 | import json
34 | from typing import Any
35 | import warnings
36 | import sys
37 | import logging
38 | from contextlib import redirect_stdout
39 |
40 | # 过滤掉特定的警告
41 | warnings.filterwarnings('ignore', category=UserWarning, module='torch.optim.lr_scheduler')
42 |
43 | from traditional.estimators import LSEstimator, LMMSEEstimator, MLEstimator, calculate_performance_metrics
44 | from deep_learning.models import CNNEstimator, RNNEstimator, LSTMEstimator, GRUEstimator, HybridEstimator
45 | from utils.data_generator import ChannelDataGenerator
46 | from utils.preprocessing import ChannelPreprocessor, augment_data
47 | from utils.trainer import ChannelEstimatorTrainer
48 | from utils.config import Config
49 |
50 | class TeeLogger:
51 | """
52 | 同时将输出写入到控制台和文件的日志记录器
53 |
54 | 该类实现了一个双向输出流,可以:
55 | 1. 将所有输出同时发送到终端和日志文件
56 | 2. 实时刷新输出,确保日志及时记录
57 | 3. 支持标准输出流的所有基本操作
58 |
59 | 参数:
60 | filename (str): 日志文件的路径
61 | """
62 | def __init__(self, filename):
63 | self.terminal = sys.stdout
64 | self.log_file = open(filename, 'w', encoding='utf-8')
65 |
66 | def write(self, message):
67 | """写入消息到终端和文件"""
68 | self.terminal.write(message)
69 | self.log_file.write(message)
70 | self.log_file.flush()
71 |
72 | def flush(self):
73 | """刷新输出缓冲区"""
74 | self.terminal.flush()
75 | self.log_file.flush()
76 |
77 | def setup_logging(save_dir: Path, timestamp: str):
78 | """
79 | 设置日志记录系统
80 |
81 | 创建日志目录并初始化日志记录器,支持:
82 | 1. 创建时间戳命名的日志文件
83 | 2. 同时输出到控制台和文件
84 | 3. 自动创建所需目录结构
85 |
86 | 参数:
87 | save_dir (Path): 保存目录的路径
88 | timestamp (str): 时间戳字符串
89 |
90 | 返回:
91 | TeeLogger: 配置好的日志记录器
92 | """
93 | log_dir = save_dir / 'logs'
94 | log_dir.mkdir(parents=True, exist_ok=True)
95 | log_file = log_dir / f'experiment_{timestamp}.log'
96 | return TeeLogger(str(log_file))
97 |
98 | def set_seed(seed):
99 | """
100 | 设置随机种子以确保实验可重复性
101 |
102 | 统一设置所有相关库的随机种子,包括:
103 | 1. Python random模块
104 | 2. NumPy
105 | 3. PyTorch CPU
106 | 4. PyTorch GPU(如果可用)
107 | 5. CUDA后端(如果可用)
108 |
109 | 参数:
110 | seed (int): 随机种子值
111 | """
112 | random.seed(seed)
113 | np.random.seed(seed)
114 | torch.manual_seed(seed)
115 | if torch.cuda.is_available():
116 | torch.cuda.manual_seed(seed)
117 | torch.cuda.manual_seed_all(seed)
118 | torch.backends.cudnn.deterministic = True
119 | torch.backends.cudnn.benchmark = False
120 |
121 | def print_section_header(title):
122 | """
123 | 打印带有时间戳的分节标题
124 |
125 | 创建醒目的分节标题,包括:
126 | 1. 分隔线
127 | 2. 当前时间戳
128 | 3. 节标题
129 |
130 | 参数:
131 | title (str): 节标题文本
132 | """
133 | timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
134 | print(f"\n{'='*80}\n{timestamp} - {title}\n{'='*80}")
135 |
136 | def print_progress(message, indent=0):
137 | """
138 | 打印带有时间戳的进度信息
139 |
140 | 格式化输出进度信息,包括:
141 | 1. 时间戳
142 | 2. 缩进级别
143 | 3. 具体消息
144 |
145 | 参数:
146 | message (str): 要打印的消息
147 | indent (int): 缩进级别(每级2个空格)
148 | """
149 | timestamp = datetime.now().strftime("%H:%M:%S")
150 | indent_str = " " * indent
151 | print(f"[{timestamp}] {indent_str}{message}")
152 |
153 | def evaluate_traditional_methods(H: np.ndarray, X: np.ndarray, Y: np.ndarray, snr: float, cfg: dict) -> Dict[str, Dict[str, float]]:
154 | """
155 | 评估传统信道估计方法的性能
156 |
157 | 实现了三种传统估计方法的评估:
158 | 1. LS(最小二乘)估计
159 | 2. LMMSE(线性最小均方误差)估计
160 | 3. ML(最大似然)估计
161 |
162 | 对每种方法:
163 | - 初始化估计器
164 | - 执行信道估计
165 | - 计算性能指标
166 | - 记录执行时间
167 | - 输出详细结果
168 |
169 | 参数:
170 | H (np.ndarray): 真实信道矩阵,shape=(n_samples, n_rx, n_tx)
171 | X (np.ndarray): 导频符号,shape=(n_samples, n_tx, n_pilot)
172 | Y (np.ndarray): 接收信号,shape=(n_samples, n_rx, n_pilot)
173 | snr (float): 信噪比(线性,非dB)
174 | cfg (dict): 配置字典,包含各估计器的参数
175 |
176 | 返回:
177 | Dict[str, Dict[str, float]]: 包含各估计器性能指标的嵌套字典
178 | """
179 | results = {}
180 |
181 | # LS估计
182 | print_progress("开始 LS 估计...", 1)
183 | start_time = time.time()
184 | ls_estimator = LSEstimator(
185 | regularization=cfg['traditional']['ls']['regularization'],
186 | lambda_=cfg['traditional']['ls']['lambda']
187 | )
188 | H_est_ls = ls_estimator.estimate(Y, X)
189 | results['ls'] = calculate_performance_metrics(H, H_est_ls)
190 | print_progress(f"LS 估计完成 (耗时: {time.time()-start_time:.2f}s)", 1)
191 | print_progress("LS 性能指标:", 2)
192 | for metric, value in results['ls'].items():
193 | print_progress(f"{metric}: {value:.6f}", 3)
194 |
195 | # LMMSE估计
196 | print_progress("\n开始 LMMSE 估计...", 1)
197 | start_time = time.time()
198 | lmmse_estimator = LMMSEEstimator(
199 | snr=snr,
200 | adaptive_snr=cfg['traditional']['lmmse']['adaptive_snr'],
201 | correlation_method=cfg['traditional']['lmmse']['correlation_method']
202 | )
203 | H_est_lmmse = lmmse_estimator.estimate(Y, X)
204 | results['lmmse'] = calculate_performance_metrics(H, H_est_lmmse)
205 | print_progress(f"LMMSE 估计完成 (耗时: {time.time()-start_time:.2f}s)", 1)
206 | print_progress("LMMSE 性能指标:", 2)
207 | for metric, value in results['lmmse'].items():
208 | print_progress(f"{metric}: {value:.6f}", 3)
209 |
210 | # ML估计
211 | print_progress("\n开始 ML 估计...", 1)
212 | start_time = time.time()
213 | ml_estimator = MLEstimator(
214 | max_iter=cfg['traditional']['ml']['max_iter'],
215 | tol=cfg['traditional']['ml']['tol'],
216 | learning_rate=cfg['traditional']['ml']['learning_rate']
217 | )
218 | H_est_ml = ml_estimator.estimate(Y, X)
219 | results['ml'] = calculate_performance_metrics(H, H_est_ml)
220 | print_progress(f"ML 估计完成 (耗时: {time.time()-start_time:.2f}s)", 1)
221 | print_progress("ML 性能指标:", 2)
222 | for metric, value in results['ml'].items():
223 | print_progress(f"{metric}: {value:.6f}", 3)
224 |
225 | return results
226 |
227 | def evaluate_dl_methods(train_loader, val_loader, test_loader, input_size, cfg):
228 | """
229 | 评估深度学习方法的性能
230 |
231 | 实现了五种深度学习模型的训练和评估:
232 | 1. CNN(卷积神经网络)
233 | 2. RNN(循环神经网络)
234 | 3. LSTM(长短期记忆网络)
235 | 4. GRU(门控循环单元)
236 | 5. Hybrid(混合CNN-LSTM模型)
237 |
238 | 对每个模型:
239 | - 初始化模型结构
240 | - 配置训练器
241 | - 执行训练过程
242 | - 加载最佳模型
243 | - 在测试集上评估
244 | - 记录性能指标
245 |
246 | 参数:
247 | train_loader (DataLoader): 训练数据加载器
248 | val_loader (DataLoader): 验证数据加载器
249 | test_loader (DataLoader): 测试数据加载器
250 | input_size (int): 输入特征维度
251 | cfg (dict): 配置字典,包含模型和训练参数
252 |
253 | 返回:
254 | dict: 包含各模型训练结果和性能指标的字典
255 | """
256 | device = torch.device('cuda' if cfg['experiment']['use_cuda'] and torch.cuda.is_available() else 'cpu')
257 | results = {}
258 |
259 | # 获取通用模型设置
260 | common_cfg = cfg['models']['common']
261 |
262 | # 定义要评估的模型
263 | models = {
264 | 'CNN': CNNEstimator(input_size=input_size),
265 | 'RNN': RNNEstimator(
266 | input_size=input_size,
267 | hidden_size=cfg['models']['rnn']['hidden_size'],
268 | num_layers=cfg['models']['rnn']['num_layers']
269 | ),
270 | 'LSTM': LSTMEstimator(
271 | input_size=input_size,
272 | hidden_size=cfg['models']['lstm']['hidden_size'],
273 | num_layers=cfg['models']['lstm']['num_layers']
274 | ),
275 | 'GRU': GRUEstimator(
276 | input_size=input_size,
277 | hidden_size=cfg['models']['gru']['hidden_size'],
278 | num_layers=cfg['models']['gru']['num_layers']
279 | ),
280 | 'Hybrid': HybridEstimator(
281 | input_size=input_size,
282 | hidden_size=cfg['models']['hybrid']['rnn_hidden_size']
283 | )
284 | }
285 |
286 | print_section_header("评估深度学习方法")
287 |
288 | # 评估每个模型
289 | for name, model in models.items():
290 | print_progress(f"\n开始训练 {name} 模型...", 1)
291 | print_progress(f"模型结构:", 2)
292 | print_progress(str(model), 3)
293 |
294 | # 创建训练器
295 | trainer = ChannelEstimatorTrainer(
296 | model=model,
297 | train_loader=train_loader,
298 | val_loader=val_loader,
299 | save_dir=cfg['experiment']['save_dir'],
300 | project_name=f"{cfg['experiment']['name']}_{name.lower()}",
301 | device=device,
302 | learning_rate=common_cfg['learning_rate']
303 | )
304 |
305 | # 训练模型
306 | start_time = time.time()
307 | trainer.train(
308 | epochs=cfg['training']['epochs'],
309 | early_stopping_patience=cfg['training']['early_stopping']['patience']
310 | )
311 | training_time = time.time() - start_time
312 | print_progress(f"{name} 模型训练完成 (耗时: {training_time:.2f}s)", 1)
313 |
314 | # 加载最佳模型进行测试
315 | trainer.load_model('best.pt')
316 | model = trainer.model
317 |
318 | # 测试评估
319 | print_progress(f"开始评估 {name} 模型...", 1)
320 | model.eval()
321 | all_preds = []
322 | all_true = []
323 |
324 | test_start_time = time.time()
325 | with torch.no_grad():
326 | for batch_X, batch_Y in test_loader:
327 | batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
328 | outputs = model(batch_X)
329 | all_preds.append(outputs.cpu().numpy())
330 | all_true.append(batch_Y.cpu().numpy())
331 |
332 | all_preds = np.concatenate(all_preds, axis=0)
333 | all_true = np.concatenate(all_true, axis=0)
334 |
335 | metrics = calculate_performance_metrics(all_true, all_preds)
336 | results[name] = {
337 | 'metrics': metrics,
338 | 'model': model,
339 | 'trainer': trainer
340 | }
341 |
342 | print_progress(f"{name} 模型评估完成 (耗时: {time.time()-test_start_time:.2f}s)", 1)
343 | print_progress(f"{name} 模型性能指标:", 2)
344 | for metric_name, value in metrics.items():
345 | print_progress(f"{metric_name}: {value:.6f}", 3)
346 |
347 | return results
348 |
349 | def plot_results(traditional_results, dl_results, cfg):
350 | """
351 | 绘制实验结果比较图表
352 |
353 | 生成四种类型的可视化图表:
354 | 1. 整体性能对比图:所有方法的MSE柱状图
355 | 2. 训练历史曲线:每个深度学习模型的训练过程
356 | 3. 性能指标雷达图:传统方法和深度学习方法的多指标对比
357 | 4. 误差分布箱线图:所有方法的MSE分布对比
358 |
359 | 参数:
360 | traditional_results (dict): 传统方法的评估结果
361 | dl_results (dict): 深度学习方法的评估结果
362 | cfg (dict): 配置字典,包含绘图相关参数
363 | """
364 | print_section_header("绘制结果比较图")
365 |
366 | # 创建保存目录
367 | save_dir = Path(cfg['experiment']['plot_dir'])
368 | save_dir.mkdir(parents=True, exist_ok=True)
369 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
370 |
371 | print_progress("生成性能对比图...", 1)
372 |
373 | # 1. 整体MSE对比图
374 | plt.figure(figsize=(12, 6))
375 |
376 | # 绘制传统方法结果
377 | x_traditional = np.arange(len(traditional_results))
378 | mse_traditional = [results['mse'] for results in traditional_results.values()]
379 | plt.bar(x_traditional, mse_traditional, alpha=0.8, label='Traditional Methods')
380 | plt.xticks(x_traditional, traditional_results.keys())
381 |
382 | # 绘制深度学习方法结果
383 | x_dl = np.arange(len(dl_results)) + len(traditional_results)
384 | mse_dl = [results['metrics']['mse'] for results in dl_results.values()]
385 | plt.bar(x_dl, mse_dl, alpha=0.8, label='Deep Learning Methods')
386 | plt.xticks(np.concatenate([x_traditional, x_dl]),
387 | list(traditional_results.keys()) + list(dl_results.keys()),
388 | rotation=45)
389 |
390 | plt.ylabel('MSE')
391 | plt.title('所有方法性能对比')
392 | plt.legend()
393 | plt.tight_layout()
394 |
395 | save_path = save_dir / f'overall_comparison_{timestamp}.png'
396 | plt.savefig(save_path)
397 | plt.close()
398 |
399 | print_progress(f"整体对比图已保存至: {save_path}", 1)
400 |
401 | # 2. 每个深度学习模型的训练历史
402 | print_progress("生成训练历史图...", 1)
403 | for name, result in dl_results.items():
404 | trainer = result['trainer']
405 | history = trainer.get_history()
406 |
407 | plt.figure(figsize=(15, 5))
408 |
409 | # 训练损失和验证损失
410 | plt.subplot(1, 2, 1)
411 | plt.plot(history['train_loss'], label='训练损失')
412 | plt.plot(history['val_loss'], label='验证损失')
413 | plt.xlabel('Epoch')
414 | plt.ylabel('Loss')
415 | plt.title(f'{name} 模型训练历史')
416 | plt.legend()
417 | plt.grid(True)
418 |
419 | # 学习率变化
420 | plt.subplot(1, 2, 2)
421 | plt.plot(history['learning_rate'], label='学习率')
422 | plt.xlabel('Epoch')
423 | plt.ylabel('Learning Rate')
424 | plt.title('学习率变化')
425 | plt.legend()
426 | plt.grid(True)
427 |
428 | plt.tight_layout()
429 | save_path = save_dir / f'{name.lower()}_training_history_{timestamp}.png'
430 | plt.savefig(save_path)
431 | plt.close()
432 |
433 | print_progress(f"{name} 模型训练历史图已保存至: {save_path}", 2)
434 |
435 | # 3. 性能指标雷达图
436 | print_progress("生成性能指标雷达图...", 1)
437 | metrics = ['mse', 'nmse', 'ber'] # 根据实际指标调整
438 |
439 | # 传统方法雷达图
440 | plt.figure(figsize=(10, 10))
441 | angles = np.linspace(0, 2*np.pi, len(metrics), endpoint=False)
442 | angles = np.concatenate((angles, [angles[0]])) # 闭合图形
443 |
444 | ax = plt.subplot(111, polar=True)
445 | for method, results in traditional_results.items():
446 | values = [results[metric] for metric in metrics]
447 | values = np.concatenate((values, [values[0]])) # 闭合图形
448 | ax.plot(angles, values, 'o-', linewidth=2, label=method)
449 | ax.fill(angles, values, alpha=0.25)
450 |
451 | ax.set_xticks(angles[:-1])
452 | ax.set_xticklabels(metrics)
453 | plt.title('传统方法性能指标对比')
454 | plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
455 |
456 | save_path = save_dir / f'traditional_metrics_radar_{timestamp}.png'
457 | plt.savefig(save_path)
458 | plt.close()
459 |
460 | print_progress(f"传统方法性能指标雷达图已保存至: {save_path}", 2)
461 |
462 | # 深度学习方法雷达图
463 | plt.figure(figsize=(10, 10))
464 | ax = plt.subplot(111, polar=True)
465 | for name, result in dl_results.items():
466 | values = [result['metrics'][metric] for metric in metrics]
467 | values = np.concatenate((values, [values[0]])) # 闭合图形
468 | ax.plot(angles, values, 'o-', linewidth=2, label=name)
469 | ax.fill(angles, values, alpha=0.25)
470 |
471 | ax.set_xticks(angles[:-1])
472 | ax.set_xticklabels(metrics)
473 | plt.title('深度学习方法性能指标对比')
474 | plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
475 |
476 | save_path = save_dir / f'dl_metrics_radar_{timestamp}.png'
477 | plt.savefig(save_path)
478 | plt.close()
479 |
480 | print_progress(f"深度学习方法性能指标雷达图已保存至: {save_path}", 2)
481 |
482 | # 4. 箱线图比较
483 | print_progress("生成性能指标箱线图...", 1)
484 | plt.figure(figsize=(15, 6))
485 |
486 | # 合并所有方法的结果
487 | all_methods = {}
488 | all_methods.update({k: {'mse': [v['mse']]} for k, v in traditional_results.items()})
489 | all_methods.update({k: {'mse': [v['metrics']['mse']]} for k, v in dl_results.items()})
490 |
491 | # 创建箱线图数据
492 | labels = []
493 | mse_data = []
494 | for method, data in all_methods.items():
495 | labels.append(method)
496 | mse_data.append(data['mse'])
497 |
498 | plt.boxplot(mse_data, labels=labels)
499 | plt.xticks(rotation=45)
500 | plt.ylabel('MSE')
501 | plt.title('所有方法MSE分布对比')
502 | plt.grid(True)
503 |
504 | plt.tight_layout()
505 | save_path = save_dir / f'mse_boxplot_{timestamp}.png'
506 | plt.savefig(save_path)
507 | plt.close()
508 |
509 | print_progress(f"性能指标箱线图已保存至: {save_path}", 2)
510 |
511 | print_progress("所有结果图表已生成完成", 1)
512 |
513 | def parse_args():
514 | """
515 | 解析命令行参数
516 |
517 | 支持的参数:
518 | 1. --config: 配置文件路径
519 | 2. --device: 运行设备(cpu/cuda)
520 | 3. --seed: 随机种子
521 |
522 | 返回:
523 | argparse.Namespace: 解析后的参数对象
524 | """
525 | parser = argparse.ArgumentParser(description='信道估计方法比较研究')
526 | parser.add_argument('--config', type=str, default='src/config/default.yml',
527 | help='配置文件路径')
528 | parser.add_argument('--device', type=str, choices=['cpu', 'cuda'],
529 | help='运行设备')
530 | parser.add_argument('--seed', type=int, help='随机种子')
531 | return parser.parse_args()
532 |
533 | def load_config(config_path: str) -> dict:
534 | """
535 | 加载配置文件
536 |
537 | 功能:
538 | 1. 读取YAML配置文件
539 | 2. 验证必要配置项
540 | 3. 检查配置完整性
541 |
542 | 参数:
543 | config_path (str): 配置文件路径
544 |
545 | 返回:
546 | dict: 配置字典
547 |
548 | 异常:
549 | ValueError: 当缺少必要的配置项时抛出
550 | """
551 | with open(config_path, 'r', encoding='utf-8') as f:
552 | config = yaml.safe_load(f)
553 |
554 | # 验证必要的配置项
555 | required_configs = [
556 | 'experiment.name',
557 | 'channel.n_tx',
558 | 'channel.n_rx',
559 | 'channel.n_pilot',
560 | 'channel.snr_db',
561 | 'channel.n_samples',
562 | 'channel.type',
563 | 'traditional.ls.regularization',
564 | 'traditional.ls.lambda',
565 | 'traditional.lmmse.adaptive_snr',
566 | 'traditional.lmmse.correlation_method',
567 | 'traditional.ml.max_iter',
568 | 'traditional.ml.tol'
569 | ]
570 |
571 | for config_path in required_configs:
572 | value = get_config_value(config, config_path)
573 | if value is None:
574 | raise ValueError(f"配置文件缺少必要项: {config_path}")
575 |
576 | return config
577 |
578 | def get_config_value(config: dict, path: str) -> Any:
579 | """
580 | 从嵌套字典中获取值
581 |
582 | 使用点号分隔的路径从嵌套字典中获取值
583 | 例如:'models.cnn.channels' -> config['models']['cnn']['channels']
584 |
585 | 参数:
586 | config (dict): 配置字典
587 | path (str): 以点分隔的配置路径
588 |
589 | 返回:
590 | Any: 配置值,如果路径不存在则返回None
591 | """
592 | keys = path.split('.')
593 | value = config
594 | for key in keys:
595 | if not isinstance(value, dict) or key not in value:
596 | return None
597 | value = value[key]
598 | return value
599 |
600 | def main():
601 | """主函数"""
602 | # 解析命令行参数
603 | parser = argparse.ArgumentParser(description='信道估计实验')
604 | parser.add_argument('--config', type=str, default='src/config/default.yml',
605 | help='配置文件路径')
606 | args = parser.parse_args()
607 |
608 | # 加载配置
609 | cfg = load_config(args.config)
610 |
611 | # 创建保存目录
612 | save_dir = Path(cfg['experiment']['save_dir'])
613 | plot_dir = Path(cfg['experiment']['plot_dir'])
614 | save_dir.mkdir(parents=True, exist_ok=True)
615 | plot_dir.mkdir(parents=True, exist_ok=True)
616 |
617 | # 设置时间戳
618 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
619 |
620 | # 设置日志记录
621 | logger = setup_logging(save_dir, timestamp)
622 | sys.stdout = logger
623 |
624 | # 设置随机种子
625 | seed = cfg['experiment']['seed']
626 | random.seed(seed)
627 | np.random.seed(seed)
628 | torch.manual_seed(seed)
629 | if torch.cuda.is_available() and cfg['experiment']['use_cuda']:
630 | torch.cuda.manual_seed(seed)
631 |
632 | try:
633 | print_section_header("开始信道估计实验")
634 |
635 | # 打印实验参数
636 | print_progress("实验参数配置:", 1)
637 | print_progress(f"发射天线数: {cfg['channel']['n_tx']}", 2)
638 | print_progress(f"接收天线数: {cfg['channel']['n_rx']}", 2)
639 | print_progress(f"导频长度: {cfg['channel']['n_pilot']}", 2)
640 | print_progress(f"信噪比: {cfg['channel']['snr_db']} dB", 2)
641 | print_progress(f"样本数量: {cfg['channel']['n_samples']}", 2)
642 | print_progress(f"批次大小: {cfg['data']['loader']['batch_size']}", 2)
643 | print_progress(f"运行设备: {'cuda' if torch.cuda.is_available() and cfg['experiment']['use_cuda'] else 'cpu'}", 2)
644 |
645 | # 生成数据
646 | print_progress("\n初始化数据生成器...", 1)
647 | start_time = time.time()
648 | data_generator = ChannelDataGenerator(
649 | n_tx=cfg['channel']['n_tx'],
650 | n_rx=cfg['channel']['n_rx'],
651 | n_pilot=cfg['channel']['n_pilot'],
652 | channel_type=cfg['channel']['type'],
653 | rician_k=cfg['channel'].get('rician_k', 1.0)
654 | )
655 |
656 | print_progress("开始生成数据...", 1)
657 | H, X, Y = data_generator.generate(
658 | n_samples=cfg['channel']['n_samples'],
659 | snr_db=cfg['channel']['snr_db']
660 | )
661 |
662 | print_progress(f"信道矩阵 H 形状: {H.shape}", 2)
663 | print_progress(f"导频符号 X 形状: {X.shape}", 2)
664 | print_progress(f"接收信号 Y 形状: {Y.shape}", 2)
665 | print_progress(f"数据生成完成 (耗时: {time.time()-start_time:.2f}s)", 1)
666 |
667 | # 数据预处理
668 | print_progress("\n开始数据预处理...", 1)
669 | start_time = time.time()
670 | preprocessor = ChannelPreprocessor(
671 | normalization=cfg['data']['preprocessing']['normalization'],
672 | remove_outliers=cfg['data']['preprocessing']['remove_outliers'],
673 | outlier_threshold=cfg['data']['preprocessing']['outlier_threshold']
674 | )
675 |
676 | H = preprocessor.fit_transform(H)
677 | Y = preprocessor.transform(Y)
678 |
679 | print_progress(f"数据预处理完成 (耗时: {time.time()-start_time:.2f}s)", 1)
680 |
681 | # 数据增强
682 | if cfg['data']['augmentation']['enabled']:
683 | print_progress("\n开始数据增强...", 1)
684 | start_time = time.time()
685 |
686 | H = augment_data(
687 | H,
688 | methods=cfg['data']['augmentation']['methods'],
689 | noise_std=cfg['data']['augmentation']['noise_std'],
690 | phase_shift_range=cfg['data']['augmentation']['phase_shift_range'],
691 | magnitude_scale_range=cfg['data']['augmentation']['magnitude_scale_range']
692 | )
693 |
694 | print_progress("增强后数据形状:", 2)
695 | print_progress(f"H: {H.shape}", 3)
696 | print_progress(f"X: {X.shape}", 3)
697 | print_progress(f"Y: {Y.shape}", 3)
698 | print_progress(f"数据增强完成 (耗时: {time.time()-start_time:.2f}s)", 1)
699 |
700 | # 评估传统方法
701 | print_section_header("评估传统估计方法")
702 | traditional_results = evaluate_traditional_methods(
703 | H=H,
704 | X=X,
705 | Y=Y,
706 | snr=10**(cfg['channel']['snr_db']/10),
707 | cfg=cfg
708 | )
709 |
710 | # 数据集划分
711 | print_section_header("准备深度学习数据集")
712 | n_samples = H.shape[0]
713 | train_ratio = cfg['data']['split']['train_ratio']
714 | val_ratio = cfg['data']['split']['val_ratio']
715 |
716 | # 计算样本数量
717 | n_train = int(n_samples * train_ratio)
718 | n_val = int(n_samples * val_ratio)
719 | n_test = n_samples - n_train - n_val
720 |
721 | print_progress("数据集划分:", 1)
722 | print_progress(f"训练集: {n_train} 样本", 2)
723 | print_progress(f"验证集: {n_val} 样本", 2)
724 | print_progress(f"测试集: {n_test} 样本", 2)
725 |
726 | # 转换为PyTorch张量
727 | # 将复数数据分离为实部和虚部
728 | X_real = torch.from_numpy(X.real).float()
729 | X_imag = torch.from_numpy(X.imag).float()
730 | H_real = torch.from_numpy(H.real).float()
731 | H_imag = torch.from_numpy(H.imag).float()
732 |
733 | # 在最后一维拼接实部和虚部
734 | X_tensor = torch.cat([X_real, X_imag], dim=-1)
735 | H_tensor = torch.cat([H_real, H_imag], dim=-1)
736 |
737 | # 创建数据集
738 | dataset = TensorDataset(X_tensor, H_tensor)
739 |
740 | # 划分数据集
741 | train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
742 | dataset, [n_train, n_val, n_test]
743 | )
744 |
745 | # 创建数据加载器
746 | train_loader = DataLoader(
747 | train_dataset,
748 | batch_size=cfg['data']['loader']['batch_size'],
749 | shuffle=True,
750 | num_workers=cfg['data']['loader']['num_workers'],
751 | pin_memory=cfg['data']['loader']['pin_memory']
752 | )
753 |
754 | val_loader = DataLoader(
755 | val_dataset,
756 | batch_size=cfg['data']['loader']['batch_size'],
757 | shuffle=False,
758 | num_workers=cfg['data']['loader']['num_workers'],
759 | pin_memory=cfg['data']['loader']['pin_memory']
760 | )
761 |
762 | test_loader = DataLoader(
763 | test_dataset,
764 | batch_size=cfg['data']['loader']['batch_size'],
765 | shuffle=False,
766 | num_workers=cfg['data']['loader']['num_workers'],
767 | pin_memory=cfg['data']['loader']['pin_memory']
768 | )
769 |
770 | # 计算输入大小(考虑实部和虚部)
771 | input_size = X.shape[1] * X.shape[2] * 2 # n_tx * n_pilot * 2 (实部和虚部)
772 |
773 | # 训练和评估深度学习模型
774 | dl_results = evaluate_dl_methods(
775 | train_loader=train_loader,
776 | val_loader=val_loader,
777 | test_loader=test_loader,
778 | input_size=input_size,
779 | cfg=cfg
780 | )
781 |
782 | # 保存结果
783 | results = {
784 | 'config': cfg,
785 | 'traditional_results': traditional_results,
786 | 'dl_results': {
787 | name: {
788 | 'metrics': result['metrics']
789 | } for name, result in dl_results.items()
790 | }
791 | }
792 |
793 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
794 | result_file = save_dir / f"results_{timestamp}.json"
795 | with open(result_file, 'w', encoding='utf-8') as f:
796 | json.dump(results, f, indent=4)
797 |
798 | # 绘制结果
799 | plot_results(traditional_results, dl_results, cfg)
800 |
801 | print("\n实验完成!结果已保存至:", result_file)
802 |
803 | except Exception as e:
804 | print(f"\n实验出错!错误信息:{str(e)}")
805 | raise e
806 |
807 | finally:
808 | # 恢复标准输出
809 | sys.stdout = sys.__stdout__
810 | # 关闭日志文件
811 | logger.log_file.close()
812 |
813 | if __name__ == "__main__":
814 | main()
--------------------------------------------------------------------------------
/src/traditional/__pycache__/estimators.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/traditional/__pycache__/estimators.cpython-312.pyc
--------------------------------------------------------------------------------
/src/traditional/__pycache__/estimators.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/traditional/__pycache__/estimators.cpython-39.pyc
--------------------------------------------------------------------------------
/src/traditional/estimators.py:
--------------------------------------------------------------------------------
1 | """
2 | 传统信道估计方法模块
3 | 包含LS、LMMSE和ML估计器
4 | """
5 |
6 | import numpy as np
7 | from scipy.special import erf
8 | from typing import Dict, Union, Optional
9 |
10 | class LSEstimator:
11 | """最小二乘(LS)估计器"""
12 |
13 | def __init__(self, regularization: bool = False, lambda_: float = 0.01):
14 | """
15 | 初始化LS估计器
16 |
17 | 参数:
18 | regularization: 是否使用正则化
19 | lambda_: 正则化参数
20 | """
21 | self.regularization = regularization
22 | self.lambda_ = lambda_
23 |
24 | def estimate(self, Y: np.ndarray, X: np.ndarray) -> np.ndarray:
25 | """
26 | 执行LS估计
27 |
28 | 参数:
29 | Y: 接收信号,shape=(n_samples, n_rx, n_pilot)
30 | X: 导频符号,shape=(n_samples, n_tx, n_pilot)
31 |
32 | 返回:
33 | H_est: 估计的信道矩阵,shape=(n_samples, n_rx, n_tx)
34 | """
35 | # 转置X以便进行批量矩阵运算
36 | X_H = np.conjugate(np.transpose(X, (0, 2, 1))) # (n_samples, n_pilot, n_tx)
37 |
38 | # 计算XX^H
39 | XX_H = np.matmul(X, X_H) # (n_samples, n_tx, n_tx)
40 |
41 | if self.regularization:
42 | # 添加正则化项
43 | n_tx = X.shape[1]
44 | reg_term = self.lambda_ * np.eye(n_tx)[np.newaxis, :, :]
45 | XX_H = XX_H + reg_term
46 |
47 | # 使用SVD进行稳定的矩阵求逆
48 | H_est_list = []
49 | for i in range(Y.shape[0]):
50 | try:
51 | # 尝试直接求逆
52 | XX_H_inv = np.linalg.inv(XX_H[i])
53 | except np.linalg.LinAlgError:
54 | # 如果矩阵接近奇异,使用伪逆
55 | XX_H_inv = np.linalg.pinv(XX_H[i])
56 |
57 | # 计算信道估计
58 | H_est_i = np.matmul(Y[i], np.matmul(X_H[i], XX_H_inv))
59 | H_est_list.append(H_est_i)
60 |
61 | H_est = np.stack(H_est_list, axis=0)
62 | return H_est
63 |
64 | class LMMSEEstimator:
65 | """线性最小均方误差(LMMSE)估计器"""
66 |
67 | def __init__(self, snr: float, adaptive_snr: bool = False,
68 | correlation_method: str = 'sample'):
69 | """
70 | 初始化LMMSE估计器
71 |
72 | 参数:
73 | snr: 信噪比(线性,非dB)
74 | adaptive_snr: 是否使用自适应SNR估计
75 | correlation_method: 信道相关矩阵估计方法,'sample' 或 'theoretical'
76 | """
77 | self.snr = snr
78 | self.adaptive_snr = adaptive_snr
79 | self.correlation_method = correlation_method
80 |
81 | if correlation_method not in ['sample', 'theoretical']:
82 | raise ValueError("correlation_method必须是'sample'或'theoretical'")
83 |
84 | def _estimate_correlation(self, H_ls: np.ndarray) -> np.ndarray:
85 | """
86 | 估计信道相关矩阵
87 |
88 | 参数:
89 | H_ls: LS估计的信道矩阵
90 |
91 | 返回:
92 | R_H: 信道相关矩阵
93 | """
94 | if self.correlation_method == 'sample':
95 | # 使用样本相关矩阵
96 | H_ls_flat = H_ls.reshape(H_ls.shape[0], -1)
97 | R_H = np.matmul(H_ls_flat.T.conj(), H_ls_flat) / H_ls.shape[0]
98 | # 添加小的正则化项以提高数值稳定性
99 | epsilon = 1e-10
100 | R_H = R_H + epsilon * np.eye(R_H.shape[0])
101 | else:
102 | # 使用理论相关矩阵(假设指数衰减)
103 | n_rx, n_tx = H_ls.shape[1:3]
104 | R_H = np.zeros((n_rx * n_tx, n_rx * n_tx), dtype=complex)
105 | rho = 0.7 # 相关系数
106 | for i in range(n_rx * n_tx):
107 | for j in range(n_rx * n_tx):
108 | R_H[i, j] = rho ** abs(i - j)
109 |
110 | return R_H
111 |
112 | def estimate(self, Y: np.ndarray, X: np.ndarray) -> np.ndarray:
113 | """
114 | 执行LMMSE估计
115 |
116 | 参数:
117 | Y: 接收信号,shape=(n_samples, n_rx, n_pilot)
118 | X: 导频符号,shape=(n_samples, n_tx, n_pilot)
119 |
120 | 返回:
121 | H_est: 估计的信道矩阵,shape=(n_samples, n_rx, n_tx)
122 | """
123 | # 首先进行LS估计
124 | ls_estimator = LSEstimator(regularization=True, lambda_=0.01) # 使用正则化的LS估计
125 | H_ls = ls_estimator.estimate(Y, X)
126 |
127 | # 估计信道相关矩阵
128 | R_H = self._estimate_correlation(H_ls)
129 |
130 | if self.adaptive_snr:
131 | # 使用样本方差估计噪声功率
132 | error = Y - np.matmul(H_ls, X)
133 | noise_power = np.mean(np.abs(error) ** 2)
134 | signal_power = np.mean(np.abs(Y) ** 2)
135 | snr = max(signal_power / noise_power, 1.0) # 确保SNR不小于1
136 | else:
137 | snr = self.snr
138 |
139 | # LMMSE估计
140 | n_samples = Y.shape[0]
141 | H_est = np.zeros_like(H_ls)
142 |
143 | for i in range(n_samples):
144 | h_ls = H_ls[i].flatten()
145 | # 使用更稳定的求解方法
146 | try:
147 | # 使用Cholesky分解求解线性方程组
148 | A = R_H + np.eye(len(h_ls)) / snr
149 | L = np.linalg.cholesky(A)
150 | h_est = h_ls.copy()
151 | # 解线性方程组 A @ x = R_H @ h_ls
152 | y = np.linalg.solve(L, R_H @ h_ls)
153 | h_est = np.linalg.solve(L.T.conj(), y)
154 | except np.linalg.LinAlgError:
155 | # 如果Cholesky分解失败,使用伪逆
156 | h_est = np.matmul(R_H, np.linalg.pinv(R_H + np.eye(len(h_ls)) / snr)) @ h_ls
157 |
158 | H_est[i] = h_est.reshape(H_ls.shape[1:])
159 |
160 | return H_est
161 |
162 | class MLEstimator:
163 | """最大似然(ML)估计器"""
164 |
165 | def __init__(self, max_iter: int = 100, tol: float = 1e-6, learning_rate: float = 0.01):
166 | """
167 | 初始化ML估计器
168 |
169 | 参数:
170 | max_iter: 最大迭代次数
171 | tol: 收敛阈值
172 | learning_rate: 初始学习率
173 | """
174 | self.max_iter = max_iter
175 | self.tol = float(tol)
176 | self.learning_rate = learning_rate
177 | self.min_lr = 1e-6 # 最小学习率
178 | self.lr_decay = 0.95 # 学习率衰减因子
179 |
180 | def _calculate_loss(self, H_est: np.ndarray, X: np.ndarray, Y: np.ndarray) -> float:
181 | """计算损失函数(负对数似然)"""
182 | error = Y - np.matmul(H_est, X)
183 | return np.mean(np.abs(error) ** 2)
184 |
185 | def estimate(self, Y: np.ndarray, X: np.ndarray) -> np.ndarray:
186 | """
187 | 执行ML估计
188 |
189 | 参数:
190 | Y: 接收信号,shape=(n_samples, n_rx, n_pilot)
191 | X: 导频符号,shape=(n_samples, n_tx, n_pilot)
192 |
193 | 返回:
194 | H_est: 估计的信道矩阵,shape=(n_samples, n_rx, n_tx)
195 | """
196 | # 使用正则化的LS估计作为初始值
197 | ls_estimator = LSEstimator(regularization=True, lambda_=0.01)
198 | H_est = ls_estimator.estimate(Y, X)
199 |
200 | # 迭代优化
201 | X_H = np.conjugate(np.transpose(X, (0, 2, 1))) # (n_samples, n_pilot, n_tx)
202 | lr = self.learning_rate
203 | best_H = H_est.copy()
204 | best_loss = self._calculate_loss(H_est, X, Y)
205 |
206 | for iter_idx in range(self.max_iter):
207 | H_prev = H_est.copy()
208 | prev_loss = self._calculate_loss(H_prev, X, Y)
209 |
210 | # 计算误差
211 | error = Y - np.matmul(H_est, X) # (n_samples, n_rx, n_pilot)
212 |
213 | # 计算梯度
214 | gradient = -np.matmul(error, X_H) # 负梯度方向
215 |
216 | # 梯度缩放
217 | grad_norm = np.maximum(np.sqrt(np.mean(np.abs(gradient) ** 2)), 1e-8)
218 | gradient = gradient / grad_norm
219 |
220 | # 更新估计
221 | H_est = H_est - lr * gradient
222 |
223 | # 计算当前损失
224 | current_loss = self._calculate_loss(H_est, X, Y)
225 |
226 | # 如果损失增加,回退并降低学习率
227 | if current_loss > prev_loss:
228 | H_est = H_prev
229 | lr = max(lr * self.lr_decay, self.min_lr)
230 | continue
231 |
232 | # 更新最佳结果
233 | if current_loss < best_loss:
234 | best_H = H_est.copy()
235 | best_loss = current_loss
236 |
237 | # 检查收敛
238 | diff = float(np.mean(np.abs(H_est - H_prev) ** 2))
239 | if diff < self.tol:
240 | break
241 |
242 | # 定期降低学习率
243 | if iter_idx > 0 and iter_idx % 10 == 0:
244 | lr = max(lr * self.lr_decay, self.min_lr)
245 |
246 | return best_H
247 |
248 | def calculate_performance_metrics(H_true: np.ndarray, H_est: np.ndarray) -> Dict[str, float]:
249 | """
250 | 计算性能指标
251 |
252 | 参数:
253 | H_true: 真实信道矩阵
254 | H_est: 估计的信道矩阵
255 |
256 | 返回:
257 | 包含各种性能指标的字典
258 | """
259 | # 计算MSE
260 | mse = np.mean(np.abs(H_true - H_est) ** 2)
261 |
262 | # 计算NMSE
263 | nmse = mse / np.mean(np.abs(H_true) ** 2)
264 |
265 | # 计算BER(假设QPSK调制)
266 | def qpsk_ber(h_true, h_est):
267 | # 添加小的常数以避免除零
268 | epsilon = 1e-10
269 | snr_eff = np.abs(h_true) ** 2 / (np.abs(h_true - h_est) ** 2 + epsilon)
270 | return 0.5 * np.mean(1 - erf(np.sqrt(snr_eff/2)))
271 |
272 | ber = qpsk_ber(H_true, H_est)
273 |
274 | return {
275 | 'mse': float(mse),
276 | 'nmse': float(nmse),
277 | 'ber': float(ber)
278 | }
--------------------------------------------------------------------------------
/src/traditional/ls_estimation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import linalg
3 |
4 | class LSChannelEstimator:
5 | """最小二乘(LS)信道估计器"""
6 |
7 | def __init__(self):
8 | self.H_est = None # 估计的信道矩阵
9 |
10 | def estimate(self, Y, X):
11 | """
12 | 使用最小二乘方法估计信道
13 |
14 | 参数:
15 | Y: 接收信号矩阵
16 | X: 发送信号矩阵(导频)
17 |
18 | 返回:
19 | H_est: 估计的信道矩阵
20 | """
21 | # 使用最小二乘方法估计信道
22 | # H = Y * X^H * (X * X^H)^(-1)
23 | X_H = np.conjugate(X).T
24 | self.H_est = Y @ X_H @ linalg.inv(X @ X_H)
25 | return self.H_est
26 |
27 | def get_mse(self, H_true):
28 | """
29 | 计算均方误差
30 |
31 | 参数:
32 | H_true: 真实信道矩阵
33 |
34 | 返回:
35 | mse: 均方误差
36 | """
37 | if self.H_est is None:
38 | raise ValueError("请先进行信道估计")
39 |
40 | mse = np.mean(np.abs(H_true - self.H_est) ** 2)
41 | return mse
--------------------------------------------------------------------------------
/src/utils/__pycache__/config.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/utils/__pycache__/config.cpython-312.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/config.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/utils/__pycache__/config.cpython-39.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/data_generator.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/utils/__pycache__/data_generator.cpython-312.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/data_generator.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/utils/__pycache__/data_generator.cpython-39.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/preprocessing.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/utils/__pycache__/preprocessing.cpython-312.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/preprocessing.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/utils/__pycache__/preprocessing.cpython-39.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/trainer.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/utils/__pycache__/trainer.cpython-312.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/trainer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ygxiuming/Channel-Estimation-Methods-Comparison-Framework/1eb33f692de658677eb737832eb690139f903891/src/utils/__pycache__/trainer.cpython-39.pyc
--------------------------------------------------------------------------------
/src/utils/config.py:
--------------------------------------------------------------------------------
1 | """
2 | 配置加载器模块
3 | 用于加载和管理YAML配置文件
4 | """
5 |
6 | import os
7 | import yaml
8 | from pathlib import Path
9 | from typing import Any, Dict, Optional
10 |
11 | class Config:
12 | """配置加载器类"""
13 |
14 | def __init__(self, config_path: str = "config/default.yml"):
15 | """
16 | 初始化配置加载器
17 |
18 | 参数:
19 | config_path: 配置文件路径,默认为'config/default.yml'
20 | """
21 | self.config_path = Path(config_path)
22 | if not self.config_path.exists():
23 | raise FileNotFoundError(f"配置文件未找到: {config_path}")
24 |
25 | with open(self.config_path, 'r', encoding='utf-8') as f:
26 | self.config = yaml.safe_load(f)
27 |
28 | def get(self, key: str, default: Any = None) -> Any:
29 | """
30 | 获取配置值
31 |
32 | 参数:
33 | key: 配置键,使用点号分隔层级,如'models.cnn.channels'
34 | default: 默认值,当键不存在时返回
35 |
36 | 返回:
37 | 配置值
38 | """
39 | keys = key.split('.')
40 | value = self.config
41 |
42 | for k in keys:
43 | if isinstance(value, dict):
44 | value = value.get(k)
45 | if value is None:
46 | return default
47 | else:
48 | return default
49 |
50 | return value
51 |
52 | def set(self, key: str, value: Any) -> None:
53 | """
54 | 设置配置值
55 |
56 | 参数:
57 | key: 配置键,使用点号分隔层级
58 | value: 要设置的值
59 | """
60 | keys = key.split('.')
61 | config = self.config
62 |
63 | for k in keys[:-1]:
64 | if k not in config:
65 | config[k] = {}
66 | config = config[k]
67 |
68 | config[keys[-1]] = value
69 |
70 | def save(self, save_path: Optional[str] = None) -> None:
71 | """
72 | 保存配置到文件
73 |
74 | 参数:
75 | save_path: 保存路径,默认为原配置文件路径
76 | """
77 | save_path = Path(save_path) if save_path else self.config_path
78 | save_path.parent.mkdir(parents=True, exist_ok=True)
79 |
80 | with open(save_path, 'w', encoding='utf-8') as f:
81 | yaml.safe_dump(self.config, f, allow_unicode=True, sort_keys=False)
82 |
83 | def update_from_args(self, args: Dict[str, Any]) -> None:
84 | """
85 | 从命令行参数更新配置
86 |
87 | 参数:
88 | args: 命令行参数字典
89 | """
90 | for key, value in args.items():
91 | if value is not None: # 只更新非None的值
92 | self.set(key, value)
93 |
94 | def __getitem__(self, key: str) -> Any:
95 | """使配置对象可以使用字典方式访问"""
96 | return self.get(key)
97 |
98 | def __setitem__(self, key: str, value: Any) -> None:
99 | """使配置对象可以使用字典方式设置值"""
100 | self.set(key, value)
101 |
102 | def __str__(self) -> str:
103 | """返回配置的字符串表示"""
104 | return yaml.safe_dump(self.config, allow_unicode=True, sort_keys=False)
105 |
106 | @property
107 | def experiment(self) -> Dict:
108 | """获取实验配置"""
109 | return self.config.get('experiment', {})
110 |
111 | @property
112 | def channel(self) -> Dict:
113 | """获取信道配置"""
114 | return self.config.get('channel', {})
115 |
116 | @property
117 | def data(self) -> Dict:
118 | """获取数据处理配置"""
119 | return self.config.get('data', {})
120 |
121 | @property
122 | def models(self) -> Dict:
123 | """获取模型配置"""
124 | return self.config.get('models', {})
125 |
126 | @property
127 | def training(self) -> Dict:
128 | """获取训练配置"""
129 | return self.config.get('training', {})
130 |
131 | @property
132 | def evaluation(self) -> Dict:
133 | """获取评估配置"""
134 | return self.config.get('evaluation', {})
135 |
136 | @property
137 | def tensorboard(self) -> Dict:
138 | """获取TensorBoard配置"""
139 | return self.config.get('tensorboard', {})
--------------------------------------------------------------------------------
/src/utils/data_generator.py:
--------------------------------------------------------------------------------
1 | """
2 | 信道数据生成器模块
3 | 用于生成MIMO信道数据、导频符号和接收信号
4 | """
5 |
6 | import numpy as np
7 | from typing import Tuple
8 |
9 | class ChannelDataGenerator:
10 | """信道数据生成器类"""
11 |
12 | def __init__(self, n_tx: int, n_rx: int, n_pilot: int,
13 | channel_type: str = 'rayleigh', rician_k: float = 1.0):
14 | """
15 | 初始化信道数据生成器
16 |
17 | 参数:
18 | n_tx: 发射天线数量
19 | n_rx: 接收天线数量
20 | n_pilot: 导频符号长度
21 | channel_type: 信道类型,'rayleigh' 或 'rician'
22 | rician_k: Rician K因子,仅在channel_type='rician'时有效
23 | """
24 | self.n_tx = n_tx
25 | self.n_rx = n_rx
26 | self.n_pilot = n_pilot
27 | self.channel_type = channel_type.lower()
28 | self.rician_k = rician_k
29 |
30 | if self.channel_type not in ['rayleigh', 'rician']:
31 | raise ValueError("信道类型必须是 'rayleigh' 或 'rician'")
32 |
33 | if self.n_pilot < self.n_tx:
34 | raise ValueError("导频长度必须大于等于发射天线数量")
35 |
36 | def generate_channel(self, n_samples: int) -> np.ndarray:
37 | """
38 | 生成MIMO信道矩阵
39 |
40 | 参数:
41 | n_samples: 样本数量
42 |
43 | 返回:
44 | shape=(n_samples, n_rx, n_tx)的信道矩阵
45 | """
46 | # 生成随机复高斯分量
47 | h_random = (np.random.normal(0, 1/np.sqrt(2), (n_samples, self.n_rx, self.n_tx)) +
48 | 1j * np.random.normal(0, 1/np.sqrt(2), (n_samples, self.n_rx, self.n_tx)))
49 |
50 | if self.channel_type == 'rayleigh':
51 | return h_random
52 | else: # Rician信道
53 | # 生成确定性分量(LOS分量)
54 | h_los = np.ones((n_samples, self.n_rx, self.n_tx)) / np.sqrt(self.n_tx * self.n_rx)
55 |
56 | # 组合LOS分量和随机分量
57 | k = self.rician_k
58 | h_rician = np.sqrt(k/(k+1)) * h_los + np.sqrt(1/(k+1)) * h_random
59 |
60 | return h_rician
61 |
62 | def generate_pilot_symbols(self, n_samples: int) -> np.ndarray:
63 | """
64 | 生成导频符号
65 |
66 | 参数:
67 | n_samples: 样本数量
68 |
69 | 返回:
70 | shape=(n_samples, n_tx, n_pilot)的导频符号矩阵
71 | """
72 | # 使用QPSK调制生成导频
73 | pilot_real = np.random.choice([-1, 1], size=(n_samples, self.n_tx, self.n_pilot))
74 | pilot_imag = np.random.choice([-1, 1], size=(n_samples, self.n_tx, self.n_pilot))
75 | pilot = (pilot_real + 1j * pilot_imag) / np.sqrt(2)
76 |
77 | # 归一化功率
78 | pilot = pilot / np.sqrt(self.n_tx)
79 |
80 | return pilot
81 |
82 | def generate_received_signal(self, H: np.ndarray, X: np.ndarray, snr_db: float) -> np.ndarray:
83 | """
84 | 生成接收信号
85 |
86 | 参数:
87 | H: shape=(n_samples, n_rx, n_tx)的信道矩阵
88 | X: shape=(n_samples, n_tx, n_pilot)的发送信号
89 | snr_db: 信噪比(dB)
90 |
91 | 返回:
92 | shape=(n_samples, n_rx, n_pilot)的接收信号
93 | """
94 | # 计算信噪比
95 | snr = 10 ** (snr_db / 10)
96 |
97 | # 计算信号功率
98 | signal_power = np.mean(np.abs(H @ X) ** 2)
99 | noise_power = signal_power / snr
100 |
101 | # 生成噪声
102 | noise = (np.random.normal(0, np.sqrt(noise_power/2), X.shape[0:1] + (self.n_rx, X.shape[-1])) +
103 | 1j * np.random.normal(0, np.sqrt(noise_power/2), X.shape[0:1] + (self.n_rx, X.shape[-1])))
104 |
105 | # 生成接收信号
106 | Y = np.matmul(H, X) + noise
107 |
108 | return Y
109 |
110 | def generate(self, n_samples: int, snr_db: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
111 | """
112 | 生成完整的训练数据集
113 |
114 | 参数:
115 | n_samples: 样本数量
116 | snr_db: 信噪比(dB)
117 |
118 | 返回:
119 | (H, X, Y)元组,分别是信道矩阵、导频符号和接收信号
120 | """
121 | # 生成信道矩阵
122 | H = self.generate_channel(n_samples)
123 |
124 | # 生成导频符号
125 | X = self.generate_pilot_symbols(n_samples)
126 |
127 | # 生成接收信号
128 | Y = self.generate_received_signal(H, X, snr_db)
129 |
130 | return H, X, Y
--------------------------------------------------------------------------------
/src/utils/preprocessing.py:
--------------------------------------------------------------------------------
1 | """
2 | 信道数据预处理模块
3 | 用于数据归一化和异常值处理
4 | """
5 |
6 | import numpy as np
7 | from typing import Optional, Union, List, Tuple
8 | from sklearn.preprocessing import StandardScaler, MinMaxScaler
9 |
10 | class ChannelPreprocessor:
11 | """信道数据预处理器"""
12 |
13 | def __init__(self, normalization: str = 'z-score',
14 | remove_outliers: bool = False,
15 | outlier_threshold: float = 3.0):
16 | """
17 | 初始化预处理器
18 |
19 | 参数:
20 | normalization: 归一化方法,'z-score' 或 'min-max'
21 | remove_outliers: 是否移除异常值
22 | outlier_threshold: 异常值阈值(标准差的倍数)
23 | """
24 | self.normalization = normalization.lower()
25 | self.remove_outliers = remove_outliers
26 | self.outlier_threshold = outlier_threshold
27 |
28 | if self.normalization not in ['z-score', 'min-max']:
29 | raise ValueError("归一化方法必须是 'z-score' 或 'min-max'")
30 |
31 | # 存储归一化参数
32 | self.mean_real = None
33 | self.std_real = None
34 | self.mean_imag = None
35 | self.std_imag = None
36 | self.min_real = None
37 | self.max_real = None
38 | self.min_imag = None
39 | self.max_imag = None
40 |
41 | def _remove_outliers(self, data: np.ndarray) -> np.ndarray:
42 | """
43 | 移除异常值
44 |
45 | 参数:
46 | data: 输入数据
47 |
48 | 返回:
49 | 处理后的数据
50 | """
51 | if not self.remove_outliers:
52 | return data
53 |
54 | # 分别处理实部和虚部
55 | real_part = np.real(data)
56 | imag_part = np.imag(data)
57 |
58 | # 计算均值和标准差
59 | mean_real = np.mean(real_part)
60 | std_real = np.std(real_part)
61 | mean_imag = np.mean(imag_part)
62 | std_imag = np.std(imag_part)
63 |
64 | # 创建掩码
65 | real_mask = np.abs(real_part - mean_real) <= self.outlier_threshold * std_real
66 | imag_mask = np.abs(imag_part - mean_imag) <= self.outlier_threshold * std_imag
67 | mask = real_mask & imag_mask
68 |
69 | # 将异常值替换为均值
70 | real_part[~real_mask] = mean_real
71 | imag_part[~imag_mask] = mean_imag
72 |
73 | return real_part + 1j * imag_part
74 |
75 | def fit(self, data: np.ndarray) -> None:
76 | """
77 | 计算归一化参数
78 |
79 | 参数:
80 | data: 输入数据
81 | """
82 | # 首先移除异常值
83 | data = self._remove_outliers(data)
84 |
85 | # 分离实部和虚部
86 | real_part = np.real(data)
87 | imag_part = np.imag(data)
88 |
89 | if self.normalization == 'z-score':
90 | self.mean_real = np.mean(real_part)
91 | self.std_real = np.std(real_part)
92 | self.mean_imag = np.mean(imag_part)
93 | self.std_imag = np.std(imag_part)
94 | else: # min-max归一化
95 | self.min_real = np.min(real_part)
96 | self.max_real = np.max(real_part)
97 | self.min_imag = np.min(imag_part)
98 | self.max_imag = np.max(imag_part)
99 |
100 | def transform(self, data: np.ndarray) -> np.ndarray:
101 | """
102 | 应用归一化
103 |
104 | 参数:
105 | data: 输入数据
106 |
107 | 返回:
108 | 归一化后的数据
109 | """
110 | # 首先移除异常值
111 | data = self._remove_outliers(data)
112 |
113 | # 分离实部和虚部
114 | real_part = np.real(data)
115 | imag_part = np.imag(data)
116 |
117 | if self.normalization == 'z-score':
118 | if self.mean_real is None or self.std_real is None:
119 | raise ValueError("请先调用fit方法计算归一化参数")
120 |
121 | real_norm = (real_part - self.mean_real) / self.std_real
122 | imag_norm = (imag_part - self.mean_imag) / self.std_imag
123 | else: # min-max归一化
124 | if self.min_real is None or self.max_real is None:
125 | raise ValueError("请先调用fit方法计算归一化参数")
126 |
127 | real_norm = (real_part - self.min_real) / (self.max_real - self.min_real)
128 | imag_norm = (imag_part - self.min_imag) / (self.max_imag - self.min_imag)
129 |
130 | return real_norm + 1j * imag_norm
131 |
132 | def fit_transform(self, data: np.ndarray) -> np.ndarray:
133 | """
134 | 计算归一化参数并应用归一化
135 |
136 | 参数:
137 | data: 输入数据
138 |
139 | 返回:
140 | 归一化后的数据
141 | """
142 | self.fit(data)
143 | return self.transform(data)
144 |
145 | def inverse_transform(self, data: np.ndarray) -> np.ndarray:
146 | """
147 | 反归一化
148 |
149 | 参数:
150 | data: 归一化后的数据
151 |
152 | 返回:
153 | 原始尺度的数据
154 | """
155 | real_part = np.real(data)
156 | imag_part = np.imag(data)
157 |
158 | if self.normalization == 'z-score':
159 | if self.mean_real is None or self.std_real is None:
160 | raise ValueError("请先调用fit方法计算归一化参数")
161 |
162 | real_orig = real_part * self.std_real + self.mean_real
163 | imag_orig = imag_part * self.std_imag + self.mean_imag
164 | else: # min-max归一化
165 | if self.min_real is None or self.max_real is None:
166 | raise ValueError("请先调用fit方法计算归一化参数")
167 |
168 | real_orig = real_part * (self.max_real - self.min_real) + self.min_real
169 | imag_orig = imag_part * (self.max_imag - self.min_imag) + self.min_imag
170 |
171 | return real_orig + 1j * imag_orig
172 |
173 | def augment_data(data: np.ndarray,
174 | methods: Optional[List[str]] = None,
175 | noise_std: float = 0.01,
176 | phase_shift_range: Tuple[float, float] = (-0.1, 0.1),
177 | magnitude_scale_range: Tuple[float, float] = (0.9, 1.1)) -> np.ndarray:
178 | """
179 | 数据增强
180 |
181 | 参数:
182 | data: 输入数据
183 | methods: 增强方法列表,可选['noise', 'phase_shift', 'magnitude_scale']
184 | noise_std: 高斯噪声标准差
185 | phase_shift_range: 相位偏移范围(弧度)
186 | magnitude_scale_range: 幅度缩放范围
187 |
188 | 返回:
189 | 增强后的数据
190 | """
191 | if methods is None:
192 | methods = ['noise', 'phase_shift', 'magnitude_scale']
193 |
194 | augmented_data = data.copy()
195 |
196 | for method in methods:
197 | if method == 'noise':
198 | # 添加复高斯噪声
199 | noise = (np.random.normal(0, noise_std, data.shape) +
200 | 1j * np.random.normal(0, noise_std, data.shape))
201 | augmented_data += noise
202 |
203 | elif method == 'phase_shift':
204 | # 随机相位偏移
205 | phase_shift = np.random.uniform(phase_shift_range[0],
206 | phase_shift_range[1],
207 | data.shape)
208 | augmented_data *= np.exp(1j * phase_shift)
209 |
210 | elif method == 'magnitude_scale':
211 | # 随机幅度缩放
212 | scale = np.random.uniform(magnitude_scale_range[0],
213 | magnitude_scale_range[1],
214 | data.shape)
215 | augmented_data *= scale
216 |
217 | return augmented_data
--------------------------------------------------------------------------------
/src/utils/trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | 信道估计训练器模块
3 |
4 | 此模块实现了用于训练深度学习模型的训练器类。训练器提供了完整的训练流程管理,
5 | 包括训练循环、验证评估、模型保存、早停机制等功能。
6 |
7 | 主要特性:
8 | - 支持GPU训练
9 | - 实时进度显示
10 | - TensorBoard可视化
11 | - 自动保存最佳模型
12 | - 训练状态恢复
13 | - 早停机制
14 | - 学习率自适应调整
15 |
16 | 作者: lzm
17 | 日期: 2025-03-07
18 | """
19 |
20 | import os
21 | import time
22 | from pathlib import Path
23 | import torch
24 | from torch.utils.tensorboard import SummaryWriter
25 | from tqdm import tqdm
26 | import sys
27 |
28 | class ChannelEstimatorTrainer:
29 | """
30 | 信道估计模型训练器
31 |
32 | 该类实现了完整的模型训练流程,包括:
33 | 1. 训练和验证循环
34 | 2. 损失计算和优化
35 | 3. 进度监控和可视化
36 | 4. 模型保存和加载
37 | 5. 训练状态管理
38 |
39 | 参数:
40 | model (nn.Module): 要训练的模型
41 | train_loader (DataLoader): 训练数据加载器
42 | val_loader (DataLoader): 验证数据加载器
43 | save_dir (str): 模型和日志保存目录
44 | project_name (str): 项目名称,用于创建子目录
45 | device (str): 训练设备('cuda'或'cpu')
46 | learning_rate (float): 初始学习率
47 | """
48 |
49 | def __init__(
50 | self,
51 | model,
52 | train_loader,
53 | val_loader,
54 | save_dir='runs/train',
55 | project_name='channel_estimation',
56 | device='cuda',
57 | learning_rate=0.001
58 | ):
59 | # 初始化模型和数据加载器
60 | self.model = model.to(device)
61 | self.train_loader = train_loader
62 | self.val_loader = val_loader
63 | self.device = device
64 | self.learning_rate = learning_rate
65 |
66 | # 创建保存目录
67 | self.save_dir = Path(save_dir) / project_name
68 | self.weights_dir = self.save_dir / 'weights'
69 | self.weights_dir.mkdir(parents=True, exist_ok=True)
70 |
71 | # 设置TensorBoard
72 | self.writer = SummaryWriter(str(self.save_dir / 'tensorboard'))
73 |
74 | # 初始化优化器和损失函数
75 | self.criterion = torch.nn.MSELoss()
76 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
77 |
78 | # 设置学习率调度器,当验证损失不再下降时降低学习率
79 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
80 | self.optimizer, 'min', patience=5, verbose=True
81 | )
82 |
83 | # 初始化训练状态
84 | self.best_val_loss = float('inf') # 最佳验证损失
85 | self.current_epoch = 0 # 当前训练轮数
86 | self.history = { # 训练历史记录
87 | 'train_loss': [], # 训练损失历史
88 | 'val_loss': [], # 验证损失历史
89 | 'learning_rate': [] # 学习率历史
90 | }
91 |
92 | def train_epoch(self):
93 | """
94 | 训练一个epoch
95 |
96 | 该方法实现了单个训练epoch的完整流程:
97 | 1. 遍历训练数据批次
98 | 2. 前向传播计算损失
99 | 3. 反向传播更新参数
100 | 4. 记录训练状态
101 | 5. 更新进度显示
102 |
103 | 返回:
104 | float: 该epoch的平均训练损失
105 | """
106 | self.model.train() # 设置为训练模式
107 | total_loss = 0
108 |
109 | # 创建进度条
110 | pbar = tqdm(self.train_loader,
111 | desc=f'Epoch {self.current_epoch + 1:<3d}',
112 | leave=True, # 保持显示
113 | position=1, # 位置在总进度条下方
114 | bar_format='{desc:<12} {percentage:3.0f}%|{bar:50}{r_bar}',
115 | ncols=120)
116 |
117 | # 遍历训练数据批次
118 | for batch_idx, (X, y) in enumerate(pbar):
119 | # 将数据移到指定设备
120 | X, y = X.to(self.device), y.to(self.device)
121 |
122 | # 前向传播
123 | self.optimizer.zero_grad() # 清除梯度
124 | outputs = self.model(X) # 模型预测
125 | loss = self.criterion(outputs, y) # 计算损失
126 |
127 | # 反向传播
128 | loss.backward() # 计算梯度
129 | self.optimizer.step() # 更新参数
130 |
131 | # 更新损失统计
132 | total_loss += loss.item()
133 | avg_loss = total_loss / (batch_idx + 1)
134 |
135 | # 更新进度条显示
136 | pbar.set_postfix({'loss': f'{avg_loss:.6f}'}, refresh=True)
137 |
138 | # 记录到TensorBoard
139 | step = self.current_epoch * len(self.train_loader) + batch_idx
140 | self.writer.add_scalar('train/batch_loss', loss.item(), step)
141 |
142 | pbar.close()
143 | return total_loss / len(self.train_loader)
144 |
145 | def validate(self):
146 | """
147 | 验证模型性能
148 |
149 | 该方法在验证集上评估模型性能:
150 | 1. 不计算梯度
151 | 2. 遍历验证数据批次
152 | 3. 计算验证损失
153 | 4. 更新进度显示
154 |
155 | 返回:
156 | float: 验证集上的平均损失
157 | """
158 | self.model.eval() # 设置为评估模式
159 | total_loss = 0
160 | val_loader_len = len(self.val_loader)
161 |
162 | with torch.no_grad(): # 不计算梯度
163 | # 创建进度条
164 | pbar = tqdm(self.val_loader,
165 | desc='Validating',
166 | leave=True, # 保持显示
167 | position=1, # 位置在总进度条下方
168 | bar_format='{desc:<12} {percentage:3.0f}%|{bar:50}{r_bar}',
169 | ncols=120)
170 |
171 | # 遍历验证数据批次
172 | for batch_idx, (X, y) in enumerate(pbar):
173 | # 将数据移到指定设备
174 | X, y = X.to(self.device), y.to(self.device)
175 | outputs = self.model(X) # 模型预测
176 | loss = self.criterion(outputs, y) # 计算损失
177 | total_loss += loss.item()
178 |
179 | # 更新进度条显示
180 | avg_loss = total_loss / (batch_idx + 1)
181 | pbar.set_postfix({'loss': f'{avg_loss:.6f}'}, refresh=True)
182 |
183 | pbar.close()
184 |
185 | return total_loss / val_loader_len
186 |
187 | def train(self, epochs=100, early_stopping_patience=10):
188 | """
189 | 完整的训练流程
190 |
191 | 该方法实现了完整的模型训练流程:
192 | 1. 多轮训练循环
193 | 2. 定期验证评估
194 | 3. 模型保存
195 | 4. 早停机制
196 | 5. 进度监控
197 |
198 | 参数:
199 | epochs (int): 训练轮数
200 | early_stopping_patience (int): 早停耐心值,验证损失多少轮未改善时停止训练
201 | """
202 | early_stopping_counter = 0
203 | start_time = time.time()
204 |
205 | # 打印训练配置
206 | print("\n训练配置:")
207 | print(f" Epochs: {epochs}")
208 | print(f" Batch Size: {self.train_loader.batch_size}")
209 | print(f" Learning Rate: {self.learning_rate}")
210 | print(f" Device: {self.device}")
211 | print(f" Early Stopping Patience: {early_stopping_patience}")
212 | print()
213 |
214 | # 创建总进度条
215 | pbar = tqdm(range(epochs),
216 | desc='Progress',
217 | leave=True, # 保持显示
218 | position=0, # 位置在最上方
219 | bar_format='{desc:<12} {percentage:3.0f}%|{bar:50}{r_bar}',
220 | ncols=120)
221 |
222 | try:
223 | # 训练循环
224 | for epoch in pbar:
225 | self.current_epoch = epoch
226 |
227 | # 训练一个epoch
228 | train_loss = self.train_epoch()
229 |
230 | # 验证评估
231 | val_loss = self.validate()
232 |
233 | # 更新学习率
234 | current_lr = self.optimizer.param_groups[0]['lr']
235 | self.scheduler.step(val_loss) # 根据验证损失调整学习率
236 |
237 | # 记录训练历史
238 | self.history['train_loss'].append(train_loss)
239 | self.history['val_loss'].append(val_loss)
240 | self.history['learning_rate'].append(current_lr)
241 |
242 | # 记录到TensorBoard
243 | self.writer.add_scalar('train/epoch_loss', train_loss, epoch)
244 | self.writer.add_scalar('val/epoch_loss', val_loss, epoch)
245 | self.writer.add_scalar('train/lr', current_lr, epoch)
246 |
247 | # 更新总进度条显示
248 | pbar.set_postfix({
249 | 'train': f'{train_loss:.6f}',
250 | 'val': f'{val_loss:.6f}',
251 | 'lr': f'{current_lr:.2e}'
252 | }, refresh=True)
253 |
254 | # 保存最佳模型
255 | if val_loss < self.best_val_loss:
256 | self.best_val_loss = val_loss
257 | self.save_model('best.pt')
258 | early_stopping_counter = 0 # 重置早停计数器
259 | else:
260 | early_stopping_counter += 1
261 |
262 | # 定期保存模型
263 | if (epoch + 1) % 10 == 0:
264 | self.save_model(f'epoch_{epoch + 1}.pt')
265 |
266 | # 早停检查
267 | if early_stopping_counter >= early_stopping_patience:
268 | print(f'\nEarly stopping triggered at epoch {epoch + 1}')
269 | break
270 |
271 | # 在每个epoch结束时打印空行
272 | print()
273 |
274 | except KeyboardInterrupt:
275 | print('\nTraining interrupted by user')
276 |
277 | finally:
278 | # 保存最后的模型和清理资源
279 | self.save_model('last.pt')
280 | self.writer.close()
281 |
282 | # 打印训练结果统计
283 | elapsed_time = time.time() - start_time
284 | print("\n训练完成:")
285 | print(f" 训练轮数: {self.current_epoch + 1}/{epochs}")
286 | print(f" 最佳验证损失: {self.best_val_loss:.6f}")
287 | print(f" 训练时间: {elapsed_time/3600:.2f}h")
288 | print(f" 保存路径: {self.weights_dir}")
289 | print()
290 |
291 | def save_model(self, filename):
292 | """
293 | 保存模型状态
294 |
295 | 保存完整的训练状态,包括:
296 | - 模型参数
297 | - 优化器状态
298 | - 学习率调度器状态
299 | - 训练历史
300 | - 最佳验证损失
301 |
302 | 参数:
303 | filename (str): 保存的文件名
304 | """
305 | save_path = self.weights_dir / filename
306 | torch.save({
307 | 'epoch': self.current_epoch,
308 | 'model_state_dict': self.model.state_dict(),
309 | 'optimizer_state_dict': self.optimizer.state_dict(),
310 | 'scheduler_state_dict': self.scheduler.state_dict(),
311 | 'best_val_loss': self.best_val_loss,
312 | 'history': self.history
313 | }, str(save_path))
314 |
315 | def load_model(self, filename):
316 | """
317 | 加载模型状态
318 |
319 | 加载完整的训练状态,包括:
320 | - 模型参数
321 | - 优化器状态
322 | - 学习率调度器状态
323 | - 训练历史
324 | - 最佳验证损失
325 |
326 | 参数:
327 | filename (str): 要加载的文件名
328 |
329 | 返回:
330 | bool: 是否成功加载模型
331 | """
332 | load_path = self.weights_dir / filename
333 | if load_path.exists():
334 | checkpoint = torch.load(str(load_path))
335 | self.model.load_state_dict(checkpoint['model_state_dict'])
336 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
337 | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
338 | self.current_epoch = checkpoint['epoch']
339 | self.best_val_loss = checkpoint['best_val_loss']
340 | if 'history' in checkpoint:
341 | self.history = checkpoint['history']
342 | return True
343 | return False
344 |
345 | def get_history(self):
346 | """
347 | 获取训练历史
348 |
349 | 返回:
350 | dict: 包含训练损失、验证损失和学习率的历史记录
351 | """
352 | return self.history
--------------------------------------------------------------------------------