├── DIEN_train_example.ipynb ├── DIN_train_example.ipynb ├── README.md ├── __pycache__ ├── activations.cpython-37.pyc ├── alibaba_data_reader.cpython-37.pyc ├── layers.cpython-37.pyc ├── loss.cpython-37.pyc ├── model.cpython-37.pyc └── utils.cpython-37.pyc ├── activations.py ├── alibaba_data_reader.py ├── layers.py ├── loss.py ├── main.ipynb ├── main.py ├── model.py ├── tensorboard.log ├── tensorboard.sh └── utils.py /DIEN_train_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import tensorflow as tf\n", 10 | "from tensorflow.keras import layers\n", 11 | "from layers import AUGRU\n", 12 | "from activations import Dice,dice\n", 13 | "import pandas as pd\n", 14 | "from model import DIEN\n", 15 | "import alibaba_data_reader as data_reader\n", 16 | "import utils\n", 17 | "import matplotlib\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "from matplotlib.font_manager import FontProperties\n", 20 | "from matplotlib.pyplot import MultipleLocator\n", 21 | "import numpy as np\n", 22 | "import os\n", 23 | "from loss import AuxLayer" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "def mkdir(path):\n", 33 | " try:\n", 34 | " if not os.path.exists(path):\n", 35 | " os.makedirs(path)\n", 36 | " return 0\n", 37 | " except:\n", 38 | " return 1\n", 39 | "model_name = \"dien\"" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "def is_in_notebook():\n", 49 | " import sys\n", 50 | " return 'ipykernel' in sys.modules\n", 51 | "def clear_output():\n", 52 | " \"\"\"\n", 53 | " clear output for both jupyter notebook and the console\n", 54 | " \"\"\"\n", 55 | " import os\n", 56 | " os.system('cls' if os.name == 'nt' else 'clear')\n", 57 | " if is_in_notebook():\n", 58 | " from IPython.display import clear_output as clear\n", 59 | " clear()" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 4, 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "name": "stdout", 69 | "output_type": "stream", 70 | "text": [ 71 | "2\n" 72 | ] 73 | } 74 | ], 75 | "source": [ 76 | "print(1)\n", 77 | "clear_output()\n", 78 | "print(2)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 5, 84 | "metadata": { 85 | "tags": [] 86 | }, 87 | "outputs": [ 88 | { 89 | "name": "stdout", 90 | "output_type": "stream", 91 | "text": [ 92 | "2.0.0\n", 93 | "GPU Available: True\n" 94 | ] 95 | } 96 | ], 97 | "source": [ 98 | "print(tf.__version__)\n", 99 | "print(\"GPU Available: \", tf.test.is_gpu_available())" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 6, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "file_path = \"/nfs/project/boweihan_2/DIEN/dien_final/\"\n", 109 | "file_path = \"\"" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "# 模型训练" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 7, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "data": { 126 | "text/html": [ 127 | "
\n", 128 | "\n", 141 | "\n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | "
brandcatecms_segidcms_groupgenderagepvalueshoppingoccupationuser_class_level
0460561129689713273324
\n", 173 | "
" 174 | ], 175 | "text/plain": [ 176 | " brand cate cms_segid cms_group gender age pvalue shopping \\\n", 177 | "0 460561 12968 97 13 2 7 3 3 \n", 178 | "\n", 179 | " occupation user_class_level \n", 180 | "0 2 4 " 181 | ] 182 | }, 183 | "execution_count": 7, 184 | "metadata": {}, 185 | "output_type": "execute_result" 186 | } 187 | ], 188 | "source": [ 189 | "train_data, test_data, embedding_count = data_reader.get_data()\n", 190 | "embedding_count" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 8, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "embedding_features_list = data_reader.get_embedding_features_list()\n", 200 | "user_behavior_features = data_reader.get_user_behavior_features()\n", 201 | "embedding_count_dict = data_reader.get_embedding_count_dict(embedding_features_list, embedding_count)\n", 202 | "embedding_dim_dict = data_reader.get_embedding_dim_dict(embedding_features_list)" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 9, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "import time\n", 212 | "stamp = time.strftime(\"%Y%m%d-%H%M%S\", time.localtime())\n", 213 | "mkdir(\"./train_log/\" + model_name)\n", 214 | "log_path = \"./train_log/\"+model_name+\"/%s\" % stamp\n", 215 | "train_summary_writer = tf.summary.create_file_writer(log_path)\n", 216 | "tf.summary.trace_on(graph=True, profiler=True)\n", 217 | "loss_file_name = utils.get_file_name()\n", 218 | "mkdir(\"./loss/\" + model_name + \"/\")\n", 219 | "utils.make_train_loss_dir(loss_file_name, cols=[\"train_aux_loss\",\"train_target_loss\",\"train_final_loss\"], model=model_name)\n", 220 | "utils.make_test_loss_dir(loss_file_name, cols=[\"test_aux_loss\",\"test_target_loss\",\"test_final_loss\"], model=model_name)" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 10, 226 | "metadata": {}, 227 | "outputs": [ 228 | { 229 | "data": { 230 | "text/plain": [ 231 | "" 232 | ] 233 | }, 234 | "execution_count": 10, 235 | "metadata": {}, 236 | "output_type": "execute_result" 237 | } 238 | ], 239 | "source": [ 240 | "model = DIEN(\n", 241 | " embedding_count_dict, \n", 242 | " embedding_dim_dict, \n", 243 | " embedding_features_list, \n", 244 | " user_behavior_features, \n", 245 | " activation=\"dice\")\n", 246 | "model" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 11, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "min_batch = 0\n", 256 | "batch = 100\n", 257 | "optimizer = tf.keras.optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)\n", 258 | "loss_metric = tf.keras.metrics.Sum()\n", 259 | "auc_metric = tf.keras.metrics.AUC()\n", 260 | "alpha = 1\n", 261 | "epochs = 3" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 12, 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show, min_batch, clk_length, show_length = data_reader.get_batch_data(train_data, min_batch, batch = batch)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 13, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "def get_train_data(label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show):\n", 280 | " user_profile_dict = {\n", 281 | " \"cms_segid\": cms_segid,\n", 282 | " \"cms_group\": cms_group,\n", 283 | " \"gender\": gender,\n", 284 | " \"age\": age,\n", 285 | " \"pvalue\": pvalue,\n", 286 | " \"shopping\": shopping,\n", 287 | " \"occupation\": occupation,\n", 288 | " \"user_class_level\": user_class_level\n", 289 | " }\n", 290 | " user_profile_list = [\"cms_segid\", \"cms_group\", \"gender\", \"age\", \"pvalue\", \"shopping\", \"occupation\", \"user_class_level\"]\n", 291 | " user_behavior_list = [\"brand\", \"cate\"]\n", 292 | " click_behavior_dict = {\n", 293 | " \"brand\": hist_brand_behavior_clk,\n", 294 | " \"cate\": hist_cate_behavior_clk\n", 295 | " }\n", 296 | " noclick_behavior_dict = {\n", 297 | " \"brand\": hist_brand_behavior_show,\n", 298 | " \"cate\": hist_cate_behavior_show\n", 299 | " }\n", 300 | " target_item_dict = {\n", 301 | " \"brand\": target_cate,\n", 302 | " \"cate\": target_brand\n", 303 | " }\n", 304 | " return user_profile_dict, user_profile_list, user_behavior_list, click_behavior_dict, noclick_behavior_dict, target_item_dict" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 14, 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "user_profile_dict, user_profile_list, user_behavior_list, click_behavior_dict, noclick_behavior_dict, target_item_dict = get_train_data(label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show) " 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 15, 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [ 322 | "def train_one_step(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list, label):\n", 323 | " with tf.GradientTape() as tape:\n", 324 | " output, logit, aux_loss = model(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list)\n", 325 | " target_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logit,labels=tf.cast(label, dtype=tf.float32)))\n", 326 | " final_loss = target_loss + alpha * aux_loss\n", 327 | " #print(\"[Train Loss] aux_loss=\" + str(aux_loss.numpy()) + \", target_loss=\" + str(target_loss.numpy()) + \", final_loss=\" + str(final_loss.numpy()))\n", 328 | " gradient = tape.gradient(final_loss, model.trainable_variables)\n", 329 | " clip_gradient, _ = tf.clip_by_global_norm(gradient, 5.0)\n", 330 | " optimizer.apply_gradients(zip(clip_gradient, model.trainable_variables))\n", 331 | " loss_metric(final_loss)\n", 332 | " return aux_loss.numpy(), target_loss.numpy(), final_loss.numpy()" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": 16, 338 | "metadata": {}, 339 | "outputs": [], 340 | "source": [ 341 | "def get_test_loss(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list, label):\n", 342 | " output, logit, aux_loss = model(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list)\n", 343 | " target_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logit,labels=tf.cast(label, dtype=tf.float32)))\n", 344 | " final_loss = target_loss + alpha * aux_loss\n", 345 | " #print(\"[Test Loss] aux_loss=\" + str(aux_loss.numpy()) + \", target_loss=\" + str(target_loss.numpy()) + \", final_loss=\" + str(final_loss.numpy()))\n", 346 | " return aux_loss.numpy(), target_loss.numpy(), final_loss.numpy()" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 18, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "#aux_loss, target_loss, final_loss = train_one_step(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list, label)" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": 17, 361 | "metadata": {}, 362 | "outputs": [ 363 | { 364 | "name": "stdout", 365 | "output_type": "stream", 366 | "text": [ 367 | "WARNING:tensorflow:Layer dien is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because it's dtype defaults to floatx.\n", 368 | "\n", 369 | "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n", 370 | "\n", 371 | "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n", 372 | "\n" 373 | ] 374 | }, 375 | { 376 | "data": { 377 | "text/plain": [ 378 | "(0.89547175, 0.69206244, 1.5875342)" 379 | ] 380 | }, 381 | "execution_count": 17, 382 | "metadata": {}, 383 | "output_type": "execute_result" 384 | } 385 | ], 386 | "source": [ 387 | "get_test_loss(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list, label)" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": 18, 393 | "metadata": {}, 394 | "outputs": [ 395 | { 396 | "name": "stdout", 397 | "output_type": "stream", 398 | "text": [ 399 | "Model: \"dien\"\n", 400 | "_________________________________________________________________\n", 401 | "Layer (type) Output Shape Param # \n", 402 | "=================================================================\n", 403 | "embedding_5 (Embedding) multiple 448 \n", 404 | "_________________________________________________________________\n", 405 | "embedding_1 (Embedding) multiple 32000000 \n", 406 | "_________________________________________________________________\n", 407 | "embedding (Embedding) multiple 32100992 \n", 408 | "_________________________________________________________________\n", 409 | "embedding_3 (Embedding) multiple 832 \n", 410 | "_________________________________________________________________\n", 411 | "embedding_2 (Embedding) multiple 6208 \n", 412 | "_________________________________________________________________\n", 413 | "embedding_4 (Embedding) multiple 192 \n", 414 | "_________________________________________________________________\n", 415 | "embedding_8 (Embedding) multiple 320 \n", 416 | "_________________________________________________________________\n", 417 | "embedding_6 (Embedding) multiple 640 \n", 418 | "_________________________________________________________________\n", 419 | "embedding_7 (Embedding) multiple 256 \n", 420 | "_________________________________________________________________\n", 421 | "embedding_9 (Embedding) multiple 320 \n", 422 | "_________________________________________________________________\n", 423 | "gru (GRU) multiple 99072 \n", 424 | "_________________________________________________________________\n", 425 | "softmax (Softmax) multiple 0 \n", 426 | "_________________________________________________________________\n", 427 | "aux_layer (AuxLayer) multiple 31876 \n", 428 | "_________________________________________________________________\n", 429 | "augru (AUGRU) multiple 98688 \n", 430 | "_________________________________________________________________\n", 431 | "sequential_1 (Sequential) multiple 148122 \n", 432 | "=================================================================\n", 433 | "Total params: 64,487,966\n", 434 | "Trainable params: 64,485,614\n", 435 | "Non-trainable params: 2,352\n", 436 | "_________________________________________________________________\n" 437 | ] 438 | } 439 | ], 440 | "source": [ 441 | "model.summary()" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": 19, 447 | "metadata": {}, 448 | "outputs": [ 449 | { 450 | "name": "stdout", 451 | "output_type": "stream", 452 | "text": [ 453 | "dien/embedding_5/embeddings:0\n", 454 | "dien/embedding_1/embeddings:0\n", 455 | "dien/embedding/embeddings:0\n", 456 | "dien/embedding_3/embeddings:0\n", 457 | "dien/embedding_2/embeddings:0\n", 458 | "dien/embedding_4/embeddings:0\n", 459 | "dien/embedding_8/embeddings:0\n", 460 | "dien/embedding_6/embeddings:0\n", 461 | "dien/embedding_7/embeddings:0\n", 462 | "dien/embedding_9/embeddings:0\n", 463 | "dien/gru/kernel:0\n", 464 | "dien/gru/recurrent_kernel:0\n", 465 | "dien/gru/bias:0\n", 466 | "dien/aux_layer/sequential/batch_normalization/gamma:0\n", 467 | "dien/aux_layer/sequential/batch_normalization/beta:0\n", 468 | "dien/aux_layer/sequential/dense/kernel:0\n", 469 | "dien/aux_layer/sequential/dense/bias:0\n", 470 | "dien/aux_layer/sequential/dense_1/kernel:0\n", 471 | "dien/aux_layer/sequential/dense_1/bias:0\n", 472 | "dien/aux_layer/sequential/dense_2/kernel:0\n", 473 | "dien/aux_layer/sequential/dense_2/bias:0\n", 474 | "dien/augru/gru_gates/dense_3/kernel:0\n", 475 | "dien/augru/gru_gates/dense_3/bias:0\n", 476 | "dien/augru/gru_gates/dense_4/kernel:0\n", 477 | "dien/augru/gru_gates_1/dense_5/kernel:0\n", 478 | "dien/augru/gru_gates_1/dense_5/bias:0\n", 479 | "dien/augru/gru_gates_1/dense_6/kernel:0\n", 480 | "dien/augru/gru_gates_2/dense_7/kernel:0\n", 481 | "dien/augru/gru_gates_2/dense_7/bias:0\n", 482 | "dien/augru/gru_gates_2/dense_8/kernel:0\n", 483 | "dien/sequential_1/batch_normalization_1/gamma:0\n", 484 | "dien/sequential_1/batch_normalization_1/beta:0\n", 485 | "dien/sequential_1/dense_9/kernel:0\n", 486 | "dien/sequential_1/dense_9/bias:0\n", 487 | "Variable:0\n", 488 | "Variable:0\n", 489 | "dien/sequential_1/dense_10/kernel:0\n", 490 | "dien/sequential_1/dense_10/bias:0\n", 491 | "Variable:0\n", 492 | "Variable:0\n", 493 | "dien/sequential_1/dense_11/kernel:0\n", 494 | "dien/sequential_1/dense_11/bias:0\n" 495 | ] 496 | } 497 | ], 498 | "source": [ 499 | "for var in model.trainable_variables:\n", 500 | " print(var.name)" 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": 20, 506 | "metadata": {}, 507 | "outputs": [], 508 | "source": [ 509 | "def get_loss_fig(train_loss, test_loss):\n", 510 | " loss_list = [\"aux_loss\", \"final_loss\"]\n", 511 | " color_list = [\"r\", \"b\"]\n", 512 | " plt.figure()\n", 513 | " cnt = 0\n", 514 | " for k in loss_list:\n", 515 | " loss = train_loss[k]\n", 516 | " step = list(np.arange(len(loss)))\n", 517 | " plt.plot(step,loss,color_list[cnt]+\"-\",label=\"train_\" + k, linestyle=\"--\")\n", 518 | " cnt += 1\n", 519 | " cnt = 0\n", 520 | " for k in loss_list:\n", 521 | " loss = test_loss[k]\n", 522 | " step = list(np.arange(len(loss)))\n", 523 | " plt.plot(step,loss,color_list[cnt],label=\"test_\" + k)\n", 524 | " cnt += 1\n", 525 | " plt.title(\"Loss\")\n", 526 | " plt.xlabel('iteration')\n", 527 | " plt.ylabel('loss')\n", 528 | " plt.legend()\n", 529 | " clear_output()\n", 530 | " plt.savefig(\"./loss/\" + model_name + \"/loss.png\")\n", 531 | " clear_output()\n", 532 | " plt.show()" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": 21, 538 | "metadata": {}, 539 | "outputs": [], 540 | "source": [ 541 | "def record_test_loss(test_loss, test_data, step):\n", 542 | " label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show, clk_length, show_length = data_reader.get_test_data(test_data)\n", 543 | " user_profile_dict, user_profile_list, user_behavior_list, click_behavior_dict, noclick_behavior_dict, target_item_dict = get_train_data(label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show)\n", 544 | " aux_loss, target_loss, final_loss = get_test_loss(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list, label)\n", 545 | " loss_dict = dict()\n", 546 | " loss_dict[\"aux_loss\"] = str(aux_loss)\n", 547 | " loss_dict[\"target_loss\"] = str(target_loss)\n", 548 | " loss_dict[\"final_loss\"] = str(final_loss)\n", 549 | " utils.add_loss(loss_dict, loss_file_name, level=\"test\")\n", 550 | " test_loss[\"aux_loss\"].append(float(aux_loss))\n", 551 | " test_loss[\"target_loss\"].append(float(target_loss))\n", 552 | " test_loss[\"final_loss\"].append(float(final_loss))\n", 553 | " with train_summary_writer.as_default():\n", 554 | " tf.summary.scalar(\"test_aux_loss epoch: \"+str(epoch), aux_loss, step = step)\n", 555 | " tf.summary.scalar(\"test_target_loss epoch: \"+str(epoch), target_loss, step = step)\n", 556 | " tf.summary.scalar(\"test_final_loss epoch: \"+str(epoch), final_loss, step = step)" 557 | ] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "execution_count": 22, 562 | "metadata": {}, 563 | "outputs": [], 564 | "source": [ 565 | "mkdir(\"./checkpoint/\" + model_name)\n", 566 | "checkpoint_path = \"./checkpoint/\" + model_name + \"/cp-{epoch:04d}.ckpt\"\n", 567 | "checkpoint_dir = os.path.dirname(checkpoint_path)" 568 | ] 569 | }, 570 | { 571 | "cell_type": "code", 572 | "execution_count": 23, 573 | "metadata": {}, 574 | "outputs": [ 575 | { 576 | "data": { 577 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nOydd3yO5/fHP1cGQZCI2DTEJkbF3rU3NVo1apdqtVWt+qpVRYeiWuqrRmmNmu23rZYaMYoSxOZnE5tYEUGS6/fHJ7f7SfJkep48Gef9ej2vO/e67vOMXOe6zjnXOUprDUEQBCHz4uRoAQRBEATHIopAEAQhkyOKQBAEIZMjikAQBCGTI4pAEAQhkyOKQBAEIZMjikAQBCGTI4pAEOJBKXVeKdXU0XIIgr0RRSAIgpDJEUUgCMlEKTVQKXVaKRWilPqfUqpQ9HGllJqulLqhlLqnlDqklKoYfa61UuqYUuqBUuqyUmqEY9+FIJiIIhCEZKCUegnAFADdABQEcAHA8ujTzQE0AFAagAeAVwDcjj43H8AbWuucACoC2JyKYgtCgrg4WgBBSGf0ALBAa70fAJRSowDcUUr5AHgKICeAsgD2aK2PW9z3FEB5pdRBrfUdAHdSVWpBSACZEQhC8igEzgIAAFrrUHDUX1hrvRnAtwBmAbiulJqrlMoVfWlnAK0BXFBKbVVK1U5luQUhXkQRCELyuALgBWNHKZUDgBeAywCgtZ6pta4GoAJoIvog+vherXUHAPkA/AJgRSrLLQjxIopAEBLGVSnlZrzADryvUqqKUiorgMkA/tVan1dKVVdK1VRKuQJ4CCAcQKRSKotSqodSKrfW+imA+wAiHfaOBCEWoggEIWHWAXhk8aoPYAyA1QCuAvAF8Gr0tbkAfA/a/y+AJqOp0ed6ATivlLoPYDCAnqkkvyAkipLCNIIgCJkbmREIgiBkckQRCIIgZHJEEQiCIGRyRBEIgiBkctLdyuK8efNqHx8fR4shCIKQrti3b98trbW3tXPpThH4+PggMDDQ0WIIgiCkK5RSF+I7J6YhQRCETI4oAkEQhEyOKAJBEIRMTrrzEQiCkDo8ffoUwcHBCA8Pd7QoQjJwc3NDkSJF4OrqmuR77KYIlFILALQFcENrXTGeaxoBmAHAFcAtrXVDe8kjCELyCA4ORs6cOeHj4wOllKPFEZKA1hq3b99GcHAwihcvnuT77Gka+gFAy/hOKqU8AMwG0F5rXQFAVzvKIghCMgkPD4eXl5cogXSEUgpeXl7JnsXZTRForbcBCEngktcArNFaX4y+/oa9ZBEEIWWIEkh/pOQ7c6SzuDQAT6VUgFJqn1Kqd3wXKqUGKaUClVKBN2/eTNHDLl0Chg0DHj1KqbiCIAgZE0cqAhcA1QC0AdACwBilVGlrF2qt52qt/bXW/t7eVhfGJcqOHcA33wAdOqRYXkEQhAyJIxVBMIC/tNYPtda3AGwDUNleD3v1VaBwYeDvv4Hff7fXUwRBsCV3797F7Nmzk31f69atcffuXTtIlHICAgLQtm1bR4thFUcqgl8B1FdKuSilsgOoCeC4vR6mFPDXX9x27w48eWKvJwmCYCviUwSRkQlX+ly3bh08PDzsJVaGw26KQCm1DMAuAGWUUsFKqf5KqcFKqcEAoLU+DuAvAIcA7AEwT2t9xF7yAEDFisCgQUBoKPDKK/Z8kiBkQBo1ivsyOumwMOvnf/iB52/dinsuCXz00Uc4c+YMqlSpgurVq6Nx48Z47bXX4OfnBwDo2LEjqlWrhgoVKmDu3LnP7vPx8cGtW7dw/vx5lCtXDgMHDkSFChXQvHlzPErAUfj999+jevXqqFy5Mjp37oywsDAAQJ8+fbBq1apn17m7uwMA1q5di6ZNm0JrjatXr6J06dK4du1aou8rJCQEHTt2RKVKlVCrVi0cOnQIALB161ZUqVIFVapUQdWqVfHgwQNcvXoVDRo0QJUqVVCxYkVs3749SZ9dcrBn1FB3rXVBrbWr1rqI1nq+1nqO1nqOxTVfaq3La60raq1n2EsWS2bPBry9gV9/BY4dS40nCoKQUj777DP4+voiKCgIX375Jfbs2YNJkybhWPQ/74IFC7Bv3z4EBgZi5syZuH37dpw2Tp06haFDh+Lo0aPw8PDA6tWr433eyy+/jL179+LgwYMoV64c5s+fn6B8nTp1QoECBTBr1iwMHDgQEyZMQIECBRJ9X+PGjUPVqlVx6NAhTJ48Gb17M1Zm6tSpmDVrFoKCgrB9+3Zky5YNS5cuRYsWLRAUFISDBw+iSpUqibafXDLdymInJ2DTJqBhQ2DAAGD7dsDZ2dFSCUI6ICAg/nPZsyd8Pm/ehM8nkRo1asRYKDVz5kysXbsWAHDp0iWcOnUKXl5eMe4pXrz4s86zWrVqOH/+fLztHzlyBB9//DHu3r2L0NBQtGjRIlGZvvnmG1SsWBG1atVC9+7dk/Q+duzY8UwhvfTSS7h9+zbu3buHunXrYvjw4ejRowdefvllFClSBNWrV0e/fv3w9OlTdOzY0S6KIFPmGvLzA77+Gti1CxgxwtHSCIKQVHLkyPHs74CAAGzcuBG7du3CwYMHUbVqVasLqbJmzfrsb2dnZ0RERMTbfp8+ffDtt9/i8OHDGDdu3LP2XFxcEBUVBYCrd59YOBkvX74MJycnXL9+/dk1iaG1jnNMKYWPPvoI8+bNw6NHj1CrVi2cOHECDRo0wLZt21C4cGH06tULixcvTtIzkkOmVAQA0LMnUKgQMGMGsHOno6URBMEaOXPmxIMHD6yeu3fvHjw9PZE9e3acOHECu3fvfu7nPXjwAAULFsTTp0+xZMmSZ8d9fHywb98+AMCvv/6Kp0+fAgAiIiLQt29fLF26FOXKlcO0adOS9JwGDRo8az8gIAB58+ZFrly5cObMGfj5+WHkyJHw9/fHiRMncOHCBeTLlw8DBw5E//79sX///ud+n7HJdKYhA6WAZctoImrfHrh+XUxEgpDW8PLyQt26dVGxYkVky5YN+fPnf3auZcuWmDNnDipVqoQyZcqgVq1az/28iRMnombNmnjhhRfg5+f3TAkNHDgQHTp0QI0aNdCkSZNnM5PJkyejfv36qF+//jOHdps2bVCuXLkEnzN+/Hj07dsXlSpVQvbs2bFo0SIAwIwZM7BlyxY4OzujfPnyaNWqFZYvX44vv/wSrq6ucHd3t8uMQFmboqRl/P39tS0rlHXrBqxcCfTqBdjh8xWEdMvx48cT7dCEtIm1704ptU9r7W/t+kxrGjJYuhTw8AB+/JGrjwVBEDIbmV4RuLgwlNTVFXjzTSDa9CcIQgZm6NChz+L1jdfChQtt0vb69evjtN2pUyebtG0vMq2PwJIGDegv6NIF+PRTYMIER0skCII9mTVrlt3abtGiRZLCTtMSmX5GYNC5M/0FEycCa9Y4WhpBEITUQxSBBWPGcNuzJ9NQCIIgZAZEEVhQsSLwwQesWdC6taOlEQRBSB1EEcTis88AHx+mnrCR70gQBCFNI4ogFkqxZoGTE6OIQhIqtikIgl1J7XoEJ06ceJb588yZM6hTp06y2zCInbE0No0aNYIt10Q9D6IIrFCyJPDf/wKRkUxbnc7W3AlChiG16xH88ssv6NChAw4cOABfX1/szCT5Z0QRxMOAAcCkScDq1QwpFYTMjgPKEaRqPYJ169ZhxowZmDdvHho3bgzArDsQEBCARo0aoUuXLihbtix69OjxLHHcJ598gurVq6NixYoYNGiQ1YRyibFs2TL4+fmhYsWKGDlyJAAquz59+qBixYrw8/PD9OnTATDjavny5VGpUiW8+uqryX6WNWQdQQK8/z6T0o0dCzRuDNSr52iJBCFz8dlnn+HIkSMICgpCQEAA2rRpgyNHjjxLRb1gwQLkyZMHjx49QvXq1dG5c+c4aahPnTqFZcuW4fvvv0e3bt2wevVq9OzZM86zWrdujcGDB8Pd3R0jrKQlPnDgAI4ePYpChQqhbt26+Oeff1CvXj289dZbGDt2LACgV69e+P3339GuXbskv8crV65g5MiR2LdvHzw9PdG8eXP88ssvKFq0KC5fvowjR1ivyzB1ffbZZzh37hyyZs1qs3KcdlMESqkFANoCuKG1rpjAddUB7AbwitY6foOaA3By4oimeXOgVSvg6lUgeoAgCJmONFCOwO71CBJ7dpEiRQAAVapUwfnz51GvXj1s2bIFX3zxBcLCwhASEoIKFSokSxHs3bsXjRo1gre3NwCgR48e2LZtG8aMGYOzZ8/i7bffRps2bdC8eXMAQKVKldCjRw907NgRHTt2TNF7iY09TUM/AGiZ0AVKKWcAnwNYb0c5notmzYC33+a6gqROZwVBsA/2rkeQENbaCQ8Px5tvvolVq1bh8OHDGDhwoFUZEiI+U5KnpycOHjyIRo0aYdasWRgwYAAA4I8//sDQoUOxb98+VKtWLcXvxxJ7lqrcBiCxmJu3AawGcMNectiCr78GypcH9u0DpkxxtDSCkHlI7XoEycXo9PPmzYvQ0NAEo4Tio2bNmti6dStu3bqFyMhILFu2DA0bNsStW7cQFRWFzp07Y+LEidi/fz+ioqJw6dIlNG7cGF988cWzSmrPi8N8BEqpwgA6AXgJQPVErh0EYBAAFCtWzP7CxXk+p7UlS1IR9OgBOEAMQch0pHY9guTi4eGBgQMHws/PDz4+PqhePcGuzCoFCxbElClT0LhxY2it0bp1a3To0AEHDx5E3759n1U9mzJlCiIjI9GzZ0/cu3cPWmu89957KYqOio1d6xEopXwA/G7NR6CUWgngK631bqXUD9HXJapObV2PIDmcOQNUrQpUqkTF4CKudiEDI/UI0i/pqR6BP4DlSqnzALoAmK2Uso3nw074+jKK6J9/gN69HS2NIAiCbXCYItBaF9da+2itfQCsAvCm1voXR8mTVF57jYVsli0zY6QFQUhf2LMeQWw6deoU51nr16et+Bh7ho8uA9AIQF6lVDCAcQBcAUBrPcdez7U3bm7A1q00EfXvD5QtCzjANCkIwnNgz3oEsTHCW9MydlMEWuvuybi2j73ksAeVKnFG8MorQNOmXDXp5uZoqQRBEFKGpJhIId26cX3Bw4dMRyEIgpBekbiX52DGDCammz0baNsWsFHaD0EQhFRFZgTPgZMTlUGdOkCfPsDhw46WSBAEIfmIInhOXF2BgQOBx4+Bhg2lxKUg2JKU1iMAgBkzZiAsLMzGEiWd8ePHY+rUqQ57fnIQRWAD+vSh4/jOHaBFC6lfIAi2Ij0rgvSE+AhsxE8/Abt3Azt3Ah9/zFoGgpBhePddICjItm1WqULbagJY1iNo1qwZ8uXLhxUrVuDx48fo1KkTJkyYgIcPH6Jbt24IDg5GZGQkxowZg+vXr+PKlSto3Lgx8ubNiy1btlhtf8iQIdi7dy8ePXqELl26YMKECQBYzyAwMBB58+ZFYGAgRowYgYCAAAwbNgx58+bF2LFjsX79ekyaNAkBAQFwckp4TB0UFITBgwcjLCwMvr6+WLBgATw9PTFz5kzMmTMHLi4uKF++PJYvX46tW7finXfeAQAopbBt2zbkzJkzBR9w0hFFYCNcXJh2onRpYPJk1i9o2tTRUglC+sayHsGGDRuwatUq7NmzB1prtG/fHtu2bcPNmzdRqFAh/PHHHwCYjC537tyYNm0atmzZgrx588bb/qRJk5AnTx5ERkaiSZMmOHToECpVqpSgPNWrV0f9+vUxbNgwrFu3LlElAAC9e/fGN998g4YNG2Ls2LGYMGECZsyYYbW2wNSpUzFr1izUrVsXoaGhcEuF2HRRBDbEx4frC0aMYARRYCCPCUK6J5GRe2qwYcMGbNiwAVWrVgUAhIaG4tSpU6hfvz5GjBiBkSNHom3btqhfv36S21yxYgXmzp2LiIgIXL16FceOHUtQEWTPnh3ff/89GjRogOnTp8PX1zfRZ9y7dw93795Fw4YNAQCvv/46unbtCsB6bYG6deti+PDh6NGjB15++eVnNRDsifgIbEznzsCGDUBEBNCyJUv4CYLw/GitMWrUKAQFBSEoKAinT59G//79Ubp0aezbtw9+fn4YNWoUPvnkkyS1d+7cOUydOhWbNm3CoUOH0KZNm2dppV1cXJ5l/YxdX+Dw4cPw8vLClStXnvs9Wast8NFHH2HevHl49OgRatWqhRMnTjz3cxJDFIEdKFWKi81OngQ6dRLnsSCkFMt6BC1atMCCBQue5d+/fPkybty4gStXriB79uzo2bMnRowYgf3798e51xr3799Hjhw5kDt3bly/fh1//vnns3M+Pj7Yt28fAGD16tXPjl+4cAFfffUVDhw4gD///BP//vtvou8hd+7c8PT0xPbt2wEAP/74Ixo2bBhvbYEzZ87Az88PI0eOhL+/f6ooAjEN2YkRIzib3rABGDcOSOIgRRAECyzrEbRq1QqvvfYaateuDYCF5X/66SecPn0aH3zwAZycnODq6orvvvsOADBo0CC0atUKBQsWtOosrly5MqpWrYoKFSqgRIkSqFu37rNz48aNQ//+/TF58mTUrFkTAGck/fv3x9SpU1GoUCHMnz8fffr0wd69exO14y9atOiZs7hEiRJYuHBhvLUFxowZgy1btsDZ2Rnly5dHq1atbPVxxotd6xHYA0fWI0gu69YB7dtz9fHatYCNyosKQqog9QjSL+mpHkGGp3Vr4O+/+Xf37sDp046VRxAEwRpiGrIzjRsDb74JLF4MtGsH7NrFegaCIKQeNWvWxOPHj2Mc+/HHH+Hn5/fcbU+aNAkrV66Mcaxr164YPXr0c7edWohpKJXYupXrCho1Av78U8pcCmkfMQ2lX8Q0lEZp2JDmoY0bgaFDHS2NIAiCid0UgVJqgVLqhlLqSDzneyilDkW/diqlKttLlrTCsGHMWDp3LhAd2CAIguBw7Dkj+AFAywTOnwPQUGtdCcBEAHPtKEuawN8f+OIL/j10KGcHgiAIjsZuikBrvQ1ASALnd2qt70Tv7gZg/3XUaYDhw4GRI7nIrH174NQpR0skCEJmJ634CPoD+DO+k0qpQUqpQKVU4M2bN1NRLNujFDBlClccOzkxkig615QgCLGwdxrqlStXoly5cmjcuDECAwMxbNiwFD0L4GrkW7duxXve3d09xW3bG4crAqVUY1ARjIzvGq31XK21v9ba39vbO/WEsxNKAatXc8HZ2bOsfxwR4WipBCHtYW9FMH/+fMyePRtbtmyBv78/Zs6cmaJnpXccGsSolKoEYB6AVlrr246UJbVRiiUu+/al8/jdd4Fvv3W0VIJgHQeVI7BrPYJPPvkEO3bswLlz59C+fXu0adMGU6dOxe+//47x48fj4sWLOHv2LC5evIh333332WyhY8eOuHTpEsLDw/HOO+9g0KBByXrfWmt8+OGH+PPPP6GUwscff4xXXnkFV69exSuvvIL79+8jIiIC3333HerUqYP+/fsjMDAQSin069cP7733XrKelxQcpgiUUsUArAHQS2v9f46Sw5HcuAFE57XCrFlMVhddj0IQBNi3HsHYsWOxefNmTJ06Ff7+/ggICIhx/sSJE9iyZQsePHiAMmXKYMiQIXB1dcWCBQuQJ08ePHr0CNWrV0fnzp3h5eWV5Pe0Zs0aBAUF4eDBg7h16xaqV6+OBg0aYOnSpWjRogVGjx6NyMhIhIWFISgoCJcvX8aRIwy+vGsnO7LdFIFSahmARgDyKqWCAYwD4AoAWus5AMYC8AIwWykFABHxLXbIqBQqBOzdy0iijz4C3nsPKFOG6asFIS2RBsoR2KUeQUK0adMGWbNmRdasWZEvXz5cv34dRYoUwcyZM7F27VoAwKVLl3Dq1KlkKYIdO3age/fucHZ2Rv78+dGwYUPs3bsX1atXR79+/fD06VN07NgRVapUQYkSJXD27Fm8/fbbaNOmDZo3b26T9xYbe0YNdddaF9Rau2qti2it52ut50QrAWitB2itPbXWVaJfmUoJGCjFKKIvv2QkUadOwPHjjpZKENIetq5HkBhZs2Z99rezszMiIiIQEBCAjRs3YteuXTh48CCqVq0ap15BUt6HNRo0aIBt27ahcOHC6NWrFxYvXgxPT08cPHgQjRo1wqxZszBgwIDnek/x4XBnsUBGjODK46xZgbZtgQSCDwQh02DPegQp4d69e/D09ET27Nlx4sQJ7N69O9ltNGjQAD///DMiIyNx8+ZNbNu2DTVq1MCFCxeQL18+DBw4EP3798f+/ftx69YtREVFoXPnzpg4ceKz92ZrJONNGmLpUmD3buYjevllZi61GJQIQqbDnvUIUkLLli0xZ84cVKpUCWXKlEGtWrWS3UanTp2wa9cuVK5cGUopfPHFFyhQoAAWLVqEL7/8Eq6urnB3d8fixYtx+fJl9O3b91m1tClTptjkfcRGks6lQfr0ARYtAl5/HVi4kOYjQUhtJOlc+kWSzmUAihbldtEiMyWFIAiCvRDTUBpk5Ejg99+BgwcZTVS6NJ3IgiCkDHvWI7Dk9u3baNKkSZzjmzZtSlZkUWojiiAN4u4ObNsGvPgiEBwM9OwJbN/OfUFITbTWUBnANpmUIvO2wMvLC0G2XnmXTFJi7hfTUBolZ05g2jTg6VNWNGvXDrh82dFSCZkJNzc33L59O0Udi+AYtNa4ffs23NzcknWfzAjSMG3bAlu2ALlzA3XrAh06cKaQPbujJRMyA0WKFEFwcDDSe6LHzIabmxuKFEleMmdRBGkYpYD69YGoKK4xmDcP6N0bWLGCmUsFwZ64urqiePHijhZDSAWkO0kHKAWcOQNkycKspWPGOFoiQRAyEqII0gFKAd9/Dzg703cweTKweLGjpRIEIaMgiiCdUKIEsGkT4OpKH8HAgcCOHY6WShCEjIAognRErVrA8uWMJPL25tqCs2cdLZUgCOkdUQTpjGbNgIsXGU0UGcnIonv3HC2VIAjpGVEE6ZACBVjEZuxY4NQp4JVXpNSlIAgpRxRBOmX/fhaycXYG1q8Hhg93tESCIKRXRBGkU158Efj0U+Cll7j/zTdMUicIgpBc7KYIlFILlFI3lFJH4jmvlFIzlVKnlVKHlFKSSSeZjB7N5HS9enF/wADOFARBEJKDPWcEPwBIqPpuKwClol+DAHxnR1kyLE5OnAksXkzfQadOUt1MEITkYc+axdsAhCRwSQcAizXZDcBDKVXQXvJkZJTirGDtWuD6daBrV3EeC4KQdBzpIygM4JLFfnD0sTgopQYppQKVUoGSACt+ypdngrqAAOA//3G0NIIgpBccqQisJTm3mu9Waz1Xa+2vtfb39va2s1jpl+zZgZYtARcX4MsvmZxOEAQhMRypCIIBFLXYLwLgioNkyTC89RbNQiVKAP37c52BIAhCQjhSEfwPQO/o6KFaAO5pra86UJ4Mgb8/UK0acPs2HcndugHh4Y6WShCEtIw9w0eXAdgFoIxSKlgp1V8pNVgpNTj6knUAzgI4DeB7AG/aS5bMhFLMR1SgAHMTBQUBH3zgaKkEQUjL2K0wjda6eyLnNYCh9np+ZqZkSeDwYWYqHT4cmD6dC886dXK0ZIIgpEVkZXEGxdWV244dgapVgX79gPPnHSqSIAhpFFEEGZibN4FWregwfvCA5S5lfYEgCLERRZCB8fYGfvqJDuTISGD3buCLLxwtlSAIaQ1RBBmcTp2YjyhnTuCFF4Dx44GDBx0tlSAIaQlRBJmAHDmYgqJqVSBPHv79+LGjpRIEIa2QuRRBVJSjJXAYM2cyF9G8eYwomjDB0RIJgpBWyDyKYPNmoHRp4IjVrNgZHmdnbnPlAmrWBD7/XFJWC4JAMo8iyJ0bOHMGaNIkU88MliwB/v0X8PAA3niDTmRBEDI3mUcRVKsGdO4M3LhhVnLJhHz1FVCoEPViYCAwe7ajJRIEwdFkHkUAAD//DHh5AUuXAn/95WhpHIK7O8tanjvH8NLRo4HLlx0tlSAIjiRzKQJnZ2DDBibk6dwZCAtztEQO4eWXgUmTuODsyRNg2DBHSyQIgiPJXIoAYNX399+nEhg71tHSOIyBA7m+4D//AdasATZudLREgiA4CsXcb4lcpNQ7ABYCeABgHoCqAD7SWm+wr3hx8ff314GBgc/f0BtvAN9/D2zZAjRs+PztpVMePwbKleOCs/37zegiQRAyFkqpfVprf2vnkjoj6Ke1vg+gOQBvAH0BfGYj+RzDtGlA0aJMxpNJjeQXLgCTJwMffggcOgR8+62jJRIEwREkVREYZSVbA1iotT4I66Um0w85cgAffQQ8egQ0auRoaRzClSvAJ58Al6IrR48fDzx86FCRBEFwAElVBPuUUhtARbBeKZUTQPoPxh8yBKhbFzh9mon7Mxk1atAkNHky9+/eBaZOdaxMgiCkPklVBP0BfASgutY6DIAraB5KEKVUS6XUSaXUaaXUR1bO51ZK/aaUOqiUOqqUSrRNm7N+PWMqp08Htm5N9cc7EmdnwNPT3M+Th9lJr11znEyCIKQ+SVUEtQGc1FrfVUr1BPAxgHsJ3aCUcgYwC0ArAOUBdFdKlY912VAAx7TWlQE0AvCVUipLMuR/fnLkANat499dutBUlIn48UegTh1g8GAgNJTOY0lVLQiZi6Qqgu8AhCmlKgP4EMAFAIsTuacGgNNa67Na6ycAlgPoEOsaDSCnUkoBcAcQAiD1S6fUrw98/TVw6xZDSzMRDRoA//wDtGnDNQWNGwNz5sisQBAyE0lVBBHRNYY7APhaa/01gJyJ3FMYwCWL/eDoY5Z8C6AcgCsADgN4R2sdx/eglBqklApUSgXevHkziSInk2HDWOX9u++YkS2T0aYNsGMHU048fgx8+aWjJRIEIbVIqiJ4oJQaBaAXgD+izT6uidxjLaoo9qKFFgCCABQCUAXAt0qpXHFu0nqu1tpfa+3v7e2dRJFTwKRJTMIzahSwZ4/9npMGUYp+81KlgJ49qQ+vX3e0VIIgpAZJVQSvAHgMrie4Bo7sExszBgMoarFfBBz5W9IXwBpNTgM4B6BsEmWyPa6uwLJlgNZAs2ZAeLjDRHEUnTvTiRweDsya5WhpBEFIDZKkCKI7/yUAciul2gII11on5iPYCxZJ48YAACAASURBVKCUUqp4tAP4VQD/i3XNRQBNAEAplR9AGQBnkyG/7WnViquO799nyupMxoMHwIEDQLt2NBNlMt+5IGRKkqQIlFLdAOwB0BVANwD/KqW6JHSP1joCwFsA1gM4DmCF1vqoUmqwUmpw9GUTAdRRSh0GsAnASK31rZS9FRvy3Xe0kezcmelKedWsyVXGQ4YAt28zqkgQhIxNUnMNHQTQTGt9I3rfG8DG6LDPVMVmuYYS49YtoGJFDpF37wb8/Oz/zDTAH38AbdsCAQFcYxcWBhw9CjhlvvSEgpChsEWuISdDCURzOxn3pk/y5gWCgug8fvllDo8zATVqcPvvv4ykPXGCa+4EQci4JLUz/0sptV4p1Ucp1QfAHwDW2U+sNEKBAsBPP7HEZe3amaLEpbc38PrrwAsvcH2dtzcwd66jpRIEwZ4kyTQEAEqpzgDqgmGh27TWa+0pWHykmmnIkjp1gF27GFeZSYzmjx+zhs/WrcCMGUxMV7Cgo6USBCGl2MI0BK31aq31cK31e45SAg5j0yYm5fnpJ+CHHxwtTaowahTQvj0L1kRGAgsXOloiQRDsRYKKQCn1QCl138rrgVLqfmoJ6XCyZQO2b2eA/YABDKvJwFy7xsL2o0cDBw8CJUuyhk8msIwJQqYkQUWgtc6ptc5l5ZVTax1nBXCGpkIFYN48/t2pE3M2Z1AKFAC2bQM+/RSoXh14+hQ4f17KWQpCRiVjR/7Ymj59aDi/dAl49VUgIvXz46U2vXoBXl4MnlqyxNHSCIJgD0QRJJeXXgK++YYxlR1iJ1PNeAwdCuzbx9QTa9bISmNByIiIIkgJr7/OEJp162hIz8AYC8l69GC9gt9/d6w8giDYHlEEKcHNjcPkbNlY53HlSkdLZFf69mXWjYIFxTwkCBkRUQQppWBBhpU6OQHduwP79ztaIrvx9ClXGnfvzknQnTuOlkgQBFsiiuB5qF2bkURRUQy6v3Ej8XvSIeXK0T/esSOVwurVjpZIEARbIorgeenbl3GVISHM1vbwoaMlsjnlynHr5gb4+gKrVjlWHkEQbIsoAlvw0kssaBMYCNSrx6W4GQgj8WpgIPMPbdpEvScIQsZAFIGtaNsWKF2aGUtbtGCVswxCyZJMs1S4MNC1K5dP/Pqro6USBMFWiCKwFc7OHDIXK8Yh85tvOloim6EUc+21bw+8+CLg4yPmIUHISNhVESilWiqlTiqlTiulPornmkZKqSCl1FGl1FZ7ymN33N05I3B3B+bMASZOdLRENuXmTSA4mOahv//O0Fk2BCFTYTdFoJRyBjALQCsA5QF0V0qVj3WNB4DZANprrSuApTDTN56eTFmdNSswblyGGTprDZQty8qdXbsyeuh/sStQC4KQLrHnjKAGgNNa67Na6ycAlgOInZPhNQBrtNYXASBWFbT0S8WKHDrXqQO89hrw11+Olui5UYrRsv/8w0R0xYpl+HV0gpBpsKciKAzgksV+cPQxS0oD8FRKBSil9imleltrSCk1SCkVqJQKvHnzpp3EtTF58zIfQ7FiQJs2TFaXzqlbl6UrQ0JoHtqwAbh3z9FSCYLwvNhTESgrx2KH0rgAqAagDYAWAMYopUrHuUnruVprf621v7e3t+0ltRceHkxBERVFZZDaldVsTL163NavD5QpAzx5Avz2m2NlEgTh+bGnIggGUNRivwiAK1au+Utr/VBrfQvANgCV7ShT6tOtG/D554y5rFuXlV7SKTVr8u306MFspIULZxgXiCBkauypCPYCKKWUKq6UygLgVQCx3Yu/AqivlHJRSmUHUBPAcTvK5Bg+/BAYM4ZD6Jo1022FsyxZgJ9/ZsJVLy+ah/76C3jwwNGSCYLwPNhNEWitIwC8BWA92Lmv0FofVUoNVkoNjr7mOIC/ABwCsAfAPK31EXvJ5FA++YTKQCkuODt50tESpZjNm4FFi6gIHj+W1NSCkN5ROp2tgPX399eB6dnWfvQo0Lgxs5Zu3MgIo3RG375mobYiRRhNJInoBCFto5Tap7X2t3ZOVhanNhUqMJfzzZtArVrAqVOOlijZFC8OXLlCS1fnznw7oaGOlkoQhJQiisAR+PuzytnDh0CNGunOTFS8OLcXLnBxWXg4lYEgCOkTUQSO4vvvGY959y4VQ1CQoyVKMoYiOHeOgVD588viMkFIz4gicBTOzvSyNm5Mu0qdOkxNkQ7w8eH2/Hm+jZdf5owgA5ZiEIRMgSgCR5I7NzOVfv45S182a8b9NE6hQsCZM8CgQdzv2hUIC8sQmTQEIVMiisDRKMV1Bv/8w3QULVqk+RAcJyegRAluAa409vYW85AgpFdEEaQVChTgOoPISAboz5njaIkS5O+/geHD+beLC81Dv/8OPHrkWLkEQUg+ogjSEt2704msFDBkCDBqVJqtdHbwIDB9ulmysksX+gjEPCQI6Q9RBGmNAQOYyc3ZGfjsM6BXrzRZA7lCBW6PHuW2USMmXP35Z4eJJAhCChFFkBZp04Y+g/z5gSVLONxOYzYXY0H0woWMgHVxodP4f/+T3EOCkN4QRZBWqVkTuHoVmDmTleLr1wdu33a0VM8oUoQ1dxYuBN56i8d69KC+Wrs2ZW0uXw7kyMH8RYIgpB6iCNIySgFvv01P7L59QOXKrAyTBlCKk5X9+5lPD+BSCB8fHk8Jp08zDPXqVZuJKQhCEhBFkB744gsm/798GahSJU3lc6halaGkISEUr0cP5tK7di35bVWqxO2tW7aVURCEhBFFkB4oUQI4dszM+9y2LTBjhqOlekZEBH0GXbrwFRWVMqfx2bPc3sgYlasFId0giiC9kCsXV2xNn86VXO+9Bwwdyl7Ywbi4AN9+C/z7L/DLL5wl/PRT8tv5+29uRREIQuoiiiC98e679BN88AEwezbQpAlw546jpcLLLwOlSlG0nj1Znvn//i95beTKxW1UlO3lEwQhfkQRpEdKlqTfoG9fYNs2wM/PDOh3IF5e9BW8+iqdycmdFTg5Ab6+QL9+9pFPEATr2FURKKVaKqVOKqVOK6U+SuC66kqpSKVUF3vKk+H45BPmhL58GahWzeHJfvLkoSIoVAho1YoTluSsKXj0CMiWzX7yCYJgHbspAqWUM4BZAFoBKA+gu1KqfDzXfQ7WNhaSQ5EijN+sU4dO5G7dGNTvIL/B0qXAjh38e/x4LnuYPj3p94eHA0eOAO+8YxfxBEGIB3vOCGoAOK21Pqu1fgJgOYAOVq57G8BqAOIiTAkeHkBAADBsGL22s2ZxOO6AGMzcuQE3N/5dvTpXGk+axLxESWHuXKau2LrVfjKmJzZuBCZPdrQUQmbAnoqgMIBLFvvB0ceeoZQqDKATgARTbSqlBimlApVSgTdv3rS5oOkeV1fg66/pNJ4/n36DKlU4W0hFtm6lPjImJLNn02/QqROtV4lRpAhQuzZw/bp95Uwv9OkDjB7taCmEzIA9FYGycix2Ks0ZAEZqrRPMqqa1nqu19tda+3t7e9tMwAyHuzs9rbNmseetVQtYvDjVHn/4MPDNN2YQU968DCe9eZMZMhKbGSxezGijmzclcghg9FWWLI6WQsgM2FMRBAMoarFfBMCVWNf4A1iulDoPoAuA2UqpjnaUKXPQrRuH1hERwOuvc73Bkyd2f2yePNwaqakBoEYNmjjCw1ma+b33gAkTqCBiM3kySzdHRqZOROzixcBHFiEM4eGp8jElGScnypMGk88KGQ2ttV1eAFwAnAVQHEAWAAcBVEjg+h8AdEms3WrVqmkhCdy6pXXRolq7uGgNaP3ii1oHB9v1kX/+yUft3Bn33I0bWvftq7WzM68BtJ44Ueu9e7V+8IDXFCumtZeX1tWqaX3xol1F1Vqbcmitdbdu/HvNGvs+s1cvradNS9q1hnz//a/WZctqHRmZ8PUREVovWsRtZiQiQuuoKEdLkXYBEKjj6VftNiPQWkcAeAuMBjoOYIXW+qhSarBSarC9nitE4+UFbN7MYjeurkxR8eKLdvXEWpsRGHh7s87OihXmsTFj6FTOmZOplK5eZRtNmjCD6erVXK189ar9TUXBwdxu3mzf5xw8CGzYkLx73niDC/WCghK+buZMTgBT0RqYpihYkLOo+/cdLUn6w8WejWut1wFYF+uYVcew1rqPPWXJlJQsyV5h9mzgwgUu/23ShOsPRo3iqi8b4uXFJjt0AF54ARgxgoXW7t3j2LZ06ZjXr1zJ60+e5GvJEjqKZ8yIa6JxdQWKFmVZZ2uvokXpIkkJERE0CwHApk0payOp+PoCx48n7dps2WjVO3GCZUD/+ou6PD7q1+c2Z87nlzMh3nuP39H27Yz0WrTI5j+lFGHEkTx8aK5SB6gYvv0WGDmS9Z5SwpUr/H1ZtpuRsKsiENII7u6My5w5E2jZkqEou3cz8D+lvacVSpQALl5kp1y5MqNYASoGa7b3qCi6MwyWLmVBtl27GDFTsCBDUAE6Th8/pj779VcgNDSu7dzTM35FUawY27PWETx4YI4ijx/nP32hQsl77zduAPnyJXzN7t1mrYaoKI5e4yM8nAvsPD3px1i3LvH2s2fn1t4+hfPnuZ0/H/jxRyqCtETsGk4ffECF5eLCbO6jRjGoLjkULswU659+yu+kdWubiZsmEEWQmWjWjEPeli1ZDrN0aa5BiD1UTyFKAX/8wb8/+YT/OCdPxm+Nih0meuECkDUrO7wbN7hEwqBLF/7zbdzIt7F4MctjXrxovi5cAC5dAs6d42j17t2Y7bu48B/aUAwGv/1Gc1bVqsCBA8CWLUynbRAezmUZ+fNzZhKbs2c50g8IABo2jP/zsTTtXL5MhRkfhrN89GjmbEpK575gAbdhYdxu2gQ0aBBT5tBQKoyElFBiNGtGZ39ICD+TtDAbsMR4/waGqfLePZomu3VLviIAqAC/+oq/kTRaSjzFSK6hzISTE/DSSxx6e3vT+F6lCutL2ojB0d6fnDkZSrp1qzlSff11dvAhIUDv3kxSZ0mhQjQveXtTSfTqxWI1gDntN/DxYUdaty7dIP/8ww4pJIR6zehImzXjaHrOHODDD3m9Urze4PXXed+BA1wQ98knzPT95pssG920KZ+1YgVnDNeuxayiZtReSKyympFmGwDOnEn42qxZWaQOoEno5EnOIix9JQEBwPvvs3MHzIJAT57Q19G0KfD55+b1T57we3n//YSfbcm5cwwLDgri9/n4MbBnD89t28bv6dixpLeXGsSeERjjHOO3dO5c8tqzVMJ9+sRsK6MgM4LMSJcu9BcMHkw7TIcOHHpOmJByI2o0b7wBlCnDDnPYMCqCW7c4pe7ShZ08ENecEBoKTJ0KtGtnzggAoEAB4JVXuNgMMHMXWYaXak2xly9nB75uHY9Vrcrjholq+HA6a48c4flHj9iJ37gRc2ZhvHbtiun47tkzpsxubjFH10OGUJFly2a+smThTMTFxVSKrVtTqfz2m3nO2qtuXTrLhw1j+05OwHffUXfnysXZ17RptNnnyMHPcNgwfgezZvGeixdNeY0OcuVK4NAhYMAAKtGEKFMGePqUCvXvv/kdGt+dMeP6/XcquZdeMpW+LTlwgLMao052fAwaRBNQbBk+/pjLa378kfvJTXPu7Mz3fP48zUIAv7v33kteO2ma+MKJ0upLwkdtzKNHjOsEtG7WjGGnNuDcOTY5d67181FRfLTBxYu8/vvvtW7SROtatbTu0UPrr7+Oed/ixaaoBg0aaJ0njxluCWh99KjWnTtrXaaMed20aTx34QL3T5/W+pVXtA4MNK8JDtb6778ZqtmhQ8w2p0zReulSrWfN0vrTT7UeMULroUO1rleP54sV07pFC60bNtS6Rg2tCxXS2tdX65Iltfbx0drVVeusWbXOm1drDw+t3d21dnMzI3xT+lJK69y5+XfBglrXqWPKVLmy1u+/r/WECVpPn651tmxaN25syvvvv1ofPqz1mTNaX7mi9d27Wj95Yn4els/p0UPr7dv5d9Wq5nHj72PHrH/Xly/ze7txI7FfTUz+/psvyzDf52HyZLbzzjspu//4cVOWdu2Sft+NG/xs4wttPXqUbe7ZkzK5kgoSCB+VGUFmx82Nw7yFC2mAr1qVQ00/v+dq1rDBDxpEM1DWrDHPd+rEkM3AQO4bo9Vs2YB69Wi1WryYMwLAdK4aMwLj+pAQmij69uVb8PTkbGH8eIafGvdevWru//030L49R8QBAVx798EHHDkeOsSRXkgI/Q2WlCvHyVNs5s1jsr2mTelABTha9vSkWeLkSR5r3Jjv7cUXaaZ55RWzDa0pZ0QEX4sX0zRlUKkSZRszhiaj+/eZ0G/vXrabNy9nGXfv0haePTs/83PnaBZ7+NBsa8sWbi9eNM1PsXF2ZhsuLpQtMhLYuZNmKoDPM2YJBw7w2FdfcQaRJQtfWbNye+AAZy5DhnA25ORkvpSKuW+8oqLYviWHD1MuJyduY8+gQkMZ+5AjB2cQxrX9+vG30bQp20lu5M+xY5wwV67MfaWYUDGpfPklzWqxfRcGv/7K7cqVDKd2BKIIBP53rl1LL9q1a+ytVq2K+5+YDCydkbGVAMD1AoYSAMzwTTc3duJXrnCaX6IEHbC5cnE63qgRr3v6lFvD1t6hAx2jTZqwY7fMyN26NTsGwy9gtBsQwP2zZ9k5vv12zLUQWbLQ3GGsLYgvpbYhi+V7NqKQjAgbwOyAW7akucxSESjFjsvZmZ+X4W9wcaFiqFCBiiBPHqBNG55btYqKoHJldnJr11JBNm5MZ3fhwqZDOiKCnXi7dnTwGo765cv5zLAw+koAmvHCwtjWvHnm+y5e3PwsgoP5nTk7mzZ0QwnGx3ff8ZVSjJrWycHJyXTsbtzIz/m//6WsLi78XSRknnNx4cBi/35zNXy9evwe+vaNe6219n79lZ/VlCkMgIh9rRFEcP8+8Oef1mWIjOTgoXBhBibYGlEEAunYkYWGO3fmf07r1hxK9u+f4ibPnIk/V47RGRkjfcsZAWCmsy5RgsrBsOuWL8+Q0sOHuX/hArc+PuZo/dNP6fA9d46d13qLBOeTJnF0N2CAecxYTJYrlxlhExICnDplzmw8Pa3H8AcG0tXy4YcxaykYysHIxmqJry9t/wlh+EDCwznCLVyYsllGWhnK89o1fl2PH9OXEhbGr7NNG7NzdnExbfrffmuG5fr6MvUHwKijrVs5OzK+txo16Fu5epVRSLt301n8zjt09nt6molu58/nWOLJk5ivRYv4nTRrxjQihtPbeBmzoagooEULTlBr1TLTkX/1FX07J09SGfj6cqDw4ovmDCoighnYAf6Eq1aNObsyfCXZs9M99vRpzHvje4WHx40+u3qVUV8hIeygs2XjtZZtGt+/JYklEPzvf/lKiJEjGcBga0QRCCadOnHINnIkh9UDBrDn+c9/UtRciRLxnytQgP8wd+7QwWo5I9ixwxwtly1L5/GpU9w/c4b/iIZiMEbcL7xgtt2nD80knp6s7Nm7N49nycK3Nnw4ZweGqcBSERidubGW4MAByvL993EXav38M6ux/fwz77Wc+vv6shM0FMuBA3Tifvcdz929y47EmIFYcvQoTS65c3PEvWwZ73F2Zgdp8Mcf3DeUg1JUGtev87VgATskY6Wx4fj28jLNL2fOmIrAcO7v3UtHNcD3166dOUOpVcvspE+d4k9k3jyeu3vX+rIUw3mbK5f5LGtERbFj9fWlYjd4912Olr28qCiuXKGp5+uvY95vKIJGjcy/Ac6ELl5kO15eHCAYCvLqVX63r74afxT1N9+YDnuAlfcCA81nWFvJXKoUI4vCw83BwMKFVNixlcaNGxzQlCsXVxEZ17m48Pdr+bnYElEEQkzeeIMKwdOT/x2jR9P4OmmSTQPGDdv/tWv856xXj7btbNnYeV64wJBRHx9GGhkd/5QpVBSG6cfHhx215ZoDrTnFLluWEUT587N9w8bs7GyGXAKmEsqVywzPfPSIppi6dSnj1au8p2BB876dO7mdNImdU+w1BJbx+xcusIM1ynEC7IStKYKzZ2nTb9yYo9f27TkStmYaGTuWncTSpTQ9Zc8es2qpEeoJmIpg5Uqae3LkMMMgHz/mCmaAHXvduuyA9u7lAvX4kv4aCmLSpJiLAw2iosyOMj4bucHdu7w+S5aYIanOzlT0UVE0/f3zD+UPCzOVjGVYbezwUeP5oaG07W/bZp47fx4YN46/r2+/tS6XZeRY3bpUwJ9+ys+lWjXr95QsSUVuaRZ9+tT6osAyZRgx5+fHdh2BrCMQ4pIvH3sxYzgzZQqHgDZM+FO5MlfM/vYb/wEOHYppmilWzPwn8/Zmx/joETuAMmVoPgC4Xb48bvseHlyHsH8/23nwwFQ+GzdS1wHs/MaPZyfr6ckO57ffzE793j2OksuUoYnCEmNEeOgQbfG9epnn1q+n3nzpJe4baw0KFIipCKzRrh3txmvX0i79xRdURBER5nqH2rUZRvroEdC8OTvHtWv51RlKrmzZmJ2Y8ffmzexAL140s6/evGmaLox6Ejdvsvid4WS3hqEIhgwxQ3wNNm9mJ960KZXx2bOUuU4d62sPDPnGjmWop7e3mZfJMAHWrMm2gJhmMkslExpK2SMjuRivXTsenzeP7Vp+JoaSWL+eAwVra0Hc3PhZurvTVJY7NwcbFy9aV+QAn2uskzFCeY3nxl5l/+OPXMrj0PUY8YUTpdWXhI+mIjt3MubRiJl75x2bp3ds355N16/PMMebN+NeExCg9XvvaX3/vtatW2tdqpTWGzdqHRISM9TRGvfva/34sdbr1jHTqdZaz5/PZ547F//bmTGDYaCVK5tv/403Yl7z+LF5rmxZrcPDtZ49W+vbt5kx1Dj34IHW48YxzPPpU76uXo3/2VeumBlZjTZ++knr3r21LlyYL+P40KH8LLp2ZZiq1lqvWsVzbdow26vxnKlTzftq1eJzDA4eNMN3jSS1Rrjk0qVxZTTa+fdf3rt+PeWwpF8/XrNwIcODly0z7/vxx7ht7t5tnq9encfOnjWP1avHY3/8wf1du8x7w8IY0lu0KD8HQOu33uL2//6Pv6vy5c22jN/NihXcN37mb75ptnn3bkz5pk/XesMGfkaWYbWrVvH7NggPN8+dOMHPP2tWrT/8UOu//tL69dd5jUGJEry2YcO4n4ktQQLhow7v2JP7EkWQyly7xoB6I3/0pEk2aTYqih15iRJcN2B0zufPJ3xf/fpaOznx2mnTKNaYMcl7ttGR7N5NBfOf/8Q8/9dfXH9QrBjXHEydSjmbN2eM/uLFvG76dK2/+IIZvgGuDQC07tlT6y+/NDuD06epRPLlS5p8xYuzDa3NNtau5bqF2GsIKlTgtlIlrl3Qmh38L79Q+QBa37tntj1sWMxO/PXX2ZFv2cJjmzdTmdWoYXbMf/wRV8ZWrXju8mXu166tddOmMa+5do3XvPSS1qtXm8oxvu94506ty5WjUgOoXAYNMjtbg8BA8zOxhp8fzxvrBv78k8fLlTPfu7Gmwfjd5crF9RzDh3MNSY8eWlesaL39n3+O+R00aqR1jhymwj12zDy3ciXXB2zaRAU7YYLWWbJQcWnNe7Jn57V+ftafZysSUgRiGhISJn9+Rg95etJLOHo04y9tQMGCNBeULWvada1F2QCc+l++TBNPlSo0IQ0fzul/fNPz+DBMRG3a0Ok3eTKdkAaDBjFCpVQpmqjef5+O4mvX6Bf47TeaTz74gM5uw05tRJe89hpNSgbXrrEdwzkN0IE7ZYp1+aw5kXPkYCQMEDM6qXx5bo8c4Vc0cyYXiHfowOsbN6ZZae9eflaWzlxvb0b0rFtnmi3y5KG5ZM8e3gdYj7tftIi2dsPmnScP7e+GCQfgWgMnJ5qIevWi+8nFxQxpPXo0Zmnt2rVpHjHqNLdoYf7ULB315crx/RhmN4AmnYsXac3s25fHDDNjq1b0ER0/zs+oalXTBGSYhh494u8iJIRO5SVL+Jleu8aoMMt4idy5uf3iC27Dwmi6NPxYJUqYDvo//qA5KVs2Rn6dOcPPff58ZnANDTXNWqlRjCle4tMQafUlMwIHcf8+59OtW3NIHt9wLBl4enIk9O235gjKcvRqScmSWnfpwpWm27dzqr9hA0e9164l77nBwebzGjTQz0w7BlWq8NjgweaxunXNe3LlMkd9nTrFHaX/+2/Mkffq1XFleP11mnhiExHBe8aN437WrNw3iv2EhNCs5ObG41OmmM/p399cJL5+vdnmggU8liePeR/A0WiJEhyBG+aOixe1/uEH/v3559weOpT4Z9qrl9nukSOcYTRtan7HxuvTT7UeMoQjb0DrV1+13l6NGub30KSJ1qdOJfz8PXvMZ/TqRTOcMXO0fA0dGvO+p09jflctW2pdurTW+fNzf8kSziQ6dzbvefhQ65Mntd66NWbb//yj9Z07nIFcuaKfmbgArefM4aymbl3+5jp1ohnr5Enzuo4dY8oWEmJbSyxkRiA8Nzlzcog1ZAi9uz160BP7HJQqRUfn0KGm89dytGtJ7docgTZpwtGdsbJ1/HjTeZhULCM3WrdmDpsffjCPaW3KZzBqlLlg7P59lngAGOZpRNoadOnCdQIvvhgzNNYSX1/OcK5dixnBZMwqjBmBkUguRw5uPT05cq1dm/tGOObnn9MZasxOLB3XGzdyGxLCr+3dd/l1KkUn+KlTjLw6doyzNGPG5ONDR3xSQha9vMy///c/jnw3bqRju1gxzgSMz2b/ftPB7+9PZ/vu3XSqNmrEz//ff80Z08SJjMKxZNkyjqqNsGLLldMXL/K9WEZtGTMK43M14h5cXDgDad+e+ydP0tH7zjv8rNevZ0SYZdRU9uyMNoodJXb2LGePrVqZOYmOHOF29Wq+jzNn+N23bMmV69u38/ykSWaKcoB/58ljPRDCHthVESilWiqlTiqlTiulPrJyvodS6lD0mXe5zgAAFsRJREFUa6dSqrI95RGek5AQZinz9GRP3K4de7MUUqCAGfmxbBmbspbmGeA/3Y0bNH1YZvFMCa6uXEVbrBjXFRw+HDPVgrGYyjDFADQjVanC4i+urgw19PCgIrt92+yoAf6Dt2zJ3Pe3bjERXrFitLAZGJFDBQuaqQuAmCYagM9bvz7umozNm9lhNm3Ka43PxFAExjFfX4aWGh3xuXM0zTRowP1SpdiZurvT5OLiYioCZ2cqiKQUujHkLVuWK2kt1yzcu2eu8zhzhseuRFcvr1GDMfrvv8/O+/BhM0rZMGNdv24qZ4MhQ6h8S5dmWKZl1NDWrfyZPn5smrUKFqSMn3zClBMdoyujL1nC35SRTsQYkJQtS9PT4sWU3yj6A/BZI0bE/QzOnOGahfr1+RsqU8Y0ORUuzPccEsLvpFw53lOoEJ9trNu4f5/v1Xjevn3mM60tUrMVdlMESilnALMAtAJQHkB3pVT5WJedA9BQa10JwEQAtjE+C/YhTx4O0wIC2LPdu8ehlOVwLBns28cVo1qzA0qoGEydOty++65tSjE+fBiz87akShX+8zZubB67cIEj2dKl6ZsAaMPOn58jyNu3mQcoKoofjWVo6N9/85+9QgXzmGWaAEvF5unJvDxGzplZs7iAKaH6QfPnU1n8/rupCDw8GM5prD2YMYPbBw8YkmosxqpQgdf9+KO5yKpQISre4GAzLUZi9O/P0e9rr3E0b4zI587lz8RQJj/9FHP2UKiQqYxi+0aM99ypU0xfAhDTB3P9uumbMOz3RtoGw19QqJD5M/Xw4OcVGsrZy+efc+ZSp44ZslyyJMODDR+MZbaVqChzZbux+vn4ca75uHKFszEnJ4Ym16vH95cnDzvzsDD+fgxle+sWP/8dO6gsq1bls/Lm5XdjrPNYsIC/jeRmTk0y8dmMnvcFoDaA9Rb7owCMSuB6TwCXE2tXfAQOJjKSYToeHjRyOjnRuJlYZXUrHD5sRuAk5bGGLXbq1GQ/Kg6lSzNKxBr/939a798f89ibb/LZo0dzPyqK/oxRo3h87FjzWiO75zvv0G5syP30qXnNzZsx7cvxFZyvVCnh81qbmV7nz2dmVCNSx+DxY25Pn9Z6zRrTj2FJ5860hVvy7rta58wZ/3OtcewYfQsdO2pdoIDWI0fyeVeuMDz4xg1Gahnv++FDRl4B9AvUrGm2FR6u9YABPGdE2RhYZj/19zf/9vWN+bmuXq3122/z/QH0KxkZVCtV0rpaNWaDNfw906bRbxEayufMmhXXjxEVZbbfuzf9CVpr/dprPHb2rHntzp30JXz6Kc+FhPD4/ftae3vzc/rss5hZVocP5zXt2zMqTGvK5uOTvO8iNnCQj6AwAMv8jcHRx+KjP4A/rZ1QSg1SSgUqpQJvxq5QIqQuTk6cX69YwVU7/fpxiJmCNBQVK8a0ZSf22OnT+bctavJu2sSRvDVKlYppFgLM2YORo0cpmh2MlaPh4fQj/Oc/5mj/+nWaBACOAF0s1vF7edE2/vbbHJkaZSBu3eJI0ljUdeiQ+bz42LuX29y5mbG0ePGYsyvD1OXra464DROIwZ07cSOV7t0zR9hJpVw55l0qX572f2MldpYsXD3r7U05XniBEVjZs5u+mD17Ys4WjGp1zs5xo8k2b2ZqCMD0wXTpwhkJwIiutWs5Sp85k4vUAH6+9eox3cOhQ5yVVq7MmU/Hjhytjx5tft9vvkmzpSWW30Xu3Oas4ckTtlW8OPdHjeJzGjQwZ4NGUZycOTm6z5OHKSwMP0yWLIyuAjgrOXuWM5CAgJgzVJsTn4Z43heArgDmWez3AvBNPNc2BnAcgFdi7cqMIA1hhDUMHsyhzIIFdn1cQguc7M2kSXy25UIgrTmrAcwRatGirMEAcEGV1lp/9x1zzicFI4Lq+nXuG3UWEqJbN17z00/c37cv/toARgy+0WZUFCNlgJg59jt25LHy5ZMmt0FUFEfEV69yf8YMtrNmTfz3HDnCa5ydtf74Y/P406cxZY2Nsc7hp58YJTR+PJ8LcCRvyd69PP6//3H/yRPG8xszAa25cBHg6DwxrMn15EnMWZ8RaXbnDveNmZklrVtzVqI1I7YsZ34XLnB2evs225k2LXG5EpbZMTOCYACWVVmLALgS+yKlVCUA8wB00FonI8u34HA8PTk8KlmS3q1Bg8zcznbAmAwao+XUxEjFHNtGazi3jZGvmxswcCCjn4wR9uDB5qgxNo8esf6BkeffcLIaUSdBQWbUT3y8+y639eszsuqDD+KPvorta1CKo+5ChWLawUuVoi3dMlY/KWhNP8rMmdw3IrRi5+9ftMhMH1GqFN/jjRuMrDFIrK5yyZJ02hcuzJH5N9/wfU+dylG/JUY2W2PdgKsrv1NfX9OhbNjtrdWciI2fn+lwNjDSShsYEUNG5tPYmXjHjOEaDuO5RYvGLBBYrBg/GyM9iWWeK5sTn4Z43heY0O4sgOIAsgA4CKBCrGuKATgNoE5S25UZQRrj2jUaklu0YDB+njwcxtiBu3cZK3/7tl2aT5AJEzgqO3485nFjhN29O7d+fhwZAlpPnJh4u5GRXCswYgT3U2KXt2TpUj7755+tnzdGzF5eKX9GYhQpwmd89hlnCBs2xHQhBQXxfOnSibcFcLSfEJbV6YYNo80/NkZ6jd9+i3vOGIXfvctrWrVKXK47d+LODmNjrDg3/A2xMfwW/ftbP//oESv0rVnDVdLP+28FR8wItNYRAN4CsB40+6zQWh9VSg1WSkWXOMdYAF4AZiulgpRSgfE0J6RV8ufnEHT9eg6ltGa8pGVmLxuROzdj5ZO7ktgWjB5N233ZsjGPv/gis3ka9Yzd3GIWlUkMIxvpqVP8CGfMMGcDKcFYhxDfLKJAAbafWK3i58EYuYaEcMbRrJn10b1lUfj4KFAg5hqN2Ozbx+eMG8cRfmgo/QCxQy3fe4/RW23bmsdWrmR4rDEKz52bcf0//5y4XB4e1gsuWfL77xzxxxedZsyWjBlBbFxcGFYbGEh/g+W6Fltj13UEWut1WuvSWmtfrfWk6GNztNZzov8eoLX21FpXiX4lkK1cSLO89Rbn159/ThvI+fP03Nkz8DmVcXZmXHhslOJbNf7Zc+Wi+eXuXTpNk0KtWnReFypEk0bz5imX01AiCdWCaNIk5toFW2M4mONT2EbK8PjKZFpy/37ckqGWGAvqChfm524UIYq9vMXJKWaKD4CKfcUKVnozqFfPNsEIABViq1bxnzcc+tbWJABUBD4+NGtdvBh3LYUtkXoEwvPj6ckFAYsW8b+ob18O43r35momG9YxSKuUKsXVoUZBneRE2/TrxzjxwEBzpWlKMZ4fOyrIEssynvbAsIVbRgBZ8sILVHxJUQRhYcCuXYlfZ4yqS5emErCsTxEfhnzr1yf8edkLQ+b79+OXt2RJ+o/KlYu5At3WiCIQbIOPD+fn58+bQ6/lyzmMHjcuwyuDQoVSXMgNdeowvPA5Fmk/QynWE3IkbdrQJBKfIgCS7oQOCUmaic3oVFetokJNiiKoWJFby+p2qYkxY7p92yyJGhvLFej2/BeSXEOCbYmMZHD44Gg30IQJLF0VX+V3AUoBa9Yw5j0jUL8+zU+26GA9PRM21RgdpbHNkyfpprUGDRjkNnLkc4mYYjp1YjqMKlXiv8bIsZSYP+J5Udqehic74O/vrwMDxaecpomK4uqaF14wS2EVLUofQrduGX52IKQer77K9BAnTzpaEvsQGkqzUO3a9Gc8D0qpffH5YWVGINgeJyeGz/zwA30EW7cywPvVV5llbMmSuPX6BCEFvP22ueI8I+LuzhXedl1DAPERCPbEMmSiZUsO2y5dYqzlsGGcG3ftSoNxfGlHBSEBjKydGRWtuWiudGn7PkdMQ0Lq8OABvYiWYTFZsnBmkDMnDbaNGnFbqVL8pcoEQUgRCZmGZEYgpA45czJT2MaNXD115w7DNlxdafxcvZp1/QCalkqUYIB9tWoMG2nenMeSEkIiCEKykP8qIfVwcaGJqGXLmMdDQ5nA3smJ27AwJmK/c4cJ7A2cnRkeUrgwt76+dEIXLcqk7oUL2z+8QhAyIGIaEtIWWjPK6NAhKoyQEOC//wW++oorb1xduWI5e/aYZakM3N0ZQ+jjQ+WQJw9zHxcsyG2+fNx6eXHVl2WWL0HIwCRkGhJFIKQPtKbpaNIkJozp14/mpt9+47GbN6k0jNSkFSowUf3ly9aLBhtkzWoqg1y5GLheoQJnGR4eVDj58/O4hwefabzc3enLkHBYIR0gikDIHGhNp/T165wNZMvGPM6rVrGG4I4dnEU8esREeQ8fshZkWBiTBYWG0nmdKxfXQiRlTb9SdHrnyMHZipsbn+3tbRbK8vCgIila1ExG5O5OJZM9O+/ZsYP3d+1qHhMFI9gQUQSCYElYmFncNzZXr1JRlCjBNJbffMPZxu3bNFmdOcMlwH5+LFS7eDHbs0ylWb48ZxlXr8YttpscnJ2pHLJk4cwlWzaatLy8qKguXqRiefyYkVZZsnARX86cVGgXLlC22rUpX8GCLIbr5sb24tuKQz5DIopAEOxJRAQjoQoUYAft7s4O9f59Lnt98IDKZO9eOsT79OFof8kS4J9/2HEXLcqUpS4uTKl57x4wezbNWkZea4DZ7fLm5fnjx+2TktIwlTk58W/j5eFBZWIoORcXKirjlS8fZ0MATXJublS4xgpzFxcqM6Uoe4ECdPhnzcrPzcOD7fzf/5npXnPm5L1Zs1LRZcnCz9vJif4fmTUlGVEEgpCeiYqiMrlzhzMAI1tZVBSVh1JMfp83LxP9aM1E/eHh7DBXrqSfw9mZazXCw6mErl6lknn8mDOI/Pl5/+PHXBH+4AE73YgIPitvXhZkCAszS379f3t3FyPVXcZx/Ptj3QUKLSsvWoIEKEEMNloIL22oTS9qpcR01ZtWTFpfklpT6ktiCNrE1Du00YQrG4yN1WC5sNZy0UgbCxK7oYDI8iJFqEACRZamgC10l93l8eL5T3Y67Cy72xnOnD3PJzmZmXPOzP6f+WfPM/8zZ55/udZWP/h3dfVPjFBvki/NzZ5kxozxv93c7PGOGePLlCl+es7Mf9TY2+sJu7XVLz6YPduTzsWLnrSbmjxxlZb5833/Eye8H8pHaH19/ZdCnzvn70950uruhiVL/Dlvv+3v+8yZfjqxpcXfs0mT/O+cOOHJ9qab/P2fPNm3l9ozbtzVU50N+a2KRBBCqCWz/ppSpWXCBP/03t3tn+q7u33+yfZ2P93W1uaFczZt8v3On/eRlJkfCO+/31/n+HGf5f3YMX985IgfhFet8sdbt/pzOzv9wH35sp8amz/fD+qvvuoXDvT1+XLlim+7+WY/yHd0+EH70iXfBl7draXFTwGeOXN1vDfc0B9v6TlZWLPGa3aNQGaJQNIKYD3QhE9kv65iu9L2lcAl4Otmtmew14xEEEKoiZ4eT1izZvVP5vz++76Utvf2+u2MGZ48zHwS4a6u/u+Axo+HhQt9v0OHfMRRGmVJ/ql+7lzf3t7uI5RTp3x7T48n0OXLPWnt3++Jr6vLP/1fuOAjmWXLfPvixf7r+xHIJBFIagL+DXwen8h+F/BVM/tX2T4rgcfxRLAMWG9mg05XEYkghBCGL6vqo0uBo2b2HzO7DGwC2ir2aQN+l+ZW3gG0Sqpznb0QQgjl6pkIZgDls42eTOuGuw+SHpG0W9Lus2fP1ryhIYRQZPVMBANd11V5Hmoo+2BmG8xssZktnjZtWk0aF0IIwdUzEZwEZpY9/gTw1gj2CSGEUEf1TAS7gHmS5khqAR4ENlfssxl4SO524IKZna5jm0IIIVSo22/JzaxX0mpgC3756DNmdlDSo2n708BL+BVDR/HLR79Rr/aEEEIYWF2LipjZS/jBvnzd02X3DXisnm0IIYQwuJi8PoQQCi53JSYknQVOjPDpU4EPUQ6y4YymeEZTLDC64olYGtdw4pllZgNedpm7RPBhSNpd7Zd1eTSa4hlNscDoiidiaVy1iidODYUQQsFFIgghhIIrWiLYkHUDamw0xTOaYoHRFU/E0rhqEk+hviMIIYRwtaKNCEIIIVSIRBBCCAVXmEQgaYWkw5KOSlqbdXuGS9JxSfsl7ZW0O62bLOkVSUfS7Uezbmc1kp6R1CnpQNm6qu2X9KPUV4clfSGbVg+sSixPSjqV+mdvmnSptK2RY5kpaaukQ5IOSvpeWp/XvqkWT+76R9I4STsldaRYfprW175vzGzUL3itozeBW4AWoANYkHW7hhnDcWBqxbqfA2vT/bXAz7Ju5yDtvwtYBBy4VvuBBamPxgJzUt81ZR3DNWJ5EvjhAPs2eizTgUXp/o34rIILctw31eLJXf/gZfonpvvNwOvA7fXom6KMCIYyW1oetQHPpvvPAl/KsC2DMrPtwDsVq6u1vw3YZGbdZnYML0q49Lo0dAiqxFJNo8dy2tI84Wb2LnAInxwqr31TLZ5qGjYec++lh81pMerQN0VJBEOaCa3BGfCypH9IeiSt+7ilst3p9mOZtW5kqrU/r/21WtK+dOqoNFzPTSySZgML8U+eue+binggh/0jqUnSXqATeMXM6tI3RUkEQ5oJrcEtN7NFwH3AY5LuyrpBdZTH/voVMBe4DTgN/CKtz0UskiYCzwPfN7P/DbbrAOvyEE8u+8fM+szsNnzSrqWSbh1k9xHHUpREkPuZ0MzsrXTbCbyAD/nOSJoOkG47s2vhiFRrf+76y8zOpH/aK8Cv6R+SN3wskprxg+ZGM/tTWp3bvhkonjz3D4CZnQe2ASuoQ98UJREMZba0hiVpgqQbS/eBe4EDeAwPp90eBl7MpoUjVq39m4EHJY2VNAeYB+zMoH1DVvrHTL6M9w80eCySBPwGOGRmvyzblMu+qRZPHvtH0jRJren+eOAe4A3q0TdZfzN+Hb+BX4lfQfAm8ETW7Rlm22/BrwboAA6W2g9MAf4KHEm3k7Nu6yAxPIcPyXvwTy7fGqz9wBOprw4D92Xd/iHE8ntgP7Av/UNOz0ksd+KnD/YBe9OyMsd9Uy2e3PUP8Bngn6nNB4CfpPU175soMRFCCAVXlFNDIYQQqohEEEIIBReJIIQQCi4SQQghFFwkghBCKLhIBKGwJLWn29mSVtX4tX880N8KoRHF5aOh8CTdjVem/OIwntNkZn2DbH/PzCbWon0h1FuMCEJhSSpVdlwHfC7Vqf9BKvT1lKRdqUjZt9P+d6da93/Af5yEpD+nQoAHS8UAJa0DxqfX21j+t+SeknRAPr/EA2WvvU3SHyW9IWlj+pVsCHX3kawbEEIDWEvZiCAd0C+Y2RJJY4HXJL2c9l0K3Gpe5hfgm2b2TioBsEvS82a2VtJq82Jhlb6CFz77LDA1PWd72rYQ+DReH+Y1YDnw99qHG8IHxYgghKvdCzyUyv++jv+kf17atrMsCQB8V1IHsAMv+DWPwd0JPGdeAO0M8DdgSdlrnzQvjLYXmF2TaEK4hhgRhHA1AY+b2ZYPrPTvEi5WPL4HuMPMLknaBowbwmtX0112v4/4/wzXSYwIQoB38WkNS7YA30nljJH0yVT1tdIk4FxKAp/CpxEs6Sk9v8J24IH0PcQ0fNrLhqh2GYorPnGE4NUde9Mpnt8C6/HTMnvSF7ZnGXga0L8Aj0rah1d73FG2bQOwT9IeM/ta2foXgDvwSrIGrDGz/6ZEEkIm4vLREEIouDg1FEIIBReJIIQQCi4SQQghFFwkghBCKLhIBCGEUHCRCEIIoeAiEYQQQsH9HybIIEsbut9xAAAAAElFTkSuQmCC\n", 578 | "text/plain": [ 579 | "
" 580 | ] 581 | }, 582 | "metadata": { 583 | "needs_background": "light" 584 | }, 585 | "output_type": "display_data" 586 | } 587 | ], 588 | "source": [ 589 | "train_loss = {\"aux_loss\":[], \"target_loss\":[], \"final_loss\":[]}\n", 590 | "test_loss = {\"aux_loss\":[], \"target_loss\":[], \"final_loss\":[]}\n", 591 | "for epoch in range(epochs):\n", 592 | " for i in range(int(len(train_data) / batch)):\n", 593 | " record_test_loss(test_loss, test_data, i)\n", 594 | " label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show, min_batch, clk_length, show_length = data_reader.get_batch_data(train_data, min_batch, batch = batch)\n", 595 | " user_profile_dict, user_profile_list, user_behavior_list, click_behavior_dict, noclick_behavior_dict, target_item_dict = get_train_data(label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show)\n", 596 | " aux_loss, target_loss, final_loss = train_one_step(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list, label)\n", 597 | " #Record_loss12\n", 598 | " loss_dict = dict()\n", 599 | " loss_dict[\"aux_loss\"] = str(aux_loss)\n", 600 | " loss_dict[\"target_loss\"] = str(target_loss)\n", 601 | " loss_dict[\"final_loss\"] = str(final_loss)\n", 602 | " utils.add_loss(loss_dict, loss_file_name, level=\"train\")\n", 603 | " train_loss[\"aux_loss\"].append(float(aux_loss))\n", 604 | " train_loss[\"target_loss\"].append(float(target_loss))\n", 605 | " train_loss[\"final_loss\"].append(float(final_loss))\n", 606 | " get_loss_fig(train_loss, test_loss)\n", 607 | " tf.summary.trace_on(graph=True, profiler=True)\n", 608 | " with train_summary_writer.as_default():\n", 609 | " tf.summary.scalar(\"train_aux_loss epoch: \"+str(epoch), aux_loss, step = i)\n", 610 | " tf.summary.scalar(\"train_target_loss epoch: \"+str(epoch), target_loss, step = i)\n", 611 | " tf.summary.scalar(\"train_final_loss epoch: \"+str(epoch), final_loss, step = i)\n", 612 | " tf.summary.trace_export(\n", 613 | " name=\"DIEN\", \n", 614 | " step=i, \n", 615 | " profiler_outdir=log_path)\n", 616 | " model.save_weights(checkpoint_path.format(epoch=epoch))" 617 | ] 618 | }, 619 | { 620 | "cell_type": "markdown", 621 | "metadata": {}, 622 | "source": [ 623 | "# 模型评估" 624 | ] 625 | }, 626 | { 627 | "cell_type": "code", 628 | "execution_count": 24, 629 | "metadata": {}, 630 | "outputs": [ 631 | { 632 | "name": "stdout", 633 | "output_type": "stream", 634 | "text": [ 635 | "./checkpoint/cp-0002.ckpt\n" 636 | ] 637 | }, 638 | { 639 | "data": { 640 | "text/plain": [ 641 | "" 642 | ] 643 | }, 644 | "execution_count": 24, 645 | "metadata": {}, 646 | "output_type": "execute_result" 647 | } 648 | ], 649 | "source": [ 650 | "last_model = DIEN(embedding_count_dict, embedding_dim_dict, embedding_features_list, user_behavior_features, activation=\"dice\")\n", 651 | "latest = tf.train.latest_checkpoint(checkpoint_dir)\n", 652 | "print(latest)\n", 653 | "last_model.load_weights(latest)" 654 | ] 655 | }, 656 | { 657 | "cell_type": "code", 658 | "execution_count": 26, 659 | "metadata": {}, 660 | "outputs": [ 661 | { 662 | "name": "stdout", 663 | "output_type": "stream", 664 | "text": [ 665 | "WARNING:tensorflow:Layer dien_1 is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because it's dtype defaults to floatx.\n", 666 | "\n", 667 | "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n", 668 | "\n", 669 | "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n", 670 | "\n" 671 | ] 672 | }, 673 | { 674 | "data": { 675 | "text/plain": [ 676 | "(0.029646765, 0.26222047, 0.29186723)" 677 | ] 678 | }, 679 | "execution_count": 26, 680 | "metadata": {}, 681 | "output_type": "execute_result" 682 | } 683 | ], 684 | "source": [ 685 | "model= last_model\n", 686 | "label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show, clk_length, show_length = data_reader.get_test_data(test_data)\n", 687 | "user_profile_dict, user_profile_list, user_behavior_list, click_behavior_dict, noclick_behavior_dict, target_item_dict = get_train_data(label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show)\n", 688 | "aux_loss, target_loss, final_loss = get_test_loss(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list, label)\n", 689 | "aux_loss, target_loss, final_loss" 690 | ] 691 | }, 692 | { 693 | "cell_type": "code", 694 | "execution_count": 27, 695 | "metadata": {}, 696 | "outputs": [], 697 | "source": [ 698 | "def convert_tensor(data):\n", 699 | " return tf.convert_to_tensor(data)\n", 700 | "\n", 701 | "def get_normal_data(data, col):\n", 702 | " return data[col].values\n", 703 | "\n", 704 | "def get_sequence_data(data, col):\n", 705 | " rst = []\n", 706 | " max_length = 0\n", 707 | " for i in data[col].values:\n", 708 | " temp = len(list(map(eval,i[1:-1].split(\",\"))))\n", 709 | " if temp > max_length:\n", 710 | " max_length = temp\n", 711 | "\n", 712 | " for i in data[col].values:\n", 713 | " temp = list(map(eval,i[1:-1].split(\",\")))\n", 714 | " padding = np.zeros(max_length - len(temp))\n", 715 | " rst.append(list(np.append(np.array(temp), padding)))\n", 716 | " return rst\n", 717 | "\n", 718 | "def get_evaluate_data(data):\n", 719 | " batch_data = data\n", 720 | " click = get_normal_data(batch_data, \"guide_dien_final_train_data.clk\")\n", 721 | " target_cate = get_normal_data(batch_data, \"guide_dien_final_train_data.cate_id\")\n", 722 | " target_brand = get_normal_data(batch_data, \"guide_dien_final_train_data.brand\")\n", 723 | " cms_segid = get_normal_data(batch_data, \"guide_dien_final_train_data.cms_segid\")\n", 724 | " cms_group = get_normal_data(batch_data, \"guide_dien_final_train_data.cms_group_id\")\n", 725 | " gender = get_normal_data(batch_data, \"guide_dien_final_train_data.final_gender_code\")\n", 726 | " age = get_normal_data(batch_data, \"guide_dien_final_train_data.age_level\")\n", 727 | " pvalue = get_normal_data(batch_data, \"guide_dien_final_train_data.pvalue_level\")\n", 728 | " shopping = get_normal_data(batch_data, \"guide_dien_final_train_data.shopping_level\")\n", 729 | " occupation = get_normal_data(batch_data, \"guide_dien_final_train_data.occupation\")\n", 730 | " user_class_level = get_normal_data(batch_data, \"guide_dien_final_train_data.new_user_class_level\")\n", 731 | " hist_brand_behavior_clk = get_sequence_data(batch_data, \"guide_dien_final_train_data.click_brand\")\n", 732 | " hist_cate_behavior_clk = get_sequence_data(batch_data, \"guide_dien_final_train_data.click_cate\")\n", 733 | " hist_brand_behavior_show = get_sequence_data(batch_data, \"guide_dien_final_train_data.show_brand\")\n", 734 | " hist_cate_behavior_show = get_sequence_data(batch_data, \"guide_dien_final_train_data.show_cate\")\n", 735 | " return tf.one_hot(click, 2), convert_tensor(target_cate), convert_tensor(target_brand), convert_tensor(cms_segid), convert_tensor(cms_group), convert_tensor(gender), convert_tensor(age), convert_tensor(pvalue), convert_tensor(shopping), convert_tensor(occupation), convert_tensor(user_class_level), convert_tensor(hist_brand_behavior_clk), convert_tensor(hist_cate_behavior_clk), convert_tensor(hist_brand_behavior_show), convert_tensor(hist_cate_behavior_show)" 736 | ] 737 | }, 738 | { 739 | "cell_type": "code", 740 | "execution_count": 29, 741 | "metadata": {}, 742 | "outputs": [], 743 | "source": [ 744 | "label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show = get_evaluate_data(test_data)\n", 745 | "user_profile_dict, user_profile_list, user_behavior_list, click_behavior_dict, noclick_behavior_dict, target_item_dict = get_train_data(label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show)\n", 746 | "output, logit, aux_loss = model(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list)" 747 | ] 748 | }, 749 | { 750 | "cell_type": "code", 751 | "execution_count": 30, 752 | "metadata": {}, 753 | "outputs": [ 754 | { 755 | "name": "stdout", 756 | "output_type": "stream", 757 | "text": [ 758 | "[训练集]正例:负例=501 : 9435\n", 759 | "[测试集]正例:负例=56 : 943\n" 760 | ] 761 | } 762 | ], 763 | "source": [ 764 | "train_label = train_data[\"guide_dien_final_train_data.clk\"].values\n", 765 | "positive_num = len(train_label[train_label == 1])\n", 766 | "negative_num = len(train_label[train_label == 0])\n", 767 | "print(\"[训练集]正例:负例=%d : %d\" % (positive_num, negative_num))\n", 768 | "test_label = test_data[\"guide_dien_final_train_data.clk\"].values\n", 769 | "positive_num = len(test_label[test_label == 1])\n", 770 | "negative_num = len(test_label[test_label == 0])\n", 771 | "print(\"[测试集]正例:负例=%d : %d\" % (positive_num, negative_num))" 772 | ] 773 | }, 774 | { 775 | "cell_type": "code", 776 | "execution_count": 31, 777 | "metadata": {}, 778 | "outputs": [], 779 | "source": [ 780 | "y_true = label.numpy()[:,-1]\n", 781 | "y_score = output.numpy()[:,-1]" 782 | ] 783 | }, 784 | { 785 | "cell_type": "code", 786 | "execution_count": 48, 787 | "metadata": {}, 788 | "outputs": [], 789 | "source": [ 790 | "threshold = 0.0031\n", 791 | "y_pre = y_score.copy()\n", 792 | "y_pre[y_pre > threshold] = 1\n", 793 | "y_pre[y_pre <= threshold] = 0" 794 | ] 795 | }, 796 | { 797 | "cell_type": "code", 798 | "execution_count": 34, 799 | "metadata": {}, 800 | "outputs": [], 801 | "source": [ 802 | "import numpy as np\n", 803 | "from sklearn.metrics import accuracy_score\n", 804 | "from sklearn.metrics import f1_score\n", 805 | "from sklearn.metrics import auc\n", 806 | "import sklearn.metrics as sm\n", 807 | "from sklearn.metrics import roc_curve, auc\n", 808 | "import matplotlib as mpl \n", 809 | "import matplotlib.pyplot as plt" 810 | ] 811 | }, 812 | { 813 | "cell_type": "code", 814 | "execution_count": 50, 815 | "metadata": {}, 816 | "outputs": [ 817 | { 818 | "name": "stdout", 819 | "output_type": "stream", 820 | "text": [ 821 | "0.8818818818818819\n" 822 | ] 823 | } 824 | ], 825 | "source": [ 826 | "print(accuracy_score(y_true, y_pre))" 827 | ] 828 | }, 829 | { 830 | "cell_type": "code", 831 | "execution_count": 51, 832 | "metadata": {}, 833 | "outputs": [ 834 | { 835 | "name": "stdout", 836 | "output_type": "stream", 837 | "text": [ 838 | "混淆矩阵为:\n", 839 | "[[876 67]\n", 840 | " [ 51 5]]\n" 841 | ] 842 | } 843 | ], 844 | "source": [ 845 | "m = sm.confusion_matrix(y_true, y_pre)\n", 846 | "print('混淆矩阵为:', m, sep='\\n')" 847 | ] 848 | }, 849 | { 850 | "cell_type": "code", 851 | "execution_count": 52, 852 | "metadata": {}, 853 | "outputs": [ 854 | { 855 | "name": "stdout", 856 | "output_type": "stream", 857 | "text": [ 858 | "分类报告为:\n", 859 | " precision recall f1-score support\n", 860 | "\n", 861 | " 0.0 0.94 0.93 0.94 943\n", 862 | " 1.0 0.07 0.09 0.08 56\n", 863 | "\n", 864 | " accuracy 0.88 999\n", 865 | " macro avg 0.51 0.51 0.51 999\n", 866 | "weighted avg 0.90 0.88 0.89 999\n", 867 | "\n" 868 | ] 869 | } 870 | ], 871 | "source": [ 872 | "r = sm.classification_report(y_true, y_pre)\n", 873 | "print('分类报告为:', r, sep='\\n')" 874 | ] 875 | }, 876 | { 877 | "cell_type": "code", 878 | "execution_count": 53, 879 | "metadata": {}, 880 | "outputs": [ 881 | { 882 | "data": { 883 | "text/plain": [ 884 | "0.679821239206181" 885 | ] 886 | }, 887 | "execution_count": 53, 888 | "metadata": {}, 889 | "output_type": "execute_result" 890 | } 891 | ], 892 | "source": [ 893 | "from sklearn.metrics import roc_auc_score\n", 894 | "auc_score = roc_auc_score(y_true,y_score)\n", 895 | "auc_score" 896 | ] 897 | }, 898 | { 899 | "cell_type": "code", 900 | "execution_count": 54, 901 | "metadata": {}, 902 | "outputs": [], 903 | "source": [ 904 | "def plot_roc(labels, predict_prob):\n", 905 | " false_positive_rate,true_positive_rate,thresholds=roc_curve(labels, predict_prob)\n", 906 | " roc_auc=auc(false_positive_rate, true_positive_rate)\n", 907 | " plt.title('ROC')\n", 908 | " plt.plot(false_positive_rate, true_positive_rate,'b',label='AUC = %0.4f'% roc_auc)\n", 909 | " plt.legend(loc='lower right')\n", 910 | " plt.plot([0,1],[0,1],'r--')\n", 911 | " plt.ylabel('TPR')\n", 912 | " plt.xlabel('FPR')\n", 913 | " plt.show()" 914 | ] 915 | }, 916 | { 917 | "cell_type": "code", 918 | "execution_count": 55, 919 | "metadata": {}, 920 | "outputs": [ 921 | { 922 | "data": { 923 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nO3de5xV8/7H8dendFVIJemioXtUNLq4K5JcyuV0ooMcdHI/+HUUJ0KOhOOSlOR6OIWi4nRBJ+RacSpdlCQZJU3RfaqZ+f7++M7UNGamPbXXXvvyfj4e89iz1v7uvT/LZH32+n6/6/M15xwiIpK6yoQdgIiIhEuJQEQkxSkRiIikOCUCEZEUp0QgIpLilAhERFKcEoGISIpTIhApgZmtMLNtZrbZzH42sxfNrEqB5080s/+a2SYz22Bmb5tZ80LvcZCZPW5mK/PeZ1nedo3YH5HI7ykRiOzd+c65KkBr4DhgAICZdQDeBSYCRwBpwDzgEzM7Kq9NeWA60ALoAhwEnAisA9rG9jBEima6s1ikeGa2ArjGOfd+3vZQoIVz7lwzmwl87Zy7vtBrpgBrnXNXmNk1wAPA0c65zTEOXyQiuiIQiZCZ1QXOAZaZWWX8N/s3imj6OnBW3u9nAlOVBCSeKRGI7N0EM9sE/Aj8AtwDHIr//2d1Ee1XA/n9/9WLaSMSN5QIRPauu3OuKnA60BR/kv8VyAVqF9G+NpCZ9/u6YtqIxA0lApEIOec+BF4EHnHObQE+A/5QRNMe+AFigPeBs83swJgEKbIPlAhESudx4Cwzaw30B640s5vNrKqZVTOzwUAH4N689v/CdymNN7OmZlbGzKqb2Z1m1jWcQxDZkxKBSCk459YCLwMDnXMfA2cDF+HHAX7ATy892Tn3bV777fgB42+A94CNwCx899IXMT8AkSJo+qiISIrTFYGISIpTIhARSXFKBCIiKU6JQEQkxR0QdgClVaNGDdegQYOwwxARSShffvllpnOuZlHPJVwiaNCgAXPmzAk7DBGRhGJmPxT3nLqGRERSnBKBiEiKUyIQEUlxSgQiIilOiUBEJMUFlgjM7Hkz+8XMFhTzvJnZk3kLec83s+ODikVERIoX5BXBi/jFuotzDtAo76cPMCLAWEREpBiB3UfgnPvIzBqU0KQb8LLz5U8/N7NDzKy2c07L+olIUsrNhSefhPXrS/e6Mjk7qfbb9zTr1pjOnaMfV5g3lNXBL9iRLyNv3+8SgZn1wV81UL9+/ZgEJyISbUuXwq23+t/NIntNa/c/nuPPHMYvPFNxKZ07R3+xuzAHi4v6z1Dk4gjOuVHOuXTnXHrNmkXeIS0iEvdycvzj66/7q4MSf7ZmkXvHAL4qewLH1VpNnfHDuO/RYFY8DfOKIAOoV2C7LrAqpFhEROJL9+4wbRpcdRU8+ihUqxbYR4V5RTAJuCJv9lB7YIPGB0QkmV15pX8sU9yZd9MmyMryv/fvD+++C88/H2gSgGCnj44BPgOamFmGmV1tZn3NrG9ek8nAcmAZ8CxwfVCxiIiE7bvv4MsvIT0dzjyziAbTpsExx8D99/vt00+Hs86KSWxBzhq6dC/PO+CGoD5fRCSevPKKHyB+8004+OACT6xfD7fdBi+9BE2bwrnnxjy2hCtDLSKyr2bPhpkzw/ns55+Hjh2hXsGR0enToVcvWLcO7roL/v53qFgx5rEpEYhIyujVC779NrzPf+SRQjsOOwzS0mDqVGjdOpSYQIlARFLE99/7JDB0KPzlL7H//LJl4cDKDl58Cb76yt9Zduyx8Omnkd9UEBAlAhFJCe++6x/PPx8OOiiEAL7/3meg996DU06BbdugUqXQkwAoEYhIgrvrLhg1au/tNm/2/fNNmgQf0x5ycmD4cBgwwM8bffppnxCKnUMae0oEIpLQPvsMDjgALrpo7227dAnhC3hmJtx9N5x2GowcCXFYJkeJQEQSXqNG/kt33Ni5E159Fa64AmrV8mMCaWlx0Q1UFCUCEUkY2dkwfrzv5sm3apWffBM3vvwS/vxnmD8fateGs8+Go44KO6oSKRGISML47DPo2fP3+1u2jH0sv7NtG9x7r58jethh8NZbPgkkACUCEUkYO3b4xzfegHbtdu8//PBw4tlD9+5+atI118DDD8Mhh4QdUcSUCEQk4dSqVegO3bBs3Ajly/u7ge+8E/72N+jUKeyoSi1+5i+JiCSSyZN9kbj77vPbp52WkEkAlAhEREonMxMuv9wXh6taFS64IOyI9psSgYhIpN57D5o3h7Fj/b0BX30F7duHHdV+0xiBiITGOT++umlTZO3nzw82nr2qXRsaN4YRI3ydoCShRCAioZk4ES68sPSvC3jBrt2cg+eeg//9z9+xdswxvo51nN4Ytq+UCEQkNEOH+htuJ06M/NxatSoceWSwcQGwfDlcey38979+tbA4KhIXbUoEIhKKTz7xN4gNGxZnvSw5Ob5E9F13+SJGzzzj7w2IoyJx0aZEICJR99tvcPHFsGFD8W1++gkOPRSuuip2cUUkM9PfIdypkx8LqFs37IgCp0QgIlH37be+RyU93d/8VZTDD/crhh14YGxjK9KOHX5R4d69fcBz5/r+pyTsBiqKEoGIBGbQoFDWYi+d2bN9kbgFC/y3/86doUGDsKOKKSUCkSTwzTeQkRF2FLstWRJ2BBHYutXfC/DYY35a6KRJPgmkICUCkQS3cyccf7yf1BJvQlkSMlLdusH770OfPn760sEHhx1RaJQIRBLc2rU+CfTrF1/VDg48EFq3DjuKQjZsgAoVfJG4gQN9obgzzgg7qtApEYgkuJ9/9o8dOsDJJ4cbS1x75x3o29fXCXrwQTj11LAjihvJOzFWJEVcd51/jIua/PFo7Vq47DI4/3w/XzWSxY1TjBKBSALLzIRZs6By5Ti7KStevPuuLxI3bpy/N2DOHDjhhLCjijvqGhJJYPlF2CZMgCpVwo0lLtWpA82a+RvDWrQIO5q4pUQgkoAyM+H772HKFL/dqlW48cSN3FwYPdoXics/+X/0UdhRxT0lApEE1KnT7quBevX8Wukpb9kyXyTugw/8TKD8InGyVxojEElAGzZAx45+Isz06WFHE7KcHHj0UWjZ0i8U8+yz/j+KkkDEAk0EZtbFzJaY2TIz61/E8web2dtmNs/MFppZvJWfEolb9er58g2NGoUdScgyM2HwYDjrLFi0yFcKTZEaQdESWNeQmZUFhgNnARnAbDOb5JxbVKDZDcAi59z5ZlYTWGJmrzrndgQVl0gYNm6Em26KfCWuvfnll+i8T8Lavh1efhmuvnp3kbj69ZUA9lGQYwRtgWXOueUAZjYW6AYUTAQOqGpmBlQB1gPZAcYkEoq5c/15Ky0tOrN7GjZM2bI48MUXPgEsXOgrhHbuHKOVapJXkImgDvBjge0MoF2hNk8Bk4BVQFXgj8653MJvZGZ9gD4A9evXDyRYkVgYPdr37cs+2LLFl4V4/HE/LfQ//0nhbBhdQSaCoq7RXKHts4G5QEfgaOA9M5vpnNu4x4ucGwWMAkhPTy/8HiJRs2WLXzAl2n78ce9tZC+6d/dF4q67DoYMifOKdoklyESQAdQrsF0X/82/oKuAIc45Bywzs++BpsCsAOMSKdY55/i1yYNSsWJw752UfvvNF4mrVMmXjB44UDWCAhBkIpgNNDKzNOAnoCdwWaE2K4FOwEwzqwU0AZYHGJNIiTIzoW1buOWW6L93lSrQrnDnqBRv0iT/7f/yy/0VwCmnhB1R0gosETjnss3sRmAaUBZ43jm30Mz65j0/ErgfeNHMvsZ3Jd3hnMsMKiaRSBx5pK9RJiH55Re4+WZ47TV/b8All4QdUdIL9M5i59xkYHKhfSML/L4K0GiPhG78eD/2uGoVHHNM2NGksKlT/ULGmzfD/ffDHXdAuXJhR5X0VGJCBH/OWboUatRQD0So6tXzZVSfftpXDZWYUCIQwS/u0quXr04gMZSbC88842+0eOYZXyTugw/CjirlKBFI0nGudHfe5ub6tUtq1QouJinC0qW+HMTMmb48RFaWplWFRIlAksqOHdC1674VYqtTJ/rxSBGys32RuHvu8dNCX3gBrrxS5SFCpEQgScM5X89n+nS/JnndupG/tnx56NEjuNikgHXr4KGHfMYePhxq1w47opSnRCBJY/hwGDUKBgyABx4IOxrZw/bt8OKLfr2AWrVg3jw/MCxxQYlAEtIPP8BTT/leBvBdQs8849cnHzw43NikkM8+80XiFi+Go4+GM89UEogzSgSSkF5/HR55BKpW3d21fNJJ8MorUEbLLcWHzZvh73+HJ5/0J/6pU30SkLijRCAJyeWVHvz5Z6hcOdxYpBjdu/sBmxtvhH/8w2dtiUv67iQJJydnd5eQxJlff/VrBQMMGuSnhg4bpiQQ53RFIAklKwsaNIA1a/y2uoHiyJtvwg03wBVX+FlBJ58cdkQSIf1vJAllyxafBM47D8aM0f1HceHnn31huIsvhsMPh549w45ISkmJQBJS584638SFKVN8TaB33vHjALNmwXHHhR2VlJK6hiShzNKSRfHlyCP9iX/4cGjaNOxoZB/pikASyuOP+8eGDcONI2Xl5vobOK691m83b+5nBikJJDQlAkkoZn6Vr3POCTuSFLRkiV8m8qab/CLMWVlhRyRRokQgIiXbuRMefBBatYJFi3ypiClTNFKfRJQIJG798APUrOnPN/k/776rKaMx9+uv8PDDvn7HokWqFJqENFgscWvlSr+Y/B//6O8dyHfWWaGFlDqysuD556FvXzjsMJg/v3TlXCWhKBFI3Lv2WujUKewoUsjHH/sicUuXQuPGvj6QkkBSUyKQmMvOhpdegk2bSm733XexiUfybNrka3gPH+4vwd59V0XiUoQSgcTcrFl+hcJIlCundUtipnt3mDEDbrnF1/KuUiXsiCRGlAgk5nbu9I+TJsEpp5Tctnx5VRcN1Pr1fhS+cmW4/34/CNyhQ9hRSYwpEUhoqlSBQw4JO4oUNm6cLxJ35ZUwdCiceGLYEUlIlAgkcP37+/XJ8+3Y4R81AzEkq1f7BPDWW9CmDfTqFXZEEjIlAgncJ5/AAQfABRfs3le1KpxwQngxpaz//Af+9Cc/PfShh+C22/wfR1Ka/gVITDRtCiNGhB2FcNRRPgM/9ZSfGiqCEoFEyVdfwbx5RT/3889Qv35s45E8OTn+pD9/Pjz3HDRr5qeFihSgRCBR8ac/weLFxT+vbqAQLFrk5+l+9hl07eq7g1QfSIqgRCBRsX27n4aeXya6sCOOiG08KW3HDj8L6P77/WDMK6/AZZdpdF6KFWgiMLMuwBNAWWC0c25IEW1OBx4HygGZzrnTgoxJglOlil+nREL222/w2GNw4YXw5JO+VpBICQKr42hmZYHhwDlAc+BSM2teqM0hwNPABc65FsAfgopHou/++6FJE/+zcmXY0aS4bdv8WEBurj/xf/01jB2rJCARCbKgb1tgmXNuuXNuBzAW6FaozWXAm865lQDOuV8CjEeibOpU/+Xz+OP92uV//nPYEaWojz7yawXcdJMvEQHqi5NSCbJrqA7wY4HtDKBdoTaNgXJm9gFQFXjCOfdy4Tcysz5AH4D6mn4SV1q2hDFjwo4iRW3c6O/WGzEC0tLg/fdVplX2SZCJoKiRKVfE57cBOgGVgM/M7HPn3NI9XuTcKGAUQHp6euH3kBj55RdfoTjfunWqAxSq7t3hgw/g1lt9P92BB4YdkSSoIBNBBlCvwHZdYFURbTKdc1uALWb2EdAKWIrEnbvugtGj99zXsmU4saSszEyffStXhgce8DOB2rcPOypJcEGOEcwGGplZmpmVB3oCkwq1mQicYmYHmFllfNdRCbPRJUzbtvn1SebN2/3z0kthR5UinPODv82awT33+H0dOigJSFQEdkXgnMs2sxuBafjpo8875xaaWd+850c65xab2VRgPpCLn2K6IKiYZP9VqKCrgJj76Se4/npft/uEE+CKK8KOSJJMoPcROOcmA5ML7RtZaPth4OEg45DSe/BBX5yyoO++g2rVwoknZb3zjq8OunMnPPII/PWvULZs2FFJktGdxVKkN96AjAxIT9+9r0YNTUqJuYYN/ToBw4b530UCoEQgxWrf3vdGSAzl5Pi7gefNgxdf9GVbp0wJOypJckoEAvjzzxdf+LpksPeF5SUACxfC1Vf7P8S556pInMSMEoEAviu6e/c997VtG04sKWfHDhgyxC8Yf/DB8O9/Q8+eKhInMaNEIABs3uwfX3kF6uXd/dGqVXjxpJTffvPdQX/4gy/fWrNm2BFJilEikD20bQuNGoUdRQrYuhWefRZuvHF3kbjatcOOSlJUqW8oM7OyZqbVrkX21YwZcOyxfiroBx/4fUoCEqJiE4GZHWRmA8zsKTPrbN5NwHKgR+xCFEkSGzbAX/4CHTv6/v8ZMzQfV+JCSV1D/wJ+BT4DrgH6AeWBbs65uTGITSS5dO/uS0b36weDBqlin8SNkhLBUc65YwHMbDSQCdR3zmliYRLaujXsCJLU2rW+Kmjlyv527bJltYCzxJ2Sxgh25v/inMsBvlcSSF6vvuof9SU1Spzz00ALFolr315JQOJSSVcErcxsI7vXFahUYNs55w4KPDqJmUMPhXLloE6dsCNJAhkZcN11/uaMdu2gd++wIxIpUbGJwDmnylYppmnTsCNIApMmwZ/+5G/Vfuwxv3ykisRJnCs2EZhZRaAv0BBfJvp551x2rAKT2MjMhNtvh88/90XlZD81bgwnn+wXkj/qqLCjEYlISWMELwHpwNdAV+DRmEQkMfXFF/Dyy76kzQUXhB1NAsrO9uWh89cIaNoUJk9WEpCEUtIYQfMCs4aeA2bFJiQJw2uvaRyz1ObP90Xi5syBbt1UJE4SVqSzhtQllKRWrw47ggS0fbufCdSmDaxcCa+/7lfxURKQBFXSFUHrvFlC4GcKadZQkpk3D6691v+uc1gpbNwITz8Nl17qB4SrVw87IpH9UlIimOecOy5mkUjMZWb6x3794Jhjwo0l7m3ZAqNGwc03++qgCxZArVphRyUSFSV1DbmYRSGhOv98lb4v0fTpvkjcbbfBhx/6fUoCkkRKuiI4zMxuK+5J59w/A4hHJH789hv83//Bc8/52twffginnhp2VCJRV1IiKAtUYfedxZJk8usLaXygGBdeCDNnwh13+MHhSpXCjkgkECUlgtXOuftiFonE3Jo1/lG9HAWsWQNVqvhCcUOGwAEH+NlBIkmspDECXQkkuYUL/eNhh4UbR1xwDv71L2jefHeRuHbtlAQkJZSUCLRiRhJ77DG/PO6hh6priJUr4dxz/d3BTZr4m8REUkhJRefWxzIQia2MDP/45pvhxhG6iRN9kTjn/ALy11+vInGScrR4fQqrUgVOOy3sKELinJ8z27QpnH46DBsGDRqEHZVIKJQIUsjKlTB8uK+Tlr9mesrJzoZHH4Wvv4ZXXvFdQW+/HXZUIqFSIkghb7wBQ4f6CTFm0LZt2BHF2Lx58Oc/w1df+amhKhInAigRpJTcXP+4Zo1PBikjKwsGD4aHHvJ1gcaNg4svDjsqkbhR0qwhSSLZ2X7RrJS0aRM88wz06gWLFikJiBQSaCIwsy5mtsTMlplZ/xLanWBmOWZ2SZDxpKr33/c9IAMG+O0yqZD+N2/2C8bk5PgicYsWwYsv+vmyIrKHwLqGzKwsMBw4C8gAZpvZJOfcoiLaPQRMCyqWVPf99/582K8ftG6dApUS3n0X+vTxo+Nt2sAZZ/hkICJFCvK7YVtgmXNuuXNuBzAW6FZEu5uA8cAvAcYiwC23wGWXhR1FgNavh6uugrPP9pdAM2f6JCAiJQpysLgO8GOB7QygXcEGZlYHuBDoCBS7UKKZ9QH6ANSvXz/qgSarr7/2C2fNmRN2JDFy4YXwySdw550wcKBmBIlEKMhEUFStosJrHDwO3OGcy7ESCuI750YBowDS09O1TkKE7r0Xxo/3v9eqBYccEm48gfj5Z6ha1U+DevhhKF/e93+JSMSC7BrKAOoV2K4LrCrUJh0Ya2YrgEuAp82se4AxpZRVq6BjRz8+sHp1kk0Zdc4P/jZvDnff7fe1baskILIPgkwEs4FGZpZmZuWBnsCkgg2cc2nOuQbOuQbAOOB659yEAGNKKWvWwOGH+1lCSbUC2YoV0KWLHw9o0cIPDIvIPgssETjnsoEb8bOBFgOvO+cWmllfM+sb1OeKN38+LF/uE0FSeestv8Dyp5/CU0/5VcOaNAk7KpGEFuidxc65ycDkQvtGFtO2d5CxpJr33/ePlyTLnRn5ReJatIAzz4QnnoAjjww7KpGkkAq3FqWkefP81UCHDmFHsp927oR//MPfFQzQuDFMmKAkIBJFqjWUJGbNgs8/37390UfQqlV48UTFV1/5RWLmzoUePWD7dqhQIeyoRJKOEkGSuP56+PLLPff1TdSRmG3b4L77/HTQmjX9uEB3TSYTCYoSQZLYuRO6dvXL7oLvTq9WLdyY9tmWLfDcc3Dllb5eUMIeiEhiUCJIIuXLJ3BNtU2bYMQIuP12qFHDF4mrUSPsqERSggaLk0T+WgMJaepUPyW0f39fHwiUBERiSIkgCeTmwg8/QJ06YUdSSuvW+e6fc87xtz1/8olfP1hEYkpdQ0lgxQrfs5Jws4QuusjfGDZwINx1l2YEiYREiSDBvf++71mBBEkEq1f7InFVqviB4PLlEyRwkeSlRJDANmyAs87yvx90kO9mj1vOwQsvwG23+QXk//lPOKHYyuMiEkMaI0hgO3f6x3vu8WMElSuHG0+xli+Hzp39zWGtWiXwDQ4iyUlXBEmgRo04XmvgzTfh8suhbFk/PbRPnxRZNFkkcSgRSDDyi8Qde6wvGf3441Cv3t5fJyIxp69mEl07dsDgwX5xZOegUSO/TJqSgEjcUiKQ6Jkzxw8ADxzot3fsCDceEYmIuoYSSG4uTJ/uS/EAbNwYbjy7bNvmR6wffdTXvp44ES64IOyoRCRCSgQJ5NNP/eSbwkIfKN6yxa8ffPXVMHRoHAQkIqWhRJBAtm71jy+8sHuN9nLl/PrtMbdxIzz9NPTr56ctLV4M1auHEIiI7C8lggTUuPHuRBCK//zH3wuwahW0b+/rAykJiCQsDRYnkL//PeQA1q71S0aedx4cfLDvq1KROJGEpyuCBOGcX7kRQiwlcfHFfj3MQYNgwABfJ0hEEp4SQYLYsAFycvzEnIMOiuEH//ST//ZfpQo89pivEBrXRY1EpLSUCOLYggWQmel/z8jwj7VqxejDnYPRo+H//s/PBvrnP6FNmxh9uIjEkhJBnPr5Z2jZ0p+PC0pLi8GHf/cdXHstzJgBZ5wBN9wQgw8VkbAoEcSplSt9EhgyBNq18/uqVInBl/Jx4+CKK/y81FGj4JprfM0gEUlaSgRxas0a/9ixY4zK9ucXiWvVCs49148H1K0bgw8WkbBp+micevFF/xj4mMCOHXDvvdCz5+4icW+8oSQgkkKUCOLU2rX+MdAF6WfN8n1NgwbBAQeoSJxIilIiiGNnnOHXc4m6rVv9bKAOHeDXX+Htt+HVV7V4vEiK0hhBiJYv9/cHFGXz5gBrt23bBq+84lcLe+ihGN+YICLxJtBEYGZdgCeAssBo59yQQs/3Au7I29wMXOecmxdkTPFi2TLfHV+Sc8+N4gdu2ABPPQV33OHrAi1eDNWqRfEDRCRRBZYIzKwsMBw4C8gAZpvZJOfcogLNvgdOc879ambnAKOAdkHFFE+++84/DhkCTZsW3SY9PUof9vbbvkjczz/DSSf5+kBKAiKSJ8grgrbAMufccgAzGwt0A3YlAufcpwXafw6kzFSV/OmhF1209yuDfbZ2Ldx8M4wd69cOnjgxitlFRJJFkImgDvBjge0MSv62fzUwpagnzKwP0Aegfv360Yov5nJz4fbbffmeb7/1+w4/PMAPzC8Sd999vktIReJEpAhBJoKibkd1RezDzM7AJ4KTi3reOTcK321Eenp6ke+RCFasgMcfhyOO8HXcLrrI3y0cVRkZfpS5ShX/YRUqQIsWUf4QEUkmQU4fzQDqFdiuC6wq3MjMWgKjgW7OuXUBxhO6/O6g0aNh0SIYPz6K1Rtyc+GZZ/xyZfmLxx9/vJKAiOxVkFcEs4FGZpYG/AT0BC4r2MDM6gNvApc755YGGEsocnPhhx92F477+mv/GPW7hb/91heJ+/BD6NQJbropyh8gIskssETgnMs2sxuBafjpo8875xaaWd+850cCdwPVgafNfzXOds4lzWjm3/8ODz74+/1RvVv4jTd8kbgKFeC55+Cqq1QkTkRKJdD7CJxzk4HJhfaNLPD7NcA1QcYQprVrfXf9E0/s3le7dpSuCPKLxB13HHTr5tcLOOKIKLyxiKQa3VkcsMqV/Rf2qNm+HR54wN8Q9vrr0LChnx4qIrKPVGsoQO+845eXjJrPP/cDwPffD5UqqUiciESFEkGANm0qvpZQqWzZArfeCiee6N908mR4+WUViRORqFAiCFD58n4yz37LyvLdP9dfDwsXwjnnROFNRUQ8jRFE0YYNkJ29ezs3dz/e7LffYNgwGDBgd5G4wMqRikgq0xVBlLz+uj9P16ix+2fDBr/0b6lNmOBvDLv3Xvg0rxyTkoCIBERXBFHyY15VpaFDoWJF/7sZdO9eijdZs8bfDPbGG37t4LffjsFq9SKS6pQIoqxvX6hadR9ffMklfvnIwYPhb3/bx8sJEZHSUSKIgg8+8F/e98nKlX5tgKpV4ckn/Uyg5s2jGZ6ISIk0RhAFDzwAH33k67tVqhThi3JzYfhw/6K77/b7jjtOSUBEYk6JIApyc/3CXwsWwAGRXGMtWQKnnQY33ugXkL/llsBjFBEpjhJBFDi3u8LoXr3+uh8IXrAAXngBpk2DBg2CDE9EpERKBPtpxw6YMQN27txLw/xM0aaNX5Fm8WLo3VuVQkUkdEoE+2nrVv/YsGExDbKy4K67/Iwg5+Doo+Hf/w54jUoRkcgpEUTJCScUsfPTT/0A8D/+4WcFqUiciMQhTR8tZPx432sTqW3biti5eTPceSc89RTUqwdTp8LZZ0ctRhUTFpcAAA2zSURBVBGRaFIiKOTyy4s5uZegbFnf47PLjh0wbhzccMPuqwERkTilrqFCcnL8Tb07d0b+s307nH/Sehg0yFedO/RQf1kxbJiSgIjEPV0RFKFMmQjvB8g3frz/9p+ZCR07wqmnwsEHBxafiEg06YqggDPP9L06ZSL9r7J6NVx8sZ8RdMQRMGeOTwIiIglEVwQF5Fd87t07whf06AGzZ8OQIXD77aW8jBARiQ86cxVQrhz85S/QqFEJjX74wY8BVK3qxwAqVYImTWIWo0hYdu7cSUZGBllZWWGHIiWoWLEidevWpVwpqhcrEeT58UfYuLGEBvlF4gYMgGuugccfh9atYxafSNgyMjKoWrUqDRo0wHRHfFxyzrFu3ToyMjJIS0uL+HUaI8jzyCP+sV69Ip785hvf93/zzXDKKX4heZEUk5WVRfXq1ZUE4piZUb169VJftSkR5Nm50y8FcNtthZ4YO9YXiVu8GF5+GSZPhiOPDCVGkbApCcS/ffkbKREUcNBBBTbyV54/4QT4wx9g0SJ/t5n+RxCRJJPSiWDDBjj2WKhTB156Ke8cv20b9O/vp4XmF4l75RWoVSvscEUEeOuttzAzvvnmm137PvjgA84777w92vXu3Ztx48YBfqC7f//+NGrUiGOOOYa2bdsyZcqU/Y7lwQcfpGHDhjRp0oRp06YV227YsGE0adKEFi1a8Le//Q2AV199ldatW+/6KVOmDHPnzgXgtddeo2XLlnu0B1i5ciVnnHEGxx13HC1btmTy5Mn7fQyQ4oPFGRl+WYCOHeGoo+CCajOh9TWwdClcfbXvLypfPuwwRaSAMWPGcPLJJzN27FgGDRoU0WsGDhzI6tWrWbBgARUqVGDNmjV8+OGH+xXHokWLGDt2LAsXLmTVqlWceeaZLF26lLJly+7RbsaMGUycOJH58+dToUIFfvnlFwB69epFr169APj666/p1q0brVu3Zt26dfTr148vv/ySmjVrcuWVVzJ9+nQ6derE4MGD6dGjB9dddx2LFi2ia9eurFixYr+OA1I8EeS78cpNXPhFf3j4aUhLg/fe83eXiUiR/vpXyPvyGjWtW/vJeCXZvHkzn3zyCTNmzOCCCy6IKBFs3bqVZ599lu+//54KFSoAUKtWLXr06LFf8U6cOJGePXtSoUIF0tLSaNiwIbNmzaJDhw57tBsxYgT9+/ff9dmHHXbY795rzJgxXHrppQAsX76cxo0bU7NmTQDOPPNMxo8fT6dOnTAzNuZNb9ywYQNHHHHEfh1DvpRKBDNnQsHkmZHhHy1nJ0yY4P91Dx4MBx4YSnwiUrIJEybQpUsXGjduzKGHHspXX33F8ccfX+Jrli1bRv369Tloj0HAot16663MmDHjd/t79uxJ//7999j3008/0b59+13bdevW5aeffvrda5cuXcrMmTO56667qFixIo888ggnFKpb/9prrzFx4kQAGjZsyDfffMOKFSuoW7cuEyZMYEdeCftBgwbRuXNnhg0bxpYtW3j//ff3ekyRSJlEsH277wLKzvbbh7KOW3iCstzNIWmH+imiKhAnEpG9fXMPypgxY/jrX/8K+JPzmDFjOP7444udKVPaGTSPPfZYxG1dEevTFvV52dnZ/Prrr3z++efMnj2bHj16sHz58l1tv/jiCypXrswxxxwDQLVq1RgxYgR//OMfKVOmDCeeeCLLly8H/PH37t2b22+/nc8++4zLL7+cBQsWUCbiujhFCzQRmFkX4AmgLDDaOTek0POW93xXYCvQ2zn3VRCxrFnjk8A/HnBcVXUcNe69kbIb1nPda2dR8/RTACUBkXi2bt06/vvf/7JgwQLMjJycHMyMoUOHUr16dX799dc92q9fv54aNWrQsGFDVq5cyaZNm6i6ly97pbkiqFu3Lj/++OOu7YyMjCK7aurWrctFF12EmdG2bVvKlClDZmbmrq6fsWPH7uoWynf++edz/vnnAzBq1Khd4w7PPfccU6dOBaBDhw5kZWWRmZlZZHdTqTjnAvnBn/y/A44CygPzgOaF2nQFpgAGtAe+2Nv7tmnTxu2LL75wrjY/uVXtu/u15tu0cW7u3H16L5FUtGjRolA/f+TIka5Pnz577Dv11FPdRx995LKyslyDBg12xbhixQpXv35999tvvznnnOvXr5/r3bu32759u3POuVWrVrl//etf+xXPggULXMuWLV1WVpZbvny5S0tLc9nZ2b9rN2LECDdw4EDnnHNLlixxdevWdbm5uc4553JyclydOnXcd999t8dr1qxZ45xzbv369a5Vq1ZuyZIlzjnnunTp4l544QXnnP971K5de9d7FVTU3wqY44o7Xxf3xP7+AB2AaQW2BwADCrV5Bri0wPYSoHZJ77uviWDSJOdmcpLLqVDRuaFDndu5c5/eRyRVhZ0ITjvtNDdlypQ99j3xxBOub9++zjnnPv74Y9euXTvXqlUrl56e7t59991d7bZv3+769evnjj76aNeiRQvXtm1bN3Xq1P2OafDgwe6oo45yjRs3dpMnT961/+qrr3azZ8/e9dm9evVyLVq0cMcdd5ybPn36rnYzZsxw7dq1+9379uzZ0zVr1sw1a9bMjRkzZtf+hQsXuhNPPNG1bNnStWrVyk2bNq3IuEqbCMwV0c8VDWZ2CdDFOXdN3vblQDvn3I0F2rwDDHHOfZy3PR24wzk3p9B79QH6ANSvX7/NDz/8UOp4PvkExt89j/73VuKwkxvv62GJpKzFixfTrFmzsMOQCBT1tzKzL51z6UW1D3KMoKhRmsJZJ5I2OOdGAaMA0tPT9ylznXQSnDS91b68VEQkqQV5Z3EGULCEW11g1T60ERGRAAWZCGYDjcwszczKAz2BSYXaTAKuMK89sME5tzrAmERkPwTVlSzRsy9/o8C6hpxz2WZ2IzANP4PoeefcQjPrm/f8SGAyfubQMvz00auCikdE9k/FihVZt26dSlHHMZe3HkHFihVL9brABouDkp6e7ubMmbP3hiISVVqhLDEUt0JZWIPFIpJEypUrV6pVryRxpHQZahERUSIQEUl5SgQiIiku4QaLzWwtUPpbi70aQGYUw0kEOubUoGNODftzzEc652oW9UTCJYL9YWZzihs1T1Y65tSgY04NQR2zuoZERFKcEoGISIpLtUQwKuwAQqBjTg065tQQyDGn1BiBiIj8XqpdEYiISCFKBCIiKS4pE4GZdTGzJWa2zMz6F/G8mdmTec/PN7Pjw4gzmiI45l55xzrfzD41s4RfpWdvx1yg3QlmlpO3al5Ci+SYzex0M5trZgvN7MNYxxhtEfzbPtjM3jazeXnHnNBVjM3seTP7xcwWFPN89M9fxa1hmag/+JLX3wFHAeWBeUDzQm26AlPwK6S1B74IO+4YHPOJQLW8389JhWMu0O6/+JLnl4Qddwz+zocAi4D6eduHhR13DI75TuChvN9rAuuB8mHHvh/HfCpwPLCgmOejfv5KxiuCtsAy59xy59wOYCzQrVCbbsDLzvscOMTMasc60Cja6zE75z51zv2at/k5fjW4RBbJ3xngJmA88EssgwtIJMd8GfCmc24lgHMu0Y87kmN2QFXziyRUwSeC7NiGGT3OuY/wx1CcqJ+/kjER1AF+LLCdkbevtG0SSWmP52r8N4pEttdjNrM6wIXAyBjGFaRI/s6NgWpm9oGZfWlmV8QsumBEcsxPAc3wy9x+DdzinMuNTXihiPr5KxnXIyhq6aTCc2QjaZNIIj4eMzsDnwhODjSi4EVyzI8DdzjncpJkRa1IjvkAoA3QCagEfGZmnzvnlgYdXEAiOeazgblAR+Bo4D0zm+mc2xh0cCGJ+vkrGRNBBlCvwHZd/DeF0rZJJBEdj5m1BEYD5zjn1sUotqBEcszpwNi8JFAD6Gpm2c65CbEJMeoi/bed6ZzbAmwxs4+AVkCiJoJIjvkqYIjzHejLzOx7oCkwKzYhxlzUz1/J2DU0G2hkZmlmVh7oCUwq1GYScEXe6Ht7YINzbnWsA42ivR6zmdUH3gQuT+BvhwXt9Zidc2nOuQbOuQbAOOD6BE4CENm/7YnAKWZ2gJlVBtoBi2McZzRFcswr8VdAmFktoAmwPKZRxlbUz19Jd0XgnMs2sxuBafgZB8875xaaWd+850fiZ5B0BZYBW/HfKBJWhMd8N1AdeDrvG3K2S+DKjREec1KJ5Jidc4vNbCowH8gFRjvnipyGmAgi/DvfD7xoZl/ju03ucM4lbHlqMxsDnA7UMLMM4B6gHAR3/lKJCRGRFJeMXUMiIlIKSgQiIilOiUBEJMUpEYiIpDglAhGRFKdEIBKhvAqmcwv8NMir9LnBzP5nZovN7J68tgX3f2Nmj4Qdv0hxku4+ApEAbXPOtS64w8waADOdc+eZ2YHAXDN7J+/p/P2VgP+Z2VvOuU9iG7LI3umKQCRK8so6fImvd1Nw/zZ8LZxELmwoSUyJQCRylQp0C71V+Ekzq46vD7+w0P5qQCPgo9iEKVI66hoSidzvuobynGJm/8OXdBiSVwLh9Lz98/G1b4Y4536OYawiEVMiENl/M51z5xW338waAx/njRHMjXVwInujriGRgOVVe30QuCPsWESKokQgEhsjgVPNLC3sQEQKU/VREZEUpysCEZEUp0QgIpLilAhERFKcEoGISIpTIhARSXFKBCIiKU6JQEQkxf0/8T/s5c/KQgMAAAAASUVORK5CYII=\n", 924 | "text/plain": [ 925 | "
" 926 | ] 927 | }, 928 | "metadata": { 929 | "needs_background": "light" 930 | }, 931 | "output_type": "display_data" 932 | } 933 | ], 934 | "source": [ 935 | "plot_roc(y_true, y_score)" 936 | ] 937 | }, 938 | { 939 | "cell_type": "markdown", 940 | "metadata": {}, 941 | "source": [ 942 | "# 整体训练图像" 943 | ] 944 | }, 945 | { 946 | "cell_type": "code", 947 | "execution_count": 57, 948 | "metadata": {}, 949 | "outputs": [ 950 | { 951 | "data": { 952 | "text/html": [ 953 | "
\n", 954 | "\n", 967 | "\n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | " \n", 978 | " \n", 979 | " \n", 980 | " \n", 981 | " \n", 982 | " \n", 983 | " \n", 984 | " \n", 985 | " \n", 986 | " \n", 987 | " \n", 988 | " \n", 989 | " \n", 990 | " \n", 991 | " \n", 992 | " \n", 993 | " \n", 994 | " \n", 995 | " \n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | " \n", 1001 | " \n", 1002 | " \n", 1003 | " \n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | " \n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | " \n", 1033 | " \n", 1034 | " \n", 1035 | " \n", 1036 | " \n", 1037 | " \n", 1038 | " \n", 1039 | " \n", 1040 | " \n", 1041 | " \n", 1042 | " \n", 1043 | " \n", 1044 | "
train_aux_losstrain_target_losstrain_final_loss
00.8954530.6920251.587478
10.8836130.6910351.574647
20.8718200.6901961.562016
30.8603340.6894091.549743
40.8486130.6888401.537453
............
2920.0302060.1975150.227721
2930.0289850.1408210.169806
2940.0289900.0819850.110975
2950.0280550.1663380.194393
2960.0287970.1971610.225958
\n", 1045 | "

297 rows × 3 columns

\n", 1046 | "
" 1047 | ], 1048 | "text/plain": [ 1049 | " train_aux_loss train_target_loss train_final_loss\n", 1050 | "0 0.895453 0.692025 1.587478\n", 1051 | "1 0.883613 0.691035 1.574647\n", 1052 | "2 0.871820 0.690196 1.562016\n", 1053 | "3 0.860334 0.689409 1.549743\n", 1054 | "4 0.848613 0.688840 1.537453\n", 1055 | ".. ... ... ...\n", 1056 | "292 0.030206 0.197515 0.227721\n", 1057 | "293 0.028985 0.140821 0.169806\n", 1058 | "294 0.028990 0.081985 0.110975\n", 1059 | "295 0.028055 0.166338 0.194393\n", 1060 | "296 0.028797 0.197161 0.225958\n", 1061 | "\n", 1062 | "[297 rows x 3 columns]" 1063 | ] 1064 | }, 1065 | "execution_count": 57, 1066 | "metadata": {}, 1067 | "output_type": "execute_result" 1068 | } 1069 | ], 1070 | "source": [ 1071 | "train_loss_data = pd.read_csv(\"./loss/dien/train_loss.csv.2020_09_22_21_35_06\")\n", 1072 | "train_loss_data" 1073 | ] 1074 | }, 1075 | { 1076 | "cell_type": "code", 1077 | "execution_count": 56, 1078 | "metadata": {}, 1079 | "outputs": [ 1080 | { 1081 | "data": { 1082 | "text/html": [ 1083 | "
\n", 1084 | "\n", 1097 | "\n", 1098 | " \n", 1099 | " \n", 1100 | " \n", 1101 | " \n", 1102 | " \n", 1103 | " \n", 1104 | " \n", 1105 | " \n", 1106 | " \n", 1107 | " \n", 1108 | " \n", 1109 | " \n", 1110 | " \n", 1111 | " \n", 1112 | " \n", 1113 | " \n", 1114 | " \n", 1115 | " \n", 1116 | " \n", 1117 | " \n", 1118 | " \n", 1119 | " \n", 1120 | " \n", 1121 | " \n", 1122 | " \n", 1123 | " \n", 1124 | " \n", 1125 | " \n", 1126 | " \n", 1127 | " \n", 1128 | " \n", 1129 | " \n", 1130 | " \n", 1131 | " \n", 1132 | " \n", 1133 | " \n", 1134 | " \n", 1135 | " \n", 1136 | " \n", 1137 | " \n", 1138 | " \n", 1139 | " \n", 1140 | " \n", 1141 | " \n", 1142 | " \n", 1143 | " \n", 1144 | " \n", 1145 | " \n", 1146 | " \n", 1147 | " \n", 1148 | " \n", 1149 | " \n", 1150 | " \n", 1151 | " \n", 1152 | " \n", 1153 | " \n", 1154 | " \n", 1155 | " \n", 1156 | " \n", 1157 | " \n", 1158 | " \n", 1159 | " \n", 1160 | " \n", 1161 | " \n", 1162 | " \n", 1163 | " \n", 1164 | " \n", 1165 | " \n", 1166 | " \n", 1167 | " \n", 1168 | " \n", 1169 | " \n", 1170 | " \n", 1171 | " \n", 1172 | " \n", 1173 | " \n", 1174 | "
test_aux_losstest_target_losstest_final_loss
00.8955500.6921211.587671
10.8837850.6913251.575110
20.8721210.6905321.562653
30.8605580.6897211.550279
40.8491010.6889171.538019
............
2920.0301820.2611070.291289
2930.0300740.2611990.291273
2940.0299660.2613540.291320
2950.0298590.2616390.291498
2960.0297520.2619370.291690
\n", 1175 | "

297 rows × 3 columns

\n", 1176 | "
" 1177 | ], 1178 | "text/plain": [ 1179 | " test_aux_loss test_target_loss test_final_loss\n", 1180 | "0 0.895550 0.692121 1.587671\n", 1181 | "1 0.883785 0.691325 1.575110\n", 1182 | "2 0.872121 0.690532 1.562653\n", 1183 | "3 0.860558 0.689721 1.550279\n", 1184 | "4 0.849101 0.688917 1.538019\n", 1185 | ".. ... ... ...\n", 1186 | "292 0.030182 0.261107 0.291289\n", 1187 | "293 0.030074 0.261199 0.291273\n", 1188 | "294 0.029966 0.261354 0.291320\n", 1189 | "295 0.029859 0.261639 0.291498\n", 1190 | "296 0.029752 0.261937 0.291690\n", 1191 | "\n", 1192 | "[297 rows x 3 columns]" 1193 | ] 1194 | }, 1195 | "execution_count": 56, 1196 | "metadata": {}, 1197 | "output_type": "execute_result" 1198 | } 1199 | ], 1200 | "source": [ 1201 | "test_loss_data = pd.read_csv(\"./loss/dien/test_loss.csv.2020_09_22_21_35_06\")\n", 1202 | "test_loss_data" 1203 | ] 1204 | }, 1205 | { 1206 | "cell_type": "code", 1207 | "execution_count": 58, 1208 | "metadata": {}, 1209 | "outputs": [], 1210 | "source": [ 1211 | "def get_loss_fig_aux(train_loss_data, test_loss_data):\n", 1212 | " train_loss = {\n", 1213 | " \"aux_loss\":list(train_loss_data[\"train_\" + \"aux_loss\"].values), \n", 1214 | " \"target_loss\":list(train_loss_data[\"train_\" + \"target_loss\"].values), \n", 1215 | " \"final_loss\":list(train_loss_data[\"train_\" + \"final_loss\"].values)\n", 1216 | " }\n", 1217 | " test_loss = {\n", 1218 | " \"aux_loss\":list(test_loss_data[\"test_\" + \"aux_loss\"].values), \n", 1219 | " \"target_loss\":list(test_loss_data[\"test_\" + \"target_loss\"].values), \n", 1220 | " \"final_loss\":list(test_loss_data[\"test_\" + \"final_loss\"].values)\n", 1221 | " }\n", 1222 | " get_loss_fig(train_loss, test_loss)" 1223 | ] 1224 | }, 1225 | { 1226 | "cell_type": "code", 1227 | "execution_count": 59, 1228 | "metadata": {}, 1229 | "outputs": [ 1230 | { 1231 | "data": { 1232 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nOydd3hUVfPHvycFAgRICKGDoddQJPQuvYMURYp0QRQVUeRFmghYEBAFeamCUqTqq6IgJRQFIUBC50eH0CG0EAIkOb8/vrncTbKp7GZT5vM8+9zdW86dbWfOmZkzo7TWEARBEDIvTo4WQBAEQXAsoggEQRAyOaIIBEEQMjmiCARBEDI5oggEQRAyOaIIBEEQMjmiCARBEDI5oggEIR6UUueVUs0cLYcg2BtRBIIgCJkcUQSCkEyUUoOUUqeVUiFKqf8ppQpF71dKqRlKqRtKqXtKqUNKqUrRx9oopY4ppR4opS4rpUY69l0IgokoAkFIBkqplwBMBdAdQEEAFwCsjD7cAkBDAGUAeAB4BcDt6GMLAbyhtc4JoBKArakotiAkiIujBRCEdEZPAIu01gcAQCk1GsAdpZQPgKcAcgIoB2Cv1vq4xXVPAVRQSgVpre8AuJOqUgtCAsiMQBCSRyFwFgAA0FqHgqP+wlrrrQC+BTAbwHWl1DylVK7oU7sAaAPgglJqu1KqTirLLQjxIopAEJLHFQAvGC+UUjkAeAG4DABa61la6+oAKoImog+i9+/TWncEkA/AzwBWpbLcghAvoggEIWFclVJuxgPswPsppaoqpbICmALgX631eaVUDaVULaWUK4CHAMIBRCqlsiileiqlcmutnwK4DyDSYe9IEGIhikAQEmYDgEcWjwYAxgJYC+AqgJIAXo0+NxeA+aD9/wJoMpoWfaw3gPNKqfsAhgDolUryC0KiKClMIwiCkLmRGYEgCEImRxSBIAhCJkcUgSAIQiZHFIEgCEImJ92tLM6bN6/28fFxtBiCIAjpiv3799/SWntbO5buFIGPjw8CAgIcLYYgCEK6Qil1Ib5jYhoSBEHI5IgiEARByOSIIhAEQcjkpDsfgSAIqcPTp08RHByM8PBwR4siJAM3NzcUKVIErq6uSb7GbopAKbUIQDsAN7TWleI5pzGAmQBcAdzSWjeylzyCICSP4OBg5MyZEz4+PlBKOVocIQlorXH79m0EBwejePHiSb7Onqah7wG0iu+gUsoDwBwAHbTWFQF0s6MsgiAkk/DwcHh5eYkSSEcopeDl5ZXsWZzdFIHWegeAkAROeQ3AOq31xejzb9hLFkEQUoYogfRHSr4zRzqLywDwVEr5K6X2K6X6xHeiUmqwUipAKRVw8+bNFN3s0iVg+HDg0aOUiisIgpAxcaQicAFQHUBbAC0BjFVKlbF2otZ6ntbaT2vt5+1tdWFcouzaBXzzDdCxY4rlFQRByJA4UhEEA/hTa/1Qa30LwA4AVex1s1dfBQoXBv76C/jtN3vdRRAEW3L37l3MmTMn2de1adMGd+/etYNEKcff3x/t2rVztBhWcaQi+AVAA6WUi1IqO4BaAI7b62ZKAX/+yW2PHsCTJ/a6kyAItiI+RRAZmXClzw0bNsDDw8NeYmU47KYIlFIrAOwGUFYpFayUGqCUGqKUGgIAWuvjAP4EcAjAXgALtNZH7CUPAFSqBAweDISGAq+8Ys87CUIGpHHjuA+jkw4Ls378++95/NatuMeSwEcffYQzZ86gatWqqFGjBpo0aYLXXnsNvr6+AIBOnTqhevXqqFixIubNm/fsOh8fH9y6dQvnz59H+fLlMWjQIFSsWBEtWrTAowQchfPnz0eNGjVQpUoVdOnSBWFhYQCAvn37Ys2aNc/Oc3d3BwCsX78ezZo1g9YaV69eRZkyZXDt2rVE31dISAg6deqEypUro3bt2jh06BAAYPv27ahatSqqVq2KatWq4cGDB7h69SoaNmyIqlWrolKlSti5c2eSPrvkYM+ooR5a64Jaa1etdRGt9UKt9Vyt9VyLc77UWlfQWlfSWs+0lyyWzJkDeHsDv/wCHDuWGncUBCGlfPbZZyhZsiQCAwPx5ZdfYu/evZg8eTKORf95Fy1ahP379yMgIACzZs3C7du347Rx6tQpDBs2DEePHoWHhwfWrl0b7/1efvll7Nu3D0FBQShfvjwWLlyYoHydO3dGgQIFMHv2bAwaNAgTJ05EgQIFEn1f48ePR7Vq1XDo0CFMmTIFffowVmbatGmYPXs2AgMDsXPnTmTLlg3Lly9Hy5YtERgYiKCgIFStWjXR9pNLpltZ7OQEbNkCNGoEDBwI7NwJODs7WipBSAf4+8d/LHv2hI/nzZvw8SRSs2bNGAulZs2ahfXr1wMALl26hFOnTsHLyyvGNcWLF3/WeVavXh3nz5+Pt/0jR47g448/xt27dxEaGoqWLVsmKtM333yDSpUqoXbt2ujRo0eS3seuXbueKaSXXnoJt2/fxr1791CvXj2MGDECPXv2xMsvv4wiRYqgRo0a6N+/P54+fYpOnTrZRRFkylxDvr7A118Du3cDI0c6WhpBEJJKjhw5nj339/fH5s2bsXv3bgQFBaFatWpWF1JlzZr12XNnZ2dERETE237fvn3x7bff4vDhwxg/fvyz9lxcXBAVFQWAq3efWDgZL1++DCcnJ1y/fv3ZOYmhtY6zTymFjz76CAsWLMCjR49Qu3ZtnDhxAg0bNsSOHTtQuHBh9O7dG0uXLk3SPZJDplQEANCrF1CoEDBzJvDPP46WRhAEa+TMmRMPHjyweuzevXvw9PRE9uzZceLECezZs+e57/fgwQMULFgQT58+xbJly57t9/Hxwf79+wEAv/zyC54+fQoAiIiIQL9+/bB8+XKUL18e06dPT9J9GjZs+Kx9f39/5M2bF7ly5cKZM2fg6+uLUaNGwc/PDydOnMCFCxeQL18+DBo0CAMGDMCBAwee+33GJtOZhgyUAlasoImoQwfg+nUxEQlCWsPLywv16tVDpUqVkC1bNuTPn//ZsVatWmHu3LmoXLkyypYti9q1az/3/SZNmoRatWrhhRdegK+v7zMlNGjQIHTs2BE1a9ZE06ZNn81MpkyZggYNGqBBgwbPHNpt27ZF+fLlE7zPhAkT0K9fP1SuXBnZs2fHkiVLAAAzZ87Etm3b4OzsjAoVKqB169ZYuXIlvvzyS7i6usLd3d0uMwJlbYqSlvHz89O2rFDWvTuwejXQuzdgh89XENItx48fT7RDE9Im1r47pdR+rbWftfMzrWnIYPlywMMD+OEHrj4WBEHIbGR6ReDiwlBSV1fgzTeBaNOfIAgZmGHDhj2L1zceixcvtknbGzdujNN2586dbdK2vci0PgJLGjakv6BrV+DTT4GJEx0tkSAI9mT27Nl2a7tly5ZJCjtNS2T6GYFBly70F0yaBKxb52hpBEEQUg9RBBaMHcttr15MQyEIgpAZEEVgQaVKwAcfsGZBmzaOlkYQBCF1EEUQi88+A3x8mHrCRr4jQRCENI0oglgoxZoFTk6MIgpJqNimIAh2JbXrEZw4ceJZ5s8zZ86gbt26yW7DIHbG0tg0btwYtlwT9TyIIrBCqVLAf/8LREYybXU6W3MnCBmG1K5H8PPPP6Njx444ePAgSpYsiX8ySf4ZUQTxMHAgMHkysHYtQ0oFIbPjgHIEqVqPYMOGDZg5cyYWLFiAJk2aADDrDvj7+6Nx48bo2rUrypUrh549ez5LHPfJJ5+gRo0aqFSpEgYPHmw1oVxirFixAr6+vqhUqRJGjRoFgMqub9++qFSpEnx9fTFjxgwAzLhaoUIFVK5cGa+++mqy72UNWUeQAO+/z6R048YBTZoA9es7WiJByFx89tlnOHLkCAIDA+Hv74+2bdviyJEjz1JRL1q0CHny5MGjR49Qo0YNdOnSJU4a6lOnTmHFihWYP38+unfvjrVr16JXr15x7tWmTRsMGTIE7u7uGGklLfHBgwdx9OhRFCpUCPXq1cPff/+N+vXr46233sK4ceMAAL1798Zvv/2G9u3bJ/k9XrlyBaNGjcL+/fvh6emJFi1a4Oeff0bRokVx+fJlHDnCel2Gqeuzzz7DuXPnkDVrVpuV47SbIlBKLQLQDsANrXWlBM6rAWAPgFe01vEb1ByAkxNHNC1aAK1bA1evAtEDBEHIdKSBcgR2r0eQ2L2LFCkCAKhatSrOnz+P+vXrY9u2bfjiiy8QFhaGkJAQVKxYMVmKYN++fWjcuDG8vb0BAD179sSOHTswduxYnD17Fm+//Tbatm2LFi1aAAAqV66Mnj17olOnTujUqVOK3kts7Gka+h5Aq4ROUEo5A/gcwEY7yvFcNG8OvP021xUkdTorCIJ9sHc9goSw1k54eDjefPNNrFmzBocPH8agQYOsypAQ8ZmSPD09ERQUhMaNG2P27NkYOHAgAOD333/HsGHDsH//flSvXj3F78cSe5aq3AEgsZibtwGsBXDDXnLYgq+/BipUAPbvB6ZOdbQ0gpB5SO16BMnF6PTz5s2L0NDQBKOE4qNWrVrYvn07bt26hcjISKxYsQKNGjXCrVu3EBUVhS5dumDSpEk4cOAAoqKicOnSJTRp0gRffPHFs0pqz4vDfARKqcIAOgN4CUCNRM4dDGAwABQrVsz+wsW5P6e1pUpREfTsCThADEHIdKR2PYLk4uHhgUGDBsHX1xc+Pj6oUSPBrswqBQsWxNSpU9GkSRNordGmTRt07NgRQUFB6Nev37OqZ1OnTkVkZCR69eqFe/fuQWuN9957L0XRUbGxaz0CpZQPgN+s+QiUUqsBfKW13qOU+j76vETVqa3rESSHM2eAatWAypWpGFzE1S5kYKQeQfolPdUj8AOwUil1HkBXAHOUUrbxfNiJkiUZRfT330CfPo6WRhAEwTY4TBForYtrrX201j4A1gB4U2v9s6PkSSqvvcZCNitWmDHSgiCkL+xZjyA2nTt3jnOvjRvTVnyMPcNHVwBoDCCvUioYwHgArgCgtZ5rr/vaGzc3YPt2mogGDADKlQMcYJoUBOE5sGc9gtgY4a1pGbspAq11j2Sc29dectiDypU5I3jlFaBZM66adHNztFSCIAgpQ1JMpJDu3bm+4OFDpqMQBEFIr0jcy3MwcyYT082ZA7RrB9go7YcgCEKqIjOC58DJicqgbl2gb1/g8GFHSyQIgpB8RBE8J66uwKBBwOPHQKNGUuJSEGxJSusRAMDMmTMRFhZmY4mSzoQJEzBt2jSH3T85iCKwAX370nF85w7QsqXULxAEW5GeFUF6QnwENuLHH4E9e4B//gE+/pi1DAQhw/Duu0BgoG3brFqVttUEsKxH0Lx5c+TLlw+rVq3C48eP0blzZ0ycOBEPHz5E9+7dERwcjMjISIwdOxbXr1/HlStX0KRJE+TNmxfbtm2z2v7QoUOxb98+PHr0CF27dsXEiRMBsJ5BQEAA8ubNi4CAAIwcORL+/v4YPnw48ubNi3HjxmHjxo2YPHky/P394eSU8Jg6MDAQQ4YMQVhYGEqWLIlFixbB09MTs2bNwty5c+Hi4oIKFSpg5cqV2L59O9555x0AgFIKO3bsQM6cOVPwAScdUQQ2wsWFaSfKlAGmTGH9gmbNHC2VIKRvLOsRbNq0CWvWrMHevXuhtUaHDh2wY8cO3Lx5E4UKFcLvv/8OgMnocufOjenTp2Pbtm3ImzdvvO1PnjwZefLkQWRkJJo2bYpDhw6hcuXKCcpTo0YNNGjQAMOHD8eGDRsSVQIA0KdPH3zzzTdo1KgRxo0bh4kTJ2LmzJlWawtMmzYNs2fPRr169RAaGgq3VIhNF0VgQ3x8uL5g5EhGEAUEcJ8gpHsSGbmnBps2bcKmTZtQrVo1AEBoaChOnTqFBg0aYOTIkRg1ahTatWuHBg0aJLnNVatWYd68eYiIiMDVq1dx7NixBBVB9uzZMX/+fDRs2BAzZsxAyZIlE73HvXv3cPfuXTRq1AgA8Prrr6Nbt24ArNcWqFevHkaMGIGePXvi5ZdfflYDwZ6Ij8DGdOkCbNoEREQArVqxhJ8gCM+P1hqjR49GYGAgAgMDcfr0aQwYMABlypTB/v374evri9GjR+OTTz5JUnvnzp3DtGnTsGXLFhw6dAht27Z9llbaxcXlWdbP2PUFDh8+DC8vL1y5cuW535O12gIfffQRFixYgEePHqF27do4ceLEc98nMUQR2IHSpbnY7ORJoHNncR4LQkqxrEfQsmVLLFq06Fn+/cuXL+PGjRu4cuUKsmfPjl69emHkyJE4cOBAnGutcf/+feTIkQO5c+fG9evX8ccffzw75uPjg/379wMA1q5d+2z/hQsX8NVXX+HgwYP4448/8O+//yb6HnLnzg1PT0/s3LkTAPDDDz+gUaNG8dYWOHPmDHx9fTFq1Cj4+fmliiIQ05CdGDmSs+lNm4Dx44EkDlIEQbDAsh5B69at8dprr6FOnToAWFj+xx9/xOnTp/HBBx/AyckJrq6u+O677wAAgwcPRuvWrVGwYEGrzuIqVaqgWrVqqFixIkqUKIF69eo9OzZ+/HgMGDAAU6ZMQa1atQBwRjJgwABMmzYNhQoVwsKFC9G3b1/s27cvUTv+kiVLnjmLS5QogcWLF8dbW2Ds2LHYtm0bnJ2dUaFCBbRu3dpWH2e82LUegT1wZD2C5LJhA9ChA1cfr18P2Ki8qCCkClKPIP2SnuoRZHjatAH++ovPe/QATp92rDyCIAjWENOQnWnSBHjzTWDpUqB9e2D3btYzEAQh9ahVqxYeP34cY98PP/wAX1/f52578uTJWL16dYx93bp1w5gxY5677dRCTEOpxPbtXFfQuDHwxx9S5lJI+4hpKP0ipqE0SqNGNA9t3gwMG+ZoaQRBEEzspgiUUouUUjeUUkfiOd5TKXUo+vGPUqqKvWRJKwwfzoyl8+YB0YENgiAIDseeM4LvAbRK4Pg5AI201pUBTAIwz46ypAn8/IAvvuDzYcM4OxAEQXA0dlMEWusdAEISOP6P1vpO9Ms9AOy/jjoNMGIEMGoUF5l16ACcOuVoiQRByOykFR/BAAB/xHdQKTVYKRWglAq4efNmKople5QCpk7limMnJ0YSReeaEgQhFvZOQ7169WqUL18eTZo0QUBAAIYPH56iewFcjXzr1q14j7u7u6e4bXvjcEWglGoCKoJR8Z2jtZ6ntfbTWvt5e3unnnB2Qilg7VouODt7lvWPIyIcLZUgpD3srQgWLlyIOXPmYNu2bfDz88OsWbNSdK/0jkODGJVSlQEsANBaa33bkbKkNkqxxGW/fnQev/su8O23jpZKEKzjoHIEdq1H8Mknn2DXrl04d+4cOnTogLZt22LatGn47bffMGHCBFy8eBFnz57FxYsX8e677z6bLXTq1AmXLl1CeHg43nnnHQwePDhZ71trjQ8//BB//PEHlFL4+OOP8corr+Dq1at45ZVXcP/+fUREROC7775D3bp1MWDAAAQEBEAphf79++O9995L1v2SgsMUgVKqGIB1AHprrf/PUXI4khs3gOi8Vpg9m8nqoutRCIIA+9YjGDduHLZu3Ypp06bBz88P/v7+MY6fOHEC27Ztw4MHD1C2bFkMHToUrq6uWLRoEfLkyYNHjx6hRo0a6NKlC7y8vJL8ntatW4fAwEAEBQXh1q1bqFGjBho2bIjly5ejZcuWGDNmDCIjIxEWFobAwEBcvnwZR44w+PKunezIdlMESqkVABoDyKuUCgYwHoArAGit5wIYB8ALwBylFABExLfYIaNSqBCwbx8jiT76CHjvPaBsWaavFoS0RBooR2CXegQJ0bZtW2TNmhVZs2ZFvnz5cP36dRQpUgSzZs3C+vXrAQCXLl3CqVOnkqUIdu3ahR49esDZ2Rn58+dHo0aNsG/fPtSoUQP9+/fH06dP0alTJ1StWhUlSpTA2bNn8fbbb6Nt27Zo0aKFTd5bbOwZNdRDa11Qa+2qtS6itV6otZ4brQSgtR6otfbUWleNfmQqJWCgFKOIvvySkUSdOwPHjztaKkFIe9i6HkFiZM2a9dlzZ2dnREREwN/fH5s3b8bu3bsRFBSEatWqxalXkJT3YY2GDRtix44dKFy4MHr37o2lS5fC09MTQUFBaNy4MWbPno2BAwc+13uKD4c7iwUyciRXHmfNCrRrByQQfCAImQZ71iNICffu3YOnpyeyZ8+OEydOYM+ePcluo2HDhvjpp58QGRmJmzdvYseOHahZsyYuXLiAfPnyYdCgQRgwYAAOHDiAW7duISoqCl26dMGkSZOevTdbIxlv0hDLlwN79jAf0csvM3OpxaBEEDId9qxHkBJatWqFuXPnonLlyihbtixq166d7DY6d+6M3bt3o0qVKlBK4YsvvkCBAgWwZMkSfPnll3B1dYW7uzuWLl2Ky5cvo1+/fs+qpU2dOtUm7yM2knQuDdK3L7BkCfD668DixTQfCUJqI0nn0i+SdC4DULQot0uWmCkpBEEQ7IWYhtIgo0YBv/0GBAUxmqhMGTqRBUFIGfasR2DJ7du30bRp0zj7t2zZkqzIotRGFEEaxN0d2LEDePFFIDgY6NUL2LmTrwUhNdFaQ2UA22RSiszbAi8vLwTaeuVdMkmJuV9MQ2mUnDmB6dOBp09Z0ax9e+DyZUdLJWQm3NzccPv27RR1LIJj0Frj9u3bcHNzS9Z1MiNIw7RrB2zbBuTODdSrB3TsyJlC9uyOlkzIDBQpUgTBwcFI74keMxtubm4oUiR5yZxFEaRhlAIaNACiorjGYMECoE8fYNUqZi4VBHvi6uqK4sWLO1oMIRWQ7iQdoBRw5gyQJQuzlo4d62iJBEHISIgiSAcoBcyfDzg703cwZQqwdKmjpRIEIaMgiiCdUKIEsGUL4OpKH8GgQcCuXY6WShCEjIAognRE7drAypWMJPL25tqCs2cdLZUgCOkdUQTpjObNgYsXGU0UGcnIonv3HC2VIAjpGVEE6ZACBVjEZtw44NQp4JVXpNSlIAgpRxRBOuXAARaycXYGNm4ERoxwtESCIKRXRBGkU158Efj0U+Cll/j6m2+YpE4QBCG52E0RKKUWKaVuKKWOxHNcKaVmKaVOK6UOKaUkk04yGTOGyel69+brgQM5UxAEQUgO9pwRfA8goeq7rQGUjn4MBvCdHWXJsDg5cSawdCl9B507S3UzQRCShz1rFu8AEJLAKR0BLNVkDwAPpVRBe8mTkVGKs4L164Hr14Fu3cR5LAhC0nGkj6AwgEsWr4Oj98VBKTVYKRWglAqQBFjxU6ECE9T5+wP/+Y+jpREEIb3gSEVgLcm51Xy3Wut5Wms/rbWft7e3ncVKv2TPDrRqBbi4AF9+yeR0giAIieFIRRAMoKjF6yIArjhIlgzDW2/RLFSiBDBgANcZCIIgJIQjFcH/APSJjh6qDeCe1vqqA+XJEPj5AdWrA7dv05HcvTsQHu5oqQRBSMvYM3x0BYDdAMoqpYKVUgOUUkOUUkOiT9kA4CyA0wDmA3jTXrJkJpRiPqICBZibKDAQ+OADR0slCEJaxm6FabTWPRI5rgEMs9f9MzOlSgGHDzNT6YgRwIwZXHjWubOjJRMEIS0iK4szKK6u3HbqBFSrBvTvD5w/71CRBEFIo4giyMDcvAm0bk2H8YMHLHcp6wsEQYiNKIIMjLc38OOPdCBHRgJ79gBffOFoqQRBSGuIIsjgdO7MfEQ5cwIvvABMmAAEBTlaKkEQ0hKiCDIBOXIwBUW1akCePHz++LGjpRIEIa2QuRRBVJSjJXAYs2YxF9GCBYwomjjR0RIJgpBWyDyKYOtWoEwZ4IjVrNgZHmdnbnPlAmrVAj7/XFJWC4JAMo8iyJ0bOHMGaNo0U88Mli0D/v0X8PAA3niDTmRBEDI3mUcRVK8OdOkC3LhhVnLJhHz1FVCoEPViQAAwZ46jJRIEwdFkHkUAAD/9BHh5AcuXA3/+6WhpHIK7O8tanjvH8NIxY4DLlx0tlSAIjiRzKQJnZ2DTJibk6dIFCAtztEQO4eWXgcmTueDsyRNg+HBHSyQIgiPJXIoAYNX399+nEhg3ztHSOIxBg7i+4D//AdatAzZvdrREgiA4CsXcb4mcpNQ7ABYDeABgAYBqAD7SWm+yr3hx8fPz0wEBAc/f0BtvAPPnA9u2AY0aPX976ZTHj4Hy5bng7MABM7pIEISMhVJqv9baz9qxpM4I+mut7wNoAcAbQD8An9lIPscwfTpQtCiT8WRSI/mFC8CUKcCHHwKHDgHffutoiQRBcARJVQRGWck2ABZrrYNgvdRk+iFHDuCjj4BHj4DGjR0tjUO4cgX45BPgUnTl6AkTgIcPHSqSIAgOIKmKYL9SahOoCDYqpXICSP/B+EOHAvXqAadPM3F/JqNmTZqEpkzh67t3gWnTHCuTIAipT1IVwQAAHwGoobUOA+AKmocSRCnVSil1Uil1Win1kZXjuZVSvyqlgpRSR5VSibZpczZuZEzljBnA9u2pfntH4uwMeHqar/PkYXbSa9ccJ5MgCKlPUhVBHQAntdZ3lVK9AHwM4F5CFyilnAHMBtAaQAUAPZRSFWKdNgzAMa11FQCNAXyllMqSDPmfnxw5gA0b+LxrV5qKMhE//ADUrQsMGQKEhtJ5LKmqBSFzkVRF8B2AMKVUFQAfArgAYGki19QEcFprfVZr/QTASgAdY52jAeRUSikA7gBCAKR+6ZQGDYCvvwZu3WJoaSaiYUPg77+Btm25pqBJE2DuXJkVCEJmIqmKICK6xnBHAF9rrb8GkDORawoDuGTxOjh6nyXfAigP4AqAwwDe0VrH8T0opQYrpQKUUgE3b95MosjJZPhwVnn/7jtmZMtktG0L7NrFlBOPHwNffuloiQRBSC2SqggeKKVGA+gN4Pdos49rItdYiyqKvWihJYBAAIUAVAXwrVIqV5yLtJ6ntfbTWvt5e3snUeQUMHkyk/CMHg3s3Wu/+6RBlKLfvHRpoFcv6sPr1x0tlSAIqUFSFcErAB6D6wmugSP7xMaMwQCKWrwuAo78LekHYJ0mpwGcA1AuiTLZHldXYMUKQGugeXMgPNxhojiKLl3oRA4PB2bPdrQ0giCkBklSBNGd/zIAuZVS7QCEa60T8xHsA1BaKTRY4gYAACAASURBVFU82gH8KoD/xTrnIoCmAKCUyg+gLICzyZDf9rRuzVXH9+8zZXUm48ED4OBBoH17mokyme9cEDIlSVIESqnuAPYC6AagO4B/lVJdE7pGax0B4C0AGwEcB7BKa31UKTVEKTUk+rRJAOoqpQ4D2AJglNb6Vsreig357jvaSP75J9OV8qpVi6uMhw4Fbt9mVJEgCBmbpOYaCgLQXGt9I/q1N4DN0WGfqYrNcg0lxq1bQKVKHCLv2QP4+tr/nmmA338H2rUD/P25xi4sDDh6FHDKfOkJBSFDYYtcQ06GEojmdjKuTZ/kzQsEBtJ5/PLLHB5nAmrW5PbffxlJe+IE19wJgpBxSWpn/qdSaqNSqq9Sqi+A3wFssJ9YaYQCBYAff2SJyzp1MkWJS29v4PXXgRde4Po6b29g3jxHSyUIgj1JkmkIAJRSXQDUA8NCd2it19tTsPhINdOQJXXrArt3M64ykxjNHz9mDZ/t24GZM5mYrmBBR0slCEJKsYVpCFrrtVrrEVrr9xylBBzGli1MyvPjj8D33ztamlRh9GigQwcWrImMBBYvdrREgiDYiwQVgVLqgVLqvpXHA6XU/dQS0uFkywbs3MkA+4EDGVaTgbl2jYXtx4wBgoKAUqVYwycTWMYEIVOSoCLQWufUWuey8siptY6zAjhDU7EisGABn3fuzJzNGZQCBYAdO4BPPwVq1ACePgXOn5dyloKQUcnYkT+2pm9fGs4vXQJefRWISP38eKlN796AlxeDp5Ytc7Q0giDYA1EEyeWll4BvvmFMZcfYyVQzHsOGAfv3M/XEunWy0lgQMiKiCFLC668zhGbDBhrSMzDGQrKePVmv4LffHCuPIAi2RxRBSnBz4zA5WzbWeVy92tES2ZV+/Zh1o2BBMQ8JQkZEFEFKKViQYaVOTkCPHsCBA46WyG48fcqVxj16cBJ0546jJRIEwZaIInge6tRhJFFUFIPub9xI/Jp0SPny9I936kSlsHatoyUSBMGWiCJ4Xvr1Y1xlSAiztT186GiJbE758ty6uQElSwJr1jhWHkEQbIsoAlvw0kssaBMQANSvz6W4GQgj8WpAAPMPbdlCvScIQsZAFIGtaNcOKFOGGUtbtmSVswxCqVJMs1S4MNCtG5dP/PKLo6USBMFWiCKwFc7OHDIXK8Yh85tvOloim6EUc+116AC8+CLg4yPmIUHISNhVESilWimlTiqlTiulPornnMZKqUCl1FGl1HZ7ymN33N05I3B3B+bOBSZNcrRENuXmTSA4mOahv/7K0Fk2BCFTYTdFoJRyBjAbQGsAFQD0UEpViHWOB4A5ADporSuCpTDTN56eTFmdNSswfnyGGTprDZQrx8qd3boxeuh/sStQC4KQLrHnjKAmgNNa67Na6ycAVgKInZPhNQDrtNYXASBWFbT0S6VKHDrXrQu89hrw55+Olui5UYrRsn//zUR0xYpl+HV0gpBpsKciKAzgksXr4Oh9lpQB4KmU8ldK7VdK9bHWkFJqsFIqQCkVcPPmTTuJa2Py5mU+hmLFgLZtmawunVOvHktXhoTQPLRpE3DvnqOlEgThebGnIlBW9sUOpXEBUB1AWwAtAYxVSpWJc5HW87TWflprP29vb9tLai88PJiCIiqKyiC1K6vZmPr1uW3QAChbFnjyBPj1V8fKJAjC82NPRRAMoKjF6yIArlg550+t9UOt9S0AOwBUsaNMqU/37sDnnzPmsl49VnpJp9SqxbfTsyezkRYunGFcIIKQqbGnItgHoLRSqrhSKguAVwHEdi/+AqCBUspFKZUdQC0Ax+0ok2P48ENg7FgOoWvVSrcVzrJkAX76iQlXvbxoHvrzT+DBA0dLJgjC82A3RaC1jgDwFoCNYOe+Smt9VCk1RCk1JPqc4wD+BHAIwF4AC7TWR+wlk0P55BMqA6W44OzkSUdLlGK2bgWWLKEiePxYUlMLQnpH6XS2AtbPz08HpGdb+9GjQJMmzFq6eTMjjNIZ/fqZhdqKFGE0kSSiE4S0jVJqv9baz9oxWVmc2lSsyFzON28CtWsDp045WqJkU7w4cOUKLV1duvDthIY6WipBEFKKKAJH4OfHKmcPHwI1a6Y7M1Hx4txeuMDFZeHhVAaCIKRPRBE4ivnzGY959y4VQ2CgoyVKMoYiOHeOgVD588viMkFIz4gicBTOzvSyNmlCu0rdukxNkQ7w8eH2/Hm+jZdf5owgA5ZiEIRMgSgCR5I7NzOVfv45S182b87XaZxChYAzZ4DBg/m6WzcgLCxDZNIQhEyJKAJHoxTXGfz9N9NRtGyZ5kNwnJyAEiW4BbjS2NtbzEOCkF4RRZBWKFCA6wwiIxmgP3euoyVKkL/+AkaM4HMXF5qHfvsNePTIsXIJgpB8RBGkJXr0oBNZKWDoUGD06DRb6SwoCJgxwyxZ2bUrfQRiHhKE9IcogrTGwIHM5ObsDHz2GdC7d5qsgVyxIrdHj3LbuDETrv70k8NEEgQhhYgiSIu0bUufQf78wLJlHG6nMZuLsSB68WJGwLq40Gn8v/9J7iFBSG+IIkir1KoFXL0KzJrFSvENGgC3bztaqmcUKcKaO4sXA2+9xX09e1JfrV+fsjZXrgRy5GD+IkEQUg9RBGkZpYC336Yndv9+oEoVVoZJAyjFycqBA8ynB3AphI8P96eE06cZhnr1qs3EFAQhCYgiSA988QWT/1++DFStmqbyOVSrxlDSkBCK17Mnc+ldu5b8tipX5vbWLdvKKAhCwogiSA+UKAEcO2bmfW7XDpg509FSPSMigj6Drl35iIpKmdP47Flub2SMytWCkG4QRZBeyJWLK7ZmzOBKrvfeA4YNYy/sYFxcgG+/Bf79F/j5Z84Sfvwx+e389Re3oggEIXURRZDeePdd+gk++ACYMwdo2hS4c8fRUuHll4HSpSlar14sz/x//5e8NnLl4jYqyvbyCYIQP6II0iOlStFv0K8fsGMH4OtrBvQ7EC8v+gpefZXO5OTOCpycgJIlgf797SOfIAjWsasiUEq1UkqdVEqdVkp9lMB5NZRSkUqprvaUJ8PxySfMCX35MlC9usOT/eTJQ0VQqBDQujUnLMlZU/DoEZAtm/3kEwTBOnZTBEopZwCzAbQGUAFAD6VUhXjO+xysbSwkhyJFGL9Zty6dyN27M6jfQX6D5cuBXbv4fMIELnuYMSPp14eHA0eO0PolCELqYc8ZQU0Ap7XWZ7XWTwCsBNDRynlvA1gLQFyEKcHDA/D3B4YPp9d29mwOxx0Qg5k7N+Dmxuc1anCl8eTJzEuUFObNY+oKf3+7iZiu2LwZmDLF0VIImQF7KoLCAC5ZvA6O3vcMpVRhAJ0BJJhqUyk1WCkVoJQKuHnzps0FTfe4ugJff02n8cKF9BtUrcrZQiqyfTv1kTEhmTOHfoPOnWm9SowiRYA6dYDr1+0rZ3qhb19gzBhHSyFkBuypCJSVfbFTac4EMEprnWBWNa31PK21n9baz9vb22YCZjjc3elpnT2bPW/t2sDSpal2+8OHgW++MYOY8uZlOOnNm8yQkdjMYOlSRhvdvCmRQwCjr7JkcbQUQmbAnoogGEBRi9dFAFyJdY4fgJVKqfMAugKYo5TqZEeZMgfdu3NoHREBvP461xs8eWL32+bJw62RmhoAatakiSM8nKWZ33sPmDiRCiI2U6awdHNkZOpExC5dCnxkEcIQHp4qH1OScXKiPGkw+ayQ0dBa2+UBwAXAWQDFAWQBEASgYgLnfw+ga2LtVq9eXQtJ4NYtrYsW1drFRWtA6xdf1Do42K63/OMP3uqff+Ieu3FD6379tHZ25jmA1pMmab1vn9YPHvCcYsW09vLSunp1rS9etKuoWmtTDq217t6dz9ets+89e/fWevr0pJ1ryPff/2pdrpzWkZEJnx8RofWSJdxmRiIitI6KcrQUaRcAATqeftVuMwKtdQSAt8BooOMAVmmtjyqlhiilhtjrvkI0Xl7A1q0sduPqyhQVL75IQ76dsDYjMPD2Zp2dVavMfWPH0qmcMydTKV29yjaaNmUG07VruVr56lX7m4qCg7ndutW+9wkKAjZtSt41b7zBhXqBgQmfN2sWJ4CpaA1MUxQsyFnU/fuOliT94WLPxrXWGwBsiLXPqmNYa93XnrJkSkqVYq8wZw5w4QKX/zZtyvUHo0dz1ZcN8fJikx07Ai+8AIwcyUJr9+5xbFumTMzzV6/m+SdP8rFsGR3FM2fGNdG4ugJFi7Kss7VH0aJ0kaSEiAiahQBgy5aUtZFUSpYEjh9P2rnZstGqd+IEy4D++Sd1eXw0aMBtzpzPL2dCvPcev6OdOxnptWSJzX9KKcKII3n40FylDlAxfPstMGoU6z2lhCtX+PuybDcjYVdFIKQR3N0ZlzlrFtCqFUNR9uxh4H9Ke08rlCgBXLzITrlKFUaxAlQM1mzvUVF0ZxgsX86CbLt3M2KmYEGGoAJ0nD5+TH32yy9AaGhc27mnZ/yKolgxtmetI3jwwBxFHj/OP32hQsl77zduAPnyJXzOnj1mrYaoKI5e4yM8nAvsPD3px9iwIfH2s2fn1t4+hfPnuV24EPjhByqCtETsGk4ffECF5eLCbO6jRzOoLjkULswU659+yu+kTRubiZsmEEWQmWjenEPeVq1YDrNMGQbtxx6qpxClgN9/5/NPPuEf5+TJ+K1RscNEL1wAsmZlh3fjBpdIGHTtyj/f5s18G0uXsjzmxYvm48IF4NIl4Nw5jlbv3o3ZvosL/9CGYjD49Veas6pVAw4eBLZtYzptg/BwLsvIn58zk9icPcuRvr8/0KhR/J+PpWnn8mUqzPgwnOVjxjBnU1I690WLuA0L43bLFqBhw5gyh4ZSYSSkhBKjeXM6+0NC+JmkhdmAJcb7NzBMlffu0TTZvXvyFQFABfjVV/yNpNFS4ilGcg1lJpycgJde4tDb25vG96pVWV/SRgyJ9v7kzMlQ0u3bzZHq66+zgw8JAfr0YZI6SwoVonnJ25tKondvFqsBzGm/gY8PO9J69egG+ftvdkghIdRrRkfavDlH03PnAh9+yPOV4vkGr7/O6w4e5IK4Tz5hpu8332TZ6GbNeK9VqzhjuHYtZhU1o/ZCYpXVjDTbAHDmTMLnZs3KInUATUInT3IWYekr8fcH3n+fnTtgFgR68oS+jmbNgM8/N89/8oTfy/vvJ3xvS86dY1hwYCC/z8ePgb17eWzHDn5Px44lvb3UIPaMwBjnGL+lc+eS156lEu7bN2ZbGQWZEWRGunalv2DIENphOnbk0HPixJQbUaN54w2gbFl2mMOHUxHcusUpddeu7OSBuOaE0FBg2jSgfXtzRgAABQoAr7zCxWaAmbvIMrxUa4q9ciU78A0buK9aNe43TFQjRtBZe+QIjz96xE78xo2YMwvjsXt3TMd3r14xZXZzizm6HjqUiixbNvORJQtnIi4uplJs04ZK5ddfzWPWHvXq0Vk+fDjbd3ICvvuOujtXLs6+pk+nzT5HDn6Gw4fzO5g9m9dcvGjKa3SQq1cDhw4BAwdSiSZE2bLA06dUqH/9xe/Q+O6MGddvv1HJvfSSqfRtycGDnNUYdbLjY/BgmoBiy/Dxx1xe88MPfJ3cNOfOznzP58/TLATwu3vvveS1k6aJL5worT4kfNTGPHrEuE5A6+bNGXZqA86dY5Pz5lk/HhXFWxtcvMjz58/XumlTrWvX1rpnT62//jrmdUuXmqIaNGyodZ48ZrgloPXRo1p36aJ12bLmedOn89iFC3x9+rTWr7yidUCAeU5wsNZ//cVQzY4dY7Y5darWy5drPXu21p9+qvXIkVoPG6Z1/fo8XqyY1i1bat2okdY1a2pdqJDWJUtqXaqU1j4+Wru6ap01q9Z582rt4aG1u7vWbm5mhG9KH0ppnTs3nxcsqHXduqZMVapo/f77Wk+cqPWMGVpny6Z1kyamvP/+q/Xhw1qfOaP1lSta372r9ZMn5udheZ+ePbXeuZPPq1Uz9xvPjx2z/l1fvszv7caNxH41MfnrLz4sw3yfhylT2M4776Ts+uPHTVnat0/6dTdu8LONL7T16FG2uXdvyuRKKkggfFRmBJkdNzcO8xYvpgG+WjUONX19n6tZwwY/eDDNQFmzxjzeuTNDNgMC+NoYrWbLBtSvT6vV0qWcEQCmc9WYERjnh4TQRNGvH9+CpydnCxMmMPzUuPbqVfP1X38BHTpwROzvz7V3H3zAkeOhQxzphYTQ32BJ+fKcPMVmwQIm22vWjA5UgKNlT0+aJU6e5L4mTfjeXnyRZppXXjHb0JpyRkTwsXQpTVMGlStTtrFjaTK6f58J/fbtY7t583KWcfcubeHZs/MzP3eOZrGHD822tm3j9uJF0/wUG2dntuHiQtkiI4F//qGZCuD9jFnCwYPc99VXnEFkycJH1qzcHjzImcvQoZwNOTmZD6VivjYeUVFs35LDhymXkxO3sWdQoaGMfciRgzMI49z+/fnbaNaM7SQ38ufYMU6Yq1Tha6WYUDGpfPklzWqxfRcGv/zC7erVDKd2BKIIBP4716+nF+3aNfZWa9bE/ScmA0tnZGwlAHC9gKEEADN8082NnfiVK5zmlyhBB2yuXJyON27M854+5dawtXfsSMdo06bs2C0zcrdpw47B8AsY7RrJ7c6eZef49tsx10JkyUJzh7G2IL6U2oYslu/ZiEIyImwAswNu1YrmMktFoBQ7Lmdnfl6Gv8HFhYqhYkUqgjx5gLZteWzNGiqCKlXYya1fTwXZpAmd3YULmw7piAh24u3b08FrOOpXruQ9w8LoKwFoxgsLY1sLFpjvu3hx87MIDuZ35uxs2tANJRgf333HR0oxalonBycn07G7eTM/5//+l7K6uPB3kZB5zsWFA4sDB8zV8PXr83vo1y/uudba++UXflZTpzIAIva5RhDB/fvAH39YlyEykoOHwoUZmGBrRBEIpFMnFhru0oX/nDZtOJQcMCDFTZ45E3+uHKMzMkb6ljMCwExnXaIElYNh161QgSGlhw/z9YUL3Pr4mKP1Tz+lw/fcOXZeGy0SnE+ezNHdwIHmPmMxWa5cZoRNSAhw6pQ5s/H0tB7DHxBAV8uHH8aspWAoByMbqyUlS9L2nxCGDyQ8nCPcwoUpm2WklaE8r13j1/X4MX0pYWH8Otu2NTtnFxfTpv/tt2ZYbsmSTP0BMOpo+3bOjozvrWZN+lauXmUU0p49dBa/8w6d/Z6eZqLbhQs5lnjyJOZjyRJ+J82bM42I4fQ2HsZsKCoKaNmSE9TatXkPgDONoCDOrCpXpsxXrvD7MGZQERHMwA7wJ1ytWszZleEryZ6d7rGnT2NeG98jPDxu9NnVq4z6CglhB50tG8+1bNP4/i1JLIHgf//LR0KMGsUABlsjikAw6dyZQ7ZRozisHjiQPc9//pOi5kqUiP9YgQL8w9y5Qwer5Yxg1y5ztFyuHJ3Hp07x9Zkz/CMaisEYcb/wgtl23740k3h6srZBnz7cnyUL39qIEZwdGKYCS0VgdObGWoKDBynL/PlxF2r99BOrsf30E6+1nPqXLMlO0FAsBw/Sifvddzx29y47EmMGYsnRozS55M7NEfeKFbzG2ZkdpMHvv/O1oRyUotK4fp2PRYvYIRkrjQ3Ht5eXaX45c8ZUBIZzf98+OqoBvr/27c0ZSu3aZid96hR/IgsW8Njdu9aXpRjO21y5zHtZIyqKHWvJklTsBu++y9GylxcVxZUrNPV8/XXM6w1F0Lix+RzgTOjiRbbj5cUBgqEgr17ld/vqq/FHUX/zjemwB1h5LyDAvIe1lcylSzOyKDzcHAwsXkyFHVtp3LjBAU358nEVkXGeiwt/v5afiy0RRSDE5I03qBA8PfnvGDOGxtfJk20aMG7Y/q9d45+zfn3atrNlY+d54QJDRn18GGlkdPxTp1JRGKYfHx921JZrDrTmFLtcOUYQ5c/P9g0bs7OzGXIJmEooVy4zPPPRI5pi6tWjjFev8pqCBc3r/vmH28mT2TnFXkNgGb9/4QI7WKMcJ8BO2JoiOHuWNv0mTTh67dCBI2FrppFx49hJLF9O01P27DGrlhqhnoCpCFavprknRw4zDPLxY65gBtix16vHDmjfPi5Qjy/pr6EgJk+OuTjQICrK7Cjjs5Eb3L3L87NkiRmS6uxMRR8VRdPf339T/rAwU8lYhtXGDh817h8aStv+jh3msfPngfHj+fv69lvrcllGjtWrRwX86af8XKpXt35NqVJU5JZm0adPrS8KLFuWEXO+vmzXEcg6AiEu+fKxFzOGM1Oncghow4Q/Vapwxeyvv/IPcOhQTNNMsWLmn8zbmx3jo0fsAMqWpfkA4Hblyrjte3hwHcKBA2znwQNT+WzeTF0HsPObMIGdrKcnO5xffzU79Xv3OEouW5YmCkuMEeGhQ7TF9+5tHtu4kXrzpZf42lhrUKBATEVgjfbtaTdev5526S++oCKKiDDXO9Spw8/w0SOgRQt2juvX86szlFy5cjE7MeP51q3sQC9eNLOv3rxpmi6MehI3b7L4neFkt4ahCIYONUN8DbZuZSferBmV8dmzDH2tW9f62gNDvnHjGOrp7W3mZTJMgLVqsS0gppnMUsmEhlL2yEguxmvfnvsXLGC7lp+JoSQ2buRAwdpaEDc3fpbu7jSV5c7NwcbFi9YVOcD7GutkjFBe476xV9n/8AOX8jh0PUZ84URp9SHho6nIP/8w5tGImXvnHZund+zQgU03aMAwx5s3457j76/1e+9pff++1m3aaF26tNabN2sdEhIz1NEa9+9r/fix1hs2MNOp1lovXMh7njsX/9uZOZNhoFWqmG//jTdinvP4sXmsXDmtw8O1njNH69u3mTHUOPbggdbjxzPM8+lTPq5ejf/eV66YGVmNNn78Ues+fbQuXJgPY/+wYfwsunVjmKrWWq9Zw2Nt2zLbq3GfadPM62rX5n0MgoLM8F0jSa0RLrl8eVwZjXb+/ZfXbtpEOSzp35/nLF7M8OAVK8zrfvghbpt79pjHa9TgvrNnzX3163Pf77/z9e7d5rVhYQzpLVqUnwOg9Vtvcft//8ffVYUKZlvG72bVKr42fuZvvmm2efduTPlmzOD7nD8/ZljtmjX8vg3Cw81jJ07w88+aVesPP9T6zz+1fv11nmNQogTPbdQo7mdiS5BA+KjDO/bkPkQRpDLXrjGg3sgfPXmyTZqNimJHXqIE1w0YnfP58wlf16CB1k5OPHf6dIo1dmzy7m10JHv2UMH85z8xj//5J9cfFCvGNQfTplHOFi0Yo790Kc+bMUPrL75ghm+AawMArXv10vrLL83O4PRpKpF8+ZImX/HibENrs43167luIfYagooVua1cmWsXtGYH//PPVD6A1vfumW0PHx6zE3/9dXbk27Zx39atVGY1a5od8++/x5WxdWseu3yZr+vU0bpZs5jnXLvGc156Seu1a03lGN93vHu31uXLU6kBWm/cqPXgwWZnaxAQYH4m1vD15XFj3cAff3B/+fLmezfWNBi/u1y5uJ5jxAiuIenZU+tKlay3/9NPMb+Dxo21zpHDVLjHjpnHVq/m+oAtW6hgJ07UOksWKi6teU327DzX19f6/WxFQopATENCwuTPz+ghT096CceMYfylDShYkOaCcuVMu661KBuAU//Ll2niqVqVJqQRIzj9j296Hh+GiahtWzr9pkyhE9Jg8GBGqJQuTRPV++/TUXztGv0Cv/5K88kHH9DZbdipjeiS116jScng2jW2YzinATpwp061Lp81J3KOHIyEAWJGJ1WowO2RI/yKZs3iAvGOHXl+kyY0K+3bx8/K0pnr7c2Ing0bTLNFnjw0l+zdy+sA63H3S5bQ1m7YvPPkof3dMOEAXGvg5EQTUe/edD+5uJghrUePxiytXbs2zSNGneaWLc2fmqWjvnx5vh/D7AbQpHPxIq2Z/fpxn2FmbN2aPqLjx/kZVatmmoAM09CjR/xdhITQqbxsGT/Ta9cYFWYZL5E7N7dffMFtWBhNl4Yfq0QJ00H/++80J2XLxsivM2f4uS9cyAyuoaGmWSs1ijHFS3waIq0+ZEbgIO7f53y6TRsOyeMbjiUDT0+OhL791hxBWY5eLSlVSuuuXbnSdOdOTvU3beKo99q15N03ONi8X8OG+plpx6BqVe4bMsTcV6+eeU2uXOaor3PnuKP0f/+NOfJeuzauDK+/ThNPbCIieM348XydNStfG8V+QkJoVnJz4/6pU837DBhgLhLfuNFsc9Ei7suTx7wO4Gi0RAmOwA1zx8WLWn//PZ9//jm3hw4l/pn27m22e+QIZxjNmpnfsfH49FOthw7lyBvQ+tVXrbdXs6b5PTRtqvWpUwnff+9e8x69e9MMZ8wcLR/DhsW87unTmN9Vq1Zalymjdf78fL1sGWcSXbqY1zx8qPXJk1pv3x6z7b//1vrOHc5ArlzRz0xcgNZz53JWU68ef3OdO9OMdfKkeV6nTjFlCwmxrSUWMiMQnpucOTnEGjqU3t2ePemJfQ5Kl6ajc9gw0/lrOdq1pE4djkCbNuXozljZOmGC6TxMKpaRG23aMIfN99+b+7Q25TMYPdpcMHb/Pks8AAzzNCJtDbp25TqBF1+MGRprScmSnOFcuxYzgsmYVRgzAiORXI4c3Hp6cuRapw5fG+GYn39OZ6gxO7F0XG/ezG1ICL+2d9/l16kUneCnTjHy6tgxztKMGZOPDx3xSQlZ9PIyn//vfxz5bt5Mx3axYpwJGJ/NgQOmg9/Pj872PXvoVG3cmJ//v/+aM6ZJkxiFY8mKFRxVG2HFliunL17ke7GM2jJmFMbnasQ9uLhwBtKhA1+fPElH7zvv8LPeuJERYZZRU9mzM9oodpTY2bOcPbZubeYkOnKE27Vr+T7OnOF336oVV67v3MnjkyebKcoBPs+Tx3oghD2wqyJQSrVSSp1USp1WSn1k5XhPpdSh6Mc/SqkqpcMungAAFr9JREFU9pRHeE5CQpilzNOTPXH79uzNUkiBAmbkx4oVbMpammeAf7obN2j6sMzimRJcXbmKtlgxris4fDhmqgVjMZVhigFoRqpalcVfXF0ZaujhQUV2+7bZUQP8g7dqxdz3t24xEV6xYrSwGRiRQwULmqkLgJgmGoD327gx7pqMrVvZYTZrxnONz8RQBMa+kiUZWmp0xOfO0TTTsCFfly7NztTdnSYXFxdTETg7U0EkpdCNIW+5clxJa7lm4d49c53HmTPcdyW6ennNmozRf/99dt6HD5tRyoYZ6/p1UzkbDB1K5VumDMMyLaOGtm/nz/TxY9OsVbAgZfzkE6ac6BRdGX3ZMv6mjHQixoCkXDmanpYupfxG0R+A9xo5Mu5ncOYM1yw0aMDfUNmypsmpcGG+55AQfifly/OaQoV4b2Pdxv37fK/G/fbvN+9pbZGarbCbIlBKOQOYDaA1gAoAeiilKsQ67RyARlrrygAmAbCN8VmwD3nycJjm78+e7d49DqUsh2PJYP9+rhjVmh1QQsVg6tbl9t13bVOK8eHDmJ23JVWr8s/bpIm578IFjmTLlKFvAqANO39+jiBv32YeoKgofjSWoaF//cU/e8WK5j7LNAGWis3Tk3l5jJwzs2dzAVNC9YMWLqSy+O03UxF4eDCc01h7MHMmtw8eMCTVWIxVsSLP++EHc5FVoUJUvMHBZlqMxBgwgKPf117jaN4Ykc+bx5+JoUx+/DHm7KFQIVMZxfaNGO+5c+eYvgQgpg/m+nXTN2HY7420DYa/oFAh82fq4cHPKzSUs5fPP+fMpW5dM2S5VCmGBxs+GMtsK1FR5sp2Y/Xz8eNc83HlCmdjTk4MTa5fn+8vTx525mFh/P0YyvbWLX7+u3ZRWVarxnvlzcvvxljnsWgRfxvJzZyaZOKzGT3vA0AdABstXo8GMDqB8z0BXE6sXfEROJjISIbpeHjQyOnkRONmYpXVrXD4sBmBk5TbGrbYadOSfas4lCnDKBFr/N//aX3gQMx9b77Je48Zw9dRUfRnjB7N/ePGmeca2T3feYd2Y0Pup0/Nc27ejGlfjq/gfOXKCR/X2sz0unAhM6MakToGjx9ze/q01uvWmX4MS7p0oS3cknff1Tpnzvjva41jx+hb6NRJ6wIFtB41ive7coXhwTduMFLLeN8PHzLyCqBfoFYts63wcK0HDuQxI8rGwDL7qZ+f+bxkyZif69q1Wr/9Nt8fQL+SkUG1cmWtq1dnNljD3zN9Ov0WoaG8z+zZcf0YUVFm+3360J+gtdavvcZ9Z8+a5/7zD30Jn37KYyEh3H//Pu/bqZPWn30WM8vqiBE8p0MHRoVpTdl8fJL3XcQGDvIRFAZgmb8xOHpffAwA8Ie1A0qpwUqpAKVUwM3YFUqE1MXJifPrVau4aqd/fw4xU5CGolKlmLbsxG47Ywaf26Im75YtHMlbo3TpmGYhwJw9GDl6lKLZwVg5Gh5OP8J//mOO9q9fp0kA4AjQxWIdv5cXbeNvv82RqVEG4tYtjiSNRV2HDpn3i499+7jNnZsZS4sXjzm7MkxdJUuaI27DBGJw507cSKV798wRdlIpX555lypU4KzGWImdJQtXz3p7U44XXmAEVvbspi9m796YswWjWp2zc9xosq1bmRoCMH0wXbtyRgIwomv9eo7SZ83iIjWAn2/9+kz3cOgQZ6VVqnDm06kTR+tjxpjf95tv0mxpieV3kTu3OWt48oRtFS/O16NH8z4NG5qzQaMoTs6c/PvkycMUFoYfJksWRlcBnJWcPcsZiL9/zBmqzYlPQzzvA0A3AAssXvcG8E085zYBcByAV2LtyowgDWGENQwZwqHMokV2vV1CC5zszeTJvLflQiCtOasBzBFq0aKswQBwQZXWWn/3HXPOJwUjgur6db426iwkRPfuPOfHH/l6//74awMYMfhGm1FRjJQBYubY79SJ+ypUSJrcBlFRHBFfvcrXM2eynXXr4r/myBGe4+ys9ccfm/ufPo0pa2yMdQ4//sgooQkTeF+AI3lL9u3j/v/9j6+fPGE8vzET0JoLFwGOzhPDmlxPnsSc9RmRZnfu8LUxM7OkTRvOSrRmxJblzO/CBc5Ob99mO9OnJy5XwjI7ZkYQDMCyKmsRAFdin6SUqgxgAYCOWutkZPkWHI6nJ4dHpUrRuzV4sJnb2Q4Yk0FjtJyaGKmYY9toDee2MfJ1cwMGDWL0kzHCHjLEHDXG5tEj1j8w8vwbTlYj6iQw0Iz6iY933+W2QQNGVn3wQfzRV7F9DUpx1F2oUEw7eOnStKVbxuonBa3pR5k1i6+NCK3Y+fuXLDHTR5Quzfd44wYjawwSq6tcqhSd9oULc2T+zTd839OmcdRviZHN1lg34OrK77RkSdOhbNjtrdWciI2vr+lwNjDSShsYEUNG5tPYmXjHjuUaDuO+RYvGLBBYrBg/GyM9iWWeK5sTn4Z43geY0O4sgOIAsgAIAlAx1jnFAJwGUDep7cqMII1x7RoNyS1bMhg/Tx4OY+zA3buMlb992y7NJ8jEiRyVHT8ec78xwu7Rg1tfX44MAa0nTUq83chIrhUYOZKvU2KXt2T5ct77p5+sHzdGzF5eKb9HYhQpwnt89hlnCJs2xXQhBQbyeJkyibcFcLSfEJbV6YYPp+09NkZ6jV9/jXvMGIXfvctzWrdOXK47d+LODmNjrDg3/A2xMfwWAwZYP/7oESv0rVvHVdLP+7eCI2YEWusIAG8B2AiafVZprY8qpYYopaJLnGMcAC8Ac5RSgUqpgHiaE9Iq+fNzCLpxI4dSWjNe0jKzl43InZux8sldSWwLxoyh7b5cuZj7X3yR2TyNesZubjGLyiSGkY301Cl+hDNnmrOBlGCsQ4hvFlGgANtPrFbx82CMXENCOONo3tz66N6yKHx8FCgQc41GbPbv533Gj+cIPzSUfoDYoZbvvcforXbtzH2rVzM81hiF587NuP6ffkpcLg8P6wWXLPntN47444tOM2ZLxowgNi4uDKsNCKC/wXJdi62x6zoCrfUGrXUZrXVJrfXk6H1ztdZzo58P1Fp7aq2rRj8SyFYupFneeovz688/pw3k/Hl67uwZ+JzKODszLjw2SvGtGn/2XLlofrl7l07TpFC7Np3XhQrRpNGiRcrlNJRIQrUgmjaNuXbB1hgO5vgUtpEyPL4ymZbcvx+3ZKglxoK6woX5uRtFiGIvb3FyipniA6BiX7WKld4M6te3TTACQIXYunX8xw2HvrU1CQAVgY8PzVoXL8ZdS2FLpB6B8Px4enJBwJIl/Bf168dhXJ8+XM1kwzoGaZXSpbk61Ciok5xom/79GSceEGCuNE0pxv1jRwVZYlnG0x4YtnDLCCBLXniBii8piiAsDNi9O/HzjFF1mTJUApb1KeLDkG/jxoQ/L3thyHz/fvzylipF/1H58jFXoNsaUQSCbfDx4fz8/Hlz6LVyJYfR48dneGVQqFCKC7mhbl2GFz7HIu1nKMV6Qo6kbVuaROJTBEDSndAhIUkzsRmd6po1VKhJUQSVKnFrWd0uNTFmTLdvmyVRY2O5At2efyHJNSTYlshIBocPiXYDTZzI0lXxVX4XoBSwbh1j3jMCDRrQ/GSLDtbTM2FTjdFRGts8eZJuWmvYkEFuo0Y9l4gppnNnpsOoWjX+c4wcS4n5I54Xpe1peLIDfn5+OiBAfMppmqgorq554QWzFFbRovQhdO+e4WcHQurx6qtMD3HypKMlsQ+hoTQL1alDf8bzoJTaH58fVmYEgu1xcmL4zPff00ewfTsDvF99lVnGli2LW69PEFLA22+bK84zIu7uXOFt1zUEEB+BYE8sQyZateKw7dIlxloOH865cbduNBjHl3ZUEBLAyNqZUdGai+bKlLHvfcQ0JKQODx7Qi2gZFpMlC2cGOXPSYNu4MbeVK8dfqkwQhBSRkGlIZgRC6pAzJzOFbd7M1VN37jBsw9WVxs+1a1nXD6BpqUQJBthXr86wkRYtuC8pISSCICQL+VcJqYeLC01ErVrF3B8aygT2Tk7choUxEfudO0xgb+DszPCQwoW5LVmSTuiiRZnUvXBh+4dXCEIGRExDQtpCa0YZHTpEhRESAvz3v8BXX3HljasrVyxnzx6zLJWBuztjCH18qBzy5GHu44IFuc2Xj1svL676sszyJQgZmIRMQ6IIhPSB1jQdTZ7MhDH9+9Pc9Ouv3HfzJpWGkZq0YkUmqr982XrRYIOsWU1lkCsXA9crVuQsw8ODCid/fu738OA9jYe7O30ZEg4rpANEEQiZA63plL5+nbOBbNmYx3nNGtYQ3LWLs4hHj5go7+FD1oIMC2OyoNBQOq9z5eJaiKSs6VeKTu8cOThbcXPjvb29zUJZHh5UJEWLmsmI3N2pZLJn5zW7dvH6bt3MfaJgBBsiikAQLAkLM4v7xubqVSqKEiWYxvKbbzjbuH2bJqszZ7gE2NeXhWqXLmV7lqk0K1TgLOPq1bjFdpODszOVQ5YsnLlky0aTlpcXFdXFi1Qsjx8z0ipLFi7iy5mTCu3CBcpWpw7lK1iQxXDd3NhefFtxyGdIRBEIgj2JiGAkVIEC7KDd3dmh3r/PZa8PHlCZ7NtHh3jfvhztL1sG/P03O+6iRZmy1MWFKTXv3QPmzKFZy8hrDTC7Xd68PH78uH1SUhqmMicnPjceHh5UJoaSc3GhojIe+fJxNgTQJOfmRoVrrDB3caEyU4qyFyhAh3/WrPzcPDzYzv/9n5nuNWdOXps1KxVdliz8vJ2c6P+RWVOSEUUgCOmZqCgqkzt3OAMwspVFRVF5KMXk93nzMtGP1kzUHx7ODnP1avo5nJ25ViM8nEro6lUqmcePOYPIn5/XP37MFeEPHrDTjYjgvfLmZUGGsDCz5Jcl/9/e/cVIdZZxHP/+WJZdCo1rAS3BBihBDDZaCIU21KYXtVLSdNWbVkxa/yS1ptQ/iSFoE1Pv0EYTrmwwNlaD5cJa5aKRNhYkllBAZPkjRahAAkWWpIAUusvu8njxvJOdDjvD7naGM2fP80lOZuacM7Pvs2/mPPOeOfO8HR1+8O/pGZwYodEkX1pbPcmMG+d/u7XV4x03zpcpU/z0nJn/qLG/3xN2R4dffDBrliedixc9abe0eOIqLfPm+f7Hj3s/lI/QBgYGL4U+e9b/P+VJq7fXJ3Bua/MR4qlTnvgnTfJ92tv9woXx4/31p0/3fr5wwfu6vX2wPe3tV091Nux/VSSCEEI9mQ3WlCotkyb5p/feXv9U39vr809u2+an2zo7vXDOhg2+37lzPpIy8wPhQw/56xw75rO8Hz3qjw8f9oPwihX+ePNmf253tx+4L1/2U2Pz5vlB/fXX/cKBgQFfrlzxbTff7Af5ri4/aF+65NvAq7tNmOCnAE+fvjreG24YjLf0nCysWuU1u0Yhs0QgaRmwFmjBJ7JfU7Fdafty4BLwNTPbXes1IxGEEOqir88T1syZg5M5v/++L6Xt/f1+O2OGJw8zn0S4p2fwO6CJE2HBAt/v4EEfcZRGWZJ/qp8zx7dv2+YjlJMnfXtfnyfQpUs9ae3b54mvp8c//Z8/7yOZJUt8+6JF/uv7UcgkEUhqAf4NfB6fyH4n8BUz+1fZPsuBp/BEsARYa2Y1p6uIRBBCCCOXVfXRxcARM/uPmV0GNgCdFft0Ar9NcytvBzokNbjOXgghhHKNTAQzgPLZRk+kdSPdB0mPS9oladeZM2fq3tAQQiiyRiaCoa7rqjwPNZx9MLN1ZrbIzBZNmzatLo0LIYTgGpkITgC3lD3+BPDOKPYJIYTQQI1MBDuBuZJmS5oAPAJsrNhnI/Co3J3AeTM71cA2hRBCqNCw35KbWb+klcAm/PLR583sgKQn0vbngFfwK4aO4JePfr1R7QkhhDC0hhYVMbNX8IN9+brnyu4b8GQj2xBCCKG2mLw+hBAKLnclJiSdAY6P8ulTgQ9RDrLpjKV4xlIsMLbiiVia10jimWlmQ152mbtE8GFI2lXtl3V5NJbiGUuxwNiKJ2JpXvWKJ04NhRBCwUUiCCGEgitaIliXdQPqbCzFM5ZigbEVT8TSvOoST6G+IwghhHC1oo0IQgghVIhEEEIIBVeYRCBpmaRDko5IWp11e0ZK0jFJ+yTtkbQrrbtJ0muSDqfbj2bdzmokPS+pW9L+snVV2y/ph6mvDkn6QjatHlqVWJ6RdDL1z5406VJpWzPHcoukzZIOSjog6btpfV77plo8uesfSe2SdkjqSrH8JK2vf9+Y2Zhf8FpHbwO3AhOALmB+1u0aYQzHgKkV634GrE73VwM/zbqdNdp/D7AQ2H+t9gPzUx+1AbNT37VkHcM1YnkG+MEQ+zZ7LNOBhen+jfisgvNz3DfV4sld/+Bl+ien+63Am8CdjeiboowIhjNbWh51Ai+k+y8AX8ywLTWZ2Vbg3YrV1drfCWwws14zO4oXJVx8XRo6DFViqabZYzllaZ5wM7sAHMQnh8pr31SLp5qmjcfce+lha1qMBvRNURLBsGZCa3IGvCrpH5IeT+s+bqlsd7r9WGatG51q7c9rf62UtDedOioN13MTi6RZwAL8k2fu+6YiHshh/0hqkbQH6AZeM7OG9E1REsGwZkJrckvNbCHwAPCkpHuyblAD5bG/fgnMAW4HTgE/T+tzEYukycBLwPfM7H+1dh1iXR7iyWX/mNmAmd2OT9q1WNJtNXYfdSxFSQS5nwnNzN5Jt93Ay/iQ77Sk6QDptju7Fo5Ktfbnrr/M7HR6014BfsXgkLzpY5HUih8015vZH9Pq3PbNUPHkuX8AzOwcsAVYRgP6piiJYDizpTUtSZMk3Vi6D9wP7MdjeCzt9hjw52xaOGrV2r8ReERSm6TZwFxgRwbtG7bSGzP5Et4/0OSxSBLwa+Cgmf2ibFMu+6ZaPHnsH0nTJHWk+xOB+4C3aETfZP3N+HX8Bn45fgXB28DTWbdnhG2/Fb8aoAs4UGo/MAX4K3A43d6UdVtrxPAiPiTvwz+5fLNW+4GnU18dAh7Iuv3DiOV3wD5gb3pDTs9JLHfjpw/2AnvSsjzHfVMtntz1D/AZ4J+pzfuBH6f1de+bKDERQggFV5RTQyGEEKqIRBBCCAUXiSCEEAouEkEIIRRcJIIQQii4SAShsCRtS7ezJK2o82v/aKi/FUIzistHQ+FJuhevTPngCJ7TYmYDNba/Z2aT69G+EBotRgShsCSVKjuuAT6X6tR/PxX6elbSzlSk7Ftp/3tTrfvf4z9OQtKfUiHAA6VigJLWABPT660v/1tyz0raL59f4uGy194i6Q+S3pK0Pv1KNoSGG591A0JoAqspGxGkA/p5M7tDUhvwhqRX076LgdvMy/wCfMPM3k0lAHZKesnMVktaaV4srNKX8cJnnwWmpudsTdsWAJ/G68O8ASwF/l7/cEP4oBgRhHC1+4FHU/nfN/Gf9M9N23aUJQGA70jqArbjBb/mUtvdwIvmBdBOA38D7ih77RPmhdH2ALPqEk0I1xAjghCuJuApM9v0gZX+XcLFisf3AXeZ2SVJW4D2Ybx2Nb1l9weI92e4TmJEEAJcwKc1LNkEfDuVM0bSJ1PV10ofAc6mJPApfBrBkr7S8ytsBR5O30NMw6e9bIpql6G44hNHCF7dsT+d4vkNsBY/LbM7fWF7hqGnAf0L8ISkvXi1x+1l29YBeyXtNrOvlq1/GbgLryRrwCoz+29KJCFkIi4fDSGEgotTQyGEUHCRCEIIoeAiEYQQQsFFIgghhIKLRBBCCAUXiSCEEAouEkEIIRTc/wFqDiBNP/7TuwAAAABJRU5ErkJggg==\n", 1233 | "text/plain": [ 1234 | "
" 1235 | ] 1236 | }, 1237 | "metadata": { 1238 | "needs_background": "light" 1239 | }, 1240 | "output_type": "display_data" 1241 | } 1242 | ], 1243 | "source": [ 1244 | "get_loss_fig_aux(train_loss_data, test_loss_data)" 1245 | ] 1246 | }, 1247 | { 1248 | "cell_type": "code", 1249 | "execution_count": null, 1250 | "metadata": {}, 1251 | "outputs": [], 1252 | "source": [] 1253 | } 1254 | ], 1255 | "metadata": { 1256 | "kernelspec": { 1257 | "display_name": "Python 3", 1258 | "language": "python", 1259 | "name": "python3" 1260 | }, 1261 | "language_info": { 1262 | "codemirror_mode": { 1263 | "name": "ipython", 1264 | "version": 3 1265 | }, 1266 | "file_extension": ".py", 1267 | "mimetype": "text/x-python", 1268 | "name": "python", 1269 | "nbconvert_exporter": "python", 1270 | "pygments_lexer": "ipython3", 1271 | "version": "3.7.6" 1272 | } 1273 | }, 1274 | "nbformat": 4, 1275 | "nbformat_minor": 4 1276 | } 1277 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DIEN-DIN 2 | 3 | 本项目使用tensorflow2.0复现阿里兴趣排序模型DIEN与DIN。 4 | 5 | DIN论文链接: https://arxiv.org/pdf/1706.06978.pdf 6 | 7 | DIEN论文链接: https://arxiv.org/pdf/1809.03672.pdf 8 | 9 | 数据集使用阿里数据集测试模型代码, 数据集链接: https://tianchi.aliyun.com/dataset/dataDetail?dataId=56 10 | 11 | # 调用方法: 12 | 13 | ## 0. 简介: 14 | 15 | DIEN的输入特征中主要包含三个部分特征: 用户历史行为序列, 目标商品特征, 用户画像特征。 16 | 用户历史行为序列需包含点击序列与非点击序列。 17 | 请按如下1~2方法处理输入特征。 18 | 19 | ## 1. 初始化: 20 | 21 | 初始化DIEN时需传入5个参数: 22 | 23 | (注:feature_list中的特征名称,需要与embedding_dict中的特征名称一样) 24 | 25 | - embedding_count_dict:string->int格式,该变量记录需要embedding各个特征的词典个数,即最大整数索引+ 1的大小; 26 | 27 | - embedding_dim_dict:string->int格式,该变量记录需要embedding各个特征的输出维数,即密集嵌入的尺寸; 28 | 29 | - embedding_features_list:list(string)格式,该变量记录DIEN中user_profile部分所有需要embedding的feature名称; 30 | 31 | - user_behavior_features:list(string)格式,该变量记录DIEN中user_behavior与target_item部分所有需要embedding的feature名称 32 | 33 | - activation:string格式,默认值"PReLU",该变量空值全连接层激活函数,”PReLU“->PReLU,"Dice"->Dice 34 | 35 | ## 2. 模型调用: 36 | 37 | 模型调用需传入6个参数: 38 | 39 | (注:feature_list中的特征名称,需要与dict中的特征名称一样) 40 | 41 | - user_profile_dict:dict:string->Tensor格式,记录user_profile部分的所有输入特征的训练数据; 42 | 43 | - user_profile_list:list(string)格式,记录user_profile部分的所有特征名称; 44 | 45 | - click_behavior_dict:dict:string->Tensor格式,记录user_behavior部分所有点击输入特征的训练数据; 46 | 47 | - noclick_behavior_dict:dict:string->Tensor格式,记录user_behavior部分所有未点击输入特征的训练数据; 48 | 49 | - target_item_dict:dict:string->Tensor格式,记录target_item部分输入特征的训练数据; 50 | 51 | - user_behavior_list:list(string)格式,记录user_behavior部分的所有特征名称。 52 | 53 | # 调用演示代码: 54 | 55 | ## DIEN: 56 | 57 | DIEN_train_example.ipynb 58 | 59 | ## DIN: 60 | 61 | DIN_train_example.ipynb 62 | 63 | # 代码: 64 | 65 | - model.py: 定义模型代码 66 | 67 | - layers.py: 自定义层 68 | 69 | - loss.py: 定义Auxiliary Loss用到的NN 70 | 71 | - activations.py: 定义Dice激活函数 72 | 73 | - alibaba_data_reader.py: 输入数据处理函数(代码中使用数据已用spark处理后得到了所需序列数据, 及特征embedding词典数) 74 | -------------------------------------------------------------------------------- /__pycache__/activations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StephenBo-China/DIEN-DIN/e1d9bb0591f0e0ce5be35cbf328077f6da2a45d2/__pycache__/activations.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/alibaba_data_reader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StephenBo-China/DIEN-DIN/e1d9bb0591f0e0ce5be35cbf328077f6da2a45d2/__pycache__/alibaba_data_reader.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/layers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StephenBo-China/DIEN-DIN/e1d9bb0591f0e0ce5be35cbf328077f6da2a45d2/__pycache__/layers.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StephenBo-China/DIEN-DIN/e1d9bb0591f0e0ce5be35cbf328077f6da2a45d2/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StephenBo-China/DIEN-DIN/e1d9bb0591f0e0ce5be35cbf328077f6da2a45d2/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StephenBo-China/DIEN-DIN/e1d9bb0591f0e0ce5be35cbf328077f6da2a45d2/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /activations.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class Dice(tf.keras.layers.Layer): 4 | def __init__(self): 5 | super(Dice, self).__init__() 6 | self.bn = tf.keras.layers.BatchNormalization(center=False, scale=False) 7 | self.alpha = self.add_weight(shape=(), dtype=tf.float32, name='alpha') 8 | 9 | def call(self, x): 10 | x_normed = self.bn(x) 11 | x_p = tf.sigmoid(x_normed) 12 | return self.alpha * (1.0 - x_p) * x + x_p * x 13 | 14 | class dice(tf.keras.layers.Layer): 15 | def __init__(self, feat_dim): 16 | super(dice, self).__init__() 17 | self.feat_dim = feat_dim 18 | self.alphas= tf.Variable(tf.zeros([feat_dim]), dtype=tf.float32) 19 | self.beta = tf.Variable(tf.zeros([feat_dim]), dtype=tf.float32) 20 | 21 | self.bn = tf.keras.layers.BatchNormalization(center=False, scale=False) 22 | 23 | def call(self, _x, axis=-1, epsilon=0.000000001): 24 | 25 | reduction_axes = list(range(len(_x.get_shape()))) 26 | del reduction_axes[axis] 27 | broadcast_shape = [1] * len(_x.get_shape()) 28 | broadcast_shape[axis] = self.feat_dim 29 | 30 | mean = tf.reduce_mean(_x, axis=reduction_axes) 31 | brodcast_mean = tf.reshape(mean, broadcast_shape) 32 | std = tf.reduce_mean(tf.square(_x - brodcast_mean) + epsilon, axis=reduction_axes) 33 | std = tf.sqrt(std) 34 | brodcast_std = tf.reshape(std, broadcast_shape) 35 | 36 | x_normed = self.bn(_x) 37 | x_p = tf.keras.activations.sigmoid(self.beta * x_normed) 38 | 39 | return self.alphas * (1.0 - x_p) * _x + x_p * _x -------------------------------------------------------------------------------- /alibaba_data_reader.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | def get_embedding_features_list(): 6 | embedding_features_list = ["cate", "brand", "cms_segid", "cms_group", 7 | "gender", "age", "pvalue", "shopping", 8 | "occupation", "user_class_level"] 9 | return embedding_features_list 10 | 11 | def get_user_behavior_features(): 12 | user_behavior_features = ["cate", "brand"] 13 | return user_behavior_features 14 | 15 | def get_embedding_count(feature, embedding_count): 16 | return embedding_count[feature].values[0] 17 | 18 | def get_embedding_count_dict(embedding_features_list, embedding_count): 19 | embedding_count_dict = dict() 20 | for feature in embedding_features_list: 21 | embedding_count_dict[feature] = get_embedding_count(feature, embedding_count) 22 | embedding_count_dict["brand"] = 500000 23 | embedding_count_dict["cate"] = 501578 24 | embedding_count_dict["gender"] = 3 25 | embedding_count_dict["pvalue"] = 10 26 | embedding_count_dict["shopping"] = 4 27 | embedding_count_dict["occupation"] = 5 28 | embedding_count_dict["user_class_level"] = 5 29 | return embedding_count_dict 30 | 31 | def get_embedding_dim_dict(embedding_features_list): 32 | embedding_dim_dict = dict() 33 | for feature in embedding_features_list: 34 | embedding_dim_dict[feature] = 64 35 | return embedding_dim_dict 36 | 37 | def get_data(): 38 | train_data = pd.read_csv("./data/train.csv", sep = "\t") 39 | train_data = train_data.fillna(0) 40 | train_data = train_data[train_data["guide_dien_final_train_data.click_cate"] != 0] 41 | train_data = train_data[train_data["guide_dien_final_train_data.click_brand"] != 0] 42 | test_data = pd.read_csv("./data/test.csv", sep = "\t") 43 | test_data = test_data.fillna(0) 44 | test_data = test_data[test_data["guide_dien_final_train_data.click_cate"] != 0] 45 | test_data = test_data[test_data["guide_dien_final_train_data.click_brand"] != 0] 46 | embedding_count = pd.read_csv("./data/embedding_count.csv") 47 | return train_data, test_data, embedding_count 48 | 49 | def get_normal_data(data, col): 50 | return data[col].values 51 | 52 | def get_sequence_data(data, col): 53 | rst = [] 54 | max_length = 0 55 | for i in data[col].values: 56 | temp = len(list(map(eval,i[1:-1].split(",")))) 57 | if temp > max_length: 58 | max_length = temp 59 | 60 | for i in data[col].values: 61 | temp = list(map(eval,i[1:-1].split(","))) 62 | padding = np.zeros(max_length - len(temp)) 63 | rst.append(list(np.append(np.array(temp), padding))) 64 | return rst 65 | 66 | def get_length(data, col): 67 | rst = [] 68 | for i in data[col].values: 69 | temp = len(list(map(eval,i[1:-1].split(",")))) 70 | rst.append(temp) 71 | return rst 72 | 73 | def convert_tensor(data): 74 | return tf.convert_to_tensor(data) 75 | 76 | def get_batch_data(data, min_batch, batch=100): 77 | # batch_data = None 78 | # if min_batch + batch <= len(data): 79 | # batch_data = data.loc[min_batch:min_batch + batch - 1] 80 | # else: 81 | # batch_data = data.loc[min_batch:] 82 | batch_data = data.sample(n=batch) 83 | click = get_normal_data(batch_data, "guide_dien_final_train_data.clk") 84 | #no_click = get_normal_data(batch_data, "guide_dien_final_train_data.nonclk") 85 | #label = [click, no_click] 86 | #label = click 87 | target_cate = get_normal_data(batch_data, "guide_dien_final_train_data.cate_id") 88 | target_brand = get_normal_data(batch_data, "guide_dien_final_train_data.brand") 89 | cms_segid = get_normal_data(batch_data, "guide_dien_final_train_data.cms_segid") 90 | cms_group = get_normal_data(batch_data, "guide_dien_final_train_data.cms_group_id") 91 | gender = get_normal_data(batch_data, "guide_dien_final_train_data.final_gender_code") 92 | age = get_normal_data(batch_data, "guide_dien_final_train_data.age_level") 93 | pvalue = get_normal_data(batch_data, "guide_dien_final_train_data.pvalue_level") 94 | shopping = get_normal_data(batch_data, "guide_dien_final_train_data.shopping_level") 95 | occupation = get_normal_data(batch_data, "guide_dien_final_train_data.occupation") 96 | user_class_level = get_normal_data(batch_data, "guide_dien_final_train_data.new_user_class_level") 97 | hist_brand_behavior_clk = get_sequence_data(batch_data, "guide_dien_final_train_data.click_brand") 98 | hist_cate_behavior_clk = get_sequence_data(batch_data, "guide_dien_final_train_data.click_cate") 99 | hist_brand_behavior_show = get_sequence_data(batch_data, "guide_dien_final_train_data.show_brand") 100 | hist_cate_behavior_show = get_sequence_data(batch_data, "guide_dien_final_train_data.show_cate") 101 | #reshape_len = convert_tensor(label).numpy().shape[1] 102 | clk_length = get_length(batch_data, "guide_dien_final_train_data.click_brand") 103 | show_length = get_length(batch_data, "guide_dien_final_train_data.show_brand") 104 | return tf.one_hot(click, 2), convert_tensor(target_cate), convert_tensor(target_brand), convert_tensor(cms_segid), convert_tensor(cms_group), convert_tensor(gender), convert_tensor(age), convert_tensor(pvalue), convert_tensor(shopping), convert_tensor(occupation), convert_tensor(user_class_level), convert_tensor(hist_brand_behavior_clk), convert_tensor(hist_cate_behavior_clk), convert_tensor(hist_brand_behavior_show), convert_tensor(hist_cate_behavior_show), min_batch + batch, clk_length, show_length 105 | 106 | def get_test_data(data): 107 | batch_data = data.head(150) 108 | #batch_data = data.sample(n = 50) 109 | click = get_normal_data(batch_data, "guide_dien_final_train_data.clk") 110 | target_cate = get_normal_data(batch_data, "guide_dien_final_train_data.cate_id") 111 | target_brand = get_normal_data(batch_data, "guide_dien_final_train_data.brand") 112 | cms_segid = get_normal_data(batch_data, "guide_dien_final_train_data.cms_segid") 113 | cms_group = get_normal_data(batch_data, "guide_dien_final_train_data.cms_group_id") 114 | gender = get_normal_data(batch_data, "guide_dien_final_train_data.final_gender_code") 115 | age = get_normal_data(batch_data, "guide_dien_final_train_data.age_level") 116 | pvalue = get_normal_data(batch_data, "guide_dien_final_train_data.pvalue_level") 117 | shopping = get_normal_data(batch_data, "guide_dien_final_train_data.shopping_level") 118 | occupation = get_normal_data(batch_data, "guide_dien_final_train_data.occupation") 119 | user_class_level = get_normal_data(batch_data, "guide_dien_final_train_data.new_user_class_level") 120 | hist_brand_behavior_clk = get_sequence_data(batch_data, "guide_dien_final_train_data.click_brand") 121 | hist_cate_behavior_clk = get_sequence_data(batch_data, "guide_dien_final_train_data.click_cate") 122 | hist_brand_behavior_show = get_sequence_data(batch_data, "guide_dien_final_train_data.show_brand") 123 | hist_cate_behavior_show = get_sequence_data(batch_data, "guide_dien_final_train_data.show_cate") 124 | clk_length = get_length(batch_data, "guide_dien_final_train_data.click_brand") 125 | show_length = get_length(batch_data, "guide_dien_final_train_data.show_brand") 126 | return tf.one_hot(click, 2), convert_tensor(target_cate), convert_tensor(target_brand), convert_tensor(cms_segid), convert_tensor(cms_group), convert_tensor(gender), convert_tensor(age), convert_tensor(pvalue), convert_tensor(shopping), convert_tensor(occupation), convert_tensor(user_class_level), convert_tensor(hist_brand_behavior_clk), convert_tensor(hist_cate_behavior_clk), convert_tensor(hist_brand_behavior_show), convert_tensor(hist_cate_behavior_show), clk_length, show_length -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import layers 3 | from activations import Dice,dice 4 | 5 | class GRU_GATES(tf.keras.layers.Layer): 6 | def __init__(self, units): 7 | super(GRU_GATES, self).__init__() 8 | self.linear_act = layers.Dense(units, activation=None, use_bias=True) 9 | self.linear_noact = layers.Dense(units, activation=None, use_bias=False) 10 | 11 | def call(self, a, b, gate_b=None): 12 | if gate_b is None: 13 | return tf.keras.activations.sigmoid(self.linear_act(a) + self.linear_noact(b)) 14 | else: 15 | return tf.keras.activations.tanh(self.linear_act(a) + tf.math.multiply(gate_b, self.linear_noact(b))) 16 | 17 | class AUGRU(layers.Layer): 18 | def __init__(self, units): 19 | super(AUGRU, self).__init__() 20 | self.u_gate = GRU_GATES(units) 21 | self.r_gate = GRU_GATES(units) 22 | self.c_memo = GRU_GATES(units) 23 | 24 | def call(self, inputs, state, att_score): 25 | u = self.u_gate(inputs, state) #u_t 26 | r = self.r_gate(inputs, state) #r_t 27 | c = self.c_memo(inputs, state, r) #\tilde{h_t} 28 | u_= att_score * u #\tilde{u_{t}'} [AUGRU Add] 29 | state_next = (1 - u_) * state + u_ * c #h_t [AUGRU change u_t on output] 30 | return state_next 31 | 32 | class attention(tf.keras.layers.Layer): 33 | def __init__(self, keys_dim): 34 | super(attention, self).__init__() 35 | self.keys_dim = keys_dim 36 | self.fc = tf.keras.Sequential() 37 | self.fc.add(layers.BatchNormalization()) 38 | self.fc.add(layers.Dense(36, activation="sigmoid")) 39 | self.fc.add(dice(36)) 40 | self.fc.add(layers.Dense(1, activation=None)) 41 | 42 | def call(self, queries, keys, keys_length): 43 | #Attention 44 | queries = tf.tile(tf.expand_dims(queries, 1), [1, tf.shape(keys)[1], 1]) 45 | din_all = tf.concat([queries, keys, queries-keys, queries*keys], axis=-1) 46 | outputs = tf.transpose(self.fc(din_all), [0,2,1]) 47 | key_masks = tf.sequence_mask(keys_length, max(keys_length), dtype=tf.bool) 48 | key_masks = tf.expand_dims(key_masks, 1) 49 | paddings = tf.ones_like(outputs) * (-2 ** 32 + 1) 50 | outputs = tf.where(key_masks, outputs, paddings) 51 | outputs = outputs / (self.keys_dim ** 0.5) 52 | #outputs = tf.keras.activations.softmax(outputs, -1) 53 | outputs = tf.keras.activations.sigmoid(outputs) 54 | 55 | #Sum Pooling 56 | outputs = tf.squeeze(tf.matmul(outputs, keys)) 57 | print("outputs:" + str(outputs.numpy().shape)) 58 | return outputs -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import layers 3 | 4 | class AuxLayer(layers.Layer): 5 | def __init__(self): 6 | super().__init__() 7 | self.fc = tf.keras.Sequential() 8 | self.fc.add(layers.BatchNormalization()) 9 | self.fc.add(layers.Dense(100, activation="sigmoid")) 10 | self.fc.add(layers.ReLU()) 11 | self.fc.add(layers.Dense(50, activation="sigmoid")) 12 | self.fc.add(layers.ReLU()) 13 | self.fc.add(layers.Dense(2, activation=None)) 14 | 15 | def call(self, input): 16 | logit = tf.squeeze(self.fc(input)) 17 | return tf.keras.activations.softmax(logit) 18 | 19 | -------------------------------------------------------------------------------- /main.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StephenBo-China/DIEN-DIN/e1d9bb0591f0e0ce5be35cbf328077f6da2a45d2/main.ipynb -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import layers 3 | from layers import AUGRU 4 | from activations import Dice 5 | import pandas as pd 6 | from model import DIEN 7 | import alibaba_data_reader as data_reader 8 | 9 | def train_one_step(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list, label, optimizer, model, alpha, loss_metric): 10 | with tf.GradientTape() as tape: 11 | output, logit, aux_loss = model(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list) 12 | target_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logit,labels=tf.cast(label, dtype=tf.float32))) 13 | final_loss = target_loss + alpha * aux_loss 14 | print("[Train Step] aux_loss=" + str(aux_loss.numpy()) + ", target_loss=" + str(target_loss.numpy()) + ", final_loss=" + str(final_loss.numpy())) 15 | gradient = tape.gradient(final_loss, model.trainable_variables) 16 | clip_gradient, _ = tf.clip_by_global_norm(gradient, 5.0) 17 | optimizer.apply_gradients(zip(clip_gradient, model.trainable_variables)) 18 | loss_metric(final_loss) 19 | 20 | def get_train_data(label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show): 21 | user_profile_dict = { 22 | "cms_segid": cms_segid, 23 | "cms_group": cms_group, 24 | "gender": gender, 25 | "age": age, 26 | "pvalue": pvalue, 27 | "shopping": shopping, 28 | "occupation": occupation, 29 | "user_class_level": user_class_level 30 | } 31 | user_profile_list = ["cms_segid", "cms_group", "gender", "age", "pvalue", "shopping", "occupation", "user_class_level"] 32 | user_behavior_list = ["brand", "cate"] 33 | click_behavior_dict = { 34 | "brand": hist_brand_behavior_clk, 35 | "cate": hist_cate_behavior_clk 36 | } 37 | noclick_behavior_dict = { 38 | "brand": hist_brand_behavior_show, 39 | "cate": hist_cate_behavior_show 40 | } 41 | target_item_dict = { 42 | "brand": target_cate, 43 | "cate": target_brand 44 | } 45 | return user_profile_dict, user_profile_list, user_behavior_list, click_behavior_dict, noclick_behavior_dict, target_item_dict 46 | 47 | def main(): 48 | train_data, test_data, embedding_count = data_reader.get_data() 49 | embedding_features_list = data_reader.get_embedding_features_list() 50 | user_behavior_features = data_reader.get_user_behavior_features() 51 | embedding_count_dict = data_reader.get_embedding_count_dict(embedding_features_list, embedding_count) 52 | embedding_dim_dict = data_reader.get_embedding_dim_dict(embedding_features_list) 53 | model = DIEN(embedding_count_dict, embedding_dim_dict, embedding_features_list, user_behavior_features) 54 | min_batch = 0 55 | batch = 100 56 | label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show, min_batch, clk_length, show_length = data_reader.get_batch_data(train_data, min_batch, batch = batch) 57 | user_profile_dict, user_profile_list, user_behavior_list, click_behavior_dict, noclick_behavior_dict, target_item_dict = get_train_data(label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show) 58 | log_path = "./train_log/" 59 | train_summary_writer = tf.summary.create_file_writer(log_path) 60 | optimizer = tf.keras.optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0) 61 | loss_metric = tf.keras.metrics.Sum() 62 | auc_metric = tf.keras.metrics.AUC() 63 | alpha = 1 64 | epochs = 1 65 | for epoch in range(epochs): 66 | min_batch = 0 67 | for i in range(int(len(train_data) / batch)): 68 | label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show, min_batch, clk_length, show_length = data_reader.get_batch_data(train_data, min_batch, batch = batch) 69 | user_profile_dict, user_profile_list, user_behavior_list, click_behavior_dict, noclick_behavior_dict, target_item_dict = get_train_data(label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show) 70 | train_one_step(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list, label, optimizer, model, alpha, loss_metric) 71 | 72 | 73 | if __name__ == "__main__": 74 | print(tf.__version__) 75 | print("GPU Available: ", tf.test.is_gpu_available()) 76 | main() -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import layers 3 | from layers import AUGRU,attention 4 | from activations import Dice,dice 5 | from loss import AuxLayer 6 | import utils 7 | 8 | class DIEN(tf.keras.Model): 9 | def __init__(self, embedding_count_dict, embedding_dim_dict, embedding_features_list, user_behavior_features, activation="PReLU"): 10 | super(DIEN, self).__init__(embedding_count_dict, embedding_dim_dict, embedding_features_list, activation) 11 | """DIEN初始化model函数 12 | 13 | 该函数在调用DIEN时进行DIEN的Embedding层,GRU层,AUGRU层,全连接层的初始化操作 14 | 15 | Args: 16 | embedding_count_dict:string->int格式,该变量记录需要embedding各个特征的词典个数,即最大整数索引+ 1的大小; 17 | embedding_dim_dict:string->int格式,该变量记录需要embedding各个特征的输出维数,即密集嵌入的尺寸; 18 | embedding_features_list:list(string)格式,该变量记录DIEN中user_profile部分所有需要embedding的feature名称; 19 | user_behavior_features:list(string)格式,该变量记录DIEN中user_behavior与target_item部分所有需要embedding的feature名称 20 | activation:string格式,默认值"PReLU",该变量空值全连接层激活函数,”PReLU“->PReLU,"Dice"->Dice 21 | """ 22 | #Init Embedding Layer 23 | self.embedding_dim_dict = embedding_dim_dict 24 | self.embedding_count_dict = embedding_count_dict 25 | self.embedding_layers = dict() 26 | for feature in embedding_features_list: 27 | self.embedding_layers[feature] = layers.Embedding(embedding_count_dict[feature], embedding_dim_dict[feature]) 28 | #Init GRU Layer 29 | self.user_behavior_gru = layers.GRU(self.get_GRU_input_dim(embedding_dim_dict, user_behavior_features), return_sequences=True) 30 | #Init Attention Layer 31 | self.attention_layer = layers.Softmax() 32 | #Init Auxiliary Layer 33 | self.AuxNet = AuxLayer() 34 | #Init AUGRU Layer 35 | self.user_behavior_augru = AUGRU(self.get_GRU_input_dim(embedding_dim_dict, user_behavior_features)) 36 | #Init Fully Connection Layer 37 | self.fc = tf.keras.Sequential() 38 | self.fc.add(layers.BatchNormalization()) 39 | self.fc.add(layers.Dense(200, activation="relu")) 40 | if activation == "Dice": 41 | self.fc.add(Dice()) 42 | elif activation == "dice": 43 | self.fc.add(dice(200)) 44 | elif activation == "PReLU": 45 | self.fc.add(layers.PReLU(alpha_initializer='zeros', weights=None)) 46 | self.fc.add(layers.Dense(80, activation="relu")) 47 | if activation == "Dice": 48 | self.fc.add(Dice()) 49 | elif activation == "dice": 50 | self.fc.add(dice(80)) 51 | elif activation == "PReLU": 52 | self.fc.add(layers.PReLU(alpha_initializer='zeros', weights=None)) 53 | self.fc.add(layers.Dense(2, activation=None)) 54 | 55 | def get_GRU_input_dim(self, embedding_dim_dict, user_behavior_features): 56 | rst = 0 57 | for feature in user_behavior_features: 58 | rst += embedding_dim_dict[feature] 59 | return rst 60 | 61 | def get_emb(self, user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list): 62 | user_profile_feature_embedding = dict() 63 | for feature in user_profile_list: 64 | data = user_profile_dict[feature] 65 | embedding_layer = self.embedding_layers[feature] 66 | user_profile_feature_embedding[feature] = embedding_layer(data) 67 | 68 | target_item_feature_embedding = dict() 69 | for feature in user_behavior_list: 70 | data = target_item_dict[feature] 71 | embedding_layer = self.embedding_layers[feature] 72 | target_item_feature_embedding[feature] = embedding_layer(data) 73 | 74 | click_behavior_embedding = dict() 75 | for feature in user_behavior_list: 76 | data = click_behavior_dict[feature] 77 | embedding_layer = self.embedding_layers[feature] 78 | click_behavior_embedding[feature] = embedding_layer(data) 79 | 80 | # noclick_behavior_embedding = dict() 81 | # for feature in user_behavior_list: 82 | # data = noclick_behavior_dict[feature] 83 | # embedding_layer = self.embedding_layers[feature] 84 | # noclick_behavior_embedding[feature] = embedding_layer(data) 85 | 86 | return utils.concat_features(user_profile_feature_embedding), utils.concat_features(target_item_feature_embedding), utils.concat_features(click_behavior_embedding)#, utils.concat_features(noclick_behavior_embedding) 87 | 88 | def auxiliary_loss(self, hidden_states, embedding_out): 89 | """Auxiliary Loss Function 90 | 91 | 论文中包含的源代码aux loss是通过hidden state与点击序列concate和hidden state 92 | 与展现序列concat后进一个全连接神经网络,通过softmax得到最终二分类结果与点击序列和展现序列求解log_loss的到最终aux loss。 93 | 该方法只使用用户的点击序列。 94 | 95 | Args: 96 | hidden_states: gru产出的所有hidden state,从h(0)到h(n-1) 97 | embedding_out: gru输入的embedding特征,从e(1)到e(n) 98 | """ 99 | click_input_ = tf.concat([hidden_states, embedding_out], -1) 100 | click_prop_ = self.AuxNet(click_input_)[:, :, 0] 101 | click_loss_ = - tf.reshape(tf.math.log(click_prop_), [-1, tf.shape(embedding_out)[1]]) 102 | return tf.reduce_mean(click_loss_) 103 | 104 | def call(self, user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list): 105 | """输入batch训练数据, 调用DIEN初始化后的model进行一次前向传播 106 | 107 | 调用该函数进行一次前向传播得到output, logit, aux_loss后,在自定义的训练函数内得出target_loss与final_loss后使用tensorflow中的梯度计算函数通过链式法则得到各层梯度后使用自定义优化器进行一次权重更新 108 | 109 | Args: 110 | user_profile_dict:dict:string->Tensor格式,记录user_profile部分的所有输入特征的训练数据; 111 | user_profile_list:list(string)格式,记录user_profile部分的所有特征名称; 112 | click_behavior_dict:dict:string->Tensor格式,记录user_behavior部分所有点击输入特征的训练数据; 113 | noclick_behavior_dict:dict:string->Tensor格式,记录user_behavior部分所有未点击输入特征的训练数据; 114 | target_item_dict:dict:string->Tensor格式,记录target_item部分输入特征的训练数据; 115 | user_behavior_list:list(string)Tensor格式,记录user_behavior部分的所有特征名称。 116 | """ 117 | #Embedding Layer 118 | user_profile_embedding, target_item_embedding, click_behavior_emebedding = self.get_emb(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list) 119 | #GRU Layer 120 | click_gru_emb = self.user_behavior_gru(click_behavior_emebedding) 121 | #noclick_gru_emb = self.user_behavior_gru(noclick_behavior_emebedding) 122 | #Auxiliary Loss 123 | aux_loss = self.auxiliary_loss(click_gru_emb[:, :-1, :], click_behavior_emebedding[:, 1:, :]) 124 | #Attention Layer 125 | hist_attn = self.attention_layer(tf.matmul(tf.expand_dims(target_item_embedding, 1), click_gru_emb, transpose_b=True)) 126 | #AUGRU Layer 127 | augru_hidden_state = tf.zeros_like(click_gru_emb[:, 0, :]) 128 | for in_emb, in_att in zip(tf.transpose(click_gru_emb, [1, 0, 2]), tf.transpose(hist_attn, [2, 0, 1])): 129 | augru_hidden_state = self.user_behavior_augru(in_emb, augru_hidden_state, in_att) 130 | join_emb = tf.concat([augru_hidden_state, user_profile_embedding], -1) 131 | logit = tf.squeeze(self.fc(join_emb)) 132 | output = tf.keras.activations.softmax(logit) 133 | return output, logit, aux_loss 134 | 135 | class DIN(tf.keras.Model): 136 | def __init__(self, embedding_count_dict, embedding_dim_dict, embedding_features_list, user_behavior_features, activation="PReLU"): 137 | super(DIN, self).__init__(embedding_count_dict, embedding_dim_dict, embedding_features_list, user_behavior_features, activation) 138 | #Init Embedding Layer 139 | self.embedding_dim_dict = embedding_dim_dict 140 | self.embedding_count_dict = embedding_count_dict 141 | self.embedding_layers = dict() 142 | for feature in embedding_features_list: 143 | self.embedding_layers[feature] = layers.Embedding(embedding_count_dict[feature], embedding_dim_dict[feature]) 144 | #DIN Attention+Sum pooling 145 | self.hist_at = attention(utils.get_input_dim(embedding_dim_dict, user_behavior_features)) 146 | #Init Fully Connection Layer 147 | self.fc = tf.keras.Sequential() 148 | self.fc.add(layers.BatchNormalization()) 149 | self.fc.add(layers.Dense(200, activation="relu")) 150 | if activation == "Dice": 151 | self.fc.add(Dice()) 152 | elif activation == "dice": 153 | self.fc.add(dice(200)) 154 | elif activation == "PReLU": 155 | self.fc.add(layers.PReLU(alpha_initializer='zeros', weights=None)) 156 | self.fc.add(layers.Dense(80, activation="relu")) 157 | if activation == "Dice": 158 | self.fc.add(Dice()) 159 | elif activation == "dice": 160 | self.fc.add(dice(80)) 161 | elif activation == "PReLU": 162 | self.fc.add(layers.PReLU(alpha_initializer='zeros', weights=None)) 163 | self.fc.add(layers.Dense(2, activation=None)) 164 | 165 | def get_emb_din(self, user_profile_dict, user_profile_list, hist_behavior_dict, target_item_dict, user_behavior_list): 166 | user_profile_feature_embedding = dict() 167 | for feature in user_profile_list: 168 | data = user_profile_dict[feature] 169 | embedding_layer = self.embedding_layers[feature] 170 | user_profile_feature_embedding[feature] = embedding_layer(data) 171 | 172 | target_item_feature_embedding = dict() 173 | for feature in user_behavior_list: 174 | data = target_item_dict[feature] 175 | embedding_layer = self.embedding_layers[feature] 176 | target_item_feature_embedding[feature] = embedding_layer(data) 177 | 178 | hist_behavior_embedding = dict() 179 | for feature in user_behavior_list: 180 | data = hist_behavior_dict[feature] 181 | embedding_layer = self.embedding_layers[feature] 182 | hist_behavior_embedding[feature] = embedding_layer(data) 183 | 184 | return utils.concat_features(user_profile_feature_embedding), utils.concat_features(target_item_feature_embedding), utils.concat_features(hist_behavior_embedding) 185 | 186 | def call(self, user_profile_dict, user_profile_list, hist_behavior_dict, target_item_dict, user_behavior_list, length): 187 | #Embedding Layer 188 | user_profile_embedding, target_item_embedding, hist_behavior_emebedding = self.get_emb_din(user_profile_dict, user_profile_list, hist_behavior_dict, target_item_dict, user_behavior_list) 189 | hist_attn_emb = self.hist_at(target_item_embedding, hist_behavior_emebedding, length) 190 | join_emb = tf.concat([user_profile_embedding, target_item_embedding, hist_attn_emb], -1) 191 | logit = tf.squeeze(self.fc(join_emb)) 192 | output = tf.keras.activations.softmax(logit) 193 | return output, logit 194 | 195 | if __name__ == "__main__": 196 | model = DIN(dict(), dict(), list(), list()) 197 | -------------------------------------------------------------------------------- /tensorboard.log: -------------------------------------------------------------------------------- 1 | nohup: ignoring input 2 | TensorBoard 2.0.0 at http://10.186.3.226:8028/ (Press CTRL+C to quit) 3 | -------------------------------------------------------------------------------- /tensorboard.sh: -------------------------------------------------------------------------------- 1 | tensorboard --logdir=./train_log/din/ --host=10.186.3.226 --port=8028 2 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import tensorflow as tf 3 | 4 | def get_file_name(): 5 | now_time = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) 6 | return "loss.csv." + now_time 7 | 8 | def make_train_loss_dir(file_name, cols=["train_aux_loss","train_target_loss","train_final_loss"], model="dien"): 9 | f = open("./loss/" + model + "/train_" + file_name, "w") 10 | f.write(",".join(cols) + "\n") 11 | f.close() 12 | 13 | def make_test_loss_dir(file_name, cols=["test_aux_loss","test_target_loss","test_final_loss"], model="dien"): 14 | f = open("./loss/" + model + "/test_" + file_name, "w") 15 | f.write(",".join(cols) + "\n") 16 | f.close() 17 | 18 | def add_loss(loss_dict, file_name, cols = ["aux_loss", "target_loss", "final_loss"], level="train", model="dien"): 19 | loss_list = list() 20 | for col in cols: 21 | loss_list.append(loss_dict[col]) 22 | f = open("./loss/" + model + "/" + level + "_" + file_name, "a") 23 | f.write(",".join(loss_list) + "\n") 24 | f.close() 25 | 26 | def get_input_dim(embedding_dim_dict, user_behavior_features): 27 | rst = 0 28 | for feature in user_behavior_features: 29 | rst += embedding_dim_dict[feature] 30 | return rst 31 | 32 | def concat_features(feature_data_dict): 33 | concat_list = [] 34 | for k in feature_data_dict: 35 | concat_list.append(feature_data_dict[k]) 36 | return tf.concat(concat_list, -1) 37 | 38 | def mkdir(path): 39 | try: 40 | if not os.path.exists(path): 41 | os.makedirs(path) 42 | return 0 43 | except: 44 | return 1 --------------------------------------------------------------------------------