├── README.md ├── homeworks ├── hw1_n_gram_generaiton.ipynb └── hw2_rnn_classification │ ├── biotech_news.tsv │ └── hw2_rnn_classification.ipynb ├── week01_text_classification ├── ag_news_test.csv ├── ag_news_train.csv ├── lecture1.pdf └── seminar1.ipynb ├── week02_generation ├── lecture2.pdf └── seminar2.ipynb ├── week03_transformer ├── lecture3.pdf └── seminar3.ipynb ├── week04_bert_gpt ├── lecture4.pdf └── seminar4.ipynb └── week05_distil_quant ├── lecture5.pdf └── seminar5.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Natural Language Processing (NLP), ФКН ВШЭ 2 | 3 | Этот репозиторий содержит материалы лекций, семинаров и домашние задания. 4 | 5 | # Темы курса 6 | 7 | 1. Классификация текста. Записи: [лекция](https://disk.yandex.ru/i/f4iwpQSXOGNXlA), [семинар](https://disk.yandex.ru/i/JRcJ3bIcsoJPYQ) 8 | 2. Генерация текста, RNN. Записи: [лекция](https://disk.yandex.ru/i/9q02Vbzy4GKw3w), [семинар](https://disk.yandex.ru/i/Gf2KyS3odxx_FQ) 9 | 3. Трансформер. Записи: [лекция](https://disk.yandex.ru/i/jNpFYKPMxFfjrg), [семинар](https://disk.yandex.ru/i/6DtHWdcH4KrvLQ) 10 | 11 | # Преподаватели 12 | 13 | * [Александр Шабалин](https://t.me/amshabalin) 14 | * [Егор Чимбулатов](https://t.me/m0rjique) 15 | * [Дарья Андреева](https://t.me/Xufana) 16 | * [Алексей Биршерт](https://t.me/Birshert) 17 | -------------------------------------------------------------------------------- /homeworks/hw1_n_gram_generaiton.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Глубинное обучение для текстовых данных, ФКН ВШЭ\n", 8 | "\n", 9 | "## Домашнее задание 1: Text Suggestion\n", 10 | "\n", 11 | "__Мягкий дедлайн: 24.09 23:59__ \n", 12 | "__Жесткий дедлайн: 27.09 23:59__\n", 13 | "\n", 14 | "### О задании\n", 15 | "\n", 16 | "В этом задании вам предстоит реализовать систему, предлагающую удачное продолжение слова или нескольких следующих слов в режиме реального времени по типу тех, которые используются в почте или поисковой строке. За дополнительные баллы полученную систему нужно будет обернуть в пользовательский интерфейс с помощью библиотеки [reflex](https://github.com/reflex-dev/reflex) или аналогов. В этой домашке вам не придется обучать никаких моделей, мы ограничимся n-граммной генерацией.\n", 17 | "\n", 18 | "### Структура\n", 19 | "\n", 20 | "Это домашнее задание состоит из двух частей: основной и бонусной. В первой вам нужно будет выполнить 5 заданий, по итогам которых вы получите минимально рабочее решение. А во второй, пользуясь тем, что вы уже сделали реализовать полноценную систему подсказки текста с пользовательским интерфейсом. Во второй части мы никак не будем ограничивать вашу фантазию. Делайте что угодно, лишь бы в результате получился удобный фреймворк. Чем лучше у вас будет результат, тем больше баллов вы получите. Если будет совсем хорошо, то мы добавим бонусов сверху по своему усмотрению.\n", 21 | "\n", 22 | "### Оценивание и штрафы\n", 23 | "\n", 24 | "Максимально допустимая оценка за работу — 15 баллов. Сдавать задание после жесткого дедлайна нельзя. При сдачи решения после мягкого дедлайна за каждый день просрочки снимается по __одному__ баллу.\n", 25 | "\n", 26 | "Задание выполняется самостоятельно. «Похожие» решения считаются плагиатом и все задействованные студенты (в том числе те, у кого списали) не могут получить за него больше 0 баллов. Весь код должен быть написан самостоятельно. Чужим кодом для пользоваться запрещается даже с указанием ссылки на источник. В разумных рамках, конечно. Взять пару очевидных строчек кода для реализации какого-то небольшого функционала можно.\n", 27 | "\n", 28 | "Неэффективная реализация кода может негативно отразиться на оценке. Также оценка может быть снижена за плохо читаемый код.\n", 29 | "\n", 30 | "При сдаче зададания в anytask вам будет необходимо сдать весь код, а если вы возьметесь за бонусную часть, то еще отчет и видео с демонстрацией вашего UI. За основную часть можно получить до __10-ти__ баллов, а за бонусную – до __5-ти__ баллов." 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "### Данные\n", 38 | "\n", 39 | "Для получения текстовых статистик используйте датасет `emails.csv`. Вы можете найти его по [ссылке](https://disk.yandex.ru/d/ikyUhWPlvfXxCg). Он содержит более 500 тысяч электронных писем на английском языке." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 1, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "import pandas as pd\n", 49 | "\n", 50 | "emails = pd.read_csv('emails.csv')\n", 51 | "len(emails)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "Заметьте, что данные очень грязные. В каждом письме содержится различная мета-информация, которая будет только мешать при предсказании продолжения текста.\n", 59 | "\n", 60 | "__Задание 1 (2 балла).__ Очистите корпус текстов по вашему усмотрению и объясните свой выбор. В идеале обработанные тексты должны содержать только текст самого письма и ничего лишнего по типу ссылок, адресатов и прочих символов, которыми мы точно не хотим продолжать текст." 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 5, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "# your code here" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "Для следующего задания вам нужно будет токенизировать текст. Для этого просто разбейте его по словам. Очевидно, итоговый результат для финального пользователя будет лучше, если ваша система также будет предлагать уместную пунктуацию. Но если вы заметите, что из-за этого падает качество самого текса, то можете удалить все небуквенные символы на этапе токенизации." 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "## Общая схема решения\n", 84 | "\n", 85 | "Мы хотим сделать систему, которая будет ускорять набор текста, советуя подходящие продолжения. Для подсказки следующего слова мы будем использовать n-граммную модель. Так как n-граммная модель работает с целыми словами, а советы мы хотим давать в риал-тайме даже когда слово еще не дописано, сперва надо научиться дополнять слово до целого." 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "## Дополнение слова\n", 93 | "\n", 94 | "В этой части вам предстоит реализовать метод дополнения слова до целого по его началу (префиксу). Для этого сперва необходимо научиться находить все слова, имеющие определенный префикс. Мы будем вызывать функцию поиска подходящих слов после каждой напечатанной пользователем буквы. Поэтому нам очень важно, чтобы поиск работал как можно быстрее. Простой перебор всех слов занимает $O(|V| \\cdot n)$ времени, где $|V|$ – размер словаря, а $n$ – длина префикса. Мы же напишем [префиксное дерево](https://ru.wikipedia.org/wiki/Префиксное_дерево), которое позволяет искать слова не больше чем за $O(n + mk)$, где $m$ - число подходящих слов, а $k$ – длина суффикса." 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "__Задание 2 (2 балла).__ Допишите префиксное дерево для поиска слов по префиксу." 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 401, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "from typing import List\n", 111 | "\n", 112 | "class PrefixTreeNode:\n", 113 | " def __init__(self):\n", 114 | " self.children: dict[str, PrefixTreeNode] = {}\n", 115 | " self.is_end_of_word = False\n", 116 | "\n", 117 | "class PrefixTree:\n", 118 | " def __init__(self, vocabulary: List[str]):\n", 119 | " \"\"\"\n", 120 | " vocabulary: список всех уникальных токенов в корпусе\n", 121 | " \"\"\"\n", 122 | " self.root = PrefixTreeNode()\n", 123 | " \n", 124 | " # your code here\n", 125 | "\n", 126 | " def search_prefix(self, prefix) -> List[str]:\n", 127 | " \"\"\"\n", 128 | " Возвращает все слова, начинающиеся на prefix\n", 129 | " prefix: str – префикс слова\n", 130 | " \"\"\"\n", 131 | "\n", 132 | " # your code here" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "vocabulary = ['aa', 'aaa', 'abb', 'bba', 'bbb', 'bcd']\n", 142 | "prefix_tree = PrefixTree(vocabulary)\n", 143 | "\n", 144 | "assert set(prefix_tree.search_prefix('a')) == set(['aa', 'aaa', 'abb'])\n", 145 | "assert set(prefix_tree.search_prefix('bb')) == set(['bba', 'bbb'])" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "Теперь, когда у нас есть способ быстро находить все слова с определенным префиксом, нам нужно их упорядочить по вероятности, чтобы выбирать лучшее. Будем оценивать вероятность слова по частоте его __встречаемости в корпусе__.\n", 153 | "\n", 154 | "__Задание 3 (2 балла).__ Допишите класс `WordCompletor`, который формирует словарь и префиксное дерево, а так же умеет находить все возможные продолжения слова вместе с их вероятностями. В этом классе вы можете при необходимости дополнительно отфильтровать слова, например, удалив все самые редкие. Постарайтесь максимально оптимизировать ваш код." 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 284, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "class WordCompletor:\n", 164 | " def __init__(self, corpus):\n", 165 | " \"\"\"\n", 166 | " corpus: list – корпус текстов\n", 167 | " \"\"\"\n", 168 | " # your code here\n", 169 | " self.prefix_tree = PrefixTree()\n", 170 | "\n", 171 | " def get_words_and_probs(self, prefix: str) -> (List[str], List[float]):\n", 172 | " \"\"\"\n", 173 | " Возвращает список слов, начинающихся на prefix,\n", 174 | " с их вероятностями (нормировать ничего не нужно)\n", 175 | " \"\"\"\n", 176 | " words, probs = [], []\n", 177 | " # your code here\n", 178 | " return words, probs" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "dummy_corpus = [\n", 188 | " [\"aa\", \"ab\"],\n", 189 | " [\"aaa\", \"abab\"],\n", 190 | " [\"abb\", \"aa\", \"ab\", \"bba\", \"bbb\", \"bcd\"],\n", 191 | "]\n", 192 | "\n", 193 | "word_completor = WordCompletor(dummy_corpus)\n", 194 | "words, probs = word_completor.get_words_and_probs('a')\n", 195 | "words_probs = list(zip(words, probs))\n", 196 | "assert set(words_probs) == {('aa', 0.2), ('ab', 0.2), ('aaa', 0.1), ('abab', 0.1), ('abb', 0.1)}" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": {}, 202 | "source": [ 203 | "## Предсказание следующих слов\n", 204 | "\n", 205 | "Теперь, когда мы умеем дописывать слово за пользователем, мы можем пойти дальше и предожить ему следующее слово (или несколько) с учетом дописанного. Для этого мы воспользуемся n-граммной моделью.\n", 206 | "\n", 207 | "Напомним, что вероятность последовательности для такой модели записывается по формуле\n", 208 | "$$\n", 209 | "P(w_1, \\dots, w_T) = \\prod_{i=1}^T P(w_i \\mid w_{i-1}, \\dots, w_{i-n}).\n", 210 | "$$\n", 211 | "\n", 212 | "$P(w_i \\mid w_{i-1}, \\dots, w_{i-n})$ оценивается по частоте встречаемости n-граммы. \n", 213 | "\n", 214 | "__Задание 4 (2 балла).__ Напишите класс для n-граммной модели. Никакого сглаживания добавлять не надо, мы же не хотим, чтобы модель советовала случайные слова (хоть и очень редко)." 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 403, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "class NGramLanguageModel:\n", 224 | " def __init__(self, corpus, n):\n", 225 | " # your code here\n", 226 | "\n", 227 | " def get_next_words_and_probs(self, prefix: list) -> (List[str], List[float]):\n", 228 | " \"\"\"\n", 229 | " Возвращает список слов, которые могут идти после prefix,\n", 230 | " а так же список вероятностей этих слов\n", 231 | " \"\"\"\n", 232 | "\n", 233 | " next_words, probs = [], []\n", 234 | " \n", 235 | " # your code here\n", 236 | "\n", 237 | " return next_words, probs" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "dummy_corpus = [\n", 247 | " ['aa', 'aa', 'aa', 'aa', 'ab'],\n", 248 | " ['aaa', 'abab'],\n", 249 | " ['abb', 'aa', 'ab', 'bba', 'bbb', 'bcd']\n", 250 | "]\n", 251 | "\n", 252 | "n_gram_model = NGramLanguageModel(corpus=dummy_corpus, n=2)\n", 253 | "\n", 254 | "next_words, probs = n_gram_model.get_next_words_and_probs(['aa', 'aa'])\n", 255 | "words_probs = list(zip(next_words, probs))\n", 256 | "\n", 257 | "assert set(words_probs) == {('aa', 2/3), ('ab', 1/3)}" 258 | ] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "metadata": {}, 263 | "source": [ 264 | "Отлично, мы теперь можем объединить два метода в автоматический дописыватель текстов: первый будет дополнять слово, а второй – предлагать продолжения. Хочется, чтобы предлагался список возможных продолжений, из который пользователь сможет выбрать наиболее подходящее. Самое сложное тут – аккуратно выбирать, что показывать, а что нет. \n", 265 | "\n", 266 | "__Задание 5 (2 балла).__ В качестве первого подхода к снаряду реализуйте метод, возвращающий всегда самое вероятное продолжение жадным способом. После этого можно добавить опцию генерации нескольких вариантов продолжений, что сделает метод гораздо лучше." 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 443, 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "class TextSuggestion:\n", 276 | " def __init__(self, word_completor, n_gram_model):\n", 277 | " self.word_completor = word_completor\n", 278 | " self.n_gram_model = n_gram_model\n", 279 | "\n", 280 | " def suggest_text(self, text: Union[str, list], n_words=3, n_texts=1) -> list[list[str]]:\n", 281 | " \"\"\"\n", 282 | " Возвращает возможные варианты продолжения текста (по умолчанию только один)\n", 283 | " \n", 284 | " text: строка или список слов – написанный пользователем текст\n", 285 | " n_words: число слов, которые дописывает n-граммная модель\n", 286 | " n_texts: число возвращаемых продолжений (пока что только одно)\n", 287 | " \n", 288 | " return: list[list[srt]] – список из n_texts списков слов, по 1 + n_words слов в каждом\n", 289 | " Первое слово – это то, которое WordCompletor дополнил до целого.\n", 290 | " \"\"\"\n", 291 | "\n", 292 | " suggestions = []\n", 293 | "\n", 294 | " # your code here\n", 295 | "\n", 296 | " return suggestions" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "dummy_corpus = [\n", 306 | " ['aa', 'aa', 'aa', 'aa', 'ab'],\n", 307 | " ['aaa', 'abab'],\n", 308 | " ['abb', 'aa', 'ab', 'bba', 'bbb', 'bcd']\n", 309 | "]\n", 310 | "\n", 311 | "word_completor = WordCompletor(dummy_corpus)\n", 312 | "n_gram_model = NGramLanguageModel(corpus=dummy_corpus, n=2)\n", 313 | "text_suggestion = TextSuggestion(word_completor, n_gram_model)\n", 314 | "\n", 315 | "assert text_suggestion.suggest_text(['aa', 'aa'], n_words=3, n_texts=1) == [['aa', 'aa', 'aa', 'aa']]\n", 316 | "assert text_suggestion.suggest_text(['abb', 'aa', 'ab'], n_words=2, n_texts=1) == [['ab', 'bba', 'bbb']]" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 450, 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [ 325 | "text_suggestion = TextSuggestion(word_completor, n_gram_model)" 326 | ] 327 | }, 328 | { 329 | "cell_type": "markdown", 330 | "metadata": {}, 331 | "source": [ 332 | "## Бонусная часть: Добавляем UI" 333 | ] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "metadata": {}, 338 | "source": [ 339 | "Запускать ячейки в юпитере – это хорошо, но будет лучше, если вашим решением действительно можно будет пользоваться. Для этого вам предлагается добавить полноценных User Interface. Мы рекомендуем использовать для этого [reflex](https://github.com/reflex-dev/reflex). Это Python библиотека для создания web-интерфейсом с очень богатым функционалом.\n", 340 | "\n", 341 | "Ваша задача – сделать поле для текстового ввода, при наборе текста в котором будут появляться подсказки в реальном времени. Продумайте, как пользователь будет выбирать подсказки, сколько продолжений рекомендавать и так далее. В общем, должно получиться красиво и удобно. В этой части вы можете модифицировать все классы по своему усмотрению и добавлять любые эвристики. Если нужно, то дополнительно обрабатывать текст и вообще делать все, что считаете нужным. \n", 342 | "\n", 343 | "За это задание можно получить до __5-ти бонусных баллов__ в зависимости о того, насколько хорошо и удобно у вас получилось. При сдаче задания прикрепите небольшой __отчет__ (полстраницы) с описанием вашей системы, а также __видео__ (1-2 минуты) с демонстрацией работы интерфейса.\n", 344 | "\n", 345 | "Мы настоятельно рекомендуем вам оформить код в проект, а не писать в ноутбуке. Но если вам очень хочется писать тут, то хотя бы не меняйте код в предыдущих заданиях, чтобы его можно было нормально оценивать." 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": null, 351 | "metadata": {}, 352 | "outputs": [], 353 | "source": [] 354 | } 355 | ], 356 | "metadata": { 357 | "kernelspec": { 358 | "display_name": "Python 3 (ipykernel)", 359 | "language": "python", 360 | "name": "python3" 361 | }, 362 | "language_info": { 363 | "codemirror_mode": { 364 | "name": "ipython", 365 | "version": 3 366 | }, 367 | "file_extension": ".py", 368 | "mimetype": "text/x-python", 369 | "name": "python", 370 | "nbconvert_exporter": "python", 371 | "pygments_lexer": "ipython3", 372 | "version": "3.12.6" 373 | }, 374 | "notebookId": "53997d2d-afb8-4477-8874-b6d46299f06c", 375 | "notebookPath": "seminar.ipynb" 376 | }, 377 | "nbformat": 4, 378 | "nbformat_minor": 4 379 | } 380 | -------------------------------------------------------------------------------- /homeworks/hw2_rnn_classification/hw2_rnn_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b1acf78a", 6 | "metadata": { 7 | "id": "b1acf78a" 8 | }, 9 | "source": [ 10 | "# Глубинное обучение для текстовых данных, ФКН ВШЭ\n", 11 | "\n", 12 | "## Домашнее задание 2: Рекуррентные нейронные сети\n", 13 | "\n", 14 | "### Оценивание и штрафы\n", 15 | "\n", 16 | "Максимально допустимая оценка за работу — __10 (+3) баллов__. Сдавать задание после указанного срока сдачи нельзя.\n", 17 | "\n", 18 | "Задание выполняется самостоятельно. «Похожие» решения считаются плагиатом и все задействованные студенты (в том числе те, у кого списали) не могут получить за него больше 0 баллов. Весь код должен быть написан самостоятельно. Чужим кодом для пользоваться запрещается даже с указанием ссылки на источник. В разумных рамках, конечно. Взять пару очевидных строчек кода для реализации какого-то небольшого функционала можно.\n", 19 | "\n", 20 | "Неэффективная реализация кода может негативно отразиться на оценке. Также оценка может быть снижена за плохо читаемый код и плохо оформленные графики. Все ответы должны сопровождаться кодом или комментариями о том, как они были получены.\n", 21 | "\n", 22 | "__Мягкий дедлайн: 5.10.25 23:59__ \n", 23 | "__Жесткий дедлайн: 8.10.25 23:59__\n", 24 | "\n", 25 | "### О задании\n", 26 | "\n", 27 | "В этом задании вам предстоит самостоятельно реализовать модель LSTM для решения задачи классификации с пересекающимися классами (multi-label classification). Это вид классификации, в которой каждый объект может относиться одновременно к нескольким классам. Такая задача часто возникает при классификации фильмов по жанрам, научных или новостных статей по темам, музыкальных композиций по инструментам и так далее.\n", 28 | "\n", 29 | "В нашем случае мы будем работать с датасетом биотехнических новостей и классифицировать их по темам. Этот датасет уже предобработан: текст приведен к нижнему регистру, удалена пунктуация, все слова разделены проблелом." 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "id": "af1a5fff", 36 | "metadata": { 37 | "colab": { 38 | "base_uri": "https://localhost:8080/", 39 | "height": 206 40 | }, 41 | "id": "af1a5fff", 42 | "outputId": "891c58cd-6964-4319-ade3-92bb90356f93" 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "import pandas as pd\n", 47 | "\n", 48 | "dataset = pd.read_csv('biotech_news.tsv', sep='\\t')\n", 49 | "dataset.head()" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "id": "HRBZwYd9QMMS", 55 | "metadata": { 56 | "id": "HRBZwYd9QMMS" 57 | }, 58 | "source": [ 59 | "## Предобработка лейблов\n", 60 | "\n", 61 | "\n", 62 | "__Задание 1 (0.5 балла)__. Как вы можете заметить, лейблы записаны в виде строк, разделенных запятыми. Для работы с ними нам нужно преобразовать их в числа. Так как каждый объект может принадлежать нескольким классам, закодируйте лейблы в виде векторов из 0 и 1, где 1 означает, что объект принадлежит соответствующему классу, а 0 – не принадлежит. Имея такую кодировку, мы сможем обучить модель, решая задачу бинарной классификации для каждого класса." 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "id": "7c65a9bf-dbe9-4cad-978d-3a0e10b1eac1", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "# your code here" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "id": "83c0296f-9699-475e-b4bd-c9e531dca2d4", 78 | "metadata": {}, 79 | "source": [ 80 | "## Предобработка данных" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "id": "vMe0c5AAXM8d", 86 | "metadata": { 87 | "id": "vMe0c5AAXM8d" 88 | }, 89 | "source": [ 90 | "В этом задании мы будем обучать рекуррентные нейронные сети. Как вы знаете, они работают лучше для коротких текстов, так как не очень хорошо улавливают далекие зависимости. Для уменьшение длин текстов их стоит почистить.\n", 91 | "\n", 92 | "Сразу разделим выборку на обучающую и тестовую, чтобы считать все нужные статистики только по обучающей." 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "id": "f8135000", 99 | "metadata": { 100 | "id": "f8135000" 101 | }, 102 | "outputs": [], 103 | "source": [ 104 | "from sklearn.model_selection import train_test_split\n", 105 | "\n", 106 | "texts_train, texts_test, y_train, y_test = train_test_split(\n", 107 | " ,\n", 108 | " ,\n", 109 | " test_size=0.2, # do not change this\n", 110 | " random_state=0 # do not change this\n" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "id": "4ace679c-db5f-45d3-8fa3-6a5c55eb912a", 116 | "metadata": {}, 117 | "source": [ 118 | "__Задание 2 (1 балл)__. Удалите из текстов стоп слова, слишком редкие и слишком частые слова. Гиперпараметры подберите самостоятельно (в идеале их стоит подбирать по качеству на тестовой выборке). Если вы считаете, что стоит добавить еще какую-то обработку, то сделайте это. Важно не удалить ничего, что может повлиять на предсказание класса." 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "BcmyCcoaXIqy", 125 | "metadata": { 126 | "id": "BcmyCcoaXIqy" 127 | }, 128 | "outputs": [], 129 | "source": [ 130 | "# your code here" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "id": "4f4848c2-7fe1-43f9-8564-1144015fc29b", 136 | "metadata": {}, 137 | "source": [ 138 | "__Задание 3 (1.5 балла)__. Осталось перевести тексты в индексы токенов, чтобы их можно было подавать в модель. У вас есть две опции, как это сделать:\n", 139 | "1. __(+0 баллов)__ Токенизировать тексты по словам.\n", 140 | "2. __(до +3 баллов)__ Реализовать свою токенизацию BPE. Количество баллов будет варьироваться в зависимости от эффективности реализации. При реализации нельзя пользоваться специализированными библиотеками.\n", 141 | "\n", 142 | "Токенизируйте тексты, переведите их в списки индексов и сложите вместе с лейблами в `DataLoader`. Не забудьте добавить в `DataLoader` `collate_fn`, которая будет дополнять все короткие тексты в батче паддингами. Для маппинга токенов в индексы вам может пригодиться `gensim.corpora.dictionary.Dictionary`." 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "id": "627e78f9-6bef-46f4-8b58-818b7eb0c082", 148 | "metadata": {}, 149 | "source": [ 150 | "## Метрика качества\n", 151 | "\n", 152 | "Перед тем, как приступить к обучению, нам нужно выбрать метрику оценки качества. Так как в задаче классификации с пересекающимися классами классы часто несбалансированы, чаще всего в качестве метрики берется [F1 score](https://en.wikipedia.org/wiki/F-score).\n", 153 | "\n", 154 | "Функция `compute_f1` принимает истинные метки и предсказанные и считает среднее значение F1 по всем классам. Используйте ее для оценки качества моделей.\n", 155 | "\n", 156 | "$$\n", 157 | "F1_{total} = \\frac{1}{K} \\sum_{k=1}^K F1(Y_k, \\hat{Y}_k),\n", 158 | "$$\n", 159 | "где $Y_k$ – истинные значения для класса k, а $\\hat{Y}_k$ – предсказания." 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "id": "671a0928-fd68-4f36-bae7-2dacb18fd865", 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "from sklearn.metrics import f1_score\n", 170 | "\n", 171 | "def compute_f1(y_true, y_pred):\n", 172 | " assert y_true.ndim == 2\n", 173 | " assert y_true.shape == y_pred.shape\n", 174 | "\n", 175 | " return f1_score(y_true, y_pred, average='macro')" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "id": "aagj29J7Ap2H", 181 | "metadata": { 182 | "id": "aagj29J7Ap2H" 183 | }, 184 | "source": [ 185 | "## Обучение моделей" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "id": "56ae5666", 191 | "metadata": { 192 | "id": "56ae5666" 193 | }, 194 | "source": [ 195 | "### RNN\n", 196 | "\n", 197 | "В качестве бейзлайна обучим самую простую рекуррентную нейронную сеть. Напомним, что блок RNN выглядит таким образом.\n", 198 | "\n", 199 | "\"drawing\"\n", 200 | "\n", 201 | "Его скрытое состояние обновляется по формуле\n", 202 | "$h_t = \\sigma(W x_{t} + U h_{t-1} + b_h)$. А предсказание считается с помощью применения линейного слоя к последнему токену\n", 203 | "$o_T = V h_T + b_o$. В качестве функции активации выберите гиперболический тангенс. \n", 204 | "\n", 205 | "__Задание 4 (2 балла)__. Реализуйте RNN в соответствии с формулой выше и обучите ее на нашу задачу. Нулевой скрытый вектор инициализируйте нулями, так модель будет обучаться стабильнее, чем при случайной инициализации. После этого замеряйте качество на тестовой выборке. У вас должно получиться значение F1 не меньше 0.33, а само обучение не должно занимать много времени." 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "id": "05743f95-dd39-43f5-81cf-1a79edc194fa", 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "# your code here" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "id": "xqt0dk6LEJUU", 221 | "metadata": { 222 | "id": "xqt0dk6LEJUU" 223 | }, 224 | "source": [ 225 | "### LSTM\n", 226 | "\n", 227 | "\"drawing\"\n", 228 | "\n", 229 | "Теперь перейдем к более продвинутым рекурренным моделям, а именно LSTM. Из-за дополнительного вектора памяти эта модель должна гораздо лучше улавливать далекие зависимости, что должно напрямую отражаться на качестве.\n", 230 | "\n", 231 | "Параметры блока LSTM обновляются вот так ($\\sigma$ означает сигмоиду):\n", 232 | "\\begin{align}\n", 233 | "f_{t} &= \\sigma(W_f x_{t} + U_f h_{t-1} + b_f) \\\\ \n", 234 | "i_{t} &= \\sigma(W_i x_{t} + U_i h_{t-1} + b_i) \\\\\n", 235 | "\\tilde{c}_{t} &= \\tanh(W_c x_{t} + U_c h_{t-1} + b_i) \\\\\n", 236 | "c_{t} &= f_t \\odot c_{t-1} + i_t \\odot \\tilde{c}_t \\\\\n", 237 | "o_{t} &= \\sigma(W_t x_{t} + U_t h_{t-1} + b_t) \\\\\n", 238 | "h_t &= o_t \\odot \\tanh(c_t)\n", 239 | "\\end{align}\n", 240 | "\n", 241 | "__Задание 5 (2 балла).__ Реализуйте LSTM по описанной схеме. Выберите гиперпараметры LSTM так, чтобы их общее число (без учета слоя эмбеддингов) примерно совпадало с числом параметров обычной RNN, но размерность скрытого слоя была не меньше 64. Так мы будем сравнивать архитектуры максимально независимо. Обучите LSTM до сходимости и сравните качество с RNN на тестовой выборке. Удалось ли получить лучший результат? Как вы можете это объяснить?" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "id": "6e18b79b-f2c6-4474-a5c0-c8ce51f13afb", 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "# your code here" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "id": "abfdd1a1-51f4-4065-85f7-22ed773a2628", 257 | "metadata": {}, 258 | "source": [ 259 | "__Задание 6 (2 балла).__ Главный недостаток RNN моделей заключается в том, что при сжатии всей информации в один вектор, важные детали пропадают. Для решения этой проблемы был придуман механизм внимания. Реализуйте его по [оригинальной статье](https://arxiv.org/abs/1409.0473). Замерьте качество и сделайте выводы. \n", 260 | "Обратите внимание, что метод был предложен для Encoder-Decoder моделей. В нашем случае декодера нет, поэтому встройте внимание в энкодер: каждый блок LSTM будет смотреть на выходы всех предыдущих блоков. " 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "id": "b5bd1fa9-2c1f-4268-be24-5c31752204ac", 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "# your code here" 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "id": "phQ-ka4mp0oS", 276 | "metadata": { 277 | "id": "phQ-ka4mp0oS" 278 | }, 279 | "source": [ 280 | "__Задание 7 (1 балл).__ Добавьте в вашу реализации возможность увеличивать число слоев LSTM. Обучите модель с двумя слоями и замерьте качество. Сделайте выводы: стоит ли увеличивать размер модели?" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "id": "ee7c7177", 287 | "metadata": { 288 | "id": "ee7c7177" 289 | }, 290 | "outputs": [], 291 | "source": [ 292 | "# your code here" 293 | ] 294 | } 295 | ], 296 | "metadata": { 297 | "accelerator": "GPU", 298 | "colab": { 299 | "gpuType": "T4", 300 | "provenance": [] 301 | }, 302 | "kernelspec": { 303 | "display_name": "Python 3 (ipykernel)", 304 | "language": "python", 305 | "name": "python3" 306 | }, 307 | "language_info": { 308 | "codemirror_mode": { 309 | "name": "ipython", 310 | "version": 3 311 | }, 312 | "file_extension": ".py", 313 | "mimetype": "text/x-python", 314 | "name": "python", 315 | "nbconvert_exporter": "python", 316 | "pygments_lexer": "ipython3", 317 | "version": "3.12.6" 318 | }, 319 | "widgets": { 320 | "application/vnd.jupyter.widget-state+json": { 321 | "12b0627d4aaf46c0adc64b442bf88d0a": { 322 | "model_module": "@jupyter-widgets/controls", 323 | "model_module_version": "1.5.0", 324 | "model_name": "DescriptionStyleModel", 325 | "state": { 326 | "_model_module": "@jupyter-widgets/controls", 327 | "_model_module_version": "1.5.0", 328 | "_model_name": "DescriptionStyleModel", 329 | "_view_count": null, 330 | "_view_module": "@jupyter-widgets/base", 331 | "_view_module_version": "1.2.0", 332 | "_view_name": "StyleView", 333 | "description_width": "" 334 | } 335 | }, 336 | "1d5b2e090c51406e953b4eec4b0b91ad": { 337 | "model_module": "@jupyter-widgets/base", 338 | "model_module_version": "1.2.0", 339 | "model_name": "LayoutModel", 340 | "state": { 341 | "_model_module": "@jupyter-widgets/base", 342 | "_model_module_version": "1.2.0", 343 | "_model_name": "LayoutModel", 344 | "_view_count": null, 345 | "_view_module": "@jupyter-widgets/base", 346 | "_view_module_version": "1.2.0", 347 | "_view_name": "LayoutView", 348 | "align_content": null, 349 | "align_items": null, 350 | "align_self": null, 351 | "border": null, 352 | "bottom": null, 353 | "display": null, 354 | "flex": null, 355 | "flex_flow": null, 356 | "grid_area": null, 357 | "grid_auto_columns": null, 358 | "grid_auto_flow": null, 359 | "grid_auto_rows": null, 360 | "grid_column": null, 361 | "grid_gap": null, 362 | "grid_row": null, 363 | "grid_template_areas": null, 364 | "grid_template_columns": null, 365 | "grid_template_rows": null, 366 | "height": null, 367 | "justify_content": null, 368 | "justify_items": null, 369 | "left": null, 370 | "margin": null, 371 | "max_height": null, 372 | "max_width": null, 373 | "min_height": null, 374 | "min_width": null, 375 | "object_fit": null, 376 | "object_position": null, 377 | "order": null, 378 | "overflow": null, 379 | "overflow_x": null, 380 | "overflow_y": null, 381 | "padding": null, 382 | "right": null, 383 | "top": null, 384 | "visibility": null, 385 | "width": null 386 | } 387 | }, 388 | "282f83858a424e2ea76990eb957dc5a0": { 389 | "model_module": "@jupyter-widgets/base", 390 | "model_module_version": "1.2.0", 391 | "model_name": "LayoutModel", 392 | "state": { 393 | "_model_module": "@jupyter-widgets/base", 394 | "_model_module_version": "1.2.0", 395 | "_model_name": "LayoutModel", 396 | "_view_count": null, 397 | "_view_module": "@jupyter-widgets/base", 398 | "_view_module_version": "1.2.0", 399 | "_view_name": "LayoutView", 400 | "align_content": null, 401 | "align_items": null, 402 | "align_self": null, 403 | "border": null, 404 | "bottom": null, 405 | "display": null, 406 | "flex": null, 407 | "flex_flow": null, 408 | "grid_area": null, 409 | "grid_auto_columns": null, 410 | "grid_auto_flow": null, 411 | "grid_auto_rows": null, 412 | "grid_column": null, 413 | "grid_gap": null, 414 | "grid_row": null, 415 | "grid_template_areas": null, 416 | "grid_template_columns": null, 417 | "grid_template_rows": null, 418 | "height": null, 419 | "justify_content": null, 420 | "justify_items": null, 421 | "left": null, 422 | "margin": null, 423 | "max_height": null, 424 | "max_width": null, 425 | "min_height": null, 426 | "min_width": null, 427 | "object_fit": null, 428 | "object_position": null, 429 | "order": null, 430 | "overflow": null, 431 | "overflow_x": null, 432 | "overflow_y": null, 433 | "padding": null, 434 | "right": null, 435 | "top": null, 436 | "visibility": null, 437 | "width": null 438 | } 439 | }, 440 | "32808478ae8c4242beb79f0272ea6b1f": { 441 | "model_module": "@jupyter-widgets/base", 442 | "model_module_version": "1.2.0", 443 | "model_name": "LayoutModel", 444 | "state": { 445 | "_model_module": "@jupyter-widgets/base", 446 | "_model_module_version": "1.2.0", 447 | "_model_name": "LayoutModel", 448 | "_view_count": null, 449 | "_view_module": "@jupyter-widgets/base", 450 | "_view_module_version": "1.2.0", 451 | "_view_name": "LayoutView", 452 | "align_content": null, 453 | "align_items": null, 454 | "align_self": null, 455 | "border": null, 456 | "bottom": null, 457 | "display": null, 458 | "flex": null, 459 | "flex_flow": null, 460 | "grid_area": null, 461 | "grid_auto_columns": null, 462 | "grid_auto_flow": null, 463 | "grid_auto_rows": null, 464 | "grid_column": null, 465 | "grid_gap": null, 466 | "grid_row": null, 467 | "grid_template_areas": null, 468 | "grid_template_columns": null, 469 | "grid_template_rows": null, 470 | "height": null, 471 | "justify_content": null, 472 | "justify_items": null, 473 | "left": null, 474 | "margin": null, 475 | "max_height": null, 476 | "max_width": null, 477 | "min_height": null, 478 | "min_width": null, 479 | "object_fit": null, 480 | "object_position": null, 481 | "order": null, 482 | "overflow": null, 483 | "overflow_x": null, 484 | "overflow_y": null, 485 | "padding": null, 486 | "right": null, 487 | "top": null, 488 | "visibility": null, 489 | "width": null 490 | } 491 | }, 492 | "34e8d1401c0e4dc1a8e71bbad7c2f74d": { 493 | "model_module": "@jupyter-widgets/controls", 494 | "model_module_version": "1.5.0", 495 | "model_name": "HTMLModel", 496 | "state": { 497 | "_dom_classes": [], 498 | "_model_module": "@jupyter-widgets/controls", 499 | "_model_module_version": "1.5.0", 500 | "_model_name": "HTMLModel", 501 | "_view_count": null, 502 | "_view_module": "@jupyter-widgets/controls", 503 | "_view_module_version": "1.5.0", 504 | "_view_name": "HTMLView", 505 | "description": "", 506 | "description_tooltip": null, 507 | "layout": "IPY_MODEL_b23f3b8b7247491c8d5e3ead7f54d886", 508 | "placeholder": "​", 509 | "style": "IPY_MODEL_cb632291897f4f9db86a00a5a71ca35f", 510 | "value": " 40/40 [36:41<00:00, 51.61s/it]" 511 | } 512 | }, 513 | "3735627f227d4b4f927955113111409f": { 514 | "model_module": "@jupyter-widgets/controls", 515 | "model_module_version": "1.5.0", 516 | "model_name": "DescriptionStyleModel", 517 | "state": { 518 | "_model_module": "@jupyter-widgets/controls", 519 | "_model_module_version": "1.5.0", 520 | "_model_name": "DescriptionStyleModel", 521 | "_view_count": null, 522 | "_view_module": "@jupyter-widgets/base", 523 | "_view_module_version": "1.2.0", 524 | "_view_name": "StyleView", 525 | "description_width": "" 526 | } 527 | }, 528 | "47f4f11bc6984b96ac3c3875d733f0ba": { 529 | "model_module": "@jupyter-widgets/controls", 530 | "model_module_version": "1.5.0", 531 | "model_name": "FloatProgressModel", 532 | "state": { 533 | "_dom_classes": [], 534 | "_model_module": "@jupyter-widgets/controls", 535 | "_model_module_version": "1.5.0", 536 | "_model_name": "FloatProgressModel", 537 | "_view_count": null, 538 | "_view_module": "@jupyter-widgets/controls", 539 | "_view_module_version": "1.5.0", 540 | "_view_name": "ProgressView", 541 | "bar_style": "success", 542 | "description": "", 543 | "description_tooltip": null, 544 | "layout": "IPY_MODEL_dc4f687f9d5940aba074e2bb41581c93", 545 | "max": 40, 546 | "min": 0, 547 | "orientation": "horizontal", 548 | "style": "IPY_MODEL_6e10fd6d1a6c47a9ac34a47ae5ba708b", 549 | "value": 40 550 | } 551 | }, 552 | "4aab16bb20824688aadbd23460adad9b": { 553 | "model_module": "@jupyter-widgets/controls", 554 | "model_module_version": "1.5.0", 555 | "model_name": "HBoxModel", 556 | "state": { 557 | "_dom_classes": [], 558 | "_model_module": "@jupyter-widgets/controls", 559 | "_model_module_version": "1.5.0", 560 | "_model_name": "HBoxModel", 561 | "_view_count": null, 562 | "_view_module": "@jupyter-widgets/controls", 563 | "_view_module_version": "1.5.0", 564 | "_view_name": "HBoxView", 565 | "box_style": "", 566 | "children": [ 567 | "IPY_MODEL_f65eec1b45de42e59fb9e24b99aad917", 568 | "IPY_MODEL_47f4f11bc6984b96ac3c3875d733f0ba", 569 | "IPY_MODEL_f58fddb1bf414071b0523701a619ad71" 570 | ], 571 | "layout": "IPY_MODEL_32808478ae8c4242beb79f0272ea6b1f" 572 | } 573 | }, 574 | "4de9492961d841aa9f3d7bc629911296": { 575 | "model_module": "@jupyter-widgets/base", 576 | "model_module_version": "1.2.0", 577 | "model_name": "LayoutModel", 578 | "state": { 579 | "_model_module": "@jupyter-widgets/base", 580 | "_model_module_version": "1.2.0", 581 | "_model_name": "LayoutModel", 582 | "_view_count": null, 583 | "_view_module": "@jupyter-widgets/base", 584 | "_view_module_version": "1.2.0", 585 | "_view_name": "LayoutView", 586 | "align_content": null, 587 | "align_items": null, 588 | "align_self": null, 589 | "border": null, 590 | "bottom": null, 591 | "display": null, 592 | "flex": null, 593 | "flex_flow": null, 594 | "grid_area": null, 595 | "grid_auto_columns": null, 596 | "grid_auto_flow": null, 597 | "grid_auto_rows": null, 598 | "grid_column": null, 599 | "grid_gap": null, 600 | "grid_row": null, 601 | "grid_template_areas": null, 602 | "grid_template_columns": null, 603 | "grid_template_rows": null, 604 | "height": null, 605 | "justify_content": null, 606 | "justify_items": null, 607 | "left": null, 608 | "margin": null, 609 | "max_height": null, 610 | "max_width": null, 611 | "min_height": null, 612 | "min_width": null, 613 | "object_fit": null, 614 | "object_position": null, 615 | "order": null, 616 | "overflow": null, 617 | "overflow_x": null, 618 | "overflow_y": null, 619 | "padding": null, 620 | "right": null, 621 | "top": null, 622 | "visibility": null, 623 | "width": null 624 | } 625 | }, 626 | "67ae0c089c4a426db3b52976fae1a9dc": { 627 | "model_module": "@jupyter-widgets/base", 628 | "model_module_version": "1.2.0", 629 | "model_name": "LayoutModel", 630 | "state": { 631 | "_model_module": "@jupyter-widgets/base", 632 | "_model_module_version": "1.2.0", 633 | "_model_name": "LayoutModel", 634 | "_view_count": null, 635 | "_view_module": "@jupyter-widgets/base", 636 | "_view_module_version": "1.2.0", 637 | "_view_name": "LayoutView", 638 | "align_content": null, 639 | "align_items": null, 640 | "align_self": null, 641 | "border": null, 642 | "bottom": null, 643 | "display": null, 644 | "flex": null, 645 | "flex_flow": null, 646 | "grid_area": null, 647 | "grid_auto_columns": null, 648 | "grid_auto_flow": null, 649 | "grid_auto_rows": null, 650 | "grid_column": null, 651 | "grid_gap": null, 652 | "grid_row": null, 653 | "grid_template_areas": null, 654 | "grid_template_columns": null, 655 | "grid_template_rows": null, 656 | "height": null, 657 | "justify_content": null, 658 | "justify_items": null, 659 | "left": null, 660 | "margin": null, 661 | "max_height": null, 662 | "max_width": null, 663 | "min_height": null, 664 | "min_width": null, 665 | "object_fit": null, 666 | "object_position": null, 667 | "order": null, 668 | "overflow": null, 669 | "overflow_x": null, 670 | "overflow_y": null, 671 | "padding": null, 672 | "right": null, 673 | "top": null, 674 | "visibility": null, 675 | "width": null 676 | } 677 | }, 678 | "6e10fd6d1a6c47a9ac34a47ae5ba708b": { 679 | "model_module": "@jupyter-widgets/controls", 680 | "model_module_version": "1.5.0", 681 | "model_name": "ProgressStyleModel", 682 | "state": { 683 | "_model_module": "@jupyter-widgets/controls", 684 | "_model_module_version": "1.5.0", 685 | "_model_name": "ProgressStyleModel", 686 | "_view_count": null, 687 | "_view_module": "@jupyter-widgets/base", 688 | "_view_module_version": "1.2.0", 689 | "_view_name": "StyleView", 690 | "bar_color": null, 691 | "description_width": "" 692 | } 693 | }, 694 | "b23f3b8b7247491c8d5e3ead7f54d886": { 695 | "model_module": "@jupyter-widgets/base", 696 | "model_module_version": "1.2.0", 697 | "model_name": "LayoutModel", 698 | "state": { 699 | "_model_module": "@jupyter-widgets/base", 700 | "_model_module_version": "1.2.0", 701 | "_model_name": "LayoutModel", 702 | "_view_count": null, 703 | "_view_module": "@jupyter-widgets/base", 704 | "_view_module_version": "1.2.0", 705 | "_view_name": "LayoutView", 706 | "align_content": null, 707 | "align_items": null, 708 | "align_self": null, 709 | "border": null, 710 | "bottom": null, 711 | "display": null, 712 | "flex": null, 713 | "flex_flow": null, 714 | "grid_area": null, 715 | "grid_auto_columns": null, 716 | "grid_auto_flow": null, 717 | "grid_auto_rows": null, 718 | "grid_column": null, 719 | "grid_gap": null, 720 | "grid_row": null, 721 | "grid_template_areas": null, 722 | "grid_template_columns": null, 723 | "grid_template_rows": null, 724 | "height": null, 725 | "justify_content": null, 726 | "justify_items": null, 727 | "left": null, 728 | "margin": null, 729 | "max_height": null, 730 | "max_width": null, 731 | "min_height": null, 732 | "min_width": null, 733 | "object_fit": null, 734 | "object_position": null, 735 | "order": null, 736 | "overflow": null, 737 | "overflow_x": null, 738 | "overflow_y": null, 739 | "padding": null, 740 | "right": null, 741 | "top": null, 742 | "visibility": null, 743 | "width": null 744 | } 745 | }, 746 | "bc4165ff8fc3480fb1590b6ecd39fb4f": { 747 | "model_module": "@jupyter-widgets/controls", 748 | "model_module_version": "1.5.0", 749 | "model_name": "HBoxModel", 750 | "state": { 751 | "_dom_classes": [], 752 | "_model_module": "@jupyter-widgets/controls", 753 | "_model_module_version": "1.5.0", 754 | "_model_name": "HBoxModel", 755 | "_view_count": null, 756 | "_view_module": "@jupyter-widgets/controls", 757 | "_view_module_version": "1.5.0", 758 | "_view_name": "HBoxView", 759 | "box_style": "", 760 | "children": [ 761 | "IPY_MODEL_cba16e32a9df4b1b89b4f7066945fc42", 762 | "IPY_MODEL_e8f0522f19c44066b5a78ded999f050a", 763 | "IPY_MODEL_34e8d1401c0e4dc1a8e71bbad7c2f74d" 764 | ], 765 | "layout": "IPY_MODEL_282f83858a424e2ea76990eb957dc5a0" 766 | } 767 | }, 768 | "cb632291897f4f9db86a00a5a71ca35f": { 769 | "model_module": "@jupyter-widgets/controls", 770 | "model_module_version": "1.5.0", 771 | "model_name": "DescriptionStyleModel", 772 | "state": { 773 | "_model_module": "@jupyter-widgets/controls", 774 | "_model_module_version": "1.5.0", 775 | "_model_name": "DescriptionStyleModel", 776 | "_view_count": null, 777 | "_view_module": "@jupyter-widgets/base", 778 | "_view_module_version": "1.2.0", 779 | "_view_name": "StyleView", 780 | "description_width": "" 781 | } 782 | }, 783 | "cba16e32a9df4b1b89b4f7066945fc42": { 784 | "model_module": "@jupyter-widgets/controls", 785 | "model_module_version": "1.5.0", 786 | "model_name": "HTMLModel", 787 | "state": { 788 | "_dom_classes": [], 789 | "_model_module": "@jupyter-widgets/controls", 790 | "_model_module_version": "1.5.0", 791 | "_model_name": "HTMLModel", 792 | "_view_count": null, 793 | "_view_module": "@jupyter-widgets/controls", 794 | "_view_module_version": "1.5.0", 795 | "_view_name": "HTMLView", 796 | "description": "", 797 | "description_tooltip": null, 798 | "layout": "IPY_MODEL_67ae0c089c4a426db3b52976fae1a9dc", 799 | "placeholder": "​", 800 | "style": "IPY_MODEL_12b0627d4aaf46c0adc64b442bf88d0a", 801 | "value": "100%" 802 | } 803 | }, 804 | "d7ed88f49793494bbdb3c2fffc01b216": { 805 | "model_module": "@jupyter-widgets/controls", 806 | "model_module_version": "1.5.0", 807 | "model_name": "DescriptionStyleModel", 808 | "state": { 809 | "_model_module": "@jupyter-widgets/controls", 810 | "_model_module_version": "1.5.0", 811 | "_model_name": "DescriptionStyleModel", 812 | "_view_count": null, 813 | "_view_module": "@jupyter-widgets/base", 814 | "_view_module_version": "1.2.0", 815 | "_view_name": "StyleView", 816 | "description_width": "" 817 | } 818 | }, 819 | "dc4f687f9d5940aba074e2bb41581c93": { 820 | "model_module": "@jupyter-widgets/base", 821 | "model_module_version": "1.2.0", 822 | "model_name": "LayoutModel", 823 | "state": { 824 | "_model_module": "@jupyter-widgets/base", 825 | "_model_module_version": "1.2.0", 826 | "_model_name": "LayoutModel", 827 | "_view_count": null, 828 | "_view_module": "@jupyter-widgets/base", 829 | "_view_module_version": "1.2.0", 830 | "_view_name": "LayoutView", 831 | "align_content": null, 832 | "align_items": null, 833 | "align_self": null, 834 | "border": null, 835 | "bottom": null, 836 | "display": null, 837 | "flex": null, 838 | "flex_flow": null, 839 | "grid_area": null, 840 | "grid_auto_columns": null, 841 | "grid_auto_flow": null, 842 | "grid_auto_rows": null, 843 | "grid_column": null, 844 | "grid_gap": null, 845 | "grid_row": null, 846 | "grid_template_areas": null, 847 | "grid_template_columns": null, 848 | "grid_template_rows": null, 849 | "height": null, 850 | "justify_content": null, 851 | "justify_items": null, 852 | "left": null, 853 | "margin": null, 854 | "max_height": null, 855 | "max_width": null, 856 | "min_height": null, 857 | "min_width": null, 858 | "object_fit": null, 859 | "object_position": null, 860 | "order": null, 861 | "overflow": null, 862 | "overflow_x": null, 863 | "overflow_y": null, 864 | "padding": null, 865 | "right": null, 866 | "top": null, 867 | "visibility": null, 868 | "width": null 869 | } 870 | }, 871 | "e7876fd73da349ea873c137c63d8d528": { 872 | "model_module": "@jupyter-widgets/controls", 873 | "model_module_version": "1.5.0", 874 | "model_name": "ProgressStyleModel", 875 | "state": { 876 | "_model_module": "@jupyter-widgets/controls", 877 | "_model_module_version": "1.5.0", 878 | "_model_name": "ProgressStyleModel", 879 | "_view_count": null, 880 | "_view_module": "@jupyter-widgets/base", 881 | "_view_module_version": "1.2.0", 882 | "_view_name": "StyleView", 883 | "bar_color": null, 884 | "description_width": "" 885 | } 886 | }, 887 | "e8f0522f19c44066b5a78ded999f050a": { 888 | "model_module": "@jupyter-widgets/controls", 889 | "model_module_version": "1.5.0", 890 | "model_name": "FloatProgressModel", 891 | "state": { 892 | "_dom_classes": [], 893 | "_model_module": "@jupyter-widgets/controls", 894 | "_model_module_version": "1.5.0", 895 | "_model_name": "FloatProgressModel", 896 | "_view_count": null, 897 | "_view_module": "@jupyter-widgets/controls", 898 | "_view_module_version": "1.5.0", 899 | "_view_name": "ProgressView", 900 | "bar_style": "success", 901 | "description": "", 902 | "description_tooltip": null, 903 | "layout": "IPY_MODEL_4de9492961d841aa9f3d7bc629911296", 904 | "max": 40, 905 | "min": 0, 906 | "orientation": "horizontal", 907 | "style": "IPY_MODEL_e7876fd73da349ea873c137c63d8d528", 908 | "value": 40 909 | } 910 | }, 911 | "f58fddb1bf414071b0523701a619ad71": { 912 | "model_module": "@jupyter-widgets/controls", 913 | "model_module_version": "1.5.0", 914 | "model_name": "HTMLModel", 915 | "state": { 916 | "_dom_classes": [], 917 | "_model_module": "@jupyter-widgets/controls", 918 | "_model_module_version": "1.5.0", 919 | "_model_name": "HTMLModel", 920 | "_view_count": null, 921 | "_view_module": "@jupyter-widgets/controls", 922 | "_view_module_version": "1.5.0", 923 | "_view_name": "HTMLView", 924 | "description": "", 925 | "description_tooltip": null, 926 | "layout": "IPY_MODEL_1d5b2e090c51406e953b4eec4b0b91ad", 927 | "placeholder": "​", 928 | "style": "IPY_MODEL_3735627f227d4b4f927955113111409f", 929 | "value": " 40/40 [1:08:10<00:00, 102.17s/it]" 930 | } 931 | }, 932 | "f65eec1b45de42e59fb9e24b99aad917": { 933 | "model_module": "@jupyter-widgets/controls", 934 | "model_module_version": "1.5.0", 935 | "model_name": "HTMLModel", 936 | "state": { 937 | "_dom_classes": [], 938 | "_model_module": "@jupyter-widgets/controls", 939 | "_model_module_version": "1.5.0", 940 | "_model_name": "HTMLModel", 941 | "_view_count": null, 942 | "_view_module": "@jupyter-widgets/controls", 943 | "_view_module_version": "1.5.0", 944 | "_view_name": "HTMLView", 945 | "description": "", 946 | "description_tooltip": null, 947 | "layout": "IPY_MODEL_f67dc08a01ac40ad98ed553fe6b7e948", 948 | "placeholder": "​", 949 | "style": "IPY_MODEL_d7ed88f49793494bbdb3c2fffc01b216", 950 | "value": "100%" 951 | } 952 | }, 953 | "f67dc08a01ac40ad98ed553fe6b7e948": { 954 | "model_module": "@jupyter-widgets/base", 955 | "model_module_version": "1.2.0", 956 | "model_name": "LayoutModel", 957 | "state": { 958 | "_model_module": "@jupyter-widgets/base", 959 | "_model_module_version": "1.2.0", 960 | "_model_name": "LayoutModel", 961 | "_view_count": null, 962 | "_view_module": "@jupyter-widgets/base", 963 | "_view_module_version": "1.2.0", 964 | "_view_name": "LayoutView", 965 | "align_content": null, 966 | "align_items": null, 967 | "align_self": null, 968 | "border": null, 969 | "bottom": null, 970 | "display": null, 971 | "flex": null, 972 | "flex_flow": null, 973 | "grid_area": null, 974 | "grid_auto_columns": null, 975 | "grid_auto_flow": null, 976 | "grid_auto_rows": null, 977 | "grid_column": null, 978 | "grid_gap": null, 979 | "grid_row": null, 980 | "grid_template_areas": null, 981 | "grid_template_columns": null, 982 | "grid_template_rows": null, 983 | "height": null, 984 | "justify_content": null, 985 | "justify_items": null, 986 | "left": null, 987 | "margin": null, 988 | "max_height": null, 989 | "max_width": null, 990 | "min_height": null, 991 | "min_width": null, 992 | "object_fit": null, 993 | "object_position": null, 994 | "order": null, 995 | "overflow": null, 996 | "overflow_x": null, 997 | "overflow_y": null, 998 | "padding": null, 999 | "right": null, 1000 | "top": null, 1001 | "visibility": null, 1002 | "width": null 1003 | } 1004 | } 1005 | } 1006 | } 1007 | }, 1008 | "nbformat": 4, 1009 | "nbformat_minor": 5 1010 | } 1011 | -------------------------------------------------------------------------------- /week01_text_classification/lecture1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashaba1in/hse-nlp/ec90e3ea026aa44932d06c4f36c5df4779c4ca06/week01_text_classification/lecture1.pdf -------------------------------------------------------------------------------- /week02_generation/lecture2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashaba1in/hse-nlp/ec90e3ea026aa44932d06c4f36c5df4779c4ca06/week02_generation/lecture2.pdf -------------------------------------------------------------------------------- /week02_generation/seminar2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Cеминар 2. Генерация текста\n", 8 | "\n", 9 | "\n", 10 | "#### План\n", 11 | "\n", 12 | "1. Токенизация\n", 13 | "2. RNN\n", 14 | "3. Метрики качества генерации" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "%load_ext autoreload\n", 24 | "%autoreload 2" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "name": "stderr", 34 | "output_type": "stream", 35 | "text": [ 36 | "/home/echimbulatov/miniconda3/envs/cosmos/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 37 | " from .autonotebook import tqdm as notebook_tqdm\n" 38 | ] 39 | }, 40 | { 41 | "data": { 42 | "text/plain": [ 43 | "device(type='cuda')" 44 | ] 45 | }, 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | } 50 | ], 51 | "source": [ 52 | "import re\n", 53 | "import pandas as pd\n", 54 | "import numpy as np\n", 55 | "import torch\n", 56 | "from tokenizers import Tokenizer, models, trainers\n", 57 | "from tokenizers.normalizers import Lowercase\n", 58 | "from tokenizers.pre_tokenizers import Whitespace\n", 59 | "from tokenizers.processors import TemplateProcessing\n", 60 | "from tokenizers.decoders import WordPiece as WordPieceDecoder\n", 61 | "import matplotlib.pyplot as plt\n", 62 | "from tqdm.auto import tqdm\n", 63 | "from functools import partial\n", 64 | "from collections import Counter\n", 65 | "\n", 66 | "\n", 67 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 68 | "device" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "### Токенизация\n", 76 | "\n", 77 | "Токенизация – процесс разбиения текста на подстроки (токены). Токенизация в первую очередь нужна для уменьшения размера словаря. При этом при токенизации мы хотим, чтобы все токены были максимально репрезентативными. Таким образом, мы балансируем между размером словаря и репрезентативностью токенов." 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 3, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "from tokenizers import Tokenizer, models\n", 87 | "\n", 88 | "tokenizer = Tokenizer(models.BPE())" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "Токенизация включает в себя несколько компонент:\n", 96 | "\n", 97 | "1. **Нормализация** - предварительную очистку текста (приведение к нижнему регистру, замена unicode символов на ascii и т.д.)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 4, 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "data": { 107 | "text/plain": [ 108 | "'hello how are u?'" 109 | ] 110 | }, 111 | "execution_count": 4, 112 | "metadata": {}, 113 | "output_type": "execute_result" 114 | } 115 | ], 116 | "source": [ 117 | "from tokenizers import normalizers\n", 118 | "from tokenizers.normalizers import NFD, StripAccents, Lowercase, Strip\n", 119 | "\n", 120 | "normalizer = normalizers.Sequence([NFD(), StripAccents(), Lowercase(), Strip()])\n", 121 | "normalizer.normalize_str(\" Héllò hôw are ü?\")" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 5, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "tokenizer.normalizer = normalizer" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "2. **Предварительная токенизация** — процесс разбиения текста на более мелкие фрагменты, задающие верхнюю границу того, какими будут токены в токенизации. Обычно предварительная токенизация разобивает текст на слова. Тогда итоговые токены будут частями этих слов." 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 6, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "data": { 147 | "text/plain": [ 148 | "[('Hello', (0, 5)),\n", 149 | " ('!', (5, 6)),\n", 150 | " ('How', (7, 10)),\n", 151 | " ('are', (11, 14)),\n", 152 | " ('you', (15, 18)),\n", 153 | " ('?', (18, 19)),\n", 154 | " ('I', (20, 21)),\n", 155 | " (\"'\", (21, 22)),\n", 156 | " ('m', (22, 23)),\n", 157 | " ('fine', (24, 28)),\n", 158 | " (',', (28, 29)),\n", 159 | " ('thank', (30, 35)),\n", 160 | " ('you', (36, 39)),\n", 161 | " ('.', (39, 40))]" 162 | ] 163 | }, 164 | "execution_count": 6, 165 | "metadata": {}, 166 | "output_type": "execute_result" 167 | } 168 | ], 169 | "source": [ 170 | "from tokenizers.pre_tokenizers import Whitespace\n", 171 | "\n", 172 | "pre_tokenizer = Whitespace()\n", 173 | "pre_tokenizer.pre_tokenize_str(\"Hello! How are you? I'm fine, thank you.\")" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "Заметьте, что строка \"I'm\" разбилась на [\"I\", \"'\", \"m\"]. Это может быть не очень хорошо, учитывая, что \"'m\" – то же самое, что \"am\". Для более \"правильного\" разбиения можно использовать претокенизацию на основе правил. Самые популярные инструменты для этого – `spaCy` и `Moses`." 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 7, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "tokenizer.pre_tokenizer = pre_tokenizer" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "3. **Модель** – главная часть токенизации. Модель необходимо обучать на корпусе текстов, чтобы получить словарь, подходящий под конкретные данные. Она применяется к результату претокенизации. И ее задача – разбить \"слова\" на более мелкие составляющие согласно выученным правилам.\n", 197 | "\n", 198 | "Наиболее распространенные виды моделей: `BPE`, `WordPiece`, `Unigram`." 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "#### Byte-Pair Encoding (BPE)\n", 206 | "\n", 207 | "__Алгоритм обучения:__\n", 208 | "\n", 209 | "1. Считаем, сколько раз каждое слово встречается в корпусе\n", 210 | "2. Создаем словарь токенов, который пока что состоит из всех уникальных символов\n", 211 | "3. Находим пару токенов, которая чаще всего встречается вместе. Склеиваем ее в токен и добавляем в словарь.\n", 212 | "4. Повторяем 3, пока словарь не достигнет нужного размера.\n", 213 | "\n", 214 | "\n", 215 | "__Пример:__\n", 216 | "\n", 217 | "Корпус слов со встречаемостями: (`(\"hug\", 2)`, `(\"pug\", 1)`, `(\"pun\", 4)`, `(\"bun\", 3)`, `(\"hugs\", 1)`) \n", 218 | "Максимальный размер словаря – __9__.\n", 219 | "\n", 220 | "__Шаг 1.__ \n", 221 | "Словарь: [`\"b\"`, `\"g\"`, `\"h\"`, `\"n\"`, `\"p\"`, `\"s\"`, `\"u\"`], размер: __7__. \n", 222 | "Корпус: [`(\"h\" \"u\" \"g\", 3)`, `(\"p\" \"u\" \"g\", 1)`, `(\"p\" \"u\" \"n\", 3)`, `(\"b\" \"u\" \"n\", 3)`, `(\"h\" \"u\" \"g\" \"s\", 1)`]\n", 223 | "\n", 224 | "Чаще всего встречается пара (`\"u\" \"n\"`) – 6 раз. Добавляем ее в словарь и обновляем корпус.\n", 225 | "\n", 226 | "__Шаг 2.__ \n", 227 | "Словарь: [`\"b\"`, `\"g\"`, `\"h\"`, `\"n\"`, `\"p\"`, `\"s\"`, `\"u\"`, `\"un\"`], размер: __8__. \n", 228 | "Корпус: [`(\"h\" \"u\" \"g\", 3)`, `(\"p\" \"u\" \"g\", 1)`, `(\"p\" \"un\", 3)`, `(\"b\" \"un\", 3)`, `(\"h\" \"u\" \"g\" \"s\", 1)`]\n", 229 | "\n", 230 | "Чаще всего встречается пара (`\"u\" \"g\"`) – 5 раз. Добавляем ее в словарь и обновляем корпус.\n", 231 | "\n", 232 | "__Шаг 3.__ \n", 233 | "Словарь: [`\"b\"`, `\"g\"`, `\"h\"`, `\"n\"`, `\"p\"`, `\"s\"`, `\"u\"`, `\"un\"`, `\"ug\"`], размер: __9__. \n", 234 | "Корпус: [`(\"h\" \"ug\", 3)`, `(\"p\" \"ug\", 1)`, `(\"p\" \"un\", 3)`, `(\"b\" \"un\", 3)`, `(\"h\" \"ug\" \"s\", 1)`]\n", 235 | "\n", 236 | "\n", 237 | "__Процесс токенизации:__\n", 238 | "\n", 239 | "1. Нормализация\n", 240 | "1. Предварительная токенизация\n", 241 | "1. Разделение слов на отдельные символы\n", 242 | "1. Применение правил слияния (в порядке их появления в словаре) к разделенным словам \n" 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "metadata": {}, 248 | "source": [ 249 | "#### Byte-level BPE\n", 250 | "\n", 251 | "Если вы собираетесь обучать очень большую модель на всем интернете, то вы можете столкнуться с ситуацией, в которой на тестовой выборке попадаются токены не из словаря (например, эмоджи). В таком случае токены придется менять на UNK и терять эту информацию.\n", 252 | "\n", 253 | "Токенизаторы GPT-2 и RoBERTa (которые довольно похожи) имеют умный способ решения этой проблемы: они рассматривают слова не как символы Unicode, а как байты. Таким образом, базовый словарь имеет небольшой размер (256), но все символы, которые вы можете придумать, все равно будут включены и не будут преобразованы в неизвестный токен. Этот трюк называется byte-level BPE." 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "metadata": {}, 259 | "source": [ 260 | "#### WordPiece\n", 261 | "\n", 262 | "WordPiece работает точно так же, как и BPE, за исключением выбора токенов для слияния. BPE не учитывает частоту встречаемости каждого из токенов по отдельности. WordPiece же исправляет это и вводит score, по которому оценивает необходимоть слияния двух токенов\n", 263 | "\n", 264 | "$$\n", 265 | "\\operatorname{score}=\\frac{\\text{freq of pair}}{\\text{freq of first element}\\times \\text{freq of second element}}\n", 266 | "$$" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 8, 272 | "metadata": {}, 273 | "outputs": [ 274 | { 275 | "name": "stdout", 276 | "output_type": "stream", 277 | "text": [ 278 | "\n", 279 | "\n", 280 | "\n" 281 | ] 282 | } 283 | ], 284 | "source": [ 285 | "from tokenizers import trainers\n", 286 | "\n", 287 | "texts = [\n", 288 | " \"An infinite number of mathematicians walk into a bar.\",\n", 289 | " \"The first one orders one beer. The second one - half a pint, the third one - a quarter.\",\n", 290 | " \"– Hey, hey, hey, stop! Here are two pints for everyone, and leave me alone!\"\n", 291 | "]\n", 292 | "\n", 293 | "trainer = trainers.BpeTrainer(\n", 294 | " vocab_size=128, min_frequency=1,\n", 295 | " special_tokens=['[BOS]', '[EOS]'],\n", 296 | " continuing_subword_prefix='##', # add special prefix to allow detokenization\n", 297 | ")\n", 298 | "tokenizer.train_from_iterator(texts, trainer)" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 9, 304 | "metadata": {}, 305 | "outputs": [ 306 | { 307 | "name": "stdout", 308 | "output_type": "stream", 309 | "text": [ 310 | "Tokens: ['an', 'infi', '##ni', '##te', 'numb', '##er', 'of', 'mathem', '##atic', '##ian', '##s', 'wal', '##k', 'into', 'a', 'bar', '.']\n", 311 | "Ids: [58, 119, 96, 92, 123, 50, 81, 122, 108, 99, 43, 87, 48, 120, 6, 69, 5]\n" 312 | ] 313 | } 314 | ], 315 | "source": [ 316 | "tokenized = tokenizer.encode(\"An infinite number of mathematicians walk into a bar.\")\n", 317 | "print('Tokens:', tokenized.tokens)\n", 318 | "print('Ids:', tokenized.ids)" 319 | ] 320 | }, 321 | { 322 | "cell_type": "markdown", 323 | "metadata": {}, 324 | "source": [ 325 | "4. __Пост-обработка__ – последний шаг токенизации. Нужен для того, чтобы как-то изменить токенизированный текст. Например, чтобы добавить токены начала и конца последовательности." 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": 10, 331 | "metadata": {}, 332 | "outputs": [], 333 | "source": [ 334 | "from tokenizers.processors import TemplateProcessing\n", 335 | "\n", 336 | "post_processor = TemplateProcessing(\n", 337 | " single=\"[BOS] $A [EOS]\",\n", 338 | " special_tokens=[(\"[BOS]\", 1), (\"[EOS]\", 2)],\n", 339 | ")" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 11, 345 | "metadata": {}, 346 | "outputs": [ 347 | { 348 | "name": "stdout", 349 | "output_type": "stream", 350 | "text": [ 351 | "Tokens: ['[BOS]', 'an', 'infi', '##ni', '##te', 'numb', '##er', 'of', 'mathem', '##atic', '##ian', '##s', 'wal', '##k', 'into', 'a', 'bar', '.', '[EOS]']\n", 352 | "Ids: [1, 58, 119, 96, 92, 123, 50, 81, 122, 108, 99, 43, 87, 48, 120, 6, 69, 5, 2]\n" 353 | ] 354 | } 355 | ], 356 | "source": [ 357 | "processed = post_processor.process(tokenized)\n", 358 | "print('Tokens:', processed.tokens)\n", 359 | "print('Ids:', processed.ids)" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 12, 365 | "metadata": {}, 366 | "outputs": [], 367 | "source": [ 368 | "tokenizer.decoder = WordPieceDecoder(prefix='##', cleanup=True)" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": 13, 374 | "metadata": {}, 375 | "outputs": [ 376 | { 377 | "data": { 378 | "text/plain": [ 379 | "'an infinite number of mathematicians walk into a bar.'" 380 | ] 381 | }, 382 | "execution_count": 13, 383 | "metadata": {}, 384 | "output_type": "execute_result" 385 | } 386 | ], 387 | "source": [ 388 | "tokenizer.decode(tokenized.ids)" 389 | ] 390 | }, 391 | { 392 | "cell_type": "markdown", 393 | "metadata": {}, 394 | "source": [ 395 | "### Обработка данных" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": 14, 401 | "metadata": {}, 402 | "outputs": [ 403 | { 404 | "name": "stderr", 405 | "output_type": "stream", 406 | "text": [ 407 | "/home/echimbulatov/miniconda3/envs/cosmos/lib/python3.11/site-packages/requests/__init__.py:86: RequestsDependencyWarning: Unable to find acceptable character detection dependency (chardet or charset_normalizer).\n", 408 | " warnings.warn(\n" 409 | ] 410 | } 411 | ], 412 | "source": [ 413 | "from datasets import load_from_disk\n", 414 | "\n", 415 | "dataset = load_from_disk('/home/echimbulatov/data/rocstories')\n", 416 | "dataset[\"train\"] = dataset[\"train\"][:5000]" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 15, 422 | "metadata": {}, 423 | "outputs": [], 424 | "source": [ 425 | "train_texts = np.array(dataset['train']['target'])\n", 426 | "test_texts = np.array(dataset['test']['target'])" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": 16, 432 | "metadata": {}, 433 | "outputs": [], 434 | "source": [ 435 | "def print_texts(texts):\n", 436 | " for t in texts:\n", 437 | " print(t, end='\\n\\n')" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": 17, 443 | "metadata": {}, 444 | "outputs": [ 445 | { 446 | "name": "stdout", 447 | "output_type": "stream", 448 | "text": [ 449 | "Kelly was at home, trying to sleep. Suddenly, she heard footsteps in her kitchen. She grabbed a gun and stood at the top of the stairs. She warned whoever it was that she was armed. She heard them run out of the house and then called police.\n", 450 | "\n", 451 | "I bought a 1969 Mercury Montego with a loose front seat. The seat was loose because the car's floor had rusted through. I removed the seat and repaired the floor with pieces of sheet metal. My repair held the seat firmly in place after I reinstalled it. The car then successfully passed the safety inspection.\n", 452 | "\n" 453 | ] 454 | } 455 | ], 456 | "source": [ 457 | "idxs = [5, 9]\n", 458 | "print_texts(train_texts[idxs])" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": 18, 464 | "metadata": {}, 465 | "outputs": [], 466 | "source": [ 467 | "from my_tokenizers import *" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": 19, 473 | "metadata": {}, 474 | "outputs": [ 475 | { 476 | "name": "stdout", 477 | "output_type": "stream", 478 | "text": [ 479 | "Tokenized\n", 480 | "['[BOS]', 'k', 'e', 'l', 'l', 'y', ' ', 'w', 'a', 's', ' ', 'a', 't', ' ', 'h', 'o', 'm', 'e', ',', ' ', 't', 'r', 'y', 'i', 'n', 'g', ' ', 't', 'o', ' ', 's', 'l', 'e', 'e', 'p', '.', ' ', 's', 'u', 'd', 'd', 'e', 'n', 'l', 'y', ',', ' ', 's', 'h', 'e', ' ', 'h', 'e', 'a', 'r', 'd', ' ', 'f', 'o', 'o', 't', 's', 't', 'e', 'p', 's', ' ', 'i', 'n', ' ', 'h', 'e', 'r', ' ', 'k', 'i', 't', 'c', 'h', 'e', 'n', '.', ' ', 's', 'h', 'e', ' ', 'g', 'r', 'a', 'b', 'b', 'e', 'd', ' ', 'a', ' ', 'g', 'u', 'n', ' ', 'a', 'n', 'd', ' ', 's', 't', 'o', 'o', 'd', ' ', 'a', 't', ' ', 't', 'h', 'e', ' ', 't', 'o', 'p', ' ', 'o', 'f', ' ', 't', 'h', 'e', ' ', 's', 't', 'a', 'i', 'r', 's', '.', ' ', 's', 'h', 'e', ' ', 'w', 'a', 'r', 'n', 'e', 'd', ' ', 'w', 'h', 'o', 'e', 'v', 'e', 'r', ' ', 'i', 't', ' ', 'w', 'a', 's', ' ', 't', 'h', 'a', 't', ' ', 's', 'h', 'e', ' ', 'w', 'a', 's', ' ', 'a', 'r', 'm', 'e', 'd', '.', ' ', 's', 'h', 'e', ' ', 'h', 'e', 'a', 'r', 'd', ' ', 't', 'h', 'e', 'm', ' ', 'r', 'u', 'n', ' ', 'o', 'u', 't', ' ', 'o', 'f', ' ', 't', 'h', 'e', ' ', 'h', 'o', 'u', 's', 'e', ' ', 'a', 'n', 'd', ' ', 't', 'h', 'e', 'n', ' ', 'c', 'a', 'l', 'l', 'e', 'd', ' ', 'p', 'o', 'l', 'i', 'c', 'e', '.', '[EOS]']\n", 481 | "\n", 482 | "['[BOS]', 'i', ' ', 'b', 'o', 'u', 'g', 'h', 't', ' ', 'a', ' ', '1', '9', '6', '9', ' ', 'm', 'e', 'r', 'c', 'u', 'r', 'y', ' ', 'm', 'o', 'n', 't', 'e', 'g', 'o', ' ', 'w', 'i', 't', 'h', ' ', 'a', ' ', 'l', 'o', 'o', 's', 'e', ' ', 'f', 'r', 'o', 'n', 't', ' ', 's', 'e', 'a', 't', '.', ' ', 't', 'h', 'e', ' ', 's', 'e', 'a', 't', ' ', 'w', 'a', 's', ' ', 'l', 'o', 'o', 's', 'e', ' ', 'b', 'e', 'c', 'a', 'u', 's', 'e', ' ', 't', 'h', 'e', ' ', 'c', 'a', 'r', \"'\", 's', ' ', 'f', 'l', 'o', 'o', 'r', ' ', 'h', 'a', 'd', ' ', 'r', 'u', 's', 't', 'e', 'd', ' ', 't', 'h', 'r', 'o', 'u', 'g', 'h', '.', ' ', 'i', ' ', 'r', 'e', 'm', 'o', 'v', 'e', 'd', ' ', 't', 'h', 'e', ' ', 's', 'e', 'a', 't', ' ', 'a', 'n', 'd', ' ', 'r', 'e', 'p', 'a', 'i', 'r', 'e', 'd', ' ', 't', 'h', 'e', ' ', 'f', 'l', 'o', 'o', 'r', ' ', 'w', 'i', 't', 'h', ' ', 'p', 'i', 'e', 'c', 'e', 's', ' ', 'o', 'f', ' ', 's', 'h', 'e', 'e', 't', ' ', 'm', 'e', 't', 'a', 'l', '.', ' ', 'm', 'y', ' ', 'r', 'e', 'p', 'a', 'i', 'r', ' ', 'h', 'e', 'l', 'd', ' ', 't', 'h', 'e', ' ', 's', 'e', 'a', 't', ' ', 'f', 'i', 'r', 'm', 'l', 'y', ' ', 'i', 'n', ' ', 'p', 'l', 'a', 'c', 'e', ' ', 'a', 'f', 't', 'e', 'r', ' ', 'i', ' ', 'r', 'e', 'i', 'n', 's', 't', 'a', 'l', 'l', 'e', 'd', ' ', 'i', 't', '.', ' ', 't', 'h', 'e', ' ', 'c', 'a', 'r', ' ', 't', 'h', 'e', 'n', ' ', 's', 'u', 'c', 'c', 'e', 's', 's', 'f', 'u', 'l', 'l', 'y', ' ', 'p', 'a', 's', 's', 'e', 'd', ' ', 't', 'h', 'e', ' ', 's', 'a', 'f', 'e', 't', 'y', ' ', 'i', 'n', 's', 'p', 'e', 'c', 't', 'i', 'o', 'n', '.', '[EOS]']\n", 483 | "\n", 484 | "Detokenized\n", 485 | "kelly was at home, trying to sleep. suddenly, she heard footsteps in her kitchen. she grabbed a gun and stood at the top of the stairs. she warned whoever it was that she was armed. she heard them run out of the house and then called police.\n", 486 | "\n", 487 | "i bought a 1969 mercury montego with a loose front seat. the seat was loose because the car's floor had rusted through. i removed the seat and repaired the floor with pieces of sheet metal. my repair held the seat firmly in place after i reinstalled it. the car then successfully passed the safety inspection.\n", 488 | "\n", 489 | "Vocab size: 82\n" 490 | ] 491 | } 492 | ], 493 | "source": [ 494 | "tok = CharacterTokenizer(train_texts)\n", 495 | "\n", 496 | "tokenized = tok.encode(train_texts[idxs])\n", 497 | "print('Tokenized')\n", 498 | "print_texts(tokenized['tokens'])\n", 499 | "print('Detokenized')\n", 500 | "print_texts(tok.decode(tokenized['input_ids']))\n", 501 | "\n", 502 | "print(f'Vocab size: {len(tok.token2id)}')" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": 20, 508 | "metadata": {}, 509 | "outputs": [ 510 | { 511 | "data": { 512 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXEAAADcCAYAAACGcpEgAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQAALK5JREFUeJzt3XlcE9f6P/BPQBIUSFgDyCaLV0FRW7SaulFBUNFqxbVWUSkqArfqrVZal2p7L9b21q2urUXb6vWKW9W6FDesGi2iKKJS9KJYMaBQwqIEgfP7w1/m65CgBBEYfN6vV16aM2fmnCczPDmcnAwixhgDIYQQQTJq7A4QQgipO0rihBAiYJTECSFEwCiJE0KIgFESJ4QQAaMkTgghAkZJnBBCBIySOCGECBglcUIIEbBmn8RFIhGio6MbuxuvDH9/f3Ts2LFR2vX392/wdknDmThxItq0adPY3XiuTZs2QSQS4fz58w3SnmCT+M2bNzF16lR4eHjA1NQUUqkUPXv2xIoVK/Do0aPG7t4Ly8nJwaefforU1NQGa/Pq1av49NNPcevWrQZrkxChWrNmDTZt2tTY3UCLxu5AXfzyyy8YOXIkJBIJJkyYgI4dO6K8vBynTp3C7NmzkZ6ejg0bNjR2N19ITk4OFi1ahDZt2qBLly4N0ubVq1exaNEi+Pv7C2LE87Rff/21sbtAXjFr1qyBra0tJk6c2Kj9EFwSz8rKwpgxY+Dm5oZjx47B0dGR2xYVFYUbN27gl19+adA+lZaWwszMrEHbrCsh9dUQYrG4QdsrKyuDWCyGkZFgf5klzYTgrsClS5eipKQEGzdu5CVwLS8vL3zwwQc65Xv27EHHjh0hkUjQoUMHHDp0iLf99u3bmD59Otq1a4eWLVvCxsYGI0eO1Jla0M53JSUlYfr06ZDL5XB2djboGABQWFiImTNnok2bNpBIJHB2dsaECRPw4MEDnDhxAt26dQMATJo0CSKRCCKRiPer27lz5zBgwADIZDK0atUKffv2xenTp3ltfPrppxCJRLh69SreffddWFlZoVevXnpf102bNmHkyJEAgLfeeotr88SJE1ydNWvWoEOHDpBIJGjdujWioqJQWFio93hP+/XXX9GqVSuMHTsWFRUVAIDr169jxIgRsLa2hqmpKbp27Yq9e/fqfa1Pnz6NWbNmwc7ODmZmZnjnnXdw//59Xt3qc+Jt2rThYqj+eDqmu3fvYvLkybC3t+euje+//5537BMnTkAkEmHbtm2YN28enJyc0KpVKxQVFdUY87Zt2+Dn5wcLCwtIpVL4+vpixYoVvDqFhYWYMWMGXFxcIJFI4OXlhS+++AJVVVU69SZOnAiZTAZLS0uEhYUhNTVV55qo6XMBfXPJVVVVWL58OTp06ABTU1PY29tj6tSp+Ouvv3j12rRpg8GDB+PUqVN44403YGpqCg8PD/zwww867TzrmtbSaDRYuHAhvLy8IJFI4OLigjlz5kCj0dT4Wj7Ly4jj8uXL6Nu3L1q2bAlnZ2d8/vnniI+Ph0gk4n6W27Rpg/T0dCQlJXHXVfXXXqPRPPe6PX/+PIKDg2Fra4uWLVvC3d0dkydPNug1ENxIfN++ffDw8MCbb75Z631OnTqFXbt2Yfr06bCwsMDKlSsRGhqK7Oxs2NjYAACSk5Nx5swZjBkzBs7Ozrh16xbWrl0Lf39/XL16Fa1ateIdc/r06bCzs8OCBQtQWlpq0DFKSkrQu3dvXLt2DZMnT8brr7+OBw8eYO/evfjzzz/h7e2NxYsXY8GCBZgyZQp69+4NAFzMx44dw8CBA+Hn54eFCxfCyMgI8fHx6NevH3777Te88cYbvL6OHDkSbdu2xb/+9S/UdOfhPn364O9//ztWrlyJjz/+GN7e3gDA/fvpp59i0aJFCAwMRGRkJDIyMrB27VokJyfj9OnTMDEx0Xvc/fv3Y8SIERg9ejS+//57GBsbIz09HT179oSTkxPmzp0LMzMzbN++HcOGDcPOnTvxzjvv8I4RExMDKysrLFy4ELdu3cLy5csRHR2N//73vzWe8+XLl6OkpIRXtmzZMqSmpnLnPDc3Fz169OA+/Lazs8PBgwcRHh6OoqIizJgxg7f/Z599BrFYjA8//BAajabG0X9iYiLGjh2LgIAAfPHFFwCAa9eu4fTp09wA4+HDh+jbty/u3r2LqVOnwtXVFWfOnEFsbCzu3buH5cuXAwAYYxg6dChOnTqFadOmwdvbG7t370ZYWFiNsdfG1KlTsWnTJkyaNAl///vfkZWVhW+++QYXL17UOZ83btzAiBEjEB4ejrCwMHz//feYOHEi/Pz80KFDBwDPv6ZtbW1RVVWFt99+G6dOncKUKVPg7e2NtLQ0LFu2DH/88Qf27NnT6HHcvXuXG8TExsbCzMwM3333HSQSCa/d5cuXIyYmBubm5vjkk08AAPb29rw6z7tu8/LyEBQUBDs7O8ydOxeWlpa4desWdu3aZdiLwARErVYzAGzo0KG13gcAE4vF7MaNG1zZpUuXGAC2atUqruzhw4c6+yqVSgaA/fDDD1xZfHw8A8B69erFKioqePVre4wFCxYwAGzXrl069auqqhhjjCUnJzMALD4+Xmd727ZtWXBwMFdX27a7uzvr378/V7Zw4UIGgI0dO1anHX0SEhIYAHb8+HFeeV5eHhOLxSwoKIhVVlZy5d988w0DwL7//nuurG/fvqxDhw6MMcZ27tzJTExMWEREBG+/gIAA5uvry8rKynhxvfnmm6xt27Zcmfa1DgwM5MU6c+ZMZmxszAoLC3nt9u3bt8bYtm/fzgCwxYsXc2Xh4eHM0dGRPXjwgFd3zJgxTCaTcefz+PHjDADz8PDQe46r++CDD5hUKtW5Pp722WefMTMzM/bHH3/wyufOncuMjY1ZdnY2Y4yxPXv2MABs6dKlXJ2KigrWu3dvneujptcgLCyMubm5cc9/++03BoBt2bKFV+/QoUM65W5ubgwAO3nyJFeWl5fHJBIJ+8c//sGV1eaa/vHHH5mRkRH77bffeNvXrVvHALDTp0/r7NvQccTExDCRSMQuXrzIleXn5zNra2sGgGVlZXHlHTp00Pt61/a63b17NwPAkpOTnxn38whqOkX766uFhYVB+wUGBsLT05N73qlTJ0ilUvzvf//jylq2bMn9//Hjx8jPz4eXlxcsLS1x4cIFnWNGRETA2NiYV1bbY+zcuROdO3fWGXECT5ZEPktqaioyMzPx7rvvIj8/Hw8ePMCDBw9QWlqKgIAAnDx5UufX8WnTpj3zmM9z5MgRlJeXY8aMGbw54IiICEilUr2fQfznP//B6NGjMXXqVKxfv57br6CgAMeOHcOoUaNQXFzM9T8/Px/BwcHIzMzE3bt3eceaMmUK73Xp3bs3Kisrcfv27Vr1/+rVq5g8eTKGDh2KefPmAXgywt25cyeGDBkCxhjXjwcPHiA4OBhqtVrnvIeFhfHOcU0sLS1RWlqKxMTEGuskJCSgd+/esLKy4rUdGBiIyspKnDx5EgBw4MABtGjRApGRkdy+xsbGiImJqVXsNbUtk8nQv39/Xtt+fn4wNzfH8ePHefV9fHy43wYBwM7ODu3ateP9/NTmmk5ISIC3tzfat2/Pa7dfv34AoNNuY8Rx6NAhKBQK3mICa2trjBs3zqC+Ac+/bi0tLQE8+W318ePHBh9fS1DTKVKpFABQXFxs0H6urq46ZVZWVrx5s0ePHiEuLg7x8fG4e/cub9pBrVbr7O/u7q5TVttj3Lx5E6GhoQbFoJWZmQkAz/x1Wq1Ww8rK6pl9NYT2omvXrh2vXCwWw8PDQyeZZmVl4b333sPIkSOxatUq3rYbN26AMYb58+dj/vz5etvLy8uDk5MT97z6+dPGVn3eU5+ioiIMHz4cTk5O+OGHH7gfqvv376OwsBAbNmyocSVTXl4e73ltX8fp06dj+/btGDhwIJycnBAUFIRRo0ZhwIABXJ3MzExcvnwZdnZ2z2z79u3bcHR0hLm5OW979XNhiMzMTKjVasjl8me2rVWbn5/aXNOZmZm4du3ac2OurZcRx+3bt6FQKHTqeXl5GdQ3fe1Vv2779u2L0NBQLFq0CMuWLYO/vz+GDRuGd999V2f65lkEl8Rbt26NK1euGLRf9RGz1tNJNiYmBvHx8ZgxYwYUCgVkMhlEIhHGjBmjM7IFoHdEZugx6kJ7nC+//LLGpYfVf+BrM3qsT46OjnB0dMSBAwdw/vx5dO3aldum7f+HH36I4OBgvftX/4GpzfmrycSJE5GTk4Pff/+dGwQ83Y/33nuvxjfETp068Z7X9nWUy+VITU3F4cOHcfDgQRw8eBDx8fGYMGECNm/ezLXfv39/zJkzR+8x/va3v9WqraeJRCK9r0llZSXveVVVFeRyObZs2aL3ONWT7Iu8/tXb9fX1xddff613u4uLi8HHa4w4aut57YlEIuzYsQNnz57Fvn37cPjwYUyePBn//ve/cfbsWZ2f45oIKokDwODBg7FhwwYolUq975h1tWPHDoSFheHf//43V1ZWVlar1ReGHsPT0/O5b0Q1Tatop4WkUikCAwNr3bfaqKlNNzc3AEBGRgY8PDy48vLycmRlZen0w9TUFPv370e/fv0wYMAAJCUlcR8cafc3MTGp9/5Xt2TJEuzZswe7du1C+/btedvs7OxgYWGBysrKl9IPsViMIUOGYMiQIaiqqsL06dOxfv16zJ8/H15eXvD09ERJSclz23Zzc8PRo0dRUlLC+6HOyMjQqWtlZcWbGtCq/puSp6cnjhw5gp49e9bbG3xtrmlPT09cunQJAQEBz502rG2b9R2Hm5sbbty4oVOur6w+YgCAHj16oEePHvjnP/+JrVu3Yty4cdi2bRvef//9Wu0vqDlxAJgzZw7MzMzw/vvvIzc3V2f7zZs3dZZy1YaxsbHOO/KqVat0RjH1cYzQ0FBcunQJu3fv1jmGdn/tWu7qbwB+fn7w9PTEV199pbP6AoDOEiZD1NRmYGAgxGIxVq5cyYtv48aNUKvVCAkJ0TmWTCbD4cOHIZfL0b9/f9y8eRPAk1Gqv78/1q9fj3v37tVr/5925MgRzJs3D5988gmGDRums93Y2BihoaHYuXOn3uTzIv3Iz8/nPTcyMuJG9dqldKNGjYJSqcThw4d19i8sLOSWYg4aNAgVFRVYu3Ytt72yslJnmgp4ktSuX7/O6/ulS5d0lp6OGjUKlZWV+Oyzz3SOUVFRYdDARas21/SoUaNw9+5dfPvttzp1Hj16xK3yqq2XEUdwcDCUSiXvm9IFBQV6R/tmZmZ1akPrr7/+0skX2t+uDVlyKbiRuKenJ7Zu3YrRo0fD29ub943NM2fOICEhoU7foBo8eDB+/PFHyGQy+Pj4QKlU4siRI9xytPo8xuzZs7Fjxw6MHDkSkydPhp+fHwoKCrB3716sW7cOnTt3hqenJywtLbFu3TpYWFjAzMwM3bt3h7u7O7777jsMHDgQHTp0wKRJk+Dk5IS7d+/i+PHjkEql2Ldvn8HxA08uIGNjY3zxxRdQq9WQSCTo168f5HI5YmNjsWjRIgwYMABvv/02MjIysGbNGnTr1g3vvfee3uPZ2toiMTERvXr1QmBgIE6dOgUnJyesXr0avXr1gq+vLyIiIuDh4YHc3FwolUr8+eefuHTpUp36/7SxY8fCzs4Obdu2xU8//cTb1r9/f9jb22PJkiU4fvw4unfvjoiICPj4+KCgoAAXLlzAkSNHUFBQUKe233//fRQUFKBfv35wdnbG7du3sWrVKnTp0oVbsjl79mzs3bsXgwcP5pa5lZaWIi0tDTt27MCtW7dga2uLIUOGoGfPnpg7dy5u3boFHx8f7Nq1S+/nNJMnT8bXX3+N4OBghIeHIy8vD+vWrUOHDh14a9r79u2LqVOnIi4uDqmpqQgKCoKJiQkyMzORkJCAFStWYMSIEQbFXJtrevz48di+fTumTZuG48ePo2fPnqisrMT169exfft2HD58mDf19jwvI445c+bgp59+Qv/+/RETE8MtMXR1dUVBQQFv9O3n54e1a9fi888/h5eXF+RyOfchbW1s3rwZa9aswTvvvANPT08UFxfj22+/hVQqxaBBg2rf6Rda29KI/vjjDxYREcHatGnDxGIxs7CwYD179mSrVq3iLV0DwKKionT2d3NzY2FhYdzzv/76i02aNInZ2toyc3NzFhwczK5fv65TT7t8SN+yoNoeg7Eny5aio6OZk5MTE4vFzNnZmYWFhfGWu/3888/Mx8eHtWjRQmc52cWLF9nw4cOZjY0Nk0gkzM3NjY0aNYodPXqUq6NdYnj//v1av67ffvst8/DwYMbGxjrLDb/55hvWvn17ZmJiwuzt7VlkZCT766+/ePs/vcRQ68aNG8zR0ZF5e3tzfbl58yabMGECc3BwYCYmJszJyYkNHjyY7dixg9uvptdau+Tv6b5VX14HoMbH0/vl5uayqKgo5uLiwkxMTJiDgwMLCAhgGzZs0GkvISGhVq/hjh07WFBQEJPL5UwsFjNXV1c2depUdu/ePV694uJiFhsby7y8vJhYLGa2trbszTffZF999RUrLy/n6uXn57Px48czqVTKZDIZGz9+PLt48aLeJag//fQT8/DwYGKxmHXp0oUdPnxYZ2me1oYNG5ifnx9r2bIls7CwYL6+vmzOnDksJyeHq+Pm5sZCQkJ09tW3nLE213R5eTn74osvWIcOHZhEImFWVlbMz8+PLVq0iKnV6me+rg0Vx8WLF1nv3r2ZRCJhzs7OLC4ujq1cuZIBYCqViqunUqlYSEgIs7CwYAC449T2ur1w4QIbO3Ysc3V1ZRKJhMnlcjZ48GB2/vz5Z74O1YkYe0mz+oSQl+bWrVtwd3dHfHx8o9+741UwY8YMrF+/HiUlJTV+YNlYBDcnTgghL1P1u6Dm5+fjxx9/RK9evZpcAgcEOCdOCCEvk0KhgL+/P7y9vZGbm4uNGzeiqKioxu81NDZK4oQQ8pRBgwZhx44d2LBhA0QiEV5//XVs3LgRffr0aeyu6UVz4oQQImA0J04IIQJGSZwQQgRMkHPiVVVVyMnJgYWFRb199ZUQQhoTYwzFxcVo3bq1QX8xSpBJPCcnx+Cb5RBCiBDcuXOH+2thtSHIJK69n/idO3d4d6YjhBChKioqgouLi8F/L0GQSVw7hSKVSimJE0KaFUOniOmDTUIIETBK4oQQImCUxAkhRMAoiRNCiIBREieEEAET5OoUQhpSm7m/6C2/tUT3z9IR0tBoJE4IIQJGSZwQQgSMkjghhAgYJXFCCBEwSuKEECJglMQJIUTAKIkTQoiA0TpxQv6/mtaDE9KU0UicEEIEjJI4IYQI2Asl8SVLlkAkEmHGjBlcWVlZGaKiomBjYwNzc3OEhoYiNzeXt192djZCQkLQqlUryOVyzJ49GxUVFS/SFUIIeSXVOYknJydj/fr16NSpE6985syZ2LdvHxISEpCUlIScnBwMHz6c215ZWYmQkBCUl5fjzJkz2Lx5MzZt2oQFCxbUPQpCCHlF1SmJl5SUYNy4cfj2229hZWXFlavVamzcuBFff/01+vXrBz8/P8THx+PMmTM4e/YsAODXX3/F1atX8dNPP6FLly4YOHAgPvvsM6xevRrl5eX1ExUhhLwi6pTEo6KiEBISgsDAQF55SkoKHj9+zCtv3749XF1doVQqAQBKpRK+vr6wt7fn6gQHB6OoqAjp6el629NoNCgqKuI9CCGE1GGJ4bZt23DhwgUkJyfrbFOpVBCLxbC0tOSV29vbQ6VScXWeTuDa7dpt+sTFxWHRokWGdpUQQpo9g0bid+7cwQcffIAtW7bA1NT0ZfVJR2xsLNRqNfe4c+dOg7VNCCFNmUFJPCUlBXl5eXj99dfRokULtGjRAklJSVi5ciVatGgBe3t7lJeXo7CwkLdfbm4uHBwcAAAODg46q1W0z7V1qpNIJJBKpbwHIYQQA5N4QEAA0tLSkJqayj26du2KcePGcf83MTHB0aNHuX0yMjKQnZ0NhUIBAFAoFEhLS0NeXh5XJzExEVKpFD4+PvUUFiGEvBoMmhO3sLBAx44deWVmZmawsbHhysPDwzFr1ixYW1tDKpUiJiYGCoUCPXr0AAAEBQXBx8cH48ePx9KlS6FSqTBv3jxERUVBIpHUU1iEPEF/Wo00d/V+75Rly5bByMgIoaGh0Gg0CA4Oxpo1a7jtxsbG2L9/PyIjI6FQKGBmZoawsDAsXry4vrtCCCHN3gsn8RMnTvCem5qaYvXq1Vi9enWN+7i5ueHAgQMv2jQhhLzy6N4phBAiYHQrWtJsNOVbyRoyN0/z+MQQNBInhBABoyROCCECRtMphNSjpjylQ5onSuJEcChREvJ/KIkTUkf0ZkKaAkripEnQlxBf5moMSsCkuaAPNgkhRMAoiRNCiIBREieEEAGjJE4IIQJGH2ySJos+fCTk+WgkTgghAkZJnBBCBIySOCGECBjNiRMiEHSLWqIPjcQJIUTAaCROSCOiFTjkRdFInBBCBIxG4oQ0QzR//uqgkTghhAiYQUl87dq16NSpE6RSKaRSKRQKBQ4ePMhtLysrQ1RUFGxsbGBubo7Q0FDk5ubyjpGdnY2QkBC0atUKcrkcs2fPRkVFRf1EQwghrxiDkrizszOWLFmClJQUnD9/Hv369cPQoUORnp4OAJg5cyb27duHhIQEJCUlIScnB8OHD+f2r6ysREhICMrLy3HmzBls3rwZmzZtwoIFC+o3KkIIeUWIGGPsRQ5gbW2NL7/8EiNGjICdnR22bt2KESNGAACuX78Ob29vKJVK9OjRAwcPHsTgwYORk5MDe3t7AMC6devw0Ucf4f79+xCLxbVqs6ioCDKZDGq1GlKp9EW6T5oIWqVRd/rmuWlOXHjqmtfq/MFmZWUlEhISUFpaCoVCgZSUFDx+/BiBgYFcnfbt28PV1ZVL4kqlEr6+vlwCB4Dg4GBERkYiPT0dr732mt62NBoNNBoN97yoqKiu3SaNjJI1IfXL4CSelpYGhUKBsrIymJubY/fu3fDx8UFqairEYjEsLS159e3t7aFSqQAAKpWKl8C127XbahIXF4dFixYZ2lVCXgn0xvhqMziJt2vXDqmpqVCr1dixYwfCwsKQlJT0MvrGiY2NxaxZs7jnRUVFcHFxealtkhdDiYWQhmFwEheLxfDy8gIA+Pn5ITk5GStWrMDo0aNRXl6OwsJC3mg8NzcXDg4OAAAHBwf8/vvvvONpV69o6+gjkUggkUgM7SohhDR7L7xOvKqqChqNBn5+fjAxMcHRo0e5bRkZGcjOzoZCoQAAKBQKpKWlIS8vj6uTmJgIqVQKHx+fF+0KIYS8cgwaicfGxmLgwIFwdXVFcXExtm7dihMnTuDw4cOQyWQIDw/HrFmzYG1tDalUipiYGCgUCvTo0QMAEBQUBB8fH4wfPx5Lly6FSqXCvHnzEBUVRSNtQgipA4OSeF5eHiZMmIB79+5BJpOhU6dOOHz4MPr37w8AWLZsGYyMjBAaGgqNRoPg4GCsWbOG29/Y2Bj79+9HZGQkFAoFzMzMEBYWhsWLF9dvVIQQ8op44XXijYHWiTd99MFm00TrxJuuBl8nTogWJWxCGg/dAIsQQgSMkjghhAgYTacQ8grRN/VF8+TCRiNxQggRMBqJE/KKozseChuNxAkhRMAoiRNCiIBREieEEAGjJE4IIQJGSZwQQgSMVqeQWqOv1xPS9NBInBBCBIxG4kQHjbgJQOvHhYKSOCHEIPTV/aaFplMIIUTAKIkTQoiAURInhBABoyROCCECRkmcEEIEjJI4IYQImEFJPC4uDt26dYOFhQXkcjmGDRuGjIwMXp2ysjJERUXBxsYG5ubmCA0NRW5uLq9OdnY2QkJC0KpVK8jlcsyePRsVFRUvHg0hhLxiDEriSUlJiIqKwtmzZ5GYmIjHjx8jKCgIpaWlXJ2ZM2di3759SEhIQFJSEnJycjB8+HBue2VlJUJCQlBeXo4zZ85g8+bN2LRpExYsWFB/URFCyCtCxBhjdd35/v37kMvlSEpKQp8+faBWq2FnZ4etW7dixIgRAIDr16/D29sbSqUSPXr0wMGDBzF48GDk5OTA3t4eALBu3Tp89NFHuH//PsRi8XPbLSoqgkwmg1qthlQqrWv3SQ3oG5vEUPRlnxdX17z2Qt/YVKvVAABra2sAQEpKCh4/fozAwECuTvv27eHq6solcaVSCV9fXy6BA0BwcDAiIyORnp6O11577UW6RAxECZvUB/qKfuOpcxKvqqrCjBkz0LNnT3Ts2BEAoFKpIBaLYWlpyatrb28PlUrF1Xk6gWu3a7fpo9FooNFouOdFRUV17TYhhDQrdV6dEhUVhStXrmDbtm312R+94uLiIJPJuIeLi8tLb5MQQoSgTkk8Ojoa+/fvx/Hjx+Hs7MyVOzg4oLy8HIWFhbz6ubm5cHBw4OpUX62ifa6tU11sbCzUajX3uHPnTl26TQghzY5BSZwxhujoaOzevRvHjh2Du7s7b7ufnx9MTExw9OhRriwjIwPZ2dlQKBQAAIVCgbS0NOTl5XF1EhMTIZVK4ePjo7ddiUQCqVTKexBCCDFwTjwqKgpbt27Fzz//DAsLC24OWyaToWXLlpDJZAgPD8esWbNgbW0NqVSKmJgYKBQK9OjRAwAQFBQEHx8fjB8/HkuXLoVKpcK8efMQFRUFiURS/xESQkgzZlASX7t2LQDA39+fVx4fH4+JEycCAJYtWwYjIyOEhoZCo9EgODgYa9as4eoaGxtj//79iIyMhEKhgJmZGcLCwrB48eIXi4QQQl5BL7ROvLHQOvH6Q0sMyctESwxrr655je6dQgghAkZJnBBCBIySOCGECBglcUIIETBK4oQQImCUxAkhRMAoiRNCiIBREieEEAGjJE4IIQJGSZwQQgTshf6yDyGE1IW+2z3QV/TrhpI4IeSloXvzvHyUxJshGuUQ8uqgOXFCCBEwSuKEECJglMQJIUTAaE78FUEfMBHSPNFInBBCBIySOCGECBhNpwgYTZGQ5qSm65mWxz4bjcQJIUTAKIkTQoiAGZzET548iSFDhqB169YQiUTYs2cPbztjDAsWLICjoyNatmyJwMBAZGZm8uoUFBRg3LhxkEqlsLS0RHh4OEpKSl4oEEIIeRUZnMRLS0vRuXNnrF69Wu/2pUuXYuXKlVi3bh3OnTsHMzMzBAcHo6ysjKszbtw4pKenIzExEfv378fJkycxZcqUukdBCCGvKIM/2Bw4cCAGDhyodxtjDMuXL8e8efMwdOhQAMAPP/wAe3t77NmzB2PGjMG1a9dw6NAhJCcno2vXrgCAVatWYdCgQfjqq6/QunXrFwiHEEJeLfU6J56VlQWVSoXAwECuTCaToXv37lAqlQAApVIJS0tLLoEDQGBgIIyMjHDu3Ln67A4hhDR79brEUKVSAQDs7e155fb29tw2lUoFuVzO70SLFrC2tubqVKfRaKDRaLjnRUVF9dltQggRLEGsTomLi4NMJuMeLi4ujd0lQghpEuo1iTs4OAAAcnNzeeW5ubncNgcHB+Tl5fG2V1RUoKCggKtTXWxsLNRqNfe4c+dOfXabEEIEq16nU9zd3eHg4ICjR4+iS5cuAJ5MfZw7dw6RkZEAAIVCgcLCQqSkpMDPzw8AcOzYMVRVVaF79+56jyuRSCCRSOqzq4QQgaBvcj6bwUm8pKQEN27c4J5nZWUhNTUV1tbWcHV1xYwZM/D555+jbdu2cHd3x/z589G6dWsMGzYMAODt7Y0BAwYgIiIC69atw+PHjxEdHY0xY8bQypQa0NfrCSE1MTiJnz9/Hm+99Rb3fNasWQCAsLAwbNq0CXPmzEFpaSmmTJmCwsJC9OrVC4cOHYKpqSm3z5YtWxAdHY2AgAAYGRkhNDQUK1eurIdwCCHk1SJijLHG7oShioqKIJPJoFarIZVKG7s7Lx2NxAnR1dymU+qa1wSxOoUQQoh+lMQJIUTAKIkTQoiAURInhBABo7/sQwgRJH0f+De3Dztrg5J4E0MrUQghhqDpFEIIETBK4oQQImA0ndJIaNqEEFIfKIkTQpqNV/FmWTSdQgghAkZJnBBCBIySOCGECBglcUIIETBK4oQQImC0OuUlo6WEhJCXiUbihBAiYDQSJ4Q0e835ZlmUxOsRTZ0QIhzN5YtBNJ1CCCECRkmcEEIEjKZT6oCmTQhpvoQ2zdJoI/HVq1ejTZs2MDU1Rffu3fH77783VlcIIUSwGiWJ//e//8WsWbOwcOFCXLhwAZ07d0ZwcDDy8vIaozuEECJYIsYYa+hGu3fvjm7duuGbb74BAFRVVcHFxQUxMTGYO3fuc/cvKiqCTCaDWq2GVCp9af2kaRNCyLPU5xRLXfNag8+Jl5eXIyUlBbGxsVyZkZERAgMDoVQq9e6j0Wig0Wi452q1GsCToF+mKs3Dl3p8Qoiw1WcO0h7L0HF1gyfxBw8eoLKyEvb29rxye3t7XL9+Xe8+cXFxWLRokU65i4vLS+kjIYTUhmx5/R+zuLgYMpms1vUFsTolNjYWs2bN4p5XVVWhoKAANjY2EIlEtT5OUVERXFxccOfOnZc6DdNYmnN8zTk2oHnH15xjA+ovPsYYiouL0bp1a4P2a/AkbmtrC2NjY+Tm5vLKc3Nz4eDgoHcfiUQCiUTCK7O0tKxzH6RSabO8mLSac3zNOTagecfXnGMD6ic+Q0bgWg2+OkUsFsPPzw9Hjx7lyqqqqnD06FEoFIqG7g4hhAhao0ynzJo1C2FhYejatSveeOMNLF++HKWlpZg0aVJjdIcQQgSrUZL46NGjcf/+fSxYsAAqlQpdunTBoUOHdD7srG8SiQQLFy7UmZppLppzfM05NqB5x9ecYwMaP75GWSdOCCGkftANsAghRMAoiRNCiIBREieEEAGjJE4IIQLWLJL4yZMnMWTIELRu3RoikQh79uzhbWeMYcGCBXB0dETLli0RGBiIzMxMXp2CggKMGzcOUqkUlpaWCA8PR0lJSQNGod/zYps4cSJEIhHvMWDAAF6dphpbXFwcunXrBgsLC8jlcgwbNgwZGRm8OmVlZYiKioKNjQ3Mzc0RGhqq80Wx7OxshISEoFWrVpDL5Zg9ezYqKioaMhQdtYnN399f59xNmzaNV6cpxgYAa9euRadOnbgvuCgUChw8eJDbLtTzpvW8+JrUuWPNwIEDB9gnn3zCdu3axQCw3bt387YvWbKEyWQytmfPHnbp0iX29ttvM3d3d/bo0SOuzoABA1jnzp3Z2bNn2W+//ca8vLzY2LFjGzgSXc+LLSwsjA0YMIDdu3ePexQUFPDqNNXYgoODWXx8PLty5QpLTU1lgwYNYq6urqykpISrM23aNObi4sKOHj3Kzp8/z3r06MHefPNNbntFRQXr2LEjCwwMZBcvXmQHDhxgtra2LDY2tjFC4tQmtr59+7KIiAjeuVOr1dz2phobY4zt3buX/fLLL+yPP/5gGRkZ7OOPP2YmJibsypUrjDHhnjet58XXlM5ds0jiT6ue6KqqqpiDgwP78ssvubLCwkImkUjYf/7zH8YYY1evXmUAWHJyMlfn4MGDTCQSsbt37zZY35+npiQ+dOjQGvcRSmyMMZaXl8cAsKSkJMbYk/NkYmLCEhISuDrXrl1jAJhSqWSMPXmTMzIyYiqViquzdu1aJpVKmUajadgAnqF6bIw9SQQffPBBjfsIJTYtKysr9t133zWr8/Y0bXyMNa1z1yymU54lKysLKpUKgYGBXJlMJkP37t25W98qlUpYWlqia9euXJ3AwEAYGRnh3LlzDd5nQ504cQJyuRzt2rVDZGQk8vPzuW1Cik17i2Fra2sAQEpKCh4/fsw7d+3bt4erqyvv3Pn6+vK+KBYcHIyioiKkp6c3YO+frXpsWlu2bIGtrS06duyI2NhYPHz4f7c/FkpslZWV2LZtG0pLS6FQKJrVeQN049NqKudOEHcxfBEqlQoA9N76VrtNpVJBLpfztrdo0QLW1tZcnaZqwIABGD58ONzd3XHz5k18/PHHGDhwIJRKJYyNjQUTW1VVFWbMmIGePXuiY8eOAJ6cF7FYrHOzs+rnTt+51W5rCvTFBgDvvvsu3Nzc0Lp1a1y+fBkfffQRMjIysGvXLgBNP7a0tDQoFAqUlZXB3Nwcu3fvho+PD1JTU5vFeaspPqBpnbtmn8SbuzFjxnD/9/X1RadOneDp6YkTJ04gICCgEXtmmKioKFy5cgWnTp1q7K7Uu5pimzJlCvd/X19fODo6IiAgADdv3oSnp2dDd9Ng7dq1Q2pqKtRqNXbs2IGwsDAkJSU1drfqTU3x+fj4NKlz1+ynU7S3t33WrW8dHBx0/r5nRUUFCgoKarw9blPl4eEBW1tb3LhxA4AwYouOjsb+/ftx/PhxODs7c+UODg4oLy9HYWEhr371c6fv3Gq3NbaaYtOne/fuAMA7d005NrFYDC8vL/j5+SEuLg6dO3fGihUrmsV5A2qOT5/GPHfNPom7u7vDwcGBd+vboqIinDt3jpvfUigUKCwsREpKClfn2LFjqKqq4k6OUPz555/Iz8+Ho6MjgKYdG2MM0dHR2L17N44dOwZ3d3fedj8/P5iYmPDOXUZGBrKzs3nnLi0tjfdGlZiYCKlUyv3q2xieF5s+qampAMA7d00xtppUVVVBo9EI+rw9izY+fRr13NXrx6SNpLi4mF28eJFdvHiRAWBff/01u3jxIrt9+zZj7MkSQ0tLS/bzzz+zy5cvs6FDh+pdYvjaa6+xc+fOsVOnTrG2bds2iWV4z4qtuLiYffjhh0ypVLKsrCx25MgR9vrrr7O2bduysrIy7hhNNbbIyEgmk8nYiRMneEu1Hj58yNWZNm0ac3V1ZceOHWPnz59nCoWCKRQKbrt2KVdQUBBLTU1lhw4dYnZ2do2+VO15sd24cYMtXryYnT9/nmVlZbGff/6ZeXh4sD59+nDHaKqxMcbY3LlzWVJSEsvKymKXL19mc+fOZSKRiP3666+MMeGeN61nxdfUzl2zSOLHjx9nAHQeYWFhjLEnywznz5/P7O3tmUQiYQEBASwjI4N3jPz8fDZ27Fhmbm7OpFIpmzRpEisuLm6EaPieFdvDhw9ZUFAQs7OzYyYmJszNzY1FRETwljUx1nRj0xcXABYfH8/VefToEZs+fTqzsrJirVq1Yu+88w67d+8e7zi3bt1iAwcOZC1btmS2trbsH//4B3v8+HEDR8P3vNiys7NZnz59mLW1NZNIJMzLy4vNnj2bt9aYsaYZG2OMTZ48mbm5uTGxWMzs7OxYQEAAl8AZE+5503pWfE3t3NGtaAkhRMCa/Zw4IYQ0Z5TECSFEwCiJE0KIgFESJ4QQAaMkTgghAkZJnBBCBIySOCGECBglcUIIETBK4oQQImCUxAkhRMAoiRNCiIBREieEEAH7f1NELUi/qMBYAAAAAElFTkSuQmCC", 513 | "text/plain": [ 514 | "
" 515 | ] 516 | }, 517 | "metadata": {}, 518 | "output_type": "display_data" 519 | } 520 | ], 521 | "source": [ 522 | "tokenized_lengths = [len(t) for t in tok.encode(test_texts)['tokens']]\n", 523 | "plt.figure(figsize=(4, 2))\n", 524 | "plt.hist(tokenized_lengths, bins=50)\n", 525 | "plt.title('Character tokenizer sequence lengths')\n", 526 | "plt.show()" 527 | ] 528 | }, 529 | { 530 | "cell_type": "code", 531 | "execution_count": 21, 532 | "metadata": {}, 533 | "outputs": [ 534 | { 535 | "name": "stdout", 536 | "output_type": "stream", 537 | "text": [ 538 | "Tokenized\n", 539 | "['[BOS]', 'kelly', 'was', 'at', 'home', ',', 'trying', 'to', 'sleep', '.', 'suddenly', ',', 'she', 'heard', 'footsteps', 'in', 'her', 'kitchen', '.', 'she', 'grabbed', 'a', 'gun', 'and', 'stood', 'at', 'the', 'top', 'of', 'the', 'stairs', '.', 'she', 'warned', 'whoever', 'it', 'was', 'that', 'she', 'was', 'armed', '.', 'she', 'heard', 'them', 'run', 'out', 'of', 'the', 'house', 'and', 'then', 'called', 'police', '.', '[EOS]']\n", 540 | "\n", 541 | "['[BOS]', 'i', 'bought', 'a', '1969', 'mercury', 'montego', 'with', 'a', 'loose', 'front', 'seat', '.', 'the', 'seat', 'was', 'loose', 'because', 'the', 'car', \"'\", 's', 'floor', 'had', 'rusted', 'through', '.', 'i', 'removed', 'the', 'seat', 'and', 'repaired', 'the', 'floor', 'with', 'pieces', 'of', 'sheet', 'metal', '.', 'my', 'repair', 'held', 'the', 'seat', 'firmly', 'in', 'place', 'after', 'i', 'reinstalled', 'it', '.', 'the', 'car', 'then', 'successfully', 'passed', 'the', 'safety', 'inspection', '.', '[EOS]']\n", 542 | "\n", 543 | "Detokenized\n", 544 | "kelly was at home , trying to sleep . suddenly , she heard footsteps in her kitchen . she grabbed a gun and stood at the top of the stairs . she warned whoever it was that she was armed . she heard them run out of the house and then called police .\n", 545 | "\n", 546 | "i bought a 1969 mercury montego with a loose front seat . the seat was loose because the car ' s floor had rusted through . i removed the seat and repaired the floor with pieces of sheet metal . my repair held the seat firmly in place after i reinstalled it . the car then successfully passed the safety inspection .\n", 547 | "\n", 548 | "Vocab size: 11141\n" 549 | ] 550 | } 551 | ], 552 | "source": [ 553 | "tok = WordTokenizer(train_texts)\n", 554 | "\n", 555 | "tokenized = tok.encode(train_texts[idxs])\n", 556 | "print('Tokenized')\n", 557 | "print_texts(tokenized['tokens'])\n", 558 | "print('Detokenized')\n", 559 | "print_texts(tok.decode(tokenized['input_ids']))\n", 560 | "\n", 561 | "print(f'Vocab size: {len(tok.token2id)}')" 562 | ] 563 | }, 564 | { 565 | "cell_type": "code", 566 | "execution_count": 22, 567 | "metadata": {}, 568 | "outputs": [ 569 | { 570 | "data": { 571 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAADcCAYAAABQ10tFAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQAAJ3hJREFUeJzt3XtYVNX+P/D3cJnhIgOIMoABIpgIXjJIBCw7QaJhZWrqyQzzVop5PSaYNzRF7Wam5e18zdTqqJV5ySsiphJ5SROPIpYBoYBmwygqKLN+f/hjH7cDwigCG96v55nnkbXXzP6s2cPbzZo1e1RCCAEiIlIMi9ougIiIzMPgJiJSGAY3EZHCMLiJiBSGwU1EpDAMbiIihWFwExEpDIObiEhhGNxERArD4K6D9u7dC5VKhb179z70fTVv3hw9evR46Pspb7+DBg2q8f1SzXn66afx9NNP13YZlZoxYwZUKhUuXbpU26VUWYMN7nXr1kGlUuG7774z2da+fXuoVCokJyebbPPy8kJYWFhNlFipgwcPYsaMGdDr9bVdClGdN2fOHGzcuLG2y6gWDTa4O3fuDADYv3+/rN1gMCA9PR1WVlY4cOCAbFtOTg5ycnKk+9a2gwcPIiEhQZHBnZGRgeXLl9d2GdSAMLjrAQ8PD/j4+JgEd2pqKoQQePnll022lf38oMEthMD169cf6DGUTqPRwNrausb2V1RUVGP7InrYGmxwA7cD+JdffpGF6IEDBxAYGIju3bvjp59+gtFolG1TqVQIDw8HANy6dQuzZs2Cr68vNBoNmjdvjsmTJ6O4uFi2n7J55B07diA4OBi2trZYunQpAODPP/9Ez549YW9vD1dXV4wbN87k/uWZMWMGJk6cCADw8fGBSqWCSqXCH3/8YVZt5Vm1ahWsrKykxweAtLQ0dOvWDY6OjrCzs0OXLl1M/iIpmys8e/YsBg0aBCcnJzg6OuL111/HtWvXTJ6TO+e4y+ov71Y2JgA4ffo0+vTpg8aNG8PGxgbBwcHYtGmT7LE///xzqFQqpKSkYOTIkXB1dcUjjzxyzzF/8sknCAwMhJ2dHZydnREcHIwvv/xS1ic3NxeDBw+GTqeDRqNBYGAg/u///s/ksco7pjt27DB536Kief7y5oaLi4sxffp0+Pn5QaPRwNPTE2+//bbJ8VSpVBg1ahQ2btyINm3aSHVu377dZD+5ubkYMmQIPDw8oNFo4OPjgxEjRqCkpETqo9frMXbsWHh6ekKj0cDPzw/z5s2T/V6Y42GMY+/evQgODoaNjQ18fX2xdOlS6bV45+MVFRVh1apV0uvq7uder9dX+rrdtWsXOnfuDCcnJzRq1AitWrXC5MmT7+u5eBBWNb7HOqRz585YvXo10tLSpF+UAwcOICwsDGFhYSgsLER6ejratWsnbfP394eLiwsAYOjQoVi1ahX69OmDCRMmIC0tDYmJiTh16pTJ3HlGRgb++c9/4o033sCwYcPQqlUrXL9+HREREcjOzsbo0aPh4eGB1atXY8+ePZXW3qtXL5w5cwZfffUVPvroIzRp0gQA0LRpU7Nru9OyZcvw5ptvYvLkyXj33XcBAHv27EH37t0RFBSE6dOnw8LCAitXrsQzzzyDH3/8ER07dpQ9Rt++feHj44PExEQcPXoUK1asgKurK+bNm1fhflevXm3SNmXKFBQUFKBRo0YAgJMnTyI8PBzNmjVDXFwc7O3tsW7dOvTs2RPffPMNXnrpJdn9R44ciaZNm2LatGn3PONevnw5Ro8ejT59+mDMmDG4ceMGfv31V6SlpeGVV14BAOTn56NTp05SoDRt2hTbtm3DkCFDYDAYMHbsWAB4oGNaEaPRiBdeeAH79+/H8OHD0bp1a5w4cQIfffQRzpw5Y/Ln//79+/Htt99i5MiRcHBwwMKFC9G7d29kZ2dLr93z58+jY8eO0Ov1GD58OPz9/ZGbm4sNGzbg2rVrUKvVuHbtGrp06YLc3Fy88cYb8PLywsGDBxEfH48LFy5gwYIFtT6OX375Bd26dYO7uzsSEhJQWlqKmTNnSr8HZVavXo2hQ4eiY8eOGD58OADA19dX1qey1+3JkyfRo0cPtGvXDjNnzoRGo8HZs2dNTmBqhGjATp48KQCIWbNmCSGEuHnzprC3txerVq0SQgih0+nE4sWLhRBCGAwGYWlpKYYNGyaEEOLYsWMCgBg6dKjsMf/1r38JAGLPnj1Sm7e3twAgtm/fLuu7YMECAUCsW7dOaisqKhJ+fn4CgEhOTr5n/e+9954AIM6dOydrN7e26OhoIYQQH3/8sVCpVNLzIYQQRqNRtGzZUkRFRQmj0Si1X7t2Tfj4+Ihnn31Waps+fboAIAYPHizb70svvSRcXFxkbd7e3iImJqbCsc2fP18AEF988YXUFhERIdq2bStu3Lghqy8sLEy0bNlSalu5cqUAIDp37ixu3bpV4T7KvPjiiyIwMPCefYYMGSLc3d3FpUuXZO39+/cXjo6O4tq1a0II845pRc9Bly5dRJcuXaSfV69eLSwsLMSPP/4o67dkyRIBQBw4cEBqAyDUarU4e/as1Hb8+HEBQHzyySdS22uvvSYsLCzEoUOHTPZfdpxnzZol7O3txZkzZ2Tb4+LihKWlpcjOzja5b02P4/nnnxd2dnYiNzdXasvMzBRWVlbi7nizt7cv9/mu6uv2o48+EgDExYsX7znumtCgp0pat24NFxcXae76+PHjKCoqklaNhIWFSf+bpqamorS0VJrf/uGHHwAA48ePlz3mhAkTAABbt26Vtfv4+CAqKkrW9sMPP8Dd3R19+vSR2uzs7KQzgvtlbm0AMH/+fIwZMwbz5s3DlClTpPZjx44hMzMTr7zyCv766y9cunQJly5dQlFRESIiIrBv3z6TP5vffPNN2c9PPvkk/vrrLxgMhirVn5ycjPj4eLz11lsYOHAgAODy5cvYs2cP+vbtiytXrkh1/PXXX4iKikJmZiZyc3NljzNs2DBYWlpWuj8nJyf8+eefOHToULnbhRD45ptv8Pzzz0MIIe370qVLiIqKQmFhIY4ePQrg4RzT9evXo3Xr1vD395ft+5lnngEAk9VPkZGRsrPJdu3aQavV4vfffwdw+8x348aNeP755xEcHGyyv7IphvXr1+PJJ5+Es7OzbL+RkZEoLS3Fvn37anUcpaWl2L17N3r27AkPDw+pn5+fH7p3725WbUDlr1snJycAwPfff3/fU0XVpUFPlahUKoSFhUnhc+DAAbi6usLPzw/A7eBetGgRAEgBXhbcWVlZsLCwkPqWcXNzg5OTE7KysmTtPj4+JvvPysqCn5+fbC4OAFq1avVA4zK3tpSUFGzduhWTJk2SzWsDQGZmJgAgJiamwv0VFhbC2dlZ+tnLy0u2vWzb33//Da1We8/a//zzT/Tr1w/h4eH48MMPpfazZ89CCIGpU6di6tSp5d63oKAAzZo1k34u7zkvz6RJk7B792507NgRfn5+6Nq1K1555RXpvYyLFy9Cr9dj2bJlWLZsWYX7Bh7OMc3MzMSpU6dM/vy/e99l7n7+gdvH4O+//wZwezwGgwFt2rSpdL+//vprlfdbmeoeR0FBAa5fv27yOgdQbltlKnvd9uvXDytWrMDQoUMRFxeHiIgI9OrVC3369IGFRc2eAzfo4AZuB/HmzZtx4sQJaX67TFhYGCZOnIjc3Fzs378fHh4eaNGihez+d/+CVsTW1rZa666KqtYWGBgIvV6P1atX44033pAFXtmZxXvvvYfHHnus3PuXzUGXqegsV1TyLXklJSXo06cPNBoN1q1bByur/708y+r417/+ZfKXS5m7f1mr+py3bt0aGRkZ2LJlC7Zv345vvvkGn376KaZNm4aEhARp36+++mqF/4GVvQ9ijoqOT2lpqew5NBqNaNu2rew/sjt5enrKfr7f5/9uRqMRzz77LN5+++1ytz/66KNmP15tjKOqKtufra0t9u3bh+TkZGzduhXbt2/Hf/7zHzzzzDPYuXNnlf66qy4M7jvWcx84cEB6kwkAgoKCoNFosHfvXqSlpeG5556Ttnl7e8NoNCIzMxOtW7eW2vPz86HX6+Ht7V3pvr29vZGeng4hhOyXOCMjo0q1V/SLb25tTZo0wYYNG9C5c2dERERI/0kB/3sDR6vVIjIyskp13a/Ro0fj2LFj2LdvH3Q6nWxb2X+Y1tbWD6UOe3t79OvXD/369UNJSQl69eqF2bNnIz4+Hk2bNoWDgwNKS0sr3bc5x9TZ2bncNfhZWVmyEwRfX18cP34cERERVf7P+F6aNm0KrVaL9PT0e/bz9fXF1atXq+35ru5xuLq6wsbGBmfPnjXZVl5bdezTwsICERERiIiIwIcffog5c+bgnXfeQXJy8kP//ZDVUWN7qqPKlhGtXbsWubm5sjNujUaDxx9/HIsXL0ZRUZFs/XZZiN/9znrZ2UR0dHSl+37uuedw/vx5bNiwQWq7du1ahX+O383e3h4ATH7576e2Rx55BLt378b169fx7LPP4q+//gJw+z8vX19fvP/++7h69arJ/S5evFilWiuzcuVKLF26FIsXLzZZpQLc/iV9+umnsXTpUly4cKFa6ygbaxm1Wo2AgAAIIXDz5k1YWlqid+/e+Oabb8oNuzv3bc4x9fX1xU8//SRbfrdlyxbk5OTI+vXt2xe5ubnlfmDp+vXrZq9Rt7CwQM+ePbF582YcPnzYZHvZGWbfvn2RmpqKHTt2mPTR6/W4deuWWfut7nFYWloiMjISGzduxPnz56X2s2fPYtu2bSb97e3tH+jDapcvXzZpK/srtCrLbKtTgz/jVqvVeOKJJ/Djjz9Co9EgKChItj0sLAwffPABAPkHb9q3b4+YmBgsW7YMer0eXbp0wc8//4xVq1ahZ8+e+Mc//lHpvocNG4ZFixbhtddew5EjR+Du7o7Vq1fDzs6uSrWX1frOO++gf//+sLa2xvPPP3/ftfn5+WHnzp14+umnERUVhT179kCr1WLFihXo3r07AgMD8frrr6NZs2bIzc1FcnIytFotNm/eXKV6K3Lp0iWMHDkSAQEB0Gg0WLNmjWz7Sy+9BHt7eyxevBidO3dG27ZtMWzYMLRo0QL5+flITU3Fn3/+iePHj9/X/rt27Qo3NzeEh4dDp9Ph1KlTWLRoEaKjo+Hg4AAAmDt3LpKTkxESEoJhw4YhICAAly9fxtGjR7F7927pl9qcYzp06FBs2LAB3bp1Q9++ffHbb79hzZo1JsvUBg4ciHXr1uHNN99EcnIywsPDUVpaitOnT2PdunXS5wPMMWfOHOzcuRNdunSRluZduHAB69evx/79++Hk5ISJEydi06ZN6NGjBwYNGoSgoCAUFRXhxIkT2LBhA/744w9pGWpVPIxxzJgxAzt37kR4eDhGjBiB0tJSLFq0CG3atMGxY8dkfYOCgrB79258+OGH0gfwQkJCqryvmTNnYt++fYiOjoa3tzcKCgrw6aef4pFHHqn5T1PX0mqWOiU+Pl4AEGFhYSbbvv32WwFAODg4mCwtu3nzpkhISBA+Pj7C2tpaeHp6ivj4eNlyNSHkS+7ulpWVJV544QVhZ2cnmjRpIsaMGSO2b99epeWAQtxestWsWTNhYWEhWxr4ILWlpaUJBwcH8dRTT0nL3H755RfRq1cv4eLiIjQajfD29hZ9+/YVSUlJ0v3KllXdvVyqbHnencsW71wKd+7cOQGgwtud9/vtt9/Ea6+9Jtzc3IS1tbVo1qyZ6NGjh9iwYYPJ/spb6laepUuXiqeeekoam6+vr5g4caIoLCyU9cvPzxexsbHC09NTWFtbCzc3NxERESGWLVsm62fOMf3ggw9Es2bNhEajEeHh4eLw4cMmy+iEEKKkpETMmzdPBAYGCo1GI5ydnUVQUJBISEiQ1QlAxMbGmoyxvKWHWVlZ4rXXXhNNmzYVGo1GtGjRQsTGxori4mKpz5UrV0R8fLzw8/MTarVaNGnSRISFhYn3339flJSU3PN5ralxJCUliQ4dOgi1Wi18fX3FihUrxIQJE4SNjY2s3+nTp8VTTz0lbG1tBQDpcar6uk1KShIvvvii8PDwEGq1Wnh4eIh//vOfJssla4JKiIc0009Ekr179+If//gHkpOTFXHFPKXr2bMnTp48Ka2Kqm8a/Bw3ESnb3df9yczMxA8//FCv/4Ns8HPcRKRsLVq0wKBBg9CiRQtkZWXhs88+g1qtrnAZY33A4CYiRevWrRu++uor5OXlQaPRIDQ0FHPmzEHLli1ru7SHhnPcREQKwzluIiKFYXATESmMIue4jUYjzp8/DwcHh2r5GCsRUW0TQuDKlSvw8PCo9KJVigzu8+fPm1yQhoioPsjJyan0G5sUGdxlH0POycmp9DKhRERKYDAY4OnpKeXbvSgyuMumR7RaLYObiOqVqkz/8s1JIiKFYXATESkMg5uISGEY3ERECsPgJiJSGEWuKqH6p3ncVpO2P+ZW/vVvRA0Rz7iJiBSGwU1EpDAMbiIihWFwExEpDIObiEhhGNxERArD4CYiUhgGNxGRwjC4iYgUhsFNRKQw/Mg7USXK+zg+wI/kU+3hGTcRkcIwuImIFIbBTUSkMAxuIiKFYXATESkMg5uISGEY3ERECsPgJiJSGAY3EZHCMLiJiBSGwU1EpDAMbiIihWFwExEpDIObiEhhzA7u3NxcvPrqq3BxcYGtrS3atm2Lw4cPS9uFEJg2bRrc3d1ha2uLyMhIZGZmyh7j8uXLGDBgALRaLZycnDBkyBBcvXr1wUdDRNQAmBXcf//9N8LDw2FtbY1t27bhv//9Lz744AM4OztLfebPn4+FCxdiyZIlSEtLg729PaKionDjxg2pz4ABA3Dy5Ens2rULW7Zswb59+zB8+PDqGxURUT1m1hcpzJs3D56enli5cqXU5uPjI/1bCIEFCxZgypQpePHFFwEAX3zxBXQ6HTZu3Ij+/fvj1KlT2L59Ow4dOoTg4GAAwCeffILnnnsO77//Pjw8PKpjXERE9ZZZZ9ybNm1CcHAwXn75Zbi6uqJDhw5Yvny5tP3cuXPIy8tDZGSk1Obo6IiQkBCkpqYCAFJTU+Hk5CSFNgBERkbCwsICaWlpDzoeIqJ6z6zg/v333/HZZ5+hZcuW2LFjB0aMGIHRo0dj1apVAIC8vDwAgE6nk91Pp9NJ2/Ly8uDq6irbbmVlhcaNG0t97lZcXAyDwSC7ERE1VGZNlRiNRgQHB2POnDkAgA4dOiA9PR1LlixBTEzMQykQABITE5GQkPDQHp+ISEnMOuN2d3dHQECArK1169bIzs4GALi5uQEA8vPzZX3y8/OlbW5ubigoKJBtv3XrFi5fviz1uVt8fDwKCwulW05OjjllExHVK2YFd3h4ODIyMmRtZ86cgbe3N4Dbb1S6ubkhKSlJ2m4wGJCWlobQ0FAAQGhoKPR6PY4cOSL12bNnD4xGI0JCQsrdr0ajgVarld2IiBoqs6ZKxo0bh7CwMMyZMwd9+/bFzz//jGXLlmHZsmUAAJVKhbFjx+Ldd99Fy5Yt4ePjg6lTp8LDwwM9e/YEcPsMvVu3bhg2bBiWLFmCmzdvYtSoUejfvz9XlBARVYFZwf3EE0/gu+++Q3x8PGbOnAkfHx8sWLAAAwYMkPq8/fbbKCoqwvDhw6HX69G5c2ds374dNjY2Up+1a9di1KhRiIiIgIWFBXr37o2FCxdW36iIiOoxlRBC1HYR5jIYDHB0dERhYSGnTeqJ5nFbTdr+mBtdC5WYKq82oO7UR/WDObnGa5UQESkMg5uISGEY3ERECsPgJiJSGAY3EZHCMLiJiBSGwU1EpDAMbiIihTHrk5NU/9TlD74QUfl4xk1EpDAMbiIihWFwExEpDIObiEhhGNxERArD4CYiUhgGNxGRwjC4iYgUhsFNRKQwDG4iIoVhcBMRKQyDm4hIYXiRKSKF4LfNUxmecRMRKQyDm4hIYRjcREQKw+AmIlIYBjcRkcIwuImIFOaBgnvu3LlQqVQYO3as1Hbjxg3ExsbCxcUFjRo1Qu/evZGfny+7X3Z2NqKjo2FnZwdXV1dMnDgRt27depBSiIgajPsO7kOHDmHp0qVo166drH3cuHHYvHkz1q9fj5SUFJw/fx69evWStpeWliI6OholJSU4ePAgVq1ahc8//xzTpk27/1EQETUg9xXcV69exYABA7B8+XI4OztL7YWFhfj3v/+NDz/8EM888wyCgoKwcuVKHDx4ED/99BMAYOfOnfjvf/+LNWvW4LHHHkP37t0xa9YsLF68GCUlJdUzKiKieuy+gjs2NhbR0dGIjIyUtR85cgQ3b96Utfv7+8PLywupqakAgNTUVLRt2xY6nU7qExUVBYPBgJMnT5a7v+LiYhgMBtmNiKihMvsj719//TWOHj2KQ4cOmWzLy8uDWq2Gk5OTrF2n0yEvL0/qc2dol20v21aexMREJCQkmFsqEVG9ZNYZd05ODsaMGYO1a9fCxsbmYdVkIj4+HoWFhdItJyenxvZNRFTXmBXcR44cQUFBAR5//HFYWVnBysoKKSkpWLhwIaysrKDT6VBSUgK9Xi+7X35+Ptzc3AAAbm5uJqtMyn4u63M3jUYDrVYruxERNVRmTZVERETgxIkTsrbXX38d/v7+mDRpEjw9PWFtbY2kpCT07t0bAJCRkYHs7GyEhoYCAEJDQzF79mwUFBTA1dUVALBr1y5otVoEBARUx5iogSrv6nm8ch7VR2YFt4ODA9q0aSNrs7e3h4uLi9Q+ZMgQjB8/Ho0bN4ZWq8Vbb72F0NBQdOrUCQDQtWtXBAQEYODAgZg/fz7y8vIwZcoUxMbGQqPRVNOwiIjqr2q/HvdHH30ECwsL9O7dG8XFxYiKisKnn34qbbe0tMSWLVswYsQIhIaGwt7eHjExMZg5c2Z1l0JEVC89cHDv3btX9rONjQ0WL16MxYsXV3gfb29v/PDDDw+6a6rDKrroPxE9OH4DDtVZD/MbXzgfTkrG4KZ6jWf+VB8xuIn+P4Y8KQWDmxSHAUsNHYOb6D5xnpxqC79IgYhIYRjcREQKw+AmIlIYznHTA+ObhUQ1i2fcREQKwzNuqjKeWRPVDTzjJiJSGAY3EZHCcKqEqBo9zAtjEZVhcBPVAAY6VSdOlRARKQyDm4hIYThVQia47K/m8EJVdD94xk1EpDAMbiIihWFwExEpDOe4ieoYvsdAlWFwNxAMA6L6g1MlREQKw+AmIlIYBjcRkcIwuImIFMas4E5MTMQTTzwBBwcHuLq6omfPnsjIyJD1uXHjBmJjY+Hi4oJGjRqhd+/eyM/Pl/XJzs5GdHQ07Ozs4OrqiokTJ+LWrVsPPhoiogbArOBOSUlBbGwsfvrpJ+zatQs3b95E165dUVRUJPUZN24cNm/ejPXr1yMlJQXnz59Hr169pO2lpaWIjo5GSUkJDh48iFWrVuHzzz/HtGnTqm9URET1mEoIIe73zhcvXoSrqytSUlLw1FNPobCwEE2bNsWXX36JPn36AABOnz6N1q1bIzU1FZ06dcK2bdvQo0cPnD9/HjqdDgCwZMkSTJo0CRcvXoRara50vwaDAY6OjigsLIRWq73f8hsULgesv3htk/rBnFx7oHXchYWFAIDGjRsDAI4cOYKbN28iMjJS6uPv7w8vLy8puFNTU9G2bVsptAEgKioKI0aMwMmTJ9GhQ4cHKYnAkCaq7+47uI1GI8aOHYvw8HC0adMGAJCXlwe1Wg0nJydZX51Oh7y8PKnPnaFdtr1sW3mKi4tRXFws/WwwGO63bCIixbvvVSWxsbFIT0/H119/XZ31lCsxMRGOjo7SzdPT86Hvk4iorrqv4B41ahS2bNmC5ORkPPLII1K7m5sbSkpKoNfrZf3z8/Ph5uYm9bl7lUnZz2V97hYfH4/CwkLplpOTcz9lExHVC2YFtxACo0aNwnfffYc9e/bAx8dHtj0oKAjW1tZISkqS2jIyMpCdnY3Q0FAAQGhoKE6cOIGCggKpz65du6DVahEQEFDufjUaDbRarexGRNRQmTXHHRsbiy+//BLff/89HBwcpDlpR0dH2NrawtHREUOGDMH48ePRuHFjaLVavPXWWwgNDUWnTp0AAF27dkVAQAAGDhyI+fPnIy8vD1OmTEFsbCw0Gk31j5CIqJ4xazmgSqUqt33lypUYNGgQgNsfwJkwYQK++uorFBcXIyoqCp9++qlsGiQrKwsjRozA3r17YW9vj5iYGMydOxdWVlX7f4TLAe+Nq0qISwSVx5xce6B13LWFwX0bA5oqwuBWHnNyjdcqISJSGH6RQi2p6GyZZ0pEVBkGNxGZpbyTDp5w1CwGN1E9xL/o6jfOcRMRKQzPuBWAq0eI6E484yYiUhgGNxGRwjC4iYgUhnPcdQzns+lhMuf1xRUodReDm4jKxZOIuotTJURECsPgJiJSGAY3EZHCMLiJiBSGwU1EpDAMbiIihWFwExEpDIObiEhhGNxERArD4CYiUhh+5P0h48eGiai68YybiEhheMZNRA+M33FZsxjc1YjTIkRyDPSHg1MlREQKwzNuIqpx5Z2J8yy86hjc94FTIkRUm2otuBcvXoz33nsPeXl5aN++PT755BN07NixtsohojqKX7dmqlaC+z//+Q/Gjx+PJUuWICQkBAsWLEBUVBQyMjLg6upaGyURUS3jX7JVpxJCiJreaUhICJ544gksWrQIAGA0GuHp6Ym33noLcXFxld7fYDDA0dERhYWF0Gq11VKTOXNufIERKYsSzsTNybUaP+MuKSnBkSNHEB8fL7VZWFggMjISqamp5d6nuLgYxcXF0s+FhYUAbg/UXG2m76hyX69x681+fCKqe8r7XU5PiKqFSipWlmdVOZeu8eC+dOkSSktLodPpZO06nQ6nT58u9z6JiYlISEgwaff09HwoNRJR/ee4oLYrKN+VK1fg6Oh4zz6KWFUSHx+P8ePHSz8bjUZcvnwZLi4uUKlUNVKDwWCAp6cncnJyqm16piYpvX6AY6gLlF4/UHfHIITAlStX4OHhUWnfGg/uJk2awNLSEvn5+bL2/Px8uLm5lXsfjUYDjUYja3NycnpYJd6TVqutUwfbXEqvH+AY6gKl1w/UzTFUdqZdpsY/OalWqxEUFISkpCSpzWg0IikpCaGhoTVdDhGR4tTKVMn48eMRExOD4OBgdOzYEQsWLEBRURFef/312iiHiEhRaiW4+/Xrh4sXL2LatGnIy8vDY489hu3bt5u8YVmXaDQaTJ8+3WTKRimUXj/AMdQFSq8fqB9jqJV13EREdP94dUAiIoVhcBMRKQyDm4hIYRjcREQKw+C+w2effYZ27dpJC/NDQ0Oxbds2afuNGzcQGxsLFxcXNGrUCL179zb5IFFdMnfuXKhUKowdO1Zqq+tjmDFjBlQqlezm7+8vba/r9ZfJzc3Fq6++ChcXF9ja2qJt27Y4fPiwtF0IgWnTpsHd3R22traIjIxEZmZmLVb8P82bNzc5BiqVCrGxsQCUcQxKS0sxdepU+Pj4wNbWFr6+vpg1a5bsOiB1+RhUSpBk06ZNYuvWreLMmTMiIyNDTJ48WVhbW4v09HQhhBBvvvmm8PT0FElJSeLw4cOiU6dOIiwsrJarLt/PP/8smjdvLtq1ayfGjBkjtdf1MUyfPl0EBgaKCxcuSLeLFy9K2+t6/UIIcfnyZeHt7S0GDRok0tLSxO+//y527Nghzp49K/WZO3eucHR0FBs3bhTHjx8XL7zwgvDx8RHXr1+vxcpvKygokD3/u3btEgBEcnKyEEIZx2D27NnCxcVFbNmyRZw7d06sX79eNGrUSHz88cdSn7p8DCrD4K6Es7OzWLFihdDr9cLa2lqsX79e2nbq1CkBQKSmptZihaauXLkiWrZsKXbt2iW6dOkiBbcSxjB9+nTRvn37crcpoX4hhJg0aZLo3LlzhduNRqNwc3MT7733ntSm1+uFRqMRX331VU2UaJYxY8YIX19fYTQaFXMMoqOjxeDBg2VtvXr1EgMGDBBCKO8Y3I1TJRUoLS3F119/jaKiIoSGhuLIkSO4efMmIiMjpT7+/v7w8vKq8HK0tSU2NhbR0dGyWgEoZgyZmZnw8PBAixYtMGDAAGRnZwNQTv2bNm1CcHAwXn75Zbi6uqJDhw5Yvny5tP3cuXPIy8uTjcPR0REhISF1ahzA7cswr1mzBoMHD4ZKpVLMMQgLC0NSUhLOnDkDADh+/Dj279+P7t27A1DWMSiPIq4OWJNOnDiB0NBQ3LhxA40aNcJ3332HgIAAHDt2DGq12uTiVjqdDnl5ebVTbDm+/vprHD16FIcOHTLZlpeXV+fHEBISgs8//xytWrXChQsXkJCQgCeffBLp6emKqB8Afv/9d3z22WcYP348Jk+ejEOHDmH06NFQq9WIiYmRai3v0sZ1aRwAsHHjRuj1egwaNAiAMl5DABAXFweDwQB/f39YWlqitLQUs2fPxoABAwBAUcegPAzuu7Rq1QrHjh1DYWEhNmzYgJiYGKSkpNR2WVWSk5ODMWPGYNeuXbCxsantcu5L2RkRALRr1w4hISHw9vbGunXrYGtrW4uVVZ3RaERwcDDmzJkDAOjQoQPS09OxZMkSxMTE1HJ15vn3v/+N7t27V+lSo3XJunXrsHbtWnz55ZcIDAzEsWPHMHbsWHh4eCjuGJSHUyV3UavV8PPzQ1BQEBITE9G+fXt8/PHHcHNzQ0lJCfR6vaz/vS5HW9OOHDmCgoICPP7447CysoKVlRVSUlKwcOFCWFlZQafT1fkx3M3JyQmPPvoozp49q4hjAADu7u4ICAiQtbVu3Vqa8imr1ZxLG9eGrKws7N69G0OHDpXalHIMJk6ciLi4OPTv3x9t27bFwIEDMW7cOCQmJgJQzjGoCIO7EkajEcXFxQgKCoK1tbXscrQZGRnIzs6uM5ejjYiIwIkTJ3Ds2DHpFhwcjAEDBkj/rutjuNvVq1fx22+/wd3dXRHHAADCw8ORkZEhaztz5gy8vb0BAD4+PnBzc5ONw2AwIC0trU6NY+XKlXB1dUV09P++r1Epx+DatWuwsJDHm6WlJYxGIwDlHIMK1fa7o3VJXFycSElJEefOnRO//vqriIuLEyqVSuzcuVMIcXsZlJeXl9izZ484fPiwCA0NFaGhobVc9b3duapEiLo/hgkTJoi9e/eKc+fOiQMHDojIyEjRpEkTUVBQIISo+/ULcXspppWVlZg9e7bIzMwUa9euFXZ2dmLNmjVSn7lz5wonJyfx/fffi19//VW8+OKLdWopWmlpqfDy8hKTJk0y2aaEYxATEyOaNWsmLQf89ttvRZMmTcTbb78t9anrx+BeGNx3GDx4sPD29hZqtVo0bdpURERESKEthBDXr18XI0eOFM7OzsLOzk689NJL4sKFC7VYceXuDu66PoZ+/foJd3d3oVarRbNmzUS/fv1k65/rev1lNm/eLNq0aSM0Go3w9/cXy5Ytk203Go1i6tSpQqfTCY1GIyIiIkRGRkYtVWtqx44dAkC5NSnhGBgMBjFmzBjh5eUlbGxsRIsWLcQ777wjiouLpT51/RjcCy/rSkSkMJzjJiJSGAY3EZHCMLiJiBSGwU1EpDAMbiIihWFwExEpDIObiEhhGNxERArD4CYiUhgGNxGRwjC4iYgUhsFNRKQw/w+6Hwoc9gvjmgAAAABJRU5ErkJggg==", 572 | "text/plain": [ 573 | "
" 574 | ] 575 | }, 576 | "metadata": {}, 577 | "output_type": "display_data" 578 | } 579 | ], 580 | "source": [ 581 | "tokenized_lengths = [len(t) for t in tok.encode(test_texts)['tokens']]\n", 582 | "plt.figure(figsize=(4, 2))\n", 583 | "plt.hist(tokenized_lengths, bins=50)\n", 584 | "plt.title('Word tokenizer sequence lengths')\n", 585 | "plt.show()" 586 | ] 587 | }, 588 | { 589 | "cell_type": "code", 590 | "execution_count": 23, 591 | "metadata": {}, 592 | "outputs": [ 593 | { 594 | "name": "stdout", 595 | "output_type": "stream", 596 | "text": [ 597 | "\n", 598 | "\n", 599 | "\n", 600 | "Tokenized\n", 601 | "['[BOS]', 'kelly', 'was', 'at', 'home', ',', 'trying', 'to', 'sleep', '.', 'suddenly', ',', 'she', 'heard', 'foot', '##st', '##ep', '##s', 'in', 'her', 'kitchen', '.', 'she', 'grabbed', 'a', 'gun', 'and', 'stood', 'at', 'the', 'top', 'of', 'the', 'stairs', '.', 'she', 'warned', 'who', '##ever', 'it', 'was', 'that', 'she', 'was', 'ar', '##med', '.', 'she', 'heard', 'them', 'run', 'out', 'of', 'the', 'house', 'and', 'then', 'called', 'police', '.', '[EOS]']\n", 602 | "\n", 603 | "['[BOS]', 'i', 'bought', 'a', '19', '##6', '##9', 'm', '##erc', '##ury', 'mo', '##nt', '##e', '##g', '##o', 'with', 'a', 'loose', 'front', 'seat', '.', 'the', 'seat', 'was', 'loose', 'because', 'the', 'car', \"'\", 's', 'floor', 'had', 'r', '##usted', 'through', '.', 'i', 'removed', 'the', 'seat', 'and', 'repair', '##ed', 'the', 'floor', 'with', 'pieces', 'of', 'she', '##et', 'metal', '.', 'my', 'repair', 'held', 'the', 'seat', 'fir', '##m', '##ly', 'in', 'place', 'after', 'i', 're', '##in', '##st', '##all', '##ed', 'it', '.', 'the', 'car', 'then', 'successfully', 'passed', 'the', 'sa', '##fet', '##y', 'insp', '##ection', '.', '[EOS]']\n", 604 | "\n", 605 | "Detokenized\n", 606 | "kelly was at home, trying to sleep. suddenly, she heard footsteps in her kitchen. she grabbed a gun and stood at the top of the stairs. she warned whoever it was that she was armed. she heard them run out of the house and then called police.\n", 607 | "\n", 608 | "i bought a 1969 mercury montego with a loose front seat. the seat was loose because the car ' s floor had rusted through. i removed the seat and repaired the floor with pieces of sheet metal. my repair held the seat firmly in place after i reinstalled it. the car then successfully passed the safety inspection.\n", 609 | "\n", 610 | "Vocab size: 4096\n" 611 | ] 612 | } 613 | ], 614 | "source": [ 615 | "tok = BPETokenizer(train_texts, vocab_size=4096)\n", 616 | "\n", 617 | "tokenized = tok.encode(train_texts[idxs])\n", 618 | "print('Tokenized')\n", 619 | "print_texts(tokenized['tokens'])\n", 620 | "\n", 621 | "print('Detokenized')\n", 622 | "print_texts(tok.decode(tokenized['input_ids']))\n", 623 | "\n", 624 | "print(f'Vocab size: {len(tok.token2id)}')" 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "execution_count": 24, 630 | "metadata": {}, 631 | "outputs": [ 632 | { 633 | "data": { 634 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAADcCAYAAABpsPoeAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQAAKOlJREFUeJzt3XtYVNXiPvCXiwzXGQRlEJWLSCGI4QFB1EQFQULTRP3q8ShyTEsxJc17XtOwzMRKpTodMW+dLLHUUAm81IlQOerxkkSFwlEHvMTFGyqzfn/4sH8OMyijCGx5P88zz+Osvfbeay9nXvasvWaPiRBCgIiIZMe0oRtARESPhgFORCRTDHAiIpligBMRyRQDnIhIphjgREQyxQAnIpIpBjgRkUwxwImIZIoBTnVi//79MDExwVdffdUg+92/f3+97pfqz9mzZ2FiYoKUlJSGbspDubu7o3///vW2P1kGeEpKCkxMTHQeTk5O6N27N9LS0vTq31/P1NQULi4uiIiI0HvTu7u762236tGvX7+Htmvp0qV48cUXoVarYWJigoULF9ZY9/z58xg2bBjs7e2hVCoxcOBA/PHHHzp1CgsLsWjRIgQFBaF58+Zo0aIFevXqhe+///6hbRk3bhxMTExq/WJas2aNLN4gRA3t9OnTWLhwIc6ePdvQTYF5QzfgcSxevBgeHh4QQqCoqAgpKSl44YUXsGPHDr3g6tu3L0aPHg0hBPLz87FmzRr06dMHu3btQlRUlFTP398f06ZN09uXi4vLQ9vz5ptvwtnZGZ07d8aePXtqrHft2jX07t0bpaWlmDNnDpo1a4aVK1ciNDQUx44dg6OjIwDgm2++wTvvvINBgwYhNjYWd+/exeeff46+ffvin//8J+Li4gxu/8iRI0hJSYGlpeVD21xlzZo1aNGiBcaMGVPrdRqDnj174ubNm7CwsGjoplATcfr0aSxatAi9evWCu7t7g7ZF1gEeFRWFwMBA6fnYsWOhVquxZcsWvQB/5pln8Le//U16/tJLL6FTp05ISkrSCfDWrVvr1DNGfn4+3N3dcfnyZbRs2bLGemvWrEFeXh4OHTqELl26SMfSsWNHrFixAm+//TYAoHfv3igoKECLFi2kdV999VX4+/tj/vz5BgNcCIHJkydj9OjRyMjIeKTjkBNTU1Oj/lDVhRs3bsDa2rpe90lkiCyHUGpib28PKysrmJs//O+Sn58fWrRogfz8/Drbf23/Gn/11Vfo0qWLFN4A4O3tjbCwMHz55ZdSma+vr054A4BCocALL7yA//3vfygvL9fb9oYNG3Dy5EksXbrUqHafOnUKBw4ckIaMevXqJS3/448/MHToUDg4OMDa2hpdu3bFrl27HrrdiooK9O/fHyqVCj/99BMAQKvVIikpCb6+vrC0tIRarcYrr7yCP//8U69N/fv3x48//oigoCBYWlqiXbt2+Pzzz3XqVR8DNzS8ZuiYAGDjxo0ICAiAlZUVHBwcMHz4cBQWFurU6dWrFzp27IicnBz07NkT1tbWmDNnTo3HrNFoEBcXhzZt2kChUKBVq1YYOHCg3sfttLQ0PP/887CxsYGdnR2io6Nx6tQpve1t374dHTt2hKWlJTp27IjU1FSMGTNG57VW03WAmsaOz5w5gyFDhsDBwQGWlpYIDAzEt99+q1Onqh///e9/Y+rUqWjZsiVsbGzw0ksv4dKlS3rtTEtLQ2hoKOzs7KBUKtGlSxds3rxZp052djb69esHlUoFa2trhIaG4t///neNffkwdX0cWq0WCxcuhIuLC6ytrdG7d2+cPn0a7u7u0ifTlJQUDB06FMC9E6yq11b1vn/Y6/bOnTtYtGgRvLy8YGlpCUdHR/To0QPp6elG9YGsA7y0tBSXL1/GpUuXcOrUKUyYMAHXrl2r1Rn0n3/+iT///FMarqhy584dXL58We9x8+bNOmmzVqvFf//7X51PDlWCgoLw+++/Gwzm+2k0GlhbW+udBZaXl2PmzJmYM2cOnJ2da92mpKQktGnTBt7e3tiwYQM2bNiAuXPnAgCKiorQrVs37NmzBxMnTsTSpUtx69YtvPjii0hNTa1xmzdv3sSAAQPw008/4fvvv0e3bt0AAK+88gqmT5+O7t27Y9WqVYiLi8OmTZsQGRmJO3fu6Gzjt99+w5AhQ9C3b1+sWLECzZs3x5gxYwwGXZWePXtKx1D1WLJkCQDAyclJqrd06VKMHj0aXl5eeP/995GQkICMjAz07NkTJSUlOtu8cuUKoqKi4O/vj6SkJPTu3bvG/cfExCA1NRVxcXFYs2YNJk+ejPLychQUFEh1NmzYgOjoaNja2uKdd97BvHnzcPr0afTo0UMn6Pfu3YuYmBiYmJggMTERgwYNQlxcHI4cOVLj/h/m1KlT6Nq1K3755RfMmjULK1asgI2NDQYNGmTw//O1117D8ePHsWDBAkyYMAE7duzApEmTdOqkpKQgOjoaV69exezZs7Fs2TL4+/tj9+7dUp3MzEz07NkTZWVlWLBgAd5++22UlJSgT58+OHToUKM4jtmzZ2PRokUIDAzE8uXL4eXlhcjISFy/fl2q07NnT0yePBkAMGfOHOk11qFDB6lObV63CxcuxKJFi9C7d2989NFHmDt3LlxdXfGf//zHuI4QMrRu3ToBQO+hUChESkqKXn0AYuzYseLSpUuiuLhYZGdni7CwMAFArFixQqrn5uZmcLsARGJiYq3bd+nSJQFALFiwoMZlixcv1lu2evVqAUCcOXOmxm3n5eUJS0tLMWrUKL1lb7zxhvDw8BC3bt2Sjic6OrpWbfb19RWhoaF65QkJCQKA+OGHH6Sy8vJy4eHhIdzd3UVlZaUQQoh9+/YJAGLr1q2ivLxchIaGihYtWoijR49K6/3www8CgNi0aZPOPnbv3q1XXvV/cfDgQamsuLhYKBQKMW3aNKmsar/79u0zeFw3b94UAQEBwsXFRVy8eFEIIcTZs2eFmZmZWLp0qU7dEydOCHNzc53y0NBQAUAkJyfX0HP/359//ikAiOXLl9dYp7y8XNjb24tx48bplGs0GqFSqXTK/f39RatWrURJSYlUtnfvXgFAuLm5PbQP8vPzBQCxbt06qSwsLEz4+flJrxEhhNBqtaJbt27Cy8tLKqt6j4WHhwutViuVv/7668LMzExqU0lJibCzsxPBwcHi5s2bOvuvWk+r1QovLy8RGRmps60bN24IDw8P0bdv3xr7q76OQ6PRCHNzczFo0CCdfS9cuFAAELGxsVLZ1q1ba3zN1fZ1+9xzz9X6vfkgsj4DX716NdLT05Geno6NGzeid+/eePnll7Ft2za9up999hlatmwJJycnBAcHSx+pEhISdOoFBwdL27z/MWLEiDppc9WZvEKh0FtWNZZb09n+jRs3MHToUFhZWWHZsmU6y3799VesWrUKy5cvN7jtR/Xdd98hKCgIPXr0kMpsbW0xfvx4nD17FqdPn9apX1paioiICJw5cwb79++Hv7+/tGzr1q1QqVTo27evzqebgIAA2NraYt++fTrb8vHxwfPPPy89b9myJZ599lm92ToPMnHiRJw4cQJff/219Klk27Zt0Gq1GDZsmE47nJ2d4eXlpdcOhUJR4wXj+1lZWcHCwgL79+/XGxKqkp6ejpKSEowYMUJn32ZmZggODpb2ffHiRRw7dgyxsbFQqVTS+n379oWPj0+tj/9+V69eRWZmJoYNG4by8nJp31euXEFkZCTy8vJw/vx5nXXGjx8PExMT6fnzzz+PyspKnDt3Tjqe8vJyzJo1S+9aRNV6x44dQ15eHv7617/iypUr0n6vX7+OsLAwHDx4EFqttkGPIyMjA3fv3sXEiRN11nvttddq3a4qtXnd2tvb49SpU8jLyzN6+/eT9UXMoKAgnaGIESNGoHPnzpg0aRL69++vMzNh4MCBmDRpEkxMTGBnZwdfX1/Y2NjobbNFixYIDw9/Ym22srICcG98uLpbt27p1LlfZWUlhg8fjtOnTyMtLU1vVsyUKVPQrVs3xMTE1Gl7z507h+DgYL3yqo+M586dQ8eOHaXyhIQE3Lp1C0ePHoWvr6/OOnl5eSgtLdUZyrhfcXGxznNXV1e9Os2bN68xHKv7+OOPsW7dOnz88cfo2rWrTjuEEPDy8jK4XrNmzXSet27dulazXBQKBd555x1MmzYNarUaXbt2Rf/+/TF69Gjpj0fVG7ZPnz4Gt6FUKgFAChZDbXz22WeN/6iNex/thRCYN28e5s2bZ7BOcXExWrduLT2v/n/QvHlzAJD+D37//XcA0HkNVFd1zLGxsTXWKS0tlbb9ME/iOKr6u3379jr1HBwcat2umvZVtb/7X7eLFy/GwIED8cwzz6Bjx47o168fRo0ahU6dOhm1L1kHeHWmpqbo3bs3Vq1ahby8PJ0AadOmzRMN5tpycHCAQqHAxYsX9ZZVlRmasjhu3Djs3LkTmzZt0nvzZ2ZmYvfu3di2bZvOGOrdu3dx8+ZNnD17Fg4ODlI4PEkDBw7EF198gWXLluHzzz+Hqen//5Cn1Wrh5OSETZs2GVy3+swdMzMzg/VELX4F8NChQ5gyZQpefvlljB8/XmeZVquFiYkJ0tLSDO7D1tZW57mhP6g1SUhIwIABA7B9+3bs2bMH8+bNQ2JiIjIzM9G5c2fpTHPDhg0Gr1PU5gJ8dfefWd6vsrJS53nVvt944w1ERkYaXKd6gD3O/0H1/S5fvlznE9n9qvd5bbZX38dRW7XZV8+ePfH777/jm2++wd69e/GPf/wDK1euRHJyMl5++eVa7+upCnDgXmgB9+ZaN0ampqbw8/MzeCEqOzsb7dq1g52dnU759OnTsW7dOiQlJRkcyqm6QDZ48GC9ZefPn4eHhwdWrlypN1x0v5pCwM3NDbm5uXrlZ86ckZbfb9CgQYiIiMCYMWNgZ2eHtWvXSss8PT3x/fffo3v37kaForEuXbqEIUOGwN/fH6tXr9Zb7unpCSEEPDw88Mwzz9T5/j09PTFt2jRMmzYNeXl58Pf3x4oVK7Bx40Z4enoCuHdB9UEnFFX9augjdvX/j6ozxOoXX6vOKqu0a9cOwL1PGHV1MlN1PCdPntQLzep1lEplnez3SRxHVX//9ttv8PDwkMqvXLmi94mvpveKsRwcHBAXF4e4uDhcu3YNPXv2xMKFC40KcFmPgVd3584d7N27FxYWFjpXhRubIUOG4PDhwzohnpubi8zMTGmKUpXly5fjvffew5w5czBlyhSD2+vTpw9SU1P1Hi1btkRgYCBSU1MxYMCAB7bJxsZGLwAA4IUXXsChQ4eQlZUllV2/fh2ffPIJ3N3dDY7Hjh49Gh988AGSk5Mxc+ZMqXzYsGGorKzEW2+9pbfO3bt3De7fWFVDTbdv38bXX39tcOhj8ODBMDMzw6JFi/TOwIQQuHLlyiPt+8aNG9IwWBVPT0/Y2dlJQ2aRkZFQKpV4++239WbdAJCmtrVq1Qr+/v5Yv349SktLpeXp6el61x3c3NxgZmaGgwcP6pSvWbNG57mTkxN69eqFjz/+2OAnQEPTAx8mIiICdnZ2SExM1Dv2qr4NCAiAp6cn3nvvPYMnVsbu90kcR1hYGMzNzXVOOADgo48+0qtbNfT6OK/X6q8xW1tbtG/f3uDQ6oPI+gw8LS1NOhMsLi7G5s2bkZeXh1mzZj3ycMH58+exceNGvXJbW1sMGjTogetu2LAB586dw40bNwAABw8elKawjRo1SvorP3HiRHz66aeIjo7GG2+8gWbNmuH999+HWq3W+RZoamoqZsyYAS8vL3To0EGvXX379oVarYarq6vBcbeEhASo1eqHthu49yZbu3YtlixZgvbt28PJyQl9+vTBrFmzsGXLFkRFRWHy5MlwcHDA+vXrkZ+fj6+//lpniOR+kyZNQllZGebOnQuVSoU5c+YgNDQUr7zyChITE3Hs2DFERESgWbNmyMvLw9atW7Fq1SoMGTLkoW19kOTkZGRmZuLVV1/VuxipVqvRt29feHp6YsmSJZg9ezbOnj2LQYMGwc7ODvn5+UhNTcX48ePxxhtvGL3vX3/9FWFhYRg2bBh8fHxgbm6O1NRUFBUVYfjw4QDunYWuXbsWo0aNwl/+8hcMHz4cLVu2REFBAXbt2oXu3btLoZGYmIjo6Gj06NEDf//733H16lV8+OGH8PX11QlClUqFoUOH4sMPP4SJiQk8PT2xc+dOvWsKwL0L/z169ICfnx/GjRuHdu3aoaioCFlZWfjf//6H48ePG3XMSqUSK1euxMsvv4wuXbrgr3/9K5o3b47jx4/jxo0bWL9+PUxNTfGPf/wDUVFR8PX1RVxcHFq3bo3z589j3759UCqV2LFjh1H7revjUKvVmDJlClasWIEXX3wR/fr1w/Hjx5GWloYWLVronHX7+/vDzMwM77zzDkpLS6FQKNCnT58ar+0Y4uPjg169eiEgIAAODg44cuQIvvrqK72pjQ/12PNYGoChaYSWlpbC399frF27Vme6kBD3phHGx8c/dLsPmkZ4/7StmlRNOTP0qD7lqLCwUAwZMkQolUpha2sr+vfvL/Ly8nTqLFiwoMbtGdqmoeOp7VQljUYjoqOjhZ2dnQCgM6Xw999/F0OGDBH29vbC0tJSBAUFiZ07d+qsf/80wvvNmDFDABAfffSRVPbJJ5+IgIAAYWVlJezs7ISfn5+YMWOGuHDhwkPbHhoaqtO26lPoHtRn1adJfv3116JHjx7CxsZG2NjYCG9vbxEfHy9yc3N19ufr61urPrx8+bKIj48X3t7ewsbGRqhUKhEcHCy+/PJLvbr79u0TkZGRQqVSCUtLS+Hp6SnGjBkjjhw5otfGDh06CIVCIXx8fMS2bdtEbGys3uvx0qVLIiYmRlhbW4vmzZuLV155RZw8eVJv+p0Q9/4/R48eLZydnUWzZs1E69atRf/+/cVXX30l1al6jx0+fFiv3YZee99++63o1q2bsLKyEkqlUgQFBYktW7bo1Dl69KgYPHiwcHR0FAqFQri5uYlhw4aJjIyMB/aroWmET+I47t69K+bNmyecnZ2FlZWV6NOnj/jll1+Eo6OjePXVV3XW//TTT0W7du2EmZmZznZq+7pdsmSJCAoKEvb29sLKykp4e3uLpUuXitu3bz+wL6ozEeIJjOIT0RMzZswY7N+/v1HcTOlpV1JSgubNm2PJkiXSl9sak6dqDJyI6FEZ+v5FUlISAOjdhqGxkPUYOBFRXfnXv/4l3dHU1tYWP/74I7Zs2YKIiAh07969oZtnEAOciAhAp06dYG5ujnfffRdlZWXShc2qiQiNEcfAiYhkimPgREQyxQAnIpIpWY6Ba7VaXLhwAXZ2dnX2tVYiooYkhEB5eTlcXFxq/IJcdbIM8AsXLqBt27YN3QwiojpXWFiINm3a1KquLAO86mZPhYWF9XKHPSKiJ62srAxt27bVu5ndg8gywKuGTZRKJQOciJ4qxgwL8yImEZFMMcCJiGSKAU5EJFMMcCIimWKAExHJlCxnoVDT5j5rl8Hys8ui67klRA2LZ+BERDLFACcikikGOBGRTDHAiYhkigFORCRTDHAiIpligBMRyRTngVOjYGhuN+d1Ez2YUWfga9euRadOnaTbuIaEhCAtLU1afuvWLcTHx8PR0RG2traIiYlBUVGRzjYKCgoQHR0Na2trODk5Yfr06bh7927dHA0RURNiVIC3adMGy5YtQ05ODo4cOYI+ffpg4MCBOHXqFADg9ddfx44dO7B161YcOHAAFy5cwODBg6X1KysrER0djdu3b+Onn37C+vXrkZKSgvnz59ftURERNQEmQgjxOBtwcHDA8uXLMWTIELRs2RKbN2/GkCFDAABnzpxBhw4dkJWVha5duyItLQ39+/fHhQsXoFarAQDJycmYOXMmLl26BAsLi1rts6ysDCqVCqWlpfxBh6eEMUMo/Co9PY0eJdce+SJmZWUlvvjiC1y/fh0hISHIycnBnTt3EB4eLtXx9vaGq6srsrKyAABZWVnw8/OTwhsAIiMjUVZWJp3FExFR7Rh9EfPEiRMICQnBrVu3YGtri9TUVPj4+ODYsWOwsLCAvb29Tn21Wg2NRgMA0Gg0OuFdtbxqWU0qKipQUVEhPS8rKzO22VTPeJZM9OQZfQb+7LPP4tixY8jOzsaECRMQGxuL06dPP4m2SRITE6FSqaQHf5GeiOgRAtzCwgLt27dHQEAAEhMT8dxzz2HVqlVwdnbG7du3UVJSolO/qKgIzs7OAABnZ2e9WSlVz6vqGDJ79myUlpZKj8LCQmObTUT01HnsL/JotVpUVFQgICAAzZo1Q0ZGhrQsNzcXBQUFCAkJAQCEhITgxIkTKC4uluqkp6dDqVTCx8enxn0oFApp6iJ/iZ6I6B6jxsBnz56NqKgouLq6ory8HJs3b8b+/fuxZ88eqFQqjB07FlOnToWDgwOUSiVee+01hISEoGvXrgCAiIgI+Pj4YNSoUXj33Xeh0Wjw5ptvIj4+HgqF4okcIFFtcdye5MaoAC8uLsbo0aNx8eJFqFQqdOrUCXv27EHfvn0BACtXroSpqSliYmJQUVGByMhIrFmzRlrfzMwMO3fuxIQJExASEgIbGxvExsZi8eLFdXtURERNgFEB/tlnnz1wuaWlJVavXo3Vq1fXWMfNzQ3fffedMbslemQ8q6anGW9mRUQkUwxwIiKZYoATEckUA5yISKYY4EREMsUAJyKSKQY4EZFMMcCJiGSKAU5EJFP8UWOqNX6rkahx4Rk4EZFMMcCJiGSKAU5EJFMMcCIimWKAExHJFAOciEimGOBERDLFACcikil+kYfoERn6YhO/1ET1iWfgREQyxQAnIpIpBjgRkUwxwImIZIoBTkQkUwxwIiKZYoATEckUA5yISKYY4EREMsUAJyKSKQY4EZFM8V4oTRzv50EkXzwDJyKSKQY4EZFMGRXgiYmJ6NKlC+zs7ODk5IRBgwYhNzdXp86tW7cQHx8PR0dH2NraIiYmBkVFRTp1CgoKEB0dDWtrazg5OWH69Om4e/fu4x8NEVETYlSAHzhwAPHx8fj555+Rnp6OO3fuICIiAtevX5fqvP7669ixYwe2bt2KAwcO4MKFCxg8eLC0vLKyEtHR0bh9+zZ++uknrF+/HikpKZg/f37dHRURURNg1EXM3bt36zxPSUmBk5MTcnJy0LNnT5SWluKzzz7D5s2b0adPHwDAunXr0KFDB/z888/o2rUr9u7di9OnT+P777+HWq2Gv78/3nrrLcycORMLFy6EhYVF3R0dEdFT7LHGwEtLSwEADg4OAICcnBzcuXMH4eHhUh1vb2+4uroiKysLAJCVlQU/Pz+o1WqpTmRkJMrKynDq1KnHaQ4RUZPyyNMItVotEhIS0L17d3Ts2BEAoNFoYGFhAXt7e526arUaGo1GqnN/eFctr1pmSEVFBSoqKqTnZWVlj9psIqKnxiOfgcfHx+PkyZP44osv6rI9BiUmJkKlUkmPtm3bPvF9EhE1do8U4JMmTcLOnTuxb98+tGnTRip3dnbG7du3UVJSolO/qKgIzs7OUp3qs1KqnlfVqW727NkoLS2VHoWFhY/SbCKip4pRAS6EwKRJk5CamorMzEx4eHjoLA8ICECzZs2QkZEhleXm5qKgoAAhISEAgJCQEJw4cQLFxcVSnfT0dCiVSvj4+Bjcr0KhgFKp1HkQETV1Ro2Bx8fHY/Pmzfjmm29gZ2cnjVmrVCpYWVlBpVJh7NixmDp1KhwcHKBUKvHaa68hJCQEXbt2BQBERETAx8cHo0aNwrvvvguNRoM333wT8fHxUCgUdX+ERERPKaMCfO3atQCAXr166ZSvW7cOY8aMAQCsXLkSpqamiImJQUVFBSIjI7FmzRqprpmZGXbu3IkJEyYgJCQENjY2iI2NxeLFix/vSIiImhijAlwI8dA6lpaWWL16NVavXl1jHTc3N3z33XfG7JpI1gzdNAzgjcPo8fBeKEREMsXbydJTo6azXKKnFc/AiYhkigFORCRTDHAiIpligBMRyRQDnIhIpjgL5SnEHyomahp4Bk5EJFMMcCIimWKAExHJFMfAqdHiNyuJHoxn4EREMsUAJyKSKQY4EZFMcQycHpsxY9Uc1yaqOwxwapL4h4SeBhxCISKSKQY4EZFMMcCJiGSKAU5EJFMMcCIimeIsFKKHqO9pkjVtg7cEpup4Bk5EJFMMcCIimWKAExHJFMfASQ+/pUgkDwxwGeBFLSIyhEMoREQyxQAnIpIpBjgRkUxxDJxI5gxdI+H1kaaBZ+BERDLFACcikimjA/zgwYMYMGAAXFxcYGJigu3bt+ssF0Jg/vz5aNWqFaysrBAeHo68vDydOlevXsXIkSOhVCphb2+PsWPH4tq1a491IERETY3RAX79+nU899xzWL16tcHl7777Lj744AMkJycjOzsbNjY2iIyMxK1bt6Q6I0eOxKlTp5Ceno6dO3fi4MGDGD9+/KMfBRFRE2T0RcyoqChERUUZXCaEQFJSEt58800MHDgQAPD5559DrVZj+/btGD58OH755Rfs3r0bhw8fRmBgIADgww8/xAsvvID33nsPLi4uj3E4RERNR52Ogefn50Oj0SA8PFwqU6lUCA4ORlZWFgAgKysL9vb2UngDQHh4OExNTZGdnW1wuxUVFSgrK9N5EBE1dXUa4BqNBgCgVqt1ytVqtbRMo9HAyclJZ7m5uTkcHBykOtUlJiZCpVJJj7Zt29Zls4mIZEkWs1Bmz56N0tJS6VFYWNjQTSIianB1+kUeZ2dnAEBRURFatWollRcVFcHf31+qU1xcrLPe3bt3cfXqVWn96hQKBRQKRV02tcnhHQaJnj51egbu4eEBZ2dnZGRkSGVlZWXIzs5GSEgIACAkJAQlJSXIycmR6mRmZkKr1SI4OLgum0NE9FQz+gz82rVr+O2336Tn+fn5OHbsGBwcHODq6oqEhAQsWbIEXl5e8PDwwLx58+Di4oJBgwYBADp06IB+/fph3LhxSE5Oxp07dzBp0iQMHz6cM1CIiIxgdIAfOXIEvXv3lp5PnToVABAbG4uUlBTMmDED169fx/jx41FSUoIePXpg9+7dsLS0lNbZtGkTJk2ahLCwMJiamiImJgYffPBBHRwOEVHTYXSA9+rVC0KIGpebmJhg8eLFWLx4cY11HBwcsHnzZmN3TURE95HFLBQiItLH28k2EP5MGhmLM4moOp6BExHJFAOciEimOIQiY/xITdS0McCJGhD/CNPjYIATNSH8/cynC8fAiYhkigFORCRTDHAiIpligBMRyRQDnIhIphjgREQyxQAnIpIpzgMnegrxC0JNA8/AiYhkigFORCRTDHAiIpniGHgjw7FLaiz4oyONH8/AiYhkigFORCRTDHAiIpligBMRyRQDnIhIpjgLpQ7x106IqD7xDJyISKZ4Bk5Ej41zxhsGA/wJ4xdz6GnD13TjwQAnauIYyPLFMXAiIpligBMRyRSHUB6AHy2JqDHjGTgRkUw1WICvXr0a7u7usLS0RHBwMA4dOtRQTSEikqUGGUL517/+halTpyI5ORnBwcFISkpCZGQkcnNz4eTk1BBNIqJGinPMa9YgAf7+++9j3LhxiIuLAwAkJydj165d+Oc//4lZs2Y1RJOI6Alg+D5Z9R7gt2/fRk5ODmbPni2VmZqaIjw8HFlZWQbXqaioQEVFhfS8tLQUAFBWVmb0/jsu2GP0OkRUt1xf31qv2zi5KFKvrKYsMFS3PlTlmRCi1uvUe4BfvnwZlZWVUKvVOuVqtRpnzpwxuE5iYiIWLVqkV962bdsn0kYierqokp5M3SehvLwcKpWqVnVlMY1w9uzZmDp1qvRcq9Xi6tWrcHR0hImJSQO27MkrKytD27ZtUVhYCKVS2dDNadTYV8Zhf9VeffSVEALl5eVwcXGp9Tr1HuAtWrSAmZkZioqKdMqLiorg7OxscB2FQgGFQqFTZm9v/6Sa2CgplUq+yWqJfWUc9lftPem+qu2Zd5V6n0ZoYWGBgIAAZGRkSGVarRYZGRkICQmp7+YQEclWgwyhTJ06FbGxsQgMDERQUBCSkpJw/fp1aVYKERE9XIME+P/93//h0qVLmD9/PjQaDfz9/bF79269C5t0b/howYIFekNIpI99ZRz2V+011r4yEcbMWSEiokaD90IhIpIpBjgRkUwxwImIZIoBTkQkUwzwRmjZsmUwMTFBQkKCVHbr1i3Ex8fD0dERtra2iImJ0fsyVFNy/vx5/O1vf4OjoyOsrKzg5+eHI0eOSMuFEJg/fz5atWoFKysrhIeHIy8vrwFb3DAqKysxb948eHh4wMrKCp6ennjrrbd07rfRVPvq4MGDGDBgAFxcXGBiYoLt27frLK9Nv1y9ehUjR46EUqmEvb09xo4di2vXrtXfQQhqVA4dOiTc3d1Fp06dxJQpU6TyV199VbRt21ZkZGSII0eOiK5du4pu3bo1XEMb0NWrV4Wbm5sYM2aMyM7OFn/88YfYs2eP+O2336Q6y5YtEyqVSmzfvl0cP35cvPjii8LDw0PcvHmzAVte/5YuXSocHR3Fzp07RX5+vti6dauwtbUVq1atkuo01b767rvvxNy5c8W2bdsEAJGamqqzvDb90q9fP/Hcc8+Jn3/+Wfzwww+iffv2YsSIEfV2DAzwRqS8vFx4eXmJ9PR0ERoaKgV4SUmJaNasmdi6datU95dffhEARFZWVgO1tuHMnDlT9OjRo8blWq1WODs7i+XLl0tlJSUlQqFQiC1bttRHExuN6Oho8fe//12nbPDgwWLkyJFCCPZVleoBXpt+OX36tAAgDh8+LNVJS0sTJiYm4vz58/XSbg6hNCLx8fGIjo5GeHi4TnlOTg7u3LmjU+7t7Q1XV9cab8H7NPv2228RGBiIoUOHwsnJCZ07d8ann34qLc/Pz4dGo9HpL5VKheDg4CbXX926dUNGRgZ+/fVXAMDx48fx448/IioqCgD7qia16ZesrCzY29sjMDBQqhMeHg5TU1NkZ2fXSztlcTfCpuCLL77Af/7zHxw+fFhvmUajgYWFhd4NvNRqNTQaTT21sPH4448/sHbtWkydOhVz5szB4cOHMXnyZFhYWCA2NlbqE0O3LG5q/TVr1iyUlZXB29sbZmZmqKysxNKlSzFy5EgAYF/VoDb9otFo9H5BzNzcHA4ODvXWdwzwRqCwsBBTpkxBeno6LC0tG7o5jZ5Wq0VgYCDefvttAEDnzp1x8uRJJCcnIzY2toFb17h8+eWX2LRpEzZv3gxfX18cO3YMCQkJcHFxYV89BTiE0gjk5OSguLgYf/nLX2Bubg5zc3McOHAAH3zwAczNzaFWq3H79m2UlJTorPegW/A+zVq1agUfHx+dsg4dOqCgoAAApD4x5pbFT6vp06dj1qxZGD58OPz8/DBq1Ci8/vrrSExMBMC+qklt+sXZ2RnFxcU6y+/evYurV6/WW98xwBuBsLAwnDhxAseOHZMegYGBGDlypPTvZs2a6dyCNzc3FwUFBU3yFrzdu3dHbm6uTtmvv/4KNzc3AICHhwecnZ11+qusrAzZ2dlNrr9u3LgBU1Pdt7mZmRm0Wi0A9lVNatMvISEhKCkpQU5OjlQnMzMTWq0WwcHB9dPQerlUSka7fxaKEPemEbq6uorMzExx5MgRERISIkJCQhqugQ3o0KFDwtzcXCxdulTk5eWJTZs2CWtra7Fx40apzrJly4S9vb345ptvxH//+18xcODAJjE1rrrY2FjRunVraRrhtm3bRIsWLcSMGTOkOk21r8rLy8XRo0fF0aNHBQDx/vvvi6NHj4pz584JIWrXL/369ROdO3cW2dnZ4scffxReXl6cRkj6AX7z5k0xceJE0bx5c2FtbS1eeuklcfHixYZrYAPbsWOH6Nixo1AoFMLb21t88sknOsu1Wq2YN2+eUKvVQqFQiLCwMJGbm9tArW04ZWVlYsqUKcLV1VVYWlqKdu3aiblz54qKigqpTlPtq3379gkAeo/Y2FghRO365cqVK2LEiBHC1tZWKJVKERcXJ8rLy+vtGHg7WSIimeIYOBGRTDHAiYhkigFORCRTDHAiIpligBMRyRQDnIhIphjgREQyxQAnIpIpBjgRkUwxwImIZIoBTkQkUwxwIiKZ+n91Ca6ur0I8eAAAAABJRU5ErkJggg==", 635 | "text/plain": [ 636 | "
" 637 | ] 638 | }, 639 | "metadata": {}, 640 | "output_type": "display_data" 641 | } 642 | ], 643 | "source": [ 644 | "tokenized_lengths = [len(t) for t in tok.encode(train_texts)['tokens']]\n", 645 | "plt.figure(figsize=(4, 2))\n", 646 | "plt.hist(tokenized_lengths, bins=50)\n", 647 | "plt.title('BPE 1024 tokenizer sequence lengths')\n", 648 | "plt.show()" 649 | ] 650 | }, 651 | { 652 | "cell_type": "markdown", 653 | "metadata": {}, 654 | "source": [ 655 | "### Обучение RNN" 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": 25, 661 | "metadata": {}, 662 | "outputs": [], 663 | "source": [ 664 | "import torch\n", 665 | "import torch.nn as nn\n", 666 | "from typing import Optional\n", 667 | "\n", 668 | "\n", 669 | "class RNN(nn.Module):\n", 670 | " def __init__(self, vocab_size, hidden_size):\n", 671 | " super().__init__()\n", 672 | " self.hidden_size = hidden_size\n", 673 | "\n", 674 | " self.embeddings = nn.Embedding(vocab_size, hidden_size)\n", 675 | "\n", 676 | " self.W = nn.Linear(hidden_size + hidden_size, hidden_size)\n", 677 | " self.O = nn.Linear(hidden_size, vocab_size)\n", 678 | "\n", 679 | " def forward(self, input_ids, h0: Optional[torch.Tensor] = None):\n", 680 | " batch_size, seq_len = input_ids.shape\n", 681 | " x = self.embeddings(input_ids)\n", 682 | "\n", 683 | " if h0 is None:\n", 684 | " h_t = torch.zeros(batch_size, self.hidden_size, device=x.device)\n", 685 | " else:\n", 686 | " h_t = h0\n", 687 | "\n", 688 | " outputs = []\n", 689 | " for t in range(seq_len):\n", 690 | " x_t = x[:, t, :]\n", 691 | "\n", 692 | " h_t = torch.tanh(\n", 693 | " self.W(torch.cat((x_t, h_t), dim=-1))\n", 694 | " )\n", 695 | "\n", 696 | " o_t = self.O(h_t)\n", 697 | " outputs.append(o_t)\n", 698 | "\n", 699 | " outputs = torch.stack(outputs, dim=1)\n", 700 | "\n", 701 | " return outputs, h_t" 702 | ] 703 | }, 704 | { 705 | "cell_type": "code", 706 | "execution_count": 26, 707 | "metadata": {}, 708 | "outputs": [], 709 | "source": [ 710 | "def train(model, dataloader, optimizer, loss_fn):\n", 711 | " model.train()\n", 712 | "\n", 713 | " accuracies = []\n", 714 | " for input_ids in dataloader:\n", 715 | " input_ids = input_ids.to(device)\n", 716 | "\n", 717 | " logits, _ = model(input_ids)\n", 718 | " \n", 719 | " shift_ids = input_ids[:, 1:]\n", 720 | " shift_logits = logits[:, :-1]\n", 721 | " loss = loss_fn(shift_logits.permute(0, 2, 1), shift_ids)\n", 722 | "\n", 723 | " optimizer.zero_grad()\n", 724 | " loss.backward()\n", 725 | " optimizer.step()\n", 726 | "\n", 727 | " accuracy = (shift_logits.argmax(-1) == shift_ids).float().mean().item()\n", 728 | " accuracies.append(accuracy)\n", 729 | " \n", 730 | " return np.mean(accuracies)\n", 731 | "\n", 732 | "\n", 733 | "@torch.no_grad()\n", 734 | "def evaluate(model, dataloader, loss_fn):\n", 735 | " model.eval()\n", 736 | "\n", 737 | " accuracies = []\n", 738 | " losses = []\n", 739 | " for input_ids in dataloader:\n", 740 | " input_ids = input_ids.to(device)\n", 741 | "\n", 742 | " logits, _ = model(input_ids)\n", 743 | " \n", 744 | " shift_ids = input_ids[:, 1:]\n", 745 | " shift_logits = logits[:, :-1]\n", 746 | " loss = loss_fn(shift_logits.permute(0, 2, 1), shift_ids)\n", 747 | "\n", 748 | " accuracies.append((shift_logits.argmax(-1) == shift_ids).float().mean().item())\n", 749 | " losses.append(loss.item())\n", 750 | "\n", 751 | " loss = np.mean(losses)\n", 752 | " accuracy = np.mean(accuracies)\n", 753 | "\n", 754 | " return accuracy, loss" 755 | ] 756 | }, 757 | { 758 | "cell_type": "code", 759 | "execution_count": 27, 760 | "metadata": {}, 761 | "outputs": [], 762 | "source": [ 763 | "from torch.distributions.categorical import Categorical\n", 764 | "\n", 765 | "\n", 766 | "@torch.inference_mode()\n", 767 | "def generate(tokenizer, model, batch_size=4, max_length=40):\n", 768 | " input_ids = torch.empty(size=(batch_size, 1), device=device).fill_(tokenizer.bos_token_id).int()\n", 769 | " h_t = torch.zeros(batch_size, model.hidden_size, device=device)\n", 770 | " gen_ids = []\n", 771 | " for i in range(max_length):\n", 772 | " logits, h_t = model(input_ids, h_t)\n", 773 | " input_ids = Categorical(logits=logits).sample()\n", 774 | " gen_ids.append(input_ids)\n", 775 | " \n", 776 | " return torch.cat(gen_ids, dim=1)" 777 | ] 778 | }, 779 | { 780 | "cell_type": "markdown", 781 | "metadata": { 782 | "tags": [] 783 | }, 784 | "source": [ 785 | "## Натренируем модели на разных токенизаторах" 786 | ] 787 | }, 788 | { 789 | "cell_type": "code", 790 | "execution_count": 28, 791 | "metadata": {}, 792 | "outputs": [], 793 | "source": [ 794 | "from torch.nn.utils.rnn import pad_sequence\n", 795 | "from torch.utils.data import DataLoader\n", 796 | "\n", 797 | "\n", 798 | "def train_rnn_model(tokenizer): \n", 799 | " def collate_fn(batch, max_length=None):\n", 800 | " tokens = tokenizer.encode(batch, max_length=max_length)['input_ids']\n", 801 | " return pad_sequence(tokens, padding_value=tokenizer.pad_token_id, batch_first=True)\n", 802 | " max_length = 150\n", 803 | " collate = partial(collate_fn, max_length=max_length)\n", 804 | " train_loader = DataLoader(train_texts, collate_fn=collate, shuffle=True, batch_size=256)\n", 805 | " test_loader = DataLoader(test_texts, collate_fn=collate, shuffle=False, batch_size=256)\n", 806 | " model = RNN(vocab_size=len(tokenizer), hidden_size=512).to(device)\n", 807 | "\n", 808 | " print('Number of parameters:', sum(p.numel() for p in model.parameters()))\n", 809 | "\n", 810 | " optimizer = torch.optim.Adam(model.parameters(), lr=4e-4, weight_decay=1e-3)\n", 811 | " loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)\n", 812 | " \n", 813 | " for epoch in tqdm(range(20)):\n", 814 | " train_accuracy = train(model, train_loader, optimizer, loss_fn)\n", 815 | " test_accuracy, test_loss = evaluate(model, test_loader, loss_fn)\n", 816 | " if epoch % 5 == 0:\n", 817 | " print(train_accuracy, test_accuracy, test_loss)\n", 818 | " \n", 819 | " _, test_loss = evaluate(model, test_loader, loss_fn)\n", 820 | " \n", 821 | " return model, test_loss" 822 | ] 823 | }, 824 | { 825 | "cell_type": "code", 826 | "execution_count": 29, 827 | "metadata": {}, 828 | "outputs": [ 829 | { 830 | "name": "stdout", 831 | "output_type": "stream", 832 | "text": [ 833 | "\n", 834 | "\n", 835 | "\n", 836 | "Number of parameters: 4723200\n" 837 | ] 838 | }, 839 | { 840 | "name": "stderr", 841 | "output_type": "stream", 842 | "text": [ 843 | " 5%|▌ | 1/20 [00:04<01:18, 4.12s/it]" 844 | ] 845 | }, 846 | { 847 | "name": "stdout", 848 | "output_type": "stream", 849 | "text": [ 850 | "0.05761686963851389 0.07558483742177487 6.160120987892151\n" 851 | ] 852 | }, 853 | { 854 | "name": "stderr", 855 | "output_type": "stream", 856 | "text": [ 857 | " 30%|███ | 6/20 [00:22<00:51, 3.71s/it]" 858 | ] 859 | }, 860 | { 861 | "name": "stdout", 862 | "output_type": "stream", 863 | "text": [ 864 | "0.11558546870946884 0.11518476940691472 5.2606532096862795\n" 865 | ] 866 | }, 867 | { 868 | "name": "stderr", 869 | "output_type": "stream", 870 | "text": [ 871 | " 55%|█████▌ | 11/20 [00:40<00:33, 3.69s/it]" 872 | ] 873 | }, 874 | { 875 | "name": "stdout", 876 | "output_type": "stream", 877 | "text": [ 878 | "0.1248983919620514 0.12248268499970436 5.065859866142273\n" 879 | ] 880 | }, 881 | { 882 | "name": "stderr", 883 | "output_type": "stream", 884 | "text": [ 885 | " 80%|████████ | 16/20 [00:59<00:14, 3.68s/it]" 886 | ] 887 | }, 888 | { 889 | "name": "stdout", 890 | "output_type": "stream", 891 | "text": [ 892 | "0.13106198459863663 0.12670493684709072 4.956925714015961\n" 893 | ] 894 | }, 895 | { 896 | "name": "stderr", 897 | "output_type": "stream", 898 | "text": [ 899 | "100%|██████████| 20/20 [01:13<00:00, 3.70s/it]\n" 900 | ] 901 | } 902 | ], 903 | "source": [ 904 | "bpe_tokenizer = BPETokenizer(corpus=train_texts, vocab_size=4096)\n", 905 | "bpe_rnn, bpe_test_loss = train_rnn_model(bpe_tokenizer)\n", 906 | "torch.save(bpe_rnn.state_dict(), 'checkpoints/bpe_rnn.pt')" 907 | ] 908 | }, 909 | { 910 | "cell_type": "code", 911 | "execution_count": 30, 912 | "metadata": {}, 913 | "outputs": [ 914 | { 915 | "name": "stdout", 916 | "output_type": "stream", 917 | "text": [ 918 | "Number of parameters: 4723200\n" 919 | ] 920 | }, 921 | { 922 | "name": "stderr", 923 | "output_type": "stream", 924 | "text": [ 925 | " 5%|▌ | 1/20 [00:03<01:01, 3.22s/it]" 926 | ] 927 | }, 928 | { 929 | "name": "stdout", 930 | "output_type": "stream", 931 | "text": [ 932 | "0.0742159198365698 0.1041804663836956 5.520843744277954\n" 933 | ] 934 | }, 935 | { 936 | "name": "stderr", 937 | "output_type": "stream", 938 | "text": [ 939 | " 30%|███ | 6/20 [00:19<00:45, 3.22s/it]" 940 | ] 941 | }, 942 | { 943 | "name": "stdout", 944 | "output_type": "stream", 945 | "text": [ 946 | "0.14468048140406609 0.1473691951483488 4.59027373790741\n" 947 | ] 948 | }, 949 | { 950 | "name": "stderr", 951 | "output_type": "stream", 952 | "text": [ 953 | " 55%|█████▌ | 11/20 [00:35<00:28, 3.22s/it]" 954 | ] 955 | }, 956 | { 957 | "name": "stdout", 958 | "output_type": "stream", 959 | "text": [ 960 | "0.15467707738280295 0.15559826269745827 4.412833058834076\n" 961 | ] 962 | }, 963 | { 964 | "name": "stderr", 965 | "output_type": "stream", 966 | "text": [ 967 | " 80%|████████ | 16/20 [00:51<00:12, 3.22s/it]" 968 | ] 969 | }, 970 | { 971 | "name": "stdout", 972 | "output_type": "stream", 973 | "text": [ 974 | "0.1624450169503689 0.16117852926254272 4.328227865695953\n" 975 | ] 976 | }, 977 | { 978 | "name": "stderr", 979 | "output_type": "stream", 980 | "text": [ 981 | "100%|██████████| 20/20 [01:04<00:00, 3.22s/it]\n" 982 | ] 983 | } 984 | ], 985 | "source": [ 986 | "word_tokenizer = WordTokenizer(corpus=train_texts, vocab_size=4096)\n", 987 | "word_rnn, word_test_loss = train_rnn_model(word_tokenizer)\n", 988 | "torch.save(bpe_rnn.state_dict(), 'checkpoints/word_rnn.pt')" 989 | ] 990 | }, 991 | { 992 | "cell_type": "code", 993 | "execution_count": 31, 994 | "metadata": {}, 995 | "outputs": [ 996 | { 997 | "name": "stdout", 998 | "output_type": "stream", 999 | "text": [ 1000 | "Number of parameters: 608850\n" 1001 | ] 1002 | }, 1003 | { 1004 | "name": "stderr", 1005 | "output_type": "stream", 1006 | "text": [ 1007 | " 5%|▌ | 1/20 [00:03<00:57, 3.04s/it]" 1008 | ] 1009 | }, 1010 | { 1011 | "name": "stdout", 1012 | "output_type": "stream", 1013 | "text": [ 1014 | "0.27662071837112306 0.32306850627064704 2.3889542520046234\n" 1015 | ] 1016 | }, 1017 | { 1018 | "name": "stderr", 1019 | "output_type": "stream", 1020 | "text": [ 1021 | " 30%|███ | 6/20 [00:18<00:42, 3.01s/it]" 1022 | ] 1023 | }, 1024 | { 1025 | "name": "stdout", 1026 | "output_type": "stream", 1027 | "text": [ 1028 | "0.4120487689971924 0.4215001069009304 1.9297221899032593\n" 1029 | ] 1030 | }, 1031 | { 1032 | "name": "stderr", 1033 | "output_type": "stream", 1034 | "text": [ 1035 | " 55%|█████▌ | 11/20 [00:33<00:27, 3.01s/it]" 1036 | ] 1037 | }, 1038 | { 1039 | "name": "stdout", 1040 | "output_type": "stream", 1041 | "text": [ 1042 | "0.4659665122628212 0.468839792907238 1.7640975922346116\n" 1043 | ] 1044 | }, 1045 | { 1046 | "name": "stderr", 1047 | "output_type": "stream", 1048 | "text": [ 1049 | " 80%|████████ | 16/20 [00:48<00:12, 3.01s/it]" 1050 | ] 1051 | }, 1052 | { 1053 | "name": "stdout", 1054 | "output_type": "stream", 1055 | "text": [ 1056 | "0.49858393222093583 0.5012813329696655 1.6595043033361434\n" 1057 | ] 1058 | }, 1059 | { 1060 | "name": "stderr", 1061 | "output_type": "stream", 1062 | "text": [ 1063 | "100%|██████████| 20/20 [01:00<00:00, 3.01s/it]\n" 1064 | ] 1065 | } 1066 | ], 1067 | "source": [ 1068 | "char_tokenizer = CharacterTokenizer(corpus=train_texts)\n", 1069 | "char_rnn, char_test_loss = train_rnn_model(char_tokenizer)\n", 1070 | "torch.save(bpe_rnn.state_dict(), 'checkpoints/char_rnn.pt')" 1071 | ] 1072 | }, 1073 | { 1074 | "cell_type": "code", 1075 | "execution_count": 32, 1076 | "metadata": {}, 1077 | "outputs": [], 1078 | "source": [ 1079 | "bpe_outputs = generate(bpe_tokenizer, bpe_rnn, batch_size=5000)\n", 1080 | "word_outputs = generate(word_tokenizer, word_rnn, batch_size=5000)\n", 1081 | "char_outputs = generate(char_tokenizer, char_rnn, batch_size=5000)" 1082 | ] 1083 | }, 1084 | { 1085 | "cell_type": "markdown", 1086 | "metadata": {}, 1087 | "source": [ 1088 | "# Метрики качества\n", 1089 | "\n", 1090 | "## Валидационная перплексия\n", 1091 | "NLL:\n", 1092 | "$$\n", 1093 | "\\text{NLL} = -\\frac{1}{N} \\sum_{i=1}^{N} \\log P(x_i \\mid x_{