├── Datasets ├── guangzhou_speed.npy └── portland_volume.npy ├── Examples ├── Toy examples.ipynb └── test.ipynb ├── Figures ├── algorithm1.png ├── algorithm2.png ├── missing patterns.png ├── norm_compare.png └── objective.png ├── Helper.py ├── Imputer.py ├── LICENSE └── README.md /Datasets/guangzhou_speed.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tongnie/tensorlib/1515a6307dea0c28c742d594c7eed013fb63cca1/Datasets/guangzhou_speed.npy -------------------------------------------------------------------------------- /Datasets/portland_volume.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tongnie/tensorlib/1515a6307dea0c28c742d594c7eed013fb63cca1/Datasets/portland_volume.npy -------------------------------------------------------------------------------- /Examples/Toy examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import pandas as pd\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import IPython\n", 13 | "import time\n", 14 | "import random\n", 15 | "import Helper as helper\n", 16 | "from Imputer import LRTC_TSpN\n", 17 | "\n", 18 | "plt.rcParams['figure.figsize'] = (10,6)\n", 19 | "%matplotlib inline\n", 20 | "IPython.display.set_matplotlib_formats('svg')" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "This notebook gives a toy example to show how to implement LRTC-TSpN (low-rank tensor completion based on truncated tensor norm) on two small-size traffic flow data. Users can adopt this model to any spatial-temporal traffic data. For more detailted discussion about LRTC-TSpN, please see [1]. More details please refer to our GitHub repository [**tensorlib - GitHub**](https://github.com/tongnie/tensorlib).\n", 28 | "\n", 29 | "
\n", 30 | "\n", 31 | "[1] Tong Nie, Guoyang Qin, Jian Sun (2022). Truncated tensor Schatten p-norm based approach for spatiotemporal traffic data imputation with complicated missing patterns. arXiv.2205.09390 [PDF] \n", 32 | "\n", 33 | "
" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "## Preparation\n", 41 | "### Third-order Tensor Structure\n", 42 | "\n", 43 | "We organize the multivariate traffic time series as a third-order\n", 44 | "tensor structure, i.e. $time~intervals×locations~(sensors)×days$. This three-dimensional data structure simultaneously\n", 45 | "captures the integrated spatial-temporal information, thus making it more efficient to impute missing values.\n", 46 | "\n", 47 | "### Spatial-temporal traffic sensor data\n", 48 | "\n", 49 | "In this notebook, we conduct data imputation on the following two small subsets of traffic speed and volume datasets, the original data can be found at our GitHub repository. [**tensorlib - GitHub**](https://github.com/tongnie/tensorlib).\n", 50 | "- **Guangzhou-small:** This is an urban traffic speed data set which consists of 214 road segments within two months (i.e., 61 days from August 1, 2016 to September 30, 2016) at 10-minute interval, in Guangzhou, China. We only use the speed data with the first 50 locations and the first 15 days. The size is (144 × 50 × 15). \n", 51 | "- **Portland-small:** This data set consists of link volume collected from highways in Portland, which contains 1156 loop detectors within one month at 15-minute interval. Volume data with the first 80 locations and the first 15 days are used. The size is (96 × 80 × 15).\n", 52 | "\n", 53 | "### Complicated missing patterns\n", 54 | "Besides the element-wise random missing case, we define three structured fiber mode-$n$ missing scenarios, which are generated through the two-by-two combinations of tensor mode-$n$ fibers. This can be described as: \n", 55 | "- **’Intervals’ mode fiber-like missing (FM-0)**, which illustrates a temporal missing pattern, is caused by adverse weather, breakdown of wireless connections or apparatus maintenance; \n", 56 | "- **’Locations’ mode fiber-like missing (FM-1)**, which denotes a spatial missing pattern, can be explained by lack of electricity for successive sensors or malfunction of Internet Data Center; \n", 57 | "- **’Days’ mode fiber-like missing (FM-2)** illuminates a spatial-temporal mixture missing situation that they are offline (do not operate) at regular time intervals everyday for specific sensors." 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "Load the Guangzhou speed dataset" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 2, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "name": "stdout", 74 | "output_type": "stream", 75 | "text": [ 76 | "Random missing rate of tensor is:51.00%\n" 77 | ] 78 | } 79 | ], 80 | "source": [ 81 | "#Random missing pattern\n", 82 | "speed_tensor = np.load('../Datasets/guangzhou_speed.npy')\n", 83 | "\n", 84 | "random.seed(123)\n", 85 | "speed_tensor_lost = helper.generate_tensor_random_missing(speed_tensor,lost_rate=0.5)\n", 86 | "tensor_miss_rate = helper.get_missing_rate(speed_tensor_lost)\n", 87 | "print(f'Random missing rate of tensor is:{100*tensor_miss_rate:.2f}%')" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "Generate three Fiber-mode missing patterns" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 3, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "name": "stdout", 104 | "output_type": "stream", 105 | "text": [ 106 | "fiber-mode0 missing rate of tensor is:50.93%\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "#Fiber mode-0 missing\n", 112 | "random.seed(123)\n", 113 | "speed_tensor_lost_fiber0 = helper.generate_fiber_missing(speed_tensor,lost_rate=0.5,mode=0)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 4, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "fiber-mode1 missing rate of tensor is:51.00%\n" 126 | ] 127 | } 128 | ], 129 | "source": [ 130 | "#Fiber mode-1 missing\n", 131 | "random.seed(123)\n", 132 | "speed_tensor_lost_fiber1 = helper.generate_fiber_missing(speed_tensor,lost_rate=0.5,mode=1)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 5, 138 | "metadata": {}, 139 | "outputs": [ 140 | { 141 | "name": "stdout", 142 | "output_type": "stream", 143 | "text": [ 144 | "fiber-mode2 missing rate of tensor is:50.89%\n" 145 | ] 146 | } 147 | ], 148 | "source": [ 149 | "#Fiber mode-2 missing\n", 150 | "random.seed(123)\n", 151 | "speed_tensor_lost_fiber2 = helper.generate_fiber_missing(speed_tensor,lost_rate=0.5,mode=2)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 6, 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "name": "stdout", 161 | "output_type": "stream", 162 | "text": [ 163 | "TSp_ADMM Iteration: \n", 164 | " Processing loop 52\n", 165 | " total iterations = 52 error=0.0008558335667463486\n", 166 | "LRTC-TSpN imptation MAE = 2.832\n", 167 | "LRTC-TSpN imputation RMSE = 4.261\n" 168 | ] 169 | }, 170 | { 171 | "data": { 172 | "image/svg+xml": [ 173 | "\n", 174 | "\n", 176 | "\n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " 2022-05-30T09:57:46.058033\n", 182 | " image/svg+xml\n", 183 | " \n", 184 | " \n", 185 | " Matplotlib v3.5.1, https://matplotlib.org/\n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 508 | " \n", 534 | " \n", 555 | " \n", 576 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 811 | " \n", 827 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 938 | " \n", 966 | " \n", 977 | " \n", 998 | " \n", 1005 | " \n", 1036 | " \n", 1049 | " \n", 1050 | " \n", 1051 | " \n", 1052 | " \n", 1053 | " \n", 1054 | " \n", 1055 | " \n", 1056 | " \n", 1057 | " \n", 1058 | " \n", 1059 | " \n", 1060 | " \n", 1061 | " \n", 1062 | " \n", 1063 | " \n", 1064 | " \n", 1065 | " \n", 1066 | " \n", 1067 | " \n", 1068 | "\n" 1069 | ], 1070 | "text/plain": [ 1071 | "
" 1072 | ] 1073 | }, 1074 | "metadata": { 1075 | "needs_background": "light" 1076 | }, 1077 | "output_type": "display_data" 1078 | } 1079 | ], 1080 | "source": [ 1081 | "#Data imputation and plot the convergency curve\n", 1082 | "theta = 0.1\n", 1083 | "plt.subplots(figsize = (5,3))\n", 1084 | "it,X_hat,MAE_List_admm,RMSE_List_admm,_ = LRTC_TSpN(speed_tensor,speed_tensor_lost_fiber0,theta =theta,p=0.7,beta=1e-5,incre=0.1,maxiter = 200,show_plot = True)\n", 1085 | "\n", 1086 | "plt.title('LRTC-TSpN')\n", 1087 | "plt.grid(alpha=0.3)\n", 1088 | "ax = plt.gca()\n", 1089 | "ax.set_axisbelow(True)\n", 1090 | "lines = ax.lines\n", 1091 | "for line in lines:\n", 1092 | " line.set_linewidth(2.3)\n", 1093 | " line.set_color('royalblue')\n", 1094 | " line.set_alpha(0.8)" 1095 | ] 1096 | }, 1097 | { 1098 | "cell_type": "code", 1099 | "execution_count": 7, 1100 | "metadata": {}, 1101 | "outputs": [ 1102 | { 1103 | "name": "stdout", 1104 | "output_type": "stream", 1105 | "text": [ 1106 | "TSp_ADMM Iteration: \n", 1107 | " Processing loop 64\n", 1108 | " total iterations = 64 error=0.0008833959353604114\n", 1109 | "LRTC-TSpN imptation MAE = 2.426\n", 1110 | "LRTC-TSpN imputation RMSE = 3.476\n" 1111 | ] 1112 | }, 1113 | { 1114 | "data": { 1115 | "image/svg+xml": [ 1116 | "\n", 1117 | "\n", 1119 | "\n", 1120 | " \n", 1121 | " \n", 1122 | " \n", 1123 | " \n", 1124 | " 2022-05-30T09:57:48.810762\n", 1125 | " image/svg+xml\n", 1126 | " \n", 1127 | " \n", 1128 | " Matplotlib v3.5.1, https://matplotlib.org/\n", 1129 | " \n", 1130 | " \n", 1131 | " \n", 1132 | " \n", 1133 | " \n", 1134 | " \n", 1135 | " \n", 1136 | " \n", 1137 | " \n", 1138 | " \n", 1139 | " \n", 1146 | " \n", 1147 | " \n", 1148 | " \n", 1149 | " \n", 1155 | " \n", 1156 | " \n", 1157 | " \n", 1158 | " \n", 1159 | " \n", 1162 | " \n", 1163 | " \n", 1164 | " \n", 1165 | " \n", 1168 | " \n", 1169 | " \n", 1170 | " \n", 1171 | " \n", 1172 | " \n", 1173 | " \n", 1174 | " \n", 1175 | " \n", 1176 | " \n", 1177 | " \n", 1198 | " \n", 1199 | " \n", 1200 | " \n", 1201 | " \n", 1202 | " \n", 1203 | " \n", 1204 | " \n", 1205 | " \n", 1208 | " \n", 1209 | " \n", 1210 | " \n", 1211 | " \n", 1212 | " \n", 1213 | " \n", 1214 | " \n", 1215 | " \n", 1216 | " \n", 1217 | " \n", 1218 | " \n", 1232 | " \n", 1233 | " \n", 1234 | " \n", 1235 | " \n", 1236 | " \n", 1237 | " \n", 1238 | " \n", 1239 | " \n", 1240 | " \n", 1243 | " \n", 1244 | " \n", 1245 | " \n", 1246 | " \n", 1247 | " \n", 1248 | " \n", 1249 | " \n", 1250 | " \n", 1251 | " \n", 1252 | " \n", 1253 | " \n", 1277 | " \n", 1278 | " \n", 1279 | " \n", 1280 | " \n", 1281 | " \n", 1282 | " \n", 1283 | " \n", 1284 | " \n", 1285 | " \n", 1288 | " \n", 1289 | " \n", 1290 | " \n", 1291 | " \n", 1292 | " \n", 1293 | " \n", 1294 | " \n", 1295 | " \n", 1296 | " \n", 1297 | " \n", 1298 | " \n", 1330 | " \n", 1331 | " \n", 1332 | " \n", 1333 | " \n", 1334 | " \n", 1335 | " \n", 1336 | " \n", 1337 | " \n", 1338 | " \n", 1341 | " \n", 1342 | " \n", 1343 | " \n", 1344 | " \n", 1345 | " \n", 1346 | " \n", 1347 | " \n", 1348 | " \n", 1349 | " \n", 1350 | " \n", 1351 | " \n", 1370 | " \n", 1371 | " \n", 1372 | " \n", 1373 | " \n", 1374 | " \n", 1375 | " \n", 1376 | " \n", 1377 | " \n", 1378 | " \n", 1381 | " \n", 1382 | " \n", 1383 | " \n", 1384 | " \n", 1385 | " \n", 1386 | " \n", 1387 | " \n", 1388 | " \n", 1389 | " \n", 1390 | " \n", 1391 | " \n", 1416 | " \n", 1417 | " \n", 1418 | " \n", 1419 | " \n", 1420 | " \n", 1421 | " \n", 1422 | " \n", 1423 | " \n", 1424 | " \n", 1427 | " \n", 1428 | " \n", 1429 | " \n", 1430 | " \n", 1431 | " \n", 1432 | " \n", 1433 | " \n", 1434 | " \n", 1435 | " \n", 1436 | " \n", 1437 | " \n", 1467 | " \n", 1468 | " \n", 1469 | " \n", 1470 | " \n", 1471 | " \n", 1472 | " \n", 1473 | " \n", 1474 | " \n", 1475 | " \n", 1476 | " \n", 1477 | " \n", 1502 | " \n", 1528 | " \n", 1549 | " \n", 1570 | " \n", 1589 | " \n", 1590 | " \n", 1591 | " \n", 1592 | " \n", 1593 | " \n", 1594 | " \n", 1595 | " \n", 1596 | " \n", 1597 | " \n", 1598 | " \n", 1599 | " \n", 1600 | " \n", 1601 | " \n", 1604 | " \n", 1605 | " \n", 1606 | " \n", 1607 | " \n", 1610 | " \n", 1611 | " \n", 1612 | " \n", 1613 | " \n", 1614 | " \n", 1615 | " \n", 1616 | " \n", 1617 | " \n", 1618 | " \n", 1619 | " \n", 1620 | " \n", 1621 | " \n", 1622 | " \n", 1623 | " \n", 1624 | " \n", 1627 | " \n", 1628 | " \n", 1629 | " \n", 1630 | " \n", 1631 | " \n", 1632 | " \n", 1633 | " \n", 1634 | " \n", 1635 | " \n", 1636 | " \n", 1637 | " \n", 1638 | " \n", 1639 | " \n", 1640 | " \n", 1641 | " \n", 1642 | " \n", 1645 | " \n", 1646 | " \n", 1647 | " \n", 1648 | " \n", 1649 | " \n", 1650 | " \n", 1651 | " \n", 1652 | " \n", 1653 | " \n", 1654 | " \n", 1655 | " \n", 1656 | " \n", 1657 | " \n", 1658 | " \n", 1659 | " \n", 1660 | " \n", 1663 | " \n", 1664 | " \n", 1665 | " \n", 1666 | " \n", 1667 | " \n", 1668 | " \n", 1669 | " \n", 1670 | " \n", 1671 | " \n", 1672 | " \n", 1673 | " \n", 1674 | " \n", 1675 | " \n", 1676 | " \n", 1677 | " \n", 1678 | " \n", 1679 | " \n", 1680 | " \n", 1696 | " \n", 1712 | " \n", 1727 | " \n", 1728 | " \n", 1729 | " \n", 1730 | " \n", 1731 | " \n", 1732 | " \n", 1733 | " \n", 1734 | " \n", 1735 | " \n", 1801 | " \n", 1802 | " \n", 1803 | " \n", 1806 | " \n", 1807 | " \n", 1808 | " \n", 1811 | " \n", 1812 | " \n", 1813 | " \n", 1816 | " \n", 1817 | " \n", 1818 | " \n", 1821 | " \n", 1822 | " \n", 1823 | " \n", 1824 | " \n", 1825 | " \n", 1826 | " \n", 1835 | " \n", 1863 | " \n", 1874 | " \n", 1895 | " \n", 1902 | " \n", 1933 | " \n", 1946 | " \n", 1947 | " \n", 1948 | " \n", 1949 | " \n", 1950 | " \n", 1951 | " \n", 1952 | " \n", 1953 | " \n", 1954 | " \n", 1955 | " \n", 1956 | " \n", 1957 | " \n", 1958 | " \n", 1959 | " \n", 1960 | " \n", 1961 | " \n", 1962 | " \n", 1963 | " \n", 1964 | " \n", 1965 | "\n" 1966 | ], 1967 | "text/plain": [ 1968 | "
" 1969 | ] 1970 | }, 1971 | "metadata": { 1972 | "needs_background": "light" 1973 | }, 1974 | "output_type": "display_data" 1975 | } 1976 | ], 1977 | "source": [ 1978 | "#Data imputation and plot the convergency curve\n", 1979 | "theta = 0.1\n", 1980 | "plt.subplots(figsize = (5,3))\n", 1981 | "it,X_hat,MAE_List_admm,RMSE_List_admm,_ = LRTC_TSpN(speed_tensor,speed_tensor_lost_fiber1,theta =theta,p=0.8,beta=1e-5,incre=0.1,maxiter = 200,show_plot = True)\n", 1982 | "\n", 1983 | "plt.title('LRTC-TSpN')\n", 1984 | "plt.grid(alpha=0.3)\n", 1985 | "ax = plt.gca()\n", 1986 | "ax.set_axisbelow(True)\n", 1987 | "lines = ax.lines\n", 1988 | "for line in lines:\n", 1989 | " line.set_linewidth(2.3)\n", 1990 | " line.set_color('royalblue')\n", 1991 | " line.set_alpha(0.8)" 1992 | ] 1993 | }, 1994 | { 1995 | "cell_type": "code", 1996 | "execution_count": 8, 1997 | "metadata": {}, 1998 | "outputs": [ 1999 | { 2000 | "name": "stdout", 2001 | "output_type": "stream", 2002 | "text": [ 2003 | "TSp_ADMM Iteration: \n", 2004 | " Processing loop 64\n", 2005 | " total iterations = 64 error=0.00099624077043648\n", 2006 | "LRTC-TSpN imptation MAE = 2.397\n", 2007 | "LRTC-TSpN imputation RMSE = 3.443\n" 2008 | ] 2009 | }, 2010 | { 2011 | "data": { 2012 | "image/svg+xml": [ 2013 | "\n", 2014 | "\n", 2016 | "\n", 2017 | " \n", 2018 | " \n", 2019 | " \n", 2020 | " \n", 2021 | " 2022-05-30T09:57:51.473221\n", 2022 | " image/svg+xml\n", 2023 | " \n", 2024 | " \n", 2025 | " Matplotlib v3.5.1, https://matplotlib.org/\n", 2026 | " \n", 2027 | " \n", 2028 | " \n", 2029 | " \n", 2030 | " \n", 2031 | " \n", 2032 | " \n", 2033 | " \n", 2034 | " \n", 2035 | " \n", 2036 | " \n", 2043 | " \n", 2044 | " \n", 2045 | " \n", 2046 | " \n", 2052 | " \n", 2053 | " \n", 2054 | " \n", 2055 | " \n", 2056 | " \n", 2059 | " \n", 2060 | " \n", 2061 | " \n", 2062 | " \n", 2065 | " \n", 2066 | " \n", 2067 | " \n", 2068 | " \n", 2069 | " \n", 2070 | " \n", 2071 | " \n", 2072 | " \n", 2073 | " \n", 2074 | " \n", 2095 | " \n", 2096 | " \n", 2097 | " \n", 2098 | " \n", 2099 | " \n", 2100 | " \n", 2101 | " \n", 2102 | " \n", 2105 | " \n", 2106 | " \n", 2107 | " \n", 2108 | " \n", 2109 | " \n", 2110 | " \n", 2111 | " \n", 2112 | " \n", 2113 | " \n", 2114 | " \n", 2115 | " \n", 2129 | " \n", 2130 | " \n", 2131 | " \n", 2132 | " \n", 2133 | " \n", 2134 | " \n", 2135 | " \n", 2136 | " \n", 2137 | " \n", 2140 | " \n", 2141 | " \n", 2142 | " \n", 2143 | " \n", 2144 | " \n", 2145 | " \n", 2146 | " \n", 2147 | " \n", 2148 | " \n", 2149 | " \n", 2150 | " \n", 2174 | " \n", 2175 | " \n", 2176 | " \n", 2177 | " \n", 2178 | " \n", 2179 | " \n", 2180 | " \n", 2181 | " \n", 2182 | " \n", 2185 | " \n", 2186 | " \n", 2187 | " \n", 2188 | " \n", 2189 | " \n", 2190 | " \n", 2191 | " \n", 2192 | " \n", 2193 | " \n", 2194 | " \n", 2195 | " \n", 2227 | " \n", 2228 | " \n", 2229 | " \n", 2230 | " \n", 2231 | " \n", 2232 | " \n", 2233 | " \n", 2234 | " \n", 2235 | " \n", 2238 | " \n", 2239 | " \n", 2240 | " \n", 2241 | " \n", 2242 | " \n", 2243 | " \n", 2244 | " \n", 2245 | " \n", 2246 | " \n", 2247 | " \n", 2248 | " \n", 2267 | " \n", 2268 | " \n", 2269 | " \n", 2270 | " \n", 2271 | " \n", 2272 | " \n", 2273 | " \n", 2274 | " \n", 2275 | " \n", 2278 | " \n", 2279 | " \n", 2280 | " \n", 2281 | " \n", 2282 | " \n", 2283 | " \n", 2284 | " \n", 2285 | " \n", 2286 | " \n", 2287 | " \n", 2288 | " \n", 2313 | " \n", 2314 | " \n", 2315 | " \n", 2316 | " \n", 2317 | " \n", 2318 | " \n", 2319 | " \n", 2320 | " \n", 2321 | " \n", 2324 | " \n", 2325 | " \n", 2326 | " \n", 2327 | " \n", 2328 | " \n", 2329 | " \n", 2330 | " \n", 2331 | " \n", 2332 | " \n", 2333 | " \n", 2334 | " \n", 2364 | " \n", 2365 | " \n", 2366 | " \n", 2367 | " \n", 2368 | " \n", 2369 | " \n", 2370 | " \n", 2371 | " \n", 2372 | " \n", 2373 | " \n", 2374 | " \n", 2399 | " \n", 2425 | " \n", 2446 | " \n", 2467 | " \n", 2486 | " \n", 2487 | " \n", 2488 | " \n", 2489 | " \n", 2490 | " \n", 2491 | " \n", 2492 | " \n", 2493 | " \n", 2494 | " \n", 2495 | " \n", 2496 | " \n", 2497 | " \n", 2498 | " \n", 2501 | " \n", 2502 | " \n", 2503 | " \n", 2504 | " \n", 2507 | " \n", 2508 | " \n", 2509 | " \n", 2510 | " \n", 2511 | " \n", 2512 | " \n", 2513 | " \n", 2514 | " \n", 2515 | " \n", 2516 | " \n", 2517 | " \n", 2518 | " \n", 2519 | " \n", 2520 | " \n", 2521 | " \n", 2524 | " \n", 2525 | " \n", 2526 | " \n", 2527 | " \n", 2528 | " \n", 2529 | " \n", 2530 | " \n", 2531 | " \n", 2532 | " \n", 2533 | " \n", 2534 | " \n", 2535 | " \n", 2536 | " \n", 2537 | " \n", 2538 | " \n", 2539 | " \n", 2542 | " \n", 2543 | " \n", 2544 | " \n", 2545 | " \n", 2546 | " \n", 2547 | " \n", 2548 | " \n", 2549 | " \n", 2550 | " \n", 2551 | " \n", 2552 | " \n", 2553 | " \n", 2554 | " \n", 2555 | " \n", 2556 | " \n", 2557 | " \n", 2560 | " \n", 2561 | " \n", 2562 | " \n", 2563 | " \n", 2564 | " \n", 2565 | " \n", 2566 | " \n", 2567 | " \n", 2568 | " \n", 2569 | " \n", 2570 | " \n", 2571 | " \n", 2572 | " \n", 2573 | " \n", 2574 | " \n", 2575 | " \n", 2576 | " \n", 2577 | " \n", 2593 | " \n", 2609 | " \n", 2624 | " \n", 2625 | " \n", 2626 | " \n", 2627 | " \n", 2628 | " \n", 2629 | " \n", 2630 | " \n", 2631 | " \n", 2632 | " \n", 2698 | " \n", 2699 | " \n", 2700 | " \n", 2703 | " \n", 2704 | " \n", 2705 | " \n", 2708 | " \n", 2709 | " \n", 2710 | " \n", 2713 | " \n", 2714 | " \n", 2715 | " \n", 2718 | " \n", 2719 | " \n", 2720 | " \n", 2721 | " \n", 2722 | " \n", 2723 | " \n", 2732 | " \n", 2760 | " \n", 2771 | " \n", 2792 | " \n", 2799 | " \n", 2830 | " \n", 2843 | " \n", 2844 | " \n", 2845 | " \n", 2846 | " \n", 2847 | " \n", 2848 | " \n", 2849 | " \n", 2850 | " \n", 2851 | " \n", 2852 | " \n", 2853 | " \n", 2854 | " \n", 2855 | " \n", 2856 | " \n", 2857 | " \n", 2858 | " \n", 2859 | " \n", 2860 | " \n", 2861 | " \n", 2862 | "\n" 2863 | ], 2864 | "text/plain": [ 2865 | "
" 2866 | ] 2867 | }, 2868 | "metadata": { 2869 | "needs_background": "light" 2870 | }, 2871 | "output_type": "display_data" 2872 | } 2873 | ], 2874 | "source": [ 2875 | "#Data imputation and plot the convergency curve\n", 2876 | "theta = 0.1\n", 2877 | "plt.subplots(figsize = (5,3))\n", 2878 | "it,X_hat,MAE_List_admm,RMSE_List_admm,_ = LRTC_TSpN(speed_tensor,speed_tensor_lost_fiber2,theta =theta,p=0.8,beta=1e-5,incre=0.1,maxiter = 200,show_plot = True)\n", 2879 | "\n", 2880 | "plt.title('LRTC-TSpN')\n", 2881 | "plt.grid(alpha=0.3)\n", 2882 | "ax = plt.gca()\n", 2883 | "ax.set_axisbelow(True)\n", 2884 | "lines = ax.lines\n", 2885 | "for line in lines:\n", 2886 | " line.set_linewidth(2.3)\n", 2887 | " line.set_color('royalblue')\n", 2888 | " line.set_alpha(0.8)" 2889 | ] 2890 | }, 2891 | { 2892 | "cell_type": "markdown", 2893 | "metadata": {}, 2894 | "source": [ 2895 | "# License\n", 2896 | "\n", 2897 | "
\n", 2898 | "This work is released under the MIT license.\n", 2899 | "
" 2900 | ] 2901 | } 2902 | ], 2903 | "metadata": { 2904 | "kernelspec": { 2905 | "display_name": "Python 3", 2906 | "language": "python", 2907 | "name": "python3" 2908 | }, 2909 | "language_info": { 2910 | "codemirror_mode": { 2911 | "name": "ipython", 2912 | "version": 3 2913 | }, 2914 | "file_extension": ".py", 2915 | "mimetype": "text/x-python", 2916 | "name": "python", 2917 | "nbconvert_exporter": "python", 2918 | "pygments_lexer": "ipython3", 2919 | "version": "3.8.3" 2920 | } 2921 | }, 2922 | "nbformat": 4, 2923 | "nbformat_minor": 4 2924 | } 2925 | -------------------------------------------------------------------------------- /Examples/test.ipynb: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Figures/algorithm1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tongnie/tensorlib/1515a6307dea0c28c742d594c7eed013fb63cca1/Figures/algorithm1.png -------------------------------------------------------------------------------- /Figures/algorithm2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tongnie/tensorlib/1515a6307dea0c28c742d594c7eed013fb63cca1/Figures/algorithm2.png -------------------------------------------------------------------------------- /Figures/missing patterns.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tongnie/tensorlib/1515a6307dea0c28c742d594c7eed013fb63cca1/Figures/missing patterns.png -------------------------------------------------------------------------------- /Figures/norm_compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tongnie/tensorlib/1515a6307dea0c28c742d594c7eed013fb63cca1/Figures/norm_compare.png -------------------------------------------------------------------------------- /Figures/objective.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tongnie/tensorlib/1515a6307dea0c28c742d594c7eed013fb63cca1/Figures/objective.png -------------------------------------------------------------------------------- /Helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Tools for tensor operations 5 | 6 | @author: nietong 7 | """ 8 | 9 | import numpy as np 10 | from collections import deque 11 | import matplotlib.pyplot as plt 12 | import random as random 13 | 14 | 15 | def shiftdim(array, n=None): 16 | if n is not None: 17 | if n >= 0: 18 | axes = tuple(range(len(array.shape))) 19 | new_axes = deque(axes) 20 | new_axes.rotate(n) 21 | return np.moveaxis(array, axes, tuple(new_axes)) 22 | return np.expand_dims(array, axis=tuple(range(-n))) 23 | else: 24 | idx = 0 25 | for dim in array.shape: 26 | if dim == 1: 27 | idx += 1 28 | else: 29 | break 30 | axes = tuple(range(idx)) 31 | # Note that this returns a tuple of 2 results 32 | 33 | 34 | def Fold(X, dim, i): 35 | #Fold a matrix into a tensor in mode i, dim is a tuple of the targeted tensor. 36 | dim = np.roll(dim, -i) 37 | X = shiftdim(np.reshape(X, dim,order='F'), len(dim)-i) 38 | return X 39 | 40 | 41 | def Unfold( X, dim, i ): 42 | #Unfold a tensor into a tensor in mode i. 43 | X_unfold = np.reshape(shiftdim(X,i), (dim[i],-1),order='F') 44 | return X_unfold 45 | 46 | 47 | def TensorFromMat(mat,dim): 48 | #Construct a 3D tensor from a matrix 49 | days_slice = [(start_i,start_i + dim[0]) for start_i in list(range(0,dim[0]*dim[2],dim[0]))] 50 | array_list = [] 51 | for day_slice in days_slice: 52 | start_i,end_i = day_slice[0],day_slice[1] 53 | array_slice = mat[start_i:end_i,:] 54 | array_list.append(array_slice) 55 | tensor3d = np.array(np.stack(array_list,axis = 0).astype('float64')) 56 | tensor3d = np.moveaxis(tensor3d,0,-1) 57 | 58 | print(tensor3d.shape) 59 | 60 | return tensor3d 61 | 62 | 63 | def Tensor2Mat(tensor): 64 | #convert a tensor into a matrix by flattening the 'day' mode to 'time interval'. 65 | #The shape of given tensor should be 'time interval * locations * days'. 66 | #Note that this operation is slightly different from Unfold operation 67 | for k in range(np.shape(tensor)[-1]): 68 | if k == 0: 69 | stacked = np.vstack(tensor[:,:,k]) 70 | else: 71 | stacked = np.vstack((stacked,tensor[:,:,k])) 72 | return stacked 73 | 74 | 75 | def compute_MAE(X_masked,X_true,X_hat): #Only calculate the errors on the masked and nonzero positions 76 | pos_test = np.where((X_true != 0) & (X_masked == 0)) 77 | MAE = np.sum(abs(X_true[pos_test]-X_hat[pos_test]))/X_true[pos_test].shape[0] 78 | 79 | return MAE 80 | 81 | 82 | def compute_RMSE(X_masked,X_true,X_hat): 83 | pos_test = np.where((X_true != 0) & (X_masked == 0)) 84 | RMSE = np.sqrt(((X_true[pos_test]-X_hat[pos_test])**2).sum()/X_true[pos_test].shape[0]) 85 | 86 | return RMSE 87 | 88 | 89 | def compute_MAPE(X_masked,X_true,X_hat): 90 | pos_test = np.where((X_true != 0) & (X_masked == 0)) 91 | MAPE = np.sum(np.abs(X_true[pos_test]-X_hat[pos_test]) / X_true[pos_test]) / X_true[pos_test].shape[0] 92 | 93 | return MAPE 94 | 95 | 96 | def get_missing_rate(X_lost): 97 | o_channel_num = (X_lost == 0).astype(int).sum().sum() 98 | matrix_miss_rate = o_channel_num/(X_lost.size) 99 | 100 | return matrix_miss_rate 101 | 102 | 103 | def generate_fiber_missing(tensor3d_true,lost_rate,mode:int): 104 | #three kinds of fiber-like missing cases, the original tensor structure is intervals*links*days. 105 | #mode0:links*days combination 106 | #mode1:intervals*days combination 107 | #mode2:intervals*links combination 108 | n = tensor3d_true.shape 109 | nn = np.delete(n,mode) 110 | S = np.ones(nn) 111 | coord = [] 112 | for i in range(nn[0]): 113 | for j in range(nn[1]): 114 | coord.append((i,j)) 115 | mask = random.sample(coord,int(lost_rate*len(coord))) 116 | for coord in mask: 117 | S[coord[0],coord[1]] = 0 118 | fai = np.expand_dims(S,mode).repeat(n[mode],axis=mode) 119 | tensor3d_lost_fiber = fai*tensor3d_true 120 | tensor_miss_rate = get_missing_rate(tensor3d_lost_fiber) 121 | print(f'fiber-mode{mode} missing rate of tensor is:{100*tensor_miss_rate:.2f}%') 122 | 123 | return tensor3d_lost_fiber 124 | 125 | def generate_tensor_random_missing(tensor3d_true,lost_rate): 126 | tensor3d_lost = tensor3d_true.copy() 127 | coord = [] 128 | m,n,q = tensor3d_lost.shape 129 | for i in range(m): 130 | for j in range(n): 131 | for k in range(q): 132 | coord.append((i,j,k)) 133 | 134 | mask = random.sample(coord,int(lost_rate*len(coord))) 135 | for coord in mask: 136 | tensor3d_lost[coord[0]][coord[1]][coord[2]] = 0 137 | return tensor3d_lost 138 | 139 | -------------------------------------------------------------------------------- /Imputer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Truncated tensor Schatten p-norm based low-rank tensor completion, LRTC-TSpN 5 | 6 | @author: nietong 7 | """ 8 | 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | from Helper import Fold,Unfold,compute_MAE,compute_RMSE,compute_MAPE 12 | 13 | #Optiminize Truncated Schatten p-norm via ADMM 14 | 15 | def truncation(unfoldj,theta): 16 | #calculate the truncation of each unfolding 17 | dim = np.array(unfoldj.shape) 18 | wj = np.zeros(min(dim),) 19 | r = np.int(np.ceil(theta * min(dim))) 20 | wj[r:] = 1 21 | 22 | return wj 23 | 24 | 25 | def GST(sigma,w,p,J=5): #J is the ineer iterations of GST, J=5 is supposed to be enough 26 | #Generalized soft-thresholding algorithm 27 | if w == 0: 28 | Sp = sigma 29 | else: 30 | dt = np.zeros(J+1) 31 | tau = (2*w*(1-p))**(1/(2-p)) + w*p*(2*w*(1-p))**((p-1)/(2-p)) 32 | if np.abs(sigma) <= tau: 33 | Sp = 0 34 | else: 35 | dt[0] = np.abs(sigma) 36 | for k in range(J): 37 | dt[k+1] = np.abs(sigma) - w*p*(dt[k])**(p-1) 38 | Sp = np.sign(sigma)*dt[k].item() 39 | 40 | return Sp 41 | 42 | 43 | def update_Mi(mat,alphai,beta,p,theta): 44 | #update M variable 45 | delta = [] 46 | u,d,v = np.linalg.svd(mat,full_matrices=False) 47 | wi = truncation(mat,theta) 48 | for j in range(len(d)): 49 | deltaj = GST(d[j],(alphai/beta)*wi[j],p) 50 | delta.append(deltaj) 51 | delta = np.diag(delta) 52 | Mi = u@delta@v 53 | 54 | return Mi 55 | 56 | 57 | def TSpN_ADMM(X_true,X_missing,Omega,alpha,beta,incre,maxIter,epsilon,p,theta): 58 | X = X_missing.copy() 59 | X[Omega==False] = np.mean(X_missing[Omega]) #Initialize with mean values 60 | errList = [] 61 | MAE_List = [] 62 | RMSE_List = [] 63 | dim = X_missing.shape 64 | M = np.zeros(np.insert(dim, 0, len(dim))) #M is a 4-th order tensor 65 | Q = np.zeros(np.insert(dim, 0, len(dim))) #Q is a 4-th order tensor 66 | print('TSp_ADMM Iteration: ') 67 | 68 | for k in range(maxIter): 69 | beta = beta * (1+incre) #Increase beta with given step 70 | print(f'\r Processing loop {k}',end = '',flush=True) 71 | 72 | #Update M variable 73 | for i in range(np.ndim(X_missing)): 74 | M[i] = Fold(update_Mi(Unfold(X+(1/beta)*Q[i],dim,i),alpha[i],beta,p,theta),dim,i) ##M为四维张量 75 | 76 | Xlast = X.copy() 77 | X = np.sum(beta*M-Q,axis=0)/(beta*(X_missing.ndim)) #Updata X variable 78 | X[Omega] = X_missing[Omega] #Observed data 79 | 80 | Q = Q + beta*(np.broadcast_to(X, np.insert(dim, 0, len(dim)))-M) #Update Q variable 81 | 82 | errList_k = np.linalg.norm(X-Xlast)/np.linalg.norm(Xlast) 83 | errList.append(errList_k) 84 | MAE_List.append(compute_MAE(X_missing,X_true,X)) 85 | RMSE_List.append(compute_RMSE(X_missing,X_true,X)) 86 | 87 | if errList_k < epsilon: 88 | break 89 | 90 | print(f'\n total iterations = {k} error={errList[-1]}') 91 | 92 | return X,MAE_List,RMSE_List,errList,k 93 | 94 | 95 | def LRTC_TSpN(complete_tensor,observed_tensor,theta=0.1,alpha=np.array([1,1,1]),p=0.5,beta=1e-5,incre=0.05,maxiter = 200,show_plot = True): 96 | X_true = complete_tensor.copy() 97 | X_missing = observed_tensor.copy() 98 | Omega = (X_missing != 0) 99 | alpha = alpha.reshape(-1,1) 100 | alpha = alpha / np.sum(alpha) 101 | epsilon = 1e-3 102 | X_hat,MAE_List,RMSE_List,errList,it = TSpN_ADMM(X_true,X_missing,Omega,alpha,beta,incre,maxiter,epsilon,p,theta) 103 | MAPE = compute_MAPE(X_missing,X_true,X_hat) 104 | print(f'LRTC-TSpN imptation MAE = {MAE_List[-1]:.3f}') 105 | print(f'LRTC-TSpN imputation RMSE = {RMSE_List[-1]:.3f}') 106 | print(f'LRTC-TSpN imputation MAPE = {MAPE:.3f}') 107 | 108 | if show_plot == True: 109 | plt.plot(range(len(MAE_List)),MAE_List) 110 | plt.xlabel('epoch') 111 | plt.ylabel('MAE') 112 | plt.title('Convergence curve of LRTC-TSpN') 113 | 114 | return it,X_hat,MAE_List,RMSE_List,errList 115 | 116 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Tong Nie 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Low-rank tensor completion algorithm for spatiotemporal traffic data imputation 2 | -------------- 3 | ![Python 3.8](https://img.shields.io/badge/Python-3.8-blue.svg) 4 | [![MIT License](https://img.shields.io/badge/license-MIT-green.svg)](https://opensource.org/licenses/MIT) 5 | 6 | **L**ow-**r**ank **t**ensor **c**ompletion using **T**runcated tensor **S**chatten **p**-**N**orm, LRTC-TSpN. 7 | 8 | 9 | > This is the code repository for paper 'Truncated tensor Schatten p-norm based approach for spatiotemporal traffic data 10 | imputation with complicated missing patterns' which is published on Transportation Research Part C: Emerging Technologies 11 | 12 | ## Overview 13 | This project provides some examples about how to use LRTC-TSpN to achieve efficient and accurate missing data imputation for transportation time series data. We aim at performing off-line data imputation tasks, with several realistic structural missing patterns. Missing data imputation problem is modelled as a low-rank tensor completion problem (low-rank tensor learning). The objective is to obtain a fully recovered tensor by minimizing a predefined tensor rank function, given the observations. We define a new **truncated tensor Schatten p-norm** to substitute for the traditional tensor nuclear norm. We recommend ones to refer to Kolda and Bader’s review [Tensor Decompositions and Applications](https://epubs.siam.org/doi/abs/10.1137/07070111x) for more basics about tensor algebra. 14 | 15 | ## Model description 16 | We organise the multiple input time series data into a third-order tensor structure of (time intervals × locations × days). Traditional methods resort to tensor nuclear norm (or sum of nuclear norm) to substitute for the tensor rank, however, these convex surrogates are not powerful in practice. Therefore, we use the newly emerging Schatten p-norm and its truncated version to approximate tensor rank in order to achieve more accurate traffic data imputation. 17 | 18 | 19 |

20 | 21 |

22 | 23 | > Schatten p-norm always serves as a better rank surrogate (closer to the true rank) than nuclaer norm, and we can use its nonconvex properity to better approximate tensor rank. 24 | 25 | The objective function of Schatten p-norm minimization is formulated as: 26 | 27 |

28 | 29 |

30 | 31 | This is a typical noncnovex optimization problem. Previous works aiming at solving tensor completion problem always conduct a singular value thresholding (SVT) algorithm. While existing SVT could not be applied to our problem directly. So the main challenge is to develop a new generalized SVT algorithm for this new definition of norm. 32 | 33 | ## Solving algorithm 34 | We solve this non-convex problem by using Alternating Direction Method of Multipliers (ADMM) and Generalized Soft Thresholding (GST). 35 | 36 | Generalized soft-thresholding algorithm: 37 | 38 |

39 | 40 |

41 | 42 | ADMM framework: 43 |

44 | 45 |

46 | 47 | Despite of nonconvexity, ADMM framework still ensures the convergence of our model. With proper updating scheme, our algorithm can converge with fewer iterations. More algorithmic details can be found in our paper. The preprint version is available at [arXiv](https://arxiv.org/abs/2205.09390), and the published version can be found at the [Elsevier publisher](https://doi.org/10.1016/j.trc.2022.103737). 48 | 49 | 50 | ## Spatial-temporal data missing patterns 51 | Besides the element-wise random missing case, we define three structured fiber mode-n missing scenarios, which are generated through the two-by-two combinations of tensor mode-n fibers. This can be described as: 52 | - **’Intervals’ mode fiber-like missing (FM-0)**, which illustrates a temporal missing pattern, is caused by adverse weather, breakdown of wireless connections or apparatus maintenance; 53 | - **’Locations’ mode fiber-like missing (FM-1)**, which denotes a spatial missing pattern, can be explained by lack of electricity for successive sensors or malfunction of Internet Data Center; 54 | - **’Days’ mode fiber-like missing (FM-2)** illuminates a spatial-temporal mixture missing situation that they are offline (do not operate) at regular time intervals everyday for specific sensors. 55 | 56 | ## Datasets 57 | In this repository, we have used two small-size traffic flow datasets to show how to implement our model, they are: 58 | - **Guangzhou-small**: Speed data with the first 50 locations and the first 15 days. The size is (144 × 50 × 15). 59 | - **Portland-small**: Volume data with the first 80 locations and the first 15 days. The size is (96 × 80 × 15). 60 | 61 | We provide the two datasets in [../Datasets/](https://github.com/tongnie/tensorlib/tree/main/Datasets). 62 | The original links for the complete data are given as following. 63 | 64 | - [Guangzhou urban traffic speed data set](https://doi.org/10.5281/zenodo.1205228) 65 | - [Portland highway traffic data set](https://portal.its.pdx.edu/home) 66 | 67 | ## Implementation 68 | The Python implementation of LRTC-TSpN is given in [../Imputer/](https://github.com/tongnie/tensorlib/blob/main/Imputer.py). The core of the algorithm is the GST and ADMM iteration module. We organize this implementation in a tensor-only way to make it more efficient. Some utils and basic tensor operation functions are provided in [../Helper/](https://github.com/tongnie/tensorlib/blob/main/Helper.py). 69 | 70 | 71 | ## Toy Examples 72 | We give some examples written in Jupyter notebook [../Examples/](https://github.com/tongnie/tensorlib/blob/main/Examples). 73 | 74 | ## References 75 | 76 | >Please cite our paper if this repo helps your research. 77 | 78 | #### Cited as: 79 | bibtex: 80 | 81 | ``` 82 | @article{nie2022truncated, 83 | title={Truncated tensor Schatten p-norm based approach for spatiotemporal traffic data imputation with complicated missing patterns}, 84 | author={Nie, Tong and Qin, Guoyang and Sun, Jian}, 85 | journal={Transportation Research Part C: Emerging Technologies}, 86 | volume={141}, 87 | pages={103737}, 88 | year={2022}, 89 | publisher={Elsevier} 90 | } 91 | ``` 92 | 93 | ### Our Publications 94 | -------------- 95 | - Tong Nie, Guoyang Qin, Yunpeng Wang, and Jian Sun (2023). **Towards better traffic volume estimation: Tackling both underdetermined and non-equilibrium problems via a correlation-adaptive graph convolution network**. arXiv preprint arXiv:2303.05660. [[Preprint](https://doi.org/10.48550/arXiv.2303.05660)] [[Code](https://github.com/tongnie/GNN4Flow)] 96 | 97 | - Tong Nie, Guoyang Qin, Yunpeng Wang, and Jian Sun (2023). **Correlating sparse sensing for large-scale traffic speed estimation: 98 | A Laplacian-enhanced low-rank tensor kriging approach**. Transportation Research Part C: Emerging Technologies, 152, 104190, [[Preprint](https://doi.org/10.48550/arXiv.2210.11780)] [[DOI](https://doi.org/10.1016/j.trc.2023.104190)] [[Code](https://github.com/tongnie/tensor4kriging)] 99 | 100 | - Tong Nie, Guoyang Qin, and Jian Sun (2022). **Truncated tensor Schatten p-norm based approach for spatiotemporal traffic data imputation with complicated missing patterns**. Transportation research part C: emerging technologies, 141, 103737, [[Preprint](https://doi.org/10.48550/arXiv.2205.09390)] [[DOI](https://doi.org/10.1016/j.trc.2022.103737)] [[Code](https://github.com/tongnie/tensorlib)] 101 | 102 | 103 | License 104 | -------------- 105 | 106 | This work is released under the MIT license. 107 | --------------------------------------------------------------------------------