├── Readme.md ├── get_data.py ├── magfft_lstm_classify.py └── rml_lstm_classify.py /Readme.md: -------------------------------------------------------------------------------- 1 | # Repository for LSTM based modulation classifier. 2 | Paper under review. 3 | 4 | ## Prerequisites 5 | Tensorflow 6 | 7 | Tflearn 8 | 9 | 10 | ## Files 11 | 12 | get_data.py -- Script for retrieving data from Electrosense 13 | 14 | magfft_lstm_classify.py -- Magnitude FFT based classification 15 | 16 | rml_lstm_classify.py -- Amplitude-phase classification model 17 | -------------------------------------------------------------------------------- /get_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import requests 4 | import json 5 | import time,sys 6 | 7 | from requests.auth import HTTPBasicAuth 8 | from collections import OrderedDict 9 | from urllib import urlencode 10 | 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import optparse 14 | import random 15 | import getpass 16 | #import initExample ## Add path to library (just for examples; you do not need this) 17 | import pyqtgraph as pg 18 | from pyqtgraph.Qt import QtCore, QtGui 19 | import numpy as np 20 | #import tflearn 21 | from numpy import linalg as la 22 | 23 | parser = optparse.OptionParser("usage: %prog -u [-p -r -t -f ]") 24 | parser.add_option("-u", "--user", dest="username", 25 | type="string", 26 | help="API username") 27 | parser.add_option("-p", "--pass", dest="password", 28 | type="string", help="API password") 29 | 30 | parser.add_option("-r", "--range", dest="frange", 31 | type="string", help="frequency range separated by commas") 32 | 33 | parser.add_option("-t", "--tresol", dest="tresol", 34 | type="string", help="time resolution") 35 | 36 | parser.add_option("-f", "--fresol", dest="fresol", 37 | type="string", help="frequency resolution") 38 | 39 | (options, args) = parser.parse_args() 40 | if not options.username: 41 | parser.error("Username not specified") 42 | 43 | if not options.password: 44 | options.password = getpass.getpass('Password:') 45 | 46 | # Electrosense API Credentials 47 | username=options.username 48 | password=options.password 49 | 50 | # Electrosense API 51 | MAIN_URI ='https://test.electrosense.org/api' 52 | SENSOR_LIST = MAIN_URI + '/sensor/list/' 53 | SENSOR_AGGREGATED = MAIN_URI + "/spectrum/aggregated" 54 | 55 | r = requests.get(SENSOR_LIST, auth=HTTPBasicAuth(username, password)) 56 | 57 | if r.status_code != 200: 58 | print r.content 59 | exit(-1) 60 | 61 | slist_json = json.loads(r.content) 62 | 63 | senlist={} 64 | status=[" (off)", " (on)"] 65 | 66 | for i, sensor in enumerate(slist_json): 67 | print "[%d] %s (%d) - Sensing: %s" % (i, sensor['name'], sensor['serial'], sensor['sensing']) 68 | senlist[sensor['name']+status[int(sensor['sensing'])]]=i 69 | 70 | print "" 71 | pos = int( raw_input("Please enter the sensor: ")) 72 | 73 | print "" 74 | print " %s (%d) - %s" % (slist_json[pos]['name'], slist_json[pos]['serial'], slist_json[pos]['sensing']) 75 | 76 | 77 | # Ask for 5 minutes of aggregatd spectrum data 78 | 79 | def get_spectrum_data (sensor_id, timeBegin, timeEnd, aggFreq, aggTime, minfreq, maxfreq): 80 | 81 | params = OrderedDict([('sensor', sensor_id), 82 | ('timeBegin', timeBegin), 83 | ('timeEnd', timeEnd), 84 | ('freqMin', int(minfreq)), 85 | ('freqMax', int(maxfreq)), 86 | ('aggFreq', aggFreq), 87 | ('aggTime', aggTime), 88 | ('aggFun','AVG')]) 89 | 90 | 91 | r = requests.get(SENSOR_AGGREGATED, auth=HTTPBasicAuth(username, password), params=urlencode(params)) 92 | 93 | 94 | if r.status_code == 200: 95 | return json.loads(r.content) 96 | else: 97 | print "Response: %d" % (r.status_code) 98 | return None 99 | 100 | sp1 = None 101 | sp2 = None 102 | sp3 = None 103 | 104 | epoch_time = int(time.time()) 105 | timeBegin = epoch_time - (3600*24*2) 106 | #timeEnd = timeBegin + (3600*20*2) 107 | timeEnd = timeBegin + (60*4) 108 | if not options.fresol: 109 | freqresol = int(100e3) 110 | else: 111 | freqresol = int(float(options.fresol)) 112 | 113 | if not options.tresol: 114 | tresol = int(60) 115 | else: 116 | tresol = int(float(options.tresol)) 117 | 118 | if not options.frange: 119 | minfreq = 50e6 120 | maxfreq = 1500e6 121 | else: 122 | minfreq = int(float(options.frange.split(",")[0])) 123 | maxfreq = int(float(options.frange.split(",")[1])) 124 | 125 | senid = slist_json[pos]['serial'] 126 | response = get_spectrum_data (slist_json[pos]['serial'], timeBegin, timeEnd, freqresol, tresol, minfreq, maxfreq) 127 | data=np.array(response['values']) 128 | print "Data:",data.shape 129 | 130 | 131 | 132 | 133 | # Interpret image data as row-major instead of col-major 134 | pg.setConfigOptions(imageAxisOrder='row-major') 135 | 136 | pg.mkQApp() 137 | #pg.setConfigOption('background','w') 138 | #pg.setConfigOption('foreground','k') 139 | tab = pg.QtGui.QTabWidget() 140 | tab.show() 141 | grid = QtGui.QGridLayout() 142 | qwid = QtGui.QWidget() 143 | qwid.setLayout(grid) 144 | split1 = QtGui.QSplitter() 145 | grid.addWidget(split1) 146 | win = pg.GraphicsLayoutWidget() 147 | win.setWindowTitle('pyqtgraph example: Image Analysis') 148 | scroll= tab.addTab(qwid,"Spectrum") 149 | split1.addWidget(win) 150 | 151 | # A plot area (ViewBox + axes) for displaying the image 152 | p1 = win.addPlot() 153 | 154 | # Item for displaying image data 155 | img = pg.ImageItem() 156 | p1.addItem(img) 157 | 158 | # Custom ROI for selecting an image region 159 | bpos = [100,1] 160 | roi = pg.ROI(bpos, [data.shape[1]/10, data.shape[0]/3]) 161 | roi.addScaleHandle([0.5, 1], [0.5, 0.5]) 162 | roi.addScaleHandle([0, 0.5], [0.5, 0.5]) 163 | p1.addItem(roi) 164 | roi.setZValue(10) # make sure ROI is drawn above image 165 | 166 | # Isocurve drawing 167 | iso = pg.IsocurveItem(level=0.8, pen='g') 168 | iso.setParentItem(img) 169 | iso.setZValue(5) 170 | 171 | # Contrast/color control 172 | hist = pg.HistogramLUTItem() 173 | hist.gradient.loadPreset("flame") 174 | hist.setImageItem(img) 175 | win.addItem(hist) 176 | 177 | # Draggable line for setting isocurve level 178 | isoLine = pg.InfiniteLine(angle=0, movable=True, pen='g') 179 | hist.vb.addItem(isoLine) 180 | hist.vb.setMouseEnabled(y=False) # makes user interaction a little easier 181 | isoLine.setValue(0.8) 182 | isoLine.setZValue(1000) # bring iso line above contrast controls 183 | 184 | # Another plot area for displaying ROI data 185 | win.nextRow() 186 | p2 = win.addPlot(colspan=2) 187 | p2.setMaximumHeight(250) 188 | win.resize(800, 800) 189 | save = QtGui.QPushButton('Save') 190 | classif = QtGui.QPushButton('Classify') 191 | fdata = QtGui.QPushButton('Fetch') 192 | cb = pg.ComboBox() 193 | cb.setItems(senlist) 194 | cb.setValue(pos) 195 | stresol = pg.SpinBox(value=tresol, step=1, bounds=[0, None]) 196 | sfresol = pg.SpinBox(value=freqresol, step=freqresol, bounds=[0, None]) 197 | sminfreq = pg.SpinBox(value=minfreq, step=freqresol, bounds=[0, None]) 198 | smaxfreq = pg.SpinBox(value=maxfreq, step=freqresol, bounds=[0, None]) 199 | stbegin= pg.SpinBox(value=3600*24*2, step=1, bounds=[0, None]) 200 | sduration = pg.SpinBox(value=4, step=1, bounds=[0, None]) 201 | slabel= QtGui.QLabel("Sensor") 202 | tlabel= QtGui.QLabel("Time resolution (s)") 203 | flabel= QtGui.QLabel("Frequency resolution (Hz)") 204 | minflabel= QtGui.QLabel("Min freq (Hz)") 205 | maxflabel= QtGui.QLabel("Max freq (Hz)") 206 | dlabel= QtGui.QLabel("Data duration (s)") 207 | tblabel= QtGui.QLabel("Begin time (s): Current time-") 208 | grid2 = QtGui.QGridLayout() 209 | grid3 = QtGui.QGridLayout() 210 | qwid2 = QtGui.QWidget() 211 | qwid3 = QtGui.QWidget() 212 | qwid2.setLayout(grid2) 213 | qwid3.setLayout(grid3) 214 | grid2.addWidget(save,0,0) 215 | grid2.addWidget(classif,0,1) 216 | grid2.addWidget(fdata,0,2) 217 | split1.addWidget(qwid2) 218 | split1.addWidget(qwid3) 219 | grid3.addWidget(slabel,0,0) 220 | grid3.addWidget(cb,0,1) 221 | grid3.addWidget(tlabel,0,2) 222 | grid3.addWidget(stresol,0,3) 223 | grid3.addWidget(flabel,0,4) 224 | grid3.addWidget(sfresol,0,5) 225 | grid3.addWidget(minflabel,1,0) 226 | grid3.addWidget(sminfreq,1,1) 227 | grid3.addWidget(maxflabel,1,2) 228 | grid3.addWidget(smaxfreq,1,3) 229 | grid3.addWidget(dlabel,1,4) 230 | grid3.addWidget(sduration,1,5) 231 | grid3.addWidget(tblabel,1,6) 232 | grid3.addWidget(stbegin,1,7) 233 | split1.setOrientation(0); 234 | win.show() 235 | 236 | saveData="" 237 | 238 | def savepath(): 239 | global saveData 240 | fileName = QtGui.QFileDialog.getSaveFileName() 241 | if fileName: 242 | outfile = fileName[0] 243 | np.save(outfile, saveData) 244 | print "File saved:",outfile 245 | 246 | def updsensor(val): 247 | global slist_json, senid 248 | senid = slist_json[cb.value()]['serial'] 249 | 250 | def updtresol(val): 251 | global tresol 252 | tresol = int(float(val.value())) 253 | 254 | def updfresol(val): 255 | global freqresol 256 | freqresol = int(float(val.value())) 257 | 258 | def updminfreq(val): 259 | global minfreq 260 | minfreq = int(float(val.value())) 261 | 262 | def updmaxfreq(val): 263 | global maxfreq 264 | maxfreq = int(float(val.value())) 265 | 266 | def updduration(val): 267 | global duration, timeBegin, timeEnd 268 | duration = int(float(val.value())) 269 | timeEnd = timeBegin + 60*duration 270 | 271 | def updtbegin(val): 272 | global timeBegin 273 | epoch_time = int(time.time()) 274 | timeBegin = epoch_time - int(float(val.value())) 275 | 276 | cb.currentIndexChanged.connect(updsensor) 277 | stresol.sigValueChanged.connect(updtresol) 278 | sfresol.sigValueChanged.connect(updfresol) 279 | sminfreq.sigValueChanged.connect(updminfreq) 280 | smaxfreq.sigValueChanged.connect(updmaxfreq) 281 | sduration.sigValueChanged.connect(updduration) 282 | stbegin.sigValueChanged.connect(updtbegin) 283 | save.clicked.connect(savepath) 284 | img.setImage(data) 285 | hist.setLevels(data.min(), data.max()) 286 | hist.autoHistogramRange() 287 | 288 | # zoom to fit imageo 289 | p1.autoRange() 290 | 291 | text = pg.TextItem(html='
result
', anchor=(-0.3,0.5), angle=45, border='w', fill=(0, 0, 255, 100)) 292 | #text = pg.TextItem("test", anchor=(0.5, -1.0)) 293 | 294 | ''' 295 | #deeplearn model 296 | nsamples=1024 297 | labels = ["dvb", "radar", "gsm", "tetra","wfm", "lte"] 298 | network = tflearn.input_data(shape=[None, nsamples, 1],name="inp") 299 | network = tflearn.lstm(network, 128, dynamic=True) 300 | network = tflearn.fully_connected(network, len(labels), activation='softmax',name="out") 301 | network = tflearn.regression(network, optimizer='adam', 302 | loss='categorical_crossentropy', 303 | learning_rate=0.001) 304 | model = tflearn.DNN(network,tensorboard_verbose=2) 305 | model.load('lstm_tech_classify_gpu.tfl') 306 | ''' 307 | 308 | def lnorm(X_train): 309 | print "Pad:", X_train.shape 310 | for i in range(X_train.shape[0]): 311 | X_train[i,:] = X_train[i,:]/la.norm(X_train[i,:],2) 312 | return X_train 313 | 314 | ''' 315 | def classify(): 316 | global text,nsamples 317 | #mods = ['OFDM','AM','FM'] 318 | if saveData.shape[1] < nsamples: 319 | res = np.zeros((saveData.shape[0],nsamples)) 320 | #append zeros 321 | res[:,:saveData.shape[1]] = saveData 322 | else: 323 | res=saveData[:,:nsamples] 324 | res = lnorm(res) 325 | res = np.reshape(res,(-1,nsamples,1)) 326 | pred = np.array(model.predict(res)) 327 | print pred 328 | print np.argmax(pred, axis=1) 329 | counts = np.bincount(np.argmax(pred,axis=1)) 330 | mod = np.argmax(counts) 331 | ht = '
'+labels[mod]+'
' 332 | text = pg.TextItem(html=ht, anchor=(-0.3,0.5), angle=45, border='w', fill=(0, 0, 255, 100)) 333 | updatePlot() 334 | p2.addItem(text) 335 | ''' 336 | 337 | 338 | def fetch(): 339 | global senid, timeBegin, timeEnd, freqresol, tresol, minfreq, maxfreq, data, img, hist 340 | #with pg.ProgressDialog("Generating test.hdf5...", 0, 100, cancelText=None, wait=0) as dlg: 341 | try: 342 | response = get_spectrum_data(senid, timeBegin, timeEnd, freqresol, tresol, minfreq, maxfreq) 343 | data=np.array(response['values']) 344 | hist.setLevels(data.min(), data.max()) 345 | img.setImage(data) 346 | hist.setImageItem(img) 347 | print "Data fetched" 348 | updatePlot() 349 | except Exception as e: 350 | print str(e) 351 | 352 | 353 | #classif.clicked.connect(classify) 354 | fdata.clicked.connect(fetch) 355 | 356 | # Callbacks for handling user interaction 357 | def updatePlot(): 358 | global img, roi, data, p2, saveData, minfreq, freqresol,text 359 | selected = roi.getArrayRegion(data, img) 360 | saveData = selected 361 | print "Selected shape:", np.shape(selected) 362 | startfreq= minfreq+int(roi.pos()[0]*freqresol) 363 | stopfreq= startfreq+int(selected.shape[1]*freqresol) 364 | x = np.arange(startfreq,stopfreq,freqresol) 365 | xdict = dict(enumerate(x)) 366 | mval = selected.mean(axis=0) 367 | p2.plot(x,mval, clear=True) 368 | text.setPos(x[np.argmax(mval)],mval.max()) 369 | 370 | roi.sigRegionChanged.connect(updatePlot) 371 | updatePlot() 372 | 373 | def updateIsocurve(): 374 | global isoLine, iso 375 | iso.setLevel(isoLine.value()) 376 | 377 | isoLine.sigDragged.connect(updateIsocurve) 378 | 379 | 380 | ## Start Qt event loop unless running in interactive mode or using pyside. 381 | if __name__ == '__main__': 382 | import sys 383 | if (sys.flags.interactive != 1) or not hasattr(QtCore, 'PYQT_VERSION'): 384 | QtGui.QApplication.instance().exec_() 385 | -------------------------------------------------------------------------------- /magfft_lstm_classify.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tflearn 3 | import numpy as np 4 | import scipy.fftpack as spfft 5 | import tensorflow as tf 6 | import scipy.io as sio 7 | from sklearn.metrics import confusion_matrix 8 | import matplotlib.pyplot as plt 9 | from scipy.signal import blackman 10 | import sys 11 | import operator 12 | from collections import OrderedDict 13 | from numpy import linalg as la 14 | import os 15 | 16 | 17 | ''' 18 | Directory containing dump files with names "_.npy" 19 | e.g. lte_1.npy, lte_2.npy, gsm_1.npy... 20 | This can be easily generated using the labeling tool over the api 21 | ''' 22 | mydir="../tech_dumps/" 23 | files = [] 24 | for file in os.listdir(mydir): 25 | if file.endswith(".npy"): 26 | files.append(os.path.join(mydir, file)) 27 | 28 | print files 29 | labels={} 30 | lfiles={} 31 | count=0 32 | for f in files: 33 | fname = f.split("/")[-1] 34 | if not labels.has_key(fname.split("_")[0]): 35 | labels[fname.split("_")[0]]=count 36 | lfiles[fname.split("_")[0]]=[] 37 | count+=1 38 | 39 | lfiles[fname.split("_")[0]].append(f) 40 | 41 | 42 | labels = OrderedDict(sorted(labels.items(), key=operator.itemgetter(1))) 43 | 44 | 45 | print labels 46 | print lfiles 47 | 48 | num_labels = len(labels) 49 | nsamples = 0 50 | 51 | for f in files: 52 | dta = np.load(f) 53 | if nsamples < dta.shape[1]: 54 | nsamples = dta.shape[1] 55 | 56 | 57 | datatype = "float32" 58 | train_data = np.zeros(nsamples,dtype=datatype) 59 | test_data = np.zeros(nsamples,dtype=datatype) 60 | train_labels = np.zeros(num_labels) 61 | test_labels = np.zeros(num_labels) 62 | 63 | 64 | def setup_data(): 65 | global train_data, train_labels, valid_data, valid_labels, test_data, test_labels 66 | for key in labels.keys(): 67 | print("--"*50) 68 | for f in lfiles[key]: 69 | dta = np.load(f) 70 | res = np.zeros((dta.shape[0],nsamples)) 71 | #append zeros 72 | res[:,:dta.shape[1]] = dta 73 | train_cnt = dta.shape[0]/2 74 | test_cnt = dta.shape[0]/2 75 | train_data = np.vstack((train_data,res[0:train_cnt])) 76 | dummy_labels = np.zeros((train_cnt, len(labels))) 77 | dummy_labels[:, labels[key]] = 1 78 | train_labels = np.vstack((train_labels,dummy_labels)) 79 | print("Training data: Generation done for:", key) 80 | test_data = np.vstack((test_data,res[train_cnt:train_cnt+test_cnt])) 81 | dummy_labels = np.zeros((test_cnt, len(labels))) 82 | dummy_labels[:, labels[key]] = 1 83 | test_labels = np.vstack((test_labels,dummy_labels)) 84 | print("Testing data: Generation done for:", key) 85 | train_data = np.delete(train_data,0,0) 86 | test_data = np.delete(test_data,0,0) 87 | train_labels = np.delete(train_labels,0,0) 88 | test_labels = np.delete(test_labels,0,0) 89 | 90 | 91 | setup_data() 92 | Y_train = train_labels 93 | Y_test = test_labels 94 | 95 | 96 | print train_data.shape 97 | print test_data.shape 98 | 99 | def lnorm(X_train): 100 | print "Pad:", X_train.shape 101 | for i in range(X_train.shape[0]): 102 | X_train[i,:] = X_train[i,:]/la.norm(X_train[i,:],2) 103 | return X_train 104 | 105 | train_data = lnorm(train_data) 106 | test_data = lnorm(test_data) 107 | 108 | #out0 = (out0-np.mean(out0))/np.std(out0) 109 | 110 | X_train = np.reshape(train_data,(-1,nsamples,1)) 111 | X_test = np.reshape(test_data,(-1,nsamples,1)) 112 | 113 | def getFontColor(value): 114 | if np.isnan(value): 115 | return "black" 116 | elif value < 0.2: 117 | return "black" 118 | else: 119 | return "white" 120 | 121 | def getConfusionMatrixPlot(true_labels, predicted_labels): 122 | # Compute confusion matrix 123 | cm = confusion_matrix(true_labels, predicted_labels) 124 | cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 125 | cm_norm = np.nan_to_num(cm_norm) 126 | cm = np.round(cm_norm,2) 127 | print(cm) 128 | 129 | # create figure 130 | fig = plt.figure() 131 | plt.clf() 132 | ax = fig.add_subplot(111) 133 | ax.set_aspect(1) 134 | ax.set_xlabel('Predicted label') 135 | ax.set_ylabel('True label') 136 | res = ax.imshow(cm, cmap=plt.cm.binary, 137 | interpolation='nearest', vmin=0, vmax=1) 138 | 139 | # add color bar 140 | plt.colorbar(res) 141 | 142 | # annotate confusion entries 143 | width = len(cm) 144 | height = len(cm[0]) 145 | 146 | for x in xrange(width): 147 | for y in xrange(height): 148 | ax.annotate(str(cm[x][y]), xy=(y, x), horizontalalignment='center', 149 | verticalalignment='center', color=getFontColor(cm[x][y])) 150 | 151 | # add genres as ticks 152 | alphabet = labels.keys() 153 | plt.xticks(range(width), alphabet[:width], rotation=30) 154 | plt.yticks(range(height), alphabet[:height]) 155 | return plt 156 | 157 | class MonitorCallback(tflearn.callbacks.Callback): 158 | def __init__(self, model): 159 | self.model = model 160 | self.accuracy = 0.0 161 | 162 | def on_epoch_end(self, training_state): 163 | print "accuracy1:", training_state.global_acc 164 | print "accuracy2:", training_state.val_acc 165 | if self.accuracy