├── .gitignore ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── docker-compose.yml └── work ├── dataset ├── 1_find_numbers.ipynb ├── 2_inverse_normalize.ipynb ├── 3_process_itn.ipynb ├── 4_process_kaggle.ipynb └── word_to_number_ru │ ├── LICENSE.txt │ ├── extractor.py │ └── number.py ├── infer ├── examples.json └── infer.ipynb ├── replaces.py ├── tb.ipynb └── train ├── readme.md ├── train-distributed.ipynb └── train.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .Trash* 2 | models 3 | data 4 | **/.ipynb_checkpoints/* 5 | **/__pycache__ -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG BASE_IMG 2 | FROM $BASE_IMG 3 | 4 | USER root 5 | 6 | RUN apt update && apt install -y curl 7 | RUN curl -fsSL https://deb.nodesource.com/setup_18.x | bash - 8 | RUN DEBIAN_FRONTEND=noninteractive apt-get install -y nodejs 9 | 10 | RUN pip install --upgrade -v \ 11 | "datasets" \ 12 | "ipywidgets" \ 13 | "jupyter" \ 14 | "jupyterlab-git" \ 15 | "jupyterlab>=4.0.0" \ 16 | "matplotlib" \ 17 | "pip" \ 18 | "requests" \ 19 | "sentencepiece" \ 20 | "tensorboard" \ 21 | "tqdm==4.62.2" \ 22 | "transformers[torch]" \ 23 | && rm -rf ~/.cache/pip/* 24 | RUN apt update && apt install -y git 25 | 26 | RUN if ! id jovyan >/dev/null 2>&1; then \ 27 | useradd -m -u 1000 -g 100 -s /bin/bash -d /home/jovyan jovyan; \ 28 | fi 29 | RUN apt install -y sudo \ 30 | && echo "jovyan ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers \ 31 | && usermod -a -G root jovyan \ 32 | && mkdir -p /home/jovyan/.local && mkdir -p /home/jovyan/.jupyter \ 33 | && chown -R 1000:100 /home/jovyan/.local && chown -R 1000:100 /home/jovyan/.jupyter 34 | 35 | ENV SHELL=/bin/bash 36 | ENV JUPYTER_DATA_DIR=/home/jovyan/.local 37 | USER jovyan 38 | WORKDIR /home/jovyan 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 Alexander Stupnikov 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | init: 3 | echo "TOKEN=test" >> .env 4 | up: 5 | docker compose build nb && docker compose up -d nb 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text normalization 2 | 3 | The fully working model is [on huggingface](https://huggingface.co/saarus72/russian_text_normalizer), and a [nice chat](https://huggingface.co/spaces/saarus72/russian-text-normalization) is on HF Space as well. 4 | 5 | ### Why 6 | 7 | A pet project as the (only) [other solution](https://github.com/snakers4/russian_stt_text_normalization) does not seem to be maintainable by either owner or others. Some others do exist and are, like, *bad*. 8 | 9 | > E.g. for input `я купил iphone 10X за 14990 руб без 3-x часов полдень и т.д.` output of 10 | > * [russian_stt_text_normalization](https://github.com/snakers4/russian_stt_text_normalization) itself is `я купил ифон десять кс за четыре девять девять ноль рублей без третьи часов полдень и т.д.` 11 | > * [text-normalization-ru-terrible](https://huggingface.co/maximxls/text-normalization-ru-terrible) is `я купил айфон сто икс за тысячу четыреста девяносто рубле без третьих часов пол`, 12 | > * [text-normalization-ru-new](https://huggingface.co/alexue4/text-normalization-ru-new) is `я купил ифон десять икс за четырнадцать тысяч девять`. 13 | 14 | ### The plan 15 | 16 | I went along with! Took these steps: 17 | 18 | 1. Get a dataset. 19 | > Done with notebooks to [find](./work/dataset/1_find_numbers.ipynb) and to [itn](./work/dataset/2_inverse_normalize.ipynb) texts, then to [construct dataset](./work/dataset/3_process_itn.ipynb). 20 | 1. Download any vast (informal?) russian raw text corpus. Could be 21 | * [IlyaGusev/ficbook](https://huggingface.co/datasets/IlyaGusev/ficbook), 22 | * [IlyaGusev/librusec_full](https://huggingface.co/datasets/IlyaGusev/librusec_full), or 23 | * ~~[Taiga Corpus](https://tatianashavrina.github.io/taiga_site)~~ [pikabu](https://huggingface.co/datasets/IlyaGusev/pikabu)! 24 | 1. Find occurances w/ regexp patterns like `r"двадцат\S+"`, 25 | 1. Make sure there is nothing but cyrillic. 26 | 1. Make inverse text normalization (that task is more straightforward and many good solutions do exist). 27 | * Used ~~[NeMo Text Processing](https://github.com/NVIDIA/NeMo-text-processing)~~ [another python package](https://github.com/flockentanz/word_to_number_ru) with some additions. 28 | 1. Polish things roughly like balance (as `два` seems to be *far* more common than `двумястами`), get rid of ITN mistakes etc. 29 | 1. Train an MVP. 30 | > Done with notebooks to [train](./work/train/train.ipynb) and to [distributed train](./work/train/train-distributed.ipynb) a model. 31 | 1. Get a relatively big LLM as we are going to prune it after (and to onnx it as well so that the resulting performance is compatible with the solution I've mentioned). 32 | * Seems to be [ai-forever/FRED-T5-1.7B](https://huggingface.co/ai-forever/FRED-T5-1.7B) as it is encoder-decoder, trainable on single **RTX3060 12GB** and good enough to get an MVP. 33 | > Turned out that 12GB is enough to inference it only so I've trained [ai-forever/FRED-T5-large](https://huggingface.co/ai-forever/FRED-T5-large). 34 | 35 | > I've managed to run **FRED-T5-1.7B** train on two 12GB GPUs using [`tensor_parallel`](https://github.com/BlackSamorez/tensor_parallel) package but model did not perform notably better. Also, the point is to have a small and fast model to infer it on CPU. 36 | 1. Train, like, any barely working model. 37 | * Several attempts are required as it is not clear which prompt is better. May be 38 | ``` 39 | Было у отца [3] сына и [2-3] пиджака с блёстками. 40 | ``` 41 | > Turned out the pattern below works well so I've made no experiments here. 42 | 1. Test and analyze. 43 | 1. ~~Regret deeply.~~ 44 | 1. To obtain a dataset of a better quality, we want to ask really big smart ass LLM to **(not inverse!)** normalize texts during the training. 45 | * Unfortunately, LLM experiments failed. I took instruct models ([Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2), [ruGPT-3.5 13B LoRA](https://huggingface.co/evilfreelancer/ruGPT-3.5-13B-lora), [GigaSaiga](https://huggingface.co/IlyaGusev/gigasaiga_lora), [Saiga2 7B](https://huggingface.co/IlyaGusev/saiga2_13b_lora)) and plain generation ones ([ruGPT-3.5-13B-GPTQ](https://huggingface.co/fffrrt/ruGPT-3.5-13B-GPTQ) and [Vikhr-7b-0.1](https://huggingface.co/AlexWortega/Vikhr-7b-0.1)), but there were always too many mistakes which can not be catch automatically. Well, they _were_, so I decided to... 46 | 1. Take the [Kaggle Text Normalization Challenge](https://www.kaggle.com/competitions/text-normalization-challenge-russian-language) dataset! So I had latin normalization as well. 47 | > Done with a notebook to [process kaggle data](./work/dataset/4_process_kaggle.ipynb). 48 | 1. Train everything again at last. Put [on hf](https://huggingface.co/saarus72/russian_text_normalizer). 49 | 50 | ## Inverse Text Normalization 51 | 52 | There are but a few packages from namely 53 | * NVidia's [NeMo](https://github.com/NVIDIA/NeMo-text-processing), 54 | * [Oknolaz](https://github.com/Oknolaz/Russian_w2n), 55 | * [SergeyShk](https://github.com/SergeyShk/Word-to-Number-Russian) and its forks from 56 | * [averkij](https://github.com/averkij/Word-to-Number-Russian) and 57 | * [flockentanz](https://github.com/flockentanz/word_to_number_ru). 58 | 59 | **NeMo** works well but tends to miss many cases I won't have missed (see the comparison table below). I used it as the first attempt but did my research then. 60 | 61 | **Oknolaz** needs to be fed with extracted numbers only and does many mistakes in that case even so bad choice for us. 62 | 63 | **SergeyShk** does either 64 | * `replace_groups` — `тысяча сто` to `1100` but `сто двести триста` to `400` or 65 | * `replace` — `сто двести триста` to `100 200 300` but `тысяча сто` to `1000 100`. 66 | 67 | It is obvious that addition should be done on decreasing values only so there are some forks to fix it (the overall code is a mess so that I didn't want to do it myself anyway). 68 | 69 | **averkij** and **flockentanz** work fine both but have some bugs so I took the second one and fixed them. Also I cover cases like `с половиной` and `одна целая две десятых`. 70 | 71 | | Original | 🟡 NeMo TP | 🔴 Oknolaz `replace` | 🔴 SergeyShk `replace_groups` | 🔴 SergeyShk `replace` | 🔴 averkij `replace` | 🔴 flockentanz `replace_groups_sa` | 🟢 flockentanz fixed | 72 | |--|--|--|--|--|--|--|--| 73 | | `сто двести триста да хоть тысячу раз` | 🟢`100 200 300 да хоть 1000 раз` | 🔴`600000` | 🔴`400 да хоть 1000 раз` | 🟢`100 200 300 да хоть 1000 раз` | 🔴`10200 300 да хоть 1000 раз` | 🟢`100 200 300 да хоть 1000 раз` | 🟢`100 200 300 да хоть 1000 раз` | 74 | | `тысяча сто` | 🟢`1100` | 🟢`1100` | 🟢`1100` | 🔴`1000 100` | 🟢`1100` | 🟢`1100` | 🟢`1100` | 75 | | `я видел сто-двести штук` | 🟡`я видел сто-двести штук` | 🔴`300` | 🟢`я видел 100-200 штук` | 🟢`я видел 100-200 штук` | 🟢`я видел 100-200 штук` | 🟢`я видел 100-200 штук` | 🟢`я видел 100-200 штук` | 76 | | `восемь девятьсот двадцать два пять пять пять тридцать пять тридцать пять, лучше позвонить, чем занимать` | 🟡`восемь 922 пять пять пять 35 35 , лучше позвонить, чем занимать` | 🔴`8` | 🔴`115, лучше позвонить, чем занимать` | 🔴`8 900 20 2 5 5 5 30 5 30 5, лучше позвонить, чем занимать` | 🟢`8 922 5 5 5 35 35, лучше позвонить, чем занимать` | 🟢`8 922 5 5 5 35 35, лучше позвонить, чем занимать` | 🟢`8 922 5 5 5 35 35, лучше позвонить, чем занимать` | 77 | | `три с половиной человека` | 🟡`три с половиной человека` | 🔴`3` | 🟡`3 с половиной человека` | 🟡`3 с половиной человека` | 🟢`3.5 человека` | 🟡`3 с половиной человека` | 🟢`3.5 человека` | 78 | | `миллион сто тысяч сто зайцев` | 🟢`1100100 зайцев` | ❌`list index out of range` | 🔴`1000100100 зайцев` | 🔴`1000000 100000 100 зайцев` | `1100100 зайцев` | 🔴`1000100100 зайцев` | 🟢`1100100 зайцев` | 79 | | `одни двойки и ни одной пятёрки` | 🟡`одни двойки и ни одной пятёрки` | 🟡`No valid number words found! ...` | 🟡`1 двойки и ни 1 пятёрки` | 🟡`1 двойки и ни 1 пятёрки` | 🟡`1 двойки и ни 1 пятёрки` | 🟡`1 двойки и ни 1 пятёрки` | 🟡`1 двойки и ни 1 пятёрки` | 80 | | `без одной минуты два` |🟢 `01:59` | 🔴`2` | 🟢`без 1 минуты 2` | 🟢`без 1 минуты 2` | 🟢`без 1 минуты 2` | 🟢`без 1 минуты 2` | 🟢`без 1 минуты 2` | 81 | | `вторая дача пять соток` | 🟡`вторая дача пять соток` | 🔴`5` | 🟢`2 дача 5 соток` | 🟢`2 дача 5 соток` | 🟢`2 дача 5 соток` | 🟢`2 дача 5 соток` | 🟢`2 дача 5 соток` | 82 | | `двести пятьдесят с половиной тысяч отборных солдат Ирака` | 🟡`250 с половиной 1000 отборных солдат Ирака` | 🔴`250000` | 🟡`250 с половиной 1000 отборных солдат Ирака` | 🔴`200 50 с половиной 1000 отборных солдат Ирака` | 🔴`2050000.5 отборных солдат Ирака` | 🟡`250 с половиной 1000 отборных солдат Ирака` | 🟢`250500 отборных солдат Ирака` | 83 | | `ноль целых ноль десятых минус две целых шесть сотых` | 🟢`0,0 -2,06` | 🟡`Redundant number word! ...` | 🔴`0 целых 0.0 минус 2 целых 0.06` | 🔴`0 целых 0.0 минус 2 целых 0.06` | 🔴`0 целых 0.0 минус 2 целых 0.06` | 🔴`0 целых 0.0 минус 2 целых 0.06` | 🟢`0 минус 2.06` | 84 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '2.3' 2 | 3 | services: 4 | nb: 5 | build: 6 | context: . 7 | dockerfile: Dockerfile 8 | args: 9 | BASE_IMG: pytorch/pytorch:2.1.2-cuda12.1-cudnn8-devel 10 | ports: 11 | - 8020:8888 12 | - 7860:7860 # gradio one 13 | - 8006-8010:8006-8010 # tensorboard 14 | volumes: 15 | - ./work/:/home/jovyan/work/ 16 | - ./data/:/home/jovyan/data/ 17 | - ./models/:/home/jovyan/models/ 18 | - /srv/wdc1/:/home/jovyan/wdc1/ 19 | command: 20 | - jupyter 21 | - lab 22 | - --ip=0.0.0.0 23 | - --port=8888 24 | - --no-browser 25 | - --notebook-dir="/home/jovyan/" 26 | - --allow-root 27 | - --LabApp.token="$TOKEN" 28 | user: jovyan 29 | runtime: nvidia 30 | shm_size: '8gb' 31 | restart: always 32 | deploy: 33 | resources: 34 | reservations: 35 | devices: 36 | - driver: nvidia 37 | device_ids: ['0', '1', '2', '3'] 38 | capabilities: [gpu] 39 | -------------------------------------------------------------------------------- /work/dataset/1_find_numbers.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "24a48cdd-f6ec-4663-afa2-2d07c73e1855", 6 | "metadata": {}, 7 | "source": [ 8 | "# Find numbers\n", 9 | "\n", 10 | "among plain text corpus.\n", 11 | "\n", 12 | "We are looking for numbers now. First, we download that.\n", 13 | "\n", 14 | "* [IlyaGusev/ficbook](https://huggingface.co/datasets/IlyaGusev/ficbook)\n", 15 | "* [IlyaGusev/librusec](https://huggingface.co/datasets/IlyaGusev/librusec)\n", 16 | "* [IlyaGusev/pikabu](https://huggingface.co/datasets/IlyaGusev/pikabu)\n", 17 | "\n", 18 | "> `pip install datasets zstandard jsonlines pysimdjson` is advised.\n", 19 | "\n", 20 | "The most simple way is to execute `git clone https://huggingface.co/datasets/IlyaGusev/librusec` eg.\n", 21 | "\n", 22 | "> One is necessarily to turn on an lfs support though.\n", 23 | "> \n", 24 | "> ```\n", 25 | "> curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash\n", 26 | "> sudo apt-get install git-lfs\n", 27 | "> git lfs install\n", 28 | "> ```" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "d603b6dc-70ea-4272-a7cf-a4636c606fb6", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "from datasets import load_dataset\n", 39 | "from pprint import pprint, pformat\n", 40 | "from tqdm.notebook import tqdm\n", 41 | "import json" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "id": "02e7c078-d109-4504-b1d9-196a1d9c0762", 47 | "metadata": {}, 48 | "source": [ 49 | "Change pathes below to where the datasets are downloaded to." 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "id": "4470b84f-2702-4ac1-a494-81affd14f9da", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "PATHES = {\n", 60 | " \"librusec\": {\n", 61 | " \"input\": \"/home/jovyan/wdc1/datasets/_PLAIN/librusec\",\n", 62 | " \"output\": \"/home/jovyan/data/librusec.jsonl\",\n", 63 | " },\n", 64 | " \"ficbook\": {\n", 65 | " \"input\": \"/home/jovyan/wdc1/datasets/_PLAIN/ficbook\",\n", 66 | " \"output\": \"/home/jovyan/data/ficbook.jsonl\",\n", 67 | " },\n", 68 | " \"pikabu\": {\n", 69 | " \"input\": \"/home/jovyan/wdc1/datasets/_WEB20/pikabu\",\n", 70 | " \"output\": \"/home/jovyan/data/pikabu.json\",\n", 71 | " },\n", 72 | "}" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "id": "76ee54e2-e2f2-4199-a6fd-69dc18ebb06b", 78 | "metadata": {}, 79 | "source": [ 80 | "Now we use the most direct approach and just morph a number in all the ways possible." 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "id": "6d69319c-7301-4f16-ad4a-88822b9fe74f", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "numbers = [\n", 91 | " 'ноль', 'нуль',\n", 92 | " 'один', 'два', 'двe', 'три', 'четыре', 'пять', 'шесть', 'семь', 'восемь', 'девять', 'десять',\n", 93 | " 'одиннадцать', 'двенадцать', 'тринадцать', 'четырнадцать', 'пятнадцать', 'шестнадцать', 'семнадцать', 'восемнадцать', 'девятнадцать', 'двадцать',\n", 94 | " 'тридцать', 'сорок', 'пятьдесят', 'шестьдесят', 'семьдесят', 'восемьдесят', 'девяносто', 'сто',\n", 95 | " 'двести', 'триста', 'четыреста', 'пятьсот', 'шестьсот', 'семьсот', 'восемьсот', 'девятьсот',\n", 96 | " 'тысяча', 'миллион', 'миллиард', 'триллион',\n", 97 | "]" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "id": "4e5fbcbd-32bd-43a6-a830-ee2520342c9f", 103 | "metadata": {}, 104 | "source": [ 105 | "Turned out that some breaking changes happened between 0.8 and 0.9 versions of pymorphy.\n", 106 | "Particularily, `второй` is no longer in a lexeme of `два` in 0.9.\n", 107 | "As I more fond of the previous behaviour I downgrade the package to 0.8.\n", 108 | "\n", 109 | " pip install pymorphy2==0.8 pymorphy2-dicts-ru" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "id": "281c4ede-b349-4fcc-9054-1a3d58b06219", 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "import pymorphy2\n", 120 | "from itertools import chain\n", 121 | "\n", 122 | "morph = pymorphy2.MorphAnalyzer()\n", 123 | "\n", 124 | "def get_lexeme(word):\n", 125 | " return set(chain(*([_.word for _ in parsing.lexeme] for parsing in morph.parse(word) if parsing.tag.POS in (\"NUMR\", \"NOUN\"))))\n", 126 | "\n", 127 | "get_lexeme(\"два\")" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "id": "5a1cb98d-a606-40f3-a3e3-2dcf0261c8b2", 133 | "metadata": {}, 134 | "source": [ 135 | "We face some mistakes as `семь` would be inflected as `семью` which is a form of `семья` as well so that we might want to do something about in in the future.\n", 136 | "Anyway we may do not find any numbers there later.\n", 137 | "We do an MVP now though so let it be." 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "id": "69874e0d-718c-4fee-bdae-85f3bc5fbe92", 143 | "metadata": {}, 144 | "source": [ 145 | "To not to search all the forms inflected one may to find a common part and change the (future) corresponding regexp according to it—and perform a fast `.contains()` check beforehand." 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "id": "82fe660b-51b3-442d-8a5e-c07e09d05ffc", 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "def get_max_common(words):\n", 156 | " \"\"\"\n", 157 | " Find a leading part only.\n", 158 | "\n", 159 | " get_max_common([\"мама\", \"мать\", \"матриарх\"]) -> \"ма\"\n", 160 | " \"\"\"\n", 161 | " words = list(words)\n", 162 | " if not words:\n", 163 | " return None\n", 164 | " result = words[0]\n", 165 | " for word in words[1:]:\n", 166 | " if word.startswith(result):\n", 167 | " continue\n", 168 | " for i, (ch1, ch2) in enumerate(zip(result, word)):\n", 169 | " if ch1 != ch2:\n", 170 | " result = result[:i]\n", 171 | " break\n", 172 | " return result\n", 173 | "\n", 174 | "get_max_common(get_lexeme(\"три\"))" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "id": "e2c75598-23e5-4440-9bc5-09bf59a9eaf3", 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "import re\n", 185 | "\n", 186 | "numbers_data = {}\n", 187 | "for number in numbers:\n", 188 | " elem = {\n", 189 | " \"word\": number,\n", 190 | " \"lexeme\": get_lexeme(number)\n", 191 | " }\n", 192 | " elem[\"substr\"] = get_max_common(elem[\"lexeme\"])\n", 193 | " elem[\"regexp\"] = re.compile(fr'\\b({elem[\"substr\"]}(?:{\"|\".join((_[len(elem[\"substr\"]):] for _ in elem[\"lexeme\"]))}))\\b')\n", 194 | " numbers_data[number] = elem\n", 195 | "numbers_data[\"одиннадцать\"]" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "id": "5d494b04-ee12-4809-919c-16c43ce8686f", 201 | "metadata": {}, 202 | "source": [ 203 | "Now lets inspect what had we downloaded so far." 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "id": "16efd176-889a-4c31-ba5a-103c9779a8f4", 209 | "metadata": {}, 210 | "source": [ 211 | "# pikabu" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "id": "503ef2d6-e387-4848-9191-61e87c91c2b5", 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "next(iter(load_dataset(PATHES[\"pikabu\"][\"input\"], split=\"train\", streaming=True)))" 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "id": "bcfc66d5-2c66-471b-8982-a50998b84805", 227 | "metadata": {}, 228 | "source": [ 229 | "So we want to split texts as they are too big to fit into GPU as LLM train.\n", 230 | "\n", 231 | "We do not want to split on **sentences** now as the LLM we will train should see not single sentences only.\n", 232 | "One is not trivial to combine arbitrary sentences together.\n", 233 | "\n", 234 | "To split on paragraths (like `.split(\"\\n\")`) seems to be a good approach.\n", 235 | "\n", 236 | "We do not want to see latin and digits for now as we dont know how to normalize it so we filter any sentence containing." 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "id": "d2e2b308-905a-4fdd-867a-642a7885a462", 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "import re\n", 247 | "\n", 248 | "\n", 249 | "regexp_lat_dig = re.compile(r\"[a-zA-Z\\d]+\")\n", 250 | "\n", 251 | "\n", 252 | "def get_matches(text):\n", 253 | " texts = text.split(\"\\n\")\n", 254 | " result = []\n", 255 | " for text in texts:\n", 256 | " if re.search(regexp_lat_dig, text):\n", 257 | " continue\n", 258 | " matches = []\n", 259 | " for number, elem in numbers_data.items():\n", 260 | " if elem[\"substr\"] and elem[\"substr\"] not in text:\n", 261 | " continue\n", 262 | " if match := re.search(elem[\"regexp\"], text):\n", 263 | " matches.append({\"number\": number, \"place\": match.span(), \"form\": match[0]})\n", 264 | " if matches:\n", 265 | " result.append({\n", 266 | " \"text\": text,\n", 267 | " \"matches\": matches\n", 268 | " })\n", 269 | " return result" 270 | ] 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "id": "d51a306d-e23e-4be0-af99-4a1c7bff63fb", 275 | "metadata": {}, 276 | "source": [ 277 | "Now we are going to process a corpus and save the result into `jsonl` file now.\n", 278 | "\n", 279 | "I use multiprocessing as multiprocessing goes brrr." 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": null, 285 | "id": "31a2be56-76e4-481f-ada5-dce60feba86c", 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "from multiprocessing import Process, Queue\n", 290 | "from multiprocessing import Pool\n", 291 | "\n", 292 | "\n", 293 | "queue = Queue()\n", 294 | "\n", 295 | "\n", 296 | "def process_example(**kwargs):\n", 297 | " if matches := get_matches(kwargs[\"text_markdown\"]):\n", 298 | " queue.put({\n", 299 | " \"index\": kwargs[\"id\"],\n", 300 | " \"texts\": matches\n", 301 | " })\n", 302 | "\n", 303 | "\n", 304 | "def write(queue):\n", 305 | " f = open(PATHES[\"pikabu\"][\"output\"], \"w\")\n", 306 | " while True:\n", 307 | " item = queue.get()\n", 308 | " if item is None:\n", 309 | " break\n", 310 | " json.dump(item, f, ensure_ascii=False)\n", 311 | " f.write(\"\\n\")\n", 312 | " f.close()\n", 313 | "\n", 314 | "\n", 315 | "writer = Process(target=write, args=(queue, ))\n", 316 | "writer.start()\n", 317 | "dataset = load_dataset(PATHES[\"pikabu\"][\"input\"], split=\"train\", streaming=True)\n", 318 | "with Pool(15) as p:\n", 319 | " for example in tqdm(dataset):\n", 320 | " p.apply(process_example, kwds={**example})\n", 321 | " queue.put(None)\n", 322 | "p.join()\n", 323 | "writer.join()" 324 | ] 325 | }, 326 | { 327 | "cell_type": "markdown", 328 | "id": "bcd352da-184d-46f8-8d67-b8ec574c0521", 329 | "metadata": {}, 330 | "source": [ 331 | "# librusec" 332 | ] 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "id": "27da7a82-c7a7-4fc7-a9f5-214b78d693b9", 337 | "metadata": { 338 | "collapsed": true, 339 | "jupyter": { 340 | "outputs_hidden": true 341 | } 342 | }, 343 | "source": [ 344 | "Paragraths here are too big so we sentencize the texts." 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "id": "c6f7ef5b-f66d-4a48-afb5-6435113093bd", 350 | "metadata": { 351 | "collapsed": true, 352 | "jupyter": { 353 | "outputs_hidden": true 354 | } 355 | }, 356 | "source": [ 357 | "Tried to use stanza but it turned out to be too slow.\n", 358 | "\n", 359 | "```\n", 360 | "!pip install stanza\n", 361 | "import stanza\n", 362 | "stanza.download('ru')\n", 363 | "nlp = stanza.Pipeline('ru', processors='tokenize')\n", 364 | "```\n", 365 | "\n", 366 | "Ended up with using spacy (haha classic)." 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": null, 372 | "id": "52f99306-109d-47e6-b50a-05c1416a05d1", 373 | "metadata": { 374 | "scrolled": true 375 | }, 376 | "outputs": [], 377 | "source": [ 378 | "!pip install -U pip setuptools wheel\n", 379 | "!pip install -U spacy\n", 380 | "!python -m spacy download ru_core_news_sm" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": null, 386 | "id": "8a997e8e-aed5-4cb7-99fd-2622165e62c6", 387 | "metadata": {}, 388 | "outputs": [], 389 | "source": [ 390 | "import re\n", 391 | "import spacy\n", 392 | "\n", 393 | "\n", 394 | "nlp_sentencizer = spacy.blank(\"ru\")\n", 395 | "nlp_sentencizer.add_pipe(\"sentencizer\")\n", 396 | "text = '\"В ходе проверочных мероприятий в целях профилактики правонарушений сотрудниками полиции было доставлено для административного разбирательства из центральной части города около 3 тысяч иностранных граждан. Как выяснилось, более 600 мигрантов находятся на территории России с различными нарушениями миграционного законодательства. Все они привлечены к административной ответственности\", - отметил собеседник агентства.'\n", 397 | "tokens = nlp_sentencizer(text)\n", 398 | "[str(sent) for sent in tokens.sents]" 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "id": "539dbb31-b03c-423f-87ef-9f0e5e3b432a", 404 | "metadata": {}, 405 | "source": [ 406 | "Some boilerplate here.\n", 407 | "Have to design text parts separation externally—as a function, at least.\n", 408 | "\n", 409 | "Better to do a nice class but that depends on would I do that process again for some other corpus." 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": null, 415 | "id": "0343e58e-d7e9-4686-b2d5-da01a4c9bc56", 416 | "metadata": {}, 417 | "outputs": [], 418 | "source": [ 419 | "import re\n", 420 | "\n", 421 | "\n", 422 | "regexp_lat_dig = re.compile(r\"[a-zA-Z\\d]+\")\n", 423 | "\n", 424 | "\n", 425 | "def get_matches(text):\n", 426 | " texts = text.split(\"\\n\")\n", 427 | " result = []\n", 428 | " for text in texts:\n", 429 | " if re.search(regexp_lat_dig, text):\n", 430 | " continue\n", 431 | " matches = []\n", 432 | " for number, elem in numbers_data.items():\n", 433 | " if elem[\"substr\"] and elem[\"substr\"] not in text:\n", 434 | " continue\n", 435 | " if match := re.search(elem[\"regexp\"], text):\n", 436 | " matches.append({\"number\": number, \"place\": match.span(), \"form\": match[0]})\n", 437 | " if matches:\n", 438 | " result.append({\n", 439 | " \"text\": text,\n", 440 | " \"matches\": matches\n", 441 | " })\n", 442 | " return result" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": null, 448 | "id": "c80c0a95-c08f-4ff3-bc9c-10e45917f957", 449 | "metadata": {}, 450 | "outputs": [], 451 | "source": [ 452 | "regexp_lat_dig = re.compile(r\"[a-zA-Z\\d]+\")\n", 453 | "\n", 454 | "\n", 455 | "def get_matches(text):\n", 456 | " nlp_sentencizer.max_length = len(text) + 100\n", 457 | " doc = nlp_sentencizer(text)\n", 458 | " result = []\n", 459 | " for sentence in doc.sents:\n", 460 | " text = str(sentence)\n", 461 | " if re.search(regexp_lat_dig, text):\n", 462 | " continue\n", 463 | " matches = []\n", 464 | " for number, elem in numbers_data.items():\n", 465 | " if elem[\"substr\"] and elem[\"substr\"] not in text:\n", 466 | " continue\n", 467 | " if match := re.search(elem[\"regexp\"], text):\n", 468 | " matches.append({\"number\": number, \"place\": match.span(), \"form\": match[0]})\n", 469 | " if matches:\n", 470 | " result.append({\n", 471 | " \"text\": text,\n", 472 | " \"matches\": matches\n", 473 | " })\n", 474 | " return result" 475 | ] 476 | }, 477 | { 478 | "cell_type": "markdown", 479 | "id": "70c3b3ef-5155-4aa4-b91d-4e7479adcef6", 480 | "metadata": {}, 481 | "source": [ 482 | "Mostly the same but boilerplace again as the key is not `text_markdown` but `text` now.\n", 483 | "\n", 484 | "Should make some refactoring later." 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": null, 490 | "id": "2ffbce50-d038-4506-b268-a86196e6447a", 491 | "metadata": { 492 | "scrolled": true 493 | }, 494 | "outputs": [], 495 | "source": [ 496 | "from multiprocessing import Process, Queue\n", 497 | "from multiprocessing import Pool\n", 498 | "\n", 499 | "\n", 500 | "queue = Queue()\n", 501 | "\n", 502 | "\n", 503 | "def process_example(**kwargs):\n", 504 | " if matches := get_matches(kwargs[\"text\"]):\n", 505 | " queue.put({\n", 506 | " \"index\": kwargs[\"id\"],\n", 507 | " \"texts\": matches\n", 508 | " })\n", 509 | "\n", 510 | "\n", 511 | "def write(queue):\n", 512 | " f = open(PATHES[\"librusec\"][\"output\"], \"w\")\n", 513 | " while True:\n", 514 | " item = queue.get()\n", 515 | " if item is None:\n", 516 | " break\n", 517 | " json.dump(item, f, ensure_ascii=False)\n", 518 | " f.write(\"\\n\")\n", 519 | " f.close()\n", 520 | "\n", 521 | "\n", 522 | "writer = Process(target=write, args=(queue, ))\n", 523 | "writer.start()\n", 524 | "dataset = load_dataset(PATHES[\"librusec\"][\"input\"], split=\"train\", streaming=True)\n", 525 | "with Pool(10) as p:\n", 526 | " for example in tqdm(dataset):\n", 527 | " p.apply(process_example, kwds={**example})\n", 528 | " queue.put(None)\n", 529 | "p.join()\n", 530 | "writer.join()" 531 | ] 532 | }, 533 | { 534 | "cell_type": "markdown", 535 | "id": "aac0058a-502a-4dcf-a735-8788397ab1b6", 536 | "metadata": {}, 537 | "source": [ 538 | "# ficbook" 539 | ] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "execution_count": null, 544 | "id": "70c2e010-7c27-409e-998f-1d7bf044ebe7", 545 | "metadata": {}, 546 | "outputs": [], 547 | "source": [ 548 | "from multiprocessing import Process, Queue\n", 549 | "from multiprocessing import Pool\n", 550 | "\n", 551 | "\n", 552 | "queue = Queue()\n", 553 | "\n", 554 | "\n", 555 | "def process_example(**kwargs):\n", 556 | " for part in kwargs[\"parts\"]:\n", 557 | " if matches := get_matches(part[\"clean_text\"]):\n", 558 | " queue.put({\n", 559 | " \"index\": part[\"url\"],\n", 560 | " \"texts\": matches\n", 561 | " })\n", 562 | "\n", 563 | "\n", 564 | "def write(queue):\n", 565 | " f = open(PATHES[\"ficbook\"][\"output\"], \"w\")\n", 566 | " while True:\n", 567 | " item = queue.get()\n", 568 | " if item is None:\n", 569 | " break\n", 570 | " json.dump(item, f, ensure_ascii=False)\n", 571 | " f.write(\"\\n\")\n", 572 | " f.close()\n", 573 | "\n", 574 | "\n", 575 | "writer = Process(target=write, args=(queue, ))\n", 576 | "writer.start()\n", 577 | "dataset = load_dataset(PATHES[\"ficbook\"][\"input\"], split=\"train\", streaming=True)\n", 578 | "with Pool(10) as p:\n", 579 | " for example in tqdm(dataset):\n", 580 | " p.apply(process_example, kwds={**example})\n", 581 | " queue.put(None)\n", 582 | "p.join()\n", 583 | "writer.join()" 584 | ] 585 | } 586 | ], 587 | "metadata": { 588 | "kernelspec": { 589 | "display_name": "Python 3 (ipykernel)", 590 | "language": "python", 591 | "name": "python3" 592 | }, 593 | "language_info": { 594 | "codemirror_mode": { 595 | "name": "ipython", 596 | "version": 3 597 | }, 598 | "file_extension": ".py", 599 | "mimetype": "text/x-python", 600 | "name": "python", 601 | "nbconvert_exporter": "python", 602 | "pygments_lexer": "ipython3", 603 | "version": "3.10.13" 604 | } 605 | }, 606 | "nbformat": 4, 607 | "nbformat_minor": 5 608 | } 609 | -------------------------------------------------------------------------------- /work/dataset/2_inverse_normalize.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "163f62b8-7ac3-41f5-82f1-7a5c9aff67df", 6 | "metadata": {}, 7 | "source": [ 8 | "Now we need to perform an inverse text normalization to obtain examples to train an LLM on." 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "4608e1d2-a7d7-4731-9336-2ec6cf7988d7", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import json\n", 19 | "from tqdm.notebook import tqdm\n", 20 | "import random" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "a16fcf55-dfda-4608-ad85-53d40b3ec413", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "from collections import defaultdict, Counter" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "9d02c5a1-445b-4949-8715-67357b4bed79", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "PATHES = {\n", 41 | " \"librusec\": {\n", 42 | " \"input\": \"/home/jovyan/data/librusec.jsonl\",\n", 43 | " \"output\": \"/home/jovyan/data/librusec_pairs.jsonl\",\n", 44 | " },\n", 45 | " \"ficbook\": {\n", 46 | " \"input\": \"/home/jovyan/data/ficbook.jsonl\",\n", 47 | " \"output\": \"/home/jovyan/data/ficbook_pairs.jsonl\",\n", 48 | " },\n", 49 | " \"pikabu\": {\n", 50 | " \"input\": \"/home/jovyan/data/pikabu.jsonl\",\n", 51 | " \"output\": \"/home/jovyan/data/pikabu_pairs.jsonl\",\n", 52 | " },\n", 53 | "}" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "id": "e2e3e3bc-2840-42e6-a2d0-15a33a7a5799", 59 | "metadata": {}, 60 | "source": [ 61 | "Uncomment pairs of input and output file pathes below one by one." 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "id": "34c439b2-24ea-4084-adab-ab14e43601eb", 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "DATASET = \"librusec\"\n", 72 | "DATASET = \"ficbook\"\n", 73 | "DATASET = \"pikabu\"" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "id": "809971b9-395a-407a-82fd-f338f3530bd1", 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "# !pip install jsonlines pysimdjson" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "id": "b528e32c-999a-4fb3-8abc-64bc663bb1f6", 90 | "metadata": { 91 | "scrolled": true 92 | }, 93 | "outputs": [], 94 | "source": [ 95 | "import simdjson\n", 96 | "\n", 97 | "\n", 98 | "data = []\n", 99 | "parser = simdjson.Parser()\n", 100 | "with open(PATHES[DATASET][\"input\"]) as f:\n", 101 | " for i, line in tqdm(enumerate(f)):\n", 102 | " data.append(parser.parse(line).as_dict())\n", 103 | "random.choice(data)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "id": "ecb3cb29-3ff5-4228-b2a0-9b5a58cd49df", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "for elem in tqdm(data):\n", 114 | " for text_elem in elem[\"texts\"]:\n", 115 | " text_elem[\"matches\"] = [_[\"form\"] for _ in text_elem[\"matches\"]]\n", 116 | "random.choice(data)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "id": "fd154913-097f-4432-a480-85d0d065d664", 122 | "metadata": { 123 | "collapsed": true, 124 | "jupyter": { 125 | "outputs_hidden": true 126 | } 127 | }, 128 | "source": [ 129 | "~~I end up using `nemo_text_processing` as it is fast.~~\n", 130 | "~~This part is vital so better use another one here next time.~~\n", 131 | "\n", 132 | "~~Seems ok for a first attempt anyway.~~" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "id": "520b9522-e296-49e7-90b2-8ec56856f427", 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "# from nemo_text_processing.inverse_text_normalization.inverse_normalize import InverseNormalizer\n", 143 | "# inverse_normalizer = InverseNormalizer(lang='ru')" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "id": "41e140cd-3174-46b2-9b0e-852a6aad6b8e", 149 | "metadata": {}, 150 | "source": [ 151 | "Use [another](https://github.com/flockentanz/word_to_number_ru) itn now as NeMo covers not much of what I want to.\n", 152 | "\n", 153 | "Do `pip install natasha==0.10.0 yargy==0.12.0` first." 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "id": "60916ab0-db85-4fc9-a636-24d7b85546ae", 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "from word_to_number_ru.extractor import NumberExtractor\n", 164 | "\n", 165 | "\n", 166 | "ne = NumberExtractor()" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "id": "869c24c2-89c5-4d09-88eb-69b08506d7a5", 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "sum(len(_[\"texts\"]) for _ in data)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "id": "03efe745-e7fb-40f0-8d50-0a7fb50fc8f0", 182 | "metadata": {}, 183 | "source": [ 184 | "Make some **ROUGH** balancing here.\n", 185 | "First, perform a quick distribution check." 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "id": "bbda9606-a856-4681-872f-07c6fd869113", 192 | "metadata": { 193 | "scrolled": true 194 | }, 195 | "outputs": [], 196 | "source": [ 197 | "counter = Counter()\n", 198 | "for elem in tqdm(data):\n", 199 | " for text in elem[\"texts\"]:\n", 200 | " counter.update(text[\"matches\"])\n", 201 | "sum(counter.values()), len(counter), counter" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "id": "b4ceca62-9db6-43d3-b60a-5d1176195597", 207 | "metadata": {}, 208 | "source": [ 209 | "Seems too much of numbers less than 10.\n", 210 | "So we strip a part of the most common numbers." 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "id": "00d930e4-1b29-41e0-8942-15faeb52c54b", 217 | "metadata": { 218 | "scrolled": true 219 | }, 220 | "outputs": [], 221 | "source": [ 222 | "from multiprocessing import Process, Queue\n", 223 | "from multiprocessing import Pool\n", 224 | "from itertools import islice, chain\n", 225 | "\n", 226 | "\n", 227 | "queue = Queue()\n", 228 | "occurances = Counter()\n", 229 | "occurances_limit = sum(counter.values()) / 25 # len(counter)\n", 230 | "print(f\"{occurances_limit=}\")\n", 231 | "\n", 232 | "\n", 233 | "def inverse_normalize(elem):\n", 234 | " result = []\n", 235 | " for i_text, text in enumerate(elem[\"texts\"]):\n", 236 | " matches = [match for match in text[\"matches\"] if occurances[match] < occurances_limit]\n", 237 | " if not matches:\n", 238 | " continue\n", 239 | " occurances.update(text[\"matches\"])\n", 240 | " # itn = inverse_normalizer.inverse_normalize(text[\"text\"], verbose=False)\n", 241 | " itn = ne(text[\"text\"])\n", 242 | " if itn == text[\"text\"]:\n", 243 | " continue\n", 244 | " result.append({\n", 245 | " \"tn\": text[\"text\"],\n", 246 | " \"itn\": itn,\n", 247 | " \"orig_index\": elem[\"index\"],\n", 248 | " \"text_index\": i_text\n", 249 | " })\n", 250 | " return result\n", 251 | "\n", 252 | "\n", 253 | "b_size = 10000 # you may wnat to decrease it in case of librusec as its texts are vast\n", 254 | "f = open(PATHES[DATASET][\"output\"], \"w\")\n", 255 | "i = 0\n", 256 | "pbar = tqdm(total=len(data))\n", 257 | "while i < len(data):\n", 258 | " with Pool(15) as p:\n", 259 | " result = p.imap_unordered(inverse_normalize, data[i:i + b_size])\n", 260 | " for elem in chain(*result):\n", 261 | " json.dump(elem, f, ensure_ascii=False)\n", 262 | " f.write(\"\\n\")\n", 263 | " i += b_size\n", 264 | " pbar.update(b_size)\n", 265 | "pbar.close()\n", 266 | "f.close()" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": null, 272 | "id": "151ddaa9-3f2b-453e-9b16-f19598bfa5e4", 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "occurances" 277 | ] 278 | } 279 | ], 280 | "metadata": { 281 | "kernelspec": { 282 | "display_name": "Python 3 (ipykernel)", 283 | "language": "python", 284 | "name": "python3" 285 | }, 286 | "language_info": { 287 | "codemirror_mode": { 288 | "name": "ipython", 289 | "version": 3 290 | }, 291 | "file_extension": ".py", 292 | "mimetype": "text/x-python", 293 | "name": "python", 294 | "nbconvert_exporter": "python", 295 | "pygments_lexer": "ipython3", 296 | "version": "3.10.13" 297 | } 298 | }, 299 | "nbformat": 4, 300 | "nbformat_minor": 5 301 | } 302 | -------------------------------------------------------------------------------- /work/dataset/3_process_itn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "50936b71-42e6-4283-91a9-d91c4c3952cd", 6 | "metadata": {}, 7 | "source": [ 8 | "This notebook is to construct `Replaces` objects out of text pairs we have saved before.\n", 9 | "\n", 10 | "Not much comments here so one is necessary to look into `../replaces.py` file (bad idea probably)." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "c2f8da5f-da5f-4e3d-93aa-b101478eb266", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "PATHES = {\n", 21 | " \"librusec\": {\n", 22 | " \"input\": \"/home/jovyan/data/librusec_pairs.jsonl\",\n", 23 | " \"output\": \"/home/jovyan/data/librusec_replaces.jsonl\",\n", 24 | " },\n", 25 | " \"ficbook\": {\n", 26 | " \"input\": \"/home/jovyan/data/ficbook_pairs.jsonl\",\n", 27 | " \"output\": \"/home/jovyan/data/ficbook_replaces.jsonl\",\n", 28 | " },\n", 29 | " \"pikabu\": {\n", 30 | " \"input\": \"/home/jovyan/data/pikabu_pairs.jsonl\",\n", 31 | " \"output\": \"/home/jovyan/data/pikabu_replaces.jsonl\",\n", 32 | " },\n", 33 | "}" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "id": "9481d1b5-47c1-4c80-8fd6-8cdceb1c8ad0", 39 | "metadata": {}, 40 | "source": [ 41 | "I use `Replaces` class as list of changes have been made upon the text." 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "747c82a0-96eb-44d7-b1b6-4e5a11573d7f", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "import os\n", 52 | "import sys\n", 53 | "sys.path.append(os.path.join(os.path.dirname(os.getcwd())))" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "id": "b43cb93b-762d-4af0-b856-ab9a8e37b92b", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "from replaces import Replace, Replaces\n", 64 | "\n", 65 | "\n", 66 | "Replaces.from_sequences(\n", 67 | " \"мама мыла раму с мылом\".split(),\n", 68 | " \"мама раму уронила\".split(),\n", 69 | " False\n", 70 | ")" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "id": "535f86e4-a87c-457d-8912-738491ad796a", 76 | "metadata": {}, 77 | "source": [ 78 | "One is possible to construct Replaces object out from list of dicts so some kind of serialization could be done easily here" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "id": "f347a04c-e5e0-4cd0-88c0-eb2ece0397be", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "Replaces([{\"text_from\": \"a\", \"text_to\": \"a\"}, {\"text_from\": \"b\", \"text_to\": \"c\"}])" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "id": "77985a6e-9c9e-46a6-bc61-6627b6a2ab55", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "DATASET = \"librusec\"\n", 99 | "# DATASET = \"ficbook\"\n", 100 | "# DATASET = \"pikabu\"" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "id": "1922bed2-3ed7-4cf8-962f-044d53471e95", 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "import json\n", 111 | "from tqdm.auto import tqdm\n", 112 | "\n", 113 | "\n", 114 | "with open(PATHES[DATASET][\"input\"]) as f:\n", 115 | " data = [json.loads(line) for line in tqdm(f)]" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "id": "8dbec4b1-df30-4d35-b0b9-941cd80da2e7", 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "import re\n", 126 | "\n", 127 | "\n", 128 | "re_tokens = re.compile(r\"[а-яА-ЯёЁ]+\\s*|[a-zA-Z]+\\s*|\\d+(?:\\.\\d+)?\\s*|[^а-яА-ЯёЁa-zA-Z\\d\\s]+\\s*\")\n", 129 | "\n", 130 | "\n", 131 | "def tokenize(text):\n", 132 | " return re.findall(re_tokens, text)\n", 133 | "\n", 134 | "\n", 135 | "\"|\".join(tokenize(\"ты, да я, да мы c тобой - вместе 2.5°C. C'est la vie! iPhone 10C pro15 f2f.\"))" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "id": "87d69de0-1adb-4847-ba39-8a75fc0a99fb", 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "for elem in tqdm(data):\n", 146 | " if \"replaces\" not in elem:\n", 147 | " elem[\"replaces\"] = Replaces.from_sequences(tokenize(elem[\"itn\"]), tokenize(elem[\"tn\"]))\n", 148 | " for r1, r2 in zip(elem[\"replaces\"], elem[\"replaces\"][1:]):\n", 149 | " if r1.type != \"E\" and r1.text_from.endswith(\" \") and r1.text_to.endswith(\" \"):\n", 150 | " r1.text_from = r1.text_from[:-1]\n", 151 | " r1.text_to = r1.text_to[:-1]\n", 152 | " r2.text_from = \" \" + r2.text_from\n", 153 | " r2.text_to = \" \" + r2.text_to\n", 154 | "data[0]" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "id": "b5de8bfe-dcae-4d29-9e97-0f128c7c5fff", 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "for i, elem in enumerate(data[:10]):\n", 165 | " print(f'{i}\\n{elem[\"replaces\"]}\\n')" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "id": "175ab729-3fea-42c3-9f96-7f05672e5c63", 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "import json\n", 176 | "\n", 177 | "\n", 178 | "with open(PATHES[DATASET][\"output\"], \"w\") as f:\n", 179 | " for elem in tqdm(data):\n", 180 | " json.dump({\"replaces\": elem[\"replaces\"]}, f, ensure_ascii=False)\n", 181 | " f.write(\"\\n\")" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "id": "d8d658ba-988e-4a7c-8496-b9fbdb2aca8e", 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "data[1]" 192 | ] 193 | } 194 | ], 195 | "metadata": { 196 | "kernelspec": { 197 | "display_name": "Python 3 (ipykernel)", 198 | "language": "python", 199 | "name": "python3" 200 | }, 201 | "language_info": { 202 | "codemirror_mode": { 203 | "name": "ipython", 204 | "version": 3 205 | }, 206 | "file_extension": ".py", 207 | "mimetype": "text/x-python", 208 | "name": "python", 209 | "nbconvert_exporter": "python", 210 | "pygments_lexer": "ipython3", 211 | "version": "3.10.13" 212 | } 213 | }, 214 | "nbformat": 4, 215 | "nbformat_minor": 5 216 | } 217 | -------------------------------------------------------------------------------- /work/dataset/4_process_kaggle.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "a1eb69d9-90f1-400c-8142-9e000cec037f", 6 | "metadata": {}, 7 | "source": [ 8 | "This notebook is for [Kaggle Russian Normalization challenge](https://www.kaggle.com/competitions/text-normalization-challenge-russian-language).\n", 9 | "\n", 10 | "In order to reproduce the results one is necessary to download `ru_train.csv` file trom the challenge website and put it alongside the notebook." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "c2f8da5f-da5f-4e3d-93aa-b101478eb266", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "PATHES = {\n", 21 | " \"load\": \"ru_train.csv\",\n", 22 | " \"save\": \"/home/jovyan/data/kaggle.jsonl\"\n", 23 | "}" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "id": "9481d1b5-47c1-4c80-8fd6-8cdceb1c8ad0", 29 | "metadata": {}, 30 | "source": [ 31 | "I use `Replaces` class as list of changes have been made upon the text." 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "id": "747c82a0-96eb-44d7-b1b6-4e5a11573d7f", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "import os\n", 42 | "import sys\n", 43 | "sys.path.append(os.path.join(os.path.dirname(os.getcwd())))" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "id": "b43cb93b-762d-4af0-b856-ab9a8e37b92b", 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "data": { 54 | "text/plain": [ 55 | "E|мама\n", 56 | "R|мыла => \n", 57 | "E|раму\n", 58 | "R|смылом => уронила" 59 | ] 60 | }, 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "output_type": "execute_result" 64 | } 65 | ], 66 | "source": [ 67 | "from replaces import Replace, Replaces\n", 68 | "\n", 69 | "\n", 70 | "Replaces.from_sequences(\n", 71 | " \"мама мыла раму с мылом\".split(),\n", 72 | " \"мама раму уронила\".split(),\n", 73 | " False\n", 74 | ")" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "id": "535f86e4-a87c-457d-8912-738491ad796a", 80 | "metadata": {}, 81 | "source": [ 82 | "One is possible to construct Replaces object out from list of dicts so some kind of serialization could be done easily here" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 4, 88 | "id": "f347a04c-e5e0-4cd0-88c0-eb2ece0397be", 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "data": { 93 | "text/plain": [ 94 | "E|a\n", 95 | "R|b => c" 96 | ] 97 | }, 98 | "execution_count": 4, 99 | "metadata": {}, 100 | "output_type": "execute_result" 101 | } 102 | ], 103 | "source": [ 104 | "Replaces([{\"text_from\": \"a\", \"text_to\": \"a\"}, {\"text_from\": \"b\", \"text_to\": \"c\"}])" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "id": "6e748a92-c4c2-440b-b5ad-36e60081b8e9", 110 | "metadata": {}, 111 | "source": [ 112 | "Parse kaggle train file" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 5, 118 | "id": "3fad0379-b19c-4073-85f0-5cdca37b7324", 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stderr", 123 | "output_type": "stream", 124 | "text": [ 125 | "10574516it [00:08, 1181950.33it/s]\n" 126 | ] 127 | }, 128 | { 129 | "data": { 130 | "text/plain": [ 131 | "(761436,\n", 132 | " {0: {'class': 'PLAIN', 'before': 'По', 'after': 'По'},\n", 133 | " 1: {'class': 'PLAIN', 'before': 'состоянию', 'after': 'состоянию'},\n", 134 | " 2: {'class': 'PLAIN', 'before': 'на', 'after': 'на'},\n", 135 | " 3: {'class': 'DATE',\n", 136 | " 'before': '1862 год',\n", 137 | " 'after': 'тысяча восемьсот шестьдесят второй год'},\n", 138 | " 4: {'class': 'PUNCT', 'before': '.', 'after': '.'}})" 139 | ] 140 | }, 141 | "execution_count": 5, 142 | "metadata": {}, 143 | "output_type": "execute_result" 144 | } 145 | ], 146 | "source": [ 147 | "from collections import defaultdict\n", 148 | "from tqdm import tqdm\n", 149 | "import csv\n", 150 | "\n", 151 | "\n", 152 | "data = defaultdict(dict)\n", 153 | "with open(PATHES[\"load\"]) as f:\n", 154 | " reader = csv.reader(f)\n", 155 | " next(reader, None) # ['sentence_id', 'token_id', 'class', 'before', 'after']\n", 156 | " for row in tqdm(reader):\n", 157 | " data[int(row[0])][int(row[1])] = {\n", 158 | " \"class\": row[2],\n", 159 | " \"before\": row[3],\n", 160 | " \"after\": row[4],\n", 161 | " }\n", 162 | "len(data), data[0]" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "id": "2272846b-dfbc-4851-a068-1ba736754848", 168 | "metadata": {}, 169 | "source": [ 170 | "Quick check on are tokens indices ok" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 6, 176 | "id": "176af51a-04c1-489d-a39a-bda19f10ea51", 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "for sent in data.values():\n", 181 | " if len(sent) == len(set(sent)) == max(sent) + 1:\n", 182 | " continue\n", 183 | " print(sent)\n", 184 | " break" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "id": "15824559-7c3b-45ae-8c1c-00efc487f4a1", 190 | "metadata": {}, 191 | "source": [ 192 | "Reformat it" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 7, 198 | "id": "32801693-2f6b-44c7-b50a-148cbd4687fa", 199 | "metadata": {}, 200 | "outputs": [ 201 | { 202 | "name": "stderr", 203 | "output_type": "stream", 204 | "text": [ 205 | "100%|██████████| 761436/761436 [00:03<00:00, 219834.88it/s]\n" 206 | ] 207 | }, 208 | { 209 | "data": { 210 | "text/plain": [ 211 | "{'sentence_id': 0,\n", 212 | " 'tokens': [{'class': 'PLAIN', 'before': 'По', 'after': 'По'},\n", 213 | " {'class': 'PLAIN', 'before': 'состоянию', 'after': 'состоянию'},\n", 214 | " {'class': 'PLAIN', 'before': 'на', 'after': 'на'},\n", 215 | " {'class': 'DATE',\n", 216 | " 'before': '1862 год',\n", 217 | " 'after': 'тысяча восемьсот шестьдесят второй год'},\n", 218 | " {'class': 'PUNCT', 'before': '.', 'after': '.'}]}" 219 | ] 220 | }, 221 | "execution_count": 7, 222 | "metadata": {}, 223 | "output_type": "execute_result" 224 | } 225 | ], 226 | "source": [ 227 | "data = [{\n", 228 | " \"sentence_id\": sent_id,\n", 229 | " \"tokens\": [token for i_token, token in sorted(sent.items(), key=lambda x: x[0])]\n", 230 | "} for sent_id, sent in tqdm(data.items())]\n", 231 | "data[0]" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "id": "46e5d535-bde8-4de4-9346-0d308935d9d0", 237 | "metadata": {}, 238 | "source": [ 239 | "Check on what classes are there" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 8, 245 | "id": "abb9a1c8-9441-4159-956c-d7033ea07d70", 246 | "metadata": {}, 247 | "outputs": [ 248 | { 249 | "data": { 250 | "text/plain": [ 251 | "{'CARDINAL',\n", 252 | " 'DATE',\n", 253 | " 'DECIMAL',\n", 254 | " 'DIGIT',\n", 255 | " 'ELECTRONIC',\n", 256 | " 'FRACTION',\n", 257 | " 'LETTERS',\n", 258 | " 'MEASURE',\n", 259 | " 'MONEY',\n", 260 | " 'ORDINAL',\n", 261 | " 'PLAIN',\n", 262 | " 'PUNCT',\n", 263 | " 'TELEPHONE',\n", 264 | " 'TIME',\n", 265 | " 'VERBATIM'}" 266 | ] 267 | }, 268 | "execution_count": 8, 269 | "metadata": {}, 270 | "output_type": "execute_result" 271 | } 272 | ], 273 | "source": [ 274 | "from itertools import chain\n", 275 | "set(chain(*[[_[\"class\"] for _ in sent[\"tokens\"]] for sent in data]))" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "id": "8e410be1-5f68-4878-a911-cddf4ac226a7", 281 | "metadata": {}, 282 | "source": [ 283 | "One is necessary to polish spaces of tokens" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 9, 289 | "id": "95a58d8e-297d-4f44-bfb6-3e2d8a0e9bc5", 290 | "metadata": {}, 291 | "outputs": [ 292 | { 293 | "name": "stderr", 294 | "output_type": "stream", 295 | "text": [ 296 | "100%|██████████| 761436/761436 [00:06<00:00, 126458.97it/s]\n" 297 | ] 298 | } 299 | ], 300 | "source": [ 301 | "import re\n", 302 | "\n", 303 | "\n", 304 | "re_trans = re.compile(r\"_(trans|latin) *\")\n", 305 | "examples = []\n", 306 | "for elem in tqdm(data):\n", 307 | " for key in (\"before\", \"after\"):\n", 308 | " for token in elem[\"tokens\"]:\n", 309 | " if \"_\" in token[key]: # quicker check to speedup\n", 310 | " token[key] = re.sub(re_trans, \"\", str(token[key]))\n", 311 | " for i1, (t1, t2) in enumerate(zip(elem[\"tokens\"], elem[\"tokens\"][1:])):\n", 312 | " if t1[key] in (\"(\", \"«\"):\n", 313 | " t1[key] = \" \" + t1[key]\n", 314 | " t2[key] = t2[key].strip()\n", 315 | " elif t2[\"class\"] == \"PUNCT\":\n", 316 | " if t2[key] == \"—\":\n", 317 | " t1[key] += \" \"\n", 318 | " else:\n", 319 | " pass\n", 320 | " elif t1[\"class\"] == t2[\"class\"] == \"ORDINAL\" and t2[key].startswith(\"—\"):\n", 321 | " pass\n", 322 | " elif t1[\"class\"] == \"VERBATIM\" or t2[\"class\"] == \"VERBATIM\" and t1[\"class\"] != \"PUNCT\":\n", 323 | " pass\n", 324 | " else:\n", 325 | " t1[key] += \" \"\n" 326 | ] 327 | }, 328 | { 329 | "cell_type": "markdown", 330 | "id": "7d88eaf9-afee-4339-85d8-de33f3822774", 331 | "metadata": {}, 332 | "source": [ 333 | "Check whether everything went ok" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 10, 339 | "id": "4dceb061-709d-4808-9a03-8cb32d3ead11", 340 | "metadata": {}, 341 | "outputs": [ 342 | { 343 | "data": { 344 | "text/plain": [ 345 | "{'sentence_id': 0,\n", 346 | " 'tokens': [{'class': 'PLAIN', 'before': 'По ', 'after': 'По '},\n", 347 | " {'class': 'PLAIN', 'before': 'состоянию ', 'after': 'состоянию '},\n", 348 | " {'class': 'PLAIN', 'before': 'на ', 'after': 'на '},\n", 349 | " {'class': 'DATE',\n", 350 | " 'before': '1862 год',\n", 351 | " 'after': 'тысяча восемьсот шестьдесят второй год'},\n", 352 | " {'class': 'PUNCT', 'before': '.', 'after': '.'}]}" 353 | ] 354 | }, 355 | "execution_count": 10, 356 | "metadata": {}, 357 | "output_type": "execute_result" 358 | } 359 | ], 360 | "source": [ 361 | "data[0]" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": 11, 367 | "id": "9d43ada1-8350-48a8-b674-4b7680d94a5c", 368 | "metadata": {}, 369 | "outputs": [ 370 | { 371 | "name": "stdout", 372 | "output_type": "stream", 373 | "text": [ 374 | "0 По состоянию на 1862 год.\n", 375 | "0 По состоянию на тысяча восемьсот шестьдесят второй год.\n", 376 | "1 Оснащались латными рукавицами и сабатонами с не длинными носками.\n", 377 | "1 Оснащались латными рукавицами и сабатонами с не длинными носками.\n", 378 | "2 В конце 1811 года, вследствие конфликта с проезжим вельможей (графом Салтыковым) вынужден был оставить службу по личному прошению.\n", 379 | "2 В конце тысяча восемьсот одиннадцатого года, вследствие конфликта с проезжим вельможей (графом Салтыковым) вынужден был оставить службу по личному прошению.\n", 380 | "3 Тиберий Юлий Поллиен Ауспекс (лат. Tiberius Julius Pollienus Auspex) — римский политический деятель начала III века.\n", 381 | "3 Тиберий Юлий Поллиен Ауспекс (лат. тибериус джулиус поллиенус оспекс) — римский политический деятель начала третьего века.\n", 382 | "4 Севернее Дудинки и северо-восточнее Белочи, в низменной долине Неруссы — урочище Узлив.\n", 383 | "4 Севернее Дудинки и северо-восточнее Белочи, в низменной долине Неруссы — урочище Узлив.\n", 384 | "5 Получение информации об адресах, почтовых индексах, странах, городах.\n", 385 | "5 Получение информации об адресах, почтовых индексах, странах, городах.\n", 386 | "6 Проверено 12 февраля 2013. Архивировано из первоисточника 15 февраля 2013. TV, ты меня не любишь?\n", 387 | "6 Проверено двенадцатого февраля две тысячи тринадцатого года. Архивировано из первоисточника пятнадцатого февраля две тысячи тринадцатого года. t v, ты меня не любишь?\n", 388 | "7 Теперь все уважительно зовут Ямамото Аники (яп. — 兄貴, Старший брат).\n", 389 | "7 Теперь все уважительно зовут Ямамото Аники (яп. — 兄貴, Старший брат).\n", 390 | "8 Муниципалитет находится в составе района (комарки) Альто-Дева.\n", 391 | "8 Муниципалитет находится в составе района (комарки) Альто-Дева.\n", 392 | "9 Впоследствии многие пилоты, которые были коллегами экипажа рейса 254, были озадачены: как можно было допустить такую ошибку.\n", 393 | "9 Впоследствии многие пилоты, которые были коллегами экипажа рейса двести пятьдесят четыре, были озадачены: как можно было допустить такую ошибку.\n" 394 | ] 395 | } 396 | ], 397 | "source": [ 398 | "for i, elem in enumerate(data[:10]):\n", 399 | " for key in (\"before\", \"after\"):\n", 400 | " print(i, \"\".join([_[key] for _ in elem[\"tokens\"]]))" 401 | ] 402 | }, 403 | { 404 | "cell_type": "markdown", 405 | "id": "5778719c-0ee7-4b69-9fdb-a51e5e88b907", 406 | "metadata": {}, 407 | "source": [ 408 | "Looks well done. Construct Replaces now." 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": 12, 414 | "id": "e223186c-3b6a-4e13-b81a-7e76271dbfc9", 415 | "metadata": {}, 416 | "outputs": [ 417 | { 418 | "name": "stderr", 419 | "output_type": "stream", 420 | "text": [ 421 | "100%|██████████| 761436/761436 [00:16<00:00, 45418.81it/s]\n" 422 | ] 423 | } 424 | ], 425 | "source": [ 426 | "for elem in tqdm(data):\n", 427 | " elem[\"replaces\"] = Replaces.from_sequences(\n", 428 | " [_[\"before\"] for _ in elem[\"tokens\"]],\n", 429 | " [_[\"after\"] for _ in elem[\"tokens\"]],\n", 430 | " False\n", 431 | " )\n", 432 | " for r1, r2 in zip(elem[\"replaces\"], elem[\"replaces\"][1:]):\n", 433 | " if r1.type != \"E\" and r1.text_from.endswith(\" \") and r1.text_to.endswith(\" \"):\n", 434 | " r1.text_from = r1.text_from[:-1]\n", 435 | " r1.text_to = r1.text_to[:-1]\n", 436 | " r2.text_from = \" \" + r2.text_from\n", 437 | " r2.text_to = \" \" + r2.text_to" 438 | ] 439 | }, 440 | { 441 | "cell_type": "markdown", 442 | "id": "3b0d838d-8a28-44d5-95f1-170eaba16175", 443 | "metadata": {}, 444 | "source": [ 445 | "Get rid of examples where latin or digits exist in resulting text" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": 13, 451 | "id": "8f685be8-e1b0-46fd-914b-a7d56edfe29d", 452 | "metadata": {}, 453 | "outputs": [ 454 | { 455 | "name": "stderr", 456 | "output_type": "stream", 457 | "text": [ 458 | "100%|██████████| 761436/761436 [00:02<00:00, 290752.10it/s]\n" 459 | ] 460 | }, 461 | { 462 | "data": { 463 | "text/plain": [ 464 | "(761436, 378074)" 465 | ] 466 | }, 467 | "execution_count": 13, 468 | "metadata": {}, 469 | "output_type": "execute_result" 470 | } 471 | ], 472 | "source": [ 473 | "re_digits_latin = re.compile(r\"[a-zA-Z\\d]\")\n", 474 | "good_data = []\n", 475 | "for elem in tqdm(data):\n", 476 | " if all(r.type == \"E\" for r in elem[\"replaces\"]):\n", 477 | " continue\n", 478 | " is_ok = True\n", 479 | " for r in elem[\"replaces\"]:\n", 480 | " if r.type == \"E\" and re.search(re_digits_latin, r.text_from):\n", 481 | " is_ok = False\n", 482 | " break\n", 483 | " if re.search(re_digits_latin, r.text_to):\n", 484 | " is_ok = False\n", 485 | " break\n", 486 | " if is_ok:\n", 487 | " good_data.append(elem)\n", 488 | "len(data), len(good_data)" 489 | ] 490 | }, 491 | { 492 | "cell_type": "markdown", 493 | "id": "f77d2b4d-718d-4f02-81e6-e151f3d11c6c", 494 | "metadata": {}, 495 | "source": [ 496 | "Check on how many examples we have so far and how latin and digits are distributed there." 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": 14, 502 | "id": "66a93e72-59db-46c1-ac58-ec7f4dc18d83", 503 | "metadata": {}, 504 | "outputs": [ 505 | { 506 | "name": "stderr", 507 | "output_type": "stream", 508 | "text": [ 509 | "100%|██████████| 378074/378074 [00:00<00:00, 597261.03it/s]\n" 510 | ] 511 | }, 512 | { 513 | "name": "stdout", 514 | "output_type": "stream", 515 | "text": [ 516 | "re_digits 293157\n" 517 | ] 518 | }, 519 | { 520 | "name": "stderr", 521 | "output_type": "stream", 522 | "text": [ 523 | "100%|██████████| 378074/378074 [00:00<00:00, 621195.20it/s]\n" 524 | ] 525 | }, 526 | { 527 | "name": "stdout", 528 | "output_type": "stream", 529 | "text": [ 530 | "re_digits_latin 344709\n" 531 | ] 532 | }, 533 | { 534 | "name": "stderr", 535 | "output_type": "stream", 536 | "text": [ 537 | "100%|██████████| 378074/378074 [00:00<00:00, 434262.38it/s]" 538 | ] 539 | }, 540 | { 541 | "name": "stdout", 542 | "output_type": "stream", 543 | "text": [ 544 | "re_latin 100163\n" 545 | ] 546 | }, 547 | { 548 | "name": "stderr", 549 | "output_type": "stream", 550 | "text": [ 551 | "\n" 552 | ] 553 | } 554 | ], 555 | "source": [ 556 | "stat_regs = {\n", 557 | " \"re_digits\": re.compile(r\"\\d\"),\n", 558 | " \"re_digits_latin\": re.compile(r\"[a-zA-Z\\d]\"),\n", 559 | " \"re_latin\": re.compile(r\"[a-zA-Z]\")\n", 560 | "}\n", 561 | "for stat_name, stat_re in stat_regs.items():\n", 562 | " print(\n", 563 | " stat_name,\n", 564 | " len([elem for elem in tqdm(good_data) if any(re.search(stat_re, r.text_from) for r in elem[\"replaces\"])])\n", 565 | " )" 566 | ] 567 | }, 568 | { 569 | "cell_type": "markdown", 570 | "id": "d3442f2b-6007-4533-9290-52df88b1f976", 571 | "metadata": {}, 572 | "source": [ 573 | "Look of what they are." 574 | ] 575 | }, 576 | { 577 | "cell_type": "code", 578 | "execution_count": 15, 579 | "id": "a07c4ed4-780c-4bfa-8e22-46f98333180c", 580 | "metadata": {}, 581 | "outputs": [ 582 | { 583 | "name": "stdout", 584 | "output_type": "stream", 585 | "text": [ 586 | "0\n", 587 | "E|По состоянию на \n", 588 | "R|1862 год => тысяча восемьсот шестьдесят второй год\n", 589 | "E|.\n", 590 | "\n", 591 | "2\n", 592 | "E|В конце \n", 593 | "R|1811 года => тысяча восемьсот одиннадцатого года\n", 594 | "E|, вследствие конфликта с проезжим вельможей (графом Салтыковым) вынужден был оставить службу по личному прошению.\n", 595 | "\n", 596 | "3\n", 597 | "E|Тиберий Юлий Поллиен Ауспекс (лат. \n", 598 | "R|Tiberius Julius Pollienus Auspex => тибериус джулиус поллиенус оспекс\n", 599 | "E|) — римский политический деятель начала \n", 600 | "R|III => третьего\n", 601 | "E| века.\n", 602 | "\n", 603 | "9\n", 604 | "E|Впоследствии многие пилоты, которые были коллегами экипажа рейса \n", 605 | "R|254 => двести пятьдесят четыре\n", 606 | "E|, были озадачены: как можно было допустить такую ошибку.\n", 607 | "\n", 608 | "10\n", 609 | "E|Полудоспех — англ. \n", 610 | "R|Half Armor => халф армор\n", 611 | "E| — латная защита рук и корпуса.\n", 612 | "\n", 613 | "11\n", 614 | "E|в \n", 615 | "R|1895—1896 => тысяча восемьсот девяносто пятом тысяча восемьсот девяносто шестом\n", 616 | "E| годах служил на Черноморском флоте на канонерской лодке «Терец».\n", 617 | "\n", 618 | "12\n", 619 | "E|Данная поправка была внесена на рассмотрение Съезда народных депутатов \n", 620 | "R|РСФСР => р с ф с р\n", 621 | "E|.\n", 622 | "\n", 623 | "13\n", 624 | "E|Революция \n", 625 | "R|1905 года => тысяча девятьсот пятого года\n", 626 | "E| потерпела поражение.\n", 627 | "\n", 628 | "15\n", 629 | "E|Производством сыра занимается компания \n", 630 | "R|Sbrinz Kase GmbH => сбринс кейс гмб\n", 631 | "E|.\n", 632 | "\n", 633 | "21\n", 634 | "E|Проверено \n", 635 | "R|17 июля 2014 => семнадцатого июля две тысячи четырнадцатого года\n", 636 | "E|. \n", 637 | "R|The next supermoon in 2014 is July 12 => зэ некст супермун ин две тысячи четырнадцать ис джули двенадцать\n", 638 | "E| (англ.).\n", 639 | "\n" 640 | ] 641 | } 642 | ], 643 | "source": [ 644 | "for i, elem in enumerate(good_data[:10]):\n", 645 | " print(f'{elem[\"sentence_id\"]}\\n{elem[\"replaces\"]}\\n')" 646 | ] 647 | }, 648 | { 649 | "cell_type": "markdown", 650 | "id": "455ee756-98cc-41ba-b93d-e2c687f427ae", 651 | "metadata": {}, 652 | "source": [ 653 | "Save it finally" 654 | ] 655 | }, 656 | { 657 | "cell_type": "code", 658 | "execution_count": 16, 659 | "id": "175ab729-3fea-42c3-9f96-7f05672e5c63", 660 | "metadata": {}, 661 | "outputs": [ 662 | { 663 | "name": "stderr", 664 | "output_type": "stream", 665 | "text": [ 666 | "100%|██████████| 378074/378074 [00:04<00:00, 77999.29it/s]\n" 667 | ] 668 | } 669 | ], 670 | "source": [ 671 | "import json\n", 672 | "\n", 673 | "\n", 674 | "with open(\"kaggle.jsonl\", \"w\") as f:\n", 675 | " for elem in tqdm(good_data):\n", 676 | " json.dump(\n", 677 | " {\n", 678 | " \"sentence_id\": elem[\"sentence_id\"],\n", 679 | " \"replaces\": elem[\"replaces\"],\n", 680 | " },\n", 681 | " f,\n", 682 | " ensure_ascii=False\n", 683 | " )\n", 684 | " f.write(\"\\n\")" 685 | ] 686 | }, 687 | { 688 | "cell_type": "code", 689 | "execution_count": 18, 690 | "id": "38ca137e-892d-4f37-9cc3-9d855394b322", 691 | "metadata": {}, 692 | "outputs": [ 693 | { 694 | "name": "stdout", 695 | "output_type": "stream", 696 | "text": [ 697 | "378074 /home/jovyan/data/kaggle.jsonl\n" 698 | ] 699 | } 700 | ], 701 | "source": [ 702 | "!wc -l {PATHES[\"save\"]}" 703 | ] 704 | }, 705 | { 706 | "cell_type": "code", 707 | "execution_count": 19, 708 | "id": "b691b0d8-e382-42cd-956f-14166deffd37", 709 | "metadata": {}, 710 | "outputs": [ 711 | { 712 | "name": "stdout", 713 | "output_type": "stream", 714 | "text": [ 715 | "-rw-r--r-- 1 jovyan users 187M Jan 17 19:59 /home/jovyan/data/kaggle.jsonl\n" 716 | ] 717 | } 718 | ], 719 | "source": [ 720 | "!ls -lh {PATHES[\"save\"]}" 721 | ] 722 | } 723 | ], 724 | "metadata": { 725 | "kernelspec": { 726 | "display_name": "Python 3 (ipykernel)", 727 | "language": "python", 728 | "name": "python3" 729 | }, 730 | "language_info": { 731 | "codemirror_mode": { 732 | "name": "ipython", 733 | "version": 3 734 | }, 735 | "file_extension": ".py", 736 | "mimetype": "text/x-python", 737 | "name": "python", 738 | "nbconvert_exporter": "python", 739 | "pygments_lexer": "ipython3", 740 | "version": "3.10.13" 741 | } 742 | }, 743 | "nbformat": 4, 744 | "nbformat_minor": 5 745 | } 746 | -------------------------------------------------------------------------------- /work/dataset/word_to_number_ru/LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | Copyright (c) 2019-2021 Shkarin Sergey 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | The above copyright notice and this permission notice shall be included in all 10 | copies or substantial portions of the Software. 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 13 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 14 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 15 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 16 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 17 | SOFTWARE. 18 | -------------------------------------------------------------------------------- /work/dataset/word_to_number_ru/extractor.py: -------------------------------------------------------------------------------- 1 | from .number import NUMBER 2 | from natasha.extractors import Extractor 3 | import math 4 | 5 | 6 | class NumberExtractor(Extractor): 7 | def __init__(self): 8 | super(NumberExtractor, self).__init__(NUMBER) 9 | 10 | @staticmethod 11 | def __n_digits(n): 12 | if n > 0: 13 | digits = int(math.log10(n)) + 1 14 | elif n == 0: 15 | digits = 1 16 | else: 17 | digits = int(math.log10(-n)) + 2 # +1 if you don't count the '-' 18 | return digits 19 | 20 | @staticmethod 21 | def __trailing_zeros(n: int): 22 | """ 23 | Count trailing zeros of a number 24 | 25 | Args: 26 | n: number 27 | 28 | Result: 29 | cnt: count of zeros 30 | """ 31 | cnt = 0 32 | while n % 10 == 0 and n != 0: 33 | cnt += 1 34 | n = n / 10 35 | return cnt 36 | 37 | def _get_groups(self, text): 38 | start = 0 39 | matches = list(self.parser.findall(text)) 40 | groups = [] 41 | group_matches = [] 42 | for i, match in enumerate(matches): 43 | if i == 0: 44 | start = match.span.start 45 | if i == len(matches) - 1: 46 | next_match = match 47 | else: 48 | next_match = matches[i + 1] 49 | group_matches.append(match.fact) 50 | if text[match.span.stop: next_match.span.start].strip() or next_match == match: 51 | groups.append((group_matches, start, match.span.stop)) 52 | group_matches = [] 53 | start = next_match.span.start 54 | return groups 55 | 56 | def __call__(self, text): 57 | """ 58 | Замена сгруппированных составных чисел в тексте и отдельно стоящих чисел без их суммирования 59 | 60 | Аргументы: 61 | text: исходный текст 62 | 63 | Результат: 64 | new_text: текст с замененными числами 65 | """ 66 | groups = self._get_groups(text) 67 | new_text = '' 68 | start = 0 69 | for group in groups: 70 | new_text += text[start:group[1]] 71 | nums = [] 72 | prev_tz = 0 73 | prev_mult = None 74 | for match in group[0]: 75 | mult = match.multiplier if match.multiplier else 1 76 | curr_num = (match.int if match.int is not None else 1) + (match.with_half or 0) 77 | tz = self.__trailing_zeros(curr_num) 78 | if tz < prev_tz and mult >= prev_mult and curr_num != 0 and \ 79 | self.__n_digits(curr_num) < self.__n_digits(nums[0][0]) and \ 80 | self.__n_digits(curr_num) <= prev_tz: 81 | nums[0] = (nums[0][0] + curr_num, mult) 82 | else: 83 | nums.insert(0, (curr_num, mult)) 84 | prev_mult = mult 85 | prev_tz = tz 86 | prev_mult = None 87 | new_nums = [] 88 | for num, mult in nums: 89 | if mult == 10 ** -1: 90 | power = 1 91 | elif mult == 10 ** -2: 92 | power = 2 93 | elif mult == 10 ** -3: 94 | power = 3 95 | else: 96 | power = None 97 | new_num = round(num * mult, power) if power else num * mult 98 | if not prev_mult or mult <= prev_mult: 99 | new_nums.append(new_num) 100 | else: 101 | new_nums[-1] += new_num 102 | prev_mult = mult 103 | new_nums = [int(_) if isinstance(_, float) and _ == int(_) else _ for _ in new_nums[::-1]] 104 | new_text += ' '.join(map(str, new_nums)) 105 | start = group[2] 106 | new_text += text[start:] 107 | return new_text -------------------------------------------------------------------------------- /work/dataset/word_to_number_ru/number.py: -------------------------------------------------------------------------------- 1 | from yargy import rule, or_ 2 | from yargy.pipelines import morph_pipeline, caseless_pipeline 3 | from yargy.interpretation import fact, const 4 | from yargy.predicates import eq, caseless, normalized, type 5 | 6 | Number = fact('Number', ['int', 'with_half', 'multiplier']) 7 | NUMS_RAW = { 8 | 'ноль': 0, 9 | 'нуль': 0, 10 | 'один': 1, 11 | 'полтора': 1.5, 12 | 'два': 2, 13 | 'три': 3, 14 | 'четыре': 4, 15 | 'пять': 5, 16 | 'шесть': 6, 17 | 'семь': 7, 18 | 'восемь': 8, 19 | 'девять': 9, 20 | 'десять': 10, 21 | 'одиннадцать': 11, 22 | 'двенадцать': 12, 23 | 'тринадцать': 13, 24 | 'четырнадцать': 14, 25 | 'пятнадцать': 15, 26 | 'шестнадцать': 16, 27 | 'семнадцать': 17, 28 | 'восемнадцать': 18, 29 | 'девятнадцать': 19, 30 | 'двадцать': 20, 31 | 'тридцать': 30, 32 | 'сорок': 40, 33 | 'пятьдесят': 50, 34 | 'шестьдесят': 60, 35 | 'семьдесят': 70, 36 | 'восемьдесят': 80, 37 | 'девяносто': 90, 38 | 'сто': 100, 39 | 'двести': 200, 40 | 'триста': 300, 41 | 'четыреста': 400, 42 | 'пятьсот': 500, 43 | 'шестьсот': 600, 44 | 'семьсот': 700, 45 | 'восемьсот': 800, 46 | 'девятьсот': 900, 47 | } 48 | NUMS_RAW_BIG = { 49 | 'тысяча': 10 ** 3, 50 | 'миллион': 10 ** 6, 51 | 'миллиард': 10 ** 9, 52 | 'триллион': 10 ** 12, 53 | } 54 | 55 | DOT = eq('.') 56 | INT = type('INT') 57 | THOUSANDTH = rule(caseless_pipeline(['тысячных', 'тысячная'])).interpretation(const(10**-3)) 58 | HUNDREDTH = rule(caseless_pipeline(['сотых', 'сотая'])).interpretation(const(10**-2)) 59 | TENTH = rule(caseless_pipeline(['десятых', 'десятая'])).interpretation(const(10**-1)) 60 | INTEGER_PART = rule(caseless_pipeline(['целых', 'целая'])).interpretation(const(10**0)) 61 | THOUSAND = or_( 62 | rule(caseless('т'), DOT), 63 | rule(caseless('тыс'), DOT.optional()), 64 | rule(normalized('тысяча')), 65 | rule(normalized('тыща')) 66 | ).interpretation(const(10**3)) 67 | MILLION = or_( 68 | rule(caseless('млн'), DOT.optional()), 69 | rule(normalized('миллион')) 70 | ).interpretation(const(10**6)) 71 | MILLIARD = or_( 72 | rule(caseless('млрд'), DOT.optional()), 73 | rule(normalized('миллиард')) 74 | ).interpretation(const(10**9)) 75 | TRILLION = or_( 76 | rule(caseless('трлн'), DOT.optional()), 77 | rule(normalized('триллион')) 78 | ).interpretation(const(10**12)) 79 | WITH_HALF = or_( 80 | rule(caseless('с'), normalized('половина')), 81 | ).interpretation(const(0.5)).interpretation(Number.with_half) 82 | MULTIPLIER = or_( 83 | THOUSANDTH, 84 | HUNDREDTH, 85 | TENTH, 86 | INTEGER_PART, 87 | THOUSAND, 88 | MILLION, 89 | MILLIARD, 90 | TRILLION 91 | ).interpretation(Number.multiplier) 92 | NUM_RAW_BIG = rule(morph_pipeline(NUMS_RAW_BIG).interpretation(Number.multiplier.normalized().custom(NUMS_RAW_BIG.get))) 93 | NUM_RAW = rule(morph_pipeline(NUMS_RAW).interpretation(Number.int.normalized().custom(NUMS_RAW.get))) 94 | NUM_INT = rule(INT).interpretation(Number.int.custom(int)) 95 | NUM = or_( 96 | NUM_RAW_BIG, 97 | NUM_RAW, 98 | NUM_INT 99 | ) 100 | NUMBER = or_( 101 | rule( 102 | NUM, 103 | WITH_HALF.optional(), 104 | MULTIPLIER.optional() 105 | ) 106 | ).interpretation(Number) 107 | -------------------------------------------------------------------------------- /work/infer/examples.json: -------------------------------------------------------------------------------- 1 | { 2 | "я купил iphone 12X за 142 990 руб без 3-x часов полдень и т.д.": "я купил айфон двенадцать икс за сто сорок две тысячи девятьсот девяносто руб без трёх часов полдень и т.д.", 3 | "я купил айфон за 14 970 рублей": "я купил айфон за четырнадцать тысяч девятьсот семьдесят рублей", 4 | "Временами я думаю, какое применение найти тем 14 697 рублям, что лежат уже больше 33 лет?": "Временами я думаю, какое применение найти тем четырнадцати тысячам шестистам девяносто семи рублям, что лежат уже больше тридцати трёх лет?", 5 | "Было у отца 3 сына, но не было даже 2-3 пиджаков с блёстками за 142 990 рублей.": "Было у отца три сына, но не было даже двух-трёх пиджаков с блёстками за сто сорок две тысячи девятьсто девяносто рублей.", 6 | "В школе у меня одни 5.": "В школе у меня одни пятёрки.", 7 | "Было у отца 3 сына. Старшему было 35, среднему - не меньше 33, а младший на 4 младше всех. Бывает.": "Было у отца три сына. Старшему было тридцать пять, среднему - не меньше тридцати трех, а младший на четыре младше всех. Бывает.", 8 | "я вырос на the beatles, меня не испугать даже 33 yellow submarine": "я вырос на зе битлз, меня не испугать даже тридцатью тремя йеллоу сабмэрин", 9 | "слыш nigga ты слыхал про gitdata?": "слыш нигга ты слыхал про гитдата?", 10 | "стоимость samsung 32MX Pro — всего 189 600 руб!": "стоимость самсунг тридцать два эм икс про — всего сто восемьдесят девять тысяч шестьсот руб!" 11 | } -------------------------------------------------------------------------------- /work/infer/infer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "b7ac75be-4e65-4949-bfed-83912bf307da", 7 | "metadata": { 8 | "scrolled": true 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "!pip install gradio sentencepiece" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "8bff11d0-f103-4139-b94b-6dab26563795", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import gradio as gr\n", 23 | "import json\n", 24 | "import re\n", 25 | "import torch\n", 26 | "from transformers import GPT2Tokenizer, T5ForConditionalGeneration\n", 27 | "from IPython.display import IFrame" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "8d1b2e67-8a5e-4fde-a87b-7289236ec56d", 34 | "metadata": { 35 | "scrolled": true 36 | }, 37 | "outputs": [], 38 | "source": [ 39 | "# device = \"cuda:0\"\n", 40 | "device = \"cpu\"\n", 41 | "HOST_IP = \"192.168.31.167\"\n", 42 | "GRADIO_PORT = 7860" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "id": "a34f2d4b-8f1e-4668-99a6-20a34d35e3ec", 48 | "metadata": {}, 49 | "source": [ 50 | "## FRED-T5-large-FT" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "id": "00325c3d-d348-49e9-9c8d-7b387d1f7fcb", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "# path = \"/home/jovyan/wdc1/models/FRED-T5-large\"\n", 61 | "path = \"/home/jovyan/models/3_fred-t5/checkpoint-11000\"\n", 62 | "path = \"/home/jovyan/models/7_fred-t5-large/checkpoint-35000\"\n", 63 | "tokenizer = GPT2Tokenizer.from_pretrained(path, eos_token='')\n", 64 | "model = T5ForConditionalGeneration.from_pretrained(path).to(device)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "id": "5a18857a-faf0-42b6-9df4-a31ba2f360c2", 70 | "metadata": {}, 71 | "source": [ 72 | "## ruT5-base-FT" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "id": "8990431a-fb43-4a87-8bef-34fa0644dfd6", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "from transformers import T5ForConditionalGeneration, T5Tokenizer\n", 83 | "path = \"/home/jovyan/models/8_ruT5-base/checkpoint-17000/\"\n", 84 | "model = T5ForConditionalGeneration.from_pretrained(path).to(device)\n", 85 | "tokenizer = T5Tokenizer.from_pretrained(path)\n", 86 | "tokenizer.add_tokens(\"\\n\")" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "id": "f27a9773-255f-4be0-b896-274d6dfb2ea3", 92 | "metadata": {}, 93 | "source": [ 94 | "## Common code then" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "id": "72df65fc-34da-49ea-8d06-f26647cadea6", 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "def predict(text):\n", 105 | " input_ids = torch.tensor([tokenizer.encode(text)]).to(device)\n", 106 | " with torch.no_grad():\n", 107 | " outputs = model.generate(input_ids, max_new_tokens=50, eos_token_id=tokenizer.eos_token_id, early_stopping=True)\n", 108 | " return tokenizer.decode(outputs[0][1:])\n", 109 | "\n", 110 | "\n", 111 | "predict(\"Было у отца [3] сына, но не было даже [2- 3] пиджаков с блёстками за [142 990] руб.\")" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "id": "67e8bd54-af8f-4dcb-a2b5-39eb96f5439b", 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "with open(\"examples.json\") as f:\n", 122 | " test_examples = json.load(f)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "id": "3617f748-5fa9-40f3-a8d6-a366c6ce8d73", 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "# re_tokens = re.compile(r\"[а-яА-Я]+\\s*|\\d+(?:\\.\\d+)?\\s*|[^а-яА-Я\\d\\s]+\\s*\")\n", 133 | "re_tokens = re.compile(r\"(?:[.,!?]|[а-яА-Я]\\S*|\\d\\S*(?:\\.\\d+)?|[^а-яА-Я\\d\\s]+)\\s*\")\n", 134 | "\n", 135 | "\n", 136 | "def tokenize(text):\n", 137 | " return re.findall(re_tokens, text)\n", 138 | "\n", 139 | "\n", 140 | "def strip_numbers(s):\n", 141 | " result = []\n", 142 | " for part in s.split():\n", 143 | " if part.isdigit():\n", 144 | " while len(part) > 3:\n", 145 | " result.append(part[:- 3 * ((len(part) - 1) // 3)])\n", 146 | " part = part[- 3 * ((len(part) - 1) // 3):]\n", 147 | " if part:\n", 148 | " result.append(part)\n", 149 | " else:\n", 150 | " result.append(part)\n", 151 | " return \" \".join(result)\n", 152 | "\n", 153 | "\n", 154 | "def construct_prompt(text):\n", 155 | " result = \"\"\n", 156 | " etid = 0\n", 157 | " token_to_add = \"\"\n", 158 | " for token in tokenize(text) + [\"\"]:\n", 159 | " if not re.search(\"[a-zA-Z\\d]\", token):\n", 160 | " if token_to_add:\n", 161 | " end_match = re.search(r\"(.+?)(\\W*)$\", token_to_add, re.M).groups()\n", 162 | " result += f\"[{strip_numbers(end_match[0])}]{end_match[1]}\"\n", 163 | " etid += 1\n", 164 | " token_to_add = \"\"\n", 165 | " result += token\n", 166 | " else:\n", 167 | " token_to_add += token\n", 168 | " return result\n", 169 | "\n", 170 | "\n", 171 | "construct_prompt('я купил iphone 12X за 142 990 руб без 3-x часов 12:00, и т.д.')" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "id": "1a10cf6b-69a0-480b-8b95-5bdfffca669e", 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "def construct_answer(prompt:str, prediction:str) -> str:\n", 182 | " replaces = []\n", 183 | " re_prompt = re.compile(r\"\\[([^\\]]+)\\]\")\n", 184 | " re_pred = re.compile(r\"\\(.+?)(?=\\|)\")\n", 185 | " pred_data = {}\n", 186 | " for match in re.finditer(re_pred, prediction.replace(\"\\n\", \" \")):\n", 187 | " pred_data[match[1]] = match[2].strip()\n", 188 | " while match := re.search(re_prompt, prompt):\n", 189 | " replace = pred_data.get(match[2], match[1])\n", 190 | " prompt = prompt[:match.span()[0]] + replace + prompt[match.span()[1]:]\n", 191 | " return prompt.replace(\"\", \"\")\n", 192 | " \n", 193 | "construct_answer(\n", 194 | " 'Было у отца [3] сына. Старшему было [35], среднему - не меньше [33], а младший на [4] младше всех. Бывает.',\n", 195 | " \"\"\" три\n", 196 | " тридцать пять\n", 197 | " тридцати трех\n", 198 | " четыре\n", 199 | "\"\"\"\n", 200 | ")" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "id": "6e0af086-7e52-4ecc-8535-c99a4a7936e8", 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "def norm(message, history):\n", 211 | " prompt = construct_prompt(message)\n", 212 | " yield f\"```Prompt:\\n{prompt}\\nPrediction:\\n...```\\n...\"\n", 213 | " prediction = predict(prompt)\n", 214 | " answer = construct_answer(prompt, prediction)\n", 215 | " yield f\"Prompt:\\n```{prompt}```\\nPrediction:\\n```\\n{prediction}\\n```\\n{answer}\"\n", 216 | "\n", 217 | "\n", 218 | "demo = gr.ChatInterface(norm, stop_btn=None, examples=list(test_examples.keys())).queue()\n", 219 | "demo.launch(inline=False, server_name=\"0.0.0.0\", server_port=GRADIO_PORT, inbrowser=True)\n", 220 | "IFrame(src=f\"http://{HOST_IP}:{GRADIO_PORT}\", width='100%', height='500px')" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "id": "4cbe5608-2b20-44bc-8599-4f36c24b7cdd", 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "# found bad results with batch generation on encoder-decoder architectures surprisingly so one by one here\n", 231 | "for lm_text, gt in test_examples.items():\n", 232 | " prompt = construct_prompt(lm_text)\n", 233 | " prediction = predict(prompt)\n", 234 | " answer = construct_answer(prompt, prediction)\n", 235 | " if gt == answer:\n", 236 | " print(f\"{gt}\\n\")\n", 237 | " else:\n", 238 | " print(f\"{lm_text}\\n{prompt}\\n{gt}\\n{answer}\\n{prediction}\\n\")" 239 | ] 240 | } 241 | ], 242 | "metadata": { 243 | "kernelspec": { 244 | "display_name": "Python 3 (ipykernel)", 245 | "language": "python", 246 | "name": "python3" 247 | }, 248 | "language_info": { 249 | "codemirror_mode": { 250 | "name": "ipython", 251 | "version": 3 252 | }, 253 | "file_extension": ".py", 254 | "mimetype": "text/x-python", 255 | "name": "python", 256 | "nbconvert_exporter": "python", 257 | "pygments_lexer": "ipython3", 258 | "version": "3.10.13" 259 | } 260 | }, 261 | "nbformat": 4, 262 | "nbformat_minor": 5 263 | } 264 | -------------------------------------------------------------------------------- /work/replaces.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from difflib import SequenceMatcher 3 | import re 4 | 5 | 6 | class Replace(dict): 7 | """ 8 | `type` is "E" for equal if `text_to` is None or is the same with `text_from`. 9 | Otherwise `type` is "R" for replace. 10 | """ 11 | 12 | def __init__( 13 | self, 14 | text_from: str, text_to: Optional[str]=None, 15 | *args, **kwargs 16 | ): 17 | """ 18 | If `text_to` is None the Replace is supposed to be equal. 19 | """ 20 | super().__init__(*args, **kwargs) 21 | self["text_from"] = text_from 22 | self["text_to"] = text_to 23 | 24 | @property 25 | def type(self): 26 | return "E" if (self.text_from == self.text_to or self["text_to"] is None) else "R" 27 | 28 | @property 29 | def text_from(self): 30 | return self["text_from"] 31 | @text_from.setter 32 | def text_from(self, value): 33 | self["text_from"] = value 34 | 35 | @property 36 | def text_to(self): 37 | return self["text_to"] if self["text_to"] is not None else self.text_from 38 | @text_to.setter 39 | def text_to(self, value): 40 | self["text_to"] = value 41 | 42 | def extend(self, r): 43 | if self.type != r.type: 44 | raise Exception("Replace type mismatch") 45 | self.text_from += r.text_from 46 | self.text_to += r.text_to 47 | return self 48 | 49 | 50 | class Replaces(list): 51 | __re_digits_latin = re.compile(r"[a-zA-Z\d]") 52 | 53 | def __init__(self, *args): 54 | super().__init__(*args) 55 | for i, elem in enumerate(self): 56 | if not isinstance(elem, Replace): 57 | if isinstance(elem, dict): 58 | self[i] = Replace(**elem) 59 | else: 60 | self[i] = Replace(elem) 61 | 62 | def add(self, r: Replace): 63 | if self and r.type == self[-1].type: 64 | self[-1].extend(r) 65 | else: 66 | return super().append(r) 67 | 68 | def __repr__(self): 69 | return "\n".join((f'{r.type}|{r.text_from}{" => " + r.text_to if r.type != "E" else ""}' for r in self)) 70 | 71 | @staticmethod 72 | def from_sequences(seq1, seq2, ingore_not_digit_latin=True): 73 | """ 74 | If `ingore_not_digit_latin` element pairs containing no digits and latin would be treated as equal to `seq1` element. 75 | """ 76 | sm = SequenceMatcher( 77 | # lambda x: not re.search(r"\w", x.strip()), 78 | a=seq1, 79 | b=seq2, 80 | autojunk=False 81 | ) 82 | result = Replaces() 83 | for tag, i1, i2, j1, j2 in sm.get_opcodes(): 84 | text_from, text_to = "".join(seq1[i1:i2]), "".join(seq2[j1:j2]) 85 | if tag == "equal": 86 | pass 87 | elif tag == "replace" and "".join((_.strip() for _ in seq1[i1:i2])) == "".join((_.strip() for _ in seq2[j1:j2])): 88 | text_to = None 89 | elif ingore_not_digit_latin and \ 90 | not re.search(Replaces.__re_digits_latin, text_from) and \ 91 | not re.search(Replaces.__re_digits_latin, text_to): 92 | text_to = None 93 | result.add(Replace(text_from, text_to)) 94 | return result 95 | -------------------------------------------------------------------------------- /work/tb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "ee65c542-36c9-4e14-97f2-620f61a70a14", 7 | "metadata": { 8 | "scrolled": true 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "!pip install tensorboardX tensorboard" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "b4899e1b-3db3-4918-808c-4549e46c743d", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "%load_ext tensorboard" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "id": "b3cf987b-0734-4dae-b9c8-b308b53234ec", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "%tensorboard --logdir /home/jovyan/models/8_ruT5-base/runs --host 0.0.0.0 --port 8007" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "id": "c4a206b3-d761-4151-82ab-815172165026", 38 | "metadata": {}, 39 | "source": [ 40 | "In order to delete stuck and unaccessable tensorboard instances which hold ports one is possible to observe not only the corresponding list via API call but containment of its folder as well so that extra instances info could be removed manually. " 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "f13097f2-3110-4e6c-9da3-4e07a57c94e4", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "from tensorboard import notebook\n", 51 | "notebook.list()" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "7d294816-cfbb-4a7d-8e95-226b4bcfab74", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "!ls /tmp/.tensorboard-info" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "id": "8c8e3ba8-5ee2-4557-92fd-ed173fcfe51a", 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "!rm /tmp/.tensorboard-info/pid-243.info" 72 | ] 73 | } 74 | ], 75 | "metadata": { 76 | "kernelspec": { 77 | "display_name": "Python 3 (ipykernel)", 78 | "language": "python", 79 | "name": "python3" 80 | }, 81 | "language_info": { 82 | "codemirror_mode": { 83 | "name": "ipython", 84 | "version": 3 85 | }, 86 | "file_extension": ".py", 87 | "mimetype": "text/x-python", 88 | "name": "python", 89 | "nbconvert_exporter": "python", 90 | "pygments_lexer": "ipython3", 91 | "version": "3.10.13" 92 | } 93 | }, 94 | "nbformat": 4, 95 | "nbformat_minor": 5 96 | } 97 | -------------------------------------------------------------------------------- /work/train/readme.md: -------------------------------------------------------------------------------- 1 | # Single GPU train [.ipynb](./work/train/train.ipynb) 2 | 3 | First model has been trained on singe GPU. 4 | 5 | I personally have four RTX 3060 12Gb. 6 | [FRED-T5-1.7B](https://huggingface.co/ai-forever/FRED-T5-1.7B) does not fit into a single GPU. 7 | Surprisingly (not), if not restricted to single GPU, [FRED-T5-large](https://huggingface.co/ai-forever/FRED-T5-large) training causes CUDA OOM. 8 | Seems to be some additional, like, occupation of gradient data leaking from all the GPUs to the main one. 9 | So I have done that `import os; os.environ["CUDA_VISIBLE_DEVICES"] = "0"` thing and used **FRED-T5-large**. 10 | 11 | According to the [memory calculator](https://huggingface.co/spaces/hf-accelerate/model-memory-usage), I need these pieces of amount of GPU RAM. 12 | 13 | | Model | Train RAM (f32) | Train RAM (f16) | Inference RAM (f32) | Inference RAM (f16) | Inference RAM (int8) | Inference RAM (int4) | 14 | |:--------------|:---------------:|:---------------:|:-------------------:|:-------------------:|:--------------------:|:--------------------:| 15 | | FRED-T5-1.7B | 24.78 GB | 12.39 GB | 6.2 GB | 3.1 GB | 1.55 GB | 792.98 MB | 16 | | FRED-T5-large | 11.46 GB | 5.73 GB | 2.86 GB | 1.43 GB | 733.3 MB | 366.65 MB | 17 | | [ruT5-large](https://huggingface.co/ai-forever/ruT5-large) | 10.99 GB | 5.5 GB | 2.75 GB | 1.37 GB | 703.5 MB | 351.75 MB | 18 | | [ruT5-base](https://huggingface.co/ai-forever/ruT5-base) | 3.32 GB | 1.66 GB | 850.31 MB | 425.15 MB | 212.58 MB | 106.29 MB | 19 | 20 | # Two-GPU distributed train [.ipynb](./work/train/train.ipynb) 21 | 22 | **transformers** itself suggests [three](https://huggingface.co/docs/transformers/perf_train_gpu_many) options to train on several GPU as a model doesn't fit into a single one. 23 | I chose TensorParallel as I have found a good (but a bit obsolete) package for that :) 24 | 25 | The second model has been trained on two GPUs with [`tensor_parallel`](https://github.com/BlackSamorez/tensor_parallel) package. 26 | I suppose I could use 3+ GPUs but there was `Bus error (core dumped)` error (not caused by the library as vanilla train of a _small_ model do cause it as well). 27 | > I've managed to fix it then by increasing SHM size up to 8 GB as suggested [here](https://github.com/pytorch/pytorch/issues/2244). 28 | -------------------------------------------------------------------------------- /work/train/train-distributed.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f4b1b475-1644-468a-afab-95306a4d69df", 6 | "metadata": {}, 7 | "source": [ 8 | "It is the time to thain something finally.\n", 9 | "\n", 10 | "Based on [translation.ipynb](https://colab.research.google.com/github/huggingface/notebooks/blob/main/transformers_doc/en/pytorch/translation.ipynb) and [fred-t5 finetune repo](https://github.com/Den4ikAI/FRED-T5-Finetuning). Modified as in [`tensor_parallel` example](https://github.com/BlackSamorez/tensor_parallel/blob/main/examples/training_flan-t5-xl.ipynb).\n", 11 | "\n", 12 | "~~I use two of my 4x RTX3060 12GB rig as use of 3+ GPUs cause `Bus error (core dumped)` error. One is necessary to restart the jupyterlab docker container then in order to recover it.~~ Fixed it by [increasing](https://github.com/pytorch/pytorch/issues/2244) shared memoty container size.\n", 13 | "\n", 14 | "> It is possible to use `\"cuda:3\"` device for a single gpu but `\"cuda:2,3\"` seems to be not supported by 🤗 thansformers." 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "id": "db03144e-02ba-427d-bb11-b2f9e759437e", 20 | "metadata": {}, 21 | "source": [ 22 | "`tensor_parallel` does not work with modern versions of transformers (despite its official requirements) so I had to downgrade it manually.\n", 23 | "```\n", 24 | "!pip install tensor_parallel\n", 25 | "!pip install transformers==4.29.2\n", 26 | "```" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "id": "34cf75a0-c419-447a-ae82-7d4d38821c0d", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "import os\n", 37 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1,2\"" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "id": "b3edf4a5-488c-494c-b520-acd70b9028e3", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "from difflib import SequenceMatcher\n", 48 | "import re\n", 49 | "import json\n", 50 | "from tqdm.notebook import tqdm\n", 51 | "import random" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "57b44495-e1b9-445a-b1c5-81436dee6820", 58 | "metadata": { 59 | "scrolled": true 60 | }, 61 | "outputs": [], 62 | "source": [ 63 | "# !pip install datasets transformers[torch]" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "id": "a74c47bb-4eba-473b-b609-600e2b54bc61", 69 | "metadata": {}, 70 | "source": [ 71 | "I use a part of the data I have as the model trains too long otherwise.\n", 72 | "8-12 hours of finetuning was just fine for my usual task so I prefer to hold on to this here." 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "id": "ce617a23-6133-47e7-93fe-7e2f74c943d2", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "FILES = [\n", 83 | " \"/home/jovyan/data/kaggle.jsonl\",\n", 84 | " \"/home/jovyan/data/ficbook_replaces.jsonl\",\n", 85 | " \"/home/jovyan/data/pikabu_replaces.jsonl\",\n", 86 | " # \"/home/jovyan/data/librusec_replaces.jsonl\"\n", 87 | "]" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "id": "9ca096ee-0ec4-4b29-b4d3-c3a8e21db0c8", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "# MODEL_PATH = \"/home/jovyan/wdc1/models/FRED-T5-1.7B\"\n", 98 | "# MODEL_PATH = \"/home/jovyan/wdc1/models/FRED-T5-large\"\n", 99 | "MODEL_PATH = \"/home/jovyan/wdc1/models/ruT5-base\"\n", 100 | "# MODEL_PATH = \"/home/jovyan/models/3_fred-t5/checkpoint-11000\"\n", 101 | "\n", 102 | "TRAINED_SAVE_PATH = \"/home/jovyan/models/8_ruT5-base\"" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "id": "f30e7f0b-c5c3-49d9-83c3-51874e5d61ca", 108 | "metadata": {}, 109 | "source": [ 110 | "In case of `ruT5-base` training do\n", 111 | "\n", 112 | "```python\n", 113 | "!pip install datasets transformers[sentencepiece]\n", 114 | "from transformers import T5ForConditionalGeneration, T5Tokenizer\n", 115 | "path = \"./ruT5-base\"\n", 116 | "model = T5ForConditionalGeneration.from_pretrained(path)\n", 117 | "tokenizer = T5Tokenizer.from_pretrained(path)\n", 118 | "```" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "1c30a2cb-5b29-42c9-9153-25f97f693b44", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "# !pip install transformers[sentencepiece]" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "id": "c38bd9ac-f71d-4592-a6f2-fb8d50932f93", 134 | "metadata": {}, 135 | "source": [ 136 | "# He obtayn" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "id": "18dcf538-7d11-49f4-a98b-d7732d9bb0d8", 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "dataset = []\n", 147 | "for file in FILES:\n", 148 | " with open(file) as f:\n", 149 | " for line in tqdm(f, desc=file):\n", 150 | " dataset.append({\"replaces\": json.loads(line)[\"replaces\"]})" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "id": "68172e55-55b3-4288-9404-646a5a730c23", 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "dataset[1000]" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "id": "1b61374b-eb2c-489e-91e2-2c3a1ece0a18", 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "from transformers import GPT2Tokenizer, T5Tokenizer\n", 171 | "\n", 172 | "\n", 173 | "tokenizer = T5Tokenizer.from_pretrained(MODEL_PATH)\n", 174 | "# tokenizer = GPT2Tokenizer.from_pretrained(MODEL_PATH, eos_token='')" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "id": "c3677846-72b3-47c2-b92c-a1a7a146006b", 180 | "metadata": {}, 181 | "source": [ 182 | "One problem about the last train iteration was deluded prediction of long numbers like `125678`.\n", 183 | "It could possibly happen because of tokenization of numbers if divided on parts which are not easy to operate.\n", 184 | "Lets check it out now." 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "id": "976a2e95-d83d-4ea4-93e4-af39707368c0", 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "good, wrong = [], []\n", 195 | "for i in range(100, 1000):\n", 196 | " a = str(i)\n", 197 | " ids = tokenizer.encode(a, add_special_tokens=False)\n", 198 | " b = \"|\".join([tokenizer.decode(_, skip_special_tokens=True) for _ in ids])\n", 199 | " (wrong if a != b else good).append(b)" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "id": "bb03ffd0-915c-43cf-ae1a-719d2f7287f2", 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "len(good), len(wrong)\n", 210 | "# (28, 872)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "id": "8c25ee62-2672-4b6b-862f-de792fa2063f", 216 | "metadata": {}, 217 | "source": [ 218 | "The particular `FRED-T5-large` tokenizer splitted the majority of the three digits numbers.\n", 219 | "May be it would be better if numbers are forced splitted on single digits like `123456` to `1 2 3 4 5 6`.\n", 220 | "\n", 221 | "Other option is to divide numbers by three digit groups such that `1234567` would turn into `1 234 567`. We try that option first." 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "id": "94d00c8d-463b-46f3-a447-c5a26036885b", 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "def strip_numbers(s):\n", 232 | " return \" \".join(((\" \".join(part) if part.isdigit() else part) for part in s.split()))\n", 233 | "\n", 234 | "\n", 235 | "def strip_numbers(s):\n", 236 | " result = []\n", 237 | " for part in s.split():\n", 238 | " if part.isdigit():\n", 239 | " while len(part) > 3:\n", 240 | " result.append(part[:- 3 * ((len(part) - 1) // 3)])\n", 241 | " part = part[- 3 * ((len(part) - 1) // 3):]\n", 242 | " if part:\n", 243 | " result.append(part)\n", 244 | " else:\n", 245 | " result.append(part)\n", 246 | " return \" \".join(result)\n", 247 | "\n", 248 | "\n", 249 | "strip_numbers(\"у нас было 1234567890 пакетиков травы, 750 ампул новокаина, 55555 пакетиков диэтиламида лизергиновой кислоты, солонка, на 1000/2000 наполненная кокаином\")\n", 250 | "# \"у нас было 1 234 567 890 пакетиков травы, 750 ампул новокаина, 55 555 пакетиков диэтиламида лизергиновой кислоты, солонка, на 1000/2000 наполненная кокаином\"" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "id": "3ab6a775-95df-4077-a81a-d43bb231b8b5", 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "import sys\n", 261 | "import os\n", 262 | "sys.path.append(os.path.join(os.path.dirname(os.getcwd())))\n", 263 | "from replaces import Replace, Replaces" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "id": "33bb3d93-4339-4f06-b41c-5a37ea0dd5f5", 270 | "metadata": { 271 | "scrolled": true 272 | }, 273 | "outputs": [], 274 | "source": [ 275 | "from collections import Counter\n", 276 | "from itertools import chain\n", 277 | "data = []\n", 278 | "# occ_limit = len(dataset) / 100 # rough trim here\n", 279 | "added = Counter()\n", 280 | "for elem in tqdm(dataset):\n", 281 | " elem[\"replaces\"] = Replaces(elem[\"replaces\"]) # recover class stuff\n", 282 | " if all(_.type == \"E\" for _ in elem[\"replaces\"]):\n", 283 | " continue\n", 284 | " if \"prompt\" in elem and \"target\" in elem:\n", 285 | " continue\n", 286 | " replace_words = list(chain(*(r.text_to.strip().lower().split() for r in elem[\"replaces\"] if r.type != \"E\")))\n", 287 | " # if not any(added[word] < occ_limit for word in replace_words):\n", 288 | " # continue\n", 289 | " added.update(replace_words)\n", 290 | " prompt, target = \"\", \"\"\n", 291 | " etid = 0\n", 292 | " for r in elem[\"replaces\"]:\n", 293 | " if r.type == \"E\":\n", 294 | " prompt += r.text_to\n", 295 | " else:\n", 296 | " ws_number = len(r.text_from) - len(r.text_from.rstrip())\n", 297 | " prompt += f\"[{strip_numbers(r.text_from.rstrip())}]{' ' * ws_number}\"\n", 298 | " target += f\" {r.text_to.strip()}\\n\"\n", 299 | " etid += 1\n", 300 | " elem[\"prompt\"] = f\"{prompt}\"\n", 301 | " elem[\"target\"] = f\"{target}\"\n", 302 | " data.append(elem)" 303 | ] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "id": "60e309d4-d538-40a3-af7e-9e2bb1202850", 308 | "metadata": {}, 309 | "source": [ 310 | "We made here train examples of that kind\n", 311 | "\n", 312 | " Временами я думаю, какое применение найти тем [14 697] рублям, что лежат уже больше [33] лет?\n", 313 | "\n", 314 | "and we want to predict a text like this\n", 315 | "\n", 316 | " четырнадцати тысячам шестистам девяноста семи\n", 317 | " тридцати трёх \n", 318 | "\n", 319 | "Lets check what have we added so far like the most (un)common __words__." 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": null, 325 | "id": "dbde1073-7c33-4537-b0cb-587ed17de332", 326 | "metadata": {}, 327 | "outputs": [], 328 | "source": [ 329 | "added.most_common()[:10], added.most_common()[-10:], data[0]" 330 | ] 331 | }, 332 | { 333 | "cell_type": "markdown", 334 | "id": "fadf4e99-c7ca-46d8-b86c-8ee8a67a5938", 335 | "metadata": {}, 336 | "source": [ 337 | "Besides rare mistakes it seems to be trainable on.\n", 338 | "\n", 339 | "The distribution is shifted anyway to my taste as will be shown later.\n", 340 | "One fast and simple thing to do about it is to iterate over and filter examples as we have seen too much of **all** the replaced words at the moment." 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "id": "851a6301-7160-421d-bc04-174abfc3a5ac", 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [ 350 | "occ_limit = (sum(added.values()) / len(added)) ** 2\n", 351 | "print(occ_limit)\n", 352 | "added2 = Counter()\n", 353 | "balanced_data = []\n", 354 | "for elem in tqdm(data):\n", 355 | " replace_words = list(chain(*[r.text_to.strip().lower().split() for r in elem[\"replaces\"] if r.type != \"E\"]))\n", 356 | " if any((added2[word] < occ_limit for word in replace_words)):\n", 357 | " balanced_data.append(elem)\n", 358 | " added2.update(replace_words)\n", 359 | "len(balanced_data), len(balanced_data) / len(data)" 360 | ] 361 | }, 362 | { 363 | "cell_type": "markdown", 364 | "id": "ef0b9beb-a30e-4b3a-8014-dde28be310f0", 365 | "metadata": {}, 366 | "source": [ 367 | "We have gotten rid of 2/3 of the data we had had!\n", 368 | "Check it out visually now." 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": null, 374 | "id": "fe982700-1d5c-458c-9a60-4ce1f85889ae", 375 | "metadata": {}, 376 | "outputs": [], 377 | "source": [ 378 | "stat_regs = {\n", 379 | " \"re_digits\": re.compile(r\"\\d\"),\n", 380 | " \"re_digits_latin\": re.compile(r\"[a-zA-Z\\d]\"),\n", 381 | " \"re_latin\": re.compile(r\"[a-zA-Z]\")\n", 382 | "}\n", 383 | "for stat_name, stat_re in stat_regs.items():\n", 384 | " print(\n", 385 | " stat_name,\n", 386 | " len([elem for elem in tqdm(balanced_data) if any(re.search(stat_re, r.text_from) for r in elem[\"replaces\"])])\n", 387 | " )" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "id": "f52f0fe3-b7b0-444c-b124-d3e7a9a86257", 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "from matplotlib import pyplot as plt \n", 398 | "\n", 399 | "axs = plt.subplot()\n", 400 | "axs.set_yscale('log')\n", 401 | "axs.plot([_[1] for _ in added2.most_common()[:1000]])\n", 402 | "axs.plot([_[1] for _ in added.most_common()[:1000]])" 403 | ] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "id": "c8e31a12-47a8-4273-9d34-e50074d5b8fe", 408 | "metadata": {}, 409 | "source": [ 410 | "Only extra data lost so far." 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": null, 416 | "id": "df676f67-1c4d-441d-9072-c30d44c9c9c2", 417 | "metadata": {}, 418 | "outputs": [], 419 | "source": [ 420 | "del data" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": null, 426 | "id": "ffb88cae-2d41-4d23-989f-15fd8afc6ed2", 427 | "metadata": {}, 428 | "outputs": [], 429 | "source": [ 430 | "from datasets import Dataset\n", 431 | "\n", 432 | "\n", 433 | "dataset = Dataset.from_list(balanced_data).train_test_split(test_size=0.01)\n", 434 | "dataset" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": null, 440 | "id": "0b6a8d28-64e2-4758-8b7f-e930151d2bcd", 441 | "metadata": {}, 442 | "outputs": [], 443 | "source": [ 444 | "def preprocess_function(examples):\n", 445 | " model_inputs = tokenizer(\n", 446 | " examples[\"prompt\"],\n", 447 | " text_target=examples[\"target\"],\n", 448 | " max_length=128, # NB should affect memory consumption\n", 449 | " truncation=True\n", 450 | " )\n", 451 | " return model_inputs\n", 452 | "\n", 453 | "\n", 454 | "dataset = dataset.map(preprocess_function, batched=True, num_proc=10)" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": null, 460 | "id": "f9d9a47b-1042-4334-9692-2af334eb1265", 461 | "metadata": {}, 462 | "outputs": [], 463 | "source": [ 464 | "dataset = dataset.remove_columns([\"prompt\", \"target\", \"replaces\"])" 465 | ] 466 | }, 467 | { 468 | "cell_type": "markdown", 469 | "id": "9c0abc0d-8a6e-4ff7-a45b-2da3151548c5", 470 | "metadata": {}, 471 | "source": [ 472 | "Just in case I get rid of examples with possible truncation mistakes." 473 | ] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "execution_count": null, 478 | "id": "9d66e513-dfa3-4f93-b798-6738dae70584", 479 | "metadata": { 480 | "scrolled": true 481 | }, 482 | "outputs": [], 483 | "source": [ 484 | "from collections import Counter\n", 485 | "c = Counter([len(_[\"input_ids\"]) for _ in dataset[\"train\"]])\n", 486 | "sum([v for k, v in c.items() if k < 128]), c" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": null, 492 | "id": "6917cfec-1674-4658-aac6-c4f03df8705b", 493 | "metadata": {}, 494 | "outputs": [], 495 | "source": [ 496 | "for k, v in dataset.items():\n", 497 | " dataset[k] = [_ for _ in v if 10 < len(_[\"input_ids\"]) < 126]\n", 498 | "{k:len(v) for k, v in dataset.items()}" 499 | ] 500 | }, 501 | { 502 | "cell_type": "markdown", 503 | "id": "669201f4-ad10-4c6b-8eed-14e712d3683a", 504 | "metadata": {}, 505 | "source": [ 506 | "# He trayn\n", 507 | "\n", 508 | "Time to train actually as last!" 509 | ] 510 | }, 511 | { 512 | "cell_type": "code", 513 | "execution_count": null, 514 | "id": "296c31ca-1755-4033-a45b-f2443d14ba7f", 515 | "metadata": {}, 516 | "outputs": [], 517 | "source": [ 518 | "from transformers import T5ForConditionalGeneration\n", 519 | "import torch\n", 520 | "\n", 521 | "\n", 522 | "model = T5ForConditionalGeneration.from_pretrained(\n", 523 | " MODEL_PATH,\n", 524 | " torch_dtype=torch.bfloat16,\n", 525 | " low_cpu_mem_usage=True,\n", 526 | " offload_state_dict=True\n", 527 | ")" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": null, 533 | "id": "949871ea-8c32-4304-9c8f-36716766590a", 534 | "metadata": {}, 535 | "outputs": [], 536 | "source": [ 537 | "import tensor_parallel as tp\n", 538 | "\n", 539 | "\n", 540 | "model = tp.tensor_parallel(\n", 541 | " model,\n", 542 | " [\"cuda:0\", \"cuda:1\"]\n", 543 | ")" 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": null, 549 | "id": "23105760-1e79-4641-8d75-e017719ef647", 550 | "metadata": {}, 551 | "outputs": [], 552 | "source": [ 553 | "input_ids = tokenizer(\"A cat sat on a mat\", return_tensors=\"pt\").input_ids.to(\"cuda\")\n", 554 | "output_ids = tokenizer(\"A cat sat did not sit on a mat\", return_tensors=\"pt\").input_ids.to(\"cuda\")\n", 555 | "loss = model(input_ids=input_ids, labels=output_ids).loss\n", 556 | "loss.backward() # check nvidia-smi for gpu memory usage :)" 557 | ] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "execution_count": null, 562 | "id": "616ab02c-e7a6-42c1-8274-eca363083ebd", 563 | "metadata": {}, 564 | "outputs": [], 565 | "source": [ 566 | "# !pip install bitsandbytes scipy" 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": null, 572 | "id": "e12cc228-f9db-441e-b406-62c7dfed0390", 573 | "metadata": {}, 574 | "outputs": [], 575 | "source": [ 576 | "from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq\n", 577 | "\n", 578 | "\n", 579 | "training_args = Seq2SeqTrainingArguments(\n", 580 | " output_dir=TRAINED_SAVE_PATH,\n", 581 | " # optim=\"adamw_bnb_8bit\",\n", 582 | " optim=\"adafactor\",\n", 583 | " evaluation_strategy=\"steps\",\n", 584 | " eval_steps=1000,\n", 585 | " save_steps=1000,\n", 586 | " logging_first_step=True,\n", 587 | " learning_rate=1e-4,\n", 588 | " lr_scheduler_type=\"constant\",\n", 589 | " # gradient_checkpointing=True,\n", 590 | " gradient_checkpointing=False,\n", 591 | " gradient_accumulation_steps=1,\n", 592 | " per_device_train_batch_size=32,\n", 593 | " per_device_eval_batch_size=32,\n", 594 | " save_total_limit=10,\n", 595 | " num_train_epochs=2,\n", 596 | " # predict_with_generate=True,\n", 597 | " # fp16=True,\n", 598 | " push_to_hub=False,\n", 599 | " remove_unused_columns=False,\n", 600 | " load_best_model_at_end=True,\n", 601 | " # auto_find_batch_size=True,\n", 602 | " auto_find_batch_size=False,\n", 603 | " dataloader_num_workers=4,\n", 604 | ")" 605 | ] 606 | }, 607 | { 608 | "cell_type": "code", 609 | "execution_count": null, 610 | "id": "6b445baa-b12a-48f1-9dba-2601b8175c3c", 611 | "metadata": {}, 612 | "outputs": [], 613 | "source": [ 614 | "import transformers\n", 615 | "transformers.logging.set_verbosity_info()" 616 | ] 617 | }, 618 | { 619 | "cell_type": "code", 620 | "execution_count": null, 621 | "id": "4642ba79-cca3-4ddc-8b5e-6255931af701", 622 | "metadata": {}, 623 | "outputs": [], 624 | "source": [ 625 | "with tp.save_tensor_parallel(model):\n", 626 | " trainer = Seq2SeqTrainer(\n", 627 | " model=model,\n", 628 | " args=training_args,\n", 629 | " train_dataset=dataset[\"train\"],\n", 630 | " eval_dataset=dataset[\"test\"],\n", 631 | " tokenizer=tokenizer,\n", 632 | " data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model),\n", 633 | " # compute_metrics=compute_metrics,\n", 634 | " # optimizers=(adam_bnb_optim, None),\n", 635 | " )\n", 636 | " trainer.train()" 637 | ] 638 | }, 639 | { 640 | "cell_type": "code", 641 | "execution_count": null, 642 | "id": "c5e2fe2b-1a58-48c3-a7be-37675fb723c9", 643 | "metadata": {}, 644 | "outputs": [], 645 | "source": [ 646 | "with tp.save_tensor_parallel(model):\n", 647 | " model.save_pretrained(os.path.join(TRAINED_SAVE_PATH, \"final\"))\n", 648 | " tokenizer.save_pretrained(os.path.join(TRAINED_SAVE_PATH, \"final\"))" 649 | ] 650 | }, 651 | { 652 | "cell_type": "markdown", 653 | "id": "b447b82f-5214-41b7-85ec-2e23271ddce6", 654 | "metadata": {}, 655 | "source": [ 656 | "# But most importantly he explayn" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": null, 662 | "id": "011a1834-3688-47f2-9519-2158f9c930e1", 663 | "metadata": {}, 664 | "outputs": [], 665 | "source": [ 666 | "import torch\n", 667 | "lm_text = 'я купил [iphone 12X] за [142 990] руб без [3-x] часов полдень и т.д.'\n", 668 | "# lm_text = 'я купил айфон за [14 970] рублей'\n", 669 | "lm_text = \"Временами я думаю, какое применение найти тем [14 697] рублям, что лежат уже больше [33] лет?\"\n", 670 | "# lm_text = \"Было у отца [3] сына, но не было даже [2-3] пиджаков с блёстками за [142 990 руб].\"\n", 671 | "# lm_text = \"В школе у меня одни [5].\"\n", 672 | "# lm_text = 'Было у отца [3] сына. Старшему было [35], среднему - не меньше [33], а младший на [4] младше всех. Бывает.'\n", 673 | "lm_text = \"Временами я думаю, какое применение найти тем [265 948 697] рублям, что лежат уже больше [33] лет?\"\n", 674 | "input_ids = torch.tensor([tokenizer.encode(lm_text)]).to(\"cuda:0\")\n", 675 | "outputs = model.generate(input_ids, eos_token_id=tokenizer.eos_token_id, early_stopping=True)\n", 676 | "print(tokenizer.decode(outputs[0][1:]))" 677 | ] 678 | }, 679 | { 680 | "cell_type": "code", 681 | "execution_count": null, 682 | "id": "a30d7a17-ee66-4e45-844b-07b7cae253b5", 683 | "metadata": {}, 684 | "outputs": [], 685 | "source": [ 686 | "!nvidia-smi" 687 | ] 688 | } 689 | ], 690 | "metadata": { 691 | "kernelspec": { 692 | "display_name": "Python 3 (ipykernel)", 693 | "language": "python", 694 | "name": "python3" 695 | }, 696 | "language_info": { 697 | "codemirror_mode": { 698 | "name": "ipython", 699 | "version": 3 700 | }, 701 | "file_extension": ".py", 702 | "mimetype": "text/x-python", 703 | "name": "python", 704 | "nbconvert_exporter": "python", 705 | "pygments_lexer": "ipython3", 706 | "version": "3.10.13" 707 | } 708 | }, 709 | "nbformat": 4, 710 | "nbformat_minor": 5 711 | } 712 | -------------------------------------------------------------------------------- /work/train/train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f4b1b475-1644-468a-afab-95306a4d69df", 6 | "metadata": {}, 7 | "source": [ 8 | "It is the time to thain something finally.\n", 9 | "\n", 10 | "Based on [translation.ipynb](https://colab.research.google.com/github/huggingface/notebooks/blob/main/transformers_doc/en/pytorch/translation.ipynb) and [fred-t5 finetune repo](https://github.com/Den4ikAI/FRED-T5-Finetuning)." 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "f9674bb5-a85e-498b-82b3-68da63a67ce8", 16 | "metadata": {}, 17 | "source": [ 18 | "I use a single RTX3060 12GB as naive use of 2+ GPUs cause OOM in case of `FRED-T5-large`.\n", 19 | "`ruT5-base` training is possible with `CUDA_VISIBLE_DEVICES=2,3` out of the box though." 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "id": "34cf75a0-c419-447a-ae82-7d4d38821c0d", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import os\n", 30 | "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1,2\"\n", 31 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "b3edf4a5-488c-494c-b520-acd70b9028e3", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "from difflib import SequenceMatcher\n", 42 | "import re\n", 43 | "import json\n", 44 | "from tqdm.notebook import tqdm\n", 45 | "import random" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "id": "a74c47bb-4eba-473b-b609-600e2b54bc61", 51 | "metadata": {}, 52 | "source": [ 53 | "I use a part of the data I have only as the model trains too long otherwise.\n", 54 | "8-12 hours of finetuning was just fine for my usual task so I prefer to hold on to this here." 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "id": "ce617a23-6133-47e7-93fe-7e2f74c943d2", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "DATASET_FILES = [\n", 65 | " \"/home/jovyan/data/kaggle.jsonl\",\n", 66 | " \"/home/jovyan/data/ficbook_replaces.jsonl\",\n", 67 | " \"/home/jovyan/data/pikabu_replaces.jsonl\",\n", 68 | " # \"/home/jovyan/data/librusec_replaces.jsonl\"\n", 69 | "]" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "9ca096ee-0ec4-4b29-b4d3-c3a8e21db0c8", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "MODEL = {\n", 80 | " 0: {\n", 81 | " \"type\": \"FRED-T5\",\n", 82 | " \"path\": \"/home/jovyan/wdc1/models/FRED-T5-large\",\n", 83 | " # \"path\": \"/home/jovyan/wdc1/models/FRED-T5-1.7B\"\n", 84 | " # \"path\": \"/home/jovyan/models/3_fred-t5/checkpoint-11000\"\n", 85 | " },\n", 86 | " 1: {\n", 87 | " \"type\": \"ruT5\",\n", 88 | " \"path\": \"/home/jovyan/wdc1/models/ruT5-base\",\n", 89 | " },\n", 90 | "}[0]\n", 91 | "TRAINED_SAVE_PATH = \"/home/jovyan/models/7_fred-t5-large\"" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "id": "c38bd9ac-f71d-4592-a6f2-fb8d50932f93", 97 | "metadata": {}, 98 | "source": [ 99 | "# He obtayn" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "id": "e61a5649-c87d-45e0-ac24-c011e6b20cd7", 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "import sys\n", 110 | "import os\n", 111 | "sys.path.append(os.path.join(os.path.dirname(os.getcwd())))\n", 112 | "from replaces import Replace, Replaces" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "id": "e98757f1-8be5-4b6a-a4d6-30709880ea56", 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "dataset = []\n", 123 | "for file in DATASET_FILES:\n", 124 | " with open(file) as f:\n", 125 | " for line in tqdm(f, desc=file):\n", 126 | " dataset.append({\"replaces\": Replaces(json.loads(line)[\"replaces\"])})" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "id": "68172e55-55b3-4288-9404-646a5a730c23", 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "dataset[1000]" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "id": "d573fa5c-75c0-433d-97db-bb44c6b9e782", 142 | "metadata": {}, 143 | "source": [ 144 | "Also `ruT5` sentencepiece tokenizer misses new line `\"\\n\"` symbol so ```\\n``` encodes-decodes into ``` extra_id_0 extra_id_1```. To not to fix its outout (a possible but painful action) [one is advised](https://github.com/google/sentencepiece/issues/101) to add the symbol explicitely." 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "id": "1b61374b-eb2c-489e-91e2-2c3a1ece0a18", 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "from transformers import GPT2Tokenizer, T5Tokenizer, AutoTokenizer\n", 155 | "\n", 156 | "\n", 157 | "if MODEL[\"type\"] == \"ruT5\":\n", 158 | " tokenizer = T5Tokenizer.from_pretrained(MODEL_PATH)\n", 159 | " tokenizer.add_tokens(\"\\n\")\n", 160 | "elif MODEL[\"type\"] == \"FRED-T5\":\n", 161 | " tokenizer = GPT2Tokenizer.from_pretrained(MODEL_PATH, eos_token=\"\")\n", 162 | "else:\n", 163 | " tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "id": "c3677846-72b3-47c2-b92c-a1a7a146006b", 169 | "metadata": {}, 170 | "source": [ 171 | "One problem about the last train iteration was deluded prediction of long numbers like `125678`.\n", 172 | "It could possibly happen because of tokenization of numbers if divided on parts which are not easy to operate.\n", 173 | "Lets check it out now." 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "id": "976a2e95-d83d-4ea4-93e4-af39707368c0", 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "good, wrong = [], []\n", 184 | "for i in range(100, 1000):\n", 185 | " a = str(i)\n", 186 | " ids = tokenizer.encode(a)\n", 187 | " b = \"|\".join([tokenizer.decode(_) for _ in ids])\n", 188 | " (wrong if a != b else good).append(b)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "id": "bb03ffd0-915c-43cf-ae1a-719d2f7287f2", 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "len(good), len(wrong)\n", 199 | "# (28, 872) in case of FRED-T5, (19, 881) in case of ruT5" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "id": "8c25ee62-2672-4b6b-862f-de792fa2063f", 205 | "metadata": {}, 206 | "source": [ 207 | "The particular `FRED-T5-large` tokenizer splitted the majority of the three digits numbers.\n", 208 | "May be it would be better if numbers are forced splitted on single digits like `123456` to `1 2 3 4 5 6`.\n", 209 | "\n", 210 | "Other option is to divide numbers by three digit groups such that `1234567` would turn into `1 234 567`. We try that option first." 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "id": "94d00c8d-463b-46f3-a447-c5a26036885b", 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "def strip_numbers(s):\n", 221 | " return \" \".join(((\" \".join(part) if part.isdigit() else part) for part in s.split()))\n", 222 | "\n", 223 | "\n", 224 | "def strip_numbers(s):\n", 225 | " result = []\n", 226 | " for part in s.split():\n", 227 | " if part.isdigit():\n", 228 | " while len(part) > 3:\n", 229 | " result.append(part[:- 3 * ((len(part) - 1) // 3)])\n", 230 | " part = part[- 3 * ((len(part) - 1) // 3):]\n", 231 | " if part:\n", 232 | " result.append(part)\n", 233 | " else:\n", 234 | " result.append(part)\n", 235 | " return \" \".join(result)\n", 236 | "\n", 237 | "\n", 238 | "strip_numbers(\"у нас было 1234567890 пакетиков травы, 750 ампул новокаина, 55555 пакетиков диэтиламида лизергиновой кислоты, солонка, на 1000/2000 наполненная кокаином\")\n", 239 | "# \"у нас было 1 234 567 890 пакетиков травы, 750 ампул новокаина, 55 555 пакетиков диэтиламида лизергиновой кислоты, солонка, на 1000/2000 наполненная кокаином\"" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": null, 245 | "id": "33bb3d93-4339-4f06-b41c-5a37ea0dd5f5", 246 | "metadata": { 247 | "scrolled": true 248 | }, 249 | "outputs": [], 250 | "source": [ 251 | "from collections import Counter\n", 252 | "from itertools import chain\n", 253 | "data = []\n", 254 | "added = Counter()\n", 255 | "for elem in tqdm(dataset):\n", 256 | " if all(_.type == \"E\" for _ in elem[\"replaces\"]):\n", 257 | " continue\n", 258 | " if \"prompt\" in elem and \"target\" in elem:\n", 259 | " continue\n", 260 | " replace_words = list(chain(*(r.text_to.strip().lower().split() for r in elem[\"replaces\"] if r.type != \"E\")))\n", 261 | " added.update(replace_words)\n", 262 | " prompt, target = \"\", \"\"\n", 263 | " etid = 0\n", 264 | " for r in elem[\"replaces\"]:\n", 265 | " if r.type == \"E\":\n", 266 | " prompt += r.text_to\n", 267 | " else:\n", 268 | " ws_number = len(r.text_from) - len(r.text_from.rstrip())\n", 269 | " prompt += f\"[{strip_numbers(r.text_from.rstrip())}]{' ' * ws_number}\"\n", 270 | " target += f\" {r.text_to.strip()} \\n\"\n", 271 | " etid += 1\n", 272 | " elem[\"prompt\"] = f\"{prompt}\"\n", 273 | " elem[\"target\"] = f\"{target}\"\n", 274 | " data.append(elem)" 275 | ] 276 | }, 277 | { 278 | "cell_type": "markdown", 279 | "id": "60e309d4-d538-40a3-af7e-9e2bb1202850", 280 | "metadata": {}, 281 | "source": [ 282 | "We made here train examples of that kind\n", 283 | "\n", 284 | " Временами я думаю, какое применение найти тем [14 697] рублям, что лежат уже больше [33] лет?\n", 285 | "\n", 286 | "and we want to predict a text like this\n", 287 | "\n", 288 | " четырнадцати тысячам шестистам девяноста семи\n", 289 | " тридцати трёх \n", 290 | "\n", 291 | "Lets check what have we added so far like the most (un)common __words__." 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "id": "dbde1073-7c33-4537-b0cb-587ed17de332", 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "added.most_common()[:10], added.most_common()[-10:], data[0]" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "id": "fadf4e99-c7ca-46d8-b86c-8ee8a67a5938", 307 | "metadata": {}, 308 | "source": [ 309 | "Besides rare mistakes it seems to be trainable on.\n", 310 | "\n", 311 | "The distribution is shifted anyway to my taste as will be shown later.\n", 312 | "One fast and simple thing to do about it is to iterate over and filter examples as we have seen too much of **all** the replaced words at the moment." 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "id": "851a6301-7160-421d-bc04-174abfc3a5ac", 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [ 322 | "occ_limit = (sum(added.values()) / len(added)) ** 2 # feel free to find another heuristic\n", 323 | "print(occ_limit)\n", 324 | "added2 = Counter()\n", 325 | "balanced_data = []\n", 326 | "for elem in tqdm(data):\n", 327 | " replace_words = list(chain(*[r.text_to.strip().lower().split() for r in elem[\"replaces\"] if r.type != \"E\"]))\n", 328 | " if any((added2[word] < occ_limit for word in replace_words)):\n", 329 | " balanced_data.append(elem)\n", 330 | " added2.update(replace_words)\n", 331 | "len(balanced_data), len(balanced_data) / len(data)" 332 | ] 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "id": "ef0b9beb-a30e-4b3a-8014-dde28be310f0", 337 | "metadata": {}, 338 | "source": [ 339 | "We have gotten rid of 2/3 of the data we had had!\n", 340 | "Check it out visually now." 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "id": "f584e0d2-03d9-4972-88a1-afdc58a9a7cb", 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [ 350 | "stat_regs = {\n", 351 | " \"re_digits\": re.compile(r\"\\d\"),\n", 352 | " \"re_digits_latin\": re.compile(r\"[a-zA-Z\\d]\"),\n", 353 | " \"re_latin\": re.compile(r\"[a-zA-Z]\")\n", 354 | "}\n", 355 | "for stat_name, stat_re in stat_regs.items():\n", 356 | " print(\n", 357 | " stat_name,\n", 358 | " len([elem for elem in tqdm(balanced_data) if any(re.search(stat_re, r.text_from) for r in elem[\"replaces\"])])\n", 359 | " )" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": null, 365 | "id": "f52f0fe3-b7b0-444c-b124-d3e7a9a86257", 366 | "metadata": {}, 367 | "outputs": [], 368 | "source": [ 369 | "from matplotlib import pyplot as plt \n", 370 | "\n", 371 | "axs = plt.subplot()\n", 372 | "axs.set_yscale('log')\n", 373 | "axs.plot([_[1] for _ in added2.most_common()[:1000]])\n", 374 | "axs.plot([_[1] for _ in added.most_common()[:1000]])" 375 | ] 376 | }, 377 | { 378 | "cell_type": "markdown", 379 | "id": "c8e31a12-47a8-4273-9d34-e50074d5b8fe", 380 | "metadata": {}, 381 | "source": [ 382 | "Only extra data lost so far." 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": null, 388 | "id": "ffb88cae-2d41-4d23-989f-15fd8afc6ed2", 389 | "metadata": {}, 390 | "outputs": [], 391 | "source": [ 392 | "dataset = Dataset.from_list(balanced_data).train_test_split(test_size=0.01)\n", 393 | "dataset" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": null, 399 | "id": "0b6a8d28-64e2-4758-8b7f-e930151d2bcd", 400 | "metadata": {}, 401 | "outputs": [], 402 | "source": [ 403 | "def preprocess_function(examples):\n", 404 | " model_inputs = tokenizer(\n", 405 | " examples[\"prompt\"],\n", 406 | " text_target=examples[\"target\"],\n", 407 | " max_length=128, # NB should affect memory consumption\n", 408 | " truncation=True\n", 409 | " )\n", 410 | " return model_inputs\n", 411 | "\n", 412 | "\n", 413 | "dataset = dataset.map(preprocess_function, batched=True, num_proc=10)" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "id": "f9d9a47b-1042-4334-9692-2af334eb1265", 420 | "metadata": {}, 421 | "outputs": [], 422 | "source": [ 423 | "dataset = dataset.remove_columns([\"prompt\", \"target\", \"replaces\"])" 424 | ] 425 | }, 426 | { 427 | "cell_type": "markdown", 428 | "id": "9c0abc0d-8a6e-4ff7-a45b-2da3151548c5", 429 | "metadata": {}, 430 | "source": [ 431 | "Just in case I get rid of examples with possible truncation mistakes." 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": null, 437 | "id": "9d66e513-dfa3-4f93-b798-6738dae70584", 438 | "metadata": { 439 | "scrolled": true 440 | }, 441 | "outputs": [], 442 | "source": [ 443 | "from collections import Counter\n", 444 | "c = Counter([len(_[\"input_ids\"]) for _ in dataset[\"train\"]])\n", 445 | "sum([v for k, v in c.items() if k < 128]), c" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": null, 451 | "id": "6917cfec-1674-4658-aac6-c4f03df8705b", 452 | "metadata": {}, 453 | "outputs": [], 454 | "source": [ 455 | "for k, v in dataset.items():\n", 456 | " dataset[k] = [_ for _ in v if 10 < len(_[\"input_ids\"]) < 126]\n", 457 | "{k:len(v) for k, v in dataset.items()}" 458 | ] 459 | }, 460 | { 461 | "cell_type": "markdown", 462 | "id": "669201f4-ad10-4c6b-8eed-14e712d3683a", 463 | "metadata": {}, 464 | "source": [ 465 | "# He trayn\n", 466 | "\n", 467 | "Time to train actually as last!" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": null, 473 | "id": "296c31ca-1755-4033-a45b-f2443d14ba7f", 474 | "metadata": {}, 475 | "outputs": [], 476 | "source": [ 477 | "from transformers import T5ForConditionalGeneration\n", 478 | "\n", 479 | "\n", 480 | "model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH)" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": null, 486 | "id": "e12cc228-f9db-441e-b406-62c7dfed0390", 487 | "metadata": {}, 488 | "outputs": [], 489 | "source": [ 490 | "from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq\n", 491 | "\n", 492 | "\n", 493 | "training_args = Seq2SeqTrainingArguments(\n", 494 | " output_dir=TRAINED_SAVE_PATH,\n", 495 | " optim=\"adafactor\",\n", 496 | " evaluation_strategy=\"steps\",\n", 497 | " eval_steps=1000,\n", 498 | " save_steps=1000,\n", 499 | " logging_first_step=True,\n", 500 | " learning_rate=1e-4,\n", 501 | " lr_scheduler_type=\"constant\",\n", 502 | " # gradient_checkpointing=True,\n", 503 | " gradient_checkpointing=False,\n", 504 | " gradient_accumulation_steps=8,\n", 505 | " per_device_train_batch_size=4,\n", 506 | " per_device_eval_batch_size=2,\n", 507 | " save_total_limit=20,\n", 508 | " num_train_epochs=2,\n", 509 | " # predict_with_generate=True,\n", 510 | " # fp16=True,\n", 511 | " push_to_hub=False,\n", 512 | " remove_unused_columns=False,\n", 513 | " load_best_model_at_end=True,\n", 514 | " # auto_find_batch_size=True,\n", 515 | " dataloader_num_workers=4,\n", 516 | " report_to=\"tensorboard\",\n", 517 | ")\n" 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": null, 523 | "id": "6b445baa-b12a-48f1-9dba-2601b8175c3c", 524 | "metadata": {}, 525 | "outputs": [], 526 | "source": [ 527 | "import transformers\n", 528 | "transformers.logging.set_verbosity_info()" 529 | ] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "execution_count": null, 534 | "id": "4642ba79-cca3-4ddc-8b5e-6255931af701", 535 | "metadata": {}, 536 | "outputs": [], 537 | "source": [ 538 | "trainer = Seq2SeqTrainer(\n", 539 | " model=model,\n", 540 | " args=training_args,\n", 541 | " train_dataset=dataset[\"train\"],\n", 542 | " eval_dataset=dataset[\"test\"],\n", 543 | " tokenizer=tokenizer,\n", 544 | " data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)\n", 545 | ")\n", 546 | "trainer.train()" 547 | ] 548 | }, 549 | { 550 | "cell_type": "code", 551 | "execution_count": null, 552 | "id": "c5e2fe2b-1a58-48c3-a7be-37675fb723c9", 553 | "metadata": {}, 554 | "outputs": [], 555 | "source": [ 556 | "model.save_pretrained(os.path.join(TRAINED_SAVE_PATH, \"final\"), safe_serialization=False)\n", 557 | "tokenizer.save_pretrained(os.path.join(TRAINED_SAVE_PATH, \"final\"))" 558 | ] 559 | }, 560 | { 561 | "cell_type": "markdown", 562 | "id": "b447b82f-5214-41b7-85ec-2e23271ddce6", 563 | "metadata": {}, 564 | "source": [ 565 | "# But most importantly he explayn" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": null, 571 | "id": "011a1834-3688-47f2-9519-2158f9c930e1", 572 | "metadata": {}, 573 | "outputs": [], 574 | "source": [ 575 | "import torch\n", 576 | "lm_text = 'я купил [iphone 12X] за [142 990] руб без [3-x] часов полдень и т.д.'\n", 577 | "lm_text = 'я купил айфон за [14 970] рублей'\n", 578 | "lm_text = \"Временами я думаю, какое применение найти тем [14 697] рублям, что лежат уже больше [33] лет?\"\n", 579 | "lm_text = \"Было у отца [3] сына, но не было даже [2-3] пиджаков с блёстками за [142 990 руб].\"\n", 580 | "lm_text = \"В школе у меня одни [5].\"\n", 581 | "lm_text = 'Было у отца [3] сына. Старшему было [35], среднему - не меньше [33], а младший на [4] младше всех. Бывает.'\n", 582 | "lm_text = \"Временами я думаю, какое применение найти тем [265 948 697] рублям, что лежат уже больше [33] лет у нашего [DevOps]?\"\n", 583 | "input_ids = torch.tensor([tokenizer.encode(lm_text)]).to(\"cuda:0\")\n", 584 | "outputs = model.generate(input_ids, eos_token_id=tokenizer.eos_token_id, early_stopping=True, max_new_tokens=50)\n", 585 | "print(tokenizer.decode(outputs[0][1:]))" 586 | ] 587 | } 588 | ], 589 | "metadata": { 590 | "kernelspec": { 591 | "display_name": "Python 3 (ipykernel)", 592 | "language": "python", 593 | "name": "python3" 594 | }, 595 | "language_info": { 596 | "codemirror_mode": { 597 | "name": "ipython", 598 | "version": 3 599 | }, 600 | "file_extension": ".py", 601 | "mimetype": "text/x-python", 602 | "name": "python", 603 | "nbconvert_exporter": "python", 604 | "pygments_lexer": "ipython3", 605 | "version": "3.10.13" 606 | } 607 | }, 608 | "nbformat": 4, 609 | "nbformat_minor": 5 610 | } 611 | --------------------------------------------------------------------------------