├── .gitignore ├── README.md ├── cfg.py ├── main.py ├── model.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | H_dataset/ 7 | backup/ 8 | log_train/ 9 | channel_response_set_test.npy 10 | channel_response_set_train.npy 11 | Pilot_8 12 | Pilot_64 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OFDM_DNN_Pytorch 2 | 3 | 4 | Reproducing the study of paper "Power of Deep Learning for Channel Estimation and Signal Detection in OFDM Systems" with pytorch framework. 5 | 6 | 7 | ### usage 8 | 1. download channel dataset from https://github.com/haoyye/OFDM_DNN 9 | 2. unzip and move the dataset to the project 10 | 3. run train file 11 | ```python main.py``` 12 | 13 | ### reference 14 | https://github.com/haoyye/OFDM_DNN -------------------------------------------------------------------------------- /cfg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | K = 64 # 全部的子载波数 5 | CP = K//4 # CP长度 6 | P = 64 # 导频数 7 | 8 | mu = 2 # QPSK调制 9 | payloadBits_per_OFDM = K * mu 10 | SNRdb = 20 11 | 12 | pilot_bits = None 13 | pilot_file_name = 'Pilot_' + str(P) 14 | if os.path.isfile(pilot_file_name): 15 | print('load pilot txt') 16 | pilot_bits = np.loadtxt(pilot_file_name, delimiter=',') 17 | else: 18 | pilot_bits = np.random.binomial(n=1, p=0.5, size=(K*mu,)) 19 | np.savetxt(pilot_file_name, delimiter=',') 20 | 21 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torch.nn as nn 4 | 5 | from matplotlib import pyplot as plt 6 | from utils import * 7 | from model import NeuralNetwork 8 | 9 | from tensorboardX import SummaryWriter 10 | import shutil 11 | from random import shuffle 12 | 13 | log_dir = 'log_train' 14 | if os.path.exists(log_dir): 15 | shutil.rmtree(log_dir) 16 | writer = SummaryWriter(logdir=log_dir, flush_secs=10) 17 | 18 | model_path = 'backup/model.pth' 19 | 20 | # Get cpu or gpu device for training. 21 | if torch.cuda.is_available(): 22 | device = "cuda" 23 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 24 | else: 25 | device = "cpu" 26 | print("[Using {} device]".format(device)) 27 | 28 | 29 | # training parameters configuration 30 | training_epochs = 200 31 | learning_rate_current = 0.001 32 | 33 | test_step = 5 34 | 35 | is_resume_model = False 36 | is_resume_data = True 37 | 38 | # define model, loss function and optimizer 39 | model = NeuralNetwork().to(device) 40 | 41 | if is_resume_model: 42 | model.load_state_dict(torch.load(model_path)) 43 | print('load model successfully!') 44 | 45 | # loss_fn = nn.CrossEntropyLoss() 46 | loss_fn = nn.MSELoss(reduction='mean') 47 | optimizer = torch.optim.Adam(model.parameters()) 48 | 49 | # Channel conditions dataset 50 | channel_response_set_train = [] 51 | channel_response_set_test = [] 52 | 53 | if is_resume_data: 54 | channel_response_set_train = np.load('channel_response_set_train.npy') 55 | channel_response_set_test = np.load('channel_response_set_test.npy') 56 | print('[channel data loaded successfully!]') 57 | else: 58 | # The H information set 59 | H_folder_train = './H_dataset/' 60 | H_folder_test = './H_dataset/' 61 | train_idx_low = 1 62 | train_idx_high = 301 63 | test_idx_low = 301 64 | test_idx_high = 401 65 | 66 | # Saving Channel conditions to a large matrix 67 | for train_idx in range(train_idx_low, train_idx_high): 68 | print("Processing the ", train_idx, "th document") 69 | H_file = H_folder_train + str(train_idx) + '.txt' 70 | with open(H_file) as f: 71 | for line in f: 72 | numbers_str = line.split() 73 | numbers_float = [float(x) for x in numbers_str] 74 | h_response = np.asarray(numbers_float[0:int(len(numbers_float)/2)])+1j*np.asarray( 75 | numbers_float[int(len(numbers_float)/2):len(numbers_float)]) 76 | channel_response_set_train.append(h_response) 77 | 78 | for test_idx in range(test_idx_low, test_idx_high): 79 | print("Processing the ", test_idx, "th document") 80 | H_file = H_folder_test + str(test_idx) + '.txt' 81 | with open(H_file) as f: 82 | for line in f: 83 | numbers_str = line.split() 84 | numbers_float = [float(x) for x in numbers_str] 85 | h_response = np.asarray(numbers_float[0:int(len(numbers_float)/2)])+1j*np.asarray( 86 | numbers_float[int(len(numbers_float)/2):len(numbers_float)]) 87 | channel_response_set_test.append(h_response) 88 | 89 | np.save('channel_response_set_train.npy', channel_response_set_train) 90 | np.save('channel_response_set_test.npy', channel_response_set_test) 91 | print('channel data saved successfully!') 92 | print('length of training channel response', len(channel_response_set_train), 93 | 'length of testing channel response', len(channel_response_set_test)) 94 | 95 | 96 | def train(epoch): 97 | loss = 0 98 | total_loss = 0 99 | 100 | train_channel_set_size = len(channel_response_set_train) 101 | train_channel_set_idx = np.arange(train_channel_set_size) 102 | shuffle(train_channel_set_idx) 103 | 104 | batch_size = 500 # 一个batch中的样本数 105 | total_batch = int(train_channel_set_size/batch_size) # 一个epoch中的batch数 106 | 107 | model.train() 108 | 109 | for i in range(total_batch): 110 | input_samples = [] 111 | input_labels = [] 112 | 113 | for j in range(batch_size): 114 | bits = np.random.binomial(n=1, p=0.5, size=(payloadBits_per_OFDM, )) 115 | channel_response = channel_response_set_train[train_channel_set_idx[i * batch_size + j]] 116 | signal_output, para = ofdm_simulate(bits, channel_response, SNRdb) 117 | input_labels.append(bits[0:16]) 118 | input_samples.append(signal_output) 119 | 120 | batch_x = torch.from_numpy(np.array(input_samples)) 121 | batch_y = torch.from_numpy(np.array(input_labels)) 122 | 123 | data, target = batch_x.to(device).float(), batch_y.to(device) 124 | 125 | # Compute prediction error 126 | pred = model(data) 127 | pred = pred.to(torch.float32) 128 | target = target.to(torch.float32) 129 | loss = loss_fn(pred, target) 130 | total_loss += loss.item() 131 | 132 | # Backpropagation 133 | optimizer.zero_grad() 134 | loss.backward() 135 | optimizer.step() 136 | 137 | if (i+1) % 100 == 0: 138 | print(f"{epoch} \t {i+1}/{total_batch} \t loss: {loss.item():>7f}") 139 | avg_loss = total_loss/total_batch 140 | writer.add_scalar(tag='avg_loss', scalar_value=avg_loss, global_step=epoch) 141 | print(f"epoch: {epoch} avg_loss: {avg_loss:>7f}") 142 | return avg_loss 143 | 144 | 145 | def test(epoch): 146 | test_channel_set_size = len(channel_response_set_test) 147 | batch_size = 300 # 一个batch中的样本数 148 | total_batch = int(test_channel_set_size/batch_size) # 一个epoch中的batch数 149 | 150 | model.eval() 151 | 152 | total_ser = 0 153 | best_ser = 10 # 随机设置的一个较大值 154 | 155 | with torch.no_grad(): 156 | for i in range(total_batch): 157 | input_samples = [] 158 | input_labels = [] 159 | 160 | for j in range(batch_size): 161 | bits = np.random.binomial(n=1, p=0.5, size=(payloadBits_per_OFDM, )) 162 | channel_response = channel_response_set_test[i * batch_size + j] 163 | signal_output, para = ofdm_simulate(bits, channel_response, SNRdb) 164 | input_labels.append(bits[0:16]) 165 | input_samples.append(signal_output) 166 | 167 | batch_x = torch.from_numpy(np.array(input_samples)) 168 | batch_y = torch.from_numpy(np.array(input_labels)) 169 | 170 | data, target = batch_x.to(device).float(), batch_y.to(device) 171 | 172 | # Compute prediction error 173 | pred = model(data) 174 | 175 | pred = torch.sign(pred - 0.5) 176 | target = torch.sign(target - 0.5) 177 | 178 | total_ser += torch.mean((pred != target).float()).item() 179 | avg_ser = total_ser / total_batch 180 | print(f"\n\ntest result --> avg_ser: {avg_ser:>7f} \n\n") 181 | writer.add_scalar(tag='avg_ser', scalar_value=avg_ser, global_step=epoch) 182 | 183 | if avg_ser < best_ser: 184 | torch.save(model.state_dict(), model_path) 185 | best_ser = avg_ser 186 | print(f'save best model in >>> {model_path}') 187 | return avg_ser 188 | 189 | 190 | loss_final = 0 191 | ber_final = 0 192 | for t in range(training_epochs): 193 | t += 1 194 | 195 | loss_final = train(t) 196 | if t % test_step == 0: 197 | ber_final = test(t) 198 | 199 | writer.close() 200 | 201 | print("final loss:", loss_final) 202 | print("final ber:", ber_final) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class NeuralNetwork(nn.Module): 7 | def __init__(self): 8 | super(NeuralNetwork, self).__init__() 9 | self.linear_relu_stack = nn.Sequential( 10 | nn.Linear(256, 600), 11 | nn.ReLU(), 12 | nn.Linear(600, 300), 13 | nn.ReLU(), 14 | nn.Linear(300, 16) 15 | ) 16 | 17 | def forward(self, x): 18 | logits = self.linear_relu_stack(x) 19 | return torch.sigmoid(logits) 20 | 21 | 22 | if __name__ == '__main__': 23 | input= torch.randn((1,256)) 24 | model = NeuralNetwork() 25 | print(model) 26 | output = model(input) 27 | print(output) 28 | print(output.shape) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from cfg import * 3 | from numpy.random import randn 4 | 5 | def Modulation(bits): 6 | bit_r = bits.reshape((int(len(bits)/mu), mu)) 7 | return (2*bit_r[:, 0]-1) + 1j*(2*bit_r[:, 1]-1) 8 | 9 | 10 | def DFT(OFDM_RX): 11 | return np.fft.fft(OFDM_RX) 12 | 13 | 14 | def IDFT(OFDM_data): 15 | return np.fft.ifft(OFDM_data) 16 | 17 | 18 | def addCP(OFDM_time): 19 | cp = OFDM_time[-CP:] 20 | return np.hstack([cp, OFDM_time]) 21 | 22 | 23 | def removeCP(signal): 24 | return signal[CP:(CP+K)] 25 | 26 | 27 | def channel(signal, channelResponse, SNRdb): 28 | convolved = np.convolve(signal, channelResponse) 29 | signal_power = np.mean(abs(convolved**2)) 30 | 31 | sigma2 = signal_power * 10**(-SNRdb/10) 32 | noise = np.sqrt(sigma2/2) * (randn(*convolved.shape) + 1j * randn(*convolved.shape)) 33 | 34 | return convolved + noise 35 | 36 | 37 | def equalize(OFDM_demod, Hest): 38 | return OFDM_demod / Hest 39 | 40 | 41 | def PS(bits): 42 | return bits.reshape((-1,)) 43 | 44 | 45 | def ofdm_simulate(codeword, channelResponse, SNRdb): 46 | OFDM_data = np.zeros(K, dtype=complex) 47 | pilotValue = Modulation(pilot_bits) 48 | OFDM_data[np.arange(K)] = pilotValue 49 | OFDM_time = IDFT(OFDM_data) 50 | OFDM_withCP = addCP(OFDM_time) 51 | OFMD_TX = OFDM_withCP 52 | OFDM_RX = channel(OFMD_TX, channelResponse, SNRdb) 53 | OFDM_RX_noCP = removeCP(OFDM_RX) 54 | 55 | symbol = np.zeros(K, dtype=complex) 56 | codeword_qpsk = Modulation(codeword) 57 | symbol[np.arange(K)] = codeword_qpsk 58 | OFDM_data_codeword = symbol 59 | OFDM_time_codeword = np.fft.ifft(OFDM_data_codeword) 60 | OFDM_withCP_codeword = addCP(OFDM_time_codeword) 61 | OFDM_RX_codeword = channel(OFDM_withCP_codeword, channelResponse, SNRdb) 62 | OFDM_RX_noCP_codeword = removeCP(OFDM_RX_codeword) 63 | 64 | return np.concatenate((np.concatenate((np.real(OFDM_RX_noCP), np.imag(OFDM_RX_noCP))), 65 | np.concatenate((np.real(OFDM_RX_noCP_codeword), np.imag(OFDM_RX_noCP_codeword))))), \ 66 | abs(channelResponse) 67 | --------------------------------------------------------------------------------