├── .gitignore
├── README.md
├── bert_classification.ipynb
├── data
├── finance_sentiment.csv
└── finance_sentiment_multiclass.csv
└── unsloth_classification.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 | dontcommit.ipynb
164 | _unsloth_temporary_saved_buffers/unsloth/Qwen2-0.5B-bnb-4bit/output_embeddings.pt
165 | answerdotai/ModernBERT-large/tokenizer_config.json
166 | answerdotai/ModernBERT-large/tokenizer.json
167 | answerdotai/ModernBERT-large/special_tokens_map.json
168 | answerdotai/ModernBERT-large/model.safetensors
169 | answerdotai/ModernBERT-large/config.json
170 | _unsloth_temporary_saved_buffers/Qwen3-4B-Base/output_embeddings.pt
171 | _unsloth_temporary_saved_buffers/Qwen3-0.6B-Base/output_embeddings.pt
172 | _unsloth_temporary_saved_buffers/Qwen2-1.5B-bnb-4bit/output_embeddings.pt
173 | data/orig.csv
174 | data/process.ipynb
175 | lora_model_Qwen2-1.5B-bnb-4bit/adapter_config.json
176 | lora_model_Qwen2-1.5B-bnb-4bit/adapter_model.safetensors
177 | lora_model_Qwen2-1.5B-bnb-4bit/README.md
178 | lora_model_Qwen3-0.6B-Base/adapter_config.json
179 | lora_model_Qwen3-0.6B-Base/adapter_model.safetensors
180 | lora_model_Qwen3-0.6B-Base/README.md
181 | lora_model_Qwen3-4B-Base/adapter_config.json
182 | lora_model_Qwen3-4B-Base/adapter_model.safetensors
183 | lora_model_Qwen3-4B-Base/README.md
184 | outputs/*
185 | Qwen2-1.5B-bnb-4bit/added_tokens.json
186 | Qwen2-1.5B-bnb-4bit/config.json
187 | Qwen2-1.5B-bnb-4bit/generation_config.json
188 | Qwen2-1.5B-bnb-4bit/gitattributes
189 | Qwen2-1.5B-bnb-4bit/merges.txt
190 | Qwen2-1.5B-bnb-4bit/model.safetensors
191 | Qwen2-1.5B-bnb-4bit/README.md
192 | Qwen2-1.5B-bnb-4bit/special_tokens_map.json
193 | Qwen2-1.5B-bnb-4bit/tokenizer.json
194 | Qwen2-1.5B-bnb-4bit/tokenizer_config.json
195 | Qwen2-1.5B-bnb-4bit/vocab.json
196 | Qwen3-0.6B-Base/added_tokens.json
197 | Qwen3-0.6B-Base/config.json
198 | Qwen3-0.6B-Base/generation_config.json
199 | Qwen3-0.6B-Base/merges.txt
200 | Qwen3-0.6B-Base/model.safetensors
201 | Qwen3-0.6B-Base/special_tokens_map.json
202 | Qwen3-0.6B-Base/tokenizer.json
203 | Qwen3-0.6B-Base/tokenizer_config.json
204 | Qwen3-0.6B-Base/vocab.json
205 | Qwen3-4B-Base/added_tokens.json
206 | Qwen3-4B-Base/config.json
207 | Qwen3-4B-Base/generation_config.json
208 | Qwen3-4B-Base/merges.txt
209 | Qwen3-4B-Base/model-00001-of-00002.safetensors
210 | Qwen3-4B-Base/model-00002-of-00002.safetensors
211 | Qwen3-4B-Base/model.safetensors.index.json
212 | Qwen3-4B-Base/special_tokens_map.json
213 | Qwen3-4B-Base/tokenizer.json
214 | Qwen3-4B-Base/tokenizer_config.json
215 | Qwen3-4B-Base/vocab.json
216 | unsloth_classification copy.ipynb
217 | unsloth_classification_broken.ipynb
218 | unsloth_compiled_cache/*
219 | trainer_output/*
220 | bert_classification_crash_test.ipynb
221 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Text classification scripts
2 |
3 | ## unsloth_classification.ipynb
4 |
5 | This modified Unsloth notebook trains LLaMa-3 on any text classification dataset, where the input is a csv with columns "text" and "label".
6 |
7 | ### Added features:
8 |
9 | - Trims the classification head to contain only the "Yes" and "No" tokens, which saves 1 GB of VRAM, allows you to train the head without massive memory usage, and makes the start of the training session more stable.
10 | - Only the last token in the sequence contributes to the loss, the model doesn't waste its capacity by trying to predict the input
11 | - includes "group_by_length = True" which speeds up training significantly for unbalanced sequence lengths
12 | - Efficiently evaluates the accuracy on the validation set using batched inference
13 |
14 | ## bert_classification.ipynb
15 |
16 | This notebook can be used to train any bert model on any text classification dataset (same format as above). The notebook also includes "group_by_length = True" which not commonly found in bert-training notebooks (they usually tokenize everything ahead of time with a lot of wasteful padding).
17 |
--------------------------------------------------------------------------------
/bert_classification.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "env: UNSLOTH_COMPILE_DISABLE=1\n"
13 | ]
14 | }
15 | ],
16 | "source": [
17 | "# needed to fix a bug with unsloth\n",
18 | "%env UNSLOTH_COMPILE_DISABLE = 1"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 2,
24 | "metadata": {},
25 | "outputs": [
26 | {
27 | "name": "stdout",
28 | "output_type": "stream",
29 | "text": [
30 | "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n",
31 | "🦥 Unsloth Zoo will now patch everything to make training faster!\n"
32 | ]
33 | },
34 | {
35 | "name": "stderr",
36 | "output_type": "stream",
37 | "text": [
38 | "c:\\ProgramData\\Anaconda3\\Lib\\site-packages\\unsloth_zoo\\gradient_checkpointing.py:330: UserWarning: expandable_segments not supported on this platform (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\pytorch\\c10/cuda/CUDAAllocatorConfig.h:28.)\n",
39 | " GPU_BUFFERS = tuple([torch.empty(2*256*2048, dtype = dtype, device = f\"cuda:{i}\") for i in range(n_gpus)])\n"
40 | ]
41 | },
42 | {
43 | "name": "stdout",
44 | "output_type": "stream",
45 | "text": [
46 | "==((====))== Unsloth 2025.4.5: Fast Modernbert patching. Transformers: 4.51.3.\n",
47 | " \\\\ /| NVIDIA GeForce RTX 3090. Num GPUs = 1. Max memory: 23.999 GB. Platform: Windows.\n",
48 | "O^O/ \\_/ \\ Torch: 2.7.0+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.0\n",
49 | "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]\n",
50 | " \"-____-\" Free license: http://github.com/unslothai/unsloth\n",
51 | "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n",
52 | "Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.\n"
53 | ]
54 | },
55 | {
56 | "name": "stderr",
57 | "output_type": "stream",
58 | "text": [
59 | "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
60 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
61 | ]
62 | },
63 | {
64 | "name": "stdout",
65 | "output_type": "stream",
66 | "text": [
67 | "model parameters:395834371\n"
68 | ]
69 | }
70 | ],
71 | "source": [
72 | "from unsloth import FastLanguageModel, FastModel\n",
73 | "import pandas as pd\n",
74 | "import numpy as np\n",
75 | "from sklearn.model_selection import train_test_split\n",
76 | "from sklearn.metrics import accuracy_score\n",
77 | "import os\n",
78 | "import torch\n",
79 | "from torch import tensor\n",
80 | "import torch.nn.functional as F\n",
81 | "from transformers import TrainingArguments, Trainer, ModernBertModel, AutoModelForSequenceClassification, training_args\n",
82 | "from datasets import load_dataset, Dataset\n",
83 | "from tqdm import tqdm\n",
84 | "\n",
85 | "model_name = 'answerdotai/ModernBERT-large'\n",
86 | "\n",
87 | "NUM_CLASSES = 3\n",
88 | "DATA_DIR = \"data/\"\n",
89 | "\n",
90 | "model, tokenizer = FastModel.from_pretrained(\n",
91 | " model_name = model_name,load_in_4bit = False,\n",
92 | " max_seq_length = 2048,\n",
93 | " dtype = None,\n",
94 | " auto_model = AutoModelForSequenceClassification,\n",
95 | " num_labels = NUM_CLASSES,\n",
96 | ")\n",
97 | "print(\"model parameters:\" + str(sum(p.numel() for p in model.parameters())))\n",
98 | "\n",
99 | "# make all parameters trainable\n",
100 | "for param in model.parameters():\n",
101 | " param.requires_grad = True"
102 | ]
103 | },
104 | {
105 | "cell_type": "markdown",
106 | "metadata": {},
107 | "source": [
108 | "The dataset can be found [here](https://github.com/timothelaborie/text_classification_scripts/blob/main/data/finance_sentiment_multiclass.csv)"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 3,
114 | "metadata": {},
115 | "outputs": [
116 | {
117 | "data": {
118 | "application/vnd.jupyter.widget-view+json": {
119 | "model_id": "e1dd1859a86d47f7ae5ac7e54dec68c4",
120 | "version_major": 2,
121 | "version_minor": 0
122 | },
123 | "text/plain": [
124 | "Map: 0%| | 0/3893 [00:00, ? examples/s]"
125 | ]
126 | },
127 | "metadata": {},
128 | "output_type": "display_data"
129 | },
130 | {
131 | "data": {
132 | "application/vnd.jupyter.widget-view+json": {
133 | "model_id": "6ed85ecfce1f4054ab7fc4f6135027db",
134 | "version_major": 2,
135 | "version_minor": 0
136 | },
137 | "text/plain": [
138 | "Map: 0%| | 0/433 [00:00, ? examples/s]"
139 | ]
140 | },
141 | "metadata": {},
142 | "output_type": "display_data"
143 | },
144 | {
145 | "data": {
146 | "text/plain": [
147 | "Dataset({\n",
148 | " features: ['text', 'labels', 'input_ids', 'attention_mask'],\n",
149 | " num_rows: 3893\n",
150 | "})"
151 | ]
152 | },
153 | "execution_count": 3,
154 | "metadata": {},
155 | "output_type": "execute_result"
156 | }
157 | ],
158 | "source": [
159 | "data = pd.read_csv(DATA_DIR + \"finance_sentiment_multiclass.csv\")\n",
160 | "\n",
161 | "labels = data[\"label\"].tolist()\n",
162 | "labels = [x-1 for x in labels]\n",
163 | "# convert labels to one hot vectors\n",
164 | "labels = np.eye(NUM_CLASSES)[labels]\n",
165 | "\n",
166 | "train_data,val_data, train_labels, val_labels = train_test_split(data[\"text\"], labels, test_size=0.1, random_state=42)\n",
167 | "dataset = Dataset.from_list([{'text': text, 'labels': label} for text, label in zip(train_data, train_labels)])\n",
168 | "val_dataset = Dataset.from_list([{'text': text, 'labels': label} for text, label in zip(val_data, val_labels)])\n",
169 | "\n",
170 | "def tokenize_function(examples):\n",
171 | " return tokenizer(examples['text'])\n",
172 | "\n",
173 | "dataset = dataset.map(tokenize_function, batched=True)\n",
174 | "val_dataset = val_dataset.map(tokenize_function, batched=True)\n",
175 | "dataset"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": 4,
181 | "metadata": {},
182 | "outputs": [
183 | {
184 | "name": "stderr",
185 | "output_type": "stream",
186 | "text": [
187 | "C:\\Users\\Timothe\\AppData\\Local\\Temp\\ipykernel_66148\\29605135.py:1: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
188 | " trainer = Trainer(\n",
189 | "==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1\n",
190 | " \\\\ /| Num examples = 3,893 | Num Epochs = 3 | Total steps = 366\n",
191 | "O^O/ \\_/ \\ Batch size per device = 32 | Gradient accumulation steps = 1\n",
192 | "\\ / Data Parallel GPUs = 1 | Total batch size (32 x 1 x 1) = 32\n",
193 | " \"-____-\" Trainable parameters = 395,834,371/395,834,371 (100.00% trained)\n"
194 | ]
195 | },
196 | {
197 | "data": {
198 | "text/html": [
199 | "\n",
200 | "
\n",
201 | " \n",
202 | "
\n",
203 | " [366/366 01:41, Epoch 3/3]\n",
204 | "
\n",
205 | " \n",
206 | " \n",
207 | " \n",
208 | " Step | \n",
209 | " Training Loss | \n",
210 | " Validation Loss | \n",
211 | " Accuracy | \n",
212 | "
\n",
213 | " \n",
214 | " \n",
215 | " \n",
216 | " 92 | \n",
217 | " 0.564900 | \n",
218 | " 0.430882 | \n",
219 | " 0.713626 | \n",
220 | "
\n",
221 | " \n",
222 | " 184 | \n",
223 | " 0.346500 | \n",
224 | " 0.308327 | \n",
225 | " 0.812933 | \n",
226 | "
\n",
227 | " \n",
228 | " 276 | \n",
229 | " 0.247500 | \n",
230 | " 0.288410 | \n",
231 | " 0.822171 | \n",
232 | "
\n",
233 | " \n",
234 | "
"
235 | ],
236 | "text/plain": [
237 | ""
238 | ]
239 | },
240 | "metadata": {},
241 | "output_type": "display_data"
242 | },
243 | {
244 | "name": "stdout",
245 | "output_type": "stream",
246 | "text": [
247 | "Unsloth: Will smartly offload gradients to save VRAM!\n"
248 | ]
249 | }
250 | ],
251 | "source": [
252 | "trainer = Trainer(\n",
253 | " model=model,\n",
254 | " tokenizer=tokenizer,\n",
255 | " train_dataset=dataset,\n",
256 | " eval_dataset=val_dataset,\n",
257 | " args=TrainingArguments(\n",
258 | " per_device_train_batch_size=32,\n",
259 | " gradient_accumulation_steps=1,\n",
260 | " warmup_steps=10,\n",
261 | " fp16=not torch.cuda.is_bf16_supported(),\n",
262 | " bf16=torch.cuda.is_bf16_supported(),\n",
263 | " optim=training_args.OptimizerNames.ADAMW_TORCH,\n",
264 | " learning_rate=5e-5,\n",
265 | " weight_decay=0.001,\n",
266 | " lr_scheduler_type=\"cosine\",\n",
267 | " seed=3407,\n",
268 | " num_train_epochs=3, # bert-style models usually need more than 1 epoch\n",
269 | " save_strategy=\"epoch\",\n",
270 | "\n",
271 | " # report_to=\"wandb\",\n",
272 | " report_to=\"none\",\n",
273 | "\n",
274 | " group_by_length=True,\n",
275 | "\n",
276 | " # eval_strategy=\"no\",\n",
277 | " eval_strategy=\"steps\",\n",
278 | " eval_steps=0.25,\n",
279 | " logging_strategy=\"steps\",\n",
280 | " logging_steps=0.25,\n",
281 | " \n",
282 | " ),\n",
283 | " compute_metrics=lambda eval_pred: { \"accuracy\": accuracy_score(eval_pred[1].argmax(axis=-1), eval_pred[0].argmax(axis=-1)) }\n",
284 | ")\n",
285 | "trainer_stats = trainer.train()"
286 | ]
287 | },
288 | {
289 | "cell_type": "code",
290 | "execution_count": 5,
291 | "metadata": {},
292 | "outputs": [
293 | {
294 | "name": "stdout",
295 | "output_type": "stream",
296 | "text": [
297 | "\n"
298 | ]
299 | }
300 | ],
301 | "source": [
302 | "model = model.cuda()\n",
303 | "model = model.eval()\n",
304 | "FastLanguageModel.for_inference(model)\n",
305 | "print()"
306 | ]
307 | },
308 | {
309 | "cell_type": "code",
310 | "execution_count": 6,
311 | "metadata": {},
312 | "outputs": [
313 | {
314 | "name": "stderr",
315 | "output_type": "stream",
316 | "text": [
317 | "Evaluating: 100%|██████████| 14/14 [00:01<00:00, 11.78it/s]"
318 | ]
319 | },
320 | {
321 | "name": "stdout",
322 | "output_type": "stream",
323 | "text": [
324 | "\n",
325 | "Validation accuracy: 82.45% (357/433)\n",
326 | "\n",
327 | "--- Random samples ---\n",
328 | "\n",
329 | "Text: Turkey Stiffens Manipulation Penalties in Banking Overhaul\n",
330 | "True: 0 Pred: 0 ✅\n",
331 | "Probs: 0: 0.846, 1: 0.114, 2: 0.040\n",
332 | "\n",
333 | "Text: The Manitowoc Company, Inc. Just Reported Earnings, And Analysts Cut Their Target Price\n",
334 | "True: 2 Pred: 2 ✅\n",
335 | "Probs: 0: 0.065, 1: 0.029, 2: 0.906\n",
336 | "\n",
337 | "Text: $BLMN $EAT $SBUX - Restaurants stocks break higher, analysts reel in near-term expectations https://t.co/fOjVVJdfF0\n",
338 | "True: 1 Pred: 1 ✅\n",
339 | "Probs: 0: 0.004, 1: 0.974, 2: 0.021\n",
340 | "\n",
341 | "Text: $CMCSA $LHX - Comcast sues L3Harris in patent dispute https://t.co/kWReshGbvz\n",
342 | "True: 2 Pred: 2 ✅\n",
343 | "Probs: 0: 0.023, 1: 0.035, 2: 0.942\n",
344 | "\n",
345 | "Text: Libyan economic experts will study the distribution of crucial oil revenue as efforts continue to solve the war-rav… https://t.co/S9lmpnDTqJ\n",
346 | "True: 0 Pred: 0 ✅\n",
347 | "Probs: 0: 0.888, 1: 0.013, 2: 0.099\n",
348 | "\n",
349 | "Text: Stocks Suffer 'Shocking' Down Week As Fed Balance Sheet Unexpectedly Shrinks https://t.co/bspsRi3Wow\n",
350 | "True: 2 Pred: 2 ✅\n",
351 | "Probs: 0: 0.005, 1: 0.001, 2: 0.993\n",
352 | "\n",
353 | "Text: Burger King says it never promised Impossible Whoppers were vegan https://t.co/oZCnoupsYV https://t.co/lauoccNH0n\n",
354 | "True: 0 Pred: 0 ✅\n",
355 | "Probs: 0: 0.762, 1: 0.070, 2: 0.167\n",
356 | "\n",
357 | "Text: McEwen Mining prices public offering at $1.325/unit\n",
358 | "True: 0 Pred: 0 ✅\n",
359 | "Probs: 0: 0.710, 1: 0.284, 2: 0.006\n",
360 | "\n",
361 | "Text: H&P downgraded at Argus as drilling industry weakness seen persisting\n",
362 | "True: 2 Pred: 2 ✅\n",
363 | "Probs: 0: 0.000, 1: 0.000, 2: 1.000\n",
364 | "\n",
365 | "Text: $SPLK - Splunk: Full Steam Ahead. Follow this and any other stock on Seeking Alpha! https://t.co/DCvnAuSOBa #markets #stocks #economy\n",
366 | "True: 0 Pred: 1 ❌\n",
367 | "Probs: 0: 0.040, 1: 0.959, 2: 0.001\n",
368 | "\n",
369 | "Text: Three dead in shooting at Oklahoma Walmart: RPT\n",
370 | "True: 0 Pred: 2 ❌\n",
371 | "Probs: 0: 0.036, 1: 0.001, 2: 0.964\n",
372 | "\n",
373 | "Text: U.S. Oil Inventories Rise More Than Expected #WTI #Stock #MarketScreener https://t.co/lMkNlbjinO https://t.co/wBBq3HdLZO\n",
374 | "True: 2 Pred: 1 ❌\n",
375 | "Probs: 0: 0.002, 1: 0.997, 2: 0.001\n",
376 | "\n",
377 | "Text: Casper Sleep stock languishes below IPO issue price after falling 5%\n",
378 | "True: 2 Pred: 2 ✅\n",
379 | "Probs: 0: 0.000, 1: 0.000, 2: 1.000\n",
380 | "\n",
381 | "Text: The global oil market is drowning in excess crude as demand plummets. Insights via @CMEGroup https://t.co/JklSJKFfRS\n",
382 | "True: 2 Pred: 2 ✅\n",
383 | "Probs: 0: 0.001, 1: 0.001, 2: 0.998\n",
384 | "\n",
385 | "Text: Asia Stocks Open Mixed as Trade Details Awaited: Markets Wrap\n",
386 | "True: 0 Pred: 1 ❌\n",
387 | "Probs: 0: 0.023, 1: 0.865, 2: 0.112\n",
388 | "\n",
389 | "Text: From Starbucks to Seattle, companies and cities alike are banning plastic straws. Are takeout containers next?… https://t.co/Ew4Fsl6K0m\n",
390 | "True: 0 Pred: 0 ✅\n",
391 | "Probs: 0: 0.968, 1: 0.003, 2: 0.029\n",
392 | "\n",
393 | "Text: Americans' outlook on the economy faltered significantly last month as the coronavirus crisis began to take hold in… https://t.co/5jeCXLXrrR\n",
394 | "True: 2 Pred: 2 ✅\n",
395 | "Probs: 0: 0.002, 1: 0.002, 2: 0.996\n",
396 | "\n",
397 | "Text: Boris Johnson’s Conservative Party will pledge not to increase several key tax measures if it wins next month’s gen… https://t.co/RENnMT4Dtr\n",
398 | "True: 0 Pred: 0 ✅\n",
399 | "Probs: 0: 0.999, 1: 0.001, 2: 0.001\n",
400 | "\n",
401 | "Text: Casper Sleep shares slide 5% to trade at $10.46, below $12 IPO price\n",
402 | "True: 2 Pred: 2 ✅\n",
403 | "Probs: 0: 0.000, 1: 0.000, 2: 1.000\n",
404 | "\n",
405 | "Text: Brixmor 2020 FFO guidance comes in on the light side\n",
406 | "True: 2 Pred: 0 ❌\n",
407 | "Probs: 0: 0.868, 1: 0.131, 2: 0.002\n"
408 | ]
409 | },
410 | {
411 | "name": "stderr",
412 | "output_type": "stream",
413 | "text": [
414 | "\n"
415 | ]
416 | }
417 | ],
418 | "source": [
419 | "batch_size = 32\n",
420 | "correct = 0\n",
421 | "results = []\n",
422 | "\n",
423 | "# If the val_labels are one-hot, convert to class indices\n",
424 | "if isinstance(val_labels, np.ndarray) and val_labels.ndim == 2:\n",
425 | " val_true_labels = np.argmax(val_labels, axis=1)\n",
426 | "else:\n",
427 | " val_true_labels = val_labels\n",
428 | "\n",
429 | "val_texts = list(val_data)\n",
430 | "val_true_labels = list(val_true_labels)\n",
431 | "\n",
432 | "with torch.no_grad():\n",
433 | " for i in tqdm(range(0, len(val_texts), batch_size), desc=\"Evaluating\"):\n",
434 | " batch_texts = val_texts[i:i+batch_size]\n",
435 | " batch_labels = val_true_labels[i:i+batch_size]\n",
436 | " # Tokenize\n",
437 | " inputs = tokenizer(batch_texts, return_tensors=\"pt\", padding=True, truncation=True, max_length=2048)\n",
438 | " inputs = {k: v.cuda() for k, v in inputs.items()}\n",
439 | " # Forward pass\n",
440 | " outputs = model(**inputs)\n",
441 | " logits = outputs.logits\n",
442 | " probs = F.softmax(logits, dim=-1)\n",
443 | " preds = torch.argmax(probs, dim=-1).cpu().numpy()\n",
444 | " # Count correct\n",
445 | " correct += np.sum(preds == batch_labels)\n",
446 | " # Store results for display\n",
447 | " for j in range(len(batch_texts)):\n",
448 | " results.append({\n",
449 | " \"text\": batch_texts[j][:200],\n",
450 | " \"true\": batch_labels[j],\n",
451 | " \"pred\": preds[j],\n",
452 | " \"probs\": probs[j].detach().float().cpu().numpy(),\n",
453 | " \"ok\": preds[j] == batch_labels[j]\n",
454 | " })\n",
455 | "\n",
456 | "accuracy = 100 * correct / len(val_texts)\n",
457 | "print(f\"\\nValidation accuracy: {accuracy:.2f}% ({correct}/{len(val_texts)})\")\n",
458 | "\n",
459 | "# Show a few random samples\n",
460 | "import random\n",
461 | "display = 20\n",
462 | "print(\"\\n--- Random samples ---\")\n",
463 | "for s in random.sample(results, min(display, len(results))):\n",
464 | " print(f\"\\nText: {s['text']}\")\n",
465 | " print(f\"True: {s['true']} Pred: {s['pred']} {'✅' if s['ok'] else '❌'}\")\n",
466 | " print(\"Probs:\", \", \".join([f\"{k}: {v:.3f}\" for k, v in enumerate(s['probs'])]))"
467 | ]
468 | },
469 | {
470 | "cell_type": "code",
471 | "execution_count": 7,
472 | "metadata": {},
473 | "outputs": [
474 | {
475 | "ename": "ZeroDivisionError",
476 | "evalue": "division by zero",
477 | "output_type": "error",
478 | "traceback": [
479 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
480 | "\u001b[1;31mZeroDivisionError\u001b[0m Traceback (most recent call last)",
481 | "Cell \u001b[1;32mIn[7], line 2\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m# stop running all cells\u001b[39;00m\n\u001b[1;32m----> 2\u001b[0m \u001b[38;5;241m1\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m0\u001b[39m\n",
482 | "\u001b[1;31mZeroDivisionError\u001b[0m: division by zero"
483 | ]
484 | }
485 | ],
486 | "source": [
487 | "# stop running all cells\n",
488 | "1/0"
489 | ]
490 | },
491 | {
492 | "cell_type": "markdown",
493 | "metadata": {},
494 | "source": [
495 | "# to load the model again (run every cell above the one where the trainer is called)"
496 | ]
497 | },
498 | {
499 | "cell_type": "code",
500 | "execution_count": null,
501 | "metadata": {},
502 | "outputs": [
503 | {
504 | "name": "stdout",
505 | "output_type": "stream",
506 | "text": [
507 | "Last checkpoint: trainer_output\\checkpoint-244\n",
508 | "==((====))== Unsloth 2025.4.5: Fast Modernbert patching. Transformers: 4.51.3.\n",
509 | " \\\\ /| NVIDIA GeForce RTX 3090. Num GPUs = 1. Max memory: 23.999 GB. Platform: Windows.\n",
510 | "O^O/ \\_/ \\ Torch: 2.7.0+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.0\n",
511 | "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]\n",
512 | " \"-____-\" Free license: http://github.com/unslothai/unsloth\n",
513 | "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n",
514 | "Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.\n"
515 | ]
516 | }
517 | ],
518 | "source": [
519 | "from transformers.trainer_utils import get_last_checkpoint\n",
520 | "\n",
521 | "output_dir = \"trainer_output\"\n",
522 | "last_checkpoint = get_last_checkpoint(output_dir)\n",
523 | "print(\"Last checkpoint:\", last_checkpoint)\n",
524 | "\n",
525 | "model, tokenizer = FastModel.from_pretrained(\n",
526 | " model_name = last_checkpoint,load_in_4bit = False,\n",
527 | " max_seq_length = 2048,\n",
528 | " dtype = None,\n",
529 | " auto_model = AutoModelForSequenceClassification,\n",
530 | " num_labels = NUM_CLASSES,\n",
531 | ")"
532 | ]
533 | },
534 | {
535 | "cell_type": "code",
536 | "execution_count": null,
537 | "metadata": {},
538 | "outputs": [
539 | {
540 | "name": "stdout",
541 | "output_type": "stream",
542 | "text": [
543 | "SequenceClassifierOutput(loss=None, logits=tensor([[-0.0579, -0.5859, -1.1719]], device='cuda:0', dtype=torch.bfloat16,\n",
544 | " grad_fn=), hidden_states=None, attentions=None)\n"
545 | ]
546 | }
547 | ],
548 | "source": [
549 | "from torch import tensor\n",
550 | "print(model(input_ids=tensor([[1,2,3,4,5]]).cuda(), attention_mask=tensor([[1,1,1,1,1]]).cuda()))"
551 | ]
552 | }
553 | ],
554 | "metadata": {
555 | "kernelspec": {
556 | "display_name": "base",
557 | "language": "python",
558 | "name": "python3"
559 | },
560 | "language_info": {
561 | "codemirror_mode": {
562 | "name": "ipython",
563 | "version": 3
564 | },
565 | "file_extension": ".py",
566 | "mimetype": "text/x-python",
567 | "name": "python",
568 | "nbconvert_exporter": "python",
569 | "pygments_lexer": "ipython3",
570 | "version": "3.12.3"
571 | },
572 | "orig_nbformat": 4,
573 | "vscode": {
574 | "interpreter": {
575 | "hash": "1be15a159d9874788f7b7854451912393d9e82d0d2bc47d83a870bda7fd9bc22"
576 | }
577 | }
578 | },
579 | "nbformat": 4,
580 | "nbformat_minor": 2
581 | }
582 |
--------------------------------------------------------------------------------
/unsloth_classification.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "IqM-T1RTzY6C"
7 | },
8 | "source": [
9 | "# Text classification with Unsloth\n",
10 | "\n",
11 | "This modified Unsloth notebook trains an LLM on any text classification dataset, where the input is a csv with columns \"text\" and \"label\".\n",
12 | "\n",
13 | "### Added features:\n",
14 | "\n",
15 | "- Trims the classification head to contain only the number tokens such as \"1\", \"2\" etc, which saves 1 GB of VRAM, allows you to train the head without massive memory usage, and makes the start of the training session more stable.\n",
16 | "- Only the last token in the sequence contributes to the loss, the model doesn't waste its capacity by trying to predict the input\n",
17 | "- includes \"group_by_length = True\" which speeds up training significantly for unbalanced sequence lengths\n",
18 | "- Efficiently evaluates the accuracy on the validation set using batched inference\n",
19 | "\n",
20 | "### Update 4th of May 2025:\n",
21 | "\n",
22 | "- Added support for more than 2 classes\n",
23 | "- The classification head is now built back up to the original size after training, no more errors in external libraries.\n",
24 | "- Made the batched inference part much faster and cleaner\n",
25 | "- Changed model to Qwen 3\n",
26 | "- Improved comments to explain the complicated parts"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 1,
32 | "metadata": {},
33 | "outputs": [
34 | {
35 | "name": "stdout",
36 | "output_type": "stream",
37 | "text": [
38 | "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n",
39 | "🦥 Unsloth Zoo will now patch everything to make training faster!\n"
40 | ]
41 | }
42 | ],
43 | "source": [
44 | "# needed as this function doesn't like it when the lm_head has its size changed\n",
45 | "from unsloth import tokenizer_utils\n",
46 | "def do_nothing(*args, **kwargs):\n",
47 | " pass\n",
48 | "tokenizer_utils.fix_untrained_tokens = do_nothing"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 2,
54 | "metadata": {},
55 | "outputs": [
56 | {
57 | "name": "stdout",
58 | "output_type": "stream",
59 | "text": [
60 | "Major: 8, Minor: 6\n"
61 | ]
62 | },
63 | {
64 | "name": "stderr",
65 | "output_type": "stream",
66 | "text": [
67 | "c:\\ProgramData\\Anaconda3\\Lib\\site-packages\\unsloth_zoo\\gradient_checkpointing.py:330: UserWarning: expandable_segments not supported on this platform (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\pytorch\\c10/cuda/CUDAAllocatorConfig.h:28.)\n",
68 | " GPU_BUFFERS = tuple([torch.empty(2*256*2048, dtype = dtype, device = f\"cuda:{i}\") for i in range(n_gpus)])\n"
69 | ]
70 | },
71 | {
72 | "name": "stdout",
73 | "output_type": "stream",
74 | "text": [
75 | "==((====))== Unsloth 2025.4.5: Fast Qwen3 patching. Transformers: 4.51.3.\n",
76 | " \\\\ /| NVIDIA GeForce RTX 3090. Num GPUs = 1. Max memory: 23.999 GB. Platform: Windows.\n",
77 | "O^O/ \\_/ \\ Torch: 2.7.0+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.0\n",
78 | "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]\n",
79 | " \"-____-\" Free license: http://github.com/unslothai/unsloth\n",
80 | "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n"
81 | ]
82 | },
83 | {
84 | "data": {
85 | "application/vnd.jupyter.widget-view+json": {
86 | "model_id": "67e411e7801f4f9db7b980df6de5d01b",
87 | "version_major": 2,
88 | "version_minor": 0
89 | },
90 | "text/plain": [
91 | "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
92 | ]
93 | },
94 | "metadata": {},
95 | "output_type": "display_data"
96 | }
97 | ],
98 | "source": [
99 | "import torch\n",
100 | "major_version, minor_version = torch.cuda.get_device_capability()\n",
101 | "print(f\"Major: {major_version}, Minor: {minor_version}\")\n",
102 | "from datasets import load_dataset\n",
103 | "import datasets\n",
104 | "from trl import SFTTrainer\n",
105 | "import pandas as pd\n",
106 | "import numpy as np\n",
107 | "import os\n",
108 | "import pandas as pd\n",
109 | "import numpy as np\n",
110 | "from unsloth import FastLanguageModel\n",
111 | "from trl import SFTTrainer\n",
112 | "from transformers import TrainingArguments, Trainer\n",
113 | "from typing import Tuple\n",
114 | "import warnings\n",
115 | "from typing import Any, Dict, List, Union\n",
116 | "from transformers import DataCollatorForLanguageModeling\n",
117 | "from sklearn.model_selection import train_test_split\n",
118 | "import matplotlib.pyplot as plt\n",
119 | "\n",
120 | "NUM_CLASSES = 3 # number of classes in the csv\n",
121 | "\n",
122 | "max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!\n",
123 | "dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+\n",
124 | "\n",
125 | "model_name = \"unsloth/Qwen3-4B-Base\";load_in_4bit = False\n",
126 | "# model_name = \"Qwen3-4B-Base\";load_in_4bit = False\n",
127 | "\n",
128 | "model, tokenizer = FastLanguageModel.from_pretrained(\n",
129 | " model_name = model_name,load_in_4bit = load_in_4bit,\n",
130 | " max_seq_length = max_seq_length,\n",
131 | " dtype = dtype,\n",
132 | ")"
133 | ]
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "metadata": {
138 | "id": "SXd9bTZd1aaL"
139 | },
140 | "source": [
141 | "We now trim the classification head so the model can only say numbers 0-NUM_CLASSES and no other words. (We don't use 0 here but keeping it makes everything simpler)"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": 3,
147 | "metadata": {},
148 | "outputs": [
149 | {
150 | "name": "stdout",
151 | "output_type": "stream",
152 | "text": [
153 | "torch.Size([4, 2560])\n",
154 | "torch.Size([151936, 2560])\n"
155 | ]
156 | },
157 | {
158 | "data": {
159 | "text/plain": [
160 | "{15: 0, 16: 1, 17: 2, 18: 3}"
161 | ]
162 | },
163 | "execution_count": 3,
164 | "metadata": {},
165 | "output_type": "execute_result"
166 | }
167 | ],
168 | "source": [
169 | "number_token_ids = []\n",
170 | "for i in range(0, NUM_CLASSES+1):\n",
171 | " number_token_ids.append(tokenizer.encode(str(i), add_special_tokens=False)[0])\n",
172 | "# keep only the number tokens from lm_head\n",
173 | "par = torch.nn.Parameter(model.lm_head.weight[number_token_ids, :])\n",
174 | "\n",
175 | "old_shape = model.lm_head.weight.shape\n",
176 | "old_size = old_shape[0]\n",
177 | "print(par.shape)\n",
178 | "print(old_shape)\n",
179 | "\n",
180 | "model.lm_head.weight = par\n",
181 | "\n",
182 | "reverse_map = {value: idx for idx, value in enumerate(number_token_ids)} # will be used later to convert an idx from the old tokenizer to the new lm_head\n",
183 | "reverse_map"
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "execution_count": 4,
189 | "metadata": {},
190 | "outputs": [
191 | {
192 | "name": "stdout",
193 | "output_type": "stream",
194 | "text": [
195 | "Unsloth: Offloading output_embeddings to disk to save VRAM\n"
196 | ]
197 | },
198 | {
199 | "name": "stderr",
200 | "output_type": "stream",
201 | "text": [
202 | "Unsloth 2025.4.5 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.\n"
203 | ]
204 | },
205 | {
206 | "name": "stdout",
207 | "output_type": "stream",
208 | "text": [
209 | "Unsloth: Training lm_head in mixed precision to save VRAM\n",
210 | "trainable parameters: 33040384\n"
211 | ]
212 | }
213 | ],
214 | "source": [
215 | "from peft import LoftQConfig\n",
216 | "\n",
217 | "model = FastLanguageModel.get_peft_model(\n",
218 | " model,\n",
219 | " r = 16,\n",
220 | " target_modules = [\n",
221 | " \"lm_head\", # can easily be trained because it now has a small size\n",
222 | " \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
223 | " \"gate_proj\", \"up_proj\", \"down_proj\",],\n",
224 | " lora_alpha = 16,\n",
225 | " lora_dropout = 0, # Supports any, but = 0 is optimized\n",
226 | " bias = \"none\", # Supports any, but = \"none\" is optimized\n",
227 | " use_gradient_checkpointing = \"unsloth\",\n",
228 | " random_state = 3407,\n",
229 | " use_rslora = True, # We support rank stabilized LoRA\n",
230 | " # init_lora_weights = 'loftq',\n",
231 | " # loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1), # And LoftQ\n",
232 | ")\n",
233 | "print(\"trainable parameters:\", sum(p.numel() for p in model.parameters() if p.requires_grad))"
234 | ]
235 | },
236 | {
237 | "cell_type": "markdown",
238 | "metadata": {},
239 | "source": [
240 | "The dataset can be found [here](https://github.com/timothelaborie/text_classification_scripts/blob/main/data/finance_sentiment_multiclass.csv)"
241 | ]
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": 5,
246 | "metadata": {},
247 | "outputs": [
248 | {
249 | "name": "stdout",
250 | "output_type": "stream",
251 | "text": [
252 | "3893\n"
253 | ]
254 | }
255 | ],
256 | "source": [
257 | "kaggle = os.getcwd() == \"/kaggle/working\"\n",
258 | "input_dir = \"/kaggle/input/whatever/\" if kaggle else \"data/\"\n",
259 | "data = pd.read_csv(input_dir + \"finance_sentiment_multiclass.csv\") # columns are text,label\n",
260 | "\n",
261 | "train_df, val_df = train_test_split(data, test_size=0.1, random_state=42)\n",
262 | "print(len(train_df))"
263 | ]
264 | },
265 | {
266 | "cell_type": "code",
267 | "execution_count": 6,
268 | "metadata": {},
269 | "outputs": [
270 | {
271 | "data": {
272 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGdCAYAAAAxCSikAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAfQElEQVR4nO3de2zV9f3H8dexlwPU9owW6fGMIkU7by3MFcfoUNgKJazIDMlQ8YIBE5XLOAPGdQnVaNthBqhMFh0BZsdqFqljQx1lYh0hzlrpLHVBjQXLbG2c9bTFeorl8/tjP77xtDDtBc+n3z4fyUns9/s57edtUZ75npvHGGMEAABgsYuivQEAAIAvQ7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsF5stDfQG2fOnNEHH3ygxMREeTyeaG8HAAB8BcYYtba2KhAI6KKLenbNZEAGywcffKC0tLRobwMAAPRCfX29Ro0a1aP7DMhgSUxMlPTfgZOSkqK8GwAA8FW0tLQoLS3N+Xu8JwZksJx9GCgpKYlgAQBggOnN0zl40i0AALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKwXG+0NIPrGrNnX6/seL87vx50AAHBuXGEBAADWI1gAAID1CBYAAGA9ggUAAFiPYAEAANYjWAAAgPUIFgAAYD2CBQAAWI9gAQAA1iNYAACA9QgWAABgPYIFAABYj2ABAADWI1gAAID1CBYAAGA9ggUAAFiPYAEAANYjWAAAgPUIFgAAYD2CBQAAWI9gAQAA1iNYAACA9QgWAABgPYIFAABYj2ABAADWI1gAAID1CBYAAGA9ggUAAFiPYAEAANYjWAAAgPUIFgAAYD2CBQAAWI9gAQAA1iNYAACA9QgWAABgPYIFAABYj2ABAADWI1gAAID1CBYAAGA9ggUAAFiPYAEAANYjWAAAgPX6FCxFRUXyeDwKBoPOMWOMCgoKFAgENHToUE2dOlW1tbUR9wuHw1q6dKlGjBihhIQEzZ49WydPnuzLVgAAgIv1OlgqKyv15JNPaty4cRHHN27cqE2bNmnr1q2qrKyU3+/X9OnT1dra6qwJBoMqKytTaWmpDh06pLa2Ns2aNUudnZ29nwQAALhWr4Klra1Nt99+u5566ikNHz7cOW6M0ZYtW7R+/XrNmTNHmZmZ2rVrlz799FPt3r1bkhQKhbR9+3b96le/0rRp03TdddeppKRENTU1OnDgQP9MBQAAXKVXwbJ48WLl5+dr2rRpEcfr6urU2NiovLw855jX69WUKVN0+PBhSVJVVZVOnz4dsSYQCCgzM9NZ01U4HFZLS0vEDQAADB6xPb1DaWmp3njjDVVWVnY719jYKElKTU2NOJ6amqoTJ044a+Lj4yOuzJxdc/b+XRUVFemBBx7o6VYBAIBL9OgKS319vZYtW6aSkhINGTLkvOs8Hk/E18aYbse6+l9r1q5dq1Ao5Nzq6+t7sm0AADDA9ShYqqqq1NTUpOzsbMXGxio2NlYVFRV67LHHFBsb61xZ6XqlpKmpyTnn9/vV0dGh5ubm867pyuv1KikpKeIGAAAGjx4FS25urmpqalRdXe3cJkyYoNtvv13V1dUaO3as/H6/ysvLnft0dHSooqJCOTk5kqTs7GzFxcVFrGloaNDRo0edNQAAAF/Uo+ewJCYmKjMzM+JYQkKCUlJSnOPBYFCFhYXKyMhQRkaGCgsLNWzYMM2bN0+S5PP5tHDhQq1YsUIpKSlKTk7WypUrlZWV1e1JvAAAAFIvnnT7ZVatWqX29nYtWrRIzc3Nmjhxovbv36/ExERnzebNmxUbG6u5c+eqvb1dubm52rlzp2JiYvp7OwAAwAU8xhgT7U30VEtLi3w+n0KhEM9n6Qdj1uzr9X2PF+f3404AAG7Wl7+/+SwhAABgPYIFAABYj2ABAADWI1gAAID1CBYAAGA9ggUAAFiPYAEAANYjWAAAgPUIFgAAYD2CBQAAWI9gAQAA1iNYAACA9QgWAABgPYIFAABYj2ABAADWI1gAAID1CBYAAGA9ggUAAFiPYAEAANYjWAAAgPUIFgAAYD2CBQAAWI9gAQAA1iNYAACA9QgWAABgPYIFAABYj2ABAADWI1gAAID1CBYAAGA9ggUAAFiPYAEAANYjWAAAgPVio70BDF5j1uzr9X2PF+f3404AALbjCgsAALAewQIAAKxHsAAAAOsRLAAAwHo86RZ90pcnzgIA8FVxhQUAAFiPYAEAANYjWAAAgPUIFgAAYD2CBQAAWI9gAQAA1iNYAACA9QgWAABgPYIFAABYj2ABAADWI1gAAID1CBYAAGA9ggUAAFiPYAEAANYjWAAAgPUIFgAAYD2CBQAAWI9gAQAA1iNYAACA9QgWAABgPYIFAABYj2ABAADWI1gAAID1CBYAAGA9ggUAAFiPYAEAANYjWAAAgPV6FCzbtm3TuHHjlJSUpKSkJE2aNEkvvPCCc94Yo4KCAgUCAQ0dOlRTp05VbW1txPcIh8NaunSpRowYoYSEBM2ePVsnT57sn2kAAIAr9ShYRo0apeLiYr3++ut6/fXX9cMf/lA//vGPnSjZuHGjNm3apK1bt6qyslJ+v1/Tp09Xa2ur8z2CwaDKyspUWlqqQ4cOqa2tTbNmzVJnZ2f/TgYAAFzDY4wxffkGycnJeuSRR7RgwQIFAgEFg0GtXr1a0n+vpqSmpuqXv/yl7r33XoVCIV1yySV6+umndcstt0iSPvjgA6Wlpen555/XjBkzvtLPbGlpkc/nUygUUlJSUl+2D0lj1uyL9hZ67HhxfrS3AADoob78/d3r57B0dnaqtLRUp06d0qRJk1RXV6fGxkbl5eU5a7xer6ZMmaLDhw9LkqqqqnT69OmINYFAQJmZmc6acwmHw2ppaYm4AQCAwaPHwVJTU6OLL75YXq9X9913n8rKynTNNdeosbFRkpSamhqxPjU11TnX2Nio+Ph4DR8+/LxrzqWoqEg+n8+5paWl9XTbAABgAOtxsFx55ZWqrq7Wq6++qvvvv1/z58/XW2+95Zz3eDwR640x3Y519WVr1q5dq1Ao5Nzq6+t7um0AADCA9ThY4uPjdcUVV2jChAkqKirS+PHj9eijj8rv90tStyslTU1NzlUXv9+vjo4ONTc3n3fNuXi9XueVSWdvAABg8Ojz+7AYYxQOh5Weni6/36/y8nLnXEdHhyoqKpSTkyNJys7OVlxcXMSahoYGHT161FkDAADQVWxPFq9bt04zZ85UWlqaWltbVVpaqpdfflkvvviiPB6PgsGgCgsLlZGRoYyMDBUWFmrYsGGaN2+eJMnn82nhwoVasWKFUlJSlJycrJUrVyorK0vTpk27IAMCAICBr0fB8uGHH+rOO+9UQ0ODfD6fxo0bpxdffFHTp0+XJK1atUrt7e1atGiRmpubNXHiRO3fv1+JiYnO99i8ebNiY2M1d+5ctbe3Kzc3Vzt37lRMTEz/TgYAAFyjz+/DEg28D0v/4n1YAABfh6i8DwsAAMDXhWABAADWI1gAAID1CBYAAGA9ggUAAFiPYAEAANYjWAAAgPUIFgAAYD2CBQAAWK9Hb80PAF9VX95BmXcyBtAVV1gAAID1CBYAAGA9ggUAAFiPYAEAANYjWAAAgPUIFgAAYD2CBQAAWI9gAQAA1iNYAACA9QgWAABgPYIFAABYj2ABAADW48MPAZfjQwgBuAFXWAAAgPUIFgAAYD2CBQAAWI9gAQAA1uNJtxh0eBIqAAw8XGEBAADWI1gAAID1eEjIJfryMAcAALbjCgsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB6vErIIr/QBAODcuMICAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArMc73WJA4l2BAWBw4QoLAACwHsECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHp/WDHxN+vIJ08eL8/txJwAw8HCFBQAAWI9gAQAA1iNYAACA9QgWAABgvR4FS1FRka6//nolJiZq5MiRuvnmm3Xs2LGINcYYFRQUKBAIaOjQoZo6dapqa2sj1oTDYS1dulQjRoxQQkKCZs+erZMnT/Z9GgAA4Eo9CpaKigotXrxYr776qsrLy/X5558rLy9Pp06dctZs3LhRmzZt0tatW1VZWSm/36/p06ertbXVWRMMBlVWVqbS0lIdOnRIbW1tmjVrljo7O/tvMgAA4Bo9elnziy++GPH1jh07NHLkSFVVVenGG2+UMUZbtmzR+vXrNWfOHEnSrl27lJqaqt27d+vee+9VKBTS9u3b9fTTT2vatGmSpJKSEqWlpenAgQOaMWNGP40GAADcok/PYQmFQpKk5ORkSVJdXZ0aGxuVl5fnrPF6vZoyZYoOHz4sSaqqqtLp06cj1gQCAWVmZjprAAAAvqjXbxxnjNHy5cs1efJkZWZmSpIaGxslSampqRFrU1NTdeLECWdNfHy8hg8f3m3N2ft3FQ6HFQ6Hna9bWlp6u20AADAA9foKy5IlS/Tmm2/qD3/4Q7dzHo8n4mtjTLdjXf2vNUVFRfL5fM4tLS2tt9sGAAADUK+CZenSpdq7d68OHjyoUaNGOcf9fr8kdbtS0tTU5Fx18fv96ujoUHNz83nXdLV27VqFQiHnVl9f35ttAwCAAapHwWKM0ZIlS7Rnzx699NJLSk9Pjzifnp4uv9+v8vJy51hHR4cqKiqUk5MjScrOzlZcXFzEmoaGBh09etRZ05XX61VSUlLEDQAADB49eg7L4sWLtXv3bv3pT39SYmKicyXF5/Np6NCh8ng8CgaDKiwsVEZGhjIyMlRYWKhhw4Zp3rx5ztqFCxdqxYoVSklJUXJyslauXKmsrCznVUMAAABf1KNg2bZtmyRp6tSpEcd37Nihu+++W5K0atUqtbe3a9GiRWpubtbEiRO1f/9+JSYmOus3b96s2NhYzZ07V+3t7crNzdXOnTsVExPTt2kAAIAr9ShYjDFfusbj8aigoEAFBQXnXTNkyBA9/vjjevzxx3vy4wEAwCDFZwkBAADrESwAAMB6BAsAALBer9/pFoD7jVmzL9pbAABJXGEBAAADAMECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArBcb7Q0AA8mYNfuivQUAGJS4wgIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADr8db8AFylLx+fcLw4vx93AqA/cYUFAABYj2ABAADWI1gAAID1eA4LAOv05XkoANyJKywAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArMf7sAADAO9LAmCw4woLAACwHsECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADrESwAAMB6PQ6WV155RTfddJMCgYA8Ho+ee+65iPPGGBUUFCgQCGjo0KGaOnWqamtrI9aEw2EtXbpUI0aMUEJCgmbPnq2TJ0/2aRAAAOBePQ6WU6dOafz48dq6des5z2/cuFGbNm3S1q1bVVlZKb/fr+nTp6u1tdVZEwwGVVZWptLSUh06dEhtbW2aNWuWOjs7ez8JAABwrdie3mHmzJmaOXPmOc8ZY7RlyxatX79ec+bMkSTt2rVLqamp2r17t+69916FQiFt375dTz/9tKZNmyZJKikpUVpamg4cOKAZM2b0YRwAAOBG/foclrq6OjU2NiovL8855vV6NWXKFB0+fFiSVFVVpdOnT0esCQQCyszMdNZ0FQ6H1dLSEnEDAACDR78GS2NjoyQpNTU14nhqaqpzrrGxUfHx8Ro+fPh513RVVFQkn8/n3NLS0vpz2wAAwHIX5FVCHo8n4mtjTLdjXf2vNWvXrlUoFHJu9fX1/bZXAABgv34NFr/fL0ndrpQ0NTU5V138fr86OjrU3Nx83jVdeb1eJSUlRdwAAMDg0a/Bkp6eLr/fr/LycudYR0eHKioqlJOTI0nKzs5WXFxcxJqGhgYdPXrUWQMAAPBFPX6VUFtbm959913n67q6OlVXVys5OVmjR49WMBhUYWGhMjIylJGRocLCQg0bNkzz5s2TJPl8Pi1cuFArVqxQSkqKkpOTtXLlSmVlZTmvGgIAAPiiHgfL66+/rh/84AfO18uXL5ckzZ8/Xzt37tSqVavU3t6uRYsWqbm5WRMnTtT+/fuVmJjo3Gfz5s2KjY3V3Llz1d7ertzcXO3cuVMxMTH9MBIAAHAbjzHGRHsTPdXS0iKfz6dQKOSq57OMWbMv2lsABrXjxfnR3gLgan35+5vPEgIAANYjWAAAgPUIFgAAYD2CBQAAWI9gAQAA1iNYAACA9QgWAABgPYIFAABYj2ABAADWI1gAAID1CBYAAGA9ggUAAFiPYAEAANYjWAAAgPVio70BAHCDMWv29fq+x4vz+3EngDtxhQUAAFiPYAEAANbjISEA+H99eVgHwIXFFRYAAGA9ggUAAFiPYAEAANYjWAAAgPUIFgAAYD2CBQAAWI9gAQAA1iNYAACA9QgWAABgPYIFAABYj2ABAADWI1gAAID1CBYAAGA9ggUAAFiPYAEAANYjWAAAgPUIFgAAYD2CBQAAWI9gAQAA1iNYAACA9QgWAABgPYIFAABYj2ABAADWI1gAAID1CBYAAGA9ggUAAFiPYAEAANaLjfYG3GbMmn3R3gIAAK7DFRYAAGA9ggUAAFiPYAEAANbjOSwAMID19Xlzx4vz+2knwIXFFRYAAGA9ggUAAFiPh4QAIMp4OwTgy3GFBQAAWI9gAQAA1iNYAACA9QgWAABgPYIFAABYj2ABAADWI1gAAID1CBYAAGA9ggUAAFiPd7o9B951EgAAu3CFBQAAWI9gAQAA1ovqQ0JPPPGEHnnkETU0NOjaa6/Vli1bdMMNN0RzSwAwqETrIfDjxflR+bkYuKJ2heWZZ55RMBjU+vXrdeTIEd1www2aOXOm3n///WhtCQAAWCpqwbJp0yYtXLhQ99xzj66++mpt2bJFaWlp2rZtW7S2BAAALBWVh4Q6OjpUVVWlNWvWRBzPy8vT4cOHu60Ph8MKh8PO16FQSJLU0tJyQfZ3JvzpBfm+AID/Gv2zP0bl5x59YEZUfm60ZG74a6/veyH+XZ39e9sY0+P7RiVYPvroI3V2dio1NTXieGpqqhobG7utLyoq0gMPPNDteFpa2gXbIwDAfXxbor2DgeNC/rtqbW2Vz+fr0X2i+qRbj8cT8bUxptsxSVq7dq2WL1/ufH3mzBl9/PHHSklJOef6L2ppaVFaWprq6+uVlJTUPxu3GPO6G/O622CbVxp8Mw/2eY0xam1tVSAQ6PH3ikqwjBgxQjExMd2upjQ1NXW76iJJXq9XXq834tg3vvGNHv3MpKSkQfGH4yzmdTfmdbfBNq80+GYezPP29MrKWVF50m18fLyys7NVXl4ecby8vFw5OTnR2BIAALBY1B4SWr58ue68805NmDBBkyZN0pNPPqn3339f9913X7S2BAAALBW1YLnlllv0n//8Rw8++KAaGhqUmZmp559/Xpdddlm//hyv16sNGzZ0e0jJrZjX3ZjX3QbbvNLgm5l5e89jevPaIgAAgK8RnyUEAACsR7AAAADrESwAAMB6BAsAALCeq4PliSeeUHp6uoYMGaLs7Gz9/e9/j/aW+s0rr7yim266SYFAQB6PR88991zEeWOMCgoKFAgENHToUE2dOlW1tbXR2WwfFRUV6frrr1diYqJGjhypm2++WceOHYtY46Z5t23bpnHjxjlvtDRp0iS98MILznk3zXouRUVF8ng8CgaDzjG3zVxQUCCPxxNx8/v9znm3zStJ//73v3XHHXcoJSVFw4YN07e//W1VVVU5590085gxY7r9fj0ejxYvXizJXbNK0ueff65f/OIXSk9P19ChQzV27Fg9+OCDOnPmjLOmX2Y2LlVaWmri4uLMU089Zd566y2zbNkyk5CQYE6cOBHtrfWL559/3qxfv948++yzRpIpKyuLOF9cXGwSExPNs88+a2pqaswtt9xiLr30UtPS0hKdDffBjBkzzI4dO8zRo0dNdXW1yc/PN6NHjzZtbW3OGjfNu3fvXrNv3z5z7Ngxc+zYMbNu3ToTFxdnjh49aoxx16xdvfbaa2bMmDFm3LhxZtmyZc5xt828YcMGc+2115qGhgbn1tTU5Jx327wff/yxueyyy8zdd99t/vGPf5i6ujpz4MAB8+677zpr3DRzU1NTxO+2vLzcSDIHDx40xrhrVmOMeeihh0xKSor5y1/+Yurq6swf//hHc/HFF5stW7Y4a/pjZtcGy3e/+11z3333RRy76qqrzJo1a6K0owuna7CcOXPG+P1+U1xc7Bz77LPPjM/nM7/5zW+isMP+1dTUZCSZiooKY4z75zXGmOHDh5vf/va3rp61tbXVZGRkmPLycjNlyhQnWNw484YNG8z48ePPec6N865evdpMnjz5vOfdOPMXLVu2zFx++eXmzJkzrpw1Pz/fLFiwIOLYnDlzzB133GGM6b/frysfEuro6FBVVZXy8vIijufl5enw4cNR2tXXp66uTo2NjRHze71eTZkyxRXzh0IhSVJycrIkd8/b2dmp0tJSnTp1SpMmTXL1rIsXL1Z+fr6mTZsWcdytM7/zzjsKBAJKT0/Xrbfeqvfee0+SO+fdu3evJkyYoJ/85CcaOXKkrrvuOj311FPOeTfOfFZHR4dKSkq0YMECeTweV846efJk/e1vf9Pbb78tSfrnP/+pQ4cO6Uc/+pGk/vv9RvXTmi+Ujz76SJ2dnd0+SDE1NbXbBy660dkZzzX/iRMnorGlfmOM0fLlyzV58mRlZmZKcue8NTU1mjRpkj777DNdfPHFKisr0zXXXOP8x+2mWSWptLRUb7zxhiorK7udc+Pvd+LEifrd736nb33rW/rwww/10EMPKScnR7W1ta6c97333tO2bdu0fPlyrVu3Tq+99pp++tOfyuv16q677nLlzGc999xz+uSTT3T33XdLcuef59WrVysUCumqq65STEyMOjs79fDDD+u2226T1H8zuzJYzvJ4PBFfG2O6HXMzN86/ZMkSvfnmmzp06FC3c26a98orr1R1dbU++eQTPfvss5o/f74qKiqc826atb6+XsuWLdP+/fs1ZMiQ865z08wzZ850/jkrK0uTJk3S5Zdfrl27dul73/ueJHfNe+bMGU2YMEGFhYWSpOuuu061tbXatm2b7rrrLmedm2Y+a/v27Zo5c6YCgUDEcTfN+swzz6ikpES7d+/Wtddeq+rqagWDQQUCAc2fP99Z19eZXfmQ0IgRIxQTE9PtakpTU1O3wnOjs682cNv8S5cu1d69e3Xw4EGNGjXKOe7GeePj43XFFVdowoQJKioq0vjx4/Xoo4+6ctaqqio1NTUpOztbsbGxio2NVUVFhR577DHFxsY6c7lp5q4SEhKUlZWld955x5W/40svvVTXXHNNxLGrr75a77//viR3/jcsSSdOnNCBAwd0zz33OMfcOOvPf/5zrVmzRrfeequysrJ055136mc/+5mKiook9d/MrgyW+Ph4ZWdnq7y8POJ4eXm5cnJyorSrr096err8fn/E/B0dHaqoqBiQ8xtjtGTJEu3Zs0cvvfSS0tPTI867bd5zMcYoHA67ctbc3FzV1NSourrauU2YMEG33367qqurNXbsWNfN3FU4HNa//vUvXXrppa78HX//+9/v9lYEb7/9tvNht26cWZJ27NihkSNHKj8/3znmxlk//fRTXXRRZE7ExMQ4L2vut5l7/7xgu519WfP27dvNW2+9ZYLBoElISDDHjx+P9tb6RWtrqzly5Ig5cuSIkWQ2bdpkjhw54rxsu7i42Ph8PrNnzx5TU1NjbrvttgH7srn777/f+Hw+8/LLL0e8VPDTTz911rhp3rVr15pXXnnF1NXVmTfffNOsW7fOXHTRRWb//v3GGHfNej5ffJWQMe6becWKFebll1827733nnn11VfNrFmzTGJiovP/J7fN+9prr5nY2Fjz8MMPm3feecf8/ve/N8OGDTMlJSXOGrfN3NnZaUaPHm1Wr17d7ZzbZp0/f7755je/6bysec+ePWbEiBFm1apVzpr+mNm1wWKMMb/+9a/NZZddZuLj4813vvMd52WwbnDw4EEjqdtt/vz5xpj/voxsw4YNxu/3G6/Xa2688UZTU1MT3U330rnmlGR27NjhrHHTvAsWLHD+3F5yySUmNzfXiRVj3DXr+XQNFrfNfPY9KOLi4kwgEDBz5swxtbW1znm3zWuMMX/+859NZmam8Xq95qqrrjJPPvlkxHm3zfzXv/7VSDLHjh3rds5ts7a0tJhly5aZ0aNHmyFDhpixY8ea9evXm3A47Kzpj5k9xhjT28tAAAAAXwdXPocFAAC4C8ECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAev8HPhxKhsJKwhAAAAAASUVORK5CYII=",
273 | "text/plain": [
274 | ""
275 | ]
276 | },
277 | "metadata": {},
278 | "output_type": "display_data"
279 | }
280 | ],
281 | "source": [
282 | "token_counts = [len(tokenizer.encode(x)) for x in train_df.text]\n",
283 | "# plot the token counts\n",
284 | "a = plt.hist(token_counts, bins=30)"
285 | ]
286 | },
287 | {
288 | "cell_type": "code",
289 | "execution_count": 7,
290 | "metadata": {
291 | "colab": {
292 | "base_uri": "https://localhost:8080/",
293 | "height": 145,
294 | "referenced_widgets": [
295 | "e0ea75df3b5c49ce89427e4d245a7646",
296 | "02237d111eba4155ab5a48a1b33d82a4",
297 | "927e1d33248d4653a785049b50b6d814",
298 | "09979ecafec9443ba1f7335ed64de778",
299 | "088a000de2234c9f94a8b09e8a8abbb2",
300 | "38c7b4ead56b4540bf68fbe3a5496b9c",
301 | "9a0787c7c98b49b8a9141e5212faa249",
302 | "daf62811f5384d0f9aea58b40f23161a",
303 | "6f90172dd5c240c885140d49add6208f",
304 | "187078f1978843f2b873cee2ee55ac91",
305 | "ca8af702ee764e7fa620476854a4f2cf",
306 | "80b9d976e30e40ecaf35a606bca5d647",
307 | "e9b423d370314fdfa3c22e95c3a35d92",
308 | "307f44c0a4da4d3ca0d44b54f9a5f6c0",
309 | "5503539430024c7586d4f8589c92fd74",
310 | "b17542f2cd0f45e79e755df622328358",
311 | "fe6d1bab4fd042b5af89f6bb8a73cc39",
312 | "8cb7748eaa354d4ea0e3222685f1d9b4",
313 | "53bfefa3e2c04cce9ab7d2b3fbca59d9",
314 | "aa58f2124335451890857ded72414ecf",
315 | "b331b446e257441387b331d822f570a0",
316 | "49e3298fbe054c7abb6a041ecd63737c",
317 | "6b104e7242cf4a608cfd0c4fc6f44db1",
318 | "1a6288e052884d5aa486415880acf134",
319 | "289c8b4d1686443f8bed7441673a2c15",
320 | "1318a594d2df47779b84f22e4b0c0702",
321 | "355ecaa94d1a43b09e1d0b3f6e6ac8e3",
322 | "156432d2f1b74697bd74604e1bd5bf13",
323 | "976fa37d0fa14282bc96b3c948b9da94",
324 | "24768bbd587d4b1ba09b6c52dcd6237c",
325 | "cdf054995f314976b7124de9564cbd9b",
326 | "1eec25504a6f41d199940c4232b0a114",
327 | "007a9dbfb5c84efd8ffd801bc04c0c20",
328 | "2ffddabe91124b279f9a3ac41d1ead9d",
329 | "f97edc77840e49c49aed8e7ab80f45d1",
330 | "33761916bce24fde9bcff866736ddad7",
331 | "726aae420987454e87a08b9314844895",
332 | "071858b3801e4844a4fc5a91205335c4",
333 | "6a98441b0ff74361a79c96178834d8bc",
334 | "40e8b59e229548c68ccb339234a5e525",
335 | "56b011c2482b427483508e259dc231fd",
336 | "507c15541fe7489aa1fd1fd81c4aa222",
337 | "4bf05ee528bb4bb2979cde4d0a09544b",
338 | "d62a5696b9704580b69e6fc3be9ffa8a"
339 | ]
340 | },
341 | "id": "LjY75GoYUCB8",
342 | "outputId": "26e1bc8e-c4e8-472e-ca91-670e3381c6b3"
343 | },
344 | "outputs": [],
345 | "source": [
346 | "prompt = \"\"\"Here is a financial news:\n",
347 | "{}\n",
348 | "\n",
349 | "Classify this news into one of the following:\n",
350 | "class 1: Bullish\n",
351 | "class 2: Neutral\n",
352 | "class 3: Bearish\n",
353 | "\n",
354 | "SOLUTION\n",
355 | "The correct answer is: class {}\"\"\"\n",
356 | "\n",
357 | "def formatting_prompts_func(dataset_):\n",
358 | " texts = []\n",
359 | " for i in range(len(dataset_['text'])):\n",
360 | " text_ = dataset_['text'].iloc[i]\n",
361 | " label_ = dataset_['label'].iloc[i] # the csv is setup so that the label column corresponds exactly to the 3 classes defined above in the prompt (important)\n",
362 | "\n",
363 | " text = prompt.format(text_, label_)\n",
364 | "\n",
365 | " texts.append(text)\n",
366 | " return texts\n",
367 | "\n",
368 | "# apply formatting_prompts_func to train_df\n",
369 | "train_df['text'] = formatting_prompts_func(train_df)\n",
370 | "train_dataset = datasets.Dataset.from_pandas(train_df,preserve_index=False)"
371 | ]
372 | },
373 | {
374 | "cell_type": "code",
375 | "execution_count": 8,
376 | "metadata": {},
377 | "outputs": [],
378 | "source": [
379 | "# this custom collator makes it so the model trains only on the last token of the sequence. It also maps from the old tokenizer to the new lm_head indices\n",
380 | "class DataCollatorForLastTokenLM(DataCollatorForLanguageModeling):\n",
381 | " def __init__(\n",
382 | " self,\n",
383 | " *args,\n",
384 | " mlm: bool = False,\n",
385 | " ignore_index: int = -100,\n",
386 | " **kwargs,\n",
387 | " ):\n",
388 | " super().__init__(*args, mlm=mlm, **kwargs)\n",
389 | " self.ignore_index = ignore_index\n",
390 | "\n",
391 | " def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:\n",
392 | " batch = super().torch_call(examples)\n",
393 | "\n",
394 | " for i in range(len(examples)):\n",
395 | " # Find the last non-padding token\n",
396 | " last_token_idx = (batch[\"labels\"][i] != self.ignore_index).nonzero()[-1].item()\n",
397 | " # Set all labels to ignore_index except for the last token\n",
398 | " batch[\"labels\"][i, :last_token_idx] = self.ignore_index\n",
399 | " # If the last token in the text is, for example, \"2\", then this was processed with the old tokenizer into number_token_ids[2]\n",
400 | " # But we don't actually want this because number_token_ids[2] could be something like 27, which is now undefined in the new lm_head. So we map it to the new lm_head index.\n",
401 | " # if this line gives you a keyerror then increase max_seq_length\n",
402 | " batch[\"labels\"][i, last_token_idx] = reverse_map[ batch[\"labels\"][i, last_token_idx].item() ]\n",
403 | "\n",
404 | "\n",
405 | " return batch\n",
406 | "collator = DataCollatorForLastTokenLM(tokenizer=tokenizer)"
407 | ]
408 | },
409 | {
410 | "cell_type": "markdown",
411 | "metadata": {
412 | "id": "idAEIeSQ3xdS"
413 | },
414 | "source": [
415 | "\n",
416 | "### Train the model\n",
417 | "Now let's use Huggingface TRL's `SFTTrainer`! More docs here: [TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support TRL's `DPOTrainer`!"
418 | ]
419 | },
420 | {
421 | "cell_type": "code",
422 | "execution_count": 9,
423 | "metadata": {
424 | "colab": {
425 | "base_uri": "https://localhost:8080/",
426 | "height": 177,
427 | "referenced_widgets": [
428 | "477d5041a08f4a3a9f7cfa1d98ab48ff",
429 | "25eecdf89b8845a989ce2c8c4d9edbeb",
430 | "e02470ee64ad4f0fb2160b088cdcba7f",
431 | "b7c0858a80684aed9c19ce00dda85815",
432 | "d0b82646d6f549d7ad59e9ff5f426e88",
433 | "d94a8d1f9d294abbb156e39da065910f",
434 | "ebb0f92e750447a39ff23a23ea73b445",
435 | "d4d01dc3290b4174ab13454a28faf972",
436 | "c6941f8c60ef49a8a79ed265c46652c2",
437 | "24aa300026334db1b545bf0fa906112e",
438 | "6d0d021108b54147a62e50fea1ba88ea",
439 | "deb2982b19764ed5aec0b4d80e776279",
440 | "c24904c0a7294f3a93bb64aa38e70316",
441 | "26a77ab74a4a4e21b9afd9798a9f9a29",
442 | "7189bac8d0474bcea50cf8711259516c",
443 | "ce81350896a44331aa6b6960b7370325",
444 | "73d4af57ddc64fb3afd0e2ab068cbcb4",
445 | "672d990adee44df6b58e27fe5804986f",
446 | "3fe95fc9bc034a2db85ed19aedfd4250",
447 | "c1c0053b8e674ed6ac81316d4af82c48",
448 | "74a0e53405cd4d64bd597ad5461ed5dd",
449 | "278d35f6e08d45e2a49c3509dd442f0c",
450 | "b19d83c6f5ad4f04941e6a684678ac07",
451 | "88e100a175e64a3dab6f772847deffd6",
452 | "457b5df3a4294c8a966e233d34c15a82",
453 | "18af2c44a24b4aaabfd192e3fc4bf655",
454 | "61944d9473394be5b19cfd48fd504481",
455 | "68d18369acc540a594aef61a2de07e63",
456 | "69c676782e0b496b820b55c45424177d",
457 | "08dcaabc623c4759a00fbee5430e3ba9",
458 | "45a3bcaffc3f4184b9ae869f7b0ccef4",
459 | "c45f02fbb40e4041ba7f95074f8e1e82",
460 | "0fc1037a17e541eba92a2d6f400ac6eb",
461 | "7bd2cc9aa724408fa9e70795744cad85",
462 | "fb0588bffd7a4238bac3f73052e09335",
463 | "68ff2006b3794e23b4ccbc2c83c52b9b",
464 | "910f2e6fffd24cd7b6e68c932c2a2524",
465 | "6345dc6f40c6444aa05ed2aba7809a3c",
466 | "92eab7fadbed4bc2bc2935b4e3d800cf",
467 | "2dfef2cd79c8463cb7eadad6a04aba91",
468 | "3be5c37493e742aebf0aa29b6723283b",
469 | "273be47263384901b6cf9f249ee3409d",
470 | "2ff515b72bbb43c889e2172c83933803",
471 | "a1cd830a712d490181d176267ce7b6f0",
472 | "1a8da471604841ec9f6c8e0073471c00",
473 | "f899ae16708544379d63960be07b7c32",
474 | "0bdbf9b92f7b4925a5ac28df93fd0fb0",
475 | "c1d8820a789f4899a839e6d28a4f333c",
476 | "e711d7f85eee4fe195fad9fbddcfece2",
477 | "56ed5fd876d94ebbaa8b4e905c348a0d",
478 | "5d461f10bdc44c6b95d070fa9d7425d1",
479 | "f4e0e7a39ad9484f930e6673c35c2f5d",
480 | "baad9118cade4cee9600a7fbf6426e0a",
481 | "16fac1aad22444c4ad42ad0e6e1dfdc9",
482 | "ac35368de2d746b4bed736459d187dbd"
483 | ]
484 | },
485 | "id": "95_Nn-89DhsL",
486 | "outputId": "adb8cb5d-0ec3-4b79-83a7-5691873873e8"
487 | },
488 | "outputs": [
489 | {
490 | "data": {
491 | "application/vnd.jupyter.widget-view+json": {
492 | "model_id": "b03dbbeaeb214c5e9ccc6e6264bd69c2",
493 | "version_major": 2,
494 | "version_minor": 0
495 | },
496 | "text/plain": [
497 | "Map: 0%| | 0/3893 [00:00, ? examples/s]"
498 | ]
499 | },
500 | "metadata": {},
501 | "output_type": "display_data"
502 | }
503 | ],
504 | "source": [
505 | "trainer = SFTTrainer(\n",
506 | " model = model,\n",
507 | " tokenizer = tokenizer,\n",
508 | " train_dataset = train_dataset,\n",
509 | " max_seq_length = max_seq_length,\n",
510 | " dataset_num_proc = 1,\n",
511 | " packing = False, # not needed because group_by_length is True\n",
512 | " args = TrainingArguments(\n",
513 | " per_device_train_batch_size = 32,\n",
514 | " gradient_accumulation_steps = 1,\n",
515 | " warmup_steps = 10,\n",
516 | " learning_rate = 1e-4,\n",
517 | " fp16 = not torch.cuda.is_bf16_supported(),\n",
518 | " bf16 = torch.cuda.is_bf16_supported(),\n",
519 | " logging_steps = 1,\n",
520 | " optim = \"adamw_8bit\",\n",
521 | " weight_decay = 0.01,\n",
522 | " lr_scheduler_type = \"cosine\",\n",
523 | " seed = 3407,\n",
524 | " output_dir = \"outputs\",\n",
525 | " num_train_epochs = 1,\n",
526 | " # report_to = \"wandb\",\n",
527 | " report_to = \"none\",\n",
528 | " group_by_length = True,\n",
529 | " ),\n",
530 | " data_collator=collator,\n",
531 | " dataset_text_field=\"text\",\n",
532 | ")"
533 | ]
534 | },
535 | {
536 | "cell_type": "code",
537 | "execution_count": 10,
538 | "metadata": {
539 | "cellView": "form",
540 | "colab": {
541 | "base_uri": "https://localhost:8080/"
542 | },
543 | "id": "2ejIt2xSNKKp",
544 | "outputId": "815b67fb-14a2-43ab-d587-3cbe038b4349"
545 | },
546 | "outputs": [
547 | {
548 | "name": "stdout",
549 | "output_type": "stream",
550 | "text": [
551 | "GPU = NVIDIA GeForce RTX 3090. Max memory = 23.999 GB.\n",
552 | "8.41 GB of memory reserved.\n"
553 | ]
554 | }
555 | ],
556 | "source": [
557 | "#@title Show current memory stats\n",
558 | "gpu_stats = torch.cuda.get_device_properties(0)\n",
559 | "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
560 | "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
561 | "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
562 | "print(f\"{start_gpu_memory} GB of memory reserved.\")"
563 | ]
564 | },
565 | {
566 | "cell_type": "code",
567 | "execution_count": 11,
568 | "metadata": {
569 | "colab": {
570 | "base_uri": "https://localhost:8080/",
571 | "height": 1000
572 | },
573 | "id": "yqxqAZ7KJ4oL",
574 | "outputId": "16039f41-2abb-44e6-f020-4e1d4fcc0102"
575 | },
576 | "outputs": [
577 | {
578 | "name": "stderr",
579 | "output_type": "stream",
580 | "text": [
581 | "==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1\n",
582 | " \\\\ /| Num examples = 3,893 | Num Epochs = 1 | Total steps = 122\n",
583 | "O^O/ \\_/ \\ Batch size per device = 32 | Gradient accumulation steps = 1\n",
584 | "\\ / Data Parallel GPUs = 1 | Total batch size (32 x 1 x 1) = 32\n",
585 | " \"-____-\" Trainable parameters = 33,040,384/4,055,518,720 (0.81% trained)\n"
586 | ]
587 | },
588 | {
589 | "data": {
590 | "text/html": [
591 | "\n",
592 | " \n",
593 | " \n",
594 | "
\n",
595 | " [122/122 04:36, Epoch 1/1]\n",
596 | "
\n",
597 | " \n",
598 | " \n",
599 | " \n",
600 | " Step | \n",
601 | " Training Loss | \n",
602 | "
\n",
603 | " \n",
604 | " \n",
605 | " \n",
606 | " 1 | \n",
607 | " 1.357800 | \n",
608 | "
\n",
609 | " \n",
610 | " 2 | \n",
611 | " 1.182200 | \n",
612 | "
\n",
613 | " \n",
614 | " 3 | \n",
615 | " 1.442300 | \n",
616 | "
\n",
617 | " \n",
618 | " 4 | \n",
619 | " 1.152100 | \n",
620 | "
\n",
621 | " \n",
622 | " 5 | \n",
623 | " 1.033100 | \n",
624 | "
\n",
625 | " \n",
626 | " 6 | \n",
627 | " 0.988200 | \n",
628 | "
\n",
629 | " \n",
630 | " 7 | \n",
631 | " 0.841600 | \n",
632 | "
\n",
633 | " \n",
634 | " 8 | \n",
635 | " 1.047500 | \n",
636 | "
\n",
637 | " \n",
638 | " 9 | \n",
639 | " 0.967200 | \n",
640 | "
\n",
641 | " \n",
642 | " 10 | \n",
643 | " 0.908900 | \n",
644 | "
\n",
645 | " \n",
646 | " 11 | \n",
647 | " 0.850100 | \n",
648 | "
\n",
649 | " \n",
650 | " 12 | \n",
651 | " 0.800500 | \n",
652 | "
\n",
653 | " \n",
654 | " 13 | \n",
655 | " 0.602900 | \n",
656 | "
\n",
657 | " \n",
658 | " 14 | \n",
659 | " 0.828200 | \n",
660 | "
\n",
661 | " \n",
662 | " 15 | \n",
663 | " 0.662600 | \n",
664 | "
\n",
665 | " \n",
666 | " 16 | \n",
667 | " 0.689000 | \n",
668 | "
\n",
669 | " \n",
670 | " 17 | \n",
671 | " 0.578200 | \n",
672 | "
\n",
673 | " \n",
674 | " 18 | \n",
675 | " 0.731100 | \n",
676 | "
\n",
677 | " \n",
678 | " 19 | \n",
679 | " 0.568800 | \n",
680 | "
\n",
681 | " \n",
682 | " 20 | \n",
683 | " 0.467700 | \n",
684 | "
\n",
685 | " \n",
686 | " 21 | \n",
687 | " 1.106200 | \n",
688 | "
\n",
689 | " \n",
690 | " 22 | \n",
691 | " 0.909300 | \n",
692 | "
\n",
693 | " \n",
694 | " 23 | \n",
695 | " 0.731900 | \n",
696 | "
\n",
697 | " \n",
698 | " 24 | \n",
699 | " 0.800000 | \n",
700 | "
\n",
701 | " \n",
702 | " 25 | \n",
703 | " 0.636100 | \n",
704 | "
\n",
705 | " \n",
706 | " 26 | \n",
707 | " 0.745200 | \n",
708 | "
\n",
709 | " \n",
710 | " 27 | \n",
711 | " 0.816300 | \n",
712 | "
\n",
713 | " \n",
714 | " 28 | \n",
715 | " 0.698400 | \n",
716 | "
\n",
717 | " \n",
718 | " 29 | \n",
719 | " 0.443900 | \n",
720 | "
\n",
721 | " \n",
722 | " 30 | \n",
723 | " 0.481700 | \n",
724 | "
\n",
725 | " \n",
726 | " 31 | \n",
727 | " 0.577100 | \n",
728 | "
\n",
729 | " \n",
730 | " 32 | \n",
731 | " 0.610900 | \n",
732 | "
\n",
733 | " \n",
734 | " 33 | \n",
735 | " 0.405500 | \n",
736 | "
\n",
737 | " \n",
738 | " 34 | \n",
739 | " 0.698300 | \n",
740 | "
\n",
741 | " \n",
742 | " 35 | \n",
743 | " 0.443000 | \n",
744 | "
\n",
745 | " \n",
746 | " 36 | \n",
747 | " 0.538800 | \n",
748 | "
\n",
749 | " \n",
750 | " 37 | \n",
751 | " 0.426200 | \n",
752 | "
\n",
753 | " \n",
754 | " 38 | \n",
755 | " 0.404100 | \n",
756 | "
\n",
757 | " \n",
758 | " 39 | \n",
759 | " 0.754100 | \n",
760 | "
\n",
761 | " \n",
762 | " 40 | \n",
763 | " 0.349500 | \n",
764 | "
\n",
765 | " \n",
766 | " 41 | \n",
767 | " 0.663100 | \n",
768 | "
\n",
769 | " \n",
770 | " 42 | \n",
771 | " 0.372800 | \n",
772 | "
\n",
773 | " \n",
774 | " 43 | \n",
775 | " 0.407400 | \n",
776 | "
\n",
777 | " \n",
778 | " 44 | \n",
779 | " 0.422400 | \n",
780 | "
\n",
781 | " \n",
782 | " 45 | \n",
783 | " 0.392000 | \n",
784 | "
\n",
785 | " \n",
786 | " 46 | \n",
787 | " 0.458300 | \n",
788 | "
\n",
789 | " \n",
790 | " 47 | \n",
791 | " 0.468700 | \n",
792 | "
\n",
793 | " \n",
794 | " 48 | \n",
795 | " 0.627400 | \n",
796 | "
\n",
797 | " \n",
798 | " 49 | \n",
799 | " 0.364700 | \n",
800 | "
\n",
801 | " \n",
802 | " 50 | \n",
803 | " 0.288000 | \n",
804 | "
\n",
805 | " \n",
806 | " 51 | \n",
807 | " 0.350700 | \n",
808 | "
\n",
809 | " \n",
810 | " 52 | \n",
811 | " 0.255100 | \n",
812 | "
\n",
813 | " \n",
814 | " 53 | \n",
815 | " 0.335900 | \n",
816 | "
\n",
817 | " \n",
818 | " 54 | \n",
819 | " 0.333300 | \n",
820 | "
\n",
821 | " \n",
822 | " 55 | \n",
823 | " 0.299500 | \n",
824 | "
\n",
825 | " \n",
826 | " 56 | \n",
827 | " 0.383600 | \n",
828 | "
\n",
829 | " \n",
830 | " 57 | \n",
831 | " 0.552900 | \n",
832 | "
\n",
833 | " \n",
834 | " 58 | \n",
835 | " 0.114900 | \n",
836 | "
\n",
837 | " \n",
838 | " 59 | \n",
839 | " 0.531400 | \n",
840 | "
\n",
841 | " \n",
842 | " 60 | \n",
843 | " 0.441900 | \n",
844 | "
\n",
845 | " \n",
846 | " 61 | \n",
847 | " 0.586900 | \n",
848 | "
\n",
849 | " \n",
850 | " 62 | \n",
851 | " 0.826700 | \n",
852 | "
\n",
853 | " \n",
854 | " 63 | \n",
855 | " 0.425800 | \n",
856 | "
\n",
857 | " \n",
858 | " 64 | \n",
859 | " 0.369000 | \n",
860 | "
\n",
861 | " \n",
862 | " 65 | \n",
863 | " 0.443500 | \n",
864 | "
\n",
865 | " \n",
866 | " 66 | \n",
867 | " 0.684300 | \n",
868 | "
\n",
869 | " \n",
870 | " 67 | \n",
871 | " 0.519100 | \n",
872 | "
\n",
873 | " \n",
874 | " 68 | \n",
875 | " 0.437900 | \n",
876 | "
\n",
877 | " \n",
878 | " 69 | \n",
879 | " 0.525900 | \n",
880 | "
\n",
881 | " \n",
882 | " 70 | \n",
883 | " 0.226600 | \n",
884 | "
\n",
885 | " \n",
886 | " 71 | \n",
887 | " 0.264700 | \n",
888 | "
\n",
889 | " \n",
890 | " 72 | \n",
891 | " 0.378600 | \n",
892 | "
\n",
893 | " \n",
894 | " 73 | \n",
895 | " 0.392200 | \n",
896 | "
\n",
897 | " \n",
898 | " 74 | \n",
899 | " 0.271700 | \n",
900 | "
\n",
901 | " \n",
902 | " 75 | \n",
903 | " 0.177400 | \n",
904 | "
\n",
905 | " \n",
906 | " 76 | \n",
907 | " 0.299900 | \n",
908 | "
\n",
909 | " \n",
910 | " 77 | \n",
911 | " 0.145200 | \n",
912 | "
\n",
913 | " \n",
914 | " 78 | \n",
915 | " 0.204400 | \n",
916 | "
\n",
917 | " \n",
918 | " 79 | \n",
919 | " 0.361200 | \n",
920 | "
\n",
921 | " \n",
922 | " 80 | \n",
923 | " 0.257100 | \n",
924 | "
\n",
925 | " \n",
926 | " 81 | \n",
927 | " 0.214300 | \n",
928 | "
\n",
929 | " \n",
930 | " 82 | \n",
931 | " 0.532200 | \n",
932 | "
\n",
933 | " \n",
934 | " 83 | \n",
935 | " 0.573500 | \n",
936 | "
\n",
937 | " \n",
938 | " 84 | \n",
939 | " 0.183000 | \n",
940 | "
\n",
941 | " \n",
942 | " 85 | \n",
943 | " 0.089900 | \n",
944 | "
\n",
945 | " \n",
946 | " 86 | \n",
947 | " 0.127200 | \n",
948 | "
\n",
949 | " \n",
950 | " 87 | \n",
951 | " 0.360300 | \n",
952 | "
\n",
953 | " \n",
954 | " 88 | \n",
955 | " 0.415400 | \n",
956 | "
\n",
957 | " \n",
958 | " 89 | \n",
959 | " 0.389200 | \n",
960 | "
\n",
961 | " \n",
962 | " 90 | \n",
963 | " 0.539400 | \n",
964 | "
\n",
965 | " \n",
966 | " 91 | \n",
967 | " 0.322300 | \n",
968 | "
\n",
969 | " \n",
970 | " 92 | \n",
971 | " 0.638500 | \n",
972 | "
\n",
973 | " \n",
974 | " 93 | \n",
975 | " 0.321500 | \n",
976 | "
\n",
977 | " \n",
978 | " 94 | \n",
979 | " 0.411100 | \n",
980 | "
\n",
981 | " \n",
982 | " 95 | \n",
983 | " 0.489000 | \n",
984 | "
\n",
985 | " \n",
986 | " 96 | \n",
987 | " 0.379200 | \n",
988 | "
\n",
989 | " \n",
990 | " 97 | \n",
991 | " 0.321600 | \n",
992 | "
\n",
993 | " \n",
994 | " 98 | \n",
995 | " 0.359100 | \n",
996 | "
\n",
997 | " \n",
998 | " 99 | \n",
999 | " 0.347800 | \n",
1000 | "
\n",
1001 | " \n",
1002 | " 100 | \n",
1003 | " 0.617300 | \n",
1004 | "
\n",
1005 | " \n",
1006 | " 101 | \n",
1007 | " 0.342400 | \n",
1008 | "
\n",
1009 | " \n",
1010 | " 102 | \n",
1011 | " 0.196300 | \n",
1012 | "
\n",
1013 | " \n",
1014 | " 103 | \n",
1015 | " 0.526400 | \n",
1016 | "
\n",
1017 | " \n",
1018 | " 104 | \n",
1019 | " 0.291300 | \n",
1020 | "
\n",
1021 | " \n",
1022 | " 105 | \n",
1023 | " 0.421600 | \n",
1024 | "
\n",
1025 | " \n",
1026 | " 106 | \n",
1027 | " 0.148100 | \n",
1028 | "
\n",
1029 | " \n",
1030 | " 107 | \n",
1031 | " 0.565300 | \n",
1032 | "
\n",
1033 | " \n",
1034 | " 108 | \n",
1035 | " 0.308900 | \n",
1036 | "
\n",
1037 | " \n",
1038 | " 109 | \n",
1039 | " 0.465800 | \n",
1040 | "
\n",
1041 | " \n",
1042 | " 110 | \n",
1043 | " 0.193000 | \n",
1044 | "
\n",
1045 | " \n",
1046 | " 111 | \n",
1047 | " 0.124000 | \n",
1048 | "
\n",
1049 | " \n",
1050 | " 112 | \n",
1051 | " 0.217600 | \n",
1052 | "
\n",
1053 | " \n",
1054 | " 113 | \n",
1055 | " 0.191400 | \n",
1056 | "
\n",
1057 | " \n",
1058 | " 114 | \n",
1059 | " 0.241700 | \n",
1060 | "
\n",
1061 | " \n",
1062 | " 115 | \n",
1063 | " 0.166500 | \n",
1064 | "
\n",
1065 | " \n",
1066 | " 116 | \n",
1067 | " 0.155000 | \n",
1068 | "
\n",
1069 | " \n",
1070 | " 117 | \n",
1071 | " 0.382100 | \n",
1072 | "
\n",
1073 | " \n",
1074 | " 118 | \n",
1075 | " 0.211400 | \n",
1076 | "
\n",
1077 | " \n",
1078 | " 119 | \n",
1079 | " 0.385200 | \n",
1080 | "
\n",
1081 | " \n",
1082 | " 120 | \n",
1083 | " 0.235100 | \n",
1084 | "
\n",
1085 | " \n",
1086 | " 121 | \n",
1087 | " 0.396300 | \n",
1088 | "
\n",
1089 | " \n",
1090 | " 122 | \n",
1091 | " 0.161100 | \n",
1092 | "
\n",
1093 | " \n",
1094 | "
"
1095 | ],
1096 | "text/plain": [
1097 | ""
1098 | ]
1099 | },
1100 | "metadata": {},
1101 | "output_type": "display_data"
1102 | },
1103 | {
1104 | "name": "stdout",
1105 | "output_type": "stream",
1106 | "text": [
1107 | "Unsloth: Will smartly offload gradients to save VRAM!\n"
1108 | ]
1109 | }
1110 | ],
1111 | "source": [
1112 | "trainer_stats = trainer.train()"
1113 | ]
1114 | },
1115 | {
1116 | "cell_type": "code",
1117 | "execution_count": 12,
1118 | "metadata": {
1119 | "cellView": "form",
1120 | "colab": {
1121 | "base_uri": "https://localhost:8080/"
1122 | },
1123 | "id": "pCqnaKmlO1U9",
1124 | "outputId": "ff1b0842-5966-4dc2-bd98-c20832526b31"
1125 | },
1126 | "outputs": [
1127 | {
1128 | "name": "stdout",
1129 | "output_type": "stream",
1130 | "text": [
1131 | "286.2905 seconds used for training.\n",
1132 | "4.77 minutes used for training.\n",
1133 | "Peak reserved memory = 9.082 GB.\n",
1134 | "Peak reserved memory for training = 0.672 GB.\n",
1135 | "Peak reserved memory % of max memory = 37.843 %.\n",
1136 | "Peak reserved memory for training % of max memory = 2.8 %.\n"
1137 | ]
1138 | }
1139 | ],
1140 | "source": [
1141 | "#@title Show final memory and time stats\n",
1142 | "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
1143 | "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n",
1144 | "used_percentage = round(used_memory /max_memory*100, 3)\n",
1145 | "lora_percentage = round(used_memory_for_lora/max_memory*100, 3)\n",
1146 | "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n",
1147 | "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n",
1148 | "print(f\"Peak reserved memory = {used_memory} GB.\")\n",
1149 | "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n",
1150 | "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n",
1151 | "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"
1152 | ]
1153 | },
1154 | {
1155 | "cell_type": "markdown",
1156 | "metadata": {
1157 | "id": "ekOmTR1hSNcr"
1158 | },
1159 | "source": [
1160 | "\n",
1161 | "### Inference\n",
1162 | "This part evaluates the model on the val set with batched inference"
1163 | ]
1164 | },
1165 | {
1166 | "cell_type": "code",
1167 | "execution_count": 13,
1168 | "metadata": {},
1169 | "outputs": [
1170 | {
1171 | "name": "stdout",
1172 | "output_type": "stream",
1173 | "text": [
1174 | "\n"
1175 | ]
1176 | }
1177 | ],
1178 | "source": [
1179 | "FastLanguageModel.for_inference(model) # Enable native 2x faster inference\n",
1180 | "print()"
1181 | ]
1182 | },
1183 | {
1184 | "cell_type": "markdown",
1185 | "metadata": {},
1186 | "source": [
1187 | "### remake the old lm_head but with unused tokens having -1000 bias and 0 weights (improves compatibility with libraries like vllm)"
1188 | ]
1189 | },
1190 | {
1191 | "cell_type": "code",
1192 | "execution_count": 14,
1193 | "metadata": {},
1194 | "outputs": [
1195 | {
1196 | "name": "stdout",
1197 | "output_type": "stream",
1198 | "text": [
1199 | "Remade lm_head: shape = torch.Size([151936, 2560]). Allowed tokens: [15, 16, 17, 18]\n"
1200 | ]
1201 | }
1202 | ],
1203 | "source": [
1204 | "# Save the current (trimmed) lm_head and bias\n",
1205 | "trimmed_lm_head = model.lm_head.weight.data.clone()\n",
1206 | "trimmed_lm_head_bias = model.lm_head.bias.data.clone() if hasattr(model.lm_head, \"bias\") and model.lm_head.bias is not None else torch.zeros(len(number_token_ids), device=trimmed_lm_head.device)\n",
1207 | "\n",
1208 | "# Create a new lm_head with shape [old_size, hidden_dim]\n",
1209 | "hidden_dim = trimmed_lm_head.shape[1]\n",
1210 | "new_lm_head = torch.full((old_size, hidden_dim), 0, dtype=trimmed_lm_head.dtype, device=trimmed_lm_head.device)\n",
1211 | "new_lm_head_bias = torch.full((old_size,), -1000.0, dtype=trimmed_lm_head_bias.dtype, device=trimmed_lm_head_bias.device)\n",
1212 | "\n",
1213 | "# Fill in the weights and bias for the allowed tokens (number_token_ids)\n",
1214 | "for new_idx, orig_token_id in enumerate(number_token_ids):\n",
1215 | " new_lm_head[orig_token_id] = trimmed_lm_head[new_idx]\n",
1216 | " new_lm_head_bias[orig_token_id] = trimmed_lm_head_bias[new_idx]\n",
1217 | "\n",
1218 | "# Update the model's lm_head weight and bias\n",
1219 | "with torch.no_grad():\n",
1220 | " new_lm_head_module = torch.nn.Linear(hidden_dim, old_size, bias=True, device=model.device)\n",
1221 | " new_lm_head_module.weight.data.copy_(new_lm_head)\n",
1222 | " new_lm_head_module.bias.data.copy_(new_lm_head_bias)\n",
1223 | " model.lm_head.modules_to_save[\"default\"] = new_lm_head_module\n",
1224 | "\n",
1225 | "print(f\"Remade lm_head: shape = {model.lm_head.weight.shape}. Allowed tokens: {number_token_ids}\")"
1226 | ]
1227 | },
1228 | {
1229 | "cell_type": "markdown",
1230 | "metadata": {},
1231 | "source": [
1232 | "# Batched Inference on Validation Set"
1233 | ]
1234 | },
1235 | {
1236 | "cell_type": "code",
1237 | "execution_count": 17,
1238 | "metadata": {},
1239 | "outputs": [
1240 | {
1241 | "name": "stderr",
1242 | "output_type": "stream",
1243 | "text": [
1244 | "Evaluating: 100%|██████████| 28/28 [00:16<00:00, 1.71it/s]"
1245 | ]
1246 | },
1247 | {
1248 | "name": "stdout",
1249 | "output_type": "stream",
1250 | "text": [
1251 | "\n",
1252 | "Validation accuracy: 83.37% (361/433)\n",
1253 | "\n",
1254 | "--- Random samples ---\n",
1255 | "\n",
1256 | "Text: $LEVI - Levi Strauss EPS beats by $0.05, beats on revenue https://t.co/UfIyY92vC5\n",
1257 | "True: 2 Pred: 2 ✅\n",
1258 | "Probs: 1: 0.005, 2: 0.992, 3: 0.003\n",
1259 | "\n",
1260 | "Text: U.S. officials are signaling the European Union might be an easier target than the U.K. for a quick outcome aimed a… https://t.co/c5D7GJJbPf\n",
1261 | "True: 1 Pred: 1 ✅\n",
1262 | "Probs: 1: 0.935, 2: 0.028, 3: 0.036\n",
1263 | "\n",
1264 | "Text: Some odd divergences in the past two weeks as most asset classes except stocks are reversing recent gains https://t.co/0XJwxOyhBA\n",
1265 | "True: 1 Pred: 3 ❌\n",
1266 | "Probs: 1: 0.066, 2: 0.027, 3: 0.907\n",
1267 | "\n",
1268 | "Text: Some officials worried about the bank's decision to drop self-imposed limits on the ECB's bond purchases https://t.co/elfM7nRACS\n",
1269 | "True: 3 Pred: 1 ❌\n",
1270 | "Probs: 1: 0.715, 2: 0.022, 3: 0.263\n",
1271 | "\n",
1272 | "Text: Is Qorvo, Inc.'s (NASDAQ:QRVO) 5.9% ROE Worse Than Average?\n",
1273 | "True: 1 Pred: 1 ✅\n",
1274 | "Probs: 1: 0.762, 2: 0.020, 3: 0.218\n",
1275 | "\n",
1276 | "Text: Here's the level to watch in Texaco as stock falls on London ban $UBER (via @TradingNation) https://t.co/fk3KnVRU95\n",
1277 | "True: 3 Pred: 3 ✅\n",
1278 | "Probs: 1: 0.067, 2: 0.009, 3: 0.924\n",
1279 | "\n",
1280 | "Text: $INPX is gaining momentum......\n",
1281 | "True: 2 Pred: 2 ✅\n",
1282 | "Probs: 1: 0.473, 2: 0.503, 3: 0.024\n",
1283 | "\n",
1284 | "Text: First Coronavirus Case Reported In New York https://t.co/Cy1yurkH6a\n",
1285 | "True: 1 Pred: 1 ✅\n",
1286 | "Probs: 1: 0.523, 2: 0.016, 3: 0.461\n",
1287 | "\n",
1288 | "Text: $MCIG - MCig up 9% on CBD distribution deals https://t.co/WV5yrJFdyD\n",
1289 | "True: 2 Pred: 2 ✅\n",
1290 | "Probs: 1: 0.242, 2: 0.745, 3: 0.014\n",
1291 | "\n",
1292 | "Text: Defence Expo: Defence Equipment Makers See This As Another Opportunity In India\n",
1293 | "True: 1 Pred: 1 ✅\n",
1294 | "Probs: 1: 0.889, 2: 0.106, 3: 0.005\n",
1295 | "\n",
1296 | "Text: 'We know that there will be very likely some effects on the United States', said the Fed chairman Jay Powell about… https://t.co/eAeVR5GUm2\n",
1297 | "True: 1 Pred: 1 ✅\n",
1298 | "Probs: 1: 0.777, 2: 0.050, 3: 0.173\n",
1299 | "\n",
1300 | "Text: $COMDX: Natural gas inventory showed a draw of 201 bcf vs a 92 bcf draw last week https://t.co/CGfYWf1Unq\n",
1301 | "True: 2 Pred: 1 ❌\n",
1302 | "Probs: 1: 0.743, 2: 0.188, 3: 0.069\n",
1303 | "\n",
1304 | "Text: Axcelis Technologies stock price target raised to $32 from $24 at Benchmark\n",
1305 | "True: 2 Pred: 2 ✅\n",
1306 | "Probs: 1: 0.001, 2: 0.998, 3: 0.001\n",
1307 | "\n",
1308 | "Text: Methanex downgraded at TD Securities on valuation\n",
1309 | "True: 3 Pred: 3 ✅\n",
1310 | "Probs: 1: 0.003, 2: 0.003, 3: 0.994\n",
1311 | "\n",
1312 | "Text: Xeris Pharma launches equity offering; shares down 3% after hours\n",
1313 | "True: 3 Pred: 3 ✅\n",
1314 | "Probs: 1: 0.011, 2: 0.007, 3: 0.982\n",
1315 | "\n",
1316 | "Text: Dizzying swings are torching Wall Street predictions https://t.co/BvV6iItVRt\n",
1317 | "True: 3 Pred: 3 ✅\n",
1318 | "Probs: 1: 0.118, 2: 0.008, 3: 0.874\n",
1319 | "\n",
1320 | "Text: Home Depot cuts sales goal as online push not delivering as expected\n",
1321 | "True: 3 Pred: 3 ✅\n",
1322 | "Probs: 1: 0.009, 2: 0.005, 3: 0.987\n",
1323 | "\n",
1324 | "Text: Stocks Gain as U.S. Economy Signals Strength -- Update #SP500 #index #MarketScreener https://t.co/e9i8yV5o1I https://t.co/phbncHu6aL\n",
1325 | "True: 2 Pred: 1 ❌\n",
1326 | "Probs: 1: 0.615, 2: 0.373, 3: 0.011\n",
1327 | "\n",
1328 | "Text: Stock-index futures turn lower\n",
1329 | "True: 3 Pred: 1 ❌\n",
1330 | "Probs: 1: 0.835, 2: 0.037, 3: 0.128\n",
1331 | "\n",
1332 | "Text: Tesla Inc. slumped in pre-market trading after the electric carmaker’s newly unveiled pickup truck elicited mixed r… https://t.co/KDcJiSWCOp\n",
1333 | "True: 3 Pred: 3 ✅\n",
1334 | "Probs: 1: 0.037, 2: 0.004, 3: 0.958\n",
1335 | "\n",
1336 | "Text: Coronavirus reports hang over cruise line sector\n",
1337 | "True: 3 Pred: 3 ✅\n",
1338 | "Probs: 1: 0.085, 2: 0.006, 3: 0.909\n",
1339 | "\n",
1340 | "Text: U.S. Xpress EPS beats by $0.04, beats on revenue\n",
1341 | "True: 2 Pred: 2 ✅\n",
1342 | "Probs: 1: 0.006, 2: 0.990, 3: 0.004\n",
1343 | "\n",
1344 | "Text: Elastic +3.7% as Canaccord turns bullish\n",
1345 | "True: 2 Pred: 2 ✅\n",
1346 | "Probs: 1: 0.014, 2: 0.983, 3: 0.003\n",
1347 | "\n",
1348 | "Text: Whirlpool -2% after large recall in U.K.\n",
1349 | "True: 3 Pred: 3 ✅\n",
1350 | "Probs: 1: 0.042, 2: 0.007, 3: 0.951\n",
1351 | "\n",
1352 | "Text: Americans' outlook on the economy faltered significantly last month as the coronavirus crisis began to take hold in… https://t.co/5jeCXLXrrR\n",
1353 | "True: 3 Pred: 3 ✅\n",
1354 | "Probs: 1: 0.047, 2: 0.005, 3: 0.948\n",
1355 | "\n",
1356 | "Text: Pizza Hut's Struggling Turnaround Weighs on Yum Brands Results\n",
1357 | "True: 3 Pred: 3 ✅\n",
1358 | "Probs: 1: 0.014, 2: 0.007, 3: 0.979\n",
1359 | "\n",
1360 | "Text: FDM : AFL-CIO Endorses USMCA After Successfully Negotiating Improvements #FDM #Stock #MarketScreener… https://t.co/Ja3PJ0uKZB\n",
1361 | "True: 1 Pred: 1 ✅\n",
1362 | "Probs: 1: 0.903, 2: 0.084, 3: 0.013\n",
1363 | "\n",
1364 | "Text: LexinFintech +3.3% after loan origination view improves\n",
1365 | "True: 2 Pred: 2 ✅\n",
1366 | "Probs: 1: 0.001, 2: 0.998, 3: 0.001\n",
1367 | "\n",
1368 | "Text: $BLPH - Bellerophon Therapeutics EPS misses by $0.33 https://t.co/foAfyMnyra\n",
1369 | "True: 3 Pred: 3 ✅\n",
1370 | "Probs: 1: 0.060, 2: 0.007, 3: 0.933\n",
1371 | "\n",
1372 | "Text: Wholesale trade underwhelm in December\n",
1373 | "True: 3 Pred: 3 ✅\n",
1374 | "Probs: 1: 0.060, 2: 0.007, 3: 0.933\n",
1375 | "\n",
1376 | "Text: Why Disney+ is the only service that can rival Netflix\n",
1377 | "True: 2 Pred: 1 ❌\n",
1378 | "Probs: 1: 0.978, 2: 0.014, 3: 0.008\n",
1379 | "\n",
1380 | "Text: $CTSO: CytoSorbents says temporarily pausing enrollment of REFRESH 2-AKI study at the recommendation of its Data... https://t.co/6ibg4NhPh1\n",
1381 | "True: 3 Pred: 1 ❌\n",
1382 | "Probs: 1: 0.519, 2: 0.023, 3: 0.458\n",
1383 | "\n",
1384 | "Text: Twitter stock falls after downgrade\n",
1385 | "True: 3 Pred: 1 ❌\n",
1386 | "Probs: 1: 0.829, 2: 0.044, 3: 0.127\n",
1387 | "\n",
1388 | "Text: LAIX beats on revenue\n",
1389 | "True: 2 Pred: 1 ❌\n",
1390 | "Probs: 1: 0.898, 2: 0.069, 3: 0.033\n",
1391 | "\n",
1392 | "Text: Coty +5% after striking Kylie Jenner deal\n",
1393 | "True: 2 Pred: 2 ✅\n",
1394 | "Probs: 1: 0.001, 2: 0.998, 3: 0.001\n",
1395 | "\n",
1396 | "Text: Airline stocks are higher amid some positive signs within the fight against the coronavirus pandemic https://t.co/BOK9NTJJuv\n",
1397 | "True: 2 Pred: 2 ✅\n",
1398 | "Probs: 1: 0.265, 2: 0.720, 3: 0.015\n",
1399 | "\n",
1400 | "Text: Estee Lauder Cuts Profit Outlook Again, Citing Coronavirus\n",
1401 | "True: 3 Pred: 3 ✅\n",
1402 | "Probs: 1: 0.018, 2: 0.002, 3: 0.980\n",
1403 | "\n",
1404 | "Text: H&P downgraded at Argus as drilling industry weakness seen persisting\n",
1405 | "True: 3 Pred: 3 ✅\n",
1406 | "Probs: 1: 0.006, 2: 0.004, 3: 0.990\n",
1407 | "\n",
1408 | "Text: Hedge Funds Aren’t Crazy About Ladder Capital Corp (LADR) Anymore\n",
1409 | "True: 3 Pred: 3 ✅\n",
1410 | "Probs: 1: 0.093, 2: 0.027, 3: 0.881\n",
1411 | "\n",
1412 | "Text: RPM International stock price target raised to $79 vs. $75 at BofA Merrill Lynch\n",
1413 | "True: 2 Pred: 2 ✅\n",
1414 | "Probs: 1: 0.018, 2: 0.978, 3: 0.004\n",
1415 | "\n",
1416 | "Text: PG&E boss says it wasn’t fully ready for California outages\n",
1417 | "True: 1 Pred: 3 ❌\n",
1418 | "Probs: 1: 0.416, 2: 0.050, 3: 0.534\n",
1419 | "\n",
1420 | "Text: The euro-area economy came close to a halt in November as the steep decline in manufacturing spread further into se… https://t.co/PKMjM0loEx\n",
1421 | "True: 3 Pred: 3 ✅\n",
1422 | "Probs: 1: 0.037, 2: 0.004, 3: 0.958\n",
1423 | "\n",
1424 | "Text: Exxon, Chevron results augur tough year ahead, shares drop 3% #economy #MarketScreener https://t.co/sABok0wweo https://t.co/BYDIfnQivg\n",
1425 | "True: 3 Pred: 1 ❌\n",
1426 | "Probs: 1: 0.801, 2: 0.123, 3: 0.075\n",
1427 | "\n",
1428 | "Text: $SOYB $MOO $FTAG - Soybeans rebound as China trade talks make progress https://t.co/1mf2YHIujB\n",
1429 | "True: 2 Pred: 2 ✅\n",
1430 | "Probs: 1: 0.002, 2: 0.996, 3: 0.002\n",
1431 | "\n",
1432 | "Text: eDreams ODIGEO S.A. reports Q2 results\n",
1433 | "True: 1 Pred: 1 ✅\n",
1434 | "Probs: 1: 0.989, 2: 0.005, 3: 0.006\n",
1435 | "\n",
1436 | "Text: Welcome back, Wall Street! @JimCramer and @byKatherineRoss are breaking down all the latest on the markets and the… https://t.co/lzJSAKAv7s\n",
1437 | "True: 1 Pred: 1 ✅\n",
1438 | "Probs: 1: 0.877, 2: 0.119, 3: 0.005\n",
1439 | "\n",
1440 | "Text: U.S. Job Report Looks Likely to Show Hot 2020 Start, Cooler Past\n",
1441 | "True: 2 Pred: 1 ❌\n",
1442 | "Probs: 1: 0.698, 2: 0.291, 3: 0.011\n",
1443 | "\n",
1444 | "Text: Nigeria's petroleum bill to be passed by mid-2020, says oil minister #economy #MarketScreener… https://t.co/iTMZMP6wLO\n",
1445 | "True: 1 Pred: 1 ✅\n",
1446 | "Probs: 1: 0.788, 2: 0.199, 3: 0.013\n",
1447 | "\n",
1448 | "Text: $ECONX: Nonfarm Payroll Revisions- October revised to 156K from 128K; September revised to 193K from 180K https://t.co/WOg237QKLC\n",
1449 | "True: 2 Pred: 1 ❌\n",
1450 | "Probs: 1: 0.753, 2: 0.131, 3: 0.116\n",
1451 | "\n",
1452 | "Text: WhatsApp Closes In On Full Fledged UPI Payments Launch\n",
1453 | "True: 1 Pred: 1 ✅\n",
1454 | "Probs: 1: 0.826, 2: 0.163, 3: 0.012\n"
1455 | ]
1456 | },
1457 | {
1458 | "name": "stderr",
1459 | "output_type": "stream",
1460 | "text": [
1461 | "\n"
1462 | ]
1463 | }
1464 | ],
1465 | "source": [
1466 | "import torch\n",
1467 | "import torch.nn.functional as F\n",
1468 | "from tqdm import tqdm\n",
1469 | "import random\n",
1470 | "\n",
1471 | "# Prepare inference prompt\n",
1472 | "inference_prompt_template = prompt.split(\"class {}\")[0] + \"class \"\n",
1473 | "\n",
1474 | "# Sort validation set by length for efficient batching\n",
1475 | "val_df['token_length'] = val_df['text'].apply(lambda x: len(tokenizer.encode(x, add_special_tokens=False)))\n",
1476 | "val_df_sorted = val_df.sort_values(by='token_length').reset_index(drop=True)\n",
1477 | "\n",
1478 | "display = 50\n",
1479 | "batch_size = 16\n",
1480 | "device = model.device\n",
1481 | "correct = 0\n",
1482 | "results = []\n",
1483 | "\n",
1484 | "with torch.inference_mode():\n",
1485 | " for i in tqdm(range(0, len(val_df_sorted), batch_size), desc=\"Evaluating\"):\n",
1486 | " batch = val_df_sorted.iloc[i:i+batch_size]\n",
1487 | " prompts = [inference_prompt_template.format(text) for text in batch['text']]\n",
1488 | " inputs = tokenizer(prompts, return_tensors=\"pt\", padding=True, truncation=True, max_length=max_seq_length).to(device)\n",
1489 | " logits = model(**inputs).logits\n",
1490 | " last_idxs = inputs.attention_mask.sum(1) - 1\n",
1491 | " last_logits = logits[torch.arange(len(batch)), last_idxs, :]\n",
1492 | " probs_all = F.softmax(last_logits, dim=-1)\n",
1493 | " probs = probs_all[:, number_token_ids] # only keep the logits for the number tokens\n",
1494 | " preds = torch.argmax(probs, dim=-1).cpu().numpy() # looks like [1 1 1 1 3 1 3 1 3 1 1 1 1 2 2 3]\n",
1495 | "\n",
1496 | " true_labels = batch['label'].tolist()\n",
1497 | " correct += sum([p == t for p, t in zip(preds, true_labels)])\n",
1498 | " # Store a few samples for display\n",
1499 | " for j in range(len(batch)):\n",
1500 | " results.append({\n",
1501 | " \"text\": batch['text'].iloc[j][:200],\n",
1502 | " \"true\": true_labels[j],\n",
1503 | " \"pred\": preds[j],\n",
1504 | " \"probs\": probs[j][1:].float().cpu().numpy(), # ignore prob for class 0 and convert from tensor to float\n",
1505 | " \"ok\": preds[j] == true_labels[j]\n",
1506 | " })\n",
1507 | "\n",
1508 | "accuracy = 100 * correct / len(val_df_sorted)\n",
1509 | "print(f\"\\nValidation accuracy: {accuracy:.2f}% ({correct}/{len(val_df_sorted)})\")\n",
1510 | "\n",
1511 | "print(\"\\n--- Random samples ---\")\n",
1512 | "for s in random.sample(results, min(display, len(results))):\n",
1513 | " print(f\"\\nText: {s['text']}\")\n",
1514 | " print(f\"True: {s['true']} Pred: {s['pred']} {'✅' if s['ok'] else '❌'}\")\n",
1515 | " print(\"Probs:\", \", \".join([f\"{k}: {v:.3f}\" for k, v in enumerate(s['probs'], start=1)]))\n",
1516 | "\n",
1517 | "# Clean up\n",
1518 | "if 'token_length' in val_df:\n",
1519 | " del val_df['token_length']"
1520 | ]
1521 | },
1522 | {
1523 | "cell_type": "code",
1524 | "execution_count": 16,
1525 | "metadata": {},
1526 | "outputs": [
1527 | {
1528 | "ename": "ZeroDivisionError",
1529 | "evalue": "division by zero",
1530 | "output_type": "error",
1531 | "traceback": [
1532 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
1533 | "\u001b[1;31mZeroDivisionError\u001b[0m Traceback (most recent call last)",
1534 | "Cell \u001b[1;32mIn[16], line 2\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m# stop running all cells\u001b[39;00m\n\u001b[1;32m----> 2\u001b[0m \u001b[38;5;241m1\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m0\u001b[39m\n",
1535 | "\u001b[1;31mZeroDivisionError\u001b[0m: division by zero"
1536 | ]
1537 | }
1538 | ],
1539 | "source": [
1540 | "# stop running all cells\n",
1541 | "1/0"
1542 | ]
1543 | },
1544 | {
1545 | "cell_type": "markdown",
1546 | "metadata": {},
1547 | "source": [
1548 | "Now if you closed the notebook kernel and want to reload the model:"
1549 | ]
1550 | },
1551 | {
1552 | "cell_type": "code",
1553 | "execution_count": null,
1554 | "metadata": {},
1555 | "outputs": [
1556 | {
1557 | "name": "stdout",
1558 | "output_type": "stream",
1559 | "text": [
1560 | "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n",
1561 | "🦥 Unsloth Zoo will now patch everything to make training faster!\n"
1562 | ]
1563 | },
1564 | {
1565 | "name": "stderr",
1566 | "output_type": "stream",
1567 | "text": [
1568 | "c:\\ProgramData\\Anaconda3\\Lib\\site-packages\\unsloth_zoo\\gradient_checkpointing.py:330: UserWarning: expandable_segments not supported on this platform (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\pytorch\\c10/cuda/CUDAAllocatorConfig.h:28.)\n",
1569 | " GPU_BUFFERS = tuple([torch.empty(2*256*2048, dtype = dtype, device = f\"cuda:{i}\") for i in range(n_gpus)])\n"
1570 | ]
1571 | },
1572 | {
1573 | "name": "stdout",
1574 | "output_type": "stream",
1575 | "text": [
1576 | "==((====))== Unsloth 2025.4.5: Fast Qwen3 patching. Transformers: 4.51.3.\n",
1577 | " \\\\ /| NVIDIA GeForce RTX 3090. Num GPUs = 1. Max memory: 23.999 GB. Platform: Windows.\n",
1578 | "O^O/ \\_/ \\ Torch: 2.7.0+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.0\n",
1579 | "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]\n",
1580 | " \"-____-\" Free license: http://github.com/unslothai/unsloth\n",
1581 | "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n"
1582 | ]
1583 | },
1584 | {
1585 | "data": {
1586 | "application/vnd.jupyter.widget-view+json": {
1587 | "model_id": "73992b1f949448ce9981a732af2d7b66",
1588 | "version_major": 2,
1589 | "version_minor": 0
1590 | },
1591 | "text/plain": [
1592 | "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
1593 | ]
1594 | },
1595 | "metadata": {},
1596 | "output_type": "display_data"
1597 | },
1598 | {
1599 | "name": "stderr",
1600 | "output_type": "stream",
1601 | "text": [
1602 | "Unsloth 2025.4.5 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.\n"
1603 | ]
1604 | },
1605 | {
1606 | "name": "stdout",
1607 | "output_type": "stream",
1608 | "text": [
1609 | "Model loaded successfully.\n",
1610 | "[\"Here is a financial news:\\nFor the global oil market, the coronavirus epidemic couldn't have hit a worse place\\n\\nClassify this news into one of the following:\\nclass 1: Bullish\\nclass 2: Neutral\\nclass 3: Bearish\\n\\nSOLUTION\\nThe correct answer is: class 3\"]\n"
1611 | ]
1612 | }
1613 | ],
1614 | "source": [
1615 | "# load the model\n",
1616 | "from unsloth import FastLanguageModel\n",
1617 | "model, tokenizer = FastLanguageModel.from_pretrained(\n",
1618 | " \"lora_model_Qwen3-4B-Base\",\n",
1619 | " load_in_4bit = False,\n",
1620 | " max_seq_length = 2048,\n",
1621 | " dtype = None,\n",
1622 | ")\n",
1623 | "print(\"Model loaded successfully.\")\n",
1624 | "\n",
1625 | "FastLanguageModel.for_inference(model)\n",
1626 | "\n",
1627 | "prompt = \"\"\"Here is a financial news:\n",
1628 | "For the global oil market, the coronavirus epidemic couldn't have hit a worse place\n",
1629 | "\n",
1630 | "Classify this news into one of the following:\n",
1631 | "class 1: Bullish\n",
1632 | "class 2: Neutral\n",
1633 | "class 3: Bearish\n",
1634 | "\n",
1635 | "SOLUTION\n",
1636 | "The correct answer is: class \"\"\"\n",
1637 | "inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
1638 | "outputs = model.generate(**inputs, max_new_tokens=1, use_cache=True)\n",
1639 | "decoded = tokenizer.batch_decode(outputs)\n",
1640 | "print(decoded)"
1641 | ]
1642 | },
1643 | {
1644 | "cell_type": "markdown",
1645 | "metadata": {
1646 | "id": "f422JgM9sdVT"
1647 | },
1648 | "source": [
1649 | "### Saving to float16 for VLLM\n",
1650 | "\n",
1651 | "We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens."
1652 | ]
1653 | },
1654 | {
1655 | "cell_type": "code",
1656 | "execution_count": null,
1657 | "metadata": {
1658 | "id": "iHjt_SMYsd3P"
1659 | },
1660 | "outputs": [],
1661 | "source": [
1662 | "# Merge to 16bit\n",
1663 | "if False: model.save_pretrained_merged(\"hf/model\", tokenizer, save_method = \"merged_16bit\",)\n",
1664 | "if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"merged_16bit\", token = \"\")\n",
1665 | "\n",
1666 | "# Merge to 4bit\n",
1667 | "if False: model.save_pretrained_merged(\"hf/model\", tokenizer, save_method = \"merged_4bit\",)\n",
1668 | "if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"merged_4bit\", token = \"\")\n",
1669 | "\n",
1670 | "# Just LoRA adapters\n",
1671 | "if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"lora\",)\n",
1672 | "if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"lora\", token = \"\")"
1673 | ]
1674 | },
1675 | {
1676 | "cell_type": "markdown",
1677 | "metadata": {
1678 | "id": "TCv4vXHd61i7"
1679 | },
1680 | "source": [
1681 | "### GGUF / llama.cpp Conversion\n",
1682 | "To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF."
1683 | ]
1684 | },
1685 | {
1686 | "cell_type": "code",
1687 | "execution_count": null,
1688 | "metadata": {
1689 | "id": "FqfebeAdT073"
1690 | },
1691 | "outputs": [],
1692 | "source": [
1693 | "# Save to 8bit Q8_0\n",
1694 | "if False: model.save_pretrained_gguf(\"model\", tokenizer,)\n",
1695 | "if False: model.push_to_hub_gguf(\"hf/model\", tokenizer, token = \"\")\n",
1696 | "\n",
1697 | "# Save to 16bit GGUF\n",
1698 | "if False: model.save_pretrained_gguf(\"model\", tokenizer, quantization_method = \"f16\")\n",
1699 | "if False: model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method = \"f16\", token = \"\")\n",
1700 | "\n",
1701 | "# Save to q4_k_m GGUF\n",
1702 | "if False: model.save_pretrained_gguf(\"model\", tokenizer, quantization_method = \"q4_k_m\")\n",
1703 | "if False: model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method = \"q4_k_m\", token = \"\")"
1704 | ]
1705 | },
1706 | {
1707 | "cell_type": "markdown",
1708 | "metadata": {
1709 | "id": "bDp0zNpwe6U_"
1710 | },
1711 | "source": [
1712 | "Now, use the `model-unsloth.gguf` file or `model-unsloth-Q4_K_M.gguf` file in `llama.cpp` or a UI based system like `GPT4All`. You can install GPT4All by going [here](https://gpt4all.io/index.html)."
1713 | ]
1714 | },
1715 | {
1716 | "cell_type": "markdown",
1717 | "metadata": {
1718 | "id": "Zt9CHJqO6p30"
1719 | },
1720 | "source": [
1721 | "And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/u54VK8m8tk) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!\n",
1722 | "\n",
1723 | "Some other links:\n",
1724 | "1. Zephyr DPO 2x faster [free Colab](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing)\n",
1725 | "2. Llama 7b 2x faster [free Colab](https://colab.research.google.com/drive/1lBzz5KeZJKXjvivbYvmGarix9Ao6Wxe5?usp=sharing)\n",
1726 | "3. TinyLlama 4x faster full Alpaca 52K in 1 hour [free Colab](https://colab.research.google.com/drive/1AZghoNBQaMDgWJpi4RbffGM1h6raLUj9?usp=sharing)\n",
1727 | "4. CodeLlama 34b 2x faster [A100 on Colab](https://colab.research.google.com/drive/1y7A0AxE3y8gdj4AVkl2aZX47Xu3P1wJT?usp=sharing)\n",
1728 | "5. Llama 7b [free Kaggle](https://www.kaggle.com/danielhanchen/unsloth-alpaca-t4-ddp)\n",
1729 | "6. We also did a [blog](https://huggingface.co/blog/unsloth-trl) with 🤗 HuggingFace, and we're in the TRL [docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth)!\n",
1730 | "\n",
1731 | "\n",
1732 | "

\n",
1733 | "

\n",
1734 | "

Support our work if you can! Thanks!\n",
1735 | "
"
1736 | ]
1737 | }
1738 | ],
1739 | "metadata": {
1740 | "accelerator": "GPU",
1741 | "colab": {
1742 | "gpuType": "T4",
1743 | "provenance": []
1744 | },
1745 | "kaggle": {
1746 | "accelerator": "gpu",
1747 | "dataSources": [
1748 | {
1749 | "datasetId": 5081962,
1750 | "sourceId": 8512897,
1751 | "sourceType": "datasetVersion"
1752 | }
1753 | ],
1754 | "dockerImageVersionId": 30733,
1755 | "isGpuEnabled": true,
1756 | "isInternetEnabled": true,
1757 | "language": "python",
1758 | "sourceType": "notebook"
1759 | },
1760 | "kernelspec": {
1761 | "display_name": "Python 3 (ipykernel)",
1762 | "language": "python",
1763 | "name": "python3"
1764 | },
1765 | "language_info": {
1766 | "codemirror_mode": {
1767 | "name": "ipython",
1768 | "version": 3
1769 | },
1770 | "file_extension": ".py",
1771 | "mimetype": "text/x-python",
1772 | "name": "python",
1773 | "nbconvert_exporter": "python",
1774 | "pygments_lexer": "ipython3",
1775 | "version": "3.12.3"
1776 | }
1777 | },
1778 | "nbformat": 4,
1779 | "nbformat_minor": 4
1780 | }
1781 |
--------------------------------------------------------------------------------