├── VERSION
├── omogre
├── accentuator
│ ├── __init__.py
│ ├── unk_vocab.py
│ ├── unk_model.py
│ ├── unk_reader.py
│ ├── reader.py
│ ├── tokenizer.py
│ ├── model.py
│ └── bert.py
├── transcriptor
│ ├── __init__.py
│ ├── transcriptor.py
│ └── unk_words.py
├── downloader.py
└── __init__.py
├── download_data.py
├── test.py
├── setup.py
├── ruslan_markup.py
├── .gitignore
├── README_eng.md
├── Wav2vec2_ru_ipa.ipynb
├── README.md
└── LICENSE
/VERSION:
--------------------------------------------------------------------------------
1 | 0.1.0
2 |
--------------------------------------------------------------------------------
/omogre/accentuator/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | from .model import Accentuator
4 |
--------------------------------------------------------------------------------
/omogre/transcriptor/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | from .transcriptor import Transcriptor
4 |
--------------------------------------------------------------------------------
/download_data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | from omogre import find_model
4 | from argparse import ArgumentParser
5 |
6 | parser = ArgumentParser(description="Download omogre model")
7 |
8 | parser.add_argument("--data_path", type=str, default=None, help="omogre model direcory")
9 | parser.add_argument("--file_name", type=str, default='accentuator_transcriptor_tiny', help="omogre model direcory")
10 | args = parser.parse_args()
11 |
12 | if __name__ == "__main__":
13 | path = find_model(file_name=args.file_name, cache_dir=args.data_path, reload=True)
14 | print('find_model', path)
15 |
16 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from omogre import Accentuator, Transcriptor
2 |
3 | # данные будут скачаны в директорию 'omogre_data'
4 |
5 | transcriptor = Transcriptor(data_path='omogre_data')
6 |
7 | accentuator = Accentuator(data_path='omogre_data')
8 |
9 | sentence_list = ['А на камне том туристы понаписали всякой ху-у-у-лиганщины, типа: "Здесь бил Вася", "Коля + Света + сосед с четвертого этажа + ротвейлер тети Маши + взвод лейтенанта Миши + Белая лошадь (виски такое) = любовь".']
10 |
11 | #['А на камне том туристы понаписали всякой ху-у-у-лиганщины, типа: "Здесь бил Вася", "Коля ! Света ! сосед с четвертого этажа ! ротвейлер тети Маши ! взвод лейтенанта Миши ! Белая лошадь (виски такое) = любовь".']
12 |
13 | #['стены замка']
14 |
15 | print('accentuator()', accentuator(sentence_list))
16 |
17 | print('transcriptor()', transcriptor(sentence_list))
18 |
19 | print('transcriptor.transcribe', transcriptor.transcribe(sentence_list))
20 | print('accentuator.accentuate', accentuator.accentuate(sentence_list))
21 |
22 | print('transcriptor.accentuate', transcriptor.accentuate(sentence_list))
23 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 |
2 | from setuptools import find_packages, setup
3 |
4 | setup(
5 | name="omogre",
6 | version="0.1.0",
7 | author="omogr",
8 | author_email="omogrus@ya.ru",
9 | description="Russian accentuator and IPA transcriptor",
10 | long_description=open("README.md", "r", encoding='utf-8').read(),
11 | long_description_content_type="text/markdown",
12 | keywords='Russian accentuator IPA transcriptor',
13 | license='CC BY-NC-SA 4.0',
14 | url="https://github.com/omogr/omogre",
15 | packages=find_packages(),
16 |
17 | install_requires=['torch>=0.4.1',
18 | 'numpy',
19 | 'requests',
20 | 'tqdm'],
21 |
22 | classifiers=[
23 | 'Programming Language :: Python :: 3',
24 | 'License :: Free for non-commercial use',
25 | 'Natural Language :: Russian',
26 | 'Topic :: Text Processing :: Linguistic',
27 | 'Topic :: Multimedia :: Sound/Audio :: Sound Synthesis',
28 | 'Topic :: Multimedia :: Sound/Audio :: Speech',
29 | ],
30 | )
31 |
--------------------------------------------------------------------------------
/omogre/downloader.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | import logging
4 | import os
5 | import tarfile
6 | import tempfile
7 | import sys
8 | import requests
9 | from tqdm import tqdm
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 | def download_model(cache_dir: str, file_name: str):
14 | model_url: str = "https://huggingface.co/omogr/omogre/resolve/main/%s.gz?download=true"%file_name
15 |
16 | etag = None
17 | try:
18 | response = requests.head(model_url, allow_redirects=True)
19 | if response.status_code != 200:
20 | raise EnvironmentError('Cannot load model, response.status_code %d'%response.status_code)
21 | else:
22 | etag = response.headers.get("ETag")
23 | except EnvironmentError:
24 | etag = None
25 |
26 | if etag is None:
27 | raise EnvironmentError('Cannot load model, etag error')
28 |
29 | with tempfile.TemporaryFile() as temp_file: # NamedTemporaryFile f.name
30 | logger.info("model not found in cache, downloading to temporary file") # , model_url) #, temp_file.name)
31 |
32 | req = requests.get(model_url, stream=True)
33 | content_length = req.headers.get('Content-Length')
34 | total = int(content_length) if content_length is not None else None
35 | progress = tqdm(unit="B", total=total)
36 | for chunk in req.iter_content(chunk_size=1024):
37 | if chunk: # filter out keep-alive new chunks
38 | progress.update(len(chunk))
39 | temp_file.write(chunk)
40 | progress.close()
41 |
42 | # we are processing the file before closing it, so flush to avoid truncation
43 | temp_file.flush()
44 | # sometimes fileobj starts at the current position, so go to the start
45 | temp_file.seek(0)
46 | etag_file_name = os.path.join(cache_dir, 'etag')
47 | try:
48 | logger.info("model archive extractall to %s", cache_dir) #, temp_file.name)
49 | with tarfile.open(fileobj=temp_file, mode='r:gz') as archive:
50 | archive.extractall(cache_dir)
51 |
52 | with open(etag_file_name, mode='w', encoding='utf-8') as fout:
53 | print(model_url, file=fout)
54 | print(etag, file=fout)
55 |
56 | except:
57 | if os.path.isfile(etag_file_name):
58 | os.remove(etag_file_name)
59 | return cache_dir
60 |
61 |
62 | if __name__ == "__main__":
63 | path: str = "omogre_data"
64 | if not os.path.exists(path):
65 | os.mkdir(path)
66 | download_model(cache_dir=path, file_name='accentuator_transcriptor_tiny')
67 |
--------------------------------------------------------------------------------
/omogre/transcriptor/transcriptor.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | import os
4 | import pickle
5 | from .unk_words import UnkWords
6 |
7 | vowels = 'аеёиоуыэюяАЕЁИОУЫЭЮЯ'
8 |
9 | def get_single_vowel(src_str: str) -> int:
10 | pos = None
11 | for indx, tc in enumerate(src_str):
12 | if tc in vowels:
13 | if pos is not None:
14 | return None
15 | pos = indx
16 | return pos
17 |
18 |
19 | def normalize_accent(src_str: str) -> str:
20 | # .casefold() ?
21 | if '+' in src_str:
22 | return src_str
23 | if 'ё' in src_str:
24 | return src_str.replace('ё', '+ё')
25 |
26 | vowel_pos = get_single_vowel(src_str)
27 | if vowel_pos is None:
28 | return src_str
29 | return src_str[:vowel_pos] + '+' + src_str[vowel_pos:]
30 |
31 |
32 | def get_g2p_without_accent(grapheme_to_phoneme_vocab: dict) -> dict:
33 | grapheme_phoneme = {}
34 | grapheme_freq = {}
35 | for grapheme_with_accent, phoneme_vars in grapheme_to_phoneme_vocab.items():
36 | grapheme = grapheme_with_accent.replace('+', '')
37 | for phoneme, freq in phoneme_vars.items():
38 | if (grapheme_freq.get(grapheme, 0) < freq):
39 | grapheme_phoneme[grapheme] = phoneme
40 | grapheme_freq[grapheme] = freq
41 | return grapheme_phoneme
42 |
43 |
44 | def get_max_freq_phoneme(key: str, vocab: dict) -> str:
45 | if key not in vocab:
46 | return None
47 | max_freq = -1
48 | max_phoneme = None
49 |
50 | for t_phoneme, freq in vocab[key].items():
51 | if freq > max_freq:
52 | max_freq = freq
53 | max_phoneme = t_phoneme
54 | return max_phoneme
55 |
56 |
57 | class Transcriptor:
58 | def __init__(self, data_path: str):
59 | transcriptor_data_path = os.path.join(data_path, 'word_vocab.pickle')
60 | with open(transcriptor_data_path, "rb") as finp:
61 | self.grapheme_to_phoneme_vocab = pickle.load(finp)
62 |
63 | self.g2p_without_accent = get_g2p_without_accent(self.grapheme_to_phoneme_vocab)
64 | self.unk_words = UnkWords(data_path=data_path)
65 |
66 | def transcribe(self, src_str: str) -> str:
67 | word_str = src_str.casefold()
68 | word_str_norm = normalize_accent(word_str)
69 | if '+' in word_str_norm:
70 | max_phoneme = get_max_freq_phoneme(word_str_norm, self.grapheme_to_phoneme_vocab)
71 | if max_phoneme is not None:
72 | return max_phoneme
73 |
74 | if word_str in self.g2p_without_accent:
75 | return self.g2p_without_accent[word_str]
76 |
77 | return self.unk_words.transcribe(word_str)
78 |
79 |
--------------------------------------------------------------------------------
/ruslan_markup.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | from omogre import Transcriptor, find_model
4 | import time
5 | from argparse import ArgumentParser
6 |
7 | parser = ArgumentParser(description="Accentuate and transcribe natasha ruslan markup")
8 |
9 | parser.add_argument("--data_path", type=str, default=None, help="Omogre model direcory")
10 | parser.add_argument("--head", type=int, default=None, help="Process only head of file")
11 | args = parser.parse_args()
12 |
13 | transcriptor = Transcriptor(data_path=args.data_path)
14 |
15 | def save_markup(src_markup_parts: list, sentence_list: list, fout_name: str):
16 | assert len(src_markup_parts) == len(sentence_list)
17 |
18 | with open(fout_name, 'w', encoding='utf-8') as fout:
19 | for parts, out_sent in zip(src_markup_parts, sentence_list):
20 | parts[1] = out_sent
21 | print('|'.join(parts), file=fout)
22 |
23 |
24 | def process(dataset_name: str):
25 | print('dataset', dataset_name)
26 | finp_name = 'natasha_ruslan_markup/%s.txt'%dataset_name
27 |
28 | sentence_list = []
29 | src_markup_parts = []
30 | print('reading', finp_name)
31 |
32 | with open(finp_name, 'r', encoding='utf-8') as finp:
33 | for entry in finp:
34 | parts = entry.strip().split('|')
35 | assert len(parts) >= 2
36 |
37 | sentence_list.append(parts[1].replace('+', ''))
38 | src_markup_parts.append(parts)
39 | if args.head is not None:
40 | if len(src_markup_parts) >= args.head:
41 | break
42 |
43 | start = time.time()
44 | output_sentences = transcriptor.accentuate(sentence_list)
45 | dt = time.time() - start
46 | print('Accentuated', dataset_name, len(sentence_list), 'sentences, dtime %.1f s'%dt)
47 |
48 | fout_name = 'natasha_ruslan_markup/%s.accentuate'%dataset_name
49 | save_markup(src_markup_parts, output_sentences, fout_name)
50 |
51 | start = time.time()
52 | output_sentences = transcriptor.transcribe(sentence_list)
53 | dt = time.time() - start
54 | print('Transcribed', dataset_name, len(sentence_list), 'sentences, dtime %.1f s'%dt)
55 |
56 | fout_name = 'natasha_ruslan_markup/%s.transcribe'%dataset_name
57 | save_markup(src_markup_parts, output_sentences, fout_name)
58 |
59 |
60 | if __name__ == '__main__':
61 | # Хабр: Open Source синтез речи SOVA
62 | # https://habr.com/ru/companies/ashmanov_net/articles/528296/
63 |
64 | # ruslan
65 | # http://dataset.sova.ai/SOVA-TTS/ruslan/ruslan_dataset.tar
66 | # natasha
67 | # http://dataset.sova.ai/SOVA-TTS/natasha/natasha_dataset.tar
68 |
69 |
70 | find_model(cache_dir='natasha_ruslan_markup', file_name='natasha_ruslan')
71 |
72 | for dataset_name in ['natasha', 'ruslan']:
73 | process(dataset_name)
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/omogre/accentuator/unk_vocab.py:
--------------------------------------------------------------------------------
1 |
2 | import bisect
3 | import pickle
4 | import os
5 |
6 | def common_left(str1: str, str2: str) -> int:
7 | if len(str1) < len(str2):
8 | for index, tc in enumerate(str1):
9 | if str2[index] != tc:
10 | return index
11 | return len(str1)
12 |
13 | for index, tc in enumerate(str2):
14 | if str1[index] != tc:
15 | return index
16 | return len(str2)
17 |
18 |
19 | def get_neib_pos(key_pos_list: list, key: str, pos: int, min_len: int = 3) -> list:
20 | result = [pos]
21 | len1 = 0
22 | len2 = 0
23 | len3 = common_left(key, key_pos_list[pos][0])
24 | max_len = len3
25 |
26 | tp1 = pos - 1
27 | if tp1 >= 0:
28 | len1 = common_left(key, key_pos_list[tp1][0])
29 | max_len = max(max_len, len1)
30 | result.append(tp1)
31 |
32 | tp2 = pos + 1
33 | if tp2 < len(key_pos_list):
34 | len2 = common_left(key, key_pos_list[tp2][0])
35 | max_len = max(max_len, len2)
36 | result.append(tp2)
37 |
38 | if max_len < min_len:
39 | return result
40 |
41 | if len1 >= max_len:
42 | result.append(tp1)
43 | tp = pos - 2
44 | while tp >= 0:
45 | len1 = common_left(key, key_pos_list[tp][0])
46 | if len1 < max_len:
47 | break
48 | result.append(tp)
49 | tp -= 1
50 |
51 | if len2 >= max_len:
52 | result.append(tp2)
53 | tp = pos + 2
54 | while tp < len(key_pos_list):
55 | len1 = common_left(key, key_pos_list[tp][0])
56 | if len1 < max_len:
57 | break
58 | result.append(tp)
59 | tp += 1
60 | return result
61 |
62 |
63 | class UnkVocab:
64 | def __init__(self, data_path: str, encoding: str = 'utf-8'):
65 | unk_file = os.path.join(data_path, 'unk_vocab.pickle')
66 |
67 | with open(unk_file, 'rb') as finp:
68 | self.acc_vocab, self.all_tails = pickle.load(finp)
69 |
70 | def cmp_form_norm(self, form: str, tpos: int, res: list):
71 | norm = self.acc_vocab[tpos][0]
72 | num_equ_chars: int = common_left(form, norm)
73 | key = (form[num_equ_chars:], norm[num_equ_chars:])
74 |
75 | if key in self.all_tails:
76 | res.append( (num_equ_chars, self.acc_vocab[tpos]) )
77 |
78 | def search_neibs(self, text: str) -> list:
79 | key = (text, '')
80 |
81 | all_len = len(self.acc_vocab)
82 | if all_len < 1:
83 | return []
84 | if key < self.acc_vocab[0]:
85 | return get_neib_pos(self.acc_vocab, text, 0)
86 |
87 | try_pos = bisect.bisect_left(self.acc_vocab, key)
88 |
89 | return get_neib_pos(self.acc_vocab, text, try_pos)
90 |
91 | def get_neibs(self, text: str) -> list:
92 | res = []
93 | for tpos in self.search_neibs(text):
94 | self.cmp_form_norm(text, tpos, res)
95 | return res
96 |
97 | def get_acc_pos(self, text: str) -> int:
98 | result = self.get_neibs(text)
99 | if len(result) < 1:
100 | return -1
101 | result.sort()
102 | cl, key_acc = result[-1]
103 | acc_pos_list = [int(tp) for tp in key_acc[1].split(',')]
104 | if len(acc_pos_list) < 1:
105 | return -1
106 | acc_pos = acc_pos_list[-1]
107 | if acc_pos < 0:
108 | return -1
109 | if acc_pos >= cl:
110 | return -1
111 | return acc_pos
112 |
113 |
--------------------------------------------------------------------------------
/omogre/accentuator/unk_model.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import collections
8 | import math
9 | import os
10 | import random
11 | import sys
12 | import numpy as np
13 | import torch
14 |
15 | from .bert import BertForTokenClassification
16 |
17 | from .unk_reader import InfBatchFromSentenceList
18 | from .unk_vocab import UnkVocab
19 |
20 |
21 | def norm_word(tword: str) -> str:
22 | return tword.casefold().replace('ё', 'е').replace(' ', '!')
23 |
24 |
25 | def check_ee(tword: str) -> int:
26 | acc_pos = tword.find('ё')
27 | if acc_pos >= 0:
28 | return acc_pos
29 |
30 | acc_pos = tword.find('Ё')
31 | if acc_pos >= 0:
32 | return acc_pos
33 | return -1
34 |
35 |
36 | class UnkModel:
37 | def __init__(self, data_path: str, device_name: str = None):
38 | model_data_path = os.path.join(data_path, 'unk_model')
39 |
40 | if device_name is None:
41 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42 | else:
43 | self.device = torch.device(device_name)
44 |
45 | self.model = BertForTokenClassification.from_pretrained(model_data_path, num_labels=1, cache_dir=None)
46 | assert self.model
47 | self.unk_vocab = UnkVocab(data_path)
48 | self.model.eval()
49 | self.model.to(self.device)
50 | self.error_counter = 0
51 |
52 | #def seed(self, seed_value: int):
53 | # random.seed(seed_value)
54 | # np.random.seed(seed_value)
55 | # torch.manual_seed(seed_value)
56 |
57 | def process_model_batch(self, input_ids, attention_mask, batch_text, all_res):
58 | with torch.no_grad():
59 | input_ids = input_ids.to(self.device)
60 | attention_mask = attention_mask.to(self.device)
61 |
62 | logits = self.model(input_ids, attention_mask=attention_mask)
63 | logits = logits.squeeze(-1)
64 | logits = logits.detach().cpu().tolist() # cpu().
65 | input_ids = input_ids.detach().cpu().tolist() # cpu().
66 |
67 | for batch_indx, (t_logits, t_input_ids) in enumerate(zip(logits, input_ids)):
68 | word_text = batch_text[batch_indx]
69 | if len(word_text) < 2:
70 | all_res.append(word_text)
71 | continue
72 |
73 | max_pos = 1
74 | max_logit = t_logits[max_pos]
75 | for token_indx in range(len(word_text)):
76 |
77 | if t_logits[token_indx+1] > max_logit:
78 | max_pos = token_indx
79 | max_logit = t_logits[token_indx+1]
80 | ct = (word_text, max_pos)
81 | all_res.append(ct)
82 |
83 | def get_acc_pos(self, unk_word:str) -> int:
84 | acc_pos = self.unk_vocab.get_acc_pos(unk_word)
85 | if acc_pos >= 0:
86 | return acc_pos
87 |
88 | file_reader = InfBatchFromSentenceList([unk_word])
89 |
90 | batch = file_reader.get_next_batch()
91 | if batch is None:
92 | return -1
93 | (all_input_ids, all_attention_mask, sentence_data) = batch
94 | sum_batch = []
95 | self.process_model_batch(all_input_ids, all_attention_mask, sentence_data, sum_batch)
96 | if len(sum_batch) != 1:
97 | return -1
98 | (word_text, max_pos) = sum_batch[0]
99 | if word_text != unk_word:
100 | return -1
101 | return max_pos
102 |
103 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/omogre/accentuator/unk_reader.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import collections
4 | import random
5 |
6 | pad_token_id = 0
7 | bos_token_id = 1
8 | eos_token_id = 2
9 | alp_token_id = 3
10 | alphabet = ' абвгдеёжзийклмнопрстуфхцчшщъыьэюя' #АБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ'
11 |
12 | TokenSpan = collections.namedtuple( # pylint: disable=invalid-name
13 | "TokenSpan", ["text", "label", "punct", "span"])
14 |
15 | MAX_BATCH_TOKEN_NUM = 6000 # 250 * 24
16 |
17 | def get_tokens(sentence: str):
18 | tokens = [bos_token_id]
19 |
20 | err_cnt = 0
21 | for tc in sentence:
22 | if tc == '+':
23 | continue
24 | else:
25 | tid = alphabet.find(tc)
26 | if tid < 0:
27 | err_cnt += 1
28 | tid = pad_token_id
29 | else:
30 | tid += alp_token_id
31 | tokens.append(tid)
32 | tokens.append(eos_token_id)
33 | return tokens, err_cnt
34 |
35 |
36 | class InfBatchFromSentenceList:
37 | def __init__(self, sentence_list: list):
38 | self.all_entries = []
39 |
40 | for line_indx, sentence in enumerate(sentence_list):
41 | tokens, err_cnt = get_tokens(sentence.casefold())
42 | if err_cnt == 0:
43 | ct = (tokens, sentence)
44 | self.all_entries.append(ct)
45 |
46 | self.file_pos = -1
47 | self.iter = 0
48 | assert len(self.all_entries) > 0
49 |
50 | def is_first_iter(self):
51 | if self.iter > 0:
52 | return False
53 | if (self.file_pos + 1) >= len(self.all_entries):
54 | return False
55 | return True
56 |
57 | def get_next_pos(self):
58 | self.file_pos += 1
59 | if self.file_pos >= len(self.all_entries):
60 | self.iter += 1
61 | self.file_pos = 0
62 |
63 | def get_next_batch(self, is_test: bool = True):
64 | if is_test:
65 | assert self.is_first_iter()
66 |
67 | max_length = 1
68 | batch_data = []
69 | sentence_data = []
70 | while True:
71 | self.get_next_pos()
72 | if is_test:
73 | if self.iter > 0:
74 | break
75 | sentence_pos = self.file_pos
76 | input_ids, sentence = self.all_entries[sentence_pos]
77 | len_input_ids = len(input_ids)
78 | max_length = max(max_length, len_input_ids)
79 |
80 | token_cnt = max_length * (1 + len(sentence_data))
81 | if token_cnt > MAX_BATCH_TOKEN_NUM:
82 | break
83 | batch_data.append((input_ids))
84 | sentence_data.append(sentence)
85 |
86 | if len(batch_data) < 1:
87 | return None
88 |
89 | all_input_ids = []
90 | all_attention_mask = []
91 |
92 | for input_ids in batch_data:
93 | len_input_ids = len(input_ids)
94 | attention_mask = [1] * len_input_ids
95 | assert len(input_ids) <= max_length
96 | if len_input_ids < max_length:
97 | # Pad input_ids and attention_mask to max length
98 | padding_length = max_length - len_input_ids
99 | input_ids += [pad_token_id] * padding_length
100 | attention_mask += [0] * padding_length
101 |
102 | all_input_ids.append(torch.tensor(input_ids, dtype=torch.long))
103 | all_attention_mask.append(torch.tensor(attention_mask, dtype=torch.long))
104 |
105 | all_input_ids = torch.stack(all_input_ids)
106 | all_attention_mask = torch.stack(all_attention_mask)
107 | batch = (all_input_ids, all_attention_mask, sentence_data)
108 | return batch
109 |
110 |
111 | if __name__ == '__main__':
112 | fr = InfBatchFromSentenceList(['квазисублимирующие'])
113 |
114 | x = fr.get_next_batch()
115 | print('get_next_batch', x)
116 |
--------------------------------------------------------------------------------
/omogre/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | import os
4 | import sys
5 | from .transcriptor import Transcriptor as TranscriptorImpl
6 | from .accentuator import Accentuator as AccentuatorImpl
7 |
8 | INITIAL_MODEL = 'accentuator_transcriptor_tiny'
9 |
10 | def punctuation_filter(input_string: str, punct: str=None):
11 | output_string = []
12 | is_stace = False
13 | for tc in input_string:
14 | if punct is None or tc in punct:
15 | output_string.append(tc)
16 | else:
17 | if not is_stace:
18 | output_string.append(' ')
19 | is_stace = True
20 | return ''.join(output_string)
21 |
22 |
23 | def find_model(file_name: str = INITIAL_MODEL, cache_dir: str = None, download: bool = True, reload=False):
24 | from pathlib import Path
25 |
26 | if cache_dir is None:
27 | try:
28 | omogr_cache = Path(os.getenv('OMOGR_CACHE', Path.home() / '.omogr_data'))
29 | except (AttributeError, ImportError):
30 | omogr_cache = os.getenv('OMOGR_CACHE', os.path.join(os.path.expanduser("~"), '.omogr_data'))
31 |
32 | cache_dir = omogr_cache
33 |
34 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
35 | cache_dir = str(cache_dir)
36 |
37 | if not cache_dir:
38 | raise EnvironmentError('Cannot find OMOGR_CACHE path')
39 |
40 | if os.path.exists(cache_dir):
41 | if os.path.isdir(cache_dir):
42 | etag_file_name = os.path.join(cache_dir, 'etag')
43 | if os.path.isfile(etag_file_name):
44 | if not reload:
45 | return cache_dir
46 |
47 | if not download:
48 | raise EnvironmentError('Cannot find model data')
49 |
50 | print('data_path', cache_dir, file=sys.stderr)
51 |
52 | if not os.path.exists(cache_dir):
53 | os.makedirs(cache_dir)
54 |
55 | if not os.path.isdir(cache_dir):
56 | raise EnvironmentError('Cannot create directory %s'%cache_dir)
57 |
58 | from .downloader import download_model
59 | return download_model(cache_dir, file_name=file_name)
60 |
61 |
62 | class Transcriptor:
63 | def __init__(self, data_path: str = None, download: bool = True, device_name:str = None, punct:str = '.,!?'):
64 | loaded_data_path = find_model(file_name=INITIAL_MODEL,
65 | cache_dir=data_path, download=download)
66 |
67 | self.punct = punct
68 | transcriptor_data_path = os.path.join(loaded_data_path, 'transcriptor/')
69 | self.transcriptor = TranscriptorImpl(data_path=transcriptor_data_path)
70 |
71 | accentuator_data_path = os.path.join(loaded_data_path, 'accentuator/')
72 | self.accentuator = AccentuatorImpl(data_path=accentuator_data_path, device_name=device_name)
73 |
74 | def accentuate(self, text):
75 | return self.accentuator.accentuate(text)
76 |
77 | def transcribe(self, sentence_list: list) -> list:
78 | sentence_word_list = self.accentuator.accentuate_by_words(sentence_list)
79 | transcribed_sentence_list = []
80 | for t_sentence in sentence_word_list:
81 | transcribed_sentence = []
82 | for t_punct, t_word in t_sentence:
83 | if t_punct:
84 | transcribed_sentence.append(punctuation_filter(t_punct, punct=self.punct))
85 | if t_word:
86 | transcribed_sentence.append(self.transcriptor.transcribe(t_word))
87 | transcribed_sentence_list.append(''.join(transcribed_sentence))
88 | return transcribed_sentence_list
89 |
90 | def __call__(self, sentence_list: list) -> list:
91 | return self.transcribe(sentence_list)
92 |
93 |
94 | class Accentuator:
95 | def __init__(self, data_path: str = None, download: bool = True, device_name:str = None):
96 | loaded_data_path = find_model(file_name=INITIAL_MODEL,
97 | cache_dir=data_path, download=download)
98 |
99 | accentuator_data_path = os.path.join(loaded_data_path, 'accentuator/')
100 | self.accentuator = AccentuatorImpl(data_path=accentuator_data_path, device_name=device_name)
101 |
102 | def accentuate(self, text):
103 | return self.accentuator.accentuate(text)
104 |
105 | def __call__(self, text):
106 | return self.accentuate(text)
--------------------------------------------------------------------------------
/README_eng.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # Omogre
5 |
6 | ## Russian Accentuator and IPA Transcriptor
7 |
8 | A library for [`Python 3`](https://www.python.org/). Automatic stress placement and [IPA transcription](https://en.wikipedia.org/wiki/International_Phonetic_Alphabet) for the Russian language.
9 |
10 | ## Dependencies
11 |
12 | Installing the library will also install [`Pytorch`](https://pytorch.org/) and [`Numpy`](https://numpy.org/). Additionally, for model downloading, [`tqdm`](https://tqdm.github.io/) and [`requests`](https://pypi.org/project/requests/) will be installed.
13 |
14 | ## Installation
15 |
16 | ### Using GIT
17 |
18 | ```bash
19 | pip install git+https://github.com/omogr/omogre.git
20 | ```
21 |
22 | ### Using pip
23 |
24 | Download the code from [GitHub](https://github.com/omogr/omogre). In the directory containing [`setup.py`](https://github.com/omogr/omogre/blob/main/setup.py), run:
25 |
26 | ```bash
27 | pip install -e .
28 | ```
29 |
30 | ### Manually
31 |
32 | Download the code from [GitHub](https://github.com/omogr/omogre). Install [`Pytorch`](https://pytorch.org/), [`Numpy`](https://numpy.org/), [`tqdm`](https://tqdm.github.io/), and [`requests`](https://pypi.org/project/requests/). Run [`test.py`](https://github.com/omogr/omogre/blob/main/test.py).
33 |
34 | ## Model downloading
35 |
36 | By default, data for models will be downloaded on the first run of the library. The script [`download_data.py`](https://github.com/omogr/omogre/blob/main/download_data.py) can also be used to download this data.
37 |
38 | You can specify a path where the model data should be stored. If data already exists in this directory, it won't be downloaded again.
39 |
40 | ## Example
41 |
42 | Script [`test.py`](https://github.com/omogr/omogre/blob/main/test.py).
43 |
44 | ```python
45 | from omogre import Accentuator, Transcriptor
46 |
47 | # Data will be downloaded to the 'omogre_data' directory
48 | transcriptor = Transcriptor(data_path='omogre_data')
49 | accentuator = Accentuator(data_path='omogre_data')
50 |
51 | sentence_list = ['стены замка']
52 |
53 | print('transcriptor', transcriptor(sentence_list))
54 | print('accentuator', accentuator(sentence_list))
55 |
56 | # Alternative call methods, differing only in notation
57 | print('transcriptor.transcribe', transcriptor.transcribe(sentence_list))
58 | print('accentuator.accentuate', accentuator.accentuate(sentence_list))
59 |
60 | print('transcriptor.accentuate', transcriptor.accentuate(sentence_list))
61 | ```
62 |
63 | ## Class Parameters
64 |
65 | ### Transcriptor
66 |
67 | All initialization parameters for the class are optional.
68 |
69 | ```python
70 | class Transcriptor(data_path: str = None,
71 | download: bool = True,
72 | device_name: str = None,
73 | punct: str = '.,!?')
74 | ```
75 |
76 | - `data_path`: Directory where the model should be located.
77 | - `device_name`: Parameter defining GPU usage. Corresponds to the initialization parameter of [torch.device](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device). Valid values include `"cpu"`, `"cuda"`, `"cuda:0"`, etc. Defaults to `"cuda"` if GPU is available, otherwise `"cpu"`.
78 | - `punct`: List of non-letter characters to be carried over from the source text to the transcription. Default is `'.,!?'`.
79 | - `download`: Whether to download the model from the internet if not found in `data_path`. Default is `True`.
80 |
81 | Class methods:
82 |
83 | ```python
84 | accentuate(sentence_list: list) -> list
85 | transcribe(sentence_list: list) -> list
86 | ```
87 |
88 | `accentuate` places stresses, `transcribe` performs transcription. Both inputs take a list of strings and return a list of strings.
89 |
90 | ### Accentuator
91 |
92 | The `Accentuator` class for stress placement is identical to the `Transcriptor` in terms of stress functionality, except it doesn't load transcription data, reducing initialization time and memory usage.
93 |
94 | All initialization parameters are optional, with the same meanings as for `Transcriptor`.
95 |
96 | ```python
97 | class Accentuator(data_path: str = None,
98 | download: bool = True,
99 | device_name: str = None)
100 | ```
101 |
102 | - `data_path`: Directory where the model should be located.
103 | - `device_name`: Parameter for GPU usage. See above for details.
104 | - `download`: Whether to download the model if not found. Default is `True`.
105 |
106 | Class method:
107 |
108 | ```python
109 | accentuate(sentence_list: list) -> list
110 | ```
111 |
112 | ## Usage Examples
113 |
114 | ### markup files for acoustic corpora
115 |
116 | The script [`ruslan_markup.py`](https://github.com/omogr/omogre/blob/main/ruslan_markup.py) places stresses and generates transcriptions for markup files of the acoustic corpora [`RUSLAN`](https://ruslan-corpus.github.io/) ([`RUSLAN with manually accentuated markup`](http://dataset.sova.ai/SOVA-TTS/ruslan/ruslan_dataset.tar)) and [`NATASHA`](http://dataset.sova.ai/SOVA-TTS/natasha/natasha_dataset.tar).
117 |
118 | These markup files already contain [manually placed stresses](https://habr.com/ru/companies/ashmanov_net/articles/528296/).
119 |
120 | The script [`ruslan_markup.py`](https://github.com/omogr/omogre/blob/main/ruslan_markup.py) generates its own stress placement for these files, allowing for an evaluation of the accuracy of stress placement.
121 |
122 | ### Speech Synthesis
123 |
124 | Accentuation and transcription can be useful for speech synthesis. The [`Colab` notebook](https://github.com/omogr/omogre/blob/main/XTTS_ru_ipa.ipynb) contains an example of running the [`XTTS`](https://github.com/coqui-ai/TTS) model trained on transcription for the Russian language. The model was trained on the [`RUSLAN`](https://ruslan-corpus.github.io/) and [`Common Voice`](https://commonvoice.mozilla.org/ru) datasets.
125 | The model weights can be downloaded from [`Hugging Face`](https://huggingface.co/omogr/XTTS-ru-ipa).
126 |
127 | ### Extraction of transcription from audio files
128 |
129 | Accentuation and transcription can be useful for acustic corpora analysis. The [`Colab` notebook](https://github.com/omogr/omogre/blob/main/Wav2vec2_ru_ipa.ipynb) contains an example of running the wav2vec2-lv-60-espeak-cv-ft model finetuned with transcription of [`RUSLAN`](https://ruslan-corpus.github.io/) and [`Common Voice`](https://commonvoice.mozilla.org/ru).
130 |
131 | ## Context Awareness and Other Features
132 |
133 | ### Stresses
134 |
135 | Stresses are placed considering context. If very long strings are encountered (for the current model, more than 510 tokens), context won't be considered for these. Stresses in these strings will be placed only where possible without context.
136 |
137 | Stresses are also placed in one-syllable words, which might look unusual but simplifies subsequent transcription determination.
138 |
139 | ### Transcription
140 |
141 | During transcription generation, extraneous characters are filtered out. Non-letter characters that are not filtered can be specified by a parameter. By default, four punctuation marks (`.,!?`) are not filtered. Transcription is determined word by word, without context. The following symbols are used for transcription:
142 |
143 | ```
144 | ʲ`ɪətrsɐnjvmapkɨʊleɫdizofʂɕbɡxːuʐæɵʉɛ
145 | ```
146 |
147 | ## Feedback
148 | Email for questions, comments and suggestions - `omogrus@ya.ru`.
149 |
150 | ## License
151 | [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en)
152 |
153 | (translated by grok-2-2024-08-13)
154 |
--------------------------------------------------------------------------------
/Wav2vec2_ru_ipa.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": []
7 | },
8 | "kernelspec": {
9 | "name": "python3",
10 | "display_name": "Python 3"
11 | },
12 | "language_info": {
13 | "name": "python"
14 | }
15 | },
16 | "cells": [
17 | {
18 | "cell_type": "markdown",
19 | "source": [
20 | "# Wav2vec2 example\n",
21 | "Accentuation and transcription can be useful for acustic corpora analysis. This notebook contains an example of running the wav2vec2-lv-60-espeak-cv-ft model finetuned with [`RUSLAN`](https://ruslan-corpus.github.io/) and [`Common Voice`](https://commonvoice.mozilla.org/ru)\n"
22 | ],
23 | "metadata": {
24 | "id": "xEmXMe9SEjK2"
25 | }
26 | },
27 | {
28 | "cell_type": "code",
29 | "source": [
30 | "# @title Download model from Hugging Face\n",
31 | "!mkdir model\n",
32 | "!git clone https://huggingface.co/omogr/wav2vec2-lv-60-ru-ipa model"
33 | ],
34 | "metadata": {
35 | "colab": {
36 | "base_uri": "https://localhost:8080/"
37 | },
38 | "id": "g5ZEeq1LMH2E",
39 | "outputId": "ff112dc9-6292-4d44-b0c1-1227f760da7b"
40 | },
41 | "execution_count": 2,
42 | "outputs": [
43 | {
44 | "output_type": "stream",
45 | "name": "stdout",
46 | "text": [
47 | "Cloning into 'model'...\n",
48 | "remote: Enumerating objects: 20, done.\u001b[K\n",
49 | "remote: Counting objects: 100% (16/16), done.\u001b[K\n",
50 | "remote: Compressing objects: 100% (16/16), done.\u001b[K\n",
51 | "remote: Total 20 (delta 2), reused 0 (delta 0), pack-reused 4 (from 1)\u001b[K\n",
52 | "Unpacking objects: 100% (20/20), 304.43 KiB | 6.34 MiB/s, done.\n"
53 | ]
54 | }
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "source": [
60 | "import os\n",
61 | "import sys\n",
62 | "import numpy as np\n",
63 | "import torch\n",
64 | "import torchaudio\n",
65 | "import random\n",
66 | "\n",
67 | "from transformers import Wav2Vec2CTCTokenizer\n",
68 | "from transformers import Wav2Vec2FeatureExtractor\n",
69 | "from transformers import Wav2Vec2Processor\n",
70 | "from transformers import Wav2Vec2ForCTC\n",
71 | "\n",
72 | "MODEL_PATH = 'model'\n",
73 | "\n",
74 | "tokenizer = Wav2Vec2CTCTokenizer(\n",
75 | " \"model/vocab.json\",\n",
76 | " bos_token=\"\",\n",
77 | " eos_token=\"\",\n",
78 | " unk_token=\"\",\n",
79 | " pad_token=\"\",\n",
80 | " word_delimiter_token=\"|\",\n",
81 | " do_lower_case=False\n",
82 | ")\n",
83 | "\n",
84 | "# @title Load model and processor\n",
85 | "feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_PATH)\n",
86 | "processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)\n",
87 | "\n",
88 | "model = Wav2Vec2ForCTC.from_pretrained(\n",
89 | " MODEL_PATH,\n",
90 | " attention_dropout=0.0,\n",
91 | " hidden_dropout=0.0,\n",
92 | " feat_proj_dropout=0.0,\n",
93 | " mask_time_prob=0.0,\n",
94 | " layerdrop=0.0,\n",
95 | " gradient_checkpointing=True,\n",
96 | " ctc_loss_reduction=\"mean\",\n",
97 | " ctc_zero_infinity=True,\n",
98 | " bos_token_id=processor.tokenizer.bos_token_id,\n",
99 | " eos_token_id=processor.tokenizer.eos_token_id,\n",
100 | " pad_token_id=processor.tokenizer.pad_token_id,\n",
101 | " vocab_size=len(processor.tokenizer.get_vocab()),\n",
102 | " )\n",
103 | "\n",
104 | "def process_wav_file(wav_file_path: str):\n",
105 | " # read soundfiles\n",
106 | " waveform, sample_rate = torchaudio.load(wav_file_path)\n",
107 | "\n",
108 | " bundle_sample_rate = 16000\n",
109 | " if sample_rate != bundle_sample_rate:\n",
110 | " waveform = torchaudio.functional.resample(waveform, sample_rate, bundle_sample_rate)\n",
111 | "\n",
112 | " # tokenize\n",
113 | " input_values = processor(waveform, sampling_rate=16000, return_tensors=\"pt\").input_values\n",
114 | " # retrieve logits\n",
115 | " with torch.no_grad():\n",
116 | " logits = model(input_values.view(1, -1)).logits\n",
117 | " # take argmax and decode\n",
118 | " predicted_ids = torch.argmax(logits, dim=-1)\n",
119 | " return processor.batch_decode(predicted_ids)"
120 | ],
121 | "metadata": {
122 | "colab": {
123 | "base_uri": "https://localhost:8080/"
124 | },
125 | "id": "l1UCVSToMGVr",
126 | "outputId": "ba94419d-8401-4511-f33b-879f1d109777"
127 | },
128 | "execution_count": 5,
129 | "outputs": [
130 | {
131 | "output_type": "stream",
132 | "name": "stderr",
133 | "text": [
134 | "Some weights of the model checkpoint at model were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.masked_spec_embed']\n",
135 | "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
136 | "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
137 | ]
138 | }
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "source": [
144 | "sample_wav_files = [\n",
145 | " 'model/sample_wav_files/common_voice_ru_38488940.wav',\n",
146 | " 'model/sample_wav_files/common_voice_ru_38488941.wav',\n",
147 | "]\n",
148 | "\n",
149 | "# @title Transcribe wav files\n",
150 | "for wav_file_path in sample_wav_files:\n",
151 | " print('File:', wav_file_path)\n",
152 | " transcription = process_wav_file(wav_file_path)\n",
153 | " print('Transcription:', transcription)\n",
154 | "\n"
155 | ],
156 | "metadata": {
157 | "colab": {
158 | "base_uri": "https://localhost:8080/"
159 | },
160 | "id": "sIFV210KM6tw",
161 | "outputId": "d8ab7223-8aa9-47d9-d17f-387eaa19880a"
162 | },
163 | "execution_count": 6,
164 | "outputs": [
165 | {
166 | "output_type": "stream",
167 | "name": "stdout",
168 | "text": [
169 | "File: model/sample_wav_files/common_voice_ru_38488940.wav\n",
170 | "Transcription: ['kak v tr`udnɨje tak i d`obrɨj vrʲɪmʲɪn`a n`aʂɨ məɫɐdʲ`ɵʂ `ɛtə ɡɫ`avnəjə bɐɡ`atstvə']\n",
171 | "File: model/sample_wav_files/common_voice_ru_38488941.wav\n",
172 | "Transcription: ['mɨ nɐdʲ`ejɪmsʲə ʂto fsʲe ɡəsʊd`arstvə pɐdʲː`erʐɨvəjɪt `ɛtət tʲekst pənʲɪm`ajɪt ʂto n`ɨnʲɪʂnʲɪjə bʲɪzʲdʲ`ejstvʲɪje nʲɪprʲɪ`jemlʲɪmə']\n"
173 | ]
174 | }
175 | ]
176 | }
177 | ]
178 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Omogre
2 |
3 | ## Russian accentuator and IPA transcriptor.
4 |
5 | [English README](https://github.com/omogr/omogre/blob/main/README_eng.md)
6 |
7 | ## Автоматическая расстановка ударений и [IPA](https://ru.wikipedia.org/wiki/%D0%9C%D0%B5%D0%B6%D0%B4%D1%83%D0%BD%D0%B0%D1%80%D0%BE%D0%B4%D0%BD%D1%8B%D0%B9_%D1%84%D0%BE%D0%BD%D0%B5%D1%82%D0%B8%D1%87%D0%B5%D1%81%D0%BA%D0%B8%D0%B9_%D0%B0%D0%BB%D1%84%D0%B0%D0%B2%D0%B8%D1%82) транскрипция для русского языка.
8 |
9 | Библиотека для [`Python 3`](https://www.python.org/).
10 |
11 | ## Зависимости
12 |
13 | Установка библиотеки повлечет за собой установку [`Pytorch`](https://pytorch.org/) и [`Numpy`](https://numpy.org/). Кроме того, для скачивания моделей установятся [`tqdm`](https://tqdm.github.io/) и [`requests`](https://pypi.org/project/requests/).
14 |
15 | ## Установка
16 |
17 | ### С помощью GIT
18 |
19 | ```bash
20 | pip install git+https://github.com/omogr/omogre.git
21 | ```
22 |
23 | ### При помощи pip
24 |
25 | Скачать код с [гитхаба](https://github.com/omogr/omogre). В директории, в которой находится файл [`setup.py`](https://github.com/omogr/omogre/blob/main/setup.py), выполнить
26 |
27 | ```bash
28 | pip install -e .
29 | ```
30 |
31 | ### Вручную
32 |
33 | Скачать код с [гитхаба](https://github.com/omogr/omogre). Установить [`Pytorch`](https://pytorch.org/), [`Numpy`](https://numpy.org/), [`tqdm`](https://tqdm.github.io/) и [`requests`](https://pypi.org/project/requests/). Запустить [`test.py`](https://github.com/omogr/omogre/blob/main/test.py).
34 |
35 | ## Загрузка моделей
36 |
37 | По умолчанию при первом запуске библиотеки скачиваются данные для моделей. Скрипт [`download_data.py`](https://github.com/omogr/omogre/blob/main/download_data.py) также позволяет загружать эти данные.
38 |
39 | При желании можно указывать путь, в котором должны располагаться данные для моделей. Если в этой директории уже есть данные, то их повторного скачивания не будет.
40 |
41 | ## Пример запуска
42 |
43 | Скрипт [`test.py`](https://github.com/omogr/omogre/blob/main/test.py).
44 |
45 | ```python
46 | from omogre import Accentuator, Transcriptor
47 |
48 | # данные будут скачаны в директорию 'omogre_data'
49 | transcriptor = Transcriptor(data_path='omogre_data')
50 | accentuator = Accentuator(data_path='omogre_data')
51 |
52 | sentence_list = ['стены замка']
53 |
54 | print('transcriptor', transcriptor(sentence_list))
55 | print('accentuator', accentuator(sentence_list))
56 |
57 | # длугие способы вызовов, отличаются только формой записи
58 | print('transcriptor.transcribe', transcriptor.transcribe(sentence_list))
59 | print('accentuator.accentuate', accentuator.accentuate(sentence_list))
60 |
61 | print('transcriptor.accentuate', transcriptor.accentuate(sentence_list))
62 | ```
63 |
64 | ## Параметры классов
65 |
66 | ### Transcriptor
67 |
68 | Все параметры инициализации класса не являются обязательными.
69 |
70 | ```python
71 | class Transcriptor(data_path: str = None,
72 | download: bool = True,
73 | device_name: str = None,
74 | punct: str = '.,!?')
75 | ```
76 |
77 | - `data_path` - директория, в которой должна находиться модель.
78 |
79 | - `device_name` - параметр, определяющий использование GPU. Соответствует параметру инициализации класса [`torch.device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device). Допустимые значения - `"cpu"`, `"cuda"`, `"cuda:0"` и т.д. По умолчанию если torch видит GPU, то `"cuda"`, иначе `"cpu"`.
80 |
81 | - `punct` - список небуквенных символов, которые переносятся из исходного текста в транскрипцию. По умолчанию `'.,!?'`.
82 |
83 | - `download` - следует ли загружать модель из интернета, если она не найдена в директории `data_path`. По умолчанию `True`.
84 |
85 |
86 | Входы класса `Transcriptor`:
87 |
88 | ```python
89 | accentuate(sentence_list: list) -> list
90 | transcribe(sentence_list: list) -> list
91 | ```
92 |
93 | В случае `accentuate` выполняется расcтановка ударений, в случае `transcribe` - транскрипция. Оба входа получают на вход список строк и возращают список строк. Строками могут быть предложения или не очень большие куски текста.
94 |
95 | ### Accentuator
96 |
97 | Расстановка ударений классом Accentuator ничем не отличается от расстановки ударений классом Transcriptor. Разница только в том, что Accentuator не загружает данные для транскрипции. Это позволяет уменьшить время инициализации класса и расход оперативной памяти.
98 |
99 | Все параметры инициализации класса не являются обязательными. Смысл параметров инициализации такой же, как у класса Transcriptor.
100 |
101 | ```python
102 | class Accentuator(data_path: str = None,
103 | download: bool = True,
104 | device_name: str = None)
105 | ```
106 |
107 | - `data_path` - директория, в которой должна находиться модель.
108 |
109 | - `device_name` - параметр, определяющий использование GPU. Соответствует параметру инициализации класса [`torch.device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device). Допустимые значения - `"cpu"`, `"cuda"`, `"cuda:0"` и т.д. По умолчанию если torch видит GPU, то `"cuda"`, иначе `"cpu"`.
110 |
111 | - `download` - следует ли загружать модель из интернета, если она не найдена в директории `data_path`. По умолчанию `True`.
112 |
113 | Входы класса `Accentuator`:
114 |
115 | ```python
116 | accentuate(sentence_list: list) -> list
117 | ```
118 |
119 | ## Примеры работы
120 |
121 | ### markup-файлы для акустических корпусов
122 |
123 | Скрипт [`ruslan_markup.py`](https://github.com/omogr/omogre/blob/main/ruslan_markup.py) расставляет ударения и порождает транскрипцию для markup-файлов акустических корпусов [`RUSLAN`](https://ruslan-corpus.github.io/) ([`RUSLAN с ручной разметкой ударений`](http://dataset.sova.ai/SOVA-TTS/ruslan/ruslan_dataset.tar)) и [`NATASHA`](http://dataset.sova.ai/SOVA-TTS/natasha/natasha_dataset.tar).
124 |
125 | markup-файлы этих корпусов уже содержат расстановку ударений, которая [была сделана](https://habr.com/ru/companies/ashmanov_net/articles/528296/) вручную.
126 |
127 | Скрипт [`ruslan_markup.py`](https://github.com/omogr/omogre/blob/main/ruslan_markup.py) порождает для тех же файлов свою собственную расстановку ударений. Изначальная ручная разметка никак не используется при тестировании и не использовалась при обучении. Таким образом, её можно использовать для оценки точности расстановки ударений.
128 |
129 | ### Синтез речи
130 |
131 | Расстановка ударений и транскрипция могут быть полезны при синтезе речи. [Ноутбук](https://github.com/omogr/omogre/blob/main/XTTS_ru_ipa.ipynb) содержит пример запуска [`XTTS`](https://github.com/coqui-ai/TTS) модели, обученной на транскрипции для русского языка. Модель обучалась на корпусах [`RUSLAN`](https://ruslan-corpus.github.io/) и [`Common Voice`](https://commonvoice.mozilla.org/ru).
132 | Веса модели можно скачать с [Hugging Face](https://huggingface.co/omogr/XTTS-ru-ipa)
133 |
134 | ### Извлечение транскрипции из акустических файлов
135 |
136 | Расстановка ударений и транскрипция могут быть полезны при анализе речи. [Ноутбук](https://github.com/omogr/omogre/blob/main/Wav2vec2_ru_ipa.ipynb) содержит пример запуска модели wav2vec2-lv-60-espeak-cv-ft дообученной на транскрипции акустических корпусов [`RUSLAN`](https://ruslan-corpus.github.io/) и [`Common Voice`](https://commonvoice.mozilla.org/ru).
137 |
138 | ## Учёт контекста и некоторые другие особенности
139 |
140 | ### Ударения
141 |
142 | Ударения расставляются с учётом контекста. Если во входном списке строк встретятся очень длинные строки (для текущей модели это больше 510 токенов), то для таких длинных строк контекст учитываться не будет. Ударения в этих строках будут ставиться только там, где это возможно без учёта контекста.
143 |
144 | В словах из одного слога ударение тоже ставится. В некоторых случаях это может выглядеть странно, но упрощает последующее определение транскрипции.
145 |
146 | ### Транскрипция
147 |
148 | При порождении транскрипции посторонние символы фильтруются. Список небуквенных символов, которые не фильтруются можно задавать отдельным параметром. По умолчанию не фильтруются четыре знака пунктуации (`.,!?`). Транскрипция определяется пословно, без учёта контекста. Для транскрипции слов используются следующие символы:
149 |
150 | ```
151 | ʲ`ɪətrsɐnjvmapkɨʊleɫdizofʂɕbɡxːuʐæɵʉɛ
152 | ```
153 |
154 | ## Обратная связь
155 |
156 | Почта для вопросов, замечаний и предложений - `omogrus@ya.ru`.
157 |
158 | ## Лицензия
159 |
160 | [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ru)
161 |
--------------------------------------------------------------------------------
/omogre/accentuator/reader.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | import os
4 | import collections
5 | import torch
6 | from .tokenizer import BertTokenizer
7 | import pickle
8 |
9 | InfTokenSpan = collections.namedtuple("TokenSpan", ["word_tokens", "punct", "first", "last"])
10 |
11 | alphabet = '-абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ'
12 | vowels = 'аеёиоуыэюяАЕЁИОУЫЭЮЯ'
13 |
14 | MAX_SENTENCE_LEN = 510
15 |
16 | class AccentVocab:
17 | def __init__(self, data_path):
18 | vocab_file = os.path.join(data_path, 'wv_word_acc.pickle')
19 | with open(vocab_file, "rb") as finp:
20 | (self.vocab, self.vocab_index) = pickle.load(finp)
21 |
22 |
23 | class AccentTokenizer:
24 | def __init__(self, data_path):
25 | self.accent_vocab = AccentVocab(data_path=data_path)
26 | bert_vocab_path = os.path.join(data_path, 'model/vocab.txt')
27 | self.tokenizer = BertTokenizer(bert_vocab_path, do_lower_case=False)
28 |
29 | self.get_vowel_pos()
30 | self.pad_token_id = 0
31 |
32 | def get_vowel_pos(self):
33 | self.token_vowel_pos = {}
34 |
35 | for token_text, token_id in self.tokenizer.vocab.items():
36 |
37 | vowel_pos = []
38 | letter_pos = 0
39 | if token_text.startswith('##'):
40 | tt = token_text[2:]
41 | else:
42 | tt = token_text
43 | for letter_pos, tc in enumerate(tt):
44 | if tc in vowels:
45 | vowel_pos.append(letter_pos)
46 | if letter_pos > 0:
47 | assert tc != '#', (token_text, token_id)
48 | ct = (len(tt), vowel_pos)
49 | self.token_vowel_pos[token_id] = ct
50 |
51 | def encode(self, txt):
52 | tokens = self.tokenizer.tokenize(txt)
53 | return self.tokenizer.convert_tokens_to_ids(tokens)
54 |
55 | def get_num_vars(self, tword):
56 | tword_index = self.accent_vocab.vocab.get(tword)
57 | if tword_index is None:
58 | return -1
59 | return len(self.accent_vocab.vocab_index[tword_index])
60 |
61 | def tokenize_word(self, letter_list, tokens, first_pos, last_pos):
62 | tword = ''.join(letter_list) # casefold() ?
63 |
64 | num_vars = self.get_num_vars(tword)
65 | if num_vars >= 0:
66 | tokens.append(InfTokenSpan(self.encode(tword), [], first_pos, last_pos))
67 |
68 | if num_vars > 1:
69 | return True
70 | return False
71 |
72 | if '-' not in tword:
73 | tokens.append(InfTokenSpan(self.encode(tword), [], first_pos, last_pos))
74 | return False
75 |
76 | parts = tword.split('-')
77 | if len(parts) < 1:
78 | return False
79 |
80 | tpos = first_pos
81 |
82 | for tp in parts:
83 | next_pos = tpos + len(tp)
84 | tokens[-1].punct.append('-')
85 | if len(tp) > 0:
86 | tokens.append(InfTokenSpan(self.encode(tp), [], tpos, next_pos))
87 | tpos = next_pos + 1
88 |
89 | for tp in parts:
90 | if self.get_num_vars(tp) > 1:
91 | return True
92 | return False
93 |
94 | def tokenize_punct(self, letters_list):
95 | return self.encode(''.join(letters_list))
96 |
97 | def get_inf_tokens(self, sentence0):
98 | sentence = sentence0.replace('+', ' ')
99 | tokens = []
100 | tokenizer_bos = 2
101 | tokenizer_sep = 3
102 |
103 | ct = InfTokenSpan([tokenizer_bos], [], 0, 0)
104 | tokens.append(ct)
105 | tword = []
106 |
107 | first_pos = -1
108 | is_easy = True
109 | for char_pos, cur_char0 in enumerate(sentence):
110 | if cur_char0 in alphabet:
111 | cur_char = cur_char0.casefold()
112 | if first_pos < 0:
113 | first_pos = char_pos
114 | tword.append(cur_char)
115 | continue
116 |
117 | if tword:
118 | if self.tokenize_word(tword, tokens, first_pos, char_pos):
119 | is_easy = False
120 |
121 | tword = []
122 | first_pos = -1
123 | tokens[-1].punct.append(cur_char0)
124 |
125 | if tword:
126 | if self.tokenize_word(tword, tokens, first_pos, len(sentence)):
127 | is_easy = False
128 | tword = []
129 |
130 | ct = InfTokenSpan([tokenizer_sep], [], 0, 0)
131 | tokens.append(ct)
132 |
133 | all_ids = []
134 | all_spans = []
135 | for ws in tokens:
136 | first_token = len(all_ids)
137 | all_ids.extend(ws.word_tokens)
138 |
139 | if ws.last > 0:
140 | last_token = len(all_ids)
141 | text_span = (ws.first, ws.last)
142 | token_span = (first_token, last_token)
143 | ct = (text_span, token_span)
144 | all_spans.append(ct)
145 |
146 | tpunct_str = ''.join(ws.punct).replace(' ', '')
147 | if len(tpunct_str) > 0:
148 | punct_tokens = self.encode(tpunct_str)
149 | all_ids.extend(punct_tokens)
150 |
151 | if len(all_ids) > MAX_SENTENCE_LEN:
152 | is_easy = True
153 |
154 | return all_ids, all_spans, is_easy
155 |
156 |
157 | class AccentDocument:
158 | def __init__(self, acc_tokenizer, all_sentences, first_pos=0, max_batch_token_num=2048):
159 | self.all_sentences = all_sentences
160 |
161 | self.max_batch_token_num = max_batch_token_num
162 | self.easy_sentences = []
163 | self.model_batches = []
164 | self.too_long_sentence_cnt = 0
165 | self.pad_token_id = acc_tokenizer.pad_token_id
166 |
167 | self.get_batches(acc_tokenizer, first_pos)
168 |
169 | def num_entries(self):
170 | return len(self.sentence_list)
171 |
172 | def add_model_batch(self, bert_sentences, max_length):
173 | all_input_ids = []
174 | sentence_spans = []
175 | all_attention_mask = []
176 |
177 | batch_length = max_length
178 |
179 | for t_doc_pos, t_input_ids, t_spans in bert_sentences:
180 | len_input_ids = len(t_input_ids)
181 | attention_mask = [1] * len_input_ids
182 | ct = (t_doc_pos, t_spans)
183 | sentence_spans.append(ct)
184 | if len_input_ids > batch_length:
185 | # Truncate t_input_ids and attention_mask to max length
186 | assert False
187 | t_input_ids = t_input_ids[:batch_length]
188 |
189 | attention_mask = attention_mask[:batch_length]
190 | elif len(t_input_ids) < batch_length:
191 | # Pad t_input_ids and attention_mask to max length
192 | padding_length = batch_length - len_input_ids
193 | t_input_ids += [self.pad_token_id] * padding_length
194 | attention_mask += [0] * padding_length
195 |
196 | all_input_ids.append(torch.tensor(t_input_ids, dtype=torch.long))
197 | all_attention_mask.append(torch.tensor(attention_mask, dtype=torch.long))
198 |
199 | all_input_ids = torch.stack(all_input_ids)
200 | all_attention_mask = torch.stack(all_attention_mask)
201 | batch = (sentence_spans, all_input_ids, all_attention_mask)
202 |
203 | self.model_batches.append(batch)
204 |
205 | def get_batches(self, acc_tokenizer, first_pos):
206 | sorted_bert_sentences = []
207 | doc_pos = first_pos
208 |
209 | while doc_pos < len(self.all_sentences):
210 |
211 | cur_sentence = self.all_sentences[doc_pos]
212 | all_ids, all_spans, is_easy = acc_tokenizer.get_inf_tokens(cur_sentence)
213 | if is_easy:
214 | ct = (doc_pos, all_spans)
215 | self.easy_sentences.append(ct)
216 | doc_pos += 1
217 | continue
218 |
219 | len_input_ids = len(all_ids)
220 |
221 | if (1 + len_input_ids) >= self.max_batch_token_num:
222 | ct = (doc_pos, cur_sentence, all_spans)
223 | easy_sentences.append(ct)
224 | doc_pos += 1
225 | self.too_long_sentence_cnt += 1
226 | continue
227 |
228 | ct = (len_input_ids, doc_pos, all_ids, all_spans)
229 | sorted_bert_sentences.append(ct)
230 | doc_pos += 1
231 |
232 | sorted_bert_sentences.sort()
233 |
234 | max_length = 1
235 | bert_sentences = []
236 |
237 | for len_input_ids, doc_pos, all_ids, all_spans in sorted_bert_sentences:
238 | new_max_length = max(max_length, len_input_ids)
239 | if new_max_length * (1 + len(bert_sentences)) >= self.max_batch_token_num:
240 | assert len(bert_sentences) > 0
241 | self.add_model_batch(bert_sentences, max_length)
242 | max_length = len_input_ids
243 | bert_sentences = []
244 | else:
245 | max_length = new_max_length
246 |
247 | ct = (doc_pos, all_ids, all_spans)
248 | bert_sentences.append(ct)
249 |
250 | if len(bert_sentences) > 0:
251 | self.add_model_batch(bert_sentences, max_length)
252 |
253 | def get_all_easy(self, acc_tokenizer, first_pos=0):
254 | doc_pos = first_pos
255 | self.easy_sentences = []
256 | while doc_pos < len(self.all_sentences):
257 | all_ids, all_spans, is_easy = acc_tokenizer.get_inf_tokens(self.all_sentences[doc_pos])
258 |
259 | ct = (doc_pos, all_spans)
260 | self.easy_sentences.append(ct)
261 | doc_pos += 1
262 |
263 |
264 | if __name__ == '__main__':
265 | pass
266 |
--------------------------------------------------------------------------------
/omogre/transcriptor/unk_words.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | import os
4 | import sys
5 | import numpy as np
6 | import pickle
7 | import json
8 |
9 | INPUT_TEXT_ALPHABET = ' <>-`абвгдеёжзийклмнопрстуфхцчшщъыьэюя?'
10 | BOS_SYMBOL = ''
11 | DELETE_SYMBOL = ''
12 | UNK_PATH_EVAL = -100.0
13 | UNK_PATH_THRESHOLD = -99.0
14 |
15 | AUXILIARY_SYMBOL_REPLACEMENTS = [
16 | (DELETE_SYMBOL, ""),
17 | ("+", ""),
18 | ("~", ""),
19 | ("ʑ", "ɕ:"),
20 | ("ɣ", "x"),
21 | (":", "ː"),
22 | ("'", "`"),
23 | ("_", "")
24 | ]
25 |
26 | class UnkWordsConfig(object):
27 | """Configuration class to store the configuration of a `UnkWords`.
28 | """
29 | def __init__(self,
30 | window_size=7,
31 | unk_bigram_eval=20,
32 | kind_of_stupid_backoff=-1.31,
33 | gram_data_name='gram_model.pickle'):
34 |
35 | self.window_size = window_size
36 | self.unk_bigram_eval = unk_bigram_eval
37 | self.kind_of_stupid_backoff = kind_of_stupid_backoff
38 | self.gram_data_name = gram_data_name
39 |
40 | @classmethod
41 | def from_dict(cls, json_object):
42 | """Constructs a `BertConfig` from a Python dictionary of parameters."""
43 | config = BertConfig(vocab_size_or_config_json_file=-1)
44 | for key, value in json_object.items():
45 | config.__dict__[key] = value
46 | return config
47 |
48 | @classmethod
49 | def from_json_file(cls, json_file):
50 | """Constructs a `BertConfig` from a json file of parameters."""
51 | with open(json_file, "r", encoding='utf-8') as reader:
52 | text = reader.read()
53 | return cls.from_dict(json.loads(text))
54 |
55 | def __repr__(self):
56 | return str(self.to_json_string())
57 |
58 | def to_dict(self):
59 | """Serializes this instance to a Python dictionary."""
60 | output = copy.deepcopy(self.__dict__)
61 | return output
62 |
63 | def to_json_string(self):
64 | """Serializes this instance to a JSON string."""
65 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
66 |
67 | def to_json_file(self, json_file_path):
68 | """ Save this instance to a json file."""
69 | with open(json_file_path, "w", encoding='utf-8') as writer:
70 | writer.write(self.to_json_string())
71 |
72 |
73 | def invert_vocab(vocab: dict) -> dict:
74 | """Inverts a dictionary, swapping keys and values.
75 | Args:
76 | vocab (dict): The dictionary to invert.
77 |
78 | Returns:
79 | dict: The inverted dictionary.
80 | """
81 |
82 | inv = {}
83 | for key, value in vocab.items():
84 | inv[value] = key
85 | return inv
86 |
87 |
88 | def replace_auxiliary_symbols(text: str) -> str:
89 | result = text
90 | for src, dst in AUXILIARY_SYMBOL_REPLACEMENTS:
91 | if result.find(src) >= 0:
92 | result = result.replace(src, dst)
93 | return result
94 |
95 |
96 | def get_input_token_vocab() -> dict:
97 | # vocab for input text tokenizer
98 | vocab = {}
99 | for indx, tc in enumerate(INPUT_TEXT_ALPHABET):
100 | vocab[tc] = indx
101 | return vocab
102 |
103 |
104 | def token_lower_bound(raw_search_key: list, id_list: list) -> int:
105 | # binary search (a kind of)
106 | search_key = tuple(raw_search_key)
107 | first_pos = 0
108 | last_pos = len(id_list)
109 | while True:
110 | half_range = (last_pos - first_pos) // 2
111 | if half_range < 1:
112 | break
113 |
114 | mid_pos = first_pos + half_range
115 | if id_list[mid_pos][0] > search_key:
116 | last_pos = mid_pos
117 | continue
118 | first_pos = mid_pos
119 | if id_list[mid_pos][0] == search_key:
120 | break
121 | return first_pos
122 |
123 |
124 | class UnkWords:
125 | def __init__(self, data_path: str, config=None):
126 | if config is None:
127 | self.config = UnkWordsConfig()
128 | else:
129 | self.config = config
130 | transcriptor_data_path = os.path.join(data_path, self.config.gram_data_name)
131 | with open(transcriptor_data_path, 'rb') as finp:
132 | (
133 | self.dst_alphabet,
134 | self.gram_prob,
135 | self.bigram_eval,
136 | self.char_phrase_table,
137 | self.gram_phrase_table,
138 | self.head_transcriptions,
139 | self.tail_transcriptions,
140 | ) = pickle.load(finp)
141 |
142 | self.inv_dst_alphabet = invert_vocab(self.dst_alphabet)
143 |
144 | self.input_token_vocab = get_input_token_vocab()
145 | self.src_stress_sign = "`"
146 |
147 | self.src_stress_indx = self.input_token_vocab.get(self.src_stress_sign)
148 | self.delete_indx = self.dst_alphabet.get(DELETE_SYMBOL, -1)
149 | self.dst_stress_indx = self.dst_alphabet.get("'", -1)
150 |
151 | assert self.delete_indx >= 0
152 | assert self.dst_stress_indx >= 0
153 | assert self.src_stress_indx == 4
154 |
155 | def tokenize(self, input_text: str) -> list:
156 | bos_input_letters_eos = ['<', '<'] + [tc for tc in input_text] + ['>', '>']
157 | return [self.input_token_vocab.get(tc, 0) for tc in bos_input_letters_eos]
158 |
159 | def viterbi_step(self, best_path: list, emission_logprobs: np.ndarray) -> list:
160 | """Performs a single step in the Viterbi search algorithm.
161 |
162 | Args:
163 | best_path (list): The best paths from the previous step.
164 | emission_logprobs (np.ndarray): Character emission probabilities.
165 |
166 | Returns:
167 | list: The top `config.window_size` best paths for the current step.
168 | """
169 |
170 | new_best_paths: dict = {}
171 | for char_indx, t_char_eval in enumerate(emission_logprobs):
172 | if t_char_eval < UNK_PATH_THRESHOLD:
173 | continue
174 | next_dst_char = self.inv_dst_alphabet[char_indx]
175 |
176 | for path_indx, (prev_path_eval, _, phoneme1, phoneme2) in enumerate(best_path):
177 | key = (phoneme1, phoneme2, next_dst_char)
178 | if key in self.gram_prob:
179 | current_gram_eval = self.gram_prob[key]
180 | else:
181 | current_gram_eval = self.config.kind_of_stupid_backoff * self.bigram_eval.get(
182 | (phoneme1, phoneme2), self.config.unk_bigram_eval)
183 |
184 | new_path_eval = prev_path_eval + current_gram_eval + t_char_eval
185 | new_key = (phoneme2, next_dst_char)
186 | if new_key not in new_best_paths:
187 | new_best_paths[new_key] = (new_path_eval, path_indx)
188 | else:
189 | best_so_far, _ = new_best_paths[new_key]
190 | if new_path_eval > best_so_far:
191 | new_best_paths[new_key] = (new_path_eval, path_indx)
192 |
193 | best_list = sorted( [
194 | (path_eval, path_indx, phoneme1, phoneme2)
195 | for (phoneme1, phoneme2), (path_eval, path_indx) in new_best_paths.items()
196 | ], reverse=True)
197 | return best_list[:self.config.window_size]
198 |
199 | def transcribe(self, input_word_text: str) -> str:
200 | """
201 | Receives a word as input, returns its transcription
202 | """
203 | if not input_word_text:
204 | return ""
205 |
206 | # in the input text of the word, the stress can be indicated by a plus,
207 | # we replace it with the symbol that is used in n-grams...
208 |
209 | word_text = input_word_text.casefold().replace('+', self.src_stress_sign)
210 | input_ids = self.tokenize(word_text)
211 | dst_alphabet_len = len(self.dst_alphabet)
212 | word_len = len(input_ids)
213 |
214 | head_lower_bound = token_lower_bound(input_ids, self.head_transcriptions)
215 | head_len = -1
216 | for indx, tid in enumerate(self.head_transcriptions[head_lower_bound][0]):
217 | if indx >= len(input_ids):
218 | break
219 | if tid != input_ids[indx]:
220 | break
221 |
222 | pattern_indx = self.head_transcriptions[head_lower_bound][1][indx]
223 | if pattern_indx in [self.delete_indx, self.dst_stress_indx]:
224 | continue
225 | head_len = indx
226 |
227 | reversed_input_ids = list(reversed(input_ids))
228 | tail_lower_bound = token_lower_bound(reversed_input_ids, self.tail_transcriptions)
229 | tail_len = -1
230 | for indx, tid in enumerate(self.tail_transcriptions[tail_lower_bound][0]):
231 | if indx >= len(input_ids):
232 | break
233 | if tid != reversed_input_ids[indx]:
234 | break
235 | pattern_indx = self.tail_transcriptions[tail_lower_bound][1][indx]
236 | if pattern_indx in [self.delete_indx, self.dst_stress_indx]:
237 | continue
238 | tail_len = indx
239 |
240 | best_path = {}
241 | best_path[1] = [ (0.0, 0, BOS_SYMBOL, BOS_SYMBOL) ]
242 | indx = 2
243 | while indx < word_len:
244 | emission_logprobs = np.ones(dst_alphabet_len, dtype=np.float32) * UNK_PATH_EVAL
245 | indx1 = word_len - indx - 1
246 | is_empty = True
247 | if indx < head_len:
248 | vocab_char = self.head_transcriptions[head_lower_bound][1][indx]
249 | emission_logprobs[vocab_char] = 0.0
250 | is_empty = False
251 | if indx1 < tail_len:
252 | vocab_char = self.tail_transcriptions[tail_lower_bound][1][indx1]
253 | emission_logprobs[vocab_char] = 0.0
254 | is_empty = False
255 | if is_empty and (indx < word_len-1):
256 | i1 = indx - 1
257 | i2 = indx + 1
258 | if self.src_stress_indx == input_ids[i1]:
259 | if i1 > 0:
260 | i1 -= 1
261 | if self.src_stress_indx == input_ids[i2]:
262 | if i2 < word_len-1:
263 | i2 += 1
264 | src = tuple(input_ids[i1:i2+1])
265 | if src in self.gram_phrase_table:
266 | is_empty = False
267 | for key, value in self.gram_phrase_table[src].items():
268 | emission_logprobs[key] = value
269 | if is_empty:
270 | src = input_ids[indx]
271 | if src in self.char_phrase_table:
272 | for key, value in self.char_phrase_table[src].items():
273 | emission_logprobs[key] = value
274 |
275 | best_path[indx] = self.viterbi_step(best_path[indx-1], emission_logprobs)
276 | indx += 1
277 |
278 | indx = word_len-1
279 | prev = 0
280 | res = []
281 |
282 | while indx > 0 and len(best_path[indx]) > 0:
283 | path_eval, prev, phoneme1, phoneme2 = best_path[indx][prev]
284 | res.append(phoneme1)
285 | indx -= 1
286 |
287 | res.reverse()
288 |
289 | if len(res) > 3: # 4?
290 | # strip BOS_SYMBOL BOS_SYMBOL ... EOS_SYMBOL
291 | return replace_auxiliary_symbols(''.join(res[2:-1]))
292 | return ""
293 |
294 |
295 |
--------------------------------------------------------------------------------
/omogre/accentuator/tokenizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | # edited from pytorch_pretrained_bert
4 |
5 | # https://github.com/google-research/bert
6 | # https://github.com/maknotavailable/pytorch-pretrained-BERT
7 |
8 |
9 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
10 | #
11 | # Licensed under the Apache License, Version 2.0 (the "License");
12 | # you may not use this file except in compliance with the License.
13 | # You may obtain a copy of the License at
14 | #
15 | # http://www.apache.org/licenses/LICENSE-2.0
16 | #
17 | # Unless required by applicable law or agreed to in writing, software
18 | # distributed under the License is distributed on an "AS IS" BASIS,
19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20 | # See the License for the specific language governing permissions and
21 | # limitations under the License.
22 |
23 | """Tokenization classes."""
24 |
25 | from __future__ import absolute_import, division, print_function, unicode_literals
26 |
27 | import collections
28 | import logging
29 | import os
30 | import unicodedata
31 | from io import open
32 |
33 |
34 | logger = logging.getLogger(__name__)
35 |
36 | VOCAB_NAME = 'vocab.txt'
37 |
38 | def load_vocab(vocab_file):
39 | """Loads a vocabulary file into a dictionary."""
40 | vocab = collections.OrderedDict()
41 | index = 0
42 | with open(vocab_file, "r", encoding="utf-8") as reader:
43 | while True:
44 | token = reader.readline()
45 | if not token:
46 | break
47 | token = token.strip()
48 | vocab[token] = index
49 | index += 1
50 | return vocab
51 |
52 |
53 | def whitespace_tokenize(text):
54 | """Runs basic whitespace cleaning and splitting on a piece of text."""
55 | text = text.strip()
56 | if not text:
57 | return []
58 | tokens = text.split()
59 | return tokens
60 |
61 |
62 | class BertTokenizer(object):
63 | """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
64 |
65 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
66 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
67 | """Constructs a BertTokenizer.
68 |
69 | Args:
70 | vocab_file: Path to a one-wordpiece-per-line vocabulary file
71 | do_lower_case: Whether to lower case the input
72 | Only has an effect when do_wordpiece_only=False
73 | do_basic_tokenize: Whether to do basic tokenization before wordpiece.
74 | max_len: An artificial maximum length to truncate tokenized sequences to;
75 | Effective maximum length is always the minimum of this
76 | value (if specified) and the underlying BERT model's
77 | sequence length.
78 | never_split: List of tokens which will never be split during tokenization.
79 | Only has an effect when do_wordpiece_only=False
80 | """
81 | if not os.path.isfile(vocab_file):
82 | raise ValueError( "Can't find a vocabulary file at path '{}'.".format(vocab_file))
83 | self.vocab = load_vocab(vocab_file)
84 | self.ids_to_tokens = collections.OrderedDict(
85 | [(ids, tok) for tok, ids in self.vocab.items()])
86 | self.do_basic_tokenize = do_basic_tokenize
87 | if do_basic_tokenize:
88 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
89 | never_split=never_split)
90 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
91 | self.max_len = max_len if max_len is not None else int(1e12)
92 |
93 | def tokenize(self, text):
94 | split_tokens = []
95 | if self.do_basic_tokenize:
96 | for token in self.basic_tokenizer.tokenize(text):
97 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
98 | split_tokens.append(sub_token)
99 | else:
100 | split_tokens = self.wordpiece_tokenizer.tokenize(text)
101 | return split_tokens
102 |
103 | def convert_tokens_to_ids(self, tokens):
104 | """Converts a sequence of tokens into ids using the vocab."""
105 | ids = []
106 | for token in tokens:
107 | ids.append(self.vocab[token])
108 | if len(ids) > self.max_len:
109 | logger.warning(
110 | "Token indices sequence length is longer than the specified maximum "
111 | " sequence length for this BERT model ({} > {}). Running this"
112 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
113 | )
114 | return ids
115 |
116 | def convert_ids_to_tokens(self, ids):
117 | """Converts a sequence of ids in wordpiece tokens using the vocab."""
118 | tokens = []
119 | for i in ids:
120 | tokens.append(self.ids_to_tokens[i])
121 | return tokens
122 |
123 | def save_vocabulary(self, vocab_path):
124 | """Save the tokenizer vocabulary to a directory or file."""
125 | index = 0
126 | if os.path.isdir(vocab_path):
127 | vocab_file = os.path.join(vocab_path, VOCAB_NAME)
128 | with open(vocab_file, "w", encoding="utf-8") as writer:
129 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
130 | if index != token_index:
131 | logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
132 | " Please check that the vocabulary is not corrupted!".format(vocab_file))
133 | index = token_index
134 | writer.write(token + u'\n')
135 | index += 1
136 | return vocab_file
137 |
138 | @classmethod
139 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
140 | """
141 | Instantiate a PreTrainedBertModel from a pre-trained model file.
142 | Download and cache the pre-trained model file if needed.
143 | """
144 | vocab_file = pretrained_model_name_or_path
145 | if os.path.isdir(vocab_file):
146 | vocab_file = os.path.join(vocab_file, VOCAB_NAME)
147 | # Instantiate tokenizer.
148 | tokenizer = cls(vocab_file, *inputs, **kwargs)
149 | return tokenizer
150 |
151 |
152 | class BasicTokenizer(object):
153 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
154 |
155 | def __init__(self,
156 | do_lower_case=True,
157 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
158 | """Constructs a BasicTokenizer.
159 |
160 | Args:
161 | do_lower_case: Whether to lower case the input.
162 | """
163 | self.do_lower_case = do_lower_case
164 | self.never_split = never_split
165 |
166 | def tokenize(self, text):
167 | """Tokenizes a piece of text."""
168 | text = self._clean_text(text)
169 | # This was added on November 1st, 2018 for the multilingual and Chinese
170 | # models. This is also applied to the English models now, but it doesn't
171 | # matter since the English models were not trained on any Chinese data
172 | # and generally don't have any Chinese data in them (there are Chinese
173 | # characters in the vocabulary because Wikipedia does have some Chinese
174 | # words in the English Wikipedia.).
175 | text = self._tokenize_chinese_chars(text)
176 | orig_tokens = whitespace_tokenize(text)
177 | split_tokens = []
178 | for token in orig_tokens:
179 | if self.do_lower_case and token not in self.never_split:
180 | token = token.lower()
181 | token = self._run_strip_accents(token)
182 | split_tokens.extend(self._run_split_on_punc(token))
183 |
184 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
185 | return output_tokens
186 |
187 | def _run_strip_accents(self, text):
188 | """Strips accents from a piece of text."""
189 | text = unicodedata.normalize("NFD", text)
190 | output = []
191 | for char in text:
192 | cat = unicodedata.category(char)
193 | if cat == "Mn":
194 | continue
195 | output.append(char)
196 | return "".join(output)
197 |
198 | def _run_split_on_punc(self, text):
199 | """Splits punctuation on a piece of text."""
200 | if text in self.never_split:
201 | return [text]
202 | chars = list(text)
203 | i = 0
204 | start_new_word = True
205 | output = []
206 | while i < len(chars):
207 | char = chars[i]
208 | if _is_punctuation(char):
209 | output.append([char])
210 | start_new_word = True
211 | else:
212 | if start_new_word:
213 | output.append([])
214 | start_new_word = False
215 | output[-1].append(char)
216 | i += 1
217 |
218 | return ["".join(x) for x in output]
219 |
220 | def _tokenize_chinese_chars(self, text):
221 | """Adds whitespace around any CJK character."""
222 | output = []
223 | for char in text:
224 | cp = ord(char)
225 | if self._is_chinese_char(cp):
226 | output.append(" ")
227 | output.append(char)
228 | output.append(" ")
229 | else:
230 | output.append(char)
231 | return "".join(output)
232 |
233 | def _is_chinese_char(self, cp):
234 | """Checks whether CP is the codepoint of a CJK character."""
235 | # This defines a "chinese character" as anything in the CJK Unicode block:
236 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
237 | #
238 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
239 | # despite its name. The modern Korean Hangul alphabet is a different block,
240 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
241 | # space-separated words, so they are not treated specially and handled
242 | # like the all of the other languages.
243 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
244 | (cp >= 0x3400 and cp <= 0x4DBF) or #
245 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
246 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
247 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
248 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
249 | (cp >= 0xF900 and cp <= 0xFAFF) or #
250 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
251 | return True
252 |
253 | return False
254 |
255 | def _clean_text(self, text):
256 | """Performs invalid character removal and whitespace cleanup on text."""
257 | output = []
258 | for char in text:
259 | cp = ord(char)
260 | if cp == 0 or cp == 0xfffd or _is_control(char):
261 | continue
262 | if _is_whitespace(char):
263 | output.append(" ")
264 | else:
265 | output.append(char)
266 | return "".join(output)
267 |
268 |
269 | class WordpieceTokenizer(object):
270 | """Runs WordPiece tokenization."""
271 |
272 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
273 | self.vocab = vocab
274 | self.unk_token = unk_token
275 | self.max_input_chars_per_word = max_input_chars_per_word
276 |
277 | def tokenize(self, text):
278 | """Tokenizes a piece of text into its word pieces.
279 |
280 | This uses a greedy longest-match-first algorithm to perform tokenization
281 | using the given vocabulary.
282 |
283 | For example:
284 | input = "unaffable"
285 | output = ["un", "##aff", "##able"]
286 |
287 | Args:
288 | text: A single token or whitespace separated tokens. This should have
289 | already been passed through `BasicTokenizer`.
290 |
291 | Returns:
292 | A list of wordpiece tokens.
293 | """
294 |
295 | output_tokens = []
296 | for token in whitespace_tokenize(text):
297 | chars = list(token)
298 | if len(chars) > self.max_input_chars_per_word:
299 | output_tokens.append(self.unk_token)
300 | continue
301 |
302 | is_bad = False
303 | start = 0
304 | sub_tokens = []
305 | while start < len(chars):
306 | end = len(chars)
307 | cur_substr = None
308 | while start < end:
309 | substr = "".join(chars[start:end])
310 | if start > 0:
311 | substr = "##" + substr
312 | if substr in self.vocab:
313 | cur_substr = substr
314 | break
315 | end -= 1
316 | if cur_substr is None:
317 | is_bad = True
318 | break
319 | sub_tokens.append(cur_substr)
320 | start = end
321 |
322 | if is_bad:
323 | output_tokens.append(self.unk_token)
324 | else:
325 | output_tokens.extend(sub_tokens)
326 | return output_tokens
327 |
328 |
329 | def _is_whitespace(char):
330 | """Checks whether `chars` is a whitespace character."""
331 | # \t, \n, and \r are technically contorl characters but we treat them
332 | # as whitespace since they are generally considered as such.
333 | if char == " " or char == "\t" or char == "\n" or char == "\r":
334 | return True
335 | cat = unicodedata.category(char)
336 | if cat == "Zs":
337 | return True
338 | return False
339 |
340 |
341 | def _is_control(char):
342 | """Checks whether `chars` is a control character."""
343 | # These are technically control characters but we count them as whitespace
344 | # characters.
345 | if char == "\t" or char == "\n" or char == "\r":
346 | return False
347 | cat = unicodedata.category(char)
348 | if cat.startswith("C"):
349 | return True
350 | return False
351 |
352 |
353 | def _is_punctuation(char):
354 | """Checks whether `chars` is a punctuation character."""
355 | cp = ord(char)
356 | # We treat all non-letter/number ASCII as punctuation.
357 | # Characters such as "^", "$", and "`" are not in the Unicode
358 | # Punctuation class but we treat them as punctuation anyways, for
359 | # consistency.
360 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
361 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
362 | return True
363 | cat = unicodedata.category(char)
364 | if cat.startswith("P"):
365 | return True
366 | return False
367 |
--------------------------------------------------------------------------------
/omogre/accentuator/model.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import collections
8 | import math
9 | import os
10 | import random
11 | import sys
12 | import numpy as np
13 | import torch
14 |
15 | from .bert import BertForTokenClassification
16 |
17 | from .reader import AccentTokenizer, AccentDocument
18 | from .unk_model import UnkModel
19 |
20 | WordAcc = collections.namedtuple("WordAcc", ["first", "last", "pos", "state"])
21 | PunctWord = collections.namedtuple("WordPunct", ["punct", "word"])
22 |
23 | vowels = "аеёиоуыэюяАЕЁИОУЫЭЮЯ"
24 | vowel_plus = "аеёиоуыэюяАЕЁИОУЫЭЮЯ+"
25 |
26 | def count_vowels(text):
27 | return sum(1 for char in text if char in vowels)
28 |
29 |
30 | def _compute_softmax(scores):
31 | """Compute softmax probability over raw logits."""
32 | if not scores:
33 | return []
34 |
35 | max_score = None
36 | for score in scores:
37 | if max_score is None or score > max_score:
38 | max_score = score
39 |
40 | exp_scores = []
41 | total_sum = 0.0
42 | for score in scores:
43 | x = math.exp(score - max_score)
44 | exp_scores.append(x)
45 | total_sum += x
46 |
47 | probs = []
48 | for score in exp_scores:
49 | probs.append(score / total_sum)
50 | return probs
51 |
52 |
53 | def list_arg_max(iterable):
54 | return max(enumerate(iterable), key=lambda x: x[1])
55 |
56 |
57 | def norm_word(tword):
58 | return tword.casefold().replace('ё', 'е').replace(' ', '!')
59 |
60 |
61 | def check_ee_comu(tword):
62 | lcw = tword.casefold()
63 | if lcw == 'кому':
64 | return 3
65 | acc_pos = lcw.find('ё')
66 | if acc_pos >= 0:
67 | return acc_pos
68 | return -1
69 |
70 |
71 | class Accentuator:
72 | def __init__(self, data_path, device_name=None):
73 | model_data_path = os.path.join(data_path, 'model')
74 | self.unk_model = UnkModel(data_path, device_name=device_name)
75 |
76 | self.tokenizer = AccentTokenizer(data_path=data_path)
77 | if device_name is None:
78 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79 | else:
80 | self.device = torch.device(device_name)
81 |
82 | self.model = BertForTokenClassification.from_pretrained(model_data_path, num_labels=10, cache_dir=None)
83 | assert self.model
84 | self.model.eval()
85 | self.model.to(self.device)
86 | self.error_counter = 0
87 |
88 | def seed(self, seed_value):
89 | random.seed(seed_value)
90 | np.random.seed(seed_value)
91 | torch.manual_seed(seed_value)
92 |
93 | def process_model_batch(self, doc, batch, sum_batch):
94 | sentence_spans, input_ids, attention_mask = batch
95 | input_ids = input_ids.to(self.device)
96 | attention_mask = attention_mask.to(self.device)
97 |
98 | with torch.no_grad():
99 | logits = self.model(input_ids, attention_mask=attention_mask)
100 | logits = logits.detach().cpu().tolist() # cpu().
101 | input_ids = input_ids.detach().cpu().tolist() # cpu().
102 |
103 | for batch_indx, (t_logits, t_input_ids) in enumerate(zip(logits, input_ids)):
104 | doc_pos, all_spans = sentence_spans[batch_indx]
105 |
106 | sentence = doc.all_sentences[doc_pos]
107 | word_list = []
108 |
109 | for (first_word_pos, last_word_pos), (first_token_pos, last_token_pos) in all_spans:
110 | tword = sentence[first_word_pos:last_word_pos]
111 |
112 | acc_pos = check_ee_comu(tword)
113 | if acc_pos >= 0:
114 | tw = WordAcc(first_word_pos, last_word_pos, acc_pos, 'eee')
115 | word_list.append(tw)
116 | continue
117 |
118 | if count_vowels(tword) < 1:
119 | tw = WordAcc(first_word_pos, last_word_pos, -1, 'emp')
120 | word_list.append(tw)
121 | continue
122 |
123 | tword_key = norm_word(tword)
124 | word_acc_id = self.tokenizer.accent_vocab.vocab.get(tword_key, 0)
125 | acc_pos = self.tokenizer.accent_vocab.vocab_index[word_acc_id]
126 |
127 | if len(acc_pos) < 1:
128 | acc_pos = self.unk_model.get_acc_pos(tword.casefold())
129 | tw = WordAcc(first_word_pos, last_word_pos, acc_pos, 'oov')
130 | word_list.append(tw)
131 | continue
132 |
133 | if len(acc_pos) < 2:
134 | tw = WordAcc(first_word_pos, last_word_pos, acc_pos[0], 'sin')
135 | word_list.append(tw)
136 | continue
137 |
138 | best_word_index = -1
139 | best_word_prob = -1.0
140 | best_letter_pos = -1
141 | letter_pos = 0
142 |
143 | for token_indx in range(first_token_pos, last_token_pos):
144 | soft_probs = _compute_softmax(t_logits[token_indx])
145 | ct = self.tokenizer.token_vowel_pos.get(t_input_ids[token_indx])
146 | if ct is None:
147 | self.error_counter += 1
148 | break
149 | num_letters, vowel_pos = ct
150 |
151 | best_prob = 0
152 | best_index = -1
153 |
154 | for prob_index, tprob in enumerate(soft_probs):
155 | if prob_index < 1:
156 | continue
157 | if prob_index > len(vowel_pos):
158 | break
159 | tpos = letter_pos + vowel_pos[prob_index-1]
160 |
161 | if tpos not in acc_pos:
162 | continue
163 | if tprob > best_prob:
164 | best_prob = tprob
165 | best_index = prob_index
166 |
167 | if best_index > 0:
168 | if best_prob > best_word_prob:
169 | best_word_index = best_index
170 | best_word_prob = best_prob
171 | best_letter_pos = letter_pos + vowel_pos[best_index-1]
172 | letter_pos += num_letters
173 |
174 | if best_letter_pos >= 0:
175 | tw = WordAcc(first_word_pos, last_word_pos, best_letter_pos, 'var')
176 | word_list.append(tw)
177 | continue
178 |
179 | acc_pos = self.unk_model.get_acc_pos(tword.casefold())
180 | tw = WordAcc(first_word_pos, last_word_pos, acc_pos, 'unk')
181 | word_list.append(tw)
182 |
183 | ct = (doc_pos, word_list)
184 | sum_batch.append(ct)
185 |
186 | def process_without_bert(self, doc, doc_pos, all_spans):
187 | sentence = doc.all_sentences[doc_pos]
188 | word_list = []
189 |
190 | for (first_word_pos, last_word_pos), (first_token_pos, last_token_pos) in all_spans:
191 | tword = sentence[first_word_pos:last_word_pos]
192 |
193 | acc_pos = check_ee_comu(tword)
194 | if acc_pos >= 0:
195 | tw = WordAcc(first_word_pos, last_word_pos, acc_pos, 'eee')
196 | word_list.append(tw)
197 | continue
198 |
199 | if count_vowels(tword) < 1:
200 | tw = WordAcc(first_word_pos, last_word_pos, -1, 'emp')
201 | word_list.append(tw)
202 | continue
203 |
204 | tword_key = norm_word(tword)
205 | # state = 'var'
206 | word_acc_id = self.tokenizer.accent_vocab.vocab.get(tword_key)
207 | if word_acc_id is not None:
208 | acc_pos = self.tokenizer.accent_vocab.vocab_index[word_acc_id]
209 | if len(acc_pos) == 1:
210 | assert len(acc_pos)
211 | tw = WordAcc(first_word_pos, last_word_pos, acc_pos[0], 'sin')
212 | word_list.append(tw)
213 | continue
214 |
215 | acc_pos = self.unk_model.get_acc_pos(tword.casefold())
216 | tw = WordAcc(first_word_pos, last_word_pos, acc_pos, 'unk')
217 | word_list.append(tw)
218 | return word_list
219 |
220 | def easy_loop(self, doc, sum_batch):
221 | for doc_pos, all_spans in doc.easy_sentences:
222 | word_list = self.process_without_bert(doc, doc_pos, all_spans)
223 | ct = (doc_pos, word_list)
224 | sum_batch.append(ct)
225 |
226 | def result_to_text(self, sentence, word_list):
227 | scale = [0 for _ in range(len(sentence))]
228 | for first_word_pos, last_word_pos, acc_pos, state in word_list:
229 | if acc_pos >= 0:
230 | tp = first_word_pos + acc_pos
231 | assert tp < len(sentence)
232 | scale[tp] = 1
233 |
234 | res = []
235 | for indx, tc in enumerate(sentence):
236 | if scale[indx]:
237 | res.append('+')
238 | res.append(tc)
239 | return ''.join(res)
240 |
241 | def result_to_word_list(self, sentence, word_list):
242 | prev_pos = 0
243 | out_word_list = []
244 | for first_word_pos, last_word_pos, acc_pos, state in word_list:
245 | if acc_pos < 0:
246 | tword = sentence[first_word_pos:last_word_pos]
247 | else:
248 | tword_letters = []
249 | rel_pos = acc_pos + first_word_pos
250 | for tpos in range(first_word_pos, last_word_pos):
251 | if tpos == rel_pos:
252 | tword_letters.append('+')
253 | tword_letters.append(sentence[tpos])
254 | tword = ''.join(tword_letters)
255 |
256 | punct = sentence[prev_pos:first_word_pos]
257 | prev_pos = last_word_pos
258 | ct = PunctWord(punct, tword)
259 | out_word_list.append(ct)
260 | punct = sentence[prev_pos:]
261 | ct = PunctWord(punct, '')
262 | out_word_list.append(ct)
263 | return out_word_list
264 |
265 | def accentuate_by_words(self, input_sentence_list):
266 | if not bool(input_sentence_list):
267 | raise ValueError('list of strings is required')
268 |
269 | if not isinstance(input_sentence_list, list):
270 | raise ValueError('list of strings is required')
271 |
272 | if not all([isinstance(elem, str) for elem in input_sentence_list]):
273 | raise ValueError('list of strings is required')
274 |
275 | doc = AccentDocument(self.tokenizer, input_sentence_list)
276 |
277 | sum_batch = []
278 | self.easy_loop(doc, sum_batch)
279 | for t_batch in doc.model_batches:
280 | self.process_model_batch(doc, t_batch, sum_batch)
281 |
282 | sum_batch_index = []
283 | for indx in range(len(sum_batch)):
284 | pos = sum_batch[indx][0]
285 | sum_batch_index.append((pos, indx))
286 |
287 | output_sentence_word_list = []
288 |
289 | for pos, indx in sorted(sum_batch_index):
290 | sentence = doc.all_sentences[pos]
291 | sentence_words = self.result_to_word_list(sentence, sum_batch[indx][1])
292 | output_sentence_word_list.append(sentence_words)
293 | return output_sentence_word_list
294 |
295 | def accentuate_sentence_list(self, input_sentence_list):
296 | if not bool(input_sentence_list):
297 | raise ValueError('a list of strings is required')
298 | if not isinstance(input_sentence_list, list):
299 | raise ValueError('a list of strings is required')
300 | if not all([
301 | isinstance(elem, str) for elem in input_sentence_list]):
302 | raise ValueError('a list of strings is required')
303 |
304 | output_sentence_word_list = self.accentuate_by_words(input_sentence_list)
305 | output_sentence_text_list = []
306 | for t_sentence in output_sentence_word_list:
307 | word_list = []
308 | for punct, tword in t_sentence:
309 | word_list.append(punct)
310 | word_list.append(tword)
311 | output_sentence_text_list.append(''.join(word_list))
312 |
313 | return output_sentence_text_list
314 |
315 | def accentuate(self, input_text):
316 | if not bool(input_text):
317 | raise ValueError('a string or list of strings is required')
318 | if isinstance(input_text, list):
319 | input_sentence_list = input_text
320 |
321 | if not all([
322 | isinstance(elem, str) for elem in input_sentence_list]):
323 | raise ValueError('a string or list of strings is required')
324 |
325 | else:
326 | if isinstance(input_text, str):
327 | input_sentence_list = [input_text]
328 | else:
329 | raise ValueError('a string or list of strings is required')
330 |
331 | output_sentence_list = self.accentuate_sentence_list(input_sentence_list)
332 |
333 | if isinstance(input_text, str):
334 | return "\n".join(output_sentence_list)
335 | return output_sentence_list
336 |
337 | # ---------------------------- this is for debugging -----------------------------
338 |
339 | def process_all_easy_sentence(self, doc, doc_pos, all_spans):
340 | sentence = doc.all_sentences[doc_pos]
341 | word_list = []
342 |
343 | for (first_word_pos, last_word_pos), (first_token_pos, last_token_pos) in all_spans:
344 | tword = sentence[first_word_pos:last_word_pos]
345 |
346 | acc_pos = check_ee_comu(tword)
347 | if acc_pos >= 0:
348 | tw = WordAcc(first_word_pos, last_word_pos, [acc_pos], 'eee')
349 | word_list.append(tw)
350 | continue
351 |
352 | tword_key = norm_word(tword)
353 |
354 | word_acc_id = self.tokenizer.accent_vocab.vocab.get(tword_key)
355 | if word_acc_id is not None:
356 | acc_pos = self.tokenizer.accent_vocab.vocab_index[word_acc_id]
357 |
358 | tw = WordAcc(first_word_pos, last_word_pos, acc_pos, 'sin')
359 | word_list.append(tw)
360 | continue
361 |
362 | tw = WordAcc(first_word_pos, last_word_pos, [], 'unk')
363 | word_list.append(tw)
364 | return word_list
365 |
366 | def all_easy_loop(self, doc, sum_batch):
367 | for doc_pos, all_spans in doc.easy_sentences:
368 | word_list = self.process_all_easy_sentence(doc, doc_pos, all_spans)
369 | ct = (doc_pos, word_list)
370 | sum_batch.append(ct)
371 |
372 | def all_easy_to_word_list(self, sentence, word_list):
373 | prev_pos = 0
374 | out_word_list = []
375 | for first_word_pos, last_word_pos, acc_pos, state in word_list:
376 | tword = sentence[first_word_pos:last_word_pos]
377 |
378 | if len(acc_pos) < 1:
379 | tword = sentence[first_word_pos:last_word_pos]
380 | else:
381 | tword_letters = []
382 | stress_sign = '+'
383 |
384 | for tpos in range(first_word_pos, last_word_pos):
385 | rel_pos = tpos - first_word_pos
386 |
387 | if not acc_pos:
388 | tword_letters.append(stress_sign)
389 | elif rel_pos in acc_pos:
390 | tword_letters.append(stress_sign)
391 | tword_letters.append(sentence[tpos])
392 | tword = ''.join(tword_letters)
393 |
394 | punct = sentence[prev_pos:first_word_pos]
395 | prev_pos = last_word_pos
396 | ct = PunctWord(punct, tword)
397 | out_word_list.append(ct)
398 | punct = sentence[prev_pos:]
399 | ct = PunctWord(punct, '')
400 | out_word_list.append(ct)
401 | return out_word_list
402 |
403 | def accentuate_all_easy(self, input_sentence_list):
404 | if not bool(input_sentence_list):
405 | raise ValueError('list of strings is required')
406 |
407 | if not isinstance(input_sentence_list, list):
408 | raise ValueError('list of strings is required')
409 |
410 | if not all([isinstance(elem, str) for elem in input_sentence_list]):
411 | raise ValueError('list of strings is required')
412 |
413 | doc = AccentDocument(self.tokenizer, input_sentence_list)
414 |
415 | doc.get_all_easy(self.tokenizer)
416 |
417 | sum_batch = []
418 | self.all_easy_loop(doc, sum_batch)
419 |
420 | sum_batch_index = []
421 | for indx in range(len(sum_batch)):
422 | pos = sum_batch[indx][0]
423 | sum_batch_index.append((pos, indx))
424 |
425 | output_sentence_word_list = []
426 |
427 | for pos, indx in sorted(sum_batch_index):
428 | sentence = doc.all_sentences[pos]
429 | sentence_words = self.all_easy_to_word_list(sentence, sum_batch[indx][1])
430 | output_sentence_word_list.append(sentence_words)
431 |
432 | output_sentence_text_list = []
433 | for t_sentence in output_sentence_word_list:
434 | word_list = []
435 | for punct, tword in t_sentence:
436 | word_list.append(punct)
437 | word_list.append(tword)
438 | output_sentence_text_list.append(''.join(word_list))
439 |
440 | return output_sentence_text_list
441 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial-ShareAlike 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More_considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
58 | Public License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial-ShareAlike 4.0 International Public License
63 | ("Public License"). To the extent this Public License may be
64 | interpreted as a contract, You are granted the Licensed Rights in
65 | consideration of Your acceptance of these terms and conditions, and the
66 | Licensor grants You such rights in consideration of benefits the
67 | Licensor receives from making the Licensed Material available under
68 | these terms and conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. BY-NC-SA Compatible License means a license listed at
88 | creativecommons.org/compatiblelicenses, approved by Creative
89 | Commons as essentially the equivalent of this Public License.
90 |
91 | d. Copyright and Similar Rights means copyright and/or similar rights
92 | closely related to copyright including, without limitation,
93 | performance, broadcast, sound recording, and Sui Generis Database
94 | Rights, without regard to how the rights are labeled or
95 | categorized. For purposes of this Public License, the rights
96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
97 | Rights.
98 |
99 | e. Effective Technological Measures means those measures that, in the
100 | absence of proper authority, may not be circumvented under laws
101 | fulfilling obligations under Article 11 of the WIPO Copyright
102 | Treaty adopted on December 20, 1996, and/or similar international
103 | agreements.
104 |
105 | f. Exceptions and Limitations means fair use, fair dealing, and/or
106 | any other exception or limitation to Copyright and Similar Rights
107 | that applies to Your use of the Licensed Material.
108 |
109 | g. License Elements means the license attributes listed in the name
110 | of a Creative Commons Public License. The License Elements of this
111 | Public License are Attribution, NonCommercial, and ShareAlike.
112 |
113 | h. Licensed Material means the artistic or literary work, database,
114 | or other material to which the Licensor applied this Public
115 | License.
116 |
117 | i. Licensed Rights means the rights granted to You subject to the
118 | terms and conditions of this Public License, which are limited to
119 | all Copyright and Similar Rights that apply to Your use of the
120 | Licensed Material and that the Licensor has authority to license.
121 |
122 | j. Licensor means the individual(s) or entity(ies) granting rights
123 | under this Public License.
124 |
125 | k. NonCommercial means not primarily intended for or directed towards
126 | commercial advantage or monetary compensation. For purposes of
127 | this Public License, the exchange of the Licensed Material for
128 | other material subject to Copyright and Similar Rights by digital
129 | file-sharing or similar means is NonCommercial provided there is
130 | no payment of monetary compensation in connection with the
131 | exchange.
132 |
133 | l. Share means to provide material to the public by any means or
134 | process that requires permission under the Licensed Rights, such
135 | as reproduction, public display, public performance, distribution,
136 | dissemination, communication, or importation, and to make material
137 | available to the public including in ways that members of the
138 | public may access the material from a place and at a time
139 | individually chosen by them.
140 |
141 | m. Sui Generis Database Rights means rights other than copyright
142 | resulting from Directive 96/9/EC of the European Parliament and of
143 | the Council of 11 March 1996 on the legal protection of databases,
144 | as amended and/or succeeded, as well as other essentially
145 | equivalent rights anywhere in the world.
146 |
147 | n. You means the individual or entity exercising the Licensed Rights
148 | under this Public License. Your has a corresponding meaning.
149 |
150 |
151 | Section 2 -- Scope.
152 |
153 | a. License grant.
154 |
155 | 1. Subject to the terms and conditions of this Public License,
156 | the Licensor hereby grants You a worldwide, royalty-free,
157 | non-sublicensable, non-exclusive, irrevocable license to
158 | exercise the Licensed Rights in the Licensed Material to:
159 |
160 | a. reproduce and Share the Licensed Material, in whole or
161 | in part, for NonCommercial purposes only; and
162 |
163 | b. produce, reproduce, and Share Adapted Material for
164 | NonCommercial purposes only.
165 |
166 | 2. Exceptions and Limitations. For the avoidance of doubt, where
167 | Exceptions and Limitations apply to Your use, this Public
168 | License does not apply, and You do not need to comply with
169 | its terms and conditions.
170 |
171 | 3. Term. The term of this Public License is specified in Section
172 | 6(a).
173 |
174 | 4. Media and formats; technical modifications allowed. The
175 | Licensor authorizes You to exercise the Licensed Rights in
176 | all media and formats whether now known or hereafter created,
177 | and to make technical modifications necessary to do so. The
178 | Licensor waives and/or agrees not to assert any right or
179 | authority to forbid You from making technical modifications
180 | necessary to exercise the Licensed Rights, including
181 | technical modifications necessary to circumvent Effective
182 | Technological Measures. For purposes of this Public License,
183 | simply making modifications authorized by this Section 2(a)
184 | (4) never produces Adapted Material.
185 |
186 | 5. Downstream recipients.
187 |
188 | a. Offer from the Licensor -- Licensed Material. Every
189 | recipient of the Licensed Material automatically
190 | receives an offer from the Licensor to exercise the
191 | Licensed Rights under the terms and conditions of this
192 | Public License.
193 |
194 | b. Additional offer from the Licensor -- Adapted Material.
195 | Every recipient of Adapted Material from You
196 | automatically receives an offer from the Licensor to
197 | exercise the Licensed Rights in the Adapted Material
198 | under the conditions of the Adapter's License You apply.
199 |
200 | c. No downstream restrictions. You may not offer or impose
201 | any additional or different terms or conditions on, or
202 | apply any Effective Technological Measures to, the
203 | Licensed Material if doing so restricts exercise of the
204 | Licensed Rights by any recipient of the Licensed
205 | Material.
206 |
207 | 6. No endorsement. Nothing in this Public License constitutes or
208 | may be construed as permission to assert or imply that You
209 | are, or that Your use of the Licensed Material is, connected
210 | with, or sponsored, endorsed, or granted official status by,
211 | the Licensor or others designated to receive attribution as
212 | provided in Section 3(a)(1)(A)(i).
213 |
214 | b. Other rights.
215 |
216 | 1. Moral rights, such as the right of integrity, are not
217 | licensed under this Public License, nor are publicity,
218 | privacy, and/or other similar personality rights; however, to
219 | the extent possible, the Licensor waives and/or agrees not to
220 | assert any such rights held by the Licensor to the limited
221 | extent necessary to allow You to exercise the Licensed
222 | Rights, but not otherwise.
223 |
224 | 2. Patent and trademark rights are not licensed under this
225 | Public License.
226 |
227 | 3. To the extent possible, the Licensor waives any right to
228 | collect royalties from You for the exercise of the Licensed
229 | Rights, whether directly or through a collecting society
230 | under any voluntary or waivable statutory or compulsory
231 | licensing scheme. In all other cases the Licensor expressly
232 | reserves any right to collect such royalties, including when
233 | the Licensed Material is used other than for NonCommercial
234 | purposes.
235 |
236 |
237 | Section 3 -- License Conditions.
238 |
239 | Your exercise of the Licensed Rights is expressly made subject to the
240 | following conditions.
241 |
242 | a. Attribution.
243 |
244 | 1. If You Share the Licensed Material (including in modified
245 | form), You must:
246 |
247 | a. retain the following if it is supplied by the Licensor
248 | with the Licensed Material:
249 |
250 | i. identification of the creator(s) of the Licensed
251 | Material and any others designated to receive
252 | attribution, in any reasonable manner requested by
253 | the Licensor (including by pseudonym if
254 | designated);
255 |
256 | ii. a copyright notice;
257 |
258 | iii. a notice that refers to this Public License;
259 |
260 | iv. a notice that refers to the disclaimer of
261 | warranties;
262 |
263 | v. a URI or hyperlink to the Licensed Material to the
264 | extent reasonably practicable;
265 |
266 | b. indicate if You modified the Licensed Material and
267 | retain an indication of any previous modifications; and
268 |
269 | c. indicate the Licensed Material is licensed under this
270 | Public License, and include the text of, or the URI or
271 | hyperlink to, this Public License.
272 |
273 | 2. You may satisfy the conditions in Section 3(a)(1) in any
274 | reasonable manner based on the medium, means, and context in
275 | which You Share the Licensed Material. For example, it may be
276 | reasonable to satisfy the conditions by providing a URI or
277 | hyperlink to a resource that includes the required
278 | information.
279 | 3. If requested by the Licensor, You must remove any of the
280 | information required by Section 3(a)(1)(A) to the extent
281 | reasonably practicable.
282 |
283 | b. ShareAlike.
284 |
285 | In addition to the conditions in Section 3(a), if You Share
286 | Adapted Material You produce, the following conditions also apply.
287 |
288 | 1. The Adapter's License You apply must be a Creative Commons
289 | license with the same License Elements, this version or
290 | later, or a BY-NC-SA Compatible License.
291 |
292 | 2. You must include the text of, or the URI or hyperlink to, the
293 | Adapter's License You apply. You may satisfy this condition
294 | in any reasonable manner based on the medium, means, and
295 | context in which You Share Adapted Material.
296 |
297 | 3. You may not offer or impose any additional or different terms
298 | or conditions on, or apply any Effective Technological
299 | Measures to, Adapted Material that restrict exercise of the
300 | rights granted under the Adapter's License You apply.
301 |
302 |
303 | Section 4 -- Sui Generis Database Rights.
304 |
305 | Where the Licensed Rights include Sui Generis Database Rights that
306 | apply to Your use of the Licensed Material:
307 |
308 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
309 | to extract, reuse, reproduce, and Share all or a substantial
310 | portion of the contents of the database for NonCommercial purposes
311 | only;
312 |
313 | b. if You include all or a substantial portion of the database
314 | contents in a database in which You have Sui Generis Database
315 | Rights, then the database in which You have Sui Generis Database
316 | Rights (but not its individual contents) is Adapted Material,
317 | including for purposes of Section 3(b); and
318 |
319 | c. You must comply with the conditions in Section 3(a) if You Share
320 | all or a substantial portion of the contents of the database.
321 |
322 | For the avoidance of doubt, this Section 4 supplements and does not
323 | replace Your obligations under this Public License where the Licensed
324 | Rights include other Copyright and Similar Rights.
325 |
326 |
327 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
328 |
329 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
330 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
331 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
332 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
333 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
334 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
335 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
336 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
337 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
338 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
339 |
340 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
341 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
342 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
343 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
344 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
345 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
346 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
347 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
348 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
349 |
350 | c. The disclaimer of warranties and limitation of liability provided
351 | above shall be interpreted in a manner that, to the extent
352 | possible, most closely approximates an absolute disclaimer and
353 | waiver of all liability.
354 |
355 |
356 | Section 6 -- Term and Termination.
357 |
358 | a. This Public License applies for the term of the Copyright and
359 | Similar Rights licensed here. However, if You fail to comply with
360 | this Public License, then Your rights under this Public License
361 | terminate automatically.
362 |
363 | b. Where Your right to use the Licensed Material has terminated under
364 | Section 6(a), it reinstates:
365 |
366 | 1. automatically as of the date the violation is cured, provided
367 | it is cured within 30 days of Your discovery of the
368 | violation; or
369 |
370 | 2. upon express reinstatement by the Licensor.
371 |
372 | For the avoidance of doubt, this Section 6(b) does not affect any
373 | right the Licensor may have to seek remedies for Your violations
374 | of this Public License.
375 |
376 | c. For the avoidance of doubt, the Licensor may also offer the
377 | Licensed Material under separate terms or conditions or stop
378 | distributing the Licensed Material at any time; however, doing so
379 | will not terminate this Public License.
380 |
381 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
382 | License.
383 |
384 |
385 | Section 7 -- Other Terms and Conditions.
386 |
387 | a. The Licensor shall not be bound by any additional or different
388 | terms or conditions communicated by You unless expressly agreed.
389 |
390 | b. Any arrangements, understandings, or agreements regarding the
391 | Licensed Material not stated herein are separate from and
392 | independent of the terms and conditions of this Public License.
393 |
394 |
395 | Section 8 -- Interpretation.
396 |
397 | a. For the avoidance of doubt, this Public License does not, and
398 | shall not be interpreted to, reduce, limit, restrict, or impose
399 | conditions on any use of the Licensed Material that could lawfully
400 | be made without permission under this Public License.
401 |
402 | b. To the extent possible, if any provision of this Public License is
403 | deemed unenforceable, it shall be automatically reformed to the
404 | minimum extent necessary to make it enforceable. If the provision
405 | cannot be reformed, it shall be severed from this Public License
406 | without affecting the enforceability of the remaining terms and
407 | conditions.
408 |
409 | c. No term or condition of this Public License will be waived and no
410 | failure to comply consented to unless expressly agreed to by the
411 | Licensor.
412 |
413 | d. Nothing in this Public License constitutes or may be interpreted
414 | as a limitation upon, or waiver of, any privileges and immunities
415 | that apply to the Licensor or You, including from the legal
416 | processes of any jurisdiction or authority.
417 |
418 | =======================================================================
419 |
420 | Creative Commons is not a party to its public
421 | licenses. Notwithstanding, Creative Commons may elect to apply one of
422 | its public licenses to material it publishes and in those instances
423 | will be considered the “Licensor.” The text of the Creative Commons
424 | public licenses is dedicated to the public domain under the CC0 Public
425 | Domain Dedication. Except for the limited purpose of indicating that
426 | material is shared under a Creative Commons public license or as
427 | otherwise permitted by the Creative Commons policies published at
428 | creativecommons.org/policies, Creative Commons does not authorize the
429 | use of the trademark "Creative Commons" or any other trademark or logo
430 | of Creative Commons without its prior written consent including,
431 | without limitation, in connection with any unauthorized modifications
432 | to any of its public licenses or any other arrangements,
433 | understandings, or agreements concerning use of licensed material. For
434 | the avoidance of doubt, this paragraph does not form part of the
435 | public licenses.
436 |
437 | Creative Commons may be contacted at creativecommons.org.
438 |
--------------------------------------------------------------------------------
/omogre/accentuator/bert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | # edited from pytorch_pretrained_bert
4 |
5 | # https://github.com/google-research/bert
6 | # https://github.com/maknotavailable/pytorch-pretrained-BERT
7 |
8 | # coding=utf-8
9 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
10 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
11 | #
12 | # Licensed under the Apache License, Version 2.0 (the "License");
13 | # you may not use this file except in compliance with the License.
14 | # You may obtain a copy of the License at
15 | #
16 | # http://www.apache.org/licenses/LICENSE-2.0
17 | #
18 | # Unless required by applicable law or agreed to in writing, software
19 | # distributed under the License is distributed on an "AS IS" BASIS,
20 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21 | # See the License for the specific language governing permissions and
22 | # limitations under the License.
23 |
24 | """PyTorch BERT model."""
25 |
26 | from __future__ import absolute_import, division, print_function, unicode_literals
27 |
28 | import copy
29 | import json
30 | import logging
31 | import math
32 | import os
33 | import shutil
34 | import tarfile
35 | import tempfile
36 | import sys
37 | from io import open
38 |
39 | import torch
40 | from torch import nn
41 | from torch.nn import CrossEntropyLoss
42 |
43 | CONFIG_NAME = "config.json"
44 | WEIGHTS_NAME = "pytorch_model.bin"
45 |
46 | logger = logging.getLogger(__name__)
47 |
48 | BERT_CONFIG_NAME = 'bert_config.json'
49 | TF_WEIGHTS_NAME = 'model.ckpt'
50 |
51 | def gelu(x):
52 | """Implementation of the gelu activation function.
53 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
54 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
55 | Also see https://arxiv.org/abs/1606.08415
56 | """
57 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
58 |
59 |
60 | def swish(x):
61 | return x * torch.sigmoid(x)
62 |
63 |
64 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
65 |
66 |
67 | class BertConfig(object):
68 | """Configuration class to store the configuration of a `BertModel`.
69 | """
70 | def __init__(self,
71 | vocab_size_or_config_json_file,
72 | hidden_size=768,
73 | num_hidden_layers=12,
74 | num_attention_heads=12,
75 | intermediate_size=3072,
76 | hidden_act="gelu",
77 | hidden_dropout_prob=0.1,
78 | attention_probs_dropout_prob=0.1,
79 | max_position_embeddings=512,
80 | type_vocab_size=2,
81 | initializer_range=0.02):
82 | """Constructs BertConfig.
83 |
84 | Args:
85 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
86 | hidden_size: Size of the encoder layers and the pooler layer.
87 | num_hidden_layers: Number of hidden layers in the Transformer encoder.
88 | num_attention_heads: Number of attention heads for each attention layer in
89 | the Transformer encoder.
90 | intermediate_size: The size of the "intermediate" (i.e., feed-forward)
91 | layer in the Transformer encoder.
92 | hidden_act: The non-linear activation function (function or string) in the
93 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
94 | hidden_dropout_prob: The dropout probabilitiy for all fully connected
95 | layers in the embeddings, encoder, and pooler.
96 | attention_probs_dropout_prob: The dropout ratio for the attention
97 | probabilities.
98 | max_position_embeddings: The maximum sequence length that this model might
99 | ever be used with. Typically set this to something large just in case
100 | (e.g., 512 or 1024 or 2048).
101 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into
102 | `BertModel`.
103 | initializer_range: The sttdev of the truncated_normal_initializer for
104 | initializing all weight matrices.
105 | """
106 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
107 | and isinstance(vocab_size_or_config_json_file, unicode)):
108 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
109 | json_config = json.loads(reader.read())
110 | for key, value in json_config.items():
111 | self.__dict__[key] = value
112 | elif isinstance(vocab_size_or_config_json_file, int):
113 | self.vocab_size = vocab_size_or_config_json_file
114 | self.hidden_size = hidden_size
115 | self.num_hidden_layers = num_hidden_layers
116 | self.num_attention_heads = num_attention_heads
117 | self.hidden_act = hidden_act
118 | self.intermediate_size = intermediate_size
119 | self.hidden_dropout_prob = hidden_dropout_prob
120 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
121 | self.max_position_embeddings = max_position_embeddings
122 | self.type_vocab_size = type_vocab_size
123 | self.initializer_range = initializer_range
124 | else:
125 | raise ValueError("First argument must be either a vocabulary size (int)"
126 | "or the path to a pretrained model config file (str)")
127 |
128 | @classmethod
129 | def from_dict(cls, json_object):
130 | """Constructs a `BertConfig` from a Python dictionary of parameters."""
131 | config = BertConfig(vocab_size_or_config_json_file=-1)
132 | for key, value in json_object.items():
133 | config.__dict__[key] = value
134 | return config
135 |
136 | @classmethod
137 | def from_json_file(cls, json_file):
138 | """Constructs a `BertConfig` from a json file of parameters."""
139 | with open(json_file, "r", encoding='utf-8') as reader:
140 | text = reader.read()
141 | return cls.from_dict(json.loads(text))
142 |
143 | def __repr__(self):
144 | return str(self.to_json_string())
145 |
146 | def to_dict(self):
147 | """Serializes this instance to a Python dictionary."""
148 | output = copy.deepcopy(self.__dict__)
149 | return output
150 |
151 | def to_json_string(self):
152 | """Serializes this instance to a JSON string."""
153 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
154 |
155 | def to_json_file(self, json_file_path):
156 | """ Save this instance to a json file."""
157 | with open(json_file_path, "w", encoding='utf-8') as writer:
158 | writer.write(self.to_json_string())
159 |
160 | #try:
161 | # from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
162 | #except ImportError:
163 | # logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
164 |
165 | class BertLayerNorm(nn.Module):
166 | def __init__(self, hidden_size, eps=1e-12):
167 | """Construct a layernorm module in the TF style (epsilon inside the square root).
168 | """
169 | super(BertLayerNorm, self).__init__()
170 | self.weight = nn.Parameter(torch.ones(hidden_size))
171 | self.bias = nn.Parameter(torch.zeros(hidden_size))
172 | self.variance_epsilon = eps
173 |
174 | def forward(self, x):
175 | u = x.mean(-1, keepdim=True)
176 | s = (x - u).pow(2).mean(-1, keepdim=True)
177 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
178 | return self.weight * x + self.bias
179 |
180 | class BertEmbeddings(nn.Module):
181 | """Construct the embeddings from word, position and token_type embeddings.
182 | """
183 | def __init__(self, config):
184 | super(BertEmbeddings, self).__init__()
185 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
186 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
187 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
188 |
189 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
190 | # any TensorFlow checkpoint file
191 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
192 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
193 |
194 | def forward(self, input_ids, token_type_ids=None):
195 | seq_length = input_ids.size(1)
196 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
197 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
198 | if token_type_ids is None:
199 | token_type_ids = torch.zeros_like(input_ids)
200 |
201 | words_embeddings = self.word_embeddings(input_ids)
202 | position_embeddings = self.position_embeddings(position_ids)
203 | token_type_embeddings = self.token_type_embeddings(token_type_ids)
204 |
205 | embeddings = words_embeddings + position_embeddings + token_type_embeddings
206 | embeddings = self.LayerNorm(embeddings)
207 | embeddings = self.dropout(embeddings)
208 | return embeddings
209 |
210 |
211 | class BertSelfAttention(nn.Module):
212 | def __init__(self, config):
213 | super(BertSelfAttention, self).__init__()
214 | if config.hidden_size % config.num_attention_heads != 0:
215 | raise ValueError(
216 | "The hidden size (%d) is not a multiple of the number of attention "
217 | "heads (%d)" % (config.hidden_size, config.num_attention_heads))
218 | self.num_attention_heads = config.num_attention_heads
219 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
220 | self.all_head_size = self.num_attention_heads * self.attention_head_size
221 |
222 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
223 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
224 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
225 |
226 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
227 |
228 | def transpose_for_scores(self, x):
229 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
230 | x = x.view(*new_x_shape)
231 | return x.permute(0, 2, 1, 3)
232 |
233 | def forward(self, hidden_states, attention_mask):
234 | mixed_query_layer = self.query(hidden_states)
235 | mixed_key_layer = self.key(hidden_states)
236 | mixed_value_layer = self.value(hidden_states)
237 |
238 | query_layer = self.transpose_for_scores(mixed_query_layer)
239 | key_layer = self.transpose_for_scores(mixed_key_layer)
240 | value_layer = self.transpose_for_scores(mixed_value_layer)
241 |
242 | # Take the dot product between "query" and "key" to get the raw attention scores.
243 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
244 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
245 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
246 | attention_scores = attention_scores + attention_mask
247 |
248 | # Normalize the attention scores to probabilities.
249 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
250 |
251 | # This is actually dropping out entire tokens to attend to, which might
252 | # seem a bit unusual, but is taken from the original Transformer paper.
253 | attention_probs = self.dropout(attention_probs)
254 |
255 | context_layer = torch.matmul(attention_probs, value_layer)
256 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
257 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
258 | context_layer = context_layer.view(*new_context_layer_shape)
259 | return context_layer
260 |
261 |
262 | class BertSelfOutput(nn.Module):
263 | def __init__(self, config):
264 | super(BertSelfOutput, self).__init__()
265 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
266 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
267 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
268 |
269 | def forward(self, hidden_states, input_tensor):
270 | hidden_states = self.dense(hidden_states)
271 | hidden_states = self.dropout(hidden_states)
272 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
273 | return hidden_states
274 |
275 |
276 | class BertAttention(nn.Module):
277 | def __init__(self, config):
278 | super(BertAttention, self).__init__()
279 | self.self = BertSelfAttention(config)
280 | self.output = BertSelfOutput(config)
281 |
282 | def forward(self, input_tensor, attention_mask):
283 | self_output = self.self(input_tensor, attention_mask)
284 | attention_output = self.output(self_output, input_tensor)
285 | return attention_output
286 |
287 |
288 | class BertIntermediate(nn.Module):
289 | def __init__(self, config):
290 | super(BertIntermediate, self).__init__()
291 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
292 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
293 | self.intermediate_act_fn = ACT2FN[config.hidden_act]
294 | else:
295 | self.intermediate_act_fn = config.hidden_act
296 |
297 | def forward(self, hidden_states):
298 | hidden_states = self.dense(hidden_states)
299 | hidden_states = self.intermediate_act_fn(hidden_states)
300 | return hidden_states
301 |
302 |
303 | class BertOutput(nn.Module):
304 | def __init__(self, config):
305 | super(BertOutput, self).__init__()
306 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
307 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
308 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
309 |
310 | def forward(self, hidden_states, input_tensor):
311 | hidden_states = self.dense(hidden_states)
312 | hidden_states = self.dropout(hidden_states)
313 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
314 | return hidden_states
315 |
316 |
317 | class BertLayer(nn.Module):
318 | def __init__(self, config):
319 | super(BertLayer, self).__init__()
320 | self.attention = BertAttention(config)
321 | self.intermediate = BertIntermediate(config)
322 | self.output = BertOutput(config)
323 |
324 | def forward(self, hidden_states, attention_mask):
325 | attention_output = self.attention(hidden_states, attention_mask)
326 | intermediate_output = self.intermediate(attention_output)
327 | layer_output = self.output(intermediate_output, attention_output)
328 | return layer_output
329 |
330 |
331 | class BertEncoder(nn.Module):
332 | def __init__(self, config):
333 | super(BertEncoder, self).__init__()
334 | layer = BertLayer(config)
335 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
336 |
337 | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
338 | all_encoder_layers = []
339 | for layer_module in self.layer:
340 | hidden_states = layer_module(hidden_states, attention_mask)
341 | if output_all_encoded_layers:
342 | all_encoder_layers.append(hidden_states)
343 | if not output_all_encoded_layers:
344 | all_encoder_layers.append(hidden_states)
345 | return all_encoder_layers
346 |
347 |
348 | class BertPooler(nn.Module):
349 | def __init__(self, config):
350 | super(BertPooler, self).__init__()
351 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
352 | self.activation = nn.Tanh()
353 |
354 | def forward(self, hidden_states):
355 | # We "pool" the model by simply taking the hidden state corresponding
356 | # to the first token.
357 | first_token_tensor = hidden_states[:, 0]
358 | pooled_output = self.dense(first_token_tensor)
359 | pooled_output = self.activation(pooled_output)
360 | return pooled_output
361 |
362 |
363 | class BertPredictionHeadTransform(nn.Module):
364 | def __init__(self, config):
365 | super(BertPredictionHeadTransform, self).__init__()
366 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
367 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
368 | self.transform_act_fn = ACT2FN[config.hidden_act]
369 | else:
370 | self.transform_act_fn = config.hidden_act
371 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
372 |
373 | def forward(self, hidden_states):
374 | hidden_states = self.dense(hidden_states)
375 | hidden_states = self.transform_act_fn(hidden_states)
376 | hidden_states = self.LayerNorm(hidden_states)
377 | return hidden_states
378 |
379 |
380 | class BertLMPredictionHead(nn.Module):
381 | def __init__(self, config, bert_model_embedding_weights):
382 | super(BertLMPredictionHead, self).__init__()
383 | self.transform = BertPredictionHeadTransform(config)
384 |
385 | # The output weights are the same as the input embeddings, but there is
386 | # an output-only bias for each token.
387 | self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
388 | bert_model_embedding_weights.size(0),
389 | bias=False)
390 | self.decoder.weight = bert_model_embedding_weights
391 | self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
392 |
393 | def forward(self, hidden_states):
394 | hidden_states = self.transform(hidden_states)
395 | hidden_states = self.decoder(hidden_states) + self.bias
396 | return hidden_states
397 |
398 |
399 | class BertOnlyMLMHead(nn.Module):
400 | def __init__(self, config, bert_model_embedding_weights):
401 | super(BertOnlyMLMHead, self).__init__()
402 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
403 |
404 | def forward(self, sequence_output):
405 | prediction_scores = self.predictions(sequence_output)
406 | return prediction_scores
407 |
408 |
409 | class BertOnlyNSPHead(nn.Module):
410 | def __init__(self, config):
411 | super(BertOnlyNSPHead, self).__init__()
412 | self.seq_relationship = nn.Linear(config.hidden_size, 2)
413 |
414 | def forward(self, pooled_output):
415 | seq_relationship_score = self.seq_relationship(pooled_output)
416 | return seq_relationship_score
417 |
418 |
419 | class BertPreTrainingHeads(nn.Module):
420 | def __init__(self, config, bert_model_embedding_weights):
421 | super(BertPreTrainingHeads, self).__init__()
422 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
423 | self.seq_relationship = nn.Linear(config.hidden_size, 2)
424 |
425 | def forward(self, sequence_output, pooled_output):
426 | prediction_scores = self.predictions(sequence_output)
427 | seq_relationship_score = self.seq_relationship(pooled_output)
428 | return prediction_scores, seq_relationship_score
429 |
430 |
431 | class BertPreTrainedModel(nn.Module):
432 | """ An abstract class to handle weights initialization and
433 | a simple interface for dowloading and loading pretrained models.
434 | """
435 | def __init__(self, config, *inputs, **kwargs):
436 | super(BertPreTrainedModel, self).__init__()
437 | if not isinstance(config, BertConfig):
438 | raise ValueError(
439 | "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
440 | "To create a model from a Google pretrained model use "
441 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
442 | self.__class__.__name__, self.__class__.__name__
443 | ))
444 | self.config = config
445 |
446 | def init_bert_weights(self, module):
447 | """ Initialize the weights.
448 | """
449 | if isinstance(module, (nn.Linear, nn.Embedding)):
450 | # Slightly different from the TF version which uses truncated_normal for initialization
451 | # cf https://github.com/pytorch/pytorch/pull/5617
452 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
453 | elif isinstance(module, BertLayerNorm):
454 | module.bias.data.zero_()
455 | module.weight.data.fill_(1.0)
456 | if isinstance(module, nn.Linear) and module.bias is not None:
457 | module.bias.data.zero_()
458 |
459 | @classmethod
460 | def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
461 | """
462 | Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
463 | Download and cache the pre-trained model file if needed.
464 |
465 | Params:
466 | pretrained_model_name_or_path: either:
467 | - a str with the name of a pre-trained model to load selected in the list of:
468 | . `bert-base-uncased`
469 | . `bert-large-uncased`
470 | . `bert-base-cased`
471 | . `bert-large-cased`
472 | . `bert-base-multilingual-uncased`
473 | . `bert-base-multilingual-cased`
474 | . `bert-base-chinese`
475 | - a path or url to a pretrained model archive containing:
476 | . `bert_config.json` a configuration file for the model
477 | . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
478 | - a path or url to a pretrained model archive containing:
479 | . `bert_config.json` a configuration file for the model
480 | . `model.chkpt` a TensorFlow checkpoint
481 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint
482 | cache_dir: an optional path to a folder in which the pre-trained models will be cached.
483 | state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
484 | *inputs, **kwargs: additional input for the specific Bert class
485 | (ex: num_labels for BertForSequenceClassification)
486 | """
487 | state_dict = kwargs.get('state_dict', None)
488 | kwargs.pop('state_dict', None)
489 | cache_dir = kwargs.get('cache_dir', None)
490 | kwargs.pop('cache_dir', None)
491 | from_tf = kwargs.get('from_tf', False)
492 | kwargs.pop('from_tf', None)
493 |
494 | archive_file = pretrained_model_name_or_path
495 | logger.info("loading archive file {}".format(archive_file))
496 | resolved_archive_file = archive_file
497 | '''
498 | if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
499 | archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
500 | else:
501 | archive_file = pretrained_model_name_or_path
502 | # redirect to the cache, if necessary
503 | try:
504 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
505 | except EnvironmentError:
506 | logger.error(
507 | "Model name '{}' was not found in model name list ({}). "
508 | "We assumed '{}' was a path or url but couldn't find any file "
509 | "associated to this path or url.".format(
510 | pretrained_model_name_or_path,
511 | ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
512 | archive_file))
513 | return None
514 | if resolved_archive_file == archive_file:
515 | logger.info("loading archive file {}".format(archive_file))
516 | else:
517 | logger.info("loading archive file {} from cache at {}".format(
518 | archive_file, resolved_archive_file))
519 | '''
520 | tempdir = None
521 | serialization_dir = resolved_archive_file
522 | '''
523 | if os.path.isdir(resolved_archive_file) or from_tf:
524 | serialization_dir = resolved_archive_file
525 | else:
526 | # Extract archive to temp dir
527 | tempdir = tempfile.mkdtemp()
528 | logger.info("extracting archive file {} to temp dir {}".format(
529 | resolved_archive_file, tempdir))
530 | with tarfile.open(resolved_archive_file, 'r:gz') as archive:
531 | archive.extractall(tempdir)
532 | serialization_dir = tempdir
533 | '''
534 | # Load config
535 | config_file = os.path.join(serialization_dir, CONFIG_NAME)
536 | #if not os.path.exists(config_file):
537 | # # Backward compatibility with old naming format
538 | # config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
539 | config = BertConfig.from_json_file(config_file)
540 | logger.info("Model config {}".format(config))
541 | # Instantiate model.
542 | model = cls(config, *inputs, **kwargs)
543 | if state_dict is None and not from_tf:
544 | weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
545 | state_dict = torch.load(weights_path, map_location='cpu', weights_only=True)
546 | if tempdir:
547 | # Clean up temp dir
548 | shutil.rmtree(tempdir)
549 | if from_tf:
550 | # Directly load from a TensorFlow checkpoint
551 | weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
552 | return load_tf_weights_in_bert(model, weights_path)
553 | # Load from a PyTorch state_dict
554 | old_keys = []
555 | new_keys = []
556 | for key in state_dict.keys():
557 | new_key = None
558 | if 'gamma' in key:
559 | new_key = key.replace('gamma', 'weight')
560 | if 'beta' in key:
561 | new_key = key.replace('beta', 'bias')
562 | if new_key:
563 | old_keys.append(key)
564 | new_keys.append(new_key)
565 | for old_key, new_key in zip(old_keys, new_keys):
566 | state_dict[new_key] = state_dict.pop(old_key)
567 |
568 | missing_keys = []
569 | unexpected_keys = []
570 | error_msgs = []
571 | # copy state_dict so _load_from_state_dict can modify it
572 | metadata = getattr(state_dict, '_metadata', None)
573 | state_dict = state_dict.copy()
574 | if metadata is not None:
575 | state_dict._metadata = metadata
576 |
577 | def load(module, prefix=''):
578 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
579 | module._load_from_state_dict(
580 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
581 | for name, child in module._modules.items():
582 | if child is not None:
583 | load(child, prefix + name + '.')
584 | start_prefix = ''
585 | if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
586 | start_prefix = 'bert.'
587 | load(model, prefix=start_prefix)
588 | if len(missing_keys) > 0:
589 | logger.info("Weights of {} not initialized from pretrained model: {}".format(
590 | model.__class__.__name__, missing_keys))
591 | if len(unexpected_keys) > 0:
592 | logger.info("Weights from pretrained model not used in {}: {}".format(
593 | model.__class__.__name__, unexpected_keys))
594 | if len(error_msgs) > 0:
595 | raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
596 | model.__class__.__name__, "\n\t".join(error_msgs)))
597 | return model
598 |
599 |
600 | class BertModel(BertPreTrainedModel):
601 | """BERT model ("Bidirectional Embedding Representations from a Transformer").
602 |
603 | Params:
604 | config: a BertConfig class instance with the configuration to build a new model
605 |
606 | Inputs:
607 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
608 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
609 | `extract_features.py`, `run_classifier.py` and `run_squad.py`)
610 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
611 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
612 | a `sentence B` token (see BERT paper for more details).
613 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
614 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
615 | input sequence length in the current batch. It's the mask that we typically use for attention when
616 | a batch has varying length sentences.
617 | `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
618 |
619 | Outputs: Tuple of (encoded_layers, pooled_output)
620 | `encoded_layers`: controled by `output_all_encoded_layers` argument:
621 | - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
622 | of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
623 | encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
624 | - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
625 | to the last attention block of shape [batch_size, sequence_length, hidden_size],
626 | `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
627 | classifier pretrained on top of the hidden state associated to the first character of the
628 | input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
629 |
630 | Example usage:
631 | ```python
632 | # Already been converted into WordPiece token ids
633 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
634 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
635 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
636 |
637 | config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
638 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
639 |
640 | model = modeling.BertModel(config=config)
641 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
642 | ```
643 | """
644 | def __init__(self, config):
645 | super(BertModel, self).__init__(config)
646 | self.embeddings = BertEmbeddings(config)
647 | self.encoder = BertEncoder(config)
648 | self.pooler = BertPooler(config)
649 | self.apply(self.init_bert_weights)
650 |
651 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
652 | if attention_mask is None:
653 | attention_mask = torch.ones_like(input_ids)
654 | if token_type_ids is None:
655 | token_type_ids = torch.zeros_like(input_ids)
656 |
657 | # We create a 3D attention mask from a 2D tensor mask.
658 | # Sizes are [batch_size, 1, 1, to_seq_length]
659 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
660 | # this attention mask is more simple than the triangular masking of causal attention
661 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
662 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
663 |
664 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
665 | # masked positions, this operation will create a tensor which is 0.0 for
666 | # positions we want to attend and -10000.0 for masked positions.
667 | # Since we are adding it to the raw scores before the softmax, this is
668 | # effectively the same as removing these entirely.
669 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
670 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
671 |
672 | embedding_output = self.embeddings(input_ids, token_type_ids)
673 | encoded_layers = self.encoder(embedding_output,
674 | extended_attention_mask,
675 | output_all_encoded_layers=output_all_encoded_layers)
676 | sequence_output = encoded_layers[-1]
677 | pooled_output = self.pooler(sequence_output)
678 | if not output_all_encoded_layers:
679 | encoded_layers = encoded_layers[-1]
680 | return encoded_layers, pooled_output
681 |
682 |
683 | class BertForPreTraining(BertPreTrainedModel):
684 | """BERT model with pre-training heads.
685 | This module comprises the BERT model followed by the two pre-training heads:
686 | - the masked language modeling head, and
687 | - the next sentence classification head.
688 |
689 | Params:
690 | config: a BertConfig class instance with the configuration to build a new model.
691 |
692 | Inputs:
693 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
694 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
695 | `extract_features.py`, `run_classifier.py` and `run_squad.py`)
696 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
697 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
698 | a `sentence B` token (see BERT paper for more details).
699 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
700 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
701 | input sequence length in the current batch. It's the mask that we typically use for attention when
702 | a batch has varying length sentences.
703 | `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
704 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
705 | is only computed for the labels set in [0, ..., vocab_size]
706 | `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size]
707 | with indices selected in [0, 1].
708 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence.
709 |
710 | Outputs:
711 | if `masked_lm_labels` and `next_sentence_label` are not `None`:
712 | Outputs the total_loss which is the sum of the masked language modeling loss and the next
713 | sentence classification loss.
714 | if `masked_lm_labels` or `next_sentence_label` is `None`:
715 | Outputs a tuple comprising
716 | - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
717 | - the next sentence classification logits of shape [batch_size, 2].
718 |
719 | Example usage:
720 | ```python
721 | # Already been converted into WordPiece token ids
722 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
723 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
724 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
725 |
726 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
727 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
728 |
729 | model = BertForPreTraining(config)
730 | masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
731 | ```
732 | """
733 | def __init__(self, config):
734 | super(BertForPreTraining, self).__init__(config)
735 | self.bert = BertModel(config)
736 | self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
737 | self.apply(self.init_bert_weights)
738 |
739 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None):
740 | sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
741 | output_all_encoded_layers=False)
742 | prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
743 |
744 | if masked_lm_labels is not None and next_sentence_label is not None:
745 | loss_fct = CrossEntropyLoss(ignore_index=-1)
746 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
747 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
748 | total_loss = masked_lm_loss + next_sentence_loss
749 | return total_loss
750 | else:
751 | return prediction_scores, seq_relationship_score
752 |
753 |
754 | class BertForMaskedLM(BertPreTrainedModel):
755 | """BERT model with the masked language modeling head.
756 | This module comprises the BERT model followed by the masked language modeling head.
757 |
758 | Params:
759 | config: a BertConfig class instance with the configuration to build a new model.
760 |
761 | Inputs:
762 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
763 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
764 | `extract_features.py`, `run_classifier.py` and `run_squad.py`)
765 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
766 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
767 | a `sentence B` token (see BERT paper for more details).
768 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
769 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
770 | input sequence length in the current batch. It's the mask that we typically use for attention when
771 | a batch has varying length sentences.
772 | `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
773 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
774 | is only computed for the labels set in [0, ..., vocab_size]
775 |
776 | Outputs:
777 | if `masked_lm_labels` is not `None`:
778 | Outputs the masked language modeling loss.
779 | if `masked_lm_labels` is `None`:
780 | Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
781 |
782 | Example usage:
783 | ```python
784 | # Already been converted into WordPiece token ids
785 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
786 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
787 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
788 |
789 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
790 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
791 |
792 | model = BertForMaskedLM(config)
793 | masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
794 | ```
795 | """
796 | def __init__(self, config):
797 | super(BertForMaskedLM, self).__init__(config)
798 | self.bert = BertModel(config)
799 | self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
800 | self.apply(self.init_bert_weights)
801 |
802 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None):
803 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask,
804 | output_all_encoded_layers=False)
805 | prediction_scores = self.cls(sequence_output)
806 |
807 | if masked_lm_labels is not None:
808 | loss_fct = CrossEntropyLoss(ignore_index=-1)
809 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
810 | return masked_lm_loss
811 | else:
812 | return prediction_scores
813 |
814 |
815 | class BertForNextSentencePrediction(BertPreTrainedModel):
816 | """BERT model with next sentence prediction head.
817 | This module comprises the BERT model followed by the next sentence classification head.
818 |
819 | Params:
820 | config: a BertConfig class instance with the configuration to build a new model.
821 |
822 | Inputs:
823 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
824 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
825 | `extract_features.py`, `run_classifier.py` and `run_squad.py`)
826 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
827 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
828 | a `sentence B` token (see BERT paper for more details).
829 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
830 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
831 | input sequence length in the current batch. It's the mask that we typically use for attention when
832 | a batch has varying length sentences.
833 | `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
834 | with indices selected in [0, 1].
835 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence.
836 |
837 | Outputs:
838 | if `next_sentence_label` is not `None`:
839 | Outputs the total_loss which is the sum of the masked language modeling loss and the next
840 | sentence classification loss.
841 | if `next_sentence_label` is `None`:
842 | Outputs the next sentence classification logits of shape [batch_size, 2].
843 |
844 | Example usage:
845 | ```python
846 | # Already been converted into WordPiece token ids
847 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
848 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
849 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
850 |
851 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
852 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
853 |
854 | model = BertForNextSentencePrediction(config)
855 | seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
856 | ```
857 | """
858 | def __init__(self, config):
859 | super(BertForNextSentencePrediction, self).__init__(config)
860 | self.bert = BertModel(config)
861 | self.cls = BertOnlyNSPHead(config)
862 | self.apply(self.init_bert_weights)
863 |
864 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None):
865 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
866 | output_all_encoded_layers=False)
867 | seq_relationship_score = self.cls( pooled_output)
868 |
869 | if next_sentence_label is not None:
870 | loss_fct = CrossEntropyLoss(ignore_index=-1)
871 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
872 | return next_sentence_loss
873 | else:
874 | return seq_relationship_score
875 |
876 |
877 | class BertForSequenceClassification(BertPreTrainedModel):
878 | """BERT model for classification.
879 | This module is composed of the BERT model with a linear layer on top of
880 | the pooled output.
881 |
882 | Params:
883 | `config`: a BertConfig class instance with the configuration to build a new model.
884 | `num_labels`: the number of classes for the classifier. Default = 2.
885 |
886 | Inputs:
887 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
888 | with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts
889 | `extract_features.py`, `run_classifier.py` and `run_squad.py`)
890 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
891 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
892 | a `sentence B` token (see BERT paper for more details).
893 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
894 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
895 | input sequence length in the current batch. It's the mask that we typically use for attention when
896 | a batch has varying length sentences.
897 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
898 | with indices selected in [0, ..., num_labels].
899 |
900 | Outputs:
901 | if `labels` is not `None`:
902 | Outputs the CrossEntropy classification loss of the output with the labels.
903 | if `labels` is `None`:
904 | Outputs the classification logits of shape [batch_size, num_labels].
905 |
906 | Example usage:
907 | ```python
908 | # Already been converted into WordPiece token ids
909 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
910 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
911 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
912 |
913 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
914 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
915 |
916 | num_labels = 2
917 |
918 | model = BertForSequenceClassification(config, num_labels)
919 | logits = model(input_ids, token_type_ids, input_mask)
920 | ```
921 | """
922 | def __init__(self, config, num_labels):
923 | super(BertForSequenceClassification, self).__init__(config)
924 | self.num_labels = num_labels
925 | self.bert = BertModel(config)
926 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
927 | self.classifier = nn.Linear(config.hidden_size, num_labels)
928 | self.apply(self.init_bert_weights)
929 |
930 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
931 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
932 | pooled_output = self.dropout(pooled_output)
933 | logits = self.classifier(pooled_output)
934 |
935 | if labels is not None:
936 | loss_fct = CrossEntropyLoss()
937 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
938 | return loss
939 | else:
940 | return logits
941 |
942 |
943 | class BertForMultipleChoice(BertPreTrainedModel):
944 | """BERT model for multiple choice tasks.
945 | This module is composed of the BERT model with a linear layer on top of
946 | the pooled output.
947 |
948 | Params:
949 | `config`: a BertConfig class instance with the configuration to build a new model.
950 | `num_choices`: the number of classes for the classifier. Default = 2.
951 |
952 | Inputs:
953 | `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
954 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
955 | `extract_features.py`, `run_classifier.py` and `run_squad.py`)
956 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length]
957 | with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
958 | and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
959 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices
960 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
961 | input sequence length in the current batch. It's the mask that we typically use for attention when
962 | a batch has varying length sentences.
963 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
964 | with indices selected in [0, ..., num_choices].
965 |
966 | Outputs:
967 | if `labels` is not `None`:
968 | Outputs the CrossEntropy classification loss of the output with the labels.
969 | if `labels` is `None`:
970 | Outputs the classification logits of shape [batch_size, num_labels].
971 |
972 | Example usage:
973 | ```python
974 | # Already been converted into WordPiece token ids
975 | input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]])
976 | input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]])
977 | token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]])
978 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
979 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
980 |
981 | num_choices = 2
982 |
983 | model = BertForMultipleChoice(config, num_choices)
984 | logits = model(input_ids, token_type_ids, input_mask)
985 | ```
986 | """
987 | def __init__(self, config, num_choices):
988 | super(BertForMultipleChoice, self).__init__(config)
989 | self.num_choices = num_choices
990 | self.bert = BertModel(config)
991 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
992 | self.classifier = nn.Linear(config.hidden_size, 1)
993 | self.apply(self.init_bert_weights)
994 |
995 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
996 | flat_input_ids = input_ids.view(-1, input_ids.size(-1))
997 | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
998 | flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
999 | _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
1000 | pooled_output = self.dropout(pooled_output)
1001 | logits = self.classifier(pooled_output)
1002 | reshaped_logits = logits.view(-1, self.num_choices)
1003 |
1004 | if labels is not None:
1005 | loss_fct = CrossEntropyLoss()
1006 | loss = loss_fct(reshaped_logits, labels)
1007 | return loss
1008 | else:
1009 | return reshaped_logits
1010 |
1011 |
1012 | class BertForTokenClassification(BertPreTrainedModel):
1013 | """BERT model for token-level classification.
1014 | This module is composed of the BERT model with a linear layer on top of
1015 | the full hidden state of the last layer.
1016 |
1017 | Params:
1018 | `config`: a BertConfig class instance with the configuration to build a new model.
1019 | `num_labels`: the number of classes for the classifier. Default = 2.
1020 |
1021 | Inputs:
1022 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
1023 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
1024 | `extract_features.py`, `run_classifier.py` and `run_squad.py`)
1025 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
1026 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
1027 | a `sentence B` token (see BERT paper for more details).
1028 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
1029 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
1030 | input sequence length in the current batch. It's the mask that we typically use for attention when
1031 | a batch has varying length sentences.
1032 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
1033 | with indices selected in [0, ..., num_labels].
1034 |
1035 | Outputs:
1036 | if `labels` is not `None`:
1037 | Outputs the CrossEntropy classification loss of the output with the labels.
1038 | if `labels` is `None`:
1039 | Outputs the classification logits of shape [batch_size, sequence_length, num_labels].
1040 |
1041 | Example usage:
1042 | ```python
1043 | # Already been converted into WordPiece token ids
1044 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
1045 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
1046 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
1047 |
1048 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
1049 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
1050 |
1051 | num_labels = 2
1052 |
1053 | model = BertForTokenClassification(config, num_labels)
1054 | logits = model(input_ids, token_type_ids, input_mask)
1055 | ```
1056 | """
1057 | def __init__(self, config, num_labels):
1058 | super(BertForTokenClassification, self).__init__(config)
1059 | self.num_labels = num_labels
1060 | self.bert = BertModel(config)
1061 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
1062 | self.classifier = nn.Linear(config.hidden_size, num_labels)
1063 | self.apply(self.init_bert_weights)
1064 |
1065 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
1066 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
1067 | sequence_output = self.dropout(sequence_output)
1068 | logits = self.classifier(sequence_output)
1069 |
1070 | if labels is not None:
1071 | loss_fct = CrossEntropyLoss()
1072 | # Only keep active parts of the loss
1073 | if attention_mask is not None:
1074 | active_loss = attention_mask.view(-1) == 1
1075 | active_logits = logits.view(-1, self.num_labels)[active_loss]
1076 | active_labels = labels.view(-1)[active_loss]
1077 | loss = loss_fct(active_logits, active_labels)
1078 | else:
1079 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1080 | return loss
1081 | else:
1082 | return logits
1083 |
1084 |
1085 | class BertForQuestionAnswering(BertPreTrainedModel):
1086 | """BERT model for Question Answering (span extraction).
1087 | This module is composed of the BERT model with a linear layer on top of
1088 | the sequence output that computes start_logits and end_logits
1089 |
1090 | Params:
1091 | `config`: a BertConfig class instance with the configuration to build a new model.
1092 |
1093 | Inputs:
1094 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
1095 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
1096 | `extract_features.py`, `run_classifier.py` and `run_squad.py`)
1097 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
1098 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
1099 | a `sentence B` token (see BERT paper for more details).
1100 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
1101 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
1102 | input sequence length in the current batch. It's the mask that we typically use for attention when
1103 | a batch has varying length sentences.
1104 | `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
1105 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken
1106 | into account for computing the loss.
1107 | `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size].
1108 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken
1109 | into account for computing the loss.
1110 |
1111 | Outputs:
1112 | if `start_positions` and `end_positions` are not `None`:
1113 | Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
1114 | if `start_positions` or `end_positions` is `None`:
1115 | Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
1116 | position tokens of shape [batch_size, sequence_length].
1117 |
1118 | Example usage:
1119 | ```python
1120 | # Already been converted into WordPiece token ids
1121 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
1122 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
1123 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
1124 |
1125 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
1126 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
1127 |
1128 | model = BertForQuestionAnswering(config)
1129 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
1130 | ```
1131 | """
1132 | def __init__(self, config):
1133 | super(BertForQuestionAnswering, self).__init__(config)
1134 | self.bert = BertModel(config)
1135 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
1136 | # self.dropout = nn.Dropout(config.hidden_dropout_prob)
1137 | self.qa_outputs = nn.Linear(config.hidden_size, 2)
1138 | self.apply(self.init_bert_weights)
1139 |
1140 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None):
1141 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
1142 | logits = self.qa_outputs(sequence_output)
1143 | start_logits, end_logits = logits.split(1, dim=-1)
1144 | start_logits = start_logits.squeeze(-1)
1145 | end_logits = end_logits.squeeze(-1)
1146 |
1147 | if start_positions is not None and end_positions is not None:
1148 | # If we are on multi-GPU, split add a dimension
1149 | if len(start_positions.size()) > 1:
1150 | start_positions = start_positions.squeeze(-1)
1151 | if len(end_positions.size()) > 1:
1152 | end_positions = end_positions.squeeze(-1)
1153 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
1154 | ignored_index = start_logits.size(1)
1155 | start_positions.clamp_(0, ignored_index)
1156 | end_positions.clamp_(0, ignored_index)
1157 |
1158 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1159 | start_loss = loss_fct(start_logits, start_positions)
1160 | end_loss = loss_fct(end_logits, end_positions)
1161 | total_loss = (start_loss + end_loss) / 2
1162 | return total_loss
1163 | else:
1164 | return start_logits, end_logits
1165 |
--------------------------------------------------------------------------------