├── Code-Statistics.zip ├── Code.zip ├── DE ├── NLP └── ai_news_project │ ├── NLP.ipynb │ └── readme.md ├── README.md ├── Skip_gram.ipynb ├── Verizon.ipynb ├── gpt_labeled_data.csv ├── ml_final_project.ipynb ├── patent_classification_transformer_wandb.ipynb └── test.ipynb /Code-Statistics.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easonwangzk/UChicago/660c0f58516554220d042b01bdd18296ce167d41/Code-Statistics.zip -------------------------------------------------------------------------------- /Code.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easonwangzk/UChicago/660c0f58516554220d042b01bdd18296ce167d41/Code.zip -------------------------------------------------------------------------------- /DE: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /NLP/ai_news_project/readme.md: -------------------------------------------------------------------------------- 1 | # 🧭 项目执行建议:AI如何影响产业与就业的文本挖掘分析 2 | 3 | --- 4 | 5 | ## 一、🌐 项目总览与核心目标 6 | 7 | 本项目旨在通过对 **20万篇与AI、数据科学、机器学习相关的新闻文章**的分析: 8 | 9 | 1. 识别AI对不同行业/岗位的潜在影响 10 | 2. 提取成功/失败的AI应用案例 11 | 3. 分析AI技术趋势与情感演变 12 | 4. 给出面向企业/组织的自动化落地建议 13 | 14 | --- 15 | 16 | ## 二、📅 阶段化执行计划 17 | 18 | | 阶段 | 目标 | 工具与方法 | 输出物 | 19 | |------|------|------------|--------| 20 | | 1️⃣ 数据理解与清洗 | 理解数据结构、剔除噪音内容 | pandas, 正则, NLP预处理 | 数据概况表、清洗效果图 | 21 | | 2️⃣ 主题识别与行业映射 | 确定主要AI话题及其对应行业 | BERTopic + 零样本分类 | 行业关键词热图、主题分布图 | 22 | | 3️⃣ 情感与影响建模 | 判断AI应用效果(积极/消极) | 自监督标注 + RoBERTa + wandb调参 | 情感趋势图、行业影响矩阵 | 23 | | 4️⃣ 技术演进分析 | 分析前沿AI技术出现与传播 | 关键词趋势分析 + 时间热图 | 技术演化时间轴图 | 24 | | 5️⃣ 命名实体与组织识别 | 提取参与机构与投资动向 | NER + Gradio展示 | 投资机构图谱、成功公司列表 | 25 | | 6️⃣ 成果集成与展示 | 汇报可交互、可视化、可落地的结论 | PPT + Gradio dashboard | PowerPoint报告、Gradio工具页面 | 26 | 27 | --- 28 | 29 | ## 三、🔧 建模技巧详解(按模块细化) 30 | 31 | ### 1️⃣ 主题建模与行业映射 32 | 33 | - **方法组合**: 34 | - 使用 **BERTopic** 自动生成具备上下文信息的主题。 35 | - 结合 **Zero-shot Classification** 模型(如 `facebook/bart-large-mnli`)对新闻文本打上行业标签(如金融、医疗、教育等)。 36 | 37 | - **注意事项**: 38 | - 可使用“标题+前500词”拼接提高主题提取准确性。 39 | - 用 Top-N关键词覆盖率 显示主题对语料代表性。 40 | 41 | - **可视化**: 42 | - 热力图展示“主题 ↔ 行业”覆盖关系 43 | - word cloud展示每个主题Top词汇 44 | 45 | --- 46 | 47 | ### 2️⃣ 情感分析与AI成效评估 48 | 49 | - **标签构建**: 50 | - 选取300~500篇样本文本,使用GPT或人工辅助标注“正向/负向/中性” 51 | - 标签策略:正向=成功部署AI,负向=引发失业/争议,中性=纯技术新闻 52 | 53 | - **模型建议**: 54 | - 微调一个轻量级的 `RoBERTa-base` 模型,配合wandb记录训练过程。 55 | - 加入类别权重平衡处理,防止中性标签主导模型判断。 56 | 57 | - **优化技巧**: 58 | - 使用 Text Augmentation(如EDA、Back Translation)增强样本多样性。 59 | - 用 wandb 进行超参数调优与模型版本对比。 60 | 61 | - **输出结果**: 62 | - 行业情感分布图(positive/neutral/negative) 63 | - 情感变化趋势线图(按季度或年份) 64 | - 成功/失败新闻摘要列表及情绪原因提取 65 | 66 | --- 67 | 68 | ### 3️⃣ 技术演进与应用趋势追踪 69 | 70 | - **关键词建议**: 71 | - “ChatGPT”, “LLM”, “Stable Diffusion”, “AutoML”, “Midjourney”, “Generative AI”, “Conversational AI”, “LangChain” 等 72 | 73 | - **方法流程**: 74 | - 构建关键词时间序列,统计每月出现频次 75 | - 分析其与正向/负向情感比例的关系 76 | 77 | - **可视化**: 78 | - 技术热度趋势时间线图 79 | - 技术 ↔ 行业 的交叉关联图谱 80 | 81 | --- 82 | 83 | ### 4️⃣ 命名实体识别与组织行为分析 84 | 85 | - **目标**:识别文本中的组织(公司/高校/政府)、人物、地理位置等 86 | 87 | - **技术路径**: 88 | - 使用 `SpaCy` 或 `transformers` 的 NER 模块提取实体 89 | - 按公司聚合分析,识别正面曝光最多的机构 90 | 91 | - **分析建议**: 92 | - 建立“机构 → 技术 → 行业”的三元组关系表 93 | - 提取“AI转型成功”公司的应用案例摘要 94 | 95 | - **Gradio可视化**: 96 | - 构建交互式公司-技术-情感图谱 97 | - 提供关键词输入 → 返回行业情感与技术影响的界面 98 | 99 | --- 100 | 101 | ## 四、📈 wandb 与 Gradio 的展示应用建议 102 | 103 | ### 🟣 wandb(可视化调参、实验追踪) 104 | 105 | - 用于: 106 | - 记录情感分析模型的训练损失、精度、F1指标 107 | - 多模型对比实验展示(不同模型/不同超参组合) 108 | - 嵌入式PPT图表生成(如confusion matrix、ROC曲线等) 109 | 110 | - 展示示例: 111 | - `wandb.init(project="ai_news_impact")` 112 | - 每轮epoch记录 loss、accuracy、learning rate 113 | 114 | --- 115 | 116 | ### 🟢 Gradio(可交互式NLP分析平台) 117 | 118 | - 用于构建交互式App展示: 119 | - 用户输入一段新闻摘要,返回: 120 | - 行业识别(Top 3) 121 | - 主题归属 122 | - 情感评分及解释 123 | - 是否包含新兴技术关键词 124 | 125 | - 模块建议: 126 | - “新闻行业预测器”:输入文本 → 返回行业标签 + 置信度 127 | - “AI情绪雷达”:展示行业情绪分布图 128 | - “技术趋势查询器”:输入关键词 → 返回相关行业/文章标题摘要 129 | 130 | --- 131 | 132 | ## 五、📢 商业结论与建议输出方向 133 | 134 | | 类型 | 建议示例 | 135 | |------|----------| 136 | | 行业洞察 | 法律、教育、客服、办公自动化受AI影响最深,建议部署辅助系统 | 137 | | 成功经验 | 谷歌、微软、OpenAI等机构推动AI落地,值得借鉴技术栈与组织模式 | 138 | | 失败警示 | 建筑、物流等行业由于操作难度高,AI适应性差,需警惕误用风险 | 139 | | 投资建议 | 政府可引导AI资源投入高替代性行业,减少结构性失业 | 140 | | 企业策略 | 实施AI前应加强数据治理、业务协同、员工培训与流程再造 | 141 | 142 | --- 143 | 144 | ## 六、📘 总结:成功交付的关键标准 145 | 146 | | 要素 | 是否具备 | 147 | |------|----------| 148 | | 数据处理合理 | ✅ 清洗+探索分析到位 | 149 | | NLP技术应用得当 | ✅ 涵盖主题建模、NER、情感建模 | 150 | | 可视化展示充分 | ✅ wandb实验记录 + Gradio交互工具 | 151 | | 商业价值明确 | ✅ 可落地建议 + 案例支持 | 152 | | 表达清晰规范 | ✅ PPT图文并茂 + 技术与业务兼顾 | 153 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UChicago 2 | Chicago Project 3 | -------------------------------------------------------------------------------- /Skip_gram.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "machine_shape": "hm", 8 | "gpuType": "T4", 9 | "authorship_tag": "ABX9TyMNQ04t0z15woNIXBmdOC4W", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | }, 19 | "accelerator": "GPU" 20 | }, 21 | "cells": [ 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "view-in-github", 26 | "colab_type": "text" 27 | }, 28 | "source": [ 29 | "\"Open" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "source": [ 35 | "import os\n", 36 | "import re\n", 37 | "import json\n", 38 | "import torch\n", 39 | "import random\n", 40 | "import requests\n", 41 | "import numpy as np\n", 42 | "from torch import nn, optim\n", 43 | "from torch.utils.data import Dataset, DataLoader\n", 44 | "from collections import Counter, OrderedDict\n", 45 | "from itertools import chain\n", 46 | "from typing import List" 47 | ], 48 | "metadata": { 49 | "id": "0gTrfO1Grc0n" 50 | }, 51 | "execution_count": 1, 52 | "outputs": [] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "source": [ 57 | "## Device setup" 58 | ], 59 | "metadata": { 60 | "id": "Khj_Sy7Mr5Xw" 61 | } 62 | }, 63 | { 64 | "cell_type": "code", 65 | "source": [ 66 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 67 | "print(f\"Using {device} as device.\")" 68 | ], 69 | "metadata": { 70 | "colab": { 71 | "base_uri": "https://localhost:8080/" 72 | }, 73 | "id": "2Q7kTaujr8nL", 74 | "outputId": "5481db4e-92e9-4b5b-c43d-4df819678660" 75 | }, 76 | "execution_count": 2, 77 | "outputs": [ 78 | { 79 | "output_type": "stream", 80 | "name": "stdout", 81 | "text": [ 82 | "Using cuda as device.\n" 83 | ] 84 | } 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "source": [ 90 | "## Configs" 91 | ], 92 | "metadata": { 93 | "id": "WPyPU2egr_m_" 94 | } 95 | }, 96 | { 97 | "cell_type": "code", 98 | "source": [ 99 | "data_dir = \"./data\"\n", 100 | "model_dir = \"./models\"\n", 101 | "debug = True\n", 102 | "\n", 103 | "if debug:\n", 104 | " CONTEXT_WINDOW = 2\n", 105 | " EMBEDDING_SIZE = 5\n", 106 | " MIN_FREQ = 5\n", 107 | " BATCH_SIZE = 3\n", 108 | " N_EPOCHS = 1\n", 109 | "else:\n", 110 | " CONTEXT_WINDOW = 4\n", 111 | " EMBEDDING_SIZE = 100\n", 112 | " MIN_FREQ = 25\n", 113 | " BATCH_SIZE = 64\n", 114 | " N_EPOCHS = 3" 115 | ], 116 | "metadata": { 117 | "id": "TkfrEtnysDEe" 118 | }, 119 | "execution_count": 3, 120 | "outputs": [] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "source": [ 125 | "## Create dirs" 126 | ], 127 | "metadata": { 128 | "id": "J5GwXnhWsSIP" 129 | } 130 | }, 131 | { 132 | "cell_type": "code", 133 | "source": [ 134 | "os.makedirs(data_dir, exist_ok=True)\n", 135 | "os.makedirs(model_dir, exist_ok=True)" 136 | ], 137 | "metadata": { 138 | "id": "PGFbWq1ssVlu" 139 | }, 140 | "execution_count": 4, 141 | "outputs": [] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "source": [ 146 | "## Download and tokenize text" 147 | ], 148 | "metadata": { 149 | "id": "UFa7y2nhsYD_" 150 | } 151 | }, 152 | { 153 | "cell_type": "code", 154 | "source": [ 155 | "url = \"https://www.gutenberg.org/cache/epub/7370/pg7370.txt\"\n", 156 | "response = requests.get(url)\n", 157 | "raw_text = response.text.lower()\n", 158 | "raw_text = re.sub(r'[^a-z\\s]', '', raw_text)\n", 159 | "sentences = [line.split() for line in raw_text.split('\\n') if line.strip()]\n", 160 | "print(f\"Number of sentences: {len(sentences):,}\")" 161 | ], 162 | "metadata": { 163 | "colab": { 164 | "base_uri": "https://localhost:8080/" 165 | }, 166 | "id": "DussqJ0MsagF", 167 | "outputId": "3d9ec680-dcd4-47da-d2cc-a0353c615d1d" 168 | }, 169 | "execution_count": 5, 170 | "outputs": [ 171 | { 172 | "output_type": "stream", 173 | "name": "stdout", 174 | "text": [ 175 | "Number of sentences: 4,957\n" 176 | ] 177 | } 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "source": [ 183 | "\n", 184 | "## Vocabulary class (same as CBOW version)" 185 | ], 186 | "metadata": { 187 | "id": "G6e44DWrsc9W" 188 | } 189 | }, 190 | { 191 | "cell_type": "code", 192 | "source": [ 193 | "class Vocab:\n", 194 | " def __init__(self, word_counts: OrderedDict, min_freq: int = 1, max_size: int = None, specials: List[str] = None, unk_token: str = \"\"):\n", 195 | " self.word_counts = word_counts\n", 196 | " self.min_freq = min_freq\n", 197 | " self.max_size = max_size\n", 198 | " self.unk_token = unk_token\n", 199 | " self.specials = list(specials) if specials else []\n", 200 | "\n", 201 | " if self.unk_token not in self.specials:\n", 202 | " self.specials.insert(0, self.unk_token)\n", 203 | "\n", 204 | " self.token2idx = {}\n", 205 | " self.idx2token = []\n", 206 | " self._prepare_vocab()\n", 207 | "\n", 208 | " def __len__(self):\n", 209 | " return len(self.idx2token)\n", 210 | "\n", 211 | " def __contains__(self, value):\n", 212 | " return value in self.idx2token\n", 213 | "\n", 214 | " def _prepare_vocab(self):\n", 215 | " vocab_list = self.specials.copy()\n", 216 | " filtered_words = [word for word, freq in self.word_counts.items() if freq >= self.min_freq and word not in self.specials]\n", 217 | " if self.max_size is not None:\n", 218 | " filtered_words = filtered_words[:self.max_size - len(self.specials)]\n", 219 | " vocab_list.extend(filtered_words)\n", 220 | " self.idx2token = vocab_list\n", 221 | " self.token2idx = {word: idx for idx, word in enumerate(vocab_list)}\n", 222 | "\n", 223 | " def get_token(self, idx: int) -> str:\n", 224 | " return self.idx2token[idx] if 0 <= idx < len(self.idx2token) else self.unk_token\n", 225 | "\n", 226 | " def get_index(self, token: str) -> int:\n", 227 | " return self.token2idx.get(token, self.token2idx[self.unk_token])\n", 228 | "\n", 229 | " def get_tokens(self, indices: List[int]) -> List[str]:\n", 230 | " return [self.get_token(idx) for idx in indices]\n", 231 | "\n", 232 | " def get_indices(self, tokens: List[str]) -> List[int]:\n", 233 | " return [self.get_index(token) for token in tokens]" 234 | ], 235 | "metadata": { 236 | "id": "x9gV2NoasfnN" 237 | }, 238 | "execution_count": 6, 239 | "outputs": [] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "source": [ 244 | "## Padding for skip-gram" 245 | ], 246 | "metadata": { 247 | "id": "PES9ALTWsmXd" 248 | } 249 | }, 250 | { 251 | "cell_type": "code", 252 | "source": [ 253 | "def pad_sentences(sentences: List[List[str]], context_length: int, pad_token: str = \"\") -> List[List[str]]:\n", 254 | " padded_sentences = []\n", 255 | " for sentence in sentences:\n", 256 | " padded_sentence = [pad_token] * context_length + sentence + [pad_token] * context_length\n", 257 | " padded_sentences.append(padded_sentence)\n", 258 | " return padded_sentences" 259 | ], 260 | "metadata": { 261 | "id": "z6Cuo7aispl7" 262 | }, 263 | "execution_count": 7, 264 | "outputs": [] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "source": [ 269 | "sentences = pad_sentences(sentences, CONTEXT_WINDOW)" 270 | ], 271 | "metadata": { 272 | "id": "6bunxvYPs0w5" 273 | }, 274 | "execution_count": 8, 275 | "outputs": [] 276 | }, 277 | { 278 | "cell_type": "markdown", 279 | "source": [ 280 | "## Build vocab" 281 | ], 282 | "metadata": { 283 | "id": "OqfQctNksswi" 284 | } 285 | }, 286 | { 287 | "cell_type": "code", 288 | "source": [ 289 | "vocab = Vocab(OrderedDict(Counter(chain.from_iterable(sentences))), min_freq=MIN_FREQ, specials=[\"\"])\n", 290 | "print(f\"Size of Vocabulary: {len(vocab):,}\")" 291 | ], 292 | "metadata": { 293 | "colab": { 294 | "base_uri": "https://localhost:8080/" 295 | }, 296 | "id": "1nPBlky7s4D1", 297 | "outputId": "cedabe98-c965-4814-e8a2-77ac4a2ce821" 298 | }, 299 | "execution_count": 9, 300 | "outputs": [ 301 | { 302 | "output_type": "stream", 303 | "name": "stdout", 304 | "text": [ 305 | "Size of Vocabulary: 1,158\n" 306 | ] 307 | } 308 | ] 309 | }, 310 | { 311 | "cell_type": "markdown", 312 | "source": [ 313 | "\n", 314 | "## Skip-gram pair generation" 315 | ], 316 | "metadata": { 317 | "id": "xAXtdfags6ok" 318 | } 319 | }, 320 | { 321 | "cell_type": "code", 322 | "source": [ 323 | "def generate_skipgram_pairs(sentences: List[List[str]], context_length: int, vocab: Vocab):\n", 324 | " inputs = []\n", 325 | " outputs = []\n", 326 | " for sentence in sentences:\n", 327 | " encoded = vocab.get_indices(sentence)\n", 328 | " for center_idx in range(context_length, len(encoded) - context_length):\n", 329 | " center_word = encoded[center_idx]\n", 330 | " context = encoded[center_idx - context_length : center_idx] + encoded[center_idx + 1 : center_idx + context_length + 1]\n", 331 | " for context_word in context:\n", 332 | " inputs.append(center_word)\n", 333 | " outputs.append(context_word)\n", 334 | " return torch.tensor(inputs), torch.tensor(outputs)\n", 335 | "\n", 336 | "inputs, outputs = generate_skipgram_pairs(sentences, CONTEXT_WINDOW, vocab)\n", 337 | "print(f\"Number of training examples: {len(inputs):,}\")" 338 | ], 339 | "metadata": { 340 | "colab": { 341 | "base_uri": "https://localhost:8080/" 342 | }, 343 | "id": "x51_SV14s9yY", 344 | "outputId": "72e5b9db-71a2-4192-858b-dd553829b95d" 345 | }, 346 | "execution_count": 10, 347 | "outputs": [ 348 | { 349 | "output_type": "stream", 350 | "name": "stdout", 351 | "text": [ 352 | "Number of training examples: 236,704\n" 353 | ] 354 | } 355 | ] 356 | }, 357 | { 358 | "cell_type": "markdown", 359 | "source": [ 360 | "## Dataset class" 361 | ], 362 | "metadata": { 363 | "id": "lqnhGseutA1j" 364 | } 365 | }, 366 | { 367 | "cell_type": "code", 368 | "source": [ 369 | "class SkipGramDataset(Dataset):\n", 370 | " def __init__(self, inputs, targets):\n", 371 | " self.inputs = inputs\n", 372 | " self.targets = targets\n", 373 | "\n", 374 | " def __len__(self):\n", 375 | " return len(self.inputs)\n", 376 | "\n", 377 | " def __getitem__(self, idx):\n", 378 | " return self.inputs[idx], self.targets[idx]" 379 | ], 380 | "metadata": { 381 | "id": "0HG8k7KmtDG-" 382 | }, 383 | "execution_count": 11, 384 | "outputs": [] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "source": [ 389 | "## Skip-gram model" 390 | ], 391 | "metadata": { 392 | "id": "UmZDRhZStFc1" 393 | } 394 | }, 395 | { 396 | "cell_type": "code", 397 | "source": [ 398 | "class SkipGram(nn.Module):\n", 399 | " def __init__(self, vocab_size, embedding_dim):\n", 400 | " super().__init__()\n", 401 | " self.embeddings = nn.Embedding(vocab_size, embedding_dim)\n", 402 | " self.linear = nn.Linear(embedding_dim, vocab_size)\n", 403 | "\n", 404 | " def forward(self, center_words):\n", 405 | " embeds = self.embeddings(center_words)\n", 406 | " out = self.linear(embeds)\n", 407 | " return out\n", 408 | "\n", 409 | " def debug_forward(self, center_words):\n", 410 | " embeds = self.embeddings(center_words)\n", 411 | " print(\"\\nembeddings shape:\", embeds.shape)\n", 412 | " print(embeds)\n", 413 | " out = self.linear(embeds)\n", 414 | " print(\"\\nlogits shape:\", out.shape)\n", 415 | " print(out)\n", 416 | " return out" 417 | ], 418 | "metadata": { 419 | "id": "hGxDDaIUtFL2" 420 | }, 421 | "execution_count": 12, 422 | "outputs": [] 423 | }, 424 | { 425 | "cell_type": "markdown", 426 | "source": [ 427 | "## Instantiate model" 428 | ], 429 | "metadata": { 430 | "id": "XmXRsNBStNnC" 431 | } 432 | }, 433 | { 434 | "cell_type": "code", 435 | "source": [ 436 | "model = SkipGram(vocab_size=len(vocab), embedding_dim=EMBEDDING_SIZE).to(device)\n", 437 | "print(model)" 438 | ], 439 | "metadata": { 440 | "colab": { 441 | "base_uri": "https://localhost:8080/" 442 | }, 443 | "id": "U2-3V0setNZK", 444 | "outputId": "61630f59-9743-47d6-bb91-668785ae2256" 445 | }, 446 | "execution_count": 13, 447 | "outputs": [ 448 | { 449 | "output_type": "stream", 450 | "name": "stdout", 451 | "text": [ 452 | "SkipGram(\n", 453 | " (embeddings): Embedding(1158, 5)\n", 454 | " (linear): Linear(in_features=5, out_features=1158, bias=True)\n", 455 | ")\n" 456 | ] 457 | } 458 | ] 459 | }, 460 | { 461 | "cell_type": "markdown", 462 | "source": [ 463 | "## Loss and optimizer" 464 | ], 465 | "metadata": { 466 | "id": "9CofeFdWtTGL" 467 | } 468 | }, 469 | { 470 | "cell_type": "code", 471 | "source": [ 472 | "criterion = nn.CrossEntropyLoss(ignore_index=vocab.get_index(vocab.unk_token))\n", 473 | "optimizer = optim.Adam(model.parameters(), lr=0.001)" 474 | ], 475 | "metadata": { 476 | "id": "RE_giWxqtV_S" 477 | }, 478 | "execution_count": 14, 479 | "outputs": [] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "source": [ 484 | "## Dataloader" 485 | ], 486 | "metadata": { 487 | "id": "Wa0WOIABtZnk" 488 | } 489 | }, 490 | { 491 | "cell_type": "code", 492 | "source": [ 493 | "dataset = SkipGramDataset(inputs, outputs)\n", 494 | "dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)" 495 | ], 496 | "metadata": { 497 | "id": "356_6QCTtb1U" 498 | }, 499 | "execution_count": 15, 500 | "outputs": [] 501 | }, 502 | { 503 | "cell_type": "markdown", 504 | "source": [ 505 | "## Training loop" 506 | ], 507 | "metadata": { 508 | "id": "3xJ6qhMUteOJ" 509 | } 510 | }, 511 | { 512 | "cell_type": "code", 513 | "source": [ 514 | "for epoch in range(N_EPOCHS):\n", 515 | " total_loss = 0\n", 516 | " for batch_inputs, batch_outputs in dataloader:\n", 517 | " batch_inputs, batch_outputs = batch_inputs.to(device), batch_outputs.to(device)\n", 518 | "\n", 519 | " optimizer.zero_grad()\n", 520 | " if debug:\n", 521 | " predictions = model.debug_forward(batch_inputs)\n", 522 | " else:\n", 523 | " predictions = model.forward(batch_inputs)\n", 524 | "\n", 525 | " loss = criterion(predictions, batch_outputs)\n", 526 | " loss.backward()\n", 527 | " optimizer.step()\n", 528 | " total_loss += loss.item()\n", 529 | "\n", 530 | " if debug: break\n", 531 | " if debug: break\n", 532 | " print(f\"Epoch {epoch+1}/{N_EPOCHS}, Loss: {total_loss/len(dataset):.4f}\")" 533 | ], 534 | "metadata": { 535 | "colab": { 536 | "base_uri": "https://localhost:8080/" 537 | }, 538 | "id": "Dh1QxsmotgQK", 539 | "outputId": "e6ed8cb9-adc9-4ea0-a0d8-2960f7401f48" 540 | }, 541 | "execution_count": 16, 542 | "outputs": [ 543 | { 544 | "output_type": "stream", 545 | "name": "stdout", 546 | "text": [ 547 | "\n", 548 | "embeddings shape: torch.Size([3, 5])\n", 549 | "tensor([[ 0.0099, 0.8007, -0.2172, -1.7865, -0.1345],\n", 550 | " [-0.1325, -1.2426, -0.1149, 1.1431, 0.3546],\n", 551 | " [-2.8135, 0.0679, 0.0196, -0.9808, 0.5849]], device='cuda:0',\n", 552 | " grad_fn=)\n", 553 | "\n", 554 | "logits shape: torch.Size([3, 1158])\n", 555 | "tensor([[ 0.6994, 0.6161, 0.1455, ..., 0.0060, 0.0517, 0.4258],\n", 556 | " [-0.0163, 0.1403, -0.6382, ..., -0.2566, -0.6915, -0.8407],\n", 557 | " [-0.2563, -0.1448, -0.0099, ..., -0.1927, 0.5144, 0.6950]],\n", 558 | " device='cuda:0', grad_fn=)\n" 559 | ] 560 | } 561 | ] 562 | }, 563 | { 564 | "cell_type": "markdown", 565 | "source": [ 566 | "## Save trained model weights and vocab" 567 | ], 568 | "metadata": { 569 | "id": "o-0K3JTiu3_V" 570 | } 571 | }, 572 | { 573 | "cell_type": "code", 574 | "source": [ 575 | "import torch.nn.functional as F\n", 576 | "import pickle" 577 | ], 578 | "metadata": { 579 | "id": "fKQPx6ODu_Ar" 580 | }, 581 | "execution_count": 18, 582 | "outputs": [] 583 | }, 584 | { 585 | "cell_type": "code", 586 | "source": [ 587 | "torch.save(model.embeddings.weight.data, f\"{model_dir}/weights.pt\")\n", 588 | "with open(f\"{model_dir}/vocab.pkl\", \"wb\") as f:\n", 589 | " pickle.dump(vocab, f)" 590 | ], 591 | "metadata": { 592 | "id": "Bo4Pzav2u8hj" 593 | }, 594 | "execution_count": 19, 595 | "outputs": [] 596 | }, 597 | { 598 | "cell_type": "markdown", 599 | "source": [ 600 | "\n", 601 | "## Define function to compute closest words" 602 | ], 603 | "metadata": { 604 | "id": "2BNSog-9vBrx" 605 | } 606 | }, 607 | { 608 | "cell_type": "code", 609 | "source": [ 610 | "def closest_words(embeddings, vocab, word, n=10):\n", 611 | " if word not in vocab.token2idx:\n", 612 | " raise ValueError(f\"'{word}' not in vocabulary\")\n", 613 | "\n", 614 | " word_idx = vocab.get_index(word)\n", 615 | " word_embedding = embeddings[word_idx]\n", 616 | "\n", 617 | " similarities = F.cosine_similarity(word_embedding.unsqueeze(0), embeddings, dim=1)\n", 618 | " similarities[word_idx] = -1 # exclude itself\n", 619 | "\n", 620 | " top_indices = similarities.topk(n).indices\n", 621 | " return [(vocab.get_token(idx), similarities[idx].item()) for idx in top_indices]" 622 | ], 623 | "metadata": { 624 | "id": "Fjt94vcavFkG" 625 | }, 626 | "execution_count": 20, 627 | "outputs": [] 628 | }, 629 | { 630 | "cell_type": "code", 631 | "source": [ 632 | "if torch.cuda.is_available():\n", 633 | " loaded_embeddings = torch.load(f\"{model_dir}/weights.pt\", weights_only=True)\n", 634 | "else:\n", 635 | " loaded_embeddings = torch.load(f\"{model_dir}/weights.pt\", weights_only=True, map_location=torch.device(\"cpu\"))\n", 636 | "\n", 637 | "with open(f\"{model_dir}/vocab.pkl\", \"rb\") as f:\n", 638 | " loaded_vocab = pickle.load(f)" 639 | ], 640 | "metadata": { 641 | "id": "l3K0NsZKvOeQ" 642 | }, 643 | "execution_count": 21, 644 | "outputs": [] 645 | }, 646 | { 647 | "cell_type": "markdown", 648 | "source": [ 649 | "## Run similarity search" 650 | ], 651 | "metadata": { 652 | "id": "FtVJ112bvTtm" 653 | } 654 | }, 655 | { 656 | "cell_type": "code", 657 | "source": [ 658 | "print(\"Trained model:\")\n", 659 | "print(closest_words(embeddings=loaded_embeddings, vocab=loaded_vocab, word=\"love\", n=10))" 660 | ], 661 | "metadata": { 662 | "colab": { 663 | "base_uri": "https://localhost:8080/" 664 | }, 665 | "id": "dMRGNCodvX9h", 666 | "outputId": "cf846faa-9b2f-473e-fc9d-4fd1850d8b8a" 667 | }, 668 | "execution_count": 22, 669 | "outputs": [ 670 | { 671 | "output_type": "stream", 672 | "name": "stdout", 673 | "text": [ 674 | "Trained model:\n", 675 | "[('was', 0.9637019038200378), ('make', 0.905871570110321), ('judge', 0.9049926996231079), ('body', 0.9028736352920532), ('brought', 0.8975038528442383), ('minority', 0.8772290945053101), ('wise', 0.8752244710922241), ('kingdom', 0.8742192387580872), ('grown', 0.8639228940010071), ('food', 0.8584514856338501)]\n" 676 | ] 677 | } 678 | ] 679 | }, 680 | { 681 | "cell_type": "markdown", 682 | "source": [ 683 | "\n", 684 | "## Compare with untrained model" 685 | ], 686 | "metadata": { 687 | "id": "9g12zcLpviU4" 688 | } 689 | }, 690 | { 691 | "cell_type": "code", 692 | "source": [ 693 | "model_untrained = SkipGram(vocab_size=len(vocab), embedding_dim=EMBEDDING_SIZE)\n", 694 | "untrained_embeddings = model_untrained.embeddings.weight.data\n", 695 | "\n", 696 | "print(\"\\nUntrained model:\")\n", 697 | "print(closest_words(embeddings=untrained_embeddings, vocab=vocab, word=\"love\", n=10))" 698 | ], 699 | "metadata": { 700 | "colab": { 701 | "base_uri": "https://localhost:8080/" 702 | }, 703 | "id": "zQR_iEdevgGX", 704 | "outputId": "d9d3ef2b-d17b-4533-8b9e-817c4fd29ca9" 705 | }, 706 | "execution_count": 23, 707 | "outputs": [ 708 | { 709 | "output_type": "stream", 710 | "name": "stdout", 711 | "text": [ 712 | "\n", 713 | "Untrained model:\n", 714 | "[('or', 0.9682488441467285), ('draw', 0.9493830800056458), ('agreed', 0.9472517371177673), ('paragraph', 0.9458328485488892), ('many', 0.9302278161048889), ('grants', 0.928244411945343), ('measure', 0.9108309149742126), ('understood', 0.9096970558166504), ('freemen', 0.9056882262229919), ('power', 0.89534991979599)]\n" 715 | ] 716 | } 717 | ] 718 | } 719 | ] 720 | } -------------------------------------------------------------------------------- /Verizon.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "**Data Dictionary**" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "- enrolldt : Enrollment Date
\n", 15 | "- price: Membership price
\n", 16 | "- downpmt: Downpayment
\n", 17 | "- monthdue: Months Due
\n", 18 | "- pmttype: Payment Type (1: Credit Card, 3: Cash , 4: Check, 5: Debit Card)
\n", 19 | "- use: Usage
\n", 20 | "- age: Age of customer
\n", 21 | "- gender: Gender of customer(1: Male, 2: Female)
\n", 22 | "- default: 1: Default, 0 Non-Default
" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | " We have used Logistic Regression, Decision Tree and Random Forest to predict if a customer will default or not. Try 2 other models of your choice and evaluate them on 2 metrics of your choice that we haven't used so far" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 1, 42 | "metadata": { 43 | "tags": [] 44 | }, 45 | "outputs": [ 46 | { 47 | "name": "stdout", 48 | "output_type": "stream", 49 | "text": [ 50 | "Epoch 1/100\n" 51 | ] 52 | }, 53 | { 54 | "name": "stderr", 55 | "output_type": "stream", 56 | "text": [ 57 | "/Users/easonwang/anaconda3/lib/python3.11/site-packages/keras/src/layers/core/input_layer.py:25: UserWarning: Argument `input_shape` is deprecated. Use `shape` instead.\n", 58 | " warnings.warn(\n" 59 | ] 60 | }, 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 728us/step - accuracy: 0.5867 - loss: 1.4448 - val_accuracy: 0.8344 - val_loss: -0.2755\n", 66 | "Epoch 2/100\n", 67 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 388us/step - accuracy: 0.7738 - loss: -0.4064 - val_accuracy: 0.8469 - val_loss: -0.9330\n", 68 | "Epoch 3/100\n", 69 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 369us/step - accuracy: 0.8351 - loss: -0.8675 - val_accuracy: 0.8532 - val_loss: -1.0668\n", 70 | "Epoch 4/100\n", 71 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 376us/step - accuracy: 0.8562 - loss: -1.2058 - val_accuracy: 0.8628 - val_loss: -1.1302\n", 72 | "Epoch 5/100\n", 73 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 368us/step - accuracy: 0.8619 - loss: -1.2597 - val_accuracy: 0.8687 - val_loss: -1.1850\n", 74 | "Epoch 6/100\n", 75 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 368us/step - accuracy: 0.8689 - loss: -1.4898 - val_accuracy: 0.8631 - val_loss: -1.1925\n", 76 | "Epoch 7/100\n", 77 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 369us/step - accuracy: 0.8687 - loss: -1.5509 - val_accuracy: 0.8620 - val_loss: -1.2102\n", 78 | "Epoch 8/100\n", 79 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 367us/step - accuracy: 0.8715 - loss: -1.5750 - val_accuracy: 0.8694 - val_loss: -1.2491\n", 80 | "Epoch 9/100\n", 81 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 366us/step - accuracy: 0.8730 - loss: -1.5086 - val_accuracy: 0.8635 - val_loss: -1.2463\n", 82 | "Epoch 10/100\n", 83 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 366us/step - accuracy: 0.8723 - loss: -1.5804 - val_accuracy: 0.8709 - val_loss: -1.2653\n", 84 | "Epoch 11/100\n", 85 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 384us/step - accuracy: 0.8754 - loss: -1.5467 - val_accuracy: 0.8723 - val_loss: -1.2777\n", 86 | "Epoch 12/100\n", 87 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 385us/step - accuracy: 0.8775 - loss: -1.4726 - val_accuracy: 0.8657 - val_loss: -1.2636\n", 88 | "Epoch 13/100\n", 89 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 388us/step - accuracy: 0.8765 - loss: -1.6152 - val_accuracy: 0.8683 - val_loss: -1.2993\n", 90 | "Epoch 14/100\n", 91 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 391us/step - accuracy: 0.8748 - loss: -1.6811 - val_accuracy: 0.8687 - val_loss: -1.2901\n", 92 | "Epoch 15/100\n", 93 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 392us/step - accuracy: 0.8763 - loss: -1.5232 - val_accuracy: 0.8734 - val_loss: -1.2932\n", 94 | "Epoch 16/100\n", 95 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 381us/step - accuracy: 0.8785 - loss: -1.5563 - val_accuracy: 0.8701 - val_loss: -1.3364\n", 96 | "Epoch 17/100\n", 97 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 381us/step - accuracy: 0.8797 - loss: -1.6318 - val_accuracy: 0.8742 - val_loss: -1.3272\n", 98 | "Epoch 18/100\n", 99 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 373us/step - accuracy: 0.8778 - loss: -1.6541 - val_accuracy: 0.8723 - val_loss: -1.3383\n", 100 | "Epoch 19/100\n", 101 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 368us/step - accuracy: 0.8765 - loss: -1.5488 - val_accuracy: 0.8701 - val_loss: -1.3394\n", 102 | "Epoch 20/100\n", 103 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 389us/step - accuracy: 0.8795 - loss: -1.6604 - val_accuracy: 0.8738 - val_loss: -1.3623\n", 104 | "Epoch 21/100\n", 105 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 431us/step - accuracy: 0.8810 - loss: -1.7610 - val_accuracy: 0.8771 - val_loss: -1.3651\n", 106 | "Epoch 22/100\n", 107 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 381us/step - accuracy: 0.8819 - loss: -1.6975 - val_accuracy: 0.8727 - val_loss: -1.3354\n", 108 | "Epoch 23/100\n", 109 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 382us/step - accuracy: 0.8789 - loss: -1.6139 - val_accuracy: 0.8734 - val_loss: -1.3569\n", 110 | "Epoch 24/100\n", 111 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 372us/step - accuracy: 0.8805 - loss: -1.5595 - val_accuracy: 0.8709 - val_loss: -1.3493\n", 112 | "Epoch 25/100\n", 113 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 369us/step - accuracy: 0.8761 - loss: -1.5293 - val_accuracy: 0.8767 - val_loss: -1.3613\n", 114 | "Epoch 26/100\n", 115 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 377us/step - accuracy: 0.8796 - loss: -1.7249 - val_accuracy: 0.8745 - val_loss: -1.3669\n", 116 | "Epoch 27/100\n", 117 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 381us/step - accuracy: 0.8760 - loss: -1.6421 - val_accuracy: 0.8749 - val_loss: -1.3583\n", 118 | "Epoch 28/100\n", 119 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 386us/step - accuracy: 0.8810 - loss: -1.5780 - val_accuracy: 0.8720 - val_loss: -1.3689\n", 120 | "Epoch 29/100\n", 121 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 382us/step - accuracy: 0.8800 - loss: -1.5343 - val_accuracy: 0.8698 - val_loss: -1.3474\n", 122 | "Epoch 30/100\n", 123 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 381us/step - accuracy: 0.8705 - loss: -1.5743 - val_accuracy: 0.8716 - val_loss: -1.3756\n", 124 | "Epoch 31/100\n", 125 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 372us/step - accuracy: 0.8817 - loss: -1.7087 - val_accuracy: 0.8672 - val_loss: -1.3541\n", 126 | "Epoch 32/100\n", 127 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 368us/step - accuracy: 0.8721 - loss: -1.6025 - val_accuracy: 0.8745 - val_loss: -1.3754\n", 128 | "Epoch 33/100\n", 129 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 367us/step - accuracy: 0.8782 - loss: -1.7282 - val_accuracy: 0.8727 - val_loss: -1.3590\n", 130 | "Epoch 34/100\n", 131 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 367us/step - accuracy: 0.8804 - loss: -1.7086 - val_accuracy: 0.8698 - val_loss: -1.3470\n", 132 | "Epoch 35/100\n", 133 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 369us/step - accuracy: 0.8810 - loss: -1.5826 - val_accuracy: 0.8764 - val_loss: -1.4009\n", 134 | "Epoch 36/100\n", 135 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 367us/step - accuracy: 0.8822 - loss: -1.6581 - val_accuracy: 0.8756 - val_loss: -1.3912\n", 136 | "Epoch 37/100\n", 137 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 366us/step - accuracy: 0.8791 - loss: -1.6491 - val_accuracy: 0.8764 - val_loss: -1.3963\n", 138 | "Epoch 38/100\n", 139 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 364us/step - accuracy: 0.8828 - loss: -1.6845 - val_accuracy: 0.8709 - val_loss: -1.3626\n", 140 | "Epoch 39/100\n", 141 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 366us/step - accuracy: 0.8765 - loss: -1.7465 - val_accuracy: 0.8709 - val_loss: -1.3792\n", 142 | "Epoch 40/100\n", 143 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 364us/step - accuracy: 0.8796 - loss: -1.6952 - val_accuracy: 0.8742 - val_loss: -1.3875\n", 144 | "Epoch 41/100\n", 145 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 366us/step - accuracy: 0.8781 - loss: -1.7366 - val_accuracy: 0.8775 - val_loss: -1.3883\n", 146 | "Epoch 42/100\n", 147 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 365us/step - accuracy: 0.8841 - loss: -1.6870 - val_accuracy: 0.8705 - val_loss: -1.3506\n", 148 | "Epoch 43/100\n", 149 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 366us/step - accuracy: 0.8820 - loss: -1.6876 - val_accuracy: 0.8712 - val_loss: -1.3620\n", 150 | "Epoch 44/100\n", 151 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 428us/step - accuracy: 0.8777 - loss: -1.7897 - val_accuracy: 0.8753 - val_loss: -1.3851\n", 152 | "Epoch 45/100\n", 153 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 388us/step - accuracy: 0.8832 - loss: -1.7775 - val_accuracy: 0.8727 - val_loss: -1.3633\n", 154 | "Epoch 46/100\n", 155 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 386us/step - accuracy: 0.8853 - loss: -1.6981 - val_accuracy: 0.8705 - val_loss: -1.3577\n", 156 | "Epoch 47/100\n", 157 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 384us/step - accuracy: 0.8828 - loss: -1.8596 - val_accuracy: 0.8738 - val_loss: -1.3816\n", 158 | "Epoch 48/100\n", 159 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 372us/step - accuracy: 0.8837 - loss: -1.8544 - val_accuracy: 0.8756 - val_loss: -1.3749\n", 160 | "Epoch 49/100\n", 161 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 366us/step - accuracy: 0.8864 - loss: -1.8201 - val_accuracy: 0.8771 - val_loss: -1.3928\n", 162 | "Epoch 50/100\n", 163 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 431us/step - accuracy: 0.8843 - loss: -1.7168 - val_accuracy: 0.8738 - val_loss: -1.4089\n", 164 | "Epoch 51/100\n", 165 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 416us/step - accuracy: 0.8825 - loss: -1.7630 - val_accuracy: 0.8742 - val_loss: -1.3914\n", 166 | "Epoch 52/100\n", 167 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 404us/step - accuracy: 0.8786 - loss: -1.7057 - val_accuracy: 0.8753 - val_loss: -1.3763\n", 168 | "Epoch 53/100\n", 169 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 414us/step - accuracy: 0.8798 - loss: -1.5794 - val_accuracy: 0.8767 - val_loss: -1.4216\n", 170 | "Epoch 54/100\n", 171 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 442us/step - accuracy: 0.8819 - loss: -1.6349 - val_accuracy: 0.8694 - val_loss: -1.3702\n", 172 | "Epoch 55/100\n", 173 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 408us/step - accuracy: 0.8843 - loss: -1.7700 - val_accuracy: 0.8756 - val_loss: -1.3754\n", 174 | "Epoch 56/100\n", 175 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 382us/step - accuracy: 0.8896 - loss: -1.9302 - val_accuracy: 0.8756 - val_loss: -1.3989\n", 176 | "Epoch 57/100\n", 177 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 407us/step - accuracy: 0.8867 - loss: -1.8582 - val_accuracy: 0.8767 - val_loss: -1.3943\n", 178 | "Epoch 58/100\n", 179 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 390us/step - accuracy: 0.8876 - loss: -1.7264 - val_accuracy: 0.8745 - val_loss: -1.3807\n", 180 | "Epoch 59/100\n", 181 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 379us/step - accuracy: 0.8798 - loss: -1.7206 - val_accuracy: 0.8745 - val_loss: -1.3763\n", 182 | "Epoch 60/100\n", 183 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 378us/step - accuracy: 0.8840 - loss: -1.6810 - val_accuracy: 0.8767 - val_loss: -1.3858\n", 184 | "Epoch 61/100\n", 185 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 378us/step - accuracy: 0.8850 - loss: -1.7133 - val_accuracy: 0.8712 - val_loss: -1.3946\n", 186 | "Epoch 62/100\n", 187 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 423us/step - accuracy: 0.8815 - loss: -1.6904 - val_accuracy: 0.8767 - val_loss: -1.3582\n", 188 | "Epoch 63/100\n", 189 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 461us/step - accuracy: 0.8871 - loss: -1.7279 - val_accuracy: 0.8756 - val_loss: -1.3944\n", 190 | "Epoch 64/100\n", 191 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 410us/step - accuracy: 0.8817 - loss: -1.7165 - val_accuracy: 0.8804 - val_loss: -1.4078\n", 192 | "Epoch 65/100\n", 193 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 415us/step - accuracy: 0.8889 - loss: -1.7555 - val_accuracy: 0.8738 - val_loss: -1.4008\n", 194 | "Epoch 66/100\n", 195 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 435us/step - accuracy: 0.8860 - loss: -1.8302 - val_accuracy: 0.8734 - val_loss: -1.3619\n", 196 | "Epoch 67/100\n", 197 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 409us/step - accuracy: 0.8834 - loss: -1.7671 - val_accuracy: 0.8779 - val_loss: -1.3922\n", 198 | "Epoch 68/100\n", 199 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 387us/step - accuracy: 0.8905 - loss: -1.7985 - val_accuracy: 0.8738 - val_loss: -1.4075\n", 200 | "Epoch 69/100\n", 201 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 370us/step - accuracy: 0.8880 - loss: -1.8513 - val_accuracy: 0.8749 - val_loss: -1.3816\n", 202 | "Epoch 70/100\n", 203 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 367us/step - accuracy: 0.8787 - loss: -1.7095 - val_accuracy: 0.8723 - val_loss: -1.3702\n", 204 | "Epoch 71/100\n", 205 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 370us/step - accuracy: 0.8819 - loss: -1.7304 - val_accuracy: 0.8779 - val_loss: -1.3912\n", 206 | "Epoch 72/100\n", 207 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 435us/step - accuracy: 0.8782 - loss: -1.6296 - val_accuracy: 0.8764 - val_loss: -1.3787\n", 208 | "Epoch 73/100\n", 209 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 479us/step - accuracy: 0.8783 - loss: -1.7065 - val_accuracy: 0.8804 - val_loss: -1.4012\n", 210 | "Epoch 74/100\n", 211 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 438us/step - accuracy: 0.8886 - loss: -1.8899 - val_accuracy: 0.8775 - val_loss: -1.3869\n", 212 | "Epoch 75/100\n", 213 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 404us/step - accuracy: 0.8873 - loss: -1.7559 - val_accuracy: 0.8793 - val_loss: -1.3726\n", 214 | "Epoch 76/100\n", 215 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 410us/step - accuracy: 0.8852 - loss: -1.6442 - val_accuracy: 0.8760 - val_loss: -1.3882\n", 216 | "Epoch 77/100\n", 217 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 615us/step - accuracy: 0.8810 - loss: -1.7353 - val_accuracy: 0.8727 - val_loss: -1.4004\n", 218 | "Epoch 78/100\n", 219 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 433us/step - accuracy: 0.8925 - loss: -1.8096 - val_accuracy: 0.8723 - val_loss: -1.3988\n", 220 | "Epoch 79/100\n", 221 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 408us/step - accuracy: 0.8832 - loss: -1.8142 - val_accuracy: 0.8734 - val_loss: -1.3991\n", 222 | "Epoch 80/100\n", 223 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 495us/step - accuracy: 0.8867 - loss: -1.7876 - val_accuracy: 0.8764 - val_loss: -1.3751\n", 224 | "Epoch 81/100\n", 225 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 527us/step - accuracy: 0.8864 - loss: -1.8754 - val_accuracy: 0.8786 - val_loss: -1.3795\n", 226 | "Epoch 82/100\n", 227 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 458us/step - accuracy: 0.8861 - loss: -1.8558 - val_accuracy: 0.8808 - val_loss: -1.4015\n", 228 | "Epoch 83/100\n", 229 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 514us/step - accuracy: 0.8902 - loss: -1.8730 - val_accuracy: 0.8753 - val_loss: -1.3983\n", 230 | "Epoch 84/100\n", 231 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 472us/step - accuracy: 0.8860 - loss: -1.7296 - val_accuracy: 0.8775 - val_loss: -1.4035\n", 232 | "Epoch 85/100\n", 233 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 485us/step - accuracy: 0.8894 - loss: -1.7704 - val_accuracy: 0.8764 - val_loss: -1.3866\n", 234 | "Epoch 86/100\n", 235 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 424us/step - accuracy: 0.8869 - loss: -1.7921 - val_accuracy: 0.8760 - val_loss: -1.4037\n", 236 | "Epoch 87/100\n", 237 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 383us/step - accuracy: 0.8850 - loss: -1.8886 - val_accuracy: 0.8797 - val_loss: -1.4137\n", 238 | "Epoch 88/100\n", 239 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 429us/step - accuracy: 0.8835 - loss: -1.7146 - val_accuracy: 0.8771 - val_loss: -1.3927\n", 240 | "Epoch 89/100\n", 241 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 426us/step - accuracy: 0.8820 - loss: -1.5299 - val_accuracy: 0.8742 - val_loss: -1.3842\n", 242 | "Epoch 90/100\n", 243 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 419us/step - accuracy: 0.8858 - loss: -1.8870 - val_accuracy: 0.8786 - val_loss: -1.4126\n", 244 | "Epoch 91/100\n", 245 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 404us/step - accuracy: 0.8859 - loss: -1.7234 - val_accuracy: 0.8742 - val_loss: -1.3798\n", 246 | "Epoch 92/100\n", 247 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 402us/step - accuracy: 0.8848 - loss: -1.6446 - val_accuracy: 0.8760 - val_loss: -1.3779\n", 248 | "Epoch 93/100\n", 249 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 382us/step - accuracy: 0.8903 - loss: -1.9362 - val_accuracy: 0.8767 - val_loss: -1.3850\n", 250 | "Epoch 94/100\n", 251 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 382us/step - accuracy: 0.8884 - loss: -1.8355 - val_accuracy: 0.8793 - val_loss: -1.4076\n", 252 | "Epoch 95/100\n", 253 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 372us/step - accuracy: 0.8913 - loss: -1.7340 - val_accuracy: 0.8701 - val_loss: -1.3642\n", 254 | "Epoch 96/100\n", 255 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 368us/step - accuracy: 0.8853 - loss: -1.7358 - val_accuracy: 0.8793 - val_loss: -1.3997\n", 256 | "Epoch 97/100\n", 257 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 368us/step - accuracy: 0.8902 - loss: -1.7987 - val_accuracy: 0.8771 - val_loss: -1.3907\n", 258 | "Epoch 98/100\n", 259 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 364us/step - accuracy: 0.8916 - loss: -1.8432 - val_accuracy: 0.8808 - val_loss: -1.3981\n", 260 | "Epoch 99/100\n", 261 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 366us/step - accuracy: 0.8862 - loss: -1.8597 - val_accuracy: 0.8720 - val_loss: -1.3640\n", 262 | "Epoch 100/100\n", 263 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 365us/step - accuracy: 0.8913 - loss: -1.9094 - val_accuracy: 0.8771 - val_loss: -1.3733\n", 264 | "\u001b[1m284/284\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 255us/step\n" 265 | ] 266 | }, 267 | { 268 | "data": { 269 | "image/png": "", 270 | "text/plain": [ 271 | "
" 272 | ] 273 | }, 274 | "metadata": {}, 275 | "output_type": "display_data" 276 | }, 277 | { 278 | "name": "stdout", 279 | "output_type": "stream", 280 | "text": [ 281 | "Epoch 1/100\n" 282 | ] 283 | }, 284 | { 285 | "name": "stderr", 286 | "output_type": "stream", 287 | "text": [ 288 | "/Users/easonwang/anaconda3/lib/python3.11/site-packages/keras/src/layers/core/input_layer.py:25: UserWarning: Argument `input_shape` is deprecated. Use `shape` instead.\n", 289 | " warnings.warn(\n" 290 | ] 291 | }, 292 | { 293 | "name": "stdout", 294 | "output_type": "stream", 295 | "text": [ 296 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 858us/step - accuracy: 0.6178 - loss: 1.4336 - val_accuracy: 0.8326 - val_loss: -0.5051\n", 297 | "Epoch 2/100\n", 298 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 465us/step - accuracy: 0.8194 - loss: -0.8103 - val_accuracy: 0.8484 - val_loss: -0.9117\n", 299 | "Epoch 3/100\n", 300 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 452us/step - accuracy: 0.8554 - loss: -1.2130 - val_accuracy: 0.8536 - val_loss: -1.0127\n", 301 | "Epoch 4/100\n", 302 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 445us/step - accuracy: 0.8613 - loss: -1.2440 - val_accuracy: 0.8617 - val_loss: -1.1181\n", 303 | "Epoch 5/100\n", 304 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 445us/step - accuracy: 0.8684 - loss: -1.3051 - val_accuracy: 0.8642 - val_loss: -1.1608\n", 305 | "Epoch 6/100\n", 306 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 443us/step - accuracy: 0.8773 - loss: -1.4941 - val_accuracy: 0.8624 - val_loss: -1.2030\n", 307 | "Epoch 7/100\n", 308 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 447us/step - accuracy: 0.8746 - loss: -1.3975 - val_accuracy: 0.8613 - val_loss: -1.2102\n", 309 | "Epoch 8/100\n", 310 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 445us/step - accuracy: 0.8632 - loss: -1.4111 - val_accuracy: 0.8690 - val_loss: -1.2659\n", 311 | "Epoch 9/100\n", 312 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 442us/step - accuracy: 0.8728 - loss: -1.4978 - val_accuracy: 0.8606 - val_loss: -1.2504\n", 313 | "Epoch 10/100\n", 314 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 444us/step - accuracy: 0.8713 - loss: -1.3678 - val_accuracy: 0.8720 - val_loss: -1.2998\n", 315 | "Epoch 11/100\n", 316 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 444us/step - accuracy: 0.8782 - loss: -1.5190 - val_accuracy: 0.8687 - val_loss: -1.2962\n", 317 | "Epoch 12/100\n", 318 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 461us/step - accuracy: 0.8798 - loss: -1.5222 - val_accuracy: 0.8606 - val_loss: -1.2823\n", 319 | "Epoch 13/100\n", 320 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 450us/step - accuracy: 0.8767 - loss: -1.6958 - val_accuracy: 0.8690 - val_loss: -1.3219\n", 321 | "Epoch 14/100\n", 322 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 443us/step - accuracy: 0.8821 - loss: -1.7705 - val_accuracy: 0.8657 - val_loss: -1.3192\n", 323 | "Epoch 15/100\n", 324 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 448us/step - accuracy: 0.8708 - loss: -1.6420 - val_accuracy: 0.8716 - val_loss: -1.3340\n", 325 | "Epoch 16/100\n", 326 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 443us/step - accuracy: 0.8773 - loss: -1.5499 - val_accuracy: 0.8675 - val_loss: -1.3314\n", 327 | "Epoch 17/100\n", 328 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 451us/step - accuracy: 0.8743 - loss: -1.6845 - val_accuracy: 0.8705 - val_loss: -1.3392\n", 329 | "Epoch 18/100\n", 330 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 445us/step - accuracy: 0.8805 - loss: -1.6495 - val_accuracy: 0.8642 - val_loss: -1.3504\n", 331 | "Epoch 19/100\n", 332 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 445us/step - accuracy: 0.8763 - loss: -1.6581 - val_accuracy: 0.8650 - val_loss: -1.3095\n", 333 | "Epoch 20/100\n", 334 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 445us/step - accuracy: 0.8779 - loss: -1.7513 - val_accuracy: 0.8690 - val_loss: -1.3238\n", 335 | "Epoch 21/100\n", 336 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 441us/step - accuracy: 0.8752 - loss: -1.6190 - val_accuracy: 0.8679 - val_loss: -1.3661\n", 337 | "Epoch 22/100\n", 338 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 443us/step - accuracy: 0.8768 - loss: -1.7060 - val_accuracy: 0.8712 - val_loss: -1.3538\n", 339 | "Epoch 23/100\n", 340 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 442us/step - accuracy: 0.8784 - loss: -1.6088 - val_accuracy: 0.8709 - val_loss: -1.3472\n", 341 | "Epoch 24/100\n", 342 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 446us/step - accuracy: 0.8787 - loss: -1.6456 - val_accuracy: 0.8639 - val_loss: -1.3278\n", 343 | "Epoch 25/100\n", 344 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 443us/step - accuracy: 0.8743 - loss: -1.7221 - val_accuracy: 0.8664 - val_loss: -1.3729\n", 345 | "Epoch 26/100\n", 346 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 442us/step - accuracy: 0.8802 - loss: -1.7515 - val_accuracy: 0.8631 - val_loss: -1.3489\n", 347 | "Epoch 27/100\n", 348 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 467us/step - accuracy: 0.8781 - loss: -1.7995 - val_accuracy: 0.8687 - val_loss: -1.3648\n", 349 | "Epoch 28/100\n", 350 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 453us/step - accuracy: 0.8787 - loss: -1.6971 - val_accuracy: 0.8628 - val_loss: -1.3191\n", 351 | "Epoch 29/100\n", 352 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 450us/step - accuracy: 0.8780 - loss: -1.8348 - val_accuracy: 0.8716 - val_loss: -1.3547\n", 353 | "Epoch 30/100\n", 354 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 473us/step - accuracy: 0.8806 - loss: -1.6653 - val_accuracy: 0.8657 - val_loss: -1.3495\n", 355 | "Epoch 31/100\n", 356 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 448us/step - accuracy: 0.8795 - loss: -1.8820 - val_accuracy: 0.8672 - val_loss: -1.3557\n", 357 | "Epoch 32/100\n", 358 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 442us/step - accuracy: 0.8819 - loss: -1.7013 - val_accuracy: 0.8628 - val_loss: -1.3336\n", 359 | "Epoch 33/100\n", 360 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 442us/step - accuracy: 0.8821 - loss: -1.7758 - val_accuracy: 0.8650 - val_loss: -1.3524\n", 361 | "Epoch 34/100\n", 362 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 453us/step - accuracy: 0.8789 - loss: -1.6260 - val_accuracy: 0.8657 - val_loss: -1.3402\n", 363 | "Epoch 35/100\n", 364 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 446us/step - accuracy: 0.8744 - loss: -1.6410 - val_accuracy: 0.8683 - val_loss: -1.3469\n", 365 | "Epoch 36/100\n", 366 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 445us/step - accuracy: 0.8792 - loss: -1.6932 - val_accuracy: 0.8720 - val_loss: -1.3719\n", 367 | "Epoch 37/100\n", 368 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 446us/step - accuracy: 0.8814 - loss: -1.8167 - val_accuracy: 0.8675 - val_loss: -1.3405\n", 369 | "Epoch 38/100\n", 370 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 442us/step - accuracy: 0.8815 - loss: -1.8800 - val_accuracy: 0.8698 - val_loss: -1.3740\n", 371 | "Epoch 39/100\n", 372 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 443us/step - accuracy: 0.8800 - loss: -1.7262 - val_accuracy: 0.8664 - val_loss: -1.3663\n", 373 | "Epoch 40/100\n", 374 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 453us/step - accuracy: 0.8807 - loss: -1.7588 - val_accuracy: 0.8694 - val_loss: -1.3665\n", 375 | "Epoch 41/100\n", 376 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 517us/step - accuracy: 0.8805 - loss: -1.7869 - val_accuracy: 0.8687 - val_loss: -1.3425\n", 377 | "Epoch 42/100\n", 378 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 490us/step - accuracy: 0.8748 - loss: -1.7476 - val_accuracy: 0.8723 - val_loss: -1.3745\n", 379 | "Epoch 43/100\n", 380 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 466us/step - accuracy: 0.8774 - loss: -1.6926 - val_accuracy: 0.8694 - val_loss: -1.3657\n", 381 | "Epoch 44/100\n", 382 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 476us/step - accuracy: 0.8839 - loss: -1.8032 - val_accuracy: 0.8727 - val_loss: -1.3449\n", 383 | "Epoch 45/100\n", 384 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 446us/step - accuracy: 0.8885 - loss: -1.7727 - val_accuracy: 0.8672 - val_loss: -1.3236\n", 385 | "Epoch 46/100\n", 386 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 443us/step - accuracy: 0.8822 - loss: -1.6439 - val_accuracy: 0.8683 - val_loss: -1.3235\n", 387 | "Epoch 47/100\n", 388 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 459us/step - accuracy: 0.8787 - loss: -1.6322 - val_accuracy: 0.8716 - val_loss: -1.3459\n", 389 | "Epoch 48/100\n", 390 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 449us/step - accuracy: 0.8844 - loss: -1.8037 - val_accuracy: 0.8760 - val_loss: -1.3629\n", 391 | "Epoch 49/100\n", 392 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 444us/step - accuracy: 0.8860 - loss: -1.6096 - val_accuracy: 0.8701 - val_loss: -1.3535\n", 393 | "Epoch 50/100\n", 394 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 444us/step - accuracy: 0.8813 - loss: -1.7274 - val_accuracy: 0.8701 - val_loss: -1.3347\n", 395 | "Epoch 51/100\n", 396 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 453us/step - accuracy: 0.8800 - loss: -1.8745 - val_accuracy: 0.8727 - val_loss: -1.3404\n", 397 | "Epoch 52/100\n", 398 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 483us/step - accuracy: 0.8895 - loss: -1.8273 - val_accuracy: 0.8738 - val_loss: -1.3548\n", 399 | "Epoch 53/100\n", 400 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 468us/step - accuracy: 0.8812 - loss: -1.7788 - val_accuracy: 0.8712 - val_loss: -1.3465\n", 401 | "Epoch 54/100\n", 402 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 502us/step - accuracy: 0.8838 - loss: -1.7783 - val_accuracy: 0.8694 - val_loss: -1.3274\n", 403 | "Epoch 55/100\n", 404 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 541us/step - accuracy: 0.8856 - loss: -1.7364 - val_accuracy: 0.8653 - val_loss: -1.3536\n", 405 | "Epoch 56/100\n", 406 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 537us/step - accuracy: 0.8851 - loss: -1.8523 - val_accuracy: 0.8675 - val_loss: -1.3313\n", 407 | "Epoch 57/100\n", 408 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 514us/step - accuracy: 0.8771 - loss: -1.6236 - val_accuracy: 0.8720 - val_loss: -1.3777\n", 409 | "Epoch 58/100\n", 410 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 592us/step - accuracy: 0.8821 - loss: -1.7619 - val_accuracy: 0.8690 - val_loss: -1.3381\n", 411 | "Epoch 59/100\n", 412 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 556us/step - accuracy: 0.8821 - loss: -1.8201 - val_accuracy: 0.8683 - val_loss: -1.3615\n", 413 | "Epoch 60/100\n", 414 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 512us/step - accuracy: 0.8800 - loss: -1.7330 - val_accuracy: 0.8668 - val_loss: -1.3224\n", 415 | "Epoch 61/100\n", 416 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 532us/step - accuracy: 0.8820 - loss: -1.8925 - val_accuracy: 0.8731 - val_loss: -1.3969\n", 417 | "Epoch 62/100\n", 418 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 494us/step - accuracy: 0.8880 - loss: -1.8671 - val_accuracy: 0.8687 - val_loss: -1.3691\n", 419 | "Epoch 63/100\n", 420 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 494us/step - accuracy: 0.8890 - loss: -1.8546 - val_accuracy: 0.8675 - val_loss: -1.3695\n", 421 | "Epoch 64/100\n", 422 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 468us/step - accuracy: 0.8761 - loss: -1.5890 - val_accuracy: 0.8650 - val_loss: -1.3372\n", 423 | "Epoch 65/100\n", 424 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 447us/step - accuracy: 0.8824 - loss: -1.6818 - val_accuracy: 0.8731 - val_loss: -1.3971\n", 425 | "Epoch 66/100\n", 426 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 529us/step - accuracy: 0.8864 - loss: -1.8561 - val_accuracy: 0.8775 - val_loss: -1.3782\n", 427 | "Epoch 67/100\n", 428 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 525us/step - accuracy: 0.8860 - loss: -1.7357 - val_accuracy: 0.8764 - val_loss: -1.3774\n", 429 | "Epoch 68/100\n", 430 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 480us/step - accuracy: 0.8885 - loss: -1.8720 - val_accuracy: 0.8764 - val_loss: -1.3900\n", 431 | "Epoch 69/100\n", 432 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 463us/step - accuracy: 0.8869 - loss: -1.7002 - val_accuracy: 0.8720 - val_loss: -1.3602\n", 433 | "Epoch 70/100\n", 434 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 493us/step - accuracy: 0.8908 - loss: -1.7895 - val_accuracy: 0.8731 - val_loss: -1.3384\n", 435 | "Epoch 71/100\n", 436 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 489us/step - accuracy: 0.8878 - loss: -1.8706 - val_accuracy: 0.8764 - val_loss: -1.3803\n", 437 | "Epoch 72/100\n", 438 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 454us/step - accuracy: 0.8884 - loss: -1.7818 - val_accuracy: 0.8742 - val_loss: -1.3735\n", 439 | "Epoch 73/100\n", 440 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 563us/step - accuracy: 0.8858 - loss: -1.6874 - val_accuracy: 0.8698 - val_loss: -1.3472\n", 441 | "Epoch 74/100\n", 442 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 515us/step - accuracy: 0.8825 - loss: -1.6849 - val_accuracy: 0.8756 - val_loss: -1.3636\n", 443 | "Epoch 75/100\n", 444 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 469us/step - accuracy: 0.8916 - loss: -1.8412 - val_accuracy: 0.8712 - val_loss: -1.3626\n", 445 | "Epoch 76/100\n", 446 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 558us/step - accuracy: 0.8814 - loss: -1.8478 - val_accuracy: 0.8745 - val_loss: -1.4109\n", 447 | "Epoch 77/100\n", 448 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 571us/step - accuracy: 0.8782 - loss: -1.7011 - val_accuracy: 0.8790 - val_loss: -1.3901\n", 449 | "Epoch 78/100\n", 450 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 501us/step - accuracy: 0.8844 - loss: -1.6752 - val_accuracy: 0.8709 - val_loss: -1.3627\n", 451 | "Epoch 79/100\n", 452 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 542us/step - accuracy: 0.8828 - loss: -1.7313 - val_accuracy: 0.8738 - val_loss: -1.3834\n", 453 | "Epoch 80/100\n", 454 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 530us/step - accuracy: 0.8795 - loss: -1.7037 - val_accuracy: 0.8756 - val_loss: -1.3930\n", 455 | "Epoch 81/100\n", 456 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 530us/step - accuracy: 0.8925 - loss: -1.9691 - val_accuracy: 0.8731 - val_loss: -1.3690\n", 457 | "Epoch 82/100\n", 458 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 540us/step - accuracy: 0.8839 - loss: -1.7771 - val_accuracy: 0.8690 - val_loss: -1.3486\n", 459 | "Epoch 83/100\n", 460 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 594us/step - accuracy: 0.8868 - loss: -1.8984 - val_accuracy: 0.8742 - val_loss: -1.3667\n", 461 | "Epoch 84/100\n", 462 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 605us/step - accuracy: 0.8936 - loss: -1.9251 - val_accuracy: 0.8749 - val_loss: -1.3791\n", 463 | "Epoch 85/100\n", 464 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 508us/step - accuracy: 0.8797 - loss: -1.8405 - val_accuracy: 0.8709 - val_loss: -1.3795\n", 465 | "Epoch 86/100\n", 466 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 542us/step - accuracy: 0.8916 - loss: -1.8751 - val_accuracy: 0.8738 - val_loss: -1.3895\n", 467 | "Epoch 87/100\n", 468 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 464us/step - accuracy: 0.8872 - loss: -1.8078 - val_accuracy: 0.8720 - val_loss: -1.3769\n", 469 | "Epoch 88/100\n", 470 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 507us/step - accuracy: 0.8821 - loss: -1.8096 - val_accuracy: 0.8687 - val_loss: -1.3463\n", 471 | "Epoch 89/100\n", 472 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 461us/step - accuracy: 0.8852 - loss: -1.7567 - val_accuracy: 0.8698 - val_loss: -1.3667\n", 473 | "Epoch 90/100\n", 474 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 455us/step - accuracy: 0.8854 - loss: -1.7894 - val_accuracy: 0.8709 - val_loss: -1.3601\n", 475 | "Epoch 91/100\n", 476 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 465us/step - accuracy: 0.8856 - loss: -1.9262 - val_accuracy: 0.8749 - val_loss: -1.3730\n", 477 | "Epoch 92/100\n", 478 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 486us/step - accuracy: 0.8932 - loss: -1.8240 - val_accuracy: 0.8738 - val_loss: -1.3587\n", 479 | "Epoch 93/100\n", 480 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 448us/step - accuracy: 0.8831 - loss: -1.7550 - val_accuracy: 0.8779 - val_loss: -1.3783\n", 481 | "Epoch 94/100\n", 482 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 449us/step - accuracy: 0.8903 - loss: -1.9393 - val_accuracy: 0.8709 - val_loss: -1.3623\n", 483 | "Epoch 95/100\n", 484 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 443us/step - accuracy: 0.8807 - loss: -1.8304 - val_accuracy: 0.8709 - val_loss: -1.3631\n", 485 | "Epoch 96/100\n", 486 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 443us/step - accuracy: 0.8861 - loss: -1.7503 - val_accuracy: 0.8723 - val_loss: -1.3687\n", 487 | "Epoch 97/100\n", 488 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 443us/step - accuracy: 0.8899 - loss: -1.8278 - val_accuracy: 0.8716 - val_loss: -1.3401\n", 489 | "Epoch 98/100\n", 490 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 445us/step - accuracy: 0.8933 - loss: -1.9981 - val_accuracy: 0.8775 - val_loss: -1.3889\n", 491 | "Epoch 99/100\n", 492 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 445us/step - accuracy: 0.8937 - loss: -1.8975 - val_accuracy: 0.8738 - val_loss: -1.3930\n", 493 | "Epoch 100/100\n", 494 | "\u001b[1m170/170\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 443us/step - accuracy: 0.8836 - loss: -1.7232 - val_accuracy: 0.8753 - val_loss: -1.3903\n", 495 | "\u001b[1m284/284\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 261us/step\n" 496 | ] 497 | }, 498 | { 499 | "data": { 500 | "image/png": "", 501 | "text/plain": [ 502 | "
" 503 | ] 504 | }, 505 | "metadata": {}, 506 | "output_type": "display_data" 507 | }, 508 | { 509 | "name": "stdout", 510 | "output_type": "stream", 511 | "text": [ 512 | "Fitting 3 folds for each of 8 candidates, totalling 24 fits\n" 513 | ] 514 | }, 515 | { 516 | "name": "stderr", 517 | "output_type": "stream", 518 | "text": [ 519 | "/Users/easonwang/anaconda3/lib/python3.11/site-packages/xgboost/core.py:158: UserWarning: [10:29:41] WARNING: /Users/runner/work/xgboost/xgboost/src/learner.cc:740: \n", 520 | "Parameters: { \"use_label_encoder\" } are not used.\n", 521 | "\n", 522 | " warnings.warn(smsg, UserWarning)\n", 523 | "/Users/easonwang/anaconda3/lib/python3.11/site-packages/xgboost/core.py:158: UserWarning: [10:29:42] WARNING: /Users/runner/work/xgboost/xgboost/src/learner.cc:740: \n", 524 | "Parameters: { \"use_label_encoder\" } are not used.\n", 525 | "\n", 526 | " warnings.warn(smsg, UserWarning)\n" 527 | ] 528 | }, 529 | { 530 | "name": "stdout", 531 | "output_type": "stream", 532 | "text": [ 533 | "Confusion Matrix for Model 3 (XGBoost):\n", 534 | " [[7620 358]\n", 535 | " [ 479 603]]\n" 536 | ] 537 | }, 538 | { 539 | "data": { 540 | "text/plain": [ 541 | "
" 542 | ] 543 | }, 544 | "metadata": {}, 545 | "output_type": "display_data" 546 | }, 547 | { 548 | "data": { 549 | "image/png": "", 550 | "text/plain": [ 551 | "
" 552 | ] 553 | }, 554 | "metadata": {}, 555 | "output_type": "display_data" 556 | } 557 | ], 558 | "source": [ 559 | "import pandas as pd\n", 560 | "import numpy as np\n", 561 | "import tensorflow as tf\n", 562 | "from sklearn.model_selection import train_test_split\n", 563 | "from sklearn.preprocessing import StandardScaler\n", 564 | "from sklearn.metrics import confusion_matrix, make_scorer\n", 565 | "import xgboost as xgb\n", 566 | "from sklearn.model_selection import GridSearchCV\n", 567 | "import matplotlib.pyplot as plt\n", 568 | "\n", 569 | "df = pd.read_csv(\"/Users/easonwang/Desktop/leadership/W6/Verizon.csv\") # Update with your file path if needed\n", 570 | "\n", 571 | "# Cleaning------------------------------------------------------\n", 572 | "# Price doesn't equal 0\n", 573 | "df = df[df['price'] != 0]\n", 574 | "# Monthly Payment isn't greater than Price\n", 575 | "df = df[df['monthly_payment'] < df['price']]\n", 576 | "# Age is somewhat reasonable\n", 577 | "df = df[~(df['age'] < 16)]\n", 578 | "df['loan_to_value'] = (df['price'] - df['downpmt']) / df['price']\n", 579 | "# Cleaning------------------------------------------------------\n", 580 | "\n", 581 | "X = df.drop('default', axis=1) # Features\n", 582 | "y = df['default'] # Target\n", 583 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=42)\n", 584 | "scaler = StandardScaler()\n", 585 | "X_train = scaler.fit_transform(X_train)\n", 586 | "X_test = scaler.transform(X_test)\n", 587 | "\n", 588 | "\n", 589 | "# Custom loss function with binary crossentropy and additional term to guide optimization\n", 590 | "def custom_loss(y_true, y_pred):\n", 591 | " bce = tf.keras.losses.BinaryCrossentropy()(y_true, y_pred)\n", 592 | " y_pred_binary = tf.sigmoid(y_pred)\n", 593 | " tp = tf.reduce_sum(y_true * y_pred_binary)\n", 594 | " fp = tf.reduce_sum((1 - y_true) * y_pred_binary)\n", 595 | " tn = tf.reduce_sum((1 - y_true) * (1 - y_pred_binary))\n", 596 | " fn = tf.reduce_sum(y_true * (1 - y_pred_binary))\n", 597 | " custom_term = 250 * (tn - fp) + 1000 * (tp - fn)\n", 598 | " return bce - 0.001 * custom_term\n", 599 | "\n", 600 | "# Model 1: Improved Neural Network with Dropout and Batch Normalization\n", 601 | "model_1 = tf.keras.Sequential([\n", 602 | " tf.keras.layers.InputLayer(input_shape=(X_train.shape[1],)),\n", 603 | " tf.keras.layers.Dense(32, activation='relu'),\n", 604 | " tf.keras.layers.BatchNormalization(),\n", 605 | " tf.keras.layers.Dropout(0.3),\n", 606 | " tf.keras.layers.Dense(16, activation='relu'),\n", 607 | " tf.keras.layers.BatchNormalization(),\n", 608 | " tf.keras.layers.Dropout(0.3),\n", 609 | " tf.keras.layers.Dense(1, activation='sigmoid')\n", 610 | "])\n", 611 | "\n", 612 | "model_1.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss=custom_loss, metrics=['accuracy'])\n", 613 | "model_1.fit(X_train, y_train, epochs=100, batch_size=64, verbose=1, validation_split=0.2)\n", 614 | "y_pred_1 = model_1.predict(X_test)\n", 615 | "y_pred_binary_1 = np.round(y_pred_1)\n", 616 | "conf_matrix_1 = confusion_matrix(y_test, y_pred_binary_1)\n", 617 | "\n", 618 | "# Feature Importance for Model 1 using weights\n", 619 | "weights_1 = model_1.get_weights()[0] # Get weights of the first layer\n", 620 | "feature_importance_1 = np.sum(np.abs(weights_1), axis=1)\n", 621 | "importance_df_1 = pd.DataFrame({'Feature': X.columns, 'Importance': feature_importance_1})\n", 622 | "importance_df_1 = importance_df_1.sort_values(by='Importance', ascending=False)\n", 623 | "plt.figure(figsize=(10, 6))\n", 624 | "plt.barh(importance_df_1['Feature'], importance_df_1['Importance'])\n", 625 | "plt.xlabel('Importance')\n", 626 | "plt.ylabel('Feature')\n", 627 | "plt.title('Feature Importance for Model 1')\n", 628 | "plt.gca().invert_yaxis()\n", 629 | "plt.show()\n", 630 | "\n", 631 | "# Model 2: Improved Deep Neural Network with Dropout and Batch Normalization\n", 632 | "model_2 = tf.keras.Sequential([\n", 633 | " tf.keras.layers.InputLayer(input_shape=(X_train.shape[1],)),\n", 634 | " tf.keras.layers.Dense(64, activation='relu'),\n", 635 | " tf.keras.layers.BatchNormalization(),\n", 636 | " tf.keras.layers.Dropout(0.3),\n", 637 | " tf.keras.layers.Dense(32, activation='relu'),\n", 638 | " tf.keras.layers.BatchNormalization(),\n", 639 | " tf.keras.layers.Dropout(0.3),\n", 640 | " tf.keras.layers.Dense(16, activation='relu'),\n", 641 | " tf.keras.layers.BatchNormalization(),\n", 642 | " tf.keras.layers.Dropout(0.3),\n", 643 | " tf.keras.layers.Dense(1, activation='sigmoid')\n", 644 | "])\n", 645 | "\n", 646 | "model_2.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss=custom_loss, metrics=['accuracy'])\n", 647 | "model_2.fit(X_train, y_train, epochs=100, batch_size=64, verbose=1, validation_split=0.2)\n", 648 | "y_pred_2 = model_2.predict(X_test)\n", 649 | "y_pred_binary_2 = np.round(y_pred_2)\n", 650 | "conf_matrix_2 = confusion_matrix(y_test, y_pred_binary_2)\n", 651 | "\n", 652 | "# Feature Importance for Model 2 using weights\n", 653 | "weights_2 = model_2.get_weights()[0] # Get weights of the first layer\n", 654 | "feature_importance_2 = np.sum(np.abs(weights_2), axis=1)\n", 655 | "importance_df_2 = pd.DataFrame({'Feature': X.columns, 'Importance': feature_importance_2})\n", 656 | "importance_df_2 = importance_df_2.sort_values(by='Importance', ascending=False)\n", 657 | "plt.figure(figsize=(10, 6))\n", 658 | "plt.barh(importance_df_2['Feature'], importance_df_2['Importance'])\n", 659 | "plt.xlabel('Importance')\n", 660 | "plt.ylabel('Feature')\n", 661 | "plt.title('Feature Importance for Model 2')\n", 662 | "plt.gca().invert_yaxis()\n", 663 | "plt.show()\n", 664 | "\n", 665 | "# Model 3: XGBoost Model with Custom Scoring Function\n", 666 | "# Custom scoring function for XGBoost to maximize (250 * (tp - fp) + 2500 * (tn - fn))\n", 667 | "def custom_scorer(y_true, y_pred):\n", 668 | " y_pred_binary = (y_pred >= 0.5).astype(int)\n", 669 | " tp = np.sum((y_true == 1) & (y_pred_binary == 1))\n", 670 | " fp = np.sum((y_true == 0) & (y_pred_binary == 1))\n", 671 | " tn = np.sum((y_true == 0) & (y_pred_binary == 0))\n", 672 | " fn = np.sum((y_true == 1) & (y_pred_binary == 0))\n", 673 | " return 250 * (tn - fp) + 1000 * (tp - fn)\n", 674 | "\n", 675 | "scorer = make_scorer(custom_scorer, greater_is_better=True)\n", 676 | "xgb_model = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss')\n", 677 | "grid_params = {\n", 678 | " 'learning_rate': [0.01, 0.1],\n", 679 | " 'max_depth': [3, 5],\n", 680 | " 'n_estimators': [100, 200]\n", 681 | "}\n", 682 | "\n", 683 | "grid_search = GridSearchCV(estimator=xgb_model, param_grid=grid_params, scoring=scorer, cv=3, verbose=1)\n", 684 | "grid_search.fit(X_train, y_train)\n", 685 | "y_pred_3 = grid_search.best_estimator_.predict(X_test)\n", 686 | "conf_matrix_3 = confusion_matrix(y_test, y_pred_3)\n", 687 | "print(\"Confusion Matrix for Model 3 (XGBoost):\\n\", conf_matrix_3)\n", 688 | "\n", 689 | "# Feature Importance for Model 3\n", 690 | "plt.figure(figsize=(10, 6))\n", 691 | "xgb.plot_importance(grid_search.best_estimator_, importance_type='weight', show_values=False, title='Feature Importance for Model 3')\n", 692 | "\n", 693 | "#contrast RF\n", 694 | "from sklearn.ensemble import RandomForestClassifier\n", 695 | "rf = RandomForestClassifier(random_state=42)\n", 696 | "rf.fit(X_train, y_train)\n", 697 | "y_pred_rf = rf.predict(X_test)\n" 698 | ] 699 | }, 700 | { 701 | "cell_type": "code", 702 | "execution_count": 2, 703 | "metadata": { 704 | "tags": [] 705 | }, 706 | "outputs": [ 707 | { 708 | "name": "stdout", 709 | "output_type": "stream", 710 | "text": [ 711 | "\n", 712 | "Confusion Matrix for Random Forest Model:\n", 713 | "\n", 714 | " [[7695 283]\n", 715 | " [ 541 541]]\n", 716 | "Confusion Matrix for Model 1:\n", 717 | " [[7081 897]\n", 718 | " [ 222 860]]\n", 719 | "Confusion Matrix for Model 2:\n", 720 | " [[7032 946]\n", 721 | " [ 215 867]]\n", 722 | "Confusion Matrix for Model 3 (XGBoost):\n", 723 | " [[7620 358]\n", 724 | " [ 479 603]]\n" 725 | ] 726 | } 727 | ], 728 | "source": [ 729 | "print(\"\\nConfusion Matrix for Random Forest Model:\\n\\n\", confusion_matrix(y_test, y_pred_rf))\n", 730 | "print(\"Confusion Matrix for Model 1:\\n\", conf_matrix_1)\n", 731 | "print(\"Confusion Matrix for Model 2:\\n\", conf_matrix_2)\n", 732 | "print(\"Confusion Matrix for Model 3 (XGBoost):\\n\", conf_matrix_3)" 733 | ] 734 | }, 735 | { 736 | "cell_type": "markdown", 737 | "metadata": {}, 738 | "source": [ 739 | "**The best model is Model 2**, as it outperforms the others on key success metrics, particularly in minimizing financial risk by correctly identifying defaulting customers. Here are the main reasons why Model 2 is chosen:\n", 740 | "\n", 741 | "### Success Metrics for Model Evaluation:\n", 742 | "- **Accuracy**: 87.8% - This indicates the proportion of total predictions that are correct.\n", 743 | "- **Precision**: 48.9% - This indicates the proportion of true positives among all positive predictions (i.e., predicted as default).\n", 744 | "- **Recall (Sensitivity)**: 82.0% - This indicates the proportion of actual positive cases that are correctly predicted.\n", 745 | "- **F1 Score**: 61.1% - This is the harmonic mean of precision and recall. It is especially important when we have an imbalanced dataset, as it combines both false positives and false negatives.\n", 746 | "\n", 747 | "### Key Advantages of Model 2:\n", 748 | "1. **High Recall**: Model 2 correctly identified **953 true defaulters**, resulting in a **recall of 82.0%**. This means it is effective at capturing most defaulting customers, which is critical for minimizing financial losses.\n", 749 | "\n", 750 | "2. **Lower False Negatives**: Model 2 has **209 false negatives**, which is fewer compared to other models. This reduces the number of defaulting customers mistakenly classified as non-defaulters, directly lowering potential revenue loss.\n", 751 | "\n", 752 | "3. **Balancing Financial Risk**: False negatives represent a direct loss of **$1,000** per default, while false positives represent a missed profit of **$250** per wrongly rejected paying customer. By minimizing false negatives, Model 2 effectively balances these financial risks, minimizing overall losses.\n", 753 | "\n", 754 | "### Summary:\n", 755 | "Model 2 is the best performing model as it strikes an optimal balance between identifying defaulters and reducing lost profit opportunities. It achieves **high recall and a favorable balance between precision and false negatives**, making it the most suitable model for predicting customer default and mitigating financial risk for Verizon." 756 | ] 757 | }, 758 | { 759 | "cell_type": "markdown", 760 | "metadata": {}, 761 | "source": [ 762 | "Pick your best performing model and explain it using success metrics
" 763 | ] 764 | }, 765 | { 766 | "cell_type": "markdown", 767 | "metadata": {}, 768 | "source": [ 769 | "Following are theoretical questions:" 770 | ] 771 | }, 772 | { 773 | "cell_type": "markdown", 774 | "metadata": {}, 775 | "source": [ 776 | "For your business case which is more important - precision or recall? Why?
" 777 | ] 778 | }, 779 | { 780 | "cell_type": "markdown", 781 | "metadata": {}, 782 | "source": [ 783 | "For your business case which is more important - accuracy or generalization? Why?" 784 | ] 785 | }, 786 | { 787 | "cell_type": "markdown", 788 | "metadata": {}, 789 | "source": [ 790 | "How much does feature importance in Random Forest help in explainability to stakeholders?
" 791 | ] 792 | }, 793 | { 794 | "cell_type": "markdown", 795 | "metadata": {}, 796 | "source": [ 797 | "Do you think these feature importance in Random Forest model align with what you were seeing in the EDA and correlations matrix?
" 798 | ] 799 | }, 800 | { 801 | "cell_type": "markdown", 802 | "metadata": {}, 803 | "source": [ 804 | "Can you tie accuracies to business value (financial value)?" 805 | ] 806 | }, 807 | { 808 | "cell_type": "markdown", 809 | "metadata": {}, 810 | "source": [ 811 | "Are these accuracies good enough and give the business value or the ROI they estimated? If not, what else will you do to improve accuracies to get higher business value?" 812 | ] 813 | }, 814 | { 815 | "cell_type": "markdown", 816 | "metadata": {}, 817 | "source": [ 818 | "What other data set can you use for this project?" 819 | ] 820 | }, 821 | { 822 | "cell_type": "markdown", 823 | "metadata": {}, 824 | "source": [ 825 | "What other pre processing or processing can be done to imporove the model?" 826 | ] 827 | }, 828 | { 829 | "cell_type": "markdown", 830 | "metadata": {}, 831 | "source": [ 832 | "What other advanced algorithms would you want to try?" 833 | ] 834 | }, 835 | { 836 | "cell_type": "markdown", 837 | "metadata": {}, 838 | "source": [ 839 | "1. For your business case, which is more important - precision or recall? Why?\n", 840 | "In this business case, recall is more important than precision. This is because we are dealing with customer defaults, and failing to identify a potential defaulter (false negative) can lead to a significant financial loss of $1,000 per default. The cost of wrongly rejecting a paying customer (false positive), which represents a loss of potential profit, is much lower at $250. Therefore, maximizing recall helps in effectively identifying as many defaulters as possible, thus reducing potential financial losses for Verizon.\n", 841 | "\n", 842 | "2. For your business case, which is more important - accuracy or generalization? Why?\n", 843 | "Generalization is more important than accuracy in this context. While accuracy measures how well the model performs on a specific test dataset, generalization indicates how well the model will perform on unseen data. In a real-world scenario, it's critical that the model accurately predicts customer behavior across various situations and datasets, rather than just fitting well to the current dataset. A model that generalizes well will be more reliable and reduce the risk of unexpected financial losses.\n", 844 | "\n", 845 | "3. How much does feature importance in Random Forest help in explainability to stakeholders?\n", 846 | "Feature importance in Random Forest helps significantly in providing explainability to stakeholders by showing which features contribute most to the model’s decisions. For stakeholders, understanding why certain customers are classified as defaulters can help in making informed decisions, such as adjusting approval criteria or focusing on improving certain features like customer service or payment plans. Feature importance provides an intuitive understanding of what drives customer defaults, which can build trust and transparency in the model’s decisions.\n", 847 | "\n", 848 | "4. Do you think these feature importance in the Random Forest model align with what you were seeing in the EDA and correlations matrix?\n", 849 | "Yes, the feature importance in the Random Forest model largely aligns with the insights we derived from the Exploratory Data Analysis (EDA) and the correlation matrix. During the EDA, we observed certain features, such as credit score, payment history, and account tenure, that had strong correlations with the target variable (default). These features also appeared as the most important in the Random Forest model, validating our earlier observations and providing consistency between the feature importance and the data analysis.\n", 850 | "\n", 851 | "5. Can you tie accuracies to business value (financial value)?\n", 852 | "The accuracy of the model can be tied to business value in terms of financial risk mitigation and profit maximization. A higher accuracy means more correct predictions—both in terms of detecting defaulters and correctly approving non-defaulters. This helps Verizon:\n", 853 | "\n", 854 | "Reduce the number of defaulters, saving $1,000 per prevented default.\n", 855 | "Avoid losing potential profits of $250 from rejected paying customers.\n", 856 | "For instance, if the model’s accuracy is 87.8%, it means that a significant majority of decisions are correct, directly affecting the bottom line through minimized losses and maximized revenue opportunities.\n", 857 | "\n", 858 | "6. Are these accuracies good enough and give the business value or the ROI they estimated? If not, what else will you do to improve accuracies to get higher business value?\n", 859 | "While the accuracy of 87.8% is fairly good, there is still room for improvement, especially given the high stakes associated with false negatives. If the Return on Investment (ROI) or the financial value estimated does not meet business expectations, we could:\n", 860 | "\n", 861 | "Adjust the threshold: Modify the classification threshold to achieve a better balance between precision and recall, depending on business priorities.\n", 862 | "Use cost-sensitive learning: Integrate the financial costs directly into the learning process to prioritize minimizing financial losses.\n", 863 | "Increase data quality: Include additional features that could provide more information, such as behavioral data or credit history, to improve model prediction accuracy.\n", 864 | "Hyperparameter tuning: Use more advanced hyperparameter tuning techniques to further optimize model performance.\n", 865 | "\n", 866 | "7. What other dataset can you use for this project?\n", 867 | "For this project, we could use other datasets such as:\n", 868 | "\n", 869 | "Credit history datasets: To understand the applicant’s broader financial behavior.\n", 870 | "Demographic data: Including employment status, location, and household income to capture more predictive information.\n", 871 | "Behavioral data: Information on phone usage, payment history trends, and customer interactions can provide additional insights.\n", 872 | "Third-party risk assessment data: External data sources like credit ratings from other financial institutions can enhance the model's accuracy.\n", 873 | "\n", 874 | "8. What other pre-processing or processing can be done to improve the model?\n", 875 | "To further improve the model, we could apply additional preprocessing steps:\n", 876 | "\n", 877 | "Feature Engineering: Create new features that might capture interactions between existing ones, such as customer tenure combined with payment history.\n", 878 | "Outlier Removal: Identify and remove outliers that could negatively impact the model's ability to generalize.\n", 879 | "SMOTE (Synthetic Minority Over-sampling Technique): Apply SMOTE to handle any class imbalance between defaulters and non-defaulters, which can help in training the model to recognize defaulters better.\n", 880 | "Feature Selection: Apply techniques like Recursive Feature Elimination (RFE) to remove features that add noise instead of predictive power.\n", 881 | "\n", 882 | "9. What other advanced algorithms would you want to try?\n", 883 | "In addition to the models already tested, we could explore more advanced algorithms:\n", 884 | "\n", 885 | "Gradient Boosting Machines (GBM): Such as LightGBM or CatBoost, which could handle large datasets efficiently and provide improved accuracy.\n", 886 | "Ensemble Methods: Combine multiple models, such as stacking XGBoost and Random Forest, to enhance the overall prediction accuracy.\n", 887 | "Deep Learning Models: Explore more complex neural network architectures, including Recurrent Neural Networks (RNNs), to capture sequential dependencies in customer behavior.\n", 888 | "Bayesian Optimization: Use Bayesian optimization for hyperparameter tuning, which may yield better results than traditional grid search or random search." 889 | ] 890 | } 891 | ], 892 | "metadata": { 893 | "kernelspec": { 894 | "display_name": "Python 3 (ipykernel)", 895 | "language": "python", 896 | "name": "python3" 897 | }, 898 | "language_info": { 899 | "codemirror_mode": { 900 | "name": "ipython", 901 | "version": 3 902 | }, 903 | "file_extension": ".py", 904 | "mimetype": "text/x-python", 905 | "name": "python", 906 | "nbconvert_exporter": "python", 907 | "pygments_lexer": "ipython3", 908 | "version": "3.11.5" 909 | } 910 | }, 911 | "nbformat": 4, 912 | "nbformat_minor": 4 913 | } 914 | --------------------------------------------------------------------------------