├── .gitignore ├── .gitmodules ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── data ├── example_data │ ├── RUSpellRU │ │ ├── corrections.txt │ │ └── sources.txt │ ├── bea60k │ │ ├── bea_txt │ │ │ ├── corrections.txt │ │ │ └── sources.txt │ │ ├── subsample │ │ │ ├── corrections.txt │ │ │ └── sources.txt │ │ ├── test.bea60k │ │ └── test.bea60k.noise │ └── jfleg │ │ ├── corrections.txt │ │ └── sources.txt └── sanity_check_samples │ ├── RUSpellRU │ ├── corrections.txt │ └── sources.txt │ ├── corrected_sents.txt │ ├── corruptor_tests │ ├── broken_csv_file_columns │ │ └── data.csv │ ├── broken_csv_file_nans │ │ └── data.csv │ ├── broken_csv_file_opening │ │ └── data.csv │ ├── broken_text_files │ │ ├── corrections.txt │ │ └── sources.txt │ ├── corrections.txt │ ├── csv │ │ └── data.csv │ ├── sources.txt │ └── wrong_names │ │ ├── corrections_.txt │ │ ├── data_.csv │ │ └── sources_.txt │ └── source_sents.txt ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── conf.py │ ├── images │ ├── bea60k_side_by_side.jpg │ ├── benchmark.png │ ├── ruspellru_side_by_side.jpg │ ├── sage-black.svg │ ├── sage-white.svg │ └── sage_banner.jpg │ ├── index.rst │ └── rst │ ├── datasets │ ├── GitHubTypoCorpusRu.rst │ ├── MedSpellchecker.rst │ ├── MultidomainGold.rst │ └── RUSpellRU.rst │ ├── evaluation │ ├── RuErrant.rst │ └── RuSpellEval.rst │ ├── spelling_correction │ ├── FredT5-large.rst │ ├── M2M100-418M.rst │ ├── RuM2M100-1.2B.rst │ ├── T5.rst │ ├── sage-fredt5-distilled-95m.rst │ ├── sage-fredt5-large.rst │ ├── sage-m2m100-1.2B.rst │ └── sage-mt5-large.rst │ └── spelling_corruption │ ├── Augmentex.rst │ └── SBSC.rst ├── images ├── bea60k_side_by_side.jpg ├── ruspellru_side_by_side.jpg ├── sage-black.svg └── sage-white.svg ├── notebooks ├── text_correction_demo.ipynb └── text_corruption_demo.ipynb ├── sage ├── __init__.py ├── evaluation │ ├── __init__.py │ ├── readme.md │ ├── ruerrant_wrapper │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── merger.py │ │ └── scorer.py │ ├── ruspelleval.py │ └── scorer.py ├── spelling_correction │ ├── __init__.py │ ├── corrector.py │ ├── m2m_correctors.py │ └── t5_correctors.py ├── spelling_corruption │ ├── __init__.py │ ├── configuration_corruptor.py │ ├── corruptor.py │ └── sbsc │ │ ├── __init__.py │ │ ├── base_classes.py │ │ ├── labeler.py │ │ ├── model.py │ │ ├── sbsc.py │ │ └── typings_positions_conditions.py └── utils │ ├── __init__.py │ ├── data_load_utils.py │ ├── lang_utils.py │ └── utils.py ├── setup.py ├── tests ├── corruptor_api_unittests.py ├── sbsc_corruptor_unittests.py ├── test_correctors.py ├── test_corruptors.py ├── test_evaluate.py ├── test_metrics.py ├── test_ruspelleval.py ├── test_utils.py ├── tests.py └── tests_english.py └── wheels └── augmentex-1.0.3-py3-none-any.whl /.gitignore: -------------------------------------------------------------------------------- 1 | venv/ 2 | __pycache__ 3 | .ipynb_checkpoints 4 | sage.egg-info 5 | build 6 | .DS_Store 7 | .idea 8 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-forever/sage/71736b6d14b1dd4a369ea9289eb2cac9a5a25f21/.gitmodules -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | 3 | build: 4 | os: "ubuntu-22.04" 5 | tools: 6 | python: "3.10" 7 | 8 | python: 9 | install: 10 | - requirements: docs/requirements.txt 11 | - method: pip 12 | path: . 13 | 14 | sphinx: 15 | configuration: docs/source/conf.py 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 AI Forever 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/sanity_check_samples/corruptor_tests/broken_csv_file_columns/data.csv: -------------------------------------------------------------------------------- 1 | source,corrections 2 | очень классная тетка ктобы что не говорил.,очень классная тетка кто бы что ни говорил 3 | Может выгоднее втулку продать и купить колесо в сборе?,Может выгоднее втулку продать и купить колесо в сборе 4 | Довольно большая часть пришедших сходила с дорожек и усаживалась на траву.,Довольно большая часть пришедших сходила с дорожек и усаживалась на траву 5 | "Симпатичнейшое шпионское устройство, такой себе гламурный фотоаппарат девушки Бонда - миниатюрная модель камеры Superheadz Clap Camera.",Симпатичнейшее шпионское устройство такой себе гламурный фотоаппарат девушки Бонда миниатюрная модель камеры Superheadz Clap Camera 6 | Опофеозом дня для меня сегодня стала фраза услышанная в новостях:,Апофеозом дня для меня сегодня стала фраза услышанная в новостях 7 | "Ну не было поста, так небыло!",Ну не было поста так не было 8 | "Хотя странно, когда я забирала к себе на выходные старого кота, который живет у родителей, да и собаку в придачу, то такого концерта мой кот не устраивал.",Хотя странно когда я забирала к себе на выходные старого кота который живет у родителей да и собаку в придачу то такого концерта мой кот не устраивал 9 | "Думаю, что лет через 10 ретроспективно просматривать это будет мне невероятно интересно.",Думаю что лет через 10 ретроспективно просматривать это будет мне невероятно интересно 10 | "Зато я считаю, что это будет полезно и для меня и для всех тех, кто меня окружает, ведь когда расстаешься с человеком на какое-то время, то многое становится прозрачным, я имею ввиду мы начинаем понимать какое место в нашей повседневности занимает этот человек...",Зато я считаю что это будет полезно и для меня и для всех тех кто меня окружает ведь когда расстаешься с человеком на какое-то время то многое становится прозрачным я имею ввиду мы начинаем понимать какое место в нашей повседневности занимает этот человек 11 | Пояним эту мысль.,Поясним эту мысль -------------------------------------------------------------------------------- /data/sanity_check_samples/corruptor_tests/broken_csv_file_nans/data.csv: -------------------------------------------------------------------------------- 1 | source,correction 2 | очень классная тетка ктобы что не говорил.,очень классная тетка кто бы что ни говорил 3 | Может выгоднее втулку продать и купить колесо в сборе?,Может выгоднее втулку продать и купить колесо в сборе 4 | Довольно большая часть пришедших сходила с дорожек и усаживалась на траву.,Довольно большая часть пришедших сходила с дорожек и усаживалась на траву 5 | "Симпатичнейшое шпионское устройство, такой себе гламурный фотоаппарат девушки Бонда - миниатюрная модель камеры Superheadz Clap Camera.",Симпатичнейшее шпионское устройство такой себе гламурный фотоаппарат девушки Бонда миниатюрная модель камеры Superheadz Clap Camera 6 | Опофеозом дня для меня сегодня стала фраза услышанная в новостях:,Апофеозом дня для меня сегодня стала фраза услышанная в новостях 7 | "Ну не было поста, так небыло!",Ну не было поста так не было 8 | "Хотя странно, когда я забирала к себе на выходные старого кота, который живет у родителей, да и собаку в придачу, то такого концерта мой кот не устраивал.",Хотя странно когда я забирала к себе на выходные старого кота который живет у родителей да и собаку в придачу то такого концерта мой кот не устраивал 9 | "Думаю, что лет через 10 ретроспективно просматривать это будет мне невероятно интересно.",Думаю что лет через 10 ретроспективно просматривать это будет мне невероятно интересно 10 | "Зато я считаю, что это будет полезно и для меня и для всех тех, кто меня окружает, ведь когда расстаешься с человеком на какое-то время, то многое становится прозрачным, я имею ввиду мы начинаем понимать какое место в нашей повседневности занимает этот человек...",Зато я считаю что это будет полезно и для меня и для всех тех кто меня окружает ведь когда расстаешься с человеком на какое-то время то многое становится прозрачным я имею ввиду мы начинаем понимать какое место в нашей повседневности занимает этот человек 11 | , -------------------------------------------------------------------------------- /data/sanity_check_samples/corruptor_tests/broken_csv_file_opening/data.csv: -------------------------------------------------------------------------------- 1 | source,correction 2 | очень классная тетка ктобы что не говорил.,очень классная тетка кто бы что ни говорил 3 | Может выгоднее втулку продать и купить колесо в сборе?,Может выгоднее втулку продать и купить колесо в сборе,asdf 4 | Довольно большая часть пришедших сходила с дорожек и усаживалась на траву.,Довольно большая часть пришедших сходила с дорожек и усаживалась на траву 5 | "Симпатичнейшое шпионское устройство, такой себе гламурный фотоаппарат девушки Бонда - миниатюрная модель камеры Superheadz Clap Camera.",Симпатичнейшее шпионское устройство такой себе гламурный фотоаппарат девушки Бонда миниатюрная модель камеры Superheadz Clap Camera 6 | Опофеозом дня для меня сегодня стала фраза услышанная в новостях:,Апофеозом дня для меня сегодня стала фраза услышанная в новостях 7 | "Ну не было поста, так небыло!",Ну не было поста так не было 8 | "Хотя странно, когда я забирала к себе на выходные старого кота, который живет у родителей, да и собаку в придачу, то такого концерта мой кот не устраивал.",Хотя странно когда я забирала к себе на выходные старого кота который живет у родителей да и собаку в придачу то такого концерта мой кот не устраивал 9 | "Думаю, что лет через 10 ретроспективно просматривать это будет мне невероятно интересно.",Думаю что лет через 10 ретроспективно просматривать это будет мне невероятно интересно 10 | "Зато я считаю, что это будет полезно и для меня и для всех тех, кто меня окружает, ведь когда расстаешься с человеком на какое-то время, то многое становится прозрачным, я имею ввиду мы начинаем понимать какое место в нашей повседневности занимает этот человек...",Зато я считаю что это будет полезно и для меня и для всех тех кто меня окружает ведь когда расстаешься с человеком на какое-то время то многое становится прозрачным я имею ввиду мы начинаем понимать какое место в нашей повседневности занимает этот человек 11 | Пояним эту мысль.,Поясним эту мысль -------------------------------------------------------------------------------- /data/sanity_check_samples/corruptor_tests/broken_text_files/corrections.txt: -------------------------------------------------------------------------------- 1 | очень классная тетка кто бы что ни говорил 2 | Может выгоднее втулку продать и купить колесо в сборе 3 | Довольно большая часть пришедших сходила с дорожек и усаживалась на траву 4 | Симпатичнейшее шпионское устройство такой себе гламурный фотоаппарат девушки Бонда миниатюрная модель камеры Superheadz Clap Camera 5 | Апофеозом дня для меня сегодня стала фраза услышанная в новостях 6 | Ну не было поста так не было 7 | Хотя странно когда я забирала к себе на выходные старого кота который живет у родителей да и собаку в придачу то такого концерта мой кот не устраивал 8 | Думаю что лет через 10 ретроспективно просматривать это будет мне невероятно интересно 9 | Зато я считаю что это будет полезно и для меня и для всех тех кто меня окружает ведь когда расстаешься с человеком на какое-то время то многое становится прозрачным я имею ввиду мы начинаем понимать какое место в нашей повседневности занимает этот человек 10 | -------------------------------------------------------------------------------- /data/sanity_check_samples/corruptor_tests/broken_text_files/sources.txt: -------------------------------------------------------------------------------- 1 | очень классная тетка ктобы что не говорил. 2 | Может выгоднее втулку продать и купить колесо в сборе? 3 | Довольно большая часть пришедших сходила с дорожек и усаживалась на траву. 4 | Симпатичнейшое шпионское устройство, такой себе гламурный фотоаппарат девушки Бонда - миниатюрная модель камеры Superheadz Clap Camera. 5 | Опофеозом дня для меня сегодня стала фраза услышанная в новостях: 6 | Ну не было поста, так небыло! 7 | Хотя странно, когда я забирала к себе на выходные старого кота, который живет у родителей, да и собаку в придачу, то такого концерта мой кот не устраивал. 8 | Думаю, что лет через 10 ретроспективно просматривать это будет мне невероятно интересно. 9 | Зато я считаю, что это будет полезно и для меня и для всех тех, кто меня окружает, ведь когда расстаешься с человеком на какое-то время, то многое становится прозрачным, я имею ввиду мы начинаем понимать какое место в нашей повседневности занимает этот человек... 10 | Пояним эту мысль. 11 | -------------------------------------------------------------------------------- /data/sanity_check_samples/corruptor_tests/corrections.txt: -------------------------------------------------------------------------------- 1 | очень классная тетка кто бы что ни говорил 2 | Может выгоднее втулку продать и купить колесо в сборе 3 | Довольно большая часть пришедших сходила с дорожек и усаживалась на траву 4 | Симпатичнейшее шпионское устройство такой себе гламурный фотоаппарат девушки Бонда миниатюрная модель камеры Superheadz Clap Camera 5 | Апофеозом дня для меня сегодня стала фраза услышанная в новостях 6 | Ну не было поста так не было 7 | Хотя странно когда я забирала к себе на выходные старого кота который живет у родителей да и собаку в придачу то такого концерта мой кот не устраивал 8 | Думаю что лет через 10 ретроспективно просматривать это будет мне невероятно интересно 9 | Зато я считаю что это будет полезно и для меня и для всех тех кто меня окружает ведь когда расстаешься с человеком на какое-то время то многое становится прозрачным я имею ввиду мы начинаем понимать какое место в нашей повседневности занимает этот человек 10 | Поясним эту мысль 11 | -------------------------------------------------------------------------------- /data/sanity_check_samples/corruptor_tests/csv/data.csv: -------------------------------------------------------------------------------- 1 | source,correction 2 | очень классная тетка ктобы что не говорил.,очень классная тетка кто бы что ни говорил 3 | Может выгоднее втулку продать и купить колесо в сборе?,Может выгоднее втулку продать и купить колесо в сборе 4 | Довольно большая часть пришедших сходила с дорожек и усаживалась на траву.,Довольно большая часть пришедших сходила с дорожек и усаживалась на траву 5 | "Симпатичнейшое шпионское устройство, такой себе гламурный фотоаппарат девушки Бонда - миниатюрная модель камеры Superheadz Clap Camera.",Симпатичнейшее шпионское устройство такой себе гламурный фотоаппарат девушки Бонда миниатюрная модель камеры Superheadz Clap Camera 6 | Опофеозом дня для меня сегодня стала фраза услышанная в новостях:,Апофеозом дня для меня сегодня стала фраза услышанная в новостях 7 | "Ну не было поста, так небыло!",Ну не было поста так не было 8 | "Хотя странно, когда я забирала к себе на выходные старого кота, который живет у родителей, да и собаку в придачу, то такого концерта мой кот не устраивал.",Хотя странно когда я забирала к себе на выходные старого кота который живет у родителей да и собаку в придачу то такого концерта мой кот не устраивал 9 | "Думаю, что лет через 10 ретроспективно просматривать это будет мне невероятно интересно.",Думаю что лет через 10 ретроспективно просматривать это будет мне невероятно интересно 10 | "Зато я считаю, что это будет полезно и для меня и для всех тех, кто меня окружает, ведь когда расстаешься с человеком на какое-то время, то многое становится прозрачным, я имею ввиду мы начинаем понимать какое место в нашей повседневности занимает этот человек...",Зато я считаю что это будет полезно и для меня и для всех тех кто меня окружает ведь когда расстаешься с человеком на какое-то время то многое становится прозрачным я имею ввиду мы начинаем понимать какое место в нашей повседневности занимает этот человек 11 | Пояним эту мысль.,Поясним эту мысль -------------------------------------------------------------------------------- /data/sanity_check_samples/corruptor_tests/sources.txt: -------------------------------------------------------------------------------- 1 | очень классная тетка ктобы что не говорил. 2 | Может выгоднее втулку продать и купить колесо в сборе? 3 | Довольно большая часть пришедших сходила с дорожек и усаживалась на траву. 4 | Симпатичнейшое шпионское устройство, такой себе гламурный фотоаппарат девушки Бонда - миниатюрная модель камеры Superheadz Clap Camera. 5 | Опофеозом дня для меня сегодня стала фраза услышанная в новостях: 6 | Ну не было поста, так небыло! 7 | Хотя странно, когда я забирала к себе на выходные старого кота, который живет у родителей, да и собаку в придачу, то такого концерта мой кот не устраивал. 8 | Думаю, что лет через 10 ретроспективно просматривать это будет мне невероятно интересно. 9 | Зато я считаю, что это будет полезно и для меня и для всех тех, кто меня окружает, ведь когда расстаешься с человеком на какое-то время, то многое становится прозрачным, я имею ввиду мы начинаем понимать какое место в нашей повседневности занимает этот человек... 10 | Пояним эту мысль. 11 | -------------------------------------------------------------------------------- /data/sanity_check_samples/corruptor_tests/wrong_names/corrections_.txt: -------------------------------------------------------------------------------- 1 | очень классная тетка кто бы что ни говорил 2 | Может выгоднее втулку продать и купить колесо в сборе 3 | Довольно большая часть пришедших сходила с дорожек и усаживалась на траву 4 | Симпатичнейшее шпионское устройство такой себе гламурный фотоаппарат девушки Бонда миниатюрная модель камеры Superheadz Clap Camera 5 | Апофеозом дня для меня сегодня стала фраза услышанная в новостях 6 | Ну не было поста так не было 7 | Хотя странно когда я забирала к себе на выходные старого кота который живет у родителей да и собаку в придачу то такого концерта мой кот не устраивал 8 | Думаю что лет через 10 ретроспективно просматривать это будет мне невероятно интересно 9 | Зато я считаю что это будет полезно и для меня и для всех тех кто меня окружает ведь когда расстаешься с человеком на какое-то время то многое становится прозрачным я имею ввиду мы начинаем понимать какое место в нашей повседневности занимает этот человек 10 | Поясним эту мысль 11 | -------------------------------------------------------------------------------- /data/sanity_check_samples/corruptor_tests/wrong_names/data_.csv: -------------------------------------------------------------------------------- 1 | source,correction 2 | очень классная тетка ктобы что не говорил.,очень классная тетка кто бы что ни говорил 3 | Может выгоднее втулку продать и купить колесо в сборе?,Может выгоднее втулку продать и купить колесо в сборе 4 | Довольно большая часть пришедших сходила с дорожек и усаживалась на траву.,Довольно большая часть пришедших сходила с дорожек и усаживалась на траву 5 | "Симпатичнейшое шпионское устройство, такой себе гламурный фотоаппарат девушки Бонда - миниатюрная модель камеры Superheadz Clap Camera.",Симпатичнейшее шпионское устройство такой себе гламурный фотоаппарат девушки Бонда миниатюрная модель камеры Superheadz Clap Camera 6 | Опофеозом дня для меня сегодня стала фраза услышанная в новостях:,Апофеозом дня для меня сегодня стала фраза услышанная в новостях 7 | "Ну не было поста, так небыло!",Ну не было поста так не было 8 | "Хотя странно, когда я забирала к себе на выходные старого кота, который живет у родителей, да и собаку в придачу, то такого концерта мой кот не устраивал.",Хотя странно когда я забирала к себе на выходные старого кота который живет у родителей да и собаку в придачу то такого концерта мой кот не устраивал 9 | "Думаю, что лет через 10 ретроспективно просматривать это будет мне невероятно интересно.",Думаю что лет через 10 ретроспективно просматривать это будет мне невероятно интересно 10 | "Зато я считаю, что это будет полезно и для меня и для всех тех, кто меня окружает, ведь когда расстаешься с человеком на какое-то время, то многое становится прозрачным, я имею ввиду мы начинаем понимать какое место в нашей повседневности занимает этот человек...",Зато я считаю что это будет полезно и для меня и для всех тех кто меня окружает ведь когда расстаешься с человеком на какое-то время то многое становится прозрачным я имею ввиду мы начинаем понимать какое место в нашей повседневности занимает этот человек 11 | Пояним эту мысль.,Поясним эту мысль -------------------------------------------------------------------------------- /data/sanity_check_samples/corruptor_tests/wrong_names/sources_.txt: -------------------------------------------------------------------------------- 1 | очень классная тетка ктобы что не говорил. 2 | Может выгоднее втулку продать и купить колесо в сборе? 3 | Довольно большая часть пришедших сходила с дорожек и усаживалась на траву. 4 | Симпатичнейшое шпионское устройство, такой себе гламурный фотоаппарат девушки Бонда - миниатюрная модель камеры Superheadz Clap Camera. 5 | Опофеозом дня для меня сегодня стала фраза услышанная в новостях: 6 | Ну не было поста, так небыло! 7 | Хотя странно, когда я забирала к себе на выходные старого кота, который живет у родителей, да и собаку в придачу, то такого концерта мой кот не устраивал. 8 | Думаю, что лет через 10 ретроспективно просматривать это будет мне невероятно интересно. 9 | Зато я считаю, что это будет полезно и для меня и для всех тех, кто меня окружает, ведь когда расстаешься с человеком на какое-то время, то многое становится прозрачным, я имею ввиду мы начинаем понимать какое место в нашей повседневности занимает этот человек... 10 | Пояним эту мысль. 11 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==7.1.2 2 | sphinx-rtd-theme==1.3.0rc1 3 | docutils==0.18.1 4 | sphinxemoji 5 | Levenshtein 6 | git+https://github.com/Askinkaty/errant/@4183e57 7 | https://huggingface.co/spacy/ru_core_news_lg/resolve/main/ru_core_news_lg-any-py3-none-any.whl 8 | . -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | 3 | # -- Project information 4 | 5 | import os 6 | import sys 7 | sys.path.insert(0, os.path.abspath('.')) 8 | 9 | print(sys.path) 10 | print(os.listdir(".")) 11 | 12 | project = 'SAGE' 13 | copyright = '2021, Graziella' 14 | author = 'Nikita Martynov' 15 | 16 | release = '1.0' 17 | version = '1.1.0' 18 | 19 | # -- General configuration 20 | 21 | extensions = [ 22 | 'sphinx.ext.duration', 23 | 'sphinx.ext.doctest', 24 | 'sphinx.ext.autodoc', 25 | 'sphinx.ext.autosummary', 26 | 'sphinx.ext.intersphinx', 27 | 'sphinx_rtd_theme', 28 | 'sphinxemoji.sphinxemoji', 29 | ] 30 | 31 | intersphinx_mapping = { 32 | 'python': ('https://docs.python.org/3/', None), 33 | 'sphinx': ('https://www.sphinx-doc.org/en/master/', None), 34 | } 35 | intersphinx_disabled_domains = ['std'] 36 | 37 | templates_path = ['_templates'] 38 | 39 | source_suffix = ['.rst'] 40 | 41 | html_static_path = ['images'] 42 | 43 | # -- Options for HTML output 44 | 45 | html_theme = 'sphinx_rtd_theme' 46 | 47 | html_theme_options = { 48 | 49 | 'collapse_navigation': False, 50 | 'sticky_navigation': True, 51 | 'navigation_depth': 4, 52 | 'includehidden': True, 53 | 'titles_only': False 54 | } 55 | 56 | # -- Options for EPUB output 57 | epub_show_urls = 'footnote' 58 | -------------------------------------------------------------------------------- /docs/source/images/bea60k_side_by_side.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-forever/sage/71736b6d14b1dd4a369ea9289eb2cac9a5a25f21/docs/source/images/bea60k_side_by_side.jpg -------------------------------------------------------------------------------- /docs/source/images/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-forever/sage/71736b6d14b1dd4a369ea9289eb2cac9a5a25f21/docs/source/images/benchmark.png -------------------------------------------------------------------------------- /docs/source/images/ruspellru_side_by_side.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-forever/sage/71736b6d14b1dd4a369ea9289eb2cac9a5a25f21/docs/source/images/ruspellru_side_by_side.jpg -------------------------------------------------------------------------------- /docs/source/images/sage_banner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-forever/sage/71736b6d14b1dd4a369ea9289eb2cac9a5a25f21/docs/source/images/sage_banner.jpg -------------------------------------------------------------------------------- /docs/source/rst/datasets/GitHubTypoCorpusRu.rst: -------------------------------------------------------------------------------- 1 | 🐙 GitHubTypoCorpusRu 2 | ------------------- 3 | 4 | The dataset is a part of `spellcheck_punctuation_benchmark `_: 5 | 6 | .. image:: ../../images/benchmark.png 7 | :align: center 8 | 9 | The Benchmark includes four datasets, each of which consists of pairs of sentences in Russian language. Each pair embodies sentence, which may contain spelling and punctuation errors, and its corresponding correction. Datasets were gathered from various sources and domains including social networks, internet blogs, github commits, medical anamnesis, literature, news, reviews and more. 10 | 11 | All datasets were passed through two-stage manual labeling pipeline. The correction of a sentence is defined by an agreement of at least two human annotators. Manual labeling scheme accounts for jargonisms, collocations and common language, hence in some cases it encourages annotators not to amend a word in favor of preserving style of a text. 12 | 13 | The latter does not apply to punctuation. Punctuation signs are rigorously marked in accordance to the rules of the Russian punctuation system. 14 | 15 | 16 | Table of contents 17 | ^^^^^^^^^^^^^^^^^ 18 | 19 | * `Dataset description <#id1>`_ 20 | 21 | * `Dataset summary <#id2>`_ 22 | * `Supported Tasks and Leaderboards <#id3>`_ 23 | * `Languages <#id4>`_ 24 | 25 | * `Dataset Structure <#id5>`_ 26 | 27 | * `Data Instances <#id6>`_ 28 | * `Data Fields <#id7>`_ 29 | * `Data Splits <#id8>`_ 30 | 31 | * `Dataset Creation <#id9>`_ 32 | 33 | * `Initial Data Collection and Normalization <#id10>`_ 34 | * `Annotation process <#id11>`_ 35 | * `Who are the annotators? <#id12>`_ 36 | 37 | * `Considerations for Using the Data <#id13>`_ 38 | 39 | * `Discussion of Biases <#id14>`_ 40 | * `Other Known Limitations <#id15>`_ 41 | 42 | * `Additional Information <#id16>`_ 43 | 44 | * `Future plans <#id17>`_ 45 | * `Dataset Curators <#id18>`_ 46 | * `Licensing Information <#id19>`_ 47 | * `Citation Information <#id20>`_ 48 | 49 | Dataset Description 50 | ^^^^^^^^^^^^^^^^^^^ 51 | 52 | - **Repository:** `SAGE `_ 53 | - **Paper:** `EACL 2024 `_ 54 | - **Point of Contact:** nikita.martynov.98@list.ru 55 | 56 | 57 | Dataset Summary 58 | ################ 59 | 60 | The Russian language part of `GitHub Typo Corpus `_. 61 | The texts are from GitHub commits. Passed the second-step of two-step manual annotation. 62 | 63 | Supported Tasks and Leaderboards 64 | ################################# 65 | 66 | - **Task:** automatic spelling correction. 67 | - **Metrics:** https://www.dialog-21.ru/media/3427/sorokinaaetal.pdf. 68 | - **ERRANT:** https://github.com/chrisjbryant/errant. 69 | 70 | 71 | Languages 72 | ######### 73 | 74 | Russian. 75 | 76 | Dataset Structure 77 | ^^^^^^^^^^^^^^^^^ 78 | 79 | Data Instances 80 | ################ 81 | 82 | - **Size of downloaded dataset files:** 1.23 Mb 83 | - **Size of the generated dataset:** 0.48 Mb 84 | - **Total amount of disk used:** 1.71 Mb 85 | 86 | An example of "test" looks as follows 87 | 88 | .. code-block:: 89 | 90 | 91 | { 92 | "source": "text: Пожалуйста выберите чат, чтобы начать общение", 93 | "correction": "text: Пожалуйста, выберите чат, чтобы начать общение.", 94 | } 95 | 96 | Data Fields 97 | ################ 98 | 99 | - `source`: a `string` feature 100 | - `correction`: a `string` feature 101 | - `domain`: a `string` feature 102 | 103 | Data Splits 104 | ################ 105 | 106 | +--------------------+------+ 107 | | | test | 108 | +====================+======+ 109 | | GitHubTypoCorpusRu | 868 | 110 | +--------------------+------+ 111 | 112 | Dataset Creation 113 | ^^^^^^^^^^^^^^^^^ 114 | 115 | Initial Data Collection and Normalization 116 | ########################################## 117 | 118 | For the reference on the original data collection please see the `paper `_. 119 | We extracted the Russian part from the original corpus and passed the texts trough the second step of two-stage manual annotation. 120 | 121 | Annotation process 122 | ########################################## 123 | 124 | We set up two-stage annotation project via a crowd-sourcing platform Toloka: 125 | 126 | 1. Data gathering stage: we provide the texts with possible mistakes to annotators and ask them to write the sentence correctly; 127 | 2. Validation stage: we provide annotators with the pair of sentences (source and its corresponding correction from the previous stage) and ask them to check if the correction is right. 128 | 129 | We prepared instructions for annotators for each task. The instructions ask annotators to correct misspellings if it does not alter the original style of the text. 130 | Instructions do not provide rigorous criteria on the matter of distinguishing the nature of an error in terms of its origin - whether it came from an urge to endow a sentence with particular stylistic features or from unintentional spelling violation since it is time-consuming and laborious to describe every possible case of employing slang, dialect, colloquialisms, etc. instead of proper language. Instructions also do not distinguish errors that come from the geographical or social background of the source. Instead, we rely on annotators’ knowledge and understanding of a language since, in this work, the important factor is to preserve the original style of the text. 131 | To ensure we receive qualified expertise, we set up test iteration on a small subset of the data for both stages. We manually validated the test results and selected annotators, who processed at least six samples (2% of the total test iteration) and did not make a single error. After test iteration, we cut 85% and 86% of labellers for gathering and validation stages. 132 | We especially urge annotators to correct mistakes associated with the substitution of the letters "ё" "й" and "щ" for corresponding "е" "и" and "ш" and not to explain abbreviations and correct punctuation errors. Each annotator is also warned about potentially sensitive topics in data (e.g., politics, societal minorities, and religion). 133 | 134 | The annotation of punctuation errors has been done in one iteration considering the low variation and difficulty of the task (relative to spelling correction). The annotators have been asked to correct punctuation signs in accordance with the rules of the Russian punctuation system. 135 | 136 | Who are the annotators? 137 | ######################## 138 | 139 | Native Russian speakers who passed the language exam. 140 | 141 | The annotators for punctuation errors are also professional editors and linguists. 142 | 143 | 144 | Considerations for Using the Data 145 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 146 | 147 | Discussion of Biases 148 | ##################### 149 | 150 | We clearly state our work’s aims and 151 | implications, making it open source and transparent. The data will be available under a public license. As our research involved anonymized textual data, informed consent from human participants was not required. However, we obtained permission to access publicly available datasets and 152 | ensured compliance with any applicable terms of 153 | service or usage policies. 154 | 155 | Other Known Limitations 156 | ######################## 157 | 158 | The data used in our research may be limited to specific 159 | domains, preventing comprehensive coverage of 160 | all possible text variations. Despite these limitations, we tried to address the issue of data diversity 161 | by incorporating single-domain and multi-domain 162 | datasets in the proposed research. This approach 163 | allowed us to shed light on the diversity and variances within the data, providing valuable insights 164 | despite the inherent constraints. 165 | 166 | We primarily focus on the Russian language. Further 167 | research is needed to expand the datasets for a wider 168 | range of languages. 169 | 170 | Additional Information 171 | ^^^^^^^^^^^^^^^^^^^^^^^^ 172 | 173 | Future plans 174 | ############### 175 | 176 | We are planning to expand our benchmark with both new Russian datasets and datasets in other languages including (but not limited to) European and CIS languages. 177 | If you would like to contribute, please contact us. 178 | 179 | Dataset Curators 180 | ################### 181 | 182 | Nikita Martynov nikita.martynov.98@list.ru (Spellcheck Punctuation Benchmark) 183 | 184 | Licensing Information 185 | ###################### 186 | 187 | All our datasets are published by MIT License. 188 | 189 | Citation Information 190 | ####################### 191 | 192 | .. code-block:: 193 | 194 | @inproceedings{martynov2023augmentation, 195 | title={Augmentation methods for spelling corruptions}, 196 | author={Martynov, Nikita and Baushenko, Mark and Abramov, Alexander and Fenogenova, Alena}, 197 | booktitle={Proceedings of the International Conference “Dialogue}, 198 | volume={2023}, 199 | year={2023} 200 | } 201 | 202 | @inproceedings{martynov-etal-2024-methodology, 203 | title = "A Methodology for Generative Spelling Correction via Natural Spelling Errors Emulation across Multiple Domains and Languages", 204 | author = "Martynov, Nikita and 205 | Baushenko, Mark and 206 | Kozlova, Anastasia and 207 | Kolomeytseva, Katerina and 208 | Abramov, Aleksandr and 209 | Fenogenova, Alena", 210 | editor = "Graham, Yvette and 211 | Purver, Matthew", 212 | booktitle = "Findings of the Association for Computational Linguistics: EACL 2024", 213 | month = mar, 214 | year = "2024", 215 | address = "St. Julian{'}s, Malta", 216 | publisher = "Association for Computational Linguistics", 217 | url = "https://aclanthology.org/2024.findings-eacl.10", 218 | pages = "138--155", 219 | abstract = "Large language models excel in text generation and generalization, however they face challenges in text editing tasks, especially in correcting spelling errors and mistyping.In this paper, we present a methodology for generative spelling correction (SC), tested on English and Russian languages and potentially can be extended to any language with minor changes. Our research mainly focuses on exploring natural spelling errors and mistyping in texts and studying how those errors can be emulated in correct sentences to enrich generative models{'} pre-train procedure effectively. We investigate the effects of emulations in various text domains and examine two spelling corruption techniques: 1) first one mimics human behavior when making a mistake through leveraging statistics of errors from a particular dataset, and 2) second adds the most common spelling errors, keyboard miss clicks, and some heuristics within the texts.We conducted experiments employing various corruption strategies, models{'} architectures, and sizes in the pre-training and fine-tuning stages and evaluated the models using single-domain and multi-domain test sets. As a practical outcome of our work, we introduce SAGE (Spell checking via Augmentation and Generative distribution Emulation).", 220 | } 221 | 222 | -------------------------------------------------------------------------------- /docs/source/rst/datasets/MedSpellchecker.rst: -------------------------------------------------------------------------------- 1 | 🫀 MedSpellchecker 2 | ------------------- 3 | 4 | The dataset is a part of `spellcheck_punctuation_benchmark `_: 5 | 6 | .. image:: ../../images/benchmark.png 7 | :align: center 8 | 9 | The Benchmark includes four datasets, each of which consists of pairs of sentences in Russian language. Each pair embodies sentence, which may contain spelling and punctuation errors, and its corresponding correction. Datasets were gathered from various sources and domains including social networks, internet blogs, github commits, medical anamnesis, literature, news, reviews and more. 10 | 11 | All datasets were passed through two-stage manual labeling pipeline. The correction of a sentence is defined by an agreement of at least two human annotators. Manual labeling scheme accounts for jargonisms, collocations and common language, hence in some cases it encourages annotators not to amend a word in favor of preserving style of a text. 12 | 13 | The latter does not apply to punctuation. Punctuation signs are rigorously marked in accordance to the rules of the Russian punctuation system. 14 | 15 | 16 | Table of contents 17 | ^^^^^^^^^^^^^^^^^ 18 | 19 | * `Dataset description <#id1>`_ 20 | 21 | * `Dataset summary <#id2>`_ 22 | * `Supported Tasks and Leaderboards <#id4>`_ 23 | * `Languages <#id5>`_ 24 | 25 | * `Dataset Structure <#id6>`_ 26 | 27 | * `Data Instances <#id7>`_ 28 | * `Data Fields <#id8>`_ 29 | * `Data Splits <#id9>`_ 30 | 31 | * `Dataset Creation <#id10>`_ 32 | 33 | * `Initial Data Collection and Normalization <#id11>`_ 34 | * `Annotation process <#id12>`_ 35 | * `Who are the annotators? <#id13>`_ 36 | 37 | * `Considerations for Using the Data <#id14>`_ 38 | 39 | * `Discussion of Biases <#id15>`_ 40 | * `Other Known Limitations <#id16>`_ 41 | 42 | * `Additional Information <#id17>`_ 43 | 44 | * `Future plans <#id18>`_ 45 | * `Dataset Curators <#id19>`_ 46 | * `Licensing Information <#id20>`_ 47 | * `Citation Information <#id21>`_ 48 | 49 | Dataset Description 50 | ^^^^^^^^^^^^^^^^^^^ 51 | 52 | - **Repository:** `SAGE `_ 53 | - **Paper:** `EACL 2024 `_ 54 | - **Point of Contact:** nikita.martynov.98@list.ru 55 | 56 | 57 | Dataset Summary 58 | ################ 59 | 60 | The dataset is obtained from the `MedSpellchecker `_ project. 61 | It originally consisted of texts from medical anamnesys that then have been passed through two-step manual annotation procedure. 62 | 63 | Supported Tasks and Leaderboards 64 | ################################# 65 | 66 | - **Task:** automatic spelling correction. 67 | - **Metrics:** https://www.dialog-21.ru/media/3427/sorokinaaetal.pdf. 68 | - **ERRANT:** https://github.com/chrisjbryant/errant. 69 | 70 | 71 | Languages 72 | ######### 73 | 74 | Russian. 75 | 76 | Dataset Structure 77 | ^^^^^^^^^^^^^^^^^ 78 | 79 | Data Instances 80 | ################ 81 | 82 | - **Size of downloaded dataset files:** 1.49 Mb 83 | - **Size of the generated dataset:** 0.54 Mb 84 | - **Total amount of disk used:** 2.03 Mb 85 | 86 | An example of "train" / "test" looks as follows 87 | 88 | .. code-block:: 89 | 90 | { 91 | "source": "Накануне (18.02.2012 г", 92 | "correction": "Накануне (18.02.2012 г.).", 93 | } 94 | 95 | Data Fields 96 | ################ 97 | 98 | - `source`: a `string` feature 99 | - `correction`: a `string` feature 100 | - `domain`: a `string` feature 101 | 102 | Data Splits 103 | ################ 104 | 105 | +---------------+------+ 106 | | | test | 107 | +===============+======+ 108 | | MedSpellcheck | 1054 | 109 | +---------------+------+ 110 | 111 | Dataset Creation 112 | ^^^^^^^^^^^^^^^^^ 113 | 114 | Initial Data Collection and Normalization 115 | ########################################## 116 | 117 | The source data gathering procedure in described in the `paper `_. 118 | We took the released splits and set up the annotation process. 119 | 120 | Annotation process 121 | ########################################## 122 | 123 | We set up two-stage annotation project via a crowd-sourcing platform Toloka: 124 | 125 | 1. Data gathering stage: we provide the texts with possible mistakes to annotators and ask them to write the sentence correctly; 126 | 2. Validation stage: we provide annotators with the pair of sentences (source and its corresponding correction from the previous stage) and ask them to check if the correction is right. 127 | 128 | We prepared instructions for annotators for each task. The instructions ask annotators to correct misspellings if it does not alter the original style of the text. 129 | Instructions do not provide rigorous criteria on the matter of distinguishing the nature of an error in terms of its origin - whether it came from an urge to endow a sentence with particular stylistic features or from unintentional spelling violation since it is time-consuming and laborious to describe every possible case of employing slang, dialect, colloquialisms, etc. instead of proper language. Instructions also do not distinguish errors that come from the geographical or social background of the source. Instead, we rely on annotators’ knowledge and understanding of a language since, in this work, the important factor is to preserve the original style of the text. 130 | To ensure we receive qualified expertise, we set up test iteration on a small subset of the data for both stages. We manually validated the test results and selected annotators, who processed at least six samples (2% of the total test iteration) and did not make a single error. After test iteration, we cut 85% and 86% of labellers for gathering and validation stages. 131 | We especially urge annotators to correct mistakes associated with the substitution of the letters "ё" "й" and "щ" for corresponding "е" "и" and "ш" and not to explain abbreviations and correct punctuation errors. Each annotator is also warned about potentially sensitive topics in data (e.g., politics, societal minorities, and religion). 132 | 133 | The annotation of punctuation errors has been done in one iteration considering the low variation and difficulty of the task (relative to spelling correction). The annotators have been asked to correct punctuation signs in accordance with the rules of the Russian punctuation system. 134 | 135 | Who are the annotators? 136 | ######################## 137 | 138 | Native Russian speakers who passed the language exam. 139 | 140 | The annotators for punctuation errors are also professional editors and linguists. 141 | 142 | 143 | Considerations for Using the Data 144 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 145 | 146 | Discussion of Biases 147 | ##################### 148 | 149 | We clearly state our work’s aims and 150 | implications, making it open source and transparent. The data will be available under a public license. As our research involved anonymized textual data, informed consent from human participants was not required. However, we obtained permission to access publicly available datasets and 151 | ensured compliance with any applicable terms of 152 | service or usage policies. 153 | 154 | Other Known Limitations 155 | ######################## 156 | 157 | The data used in our research may be limited to specific 158 | domains, preventing comprehensive coverage of 159 | all possible text variations. Despite these limitations, we tried to address the issue of data diversity 160 | by incorporating single-domain and multi-domain 161 | datasets in the proposed research. This approach 162 | allowed us to shed light on the diversity and variances within the data, providing valuable insights 163 | despite the inherent constraints. 164 | 165 | We primarily focus on the Russian language. Further 166 | research is needed to expand the datasets for a wider 167 | range of languages. 168 | 169 | Additional Information 170 | ^^^^^^^^^^^^^^^^^^^^^^^^ 171 | 172 | Future plans 173 | ############### 174 | 175 | We are planning to expand our benchmark with both new Russian datasets and datasets in other languages including (but not limited to) European and CIS languages. 176 | If you would like to contribute, please contact us. 177 | 178 | Dataset Curators 179 | ################### 180 | 181 | Nikita Martynov nikita.martynov.98@list.ru (Spellcheck Punctuation Benchmark) 182 | 183 | Licensing Information 184 | ###################### 185 | 186 | All our datasets are published by MIT License. 187 | 188 | Citation Information 189 | ####################### 190 | 191 | .. code-block:: 192 | 193 | @inproceedings{martynov2023augmentation, 194 | title={Augmentation methods for spelling corruptions}, 195 | author={Martynov, Nikita and Baushenko, Mark and Abramov, Alexander and Fenogenova, Alena}, 196 | booktitle={Proceedings of the International Conference “Dialogue}, 197 | volume={2023}, 198 | year={2023} 199 | } 200 | 201 | @inproceedings{martynov-etal-2024-methodology, 202 | title = "A Methodology for Generative Spelling Correction via Natural Spelling Errors Emulation across Multiple Domains and Languages", 203 | author = "Martynov, Nikita and 204 | Baushenko, Mark and 205 | Kozlova, Anastasia and 206 | Kolomeytseva, Katerina and 207 | Abramov, Aleksandr and 208 | Fenogenova, Alena", 209 | editor = "Graham, Yvette and 210 | Purver, Matthew", 211 | booktitle = "Findings of the Association for Computational Linguistics: EACL 2024", 212 | month = mar, 213 | year = "2024", 214 | address = "St. Julian{'}s, Malta", 215 | publisher = "Association for Computational Linguistics", 216 | url = "https://aclanthology.org/2024.findings-eacl.10", 217 | pages = "138--155", 218 | abstract = "Large language models excel in text generation and generalization, however they face challenges in text editing tasks, especially in correcting spelling errors and mistyping.In this paper, we present a methodology for generative spelling correction (SC), tested on English and Russian languages and potentially can be extended to any language with minor changes. Our research mainly focuses on exploring natural spelling errors and mistyping in texts and studying how those errors can be emulated in correct sentences to enrich generative models{'} pre-train procedure effectively. We investigate the effects of emulations in various text domains and examine two spelling corruption techniques: 1) first one mimics human behavior when making a mistake through leveraging statistics of errors from a particular dataset, and 2) second adds the most common spelling errors, keyboard miss clicks, and some heuristics within the texts.We conducted experiments employing various corruption strategies, models{'} architectures, and sizes in the pre-training and fine-tuning stages and evaluated the models using single-domain and multi-domain test sets. As a practical outcome of our work, we introduce SAGE (Spell checking via Augmentation and Generative distribution Emulation).", 219 | } 220 | 221 | -------------------------------------------------------------------------------- /docs/source/rst/datasets/RUSpellRU.rst: -------------------------------------------------------------------------------- 1 | 📕 RUSpellRU 2 | ------------------- 3 | 4 | 5 | The dataset is a part of `spellcheck_punctuation_benchmark `_: 6 | 7 | .. image:: ../../images/benchmark.png 8 | :align: center 9 | 10 | 11 | The Benchmark includes four datasets, each of which consists of pairs of sentences in Russian language. Each pair embodies sentence, which may contain spelling and punctuation errors, and its corresponding correction. Datasets were gathered from various sources and domains including social networks, internet blogs, github commits, medical anamnesis, literature, news, reviews and more. 12 | 13 | All datasets were passed through two-stage manual labeling pipeline. The correction of a sentence is defined by an agreement of at least two human annotators. Manual labeling scheme accounts for jargonisms, collocations and common language, hence in some cases it encourages annotators not to amend a word in favor of preserving style of a text. 14 | 15 | The latter does not apply to punctuation. Punctuation signs are rigorously marked in accordance to the rules of the Russian punctuation system. 16 | 17 | 18 | Table of contents 19 | ^^^^^^^^^^^^^^^^^ 20 | 21 | * `Dataset description <#id1>`_ 22 | 23 | * `Dataset summary <#id2>`_ 24 | * `Supported Tasks and Leaderboards <#id3>`_ 25 | * `Languages <#id4>`_ 26 | 27 | * `Dataset Structure <#id5>`_ 28 | 29 | * `Data Instances <#id6>`_ 30 | * `Data Fields <#id7>`_ 31 | * `Data Splits <#id8>`_ 32 | 33 | * `Considerations for Using the Data <#id9>`_ 34 | 35 | * `Discussion of Biases <#id9>`_ 36 | * `Other Known Limitations <#id10>`_ 37 | 38 | * `Additional Information <#id11>`_ 39 | 40 | * `Future plans <#id12>`_ 41 | * `Dataset Curators <#id13>`_ 42 | * `Licensing Information <#id14>`_ 43 | * `Citation Information <#id15>`_ 44 | 45 | 46 | Dataset Description 47 | ^^^^^^^^^^^^^^^^^^^ 48 | 49 | - **Repository:** `SAGE `_ 50 | - **Paper:** `EACL 2024 `_ 51 | - **Point of Contact:** nikita.martynov.98@list.ru 52 | 53 | Dataset Summary 54 | ################ 55 | 56 | The dataset origins from `RuSpellEval competition `_. 57 | The texts were gathered from `LiveJournal `_ and annotated by linguistic experts in two rounds. 58 | RUSpellRU amounts for 4k sentence pairs that represented Social Networks and Internet Blogs text domains. 59 | 60 | Supported Tasks and Leaderboards 61 | ################################# 62 | 63 | - **Task:** automatic spelling correction. 64 | - **Metrics:** https://www.dialog-21.ru/media/3427/sorokinaaetal.pdf. 65 | - **ERRANT:** https://github.com/chrisjbryant/errant. 66 | 67 | Languages 68 | ######### 69 | 70 | Russian. 71 | 72 | Dataset Structure 73 | ^^^^^^^^^^^^^^^^^ 74 | 75 | Data Instances 76 | ################ 77 | 78 | - **Size of downloaded dataset files:** 3.65 Mb 79 | - **Size of the generated dataset:** 1.31 Mb 80 | - **Total amount of disk used:** 4.96 Mb 81 | 82 | An example of "train" / "test" looks as follows 83 | 84 | .. code-block:: 85 | 86 | { 87 | "source": "очень классная тетка ктобы что не говорил.", 88 | "correction": "очень классная тетка кто бы что ни говорил", 89 | } 90 | 91 | Data Fields 92 | ################ 93 | 94 | - `source`: a `string` feature 95 | - `correction`: a `string` feature 96 | - `domain`: a `string` feature 97 | 98 | Data Splits 99 | ################ 100 | 101 | +-----------+-------+------+ 102 | | | train | test | 103 | +===========+=======+======+ 104 | | RUSpellRU | 2000 | 2008 | 105 | +-----------+-------+------+ 106 | 107 | 108 | Considerations for Using the Data 109 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 110 | 111 | Discussion of Biases 112 | ##################### 113 | 114 | We clearly state our work’s aims and 115 | implications, making it open source and transparent. The data will be available under a public license. As our research involved anonymized textual data, informed consent from human participants was not required. However, we obtained permission to access publicly available datasets and 116 | ensured compliance with any applicable terms of 117 | service or usage policies. 118 | 119 | Other Known Limitations 120 | ######################## 121 | 122 | The data used in our research may be limited to specific 123 | domains, preventing comprehensive coverage of 124 | all possible text variations. Despite these limitations, we tried to address the issue of data diversity 125 | by incorporating single-domain and multi-domain 126 | datasets in the proposed research. This approach 127 | allowed us to shed light on the diversity and variances within the data, providing valuable insights 128 | despite the inherent constraints. 129 | 130 | We primarily focus on the Russian language. Further 131 | research is needed to expand the datasets for a wider 132 | range of languages. 133 | 134 | Additional Information 135 | ^^^^^^^^^^^^^^^^^^^^^^^^ 136 | 137 | Future plans 138 | ############### 139 | 140 | We are planning to expand our benchmark with both new Russian datasets and datasets in other languages including (but not limited to) European and CIS languages. 141 | If you would like to contribute, please contact us. 142 | 143 | Dataset Curators 144 | ################### 145 | 146 | Nikita Martynov nikita.martynov.98@list.ru (Spellcheck Punctuation Benchmark) 147 | 148 | Licensing Information 149 | ###################### 150 | 151 | All our datasets are published by MIT License. 152 | 153 | Citation Information 154 | ####################### 155 | 156 | .. code-block:: 157 | 158 | @inproceedings{martynov2023augmentation, 159 | title={Augmentation methods for spelling corruptions}, 160 | author={Martynov, Nikita and Baushenko, Mark and Abramov, Alexander and Fenogenova, Alena}, 161 | booktitle={Proceedings of the International Conference “Dialogue}, 162 | volume={2023}, 163 | year={2023} 164 | } 165 | 166 | @inproceedings{martynov-etal-2024-methodology, 167 | title = "A Methodology for Generative Spelling Correction via Natural Spelling Errors Emulation across Multiple Domains and Languages", 168 | author = "Martynov, Nikita and 169 | Baushenko, Mark and 170 | Kozlova, Anastasia and 171 | Kolomeytseva, Katerina and 172 | Abramov, Aleksandr and 173 | Fenogenova, Alena", 174 | editor = "Graham, Yvette and 175 | Purver, Matthew", 176 | booktitle = "Findings of the Association for Computational Linguistics: EACL 2024", 177 | month = mar, 178 | year = "2024", 179 | address = "St. Julian{'}s, Malta", 180 | publisher = "Association for Computational Linguistics", 181 | url = "https://aclanthology.org/2024.findings-eacl.10", 182 | pages = "138--155", 183 | abstract = "Large language models excel in text generation and generalization, however they face challenges in text editing tasks, especially in correcting spelling errors and mistyping.In this paper, we present a methodology for generative spelling correction (SC), tested on English and Russian languages and potentially can be extended to any language with minor changes. Our research mainly focuses on exploring natural spelling errors and mistyping in texts and studying how those errors can be emulated in correct sentences to enrich generative models{'} pre-train procedure effectively. We investigate the effects of emulations in various text domains and examine two spelling corruption techniques: 1) first one mimics human behavior when making a mistake through leveraging statistics of errors from a particular dataset, and 2) second adds the most common spelling errors, keyboard miss clicks, and some heuristics within the texts.We conducted experiments employing various corruption strategies, models{'} architectures, and sizes in the pre-training and fine-tuning stages and evaluated the models using single-domain and multi-domain test sets. As a practical outcome of our work, we introduce SAGE (Spell checking via Augmentation and Generative distribution Emulation).", 184 | } 185 | 186 | -------------------------------------------------------------------------------- /docs/source/rst/evaluation/RuErrant.rst: -------------------------------------------------------------------------------- 1 | 💥 RuErrant 2 | ------------------- 3 | 4 | RuERRANT is an adaptation of the `ERRANT metric `_ to the Russian language. The adaptation was primarily done in https://github.com/Askinkaty/errant and further developed within SAGE. The changes to the original ERRANT implementation for English are the following: 5 | 6 | 1. Basic parsing model changed to Spacy's `ru_core_news_lg`. 7 | 2. Included a dictionary of Russian words (main forms). 8 | 3. Introduced detection of error correction types specific for Russian (degrees of adjectives, verb aspect). 9 | 4. [our contribution] Introduced a simplified error correction typology: 10 | - `CASE`: spelling corrections including only character case change; 11 | - `PUNCT`: punctuation corrections; 12 | - `YO`: spelling corrections regarding "е"/"ё" substitutions; 13 | - `SPELL`: all other word-level spelling corrections. 14 | 5. [our contribution] Introduced detection of multiple error correction types per word, e.g. "федор" -> "Фёдор" contains both CASE and YO corrections. 15 | 6. [our contribution] Introduced detection of inner word punctuation corrections which covers joint ("AB") vs. hyphen ("A-B") vs. space ("A B") word spelling. Corrections of this type are attributed to the `SPELL` category. 16 | 17 | Scoring 18 | ^^^^^^^^ 19 | 20 | To score model's corrections against gold corrections, use a Scorer instance: 21 | 22 | .. code-block:: python 23 | 24 | from sage.evaluation.scorer import Scorer 25 | 26 | s = Scorer() 27 | 28 | s.score( 29 | ["спел Кейс ее .", "спел Кейс ее ."], 30 | ["спелл кейс её !", "спелл кейс её !"], 31 | ["спел кейс её .", "спелл Кейс ее !"], 32 | metrics=["errant"] 33 | ) 34 | >>> {'CASE_Precision': 100.0, 'CASE_Recall': 50.0, 'CASE_F1': 66.67, 35 | 'YO_Precision': 100.0, 'YO_Recall': 50.0, 'YO_F1': 66.67, 36 | 'SPELL_Precision': 100.0, 'SPELL_Recall': 50.0, 'SPELL_F1': 66.67, 37 | 'PUNCT_Precision': 100.0, 'PUNCT_Recall': 50.0, 'PUNCT_F1': 66.67} 38 | 39 | -------------------------------------------------------------------------------- /docs/source/rst/evaluation/RuSpellEval.rst: -------------------------------------------------------------------------------- 1 | ⚠️ RuSpellEval 2 | ------------------- 3 | 4 | RuSpellEval is described in `paper `_. 5 | 6 | The metric does not account for punctuation and register associated errors, so we primarily used it in earlier releases. 7 | 8 | You can still invoke the RuSpellEval metric by calling the following: 9 | 10 | .. code-block:: python 11 | 12 | from sage.evaluation import Scorer 13 | from sage.utils import DatasetsAvailable, load_available_dataset_from_hf 14 | 15 | sources, corrections = load_available_dataset_from_hf(DatasetsAvailable.RUSpellRU.name, for_labeler=True, split="test") 16 | 17 | scorer = Scorer() 18 | metrics = scorer.score(sources, corrections, corrections, metrics=["ruspelleval"]) 19 | print(metrics) 20 | 21 | # {'Precision': 100.0, 'Recall': 100.0, 'F1': 100.0} 22 | 23 | 24 | ... or in conjunction with RuErrant metric: 25 | 26 | .. code-block:: python 27 | 28 | import os 29 | import torch 30 | from sage.utils import DatasetsAvailable 31 | from sage.spelling_correction import AvailableCorrectors 32 | from sage.spelling_correction import T5ModelForSpellingCorruption 33 | 34 | corrector_fred_95m = T5ModelForSpellingCorruption.from_pretrained(AvailableCorrectors.sage_fredt5_distilled_95m.value) 35 | corrector_mt5 = T5ModelForSpellingCorruption.from_pretrained(AvailableCorrectors.sage_mt5_large.value) 36 | 37 | corrector_fred_95m.model.to(torch.device("cuda:0")) 38 | corrector_mt5.model.to(torch.device("cuda:0")) 39 | 40 | metrics = corrector_fred_95m.evaluate("RUSpellRU", metrics=["errant", "ruspelleval"], batch_size=32) 41 | print(metrics) 42 | # {'CASE_Precision': 94.41, 'CASE_Recall': 92.55, 'CASE_F1': 93.47, 'SPELL_Precision': 77.52, 'SPELL_Recall': 64.09, 'SPELL_F1': 70.17, 'PUNCT_Precision': 86.77, 'PUNCT_Recall': 80.59, 'PUNCT_F1': 83.56, 'YO_Precision': 46.21, 'YO_Recall': 73.83, 'YO_F1': 56.84, 'Precision': 83.48, 'Recall': 74.75, 'F1': 78.87} 43 | -------------------------------------------------------------------------------- /docs/source/rst/spelling_correction/T5.rst: -------------------------------------------------------------------------------- 1 | 🇬🇧 T5-large 2 | ------------------- 3 | 4 | The model corrects spelling errors and typos by bringing all words in the text to the standard English language. 5 | The proofreader was trained based on the `T5-large `_ model. 6 | An extensive dataset with “artificial” errors was taken as a training corpus: the corpus was assembled on the basis of the English-language Wikipedia and News blogs, then typos and spelling errors were automatically introduced into it using the functionality of the `SAGE library `_. 7 | 8 | 9 | Table of contents 10 | ^^^^^^^^^^^^^^^^^ 11 | 12 | * `Public references <#id2>`_ 13 | * `Examples <#id3>`_ 14 | * `Metrics <#id4>`_ 15 | * `How to use <#id5>`_ 16 | * `API <#id6>`_ 17 | * `Resources <#id7>`_ 18 | * `License <#id10>`_ 19 | * `Specifications <#id12>`_ 20 | * `Contacts <#id13>`_ 21 | 22 | 23 | Public references 24 | ^^^^^^^^^^^^^^^^^ 25 | 26 | - `SAGE library announcement `_, DataFest 2023 27 | - `Paper about synthetic error generation methods `_, Dialogue 2023 28 | - `EACL 2024 paper `_ 29 | 30 | 31 | Examples 32 | ^^^^^^^^^ 33 | 34 | +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ 35 | | Input | Output | 36 | +===============================================================================================================================================================================================================================================================================+===========================================================================================================================================================================================================================================================================+ 37 | | Th festeivаl was excelzecnt in many ways, and in particular it beinganinternational festjival sss a chаllenging, bet brilli an t ea. | The festival was excellent in many ways, and in particular it beinganinternational festival is a challenging, but brilliant one to see. | 38 | +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ 39 | | That 's why I believe in the solution which is the closest to human nature and can help us to avoid boredome. I am sure that eventually we will take off our clothes and in the future we will be undressed and free. There wo n't be any problem with being up - do - date . | That's why I believe in the solution which is the closest to human nature and can help us to avoid boredom. I am sure that eventually we will take off our clothes and in the future we will be undressed and free. There won't be any problem with being up - do - date. | 40 | +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ 41 | | If you bought something goregous, you well be very happy. | If you bought something gorgeous, you will be very happy. | 42 | +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ 43 | 44 | 45 | Metrics 46 | ^^^^^^^^^ 47 | 48 | Below are automatic metrics for determining the correctness of the spell checkers. 49 | We present a comparison of our solution both with open automatic spell checkers and with the ChatGPT family of models on two available datasets: 50 | - **BEA60K**: English spelling errors collected from several domains; 51 | - **JFLEG**: 1601 sentences in English, which contain about 2 thousand spelling errors; 52 | 53 | 54 | **BEA60K** 55 | 56 | +------------------------------------------------+-----------+--------+------+ 57 | | Model | Precision | Recall | F1 | 58 | +================================================+===========+========+======+ 59 | | T5-large-spell | 66.5 | 83.1 | 73.9 | 60 | +------------------------------------------------+-----------+--------+------+ 61 | | ChatGPT gpt-3.5-turbo-0301 | 66.9 | 84.1 | 74.5 | 62 | +------------------------------------------------+-----------+--------+------+ 63 | | ChatGPT gpt-4-0314 | 68.6 | 85.2 | 76.0 | 64 | +------------------------------------------------+-----------+--------+------+ 65 | | ChatGPT text-davinci-003 | 67.8 | 83.9 | 75.0 | 66 | +------------------------------------------------+-----------+--------+------+ 67 | | Bert (https://github.com/neuspell/neuspell) | 65.8 | 79.6 | 72.0 | 68 | +------------------------------------------------+-----------+--------+------+ 69 | | SC-LSTM (https://github.com/neuspell/neuspell) | 62.2 | 80.3 | 72.0 | 70 | +------------------------------------------------+-----------+--------+------+ 71 | 72 | 73 | **JFLEG** 74 | 75 | +------------------------------------------------+-----------+--------+------+ 76 | | Model | Precision | Recall | F1 | 77 | +================================================+===========+========+======+ 78 | | T5-large-spell | 83.4 | 84.3 | 83.8 | 79 | +------------------------------------------------+-----------+--------+------+ 80 | | ChatGPT gpt-3.5-turbo-0301 | 77.8 | 88.6 | 82.9 | 81 | +------------------------------------------------+-----------+--------+------+ 82 | | ChatGPT gpt-4-0314 | 77.9 | 88.3 | 82.8 | 83 | +------------------------------------------------+-----------+--------+------+ 84 | | ChatGPT text-davinci-003 | 76.8 | 88.5 | 82.2 | 85 | +------------------------------------------------+-----------+--------+------+ 86 | | Bert (https://github.com/neuspell/neuspell) | 78.5 | 85.4 | 81.8 | 87 | +------------------------------------------------+-----------+--------+------+ 88 | | SC-LSTM (https://github.com/neuspell/neuspell) | 80.6 | 86.1 | 83.2 | 89 | +------------------------------------------------+-----------+--------+------+ 90 | 91 | 92 | 93 | How to use 94 | ^^^^^^^^^^^^ 95 | 96 | .. code-block:: python 97 | 98 | from transformers import T5ForConditionalGeneration, AutoTokenizer 99 | 100 | path_to_model = "ai-forever/T5-large-spell" 101 | model = T5ForConditionalGeneration.from_pretrained(path_to_model) 102 | tokenizer = AutoTokenizer.from_pretrained(path_to_model) 103 | 104 | prefix = "grammar: " 105 | sentence = "If you bought something goregous, you well be very happy." 106 | sentence = prefix + sentence 107 | encodings = tokenizer(sentence, return_tensors="pt") 108 | generated_tokens = model.generate(**encodings) 109 | answer = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) 110 | print(answer) 111 | 112 | # ["If you bought something gorgeous, you will be very happy."] 113 | 114 | 115 | API 116 | ^^^^^ 117 | 118 | .. autoclass:: sage.spelling_correction.t5_correctors.T5ModelForSpellingCorruption 119 | :members: 120 | :inherited-members: 121 | :show-inheritance: 122 | 123 | 124 | Resources 125 | ^^^^^^^^^ 126 | 127 | - `SAGE library `_, GitHub 128 | - `SAGE v1.1.0 models release `_, HuggingFace 129 | - `EACL 2024 paper `_ 130 | 131 | License 132 | ^^^^^^^^^ 133 | 134 | The `T5-large `_ model, on which our solution is based, and its source code are supplied under the APACHE-2.0 license. 135 | Our solution is supplied under MIT license. 136 | 137 | Specifications 138 | ^^^^^^^^^^^^^^^ 139 | - File size: 3 Gb; 140 | - Framework: pytorch 141 | - Format: AI Service 142 | - Version: v1.0 143 | - Developer: SberDevices, AGI NLP 144 | 145 | Contacts 146 | ^^^^^^^^^^^ 147 | nikita.martynov.98@list.ru 148 | 149 | -------------------------------------------------------------------------------- /docs/source/rst/spelling_correction/sage-mt5-large.rst: -------------------------------------------------------------------------------- 1 | 🎯 sage-mt5-large 2 | ------------------- 3 | 4 | The model is a part of `SAGE v1.1.0 release `_ 5 | 6 | .. image:: ../../images/sage_banner.jpg 7 | :align: center 8 | 9 | The model corrects spelling errors and typos in both Russian and English languages by bringing all the words in the text to the norm of the language. 10 | Corrector had been trained based on the model `mT5-large `_ architecture. 11 | An extensive dataset with “artificial” errors was taken as a training corpus: the corpus was assembled on the basis of the Russian-language Wikipedia and transcripts of Russian-language videos, then typos and spelling errors were automatically introduced into it using the library `SAGE library `_. 12 | 13 | 14 | Table of contents 15 | ^^^^^^^^^^^^^^^^^ 16 | 17 | * `Public references <#id1>`_ 18 | * `Examples <#id2>`_ 19 | * `Metrics <#id3>`_ 20 | * `How to use <#id4>`_ 21 | * `API <#id5>`_ 22 | * `Limitations <#id6>`_ 23 | * `Resources <#id7>`_ 24 | * `License <#id10>`_ 25 | * `Specifications <#id12>`_ 26 | * `Contacts <#id13>`_ 27 | 28 | 29 | Public references 30 | ^^^^^^^^^^^^^^^^^ 31 | 32 | - `SAGE library announcement `_, DataFest 2023 33 | - `Paper about synthetic error generation methods `_, Dialogue 2023 34 | - `EACL 2024 paper `_ 35 | 36 | 37 | Examples 38 | ^^^^^^^^^ 39 | 40 | +--------------------------------------------------------------------+--------------------------------------------------------------------------+ 41 | | Input | Output | 42 | +====================================================================+==========================================================================+ 43 | | Перведи мне текст на аглиском: "Screw you kuys, I am goin hme (c). | Переведи мне текст на английском: "Screw you guys, I am going home" (c). | 44 | +--------------------------------------------------------------------+--------------------------------------------------------------------------+ 45 | | И не чсно прохожим в этот день непогожйи почему я веселый такйо | И мне ясно прохожим в этот день непогожий, почему я веселый такой | 46 | +--------------------------------------------------------------------+--------------------------------------------------------------------------+ 47 | | If you bought something goregous, you well be very happy. | If you bought something gorgeous, you will be very happy. | 48 | +--------------------------------------------------------------------+--------------------------------------------------------------------------+ 49 | 50 | 51 | Metrics 52 | ^^^^^^^^^ 53 | 54 | Below are automatic metrics for determining the correctness of the spell checkers. 55 | We compare our solution with both open automatic spell checkers and the ChatGPT family of models on all six available datasets: 56 | - **RUSpellRU**: texts collected from (`LiveJournal `_), with manually corrected typos and errors; 57 | - **MultidomainGold**: examples from 7 text sources, including the open web, news, social media, reviews, subtitles, policy documents and literary works; 58 | - **MedSpellChecker**: texts with errors from medical anamnesis; 59 | - **GitHubTypoCorpusRu**: spelling errors and typos in commits from GitHub; 60 | - **BEA60K**: English spelling errors collected from several domains; 61 | - **JFLEG**: 1601 sentences in English, which contain about 2 thousand spelling errors; 62 | 63 | RUSpellRU, MultidomainGold, MedSpellChecker, GitHubTypoCorpusRu are datasets for the Russian spellchecking and BEA60K and JFLEG are those for the English language. 64 | 65 | 66 | **RUSpellRU** 67 | 68 | +----------------------+-----------+--------+------+ 69 | | Model | Precision | Recall | F1 | 70 | +======================+===========+========+======+ 71 | | sage-mt5-large | 55.7 | 68.5 | 61.4 | 72 | +----------------------+-----------+--------+------+ 73 | | sage-mt5-large (ft.) | 88.4 | 71.6 | 79.1 | 74 | +----------------------+-----------+--------+------+ 75 | | sage-ai-service | 93.5 | 82.4 | 87.6 | 76 | +----------------------+-----------+--------+------+ 77 | | gpt-3.5-turbo | 39.6 | 62.3 | 48.5 | 78 | +----------------------+-----------+--------+------+ 79 | | gpt-4 | 69.5 | 81.0 | 74.8 | 80 | +----------------------+-----------+--------+------+ 81 | 82 | 83 | **MultidomainGold** 84 | 85 | +----------------------+-----------+--------+------+ 86 | | Model | Precision | Recall | F1 | 87 | +======================+===========+========+======+ 88 | | sage-mt5-large | 35.4 | 57.9 | 43.9 | 89 | +----------------------+-----------+--------+------+ 90 | | sage-mt5-large (ft.) | 65.3 | 62.7 | 63.9 | 91 | +----------------------+-----------+--------+------+ 92 | | sage-ai-service | 70.9 | 68.8 | 69.9 | 93 | +----------------------+-----------+--------+------+ 94 | | gpt-3.5-turbo | 17.8 | 56.1 | 27.0 | 95 | +----------------------+-----------+--------+------+ 96 | | gpt-4 | 31.1 | 78.1 | 44.5 | 97 | +----------------------+-----------+--------+------+ 98 | 99 | 100 | **MedSpellChecker** 101 | 102 | +----------------------+-----------+--------+------+ 103 | | Model | Precision | Recall | F1 | 104 | +======================+===========+========+======+ 105 | | sage-mt5-large | 35.1 | 70.8 | 47.0 | 106 | +----------------------+-----------+--------+------+ 107 | | sage-mt5-large (ft.) | 77.7 | 77.5 | 77.6 | 108 | +----------------------+-----------+--------+------+ 109 | | sage-ai-service | 73.4 | 76.2 | 74.9 | 110 | +----------------------+-----------+--------+------+ 111 | | gpt-3.5-turbo | 15.1 | 53.6 | 23.5 | 112 | +----------------------+-----------+--------+------+ 113 | | gpt-4 | 48.9 | 88.7 | 63.1 | 114 | +----------------------+-----------+--------+------+ 115 | 116 | 117 | **GitHubTypoCorpusRu** 118 | 119 | +----------------------+-----------+--------+------+ 120 | | Model | Precision | Recall | F1 | 121 | +======================+===========+========+======+ 122 | | sage-mt5-large | 47.4 | 53.8 | 50.4 | 123 | +----------------------+-----------+--------+------+ 124 | | sage-mt5-large (ft.) | 69.5 | 46.0 | 55.3 | 125 | +----------------------+-----------+--------+------+ 126 | | sage-ai-service | 76.1 | 51.2 | 61.2 | 127 | +----------------------+-----------+--------+------+ 128 | | gpt-3.5-turbo | 23.7 | 43.9 | 30.8 | 129 | +----------------------+-----------+--------+------+ 130 | | gpt-4 | 34.7 | 60.5 | 44.1 | 131 | +----------------------+-----------+--------+------+ 132 | 133 | 134 | **BEA60K** 135 | 136 | +------------------------------------------------+-----------+--------+------+ 137 | | Model | Precision | Recall | F1 | 138 | +================================================+===========+========+======+ 139 | | sage-mt5-large | 64.7 | 83.8 | 73.0 | 140 | +------------------------------------------------+-----------+--------+------+ 141 | | gpt-3.5-turbo | 66.9 | 84.1 | 74.5 | 142 | +------------------------------------------------+-----------+--------+------+ 143 | | gpt-4 | 68.6 | 85.2 | 76.0 | 144 | +------------------------------------------------+-----------+--------+------+ 145 | | Bert (https://github.com/neuspell/neuspell) | 65.8 | 79.6 | 72.0 | 146 | +------------------------------------------------+-----------+--------+------+ 147 | | SC-LSTM (https://github.com/neuspell/neuspell) | 62.2 | 80.3 | 72.0 | 148 | +------------------------------------------------+-----------+--------+------+ 149 | 150 | 151 | **JFLEG** 152 | 153 | +------------------------------------------------+-----------+--------+------+ 154 | | Model | Precision | Recall | F1 | 155 | +================================================+===========+========+======+ 156 | | sage-mt5-large | 74.9 | 88.4 | 81.1 | 157 | +------------------------------------------------+-----------+--------+------+ 158 | | gpt-3.5-turbo | 77.8 | 88.6 | 82.9 | 159 | +------------------------------------------------+-----------+--------+------+ 160 | | gpt-4 | 77.9 | 88.3 | 82.8 | 161 | +------------------------------------------------+-----------+--------+------+ 162 | | Bert (https://github.com/neuspell/neuspell) | 78.5 | 85.4 | 81.8 | 163 | +------------------------------------------------+-----------+--------+------+ 164 | | SC-LSTM (https://github.com/neuspell/neuspell) | 80.6 | 86.1 | 83.2 | 165 | +------------------------------------------------+-----------+--------+------+ 166 | 167 | 168 | How to use 169 | ^^^^^^^^^^^^ 170 | 171 | .. code-block:: python 172 | 173 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 174 | 175 | tokenizer = AutoTokenizer.from_pretrained("ai-forever/sage-mt5-large") 176 | model = AutoModelForSeq2SeqLM.from_pretrained("ai-forever/sage-mt5-large", device_map='cuda') 177 | 178 | sentence = "Перведи мне текст на аглиском: \"Screw you kuys, I am goin hme (c)." 179 | inputs = tokenizer(sentence, max_length=None, padding="longest", truncation=False, return_tensors="pt") 180 | outputs = model.generate(**inputs.to(model.device), max_length = inputs["input_ids"].size(1) * 1.5) 181 | print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) 182 | 183 | # ["Переведи мне текст на английском: "Screw you guys, I am going home" (c)."] 184 | 185 | 186 | API 187 | ^^^^^ 188 | 189 | .. autoclass:: sage.spelling_correction.t5_correctors.T5ModelForSpellingCorruption 190 | :members: 191 | :inherited-members: 192 | :show-inheritance: 193 | 194 | 195 | Limitations 196 | ^^^^^^^^^^^^ 197 | 198 | - For the Russian language the model is intended to be fine-tuned for better performance. 199 | 200 | Resources 201 | ^^^^^^^^^^^^ 202 | 203 | - `SAGE library `_, GitHub 204 | - `sage-fredt5-large `_, HuggingFace 205 | - `sage-fredt5-distilled-95m `_, HuggingFace 206 | - `sage-m2m100-1.2B `_, HuggingFace 207 | - `sage-mt5-large `_, HuggingFace 208 | 209 | 210 | License 211 | ^^^^^^^^^ 212 | 213 | Model `mt5-large `_, on the basis of which our solution is made, and its source code are supplied under the Apache-2.0 license. 214 | Our solution comes with MIT license. 215 | 216 | Specifications 217 | ^^^^^^^^^^^^^^^ 218 | - File size: 5 Gb; 219 | - Framework: pytorch 220 | - Format: AI Service 221 | - Version: v1.0 222 | - Developer: SberDevices, AGI NLP 223 | 224 | Contacts 225 | ^^^^^^^^^^^ 226 | nikita.martynov.98@list.ru 227 | 228 | -------------------------------------------------------------------------------- /docs/source/rst/spelling_corruption/Augmentex.rst: -------------------------------------------------------------------------------- 1 | ✏️ Augmentex 2 | ------------------- 3 | 4 | We implemented two methods for spelling corruption. **S**\ tatistic-\ **b**\ ased **S**\ pelling **C**\ orruption (\ **SBSC**\ ) aims 5 | to mimic human behaviour when making an error. While `Augmentex `_ relies on rule-based heuristics and common 6 | errors and mistypings especially those committed while typing text on a keyboard. 7 | 8 | 🚀 Both methods proved their effectiveness for spelling correction systems and celebrated substantial **performance gains** 9 | fully reported in our `Paper `_. 10 | 11 | 12 | **Augmentex** introduces rule-based and common statistic (empowered by `KartaSlov `_ project) 13 | approach to insert errors in text. It is fully described again in the `Paper `_ 14 | and in this 🗣️\ `Talk `_. 15 | 16 | 🖇️ Augmentex allows you to operate on two levels of granularity when it comes to text corruption and offers you sets of 17 | specific methods suited for particular level: 18 | 19 | 20 | * **Word level**\ : 21 | 22 | * *replace* - replace a random word with its incorrect counterpart; 23 | * *delete* - delete random word; 24 | * *swap* - swap two random words; 25 | * *stopword* - add random words from stop-list; 26 | * *reverse* - change a case of the first letter of a random word; 27 | 28 | * **Character level**\ : 29 | 30 | * *shift* - randomly swaps upper / lower case in a string; 31 | * *orfo* - substitute correct characters with their common incorrect counterparts; 32 | * *typo* - substitute correct characters as if they are mistyped on a keyboard; 33 | * *delete* - delete random character; 34 | * *multiply* - multiply random character; 35 | * *swap* - swap two adjacent characters; 36 | * *insert* - insert random character; 37 | 38 | To access Augmentex you only need these few manipulations: 39 | 40 | .. code-block:: python 41 | 42 | from sage.spelling_corruption import CharAugConfig, CharAugCorruptor 43 | 44 | config = CharAugConfig( 45 | unit_prob=0.3, # proportion of characters that is going to undergo edits 46 | min_aug=1, # minimum number of edits 47 | max_aug=5, # maximum number of edits 48 | mult_num=3 # `multiply` edit 49 | ) 50 | corruptor = CharAugCorruptor.from_config(config) 51 | 52 | ... or like this: 53 | 54 | .. code-block:: python 55 | 56 | from sage.spelling_corruption import WordAugConfig, WordAugCorruptor 57 | 58 | config = WordAugConfig( 59 | unit_prob=0.4, # proportion of characters that is going to undergo edits 60 | min_aug=1, # minimum number of edits 61 | max_aug=5, # maximum number of edits 62 | ) 63 | corruptor = WordAugCorruptor.from_config(config) 64 | 65 | Augmentex has been created by our fellow team, the project has its own `repo `_\ , do not forget to take a look! 66 | -------------------------------------------------------------------------------- /docs/source/rst/spelling_corruption/SBSC.rst: -------------------------------------------------------------------------------- 1 | 📊 SBSC 2 | ------------------- 3 | 4 | We implemented two methods for spelling corruption. **S**\ tatistic-\ **b**\ ased **S**\ pelling **C**\ orruption (\ **SBSC**\ ) aims 5 | to mimic human behaviour when making an error. While `Augmentex `_ relies on rule-based heuristics and common 6 | errors and mistypings especially those committed while typing text on a keyboard. 7 | 8 | 🚀 Both methods proved their effectiveness for spelling correction systems and celebrated substantial **performance gains** 9 | fully reported in our `Paper `_. 10 | 11 | 12 | **SBSC** is thoroughly described in our another `Paper `_ 13 | and in this 🗣️\ `Talk `_. 14 | 15 | Briefly, SBSC follows two simple steps: 16 | 17 | 18 | * 🧠 Analyze errors, their type and positions in a source text; 19 | * ✏️ Reproduce errors from the source text in a new sentence; 20 | 21 | 🧠 To analyze errors in a source sentence we need its corresponding correction in order to build 22 | `Levenshtein matrix `_\ , traverse it back starting from the 23 | bottom right entry and determine the exact position and type of an error. We then aggregate all obtained statistics and 24 | normalize it to valid discrete distributions. 25 | 26 | ✏️ "Reproduce" step is even less complicated: we just sample number of errors per sentence, their types and relative 27 | positions from corresponding distributions and apply them to a correct sentence. 28 | 29 | As stated, you need a parallel dataset to "fit" SBSC. We provide a set of four datasets with natural errors covering 30 | exhaustive range of domains: 31 | 32 | 33 | * **RUSpellRU**\ : texts collected from `LiveJournal `_\ , with manually corrected typos and errors; 34 | * **MultidomainGold**\ : examples from 7 text sources, including the open web, news, social media, reviews, subtitles, policy documents and literary works; 35 | * **MedSpellChecker**\ : texts with errors from medical anamnesis; 36 | * **GitHubTypoCorpusRu**\ : spelling errors and typos in commits from GitHub; 37 | 38 | You can use them as simple as 39 | 40 | .. code-block:: python 41 | 42 | import sage 43 | from sage.spelling_corruption import SBSCConfig, SBSCCorruptor 44 | from sage.utils import DatasetsAvailable 45 | 46 | # Instantiate SBSC corruptor from a dataset with errors in medical anamnesis 47 | config = SBSCConfig( 48 | reference_dataset_name_or_path=DatasetsAvailable.MedSpellchecker.name, 49 | reference_dataset_split="test" 50 | ) 51 | corruptor = SBSCCorruptor.from_config(config) 52 | 53 | ... or you can initialize your SBSC from locally stored dataset: 54 | 55 | .. code-block:: python 56 | 57 | import os 58 | from sage.spelling_corruption import SBSCConfig, SBSCCorruptor 59 | 60 | # Instantiate SBSC corruptor from a JFLEG dataset 61 | config = SBSCConfig( 62 | lang="en", 63 | reference_dataset_name_or_path=os.path.join("data", "example_data", "jfleg"), 64 | ) 65 | corruptor = SBSCCorruptor.from_config(config) 66 | 67 | ✅ To check how good SBSC actually approximates original errors, you can plot side-by-side graphs of original and 68 | synthetically generated distributions: 69 | 70 | 71 | |pic1| |pic2| 72 | 73 | .. |pic1| image:: ../../images/bea60k_side_by_side.jpg 74 | :width: 45% 75 | 76 | .. |pic2| image:: ../../images/ruspellru_side_by_side.jpg 77 | :width: 45% 78 | 79 | 80 | 81 | To access these graphs you can simply 82 | 83 | .. code-block:: python 84 | 85 | from sage.utils import load_available_dataset_from_hf, draw_and_save_errors_distributions_comparison_charts 86 | from sage.spelling_corruption.sbsc.labeler import process_mistypings 87 | from sage.spelling_corruption import SBSCCorruptor 88 | 89 | sources, corrections = load_available_dataset_from_hf("RUSpellRU", for_labeler=True, split="train") 90 | ruspellru_stats, ruspellru_confusion_matrix, ruspellru_typos_cnt = process_mistypings(sources, corrections) 91 | 92 | corruptor = SBSCCorruptor.from_default_config() 93 | spoiled_sentences = corruptor.batch_corrupt(corrections) 94 | 95 | sbsc_stats, sbsc_confusion_matrix, sbsc_typos_cnt = process_mistypings(spoiled_sentences, corrections) 96 | 97 | draw_and_save_errors_distributions_comparison_charts( 98 | actual_typos_cnt = sbsc_typos_cnt, 99 | reference_typos_cnt=ruspellru_typos_cnt, 100 | actual_stats=sbsc_stats, 101 | reference_stats=ruspellru_stats, 102 | path_to_save="ruspellru_sbsc.jpg" 103 | ) 104 | -------------------------------------------------------------------------------- /images/bea60k_side_by_side.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-forever/sage/71736b6d14b1dd4a369ea9289eb2cac9a5a25f21/images/bea60k_side_by_side.jpg -------------------------------------------------------------------------------- /images/ruspellru_side_by_side.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-forever/sage/71736b6d14b1dd4a369ea9289eb2cac9a5a25f21/images/ruspellru_side_by_side.jpg -------------------------------------------------------------------------------- /sage/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | __author__ = "Nikita Martynov, Mark Baushenko, Alexandr Abramov and Alena Fenogenova" 3 | __email__ = "nikita.martynov.98@list.ru" 4 | 5 | from .utils.lang_utils import AVAILABLE_LANG_CODES 6 | 7 | __all__ = [ 8 | "AVAILABLE_LANG_CODES", 9 | ] 10 | -------------------------------------------------------------------------------- /sage/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .scorer import Scorer 2 | 3 | __all__ = [ 4 | "Scorer" 5 | ] 6 | -------------------------------------------------------------------------------- /sage/evaluation/readme.md: -------------------------------------------------------------------------------- 1 | ### Evaluation in SAGE 2 | 3 | #### RuERRANT 4 | 5 | RuERRANT is an adaptation of the ERRANT metric ([repo](https://github.com/chrisjbryant/errant), [paper-2016](https://aclanthology.org/C16-1079.pdf), [paper-2017](https://aclanthology.org/P17-1074.pdf)) to the Russian language. The adaptation was primarily done in https://github.com/Askinkaty/errant and further developed within SAGE. The changes to the original ERRANT implementation for English are the following: 6 | 7 | 1. Basic parsing model changed to Spacy's `ru_core_news_lg`. 8 | 1. Included a dictionary of Russian words (main forms). 9 | 1. Introduced detection of error correction types specific for Russian (degrees of adjectives, verb aspect). 10 | 1. [our contribution] Introduced a simplified error correction typology: 11 | - `CASE`: spelling corrections including only character case change; 12 | - `PUNCT`: punctuation corrections; 13 | - `YO`: spelling corrections regarding "е"/"ё" substitutions; 14 | - `SPELL`: all other word-level spelling corrections. 15 | 1. [our contribution] Introduced detection of multiple error correction types per word, e.g. "федор" -> "Фёдор" contains both CASE and YO corrections. 16 | 1. [our contribution] Introduced detection of inner word punctuation corrections which covers joint ("AB") vs. hyphen ("A-B") vs. space ("A B") word spelling. Corrections of this type are attributed to the `SPELL` category. 17 | 18 | #### Scoring 19 | 20 | To score model's corrections against gold corrections, use a Scorer instance: 21 | 22 | ```python 23 | from sage.evaluation.scorer import Scorer 24 | 25 | s = Scorer() 26 | 27 | s.score( 28 | ["спел Кейс ее .", "спел Кейс ее ."], 29 | ["спелл кейс её !", "спелл кейс её !"], 30 | ["спел кейс её .", "спелл Кейс ее !"], 31 | metrics=["errant"] 32 | ) 33 | >>> {'CASE_Precision': 100.0, 'CASE_Recall': 50.0, 'CASE_F1': 66.67, 34 | 'YO_Precision': 100.0, 'YO_Recall': 50.0, 'YO_F1': 66.67, 35 | 'SPELL_Precision': 100.0, 'SPELL_Recall': 50.0, 'SPELL_F1': 66.67, 36 | 'PUNCT_Precision': 100.0, 'PUNCT_Recall': 50.0, 'PUNCT_F1': 66.67} 37 | ``` 38 | -------------------------------------------------------------------------------- /sage/evaluation/ruerrant_wrapper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-forever/sage/71736b6d14b1dd4a369ea9289eb2cac9a5a25f21/sage/evaluation/ruerrant_wrapper/__init__.py -------------------------------------------------------------------------------- /sage/evaluation/ruerrant_wrapper/classifier.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections import defaultdict 4 | from string import punctuation 5 | 6 | import Levenshtein 7 | from errant.edit import Edit 8 | 9 | 10 | def edit_to_tuple(edit: Edit, idx: int = 0) -> tuple[int, int, str, str, int]: 11 | cor_toks_str = " ".join([tok.text for tok in edit.c_toks]) 12 | return [edit.o_start, edit.o_end, edit.type, cor_toks_str, idx] 13 | 14 | 15 | def classify(edit: Edit) -> list[Edit]: 16 | """Classifies an Edit via updating its `type` attribute.""" 17 | # Insertion and deletion 18 | if ((not edit.o_toks and edit.c_toks) or (edit.o_toks and not edit.c_toks)): 19 | error_cats = get_one_sided_type(edit.o_toks, edit.c_toks) 20 | elif edit.o_toks != edit.c_toks: 21 | error_cats = get_two_sided_type(edit.o_toks, edit.c_toks) 22 | else: 23 | error_cats = {"NA": edit.c_toks[0].text} 24 | new_edit_list = [] 25 | if error_cats: 26 | for error_cat, correct_str in error_cats.items(): 27 | edit.type = error_cat 28 | edit_tuple = edit_to_tuple(edit) 29 | edit_tuple[3] = correct_str 30 | new_edit_list.append(edit_tuple) 31 | return new_edit_list 32 | 33 | 34 | def get_edit_info(toks): 35 | pos = [] 36 | dep = [] 37 | morph = dict() 38 | for tok in toks: 39 | pos.append(tok.tag_) 40 | dep.append(tok.dep_) 41 | morphs = str(tok.morph).split('|') 42 | for m in morphs: 43 | if len(m.strip()): 44 | k, v = m.strip().split('=') 45 | morph[k] = v 46 | return pos, dep, morph 47 | 48 | 49 | def get_one_sided_type(o_toks, c_toks): 50 | """Classifies a zero-to-one or one-to-zero error based on a token list.""" 51 | pos_list, _, _ = get_edit_info(o_toks if o_toks else c_toks) 52 | if "PUNCT" in pos_list or "SPACE" in pos_list: 53 | return {"PUNCT": c_toks[0].text if c_toks else ""} 54 | return {"SPELL": c_toks[0].text if c_toks else ""} 55 | 56 | 57 | def get_two_sided_type(o_toks, c_toks) -> dict[str, str]: 58 | """Classifies a one-to-one or one-to-many or many-to-one error based on token lists.""" 59 | # one-to-one cases 60 | if len(o_toks) == len(c_toks) == 1: 61 | if ( 62 | all(char in punctuation + " " for char in o_toks[0].text) and 63 | all(char in punctuation + " " for char in c_toks[0].text) 64 | ): 65 | return {"PUNCT": c_toks[0].text} 66 | source_w, correct_w = o_toks[0].text, c_toks[0].text 67 | if source_w != correct_w: 68 | # if both string are lowercase or both are uppercase, 69 | # and there is no "ё" in both, then it may be only "SPELL" error type 70 | if (((source_w.islower() and correct_w.islower()) or 71 | (source_w.isupper() and correct_w.isupper())) and 72 | "ё" not in source_w + correct_w): 73 | return {"SPELL": correct_w} 74 | # edits with multiple errors (e.g. SPELL + CASE) 75 | # Step 1. Make char-level Levenstein table 76 | char_edits = Levenshtein.editops(source_w, correct_w) 77 | # Step 2. Classify operations (CASE, YO, SPELL) 78 | edits_classified = classify_char_edits(char_edits, source_w, correct_w) 79 | # Step 3. Combine the same-typed errors into minimal string pairs 80 | separated_edits = get_edit_strings(source_w, correct_w, edits_classified) 81 | return separated_edits 82 | # one-to-many and many-to-one cases 83 | if all(char in punctuation + " " for char in o_toks.text + c_toks.text): 84 | return {"PUNCT": c_toks.text} 85 | joint_corr_str = " ".join([tok.text for tok in c_toks]) 86 | joint_corr_str = joint_corr_str.replace("- ", "-").replace(" -", "-") 87 | return {"SPELL": joint_corr_str} 88 | 89 | 90 | def classify_char_edits(char_edits, source_w, correct_w): 91 | """Classifies char-level Levenstein operations into SPELL, YO and CASE.""" 92 | edits_classified = [] 93 | for edit in char_edits: 94 | if edit[0] == "replace": 95 | if "ё" in [source_w[edit[1]], correct_w[edit[2]]]: 96 | edits_classified.append((*edit, "YO")) 97 | elif source_w[edit[1]].lower() == correct_w[edit[2]].lower(): 98 | edits_classified.append((*edit, "CASE")) 99 | else: 100 | if ( 101 | (source_w[edit[1]].islower() and correct_w[edit[2]].isupper()) or 102 | (source_w[edit[1]].isupper() and correct_w[edit[2]].islower()) 103 | ): 104 | edits_classified.append((*edit, "CASE")) 105 | edits_classified.append((*edit, "SPELL")) 106 | else: 107 | edits_classified.append((*edit, "SPELL")) 108 | return edits_classified 109 | 110 | 111 | def get_edit_strings(source: str, correction: str, 112 | edits_classified: list[tuple]) -> dict[str, str]: 113 | """ 114 | Applies classified (SPELL, YO and CASE) char operations to source word separately. 115 | Returns a dict mapping error type to source string with corrections of this type only. 116 | """ 117 | separated_edits = defaultdict(lambda: source) 118 | shift = 0 # char position shift to consider on deletions and insertions 119 | for edit in edits_classified: 120 | edit_type = edit[3] 121 | curr_src = separated_edits[edit_type] 122 | if edit_type == "CASE": # SOURCE letter spelled in CORRECTION case 123 | if correction[edit[2]].isupper(): 124 | correction_char = source[edit[1]].upper() 125 | else: 126 | correction_char = source[edit[1]].lower() 127 | else: 128 | if edit[0] == "delete": 129 | correction_char = "" 130 | elif edit[0] == "insert": 131 | correction_char = correction[edit[2]] 132 | elif source[edit[1]].isupper(): 133 | correction_char = correction[edit[2]].upper() 134 | else: 135 | correction_char = correction[edit[2]].lower() 136 | if edit[0] == "replace": 137 | separated_edits[edit_type] = curr_src[:edit[1] + shift] + correction_char + \ 138 | curr_src[edit[1]+shift + 1:] 139 | elif edit[0] == "delete": 140 | separated_edits[edit_type] = curr_src[:edit[1] + shift] + \ 141 | curr_src[edit[1]+shift + 1:] 142 | shift -= 1 143 | elif edit[0] == "insert": 144 | separated_edits[edit_type] = curr_src[:edit[1] + shift] + correction_char + \ 145 | curr_src[edit[1]+shift:] 146 | shift += 1 147 | return dict(separated_edits) 148 | -------------------------------------------------------------------------------- /sage/evaluation/ruerrant_wrapper/merger.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import itertools 4 | import re 5 | from string import punctuation 6 | 7 | import Levenshtein 8 | from errant.alignment import Alignment 9 | from errant.edit import Edit 10 | 11 | 12 | def get_rule_edits(alignment: Alignment) -> list[Edit]: 13 | """Groups word-level alignment according to merging rules.""" 14 | edits = [] 15 | # Split alignment into groups 16 | alignment_groups = group_alignment(alignment, "new") 17 | for op, group in alignment_groups: 18 | group = list(group) 19 | # Ignore M 20 | if op == "M": 21 | continue 22 | # T is always split 23 | if op == "T": 24 | for seq in group: 25 | edits.append(Edit(alignment.orig, alignment.cor, seq[1:])) 26 | # Process D, I and S subsequence 27 | else: 28 | processed = process_seq(group, alignment) 29 | # Turn the processed sequence into edits 30 | for seq in processed: 31 | edits.append(Edit(alignment.orig, alignment.cor, seq[1:])) 32 | return edits 33 | 34 | 35 | def group_alignment(alignment: Alignment, mode: str = "default") -> list[tuple[str, list[tuple]]]: 36 | """ 37 | Does initial alignment grouping: 38 | 1. Make groups of MDM, MIM od MSM. 39 | 2. In remaining operations, make groups of Ms, groups of Ts, and D/I/Ss. 40 | Do not group what was on the sides of M[DIS]M: SSMDMS -> [SS, MDM, S], not [MDM, SSS]. 41 | 3. Sort groups by the order in which they appear in the alignment. 42 | """ 43 | if mode == "new": 44 | op_groups = [] 45 | # Format operation types sequence as string to use regex sequence search 46 | all_ops_seq = "".join([op[0][0] for op in alignment.align_seq]) 47 | # Find M[DIS]M groups and merge (need them to detect hyphen vs. space spelling) 48 | ungrouped_ids = list(range(len(alignment.align_seq))) 49 | for match in re.finditer("M[DIS]M", all_ops_seq): 50 | start, end = match.start(), match.end() 51 | op_groups.append(("MSM", alignment.align_seq[start:end])) 52 | for idx in range(start, end): 53 | ungrouped_ids.remove(idx) 54 | # Group remaining operations by default rules (groups of M, T and rest) 55 | if ungrouped_ids: 56 | def get_group_type(operation): 57 | return operation if operation in {"M", "T"} else "DIS" 58 | curr_group = [alignment.align_seq[ungrouped_ids[0]]] 59 | last_oper_type = get_group_type(curr_group[0][0][0]) 60 | for i, idx in enumerate(ungrouped_ids[1:], start=1): 61 | operation = alignment.align_seq[idx] 62 | oper_type = get_group_type(operation[0][0]) 63 | if (oper_type == last_oper_type and 64 | (idx - ungrouped_ids[i-1] == 1 or oper_type in {"M", "T"})): 65 | curr_group.append(operation) 66 | else: 67 | op_groups.append((last_oper_type, curr_group)) 68 | curr_group = [operation] 69 | last_oper_type = oper_type 70 | if curr_group: 71 | op_groups.append((last_oper_type, curr_group)) 72 | # Sort groups by the start id of the first group entry 73 | op_groups = sorted(op_groups, key=lambda x: x[1][0][1]) 74 | else: 75 | grouped = itertools.groupby(alignment.align_seq, 76 | lambda x: x[0][0] if x[0][0] in {"M", "T"} else False) 77 | op_groups = [(op, list(group)) for op, group in grouped] 78 | return op_groups 79 | 80 | 81 | def process_seq(seq: list[tuple], alignment: Alignment) -> list[tuple]: 82 | """Applies merging rules to previously formed alignment groups (`seq`).""" 83 | # Return single alignments 84 | if len(seq) <= 1: 85 | return seq 86 | # Get the ops for the whole sequence 87 | ops = [op[0] for op in seq] 88 | 89 | # Get indices of all start-end combinations in the seq: 012 = 01, 02, 12 90 | combos = list(itertools.combinations(range(0, len(seq)), 2)) 91 | # Sort them starting with largest spans first 92 | combos.sort(key=lambda x: x[1] - x[0], reverse=True) 93 | # Loop through combos 94 | for start, end in combos: 95 | # Ignore ranges that do NOT contain a substitution, deletion or insertion. 96 | if not any(type_ in ops[start:end + 1] for type_ in ["D", "I", "S"]): 97 | continue 98 | # Merge all D xor I ops. (95% of human multi-token edits contain S). 99 | if set(ops[start:end + 1]) == {"D"} or set(ops[start:end + 1]) == {"I"}: 100 | return (process_seq(seq[:start], alignment) 101 | + merge_edits(seq[start:end + 1]) 102 | + process_seq(seq[end + 1:], alignment)) 103 | # Get the tokens in orig and cor. 104 | o = alignment.orig[seq[start][1]:seq[end][2]] 105 | c = alignment.cor[seq[start][3]:seq[end][4]] 106 | if ops[start:end + 1] in [["M", "D", "M"], ["M", "I", "M"], ["M", "S", "M"]]: 107 | # merge hyphens 108 | if (o[start + 1].text == "-" or c[start + 1].text == "-") and len(o) != len(c): 109 | return (process_seq(seq[:start], alignment) 110 | + merge_edits(seq[start:end + 1]) 111 | + process_seq(seq[end + 1:], alignment)) 112 | # if it is not a hyphen-space edit, return only punct edit 113 | return seq[start + 1: end] 114 | # Merge possessive suffixes: [friends -> friend 's] 115 | if o[-1].tag_ == "POS" or c[-1].tag_ == "POS": 116 | return (process_seq(seq[:end - 1], alignment) 117 | + merge_edits(seq[end - 1:end + 1]) 118 | + process_seq(seq[end + 1:], alignment)) 119 | # Case changes 120 | if o[-1].lower == c[-1].lower: 121 | # Merge first token I or D: [Cat -> The big cat] 122 | if (start == 0 and 123 | (len(o) == 1 and c[0].text[0].isupper()) or 124 | (len(c) == 1 and o[0].text[0].isupper())): 125 | return (merge_edits(seq[start:end + 1]) 126 | + process_seq(seq[end + 1:], alignment)) 127 | # Merge with previous punctuation: [, we -> . We], [we -> . We] 128 | if (len(o) > 1 and is_punct(o[-2])) or \ 129 | (len(c) > 1 and is_punct(c[-2])): 130 | return (process_seq(seq[:end - 1], alignment) 131 | + merge_edits(seq[end - 1:end + 1]) 132 | + process_seq(seq[end + 1:], alignment)) 133 | # Merge whitespace/hyphens: [acat -> a cat], [sub - way -> subway] 134 | s_str = re.sub("['-]", "", "".join([tok.lower_ for tok in o])) 135 | t_str = re.sub("['-]", "", "".join([tok.lower_ for tok in c])) 136 | if s_str == t_str or s_str.replace(" ", "") == t_str.replace(" ", ""): 137 | return (process_seq(seq[:start], alignment) 138 | + merge_edits(seq[start:end + 1]) 139 | + process_seq(seq[end + 1:], alignment)) 140 | # Merge same POS or auxiliary/infinitive/phrasal verbs: 141 | # [to eat -> eating], [watch -> look at] 142 | pos_set = set([tok.pos for tok in o] + [tok.pos for tok in c]) 143 | if len(o) != len(c) and (len(pos_set) == 1 or pos_set.issubset({"AUX", "PART", "VERB"})): 144 | return (process_seq(seq[:start], alignment) 145 | + merge_edits(seq[start:end + 1]) 146 | + process_seq(seq[end + 1:], alignment)) 147 | # Split rules take effect when we get to smallest chunks 148 | if end - start < 2: 149 | # Split adjacent substitutions 150 | if len(o) == len(c) == 2: 151 | return (process_seq(seq[:start + 1], alignment) 152 | + process_seq(seq[start + 1:], alignment)) 153 | # Split similar substitutions at sequence boundaries 154 | if ((ops[start] == "S" and char_cost(o[0].text, c[0].text) > 0.75) or 155 | (ops[end] == "S" and char_cost(o[-1].text, c[-1].text) > 0.75)): 156 | return (process_seq(seq[:start + 1], alignment) 157 | + process_seq(seq[start + 1:], alignment)) 158 | # Split final determiners 159 | if (end == len(seq) - 1 and 160 | ((ops[-1] in {"D", "S"} and o[-1].pos == "DET") or 161 | (ops[-1] in {"I", "S"} and c[-1].pos == "DET"))): 162 | return process_seq(seq[:-1], alignment) + [seq[-1]] 163 | return seq 164 | 165 | 166 | def is_punct(token) -> bool: 167 | return token.text in punctuation 168 | 169 | 170 | def char_cost(a: str, b: str) -> float: 171 | """Calculate the cost of character alignment; i.e. char similarity.""" 172 | 173 | return Levenshtein.ratio(a, b) 174 | 175 | 176 | def merge_edits(seq: list[tuple]) -> list[tuple]: 177 | """Merge the input alignment sequence to a single edit span.""" 178 | 179 | if seq: 180 | return [("X", seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])] 181 | return seq 182 | -------------------------------------------------------------------------------- /sage/evaluation/ruerrant_wrapper/scorer.py: -------------------------------------------------------------------------------- 1 | """A wrapper over the 'errant' library fork from https://github.com/Askinkaty/errant/. 2 | 3 | This implemetation brings some changes over the fork (which provided initial adaptation 4 | of the ERRANT metric for the Russian language). The changes deal with token merging and 5 | error classification and are described in detail in the readme. 6 | """ 7 | 8 | from __future__ import annotations 9 | 10 | import re 11 | from collections import Counter, namedtuple 12 | from typing import Iterable 13 | from tqdm.auto import tqdm 14 | 15 | from errant.annotator import Annotator 16 | from errant.commands.compare_m2 import process_edits 17 | from errant.commands.compare_m2 import evaluate_edits 18 | from errant.commands.compare_m2 import merge_dict 19 | from errant.edit import Edit 20 | import spacy 21 | from spacy.tokenizer import Tokenizer 22 | from spacy.util import compile_prefix_regex, compile_infix_regex, compile_suffix_regex 23 | 24 | from sage.evaluation.ruerrant_wrapper import classifier 25 | from sage.evaluation.ruerrant_wrapper import merger 26 | 27 | 28 | def update_spacy_tokenizer(nlp): 29 | """ 30 | Changes Spacy tokenizer to parse additional patterns. 31 | """ 32 | infix_re = compile_infix_regex(nlp.Defaults.infixes[:-1] + ["\]\("]) 33 | simple_url_re = re.compile(r'''^https?://''') 34 | nlp.tokenizer = Tokenizer( 35 | nlp.vocab, 36 | prefix_search=compile_prefix_regex(nlp.Defaults.prefixes + ['\\\\\"']).search, 37 | suffix_search=compile_suffix_regex(nlp.Defaults.suffixes + ['\\\\']).search, 38 | infix_finditer=infix_re.finditer, 39 | token_match=None, 40 | url_match=simple_url_re.match 41 | ) 42 | return nlp 43 | 44 | 45 | class RuErrantScorer: 46 | """A scorer to evaluate spelling correction triplets with ERRANT metric.""" 47 | 48 | def __init__(self) -> None: 49 | self.annotator = Annotator("ru", 50 | nlp=update_spacy_tokenizer(spacy.load("ru_core_news_lg")), 51 | merger=merger, 52 | classifier=classifier) 53 | 54 | def annotate_errors(self, orig: str, cor: str, merging: str = "rules") -> list[Edit]: 55 | """ 56 | Overrides `Annotator.annotate()` function to allow multiple errors per token. 57 | This is nesessary to parse combined errors, e.g.: 58 | ["werd", "Word"] >>> Errors: ["SPELL", "CASE"] 59 | The `classify()` method called inside is implemented in ruerrant_classifier.py 60 | (also overrides the original classifier). 61 | """ 62 | 63 | alignment = self.annotator.align(orig, cor, False) 64 | edits = self.annotator.merge(alignment, merging) 65 | classified_edits = [] 66 | for edit in edits: 67 | classified_edits.extend(self.annotator.classify(edit)) 68 | return sorted(classified_edits, key=lambda x: (x[0], x[2])) 69 | 70 | def evaluate(self, 71 | sources: Iterable[str], 72 | corrections: Iterable[str], 73 | answers: Iterable[str]) -> dict[str, tuple[float, float, float]]: 74 | """ 75 | Evaluates iterables of sources, hyp and ref corrections with ERRANT metric. 76 | 77 | Args: 78 | sources (Iterable[str]): an iterable of source texts; 79 | corrections (Iterable[str]): an iterable of gold corrections for the source texts; 80 | answers (Iterable[str]): an iterable of evaluated corrections for the source texts; 81 | 82 | Returns: 83 | dict[str, tuple[float, ...]]: a dict mapping error categories to the corresponding 84 | P, R, F1 metric values. 85 | """ 86 | 87 | best_dict = Counter({"tp": 0, "fp": 0, "fn": 0}) 88 | best_cats = {} 89 | sents = zip(sources, corrections, answers) 90 | pb = tqdm(sents, desc="Calculating errant metric", total=len(sources)) 91 | for sent_id, sent in enumerate(pb): 92 | src = self.annotator.parse(sent[0]) 93 | ref = self.annotator.parse(sent[1]) 94 | hyp = self.annotator.parse(sent[2]) 95 | # Align hyp and ref corrections and annotate errors 96 | hyp_edits = self.annotate_errors(src, hyp) 97 | ref_edits = self.annotate_errors(src, ref) 98 | # Process the edits for detection/correction based on args 99 | ProcessingArgs = namedtuple("ProcessingArgs", 100 | ["dt", "ds", "single", "multi", "filt", "cse"], 101 | defaults=[False, False, False, False, [], True]) 102 | processing_args = ProcessingArgs() 103 | hyp_dict = process_edits(hyp_edits, processing_args) 104 | ref_dict = process_edits(ref_edits, processing_args) 105 | # Evaluate edits and get best TP, FP, FN hyp+ref combo. 106 | EvaluationArgs = namedtuple("EvaluationArgs", 107 | ["beta", "verbose"], 108 | defaults=[1.0, False]) 109 | evaluation_args = EvaluationArgs() 110 | count_dict, cat_dict = evaluate_edits( 111 | hyp_dict, ref_dict, best_dict, sent_id, evaluation_args) 112 | # Merge these dicts with best_dict and best_cats 113 | best_dict += Counter(count_dict) # corpus-level TP, FP, FN 114 | best_cats = merge_dict(best_cats, cat_dict) # corpus-level errortype-wise TP, FP, FN 115 | cat_prf = {} 116 | for cat, values in best_cats.items(): 117 | tp, fp, fn = values # fp - extra corrections, fn - missed corrections 118 | p = float(tp) / (tp + fp) if tp + fp else 1.0 119 | r = float(tp) / (tp + fn) if tp + fn else 1.0 120 | f = (2 * p * r) / (p + r) if p + r else 0.0 121 | cat_prf[cat] = (p, r, f) 122 | 123 | for error_category in ["CASE", "PUNCT", "SPELL", "YO"]: 124 | if error_category not in cat_prf: 125 | cat_prf[error_category] = (1.0, 1.0, 1.0) 126 | 127 | return cat_prf 128 | -------------------------------------------------------------------------------- /sage/evaluation/scorer.py: -------------------------------------------------------------------------------- 1 | """Generic evaluator for spelling correction task.""" 2 | from __future__ import annotations 3 | 4 | import warnings 5 | from typing import Iterable 6 | 7 | from sage.evaluation.ruerrant_wrapper.scorer import RuErrantScorer 8 | from sage.evaluation.ruspelleval import evaluation as calculate_ruspelleval_metric 9 | 10 | 11 | class Scorer: 12 | """ 13 | Generic evaluator for spelling correction task. 14 | Specific evaluation function calls are implemented in the `score()` function. 15 | 16 | If it is not planned to use "errant" metric with a particular class instance, 17 | consider passing `load_errant=False` to optimize for time and memory. 18 | 19 | Attributes: 20 | errant: a RuErrantScorer instance (unless Scorer is initialized with load_errant=False). 21 | """ 22 | 23 | def __init__(self, load_errant=True) -> None: 24 | if load_errant: 25 | self.errant = RuErrantScorer() 26 | else: 27 | self.errant = None 28 | 29 | def score(self, sources: Iterable[str], corrections: Iterable[str], answers: Iterable[str], 30 | metrics: Iterable[str]) -> dict[str, float]: 31 | """ 32 | Evaluate spelling correction using the specified metrics. 33 | 34 | Args: 35 | sources (Iterable[str]): an iterable of source texts; 36 | corrections (Iterable[str]): an iterable of gold corrections for the source texts; 37 | answers (Iterable[str]): an iterable of evaluated corrections for the source texts; 38 | metrics (Iterable[str]): an iterable of metric to evaluate with; 39 | 40 | Returns: 41 | dict[str, float]: a dict mapping metric names to their values 42 | (the names may not be the same as in the `metrics` arg). 43 | """ 44 | 45 | if metrics: 46 | for metric in metrics: 47 | if metric == "errant": 48 | if self.errant is None: 49 | raise AttributeError( 50 | "You called for `errant` metric which has not been loaded.", 51 | "To use, reinitialize the Scorer with `load_errant=True`.") 52 | elif metric != "ruspelleval": 53 | raise ValueError(f"You provided a wrong metric name: `{metric}`.", 54 | "Available metrics are: [`errant`, `ruspelleval`].") 55 | else: 56 | raise ValueError("The `metrics` argument must contain at least one metric name.") 57 | if isinstance(sources, str) or isinstance(corrections, str) or isinstance(answers, str): 58 | raise ValueError("The `sources`, `corrections`, and `answers` arguments", 59 | "must be iterables of strings.") 60 | if "" in sources or "" in corrections: 61 | # probably too greedy condition (spacy in errant cannot parse empty strings) 62 | raise ValueError("All input strings must not be empty.") 63 | if "" in answers: 64 | warnings.warn("Some of the answers are empty. They will be removed from the evaluation.", UserWarning) 65 | sources = [source for source, answer in zip(sources, answers) if answer] 66 | corrections = [correction for correction, answer in zip(corrections, answers) if answer] 67 | answers = [answer for answer in answers if answer] 68 | result = {} 69 | for metric in metrics: 70 | if metric == "errant" and self.errant is not None: 71 | metrics_by_cats = self.errant.evaluate(sources, corrections, answers) 72 | result_dict = {} 73 | metrics = ["Precision", "Recall", "F1"] 74 | for cat, values in metrics_by_cats.items(): 75 | for metric_name, metric_value in zip(metrics, values): 76 | result_dict[f"{cat}_{metric_name}"] = round(float(metric_value) * 100, 2) 77 | result.update(result_dict) 78 | elif metric == "ruspelleval": 79 | result.update(calculate_ruspelleval_metric(sources, corrections, answers)) 80 | return result 81 | -------------------------------------------------------------------------------- /sage/spelling_correction/__init__.py: -------------------------------------------------------------------------------- 1 | from .m2m_correctors import RuM2M100ModelForSpellingCorrection 2 | from .t5_correctors import T5ModelForSpellingCorruption 3 | from .corrector import AvailableCorrectors 4 | 5 | __all__ = [ 6 | "RuM2M100ModelForSpellingCorrection", 7 | "T5ModelForSpellingCorruption", 8 | "AvailableCorrectors" 9 | ] 10 | -------------------------------------------------------------------------------- /sage/spelling_correction/corrector.py: -------------------------------------------------------------------------------- 1 | """Abstract API to spelling correction models. 2 | 3 | The file also contains available pre-trained models for spelling correction 4 | in Russian and English (yet more is to come). 5 | 6 | To see all available models: 7 | 8 | models = [model.name for model in AvailableCorrectors] 9 | 10 | To launch one of the available models: 11 | 12 | model_path = AvailableCorrectors.m2m100_1B.value 13 | ... # pass model path for initialization 14 | 15 | """ 16 | 17 | import os 18 | import enum 19 | from abc import ABCMeta, abstractmethod 20 | from typing import List, Union, Dict, Optional, Any 21 | 22 | import pandas as pd 23 | 24 | from ..evaluation.scorer import Scorer 25 | from ..utils.data_load_utils import load_available_dataset_from_hf, DatasetsAvailable 26 | 27 | 28 | datasets_available = [dataset.name for dataset in DatasetsAvailable] 29 | 30 | 31 | class AvailableCorrectors(enum.Enum): 32 | """Available models for spelling and punctuation correction""" 33 | 34 | sage_fredt5_large = "ai-forever/sage-fredt5-large" 35 | sage_fredt5_distilled_95m = "ai-forever/sage-fredt5-distilled-95m" 36 | sage_m2m100_1B = "ai-forever/sage-m2m100-1.2B" 37 | sage_mt5_large = "ai-forever/sage-mt5-large" 38 | 39 | m2m100_1B = "ai-forever/RuM2M100-1.2B" 40 | m2m100_418M = "ai-forever/RuM2M100-418M" 41 | fred_large = "ai-forever/FRED-T5-large-spell" 42 | ent5_large = "ai-forever/T5-large-spell" 43 | 44 | 45 | class Corrector(metaclass=ABCMeta): 46 | """Base class for all correctors.""" 47 | 48 | @classmethod 49 | @abstractmethod 50 | def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]): 51 | pass 52 | 53 | def correct(self, sentence: str, prefix: Optional[str] = "", **generation_params) -> List[str]: 54 | """ 55 | Corrects a single input sentence. 56 | 57 | :param sentence: a source sentence; 58 | :type sentence: str 59 | :param prefix: some models need some sort of a prompting; 60 | :type prefix: str 61 | :param generation_params: parameters passed to `generate` method of a HuggingFace model; 62 | :type generation_params: dict 63 | :return: corresponding corrected sentence 64 | :rtype: list of str 65 | """ 66 | return self.batch_correct([sentence], 1, prefix, **generation_params)[-1] 67 | 68 | def evaluate( 69 | self, 70 | dataset_name_or_path: Optional[Union[str, os.PathLike]], 71 | metrics: List, 72 | batch_size: int, 73 | prefix: str = "", 74 | dataset_split: str = "test", 75 | **generation_params, 76 | ) -> Dict[str, float]: 77 | """ 78 | Evaluate the particular model on the spellcheck datasets. 79 | 80 | :param dataset_name_or_path: a path to a locally situated dataset or a name of a dataset on HuggingFace; 81 | :type dataset_name_or_path: str 82 | :param metrics: set of metrics to be used to report performance; 83 | :type metrics: list of str 84 | :param batch_size: size of subsample of input sentences; 85 | :type batch_size: int 86 | :param prefix: some models need some sort of a prompting; 87 | :type prefix: str 88 | :param dataset_split: train / test / dev part to be evaluated on; 89 | :type dataset_split: str 90 | :param generation_params: parameters passed to `generate` method of a HuggingFace model; 91 | :type generation_params: dict 92 | :return: mapping between metric's name and its corresponding value 93 | :rtype: dict[str, float] 94 | """ 95 | dataset_name_or_path = str(dataset_name_or_path) 96 | if dataset_name_or_path in datasets_available: 97 | sources, corrections = load_available_dataset_from_hf( 98 | dataset_name_or_path, for_labeler=True, split=dataset_split) 99 | elif os.path.isdir(dataset_name_or_path): 100 | if os.path.isfile(os.path.join(dataset_name_or_path, "sources.txt")) and \ 101 | os.path.isfile(os.path.join(dataset_name_or_path, "corrections.txt")): 102 | src_file = open(os.path.join(dataset_name_or_path, "sources.txt"), encoding="utf8") 103 | corr_file = open(os.path.join(dataset_name_or_path, "corrections.txt"), encoding="utf8") 104 | sources = src_file.read().split("\n") 105 | corrections = corr_file.read().split("\n") 106 | src_file.close() 107 | corr_file.close() 108 | if len(sources) != len(corrections): 109 | raise RuntimeError("Sources and corrections must be of the same length, but get {} vs {}".format( 110 | len(sources), len(corrections))) 111 | elif os.path.isfile(os.path.join(dataset_name_or_path, "data.csv")): 112 | try: 113 | data = pd.read_csv(os.path.join(dataset_name_or_path, "data.csv")) 114 | except Exception as e: 115 | raise RuntimeError("Wrong format of file {}. Raised an error: {}".format( 116 | os.path.join(dataset_name_or_path, "data.csv"), str(e))) 117 | if not ("source" in data and "correction" in data): 118 | raise RuntimeError("You must provide 'source' and 'correction' columns in {}".format( 119 | os.path.join(dataset_name_or_path, "data.csv") 120 | )) 121 | if data.isna().any().max(): 122 | raise ValueError("Your data at {} contain unnecessary nans".format( 123 | os.path.join(dataset_name_or_path, "data.csv"))) 124 | sources = data.source.values.tolist() 125 | corrections = data.correction.values.tolist() 126 | else: 127 | raise RuntimeError("You must provide either 'data.csv' or 'sources.txt'/'corrections.txt' in {}".format( 128 | dataset_name_or_path 129 | )) 130 | else: 131 | raise ValueError("You must provide either valid path or available dataset's name, you provided {}".format( 132 | dataset_name_or_path 133 | )) 134 | 135 | answers = self.batch_correct(sources, batch_size, prefix, **generation_params) 136 | if "num_return_sequences" in generation_params and generation_params["num_return_sequences"] > 1: 137 | num_sequences = generation_params["num_return_sequences"] 138 | answers = [batch_answers[::num_sequences] for batch_answers in answers] 139 | answers = sum(answers, []) 140 | scorer = Scorer("errant" in metrics) 141 | metrics_dict = scorer.score(sources, corrections, answers, metrics) 142 | return metrics_dict 143 | 144 | @abstractmethod 145 | def batch_correct( 146 | self, 147 | sentences: List[str], 148 | batch_size: int, 149 | prefix: Optional[str] = "", 150 | **generation_params, 151 | ) -> List[List[Any]]: 152 | """Correct multiple sentences""" 153 | -------------------------------------------------------------------------------- /sage/spelling_correction/m2m_correctors.py: -------------------------------------------------------------------------------- 1 | """API to M2M100-based models for spelling correction. 2 | 3 | To load a model: 4 | 5 | from corrector import AvailableCorrectors 6 | 7 | model = RuM2M100ModelForSpellingCorrection.from_pretrained(AvailableCorrectors.m2m100_1B.value) 8 | ... 9 | """ 10 | 11 | import os 12 | from typing import List, Optional, Union, Any 13 | from tqdm.auto import tqdm 14 | from transformers import M2M100ForConditionalGeneration 15 | from transformers.models.m2m_100.tokenization_m2m_100 import M2M100Tokenizer 16 | 17 | from .corrector import Corrector 18 | 19 | 20 | class RuM2M100ModelForSpellingCorrection(Corrector): 21 | """M2M100-based models.""" 22 | 23 | def __init__(self, model_name_or_path: Union[str, os.PathLike]): 24 | """ 25 | Initialize the M2M100-type corrector from a pre-trained checkpoint. 26 | The latter can be either locally situated checkpoint or a name of a model on HuggingFace. 27 | 28 | NOTE: This method does not really load the weights, it just stores the path or name. 29 | 30 | :param model_name_or_path: the aforementioned name or path to checkpoint; 31 | :type model_name_or_path: str or os.PathLike; 32 | """ 33 | self.model_name_or_path = model_name_or_path 34 | 35 | @classmethod 36 | def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]): 37 | """ 38 | Initialize the M2M100-type corrector from a pre-trained checkpoint. 39 | The latter can be either locally situated checkpoint or a name of a model on HuggingFace. 40 | 41 | :param model_name_or_path: the aforementioned name or path to checkpoint; 42 | :type model_name_or_path: str or os.PathLike 43 | :return: corrector initialized from pre-trained weights 44 | :rtype: object of :class:`RuM2M100ModelForSpellingCorrection` 45 | """ 46 | 47 | engine = cls(model_name_or_path) 48 | engine.model = M2M100ForConditionalGeneration.from_pretrained(model_name_or_path) 49 | engine.tokenizer = M2M100Tokenizer.from_pretrained(model_name_or_path, src_lang="ru", tgt_lang="ru") 50 | 51 | return engine 52 | 53 | def batch_correct( 54 | self, 55 | sentences: List[str], 56 | batch_size: int, 57 | prefix: Optional[str] = "", 58 | **generation_params, 59 | ) -> List[List[Any]]: 60 | """ 61 | Corrects multiple sentences. 62 | 63 | :param sentences: input sentences to correct; 64 | :type sentences: list of str 65 | :param batch_size: size of subsample of input sentences; 66 | :type batch_size: int 67 | :param prefix: some models need some sort of a prompting; 68 | :type prefix: str 69 | :param generation_params: parameters passed to `generate` method of a HuggingFace model; 70 | :type generation_params: dict 71 | :return: corresponding corrections 72 | :rtype: list of list of str 73 | """ 74 | if not hasattr(self, "model"): 75 | raise RuntimeError("Please load weights using `from_pretrained` method from one of the available models.") 76 | batches = [sentences[i:i + batch_size] for i in range(0, len(sentences), batch_size)] 77 | result = [] 78 | pb = tqdm(total=len(batches)) 79 | device = self.model.device 80 | if "forced_bos_token_id" in generation_params: 81 | generation_params.pop("forced_bos_token_id") 82 | for batch in batches: 83 | encodings = self.tokenizer.batch_encode_plus( 84 | batch, max_length=None, padding="longest", truncation=False, return_tensors='pt') 85 | for k, v in encodings.items(): 86 | encodings[k] = v.to(device) 87 | generated_tokens = self.model.generate( 88 | **encodings, **generation_params, forced_bos_token_id=self.tokenizer.get_lang_id("ru"), 89 | max_length=int(1.5*encodings["input_ids"].shape[1])) 90 | ans = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) 91 | result.append(ans) 92 | pb.update(1) 93 | return result 94 | -------------------------------------------------------------------------------- /sage/spelling_correction/t5_correctors.py: -------------------------------------------------------------------------------- 1 | """API to T5-based models for spelling correction. 2 | 3 | To load a model: 4 | 5 | from corrector import AvailableCorrectors 6 | 7 | model = T5ModelForSpellingCorruption.from_pretrained(AvailableCorrectors.fred_large.value) 8 | ... 9 | """ 10 | 11 | import os 12 | from typing import List, Optional, Union, Any 13 | 14 | import torch 15 | from tqdm.auto import tqdm 16 | from transformers import T5ForConditionalGeneration, AutoTokenizer 17 | 18 | from .corrector import Corrector 19 | 20 | 21 | class T5ModelForSpellingCorruption(Corrector): 22 | """T5-based models.""" 23 | 24 | def __init__(self, model_name_or_path: Union[str, os.PathLike]): 25 | """ 26 | Initialize the T5-type corrector from a pre-trained checkpoint. 27 | The latter can be either locally situated checkpoint or a name of a model on HuggingFace. 28 | 29 | NOTE: This method does not really load the weights, it just stores the path or name. 30 | 31 | :param model_name_or_path: the aforementioned name or path to checkpoint; 32 | :type model_name_or_path: str or os.PathLike; 33 | """ 34 | self.model_name_or_path = model_name_or_path 35 | self.max_model_length = 512 36 | 37 | @classmethod 38 | def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]): 39 | """ 40 | Initialize the T5-type corrector from a pre-trained checkpoint. 41 | The latter can be either locally situated checkpoint or a name of a model on HuggingFace. 42 | 43 | :param model_name_or_path: the aforementioned name or path to checkpoint; 44 | :type model_name_or_path: str or os.PathLike 45 | :return: corrector initialized from pre-trained weights 46 | :rtype: object of :class:`T5ModelForSpellingCorruption` 47 | """ 48 | engine = cls(model_name_or_path) 49 | engine.model = T5ForConditionalGeneration.from_pretrained(model_name_or_path) 50 | engine.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 51 | 52 | return engine 53 | 54 | def batch_correct( 55 | self, 56 | sentences: List[str], 57 | batch_size: int, 58 | prefix: Optional[str] = "", 59 | **generation_params, 60 | ) -> List[List[Any]]: 61 | """ 62 | Corrects multiple sentences. 63 | 64 | :param sentences: input sentences to correct; 65 | :type sentences: list of str 66 | :param batch_size: size of subsample of input sentences; 67 | :type batch_size: int 68 | :param prefix: some models need some sort of a prompting; 69 | :type prefix: str 70 | :param generation_params: parameters passed to `generate` method of a HuggingFace model; 71 | :type generation_params: dict 72 | :return: corresponding corrections 73 | :rtype: list of list of str 74 | """ 75 | if not hasattr(self, "model"): 76 | raise RuntimeError("Please load weights using `from_pretrained` method from one of the available models.") 77 | batches = [sentences[i:i + batch_size] for i in range(0, len(sentences), batch_size)] 78 | result = [] 79 | pb = tqdm(total=len(batches)) 80 | device = self.model.device 81 | for batch in batches: 82 | batch_prefix = [prefix + sentence for sentence in batch] 83 | with torch.inference_mode(): 84 | encodings = self.tokenizer.batch_encode_plus( 85 | batch_prefix, max_length=None, padding="longest", truncation=False, return_tensors='pt') 86 | for k, v in encodings.items(): 87 | encodings[k] = v.to(device) 88 | generated_tokens = self.model.generate( 89 | **encodings, **generation_params, max_length=encodings['input_ids'].size(1) * 1.5) 90 | ans = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) 91 | result.append(ans) 92 | pb.update(1) 93 | return result 94 | -------------------------------------------------------------------------------- /sage/spelling_corruption/__init__.py: -------------------------------------------------------------------------------- 1 | from .corruptor import WordAugCorruptor, CharAugCorruptor, SBSCCorruptor 2 | from .configuration_corruptor import WordAugConfig, CharAugConfig, SBSCConfig 3 | from .sbsc.labeler import TyposTypes 4 | 5 | __all__ = [ 6 | "WordAugCorruptor", 7 | "CharAugCorruptor", 8 | "SBSCCorruptor", 9 | "WordAugConfig", 10 | "CharAugConfig", 11 | "SBSCConfig", 12 | "TyposTypes", 13 | ] 14 | -------------------------------------------------------------------------------- /sage/spelling_corruption/configuration_corruptor.py: -------------------------------------------------------------------------------- 1 | """Configuration classes for corruption methods. 2 | 3 | Currently, three options are maintained: word- and char-level Augmentex and SBSC (Statistic-based 4 | spelling corruption). 5 | 6 | Examples: 7 | from corruptor import WordAugCorruptor 8 | 9 | config = WordAugConfig() 10 | corruptor = WordAugCorruptor.from_config(config) 11 | 12 | ... 13 | 14 | from corruptor import SBSCCorruptor 15 | 16 | config = SBSCConfig( 17 | lang="ru", 18 | reference_dataset_name_or_path="RUSpellRU" 19 | ) 20 | corruptor = SBSCCorruptor.from_config(config) 21 | """ 22 | 23 | import os 24 | from dataclasses import dataclass, field 25 | from typing import List, Dict, Union, Optional 26 | 27 | 28 | @dataclass 29 | class WordAugConfig: 30 | """Word-level Augmentex config. 31 | 32 | Attributes: 33 | min_aug (int): The minimum amount of augmentation. Defaults to 1. 34 | max_aug (int): The maximum amount of augmentation. Defaults to 5. 35 | unit_prob (float): Percentage of the phrase to which augmentations will be applied. Defaults to 0.3. 36 | """ 37 | min_aug: Optional[int] = field( 38 | default=1, 39 | metadata={"help": "The minimum amount of augmentation. Defaults to 1."}, 40 | ) 41 | 42 | max_aug: Optional[int] = field( 43 | default=5, 44 | metadata={"help": "The maximum amount of augmentation. Defaults to 5."}, 45 | ) 46 | 47 | unit_prob: Optional[float] = field( 48 | default=0.3, 49 | metadata={ 50 | "help": "Percentage of the phrase to which augmentations will be applied. Defaults to 0.3."} 51 | ) 52 | 53 | 54 | @dataclass 55 | class CharAugConfig: 56 | """Char-level Augmentex config. 57 | 58 | Attributes: 59 | min_aug (int): The minimum amount of augmentation. Defaults to 1. 60 | max_aug (int): The maximum amount of augmentation. Defaults to 5. 61 | unit_prob (float): Percentage of the phrase to which augmentations will be applied. Defaults to 0.3. 62 | mult_num (int): Maximum repetitions of characters. Defaults to 5. 63 | """ 64 | min_aug: Optional[int] = field( 65 | default=1, 66 | metadata={"help": "The minimum amount of augmentation. Defaults to 1."}, 67 | ) 68 | 69 | max_aug: Optional[int] = field( 70 | default=5, 71 | metadata={"help": "The maximum amount of augmentation. Defaults to 5."}, 72 | ) 73 | 74 | unit_prob: Optional[float] = field( 75 | default=0.3, 76 | metadata={ 77 | "help": "Percentage of the phrase to which augmentations will be applied. Defaults to 0.3."} 78 | ) 79 | 80 | mult_num: Optional[int] = field( 81 | default=5, 82 | metadata={"help": "Maximum repetitions of characters. Defaults to 5."}, 83 | ) 84 | 85 | 86 | @dataclass 87 | class SBSCConfig: 88 | """Config for statistic-based spelling corruption. 89 | 90 | Attributes: 91 | lang (str): source language; 92 | typos_count (List[int]): number of typos per sentence; 93 | stats (Dict[str, Dict[str, List[float]]]): 94 | types of typos and their absolute and relative positions in a sentence; 95 | confusion_matrix (Dict[str, Dict[str, int]]): Candidate replacements with corresponding frequencies; 96 | skip_if_position_not_found (bool): 97 | Whether to search for suitable position in a sentence when position is not found in interval; 98 | reference_dataset_name_or_path (bool): Path to or name of reference dataset 99 | reference_dataset_split (str): Dataset split to use when acquiring statistics. 100 | """ 101 | lang: str = field( 102 | default="ru", 103 | metadata={"help": "Source language"} 104 | ) 105 | typos_count: Optional[List[int]] = field( 106 | default=None, 107 | metadata={"help": "Number of errors per sentence"}, 108 | ) 109 | 110 | stats: Optional[Dict[str, Dict[str, List[float]]]] = field( 111 | default=None, 112 | metadata={"help": "Relative and absolute positions of errors of corresponding types"}, 113 | ) 114 | 115 | confusion_matrix: Optional[Dict[str, Dict[str, int]]] = field( 116 | default=None, 117 | metadata={"help": "Candidate replacements with corresponding frequencies"}, 118 | ) 119 | 120 | skip_if_position_not_found: bool = field( 121 | default=True, 122 | metadata={ 123 | "help": "Whether to search for suitable position in a sentence when position is not found in interval"}, 124 | ) 125 | 126 | reference_dataset_name_or_path: Optional[Union[str, os.PathLike]] = field( 127 | default="RUSpellRU", 128 | metadata={"help": "Path to or name of reference dataset"}, 129 | ) 130 | 131 | reference_dataset_split: str = field( 132 | default="train", 133 | metadata={"help": "Dataset split to use when acquiring statistics."}, 134 | ) 135 | -------------------------------------------------------------------------------- /sage/spelling_corruption/corruptor.py: -------------------------------------------------------------------------------- 1 | """API to available methods of spelling corruption. 2 | 3 | Currently, three options are available: word- and char-level Augmentex and 4 | Statistical-based spelling corruption (SBSC). 5 | 6 | Examples: 7 | from configuration_corruptor import CharAugConfig 8 | 9 | config = CharAugConfig(min_aug=10, max_aug=50, unit_prob=0.5) 10 | corruptor = CharAugCorruptor.from_config(config) 11 | print(corruptor.corrupt(sentence)) 12 | 13 | ... 14 | 15 | corruptor = SBSCCorruptor.from_default_config() 16 | print(corruptor.corrupt(sentence)) 17 | """ 18 | 19 | import dataclasses 20 | from dataclasses import asdict 21 | from typing import List, Union, Optional 22 | from abc import ABCMeta, abstractmethod 23 | 24 | from augmentex.char import CharAug 25 | from augmentex.word import WordAug 26 | 27 | from .sbsc.sbsc import StatisticBasedSpellingCorruption 28 | from .configuration_corruptor import WordAugConfig, CharAugConfig, SBSCConfig 29 | 30 | 31 | class Corruptor(metaclass=ABCMeta): 32 | """Base class for all corruptors. 33 | 34 | Attributes: 35 | config (Dict[str, Any]): config for every particular corruption class; 36 | engine (Union[WordAugCorruptor, CharAugCorruptor, SBSCCorruptor]): 37 | corruptor class; 38 | """ 39 | 40 | engine = None 41 | 42 | def __init__(self): 43 | self.config = asdict(self.get_default_config()) 44 | 45 | @classmethod 46 | def from_config(cls, config: Union[WordAugConfig, CharAugConfig, SBSCConfig]): 47 | """Initialize corruptor from a given config. 48 | 49 | Args: 50 | config (Union[WordAugConfig, CharAugConfig, SBSCConfig]): 51 | config for every particular corruption class; 52 | 53 | Returns: 54 | particular corruptor class initialized from a given config; 55 | """ 56 | corruptor = cls() 57 | corruptor.config = {field.name: getattr(config, field.name) for field in dataclasses.fields(config)} 58 | corruptor.engine = corruptor.engine(**corruptor.config) 59 | 60 | return corruptor 61 | 62 | @classmethod 63 | def from_default_config(cls): 64 | """Initialize corruptor from a default config. 65 | 66 | Returns: 67 | particular corruptor class initialized from a default config; 68 | """ 69 | corruptor = cls() 70 | corruptor.engine = corruptor.engine(**corruptor.config) 71 | 72 | return corruptor 73 | 74 | @abstractmethod 75 | def corrupt(self, sentence: str, action: Optional[str] = None, seed: Optional[int] = 42) -> str: 76 | pass 77 | 78 | @abstractmethod 79 | def batch_corrupt( 80 | self, sentences: List[str], action: Optional[str] = None, batch_prob: Optional[float] = 0.3, 81 | seed: Optional[int] = 42) -> List[str]: 82 | pass 83 | 84 | @staticmethod 85 | @abstractmethod 86 | def get_default_config(): 87 | pass 88 | 89 | 90 | class AugCorruptor(Corruptor, metaclass=ABCMeta): 91 | """Base class for Augmentex-based corruptors.""" 92 | 93 | def corrupt(self, sentence: str, action: Optional[str] = None, seed: Optional[int] = 42) -> str: 94 | return self.engine.augment(sentence, seed=seed, action=action) 95 | 96 | def batch_corrupt( 97 | self, sentences: List[str], action: Optional[str] = None, batch_prob: Optional[float] = 0.3, 98 | seed: Optional[int] = 42) -> List[str]: 99 | return self.engine.aug_batch(sentences, seed=seed, batch_prob=batch_prob, action=action) 100 | 101 | 102 | class WordAugCorruptor(AugCorruptor): 103 | 104 | engine = WordAug 105 | 106 | @staticmethod 107 | def get_default_config(): 108 | return WordAugConfig() 109 | 110 | 111 | class CharAugCorruptor(AugCorruptor): 112 | 113 | engine = CharAug 114 | 115 | @staticmethod 116 | def get_default_config(): 117 | return CharAugConfig() 118 | 119 | 120 | class SBSCCorruptor(Corruptor): 121 | 122 | engine = StatisticBasedSpellingCorruption 123 | 124 | def corrupt(self, sentence: str, action: Optional[str] = None, seed: Optional[int] = 42) -> str: 125 | return self.engine.corrupt(sentence, seed) 126 | 127 | def batch_corrupt( 128 | self, sentences: List[str], action: Optional[str] = None, batch_prob: Optional[float] = 0.3, 129 | seed: Optional[int] = 42) -> List[str]: 130 | return self.engine.batch_corrupt(sentences, seed) 131 | 132 | @staticmethod 133 | def get_default_config(): 134 | return SBSCConfig() 135 | -------------------------------------------------------------------------------- /sage/spelling_corruption/sbsc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-forever/sage/71736b6d14b1dd4a369ea9289eb2cac9a5a25f21/sage/spelling_corruption/sbsc/__init__.py -------------------------------------------------------------------------------- /sage/spelling_corruption/sbsc/base_classes.py: -------------------------------------------------------------------------------- 1 | """Base classes for misspellings. 2 | 3 | Includes parent abstract class and corresponding APIs for each type of errors, 4 | as well as API to discrete distributions (`class Distribution`). 5 | """ 6 | 7 | import logging 8 | from abc import ABCMeta, abstractmethod 9 | from typing import Optional, Callable, List 10 | 11 | import numpy as np 12 | 13 | from .typings_positions_conditions import initialize_conditions 14 | from ...utils.lang_utils import INSERTION_OPTIONS 15 | 16 | conditions = initialize_conditions() 17 | MISSPELLINGS = {} 18 | 19 | 20 | def register_misspelling(cls): 21 | MISSPELLINGS[cls.description()] = cls() 22 | 23 | 24 | class Distribution: 25 | """Emulates discrete distribution.""" 26 | 27 | def __init__(self, evidences: List[int], exclude_zero: bool): 28 | if exclude_zero: 29 | evidences = [elem for elem in evidences if elem != 0] 30 | self.values, counts = np.unique(evidences, return_counts=True) 31 | self.p = counts / sum(counts) 32 | 33 | def sample(self, rng: np.random.default_rng): 34 | if len(self.values) == 0: 35 | raise ValueError("You cannot sample from empty distribution, provide some statistics first") 36 | value = rng.choice(self.values, size=1, p=self.p)[0] 37 | return value 38 | 39 | 40 | class Typo(metaclass=ABCMeta): 41 | """Base class for all handlers. 42 | 43 | Attributes: 44 | condition (typings_positions_conditions.Condition): 45 | condition for appropriate position of a typo in a sentence. 46 | """ 47 | 48 | def __init__(self): 49 | self.condition = None if self.desc is None else conditions[self.desc] 50 | 51 | @staticmethod 52 | @abstractmethod 53 | def description(): 54 | """We will need this in object of Fabric class 55 | when instantiating an object from dict of possible misspellings. 56 | """ 57 | 58 | @property 59 | @abstractmethod 60 | def desc(self): 61 | """We need this to identify particular type of error 62 | And use it while initialization. 63 | """ 64 | 65 | @abstractmethod 66 | def apply( 67 | self, pos: int, sentence: str, lang: str, rng: np.random.default_rng, 68 | substitutions: Optional[Distribution] = None 69 | ) -> str: 70 | """Insert typo in particular `pos` in a `sentence`. 71 | 72 | Args: 73 | pos (int): position to insert typo; 74 | sentence (str): original sentence; 75 | lang (str): language code; 76 | rng (np.random.default_rng): random generator; 77 | substitutions (Distribution): optional, set of options for substitution; 78 | """ 79 | 80 | def adjust_position( 81 | self, pos: int, most_left: int, most_right: int, skip_if_position_not_found: bool, 82 | used_positions: List[int], rng: np.random.default_rng, lang: str, sentence: Optional[str] = None 83 | ) -> int: 84 | """Select appropriate position in interval from `most_left` to `most_right` 85 | starting from `pos` in a `sentence`. 86 | 87 | Args: 88 | pos (int): starting position; 89 | most_left (int): starting position of interval; 90 | most_right (int): ending position of interval; 91 | skip_if_position_not_found (bool): whether to skip, when appropriate position for typo cannot be found; 92 | used_positions (List[int]): array of taken positions; 93 | rng (np.random.default_rng): random generator; 94 | lang (str): language code; 95 | sentence (str): original sentence; 96 | """ 97 | effective_tries = (most_right - most_left) * 2 98 | cnt_tries = 0 99 | while self.condition.condition(pos, used_positions, sentence, lang): 100 | pos = rng.integers(low=most_left, high=most_right, size=1)[0] 101 | cnt_tries += 1 102 | if cnt_tries == effective_tries: 103 | logging.info("Falling back on {}".format(self.desc)) 104 | pos = self._fallback(skip_if_position_not_found)(used_positions, lang, most_left, most_right, sentence) 105 | break 106 | return pos 107 | 108 | def _fallback(self, skip_if_position_not_found: bool) -> Callable: 109 | if skip_if_position_not_found: 110 | return self._skip_fallback_strategy 111 | return self._default_fallback_strategy 112 | 113 | def _default_fallback_strategy(self, used_positions: List[int], lang: str, most_left: Optional[int] = None, 114 | most_right: Optional[int] = None, sentence: Optional[str] = None 115 | ) -> Optional[int]: 116 | """Iterate through the whole `sentence` and search for appropriate position. 117 | When one found, stop iterating. 118 | 119 | Args: 120 | used_positions (List[int]): array of taken positions; 121 | lang (str): language code; 122 | most_left (int): starting position of interval; 123 | most_right (int): ending position of interval; 124 | sentence (str): original sentence; 125 | """ 126 | pos = None 127 | for i, ch in enumerate(sentence): 128 | if self.condition.condition(i, used_positions, sentence, lang): 129 | continue 130 | pos = i 131 | break 132 | return pos 133 | 134 | @staticmethod 135 | def _skip_fallback_strategy(used_positions: List[int], lang: str, most_left: Optional[int] = None, 136 | most_right: Optional[int] = None, sentence: Optional[str] = None, 137 | ) -> Optional[int]: 138 | """Skipping current typo if there is no appropriate position for it.""" 139 | 140 | return None 141 | 142 | 143 | class Fabric: 144 | """This acts as somewhat factory for the handlers. 145 | 146 | Attributes: 147 | used_positions (List[int]): array of taken positions; 148 | """ 149 | 150 | def __init__(self): 151 | self.used_positions = [] 152 | 153 | def finish(self, pos: int, typo: str) -> None: 154 | """Alter `used_positions` after typo has been inserted. 155 | 156 | Args: 157 | pos (int): position of typo; 158 | typo (str): type of typo; 159 | """ 160 | self.used_positions.append(pos) 161 | executor = conditions[typo] 162 | executor.alter_positions(pos, self.used_positions) 163 | 164 | @staticmethod 165 | def get_handler(typo: str) -> Typo: 166 | return MISSPELLINGS[typo] 167 | 168 | 169 | @register_misspelling 170 | class Insertion(Typo): 171 | """API to insertion typo. 172 | 173 | Insertion type of error implies insertion of unnecessary characters in 174 | an original sentence. 175 | 176 | Examples of error: 177 | 1. Error -> Errror; 178 | 2. Мама дома мыла раму -> Марма дома мыла раму; 179 | """ 180 | 181 | @staticmethod 182 | def description(): 183 | return "insertion" 184 | 185 | @property 186 | def desc(self): 187 | return "insertion" 188 | 189 | def apply( 190 | self, pos: int, sentence: str, lang: str, rng: np.random.default_rng, 191 | substitutions: Optional[Distribution] = None 192 | ) -> str: 193 | insertions = getattr(INSERTION_OPTIONS, lang) 194 | insertion = rng.choice(insertions, size=1)[0] 195 | sentence = sentence[:pos] + insertion + sentence[pos:] 196 | return sentence 197 | 198 | 199 | @register_misspelling 200 | class Deletion(Typo): 201 | """API to deletion typo. 202 | 203 | Deletion type of error implies deletion of characters in 204 | an original sentence. 205 | 206 | Examples of error: 207 | 1. Error -> Eror; 208 | 2. Мама дома мыла раму -> Мма дома мыла раму; 209 | """ 210 | 211 | @staticmethod 212 | def description(): 213 | return "deletion" 214 | 215 | @property 216 | def desc(self): 217 | return "deletion" 218 | 219 | def apply( 220 | self, pos: int, sentence: str, lang: str, rng: np.random.default_rng, 221 | substitutions: Optional[Distribution] = None 222 | ) -> str: 223 | sentence = sentence[:pos] + sentence[min(len(sentence), pos + 1):] 224 | return sentence 225 | 226 | 227 | @register_misspelling 228 | class Transposition(Typo): 229 | """API to transposition typo. 230 | 231 | Transposition type of error implies swapping two adjacent characters. 232 | 233 | Examples of error: 234 | 1. Error -> Errro; 235 | 2. Мама дома мыла раму -> Маам дома мыла раму; 236 | """ 237 | 238 | @staticmethod 239 | def description(): 240 | return "transposition" 241 | 242 | @property 243 | def desc(self): 244 | return "transposition" 245 | 246 | def apply( 247 | self, pos: int, sentence: str, lang: str, rng: np.random.default_rng, 248 | substitutions: Optional[Distribution] = None 249 | ) -> str: 250 | sentence = sentence[:pos] + sentence[pos + 1] + sentence[pos] + sentence[min(len(sentence), pos + 2):] 251 | return sentence 252 | 253 | 254 | @register_misspelling 255 | class Substitution(Typo): 256 | """API to substitution typo. 257 | 258 | Substitution type of error implies substitution of one character 259 | in an original sentence. 260 | 261 | Examples of error: 262 | 1. Error -> Errar; 263 | 2. Мама дома мыла раму -> Мама Мома мыла раму; 264 | """ 265 | @staticmethod 266 | def description(): 267 | return "substitution" 268 | 269 | @property 270 | def desc(self): 271 | return "substitution" 272 | 273 | def apply( 274 | self, pos: int, sentence: str, lang: str, rng: np.random.default_rng, 275 | substitutions: Optional[Distribution] = None 276 | ) -> str: 277 | substitution = substitutions.sample(rng) 278 | if sentence[pos].isupper(): 279 | substitution = substitution.upper() 280 | sentence = sentence[:pos] + substitution + sentence[min(len(sentence), pos + 1):] 281 | return sentence 282 | 283 | 284 | @register_misspelling 285 | class ExtraSeparator(Typo): 286 | """API to extra separator typo. 287 | 288 | ExtraSeparator type of error implies insertion of extra gap 289 | in an original sentence. 290 | 291 | Examples of error: 292 | 1. Error -> Err or; 293 | 2. Мама дома мыла раму -> Ма ма дома мыла раму; 294 | """ 295 | @staticmethod 296 | def description(): 297 | return "extra_separator" 298 | 299 | @property 300 | def desc(self): 301 | return "extra_separator" 302 | 303 | def apply( 304 | self, pos: int, sentence: str, lang: str, rng: np.random.default_rng, 305 | substitutions: Optional[Distribution] = None 306 | ) -> str: 307 | sentence = sentence[:pos] + " " + sentence[pos:] 308 | return sentence 309 | 310 | 311 | @register_misspelling 312 | class MissingSeparator(Typo): 313 | """API to missing separator typo. 314 | 315 | MissingSeparator type of error implies deletion of a gap 316 | in an original sentence. 317 | 318 | Examples of error: 319 | 1. Error specifically made for this example-> Errorspecifically made for this example; 320 | 2. Мама дома мыла раму -> Мама домамыла раму; 321 | """ 322 | @staticmethod 323 | def description(): 324 | return "missing_separator" 325 | 326 | @property 327 | def desc(self): 328 | return "missing_separator" 329 | 330 | def apply( 331 | self, pos: int, sentence: str, lang: str, rng: np.random.default_rng, 332 | substitutions: Optional[Distribution] = None 333 | ) -> str: 334 | sentence = sentence[:pos] + sentence[min(len(sentence), pos + 1):] 335 | return sentence 336 | -------------------------------------------------------------------------------- /sage/spelling_corruption/sbsc/labeler.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides functionality to detect type and position of mistyping 3 | given source sentence and corresponding corrected sentence. 4 | 5 | """ 6 | 7 | import re 8 | import enum 9 | import string 10 | from typing import List, Dict 11 | 12 | import numpy as np 13 | from tqdm.auto import tqdm 14 | 15 | 16 | class TyposTypes(enum.Enum): 17 | """Available types of errors.""" 18 | 19 | insertion = "Extra character" 20 | deletion = "Missing character" 21 | substitution = "Wrong character" 22 | transposition = "Two adjacent characters shuffled" 23 | missing_separator = "Missing gap" 24 | extra_separator = "Extra gap" 25 | 26 | 27 | def make_levenshtein_table(source, correct, allow_transpositions=False, removal_cost=1.0, insertion_cost=1.0, 28 | replace_cost=1.0, transposition_cost=1.0): 29 | first_length, second_length = len(source), len(correct) 30 | table = np.zeros(shape=(first_length + 1, second_length + 1), dtype=float) 31 | for i in range(1, second_length + 1): 32 | table[0][i] = i 33 | for i in range(1, first_length + 1): 34 | table[i][0] = i 35 | for i, first_word in enumerate(source, 1): 36 | for j, second_word in enumerate(correct, 1): 37 | if first_word == second_word: 38 | table[i][j] = table[i-1][j-1] 39 | else: 40 | table[i][j] = min((table[i-1][j-1] + replace_cost, 41 | table[i][j-1] + removal_cost, 42 | table[i-1][j] + insertion_cost)) 43 | if (allow_transpositions and min(i, j) >= 2 44 | and first_word == correct[j-2] and second_word == source[i-2]): 45 | table[i][j] = min(table[i][j], table[i-2][j-2] + transposition_cost) 46 | return table 47 | 48 | 49 | def process_group(source: str, correction: str, levenshtein_table: np.array) -> \ 50 | [Dict[str, List[int]], Dict[str, List[int]], Dict[str, Dict[str, int]]]: 51 | 52 | """ 53 | Identify type of mistyping and its position. 54 | 55 | Trace back table of Levenshtein distances and detect 56 | type of mistyping by the node from which it came from. 57 | 58 | Args: 59 | source (str): source sequence. 60 | correction (str): corrected sequence. 61 | levenshtein_table (np.array): table filled with distances between prefixes. 62 | 63 | Returns: 64 | d: Dict[str, List[int]], distribution of mistypings in sequence. 65 | d_src: Dict[str, List[int]], analogously, but positions are related to source sentence. 66 | confusion_matrix: Dict[str, Dict[str, int]], confusion matrix. 67 | 68 | """ 69 | 70 | source = source.lower() 71 | correction = correction.lower() 72 | i = len(source) 73 | j = len(correction) 74 | d = {typo_type.name: [] for typo_type in TyposTypes} 75 | d_src = {typo_type.name: [] for typo_type in TyposTypes} 76 | confusion_matrix = {} 77 | 78 | while i > 0 and j > 0: 79 | # If characters are same 80 | if source[i-1] == correction[j-1]: 81 | i -= 1 82 | j -= 1 83 | 84 | # Substitution 85 | elif levenshtein_table[i][j] == levenshtein_table[i - 1][j - 1] + 1: 86 | d["substitution"].append(j - 1) 87 | d_src["substitution"].append(i - 1) 88 | correct_char = correction[j - 1] 89 | source_char = source[i - 1] 90 | if correct_char in confusion_matrix: 91 | if source_char in confusion_matrix[correct_char]: 92 | confusion_matrix[correct_char][source_char] += 1 93 | else: 94 | confusion_matrix[correct_char][source_char] = 1 95 | else: 96 | confusion_matrix[correct_char] = {source_char: 1} 97 | 98 | j -= 1 99 | i -= 1 100 | 101 | # Insertion 102 | elif levenshtein_table[i][j] == levenshtein_table[i - 1][j] + 1: 103 | if source[i - 1] == " ": 104 | d["extra_separator"].append(j) 105 | d_src["extra_separator"].append(i - 1) 106 | else: 107 | d["insertion"].append(j) 108 | d_src["insertion"].append(i - 1) 109 | i -= 1 110 | 111 | # Deletion 112 | elif levenshtein_table[i][j] == levenshtein_table[i][j - 1] + 1: 113 | if j < len(source) and correction[j - 1] == " ": 114 | d["missing_separator"].append(j - 1) 115 | d_src["missing_separator"].append(i) 116 | else: 117 | d["deletion"].append(j - 1) 118 | d_src["deletion"].append(i) 119 | j -= 1 120 | 121 | # Transposition 122 | elif min(i, j) >= 2 and levenshtein_table[i][j] == levenshtein_table[i-2][j-2] + 1: 123 | d["transposition"].append(j - 2) 124 | d_src["transposition"].append(i - 2) 125 | i -= 2 126 | j -= 2 127 | 128 | if i > 0: 129 | d["insertion"].extend([0] * i) 130 | d_src["insertion"].extend(list(range(i))) 131 | if j > 0: 132 | d["deletion"].extend(list(range(j))) 133 | d_src["deletion"].extend([0] * j) 134 | 135 | return d, d_src, confusion_matrix 136 | 137 | 138 | def process_mistypings( 139 | src: List[str], corr: List[str], 140 | ) -> [Dict[str, Dict[str, List[float]]], Dict[str, Dict[str, int]], List[int]]: 141 | 142 | """ 143 | Processes allignment groups and outputs mistypings distribution. 144 | We have following classification of mistypings that goes like this: 145 | 1. insertion ("туберкПулёз" -> "туберкулёз") 146 | 2. deletion ("тубркулёз" -> "туберкулёз") 147 | 3. substitution ("тубИркулёз" -> "туберкулёз") 148 | 4. transposition ("тубРЕкулёз" -> "туберкулёз") 149 | 5. extra_separator ("туберкулёз" -> "тубе ркулёз") 150 | 6. missing_separator ("острый туберкулёз" -> "острыйтуберкулёз") 151 | 152 | Args: 153 | src (List[str]): original sequences with mistypings. 154 | corr (List[str]): corrected sequences. 155 | 156 | Returns: 157 | global_stats: Dict[str, Dict[str, List[float]]], distributions of positions 158 | across the whole corpus. 159 | global_cm: Dict[str, Dict[str, int]], confusion matrix on the whole corpus. 160 | mistypings_cnt: List[int], number of mistypings in each sentence. 161 | 162 | """ 163 | global_stats = {typo_type.name: {"abs": [], "rel": []} for typo_type in TyposTypes} 164 | global_cm = {} 165 | mistypings_cnt = [] 166 | pattern = string.punctuation.replace("-", "") 167 | 168 | l = len(src) 169 | for source, correction in tqdm(zip(src, corr), total=l): 170 | source = re.sub(r"[{}]".format(pattern), "", source.lower().strip()) 171 | correction = re.sub(r"[{}]".format(pattern), "", correction.lower().strip()) 172 | 173 | dp = make_levenshtein_table(source, correction, allow_transpositions=True) 174 | # We gather distributions from source sentences, NOT from corrections 175 | _, local_stats, local_cm = process_group(source, correction, dp) 176 | 177 | mistypings_cnt.append(sum((len(v) for _, v in local_stats.items()))) 178 | for typo, positions in local_stats.items(): 179 | global_stats[typo]["abs"].extend(positions) 180 | global_stats[typo]["rel"].extend([0. if len(source) == 0 else pos / len(source) for pos in positions]) 181 | for correct_char, candidates in local_cm.items(): 182 | if correct_char not in global_cm: 183 | global_cm[correct_char] = {} 184 | for candidate, cnt in candidates.items(): 185 | if candidate in global_cm[correct_char]: 186 | global_cm[correct_char][candidate] += cnt 187 | else: 188 | global_cm[correct_char][candidate] = cnt 189 | 190 | return global_stats, global_cm, mistypings_cnt 191 | -------------------------------------------------------------------------------- /sage/spelling_corruption/sbsc/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides the main functionality to make statistical mistypings 3 | that is embodied in Model class. 4 | 5 | """ 6 | 7 | import math 8 | from functools import reduce 9 | from typing import List, Dict, Optional, Union 10 | 11 | import numpy as np 12 | 13 | from .base_classes import Fabric, Distribution 14 | from .labeler import TyposTypes 15 | from ...utils.lang_utils import SUBSTITUTION_OPTIONS, AVAILABLE_LANG_CODES 16 | 17 | 18 | class Model: 19 | """Statistical model parametrized by fetched distributions. 20 | 21 | Given parallel corpus, number of typos per sentence, types of error and their 22 | corresponding positions and substitution statistics are first gathered. 23 | Raw statistics are then fed to `Model` and normalized to appropriate discrete 24 | distributions. `Model` is parametrized by these distributions, and is used 25 | to corrupt text in a statistic-based manner. 26 | 27 | Attributes: 28 | debug_mode (bool): used for tests purposes; 29 | stats (Dict[str, List[int]]): used for tests purposes; 30 | lang (str): language of original text; 31 | skip_if_position_not_found (bool): whether to skip typo, when appropriate position cannot be found; 32 | 33 | Usage: 34 | from labeler import process_mistypings 35 | 36 | sources, corrections = load_data(...) 37 | typos_cnt, cm, stats = process_mistypings(sources, corrections) 38 | model = Model(typos_cnt, stats, cm, True, "ru") 39 | print(model.transform(clean_sentence)) 40 | """ 41 | names = [typo_type.name for typo_type in TyposTypes] 42 | 43 | def __init__( 44 | self, typos_count: List[int], 45 | stats: Dict[str, Dict[str, List[float]]], 46 | confusion_matrix: Dict[str, Dict[str, int]], 47 | skip_if_position_not_found: bool, 48 | lang: str, 49 | debug_mode: bool = False, 50 | ): 51 | # For debugging purposes only 52 | self.debug_mode = debug_mode 53 | self.stats = { 54 | "used_positions_pre": [], 55 | "used_positions_after": [], 56 | "pos": [], 57 | } 58 | 59 | self.validate_inputs(stats, confusion_matrix, typos_count, lang) 60 | 61 | self.lang = lang.strip("_ ").lower() 62 | self.skip_if_position_not_found = skip_if_position_not_found 63 | 64 | # Number of mistypings per sentence 65 | self._register_distribution("number_of_errors_per_sent", typos_count, False) 66 | 67 | # Type of mistypings 68 | typos_cnt = {typo: len(v["abs"]) for typo, v in stats.items()} 69 | typos = reduce(lambda x, y: x + y, [[k] * v for k, v in typos_cnt.items()]) 70 | self._register_distribution("type_of_typo", typos) 71 | 72 | # Relative positions of mistypings 73 | self._bins = [0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.] 74 | for typo, v in stats.items(): 75 | # To avoid 1.s being thrown in 11th bucket 76 | rel_positions = [pos if pos < 1. else pos - 0.00001 for pos in v["rel"]] 77 | 78 | buckets = np.digitize(rel_positions, self._bins) 79 | self._register_distribution(typo + "_positions", buckets) 80 | 81 | # Substitutions (confusion matrix) 82 | for ch, candidates in confusion_matrix.items(): 83 | counts = reduce(lambda x, y: x + y, [[k] * v for k, v in candidates.items()]) 84 | self._register_distribution("substitutions_for_{}".format(ord(ch)), counts) 85 | 86 | @classmethod 87 | def validate_inputs(cls, stats: Dict[str, Dict[str, List[float]]], confusion_matrix: Dict[str, Dict[str, int]], 88 | typos_counts: List[int], lang: str): 89 | lang = lang.strip("_ ").lower() 90 | if lang not in AVAILABLE_LANG_CODES: 91 | raise ValueError( 92 | "Wrong language code: {}. Available codes are {}".format(lang, " ".join(AVAILABLE_LANG_CODES))) 93 | if len(stats) == 0: 94 | raise ValueError("Stats are empty, you should provide some") 95 | total_pos_num = 0 96 | for k, v in stats.items(): 97 | if k not in cls.names: 98 | raise ValueError("You provided stats in wrong format, the key {} is not expected".format(k)) 99 | if len(v["abs"]) != len(v["rel"]): 100 | raise ValueError("Your inputs' lengths in stats (abs / rel) do not match for {}".format(k)) 101 | illegal_positions = [i for i, elem in enumerate(v["abs"]) if elem < 0] 102 | if len(illegal_positions) != 0: 103 | raise ValueError("Provide non-negative values for absolute positions for {} at positions {}".format( 104 | k, illegal_positions)) 105 | illegal_positions = [i for i, elem in enumerate(v["rel"]) if elem < 0 or elem > 1] 106 | if len(illegal_positions) != 0: 107 | raise ValueError("Provide values between 0 and 1 for relative positions for {} at positions {}".format( 108 | k, illegal_positions)) 109 | total_pos_num += len(v["abs"]) 110 | if total_pos_num == 0: 111 | raise ValueError("Provide some actual statistics") 112 | if len(typos_counts) == 0: 113 | raise ValueError("Typos counts are empty, you should provide some") 114 | if min(typos_counts) < 0: 115 | raise ValueError("Provide non-negative number of errors") 116 | if len(confusion_matrix) == 0 and "substitution" in stats and len(stats["substitution"]["abs"]) > 0: 117 | raise ValueError("Confusion matrix is empty, but substitution is in stats") 118 | for k, v in confusion_matrix.items(): 119 | if len(k) != 1: 120 | raise ValueError("Wrong format of key {} in confusion matrix".format(k)) 121 | for sub, count in v.items(): 122 | if len(sub) != 1: 123 | raise ValueError("Wrong format of substitution {} in confusion matrix".format(sub)) 124 | if count < 0: 125 | raise ValueError("Provide non-negative value for count for key {} and substitution {}".format( 126 | k, sub)) 127 | 128 | def _register_distribution( 129 | self, distribution: str, evidences: Union[List[int], np.array], exclude_zero: Optional[bool] = False): 130 | if hasattr(self, distribution): 131 | raise ValueError("You already defined that distribution {}".format(distribution)) 132 | d = Distribution(evidences, exclude_zero) 133 | setattr(self, distribution, d) 134 | 135 | def _factorization_scheme(self, interval_idx: int, sequence_length: int) -> [int, int]: 136 | """Calculates exact absolute edge positions in a sentence, considering relative positions in a sentence. 137 | 138 | Args: 139 | interval_idx (int): 140 | interval id ranging from 1 to 10, representing equal non-overlaping semi-open 141 | intervals in [0,1]; 142 | sequence_length (int): number of characters in a sentence; 143 | """ 144 | left, right = self._bins[interval_idx - 1], self._bins[interval_idx] 145 | most_left = math.ceil(sequence_length * left) 146 | most_right = math.ceil(sequence_length * right) 147 | return most_left, most_right 148 | 149 | def transform(self, sentence: str, rng: np.random.default_rng): 150 | """Spelling corruption procedure. 151 | 152 | The algorithm follows consequtive steps: 153 | 1. Sample number of errors; 154 | 2. For each error sample its type and corresponding interval in a sentence; 155 | 3. Calculate absolute start and ending positions for typo; 156 | 4. In a given interval find appropriate position for typo; 157 | 5. Insert typo; 158 | 159 | Args: 160 | sentence (str): original sentence; 161 | rng (np.random.default_rng): random generator; 162 | 163 | Returns: 164 | sentence (str): original sentence, but with errors; 165 | """ 166 | # Sample number of mistypings 167 | num_typos = self.number_of_errors_per_sent.sample(rng) 168 | fabric = Fabric() 169 | 170 | for _ in range(num_typos): 171 | # take len() for every typo, because with each 172 | # typo length of the sentence changes 173 | l = len(sentence) 174 | 175 | # sample typo and corresponding interval for position 176 | typo = self.type_of_typo.sample(rng) 177 | handler = fabric.get_handler(typo) 178 | position_distribution = getattr(self, typo + "_positions") 179 | 180 | # sample bin a.k.a. interval for typo's position 181 | # and initial exact position inside this interval 182 | effective_tries = l 183 | most_left, most_right = -1, -1 184 | while effective_tries >= 0: 185 | interval_idx = position_distribution.sample(rng) 186 | most_left, most_right = self._factorization_scheme(interval_idx, l) 187 | if most_right - most_left >= 1: # for fixed bins that means length of sentence < 10 188 | break 189 | effective_tries -= 1 190 | if most_right - most_left < 1: 191 | continue 192 | 193 | pos = rng.integers(low=most_left, high=most_right, size=1)[0] 194 | 195 | # Correct the position 196 | pos = handler.adjust_position( 197 | pos, most_left, most_right, self.skip_if_position_not_found, 198 | fabric.used_positions, rng, self.lang, sentence 199 | ) 200 | if pos is not None: 201 | try: 202 | substitutions = getattr(self, "substitutions_for_{}".format(ord(sentence[pos].lower()))) 203 | except AttributeError: 204 | substitutions = Distribution(getattr(SUBSTITUTION_OPTIONS, self.lang), False) 205 | sentence = handler.apply(pos, sentence, self.lang, rng, substitutions) 206 | 207 | if self.debug_mode: 208 | used_positions_cp = fabric.used_positions.copy() 209 | self.stats["used_positions_pre"].append(used_positions_cp) 210 | self.stats["pos"].append(pos) 211 | 212 | fabric.finish(pos, typo) 213 | 214 | if self.debug_mode: 215 | used_positions_cp = fabric.used_positions.copy() 216 | self.stats["used_positions_after"].append(used_positions_cp) 217 | 218 | return sentence 219 | -------------------------------------------------------------------------------- /sage/spelling_corruption/sbsc/sbsc.py: -------------------------------------------------------------------------------- 1 | """API to Statistical-based Spelling Corruption method. 2 | 3 | Examples: 4 | corruptor = StatisticBasedSpellingCorruption( 5 | lang="ru", 6 | reference_dataset_name_or_path="RUSpellRU", 7 | ) 8 | print(corruptor.corrupt(sentence)) 9 | 10 | .... 11 | 12 | from labeler import process_mistypings 13 | 14 | sources, corrections = load_data(...) 15 | typos_cnt, cm, stats = process_mistypings(sources, corrections) 16 | corruptor = StatisticBasedSpellingCorruption( 17 | lang="ru", 18 | typos_count=typos_cnt, 19 | stats=stats, 20 | confusion_matrix=cm, 21 | ) 22 | print(corruptor.corrupt(sentence)) 23 | """ 24 | 25 | import os 26 | from typing import List, Dict, Optional, Union 27 | 28 | import numpy as np 29 | import pandas as pd 30 | from tqdm.auto import tqdm 31 | 32 | from .model import Model 33 | from .labeler import process_mistypings 34 | from ...utils.data_load_utils import load_available_dataset_from_hf, DatasetsAvailable 35 | 36 | datasets_available = [dataset.name for dataset in DatasetsAvailable] 37 | 38 | 39 | class StatisticBasedSpellingCorruption: 40 | """API to `Model` class from model.py. 41 | 42 | Attributes: 43 | model (model.Model): statistic-based spelling corruption model; 44 | """ 45 | 46 | def __init__( 47 | self, 48 | lang: str, 49 | typos_count: Optional[List[int]] = None, 50 | stats: Optional[Dict[str, Dict[str, List[float]]]] = None, 51 | confusion_matrix: Optional[Dict[str, Dict[str, int]]] = None, 52 | skip_if_position_not_found: bool = True, 53 | reference_dataset_name_or_path: Optional[Union[str, os.PathLike]] = None, 54 | reference_dataset_split: str = "train", 55 | ): 56 | typos_count_ = None 57 | stats_ = None 58 | confusion_matrix_ = None 59 | 60 | if (typos_count is None or stats is None or confusion_matrix is None) and reference_dataset_name_or_path is None: 61 | raise RuntimeError('''You should provide at least one of :typos_count:/:stats:/:confusion_matrix: 62 | or :reference_dataset_name_or_path:''') 63 | if (typos_count is None or stats is None or confusion_matrix is None) and \ 64 | reference_dataset_name_or_path is not None: 65 | reference_dataset_name_or_path = str(reference_dataset_name_or_path) 66 | if reference_dataset_name_or_path in datasets_available: 67 | sources, corrections = load_available_dataset_from_hf( 68 | reference_dataset_name_or_path, for_labeler=True, split=reference_dataset_split) 69 | stats_, confusion_matrix_, typos_count_ = process_mistypings(sources, corrections) 70 | elif os.path.isdir(reference_dataset_name_or_path): 71 | if os.path.isfile(os.path.join(reference_dataset_name_or_path, "sources.txt")) and \ 72 | os.path.isfile(os.path.join(reference_dataset_name_or_path, "corrections.txt")): 73 | src_file = open(os.path.join(reference_dataset_name_or_path, "sources.txt"), encoding="utf8") 74 | corr_file = open(os.path.join(reference_dataset_name_or_path, "corrections.txt"), encoding="utf8") 75 | sources = src_file.read().split("\n") 76 | corrections = corr_file.read().split("\n") 77 | src_file.close() 78 | corr_file.close() 79 | if len(sources) != len(corrections): 80 | raise RuntimeError("Sources and corrections must be of the same length, but get {} vs {}".format( 81 | len(sources), len(corrections))) 82 | stats_, confusion_matrix_, typos_count_ = process_mistypings(sources, corrections) 83 | elif os.path.isfile(os.path.join(reference_dataset_name_or_path, "data.csv")): 84 | try: 85 | data = pd.read_csv(os.path.join(reference_dataset_name_or_path, "data.csv")) 86 | except Exception as e: 87 | raise RuntimeError("Wrong format of file {}. Raised an error: {}".format( 88 | os.path.join(reference_dataset_name_or_path, "data.csv"), str(e))) 89 | if not ("source" in data and "correction" in data): 90 | raise RuntimeError("You must provide 'source' and 'correction' columns in {}".format( 91 | os.path.join(reference_dataset_name_or_path, "data.csv") 92 | )) 93 | if data.isna().any().max(): 94 | raise ValueError("Your data at {} contain unnecessary nans".format( 95 | os.path.join(reference_dataset_name_or_path, "data.csv"))) 96 | sources = data.source.values.tolist() 97 | corrections = data.correction.values.tolist() 98 | stats_, confusion_matrix_, typos_count_ = process_mistypings(sources, corrections) 99 | else: 100 | raise RuntimeError("You must provide either 'data.csv' or 'sources.txt'/'corrections.txt' in {}".format( 101 | reference_dataset_name_or_path 102 | )) 103 | else: 104 | raise ValueError("You must provide either valid path or available dataset's name, you provided {}".format( 105 | reference_dataset_name_or_path 106 | )) 107 | if typos_count is not None: 108 | typos_count_ = typos_count 109 | if stats is not None: 110 | stats_ = stats 111 | if confusion_matrix is not None: 112 | confusion_matrix_ = confusion_matrix 113 | 114 | self.model = Model( 115 | typos_count=typos_count_, 116 | stats=stats_, 117 | confusion_matrix=confusion_matrix_, 118 | skip_if_position_not_found=skip_if_position_not_found, 119 | lang=lang, 120 | ) 121 | 122 | @staticmethod 123 | def show_reference_datasets_available(): 124 | print(*datasets_available, sep="\n") 125 | 126 | def corrupt(self, sentence: str, seed: int) -> str: 127 | return self.batch_corrupt([sentence], seed)[0] 128 | 129 | def batch_corrupt(self, sentences: List[str], seed: int) -> List[str]: 130 | result = [] 131 | pb = tqdm(total=len(sentences)) 132 | rng = np.random.default_rng(seed) 133 | for sentence in sentences: 134 | result.append(self.model.transform(sentence, rng)) 135 | pb.update(1) 136 | return result 137 | -------------------------------------------------------------------------------- /sage/spelling_corruption/sbsc/typings_positions_conditions.py: -------------------------------------------------------------------------------- 1 | """Conditions to search for appropriate position for corresponding typo. 2 | 3 | Each class embodies necessary conditions in order for particular type of error 4 | to be properly inserted. 5 | """ 6 | 7 | import string 8 | from abc import ABCMeta, abstractmethod 9 | from typing import List 10 | 11 | 12 | class Condition(metaclass=ABCMeta): 13 | """Base class for all conditions.""" 14 | 15 | @staticmethod 16 | @abstractmethod 17 | def condition(pos: int, used_positions: List[int], sentence: str, lang: str) -> bool: 18 | """Checks whether particular position `pos` satisfies typo's requirements. 19 | 20 | Args: 21 | pos (int): position to check on; 22 | used_positions (List[str]): taken positions; 23 | sentence (str): original sentence; 24 | lang (str): language of original sentence; 25 | 26 | Returns: 27 | Whether or not `pos` is appropriate position to insert a particular typo. 28 | """ 29 | 30 | @staticmethod 31 | @abstractmethod 32 | def alter_positions(pos: int, used_positions: List[int]): 33 | """Corrects list of taken positions accordingly. 34 | 35 | When typo is inserted, one must include its position in a list of 36 | taken positions and alter present positions. 37 | 38 | Args: 39 | pos (int): position of inserted typo; 40 | used_positions (List[str]): list of taken positions; 41 | """ 42 | 43 | 44 | class InsertionConditions(Condition): 45 | @staticmethod 46 | def condition(pos: int, used_positions: List[int], sentence: str, lang: str) -> bool: 47 | return pos in used_positions 48 | 49 | @staticmethod 50 | def alter_positions(pos: int, used_positions: List[int]): 51 | for i, p in enumerate(used_positions): 52 | used_positions[i] = p if p <= pos else p + 1 53 | 54 | 55 | class DeletionConditions(Condition): 56 | @staticmethod 57 | def condition(pos: int, used_positions: List[int], sentence: str, lang: str) -> bool: 58 | punctuation = string.punctuation.replace("-", "") 59 | return pos in used_positions or sentence[pos] == " " or sentence[pos] in punctuation 60 | 61 | @staticmethod 62 | def alter_positions(pos: int, used_positions: List[int]): 63 | for i, p in enumerate(used_positions): 64 | used_positions[i] = p if p <= pos else p - 1 65 | 66 | 67 | class TranspositionConditions(Condition): 68 | @staticmethod 69 | def condition(pos: int, used_positions: List[int], sentence: str, lang: str) -> bool: 70 | return pos == len(sentence) - 1 or pos in used_positions or (pos + 1) in used_positions or \ 71 | sentence[pos] in string.punctuation or sentence[pos + 1] in string.punctuation or \ 72 | sentence[pos] == sentence[pos + 1] 73 | 74 | @staticmethod 75 | def alter_positions(pos: int, used_positions: List[int]): 76 | used_positions.append(pos + 1) 77 | 78 | 79 | class SubstitutionConditions(Condition): 80 | @staticmethod 81 | def condition(pos: int, used_positions: List[int], sentence: str, lang: str) -> bool: 82 | return pos in used_positions or sentence[pos] == " " or sentence[pos] in string.punctuation or \ 83 | sentence[pos] in string.digits or ((sentence[pos] in string.ascii_letters) == (lang == "ru")) 84 | 85 | @staticmethod 86 | def alter_positions(pos: int, used_positions: List[int]): 87 | pass 88 | 89 | 90 | class ExtraSeparatorConditions(Condition): 91 | @staticmethod 92 | def condition(pos: int, used_positions: List[int], sentence: str, lang: str) -> bool: 93 | return pos == 0 or sentence[pos - 1] == " " or sentence[pos] == " " or pos in used_positions or \ 94 | sentence[pos] in string.punctuation 95 | 96 | @staticmethod 97 | def alter_positions(pos: int, used_positions: List[int]): 98 | for i, p in enumerate(used_positions): 99 | used_positions[i] = p if p <= pos else p + 1 100 | 101 | 102 | class MissingSeparatorConditions(Condition): 103 | @staticmethod 104 | def condition(pos: int, used_positions: List[int], sentence: str, lang: str) -> bool: 105 | return sentence[pos] != " " 106 | 107 | @staticmethod 108 | def alter_positions(pos: int, used_positions: List[int]): 109 | for i, p in enumerate(used_positions): 110 | used_positions[i] = p if p <= pos else p - 1 111 | 112 | 113 | def initialize_conditions(): 114 | return { 115 | "insertion": InsertionConditions, 116 | "deletion": DeletionConditions, 117 | "substitution": SubstitutionConditions, 118 | "transposition": TranspositionConditions, 119 | "extra_separator": ExtraSeparatorConditions, 120 | "missing_separator": MissingSeparatorConditions, 121 | } 122 | -------------------------------------------------------------------------------- /sage/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_load_utils import load_available_dataset_from_hf, DatasetsAvailable 2 | from .utils import draw_and_save_errors_distributions_comparison_charts 3 | 4 | __all__ = [ 5 | "load_available_dataset_from_hf", 6 | "draw_and_save_errors_distributions_comparison_charts", 7 | "DatasetsAvailable" 8 | ] 9 | -------------------------------------------------------------------------------- /sage/utils/data_load_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for loading datasets from hub.""" 2 | 3 | import enum 4 | from typing import Optional, Union, List, Tuple 5 | 6 | import pandas as pd 7 | from datasets import load_dataset 8 | 9 | 10 | class DatasetsAvailable(enum.Enum): 11 | """Datasets available""" 12 | 13 | MultidomainGold = "Multidomain gold dataset. For more see `ai-forever/spellcheck_punctuation_benchmark`." 14 | RUSpellRU = "Social media texts and blogs. For more see `ai-forever/spellcheck_punctuation_benchmark`." 15 | MedSpellchecker = "Medical anamnesis. For more see `ai-forever/spellcheck_punctuation_benchmark`." 16 | GitHubTypoCorpusRu = "Github commits. For more see `ai-forever/spellcheck_punctuation_benchmark`." 17 | 18 | MultidomainGold_orth = "Multidomain gold dataset orthography only. For more see `ai-forever/spellcheck_benchmark`." 19 | RUSpellRU_orth = "Social media texts and blogs orthography only. For more see `ai-forever/spellcheck_benchmark`." 20 | MedSpellchecker_orth = "Medical anamnesis orthography only. For more see `ai-forever/spellcheck_benchmark`." 21 | GitHubTypoCorpusRu_orth = "Github commits orthography only. For more see `ai-forever/spellcheck_benchmark`." 22 | 23 | 24 | datasets_available = [dataset.name for dataset in DatasetsAvailable] 25 | 26 | 27 | def load_available_dataset_from_hf( 28 | dataset_name: str, for_labeler: bool, split: Optional[str] = None 29 | ) -> Union[Tuple[List[str], List[str]], pd.DataFrame]: 30 | if dataset_name not in datasets_available: 31 | raise ValueError("You provided wrong dataset name: {}\nAvailable datasets are: {}".format( 32 | dataset_name, *datasets_available)) 33 | source_collection = "spellcheck_punctuation_benchmark" 34 | if dataset_name[-4:] == "orth": 35 | source_collection = "spellcheck_benchmark" 36 | dataset_name = dataset_name[:-5] 37 | dataset = load_dataset("ai-forever/{}".format(source_collection), dataset_name, split=split) 38 | if split is None: 39 | dataset = pd.concat([dataset[split].to_pandas() for split in dataset.keys()]).reset_index(drop=True) 40 | else: 41 | dataset = dataset.to_pandas() 42 | if for_labeler: 43 | sources = dataset.source.values.tolist() 44 | corrections = dataset.correction.values.tolist() 45 | return sources, corrections 46 | return dataset 47 | -------------------------------------------------------------------------------- /sage/utils/lang_utils.py: -------------------------------------------------------------------------------- 1 | """Language-related utils""" 2 | 3 | from collections import namedtuple 4 | 5 | AVAILABLE_LANG_CODES = ["ru", "en"] 6 | 7 | 8 | class InsertionOptions(namedtuple("insertion_options", AVAILABLE_LANG_CODES)): 9 | pass 10 | 11 | 12 | class SubstitutionOptions(namedtuple("substitution_options", AVAILABLE_LANG_CODES)): 13 | pass 14 | 15 | 16 | INSERTION_OPTIONS = InsertionOptions( 17 | ru=list("абвгдеёжзийклмнопрстуфхцчшщъыьэюя"), 18 | en=list("abcdefghijklmnopqrstuvwxyz") 19 | ) 20 | 21 | SUBSTITUTION_OPTIONS = SubstitutionOptions( 22 | ru=list("абвгдеёжзийклмнопрстуфхцчшщъыьэюя "), 23 | en=list("abcdefghijklmnopqrstuvwxyz ") 24 | ) 25 | -------------------------------------------------------------------------------- /sage/utils/utils.py: -------------------------------------------------------------------------------- 1 | """General utils""" 2 | 3 | import os 4 | from typing import List, Dict, Any, Union 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | def _draw_distributions_with_spines( 11 | axes: plt.axes, row: int, column: int, actual_values: List[Any], reference_values: List[Any], title: str): 12 | """Draws two discrete distributions on a single plot.""" 13 | 14 | axes[row][column].hist(reference_values, bins=20, ec="black", fc="red", label="Reference", density=True) 15 | axes[row][column].hist(actual_values, bins=20, ec="black", fc="green", label="Actual", alpha=0.7, density=True) 16 | axes[row][column].grid(b=True, color='grey', linestyle='-.', linewidth=0.5, alpha=0.3) 17 | for s in ['top', 'bottom', 'left', 'right']: 18 | axes[row][column].spines[s].set_visible(False) 19 | axes[row][column].set_title(title) 20 | axes[row][column].legend() 21 | 22 | 23 | def draw_and_save_errors_distributions_comparison_charts( 24 | actual_typos_cnt: List[int], 25 | reference_typos_cnt: List[int], 26 | actual_stats: Dict[str, Dict[str, List]], 27 | reference_stats: Dict[str, Dict[str, List]], 28 | path_to_save: Union[str, os.PathLike], 29 | ): 30 | """Draws following distributions for reference and actual values: 31 | 1. Number of errors per sentence; 32 | 2. Types of errors; 33 | 3. Relative positions of each type of errrors; 34 | ... and saves resulting charts to mentioned place `path_to_save`. 35 | 36 | Args: 37 | actual_typos_cnt (List[int]): number of errors actual; 38 | reference_typos_cnt (List[int]): number of errors reference; 39 | actual_stats (Dict[str, Dict[str, List]]): types of errors and their relative positions, actual; 40 | reference_stats (Dict[str, Dict[str, List]]): types of errors and their relative positions, reference; 41 | path_to_save (Union[str, os.PathLike]): where to save charts; 42 | """ 43 | 44 | def _stats(d): 45 | tmp = {k: len(v["abs"]) for k, v in d.items()} 46 | total = sum(tmp.values()) 47 | return tmp, total 48 | 49 | _, ax = plt.subplots(4, 2, figsize=(15, 15)) 50 | plt.rcParams["figure.autolayout"] = True 51 | 52 | _draw_distributions_with_spines( 53 | ax, 0, 0, actual_typos_cnt, reference_typos_cnt, "Number of errors per sentence") 54 | _draw_distributions_with_spines( 55 | ax, 1, 0, actual_stats["insertion"]["rel"], reference_stats["insertion"]["rel"], 56 | "Relative positions of insertions") 57 | _draw_distributions_with_spines( 58 | ax, 1, 1, actual_stats["deletion"]["rel"], reference_stats["deletion"]["rel"], 59 | "Relative positions of deletions") 60 | _draw_distributions_with_spines( 61 | ax, 2, 0, actual_stats["substitution"]["rel"], reference_stats["substitution"]["rel"], 62 | "Relative positions of substitutions") 63 | _draw_distributions_with_spines( 64 | ax, 2, 1, actual_stats["transposition"]["rel"], reference_stats["transposition"]["rel"], 65 | "Relative positions of transpositions") 66 | _draw_distributions_with_spines( 67 | ax, 3, 0, actual_stats["missing_separator"]["rel"], reference_stats["missing_separator"]["rel"], 68 | "Relative positions of missing separators") 69 | _draw_distributions_with_spines( 70 | ax, 3, 1, actual_stats["extra_separator"]["rel"], reference_stats["extra_separator"]["rel"], 71 | "Relative positions of extra separators") 72 | 73 | width = 0.3 74 | d_ref, t_ref = _stats(reference_stats) 75 | d_act, t_act = _stats(actual_stats) 76 | d_ref = {k: v / (t_ref + 0.000001) for k, v in d_ref.items()} 77 | d_act = {k: v / (t_act + 0.000001) for k, v in d_act.items()} 78 | labels = list(d_ref.keys()) 79 | ids = np.array(range(len(labels))) 80 | 81 | ax[0][1].barh(ids, list(d_act.values()), width, color='green', label='Actual') 82 | ax[0][1].barh(ids + width, list(d_ref.values()), width, color='red', label='Reference') 83 | ax[0][1].set(yticks=ids + width, yticklabels=labels) 84 | ax[0][1].grid(b=True, color='grey', linestyle='-.', linewidth=0.5, alpha=0.3) 85 | for s in ['top', 'bottom', 'left', 'right']: 86 | ax[0][1].spines[s].set_visible(False) 87 | ax[0][1].legend() 88 | ax[0][1].set_title("Types of errors") 89 | 90 | plt.savefig(path_to_save) 91 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import setup, find_packages 4 | 5 | with open("README.md", mode="r", encoding="utf-8") as readme_file: 6 | readme = readme_file.read() 7 | 8 | setup_dir = os.path.abspath(os.path.dirname(__file__)) 9 | augmentex_path = os.path.join(setup_dir, "wheels/augmentex-1.0.3-py3-none-any.whl") 10 | 11 | requirements = [ 12 | "numpy", 13 | "pandas", 14 | "tqdm", 15 | "pyyaml", 16 | "packaging", 17 | "requests", 18 | "sentencepiece", 19 | "datasets", 20 | "protobuf", 21 | "timeout_decorator", 22 | "matplotlib>=3.2,<3.7", 23 | "torch>=1.9.0,<=2.2.0", 24 | "transformers>=4.20.0", 25 | f"augmentex @ file://{augmentex_path}" 26 | ] 27 | 28 | extras_requirements = { 29 | "errant": [ 30 | "ru-core-news-lg @ https://huggingface.co/spacy/ru_core_news_lg/resolve/main/ru_core_news_lg-any-py3-none-any.whl", 31 | "errant @ git+https://github.com/Askinkaty/errant/@4183e57", 32 | "Levenshtein" 33 | ] 34 | } 35 | 36 | setup( 37 | name="sage", 38 | version="1.1.0", 39 | author="Nikita Martynov, Mark Baushenko, Alena Fenogenova and Alexandr Abramov", 40 | author_email="nikita.martynov.98@list.ru", 41 | description="SAGE: Spell checking via Augmentation and Generative distribution Emulation", 42 | long_description=readme, 43 | long_description_content_type="text/markdown", 44 | license="MIT", 45 | url="https://github.com/orgs/ai-forever/sage", 46 | packages=find_packages(), 47 | classifiers=[ 48 | "Natural Language :: English", 49 | "Natural Language :: Russian", 50 | "Programming Language :: Python :: 3.8", 51 | "Programming Language :: Python :: 3.9", 52 | "Programming Language :: Python :: 3.10", 53 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 54 | "Topic :: Software Development :: Libraries :: Python Modules", 55 | "Topic :: Text Editors :: Text Processing", 56 | ], 57 | python_requires=">=3.8.0,<3.11.0", 58 | install_requires=requirements, 59 | extras_require=extras_requirements, 60 | keywords="sage spelling correction nlp deep learning transformers pytorch" 61 | ) 62 | -------------------------------------------------------------------------------- /tests/corruptor_api_unittests.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | import numpy as np 4 | 5 | from sage.spelling_corruption.sbsc.labeler import process_mistypings 6 | from sage.spelling_corruption import CharAugConfig, WordAugConfig, SBSCConfig 7 | from sage.utils.data_load_utils import load_available_dataset_from_hf 8 | from sage.utils.utils import draw_and_save_errors_distributions_comparison_charts 9 | from sage.spelling_corruption import WordAugCorruptor, CharAugCorruptor, SBSCCorruptor 10 | 11 | SEED = 0 12 | 13 | 14 | class CorruptorApiTests(unittest.TestCase): 15 | sentence = "я пошел домой" 16 | sentences = [sentence] * 3 17 | n_tests = 10 18 | sources, corrections = load_available_dataset_from_hf("RUSpellRU", for_labeler=True, split="train") 19 | ruspellru_stats, ruspellru_confusion_matrix, ruspellru_typos_cnt = process_mistypings(sources, corrections) 20 | 21 | def _draw_generated_distributions(self, corruptor, file_name): 22 | spoiled_sentences = corruptor.batch_corrupt(self.corrections, seed=SEED) 23 | ours_ruspellru_stats, ours_ruspellru_confusion_matrix, ours_ruspellru_typos_cnt = \ 24 | process_mistypings(spoiled_sentences, self.corrections) 25 | draw_and_save_errors_distributions_comparison_charts( 26 | actual_typos_cnt=ours_ruspellru_typos_cnt, 27 | reference_typos_cnt=self.ruspellru_typos_cnt, 28 | actual_stats=ours_ruspellru_stats, 29 | reference_stats=self.ruspellru_stats, 30 | path_to_save=file_name 31 | ) 32 | 33 | def _test_generation_correct(self, corruptor): 34 | res = corruptor.corrupt(self.sentence) 35 | self.assertEqual(type(res), str) 36 | 37 | res = corruptor.batch_corrupt(self.sentences) 38 | self.assertEqual(3, len(res)) 39 | 40 | def test_word_augmenter(self): 41 | default_config = WordAugConfig() 42 | corruptor = WordAugCorruptor.from_default_config() 43 | self.assertEqual(corruptor.engine.min_aug, default_config.min_aug) 44 | self.assertEqual(corruptor.engine.max_aug, default_config.max_aug) 45 | self.assertEqual(corruptor.engine.unit_prob, default_config.unit_prob) 46 | self._test_generation_correct(corruptor) 47 | 48 | config = WordAugConfig(min_aug=2, max_aug=6, unit_prob=0.1) 49 | corruptor = WordAugCorruptor.from_config(config) 50 | self.assertEqual(corruptor.engine.min_aug, config.min_aug) 51 | self.assertEqual(corruptor.engine.max_aug, config.max_aug) 52 | self.assertEqual(corruptor.engine.unit_prob, config.unit_prob) 53 | self._test_generation_correct(corruptor) 54 | 55 | def test_char_augmenter(self): 56 | default_config = CharAugConfig() 57 | corruptor = CharAugCorruptor.from_default_config() 58 | self.assertEqual(corruptor.engine.min_aug, default_config.min_aug) 59 | self.assertEqual(corruptor.engine.max_aug, default_config.max_aug) 60 | self.assertEqual(corruptor.engine.unit_prob, default_config.unit_prob) 61 | self.assertEqual(corruptor.engine.mult_num, default_config.mult_num) 62 | self._test_generation_correct(corruptor) 63 | 64 | config = CharAugConfig(min_aug=2, max_aug=6, unit_prob=0.1, mult_num=3) 65 | corruptor = CharAugCorruptor.from_config(config) 66 | self.assertEqual(corruptor.engine.min_aug, config.min_aug) 67 | self.assertEqual(corruptor.engine.max_aug, config.max_aug) 68 | self.assertEqual(corruptor.engine.unit_prob, config.unit_prob) 69 | self.assertEqual(corruptor.engine.mult_num, config.mult_num) 70 | self._test_generation_correct(corruptor) 71 | 72 | def test_sbsc_corruptor(self): 73 | # From custom stats 74 | typos_count = [3] 75 | stats = { 76 | "insertion": { 77 | "abs": [0], "rel": [np.random.uniform(high=0.1)] 78 | }, 79 | "deletion": { 80 | "abs": [2], "rel": [np.random.uniform(low=0.2, high=0.3)] 81 | }, 82 | "substitution": { 83 | "abs": [1], "rel": [np.random.uniform(low=0.4, high=0.5)] 84 | }, 85 | "transposition": { 86 | "abs": [3], "rel": [np.random.uniform(low=0.6, high=0.7)] 87 | }, 88 | "extra_separator": { 89 | "abs": [4], "rel": [np.random.uniform(low=0.8, high=0.9)] 90 | }, 91 | "missing_separator": { 92 | "abs": [5], "rel": [np.random.uniform(low=0.9, high=1.)] 93 | }, 94 | } 95 | config = SBSCConfig( 96 | typos_count=typos_count, 97 | stats=stats, 98 | confusion_matrix={" ": {" ": 1}}, 99 | ) 100 | corruptor = SBSCCorruptor.from_config(config) 101 | self._test_generation_correct(corruptor) 102 | 103 | # From txt files 104 | config = SBSCConfig( 105 | reference_dataset_name_or_path=os.path.join(os.getcwd(), "data", "sanity_check_samples", "corruptor_tests") 106 | ) 107 | corruptor = SBSCCorruptor.from_config(config) 108 | self._test_generation_correct(corruptor) 109 | 110 | # From csv file 111 | config = SBSCConfig( 112 | reference_dataset_name_or_path=os.path.join( 113 | os.getcwd(), "data", "sanity_check_samples", "corruptor_tests", "csv") 114 | ) 115 | corruptor = SBSCCorruptor.from_config(config) 116 | self._test_generation_correct(corruptor) 117 | 118 | # From partial custom stats 119 | config = SBSCConfig( 120 | typos_count=None, 121 | stats=stats, 122 | confusion_matrix={" ": {" ": 1}}, 123 | skip_if_position_not_found=True, 124 | reference_dataset_name_or_path=os.path.join(os.getcwd(), "data", "sanity_check_samples", "corruptor_tests"), 125 | ) 126 | corruptor = SBSCCorruptor.from_config(config) 127 | self._test_generation_correct(corruptor) 128 | 129 | config = SBSCConfig( 130 | typos_count=typos_count, 131 | stats=None, 132 | confusion_matrix={" ": {" ": 1}}, 133 | skip_if_position_not_found=True, 134 | reference_dataset_name_or_path=os.path.join(os.getcwd(), "data", "sanity_check_samples", "corruptor_tests"), 135 | ) 136 | corruptor = SBSCCorruptor.from_config(config) 137 | self._test_generation_correct(corruptor) 138 | 139 | config = SBSCConfig( 140 | typos_count=typos_count, 141 | stats=stats, 142 | confusion_matrix=None, 143 | skip_if_position_not_found=True, 144 | reference_dataset_name_or_path=os.path.join(os.getcwd(), "data", "sanity_check_samples", "corruptor_tests"), 145 | ) 146 | corruptor = SBSCCorruptor.from_config(config) 147 | self._test_generation_correct(corruptor) 148 | 149 | config = SBSCConfig( 150 | typos_count=None, 151 | stats=None, 152 | confusion_matrix={" ": {" ": 1}}, 153 | skip_if_position_not_found=True, 154 | reference_dataset_name_or_path=os.path.join(os.getcwd(), "data", "sanity_check_samples", "corruptor_tests"), 155 | ) 156 | corruptor = SBSCCorruptor.from_config(config) 157 | self._test_generation_correct(corruptor) 158 | 159 | config = SBSCConfig( 160 | typos_count=None, 161 | stats=stats, 162 | confusion_matrix=None, 163 | skip_if_position_not_found=True, 164 | reference_dataset_name_or_path=os.path.join(os.getcwd(), "data", "sanity_check_samples", "corruptor_tests") 165 | ) 166 | corruptor = SBSCCorruptor.from_config(config) 167 | self._test_generation_correct(corruptor) 168 | 169 | config = SBSCConfig( 170 | typos_count=typos_count, 171 | stats=None, 172 | confusion_matrix=None, 173 | skip_if_position_not_found=True, 174 | reference_dataset_name_or_path=os.path.join(os.getcwd(), "data", "sanity_check_samples", "corruptor_tests") 175 | ) 176 | corruptor = SBSCCorruptor.from_config(config) 177 | self._test_generation_correct(corruptor) 178 | 179 | def test_sbsc_corruptor_batch_corrupt(self): 180 | typos_count = [3] 181 | stats = { 182 | "insertion": { 183 | "abs": [0], "rel": [np.random.uniform(high=0.1)] 184 | }, 185 | "deletion": { 186 | "abs": [2], "rel": [np.random.uniform(low=0.2, high=0.3)] 187 | }, 188 | "substitution": { 189 | "abs": [1], "rel": [np.random.uniform(low=0.4, high=0.5)] 190 | }, 191 | "transposition": { 192 | "abs": [3], "rel": [np.random.uniform(low=0.6, high=0.7)] 193 | }, 194 | "extra_separator": { 195 | "abs": [4], "rel": [np.random.uniform(low=0.8, high=0.9)] 196 | }, 197 | "missing_separator": { 198 | "abs": [5], "rel": [np.random.uniform(low=0.9, high=1.)] 199 | }, 200 | } 201 | config = SBSCConfig( 202 | typos_count=typos_count, 203 | stats=stats, 204 | confusion_matrix={" ": {" ": 1}}, 205 | ) 206 | corruptor = SBSCCorruptor.from_config(config) 207 | self._draw_generated_distributions(corruptor, "sbsc_random_stats.jpg") 208 | 209 | config = SBSCConfig( 210 | reference_dataset_name_or_path=os.path.join(os.getcwd(), "data", "sanity_check_samples", "corruptor_tests"), 211 | ) 212 | corruptor = SBSCCorruptor.from_config(config) 213 | self._draw_generated_distributions(corruptor, "sbsc_stats_from_txt.jpg") 214 | 215 | config = SBSCConfig( 216 | reference_dataset_name_or_path=os.path.join( 217 | os.getcwd(), "data", "sanity_check_samples", "corruptor_tests", "csv"), 218 | ) 219 | corruptor = SBSCCorruptor.from_config(config) 220 | self._draw_generated_distributions(corruptor, "sbsc_stats_from_csv.jpg") 221 | 222 | config = SBSCConfig( 223 | typos_count=None, 224 | stats=stats, 225 | confusion_matrix={" ": {" ": 1}}, 226 | skip_if_position_not_found=True, 227 | reference_dataset_name_or_path=os.path.join(os.getcwd(), "data", "sanity_check_samples", "corruptor_tests"), 228 | ) 229 | corruptor = SBSCCorruptor.from_config(config) 230 | self._draw_generated_distributions(corruptor, "sbsc_custom_stats_typos_from_txt.jpg") 231 | 232 | config = SBSCConfig( 233 | reference_dataset_name_or_path="RUSpellRU", 234 | reference_dataset_split="train" 235 | ) 236 | corruptor = SBSCCorruptor.from_config(config) 237 | self._draw_generated_distributions(corruptor, "sbsc_from_dataset.jpg") 238 | 239 | # English version 240 | en_config = SBSCConfig( 241 | lang="en", 242 | reference_dataset_name_or_path=os.path.join(os.getcwd(), "data", "example_data", "bea60k", "subsample") 243 | ) 244 | corruptor = SBSCCorruptor.from_config(en_config) 245 | 246 | with open(os.path.join(os.getcwd(), "data", "example_data", "bea60k", "subsample", "sources.txt")) as src: 247 | sources = src.read().split("\n") 248 | with open(os.path.join(os.getcwd(), "data", "example_data", "bea60k", "subsample", "corrections.txt")) as corr: 249 | corrections = corr.read().split("\n") 250 | 251 | bea_stats, bea_confusion_matrix, bea_typos_cnt = process_mistypings(sources, corrections) 252 | spoiled_sentences = corruptor.batch_corrupt(corrections) 253 | ours_bea_stats, ours_bea_confusion_matrix, ours_bea_typos_cnt = process_mistypings(spoiled_sentences, corrections) 254 | draw_and_save_errors_distributions_comparison_charts( 255 | actual_typos_cnt=ours_bea_typos_cnt, 256 | reference_typos_cnt=bea_typos_cnt, 257 | actual_stats=ours_bea_stats, 258 | reference_stats=bea_stats, 259 | path_to_save="bea60k.jpg" 260 | ) 261 | 262 | with self.assertRaises(RuntimeError): 263 | corruptor = SBSCCorruptor.from_config( 264 | config=SBSCConfig( 265 | reference_dataset_name_or_path=None, 266 | ) 267 | ) 268 | 269 | def test_word_aug_batch_corrupt(self): 270 | config = WordAugConfig(min_aug=2, max_aug=6, unit_prob=0.1) 271 | corruptor = WordAugCorruptor.from_config(config) 272 | self._draw_generated_distributions(corruptor, "word_aug.jpg") 273 | 274 | config = CharAugConfig(min_aug=2, max_aug=6, unit_prob=0.1, mult_num=3) 275 | corruptor = CharAugCorruptor.from_config(config) 276 | self._draw_generated_distributions(corruptor, "char_aug.jpg") 277 | 278 | 279 | if __name__ == '__main__': 280 | unittest.main() 281 | -------------------------------------------------------------------------------- /tests/test_correctors.py: -------------------------------------------------------------------------------- 1 | from sage.spelling_correction import RuM2M100ModelForSpellingCorrection, T5ModelForSpellingCorruption 2 | from sage.spelling_correction import AvailableCorrectors 3 | 4 | example_sentences_ru = [ 5 | "Я пшёл домой", 6 | "Очень классная тетка ктобы что не говорил." 7 | ] 8 | 9 | example_sentences_en = [ 10 | "Fathr telll me, we get or we dereve.", 11 | "Scrw you guys, I am goin homee. (c)" 12 | ] 13 | 14 | example_sentences_ru_en = [ 15 | "Перведи мне текст на аглиском: \"Screw you kuys, I am goin hme (c).", 16 | "\"Don't you went to go upstayers?\", - сказл мне както дед." 17 | ] 18 | 19 | if __name__ == "__main__": 20 | m2m_large_corrector = RuM2M100ModelForSpellingCorrection.from_pretrained(AvailableCorrectors.m2m100_1B.value) 21 | m2m_small_corrector = RuM2M100ModelForSpellingCorrection.from_pretrained(AvailableCorrectors.m2m100_418M.value) 22 | fred_corrector = T5ModelForSpellingCorruption.from_pretrained(AvailableCorrectors.fred_large.value) 23 | ent5_corrector = T5ModelForSpellingCorruption.from_pretrained(AvailableCorrectors.ent5_large.value) 24 | 25 | sage_fredt5_large = T5ModelForSpellingCorruption.from_pretrained(AvailableCorrectors.sage_fredt5_large.value) 26 | sage_fredt5_distilled = T5ModelForSpellingCorruption.from_pretrained( 27 | AvailableCorrectors.sage_fredt5_distilled_95m.value) 28 | sage_mt5_large = T5ModelForSpellingCorruption.from_pretrained(AvailableCorrectors.sage_mt5_large.value) 29 | sage_m2m_100 = RuM2M100ModelForSpellingCorrection.from_pretrained(AvailableCorrectors.sage_m2m100_1B.value) 30 | 31 | print("\n------------------------------------------------------------------\n") 32 | 33 | print("m2m_large_corrector: {}".format(m2m_large_corrector.correct(example_sentences_ru[0]))) 34 | print("m2m_small_corrector: {}".format(m2m_small_corrector.correct(example_sentences_ru[0]))) 35 | print("fred_corrector: {}".format(fred_corrector.correct(example_sentences_ru[0], prefix="Исправь: "))) 36 | print("ent5_corrector: {}".format(ent5_corrector.correct(example_sentences_en[0], prefix="grammar: "))) 37 | 38 | print("sage_fredt5_large: {}".format(sage_fredt5_large.correct(example_sentences_ru[0], prefix=""))) 39 | print("sage_fredt5_distilled: {}".format(sage_fredt5_distilled.correct(example_sentences_ru[0], prefix=""))) 40 | print("sage_mt5_large: {}".format(sage_mt5_large.correct(example_sentences_ru_en[0]))) 41 | print("sage_m2m_100: {}".format(sage_m2m_100.correct(example_sentences_ru[0]))) 42 | 43 | print("\n------------------------------------------------------------------\n") 44 | 45 | print("\nm2m_large_corrector:\n") 46 | print("\n".join(["{}: {}".format(k, v[0]) for k, v in zip( 47 | example_sentences_ru, m2m_large_corrector.batch_correct(example_sentences_ru, 1))])) 48 | 49 | print("\nm2m_small_corrector: \n") 50 | print("\n".join(["{}: {}".format(k, v[0]) for k, v in zip( 51 | example_sentences_ru, m2m_small_corrector.batch_correct(example_sentences_ru, 1))])) 52 | 53 | print("\nfred_corrector: \n") 54 | print("\n".join(["{}: {}".format(k, v[0]) for k, v in zip( 55 | example_sentences_ru, fred_corrector.batch_correct(example_sentences_ru, 1, "Исправь: "))])) 56 | 57 | print("\nent5_corrector: \n") 58 | print("\n".join(["{}: {}".format(k, v[0]) for k, v in zip( 59 | example_sentences_en, ent5_corrector.batch_correct(example_sentences_en, 1, "grammar: "))])) 60 | 61 | print("\nsage_fredt5_large:\n") 62 | print("\n".join(["{}: {}".format(k, v[0]) for k, v in zip( 63 | example_sentences_ru, sage_fredt5_large.batch_correct(example_sentences_ru, 1, prefix=""))])) 64 | 65 | print("\nsage_fredt5_distilled: \n") 66 | print("\n".join(["{}: {}".format(k, v[0]) for k, v in zip( 67 | example_sentences_ru, sage_fredt5_distilled.batch_correct(example_sentences_ru, 1, prefix=""))])) 68 | 69 | print("\nsage_mt5_large: \n") 70 | print("\n".join(["{}: {}".format(k, v[0]) for k, v in zip( 71 | example_sentences_ru_en, sage_mt5_large.batch_correct(example_sentences_ru_en, 1))])) 72 | 73 | print("\nsage_m2m_100: \n") 74 | print("\n".join(["{}: {}".format(k, v[0]) for k, v in zip( 75 | example_sentences_ru, sage_m2m_100.batch_correct(example_sentences_en, 1))])) 76 | -------------------------------------------------------------------------------- /tests/test_corruptors.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from sage.spelling_corruption import WordAugCorruptor, CharAugCorruptor, SBSCCorruptor 4 | from sage.spelling_corruption import WordAugConfig, CharAugConfig, SBSCConfig 5 | 6 | corrections = [ 7 | "Я пошёл домой. Больше мне тут делать нечего", 8 | "Заметьте, - не я это предложил." 9 | ] 10 | 11 | en_corrections = [ 12 | "Father tell me: we get or we deserve", 13 | "Screw you guys, I am going home (c)." 14 | ] 15 | 16 | if __name__ == "__main__": 17 | word_aug_default = WordAugCorruptor.from_default_config() 18 | word_aug_config = WordAugConfig( 19 | min_aug=2, max_aug=6, unit_prob=0.1 20 | ) 21 | word_aug_custom = WordAugCorruptor.from_config(word_aug_config) 22 | 23 | char_aug_default = CharAugCorruptor.from_default_config() 24 | char_aug_config = CharAugConfig( 25 | min_aug=2, max_aug=6, unit_prob=0.1, mult_num=3 26 | ) 27 | char_aug_custom = CharAugCorruptor.from_config(char_aug_config) 28 | 29 | sbsc_default = SBSCCorruptor.from_default_config() 30 | sbsc_config = SBSCConfig( 31 | reference_dataset_name_or_path="MedSpellchecker", 32 | reference_dataset_split="test" 33 | ) 34 | sbsc_custom = SBSCCorruptor.from_config(sbsc_config) 35 | 36 | sbsc_config_en = SBSCConfig( 37 | lang="en", 38 | reference_dataset_name_or_path=os.path.join(os.getcwd(), "data", "example_data", "bea60k", "subsample") 39 | ) 40 | sbsc_english = SBSCCorruptor.from_config(sbsc_config_en) 41 | 42 | print("\n------------------------------------------------------------------\n") 43 | 44 | print("word_aug_default: {}".format(word_aug_default.corrupt(corrections[0]))) 45 | print("word_aug_custom: {}".format(word_aug_custom.corrupt(corrections[0]))) 46 | print("char_aug_default: {}".format(char_aug_default.corrupt(corrections[0]))) 47 | print("char_aug_custom: {}".format(char_aug_custom.corrupt(corrections[0]))) 48 | print("sbsc_default: {}".format(sbsc_default.corrupt(corrections[0]))) 49 | print("sbsc_custom: {}".format(sbsc_custom.corrupt(corrections[0]))) 50 | print("sbsc_english: {}".format(sbsc_english.corrupt(en_corrections[0]))) 51 | 52 | print("\n------------------------------------------------------------------\n") 53 | 54 | print("word_aug_default: \n{}\n".format("\n".join(word_aug_default.batch_corrupt(corrections, batch_prob=0.5)))) 55 | print("word_aug_custom: \n{}\n".format("\n".join(word_aug_custom.batch_corrupt(corrections)))) 56 | print("char_aug_default: \n{}\n".format("\n".join(char_aug_default.batch_corrupt(corrections, batch_prob=0.5)))) 57 | print("char_aug_custom: \n{}\n".format("\n".join(char_aug_custom.batch_corrupt(corrections)))) 58 | print("sbsc_default: \n{}\n".format("\n".join(sbsc_default.batch_corrupt(corrections)))) 59 | print("sbsc_custom: \n{}\n".format("\n".join(sbsc_custom.batch_corrupt(corrections)))) 60 | print("sbsc_english: {}".format("\n".join(sbsc_english.batch_corrupt(en_corrections)))) 61 | -------------------------------------------------------------------------------- /tests/test_evaluate.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from sage.evaluation import Scorer 4 | 5 | 6 | class TestEvaluationKit(unittest.TestCase): 7 | 8 | scorer = Scorer() 9 | 10 | sources = ["спел Кейс ее .", "спел Кейс ее ."] 11 | corrections = ["спелл кейс её !", "спелл кейс её !"] 12 | answers = ["спел кейс её .", "спелл Кейс ее !"] 13 | 14 | def test_scorer_errant_only(self): 15 | metrics = self.scorer.score(self.sources, self.corrections, self.answers, metrics=["errant"]) 16 | expected_metrics = { 17 | "CASE_Precision": 100.0, "CASE_Recall": 50.0, "CASE_F1": 66.67, 18 | "YO_Precision": 100.0, "YO_Recall": 50.0, "YO_F1": 66.67, 19 | "SPELL_Precision": 100.0, "SPELL_Recall": 50.0, "SPELL_F1": 66.67, 20 | "PUNCT_Precision": 100.0, "PUNCT_Recall": 50.0, "PUNCT_F1": 66.67 21 | } 22 | self.assertDictEqual(metrics, expected_metrics) 23 | 24 | def test_scorer_ruspelleval_only(self): 25 | metrics = self.scorer.score(self.sources, self.corrections, self.answers, metrics=["ruspelleval"]) 26 | self.assertDictEqual(metrics, {"Precision": 100.0, "Recall": 50.0, "F1": 66.67}) 27 | 28 | def test_scorer_errant_ruspelleval(self): 29 | metrics = self.scorer.score(self.sources, self.corrections, self.answers, metrics=["errant", "ruspelleval"]) 30 | expected_metrics = { 31 | "CASE_Precision": 100.0, "CASE_Recall": 50.0, "CASE_F1": 66.67, 32 | "YO_Precision": 100.0, "YO_Recall": 50.0, "YO_F1": 66.67, 33 | "SPELL_Precision": 100.0, "SPELL_Recall": 50.0, "SPELL_F1": 66.67, 34 | "PUNCT_Precision": 100.0, "PUNCT_Recall": 50.0, "PUNCT_F1": 66.67, 35 | "Precision": 100.0, "Recall": 50.0, "F1": 66.67 36 | } 37 | self.assertDictEqual(metrics, expected_metrics) 38 | 39 | def test_empty_errant(self): 40 | scorer = Scorer(False) 41 | self.assertRaises( 42 | AttributeError, 43 | scorer.score, 44 | **{"sources": self.sources, "corrections": self.corrections, "answers": self.answers, "metrics": ["errant"]} 45 | ) 46 | 47 | 48 | if __name__ == "__main__": 49 | unittest.main() 50 | -------------------------------------------------------------------------------- /tests/test_ruspelleval.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from sage.utils import load_available_dataset_from_hf, DatasetsAvailable 4 | from sage.evaluation.ruspelleval import evaluation 5 | 6 | 7 | class TestRuSpellEval(unittest.TestCase): 8 | 9 | def test_ruspelleval_edge_cases(self): 10 | ruspell_sources, ruspell_corrections = load_available_dataset_from_hf( 11 | DatasetsAvailable.RUSpellRU.name, for_labeler=True, split="test") 12 | coincide_metrics = evaluation(ruspell_sources, ruspell_corrections, ruspell_corrections) 13 | self.assertEqual(coincide_metrics, {"Precision": 100.0, "Recall": 100.0, "F1": 100.0}) 14 | loose_metrics = evaluation( 15 | ruspell_sources, ruspell_corrections, ruspell_sources[:-1] + [ruspell_corrections[-1]]) 16 | self.assertEqual(loose_metrics, {"Precision": 100.0, "Recall": 0.05, "F1": 0.1}) 17 | 18 | def test_ruspelleval_general_case(self): 19 | source = " ".join(["фотка", "классная", "кстате", "хоть", "и", "не", "по", "теме"]) 20 | correction = " ".join(["фотка", "классная", "кстати", "хоть", "и", "не", "по", "теме"]) 21 | answer = " ".join(["фотка", "классная", "кстати", "хотя", "не", "по", "теме"]) 22 | case1_metrics = evaluation([source], [correction], [answer]) 23 | self.assertEqual(case1_metrics, {"Precision": 50.0, "Recall": 100.0, "F1": 66.67}) 24 | 25 | 26 | if __name__ == "__main__": 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from sage.utils import load_available_dataset_from_hf, DatasetsAvailable 4 | 5 | datasets_available = [dataset.name for dataset in DatasetsAvailable] 6 | 7 | DF2LENS = { 8 | "MultidomainGold": {"train": 3569, "test": 4106}, 9 | "RUSpellRU": {"train": 2000, "test": 2008}, 10 | "MedSpellchecker": {"test": 1054}, 11 | "GitHubTypoCorpusRu": {"test": 868}, 12 | "MultidomainGold_orth": {"train": 3571, "test": 4107}, 13 | "RUSpellRU_orth": {"train": 2000, "test": 2008}, 14 | "MedSpellchecker_orth": {"test": 1054}, 15 | "GitHubTypoCorpusRu_orth": {"test": 868}, 16 | } 17 | 18 | 19 | class TestUtils(unittest.TestCase): 20 | def test_load_datasets_from_hf(self): 21 | for dataset_name in datasets_available: 22 | splits = list(DF2LENS[dataset_name].keys()) 23 | for split in splits: 24 | dataset_split = load_available_dataset_from_hf(dataset_name, split=split, for_labeler=False) 25 | self.assertEqual(len(dataset_split), DF2LENS[dataset_name][split]) 26 | sources, corrections = load_available_dataset_from_hf(dataset_name, split=split, for_labeler=True) 27 | self.assertEqual(len(sources), DF2LENS[dataset_name][split]) 28 | self.assertEqual(len(corrections), DF2LENS[dataset_name][split]) 29 | dataset = load_available_dataset_from_hf(dataset_name, for_labeler=False) 30 | self.assertEqual(len(dataset), sum(DF2LENS[dataset_name].values())) 31 | 32 | 33 | if __name__ == '__main__': 34 | unittest.main() 35 | -------------------------------------------------------------------------------- /wheels/augmentex-1.0.3-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-forever/sage/71736b6d14b1dd4a369ea9289eb2cac9a5a25f21/wheels/augmentex-1.0.3-py3-none-any.whl --------------------------------------------------------------------------------