├── .gitignore ├── LICENSE ├── README.md ├── als.ipynb ├── als.py ├── train_result.png └── week18_deeplearning_cv └── ch18_8_image_search_part_3.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Hyeong Jun Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## ALS_implementation 2 | Implementation of ALS algorithm from "Collaborative Filtering for Implicit Feedback Data" [[paper]](http://yifanhu.net/PUB/cf.pdf) 3 | 4 | ## Contents 5 | - als.ipynb : notebook which contains step by step comments and visulization. 6 | - als.py : simplified python implementation 7 | 8 | ## Parameters to fix 9 | - r_lambda = 40 10 | - nf = 200 11 | - alpha = 40 12 | 13 | ## Train Result 14 | 15 | 16 | ## More detail 17 | - als.ipynb contains details of each algorithm steps. 18 | - more information can be found in my blog posting. (in Korean) [[blog]](https://yeomko.tistory.com/8) 19 | 20 | -------------------------------------------------------------------------------- /als.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## ALS Implementation\n", 8 | "- This notebook |is implementation of ALS algorithm from \"collaborative filtering for implicit dataset\"\n", 9 | "\n", 10 | "### Initialize parameters\n", 11 | "- r_lambda: normalization parameter \n", 12 | "- alpha: confidence level \n", 13 | "- nf: dimension of latent vector of each user and item \n", 14 | "- initilzed values(40, 200, 40) are the best parameters from the paper" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 98, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "r_lambda = 40\n", 24 | "nf = 200\n", 25 | "alpha = 40" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "### Initialize original rating matrix data\n", 33 | "- make sample (10 x 11) matrix\n", 34 | "- 10 : num of users\n", 35 | "- 11 : num of items" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 99, 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "name": "stdout", 45 | "output_type": "stream", 46 | "text": [ 47 | "(10, 11)\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "import numpy as np\n", 53 | "\n", 54 | "\n", 55 | "# sample rating matrix\n", 56 | "R = np.array([[0, 0, 0, 4, 4, 0, 0, 0, 0, 0, 0],\n", 57 | " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],\n", 58 | " [0, 0, 0, 0, 0, 0, 0, 1, 0, 4, 0],\n", 59 | " [0, 3, 4, 0, 3, 0, 0, 2, 2, 0, 0],\n", 60 | " [0, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0],\n", 61 | " [0, 0, 0, 0, 0, 0, 5, 0, 0, 5, 0],\n", 62 | " [0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 5],\n", 63 | " [0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 4],\n", 64 | " [0, 0, 0, 0, 0, 0, 5, 0, 0, 5, 0],\n", 65 | " [0, 0, 0, 3, 0, 0, 0, 0, 4, 5, 0]])\n", 66 | "print(R.shape)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "### Initialize user and item latent factor matrix\n", 74 | "- nu: num of users (10)\n", 75 | "- ni: num of items (11)\n", 76 | "- nf: dimension of latent vector" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 100, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "[[0.00457303 0.00424385 0.00616191 ... 0.0055276 0.00441167 0.0032725 ]\n", 89 | " [0.00197897 0.00110453 0.00558329 ... 0.00915551 0.00749223 0.006756 ]\n", 90 | " [0.00549164 0.00358061 0.00012187 ... 0.00723139 0.00441681 0.00632714]\n", 91 | " ...\n", 92 | " [0.00380493 0.005744 0.0024226 ... 0.00737873 0.00131759 0.00736437]\n", 93 | " [0.00115033 0.00066236 0.00046558 ... 0.00804704 0.00986983 0.00343679]\n", 94 | " [0.00504484 0.00616501 0.00130519 ... 0.00236188 0.00136329 0.00666413]]\n" 95 | ] 96 | } 97 | ], 98 | "source": [ 99 | "nu = R.shape[0]\n", 100 | "ni = R.shape[1]\n", 101 | "\n", 102 | "# initialize X and Y with very small values\n", 103 | "X = np.random.rand(nu, nf) * 0.01\n", 104 | "Y = np.random.rand(ni, nf) * 0.01\n", 105 | "\n", 106 | "print(X)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "### Initialize Binary Rating Matrix P\n", 114 | "- Convert original rating matrix R into P\n", 115 | "- Pui = 1 if Rui > 0\n", 116 | "- Pui = 0 if Rui = 0" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 101, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "name": "stdout", 126 | "output_type": "stream", 127 | "text": [ 128 | "[[0 0 0 1 1 0 0 0 0 0 0]\n", 129 | " [0 0 0 0 0 0 0 0 0 0 1]\n", 130 | " [0 0 0 0 0 0 0 1 0 1 0]\n", 131 | " [0 1 1 0 1 0 0 1 1 0 0]\n", 132 | " [0 1 1 0 0 0 0 0 0 0 0]\n", 133 | " [0 0 0 0 0 0 1 0 0 1 0]\n", 134 | " [0 0 1 0 0 0 0 0 0 0 1]\n", 135 | " [0 0 0 0 0 1 0 0 0 0 1]\n", 136 | " [0 0 0 0 0 0 1 0 0 1 0]\n", 137 | " [0 0 0 1 0 0 0 0 1 1 0]]\n" 138 | ] 139 | } 140 | ], 141 | "source": [ 142 | "P = np.copy(R)\n", 143 | "P[P > 0] = 1\n", 144 | "print(P)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "### Initialize Confidence Matrix C\n", 152 | "- Cui = 1 + alpha * Rui\n", 153 | "- Cui means confidence level of certain rating data" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 102, 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "name": "stdout", 163 | "output_type": "stream", 164 | "text": [ 165 | "[[ 1 1 1 161 161 1 1 1 1 1 1]\n", 166 | " [ 1 1 1 1 1 1 1 1 1 1 41]\n", 167 | " [ 1 1 1 1 1 1 1 41 1 161 1]\n", 168 | " [ 1 121 161 1 121 1 1 81 81 1 1]\n", 169 | " [ 1 201 201 1 1 1 1 1 1 1 1]\n", 170 | " [ 1 1 1 1 1 1 201 1 1 201 1]\n", 171 | " [ 1 1 161 1 1 1 1 1 1 1 201]\n", 172 | " [ 1 1 1 1 1 161 1 1 1 1 161]\n", 173 | " [ 1 1 1 1 1 1 201 1 1 201 1]\n", 174 | " [ 1 1 1 121 1 1 1 1 161 201 1]]\n" 175 | ] 176 | } 177 | ], 178 | "source": [ 179 | "C = 1 + alpha * R\n", 180 | "print(C)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": {}, 186 | "source": [ 187 | "### Set up loss function\n", 188 | "- C: confidence matrix\n", 189 | "- P: binary rating matrix\n", 190 | "- X: user latent matrix\n", 191 | "- Y: item latent matrix\n", 192 | "- r_lambda: regularization lambda\n", 193 | "- xTy: predict matrix \n", 194 | " \n", 195 | "- Total_loss = (confidence_level * predict loss) + regularization loss" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 103, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "def loss_function(C, P, xTy, X, Y, r_lambda):\n", 205 | " predict_error = np.square(P - xTy)\n", 206 | " confidence_error = np.sum(C * predict_error)\n", 207 | " regularization = r_lambda * (np.sum(np.square(X)) + np.sum(np.square(Y)))\n", 208 | " total_loss = confidence_error + regularization\n", 209 | " return np.sum(predict_error), confidence_error, regularization, total_loss" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": {}, 215 | "source": [ 216 | "### Optimization Function for user and item\n", 217 | "- X[u] = (yTCuy + lambda*I)^-1yTCuy\n", 218 | "- Y[i] = (xTCix + lambda*I)^-1xTCix\n", 219 | "- two formula is the same when it changes X to Y and u to i" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 104, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "def optimize_user(X, Y, C, P, nu, nf, r_lambda):\n", 229 | " yT = np.transpose(Y)\n", 230 | " for u in range(nu):\n", 231 | " Cu = np.diag(C[u])\n", 232 | " yT_Cu_y = np.matmul(np.matmul(yT, Cu), Y)\n", 233 | " lI = np.dot(r_lambda, np.identity(nf))\n", 234 | " yT_Cu_pu = np.matmul(np.matmul(yT, Cu), P[u])\n", 235 | " X[u] = np.linalg.solve(yT_Cu_y + lI, yT_Cu_pu)\n", 236 | "\n", 237 | "def optimize_item(X, Y, C, P, ni, nf, r_lambda):\n", 238 | " xT = np.transpose(X)\n", 239 | " for i in range(ni):\n", 240 | " Ci = np.diag(C[:, i])\n", 241 | " xT_Ci_x = np.matmul(np.matmul(xT, Ci), X)\n", 242 | " lI = np.dot(r_lambda, np.identity(nf))\n", 243 | " xT_Ci_pi = np.matmul(np.matmul(xT, Ci), P[:, i])\n", 244 | " Y[i] = np.linalg.solve(xT_Ci_x + lI, xT_Ci_pi)" 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "metadata": {}, 250 | "source": [ 251 | "### Train\n", 252 | "- usually ALS algorithm repeat train steps for 10 ~ 15 times" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 105, 258 | "metadata": {}, 259 | "outputs": [ 260 | { 261 | "name": "stdout", 262 | "output_type": "stream", 263 | "text": [ 264 | "----------------step 0----------------\n", 265 | "predict error: 22.767452\n", 266 | "confidence error: 3466.937317\n", 267 | "regularization: 5.636577\n", 268 | "total loss: 3472.573894\n", 269 | "----------------step 1----------------\n", 270 | "predict error: 31.023025\n", 271 | "confidence error: 296.327809\n", 272 | "regularization: 641.240713\n", 273 | "total loss: 937.568522\n", 274 | "----------------step 2----------------\n", 275 | "predict error: 34.018641\n", 276 | "confidence error: 139.069458\n", 277 | "regularization: 651.405197\n", 278 | "total loss: 790.474655\n", 279 | "----------------step 3----------------\n", 280 | "predict error: 32.074433\n", 281 | "confidence error: 120.069403\n", 282 | "regularization: 650.898109\n", 283 | "total loss: 770.967512\n", 284 | "----------------step 4----------------\n", 285 | "predict error: 29.512463\n", 286 | "confidence error: 109.299027\n", 287 | "regularization: 653.525673\n", 288 | "total loss: 762.824700\n", 289 | "----------------step 5----------------\n", 290 | "predict error: 27.060229\n", 291 | "confidence error: 102.230607\n", 292 | "regularization: 656.474772\n", 293 | "total loss: 758.705379\n", 294 | "----------------step 6----------------\n", 295 | "predict error: 25.016805\n", 296 | "confidence error: 97.437625\n", 297 | "regularization: 658.928086\n", 298 | "total loss: 756.365710\n", 299 | "----------------step 7----------------\n", 300 | "predict error: 23.442905\n", 301 | "confidence error: 94.190686\n", 302 | "regularization: 660.770057\n", 303 | "total loss: 754.960743\n", 304 | "----------------step 8----------------\n", 305 | "predict error: 22.284994\n", 306 | "confidence error: 92.007313\n", 307 | "regularization: 662.092210\n", 308 | "total loss: 754.099523\n", 309 | "----------------step 9----------------\n", 310 | "predict error: 21.454691\n", 311 | "confidence error: 90.546515\n", 312 | "regularization: 663.021474\n", 313 | "total loss: 753.567989\n", 314 | "----------------step 10----------------\n", 315 | "predict error: 20.867422\n", 316 | "confidence error: 89.570690\n", 317 | "regularization: 663.667698\n", 318 | "total loss: 753.238388\n", 319 | "----------------step 11----------------\n", 320 | "predict error: 20.455195\n", 321 | "confidence error: 88.918571\n", 322 | "regularization: 664.113910\n", 323 | "total loss: 753.032480\n", 324 | "----------------step 12----------------\n", 325 | "predict error: 20.167422\n", 326 | "confidence error: 88.482663\n", 327 | "regularization: 664.419671\n", 328 | "total loss: 752.902334\n", 329 | "----------------step 13----------------\n", 330 | "predict error: 19.967775\n", 331 | "confidence error: 88.191857\n", 332 | "regularization: 664.626891\n", 333 | "total loss: 752.818748\n", 334 | "----------------step 14----------------\n", 335 | "predict error: 19.830538\n", 336 | "confidence error: 87.999084\n", 337 | "regularization: 664.764920\n", 338 | "total loss: 752.764004\n", 339 | "final predict\n", 340 | "[array([[ 0. , 0.7213401 , 0.68235019, 0.8576162 , 0.83827723,\n", 341 | " 0.09239648, 0.58977576, 0.68720774, 0.81166324, 0.81802531,\n", 342 | " 0.30765598],\n", 343 | " [ 0. , 0.3361484 , 0.41909192, 0.10608421, 0.15560257,\n", 344 | " 0.39303202, -0.02675532, 0.14828855, 0.13284078, 0.03463808,\n", 345 | " 0.4837343 ],\n", 346 | " [ 0. , 0.46243844, 0.42946834, 0.63028311, 0.58534636,\n", 347 | " 0.03177001, 0.7517364 , 0.56616991, 0.62686475, 0.86784737,\n", 348 | " 0.16049371],\n", 349 | " [ 0. , 0.92285236, 0.94187576, 0.89650102, 0.90838244,\n", 350 | " 0.3563045 , 0.60756749, 0.77780712, 0.87870053, 0.86042895,\n", 351 | " 0.62460443],\n", 352 | " [ 0. , 0.84981068, 0.92676941, 0.62501363, 0.68101786,\n", 353 | " 0.53344484, 0.26330552, 0.57464173, 0.62942323, 0.4788263 ,\n", 354 | " 0.77920821],\n", 355 | " [ 0. , 0.29236108, 0.24197338, 0.54119403, 0.46961086,\n", 356 | " -0.08184888, 0.88182444, 0.51154855, 0.55064509, 0.93896817,\n", 357 | " -0.00732 ],\n", 358 | " [ 0. , 0.74376899, 0.8749514 , 0.37604512, 0.45811821,\n", 359 | " 0.68817876, 0.07352049, 0.40765911, 0.41042402, 0.2330001 ,\n", 360 | " 0.89526845],\n", 361 | " [ 0. , 0.56340473, 0.73794323, 0.08895129, 0.18834466,\n", 362 | " 0.78203442, -0.1124117 , 0.20273044, 0.15151988, -0.02678813,\n", 363 | " 0.92869513],\n", 364 | " [ 0. , 0.29236108, 0.24197338, 0.54119403, 0.46961086,\n", 365 | " -0.08184888, 0.88182444, 0.51154855, 0.55064509, 0.93896817,\n", 366 | " -0.00732 ],\n", 367 | " [ 0. , 0.68684109, 0.6415058 , 0.86952565, 0.83246496,\n", 368 | " 0.05966073, 0.75878686, 0.72385157, 0.83634861, 0.96332154,\n", 369 | " 0.26036486]])]\n" 370 | ] 371 | } 372 | ], 373 | "source": [ 374 | "predict_errors = []\n", 375 | "confidence_errors = []\n", 376 | "regularization_list = []\n", 377 | "total_losses = []\n", 378 | "\n", 379 | "for i in range(15):\n", 380 | " if i!=0: \n", 381 | " optimize_user(X, Y, C, P, nu, nf, r_lambda)\n", 382 | " optimize_item(X, Y, C, P, ni, nf, r_lambda)\n", 383 | " predict = np.matmul(X, np.transpose(Y))\n", 384 | " predict_error, confidence_error, regularization, total_loss = loss_function(C, P, predict, X, Y, r_lambda)\n", 385 | " \n", 386 | " predict_errors.append(predict_error)\n", 387 | " confidence_errors.append(confidence_error)\n", 388 | " regularization_list.append(regularization)\n", 389 | " total_losses.append(total_loss)\n", 390 | " \n", 391 | " print('----------------step %d----------------' % i)\n", 392 | " print(\"predict error: %f\" % predict_error)\n", 393 | " print(\"confidence error: %f\" % confidence_error)\n", 394 | " print(\"regularization: %f\" % regularization)\n", 395 | " print(\"total loss: %f\" % total_loss)\n", 396 | " \n", 397 | "predict = np.matmul(X, np.transpose(Y))\n", 398 | "print('final predict')\n", 399 | "print([predict])" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": 106, 405 | "metadata": {}, 406 | "outputs": [ 407 | { 408 | "data": { 409 | "text/plain": [ 410 | "
" 411 | ] 412 | }, 413 | "metadata": {}, 414 | "output_type": "display_data" 415 | }, 416 | { 417 | "data": { 418 | "image/png": "\n", 419 | "text/plain": [ 420 | "
" 421 | ] 422 | }, 423 | "metadata": { 424 | "needs_background": "light" 425 | }, 426 | "output_type": "display_data" 427 | } 428 | ], 429 | "source": [ 430 | "from matplotlib import pyplot as plt\n", 431 | "%matplotlib inline\n", 432 | "\n", 433 | "plt.subplots_adjust(wspace=100.0, hspace=20.0)\n", 434 | "fig = plt.figure()\n", 435 | "fig.set_figheight(10)\n", 436 | "fig.set_figwidth(10)\n", 437 | "predict_error_line = fig.add_subplot(2, 2, 1)\n", 438 | "confidence_error_line = fig.add_subplot(2, 2, 2)\n", 439 | "regularization_error_line = fig.add_subplot(2, 2, 3)\n", 440 | "total_loss_line = fig.add_subplot(2, 2, 4)\n", 441 | "\n", 442 | "predict_error_line.set_title(\"Predict Error\") \n", 443 | "predict_error_line.plot(predict_errors)\n", 444 | "\n", 445 | "confidence_error_line.set_title(\"Confidence Error\")\n", 446 | "confidence_error_line.plot(confidence_errors)\n", 447 | "\n", 448 | "regularization_error_line.set_title(\"Regularization\")\n", 449 | "regularization_error_line.plot(regularization_list)\n", 450 | "\n", 451 | "total_loss_line.set_title(\"Total Loss\")\n", 452 | "total_loss_line.plot(total_losses)\n", 453 | "plt.show()" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": null, 459 | "metadata": {}, 460 | "outputs": [], 461 | "source": [] 462 | } 463 | ], 464 | "metadata": { 465 | "kernelspec": { 466 | "display_name": "Python 3", 467 | "language": "python", 468 | "name": "python3" 469 | }, 470 | "language_info": { 471 | "codemirror_mode": { 472 | "name": "ipython", 473 | "version": 3 474 | }, 475 | "file_extension": ".py", 476 | "mimetype": "text/x-python", 477 | "name": "python", 478 | "nbconvert_exporter": "python", 479 | "pygments_lexer": "ipython3", 480 | "version": "3.7.1" 481 | } 482 | }, 483 | "nbformat": 4, 484 | "nbformat_minor": 2 485 | } 486 | -------------------------------------------------------------------------------- /als.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | 4 | 5 | def loss_function(C, P, xTy, X, Y, r_lambda): 6 | predict_error = np.square(P - xTy) 7 | confidence_error = np.sum(C * predict_error) 8 | regularization = r_lambda * (np.sum(np.square(X)) + np.sum(np.square(Y))) 9 | total_loss = confidence_error + regularization 10 | return np.sum(predict_error), confidence_error, regularization, total_loss 11 | 12 | 13 | def optimize_user(X, Y, C, P, nu, nf, r_lambda): 14 | yT = np.transpose(Y) 15 | for u in range(nu): 16 | Cu = np.diag(C[u]) 17 | yT_Cu_y = np.matmul(np.matmul(yT, Cu), Y) 18 | lI = np.dot(r_lambda, np.identity(nf)) 19 | yT_Cu_pu = np.matmul(np.matmul(yT, Cu), P[u]) 20 | X[u] = np.linalg.solve(yT_Cu_y + lI, yT_Cu_pu) 21 | 22 | 23 | def optimize_item(X, Y, C, P, ni, nf, r_lambda): 24 | xT = np.transpose(X) 25 | for i in range(ni): 26 | Ci = np.diag(C[:, i]) 27 | xT_Ci_x = np.matmul(np.matmul(xT, Ci), X) 28 | lI = np.dot(r_lambda, np.identity(nf)) 29 | xT_Ci_pi = np.matmul(np.matmul(xT, Ci), P[:, i]) 30 | Y[i] = np.linalg.solve(xT_Ci_x + lI, xT_Ci_pi) 31 | 32 | 33 | def plot_losses(predict_errors, confidence_errors, regularization_list, total_losses): 34 | plt.subplots_adjust(wspace=100.0, hspace=20.0) 35 | fig = plt.figure() 36 | fig.set_figheight(10) 37 | fig.set_figwidth(10) 38 | 39 | predict_error_line = fig.add_subplot(2, 2, 1) 40 | confidence_error_line = fig.add_subplot(2, 2, 2) 41 | regularization_error_line = fig.add_subplot(2, 2, 3) 42 | total_loss_line = fig.add_subplot(2, 2, 4) 43 | 44 | predict_error_line.set_title("Predict Error") 45 | predict_error_line.plot(predict_errors) 46 | 47 | confidence_error_line.set_title("Confidence Error") 48 | confidence_error_line.plot(confidence_errors) 49 | 50 | regularization_error_line.set_title("Regularization") 51 | regularization_error_line.plot(regularization_list) 52 | 53 | total_loss_line.set_title("Total Loss") 54 | total_loss_line.plot(total_losses) 55 | plt.show() 56 | 57 | 58 | def train(): 59 | predict_errors = [] 60 | confidence_errors = [] 61 | regularization_list = [] 62 | total_losses = [] 63 | 64 | for i in range(15): 65 | if i != 0: 66 | optimize_user(X, Y, C, P, nu, nf, r_lambda) 67 | optimize_item(X, Y, C, P, ni, nf, r_lambda) 68 | predict = np.matmul(X, np.transpose(Y)) 69 | predict_error, confidence_error, regularization, total_loss = loss_function(C, P, predict, X, Y, r_lambda) 70 | 71 | predict_errors.append(predict_error) 72 | confidence_errors.append(confidence_error) 73 | regularization_list.append(regularization) 74 | total_losses.append(total_loss) 75 | 76 | print('----------------step %d----------------' % i) 77 | print("predict error: %f" % predict_error) 78 | print("confidence error: %f" % confidence_error) 79 | print("regularization: %f" % regularization) 80 | print("total loss: %f" % total_loss) 81 | 82 | predict = np.matmul(X, np.transpose(Y)) 83 | print('final predict') 84 | print([predict]) 85 | 86 | return predict_errors, confidence_errors, regularization_list, total_losses 87 | 88 | 89 | if __name__ == '__main__': 90 | r_lambda = 40 91 | nf = 200 92 | alpha = 40 93 | 94 | R = np.array([[0, 0, 0, 4, 4, 0, 0, 0, 0, 0, 0], 95 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 96 | [0, 0, 0, 0, 0, 0, 0, 1, 0, 4, 0], 97 | [0, 3, 4, 0, 3, 0, 0, 2, 2, 0, 0], 98 | [0, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0], 99 | [0, 0, 0, 0, 0, 0, 5, 0, 0, 5, 0], 100 | [0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 5], 101 | [0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 4], 102 | [0, 0, 0, 0, 0, 0, 5, 0, 0, 5, 0], 103 | [0, 0, 0, 3, 0, 0, 0, 0, 4, 5, 0]]) 104 | 105 | nu = R.shape[0] 106 | ni = R.shape[1] 107 | 108 | # initialize X and Y with very small values 109 | X = np.random.rand(nu, nf) * 0.01 110 | Y = np.random.rand(ni, nf) * 0.01 111 | 112 | P = np.copy(R) 113 | P[P > 0] = 1 114 | C = 1 + alpha * R 115 | 116 | predict_errors, confidence_errors, regularization_list, total_losses = train() 117 | plot_losses(predict_errors, confidence_errors, regularization_list, total_losses) 118 | -------------------------------------------------------------------------------- /train_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeomko22/ALS_implementation/b89f4ef31dd5052defc643f91f2f48b1a27b6df5/train_result.png --------------------------------------------------------------------------------