├── README.md ├── sample_reviews.parquet ├── subject_transfer_learning_v1.pdf └── transfer_learning_AI.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # **Transfer Learning and Fine Tuning in Sentiment Analysis** - 42 Urduliz Bizkaia AI Project 2 | 3 | ## **Project Overview** 4 | 5 | This project introduces **Transfer Learning** in **Sentiment Analysis**, a key NLP technique used to determine whether a text conveys positive or negative sentiment. By leveraging **pretrained LLM models**, we aim to create a sentiment analysis system with minimal training data and computational resources. 6 | 7 | The focus is on developing a binary sentiment analysis model using **transfer learning** techniques. Students will explore pretrained models, experiment with tokenization, and adjust **hyperparameters** to optimize performance while considering resource limitations. 8 | 9 | ## **Project Aim** 10 | 11 | The project provides hands-on experience in: 12 | 1. **Transfer Learning**: Adapt pretrained models like **BERT**, **RoBERTa**, or **GPT-2** for sentiment analysis. 13 | 2. **Sentiment Analysis**: Create a binary classifier to categorize text (positive/negative). 14 | 3. **Tokenization**: Research and apply suitable tokenization methods. 15 | 4. **Hyperparameter Tuning**: Optimize hyperparameters like **learning rate** and **batch size**. 16 | 5. **Resource Efficiency**: Train models within limited computing resources. 17 | 18 | ## **Key Features** 19 | - Choose from a range of **pretrained models** (e.g., BERT, RoBERTa). 20 | - Flexible resources: Use **local or cloud computing** (e.g., **sgoinfre**). 21 | - Customize the project based on personal research and interests. 22 | 23 | ## **Important Clarification** 24 | 25 | The provided **sample dataset** is not intended for training but to give an idea of what sentiment analysis datasets look like. Please choose your own dataset (e.g., IMDB, Yelp, Amazon reviews) for training and evaluation. 26 | 27 | ## **Bonus Challenges** 28 | - Test your model on additional datasets for generalization. 29 | - Compare tokenization techniques. 30 | - Deploy your model for real-time sentiment prediction via a web interface. 31 | 32 | ## **Note on Fine-Tuning** 33 | **Fine-tuning pretrained models** is a critical part of this project. By adjusting hyperparameters and training on your specific dataset, you can improve model performance and efficiency. 34 | 35 | # Solution Notebook: 36 | ### transfer_learning_AI.ipynb 37 | 38 | A complete solution to the project is provided in the notebook transfer_learning_AI.ipynb. This notebook was developed and executed in Amazon SageMaker Studio Lab, a cloud-based environment designed for experimenting with machine learning workflows in a flexible and scalable way. 39 | 40 | The solution reflects the project's core objective: to gain hands-on experience in cloud-based development using transfer learning and fine-tuning techniques in Natural Language Processing (NLP). 41 | 42 | The notebook includes: 43 | 44 | ✅ The full implementation of a sentiment analysis model using transfer learning. 45 | 46 | 47 | ✅ Executed code cells, allowing immediate inspection of results and outputs. 48 | 49 | 50 | ✅ Additional code snippets and in-line explanations to support a deeper understanding of each phase, from tokenization to evaluation. 51 | 52 | 53 | 💡 This notebook is both a functional solution and a guided walkthrough for understanding how to apply transfer learning in a modern, cloud-based NLP project. 54 | 55 | 56 | Before running the notebook in your own cloud environment (e.g., Google Colab or SageMaker Studio Lab), remember to update file paths if necessary. 57 | 58 | -------------------------------------------------------------------------------- /sample_reviews.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcdata42/LLMs_transfer_learning/dc261e1e74dae530dd2d15b12e91fa0a89f2e7dc/sample_reviews.parquet -------------------------------------------------------------------------------- /subject_transfer_learning_v1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcdata42/LLMs_transfer_learning/dc261e1e74dae530dd2d15b12e91fa0a89f2e7dc/subject_transfer_learning_v1.pdf -------------------------------------------------------------------------------- /transfer_learning_AI.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b655f09b-7c91-4b6d-ac13-7203a360903e", 6 | "metadata": { 7 | "id": "b655f09b-7c91-4b6d-ac13-7203a360903e" 8 | }, 9 | "source": [ 10 | "# training a pre trained model with data for sentiment analysis" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "81b2f6a9-de86-45d8-8a4b-1365d830bd39", 16 | "metadata": { 17 | "id": "81b2f6a9-de86-45d8-8a4b-1365d830bd39" 18 | }, 19 | "source": [ 20 | "* Glue -- sst2 dataset\n", 21 | "* train and tests dataset\n", 22 | "* ver interesting table: sentence - tokens - ids - n_tokens - essential_tokens\n", 23 | "* predicting labels from other datasets different fron sst2" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "id": "e464377d-8975-4cb3-97c1-d9569eae5b26", 29 | "metadata": { 30 | "id": "e464377d-8975-4cb3-97c1-d9569eae5b26" 31 | }, 32 | "source": [ 33 | "# libraries" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 1, 39 | "id": "ab3afc5b-e646-40ff-b30c-54a0c39aefa3", 40 | "metadata": { 41 | "colab": { 42 | "base_uri": "https://localhost:8080/" 43 | }, 44 | "executionInfo": { 45 | "elapsed": 30520, 46 | "status": "ok", 47 | "timestamp": 1725026930710, 48 | "user": { 49 | "displayName": "jcdata", 50 | "userId": "10219721298441697414" 51 | }, 52 | "user_tz": -120 53 | }, 54 | "id": "ab3afc5b-e646-40ff-b30c-54a0c39aefa3", 55 | "outputId": "925e325b-bf13-41af-cd90-1e548d256170", 56 | "tags": [] 57 | }, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "Note: you may need to restart the kernel to use updated packages.\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "pip install datasets transformers evaluate torch scikit-learn accelerate -U" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "id": "b05cb47a-d04a-4c50-9847-b2710da3eb8e", 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 1, 82 | "id": "f356f062-ed4b-4260-ae96-6967f82ab1b5", 83 | "metadata": { 84 | "id": "f356f062-ed4b-4260-ae96-6967f82ab1b5", 85 | "outputId": "2c853a43-b22c-402a-a629-e11e31dd6ef9", 86 | "tags": [] 87 | }, 88 | "outputs": [], 89 | "source": [ 90 | "from datasets import load_dataset\n", 91 | "\n", 92 | "from transformers import AutoTokenizer\n", 93 | "from transformers import AutoModelForSequenceClassification\n", 94 | "from transformers import Trainer, TrainingArguments\n", 95 | "\n", 96 | "import evaluate\n", 97 | "import torch\n", 98 | "import pandas as pd\n", 99 | "from sklearn.metrics import accuracy_score\n", 100 | "from tqdm import tqdm" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "id": "42bda474-c51f-418e-aa20-0957f2d8be71", 106 | "metadata": { 107 | "id": "42bda474-c51f-418e-aa20-0957f2d8be71", 108 | "tags": [] 109 | }, 110 | "source": [ 111 | "# importing the model - tokenizer - dataset" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 4, 117 | "id": "9d04084c-b880-44f0-ad0c-1fcf19b353ad", 118 | "metadata": { 119 | "id": "9d04084c-b880-44f0-ad0c-1fcf19b353ad", 120 | "outputId": "0f3dcdf7-4c47-4fbf-be45-af46a932b2cc", 121 | "tags": [] 122 | }, 123 | "outputs": [ 124 | { 125 | "name": "stderr", 126 | "output_type": "stream", 127 | "text": [ 128 | "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n", 129 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 130 | ] 131 | } 132 | ], 133 | "source": [ 134 | "# 'num_labels=2' specifies that this is a binary classification task\n", 135 | "model = AutoModelForSequenceClassification.from_pretrained(\n", 136 | " \"distilbert-base-uncased\", num_labels=2)\n", 137 | "# DistilBERT is a smaller, faster version of BERT. It has already been\n", 138 | "# pre-trained on general language tasks" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 4, 144 | "id": "2b0dd2fd-cfc3-4195-a010-a60bb6eeac23", 145 | "metadata": { 146 | "id": "2b0dd2fd-cfc3-4195-a010-a60bb6eeac23", 147 | "outputId": "899adae5-aa5a-4c97-db6a-53edadaf52e9", 148 | "tags": [] 149 | }, 150 | "outputs": [ 151 | { 152 | "name": "stderr", 153 | "output_type": "stream", 154 | "text": [ 155 | "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", 156 | " warnings.warn(\n" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "# Load the tokenizer for DistilBERT (or any other model to fine-tune).\n", 162 | "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 5, 168 | "id": "f8a2f99f-560b-42ca-9f56-3e3e353452a3", 169 | "metadata": { 170 | "tags": [] 171 | }, 172 | "outputs": [ 173 | { 174 | "data": { 175 | "text/plain": [ 176 | "\u001b[0;31mSignature:\u001b[0m \n", 177 | "\u001b[0mtokenizer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n", 178 | "\u001b[0;34m\u001b[0m \u001b[0mtext\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 179 | "\u001b[0;34m\u001b[0m \u001b[0mtext_pair\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 180 | "\u001b[0;34m\u001b[0m \u001b[0mtext_target\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 181 | "\u001b[0;34m\u001b[0m \u001b[0mtext_pair_target\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 182 | "\u001b[0;34m\u001b[0m \u001b[0madd_special_tokens\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 183 | "\u001b[0;34m\u001b[0m \u001b[0mpadding\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbool\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtransformers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgeneric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPaddingStrategy\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 184 | "\u001b[0;34m\u001b[0m \u001b[0mtruncation\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbool\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtransformers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtokenization_utils_base\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTruncationStrategy\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 185 | "\u001b[0;34m\u001b[0m \u001b[0mmax_length\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 186 | "\u001b[0;34m\u001b[0m \u001b[0mstride\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 187 | "\u001b[0;34m\u001b[0m \u001b[0mis_split_into_words\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 188 | "\u001b[0;34m\u001b[0m \u001b[0mpad_to_multiple_of\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 189 | "\u001b[0;34m\u001b[0m \u001b[0mreturn_tensors\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtransformers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgeneric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensorType\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 190 | "\u001b[0;34m\u001b[0m \u001b[0mreturn_token_type_ids\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbool\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 191 | "\u001b[0;34m\u001b[0m \u001b[0mreturn_attention_mask\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbool\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 192 | "\u001b[0;34m\u001b[0m \u001b[0mreturn_overflowing_tokens\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 193 | "\u001b[0;34m\u001b[0m \u001b[0mreturn_special_tokens_mask\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 194 | "\u001b[0;34m\u001b[0m \u001b[0mreturn_offsets_mapping\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 195 | "\u001b[0;34m\u001b[0m \u001b[0mreturn_length\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 196 | "\u001b[0;34m\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 197 | "\u001b[0;34m\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 198 | "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mtransformers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtokenization_utils_base\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mBatchEncoding\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 199 | "\u001b[0;31mType:\u001b[0m DistilBertTokenizerFast\n", 200 | "\u001b[0;31mString form:\u001b[0m \n", 201 | "DistilBertTokenizerFast(name_or_path='distilbert-base-uncased', vocab_size=30522, model_max_lengt <...> Token(\"[MASK]\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n", 202 | " }\n", 203 | "\u001b[0;31mLength:\u001b[0m 30522\n", 204 | "\u001b[0;31mFile:\u001b[0m ~/.conda/envs/default/lib/python3.9/site-packages/transformers/models/distilbert/tokenization_distilbert_fast.py\n", 205 | "\u001b[0;31mDocstring:\u001b[0m \n", 206 | "Construct a \"fast\" DistilBERT tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece.\n", 207 | "\n", 208 | "This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should\n", 209 | "refer to this superclass for more information regarding those methods.\n", 210 | "\n", 211 | "Args:\n", 212 | " vocab_file (`str`):\n", 213 | " File containing the vocabulary.\n", 214 | " do_lower_case (`bool`, *optional*, defaults to `True`):\n", 215 | " Whether or not to lowercase the input when tokenizing.\n", 216 | " unk_token (`str`, *optional*, defaults to `\"[UNK]\"`):\n", 217 | " The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this\n", 218 | " token instead.\n", 219 | " sep_token (`str`, *optional*, defaults to `\"[SEP]\"`):\n", 220 | " The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for\n", 221 | " sequence classification or for a text and a question for question answering. It is also used as the last\n", 222 | " token of a sequence built with special tokens.\n", 223 | " pad_token (`str`, *optional*, defaults to `\"[PAD]\"`):\n", 224 | " The token used for padding, for example when batching sequences of different lengths.\n", 225 | " cls_token (`str`, *optional*, defaults to `\"[CLS]\"`):\n", 226 | " The classifier token which is used when doing sequence classification (classification of the whole sequence\n", 227 | " instead of per-token classification). It is the first token of the sequence when built with special tokens.\n", 228 | " mask_token (`str`, *optional*, defaults to `\"[MASK]\"`):\n", 229 | " The token used for masking values. This is the token used when training this model with masked language\n", 230 | " modeling. This is the token which the model will try to predict.\n", 231 | " clean_text (`bool`, *optional*, defaults to `True`):\n", 232 | " Whether or not to clean the text before tokenization by removing any control characters and replacing all\n", 233 | " whitespaces by the classic one.\n", 234 | " tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):\n", 235 | " Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this\n", 236 | " issue](https://github.com/huggingface/transformers/issues/328)).\n", 237 | " strip_accents (`bool`, *optional*):\n", 238 | " Whether or not to strip all accents. If this option is not specified, then it will be determined by the\n", 239 | " value for `lowercase` (as in the original BERT).\n", 240 | " wordpieces_prefix (`str`, *optional*, defaults to `\"##\"`):\n", 241 | " The prefix for subwords.\n", 242 | "\u001b[0;31mCall docstring:\u001b[0m\n", 243 | "Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of\n", 244 | "sequences.\n", 245 | "\n", 246 | "Args:\n", 247 | " text (`str`, `List[str]`, `List[List[str]]`, *optional*):\n", 248 | " The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings\n", 249 | " (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set\n", 250 | " `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).\n", 251 | " text_pair (`str`, `List[str]`, `List[List[str]]`, *optional*):\n", 252 | " The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings\n", 253 | " (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set\n", 254 | " `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).\n", 255 | " text_target (`str`, `List[str]`, `List[List[str]]`, *optional*):\n", 256 | " The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a\n", 257 | " list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),\n", 258 | " you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).\n", 259 | " text_pair_target (`str`, `List[str]`, `List[List[str]]`, *optional*):\n", 260 | " The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a\n", 261 | " list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),\n", 262 | " you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).\n", 263 | "\n", 264 | " add_special_tokens (`bool`, *optional*, defaults to `True`):\n", 265 | " Whether or not to add special tokens when encoding the sequences. This will use the underlying\n", 266 | " `PretrainedTokenizerBase.build_inputs_with_special_tokens` function, which defines which tokens are\n", 267 | " automatically added to the input ids. This is usefull if you want to add `bos` or `eos` tokens\n", 268 | " automatically.\n", 269 | " padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):\n", 270 | " Activates and controls padding. Accepts the following values:\n", 271 | "\n", 272 | " - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single\n", 273 | " sequence if provided).\n", 274 | " - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum\n", 275 | " acceptable input length for the model if that argument is not provided.\n", 276 | " - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different\n", 277 | " lengths).\n", 278 | " truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):\n", 279 | " Activates and controls truncation. Accepts the following values:\n", 280 | "\n", 281 | " - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or\n", 282 | " to the maximum acceptable input length for the model if that argument is not provided. This will\n", 283 | " truncate token by token, removing a token from the longest sequence in the pair if a pair of\n", 284 | " sequences (or a batch of pairs) is provided.\n", 285 | " - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the\n", 286 | " maximum acceptable input length for the model if that argument is not provided. This will only\n", 287 | " truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.\n", 288 | " - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the\n", 289 | " maximum acceptable input length for the model if that argument is not provided. This will only\n", 290 | " truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.\n", 291 | " - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths\n", 292 | " greater than the model maximum admissible input size).\n", 293 | " max_length (`int`, *optional*):\n", 294 | " Controls the maximum length to use by one of the truncation/padding parameters.\n", 295 | "\n", 296 | " If left unset or set to `None`, this will use the predefined model maximum length if a maximum length\n", 297 | " is required by one of the truncation/padding parameters. If the model has no specific maximum input\n", 298 | " length (like XLNet) truncation/padding to a maximum length will be deactivated.\n", 299 | " stride (`int`, *optional*, defaults to 0):\n", 300 | " If set to a number along with `max_length`, the overflowing tokens returned when\n", 301 | " `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence\n", 302 | " returned to provide some overlap between truncated and overflowing sequences. The value of this\n", 303 | " argument defines the number of overlapping tokens.\n", 304 | " is_split_into_words (`bool`, *optional*, defaults to `False`):\n", 305 | " Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the\n", 306 | " tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)\n", 307 | " which it will tokenize. This is useful for NER or token classification.\n", 308 | " pad_to_multiple_of (`int`, *optional*):\n", 309 | " If set will pad the sequence to a multiple of the provided value. Requires `padding` to be activated.\n", 310 | " This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability\n", 311 | " `>= 7.5` (Volta).\n", 312 | " return_tensors (`str` or [`~utils.TensorType`], *optional*):\n", 313 | " If set, will return tensors instead of list of python integers. Acceptable values are:\n", 314 | "\n", 315 | " - `'tf'`: Return TensorFlow `tf.constant` objects.\n", 316 | " - `'pt'`: Return PyTorch `torch.Tensor` objects.\n", 317 | " - `'np'`: Return Numpy `np.ndarray` objects.\n", 318 | "\n", 319 | " return_token_type_ids (`bool`, *optional*):\n", 320 | " Whether to return token type IDs. If left to the default, will return the token type IDs according to\n", 321 | " the specific tokenizer's default, defined by the `return_outputs` attribute.\n", 322 | "\n", 323 | " [What are token type IDs?](../glossary#token-type-ids)\n", 324 | " return_attention_mask (`bool`, *optional*):\n", 325 | " Whether to return the attention mask. If left to the default, will return the attention mask according\n", 326 | " to the specific tokenizer's default, defined by the `return_outputs` attribute.\n", 327 | "\n", 328 | " [What are attention masks?](../glossary#attention-mask)\n", 329 | " return_overflowing_tokens (`bool`, *optional*, defaults to `False`):\n", 330 | " Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch\n", 331 | " of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead\n", 332 | " of returning overflowing tokens.\n", 333 | " return_special_tokens_mask (`bool`, *optional*, defaults to `False`):\n", 334 | " Whether or not to return special tokens mask information.\n", 335 | " return_offsets_mapping (`bool`, *optional*, defaults to `False`):\n", 336 | " Whether or not to return `(char_start, char_end)` for each token.\n", 337 | "\n", 338 | " This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using\n", 339 | " Python's tokenizer, this method will raise `NotImplementedError`.\n", 340 | " return_length (`bool`, *optional*, defaults to `False`):\n", 341 | " Whether or not to return the lengths of the encoded inputs.\n", 342 | " verbose (`bool`, *optional*, defaults to `True`):\n", 343 | " Whether or not to print more information and warnings.\n", 344 | " **kwargs: passed to the `self.tokenize()` method\n", 345 | "\n", 346 | "Return:\n", 347 | " [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:\n", 348 | "\n", 349 | " - **input_ids** -- List of token ids to be fed to a model.\n", 350 | "\n", 351 | " [What are input IDs?](../glossary#input-ids)\n", 352 | "\n", 353 | " - **token_type_ids** -- List of token type ids to be fed to a model (when `return_token_type_ids=True` or\n", 354 | " if *\"token_type_ids\"* is in `self.model_input_names`).\n", 355 | "\n", 356 | " [What are token type IDs?](../glossary#token-type-ids)\n", 357 | "\n", 358 | " - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when\n", 359 | " `return_attention_mask=True` or if *\"attention_mask\"* is in `self.model_input_names`).\n", 360 | "\n", 361 | " [What are attention masks?](../glossary#attention-mask)\n", 362 | "\n", 363 | " - **overflowing_tokens** -- List of overflowing tokens sequences (when a `max_length` is specified and\n", 364 | " `return_overflowing_tokens=True`).\n", 365 | " - **num_truncated_tokens** -- Number of tokens truncated (when a `max_length` is specified and\n", 366 | " `return_overflowing_tokens=True`).\n", 367 | " - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying\n", 368 | " regular sequence tokens (when `add_special_tokens=True` and `return_special_tokens_mask=True`).\n", 369 | " - **length** -- The length of the inputs (when `return_length=True`)\n" 370 | ] 371 | }, 372 | "metadata": {}, 373 | "output_type": "display_data" 374 | } 375 | ], 376 | "source": [ 377 | "tokenizer?" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": 6, 383 | "id": "d0e4c076-449b-4991-a82b-4ee948bd127a", 384 | "metadata": { 385 | "id": "d0e4c076-449b-4991-a82b-4ee948bd127a", 386 | "tags": [] 387 | }, 388 | "outputs": [ 389 | { 390 | "data": { 391 | "application/vnd.jupyter.widget-view+json": { 392 | "model_id": "482d1ce18adc49d88847fcf50bf28f79", 393 | "version_major": 2, 394 | "version_minor": 0 395 | }, 396 | "text/plain": [ 397 | "README.md: 0%| | 0.00/35.3k [00:00 integer\n", 591 | "int(x, base=10) -> integer\n", 592 | "\n", 593 | "Convert a number or string to an integer, or return 0 if no arguments\n", 594 | "are given. If x is a number, return x.__int__(). For floating point\n", 595 | "numbers, this truncates towards zero.\n", 596 | "\n", 597 | "If x is not a number or if base is given, then x must be a string,\n", 598 | "bytes, or bytearray instance representing an integer literal in the\n", 599 | "given base. The literal can be preceded by '+' or '-' and be surrounded\n", 600 | "by whitespace. The base defaults to 10. Valid bases are 0 and 2-36.\n", 601 | "Base 0 means to interpret the base from the string as an integer literal.\n", 602 | ">>> int('0b100', base=0)\n", 603 | "4\n" 604 | ] 605 | }, 606 | "metadata": {}, 607 | "output_type": "display_data" 608 | } 609 | ], 610 | "source": [ 611 | "pad_token_id?" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": 14, 617 | "id": "77e0b133-bbec-4d74-b8dc-35259640359c", 618 | "metadata": { 619 | "id": "77e0b133-bbec-4d74-b8dc-35259640359c", 620 | "tags": [] 621 | }, 622 | "outputs": [], 623 | "source": [ 624 | "# Function to process the tokenized dataset and extract necessary fields\n", 625 | "def extract_token_info_with_essential_tokens(example):\n", 626 | " # Get the original sentence\n", 627 | " original_sentence = example['sentence'] if 'sentence' in example else None\n", 628 | "\n", 629 | " # Get the tokenized sentence by converting token IDs back to tokens\n", 630 | " tokenized_sentence = tokenizer.convert_ids_to_tokens(example['input_ids'])\n", 631 | "\n", 632 | " # Count the number of tokens excluding padding\n", 633 | " essential_tokens = sum(1 for token_id in example['input_ids'] if token_id != pad_token_id)\n", 634 | "\n", 635 | " # Return original sentence, tokenized sentence, token IDs, total tokens, and essential tokens\n", 636 | " return {\n", 637 | " 'sentence': original_sentence, # The original sentence\n", 638 | " 'tokenized_sentence': \" \".join(tokenized_sentence), # Tokenized sentence as a string\n", 639 | " 'token_ids': example['input_ids'], # List of token IDs\n", 640 | " 'num_tokens': len(example['input_ids']), # Total number of tokens (including padding)\n", 641 | " 'essential_tokens': essential_tokens # Number of tokens excluding padding\n", 642 | " }\n" 643 | ] 644 | }, 645 | { 646 | "cell_type": "code", 647 | "execution_count": 15, 648 | "id": "b0e64893-90c8-469e-b528-4ac7284cb47b", 649 | "metadata": { 650 | "tags": [] 651 | }, 652 | "outputs": [], 653 | "source": [ 654 | "# Apply the extraction function to the already tokenized dataset\n", 655 | "processed_test = tokenized_test.map(extract_token_info_with_essential_tokens, batched=False)" 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": 16, 661 | "id": "e065b506-07d2-4551-a34e-fe110a393239", 662 | "metadata": { 663 | "tags": [] 664 | }, 665 | "outputs": [ 666 | { 667 | "data": { 668 | "text/plain": [ 669 | "\u001b[0;31mType:\u001b[0m Dataset\n", 670 | "\u001b[0;31mString form:\u001b[0m\n", 671 | "Dataset({\n", 672 | " features: ['sentence', 'label', 'idx', 'input_ids', 'attention_mask', 'tokenized_sentence', 'token_ids', 'num_tokens', 'essential_tokens'],\n", 673 | " num_rows: 150\n", 674 | "})\n", 675 | "\u001b[0;31mLength:\u001b[0m 150\n", 676 | "\u001b[0;31mFile:\u001b[0m ~/.conda/envs/default/lib/python3.9/site-packages/datasets/arrow_dataset.py\n", 677 | "\u001b[0;31mDocstring:\u001b[0m A Dataset backed by an Arrow table.\n" 678 | ] 679 | }, 680 | "metadata": {}, 681 | "output_type": "display_data" 682 | } 683 | ], 684 | "source": [ 685 | "processed_test?" 686 | ] 687 | }, 688 | { 689 | "cell_type": "code", 690 | "execution_count": 17, 691 | "id": "6ce30866-b861-4280-8cf5-78533eafec46", 692 | "metadata": { 693 | "tags": [] 694 | }, 695 | "outputs": [], 696 | "source": [ 697 | "# Create a DataFrame from the processed dataset\n", 698 | "test_data = [\n", 699 | " {\n", 700 | " \"sentence\": ex['sentence'],\n", 701 | " \"tokenized_sentence\": ex['tokenized_sentence'], # Tokenized sentence as a string\n", 702 | " \"token_ids\": ex['token_ids'], #Tokens id\n", 703 | " \"essential_tokens\": ex['essential_tokens'], # Number of tokens without padding\n", 704 | " \"num_tokens\": ex['num_tokens'], # Total tokens (with padding)\n", 705 | " }\n", 706 | " for ex in processed_test\n", 707 | "]" 708 | ] 709 | }, 710 | { 711 | "cell_type": "code", 712 | "execution_count": 18, 713 | "id": "55cf0bf9-2a97-401f-9bcd-debadffa5664", 714 | "metadata": { 715 | "tags": [] 716 | }, 717 | "outputs": [ 718 | { 719 | "data": { 720 | "text/plain": [ 721 | "\u001b[0;31mType:\u001b[0m list\n", 722 | "\u001b[0;31mString form:\u001b[0m [{'sentence': 'at least one scene is so disgusting that viewers may be hard pressed to retain the <...> 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'essential_tokens': 37, 'num_tokens': 54}]\n", 723 | "\u001b[0;31mLength:\u001b[0m 150\n", 724 | "\u001b[0;31mDocstring:\u001b[0m \n", 725 | "Built-in mutable sequence.\n", 726 | "\n", 727 | "If no argument is given, the constructor creates a new empty list.\n", 728 | "The argument must be an iterable if specified.\n" 729 | ] 730 | }, 731 | "metadata": {}, 732 | "output_type": "display_data" 733 | } 734 | ], 735 | "source": [ 736 | "test_data?" 737 | ] 738 | }, 739 | { 740 | "cell_type": "code", 741 | "execution_count": 19, 742 | "id": "ba48ac88-6fb6-45d6-99c8-326267cfe222", 743 | "metadata": { 744 | "tags": [] 745 | }, 746 | "outputs": [], 747 | "source": [ 748 | "# Create a pandas DataFrame\n", 749 | "df_test = pd.DataFrame(test_data)\n", 750 | "\n", 751 | "# Sort the DataFrame by the number of essential tokens in descending order\n", 752 | "df_test = df_test.sort_values(by=\"essential_tokens\", ascending=False)" 753 | ] 754 | }, 755 | { 756 | "cell_type": "code", 757 | "execution_count": 20, 758 | "id": "cfcedba6-efda-48ee-ba28-1d62efc6a44c", 759 | "metadata": { 760 | "tags": [] 761 | }, 762 | "outputs": [ 763 | { 764 | "data": { 765 | "text/html": [ 766 | "
\n", 767 | "\n", 780 | "\n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | "
sentencetokenized_sentencetoken_idsessential_tokensnum_tokens
13for all its technical virtuosity , the film is...[CLS] for all its technical vi ##rt ##uo ##sit...[101, 2005, 2035, 2049, 4087, 6819, 5339, 1909...5454
34the special effects and many scenes of weightl...[CLS] the special effects and many scenes of w...[101, 1996, 2569, 3896, 1998, 2116, 5019, 1997...4754
47but the power of these ( subjects ) is obscure...[CLS] but the power of these ( subjects ) is o...[101, 2021, 1996, 2373, 1997, 2122, 1006, 5739...4754
113it does nothing new with the old story , excep...[CLS] it does nothing new with the old story ,...[101, 2009, 2515, 2498, 2047, 2007, 1996, 2214...4754
54the tale of tok ( andy lau ) , a sleek sociopa...[CLS] the tale of to ##k ( andy lau ) , a slee...[101, 1996, 6925, 1997, 2000, 2243, 1006, 5557...4654
..................
97big fat waste of time .[CLS] big fat waste of time . [SEP] [PAD] [PAD...[101, 2502, 6638, 5949, 1997, 2051, 1012, 102,...854
66it treats women like idiots .[CLS] it treats women like idiots . [SEP] [PAD...[101, 2009, 18452, 2308, 2066, 28781, 1012, 10...854
109a deep and meaningful film .[CLS] a deep and meaningful film . [SEP] [PAD]...[101, 1037, 2784, 1998, 15902, 2143, 1012, 102...854
88a wildly inconsistent emotional experience .[CLS] a wildly inconsistent emotional experien...[101, 1037, 13544, 20316, 6832, 3325, 1012, 10...854
64one from the heart .[CLS] one from the heart . [SEP] [PAD] [PAD] [...[101, 2028, 2013, 1996, 2540, 1012, 102, 0, 0,...754
\n", 882 | "

150 rows × 5 columns

\n", 883 | "
" 884 | ], 885 | "text/plain": [ 886 | " sentence \\\n", 887 | "13 for all its technical virtuosity , the film is... \n", 888 | "34 the special effects and many scenes of weightl... \n", 889 | "47 but the power of these ( subjects ) is obscure... \n", 890 | "113 it does nothing new with the old story , excep... \n", 891 | "54 the tale of tok ( andy lau ) , a sleek sociopa... \n", 892 | ".. ... \n", 893 | "97 big fat waste of time . \n", 894 | "66 it treats women like idiots . \n", 895 | "109 a deep and meaningful film . \n", 896 | "88 a wildly inconsistent emotional experience . \n", 897 | "64 one from the heart . \n", 898 | "\n", 899 | " tokenized_sentence \\\n", 900 | "13 [CLS] for all its technical vi ##rt ##uo ##sit... \n", 901 | "34 [CLS] the special effects and many scenes of w... \n", 902 | "47 [CLS] but the power of these ( subjects ) is o... \n", 903 | "113 [CLS] it does nothing new with the old story ,... \n", 904 | "54 [CLS] the tale of to ##k ( andy lau ) , a slee... \n", 905 | ".. ... \n", 906 | "97 [CLS] big fat waste of time . [SEP] [PAD] [PAD... \n", 907 | "66 [CLS] it treats women like idiots . [SEP] [PAD... \n", 908 | "109 [CLS] a deep and meaningful film . [SEP] [PAD]... \n", 909 | "88 [CLS] a wildly inconsistent emotional experien... \n", 910 | "64 [CLS] one from the heart . [SEP] [PAD] [PAD] [... \n", 911 | "\n", 912 | " token_ids essential_tokens \\\n", 913 | "13 [101, 2005, 2035, 2049, 4087, 6819, 5339, 1909... 54 \n", 914 | "34 [101, 1996, 2569, 3896, 1998, 2116, 5019, 1997... 47 \n", 915 | "47 [101, 2021, 1996, 2373, 1997, 2122, 1006, 5739... 47 \n", 916 | "113 [101, 2009, 2515, 2498, 2047, 2007, 1996, 2214... 47 \n", 917 | "54 [101, 1996, 6925, 1997, 2000, 2243, 1006, 5557... 46 \n", 918 | ".. ... ... \n", 919 | "97 [101, 2502, 6638, 5949, 1997, 2051, 1012, 102,... 8 \n", 920 | "66 [101, 2009, 18452, 2308, 2066, 28781, 1012, 10... 8 \n", 921 | "109 [101, 1037, 2784, 1998, 15902, 2143, 1012, 102... 8 \n", 922 | "88 [101, 1037, 13544, 20316, 6832, 3325, 1012, 10... 8 \n", 923 | "64 [101, 2028, 2013, 1996, 2540, 1012, 102, 0, 0,... 7 \n", 924 | "\n", 925 | " num_tokens \n", 926 | "13 54 \n", 927 | "34 54 \n", 928 | "47 54 \n", 929 | "113 54 \n", 930 | "54 54 \n", 931 | ".. ... \n", 932 | "97 54 \n", 933 | "66 54 \n", 934 | "109 54 \n", 935 | "88 54 \n", 936 | "64 54 \n", 937 | "\n", 938 | "[150 rows x 5 columns]" 939 | ] 940 | }, 941 | "execution_count": 20, 942 | "metadata": {}, 943 | "output_type": "execute_result" 944 | } 945 | ], 946 | "source": [ 947 | "df_test" 948 | ] 949 | }, 950 | { 951 | "cell_type": "markdown", 952 | "id": "3c6223ae-f687-4a28-8d7b-896143d5bae8", 953 | "metadata": { 954 | "id": "3c6223ae-f687-4a28-8d7b-896143d5bae8" 955 | }, 956 | "source": [ 957 | "## labelling (predicting) with the pretrained model" 958 | ] 959 | }, 960 | { 961 | "cell_type": "markdown", 962 | "id": "80b37cce-4f7e-46ad-9d4c-301b43e00d87", 963 | "metadata": { 964 | "id": "80b37cce-4f7e-46ad-9d4c-301b43e00d87" 965 | }, 966 | "source": [ 967 | "### labelling a single sentence" 968 | ] 969 | }, 970 | { 971 | "cell_type": "code", 972 | "execution_count": 21, 973 | "id": "8819d80b-acab-49af-93de-03e60ffab663", 974 | "metadata": { 975 | "id": "8819d80b-acab-49af-93de-03e60ffab663", 976 | "tags": [] 977 | }, 978 | "outputs": [], 979 | "source": [ 980 | "def predict_label_sentence(sentence):\n", 981 | " \"\"\"\n", 982 | " Takes a sentence and returns the original sentence, the tokenized sentence,\n", 983 | " the token IDs, the softmax probabilities, and the predicted label.\n", 984 | "\n", 985 | " Args:\n", 986 | " - sentence (str): The input sentence for sentiment analysis.\n", 987 | "\n", 988 | " Returns:\n", 989 | " - dict: A dictionary containing:\n", 990 | " - 'original_sentence': The original sentence.\n", 991 | " - 'tokenized_sentence': The tokenized version of the sentence.\n", 992 | " - 'input_ids': The token IDs (numerical representation).\n", 993 | " - 'softmax_probs': The softmax probabilities for each class.\n", 994 | " - 'predicted_label': The predicted class label.\n", 995 | " - 'sentiment': Sentiment as 'positive' or 'negative'.\n", 996 | " \"\"\"\n", 997 | "\n", 998 | " # Step 1: Tokenize the sentence\n", 999 | " tokens = tokenizer(sentence, return_tensors=\"pt\", padding=\"longest\", truncation=True, max_length=512)\n", 1000 | "\n", 1001 | " # Step 2: Get the token IDs and tokenized sentence\n", 1002 | " input_ids = tokens['input_ids']\n", 1003 | " tokenized_sentence = tokenizer.convert_ids_to_tokens(input_ids[0])\n", 1004 | "\n", 1005 | " # Step 3: Pass the tokenized input to the model to get logits\n", 1006 | " with torch.no_grad(): # Disable gradient computation for evaluation\n", 1007 | " output = model(**tokens)\n", 1008 | " logits = output.logits\n", 1009 | "\n", 1010 | " # Step 4: Apply softmax using torch to get probabilities\n", 1011 | " softmax_probs = torch.softmax(logits, dim=-1)\n", 1012 | "\n", 1013 | " # Step 5: Get the predicted label (argmax of softmax output)\n", 1014 | " predicted_label = torch.argmax(softmax_probs, dim=-1).item()\n", 1015 | "\n", 1016 | " # Step 6: Determine sentiment based on the predicted label\n", 1017 | " sentiment_label = \"positive\" if predicted_label == 1 else \"negative\"\n", 1018 | "\n", 1019 | " # Step 7: Prepare result dictionary\n", 1020 | " result = {\n", 1021 | " \"original_sentence\": sentence,\n", 1022 | " \"tokenized_sentence\": tokenized_sentence,\n", 1023 | " \"input_ids\": input_ids[0].tolist(),\n", 1024 | " \"softmax_probs\": softmax_probs[0].tolist(), # Convert tensor to list\n", 1025 | " \"predicted_label\": predicted_label,\n", 1026 | " \"sentiment\": sentiment_label\n", 1027 | " }\n", 1028 | "\n", 1029 | " return result\n" 1030 | ] 1031 | }, 1032 | { 1033 | "cell_type": "code", 1034 | "execution_count": 22, 1035 | "id": "183cf205-35cf-480c-8602-75a09a1fa490", 1036 | "metadata": { 1037 | "id": "183cf205-35cf-480c-8602-75a09a1fa490", 1038 | "tags": [] 1039 | }, 1040 | "outputs": [], 1041 | "source": [ 1042 | "sentence_label = predict_label_sentence(\"I am happy\")" 1043 | ] 1044 | }, 1045 | { 1046 | "cell_type": "code", 1047 | "execution_count": 23, 1048 | "id": "a54f7f32-5935-47f3-b660-da05144a8c43", 1049 | "metadata": { 1050 | "id": "a54f7f32-5935-47f3-b660-da05144a8c43", 1051 | "outputId": "78c08d01-c712-4f32-e83e-6188e1d3a8c5", 1052 | "tags": [] 1053 | }, 1054 | "outputs": [ 1055 | { 1056 | "data": { 1057 | "text/plain": [ 1058 | "{'original_sentence': 'I am happy',\n", 1059 | " 'tokenized_sentence': ['[CLS]', 'i', 'am', 'happy', '[SEP]'],\n", 1060 | " 'input_ids': [101, 1045, 2572, 3407, 102],\n", 1061 | " 'softmax_probs': [0.46126559376716614, 0.5387344360351562],\n", 1062 | " 'predicted_label': 1,\n", 1063 | " 'sentiment': 'positive'}" 1064 | ] 1065 | }, 1066 | "execution_count": 23, 1067 | "metadata": {}, 1068 | "output_type": "execute_result" 1069 | } 1070 | ], 1071 | "source": [ 1072 | "sentence_label" 1073 | ] 1074 | }, 1075 | { 1076 | "cell_type": "markdown", 1077 | "id": "37a94232-7568-4e2e-bea5-ef50bb27f1d9", 1078 | "metadata": { 1079 | "id": "37a94232-7568-4e2e-bea5-ef50bb27f1d9" 1080 | }, 1081 | "source": [ 1082 | "### labelling a whole dataset" 1083 | ] 1084 | }, 1085 | { 1086 | "cell_type": "code", 1087 | "execution_count": 24, 1088 | "id": "e3d3ced7-6b44-4497-b6c7-ac39b9cad1dc", 1089 | "metadata": { 1090 | "tags": [] 1091 | }, 1092 | "outputs": [], 1093 | "source": [ 1094 | "def predict_label_dataset(new_data, text_columns=[\"text\", \"sentence\", \"content\", \"title\"]):\n", 1095 | " \"\"\"\n", 1096 | " Evaluates the model on a subset of data and returns a DataFrame with\n", 1097 | " all the sentences, true labels, predicted labels, number of tokens,\n", 1098 | " softmax probabilities, tokenized sentences, and input IDs, along with accuracy.\n", 1099 | "\n", 1100 | " Args:\n", 1101 | " - new_data: The subset of the dataset to evaluate (already selected).\n", 1102 | " - text_columns: A list of possible text columns to use (default: ['text', 'sentence', 'content', 'title']).\n", 1103 | "\n", 1104 | " Returns:\n", 1105 | " - df: DataFrame containing detailed prediction information.\n", 1106 | " - accuracy: Accuracy of the model on the dataset (None if labels are unavailable).\n", 1107 | " \"\"\"\n", 1108 | " \n", 1109 | " # Step 1: Find the appropriate text column\n", 1110 | " for col in text_columns:\n", 1111 | " if col in new_data.column_names:\n", 1112 | " text_column = col\n", 1113 | " break\n", 1114 | " else:\n", 1115 | " raise ValueError(f\"None of the specified text columns {text_columns} were found in the dataset.\")\n", 1116 | " \n", 1117 | " # Initialize results dictionary\n", 1118 | " results = {\n", 1119 | " \"Sentence\": [],\n", 1120 | " \"Tokenized Sentence\": [],\n", 1121 | " \"Input IDs\": [],\n", 1122 | " \"Number of Tokens\": [],\n", 1123 | " \"Softmax Probs\": [],\n", 1124 | " \"Predicted Label\": [],\n", 1125 | " \"True Label\": [],\n", 1126 | " \"Sentiment\": []\n", 1127 | " }\n", 1128 | " \n", 1129 | " # Iterate through each example in the dataset with a progress bar\n", 1130 | " for i in tqdm(range(len(new_data)), desc=\"Labelling Sentences\"):\n", 1131 | " sentence = new_data[i][text_column]\n", 1132 | " true_label = new_data[i].get('label', -1) # Use -1 if label is missing\n", 1133 | " \n", 1134 | " # Use predict_label_sentence function to get predictions and other details\n", 1135 | " prediction = predict_label_sentence(sentence)\n", 1136 | " \n", 1137 | " # Count non-padding tokens\n", 1138 | " input_ids = prediction[\"input_ids\"]\n", 1139 | " pad_token_id = tokenizer.pad_token_id\n", 1140 | " num_tokens = sum([1 for token_id in input_ids if token_id != pad_token_id])\n", 1141 | " \n", 1142 | " # Append the details to the results dictionary\n", 1143 | " results[\"Sentence\"].append(prediction[\"original_sentence\"])\n", 1144 | " results[\"Tokenized Sentence\"].append(prediction[\"tokenized_sentence\"])\n", 1145 | " results[\"Input IDs\"].append(prediction[\"input_ids\"])\n", 1146 | " results[\"Number of Tokens\"].append(num_tokens)\n", 1147 | " results[\"Softmax Probs\"].append(prediction[\"softmax_probs\"])\n", 1148 | " results[\"Predicted Label\"].append(prediction[\"predicted_label\"])\n", 1149 | " results[\"True Label\"].append(true_label)\n", 1150 | " results[\"Sentiment\"].append(prediction[\"sentiment\"])\n", 1151 | " \n", 1152 | " # Convert results to a DataFrame\n", 1153 | " df = pd.DataFrame(results)\n", 1154 | " \n", 1155 | " # Check if true labels are available (i.e., not all -1)\n", 1156 | " if df[\"True Label\"].isin([-1]).all():\n", 1157 | " accuracy = None\n", 1158 | " else:\n", 1159 | " # Calculate Accuracy\n", 1160 | " accuracy = accuracy_score(df[\"True Label\"], df[\"Predicted Label\"])\n", 1161 | " \n", 1162 | " # Return the DataFrame and accuracy\n", 1163 | " return df, accuracy\n" 1164 | ] 1165 | }, 1166 | { 1167 | "cell_type": "code", 1168 | "execution_count": 25, 1169 | "id": "d818af2b-c43c-47d7-a362-d51bf8c79476", 1170 | "metadata": { 1171 | "tags": [] 1172 | }, 1173 | "outputs": [ 1174 | { 1175 | "name": "stderr", 1176 | "output_type": "stream", 1177 | "text": [ 1178 | "Labelling Sentences: 100%|██████████| 150/150 [00:09<00:00, 15.71it/s]\n" 1179 | ] 1180 | } 1181 | ], 1182 | "source": [ 1183 | "df_test_non_trained, accuracy_test_non_trained = predict_label_dataset(test_dataset)" 1184 | ] 1185 | }, 1186 | { 1187 | "cell_type": "code", 1188 | "execution_count": 26, 1189 | "id": "200e6c87-4287-46af-a100-4152cba122e5", 1190 | "metadata": { 1191 | "tags": [] 1192 | }, 1193 | "outputs": [ 1194 | { 1195 | "data": { 1196 | "text/plain": [ 1197 | "0.5066666666666667" 1198 | ] 1199 | }, 1200 | "execution_count": 26, 1201 | "metadata": {}, 1202 | "output_type": "execute_result" 1203 | } 1204 | ], 1205 | "source": [ 1206 | "accuracy_test_non_trained" 1207 | ] 1208 | }, 1209 | { 1210 | "cell_type": "code", 1211 | "execution_count": 27, 1212 | "id": "26e85dc6-3765-47fa-902c-ef8d4c3175a6", 1213 | "metadata": { 1214 | "tags": [] 1215 | }, 1216 | "outputs": [ 1217 | { 1218 | "data": { 1219 | "text/html": [ 1220 | "
\n", 1221 | "\n", 1234 | "\n", 1235 | " \n", 1236 | " \n", 1237 | " \n", 1238 | " \n", 1239 | " \n", 1240 | " \n", 1241 | " \n", 1242 | " \n", 1243 | " \n", 1244 | " \n", 1245 | " \n", 1246 | " \n", 1247 | " \n", 1248 | " \n", 1249 | " \n", 1250 | " \n", 1251 | " \n", 1252 | " \n", 1253 | " \n", 1254 | " \n", 1255 | " \n", 1256 | " \n", 1257 | " \n", 1258 | " \n", 1259 | " \n", 1260 | " \n", 1261 | " \n", 1262 | " \n", 1263 | " \n", 1264 | " \n", 1265 | " \n", 1266 | " \n", 1267 | " \n", 1268 | " \n", 1269 | " \n", 1270 | " \n", 1271 | " \n", 1272 | " \n", 1273 | " \n", 1274 | " \n", 1275 | " \n", 1276 | " \n", 1277 | " \n", 1278 | " \n", 1279 | " \n", 1280 | " \n", 1281 | " \n", 1282 | " \n", 1283 | " \n", 1284 | " \n", 1285 | " \n", 1286 | " \n", 1287 | " \n", 1288 | " \n", 1289 | " \n", 1290 | " \n", 1291 | " \n", 1292 | " \n", 1293 | " \n", 1294 | " \n", 1295 | " \n", 1296 | " \n", 1297 | " \n", 1298 | " \n", 1299 | " \n", 1300 | " \n", 1301 | " \n", 1302 | " \n", 1303 | " \n", 1304 | " \n", 1305 | " \n", 1306 | " \n", 1307 | " \n", 1308 | " \n", 1309 | " \n", 1310 | " \n", 1311 | " \n", 1312 | " \n", 1313 | " \n", 1314 | " \n", 1315 | " \n", 1316 | " \n", 1317 | " \n", 1318 | " \n", 1319 | " \n", 1320 | " \n", 1321 | " \n", 1322 | " \n", 1323 | " \n", 1324 | " \n", 1325 | " \n", 1326 | " \n", 1327 | " \n", 1328 | " \n", 1329 | " \n", 1330 | " \n", 1331 | " \n", 1332 | " \n", 1333 | " \n", 1334 | " \n", 1335 | " \n", 1336 | " \n", 1337 | " \n", 1338 | " \n", 1339 | " \n", 1340 | " \n", 1341 | " \n", 1342 | " \n", 1343 | " \n", 1344 | " \n", 1345 | " \n", 1346 | " \n", 1347 | " \n", 1348 | " \n", 1349 | " \n", 1350 | " \n", 1351 | " \n", 1352 | " \n", 1353 | " \n", 1354 | " \n", 1355 | " \n", 1356 | " \n", 1357 | " \n", 1358 | " \n", 1359 | " \n", 1360 | " \n", 1361 | " \n", 1362 | " \n", 1363 | " \n", 1364 | " \n", 1365 | " \n", 1366 | " \n", 1367 | " \n", 1368 | " \n", 1369 | " \n", 1370 | " \n", 1371 | "
SentenceTokenized SentenceInput IDsNumber of TokensSoftmax ProbsPredicted LabelTrue LabelSentiment
0at least one scene is so disgusting that viewe...[[CLS], at, least, one, scene, is, so, disgust...[101, 2012, 2560, 2028, 3496, 2003, 2061, 1942...20[0.49043747782707214, 0.5095624923706055]10positive
1even the finest chef ca n't make a hotdog into...[[CLS], even, the, finest, chef, ca, n, ', t, ...[101, 2130, 1996, 10418, 10026, 6187, 1050, 10...44[0.4584480822086334, 0.541551947593689]10positive
2collateral damage finally delivers the goods f...[[CLS], collateral, damage, finally, delivers,...[101, 24172, 4053, 2633, 18058, 1996, 5350, 20...14[0.47445327043533325, 0.525546669960022]11positive
3exciting and direct , with ghost imagery that ...[[CLS], exciting, and, direct, ,, with, ghost,...[101, 10990, 1998, 3622, 1010, 2007, 5745, 134...20[0.443885862827301, 0.5561141967773438]11positive
4and when you 're talking about a slapstick com...[[CLS], and, when, you, ', re, talking, about,...[101, 1998, 2043, 2017, 1005, 2128, 3331, 2055...22[0.4730995297431946, 0.5269004702568054]10positive
...........................
145it 's a bit disappointing that it only manages...[[CLS], it, ', s, a, bit, disappointing, that,...[101, 2009, 1005, 1055, 1037, 2978, 15640, 200...20[0.4603593051433563, 0.5396407246589661]10positive
146a breezy romantic comedy that has the punch of...[[CLS], a, bree, ##zy, romantic, comedy, that,...[101, 1037, 21986, 9096, 6298, 4038, 2008, 203...24[0.46655434370040894, 0.5334456562995911]11positive
147the film tries too hard to be funny and tries ...[[CLS], the, film, tries, too, hard, to, be, f...[101, 1996, 2143, 5363, 2205, 2524, 2000, 2022...18[0.4715810716152191, 0.5284189581871033]10positive
148thanks to scott 's charismatic roger and eisen...[[CLS], thanks, to, scott, ', s, charismatic, ...[101, 4283, 2000, 3660, 1005, 1055, 23916, 507...35[0.4877658486366272, 0.512234091758728]11positive
149drops you into a dizzying , volatile , pressur...[[CLS], drops, you, into, a, dizzy, ##ing, ,, ...[101, 9010, 2017, 2046, 1037, 14849, 2075, 101...37[0.47137606143951416, 0.5286239385604858]11positive
\n", 1372 | "

150 rows × 8 columns

\n", 1373 | "
" 1374 | ], 1375 | "text/plain": [ 1376 | " Sentence \\\n", 1377 | "0 at least one scene is so disgusting that viewe... \n", 1378 | "1 even the finest chef ca n't make a hotdog into... \n", 1379 | "2 collateral damage finally delivers the goods f... \n", 1380 | "3 exciting and direct , with ghost imagery that ... \n", 1381 | "4 and when you 're talking about a slapstick com... \n", 1382 | ".. ... \n", 1383 | "145 it 's a bit disappointing that it only manages... \n", 1384 | "146 a breezy romantic comedy that has the punch of... \n", 1385 | "147 the film tries too hard to be funny and tries ... \n", 1386 | "148 thanks to scott 's charismatic roger and eisen... \n", 1387 | "149 drops you into a dizzying , volatile , pressur... \n", 1388 | "\n", 1389 | " Tokenized Sentence \\\n", 1390 | "0 [[CLS], at, least, one, scene, is, so, disgust... \n", 1391 | "1 [[CLS], even, the, finest, chef, ca, n, ', t, ... \n", 1392 | "2 [[CLS], collateral, damage, finally, delivers,... \n", 1393 | "3 [[CLS], exciting, and, direct, ,, with, ghost,... \n", 1394 | "4 [[CLS], and, when, you, ', re, talking, about,... \n", 1395 | ".. ... \n", 1396 | "145 [[CLS], it, ', s, a, bit, disappointing, that,... \n", 1397 | "146 [[CLS], a, bree, ##zy, romantic, comedy, that,... \n", 1398 | "147 [[CLS], the, film, tries, too, hard, to, be, f... \n", 1399 | "148 [[CLS], thanks, to, scott, ', s, charismatic, ... \n", 1400 | "149 [[CLS], drops, you, into, a, dizzy, ##ing, ,, ... \n", 1401 | "\n", 1402 | " Input IDs Number of Tokens \\\n", 1403 | "0 [101, 2012, 2560, 2028, 3496, 2003, 2061, 1942... 20 \n", 1404 | "1 [101, 2130, 1996, 10418, 10026, 6187, 1050, 10... 44 \n", 1405 | "2 [101, 24172, 4053, 2633, 18058, 1996, 5350, 20... 14 \n", 1406 | "3 [101, 10990, 1998, 3622, 1010, 2007, 5745, 134... 20 \n", 1407 | "4 [101, 1998, 2043, 2017, 1005, 2128, 3331, 2055... 22 \n", 1408 | ".. ... ... \n", 1409 | "145 [101, 2009, 1005, 1055, 1037, 2978, 15640, 200... 20 \n", 1410 | "146 [101, 1037, 21986, 9096, 6298, 4038, 2008, 203... 24 \n", 1411 | "147 [101, 1996, 2143, 5363, 2205, 2524, 2000, 2022... 18 \n", 1412 | "148 [101, 4283, 2000, 3660, 1005, 1055, 23916, 507... 35 \n", 1413 | "149 [101, 9010, 2017, 2046, 1037, 14849, 2075, 101... 37 \n", 1414 | "\n", 1415 | " Softmax Probs Predicted Label True Label \\\n", 1416 | "0 [0.49043747782707214, 0.5095624923706055] 1 0 \n", 1417 | "1 [0.4584480822086334, 0.541551947593689] 1 0 \n", 1418 | "2 [0.47445327043533325, 0.525546669960022] 1 1 \n", 1419 | "3 [0.443885862827301, 0.5561141967773438] 1 1 \n", 1420 | "4 [0.4730995297431946, 0.5269004702568054] 1 0 \n", 1421 | ".. ... ... ... \n", 1422 | "145 [0.4603593051433563, 0.5396407246589661] 1 0 \n", 1423 | "146 [0.46655434370040894, 0.5334456562995911] 1 1 \n", 1424 | "147 [0.4715810716152191, 0.5284189581871033] 1 0 \n", 1425 | "148 [0.4877658486366272, 0.512234091758728] 1 1 \n", 1426 | "149 [0.47137606143951416, 0.5286239385604858] 1 1 \n", 1427 | "\n", 1428 | " Sentiment \n", 1429 | "0 positive \n", 1430 | "1 positive \n", 1431 | "2 positive \n", 1432 | "3 positive \n", 1433 | "4 positive \n", 1434 | ".. ... \n", 1435 | "145 positive \n", 1436 | "146 positive \n", 1437 | "147 positive \n", 1438 | "148 positive \n", 1439 | "149 positive \n", 1440 | "\n", 1441 | "[150 rows x 8 columns]" 1442 | ] 1443 | }, 1444 | "execution_count": 27, 1445 | "metadata": {}, 1446 | "output_type": "execute_result" 1447 | } 1448 | ], 1449 | "source": [ 1450 | "df_test_non_trained" 1451 | ] 1452 | }, 1453 | { 1454 | "cell_type": "markdown", 1455 | "id": "4e50d387-fa63-4260-af7e-3a705f2546ec", 1456 | "metadata": { 1457 | "id": "4e50d387-fa63-4260-af7e-3a705f2546ec" 1458 | }, 1459 | "source": [ 1460 | "## prepare datasets for pytorch" 1461 | ] 1462 | }, 1463 | { 1464 | "cell_type": "code", 1465 | "execution_count": 28, 1466 | "id": "5214926c-9aec-4347-8eba-ab678498f8df", 1467 | "metadata": { 1468 | "id": "5214926c-9aec-4347-8eba-ab678498f8df", 1469 | "outputId": "1a7951ae-ee9f-438a-ba02-91c2d937ebd4", 1470 | "tags": [] 1471 | }, 1472 | "outputs": [ 1473 | { 1474 | "name": "stdout", 1475 | "output_type": "stream", 1476 | "text": [ 1477 | "['label', 'input_ids', 'attention_mask']\n", 1478 | "['label', 'input_ids', 'attention_mask']\n" 1479 | ] 1480 | } 1481 | ], 1482 | "source": [ 1483 | "# Remove the 'sentence' and 'idx' columns from the tokenized_train dataset\n", 1484 | "tokenized_train = tokenized_train.remove_columns([\"sentence\", \"idx\"])\n", 1485 | "\n", 1486 | "# Remove the 'sentence' and 'idx' columns from the tokenized_test dataset\n", 1487 | "tokenized_test = tokenized_test.remove_columns([\"sentence\", \"idx\"])\n", 1488 | "\n", 1489 | "# Check the columns after removal (optional, for confirmation)\n", 1490 | "print(tokenized_train.column_names)\n", 1491 | "print(tokenized_test.column_names)\n" 1492 | ] 1493 | }, 1494 | { 1495 | "cell_type": "code", 1496 | "execution_count": 29, 1497 | "id": "448991c0-2d04-48c9-a934-1f6a0e82b16f", 1498 | "metadata": { 1499 | "tags": [] 1500 | }, 1501 | "outputs": [ 1502 | { 1503 | "data": { 1504 | "text/plain": [ 1505 | "\u001b[0;31mType:\u001b[0m Dataset\n", 1506 | "\u001b[0;31mString form:\u001b[0m\n", 1507 | "Dataset({\n", 1508 | " features: ['label', 'input_ids', 'attention_mask'],\n", 1509 | " num_rows: 600\n", 1510 | "})\n", 1511 | "\u001b[0;31mLength:\u001b[0m 600\n", 1512 | "\u001b[0;31mFile:\u001b[0m ~/.conda/envs/default/lib/python3.9/site-packages/datasets/arrow_dataset.py\n", 1513 | "\u001b[0;31mDocstring:\u001b[0m A Dataset backed by an Arrow table.\n" 1514 | ] 1515 | }, 1516 | "metadata": {}, 1517 | "output_type": "display_data" 1518 | } 1519 | ], 1520 | "source": [ 1521 | "tokenized_train?" 1522 | ] 1523 | }, 1524 | { 1525 | "cell_type": "code", 1526 | "execution_count": 30, 1527 | "id": "79a09813-76bb-40ab-8288-f419a9e73fd8", 1528 | "metadata": { 1529 | "tags": [] 1530 | }, 1531 | "outputs": [ 1532 | { 1533 | "data": { 1534 | "text/plain": [ 1535 | "(600, 3)" 1536 | ] 1537 | }, 1538 | "execution_count": 30, 1539 | "metadata": {}, 1540 | "output_type": "execute_result" 1541 | } 1542 | ], 1543 | "source": [ 1544 | "tokenized_train.shape" 1545 | ] 1546 | }, 1547 | { 1548 | "cell_type": "code", 1549 | "execution_count": 31, 1550 | "id": "25878e34-238b-4347-abce-6903d84887d9", 1551 | "metadata": { 1552 | "id": "25878e34-238b-4347-abce-6903d84887d9", 1553 | "tags": [] 1554 | }, 1555 | "outputs": [], 1556 | "source": [ 1557 | "# Hugging Face's Trainer API expects the data in PyTorch format.\n", 1558 | "tokenized_train.set_format(\"torch\")\n", 1559 | "tokenized_test.set_format(\"torch\")" 1560 | ] 1561 | }, 1562 | { 1563 | "cell_type": "code", 1564 | "execution_count": 32, 1565 | "id": "579f7081-f982-41b6-ac63-7db027c5b784", 1566 | "metadata": { 1567 | "id": "579f7081-f982-41b6-ac63-7db027c5b784", 1568 | "outputId": "6bd0b892-f49f-492e-aeaf-d8385c27ebac", 1569 | "tags": [] 1570 | }, 1571 | "outputs": [ 1572 | { 1573 | "data": { 1574 | "text/plain": [ 1575 | "\u001b[0;31mType:\u001b[0m Dataset\n", 1576 | "\u001b[0;31mString form:\u001b[0m\n", 1577 | "Dataset({\n", 1578 | " features: ['label', 'input_ids', 'attention_mask'],\n", 1579 | " num_rows: 600\n", 1580 | "})\n", 1581 | "\u001b[0;31mLength:\u001b[0m 600\n", 1582 | "\u001b[0;31mFile:\u001b[0m ~/.conda/envs/default/lib/python3.9/site-packages/datasets/arrow_dataset.py\n", 1583 | "\u001b[0;31mDocstring:\u001b[0m A Dataset backed by an Arrow table.\n" 1584 | ] 1585 | }, 1586 | "metadata": {}, 1587 | "output_type": "display_data" 1588 | } 1589 | ], 1590 | "source": [ 1591 | "tokenized_train?" 1592 | ] 1593 | }, 1594 | { 1595 | "cell_type": "markdown", 1596 | "id": "733e8ffa-902a-431b-a35e-25c96aad5923", 1597 | "metadata": { 1598 | "id": "733e8ffa-902a-431b-a35e-25c96aad5923" 1599 | }, 1600 | "source": [ 1601 | "# Finetuning the model" 1602 | ] 1603 | }, 1604 | { 1605 | "cell_type": "code", 1606 | "execution_count": 33, 1607 | "id": "8f6a8ad4-fe2d-4ced-809b-9baa46f5542c", 1608 | "metadata": { 1609 | "id": "8f6a8ad4-fe2d-4ced-809b-9baa46f5542c", 1610 | "tags": [] 1611 | }, 1612 | "outputs": [ 1613 | { 1614 | "data": { 1615 | "text/plain": [ 1616 | "DistilBertForSequenceClassification(\n", 1617 | " (distilbert): DistilBertModel(\n", 1618 | " (embeddings): Embeddings(\n", 1619 | " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", 1620 | " (position_embeddings): Embedding(512, 768)\n", 1621 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 1622 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1623 | " )\n", 1624 | " (transformer): Transformer(\n", 1625 | " (layer): ModuleList(\n", 1626 | " (0-5): 6 x TransformerBlock(\n", 1627 | " (attention): MultiHeadSelfAttention(\n", 1628 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1629 | " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", 1630 | " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", 1631 | " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", 1632 | " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", 1633 | " )\n", 1634 | " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 1635 | " (ffn): FFN(\n", 1636 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1637 | " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", 1638 | " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", 1639 | " (activation): GELUActivation()\n", 1640 | " )\n", 1641 | " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 1642 | " )\n", 1643 | " )\n", 1644 | " )\n", 1645 | " )\n", 1646 | " (pre_classifier): Linear(in_features=768, out_features=768, bias=True)\n", 1647 | " (classifier): Linear(in_features=768, out_features=2, bias=True)\n", 1648 | " (dropout): Dropout(p=0.2, inplace=False)\n", 1649 | ")" 1650 | ] 1651 | }, 1652 | "execution_count": 33, 1653 | "metadata": {}, 1654 | "output_type": "execute_result" 1655 | } 1656 | ], 1657 | "source": [ 1658 | "model" 1659 | ] 1660 | }, 1661 | { 1662 | "cell_type": "code", 1663 | "execution_count": 34, 1664 | "id": "8cd56f01-cf4a-4716-a992-051aa52e93ee", 1665 | "metadata": { 1666 | "id": "8cd56f01-cf4a-4716-a992-051aa52e93ee", 1667 | "tags": [] 1668 | }, 1669 | "outputs": [ 1670 | { 1671 | "data": { 1672 | "text/plain": [ 1673 | "\u001b[0;31mSignature:\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1674 | "\u001b[0;31mType:\u001b[0m DistilBertForSequenceClassification\n", 1675 | "\u001b[0;31mString form:\u001b[0m \n", 1676 | "DistilBertForSequenceClassification(\n", 1677 | " (distilbert): DistilBertModel(\n", 1678 | " (embeddings): Embedding <...> : Linear(in_features=768, out_features=2, bias=True)\n", 1679 | " (dropout): Dropout(p=0.2, inplace=False)\n", 1680 | " )\n", 1681 | "\u001b[0;31mFile:\u001b[0m ~/.conda/envs/default/lib/python3.9/site-packages/transformers/models/distilbert/modeling_distilbert.py\n", 1682 | "\u001b[0;31mDocstring:\u001b[0m \n", 1683 | "DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the\n", 1684 | "pooled output) e.g. for GLUE tasks.\n", 1685 | "\n", 1686 | "\n", 1687 | "This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n", 1688 | "library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n", 1689 | "etc.)\n", 1690 | "\n", 1691 | "This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n", 1692 | "Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n", 1693 | "and behavior.\n", 1694 | "\n", 1695 | "Parameters:\n", 1696 | " config ([`DistilBertConfig`]): Model configuration class with all the parameters of the model.\n", 1697 | " Initializing with a config file does not load the weights associated with the model, only the\n", 1698 | " configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n", 1699 | "\u001b[0;31mInit docstring:\u001b[0m Initialize internal Module state, shared by both nn.Module and ScriptModule.\n" 1700 | ] 1701 | }, 1702 | "metadata": {}, 1703 | "output_type": "display_data" 1704 | } 1705 | ], 1706 | "source": [ 1707 | "model?" 1708 | ] 1709 | }, 1710 | { 1711 | "cell_type": "code", 1712 | "execution_count": 35, 1713 | "id": "0deb2a00-44a0-41e1-be14-0a9add0b229a", 1714 | "metadata": { 1715 | "id": "0deb2a00-44a0-41e1-be14-0a9add0b229a", 1716 | "outputId": "a62483f2-7e63-455d-db6e-97466471b267", 1717 | "tags": [] 1718 | }, 1719 | "outputs": [ 1720 | { 1721 | "data": { 1722 | "text/plain": [ 1723 | "DistilBertConfig {\n", 1724 | " \"_name_or_path\": \"distilbert-base-uncased\",\n", 1725 | " \"activation\": \"gelu\",\n", 1726 | " \"architectures\": [\n", 1727 | " \"DistilBertForMaskedLM\"\n", 1728 | " ],\n", 1729 | " \"attention_dropout\": 0.1,\n", 1730 | " \"dim\": 768,\n", 1731 | " \"dropout\": 0.1,\n", 1732 | " \"hidden_dim\": 3072,\n", 1733 | " \"initializer_range\": 0.02,\n", 1734 | " \"max_position_embeddings\": 512,\n", 1735 | " \"model_type\": \"distilbert\",\n", 1736 | " \"n_heads\": 12,\n", 1737 | " \"n_layers\": 6,\n", 1738 | " \"pad_token_id\": 0,\n", 1739 | " \"qa_dropout\": 0.1,\n", 1740 | " \"seq_classif_dropout\": 0.2,\n", 1741 | " \"sinusoidal_pos_embds\": false,\n", 1742 | " \"tie_weights_\": true,\n", 1743 | " \"transformers_version\": \"4.44.2\",\n", 1744 | " \"vocab_size\": 30522\n", 1745 | "}" 1746 | ] 1747 | }, 1748 | "execution_count": 35, 1749 | "metadata": {}, 1750 | "output_type": "execute_result" 1751 | } 1752 | ], 1753 | "source": [ 1754 | "model.config" 1755 | ] 1756 | }, 1757 | { 1758 | "cell_type": "markdown", 1759 | "id": "32cc46a7-b9f3-496f-8c1d-7c6700e68edd", 1760 | "metadata": { 1761 | "id": "32cc46a7-b9f3-496f-8c1d-7c6700e68edd", 1762 | "jp-MarkdownHeadingCollapsed": true, 1763 | "tags": [] 1764 | }, 1765 | "source": [ 1766 | "## Detailed Explanation of `DistilBertConfig`" 1767 | ] 1768 | }, 1769 | { 1770 | "cell_type": "markdown", 1771 | "id": "49fd8331-9a0f-4d49-b453-eb0d2bf23c53", 1772 | "metadata": { 1773 | "id": "49fd8331-9a0f-4d49-b453-eb0d2bf23c53" 1774 | }, 1775 | "source": [ 1776 | "\n", 1777 | "This configuration describes the architecture and hyperparameters for the `DistilBERT` model. Below is an in-depth explanation of each field in the configuration:\n", 1778 | "\n", 1779 | "1. **`_name_or_path`: \"distilbert-base-uncased\"` \n", 1780 | " - This specifies the name or path of the pretrained model.\n", 1781 | " - `\"distilbert-base-uncased\"` is a smaller, lighter version of BERT that removes the case sensitivity of text (i.e., it treats \"Hello\" and \"hello\" the same way).\n", 1782 | " \n", 1783 | "2. **`activation`: \"gelu\"` \n", 1784 | " - This defines the activation function used in the model. \n", 1785 | " - `\"gelu\"` stands for **Gaussian Error Linear Unit**, which is a smoother version of ReLU and commonly used in transformer models.\n", 1786 | "\n", 1787 | "3. **`architectures`: [\"DistilBertForMaskedLM\"]** \n", 1788 | " - This indicates the type of architecture being used. \n", 1789 | " - `DistilBertForMaskedLM` is the architecture for **Masked Language Modeling**, where the model predicts missing or masked words in sentences. This is used for pretraining BERT-based models.\n", 1790 | "\n", 1791 | "4. **`attention_dropout`: 0.1** \n", 1792 | " - Dropout rate for the attention layers. \n", 1793 | " - Dropout is a regularization technique used to prevent overfitting by randomly setting a fraction of the attention scores to zero during training. In this case, the rate is 10% (0.1).\n", 1794 | "\n", 1795 | "5. **`dim`: 768** \n", 1796 | " - The dimensionality of the hidden representations in the model. \n", 1797 | " - Each input token is represented by a vector of size 768 in this version of DistilBERT.\n", 1798 | "\n", 1799 | "6. **`dropout`: 0.1** \n", 1800 | " - The general dropout rate applied throughout the model. \n", 1801 | " - This helps prevent overfitting by randomly dropping 10% of the neurons during training.\n", 1802 | "\n", 1803 | "7. **`hidden_dim`: 3072** \n", 1804 | " - This represents the size of the hidden layer in the feedforward neural network part of the transformer model. \n", 1805 | " - Specifically, this is the size of the intermediate layer in each transformer block, which typically has a larger dimension (3072) compared to the input/output dimension (768).\n", 1806 | "8. **`initializer_range`: 0.02** \n", 1807 | " - This defines the range used to initialize the weights in the model. \n", 1808 | " - The model’s weights are initialized using a uniform distribution in the range [-0.02, 0.02].\n", 1809 | "\n", 1810 | "9. **`max_position_embeddings`: 512** \n", 1811 | " - The maximum number of tokens or positions that the model can handle. \n", 1812 | " - For DistilBERT, this is capped at 512 tokens. Any input longer than 512 tokens will be truncated.\n", 1813 | "\n", 1814 | "10. **`model_type`: \"distilbert\"` \n", 1815 | " - This defines the type of model being used. \n", 1816 | " - `distilbert` is a distilled version of the BERT model, which retains 97% of BERT’s performance but is 60% faster and smaller in size.\n", 1817 | "\n", 1818 | "11. **`n_heads`: 12** \n", 1819 | " - The number of attention heads in the multi-head attention mechanism. \n", 1820 | " - In transformer architectures like BERT, the attention mechanism is split into multiple \"heads\" that focus on different parts of the input sequence. DistilBERT uses 12 attention heads.\n", 1821 | "\n", 1822 | "12. **`n_layers`: 6** \n", 1823 | " - The number of layers (transformer blocks) in the model. \n", 1824 | " - DistilBERT has 6 layers, as opposed to the 12 layers in BERT. This reduction is one reason why DistilBERT is faster and smaller.\n", 1825 | "\n", 1826 | "13. **`pad_token_id`: 0** \n", 1827 | " - The token ID used to represent padding in the input sequence. \n", 1828 | " - Padding tokens are added to make all sequences in a batch the same length, and `0` is the ID for the padding token.\n", 1829 | "\n", 1830 | "14. **`qa_dropout`: 0.1** \n", 1831 | " - Dropout rate applied during the Question Answering (QA) head of the model. \n", 1832 | " - This is used in tasks like SQuAD (Stanford Question Answering Dataset), where a 10% dropout rate is applied.\n", 1833 | "\n", 1834 | "15. **`seq_classif_dropout`: 0.2** \n", 1835 | " - Dropout rate used in the sequence classification head of the model. \n", 1836 | " - This is applicable for tasks like text classification, where a 20% dropout rate is applied to prevent overfitting.\n", 1837 | "\n", 1838 | "16. **`sinusoidal_pos_embds`: false** \n", 1839 | " - This flag indicates whether sinusoidal positional embeddings are used. \n", 1840 | " - DistilBERT uses learned positional embeddings (as in the original BERT) instead of sinusoidal ones.\n", 1841 | "\n", 1842 | "17. **`tie_weights_`: true** \n", 1843 | " - This indicates whether the weights of the embeddings and the output layer are tied. \n", 1844 | " - Weight tying reduces the number of parameters in the model and ensures that the input and output embeddings are similar.\n", 1845 | "\n", 1846 | "18. **`transformers_version`: \"4.44.0\"** \n", 1847 | " - This specifies the version of the Hugging Face Transformers library used to configure the model. \n", 1848 | " - In this case, the version is 4.44.0.\n", 1849 | "\n", 1850 | "19. **`vocab_size`: 30522** \n", 1851 | " - The size of the vocabulary used by the tokenizer and the model. \n", 1852 | " - DistilBERT inherits the BERT tokenizer, which uses a vocabulary of 30,522 tokens. This includes words, subwords, and special tokens (like [PAD], [CLS], etc.).\n" 1853 | ] 1854 | }, 1855 | { 1856 | "cell_type": "markdown", 1857 | "id": "ed9bbb9e-d86d-4799-86eb-a3bb2e95627d", 1858 | "metadata": { 1859 | "id": "ed9bbb9e-d86d-4799-86eb-a3bb2e95627d" 1860 | }, 1861 | "source": [ 1862 | "## training arguments" 1863 | ] 1864 | }, 1865 | { 1866 | "cell_type": "code", 1867 | "execution_count": 36, 1868 | "id": "f9d2679d-08aa-4fab-b845-e777b3b8a632", 1869 | "metadata": { 1870 | "id": "f9d2679d-08aa-4fab-b845-e777b3b8a632", 1871 | "tags": [] 1872 | }, 1873 | "outputs": [], 1874 | "source": [ 1875 | "# Load accuracy as the evaluation metric. This will be used to compute\n", 1876 | "# the accuracy of the model on the validation dataset during evaluation.\n", 1877 | "accuracy_metric = evaluate.load(\"accuracy\")\n", 1878 | "\n", 1879 | "# Define the function to compute metrics (accuracy in this case).\n", 1880 | "def compute_metrics(eval_pred):\n", 1881 | " predictions, labels = eval_pred\n", 1882 | " predictions = predictions.argmax(axis=1)\n", 1883 | " return accuracy_metric.compute(predictions=predictions, references=labels)\n" 1884 | ] 1885 | }, 1886 | { 1887 | "cell_type": "code", 1888 | "execution_count": 37, 1889 | "id": "6d256a14-d35d-4d11-b102-fa3fe76b6732", 1890 | "metadata": { 1891 | "id": "6d256a14-d35d-4d11-b102-fa3fe76b6732", 1892 | "tags": [] 1893 | }, 1894 | "outputs": [ 1895 | { 1896 | "name": "stderr", 1897 | "output_type": "stream", 1898 | "text": [ 1899 | "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/transformers/training_args.py:1525: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n", 1900 | " warnings.warn(\n" 1901 | ] 1902 | } 1903 | ], 1904 | "source": [ 1905 | "# Define the training arguments, which control how the model will be trained.\n", 1906 | "# Each argument has a direct or indirect impact on both the computation time\n", 1907 | "# and the model's final performance.\n", 1908 | "training_args = TrainingArguments(\n", 1909 | " output_dir=\"./results\", # Directory where the model's checkpoints\n", 1910 | " # and outputs will be saved.\n", 1911 | " # (Doesn't directly affect training time)\n", 1912 | "\n", 1913 | " eval_steps=5, # Evaluate the model every eval_steps.\n", 1914 | " # Frequent evaluations can slow down training,\n", 1915 | " # but provide insights into model performance\n", 1916 | " # during training\n", 1917 | "\n", 1918 | " learning_rate=2e-5, # Learning rate controls the speed at which\n", 1919 | " # the model updates weights during training.\n", 1920 | " # A higher rate may lead to faster convergence,\n", 1921 | " # but could also risk overshooting optima,\n", 1922 | " # while a lower rate results in slower but\n", 1923 | " # potentially more stable training.\n", 1924 | "\n", 1925 | " per_device_train_batch_size=16, # Batch size for training on each device (GPU/CPU).\n", 1926 | " # A larger batch size speeds up training by\n", 1927 | " # processing more data per step, but uses more memory.\n", 1928 | " # If you run out of memory, reduce this value.\n", 1929 | " # Smaller batch sizes mean more updates per epoch.\n", 1930 | "\n", 1931 | " per_device_eval_batch_size=64, # Batch size for evaluation (validation/test set).\n", 1932 | " # Larger batch sizes can make evaluation faster\n", 1933 | " # but require more memory. Evaluation only happens\n", 1934 | " # during the validation phase, so it doesn't affect\n", 1935 | " # the training speed.\n", 1936 | "\n", 1937 | " num_train_epochs=4, # Number of training epochs. Each epoch is one full\n", 1938 | " # pass through the training dataset. More epochs\n", 1939 | " # increase training time but give the model more\n", 1940 | " # chances to learn. Fewer epochs result in faster\n", 1941 | " # training but risk underfitting the model.\n", 1942 | "\n", 1943 | " gradient_accumulation_steps=3, # Accumulate gradients over multiple steps before\n", 1944 | " # updating model weights. This simulates a larger\n", 1945 | " # batch size (e.g., with batch_size=16 and\n", 1946 | " # gradient_accumulation_steps=3, the model behaves\n", 1947 | " # like batch_size=48). This reduces memory usage\n", 1948 | " # but slows down training because updates happen\n", 1949 | " # less frequently.\n", 1950 | "\n", 1951 | " weight_decay=0.01, # Weight decay applies regularization during training\n", 1952 | " # to prevent overfitting by penalizing large weights.\n", 1953 | " # It improves generalization and helps ensure that\n", 1954 | " # the model performs well on unseen data.\n", 1955 | "\n", 1956 | " logging_dir=\"./logs\", # Directory for saving logs. Logging doesn't directly\n", 1957 | " # affect training speed, but frequent logging\n", 1958 | " # (e.g., at every step) can slow down the process.\n", 1959 | " # Set appropriate intervals for logging to balance\n", 1960 | " # information and speed.\n", 1961 | "\n", 1962 | " logging_steps=100, # Log metrics every 100 steps. Too frequent logging\n", 1963 | " # can slow training down, while infrequent logging\n", 1964 | " # might not provide enough insight into the model's\n", 1965 | " # performance during training. Adjust based on your\n", 1966 | " # need for monitoring.\n", 1967 | "\n", 1968 | " save_strategy=\"epoch\", # Save the model's checkpoints at the end of each\n", 1969 | " # epoch. This is generally efficient and safe\n", 1970 | " # unless you need more frequent saving (e.g., \"steps\").\n", 1971 | " # More frequent saving can slow down training,\n", 1972 | " # as saving checkpoints takes time.\n", 1973 | "\n", 1974 | " load_best_model_at_end=True, # Load the best model based on the evaluation\n", 1975 | " # metric after training finishes. While this\n", 1976 | " # doesn't affect training speed, it ensures the\n", 1977 | " # best-performing model (usually evaluated on\n", 1978 | " # validation accuracy or loss) is kept.\n", 1979 | "\n", 1980 | " metric_for_best_model=\"accuracy\",# Monitor accuracy to select the best model.\n", 1981 | " # This defines the metric used to determine\n", 1982 | " # which model is considered the best when\n", 1983 | " # `load_best_model_at_end` is set to True.\n", 1984 | "\n", 1985 | " evaluation_strategy=\"epoch\", # Run evaluation at the end of each epoch.\n", 1986 | " # This balances training and evaluation time,\n", 1987 | " # allowing for regular checks on validation\n", 1988 | " # performance without frequent interruptions.\n", 1989 | "\n", 1990 | " report_to=\"none\", # No need to report results to external platforms\n", 1991 | " # like TensorBoard or Weights & Biases. This keeps\n", 1992 | " # overhead minimal and speeds up the training process\n", 1993 | " # if you're not interested in reporting metrics\n", 1994 | " # elsewhere.\n", 1995 | "\n", 1996 | " seed=42 # Sets a fixed random seed to ensure reproducibility.\n", 1997 | " # Doesn't affect computation time but helps ensure\n", 1998 | " # the same results on re-runs.\n", 1999 | ")\n" 2000 | ] 2001 | }, 2002 | { 2003 | "cell_type": "code", 2004 | "execution_count": 38, 2005 | "id": "6125ca4c-4b61-4225-9d45-da303f5fc0ab", 2006 | "metadata": { 2007 | "id": "6125ca4c-4b61-4225-9d45-da303f5fc0ab", 2008 | "tags": [] 2009 | }, 2010 | "outputs": [], 2011 | "source": [ 2012 | "trainer = Trainer(\n", 2013 | " model=model, # The model to fine-tune\n", 2014 | " args=training_args, # Training arguments (from TrainingArguments)\n", 2015 | " train_dataset=tokenized_train, # The tokenized training dataset\n", 2016 | " eval_dataset=tokenized_test, # The tokenized evaluation dataset\n", 2017 | " compute_metrics=compute_metrics, # Function to compute evaluation metrics\n", 2018 | "\n", 2019 | " # Additional Arguments\n", 2020 | " tokenizer=tokenizer, # The tokenizer to use (optional, but useful if you\n", 2021 | " # want to use it for decoding or processing inputs).\n", 2022 | "\n", 2023 | " data_collator=None, # A function to prepare batches of data. This is\n", 2024 | " # typically left as `None`, and the default\n", 2025 | " # collator is used, but you can define your own\n", 2026 | " # data collator if necessary (e.g., for dynamic\n", 2027 | " # padding).\n", 2028 | "\n", 2029 | " optimizers=(None, None), # You can provide your own optimizer and scheduler\n", 2030 | " # (learning rate scheduler). If `None`, the default\n", 2031 | " # AdamW optimizer and linear scheduler are used.\n", 2032 | "\n", 2033 | " callbacks=None, # List of callbacks, such as `EarlyStoppingCallback`,\n", 2034 | " # to run during training. Callbacks can be used\n", 2035 | " # to perform additional actions during training.\n", 2036 | "\n", 2037 | " preprocess_logits_for_metrics=None, # If you want to pre-process logits before\n", 2038 | " # computing metrics, define a function here.\n", 2039 | ")\n" 2040 | ] 2041 | }, 2042 | { 2043 | "cell_type": "code", 2044 | "execution_count": 39, 2045 | "id": "22028822-7885-42d8-9474-e20a6e26ce4c", 2046 | "metadata": { 2047 | "id": "22028822-7885-42d8-9474-e20a6e26ce4c", 2048 | "tags": [] 2049 | }, 2050 | "outputs": [], 2051 | "source": [ 2052 | "#trainer?" 2053 | ] 2054 | }, 2055 | { 2056 | "cell_type": "code", 2057 | "execution_count": 40, 2058 | "id": "585c5804-3cf4-4d7d-990d-e2f0ce2eebd7", 2059 | "metadata": { 2060 | "id": "585c5804-3cf4-4d7d-990d-e2f0ce2eebd7", 2061 | "outputId": "d55f182b-66b7-451a-cf54-b2a38e2ebf90", 2062 | "tags": [] 2063 | }, 2064 | "outputs": [ 2065 | { 2066 | "data": { 2067 | "text/html": [ 2068 | "\n", 2069 | "
\n", 2070 | " \n", 2071 | " \n", 2072 | " [48/48 05:18, Epoch 3/4]\n", 2073 | "
\n", 2074 | " \n", 2075 | " \n", 2076 | " \n", 2077 | " \n", 2078 | " \n", 2079 | " \n", 2080 | " \n", 2081 | " \n", 2082 | " \n", 2083 | " \n", 2084 | " \n", 2085 | " \n", 2086 | " \n", 2087 | " \n", 2088 | " \n", 2089 | " \n", 2090 | " \n", 2091 | " \n", 2092 | " \n", 2093 | " \n", 2094 | " \n", 2095 | " \n", 2096 | " \n", 2097 | " \n", 2098 | " \n", 2099 | " \n", 2100 | " \n", 2101 | " \n", 2102 | " \n", 2103 | "
EpochTraining LossValidation LossAccuracy
0No log0.6752320.506667
1No log0.5985890.680000
3No log0.4899730.873333

" 2104 | ], 2105 | "text/plain": [ 2106 | "" 2107 | ] 2108 | }, 2109 | "metadata": {}, 2110 | "output_type": "display_data" 2111 | }, 2112 | { 2113 | "data": { 2114 | "text/plain": [ 2115 | "TrainOutput(global_step=48, training_loss=0.5745113690694174, metrics={'train_runtime': 326.155, 'train_samples_per_second': 7.358, 'train_steps_per_second': 0.147, 'total_flos': 31264375885920.0, 'train_loss': 0.5745113690694174, 'epoch': 3.7894736842105265})" 2116 | ] 2117 | }, 2118 | "execution_count": 40, 2119 | "metadata": {}, 2120 | "output_type": "execute_result" 2121 | } 2122 | ], 2123 | "source": [ 2124 | "trainer.train()\n", 2125 | "#trainer.train()" 2126 | ] 2127 | }, 2128 | { 2129 | "cell_type": "code", 2130 | "execution_count": 41, 2131 | "id": "526bf3b2-a324-4a02-aa82-6726d3649ac5", 2132 | "metadata": { 2133 | "id": "526bf3b2-a324-4a02-aa82-6726d3649ac5", 2134 | "outputId": "48c2b14f-7a45-4424-e394-3d336bfd0dcb", 2135 | "tags": [] 2136 | }, 2137 | "outputs": [ 2138 | { 2139 | "data": { 2140 | "text/html": [ 2141 | "\n", 2142 | "

\n", 2143 | " \n", 2144 | " \n", 2145 | " [10/10 00:21]\n", 2146 | "
\n", 2147 | " " 2148 | ], 2149 | "text/plain": [ 2150 | "" 2151 | ] 2152 | }, 2153 | "metadata": {}, 2154 | "output_type": "display_data" 2155 | } 2156 | ], 2157 | "source": [ 2158 | "# Run evaluation on both training and validation datasets\n", 2159 | "train_results = trainer.evaluate(eval_dataset=tokenized_train) # Evaluate on training set\n", 2160 | "test_results = trainer.evaluate(eval_dataset=tokenized_test) # Evaluate on test set\n" 2161 | ] 2162 | }, 2163 | { 2164 | "cell_type": "code", 2165 | "execution_count": 42, 2166 | "id": "e38d4b0d-6f2a-4de8-a13f-4fa7ffa740f6", 2167 | "metadata": { 2168 | "id": "e38d4b0d-6f2a-4de8-a13f-4fa7ffa740f6", 2169 | "outputId": "cb495624-5487-4fb2-fb25-74fe9d52e598", 2170 | "tags": [] 2171 | }, 2172 | "outputs": [ 2173 | { 2174 | "data": { 2175 | "text/plain": [ 2176 | "{'eval_loss': 0.4245406687259674,\n", 2177 | " 'eval_accuracy': 0.8816666666666667,\n", 2178 | " 'eval_runtime': 19.1689,\n", 2179 | " 'eval_samples_per_second': 31.301,\n", 2180 | " 'eval_steps_per_second': 0.522,\n", 2181 | " 'epoch': 3.7894736842105265}" 2182 | ] 2183 | }, 2184 | "execution_count": 42, 2185 | "metadata": {}, 2186 | "output_type": "execute_result" 2187 | } 2188 | ], 2189 | "source": [ 2190 | "train_results" 2191 | ] 2192 | }, 2193 | { 2194 | "cell_type": "code", 2195 | "execution_count": 43, 2196 | "id": "72714d56-c6c2-45dc-ab45-673ea1f88064", 2197 | "metadata": { 2198 | "id": "72714d56-c6c2-45dc-ab45-673ea1f88064", 2199 | "outputId": "b3c14914-510b-4a60-8c54-0953bd9393ba", 2200 | "tags": [] 2201 | }, 2202 | "outputs": [ 2203 | { 2204 | "data": { 2205 | "text/plain": [ 2206 | "{'eval_loss': 0.48997294902801514,\n", 2207 | " 'eval_accuracy': 0.8733333333333333,\n", 2208 | " 'eval_runtime': 4.4629,\n", 2209 | " 'eval_samples_per_second': 33.611,\n", 2210 | " 'eval_steps_per_second': 0.672,\n", 2211 | " 'epoch': 3.7894736842105265}" 2212 | ] 2213 | }, 2214 | "execution_count": 43, 2215 | "metadata": {}, 2216 | "output_type": "execute_result" 2217 | } 2218 | ], 2219 | "source": [ 2220 | "test_results" 2221 | ] 2222 | }, 2223 | { 2224 | "cell_type": "markdown", 2225 | "id": "82b522bd-0077-499a-9993-668fbb1d5df7", 2226 | "metadata": { 2227 | "id": "82b522bd-0077-499a-9993-668fbb1d5df7" 2228 | }, 2229 | "source": [ 2230 | "## labelling with the trained model" 2231 | ] 2232 | }, 2233 | { 2234 | "cell_type": "code", 2235 | "execution_count": 44, 2236 | "id": "4210f56d-569a-445f-a74e-f23f777603bf", 2237 | "metadata": { 2238 | "id": "4210f56d-569a-445f-a74e-f23f777603bf", 2239 | "outputId": "04448d5c-24a9-4838-f454-338b4e9f6c55" 2240 | }, 2241 | "outputs": [], 2242 | "source": [ 2243 | "sentence_label = predict_label_sentence(\"I am happy\")" 2244 | ] 2245 | }, 2246 | { 2247 | "cell_type": "code", 2248 | "execution_count": 45, 2249 | "id": "b5b396d6-7161-4352-ad18-e4e379e9e281", 2250 | "metadata": { 2251 | "id": "b5b396d6-7161-4352-ad18-e4e379e9e281" 2252 | }, 2253 | "outputs": [ 2254 | { 2255 | "data": { 2256 | "text/plain": [ 2257 | "{'original_sentence': 'I am happy',\n", 2258 | " 'tokenized_sentence': ['[CLS]', 'i', 'am', 'happy', '[SEP]'],\n", 2259 | " 'input_ids': [101, 1045, 2572, 3407, 102],\n", 2260 | " 'softmax_probs': [0.19336703419685364, 0.8066329956054688],\n", 2261 | " 'predicted_label': 1,\n", 2262 | " 'sentiment': 'positive'}" 2263 | ] 2264 | }, 2265 | "execution_count": 45, 2266 | "metadata": {}, 2267 | "output_type": "execute_result" 2268 | } 2269 | ], 2270 | "source": [ 2271 | "sentence_label" 2272 | ] 2273 | }, 2274 | { 2275 | "cell_type": "code", 2276 | "execution_count": 46, 2277 | "id": "c599ffd8-8c6f-4870-851e-2329dbe5a3ec", 2278 | "metadata": { 2279 | "tags": [] 2280 | }, 2281 | "outputs": [ 2282 | { 2283 | "name": "stderr", 2284 | "output_type": "stream", 2285 | "text": [ 2286 | "Labelling Sentences: 100%|██████████| 150/150 [00:10<00:00, 14.60it/s]\n" 2287 | ] 2288 | } 2289 | ], 2290 | "source": [ 2291 | "df_test_trained, accuracy_test_trained = predict_label_dataset(test_dataset)" 2292 | ] 2293 | }, 2294 | { 2295 | "cell_type": "code", 2296 | "execution_count": 47, 2297 | "id": "1bba9e97-18c0-47f0-8338-140b5fb201d9", 2298 | "metadata": { 2299 | "tags": [] 2300 | }, 2301 | "outputs": [ 2302 | { 2303 | "data": { 2304 | "text/html": [ 2305 | "
\n", 2306 | "\n", 2319 | "\n", 2320 | " \n", 2321 | " \n", 2322 | " \n", 2323 | " \n", 2324 | " \n", 2325 | " \n", 2326 | " \n", 2327 | " \n", 2328 | " \n", 2329 | " \n", 2330 | " \n", 2331 | " \n", 2332 | " \n", 2333 | " \n", 2334 | " \n", 2335 | " \n", 2336 | " \n", 2337 | " \n", 2338 | " \n", 2339 | " \n", 2340 | " \n", 2341 | " \n", 2342 | " \n", 2343 | " \n", 2344 | " \n", 2345 | " \n", 2346 | " \n", 2347 | " \n", 2348 | " \n", 2349 | " \n", 2350 | " \n", 2351 | " \n", 2352 | " \n", 2353 | " \n", 2354 | " \n", 2355 | " \n", 2356 | " \n", 2357 | " \n", 2358 | " \n", 2359 | " \n", 2360 | " \n", 2361 | " \n", 2362 | " \n", 2363 | " \n", 2364 | " \n", 2365 | " \n", 2366 | " \n", 2367 | " \n", 2368 | " \n", 2369 | " \n", 2370 | " \n", 2371 | " \n", 2372 | " \n", 2373 | " \n", 2374 | " \n", 2375 | " \n", 2376 | " \n", 2377 | " \n", 2378 | " \n", 2379 | " \n", 2380 | " \n", 2381 | " \n", 2382 | " \n", 2383 | " \n", 2384 | " \n", 2385 | " \n", 2386 | " \n", 2387 | " \n", 2388 | " \n", 2389 | " \n", 2390 | " \n", 2391 | " \n", 2392 | " \n", 2393 | " \n", 2394 | " \n", 2395 | " \n", 2396 | " \n", 2397 | " \n", 2398 | " \n", 2399 | " \n", 2400 | " \n", 2401 | " \n", 2402 | " \n", 2403 | " \n", 2404 | " \n", 2405 | " \n", 2406 | " \n", 2407 | " \n", 2408 | " \n", 2409 | " \n", 2410 | " \n", 2411 | " \n", 2412 | " \n", 2413 | " \n", 2414 | " \n", 2415 | " \n", 2416 | " \n", 2417 | " \n", 2418 | " \n", 2419 | " \n", 2420 | " \n", 2421 | " \n", 2422 | " \n", 2423 | " \n", 2424 | " \n", 2425 | " \n", 2426 | " \n", 2427 | " \n", 2428 | " \n", 2429 | " \n", 2430 | " \n", 2431 | " \n", 2432 | " \n", 2433 | " \n", 2434 | " \n", 2435 | " \n", 2436 | " \n", 2437 | " \n", 2438 | " \n", 2439 | " \n", 2440 | " \n", 2441 | " \n", 2442 | " \n", 2443 | " \n", 2444 | " \n", 2445 | " \n", 2446 | " \n", 2447 | " \n", 2448 | " \n", 2449 | " \n", 2450 | " \n", 2451 | " \n", 2452 | " \n", 2453 | " \n", 2454 | " \n", 2455 | " \n", 2456 | "
SentenceTokenized SentenceInput IDsNumber of TokensSoftmax ProbsPredicted LabelTrue LabelSentiment
0at least one scene is so disgusting that viewe...[[CLS], at, least, one, scene, is, so, disgust...[101, 2012, 2560, 2028, 3496, 2003, 2061, 1942...20[0.6496649980545044, 0.350335031747818]00negative
1even the finest chef ca n't make a hotdog into...[[CLS], even, the, finest, chef, ca, n, ', t, ...[101, 2130, 1996, 10418, 10026, 6187, 1050, 10...44[0.564052939414978, 0.4359470307826996]00negative
2collateral damage finally delivers the goods f...[[CLS], collateral, damage, finally, delivers,...[101, 24172, 4053, 2633, 18058, 1996, 5350, 20...14[0.5163972973823547, 0.48360273241996765]01negative
3exciting and direct , with ghost imagery that ...[[CLS], exciting, and, direct, ,, with, ghost,...[101, 10990, 1998, 3622, 1010, 2007, 5745, 134...20[0.1598498672246933, 0.8401501178741455]11positive
4and when you 're talking about a slapstick com...[[CLS], and, when, you, ', re, talking, about,...[101, 1998, 2043, 2017, 1005, 2128, 3331, 2055...22[0.6049067974090576, 0.39509323239326477]00negative
...........................
145it 's a bit disappointing that it only manages...[[CLS], it, ', s, a, bit, disappointing, that,...[101, 2009, 1005, 1055, 1037, 2978, 15640, 200...20[0.5640332102775574, 0.4359667897224426]00negative
146a breezy romantic comedy that has the punch of...[[CLS], a, bree, ##zy, romantic, comedy, that,...[101, 1037, 21986, 9096, 6298, 4038, 2008, 203...24[0.18238702416419983, 0.8176130056381226]11positive
147the film tries too hard to be funny and tries ...[[CLS], the, film, tries, too, hard, to, be, f...[101, 1996, 2143, 5363, 2205, 2524, 2000, 2022...18[0.614990770816803, 0.3850092589855194]00negative
148thanks to scott 's charismatic roger and eisen...[[CLS], thanks, to, scott, ', s, charismatic, ...[101, 4283, 2000, 3660, 1005, 1055, 23916, 507...35[0.2577976882457733, 0.7422022819519043]11positive
149drops you into a dizzying , volatile , pressur...[[CLS], drops, you, into, a, dizzy, ##ing, ,, ...[101, 9010, 2017, 2046, 1037, 14849, 2075, 101...37[0.6184784173965454, 0.381521612405777]01negative
\n", 2457 | "

150 rows × 8 columns

\n", 2458 | "
" 2459 | ], 2460 | "text/plain": [ 2461 | " Sentence \\\n", 2462 | "0 at least one scene is so disgusting that viewe... \n", 2463 | "1 even the finest chef ca n't make a hotdog into... \n", 2464 | "2 collateral damage finally delivers the goods f... \n", 2465 | "3 exciting and direct , with ghost imagery that ... \n", 2466 | "4 and when you 're talking about a slapstick com... \n", 2467 | ".. ... \n", 2468 | "145 it 's a bit disappointing that it only manages... \n", 2469 | "146 a breezy romantic comedy that has the punch of... \n", 2470 | "147 the film tries too hard to be funny and tries ... \n", 2471 | "148 thanks to scott 's charismatic roger and eisen... \n", 2472 | "149 drops you into a dizzying , volatile , pressur... \n", 2473 | "\n", 2474 | " Tokenized Sentence \\\n", 2475 | "0 [[CLS], at, least, one, scene, is, so, disgust... \n", 2476 | "1 [[CLS], even, the, finest, chef, ca, n, ', t, ... \n", 2477 | "2 [[CLS], collateral, damage, finally, delivers,... \n", 2478 | "3 [[CLS], exciting, and, direct, ,, with, ghost,... \n", 2479 | "4 [[CLS], and, when, you, ', re, talking, about,... \n", 2480 | ".. ... \n", 2481 | "145 [[CLS], it, ', s, a, bit, disappointing, that,... \n", 2482 | "146 [[CLS], a, bree, ##zy, romantic, comedy, that,... \n", 2483 | "147 [[CLS], the, film, tries, too, hard, to, be, f... \n", 2484 | "148 [[CLS], thanks, to, scott, ', s, charismatic, ... \n", 2485 | "149 [[CLS], drops, you, into, a, dizzy, ##ing, ,, ... \n", 2486 | "\n", 2487 | " Input IDs Number of Tokens \\\n", 2488 | "0 [101, 2012, 2560, 2028, 3496, 2003, 2061, 1942... 20 \n", 2489 | "1 [101, 2130, 1996, 10418, 10026, 6187, 1050, 10... 44 \n", 2490 | "2 [101, 24172, 4053, 2633, 18058, 1996, 5350, 20... 14 \n", 2491 | "3 [101, 10990, 1998, 3622, 1010, 2007, 5745, 134... 20 \n", 2492 | "4 [101, 1998, 2043, 2017, 1005, 2128, 3331, 2055... 22 \n", 2493 | ".. ... ... \n", 2494 | "145 [101, 2009, 1005, 1055, 1037, 2978, 15640, 200... 20 \n", 2495 | "146 [101, 1037, 21986, 9096, 6298, 4038, 2008, 203... 24 \n", 2496 | "147 [101, 1996, 2143, 5363, 2205, 2524, 2000, 2022... 18 \n", 2497 | "148 [101, 4283, 2000, 3660, 1005, 1055, 23916, 507... 35 \n", 2498 | "149 [101, 9010, 2017, 2046, 1037, 14849, 2075, 101... 37 \n", 2499 | "\n", 2500 | " Softmax Probs Predicted Label True Label \\\n", 2501 | "0 [0.6496649980545044, 0.350335031747818] 0 0 \n", 2502 | "1 [0.564052939414978, 0.4359470307826996] 0 0 \n", 2503 | "2 [0.5163972973823547, 0.48360273241996765] 0 1 \n", 2504 | "3 [0.1598498672246933, 0.8401501178741455] 1 1 \n", 2505 | "4 [0.6049067974090576, 0.39509323239326477] 0 0 \n", 2506 | ".. ... ... ... \n", 2507 | "145 [0.5640332102775574, 0.4359667897224426] 0 0 \n", 2508 | "146 [0.18238702416419983, 0.8176130056381226] 1 1 \n", 2509 | "147 [0.614990770816803, 0.3850092589855194] 0 0 \n", 2510 | "148 [0.2577976882457733, 0.7422022819519043] 1 1 \n", 2511 | "149 [0.6184784173965454, 0.381521612405777] 0 1 \n", 2512 | "\n", 2513 | " Sentiment \n", 2514 | "0 negative \n", 2515 | "1 negative \n", 2516 | "2 negative \n", 2517 | "3 positive \n", 2518 | "4 negative \n", 2519 | ".. ... \n", 2520 | "145 negative \n", 2521 | "146 positive \n", 2522 | "147 negative \n", 2523 | "148 positive \n", 2524 | "149 negative \n", 2525 | "\n", 2526 | "[150 rows x 8 columns]" 2527 | ] 2528 | }, 2529 | "execution_count": 47, 2530 | "metadata": {}, 2531 | "output_type": "execute_result" 2532 | } 2533 | ], 2534 | "source": [ 2535 | "df_test_trained" 2536 | ] 2537 | }, 2538 | { 2539 | "cell_type": "code", 2540 | "execution_count": 48, 2541 | "id": "285f61e8-545b-42c2-bccf-ec961bcf7435", 2542 | "metadata": { 2543 | "tags": [] 2544 | }, 2545 | "outputs": [ 2546 | { 2547 | "data": { 2548 | "text/plain": [ 2549 | "0.8733333333333333" 2550 | ] 2551 | }, 2552 | "execution_count": 48, 2553 | "metadata": {}, 2554 | "output_type": "execute_result" 2555 | } 2556 | ], 2557 | "source": [ 2558 | "accuracy_test_trained" 2559 | ] 2560 | }, 2561 | { 2562 | "cell_type": "markdown", 2563 | "id": "951f027b-f4bc-46e7-bf49-21fbf1bc86cd", 2564 | "metadata": { 2565 | "id": "951f027b-f4bc-46e7-bf49-21fbf1bc86cd" 2566 | }, 2567 | "source": [ 2568 | "# compatible datasasets" 2569 | ] 2570 | }, 2571 | { 2572 | "cell_type": "code", 2573 | "execution_count": 49, 2574 | "id": "37b6a6d6-6aa3-425b-b8d3-ebb10e6a7071", 2575 | "metadata": {}, 2576 | "outputs": [], 2577 | "source": [ 2578 | "# Define the list of compatible datasets\n", 2579 | "compatible_datasets = {\n", 2580 | " \"imdb\": \"imdb\",\n", 2581 | " \"yelp\": \"yelp_polarity\",\n", 2582 | " \"amazon\": \"amazon_polarity\"\n", 2583 | "}" 2584 | ] 2585 | }, 2586 | { 2587 | "cell_type": "code", 2588 | "execution_count": 50, 2589 | "id": "e8a3e5d1-3583-422d-b66e-16dd6b9f38a3", 2590 | "metadata": { 2591 | "id": "e8a3e5d1-3583-422d-b66e-16dd6b9f38a3", 2592 | "tags": [] 2593 | }, 2594 | "outputs": [], 2595 | "source": [ 2596 | "def download_and_select_samples(dataset_name, n_samples):\n", 2597 | " \"\"\"\n", 2598 | " Downloads a dataset for sentiment analysis and selects a random subset of n_samples.\n", 2599 | "\n", 2600 | " Args:\n", 2601 | " - dataset_name: Name of the dataset (must be one of the compatible datasets).\n", 2602 | " - n_samples: Number of random samples to select.\n", 2603 | "\n", 2604 | " Returns:\n", 2605 | " - new_data: A subset of the dataset with n_samples randomly selected.\n", 2606 | " \"\"\"\n", 2607 | "\n", 2608 | " # Step 1: Load the dataset\n", 2609 | " if dataset_name not in compatible_datasets:\n", 2610 | " raise ValueError(f\"Dataset '{dataset_name}' not found. Choose from {list(compatible_datasets.keys())}\")\n", 2611 | "\n", 2612 | " dataset_info = compatible_datasets[dataset_name]\n", 2613 | "\n", 2614 | " # Some datasets require specifying a subset\n", 2615 | " if isinstance(dataset_info, tuple):\n", 2616 | " dataset = load_dataset(*dataset_info)\n", 2617 | " else:\n", 2618 | " dataset = load_dataset(dataset_info)\n", 2619 | "\n", 2620 | " # Use the test split if available, otherwise use the train split\n", 2621 | " split = 'test' if 'test' in dataset else 'train'\n", 2622 | " data = dataset[split]\n", 2623 | "\n", 2624 | " # Step 2: Select a random sample of n_samples from the dataset\n", 2625 | " new_data = data.shuffle(seed=17).select(range(n_samples))\n", 2626 | "\n", 2627 | " return new_data\n" 2628 | ] 2629 | }, 2630 | { 2631 | "cell_type": "markdown", 2632 | "id": "2e9da99e-ba90-43c7-8a45-4387afac1af6", 2633 | "metadata": { 2634 | "id": "2e9da99e-ba90-43c7-8a45-4387afac1af6" 2635 | }, 2636 | "source": [ 2637 | "### imdb" 2638 | ] 2639 | }, 2640 | { 2641 | "cell_type": "code", 2642 | "execution_count": 51, 2643 | "id": "69c10bba-99fa-453d-a131-3f709895418f", 2644 | "metadata": { 2645 | "id": "69c10bba-99fa-453d-a131-3f709895418f", 2646 | "tags": [] 2647 | }, 2648 | "outputs": [ 2649 | { 2650 | "data": { 2651 | "application/vnd.jupyter.widget-view+json": { 2652 | "model_id": "f02e606b9f534a65afa2c040a8467207", 2653 | "version_major": 2, 2654 | "version_minor": 0 2655 | }, 2656 | "text/plain": [ 2657 | "README.md: 0%| | 0.00/7.81k [00:00\n", 2752 | "\n", 2765 | "\n", 2766 | " \n", 2767 | " \n", 2768 | " \n", 2769 | " \n", 2770 | " \n", 2771 | " \n", 2772 | " \n", 2773 | " \n", 2774 | " \n", 2775 | " \n", 2776 | " \n", 2777 | " \n", 2778 | " \n", 2779 | " \n", 2780 | " \n", 2781 | " \n", 2782 | " \n", 2783 | " \n", 2784 | " \n", 2785 | " \n", 2786 | " \n", 2787 | " \n", 2788 | " \n", 2789 | " \n", 2790 | " \n", 2791 | " \n", 2792 | " \n", 2793 | " \n", 2794 | " \n", 2795 | " \n", 2796 | " \n", 2797 | " \n", 2798 | " \n", 2799 | " \n", 2800 | " \n", 2801 | " \n", 2802 | " \n", 2803 | " \n", 2804 | " \n", 2805 | " \n", 2806 | " \n", 2807 | " \n", 2808 | " \n", 2809 | " \n", 2810 | " \n", 2811 | " \n", 2812 | " \n", 2813 | " \n", 2814 | " \n", 2815 | " \n", 2816 | " \n", 2817 | " \n", 2818 | " \n", 2819 | " \n", 2820 | " \n", 2821 | " \n", 2822 | " \n", 2823 | " \n", 2824 | " \n", 2825 | " \n", 2826 | " \n", 2827 | " \n", 2828 | " \n", 2829 | " \n", 2830 | " \n", 2831 | " \n", 2832 | " \n", 2833 | " \n", 2834 | " \n", 2835 | " \n", 2836 | " \n", 2837 | " \n", 2838 | " \n", 2839 | " \n", 2840 | " \n", 2841 | " \n", 2842 | " \n", 2843 | " \n", 2844 | " \n", 2845 | " \n", 2846 | " \n", 2847 | " \n", 2848 | " \n", 2849 | " \n", 2850 | " \n", 2851 | " \n", 2852 | " \n", 2853 | " \n", 2854 | " \n", 2855 | " \n", 2856 | " \n", 2857 | " \n", 2858 | " \n", 2859 | " \n", 2860 | " \n", 2861 | " \n", 2862 | " \n", 2863 | " \n", 2864 | " \n", 2865 | " \n", 2866 | " \n", 2867 | " \n", 2868 | " \n", 2869 | " \n", 2870 | " \n", 2871 | " \n", 2872 | " \n", 2873 | " \n", 2874 | " \n", 2875 | " \n", 2876 | " \n", 2877 | " \n", 2878 | " \n", 2879 | " \n", 2880 | " \n", 2881 | " \n", 2882 | " \n", 2883 | " \n", 2884 | " \n", 2885 | " \n", 2886 | " \n", 2887 | " \n", 2888 | " \n", 2889 | " \n", 2890 | " \n", 2891 | " \n", 2892 | " \n", 2893 | " \n", 2894 | " \n", 2895 | " \n", 2896 | " \n", 2897 | " \n", 2898 | " \n", 2899 | " \n", 2900 | " \n", 2901 | " \n", 2902 | "
SentenceTokenized SentenceInput IDsNumber of TokensSoftmax ProbsPredicted LabelTrue LabelSentiment
0I wish I had read the comments on IMDb before ...[[CLS], i, wish, i, had, read, the, comments, ...[101, 1045, 4299, 1045, 2018, 3191, 1996, 7928...159[0.5610942244529724, 0.4389057457447052]00negative
1I loved this movie! So worth the long running ...[[CLS], i, loved, this, movie, !, so, worth, t...[101, 1045, 3866, 2023, 3185, 999, 2061, 4276,...149[0.4507291316986084, 0.5492709279060364]11positive
2I actually went to see this movie with low exp...[[CLS], i, actually, went, to, see, this, movi...[101, 1045, 2941, 2253, 2000, 2156, 2023, 3185...222[0.39913123846054077, 0.6008687615394592]11positive
3For anyone who cares to know something about t...[[CLS], for, anyone, who, cares, to, know, som...[101, 2005, 3087, 2040, 14977, 2000, 2113, 224...201[0.522199273109436, 0.4778006672859192]00negative
4Eric Idle, Robbie Coltraine, Janet Suzman - it...[[CLS], eric, idle, ,, robbie, colt, ##raine, ...[101, 4388, 18373, 1010, 12289, 9110, 26456, 1...125[0.5400406122207642, 0.45995938777923584]00negative
...........................
145I bought this Chuck Norris DVD knowing that it...[[CLS], i, bought, this, chuck, norris, dvd, k...[101, 1045, 4149, 2023, 8057, 15466, 4966, 420...178[0.4604353904724121, 0.5395646095275879]10positive
146This movie is based on the book, \"A Many Splen...[[CLS], this, movie, is, based, on, the, book,...[101, 2023, 3185, 2003, 2241, 2006, 1996, 2338...266[0.31596124172210693, 0.6840387582778931]11positive
147I must say, when I saw this film at a 6.5 on t...[[CLS], i, must, say, ,, when, i, saw, this, f...[101, 1045, 2442, 2360, 1010, 2043, 1045, 2387...500[0.5218430757522583, 0.4781569540500641]00negative
148I really enjoyed this movie. I have a real sen...[[CLS], i, really, enjoyed, this, movie, ., i,...[101, 1045, 2428, 5632, 2023, 3185, 1012, 1045...350[0.4625762701034546, 0.5374237298965454]11positive
149...and boy is the collision deafening. A femal...[[CLS], ., ., ., and, boy, is, the, collision,...[101, 1012, 1012, 1012, 1998, 2879, 2003, 1996...148[0.5410314798355103, 0.45896854996681213]00negative
\n", 2903 | "

150 rows × 8 columns

\n", 2904 | "" 2905 | ], 2906 | "text/plain": [ 2907 | " Sentence \\\n", 2908 | "0 I wish I had read the comments on IMDb before ... \n", 2909 | "1 I loved this movie! So worth the long running ... \n", 2910 | "2 I actually went to see this movie with low exp... \n", 2911 | "3 For anyone who cares to know something about t... \n", 2912 | "4 Eric Idle, Robbie Coltraine, Janet Suzman - it... \n", 2913 | ".. ... \n", 2914 | "145 I bought this Chuck Norris DVD knowing that it... \n", 2915 | "146 This movie is based on the book, \"A Many Splen... \n", 2916 | "147 I must say, when I saw this film at a 6.5 on t... \n", 2917 | "148 I really enjoyed this movie. I have a real sen... \n", 2918 | "149 ...and boy is the collision deafening. A femal... \n", 2919 | "\n", 2920 | " Tokenized Sentence \\\n", 2921 | "0 [[CLS], i, wish, i, had, read, the, comments, ... \n", 2922 | "1 [[CLS], i, loved, this, movie, !, so, worth, t... \n", 2923 | "2 [[CLS], i, actually, went, to, see, this, movi... \n", 2924 | "3 [[CLS], for, anyone, who, cares, to, know, som... \n", 2925 | "4 [[CLS], eric, idle, ,, robbie, colt, ##raine, ... \n", 2926 | ".. ... \n", 2927 | "145 [[CLS], i, bought, this, chuck, norris, dvd, k... \n", 2928 | "146 [[CLS], this, movie, is, based, on, the, book,... \n", 2929 | "147 [[CLS], i, must, say, ,, when, i, saw, this, f... \n", 2930 | "148 [[CLS], i, really, enjoyed, this, movie, ., i,... \n", 2931 | "149 [[CLS], ., ., ., and, boy, is, the, collision,... \n", 2932 | "\n", 2933 | " Input IDs Number of Tokens \\\n", 2934 | "0 [101, 1045, 4299, 1045, 2018, 3191, 1996, 7928... 159 \n", 2935 | "1 [101, 1045, 3866, 2023, 3185, 999, 2061, 4276,... 149 \n", 2936 | "2 [101, 1045, 2941, 2253, 2000, 2156, 2023, 3185... 222 \n", 2937 | "3 [101, 2005, 3087, 2040, 14977, 2000, 2113, 224... 201 \n", 2938 | "4 [101, 4388, 18373, 1010, 12289, 9110, 26456, 1... 125 \n", 2939 | ".. ... ... \n", 2940 | "145 [101, 1045, 4149, 2023, 8057, 15466, 4966, 420... 178 \n", 2941 | "146 [101, 2023, 3185, 2003, 2241, 2006, 1996, 2338... 266 \n", 2942 | "147 [101, 1045, 2442, 2360, 1010, 2043, 1045, 2387... 500 \n", 2943 | "148 [101, 1045, 2428, 5632, 2023, 3185, 1012, 1045... 350 \n", 2944 | "149 [101, 1012, 1012, 1012, 1998, 2879, 2003, 1996... 148 \n", 2945 | "\n", 2946 | " Softmax Probs Predicted Label True Label \\\n", 2947 | "0 [0.5610942244529724, 0.4389057457447052] 0 0 \n", 2948 | "1 [0.4507291316986084, 0.5492709279060364] 1 1 \n", 2949 | "2 [0.39913123846054077, 0.6008687615394592] 1 1 \n", 2950 | "3 [0.522199273109436, 0.4778006672859192] 0 0 \n", 2951 | "4 [0.5400406122207642, 0.45995938777923584] 0 0 \n", 2952 | ".. ... ... ... \n", 2953 | "145 [0.4604353904724121, 0.5395646095275879] 1 0 \n", 2954 | "146 [0.31596124172210693, 0.6840387582778931] 1 1 \n", 2955 | "147 [0.5218430757522583, 0.4781569540500641] 0 0 \n", 2956 | "148 [0.4625762701034546, 0.5374237298965454] 1 1 \n", 2957 | "149 [0.5410314798355103, 0.45896854996681213] 0 0 \n", 2958 | "\n", 2959 | " Sentiment \n", 2960 | "0 negative \n", 2961 | "1 positive \n", 2962 | "2 positive \n", 2963 | "3 negative \n", 2964 | "4 negative \n", 2965 | ".. ... \n", 2966 | "145 positive \n", 2967 | "146 positive \n", 2968 | "147 negative \n", 2969 | "148 positive \n", 2970 | "149 negative \n", 2971 | "\n", 2972 | "[150 rows x 8 columns]" 2973 | ] 2974 | }, 2975 | "execution_count": 55, 2976 | "metadata": {}, 2977 | "output_type": "execute_result" 2978 | } 2979 | ], 2980 | "source": [ 2981 | "imdb_df" 2982 | ] 2983 | }, 2984 | { 2985 | "cell_type": "code", 2986 | "execution_count": 56, 2987 | "id": "dd73c530-8aa2-43f1-a6a2-097d344a4d8a", 2988 | "metadata": { 2989 | "tags": [] 2990 | }, 2991 | "outputs": [ 2992 | { 2993 | "data": { 2994 | "text/plain": [ 2995 | "\"I wish I had read the comments on IMDb before I saw this movie. The first 1 hour was OK, though it did make me wonder why everything was centered at Chicago and why no one reported any weather anomaly from outside US. Isolated acts of nature (of this magnitude) are unthinkable. But beyond the first 60 minutes, the movie just drags on like a never-ending story. The screenplay is horrible. As for the actors, very poor choice. Only the people hired to run in panic stick to their roles. But I do have to agree that this movie has got some good 'special effects'. If you rented it on a DVD and would want to watch the movie, despite the reviews, then play it on maximum speed your player would allow!\"" 2996 | ] 2997 | }, 2998 | "execution_count": 56, 2999 | "metadata": {}, 3000 | "output_type": "execute_result" 3001 | } 3002 | ], 3003 | "source": [ 3004 | "imdb_df['Sentence'][0]" 3005 | ] 3006 | }, 3007 | { 3008 | "cell_type": "markdown", 3009 | "id": "7e8538fa-9a32-4f5b-8730-885494cfc249", 3010 | "metadata": { 3011 | "id": "7e8538fa-9a32-4f5b-8730-885494cfc249" 3012 | }, 3013 | "source": [ 3014 | "### yelp" 3015 | ] 3016 | }, 3017 | { 3018 | "cell_type": "code", 3019 | "execution_count": 57, 3020 | "id": "491887d2-7dc5-4612-91e0-bb8ef8baac84", 3021 | "metadata": { 3022 | "id": "491887d2-7dc5-4612-91e0-bb8ef8baac84", 3023 | "tags": [] 3024 | }, 3025 | "outputs": [ 3026 | { 3027 | "data": { 3028 | "application/vnd.jupyter.widget-view+json": { 3029 | "model_id": "979f00b7068745d5a02aa1bcdd550ec8", 3030 | "version_major": 2, 3031 | "version_minor": 0 3032 | }, 3033 | "text/plain": [ 3034 | "README.md: 0%| | 0.00/8.93k [00:00\n", 3132 | "\n", 3145 | "\n", 3146 | " \n", 3147 | " \n", 3148 | " \n", 3149 | " \n", 3150 | " \n", 3151 | " \n", 3152 | " \n", 3153 | " \n", 3154 | " \n", 3155 | " \n", 3156 | " \n", 3157 | " \n", 3158 | " \n", 3159 | " \n", 3160 | " \n", 3161 | " \n", 3162 | " \n", 3163 | " \n", 3164 | " \n", 3165 | " \n", 3166 | " \n", 3167 | " \n", 3168 | " \n", 3169 | " \n", 3170 | " \n", 3171 | " \n", 3172 | " \n", 3173 | " \n", 3174 | " \n", 3175 | " \n", 3176 | " \n", 3177 | " \n", 3178 | " \n", 3179 | " \n", 3180 | " \n", 3181 | " \n", 3182 | " \n", 3183 | " \n", 3184 | " \n", 3185 | " \n", 3186 | " \n", 3187 | " \n", 3188 | " \n", 3189 | " \n", 3190 | " \n", 3191 | " \n", 3192 | " \n", 3193 | " \n", 3194 | " \n", 3195 | " \n", 3196 | " \n", 3197 | " \n", 3198 | " \n", 3199 | " \n", 3200 | " \n", 3201 | " \n", 3202 | " \n", 3203 | " \n", 3204 | " \n", 3205 | " \n", 3206 | " \n", 3207 | " \n", 3208 | " \n", 3209 | " \n", 3210 | " \n", 3211 | " \n", 3212 | " \n", 3213 | " \n", 3214 | " \n", 3215 | " \n", 3216 | " \n", 3217 | " \n", 3218 | " \n", 3219 | " \n", 3220 | " \n", 3221 | " \n", 3222 | " \n", 3223 | " \n", 3224 | " \n", 3225 | " \n", 3226 | " \n", 3227 | " \n", 3228 | " \n", 3229 | " \n", 3230 | " \n", 3231 | " \n", 3232 | " \n", 3233 | " \n", 3234 | " \n", 3235 | " \n", 3236 | " \n", 3237 | " \n", 3238 | " \n", 3239 | " \n", 3240 | " \n", 3241 | " \n", 3242 | " \n", 3243 | " \n", 3244 | " \n", 3245 | " \n", 3246 | " \n", 3247 | " \n", 3248 | " \n", 3249 | " \n", 3250 | " \n", 3251 | " \n", 3252 | " \n", 3253 | " \n", 3254 | " \n", 3255 | " \n", 3256 | " \n", 3257 | " \n", 3258 | " \n", 3259 | " \n", 3260 | " \n", 3261 | " \n", 3262 | " \n", 3263 | " \n", 3264 | " \n", 3265 | " \n", 3266 | " \n", 3267 | " \n", 3268 | " \n", 3269 | " \n", 3270 | " \n", 3271 | " \n", 3272 | " \n", 3273 | " \n", 3274 | " \n", 3275 | " \n", 3276 | " \n", 3277 | " \n", 3278 | " \n", 3279 | " \n", 3280 | " \n", 3281 | " \n", 3282 | "
SentenceTokenized SentenceInput IDsNumber of TokensSoftmax ProbsPredicted LabelTrue LabelSentiment
0Service and food were awesome! Highly recommen...[[CLS], service, and, food, were, awesome, !, ...[101, 2326, 1998, 2833, 2020, 12476, 999, 3811...23[0.2670975625514984, 0.732902467250824]11positive
1The food was OK, it was kind of slow so the fi...[[CLS], the, food, was, ok, ,, it, was, kind, ...[101, 1996, 2833, 2001, 7929, 1010, 2009, 2001...163[0.5392808318138123, 0.46071913838386536]00negative
2The gym is dirty and old and the whole place i...[[CLS], the, gym, is, dirty, and, old, and, th...[101, 1996, 9726, 2003, 6530, 1998, 2214, 1998...279[0.5794830918312073, 0.4205169379711151]00negative
3Just arrived from the overnight train, arrived...[[CLS], just, arrived, from, the, overnight, t...[101, 2074, 3369, 2013, 1996, 11585, 3345, 101...68[0.275282621383667, 0.724717378616333]11positive
4So just in case this is the first review you'v...[[CLS], so, just, in, case, this, is, the, fir...[101, 2061, 2074, 1999, 2553, 2023, 2003, 1996...309[0.4868742823600769, 0.5131257176399231]11positive
...........................
145Usually I am not a big stickler for customer s...[[CLS], usually, i, am, not, a, big, stick, ##...[101, 2788, 1045, 2572, 2025, 1037, 2502, 6293...512[0.4426237642765045, 0.5573763251304626]11positive
146Of the cheaper casinos on the Strip, Bally's h...[[CLS], of, the, cheaper, casinos, on, the, st...[101, 1997, 1996, 16269, 27300, 2006, 1996, 61...185[0.44811585545539856, 0.551884114742279]11positive
147Extradinarilly big for a cafe! They've got eve...[[CLS], extra, ##dina, ##rill, ##y, big, for, ...[101, 4469, 18979, 24714, 2100, 2502, 2005, 10...140[0.4032253921031952, 0.5967746376991272]11positive
148The serving is good, but the steak dinner is n...[[CLS], the, serving, is, good, ,, but, the, s...[101, 1996, 3529, 2003, 2204, 1010, 2021, 1996...32[0.5420543551445007, 0.45794567465782166]00negative
149We have been here a few times and the food is ...[[CLS], we, have, been, here, a, few, times, a...[101, 2057, 2031, 2042, 2182, 1037, 2261, 2335...142[0.5248243808746338, 0.4751756191253662]00negative
\n", 3283 | "

150 rows × 8 columns

\n", 3284 | "" 3285 | ], 3286 | "text/plain": [ 3287 | " Sentence \\\n", 3288 | "0 Service and food were awesome! Highly recommen... \n", 3289 | "1 The food was OK, it was kind of slow so the fi... \n", 3290 | "2 The gym is dirty and old and the whole place i... \n", 3291 | "3 Just arrived from the overnight train, arrived... \n", 3292 | "4 So just in case this is the first review you'v... \n", 3293 | ".. ... \n", 3294 | "145 Usually I am not a big stickler for customer s... \n", 3295 | "146 Of the cheaper casinos on the Strip, Bally's h... \n", 3296 | "147 Extradinarilly big for a cafe! They've got eve... \n", 3297 | "148 The serving is good, but the steak dinner is n... \n", 3298 | "149 We have been here a few times and the food is ... \n", 3299 | "\n", 3300 | " Tokenized Sentence \\\n", 3301 | "0 [[CLS], service, and, food, were, awesome, !, ... \n", 3302 | "1 [[CLS], the, food, was, ok, ,, it, was, kind, ... \n", 3303 | "2 [[CLS], the, gym, is, dirty, and, old, and, th... \n", 3304 | "3 [[CLS], just, arrived, from, the, overnight, t... \n", 3305 | "4 [[CLS], so, just, in, case, this, is, the, fir... \n", 3306 | ".. ... \n", 3307 | "145 [[CLS], usually, i, am, not, a, big, stick, ##... \n", 3308 | "146 [[CLS], of, the, cheaper, casinos, on, the, st... \n", 3309 | "147 [[CLS], extra, ##dina, ##rill, ##y, big, for, ... \n", 3310 | "148 [[CLS], the, serving, is, good, ,, but, the, s... \n", 3311 | "149 [[CLS], we, have, been, here, a, few, times, a... \n", 3312 | "\n", 3313 | " Input IDs Number of Tokens \\\n", 3314 | "0 [101, 2326, 1998, 2833, 2020, 12476, 999, 3811... 23 \n", 3315 | "1 [101, 1996, 2833, 2001, 7929, 1010, 2009, 2001... 163 \n", 3316 | "2 [101, 1996, 9726, 2003, 6530, 1998, 2214, 1998... 279 \n", 3317 | "3 [101, 2074, 3369, 2013, 1996, 11585, 3345, 101... 68 \n", 3318 | "4 [101, 2061, 2074, 1999, 2553, 2023, 2003, 1996... 309 \n", 3319 | ".. ... ... \n", 3320 | "145 [101, 2788, 1045, 2572, 2025, 1037, 2502, 6293... 512 \n", 3321 | "146 [101, 1997, 1996, 16269, 27300, 2006, 1996, 61... 185 \n", 3322 | "147 [101, 4469, 18979, 24714, 2100, 2502, 2005, 10... 140 \n", 3323 | "148 [101, 1996, 3529, 2003, 2204, 1010, 2021, 1996... 32 \n", 3324 | "149 [101, 2057, 2031, 2042, 2182, 1037, 2261, 2335... 142 \n", 3325 | "\n", 3326 | " Softmax Probs Predicted Label True Label \\\n", 3327 | "0 [0.2670975625514984, 0.732902467250824] 1 1 \n", 3328 | "1 [0.5392808318138123, 0.46071913838386536] 0 0 \n", 3329 | "2 [0.5794830918312073, 0.4205169379711151] 0 0 \n", 3330 | "3 [0.275282621383667, 0.724717378616333] 1 1 \n", 3331 | "4 [0.4868742823600769, 0.5131257176399231] 1 1 \n", 3332 | ".. ... ... ... \n", 3333 | "145 [0.4426237642765045, 0.5573763251304626] 1 1 \n", 3334 | "146 [0.44811585545539856, 0.551884114742279] 1 1 \n", 3335 | "147 [0.4032253921031952, 0.5967746376991272] 1 1 \n", 3336 | "148 [0.5420543551445007, 0.45794567465782166] 0 0 \n", 3337 | "149 [0.5248243808746338, 0.4751756191253662] 0 0 \n", 3338 | "\n", 3339 | " Sentiment \n", 3340 | "0 positive \n", 3341 | "1 negative \n", 3342 | "2 negative \n", 3343 | "3 positive \n", 3344 | "4 positive \n", 3345 | ".. ... \n", 3346 | "145 positive \n", 3347 | "146 positive \n", 3348 | "147 positive \n", 3349 | "148 negative \n", 3350 | "149 negative \n", 3351 | "\n", 3352 | "[150 rows x 8 columns]" 3353 | ] 3354 | }, 3355 | "execution_count": 61, 3356 | "metadata": {}, 3357 | "output_type": "execute_result" 3358 | } 3359 | ], 3360 | "source": [ 3361 | "yelp_df" 3362 | ] 3363 | }, 3364 | { 3365 | "cell_type": "code", 3366 | "execution_count": 62, 3367 | "id": "c3720448-ad09-402c-871d-40294d0b667c", 3368 | "metadata": { 3369 | "tags": [] 3370 | }, 3371 | "outputs": [ 3372 | { 3373 | "data": { 3374 | "text/plain": [ 3375 | "\"Service and food were awesome! Highly recommend the French onion soup. Can't wait to come back.\"" 3376 | ] 3377 | }, 3378 | "execution_count": 62, 3379 | "metadata": {}, 3380 | "output_type": "execute_result" 3381 | } 3382 | ], 3383 | "source": [ 3384 | "yelp_df['Sentence'][0]" 3385 | ] 3386 | }, 3387 | { 3388 | "cell_type": "markdown", 3389 | "id": "3bcdad9e-2681-419f-9d1e-836070d2821b", 3390 | "metadata": { 3391 | "id": "3bcdad9e-2681-419f-9d1e-836070d2821b" 3392 | }, 3393 | "source": [ 3394 | "### amazon" 3395 | ] 3396 | }, 3397 | { 3398 | "cell_type": "code", 3399 | "execution_count": null, 3400 | "id": "bbaf0990-cd57-4fe8-996e-257764bef4c8", 3401 | "metadata": {}, 3402 | "outputs": [], 3403 | "source": [] 3404 | }, 3405 | { 3406 | "cell_type": "code", 3407 | "execution_count": 63, 3408 | "id": "a929ef6c-f245-4048-aa02-1c2ded483581", 3409 | "metadata": { 3410 | "id": "a929ef6c-f245-4048-aa02-1c2ded483581", 3411 | "tags": [] 3412 | }, 3413 | "outputs": [ 3414 | { 3415 | "data": { 3416 | "application/vnd.jupyter.widget-view+json": { 3417 | "model_id": "aeedb257d6f74fa499030302d8e189cb", 3418 | "version_major": 2, 3419 | "version_minor": 0 3420 | }, 3421 | "text/plain": [ 3422 | "README.md: 0%| | 0.00/6.81k [00:00\n", 3520 | "\n", 3533 | "\n", 3534 | " \n", 3535 | " \n", 3536 | " \n", 3537 | " \n", 3538 | " \n", 3539 | " \n", 3540 | " \n", 3541 | " \n", 3542 | " \n", 3543 | " \n", 3544 | " \n", 3545 | " \n", 3546 | " \n", 3547 | " \n", 3548 | " \n", 3549 | " \n", 3550 | " \n", 3551 | " \n", 3552 | " \n", 3553 | " \n", 3554 | " \n", 3555 | " \n", 3556 | " \n", 3557 | " \n", 3558 | " \n", 3559 | " \n", 3560 | " \n", 3561 | " \n", 3562 | " \n", 3563 | " \n", 3564 | " \n", 3565 | " \n", 3566 | " \n", 3567 | " \n", 3568 | " \n", 3569 | " \n", 3570 | " \n", 3571 | " \n", 3572 | " \n", 3573 | " \n", 3574 | " \n", 3575 | " \n", 3576 | " \n", 3577 | " \n", 3578 | " \n", 3579 | " \n", 3580 | " \n", 3581 | " \n", 3582 | " \n", 3583 | " \n", 3584 | " \n", 3585 | " \n", 3586 | " \n", 3587 | " \n", 3588 | " \n", 3589 | " \n", 3590 | " \n", 3591 | " \n", 3592 | " \n", 3593 | " \n", 3594 | " \n", 3595 | " \n", 3596 | " \n", 3597 | " \n", 3598 | " \n", 3599 | " \n", 3600 | " \n", 3601 | " \n", 3602 | " \n", 3603 | " \n", 3604 | " \n", 3605 | " \n", 3606 | " \n", 3607 | " \n", 3608 | " \n", 3609 | " \n", 3610 | " \n", 3611 | " \n", 3612 | " \n", 3613 | " \n", 3614 | " \n", 3615 | " \n", 3616 | " \n", 3617 | " \n", 3618 | " \n", 3619 | " \n", 3620 | " \n", 3621 | " \n", 3622 | " \n", 3623 | " \n", 3624 | " \n", 3625 | " \n", 3626 | " \n", 3627 | " \n", 3628 | " \n", 3629 | " \n", 3630 | " \n", 3631 | " \n", 3632 | " \n", 3633 | " \n", 3634 | " \n", 3635 | " \n", 3636 | " \n", 3637 | " \n", 3638 | " \n", 3639 | " \n", 3640 | " \n", 3641 | " \n", 3642 | " \n", 3643 | " \n", 3644 | " \n", 3645 | " \n", 3646 | " \n", 3647 | " \n", 3648 | " \n", 3649 | " \n", 3650 | " \n", 3651 | " \n", 3652 | " \n", 3653 | " \n", 3654 | " \n", 3655 | " \n", 3656 | " \n", 3657 | " \n", 3658 | " \n", 3659 | " \n", 3660 | " \n", 3661 | " \n", 3662 | " \n", 3663 | " \n", 3664 | " \n", 3665 | " \n", 3666 | " \n", 3667 | " \n", 3668 | " \n", 3669 | " \n", 3670 | "
SentenceTokenized SentenceInput IDsNumber of TokensSoftmax ProbsPredicted LabelTrue LabelSentiment
0Ben Harper was brought to my attention through...[[CLS], ben, harper, was, brought, to, my, att...[101, 3841, 8500, 2001, 2716, 2000, 2026, 3086...132[0.5244153141975403, 0.47558465600013733]00negative
1I think I'm one of the few folks that recieved...[[CLS], i, think, i, ', m, one, of, the, few, ...[101, 1045, 2228, 1045, 1005, 1049, 2028, 1997...60[0.5591208338737488, 0.44087913632392883]01negative
2First, why did I read this book? Do I have an ...[[CLS], first, ,, why, did, i, read, this, boo...[101, 2034, 1010, 2339, 2106, 1045, 3191, 2023...224[0.4386332631111145, 0.5613666772842407]11positive
3if they taught this kind of history in school,...[[CLS], if, they, taught, this, kind, of, hist...[101, 2065, 2027, 4036, 2023, 2785, 1997, 2381...69[0.48176589608192444, 0.518234133720398]11positive
4Well, well ,well the latest in the o'malley sa...[[CLS], well, ,, well, ,, well, the, latest, i...[101, 2092, 1010, 2092, 1010, 2092, 1996, 6745...66[0.24162133038043976, 0.7583786249160767]11positive
...........................
145I highly recommend Gary Chapman's \"5 Love Lang...[[CLS], i, highly, recommend, gary, chapman, '...[101, 1045, 3811, 16755, 5639, 11526, 1005, 10...71[0.2884887158870697, 0.7115112543106079]11positive
146I purchased one of the HP 540 series PDA's (in...[[CLS], i, purchased, one, of, the, hp, 540, s...[101, 1045, 4156, 2028, 1997, 1996, 6522, 2026...91[0.5197734832763672, 0.4802265465259552]00negative
147After reading the many rave reviews, I was exp...[[CLS], after, reading, the, many, rave, revie...[101, 2044, 3752, 1996, 2116, 23289, 4391, 101...192[0.5477862358093262, 0.45221370458602905]00negative
148For most of my 7th and 8th grade year, a few o...[[CLS], for, most, of, my, 7th, and, 8th, grad...[101, 2005, 2087, 1997, 2026, 5504, 1998, 5893...211[0.45704326033592224, 0.5429567098617554]11positive
149This has to be RUSH's best work since the earl...[[CLS], this, has, to, be, rush, ', s, best, w...[101, 2023, 2038, 2000, 2022, 5481, 1005, 1055...166[0.35129478573799133, 0.6487051844596863]11positive
\n", 3671 | "

150 rows × 8 columns

\n", 3672 | "" 3673 | ], 3674 | "text/plain": [ 3675 | " Sentence \\\n", 3676 | "0 Ben Harper was brought to my attention through... \n", 3677 | "1 I think I'm one of the few folks that recieved... \n", 3678 | "2 First, why did I read this book? Do I have an ... \n", 3679 | "3 if they taught this kind of history in school,... \n", 3680 | "4 Well, well ,well the latest in the o'malley sa... \n", 3681 | ".. ... \n", 3682 | "145 I highly recommend Gary Chapman's \"5 Love Lang... \n", 3683 | "146 I purchased one of the HP 540 series PDA's (in... \n", 3684 | "147 After reading the many rave reviews, I was exp... \n", 3685 | "148 For most of my 7th and 8th grade year, a few o... \n", 3686 | "149 This has to be RUSH's best work since the earl... \n", 3687 | "\n", 3688 | " Tokenized Sentence \\\n", 3689 | "0 [[CLS], ben, harper, was, brought, to, my, att... \n", 3690 | "1 [[CLS], i, think, i, ', m, one, of, the, few, ... \n", 3691 | "2 [[CLS], first, ,, why, did, i, read, this, boo... \n", 3692 | "3 [[CLS], if, they, taught, this, kind, of, hist... \n", 3693 | "4 [[CLS], well, ,, well, ,, well, the, latest, i... \n", 3694 | ".. ... \n", 3695 | "145 [[CLS], i, highly, recommend, gary, chapman, '... \n", 3696 | "146 [[CLS], i, purchased, one, of, the, hp, 540, s... \n", 3697 | "147 [[CLS], after, reading, the, many, rave, revie... \n", 3698 | "148 [[CLS], for, most, of, my, 7th, and, 8th, grad... \n", 3699 | "149 [[CLS], this, has, to, be, rush, ', s, best, w... \n", 3700 | "\n", 3701 | " Input IDs Number of Tokens \\\n", 3702 | "0 [101, 3841, 8500, 2001, 2716, 2000, 2026, 3086... 132 \n", 3703 | "1 [101, 1045, 2228, 1045, 1005, 1049, 2028, 1997... 60 \n", 3704 | "2 [101, 2034, 1010, 2339, 2106, 1045, 3191, 2023... 224 \n", 3705 | "3 [101, 2065, 2027, 4036, 2023, 2785, 1997, 2381... 69 \n", 3706 | "4 [101, 2092, 1010, 2092, 1010, 2092, 1996, 6745... 66 \n", 3707 | ".. ... ... \n", 3708 | "145 [101, 1045, 3811, 16755, 5639, 11526, 1005, 10... 71 \n", 3709 | "146 [101, 1045, 4156, 2028, 1997, 1996, 6522, 2026... 91 \n", 3710 | "147 [101, 2044, 3752, 1996, 2116, 23289, 4391, 101... 192 \n", 3711 | "148 [101, 2005, 2087, 1997, 2026, 5504, 1998, 5893... 211 \n", 3712 | "149 [101, 2023, 2038, 2000, 2022, 5481, 1005, 1055... 166 \n", 3713 | "\n", 3714 | " Softmax Probs Predicted Label True Label \\\n", 3715 | "0 [0.5244153141975403, 0.47558465600013733] 0 0 \n", 3716 | "1 [0.5591208338737488, 0.44087913632392883] 0 1 \n", 3717 | "2 [0.4386332631111145, 0.5613666772842407] 1 1 \n", 3718 | "3 [0.48176589608192444, 0.518234133720398] 1 1 \n", 3719 | "4 [0.24162133038043976, 0.7583786249160767] 1 1 \n", 3720 | ".. ... ... ... \n", 3721 | "145 [0.2884887158870697, 0.7115112543106079] 1 1 \n", 3722 | "146 [0.5197734832763672, 0.4802265465259552] 0 0 \n", 3723 | "147 [0.5477862358093262, 0.45221370458602905] 0 0 \n", 3724 | "148 [0.45704326033592224, 0.5429567098617554] 1 1 \n", 3725 | "149 [0.35129478573799133, 0.6487051844596863] 1 1 \n", 3726 | "\n", 3727 | " Sentiment \n", 3728 | "0 negative \n", 3729 | "1 negative \n", 3730 | "2 positive \n", 3731 | "3 positive \n", 3732 | "4 positive \n", 3733 | ".. ... \n", 3734 | "145 positive \n", 3735 | "146 negative \n", 3736 | "147 negative \n", 3737 | "148 positive \n", 3738 | "149 positive \n", 3739 | "\n", 3740 | "[150 rows x 8 columns]" 3741 | ] 3742 | }, 3743 | "execution_count": 67, 3744 | "metadata": {}, 3745 | "output_type": "execute_result" 3746 | } 3747 | ], 3748 | "source": [ 3749 | "amazon_df" 3750 | ] 3751 | }, 3752 | { 3753 | "cell_type": "code", 3754 | "execution_count": 68, 3755 | "id": "793fa6a9-3b3a-440a-badc-552f2827ca26", 3756 | "metadata": {}, 3757 | "outputs": [ 3758 | { 3759 | "data": { 3760 | "text/plain": [ 3761 | "\"Ben Harper was brought to my attention through his association with Jack Johnson. Then Direct TV showed Ben Harper this month on their free concert. I only saw part of the show and decided to buy Live from Mars as my first (and last) Ben Harper CD. I can't get into his music...it doesn't have any flow. His guitar playing is mediocre at best, and his vocals even worse. At times I thought Tiny Tim had come back from the dead. I'll stick with Jack Johnson. Ben Harper was not what I expected, and I utterly fail to see what all the hype is about.\"" 3762 | ] 3763 | }, 3764 | "execution_count": 68, 3765 | "metadata": {}, 3766 | "output_type": "execute_result" 3767 | } 3768 | ], 3769 | "source": [ 3770 | "amazon_df['Sentence'][0]" 3771 | ] 3772 | }, 3773 | { 3774 | "cell_type": "markdown", 3775 | "id": "09db2f8c-6e95-4d3d-a014-002e5db8070a", 3776 | "metadata": {}, 3777 | "source": [ 3778 | "# Comparing tokenizers" 3779 | ] 3780 | }, 3781 | { 3782 | "cell_type": "code", 3783 | "execution_count": 69, 3784 | "id": "faf92fcd-2c4f-4124-add8-105fd225d971", 3785 | "metadata": { 3786 | "tags": [] 3787 | }, 3788 | "outputs": [], 3789 | "source": [ 3790 | "# Define the tokenizers to be compared\n", 3791 | "tokenizers = {\n", 3792 | " \"BART\": \"facebook/bart-base\",\n", 3793 | " \"DistilBERT\": \"distilbert-base-uncased\",\n", 3794 | " \"GPT-2\": \"gpt2\",\n", 3795 | " \"T5\": \"t5-small\",\n", 3796 | " \"Albert\": \"albert-base-v2\",\n", 3797 | " \"XLM-Roberta\": \"xlm-roberta-base\"\n", 3798 | "}" 3799 | ] 3800 | }, 3801 | { 3802 | "cell_type": "code", 3803 | "execution_count": 70, 3804 | "id": "6d72a660-c06f-4f85-b0dc-1c34414373fe", 3805 | "metadata": { 3806 | "tags": [] 3807 | }, 3808 | "outputs": [], 3809 | "source": [ 3810 | "def load_tokenizer(model_name):\n", 3811 | " \"\"\"\n", 3812 | " Load and return a tokenizer based on the provided model name.\n", 3813 | " \"\"\"\n", 3814 | " return AutoTokenizer.from_pretrained(model_name)" 3815 | ] 3816 | }, 3817 | { 3818 | "cell_type": "code", 3819 | "execution_count": 71, 3820 | "id": "7876fc25-74fb-4f92-bbe6-82a1b488aab2", 3821 | "metadata": { 3822 | "tags": [] 3823 | }, 3824 | "outputs": [], 3825 | "source": [ 3826 | "\n", 3827 | "def get_vocab_size(tokenizer):\n", 3828 | " \"\"\"\n", 3829 | " Return the vocabulary size of the provided tokenizer.\n", 3830 | " \"\"\"\n", 3831 | " return tokenizer.vocab_size" 3832 | ] 3833 | }, 3834 | { 3835 | "cell_type": "code", 3836 | "execution_count": 72, 3837 | "id": "928f33de-9673-48b9-a770-528b85fcc13a", 3838 | "metadata": { 3839 | "tags": [] 3840 | }, 3841 | "outputs": [], 3842 | "source": [ 3843 | "def tokenize_sentence(tokenizer, sentence):\n", 3844 | " \"\"\"\n", 3845 | " Tokenize the sentence using the provided tokenizer and return the tokens and token IDs.\n", 3846 | " \"\"\"\n", 3847 | " tokens = tokenizer.tokenize(sentence)\n", 3848 | " token_ids = tokenizer.encode(sentence, add_special_tokens=True)\n", 3849 | " return tokens, token_ids" 3850 | ] 3851 | }, 3852 | { 3853 | "cell_type": "code", 3854 | "execution_count": 73, 3855 | "id": "985620a7-65fe-46d1-9b34-1c934b995aec", 3856 | "metadata": { 3857 | "tags": [] 3858 | }, 3859 | "outputs": [], 3860 | "source": [ 3861 | "def print_tokenizer_info(tokenizer_name, tokens, token_ids, vocab_size):\n", 3862 | " \"\"\"\n", 3863 | " Print information about the tokenizer, including tokens, token IDs, and vocabulary size.\n", 3864 | " \"\"\"\n", 3865 | " print(f\"\\n{tokenizer_name} Vocabulary Size: {vocab_size}\")\n", 3866 | " print(f\"{tokenizer_name} Tokenized Sentence:\")\n", 3867 | " print(tokens)\n", 3868 | " print(f\"{tokenizer_name} Token IDs:\")\n", 3869 | " print(token_ids)\n" 3870 | ] 3871 | }, 3872 | { 3873 | "cell_type": "code", 3874 | "execution_count": 74, 3875 | "id": "616fa702-22cf-46f1-91b3-a9901cbb0fa5", 3876 | "metadata": { 3877 | "tags": [] 3878 | }, 3879 | "outputs": [], 3880 | "source": [ 3881 | "\n", 3882 | "# Define the tokenizers to be compared\n", 3883 | "tokenizers = {\n", 3884 | " \"BART\": \"facebook/bart-base\",\n", 3885 | " \"DistilBERT\": \"distilbert-base-uncased\",\n", 3886 | " \"GPT-2\": \"gpt2\",\n", 3887 | " \"T5\": \"t5-small\",\n", 3888 | " \"Albert\": \"albert-base-v2\",\n", 3889 | " \"XLM-Roberta\": \"xlm-roberta-base\"\n", 3890 | "}" 3891 | ] 3892 | }, 3893 | { 3894 | "cell_type": "code", 3895 | "execution_count": 75, 3896 | "id": "e520c2b0-7d7f-4d48-be42-88991bc0e910", 3897 | "metadata": { 3898 | "tags": [] 3899 | }, 3900 | "outputs": [], 3901 | "source": [ 3902 | "# Sentence to tokenize\n", 3903 | "sentence = \"This is how a tokenized expression looks. En español es distinto\"" 3904 | ] 3905 | }, 3906 | { 3907 | "cell_type": "code", 3908 | "execution_count": 76, 3909 | "id": "fb818940-05ae-4036-93cf-cc0492d3afcf", 3910 | "metadata": { 3911 | "tags": [] 3912 | }, 3913 | "outputs": [ 3914 | { 3915 | "name": "stderr", 3916 | "output_type": "stream", 3917 | "text": [ 3918 | "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", 3919 | " warnings.warn(\n" 3920 | ] 3921 | }, 3922 | { 3923 | "name": "stdout", 3924 | "output_type": "stream", 3925 | "text": [ 3926 | "\n", 3927 | "BART Vocabulary Size: 50265\n", 3928 | "BART Tokenized Sentence:\n", 3929 | "['This', 'Ġis', 'Ġhow', 'Ġa', 'Ġtoken', 'ized', 'Ġexpression', 'Ġlooks', '.', 'ĠEn', 'Ġes', 'pa', 'ñ', 'ol', 'Ġes', 'Ġdist', 'into']\n", 3930 | "BART Token IDs:\n", 3931 | "[0, 713, 16, 141, 10, 19233, 1538, 8151, 1326, 4, 2271, 2714, 6709, 6303, 1168, 2714, 7018, 12473, 2]\n", 3932 | "\n", 3933 | "DistilBERT Vocabulary Size: 30522\n", 3934 | "DistilBERT Tokenized Sentence:\n", 3935 | "['this', 'is', 'how', 'a', 'token', '##ized', 'expression', 'looks', '.', 'en', 'es', '##pan', '##ol', 'es', 'di', '##sti', '##nto']\n", 3936 | "DistilBERT Token IDs:\n", 3937 | "[101, 2023, 2003, 2129, 1037, 19204, 3550, 3670, 3504, 1012, 4372, 9686, 9739, 4747, 9686, 4487, 16643, 13663, 102]\n", 3938 | "\n", 3939 | "GPT-2 Vocabulary Size: 50257\n", 3940 | "GPT-2 Tokenized Sentence:\n", 3941 | "['This', 'Ġis', 'Ġhow', 'Ġa', 'Ġtoken', 'ized', 'Ġexpression', 'Ġlooks', '.', 'ĠEn', 'Ġes', 'pa', 'ñ', 'ol', 'Ġes', 'Ġdist', 'into']\n", 3942 | "GPT-2 Token IDs:\n", 3943 | "[1212, 318, 703, 257, 11241, 1143, 5408, 3073, 13, 2039, 1658, 8957, 12654, 349, 1658, 1233, 20424]\n", 3944 | "\n", 3945 | "T5 Vocabulary Size: 32100\n", 3946 | "T5 Tokenized Sentence:\n", 3947 | "['▁This', '▁is', '▁how', '▁', 'a', '▁token', 'ized', '▁expression', '▁looks', '.', '▁En', '▁esp', 'a', 'ñ', 'o', 'l', '▁', 'e', 's', '▁', 'distin', 'to']\n", 3948 | "T5 Token IDs:\n", 3949 | "[100, 19, 149, 3, 9, 14145, 1601, 3893, 1416, 5, 695, 16159, 9, 2, 32, 40, 3, 15, 7, 3, 19694, 235, 1]\n", 3950 | "\n", 3951 | "Albert Vocabulary Size: 30000\n", 3952 | "Albert Tokenized Sentence:\n", 3953 | "['▁this', '▁is', '▁how', '▁a', '▁to', 'ken', 'ized', '▁expression', '▁looks', '.', '▁en', '▁espanol', '▁', 'es', '▁dis', 'tin', 'to']\n", 3954 | "Albert Token IDs:\n", 3955 | "[2, 48, 25, 184, 21, 20, 2853, 1333, 1803, 1879, 9, 1957, 24339, 13, 160, 1460, 2864, 262, 3]\n", 3956 | "\n", 3957 | "XLM-Roberta Vocabulary Size: 250002\n", 3958 | "XLM-Roberta Tokenized Sentence:\n", 3959 | "['▁This', '▁is', '▁how', '▁a', '▁to', 'ken', 'ized', '▁expression', '▁looks', '.', '▁En', '▁español', '▁es', '▁distin', 'to']\n", 3960 | "XLM-Roberta Token IDs:\n", 3961 | "[0, 3293, 83, 3642, 10, 47, 1098, 29367, 125195, 33342, 5, 357, 36131, 198, 34973, 188, 2]\n" 3962 | ] 3963 | } 3964 | ], 3965 | "source": [ 3966 | "# Load tokenizers and compare them\n", 3967 | "for name, model_name in tokenizers.items():\n", 3968 | " tokenizer = load_tokenizer(model_name)\n", 3969 | " vocab_size = get_vocab_size(tokenizer)\n", 3970 | " tokens, token_ids = tokenize_sentence(tokenizer, sentence)\n", 3971 | " print_tokenizer_info(name, tokens, token_ids, vocab_size)" 3972 | ] 3973 | } 3974 | ], 3975 | "metadata": { 3976 | "colab": { 3977 | "collapsed_sections": [ 3978 | "32cc46a7-b9f3-496f-8c1d-7c6700e68edd" 3979 | ], 3980 | "provenance": [], 3981 | "toc_visible": true 3982 | }, 3983 | "kernelspec": { 3984 | "display_name": ".conda-default:Python", 3985 | "language": "python", 3986 | "name": "conda-env-.conda-default-py" 3987 | }, 3988 | "language_info": { 3989 | "codemirror_mode": { 3990 | "name": "ipython", 3991 | "version": 3 3992 | }, 3993 | "file_extension": ".py", 3994 | "mimetype": "text/x-python", 3995 | "name": "python", 3996 | "nbconvert_exporter": "python", 3997 | "pygments_lexer": "ipython3", 3998 | "version": "3.9.19" 3999 | } 4000 | }, 4001 | "nbformat": 4, 4002 | "nbformat_minor": 5 4003 | } 4004 | --------------------------------------------------------------------------------