├── README.md
├── RNN-based Decoder for Convolutional Codes.ipynb
└── coding.py
/README.md:
--------------------------------------------------------------------------------
1 | # On Recurrent Neural Networks for Sequence-based Processing in Communications
2 | ## In this notebook we show how to build a decoder for convolutional codes based on recurrent neural networks
3 | Accompanying code of paper ["On Recurrent Neural Networks for Sequence-based Processing in Communications" by Daniel Tandler, Sebastian Dörner, Sebastian Cammerer, Stephan ten Brink](https://ieeexplore.ieee.org/document/9048728)
4 |
5 | If you find this code helpful please cite this work using the following bibtex entry:
6 |
7 | ```tex
8 | @article{RNN-Conv-Decoding-Tandler2019,
9 | author = {Daniel Tandler and
10 | Sebastian D{\"{o}}rner and
11 | Sebastian Cammerer and
12 | Stephan ten Brink},
13 | booktitle = {2019 53rd Asilomar Conference on Signals, Systems, and Computers},
14 | title = {On Recurrent Neural Networks for Sequence-based Processing in Communications},
15 | year = {2019},
16 | pages = {537-543}
17 | }
18 | ```
19 |
20 |
21 | ## Installation/Setup
22 |
23 | An example of the used code is given in the Jupyter Notebook (.ipynb file), the coding.py file is just for arbitrary code generation and not required to run the notebook.
24 |
25 | You can directly run the notebook with code and short explanations in google colab:
26 |
27 | [Run this Notebook in Google Colaboratory: Link to colab.google.com](https://colab.research.google.com/github/sdnr/RNN-Conv-Decoder/blob/master/RNN-based%20Decoder%20for%20Convolutional%20Codes.ipynb)
28 |
--------------------------------------------------------------------------------
/RNN-based Decoder for Convolutional Codes.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "accelerator": "GPU",
6 | "colab": {
7 | "name": "Copy of RNN-based Decoder for Convolutional Codes.ipynb",
8 | "provenance": [],
9 | "collapsed_sections": []
10 | },
11 | "kernelspec": {
12 | "display_name": "Python 3",
13 | "language": "python",
14 | "name": "python3"
15 | },
16 | "language_info": {
17 | "codemirror_mode": {
18 | "name": "ipython",
19 | "version": 3
20 | },
21 | "file_extension": ".py",
22 | "mimetype": "text/x-python",
23 | "name": "python",
24 | "nbconvert_exporter": "python",
25 | "pygments_lexer": "ipython3",
26 | "version": "3.6.8"
27 | },
28 | "varInspector": {
29 | "cols": {
30 | "lenName": 16,
31 | "lenType": 16,
32 | "lenVar": 40
33 | },
34 | "kernels_config": {
35 | "python": {
36 | "delete_cmd_postfix": "",
37 | "delete_cmd_prefix": "del ",
38 | "library": "var_list.py",
39 | "varRefreshCmd": "print(var_dic_list())"
40 | },
41 | "r": {
42 | "delete_cmd_postfix": ") ",
43 | "delete_cmd_prefix": "rm(",
44 | "library": "var_list.r",
45 | "varRefreshCmd": "cat(var_dic_list()) "
46 | }
47 | },
48 | "oldHeight": 654.4,
49 | "position": {
50 | "height": "676px",
51 | "left": "714px",
52 | "right": "63px",
53 | "top": "-5px",
54 | "width": "800px"
55 | },
56 | "types_to_exclude": [
57 | "module",
58 | "function",
59 | "builtin_function_or_method",
60 | "instance",
61 | "_Feature"
62 | ],
63 | "varInspector_section_display": "block",
64 | "window_display": false
65 | }
66 | },
67 | "cells": [
68 | {
69 | "cell_type": "markdown",
70 | "metadata": {
71 | "id": "xFHXWn4HnIA2"
72 | },
73 | "source": [
74 | "# On Recurrent Neural Networks for Sequence-based Processing in Communications\n",
75 | "## In this notebook we show how to build a decoder for convolutional codes based on recurrent neural networks\n",
76 | "Accompanying code of paper [\"On Recurrent Neural Networks for Sequence-based Processing in Communications\" by Daniel Tandler, Sebastian Dörner, Sebastian Cammerer, Stephan ten Brink](https://ieeexplore.ieee.org/document/9048728)\n",
77 | "\n",
78 | "If you find this code helpful please cite this work using the following bibtex entry:\n",
79 | "\n",
80 | "```tex\n",
81 | "@article{RNN-Conv-Decoding-Tandler2019,\n",
82 | " author = {Daniel Tandler and\n",
83 | " Sebastian D{\\\"{o}}rner and\n",
84 | " Sebastian Cammerer and\n",
85 | " Stephan ten Brink},\n",
86 | " booktitle = {2019 53rd Asilomar Conference on Signals, Systems, and Computers},\n",
87 | " title = {On Recurrent Neural Networks for Sequence-based Processing in Communications},\n",
88 | " year = {2019},\n",
89 | " pages = {537-543}\n",
90 | "}\n",
91 | "```"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "source": [
97 | "# downgrading numpy to not run into error with tf1's cudnngru\n",
98 | "!pip3 install numpy==1.19.2"
99 | ],
100 | "metadata": {
101 | "colab": {
102 | "base_uri": "https://localhost:8080/"
103 | },
104 | "id": "bCqo_Me1BAUe",
105 | "outputId": "8014e8e8-70c5-4c38-b37e-a4f88d4ea905"
106 | },
107 | "execution_count": 1,
108 | "outputs": [
109 | {
110 | "output_type": "stream",
111 | "name": "stdout",
112 | "text": [
113 | "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
114 | "Requirement already satisfied: numpy==1.19.2 in /usr/local/lib/python3.7/dist-packages (1.19.2)\n"
115 | ]
116 | }
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "metadata": {
122 | "id": "QEy0wTWmvk71",
123 | "colab": {
124 | "base_uri": "https://localhost:8080/"
125 | },
126 | "outputId": "e94acd5f-3a1c-4fea-b3d7-b69df3c3faf4"
127 | },
128 | "source": [
129 | "# magic command to use TF 1.X in colaboraty when importing tensorflow\n",
130 | "%tensorflow_version 1.x \n",
131 | "import tensorflow as tf # imports the tensorflow library to the python kernel\n",
132 | "tf.logging.set_verbosity(tf.logging.ERROR) # sets the amount of debug information from TF (INFO, WARNING, ERROR)\n",
133 | "\n",
134 | "print(\"Using tensorflow version:\", tf.__version__)"
135 | ],
136 | "execution_count": 2,
137 | "outputs": [
138 | {
139 | "output_type": "stream",
140 | "name": "stdout",
141 | "text": [
142 | "TensorFlow 1.x selected.\n",
143 | "Using tensorflow version: 1.15.2\n"
144 | ]
145 | }
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "metadata": {
151 | "id": "o0V_CfUbnIBB",
152 | "colab": {
153 | "base_uri": "https://localhost:8080/"
154 | },
155 | "outputId": "72fc5df0-40ce-4bea-f21f-327b4b75c68d"
156 | },
157 | "source": [
158 | "import tensorflow as tf\n",
159 | "import numpy as np\n",
160 | "print(\"numpy version:\",np.__version__)\n",
161 | "import matplotlib.pyplot as plt"
162 | ],
163 | "execution_count": 3,
164 | "outputs": [
165 | {
166 | "output_type": "stream",
167 | "name": "stdout",
168 | "text": [
169 | "numpy version: 1.19.2\n"
170 | ]
171 | }
172 | ]
173 | },
174 | {
175 | "cell_type": "markdown",
176 | "metadata": {
177 | "id": "Lwaxsc7YnIBY"
178 | },
179 | "source": [
180 | "# Code Setup\n",
181 | "\n",
182 | "We first set up a code class that holds all necessary parameters and provides functions to quickly generate large samples of encoded bits.\n",
183 | "\n",
184 | "For this notebook, only code examples for memory 1,2,4 and 6 are provided.\n",
185 | "\n",
186 | "To generate other convolutional codes check out the accompanying coding.py which uses [CommPy](https://github.com/veeresht/CommPy) to generate arbitray codes."
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "metadata": {
192 | "id": "rz8aiownnIBc"
193 | },
194 | "source": [
195 | "class code:\n",
196 | " def __init__(self,m):\n",
197 | " self.m = m # Number of delay elements in the convolutional encoder\n",
198 | " self.tb_depth = 5*(self.m + 1) # Traceback depth of the decoder\n",
199 | " self.code_rate = 0.5\n",
200 | " if m == 1:\n",
201 | " self.d1 = 0o1\n",
202 | " self.d2 = 0o3\n",
203 | " self.impulse_response = np.array([0, 1, 1, 1])\n",
204 | " self.viterbi_reference = np.array([7.293792e-02,5.801720e-02,4.490250e-02,3.349593e-02,2.429049e-02,1.684274e-02,1.124068e-02,7.277303e-03,4.354604e-03,2.546695e-03,1.382015e-03,7.138968e-04])\n",
205 | " elif m == 2:\n",
206 | " self.d1 = 0o5\n",
207 | " self.d2 = 0o7\n",
208 | " self.impulse_response = np.array([1, 1, 0, 1, 1, 1])\n",
209 | " self.viterbi_reference = np.array([9.278817e-02,6.424232e-02,4.195904e-02,2.531590e-02,1.424276e-02,7.385386e-03,3.617080e-03,1.526589e-03,6.319029e-04,2.502278e-04,7.633503e-05,2.566724e-05])\n",
210 | " elif m == 4:\n",
211 | " self.d1 = 0o23\n",
212 | " self.d2 = 0o35\n",
213 | " self.impulse_response = np.array([1, 1, 0, 1, 0, 1, 1, 0, 1, 1])\n",
214 | " self.viterbi_reference = np.array([1.266374e-01,7.990744e-02,4.546113e-02,2.301058e-02,1.045569e-02,4.220632e-03,1.526512e-03,5.214676e-04,1.482288e-04,3.666830e-05,7.778123e-06,1.444509e-06])\n",
215 | " elif m == 6:\n",
216 | " self.d1 = 0o133\n",
217 | " self.d2 = 0o171\n",
218 | " self.impulse_response = np.array([1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1])\n",
219 | " self.viterbi_reference = np.array([1.547330e-01,8.593706e-02,3.985466e-02,1.544436e-02,5.221681e-03,1.378203e-03,3.501900e-04,8.042758e-05,1.676778e-05,2.989088e-06,3.444674e-07,np.NaN])\n",
220 | " else:\n",
221 | " print(\"Code not available!\")\n",
222 | " \n",
223 | " def zero_pad(self,u):\n",
224 | " return np.reshape(np.stack([u,np.zeros_like(u)],axis=1),(-1,))\n",
225 | " \n",
226 | " def encode_sequence(self,u,terminate=False):\n",
227 | " if terminate:\n",
228 | " return np.convolve(self.zero_pad(u),self.impulse_response,mode='full')[:-1] % 2\n",
229 | " else:\n",
230 | " return np.convolve(self.zero_pad(u),self.impulse_response,mode='full')[:len(u)*2] % 2\n",
231 | " \n",
232 | " def encode_batch(self,u,terminate=False):\n",
233 | " x0 = self.encode_sequence(u[0],terminate)\n",
234 | " x = np.empty((u.shape[0],x0.shape[0]),dtype=np.int8)\n",
235 | " x[0] = x0\n",
236 | " for i in range(len(u)-1):\n",
237 | " x[i+1] = self.encode_sequence(u[i+1],terminate)\n",
238 | " return x"
239 | ],
240 | "execution_count": 4,
241 | "outputs": []
242 | },
243 | {
244 | "cell_type": "markdown",
245 | "metadata": {
246 | "id": "h7CbqjGjnIBp"
247 | },
248 | "source": [
249 | "### Our SNR definition"
250 | ]
251 | },
252 | {
253 | "cell_type": "code",
254 | "metadata": {
255 | "id": "H4AoHWSOnIBu"
256 | },
257 | "source": [
258 | "def ebnodb2std(ebnodb, coderate=1):\n",
259 | " ebno = 10**(ebnodb/10)\n",
260 | " return (1/np.sqrt(2*coderate*ebno)).astype(np.float32)"
261 | ],
262 | "execution_count": 5,
263 | "outputs": []
264 | },
265 | {
266 | "cell_type": "markdown",
267 | "metadata": {
268 | "id": "q8NToLOvnIB6"
269 | },
270 | "source": [
271 | "### Choose which Convolutional you want to use"
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "metadata": {
277 | "id": "MLNYe-pKnIB-"
278 | },
279 | "source": [
280 | "#code = code(m=1) # memory 1 rate 0.5 code with generator polynomials 0o1 and 0o3 (octal)\n",
281 | "#code = code(m=2) # memory 2 rate 0.5 code with generator polynomials 0o5 and 0o7 (octal)\n",
282 | "#code = code(m=4) # memory 4 rate 0.5 code with generator polynomials 0o23 and 0o35 (octal)\n",
283 | "code = code(m=6) # memory 6 rate 0.5 code with generator polynomials 0o133 and 0o171 (octal)"
284 | ],
285 | "execution_count": 6,
286 | "outputs": []
287 | },
288 | {
289 | "cell_type": "markdown",
290 | "metadata": {
291 | "id": "vHxCyjidnICJ"
292 | },
293 | "source": [
294 | "## Parameters"
295 | ]
296 | },
297 | {
298 | "cell_type": "code",
299 | "metadata": {
300 | "id": "sYiAZ51CnICN",
301 | "colab": {
302 | "base_uri": "https://localhost:8080/"
303 | },
304 | "outputId": "f38cf17c-cd15-4612-cf3d-b222bb72ea1b"
305 | },
306 | "source": [
307 | "model_name = \"%s%sm%s_Model\" % (oct(code.d1),oct(code.d2),code.m)\n",
308 | "saver_path = \"trained_models/\"+model_name\n",
309 | "\n",
310 | "gradient_depth = code.tb_depth\n",
311 | "additional_input = 0\n",
312 | "decision_offset = int(len(code.impulse_response)/2)\n",
313 | "sequence_length = 15\n",
314 | "\n",
315 | "rnn_layers = 3\n",
316 | "rnn_units_per_layer = 256\n",
317 | "dense_layers = [16]\n",
318 | "\n",
319 | "print(\"Code Rate:\", code.code_rate)\n",
320 | "print(\"RNN layers:\", rnn_layers)\n",
321 | "print(\"Units per layer:\", rnn_units_per_layer)\n",
322 | "print(\"Gradient depth:\", gradient_depth)\n",
323 | "print(\"ConvCode traceback length thump rule:\",code.tb_depth)"
324 | ],
325 | "execution_count": 7,
326 | "outputs": [
327 | {
328 | "output_type": "stream",
329 | "name": "stdout",
330 | "text": [
331 | "Code Rate: 0.5\n",
332 | "RNN layers: 3\n",
333 | "Units per layer: 256\n",
334 | "Gradient depth: 35\n",
335 | "ConvCode traceback length thump rule: 35\n"
336 | ]
337 | }
338 | ]
339 | },
340 | {
341 | "cell_type": "markdown",
342 | "metadata": {
343 | "id": "pgCc5a_DnICZ"
344 | },
345 | "source": [
346 | "## Tensorflow Graph"
347 | ]
348 | },
349 | {
350 | "cell_type": "code",
351 | "metadata": {
352 | "id": "Y8vm-YG3nICc"
353 | },
354 | "source": [
355 | "graph = tf.Graph()\n",
356 | "with graph.as_default():\n",
357 | " \n",
358 | " # Encoded Sequence Input\n",
359 | " x = tf.placeholder(tf.float32,shape=[2*gradient_depth+sequence_length,None,2*(1+2*additional_input)],name=\"coded_sequence\")\n",
360 | " \n",
361 | " # Decoding\n",
362 | " multi_rnn_cell = tf.contrib.cudnn_rnn.CudnnGRU(rnn_layers,rnn_units_per_layer,direction='bidirectional')\n",
363 | " multi_rnn_cell.build(input_shape=[2*gradient_depth+sequence_length,None,(1+2*additional_input)*2])\n",
364 | "\n",
365 | " out,(new_state,) = multi_rnn_cell(x)\n",
366 | " \n",
367 | " out_sequence = out[gradient_depth:gradient_depth+sequence_length,:,:]\n",
368 | " \n",
369 | " # final dense layers:\n",
370 | " for size in dense_layers:\n",
371 | " out_sequence = tf.layers.dense(out_sequence,size,activation=tf.nn.relu)\n",
372 | " u_hat = tf.layers.dense(out_sequence,1,activation=tf.nn.sigmoid)\n",
373 | " \n",
374 | " u_hat = tf.squeeze(u_hat)\n",
375 | " \n",
376 | " u_hat_bits = tf.cast(tf.greater(u_hat,0.5),tf.int8)\n",
377 | " \n",
378 | " \n",
379 | " # Loss function\n",
380 | " u_label = tf.placeholder(tf.int8,shape=[sequence_length,None],name=\"uncoded_bits\")\n",
381 | " loss = tf.losses.log_loss(labels=u_label,predictions=u_hat)\n",
382 | " correct_predictions = tf.equal(u_hat_bits, u_label)\n",
383 | " ber = 1.0 - tf.reduce_mean(tf.cast(correct_predictions, tf.float32),axis=1)\n",
384 | "\n",
385 | "\n",
386 | " # Training\n",
387 | " lr = tf.placeholder(tf.float32, shape=[])\n",
388 | " optimizer = tf.train.RMSPropOptimizer(lr)\n",
389 | " step = optimizer.minimize(loss)\n",
390 | " \n",
391 | " # Init\n",
392 | " init = tf.global_variables_initializer()\n",
393 | " \n",
394 | " # Saver\n",
395 | " saver = tf.train.Saver()"
396 | ],
397 | "execution_count": 8,
398 | "outputs": []
399 | },
400 | {
401 | "cell_type": "markdown",
402 | "metadata": {
403 | "id": "wFhX5TfMnICo"
404 | },
405 | "source": [
406 | "### Let's print all trainable variables of the graph we just defined:\n",
407 | "Note that special CudnnGRU layers generate some kind of \"sub\"-graph and therefore their variables are not shown here but in a so called opaque_kernel."
408 | ]
409 | },
410 | {
411 | "cell_type": "code",
412 | "metadata": {
413 | "id": "TKjS78WEnICr",
414 | "colab": {
415 | "base_uri": "https://localhost:8080/"
416 | },
417 | "outputId": "4fed0087-504a-4df6-d58b-32f27ba907b4"
418 | },
419 | "source": [
420 | "def model_summary(for_graph): #from TensorFlow slim.model_analyzer.analyze_vars source\n",
421 | " print(\"{:60}{:21}{:14}{:>17}\".format('Name','Shape','Variables','Size'))\n",
422 | " print('{:-<112}'.format(''))\n",
423 | " total_size = 0\n",
424 | " total_bytes = 0\n",
425 | " for var in for_graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):\n",
426 | " # if var.num_elements() is None or [] assume size 0.\n",
427 | " var_shape = var.get_shape()\n",
428 | " var_size = var.get_shape().num_elements() or 0\n",
429 | " var_bytes = var_size * var.dtype.size\n",
430 | " total_size += var_size\n",
431 | " total_bytes += var_bytes\n",
432 | " print(\"{:60}{:21}{:14}{:>11} bytes\".format(var.name, str(var_shape), str(var_size), var_bytes))\n",
433 | " print()\n",
434 | " print('\\033[1mTotal number of variables:\\t{}\\033[0m'.format(total_size))\n",
435 | " print('\\033[1mTotal bytes of variables:\\t{}\\033[0m'.format(total_bytes))\n",
436 | "\n",
437 | "model_summary(graph)"
438 | ],
439 | "execution_count": 9,
440 | "outputs": [
441 | {
442 | "output_type": "stream",
443 | "name": "stdout",
444 | "text": [
445 | "Name Shape Variables Size\n",
446 | "----------------------------------------------------------------------------------------------------------------\n",
447 | "cudnn_gru/opaque_kernel:0 0 0 bytes\n",
448 | "dense/kernel:0 (512, 16) 8192 32768 bytes\n",
449 | "dense/bias:0 (16,) 16 64 bytes\n",
450 | "dense_1/kernel:0 (16, 1) 16 64 bytes\n",
451 | "dense_1/bias:0 (1,) 1 4 bytes\n",
452 | "\n",
453 | "\u001b[1mTotal number of variables:\t8225\u001b[0m\n",
454 | "\u001b[1mTotal bytes of variables:\t32900\u001b[0m\n"
455 | ]
456 | }
457 | ]
458 | },
459 | {
460 | "cell_type": "markdown",
461 | "metadata": {
462 | "id": "LP8nLUmKnIC2"
463 | },
464 | "source": [
465 | "## Generator function\n",
466 | "Let's define a generatior function that first generates a large dataset pair of bit sequences and encoded bit sequences.\n",
467 | "\n",
468 | "In a second step, this functions slices those datasets in shorter snippets which are then fed to the NN decoder."
469 | ]
470 | },
471 | {
472 | "cell_type": "code",
473 | "metadata": {
474 | "id": "Caj6X4x1nIC5"
475 | },
476 | "source": [
477 | "def generator(batch_size,iterations,gradient_depth,sequence_length,additional_input,decision_offset,sigma,apriori):\n",
478 | " offset = code.tb_depth + 3\n",
479 | " full_uncoded_sequences = np.random.randint(0,100,[batch_size,iterations+2*gradient_depth+sequence_length+offset+2*additional_input],dtype=np.int8)\n",
480 | " full_uncoded_sequences = np.less(full_uncoded_sequences,np.array([apriori*100],dtype=np.int8)).astype(np.int8)\n",
481 | " full_coded_sequences = code.encode_batch(full_uncoded_sequences)\n",
482 | " full_coded_sequences = np.reshape(full_coded_sequences,[batch_size,-1,int(1/code.code_rate)])\n",
483 | " \n",
484 | " # Feeding\n",
485 | " for i in range(iterations):\n",
486 | " encoded_sequences = full_coded_sequences[:,offset+i:offset+i+2*gradient_depth+sequence_length+2*additional_input,:]\n",
487 | " labels = full_uncoded_sequences[:,offset+i+gradient_depth+additional_input+decision_offset:offset+i+gradient_depth+additional_input+decision_offset+sequence_length]\n",
488 | " \n",
489 | " # BPSK Modulation\n",
490 | " modulated_sequences = (encoded_sequences.astype(np.float32) - 0.5) * 2.0\n",
491 | "\n",
492 | " # AWGN\n",
493 | " noise = np.random.normal(size=modulated_sequences.shape).astype(np.float32)\n",
494 | " noised_sequences = modulated_sequences + noise * sigma\n",
495 | " \n",
496 | " # Input Processing\n",
497 | " stack_array = []\n",
498 | " for k in range(2*gradient_depth+sequence_length):\n",
499 | " stack_array.append(noised_sequences[:,k:k+2*additional_input+1,:])\n",
500 | " input_x = np.stack(stack_array,axis=1)\n",
501 | " input_x = np.reshape(input_x,newshape=[batch_size,sequence_length+2*gradient_depth,(1+2*additional_input)*2])\n",
502 | "\n",
503 | " # Transpose dimensions 1 and 0 because CudnnGRU layers need [time,batch,input] feeding\n",
504 | " input_x = np.transpose(input_x,axes=[1,0,2])\n",
505 | " input_labels = np.transpose(labels,axes=[1,0])\n",
506 | " \n",
507 | " yield input_x,input_labels"
508 | ],
509 | "execution_count": 10,
510 | "outputs": []
511 | },
512 | {
513 | "cell_type": "markdown",
514 | "metadata": {
515 | "id": "uVX0fZaTnIDF"
516 | },
517 | "source": [
518 | "## Starting a tensorflow session\n",
519 | "We create a session for the previously defined graph and save the initial state of the graph in training_stage 0"
520 | ]
521 | },
522 | {
523 | "cell_type": "code",
524 | "metadata": {
525 | "id": "eyffvGZ9nIDH",
526 | "colab": {
527 | "base_uri": "https://localhost:8080/",
528 | "height": 35
529 | },
530 | "outputId": "b6850dad-3d33-4415-dea6-2095e8dcf36a"
531 | },
532 | "source": [
533 | "sess_config = tf.ConfigProto()\n",
534 | "#sess_config.gpu_options.per_process_gpu_memory_fraction = 0.3 # to limit the amount of GPU memory usage\n",
535 | "sess_config.gpu_options.allow_growth = True\n",
536 | "sess = tf.Session(graph=graph, config=sess_config)\n",
537 | "sess.run(init)\n",
538 | "\n",
539 | "trained_stages = 0\n",
540 | "saver.save(sess,saver_path,global_step=trained_stages)"
541 | ],
542 | "execution_count": 11,
543 | "outputs": [
544 | {
545 | "output_type": "execute_result",
546 | "data": {
547 | "text/plain": [
548 | "'trained_models/0o1330o171m6_Model-0'"
549 | ],
550 | "application/vnd.google.colaboratory.intrinsic+json": {
551 | "type": "string"
552 | }
553 | },
554 | "metadata": {},
555 | "execution_count": 11
556 | }
557 | ]
558 | },
559 | {
560 | "cell_type": "markdown",
561 | "metadata": {
562 | "id": "sy0m9X_nnIDR"
563 | },
564 | "source": [
565 | "## Auxiliary functions"
566 | ]
567 | },
568 | {
569 | "cell_type": "code",
570 | "metadata": {
571 | "id": "3Ihe8ZHonIDU"
572 | },
573 | "source": [
574 | "# generates a python dictionary linking numpy feeds to tensorflow tensors\n",
575 | "def gen_feed_dict(x_feed, u_feed, lr_feed=1e-3):\n",
576 | " feed_dict = {\n",
577 | " x: x_feed,\n",
578 | " u_label: u_feed,\n",
579 | " lr: lr_feed,\n",
580 | " }\n",
581 | " return feed_dict\n",
582 | "\n",
583 | "# runs a single batch to predict u_hat and calculate the BER\n",
584 | "def test_step(x_feed, u_feed):\n",
585 | " return sess.run(ber,feed_dict=gen_feed_dict(x_feed,u_feed))\n",
586 | "\n",
587 | "# runs a monte carlo simulation of several test steps to get meaningful BERs\n",
588 | "def test(test_parameters, plot=False, plot_baseline=False, ber_at_time=int(sequence_length/2)):\n",
589 | " test_sigma = ebnodb2std(test_parameters['ebnodb'],code.code_rate)\n",
590 | " ber = np.zeros([len(test_parameters['ebnodb']),sequence_length])\n",
591 | " for i in range(len(test_sigma)):\n",
592 | " for x_feed,u_feed in generator(test_parameters['batch_size'],test_parameters['iterations'],gradient_depth,sequence_length,additional_input,decision_offset,test_sigma[i],0.5):\n",
593 | " curr_ber = test_step(x_feed, u_feed)\n",
594 | " ber[i] += curr_ber\n",
595 | " # logging\n",
596 | " print(\"SNR:\",test_parameters['ebnodb'][i])\n",
597 | " print(\"BER:\",ber[i]/test_parameters['iterations'])\n",
598 | " ber = ber/test_parameters['iterations']\n",
599 | " print(\"Final BER:\",ber)\n",
600 | " if (plot):\n",
601 | " plot_bler_vs_ebnodb(test_parameters['ebnodb'], ber[:,ber_at_time], plot_baseline)\n",
602 | " return ber\n",
603 | "\n",
604 | "# runs a single training step\n",
605 | "def train_step(x_feed,u_feed,lr_feed):\n",
606 | " return sess.run([step,loss,ber],feed_dict=gen_feed_dict(x_feed,u_feed,lr_feed))\n",
607 | " \n",
608 | "# runs a training set according to training_params\n",
609 | "def train(training_params):\n",
610 | " global trained_stages\n",
611 | " pl = training_params['learning']\n",
612 | " #early_stopping = training_params['early_stopping']\n",
613 | " for epoch in pl:\n",
614 | " # learning params\n",
615 | " batch_size = epoch[0]\n",
616 | " iterations = epoch[1]\n",
617 | " learning_rate = epoch[2]\n",
618 | " ebnodb = epoch[3]\n",
619 | " apriori = epoch[4]\n",
620 | " train_sigma = ebnodb2std(ebnodb,code.code_rate)\n",
621 | " # logging\n",
622 | " logging_interval = int(iterations/10)\n",
623 | " logging_it_counter = 0\n",
624 | " logging_interval_loss = 0.0\n",
625 | " logging_interval_ber = np.zeros([sequence_length])\n",
626 | " \n",
627 | " print(\"\\nTraining Epoch - Batch Size: %d, Iterations: %d, Learning Rate: %.4f, EbNodB %.1f (std: %.3f), P_apriori %.2f\" % (batch_size,iterations,learning_rate,ebnodb,train_sigma,apriori))\n",
628 | " # training\n",
629 | " for x_feed,u_feed in generator(batch_size,iterations,gradient_depth,sequence_length,additional_input,decision_offset,train_sigma,apriori):\n",
630 | " _,curr_loss,curr_ber = train_step(x_feed,u_feed,learning_rate)\n",
631 | " # logging\n",
632 | " logging_interval_loss += curr_loss\n",
633 | " logging_interval_ber += curr_ber\n",
634 | " logging_it_counter += 1\n",
635 | "\n",
636 | " if logging_it_counter%logging_interval == 0:\n",
637 | " #if early_stopping and previous_logging_interval_loss < logging_interval_loss:\n",
638 | " # print(\"\")\n",
639 | "\n",
640 | " print(\" Iteration %d to %d - Avg. Loss: %.3E Avg. BER: %.3E Min. @ BER[%d]=%.3E\" % (logging_it_counter-logging_interval,\n",
641 | " logging_it_counter,\n",
642 | " logging_interval_loss/logging_interval,\n",
643 | " np.mean(logging_interval_ber/logging_interval),\n",
644 | " np.argmin(logging_interval_ber/logging_interval),\n",
645 | " np.min(logging_interval_ber/logging_interval)))\n",
646 | " logging_interval_loss = 0.0\n",
647 | " logging_interval_ber = 0.0\n",
648 | " \n",
649 | " # save weights\n",
650 | " trained_stages += 1\n",
651 | " saver.save(sess,saver_path,global_step=trained_stages)\n",
652 | " print(\" -> saved as training stage: %s-%d\" % (model_name,trained_stages))\n",
653 | "\n",
654 | "# plots a BER curve\n",
655 | "def plot_bler_vs_ebnodb(ebnodb, ber, baseline=False):\n",
656 | " image = plt.figure(figsize=(12,6))\n",
657 | " plt.plot(ebnodb, ber, '-o')\n",
658 | " if baseline:\n",
659 | " plt.plot(ebnodb, baseline_ber, '--')\n",
660 | " plt.legend(['RNN Decoder', 'Viterbi Decoder']);\n",
661 | " plt.yscale('log')\n",
662 | " plt.xlabel('EbNo (dB)', fontsize=16)\n",
663 | " plt.ylabel('Bit-error rate', fontsize=16)"
664 | ],
665 | "execution_count": 12,
666 | "outputs": []
667 | },
668 | {
669 | "cell_type": "markdown",
670 | "metadata": {
671 | "id": "2nLozXmgnIDd"
672 | },
673 | "source": [
674 | "## Training\n",
675 | "Let's define training parameters and begin with the training.\n",
676 | "\n",
677 | "Notice that we use so called a priori ramp-up training [1].\n",
678 | "That is, setting the a priori probability of either ones or zeros in bit vector u to a small value and later, in subsequent training epochs, raising it up till 0.5 where ones and zeros are uniformly distributed again."
679 | ]
680 | },
681 | {
682 | "cell_type": "code",
683 | "metadata": {
684 | "id": "DYsTNnVvnIDg"
685 | },
686 | "source": [
687 | "train_snr_db = 1.5\n",
688 | "training_params = {\n",
689 | " 'learning' : [ #batch_size, iterations, learning_rate, training_ebnodb, apriori\n",
690 | " [100, 1000, 0.001, train_snr_db, 0.01],\n",
691 | " [100, 1000, 0.001, train_snr_db, 0.1],\n",
692 | " [100, 1000, 0.001, train_snr_db, 0.2],\n",
693 | " [100, 1000, 0.001, train_snr_db, 0.3],\n",
694 | " [100, 1000, 0.001, train_snr_db, 0.4],\n",
695 | " [100, 500000, 0.0001, train_snr_db, 0.5],\n",
696 | " [500, 100000, 0.0001, train_snr_db, 0.5],\n",
697 | " [1000, 50000, 0.0001, train_snr_db, 0.5],\n",
698 | " [2000, 50000, 0.0001, train_snr_db, 0.5],\n",
699 | " ]\n",
700 | "}"
701 | ],
702 | "execution_count": 13,
703 | "outputs": []
704 | },
705 | {
706 | "cell_type": "code",
707 | "metadata": {
708 | "id": "60hOY8pGnIDp",
709 | "colab": {
710 | "base_uri": "https://localhost:8080/",
711 | "height": 1000
712 | },
713 | "outputId": "479f02b3-0c1c-4ad7-8ed0-1974759e441e"
714 | },
715 | "source": [
716 | "train(training_params)"
717 | ],
718 | "execution_count": 14,
719 | "outputs": [
720 | {
721 | "output_type": "stream",
722 | "name": "stdout",
723 | "text": [
724 | "\n",
725 | "Training Epoch - Batch Size: 100, Iterations: 1000, Learning Rate: 0.0010, EbNodB 1.5 (std: 0.841), P_apriori 0.01\n",
726 | " Iteration 0 to 100 - Avg. Loss: 2.168E-01 Avg. BER: 1.960E-02 Min. @ BER[14]=1.820E-02\n",
727 | " Iteration 100 to 200 - Avg. Loss: 4.481E-02 Avg. BER: 1.110E-02 Min. @ BER[3]=1.020E-02\n",
728 | " Iteration 200 to 300 - Avg. Loss: 2.168E-02 Avg. BER: 9.087E-03 Min. @ BER[0]=8.700E-03\n",
729 | " Iteration 300 to 400 - Avg. Loss: 2.077E-02 Avg. BER: 1.069E-02 Min. @ BER[6]=1.040E-02\n",
730 | " Iteration 400 to 500 - Avg. Loss: 1.726E-02 Avg. BER: 8.720E-03 Min. @ BER[13]=8.200E-03\n",
731 | " Iteration 500 to 600 - Avg. Loss: 1.409E-02 Avg. BER: 5.440E-03 Min. @ BER[7]=3.600E-03\n",
732 | " Iteration 600 to 700 - Avg. Loss: 7.269E-03 Avg. BER: 2.087E-03 Min. @ BER[5]=1.600E-03\n",
733 | " Iteration 700 to 800 - Avg. Loss: 6.081E-03 Avg. BER: 1.820E-03 Min. @ BER[2]=1.300E-03\n",
734 | " Iteration 800 to 900 - Avg. Loss: 4.063E-03 Avg. BER: 1.260E-03 Min. @ BER[6]=8.000E-04\n",
735 | " Iteration 900 to 1000 - Avg. Loss: 1.967E-03 Avg. BER: 6.600E-04 Min. @ BER[10]=3.000E-04\n",
736 | " -> saved as training stage: 0o1330o171m6_Model-1\n",
737 | "\n",
738 | "Training Epoch - Batch Size: 100, Iterations: 1000, Learning Rate: 0.0010, EbNodB 1.5 (std: 0.841), P_apriori 0.10\n",
739 | " Iteration 0 to 100 - Avg. Loss: 1.519E-01 Avg. BER: 6.061E-02 Min. @ BER[4]=5.730E-02\n",
740 | " Iteration 100 to 200 - Avg. Loss: 1.139E-01 Avg. BER: 4.728E-02 Min. @ BER[2]=4.320E-02\n",
741 | " Iteration 200 to 300 - Avg. Loss: 9.829E-02 Avg. BER: 4.007E-02 Min. @ BER[0]=3.650E-02\n",
742 | " Iteration 300 to 400 - Avg. Loss: 7.084E-02 Avg. BER: 2.793E-02 Min. @ BER[1]=2.380E-02\n",
743 | " Iteration 400 to 500 - Avg. Loss: 5.512E-02 Avg. BER: 2.075E-02 Min. @ BER[5]=1.790E-02\n",
744 | " Iteration 500 to 600 - Avg. Loss: 4.875E-02 Avg. BER: 1.853E-02 Min. @ BER[0]=1.540E-02\n",
745 | " Iteration 600 to 700 - Avg. Loss: 5.106E-02 Avg. BER: 1.949E-02 Min. @ BER[0]=1.540E-02\n",
746 | " Iteration 700 to 800 - Avg. Loss: 4.781E-02 Avg. BER: 1.792E-02 Min. @ BER[0]=1.440E-02\n",
747 | " Iteration 800 to 900 - Avg. Loss: 4.288E-02 Avg. BER: 1.667E-02 Min. @ BER[1]=1.220E-02\n",
748 | " Iteration 900 to 1000 - Avg. Loss: 4.000E-02 Avg. BER: 1.522E-02 Min. @ BER[3]=1.200E-02\n",
749 | " -> saved as training stage: 0o1330o171m6_Model-2\n",
750 | "\n",
751 | "Training Epoch - Batch Size: 100, Iterations: 1000, Learning Rate: 0.0010, EbNodB 1.5 (std: 0.841), P_apriori 0.20\n",
752 | " Iteration 0 to 100 - Avg. Loss: 1.995E-01 Avg. BER: 8.715E-02 Min. @ BER[0]=7.990E-02\n",
753 | " Iteration 100 to 200 - Avg. Loss: 1.838E-01 Avg. BER: 7.964E-02 Min. @ BER[0]=7.010E-02\n",
754 | " Iteration 200 to 300 - Avg. Loss: 1.764E-01 Avg. BER: 7.613E-02 Min. @ BER[0]=6.960E-02\n",
755 | " Iteration 300 to 400 - Avg. Loss: 1.582E-01 Avg. BER: 6.832E-02 Min. @ BER[3]=6.040E-02\n",
756 | " Iteration 400 to 500 - Avg. Loss: 1.508E-01 Avg. BER: 6.343E-02 Min. @ BER[0]=5.520E-02\n",
757 | " Iteration 500 to 600 - Avg. Loss: 1.463E-01 Avg. BER: 6.268E-02 Min. @ BER[2]=5.650E-02\n",
758 | " Iteration 600 to 700 - Avg. Loss: 1.455E-01 Avg. BER: 6.221E-02 Min. @ BER[0]=5.070E-02\n",
759 | " Iteration 700 to 800 - Avg. Loss: 1.361E-01 Avg. BER: 5.781E-02 Min. @ BER[0]=5.190E-02\n",
760 | " Iteration 800 to 900 - Avg. Loss: 1.379E-01 Avg. BER: 5.850E-02 Min. @ BER[0]=5.270E-02\n",
761 | " Iteration 900 to 1000 - Avg. Loss: 1.389E-01 Avg. BER: 5.840E-02 Min. @ BER[2]=5.300E-02\n",
762 | " -> saved as training stage: 0o1330o171m6_Model-3\n",
763 | "\n",
764 | "Training Epoch - Batch Size: 100, Iterations: 1000, Learning Rate: 0.0010, EbNodB 1.5 (std: 0.841), P_apriori 0.30\n",
765 | " Iteration 0 to 100 - Avg. Loss: 3.464E-01 Avg. BER: 1.693E-01 Min. @ BER[0]=1.521E-01\n",
766 | " Iteration 100 to 200 - Avg. Loss: 3.178E-01 Avg. BER: 1.538E-01 Min. @ BER[0]=1.445E-01\n",
767 | " Iteration 200 to 300 - Avg. Loss: 3.141E-01 Avg. BER: 1.505E-01 Min. @ BER[0]=1.394E-01\n",
768 | " Iteration 300 to 400 - Avg. Loss: 3.154E-01 Avg. BER: 1.534E-01 Min. @ BER[1]=1.452E-01\n",
769 | " Iteration 400 to 500 - Avg. Loss: 3.079E-01 Avg. BER: 1.483E-01 Min. @ BER[3]=1.389E-01\n",
770 | " Iteration 500 to 600 - Avg. Loss: 3.123E-01 Avg. BER: 1.508E-01 Min. @ BER[0]=1.375E-01\n",
771 | " Iteration 600 to 700 - Avg. Loss: 2.892E-01 Avg. BER: 1.386E-01 Min. @ BER[1]=1.267E-01\n",
772 | " Iteration 700 to 800 - Avg. Loss: 2.950E-01 Avg. BER: 1.413E-01 Min. @ BER[2]=1.317E-01\n",
773 | " Iteration 800 to 900 - Avg. Loss: 2.759E-01 Avg. BER: 1.323E-01 Min. @ BER[1]=1.236E-01\n",
774 | " Iteration 900 to 1000 - Avg. Loss: 2.751E-01 Avg. BER: 1.343E-01 Min. @ BER[0]=1.246E-01\n",
775 | " -> saved as training stage: 0o1330o171m6_Model-4\n",
776 | "\n",
777 | "Training Epoch - Batch Size: 100, Iterations: 1000, Learning Rate: 0.0010, EbNodB 1.5 (std: 0.841), P_apriori 0.40\n",
778 | " Iteration 0 to 100 - Avg. Loss: 5.074E-01 Avg. BER: 2.843E-01 Min. @ BER[1]=2.710E-01\n",
779 | " Iteration 100 to 200 - Avg. Loss: 4.780E-01 Avg. BER: 2.661E-01 Min. @ BER[1]=2.488E-01\n",
780 | " Iteration 200 to 300 - Avg. Loss: 5.105E-01 Avg. BER: 2.867E-01 Min. @ BER[0]=2.677E-01\n",
781 | " Iteration 300 to 400 - Avg. Loss: 4.990E-01 Avg. BER: 2.801E-01 Min. @ BER[0]=2.573E-01\n",
782 | " Iteration 400 to 500 - Avg. Loss: 4.972E-01 Avg. BER: 2.794E-01 Min. @ BER[0]=2.627E-01\n",
783 | " Iteration 500 to 600 - Avg. Loss: 4.795E-01 Avg. BER: 2.658E-01 Min. @ BER[0]=2.501E-01\n",
784 | " Iteration 600 to 700 - Avg. Loss: 4.837E-01 Avg. BER: 2.711E-01 Min. @ BER[2]=2.534E-01\n",
785 | " Iteration 700 to 800 - Avg. Loss: 4.583E-01 Avg. BER: 2.494E-01 Min. @ BER[0]=2.356E-01\n",
786 | " Iteration 800 to 900 - Avg. Loss: 4.650E-01 Avg. BER: 2.541E-01 Min. @ BER[1]=2.361E-01\n",
787 | " Iteration 900 to 1000 - Avg. Loss: 4.607E-01 Avg. BER: 2.525E-01 Min. @ BER[1]=2.432E-01\n",
788 | " -> saved as training stage: 0o1330o171m6_Model-5\n",
789 | "\n",
790 | "Training Epoch - Batch Size: 100, Iterations: 500000, Learning Rate: 0.0001, EbNodB 1.5 (std: 0.841), P_apriori 0.50\n"
791 | ]
792 | },
793 | {
794 | "output_type": "error",
795 | "ename": "KeyboardInterrupt",
796 | "evalue": "ignored",
797 | "traceback": [
798 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
799 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
800 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraining_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
801 | "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(training_params)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;31m# training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx_feed\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mu_feed\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mgenerator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0miterations\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mgradient_depth\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0msequence_length\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0madditional_input\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdecision_offset\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtrain_sigma\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mapriori\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 57\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcurr_loss\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcurr_ber\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_feed\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mu_feed\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlearning_rate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 58\u001b[0m \u001b[0;31m# logging\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mlogging_interval_loss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mcurr_loss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
802 | "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain_step\u001b[0;34m(x_feed, u_feed, lr_feed)\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;31m# runs a single training step\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtrain_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_feed\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mu_feed\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlr_feed\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mber\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mfeed_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgen_feed_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_feed\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mu_feed\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlr_feed\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;31m# runs a training set according to training_params\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
803 | "\u001b[0;32m/tensorflow-1.15.2/python3.7/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 954\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 955\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 956\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 957\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 958\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
804 | "\u001b[0;32m/tensorflow-1.15.2/python3.7/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1178\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1179\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m-> 1180\u001b[0;31m feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[1;32m 1181\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1182\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
805 | "\u001b[0;32m/tensorflow-1.15.2/python3.7/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1357\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1358\u001b[0m return self._do_call(_run_fn, feeds, fetches, targets, options,\n\u001b[0;32m-> 1359\u001b[0;31m run_metadata)\n\u001b[0m\u001b[1;32m 1360\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1361\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_prun_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
806 | "\u001b[0;32m/tensorflow-1.15.2/python3.7/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m 1363\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1364\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1365\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1366\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1367\u001b[0m \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
807 | "\u001b[0;32m/tensorflow-1.15.2/python3.7/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m 1348\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_extend_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1349\u001b[0m return self._call_tf_sessionrun(options, feed_dict, fetch_list,\n\u001b[0;32m-> 1350\u001b[0;31m target_list, run_metadata)\n\u001b[0m\u001b[1;32m 1351\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1352\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
808 | "\u001b[0;32m/tensorflow-1.15.2/python3.7/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_call_tf_sessionrun\u001b[0;34m(self, options, feed_dict, fetch_list, target_list, run_metadata)\u001b[0m\n\u001b[1;32m 1441\u001b[0m return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict,\n\u001b[1;32m 1442\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1443\u001b[0;31m run_metadata)\n\u001b[0m\u001b[1;32m 1444\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1445\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_tf_sessionprun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
809 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
810 | ]
811 | }
812 | ]
813 | },
814 | {
815 | "cell_type": "markdown",
816 | "metadata": {
817 | "id": "S2jKzN85nIDx"
818 | },
819 | "source": [
820 | "## Restoring already trained models\n",
821 | "Since we used the tensorflow saver after each training epoch, we can load already trained models and then evaluate"
822 | ]
823 | },
824 | {
825 | "cell_type": "code",
826 | "metadata": {
827 | "id": "R5_apwuCnID1"
828 | },
829 | "source": [
830 | "#saver.restore(sess,\"%s-%d\" % (saver_path,9))"
831 | ],
832 | "execution_count": 15,
833 | "outputs": []
834 | },
835 | {
836 | "cell_type": "markdown",
837 | "metadata": {
838 | "id": "ODSuWtQWnID9"
839 | },
840 | "source": [
841 | "## Evaluation: Monte Carlo BER simulation"
842 | ]
843 | },
844 | {
845 | "cell_type": "code",
846 | "metadata": {
847 | "id": "0ulrSB8gnIEA",
848 | "colab": {
849 | "base_uri": "https://localhost:8080/"
850 | },
851 | "outputId": "a79ef38a-81a8-49d9-c64b-e24c9a5cb280"
852 | },
853 | "source": [
854 | "test_parameters = {\n",
855 | " 'batch_size' : 2000,\n",
856 | " 'iterations' : 100,\n",
857 | " 'ebnodb' : np.arange(0,5.6,0.5)\n",
858 | "}\n",
859 | "sim_ber = test(test_parameters,plot=False,plot_baseline=False)"
860 | ],
861 | "execution_count": 16,
862 | "outputs": [
863 | {
864 | "output_type": "stream",
865 | "name": "stdout",
866 | "text": [
867 | "SNR: 0.0\n",
868 | "BER: [0.45576 0.457 0.4558 0.455955 0.45695 0.45717 0.45542 0.45647\n",
869 | " 0.457045 0.455375 0.45677 0.45647 0.45766 0.45634 0.45731 ]\n",
870 | "SNR: 0.5\n",
871 | "BER: [0.44117 0.44179 0.43963 0.43956 0.43945 0.43944 0.43984 0.439545\n",
872 | " 0.44075 0.44156 0.4421 0.44107 0.44037 0.44228 0.44171 ]\n",
873 | "SNR: 1.0\n",
874 | "BER: [0.42144 0.4202 0.42042 0.421925 0.421415 0.42074 0.42227 0.42213\n",
875 | " 0.422155 0.42186 0.42179 0.422345 0.42302 0.423785 0.42311 ]\n",
876 | "SNR: 1.5\n",
877 | "BER: [0.398795 0.400125 0.39785 0.397695 0.39824 0.399015 0.39951 0.398105\n",
878 | " 0.40013 0.39963 0.398385 0.400075 0.40104 0.39912 0.39943 ]\n",
879 | "SNR: 2.0\n",
880 | "BER: [0.371985 0.37340001 0.37313 0.373795 0.373295 0.373975\n",
881 | " 0.374525 0.37428 0.37472 0.37539 0.375845 0.374815\n",
882 | " 0.376405 0.375295 0.3757 ]\n",
883 | "SNR: 2.5\n",
884 | "BER: [0.34763 0.34754 0.34817 0.347765 0.348205 0.34797 0.34716 0.346935\n",
885 | " 0.348895 0.348965 0.350425 0.349585 0.350415 0.35102 0.3515 ]\n",
886 | "SNR: 3.0\n",
887 | "BER: [0.318415 0.31746 0.317815 0.31865 0.31829 0.318725 0.319555 0.319345\n",
888 | " 0.31983 0.319635 0.32059 0.32066 0.322305 0.321975 0.322405]\n",
889 | "SNR: 3.5\n",
890 | "BER: [0.292285 0.2923 0.2931 0.293585 0.292635 0.29359 0.293415 0.29419\n",
891 | " 0.29523 0.295885 0.29455 0.29592 0.297125 0.298355 0.298075]\n",
892 | "SNR: 4.0\n",
893 | "BER: [0.27332 0.272845 0.27345 0.27313 0.27271 0.27439 0.2748 0.273955\n",
894 | " 0.27474 0.27544 0.27759 0.27654 0.277185 0.278665 0.280505]\n",
895 | "SNR: 4.5\n",
896 | "BER: [0.25164 0.25231 0.251455 0.25138 0.252965 0.253595 0.25241 0.252215\n",
897 | " 0.25412 0.25438 0.255165 0.25613 0.256195 0.257295 0.258125]\n",
898 | "SNR: 5.0\n",
899 | "BER: [0.23532 0.236625 0.235905 0.23755 0.2362 0.23742 0.237885 0.237495\n",
900 | " 0.23814 0.23905 0.239465 0.240705 0.24107 0.242395 0.242365]\n",
901 | "SNR: 5.5\n",
902 | "BER: [0.21717 0.21814 0.21866 0.21848 0.217555 0.218565\n",
903 | " 0.219365 0.219425 0.22072499 0.221625 0.222785 0.22249\n",
904 | " 0.22271 0.22363 0.225355 ]\n",
905 | "Final BER: [[0.45576 0.457 0.4558 0.455955 0.45695 0.45717\n",
906 | " 0.45542 0.45647 0.457045 0.455375 0.45677 0.45647\n",
907 | " 0.45766 0.45634 0.45731 ]\n",
908 | " [0.44117 0.44179 0.43963 0.43956 0.43945 0.43944\n",
909 | " 0.43984 0.439545 0.44075 0.44156 0.4421 0.44107\n",
910 | " 0.44037 0.44228 0.44171 ]\n",
911 | " [0.42144 0.4202 0.42042 0.421925 0.421415 0.42074\n",
912 | " 0.42227 0.42213 0.422155 0.42186 0.42179 0.422345\n",
913 | " 0.42302 0.423785 0.42311 ]\n",
914 | " [0.398795 0.400125 0.39785 0.397695 0.39824 0.399015\n",
915 | " 0.39951 0.398105 0.40013 0.39963 0.398385 0.400075\n",
916 | " 0.40104 0.39912 0.39943 ]\n",
917 | " [0.371985 0.37340001 0.37313 0.373795 0.373295 0.373975\n",
918 | " 0.374525 0.37428 0.37472 0.37539 0.375845 0.374815\n",
919 | " 0.376405 0.375295 0.3757 ]\n",
920 | " [0.34763 0.34754 0.34817 0.347765 0.348205 0.34797\n",
921 | " 0.34716 0.346935 0.348895 0.348965 0.350425 0.349585\n",
922 | " 0.350415 0.35102 0.3515 ]\n",
923 | " [0.318415 0.31746 0.317815 0.31865 0.31829 0.318725\n",
924 | " 0.319555 0.319345 0.31983 0.319635 0.32059 0.32066\n",
925 | " 0.322305 0.321975 0.322405 ]\n",
926 | " [0.292285 0.2923 0.2931 0.293585 0.292635 0.29359\n",
927 | " 0.293415 0.29419 0.29523 0.295885 0.29455 0.29592\n",
928 | " 0.297125 0.298355 0.298075 ]\n",
929 | " [0.27332 0.272845 0.27345 0.27313 0.27271 0.27439\n",
930 | " 0.2748 0.273955 0.27474 0.27544 0.27759 0.27654\n",
931 | " 0.277185 0.278665 0.280505 ]\n",
932 | " [0.25164 0.25231 0.251455 0.25138 0.252965 0.253595\n",
933 | " 0.25241 0.252215 0.25412 0.25438 0.255165 0.25613\n",
934 | " 0.256195 0.257295 0.258125 ]\n",
935 | " [0.23532 0.236625 0.235905 0.23755 0.2362 0.23742\n",
936 | " 0.237885 0.237495 0.23814 0.23905 0.239465 0.240705\n",
937 | " 0.24107 0.242395 0.242365 ]\n",
938 | " [0.21717 0.21814 0.21866 0.21848 0.217555 0.218565\n",
939 | " 0.219365 0.219425 0.22072499 0.221625 0.222785 0.22249\n",
940 | " 0.22271 0.22363 0.225355 ]]\n"
941 | ]
942 | }
943 | ]
944 | },
945 | {
946 | "cell_type": "code",
947 | "metadata": {
948 | "id": "qDuKGcjUnIEI",
949 | "colab": {
950 | "base_uri": "https://localhost:8080/",
951 | "height": 393
952 | },
953 | "outputId": "c48acfce-b0ec-41c0-f9e7-f9dd05b5dacf"
954 | },
955 | "source": [
956 | "plot_bler_vs_ebnodb(test_parameters['ebnodb'], sim_ber[:,0])\n",
957 | "plt.plot(np.arange(0,5.6,0.5),code.viterbi_reference)\n",
958 | "plt.ylim(1e-6,0.5)\n",
959 | "plt.grid()\n",
960 | "plt.legend(['RNN-Decoder','Viterbi reference'])\n",
961 | "plt.show();"
962 | ],
963 | "execution_count": 17,
964 | "outputs": [
965 | {
966 | "output_type": "display_data",
967 | "data": {
968 | "text/plain": [
969 | ""
970 | ],
971 | "image/png": "\n"
972 | },
973 | "metadata": {
974 | "needs_background": "light"
975 | }
976 | }
977 | ]
978 | },
979 | {
980 | "cell_type": "code",
981 | "metadata": {
982 | "id": "TWgUOTgynIEQ"
983 | },
984 | "source": [
985 | ""
986 | ],
987 | "execution_count": null,
988 | "outputs": []
989 | }
990 | ]
991 | }
--------------------------------------------------------------------------------
/coding.py:
--------------------------------------------------------------------------------
1 | '''
2 | These scripts require the python library CommPy!
3 |
4 | Install:
5 |
6 | $ git clone https://github.com/veeresht/CommPy.git
7 | $ cd CommPy
8 | $ python3 setup.py install
9 | '''
10 |
11 |
12 |
13 | import numpy as np
14 | from commpy.channelcoding import convcode as cc
15 |
16 | class code:
17 | def __init__(self,d1,d2,m):
18 | self.d1 = d1
19 | self.d2 = d2
20 | self.m = m # Number of delay elements in the convolutional encoder
21 | self.generator_matrixNSC = np.array([[self.d1, self.d2]])# G(D) corresponding to the convolutional encoder
22 | self.trellisNSC = cc.Trellis(np.array([self.m]), self.generator_matrixNSC)# Create trellis data structure
23 | self.tb_depth = 5*(self.m + 1) # Traceback depth of the decoder
24 | self.code_rate = self.trellisNSC.k / self.trellisNSC.n # the code rate
25 | ## get impulse response
26 | self.impulse_response = self.commpy_encode_sequence(np.concatenate([np.array([1],dtype=np.int8),np.zeros([self.m],dtype=np.int8)],axis=0)).astype(np.int8)
27 |
28 | def commpy_encode_sequence(self,u,terminate=False):
29 | if terminate:
30 | return cc.conv_encode(u, self.trellisNSC, code_type = 'default')
31 | else:
32 | return cc.conv_encode(u, self.trellisNSC, code_type = 'default')[:-2*self.trellisNSC.total_memory]
33 |
34 | def commpy_encode_batch(self,u,terminate=False):
35 | x0 = self.commpy_encode_sequence(u[0],terminate)
36 | x = np.empty(shape=[u.shape[0],len(x0)],dtype=np.int8)
37 | x[0] = x0
38 | for i in range(len(u)-1):
39 | x[i+1] = self.commpy_encode_sequence(u[i+1],terminate)
40 | return x
41 |
42 | def commpy_decode_sequence(self,y):
43 | return cc.viterbi_decode(y, self.trellisNSC, self.tb_depth,'unquantized')
44 |
45 | def commpy_decode_batch(self,y):
46 | u_hat0 = cc.viterbi_decode(y[0], self.trellisNSC, self.tb_depth,'unquantized')
47 | u_hat = np.empty(shape=[y.shape[0],len(u_hat0)],dtype=np.int8)
48 | u_hat[0] = u_hat0
49 | for i in range(len(y)-1):
50 | u_hat[i+1] = cc.viterbi_decode(y[i+1], self.trellisNSC, self.tb_depth,'unquantized')
51 | return u_hat
52 |
53 | def zero_pad(self,u):
54 | return np.reshape(np.stack([u,np.zeros_like(u)],axis=1),(-1,))
55 |
56 | def encode_sequence(self,u,terminate=False):
57 | if terminate:
58 | return np.convolve(self.zero_pad(u),self.impulse_response,mode='full')[:-1] % 2
59 | else:
60 | return np.convolve(self.zero_pad(u),self.impulse_response,mode='full')[:len(u)*2] % 2
61 |
62 | def encode_batch(self,u,terminate=False):
63 | x0 = self.encode_sequence(u[0],terminate)
64 | x = np.empty((u.shape[0],x0.shape[0]),dtype=np.int8)
65 | x[0] = x0
66 | for i in range(len(u)-1):
67 | x[i+1] = self.encode_sequence(u[i+1],terminate)
68 | return x
--------------------------------------------------------------------------------