├── .gitignore ├── README.md ├── bert_classification.ipynb ├── data ├── finance_sentiment.csv └── finance_sentiment_multiclass.csv └── unsloth_classification.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | dontcommit.ipynb 164 | _unsloth_temporary_saved_buffers/unsloth/Qwen2-0.5B-bnb-4bit/output_embeddings.pt 165 | answerdotai/ModernBERT-large/tokenizer_config.json 166 | answerdotai/ModernBERT-large/tokenizer.json 167 | answerdotai/ModernBERT-large/special_tokens_map.json 168 | answerdotai/ModernBERT-large/model.safetensors 169 | answerdotai/ModernBERT-large/config.json 170 | _unsloth_temporary_saved_buffers/Qwen3-4B-Base/output_embeddings.pt 171 | _unsloth_temporary_saved_buffers/Qwen3-0.6B-Base/output_embeddings.pt 172 | _unsloth_temporary_saved_buffers/Qwen2-1.5B-bnb-4bit/output_embeddings.pt 173 | data/orig.csv 174 | data/process.ipynb 175 | lora_model_Qwen2-1.5B-bnb-4bit/adapter_config.json 176 | lora_model_Qwen2-1.5B-bnb-4bit/adapter_model.safetensors 177 | lora_model_Qwen2-1.5B-bnb-4bit/README.md 178 | lora_model_Qwen3-0.6B-Base/adapter_config.json 179 | lora_model_Qwen3-0.6B-Base/adapter_model.safetensors 180 | lora_model_Qwen3-0.6B-Base/README.md 181 | lora_model_Qwen3-4B-Base/adapter_config.json 182 | lora_model_Qwen3-4B-Base/adapter_model.safetensors 183 | lora_model_Qwen3-4B-Base/README.md 184 | outputs/* 185 | Qwen2-1.5B-bnb-4bit/added_tokens.json 186 | Qwen2-1.5B-bnb-4bit/config.json 187 | Qwen2-1.5B-bnb-4bit/generation_config.json 188 | Qwen2-1.5B-bnb-4bit/gitattributes 189 | Qwen2-1.5B-bnb-4bit/merges.txt 190 | Qwen2-1.5B-bnb-4bit/model.safetensors 191 | Qwen2-1.5B-bnb-4bit/README.md 192 | Qwen2-1.5B-bnb-4bit/special_tokens_map.json 193 | Qwen2-1.5B-bnb-4bit/tokenizer.json 194 | Qwen2-1.5B-bnb-4bit/tokenizer_config.json 195 | Qwen2-1.5B-bnb-4bit/vocab.json 196 | Qwen3-0.6B-Base/added_tokens.json 197 | Qwen3-0.6B-Base/config.json 198 | Qwen3-0.6B-Base/generation_config.json 199 | Qwen3-0.6B-Base/merges.txt 200 | Qwen3-0.6B-Base/model.safetensors 201 | Qwen3-0.6B-Base/special_tokens_map.json 202 | Qwen3-0.6B-Base/tokenizer.json 203 | Qwen3-0.6B-Base/tokenizer_config.json 204 | Qwen3-0.6B-Base/vocab.json 205 | Qwen3-4B-Base/added_tokens.json 206 | Qwen3-4B-Base/config.json 207 | Qwen3-4B-Base/generation_config.json 208 | Qwen3-4B-Base/merges.txt 209 | Qwen3-4B-Base/model-00001-of-00002.safetensors 210 | Qwen3-4B-Base/model-00002-of-00002.safetensors 211 | Qwen3-4B-Base/model.safetensors.index.json 212 | Qwen3-4B-Base/special_tokens_map.json 213 | Qwen3-4B-Base/tokenizer.json 214 | Qwen3-4B-Base/tokenizer_config.json 215 | Qwen3-4B-Base/vocab.json 216 | unsloth_classification copy.ipynb 217 | unsloth_classification_broken.ipynb 218 | unsloth_compiled_cache/* 219 | trainer_output/* 220 | bert_classification_crash_test.ipynb 221 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text classification scripts 2 | 3 | ## unsloth_classification.ipynb 4 | 5 | This modified Unsloth notebook trains LLaMa-3 on any text classification dataset, where the input is a csv with columns "text" and "label". 6 | 7 | ### Added features: 8 | 9 | - Trims the classification head to contain only the "Yes" and "No" tokens, which saves 1 GB of VRAM, allows you to train the head without massive memory usage, and makes the start of the training session more stable. 10 | - Only the last token in the sequence contributes to the loss, the model doesn't waste its capacity by trying to predict the input 11 | - includes "group_by_length = True" which speeds up training significantly for unbalanced sequence lengths 12 | - Efficiently evaluates the accuracy on the validation set using batched inference 13 | 14 | ## bert_classification.ipynb 15 | 16 | This notebook can be used to train any bert model on any text classification dataset (same format as above). The notebook also includes "group_by_length = True" which not commonly found in bert-training notebooks (they usually tokenize everything ahead of time with a lot of wasteful padding). 17 | -------------------------------------------------------------------------------- /bert_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "env: UNSLOTH_COMPILE_DISABLE=1\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "# needed to fix a bug with unsloth\n", 18 | "%env UNSLOTH_COMPILE_DISABLE = 1" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "name": "stdout", 28 | "output_type": "stream", 29 | "text": [ 30 | "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n", 31 | "🦥 Unsloth Zoo will now patch everything to make training faster!\n" 32 | ] 33 | }, 34 | { 35 | "name": "stderr", 36 | "output_type": "stream", 37 | "text": [ 38 | "c:\\ProgramData\\Anaconda3\\Lib\\site-packages\\unsloth_zoo\\gradient_checkpointing.py:330: UserWarning: expandable_segments not supported on this platform (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\pytorch\\c10/cuda/CUDAAllocatorConfig.h:28.)\n", 39 | " GPU_BUFFERS = tuple([torch.empty(2*256*2048, dtype = dtype, device = f\"cuda:{i}\") for i in range(n_gpus)])\n" 40 | ] 41 | }, 42 | { 43 | "name": "stdout", 44 | "output_type": "stream", 45 | "text": [ 46 | "==((====))== Unsloth 2025.4.5: Fast Modernbert patching. Transformers: 4.51.3.\n", 47 | " \\\\ /| NVIDIA GeForce RTX 3090. Num GPUs = 1. Max memory: 23.999 GB. Platform: Windows.\n", 48 | "O^O/ \\_/ \\ Torch: 2.7.0+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.0\n", 49 | "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]\n", 50 | " \"-____-\" Free license: http://github.com/unslothai/unsloth\n", 51 | "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n", 52 | "Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.\n" 53 | ] 54 | }, 55 | { 56 | "name": "stderr", 57 | "output_type": "stream", 58 | "text": [ 59 | "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n", 60 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 61 | ] 62 | }, 63 | { 64 | "name": "stdout", 65 | "output_type": "stream", 66 | "text": [ 67 | "model parameters:395834371\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "from unsloth import FastLanguageModel, FastModel\n", 73 | "import pandas as pd\n", 74 | "import numpy as np\n", 75 | "from sklearn.model_selection import train_test_split\n", 76 | "from sklearn.metrics import accuracy_score\n", 77 | "import os\n", 78 | "import torch\n", 79 | "from torch import tensor\n", 80 | "import torch.nn.functional as F\n", 81 | "from transformers import TrainingArguments, Trainer, ModernBertModel, AutoModelForSequenceClassification, training_args\n", 82 | "from datasets import load_dataset, Dataset\n", 83 | "from tqdm import tqdm\n", 84 | "\n", 85 | "model_name = 'answerdotai/ModernBERT-large'\n", 86 | "\n", 87 | "NUM_CLASSES = 3\n", 88 | "DATA_DIR = \"data/\"\n", 89 | "\n", 90 | "model, tokenizer = FastModel.from_pretrained(\n", 91 | " model_name = model_name,load_in_4bit = False,\n", 92 | " max_seq_length = 2048,\n", 93 | " dtype = None,\n", 94 | " auto_model = AutoModelForSequenceClassification,\n", 95 | " num_labels = NUM_CLASSES,\n", 96 | ")\n", 97 | "print(\"model parameters:\" + str(sum(p.numel() for p in model.parameters())))\n", 98 | "\n", 99 | "# make all parameters trainable\n", 100 | "for param in model.parameters():\n", 101 | " param.requires_grad = True" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "The dataset can be found [here](https://github.com/timothelaborie/text_classification_scripts/blob/main/data/finance_sentiment_multiclass.csv)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 3, 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "data": { 118 | "application/vnd.jupyter.widget-view+json": { 119 | "model_id": "e1dd1859a86d47f7ae5ac7e54dec68c4", 120 | "version_major": 2, 121 | "version_minor": 0 122 | }, 123 | "text/plain": [ 124 | "Map: 0%| | 0/3893 [00:00\n", 201 | " \n", 202 | " \n", 203 | " [366/366 01:41, Epoch 3/3]\n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | "
StepTraining LossValidation LossAccuracy
920.5649000.4308820.713626
1840.3465000.3083270.812933
2760.2475000.2884100.822171

" 235 | ], 236 | "text/plain": [ 237 | "" 238 | ] 239 | }, 240 | "metadata": {}, 241 | "output_type": "display_data" 242 | }, 243 | { 244 | "name": "stdout", 245 | "output_type": "stream", 246 | "text": [ 247 | "Unsloth: Will smartly offload gradients to save VRAM!\n" 248 | ] 249 | } 250 | ], 251 | "source": [ 252 | "trainer = Trainer(\n", 253 | " model=model,\n", 254 | " tokenizer=tokenizer,\n", 255 | " train_dataset=dataset,\n", 256 | " eval_dataset=val_dataset,\n", 257 | " args=TrainingArguments(\n", 258 | " per_device_train_batch_size=32,\n", 259 | " gradient_accumulation_steps=1,\n", 260 | " warmup_steps=10,\n", 261 | " fp16=not torch.cuda.is_bf16_supported(),\n", 262 | " bf16=torch.cuda.is_bf16_supported(),\n", 263 | " optim=training_args.OptimizerNames.ADAMW_TORCH,\n", 264 | " learning_rate=5e-5,\n", 265 | " weight_decay=0.001,\n", 266 | " lr_scheduler_type=\"cosine\",\n", 267 | " seed=3407,\n", 268 | " num_train_epochs=3, # bert-style models usually need more than 1 epoch\n", 269 | " save_strategy=\"epoch\",\n", 270 | "\n", 271 | " # report_to=\"wandb\",\n", 272 | " report_to=\"none\",\n", 273 | "\n", 274 | " group_by_length=True,\n", 275 | "\n", 276 | " # eval_strategy=\"no\",\n", 277 | " eval_strategy=\"steps\",\n", 278 | " eval_steps=0.25,\n", 279 | " logging_strategy=\"steps\",\n", 280 | " logging_steps=0.25,\n", 281 | " \n", 282 | " ),\n", 283 | " compute_metrics=lambda eval_pred: { \"accuracy\": accuracy_score(eval_pred[1].argmax(axis=-1), eval_pred[0].argmax(axis=-1)) }\n", 284 | ")\n", 285 | "trainer_stats = trainer.train()" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 5, 291 | "metadata": {}, 292 | "outputs": [ 293 | { 294 | "name": "stdout", 295 | "output_type": "stream", 296 | "text": [ 297 | "\n" 298 | ] 299 | } 300 | ], 301 | "source": [ 302 | "model = model.cuda()\n", 303 | "model = model.eval()\n", 304 | "FastLanguageModel.for_inference(model)\n", 305 | "print()" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 6, 311 | "metadata": {}, 312 | "outputs": [ 313 | { 314 | "name": "stderr", 315 | "output_type": "stream", 316 | "text": [ 317 | "Evaluating: 100%|██████████| 14/14 [00:01<00:00, 11.78it/s]" 318 | ] 319 | }, 320 | { 321 | "name": "stdout", 322 | "output_type": "stream", 323 | "text": [ 324 | "\n", 325 | "Validation accuracy: 82.45% (357/433)\n", 326 | "\n", 327 | "--- Random samples ---\n", 328 | "\n", 329 | "Text: Turkey Stiffens Manipulation Penalties in Banking Overhaul\n", 330 | "True: 0 Pred: 0 ✅\n", 331 | "Probs: 0: 0.846, 1: 0.114, 2: 0.040\n", 332 | "\n", 333 | "Text: The Manitowoc Company, Inc. Just Reported Earnings, And Analysts Cut Their Target Price\n", 334 | "True: 2 Pred: 2 ✅\n", 335 | "Probs: 0: 0.065, 1: 0.029, 2: 0.906\n", 336 | "\n", 337 | "Text: $BLMN $EAT $SBUX - Restaurants stocks break higher, analysts reel in near-term expectations https://t.co/fOjVVJdfF0\n", 338 | "True: 1 Pred: 1 ✅\n", 339 | "Probs: 0: 0.004, 1: 0.974, 2: 0.021\n", 340 | "\n", 341 | "Text: $CMCSA $LHX - Comcast sues L3Harris in patent dispute https://t.co/kWReshGbvz\n", 342 | "True: 2 Pred: 2 ✅\n", 343 | "Probs: 0: 0.023, 1: 0.035, 2: 0.942\n", 344 | "\n", 345 | "Text: Libyan economic experts will study the distribution of crucial oil revenue as efforts continue to solve the war-rav… https://t.co/S9lmpnDTqJ\n", 346 | "True: 0 Pred: 0 ✅\n", 347 | "Probs: 0: 0.888, 1: 0.013, 2: 0.099\n", 348 | "\n", 349 | "Text: Stocks Suffer 'Shocking' Down Week As Fed Balance Sheet Unexpectedly Shrinks https://t.co/bspsRi3Wow\n", 350 | "True: 2 Pred: 2 ✅\n", 351 | "Probs: 0: 0.005, 1: 0.001, 2: 0.993\n", 352 | "\n", 353 | "Text: Burger King says it never promised Impossible Whoppers were vegan https://t.co/oZCnoupsYV https://t.co/lauoccNH0n\n", 354 | "True: 0 Pred: 0 ✅\n", 355 | "Probs: 0: 0.762, 1: 0.070, 2: 0.167\n", 356 | "\n", 357 | "Text: McEwen Mining prices public offering at $1.325/unit\n", 358 | "True: 0 Pred: 0 ✅\n", 359 | "Probs: 0: 0.710, 1: 0.284, 2: 0.006\n", 360 | "\n", 361 | "Text: H&P downgraded at Argus as drilling industry weakness seen persisting\n", 362 | "True: 2 Pred: 2 ✅\n", 363 | "Probs: 0: 0.000, 1: 0.000, 2: 1.000\n", 364 | "\n", 365 | "Text: $SPLK - Splunk: Full Steam Ahead. Follow this and any other stock on Seeking Alpha! https://t.co/DCvnAuSOBa #markets #stocks #economy\n", 366 | "True: 0 Pred: 1 ❌\n", 367 | "Probs: 0: 0.040, 1: 0.959, 2: 0.001\n", 368 | "\n", 369 | "Text: Three dead in shooting at Oklahoma Walmart: RPT\n", 370 | "True: 0 Pred: 2 ❌\n", 371 | "Probs: 0: 0.036, 1: 0.001, 2: 0.964\n", 372 | "\n", 373 | "Text: U.S. Oil Inventories Rise More Than Expected #WTI #Stock #MarketScreener https://t.co/lMkNlbjinO https://t.co/wBBq3HdLZO\n", 374 | "True: 2 Pred: 1 ❌\n", 375 | "Probs: 0: 0.002, 1: 0.997, 2: 0.001\n", 376 | "\n", 377 | "Text: Casper Sleep stock languishes below IPO issue price after falling 5%\n", 378 | "True: 2 Pred: 2 ✅\n", 379 | "Probs: 0: 0.000, 1: 0.000, 2: 1.000\n", 380 | "\n", 381 | "Text: The global oil market is drowning in excess crude as demand plummets. Insights via @CMEGroup https://t.co/JklSJKFfRS\n", 382 | "True: 2 Pred: 2 ✅\n", 383 | "Probs: 0: 0.001, 1: 0.001, 2: 0.998\n", 384 | "\n", 385 | "Text: Asia Stocks Open Mixed as Trade Details Awaited: Markets Wrap\n", 386 | "True: 0 Pred: 1 ❌\n", 387 | "Probs: 0: 0.023, 1: 0.865, 2: 0.112\n", 388 | "\n", 389 | "Text: From Starbucks to Seattle, companies and cities alike are banning plastic straws. Are takeout containers next?… https://t.co/Ew4Fsl6K0m\n", 390 | "True: 0 Pred: 0 ✅\n", 391 | "Probs: 0: 0.968, 1: 0.003, 2: 0.029\n", 392 | "\n", 393 | "Text: Americans' outlook on the economy faltered significantly last month as the coronavirus crisis began to take hold in… https://t.co/5jeCXLXrrR\n", 394 | "True: 2 Pred: 2 ✅\n", 395 | "Probs: 0: 0.002, 1: 0.002, 2: 0.996\n", 396 | "\n", 397 | "Text: Boris Johnson’s Conservative Party will pledge not to increase several key tax measures if it wins next month’s gen… https://t.co/RENnMT4Dtr\n", 398 | "True: 0 Pred: 0 ✅\n", 399 | "Probs: 0: 0.999, 1: 0.001, 2: 0.001\n", 400 | "\n", 401 | "Text: Casper Sleep shares slide 5% to trade at $10.46, below $12 IPO price\n", 402 | "True: 2 Pred: 2 ✅\n", 403 | "Probs: 0: 0.000, 1: 0.000, 2: 1.000\n", 404 | "\n", 405 | "Text: Brixmor 2020 FFO guidance comes in on the light side\n", 406 | "True: 2 Pred: 0 ❌\n", 407 | "Probs: 0: 0.868, 1: 0.131, 2: 0.002\n" 408 | ] 409 | }, 410 | { 411 | "name": "stderr", 412 | "output_type": "stream", 413 | "text": [ 414 | "\n" 415 | ] 416 | } 417 | ], 418 | "source": [ 419 | "batch_size = 32\n", 420 | "correct = 0\n", 421 | "results = []\n", 422 | "\n", 423 | "# If the val_labels are one-hot, convert to class indices\n", 424 | "if isinstance(val_labels, np.ndarray) and val_labels.ndim == 2:\n", 425 | " val_true_labels = np.argmax(val_labels, axis=1)\n", 426 | "else:\n", 427 | " val_true_labels = val_labels\n", 428 | "\n", 429 | "val_texts = list(val_data)\n", 430 | "val_true_labels = list(val_true_labels)\n", 431 | "\n", 432 | "with torch.no_grad():\n", 433 | " for i in tqdm(range(0, len(val_texts), batch_size), desc=\"Evaluating\"):\n", 434 | " batch_texts = val_texts[i:i+batch_size]\n", 435 | " batch_labels = val_true_labels[i:i+batch_size]\n", 436 | " # Tokenize\n", 437 | " inputs = tokenizer(batch_texts, return_tensors=\"pt\", padding=True, truncation=True, max_length=2048)\n", 438 | " inputs = {k: v.cuda() for k, v in inputs.items()}\n", 439 | " # Forward pass\n", 440 | " outputs = model(**inputs)\n", 441 | " logits = outputs.logits\n", 442 | " probs = F.softmax(logits, dim=-1)\n", 443 | " preds = torch.argmax(probs, dim=-1).cpu().numpy()\n", 444 | " # Count correct\n", 445 | " correct += np.sum(preds == batch_labels)\n", 446 | " # Store results for display\n", 447 | " for j in range(len(batch_texts)):\n", 448 | " results.append({\n", 449 | " \"text\": batch_texts[j][:200],\n", 450 | " \"true\": batch_labels[j],\n", 451 | " \"pred\": preds[j],\n", 452 | " \"probs\": probs[j].detach().float().cpu().numpy(),\n", 453 | " \"ok\": preds[j] == batch_labels[j]\n", 454 | " })\n", 455 | "\n", 456 | "accuracy = 100 * correct / len(val_texts)\n", 457 | "print(f\"\\nValidation accuracy: {accuracy:.2f}% ({correct}/{len(val_texts)})\")\n", 458 | "\n", 459 | "# Show a few random samples\n", 460 | "import random\n", 461 | "display = 20\n", 462 | "print(\"\\n--- Random samples ---\")\n", 463 | "for s in random.sample(results, min(display, len(results))):\n", 464 | " print(f\"\\nText: {s['text']}\")\n", 465 | " print(f\"True: {s['true']} Pred: {s['pred']} {'✅' if s['ok'] else '❌'}\")\n", 466 | " print(\"Probs:\", \", \".join([f\"{k}: {v:.3f}\" for k, v in enumerate(s['probs'])]))" 467 | ] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "execution_count": 7, 472 | "metadata": {}, 473 | "outputs": [ 474 | { 475 | "ename": "ZeroDivisionError", 476 | "evalue": "division by zero", 477 | "output_type": "error", 478 | "traceback": [ 479 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 480 | "\u001b[1;31mZeroDivisionError\u001b[0m Traceback (most recent call last)", 481 | "Cell \u001b[1;32mIn[7], line 2\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m# stop running all cells\u001b[39;00m\n\u001b[1;32m----> 2\u001b[0m \u001b[38;5;241m1\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m0\u001b[39m\n", 482 | "\u001b[1;31mZeroDivisionError\u001b[0m: division by zero" 483 | ] 484 | } 485 | ], 486 | "source": [ 487 | "# stop running all cells\n", 488 | "1/0" 489 | ] 490 | }, 491 | { 492 | "cell_type": "markdown", 493 | "metadata": {}, 494 | "source": [ 495 | "# to load the model again (run every cell above the one where the trainer is called)" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": null, 501 | "metadata": {}, 502 | "outputs": [ 503 | { 504 | "name": "stdout", 505 | "output_type": "stream", 506 | "text": [ 507 | "Last checkpoint: trainer_output\\checkpoint-244\n", 508 | "==((====))== Unsloth 2025.4.5: Fast Modernbert patching. Transformers: 4.51.3.\n", 509 | " \\\\ /| NVIDIA GeForce RTX 3090. Num GPUs = 1. Max memory: 23.999 GB. Platform: Windows.\n", 510 | "O^O/ \\_/ \\ Torch: 2.7.0+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.0\n", 511 | "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]\n", 512 | " \"-____-\" Free license: http://github.com/unslothai/unsloth\n", 513 | "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n", 514 | "Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.\n" 515 | ] 516 | } 517 | ], 518 | "source": [ 519 | "from transformers.trainer_utils import get_last_checkpoint\n", 520 | "\n", 521 | "output_dir = \"trainer_output\"\n", 522 | "last_checkpoint = get_last_checkpoint(output_dir)\n", 523 | "print(\"Last checkpoint:\", last_checkpoint)\n", 524 | "\n", 525 | "model, tokenizer = FastModel.from_pretrained(\n", 526 | " model_name = last_checkpoint,load_in_4bit = False,\n", 527 | " max_seq_length = 2048,\n", 528 | " dtype = None,\n", 529 | " auto_model = AutoModelForSequenceClassification,\n", 530 | " num_labels = NUM_CLASSES,\n", 531 | ")" 532 | ] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "execution_count": null, 537 | "metadata": {}, 538 | "outputs": [ 539 | { 540 | "name": "stdout", 541 | "output_type": "stream", 542 | "text": [ 543 | "SequenceClassifierOutput(loss=None, logits=tensor([[-0.0579, -0.5859, -1.1719]], device='cuda:0', dtype=torch.bfloat16,\n", 544 | " grad_fn=), hidden_states=None, attentions=None)\n" 545 | ] 546 | } 547 | ], 548 | "source": [ 549 | "from torch import tensor\n", 550 | "print(model(input_ids=tensor([[1,2,3,4,5]]).cuda(), attention_mask=tensor([[1,1,1,1,1]]).cuda()))" 551 | ] 552 | } 553 | ], 554 | "metadata": { 555 | "kernelspec": { 556 | "display_name": "base", 557 | "language": "python", 558 | "name": "python3" 559 | }, 560 | "language_info": { 561 | "codemirror_mode": { 562 | "name": "ipython", 563 | "version": 3 564 | }, 565 | "file_extension": ".py", 566 | "mimetype": "text/x-python", 567 | "name": "python", 568 | "nbconvert_exporter": "python", 569 | "pygments_lexer": "ipython3", 570 | "version": "3.12.3" 571 | }, 572 | "orig_nbformat": 4, 573 | "vscode": { 574 | "interpreter": { 575 | "hash": "1be15a159d9874788f7b7854451912393d9e82d0d2bc47d83a870bda7fd9bc22" 576 | } 577 | } 578 | }, 579 | "nbformat": 4, 580 | "nbformat_minor": 2 581 | } 582 | -------------------------------------------------------------------------------- /unsloth_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "IqM-T1RTzY6C" 7 | }, 8 | "source": [ 9 | "# Text classification with Unsloth\n", 10 | "\n", 11 | "This modified Unsloth notebook trains an LLM on any text classification dataset, where the input is a csv with columns \"text\" and \"label\".\n", 12 | "\n", 13 | "### Added features:\n", 14 | "\n", 15 | "- Trims the classification head to contain only the number tokens such as \"1\", \"2\" etc, which saves 1 GB of VRAM, allows you to train the head without massive memory usage, and makes the start of the training session more stable.\n", 16 | "- Only the last token in the sequence contributes to the loss, the model doesn't waste its capacity by trying to predict the input\n", 17 | "- includes \"group_by_length = True\" which speeds up training significantly for unbalanced sequence lengths\n", 18 | "- Efficiently evaluates the accuracy on the validation set using batched inference\n", 19 | "\n", 20 | "### Update 4th of May 2025:\n", 21 | "\n", 22 | "- Added support for more than 2 classes\n", 23 | "- The classification head is now built back up to the original size after training, no more errors in external libraries.\n", 24 | "- Made the batched inference part much faster and cleaner\n", 25 | "- Changed model to Qwen 3\n", 26 | "- Improved comments to explain the complicated parts" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 1, 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "name": "stdout", 36 | "output_type": "stream", 37 | "text": [ 38 | "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n", 39 | "🦥 Unsloth Zoo will now patch everything to make training faster!\n" 40 | ] 41 | } 42 | ], 43 | "source": [ 44 | "# needed as this function doesn't like it when the lm_head has its size changed\n", 45 | "from unsloth import tokenizer_utils\n", 46 | "def do_nothing(*args, **kwargs):\n", 47 | " pass\n", 48 | "tokenizer_utils.fix_untrained_tokens = do_nothing" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "Major: 8, Minor: 6\n" 61 | ] 62 | }, 63 | { 64 | "name": "stderr", 65 | "output_type": "stream", 66 | "text": [ 67 | "c:\\ProgramData\\Anaconda3\\Lib\\site-packages\\unsloth_zoo\\gradient_checkpointing.py:330: UserWarning: expandable_segments not supported on this platform (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\pytorch\\c10/cuda/CUDAAllocatorConfig.h:28.)\n", 68 | " GPU_BUFFERS = tuple([torch.empty(2*256*2048, dtype = dtype, device = f\"cuda:{i}\") for i in range(n_gpus)])\n" 69 | ] 70 | }, 71 | { 72 | "name": "stdout", 73 | "output_type": "stream", 74 | "text": [ 75 | "==((====))== Unsloth 2025.4.5: Fast Qwen3 patching. Transformers: 4.51.3.\n", 76 | " \\\\ /| NVIDIA GeForce RTX 3090. Num GPUs = 1. Max memory: 23.999 GB. Platform: Windows.\n", 77 | "O^O/ \\_/ \\ Torch: 2.7.0+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.0\n", 78 | "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]\n", 79 | " \"-____-\" Free license: http://github.com/unslothai/unsloth\n", 80 | "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n" 81 | ] 82 | }, 83 | { 84 | "data": { 85 | "application/vnd.jupyter.widget-view+json": { 86 | "model_id": "67e411e7801f4f9db7b980df6de5d01b", 87 | "version_major": 2, 88 | "version_minor": 0 89 | }, 90 | "text/plain": [ 91 | "Loading checkpoint shards: 0%| | 0/2 [00:00" 275 | ] 276 | }, 277 | "metadata": {}, 278 | "output_type": "display_data" 279 | } 280 | ], 281 | "source": [ 282 | "token_counts = [len(tokenizer.encode(x)) for x in train_df.text]\n", 283 | "# plot the token counts\n", 284 | "a = plt.hist(token_counts, bins=30)" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 7, 290 | "metadata": { 291 | "colab": { 292 | "base_uri": "https://localhost:8080/", 293 | "height": 145, 294 | "referenced_widgets": [ 295 | "e0ea75df3b5c49ce89427e4d245a7646", 296 | "02237d111eba4155ab5a48a1b33d82a4", 297 | "927e1d33248d4653a785049b50b6d814", 298 | "09979ecafec9443ba1f7335ed64de778", 299 | "088a000de2234c9f94a8b09e8a8abbb2", 300 | "38c7b4ead56b4540bf68fbe3a5496b9c", 301 | "9a0787c7c98b49b8a9141e5212faa249", 302 | "daf62811f5384d0f9aea58b40f23161a", 303 | "6f90172dd5c240c885140d49add6208f", 304 | "187078f1978843f2b873cee2ee55ac91", 305 | "ca8af702ee764e7fa620476854a4f2cf", 306 | "80b9d976e30e40ecaf35a606bca5d647", 307 | "e9b423d370314fdfa3c22e95c3a35d92", 308 | "307f44c0a4da4d3ca0d44b54f9a5f6c0", 309 | "5503539430024c7586d4f8589c92fd74", 310 | "b17542f2cd0f45e79e755df622328358", 311 | "fe6d1bab4fd042b5af89f6bb8a73cc39", 312 | "8cb7748eaa354d4ea0e3222685f1d9b4", 313 | "53bfefa3e2c04cce9ab7d2b3fbca59d9", 314 | "aa58f2124335451890857ded72414ecf", 315 | "b331b446e257441387b331d822f570a0", 316 | "49e3298fbe054c7abb6a041ecd63737c", 317 | "6b104e7242cf4a608cfd0c4fc6f44db1", 318 | "1a6288e052884d5aa486415880acf134", 319 | "289c8b4d1686443f8bed7441673a2c15", 320 | "1318a594d2df47779b84f22e4b0c0702", 321 | "355ecaa94d1a43b09e1d0b3f6e6ac8e3", 322 | "156432d2f1b74697bd74604e1bd5bf13", 323 | "976fa37d0fa14282bc96b3c948b9da94", 324 | "24768bbd587d4b1ba09b6c52dcd6237c", 325 | "cdf054995f314976b7124de9564cbd9b", 326 | "1eec25504a6f41d199940c4232b0a114", 327 | "007a9dbfb5c84efd8ffd801bc04c0c20", 328 | "2ffddabe91124b279f9a3ac41d1ead9d", 329 | "f97edc77840e49c49aed8e7ab80f45d1", 330 | "33761916bce24fde9bcff866736ddad7", 331 | "726aae420987454e87a08b9314844895", 332 | "071858b3801e4844a4fc5a91205335c4", 333 | "6a98441b0ff74361a79c96178834d8bc", 334 | "40e8b59e229548c68ccb339234a5e525", 335 | "56b011c2482b427483508e259dc231fd", 336 | "507c15541fe7489aa1fd1fd81c4aa222", 337 | "4bf05ee528bb4bb2979cde4d0a09544b", 338 | "d62a5696b9704580b69e6fc3be9ffa8a" 339 | ] 340 | }, 341 | "id": "LjY75GoYUCB8", 342 | "outputId": "26e1bc8e-c4e8-472e-ca91-670e3381c6b3" 343 | }, 344 | "outputs": [], 345 | "source": [ 346 | "prompt = \"\"\"Here is a financial news:\n", 347 | "{}\n", 348 | "\n", 349 | "Classify this news into one of the following:\n", 350 | "class 1: Bullish\n", 351 | "class 2: Neutral\n", 352 | "class 3: Bearish\n", 353 | "\n", 354 | "SOLUTION\n", 355 | "The correct answer is: class {}\"\"\"\n", 356 | "\n", 357 | "def formatting_prompts_func(dataset_):\n", 358 | " texts = []\n", 359 | " for i in range(len(dataset_['text'])):\n", 360 | " text_ = dataset_['text'].iloc[i]\n", 361 | " label_ = dataset_['label'].iloc[i] # the csv is setup so that the label column corresponds exactly to the 3 classes defined above in the prompt (important)\n", 362 | "\n", 363 | " text = prompt.format(text_, label_)\n", 364 | "\n", 365 | " texts.append(text)\n", 366 | " return texts\n", 367 | "\n", 368 | "# apply formatting_prompts_func to train_df\n", 369 | "train_df['text'] = formatting_prompts_func(train_df)\n", 370 | "train_dataset = datasets.Dataset.from_pandas(train_df,preserve_index=False)" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": 8, 376 | "metadata": {}, 377 | "outputs": [], 378 | "source": [ 379 | "# this custom collator makes it so the model trains only on the last token of the sequence. It also maps from the old tokenizer to the new lm_head indices\n", 380 | "class DataCollatorForLastTokenLM(DataCollatorForLanguageModeling):\n", 381 | " def __init__(\n", 382 | " self,\n", 383 | " *args,\n", 384 | " mlm: bool = False,\n", 385 | " ignore_index: int = -100,\n", 386 | " **kwargs,\n", 387 | " ):\n", 388 | " super().__init__(*args, mlm=mlm, **kwargs)\n", 389 | " self.ignore_index = ignore_index\n", 390 | "\n", 391 | " def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:\n", 392 | " batch = super().torch_call(examples)\n", 393 | "\n", 394 | " for i in range(len(examples)):\n", 395 | " # Find the last non-padding token\n", 396 | " last_token_idx = (batch[\"labels\"][i] != self.ignore_index).nonzero()[-1].item()\n", 397 | " # Set all labels to ignore_index except for the last token\n", 398 | " batch[\"labels\"][i, :last_token_idx] = self.ignore_index\n", 399 | " # If the last token in the text is, for example, \"2\", then this was processed with the old tokenizer into number_token_ids[2]\n", 400 | " # But we don't actually want this because number_token_ids[2] could be something like 27, which is now undefined in the new lm_head. So we map it to the new lm_head index.\n", 401 | " # if this line gives you a keyerror then increase max_seq_length\n", 402 | " batch[\"labels\"][i, last_token_idx] = reverse_map[ batch[\"labels\"][i, last_token_idx].item() ]\n", 403 | "\n", 404 | "\n", 405 | " return batch\n", 406 | "collator = DataCollatorForLastTokenLM(tokenizer=tokenizer)" 407 | ] 408 | }, 409 | { 410 | "cell_type": "markdown", 411 | "metadata": { 412 | "id": "idAEIeSQ3xdS" 413 | }, 414 | "source": [ 415 | "\n", 416 | "### Train the model\n", 417 | "Now let's use Huggingface TRL's `SFTTrainer`! More docs here: [TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support TRL's `DPOTrainer`!" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": 9, 423 | "metadata": { 424 | "colab": { 425 | "base_uri": "https://localhost:8080/", 426 | "height": 177, 427 | "referenced_widgets": [ 428 | "477d5041a08f4a3a9f7cfa1d98ab48ff", 429 | "25eecdf89b8845a989ce2c8c4d9edbeb", 430 | "e02470ee64ad4f0fb2160b088cdcba7f", 431 | "b7c0858a80684aed9c19ce00dda85815", 432 | "d0b82646d6f549d7ad59e9ff5f426e88", 433 | "d94a8d1f9d294abbb156e39da065910f", 434 | "ebb0f92e750447a39ff23a23ea73b445", 435 | "d4d01dc3290b4174ab13454a28faf972", 436 | "c6941f8c60ef49a8a79ed265c46652c2", 437 | "24aa300026334db1b545bf0fa906112e", 438 | "6d0d021108b54147a62e50fea1ba88ea", 439 | "deb2982b19764ed5aec0b4d80e776279", 440 | "c24904c0a7294f3a93bb64aa38e70316", 441 | "26a77ab74a4a4e21b9afd9798a9f9a29", 442 | "7189bac8d0474bcea50cf8711259516c", 443 | "ce81350896a44331aa6b6960b7370325", 444 | "73d4af57ddc64fb3afd0e2ab068cbcb4", 445 | "672d990adee44df6b58e27fe5804986f", 446 | "3fe95fc9bc034a2db85ed19aedfd4250", 447 | "c1c0053b8e674ed6ac81316d4af82c48", 448 | "74a0e53405cd4d64bd597ad5461ed5dd", 449 | "278d35f6e08d45e2a49c3509dd442f0c", 450 | "b19d83c6f5ad4f04941e6a684678ac07", 451 | "88e100a175e64a3dab6f772847deffd6", 452 | "457b5df3a4294c8a966e233d34c15a82", 453 | "18af2c44a24b4aaabfd192e3fc4bf655", 454 | "61944d9473394be5b19cfd48fd504481", 455 | "68d18369acc540a594aef61a2de07e63", 456 | "69c676782e0b496b820b55c45424177d", 457 | "08dcaabc623c4759a00fbee5430e3ba9", 458 | "45a3bcaffc3f4184b9ae869f7b0ccef4", 459 | "c45f02fbb40e4041ba7f95074f8e1e82", 460 | "0fc1037a17e541eba92a2d6f400ac6eb", 461 | "7bd2cc9aa724408fa9e70795744cad85", 462 | "fb0588bffd7a4238bac3f73052e09335", 463 | "68ff2006b3794e23b4ccbc2c83c52b9b", 464 | "910f2e6fffd24cd7b6e68c932c2a2524", 465 | "6345dc6f40c6444aa05ed2aba7809a3c", 466 | "92eab7fadbed4bc2bc2935b4e3d800cf", 467 | "2dfef2cd79c8463cb7eadad6a04aba91", 468 | "3be5c37493e742aebf0aa29b6723283b", 469 | "273be47263384901b6cf9f249ee3409d", 470 | "2ff515b72bbb43c889e2172c83933803", 471 | "a1cd830a712d490181d176267ce7b6f0", 472 | "1a8da471604841ec9f6c8e0073471c00", 473 | "f899ae16708544379d63960be07b7c32", 474 | "0bdbf9b92f7b4925a5ac28df93fd0fb0", 475 | "c1d8820a789f4899a839e6d28a4f333c", 476 | "e711d7f85eee4fe195fad9fbddcfece2", 477 | "56ed5fd876d94ebbaa8b4e905c348a0d", 478 | "5d461f10bdc44c6b95d070fa9d7425d1", 479 | "f4e0e7a39ad9484f930e6673c35c2f5d", 480 | "baad9118cade4cee9600a7fbf6426e0a", 481 | "16fac1aad22444c4ad42ad0e6e1dfdc9", 482 | "ac35368de2d746b4bed736459d187dbd" 483 | ] 484 | }, 485 | "id": "95_Nn-89DhsL", 486 | "outputId": "adb8cb5d-0ec3-4b79-83a7-5691873873e8" 487 | }, 488 | "outputs": [ 489 | { 490 | "data": { 491 | "application/vnd.jupyter.widget-view+json": { 492 | "model_id": "b03dbbeaeb214c5e9ccc6e6264bd69c2", 493 | "version_major": 2, 494 | "version_minor": 0 495 | }, 496 | "text/plain": [ 497 | "Map: 0%| | 0/3893 [00:00\n", 593 | " \n", 594 | " \n", 595 | " [122/122 04:36, Epoch 1/1]\n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | " \n", 943 | " \n", 944 | " \n", 945 | " \n", 946 | " \n", 947 | " \n", 948 | " \n", 949 | " \n", 950 | " \n", 951 | " \n", 952 | " \n", 953 | " \n", 954 | " \n", 955 | " \n", 956 | " \n", 957 | " \n", 958 | " \n", 959 | " \n", 960 | " \n", 961 | " \n", 962 | " \n", 963 | " \n", 964 | " \n", 965 | " \n", 966 | " \n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | " \n", 978 | " \n", 979 | " \n", 980 | " \n", 981 | " \n", 982 | " \n", 983 | " \n", 984 | " \n", 985 | " \n", 986 | " \n", 987 | " \n", 988 | " \n", 989 | " \n", 990 | " \n", 991 | " \n", 992 | " \n", 993 | " \n", 994 | " \n", 995 | " \n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | " \n", 1001 | " \n", 1002 | " \n", 1003 | " \n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | " \n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | " \n", 1033 | " \n", 1034 | " \n", 1035 | " \n", 1036 | " \n", 1037 | " \n", 1038 | " \n", 1039 | " \n", 1040 | " \n", 1041 | " \n", 1042 | " \n", 1043 | " \n", 1044 | " \n", 1045 | " \n", 1046 | " \n", 1047 | " \n", 1048 | " \n", 1049 | " \n", 1050 | " \n", 1051 | " \n", 1052 | " \n", 1053 | " \n", 1054 | " \n", 1055 | " \n", 1056 | " \n", 1057 | " \n", 1058 | " \n", 1059 | " \n", 1060 | " \n", 1061 | " \n", 1062 | " \n", 1063 | " \n", 1064 | " \n", 1065 | " \n", 1066 | " \n", 1067 | " \n", 1068 | " \n", 1069 | " \n", 1070 | " \n", 1071 | " \n", 1072 | " \n", 1073 | " \n", 1074 | " \n", 1075 | " \n", 1076 | " \n", 1077 | " \n", 1078 | " \n", 1079 | " \n", 1080 | " \n", 1081 | " \n", 1082 | " \n", 1083 | " \n", 1084 | " \n", 1085 | " \n", 1086 | " \n", 1087 | " \n", 1088 | " \n", 1089 | " \n", 1090 | " \n", 1091 | " \n", 1092 | " \n", 1093 | " \n", 1094 | "
StepTraining Loss
11.357800
21.182200
31.442300
41.152100
51.033100
60.988200
70.841600
81.047500
90.967200
100.908900
110.850100
120.800500
130.602900
140.828200
150.662600
160.689000
170.578200
180.731100
190.568800
200.467700
211.106200
220.909300
230.731900
240.800000
250.636100
260.745200
270.816300
280.698400
290.443900
300.481700
310.577100
320.610900
330.405500
340.698300
350.443000
360.538800
370.426200
380.404100
390.754100
400.349500
410.663100
420.372800
430.407400
440.422400
450.392000
460.458300
470.468700
480.627400
490.364700
500.288000
510.350700
520.255100
530.335900
540.333300
550.299500
560.383600
570.552900
580.114900
590.531400
600.441900
610.586900
620.826700
630.425800
640.369000
650.443500
660.684300
670.519100
680.437900
690.525900
700.226600
710.264700
720.378600
730.392200
740.271700
750.177400
760.299900
770.145200
780.204400
790.361200
800.257100
810.214300
820.532200
830.573500
840.183000
850.089900
860.127200
870.360300
880.415400
890.389200
900.539400
910.322300
920.638500
930.321500
940.411100
950.489000
960.379200
970.321600
980.359100
990.347800
1000.617300
1010.342400
1020.196300
1030.526400
1040.291300
1050.421600
1060.148100
1070.565300
1080.308900
1090.465800
1100.193000
1110.124000
1120.217600
1130.191400
1140.241700
1150.166500
1160.155000
1170.382100
1180.211400
1190.385200
1200.235100
1210.396300
1220.161100

" 1095 | ], 1096 | "text/plain": [ 1097 | "" 1098 | ] 1099 | }, 1100 | "metadata": {}, 1101 | "output_type": "display_data" 1102 | }, 1103 | { 1104 | "name": "stdout", 1105 | "output_type": "stream", 1106 | "text": [ 1107 | "Unsloth: Will smartly offload gradients to save VRAM!\n" 1108 | ] 1109 | } 1110 | ], 1111 | "source": [ 1112 | "trainer_stats = trainer.train()" 1113 | ] 1114 | }, 1115 | { 1116 | "cell_type": "code", 1117 | "execution_count": 12, 1118 | "metadata": { 1119 | "cellView": "form", 1120 | "colab": { 1121 | "base_uri": "https://localhost:8080/" 1122 | }, 1123 | "id": "pCqnaKmlO1U9", 1124 | "outputId": "ff1b0842-5966-4dc2-bd98-c20832526b31" 1125 | }, 1126 | "outputs": [ 1127 | { 1128 | "name": "stdout", 1129 | "output_type": "stream", 1130 | "text": [ 1131 | "286.2905 seconds used for training.\n", 1132 | "4.77 minutes used for training.\n", 1133 | "Peak reserved memory = 9.082 GB.\n", 1134 | "Peak reserved memory for training = 0.672 GB.\n", 1135 | "Peak reserved memory % of max memory = 37.843 %.\n", 1136 | "Peak reserved memory for training % of max memory = 2.8 %.\n" 1137 | ] 1138 | } 1139 | ], 1140 | "source": [ 1141 | "#@title Show final memory and time stats\n", 1142 | "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", 1143 | "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", 1144 | "used_percentage = round(used_memory /max_memory*100, 3)\n", 1145 | "lora_percentage = round(used_memory_for_lora/max_memory*100, 3)\n", 1146 | "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", 1147 | "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", 1148 | "print(f\"Peak reserved memory = {used_memory} GB.\")\n", 1149 | "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", 1150 | "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", 1151 | "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" 1152 | ] 1153 | }, 1154 | { 1155 | "cell_type": "markdown", 1156 | "metadata": { 1157 | "id": "ekOmTR1hSNcr" 1158 | }, 1159 | "source": [ 1160 | "\n", 1161 | "### Inference\n", 1162 | "This part evaluates the model on the val set with batched inference" 1163 | ] 1164 | }, 1165 | { 1166 | "cell_type": "code", 1167 | "execution_count": 13, 1168 | "metadata": {}, 1169 | "outputs": [ 1170 | { 1171 | "name": "stdout", 1172 | "output_type": "stream", 1173 | "text": [ 1174 | "\n" 1175 | ] 1176 | } 1177 | ], 1178 | "source": [ 1179 | "FastLanguageModel.for_inference(model) # Enable native 2x faster inference\n", 1180 | "print()" 1181 | ] 1182 | }, 1183 | { 1184 | "cell_type": "markdown", 1185 | "metadata": {}, 1186 | "source": [ 1187 | "### remake the old lm_head but with unused tokens having -1000 bias and 0 weights (improves compatibility with libraries like vllm)" 1188 | ] 1189 | }, 1190 | { 1191 | "cell_type": "code", 1192 | "execution_count": 14, 1193 | "metadata": {}, 1194 | "outputs": [ 1195 | { 1196 | "name": "stdout", 1197 | "output_type": "stream", 1198 | "text": [ 1199 | "Remade lm_head: shape = torch.Size([151936, 2560]). Allowed tokens: [15, 16, 17, 18]\n" 1200 | ] 1201 | } 1202 | ], 1203 | "source": [ 1204 | "# Save the current (trimmed) lm_head and bias\n", 1205 | "trimmed_lm_head = model.lm_head.weight.data.clone()\n", 1206 | "trimmed_lm_head_bias = model.lm_head.bias.data.clone() if hasattr(model.lm_head, \"bias\") and model.lm_head.bias is not None else torch.zeros(len(number_token_ids), device=trimmed_lm_head.device)\n", 1207 | "\n", 1208 | "# Create a new lm_head with shape [old_size, hidden_dim]\n", 1209 | "hidden_dim = trimmed_lm_head.shape[1]\n", 1210 | "new_lm_head = torch.full((old_size, hidden_dim), 0, dtype=trimmed_lm_head.dtype, device=trimmed_lm_head.device)\n", 1211 | "new_lm_head_bias = torch.full((old_size,), -1000.0, dtype=trimmed_lm_head_bias.dtype, device=trimmed_lm_head_bias.device)\n", 1212 | "\n", 1213 | "# Fill in the weights and bias for the allowed tokens (number_token_ids)\n", 1214 | "for new_idx, orig_token_id in enumerate(number_token_ids):\n", 1215 | " new_lm_head[orig_token_id] = trimmed_lm_head[new_idx]\n", 1216 | " new_lm_head_bias[orig_token_id] = trimmed_lm_head_bias[new_idx]\n", 1217 | "\n", 1218 | "# Update the model's lm_head weight and bias\n", 1219 | "with torch.no_grad():\n", 1220 | " new_lm_head_module = torch.nn.Linear(hidden_dim, old_size, bias=True, device=model.device)\n", 1221 | " new_lm_head_module.weight.data.copy_(new_lm_head)\n", 1222 | " new_lm_head_module.bias.data.copy_(new_lm_head_bias)\n", 1223 | " model.lm_head.modules_to_save[\"default\"] = new_lm_head_module\n", 1224 | "\n", 1225 | "print(f\"Remade lm_head: shape = {model.lm_head.weight.shape}. Allowed tokens: {number_token_ids}\")" 1226 | ] 1227 | }, 1228 | { 1229 | "cell_type": "markdown", 1230 | "metadata": {}, 1231 | "source": [ 1232 | "# Batched Inference on Validation Set" 1233 | ] 1234 | }, 1235 | { 1236 | "cell_type": "code", 1237 | "execution_count": 17, 1238 | "metadata": {}, 1239 | "outputs": [ 1240 | { 1241 | "name": "stderr", 1242 | "output_type": "stream", 1243 | "text": [ 1244 | "Evaluating: 100%|██████████| 28/28 [00:16<00:00, 1.71it/s]" 1245 | ] 1246 | }, 1247 | { 1248 | "name": "stdout", 1249 | "output_type": "stream", 1250 | "text": [ 1251 | "\n", 1252 | "Validation accuracy: 83.37% (361/433)\n", 1253 | "\n", 1254 | "--- Random samples ---\n", 1255 | "\n", 1256 | "Text: $LEVI - Levi Strauss EPS beats by $0.05, beats on revenue https://t.co/UfIyY92vC5\n", 1257 | "True: 2 Pred: 2 ✅\n", 1258 | "Probs: 1: 0.005, 2: 0.992, 3: 0.003\n", 1259 | "\n", 1260 | "Text: U.S. officials are signaling the European Union might be an easier target than the U.K. for a quick outcome aimed a… https://t.co/c5D7GJJbPf\n", 1261 | "True: 1 Pred: 1 ✅\n", 1262 | "Probs: 1: 0.935, 2: 0.028, 3: 0.036\n", 1263 | "\n", 1264 | "Text: Some odd divergences in the past two weeks as most asset classes except stocks are reversing recent gains https://t.co/0XJwxOyhBA\n", 1265 | "True: 1 Pred: 3 ❌\n", 1266 | "Probs: 1: 0.066, 2: 0.027, 3: 0.907\n", 1267 | "\n", 1268 | "Text: Some officials worried about the bank's decision to drop self-imposed limits on the ECB's bond purchases https://t.co/elfM7nRACS\n", 1269 | "True: 3 Pred: 1 ❌\n", 1270 | "Probs: 1: 0.715, 2: 0.022, 3: 0.263\n", 1271 | "\n", 1272 | "Text: Is Qorvo, Inc.'s (NASDAQ:QRVO) 5.9% ROE Worse Than Average?\n", 1273 | "True: 1 Pred: 1 ✅\n", 1274 | "Probs: 1: 0.762, 2: 0.020, 3: 0.218\n", 1275 | "\n", 1276 | "Text: Here's the level to watch in Texaco as stock falls on London ban $UBER (via @TradingNation) https://t.co/fk3KnVRU95\n", 1277 | "True: 3 Pred: 3 ✅\n", 1278 | "Probs: 1: 0.067, 2: 0.009, 3: 0.924\n", 1279 | "\n", 1280 | "Text: $INPX is gaining momentum......\n", 1281 | "True: 2 Pred: 2 ✅\n", 1282 | "Probs: 1: 0.473, 2: 0.503, 3: 0.024\n", 1283 | "\n", 1284 | "Text: First Coronavirus Case Reported In New York https://t.co/Cy1yurkH6a\n", 1285 | "True: 1 Pred: 1 ✅\n", 1286 | "Probs: 1: 0.523, 2: 0.016, 3: 0.461\n", 1287 | "\n", 1288 | "Text: $MCIG - MCig up 9% on CBD distribution deals https://t.co/WV5yrJFdyD\n", 1289 | "True: 2 Pred: 2 ✅\n", 1290 | "Probs: 1: 0.242, 2: 0.745, 3: 0.014\n", 1291 | "\n", 1292 | "Text: Defence Expo: Defence Equipment Makers See This As Another Opportunity In India\n", 1293 | "True: 1 Pred: 1 ✅\n", 1294 | "Probs: 1: 0.889, 2: 0.106, 3: 0.005\n", 1295 | "\n", 1296 | "Text: 'We know that there will be very likely some effects on the United States', said the Fed chairman Jay Powell about… https://t.co/eAeVR5GUm2\n", 1297 | "True: 1 Pred: 1 ✅\n", 1298 | "Probs: 1: 0.777, 2: 0.050, 3: 0.173\n", 1299 | "\n", 1300 | "Text: $COMDX: Natural gas inventory showed a draw of 201 bcf vs a 92 bcf draw last week https://t.co/CGfYWf1Unq\n", 1301 | "True: 2 Pred: 1 ❌\n", 1302 | "Probs: 1: 0.743, 2: 0.188, 3: 0.069\n", 1303 | "\n", 1304 | "Text: Axcelis Technologies stock price target raised to $32 from $24 at Benchmark\n", 1305 | "True: 2 Pred: 2 ✅\n", 1306 | "Probs: 1: 0.001, 2: 0.998, 3: 0.001\n", 1307 | "\n", 1308 | "Text: Methanex downgraded at TD Securities on valuation\n", 1309 | "True: 3 Pred: 3 ✅\n", 1310 | "Probs: 1: 0.003, 2: 0.003, 3: 0.994\n", 1311 | "\n", 1312 | "Text: Xeris Pharma launches equity offering; shares down 3% after hours\n", 1313 | "True: 3 Pred: 3 ✅\n", 1314 | "Probs: 1: 0.011, 2: 0.007, 3: 0.982\n", 1315 | "\n", 1316 | "Text: Dizzying swings are torching Wall Street predictions https://t.co/BvV6iItVRt\n", 1317 | "True: 3 Pred: 3 ✅\n", 1318 | "Probs: 1: 0.118, 2: 0.008, 3: 0.874\n", 1319 | "\n", 1320 | "Text: Home Depot cuts sales goal as online push not delivering as expected\n", 1321 | "True: 3 Pred: 3 ✅\n", 1322 | "Probs: 1: 0.009, 2: 0.005, 3: 0.987\n", 1323 | "\n", 1324 | "Text: Stocks Gain as U.S. Economy Signals Strength -- Update #SP500 #index #MarketScreener https://t.co/e9i8yV5o1I https://t.co/phbncHu6aL\n", 1325 | "True: 2 Pred: 1 ❌\n", 1326 | "Probs: 1: 0.615, 2: 0.373, 3: 0.011\n", 1327 | "\n", 1328 | "Text: Stock-index futures turn lower\n", 1329 | "True: 3 Pred: 1 ❌\n", 1330 | "Probs: 1: 0.835, 2: 0.037, 3: 0.128\n", 1331 | "\n", 1332 | "Text: Tesla Inc. slumped in pre-market trading after the electric carmaker’s newly unveiled pickup truck elicited mixed r… https://t.co/KDcJiSWCOp\n", 1333 | "True: 3 Pred: 3 ✅\n", 1334 | "Probs: 1: 0.037, 2: 0.004, 3: 0.958\n", 1335 | "\n", 1336 | "Text: Coronavirus reports hang over cruise line sector\n", 1337 | "True: 3 Pred: 3 ✅\n", 1338 | "Probs: 1: 0.085, 2: 0.006, 3: 0.909\n", 1339 | "\n", 1340 | "Text: U.S. Xpress EPS beats by $0.04, beats on revenue\n", 1341 | "True: 2 Pred: 2 ✅\n", 1342 | "Probs: 1: 0.006, 2: 0.990, 3: 0.004\n", 1343 | "\n", 1344 | "Text: Elastic +3.7% as Canaccord turns bullish\n", 1345 | "True: 2 Pred: 2 ✅\n", 1346 | "Probs: 1: 0.014, 2: 0.983, 3: 0.003\n", 1347 | "\n", 1348 | "Text: Whirlpool -2% after large recall in U.K.\n", 1349 | "True: 3 Pred: 3 ✅\n", 1350 | "Probs: 1: 0.042, 2: 0.007, 3: 0.951\n", 1351 | "\n", 1352 | "Text: Americans' outlook on the economy faltered significantly last month as the coronavirus crisis began to take hold in… https://t.co/5jeCXLXrrR\n", 1353 | "True: 3 Pred: 3 ✅\n", 1354 | "Probs: 1: 0.047, 2: 0.005, 3: 0.948\n", 1355 | "\n", 1356 | "Text: Pizza Hut's Struggling Turnaround Weighs on Yum Brands Results\n", 1357 | "True: 3 Pred: 3 ✅\n", 1358 | "Probs: 1: 0.014, 2: 0.007, 3: 0.979\n", 1359 | "\n", 1360 | "Text: FDM : AFL-CIO Endorses USMCA After Successfully Negotiating Improvements #FDM #Stock #MarketScreener… https://t.co/Ja3PJ0uKZB\n", 1361 | "True: 1 Pred: 1 ✅\n", 1362 | "Probs: 1: 0.903, 2: 0.084, 3: 0.013\n", 1363 | "\n", 1364 | "Text: LexinFintech +3.3% after loan origination view improves\n", 1365 | "True: 2 Pred: 2 ✅\n", 1366 | "Probs: 1: 0.001, 2: 0.998, 3: 0.001\n", 1367 | "\n", 1368 | "Text: $BLPH - Bellerophon Therapeutics EPS misses by $0.33 https://t.co/foAfyMnyra\n", 1369 | "True: 3 Pred: 3 ✅\n", 1370 | "Probs: 1: 0.060, 2: 0.007, 3: 0.933\n", 1371 | "\n", 1372 | "Text: Wholesale trade underwhelm in December\n", 1373 | "True: 3 Pred: 3 ✅\n", 1374 | "Probs: 1: 0.060, 2: 0.007, 3: 0.933\n", 1375 | "\n", 1376 | "Text: Why Disney+ is the only service that can rival Netflix\n", 1377 | "True: 2 Pred: 1 ❌\n", 1378 | "Probs: 1: 0.978, 2: 0.014, 3: 0.008\n", 1379 | "\n", 1380 | "Text: $CTSO: CytoSorbents says temporarily pausing enrollment of REFRESH 2-AKI study at the recommendation of its Data... https://t.co/6ibg4NhPh1\n", 1381 | "True: 3 Pred: 1 ❌\n", 1382 | "Probs: 1: 0.519, 2: 0.023, 3: 0.458\n", 1383 | "\n", 1384 | "Text: Twitter stock falls after downgrade\n", 1385 | "True: 3 Pred: 1 ❌\n", 1386 | "Probs: 1: 0.829, 2: 0.044, 3: 0.127\n", 1387 | "\n", 1388 | "Text: LAIX beats on revenue\n", 1389 | "True: 2 Pred: 1 ❌\n", 1390 | "Probs: 1: 0.898, 2: 0.069, 3: 0.033\n", 1391 | "\n", 1392 | "Text: Coty +5% after striking Kylie Jenner deal\n", 1393 | "True: 2 Pred: 2 ✅\n", 1394 | "Probs: 1: 0.001, 2: 0.998, 3: 0.001\n", 1395 | "\n", 1396 | "Text: Airline stocks are higher amid some positive signs within the fight against the coronavirus pandemic https://t.co/BOK9NTJJuv\n", 1397 | "True: 2 Pred: 2 ✅\n", 1398 | "Probs: 1: 0.265, 2: 0.720, 3: 0.015\n", 1399 | "\n", 1400 | "Text: Estee Lauder Cuts Profit Outlook Again, Citing Coronavirus\n", 1401 | "True: 3 Pred: 3 ✅\n", 1402 | "Probs: 1: 0.018, 2: 0.002, 3: 0.980\n", 1403 | "\n", 1404 | "Text: H&P downgraded at Argus as drilling industry weakness seen persisting\n", 1405 | "True: 3 Pred: 3 ✅\n", 1406 | "Probs: 1: 0.006, 2: 0.004, 3: 0.990\n", 1407 | "\n", 1408 | "Text: Hedge Funds Aren’t Crazy About Ladder Capital Corp (LADR) Anymore\n", 1409 | "True: 3 Pred: 3 ✅\n", 1410 | "Probs: 1: 0.093, 2: 0.027, 3: 0.881\n", 1411 | "\n", 1412 | "Text: RPM International stock price target raised to $79 vs. $75 at BofA Merrill Lynch\n", 1413 | "True: 2 Pred: 2 ✅\n", 1414 | "Probs: 1: 0.018, 2: 0.978, 3: 0.004\n", 1415 | "\n", 1416 | "Text: PG&E boss says it wasn’t fully ready for California outages\n", 1417 | "True: 1 Pred: 3 ❌\n", 1418 | "Probs: 1: 0.416, 2: 0.050, 3: 0.534\n", 1419 | "\n", 1420 | "Text: The euro-area economy came close to a halt in November as the steep decline in manufacturing spread further into se… https://t.co/PKMjM0loEx\n", 1421 | "True: 3 Pred: 3 ✅\n", 1422 | "Probs: 1: 0.037, 2: 0.004, 3: 0.958\n", 1423 | "\n", 1424 | "Text: Exxon, Chevron results augur tough year ahead, shares drop 3% #economy #MarketScreener https://t.co/sABok0wweo https://t.co/BYDIfnQivg\n", 1425 | "True: 3 Pred: 1 ❌\n", 1426 | "Probs: 1: 0.801, 2: 0.123, 3: 0.075\n", 1427 | "\n", 1428 | "Text: $SOYB $MOO $FTAG - Soybeans rebound as China trade talks make progress https://t.co/1mf2YHIujB\n", 1429 | "True: 2 Pred: 2 ✅\n", 1430 | "Probs: 1: 0.002, 2: 0.996, 3: 0.002\n", 1431 | "\n", 1432 | "Text: eDreams ODIGEO S.A. reports Q2 results\n", 1433 | "True: 1 Pred: 1 ✅\n", 1434 | "Probs: 1: 0.989, 2: 0.005, 3: 0.006\n", 1435 | "\n", 1436 | "Text: Welcome back, Wall Street! @JimCramer and @byKatherineRoss are breaking down all the latest on the markets and the… https://t.co/lzJSAKAv7s\n", 1437 | "True: 1 Pred: 1 ✅\n", 1438 | "Probs: 1: 0.877, 2: 0.119, 3: 0.005\n", 1439 | "\n", 1440 | "Text: U.S. Job Report Looks Likely to Show Hot 2020 Start, Cooler Past\n", 1441 | "True: 2 Pred: 1 ❌\n", 1442 | "Probs: 1: 0.698, 2: 0.291, 3: 0.011\n", 1443 | "\n", 1444 | "Text: Nigeria's petroleum bill to be passed by mid-2020, says oil minister #economy #MarketScreener… https://t.co/iTMZMP6wLO\n", 1445 | "True: 1 Pred: 1 ✅\n", 1446 | "Probs: 1: 0.788, 2: 0.199, 3: 0.013\n", 1447 | "\n", 1448 | "Text: $ECONX: Nonfarm Payroll Revisions- October revised to 156K from 128K; September revised to 193K from 180K https://t.co/WOg237QKLC\n", 1449 | "True: 2 Pred: 1 ❌\n", 1450 | "Probs: 1: 0.753, 2: 0.131, 3: 0.116\n", 1451 | "\n", 1452 | "Text: WhatsApp Closes In On Full Fledged UPI Payments Launch\n", 1453 | "True: 1 Pred: 1 ✅\n", 1454 | "Probs: 1: 0.826, 2: 0.163, 3: 0.012\n" 1455 | ] 1456 | }, 1457 | { 1458 | "name": "stderr", 1459 | "output_type": "stream", 1460 | "text": [ 1461 | "\n" 1462 | ] 1463 | } 1464 | ], 1465 | "source": [ 1466 | "import torch\n", 1467 | "import torch.nn.functional as F\n", 1468 | "from tqdm import tqdm\n", 1469 | "import random\n", 1470 | "\n", 1471 | "# Prepare inference prompt\n", 1472 | "inference_prompt_template = prompt.split(\"class {}\")[0] + \"class \"\n", 1473 | "\n", 1474 | "# Sort validation set by length for efficient batching\n", 1475 | "val_df['token_length'] = val_df['text'].apply(lambda x: len(tokenizer.encode(x, add_special_tokens=False)))\n", 1476 | "val_df_sorted = val_df.sort_values(by='token_length').reset_index(drop=True)\n", 1477 | "\n", 1478 | "display = 50\n", 1479 | "batch_size = 16\n", 1480 | "device = model.device\n", 1481 | "correct = 0\n", 1482 | "results = []\n", 1483 | "\n", 1484 | "with torch.inference_mode():\n", 1485 | " for i in tqdm(range(0, len(val_df_sorted), batch_size), desc=\"Evaluating\"):\n", 1486 | " batch = val_df_sorted.iloc[i:i+batch_size]\n", 1487 | " prompts = [inference_prompt_template.format(text) for text in batch['text']]\n", 1488 | " inputs = tokenizer(prompts, return_tensors=\"pt\", padding=True, truncation=True, max_length=max_seq_length).to(device)\n", 1489 | " logits = model(**inputs).logits\n", 1490 | " last_idxs = inputs.attention_mask.sum(1) - 1\n", 1491 | " last_logits = logits[torch.arange(len(batch)), last_idxs, :]\n", 1492 | " probs_all = F.softmax(last_logits, dim=-1)\n", 1493 | " probs = probs_all[:, number_token_ids] # only keep the logits for the number tokens\n", 1494 | " preds = torch.argmax(probs, dim=-1).cpu().numpy() # looks like [1 1 1 1 3 1 3 1 3 1 1 1 1 2 2 3]\n", 1495 | "\n", 1496 | " true_labels = batch['label'].tolist()\n", 1497 | " correct += sum([p == t for p, t in zip(preds, true_labels)])\n", 1498 | " # Store a few samples for display\n", 1499 | " for j in range(len(batch)):\n", 1500 | " results.append({\n", 1501 | " \"text\": batch['text'].iloc[j][:200],\n", 1502 | " \"true\": true_labels[j],\n", 1503 | " \"pred\": preds[j],\n", 1504 | " \"probs\": probs[j][1:].float().cpu().numpy(), # ignore prob for class 0 and convert from tensor to float\n", 1505 | " \"ok\": preds[j] == true_labels[j]\n", 1506 | " })\n", 1507 | "\n", 1508 | "accuracy = 100 * correct / len(val_df_sorted)\n", 1509 | "print(f\"\\nValidation accuracy: {accuracy:.2f}% ({correct}/{len(val_df_sorted)})\")\n", 1510 | "\n", 1511 | "print(\"\\n--- Random samples ---\")\n", 1512 | "for s in random.sample(results, min(display, len(results))):\n", 1513 | " print(f\"\\nText: {s['text']}\")\n", 1514 | " print(f\"True: {s['true']} Pred: {s['pred']} {'✅' if s['ok'] else '❌'}\")\n", 1515 | " print(\"Probs:\", \", \".join([f\"{k}: {v:.3f}\" for k, v in enumerate(s['probs'], start=1)]))\n", 1516 | "\n", 1517 | "# Clean up\n", 1518 | "if 'token_length' in val_df:\n", 1519 | " del val_df['token_length']" 1520 | ] 1521 | }, 1522 | { 1523 | "cell_type": "code", 1524 | "execution_count": 16, 1525 | "metadata": {}, 1526 | "outputs": [ 1527 | { 1528 | "ename": "ZeroDivisionError", 1529 | "evalue": "division by zero", 1530 | "output_type": "error", 1531 | "traceback": [ 1532 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 1533 | "\u001b[1;31mZeroDivisionError\u001b[0m Traceback (most recent call last)", 1534 | "Cell \u001b[1;32mIn[16], line 2\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m# stop running all cells\u001b[39;00m\n\u001b[1;32m----> 2\u001b[0m \u001b[38;5;241m1\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m0\u001b[39m\n", 1535 | "\u001b[1;31mZeroDivisionError\u001b[0m: division by zero" 1536 | ] 1537 | } 1538 | ], 1539 | "source": [ 1540 | "# stop running all cells\n", 1541 | "1/0" 1542 | ] 1543 | }, 1544 | { 1545 | "cell_type": "markdown", 1546 | "metadata": {}, 1547 | "source": [ 1548 | "Now if you closed the notebook kernel and want to reload the model:" 1549 | ] 1550 | }, 1551 | { 1552 | "cell_type": "code", 1553 | "execution_count": null, 1554 | "metadata": {}, 1555 | "outputs": [ 1556 | { 1557 | "name": "stdout", 1558 | "output_type": "stream", 1559 | "text": [ 1560 | "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n", 1561 | "🦥 Unsloth Zoo will now patch everything to make training faster!\n" 1562 | ] 1563 | }, 1564 | { 1565 | "name": "stderr", 1566 | "output_type": "stream", 1567 | "text": [ 1568 | "c:\\ProgramData\\Anaconda3\\Lib\\site-packages\\unsloth_zoo\\gradient_checkpointing.py:330: UserWarning: expandable_segments not supported on this platform (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\pytorch\\c10/cuda/CUDAAllocatorConfig.h:28.)\n", 1569 | " GPU_BUFFERS = tuple([torch.empty(2*256*2048, dtype = dtype, device = f\"cuda:{i}\") for i in range(n_gpus)])\n" 1570 | ] 1571 | }, 1572 | { 1573 | "name": "stdout", 1574 | "output_type": "stream", 1575 | "text": [ 1576 | "==((====))== Unsloth 2025.4.5: Fast Qwen3 patching. Transformers: 4.51.3.\n", 1577 | " \\\\ /| NVIDIA GeForce RTX 3090. Num GPUs = 1. Max memory: 23.999 GB. Platform: Windows.\n", 1578 | "O^O/ \\_/ \\ Torch: 2.7.0+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.0\n", 1579 | "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]\n", 1580 | " \"-____-\" Free license: http://github.com/unslothai/unsloth\n", 1581 | "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n" 1582 | ] 1583 | }, 1584 | { 1585 | "data": { 1586 | "application/vnd.jupyter.widget-view+json": { 1587 | "model_id": "73992b1f949448ce9981a732af2d7b66", 1588 | "version_major": 2, 1589 | "version_minor": 0 1590 | }, 1591 | "text/plain": [ 1592 | "Loading checkpoint shards: 0%| | 0/2 [00:00\n", 1732 | " \n", 1733 | " \n", 1734 | " Support our work if you can! Thanks!\n", 1735 | "" 1736 | ] 1737 | } 1738 | ], 1739 | "metadata": { 1740 | "accelerator": "GPU", 1741 | "colab": { 1742 | "gpuType": "T4", 1743 | "provenance": [] 1744 | }, 1745 | "kaggle": { 1746 | "accelerator": "gpu", 1747 | "dataSources": [ 1748 | { 1749 | "datasetId": 5081962, 1750 | "sourceId": 8512897, 1751 | "sourceType": "datasetVersion" 1752 | } 1753 | ], 1754 | "dockerImageVersionId": 30733, 1755 | "isGpuEnabled": true, 1756 | "isInternetEnabled": true, 1757 | "language": "python", 1758 | "sourceType": "notebook" 1759 | }, 1760 | "kernelspec": { 1761 | "display_name": "Python 3 (ipykernel)", 1762 | "language": "python", 1763 | "name": "python3" 1764 | }, 1765 | "language_info": { 1766 | "codemirror_mode": { 1767 | "name": "ipython", 1768 | "version": 3 1769 | }, 1770 | "file_extension": ".py", 1771 | "mimetype": "text/x-python", 1772 | "name": "python", 1773 | "nbconvert_exporter": "python", 1774 | "pygments_lexer": "ipython3", 1775 | "version": "3.12.3" 1776 | } 1777 | }, 1778 | "nbformat": 4, 1779 | "nbformat_minor": 4 1780 | } 1781 | --------------------------------------------------------------------------------