├── DIEN_train_example.ipynb
├── DIN_train_example.ipynb
├── README.md
├── __pycache__
├── activations.cpython-37.pyc
├── alibaba_data_reader.cpython-37.pyc
├── layers.cpython-37.pyc
├── loss.cpython-37.pyc
├── model.cpython-37.pyc
└── utils.cpython-37.pyc
├── activations.py
├── alibaba_data_reader.py
├── layers.py
├── loss.py
├── main.ipynb
├── main.py
├── model.py
├── tensorboard.log
├── tensorboard.sh
└── utils.py
/DIEN_train_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import tensorflow as tf\n",
10 | "from tensorflow.keras import layers\n",
11 | "from layers import AUGRU\n",
12 | "from activations import Dice,dice\n",
13 | "import pandas as pd\n",
14 | "from model import DIEN\n",
15 | "import alibaba_data_reader as data_reader\n",
16 | "import utils\n",
17 | "import matplotlib\n",
18 | "import matplotlib.pyplot as plt\n",
19 | "from matplotlib.font_manager import FontProperties\n",
20 | "from matplotlib.pyplot import MultipleLocator\n",
21 | "import numpy as np\n",
22 | "import os\n",
23 | "from loss import AuxLayer"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 2,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "def mkdir(path):\n",
33 | " try:\n",
34 | " if not os.path.exists(path):\n",
35 | " os.makedirs(path)\n",
36 | " return 0\n",
37 | " except:\n",
38 | " return 1\n",
39 | "model_name = \"dien\""
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 3,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "def is_in_notebook():\n",
49 | " import sys\n",
50 | " return 'ipykernel' in sys.modules\n",
51 | "def clear_output():\n",
52 | " \"\"\"\n",
53 | " clear output for both jupyter notebook and the console\n",
54 | " \"\"\"\n",
55 | " import os\n",
56 | " os.system('cls' if os.name == 'nt' else 'clear')\n",
57 | " if is_in_notebook():\n",
58 | " from IPython.display import clear_output as clear\n",
59 | " clear()"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 4,
65 | "metadata": {},
66 | "outputs": [
67 | {
68 | "name": "stdout",
69 | "output_type": "stream",
70 | "text": [
71 | "2\n"
72 | ]
73 | }
74 | ],
75 | "source": [
76 | "print(1)\n",
77 | "clear_output()\n",
78 | "print(2)"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 5,
84 | "metadata": {
85 | "tags": []
86 | },
87 | "outputs": [
88 | {
89 | "name": "stdout",
90 | "output_type": "stream",
91 | "text": [
92 | "2.0.0\n",
93 | "GPU Available: True\n"
94 | ]
95 | }
96 | ],
97 | "source": [
98 | "print(tf.__version__)\n",
99 | "print(\"GPU Available: \", tf.test.is_gpu_available())"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 6,
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "file_path = \"/nfs/project/boweihan_2/DIEN/dien_final/\"\n",
109 | "file_path = \"\""
110 | ]
111 | },
112 | {
113 | "cell_type": "markdown",
114 | "metadata": {},
115 | "source": [
116 | "# 模型训练"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": 7,
122 | "metadata": {},
123 | "outputs": [
124 | {
125 | "data": {
126 | "text/html": [
127 | "
\n",
128 | "\n",
141 | "
\n",
142 | " \n",
143 | " \n",
144 | " | \n",
145 | " brand | \n",
146 | " cate | \n",
147 | " cms_segid | \n",
148 | " cms_group | \n",
149 | " gender | \n",
150 | " age | \n",
151 | " pvalue | \n",
152 | " shopping | \n",
153 | " occupation | \n",
154 | " user_class_level | \n",
155 | "
\n",
156 | " \n",
157 | " \n",
158 | " \n",
159 | " 0 | \n",
160 | " 460561 | \n",
161 | " 12968 | \n",
162 | " 97 | \n",
163 | " 13 | \n",
164 | " 2 | \n",
165 | " 7 | \n",
166 | " 3 | \n",
167 | " 3 | \n",
168 | " 2 | \n",
169 | " 4 | \n",
170 | "
\n",
171 | " \n",
172 | "
\n",
173 | "
"
174 | ],
175 | "text/plain": [
176 | " brand cate cms_segid cms_group gender age pvalue shopping \\\n",
177 | "0 460561 12968 97 13 2 7 3 3 \n",
178 | "\n",
179 | " occupation user_class_level \n",
180 | "0 2 4 "
181 | ]
182 | },
183 | "execution_count": 7,
184 | "metadata": {},
185 | "output_type": "execute_result"
186 | }
187 | ],
188 | "source": [
189 | "train_data, test_data, embedding_count = data_reader.get_data()\n",
190 | "embedding_count"
191 | ]
192 | },
193 | {
194 | "cell_type": "code",
195 | "execution_count": 8,
196 | "metadata": {},
197 | "outputs": [],
198 | "source": [
199 | "embedding_features_list = data_reader.get_embedding_features_list()\n",
200 | "user_behavior_features = data_reader.get_user_behavior_features()\n",
201 | "embedding_count_dict = data_reader.get_embedding_count_dict(embedding_features_list, embedding_count)\n",
202 | "embedding_dim_dict = data_reader.get_embedding_dim_dict(embedding_features_list)"
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": 9,
208 | "metadata": {},
209 | "outputs": [],
210 | "source": [
211 | "import time\n",
212 | "stamp = time.strftime(\"%Y%m%d-%H%M%S\", time.localtime())\n",
213 | "mkdir(\"./train_log/\" + model_name)\n",
214 | "log_path = \"./train_log/\"+model_name+\"/%s\" % stamp\n",
215 | "train_summary_writer = tf.summary.create_file_writer(log_path)\n",
216 | "tf.summary.trace_on(graph=True, profiler=True)\n",
217 | "loss_file_name = utils.get_file_name()\n",
218 | "mkdir(\"./loss/\" + model_name + \"/\")\n",
219 | "utils.make_train_loss_dir(loss_file_name, cols=[\"train_aux_loss\",\"train_target_loss\",\"train_final_loss\"], model=model_name)\n",
220 | "utils.make_test_loss_dir(loss_file_name, cols=[\"test_aux_loss\",\"test_target_loss\",\"test_final_loss\"], model=model_name)"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 10,
226 | "metadata": {},
227 | "outputs": [
228 | {
229 | "data": {
230 | "text/plain": [
231 | ""
232 | ]
233 | },
234 | "execution_count": 10,
235 | "metadata": {},
236 | "output_type": "execute_result"
237 | }
238 | ],
239 | "source": [
240 | "model = DIEN(\n",
241 | " embedding_count_dict, \n",
242 | " embedding_dim_dict, \n",
243 | " embedding_features_list, \n",
244 | " user_behavior_features, \n",
245 | " activation=\"dice\")\n",
246 | "model"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": 11,
252 | "metadata": {},
253 | "outputs": [],
254 | "source": [
255 | "min_batch = 0\n",
256 | "batch = 100\n",
257 | "optimizer = tf.keras.optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)\n",
258 | "loss_metric = tf.keras.metrics.Sum()\n",
259 | "auc_metric = tf.keras.metrics.AUC()\n",
260 | "alpha = 1\n",
261 | "epochs = 3"
262 | ]
263 | },
264 | {
265 | "cell_type": "code",
266 | "execution_count": 12,
267 | "metadata": {},
268 | "outputs": [],
269 | "source": [
270 | "label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show, min_batch, clk_length, show_length = data_reader.get_batch_data(train_data, min_batch, batch = batch)"
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": 13,
276 | "metadata": {},
277 | "outputs": [],
278 | "source": [
279 | "def get_train_data(label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show):\n",
280 | " user_profile_dict = {\n",
281 | " \"cms_segid\": cms_segid,\n",
282 | " \"cms_group\": cms_group,\n",
283 | " \"gender\": gender,\n",
284 | " \"age\": age,\n",
285 | " \"pvalue\": pvalue,\n",
286 | " \"shopping\": shopping,\n",
287 | " \"occupation\": occupation,\n",
288 | " \"user_class_level\": user_class_level\n",
289 | " }\n",
290 | " user_profile_list = [\"cms_segid\", \"cms_group\", \"gender\", \"age\", \"pvalue\", \"shopping\", \"occupation\", \"user_class_level\"]\n",
291 | " user_behavior_list = [\"brand\", \"cate\"]\n",
292 | " click_behavior_dict = {\n",
293 | " \"brand\": hist_brand_behavior_clk,\n",
294 | " \"cate\": hist_cate_behavior_clk\n",
295 | " }\n",
296 | " noclick_behavior_dict = {\n",
297 | " \"brand\": hist_brand_behavior_show,\n",
298 | " \"cate\": hist_cate_behavior_show\n",
299 | " }\n",
300 | " target_item_dict = {\n",
301 | " \"brand\": target_cate,\n",
302 | " \"cate\": target_brand\n",
303 | " }\n",
304 | " return user_profile_dict, user_profile_list, user_behavior_list, click_behavior_dict, noclick_behavior_dict, target_item_dict"
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": 14,
310 | "metadata": {},
311 | "outputs": [],
312 | "source": [
313 | "user_profile_dict, user_profile_list, user_behavior_list, click_behavior_dict, noclick_behavior_dict, target_item_dict = get_train_data(label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show) "
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": 15,
319 | "metadata": {},
320 | "outputs": [],
321 | "source": [
322 | "def train_one_step(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list, label):\n",
323 | " with tf.GradientTape() as tape:\n",
324 | " output, logit, aux_loss = model(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list)\n",
325 | " target_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logit,labels=tf.cast(label, dtype=tf.float32)))\n",
326 | " final_loss = target_loss + alpha * aux_loss\n",
327 | " #print(\"[Train Loss] aux_loss=\" + str(aux_loss.numpy()) + \", target_loss=\" + str(target_loss.numpy()) + \", final_loss=\" + str(final_loss.numpy()))\n",
328 | " gradient = tape.gradient(final_loss, model.trainable_variables)\n",
329 | " clip_gradient, _ = tf.clip_by_global_norm(gradient, 5.0)\n",
330 | " optimizer.apply_gradients(zip(clip_gradient, model.trainable_variables))\n",
331 | " loss_metric(final_loss)\n",
332 | " return aux_loss.numpy(), target_loss.numpy(), final_loss.numpy()"
333 | ]
334 | },
335 | {
336 | "cell_type": "code",
337 | "execution_count": 16,
338 | "metadata": {},
339 | "outputs": [],
340 | "source": [
341 | "def get_test_loss(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list, label):\n",
342 | " output, logit, aux_loss = model(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list)\n",
343 | " target_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logit,labels=tf.cast(label, dtype=tf.float32)))\n",
344 | " final_loss = target_loss + alpha * aux_loss\n",
345 | " #print(\"[Test Loss] aux_loss=\" + str(aux_loss.numpy()) + \", target_loss=\" + str(target_loss.numpy()) + \", final_loss=\" + str(final_loss.numpy()))\n",
346 | " return aux_loss.numpy(), target_loss.numpy(), final_loss.numpy()"
347 | ]
348 | },
349 | {
350 | "cell_type": "code",
351 | "execution_count": 18,
352 | "metadata": {},
353 | "outputs": [],
354 | "source": [
355 | "#aux_loss, target_loss, final_loss = train_one_step(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list, label)"
356 | ]
357 | },
358 | {
359 | "cell_type": "code",
360 | "execution_count": 17,
361 | "metadata": {},
362 | "outputs": [
363 | {
364 | "name": "stdout",
365 | "output_type": "stream",
366 | "text": [
367 | "WARNING:tensorflow:Layer dien is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because it's dtype defaults to floatx.\n",
368 | "\n",
369 | "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
370 | "\n",
371 | "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
372 | "\n"
373 | ]
374 | },
375 | {
376 | "data": {
377 | "text/plain": [
378 | "(0.89547175, 0.69206244, 1.5875342)"
379 | ]
380 | },
381 | "execution_count": 17,
382 | "metadata": {},
383 | "output_type": "execute_result"
384 | }
385 | ],
386 | "source": [
387 | "get_test_loss(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list, label)"
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "execution_count": 18,
393 | "metadata": {},
394 | "outputs": [
395 | {
396 | "name": "stdout",
397 | "output_type": "stream",
398 | "text": [
399 | "Model: \"dien\"\n",
400 | "_________________________________________________________________\n",
401 | "Layer (type) Output Shape Param # \n",
402 | "=================================================================\n",
403 | "embedding_5 (Embedding) multiple 448 \n",
404 | "_________________________________________________________________\n",
405 | "embedding_1 (Embedding) multiple 32000000 \n",
406 | "_________________________________________________________________\n",
407 | "embedding (Embedding) multiple 32100992 \n",
408 | "_________________________________________________________________\n",
409 | "embedding_3 (Embedding) multiple 832 \n",
410 | "_________________________________________________________________\n",
411 | "embedding_2 (Embedding) multiple 6208 \n",
412 | "_________________________________________________________________\n",
413 | "embedding_4 (Embedding) multiple 192 \n",
414 | "_________________________________________________________________\n",
415 | "embedding_8 (Embedding) multiple 320 \n",
416 | "_________________________________________________________________\n",
417 | "embedding_6 (Embedding) multiple 640 \n",
418 | "_________________________________________________________________\n",
419 | "embedding_7 (Embedding) multiple 256 \n",
420 | "_________________________________________________________________\n",
421 | "embedding_9 (Embedding) multiple 320 \n",
422 | "_________________________________________________________________\n",
423 | "gru (GRU) multiple 99072 \n",
424 | "_________________________________________________________________\n",
425 | "softmax (Softmax) multiple 0 \n",
426 | "_________________________________________________________________\n",
427 | "aux_layer (AuxLayer) multiple 31876 \n",
428 | "_________________________________________________________________\n",
429 | "augru (AUGRU) multiple 98688 \n",
430 | "_________________________________________________________________\n",
431 | "sequential_1 (Sequential) multiple 148122 \n",
432 | "=================================================================\n",
433 | "Total params: 64,487,966\n",
434 | "Trainable params: 64,485,614\n",
435 | "Non-trainable params: 2,352\n",
436 | "_________________________________________________________________\n"
437 | ]
438 | }
439 | ],
440 | "source": [
441 | "model.summary()"
442 | ]
443 | },
444 | {
445 | "cell_type": "code",
446 | "execution_count": 19,
447 | "metadata": {},
448 | "outputs": [
449 | {
450 | "name": "stdout",
451 | "output_type": "stream",
452 | "text": [
453 | "dien/embedding_5/embeddings:0\n",
454 | "dien/embedding_1/embeddings:0\n",
455 | "dien/embedding/embeddings:0\n",
456 | "dien/embedding_3/embeddings:0\n",
457 | "dien/embedding_2/embeddings:0\n",
458 | "dien/embedding_4/embeddings:0\n",
459 | "dien/embedding_8/embeddings:0\n",
460 | "dien/embedding_6/embeddings:0\n",
461 | "dien/embedding_7/embeddings:0\n",
462 | "dien/embedding_9/embeddings:0\n",
463 | "dien/gru/kernel:0\n",
464 | "dien/gru/recurrent_kernel:0\n",
465 | "dien/gru/bias:0\n",
466 | "dien/aux_layer/sequential/batch_normalization/gamma:0\n",
467 | "dien/aux_layer/sequential/batch_normalization/beta:0\n",
468 | "dien/aux_layer/sequential/dense/kernel:0\n",
469 | "dien/aux_layer/sequential/dense/bias:0\n",
470 | "dien/aux_layer/sequential/dense_1/kernel:0\n",
471 | "dien/aux_layer/sequential/dense_1/bias:0\n",
472 | "dien/aux_layer/sequential/dense_2/kernel:0\n",
473 | "dien/aux_layer/sequential/dense_2/bias:0\n",
474 | "dien/augru/gru_gates/dense_3/kernel:0\n",
475 | "dien/augru/gru_gates/dense_3/bias:0\n",
476 | "dien/augru/gru_gates/dense_4/kernel:0\n",
477 | "dien/augru/gru_gates_1/dense_5/kernel:0\n",
478 | "dien/augru/gru_gates_1/dense_5/bias:0\n",
479 | "dien/augru/gru_gates_1/dense_6/kernel:0\n",
480 | "dien/augru/gru_gates_2/dense_7/kernel:0\n",
481 | "dien/augru/gru_gates_2/dense_7/bias:0\n",
482 | "dien/augru/gru_gates_2/dense_8/kernel:0\n",
483 | "dien/sequential_1/batch_normalization_1/gamma:0\n",
484 | "dien/sequential_1/batch_normalization_1/beta:0\n",
485 | "dien/sequential_1/dense_9/kernel:0\n",
486 | "dien/sequential_1/dense_9/bias:0\n",
487 | "Variable:0\n",
488 | "Variable:0\n",
489 | "dien/sequential_1/dense_10/kernel:0\n",
490 | "dien/sequential_1/dense_10/bias:0\n",
491 | "Variable:0\n",
492 | "Variable:0\n",
493 | "dien/sequential_1/dense_11/kernel:0\n",
494 | "dien/sequential_1/dense_11/bias:0\n"
495 | ]
496 | }
497 | ],
498 | "source": [
499 | "for var in model.trainable_variables:\n",
500 | " print(var.name)"
501 | ]
502 | },
503 | {
504 | "cell_type": "code",
505 | "execution_count": 20,
506 | "metadata": {},
507 | "outputs": [],
508 | "source": [
509 | "def get_loss_fig(train_loss, test_loss):\n",
510 | " loss_list = [\"aux_loss\", \"final_loss\"]\n",
511 | " color_list = [\"r\", \"b\"]\n",
512 | " plt.figure()\n",
513 | " cnt = 0\n",
514 | " for k in loss_list:\n",
515 | " loss = train_loss[k]\n",
516 | " step = list(np.arange(len(loss)))\n",
517 | " plt.plot(step,loss,color_list[cnt]+\"-\",label=\"train_\" + k, linestyle=\"--\")\n",
518 | " cnt += 1\n",
519 | " cnt = 0\n",
520 | " for k in loss_list:\n",
521 | " loss = test_loss[k]\n",
522 | " step = list(np.arange(len(loss)))\n",
523 | " plt.plot(step,loss,color_list[cnt],label=\"test_\" + k)\n",
524 | " cnt += 1\n",
525 | " plt.title(\"Loss\")\n",
526 | " plt.xlabel('iteration')\n",
527 | " plt.ylabel('loss')\n",
528 | " plt.legend()\n",
529 | " clear_output()\n",
530 | " plt.savefig(\"./loss/\" + model_name + \"/loss.png\")\n",
531 | " clear_output()\n",
532 | " plt.show()"
533 | ]
534 | },
535 | {
536 | "cell_type": "code",
537 | "execution_count": 21,
538 | "metadata": {},
539 | "outputs": [],
540 | "source": [
541 | "def record_test_loss(test_loss, test_data, step):\n",
542 | " label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show, clk_length, show_length = data_reader.get_test_data(test_data)\n",
543 | " user_profile_dict, user_profile_list, user_behavior_list, click_behavior_dict, noclick_behavior_dict, target_item_dict = get_train_data(label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show)\n",
544 | " aux_loss, target_loss, final_loss = get_test_loss(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list, label)\n",
545 | " loss_dict = dict()\n",
546 | " loss_dict[\"aux_loss\"] = str(aux_loss)\n",
547 | " loss_dict[\"target_loss\"] = str(target_loss)\n",
548 | " loss_dict[\"final_loss\"] = str(final_loss)\n",
549 | " utils.add_loss(loss_dict, loss_file_name, level=\"test\")\n",
550 | " test_loss[\"aux_loss\"].append(float(aux_loss))\n",
551 | " test_loss[\"target_loss\"].append(float(target_loss))\n",
552 | " test_loss[\"final_loss\"].append(float(final_loss))\n",
553 | " with train_summary_writer.as_default():\n",
554 | " tf.summary.scalar(\"test_aux_loss epoch: \"+str(epoch), aux_loss, step = step)\n",
555 | " tf.summary.scalar(\"test_target_loss epoch: \"+str(epoch), target_loss, step = step)\n",
556 | " tf.summary.scalar(\"test_final_loss epoch: \"+str(epoch), final_loss, step = step)"
557 | ]
558 | },
559 | {
560 | "cell_type": "code",
561 | "execution_count": 22,
562 | "metadata": {},
563 | "outputs": [],
564 | "source": [
565 | "mkdir(\"./checkpoint/\" + model_name)\n",
566 | "checkpoint_path = \"./checkpoint/\" + model_name + \"/cp-{epoch:04d}.ckpt\"\n",
567 | "checkpoint_dir = os.path.dirname(checkpoint_path)"
568 | ]
569 | },
570 | {
571 | "cell_type": "code",
572 | "execution_count": 23,
573 | "metadata": {},
574 | "outputs": [
575 | {
576 | "data": {
577 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nOydd3yO5/fHP1cGQZCI2DTEJkbF3rU3NVo1apdqtVWt+qpVRYeiWuqrRmmNmu23rZYaMYoSxOZnE5tYEUGS6/fHJ7f7SfJkep48Gef9ej2vO/e67vOMXOe6zjnXOUprDUEQBCHz4uRoAQRBEATHIopAEAQhkyOKQBAEIZMjikAQBCGTI4pAEAQhkyOKQBAEIZMjikAQBCGTI4pAEOJBKXVeKdXU0XIIgr0RRSAIgpDJEUUgCMlEKTVQKXVaKRWilPqfUqpQ9HGllJqulLqhlLqnlDqklKoYfa61UuqYUuqBUuqyUmqEY9+FIJiIIhCEZKCUegnAFADdABQEcAHA8ujTzQE0AFAagAeAVwDcjj43H8AbWuucACoC2JyKYgtCgrg4WgBBSGf0ALBAa70fAJRSowDcUUr5AHgKICeAsgD2aK2PW9z3FEB5pdRBrfUdAHdSVWpBSACZEQhC8igEzgIAAFrrUHDUX1hrvRnAtwBmAbiulJqrlMoVfWlnAK0BXFBKbVVK1U5luQUhXkQRCELyuALgBWNHKZUDgBeAywCgtZ6pta4GoAJoIvog+vherXUHAPkA/AJgRSrLLQjxIopAEBLGVSnlZrzADryvUqqKUiorgMkA/tVan1dKVVdK1VRKuQJ4CCAcQKRSKotSqodSKrfW+imA+wAiHfaOBCEWoggEIWHWAXhk8aoPYAyA1QCuAvAF8Gr0tbkAfA/a/y+AJqOp0ed6ATivlLoPYDCAnqkkvyAkipLCNIIgCJkbmREIgiBkckQRCIIgZHJEEQiCIGRyRBEIgiBkctLdyuK8efNqHx8fR4shCIKQrti3b98trbW3tXPpThH4+PggMDDQ0WIIgiCkK5RSF+I7J6YhQRCETI4oAkEQhEyOKAJBEIRMTrrzEQiCkDo8ffoUwcHBCA8Pd7QoQjJwc3NDkSJF4OrqmuR77KYIlFILALQFcENrXTGeaxoBmAHAFcAtrXVDe8kjCELyCA4ORs6cOeHj4wOllKPFEZKA1hq3b99GcHAwihcvnuT77Gka+gFAy/hOKqU8AMwG0F5rXQFAVzvKIghCMgkPD4eXl5cogXSEUgpeXl7JnsXZTRForbcBCEngktcArNFaX4y+/oa9ZBEEIWWIEkh/pOQ7c6SzuDQAT6VUgFJqn1Kqd3wXKqUGKaUClVKBN2/eTNHDLl0Chg0DHj1KqbiCIAgZE0cqAhcA1QC0AdACwBilVGlrF2qt52qt/bXW/t7eVhfGJcqOHcA33wAdOqRYXkEQhAyJIxVBMIC/tNYPtda3AGwDUNleD3v1VaBwYeDvv4Hff7fXUwRBsCV3797F7Nmzk31f69atcffuXTtIlHICAgLQtm1bR4thFUcqgl8B1FdKuSilsgOoCeC4vR6mFPDXX9x27w48eWKvJwmCYCviUwSRkQlX+ly3bh08PDzsJVaGw26KQCm1DMAuAGWUUsFKqf5KqcFKqcEAoLU+DuAvAIcA7AEwT2t9xF7yAEDFisCgQUBoKPDKK/Z8kiBkQBo1ivsyOumwMOvnf/iB52/dinsuCXz00Uc4c+YMqlSpgurVq6Nx48Z47bXX4OfnBwDo2LEjqlWrhgoVKmDu3LnP7vPx8cGtW7dw/vx5lCtXDgMHDkSFChXQvHlzPErAUfj999+jevXqqFy5Mjp37oywsDAAQJ8+fbBq1apn17m7uwMA1q5di6ZNm0JrjatXr6J06dK4du1aou8rJCQEHTt2RKVKlVCrVi0cOnQIALB161ZUqVIFVapUQdWqVfHgwQNcvXoVDRo0QJUqVVCxYkVs3749SZ9dcrBn1FB3rXVBrbWr1rqI1nq+1nqO1nqOxTVfaq3La60raq1n2EsWS2bPBry9gV9/BY4dS40nCoKQUj777DP4+voiKCgIX375Jfbs2YNJkybhWPQ/74IFC7Bv3z4EBgZi5syZuH37dpw2Tp06haFDh+Lo0aPw8PDA6tWr433eyy+/jL179+LgwYMoV64c5s+fn6B8nTp1QoECBTBr1iwMHDgQEyZMQIECBRJ9X+PGjUPVqlVx6NAhTJ48Gb17M1Zm6tSpmDVrFoKCgrB9+3Zky5YNS5cuRYsWLRAUFISDBw+iSpUqibafXDLdymInJ2DTJqBhQ2DAAGD7dsDZ2dFSCUI6ICAg/nPZsyd8Pm/ehM8nkRo1asRYKDVz5kysXbsWAHDp0iWcOnUKXl5eMe4pXrz4s86zWrVqOH/+fLztHzlyBB9//DHu3r2L0NBQtGjRIlGZvvnmG1SsWBG1atVC9+7dk/Q+duzY8UwhvfTSS7h9+zbu3buHunXrYvjw4ejRowdefvllFClSBNWrV0e/fv3w9OlTdOzY0S6KIFPmGvLzA77+Gti1CxgxwtHSCIKQVHLkyPHs74CAAGzcuBG7du3CwYMHUbVqVasLqbJmzfrsb2dnZ0RERMTbfp8+ffDtt9/i8OHDGDdu3LP2XFxcEBUVBYCrd59YOBkvX74MJycnXL9+/dk1iaG1jnNMKYWPPvoI8+bNw6NHj1CrVi2cOHECDRo0wLZt21C4cGH06tULixcvTtIzkkOmVAQA0LMnUKgQMGMGsHOno6URBMEaOXPmxIMHD6yeu3fvHjw9PZE9e3acOHECu3fvfu7nPXjwAAULFsTTp0+xZMmSZ8d9fHywb98+AMCvv/6Kp0+fAgAiIiLQt29fLF26FOXKlcO0adOS9JwGDRo8az8gIAB58+ZFrly5cObMGfj5+WHkyJHw9/fHiRMncOHCBeTLlw8DBw5E//79sX///ud+n7HJdKYhA6WAZctoImrfHrh+XUxEgpDW8PLyQt26dVGxYkVky5YN+fPnf3auZcuWmDNnDipVqoQyZcqgVq1az/28iRMnombNmnjhhRfg5+f3TAkNHDgQHTp0QI0aNdCkSZNnM5PJkyejfv36qF+//jOHdps2bVCuXLkEnzN+/Hj07dsXlSpVQvbs2bFo0SIAwIwZM7BlyxY4OzujfPnyaNWqFZYvX44vv/wSrq6ucHd3t8uMQFmboqRl/P39tS0rlHXrBqxcCfTqBdjh8xWEdMvx48cT7dCEtIm1704ptU9r7W/t+kxrGjJYuhTw8AB+/JGrjwVBEDIbmV4RuLgwlNTVFXjzTSDa9CcIQgZm6NChz+L1jdfChQtt0vb69evjtN2pUyebtG0vMq2PwJIGDegv6NIF+PRTYMIER0skCII9mTVrlt3abtGiRZLCTtMSmX5GYNC5M/0FEycCa9Y4WhpBEITUQxSBBWPGcNuzJ9NQCIIgZAZEEVhQsSLwwQesWdC6taOlEQRBSB1EEcTis88AHx+mnrCR70gQBCFNI4ogFkqxZoGTE6OIQhIqtikIgl1J7XoEJ06ceJb588yZM6hTp06y2zCInbE0No0aNYIt10Q9D6IIrFCyJPDf/wKRkUxbnc7W3AlChiG16xH88ssv6NChAw4cOABfX1/szCT5Z0QRxMOAAcCkScDq1QwpFYTMjgPKEaRqPYJ169ZhxowZmDdvHho3bgzArDsQEBCARo0aoUuXLihbtix69OjxLHHcJ598gurVq6NixYoYNGiQ1YRyibFs2TL4+fmhYsWKGDlyJAAquz59+qBixYrw8/PD9OnTATDjavny5VGpUiW8+uqryX6WNWQdQQK8/z6T0o0dCzRuDNSr52iJBCFz8dlnn+HIkSMICgpCQEAA2rRpgyNHjjxLRb1gwQLkyZMHjx49QvXq1dG5c+c4aahPnTqFZcuW4fvvv0e3bt2wevVq9OzZM86zWrdujcGDB8Pd3R0jrKQlPnDgAI4ePYpChQqhbt26+Oeff1CvXj289dZbGDt2LACgV69e+P3339GuXbskv8crV65g5MiR2LdvHzw9PdG8eXP88ssvKFq0KC5fvowjR1ivyzB1ffbZZzh37hyyZs1qs3KcdlMESqkFANoCuKG1rpjAddUB7AbwitY6foOaA3By4oimeXOgVSvg6lUgeoAgCJmONFCOwO71CBJ7dpEiRQAAVapUwfnz51GvXj1s2bIFX3zxBcLCwhASEoIKFSokSxHs3bsXjRo1gre3NwCgR48e2LZtG8aMGYOzZ8/i7bffRps2bdC8eXMAQKVKldCjRw907NgRHTt2TNF7iY09TUM/AGiZ0AVKKWcAnwNYb0c5notmzYC33+a6gqROZwVBsA/2rkeQENbaCQ8Px5tvvolVq1bh8OHDGDhwoFUZEiI+U5KnpycOHjyIRo0aYdasWRgwYAAA4I8//sDQoUOxb98+VKtWLcXvxxJ7lqrcBiCxmJu3AawGcMNectiCr78GypcH9u0DpkxxtDSCkHlI7XoEycXo9PPmzYvQ0NAEo4Tio2bNmti6dStu3bqFyMhILFu2DA0bNsStW7cQFRWFzp07Y+LEidi/fz+ioqJw6dIlNG7cGF988cWzSmrPi8N8BEqpwgA6AXgJQPVErh0EYBAAFCtWzP7CxXk+p7UlS1IR9OgBOEAMQch0pHY9guTi4eGBgQMHws/PDz4+PqhePcGuzCoFCxbElClT0LhxY2it0bp1a3To0AEHDx5E3759n1U9mzJlCiIjI9GzZ0/cu3cPWmu89957KYqOio1d6xEopXwA/G7NR6CUWgngK631bqXUD9HXJapObV2PIDmcOQNUrQpUqkTF4CKudiEDI/UI0i/pqR6BP4DlSqnzALoAmK2Uso3nw074+jKK6J9/gN69HS2NIAiCbXCYItBaF9da+2itfQCsAvCm1voXR8mTVF57jYVsli0zY6QFQUhf2LMeQWw6deoU51nr16et+Bh7ho8uA9AIQF6lVDCAcQBcAUBrPcdez7U3bm7A1q00EfXvD5QtCzjANCkIwnNgz3oEsTHCW9MydlMEWuvuybi2j73ksAeVKnFG8MorQNOmXDXp5uZoqQRBEFKGpJhIId26cX3Bw4dMRyEIgpBekbiX52DGDCammz0baNsWsFHaD0EQhFRFZgTPgZMTlUGdOkCfPsDhw46WSBAEIfmIInhOXF2BgQOBx4+Bhg2lxKUg2JKU1iMAgBkzZiAsLMzGEiWd8ePHY+rUqQ57fnIQRWAD+vSh4/jOHaBFC6lfIAi2Ij0rgvSE+AhsxE8/Abt3Azt3Ah9/zFoGgpBhePddICjItm1WqULbagJY1iNo1qwZ8uXLhxUrVuDx48fo1KkTJkyYgIcPH6Jbt24IDg5GZGQkxowZg+vXr+PKlSto3Lgx8ubNiy1btlhtf8iQIdi7dy8ePXqELl26YMKECQBYzyAwMBB58+ZFYGAgRowYgYCAAAwbNgx58+bF2LFjsX79ekyaNAkBAQFwckp4TB0UFITBgwcjLCwMvr6+WLBgATw9PTFz5kzMmTMHLi4uKF++PJYvX46tW7finXfeAQAopbBt2zbkzJkzBR9w0hFFYCNcXJh2onRpYPJk1i9o2tTRUglC+sayHsGGDRuwatUq7NmzB1prtG/fHtu2bcPNmzdRqFAh/PHHHwCYjC537tyYNm0atmzZgrx588bb/qRJk5AnTx5ERkaiSZMmOHToECpVqpSgPNWrV0f9+vUxbNgwrFu3LlElAAC9e/fGN998g4YNG2Ls2LGYMGECZsyYYbW2wNSpUzFr1izUrVsXoaGhcEuF2HRRBDbEx4frC0aMYARRYCCPCUK6J5GRe2qwYcMGbNiwAVWrVgUAhIaG4tSpU6hfvz5GjBiBkSNHom3btqhfv36S21yxYgXmzp2LiIgIXL16FceOHUtQEWTPnh3ff/89GjRogOnTp8PX1zfRZ9y7dw93795Fw4YNAQCvv/46unbtCsB6bYG6deti+PDh6NGjB15++eVnNRDsifgIbEznzsCGDUBEBNCyJUv4CYLw/GitMWrUKAQFBSEoKAinT59G//79Ubp0aezbtw9+fn4YNWoUPvnkkyS1d+7cOUydOhWbNm3CoUOH0KZNm2dppV1cXJ5l/YxdX+Dw4cPw8vLClStXnvs9Wast8NFHH2HevHl49OgRatWqhRMnTjz3cxJDFIEdKFWKi81OngQ6dRLnsSCkFMt6BC1atMCCBQue5d+/fPkybty4gStXriB79uzo2bMnRowYgf3798e51xr3799Hjhw5kDt3bly/fh1//vnns3M+Pj7Yt28fAGD16tXPjl+4cAFfffUVDhw4gD///BP//vtvou8hd+7c8PT0xPbt2wEAP/74Ixo2bBhvbYEzZ87Az88PI0eOhL+/f6ooAjEN2YkRIzib3rABGDcOSOIgRRAECyzrEbRq1QqvvfYaateuDYCF5X/66SecPn0aH3zwAZycnODq6orvvvsOADBo0CC0atUKBQsWtOosrly5MqpWrYoKFSqgRIkSqFu37rNz48aNQ//+/TF58mTUrFkTAGck/fv3x9SpU1GoUCHMnz8fffr0wd69exO14y9atOiZs7hEiRJYuHBhvLUFxowZgy1btsDZ2Rnly5dHq1atbPVxxotd6xHYA0fWI0gu69YB7dtz9fHatYCNyosKQqog9QjSL+mpHkGGp3Vr4O+/+Xf37sDp046VRxAEwRpiGrIzjRsDb74JLF4MtGsH7NrFegaCIKQeNWvWxOPHj2Mc+/HHH+Hn5/fcbU+aNAkrV66Mcaxr164YPXr0c7edWohpKJXYupXrCho1Av78U8pcCmkfMQ2lX8Q0lEZp2JDmoY0bgaFDHS2NIAiCid0UgVJqgVLqhlLqSDzneyilDkW/diqlKttLlrTCsGHMWDp3LhAd2CAIguBw7Dkj+AFAywTOnwPQUGtdCcBEAHPtKEuawN8f+OIL/j10KGcHgiAIjsZuikBrvQ1ASALnd2qt70Tv7gZg/3XUaYDhw4GRI7nIrH174NQpR0skCEJmJ634CPoD+DO+k0qpQUqpQKVU4M2bN1NRLNujFDBlClccOzkxkig615QgCLGwdxrqlStXoly5cmjcuDECAwMxbNiwFD0L4GrkW7duxXve3d09xW3bG4crAqVUY1ARjIzvGq31XK21v9ba39vbO/WEsxNKAatXc8HZ2bOsfxwR4WipBCHtYW9FMH/+fMyePRtbtmyBv78/Zs6cmaJnpXccGsSolKoEYB6AVlrr246UJbVRiiUu+/al8/jdd4Fvv3W0VIJgHQeVI7BrPYJPPvkEO3bswLlz59C+fXu0adMGU6dOxe+//47x48fj4sWLOHv2LC5evIh333332WyhY8eOuHTpEsLDw/HOO+9g0KBByXrfWmt8+OGH+PPPP6GUwscff4xXXnkFV69exSuvvIL79+8jIiIC3333HerUqYP+/fsjMDAQSin069cP7733XrKelxQcpgiUUsUArAHQS2v9f46Sw5HcuAFE57XCrFlMVhddj0IQBNi3HsHYsWOxefNmTJ06Ff7+/ggICIhx/sSJE9iyZQsePHiAMmXKYMiQIXB1dcWCBQuQJ08ePHr0CNWrV0fnzp3h5eWV5Pe0Zs0aBAUF4eDBg7h16xaqV6+OBg0aYOnSpWjRogVGjx6NyMhIhIWFISgoCJcvX8aRIwy+vGsnO7LdFIFSahmARgDyKqWCAYwD4AoAWus5AMYC8AIwWykFABHxLXbIqBQqBOzdy0iijz4C3nsPKFOG6asFIS2RBsoR2KUeQUK0adMGWbNmRdasWZEvXz5cv34dRYoUwcyZM7F27VoAwKVLl3Dq1KlkKYIdO3age/fucHZ2Rv78+dGwYUPs3bsX1atXR79+/fD06VN07NgRVapUQYkSJXD27Fm8/fbbaNOmDZo3b26T9xYbe0YNdddaF9Rau2qti2it52ut50QrAWitB2itPbXWVaJfmUoJGCjFKKIvv2QkUadOwPHjjpZKENIetq5HkBhZs2Z99rezszMiIiIQEBCAjRs3YteuXTh48CCqVq0ap15BUt6HNRo0aIBt27ahcOHC6NWrFxYvXgxPT08cPHgQjRo1wqxZszBgwIDnek/x4XBnsUBGjODK46xZgbZtgQSCDwQh02DPegQp4d69e/D09ET27Nlx4sQJ7N69O9ltNGjQAD///DMiIyNx8+ZNbNu2DTVq1MCFCxeQL18+DBw4EP3798f+/ftx69YtREVFoXPnzpg4ceKz92ZrJONNGmLpUmD3buYjevllZi61GJQIQqbDnvUIUkLLli0xZ84cVKpUCWXKlEGtWrWS3UanTp2wa9cuVK5cGUopfPHFFyhQoAAWLVqEL7/8Eq6urnB3d8fixYtx+fJl9O3b91m1tClTptjkfcRGks6lQfr0ARYtAl5/HVi4kOYjQUhtJOlc+kWSzmUAihbldtEiMyWFIAiCvRDTUBpk5Ejg99+BgwcZTVS6NJ3IgiCkDHvWI7Dk9u3baNKkSZzjmzZtSlZkUWojiiAN4u4ObNsGvPgiEBwM9OwJbN/OfUFITbTWUBnANpmUIvO2wMvLC0G2XnmXTFJi7hfTUBolZ05g2jTg6VNWNGvXDrh82dFSCZkJNzc33L59O0Udi+AYtNa4ffs23NzcknWfzAjSMG3bAlu2ALlzA3XrAh06cKaQPbujJRMyA0WKFEFwcDDSe6LHzIabmxuKFEleMmdRBGkYpYD69YGoKK4xmDcP6N0bWLGCmUsFwZ64urqiePHijhZDSAWkO0kHKAWcOQNkycKspWPGOFoiQRAyEqII0gFKAd9/Dzg703cweTKweLGjpRIEIaMgiiCdUKIEsGkT4OpKH8HAgcCOHY6WShCEjIAognRErVrA8uWMJPL25tqCs2cdLZUgCOkdUQTpjGbNgIsXGU0UGcnIonv3HC2VIAjpGVEE6ZACBVjEZuxY4NQp4JVXpNSlIAgpRxRBOmX/fhaycXYG1q8Hhg93tESCIKRXRBGkU158Efj0U+Cll7j/zTdMUicIgpBc7KYIlFILlFI3lFJH4jmvlFIzlVKnlVKHlFKSSSeZjB7N5HS9enF/wADOFARBEJKDPWcEPwBIqPpuKwClol+DAHxnR1kyLE5OnAksXkzfQadOUt1MEITkYc+axdsAhCRwSQcAizXZDcBDKVXQXvJkZJTirGDtWuD6daBrV3EeC4KQdBzpIygM4JLFfnD0sTgopQYppQKVUoGSACt+ypdngrqAAOA//3G0NIIgpBccqQisJTm3mu9Waz1Xa+2vtfb39va2s1jpl+zZgZYtARcX4MsvmZxOEAQhMRypCIIBFLXYLwLgioNkyTC89RbNQiVKAP37c52BIAhCQjhSEfwPQO/o6KFaAO5pra86UJ4Mgb8/UK0acPs2HcndugHh4Y6WShCEtIw9w0eXAdgFoIxSKlgp1V8pNVgpNTj6knUAzgI4DeB7AG/aS5bMhFLMR1SgAHMTBQUBH3zgaKkEQUjL2K0wjda6eyLnNYCh9np+ZqZkSeDwYWYqHT4cmD6dC886dXK0ZIIgpEVkZXEGxdWV244dgapVgX79gPPnHSqSIAhpFFEEGZibN4FWregwfvCA5S5lfYEgCLERRZCB8fYGfvqJDuTISGD3buCLLxwtlSAIaQ1RBBmcTp2YjyhnTuCFF4Dx44GDBx0tlSAIaQlRBJmAHDmYgqJqVSBPHv79+LGjpRIEIa2QuRRBVJSjJXAYM2cyF9G8eYwomjDB0RIJgpBWyDyKYPNmoHRp4IjVrNgZHmdnbnPlAmrWBD7/XFJWC4JAMo8iyJ0bOHMGaNIkU88MliwB/v0X8PAA3niDTmRBEDI3mUcRVKsGdO4M3LhhVnLJhHz1FVCoEPViYCAwe7ajJRIEwdFkHkUAAD//DHh5AUuXAn/95WhpHIK7O8tanjvH8NLRo4HLlx0tlSAIjiRzKQJnZ2DDBibk6dwZCAtztEQO4eWXgUmTuODsyRNg2DBHSyQIgiPJXIoAYNX399+nEhg71tHSOIyBA7m+4D//AdasATZudLREgiA4CsXcb4lcpNQ7ABYCeABgHoCqAD7SWm+wr3hx8ff314GBgc/f0BtvAN9/D2zZAjRs+PztpVMePwbKleOCs/37zegiQRAyFkqpfVprf2vnkjoj6Ke1vg+gOQBvAH0BfGYj+RzDtGlA0aJMxpNJjeQXLgCTJwMffggcOgR8+62jJRIEwREkVREYZSVbA1iotT4I66Um0w85cgAffQQ8egQ0auRoaRzClSvAJ58Al6IrR48fDzx86FCRBEFwAElVBPuUUhtARbBeKZUTQPoPxh8yBKhbFzh9mon7Mxk1atAkNHky9+/eBaZOdaxMgiCkPklVBP0BfASgutY6DIAraB5KEKVUS6XUSaXUaaXUR1bO51ZK/aaUOqiUOqqUSrRNm7N+PWMqp08Htm5N9cc7EmdnwNPT3M+Th9lJr11znEyCIKQ+SVUEtQGc1FrfVUr1BPAxgHsJ3aCUcgYwC0ArAOUBdFdKlY912VAAx7TWlQE0AvCVUipLMuR/fnLkANat499dutBUlIn48UegTh1g8GAgNJTOY0lVLQiZi6Qqgu8AhCmlKgP4EMAFAIsTuacGgNNa67Na6ycAlgPoEOsaDSCnUkoBcAcQAiD1S6fUrw98/TVw6xZDSzMRDRoA//wDtGnDNQWNGwNz5sisQBAyE0lVBBHRNYY7APhaa/01gJyJ3FMYwCWL/eDoY5Z8C6AcgCsADgN4R2sdx/eglBqklApUSgXevHkziSInk2HDWOX9u++YkS2T0aYNsGMHU048fgx8+aWjJRIEIbVIqiJ4oJQaBaAXgD+izT6uidxjLaoo9qKFFgCCABQCUAXAt0qpXHFu0nqu1tpfa+3v7e2dRJFTwKRJTMIzahSwZ4/9npMGUYp+81KlgJ49qQ+vX3e0VIIgpAZJVQSvAHgMrie4Bo7sExszBgMoarFfBBz5W9IXwBpNTgM4B6BsEmWyPa6uwLJlgNZAs2ZAeLjDRHEUnTvTiRweDsya5WhpBEFIDZKkCKI7/yUAciul2gII11on5iPYCxZJ48YAACAASURBVKCUUqp4tAP4VQD/i3XNRQBNAEAplR9AGQBnkyG/7WnViquO799nyupMxoMHwIEDQLt2NBNlMt+5IGRKkqQIlFLdAOwB0BVANwD/KqW6JHSP1joCwFsA1gM4DmCF1vqoUmqwUmpw9GUTAdRRSh0GsAnASK31rZS9FRvy3Xe0kezcmelKedWsyVXGQ4YAt28zqkgQhIxNUnMNHQTQTGt9I3rfG8DG6LDPVMVmuYYS49YtoGJFDpF37wb8/Oz/zDTAH38AbdsCAQFcYxcWBhw9CjhlvvSEgpChsEWuISdDCURzOxn3pk/y5gWCgug8fvllDo8zATVqcPvvv4ykPXGCa+4EQci4JLUz/0sptV4p1Ucp1QfAHwDW2U+sNEKBAsBPP7HEZe3amaLEpbc38PrrwAsvcH2dtzcwd66jpRIEwZ4kyTQEAEqpzgDqgmGh27TWa+0pWHykmmnIkjp1gF27GFeZSYzmjx+zhs/WrcCMGUxMV7Cgo6USBCGl2MI0BK31aq31cK31e45SAg5j0yYm5fnpJ+CHHxwtTaowahTQvj0L1kRGAgsXOloiQRDsRYKKQCn1QCl138rrgVLqfmoJ6XCyZQO2b2eA/YABDKvJwFy7xsL2o0cDBw8CJUuyhk8msIwJQqYkQUWgtc6ptc5l5ZVTax1nBXCGpkIFYN48/t2pE3M2Z1AKFAC2bQM+/RSoXh14+hQ4f17KWQpCRiVjR/7Ymj59aDi/dAl49VUgIvXz46U2vXoBXl4MnlqyxNHSCIJgD0QRJJeXXgK++YYxlR1iJ1PNeAwdCuzbx9QTa9bISmNByIiIIkgJr7/OEJp162hIz8AYC8l69GC9gt9/d6w8giDYHlEEKcHNjcPkbNlY53HlSkdLZFf69mXWjYIFxTwkCBkRUQQppWBBhpU6OQHduwP79ztaIrvx9ClXGnfvzknQnTuOlkgQBFsiiuB5qF2bkURRUQy6v3Ej8XvSIeXK0T/esSOVwurVjpZIEARbIorgeenbl3GVISHM1vbwoaMlsjnlynHr5gb4+gKrVjlWHkEQbIsoAlvw0kssaBMYCNSrx6W4GQgj8WpgIPMPbdpEvScIQsZAFIGtaNsWKF2aGUtbtGCVswxCyZJMs1S4MNC1K5dP/Pqro6USBMFWiCKwFc7OHDIXK8Yh85tvOloim6EUc+21bw+8+CLg4yPmIUHISNhVESilWiqlTiqlTiulPornmkZKqSCl1FGl1FZ7ymN33N05I3B3B+bMASZOdLRENuXmTSA4mOahv//O0Fk2BCFTYTdFoJRyBjALQCsA5QF0V0qVj3WNB4DZANprrSuApTDTN56eTFmdNSswblyGGTprDZQty8qdXbsyeuh/sStQC4KQLrHnjKAGgNNa67Na6ycAlgOInZPhNQBrtNYXASBWFbT0S8WKHDrXqQO89hrw11+Olui5UYrRsv/8w0R0xYpl+HV0gpBpsKciKAzgksV+cPQxS0oD8FRKBSil9imleltrSCk1SCkVqJQKvHnzpp3EtTF58zIfQ7FiQJs2TFaXzqlbl6UrQ0JoHtqwAbh3z9FSCYLwvNhTESgrx2KH0rgAqAagDYAWAMYopUrHuUnruVprf621v7e3t+0ltRceHkxBERVFZZDaldVsTL163NavD5QpAzx5Avz2m2NlEgTh+bGnIggGUNRivwiAK1au+Utr/VBrfQvANgCV7ShT6tOtG/D554y5rFuXlV7SKTVr8u306MFspIULZxgXiCBkauypCPYCKKWUKq6UygLgVQCx3Yu/AqivlHJRSmUHUBPAcTvK5Bg+/BAYM4ZD6Jo1022FsyxZgJ9/ZsJVLy+ah/76C3jwwNGSCYLwPNhNEWitIwC8BWA92Lmv0FofVUoNVkoNjr7mOIC/ABwCsAfAPK31EXvJ5FA++YTKQCkuODt50tESpZjNm4FFi6gIHj+W1NSCkN5ROp2tgPX399eB6dnWfvQo0Lgxs5Zu3MgIo3RG375mobYiRRhNJInoBCFto5Tap7X2t3ZOVhanNhUqMJfzzZtArVrAqVOOlijZFC8OXLlCS1fnznw7oaGOlkoQhJQiisAR+PuzytnDh0CNGunOTFS8OLcXLnBxWXg4lYEgCOkTUQSO4vvvGY959y4VQ1CQoyVKMoYiOHeOgVD588viMkFIz4gicBTOzvSyNm5Mu0qdOkxNkQ7w8eH2/Hm+jZdf5owgA5ZiEIRMgSgCR5I7NzOVfv45S182a8b9NE6hQsCZM8CgQdzv2hUIC8sQmTQEIVMiisDRKMV1Bv/8w3QULVqk+RAcJyegRAluAa409vYW85AgpFdEEaQVChTgOoPISAboz5njaIkS5O+/geHD+beLC81Dv/8OPHrkWLkEQUg+ogjSEt2704msFDBkCDBqVJqtdHbwIDB9ulmysksX+gjEPCQI6Q9RBGmNAQOYyc3ZGfjsM6BXrzRZA7lCBW6PHuW2USMmXP35Z4eJJAhCChFFkBZp04Y+g/z5gSVLONxOYzYXY0H0woWMgHVxodP4f/+T3EOCkN4QRZBWqVkTuHoVmDmTleLr1wdu33a0VM8oUoQ1dxYuBN56i8d69KC+Wrs2ZW0uXw7kyMH8RYIgpB6iCNIySgFvv01P7L59QOXKrAyTBlCKk5X9+5lPD+BSCB8fHk8Jp08zDPXqVZuJKQhCEhBFkB744gsm/798GahSJU3lc6halaGkISEUr0cP5tK7di35bVWqxO2tW7aVURCEhBFFkB4oUQI4dszM+9y2LTBjhqOlekZEBH0GXbrwFRWVMqfx2bPc3sgYlasFId0giiC9kCsXV2xNn86VXO+9Bwwdyl7Ywbi4AN9+C/z7L/DLL5wl/PRT8tv5+29uRREIQuoiiiC98e679BN88AEwezbQpAlw546jpcLLLwOlSlG0nj1Znvn//i95beTKxW1UlO3lEwQhfkQRpEdKlqTfoG9fYNs2wM/PDOh3IF5e9BW8+iqdycmdFTg5Ab6+QL9+9pFPEATr2FURKKVaKqVOKqVOK6U+SuC66kqpSKVUF3vKk+H45BPmhL58GahWzeHJfvLkoSIoVAho1YoTluSsKXj0CMiWzX7yCYJgHbspAqWUM4BZAFoBKA+gu1KqfDzXfQ7WNhaSQ5EijN+sU4dO5G7dGNTvIL/B0qXAjh38e/x4LnuYPj3p94eHA0eOAO+8YxfxBEGIB3vOCGoAOK21Pqu1fgJgOYAOVq57G8BqAOIiTAkeHkBAADBsGL22s2ZxOO6AGMzcuQE3N/5dvTpXGk+axLxESWHuXKau2LrVfjKmJzZuBCZPdrQUQmbAnoqgMIBLFvvB0ceeoZQqDKATgARTbSqlBimlApVSgTdv3rS5oOkeV1fg66/pNJ4/n36DKlU4W0hFtm6lPjImJLNn02/QqROtV4lRpAhQuzZw/bp95Uwv9OkDjB7taCmEzIA9FYGycix2Ks0ZAEZqrRPMqqa1nqu19tda+3t7e9tMwAyHuzs9rbNmseetVQtYvDjVHn/4MPDNN2YQU968DCe9eZMZMhKbGSxezGijmzclcghg9FWWLI6WQsgM2FMRBAMoarFfBMCVWNf4A1iulDoPoAuA2UqpjnaUKXPQrRuH1hERwOuvc73Bkyd2f2yePNwaqakBoEYNmjjCw1ma+b33gAkTqCBiM3kySzdHRqZOROzixcBHFiEM4eGp8jElGScnypMGk88KGQ2ttV1eAFwAnAVQHEAWAAcBVEjg+h8AdEms3WrVqmkhCdy6pXXRolq7uGgNaP3ii1oHB9v1kX/+yUft3Bn33I0bWvftq7WzM68BtJ44Ueu9e7V+8IDXFCumtZeX1tWqaX3xol1F1Vqbcmitdbdu/HvNGvs+s1cvradNS9q1hnz//a/WZctqHRmZ8PUREVovWsRtZiQiQuuoKEdLkXYBEKjj6VftNiPQWkcAeAuMBjoOYIXW+qhSarBSarC9nitE4+UFbN7MYjeurkxR8eKLdvXEWpsRGHh7s87OihXmsTFj6FTOmZOplK5eZRtNmjCD6erVXK189ar9TUXBwdxu3mzf5xw8CGzYkLx73niDC/WCghK+buZMTgBT0RqYpihYkLOo+/cdLUn6w8WejWut1wFYF+uYVcew1rqPPWXJlJQsyV5h9mzgwgUu/23ShOsPRo3iqi8b4uXFJjt0AF54ARgxgoXW7t3j2LZ06ZjXr1zJ60+e5GvJEjqKZ8yIa6JxdQWKFmVZZ2uvokXpIkkJERE0CwHApk0payOp+PoCx48n7dps2WjVO3GCZUD/+ou6PD7q1+c2Z87nlzMh3nuP39H27Yz0WrTI5j+lFGHEkTx8aK5SB6gYvv0WGDmS9Z5SwpUr/H1ZtpuRsKsiENII7u6My5w5E2jZkqEou3cz8D+lvacVSpQALl5kp1y5MqNYASoGa7b3qCi6MwyWLmVBtl27GDFTsCBDUAE6Th8/pj779VcgNDSu7dzTM35FUawY27PWETx4YI4ijx/nP32hQsl77zduAPnyJXzN7t1mrYaoKI5e4yM8nAvsPD3px1i3LvH2s2fn1t4+hfPnuZ0/H/jxRyqCtETsGk4ffECF5eLCbO6jRjGoLjkULswU659+yu+kdWubiZsmEEWQmWjWjEPeli1ZDrN0aa5BiD1UTyFKAX/8wb8/+YT/OCdPxm+Nih0meuECkDUrO7wbN7hEwqBLF/7zbdzIt7F4MctjXrxovi5cAC5dAs6d42j17t2Y7bu48B/aUAwGv/1Gc1bVqsCBA8CWLUynbRAezmUZ+fNzZhKbs2c50g8IABo2jP/zsTTtXL5MhRkfhrN89GjmbEpK575gAbdhYdxu2gQ0aBBT5tBQKoyElFBiNGtGZ39ICD+TtDAbsMR4/waGqfLePZomu3VLviIAqAC/+oq/kTRaSjzFSK6hzISTE/DSSxx6e3vT+F6lCutL2ojB0d6fnDkZSrp1qzlSff11dvAhIUDv3kxSZ0mhQjQveXtTSfTqxWI1gDntN/DxYUdaty7dIP/8ww4pJIR6zehImzXjaHrOHODDD3m9Urze4PXXed+BA1wQ98knzPT95pssG920KZ+1YgVnDNeuxayiZtReSKyympFmGwDOnEn42qxZWaQOoEno5EnOIix9JQEBwPvvs3MHzIJAT57Q19G0KfD55+b1T57we3n//YSfbcm5cwwLDgri9/n4MbBnD89t28bv6dixpLeXGsSeERjjHOO3dO5c8tqzVMJ9+sRsK6MgM4LMSJcu9BcMHkw7TIcOHHpOmJByI2o0b7wBlCnDDnPYMCqCW7c4pe7ShZ08ENecEBoKTJ0KtGtnzggAoEAB4JVXuNgMMHMXWYaXak2xly9nB75uHY9Vrcrjholq+HA6a48c4flHj9iJ37gRc2ZhvHbtiun47tkzpsxubjFH10OGUJFly2a+smThTMTFxVSKrVtTqfz2m3nO2qtuXTrLhw1j+05OwHffUXfnysXZ17RptNnnyMHPcNgwfgezZvGeixdNeY0OcuVK4NAhYMAAKtGEKFMGePqUCvXvv/kdGt+dMeP6/XcquZdeMpW+LTlwgLMao052fAwaRBNQbBk+/pjLa378kfvJTXPu7Mz3fP48zUIAv7v33kteO2ma+MKJ0upLwkdtzKNHjOsEtG7WjGGnNuDcOTY5d67181FRfLTBxYu8/vvvtW7SROtatbTu0UPrr7+Oed/ixaaoBg0aaJ0njxluCWh99KjWnTtrXaaMed20aTx34QL3T5/W+pVXtA4MNK8JDtb6778ZqtmhQ8w2p0zReulSrWfN0vrTT7UeMULroUO1rleP54sV07pFC60bNtS6Rg2tCxXS2tdX65Iltfbx0drVVeusWbXOm1drDw+t3d21dnMzI3xT+lJK69y5+XfBglrXqWPKVLmy1u+/r/WECVpPn651tmxaN25syvvvv1ofPqz1mTNaX7mi9d27Wj95Yn4els/p0UPr7dv5d9Wq5nHj72PHrH/Xly/ze7txI7FfTUz+/psvyzDf52HyZLbzzjspu//4cVOWdu2Sft+NG/xs4wttPXqUbe7ZkzK5kgoSCB+VGUFmx82Nw7yFC2mAr1qVQ00/v+dq1rDBDxpEM1DWrDHPd+rEkM3AQO4bo9Vs2YB69Wi1WryYMwLAdK4aMwLj+pAQmij69uVb8PTkbGH8eIafGvdevWru//030L49R8QBAVx798EHHDkeOsSRXkgI/Q2WlCvHyVNs5s1jsr2mTelABTha9vSkWeLkSR5r3Jjv7cUXaaZ55RWzDa0pZ0QEX4sX0zRlUKkSZRszhiaj+/eZ0G/vXrabNy9nGXfv0haePTs/83PnaBZ7+NBsa8sWbi9eNM1PsXF2ZhsuLpQtMhLYuZNmKoDPM2YJBw7w2FdfcQaRJQtfWbNye+AAZy5DhnA25ORkvpSKuW+8oqLYviWHD1MuJyduY8+gQkMZ+5AjB2cQxrX9+vG30bQp20lu5M+xY5wwV67MfaWYUDGpfPklzWqxfRcGv/7K7cqVDKd2BKIIBP53rl1LL9q1a+ytVq2K+5+YDCydkbGVAMD1AoYSAMzwTTc3duJXrnCaX6IEHbC5cnE63qgRr3v6lFvD1t6hAx2jTZqwY7fMyN26NTsGwy9gtBsQwP2zZ9k5vv12zLUQWbLQ3GGsLYgvpbYhi+V7NqKQjAgbwOyAW7akucxSESjFjsvZmZ+X4W9wcaFiqFCBiiBPHqBNG55btYqKoHJldnJr11JBNm5MZ3fhwqZDOiKCnXi7dnTwGo765cv5zLAw+koAmvHCwtjWvHnm+y5e3PwsgoP5nTk7mzZ0QwnGx3ff8ZVSjJrWycHJyXTsbtzIz/m//6WsLi78XSRknnNx4cBi/35zNXy9evwe+vaNe6219n79lZ/VlCkMgIh9rRFEcP8+8Oef1mWIjOTgoXBhBibYGlEEAunYkYWGO3fmf07r1hxK9u+f4ibPnIk/V47RGRkjfcsZAWCmsy5RgsrBsOuWL8+Q0sOHuX/hArc+PuZo/dNP6fA9d46d13qLBOeTJnF0N2CAecxYTJYrlxlhExICnDplzmw8Pa3H8AcG0tXy4YcxaykYysHIxmqJry9t/wlh+EDCwznCLVyYsllGWhnK89o1fl2PH9OXEhbGr7NNG7NzdnExbfrffmuG5fr6MvUHwKijrVs5OzK+txo16Fu5epVRSLt301n8zjt09nt6molu58/nWOLJk5ivRYv4nTRrxjQihtPbeBmzoagooEULTlBr1TLTkX/1FX07J09SGfj6cqDw4ovmDCoighnYAf6Eq1aNObsyfCXZs9M99vRpzHvje4WHx40+u3qVUV8hIeygs2XjtZZtGt+/JYklEPzvf/lKiJEjGcBga0QRCCadOnHINnIkh9UDBrDn+c9/UtRciRLxnytQgP8wd+7QwWo5I9ixwxwtly1L5/GpU9w/c4b/iIZiMEbcL7xgtt2nD80knp6s7Nm7N49nycK3Nnw4ZweGqcBSERidubGW4MAByvL993EXav38M6ux/fwz77Wc+vv6shM0FMuBA3Tifvcdz929y47EmIFYcvQoTS65c3PEvWwZ73F2Zgdp8Mcf3DeUg1JUGtev87VgATskY6Wx4fj28jLNL2fOmIrAcO7v3UtHNcD3166dOUOpVcvspE+d4k9k3jyeu3vX+rIUw3mbK5f5LGtERbFj9fWlYjd4912Olr28qCiuXKGp5+uvY95vKIJGjcy/Ac6ELl5kO15eHCAYCvLqVX63r74afxT1N9+YDnuAlfcCA81nWFvJXKoUI4vCw83BwMKFVNixlcaNGxzQlCsXVxEZ17m48Pdr+bnYElEEQkzeeIMKwdOT/x2jR9P4OmmSTQPGDdv/tWv856xXj7btbNnYeV64wJBRHx9GGhkd/5QpVBSG6cfHhx215ZoDrTnFLluWEUT587N9w8bs7GyGXAKmEsqVywzPfPSIppi6dSnj1au8p2BB876dO7mdNImdU+w1BJbx+xcusIM1ynEC7IStKYKzZ2nTb9yYo9f27TkStmYaGTuWncTSpTQ9Zc8es2qpEeoJmIpg5Uqae3LkMMMgHz/mCmaAHXvduuyA9u7lAvX4kv4aCmLSpJiLAw2iosyOMj4bucHdu7w+S5aYIanOzlT0UVE0/f3zD+UPCzOVjGVYbezwUeP5oaG07W/bZp47fx4YN46/r2+/tS6XZeRY3bpUwJ9+ys+lWjXr95QsSUVuaRZ9+tT6osAyZRgx5+fHdh2BrCMQ4pIvH3sxYzgzZQqHgDZM+FO5MlfM/vYb/wEOHYppmilWzPwn8/Zmx/joETuAMmVoPgC4Xb48bvseHlyHsH8/23nwwFQ+GzdS1wHs/MaPZyfr6ckO57ffzE793j2OksuUoYnCEmNEeOgQbfG9epnn1q+n3nzpJe4baw0KFIipCKzRrh3txmvX0i79xRdURBER5nqH2rUZRvroEdC8OTvHtWv51RlKrmzZmJ2Y8ffmzexAL140s6/evGmaLox6Ejdvsvid4WS3hqEIhgwxQ3wNNm9mJ960KZXx2bOUuU4d62sPDPnGjmWop7e3mZfJMAHWrMm2gJhmMkslExpK2SMjuRivXTsenzeP7Vp+JoaSWL+eAwVra0Hc3PhZurvTVJY7NwcbFy9aV+QAn2uskzFCeY3nxl5l/+OPXMrj0PUY8YUTpdWXhI+mIjt3MubRiJl75x2bp3ds355N16/PMMebN+NeExCg9XvvaX3/vtatW2tdqpTWGzdqHRISM9TRGvfva/34sdbr1jHTqdZaz5/PZ547F//bmTGDYaCVK5tv/403Yl7z+LF5rmxZrcPDtZ49W+vbt5kx1Dj34IHW48YxzPPpU76uXo3/2VeumBlZjTZ++knr3r21LlyYL+P40KH8LLp2ZZiq1lqvWsVzbdow26vxnKlTzftq1eJzDA4eNMN3jSS1Rrjk0qVxZTTa+fdf3rt+PeWwpF8/XrNwIcODly0z7/vxx7ht7t5tnq9encfOnjWP1avHY3/8wf1du8x7w8IY0lu0KD8HQOu33uL2//6Pv6vy5c22jN/NihXcN37mb75ptnn3bkz5pk/XesMGfkaWYbWrVvH7NggPN8+dOMHPP2tWrT/8UOu//tL69dd5jUGJEry2YcO4n4ktQQLhow7v2JP7EkWQyly7xoB6I3/0pEk2aTYqih15iRJcN2B0zufPJ3xf/fpaOznx2mnTKNaYMcl7ttGR7N5NBfOf/8Q8/9dfXH9QrBjXHEydSjmbN2eM/uLFvG76dK2/+IIZvgGuDQC07tlT6y+/NDuD06epRPLlS5p8xYuzDa3NNtau5bqF2GsIKlTgtlIlrl3Qmh38L79Q+QBa37tntj1sWMxO/PXX2ZFv2cJjmzdTmdWoYXbMf/wRV8ZWrXju8mXu166tddOmMa+5do3XvPSS1qtXm8oxvu94506ty5WjUgOoXAYNMjtbg8BA8zOxhp8fzxvrBv78k8fLlTPfu7Gmwfjd5crF9RzDh3MNSY8eWlesaL39n3+O+R00aqR1jhymwj12zDy3ciXXB2zaRAU7YYLWWbJQcWnNe7Jn57V+ftafZysSUgRiGhISJn9+Rg95etJLOHo04y9tQMGCNBeULWvada1F2QCc+l++TBNPlSo0IQ0fzul/fNPz+DBMRG3a0Ok3eTKdkAaDBjFCpVQpmqjef5+O4mvX6Bf47TeaTz74gM5uw05tRJe89hpNSgbXrrEdwzkN0IE7ZYp1+aw5kXPkYCQMEDM6qXx5bo8c4Vc0cyYXiHfowOsbN6ZZae9eflaWzlxvb0b0rFtnmi3y5KG5ZM8e3gdYj7tftIi2dsPmnScP7e+GCQfgWgMnJ5qIevWi+8nFxQxpPXo0Zmnt2rVpHjHqNLdoYf7ULB315crx/RhmN4AmnYsXac3s25fHDDNjq1b0ER0/zs+oalXTBGSYhh494u8iJIRO5SVL+Jleu8aoMMt4idy5uf3iC27Dwmi6NPxYJUqYDvo//qA5KVs2Rn6dOcPPff58ZnANDTXNWqlRjCle4tMQafUlMwIHcf8+59OtW3NIHt9wLBl4enIk9O235gjKcvRqScmSWnfpwpWm27dzqr9hA0e9164l77nBwebzGjTQz0w7BlWq8NjgweaxunXNe3LlMkd9nTrFHaX/+2/Mkffq1XFleP11mnhiExHBe8aN437WrNw3iv2EhNCs5ObG41OmmM/p399cJL5+vdnmggU8liePeR/A0WiJEhyBG+aOixe1/uEH/v3559weOpT4Z9qrl9nukSOcYTRtan7HxuvTT7UeMoQjb0DrV1+13l6NGub30KSJ1qdOJfz8PXvMZ/TqRTOcMXO0fA0dGvO+p09jflctW2pdurTW+fNzf8kSziQ6dzbvefhQ65Mntd66NWbb//yj9Z07nIFcuaKfmbgArefM4aymbl3+5jp1ohnr5Enzuo4dY8oWEmJbSyxkRiA8Nzlzcog1ZAi9uz160BP7HJQqRUfn0KGm89dytGtJ7docgTZpwtGdsbJ1/HjTeZhULCM3WrdmDpsffjCPaW3KZzBqlLlg7P59lngAGOZpRNoadOnCdQIvvhgzNNYSX1/OcK5dixnBZMwqjBmBkUguRw5uPT05cq1dm/tGOObnn9MZasxOLB3XGzdyGxLCr+3dd/l1KkUn+KlTjLw6doyzNGPG5ONDR3xSQha9vMy///c/jnw3bqRju1gxzgSMz2b/ftPB7+9PZ/vu3XSqNmrEz//ff80Z08SJjMKxZNkyjqqNsGLLldMXL/K9WEZtGTMK43M14h5cXDgDad+e+ydP0tH7zjv8rNevZ0SYZdRU9uyMNoodJXb2LGePrVqZOYmOHOF29Wq+jzNn+N23bMmV69u38/ykSWaKcoB/58ljPRDCHthVESilWiqlTiqlTiulPrJyvodS6lD0mXe5zgAAFsRJREFUa6dSqrI95RGek5AQZinz9GRP3K4de7MUUqCAGfmxbBmbspbmGeA/3Y0bNH1YZvFMCa6uXEVbrBjXFRw+HDPVgrGYyjDFADQjVanC4i+urgw19PCgIrt92+yoAf6Dt2zJ3Pe3bjERXrFitLAZGJFDBQuaqQuAmCYagM9bvz7umozNm9lhNm3Ka43PxFAExjFfX4aWGh3xuXM0zTRowP1SpdiZurvT5OLiYioCZ2cqiKQUujHkLVuWK2kt1yzcu2eu8zhzhseuRFcvr1GDMfrvv8/O+/BhM0rZMGNdv24qZ4MhQ6h8S5dmWKZl1NDWrfyZPn5smrUKFqSMn3zClBMdoyujL1nC35SRTsQYkJQtS9PT4sWU3yj6A/BZI0bE/QzOnOGahfr1+RsqU8Y0ORUuzPccEsLvpFw53lOoEJ9trNu4f5/v1Xjevn3mM60tUrMVdlMESilnALMAtAJQHkB3pVT5WJedA9BQa10JwEQAtjE+C/YhTx4O0wIC2LPdu8ehlOVwLBns28cVo1qzA0qoGEydOty++65tSjE+fBiz87akShX+8zZubB67cIEj2dKl6ZsAaMPOn58jyNu3mQcoKoofjWVo6N9/85+9QgXzmGWaAEvF5unJvDxGzplZs7iAKaH6QfPnU1n8/rupCDw8GM5prD2YMYPbBw8YkmosxqpQgdf9+KO5yKpQISre4GAzLUZi9O/P0e9rr3E0b4zI587lz8RQJj/9FHP2UKiQqYxi+0aM99ypU0xfAhDTB3P9uumbMOz3RtoGw19QqJD5M/Xw4OcVGsrZy+efc+ZSp44ZslyyJMODDR+MZbaVqChzZbux+vn4ca75uHKFszEnJ4Ym16vH95cnDzvzsDD+fgxle+sWP/8dO6gsq1bls/Lm5XdjrPNYsIC/jeRmTk0y8dmMnvcFoDaA9Rb7owCMSuB6TwCXE2tXfAQOJjKSYToeHjRyOjnRuJlYZXUrHD5sRuAk5bGGLXbq1GQ/Kg6lSzNKxBr/939a798f89ibb/LZo0dzPyqK/oxRo3h87FjzWiO75zvv0G5syP30qXnNzZsx7cvxFZyvVCnh81qbmV7nz2dmVCNSx+DxY25Pn9Z6zRrTj2FJ5860hVvy7rta58wZ/3OtcewYfQsdO2pdoIDWI0fyeVeuMDz4xg1Gahnv++FDRl4B9AvUrGm2FR6u9YABPGdE2RhYZj/19zf/9vWN+bmuXq3122/z/QH0KxkZVCtV0rpaNWaDNfw906bRbxEayufMmhXXjxEVZbbfuzf9CVpr/dprPHb2rHntzp30JXz6Kc+FhPD4/ftae3vzc/rss5hZVocP5zXt2zMqTGvK5uOTvO8iNnCQj6AwAMv8jcHRx+KjP4A/rZ1QSg1SSgUqpQJvxq5QIqQuTk6cX69YwVU7/fpxiJmCNBQVK8a0ZSf22OnT+bctavJu2sSRvDVKlYppFgLM2YORo0cpmh2MlaPh4fQj/Oc/5mj/+nWaBACOAF0s1vF7edE2/vbbHJkaZSBu3eJI0ljUdeiQ+bz42LuX29y5mbG0ePGYsyvD1OXra464DROIwZ07cSOV7t0zR9hJpVw55l0qX572f2MldpYsXD3r7U05XniBEVjZs5u+mD17Ys4WjGp1zs5xo8k2b2ZqCMD0wXTpwhkJwIiutWs5Sp85k4vUAH6+9eox3cOhQ5yVVq7MmU/Hjhytjx5tft9vvkmzpSWW30Xu3Oas4ckTtlW8OPdHjeJzGjQwZ4NGUZycOTm6z5OHKSwMP0yWLIyuAjgrOXuWM5CAgJgzVJsTn4Z43heArgDmWez3AvBNPNc2BnAcgFdi7cqMIA1hhDUMHsyhzIIFdn1cQguc7M2kSXy25UIgrTmrAcwRatGirMEAcEGV1lp/9x1zzicFI4Lq+nXuG3UWEqJbN17z00/c37cv/toARgy+0WZUFCNlgJg59jt25LHy5ZMmt0FUFEfEV69yf8YMtrNmTfz3HDnCa5ydtf74Y/P406cxZY2Nsc7hp58YJTR+PJ8LcCRvyd69PP6//3H/yRPG8xszAa25cBHg6DwxrMn15EnMWZ8RaXbnDveNmZklrVtzVqI1I7YsZ34XLnB2evs225k2LXG5EpbZMTOCYACWVVmLALgS+yKlVCUA8wB00FonI8u34HA8PTk8KlmS3q1Bg8zcznbAmAwao+XUxEjFHNtGazi3jZGvmxswcCCjn4wR9uDB5qgxNo8esf6BkeffcLIaUSdBQWbUT3y8+y639eszsuqDD+KPvorta1CKo+5ChWLawUuVoi3dMlY/KWhNP8rMmdw3IrRi5+9ftMhMH1GqFN/jjRuMrDFIrK5yyZJ02hcuzJH5N9/wfU+dylG/JUY2W2PdgKsrv1NfX9OhbNjtrdWciI2fn+lwNjDSShsYEUNG5tPYmXjHjOEaDuO5RYvGLBBYrBg/GyM9iWWeK5sTn4Z43heY0O4sgOIAsgA4CKBCrGuKATgNoE5S25UZQRrj2jUaklu0YDB+njwcxtiBu3cZK3/7tl2aT5AJEzgqO3485nFjhN29O7d+fhwZAlpPnJh4u5GRXCswYgT3U2KXt2TpUj7755+tnzdGzF5eKX9GYhQpwmd89hlnCBs2xHQhBQXxfOnSibcFcLSfEJbV6YYNo80/NkZ6jd9+i3vOGIXfvctrWrVKXK47d+LODmNjrDg3/A2xMfwW/ftbP//oESv0rVnDVdLP+28FR8wItNYRAN4CsB40+6zQWh9VSg1WSkWXOMdYAF4AZiulgpRSgfE0J6RV8ufnEHT9eg6ltGa8pGVmLxuROzdj5ZO7ktgWjB5N233ZsjGPv/gis3ka9Yzd3GIWlUkMIxvpqVP8CGfMMGcDKcFYhxDfLKJAAbafWK3i58EYuYaEcMbRrJn10b1lUfj4KFAg5hqN2Ozbx+eMG8cRfmgo/QCxQy3fe4/RW23bmsdWrmR4rDEKz52bcf0//5y4XB4e1gsuWfL77xzxxxedZsyWjBlBbFxcGFYbGEh/g+W6Fltj13UEWut1WuvSWmtfrfWk6GNztNZzov8eoLX21FpXiX4lkK1cSLO89Rbn159/ThvI+fP03Nkz8DmVcXZmXHhslOJbNf7Zc+Wi+eXuXTpNk0KtWnReFypEk0bz5imX01AiCdWCaNIk5toFW2M4mONT2EbK8PjKZFpy/37ckqGWGAvqChfm524UIYq9vMXJKWaKD4CKfcUKVnozqFfPNsEIABViq1bxnzcc+tbWJABUBD4+NGtdvBh3LYUtkXoEwvPj6ckFAYsW8b+ob18O43r35momG9YxSKuUKsXVoUZBneRE2/TrxzjxwEBzpWlKMZ4fOyrIEssynvbAsIVbRgBZ8sILVHxJUQRhYcCuXYlfZ4yqS5emErCsTxEfhnzr1yf8edkLQ+b79+OXt2RJ+o/KlYu5At3WiCIQbIOPD+fn58+bQ6/lyzmMHjcuwyuDQoVSXMgNdeowvPA5Fmk/QynWE3IkbdrQJBKfIgCS7oQOCUmaic3oVFetokJNiiKoWJFby+p2qYkxY7p92yyJGhvLFej2/BeSXEOCbYmMZHD44Gg30IQJLF0VX+V3AUoBa9Yw5j0jUL8+zU+26GA9PRM21RgdpbHNkyfpprUGDRjkNnLkc4mYYjp1YjqMKlXiv8bIsZSYP+J5Udqehic74O/vrwMDxaecpomK4uqaF14wS2EVLUofQrduGX52IKQer77K9BAnTzpaEvsQGkqzUO3a9Gc8D0qpffH5YWVGINgeJyeGz/zwA30EW7cywPvVV5llbMmSuPX6BCEFvP22ueI8I+LuzhXedl1DAPERCPbEMmSiZUsO2y5dYqzlsGGcG3ftSoNxfGlHBSEBjKydGRWtuWiudGn7PkdMQ0Lq8OABvYiWYTFZsnBmkDMnDbaNGnFbqVL8pcoEQUgRCZmGZEYgpA45czJT2MaNXD115w7DNlxdafxcvZp1/QCalkqUYIB9tWoMG2nenMeSEkIiCEKykP8qIfVwcaGJqGXLmMdDQ5nA3smJ27AwJmK/c4cJ7A2cnRkeUrgwt76+dEIXLcqk7oUL2z+8QhAyIGIaEtIWWjPK6NAhKoyQEOC//wW++oorb1xduWI5e/aYZakM3N0ZQ+jjQ+WQJw9zHxcsyG2+fNx6eXHVl2WWL0HIwCRkGhJFIKQPtKbpaNIkJozp14/mpt9+47GbN6k0jNSkFSowUf3ly9aLBhtkzWoqg1y5GLheoQJnGR4eVDj58/O4hwefabzc3enLkHBYIR0gikDIHGhNp/T165wNZMvGPM6rVrGG4I4dnEU8esREeQ8fshZkWBiTBYWG0nmdKxfXQiRlTb9SdHrnyMHZipsbn+3tbRbK8vCgIila1ExG5O5OJZM9O+/ZsYP3d+1qHhMFI9gQUQSCYElYmFncNzZXr1JRlCjBNJbffMPZxu3bNFmdOcMlwH5+LFS7eDHbs0ylWb48ZxlXr8YttpscnJ2pHLJk4cwlWzaatLy8qKguXqRiefyYkVZZsnARX86cVGgXLlC22rUpX8GCLIbr5sb24tuKQz5DIopAEOxJRAQjoQoUYAft7s4O9f59Lnt98IDKZO9eOsT79OFof8kS4J9/2HEXLcqUpS4uTKl57x4wezbNWkZea4DZ7fLm5fnjx+2TktIwlTk58W/j5eFBZWIoORcXKirjlS8fZ0MATXJublS4xgpzFxcqM6Uoe4ECdPhnzcrPzcOD7fzf/5npXnPm5L1Zs1LRZcnCz9vJif4fmTUlGVEEgpCeiYqiMrlzhzMAI1tZVBSVh1JMfp83LxP9aM1E/eHh7DBXrqSfw9mZazXCw6mErl6lknn8mDOI/Pl5/+PHXBH+4AE73YgIPitvXhZkCAszS379f3t3FyPVXcZx/Ptj3QUKLSsvWoIEKEEMNloIL22oTS9qpcR01ZtWTFpfklpT6ktiCNrE1Du00YQrG4yN1WC5sNZy0UgbCxK7oYDI8iJFqEACRZamgC10l93l8eL5T3Y67Cy72xnOnD3PJzmZmXPOzP6f+WfPM/8zZ55/udZWP/h3dfVPjFBvki/NzZ5kxozxv93c7PGOGePLlCl+es7Mf9TY2+sJu7XVLz6YPduTzsWLnrSbmjxxlZb5833/Eye8H8pHaH19/ZdCnzvn70950uruhiVL/Dlvv+3v+8yZfjqxpcXfs0mT/O+cOOHJ9qab/P2fPNm3l9ozbtzVU50N+a2KRBBCqCWz/ppSpWXCBP/03t3tn+q7u33+yfZ2P93W1uaFczZt8v3On/eRlJkfCO+/31/n+HGf5f3YMX985IgfhFet8sdbt/pzOzv9wH35sp8amz/fD+qvvuoXDvT1+XLlim+7+WY/yHd0+EH70iXfBl7draXFTwGeOXN1vDfc0B9v6TlZWLPGa3aNQGaJQNIKYD3QhE9kv65iu9L2lcAl4Otmtmew14xEEEKoiZ4eT1izZvVP5vz++76Utvf2+u2MGZ48zHwS4a6u/u+Axo+HhQt9v0OHfMRRGmVJ/ql+7lzf3t7uI5RTp3x7T48n0OXLPWnt3++Jr6vLP/1fuOAjmWXLfPvixf7r+xHIJBFIagL+DXwen8h+F/BVM/tX2T4rgcfxRLAMWG9mg05XEYkghBCGL6vqo0uBo2b2HzO7DGwC2ir2aQN+l+ZW3gG0Sqpznb0QQgjl6pkIZgDls42eTOuGuw+SHpG0W9Lus2fP1ryhIYRQZPVMBANd11V5Hmoo+2BmG8xssZktnjZtWk0aF0IIwdUzEZwEZpY9/gTw1gj2CSGEUEf1TAS7gHmS5khqAR4ENlfssxl4SO524IKZna5jm0IIIVSo22/JzaxX0mpgC3756DNmdlDSo2n708BL+BVDR/HLR79Rr/aEEEIYWF2LipjZS/jBvnzd02X3DXisnm0IIYQwuJi8PoQQCi53JSYknQVOjPDpU4EPUQ6y4YymeEZTLDC64olYGtdw4pllZgNedpm7RPBhSNpd7Zd1eTSa4hlNscDoiidiaVy1iidODYUQQsFFIgghhIIrWiLYkHUDamw0xTOaYoHRFU/E0rhqEk+hviMIIYRwtaKNCEIIIVSIRBBCCAVXmEQgaYWkw5KOSlqbdXuGS9JxSfsl7ZW0O62bLOkVSUfS7Uezbmc1kp6R1CnpQNm6qu2X9KPUV4clfSGbVg+sSixPSjqV+mdvmnSptK2RY5kpaaukQ5IOSvpeWp/XvqkWT+76R9I4STsldaRYfprW175vzGzUL3itozeBW4AWoANYkHW7hhnDcWBqxbqfA2vT/bXAz7Ju5yDtvwtYBBy4VvuBBamPxgJzUt81ZR3DNWJ5EvjhAPs2eizTgUXp/o34rIILctw31eLJXf/gZfonpvvNwOvA7fXom6KMCIYyW1oetQHPpvvPAl/KsC2DMrPtwDsVq6u1vw3YZGbdZnYML0q49Lo0dAiqxFJNo8dy2tI84Wb2LnAInxwqr31TLZ5qGjYec++lh81pMerQN0VJBEOaCa3BGfCypH9IeiSt+7ilst3p9mOZtW5kqrU/r/21WtK+dOqoNFzPTSySZgML8U+eue+binggh/0jqUnSXqATeMXM6tI3RUkEQ5oJrcEtN7NFwH3AY5LuyrpBdZTH/voVMBe4DTgN/CKtz0UskiYCzwPfN7P/DbbrAOvyEE8u+8fM+szsNnzSrqWSbh1k9xHHUpREkPuZ0MzsrXTbCbyAD/nOSJoOkG47s2vhiFRrf+76y8zOpH/aK8Cv6R+SN3wskprxg+ZGM/tTWp3bvhkonjz3D4CZnQe2ASuoQ98UJREMZba0hiVpgqQbS/eBe4EDeAwPp90eBl7MpoUjVq39m4EHJY2VNAeYB+zMoH1DVvrHTL6M9w80eCySBPwGOGRmvyzblMu+qRZPHvtH0jRJren+eOAe4A3q0TdZfzN+Hb+BX4lfQfAm8ETW7Rlm22/BrwboAA6W2g9MAf4KHEm3k7Nu6yAxPIcPyXvwTy7fGqz9wBOprw4D92Xd/iHE8ntgP7Av/UNOz0ksd+KnD/YBe9OyMsd9Uy2e3PUP8Bngn6nNB4CfpPU175soMRFCCAVXlFNDIYQQqohEEEIIBReJIIQQCi4SQQghFFwkghBCKLhIBKGwJLWn29mSVtX4tX880N8KoRHF5aOh8CTdjVem/OIwntNkZn2DbH/PzCbWon0h1FuMCEJhSSpVdlwHfC7Vqf9BKvT1lKRdqUjZt9P+d6da93/Af5yEpD+nQoAHS8UAJa0DxqfX21j+t+SeknRAPr/EA2WvvU3SHyW9IWlj+pVsCHX3kawbEEIDWEvZiCAd0C+Y2RJJY4HXJL2c9l0K3Gpe5hfgm2b2TioBsEvS82a2VtJq82Jhlb6CFz77LDA1PWd72rYQ+DReH+Y1YDnw99qHG8IHxYgghKvdCzyUyv++jv+kf17atrMsCQB8V1IHsAMv+DWPwd0JPGdeAO0M8DdgSdlrnzQvjLYXmF2TaEK4hhgRhHA1AY+b2ZYPrPTvEi5WPL4HuMPMLknaBowbwmtX0112v4/4/wzXSYwIQoB38WkNS7YA30nljJH0yVT1tdIk4FxKAp/CpxEs6Sk9v8J24IH0PcQ0fNrLhqh2GYorPnGE4NUde9Mpnt8C6/HTMnvSF7ZnGXga0L8Aj0rah1d73FG2bQOwT9IeM/ta2foXgDvwSrIGrDGz/6ZEEkIm4vLREEIouDg1FEIIBReJIIQQCi4SQQghFFwkghBCKLhIBCGEUHCRCEIIoeAiEYQQQsH9HybIIEsbut9xAAAAAElFTkSuQmCC\n",
578 | "text/plain": [
579 | ""
580 | ]
581 | },
582 | "metadata": {
583 | "needs_background": "light"
584 | },
585 | "output_type": "display_data"
586 | }
587 | ],
588 | "source": [
589 | "train_loss = {\"aux_loss\":[], \"target_loss\":[], \"final_loss\":[]}\n",
590 | "test_loss = {\"aux_loss\":[], \"target_loss\":[], \"final_loss\":[]}\n",
591 | "for epoch in range(epochs):\n",
592 | " for i in range(int(len(train_data) / batch)):\n",
593 | " record_test_loss(test_loss, test_data, i)\n",
594 | " label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show, min_batch, clk_length, show_length = data_reader.get_batch_data(train_data, min_batch, batch = batch)\n",
595 | " user_profile_dict, user_profile_list, user_behavior_list, click_behavior_dict, noclick_behavior_dict, target_item_dict = get_train_data(label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show)\n",
596 | " aux_loss, target_loss, final_loss = train_one_step(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list, label)\n",
597 | " #Record_loss12\n",
598 | " loss_dict = dict()\n",
599 | " loss_dict[\"aux_loss\"] = str(aux_loss)\n",
600 | " loss_dict[\"target_loss\"] = str(target_loss)\n",
601 | " loss_dict[\"final_loss\"] = str(final_loss)\n",
602 | " utils.add_loss(loss_dict, loss_file_name, level=\"train\")\n",
603 | " train_loss[\"aux_loss\"].append(float(aux_loss))\n",
604 | " train_loss[\"target_loss\"].append(float(target_loss))\n",
605 | " train_loss[\"final_loss\"].append(float(final_loss))\n",
606 | " get_loss_fig(train_loss, test_loss)\n",
607 | " tf.summary.trace_on(graph=True, profiler=True)\n",
608 | " with train_summary_writer.as_default():\n",
609 | " tf.summary.scalar(\"train_aux_loss epoch: \"+str(epoch), aux_loss, step = i)\n",
610 | " tf.summary.scalar(\"train_target_loss epoch: \"+str(epoch), target_loss, step = i)\n",
611 | " tf.summary.scalar(\"train_final_loss epoch: \"+str(epoch), final_loss, step = i)\n",
612 | " tf.summary.trace_export(\n",
613 | " name=\"DIEN\", \n",
614 | " step=i, \n",
615 | " profiler_outdir=log_path)\n",
616 | " model.save_weights(checkpoint_path.format(epoch=epoch))"
617 | ]
618 | },
619 | {
620 | "cell_type": "markdown",
621 | "metadata": {},
622 | "source": [
623 | "# 模型评估"
624 | ]
625 | },
626 | {
627 | "cell_type": "code",
628 | "execution_count": 24,
629 | "metadata": {},
630 | "outputs": [
631 | {
632 | "name": "stdout",
633 | "output_type": "stream",
634 | "text": [
635 | "./checkpoint/cp-0002.ckpt\n"
636 | ]
637 | },
638 | {
639 | "data": {
640 | "text/plain": [
641 | ""
642 | ]
643 | },
644 | "execution_count": 24,
645 | "metadata": {},
646 | "output_type": "execute_result"
647 | }
648 | ],
649 | "source": [
650 | "last_model = DIEN(embedding_count_dict, embedding_dim_dict, embedding_features_list, user_behavior_features, activation=\"dice\")\n",
651 | "latest = tf.train.latest_checkpoint(checkpoint_dir)\n",
652 | "print(latest)\n",
653 | "last_model.load_weights(latest)"
654 | ]
655 | },
656 | {
657 | "cell_type": "code",
658 | "execution_count": 26,
659 | "metadata": {},
660 | "outputs": [
661 | {
662 | "name": "stdout",
663 | "output_type": "stream",
664 | "text": [
665 | "WARNING:tensorflow:Layer dien_1 is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because it's dtype defaults to floatx.\n",
666 | "\n",
667 | "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
668 | "\n",
669 | "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
670 | "\n"
671 | ]
672 | },
673 | {
674 | "data": {
675 | "text/plain": [
676 | "(0.029646765, 0.26222047, 0.29186723)"
677 | ]
678 | },
679 | "execution_count": 26,
680 | "metadata": {},
681 | "output_type": "execute_result"
682 | }
683 | ],
684 | "source": [
685 | "model= last_model\n",
686 | "label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show, clk_length, show_length = data_reader.get_test_data(test_data)\n",
687 | "user_profile_dict, user_profile_list, user_behavior_list, click_behavior_dict, noclick_behavior_dict, target_item_dict = get_train_data(label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show)\n",
688 | "aux_loss, target_loss, final_loss = get_test_loss(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list, label)\n",
689 | "aux_loss, target_loss, final_loss"
690 | ]
691 | },
692 | {
693 | "cell_type": "code",
694 | "execution_count": 27,
695 | "metadata": {},
696 | "outputs": [],
697 | "source": [
698 | "def convert_tensor(data):\n",
699 | " return tf.convert_to_tensor(data)\n",
700 | "\n",
701 | "def get_normal_data(data, col):\n",
702 | " return data[col].values\n",
703 | "\n",
704 | "def get_sequence_data(data, col):\n",
705 | " rst = []\n",
706 | " max_length = 0\n",
707 | " for i in data[col].values:\n",
708 | " temp = len(list(map(eval,i[1:-1].split(\",\"))))\n",
709 | " if temp > max_length:\n",
710 | " max_length = temp\n",
711 | "\n",
712 | " for i in data[col].values:\n",
713 | " temp = list(map(eval,i[1:-1].split(\",\")))\n",
714 | " padding = np.zeros(max_length - len(temp))\n",
715 | " rst.append(list(np.append(np.array(temp), padding)))\n",
716 | " return rst\n",
717 | "\n",
718 | "def get_evaluate_data(data):\n",
719 | " batch_data = data\n",
720 | " click = get_normal_data(batch_data, \"guide_dien_final_train_data.clk\")\n",
721 | " target_cate = get_normal_data(batch_data, \"guide_dien_final_train_data.cate_id\")\n",
722 | " target_brand = get_normal_data(batch_data, \"guide_dien_final_train_data.brand\")\n",
723 | " cms_segid = get_normal_data(batch_data, \"guide_dien_final_train_data.cms_segid\")\n",
724 | " cms_group = get_normal_data(batch_data, \"guide_dien_final_train_data.cms_group_id\")\n",
725 | " gender = get_normal_data(batch_data, \"guide_dien_final_train_data.final_gender_code\")\n",
726 | " age = get_normal_data(batch_data, \"guide_dien_final_train_data.age_level\")\n",
727 | " pvalue = get_normal_data(batch_data, \"guide_dien_final_train_data.pvalue_level\")\n",
728 | " shopping = get_normal_data(batch_data, \"guide_dien_final_train_data.shopping_level\")\n",
729 | " occupation = get_normal_data(batch_data, \"guide_dien_final_train_data.occupation\")\n",
730 | " user_class_level = get_normal_data(batch_data, \"guide_dien_final_train_data.new_user_class_level\")\n",
731 | " hist_brand_behavior_clk = get_sequence_data(batch_data, \"guide_dien_final_train_data.click_brand\")\n",
732 | " hist_cate_behavior_clk = get_sequence_data(batch_data, \"guide_dien_final_train_data.click_cate\")\n",
733 | " hist_brand_behavior_show = get_sequence_data(batch_data, \"guide_dien_final_train_data.show_brand\")\n",
734 | " hist_cate_behavior_show = get_sequence_data(batch_data, \"guide_dien_final_train_data.show_cate\")\n",
735 | " return tf.one_hot(click, 2), convert_tensor(target_cate), convert_tensor(target_brand), convert_tensor(cms_segid), convert_tensor(cms_group), convert_tensor(gender), convert_tensor(age), convert_tensor(pvalue), convert_tensor(shopping), convert_tensor(occupation), convert_tensor(user_class_level), convert_tensor(hist_brand_behavior_clk), convert_tensor(hist_cate_behavior_clk), convert_tensor(hist_brand_behavior_show), convert_tensor(hist_cate_behavior_show)"
736 | ]
737 | },
738 | {
739 | "cell_type": "code",
740 | "execution_count": 29,
741 | "metadata": {},
742 | "outputs": [],
743 | "source": [
744 | "label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show = get_evaluate_data(test_data)\n",
745 | "user_profile_dict, user_profile_list, user_behavior_list, click_behavior_dict, noclick_behavior_dict, target_item_dict = get_train_data(label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show)\n",
746 | "output, logit, aux_loss = model(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list)"
747 | ]
748 | },
749 | {
750 | "cell_type": "code",
751 | "execution_count": 30,
752 | "metadata": {},
753 | "outputs": [
754 | {
755 | "name": "stdout",
756 | "output_type": "stream",
757 | "text": [
758 | "[训练集]正例:负例=501 : 9435\n",
759 | "[测试集]正例:负例=56 : 943\n"
760 | ]
761 | }
762 | ],
763 | "source": [
764 | "train_label = train_data[\"guide_dien_final_train_data.clk\"].values\n",
765 | "positive_num = len(train_label[train_label == 1])\n",
766 | "negative_num = len(train_label[train_label == 0])\n",
767 | "print(\"[训练集]正例:负例=%d : %d\" % (positive_num, negative_num))\n",
768 | "test_label = test_data[\"guide_dien_final_train_data.clk\"].values\n",
769 | "positive_num = len(test_label[test_label == 1])\n",
770 | "negative_num = len(test_label[test_label == 0])\n",
771 | "print(\"[测试集]正例:负例=%d : %d\" % (positive_num, negative_num))"
772 | ]
773 | },
774 | {
775 | "cell_type": "code",
776 | "execution_count": 31,
777 | "metadata": {},
778 | "outputs": [],
779 | "source": [
780 | "y_true = label.numpy()[:,-1]\n",
781 | "y_score = output.numpy()[:,-1]"
782 | ]
783 | },
784 | {
785 | "cell_type": "code",
786 | "execution_count": 48,
787 | "metadata": {},
788 | "outputs": [],
789 | "source": [
790 | "threshold = 0.0031\n",
791 | "y_pre = y_score.copy()\n",
792 | "y_pre[y_pre > threshold] = 1\n",
793 | "y_pre[y_pre <= threshold] = 0"
794 | ]
795 | },
796 | {
797 | "cell_type": "code",
798 | "execution_count": 34,
799 | "metadata": {},
800 | "outputs": [],
801 | "source": [
802 | "import numpy as np\n",
803 | "from sklearn.metrics import accuracy_score\n",
804 | "from sklearn.metrics import f1_score\n",
805 | "from sklearn.metrics import auc\n",
806 | "import sklearn.metrics as sm\n",
807 | "from sklearn.metrics import roc_curve, auc\n",
808 | "import matplotlib as mpl \n",
809 | "import matplotlib.pyplot as plt"
810 | ]
811 | },
812 | {
813 | "cell_type": "code",
814 | "execution_count": 50,
815 | "metadata": {},
816 | "outputs": [
817 | {
818 | "name": "stdout",
819 | "output_type": "stream",
820 | "text": [
821 | "0.8818818818818819\n"
822 | ]
823 | }
824 | ],
825 | "source": [
826 | "print(accuracy_score(y_true, y_pre))"
827 | ]
828 | },
829 | {
830 | "cell_type": "code",
831 | "execution_count": 51,
832 | "metadata": {},
833 | "outputs": [
834 | {
835 | "name": "stdout",
836 | "output_type": "stream",
837 | "text": [
838 | "混淆矩阵为:\n",
839 | "[[876 67]\n",
840 | " [ 51 5]]\n"
841 | ]
842 | }
843 | ],
844 | "source": [
845 | "m = sm.confusion_matrix(y_true, y_pre)\n",
846 | "print('混淆矩阵为:', m, sep='\\n')"
847 | ]
848 | },
849 | {
850 | "cell_type": "code",
851 | "execution_count": 52,
852 | "metadata": {},
853 | "outputs": [
854 | {
855 | "name": "stdout",
856 | "output_type": "stream",
857 | "text": [
858 | "分类报告为:\n",
859 | " precision recall f1-score support\n",
860 | "\n",
861 | " 0.0 0.94 0.93 0.94 943\n",
862 | " 1.0 0.07 0.09 0.08 56\n",
863 | "\n",
864 | " accuracy 0.88 999\n",
865 | " macro avg 0.51 0.51 0.51 999\n",
866 | "weighted avg 0.90 0.88 0.89 999\n",
867 | "\n"
868 | ]
869 | }
870 | ],
871 | "source": [
872 | "r = sm.classification_report(y_true, y_pre)\n",
873 | "print('分类报告为:', r, sep='\\n')"
874 | ]
875 | },
876 | {
877 | "cell_type": "code",
878 | "execution_count": 53,
879 | "metadata": {},
880 | "outputs": [
881 | {
882 | "data": {
883 | "text/plain": [
884 | "0.679821239206181"
885 | ]
886 | },
887 | "execution_count": 53,
888 | "metadata": {},
889 | "output_type": "execute_result"
890 | }
891 | ],
892 | "source": [
893 | "from sklearn.metrics import roc_auc_score\n",
894 | "auc_score = roc_auc_score(y_true,y_score)\n",
895 | "auc_score"
896 | ]
897 | },
898 | {
899 | "cell_type": "code",
900 | "execution_count": 54,
901 | "metadata": {},
902 | "outputs": [],
903 | "source": [
904 | "def plot_roc(labels, predict_prob):\n",
905 | " false_positive_rate,true_positive_rate,thresholds=roc_curve(labels, predict_prob)\n",
906 | " roc_auc=auc(false_positive_rate, true_positive_rate)\n",
907 | " plt.title('ROC')\n",
908 | " plt.plot(false_positive_rate, true_positive_rate,'b',label='AUC = %0.4f'% roc_auc)\n",
909 | " plt.legend(loc='lower right')\n",
910 | " plt.plot([0,1],[0,1],'r--')\n",
911 | " plt.ylabel('TPR')\n",
912 | " plt.xlabel('FPR')\n",
913 | " plt.show()"
914 | ]
915 | },
916 | {
917 | "cell_type": "code",
918 | "execution_count": 55,
919 | "metadata": {},
920 | "outputs": [
921 | {
922 | "data": {
923 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nO3de5xV8/7H8dendFVIJemioXtUNLq4K5JcyuV0ooMcdHI/+HUUJ0KOhOOSlOR6OIWi4nRBJ+RacSpdlCQZJU3RfaqZ+f7++M7UNGamPbXXXvvyfj4e89iz1v7uvT/LZH32+n6/6/M15xwiIpK6yoQdgIiIhEuJQEQkxSkRiIikOCUCEZEUp0QgIpLilAhERFKcEoGISIpTIhApgZmtMLNtZrbZzH42sxfNrEqB5080s/+a2SYz22Bmb5tZ80LvcZCZPW5mK/PeZ1nedo3YH5HI7ykRiOzd+c65KkBr4DhgAICZdQDeBSYCRwBpwDzgEzM7Kq9NeWA60ALoAhwEnAisA9rG9jBEima6s1ikeGa2ArjGOfd+3vZQoIVz7lwzmwl87Zy7vtBrpgBrnXNXmNk1wAPA0c65zTEOXyQiuiIQiZCZ1QXOAZaZWWX8N/s3imj6OnBW3u9nAlOVBCSeKRGI7N0EM9sE/Aj8AtwDHIr//2d1Ee1XA/n9/9WLaSMSN5QIRPauu3OuKnA60BR/kv8VyAVqF9G+NpCZ9/u6YtqIxA0lApEIOec+BF4EHnHObQE+A/5QRNMe+AFigPeBs83swJgEKbIPlAhESudx4Cwzaw30B640s5vNrKqZVTOzwUAH4N689v/CdymNN7OmZlbGzKqb2Z1m1jWcQxDZkxKBSCk459YCLwMDnXMfA2cDF+HHAX7ATy892Tn3bV777fgB42+A94CNwCx899IXMT8AkSJo+qiISIrTFYGISIpTIhARSXFKBCIiKU6JQEQkxR0QdgClVaNGDdegQYOwwxARSShffvllpnOuZlHPJVwiaNCgAXPmzAk7DBGRhGJmPxT3nLqGRERSnBKBiEiKUyIQEUlxSgQiIilOiUBEJMUFlgjM7Hkz+8XMFhTzvJnZk3kLec83s+ODikVERIoX5BXBi/jFuotzDtAo76cPMCLAWEREpBiB3UfgnPvIzBqU0KQb8LLz5U8/N7NDzKy2c07L+olIUsrNhSefhPXrS/e6Mjk7qfbb9zTr1pjOnaMfV5g3lNXBL9iRLyNv3+8SgZn1wV81UL9+/ZgEJyISbUuXwq23+t/NIntNa/c/nuPPHMYvPFNxKZ07R3+xuzAHi4v6z1Dk4gjOuVHOuXTnXHrNmkXeIS0iEvdycvzj66/7q4MSf7ZmkXvHAL4qewLH1VpNnfHDuO/RYFY8DfOKIAOoV2C7LrAqpFhEROJL9+4wbRpcdRU8+ihUqxbYR4V5RTAJuCJv9lB7YIPGB0QkmV15pX8sU9yZd9MmyMryv/fvD+++C88/H2gSgGCnj44BPgOamFmGmV1tZn3NrG9ek8nAcmAZ8CxwfVCxiIiE7bvv4MsvIT0dzjyziAbTpsExx8D99/vt00+Hs86KSWxBzhq6dC/PO+CGoD5fRCSevPKKHyB+8004+OACT6xfD7fdBi+9BE2bwrnnxjy2hCtDLSKyr2bPhpkzw/ns55+Hjh2hXsGR0enToVcvWLcO7roL/v53qFgx5rEpEYhIyujVC779NrzPf+SRQjsOOwzS0mDqVGjdOpSYQIlARFLE99/7JDB0KPzlL7H//LJl4cDKDl58Cb76yt9Zduyx8Omnkd9UEBAlAhFJCe++6x/PPx8OOiiEAL7/3meg996DU06BbdugUqXQkwAoEYhIgrvrLhg1au/tNm/2/fNNmgQf0x5ycmD4cBgwwM8bffppnxCKnUMae0oEIpLQPvsMDjgALrpo7227dAnhC3hmJtx9N5x2GowcCXFYJkeJQEQSXqNG/kt33Ni5E159Fa64AmrV8mMCaWlx0Q1UFCUCEUkY2dkwfrzv5sm3apWffBM3vvwS/vxnmD8fateGs8+Go44KO6oSKRGISML47DPo2fP3+1u2jH0sv7NtG9x7r58jethh8NZbPgkkACUCEUkYO3b4xzfegHbtdu8//PBw4tlD9+5+atI118DDD8Mhh4QdUcSUCEQk4dSqVegO3bBs3Ajly/u7ge+8E/72N+jUKeyoSi1+5i+JiCSSyZN9kbj77vPbp52WkEkAlAhEREonMxMuv9wXh6taFS64IOyI9psSgYhIpN57D5o3h7Fj/b0BX30F7duHHdV+0xiBiITGOT++umlTZO3nzw82nr2qXRsaN4YRI3ydoCShRCAioZk4ES68sPSvC3jBrt2cg+eeg//9z9+xdswxvo51nN4Ytq+UCEQkNEOH+htuJ06M/NxatSoceWSwcQGwfDlcey38979+tbA4KhIXbUoEIhKKTz7xN4gNGxZnvSw5Ob5E9F13+SJGzzzj7w2IoyJx0aZEICJR99tvcPHFsGFD8W1++gkOPRSuuip2cUUkM9PfIdypkx8LqFs37IgCp0QgIlH37be+RyU93d/8VZTDD/crhh14YGxjK9KOHX5R4d69fcBz5/r+pyTsBiqKEoGIBGbQoFDWYi+d2bN9kbgFC/y3/86doUGDsKOKKSUCkSTwzTeQkRF2FLstWRJ2BBHYutXfC/DYY35a6KRJPgmkICUCkQS3cyccf7yf1BJvQlkSMlLdusH770OfPn760sEHhx1RaJQIRBLc2rU+CfTrF1/VDg48EFq3DjuKQjZsgAoVfJG4gQN9obgzzgg7qtApEYgkuJ9/9o8dOsDJJ4cbS1x75x3o29fXCXrwQTj11LAjihvJOzFWJEVcd51/jIua/PFo7Vq47DI4/3w/XzWSxY1TjBKBSALLzIRZs6By5Ti7KStevPuuLxI3bpy/N2DOHDjhhLCjijvqGhJJYPlF2CZMgCpVwo0lLtWpA82a+RvDWrQIO5q4pUQgkoAyM+H772HKFL/dqlW48cSN3FwYPdoXics/+X/0UdhRxT0lApEE1KnT7quBevX8Wukpb9kyXyTugw/8TKD8InGyVxojEElAGzZAx45+Isz06WFHE7KcHHj0UWjZ0i8U8+yz/j+KkkDEAk0EZtbFzJaY2TIz61/E8web2dtmNs/MFppZvJWfEolb9er58g2NGoUdScgyM2HwYDjrLFi0yFcKTZEaQdESWNeQmZUFhgNnARnAbDOb5JxbVKDZDcAi59z5ZlYTWGJmrzrndgQVl0gYNm6Em26KfCWuvfnll+i8T8Lavh1efhmuvnp3kbj69ZUA9lGQYwRtgWXOueUAZjYW6AYUTAQOqGpmBlQB1gPZAcYkEoq5c/15Ky0tOrN7GjZM2bI48MUXPgEsXOgrhHbuHKOVapJXkImgDvBjge0MoF2hNk8Bk4BVQFXgj8653MJvZGZ9gD4A9evXDyRYkVgYPdr37cs+2LLFl4V4/HE/LfQ//0nhbBhdQSaCoq7RXKHts4G5QEfgaOA9M5vpnNu4x4ucGwWMAkhPTy/8HiJRs2WLXzAl2n78ce9tZC+6d/dF4q67DoYMifOKdoklyESQAdQrsF0X/82/oKuAIc45Bywzs++BpsCsAOMSKdY55/i1yYNSsWJw752UfvvNF4mrVMmXjB44UDWCAhBkIpgNNDKzNOAnoCdwWaE2K4FOwEwzqwU0AZYHGJNIiTIzoW1buOWW6L93lSrQrnDnqBRv0iT/7f/yy/0VwCmnhB1R0gosETjnss3sRmAaUBZ43jm30Mz65j0/ErgfeNHMvsZ3Jd3hnMsMKiaRSBx5pK9RJiH55Re4+WZ47TV/b8All4QdUdIL9M5i59xkYHKhfSML/L4K0GiPhG78eD/2uGoVHHNM2NGksKlT/ULGmzfD/ffDHXdAuXJhR5X0VGJCBH/OWboUatRQD0So6tXzZVSfftpXDZWYUCIQwS/u0quXr04gMZSbC88842+0eOYZXyTugw/CjirlKBFI0nGudHfe5ub6tUtq1QouJinC0qW+HMTMmb48RFaWplWFRIlAksqOHdC1674VYqtTJ/rxSBGys32RuHvu8dNCX3gBrrxS5SFCpEQgScM5X89n+nS/JnndupG/tnx56NEjuNikgHXr4KGHfMYePhxq1w47opSnRCBJY/hwGDUKBgyABx4IOxrZw/bt8OKLfr2AWrVg3jw/MCxxQYlAEtIPP8BTT/leBvBdQs8849cnHzw43NikkM8+80XiFi+Go4+GM89UEogzSgSSkF5/HR55BKpW3d21fNJJ8MorUEbLLcWHzZvh73+HJ5/0J/6pU30SkLijRCAJyeWVHvz5Z6hcOdxYpBjdu/sBmxtvhH/8w2dtiUv67iQJJydnd5eQxJlff/VrBQMMGuSnhg4bpiQQ53RFIAklKwsaNIA1a/y2uoHiyJtvwg03wBVX+FlBJ58cdkQSIf1vJAllyxafBM47D8aM0f1HceHnn31huIsvhsMPh549w45ISkmJQBJS584638SFKVN8TaB33vHjALNmwXHHhR2VlJK6hiShzNKSRfHlyCP9iX/4cGjaNOxoZB/pikASyuOP+8eGDcONI2Xl5vobOK691m83b+5nBikJJDQlAkkoZn6Vr3POCTuSFLRkiV8m8qab/CLMWVlhRyRRokQgIiXbuRMefBBatYJFi3ypiClTNFKfRJQIJG798APUrOnPN/k/776rKaMx9+uv8PDDvn7HokWqFJqENFgscWvlSr+Y/B//6O8dyHfWWaGFlDqysuD556FvXzjsMJg/v3TlXCWhKBFI3Lv2WujUKewoUsjHH/sicUuXQuPGvj6QkkBSUyKQmMvOhpdegk2bSm733XexiUfybNrka3gPH+4vwd59V0XiUoQSgcTcrFl+hcJIlCundUtipnt3mDEDbrnF1/KuUiXsiCRGlAgk5nbu9I+TJsEpp5Tctnx5VRcN1Pr1fhS+cmW4/34/CNyhQ9hRSYwpEUhoqlSBQw4JO4oUNm6cLxJ35ZUwdCiceGLYEUlIlAgkcP37+/XJ8+3Y4R81AzEkq1f7BPDWW9CmDfTqFXZEEjIlAgncJ5/AAQfABRfs3le1KpxwQngxpaz//Af+9Cc/PfShh+C22/wfR1Ka/gVITDRtCiNGhB2FcNRRPgM/9ZSfGiqCEoFEyVdfwbx5RT/3889Qv35s45E8OTn+pD9/Pjz3HDRr5qeFihSgRCBR8ac/weLFxT+vbqAQLFrk5+l+9hl07eq7g1QfSIqgRCBRsX27n4aeXya6sCOOiG08KW3HDj8L6P77/WDMK6/AZZdpdF6KFWgiMLMuwBNAWWC0c25IEW1OBx4HygGZzrnTgoxJglOlil+nREL222/w2GNw4YXw5JO+VpBICQKr42hmZYHhwDlAc+BSM2teqM0hwNPABc65FsAfgopHou/++6FJE/+zcmXY0aS4bdv8WEBurj/xf/01jB2rJCARCbKgb1tgmXNuuXNuBzAW6FaozWXAm865lQDOuV8CjEeibOpU/+Xz+OP92uV//nPYEaWojz7yawXcdJMvEQHqi5NSCbJrqA7wY4HtDKBdoTaNgXJm9gFQFXjCOfdy4Tcysz5AH4D6mn4SV1q2hDFjwo4iRW3c6O/WGzEC0tLg/fdVplX2SZCJoKiRKVfE57cBOgGVgM/M7HPn3NI9XuTcKGAUQHp6euH3kBj55RdfoTjfunWqAxSq7t3hgw/g1lt9P92BB4YdkSSoIBNBBlCvwHZdYFURbTKdc1uALWb2EdAKWIrEnbvugtGj99zXsmU4saSszEyffStXhgce8DOB2rcPOypJcEGOEcwGGplZmpmVB3oCkwq1mQicYmYHmFllfNdRCbPRJUzbtvn1SebN2/3z0kthR5UinPODv82awT33+H0dOigJSFQEdkXgnMs2sxuBafjpo8875xaaWd+850c65xab2VRgPpCLn2K6IKiYZP9VqKCrgJj76Se4/npft/uEE+CKK8KOSJJMoPcROOcmA5ML7RtZaPth4OEg45DSe/BBX5yyoO++g2rVwoknZb3zjq8OunMnPPII/PWvULZs2FFJktGdxVKkN96AjAxIT9+9r0YNTUqJuYYN/ToBw4b530UCoEQgxWrf3vdGSAzl5Pi7gefNgxdf9GVbp0wJOypJckoEAvjzzxdf+LpksPeF5SUACxfC1Vf7P8S556pInMSMEoEAviu6e/c997VtG04sKWfHDhgyxC8Yf/DB8O9/Q8+eKhInMaNEIABs3uwfX3kF6uXd/dGqVXjxpJTffvPdQX/4gy/fWrNm2BFJilEikD20bQuNGoUdRQrYuhWefRZuvHF3kbjatcOOSlJUqW8oM7OyZqbVrkX21YwZcOyxfiroBx/4fUoCEqJiE4GZHWRmA8zsKTPrbN5NwHKgR+xCFEkSGzbAX/4CHTv6/v8ZMzQfV+JCSV1D/wJ+BT4DrgH6AeWBbs65uTGITSS5dO/uS0b36weDBqlin8SNkhLBUc65YwHMbDSQCdR3zmliYRLaujXsCJLU2rW+Kmjlyv527bJltYCzxJ2Sxgh25v/inMsBvlcSSF6vvuof9SU1Spzz00ALFolr315JQOJSSVcErcxsI7vXFahUYNs55w4KPDqJmUMPhXLloE6dsCNJAhkZcN11/uaMdu2gd++wIxIpUbGJwDmnylYppmnTsCNIApMmwZ/+5G/Vfuwxv3ykisRJnCs2EZhZRaAv0BBfJvp551x2rAKT2MjMhNtvh88/90XlZD81bgwnn+wXkj/qqLCjEYlISWMELwHpwNdAV+DRmEQkMfXFF/Dyy76kzQUXhB1NAsrO9uWh89cIaNoUJk9WEpCEUtIYQfMCs4aeA2bFJiQJw2uvaRyz1ObP90Xi5syBbt1UJE4SVqSzhtQllKRWrw47ggS0fbufCdSmDaxcCa+/7lfxURKQBFXSFUHrvFlC4GcKadZQkpk3D6691v+uc1gpbNwITz8Nl17qB4SrVw87IpH9UlIimOecOy5mkUjMZWb6x3794Jhjwo0l7m3ZAqNGwc03++qgCxZArVphRyUSFSV1DbmYRSGhOv98lb4v0fTpvkjcbbfBhx/6fUoCkkRKuiI4zMxuK+5J59w/A4hHJH789hv83//Bc8/52twffginnhp2VCJRV1IiKAtUYfedxZJk8usLaXygGBdeCDNnwh13+MHhSpXCjkgkECUlgtXOuftiFonE3Jo1/lG9HAWsWQNVqvhCcUOGwAEH+NlBIkmspDECXQkkuYUL/eNhh4UbR1xwDv71L2jefHeRuHbtlAQkJZSUCLRiRhJ77DG/PO6hh6priJUr4dxz/d3BTZr4m8REUkhJRefWxzIQia2MDP/45pvhxhG6iRN9kTjn/ALy11+vInGScrR4fQqrUgVOOy3sKELinJ8z27QpnH46DBsGDRqEHZVIKJQIUsjKlTB8uK+Tlr9mesrJzoZHH4Wvv4ZXXvFdQW+/HXZUIqFSIkghb7wBQ4f6CTFm0LZt2BHF2Lx58Oc/w1df+amhKhInAigRpJTcXP+4Zo1PBikjKwsGD4aHHvJ1gcaNg4svDjsqkbhR0qwhSSLZ2X7RrJS0aRM88wz06gWLFikJiBQSaCIwsy5mtsTMlplZ/xLanWBmOWZ2SZDxpKr33/c9IAMG+O0yqZD+N2/2C8bk5PgicYsWwYsv+vmyIrKHwLqGzKwsMBw4C8gAZpvZJOfcoiLaPQRMCyqWVPf99/582K8ftG6dApUS3n0X+vTxo+Nt2sAZZ/hkICJFCvK7YVtgmXNuuXNuBzAW6FZEu5uA8cAvAcYiwC23wGWXhR1FgNavh6uugrPP9pdAM2f6JCAiJQpysLgO8GOB7QygXcEGZlYHuBDoCBS7UKKZ9QH6ANSvXz/qgSarr7/2C2fNmRN2JDFy4YXwySdw550wcKBmBIlEKMhEUFStosJrHDwO3OGcy7ESCuI750YBowDS09O1TkKE7r0Xxo/3v9eqBYccEm48gfj5Z6ha1U+DevhhKF/e93+JSMSC7BrKAOoV2K4LrCrUJh0Ya2YrgEuAp82se4AxpZRVq6BjRz8+sHp1kk0Zdc4P/jZvDnff7fe1baskILIPgkwEs4FGZpZmZuWBnsCkgg2cc2nOuQbOuQbAOOB659yEAGNKKWvWwOGH+1lCSbUC2YoV0KWLHw9o0cIPDIvIPgssETjnsoEb8bOBFgOvO+cWmllfM+sb1OeKN38+LF/uE0FSeestv8Dyp5/CU0/5VcOaNAk7KpGEFuidxc65ycDkQvtGFtO2d5CxpJr33/ePlyTLnRn5ReJatIAzz4QnnoAjjww7KpGkkAq3FqWkefP81UCHDmFHsp927oR//MPfFQzQuDFMmKAkIBJFqjWUJGbNgs8/37390UfQqlV48UTFV1/5RWLmzoUePWD7dqhQIeyoRJKOEkGSuP56+PLLPff1TdSRmG3b4L77/HTQmjX9uEB3TSYTCYoSQZLYuRO6dvXL7oLvTq9WLdyY9tmWLfDcc3Dllb5eUMIeiEhiUCJIIuXLJ3BNtU2bYMQIuP12qFHDF4mrUSPsqERSggaLk0T+WgMJaepUPyW0f39fHwiUBERiSIkgCeTmwg8/QJ06YUdSSuvW+e6fc87xtz1/8olfP1hEYkpdQ0lgxQrfs5Jws4QuusjfGDZwINx1l2YEiYREiSDBvf++71mBBEkEq1f7InFVqviB4PLlEyRwkeSlRJDANmyAs87yvx90kO9mj1vOwQsvwG23+QXk//lPOKHYyuMiEkMaI0hgO3f6x3vu8WMElSuHG0+xli+Hzp39zWGtWiXwDQ4iyUlXBEmgRo04XmvgzTfh8suhbFk/PbRPnxRZNFkkcSgRSDDyi8Qde6wvGf3441Cv3t5fJyIxp69mEl07dsDgwX5xZOegUSO/TJqSgEjcUiKQ6Jkzxw8ADxzot3fsCDceEYmIuoYSSG4uTJ/uS/EAbNwYbjy7bNvmR6wffdTXvp44ES64IOyoRCRCSgQJ5NNP/eSbwkIfKN6yxa8ffPXVMHRoHAQkIqWhRJBAtm71jy+8sHuN9nLl/PrtMbdxIzz9NPTr56ctLV4M1auHEIiI7C8lggTUuPHuRBCK//zH3wuwahW0b+/rAykJiCQsDRYnkL//PeQA1q71S0aedx4cfLDvq1KROJGEpyuCBOGcX7kRQiwlcfHFfj3MQYNgwABfJ0hEEp4SQYLYsAFycvzEnIMOiuEH//ST//ZfpQo89pivEBrXRY1EpLSUCOLYggWQmel/z8jwj7VqxejDnYPRo+H//s/PBvrnP6FNmxh9uIjEkhJBnPr5Z2jZ0p+PC0pLi8GHf/cdXHstzJgBZ5wBN9wQgw8VkbAoEcSplSt9EhgyBNq18/uqVInBl/Jx4+CKK/y81FGj4JprfM0gEUlaSgRxas0a/9ixY4zK9ucXiWvVCs49148H1K0bgw8WkbBp+micevFF/xj4mMCOHXDvvdCz5+4icW+8oSQgkkKUCOLU2rX+MdAF6WfN8n1NgwbBAQeoSJxIilIiiGNnnOHXc4m6rVv9bKAOHeDXX+Htt+HVV7V4vEiK0hhBiJYv9/cHFGXz5gBrt23bBq+84lcLe+ihGN+YICLxJtBEYGZdgCeAssBo59yQQs/3Au7I29wMXOecmxdkTPFi2TLfHV+Sc8+N4gdu2ABPPQV33OHrAi1eDNWqRfEDRCRRBZYIzKwsMBw4C8gAZpvZJOfcogLNvgdOc879ambnAKOAdkHFFE+++84/DhkCTZsW3SY9PUof9vbbvkjczz/DSSf5+kBKAiKSJ8grgrbAMufccgAzGwt0A3YlAufcpwXafw6kzFSV/OmhF1209yuDfbZ2Ldx8M4wd69cOnjgxitlFRJJFkImgDvBjge0MSv62fzUwpagnzKwP0Aegfv360Yov5nJz4fbbffmeb7/1+w4/PMAPzC8Sd999vktIReJEpAhBJoKibkd1RezDzM7AJ4KTi3reOTcK321Eenp6ke+RCFasgMcfhyOO8HXcLrrI3y0cVRkZfpS5ShX/YRUqQIsWUf4QEUkmQU4fzQDqFdiuC6wq3MjMWgKjgW7OuXUBxhO6/O6g0aNh0SIYPz6K1Rtyc+GZZ/xyZfmLxx9/vJKAiOxVkFcEs4FGZpYG/AT0BC4r2MDM6gNvApc755YGGEsocnPhhx92F477+mv/GPW7hb/91heJ+/BD6NQJbropyh8gIskssETgnMs2sxuBafjpo8875xaaWd+850cCdwPVgafNfzXOds4lzWjm3/8ODz74+/1RvVv4jTd8kbgKFeC55+Cqq1QkTkRKJdD7CJxzk4HJhfaNLPD7NcA1QcYQprVrfXf9E0/s3le7dpSuCPKLxB13HHTr5tcLOOKIKLyxiKQa3VkcsMqV/Rf2qNm+HR54wN8Q9vrr0LChnx4qIrKPVGsoQO+845eXjJrPP/cDwPffD5UqqUiciESFEkGANm0qvpZQqWzZArfeCiee6N908mR4+WUViRORqFAiCFD58n4yz37LyvLdP9dfDwsXwjnnROFNRUQ8jRFE0YYNkJ29ezs3dz/e7LffYNgwGDBgd5G4wMqRikgq0xVBlLz+uj9P16ix+2fDBr/0b6lNmOBvDLv3Xvg0rxyTkoCIBERXBFHyY15VpaFDoWJF/7sZdO9eijdZs8bfDPbGG37t4LffjsFq9SKS6pQIoqxvX6hadR9ffMklfvnIwYPhb3/bx8sJEZHSUSKIgg8+8F/e98nKlX5tgKpV4ckn/Uyg5s2jGZ6ISIk0RhAFDzwAH33k67tVqhThi3JzYfhw/6K77/b7jjtOSUBEYk6JIApyc/3CXwsWwAGRXGMtWQKnnQY33ugXkL/llsBjFBEpjhJBFDi3u8LoXr3+uh8IXrAAXngBpk2DBg2CDE9EpERKBPtpxw6YMQN27txLw/xM0aaNX5Fm8WLo3VuVQkUkdEoE+2nrVv/YsGExDbKy4K67/Iwg5+Doo+Hf/w54jUoRkcgpEUTJCScUsfPTT/0A8D/+4WcFqUiciMQhTR8tZPx432sTqW3biti5eTPceSc89RTUqwdTp8LZZ0ctRhUTFpcAAA2zSURBVBGRaFIiKOTyy4s5uZegbFnf47PLjh0wbhzccMPuqwERkTilrqFCcnL8Tb07d0b+s307nH/Sehg0yFedO/RQf1kxbJiSgIjEPV0RFKFMmQjvB8g3frz/9p+ZCR07wqmnwsEHBxafiEg06YqggDPP9L06ZSL9r7J6NVx8sZ8RdMQRMGeOTwIiIglEVwQF5Fd87t07whf06AGzZ8OQIXD77aW8jBARiQ86cxVQrhz85S/QqFEJjX74wY8BVK3qxwAqVYImTWIWo0hYdu7cSUZGBllZWWGHIiWoWLEidevWpVwpqhcrEeT58UfYuLGEBvlF4gYMgGuugccfh9atYxafSNgyMjKoWrUqDRo0wHRHfFxyzrFu3ToyMjJIS0uL+HUaI8jzyCP+sV69Ip785hvf93/zzXDKKX4heZEUk5WVRfXq1ZUE4piZUb169VJftSkR5Nm50y8FcNtthZ4YO9YXiVu8GF5+GSZPhiOPDCVGkbApCcS/ffkbKREUcNBBBTbyV54/4QT4wx9g0SJ/t5n+RxCRJJPSiWDDBjj2WKhTB156Ke8cv20b9O/vp4XmF4l75RWoVSvscEUEeOuttzAzvvnmm137PvjgA84777w92vXu3Ztx48YBfqC7f//+NGrUiGOOOYa2bdsyZcqU/Y7lwQcfpGHDhjRp0oRp06YV227YsGE0adKEFi1a8Le//Q2AV199ldatW+/6KVOmDHPnzgXgtddeo2XLlnu0B1i5ciVnnHEGxx13HC1btmTy5Mn7fQyQ4oPFGRl+WYCOHeGoo+CCajOh9TWwdClcfbXvLypfPuwwRaSAMWPGcPLJJzN27FgGDRoU0WsGDhzI6tWrWbBgARUqVGDNmjV8+OGH+xXHokWLGDt2LAsXLmTVqlWceeaZLF26lLJly+7RbsaMGUycOJH58+dToUIFfvnlFwB69epFr169APj666/p1q0brVu3Zt26dfTr148vv/ySmjVrcuWVVzJ9+nQ6derE4MGD6dGjB9dddx2LFi2ia9eurFixYr+OA1I8EeS78cpNXPhFf3j4aUhLg/fe83eXiUiR/vpXyPvyGjWtW/vJeCXZvHkzn3zyCTNmzOCCCy6IKBFs3bqVZ599lu+//54KFSoAUKtWLXr06LFf8U6cOJGePXtSoUIF0tLSaNiwIbNmzaJDhw57tBsxYgT9+/ff9dmHHXbY795rzJgxXHrppQAsX76cxo0bU7NmTQDOPPNMxo8fT6dOnTAzNuZNb9ywYQNHHHHEfh1DvpRKBDNnQsHkmZHhHy1nJ0yY4P91Dx4MBx4YSnwiUrIJEybQpUsXGjduzKGHHspXX33F8ccfX+Jrli1bRv369Tloj0HAot16663MmDHjd/t79uxJ//7999j3008/0b59+13bdevW5aeffvrda5cuXcrMmTO56667qFixIo888ggnFKpb/9prrzFx4kQAGjZsyDfffMOKFSuoW7cuEyZMYEdeCftBgwbRuXNnhg0bxpYtW3j//ff3ekyRSJlEsH277wLKzvbbh7KOW3iCstzNIWmH+imiKhAnEpG9fXMPypgxY/jrX/8K+JPzmDFjOP7444udKVPaGTSPPfZYxG1dEevTFvV52dnZ/Prrr3z++efMnj2bHj16sHz58l1tv/jiCypXrswxxxwDQLVq1RgxYgR//OMfKVOmDCeeeCLLly8H/PH37t2b22+/nc8++4zLL7+cBQsWUCbiujhFCzQRmFkX4AmgLDDaOTek0POW93xXYCvQ2zn3VRCxrFnjk8A/HnBcVXUcNe69kbIb1nPda2dR8/RTACUBkXi2bt06/vvf/7JgwQLMjJycHMyMoUOHUr16dX799dc92q9fv54aNWrQsGFDVq5cyaZNm6i6ly97pbkiqFu3Lj/++OOu7YyMjCK7aurWrctFF12EmdG2bVvKlClDZmbmrq6fsWPH7uoWynf++edz/vnnAzBq1Khd4w7PPfccU6dOBaBDhw5kZWWRmZlZZHdTqTjnAvnBn/y/A44CygPzgOaF2nQFpgAGtAe+2Nv7tmnTxu2LL75wrjY/uVXtu/u15tu0cW7u3H16L5FUtGjRolA/f+TIka5Pnz577Dv11FPdRx995LKyslyDBg12xbhixQpXv35999tvvznnnOvXr5/r3bu32759u3POuVWrVrl//etf+xXPggULXMuWLV1WVpZbvny5S0tLc9nZ2b9rN2LECDdw4EDnnHNLlixxdevWdbm5uc4553JyclydOnXcd999t8dr1qxZ45xzbv369a5Vq1ZuyZIlzjnnunTp4l544QXnnP971K5de9d7FVTU3wqY44o7Xxf3xP7+AB2AaQW2BwADCrV5Bri0wPYSoHZJ77uviWDSJOdmcpLLqVDRuaFDndu5c5/eRyRVhZ0ITjvtNDdlypQ99j3xxBOub9++zjnnPv74Y9euXTvXqlUrl56e7t59991d7bZv3+769evnjj76aNeiRQvXtm1bN3Xq1P2OafDgwe6oo45yjRs3dpMnT961/+qrr3azZ8/e9dm9evVyLVq0cMcdd5ybPn36rnYzZsxw7dq1+9379uzZ0zVr1sw1a9bMjRkzZtf+hQsXuhNPPNG1bNnStWrVyk2bNq3IuEqbCMwV0c8VDWZ2CdDFOXdN3vblQDvn3I0F2rwDDHHOfZy3PR24wzk3p9B79QH6ANSvX7/NDz/8UOp4PvkExt89j/73VuKwkxvv62GJpKzFixfTrFmzsMOQCBT1tzKzL51z6UW1D3KMoKhRmsJZJ5I2OOdGAaMA0tPT9ylznXQSnDS91b68VEQkqQV5Z3EGULCEW11g1T60ERGRAAWZCGYDjcwszczKAz2BSYXaTAKuMK89sME5tzrAmERkPwTVlSzRsy9/o8C6hpxz2WZ2IzANP4PoeefcQjPrm/f8SGAyfubQMvz00auCikdE9k/FihVZt26dSlHHMZe3HkHFihVL9brABouDkp6e7ubMmbP3hiISVVqhLDEUt0JZWIPFIpJEypUrV6pVryRxpHQZahERUSIQEUl5SgQiIiku4QaLzWwtUPpbi70aQGYUw0kEOubUoGNODftzzEc652oW9UTCJYL9YWZzihs1T1Y65tSgY04NQR2zuoZERFKcEoGISIpLtUQwKuwAQqBjTg065tQQyDGn1BiBiIj8XqpdEYiISCFKBCIiKS4pE4GZdTGzJWa2zMz6F/G8mdmTec/PN7Pjw4gzmiI45l55xzrfzD41s4RfpWdvx1yg3QlmlpO3al5Ci+SYzex0M5trZgvN7MNYxxhtEfzbPtjM3jazeXnHnNBVjM3seTP7xcwWFPN89M9fxa1hmag/+JLX3wFHAeWBeUDzQm26AlPwK6S1B74IO+4YHPOJQLW8389JhWMu0O6/+JLnl4Qddwz+zocAi4D6eduHhR13DI75TuChvN9rAuuB8mHHvh/HfCpwPLCgmOejfv5KxiuCtsAy59xy59wOYCzQrVCbbsDLzvscOMTMasc60Cja6zE75z51zv2at/k5fjW4RBbJ3xngJmA88EssgwtIJMd8GfCmc24lgHMu0Y87kmN2QFXziyRUwSeC7NiGGT3OuY/wx1CcqJ+/kjER1AF+LLCdkbevtG0SSWmP52r8N4pEttdjNrM6wIXAyBjGFaRI/s6NgWpm9oGZfWlmV8QsumBEcsxPAc3wy9x+DdzinMuNTXihiPr5KxnXIyhq6aTCc2QjaZNIIj4eMzsDnwhODjSi4EVyzI8DdzjncpJkRa1IjvkAoA3QCagEfGZmnzvnlgYdXEAiOeazgblAR+Bo4D0zm+mc2xh0cCGJ+vkrGRNBBlCvwHZd/DeF0rZJJBEdj5m1BEYD5zjn1sUotqBEcszpwNi8JFAD6Gpm2c65CbEJMeoi/bed6ZzbAmwxs4+AVkCiJoJIjvkqYIjzHejLzOx7oCkwKzYhxlzUz1/J2DU0G2hkZmlmVh7oCUwq1GYScEXe6Ht7YINzbnWsA42ivR6zmdUH3gQuT+BvhwXt9Zidc2nOuQbOuQbAOOD6BE4CENm/7YnAKWZ2gJlVBtoBi2McZzRFcswr8VdAmFktoAmwPKZRxlbUz19Jd0XgnMs2sxuBafgZB8875xaaWd+850fiZ5B0BZYBW/HfKBJWhMd8N1AdeDrvG3K2S+DKjREec1KJ5Jidc4vNbCowH8gFRjvnipyGmAgi/DvfD7xoZl/ju03ucM4lbHlqMxsDnA7UMLMM4B6gHAR3/lKJCRGRFJeMXUMiIlIKSgQiIilOiUBEJMUpEYiIpDglAhGRFKdEIBKhvAqmcwv8NMir9LnBzP5nZovN7J68tgX3f2Nmj4Qdv0hxku4+ApEAbXPOtS64w8waADOdc+eZ2YHAXDN7J+/p/P2VgP+Z2VvOuU9iG7LI3umKQCRK8so6fImvd1Nw/zZ8LZxELmwoSUyJQCRylQp0C71V+Ekzq46vD7+w0P5qQCPgo9iEKVI66hoSidzvuobynGJm/8OXdBiSVwLh9Lz98/G1b4Y4536OYawiEVMiENl/M51z5xW338waAx/njRHMjXVwInujriGRgOVVe30QuCPsWESKokQgEhsjgVPNLC3sQEQKU/VREZEUpysCEZEUp0QgIpLilAhERFKcEoGISIpTIhARSXFKBCIiKU6JQEQkxf0/8T/s5c/KQgMAAAAASUVORK5CYII=\n",
924 | "text/plain": [
925 | ""
926 | ]
927 | },
928 | "metadata": {
929 | "needs_background": "light"
930 | },
931 | "output_type": "display_data"
932 | }
933 | ],
934 | "source": [
935 | "plot_roc(y_true, y_score)"
936 | ]
937 | },
938 | {
939 | "cell_type": "markdown",
940 | "metadata": {},
941 | "source": [
942 | "# 整体训练图像"
943 | ]
944 | },
945 | {
946 | "cell_type": "code",
947 | "execution_count": 57,
948 | "metadata": {},
949 | "outputs": [
950 | {
951 | "data": {
952 | "text/html": [
953 | "\n",
954 | "\n",
967 | "
\n",
968 | " \n",
969 | " \n",
970 | " | \n",
971 | " train_aux_loss | \n",
972 | " train_target_loss | \n",
973 | " train_final_loss | \n",
974 | "
\n",
975 | " \n",
976 | " \n",
977 | " \n",
978 | " 0 | \n",
979 | " 0.895453 | \n",
980 | " 0.692025 | \n",
981 | " 1.587478 | \n",
982 | "
\n",
983 | " \n",
984 | " 1 | \n",
985 | " 0.883613 | \n",
986 | " 0.691035 | \n",
987 | " 1.574647 | \n",
988 | "
\n",
989 | " \n",
990 | " 2 | \n",
991 | " 0.871820 | \n",
992 | " 0.690196 | \n",
993 | " 1.562016 | \n",
994 | "
\n",
995 | " \n",
996 | " 3 | \n",
997 | " 0.860334 | \n",
998 | " 0.689409 | \n",
999 | " 1.549743 | \n",
1000 | "
\n",
1001 | " \n",
1002 | " 4 | \n",
1003 | " 0.848613 | \n",
1004 | " 0.688840 | \n",
1005 | " 1.537453 | \n",
1006 | "
\n",
1007 | " \n",
1008 | " ... | \n",
1009 | " ... | \n",
1010 | " ... | \n",
1011 | " ... | \n",
1012 | "
\n",
1013 | " \n",
1014 | " 292 | \n",
1015 | " 0.030206 | \n",
1016 | " 0.197515 | \n",
1017 | " 0.227721 | \n",
1018 | "
\n",
1019 | " \n",
1020 | " 293 | \n",
1021 | " 0.028985 | \n",
1022 | " 0.140821 | \n",
1023 | " 0.169806 | \n",
1024 | "
\n",
1025 | " \n",
1026 | " 294 | \n",
1027 | " 0.028990 | \n",
1028 | " 0.081985 | \n",
1029 | " 0.110975 | \n",
1030 | "
\n",
1031 | " \n",
1032 | " 295 | \n",
1033 | " 0.028055 | \n",
1034 | " 0.166338 | \n",
1035 | " 0.194393 | \n",
1036 | "
\n",
1037 | " \n",
1038 | " 296 | \n",
1039 | " 0.028797 | \n",
1040 | " 0.197161 | \n",
1041 | " 0.225958 | \n",
1042 | "
\n",
1043 | " \n",
1044 | "
\n",
1045 | "
297 rows × 3 columns
\n",
1046 | "
"
1047 | ],
1048 | "text/plain": [
1049 | " train_aux_loss train_target_loss train_final_loss\n",
1050 | "0 0.895453 0.692025 1.587478\n",
1051 | "1 0.883613 0.691035 1.574647\n",
1052 | "2 0.871820 0.690196 1.562016\n",
1053 | "3 0.860334 0.689409 1.549743\n",
1054 | "4 0.848613 0.688840 1.537453\n",
1055 | ".. ... ... ...\n",
1056 | "292 0.030206 0.197515 0.227721\n",
1057 | "293 0.028985 0.140821 0.169806\n",
1058 | "294 0.028990 0.081985 0.110975\n",
1059 | "295 0.028055 0.166338 0.194393\n",
1060 | "296 0.028797 0.197161 0.225958\n",
1061 | "\n",
1062 | "[297 rows x 3 columns]"
1063 | ]
1064 | },
1065 | "execution_count": 57,
1066 | "metadata": {},
1067 | "output_type": "execute_result"
1068 | }
1069 | ],
1070 | "source": [
1071 | "train_loss_data = pd.read_csv(\"./loss/dien/train_loss.csv.2020_09_22_21_35_06\")\n",
1072 | "train_loss_data"
1073 | ]
1074 | },
1075 | {
1076 | "cell_type": "code",
1077 | "execution_count": 56,
1078 | "metadata": {},
1079 | "outputs": [
1080 | {
1081 | "data": {
1082 | "text/html": [
1083 | "\n",
1084 | "\n",
1097 | "
\n",
1098 | " \n",
1099 | " \n",
1100 | " | \n",
1101 | " test_aux_loss | \n",
1102 | " test_target_loss | \n",
1103 | " test_final_loss | \n",
1104 | "
\n",
1105 | " \n",
1106 | " \n",
1107 | " \n",
1108 | " 0 | \n",
1109 | " 0.895550 | \n",
1110 | " 0.692121 | \n",
1111 | " 1.587671 | \n",
1112 | "
\n",
1113 | " \n",
1114 | " 1 | \n",
1115 | " 0.883785 | \n",
1116 | " 0.691325 | \n",
1117 | " 1.575110 | \n",
1118 | "
\n",
1119 | " \n",
1120 | " 2 | \n",
1121 | " 0.872121 | \n",
1122 | " 0.690532 | \n",
1123 | " 1.562653 | \n",
1124 | "
\n",
1125 | " \n",
1126 | " 3 | \n",
1127 | " 0.860558 | \n",
1128 | " 0.689721 | \n",
1129 | " 1.550279 | \n",
1130 | "
\n",
1131 | " \n",
1132 | " 4 | \n",
1133 | " 0.849101 | \n",
1134 | " 0.688917 | \n",
1135 | " 1.538019 | \n",
1136 | "
\n",
1137 | " \n",
1138 | " ... | \n",
1139 | " ... | \n",
1140 | " ... | \n",
1141 | " ... | \n",
1142 | "
\n",
1143 | " \n",
1144 | " 292 | \n",
1145 | " 0.030182 | \n",
1146 | " 0.261107 | \n",
1147 | " 0.291289 | \n",
1148 | "
\n",
1149 | " \n",
1150 | " 293 | \n",
1151 | " 0.030074 | \n",
1152 | " 0.261199 | \n",
1153 | " 0.291273 | \n",
1154 | "
\n",
1155 | " \n",
1156 | " 294 | \n",
1157 | " 0.029966 | \n",
1158 | " 0.261354 | \n",
1159 | " 0.291320 | \n",
1160 | "
\n",
1161 | " \n",
1162 | " 295 | \n",
1163 | " 0.029859 | \n",
1164 | " 0.261639 | \n",
1165 | " 0.291498 | \n",
1166 | "
\n",
1167 | " \n",
1168 | " 296 | \n",
1169 | " 0.029752 | \n",
1170 | " 0.261937 | \n",
1171 | " 0.291690 | \n",
1172 | "
\n",
1173 | " \n",
1174 | "
\n",
1175 | "
297 rows × 3 columns
\n",
1176 | "
"
1177 | ],
1178 | "text/plain": [
1179 | " test_aux_loss test_target_loss test_final_loss\n",
1180 | "0 0.895550 0.692121 1.587671\n",
1181 | "1 0.883785 0.691325 1.575110\n",
1182 | "2 0.872121 0.690532 1.562653\n",
1183 | "3 0.860558 0.689721 1.550279\n",
1184 | "4 0.849101 0.688917 1.538019\n",
1185 | ".. ... ... ...\n",
1186 | "292 0.030182 0.261107 0.291289\n",
1187 | "293 0.030074 0.261199 0.291273\n",
1188 | "294 0.029966 0.261354 0.291320\n",
1189 | "295 0.029859 0.261639 0.291498\n",
1190 | "296 0.029752 0.261937 0.291690\n",
1191 | "\n",
1192 | "[297 rows x 3 columns]"
1193 | ]
1194 | },
1195 | "execution_count": 56,
1196 | "metadata": {},
1197 | "output_type": "execute_result"
1198 | }
1199 | ],
1200 | "source": [
1201 | "test_loss_data = pd.read_csv(\"./loss/dien/test_loss.csv.2020_09_22_21_35_06\")\n",
1202 | "test_loss_data"
1203 | ]
1204 | },
1205 | {
1206 | "cell_type": "code",
1207 | "execution_count": 58,
1208 | "metadata": {},
1209 | "outputs": [],
1210 | "source": [
1211 | "def get_loss_fig_aux(train_loss_data, test_loss_data):\n",
1212 | " train_loss = {\n",
1213 | " \"aux_loss\":list(train_loss_data[\"train_\" + \"aux_loss\"].values), \n",
1214 | " \"target_loss\":list(train_loss_data[\"train_\" + \"target_loss\"].values), \n",
1215 | " \"final_loss\":list(train_loss_data[\"train_\" + \"final_loss\"].values)\n",
1216 | " }\n",
1217 | " test_loss = {\n",
1218 | " \"aux_loss\":list(test_loss_data[\"test_\" + \"aux_loss\"].values), \n",
1219 | " \"target_loss\":list(test_loss_data[\"test_\" + \"target_loss\"].values), \n",
1220 | " \"final_loss\":list(test_loss_data[\"test_\" + \"final_loss\"].values)\n",
1221 | " }\n",
1222 | " get_loss_fig(train_loss, test_loss)"
1223 | ]
1224 | },
1225 | {
1226 | "cell_type": "code",
1227 | "execution_count": 59,
1228 | "metadata": {},
1229 | "outputs": [
1230 | {
1231 | "data": {
1232 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nOydd3hUVfPHvycFAgRICKGDoddQJPQuvYMURYp0QRQVUeRFmghYEBAFeamCUqTqq6IgJRQFIUBC50eH0CG0EAIkOb8/vrncTbKp7GZT5vM8+9zdW86dbWfOmZkzo7TWEARBEDIvTo4WQBAEQXAsoggEQRAyOaIIBEEQMjmiCARBEDI5oggEQRAyOaIIBEEQMjmiCARBEDI5oggEIR6UUueVUs0cLYcg2BtRBIIgCJkcUQSCkEyUUoOUUqeVUiFKqf8ppQpF71dKqRlKqRtKqXtKqUNKqUrRx9oopY4ppR4opS4rpUY69l0IgokoAkFIBkqplwBMBdAdQEEAFwCsjD7cAkBDAGUAeAB4BcDt6GMLAbyhtc4JoBKArakotiAkiIujBRCEdEZPAIu01gcAQCk1GsAdpZQPgKcAcgIoB2Cv1vq4xXVPAVRQSgVpre8AuJOqUgtCAsiMQBCSRyFwFgAA0FqHgqP+wlrrrQC+BTAbwHWl1DylVK7oU7sAaAPgglJqu1KqTirLLQjxIopAEJLHFQAvGC+UUjkAeAG4DABa61la6+oAKoImog+i9+/TWncEkA/AzwBWpbLcghAvoggEIWFclVJuxgPswPsppaoqpbICmALgX631eaVUDaVULaWUK4CHAMIBRCqlsiileiqlcmutnwK4DyDSYe9IEGIhikAQEmYDgEcWjwYAxgJYC+AqgJIAXo0+NxeA+aD9/wJoMpoWfaw3gPNKqfsAhgDolUryC0KiKClMIwiCkLmRGYEgCEImRxSBIAhCJkcUgSAIQiZHFIEgCEImJ92tLM6bN6/28fFxtBiCIAjpiv3799/SWntbO5buFIGPjw8CAgIcLYYgCEK6Qil1Ib5jYhoSBEHI5IgiEARByOSIIhAEQcjkpDsfgSAIqcPTp08RHByM8PBwR4siJAM3NzcUKVIErq6uSb7GbopAKbUIQDsAN7TWleI5pzGAmQBcAdzSWjeylzyCICSP4OBg5MyZEz4+PlBKOVocIQlorXH79m0EBwejePHiSb7Onqah7wG0iu+gUsoDwBwAHbTWFQF0s6MsgiAkk/DwcHh5eYkSSEcopeDl5ZXsWZzdFIHWegeAkAROeQ3AOq31xejzb9hLFkEQUoYogfRHSr4zRzqLywDwVEr5K6X2K6X6xHeiUmqwUipAKRVw8+bNFN3s0iVg+HDg0aOUiisIgpAxcaQicAFQHUBbAC0BjFVKlbF2otZ6ntbaT2vt5+1tdWFcouzaBXzzDdCxY4rlFQRByJA4UhEEA/hTa/1Qa30LwA4AVex1s1dfBQoXBv76C/jtN3vdRRAEW3L37l3MmTMn2de1adMGd+/etYNEKcff3x/t2rVztBhWcaQi+AVAA6WUi1IqO4BaAI7b62ZKAX/+yW2PHsCTJ/a6kyAItiI+RRAZmXClzw0bNsDDw8NeYmU47KYIlFIrAOwGUFYpFayUGqCUGqKUGgIAWuvjAP4EcAjAXgALtNZH7CUPAFSqBAweDISGAq+8Ys87CUIGpHHjuA+jkw4Ls378++95/NatuMeSwEcffYQzZ86gatWqqFGjBpo0aYLXXnsNvr6+AIBOnTqhevXqqFixIubNm/fsOh8fH9y6dQvnz59H+fLlMWjQIFSsWBEtWrTAowQchfPnz0eNGjVQpUoVdOnSBWFhYQCAvn37Ys2aNc/Oc3d3BwCsX78ezZo1g9YaV69eRZkyZXDt2rVE31dISAg6deqEypUro3bt2jh06BAAYPv27ahatSqqVq2KatWq4cGDB7h69SoaNmyIqlWrolKlSti5c2eSPrvkYM+ooR5a64Jaa1etdRGt9UKt9Vyt9VyLc77UWlfQWlfSWs+0lyyWzJkDeHsDv/wCHDuWGncUBCGlfPbZZyhZsiQCAwPx5ZdfYu/evZg8eTKORf95Fy1ahP379yMgIACzZs3C7du347Rx6tQpDBs2DEePHoWHhwfWrl0b7/1efvll7Nu3D0FBQShfvjwWLlyYoHydO3dGgQIFMHv2bAwaNAgTJ05EgQIFEn1f48ePR7Vq1XDo0CFMmTIFffowVmbatGmYPXs2AgMDsXPnTmTLlg3Lly9Hy5YtERgYiKCgIFStWjXR9pNLpltZ7OQEbNkCNGoEDBwI7NwJODs7WipBSAf4+8d/LHv2hI/nzZvw8SRSs2bNGAulZs2ahfXr1wMALl26hFOnTsHLyyvGNcWLF3/WeVavXh3nz5+Pt/0jR47g448/xt27dxEaGoqWLVsmKtM333yDSpUqoXbt2ujRo0eS3seuXbueKaSXXnoJt2/fxr1791CvXj2MGDECPXv2xMsvv4wiRYqgRo0a6N+/P54+fYpOnTrZRRFkylxDvr7A118Du3cDI0c6WhpBEJJKjhw5nj339/fH5s2bsXv3bgQFBaFatWpWF1JlzZr12XNnZ2dERETE237fvn3x7bff4vDhwxg/fvyz9lxcXBAVFQWAq3efWDgZL1++DCcnJ1y/fv3ZOYmhtY6zTymFjz76CAsWLMCjR49Qu3ZtnDhxAg0bNsSOHTtQuHBh9O7dG0uXLk3SPZJDplQEANCrF1CoEDBzJvDPP46WRhAEa+TMmRMPHjyweuzevXvw9PRE9uzZceLECezZs+e57/fgwQMULFgQT58+xbJly57t9/Hxwf79+wEAv/zyC54+fQoAiIiIQL9+/bB8+XKUL18e06dPT9J9GjZs+Kx9f39/5M2bF7ly5cKZM2fg6+uLUaNGwc/PDydOnMCFCxeQL18+DBo0CAMGDMCBAwee+33GJtOZhgyUAlasoImoQwfg+nUxEQlCWsPLywv16tVDpUqVkC1bNuTPn//ZsVatWmHu3LmoXLkyypYti9q1az/3/SZNmoRatWrhhRdegK+v7zMlNGjQIHTs2BE1a9ZE06ZNn81MpkyZggYNGqBBgwbPHNpt27ZF+fLlE7zPhAkT0K9fP1SuXBnZs2fHkiVLAAAzZ87Etm3b4OzsjAoVKqB169ZYuXIlvvzyS7i6usLd3d0uMwJlbYqSlvHz89O2rFDWvTuwejXQuzdgh89XENItx48fT7RDE9Im1r47pdR+rbWftfMzrWnIYPlywMMD+OEHrj4WBEHIbGR6ReDiwlBSV1fgzTeBaNOfIAgZmGHDhj2L1zceixcvtknbGzdujNN2586dbdK2vci0PgJLGjakv6BrV+DTT4GJEx0tkSAI9mT27Nl2a7tly5ZJCjtNS2T6GYFBly70F0yaBKxb52hpBEEQUg9RBBaMHcttr15MQyEIgpAZEEVgQaVKwAcfsGZBmzaOlkYQBCF1EEUQi88+A3x8mHrCRr4jQRCENI0oglgoxZoFTk6MIgpJqNimIAh2JbXrEZw4ceJZ5s8zZ86gbt26yW7DIHbG0tg0btwYtlwT9TyIIrBCqVLAf/8LREYybXU6W3MnCBmG1K5H8PPPP6Njx444ePAgSpYsiX8ySf4ZUQTxMHAgMHkysHYtQ0oFIbPjgHIEqVqPYMOGDZg5cyYWLFiAJk2aADDrDvj7+6Nx48bo2rUrypUrh549ez5LHPfJJ5+gRo0aqFSpEgYPHmw1oVxirFixAr6+vqhUqRJGjRoFgMqub9++qFSpEnx9fTFjxgwAzLhaoUIFVK5cGa+++mqy72UNWUeQAO+/z6R048YBTZoA9es7WiJByFx89tlnOHLkCAIDA+Hv74+2bdviyJEjz1JRL1q0CHny5MGjR49Qo0YNdOnSJU4a6lOnTmHFihWYP38+unfvjrVr16JXr15x7tWmTRsMGTIE7u7uGGklLfHBgwdx9OhRFCpUCPXq1cPff/+N+vXr46233sK4ceMAAL1798Zvv/2G9u3bJ/k9XrlyBaNGjcL+/fvh6emJFi1a4Oeff0bRokVx+fJlHDnCel2Gqeuzzz7DuXPnkDVrVpuV47SbIlBKLQLQDsANrXWlBM6rAWAPgFe01vEb1ByAkxNHNC1aAK1bA1evAtEDBEHIdKSBcgR2r0eQ2L2LFCkCAKhatSrOnz+P+vXrY9u2bfjiiy8QFhaGkJAQVKxYMVmKYN++fWjcuDG8vb0BAD179sSOHTswduxYnD17Fm+//Tbatm2LFi1aAAAqV66Mnj17olOnTujUqVOK3kts7Gka+h5Aq4ROUEo5A/gcwEY7yvFcNG8OvP021xUkdTorCIJ9sHc9goSw1k54eDjefPNNrFmzBocPH8agQYOsypAQ8ZmSPD09ERQUhMaNG2P27NkYOHAgAOD333/HsGHDsH//flSvXj3F78cSe5aq3AEgsZibtwGsBXDDXnLYgq+/BipUAPbvB6ZOdbQ0gpB5SO16BMnF6PTz5s2L0NDQBKOE4qNWrVrYvn07bt26hcjISKxYsQKNGjXCrVu3EBUVhS5dumDSpEk4cOAAoqKicOnSJTRp0gRffPHFs0pqz4vDfARKqcIAOgN4CUCNRM4dDGAwABQrVsz+wsW5P6e1pUpREfTsCThADEHIdKR2PYLk4uHhgUGDBsHX1xc+Pj6oUSPBrswqBQsWxNSpU9GkSRNordGmTRt07NgRQUFB6Nev37OqZ1OnTkVkZCR69eqFe/fuQWuN9957L0XRUbGxaz0CpZQPgN+s+QiUUqsBfKW13qOU+j76vETVqa3rESSHM2eAatWAypWpGFzE1S5kYKQeQfolPdUj8AOwUil1HkBXAHOUUrbxfNiJkiUZRfT330CfPo6WRhAEwTY4TBForYtrrX201j4A1gB4U2v9s6PkSSqvvcZCNitWmDHSgiCkL+xZjyA2nTt3jnOvjRvTVnyMPcNHVwBoDCCvUioYwHgArgCgtZ5rr/vaGzc3YPt2mogGDADKlQMcYJoUBOE5sGc9gtgY4a1pGbspAq11j2Sc29dectiDypU5I3jlFaBZM66adHNztFSCIAgpQ1JMpJDu3bm+4OFDpqMQBEFIr0jcy3MwcyYT082ZA7RrB9go7YcgCEKqIjOC58DJicqgbl2gb1/g8GFHSyQIgpB8RBE8J66uwKBBwOPHQKNGUuJSEGxJSusRAMDMmTMRFhZmY4mSzoQJEzBt2jSH3T85iCKwAX370nF85w7QsqXULxAEW5GeFUF6QnwENuLHH4E9e4B//gE+/pi1DAQhw/Duu0BgoG3brFqVttUEsKxH0Lx5c+TLlw+rVq3C48eP0blzZ0ycOBEPHz5E9+7dERwcjMjISIwdOxbXr1/HlStX0KRJE+TNmxfbtm2z2v7QoUOxb98+PHr0CF27dsXEiRMBsJ5BQEAA8ubNi4CAAIwcORL+/v4YPnw48ubNi3HjxmHjxo2YPHky/P394eSU8Jg6MDAQQ4YMQVhYGEqWLIlFixbB09MTs2bNwty5c+Hi4oIKFSpg5cqV2L59O9555x0AgFIKO3bsQM6cOVPwAScdUQQ2wsWFaSfKlAGmTGH9gmbNHC2VIKRvLOsRbNq0CWvWrMHevXuhtUaHDh2wY8cO3Lx5E4UKFcLvv/8OgMnocufOjenTp2Pbtm3ImzdvvO1PnjwZefLkQWRkJJo2bYpDhw6hcuXKCcpTo0YNNGjQAMOHD8eGDRsSVQIA0KdPH3zzzTdo1KgRxo0bh4kTJ2LmzJlWawtMmzYNs2fPRr169RAaGgq3VIhNF0VgQ3x8uL5g5EhGEAUEcJ8gpHsSGbmnBps2bcKmTZtQrVo1AEBoaChOnTqFBg0aYOTIkRg1ahTatWuHBg0aJLnNVatWYd68eYiIiMDVq1dx7NixBBVB9uzZMX/+fDRs2BAzZsxAyZIlE73HvXv3cPfuXTRq1AgA8Prrr6Nbt24ArNcWqFevHkaMGIGePXvi5ZdfflYDwZ6Ij8DGdOkCbNoEREQArVqxhJ8gCM+P1hqjR49GYGAgAgMDcfr0aQwYMABlypTB/v374evri9GjR+OTTz5JUnvnzp3DtGnTsGXLFhw6dAht27Z9llbaxcXlWdbP2PUFDh8+DC8vL1y5cuW535O12gIfffQRFixYgEePHqF27do4ceLEc98nMUQR2IHSpbnY7ORJoHNncR4LQkqxrEfQsmVLLFq06Fn+/cuXL+PGjRu4cuUKsmfPjl69emHkyJE4cOBAnGutcf/+feTIkQO5c+fG9evX8ccffzw75uPjg/379wMA1q5d+2z/hQsX8NVXX+HgwYP4448/8O+//yb6HnLnzg1PT0/s3LkTAPDDDz+gUaNG8dYWOHPmDHx9fTFq1Cj4+fmliiIQ05CdGDmSs+lNm4Dx44EkDlIEQbDAsh5B69at8dprr6FOnToAWFj+xx9/xOnTp/HBBx/AyckJrq6u+O677wAAgwcPRuvWrVGwYEGrzuIqVaqgWrVqqFixIkqUKIF69eo9OzZ+/HgMGDAAU6ZMQa1atQBwRjJgwABMmzYNhQoVwsKFC9G3b1/s27cvUTv+kiVLnjmLS5QogcWLF8dbW2Ds2LHYtm0bnJ2dUaFCBbRu3dpWH2e82LUegT1wZD2C5LJhA9ChA1cfr18P2Ki8qCCkClKPIP2SnuoRZHjatAH++ovPe/QATp92rDyCIAjWENOQnWnSBHjzTWDpUqB9e2D3btYzEAQh9ahVqxYeP34cY98PP/wAX1/f52578uTJWL16dYx93bp1w5gxY5677dRCTEOpxPbtXFfQuDHwxx9S5lJI+4hpKP0ipqE0SqNGNA9t3gwMG+ZoaQRBEEzspgiUUouUUjeUUkfiOd5TKXUo+vGPUqqKvWRJKwwfzoyl8+YB0YENgiAIDseeM4LvAbRK4Pg5AI201pUBTAIwz46ypAn8/IAvvuDzYcM4OxAEQXA0dlMEWusdAEISOP6P1vpO9Ms9AOy/jjoNMGIEMGoUF5l16ACcOuVoiQRByOykFR/BAAB/xHdQKTVYKRWglAq4efNmKople5QCpk7limMnJ0YSReeaEgQhFvZOQ7169WqUL18eTZo0QUBAAIYPH56iewFcjXzr1q14j7u7u6e4bXvjcEWglGoCKoJR8Z2jtZ6ntfbTWvt5e3unnnB2Qilg7VouODt7lvWPIyIcLZUgpD3srQgWLlyIOXPmYNu2bfDz88OsWbNSdK/0jkODGJVSlQEsANBaa33bkbKkNkqxxGW/fnQev/su8O23jpZKEKzjoHIEdq1H8Mknn2DXrl04d+4cOnTogLZt22LatGn47bffMGHCBFy8eBFnz57FxYsX8e677z6bLXTq1AmXLl1CeHg43nnnHQwePDhZ71trjQ8//BB//PEHlFL4+OOP8corr+Dq1at45ZVXcP/+fUREROC7775D3bp1MWDAAAQEBEAphf79++O9995L1v2SgsMUgVKqGIB1AHprrf/PUXI4khs3gOi8Vpg9m8nqoutRCIIA+9YjGDduHLZu3Ypp06bBz88P/v7+MY6fOHEC27Ztw4MHD1C2bFkMHToUrq6uWLRoEfLkyYNHjx6hRo0a6NKlC7y8vJL8ntatW4fAwEAEBQXh1q1bqFGjBho2bIjly5ejZcuWGDNmDCIjIxEWFobAwEBcvnwZR44w+PKunezIdlMESqkVABoDyKuUCgYwHoArAGit5wIYB8ALwBylFABExLfYIaNSqBCwbx8jiT76CHjvPaBsWaavFoS0RBooR2CXegQJ0bZtW2TNmhVZs2ZFvnz5cP36dRQpUgSzZs3C+vXrAQCXLl3CqVOnkqUIdu3ahR49esDZ2Rn58+dHo0aNsG/fPtSoUQP9+/fH06dP0alTJ1StWhUlSpTA2bNn8fbbb6Nt27Zo0aKFTd5bbOwZNdRDa11Qa+2qtS6itV6otZ4brQSgtR6otfbUWleNfmQqJWCgFKOIvvySkUSdOwPHjztaKkFIe9i6HkFiZM2a9dlzZ2dnREREwN/fH5s3b8bu3bsRFBSEatWqxalXkJT3YY2GDRtix44dKFy4MHr37o2lS5fC09MTQUFBaNy4MWbPno2BAwc+13uKD4c7iwUyciRXHmfNCrRrByQQfCAImQZ71iNICffu3YOnpyeyZ8+OEydOYM+ePcluo2HDhvjpp58QGRmJmzdvYseOHahZsyYuXLiAfPnyYdCgQRgwYAAOHDiAW7duISoqCl26dMGkSZOevTdbIxlv0hDLlwN79jAf0csvM3OpxaBEEDId9qxHkBJatWqFuXPnonLlyihbtixq166d7DY6d+6M3bt3o0qVKlBK4YsvvkCBAgWwZMkSfPnll3B1dYW7uzuWLl2Ky5cvo1+/fs+qpU2dOtUm7yM2knQuDdK3L7BkCfD668DixTQfCUJqI0nn0i+SdC4DULQot0uWmCkpBEEQ7IWYhtIgo0YBv/0GBAUxmqhMGTqRBUFIGfasR2DJ7du30bRp0zj7t2zZkqzIotRGFEEaxN0d2LEDePFFIDgY6NUL2LmTrwUhNdFaQ2UA22RSiszbAi8vLwTaeuVdMkmJuV9MQ2mUnDmB6dOBp09Z0ax9e+DyZUdLJWQm3NzccPv27RR1LIJj0Frj9u3bcHNzS9Z1MiNIw7RrB2zbBuTODdSrB3TsyJlC9uyOlkzIDBQpUgTBwcFI74keMxtubm4oUiR5yZxFEaRhlAIaNACiorjGYMECoE8fYNUqZi4VBHvi6uqK4sWLO1oMIRWQ7iQdoBRw5gyQJQuzlo4d62iJBEHISIgiSAcoBcyfDzg703cwZQqwdKmjpRIEIaMgiiCdUKIEsGUL4OpKH8GgQcCuXY6WShCEjIAognRE7drAypWMJPL25tqCs2cdLZUgCOkdUQTpjObNgYsXGU0UGcnIonv3HC2VIAjpGVEE6ZACBVjEZtw44NQp4JVXpNSlIAgpRxRBOuXAARaycXYGNm4ERoxwtESCIKRXRBGkU158Efj0U+Cll/j6m2+YpE4QBCG52E0RKKUWKaVuKKWOxHNcKaVmKaVOK6UOKaUkk04yGTOGyel69+brgQM5UxAEQUgO9pwRfA8goeq7rQGUjn4MBvCdHWXJsDg5cSawdCl9B507S3UzQRCShz1rFu8AEJLAKR0BLNVkDwAPpVRBe8mTkVGKs4L164Hr14Fu3cR5LAhC0nGkj6AwgEsWr4Oj98VBKTVYKRWglAqQBFjxU6ECE9T5+wP/+Y+jpREEIb3gSEVgLcm51Xy3Wut5Wms/rbWft7e3ncVKv2TPDrRqBbi4AF9+yeR0giAIieFIRRAMoKjF6yIArjhIlgzDW2/RLFSiBDBgANcZCIIgJIQjFcH/APSJjh6qDeCe1vqqA+XJEPj5AdWrA7dv05HcvTsQHu5oqQRBSMvYM3x0BYDdAMoqpYKVUgOUUkOUUkOiT9kA4CyA0wDmA3jTXrJkJpRiPqICBZibKDAQ+OADR0slCEJaxm6FabTWPRI5rgEMs9f9MzOlSgGHDzNT6YgRwIwZXHjWubOjJRMEIS0iK4szKK6u3HbqBFSrBvTvD5w/71CRBEFIo4giyMDcvAm0bk2H8YMHLHcp6wsEQYiNKIIMjLc38OOPdCBHRgJ79gBffOFoqQRBSGuIIsjgdO7MfEQ5cwIvvABMmAAEBTlaKkEQ0hKiCDIBOXIwBUW1akCePHz++LGjpRIEIa2QuRRBVJSjJXAYs2YxF9GCBYwomjjR0RIJgpBWyDyKYOtWoEwZ4IjVrNgZHmdnbnPlAmrVAj7/XFJWC4JAMo8iyJ0bOHMGaNo0U88Mli0D/v0X8PAA3niDTmRBEDI3mUcRVK8OdOkC3LhhVnLJhHz1FVCoEPViQAAwZ46jJRIEwdFkHkUAAD/9BHh5AcuXA3/+6WhpHIK7O8tanjvH8NIxY4DLlx0tlSAIjiRzKQJnZ2DTJibk6dIFCAtztEQO4eWXgcmTueDsyRNg+HBHSyQIgiPJXIoAYNX399+nEhg3ztHSOIxBg7i+4D//AdatAzZvdrREgiA4CsXcb4mcpNQ7ABYDeABgAYBqAD7SWm+yr3hx8fPz0wEBAc/f0BtvAPPnA9u2AY0aPX976ZTHj4Hy5bng7MABM7pIEISMhVJqv9baz9qxpM4I+mut7wNoAcAbQD8An9lIPscwfTpQtCiT8WRSI/mFC8CUKcCHHwKHDgHffutoiQRBcARJVQRGWck2ABZrrYNgvdRk+iFHDuCjj4BHj4DGjR0tjUO4cgX45BPgUnTl6AkTgIcPHSqSIAgOIKmKYL9SahOoCDYqpXICSP/B+EOHAvXqAadPM3F/JqNmTZqEpkzh67t3gWnTHCuTIAipT1IVwQAAHwGoobUOA+AKmocSRCnVSil1Uil1Win1kZXjuZVSvyqlgpRSR5VSibZpczZuZEzljBnA9u2pfntH4uwMeHqar/PkYXbSa9ccJ5MgCKlPUhVBHQAntdZ3lVK9AHwM4F5CFyilnAHMBtAaQAUAPZRSFWKdNgzAMa11FQCNAXyllMqSDPmfnxw5gA0b+LxrV5qKMhE//ADUrQsMGQKEhtJ5LKmqBSFzkVRF8B2AMKVUFQAfArgAYGki19QEcFprfVZr/QTASgAdY52jAeRUSikA7gBCAKR+6ZQGDYCvvwZu3WJoaSaiYUPg77+Btm25pqBJE2DuXJkVCEJmIqmKICK6xnBHAF9rrb8GkDORawoDuGTxOjh6nyXfAigP4AqAwwDe0VrH8T0opQYrpQKUUgE3b95MosjJZPhwVnn/7jtmZMtktG0L7NrFlBOPHwNffuloiQRBSC2SqggeKKVGA+gN4Pdos49rItdYiyqKvWihJYBAAIUAVAXwrVIqV5yLtJ6ntfbTWvt5e3snUeQUMHkyk/CMHg3s3Wu/+6RBlKLfvHRpoFcv6sPr1x0tlSAIqUFSFcErAB6D6wmugSP7xMaMwQCKWrwuAo78LekHYJ0mpwGcA1AuiTLZHldXYMUKQGugeXMgPNxhojiKLl3oRA4PB2bPdrQ0giCkBklSBNGd/zIAuZVS7QCEa60T8xHsA1BaKTRY4gYAACAASURBVFU82gH8KoD/xTrnIoCmAKCUyg+gLICzyZDf9rRuzVXH9+8zZXUm48ED4OBBoH17mokyme9cEDIlSVIESqnuAPYC6AagO4B/lVJdE7pGax0B4C0AGwEcB7BKa31UKTVEKTUk+rRJAOoqpQ4D2AJglNb6Vsreig357jvaSP75J9OV8qpVi6uMhw4Fbt9mVJEgCBmbpOYaCgLQXGt9I/q1N4DN0WGfqYrNcg0lxq1bQKVKHCLv2QP4+tr/nmmA338H2rUD/P25xi4sDDh6FHDKfOkJBSFDYYtcQ06GEojmdjKuTZ/kzQsEBtJ5/PLLHB5nAmrW5PbffxlJe+IE19wJgpBxSWpn/qdSaqNSqq9Sqi+A3wFssJ9YaYQCBYAff2SJyzp1MkWJS29v4PXXgRde4Po6b29g3jxHSyUIgj1JkmkIAJRSXQDUA8NCd2it19tTsPhINdOQJXXrArt3M64ykxjNHz9mDZ/t24GZM5mYrmBBR0slCEJKsYVpCFrrtVrrEVrr9xylBBzGli1MyvPjj8D33ztamlRh9GigQwcWrImMBBYvdrREgiDYiwQVgVLqgVLqvpXHA6XU/dQS0uFkywbs3MkA+4EDGVaTgbl2jYXtx4wBgoKAUqVYwycTWMYEIVOSoCLQWufUWuey8siptY6zAjhDU7EisGABn3fuzJzNGZQCBYAdO4BPPwVq1ACePgXOn5dyloKQUcnYkT+2pm9fGs4vXQJefRWISP38eKlN796AlxeDp5Ytc7Q0giDYA1EEyeWll4BvvmFMZcfYyVQzHsOGAfv3M/XEunWy0lgQMiKiCFLC668zhGbDBhrSMzDGQrKePVmv4LffHCuPIAi2RxRBSnBz4zA5WzbWeVy92tES2ZV+/Zh1o2BBMQ8JQkZEFEFKKViQYaVOTkCPHsCBA46WyG48fcqVxj16cBJ0546jJRIEwZaIInge6tRhJFFUFIPub9xI/Jp0SPny9I936kSlsHatoyUSBMGWiCJ4Xvr1Y1xlSAiztT186GiJbE758ty6uQElSwJr1jhWHkEQbIsoAlvw0kssaBMQANSvz6W4GQgj8WpAAPMPbdlCvScIQsZAFIGtaNcOKFOGGUtbtmSVswxCqVJMs1S4MNCtG5dP/PKLo6USBMFWiCKwFc7OHDIXK8Yh85tvOloim6EUc+116AC8+CLg4yPmIUHISNhVESilWimlTiqlTiulPornnMZKqUCl1FGl1HZ7ymN33N05I3B3B+bOBSZNcrRENuXmTSA4mOahv/7K0Fk2BCFTYTdFoJRyBjAbQGsAFQD0UEpViHWOB4A5ADporSuCpTDTN56eTFmdNSswfnyGGTprDZQrx8qd3boxeuh/sStQC4KQLrHnjKAmgNNa67Na6ycAVgKInZPhNQDrtNYXASBWFbT0S6VKHDrXrQu89hrw55+Olui5UYrRsn//zUR0xYpl+HV0gpBpsKciKAzgksXr4Oh9lpQB4KmU8ldK7VdK9bHWkFJqsFIqQCkVcPPmTTuJa2Py5mU+hmLFgLZtmawunVOvHktXhoTQPLRpE3DvnqOlEgThebGnIlBW9sUOpXEBUB1AWwAtAYxVSpWJc5HW87TWflprP29vb9tLai88PJiCIiqKyiC1K6vZmPr1uW3QAChbFnjyBPj1V8fKJAjC82NPRRAMoKjF6yIArlg550+t9UOt9S0AOwBUsaNMqU/37sDnnzPmsl49VnpJp9SqxbfTsyezkRYunGFcIIKQqbGnItgHoLRSqrhSKguAVwHEdi/+AqCBUspFKZUdQC0Ax+0ok2P48ENg7FgOoWvVSrcVzrJkAX76iQlXvbxoHvrzT+DBA0dLJgjC82A3RaC1jgDwFoCNYOe+Smt9VCk1RCk1JPqc4wD+BHAIwF4AC7TWR+wlk0P55BMqA6W44OzkSUdLlGK2bgWWLKEiePxYUlMLQnpH6XS2AtbPz08HpGdb+9GjQJMmzFq6eTMjjNIZ/fqZhdqKFGE0kSSiE4S0jVJqv9baz9oxWVmc2lSsyFzON28CtWsDp045WqJkU7w4cOUKLV1duvDthIY6WipBEFKKKAJH4OfHKmcPHwI1a6Y7M1Hx4txeuMDFZeHhVAaCIKRPRBE4ivnzGY959y4VQ2CgoyVKMoYiOHeOgVD588viMkFIz4gicBTOzvSyNmlCu0rdukxNkQ7w8eH2/Hm+jZdf5owgA5ZiEIRMgSgCR5I7NzOVfv45S182b87XaZxChYAzZ4DBg/m6WzcgLCxDZNIQhEyJKAJHoxTXGfz9N9NRtGyZ5kNwnJyAEiW4BbjS2NtbzEOCkF4RRZBWKFCA6wwiIxmgP3euoyVKkL/+AkaM4HMXF5qHfvsNePTIsXIJgpB8RBGkJXr0oBNZKWDoUGD06DRb6SwoCJgxwyxZ2bUrfQRiHhKE9IcogrTGwIHM5ObsDHz2GdC7d5qsgVyxIrdHj3LbuDETrv70k8NEEgQhhYgiSIu0bUufQf78wLJlHG6nMZuLsSB68WJGwLq40Gn8v/9J7iFBSG+IIkir1KoFXL0KzJrFSvENGgC3bztaqmcUKcKaO4sXA2+9xX09e1JfrV+fsjZXrgRy5GD+IkEQUg9RBGkZpYC336Yndv9+oEoVVoZJAyjFycqBA8ynB3AphI8P96eE06cZhnr1qs3EFAQhCYgiSA988QWT/1++DFStmqbyOVSrxlDSkBCK17Mnc+ldu5b8tipX5vbWLdvKKAhCwogiSA+UKAEcO2bmfW7XDpg509FSPSMigj6Drl35iIpKmdP47Flub2SMytWCkG4QRZBeyJWLK7ZmzOBKrvfeA4YNYy/sYFxcgG+/Bf79F/j5Z84Sfvwx+e389Re3oggEIXURRZDeePdd+gk++ACYMwdo2hS4c8fRUuHll4HSpSlar14sz/x//5e8NnLl4jYqyvbyCYIQP6II0iOlStFv0K8fsGMH4OtrBvQ7EC8v+gpefZXO5OTOCpycgJIlgf797SOfIAjWsasiUEq1UkqdVEqdVkp9lMB5NZRSkUqprvaUJ8PxySfMCX35MlC9usOT/eTJQ0VQqBDQujUnLMlZU/DoEZAtm/3kEwTBOnZTBEopZwCzAbQGUAFAD6VUhXjO+xysbSwkhyJFGL9Zty6dyN27M6jfQX6D5cuBXbv4fMIELnuYMSPp14eHA0eO0PolCELqYc8ZQU0Ap7XWZ7XWTwCsBNDRynlvA1gLQFyEKcHDA/D3B4YPp9d29mwOxx0Qg5k7N+Dmxuc1anCl8eTJzEuUFObNY+oKf3+7iZiu2LwZmDLF0VIImQF7KoLCAC5ZvA6O3vcMpVRhAJ0BJJhqUyk1WCkVoJQKuHnzps0FTfe4ugJff02n8cKF9BtUrcrZQiqyfTv1kTEhmTOHfoPOnWm9SowiRYA6dYDr1+0rZ3qhb19gzBhHSyFkBuypCJSVfbFTac4EMEprnWBWNa31PK21n9baz9vb22YCZjjc3elpnT2bPW/t2sDSpal2+8OHgW++MYOY8uZlOOnNm8yQkdjMYOlSRhvdvCmRQwCjr7JkcbQUQmbAnoogGEBRi9dFAFyJdY4fgJVKqfMAugKYo5TqZEeZMgfdu3NoHREBvP461xs8eWL32+bJw62RmhoAatakiSM8nKWZ33sPmDiRCiI2U6awdHNkZOpExC5dCnxkEcIQHp4qH1OScXKiPGkw+ayQ0dBa2+UBwAXAWQDFAWQBEASgYgLnfw+ga2LtVq9eXQtJ4NYtrYsW1drFRWtA6xdf1Do42K63/OMP3uqff+Ieu3FD6379tHZ25jmA1pMmab1vn9YPHvCcYsW09vLSunp1rS9etKuoWmtTDq217t6dz9ets+89e/fWevr0pJ1ryPff/2pdrpzWkZEJnx8RofWSJdxmRiIitI6KcrQUaRcAATqeftVuMwKtdQSAt8BooOMAVmmtjyqlhiilhtjrvkI0Xl7A1q0sduPqyhQVL75IQ76dsDYjMPD2Zp2dVavMfWPH0qmcMydTKV29yjaaNmUG07VruVr56lX7m4qCg7ndutW+9wkKAjZtSt41b7zBhXqBgQmfN2sWJ4CpaA1MUxQsyFnU/fuOliT94WLPxrXWGwBsiLXPqmNYa93XnrJkSkqVYq8wZw5w4QKX/zZtyvUHo0dz1ZcN8fJikx07Ai+8AIwcyUJr9+5xbFumTMzzV6/m+SdP8rFsGR3FM2fGNdG4ugJFi7Kss7VH0aJ0kaSEiAiahQBgy5aUtZFUSpYEjh9P2rnZstGqd+IEy4D++Sd1eXw0aMBtzpzPL2dCvPcev6OdOxnptWSJzX9KKcKII3n40FylDlAxfPstMGoU6z2lhCtX+PuybDcjYVdFIKQR3N0ZlzlrFtCqFUNR9uxh4H9Ke08rlCgBXLzITrlKFUaxAlQM1mzvUVF0ZxgsX86CbLt3M2KmYEGGoAJ0nD5+TH32yy9AaGhc27mnZ/yKolgxtmetI3jwwBxFHj/OP32hQsl77zduAPnyJXzOnj1mrYaoKI5e4yM8nAvsPD3px9iwIfH2s2fn1t4+hfPnuV24EPjhByqCtETsGk4ffECF5eLCbO6jRzOoLjkULswU659+yu+kTRubiZsmEEWQmWjenEPeVq1YDrNMGQbtxx6qpxClgN9/5/NPPuEf5+TJ+K1RscNEL1wAsmZlh3fjBpdIGHTtyj/f5s18G0uXsjzmxYvm48IF4NIl4Nw5jlbv3o3ZvosL/9CGYjD49Veas6pVAw4eBLZtYzptg/BwLsvIn58zk9icPcuRvr8/0KhR/J+PpWnn8mUqzPgwnOVjxjBnU1I690WLuA0L43bLFqBhw5gyh4ZSYSSkhBKjeXM6+0NC+JmkhdmAJcb7NzBMlffu0TTZvXvyFQFABfjVV/yNpNFS4ilGcg1lJpycgJde4tDb25vG96pVWV/SRgyJ9v7kzMlQ0u3bzZHq66+zgw8JAfr0YZI6SwoVonnJ25tKondvFqsBzGm/gY8PO9J69egG+ftvdkghIdRrRkfavDlH03PnAh9+yPOV4vkGr7/O6w4e5IK4Tz5hpu8332TZ6GbNeK9VqzhjuHYtZhU1o/ZCYpXVjDTbAHDmTMLnZs3KInUATUInT3IWYekr8fcH3n+fnTtgFgR68oS+jmbNgM8/N89/8oTfy/vvJ3xvS86dY1hwYCC/z8ePgb17eWzHDn5Px44lvb3UIPaMwBjnGL+lc+eS156lEu7bN2ZbGQWZEWRGunalv2DIENphOnbk0HPixJQbUaN54w2gbFl2mMOHUxHcusUpddeu7OSBuOaE0FBg2jSgfXtzRgAABQoAr7zCxWaAmbvIMrxUa4q9ciU78A0buK9aNe43TFQjRtBZe+QIjz96xE78xo2YMwvjsXt3TMd3r14xZXZzizm6HjqUiixbNvORJQtnIi4uplJs04ZK5ddfzWPWHvXq0Vk+fDjbd3ICvvuOujtXLs6+pk+nzT5HDn6Gw4fzO5g9m9dcvGjKa3SQq1cDhw4BAwdSiSZE2bLA06dUqH/9xe/Q+O6MGddvv1HJvfSSqfRtycGDnNUYdbLjY/BgmoBiy/Dxx1xe88MPfJ3cNOfOznzP58/TLATwu3vvveS1k6aJL5worT4kfNTGPHrEuE5A6+bNGXZqA86dY5Pz5lk/HhXFWxtcvMjz58/XumlTrWvX1rpnT62//jrmdUuXmqIaNGyodZ48ZrgloPXRo1p36aJ12bLmedOn89iFC3x9+rTWr7yidUCAeU5wsNZ//cVQzY4dY7Y5darWy5drPXu21p9+qvXIkVoPG6Z1/fo8XqyY1i1bat2okdY1a2pdqJDWJUtqXaqU1j4+Wru6ap01q9Z582rt4aG1u7vWbm5mhG9KH0ppnTs3nxcsqHXduqZMVapo/f77Wk+cqPWMGVpny6Z1kyamvP/+q/Xhw1qfOaP1lSta372r9ZMn5udheZ+ePbXeuZPPq1Uz9xvPjx2z/l1fvszv7caNxH41MfnrLz4sw3yfhylT2M4776Ts+uPHTVnat0/6dTdu8LONL7T16FG2uXdvyuRKKkggfFRmBJkdNzcO8xYvpgG+WjUONX19n6tZwwY/eDDNQFmzxjzeuTNDNgMC+NoYrWbLBtSvT6vV0qWcEQCmc9WYERjnh4TQRNGvH9+CpydnCxMmMPzUuPbqVfP1X38BHTpwROzvz7V3H3zAkeOhQxzphYTQ32BJ+fKcPMVmwQIm22vWjA5UgKNlT0+aJU6e5L4mTfjeXnyRZppXXjHb0JpyRkTwsXQpTVMGlStTtrFjaTK6f58J/fbtY7t583KWcfcubeHZs/MzP3eOZrGHD822tm3j9uJF0/wUG2dntuHiQtkiI4F//qGZCuD9jFnCwYPc99VXnEFkycJH1qzcHjzImcvQoZwNOTmZD6VivjYeUVFs35LDhymXkxO3sWdQoaGMfciRgzMI49z+/fnbaNaM7SQ38ufYMU6Yq1Tha6WYUDGpfPklzWqxfRcGv/zC7erVDKd2BKIIBP4716+nF+3aNfZWa9bE/ScmA0tnZGwlAHC9gKEEADN8082NnfiVK5zmlyhBB2yuXJyON27M854+5dawtXfsSMdo06bs2C0zcrdpw47B8AsY7RrJ7c6eZef49tsx10JkyUJzh7G2IL6U2oYslu/ZiEIyImwAswNu1YrmMktFoBQ7Lmdnfl6Gv8HFhYqhYkUqgjx5gLZteWzNGiqCKlXYya1fTwXZpAmd3YULmw7piAh24u3b08FrOOpXruQ9w8LoKwFoxgsLY1sLFpjvu3hx87MIDuZ35uxs2tANJRgf333HR0oxalonBycn07G7eTM/5//+l7K6uPB3kZB5zsWFA4sDB8zV8PXr83vo1y/uudba++UXflZTpzIAIva5RhDB/fvAH39YlyEykoOHwoUZmGBrRBEIpFMnFhru0oX/nDZtOJQcMCDFTZ45E3+uHKMzMkb6ljMCwExnXaIElYNh161QgSGlhw/z9YUL3Pr4mKP1Tz+lw/fcOXZeGy0SnE+ezNHdwIHmPmMxWa5cZoRNSAhw6pQ5s/H0tB7DHxBAV8uHH8aspWAoByMbqyUlS9L2nxCGDyQ8nCPcwoUpm2WklaE8r13j1/X4MX0pYWH8Otu2NTtnFxfTpv/tt2ZYbsmSTP0BMOpo+3bOjozvrWZN+lauXmUU0p49dBa/8w6d/Z6eZqLbhQs5lnjyJOZjyRJ+J82bM42I4fQ2HsZsKCoKaNmSE9TatXkPgDONoCDOrCpXpsxXrvD7MGZQERHMwA7wJ1ytWszZleEryZ6d7rGnT2NeG98jPDxu9NnVq4z6CglhB50tG8+1bNP4/i1JLIHgf//LR0KMGsUABlsjikAw6dyZQ7ZRozisHjiQPc9//pOi5kqUiP9YgQL8w9y5Qwer5Yxg1y5ztFyuHJ3Hp07x9Zkz/CMaisEYcb/wgtl23740k3h6srZBnz7cnyUL39qIEZwdGKYCS0VgdObGWoKDBynL/PlxF2r99BOrsf30E6+1nPqXLMlO0FAsBw/Sifvddzx29y47EmMGYsnRozS55M7NEfeKFbzG2ZkdpMHvv/O1oRyUotK4fp2PRYvYIRkrjQ3Ht5eXaX45c8ZUBIZzf98+OqoBvr/27c0ZSu3aZid96hR/IgsW8Njdu9aXpRjO21y5zHtZIyqKHWvJklTsBu++y9GylxcVxZUrNPV8/XXM6w1F0Lix+RzgTOjiRbbj5cUBgqEgr17ld/vqq/FHUX/zjemwB1h5LyDAvIe1lcylSzOyKDzcHAwsXkyFHVtp3LjBAU358nEVkXGeiwt/v5afiy0RRSDE5I03qBA8PfnvGDOGxtfJk20aMG7Y/q9d45+zfn3atrNlY+d54QJDRn18GGlkdPxTp1JRGKYfHx921JZrDrTmFLtcOUYQ5c/P9g0bs7OzGXIJmEooVy4zPPPRI5pi6tWjjFev8pqCBc3r/vmH28mT2TnFXkNgGb9/4QI7WKMcJ8BO2JoiOHuWNv0mTTh67dCBI2FrppFx49hJLF9O01P27DGrlhqhnoCpCFavprknRw4zDPLxY65gBtix16vHDmjfPi5Qjy/pr6EgJk+OuTjQICrK7Cjjs5Eb3L3L87NkiRmS6uxMRR8VRdPf339T/rAwU8lYhtXGDh817h8aStv+jh3msfPngfHj+fv69lvrcllGjtWrRwX86af8XKpXt35NqVJU5JZm0adPrS8KLFuWEXO+vmzXEcg6AiEu+fKxFzOGM1Oncghow4Q/Vapwxeyvv/IPcOhQTNNMsWLmn8zbmx3jo0fsAMqWpfkA4Hblyrjte3hwHcKBA2znwQNT+WzeTF0HsPObMIGdrKcnO5xffzU79Xv3OEouW5YmCkuMEeGhQ7TF9+5tHtu4kXrzpZf42lhrUKBATEVgjfbtaTdev5526S++oCKKiDDXO9Spw8/w0SOgRQt2juvX86szlFy5cjE7MeP51q3sQC9eNLOv3rxpmi6MehI3b7L4neFkt4ahCIYONUN8DbZuZSferBmV8dmzDH2tW9f62gNDvnHjGOrp7W3mZTJMgLVqsS0gppnMUsmEhlL2yEguxmvfnvsXLGC7lp+JoSQ2buRAwdpaEDc3fpbu7jSV5c7NwcbFi9YVOcD7GutkjFBe476xV9n/8AOX8jh0PUZ84URp9SHho6nIP/8w5tGImXvnHZund+zQgU03aMAwx5s3457j76/1e+9pff++1m3aaF26tNabN2sdEhIz1NEa9+9r/fix1hs2MNOp1lovXMh7njsX/9uZOZNhoFWqmG//jTdinvP4sXmsXDmtw8O1njNH69u3mTHUOPbggdbjxzPM8+lTPq5ejf/eV66YGVmNNn78Ues+fbQuXJgPY/+wYfwsunVjmKrWWq9Zw2Nt2zLbq3GfadPM62rX5n0MgoLM8F0jSa0RLrl8eVwZjXb+/ZfXbtpEOSzp35/nLF7M8OAVK8zrfvghbpt79pjHa9TgvrNnzX3163Pf77/z9e7d5rVhYQzpLVqUnwOg9Vtvcft//8ffVYUKZlvG72bVKr42fuZvvmm2efduTPlmzOD7nD8/ZljtmjX8vg3Cw81jJ07w88+aVesPP9T6zz+1fv11nmNQogTPbdQo7mdiS5BA+KjDO/bkPkQRpDLXrjGg3sgfPXmyTZqNimJHXqIE1w0YnfP58wlf16CB1k5OPHf6dIo1dmzy7m10JHv2UMH85z8xj//5J9cfFCvGNQfTplHOFi0Yo790Kc+bMUPrL75ghm+AawMArXv10vrLL83O4PRpKpF8+ZImX/HibENrs43167luIfYagooVua1cmWsXtGYH//PPVD6A1vfumW0PHx6zE3/9dXbk27Zx39atVGY1a5od8++/x5WxdWseu3yZr+vU0bpZs5jnXLvGc156Seu1a03lGN93vHu31uXLU6kBWm/cqPXgwWZnaxAQYH4m1vD15XFj3cAff3B/+fLmezfWNBi/u1y5uJ5jxAiuIenZU+tKlay3/9NPMb+Dxo21zpHDVLjHjpnHVq/m+oAtW6hgJ07UOksWKi6teU327DzX19f6/WxFQopATENCwuTPz+ghT096CceMYfylDShYkOaCcuVMu661KBuAU//Ll2niqVqVJqQRIzj9j296Hh+GiahtWzr9pkyhE9Jg8GBGqJQuTRPV++/TUXztGv0Cv/5K88kHH9DZbdipjeiS116jScng2jW2YzinATpwp061Lp81J3KOHIyEAWJGJ1WowO2RI/yKZs3iAvGOHXl+kyY0K+3bx8/K0pnr7c2Ing0bTLNFnjw0l+zdy+sA63H3S5bQ1m7YvPPkof3dMOEAXGvg5EQTUe/edD+5uJghrUePxiytXbs2zSNGneaWLc2fmqWjvnx5vh/D7AbQpHPxIq2Z/fpxn2FmbN2aPqLjx/kZVatmmoAM09CjR/xdhITQqbxsGT/Ta9cYFWYZL5E7N7dffMFtWBhNl4Yfq0QJ00H/++80J2XLxsivM2f4uS9cyAyuoaGmWSs1ijHFS3waIq0+ZEbgIO7f53y6TRsOyeMbjiUDT0+OhL791hxBWY5eLSlVSuuuXbnSdOdOTvU3beKo99q15N03ONi8X8OG+plpx6BqVe4bMsTcV6+eeU2uXOaor3PnuKP0f/+NOfJeuzauDK+/ThNPbCIieM348XydNStfG8V+QkJoVnJz4/6pU837DBhgLhLfuNFsc9Ei7suTx7wO4Gi0RAmOwA1zx8WLWn//PZ9//jm3hw4l/pn27m22e+QIZxjNmpnfsfH49FOthw7lyBvQ+tVXrbdXs6b5PTRtqvWpUwnff+9e8x69e9MMZ8wcLR/DhsW87unTmN9Vq1Zalymjdf78fL1sGWcSXbqY1zx8qPXJk1pv3x6z7b//1vrOHc5ArlzRz0xcgNZz53JWU68ef3OdO9OMdfKkeV6nTjFlCwmxrSUWMiMQnpucOTnEGjqU3t2ePemJfQ5Kl6ajc9gw0/lrOdq1pE4djkCbNuXozljZOmGC6TxMKpaRG23aMIfN99+b+7Q25TMYPdpcMHb/Pks8AAzzNCJtDbp25TqBF1+MGRprScmSnOFcuxYzgsmYVRgzAiORXI4c3Hp6cuRapw5fG+GYn39OZ6gxO7F0XG/ezG1ICL+2d9/l16kUneCnTjHy6tgxztKMGZOPDx3xSQlZ9PIyn//vfxz5bt5Mx3axYpwJGJ/NgQOmg9/Pj872PXvoVG3cmJ//v/+aM6ZJkxiFY8mKFRxVG2HFliunL17ke7GM2jJmFMbnasQ9uLhwBtKhA1+fPElH7zvv8LPeuJERYZZRU9mzM9oodpTY2bOcPbZubeYkOnKE27Vr+T7OnOF336oVV67v3MnjkyebKcoBPs+Tx3oghD2wqyJQSrVSSp1USp1WSn1k5XhPpdSh6Mc/SqkqpcMungAAFr9JREFU9pRHeE5CQpilzNOTPXH79uzNUkiBAmbkx4oVbMpammeAf7obN2j6sMzimRJcXbmKtlgxris4fDhmqgVjMZVhigFoRqpalcVfXF0ZaujhQUV2+7bZUQP8g7dqxdz3t24xEV6xYrSwGRiRQwULmqkLgJgmGoD327gx7pqMrVvZYTZrxnONz8RQBMa+kiUZWmp0xOfO0TTTsCFfly7NztTdnSYXFxdTETg7U0EkpdCNIW+5clxJa7lm4d49c53HmTPcdyW6ennNmozRf/99dt6HD5tRyoYZ6/p1UzkbDB1K5VumDMMyLaOGtm/nz/TxY9OsVbAgZfzkE6ac6BRdGX3ZMv6mjHQixoCkXDmanpYupfxG0R+A9xo5Mu5ncOYM1yw0aMDfUNmypsmpcGG+55AQfifly/OaQoV4b2Pdxv37fK/G/fbvN+9pbZGarbCbIlBKOQOYDaA1gAoAeiilKsQ67RyARlrrygAmAbCN8VmwD3nycJjm78+e7d49DqUsh2PJYP9+rhjVmh1QQsVg6tbl9t13bVOK8eHDmJ23JVWr8s/bpIm578IFjmTLlKFvAqANO39+jiBv32YeoKgofjSWoaF//cU/e8WK5j7LNAGWis3Tk3l5jJwzs2dzAVNC9YMWLqSy+O03UxF4eDCc01h7MHMmtw8eMCTVWIxVsSLP++EHc5FVoUJUvMHBZlqMxBgwgKPf117jaN4Ykc+bx5+JoUx+/DHm7KFQIVMZxfaNGO+5c+eYvgQgpg/m+nXTN2HY7420DYa/oFAh82fq4cHPKzSUs5fPP+fMpW5dM2S5VCmGBxs+GMtsK1FR5sp2Y/Xz8eNc83HlCmdjTk4MTa5fn+8vTx525mFh/P0YyvbWLX7+u3ZRWVarxnvlzcvvxljnsWgRfxvJzZyaZOKzGT3vA0AdABstXo8GMDqB8z0BXE6sXfEROJjISIbpeHjQyOnkRONmYpXVrXD4sBmBk5TbGrbYadOSfas4lCnDKBFr/N//aX3gQMx9b77Je48Zw9dRUfRnjB7N/ePGmeca2T3feYd2Y0Pup0/Nc27ejGlfjq/gfOXKCR/X2sz0unAhM6MakToGjx9ze/q01uvWmX4MS7p0oS3cknff1Tpnzvjva41jx+hb6NRJ6wIFtB41ive7coXhwTduMFLLeN8PHzLyCqBfoFYts63wcK0HDuQxI8rGwDL7qZ+f+bxkyZif69q1Wr/9Nt8fQL+SkUG1cmWtq1dnNljD3zN9Ov0WoaG8z+zZcf0YUVFm+3360J+gtdavvcZ9Z8+a5/7zD30Jn37KYyEh3H//Pu/bqZPWn30WM8vqiBE8p0MHRoVpTdl8fJL3XcQGDvIRFAZgmb8xOHpffAwA8Ie1A0qpwUqpAKVUwM3YFUqE1MXJifPrVau4aqd/fw4xU5CGolKlmLbsxG47Ywaf26Im75YtHMlbo3TpmGYhwJw9GDl6lKLZwVg5Gh5OP8J//mOO9q9fp0kA4AjQxWIdv5cXbeNvv82RqVEG4tYtjiSNRV2HDpn3i499+7jNnZsZS4sXjzm7MkxdJUuaI27DBGJw507cSKV798wRdlIpX555lypU4KzGWImdJQtXz3p7U44XXmAEVvbspi9m796YswWjWp2zc9xosq1bmRoCMH0wXbtyRgIwomv9eo7SZ83iIjWAn2/9+kz3cOgQZ6VVqnDm06kTR+tjxpjf95tv0mxpieV3kTu3OWt48oRtFS/O16NH8z4NG5qzQaMoTs6c/PvkycMUFoYfJksWRlcBnJWcPcsZiL9/zBmqzYlPQzzvA0A3AAssXvcG8E085zYBcByAV2LtyowgDWGENQwZwqHMokV2vV1CC5zszeTJvLflQiCtOasBzBFq0aKswQBwQZXWWn/3HXPOJwUjgur6db426iwkRPfuPOfHH/l6//74awMYMfhGm1FRjJQBYubY79SJ+ypUSJrcBlFRHBFfvcrXM2eynXXr4r/myBGe4+ys9ccfm/ufPo0pa2yMdQ4//sgooQkTeF+AI3lL9u3j/v/9j6+fPGE8vzET0JoLFwGOzhPDmlxPnsSc9RmRZnfu8LUxM7OkTRvOSrRmxJblzO/CBc5Ob99mO9OnJy5XwjI7ZkYQDMCyKmsRAFdin6SUqgxgAYCOWutkZPkWHI6nJ4dHpUrRuzV4sJnb2Q4Yk0FjtJyaGKmYY9toDee2MfJ1cwMGDWL0kzHCHjLEHDXG5tEj1j8w8vwbTlYj6iQw0Iz6iY933+W2QQNGVn3wQfzRV7F9DUpx1F2oUEw7eOnStKVbxuonBa3pR5k1i6+NCK3Y+fuXLDHTR5Quzfd44wYjawwSq6tcqhSd9oULc2T+zTd839OmcdRviZHN1lg34OrK77RkSdOhbNjtrdWciI2vr+lwNjDSShsYEUNG5tPYmXjHjuUaDuO+RYvGLBBYrBg/GyM9iWWeK5sTn4Z43geY0O4sgOIAsgAIAlAx1jnFAJwGUDep7cqMII1x7RoNyS1bMhg/Tx4OY+zA3buMlb992y7NJ8jEiRyVHT8ec78xwu7Rg1tfX44MAa0nTUq83chIrhUYOZKvU2KXt2T5ct77p5+sHzdGzF5eKb9HYhQpwnt89hlnCJs2xXQhBQbyeJkyibcFcLSfEJbV6YYPp+09NkZ6jV9/jXvMGIXfvctzWrdOXK47d+LODmNjrDg3/A2xMfwWAwZYP/7oESv0rVvHVdLP+7eCI2YEWusIAG8B2AiafVZprY8qpYYopaJLnGMcAC8Ac5RSgUqpgHiaE9Iq+fNzCLpxI4dSWjNe0jKzl43InZux8sldSWwLxoyh7b5cuZj7X3yR2TyNesZubjGLyiSGkY301Cl+hDNnmrOBlGCsQ4hvFlGgANtPrFbx82CMXENCOONo3tz66N6yKHx8FCgQc41GbPbv533Gj+cIPzSUfoDYoZbvvcforXbtzH2rVzM81hiF587NuP6ffkpcLg8P6wWXLPntN47444tOM2ZLxowgNi4uDKsNCKC/wXJdi62x6zoCrfUGrXUZrXVJrfXk6H1ztdZzo58P1Fp7aq2rRj8SyFYupFneeovz688/pw3k/Hl67uwZ+JzKODszLjw2SvGtGn/2XLlofrl7l07TpFC7Np3XhQrRpNGiRcrlNJRIQrUgmjaNuXbB1hgO5vgUtpEyPL4ymZbcvx+3ZKglxoK6woX5uRtFiGIvb3FyipniA6BiX7WKld4M6te3TTACQIXYunX8xw2HvrU1CQAVgY8PzVoXL8ZdS2FLpB6B8Px4enJBwJIl/Bf168dhXJ8+XM1kwzoGaZXSpbk61Ciok5xom/79GSceEGCuNE0pxv1jRwVZYlnG0x4YtnDLCCBLXniBii8piiAsDNi9O/HzjFF1mTJUApb1KeLDkG/jxoQ/L3thyHz/fvzylipF/1H58jFXoNsaUQSCbfDx4fz8/Hlz6LVyJYfR48dneGVQqFCKC7mhbl2GFz7HIu1nKMV6Qo6kbVuaROJTBEDSndAhIUkzsRmd6po1VKhJUQSVKnFrWd0uNTFmTLdvmyVRY2O5At2efyHJNSTYlshIBocPiXYDTZzI0lXxVX4XoBSwbh1j3jMCDRrQ/GSLDtbTM2FTjdFRGts8eZJuWmvYkEFuo0Y9l4gppnNnpsOoWjX+c4wcS4n5I54Xpe1peLIDfn5+OiBAfMppmqgorq554QWzFFbRovQhdO+e4WcHQurx6qtMD3HypKMlsQ+hoTQL1alDf8bzoJTaH58fVmYEgu1xcmL4zPff00ewfTsDvF99lVnGli2LW69PEFLA22+bK84zIu7uXOFt1zUEEB+BYE8sQyZateKw7dIlxloOH865cbduNBjHl3ZUEBLAyNqZUdGai+bKlLHvfcQ0JKQODx7Qi2gZFpMlC2cGOXPSYNu4MbeVK8dfqkwQhBSRkGlIZgRC6pAzJzOFbd7M1VN37jBsw9WVxs+1a1nXD6BpqUQJBthXr86wkRYtuC8pISSCICQL+VcJqYeLC01ErVrF3B8aygT2Tk7choUxEfudO0xgb+DszPCQwoW5LVmSTuiiRZnUvXBh+4dXCEIGRExDQtpCa0YZHTpEhRESAvz3v8BXX3HljasrVyxnzx6zLJWBuztjCH18qBzy5GHu44IFuc2Xj1svL676sszyJQgZmIRMQ6IIhPSB1jQdTZ7MhDH9+9Pc9Ouv3HfzJpWGkZq0YkUmqr982XrRYIOsWU1lkCsXA9crVuQsw8ODCid/fu738OA9jYe7O30ZEg4rpANEEQiZA63plL5+nbOBbNmYx3nNGtYQ3LWLs4hHj5go7+FD1oIMC2OyoNBQOq9z5eJaiKSs6VeKTu8cOThbcXPjvb29zUJZHh5UJEWLmsmI3N2pZLJn5zW7dvH6bt3MfaJgBBsiikAQLAkLM4v7xubqVSqKEiWYxvKbbzjbuH2bJqszZ7gE2NeXhWqXLmV7lqk0K1TgLOPq1bjFdpODszOVQ5YsnLlky0aTlpcXFdXFi1Qsjx8z0ipLFi7iy5mTCu3CBcpWpw7lK1iQxXDd3NhefFtxyGdIRBEIgj2JiGAkVIEC7KDd3dmh3r/PZa8PHlCZ7NtHh3jfvhztL1sG/P03O+6iRZmy1MWFKTXv3QPmzKFZy8hrDTC7Xd68PH78uH1SUhqmMicnPjceHh5UJoaSc3GhojIe+fJxNgTQJOfmRoVrrDB3caEyU4qyFyhAh3/WrPzcPDzYzv/9n5nuNWdOXps1KxVdliz8vJ2c6P+RWVOSEUUgCOmZqCgqkzt3OAMwspVFRVF5KMXk93nzMtGP1kzUHx7ODnP1avo5nJ25ViM8nEro6lUqmcePOYPIn5/XP37MFeEPHrDTjYjgvfLmZUGGsDCz5Jcl/9/e/cVIdZZxHP/+WJZdCo1rAS3BBihBDDZaCIU21KYXtVLSdNWbVkxa/yS1ptQ/iSFoE1Pv0EYTrmwwNlaD5cJa5aKRNhYkllBAZPkjRahAAkWWpIAUusvu8njxvJOdDjvD7naGM2fP80lOZuacM7Pvs2/mPPOeOfO8HR1+8O/pGZwYodEkX1pbPcmMG+d/u7XV4x03zpcpU/z0nJn/qLG/3xN2R4dffDBrliedixc9abe0eOIqLfPm+f7Hj3s/lI/QBgYGL4U+e9b/P+VJq7fXJ3Bua/MR4qlTnvgnTfJ92tv9woXx4/31p0/3fr5wwfu6vX2wPe3tV091Nux/VSSCEEI9mQ3WlCotkyb5p/feXv9U39vr809u2+an2zo7vXDOhg2+37lzPpIy8wPhQw/56xw75rO8Hz3qjw8f9oPwihX+ePNmf253tx+4L1/2U2Pz5vlB/fXX/cKBgQFfrlzxbTff7Af5ri4/aF+65NvAq7tNmOCnAE+fvjreG24YjLf0nCysWuU1u0Yhs0QgaRmwFmjBJ7JfU7Fdafty4BLwNTPbXes1IxGEEOqir88T1syZg5M5v/++L6Xt/f1+O2OGJw8zn0S4p2fwO6CJE2HBAt/v4EEfcZRGWZJ/qp8zx7dv2+YjlJMnfXtfnyfQpUs9ae3b54mvp8c//Z8/7yOZJUt8+6JF/uv7UcgkEUhqAf4NfB6fyH4n8BUz+1fZPsuBp/BEsARYa2Y1p6uIRBBCCCOXVfXRxcARM/uPmV0GNgCdFft0Ar9NcytvBzokNbjOXgghhHKNTAQzgPLZRk+kdSPdB0mPS9oladeZM2fq3tAQQiiyRiaCoa7rqjwPNZx9MLN1ZrbIzBZNmzatLo0LIYTgGpkITgC3lD3+BPDOKPYJIYTQQI1MBDuBuZJmS5oAPAJsrNhnI/Co3J3AeTM71cA2hRBCqNCw35KbWb+klcAm/PLR583sgKQn0vbngFfwK4aO4JePfr1R7QkhhDC0hhYVMbNX8IN9+brnyu4b8GQj2xBCCKG2mLw+hBAKLnclJiSdAY6P8ulTgQ9RDrLpjKV4xlIsMLbiiVia10jimWlmQ152mbtE8GFI2lXtl3V5NJbiGUuxwNiKJ2JpXvWKJ04NhRBCwUUiCCGEgitaIliXdQPqbCzFM5ZigbEVT8TSvOoST6G+IwghhHC1oo0IQgghVIhEEEIIBVeYRCBpmaRDko5IWp11e0ZK0jFJ+yTtkbQrrbtJ0muSDqfbj2bdzmokPS+pW9L+snVV2y/ph6mvDkn6QjatHlqVWJ6RdDL1z5406VJpWzPHcoukzZIOSjog6btpfV77plo8uesfSe2SdkjqSrH8JK2vf9+Y2Zhf8FpHbwO3AhOALmB+1u0aYQzHgKkV634GrE73VwM/zbqdNdp/D7AQ2H+t9gPzUx+1AbNT37VkHcM1YnkG+MEQ+zZ7LNOBhen+jfisgvNz3DfV4sld/+Bl+ien+63Am8CdjeiboowIhjNbWh51Ai+k+y8AX8ywLTWZ2Vbg3YrV1drfCWwws14zO4oXJVx8XRo6DFViqabZYzllaZ5wM7sAHMQnh8pr31SLp5qmjcfce+lha1qMBvRNURLBsGZCa3IGvCrpH5IeT+s+bqlsd7r9WGatG51q7c9rf62UtDedOioN13MTi6RZwAL8k2fu+6YiHshh/0hqkbQH6AZeM7OG9E1REsGwZkJrckvNbCHwAPCkpHuyblAD5bG/fgnMAW4HTgE/T+tzEYukycBLwPfM7H+1dh1iXR7iyWX/mNmAmd2OT9q1WNJtNXYfdSxFSQS5nwnNzN5Jt93Ay/iQ77Sk6QDptju7Fo5Ktfbnrr/M7HR6014BfsXgkLzpY5HUih8015vZH9Pq3PbNUPHkuX8AzOwcsAVYRgP6piiJYDizpTUtSZMk3Vi6D9wP7MdjeCzt9hjw52xaOGrV2r8ReERSm6TZwFxgRwbtG7bSGzP5Et4/0OSxSBLwa+Cgmf2ibFMu+6ZaPHnsH0nTJHWk+xOB+4C3aETfZP3N+HX8Bn45fgXB28DTWbdnhG2/Fb8aoAs4UGo/MAX4K3A43d6UdVtrxPAiPiTvwz+5fLNW+4GnU18dAh7Iuv3DiOV3wD5gb3pDTs9JLHfjpw/2AnvSsjzHfVMtntz1D/AZ4J+pzfuBH6f1de+bKDERQggFV5RTQyGEEKqIRBBCCAUXiSCEEAouEkEIIRRcJIIQQii4SAShsCRtS7ezJK2o82v/aKi/FUIzistHQ+FJuhevTPngCJ7TYmYDNba/Z2aT69G+EBotRgShsCSVKjuuAT6X6tR/PxX6elbSzlSk7Ftp/3tTrfvf4z9OQtKfUiHAA6VigJLWABPT660v/1tyz0raL59f4uGy194i6Q+S3pK0Pv1KNoSGG591A0JoAqspGxGkA/p5M7tDUhvwhqRX076LgdvMy/wCfMPM3k0lAHZKesnMVktaaV4srNKX8cJnnwWmpudsTdsWAJ/G68O8ASwF/l7/cEP4oBgRhHC1+4FHU/nfN/Gf9M9N23aUJQGA70jqArbjBb/mUtvdwIvmBdBOA38D7ih77RPmhdH2ALPqEk0I1xAjghCuJuApM9v0gZX+XcLFisf3AXeZ2SVJW4D2Ybx2Nb1l9weI92e4TmJEEAJcwKc1LNkEfDuVM0bSJ1PV10ofAc6mJPApfBrBkr7S8ytsBR5O30NMw6e9bIpql6G44hNHCF7dsT+d4vkNsBY/LbM7fWF7hqGnAf0L8ISkvXi1x+1l29YBeyXtNrOvlq1/GbgLryRrwCoz+29KJCFkIi4fDSGEgotTQyGEUHCRCEIIoeAiEYQQQsFFIgghhIKLRBBCCAUXiSCEEAouEkEIIRTc/wFqDiBNP/7TuwAAAABJRU5ErkJggg==\n",
1233 | "text/plain": [
1234 | ""
1235 | ]
1236 | },
1237 | "metadata": {
1238 | "needs_background": "light"
1239 | },
1240 | "output_type": "display_data"
1241 | }
1242 | ],
1243 | "source": [
1244 | "get_loss_fig_aux(train_loss_data, test_loss_data)"
1245 | ]
1246 | },
1247 | {
1248 | "cell_type": "code",
1249 | "execution_count": null,
1250 | "metadata": {},
1251 | "outputs": [],
1252 | "source": []
1253 | }
1254 | ],
1255 | "metadata": {
1256 | "kernelspec": {
1257 | "display_name": "Python 3",
1258 | "language": "python",
1259 | "name": "python3"
1260 | },
1261 | "language_info": {
1262 | "codemirror_mode": {
1263 | "name": "ipython",
1264 | "version": 3
1265 | },
1266 | "file_extension": ".py",
1267 | "mimetype": "text/x-python",
1268 | "name": "python",
1269 | "nbconvert_exporter": "python",
1270 | "pygments_lexer": "ipython3",
1271 | "version": "3.7.6"
1272 | }
1273 | },
1274 | "nbformat": 4,
1275 | "nbformat_minor": 4
1276 | }
1277 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DIEN-DIN
2 |
3 | 本项目使用tensorflow2.0复现阿里兴趣排序模型DIEN与DIN。
4 |
5 | DIN论文链接: https://arxiv.org/pdf/1706.06978.pdf
6 |
7 | DIEN论文链接: https://arxiv.org/pdf/1809.03672.pdf
8 |
9 | 数据集使用阿里数据集测试模型代码, 数据集链接: https://tianchi.aliyun.com/dataset/dataDetail?dataId=56
10 |
11 | # 调用方法:
12 |
13 | ## 0. 简介:
14 |
15 | DIEN的输入特征中主要包含三个部分特征: 用户历史行为序列, 目标商品特征, 用户画像特征。
16 | 用户历史行为序列需包含点击序列与非点击序列。
17 | 请按如下1~2方法处理输入特征。
18 |
19 | ## 1. 初始化:
20 |
21 | 初始化DIEN时需传入5个参数:
22 |
23 | (注:feature_list中的特征名称,需要与embedding_dict中的特征名称一样)
24 |
25 | - embedding_count_dict:string->int格式,该变量记录需要embedding各个特征的词典个数,即最大整数索引+ 1的大小;
26 |
27 | - embedding_dim_dict:string->int格式,该变量记录需要embedding各个特征的输出维数,即密集嵌入的尺寸;
28 |
29 | - embedding_features_list:list(string)格式,该变量记录DIEN中user_profile部分所有需要embedding的feature名称;
30 |
31 | - user_behavior_features:list(string)格式,该变量记录DIEN中user_behavior与target_item部分所有需要embedding的feature名称
32 |
33 | - activation:string格式,默认值"PReLU",该变量空值全连接层激活函数,”PReLU“->PReLU,"Dice"->Dice
34 |
35 | ## 2. 模型调用:
36 |
37 | 模型调用需传入6个参数:
38 |
39 | (注:feature_list中的特征名称,需要与dict中的特征名称一样)
40 |
41 | - user_profile_dict:dict:string->Tensor格式,记录user_profile部分的所有输入特征的训练数据;
42 |
43 | - user_profile_list:list(string)格式,记录user_profile部分的所有特征名称;
44 |
45 | - click_behavior_dict:dict:string->Tensor格式,记录user_behavior部分所有点击输入特征的训练数据;
46 |
47 | - noclick_behavior_dict:dict:string->Tensor格式,记录user_behavior部分所有未点击输入特征的训练数据;
48 |
49 | - target_item_dict:dict:string->Tensor格式,记录target_item部分输入特征的训练数据;
50 |
51 | - user_behavior_list:list(string)格式,记录user_behavior部分的所有特征名称。
52 |
53 | # 调用演示代码:
54 |
55 | ## DIEN:
56 |
57 | DIEN_train_example.ipynb
58 |
59 | ## DIN:
60 |
61 | DIN_train_example.ipynb
62 |
63 | # 代码:
64 |
65 | - model.py: 定义模型代码
66 |
67 | - layers.py: 自定义层
68 |
69 | - loss.py: 定义Auxiliary Loss用到的NN
70 |
71 | - activations.py: 定义Dice激活函数
72 |
73 | - alibaba_data_reader.py: 输入数据处理函数(代码中使用数据已用spark处理后得到了所需序列数据, 及特征embedding词典数)
74 |
--------------------------------------------------------------------------------
/__pycache__/activations.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StephenBo-China/DIEN-DIN/e1d9bb0591f0e0ce5be35cbf328077f6da2a45d2/__pycache__/activations.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/alibaba_data_reader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StephenBo-China/DIEN-DIN/e1d9bb0591f0e0ce5be35cbf328077f6da2a45d2/__pycache__/alibaba_data_reader.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/layers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StephenBo-China/DIEN-DIN/e1d9bb0591f0e0ce5be35cbf328077f6da2a45d2/__pycache__/layers.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StephenBo-China/DIEN-DIN/e1d9bb0591f0e0ce5be35cbf328077f6da2a45d2/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StephenBo-China/DIEN-DIN/e1d9bb0591f0e0ce5be35cbf328077f6da2a45d2/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StephenBo-China/DIEN-DIN/e1d9bb0591f0e0ce5be35cbf328077f6da2a45d2/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/activations.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | class Dice(tf.keras.layers.Layer):
4 | def __init__(self):
5 | super(Dice, self).__init__()
6 | self.bn = tf.keras.layers.BatchNormalization(center=False, scale=False)
7 | self.alpha = self.add_weight(shape=(), dtype=tf.float32, name='alpha')
8 |
9 | def call(self, x):
10 | x_normed = self.bn(x)
11 | x_p = tf.sigmoid(x_normed)
12 | return self.alpha * (1.0 - x_p) * x + x_p * x
13 |
14 | class dice(tf.keras.layers.Layer):
15 | def __init__(self, feat_dim):
16 | super(dice, self).__init__()
17 | self.feat_dim = feat_dim
18 | self.alphas= tf.Variable(tf.zeros([feat_dim]), dtype=tf.float32)
19 | self.beta = tf.Variable(tf.zeros([feat_dim]), dtype=tf.float32)
20 |
21 | self.bn = tf.keras.layers.BatchNormalization(center=False, scale=False)
22 |
23 | def call(self, _x, axis=-1, epsilon=0.000000001):
24 |
25 | reduction_axes = list(range(len(_x.get_shape())))
26 | del reduction_axes[axis]
27 | broadcast_shape = [1] * len(_x.get_shape())
28 | broadcast_shape[axis] = self.feat_dim
29 |
30 | mean = tf.reduce_mean(_x, axis=reduction_axes)
31 | brodcast_mean = tf.reshape(mean, broadcast_shape)
32 | std = tf.reduce_mean(tf.square(_x - brodcast_mean) + epsilon, axis=reduction_axes)
33 | std = tf.sqrt(std)
34 | brodcast_std = tf.reshape(std, broadcast_shape)
35 |
36 | x_normed = self.bn(_x)
37 | x_p = tf.keras.activations.sigmoid(self.beta * x_normed)
38 |
39 | return self.alphas * (1.0 - x_p) * _x + x_p * _x
--------------------------------------------------------------------------------
/alibaba_data_reader.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import tensorflow as tf
3 | import numpy as np
4 |
5 | def get_embedding_features_list():
6 | embedding_features_list = ["cate", "brand", "cms_segid", "cms_group",
7 | "gender", "age", "pvalue", "shopping",
8 | "occupation", "user_class_level"]
9 | return embedding_features_list
10 |
11 | def get_user_behavior_features():
12 | user_behavior_features = ["cate", "brand"]
13 | return user_behavior_features
14 |
15 | def get_embedding_count(feature, embedding_count):
16 | return embedding_count[feature].values[0]
17 |
18 | def get_embedding_count_dict(embedding_features_list, embedding_count):
19 | embedding_count_dict = dict()
20 | for feature in embedding_features_list:
21 | embedding_count_dict[feature] = get_embedding_count(feature, embedding_count)
22 | embedding_count_dict["brand"] = 500000
23 | embedding_count_dict["cate"] = 501578
24 | embedding_count_dict["gender"] = 3
25 | embedding_count_dict["pvalue"] = 10
26 | embedding_count_dict["shopping"] = 4
27 | embedding_count_dict["occupation"] = 5
28 | embedding_count_dict["user_class_level"] = 5
29 | return embedding_count_dict
30 |
31 | def get_embedding_dim_dict(embedding_features_list):
32 | embedding_dim_dict = dict()
33 | for feature in embedding_features_list:
34 | embedding_dim_dict[feature] = 64
35 | return embedding_dim_dict
36 |
37 | def get_data():
38 | train_data = pd.read_csv("./data/train.csv", sep = "\t")
39 | train_data = train_data.fillna(0)
40 | train_data = train_data[train_data["guide_dien_final_train_data.click_cate"] != 0]
41 | train_data = train_data[train_data["guide_dien_final_train_data.click_brand"] != 0]
42 | test_data = pd.read_csv("./data/test.csv", sep = "\t")
43 | test_data = test_data.fillna(0)
44 | test_data = test_data[test_data["guide_dien_final_train_data.click_cate"] != 0]
45 | test_data = test_data[test_data["guide_dien_final_train_data.click_brand"] != 0]
46 | embedding_count = pd.read_csv("./data/embedding_count.csv")
47 | return train_data, test_data, embedding_count
48 |
49 | def get_normal_data(data, col):
50 | return data[col].values
51 |
52 | def get_sequence_data(data, col):
53 | rst = []
54 | max_length = 0
55 | for i in data[col].values:
56 | temp = len(list(map(eval,i[1:-1].split(","))))
57 | if temp > max_length:
58 | max_length = temp
59 |
60 | for i in data[col].values:
61 | temp = list(map(eval,i[1:-1].split(",")))
62 | padding = np.zeros(max_length - len(temp))
63 | rst.append(list(np.append(np.array(temp), padding)))
64 | return rst
65 |
66 | def get_length(data, col):
67 | rst = []
68 | for i in data[col].values:
69 | temp = len(list(map(eval,i[1:-1].split(","))))
70 | rst.append(temp)
71 | return rst
72 |
73 | def convert_tensor(data):
74 | return tf.convert_to_tensor(data)
75 |
76 | def get_batch_data(data, min_batch, batch=100):
77 | # batch_data = None
78 | # if min_batch + batch <= len(data):
79 | # batch_data = data.loc[min_batch:min_batch + batch - 1]
80 | # else:
81 | # batch_data = data.loc[min_batch:]
82 | batch_data = data.sample(n=batch)
83 | click = get_normal_data(batch_data, "guide_dien_final_train_data.clk")
84 | #no_click = get_normal_data(batch_data, "guide_dien_final_train_data.nonclk")
85 | #label = [click, no_click]
86 | #label = click
87 | target_cate = get_normal_data(batch_data, "guide_dien_final_train_data.cate_id")
88 | target_brand = get_normal_data(batch_data, "guide_dien_final_train_data.brand")
89 | cms_segid = get_normal_data(batch_data, "guide_dien_final_train_data.cms_segid")
90 | cms_group = get_normal_data(batch_data, "guide_dien_final_train_data.cms_group_id")
91 | gender = get_normal_data(batch_data, "guide_dien_final_train_data.final_gender_code")
92 | age = get_normal_data(batch_data, "guide_dien_final_train_data.age_level")
93 | pvalue = get_normal_data(batch_data, "guide_dien_final_train_data.pvalue_level")
94 | shopping = get_normal_data(batch_data, "guide_dien_final_train_data.shopping_level")
95 | occupation = get_normal_data(batch_data, "guide_dien_final_train_data.occupation")
96 | user_class_level = get_normal_data(batch_data, "guide_dien_final_train_data.new_user_class_level")
97 | hist_brand_behavior_clk = get_sequence_data(batch_data, "guide_dien_final_train_data.click_brand")
98 | hist_cate_behavior_clk = get_sequence_data(batch_data, "guide_dien_final_train_data.click_cate")
99 | hist_brand_behavior_show = get_sequence_data(batch_data, "guide_dien_final_train_data.show_brand")
100 | hist_cate_behavior_show = get_sequence_data(batch_data, "guide_dien_final_train_data.show_cate")
101 | #reshape_len = convert_tensor(label).numpy().shape[1]
102 | clk_length = get_length(batch_data, "guide_dien_final_train_data.click_brand")
103 | show_length = get_length(batch_data, "guide_dien_final_train_data.show_brand")
104 | return tf.one_hot(click, 2), convert_tensor(target_cate), convert_tensor(target_brand), convert_tensor(cms_segid), convert_tensor(cms_group), convert_tensor(gender), convert_tensor(age), convert_tensor(pvalue), convert_tensor(shopping), convert_tensor(occupation), convert_tensor(user_class_level), convert_tensor(hist_brand_behavior_clk), convert_tensor(hist_cate_behavior_clk), convert_tensor(hist_brand_behavior_show), convert_tensor(hist_cate_behavior_show), min_batch + batch, clk_length, show_length
105 |
106 | def get_test_data(data):
107 | batch_data = data.head(150)
108 | #batch_data = data.sample(n = 50)
109 | click = get_normal_data(batch_data, "guide_dien_final_train_data.clk")
110 | target_cate = get_normal_data(batch_data, "guide_dien_final_train_data.cate_id")
111 | target_brand = get_normal_data(batch_data, "guide_dien_final_train_data.brand")
112 | cms_segid = get_normal_data(batch_data, "guide_dien_final_train_data.cms_segid")
113 | cms_group = get_normal_data(batch_data, "guide_dien_final_train_data.cms_group_id")
114 | gender = get_normal_data(batch_data, "guide_dien_final_train_data.final_gender_code")
115 | age = get_normal_data(batch_data, "guide_dien_final_train_data.age_level")
116 | pvalue = get_normal_data(batch_data, "guide_dien_final_train_data.pvalue_level")
117 | shopping = get_normal_data(batch_data, "guide_dien_final_train_data.shopping_level")
118 | occupation = get_normal_data(batch_data, "guide_dien_final_train_data.occupation")
119 | user_class_level = get_normal_data(batch_data, "guide_dien_final_train_data.new_user_class_level")
120 | hist_brand_behavior_clk = get_sequence_data(batch_data, "guide_dien_final_train_data.click_brand")
121 | hist_cate_behavior_clk = get_sequence_data(batch_data, "guide_dien_final_train_data.click_cate")
122 | hist_brand_behavior_show = get_sequence_data(batch_data, "guide_dien_final_train_data.show_brand")
123 | hist_cate_behavior_show = get_sequence_data(batch_data, "guide_dien_final_train_data.show_cate")
124 | clk_length = get_length(batch_data, "guide_dien_final_train_data.click_brand")
125 | show_length = get_length(batch_data, "guide_dien_final_train_data.show_brand")
126 | return tf.one_hot(click, 2), convert_tensor(target_cate), convert_tensor(target_brand), convert_tensor(cms_segid), convert_tensor(cms_group), convert_tensor(gender), convert_tensor(age), convert_tensor(pvalue), convert_tensor(shopping), convert_tensor(occupation), convert_tensor(user_class_level), convert_tensor(hist_brand_behavior_clk), convert_tensor(hist_cate_behavior_clk), convert_tensor(hist_brand_behavior_show), convert_tensor(hist_cate_behavior_show), clk_length, show_length
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras import layers
3 | from activations import Dice,dice
4 |
5 | class GRU_GATES(tf.keras.layers.Layer):
6 | def __init__(self, units):
7 | super(GRU_GATES, self).__init__()
8 | self.linear_act = layers.Dense(units, activation=None, use_bias=True)
9 | self.linear_noact = layers.Dense(units, activation=None, use_bias=False)
10 |
11 | def call(self, a, b, gate_b=None):
12 | if gate_b is None:
13 | return tf.keras.activations.sigmoid(self.linear_act(a) + self.linear_noact(b))
14 | else:
15 | return tf.keras.activations.tanh(self.linear_act(a) + tf.math.multiply(gate_b, self.linear_noact(b)))
16 |
17 | class AUGRU(layers.Layer):
18 | def __init__(self, units):
19 | super(AUGRU, self).__init__()
20 | self.u_gate = GRU_GATES(units)
21 | self.r_gate = GRU_GATES(units)
22 | self.c_memo = GRU_GATES(units)
23 |
24 | def call(self, inputs, state, att_score):
25 | u = self.u_gate(inputs, state) #u_t
26 | r = self.r_gate(inputs, state) #r_t
27 | c = self.c_memo(inputs, state, r) #\tilde{h_t}
28 | u_= att_score * u #\tilde{u_{t}'} [AUGRU Add]
29 | state_next = (1 - u_) * state + u_ * c #h_t [AUGRU change u_t on output]
30 | return state_next
31 |
32 | class attention(tf.keras.layers.Layer):
33 | def __init__(self, keys_dim):
34 | super(attention, self).__init__()
35 | self.keys_dim = keys_dim
36 | self.fc = tf.keras.Sequential()
37 | self.fc.add(layers.BatchNormalization())
38 | self.fc.add(layers.Dense(36, activation="sigmoid"))
39 | self.fc.add(dice(36))
40 | self.fc.add(layers.Dense(1, activation=None))
41 |
42 | def call(self, queries, keys, keys_length):
43 | #Attention
44 | queries = tf.tile(tf.expand_dims(queries, 1), [1, tf.shape(keys)[1], 1])
45 | din_all = tf.concat([queries, keys, queries-keys, queries*keys], axis=-1)
46 | outputs = tf.transpose(self.fc(din_all), [0,2,1])
47 | key_masks = tf.sequence_mask(keys_length, max(keys_length), dtype=tf.bool)
48 | key_masks = tf.expand_dims(key_masks, 1)
49 | paddings = tf.ones_like(outputs) * (-2 ** 32 + 1)
50 | outputs = tf.where(key_masks, outputs, paddings)
51 | outputs = outputs / (self.keys_dim ** 0.5)
52 | #outputs = tf.keras.activations.softmax(outputs, -1)
53 | outputs = tf.keras.activations.sigmoid(outputs)
54 |
55 | #Sum Pooling
56 | outputs = tf.squeeze(tf.matmul(outputs, keys))
57 | print("outputs:" + str(outputs.numpy().shape))
58 | return outputs
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras import layers
3 |
4 | class AuxLayer(layers.Layer):
5 | def __init__(self):
6 | super().__init__()
7 | self.fc = tf.keras.Sequential()
8 | self.fc.add(layers.BatchNormalization())
9 | self.fc.add(layers.Dense(100, activation="sigmoid"))
10 | self.fc.add(layers.ReLU())
11 | self.fc.add(layers.Dense(50, activation="sigmoid"))
12 | self.fc.add(layers.ReLU())
13 | self.fc.add(layers.Dense(2, activation=None))
14 |
15 | def call(self, input):
16 | logit = tf.squeeze(self.fc(input))
17 | return tf.keras.activations.softmax(logit)
18 |
19 |
--------------------------------------------------------------------------------
/main.ipynb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StephenBo-China/DIEN-DIN/e1d9bb0591f0e0ce5be35cbf328077f6da2a45d2/main.ipynb
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras import layers
3 | from layers import AUGRU
4 | from activations import Dice
5 | import pandas as pd
6 | from model import DIEN
7 | import alibaba_data_reader as data_reader
8 |
9 | def train_one_step(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list, label, optimizer, model, alpha, loss_metric):
10 | with tf.GradientTape() as tape:
11 | output, logit, aux_loss = model(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list)
12 | target_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logit,labels=tf.cast(label, dtype=tf.float32)))
13 | final_loss = target_loss + alpha * aux_loss
14 | print("[Train Step] aux_loss=" + str(aux_loss.numpy()) + ", target_loss=" + str(target_loss.numpy()) + ", final_loss=" + str(final_loss.numpy()))
15 | gradient = tape.gradient(final_loss, model.trainable_variables)
16 | clip_gradient, _ = tf.clip_by_global_norm(gradient, 5.0)
17 | optimizer.apply_gradients(zip(clip_gradient, model.trainable_variables))
18 | loss_metric(final_loss)
19 |
20 | def get_train_data(label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show):
21 | user_profile_dict = {
22 | "cms_segid": cms_segid,
23 | "cms_group": cms_group,
24 | "gender": gender,
25 | "age": age,
26 | "pvalue": pvalue,
27 | "shopping": shopping,
28 | "occupation": occupation,
29 | "user_class_level": user_class_level
30 | }
31 | user_profile_list = ["cms_segid", "cms_group", "gender", "age", "pvalue", "shopping", "occupation", "user_class_level"]
32 | user_behavior_list = ["brand", "cate"]
33 | click_behavior_dict = {
34 | "brand": hist_brand_behavior_clk,
35 | "cate": hist_cate_behavior_clk
36 | }
37 | noclick_behavior_dict = {
38 | "brand": hist_brand_behavior_show,
39 | "cate": hist_cate_behavior_show
40 | }
41 | target_item_dict = {
42 | "brand": target_cate,
43 | "cate": target_brand
44 | }
45 | return user_profile_dict, user_profile_list, user_behavior_list, click_behavior_dict, noclick_behavior_dict, target_item_dict
46 |
47 | def main():
48 | train_data, test_data, embedding_count = data_reader.get_data()
49 | embedding_features_list = data_reader.get_embedding_features_list()
50 | user_behavior_features = data_reader.get_user_behavior_features()
51 | embedding_count_dict = data_reader.get_embedding_count_dict(embedding_features_list, embedding_count)
52 | embedding_dim_dict = data_reader.get_embedding_dim_dict(embedding_features_list)
53 | model = DIEN(embedding_count_dict, embedding_dim_dict, embedding_features_list, user_behavior_features)
54 | min_batch = 0
55 | batch = 100
56 | label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show, min_batch, clk_length, show_length = data_reader.get_batch_data(train_data, min_batch, batch = batch)
57 | user_profile_dict, user_profile_list, user_behavior_list, click_behavior_dict, noclick_behavior_dict, target_item_dict = get_train_data(label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show)
58 | log_path = "./train_log/"
59 | train_summary_writer = tf.summary.create_file_writer(log_path)
60 | optimizer = tf.keras.optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
61 | loss_metric = tf.keras.metrics.Sum()
62 | auc_metric = tf.keras.metrics.AUC()
63 | alpha = 1
64 | epochs = 1
65 | for epoch in range(epochs):
66 | min_batch = 0
67 | for i in range(int(len(train_data) / batch)):
68 | label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show, min_batch, clk_length, show_length = data_reader.get_batch_data(train_data, min_batch, batch = batch)
69 | user_profile_dict, user_profile_list, user_behavior_list, click_behavior_dict, noclick_behavior_dict, target_item_dict = get_train_data(label, target_cate, target_brand, cms_segid, cms_group, gender, age, pvalue, shopping, occupation, user_class_level, hist_brand_behavior_clk, hist_cate_behavior_clk, hist_brand_behavior_show, hist_cate_behavior_show)
70 | train_one_step(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list, label, optimizer, model, alpha, loss_metric)
71 |
72 |
73 | if __name__ == "__main__":
74 | print(tf.__version__)
75 | print("GPU Available: ", tf.test.is_gpu_available())
76 | main()
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras import layers
3 | from layers import AUGRU,attention
4 | from activations import Dice,dice
5 | from loss import AuxLayer
6 | import utils
7 |
8 | class DIEN(tf.keras.Model):
9 | def __init__(self, embedding_count_dict, embedding_dim_dict, embedding_features_list, user_behavior_features, activation="PReLU"):
10 | super(DIEN, self).__init__(embedding_count_dict, embedding_dim_dict, embedding_features_list, activation)
11 | """DIEN初始化model函数
12 |
13 | 该函数在调用DIEN时进行DIEN的Embedding层,GRU层,AUGRU层,全连接层的初始化操作
14 |
15 | Args:
16 | embedding_count_dict:string->int格式,该变量记录需要embedding各个特征的词典个数,即最大整数索引+ 1的大小;
17 | embedding_dim_dict:string->int格式,该变量记录需要embedding各个特征的输出维数,即密集嵌入的尺寸;
18 | embedding_features_list:list(string)格式,该变量记录DIEN中user_profile部分所有需要embedding的feature名称;
19 | user_behavior_features:list(string)格式,该变量记录DIEN中user_behavior与target_item部分所有需要embedding的feature名称
20 | activation:string格式,默认值"PReLU",该变量空值全连接层激活函数,”PReLU“->PReLU,"Dice"->Dice
21 | """
22 | #Init Embedding Layer
23 | self.embedding_dim_dict = embedding_dim_dict
24 | self.embedding_count_dict = embedding_count_dict
25 | self.embedding_layers = dict()
26 | for feature in embedding_features_list:
27 | self.embedding_layers[feature] = layers.Embedding(embedding_count_dict[feature], embedding_dim_dict[feature])
28 | #Init GRU Layer
29 | self.user_behavior_gru = layers.GRU(self.get_GRU_input_dim(embedding_dim_dict, user_behavior_features), return_sequences=True)
30 | #Init Attention Layer
31 | self.attention_layer = layers.Softmax()
32 | #Init Auxiliary Layer
33 | self.AuxNet = AuxLayer()
34 | #Init AUGRU Layer
35 | self.user_behavior_augru = AUGRU(self.get_GRU_input_dim(embedding_dim_dict, user_behavior_features))
36 | #Init Fully Connection Layer
37 | self.fc = tf.keras.Sequential()
38 | self.fc.add(layers.BatchNormalization())
39 | self.fc.add(layers.Dense(200, activation="relu"))
40 | if activation == "Dice":
41 | self.fc.add(Dice())
42 | elif activation == "dice":
43 | self.fc.add(dice(200))
44 | elif activation == "PReLU":
45 | self.fc.add(layers.PReLU(alpha_initializer='zeros', weights=None))
46 | self.fc.add(layers.Dense(80, activation="relu"))
47 | if activation == "Dice":
48 | self.fc.add(Dice())
49 | elif activation == "dice":
50 | self.fc.add(dice(80))
51 | elif activation == "PReLU":
52 | self.fc.add(layers.PReLU(alpha_initializer='zeros', weights=None))
53 | self.fc.add(layers.Dense(2, activation=None))
54 |
55 | def get_GRU_input_dim(self, embedding_dim_dict, user_behavior_features):
56 | rst = 0
57 | for feature in user_behavior_features:
58 | rst += embedding_dim_dict[feature]
59 | return rst
60 |
61 | def get_emb(self, user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list):
62 | user_profile_feature_embedding = dict()
63 | for feature in user_profile_list:
64 | data = user_profile_dict[feature]
65 | embedding_layer = self.embedding_layers[feature]
66 | user_profile_feature_embedding[feature] = embedding_layer(data)
67 |
68 | target_item_feature_embedding = dict()
69 | for feature in user_behavior_list:
70 | data = target_item_dict[feature]
71 | embedding_layer = self.embedding_layers[feature]
72 | target_item_feature_embedding[feature] = embedding_layer(data)
73 |
74 | click_behavior_embedding = dict()
75 | for feature in user_behavior_list:
76 | data = click_behavior_dict[feature]
77 | embedding_layer = self.embedding_layers[feature]
78 | click_behavior_embedding[feature] = embedding_layer(data)
79 |
80 | # noclick_behavior_embedding = dict()
81 | # for feature in user_behavior_list:
82 | # data = noclick_behavior_dict[feature]
83 | # embedding_layer = self.embedding_layers[feature]
84 | # noclick_behavior_embedding[feature] = embedding_layer(data)
85 |
86 | return utils.concat_features(user_profile_feature_embedding), utils.concat_features(target_item_feature_embedding), utils.concat_features(click_behavior_embedding)#, utils.concat_features(noclick_behavior_embedding)
87 |
88 | def auxiliary_loss(self, hidden_states, embedding_out):
89 | """Auxiliary Loss Function
90 |
91 | 论文中包含的源代码aux loss是通过hidden state与点击序列concate和hidden state
92 | 与展现序列concat后进一个全连接神经网络,通过softmax得到最终二分类结果与点击序列和展现序列求解log_loss的到最终aux loss。
93 | 该方法只使用用户的点击序列。
94 |
95 | Args:
96 | hidden_states: gru产出的所有hidden state,从h(0)到h(n-1)
97 | embedding_out: gru输入的embedding特征,从e(1)到e(n)
98 | """
99 | click_input_ = tf.concat([hidden_states, embedding_out], -1)
100 | click_prop_ = self.AuxNet(click_input_)[:, :, 0]
101 | click_loss_ = - tf.reshape(tf.math.log(click_prop_), [-1, tf.shape(embedding_out)[1]])
102 | return tf.reduce_mean(click_loss_)
103 |
104 | def call(self, user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list):
105 | """输入batch训练数据, 调用DIEN初始化后的model进行一次前向传播
106 |
107 | 调用该函数进行一次前向传播得到output, logit, aux_loss后,在自定义的训练函数内得出target_loss与final_loss后使用tensorflow中的梯度计算函数通过链式法则得到各层梯度后使用自定义优化器进行一次权重更新
108 |
109 | Args:
110 | user_profile_dict:dict:string->Tensor格式,记录user_profile部分的所有输入特征的训练数据;
111 | user_profile_list:list(string)格式,记录user_profile部分的所有特征名称;
112 | click_behavior_dict:dict:string->Tensor格式,记录user_behavior部分所有点击输入特征的训练数据;
113 | noclick_behavior_dict:dict:string->Tensor格式,记录user_behavior部分所有未点击输入特征的训练数据;
114 | target_item_dict:dict:string->Tensor格式,记录target_item部分输入特征的训练数据;
115 | user_behavior_list:list(string)Tensor格式,记录user_behavior部分的所有特征名称。
116 | """
117 | #Embedding Layer
118 | user_profile_embedding, target_item_embedding, click_behavior_emebedding = self.get_emb(user_profile_dict, user_profile_list, click_behavior_dict, target_item_dict, noclick_behavior_dict, user_behavior_list)
119 | #GRU Layer
120 | click_gru_emb = self.user_behavior_gru(click_behavior_emebedding)
121 | #noclick_gru_emb = self.user_behavior_gru(noclick_behavior_emebedding)
122 | #Auxiliary Loss
123 | aux_loss = self.auxiliary_loss(click_gru_emb[:, :-1, :], click_behavior_emebedding[:, 1:, :])
124 | #Attention Layer
125 | hist_attn = self.attention_layer(tf.matmul(tf.expand_dims(target_item_embedding, 1), click_gru_emb, transpose_b=True))
126 | #AUGRU Layer
127 | augru_hidden_state = tf.zeros_like(click_gru_emb[:, 0, :])
128 | for in_emb, in_att in zip(tf.transpose(click_gru_emb, [1, 0, 2]), tf.transpose(hist_attn, [2, 0, 1])):
129 | augru_hidden_state = self.user_behavior_augru(in_emb, augru_hidden_state, in_att)
130 | join_emb = tf.concat([augru_hidden_state, user_profile_embedding], -1)
131 | logit = tf.squeeze(self.fc(join_emb))
132 | output = tf.keras.activations.softmax(logit)
133 | return output, logit, aux_loss
134 |
135 | class DIN(tf.keras.Model):
136 | def __init__(self, embedding_count_dict, embedding_dim_dict, embedding_features_list, user_behavior_features, activation="PReLU"):
137 | super(DIN, self).__init__(embedding_count_dict, embedding_dim_dict, embedding_features_list, user_behavior_features, activation)
138 | #Init Embedding Layer
139 | self.embedding_dim_dict = embedding_dim_dict
140 | self.embedding_count_dict = embedding_count_dict
141 | self.embedding_layers = dict()
142 | for feature in embedding_features_list:
143 | self.embedding_layers[feature] = layers.Embedding(embedding_count_dict[feature], embedding_dim_dict[feature])
144 | #DIN Attention+Sum pooling
145 | self.hist_at = attention(utils.get_input_dim(embedding_dim_dict, user_behavior_features))
146 | #Init Fully Connection Layer
147 | self.fc = tf.keras.Sequential()
148 | self.fc.add(layers.BatchNormalization())
149 | self.fc.add(layers.Dense(200, activation="relu"))
150 | if activation == "Dice":
151 | self.fc.add(Dice())
152 | elif activation == "dice":
153 | self.fc.add(dice(200))
154 | elif activation == "PReLU":
155 | self.fc.add(layers.PReLU(alpha_initializer='zeros', weights=None))
156 | self.fc.add(layers.Dense(80, activation="relu"))
157 | if activation == "Dice":
158 | self.fc.add(Dice())
159 | elif activation == "dice":
160 | self.fc.add(dice(80))
161 | elif activation == "PReLU":
162 | self.fc.add(layers.PReLU(alpha_initializer='zeros', weights=None))
163 | self.fc.add(layers.Dense(2, activation=None))
164 |
165 | def get_emb_din(self, user_profile_dict, user_profile_list, hist_behavior_dict, target_item_dict, user_behavior_list):
166 | user_profile_feature_embedding = dict()
167 | for feature in user_profile_list:
168 | data = user_profile_dict[feature]
169 | embedding_layer = self.embedding_layers[feature]
170 | user_profile_feature_embedding[feature] = embedding_layer(data)
171 |
172 | target_item_feature_embedding = dict()
173 | for feature in user_behavior_list:
174 | data = target_item_dict[feature]
175 | embedding_layer = self.embedding_layers[feature]
176 | target_item_feature_embedding[feature] = embedding_layer(data)
177 |
178 | hist_behavior_embedding = dict()
179 | for feature in user_behavior_list:
180 | data = hist_behavior_dict[feature]
181 | embedding_layer = self.embedding_layers[feature]
182 | hist_behavior_embedding[feature] = embedding_layer(data)
183 |
184 | return utils.concat_features(user_profile_feature_embedding), utils.concat_features(target_item_feature_embedding), utils.concat_features(hist_behavior_embedding)
185 |
186 | def call(self, user_profile_dict, user_profile_list, hist_behavior_dict, target_item_dict, user_behavior_list, length):
187 | #Embedding Layer
188 | user_profile_embedding, target_item_embedding, hist_behavior_emebedding = self.get_emb_din(user_profile_dict, user_profile_list, hist_behavior_dict, target_item_dict, user_behavior_list)
189 | hist_attn_emb = self.hist_at(target_item_embedding, hist_behavior_emebedding, length)
190 | join_emb = tf.concat([user_profile_embedding, target_item_embedding, hist_attn_emb], -1)
191 | logit = tf.squeeze(self.fc(join_emb))
192 | output = tf.keras.activations.softmax(logit)
193 | return output, logit
194 |
195 | if __name__ == "__main__":
196 | model = DIN(dict(), dict(), list(), list())
197 |
--------------------------------------------------------------------------------
/tensorboard.log:
--------------------------------------------------------------------------------
1 | nohup: ignoring input
2 | TensorBoard 2.0.0 at http://10.186.3.226:8028/ (Press CTRL+C to quit)
3 |
--------------------------------------------------------------------------------
/tensorboard.sh:
--------------------------------------------------------------------------------
1 | tensorboard --logdir=./train_log/din/ --host=10.186.3.226 --port=8028
2 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import tensorflow as tf
3 |
4 | def get_file_name():
5 | now_time = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
6 | return "loss.csv." + now_time
7 |
8 | def make_train_loss_dir(file_name, cols=["train_aux_loss","train_target_loss","train_final_loss"], model="dien"):
9 | f = open("./loss/" + model + "/train_" + file_name, "w")
10 | f.write(",".join(cols) + "\n")
11 | f.close()
12 |
13 | def make_test_loss_dir(file_name, cols=["test_aux_loss","test_target_loss","test_final_loss"], model="dien"):
14 | f = open("./loss/" + model + "/test_" + file_name, "w")
15 | f.write(",".join(cols) + "\n")
16 | f.close()
17 |
18 | def add_loss(loss_dict, file_name, cols = ["aux_loss", "target_loss", "final_loss"], level="train", model="dien"):
19 | loss_list = list()
20 | for col in cols:
21 | loss_list.append(loss_dict[col])
22 | f = open("./loss/" + model + "/" + level + "_" + file_name, "a")
23 | f.write(",".join(loss_list) + "\n")
24 | f.close()
25 |
26 | def get_input_dim(embedding_dim_dict, user_behavior_features):
27 | rst = 0
28 | for feature in user_behavior_features:
29 | rst += embedding_dim_dict[feature]
30 | return rst
31 |
32 | def concat_features(feature_data_dict):
33 | concat_list = []
34 | for k in feature_data_dict:
35 | concat_list.append(feature_data_dict[k])
36 | return tf.concat(concat_list, -1)
37 |
38 | def mkdir(path):
39 | try:
40 | if not os.path.exists(path):
41 | os.makedirs(path)
42 | return 0
43 | except:
44 | return 1
--------------------------------------------------------------------------------