├── machine-learning-with-chinese-food ├── data │ └── recipe.txt ├── machine-learning-with-chinese-food.ipynb └── readme.md ├── nba-allstar-prediction ├── data.csv ├── how-to-find-nba-all-star.ipynb └── label.csv ├── nba-prediction ├── data │ └── all.txt ├── nba_score_prediction.ipynb └── readme.md └── readme.md /machine-learning-with-chinese-food/readme.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lijin-THU/play-with-machine-learning/1768ff47eddceaf6975eb1484fab1245598f30e2/machine-learning-with-chinese-food/readme.md -------------------------------------------------------------------------------- /nba-prediction/nba_score_prediction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": false, 8 | "deletable": true, 9 | "editable": true 10 | }, 11 | "outputs": [ 12 | { 13 | "name": "stderr", 14 | "output_type": "stream", 15 | "text": [ 16 | "Using TensorFlow backend.\n" 17 | ] 18 | } 19 | ], 20 | "source": [ 21 | "import keras\n", 22 | "import numpy as np\n", 23 | "from keras.models import Sequential, Model\n", 24 | "from keras.layers import *\n", 25 | "from keras.preprocessing.sequence import pad_sequences " 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "处理一下tensorflow的问题:" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": { 39 | "collapsed": false, 40 | "deletable": true, 41 | "editable": true 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "import keras.backend as K\n", 46 | "\n", 47 | "if K.backend() == \"tensorflow\":\n", 48 | " config = K.tf.ConfigProto()\n", 49 | " config.gpu_options.allow_growth = True\n", 50 | " session = K.tf.Session(config=config)\n", 51 | " K.set_session(session)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "读数据:" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 3, 64 | "metadata": { 65 | "collapsed": false, 66 | "deletable": true, 67 | "editable": true 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "with open(\"data/all.txt\") as f:\n", 72 | " all_data = [line.strip().split(\";\") for line in f]" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 4, 78 | "metadata": { 79 | "collapsed": true, 80 | "deletable": true, 81 | "editable": true 82 | }, 83 | "outputs": [], 84 | "source": [ 85 | "data_x_team_id_raw = all_data[0::9]\n", 86 | "data_y_raw = all_data[1::9]\n", 87 | "data_x_team_abbr_raw = all_data[2::9]\n", 88 | "data_x_home_min_raw = all_data[3::9]\n", 89 | "data_x_home_id_raw = all_data[4::9]\n", 90 | "data_x_home_name_raw = all_data[5::9]\n", 91 | "data_x_visitor_min_raw = all_data[6::9]\n", 92 | "data_x_visitor_id_raw = all_data[7::9]\n", 93 | "data_x_visitor_name_raw = all_data[8::9]" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 5, 99 | "metadata": { 100 | "collapsed": true, 101 | "deletable": true, 102 | "editable": true 103 | }, 104 | "outputs": [], 105 | "source": [ 106 | "def flatten(x):\n", 107 | " for seq in x:\n", 108 | " for s in seq:\n", 109 | " yield s" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "把球队名称和ID对应:" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 6, 122 | "metadata": { 123 | "collapsed": true, 124 | "deletable": true, 125 | "editable": true 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "id2team = dict(zip(flatten(data_x_team_id_raw), flatten(data_x_team_abbr_raw)))\n", 130 | "id2player = dict(zip(flatten(data_x_home_id_raw + data_x_visitor_id_raw), \n", 131 | " flatten(data_x_home_name_raw + data_x_visitor_name_raw)))\n", 132 | "team2id = dict(zip(flatten(data_x_team_abbr_raw), flatten(data_x_team_id_raw)))\n", 133 | "player2id = dict(zip(flatten(data_x_home_name_raw + data_x_visitor_name_raw),\n", 134 | " flatten(data_x_home_id_raw + data_x_visitor_id_raw)))" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 7, 140 | "metadata": { 141 | "collapsed": false, 142 | "deletable": true, 143 | "editable": true 144 | }, 145 | "outputs": [ 146 | { 147 | "name": "stdout", 148 | "output_type": "stream", 149 | "text": [ 150 | "total_teams 30\n" 151 | ] 152 | } 153 | ], 154 | "source": [ 155 | "print \"total_teams\", len(id2team)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 8, 161 | "metadata": { 162 | "collapsed": false, 163 | "deletable": true, 164 | "editable": true 165 | }, 166 | "outputs": [ 167 | { 168 | "name": "stdout", 169 | "output_type": "stream", 170 | "text": [ 171 | "total_players 898\n" 172 | ] 173 | } 174 | ], 175 | "source": [ 176 | "print \"total_players\", len(id2player)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": {}, 182 | "source": [ 183 | "将球队id和队员id序列化方便Embbedding:" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 9, 189 | "metadata": { 190 | "collapsed": true, 191 | "deletable": true, 192 | "editable": true 193 | }, 194 | "outputs": [], 195 | "source": [ 196 | "tid2index = {tid: idx for idx, tid in enumerate(id2team)}\n", 197 | "index2tid = {idx: tid for idx, tid in enumerate(id2team)}\n", 198 | "\n", 199 | "pid2index = {pid: idx+1 for idx, pid in enumerate(id2player)}\n", 200 | "index2pid = {idx+1: pid for idx, pid in enumerate(id2player)}" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "将时间转化为秒:" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 10, 213 | "metadata": { 214 | "collapsed": true, 215 | "deletable": true, 216 | "editable": true 217 | }, 218 | "outputs": [], 219 | "source": [ 220 | "def str2time(ms):\n", 221 | " m, s = ms.split(\":\")\n", 222 | " return int(m) * 60 + int(s)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "metadata": {}, 228 | "source": [ 229 | "读取数据:" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 11, 235 | "metadata": { 236 | "collapsed": false, 237 | "deletable": true, 238 | "editable": true 239 | }, 240 | "outputs": [], 241 | "source": [ 242 | "data_x_team_id = np.array(map(lambda x: [tid2index[tid] for tid in x], data_x_team_id_raw))\n", 243 | "data_x_home_id = pad_sequences(map(lambda x: [pid2index[pid] for pid in x], data_x_home_id_raw), padding=\"post\", maxlen=13)\n", 244 | "data_x_vistor_id = pad_sequences(map(lambda x: [pid2index[pid] for pid in x], data_x_visitor_id_raw), padding=\"post\", maxlen=13)\n", 245 | "data_x_home_min = pad_sequences(map(lambda x: [str2time(ms) for ms in x], data_x_home_min_raw), padding=\"post\", maxlen=13)\n", 246 | "data_x_visitor_min = pad_sequences(map(lambda x: [str2time(ms) for ms in x], data_x_visitor_min_raw), padding=\"post\", maxlen=13)\n", 247 | "\n", 248 | "data_x_home_min = 5 * data_x_home_min.astype(K.floatx()) / data_x_home_min.sum(axis=-1)[:,None]\n", 249 | "data_x_visitor_min = 5 * data_x_visitor_min.astype(K.floatx()) / data_x_visitor_min.sum(axis=-1)[:,None]\n", 250 | "\n", 251 | "data_y = np.array(data_y_raw, dtype=int)" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "构造比分预测模型:" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 12, 264 | "metadata": { 265 | "collapsed": false, 266 | "deletable": true, 267 | "editable": true 268 | }, 269 | "outputs": [], 270 | "source": [ 271 | "x_t = Input((2,), name=\"team_id\") \n", 272 | "\n", 273 | "x_h_id = Input((13,), name=\"home_player_id\")\n", 274 | "x_h_min = Input((13,1), name=\"home_player_time\")\n", 275 | "\n", 276 | "x_v_id = Input((13,), name=\"visitor_player_id\")\n", 277 | "x_v_min = Input((13,1), name=\"visitor_player_time\")\n", 278 | "\n", 279 | "emb_dim = 256\n", 280 | "\n", 281 | "team_emb = Sequential(name=\"team_emb\")\n", 282 | "team_emb.add(Embedding(input_dim=30, output_dim=emb_dim, input_length=2))\n", 283 | "team_emb.add(Flatten())\n", 284 | "\n", 285 | "player_emb = Sequential(name=\"player_emb\")\n", 286 | "player_emb.add(Embedding(input_dim=len(id2player)+1, output_dim=emb_dim, input_length=13))\n", 287 | "\n", 288 | "feat_t = team_emb(x_t)\n", 289 | "\n", 290 | "feat_h_id = player_emb(x_h_id)\n", 291 | "feat_h = dot([x_h_min, feat_h_id], axes=1, name=\"home_player_sum\")\n", 292 | "feat_h = Reshape((emb_dim,), name=\"home_feat\")(feat_h)\n", 293 | "\n", 294 | "feat_v_id = player_emb(x_v_id)\n", 295 | "feat_v = dot([x_v_min, feat_v_id], axes=1, name=\"visitor_player_sum\")\n", 296 | "feat_v = Reshape((emb_dim,), name=\"visitor_feat\")(feat_v)\n", 297 | "\n", 298 | "feat = concatenate([feat_t, feat_h, feat_v], name=\"all_feat\")\n", 299 | "\n", 300 | "hid = Dense(256, activation=\"relu\", name=\"hidden_1\")(feat)\n", 301 | "hid = Dropout(0.2, name=\"dropout_1\")(hid)\n", 302 | "hid = Dense(128, activation=\"relu\", name=\"hidden_2\")(hid)\n", 303 | "hid = Dropout(0.2, name=\"dropout_2\")(hid)\n", 304 | "score = Dense(2, activation=\"relu\", name=\"score\")(hid)" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 14, 310 | "metadata": { 311 | "collapsed": false, 312 | "deletable": true, 313 | "editable": true 314 | }, 315 | "outputs": [], 316 | "source": [ 317 | "model = Model(inputs=[x_t, x_h_id, x_h_min, x_v_id, x_v_min], outputs=score)\n", 318 | "\n", 319 | "\n", 320 | "def win(y_true, y_pred):\n", 321 | " return K.mean(K.equal(K.argmax(y_pred, axis=-1), K.argmax(y_true, axis=-1)))\n", 322 | "\n", 323 | "model.compile(optimizer=\"adam\", loss=\"mae\", metrics=[win])" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 15, 329 | "metadata": { 330 | "collapsed": true, 331 | "deletable": true, 332 | "editable": true 333 | }, 334 | "outputs": [], 335 | "source": [ 336 | "np.random.seed(1105)\n", 337 | "\n", 338 | "idx = np.arange(len(data_y))\n", 339 | "np.random.shuffle(idx)\n", 340 | "\n", 341 | "train_idx = idx[600:]\n", 342 | "valid_idx = idx[:600]" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": 16, 348 | "metadata": { 349 | "collapsed": false, 350 | "deletable": true, 351 | "editable": true 352 | }, 353 | "outputs": [ 354 | { 355 | "data": { 356 | "text/plain": [ 357 | "6632" 358 | ] 359 | }, 360 | "execution_count": 16, 361 | "metadata": {}, 362 | "output_type": "execute_result" 363 | } 364 | ], 365 | "source": [ 366 | "len(idx)" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": 17, 372 | "metadata": { 373 | "collapsed": false, 374 | "deletable": true, 375 | "editable": true 376 | }, 377 | "outputs": [], 378 | "source": [ 379 | "hist = model.fit([data_x_team_id[train_idx], \n", 380 | " data_x_home_id[train_idx], data_x_home_min[train_idx, :, None], \n", 381 | " data_x_vistor_id[train_idx], data_x_visitor_min[train_idx, :, None]],\n", 382 | " data_y[train_idx],\n", 383 | " validation_data = ([data_x_team_id[valid_idx], \n", 384 | " data_x_home_id[valid_idx], data_x_home_min[valid_idx, :, None], \n", 385 | " data_x_vistor_id[valid_idx], data_x_visitor_min[valid_idx, :, None]],\n", 386 | " data_y[valid_idx]),\n", 387 | " verbose=0,\n", 388 | " epochs=30)" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 18, 394 | "metadata": { 395 | "collapsed": false, 396 | "deletable": true, 397 | "editable": true 398 | }, 399 | "outputs": [ 400 | { 401 | "data": { 402 | "text/plain": [ 403 | "[10.165847600301106,\n", 404 | " 9.9543780390421546,\n", 405 | " 9.8001865132649737,\n", 406 | " 10.289466730753581,\n", 407 | " 9.228977038065592,\n", 408 | " 11.430291798909506,\n", 409 | " 9.3323127492268885,\n", 410 | " 9.1152853902180997,\n", 411 | " 9.1859752019246415,\n", 412 | " 9.4495480346679681,\n", 413 | " 9.5419914118448901,\n", 414 | " 9.3206252034505201,\n", 415 | " 9.0971596018473306,\n", 416 | " 9.0549177932739262,\n", 417 | " 9.2275110244750973,\n", 418 | " 9.3106695048014316,\n", 419 | " 9.1169172541300458,\n", 420 | " 9.3980016962687181,\n", 421 | " 9.1642890930175778,\n", 422 | " 9.702613741556803,\n", 423 | " 9.2596215184529616,\n", 424 | " 9.3760626475016284,\n", 425 | " 8.9202607472737636,\n", 426 | " 9.2350659434000644,\n", 427 | " 9.2720768737792962,\n", 428 | " 10.19280756632487,\n", 429 | " 9.3464718373616531,\n", 430 | " 9.4991256459554041,\n", 431 | " 9.2756049474080395,\n", 432 | " 9.1515356826782224]" 433 | ] 434 | }, 435 | "execution_count": 18, 436 | "metadata": {}, 437 | "output_type": "execute_result" 438 | } 439 | ], 440 | "source": [ 441 | "hist.history[\"val_loss\"]" 442 | ] 443 | }, 444 | { 445 | "cell_type": "markdown", 446 | "metadata": {}, 447 | "source": [ 448 | "预测测试数据正确率:" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": 20, 454 | "metadata": { 455 | "collapsed": false, 456 | "deletable": true, 457 | "editable": true 458 | }, 459 | "outputs": [ 460 | { 461 | "name": "stdout", 462 | "output_type": "stream", 463 | "text": [ 464 | "0.656666666667\n", 465 | "0.563333333333\n" 466 | ] 467 | } 468 | ], 469 | "source": [ 470 | "data_t = model.predict([data_x_team_id[valid_idx], \n", 471 | " data_x_home_id[valid_idx], data_x_home_min[valid_idx , :, None], \n", 472 | " data_x_vistor_id[valid_idx], data_x_visitor_min[valid_idx , :, None]])\n", 473 | "\n", 474 | "print np.mean(data_t.argmax(axis=-1) == data_y[valid_idx].argmax(axis=-1))\n", 475 | "print np.mean(0 == data_y[valid_idx].argmax(axis=-1))" 476 | ] 477 | }, 478 | { 479 | "cell_type": "markdown", 480 | "metadata": {}, 481 | "source": [ 482 | "测试集上的预测比分:" 483 | ] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "execution_count": 21, 488 | "metadata": { 489 | "collapsed": false, 490 | "deletable": true, 491 | "editable": true, 492 | "scrolled": true 493 | }, 494 | "outputs": [ 495 | { 496 | "data": { 497 | "text/plain": [ 498 | "array([[ 118.31840515, 106.11000824],\n", 499 | " [ 99.41278076, 93.19446564],\n", 500 | " [ 100.38309479, 99.36334991],\n", 501 | " ..., \n", 502 | " [ 99.76564026, 93.42744446],\n", 503 | " [ 91.10274506, 97.74150848],\n", 504 | " [ 103.73945618, 95.22878265]], dtype=float32)" 505 | ] 506 | }, 507 | "execution_count": 21, 508 | "metadata": {}, 509 | "output_type": "execute_result" 510 | } 511 | ], 512 | "source": [ 513 | "data_t" 514 | ] 515 | }, 516 | { 517 | "cell_type": "markdown", 518 | "metadata": {}, 519 | "source": [ 520 | "真实比分:" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": 24, 526 | "metadata": { 527 | "collapsed": false, 528 | "deletable": true, 529 | "editable": true 530 | }, 531 | "outputs": [ 532 | { 533 | "data": { 534 | "text/plain": [ 535 | "array([[116, 92],\n", 536 | " [107, 112],\n", 537 | " [ 94, 100],\n", 538 | " ..., \n", 539 | " [105, 109],\n", 540 | " [112, 100],\n", 541 | " [ 95, 94]])" 542 | ] 543 | }, 544 | "execution_count": 24, 545 | "metadata": {}, 546 | "output_type": "execute_result" 547 | } 548 | ], 549 | "source": [ 550 | "data_y[valid_idx]" 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": 22, 556 | "metadata": { 557 | "collapsed": true, 558 | "deletable": true, 559 | "editable": true 560 | }, 561 | "outputs": [], 562 | "source": [ 563 | "data_x_team_abbr = np.array(data_x_team_abbr_raw)" 564 | ] 565 | }, 566 | { 567 | "cell_type": "code", 568 | "execution_count": 23, 569 | "metadata": { 570 | "collapsed": false, 571 | "deletable": true, 572 | "editable": true 573 | }, 574 | "outputs": [ 575 | { 576 | "data": { 577 | "text/plain": [ 578 | "array([['SAC', 'LAL'],\n", 579 | " ['LAL', 'HOU'],\n", 580 | " ['MIN', 'NYK'],\n", 581 | " ..., \n", 582 | " ['WAS', 'MIL'],\n", 583 | " ['CHA', 'GSW'],\n", 584 | " ['SAC', 'WAS']], \n", 585 | " dtype='|S3')" 586 | ] 587 | }, 588 | "execution_count": 23, 589 | "metadata": {}, 590 | "output_type": "execute_result" 591 | } 592 | ], 593 | "source": [ 594 | "data_x_team_abbr[valid_idx]" 595 | ] 596 | }, 597 | { 598 | "cell_type": "markdown", 599 | "metadata": {}, 600 | "source": [ 601 | "训练集上的比分和队伍:" 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": 25, 607 | "metadata": { 608 | "collapsed": true, 609 | "deletable": true, 610 | "editable": true 611 | }, 612 | "outputs": [], 613 | "source": [ 614 | "data_p = model.predict([data_x_team_id[train_idx], \n", 615 | " data_x_home_id[train_idx], data_x_home_min[train_idx , :, None], \n", 616 | " data_x_vistor_id[train_idx], data_x_visitor_min[train_idx , :, None]])" 617 | ] 618 | }, 619 | { 620 | "cell_type": "code", 621 | "execution_count": 26, 622 | "metadata": { 623 | "collapsed": false, 624 | "deletable": true, 625 | "editable": true, 626 | "scrolled": true 627 | }, 628 | "outputs": [ 629 | { 630 | "data": { 631 | "text/plain": [ 632 | "array([[ 111.05187225, 93.65827179],\n", 633 | " [ 103.43552399, 105.09658813],\n", 634 | " [ 108.43019104, 107.60032654],\n", 635 | " ..., \n", 636 | " [ 85.79730988, 81.65973663],\n", 637 | " [ 107.5667038 , 106.44389343],\n", 638 | " [ 99.8336792 , 98.61862183]], dtype=float32)" 639 | ] 640 | }, 641 | "execution_count": 26, 642 | "metadata": {}, 643 | "output_type": "execute_result" 644 | } 645 | ], 646 | "source": [ 647 | "data_p" 648 | ] 649 | }, 650 | { 651 | "cell_type": "code", 652 | "execution_count": 27, 653 | "metadata": { 654 | "collapsed": false, 655 | "deletable": true, 656 | "editable": true 657 | }, 658 | "outputs": [ 659 | { 660 | "data": { 661 | "text/plain": [ 662 | "array([['HOU', 'PHI'],\n", 663 | " ['DEN', 'ATL'],\n", 664 | " ['PHX', 'MIN'],\n", 665 | " ..., \n", 666 | " ['DET', 'IND'],\n", 667 | " ['LAL', 'BKN'],\n", 668 | " ['MIN', 'UTA']], \n", 669 | " dtype='|S3')" 670 | ] 671 | }, 672 | "execution_count": 27, 673 | "metadata": {}, 674 | "output_type": "execute_result" 675 | } 676 | ], 677 | "source": [ 678 | "data_x_team_abbr[train_idx]" 679 | ] 680 | }, 681 | { 682 | "cell_type": "markdown", 683 | "metadata": {}, 684 | "source": [ 685 | "真实比分:" 686 | ] 687 | }, 688 | { 689 | "cell_type": "code", 690 | "execution_count": 28, 691 | "metadata": { 692 | "collapsed": false, 693 | "deletable": true, 694 | "editable": true 695 | }, 696 | "outputs": [ 697 | { 698 | "data": { 699 | "text/plain": [ 700 | "array([[120, 98],\n", 701 | " [105, 119],\n", 702 | " [107, 104],\n", 703 | " ..., \n", 704 | " [ 77, 88],\n", 705 | " [105, 114],\n", 706 | " [ 92, 94]])" 707 | ] 708 | }, 709 | "execution_count": 28, 710 | "metadata": {}, 711 | "output_type": "execute_result" 712 | } 713 | ], 714 | "source": [ 715 | "data_y[train_idx]" 716 | ] 717 | }, 718 | { 719 | "cell_type": "markdown", 720 | "metadata": {}, 721 | "source": [ 722 | "模型结构:" 723 | ] 724 | }, 725 | { 726 | "cell_type": "code", 727 | "execution_count": 29, 728 | "metadata": { 729 | "collapsed": false, 730 | "deletable": true, 731 | "editable": true 732 | }, 733 | "outputs": [ 734 | { 735 | "data": { 736 | "image/svg+xml": [ 737 | "\n", 738 | "\n", 739 | "G\n", 740 | "\n", 741 | "\n", 742 | "140572603596240\n", 743 | "\n", 744 | "home_player_id: InputLayer\n", 745 | "\n", 746 | "\n", 747 | "140572603639952\n", 748 | "\n", 749 | "player_emb: Sequential\n", 750 | "\n", 751 | "\n", 752 | "140572603596240->140572603639952\n", 753 | "\n", 754 | "\n", 755 | "\n", 756 | "\n", 757 | "140572603596368\n", 758 | "\n", 759 | "visitor_player_id: InputLayer\n", 760 | "\n", 761 | "\n", 762 | "140572603596368->140572603639952\n", 763 | "\n", 764 | "\n", 765 | "\n", 766 | "\n", 767 | "140572603596560\n", 768 | "\n", 769 | "home_player_time: InputLayer\n", 770 | "\n", 771 | "\n", 772 | "140572603639824\n", 773 | "\n", 774 | "home_player_sum: Dot\n", 775 | "\n", 776 | "\n", 777 | "140572603596560->140572603639824\n", 778 | "\n", 779 | "\n", 780 | "\n", 781 | "\n", 782 | "140572603639952->140572603639824\n", 783 | "\n", 784 | "\n", 785 | "\n", 786 | "\n", 787 | "140575373044240\n", 788 | "\n", 789 | "visitor_player_sum: Dot\n", 790 | "\n", 791 | "\n", 792 | "140572603639952->140575373044240\n", 793 | "\n", 794 | "\n", 795 | "\n", 796 | "\n", 797 | "140572603638352\n", 798 | "\n", 799 | "visitor_player_time: InputLayer\n", 800 | "\n", 801 | "\n", 802 | "140572603638352->140575373044240\n", 803 | "\n", 804 | "\n", 805 | "\n", 806 | "\n", 807 | "140572603596176\n", 808 | "\n", 809 | "team_id: InputLayer\n", 810 | "\n", 811 | "\n", 812 | "140572603639632\n", 813 | "\n", 814 | "team_emb: Sequential\n", 815 | "\n", 816 | "\n", 817 | "140572603596176->140572603639632\n", 818 | "\n", 819 | "\n", 820 | "\n", 821 | "\n", 822 | "140573144624848\n", 823 | "\n", 824 | "home_feat: Reshape\n", 825 | "\n", 826 | "\n", 827 | "140572603639824->140573144624848\n", 828 | "\n", 829 | "\n", 830 | "\n", 831 | "\n", 832 | "140573144209040\n", 833 | "\n", 834 | "visitor_feat: Reshape\n", 835 | "\n", 836 | "\n", 837 | "140575373044240->140573144209040\n", 838 | "\n", 839 | "\n", 840 | "\n", 841 | "\n", 842 | "140572603815568\n", 843 | "\n", 844 | "all_feat: Concatenate\n", 845 | "\n", 846 | "\n", 847 | "140572603639632->140572603815568\n", 848 | "\n", 849 | "\n", 850 | "\n", 851 | "\n", 852 | "140573144624848->140572603815568\n", 853 | "\n", 854 | "\n", 855 | "\n", 856 | "\n", 857 | "140573144209040->140572603815568\n", 858 | "\n", 859 | "\n", 860 | "\n", 861 | "\n", 862 | "140572603817872\n", 863 | "\n", 864 | "hidden_1: Dense\n", 865 | "\n", 866 | "\n", 867 | "140572603815568->140572603817872\n", 868 | "\n", 869 | "\n", 870 | "\n", 871 | "\n", 872 | "140572603817936\n", 873 | "\n", 874 | "dropout_1: Dropout\n", 875 | "\n", 876 | "\n", 877 | "140572603817872->140572603817936\n", 878 | "\n", 879 | "\n", 880 | "\n", 881 | "\n", 882 | "140573143359568\n", 883 | "\n", 884 | "hidden_2: Dense\n", 885 | "\n", 886 | "\n", 887 | "140572603817936->140573143359568\n", 888 | "\n", 889 | "\n", 890 | "\n", 891 | "\n", 892 | "140573143359888\n", 893 | "\n", 894 | "dropout_2: Dropout\n", 895 | "\n", 896 | "\n", 897 | "140573143359568->140573143359888\n", 898 | "\n", 899 | "\n", 900 | "\n", 901 | "\n", 902 | "140573143295952\n", 903 | "\n", 904 | "score: Dense\n", 905 | "\n", 906 | "\n", 907 | "140573143359888->140573143295952\n", 908 | "\n", 909 | "\n", 910 | "\n", 911 | "\n", 912 | "" 913 | ], 914 | "text/plain": [ 915 | "" 916 | ] 917 | }, 918 | "execution_count": 29, 919 | "metadata": {}, 920 | "output_type": "execute_result" 921 | } 922 | ], 923 | "source": [ 924 | "from IPython.display import SVG\n", 925 | "from keras.utils.vis_utils import model_to_dot, plot_model\n", 926 | "\n", 927 | "plot_model(model, to_file=\"model.png\")\n", 928 | "SVG(model_to_dot(model).create(prog='dot', format='svg'))" 929 | ] 930 | }, 931 | { 932 | "cell_type": "markdown", 933 | "metadata": {}, 934 | "source": [ 935 | "关公战秦琼:" 936 | ] 937 | }, 938 | { 939 | "cell_type": "code", 940 | "execution_count": 34, 941 | "metadata": { 942 | "collapsed": false, 943 | "deletable": true, 944 | "editable": true 945 | }, 946 | "outputs": [], 947 | "source": [ 948 | "names1 = \"Mike Miller, LeBron James, Chris Bosh, Dwyane Wade, Mario Chalmers,\" \\\n", 949 | " \" Ray Allen, Shane Battier, Chris Andersen, Udonis Haslem\"\n", 950 | "pid1 = [[pid2index[player2id[name]] for name in names1.split(\", \")]]\n", 951 | "min1 = [[19,47,28,39,40,20,29,19,2]]" 952 | ] 953 | }, 954 | { 955 | "cell_type": "code", 956 | "execution_count": 35, 957 | "metadata": { 958 | "collapsed": false, 959 | "deletable": true, 960 | "editable": true 961 | }, 962 | "outputs": [], 963 | "source": [ 964 | "names2 = \"LeBron James, JR Smith, Kevin Love, Kyrie Irving, Tristan Thompson, Richard Jefferson, Mo Williams, Iman Shumpert\"\n", 965 | "pid2 = [[pid2index[player2id[name]] for name in names2.split(\", \")]]\n", 966 | "min2 = [[47,39,30,43,32,26,5,19]]" 967 | ] 968 | }, 969 | { 970 | "cell_type": "code", 971 | "execution_count": 36, 972 | "metadata": { 973 | "collapsed": false, 974 | "deletable": true, 975 | "editable": true 976 | }, 977 | "outputs": [], 978 | "source": [ 979 | "pid1 = pad_sequences(pid1, maxlen=13, padding=\"post\")\n", 980 | "pid2 = pad_sequences(pid2, maxlen=13, padding=\"post\")\n", 981 | "min1 = pad_sequences(min1, maxlen=13, padding=\"post\")\n", 982 | "min2 = pad_sequences(min2, maxlen=13, padding=\"post\")\n", 983 | "\n", 984 | "min1 = 5 * min1.astype(np.float32) / min1.sum()\n", 985 | "min2 = 5 * min2.astype(np.float32) / min2.sum()" 986 | ] 987 | }, 988 | { 989 | "cell_type": "code", 990 | "execution_count": 37, 991 | "metadata": { 992 | "collapsed": false, 993 | "deletable": true, 994 | "editable": true 995 | }, 996 | "outputs": [ 997 | { 998 | "data": { 999 | "text/plain": [ 1000 | "4" 1001 | ] 1002 | }, 1003 | "execution_count": 37, 1004 | "metadata": {}, 1005 | "output_type": "execute_result" 1006 | } 1007 | ], 1008 | "source": [ 1009 | "tid2index[team2id[\"MIA\"]]" 1010 | ] 1011 | }, 1012 | { 1013 | "cell_type": "code", 1014 | "execution_count": 38, 1015 | "metadata": { 1016 | "collapsed": false, 1017 | "deletable": true, 1018 | "editable": true 1019 | }, 1020 | "outputs": [ 1021 | { 1022 | "data": { 1023 | "text/plain": [ 1024 | "14" 1025 | ] 1026 | }, 1027 | "execution_count": 38, 1028 | "metadata": {}, 1029 | "output_type": "execute_result" 1030 | } 1031 | ], 1032 | "source": [ 1033 | "tid2index[team2id[\"CLE\"]]" 1034 | ] 1035 | }, 1036 | { 1037 | "cell_type": "code", 1038 | "execution_count": 39, 1039 | "metadata": { 1040 | "collapsed": false, 1041 | "deletable": true, 1042 | "editable": true 1043 | }, 1044 | "outputs": [ 1045 | { 1046 | "data": { 1047 | "text/plain": [ 1048 | "array([[ 102.45483398, 100.24012756]], dtype=float32)" 1049 | ] 1050 | }, 1051 | "execution_count": 39, 1052 | "metadata": {}, 1053 | "output_type": "execute_result" 1054 | } 1055 | ], 1056 | "source": [ 1057 | "model.predict([np.array([[14, 4]]), pid2, min2[..., None], pid1, min1[..., None]])" 1058 | ] 1059 | }, 1060 | { 1061 | "cell_type": "code", 1062 | "execution_count": 40, 1063 | "metadata": { 1064 | "collapsed": false, 1065 | "deletable": true, 1066 | "editable": true 1067 | }, 1068 | "outputs": [ 1069 | { 1070 | "data": { 1071 | "text/plain": [ 1072 | "array([[ 110.89669037, 106.30657196]], dtype=float32)" 1073 | ] 1074 | }, 1075 | "execution_count": 40, 1076 | "metadata": {}, 1077 | "output_type": "execute_result" 1078 | } 1079 | ], 1080 | "source": [ 1081 | "model.predict([np.array([[4, 14]]), pid1, min1[..., None], pid2, min2[..., None]])" 1082 | ] 1083 | } 1084 | ], 1085 | "metadata": { 1086 | "kernelspec": { 1087 | "display_name": "Python 2", 1088 | "language": "python", 1089 | "name": "python2" 1090 | }, 1091 | "language_info": { 1092 | "codemirror_mode": { 1093 | "name": "ipython", 1094 | "version": 2 1095 | }, 1096 | "file_extension": ".py", 1097 | "mimetype": "text/x-python", 1098 | "name": "python", 1099 | "nbconvert_exporter": "python", 1100 | "pygments_lexer": "ipython2", 1101 | "version": "2.7.12" 1102 | } 1103 | }, 1104 | "nbformat": 4, 1105 | "nbformat_minor": 2 1106 | } 1107 | -------------------------------------------------------------------------------- /nba-prediction/readme.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lijin-THU/play-with-machine-learning/1768ff47eddceaf6975eb1484fab1245598f30e2/nba-prediction/readme.md -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lijin-THU/play-with-machine-learning/1768ff47eddceaf6975eb1484fab1245598f30e2/readme.md --------------------------------------------------------------------------------