├── Readme.md ├── __pycache__ ├── model.cpython-35.pyc └── settings.cpython-35.pyc ├── model_save_10 ├── DetNetmodel.cpkt.meta ├── DetNetmodel.cpkt.index ├── DetNetmodel.cpkt.data-00000-of-00001 ├── events.out.tfevents.1553240081.DESKTOP-OLT95JR ├── events.out.tfevents.1553737434.DESKTOP-OLT95JR ├── events.out.tfevents.1553737477.DESKTOP-OLT95JR └── checkpoint ├── model_save_20 ├── DetNetmodel.cpkt.meta ├── DetNetmodel.cpkt.index ├── DetNetmodel.cpkt.data-00000-of-00001 ├── events.out.tfevents.1553240072.DESKTOP-20KBE6K ├── events.out.tfevents.1553241096.DESKTOP-20KBE6K ├── events.out.tfevents.1553241140.DESKTOP-20KBE6K ├── events.out.tfevents.1553241187.DESKTOP-20KBE6K └── checkpoint ├── model_save_30 ├── DetNetmodel.cpkt.meta ├── DetNetmodel.cpkt.index ├── DetNetmodel.cpkt.data-00000-of-00001 ├── events.out.tfevents.1553651245.DESKTOP-OLT95JR ├── events.out.tfevents.1553240273.xiaoliang-X11DAi-N └── checkpoint ├── model_save_100 ├── DetNetmodel.cpkt.index ├── DetNetmodel.cpkt.meta ├── DetNetmodel.cpkt.data-00000-of-00001 ├── events.out.tfevents.1553654009.xiaoliang-X11DAi-N └── checkpoint ├── settings.py ├── ber_plot.py ├── DetNet_test.py ├── code_test.py ├── DetNet_train.py └── model.py /Readme.md: -------------------------------------------------------------------------------- 1 | # DeepMIMODetection 2 | 基于深度学习的MIMI信号检测网络 3 | -------------------------------------------------------------------------------- /__pycache__/model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/__pycache__/model.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/settings.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/__pycache__/settings.cpython-35.pyc -------------------------------------------------------------------------------- /model_save_10/DetNetmodel.cpkt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_10/DetNetmodel.cpkt.meta -------------------------------------------------------------------------------- /model_save_20/DetNetmodel.cpkt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_20/DetNetmodel.cpkt.meta -------------------------------------------------------------------------------- /model_save_30/DetNetmodel.cpkt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_30/DetNetmodel.cpkt.meta -------------------------------------------------------------------------------- /model_save_10/DetNetmodel.cpkt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_10/DetNetmodel.cpkt.index -------------------------------------------------------------------------------- /model_save_100/DetNetmodel.cpkt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_100/DetNetmodel.cpkt.index -------------------------------------------------------------------------------- /model_save_100/DetNetmodel.cpkt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_100/DetNetmodel.cpkt.meta -------------------------------------------------------------------------------- /model_save_20/DetNetmodel.cpkt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_20/DetNetmodel.cpkt.index -------------------------------------------------------------------------------- /model_save_30/DetNetmodel.cpkt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_30/DetNetmodel.cpkt.index -------------------------------------------------------------------------------- /model_save_10/DetNetmodel.cpkt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_10/DetNetmodel.cpkt.data-00000-of-00001 -------------------------------------------------------------------------------- /model_save_20/DetNetmodel.cpkt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_20/DetNetmodel.cpkt.data-00000-of-00001 -------------------------------------------------------------------------------- /model_save_30/DetNetmodel.cpkt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_30/DetNetmodel.cpkt.data-00000-of-00001 -------------------------------------------------------------------------------- /model_save_100/DetNetmodel.cpkt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_100/DetNetmodel.cpkt.data-00000-of-00001 -------------------------------------------------------------------------------- /model_save_10/events.out.tfevents.1553240081.DESKTOP-OLT95JR: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_10/events.out.tfevents.1553240081.DESKTOP-OLT95JR -------------------------------------------------------------------------------- /model_save_10/events.out.tfevents.1553737434.DESKTOP-OLT95JR: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_10/events.out.tfevents.1553737434.DESKTOP-OLT95JR -------------------------------------------------------------------------------- /model_save_10/events.out.tfevents.1553737477.DESKTOP-OLT95JR: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_10/events.out.tfevents.1553737477.DESKTOP-OLT95JR -------------------------------------------------------------------------------- /model_save_20/events.out.tfevents.1553240072.DESKTOP-20KBE6K: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_20/events.out.tfevents.1553240072.DESKTOP-20KBE6K -------------------------------------------------------------------------------- /model_save_20/events.out.tfevents.1553241096.DESKTOP-20KBE6K: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_20/events.out.tfevents.1553241096.DESKTOP-20KBE6K -------------------------------------------------------------------------------- /model_save_20/events.out.tfevents.1553241140.DESKTOP-20KBE6K: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_20/events.out.tfevents.1553241140.DESKTOP-20KBE6K -------------------------------------------------------------------------------- /model_save_20/events.out.tfevents.1553241187.DESKTOP-20KBE6K: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_20/events.out.tfevents.1553241187.DESKTOP-20KBE6K -------------------------------------------------------------------------------- /model_save_30/events.out.tfevents.1553651245.DESKTOP-OLT95JR: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_30/events.out.tfevents.1553651245.DESKTOP-OLT95JR -------------------------------------------------------------------------------- /model_save_100/events.out.tfevents.1553654009.xiaoliang-X11DAi-N: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_100/events.out.tfevents.1553654009.xiaoliang-X11DAi-N -------------------------------------------------------------------------------- /model_save_30/events.out.tfevents.1553240273.xiaoliang-X11DAi-N: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hello-zhanghao/DeepMIMODetection/HEAD/model_save_30/events.out.tfevents.1553240273.xiaoliang-X11DAi-N -------------------------------------------------------------------------------- /model_save_100/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "/home/haozhang/DeepMIMODetection_V3.3/model_save_100/DetNetmodel.cpkt" 2 | all_model_checkpoint_paths: "/home/haozhang/DeepMIMODetection_V3.3/model_save_100/DetNetmodel.cpkt" 3 | -------------------------------------------------------------------------------- /model_save_10/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "C:\\Users\\zfzdr\\Desktop\\DeepMIMODetection_V3.4.1/model_save_10/DetNetmodel.cpkt" 2 | all_model_checkpoint_paths: "C:\\Users\\zfzdr\\Desktop\\DeepMIMODetection_V3.4.1/model_save_10/DetNetmodel.cpkt" 3 | -------------------------------------------------------------------------------- /model_save_20/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "C:\\Users\\zfzdr\\Desktop\\DeepMIMODetection_V3.3/model_save_20/DetNetmodel.cpkt" 2 | all_model_checkpoint_paths: "C:\\Users\\zfzdr\\Desktop\\DeepMIMODetection_V3.3/model_save_20/DetNetmodel.cpkt" 3 | -------------------------------------------------------------------------------- /model_save_30/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "C:\\Users\\zfzdr\\Desktop\\DeepMIMODetection_V3.3/model_save_30/DetNetmodel.cpkt" 2 | all_model_checkpoint_paths: "C:\\Users\\zfzdr\\Desktop\\DeepMIMODetection_V3.3/model_save_30/DetNetmodel.cpkt" 3 | -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Mar 22 13:21:11 2019 4 | 5 | @author: zfzdr 6 | """ 7 | import os 8 | class Settings(): 9 | def __init__(self): 10 | self.Nr = 16 11 | self.Nt = 8 12 | self.mod_Name = 'qpsk' 13 | self.M = 4 14 | self.batchsize = 2000 15 | self.maxEpochs = 30000 16 | self.Learning_rate = 0.001 17 | self.L_size = 10 18 | self.savedir = os.getcwd() + '/model_save_%d/'%self.L_size 19 | self.batchsize_test = 3000 20 | self.maxEpochs_test = 200 21 | self.SNR_st = 1 22 | self.SNR_ed = 16 -------------------------------------------------------------------------------- /ber_plot.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Mar 7 15:42:59 2019 4 | 5 | @author: zfzdr 6 | """ 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | #t = np.arange(1,16,1) 10 | #plt.plot(t, ber_DetNet[0], label='ber_DetNet') 11 | #plt.plot(t, ber_ZF[0], label='ber_ZF') 12 | #plt.yscale('log') 13 | #plt.legend() 14 | #plt.grid(True) 15 | 16 | ##layer = 10 17 | #ber_DetNet_10 = [0.09139271, 0.06937187, 0.05122896, 0.03583354, 0.02316281, 0.01518677, 18 | # 0.00859656, 0.00550948, 0.00318521, 0.00180083, 0.00140073, 0.00085437, 19 | # 0.00056792, 0.0005126, 0.00025667] 20 | #ber_ZF_10 = [1.23497812e-01, 9.89660417e-02, 7.57281250e-02, 5.39903125e-02, 21 | # 3.70536458e-02, 2.41823958e-02, 1.32711458e-02, 7.92520833e-03, 22 | # 3.94031250e-03, 1.64312500e-03, 6.23437500e-04, 2.64062500e-04, 23 | # 7.62500000e-05, 2.90625000e-05, 2.60416667e-06] 24 | ##ZF_detector time is:540.4367575645447 25 | ##DetNet_detector time is:557.672048330307 26 | ##Total time is: 1339.3493297100067 27 | ## 28 | ##layer = 20 29 | #ber_DetNet_20 = [8.91815625e-02, 6.51067708e-02, 4.54131250e-02, 2.73276042e-02, 30 | # 1.53232292e-02, 8.01062500e-03, 3.68093750e-03, 1.52062500e-03, 31 | # 6.85520833e-04, 2.50625000e-04, 1.00416667e-04, 4.89583333e-05, 32 | # 2.09375000e-05, 6.97916667e-06, 3.64583333e-06] 33 | #ber_ZF_20 = [1.25135521e-01, 9.81607292e-02, 7.69308333e-02, 5.38016667e-02, 34 | # 3.67066667e-02, 2.39352083e-02, 1.37806250e-02, 7.58062500e-03, 35 | # 3.83552083e-03, 1.71583333e-03, 6.51562500e-04, 2.64687500e-04, 36 | # 7.39583333e-05, 1.60416667e-05, 4.37500000e-06] 37 | ##ZF_detector time is:580.1622507572174 38 | ##DetNet_detector time is:618.585266828537 39 | ##Total time is: 1457.339527130127 40 | # 41 | ##layer=30 42 | #ber_DetNet_30 = [9.29106250e-02, 6.85543750e-02, 4.75781250e-02, 3.03684375e-02, 43 | # 1.75051042e-02, 9.55927083e-03, 4.37552083e-03, 1.78520833e-03, 44 | # 8.18020833e-04, 2.80208333e-04, 9.31250000e-05, 3.09375000e-05, 45 | # 1.20833333e-05, 4.37500000e-06, 2.91666667e-06,] 46 | #ber_ZF_30 = [1.24593646e-01, 9.80235417e-02, 7.54879167e-02, 5.41806250e-02, 47 | # 3.64765625e-02, 2.38037500e-02, 1.41215625e-02, 7.23135417e-03, 48 | # 3.97687500e-03, 1.75166667e-03, 6.48229167e-04, 2.10729167e-04, 49 | # 8.35416667e-05, 1.20833333e-05, 8.64583333e-06] 50 | ##ZF_detector time is:612.1121361255646 51 | ##DetNet_detector time is:681.0219402313232 52 | ##Total time is: 1550.7440061569214 53 | 54 | t =np.arange(SNR_st+6, SNR_ed-2, 1) 55 | plt.plot(t, ber_ZF[6:13], label='ber_ZF') 56 | plt.plot(t, ber_DetNet_10[6:13], label='ber_DetNet_10') 57 | plt.plot(t, ber_DetNet_20[6:13], label='ber_DetNet_20') 58 | plt.plot(t, ber_DetNet_30[6:13], label='ber_DetNet_30') 59 | plt.plot(t, ber_DetNet_100[6:13], label='ber_DetNet_100') 60 | plt.yscale('log') 61 | plt.legend() 62 | plt.grid(True) -------------------------------------------------------------------------------- /DetNet_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Mar 22 10:42:14 2019 4 | 5 | @author: zfzdr 6 | """ 7 | #*********************************************************************** 8 | #*********************************************************************** 9 | 10 | import tensorflow as tf 11 | import time as tm 12 | import datetime 13 | 14 | #自定义模块 15 | from model import * 16 | from settings import Settings 17 | 18 | # Define the parameter of the mode 19 | DetNet_settings = Settings() 20 | Nr = DetNet_settings.Nr #number of receive antennas 21 | Nt = DetNet_settings.Nt #number of transmit antenns 22 | mod_Name = DetNet_settings.mod_Name 23 | batchsize_test = DetNet_settings.batchsize_test 24 | maxEpochs_test = DetNet_settings.maxEpochs_test 25 | M = DetNet_settings.M 26 | SNR_st = DetNet_settings.SNR_st 27 | SNR_ed = DetNet_settings.SNR_ed 28 | 29 | 30 | numbits_perSymbol = int(np.log2(M)) 31 | source_length = batchsize_test * maxEpochs_test * Nt * numbits_perSymbol 32 | 33 | loc1 = os.getcwd() + '/model_save_10/DetNetmodel.cpkt' 34 | loc2 = os.getcwd() + '/model_save_20/DetNetmodel.cpkt' 35 | loc3 = os.getcwd() + '/model_save_30/DetNetmodel.cpkt' 36 | loc4 = os.getcwd() + '/model_save_100/DetNetmodel.cpkt' 37 | 38 | predict1 = Predict(loc1) 39 | predict2 = Predict(loc2) 40 | predict3 = Predict(loc3) 41 | predict4 = Predict(loc4) 42 | 43 | time_ZF = 0 44 | time_DetNet_10 = 0 45 | time_DetNet_20 = 0 46 | time_DetNet_30 = 0 47 | time_DetNet_100 = 0 48 | 49 | N = SNR_ed - SNR_st 50 | ber_ZF = np.zeros(N) 51 | ber_DetNet_10 = np.zeros(N) 52 | ber_DetNet_20 = np.zeros(N) 53 | ber_DetNet_30 = np.zeros(N) 54 | ber_DetNet_100 = np.zeros(N) 55 | 56 | time_st = tm.time() 57 | for SNR in np.arange(SNR_st, SNR_ed, 1): 58 | now_time = datetime.datetime.now().strftime('%X') 59 | print(SNR, now_time) 60 | 61 | biterrorsum_ZF = 0 62 | biterrorsum_DetNet_10 = 0 63 | biterrorsum_DetNet_20 = 0 64 | biterrorsum_DetNet_30 = 0 65 | biterrorsum_DetNet_100 = 0 66 | for i in range(maxEpochs_test): 67 | source_test, x_test, y_test, H_test = generate_testdata( 68 | batchsize_test, Nt, Nr, mod_Name, SNR) 69 | 70 | time1 = tm.time() 71 | x_ZF = predict1.ZF_detector(y_test, H_test) 72 | time2 = tm.time() 73 | time_ZF = time_ZF + time2 - time1 74 | 75 | time1 = tm.time() 76 | x_DetNet_10 = predict1.DetNet_detector(y_test, H_test) 77 | time2 = tm.time() 78 | time_DetNet_10 = time_DetNet_10 + time2 - time1 79 | 80 | time1 = tm.time() 81 | x_DetNet_20 = predict2.DetNet_detector(y_test, H_test) 82 | time2 = tm.time() 83 | time_DetNet_20 = time_DetNet_20 + time2 - time1 84 | 85 | time1 = tm.time() 86 | x_DetNet_30 = predict3.DetNet_detector(y_test, H_test) 87 | time2 = tm.time() 88 | time_DetNet_30 = time_DetNet_30 + time2 - time1 89 | 90 | time1 = tm.time() 91 | x_DetNet_100 = predict4.DetNet_detector(y_test, H_test) 92 | time2 = tm.time() 93 | time_DetNet_100 = time_DetNet_100 + time2 - time1 94 | 95 | 96 | biterrorsum_ZF = biterrorsum_ZF + bit_error( 97 | source_test, x_ZF, Nt, batchsize_test, mod_Name) 98 | biterrorsum_DetNet_10 = biterrorsum_DetNet_10 + bit_error( 99 | source_test, x_DetNet_10, Nt, batchsize_test, mod_Name) 100 | biterrorsum_DetNet_20 = biterrorsum_DetNet_20 + bit_error( 101 | source_test, x_DetNet_20, Nt, batchsize_test, mod_Name) 102 | biterrorsum_DetNet_30 = biterrorsum_DetNet_30 + bit_error( 103 | source_test, x_DetNet_30, Nt, batchsize_test, mod_Name) 104 | biterrorsum_DetNet_100 = biterrorsum_DetNet_100 + bit_error( 105 | source_test, x_DetNet_100, Nt, batchsize_test, mod_Name) 106 | 107 | 108 | #计算不同误码率下ZF检测和DetNet检测的性能 109 | ber_ZF[SNR-SNR_st] = biterrorsum_ZF / source_length 110 | ber_DetNet_10[SNR-SNR_st] = biterrorsum_DetNet_10 / source_length 111 | ber_DetNet_20[SNR-SNR_st] = biterrorsum_DetNet_20 / source_length 112 | ber_DetNet_30[SNR-SNR_st] = biterrorsum_DetNet_30 / source_length 113 | ber_DetNet_100[SNR-SNR_st] = biterrorsum_DetNet_100 / source_length 114 | 115 | time_ed = tm.time() 116 | 117 | print('ber_ZF:', ber_ZF) 118 | print('ber_DetNet_10:', ber_DetNet_10) 119 | print('ber_DetNet_20:', ber_DetNet_20) 120 | print('ber_DetNet_30:', ber_DetNet_30) 121 | print('ber_DetNet_100:', ber_DetNet_100) 122 | 123 | 124 | print('ZF_detector time is:'+str(time_ZF)) 125 | print('DetNet_10_detector time is:'+str(time_DetNet_10)) 126 | print('DetNet_20_detector time is:'+str(time_DetNet_20)) 127 | print('DetNet_30_detector time is:'+str(time_DetNet_30)) 128 | print('DetNet_100_detector time is:'+str(time_DetNet_100)) 129 | 130 | print("Total time is: "+str(time_ed-time_st)) 131 | 132 | 133 | -------------------------------------------------------------------------------- /code_test.py: -------------------------------------------------------------------------------- 1 | #from tensorflow.python.tools.inspect_checkpoint import \ 2 | #print_tensors_in_checkpoint_file as pt 3 | # 4 | #from tensorflow.python import pywrap_tensorflow 5 | # 6 | #savedir = os.getcwd() + '/model_save_%d/'%5 7 | #savefile = savedir+'/DetNetmodel.cpkt' 8 | #pt(savedir+'DetNetmodel.cpkt', None, True, True) 9 | ##reader = pywrap_tensorflow.NewCheckpointReader(savefile) 10 | ## 11 | ##var_to_shape_map = reader.get_variable_to_shape_map() 12 | ## 13 | ##for key in var_to_shape_map: 14 | ## 15 | ## print("tensor_name: ", DetNet_V3.2g 16 | import tensorflow as tf 17 | import time as tm 18 | import datetime 19 | import os 20 | 21 | #自定义模块 22 | from model import * 23 | from settings import Settings 24 | 25 | 26 | # Define the parameter of the mode 27 | DetNet_settings = Settings() 28 | Nr = DetNet_settings.Nr #number of receive antennas 29 | Nt = DetNet_settings.Nt #number of transmit antenns 30 | mod_Name = DetNet_settings.mod_Name 31 | batchsize_test = DetNet_settings.batchsize_test 32 | maxEpochs_test = DetNet_settings.maxEpochs_test 33 | #savedir = DetNet_settings.savedir 34 | #savedir = savedir.replace('\\','/') 35 | #savedir = r'C:\Users\zfzdr\Desktop\DeepMIMODetection_V3.3\model_save_30\\' 36 | M = DetNet_settings.M 37 | 38 | 39 | numbits_perSymbol = int(np.log2(M)) 40 | source_length = batchsize_test * maxEpochs_test * Nt * numbits_perSymbol 41 | 42 | 43 | c1 = tf.constant(1) 44 | print(c1.graph) 45 | g1 = tf.Graph() 46 | g2 = tf.Graph() 47 | g3 = tf.Graph() 48 | print(g1) 49 | print(g2) 50 | print(g3) 51 | sess = tf.Session(graph=g1) 52 | loc1 = os.getcwd() + '/model_save_10/DetNetmodel.cpkt' 53 | loc2 = os.getcwd() + '/model_save_20/DetNetmodel.cpkt' 54 | predict1 = Predict(loc1) 55 | predict2 = Predict(loc2) 56 | time_ZF = 0 57 | time_DetNet = 0 58 | ber_DetNet_10 = np.zeros((1,15)) 59 | ber_DetNet_20 = np.zeros((1,15)) 60 | ber_ZF = np.zeros((1,15)) 61 | time_st = tm.time() 62 | for SNR in range(1,16): 63 | biterrorsum_DetNet_10 = 0 64 | biterrorsum_DetNet_20 = 0 65 | biterrorsum_ZF = 0 66 | for i in range(maxEpochs_test): 67 | source_test, x_test, y_test, H_test = generate_testdata( 68 | batchsize_test, Nt, Nr, mod_Name, SNR) 69 | # #计算ZF检测的误码率和时间复杂度 70 | # time_ZF_st = tm.time() 71 | # x_ZF = sess.run(ZF_detector,feed_dict={y:y_test, H:H_test}) 72 | # biterrorsum_ZF = biterrorsum_ZF + bit_error( 73 | # source_test, x_ZF, Nt, batchsize_test, mod_Name) 74 | # time_ZF_ed = tm.time() 75 | # time_ZF = time_ZF + time_ZF_ed - time_ZF_st 76 | # #计算DetNet检测的误码率和时间复杂度 77 | 78 | time_DetNet_st = tm.time() 79 | x_DetNet_10 = predict1.DetNet_detector(y_test, H_test) 80 | x_DetNet_20 = predict2.DetNet_detector(y_test, H_test) 81 | x_ZF = predict1.ZF_detector(y_test, H_test) 82 | 83 | biterrorsum_DetNet_10 = biterrorsum_DetNet_10 + bit_error( 84 | source_test, x_DetNet_10, Nt, batchsize_test, mod_Name) 85 | biterrorsum_DetNet_20 = biterrorsum_DetNet_20 + bit_error( 86 | source_test, x_DetNet_20, Nt, batchsize_test, mod_Name) 87 | biterrorsum_ZF = biterrorsum_ZF + bit_error( 88 | source_test, x_ZF, Nt, batchsize_test, mod_Name) 89 | 90 | time_DetNet_ed = tm.time() 91 | time_DetNet = time_DetNet + time_DetNet_ed - time_DetNet_st 92 | # 计算不同误码率下ZF检测和DetNet检测的性能 93 | ber_DetNet_10[:, SNR-1] = biterrorsum_DetNet_10 / source_length 94 | ber_DetNet_20[:, SNR-1] = biterrorsum_DetNet_20 / source_length 95 | ber_ZF[:, SNR-1] = biterrorsum_ZF / source_length 96 | 97 | print('ber_DetNet_10:', ber_DetNet_10[0]) 98 | print('ber_DetNet_20:', ber_DetNet_20[0]) 99 | print('ber_ZF:', ber_ZF) 100 | 101 | time_ed = tm.time() 102 | print("Total time is: "+str(time_ed-time_st)) 103 | 104 | 105 | #with g1.as_default(): 106 | # saver = tf.train.import_meta_graph(loc + '.meta') 107 | # saver.restore(sess, loc) 108 | # time_ZF = 0 109 | # time_DetNet = 0 110 | # ber_DetNet_10 = np.zeros((1,15)) 111 | # ber_ZF_10 = np.zeros((1,15)) 112 | # for SNR in range(1,16): 113 | # biterrorsum_DetNet = 0 114 | # biterrorsum_ZF = 0 115 | # for i in range(maxEpochs_test): 116 | # source_test, x_test, y_test, H_test = generate_testdata( 117 | # batchsize_test, Nt, Nr, mod_Name, SNR) 118 | # #计算ZF检测的误码率和时间复杂度 119 | # time_ZF_st = tm.time() 120 | # x_ZF = sess.run(ZF_detector,feed_dict={y:y_test, H:H_test}) 121 | # biterrorsum_ZF = biterrorsum_ZF + bit_error( 122 | # source_test, x_ZF, Nt, batchsize_test, mod_Name) 123 | # time_ZF_ed = tm.time() 124 | # time_ZF = time_ZF + time_ZF_ed - time_ZF_st 125 | # #计算DetNet检测的误码率和时间复杂度 126 | # time_DetNet_st = tm.time() 127 | # x_DetNet = sess.run(DetNet_detector,feed_dict={y:y_test,H:H_test}) 128 | # biterrorsum_DetNet = biterrorsum_DetNet + bit_error( 129 | # source_test, x_DetNet, Nt, batchsize_test, mod_Name) 130 | # time_DetNet_ed = tm.time() 131 | # time_DetNet = time_DetNet + time_DetNet_ed - time_DetNet_st 132 | # #计算不同误码率下ZF检测和DetNet检测的性能 133 | # ber_DetNet_10[:, SNR-1] = biterrorsum_DetNet / source_length 134 | # ber_ZF_20[:, SNR-1] = biterrorsum_ZF/source_length 135 | # 136 | # 137 | # 138 | # 139 | #time_ed = tm.time() 140 | #print("Total time is: "+str(time_ed-time_st)) 141 | ##with tf.Session as sess: 142 | ## print(g1) 143 | ## print(g2) 144 | # print(g3) 145 | -------------------------------------------------------------------------------- /DetNet_train.py: -------------------------------------------------------------------------------- 1 | ### -*- coding: utf-8 -*- 2 | ##""" 3 | ##Created on Wed Feb 27 11:05:05 2019 4 | ## 5 | ##@author: zfzdr 6 | ##""" 7 | #*********************************************************************** 8 | #*********************************************************************** 9 | 10 | 11 | #from __future__ import division 12 | import tensorflow as tf 13 | #import numpy as np 14 | #import math as ms 15 | import time as tm 16 | #import scipy.io as scio 17 | import datetime 18 | #自定义模块 19 | from model import * 20 | from settings import Settings 21 | 22 | 23 | # Define the parameter of the mode 24 | DetNet_settings = Settings() 25 | Nr = DetNet_settings.Nr #number of receive antennas 26 | Nt = DetNet_settings.Nt #number of transmit antenns 27 | mod_Name = DetNet_settings.mod_Name 28 | batchsize = DetNet_settings.batchsize 29 | maxEpochs = DetNet_settings.maxEpochs 30 | Learning_rate = DetNet_settings.Learning_rate 31 | L_size = DetNet_settings.L_size 32 | savedir = DetNet_settings.savedir 33 | 34 | # Define the Deep MIMO Detection model of the system 35 | tf.reset_default_graph() 36 | # 1 Define the placeholder 37 | x = tf.placeholder(tf.float32, shape=(2 * Nt, None), name = 'transmit') 38 | y = tf.placeholder(tf.float32, shape=(2 * Nr, None), name = 'receiver') 39 | H = tf.placeholder(tf.float32, shape=(2*Nr, 2*Nt), name ='channel') 40 | 41 | # 2 Define the parameters of the network 42 | para_dic = {} 43 | for i in range(L_size): 44 | para_dic['w{}'.format(i+1)] = tf.Variable(tf.truncated_normal( 45 | [2 * Nt, 6 * Nt], stddev=1, seed=1, dtype='float32')) 46 | 47 | para_dic['b{}'.format(i+1)] = tf.Variable(tf.constant(0.001, 48 | shape=[2*Nt, 1], dtype='float32')) 49 | 50 | para_dic['t{}'.format(i+1)]=tf.Variable(tf.constant(0.5, dtype='float32')) 51 | 52 | 53 | # 3 Define the process of the forward propagation 54 | output_dic = {} 55 | for k in range(L_size): 56 | if k == 0: 57 | w = para_dic['w{}'.format(k+1)] 58 | b = para_dic['b{}'.format(k+1)] 59 | t = para_dic['t{}'.format(k+1)] 60 | Stack1 = tf.matmul(tf.transpose(H), y) 61 | s = tf.shape(Stack1) 62 | Stack2 = tf.zeros([2 * Nt, s[1]], dtype='float32') 63 | Stack3 = tf.matmul(tf.matmul(tf.transpose(H), H), Stack2) 64 | Stack = tf.concat([Stack1, Stack2, Stack3], 0) 65 | Z = tf.matmul(w, Stack) + b 66 | x_EST = -1 + tf.nn.relu(Z + t)/abs(t) - tf.nn.relu(Z - t)/abs(t) 67 | output_dic['x_EST_{}'.format(k+1)] = x_EST 68 | else: 69 | w = para_dic['w{}'.format(k+1)] 70 | b = para_dic['b{}'.format(k+1)] 71 | t = para_dic['t{}'.format(k+1)] 72 | Stack1 = tf.matmul(tf.transpose(H), y) 73 | Stack2 = output_dic['x_EST_{}'.format(k)] 74 | Stack3 = tf.matmul(tf.matmul(tf.transpose(H), H), Stack2) 75 | Stack = tf.concat([Stack1, Stack2, Stack3], 0) 76 | Z = tf.matmul(w, Stack) + b 77 | x_EST = -1 + tf.nn.relu(Z + t)/abs(t) - tf.nn.relu(Z - t)/abs(t) 78 | output_dic['x_EST_{}'.format(k+1)] = x_EST 79 | #在使用模型时,我们通常需要知道该模型的输入输出,所以记得将输出添加到一个集合中 80 | tf.add_to_collection('DetNet_detector',output_dic['x_EST_{}'.format(k+1)] ) 81 | DetNet_detector = tf.add(output_dic['x_EST_{}'.format(k+1)], 0, 82 | name='DetNet_detector') 83 | 84 | # 4 Define the loss function and backpropagation algorithm 85 | xwave_part1 = tf.matrix_inverse(tf.matmul(tf.transpose(H), H)) 86 | xwave_part2 = tf.matmul(xwave_part1, tf.transpose(H)) 87 | ZF_detector = tf.matmul(xwave_part2, y, name='ZF_detector') 88 | 89 | loss = 0 90 | for m in range(L_size): 91 | loss = loss + tf.log(m+1.0) * tf.reduce_sum( 92 | tf.squared_difference(x,output_dic['x_EST_{}'.format(m+1)])) \ 93 | / tf.reduce_sum(tf.squared_difference(x,ZF_detector)) 94 | tf.summary.scalar('loss',loss) 95 | with tf.name_scope('train'): 96 | train = tf.train.AdamOptimizer( 97 | learning_rate = Learning_rate).minimize(loss=loss) 98 | 99 | # 5 save the model 100 | saver = tf.train.Saver() 101 | 102 | # 6 Create a session to run TensorFlow to train the model 103 | init = tf.global_variables_initializer() 104 | with tf.Session() as sess: 105 | sess.run(init) 106 | merged_summary_op = tf.summary.merge_all() 107 | summary_writer = tf.summary.FileWriter(savedir,sess.graph) 108 | saver.restore(sess,savedir+'DetNetmodel.cpkt') 109 | time_st =tm.time() 110 | now_time = datetime.datetime.now().strftime('%X') 111 | print("Start time is: "+now_time) 112 | for epoch in range(maxEpochs): 113 | source, x_Feed, y_Feed, H_Feed = generate_traindata( 114 | batchsize, Nt, Nr, mod_Name) 115 | sess.run(train,feed_dict={x:x_Feed, y:y_Feed, H:H_Feed}) 116 | if epoch%10==0: 117 | Loss_Value=sess.run(loss,feed_dict={x:x_Feed, y:y_Feed, H:H_Feed}) 118 | now_time = datetime.datetime.now().strftime('%X') 119 | print() 120 | print("After %d training step(s), Current time is: "%epoch+now_time) 121 | summary_str = sess.run(merged_summary_op, feed_dict={ 122 | x:x_Feed, y:y_Feed, H:H_Feed}) 123 | summary_writer.add_summary(summary_str, epoch) 124 | x_DetNet = sess.run(DetNet_detector,feed_dict={ 125 | y:y_Feed, H:H_Feed}) 126 | x_ZF = sess.run(ZF_detector,feed_dict={y:y_Feed, H:H_Feed}) 127 | testbiterrorsum10 = bit_error( 128 | source, x_DetNet, Nt, batchsize, mod_Name) 129 | biterrorsum_ZF = bit_error( 130 | source, x_ZF, Nt, batchsize, mod_Name) 131 | ber_ZF = biterrorsum_ZF / batchsize / Nt / 2 132 | ber_DetNet = testbiterrorsum10 / batchsize / Nt / 2 133 | print("Loss=%g Acc_ZF=%g Acc_DetNet=%g"%( 134 | Loss_Value, ber_ZF,ber_DetNet)) 135 | if epoch%50==0: 136 | saver.save(sess,savedir+'DetNetmodel.cpkt') 137 | time_ed = tm.time() 138 | print("Total time is: "+str(time_ed-time_st)) 139 | 140 | 141 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Mar 22 12:56:54 2019 4 | 5 | @author: zfzdr 6 | """ 7 | import numpy as np 8 | import tensorflow as tf 9 | #modulation and demodulation:bpsk/qpsk 10 | table_bpsk = np.array([1 + 0j, -1 + 0j]) #modType=2 11 | table_qpsk = np.array( 12 | [-0.707107 + 0.707107j, -0.707107 - 0.707107j, 13 | 0.707107 + 0.707107j, 0.707107 - 0.707107j]) # modType=4 14 | mod_Table = {'bpsk': table_bpsk, 'qpsk': table_qpsk} 15 | 16 | #modulation function 17 | #Input 18 | #sourceSeq: binary sequence, the times of log2(modulation order) 19 | #mod_Name:modulation method:bpsk/qpsk 20 | #Output 21 | #modSeq = modulated signal,dtype='complex',shape=(num_Symbol,1) 22 | def modulation(sourceSeq, mod_Name): 23 | mod = mod_Table[mod_Name] 24 | mod_Type = int(mod.size) # Calcuate the Modulation order 25 | # Calcuate the Symbol number of sourceSeq 26 | num_Symbol = int(sourceSeq.size * 2 / mod_Type) 27 | mod_Seq = np.zeros((num_Symbol, 1), dtype='complex128') 28 | if mod_Type == 2: 29 | for i in range(num_Symbol): 30 | index = sourceSeq[i][0] 31 | mod_Seq[i][0] = mod[index] 32 | if mod_Type == 4: 33 | for i in range(num_Symbol): 34 | index = sourceSeq[i * 2][0] * 2 + sourceSeq[i *2 + 1][0] 35 | mod_Seq[i][0] = mod[index] 36 | return mod_Seq 37 | 38 | #demodulation 39 | #Input 40 | #receiveSeq:received signal,dtype='complex',shape=(num_Symbol,1) 41 | #mod_Name:modulation method:bpsk/qpsk 42 | #Output 43 | #demod_Seq:demodulated signal,dtype='int',shape=(num_Symbol*2,1) 44 | def demodulation(receiveSeq, mod_Name): 45 | mod_Type = int(mod_Table[mod_Name].size) 46 | num_Symbol = int(receiveSeq.size) 47 | if mod_Type == 2: 48 | demod_Seq = ((receiveSeq.real < 0) * 1) 49 | if mod_Type == 4: 50 | demod_Seq = np.zeros((num_Symbol * 2, 1)) 51 | for i in range((num_Symbol)): 52 | demod_Seq[i * 2][0] = (receiveSeq[i].real > 0) * 1 53 | demod_Seq[i * 2 + 1][0] = (receiveSeq[i].imag < 0) * 1 54 | return demod_Seq 55 | 56 | # Define function to generate training data 57 | # Input 58 | # batchsize:num of symbols transmitted per antenna in a frame 59 | # Nt:transmit antennas 60 | # Nr:receive antennas 61 | # mod_Name:modulation method:bpsk/qpsk 62 | # output 63 | # source:base data,binary data 64 | # x_real:modulated signal and transfer real number,dtype='float32', 65 | # y_real:received signal and transfer real number,dtype='float32' 66 | # H_real:channel matrix(real nummber) 67 | def generate_traindata(batchsize, Nt, Nr, mod_Name): 68 | mod_Type = int(mod_Table[mod_Name].size) 69 | #H_complex are randomly generated,and H_complex are 归一化的 70 | H_complex = (np.random.randn(Nr,Nt) + 1j*np.random.randn(Nr,Nt))/np.sqrt(2) 71 | H_complex = (np.sqrt(Nr)) / np.sqrt(np.trace(np.dot( 72 | np.transpose(np.conjugate(H_complex)),H_complex))) * H_complex 73 | H_part1 = np.concatenate((np.real(H_complex), np.imag(H_complex))) 74 | H_part2 = np.concatenate((-np.imag(H_complex), np.real(H_complex))) 75 | H_real = np.concatenate((H_part1,H_part2), axis=1).astype('float32') 76 | numbits_perSymbol = np.log2(mod_Type).astype(int) 77 | # generate 0,1 bits 78 | source = np.random.randint(0, 2, ((batchsize * Nt) * numbits_perSymbol, 1)) 79 | # BPSK/QPSK modulation 80 | x_complex = modulation(source, mod_Name) 81 | x_complex = np.reshape(x_complex, (Nt, batchsize)) 82 | # Noise Vector with independent,zero mean Gaussian variables of variance 1 83 | w = np.zeros((Nr, batchsize), dtype='complex') 84 | for m in range(batchsize): 85 | SNR = np.random.uniform(8, 16, 1) # 8dB-13dB uniform distribution 86 | sigma = np.sqrt(1 / (10 ** (SNR /10))) 87 | wpart = sigma / np.sqrt(2) * np.random.randn(Nr) + \ 88 | 1j * sigma / np.sqrt(2) * np.random.randn(Nr) 89 | w[:, m] = wpart 90 | y_complex = np.dot(H_complex, x_complex) + w 91 | x_real = np.concatenate((np.real(x_complex), np.imag(x_complex)), axis=0) 92 | y_real = np.concatenate((np.real(y_complex), np.imag(y_complex)), axis=0) 93 | return source, x_real, y_real, H_real 94 | 95 | # Define function to generate test data 96 | # Input 97 | # batchsize:num of symbols transmitted per antenna in a frame 98 | # Nt:transmit antennas 99 | # Nr:receive antennas 100 | # mod_Name:modulation method:bpsk/qpsk 101 | # SNR:specify the SNR to generate different receive signal 102 | # output 103 | # source:base data,binary data 104 | # x_real:modulated signal and transfer real number,dtype='float32', 105 | # y_real:received signal and transfer real number,dtype='float32' 106 | # H_real:channel matrix(real nummber) 107 | def generate_testdata(batchsize, Nt, Nr, mod_Name,SNR): 108 | mod_Type = int(mod_Table[mod_Name].size) 109 | #H are randomly generated 110 | H_complex = (np.random.randn(Nr,Nt) + 1j*np.random.randn(Nr,Nt))/np.sqrt(2) 111 | H_complex = (np.sqrt(Nr)) / np.sqrt(np.trace(np.dot( 112 | np.transpose(np.conjugate(H_complex)),H_complex))) * H_complex 113 | H_part1 = np.concatenate((np.real(H_complex), np.imag(H_complex))) 114 | H_part2 = np.concatenate((-np.imag(H_complex), np.real(H_complex))) 115 | H_real = np.concatenate((H_part1,H_part2), axis=1).astype('float32') 116 | numbits_perSymbol = np.log2(mod_Type).astype(int) 117 | # generate 0,1 bits 118 | source = np.random.randint(0, 2, ((batchsize * Nt) * numbits_perSymbol, 1)) 119 | # BPSK/QPSK modulation 120 | x_complex = modulation(source, mod_Name) 121 | x_complex = np.reshape(x_complex, (Nt, batchsize)) 122 | # Noise Vector with independent,zero mean Gaussian variables of variance 1 123 | w = np.zeros((Nr, batchsize), dtype='complex') 124 | for m in range(batchsize): 125 | sigma = np.sqrt(1 / (10 ** (SNR /10))) 126 | wpart = sigma / np.sqrt(2) * np.random.randn(Nr) + \ 127 | 1j * sigma / np.sqrt(2) * np.random.randn(Nr) 128 | w[:, m] = wpart 129 | y_complex = np.dot(H_complex, x_complex) + w 130 | x_real = np.concatenate((np.real(x_complex), np.imag(x_complex)), axis=0) 131 | y_real = np.concatenate((np.real(y_complex), np.imag(y_complex)), axis=0) 132 | return source, x_real, y_real, H_real 133 | 134 | # Calculate the number of error bit between source and x_EST 135 | def bit_error(source, x_EST, Nt, numSymbol_perChannel, mod_Name): 136 | x_EST_complex = x_EST[0:Nt, :] + 1j*x_EST[Nt:2*Nt, :] 137 | x_EST_complex_seq = np.reshape(x_EST_complex, (numSymbol_perChannel*Nt, 1)) 138 | x_ESTbit_seq = demodulation(x_EST_complex_seq, mod_Name) 139 | num_biterror = np.sum(abs(x_ESTbit_seq - source)) 140 | return num_biterror 141 | 142 | ##python创建目录文件夹 143 | #def mkdir(dir): 144 | # isExists = os.path.exists(dir) 145 | # if not isExists: 146 | # os.mkdir(dir) 147 | # print(dir+'创建成功') 148 | # return True 149 | # else: 150 | # print(dir+'目录已存在') 151 | # return false 152 | 153 | 154 | 155 | ##载入参数 156 | #saver = tf.train.import_meta_graph(savedir+'DetNetmodel.cpkt.meta') 157 | #graph = tf.get_default_graph() 158 | #y = graph.get_tensor_by_name(name='receiver:0') 159 | #H = graph.get_tensor_by_name(name='channel:0') 160 | #ZF_detector = graph.get_tensor_by_name(name='ZF_detector:0') 161 | #DetNet_detector = graph.get_tensor_by_name(name='DetNet_detector:0') 162 | class Predict(): 163 | def __init__(self, loc): 164 | self.graph=tf.Graph()#为每个类(实例)单独创建一个graph 165 | with self.graph.as_default(): 166 | self.saver=tf.train.import_meta_graph(loc+'.meta')#创建恢复器 167 | #注意!恢复器必须要在新创建的图里面生成,否则会出错。 168 | self.sess=tf.Session(graph=self.graph)#创建新的sess 169 | with self.sess.as_default(): 170 | with self.graph.as_default(): 171 | self.saver.restore(self.sess,loc)#从恢复点恢复参数 172 | 173 | def DetNet_detector(self, y_test, H_test): 174 | return self.sess.run('DetNet_detector:0',feed_dict={'receiver:0':y_test,'channel:0':H_test}) 175 | 176 | def ZF_detector(self, y_test, H_test): 177 | return self.sess.run('ZF_detector:0',feed_dict={'receiver:0':y_test,'channel:0':H_test}) 178 | --------------------------------------------------------------------------------