├── VERSION ├── omogre ├── accentuator │ ├── __init__.py │ ├── unk_vocab.py │ ├── unk_model.py │ ├── unk_reader.py │ ├── reader.py │ ├── tokenizer.py │ ├── model.py │ └── bert.py ├── transcriptor │ ├── __init__.py │ ├── transcriptor.py │ └── unk_words.py ├── downloader.py └── __init__.py ├── download_data.py ├── test.py ├── setup.py ├── ruslan_markup.py ├── .gitignore ├── README_eng.md ├── Wav2vec2_ru_ipa.ipynb ├── README.md └── LICENSE /VERSION: -------------------------------------------------------------------------------- 1 | 0.1.0 2 | -------------------------------------------------------------------------------- /omogre/accentuator/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from .model import Accentuator 4 | -------------------------------------------------------------------------------- /omogre/transcriptor/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from .transcriptor import Transcriptor 4 | -------------------------------------------------------------------------------- /download_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from omogre import find_model 4 | from argparse import ArgumentParser 5 | 6 | parser = ArgumentParser(description="Download omogre model") 7 | 8 | parser.add_argument("--data_path", type=str, default=None, help="omogre model direcory") 9 | parser.add_argument("--file_name", type=str, default='accentuator_transcriptor_tiny', help="omogre model direcory") 10 | args = parser.parse_args() 11 | 12 | if __name__ == "__main__": 13 | path = find_model(file_name=args.file_name, cache_dir=args.data_path, reload=True) 14 | print('find_model', path) 15 | 16 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from omogre import Accentuator, Transcriptor 2 | 3 | # данные будут скачаны в директорию 'omogre_data' 4 | 5 | transcriptor = Transcriptor(data_path='omogre_data') 6 | 7 | accentuator = Accentuator(data_path='omogre_data') 8 | 9 | sentence_list = ['А на камне том туристы понаписали всякой ху-у-у-лиганщины, типа: "Здесь бил Вася", "Коля + Света + сосед с четвертого этажа + ротвейлер тети Маши + взвод лейтенанта Миши + Белая лошадь (виски такое) = любовь".'] 10 | 11 | #['А на камне том туристы понаписали всякой ху-у-у-лиганщины, типа: "Здесь бил Вася", "Коля ! Света ! сосед с четвертого этажа ! ротвейлер тети Маши ! взвод лейтенанта Миши ! Белая лошадь (виски такое) = любовь".'] 12 | 13 | #['стены замка'] 14 | 15 | print('accentuator()', accentuator(sentence_list)) 16 | 17 | print('transcriptor()', transcriptor(sentence_list)) 18 | 19 | print('transcriptor.transcribe', transcriptor.transcribe(sentence_list)) 20 | print('accentuator.accentuate', accentuator.accentuate(sentence_list)) 21 | 22 | print('transcriptor.accentuate', transcriptor.accentuate(sentence_list)) 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | 2 | from setuptools import find_packages, setup 3 | 4 | setup( 5 | name="omogre", 6 | version="0.1.0", 7 | author="omogr", 8 | author_email="omogrus@ya.ru", 9 | description="Russian accentuator and IPA transcriptor", 10 | long_description=open("README.md", "r", encoding='utf-8').read(), 11 | long_description_content_type="text/markdown", 12 | keywords='Russian accentuator IPA transcriptor', 13 | license='CC BY-NC-SA 4.0', 14 | url="https://github.com/omogr/omogre", 15 | packages=find_packages(), 16 | 17 | install_requires=['torch>=0.4.1', 18 | 'numpy', 19 | 'requests', 20 | 'tqdm'], 21 | 22 | classifiers=[ 23 | 'Programming Language :: Python :: 3', 24 | 'License :: Free for non-commercial use', 25 | 'Natural Language :: Russian', 26 | 'Topic :: Text Processing :: Linguistic', 27 | 'Topic :: Multimedia :: Sound/Audio :: Sound Synthesis', 28 | 'Topic :: Multimedia :: Sound/Audio :: Speech', 29 | ], 30 | ) 31 | -------------------------------------------------------------------------------- /omogre/downloader.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import logging 4 | import os 5 | import tarfile 6 | import tempfile 7 | import sys 8 | import requests 9 | from tqdm import tqdm 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | def download_model(cache_dir: str, file_name: str): 14 | model_url: str = "https://huggingface.co/omogr/omogre/resolve/main/%s.gz?download=true"%file_name 15 | 16 | etag = None 17 | try: 18 | response = requests.head(model_url, allow_redirects=True) 19 | if response.status_code != 200: 20 | raise EnvironmentError('Cannot load model, response.status_code %d'%response.status_code) 21 | else: 22 | etag = response.headers.get("ETag") 23 | except EnvironmentError: 24 | etag = None 25 | 26 | if etag is None: 27 | raise EnvironmentError('Cannot load model, etag error') 28 | 29 | with tempfile.TemporaryFile() as temp_file: # NamedTemporaryFile f.name 30 | logger.info("model not found in cache, downloading to temporary file") # , model_url) #, temp_file.name) 31 | 32 | req = requests.get(model_url, stream=True) 33 | content_length = req.headers.get('Content-Length') 34 | total = int(content_length) if content_length is not None else None 35 | progress = tqdm(unit="B", total=total) 36 | for chunk in req.iter_content(chunk_size=1024): 37 | if chunk: # filter out keep-alive new chunks 38 | progress.update(len(chunk)) 39 | temp_file.write(chunk) 40 | progress.close() 41 | 42 | # we are processing the file before closing it, so flush to avoid truncation 43 | temp_file.flush() 44 | # sometimes fileobj starts at the current position, so go to the start 45 | temp_file.seek(0) 46 | etag_file_name = os.path.join(cache_dir, 'etag') 47 | try: 48 | logger.info("model archive extractall to %s", cache_dir) #, temp_file.name) 49 | with tarfile.open(fileobj=temp_file, mode='r:gz') as archive: 50 | archive.extractall(cache_dir) 51 | 52 | with open(etag_file_name, mode='w', encoding='utf-8') as fout: 53 | print(model_url, file=fout) 54 | print(etag, file=fout) 55 | 56 | except: 57 | if os.path.isfile(etag_file_name): 58 | os.remove(etag_file_name) 59 | return cache_dir 60 | 61 | 62 | if __name__ == "__main__": 63 | path: str = "omogre_data" 64 | if not os.path.exists(path): 65 | os.mkdir(path) 66 | download_model(cache_dir=path, file_name='accentuator_transcriptor_tiny') 67 | -------------------------------------------------------------------------------- /omogre/transcriptor/transcriptor.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import os 4 | import pickle 5 | from .unk_words import UnkWords 6 | 7 | vowels = 'аеёиоуыэюяАЕЁИОУЫЭЮЯ' 8 | 9 | def get_single_vowel(src_str: str) -> int: 10 | pos = None 11 | for indx, tc in enumerate(src_str): 12 | if tc in vowels: 13 | if pos is not None: 14 | return None 15 | pos = indx 16 | return pos 17 | 18 | 19 | def normalize_accent(src_str: str) -> str: 20 | # .casefold() ? 21 | if '+' in src_str: 22 | return src_str 23 | if 'ё' in src_str: 24 | return src_str.replace('ё', '+ё') 25 | 26 | vowel_pos = get_single_vowel(src_str) 27 | if vowel_pos is None: 28 | return src_str 29 | return src_str[:vowel_pos] + '+' + src_str[vowel_pos:] 30 | 31 | 32 | def get_g2p_without_accent(grapheme_to_phoneme_vocab: dict) -> dict: 33 | grapheme_phoneme = {} 34 | grapheme_freq = {} 35 | for grapheme_with_accent, phoneme_vars in grapheme_to_phoneme_vocab.items(): 36 | grapheme = grapheme_with_accent.replace('+', '') 37 | for phoneme, freq in phoneme_vars.items(): 38 | if (grapheme_freq.get(grapheme, 0) < freq): 39 | grapheme_phoneme[grapheme] = phoneme 40 | grapheme_freq[grapheme] = freq 41 | return grapheme_phoneme 42 | 43 | 44 | def get_max_freq_phoneme(key: str, vocab: dict) -> str: 45 | if key not in vocab: 46 | return None 47 | max_freq = -1 48 | max_phoneme = None 49 | 50 | for t_phoneme, freq in vocab[key].items(): 51 | if freq > max_freq: 52 | max_freq = freq 53 | max_phoneme = t_phoneme 54 | return max_phoneme 55 | 56 | 57 | class Transcriptor: 58 | def __init__(self, data_path: str): 59 | transcriptor_data_path = os.path.join(data_path, 'word_vocab.pickle') 60 | with open(transcriptor_data_path, "rb") as finp: 61 | self.grapheme_to_phoneme_vocab = pickle.load(finp) 62 | 63 | self.g2p_without_accent = get_g2p_without_accent(self.grapheme_to_phoneme_vocab) 64 | self.unk_words = UnkWords(data_path=data_path) 65 | 66 | def transcribe(self, src_str: str) -> str: 67 | word_str = src_str.casefold() 68 | word_str_norm = normalize_accent(word_str) 69 | if '+' in word_str_norm: 70 | max_phoneme = get_max_freq_phoneme(word_str_norm, self.grapheme_to_phoneme_vocab) 71 | if max_phoneme is not None: 72 | return max_phoneme 73 | 74 | if word_str in self.g2p_without_accent: 75 | return self.g2p_without_accent[word_str] 76 | 77 | return self.unk_words.transcribe(word_str) 78 | 79 | -------------------------------------------------------------------------------- /ruslan_markup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from omogre import Transcriptor, find_model 4 | import time 5 | from argparse import ArgumentParser 6 | 7 | parser = ArgumentParser(description="Accentuate and transcribe natasha ruslan markup") 8 | 9 | parser.add_argument("--data_path", type=str, default=None, help="Omogre model direcory") 10 | parser.add_argument("--head", type=int, default=None, help="Process only head of file") 11 | args = parser.parse_args() 12 | 13 | transcriptor = Transcriptor(data_path=args.data_path) 14 | 15 | def save_markup(src_markup_parts: list, sentence_list: list, fout_name: str): 16 | assert len(src_markup_parts) == len(sentence_list) 17 | 18 | with open(fout_name, 'w', encoding='utf-8') as fout: 19 | for parts, out_sent in zip(src_markup_parts, sentence_list): 20 | parts[1] = out_sent 21 | print('|'.join(parts), file=fout) 22 | 23 | 24 | def process(dataset_name: str): 25 | print('dataset', dataset_name) 26 | finp_name = 'natasha_ruslan_markup/%s.txt'%dataset_name 27 | 28 | sentence_list = [] 29 | src_markup_parts = [] 30 | print('reading', finp_name) 31 | 32 | with open(finp_name, 'r', encoding='utf-8') as finp: 33 | for entry in finp: 34 | parts = entry.strip().split('|') 35 | assert len(parts) >= 2 36 | 37 | sentence_list.append(parts[1].replace('+', '')) 38 | src_markup_parts.append(parts) 39 | if args.head is not None: 40 | if len(src_markup_parts) >= args.head: 41 | break 42 | 43 | start = time.time() 44 | output_sentences = transcriptor.accentuate(sentence_list) 45 | dt = time.time() - start 46 | print('Accentuated', dataset_name, len(sentence_list), 'sentences, dtime %.1f s'%dt) 47 | 48 | fout_name = 'natasha_ruslan_markup/%s.accentuate'%dataset_name 49 | save_markup(src_markup_parts, output_sentences, fout_name) 50 | 51 | start = time.time() 52 | output_sentences = transcriptor.transcribe(sentence_list) 53 | dt = time.time() - start 54 | print('Transcribed', dataset_name, len(sentence_list), 'sentences, dtime %.1f s'%dt) 55 | 56 | fout_name = 'natasha_ruslan_markup/%s.transcribe'%dataset_name 57 | save_markup(src_markup_parts, output_sentences, fout_name) 58 | 59 | 60 | if __name__ == '__main__': 61 | # Хабр: Open Source синтез речи SOVA 62 | # https://habr.com/ru/companies/ashmanov_net/articles/528296/ 63 | 64 | # ruslan 65 | # http://dataset.sova.ai/SOVA-TTS/ruslan/ruslan_dataset.tar 66 | # natasha 67 | # http://dataset.sova.ai/SOVA-TTS/natasha/natasha_dataset.tar 68 | 69 | 70 | find_model(cache_dir='natasha_ruslan_markup', file_name='natasha_ruslan') 71 | 72 | for dataset_name in ['natasha', 'ruslan']: 73 | process(dataset_name) 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /omogre/accentuator/unk_vocab.py: -------------------------------------------------------------------------------- 1 | 2 | import bisect 3 | import pickle 4 | import os 5 | 6 | def common_left(str1: str, str2: str) -> int: 7 | if len(str1) < len(str2): 8 | for index, tc in enumerate(str1): 9 | if str2[index] != tc: 10 | return index 11 | return len(str1) 12 | 13 | for index, tc in enumerate(str2): 14 | if str1[index] != tc: 15 | return index 16 | return len(str2) 17 | 18 | 19 | def get_neib_pos(key_pos_list: list, key: str, pos: int, min_len: int = 3) -> list: 20 | result = [pos] 21 | len1 = 0 22 | len2 = 0 23 | len3 = common_left(key, key_pos_list[pos][0]) 24 | max_len = len3 25 | 26 | tp1 = pos - 1 27 | if tp1 >= 0: 28 | len1 = common_left(key, key_pos_list[tp1][0]) 29 | max_len = max(max_len, len1) 30 | result.append(tp1) 31 | 32 | tp2 = pos + 1 33 | if tp2 < len(key_pos_list): 34 | len2 = common_left(key, key_pos_list[tp2][0]) 35 | max_len = max(max_len, len2) 36 | result.append(tp2) 37 | 38 | if max_len < min_len: 39 | return result 40 | 41 | if len1 >= max_len: 42 | result.append(tp1) 43 | tp = pos - 2 44 | while tp >= 0: 45 | len1 = common_left(key, key_pos_list[tp][0]) 46 | if len1 < max_len: 47 | break 48 | result.append(tp) 49 | tp -= 1 50 | 51 | if len2 >= max_len: 52 | result.append(tp2) 53 | tp = pos + 2 54 | while tp < len(key_pos_list): 55 | len1 = common_left(key, key_pos_list[tp][0]) 56 | if len1 < max_len: 57 | break 58 | result.append(tp) 59 | tp += 1 60 | return result 61 | 62 | 63 | class UnkVocab: 64 | def __init__(self, data_path: str, encoding: str = 'utf-8'): 65 | unk_file = os.path.join(data_path, 'unk_vocab.pickle') 66 | 67 | with open(unk_file, 'rb') as finp: 68 | self.acc_vocab, self.all_tails = pickle.load(finp) 69 | 70 | def cmp_form_norm(self, form: str, tpos: int, res: list): 71 | norm = self.acc_vocab[tpos][0] 72 | num_equ_chars: int = common_left(form, norm) 73 | key = (form[num_equ_chars:], norm[num_equ_chars:]) 74 | 75 | if key in self.all_tails: 76 | res.append( (num_equ_chars, self.acc_vocab[tpos]) ) 77 | 78 | def search_neibs(self, text: str) -> list: 79 | key = (text, '') 80 | 81 | all_len = len(self.acc_vocab) 82 | if all_len < 1: 83 | return [] 84 | if key < self.acc_vocab[0]: 85 | return get_neib_pos(self.acc_vocab, text, 0) 86 | 87 | try_pos = bisect.bisect_left(self.acc_vocab, key) 88 | 89 | return get_neib_pos(self.acc_vocab, text, try_pos) 90 | 91 | def get_neibs(self, text: str) -> list: 92 | res = [] 93 | for tpos in self.search_neibs(text): 94 | self.cmp_form_norm(text, tpos, res) 95 | return res 96 | 97 | def get_acc_pos(self, text: str) -> int: 98 | result = self.get_neibs(text) 99 | if len(result) < 1: 100 | return -1 101 | result.sort() 102 | cl, key_acc = result[-1] 103 | acc_pos_list = [int(tp) for tp in key_acc[1].split(',')] 104 | if len(acc_pos_list) < 1: 105 | return -1 106 | acc_pos = acc_pos_list[-1] 107 | if acc_pos < 0: 108 | return -1 109 | if acc_pos >= cl: 110 | return -1 111 | return acc_pos 112 | 113 | -------------------------------------------------------------------------------- /omogre/accentuator/unk_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import collections 8 | import math 9 | import os 10 | import random 11 | import sys 12 | import numpy as np 13 | import torch 14 | 15 | from .bert import BertForTokenClassification 16 | 17 | from .unk_reader import InfBatchFromSentenceList 18 | from .unk_vocab import UnkVocab 19 | 20 | 21 | def norm_word(tword: str) -> str: 22 | return tword.casefold().replace('ё', 'е').replace(' ', '!') 23 | 24 | 25 | def check_ee(tword: str) -> int: 26 | acc_pos = tword.find('ё') 27 | if acc_pos >= 0: 28 | return acc_pos 29 | 30 | acc_pos = tword.find('Ё') 31 | if acc_pos >= 0: 32 | return acc_pos 33 | return -1 34 | 35 | 36 | class UnkModel: 37 | def __init__(self, data_path: str, device_name: str = None): 38 | model_data_path = os.path.join(data_path, 'unk_model') 39 | 40 | if device_name is None: 41 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 42 | else: 43 | self.device = torch.device(device_name) 44 | 45 | self.model = BertForTokenClassification.from_pretrained(model_data_path, num_labels=1, cache_dir=None) 46 | assert self.model 47 | self.unk_vocab = UnkVocab(data_path) 48 | self.model.eval() 49 | self.model.to(self.device) 50 | self.error_counter = 0 51 | 52 | #def seed(self, seed_value: int): 53 | # random.seed(seed_value) 54 | # np.random.seed(seed_value) 55 | # torch.manual_seed(seed_value) 56 | 57 | def process_model_batch(self, input_ids, attention_mask, batch_text, all_res): 58 | with torch.no_grad(): 59 | input_ids = input_ids.to(self.device) 60 | attention_mask = attention_mask.to(self.device) 61 | 62 | logits = self.model(input_ids, attention_mask=attention_mask) 63 | logits = logits.squeeze(-1) 64 | logits = logits.detach().cpu().tolist() # cpu(). 65 | input_ids = input_ids.detach().cpu().tolist() # cpu(). 66 | 67 | for batch_indx, (t_logits, t_input_ids) in enumerate(zip(logits, input_ids)): 68 | word_text = batch_text[batch_indx] 69 | if len(word_text) < 2: 70 | all_res.append(word_text) 71 | continue 72 | 73 | max_pos = 1 74 | max_logit = t_logits[max_pos] 75 | for token_indx in range(len(word_text)): 76 | 77 | if t_logits[token_indx+1] > max_logit: 78 | max_pos = token_indx 79 | max_logit = t_logits[token_indx+1] 80 | ct = (word_text, max_pos) 81 | all_res.append(ct) 82 | 83 | def get_acc_pos(self, unk_word:str) -> int: 84 | acc_pos = self.unk_vocab.get_acc_pos(unk_word) 85 | if acc_pos >= 0: 86 | return acc_pos 87 | 88 | file_reader = InfBatchFromSentenceList([unk_word]) 89 | 90 | batch = file_reader.get_next_batch() 91 | if batch is None: 92 | return -1 93 | (all_input_ids, all_attention_mask, sentence_data) = batch 94 | sum_batch = [] 95 | self.process_model_batch(all_input_ids, all_attention_mask, sentence_data, sum_batch) 96 | if len(sum_batch) != 1: 97 | return -1 98 | (word_text, max_pos) = sum_batch[0] 99 | if word_text != unk_word: 100 | return -1 101 | return max_pos 102 | 103 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /omogre/accentuator/unk_reader.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import collections 4 | import random 5 | 6 | pad_token_id = 0 7 | bos_token_id = 1 8 | eos_token_id = 2 9 | alp_token_id = 3 10 | alphabet = ' абвгдеёжзийклмнопрстуфхцчшщъыьэюя' #АБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ' 11 | 12 | TokenSpan = collections.namedtuple( # pylint: disable=invalid-name 13 | "TokenSpan", ["text", "label", "punct", "span"]) 14 | 15 | MAX_BATCH_TOKEN_NUM = 6000 # 250 * 24 16 | 17 | def get_tokens(sentence: str): 18 | tokens = [bos_token_id] 19 | 20 | err_cnt = 0 21 | for tc in sentence: 22 | if tc == '+': 23 | continue 24 | else: 25 | tid = alphabet.find(tc) 26 | if tid < 0: 27 | err_cnt += 1 28 | tid = pad_token_id 29 | else: 30 | tid += alp_token_id 31 | tokens.append(tid) 32 | tokens.append(eos_token_id) 33 | return tokens, err_cnt 34 | 35 | 36 | class InfBatchFromSentenceList: 37 | def __init__(self, sentence_list: list): 38 | self.all_entries = [] 39 | 40 | for line_indx, sentence in enumerate(sentence_list): 41 | tokens, err_cnt = get_tokens(sentence.casefold()) 42 | if err_cnt == 0: 43 | ct = (tokens, sentence) 44 | self.all_entries.append(ct) 45 | 46 | self.file_pos = -1 47 | self.iter = 0 48 | assert len(self.all_entries) > 0 49 | 50 | def is_first_iter(self): 51 | if self.iter > 0: 52 | return False 53 | if (self.file_pos + 1) >= len(self.all_entries): 54 | return False 55 | return True 56 | 57 | def get_next_pos(self): 58 | self.file_pos += 1 59 | if self.file_pos >= len(self.all_entries): 60 | self.iter += 1 61 | self.file_pos = 0 62 | 63 | def get_next_batch(self, is_test: bool = True): 64 | if is_test: 65 | assert self.is_first_iter() 66 | 67 | max_length = 1 68 | batch_data = [] 69 | sentence_data = [] 70 | while True: 71 | self.get_next_pos() 72 | if is_test: 73 | if self.iter > 0: 74 | break 75 | sentence_pos = self.file_pos 76 | input_ids, sentence = self.all_entries[sentence_pos] 77 | len_input_ids = len(input_ids) 78 | max_length = max(max_length, len_input_ids) 79 | 80 | token_cnt = max_length * (1 + len(sentence_data)) 81 | if token_cnt > MAX_BATCH_TOKEN_NUM: 82 | break 83 | batch_data.append((input_ids)) 84 | sentence_data.append(sentence) 85 | 86 | if len(batch_data) < 1: 87 | return None 88 | 89 | all_input_ids = [] 90 | all_attention_mask = [] 91 | 92 | for input_ids in batch_data: 93 | len_input_ids = len(input_ids) 94 | attention_mask = [1] * len_input_ids 95 | assert len(input_ids) <= max_length 96 | if len_input_ids < max_length: 97 | # Pad input_ids and attention_mask to max length 98 | padding_length = max_length - len_input_ids 99 | input_ids += [pad_token_id] * padding_length 100 | attention_mask += [0] * padding_length 101 | 102 | all_input_ids.append(torch.tensor(input_ids, dtype=torch.long)) 103 | all_attention_mask.append(torch.tensor(attention_mask, dtype=torch.long)) 104 | 105 | all_input_ids = torch.stack(all_input_ids) 106 | all_attention_mask = torch.stack(all_attention_mask) 107 | batch = (all_input_ids, all_attention_mask, sentence_data) 108 | return batch 109 | 110 | 111 | if __name__ == '__main__': 112 | fr = InfBatchFromSentenceList(['квазисублимирующие']) 113 | 114 | x = fr.get_next_batch() 115 | print('get_next_batch', x) 116 | -------------------------------------------------------------------------------- /omogre/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import os 4 | import sys 5 | from .transcriptor import Transcriptor as TranscriptorImpl 6 | from .accentuator import Accentuator as AccentuatorImpl 7 | 8 | INITIAL_MODEL = 'accentuator_transcriptor_tiny' 9 | 10 | def punctuation_filter(input_string: str, punct: str=None): 11 | output_string = [] 12 | is_stace = False 13 | for tc in input_string: 14 | if punct is None or tc in punct: 15 | output_string.append(tc) 16 | else: 17 | if not is_stace: 18 | output_string.append(' ') 19 | is_stace = True 20 | return ''.join(output_string) 21 | 22 | 23 | def find_model(file_name: str = INITIAL_MODEL, cache_dir: str = None, download: bool = True, reload=False): 24 | from pathlib import Path 25 | 26 | if cache_dir is None: 27 | try: 28 | omogr_cache = Path(os.getenv('OMOGR_CACHE', Path.home() / '.omogr_data')) 29 | except (AttributeError, ImportError): 30 | omogr_cache = os.getenv('OMOGR_CACHE', os.path.join(os.path.expanduser("~"), '.omogr_data')) 31 | 32 | cache_dir = omogr_cache 33 | 34 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 35 | cache_dir = str(cache_dir) 36 | 37 | if not cache_dir: 38 | raise EnvironmentError('Cannot find OMOGR_CACHE path') 39 | 40 | if os.path.exists(cache_dir): 41 | if os.path.isdir(cache_dir): 42 | etag_file_name = os.path.join(cache_dir, 'etag') 43 | if os.path.isfile(etag_file_name): 44 | if not reload: 45 | return cache_dir 46 | 47 | if not download: 48 | raise EnvironmentError('Cannot find model data') 49 | 50 | print('data_path', cache_dir, file=sys.stderr) 51 | 52 | if not os.path.exists(cache_dir): 53 | os.makedirs(cache_dir) 54 | 55 | if not os.path.isdir(cache_dir): 56 | raise EnvironmentError('Cannot create directory %s'%cache_dir) 57 | 58 | from .downloader import download_model 59 | return download_model(cache_dir, file_name=file_name) 60 | 61 | 62 | class Transcriptor: 63 | def __init__(self, data_path: str = None, download: bool = True, device_name:str = None, punct:str = '.,!?'): 64 | loaded_data_path = find_model(file_name=INITIAL_MODEL, 65 | cache_dir=data_path, download=download) 66 | 67 | self.punct = punct 68 | transcriptor_data_path = os.path.join(loaded_data_path, 'transcriptor/') 69 | self.transcriptor = TranscriptorImpl(data_path=transcriptor_data_path) 70 | 71 | accentuator_data_path = os.path.join(loaded_data_path, 'accentuator/') 72 | self.accentuator = AccentuatorImpl(data_path=accentuator_data_path, device_name=device_name) 73 | 74 | def accentuate(self, text): 75 | return self.accentuator.accentuate(text) 76 | 77 | def transcribe(self, sentence_list: list) -> list: 78 | sentence_word_list = self.accentuator.accentuate_by_words(sentence_list) 79 | transcribed_sentence_list = [] 80 | for t_sentence in sentence_word_list: 81 | transcribed_sentence = [] 82 | for t_punct, t_word in t_sentence: 83 | if t_punct: 84 | transcribed_sentence.append(punctuation_filter(t_punct, punct=self.punct)) 85 | if t_word: 86 | transcribed_sentence.append(self.transcriptor.transcribe(t_word)) 87 | transcribed_sentence_list.append(''.join(transcribed_sentence)) 88 | return transcribed_sentence_list 89 | 90 | def __call__(self, sentence_list: list) -> list: 91 | return self.transcribe(sentence_list) 92 | 93 | 94 | class Accentuator: 95 | def __init__(self, data_path: str = None, download: bool = True, device_name:str = None): 96 | loaded_data_path = find_model(file_name=INITIAL_MODEL, 97 | cache_dir=data_path, download=download) 98 | 99 | accentuator_data_path = os.path.join(loaded_data_path, 'accentuator/') 100 | self.accentuator = AccentuatorImpl(data_path=accentuator_data_path, device_name=device_name) 101 | 102 | def accentuate(self, text): 103 | return self.accentuator.accentuate(text) 104 | 105 | def __call__(self, text): 106 | return self.accentuate(text) -------------------------------------------------------------------------------- /README_eng.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # Omogre 5 | 6 | ## Russian Accentuator and IPA Transcriptor 7 | 8 | A library for [`Python 3`](https://www.python.org/). Automatic stress placement and [IPA transcription](https://en.wikipedia.org/wiki/International_Phonetic_Alphabet) for the Russian language. 9 | 10 | ## Dependencies 11 | 12 | Installing the library will also install [`Pytorch`](https://pytorch.org/) and [`Numpy`](https://numpy.org/). Additionally, for model downloading, [`tqdm`](https://tqdm.github.io/) and [`requests`](https://pypi.org/project/requests/) will be installed. 13 | 14 | ## Installation 15 | 16 | ### Using GIT 17 | 18 | ```bash 19 | pip install git+https://github.com/omogr/omogre.git 20 | ``` 21 | 22 | ### Using pip 23 | 24 | Download the code from [GitHub](https://github.com/omogr/omogre). In the directory containing [`setup.py`](https://github.com/omogr/omogre/blob/main/setup.py), run: 25 | 26 | ```bash 27 | pip install -e . 28 | ``` 29 | 30 | ### Manually 31 | 32 | Download the code from [GitHub](https://github.com/omogr/omogre). Install [`Pytorch`](https://pytorch.org/), [`Numpy`](https://numpy.org/), [`tqdm`](https://tqdm.github.io/), and [`requests`](https://pypi.org/project/requests/). Run [`test.py`](https://github.com/omogr/omogre/blob/main/test.py). 33 | 34 | ## Model downloading 35 | 36 | By default, data for models will be downloaded on the first run of the library. The script [`download_data.py`](https://github.com/omogr/omogre/blob/main/download_data.py) can also be used to download this data. 37 | 38 | You can specify a path where the model data should be stored. If data already exists in this directory, it won't be downloaded again. 39 | 40 | ## Example 41 | 42 | Script [`test.py`](https://github.com/omogr/omogre/blob/main/test.py). 43 | 44 | ```python 45 | from omogre import Accentuator, Transcriptor 46 | 47 | # Data will be downloaded to the 'omogre_data' directory 48 | transcriptor = Transcriptor(data_path='omogre_data') 49 | accentuator = Accentuator(data_path='omogre_data') 50 | 51 | sentence_list = ['стены замка'] 52 | 53 | print('transcriptor', transcriptor(sentence_list)) 54 | print('accentuator', accentuator(sentence_list)) 55 | 56 | # Alternative call methods, differing only in notation 57 | print('transcriptor.transcribe', transcriptor.transcribe(sentence_list)) 58 | print('accentuator.accentuate', accentuator.accentuate(sentence_list)) 59 | 60 | print('transcriptor.accentuate', transcriptor.accentuate(sentence_list)) 61 | ``` 62 | 63 | ## Class Parameters 64 | 65 | ### Transcriptor 66 | 67 | All initialization parameters for the class are optional. 68 | 69 | ```python 70 | class Transcriptor(data_path: str = None, 71 | download: bool = True, 72 | device_name: str = None, 73 | punct: str = '.,!?') 74 | ``` 75 | 76 | - `data_path`: Directory where the model should be located. 77 | - `device_name`: Parameter defining GPU usage. Corresponds to the initialization parameter of [torch.device](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device). Valid values include `"cpu"`, `"cuda"`, `"cuda:0"`, etc. Defaults to `"cuda"` if GPU is available, otherwise `"cpu"`. 78 | - `punct`: List of non-letter characters to be carried over from the source text to the transcription. Default is `'.,!?'`. 79 | - `download`: Whether to download the model from the internet if not found in `data_path`. Default is `True`. 80 | 81 | Class methods: 82 | 83 | ```python 84 | accentuate(sentence_list: list) -> list 85 | transcribe(sentence_list: list) -> list 86 | ``` 87 | 88 | `accentuate` places stresses, `transcribe` performs transcription. Both inputs take a list of strings and return a list of strings. 89 | 90 | ### Accentuator 91 | 92 | The `Accentuator` class for stress placement is identical to the `Transcriptor` in terms of stress functionality, except it doesn't load transcription data, reducing initialization time and memory usage. 93 | 94 | All initialization parameters are optional, with the same meanings as for `Transcriptor`. 95 | 96 | ```python 97 | class Accentuator(data_path: str = None, 98 | download: bool = True, 99 | device_name: str = None) 100 | ``` 101 | 102 | - `data_path`: Directory where the model should be located. 103 | - `device_name`: Parameter for GPU usage. See above for details. 104 | - `download`: Whether to download the model if not found. Default is `True`. 105 | 106 | Class method: 107 | 108 | ```python 109 | accentuate(sentence_list: list) -> list 110 | ``` 111 | 112 | ## Usage Examples 113 | 114 | ### markup files for acoustic corpora 115 | 116 | The script [`ruslan_markup.py`](https://github.com/omogr/omogre/blob/main/ruslan_markup.py) places stresses and generates transcriptions for markup files of the acoustic corpora [`RUSLAN`](https://ruslan-corpus.github.io/) ([`RUSLAN with manually accentuated markup`](http://dataset.sova.ai/SOVA-TTS/ruslan/ruslan_dataset.tar)) and [`NATASHA`](http://dataset.sova.ai/SOVA-TTS/natasha/natasha_dataset.tar). 117 | 118 | These markup files already contain [manually placed stresses](https://habr.com/ru/companies/ashmanov_net/articles/528296/). 119 | 120 | The script [`ruslan_markup.py`](https://github.com/omogr/omogre/blob/main/ruslan_markup.py) generates its own stress placement for these files, allowing for an evaluation of the accuracy of stress placement. 121 | 122 | ### Speech Synthesis 123 | 124 | Accentuation and transcription can be useful for speech synthesis. The [`Colab` notebook](https://github.com/omogr/omogre/blob/main/XTTS_ru_ipa.ipynb) contains an example of running the [`XTTS`](https://github.com/coqui-ai/TTS) model trained on transcription for the Russian language. The model was trained on the [`RUSLAN`](https://ruslan-corpus.github.io/) and [`Common Voice`](https://commonvoice.mozilla.org/ru) datasets. 125 | The model weights can be downloaded from [`Hugging Face`](https://huggingface.co/omogr/XTTS-ru-ipa). 126 | 127 | ### Extraction of transcription from audio files 128 | 129 | Accentuation and transcription can be useful for acustic corpora analysis. The [`Colab` notebook](https://github.com/omogr/omogre/blob/main/Wav2vec2_ru_ipa.ipynb) contains an example of running the wav2vec2-lv-60-espeak-cv-ft model finetuned with transcription of [`RUSLAN`](https://ruslan-corpus.github.io/) and [`Common Voice`](https://commonvoice.mozilla.org/ru). 130 | 131 | ## Context Awareness and Other Features 132 | 133 | ### Stresses 134 | 135 | Stresses are placed considering context. If very long strings are encountered (for the current model, more than 510 tokens), context won't be considered for these. Stresses in these strings will be placed only where possible without context. 136 | 137 | Stresses are also placed in one-syllable words, which might look unusual but simplifies subsequent transcription determination. 138 | 139 | ### Transcription 140 | 141 | During transcription generation, extraneous characters are filtered out. Non-letter characters that are not filtered can be specified by a parameter. By default, four punctuation marks (`.,!?`) are not filtered. Transcription is determined word by word, without context. The following symbols are used for transcription: 142 | 143 | ``` 144 | ʲ`ɪətrsɐnjvmapkɨʊleɫdizofʂɕbɡxːuʐæɵʉɛ 145 | ``` 146 | 147 | ## Feedback 148 | Email for questions, comments and suggestions - `omogrus@ya.ru`. 149 | 150 | ## License 151 | [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en) 152 | 153 | (translated by grok-2-2024-08-13) 154 | -------------------------------------------------------------------------------- /Wav2vec2_ru_ipa.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [] 7 | }, 8 | "kernelspec": { 9 | "name": "python3", 10 | "display_name": "Python 3" 11 | }, 12 | "language_info": { 13 | "name": "python" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "source": [ 20 | "# Wav2vec2 example\n", 21 | "Accentuation and transcription can be useful for acustic corpora analysis. This notebook contains an example of running the wav2vec2-lv-60-espeak-cv-ft model finetuned with [`RUSLAN`](https://ruslan-corpus.github.io/) and [`Common Voice`](https://commonvoice.mozilla.org/ru)\n" 22 | ], 23 | "metadata": { 24 | "id": "xEmXMe9SEjK2" 25 | } 26 | }, 27 | { 28 | "cell_type": "code", 29 | "source": [ 30 | "# @title Download model from Hugging Face\n", 31 | "!mkdir model\n", 32 | "!git clone https://huggingface.co/omogr/wav2vec2-lv-60-ru-ipa model" 33 | ], 34 | "metadata": { 35 | "colab": { 36 | "base_uri": "https://localhost:8080/" 37 | }, 38 | "id": "g5ZEeq1LMH2E", 39 | "outputId": "ff112dc9-6292-4d44-b0c1-1227f760da7b" 40 | }, 41 | "execution_count": 2, 42 | "outputs": [ 43 | { 44 | "output_type": "stream", 45 | "name": "stdout", 46 | "text": [ 47 | "Cloning into 'model'...\n", 48 | "remote: Enumerating objects: 20, done.\u001b[K\n", 49 | "remote: Counting objects: 100% (16/16), done.\u001b[K\n", 50 | "remote: Compressing objects: 100% (16/16), done.\u001b[K\n", 51 | "remote: Total 20 (delta 2), reused 0 (delta 0), pack-reused 4 (from 1)\u001b[K\n", 52 | "Unpacking objects: 100% (20/20), 304.43 KiB | 6.34 MiB/s, done.\n" 53 | ] 54 | } 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "source": [ 60 | "import os\n", 61 | "import sys\n", 62 | "import numpy as np\n", 63 | "import torch\n", 64 | "import torchaudio\n", 65 | "import random\n", 66 | "\n", 67 | "from transformers import Wav2Vec2CTCTokenizer\n", 68 | "from transformers import Wav2Vec2FeatureExtractor\n", 69 | "from transformers import Wav2Vec2Processor\n", 70 | "from transformers import Wav2Vec2ForCTC\n", 71 | "\n", 72 | "MODEL_PATH = 'model'\n", 73 | "\n", 74 | "tokenizer = Wav2Vec2CTCTokenizer(\n", 75 | " \"model/vocab.json\",\n", 76 | " bos_token=\"\",\n", 77 | " eos_token=\"\",\n", 78 | " unk_token=\"\",\n", 79 | " pad_token=\"\",\n", 80 | " word_delimiter_token=\"|\",\n", 81 | " do_lower_case=False\n", 82 | ")\n", 83 | "\n", 84 | "# @title Load model and processor\n", 85 | "feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_PATH)\n", 86 | "processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)\n", 87 | "\n", 88 | "model = Wav2Vec2ForCTC.from_pretrained(\n", 89 | " MODEL_PATH,\n", 90 | " attention_dropout=0.0,\n", 91 | " hidden_dropout=0.0,\n", 92 | " feat_proj_dropout=0.0,\n", 93 | " mask_time_prob=0.0,\n", 94 | " layerdrop=0.0,\n", 95 | " gradient_checkpointing=True,\n", 96 | " ctc_loss_reduction=\"mean\",\n", 97 | " ctc_zero_infinity=True,\n", 98 | " bos_token_id=processor.tokenizer.bos_token_id,\n", 99 | " eos_token_id=processor.tokenizer.eos_token_id,\n", 100 | " pad_token_id=processor.tokenizer.pad_token_id,\n", 101 | " vocab_size=len(processor.tokenizer.get_vocab()),\n", 102 | " )\n", 103 | "\n", 104 | "def process_wav_file(wav_file_path: str):\n", 105 | " # read soundfiles\n", 106 | " waveform, sample_rate = torchaudio.load(wav_file_path)\n", 107 | "\n", 108 | " bundle_sample_rate = 16000\n", 109 | " if sample_rate != bundle_sample_rate:\n", 110 | " waveform = torchaudio.functional.resample(waveform, sample_rate, bundle_sample_rate)\n", 111 | "\n", 112 | " # tokenize\n", 113 | " input_values = processor(waveform, sampling_rate=16000, return_tensors=\"pt\").input_values\n", 114 | " # retrieve logits\n", 115 | " with torch.no_grad():\n", 116 | " logits = model(input_values.view(1, -1)).logits\n", 117 | " # take argmax and decode\n", 118 | " predicted_ids = torch.argmax(logits, dim=-1)\n", 119 | " return processor.batch_decode(predicted_ids)" 120 | ], 121 | "metadata": { 122 | "colab": { 123 | "base_uri": "https://localhost:8080/" 124 | }, 125 | "id": "l1UCVSToMGVr", 126 | "outputId": "ba94419d-8401-4511-f33b-879f1d109777" 127 | }, 128 | "execution_count": 5, 129 | "outputs": [ 130 | { 131 | "output_type": "stream", 132 | "name": "stderr", 133 | "text": [ 134 | "Some weights of the model checkpoint at model were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.masked_spec_embed']\n", 135 | "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 136 | "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" 137 | ] 138 | } 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "source": [ 144 | "sample_wav_files = [\n", 145 | " 'model/sample_wav_files/common_voice_ru_38488940.wav',\n", 146 | " 'model/sample_wav_files/common_voice_ru_38488941.wav',\n", 147 | "]\n", 148 | "\n", 149 | "# @title Transcribe wav files\n", 150 | "for wav_file_path in sample_wav_files:\n", 151 | " print('File:', wav_file_path)\n", 152 | " transcription = process_wav_file(wav_file_path)\n", 153 | " print('Transcription:', transcription)\n", 154 | "\n" 155 | ], 156 | "metadata": { 157 | "colab": { 158 | "base_uri": "https://localhost:8080/" 159 | }, 160 | "id": "sIFV210KM6tw", 161 | "outputId": "d8ab7223-8aa9-47d9-d17f-387eaa19880a" 162 | }, 163 | "execution_count": 6, 164 | "outputs": [ 165 | { 166 | "output_type": "stream", 167 | "name": "stdout", 168 | "text": [ 169 | "File: model/sample_wav_files/common_voice_ru_38488940.wav\n", 170 | "Transcription: ['kak v tr`udnɨje tak i d`obrɨj vrʲɪmʲɪn`a n`aʂɨ məɫɐdʲ`ɵʂ `ɛtə ɡɫ`avnəjə bɐɡ`atstvə']\n", 171 | "File: model/sample_wav_files/common_voice_ru_38488941.wav\n", 172 | "Transcription: ['mɨ nɐdʲ`ejɪmsʲə ʂto fsʲe ɡəsʊd`arstvə pɐdʲː`erʐɨvəjɪt `ɛtət tʲekst pənʲɪm`ajɪt ʂto n`ɨnʲɪʂnʲɪjə bʲɪzʲdʲ`ejstvʲɪje nʲɪprʲɪ`jemlʲɪmə']\n" 173 | ] 174 | } 175 | ] 176 | } 177 | ] 178 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Omogre 2 | 3 | ## Russian accentuator and IPA transcriptor. 4 | 5 | [English README](https://github.com/omogr/omogre/blob/main/README_eng.md) 6 | 7 | ## Автоматическая расстановка ударений и [IPA](https://ru.wikipedia.org/wiki/%D0%9C%D0%B5%D0%B6%D0%B4%D1%83%D0%BD%D0%B0%D1%80%D0%BE%D0%B4%D0%BD%D1%8B%D0%B9_%D1%84%D0%BE%D0%BD%D0%B5%D1%82%D0%B8%D1%87%D0%B5%D1%81%D0%BA%D0%B8%D0%B9_%D0%B0%D0%BB%D1%84%D0%B0%D0%B2%D0%B8%D1%82) транскрипция для русского языка. 8 | 9 | Библиотека для [`Python 3`](https://www.python.org/). 10 | 11 | ## Зависимости 12 | 13 | Установка библиотеки повлечет за собой установку [`Pytorch`](https://pytorch.org/) и [`Numpy`](https://numpy.org/). Кроме того, для скачивания моделей установятся [`tqdm`](https://tqdm.github.io/) и [`requests`](https://pypi.org/project/requests/). 14 | 15 | ## Установка 16 | 17 | ### С помощью GIT 18 | 19 | ```bash 20 | pip install git+https://github.com/omogr/omogre.git 21 | ``` 22 | 23 | ### При помощи pip 24 | 25 | Скачать код с [гитхаба](https://github.com/omogr/omogre). В директории, в которой находится файл [`setup.py`](https://github.com/omogr/omogre/blob/main/setup.py), выполнить 26 | 27 | ```bash 28 | pip install -e . 29 | ``` 30 | 31 | ### Вручную 32 | 33 | Скачать код с [гитхаба](https://github.com/omogr/omogre). Установить [`Pytorch`](https://pytorch.org/), [`Numpy`](https://numpy.org/), [`tqdm`](https://tqdm.github.io/) и [`requests`](https://pypi.org/project/requests/). Запустить [`test.py`](https://github.com/omogr/omogre/blob/main/test.py). 34 | 35 | ## Загрузка моделей 36 | 37 | По умолчанию при первом запуске библиотеки скачиваются данные для моделей. Скрипт [`download_data.py`](https://github.com/omogr/omogre/blob/main/download_data.py) также позволяет загружать эти данные. 38 | 39 | При желании можно указывать путь, в котором должны располагаться данные для моделей. Если в этой директории уже есть данные, то их повторного скачивания не будет. 40 | 41 | ## Пример запуска 42 | 43 | Скрипт [`test.py`](https://github.com/omogr/omogre/blob/main/test.py). 44 | 45 | ```python 46 | from omogre import Accentuator, Transcriptor 47 | 48 | # данные будут скачаны в директорию 'omogre_data' 49 | transcriptor = Transcriptor(data_path='omogre_data') 50 | accentuator = Accentuator(data_path='omogre_data') 51 | 52 | sentence_list = ['стены замка'] 53 | 54 | print('transcriptor', transcriptor(sentence_list)) 55 | print('accentuator', accentuator(sentence_list)) 56 | 57 | # длугие способы вызовов, отличаются только формой записи 58 | print('transcriptor.transcribe', transcriptor.transcribe(sentence_list)) 59 | print('accentuator.accentuate', accentuator.accentuate(sentence_list)) 60 | 61 | print('transcriptor.accentuate', transcriptor.accentuate(sentence_list)) 62 | ``` 63 | 64 | ## Параметры классов 65 | 66 | ### Transcriptor 67 | 68 | Все параметры инициализации класса не являются обязательными. 69 | 70 | ```python 71 | class Transcriptor(data_path: str = None, 72 | download: bool = True, 73 | device_name: str = None, 74 | punct: str = '.,!?') 75 | ``` 76 | 77 | - `data_path` - директория, в которой должна находиться модель. 78 | 79 | - `device_name` - параметр, определяющий использование GPU. Соответствует параметру инициализации класса [`torch.device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device). Допустимые значения - `"cpu"`, `"cuda"`, `"cuda:0"` и т.д. По умолчанию если torch видит GPU, то `"cuda"`, иначе `"cpu"`. 80 | 81 | - `punct` - список небуквенных символов, которые переносятся из исходного текста в транскрипцию. По умолчанию `'.,!?'`. 82 | 83 | - `download` - следует ли загружать модель из интернета, если она не найдена в директории `data_path`. По умолчанию `True`. 84 | 85 | 86 | Входы класса `Transcriptor`: 87 | 88 | ```python 89 | accentuate(sentence_list: list) -> list 90 | transcribe(sentence_list: list) -> list 91 | ``` 92 | 93 | В случае `accentuate` выполняется расcтановка ударений, в случае `transcribe` - транскрипция. Оба входа получают на вход список строк и возращают список строк. Строками могут быть предложения или не очень большие куски текста. 94 | 95 | ### Accentuator 96 | 97 | Расстановка ударений классом Accentuator ничем не отличается от расстановки ударений классом Transcriptor. Разница только в том, что Accentuator не загружает данные для транскрипции. Это позволяет уменьшить время инициализации класса и расход оперативной памяти. 98 | 99 | Все параметры инициализации класса не являются обязательными. Смысл параметров инициализации такой же, как у класса Transcriptor. 100 | 101 | ```python 102 | class Accentuator(data_path: str = None, 103 | download: bool = True, 104 | device_name: str = None) 105 | ``` 106 | 107 | - `data_path` - директория, в которой должна находиться модель. 108 | 109 | - `device_name` - параметр, определяющий использование GPU. Соответствует параметру инициализации класса [`torch.device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device). Допустимые значения - `"cpu"`, `"cuda"`, `"cuda:0"` и т.д. По умолчанию если torch видит GPU, то `"cuda"`, иначе `"cpu"`. 110 | 111 | - `download` - следует ли загружать модель из интернета, если она не найдена в директории `data_path`. По умолчанию `True`. 112 | 113 | Входы класса `Accentuator`: 114 | 115 | ```python 116 | accentuate(sentence_list: list) -> list 117 | ``` 118 | 119 | ## Примеры работы 120 | 121 | ### markup-файлы для акустических корпусов 122 | 123 | Скрипт [`ruslan_markup.py`](https://github.com/omogr/omogre/blob/main/ruslan_markup.py) расставляет ударения и порождает транскрипцию для markup-файлов акустических корпусов [`RUSLAN`](https://ruslan-corpus.github.io/) ([`RUSLAN с ручной разметкой ударений`](http://dataset.sova.ai/SOVA-TTS/ruslan/ruslan_dataset.tar)) и [`NATASHA`](http://dataset.sova.ai/SOVA-TTS/natasha/natasha_dataset.tar). 124 | 125 | markup-файлы этих корпусов уже содержат расстановку ударений, которая [была сделана](https://habr.com/ru/companies/ashmanov_net/articles/528296/) вручную. 126 | 127 | Скрипт [`ruslan_markup.py`](https://github.com/omogr/omogre/blob/main/ruslan_markup.py) порождает для тех же файлов свою собственную расстановку ударений. Изначальная ручная разметка никак не используется при тестировании и не использовалась при обучении. Таким образом, её можно использовать для оценки точности расстановки ударений. 128 | 129 | ### Синтез речи 130 | 131 | Расстановка ударений и транскрипция могут быть полезны при синтезе речи. [Ноутбук](https://github.com/omogr/omogre/blob/main/XTTS_ru_ipa.ipynb) содержит пример запуска [`XTTS`](https://github.com/coqui-ai/TTS) модели, обученной на транскрипции для русского языка. Модель обучалась на корпусах [`RUSLAN`](https://ruslan-corpus.github.io/) и [`Common Voice`](https://commonvoice.mozilla.org/ru). 132 | Веса модели можно скачать с [Hugging Face](https://huggingface.co/omogr/XTTS-ru-ipa) 133 | 134 | ### Извлечение транскрипции из акустических файлов 135 | 136 | Расстановка ударений и транскрипция могут быть полезны при анализе речи. [Ноутбук](https://github.com/omogr/omogre/blob/main/Wav2vec2_ru_ipa.ipynb) содержит пример запуска модели wav2vec2-lv-60-espeak-cv-ft дообученной на транскрипции акустических корпусов [`RUSLAN`](https://ruslan-corpus.github.io/) и [`Common Voice`](https://commonvoice.mozilla.org/ru). 137 | 138 | ## Учёт контекста и некоторые другие особенности 139 | 140 | ### Ударения 141 | 142 | Ударения расставляются с учётом контекста. Если во входном списке строк встретятся очень длинные строки (для текущей модели это больше 510 токенов), то для таких длинных строк контекст учитываться не будет. Ударения в этих строках будут ставиться только там, где это возможно без учёта контекста. 143 | 144 | В словах из одного слога ударение тоже ставится. В некоторых случаях это может выглядеть странно, но упрощает последующее определение транскрипции. 145 | 146 | ### Транскрипция 147 | 148 | При порождении транскрипции посторонние символы фильтруются. Список небуквенных символов, которые не фильтруются можно задавать отдельным параметром. По умолчанию не фильтруются четыре знака пунктуации (`.,!?`). Транскрипция определяется пословно, без учёта контекста. Для транскрипции слов используются следующие символы: 149 | 150 | ``` 151 | ʲ`ɪətrsɐnjvmapkɨʊleɫdizofʂɕbɡxːuʐæɵʉɛ 152 | ``` 153 | 154 | ## Обратная связь 155 | 156 | Почта для вопросов, замечаний и предложений - `omogrus@ya.ru`. 157 | 158 | ## Лицензия 159 | 160 | [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ru) 161 | -------------------------------------------------------------------------------- /omogre/accentuator/reader.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import os 4 | import collections 5 | import torch 6 | from .tokenizer import BertTokenizer 7 | import pickle 8 | 9 | InfTokenSpan = collections.namedtuple("TokenSpan", ["word_tokens", "punct", "first", "last"]) 10 | 11 | alphabet = '-абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ' 12 | vowels = 'аеёиоуыэюяАЕЁИОУЫЭЮЯ' 13 | 14 | MAX_SENTENCE_LEN = 510 15 | 16 | class AccentVocab: 17 | def __init__(self, data_path): 18 | vocab_file = os.path.join(data_path, 'wv_word_acc.pickle') 19 | with open(vocab_file, "rb") as finp: 20 | (self.vocab, self.vocab_index) = pickle.load(finp) 21 | 22 | 23 | class AccentTokenizer: 24 | def __init__(self, data_path): 25 | self.accent_vocab = AccentVocab(data_path=data_path) 26 | bert_vocab_path = os.path.join(data_path, 'model/vocab.txt') 27 | self.tokenizer = BertTokenizer(bert_vocab_path, do_lower_case=False) 28 | 29 | self.get_vowel_pos() 30 | self.pad_token_id = 0 31 | 32 | def get_vowel_pos(self): 33 | self.token_vowel_pos = {} 34 | 35 | for token_text, token_id in self.tokenizer.vocab.items(): 36 | 37 | vowel_pos = [] 38 | letter_pos = 0 39 | if token_text.startswith('##'): 40 | tt = token_text[2:] 41 | else: 42 | tt = token_text 43 | for letter_pos, tc in enumerate(tt): 44 | if tc in vowels: 45 | vowel_pos.append(letter_pos) 46 | if letter_pos > 0: 47 | assert tc != '#', (token_text, token_id) 48 | ct = (len(tt), vowel_pos) 49 | self.token_vowel_pos[token_id] = ct 50 | 51 | def encode(self, txt): 52 | tokens = self.tokenizer.tokenize(txt) 53 | return self.tokenizer.convert_tokens_to_ids(tokens) 54 | 55 | def get_num_vars(self, tword): 56 | tword_index = self.accent_vocab.vocab.get(tword) 57 | if tword_index is None: 58 | return -1 59 | return len(self.accent_vocab.vocab_index[tword_index]) 60 | 61 | def tokenize_word(self, letter_list, tokens, first_pos, last_pos): 62 | tword = ''.join(letter_list) # casefold() ? 63 | 64 | num_vars = self.get_num_vars(tword) 65 | if num_vars >= 0: 66 | tokens.append(InfTokenSpan(self.encode(tword), [], first_pos, last_pos)) 67 | 68 | if num_vars > 1: 69 | return True 70 | return False 71 | 72 | if '-' not in tword: 73 | tokens.append(InfTokenSpan(self.encode(tword), [], first_pos, last_pos)) 74 | return False 75 | 76 | parts = tword.split('-') 77 | if len(parts) < 1: 78 | return False 79 | 80 | tpos = first_pos 81 | 82 | for tp in parts: 83 | next_pos = tpos + len(tp) 84 | tokens[-1].punct.append('-') 85 | if len(tp) > 0: 86 | tokens.append(InfTokenSpan(self.encode(tp), [], tpos, next_pos)) 87 | tpos = next_pos + 1 88 | 89 | for tp in parts: 90 | if self.get_num_vars(tp) > 1: 91 | return True 92 | return False 93 | 94 | def tokenize_punct(self, letters_list): 95 | return self.encode(''.join(letters_list)) 96 | 97 | def get_inf_tokens(self, sentence0): 98 | sentence = sentence0.replace('+', ' ') 99 | tokens = [] 100 | tokenizer_bos = 2 101 | tokenizer_sep = 3 102 | 103 | ct = InfTokenSpan([tokenizer_bos], [], 0, 0) 104 | tokens.append(ct) 105 | tword = [] 106 | 107 | first_pos = -1 108 | is_easy = True 109 | for char_pos, cur_char0 in enumerate(sentence): 110 | if cur_char0 in alphabet: 111 | cur_char = cur_char0.casefold() 112 | if first_pos < 0: 113 | first_pos = char_pos 114 | tword.append(cur_char) 115 | continue 116 | 117 | if tword: 118 | if self.tokenize_word(tword, tokens, first_pos, char_pos): 119 | is_easy = False 120 | 121 | tword = [] 122 | first_pos = -1 123 | tokens[-1].punct.append(cur_char0) 124 | 125 | if tword: 126 | if self.tokenize_word(tword, tokens, first_pos, len(sentence)): 127 | is_easy = False 128 | tword = [] 129 | 130 | ct = InfTokenSpan([tokenizer_sep], [], 0, 0) 131 | tokens.append(ct) 132 | 133 | all_ids = [] 134 | all_spans = [] 135 | for ws in tokens: 136 | first_token = len(all_ids) 137 | all_ids.extend(ws.word_tokens) 138 | 139 | if ws.last > 0: 140 | last_token = len(all_ids) 141 | text_span = (ws.first, ws.last) 142 | token_span = (first_token, last_token) 143 | ct = (text_span, token_span) 144 | all_spans.append(ct) 145 | 146 | tpunct_str = ''.join(ws.punct).replace(' ', '') 147 | if len(tpunct_str) > 0: 148 | punct_tokens = self.encode(tpunct_str) 149 | all_ids.extend(punct_tokens) 150 | 151 | if len(all_ids) > MAX_SENTENCE_LEN: 152 | is_easy = True 153 | 154 | return all_ids, all_spans, is_easy 155 | 156 | 157 | class AccentDocument: 158 | def __init__(self, acc_tokenizer, all_sentences, first_pos=0, max_batch_token_num=2048): 159 | self.all_sentences = all_sentences 160 | 161 | self.max_batch_token_num = max_batch_token_num 162 | self.easy_sentences = [] 163 | self.model_batches = [] 164 | self.too_long_sentence_cnt = 0 165 | self.pad_token_id = acc_tokenizer.pad_token_id 166 | 167 | self.get_batches(acc_tokenizer, first_pos) 168 | 169 | def num_entries(self): 170 | return len(self.sentence_list) 171 | 172 | def add_model_batch(self, bert_sentences, max_length): 173 | all_input_ids = [] 174 | sentence_spans = [] 175 | all_attention_mask = [] 176 | 177 | batch_length = max_length 178 | 179 | for t_doc_pos, t_input_ids, t_spans in bert_sentences: 180 | len_input_ids = len(t_input_ids) 181 | attention_mask = [1] * len_input_ids 182 | ct = (t_doc_pos, t_spans) 183 | sentence_spans.append(ct) 184 | if len_input_ids > batch_length: 185 | # Truncate t_input_ids and attention_mask to max length 186 | assert False 187 | t_input_ids = t_input_ids[:batch_length] 188 | 189 | attention_mask = attention_mask[:batch_length] 190 | elif len(t_input_ids) < batch_length: 191 | # Pad t_input_ids and attention_mask to max length 192 | padding_length = batch_length - len_input_ids 193 | t_input_ids += [self.pad_token_id] * padding_length 194 | attention_mask += [0] * padding_length 195 | 196 | all_input_ids.append(torch.tensor(t_input_ids, dtype=torch.long)) 197 | all_attention_mask.append(torch.tensor(attention_mask, dtype=torch.long)) 198 | 199 | all_input_ids = torch.stack(all_input_ids) 200 | all_attention_mask = torch.stack(all_attention_mask) 201 | batch = (sentence_spans, all_input_ids, all_attention_mask) 202 | 203 | self.model_batches.append(batch) 204 | 205 | def get_batches(self, acc_tokenizer, first_pos): 206 | sorted_bert_sentences = [] 207 | doc_pos = first_pos 208 | 209 | while doc_pos < len(self.all_sentences): 210 | 211 | cur_sentence = self.all_sentences[doc_pos] 212 | all_ids, all_spans, is_easy = acc_tokenizer.get_inf_tokens(cur_sentence) 213 | if is_easy: 214 | ct = (doc_pos, all_spans) 215 | self.easy_sentences.append(ct) 216 | doc_pos += 1 217 | continue 218 | 219 | len_input_ids = len(all_ids) 220 | 221 | if (1 + len_input_ids) >= self.max_batch_token_num: 222 | ct = (doc_pos, cur_sentence, all_spans) 223 | easy_sentences.append(ct) 224 | doc_pos += 1 225 | self.too_long_sentence_cnt += 1 226 | continue 227 | 228 | ct = (len_input_ids, doc_pos, all_ids, all_spans) 229 | sorted_bert_sentences.append(ct) 230 | doc_pos += 1 231 | 232 | sorted_bert_sentences.sort() 233 | 234 | max_length = 1 235 | bert_sentences = [] 236 | 237 | for len_input_ids, doc_pos, all_ids, all_spans in sorted_bert_sentences: 238 | new_max_length = max(max_length, len_input_ids) 239 | if new_max_length * (1 + len(bert_sentences)) >= self.max_batch_token_num: 240 | assert len(bert_sentences) > 0 241 | self.add_model_batch(bert_sentences, max_length) 242 | max_length = len_input_ids 243 | bert_sentences = [] 244 | else: 245 | max_length = new_max_length 246 | 247 | ct = (doc_pos, all_ids, all_spans) 248 | bert_sentences.append(ct) 249 | 250 | if len(bert_sentences) > 0: 251 | self.add_model_batch(bert_sentences, max_length) 252 | 253 | def get_all_easy(self, acc_tokenizer, first_pos=0): 254 | doc_pos = first_pos 255 | self.easy_sentences = [] 256 | while doc_pos < len(self.all_sentences): 257 | all_ids, all_spans, is_easy = acc_tokenizer.get_inf_tokens(self.all_sentences[doc_pos]) 258 | 259 | ct = (doc_pos, all_spans) 260 | self.easy_sentences.append(ct) 261 | doc_pos += 1 262 | 263 | 264 | if __name__ == '__main__': 265 | pass 266 | -------------------------------------------------------------------------------- /omogre/transcriptor/unk_words.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import os 4 | import sys 5 | import numpy as np 6 | import pickle 7 | import json 8 | 9 | INPUT_TEXT_ALPHABET = ' <>-`абвгдеёжзийклмнопрстуфхцчшщъыьэюя?' 10 | BOS_SYMBOL = '' 11 | DELETE_SYMBOL = '' 12 | UNK_PATH_EVAL = -100.0 13 | UNK_PATH_THRESHOLD = -99.0 14 | 15 | AUXILIARY_SYMBOL_REPLACEMENTS = [ 16 | (DELETE_SYMBOL, ""), 17 | ("+", ""), 18 | ("~", ""), 19 | ("ʑ", "ɕ:"), 20 | ("ɣ", "x"), 21 | (":", "ː"), 22 | ("'", "`"), 23 | ("_", "") 24 | ] 25 | 26 | class UnkWordsConfig(object): 27 | """Configuration class to store the configuration of a `UnkWords`. 28 | """ 29 | def __init__(self, 30 | window_size=7, 31 | unk_bigram_eval=20, 32 | kind_of_stupid_backoff=-1.31, 33 | gram_data_name='gram_model.pickle'): 34 | 35 | self.window_size = window_size 36 | self.unk_bigram_eval = unk_bigram_eval 37 | self.kind_of_stupid_backoff = kind_of_stupid_backoff 38 | self.gram_data_name = gram_data_name 39 | 40 | @classmethod 41 | def from_dict(cls, json_object): 42 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 43 | config = BertConfig(vocab_size_or_config_json_file=-1) 44 | for key, value in json_object.items(): 45 | config.__dict__[key] = value 46 | return config 47 | 48 | @classmethod 49 | def from_json_file(cls, json_file): 50 | """Constructs a `BertConfig` from a json file of parameters.""" 51 | with open(json_file, "r", encoding='utf-8') as reader: 52 | text = reader.read() 53 | return cls.from_dict(json.loads(text)) 54 | 55 | def __repr__(self): 56 | return str(self.to_json_string()) 57 | 58 | def to_dict(self): 59 | """Serializes this instance to a Python dictionary.""" 60 | output = copy.deepcopy(self.__dict__) 61 | return output 62 | 63 | def to_json_string(self): 64 | """Serializes this instance to a JSON string.""" 65 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 66 | 67 | def to_json_file(self, json_file_path): 68 | """ Save this instance to a json file.""" 69 | with open(json_file_path, "w", encoding='utf-8') as writer: 70 | writer.write(self.to_json_string()) 71 | 72 | 73 | def invert_vocab(vocab: dict) -> dict: 74 | """Inverts a dictionary, swapping keys and values. 75 | Args: 76 | vocab (dict): The dictionary to invert. 77 | 78 | Returns: 79 | dict: The inverted dictionary. 80 | """ 81 | 82 | inv = {} 83 | for key, value in vocab.items(): 84 | inv[value] = key 85 | return inv 86 | 87 | 88 | def replace_auxiliary_symbols(text: str) -> str: 89 | result = text 90 | for src, dst in AUXILIARY_SYMBOL_REPLACEMENTS: 91 | if result.find(src) >= 0: 92 | result = result.replace(src, dst) 93 | return result 94 | 95 | 96 | def get_input_token_vocab() -> dict: 97 | # vocab for input text tokenizer 98 | vocab = {} 99 | for indx, tc in enumerate(INPUT_TEXT_ALPHABET): 100 | vocab[tc] = indx 101 | return vocab 102 | 103 | 104 | def token_lower_bound(raw_search_key: list, id_list: list) -> int: 105 | # binary search (a kind of) 106 | search_key = tuple(raw_search_key) 107 | first_pos = 0 108 | last_pos = len(id_list) 109 | while True: 110 | half_range = (last_pos - first_pos) // 2 111 | if half_range < 1: 112 | break 113 | 114 | mid_pos = first_pos + half_range 115 | if id_list[mid_pos][0] > search_key: 116 | last_pos = mid_pos 117 | continue 118 | first_pos = mid_pos 119 | if id_list[mid_pos][0] == search_key: 120 | break 121 | return first_pos 122 | 123 | 124 | class UnkWords: 125 | def __init__(self, data_path: str, config=None): 126 | if config is None: 127 | self.config = UnkWordsConfig() 128 | else: 129 | self.config = config 130 | transcriptor_data_path = os.path.join(data_path, self.config.gram_data_name) 131 | with open(transcriptor_data_path, 'rb') as finp: 132 | ( 133 | self.dst_alphabet, 134 | self.gram_prob, 135 | self.bigram_eval, 136 | self.char_phrase_table, 137 | self.gram_phrase_table, 138 | self.head_transcriptions, 139 | self.tail_transcriptions, 140 | ) = pickle.load(finp) 141 | 142 | self.inv_dst_alphabet = invert_vocab(self.dst_alphabet) 143 | 144 | self.input_token_vocab = get_input_token_vocab() 145 | self.src_stress_sign = "`" 146 | 147 | self.src_stress_indx = self.input_token_vocab.get(self.src_stress_sign) 148 | self.delete_indx = self.dst_alphabet.get(DELETE_SYMBOL, -1) 149 | self.dst_stress_indx = self.dst_alphabet.get("'", -1) 150 | 151 | assert self.delete_indx >= 0 152 | assert self.dst_stress_indx >= 0 153 | assert self.src_stress_indx == 4 154 | 155 | def tokenize(self, input_text: str) -> list: 156 | bos_input_letters_eos = ['<', '<'] + [tc for tc in input_text] + ['>', '>'] 157 | return [self.input_token_vocab.get(tc, 0) for tc in bos_input_letters_eos] 158 | 159 | def viterbi_step(self, best_path: list, emission_logprobs: np.ndarray) -> list: 160 | """Performs a single step in the Viterbi search algorithm. 161 | 162 | Args: 163 | best_path (list): The best paths from the previous step. 164 | emission_logprobs (np.ndarray): Character emission probabilities. 165 | 166 | Returns: 167 | list: The top `config.window_size` best paths for the current step. 168 | """ 169 | 170 | new_best_paths: dict = {} 171 | for char_indx, t_char_eval in enumerate(emission_logprobs): 172 | if t_char_eval < UNK_PATH_THRESHOLD: 173 | continue 174 | next_dst_char = self.inv_dst_alphabet[char_indx] 175 | 176 | for path_indx, (prev_path_eval, _, phoneme1, phoneme2) in enumerate(best_path): 177 | key = (phoneme1, phoneme2, next_dst_char) 178 | if key in self.gram_prob: 179 | current_gram_eval = self.gram_prob[key] 180 | else: 181 | current_gram_eval = self.config.kind_of_stupid_backoff * self.bigram_eval.get( 182 | (phoneme1, phoneme2), self.config.unk_bigram_eval) 183 | 184 | new_path_eval = prev_path_eval + current_gram_eval + t_char_eval 185 | new_key = (phoneme2, next_dst_char) 186 | if new_key not in new_best_paths: 187 | new_best_paths[new_key] = (new_path_eval, path_indx) 188 | else: 189 | best_so_far, _ = new_best_paths[new_key] 190 | if new_path_eval > best_so_far: 191 | new_best_paths[new_key] = (new_path_eval, path_indx) 192 | 193 | best_list = sorted( [ 194 | (path_eval, path_indx, phoneme1, phoneme2) 195 | for (phoneme1, phoneme2), (path_eval, path_indx) in new_best_paths.items() 196 | ], reverse=True) 197 | return best_list[:self.config.window_size] 198 | 199 | def transcribe(self, input_word_text: str) -> str: 200 | """ 201 | Receives a word as input, returns its transcription 202 | """ 203 | if not input_word_text: 204 | return "" 205 | 206 | # in the input text of the word, the stress can be indicated by a plus, 207 | # we replace it with the symbol that is used in n-grams... 208 | 209 | word_text = input_word_text.casefold().replace('+', self.src_stress_sign) 210 | input_ids = self.tokenize(word_text) 211 | dst_alphabet_len = len(self.dst_alphabet) 212 | word_len = len(input_ids) 213 | 214 | head_lower_bound = token_lower_bound(input_ids, self.head_transcriptions) 215 | head_len = -1 216 | for indx, tid in enumerate(self.head_transcriptions[head_lower_bound][0]): 217 | if indx >= len(input_ids): 218 | break 219 | if tid != input_ids[indx]: 220 | break 221 | 222 | pattern_indx = self.head_transcriptions[head_lower_bound][1][indx] 223 | if pattern_indx in [self.delete_indx, self.dst_stress_indx]: 224 | continue 225 | head_len = indx 226 | 227 | reversed_input_ids = list(reversed(input_ids)) 228 | tail_lower_bound = token_lower_bound(reversed_input_ids, self.tail_transcriptions) 229 | tail_len = -1 230 | for indx, tid in enumerate(self.tail_transcriptions[tail_lower_bound][0]): 231 | if indx >= len(input_ids): 232 | break 233 | if tid != reversed_input_ids[indx]: 234 | break 235 | pattern_indx = self.tail_transcriptions[tail_lower_bound][1][indx] 236 | if pattern_indx in [self.delete_indx, self.dst_stress_indx]: 237 | continue 238 | tail_len = indx 239 | 240 | best_path = {} 241 | best_path[1] = [ (0.0, 0, BOS_SYMBOL, BOS_SYMBOL) ] 242 | indx = 2 243 | while indx < word_len: 244 | emission_logprobs = np.ones(dst_alphabet_len, dtype=np.float32) * UNK_PATH_EVAL 245 | indx1 = word_len - indx - 1 246 | is_empty = True 247 | if indx < head_len: 248 | vocab_char = self.head_transcriptions[head_lower_bound][1][indx] 249 | emission_logprobs[vocab_char] = 0.0 250 | is_empty = False 251 | if indx1 < tail_len: 252 | vocab_char = self.tail_transcriptions[tail_lower_bound][1][indx1] 253 | emission_logprobs[vocab_char] = 0.0 254 | is_empty = False 255 | if is_empty and (indx < word_len-1): 256 | i1 = indx - 1 257 | i2 = indx + 1 258 | if self.src_stress_indx == input_ids[i1]: 259 | if i1 > 0: 260 | i1 -= 1 261 | if self.src_stress_indx == input_ids[i2]: 262 | if i2 < word_len-1: 263 | i2 += 1 264 | src = tuple(input_ids[i1:i2+1]) 265 | if src in self.gram_phrase_table: 266 | is_empty = False 267 | for key, value in self.gram_phrase_table[src].items(): 268 | emission_logprobs[key] = value 269 | if is_empty: 270 | src = input_ids[indx] 271 | if src in self.char_phrase_table: 272 | for key, value in self.char_phrase_table[src].items(): 273 | emission_logprobs[key] = value 274 | 275 | best_path[indx] = self.viterbi_step(best_path[indx-1], emission_logprobs) 276 | indx += 1 277 | 278 | indx = word_len-1 279 | prev = 0 280 | res = [] 281 | 282 | while indx > 0 and len(best_path[indx]) > 0: 283 | path_eval, prev, phoneme1, phoneme2 = best_path[indx][prev] 284 | res.append(phoneme1) 285 | indx -= 1 286 | 287 | res.reverse() 288 | 289 | if len(res) > 3: # 4? 290 | # strip BOS_SYMBOL BOS_SYMBOL ... EOS_SYMBOL 291 | return replace_auxiliary_symbols(''.join(res[2:-1])) 292 | return "" 293 | 294 | 295 | -------------------------------------------------------------------------------- /omogre/accentuator/tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | # edited from pytorch_pretrained_bert 4 | 5 | # https://github.com/google-research/bert 6 | # https://github.com/maknotavailable/pytorch-pretrained-BERT 7 | 8 | 9 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 10 | # 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # http://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License. 22 | 23 | """Tokenization classes.""" 24 | 25 | from __future__ import absolute_import, division, print_function, unicode_literals 26 | 27 | import collections 28 | import logging 29 | import os 30 | import unicodedata 31 | from io import open 32 | 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | VOCAB_NAME = 'vocab.txt' 37 | 38 | def load_vocab(vocab_file): 39 | """Loads a vocabulary file into a dictionary.""" 40 | vocab = collections.OrderedDict() 41 | index = 0 42 | with open(vocab_file, "r", encoding="utf-8") as reader: 43 | while True: 44 | token = reader.readline() 45 | if not token: 46 | break 47 | token = token.strip() 48 | vocab[token] = index 49 | index += 1 50 | return vocab 51 | 52 | 53 | def whitespace_tokenize(text): 54 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 55 | text = text.strip() 56 | if not text: 57 | return [] 58 | tokens = text.split() 59 | return tokens 60 | 61 | 62 | class BertTokenizer(object): 63 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 64 | 65 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, 66 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 67 | """Constructs a BertTokenizer. 68 | 69 | Args: 70 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 71 | do_lower_case: Whether to lower case the input 72 | Only has an effect when do_wordpiece_only=False 73 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 74 | max_len: An artificial maximum length to truncate tokenized sequences to; 75 | Effective maximum length is always the minimum of this 76 | value (if specified) and the underlying BERT model's 77 | sequence length. 78 | never_split: List of tokens which will never be split during tokenization. 79 | Only has an effect when do_wordpiece_only=False 80 | """ 81 | if not os.path.isfile(vocab_file): 82 | raise ValueError( "Can't find a vocabulary file at path '{}'.".format(vocab_file)) 83 | self.vocab = load_vocab(vocab_file) 84 | self.ids_to_tokens = collections.OrderedDict( 85 | [(ids, tok) for tok, ids in self.vocab.items()]) 86 | self.do_basic_tokenize = do_basic_tokenize 87 | if do_basic_tokenize: 88 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 89 | never_split=never_split) 90 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 91 | self.max_len = max_len if max_len is not None else int(1e12) 92 | 93 | def tokenize(self, text): 94 | split_tokens = [] 95 | if self.do_basic_tokenize: 96 | for token in self.basic_tokenizer.tokenize(text): 97 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 98 | split_tokens.append(sub_token) 99 | else: 100 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 101 | return split_tokens 102 | 103 | def convert_tokens_to_ids(self, tokens): 104 | """Converts a sequence of tokens into ids using the vocab.""" 105 | ids = [] 106 | for token in tokens: 107 | ids.append(self.vocab[token]) 108 | if len(ids) > self.max_len: 109 | logger.warning( 110 | "Token indices sequence length is longer than the specified maximum " 111 | " sequence length for this BERT model ({} > {}). Running this" 112 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 113 | ) 114 | return ids 115 | 116 | def convert_ids_to_tokens(self, ids): 117 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 118 | tokens = [] 119 | for i in ids: 120 | tokens.append(self.ids_to_tokens[i]) 121 | return tokens 122 | 123 | def save_vocabulary(self, vocab_path): 124 | """Save the tokenizer vocabulary to a directory or file.""" 125 | index = 0 126 | if os.path.isdir(vocab_path): 127 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 128 | with open(vocab_file, "w", encoding="utf-8") as writer: 129 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 130 | if index != token_index: 131 | logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." 132 | " Please check that the vocabulary is not corrupted!".format(vocab_file)) 133 | index = token_index 134 | writer.write(token + u'\n') 135 | index += 1 136 | return vocab_file 137 | 138 | @classmethod 139 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 140 | """ 141 | Instantiate a PreTrainedBertModel from a pre-trained model file. 142 | Download and cache the pre-trained model file if needed. 143 | """ 144 | vocab_file = pretrained_model_name_or_path 145 | if os.path.isdir(vocab_file): 146 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 147 | # Instantiate tokenizer. 148 | tokenizer = cls(vocab_file, *inputs, **kwargs) 149 | return tokenizer 150 | 151 | 152 | class BasicTokenizer(object): 153 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 154 | 155 | def __init__(self, 156 | do_lower_case=True, 157 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 158 | """Constructs a BasicTokenizer. 159 | 160 | Args: 161 | do_lower_case: Whether to lower case the input. 162 | """ 163 | self.do_lower_case = do_lower_case 164 | self.never_split = never_split 165 | 166 | def tokenize(self, text): 167 | """Tokenizes a piece of text.""" 168 | text = self._clean_text(text) 169 | # This was added on November 1st, 2018 for the multilingual and Chinese 170 | # models. This is also applied to the English models now, but it doesn't 171 | # matter since the English models were not trained on any Chinese data 172 | # and generally don't have any Chinese data in them (there are Chinese 173 | # characters in the vocabulary because Wikipedia does have some Chinese 174 | # words in the English Wikipedia.). 175 | text = self._tokenize_chinese_chars(text) 176 | orig_tokens = whitespace_tokenize(text) 177 | split_tokens = [] 178 | for token in orig_tokens: 179 | if self.do_lower_case and token not in self.never_split: 180 | token = token.lower() 181 | token = self._run_strip_accents(token) 182 | split_tokens.extend(self._run_split_on_punc(token)) 183 | 184 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 185 | return output_tokens 186 | 187 | def _run_strip_accents(self, text): 188 | """Strips accents from a piece of text.""" 189 | text = unicodedata.normalize("NFD", text) 190 | output = [] 191 | for char in text: 192 | cat = unicodedata.category(char) 193 | if cat == "Mn": 194 | continue 195 | output.append(char) 196 | return "".join(output) 197 | 198 | def _run_split_on_punc(self, text): 199 | """Splits punctuation on a piece of text.""" 200 | if text in self.never_split: 201 | return [text] 202 | chars = list(text) 203 | i = 0 204 | start_new_word = True 205 | output = [] 206 | while i < len(chars): 207 | char = chars[i] 208 | if _is_punctuation(char): 209 | output.append([char]) 210 | start_new_word = True 211 | else: 212 | if start_new_word: 213 | output.append([]) 214 | start_new_word = False 215 | output[-1].append(char) 216 | i += 1 217 | 218 | return ["".join(x) for x in output] 219 | 220 | def _tokenize_chinese_chars(self, text): 221 | """Adds whitespace around any CJK character.""" 222 | output = [] 223 | for char in text: 224 | cp = ord(char) 225 | if self._is_chinese_char(cp): 226 | output.append(" ") 227 | output.append(char) 228 | output.append(" ") 229 | else: 230 | output.append(char) 231 | return "".join(output) 232 | 233 | def _is_chinese_char(self, cp): 234 | """Checks whether CP is the codepoint of a CJK character.""" 235 | # This defines a "chinese character" as anything in the CJK Unicode block: 236 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 237 | # 238 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 239 | # despite its name. The modern Korean Hangul alphabet is a different block, 240 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 241 | # space-separated words, so they are not treated specially and handled 242 | # like the all of the other languages. 243 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 244 | (cp >= 0x3400 and cp <= 0x4DBF) or # 245 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 246 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 247 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 248 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 249 | (cp >= 0xF900 and cp <= 0xFAFF) or # 250 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 251 | return True 252 | 253 | return False 254 | 255 | def _clean_text(self, text): 256 | """Performs invalid character removal and whitespace cleanup on text.""" 257 | output = [] 258 | for char in text: 259 | cp = ord(char) 260 | if cp == 0 or cp == 0xfffd or _is_control(char): 261 | continue 262 | if _is_whitespace(char): 263 | output.append(" ") 264 | else: 265 | output.append(char) 266 | return "".join(output) 267 | 268 | 269 | class WordpieceTokenizer(object): 270 | """Runs WordPiece tokenization.""" 271 | 272 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 273 | self.vocab = vocab 274 | self.unk_token = unk_token 275 | self.max_input_chars_per_word = max_input_chars_per_word 276 | 277 | def tokenize(self, text): 278 | """Tokenizes a piece of text into its word pieces. 279 | 280 | This uses a greedy longest-match-first algorithm to perform tokenization 281 | using the given vocabulary. 282 | 283 | For example: 284 | input = "unaffable" 285 | output = ["un", "##aff", "##able"] 286 | 287 | Args: 288 | text: A single token or whitespace separated tokens. This should have 289 | already been passed through `BasicTokenizer`. 290 | 291 | Returns: 292 | A list of wordpiece tokens. 293 | """ 294 | 295 | output_tokens = [] 296 | for token in whitespace_tokenize(text): 297 | chars = list(token) 298 | if len(chars) > self.max_input_chars_per_word: 299 | output_tokens.append(self.unk_token) 300 | continue 301 | 302 | is_bad = False 303 | start = 0 304 | sub_tokens = [] 305 | while start < len(chars): 306 | end = len(chars) 307 | cur_substr = None 308 | while start < end: 309 | substr = "".join(chars[start:end]) 310 | if start > 0: 311 | substr = "##" + substr 312 | if substr in self.vocab: 313 | cur_substr = substr 314 | break 315 | end -= 1 316 | if cur_substr is None: 317 | is_bad = True 318 | break 319 | sub_tokens.append(cur_substr) 320 | start = end 321 | 322 | if is_bad: 323 | output_tokens.append(self.unk_token) 324 | else: 325 | output_tokens.extend(sub_tokens) 326 | return output_tokens 327 | 328 | 329 | def _is_whitespace(char): 330 | """Checks whether `chars` is a whitespace character.""" 331 | # \t, \n, and \r are technically contorl characters but we treat them 332 | # as whitespace since they are generally considered as such. 333 | if char == " " or char == "\t" or char == "\n" or char == "\r": 334 | return True 335 | cat = unicodedata.category(char) 336 | if cat == "Zs": 337 | return True 338 | return False 339 | 340 | 341 | def _is_control(char): 342 | """Checks whether `chars` is a control character.""" 343 | # These are technically control characters but we count them as whitespace 344 | # characters. 345 | if char == "\t" or char == "\n" or char == "\r": 346 | return False 347 | cat = unicodedata.category(char) 348 | if cat.startswith("C"): 349 | return True 350 | return False 351 | 352 | 353 | def _is_punctuation(char): 354 | """Checks whether `chars` is a punctuation character.""" 355 | cp = ord(char) 356 | # We treat all non-letter/number ASCII as punctuation. 357 | # Characters such as "^", "$", and "`" are not in the Unicode 358 | # Punctuation class but we treat them as punctuation anyways, for 359 | # consistency. 360 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 361 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 362 | return True 363 | cat = unicodedata.category(char) 364 | if cat.startswith("P"): 365 | return True 366 | return False 367 | -------------------------------------------------------------------------------- /omogre/accentuator/model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import collections 8 | import math 9 | import os 10 | import random 11 | import sys 12 | import numpy as np 13 | import torch 14 | 15 | from .bert import BertForTokenClassification 16 | 17 | from .reader import AccentTokenizer, AccentDocument 18 | from .unk_model import UnkModel 19 | 20 | WordAcc = collections.namedtuple("WordAcc", ["first", "last", "pos", "state"]) 21 | PunctWord = collections.namedtuple("WordPunct", ["punct", "word"]) 22 | 23 | vowels = "аеёиоуыэюяАЕЁИОУЫЭЮЯ" 24 | vowel_plus = "аеёиоуыэюяАЕЁИОУЫЭЮЯ+" 25 | 26 | def count_vowels(text): 27 | return sum(1 for char in text if char in vowels) 28 | 29 | 30 | def _compute_softmax(scores): 31 | """Compute softmax probability over raw logits.""" 32 | if not scores: 33 | return [] 34 | 35 | max_score = None 36 | for score in scores: 37 | if max_score is None or score > max_score: 38 | max_score = score 39 | 40 | exp_scores = [] 41 | total_sum = 0.0 42 | for score in scores: 43 | x = math.exp(score - max_score) 44 | exp_scores.append(x) 45 | total_sum += x 46 | 47 | probs = [] 48 | for score in exp_scores: 49 | probs.append(score / total_sum) 50 | return probs 51 | 52 | 53 | def list_arg_max(iterable): 54 | return max(enumerate(iterable), key=lambda x: x[1]) 55 | 56 | 57 | def norm_word(tword): 58 | return tword.casefold().replace('ё', 'е').replace(' ', '!') 59 | 60 | 61 | def check_ee_comu(tword): 62 | lcw = tword.casefold() 63 | if lcw == 'кому': 64 | return 3 65 | acc_pos = lcw.find('ё') 66 | if acc_pos >= 0: 67 | return acc_pos 68 | return -1 69 | 70 | 71 | class Accentuator: 72 | def __init__(self, data_path, device_name=None): 73 | model_data_path = os.path.join(data_path, 'model') 74 | self.unk_model = UnkModel(data_path, device_name=device_name) 75 | 76 | self.tokenizer = AccentTokenizer(data_path=data_path) 77 | if device_name is None: 78 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 79 | else: 80 | self.device = torch.device(device_name) 81 | 82 | self.model = BertForTokenClassification.from_pretrained(model_data_path, num_labels=10, cache_dir=None) 83 | assert self.model 84 | self.model.eval() 85 | self.model.to(self.device) 86 | self.error_counter = 0 87 | 88 | def seed(self, seed_value): 89 | random.seed(seed_value) 90 | np.random.seed(seed_value) 91 | torch.manual_seed(seed_value) 92 | 93 | def process_model_batch(self, doc, batch, sum_batch): 94 | sentence_spans, input_ids, attention_mask = batch 95 | input_ids = input_ids.to(self.device) 96 | attention_mask = attention_mask.to(self.device) 97 | 98 | with torch.no_grad(): 99 | logits = self.model(input_ids, attention_mask=attention_mask) 100 | logits = logits.detach().cpu().tolist() # cpu(). 101 | input_ids = input_ids.detach().cpu().tolist() # cpu(). 102 | 103 | for batch_indx, (t_logits, t_input_ids) in enumerate(zip(logits, input_ids)): 104 | doc_pos, all_spans = sentence_spans[batch_indx] 105 | 106 | sentence = doc.all_sentences[doc_pos] 107 | word_list = [] 108 | 109 | for (first_word_pos, last_word_pos), (first_token_pos, last_token_pos) in all_spans: 110 | tword = sentence[first_word_pos:last_word_pos] 111 | 112 | acc_pos = check_ee_comu(tword) 113 | if acc_pos >= 0: 114 | tw = WordAcc(first_word_pos, last_word_pos, acc_pos, 'eee') 115 | word_list.append(tw) 116 | continue 117 | 118 | if count_vowels(tword) < 1: 119 | tw = WordAcc(first_word_pos, last_word_pos, -1, 'emp') 120 | word_list.append(tw) 121 | continue 122 | 123 | tword_key = norm_word(tword) 124 | word_acc_id = self.tokenizer.accent_vocab.vocab.get(tword_key, 0) 125 | acc_pos = self.tokenizer.accent_vocab.vocab_index[word_acc_id] 126 | 127 | if len(acc_pos) < 1: 128 | acc_pos = self.unk_model.get_acc_pos(tword.casefold()) 129 | tw = WordAcc(first_word_pos, last_word_pos, acc_pos, 'oov') 130 | word_list.append(tw) 131 | continue 132 | 133 | if len(acc_pos) < 2: 134 | tw = WordAcc(first_word_pos, last_word_pos, acc_pos[0], 'sin') 135 | word_list.append(tw) 136 | continue 137 | 138 | best_word_index = -1 139 | best_word_prob = -1.0 140 | best_letter_pos = -1 141 | letter_pos = 0 142 | 143 | for token_indx in range(first_token_pos, last_token_pos): 144 | soft_probs = _compute_softmax(t_logits[token_indx]) 145 | ct = self.tokenizer.token_vowel_pos.get(t_input_ids[token_indx]) 146 | if ct is None: 147 | self.error_counter += 1 148 | break 149 | num_letters, vowel_pos = ct 150 | 151 | best_prob = 0 152 | best_index = -1 153 | 154 | for prob_index, tprob in enumerate(soft_probs): 155 | if prob_index < 1: 156 | continue 157 | if prob_index > len(vowel_pos): 158 | break 159 | tpos = letter_pos + vowel_pos[prob_index-1] 160 | 161 | if tpos not in acc_pos: 162 | continue 163 | if tprob > best_prob: 164 | best_prob = tprob 165 | best_index = prob_index 166 | 167 | if best_index > 0: 168 | if best_prob > best_word_prob: 169 | best_word_index = best_index 170 | best_word_prob = best_prob 171 | best_letter_pos = letter_pos + vowel_pos[best_index-1] 172 | letter_pos += num_letters 173 | 174 | if best_letter_pos >= 0: 175 | tw = WordAcc(first_word_pos, last_word_pos, best_letter_pos, 'var') 176 | word_list.append(tw) 177 | continue 178 | 179 | acc_pos = self.unk_model.get_acc_pos(tword.casefold()) 180 | tw = WordAcc(first_word_pos, last_word_pos, acc_pos, 'unk') 181 | word_list.append(tw) 182 | 183 | ct = (doc_pos, word_list) 184 | sum_batch.append(ct) 185 | 186 | def process_without_bert(self, doc, doc_pos, all_spans): 187 | sentence = doc.all_sentences[doc_pos] 188 | word_list = [] 189 | 190 | for (first_word_pos, last_word_pos), (first_token_pos, last_token_pos) in all_spans: 191 | tword = sentence[first_word_pos:last_word_pos] 192 | 193 | acc_pos = check_ee_comu(tword) 194 | if acc_pos >= 0: 195 | tw = WordAcc(first_word_pos, last_word_pos, acc_pos, 'eee') 196 | word_list.append(tw) 197 | continue 198 | 199 | if count_vowels(tword) < 1: 200 | tw = WordAcc(first_word_pos, last_word_pos, -1, 'emp') 201 | word_list.append(tw) 202 | continue 203 | 204 | tword_key = norm_word(tword) 205 | # state = 'var' 206 | word_acc_id = self.tokenizer.accent_vocab.vocab.get(tword_key) 207 | if word_acc_id is not None: 208 | acc_pos = self.tokenizer.accent_vocab.vocab_index[word_acc_id] 209 | if len(acc_pos) == 1: 210 | assert len(acc_pos) 211 | tw = WordAcc(first_word_pos, last_word_pos, acc_pos[0], 'sin') 212 | word_list.append(tw) 213 | continue 214 | 215 | acc_pos = self.unk_model.get_acc_pos(tword.casefold()) 216 | tw = WordAcc(first_word_pos, last_word_pos, acc_pos, 'unk') 217 | word_list.append(tw) 218 | return word_list 219 | 220 | def easy_loop(self, doc, sum_batch): 221 | for doc_pos, all_spans in doc.easy_sentences: 222 | word_list = self.process_without_bert(doc, doc_pos, all_spans) 223 | ct = (doc_pos, word_list) 224 | sum_batch.append(ct) 225 | 226 | def result_to_text(self, sentence, word_list): 227 | scale = [0 for _ in range(len(sentence))] 228 | for first_word_pos, last_word_pos, acc_pos, state in word_list: 229 | if acc_pos >= 0: 230 | tp = first_word_pos + acc_pos 231 | assert tp < len(sentence) 232 | scale[tp] = 1 233 | 234 | res = [] 235 | for indx, tc in enumerate(sentence): 236 | if scale[indx]: 237 | res.append('+') 238 | res.append(tc) 239 | return ''.join(res) 240 | 241 | def result_to_word_list(self, sentence, word_list): 242 | prev_pos = 0 243 | out_word_list = [] 244 | for first_word_pos, last_word_pos, acc_pos, state in word_list: 245 | if acc_pos < 0: 246 | tword = sentence[first_word_pos:last_word_pos] 247 | else: 248 | tword_letters = [] 249 | rel_pos = acc_pos + first_word_pos 250 | for tpos in range(first_word_pos, last_word_pos): 251 | if tpos == rel_pos: 252 | tword_letters.append('+') 253 | tword_letters.append(sentence[tpos]) 254 | tword = ''.join(tword_letters) 255 | 256 | punct = sentence[prev_pos:first_word_pos] 257 | prev_pos = last_word_pos 258 | ct = PunctWord(punct, tword) 259 | out_word_list.append(ct) 260 | punct = sentence[prev_pos:] 261 | ct = PunctWord(punct, '') 262 | out_word_list.append(ct) 263 | return out_word_list 264 | 265 | def accentuate_by_words(self, input_sentence_list): 266 | if not bool(input_sentence_list): 267 | raise ValueError('list of strings is required') 268 | 269 | if not isinstance(input_sentence_list, list): 270 | raise ValueError('list of strings is required') 271 | 272 | if not all([isinstance(elem, str) for elem in input_sentence_list]): 273 | raise ValueError('list of strings is required') 274 | 275 | doc = AccentDocument(self.tokenizer, input_sentence_list) 276 | 277 | sum_batch = [] 278 | self.easy_loop(doc, sum_batch) 279 | for t_batch in doc.model_batches: 280 | self.process_model_batch(doc, t_batch, sum_batch) 281 | 282 | sum_batch_index = [] 283 | for indx in range(len(sum_batch)): 284 | pos = sum_batch[indx][0] 285 | sum_batch_index.append((pos, indx)) 286 | 287 | output_sentence_word_list = [] 288 | 289 | for pos, indx in sorted(sum_batch_index): 290 | sentence = doc.all_sentences[pos] 291 | sentence_words = self.result_to_word_list(sentence, sum_batch[indx][1]) 292 | output_sentence_word_list.append(sentence_words) 293 | return output_sentence_word_list 294 | 295 | def accentuate_sentence_list(self, input_sentence_list): 296 | if not bool(input_sentence_list): 297 | raise ValueError('a list of strings is required') 298 | if not isinstance(input_sentence_list, list): 299 | raise ValueError('a list of strings is required') 300 | if not all([ 301 | isinstance(elem, str) for elem in input_sentence_list]): 302 | raise ValueError('a list of strings is required') 303 | 304 | output_sentence_word_list = self.accentuate_by_words(input_sentence_list) 305 | output_sentence_text_list = [] 306 | for t_sentence in output_sentence_word_list: 307 | word_list = [] 308 | for punct, tword in t_sentence: 309 | word_list.append(punct) 310 | word_list.append(tword) 311 | output_sentence_text_list.append(''.join(word_list)) 312 | 313 | return output_sentence_text_list 314 | 315 | def accentuate(self, input_text): 316 | if not bool(input_text): 317 | raise ValueError('a string or list of strings is required') 318 | if isinstance(input_text, list): 319 | input_sentence_list = input_text 320 | 321 | if not all([ 322 | isinstance(elem, str) for elem in input_sentence_list]): 323 | raise ValueError('a string or list of strings is required') 324 | 325 | else: 326 | if isinstance(input_text, str): 327 | input_sentence_list = [input_text] 328 | else: 329 | raise ValueError('a string or list of strings is required') 330 | 331 | output_sentence_list = self.accentuate_sentence_list(input_sentence_list) 332 | 333 | if isinstance(input_text, str): 334 | return "\n".join(output_sentence_list) 335 | return output_sentence_list 336 | 337 | # ---------------------------- this is for debugging ----------------------------- 338 | 339 | def process_all_easy_sentence(self, doc, doc_pos, all_spans): 340 | sentence = doc.all_sentences[doc_pos] 341 | word_list = [] 342 | 343 | for (first_word_pos, last_word_pos), (first_token_pos, last_token_pos) in all_spans: 344 | tword = sentence[first_word_pos:last_word_pos] 345 | 346 | acc_pos = check_ee_comu(tword) 347 | if acc_pos >= 0: 348 | tw = WordAcc(first_word_pos, last_word_pos, [acc_pos], 'eee') 349 | word_list.append(tw) 350 | continue 351 | 352 | tword_key = norm_word(tword) 353 | 354 | word_acc_id = self.tokenizer.accent_vocab.vocab.get(tword_key) 355 | if word_acc_id is not None: 356 | acc_pos = self.tokenizer.accent_vocab.vocab_index[word_acc_id] 357 | 358 | tw = WordAcc(first_word_pos, last_word_pos, acc_pos, 'sin') 359 | word_list.append(tw) 360 | continue 361 | 362 | tw = WordAcc(first_word_pos, last_word_pos, [], 'unk') 363 | word_list.append(tw) 364 | return word_list 365 | 366 | def all_easy_loop(self, doc, sum_batch): 367 | for doc_pos, all_spans in doc.easy_sentences: 368 | word_list = self.process_all_easy_sentence(doc, doc_pos, all_spans) 369 | ct = (doc_pos, word_list) 370 | sum_batch.append(ct) 371 | 372 | def all_easy_to_word_list(self, sentence, word_list): 373 | prev_pos = 0 374 | out_word_list = [] 375 | for first_word_pos, last_word_pos, acc_pos, state in word_list: 376 | tword = sentence[first_word_pos:last_word_pos] 377 | 378 | if len(acc_pos) < 1: 379 | tword = sentence[first_word_pos:last_word_pos] 380 | else: 381 | tword_letters = [] 382 | stress_sign = '+' 383 | 384 | for tpos in range(first_word_pos, last_word_pos): 385 | rel_pos = tpos - first_word_pos 386 | 387 | if not acc_pos: 388 | tword_letters.append(stress_sign) 389 | elif rel_pos in acc_pos: 390 | tword_letters.append(stress_sign) 391 | tword_letters.append(sentence[tpos]) 392 | tword = ''.join(tword_letters) 393 | 394 | punct = sentence[prev_pos:first_word_pos] 395 | prev_pos = last_word_pos 396 | ct = PunctWord(punct, tword) 397 | out_word_list.append(ct) 398 | punct = sentence[prev_pos:] 399 | ct = PunctWord(punct, '') 400 | out_word_list.append(ct) 401 | return out_word_list 402 | 403 | def accentuate_all_easy(self, input_sentence_list): 404 | if not bool(input_sentence_list): 405 | raise ValueError('list of strings is required') 406 | 407 | if not isinstance(input_sentence_list, list): 408 | raise ValueError('list of strings is required') 409 | 410 | if not all([isinstance(elem, str) for elem in input_sentence_list]): 411 | raise ValueError('list of strings is required') 412 | 413 | doc = AccentDocument(self.tokenizer, input_sentence_list) 414 | 415 | doc.get_all_easy(self.tokenizer) 416 | 417 | sum_batch = [] 418 | self.all_easy_loop(doc, sum_batch) 419 | 420 | sum_batch_index = [] 421 | for indx in range(len(sum_batch)): 422 | pos = sum_batch[indx][0] 423 | sum_batch_index.append((pos, indx)) 424 | 425 | output_sentence_word_list = [] 426 | 427 | for pos, indx in sorted(sum_batch_index): 428 | sentence = doc.all_sentences[pos] 429 | sentence_words = self.all_easy_to_word_list(sentence, sum_batch[indx][1]) 430 | output_sentence_word_list.append(sentence_words) 431 | 432 | output_sentence_text_list = [] 433 | for t_sentence in output_sentence_word_list: 434 | word_list = [] 435 | for punct, tword in t_sentence: 436 | word_list.append(punct) 437 | word_list.append(tword) 438 | output_sentence_text_list.append(''.join(word_list)) 439 | 440 | return output_sentence_text_list 441 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial-ShareAlike 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 58 | Public License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 63 | ("Public License"). To the extent this Public License may be 64 | interpreted as a contract, You are granted the Licensed Rights in 65 | consideration of Your acceptance of these terms and conditions, and the 66 | Licensor grants You such rights in consideration of benefits the 67 | Licensor receives from making the Licensed Material available under 68 | these terms and conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. BY-NC-SA Compatible License means a license listed at 88 | creativecommons.org/compatiblelicenses, approved by Creative 89 | Commons as essentially the equivalent of this Public License. 90 | 91 | d. Copyright and Similar Rights means copyright and/or similar rights 92 | closely related to copyright including, without limitation, 93 | performance, broadcast, sound recording, and Sui Generis Database 94 | Rights, without regard to how the rights are labeled or 95 | categorized. For purposes of this Public License, the rights 96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 97 | Rights. 98 | 99 | e. Effective Technological Measures means those measures that, in the 100 | absence of proper authority, may not be circumvented under laws 101 | fulfilling obligations under Article 11 of the WIPO Copyright 102 | Treaty adopted on December 20, 1996, and/or similar international 103 | agreements. 104 | 105 | f. Exceptions and Limitations means fair use, fair dealing, and/or 106 | any other exception or limitation to Copyright and Similar Rights 107 | that applies to Your use of the Licensed Material. 108 | 109 | g. License Elements means the license attributes listed in the name 110 | of a Creative Commons Public License. The License Elements of this 111 | Public License are Attribution, NonCommercial, and ShareAlike. 112 | 113 | h. Licensed Material means the artistic or literary work, database, 114 | or other material to which the Licensor applied this Public 115 | License. 116 | 117 | i. Licensed Rights means the rights granted to You subject to the 118 | terms and conditions of this Public License, which are limited to 119 | all Copyright and Similar Rights that apply to Your use of the 120 | Licensed Material and that the Licensor has authority to license. 121 | 122 | j. Licensor means the individual(s) or entity(ies) granting rights 123 | under this Public License. 124 | 125 | k. NonCommercial means not primarily intended for or directed towards 126 | commercial advantage or monetary compensation. For purposes of 127 | this Public License, the exchange of the Licensed Material for 128 | other material subject to Copyright and Similar Rights by digital 129 | file-sharing or similar means is NonCommercial provided there is 130 | no payment of monetary compensation in connection with the 131 | exchange. 132 | 133 | l. Share means to provide material to the public by any means or 134 | process that requires permission under the Licensed Rights, such 135 | as reproduction, public display, public performance, distribution, 136 | dissemination, communication, or importation, and to make material 137 | available to the public including in ways that members of the 138 | public may access the material from a place and at a time 139 | individually chosen by them. 140 | 141 | m. Sui Generis Database Rights means rights other than copyright 142 | resulting from Directive 96/9/EC of the European Parliament and of 143 | the Council of 11 March 1996 on the legal protection of databases, 144 | as amended and/or succeeded, as well as other essentially 145 | equivalent rights anywhere in the world. 146 | 147 | n. You means the individual or entity exercising the Licensed Rights 148 | under this Public License. Your has a corresponding meaning. 149 | 150 | 151 | Section 2 -- Scope. 152 | 153 | a. License grant. 154 | 155 | 1. Subject to the terms and conditions of this Public License, 156 | the Licensor hereby grants You a worldwide, royalty-free, 157 | non-sublicensable, non-exclusive, irrevocable license to 158 | exercise the Licensed Rights in the Licensed Material to: 159 | 160 | a. reproduce and Share the Licensed Material, in whole or 161 | in part, for NonCommercial purposes only; and 162 | 163 | b. produce, reproduce, and Share Adapted Material for 164 | NonCommercial purposes only. 165 | 166 | 2. Exceptions and Limitations. For the avoidance of doubt, where 167 | Exceptions and Limitations apply to Your use, this Public 168 | License does not apply, and You do not need to comply with 169 | its terms and conditions. 170 | 171 | 3. Term. The term of this Public License is specified in Section 172 | 6(a). 173 | 174 | 4. Media and formats; technical modifications allowed. The 175 | Licensor authorizes You to exercise the Licensed Rights in 176 | all media and formats whether now known or hereafter created, 177 | and to make technical modifications necessary to do so. The 178 | Licensor waives and/or agrees not to assert any right or 179 | authority to forbid You from making technical modifications 180 | necessary to exercise the Licensed Rights, including 181 | technical modifications necessary to circumvent Effective 182 | Technological Measures. For purposes of this Public License, 183 | simply making modifications authorized by this Section 2(a) 184 | (4) never produces Adapted Material. 185 | 186 | 5. Downstream recipients. 187 | 188 | a. Offer from the Licensor -- Licensed Material. Every 189 | recipient of the Licensed Material automatically 190 | receives an offer from the Licensor to exercise the 191 | Licensed Rights under the terms and conditions of this 192 | Public License. 193 | 194 | b. Additional offer from the Licensor -- Adapted Material. 195 | Every recipient of Adapted Material from You 196 | automatically receives an offer from the Licensor to 197 | exercise the Licensed Rights in the Adapted Material 198 | under the conditions of the Adapter's License You apply. 199 | 200 | c. No downstream restrictions. You may not offer or impose 201 | any additional or different terms or conditions on, or 202 | apply any Effective Technological Measures to, the 203 | Licensed Material if doing so restricts exercise of the 204 | Licensed Rights by any recipient of the Licensed 205 | Material. 206 | 207 | 6. No endorsement. Nothing in this Public License constitutes or 208 | may be construed as permission to assert or imply that You 209 | are, or that Your use of the Licensed Material is, connected 210 | with, or sponsored, endorsed, or granted official status by, 211 | the Licensor or others designated to receive attribution as 212 | provided in Section 3(a)(1)(A)(i). 213 | 214 | b. Other rights. 215 | 216 | 1. Moral rights, such as the right of integrity, are not 217 | licensed under this Public License, nor are publicity, 218 | privacy, and/or other similar personality rights; however, to 219 | the extent possible, the Licensor waives and/or agrees not to 220 | assert any such rights held by the Licensor to the limited 221 | extent necessary to allow You to exercise the Licensed 222 | Rights, but not otherwise. 223 | 224 | 2. Patent and trademark rights are not licensed under this 225 | Public License. 226 | 227 | 3. To the extent possible, the Licensor waives any right to 228 | collect royalties from You for the exercise of the Licensed 229 | Rights, whether directly or through a collecting society 230 | under any voluntary or waivable statutory or compulsory 231 | licensing scheme. In all other cases the Licensor expressly 232 | reserves any right to collect such royalties, including when 233 | the Licensed Material is used other than for NonCommercial 234 | purposes. 235 | 236 | 237 | Section 3 -- License Conditions. 238 | 239 | Your exercise of the Licensed Rights is expressly made subject to the 240 | following conditions. 241 | 242 | a. Attribution. 243 | 244 | 1. If You Share the Licensed Material (including in modified 245 | form), You must: 246 | 247 | a. retain the following if it is supplied by the Licensor 248 | with the Licensed Material: 249 | 250 | i. identification of the creator(s) of the Licensed 251 | Material and any others designated to receive 252 | attribution, in any reasonable manner requested by 253 | the Licensor (including by pseudonym if 254 | designated); 255 | 256 | ii. a copyright notice; 257 | 258 | iii. a notice that refers to this Public License; 259 | 260 | iv. a notice that refers to the disclaimer of 261 | warranties; 262 | 263 | v. a URI or hyperlink to the Licensed Material to the 264 | extent reasonably practicable; 265 | 266 | b. indicate if You modified the Licensed Material and 267 | retain an indication of any previous modifications; and 268 | 269 | c. indicate the Licensed Material is licensed under this 270 | Public License, and include the text of, or the URI or 271 | hyperlink to, this Public License. 272 | 273 | 2. You may satisfy the conditions in Section 3(a)(1) in any 274 | reasonable manner based on the medium, means, and context in 275 | which You Share the Licensed Material. For example, it may be 276 | reasonable to satisfy the conditions by providing a URI or 277 | hyperlink to a resource that includes the required 278 | information. 279 | 3. If requested by the Licensor, You must remove any of the 280 | information required by Section 3(a)(1)(A) to the extent 281 | reasonably practicable. 282 | 283 | b. ShareAlike. 284 | 285 | In addition to the conditions in Section 3(a), if You Share 286 | Adapted Material You produce, the following conditions also apply. 287 | 288 | 1. The Adapter's License You apply must be a Creative Commons 289 | license with the same License Elements, this version or 290 | later, or a BY-NC-SA Compatible License. 291 | 292 | 2. You must include the text of, or the URI or hyperlink to, the 293 | Adapter's License You apply. You may satisfy this condition 294 | in any reasonable manner based on the medium, means, and 295 | context in which You Share Adapted Material. 296 | 297 | 3. You may not offer or impose any additional or different terms 298 | or conditions on, or apply any Effective Technological 299 | Measures to, Adapted Material that restrict exercise of the 300 | rights granted under the Adapter's License You apply. 301 | 302 | 303 | Section 4 -- Sui Generis Database Rights. 304 | 305 | Where the Licensed Rights include Sui Generis Database Rights that 306 | apply to Your use of the Licensed Material: 307 | 308 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 309 | to extract, reuse, reproduce, and Share all or a substantial 310 | portion of the contents of the database for NonCommercial purposes 311 | only; 312 | 313 | b. if You include all or a substantial portion of the database 314 | contents in a database in which You have Sui Generis Database 315 | Rights, then the database in which You have Sui Generis Database 316 | Rights (but not its individual contents) is Adapted Material, 317 | including for purposes of Section 3(b); and 318 | 319 | c. You must comply with the conditions in Section 3(a) if You Share 320 | all or a substantial portion of the contents of the database. 321 | 322 | For the avoidance of doubt, this Section 4 supplements and does not 323 | replace Your obligations under this Public License where the Licensed 324 | Rights include other Copyright and Similar Rights. 325 | 326 | 327 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 328 | 329 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 330 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 331 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 332 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 333 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 334 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 335 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 336 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 337 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 338 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 339 | 340 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 341 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 342 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 343 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 344 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 345 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 346 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 347 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 348 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 349 | 350 | c. The disclaimer of warranties and limitation of liability provided 351 | above shall be interpreted in a manner that, to the extent 352 | possible, most closely approximates an absolute disclaimer and 353 | waiver of all liability. 354 | 355 | 356 | Section 6 -- Term and Termination. 357 | 358 | a. This Public License applies for the term of the Copyright and 359 | Similar Rights licensed here. However, if You fail to comply with 360 | this Public License, then Your rights under this Public License 361 | terminate automatically. 362 | 363 | b. Where Your right to use the Licensed Material has terminated under 364 | Section 6(a), it reinstates: 365 | 366 | 1. automatically as of the date the violation is cured, provided 367 | it is cured within 30 days of Your discovery of the 368 | violation; or 369 | 370 | 2. upon express reinstatement by the Licensor. 371 | 372 | For the avoidance of doubt, this Section 6(b) does not affect any 373 | right the Licensor may have to seek remedies for Your violations 374 | of this Public License. 375 | 376 | c. For the avoidance of doubt, the Licensor may also offer the 377 | Licensed Material under separate terms or conditions or stop 378 | distributing the Licensed Material at any time; however, doing so 379 | will not terminate this Public License. 380 | 381 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 382 | License. 383 | 384 | 385 | Section 7 -- Other Terms and Conditions. 386 | 387 | a. The Licensor shall not be bound by any additional or different 388 | terms or conditions communicated by You unless expressly agreed. 389 | 390 | b. Any arrangements, understandings, or agreements regarding the 391 | Licensed Material not stated herein are separate from and 392 | independent of the terms and conditions of this Public License. 393 | 394 | 395 | Section 8 -- Interpretation. 396 | 397 | a. For the avoidance of doubt, this Public License does not, and 398 | shall not be interpreted to, reduce, limit, restrict, or impose 399 | conditions on any use of the Licensed Material that could lawfully 400 | be made without permission under this Public License. 401 | 402 | b. To the extent possible, if any provision of this Public License is 403 | deemed unenforceable, it shall be automatically reformed to the 404 | minimum extent necessary to make it enforceable. If the provision 405 | cannot be reformed, it shall be severed from this Public License 406 | without affecting the enforceability of the remaining terms and 407 | conditions. 408 | 409 | c. No term or condition of this Public License will be waived and no 410 | failure to comply consented to unless expressly agreed to by the 411 | Licensor. 412 | 413 | d. Nothing in this Public License constitutes or may be interpreted 414 | as a limitation upon, or waiver of, any privileges and immunities 415 | that apply to the Licensor or You, including from the legal 416 | processes of any jurisdiction or authority. 417 | 418 | ======================================================================= 419 | 420 | Creative Commons is not a party to its public 421 | licenses. Notwithstanding, Creative Commons may elect to apply one of 422 | its public licenses to material it publishes and in those instances 423 | will be considered the “Licensor.” The text of the Creative Commons 424 | public licenses is dedicated to the public domain under the CC0 Public 425 | Domain Dedication. Except for the limited purpose of indicating that 426 | material is shared under a Creative Commons public license or as 427 | otherwise permitted by the Creative Commons policies published at 428 | creativecommons.org/policies, Creative Commons does not authorize the 429 | use of the trademark "Creative Commons" or any other trademark or logo 430 | of Creative Commons without its prior written consent including, 431 | without limitation, in connection with any unauthorized modifications 432 | to any of its public licenses or any other arrangements, 433 | understandings, or agreements concerning use of licensed material. For 434 | the avoidance of doubt, this paragraph does not form part of the 435 | public licenses. 436 | 437 | Creative Commons may be contacted at creativecommons.org. 438 | -------------------------------------------------------------------------------- /omogre/accentuator/bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | # edited from pytorch_pretrained_bert 4 | 5 | # https://github.com/google-research/bert 6 | # https://github.com/maknotavailable/pytorch-pretrained-BERT 7 | 8 | # coding=utf-8 9 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 10 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 11 | # 12 | # Licensed under the Apache License, Version 2.0 (the "License"); 13 | # you may not use this file except in compliance with the License. 14 | # You may obtain a copy of the License at 15 | # 16 | # http://www.apache.org/licenses/LICENSE-2.0 17 | # 18 | # Unless required by applicable law or agreed to in writing, software 19 | # distributed under the License is distributed on an "AS IS" BASIS, 20 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 21 | # See the License for the specific language governing permissions and 22 | # limitations under the License. 23 | 24 | """PyTorch BERT model.""" 25 | 26 | from __future__ import absolute_import, division, print_function, unicode_literals 27 | 28 | import copy 29 | import json 30 | import logging 31 | import math 32 | import os 33 | import shutil 34 | import tarfile 35 | import tempfile 36 | import sys 37 | from io import open 38 | 39 | import torch 40 | from torch import nn 41 | from torch.nn import CrossEntropyLoss 42 | 43 | CONFIG_NAME = "config.json" 44 | WEIGHTS_NAME = "pytorch_model.bin" 45 | 46 | logger = logging.getLogger(__name__) 47 | 48 | BERT_CONFIG_NAME = 'bert_config.json' 49 | TF_WEIGHTS_NAME = 'model.ckpt' 50 | 51 | def gelu(x): 52 | """Implementation of the gelu activation function. 53 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 54 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 55 | Also see https://arxiv.org/abs/1606.08415 56 | """ 57 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 58 | 59 | 60 | def swish(x): 61 | return x * torch.sigmoid(x) 62 | 63 | 64 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 65 | 66 | 67 | class BertConfig(object): 68 | """Configuration class to store the configuration of a `BertModel`. 69 | """ 70 | def __init__(self, 71 | vocab_size_or_config_json_file, 72 | hidden_size=768, 73 | num_hidden_layers=12, 74 | num_attention_heads=12, 75 | intermediate_size=3072, 76 | hidden_act="gelu", 77 | hidden_dropout_prob=0.1, 78 | attention_probs_dropout_prob=0.1, 79 | max_position_embeddings=512, 80 | type_vocab_size=2, 81 | initializer_range=0.02): 82 | """Constructs BertConfig. 83 | 84 | Args: 85 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 86 | hidden_size: Size of the encoder layers and the pooler layer. 87 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 88 | num_attention_heads: Number of attention heads for each attention layer in 89 | the Transformer encoder. 90 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 91 | layer in the Transformer encoder. 92 | hidden_act: The non-linear activation function (function or string) in the 93 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 94 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 95 | layers in the embeddings, encoder, and pooler. 96 | attention_probs_dropout_prob: The dropout ratio for the attention 97 | probabilities. 98 | max_position_embeddings: The maximum sequence length that this model might 99 | ever be used with. Typically set this to something large just in case 100 | (e.g., 512 or 1024 or 2048). 101 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 102 | `BertModel`. 103 | initializer_range: The sttdev of the truncated_normal_initializer for 104 | initializing all weight matrices. 105 | """ 106 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 107 | and isinstance(vocab_size_or_config_json_file, unicode)): 108 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 109 | json_config = json.loads(reader.read()) 110 | for key, value in json_config.items(): 111 | self.__dict__[key] = value 112 | elif isinstance(vocab_size_or_config_json_file, int): 113 | self.vocab_size = vocab_size_or_config_json_file 114 | self.hidden_size = hidden_size 115 | self.num_hidden_layers = num_hidden_layers 116 | self.num_attention_heads = num_attention_heads 117 | self.hidden_act = hidden_act 118 | self.intermediate_size = intermediate_size 119 | self.hidden_dropout_prob = hidden_dropout_prob 120 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 121 | self.max_position_embeddings = max_position_embeddings 122 | self.type_vocab_size = type_vocab_size 123 | self.initializer_range = initializer_range 124 | else: 125 | raise ValueError("First argument must be either a vocabulary size (int)" 126 | "or the path to a pretrained model config file (str)") 127 | 128 | @classmethod 129 | def from_dict(cls, json_object): 130 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 131 | config = BertConfig(vocab_size_or_config_json_file=-1) 132 | for key, value in json_object.items(): 133 | config.__dict__[key] = value 134 | return config 135 | 136 | @classmethod 137 | def from_json_file(cls, json_file): 138 | """Constructs a `BertConfig` from a json file of parameters.""" 139 | with open(json_file, "r", encoding='utf-8') as reader: 140 | text = reader.read() 141 | return cls.from_dict(json.loads(text)) 142 | 143 | def __repr__(self): 144 | return str(self.to_json_string()) 145 | 146 | def to_dict(self): 147 | """Serializes this instance to a Python dictionary.""" 148 | output = copy.deepcopy(self.__dict__) 149 | return output 150 | 151 | def to_json_string(self): 152 | """Serializes this instance to a JSON string.""" 153 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 154 | 155 | def to_json_file(self, json_file_path): 156 | """ Save this instance to a json file.""" 157 | with open(json_file_path, "w", encoding='utf-8') as writer: 158 | writer.write(self.to_json_string()) 159 | 160 | #try: 161 | # from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm 162 | #except ImportError: 163 | # logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .") 164 | 165 | class BertLayerNorm(nn.Module): 166 | def __init__(self, hidden_size, eps=1e-12): 167 | """Construct a layernorm module in the TF style (epsilon inside the square root). 168 | """ 169 | super(BertLayerNorm, self).__init__() 170 | self.weight = nn.Parameter(torch.ones(hidden_size)) 171 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 172 | self.variance_epsilon = eps 173 | 174 | def forward(self, x): 175 | u = x.mean(-1, keepdim=True) 176 | s = (x - u).pow(2).mean(-1, keepdim=True) 177 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 178 | return self.weight * x + self.bias 179 | 180 | class BertEmbeddings(nn.Module): 181 | """Construct the embeddings from word, position and token_type embeddings. 182 | """ 183 | def __init__(self, config): 184 | super(BertEmbeddings, self).__init__() 185 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) 186 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 187 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 188 | 189 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 190 | # any TensorFlow checkpoint file 191 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 192 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 193 | 194 | def forward(self, input_ids, token_type_ids=None): 195 | seq_length = input_ids.size(1) 196 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 197 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 198 | if token_type_ids is None: 199 | token_type_ids = torch.zeros_like(input_ids) 200 | 201 | words_embeddings = self.word_embeddings(input_ids) 202 | position_embeddings = self.position_embeddings(position_ids) 203 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 204 | 205 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 206 | embeddings = self.LayerNorm(embeddings) 207 | embeddings = self.dropout(embeddings) 208 | return embeddings 209 | 210 | 211 | class BertSelfAttention(nn.Module): 212 | def __init__(self, config): 213 | super(BertSelfAttention, self).__init__() 214 | if config.hidden_size % config.num_attention_heads != 0: 215 | raise ValueError( 216 | "The hidden size (%d) is not a multiple of the number of attention " 217 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 218 | self.num_attention_heads = config.num_attention_heads 219 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 220 | self.all_head_size = self.num_attention_heads * self.attention_head_size 221 | 222 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 223 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 224 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 225 | 226 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 227 | 228 | def transpose_for_scores(self, x): 229 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 230 | x = x.view(*new_x_shape) 231 | return x.permute(0, 2, 1, 3) 232 | 233 | def forward(self, hidden_states, attention_mask): 234 | mixed_query_layer = self.query(hidden_states) 235 | mixed_key_layer = self.key(hidden_states) 236 | mixed_value_layer = self.value(hidden_states) 237 | 238 | query_layer = self.transpose_for_scores(mixed_query_layer) 239 | key_layer = self.transpose_for_scores(mixed_key_layer) 240 | value_layer = self.transpose_for_scores(mixed_value_layer) 241 | 242 | # Take the dot product between "query" and "key" to get the raw attention scores. 243 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 244 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 245 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 246 | attention_scores = attention_scores + attention_mask 247 | 248 | # Normalize the attention scores to probabilities. 249 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 250 | 251 | # This is actually dropping out entire tokens to attend to, which might 252 | # seem a bit unusual, but is taken from the original Transformer paper. 253 | attention_probs = self.dropout(attention_probs) 254 | 255 | context_layer = torch.matmul(attention_probs, value_layer) 256 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 257 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 258 | context_layer = context_layer.view(*new_context_layer_shape) 259 | return context_layer 260 | 261 | 262 | class BertSelfOutput(nn.Module): 263 | def __init__(self, config): 264 | super(BertSelfOutput, self).__init__() 265 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 266 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 267 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 268 | 269 | def forward(self, hidden_states, input_tensor): 270 | hidden_states = self.dense(hidden_states) 271 | hidden_states = self.dropout(hidden_states) 272 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 273 | return hidden_states 274 | 275 | 276 | class BertAttention(nn.Module): 277 | def __init__(self, config): 278 | super(BertAttention, self).__init__() 279 | self.self = BertSelfAttention(config) 280 | self.output = BertSelfOutput(config) 281 | 282 | def forward(self, input_tensor, attention_mask): 283 | self_output = self.self(input_tensor, attention_mask) 284 | attention_output = self.output(self_output, input_tensor) 285 | return attention_output 286 | 287 | 288 | class BertIntermediate(nn.Module): 289 | def __init__(self, config): 290 | super(BertIntermediate, self).__init__() 291 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 292 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): 293 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 294 | else: 295 | self.intermediate_act_fn = config.hidden_act 296 | 297 | def forward(self, hidden_states): 298 | hidden_states = self.dense(hidden_states) 299 | hidden_states = self.intermediate_act_fn(hidden_states) 300 | return hidden_states 301 | 302 | 303 | class BertOutput(nn.Module): 304 | def __init__(self, config): 305 | super(BertOutput, self).__init__() 306 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 307 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 308 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 309 | 310 | def forward(self, hidden_states, input_tensor): 311 | hidden_states = self.dense(hidden_states) 312 | hidden_states = self.dropout(hidden_states) 313 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 314 | return hidden_states 315 | 316 | 317 | class BertLayer(nn.Module): 318 | def __init__(self, config): 319 | super(BertLayer, self).__init__() 320 | self.attention = BertAttention(config) 321 | self.intermediate = BertIntermediate(config) 322 | self.output = BertOutput(config) 323 | 324 | def forward(self, hidden_states, attention_mask): 325 | attention_output = self.attention(hidden_states, attention_mask) 326 | intermediate_output = self.intermediate(attention_output) 327 | layer_output = self.output(intermediate_output, attention_output) 328 | return layer_output 329 | 330 | 331 | class BertEncoder(nn.Module): 332 | def __init__(self, config): 333 | super(BertEncoder, self).__init__() 334 | layer = BertLayer(config) 335 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 336 | 337 | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): 338 | all_encoder_layers = [] 339 | for layer_module in self.layer: 340 | hidden_states = layer_module(hidden_states, attention_mask) 341 | if output_all_encoded_layers: 342 | all_encoder_layers.append(hidden_states) 343 | if not output_all_encoded_layers: 344 | all_encoder_layers.append(hidden_states) 345 | return all_encoder_layers 346 | 347 | 348 | class BertPooler(nn.Module): 349 | def __init__(self, config): 350 | super(BertPooler, self).__init__() 351 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 352 | self.activation = nn.Tanh() 353 | 354 | def forward(self, hidden_states): 355 | # We "pool" the model by simply taking the hidden state corresponding 356 | # to the first token. 357 | first_token_tensor = hidden_states[:, 0] 358 | pooled_output = self.dense(first_token_tensor) 359 | pooled_output = self.activation(pooled_output) 360 | return pooled_output 361 | 362 | 363 | class BertPredictionHeadTransform(nn.Module): 364 | def __init__(self, config): 365 | super(BertPredictionHeadTransform, self).__init__() 366 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 367 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): 368 | self.transform_act_fn = ACT2FN[config.hidden_act] 369 | else: 370 | self.transform_act_fn = config.hidden_act 371 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 372 | 373 | def forward(self, hidden_states): 374 | hidden_states = self.dense(hidden_states) 375 | hidden_states = self.transform_act_fn(hidden_states) 376 | hidden_states = self.LayerNorm(hidden_states) 377 | return hidden_states 378 | 379 | 380 | class BertLMPredictionHead(nn.Module): 381 | def __init__(self, config, bert_model_embedding_weights): 382 | super(BertLMPredictionHead, self).__init__() 383 | self.transform = BertPredictionHeadTransform(config) 384 | 385 | # The output weights are the same as the input embeddings, but there is 386 | # an output-only bias for each token. 387 | self.decoder = nn.Linear(bert_model_embedding_weights.size(1), 388 | bert_model_embedding_weights.size(0), 389 | bias=False) 390 | self.decoder.weight = bert_model_embedding_weights 391 | self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) 392 | 393 | def forward(self, hidden_states): 394 | hidden_states = self.transform(hidden_states) 395 | hidden_states = self.decoder(hidden_states) + self.bias 396 | return hidden_states 397 | 398 | 399 | class BertOnlyMLMHead(nn.Module): 400 | def __init__(self, config, bert_model_embedding_weights): 401 | super(BertOnlyMLMHead, self).__init__() 402 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) 403 | 404 | def forward(self, sequence_output): 405 | prediction_scores = self.predictions(sequence_output) 406 | return prediction_scores 407 | 408 | 409 | class BertOnlyNSPHead(nn.Module): 410 | def __init__(self, config): 411 | super(BertOnlyNSPHead, self).__init__() 412 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 413 | 414 | def forward(self, pooled_output): 415 | seq_relationship_score = self.seq_relationship(pooled_output) 416 | return seq_relationship_score 417 | 418 | 419 | class BertPreTrainingHeads(nn.Module): 420 | def __init__(self, config, bert_model_embedding_weights): 421 | super(BertPreTrainingHeads, self).__init__() 422 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) 423 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 424 | 425 | def forward(self, sequence_output, pooled_output): 426 | prediction_scores = self.predictions(sequence_output) 427 | seq_relationship_score = self.seq_relationship(pooled_output) 428 | return prediction_scores, seq_relationship_score 429 | 430 | 431 | class BertPreTrainedModel(nn.Module): 432 | """ An abstract class to handle weights initialization and 433 | a simple interface for dowloading and loading pretrained models. 434 | """ 435 | def __init__(self, config, *inputs, **kwargs): 436 | super(BertPreTrainedModel, self).__init__() 437 | if not isinstance(config, BertConfig): 438 | raise ValueError( 439 | "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " 440 | "To create a model from a Google pretrained model use " 441 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 442 | self.__class__.__name__, self.__class__.__name__ 443 | )) 444 | self.config = config 445 | 446 | def init_bert_weights(self, module): 447 | """ Initialize the weights. 448 | """ 449 | if isinstance(module, (nn.Linear, nn.Embedding)): 450 | # Slightly different from the TF version which uses truncated_normal for initialization 451 | # cf https://github.com/pytorch/pytorch/pull/5617 452 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 453 | elif isinstance(module, BertLayerNorm): 454 | module.bias.data.zero_() 455 | module.weight.data.fill_(1.0) 456 | if isinstance(module, nn.Linear) and module.bias is not None: 457 | module.bias.data.zero_() 458 | 459 | @classmethod 460 | def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): 461 | """ 462 | Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. 463 | Download and cache the pre-trained model file if needed. 464 | 465 | Params: 466 | pretrained_model_name_or_path: either: 467 | - a str with the name of a pre-trained model to load selected in the list of: 468 | . `bert-base-uncased` 469 | . `bert-large-uncased` 470 | . `bert-base-cased` 471 | . `bert-large-cased` 472 | . `bert-base-multilingual-uncased` 473 | . `bert-base-multilingual-cased` 474 | . `bert-base-chinese` 475 | - a path or url to a pretrained model archive containing: 476 | . `bert_config.json` a configuration file for the model 477 | . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance 478 | - a path or url to a pretrained model archive containing: 479 | . `bert_config.json` a configuration file for the model 480 | . `model.chkpt` a TensorFlow checkpoint 481 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint 482 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 483 | state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models 484 | *inputs, **kwargs: additional input for the specific Bert class 485 | (ex: num_labels for BertForSequenceClassification) 486 | """ 487 | state_dict = kwargs.get('state_dict', None) 488 | kwargs.pop('state_dict', None) 489 | cache_dir = kwargs.get('cache_dir', None) 490 | kwargs.pop('cache_dir', None) 491 | from_tf = kwargs.get('from_tf', False) 492 | kwargs.pop('from_tf', None) 493 | 494 | archive_file = pretrained_model_name_or_path 495 | logger.info("loading archive file {}".format(archive_file)) 496 | resolved_archive_file = archive_file 497 | ''' 498 | if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: 499 | archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] 500 | else: 501 | archive_file = pretrained_model_name_or_path 502 | # redirect to the cache, if necessary 503 | try: 504 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) 505 | except EnvironmentError: 506 | logger.error( 507 | "Model name '{}' was not found in model name list ({}). " 508 | "We assumed '{}' was a path or url but couldn't find any file " 509 | "associated to this path or url.".format( 510 | pretrained_model_name_or_path, 511 | ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), 512 | archive_file)) 513 | return None 514 | if resolved_archive_file == archive_file: 515 | logger.info("loading archive file {}".format(archive_file)) 516 | else: 517 | logger.info("loading archive file {} from cache at {}".format( 518 | archive_file, resolved_archive_file)) 519 | ''' 520 | tempdir = None 521 | serialization_dir = resolved_archive_file 522 | ''' 523 | if os.path.isdir(resolved_archive_file) or from_tf: 524 | serialization_dir = resolved_archive_file 525 | else: 526 | # Extract archive to temp dir 527 | tempdir = tempfile.mkdtemp() 528 | logger.info("extracting archive file {} to temp dir {}".format( 529 | resolved_archive_file, tempdir)) 530 | with tarfile.open(resolved_archive_file, 'r:gz') as archive: 531 | archive.extractall(tempdir) 532 | serialization_dir = tempdir 533 | ''' 534 | # Load config 535 | config_file = os.path.join(serialization_dir, CONFIG_NAME) 536 | #if not os.path.exists(config_file): 537 | # # Backward compatibility with old naming format 538 | # config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME) 539 | config = BertConfig.from_json_file(config_file) 540 | logger.info("Model config {}".format(config)) 541 | # Instantiate model. 542 | model = cls(config, *inputs, **kwargs) 543 | if state_dict is None and not from_tf: 544 | weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) 545 | state_dict = torch.load(weights_path, map_location='cpu', weights_only=True) 546 | if tempdir: 547 | # Clean up temp dir 548 | shutil.rmtree(tempdir) 549 | if from_tf: 550 | # Directly load from a TensorFlow checkpoint 551 | weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) 552 | return load_tf_weights_in_bert(model, weights_path) 553 | # Load from a PyTorch state_dict 554 | old_keys = [] 555 | new_keys = [] 556 | for key in state_dict.keys(): 557 | new_key = None 558 | if 'gamma' in key: 559 | new_key = key.replace('gamma', 'weight') 560 | if 'beta' in key: 561 | new_key = key.replace('beta', 'bias') 562 | if new_key: 563 | old_keys.append(key) 564 | new_keys.append(new_key) 565 | for old_key, new_key in zip(old_keys, new_keys): 566 | state_dict[new_key] = state_dict.pop(old_key) 567 | 568 | missing_keys = [] 569 | unexpected_keys = [] 570 | error_msgs = [] 571 | # copy state_dict so _load_from_state_dict can modify it 572 | metadata = getattr(state_dict, '_metadata', None) 573 | state_dict = state_dict.copy() 574 | if metadata is not None: 575 | state_dict._metadata = metadata 576 | 577 | def load(module, prefix=''): 578 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 579 | module._load_from_state_dict( 580 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 581 | for name, child in module._modules.items(): 582 | if child is not None: 583 | load(child, prefix + name + '.') 584 | start_prefix = '' 585 | if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()): 586 | start_prefix = 'bert.' 587 | load(model, prefix=start_prefix) 588 | if len(missing_keys) > 0: 589 | logger.info("Weights of {} not initialized from pretrained model: {}".format( 590 | model.__class__.__name__, missing_keys)) 591 | if len(unexpected_keys) > 0: 592 | logger.info("Weights from pretrained model not used in {}: {}".format( 593 | model.__class__.__name__, unexpected_keys)) 594 | if len(error_msgs) > 0: 595 | raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 596 | model.__class__.__name__, "\n\t".join(error_msgs))) 597 | return model 598 | 599 | 600 | class BertModel(BertPreTrainedModel): 601 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 602 | 603 | Params: 604 | config: a BertConfig class instance with the configuration to build a new model 605 | 606 | Inputs: 607 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 608 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 609 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 610 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 611 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 612 | a `sentence B` token (see BERT paper for more details). 613 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 614 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 615 | input sequence length in the current batch. It's the mask that we typically use for attention when 616 | a batch has varying length sentences. 617 | `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. 618 | 619 | Outputs: Tuple of (encoded_layers, pooled_output) 620 | `encoded_layers`: controled by `output_all_encoded_layers` argument: 621 | - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end 622 | of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each 623 | encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], 624 | - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding 625 | to the last attention block of shape [batch_size, sequence_length, hidden_size], 626 | `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a 627 | classifier pretrained on top of the hidden state associated to the first character of the 628 | input (`CLS`) to train on the Next-Sentence task (see BERT's paper). 629 | 630 | Example usage: 631 | ```python 632 | # Already been converted into WordPiece token ids 633 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 634 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 635 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 636 | 637 | config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 638 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 639 | 640 | model = modeling.BertModel(config=config) 641 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 642 | ``` 643 | """ 644 | def __init__(self, config): 645 | super(BertModel, self).__init__(config) 646 | self.embeddings = BertEmbeddings(config) 647 | self.encoder = BertEncoder(config) 648 | self.pooler = BertPooler(config) 649 | self.apply(self.init_bert_weights) 650 | 651 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): 652 | if attention_mask is None: 653 | attention_mask = torch.ones_like(input_ids) 654 | if token_type_ids is None: 655 | token_type_ids = torch.zeros_like(input_ids) 656 | 657 | # We create a 3D attention mask from a 2D tensor mask. 658 | # Sizes are [batch_size, 1, 1, to_seq_length] 659 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 660 | # this attention mask is more simple than the triangular masking of causal attention 661 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 662 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 663 | 664 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 665 | # masked positions, this operation will create a tensor which is 0.0 for 666 | # positions we want to attend and -10000.0 for masked positions. 667 | # Since we are adding it to the raw scores before the softmax, this is 668 | # effectively the same as removing these entirely. 669 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 670 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 671 | 672 | embedding_output = self.embeddings(input_ids, token_type_ids) 673 | encoded_layers = self.encoder(embedding_output, 674 | extended_attention_mask, 675 | output_all_encoded_layers=output_all_encoded_layers) 676 | sequence_output = encoded_layers[-1] 677 | pooled_output = self.pooler(sequence_output) 678 | if not output_all_encoded_layers: 679 | encoded_layers = encoded_layers[-1] 680 | return encoded_layers, pooled_output 681 | 682 | 683 | class BertForPreTraining(BertPreTrainedModel): 684 | """BERT model with pre-training heads. 685 | This module comprises the BERT model followed by the two pre-training heads: 686 | - the masked language modeling head, and 687 | - the next sentence classification head. 688 | 689 | Params: 690 | config: a BertConfig class instance with the configuration to build a new model. 691 | 692 | Inputs: 693 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 694 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 695 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 696 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 697 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 698 | a `sentence B` token (see BERT paper for more details). 699 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 700 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 701 | input sequence length in the current batch. It's the mask that we typically use for attention when 702 | a batch has varying length sentences. 703 | `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 704 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 705 | is only computed for the labels set in [0, ..., vocab_size] 706 | `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size] 707 | with indices selected in [0, 1]. 708 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 709 | 710 | Outputs: 711 | if `masked_lm_labels` and `next_sentence_label` are not `None`: 712 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 713 | sentence classification loss. 714 | if `masked_lm_labels` or `next_sentence_label` is `None`: 715 | Outputs a tuple comprising 716 | - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and 717 | - the next sentence classification logits of shape [batch_size, 2]. 718 | 719 | Example usage: 720 | ```python 721 | # Already been converted into WordPiece token ids 722 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 723 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 724 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 725 | 726 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 727 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 728 | 729 | model = BertForPreTraining(config) 730 | masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 731 | ``` 732 | """ 733 | def __init__(self, config): 734 | super(BertForPreTraining, self).__init__(config) 735 | self.bert = BertModel(config) 736 | self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) 737 | self.apply(self.init_bert_weights) 738 | 739 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None): 740 | sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, 741 | output_all_encoded_layers=False) 742 | prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) 743 | 744 | if masked_lm_labels is not None and next_sentence_label is not None: 745 | loss_fct = CrossEntropyLoss(ignore_index=-1) 746 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 747 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 748 | total_loss = masked_lm_loss + next_sentence_loss 749 | return total_loss 750 | else: 751 | return prediction_scores, seq_relationship_score 752 | 753 | 754 | class BertForMaskedLM(BertPreTrainedModel): 755 | """BERT model with the masked language modeling head. 756 | This module comprises the BERT model followed by the masked language modeling head. 757 | 758 | Params: 759 | config: a BertConfig class instance with the configuration to build a new model. 760 | 761 | Inputs: 762 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 763 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 764 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 765 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 766 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 767 | a `sentence B` token (see BERT paper for more details). 768 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 769 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 770 | input sequence length in the current batch. It's the mask that we typically use for attention when 771 | a batch has varying length sentences. 772 | `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 773 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 774 | is only computed for the labels set in [0, ..., vocab_size] 775 | 776 | Outputs: 777 | if `masked_lm_labels` is not `None`: 778 | Outputs the masked language modeling loss. 779 | if `masked_lm_labels` is `None`: 780 | Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. 781 | 782 | Example usage: 783 | ```python 784 | # Already been converted into WordPiece token ids 785 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 786 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 787 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 788 | 789 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 790 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 791 | 792 | model = BertForMaskedLM(config) 793 | masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) 794 | ``` 795 | """ 796 | def __init__(self, config): 797 | super(BertForMaskedLM, self).__init__(config) 798 | self.bert = BertModel(config) 799 | self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) 800 | self.apply(self.init_bert_weights) 801 | 802 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None): 803 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, 804 | output_all_encoded_layers=False) 805 | prediction_scores = self.cls(sequence_output) 806 | 807 | if masked_lm_labels is not None: 808 | loss_fct = CrossEntropyLoss(ignore_index=-1) 809 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 810 | return masked_lm_loss 811 | else: 812 | return prediction_scores 813 | 814 | 815 | class BertForNextSentencePrediction(BertPreTrainedModel): 816 | """BERT model with next sentence prediction head. 817 | This module comprises the BERT model followed by the next sentence classification head. 818 | 819 | Params: 820 | config: a BertConfig class instance with the configuration to build a new model. 821 | 822 | Inputs: 823 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 824 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 825 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 826 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 827 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 828 | a `sentence B` token (see BERT paper for more details). 829 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 830 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 831 | input sequence length in the current batch. It's the mask that we typically use for attention when 832 | a batch has varying length sentences. 833 | `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] 834 | with indices selected in [0, 1]. 835 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 836 | 837 | Outputs: 838 | if `next_sentence_label` is not `None`: 839 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 840 | sentence classification loss. 841 | if `next_sentence_label` is `None`: 842 | Outputs the next sentence classification logits of shape [batch_size, 2]. 843 | 844 | Example usage: 845 | ```python 846 | # Already been converted into WordPiece token ids 847 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 848 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 849 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 850 | 851 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 852 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 853 | 854 | model = BertForNextSentencePrediction(config) 855 | seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 856 | ``` 857 | """ 858 | def __init__(self, config): 859 | super(BertForNextSentencePrediction, self).__init__(config) 860 | self.bert = BertModel(config) 861 | self.cls = BertOnlyNSPHead(config) 862 | self.apply(self.init_bert_weights) 863 | 864 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None): 865 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, 866 | output_all_encoded_layers=False) 867 | seq_relationship_score = self.cls( pooled_output) 868 | 869 | if next_sentence_label is not None: 870 | loss_fct = CrossEntropyLoss(ignore_index=-1) 871 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 872 | return next_sentence_loss 873 | else: 874 | return seq_relationship_score 875 | 876 | 877 | class BertForSequenceClassification(BertPreTrainedModel): 878 | """BERT model for classification. 879 | This module is composed of the BERT model with a linear layer on top of 880 | the pooled output. 881 | 882 | Params: 883 | `config`: a BertConfig class instance with the configuration to build a new model. 884 | `num_labels`: the number of classes for the classifier. Default = 2. 885 | 886 | Inputs: 887 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 888 | with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts 889 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 890 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 891 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 892 | a `sentence B` token (see BERT paper for more details). 893 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 894 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 895 | input sequence length in the current batch. It's the mask that we typically use for attention when 896 | a batch has varying length sentences. 897 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 898 | with indices selected in [0, ..., num_labels]. 899 | 900 | Outputs: 901 | if `labels` is not `None`: 902 | Outputs the CrossEntropy classification loss of the output with the labels. 903 | if `labels` is `None`: 904 | Outputs the classification logits of shape [batch_size, num_labels]. 905 | 906 | Example usage: 907 | ```python 908 | # Already been converted into WordPiece token ids 909 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 910 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 911 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 912 | 913 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 914 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 915 | 916 | num_labels = 2 917 | 918 | model = BertForSequenceClassification(config, num_labels) 919 | logits = model(input_ids, token_type_ids, input_mask) 920 | ``` 921 | """ 922 | def __init__(self, config, num_labels): 923 | super(BertForSequenceClassification, self).__init__(config) 924 | self.num_labels = num_labels 925 | self.bert = BertModel(config) 926 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 927 | self.classifier = nn.Linear(config.hidden_size, num_labels) 928 | self.apply(self.init_bert_weights) 929 | 930 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 931 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 932 | pooled_output = self.dropout(pooled_output) 933 | logits = self.classifier(pooled_output) 934 | 935 | if labels is not None: 936 | loss_fct = CrossEntropyLoss() 937 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 938 | return loss 939 | else: 940 | return logits 941 | 942 | 943 | class BertForMultipleChoice(BertPreTrainedModel): 944 | """BERT model for multiple choice tasks. 945 | This module is composed of the BERT model with a linear layer on top of 946 | the pooled output. 947 | 948 | Params: 949 | `config`: a BertConfig class instance with the configuration to build a new model. 950 | `num_choices`: the number of classes for the classifier. Default = 2. 951 | 952 | Inputs: 953 | `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] 954 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 955 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 956 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] 957 | with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` 958 | and type 1 corresponds to a `sentence B` token (see BERT paper for more details). 959 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices 960 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 961 | input sequence length in the current batch. It's the mask that we typically use for attention when 962 | a batch has varying length sentences. 963 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 964 | with indices selected in [0, ..., num_choices]. 965 | 966 | Outputs: 967 | if `labels` is not `None`: 968 | Outputs the CrossEntropy classification loss of the output with the labels. 969 | if `labels` is `None`: 970 | Outputs the classification logits of shape [batch_size, num_labels]. 971 | 972 | Example usage: 973 | ```python 974 | # Already been converted into WordPiece token ids 975 | input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) 976 | input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) 977 | token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) 978 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 979 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 980 | 981 | num_choices = 2 982 | 983 | model = BertForMultipleChoice(config, num_choices) 984 | logits = model(input_ids, token_type_ids, input_mask) 985 | ``` 986 | """ 987 | def __init__(self, config, num_choices): 988 | super(BertForMultipleChoice, self).__init__(config) 989 | self.num_choices = num_choices 990 | self.bert = BertModel(config) 991 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 992 | self.classifier = nn.Linear(config.hidden_size, 1) 993 | self.apply(self.init_bert_weights) 994 | 995 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 996 | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) 997 | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None 998 | flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None 999 | _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) 1000 | pooled_output = self.dropout(pooled_output) 1001 | logits = self.classifier(pooled_output) 1002 | reshaped_logits = logits.view(-1, self.num_choices) 1003 | 1004 | if labels is not None: 1005 | loss_fct = CrossEntropyLoss() 1006 | loss = loss_fct(reshaped_logits, labels) 1007 | return loss 1008 | else: 1009 | return reshaped_logits 1010 | 1011 | 1012 | class BertForTokenClassification(BertPreTrainedModel): 1013 | """BERT model for token-level classification. 1014 | This module is composed of the BERT model with a linear layer on top of 1015 | the full hidden state of the last layer. 1016 | 1017 | Params: 1018 | `config`: a BertConfig class instance with the configuration to build a new model. 1019 | `num_labels`: the number of classes for the classifier. Default = 2. 1020 | 1021 | Inputs: 1022 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1023 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1024 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1025 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1026 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1027 | a `sentence B` token (see BERT paper for more details). 1028 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1029 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1030 | input sequence length in the current batch. It's the mask that we typically use for attention when 1031 | a batch has varying length sentences. 1032 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length] 1033 | with indices selected in [0, ..., num_labels]. 1034 | 1035 | Outputs: 1036 | if `labels` is not `None`: 1037 | Outputs the CrossEntropy classification loss of the output with the labels. 1038 | if `labels` is `None`: 1039 | Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. 1040 | 1041 | Example usage: 1042 | ```python 1043 | # Already been converted into WordPiece token ids 1044 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1045 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1046 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1047 | 1048 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1049 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1050 | 1051 | num_labels = 2 1052 | 1053 | model = BertForTokenClassification(config, num_labels) 1054 | logits = model(input_ids, token_type_ids, input_mask) 1055 | ``` 1056 | """ 1057 | def __init__(self, config, num_labels): 1058 | super(BertForTokenClassification, self).__init__(config) 1059 | self.num_labels = num_labels 1060 | self.bert = BertModel(config) 1061 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1062 | self.classifier = nn.Linear(config.hidden_size, num_labels) 1063 | self.apply(self.init_bert_weights) 1064 | 1065 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 1066 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 1067 | sequence_output = self.dropout(sequence_output) 1068 | logits = self.classifier(sequence_output) 1069 | 1070 | if labels is not None: 1071 | loss_fct = CrossEntropyLoss() 1072 | # Only keep active parts of the loss 1073 | if attention_mask is not None: 1074 | active_loss = attention_mask.view(-1) == 1 1075 | active_logits = logits.view(-1, self.num_labels)[active_loss] 1076 | active_labels = labels.view(-1)[active_loss] 1077 | loss = loss_fct(active_logits, active_labels) 1078 | else: 1079 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1080 | return loss 1081 | else: 1082 | return logits 1083 | 1084 | 1085 | class BertForQuestionAnswering(BertPreTrainedModel): 1086 | """BERT model for Question Answering (span extraction). 1087 | This module is composed of the BERT model with a linear layer on top of 1088 | the sequence output that computes start_logits and end_logits 1089 | 1090 | Params: 1091 | `config`: a BertConfig class instance with the configuration to build a new model. 1092 | 1093 | Inputs: 1094 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1095 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1096 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1097 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1098 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1099 | a `sentence B` token (see BERT paper for more details). 1100 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1101 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1102 | input sequence length in the current batch. It's the mask that we typically use for attention when 1103 | a batch has varying length sentences. 1104 | `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. 1105 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 1106 | into account for computing the loss. 1107 | `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. 1108 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 1109 | into account for computing the loss. 1110 | 1111 | Outputs: 1112 | if `start_positions` and `end_positions` are not `None`: 1113 | Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. 1114 | if `start_positions` or `end_positions` is `None`: 1115 | Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end 1116 | position tokens of shape [batch_size, sequence_length]. 1117 | 1118 | Example usage: 1119 | ```python 1120 | # Already been converted into WordPiece token ids 1121 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1122 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1123 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1124 | 1125 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1126 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1127 | 1128 | model = BertForQuestionAnswering(config) 1129 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 1130 | ``` 1131 | """ 1132 | def __init__(self, config): 1133 | super(BertForQuestionAnswering, self).__init__(config) 1134 | self.bert = BertModel(config) 1135 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 1136 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 1137 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 1138 | self.apply(self.init_bert_weights) 1139 | 1140 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): 1141 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 1142 | logits = self.qa_outputs(sequence_output) 1143 | start_logits, end_logits = logits.split(1, dim=-1) 1144 | start_logits = start_logits.squeeze(-1) 1145 | end_logits = end_logits.squeeze(-1) 1146 | 1147 | if start_positions is not None and end_positions is not None: 1148 | # If we are on multi-GPU, split add a dimension 1149 | if len(start_positions.size()) > 1: 1150 | start_positions = start_positions.squeeze(-1) 1151 | if len(end_positions.size()) > 1: 1152 | end_positions = end_positions.squeeze(-1) 1153 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1154 | ignored_index = start_logits.size(1) 1155 | start_positions.clamp_(0, ignored_index) 1156 | end_positions.clamp_(0, ignored_index) 1157 | 1158 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 1159 | start_loss = loss_fct(start_logits, start_positions) 1160 | end_loss = loss_fct(end_logits, end_positions) 1161 | total_loss = (start_loss + end_loss) / 2 1162 | return total_loss 1163 | else: 1164 | return start_logits, end_logits 1165 | --------------------------------------------------------------------------------