├── Pilot_8 ├── generations.py ├── main.py ├── README.md ├── Pilot_64 └── Global_parameters.py /Pilot_8: -------------------------------------------------------------------------------- 1 | 1.000000000000000000e+00 2 | 1.000000000000000000e+00 3 | 0.000000000000000000e+00 4 | 0.000000000000000000e+00 5 | 0.000000000000000000e+00 6 | 0.000000000000000000e+00 7 | 0.000000000000000000e+00 8 | 1.000000000000000000e+00 9 | 0.000000000000000000e+00 10 | 1.000000000000000000e+00 11 | 1.000000000000000000e+00 12 | 1.000000000000000000e+00 13 | 1.000000000000000000e+00 14 | 1.000000000000000000e+00 15 | 1.000000000000000000e+00 16 | 0.000000000000000000e+00 17 | -------------------------------------------------------------------------------- /generations.py: -------------------------------------------------------------------------------- 1 | from Global_parameters import * 2 | 3 | channel_train = np.load('channel_train.npy') 4 | train_size = channel_train.shape[0] 5 | channel_test = np.load('channel_test.npy') 6 | test_size = channel_test.shape[0] 7 | 8 | 9 | def training_gen(bs, SNRdb = 20): 10 | while True: 11 | index = np.random.choice(np.arange(train_size), size=bs) 12 | H_total = channel_train[index] 13 | input_samples = [] 14 | input_labels = [] 15 | for H in H_total: 16 | bits = np.random.binomial(n=1, p=0.5, size=(payloadBits_per_OFDM,)) 17 | signal_output, para = ofdm_simulate(bits, H, SNRdb) 18 | input_labels.append(bits[0:16]) 19 | input_samples.append(signal_output) 20 | yield (np.asarray(input_samples), np.asarray(input_labels)) 21 | 22 | 23 | def validation_gen(bs, SNRdb = 20): 24 | while True: 25 | index = np.random.choice(np.arange(train_size), size=bs) 26 | H_total = channel_train[index] 27 | input_samples = [] 28 | input_labels = [] 29 | for H in H_total: 30 | bits = np.random.binomial(n=1, p=0.5, size=(payloadBits_per_OFDM,)) 31 | signal_output, para = ofdm_simulate(bits, H, SNRdb) 32 | input_labels.append(bits[0:16]) 33 | input_samples.append(signal_output) 34 | yield (np.asarray(input_samples), np.asarray(input_labels)) 35 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.keras import * 2 | from tensorflow.python.keras.layers import * 3 | from generations import * 4 | import tensorflow as tf 5 | 6 | 7 | def bit_err(y_true, y_pred): 8 | err = 1 - tf.reduce_mean( 9 | tf.reduce_mean( 10 | tf.to_float( 11 | tf.equal( 12 | tf.sign( 13 | y_pred - 0.5), 14 | tf.cast( 15 | tf.sign( 16 | y_true - 0.5), 17 | tf.float32))), 18 | 1)) 19 | return err 20 | 21 | 22 | input_bits = Input(shape=(payloadBits_per_OFDM * 2,)) 23 | temp = BatchNormalization()(input_bits) 24 | temp = Dense(n_hidden_1, activation='relu')(input_bits) 25 | temp = BatchNormalization()(temp) 26 | temp = Dense(n_hidden_2, activation='relu')(temp) 27 | temp = BatchNormalization()(temp) 28 | temp = Dense(n_hidden_3, activation='relu')(temp) 29 | temp = BatchNormalization()(temp) 30 | out_put = Dense(n_output, activation='sigmoid')(temp) 31 | model = Model(input_bits, out_put) 32 | model.compile(optimizer='adam', loss='mse', metrics=[bit_err]) 33 | model.summary() 34 | checkpoint = callbacks.ModelCheckpoint('./temp_trained_25.h5', monitor='val_bit_err', 35 | verbose=0, save_best_only=True, mode='min', save_weights_only=True) 36 | model.fit_generator( 37 | training_gen(1000,25), 38 | steps_per_epoch=50, 39 | epochs=10000, 40 | validation_data=validation_gen(1000, 25), 41 | validation_steps=1, 42 | callbacks=[checkpoint], 43 | verbose=2) 44 | 45 | model.load_weights('./temp_trained_25.h5') 46 | BER = [] 47 | for SNR in range(5, 30, 5): 48 | y = model.evaluate( 49 | validation_gen(10000, SNR), 50 | steps=1 51 | ) 52 | BER.append(y[1]) 53 | print(y) 54 | print(BER) 55 | BER_matlab = np.array(BER) 56 | import scipy.io as sio 57 | sio.savemat('BER.mat', {'BER':BER_matlab}) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DNN_detection_via_keras 2 | This is the simplest implementation of **Power of Deep Learning for Channel Estimation and Signal Detection in OFDM Systems** using keras. I tried my best to simplify the codes, so that everyone can follow it easily. The original tensorflow version codes can be referred to [here](https://github.com/haoyye/OFDM_DNN). Compared with other frameworks (e.g., **tensorflow, pytorch**, **MXNet** and so on), this **keras-version** is the simplest realization. 3 | 4 | # Some reference 5 | According to many readers comments, I have written a simple blog of this paper, which may be helpful for Chinese reseachers to understand the main idea of this paper, 6 | you can find the blog in [blog address](https://blog.csdn.net/weixin_39274659/article/details/107748483) 7 | 8 | # First 9 | Some common problems are answered in the issue,hopefully it can help you. Besides, if this work helps you, please kindly star or fork the repo to support me. 10 | 11 | 12 | # Requirement 13 | tensorflow-gpu >= 1.12.0 14 | As the codes are written before the publication of tensorflow 2.0. 15 | 16 | # data sets 17 | I have uploaded the required data sets in [BaiduYun Drive](https://pan.baidu.com/s/16_hVoPErs4dV3LXtwPU-4w) 18 | 19 | password: **1234** 20 | 21 | As some readers mentioned, I also provided the [download url](https://drive.google.com/drive/folders/1pwjEzmLZIybk3SWNAwo6hmzmUnd5Sgsf?usp=sharing) for Google driver. 22 | 23 | which are generated by saving the numpy arrays loaded from original provided .txt files. 24 | 25 | Then, directly move the channel_train.npy and channel_test.npy to current file. Namely, the paths are 26 | './channel_train.npy' and './channel_test.npy'. 27 | 28 | Original datasets is provided in https://github.com/haoyye/OFDM_DNN as txt.file, which may cost much time to load the data. Therefore, I save enough samples as the .npy files, so that the training sets can be loaded easily and also reduce the file size. 29 | 30 | # How to use 31 | After downloaded and moved the data sets, just run main.py directly. 32 | 33 | # Some evaluation 34 | Since this repo is just a reproduction, so I follow the original idea of the author: generate random init bits, simulate the channel by loading data from the .npy file, and then build the neuron network to recover bits from the received bits. 35 | 36 | **I know some readers want to directly apply the detection neuron network to replace their traditional receiver, for comparisons and so on. It is much easy to do with this codes. In brief, the codes for generated data is not needed. You can just save your original bits and receive signal of your own system as a .mat file (if you use Matlab) or .npy file. Then, load the data by Python and use the .fit function, where original_bits is the label and receiver signal is exactly the input of the network. You even do not need to simulate the channel (as you do it in your previous work and only receive signal is required).** 37 | 38 | Sorry for my English. If you have any problem, please contact me via my email. 39 | Hopefully it is helpful for you and if possible, star or fork this repo to support. 40 | -------------------------------------------------------------------------------- /Pilot_64: -------------------------------------------------------------------------------- 1 | 0.000000000000000000e+00 2 | 0.000000000000000000e+00 3 | 0.000000000000000000e+00 4 | 0.000000000000000000e+00 5 | 0.000000000000000000e+00 6 | 0.000000000000000000e+00 7 | 0.000000000000000000e+00 8 | 0.000000000000000000e+00 9 | 0.000000000000000000e+00 10 | 0.000000000000000000e+00 11 | 0.000000000000000000e+00 12 | 0.000000000000000000e+00 13 | 1.000000000000000000e+00 14 | 1.000000000000000000e+00 15 | 0.000000000000000000e+00 16 | 0.000000000000000000e+00 17 | 0.000000000000000000e+00 18 | 0.000000000000000000e+00 19 | 1.000000000000000000e+00 20 | 0.000000000000000000e+00 21 | 0.000000000000000000e+00 22 | 1.000000000000000000e+00 23 | 0.000000000000000000e+00 24 | 0.000000000000000000e+00 25 | 1.000000000000000000e+00 26 | 1.000000000000000000e+00 27 | 0.000000000000000000e+00 28 | 1.000000000000000000e+00 29 | 0.000000000000000000e+00 30 | 1.000000000000000000e+00 31 | 1.000000000000000000e+00 32 | 1.000000000000000000e+00 33 | 0.000000000000000000e+00 34 | 0.000000000000000000e+00 35 | 1.000000000000000000e+00 36 | 0.000000000000000000e+00 37 | 1.000000000000000000e+00 38 | 0.000000000000000000e+00 39 | 1.000000000000000000e+00 40 | 0.000000000000000000e+00 41 | 0.000000000000000000e+00 42 | 0.000000000000000000e+00 43 | 0.000000000000000000e+00 44 | 1.000000000000000000e+00 45 | 1.000000000000000000e+00 46 | 1.000000000000000000e+00 47 | 1.000000000000000000e+00 48 | 1.000000000000000000e+00 49 | 1.000000000000000000e+00 50 | 0.000000000000000000e+00 51 | 0.000000000000000000e+00 52 | 1.000000000000000000e+00 53 | 0.000000000000000000e+00 54 | 0.000000000000000000e+00 55 | 0.000000000000000000e+00 56 | 1.000000000000000000e+00 57 | 1.000000000000000000e+00 58 | 0.000000000000000000e+00 59 | 1.000000000000000000e+00 60 | 0.000000000000000000e+00 61 | 1.000000000000000000e+00 62 | 0.000000000000000000e+00 63 | 0.000000000000000000e+00 64 | 0.000000000000000000e+00 65 | 1.000000000000000000e+00 66 | 0.000000000000000000e+00 67 | 1.000000000000000000e+00 68 | 1.000000000000000000e+00 69 | 0.000000000000000000e+00 70 | 0.000000000000000000e+00 71 | 1.000000000000000000e+00 72 | 1.000000000000000000e+00 73 | 1.000000000000000000e+00 74 | 0.000000000000000000e+00 75 | 0.000000000000000000e+00 76 | 0.000000000000000000e+00 77 | 1.000000000000000000e+00 78 | 1.000000000000000000e+00 79 | 1.000000000000000000e+00 80 | 0.000000000000000000e+00 81 | 0.000000000000000000e+00 82 | 1.000000000000000000e+00 83 | 1.000000000000000000e+00 84 | 1.000000000000000000e+00 85 | 0.000000000000000000e+00 86 | 0.000000000000000000e+00 87 | 1.000000000000000000e+00 88 | 1.000000000000000000e+00 89 | 0.000000000000000000e+00 90 | 0.000000000000000000e+00 91 | 1.000000000000000000e+00 92 | 0.000000000000000000e+00 93 | 0.000000000000000000e+00 94 | 0.000000000000000000e+00 95 | 1.000000000000000000e+00 96 | 1.000000000000000000e+00 97 | 1.000000000000000000e+00 98 | 0.000000000000000000e+00 99 | 1.000000000000000000e+00 100 | 0.000000000000000000e+00 101 | 0.000000000000000000e+00 102 | 0.000000000000000000e+00 103 | 1.000000000000000000e+00 104 | 0.000000000000000000e+00 105 | 1.000000000000000000e+00 106 | 1.000000000000000000e+00 107 | 1.000000000000000000e+00 108 | 1.000000000000000000e+00 109 | 0.000000000000000000e+00 110 | 0.000000000000000000e+00 111 | 0.000000000000000000e+00 112 | 0.000000000000000000e+00 113 | 0.000000000000000000e+00 114 | 0.000000000000000000e+00 115 | 1.000000000000000000e+00 116 | 0.000000000000000000e+00 117 | 0.000000000000000000e+00 118 | 1.000000000000000000e+00 119 | 1.000000000000000000e+00 120 | 0.000000000000000000e+00 121 | 1.000000000000000000e+00 122 | 0.000000000000000000e+00 123 | 1.000000000000000000e+00 124 | 1.000000000000000000e+00 125 | 0.000000000000000000e+00 126 | 1.000000000000000000e+00 127 | 1.000000000000000000e+00 128 | 0.000000000000000000e+00 129 | -------------------------------------------------------------------------------- /Global_parameters.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | K = 64 5 | CP = K // 4 6 | P = 64 7 | allCarriers = np.arange(K) # indices of all subcarriers ([0, 1, ... K-1]) 8 | 9 | if P < K: 10 | pilotCarriers = allCarriers[::K // P] # Pilots is every (K/P)th carrier. 11 | dataCarriers = np.delete(allCarriers, pilotCarriers) 12 | 13 | else: # K = P 14 | pilotCarriers = allCarriers 15 | dataCarriers = [] 16 | 17 | mu = 2 18 | payloadBits_per_OFDM = K * mu 19 | SNRdb = 25 20 | H_folder_train = '../H_dataset/Train/' 21 | H_folder_test = '../H_dataset/Test/' 22 | n_hidden_1 = 500 23 | n_hidden_2 = 250 # 1st layer num features 24 | n_hidden_3 = 120 # 2nd layer num features 25 | n_output = 16 # every 16 bit are predicted by a model 26 | 27 | 28 | def Modulation(bits): 29 | bit_r = bits.reshape((int(len(bits) / mu), mu)) 30 | # This is just for QAM modulation 31 | return (2 * bit_r[:, 0] - 1) + 1j * (2 * bit_r[:, 1] - 1) 32 | 33 | 34 | def OFDM_symbol(Data, pilot_flag): 35 | symbol = np.zeros(K, dtype=complex) # the overall K subcarriers 36 | #symbol = np.zeros(K) 37 | symbol[pilotCarriers] = pilotValue # allocate the pilot subcarriers 38 | symbol[dataCarriers] = Data # allocate the pilot subcarriers 39 | return symbol 40 | 41 | 42 | def IDFT(OFDM_data): 43 | return np.fft.ifft(OFDM_data) 44 | 45 | 46 | def addCP(OFDM_time): 47 | cp = OFDM_time[-CP:] # take the last CP samples ... 48 | return np.hstack([cp, OFDM_time]) # ... and add them to the beginning 49 | 50 | 51 | def channel(signal, channelResponse, SNRdb): 52 | convolved = np.convolve(signal, channelResponse) 53 | signal_power = np.mean(abs(convolved**2)) 54 | sigma2 = signal_power * 10**(-SNRdb / 10) 55 | noise = np.sqrt(sigma2 / 2) * (np.random.randn(* 56 | convolved.shape) + 1j * np.random.randn(*convolved.shape)) 57 | return convolved + noise 58 | 59 | 60 | def removeCP(signal): 61 | return signal[CP:(CP + K)] 62 | 63 | 64 | def DFT(OFDM_RX): 65 | return np.fft.fft(OFDM_RX) 66 | 67 | 68 | def ofdm_simulate(codeword, channelResponse, SNRdb): 69 | bits = np.random.binomial(n=1, p=0.5, size=(2*(K - P),)) 70 | QAM = Modulation(bits) 71 | OFDM_data = np.zeros(K, dtype=complex) 72 | OFDM_data[pilotCarriers] = pilotValue 73 | OFDM_data[dataCarriers] = QAM 74 | OFDM_time = IDFT(OFDM_data) 75 | OFDM_withCP = addCP(OFDM_time) 76 | OFDM_TX = OFDM_withCP 77 | OFDM_RX = channel(OFDM_TX, channelResponse, SNRdb) 78 | OFDM_RX_noCP = removeCP(OFDM_RX) 79 | OFDM_RX_noCP = DFT(OFDM_RX_noCP) 80 | 81 | # ----- target inputs --- 82 | symbol = np.zeros(K, dtype=complex) 83 | codeword_qam = Modulation(codeword) 84 | symbol[np.arange(K)] = codeword_qam 85 | OFDM_data_codeword = symbol 86 | OFDM_time_codeword = np.fft.ifft(OFDM_data_codeword) 87 | OFDM_withCP_cordword = addCP(OFDM_time_codeword) 88 | OFDM_RX_codeword = channel(OFDM_withCP_cordword, channelResponse, SNRdb) 89 | OFDM_RX_noCP_codeword = removeCP(OFDM_RX_codeword) 90 | OFDM_RX_noCP_codeword = DFT(OFDM_RX_noCP_codeword) 91 | return np.concatenate( 92 | (np.concatenate( 93 | (np.real(OFDM_RX_noCP), np.imag(OFDM_RX_noCP))), np.concatenate( 94 | (np.real(OFDM_RX_noCP_codeword), np.imag(OFDM_RX_noCP_codeword))))), abs(channelResponse) 95 | 96 | 97 | Pilot_file_name = 'Pilot_' + str(P) 98 | if os.path.isfile(Pilot_file_name): 99 | print('Load Training Pilots txt') 100 | # load file 101 | bits = np.loadtxt(Pilot_file_name, delimiter=',') 102 | else: 103 | # write file 104 | bits = np.random.binomial(n=1, p=0.5, size=(K * mu, )) 105 | np.savetxt(Pilot_file_name, bits, delimiter=',') 106 | 107 | 108 | pilotValue = Modulation(bits) 109 | --------------------------------------------------------------------------------