├── brnolm ├── __init__.py ├── runtime │ ├── __init__.py │ ├── model_statistics.py │ ├── reporting.py │ ├── tensor_reorganization.py │ ├── runtime_utils.py │ ├── loggers.py │ └── runtime_multifile.py ├── smm_itf │ ├── __init__.py │ ├── pack-smm.py │ ├── xtract-ivecs-example.py │ ├── ivec_appenders.py │ └── smm_ivec_extractor.py ├── data_pipeline │ ├── __init__.py │ ├── temporal_splitting.py │ ├── threaded.py │ ├── masked.py │ ├── augmentation.py │ ├── reading.py │ ├── flexible_pipeline.py │ ├── multistream.py │ └── split_corpus_dataset.py ├── oov_clustering │ ├── __init__.py │ ├── embeddings_computation.py │ └── embeddings_io.py ├── language_models │ ├── __init__.py │ ├── encoders.py │ ├── decoders.py │ ├── ffnn_models.py │ ├── lstm_model.py │ ├── vocab.py │ └── transformer.py ├── kaldi_itf.py ├── lm-info.py ├── analysis.py ├── zoo.py ├── rmn-activation-plotter.py ├── rmn-grad-plotter.py ├── rmn-plotter.py ├── plotting.py ├── analyze-ivec-distribution.py ├── files-to-bow.py ├── multifile-ml-unigram-ppl.py ├── investigate-ivecs.py ├── multifile-ivec-unigram-ppl.py ├── multifile-ml-unigram-tranfer-ppl.py ├── srilm-debug2.py └── analyze-ivec-changes.py ├── test ├── __init__.py ├── utils.py ├── test_language_models │ ├── test_decoders.py │ └── test_lstm.py ├── test_model_statistics.py ├── test_runtime_util.py ├── test_data_pipeline │ └── test_reading.py ├── test_runtime │ └── test_evaluation.py ├── test_smm_ivec_extractor.py ├── test_det.py └── test_analysis.py ├── .gitignore ├── scripts ├── model-info.py ├── export-torchscript.py ├── oov-clustering │ ├── apply-linear-transform.py │ ├── plot-det.py │ ├── insert-oovs.py │ ├── compare-references.py │ ├── compute-wc-covariance.py │ ├── compute-edit-distance.py │ ├── collect-embeddings.py │ ├── evaluate-embeddings.py │ ├── predict-embeddings.py │ ├── evaluate-embeddings-selective.py │ ├── evaluate-embeddings-large-scale.py │ ├── reference-matrix-by-word-alignment.py │ └── process-hybrid-paths.py ├── get-char-vocab.py ├── migrator-batch-first.py ├── rescoring │ ├── score-combiner.py │ ├── pick-best.py │ ├── plot-2d.py │ ├── rescoring-combine-scores.py │ ├── rescore-nbest-continuous.py │ └── rescore-kaldi-latts-continuous.py ├── train │ ├── logger.py │ └── train-multifile.py ├── migrator.py ├── eval │ ├── eval.py │ ├── eval-multifile.py │ ├── eval-independent.py │ ├── eval-noivecs-domain-adaptation.py │ ├── eval-chime.py │ ├── eval-ivecs-oracle.py │ ├── eval-chime-v2.py │ ├── eval-ivecs-partial.py │ └── eval-ivecs-domain-adaptation.py ├── model-building │ ├── build-shallow-nn.py │ ├── build-shallow-nn-with-ivec.py │ ├── build-lstmp.py │ ├── build-transformer.py │ └── build-lstm.py ├── corpus-stats.py ├── sample-from-lm.py └── display-augmented-data.py ├── pyproject.toml ├── LICENSE └── README.md /brnolm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /brnolm/runtime/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /brnolm/smm_itf/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/* 2 | dist/* 3 | -------------------------------------------------------------------------------- /brnolm/data_pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /brnolm/oov_clustering/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /brnolm/language_models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /brnolm/oov_clustering/embeddings_computation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def tensor_from_words(words, vocab): 5 | return torch.tensor([vocab[w] for w in words]).view(1, -1) 6 | -------------------------------------------------------------------------------- /brnolm/kaldi_itf.py: -------------------------------------------------------------------------------- 1 | 2 | def split_nbest_key(key): 3 | fields = key.split('-') 4 | segment = '-'.join(fields[:-1]) 5 | trans_id = fields[-1] 6 | 7 | return segment, trans_id 8 | -------------------------------------------------------------------------------- /test/utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | def getStream(words): 4 | data_source = io.StringIO() 5 | data_source.write(" ".join(words)) 6 | data_source.seek(0) 7 | 8 | return data_source 9 | -------------------------------------------------------------------------------- /brnolm/language_models/encoders.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class FlatEmbedding(nn.Module): 5 | def __init__(self, nb_tokens, dim_embs, init_range=0.1): 6 | super().__init__() 7 | self.embeddings = nn.Embedding(nb_tokens, dim_embs) 8 | nn.init.uniform_(self.embeddings.weight, -init_range, init_range) 9 | 10 | def forward(self, x): 11 | return self.embeddings(x) 12 | -------------------------------------------------------------------------------- /scripts/model-info.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import torch 4 | 5 | from brnolm.runtime.model_statistics import ModelStatistics 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('model_path') 11 | args = parser.parse_args() 12 | 13 | lm = torch.load(args.model_path, map_location='cpu') 14 | print(ModelStatistics(lm)) 15 | 16 | 17 | if __name__ == '__main__': 18 | main() 19 | -------------------------------------------------------------------------------- /test/test_language_models/test_decoders.py: -------------------------------------------------------------------------------- 1 | from test.common import TestCase 2 | import torch 3 | 4 | from brnolm.language_models.decoders import FullSoftmaxDecoder 5 | 6 | 7 | class FullSoftmaxDecoderTests(TestCase): 8 | def test_raw_log_prob_shape(self): 9 | decoder = FullSoftmaxDecoder(4, 3) 10 | o = torch.zeros((2, 3, 4), dtype=torch.float64) 11 | t = torch.tensor([ 12 | [0, 1, 2], 13 | [2, 1, 1], 14 | ]) 15 | y = decoder.neg_log_prob_raw(o, t) 16 | 17 | self.assertEqual(y.shape, (2, 3)) 18 | -------------------------------------------------------------------------------- /brnolm/lm-info.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import torch 5 | 6 | if __name__ == '__main__': 7 | parser = argparse.ArgumentParser(description='PyTorch RNN/LSTM Language Model') 8 | parser.add_argument('--vocab', action='store_true', help='print out the full vocab') 9 | parser.add_argument('load', help='where to load a model from') 10 | args = parser.parse_args() 11 | 12 | lm = torch.load(args.load, map_location='cpu') 13 | print(lm.model) 14 | print("Vocab len:", len(lm.vocab)) 15 | if args.vocab: 16 | print([c for c in lm.vocab.w2i_]) 17 | -------------------------------------------------------------------------------- /scripts/export-torchscript.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from brnolm.language_models import language_model 4 | 5 | 6 | def main(args): 7 | if args.force_cpu: 8 | lm = torch.load(args.lm, map_location='cpu') 9 | else: 10 | lm = torch.load(args.lm) 11 | language_model.torchscript_export(lm, args.frozen_lm) 12 | 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--force-cpu', action='store_true') 17 | parser.add_argument('lm') 18 | parser.add_argument('frozen_lm') 19 | args = parser.parse_args() 20 | 21 | main(args) 22 | -------------------------------------------------------------------------------- /test/test_language_models/test_lstm.py: -------------------------------------------------------------------------------- 1 | from test.common import TestCase 2 | 3 | import torch 4 | 5 | from brnolm.language_models.lstm_model import LSTMLanguageModel 6 | from brnolm.language_models.encoders import FlatEmbedding 7 | 8 | 9 | class OutputExtractionTests(TestCase): 10 | def test_multilayer(self): 11 | encoder = FlatEmbedding(4, 10) 12 | model = LSTMLanguageModel(token_encoder=encoder, dim_input=10, dim_lstm=10, nb_layers=2, dropout=0.0) 13 | h0 = model.init_hidden(3) 14 | o, h1 = model(torch.tensor([[1], [2], [3]]), h0) 15 | 16 | self.assertEqual(model.extract_output_from_h(h1).unsqueeze(1), o) 17 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "brnolm" 3 | version = "0.3.0" 4 | description = "A language modeling toolit" 5 | readme = "README.md" 6 | requires-python = ">=3.6" 7 | license = {file = "LICENSE"} 8 | keywords = ["language modeling"] 9 | 10 | authors = [ 11 | {email = "ibenes@fit.vutbr.cz", name = "Karel Beneš"} 12 | ] 13 | 14 | classifiers=[ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: MIT License", 17 | "Operating System :: POSIX", 18 | ] 19 | 20 | dependencies = [ 21 | "numpy", 22 | "torch>=1.4", 23 | "scikit-learn", 24 | ] 25 | 26 | [project.urls] 27 | repository = "https://github.com/BUTSpeechFIT/BrnoLM" 28 | -------------------------------------------------------------------------------- /scripts/oov-clustering/apply-linear-transform.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import sys 5 | 6 | import numpy as np 7 | 8 | from brnolm.oov_clustering.embeddings_io import emb_line_iterator, str_from_embedding 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--transform', required=True) 13 | args = parser.parse_args() 14 | 15 | transform = np.loadtxt(args.transform) 16 | 17 | for key, embedding in emb_line_iterator(sys.stdin): 18 | projected = embedding @ transform 19 | 20 | emb_str = str_from_embedding(projected) 21 | sys.stdout.write("{} {}\n".format(key, emb_str)) 22 | -------------------------------------------------------------------------------- /scripts/get-char-vocab.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import sys 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--offset', type=int, default=0, help='where to start numbering') 8 | parser.add_argument('file', help='where to collect characters from') 9 | args = parser.parse_args() 10 | 11 | bag_of_letters = set() 12 | with open(args.file) as f: 13 | for line in f: 14 | bag_of_letters = bag_of_letters | set(line) 15 | 16 | bag_of_letters = bag_of_letters - set('\n') 17 | 18 | for i, c in enumerate(bag_of_letters): 19 | sys.stdout.write("'{}' {}\n".format(c, i + args.offset)) 20 | -------------------------------------------------------------------------------- /scripts/oov-clustering/plot-det.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import pickle 5 | 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--log-det', action='store_true') 10 | parser.add_argument('--baseline', action='store_true') 11 | parser.add_argument('--free-axis', action='store_true') 12 | parser.add_argument('--eer-line', action='store_true') 13 | parser.add_argument('file', help="where is the pickled DETCurve") 14 | args = parser.parse_args() 15 | 16 | with open(args.file, 'rb') as f: 17 | det = pickle.load(f) 18 | 19 | det.plot(args.log_det, not args.free_axis, args.eer_line, filename=None) 20 | -------------------------------------------------------------------------------- /brnolm/analysis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def categorical_entropy(p, eps=1e-100): 5 | zeros = p <= eps 6 | 7 | log_p = p.log() 8 | log_p.masked_fill_(zeros, 0.0) # eliminates -inf for p[x] = 0.0 9 | 10 | H_p = - torch.sum(p*log_p, dim=-1) 11 | return H_p / torch.log(torch.FloatTensor([2])) 12 | 13 | 14 | def categorical_cross_entropy(p, q, eps=1e-100): 15 | zeros = p <= eps 16 | 17 | log_q = q.log() 18 | log_q = torch.zeros_like(p) + log_q 19 | log_q.masked_fill_(zeros, 0.0) # eliminates -inf for p[x] = 0.0 20 | 21 | Xent = - torch.sum(p*log_q, dim=-1) 22 | return Xent / torch.log(torch.FloatTensor([2])) 23 | 24 | 25 | def categorical_kld(p, q): 26 | return categorical_cross_entropy(p, q) - categorical_entropy(p) 27 | -------------------------------------------------------------------------------- /brnolm/zoo.py: -------------------------------------------------------------------------------- 1 | sloth = """⣾⣷⣶⣦⣶⡆⢸⣧⣾⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ 2 | ⠀⠀⠀⠀⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ 3 | ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡀⠀⠀⠀⠀⠀⠀⢀⣀⣀⠤⠤⠖⠒⠋⠋⠀ 4 | ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡀⢠⣶⣾⣀⣸⣿⣿⡤⠴⠒⠚⠋⠉⠁⠀⠀⠀⢀⣀⣠⠤⠴⠆ 5 | ⠀⠀⠀⠀⠀⠀⠀⠀⠀⢠⣾⣾⣤⡼⠿⠷⢺⠋⠉⢯⠀⠀⠀⠀⣀⣀⡠⠤⠴⠒⠚⠉⠉⠁⠀⠀⠀⠀⠀ 6 | ⠀⠀⠀⠀⠠⠤⠖⠒⠋⢹⠟⠉⢳⡀⢀⣀⣸⠀⠀⠈⢷⡎⠉⣉⣷⠤⠤⡖⠲⡍⠉⠉⢹⡗⠦⡀⠀⠀⠀ 7 | ⠀⠀⠀⠀⠀⣀⣀⡠⠤⢾⠀⠀⠀⢿⠁⠀⢿⠀⠀⠀⠸⣇⠞⠉⠢⡀⣾⡿⠀⣸⠀⠀⣼⡀⠀⠈⢢⠀⠀ 8 | ⠀⠀⠀⠀⠀⠉⠀⠀⠀⡟⠀⠀⠀⠘⡀⠀⢸⡄⠀⠀⠀⢻⠀⠀⢀⣿⣿⠓⠊⠁⠀⠀⠀⠈⠑⢄⠀⢳⠀ 9 | ⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⢯⠉⢩⡇⠀⠀⠀⠸⡄⠀⠈⢿⣾⠀⢀⡤⠤⡀⠀⠀⠀⠈⣆⠈⣇ 10 | ⠀⠀⠀⠀⠀⠀⠀⠀⠀⣇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢣⠀⠀⠀⠀⢀⢃⣶⣄⠈⡆⠀⠀⠀⢸⠀⢸ 11 | ⠀⠀⠀⠀⠀⠀⠀⠀⠀⢻⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⣆⠀⠀⠀⢸⠈⠛⠋⠀⢸⠀⠀⠀⢸⠀⣸ 12 | ⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠘⢆⠀⠀⡎⠀⠀⠀⠀⢸⠄⠀⣠⠃⢠⠇ 13 | ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢦⣰⠃⠀⠀⠀⠀⢸⣠⠖⠁⣠⠏⠀ 14 | ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠹⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠑⠒⠐⠒⠊⠉⠀⢀⡼⠋⠀⠀ 15 | ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢆⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠚⠁⠀⠀⠀⠀ 16 | ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⠦⣀⡀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⡤⠔⠚⠉⠀⠀⠀⠀⠀⠀⠀⠀ 17 | ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠉⠒⠒⠒⠒⠚⠉⠉⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ 18 | """ 19 | -------------------------------------------------------------------------------- /test/test_model_statistics.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from brnolm.runtime.model_statistics import scaled_int_str 4 | 5 | 6 | class ScaledIntRepreTests(TestCase): 7 | def test_order_0(self): 8 | self.assertEqual(scaled_int_str(0), '0') 9 | 10 | def test_order_1(self): 11 | self.assertEqual(scaled_int_str(10), '10') 12 | 13 | def test_order_2(self): 14 | self.assertEqual(scaled_int_str(210), '210') 15 | 16 | def test_order_3(self): 17 | self.assertEqual(scaled_int_str(3210), '3.2k') 18 | 19 | def test_order_4(self): 20 | self.assertEqual(scaled_int_str(43210), '43.2k') 21 | 22 | def test_order_5(self): 23 | self.assertEqual(scaled_int_str(543210), '543.2k') 24 | 25 | def test_order_6(self): 26 | self.assertEqual(scaled_int_str(6543210), '6.5M') 27 | -------------------------------------------------------------------------------- /scripts/migrator-batch-first.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | '''Migrates old LM from before proper brnolm package was introduced. 3 | 4 | Build around the proposition of this SO answer: 5 | https://stackoverflow.com/a/53327348/9703830 6 | 7 | Uses a separate, monkey-patched pickle (`my_pickle`) for de-serialization 8 | in order to ensure that the pure system pickle is ready to serialize the model. 9 | ''' 10 | 11 | import argparse 12 | import torch 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('source') 18 | parser.add_argument('target') 19 | args = parser.parse_args() 20 | 21 | lm = torch.load(args.source, map_location='cpu') 22 | lm.model.rnn.batch_first = True 23 | lm.model.batch_first = True 24 | torch.save(lm, args.target) 25 | 26 | 27 | if __name__ == '__main__': 28 | main() 29 | -------------------------------------------------------------------------------- /brnolm/data_pipeline/temporal_splitting.py: -------------------------------------------------------------------------------- 1 | class TemporalSplits(): 2 | def __init__(self, seq, nb_inputs_necessary, nb_targets_parallel): 3 | self._seq = seq 4 | self._nb_inputs_necessary = nb_inputs_necessary 5 | self._nb_target_parallel = nb_targets_parallel 6 | 7 | def __iter__(self): 8 | for lend, rend in self.ranges(): 9 | yield ( 10 | self._seq[lend:rend], 11 | self._seq[lend+self._nb_inputs_necessary:rend+1] 12 | ) 13 | 14 | def __len__(self): 15 | return max(len(self._seq) - self._nb_inputs_necessary - self._nb_target_parallel + 1, 0) 16 | 17 | def ranges(self): 18 | for i in range(0, len(self), self._nb_target_parallel): 19 | lend = i 20 | rend = i + self._nb_inputs_necessary + self._nb_target_parallel - 1 21 | yield lend, rend 22 | -------------------------------------------------------------------------------- /brnolm/rmn-activation-plotter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import pickle 4 | import argparse 5 | 6 | import plotting 7 | 8 | import numpy as np 9 | 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--color-bar", action='store_true') 14 | parser.add_argument("activations") 15 | args = parser.parse_args() 16 | 17 | with open(args.activations, 'rb') as f: 18 | activations = pickle.load(f) 19 | 20 | fig_titles = [] 21 | for a in activations: 22 | nb_zeros = np.sum(a==0.0) 23 | fig_titles.append(", ".join([str(x) for x in [nb_zeros, np.mean(a), np.max(a)]])) 24 | 25 | color_setup = { 26 | 'cmap': 'gnuplot', 27 | 'vmin': 0, 28 | 'vmax': 30, 29 | 'colorbar': args.color_bar, 30 | } 31 | 32 | plotting.grid_plot(activations, lambda x:x, "activations", fig_titles, coloring=color_setup) 33 | -------------------------------------------------------------------------------- /test/test_runtime_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .common import TestCase 3 | 4 | from brnolm.runtime.runtime_utils import repackage_hidden 5 | 6 | 7 | class TensorReorganizerTests(TestCase): 8 | def setUp(self): 9 | tensor = torch.tensor( 10 | [[[0.1, 0.1], [0.2, 0.2], [0.3, 0.3]]], 11 | requires_grad=True 12 | ) 13 | self.computed_tensor = tensor * 3 14 | 15 | def test_data_kept(self): 16 | repackaged = repackage_hidden(self.computed_tensor) 17 | self.assertEqual(self.computed_tensor, repackaged) 18 | 19 | def test_result_requires_grad(self): 20 | repackaged = repackage_hidden(self.computed_tensor) 21 | self.assertTrue(repackaged.requires_grad_) 22 | 23 | def test_is_detached(self): 24 | self.assertFalse(self.computed_tensor.grad_fn is None) 25 | repackaged = repackage_hidden(self.computed_tensor) 26 | self.assertTrue(repackaged.grad_fn is None) 27 | -------------------------------------------------------------------------------- /brnolm/rmn-grad-plotter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import pickle 4 | import argparse 5 | 6 | import plotting 7 | 8 | import numpy as np 9 | 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--norms", action='store_true', 14 | help="print average norm of the gradient per layer") 15 | parser.add_argument("grads") 16 | args = parser.parse_args() 17 | 18 | with open(args.grads, 'rb') as f: 19 | grads = pickle.load(f) 20 | 21 | fig_titles = [] 22 | if args.norms: 23 | for g in grads: 24 | norms = np.linalg.norm(g, ord=2, axis=1) 25 | fig_titles.append(", ".join([str(x) for x in [np.min(norms), np.mean(norms), np.max(norms)]])) 26 | 27 | coloring = { 28 | 'vmin' : -5e-3, 29 | 'vmax' : 5e-3, 30 | 'cmap' : 'RdBu' 31 | } 32 | 33 | plotting.grid_plot(grads, lambda x:x, "Weighted grads", fig_titles, coloring) 34 | -------------------------------------------------------------------------------- /brnolm/smm_itf/pack-smm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import os 4 | 5 | import pickle 6 | 7 | import sys 8 | from smm import SMM 9 | import utils 10 | 11 | import .smm_ivec_extractor 12 | 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--smm', required=True, help="path to a trained SMM model") 17 | parser.add_argument('--tokenizer', required=True, help="path to the tokenizer used for processing the data for training the SMM") 18 | parser.add_argument('--save', required=True, help="where to put the final SMM ivec extractor") 19 | args = parser.parse_args() 20 | 21 | model_f = os.path.realpath(args.smm) 22 | model, config = utils.load_model_and_config(model_f) 23 | 24 | with open(args.tokenizer, 'rb') as f: 25 | tokenizer = pickle.load(f) 26 | 27 | ivec_extractor = smm_ivec_extractor.IvecExtractor(model, nb_iters=10, lr=config['eta'], tokenizer=tokenizer) 28 | 29 | with open(args.save, 'wb') as f: 30 | ivec_extractor.save(f) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Karel Benes 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/oov-clustering/insert-oovs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import copy 5 | import sys 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--unk', default="") 10 | parser.add_argument('--unk-oi', default="") 11 | parser.add_argument('--oov-list', required=True) 12 | args = parser.parse_args() 13 | 14 | with open(args.oov_list) as f: 15 | oovs = f.read().split() 16 | 17 | oov_counts = {oov: 0 for oov in oovs} 18 | 19 | for line in sys.stdin: 20 | words = line.split() 21 | oov_hidden = [args.unk if w in oovs else w for w in words] 22 | 23 | if oov_hidden == words: 24 | continue 25 | 26 | for i, w in enumerate(words): 27 | if w in oovs: 28 | path_line = copy.deepcopy(oov_hidden) 29 | path_line[i] = args.unk_oi 30 | path_line_str = " ".join(path_line) 31 | 32 | path_key = "{}:{}".format(w, oov_counts[w]) 33 | oov_counts[w] += 1 34 | sys.stdout.write("{} {}\n".format(path_key, path_line_str)) 35 | -------------------------------------------------------------------------------- /brnolm/data_pipeline/threaded.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import queue 3 | 4 | 5 | class DataCreator(threading.Thread): 6 | def __init__(self, q, data_stream, device): 7 | super().__init__(daemon=True) 8 | self.q = q 9 | try: 10 | self.in_stream = iter(data_stream) 11 | except TypeError: 12 | self.in_stream = data_stream 13 | 14 | self.device = device 15 | 16 | def run(self): 17 | # for batch in self.in_stream: 18 | while True: 19 | try: 20 | batch = next(self.in_stream) 21 | except StopIteration: 22 | break 23 | 24 | batch = (x.to(self.device) for x in batch) 25 | self.q.put(batch) 26 | 27 | 28 | class OndemandDataProvider: 29 | def __init__(self, in_data, device): 30 | self.data = in_data 31 | self.device = device 32 | 33 | def __iter__(self): 34 | q = queue.Queue(maxsize=10) 35 | feeder_thread = DataCreator(q, self.data, self.device) 36 | feeder_thread.start() 37 | 38 | while feeder_thread.is_alive() or not q.empty(): 39 | yield q.get() 40 | -------------------------------------------------------------------------------- /brnolm/rmn-plotter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | import argparse 5 | import plotting 6 | import numpy as np 7 | 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("model") 12 | 13 | args = parser.parse_args() 14 | 15 | with open(args.model, 'rb') as f: 16 | model = torch.load(f) 17 | 18 | cs_numpied = [c.weight.data.numpy() for c in model._cs] 19 | fig_titles = [] 20 | for c in cs_numpied: 21 | norms = np.linalg.norm(c, ord=2, axis=1) 22 | fig_titles.append(", ".join([str(x) for x in [np.min(norms), np.mean(norms), np.max(norms)]])) 23 | plotting.grid_plot(cs_numpied, lambda x: x, "C-Weights", fig_titles) 24 | 25 | plotting.grid_plot(model._cs, lambda c: c.bias.data.numpy()[...,None], "Biases") 26 | 27 | ps_numpied = [p.weight.data.numpy() for p in model._ps] 28 | fig_titles = [] 29 | for p in ps_numpied: 30 | norms = np.linalg.norm(p, ord=2, axis=1) 31 | fig_titles.append(", ".join([str(x) for x in [np.min(norms), np.mean(norms), np.max(norms)]])) 32 | plotting.grid_plot(ps_numpied, lambda x: x, "P-Weights", fig_titles) 33 | -------------------------------------------------------------------------------- /scripts/oov-clustering/compare-references.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import sys 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('ref1') 8 | parser.add_argument('ref2') 9 | args = parser.parse_args() 10 | 11 | nb_00 = 0 12 | nb_01 = 0 13 | nb_10 = 0 14 | nb_11 = 0 15 | 16 | with open(args.ref1) as f1, open(args.ref2) as f2: 17 | for l1, l2 in zip(f1, f2): 18 | fields_1 = [int(float(r)) for r in l1.split()] 19 | fields_2 = [int(float(r)) for r in l2.split()] 20 | 21 | for r1, r2 in zip(fields_1, fields_2): 22 | if r1 == 0 and r2 == 0: 23 | nb_00 += 1 24 | elif r1 == 0 and r2 == 1: 25 | nb_01 += 1 26 | elif r1 == 1 and r2 == 0: 27 | nb_10 += 1 28 | elif r1 == 1 and r2 == 1: 29 | nb_11 += 1 30 | 31 | nb_total = nb_00 + nb_01 + nb_10 + nb_11 32 | 33 | print("{:.2f} {:.2f}".format(100.0*nb_00/nb_total, 100.0*nb_01/nb_total)) 34 | print("{:.2f} {:.2f}".format(100.0*nb_10/nb_total, 100.0*nb_11/nb_total)) 35 | -------------------------------------------------------------------------------- /brnolm/oov_clustering/embeddings_io.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict, List 3 | 4 | 5 | def emb_line_iterator(f): 6 | for line in f: 7 | fields = line.split() 8 | key = fields[0] 9 | embedding = np.asarray([float(e) for e in fields[1:]]) 10 | 11 | yield key, embedding 12 | 13 | 14 | def all_embs_from_file(f): 15 | embs = [] 16 | keys = [] 17 | 18 | for key, emb in emb_line_iterator(f): 19 | embs.append(emb) 20 | keys.append(key) 21 | 22 | return keys, np.stack(embs) 23 | 24 | 25 | def str_from_embedding(emb): 26 | return " ".join(["{:.4f}".format(e) for e in emb]) 27 | 28 | 29 | def all_embs_by_key(f, shall_be_collected=lambda w: True, key_transform=lambda w: w): 30 | collection: Dict[str, List[np.ndarray]] = {} 31 | 32 | for word, emb in emb_line_iterator(f): 33 | word = key_transform(word) 34 | 35 | if not shall_be_collected(word): 36 | continue 37 | 38 | if word in collection: 39 | collection[word].append(emb) 40 | else: 41 | collection[word] = [emb] 42 | 43 | for w in collection: 44 | collection[w] = np.stack(collection[w]) 45 | 46 | return collection 47 | -------------------------------------------------------------------------------- /scripts/rescoring/score-combiner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import sys 5 | 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser(description='PyTorch RNN/LSTM Language Model') 9 | parser.add_argument('first_scores', help='first set of scores to consider') 10 | parser.add_argument('first_weight', type=float, help='weight of the first set') 11 | parser.add_argument('second_scores', help='second set of scores to consider') 12 | parser.add_argument('second_weight', type=float, help='weight of the second set') 13 | args = parser.parse_args() 14 | 15 | sys.stderr.write("{}\n".format(sys.argv)) 16 | 17 | with open(args.first_scores, 'r') as first_f, open(args.second_scores, 'r') as second_f: 18 | 19 | for first_line, second_line in zip(first_f, second_f): 20 | first_fields = first_line.split() 21 | second_fields = second_line.split() 22 | 23 | assert first_fields[0] == second_fields[0] 24 | 25 | first_s = float(first_fields[1]) 26 | second_s = float(second_fields[1]) 27 | 28 | combined_score = first_s * args.first_weight + second_s * args.second_weight 29 | print("{} {:.4f}".format(first_fields[0], combined_score)) 30 | -------------------------------------------------------------------------------- /scripts/train/logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | from torch.utils.tensorboard import SummaryWriter 3 | 4 | 5 | class Logger(object): 6 | def __init__(self, log_dir, update_freq): 7 | """Create a summary writer logging to log_dir.""" 8 | self.writer = SummaryWriter(log_dir) 9 | self.update_freq = update_freq 10 | self.step = 0 11 | 12 | def next_step(self): 13 | self.step += 1 14 | 15 | def logging_step(self): 16 | return (self.step+1) % self.update_freq == 0 17 | 18 | def scalar_summary(self, tag, value, enforce=False): 19 | if not enforce and not self.logging_step(): 20 | return 21 | 22 | self.writer.add_scalar(tag, value, self.step) 23 | 24 | def hierarchical_scalar_summary(self, master_tag, tag, value, enforce=False): 25 | if not enforce and not self.logging_step(): 26 | return 27 | 28 | self.writer.add_scalars(master_tag, {tag: value}, self.step) 29 | 30 | def histo_summary(self, tag, values, bins=1000, enforce=False): 31 | if not enforce and not self.logging_step(): 32 | return 33 | 34 | self.writer.add_histogram(tag, values, self.step, max_bins=bins) 35 | -------------------------------------------------------------------------------- /brnolm/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import atexit 3 | 4 | atexit.register(plt.show) 5 | 6 | DEFAULT_COLORING = {'cmap':plt.cm.RdBu, 'vmin':-0.5, 'vmax':0.5, 'color_bar': True} 7 | 8 | def _flip_ord(X, xsize, ysize): 9 | assert len(X) == xsize * ysize 10 | reord_X = [] 11 | for i in range(len(X)): 12 | x_y = i // xsize 13 | x_x = i % xsize 14 | reord_X.append(X[x_x * ysize + x_y]) 15 | 16 | return reord_X 17 | 18 | 19 | def grid_plot(X, numpy_accessor, title, fig_titles=None, 20 | coloring=DEFAULT_COLORING): 21 | 22 | if fig_titles == None: 23 | fig_titles = [""] * len(X) 24 | assert(len(X) == len(fig_titles)) 25 | 26 | reord_X = _flip_ord(X, 5, 3) 27 | 28 | fig, axes = plt.subplots(nrows=3, ncols=5, figsize=(20,20)) 29 | for x, ax, f_title in zip(X, axes.flat, fig_titles): 30 | im = ax.imshow( 31 | numpy_accessor(x), 32 | cmap=coloring['cmap'], 33 | vmin=coloring['vmin'], 34 | vmax=coloring['vmax'], 35 | ) 36 | ax.set_title(f_title, fontsize=8) 37 | ax.axis('off') 38 | if coloring['colorbar']: 39 | fig.colorbar(im, ax=axes.ravel().tolist()) 40 | fig.canvas.set_window_title(title) 41 | plt.show(block=False) 42 | -------------------------------------------------------------------------------- /brnolm/smm_itf/xtract-ivecs-example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # author : KarelB 5 | # e-mail : ibenes AT fit.vutbr.cz 6 | 7 | import argparse 8 | import torch 9 | 10 | import .smm_ivec_extractor 11 | 12 | 13 | def bow_from_sentence(sentence, vocab): 14 | bow = torch.zeros((1, len(vocab))).float() 15 | 16 | for w in sentence.split(): 17 | bow[0, vocab[w]] += 1.0 18 | 19 | return bow 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--complete-smm", required=True, help="path to a complete SMM model file") 25 | parser.add_argument('--mkl', default=1, type=int, help='number of MKL threads') 26 | 27 | args = parser.parse_args() 28 | 29 | torch.set_num_threads(args.mkl) 30 | torch.manual_seed(0) 31 | 32 | with open(args.complete_smm, 'rb') as f: 33 | ivec_xtractor = smm_ivec_extractor.load(f) 34 | 35 | s_biology1 = "whale rat elephant hippopotamus bee zoology mammals" 36 | s_biology2 = "flower cat dog insect insect" 37 | s_buildings1 = "bridge tower house chimney" 38 | s_buildings2 = "castle factory architecture wall" 39 | 40 | sentences = [s_biology1, s_biology2, s_buildings1, s_buildings2] 41 | ivecs = [ivec_xtractor(s) for s in sentences] 42 | ivecs = torch.stack(ivecs) 43 | 44 | print(ivecs @ ivecs.t()) 45 | -------------------------------------------------------------------------------- /scripts/migrator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | '''Migrates old LM from before proper brnolm package was introduced. 3 | 4 | Build around the proposition of this SO answer: 5 | https://stackoverflow.com/a/53327348/9703830 6 | 7 | Uses a separate, monkey-patched pickle (`my_pickle`) for de-serialization 8 | in order to ensure that the pure system pickle is ready to serialize the model. 9 | ''' 10 | 11 | import argparse 12 | import importlib 13 | import pickle 14 | import torch 15 | 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('source') 20 | parser.add_argument('target') 21 | args = parser.parse_args() 22 | 23 | my_pickle = load_module_extra('pickle') 24 | my_pickle.Unpickler = MyUnpickler 25 | 26 | lm = torch.load(args.source, map_location='cpu', pickle_module=my_pickle) 27 | torch.save(lm, args.target) 28 | 29 | 30 | def load_module_extra(identifier): 31 | spec = importlib.util.find_spec(identifier) 32 | module = importlib.util.module_from_spec(spec) 33 | spec.loader.exec_module(module) 34 | 35 | return module 36 | 37 | 38 | class MyUnpickler(pickle.Unpickler): 39 | def find_class(self, module, name): 40 | renamed_module = module 41 | if module.startswith('language_models'): 42 | renamed_module = 'brnolm.' + module 43 | return super().find_class(renamed_module, name) 44 | 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /brnolm/analyze-ivec-distribution.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import io 4 | import math 5 | import sys 6 | import torch 7 | import numpy as np 8 | 9 | import split_corpus_dataset 10 | import ivec_appenders 11 | import smm_ivec_extractor 12 | 13 | from runtime_utils import filenames_file_to_filenames 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--file-list', required=True, 18 | help="file with list of files analyze") 19 | parser.add_argument('--ivec-extractor', required=True, 20 | help="iVector extractor to use") 21 | parser.add_argument('--output', required=True, 22 | help="where to put the ivectors") 23 | args = parser.parse_args() 24 | print(args) 25 | 26 | print("loading SMM iVector extractor ...") 27 | with open(args.ivec_extractor, 'rb') as f: 28 | ivec_extractor = smm_ivec_extractor.load(f) 29 | print(ivec_extractor) 30 | 31 | documents = filenames_file_to_filenames(args.file_list) 32 | 33 | ivecs = [] 34 | for doc in documents: 35 | with open(doc) as f: 36 | content = f.read() 37 | 38 | complete_ivec = ivec_extractor(content) 39 | ivecs.append(complete_ivec) 40 | 41 | ivecs = torch.stack(ivecs) 42 | print(ivecs) 43 | 44 | with open(args.output, 'w') as f: 45 | np.savetxt(f, ivecs.cpu().numpy()) 46 | -------------------------------------------------------------------------------- /brnolm/runtime/model_statistics.py: -------------------------------------------------------------------------------- 1 | def scaled_int_str(value): 2 | if value < 1000: 3 | return f'{value}' 4 | elif value < 1000000: 5 | return f'{value/1000:.1f}k' 6 | else: 7 | return f'{value/1000000:.1f}M' 8 | 9 | 10 | class ModelStatistics: 11 | def __init__(self, model): 12 | self.model = model 13 | 14 | def total_nb_params(self): 15 | return sum(p.numel() for p in self.model.parameters()) 16 | 17 | def nb_trainable_params(self): 18 | return sum(p.numel() for p in self.model.parameters() if p.requires_grad) 19 | 20 | def trainable_params_breakup(self): 21 | per_param_desc = (f'{name} {scaled_int_str(p.numel())}\n' for name, p in self.model.named_parameters() if p.requires_grad) 22 | return ''.join(per_param_desc) 23 | 24 | def nb_nontrainable_params(self): 25 | return sum(p.numel() for p in self.model.parameters() if not p.requires_grad) 26 | 27 | def __str__(self): 28 | torch_desc = f'{self.model}\n' 29 | nb_params_desc = f'Total number of parameters: {scaled_int_str(self.total_nb_params())}\n' 30 | nb_trainable_desc = f'Number of trainable parameters: {scaled_int_str(self.nb_trainable_params())}\n' 31 | nb_nontrainable_desc = f'Number of nontrainable parameters: {scaled_int_str(self.nb_nontrainable_params())}\n' 32 | return torch_desc + nb_params_desc + nb_trainable_desc + nb_nontrainable_desc + self.trainable_params_breakup() 33 | -------------------------------------------------------------------------------- /brnolm/files-to-bow.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | 5 | import lstm_model 6 | import vocab 7 | import language_model 8 | 9 | import scipy.io as sio 10 | from sklearn.feature_extraction.text import CountVectorizer 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser(description='PyTorch LSTM Language Model') 14 | parser.add_argument('--filelist', type=str, required=True, 15 | help='list of files with documents (one per line)') 16 | parser.add_argument('--load', type=str, required=True, 17 | help='path to corresponding NN LM') 18 | parser.add_argument('--save', type=str, required=True, 19 | help='path to mtx file') 20 | args = parser.parse_args() 21 | 22 | print("loading model...") 23 | with open(args.load, 'rb') as f: 24 | lm = language_model.load(f) 25 | print(lm.model) 26 | vocab = lm.vocab 27 | 28 | documents = [] 29 | with open(args.filelist) as fl: 30 | filenames = fl.read().split() 31 | for filename in filenames: 32 | with open(filename) as f: 33 | documents.append(f.read()) 34 | 35 | cvect = CountVectorizer(documents, analyzer='word', lowercase=False, vocabulary=vocab) 36 | 37 | document_bows = cvect.fit_transform(documents) 38 | vocab = cvect.get_feature_names() 39 | print('document_bows:', document_bows.shape) 40 | 41 | sio.mmwrite(args.save, document_bows.T) 42 | -------------------------------------------------------------------------------- /brnolm/multifile-ml-unigram-ppl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import math 5 | import runtime_utils 6 | import vocab 7 | 8 | import torch 9 | import analysis 10 | 11 | 12 | def bows_to_ps(bows): 13 | uni_ps = bows.t() / bows.sum(dim=1) 14 | return uni_ps.t() 15 | 16 | 17 | def bows_to_ent(bows): 18 | uni_ps = bows_to_ps(bows) 19 | entropies = analysis.categorical_entropy(uni_ps) 20 | 21 | avg_entropy = entropies @ bows.sum(dim=1) / bows.sum() 22 | return avg_entropy 23 | 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--file-list') 28 | parser.add_argument('--vocab') 29 | parser.add_argument('--unk', default='') 30 | args = parser.parse_args() 31 | 32 | fns = runtime_utils.filenames_file_to_filenames(args.file_list) 33 | 34 | documents = [] 35 | for fn in fns: 36 | with open(fn) as f: 37 | documents.append(f.read().split()) 38 | 39 | with open(args.vocab) as f: 40 | vocab = vocab.vocab_from_kaldi_wordlist(f, args.unk) 41 | 42 | bows = torch.zeros(len(documents), len(vocab)).long() 43 | 44 | for doc_no, doc in enumerate(documents): 45 | for w in doc: 46 | bows[doc_no, vocab[w]] += 1 47 | 48 | avg_entropy = bows_to_ent(bows.float()) 49 | print("{:.4f} {:.2f}".format(avg_entropy, 2**avg_entropy)) 50 | 51 | bows_combined = bows.sum(dim=0, keepdim=True) 52 | overall_entropy = bows_to_ent(bows_combined.float()) 53 | print("{:.4f} {:.2f}".format(overall_entropy, 2**overall_entropy)) 54 | -------------------------------------------------------------------------------- /brnolm/data_pipeline/masked.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | 4 | 5 | def masked_tensor_from_sentences(sentences: List[List[int]], filler=0, device=torch.device('cpu'), target_all=False): 6 | try: 7 | sentences[0][0] 8 | except TypeError: 9 | raise ValueError("masked_tensor_from_sentences() consumes List of Lists (batch X time)") 10 | 11 | batch_size = len(sentences) 12 | max_len = max(len(s) for s in sentences) 13 | 14 | shape = (batch_size, max_len-1) 15 | dtype = torch.int64 16 | input = torch.zeros(shape, dtype=dtype, device=device) 17 | target = torch.zeros(shape, dtype=dtype, device=device) 18 | mask = torch.zeros(shape, dtype=dtype, device=device) 19 | 20 | for s in range(len(sentences)): 21 | for t in range(len(sentences[s]) - 1): 22 | input[s, t] = sentences[s][t] 23 | target[s, t] = sentences[s][t+1] 24 | mask[s, t] = 1 25 | 26 | if target_all: 27 | first_inputs = torch.tensor([s[0] for s in sentences], dtype=dtype, device=device) 28 | target = torch.cat([first_inputs.view(-1, 1), target], dim=1) 29 | 30 | batch_of_ones = torch.ones((batch_size, 1), dtype=mask.dtype, device=mask.device) 31 | mask = torch.cat([batch_of_ones, mask], dim=1) 32 | 33 | if input.shape[1] == 0: 34 | batch_of_zeros = torch.zeros((batch_size, 1), dtype=mask.dtype, device=mask.device) 35 | mask = torch.cat([mask, batch_of_zeros], dim=1) 36 | input = batch_of_zeros.clone().detach() 37 | target = torch.cat([target, batch_of_zeros], dim=1) 38 | 39 | return input, target, mask 40 | -------------------------------------------------------------------------------- /brnolm/investigate-ivecs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import torch 4 | 5 | import smm_ivec_extractor 6 | 7 | from runtime_utils import init_seeds, filenames_file_to_filenames 8 | 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser(description='PyTorch RNN/LSTM Language Model') 12 | parser.add_argument('--filelist', type=str, required=True, 13 | help='file with paths to documents') 14 | parser.add_argument('--seed', type=int, default=1111, 15 | help='random seed') 16 | parser.add_argument('--cuda', action='store_true', 17 | help='use CUDA') 18 | parser.add_argument('--ivec-extractor', type=str, required=True, 19 | help='where to load a ivector extractor from') 20 | args = parser.parse_args() 21 | print(args) 22 | 23 | init_seeds(args.seed, args.cuda) 24 | 25 | print("loading SMM iVector extractor ...") 26 | with open(args.ivec_extractor, 'rb') as f: 27 | ivec_extractor = smm_ivec_extractor.load(f) 28 | print(ivec_extractor) 29 | 30 | print("reading data...") 31 | filenames = filenames_file_to_filenames(args.filelist) 32 | texts = [] 33 | for fn in filenames: 34 | with open(fn) as f: 35 | texts.append(f.read()) 36 | 37 | print("computing iVectors...") 38 | ivecs = [ivec_extractor(t) for t in texts] 39 | ivecs = torch.stack(ivecs) 40 | 41 | print("Elements mean:\t", ivecs.mean()) 42 | print("Elements var:\t", ivecs.var()) 43 | 44 | sq_magnitudes = ivecs.pow(2).sum(1) 45 | print("Sq magn mean:\t", sq_magnitudes.mean()) 46 | print("Sq magn var:\t", sq_magnitudes.var()) 47 | -------------------------------------------------------------------------------- /scripts/oov-clustering/compute-wc-covariance.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import sys 5 | 6 | import numpy as np 7 | from scipy.linalg import fractional_matrix_power 8 | 9 | from brnolm.oov_clustering.embeddings_io import all_embs_by_key 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--show-cov', action='store_true') 14 | parser.add_argument( 15 | '--filter', 16 | help='file with words. If present, cov_wc will be computed exclusively from these' 17 | ) 18 | args = parser.parse_args() 19 | 20 | if args.filter: 21 | with open(args.filter) as f: 22 | words_to_collect = f.read().split() 23 | shall_be_collected = lambda w: w in words_to_collect 24 | else: 25 | shall_be_collected = lambda _: True 26 | 27 | collection = all_embs_by_key(sys.stdin, shall_be_collected) 28 | sys.stderr.write( 29 | "INFO: Used total of {} words, {} unique\n".format( 30 | sum(block.shape[0] for block in collection.values()), 31 | len(collection) 32 | ) 33 | ) 34 | 35 | centered = [] 36 | for w in collection: 37 | w_vectors = collection[w] 38 | mean = w_vectors.mean(axis=0) 39 | centered.append(w_vectors - mean) 40 | 41 | all_centered = np.concatenate(centered) 42 | covariance = np.cov(all_centered, rowvar=False) 43 | 44 | whitener = fractional_matrix_power(covariance, -0.5) 45 | np.savetxt(sys.stdout, whitener) 46 | 47 | if args.show_cov: 48 | import matplotlib.pyplot as plt 49 | 50 | plt.figure() 51 | plt.imshow(covariance) 52 | plt.colorbar() 53 | plt.show() 54 | -------------------------------------------------------------------------------- /test/test_data_pipeline/test_reading.py: -------------------------------------------------------------------------------- 1 | from test.common import TestCase 2 | 3 | import io 4 | import torch 5 | 6 | from brnolm.language_models.vocab import Vocabulary 7 | 8 | from brnolm.data_pipeline.reading import get_independent_lines 9 | 10 | 11 | def get_stream(string): 12 | data_source = io.StringIO() 13 | data_source.write(string) 14 | data_source.seek(0) 15 | 16 | return data_source 17 | 18 | 19 | class IndependentSentecesTests(TestCase): 20 | def setUp(self): 21 | self.vocab = Vocabulary('', 0) 22 | self.vocab.add_from_text('a b c') 23 | 24 | def test_single_word(self): 25 | f = get_stream('a\n') 26 | lines = get_independent_lines(f, self.vocab) 27 | self.assertEqual(lines, [torch.tensor([1])]) 28 | 29 | def test_two_words(self): 30 | f = get_stream('a b\n') 31 | lines = get_independent_lines(f, self.vocab) 32 | self.assertEqual(lines, [torch.tensor([1, 2])]) 33 | 34 | def test_two_lines(self): 35 | f = get_stream('a b\nb c a\n') 36 | lines = get_independent_lines(f, self.vocab) 37 | expected = [ 38 | torch.tensor([1, 2]), 39 | torch.tensor([2, 3, 1]), 40 | ] 41 | self.assertEqual(lines, expected) 42 | 43 | def test_empty_line_skipped(self): 44 | f = get_stream('a b\n\nb c a\n') 45 | lines = get_independent_lines(f, self.vocab) 46 | expected = [ 47 | torch.tensor([1, 2]), 48 | torch.tensor([2, 3, 1]), 49 | ] 50 | self.assertEqual(lines, expected) 51 | 52 | def test_empty_file_empty_output(self): 53 | f = get_stream('') 54 | lines = get_independent_lines(f, self.vocab) 55 | self.assertEqual(lines, []) 56 | -------------------------------------------------------------------------------- /scripts/oov-clustering/compute-edit-distance.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import sys 4 | 5 | 6 | def levenshtein_distance(s1, s2): 7 | ''' taken from stackoverflow 8 | ''' 9 | if len(s1) > len(s2): 10 | s1, s2 = s2, s1 11 | 12 | distances = range(len(s1) + 1) 13 | for i2, c2 in enumerate(s2): 14 | distances_ = [i2+1] 15 | for i1, c1 in enumerate(s1): 16 | if c1 == c2: 17 | distances_.append(distances[i1]) 18 | else: 19 | distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) 20 | distances = distances_ 21 | return distances[-1] 22 | 23 | 24 | if __name__ == '__main__': 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--lexicon', help='Compute edit dist. in a phonetic space given by the LEXICON') 27 | args = parser.parse_args() 28 | 29 | lexicon = {} 30 | if args.lexicon: 31 | with open(args.lexicon) as f: 32 | for line in f: 33 | fields = line.split() 34 | lexicon[fields[0]] = fields[1:] 35 | 36 | words = sys.stdin.read().split() 37 | unique_pairs = [(words[i], words[j]) for i in range(len(words)) for j in range(i)] 38 | 39 | if args.lexicon: 40 | pair_distances = sorted( 41 | [(a, b, levenshtein_distance(lexicon[a], lexicon[b])) for a, b in unique_pairs], 42 | key=lambda pair_with_dist: pair_with_dist[2] 43 | ) 44 | else: 45 | pair_distances = sorted( 46 | [(a, b, levenshtein_distance(a, b)) for a, b in unique_pairs], 47 | key=lambda pair_with_dist: pair_with_dist[2] 48 | ) 49 | 50 | for a, b, d in pair_distances: 51 | sys.stdout.write("{} {} {}\n".format(a, b, d)) 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BrnoLM 2 | A neural language modeling toolkit built on PyTorch. 3 | 4 | This is a scientific piece of code, so expect rough edges. 5 | 6 | BrnoLM has so far powered language modeling in the following papers: 7 | * Beneš et al. [Text Augmentation for Language Models in High Error Recognition Scenario](https://arxiv.org/pdf/2011.06056.pdf) 8 | * Žmolíková et al. [BUT System for CHiME-6 Challenge](https://www.fit.vutbr.cz/research/groups/speech/publi/2020/zmolikova_CHiME_2020_abstract.pdf) 9 | * Beneš et al. [i-vectors in language modeling: An efficient way of domain adaptation for feed-forward models](http://www.fit.vutbr.cz/research/groups/speech/publi/2018/benes_interspeech2018_1070.pdf) 10 | * Beneš et al. [Unsupervised Language Model Adaptation for Speech Recognition with no Extra Resources](http://www.fit.vutbr.cz/research/groups/speech/publi/2019/benes_DAGA_2019.pdf) 11 | 12 | 13 | ## Installation 14 | To install, clone this repository and exploit the provided `setup.py`, e.g.: 15 | 16 | ``` 17 | git clone git@github.com:BUTSpeechFIT/BrnoLM.git 18 | cd BrnoLM 19 | pip install . # or, if you don't care about environmental pollution: python setup.py install 20 | ``` 21 | 22 | If you want to edit the sources, [pip with `-e`](https://pip.pypa.io/en/stable/reference/pip_install/#editable-installs) or [setup.py develop](https://setuptools.readthedocs.io/en/latest/setuptools.html#development-mode). 23 | 24 | Occasionally, a PIP version is produced, so you can simply `pip install brnolm` to obtain the last pre-packed version. 25 | 26 | 27 | ### Requirements 28 | The above way of installation takes care of dependencies. 29 | If you want to prepare an environment yourself, know that BrnoLM requires: 30 | 31 | ``` 32 | torch 33 | numpy 34 | scikit-learn 35 | ``` 36 | Exact tested versions are provided in `setup.py`. 37 | -------------------------------------------------------------------------------- /brnolm/data_pipeline/augmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Substitutor: 5 | def __init__(self, rate, replacements_range): 6 | self.rate = rate 7 | self.replacements_range = replacements_range 8 | if not isinstance(replacements_range, int) or replacements_range < 0: 9 | raise ValueError(f"Replacements range needs to be a positive integer, got {replacements_range}") 10 | 11 | def __call__(self, X, targets): 12 | replacements = torch.randint(0, self.replacements_range, X.shape, device=X.device, dtype=X.dtype) 13 | mask = torch.full(X.shape, self.rate, device=X.device) 14 | mask = torch.bernoulli(mask).long() 15 | 16 | return X * (1-mask) + replacements * mask, targets 17 | 18 | 19 | class Deletor: 20 | def __init__(self, rate): 21 | self.rate = rate 22 | 23 | def __call__(self, X, targets): 24 | timemask = torch.full((X.shape[1], ), 1 - self.rate, device=X.device) 25 | timemask = torch.bernoulli(timemask).bool() 26 | return X[:, timemask], targets[:, timemask] 27 | 28 | 29 | class Corruptor: 30 | def __init__(self, source, substitution_rate=-1.0, replacements_range=None, deletion_rate=-1.0): 31 | self.source = source 32 | if substitution_rate > 0.0: 33 | self.substitutor = Substitutor(substitution_rate, replacements_range) 34 | else: 35 | self.substitutor = None 36 | 37 | if deletion_rate > 0.0: 38 | self.deletor = Deletor(deletion_rate) 39 | else: 40 | self.deletor = None 41 | 42 | def __iter__(self): 43 | for X, t in self.source: 44 | if self.deletor: 45 | X, t = self.deletor(X, t) 46 | if self.substitutor: 47 | X, t = self.substitutor(X, t) 48 | yield X, t 49 | -------------------------------------------------------------------------------- /brnolm/multifile-ivec-unigram-ppl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import math 5 | import runtime_utils 6 | 7 | import torch 8 | 9 | import language_model 10 | import smm_ivec_extractor 11 | 12 | 13 | def bows_to_ps(bows): 14 | uni_ps = bows.t() / bows.sum(dim=1) 15 | return uni_ps.t() 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--file-list') 20 | parser.add_argument('--ivec-lm') 21 | parser.add_argument('--ivec-extractor') 22 | args = parser.parse_args() 23 | 24 | print("loading LM...") 25 | with open(args.ivec_lm, 'rb') as f: 26 | lm = language_model.load(f) 27 | lm.model.cuda() 28 | print(lm.model) 29 | 30 | print("loading SMM iVector extractor ...") 31 | with open(args.ivec_extractor, 'rb') as f: 32 | ivec_extractor = smm_ivec_extractor.load(f) 33 | print(ivec_extractor) 34 | 35 | fns = runtime_utils.filenames_file_to_filenames(args.file_list) 36 | documents = [] 37 | for fn in fns: 38 | with open(fn) as f: 39 | documents.append(f.read().split()) 40 | 41 | bows = torch.zeros(len(documents), len(lm.vocab)).long() 42 | for doc_no, doc in enumerate(documents): 43 | for w in doc: 44 | bows[doc_no, lm.vocab[w]] += 1 45 | unigram_ps = bows_to_ps(bows.float()).cuda() 46 | 47 | cross_entropies = [] 48 | for doc_no, doc in enumerate(documents): 49 | text = " ".join(doc) 50 | ivec = ivec_extractor(text).cuda() 51 | qs = lm.model.ivec_to_logprobs(ivec).data 52 | cross_entropies.append(unigram_ps[doc_no] @ qs) 53 | 54 | cross_entropies = torch.FloatTensor(cross_entropies) 55 | avg_ce = -cross_entropies @ bows.float().sum(dim=1) / bows.sum() 56 | 57 | print("{:.4f} {:.2f}".format(avg_ce, math.exp(avg_ce))) 58 | -------------------------------------------------------------------------------- /brnolm/smm_itf/ivec_appenders.py: -------------------------------------------------------------------------------- 1 | from brnolm.runtime.tensor_reorganization import TensorReorganizer 2 | 3 | class CheatingIvecAppender(): 4 | def __init__(self, tokens, ivec_eetor): 5 | """ 6 | Args: 7 | tokens (TokenizedSplit): Source of tokens, represents single 'document'. 8 | """ 9 | self.tokens = tokens 10 | all_words = " ".join(self.tokens.input_words()) 11 | self._ivec = ivec_eetor(all_words) 12 | 13 | 14 | def __iter__(self): 15 | for x, t in self.tokens: 16 | yield (x, t, self._ivec) 17 | 18 | 19 | class HistoryIvecAppender(): 20 | def __init__(self, tokens, ivec_eetor): 21 | """ 22 | Args: 23 | tokens (TokenizedSplit): Source of tokens, represents single 'document'. 24 | """ 25 | self.tokens = tokens 26 | self._ivec_eetor = ivec_eetor 27 | 28 | 29 | def __iter__(self): 30 | history_words = [] 31 | for (x, t), words in zip(self.tokens, self.tokens.input_words()): 32 | ivec = self._ivec_eetor(" ".join(history_words)) 33 | history_words += words.split() 34 | yield (x, t, ivec) 35 | 36 | 37 | class ParalelIvecAppender: 38 | def __init__(self, stream, extractor, translator): 39 | self._stream = stream 40 | self._extractor = extractor 41 | self._translator = translator 42 | self._reorganizer = TensorReorganizer(extractor.zero_bows) 43 | 44 | def __iter__(self): 45 | old_bows = None 46 | for x, t, mask in self._stream: 47 | corresponding_bows = self._reorganizer(old_bows, mask, x.size(0)) 48 | try: 49 | ivectors = self._extractor(corresponding_bows) 50 | except RuntimeError: 51 | print(x.size(), mask, corresponding_bows.size()) 52 | raise 53 | old_bows = corresponding_bows + self._translator(x) 54 | yield x, t, ivectors, mask 55 | -------------------------------------------------------------------------------- /scripts/rescoring/pick-best.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import sys 5 | 6 | 7 | def read_latt(f): 8 | line = f.readline() 9 | while line == '\n': 10 | line = f.readline() 11 | 12 | if line == '': 13 | return None, None, None 14 | 15 | fields = line.strip().split('-') 16 | segment_id = '-'.join(fields[:-1]) 17 | trans_id = fields[-1] 18 | 19 | content = "" 20 | line = f.readline() 21 | while line != '\n': 22 | content += line 23 | line = f.readline() 24 | 25 | return segment_id, trans_id, content 26 | 27 | 28 | def read_pick(f): 29 | return tuple(f.readline().split()) 30 | 31 | 32 | if __name__ == '__main__': 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('pick_file') 35 | 36 | args = parser.parse_args() 37 | 38 | tot_printed = 0 39 | unserved = 0 40 | 41 | with open(args.pick_file, 'r') as p_f: 42 | segment, best_trans = read_pick(p_f) 43 | served = False 44 | 45 | while True: 46 | seg_id, trans_id, latt = read_latt(sys.stdin) 47 | if not seg_id: 48 | if p_f.readline() == '': # both files ended 49 | break 50 | else: 51 | raise ValueError("Latts file (stdin) ended sooner than picks file") 52 | 53 | if seg_id != segment: 54 | if not served: 55 | sys.stderr.write("Unserved picks of segment " + segment + ", wanted " + best_trans + "-th transcription \n") 56 | unserved += 1 57 | 58 | segment, best_trans = read_pick(p_f) 59 | served = False 60 | 61 | if trans_id == best_trans: 62 | print(segment + "\n" + latt + "\n") 63 | tot_printed += 1 64 | served = True 65 | 66 | if unserved > 0: 67 | sys.stderr.write("ERROR: Unserved " + str(unserved) + " picks \n") 68 | sys.exit(1) 69 | -------------------------------------------------------------------------------- /test/test_runtime/test_evaluation.py: -------------------------------------------------------------------------------- 1 | from test.common import TestCase 2 | import math 3 | import torch 4 | 5 | from brnolm.runtime.evaluation import get_oov_additional_cost 6 | from brnolm.runtime.evaluation import OovCostApplicator 7 | 8 | 9 | class OovCostTests(TestCase): 10 | def test_simple(self): 11 | oov_cost = get_oov_additional_cost(100, 1000) 12 | expected = -math.log(1.0/900) 13 | self.assertEqual(oov_cost, expected) 14 | 15 | 16 | class OovCostApplicatorTests(TestCase): 17 | def setUp(self): 18 | self.cost_applier = OovCostApplicator(1.0, 0) 19 | 20 | def test_void(self): 21 | line_ids = torch.tensor([2, 1]) 22 | losses = torch.tensor([0.5, 0.2]) 23 | self.assertEqual(self.cost_applier(line_ids, losses), losses) 24 | 25 | def test_single_line_single_oov(self): 26 | line_ids = torch.tensor([2, 0, 1]) 27 | losses = torch.tensor([0.5, 0.7, 0.2]) 28 | adjusted_losses = torch.tensor([0.5, 1.7, 0.2]) 29 | self.assertEqual(self.cost_applier(line_ids, losses), adjusted_losses) 30 | 31 | def test_single_line_multiple_oovs(self): 32 | line_ids = torch.tensor([0, 2, 0, 1]) 33 | losses = torch.tensor([0.3, 0.5, 0.7, 0.2]) 34 | adjusted_losses = torch.tensor([1.3, 0.5, 1.7, 0.2]) 35 | self.assertEqual(self.cost_applier(line_ids, losses), adjusted_losses) 36 | 37 | def test_non_matching_len(self): 38 | line_ids = torch.tensor([0, 2, 1]) 39 | losses = torch.tensor([0.3, 0.5, 0.7, 0.0]) 40 | adjusted_losses = torch.tensor([1.3, 0.5, 0.7, 0.0]) 41 | self.assertEqual(self.cost_applier(line_ids, losses), adjusted_losses) 42 | 43 | def test_zero_penalty(self): 44 | zero_applier = OovCostApplicator(0.0, 0) 45 | line_ids = torch.tensor([0, 2, 1]) 46 | losses = torch.tensor([0.3, 0.5, 0.7, 0.0]) 47 | adjusted_losses = torch.tensor([0.3, 0.5, 0.7, 0.0]) 48 | self.assertEqual(zero_applier(line_ids, losses), adjusted_losses) 49 | -------------------------------------------------------------------------------- /scripts/rescoring/plot-2d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def parse_line(line): 8 | # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 9 | # %WER 10.35 [ 852 / 8234, 77 ins, 136 del, 639 sub ] /path/pick-1.0-17.0-20-0.5/wer_1_0.0 10 | 11 | fields = line.split() 12 | wer = fields[1] 13 | nb_errs = fields[3] 14 | nb_words = fields[5][:-1] 15 | nb_ins = fields[6] 16 | nb_dels = fields[8] 17 | nb_subs = fields[10] 18 | vals = { 19 | 'wer': float(wer), 20 | 'nb_errs': int(nb_errs), 21 | 'nb_words': int(nb_words), 22 | 'nb_ins': int(nb_ins), 23 | 'nb_dels': int(nb_dels), 24 | 'nb_subs': int(nb_subs), 25 | } 26 | 27 | fn = fields[13] 28 | 29 | parts_of_path = fn.split('/') 30 | assert parts_of_path[-1] == 'wer_1_0.0' 31 | 32 | coords_field = parts_of_path[-2] 33 | coords_fields = coords_field.split('-') 34 | assert coords_fields[0] == 'pick' 35 | 36 | coords = tuple(float(c) for c in coords_fields[1:]) 37 | 38 | return coords, vals 39 | 40 | 41 | if __name__ == '__main__': 42 | coord1 = 2 43 | coord2 = 3 44 | key = 'nb_errs' 45 | 46 | default_coords = (1.0, 17.0, 17.0, 0.0) 47 | 48 | measurements = {} 49 | for line in sys.stdin: 50 | coords, vals = parse_line(line) 51 | measurements[coords] = vals 52 | 53 | default_vals = measurements[default_coords] 54 | 55 | xs = [] 56 | ys = [] 57 | cs = [] 58 | for k in measurements: 59 | xs.append(k[coord1]) 60 | ys.append(k[coord2]) 61 | cs.append(measurements[k][key] - default_vals[key]) 62 | 63 | max_range = max(abs(x) for x in cs) 64 | print(max_range) 65 | 66 | plt.figure() 67 | # plt.scatter(xs, ys, c=cs, cmap='seismic', vmin=-max_range, vmax=max_range) 68 | plt.tricontourf(xs, ys, cs, 100, cmap='seismic', vmin=-max_range, vmax=max_range) 69 | plt.plot(xs, ys, 'ko ', markersize=1.0) 70 | plt.show() 71 | -------------------------------------------------------------------------------- /scripts/oov-clustering/collect-embeddings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import sys 5 | 6 | import torch 7 | 8 | from brnolm.oov_clustering.embeddings_io import str_from_embedding 9 | from brnolm.oov_clustering.embeddings_computation import tensor_from_words 10 | 11 | 12 | def embs_from_words(words, lm): 13 | words = [""] + words 14 | th_data = tensor_from_words(words, lm.vocab)[:, :-1] 15 | h0 = lm.model.init_hidden(th_data.size(0)) 16 | 17 | if not lm.model.batch_first: 18 | th_data = th_data.t() 19 | emb, h = lm.model(th_data, h0) 20 | return [an_embedding[0].detach() for an_embedding in emb] 21 | 22 | 23 | if __name__ == '__main__': 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--fwd-lm') 26 | parser.add_argument('--bwd-lm') 27 | args = parser.parse_args() 28 | 29 | if not args.fwd_lm and not args.bwd_lm: 30 | sys.stderr.write("At least one of '--fwd-lm' and '--bwd-lm' needs to be specified\n") 31 | sys.exit(1) 32 | 33 | if args.fwd_lm: 34 | fwd_lm = torch.load(args.fwd_lm, map_location=lambda storage, location: storage) 35 | fwd_lm.eval() 36 | if args.bwd_lm: 37 | bwd_lm = torch.load(args.bwd_lm, map_location=lambda storage, location: storage) 38 | bwd_lm.eval() 39 | 40 | vocabulary = fwd_lm.vocab if args.fwd_lm else bwd_lm.vocab 41 | 42 | for line in sys.stdin: 43 | words = line.split() 44 | 45 | data_cols = [words] 46 | 47 | if args.fwd_lm: 48 | fwd_embs = embs_from_words(words, fwd_lm) 49 | fwd_embs_strs = [str_from_embedding(emb) for emb in fwd_embs] 50 | data_cols.append(fwd_embs_strs) 51 | 52 | if args.bwd_lm: 53 | bwd_embs = reversed(list(embs_from_words(list(reversed(words)), bwd_lm))) 54 | bwd_embs_strs = [str_from_embedding(emb) for emb in bwd_embs] 55 | data_cols.append(bwd_embs_strs) 56 | 57 | for data_row in zip(*data_cols): 58 | elem_strs = ["{}".format(elem) for elem in data_row] 59 | sys.stdout.write(" ".join(elem_strs) + "\n") 60 | -------------------------------------------------------------------------------- /brnolm/runtime/reporting.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import torch 3 | import math 4 | 5 | 6 | class ValidationWatcher: 7 | def __init__(self, val_fn, initial_val_loss, freq_in_tokens, workdir, lm, lr_control=None): 8 | self.val_losses = [initial_val_loss] 9 | self.validation_fn = val_fn 10 | self.lm = lm 11 | 12 | pathlib.Path(workdir).mkdir(parents=True, exist_ok=True) 13 | 14 | assert(freq_in_tokens > 0) 15 | self.freq = freq_in_tokens 16 | 17 | self.report_fn = '{}/validation-report.txt'.format(workdir) 18 | self.best_model_fn = '{}/best.lm'.format(workdir) 19 | self.latest_model_fn = '{}/latest.lm'.format(workdir) 20 | 21 | self.running_loss = 0.0 22 | self.running_targets = 0 23 | self.running_updates = 0 24 | 25 | self.nb_total_updates = 0 26 | 27 | self.lr_control = lr_control 28 | 29 | def log_training_update(self, loss, nb_targets): 30 | self.running_loss += loss 31 | self.running_targets += nb_targets 32 | self.running_updates += 1 33 | self.nb_total_updates += 1 34 | 35 | if self.running_targets > self.freq: 36 | val_loss = self.run_validation() 37 | 38 | running_ppl = math.exp(self.running_loss / self.running_updates) 39 | val_ppl = math.exp(val_loss) 40 | 41 | desc = '{} updates: {:.2f} {:.2f} {:.3f}\n'.format( 42 | self.nb_total_updates, running_ppl, val_ppl, val_ppl - running_ppl 43 | ) 44 | with open(self.report_fn, 'a') as f: 45 | f.write(desc) 46 | 47 | torch.save(self.lm, self.latest_model_fn) 48 | if min(self.val_losses) == self.val_losses[-1]: 49 | torch.save(self.lm, self.best_model_fn) 50 | 51 | self.running_loss = 0.0 52 | self.running_targets = 0 53 | self.running_updates = 0 54 | 55 | def run_validation(self): 56 | val_loss = self.validation_fn() 57 | self.val_losses.append(val_loss) 58 | 59 | if self.lr_control is not None: 60 | self.lr_control.step(val_loss) 61 | 62 | return val_loss 63 | -------------------------------------------------------------------------------- /scripts/eval/eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import logging 5 | import math 6 | import torch 7 | from safe_gpu.safe_gpu import GPUOwner 8 | 9 | from brnolm.runtime.runtime_utils import init_seeds 10 | from brnolm.runtime.evaluation import EnblockEvaluator 11 | 12 | 13 | def main(args): 14 | print(args) 15 | 16 | init_seeds(args.seed, args.cuda) 17 | 18 | print("loading model...") 19 | device = torch.device('cuda') if args.cuda else torch.device('cpu') 20 | lm = torch.load(args.load, map_location=device) 21 | print(lm) 22 | 23 | evaluator = EnblockEvaluator( 24 | lm, 25 | args.data, 26 | args.batch_size, 27 | args.target_seq_len, 28 | tokenize_regime='chars' if args.characters else 'words', 29 | ) 30 | eval_report = evaluator.evaluate() 31 | 32 | print('total loss {:.1f} | per token loss {:5.2f} | ppl {:8.2f}'.format(eval_report.total_loss, eval_report.loss_per_token, math.exp(eval_report.loss_per_token))) 33 | 34 | 35 | if __name__ == '__main__': 36 | logging.basicConfig(level=logging.INFO, format='[%(levelname)s::%(name)s] %(message)s') 37 | parser = argparse.ArgumentParser(description='PyTorch RNN/LSTM Language Model') 38 | parser.add_argument('--data', type=str, required=True, 39 | help='location of the data corpus') 40 | parser.add_argument('--characters', action='store_true', 41 | help='Treat the file as containing character tokens') 42 | 43 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 44 | help='batch size') 45 | parser.add_argument('--target-seq-len', type=int, default=35, 46 | help='sequence length') 47 | 48 | parser.add_argument('--seed', type=int, default=1111, 49 | help='random seed') 50 | parser.add_argument('--cuda', action='store_true', 51 | help='use CUDA') 52 | parser.add_argument('--load', type=str, required=True, 53 | help='where to load a model from') 54 | args = parser.parse_args() 55 | 56 | if args.cuda: 57 | gpu_owner = GPUOwner() 58 | 59 | main(args) 60 | -------------------------------------------------------------------------------- /brnolm/multifile-ml-unigram-tranfer-ppl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import math 5 | import runtime_utils 6 | import vocab 7 | 8 | import torch 9 | import analysis 10 | 11 | 12 | def bows_to_ps(bows): 13 | uni_ps = bows.t() / bows.sum(dim=1) 14 | return uni_ps.t() 15 | 16 | 17 | def bows_to_ent(bows): 18 | uni_ps = bows_to_ps(bows) 19 | entropies = analysis.categorical_entropy(uni_ps) 20 | 21 | avg_entropy = entropies @ bows.sum(dim=1) / bows.sum() 22 | return avg_entropy 23 | 24 | def documents_from_fn(fn_filelist): 25 | fns = runtime_utils.filenames_file_to_filenames(fn_filelist) 26 | documents = [] 27 | for fn in fns: 28 | with open(fn) as f: 29 | documents.append(f.read().split()) 30 | 31 | return documents 32 | 33 | def bow_from_documents(documents, vocab): 34 | bows = torch.zeros(len(documents), len(vocab)).long() 35 | for doc_no, doc in enumerate(documents): 36 | for w in doc: 37 | bows[doc_no, vocab[w]] += 1 38 | 39 | return bows 40 | 41 | 42 | if __name__ == '__main__': 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('--source-list') 45 | parser.add_argument('--file-list') 46 | parser.add_argument('--vocab') 47 | parser.add_argument('--unk', default='') 48 | args = parser.parse_args() 49 | 50 | with open(args.vocab) as f: 51 | vocab = vocab.vocab_from_kaldi_wordlist(f, args.unk) 52 | 53 | documents = documents_from_fn(args.source_list) 54 | bows = bow_from_documents(documents, vocab).float() 55 | unigram_ps = bows_to_ps(bows.sum(dim=0, keepdim=True)).squeeze() 56 | 57 | test_documents = documents_from_fn(args.file_list) 58 | test_bows = bow_from_documents(test_documents, vocab).float() 59 | test_unigrams = bows_to_ps(test_bows) 60 | 61 | print(unigram_ps.size()) 62 | print(test_unigrams.size()) 63 | 64 | cross_entropies = analysis.categorical_cross_entropy(test_unigrams, unigram_ps) 65 | # print(cross_entropies) 66 | 67 | test_lengths = test_bows.sum(dim=1) 68 | avg_entropy = cross_entropies @ test_lengths / test_bows.sum() 69 | print("{:.4f} {:.2f}".format(avg_entropy, 2**avg_entropy)) 70 | -------------------------------------------------------------------------------- /brnolm/runtime/tensor_reorganization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from typing import Dict, Any 4 | 5 | 6 | class Singleton(type): 7 | _instances: Dict[Any, Any] = {} # TODO what is the actual type? 8 | 9 | def __call__(cls, *args, **kwargs): 10 | if cls not in cls._instances: 11 | cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) 12 | return cls._instances[cls] 13 | 14 | 15 | class InfiniNoneType(metaclass=Singleton): 16 | def __eq__(self, other): 17 | return other is None or isinstance(other, InfiniNone) 18 | 19 | def __iter__(self): 20 | while True: 21 | yield InfiniNoneType() 22 | 23 | 24 | InfiniNone = InfiniNoneType() 25 | 26 | 27 | def reorg_single(orig, mask, new=None): 28 | reorg = torch.index_select(orig, dim=-2, index=mask) 29 | if new is not InfiniNone: 30 | reorg = torch.cat([reorg, new], dim=-2) 31 | 32 | return reorg 33 | 34 | 35 | class TensorReorganizer(): 36 | def __init__(self, zeros_provider): 37 | self._zeros_provider = zeros_provider 38 | 39 | def __call__(self, orig, mask, batch_size): 40 | if len(mask) == 0: 41 | return self._zeros_provider(batch_size) 42 | 43 | if mask.size(0) > batch_size: 44 | raise ValueError("Cannot reorganize mask {} to batch size {}".format(mask, batch_size)) 45 | 46 | if isinstance(orig, tuple): 47 | single_var = False 48 | elif isinstance(orig, torch.Tensor): 49 | single_var = True 50 | else: 51 | raise TypeError( 52 | "orig has unsupported type {}, " 53 | "only tuples and Tensors are accepted".format( 54 | orig.__class__ 55 | ) 56 | ) 57 | 58 | adding = mask.size(0) < batch_size 59 | if adding: 60 | nb_needed_new = batch_size - mask.size(0) 61 | new = self._zeros_provider(nb_needed_new) 62 | else: 63 | new = InfiniNone 64 | 65 | if single_var: 66 | reorg = reorg_single(orig, mask, new) 67 | else: 68 | reorg = tuple(reorg_single(o, mask, n) for o, n in zip(orig, new)) 69 | 70 | return reorg 71 | -------------------------------------------------------------------------------- /scripts/model-building/build-shallow-nn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import torch 4 | 5 | from brnolm.language_models import ffnn_models, vocab, language_model 6 | from brnolm.language_models.decoders import FullSoftmaxDecoder 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser(description='PyTorch FFNN Language Model') 11 | parser.add_argument('--wordlist', type=str, required=True, 12 | help='word -> int map; Kaldi style "words.txt"') 13 | parser.add_argument('--unk', type=str, default="", 14 | help='expected form of "unk" word. Most likely a or ') 15 | parser.add_argument('--emsize', type=int, default=200, 16 | help='size of word embeddings') 17 | parser.add_argument('--nhid', type=int, default=200, 18 | help='number of hidden units per layer') 19 | parser.add_argument('--hist-len', type=int, default=2, 20 | help='number of input words. If n-grams are being modelled, then (n-1)') 21 | parser.add_argument('--dropout', type=float, default=0.2, 22 | help='dropout applied to layers (0 = no dropout)') 23 | parser.add_argument('--tied', action='store_true', 24 | help='tie the word embedding and softmax weights') 25 | parser.add_argument('--seed', type=int, default=1111, 26 | help='random seed') 27 | parser.add_argument('--save', type=str, required=True, 28 | help='path to save the final model') 29 | args = parser.parse_args() 30 | 31 | # Set the random seed manually for reproducibility. 32 | torch.manual_seed(args.seed) 33 | 34 | print("loading vocabulary...") 35 | with open(args.wordlist, 'r') as f: 36 | vocabulary = vocab.vocab_from_kaldi_wordlist(f, args.unk) 37 | 38 | print("building model...") 39 | 40 | model = ffnn_models.BengioModel( 41 | len(vocabulary), args.emsize, args.hist_len, 42 | args.nhid, args.dropout 43 | ) 44 | 45 | decoder = FullSoftmaxDecoder(args.nhid, len(vocabulary)) 46 | 47 | lm = language_model.LanguageModel(model, decoder, vocabulary) 48 | torch.save(lm, args.save) 49 | -------------------------------------------------------------------------------- /scripts/oov-clustering/evaluate-embeddings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import sys 5 | 6 | import numpy as np 7 | from scipy.spatial.distance import pdist, squareform 8 | 9 | from brnolm.oov_clustering.embeddings_io import all_embs_from_file 10 | from brnolm.oov_clustering.det import DETCurve 11 | 12 | 13 | def trial_scores_list(keys, similarities): 14 | score_tg = [] 15 | for i in range(len(keys)): 16 | for j in range(i+1, len(keys)): 17 | a = keys[i].split(':')[0] 18 | b = keys[j].split(':')[0] 19 | 20 | score = similarities[i, j] 21 | 22 | if a == b: 23 | score_tg.append((score, 1)) 24 | else: 25 | score_tg.append((score, 0)) 26 | 27 | return score_tg 28 | 29 | 30 | if __name__ == '__main__': 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--length-norm', action='store_true') 33 | parser.add_argument('--log-det', action='store_true') 34 | parser.add_argument('--eps', type=float, default=1e-3, help='to prevent log of zero') 35 | parser.add_argument('--plot', help='where to store the DET curve') 36 | parser.add_argument('--baseline', action='store_true') 37 | parser.add_argument('--free-axis', action='store_true') 38 | parser.add_argument('--eer-line', action='store_true') 39 | parser.add_argument('--metric', default='inner_prod', choices=['inner_prod', 'l2_dist']) 40 | args = parser.parse_args() 41 | 42 | keys, embs = all_embs_from_file(sys.stdin) 43 | if args.length_norm: 44 | embs /= np.linalg.norm(embs, axis=1)[:, None] 45 | 46 | if args.metric == 'inner_prod': 47 | similarities = embs @ embs.T 48 | elif args.metric == 'l2_dist': 49 | similarities = -squareform(pdist(embs)) 50 | 51 | if args.plot: 52 | import matplotlib.pyplot as plt 53 | plt.figure() 54 | plt.imshow(similarities) 55 | plt.colorbar() 56 | 57 | score_tg = trial_scores_list(keys, similarities) 58 | 59 | det = DETCurve(score_tg, args.baseline, max_det_points=20) 60 | sys.stdout.write(det.textual_report()) 61 | if args.plot is not None: 62 | det.plot(args.log_det, not args.free_axis, args.eer_line, args.plot) 63 | -------------------------------------------------------------------------------- /scripts/eval/eval-multifile.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import math 5 | import torch 6 | 7 | from brnolm.data_pipeline.multistream import BatchBuilder 8 | 9 | from brnolm.data_pipeline.reading import tokens_from_file 10 | from brnolm.data_pipeline.temporal_splitting import TemporalSplits 11 | from brnolm.runtime.runtime_utils import CudaStream, init_seeds, filelist_to_objects 12 | from brnolm.runtime.runtime_multifile import evaluate 13 | 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser(description='PyTorch RNN/LSTM Language Model') 17 | parser.add_argument('--file-list', type=str, required=True, 18 | help='file with paths to training documents') 19 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 20 | help='batch size') 21 | parser.add_argument('--target-seq-len', type=int, default=35, 22 | help='sequence length') 23 | parser.add_argument('--seed', type=int, default=1111, 24 | help='random seed') 25 | parser.add_argument('--cuda', action='store_true', 26 | help='use CUDA') 27 | parser.add_argument('--concat-articles', action='store_true', 28 | help='pass hidden states over article boundaries') 29 | parser.add_argument('--load', type=str, required=True, 30 | help='where to load a model from') 31 | args = parser.parse_args() 32 | print(args) 33 | 34 | init_seeds(args.seed, args.cuda) 35 | 36 | print("loading model...") 37 | lm = torch.load(args.load) 38 | if args.cuda: 39 | lm.cuda() 40 | print(lm.model) 41 | 42 | print("preparing data...") 43 | 44 | def temp_splits_from_fn(fn): 45 | tokens = tokens_from_file(fn, lm.vocab, randomize=False) 46 | return TemporalSplits(tokens, lm.model.in_len, args.target_seq_len) 47 | 48 | tss = filelist_to_objects(args.file_list, temp_splits_from_fn) 49 | data = BatchBuilder(tss, args.batch_size, 50 | discard_h=not args.concat_articles) 51 | if args.cuda: 52 | data = CudaStream(data) 53 | 54 | loss = evaluate(lm, data, use_ivecs=False) 55 | print('loss {:5.2f} | ppl {:8.2f}'.format(loss, math.exp(loss))) 56 | -------------------------------------------------------------------------------- /scripts/eval/eval-independent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import logging 5 | import math 6 | import torch 7 | 8 | from brnolm.runtime.runtime_utils import init_seeds 9 | from brnolm.runtime.evaluation import IndependentLinesEvaluator 10 | 11 | 12 | def main(): 13 | logging.basicConfig(level=logging.INFO, format='[%(levelname)s::%(name)s] %(message)s') 14 | parser = argparse.ArgumentParser(description='PyTorch RNN/LSTM Language Model') 15 | parser.add_argument('--data', type=str, required=True, 16 | help='location of the data corpus') 17 | parser.add_argument('--prefix', type=str, 18 | help='') 19 | parser.add_argument('--total-vocab-size', type=int, 20 | help='how many words should be assumed to exist overall') 21 | 22 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 23 | help='batch size') 24 | parser.add_argument('--max-tokens', type=int, default=1000, metavar='N', 25 | help='Maximal number of softmaxes in a batch') 26 | 27 | parser.add_argument('--seed', type=int, default=1111, 28 | help='random seed') 29 | parser.add_argument('--cuda', action='store_true', 30 | help='use CUDA') 31 | parser.add_argument('--load', type=str, required=True, 32 | help='where to load a model from') 33 | args = parser.parse_args() 34 | print(args) 35 | 36 | init_seeds(args.seed, args.cuda) 37 | 38 | print("loading model...") 39 | lm = torch.load(args.load, map_location='cpu') 40 | lm.nb_nonzero_masks = 0 41 | lm.eval() 42 | if args.cuda: 43 | lm.cuda() 44 | print(lm) 45 | 46 | evaluator = IndependentLinesEvaluator( 47 | lm=lm, 48 | fn_evalset=args.data, 49 | max_batch_size=args.batch_size, 50 | max_tokens=args.max_tokens, 51 | total_vocab_size=args.total_vocab_size 52 | ) 53 | eval_report = evaluator.evaluate(args.prefix) 54 | 55 | print(f'Utilization: {100.0*eval_report.utilization:.2f} %') 56 | print('total loss {:.1f} | per token loss {:5.2f} | ppl {:8.2f}'.format(eval_report.total_loss, eval_report.loss_per_token, math.exp(eval_report.loss_per_token))) 57 | 58 | 59 | if __name__ == '__main__': 60 | main() 61 | -------------------------------------------------------------------------------- /scripts/oov-clustering/predict-embeddings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import sys 4 | 5 | import torch 6 | 7 | from brnolm.oov_clustering.embeddings_io import str_from_embedding 8 | from brnolm.oov_clustering.embeddings_computation import tensor_from_words 9 | 10 | 11 | def relevant_prefix(transcript, word_of_interest): 12 | first_oov_oi_loc = transcript.index(word_of_interest) 13 | if word_of_interest in transcript[first_oov_oi_loc+1:]: 14 | raise ValueError("there are multiple OOVs of interest!") 15 | 16 | return transcript[:first_oov_oi_loc] 17 | 18 | 19 | BATCH_SIZE = 1 20 | 21 | 22 | def emb_from_string(transcript, lm): 23 | prefix = relevant_prefix(transcript, args.unk_oi) 24 | prefix = [""] + prefix 25 | 26 | th_data = tensor_from_words(prefix, lm.vocab) 27 | h0 = lm.model.init_hidden(th_data.size(0)) 28 | 29 | if not lm.model.batch_first: 30 | th_data = th_data.t() 31 | emb, h = lm.model(th_data, h0) 32 | out_emb = emb[0][-1].data 33 | 34 | return out_emb 35 | 36 | 37 | if __name__ == '__main__': 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--unk', default="") 40 | parser.add_argument('--unk-oi', default="") 41 | parser.add_argument('--fwd-lm') 42 | parser.add_argument('--bwd-lm') 43 | args = parser.parse_args() 44 | 45 | if not args.fwd_lm and not args.bwd_lm: 46 | sys.stderr.write("At least one of '--fwd-lm' and '--bwd-lm' needs to be specified\n") 47 | sys.exit(1) 48 | 49 | if args.fwd_lm: 50 | fwd_lm = torch.load(args.fwd_lm, map_location=lambda storage, location: storage) 51 | fwd_lm.eval() 52 | if args.bwd_lm: 53 | bwd_lm = torch.load(args.bwd_lm, map_location=lambda storage, location: storage) 54 | bwd_lm.eval() 55 | 56 | for line_no, line in enumerate(sys.stdin): 57 | fields = line.split() 58 | key = fields[0] 59 | transcript = fields[1:] 60 | 61 | output = key 62 | if args.fwd_lm: 63 | fwd_emb = emb_from_string(transcript, fwd_lm) 64 | output += " " + str_from_embedding(fwd_emb) 65 | if args.bwd_lm: 66 | bwd_emb = emb_from_string(list(reversed(transcript)), bwd_lm) 67 | output += " " + str_from_embedding(bwd_emb) 68 | output += '\n' 69 | 70 | sys.stdout.write(output) 71 | -------------------------------------------------------------------------------- /scripts/model-building/build-shallow-nn-with-ivec.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import torch 4 | 5 | from brnolm.language_models import language_model, vocab, ffnn_models 6 | from brnolm.language_models.decoders import FullSoftmaxDecoder 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser(description='PyTorch FFNN Language Model') 11 | parser.add_argument('--wordlist', type=str, required=True, 12 | help='word -> int map; Kaldi style "words.txt"') 13 | parser.add_argument('--unk', type=str, default="", 14 | help='expected form of "unk" word. Most likely a or ') 15 | parser.add_argument('--emsize', type=int, default=200, 16 | help='size of word embeddings') 17 | parser.add_argument('--nhid', type=int, default=200, 18 | help='number of hidden units per layer') 19 | parser.add_argument('--ivec-dim', type=int, required=True, 20 | help='ivector dimensionality') 21 | parser.add_argument('--hist-len', type=int, default=2, 22 | help='number of input words. If n-grams are being modelled, then (n-1)') 23 | parser.add_argument('--dropout', type=float, default=0.2, 24 | help='dropout applied to layers (0 = no dropout)') 25 | parser.add_argument('--tied', action='store_true', 26 | help='tie the word embedding and softmax weights') 27 | parser.add_argument('--seed', type=int, default=1111, 28 | help='random seed') 29 | parser.add_argument('--save', type=str, required=True, 30 | help='path to save the final model') 31 | args = parser.parse_args() 32 | 33 | # Set the random seed manually for reproducibility. 34 | torch.manual_seed(args.seed) 35 | 36 | print("loading vocabulary...") 37 | with open(args.wordlist, 'r') as f: 38 | vocabulary = vocab.vocab_from_kaldi_wordlist(f, args.unk) 39 | 40 | print("building model...") 41 | 42 | model = ffnn_models.BengioModelIvecInput( 43 | len(vocabulary), args.emsize, args.hist_len, 44 | args.nhid, args.dropout, args.ivec_dim 45 | ) 46 | 47 | decoder = FullSoftmaxDecoder(args.nhid, len(vocabulary)) 48 | 49 | lm = language_model.LanguageModel(model, decoder, vocabulary) 50 | torch.save(lm, args.save) 51 | -------------------------------------------------------------------------------- /scripts/rescoring/rescoring-combine-scores.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | 5 | import brnolm.kaldi_itf as kaldi_itf 6 | 7 | 8 | def dict_argmin(dict): 9 | return min(dict, key=dict.get) 10 | 11 | 12 | def write_best(scores, key, out_f): 13 | best = dict_argmin(scores) 14 | out_f.write(key + ' ' + best + '\n') 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser(description='PyTorch RNN/LSTM Language Model') 19 | parser.add_argument('--ac-scale', type=float, default=1.0, help='weight of acoustic score') 20 | parser.add_argument('--gr-scale', type=float, required=True, help='weight of graph scores') 21 | parser.add_argument('--lm-scale', type=float, required=True, help='weight of rnnlm scores') 22 | parser.add_argument('acoustic_scores', help='file with acoustic scores') 23 | parser.add_argument('graph_scores', help='file with graph scores') 24 | parser.add_argument('rnnlm_scores', help='file with rnnlm scores') 25 | parser.add_argument('out_filename', help='where to put the best picked transcripts') 26 | args = parser.parse_args() 27 | 28 | print(args) 29 | 30 | curr_seg = None 31 | segment_utts_scores = {} 32 | 33 | with open(args.acoustic_scores, 'r') as ac_f, \ 34 | open(args.graph_scores, 'r') as gr_f, \ 35 | open(args.rnnlm_scores, 'r') as lm_f, \ 36 | open(args.out_filename, 'w') as out_f: 37 | 38 | for ac_line, gr_line, lm_line in zip(ac_f, gr_f, lm_f): 39 | ac_fields = ac_line.split() 40 | gr_fields = gr_line.split() 41 | lm_fields = lm_line.split() 42 | 43 | assert ac_fields[0] == gr_fields[0] and gr_fields[0] == lm_fields[0] 44 | segment, trans_id = kaldi_itf.split_nbest_key(ac_fields[0]) 45 | 46 | if not curr_seg: 47 | curr_seg = segment 48 | 49 | if segment != curr_seg: 50 | write_best(segment_utts_scores, curr_seg, out_f) 51 | 52 | curr_seg = segment 53 | segment_utts_scores = {} 54 | 55 | ac_s = float(ac_fields[1]) 56 | gr_s = float(gr_fields[1]) 57 | lm_s = float(lm_fields[1]) 58 | 59 | segment_utts_scores[trans_id] = ( 60 | args.ac_scale * ac_s + 61 | args.gr_scale * gr_s + 62 | args.lm_scale * lm_s 63 | ) 64 | 65 | write_best(segment_utts_scores, curr_seg, out_f) 66 | -------------------------------------------------------------------------------- /scripts/corpus-stats.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import torch 5 | import sys 6 | import collections 7 | 8 | from brnolm.data_pipeline.reading import tokens_from_fn, word_splitter, char_splitter 9 | 10 | 11 | def get_oovs(fn, regime, vocab): 12 | with open(args.train) as f: 13 | lines = f.read().split('\n') 14 | 15 | oov_counts = collections.defaultdict(int) 16 | 17 | if regime == 'words': 18 | tokenizer = word_splitter 19 | elif regime == 'chars': 20 | tokenizer = lambda line: char_splitter(line, '') 21 | else: 22 | raise ValueError("unsupported regime {}".format(regime)) 23 | 24 | for line in lines: 25 | tokens = tokenizer(line) 26 | for tok in tokens: 27 | if vocab[tok] == vocab.unk_ind: 28 | oov_counts[tok] += 1 29 | 30 | return oov_counts 31 | 32 | 33 | if __name__ == '__main__': 34 | parser = argparse.ArgumentParser(description='PyTorch RNN/LSTM Language Model') 35 | parser.add_argument('--characters', action='store_true', 36 | help='work on character level, whitespace is significant') 37 | parser.add_argument('--lm', type=str, required=True, 38 | help='where to load a model from') 39 | parser.add_argument('--dump-oovs', type=str, 40 | help='where to write oovs ("-" for stdout)') 41 | parser.add_argument('train', type=str, 42 | help='location of the train corpus') 43 | args = parser.parse_args() 44 | 45 | lm = torch.load(args.lm, map_location='cpu') 46 | 47 | tokenize_regime = 'words' 48 | if args.characters: 49 | tokenize_regime = 'chars' 50 | 51 | train_ids = tokens_from_fn(args.train, lm.vocab, randomize=False, regime=tokenize_regime) 52 | sys.stdout.write('Vocabulary size: {}\n'.format(len(lm.vocab))) 53 | 54 | nb_tokens = len(train_ids) 55 | sys.stdout.write('Nb tokens: {}\n'.format(nb_tokens)) 56 | 57 | oov_mask = train_ids == lm.vocab.unk_ind 58 | nb_oovs = oov_mask.sum() 59 | 60 | sys.stdout.write('Nb oovs: {} ({:.2f} %)\n'.format(nb_oovs, 100.0 * nb_oovs/nb_tokens)) 61 | 62 | if args.dump_oovs: 63 | oov_counts = get_oovs(args.train, 'chars' if args.characters else 'words', lm.vocab) 64 | with open(args.dump_oovs, 'w') as f: 65 | for tok, count in sorted(oov_counts.items(), key=lambda item: item[1], reverse=True): 66 | f.write(f'{tok} {count}\n') 67 | -------------------------------------------------------------------------------- /scripts/model-building/build-lstmp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import torch 5 | 6 | from brnolm.language_models import lstm_model, vocab, language_model 7 | from brnolm.language_models.decoders import FullSoftmaxDecoder 8 | 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser(description='PyTorch LSTM Language Model') 12 | parser.add_argument('--wordlist', type=str, required=True, 13 | help='word -> int map; Kaldi style "words.txt"') 14 | parser.add_argument('--quoted-wordlist', action='store_true', 15 | help='assume the words are quoted (with a single quote)') 16 | parser.add_argument('--unk', type=str, default="", 17 | help='expected form of "unk" word. Most likely a or ') 18 | parser.add_argument('--emsize', type=int, default=200, 19 | help='size of word embeddings') 20 | parser.add_argument('--nhid', type=int, default=200, 21 | help='number of hidden units per layer') 22 | parser.add_argument('--nlayers', type=int, default=2, 23 | help='number of layers') 24 | parser.add_argument('--dropout', type=float, default=0.2, 25 | help='dropout applied to layers (0 = no dropout)') 26 | parser.add_argument('--tied', action='store_true', 27 | help='tie the word embedding and softmax weights') 28 | parser.add_argument('--seed', type=int, default=1111, 29 | help='random seed') 30 | parser.add_argument('--save', type=str, required=True, 31 | help='path to save the final model') 32 | args = parser.parse_args() 33 | 34 | # Set the random seed manually for reproducibility. 35 | torch.manual_seed(args.seed) 36 | 37 | print("loading vocabulary...") 38 | with open(args.wordlist, 'r') as f: 39 | if args.quoted_wordlist: 40 | vocabulary = vocab.quoted_vocab_from_kaldi_wordlist(f, args.unk) 41 | else: 42 | vocabulary = vocab.vocab_from_kaldi_wordlist(f, args.unk) 43 | 44 | print("building model...") 45 | 46 | model = lstm_model.LSTMPLanguageModel( 47 | len(vocabulary), args.emsize, args.nhid, 48 | args.nlayers, args.dropout, args.tied 49 | ) 50 | 51 | decoder = FullSoftmaxDecoder(args.emsize, len(vocabulary)) 52 | 53 | lm = language_model.LanguageModel(model, decoder, vocabulary) 54 | torch.save(lm, args.save) 55 | -------------------------------------------------------------------------------- /scripts/oov-clustering/evaluate-embeddings-selective.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | from os.path import commonprefix 5 | import sys 6 | 7 | import numpy as np 8 | from scipy.spatial.distance import pdist, squareform 9 | 10 | from brnolm.oov_clustering.embeddings_io import all_embs_by_key 11 | from brnolm.oov_clustering.det import DETCurve 12 | 13 | 14 | def extract_unique_scores(square_scores): 15 | return square_scores[np.triu_indices(square_scores.shape[0], k=0)] 16 | 17 | 18 | def only_differ_in_suffix(a, b, suffix_maxlen=1): 19 | prefix = commonprefix([a, b]) 20 | a_suffix = a[len(prefix):] 21 | b_suffix = a[len(prefix):] 22 | 23 | return max([len(a_suffix), len(b_suffix)]) <= suffix_maxlen 24 | 25 | 26 | if __name__ == '__main__': 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--log-det', action='store_true') 29 | parser.add_argument('--eps', type=float, default=1e-3, help='to prevent log of zero') 30 | parser.add_argument('--plot', action='store_true') 31 | parser.add_argument('--baseline', action='store_true') 32 | parser.add_argument('--trials', required=True, help='file with word pairs to compare') 33 | parser.add_argument('--disregard-suffixes', action='store_true') 34 | parser.add_argument('--free-axis', action='store_true') 35 | parser.add_argument('--eer-line', action='store_true') 36 | parser.add_argument('--metric', default='inner_prod', choices=['inner_prod']) 37 | args = parser.parse_args() 38 | 39 | trial_pairs = [] 40 | with open(args.trials) as f: 41 | for line in f: 42 | trial_pairs.append(tuple(line.split())) 43 | 44 | emb_collection = all_embs_by_key(sys.stdin, key_transform=lambda w: w.split(':')[0]) 45 | 46 | score_tg = [] 47 | for w in emb_collection: 48 | embs = emb_collection[w] 49 | similarities = embs @ embs.T 50 | score_tg.extend([(s, 1) for s in extract_unique_scores(similarities)]) 51 | 52 | for a, b in trial_pairs: 53 | if a not in emb_collection or b not in emb_collection: 54 | continue 55 | 56 | if args.disregard_suffixes and only_differ_in_suffix(a, b): 57 | continue 58 | 59 | a_embs = emb_collection[a] 60 | b_embs = emb_collection[b] 61 | similarities = a_embs @ b_embs.T 62 | score_tg.extend([(s, 0) for s in similarities.flat]) 63 | 64 | det = DETCurve(score_tg, args.baseline, max_det_points=200) 65 | sys.stdout.write(det.textual_report()) 66 | if args.plot: 67 | det.plot(args.log_det, not args.free_axis, args.eer_line) 68 | -------------------------------------------------------------------------------- /scripts/eval/eval-noivecs-domain-adaptation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import math 5 | import torch 6 | 7 | from brnolm.data_pipeline.split_corpus_dataset import DomainAdaptationSplitFFMultiTarget 8 | from brnolm.data_pipeline.multistream import BatchBuilder 9 | 10 | from brnolm.runtime.runtime_utils import CudaStream, filelist_to_objects, init_seeds 11 | from brnolm.runtime.runtime_multifile import evaluate_ 12 | 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser(description='PyTorch RNN/LSTM Language Model') 16 | parser.add_argument('--file-list', type=str, required=True, 17 | help='file with paths to training documents') 18 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 19 | help='batch size') 20 | parser.add_argument('--target-seq-len', type=int, default=35, 21 | help='sequence length') 22 | parser.add_argument('--seed', type=int, default=1111, 23 | help='random seed') 24 | parser.add_argument('--cuda', action='store_true', 25 | help='use CUDA') 26 | parser.add_argument('--concat-articles', action='store_true', 27 | help='pass hidden states over article boundaries') 28 | parser.add_argument('--domain-portion', type=float, required=True, 29 | help='portion of text to use as domain documents. Taken from the back.') 30 | parser.add_argument('--load', type=str, required=True, 31 | help='where to load a model from') 32 | parser.add_argument('--ivec-nb-iters', type=int, 33 | help='override the number of iterations when extracting ivectors') 34 | args = parser.parse_args() 35 | print(args) 36 | 37 | init_seeds(args.seed, args.cuda) 38 | 39 | print("loading LM...") 40 | lm = torch.load(args.load) 41 | if args.cuda: 42 | lm.cuda() 43 | print(lm.model) 44 | 45 | print("preparing data...") 46 | 47 | def ivec_ts_from_file(f): 48 | da_ts = DomainAdaptationSplitFFMultiTarget( 49 | f, lm.vocab, lm.model.in_len, 50 | args.target_seq_len, end_portion=args.domain_portion, 51 | ) 52 | return da_ts 53 | 54 | tss = filelist_to_objects(args.file_list, ivec_ts_from_file) 55 | data = BatchBuilder(tss, args.batch_size, 56 | discard_h=not args.concat_articles) 57 | if args.cuda: 58 | data = CudaStream(data) 59 | 60 | loss = evaluate_( 61 | lm, data, 62 | use_ivecs=False, 63 | custom_batches=True, 64 | ) 65 | print('loss {:5.2f} | ppl {:8.2f}'.format(loss, math.exp(loss))) 66 | -------------------------------------------------------------------------------- /brnolm/data_pipeline/reading.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def word_splitter(line): 5 | return line.split() 6 | 7 | 8 | def char_splitter(line, sentence_end_token=None): 9 | chars = list(line) 10 | 11 | if sentence_end_token is None: 12 | return chars 13 | else: 14 | return chars + [sentence_end_token] 15 | 16 | 17 | class WordIdProvider: 18 | def __init__(self, vocab): 19 | self.vocab = vocab 20 | 21 | def __call__(self, text): 22 | return [self.vocab[w] for w in text.split()] 23 | 24 | 25 | class WordIdLineEndProvider: 26 | def __init__(self, vocab, line_end=''): 27 | self.vocab = vocab 28 | self.line_end = line_end 29 | 30 | def __call__(self, text): 31 | return [self.vocab[w] for w in text.split() + [self.line_end]] 32 | 33 | 34 | class CharIdProvider: 35 | def __init__(self, vocab, sentence_end_token=''): 36 | self.vocab = vocab 37 | self.sentence_end_token = sentence_end_token 38 | 39 | def __call__(self, text): 40 | chars = [self.sentence_end_token if c == '\n' else c for c in text] 41 | return [self.vocab[c] for c in chars] 42 | 43 | 44 | class TokenizerFactory: 45 | tokenize_regimes = { 46 | 'words': WordIdProvider, 47 | 'words-lines': WordIdLineEndProvider, 48 | 'chars': CharIdProvider, 49 | } 50 | 51 | regimes_names = list(tokenize_regimes.keys()) 52 | 53 | def register_parameter(self, parser, param_name): 54 | parser.add_argument( 55 | param_name, 56 | choices=self.regimes_names, 57 | help='words are separated by whitespace, words-lines turns \\n into , chars are verbatim + \\n => ' 58 | ) 59 | 60 | def construct_tokenizer(self, regime, vocab): 61 | if regime in self.tokenize_regimes: 62 | return self.tokenize_regimes[regime](vocab) 63 | else: 64 | raise ValueError(f'Unsupported tokenization regime {regime}') 65 | 66 | 67 | tokenizer_factory = TokenizerFactory() 68 | 69 | 70 | def tokens_from_file(f, randomize, tokenizer): 71 | ids = [] 72 | 73 | lines = f.read().split('\n') 74 | 75 | if randomize: 76 | import random 77 | random.shuffle(lines) 78 | 79 | for line in lines: 80 | ids.extend(tokenizer(line)) 81 | 82 | return torch.LongTensor(ids) 83 | 84 | 85 | def tokens_from_fn(fn, randomize, tokenizer): 86 | with open(fn, 'r') as f: 87 | return tokens_from_file(f, randomize, tokenizer) 88 | 89 | 90 | def get_independent_lines(f, vocab): 91 | lines = [] 92 | for line in f: 93 | words = line.split() 94 | if words: 95 | lines.append(torch.tensor([vocab[w] for w in words])) 96 | 97 | return lines 98 | -------------------------------------------------------------------------------- /scripts/eval/eval-chime.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import logging 5 | import math 6 | import torch 7 | 8 | from brnolm.runtime.runtime_utils import init_seeds 9 | from brnolm.runtime.evaluation import SubstitutionalEnblockEvaluator 10 | from brnolm.data_pipeline.augmentation import Corruptor 11 | 12 | 13 | if __name__ == '__main__': 14 | logging.basicConfig(level=logging.INFO, format='[%(levelname)s::%(name)s] %(message)s') 15 | parser = argparse.ArgumentParser(description='PyTorch RNN/LSTM Language Model') 16 | parser.add_argument('--data', type=str, required=True, 17 | help='location of the data corpus') 18 | parser.add_argument('--shuffle-lines', action='store_true', 19 | help='shuffle lines before every epoch') 20 | 21 | parser.add_argument('--subs-rate', type=float, required=True, 22 | help='what ratio of input tokens should be substituted') 23 | parser.add_argument('--del-rate', type=float, required=True, 24 | help='what ratio of tokens should be deleted') 25 | parser.add_argument('--rounds', type=int, required=True, 26 | help='how many times to run through the eval data') 27 | parser.add_argument('--individual', action='store_true', 28 | help='report individual rounds') 29 | 30 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 31 | help='batch size') 32 | parser.add_argument('--target-seq-len', type=int, default=35, 33 | help='sequence length') 34 | 35 | parser.add_argument('--seed', type=int, default=1111, 36 | help='random seed') 37 | parser.add_argument('--cuda', action='store_true', 38 | help='use CUDA') 39 | parser.add_argument('--load', type=str, required=True, 40 | help='where to load a model from') 41 | args = parser.parse_args() 42 | print(args) 43 | 44 | init_seeds(args.seed, args.cuda) 45 | 46 | print("loading model...") 47 | device = torch.device('cuda') if args.cuda else torch.device('cpu') 48 | lm = torch.load(args.load, map_location=device) 49 | print(lm) 50 | 51 | evaluator = SubstitutionalEnblockEvaluator( 52 | lm, 53 | args.data, 54 | args.batch_size, 55 | args.target_seq_len, 56 | lambda data: Corruptor(data, substitution_rate=args.subs_rate, replacements_range=len(lm.vocab), deletion_rate=args.del_rate), 57 | args.rounds, 58 | ) 59 | eval_report = evaluator.evaluate(report_individual=args.individual) 60 | 61 | print('total loss {:.1f} | per token loss {:5.2f} | ppl {:8.2f}'.format(eval_report.total_loss, eval_report.loss_per_token, math.exp(eval_report.loss_per_token))) 62 | -------------------------------------------------------------------------------- /scripts/model-building/build-transformer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import torch 5 | 6 | from brnolm.language_models import transformer, vocab, language_model 7 | from brnolm.language_models.decoders import FullSoftmaxDecoder 8 | 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser(description='PyTorch LSTM Language Model') 12 | parser.add_argument('--wordlist', type=str, required=True, 13 | help='word -> int map; Kaldi style "words.txt"') 14 | parser.add_argument('--quoted-wordlist', action='store_true', 15 | help='assume the words are quoted (with a single quote)') 16 | parser.add_argument('--unk', type=str, default="", 17 | help='expected form of "unk" word. Most likely a or ') 18 | parser.add_argument('--dim-residual', type=int, 19 | help='Dimension of most of the model') 20 | parser.add_argument('--dim-ff', type=int, 21 | help='Dimension of the feed-forwad layers') 22 | parser.add_argument('--nb-layers', type=int, 23 | help='number of layers') 24 | parser.add_argument('--nb-heads', type=int, 25 | help='number of layers') 26 | parser.add_argument('--dropout', type=float, default=0.0, 27 | help='dropout applied to layers (0 = no dropout)') 28 | parser.add_argument('--tied', action='store_true', 29 | help='tie the word embedding and softmax weights') 30 | parser.add_argument('--seed', type=int, default=1111, 31 | help='random seed') 32 | parser.add_argument('--save', type=str, required=True, 33 | help='path to save the final model') 34 | args = parser.parse_args() 35 | print(args) 36 | 37 | # Set the random seed manually for reproducibility. 38 | torch.manual_seed(args.seed) 39 | 40 | print("loading vocabulary...") 41 | with open(args.wordlist, 'r') as f: 42 | if args.quoted_wordlist: 43 | vocabulary = vocab.quoted_vocab_from_kaldi_wordlist(f, args.unk) 44 | else: 45 | vocabulary = vocab.vocab_from_kaldi_wordlist(f, args.unk) 46 | 47 | if not vocabulary.is_continuous(): 48 | raise ValueError("Vocabulary is not continuous, missing indexes {}".format(vocabulary.missing_indexes())) 49 | 50 | print("building model...") 51 | 52 | model = transformer.TransformerLM( 53 | vocab_size=len(vocabulary), 54 | nb_heads=args.nb_heads, 55 | dim_res=args.dim_residual, 56 | dim_ff=args.dim_ff, 57 | nb_layers=args.nb_layers, 58 | dropout=args.dropout, 59 | ) 60 | 61 | decoder = FullSoftmaxDecoder(args.dim_residual, len(vocabulary)) 62 | 63 | lm = language_model.LanguageModel(model, decoder, vocabulary) 64 | torch.save(lm, args.save) 65 | -------------------------------------------------------------------------------- /brnolm/srilm-debug2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import math 4 | import sys 5 | 6 | import torch 7 | 8 | 9 | BATCH_SIZE = 1 10 | 11 | 12 | def per_word_logprobs(words, lm): 13 | words_tensor = torch.tensor([lm.vocab[w] for w in words], requires_grad=False).view(1, -1) 14 | 15 | x = words_tensor[:, :-1] 16 | t = words_tensor[:, 1:] 17 | 18 | h0 = lm.model.init_hidden(x.size(0)) 19 | if not lm.model.batch_first: 20 | x = x.t() 21 | t = t.t() 22 | 23 | with torch.no_grad(): 24 | out_embs, h = lm.model(x, h0) 25 | log_probs = -lm.decoder.neg_log_prob_raw(out_embs, t) 26 | 27 | return log_probs.view(-1) 28 | 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--lm', required=True) 33 | parser.add_argument('--sb', default='') 34 | parser.add_argument('--add-boundaries', action='store_true') 35 | args = parser.parse_args() 36 | 37 | lm = torch.load(args.lm, map_location=lambda storage, location: storage) 38 | lm.eval() 39 | 40 | nb_sentences = 0 41 | nb_words = 0 42 | first_line = True 43 | total_logprob = 0.0 44 | sentence_boundary_logprob = 0.0 45 | for line_no, line in enumerate(sys.stdin): 46 | words = line.split() 47 | 48 | if len(words) <= 1: 49 | sys.stderr.write('Skipping line {} due to only containing "{}".\n'.format(line_no, words)) 50 | continue 51 | 52 | if args.add_boundaries: 53 | if words[-1] == args.sb: 54 | sys.stderr.write('Skipping line {}, contains sentence boundary, despite --add-boundaries. "{}".'.format(line_no, args.sb)) 55 | continue 56 | else: 57 | words.append(args.sb) 58 | else: 59 | if words[-1] != args.sb: 60 | sys.stderr.write('Skipping line {}, lacks sentence boundary "{}".'.format(line_no, args.sb)) 61 | continue 62 | 63 | log_probs = per_word_logprobs([args.sb] + words, lm) 64 | 65 | if first_line: 66 | first_line = False 67 | else: 68 | sys.stdout.write('\n') 69 | 70 | sys.stdout.write(line) 71 | for w, log_p in zip(words, log_probs): 72 | word_field = "p( {} | ...)".format(w) 73 | sys.stdout.write("\t{}\t= [1gram] {:.3f}\n".format(word_field, log_p.item(), log_p.exp().item())) 74 | 75 | nb_sentences += 1 76 | nb_words += len(words) 77 | total_logprob += log_probs.sum().item() 78 | sentence_boundary_logprob += log_probs[-1] 79 | 80 | nb_words -= nb_sentences 81 | sys.stdout.write('{} sentences, {} words, 0 OOVs\n'.format(nb_sentences, nb_words)) 82 | sys.stdout.write('0 zeroprobs, logprob= {:.3f} ppl= {:.2f} ppl1= {:.2f}\n'.format( 83 | total_logprob, math.exp(-total_logprob/(nb_words+nb_sentences)), 84 | math.exp(-(total_logprob-sentence_boundary_logprob)/nb_words) 85 | )) 86 | 87 | 88 | if __name__ == '__main__': 89 | main() 90 | -------------------------------------------------------------------------------- /brnolm/language_models/decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class FullSoftmaxDecoder(torch.nn.Module): 5 | def __init__(self, nb_hidden, nb_output, init_range=0.1): 6 | super().__init__() 7 | 8 | self.projection = torch.nn.Linear(nb_hidden, nb_output) 9 | self.log_softmax = torch.nn.LogSoftmax(dim=-1) 10 | 11 | self.projection.weight.data.uniform_(-init_range, init_range) 12 | self.projection.bias.data.fill_(0) 13 | 14 | self.nllloss = torch.nn.NLLLoss(reduction='sum') 15 | 16 | def forward(self, X): 17 | a = self.projection(X) 18 | return self.log_softmax(a) 19 | 20 | def neg_log_prob_raw(self, X, targets): 21 | orig_shape = targets.shape 22 | preds = self.forward(X) 23 | targets_flat = targets.view(-1) 24 | preds_flat = preds.view(-1, preds.size(-1)) 25 | 26 | return torch.nn.functional.nll_loss(preds_flat, targets_flat, reduction='none').view(orig_shape) 27 | 28 | @torch.jit.export 29 | def neg_log_prob(self, X, targets): 30 | return self.neg_log_prob_raw(X, targets).sum(), targets.numel() 31 | 32 | 33 | class CustomLossFullSoftmaxDecoder(torch.nn.Module): 34 | def __init__(self, nb_hidden, nb_output, init_range=0.1, label_smoothing=None): 35 | super().__init__() 36 | 37 | assert label_smoothing is None or (label_smoothing >= 0.0 and label_smoothing < 1.0) 38 | 39 | self.projection = torch.nn.Linear(nb_hidden, nb_output) 40 | self.log_softmax = torch.nn.LogSoftmax(dim=-1) 41 | 42 | self.projection.weight.data.uniform_(-init_range, init_range) 43 | self.projection.bias.data.fill_(0) 44 | 45 | if label_smoothing is not None: 46 | self.core_loss = LabelSmoothedNLLLoss(label_smoothing) 47 | else: 48 | self.core_loss = plain_nll_loss 49 | 50 | def forward(self, X): 51 | a = self.projection(X) 52 | return self.log_softmax(a) 53 | 54 | def neg_log_prob_raw(self, X, targets): 55 | orig_shape = targets.shape 56 | preds = self.forward(X) 57 | targets_flat = targets.view(-1) 58 | preds_flat = preds.view(-1, preds.size(-1)) 59 | 60 | return self.core_loss(preds_flat, targets_flat).view(orig_shape) 61 | 62 | def neg_log_prob(self, X, targets): 63 | return self.neg_log_prob_raw(X, targets).sum(), targets.numel() 64 | 65 | 66 | def plain_nll_loss(preds, targets): 67 | return torch.nn.functional.nll_loss(preds, targets, reduction='none') 68 | 69 | 70 | class LabelSmoothedNLLLoss: 71 | def __init__(self, amount): 72 | self.amount = amount 73 | 74 | def __call__(self, preds, targets): 75 | eps = self.amount 76 | n_class = preds.size(1) 77 | 78 | assert len(targets.shape) == 1 79 | assert len(preds.shape) == 2 80 | 81 | smooth_targets = torch.zeros_like(preds).scatter_(1, targets.unsqueeze(1), 1) 82 | smooth_targets = smooth_targets * (1 - eps) + (1 - smooth_targets) * eps / n_class 83 | loss = -(smooth_targets * preds).sum(axis=1) 84 | 85 | return loss 86 | -------------------------------------------------------------------------------- /scripts/eval/eval-ivecs-oracle.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import math 5 | import torch 6 | 7 | from brnolm.data_pipeline.split_corpus_dataset import TokenizedSplitFFBase 8 | from brnolm.smm_itf import ivec_appenders 9 | from brnolm.smm_itf import smm_ivec_extractor 10 | from brnolm.data_pipeline.multistream import BatchBuilder 11 | from brnolm.data_pipeline.temporal_splitting import TemporalSplits 12 | 13 | from brnolm.runtime.runtime_utils import CudaStream, filelist_to_objects, init_seeds 14 | from brnolm.runtime.runtime_multifile import evaluate 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser(description='PyTorch RNN/LSTM Language Model') 19 | parser.add_argument('--file-list', type=str, required=True, 20 | help='file with paths to training documents') 21 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 22 | help='batch size') 23 | parser.add_argument('--target-seq-len', type=int, default=35, 24 | help='sequence length') 25 | parser.add_argument('--seed', type=int, default=1111, 26 | help='random seed') 27 | parser.add_argument('--cuda', action='store_true', 28 | help='use CUDA') 29 | parser.add_argument('--concat-articles', action='store_true', 30 | help='pass hidden states over article boundaries') 31 | parser.add_argument('--load', type=str, required=True, 32 | help='where to load a model from') 33 | parser.add_argument('--ivec-extractor', type=str, required=True, 34 | help='where to load a ivector extractor from') 35 | parser.add_argument('--ivec-nb-iters', type=int, 36 | help='override the number of iterations when extracting ivectors') 37 | args = parser.parse_args() 38 | print(args) 39 | 40 | init_seeds(args.seed, args.cuda) 41 | 42 | print("loading LM...") 43 | lm = torch.load(args.load) 44 | if args.cuda: 45 | lm.cuda() 46 | print(lm.model) 47 | 48 | print("loading SMM iVector extractor ...") 49 | with open(args.ivec_extractor, 'rb') as f: 50 | ivec_extractor = smm_ivec_extractor.load(f) 51 | if args.ivec_nb_iters: 52 | ivec_extractor._nb_iters = args.ivec_nb_iters 53 | print(ivec_extractor) 54 | 55 | print("preparing data...") 56 | 57 | def ivec_ts_from_file(f): 58 | ts = TokenizedSplitFFBase( 59 | f, lm.vocab, 60 | lambda seq: TemporalSplits(seq, lm.model.in_len, args.target_seq_len) 61 | ) 62 | return ivec_appenders.CheatingIvecAppender(ts, ivec_extractor) 63 | 64 | data_ivecs = filelist_to_objects(args.file_list, ivec_ts_from_file) 65 | data = BatchBuilder( 66 | data_ivecs, 67 | args.batch_size, 68 | discard_h=not args.concat_articles 69 | ) 70 | 71 | if args.cuda: 72 | data = CudaStream(data) 73 | 74 | loss = evaluate(lm, data, use_ivecs=True) 75 | print('loss {:5.2f} | ppl {:8.2f}'.format(loss, math.exp(loss))) 76 | -------------------------------------------------------------------------------- /scripts/eval/eval-chime-v2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import logging 5 | import math 6 | import torch 7 | 8 | from safe_gpu.safe_gpu import GPUOwner 9 | from brnolm.runtime.runtime_utils import init_seeds 10 | from brnolm.runtime.evaluation import SubstitutionalEnblockEvaluator_v2 11 | from brnolm.data_pipeline.aug_paper_pipeline import Corruptor 12 | 13 | 14 | if __name__ == '__main__': 15 | logging.basicConfig(level=logging.INFO, format='[%(levelname)s::%(name)s] %(message)s') 16 | parser = argparse.ArgumentParser(description='PyTorch RNN/LSTM Language Model') 17 | parser.add_argument('--data', type=str, required=True, 18 | help='location of the data corpus') 19 | parser.add_argument('--shuffle-lines', action='store_true', 20 | help='shuffle lines before every epoch') 21 | 22 | parser.add_argument('--subs-rate', type=float, required=True, 23 | help='what ratio of input tokens should be substituted') 24 | parser.add_argument('--del-rate', type=float, required=True, 25 | help='what ratio of tokens should be deleted') 26 | parser.add_argument('--ins-rate', type=float, required=True, 27 | help='what ratio of tokens should be inserted') 28 | parser.add_argument('--rounds', type=int, required=True, 29 | help='how many times to run through the eval data') 30 | parser.add_argument('--individual', action='store_true', 31 | help='report individual rounds') 32 | 33 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 34 | help='batch size') 35 | parser.add_argument('--target-seq-len', type=int, default=35, 36 | help='sequence length') 37 | 38 | parser.add_argument('--seed', type=int, default=1111, 39 | help='random seed') 40 | parser.add_argument('--cuda', action='store_true', 41 | help='use CUDA') 42 | parser.add_argument('--load', type=str, required=True, 43 | help='where to load a model from') 44 | args = parser.parse_args() 45 | print(args) 46 | 47 | init_seeds(args.seed, args.cuda) 48 | 49 | print("loading model...") 50 | device = torch.device('cuda') if args.cuda else torch.device('cpu') 51 | if args.cuda: 52 | gpu_owner = GPUOwner(lambda: torch.zeros((1), device='cuda')) 53 | 54 | lm = torch.load(args.load, map_location=device) 55 | print(lm) 56 | 57 | evaluator = SubstitutionalEnblockEvaluator_v2( 58 | lm, 59 | args.data, 60 | args.batch_size, 61 | args.target_seq_len, 62 | lambda streams: Corruptor(streams, args.subs_rate, len(lm.vocab), args.del_rate, args.ins_rate, protected=[lm.vocab['']]), 63 | args.rounds, 64 | ) 65 | eval_report = evaluator.evaluate(report_individual=args.individual) 66 | 67 | print('total loss {:.1f} | per token loss {:5.2f} | ppl {:8.2f}'.format(eval_report.total_loss, eval_report.loss_per_token, math.exp(eval_report.loss_per_token))) 68 | -------------------------------------------------------------------------------- /scripts/eval/eval-ivecs-partial.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import math 5 | import torch 6 | 7 | from brnolm.smm_itf import ivec_appenders 8 | from brnolm.smm_itf import smm_ivec_extractor 9 | from brnolm.data_pipeline.multistream import BatchBuilder 10 | from brnolm.data_pipeline.temporal_splitting import TemporalSplits 11 | from brnolm.data_pipeline.split_corpus_dataset import TokenizedSplitFFBase 12 | 13 | from brnolm.runtime.runtime_utils import CudaStream, filelist_to_objects, init_seeds 14 | from brnolm.runtime.runtime_multifile import evaluate 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser(description='PyTorch RNN/LSTM Language Model') 19 | parser.add_argument('--file-list', type=str, required=True, 20 | help='file with paths to training documents') 21 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 22 | help='batch size') 23 | parser.add_argument('--target-seq-len', type=int, default=35, 24 | help='sequence length') 25 | parser.add_argument('--seed', type=int, default=1111, 26 | help='random seed') 27 | parser.add_argument('--cuda', action='store_true', 28 | help='use CUDA') 29 | parser.add_argument('--concat-articles', action='store_true', 30 | help='pass hidden states over article boundaries') 31 | parser.add_argument('--load', type=str, required=True, 32 | help='where to load a model from') 33 | parser.add_argument('--ivec-extractor', type=str, required=True, 34 | help='where to load a ivector extractor from') 35 | parser.add_argument('--ivec-nb-iters', type=int, 36 | help='override the number of iterations when extracting ivectors') 37 | args = parser.parse_args() 38 | print(args) 39 | 40 | init_seeds(args.seed, args.cuda) 41 | 42 | print("loading LM...") 43 | lm = torch.load(args.load) 44 | if args.cuda: 45 | lm.cuda() 46 | print(lm.model) 47 | 48 | print("loading SMM iVector extractor ...") 49 | with open(args.ivec_extractor, 'rb') as f: 50 | ivec_extractor = smm_ivec_extractor.load(f) 51 | if args.ivec_nb_iters is not None: 52 | ivec_extractor._nb_iters = args.ivec_nb_iters 53 | print(ivec_extractor) 54 | 55 | print("preparing data...") 56 | 57 | def ts_from_file(f): 58 | return TokenizedSplitFFBase( 59 | f, lm.vocab, 60 | lambda seq: TemporalSplits(seq, lm.model.in_len, args.target_seq_len) 61 | ) 62 | 63 | tss = filelist_to_objects(args.file_list, ts_from_file) 64 | data = BatchBuilder(tss, args.batch_size, 65 | discard_h=not args.concat_articles) 66 | if args.cuda: 67 | data = CudaStream(data) 68 | data_ivecs = ivec_appenders.ParalelIvecAppender( 69 | data, ivec_extractor, ivec_extractor.build_translator(lm.vocab) 70 | ) 71 | 72 | print("evaluating...") 73 | loss = evaluate(lm, data_ivecs, use_ivecs=True) 74 | print('loss {:5.2f} | ppl {:8.2f}'.format(loss, math.exp(loss))) 75 | -------------------------------------------------------------------------------- /scripts/sample-from-lm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import torch 5 | from torch.distributions import Categorical 6 | 7 | import sys 8 | 9 | 10 | class Unbuffered: 11 | def __init__(self, stream): 12 | self.stream = stream 13 | 14 | def write(self, data): 15 | self.stream.write(data) 16 | self.stream.flush() 17 | 18 | def writelines(self, datas): 19 | self.stream.writelines(datas) 20 | self.stream.flush() 21 | 22 | def __getattr__(self, attr): 23 | return getattr(self.stream, attr) 24 | 25 | 26 | def get_max(log_probs): 27 | return torch.max(log_probs, 0)[1] 28 | 29 | 30 | def sample(log_probs, temperature): 31 | annealed_logits = log_probs / temperature 32 | dist = Categorical(logits=annealed_logits) 33 | return dist.sample() 34 | 35 | 36 | def get_sampler(args): 37 | if args.sampler == 'max': 38 | return get_max 39 | elif args.sampler == 'sample': 40 | return lambda y: sample(y, args.temperature) 41 | else: 42 | raise ValueError(f"Unacceptable sampler {args.sampler}") 43 | 44 | 45 | class NextIndexProducer: 46 | def __init__(self, lm, sampler, seed_text): 47 | self.lm = lm 48 | self.sampler = sampler 49 | self.inds_to_process = [lm.vocab[c] for c in seed_text] 50 | self.h = self.lm.model.init_hidden(1) 51 | 52 | def __call__(self): 53 | o, self.h = self.lm.model(torch.tensor(self.inds_to_process).unsqueeze(0), self.h) 54 | y = self.lm.decoder(o[0, -1, :]).squeeze().detach() 55 | 56 | sample = self.sampler(y).item() 57 | self.inds_to_process = [sample] 58 | 59 | return self.lm.vocab.i2w(sample) 60 | 61 | 62 | class LineWriter: 63 | def __init__(self, f): 64 | self._f = f 65 | 66 | def write(self, string): 67 | self._f.write(string) 68 | 69 | def __enter__(self): 70 | self._put_delim() 71 | return self 72 | 73 | def __exit__(self, type, value, traceback): 74 | self._put_delim() 75 | self._f.write('\n') 76 | 77 | def _put_delim(self): 78 | self._f.write("'") 79 | 80 | 81 | def main(): 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('--sb') 84 | parser.add_argument('--sampler', choices=['max', 'sample'], default='sample') 85 | parser.add_argument('--temperature', type=float, default=1.0) 86 | parser.add_argument('lm') 87 | parser.add_argument('seed_text') 88 | parser.add_argument('nb_tokens', nargs='?', type=int, default=10) 89 | args = parser.parse_args() 90 | 91 | lm = torch.load(args.lm, map_location="cpu") 92 | if args.sb: 93 | assert args.sb in lm.vocab 94 | 95 | index_producer = NextIndexProducer(lm, get_sampler(args), args.seed_text) 96 | 97 | sys.stdout = Unbuffered(sys.stdout) 98 | with LineWriter(sys.stdout) as writer: 99 | writer.write(args.seed_text) 100 | for i in range(args.nb_tokens): 101 | char = index_producer() 102 | 103 | if char == args.sb: 104 | char = '\n' 105 | 106 | writer.write(char) 107 | 108 | 109 | if __name__ == '__main__': 110 | main() 111 | -------------------------------------------------------------------------------- /scripts/eval/eval-ivecs-domain-adaptation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import math 5 | import torch 6 | 7 | from brnolm.smm_itf import ivec_appenders 8 | from brnolm.data_pipeline.split_corpus_dataset import DomainAdaptationSplitFFMultiTarget 9 | from brnolm.data_pipeline.multistream import BatchBuilder 10 | from brnolm.smm_itf import smm_ivec_extractor 11 | 12 | from brnolm.runtime.runtime_utils import CudaStream, filelist_to_objects, init_seeds 13 | from brnolm.runtime.runtime_multifile import evaluate_ 14 | 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser(description='PyTorch RNN/LSTM Language Model') 18 | parser.add_argument('--file-list', type=str, required=True, 19 | help='file with paths to training documents') 20 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 21 | help='batch size') 22 | parser.add_argument('--target-seq-len', type=int, default=35, 23 | help='sequence length') 24 | parser.add_argument('--seed', type=int, default=1111, 25 | help='random seed') 26 | parser.add_argument('--cuda', action='store_true', 27 | help='use CUDA') 28 | parser.add_argument('--concat-articles', action='store_true', 29 | help='pass hidden states over article boundaries') 30 | parser.add_argument('--domain-portion', type=float, required=True, 31 | help='portion of text to use as domain documents. Taken from the back.') 32 | parser.add_argument('--load', type=str, required=True, 33 | help='where to load a model from') 34 | parser.add_argument('--ivec-extractor', type=str, required=True, 35 | help='where to load a ivector extractor from') 36 | parser.add_argument('--ivec-nb-iters', type=int, 37 | help='override the number of iterations when extracting ivectors') 38 | args = parser.parse_args() 39 | print(args) 40 | 41 | init_seeds(args.seed, args.cuda) 42 | 43 | print("loading LM...") 44 | lm = torch.load(args.load) 45 | if args.cuda: 46 | lm.cuda() 47 | print(lm.model) 48 | 49 | print("loading SMM iVector extractor ...") 50 | with open(args.ivec_extractor, 'rb') as f: 51 | ivec_extractor = smm_ivec_extractor.load(f) 52 | if args.ivec_nb_iters is not None: 53 | ivec_extractor._nb_iters = args.ivec_nb_iters 54 | print(ivec_extractor) 55 | 56 | print("preparing data...") 57 | 58 | def ivec_ts_from_file(f): 59 | da_ts = DomainAdaptationSplitFFMultiTarget( 60 | f, lm.vocab, lm.model.in_len, 61 | args.target_seq_len, end_portion=args.domain_portion, 62 | ) 63 | return ivec_appenders.CheatingIvecAppender(da_ts, ivec_extractor) 64 | 65 | tss = filelist_to_objects(args.file_list, ivec_ts_from_file) 66 | data = BatchBuilder(tss, args.batch_size, 67 | discard_h=not args.concat_articles) 68 | if args.cuda: 69 | data = CudaStream(data) 70 | 71 | loss = evaluate_( 72 | lm, data, 73 | use_ivecs=True, 74 | custom_batches=True, 75 | ) 76 | print('loss {:5.2f} | ppl {:8.2f}'.format(loss, math.exp(loss))) 77 | -------------------------------------------------------------------------------- /brnolm/language_models/ffnn_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BengioModel(nn.Module): 7 | """Container module with an encoder, a recurrent module, and a decoder.""" 8 | 9 | def __init__(self, ntoken, emb_size, in_len, nb_hidden, dropout=0.5): 10 | super().__init__() 11 | self.drop = nn.Dropout(dropout) 12 | self.encoder = nn.Embedding(ntoken, emb_size) 13 | self.emb2h = nn.ModuleList([nn.Linear(emb_size, nb_hidden) for _ in range(in_len)]) 14 | 15 | self.init_weights() 16 | 17 | self.nb_hidden = nb_hidden 18 | self.in_len = in_len 19 | self.emb_size = emb_size 20 | 21 | self.batch_first = True 22 | 23 | def init_weights(self): 24 | initrange = 0.1 25 | self.encoder.weight.data.uniform_(-initrange, initrange) 26 | for e2h in self.emb2h: 27 | e2h.weight.data.uniform_(-initrange, initrange) 28 | 29 | def forward(self, input, hidden): 30 | emb = self.drop(self.encoder(input)) 31 | projections = [proj(emb[:, i:emb.size(1)-(self.in_len-i)+1]) for i, proj in enumerate(self.emb2h)] 32 | projections = torch.stack(projections, dim=-1) 33 | output = F.tanh(torch.sum(projections, dim=-1)) 34 | output = self.drop(output) 35 | return output, hidden 36 | 37 | def init_hidden(self, bsz): 38 | # not used, but to fit into the framework of other ivec-LMs 39 | weight = next(self.parameters()).data 40 | return weight.new_zeros(1, bsz, self.nb_hidden) 41 | 42 | 43 | class BengioModelIvecInput(nn.Module): 44 | """Container module with an encoder, a recurrent module, and a decoder.""" 45 | 46 | def __init__(self, ntoken, emb_size, in_len, nb_hidden, dropout, ivec_dim): 47 | super().__init__() 48 | self.drop = nn.Dropout(dropout) 49 | self.encoder = nn.Embedding(ntoken, emb_size) 50 | self.emb2h = nn.ModuleList([nn.Linear(emb_size, nb_hidden) for _ in range(in_len)]) 51 | self.ivec2h = nn.Linear(ivec_dim, nb_hidden) 52 | 53 | self.init_weights() 54 | 55 | self.nb_hidden = nb_hidden 56 | self.in_len = in_len 57 | self.emb_size = emb_size 58 | 59 | self.batch_first = True 60 | 61 | def init_weights(self): 62 | initrange = 0.1 63 | self.encoder.weight.data.uniform_(-initrange, initrange) 64 | for e2h in self.emb2h: 65 | e2h.weight.data.uniform_(-initrange, initrange) 66 | 67 | def forward(self, input, hidden, ivec): 68 | if len(ivec.size()) == 1: 69 | ivec = ivec.unsqueeze(0) 70 | emb = self.drop(self.encoder(input)) 71 | projections = [proj(emb[:, i:emb.size(1)-(self.in_len-i)+1]) for i, proj in enumerate(self.emb2h)] 72 | nb_timesteps = projections[0].size(1) 73 | projected_ivec = self.ivec2h(ivec).unsqueeze(dim=-2).expand(-1, nb_timesteps, -1) 74 | projections = torch.stack(projections + [projected_ivec], dim=-1) 75 | output = F.tanh(torch.sum(projections, dim=-1)) 76 | output = self.drop(output) 77 | return output, hidden 78 | 79 | def init_hidden(self, bsz): 80 | # not used, but to fit into the framework of other ivec-LMs 81 | weight = next(self.parameters()).data 82 | return weight.new_zeros(1, bsz, self.nb_hidden) 83 | -------------------------------------------------------------------------------- /scripts/model-building/build-lstm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import torch 5 | 6 | from brnolm.language_models import lstm_model, vocab, language_model 7 | from brnolm.language_models.decoders import CustomLossFullSoftmaxDecoder 8 | from brnolm.language_models.encoders import FlatEmbedding 9 | 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser(description='PyTorch LSTM Language Model') 13 | parser.add_argument('--wordlist', type=str, required=True, 14 | help='word -> int map; Kaldi style "words.txt"') 15 | parser.add_argument('--quoted-wordlist', action='store_true', 16 | help='assume the words are quoted (with a single quote)') 17 | parser.add_argument('--unk', type=str, default="", 18 | help='expected form of "unk" word. Most likely a or ') 19 | parser.add_argument('--emsize', type=int, default=200, 20 | help='size of word embeddings') 21 | parser.add_argument('--nhid', type=int, default=200, 22 | help='number of hidden units per layer') 23 | parser.add_argument('--nlayers', type=int, default=2, 24 | help='number of layers') 25 | parser.add_argument('--dropout', type=float, default=0.2, 26 | help='dropout applied to layers (0 = no dropout)') 27 | parser.add_argument('--label-smoothing', type=float, 28 | help='amount of correct probability to distribute across others') 29 | parser.add_argument('--tied', action='store_true', 30 | help='tie the word embedding and softmax weights') 31 | parser.add_argument('--embedding-init-range', type=float, default=0.1, 32 | help='Range for initialization of both input and output weight matrices') 33 | parser.add_argument('--seed', type=int, default=1111, 34 | help='random seed') 35 | parser.add_argument('--save', type=str, required=True, 36 | help='path to save the final model') 37 | args = parser.parse_args() 38 | 39 | # Set the random seed manually for reproducibility. 40 | torch.manual_seed(args.seed) 41 | 42 | print("loading vocabulary...") 43 | with open(args.wordlist, 'r') as f: 44 | if args.quoted_wordlist: 45 | vocabulary = vocab.quoted_vocab_from_kaldi_wordlist(f, args.unk) 46 | else: 47 | vocabulary = vocab.vocab_from_kaldi_wordlist(f, args.unk) 48 | 49 | if not vocabulary.is_continuous(): 50 | raise ValueError("Vocabulary is not continuous, missing indexes {}".format(vocabulary.missing_indexes())) 51 | 52 | print("building model...") 53 | 54 | encoder = FlatEmbedding(len(vocabulary), args.emsize, init_range=args.embedding_init_range) 55 | 56 | model = lstm_model.LSTMLanguageModel( 57 | token_encoder=encoder, 58 | dim_input=args.emsize, 59 | dim_lstm=args.nhid, 60 | nb_layers=args.nlayers, 61 | dropout=args.dropout 62 | ) 63 | 64 | decoder = CustomLossFullSoftmaxDecoder(args.nhid, len(vocabulary), init_range=args.embedding_init_range) 65 | 66 | if args.tied: 67 | decoder.projection.weight = encoder.embeddings.weight 68 | 69 | lm = language_model.LanguageModel(model, decoder, vocabulary) 70 | torch.save(lm, args.save) 71 | -------------------------------------------------------------------------------- /scripts/oov-clustering/evaluate-embeddings-large-scale.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | from brnolm.oov_clustering.det import DETCurve 5 | from typing import List, Tuple 6 | import pickle 7 | import random 8 | import sys 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--ref', required=True, 13 | help='reference matrix, text, triangular') 14 | parser.add_argument('--selection', help='file with candidates to consider') 15 | parser.add_argument('--posteriors', help='file with per-candidate posteriors') 16 | parser.add_argument('--posteriors-percentage', type=float, default=0.1, 17 | help='file with per-candidate posteriors') 18 | parser.add_argument('--scores', required=True, 19 | help='reference matrix, text, triangular') 20 | parser.add_argument('--sampling-rate', type=float, default=0.1, 21 | help='reference matrix, text, triangular') 22 | parser.add_argument('--det-file', required=True, 23 | help='where to put the pickled DETCurve object') 24 | args = parser.parse_args() 25 | 26 | if args.posteriors and not args.selection: 27 | posteriors = [] 28 | with open(args.posteriors) as f: 29 | for line in f: 30 | fields = line.split() 31 | posteriors.append(float(fields[1])) 32 | sorted_posteriors = sorted(posteriors, reverse=True) 33 | threshold = sorted_posteriors[int(len(posteriors)*args.posteriors_percentage)] 34 | index_is_interesting = lambda x: posteriors[x] > threshold 35 | elif args.posteriors and args.selection: 36 | all_candidates = [] 37 | with open(args.posteriors) as f: 38 | for line in f: 39 | fields = line.split() 40 | all_candidates.append(fields[0]) 41 | with open(args.selection) as f: 42 | selected = f.read().split() 43 | indexes_of_interesting = [all_candidates.index(s) for s in selected] 44 | index_is_interesting = lambda x: x in indexes_of_interesting 45 | elif not args.posteriors and args.selection: 46 | sys.stderr.write("Once selection is given, posteriors are necessary\n") 47 | sys.exit(1) 48 | else: 49 | index_is_interesting = lambda x: True 50 | 51 | score_tg: List[Tuple[float, float]] = [] 52 | with open(args.ref) as ref_f, open(args.scores) as scores_f: 53 | for i, (ref_line, score_line) in enumerate(zip(ref_f, scores_f)): 54 | if not index_is_interesting(i): 55 | continue 56 | 57 | ref_fields = [float(x) for x in ref_line.split()] 58 | score_fields = [float(x) for x in score_line.split()] 59 | 60 | line_score_tg = list(zip(score_fields, ref_fields)) 61 | if args.posteriors: 62 | line_score_tg = [ 63 | stg for j, stg in enumerate(line_score_tg) if index_is_interesting(j) 64 | ] 65 | 66 | score_tg.extend(random.sample( 67 | line_score_tg, 68 | int(len(line_score_tg)*args.sampling_rate) 69 | )) 70 | 71 | det = DETCurve(score_tg, baseline=True, max_det_points=500) 72 | 73 | with open(args.det_file, 'wb') as f: 74 | pickle.dump(det, f) 75 | -------------------------------------------------------------------------------- /brnolm/language_models/lstm_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Tuple 4 | 5 | 6 | class LSTMLanguageModel(nn.Module): 7 | """Container module with an encoder, a recurrent module, and a decoder.""" 8 | 9 | def __init__(self, token_encoder, dim_input, dim_lstm, nb_layers, dropout=0.5): 10 | super(LSTMLanguageModel, self).__init__() 11 | self.drop = nn.Dropout(dropout) 12 | self.encoder = token_encoder 13 | self.rnn = nn.LSTM(dim_input, dim_lstm, nb_layers, dropout=dropout, batch_first=True) 14 | 15 | self.dim_lstm = dim_lstm 16 | self.nlayers = nb_layers 17 | # self.nhid = dim_lstm 18 | 19 | self.batch_first = True 20 | self.in_len = 1 21 | 22 | @property 23 | def nhid(self): 24 | return self.rnn.weight_hh_l1.shape[1] 25 | 26 | def init_weights(self): 27 | initrange = 0.1 28 | self.encoder.weight.data.uniform_(-initrange, initrange) 29 | 30 | def forward(self, input, hidden: Tuple[torch.Tensor, torch.Tensor]): 31 | emb = self.drop(self.encoder(input)) 32 | output, hidden = self.rnn(emb, hidden) 33 | output = self.drop(output) 34 | return output, hidden 35 | 36 | def output_expected_embs(self, input): 37 | assert (len(input.size()) == 2) # batch X time index 38 | assert (input.size()[0] == 1) 39 | 40 | hidden = self.init_hidden(1) 41 | emb = self.drop(self.encoder(input)) 42 | outputs, _ = self.rnn(emb, hidden) 43 | return outputs 44 | 45 | @torch.jit.export 46 | def init_hidden(self, bsz: int): 47 | weight = self.rnn.weight_ih_l0 48 | return (weight.new_zeros(self.nlayers, bsz, self.nhid), 49 | weight.new_zeros(self.nlayers, bsz, self.nhid)) 50 | 51 | @torch.jit.export 52 | def extract_output_from_h(self, hidden_state: Tuple[torch.Tensor, torch.Tensor]): 53 | h = hidden_state[0] # hidden state is (h, c) 54 | return h[-1] # last layer is the output one 55 | 56 | 57 | class LSTMPLanguageModel(nn.Module): 58 | """Container module with an encoder, a recurrent module, and a decoder.""" 59 | 60 | def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False): 61 | super().__init__() 62 | self.drop = nn.Dropout(dropout) 63 | self.encoder = nn.Embedding(ntoken, ninp) 64 | self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout, batch_first=True) 65 | self.proj = nn.Linear(nhid, ninp) 66 | 67 | if tie_weights: 68 | raise NotImplementedError 69 | 70 | self.init_weights() 71 | 72 | self.nhid = nhid 73 | self.nlayers = nlayers 74 | 75 | self.batch_first = True 76 | self.in_len = 1 77 | 78 | def init_weights(self): 79 | initrange = 0.1 80 | self.encoder.weight.data.uniform_(-initrange, initrange) 81 | 82 | def forward(self, input, hidden): 83 | emb = self.drop(self.encoder(input)) 84 | output, hidden = self.rnn(emb, hidden) 85 | output = self.drop(output) 86 | projected = self.proj(output) 87 | return projected, hidden 88 | 89 | def init_hidden(self, bsz): 90 | weight = next(self.parameters()).data 91 | return (weight.new_zeros(self.nlayers, bsz, self.nhid), 92 | weight.new_zeros(self.nlayers, bsz, self.nhid)) 93 | -------------------------------------------------------------------------------- /scripts/display-augmented-data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import logging 5 | import pickle 6 | import sys 7 | import torch 8 | 9 | from brnolm.data_pipeline.reading import tokens_from_fn 10 | from brnolm.data_pipeline.aug_paper_pipeline import form_input_targets 11 | from brnolm.data_pipeline.aug_paper_pipeline import Corruptor 12 | from brnolm.data_pipeline.aug_paper_pipeline import Confuser, StatisticsCorruptor 13 | 14 | 15 | RED_MARK = '\033[91m' 16 | END_MARK = '\033[0m' 17 | 18 | 19 | def main(args): 20 | logging.basicConfig(level=logging.INFO, format='[%(levelname)s::%(name)s] %(message)s') 21 | 22 | lm = torch.load(args.load, map_location='cpu') 23 | 24 | tokenize_regime = 'words' 25 | train_ids = tokens_from_fn(args.data, lm.vocab, randomize=False, regime=tokenize_regime) 26 | train_streams = form_input_targets(train_ids) 27 | if args.statistics: 28 | with open(args.statistics, 'rb') as f: 29 | summary = pickle.load(f) 30 | confuser = Confuser(summary.confusions, lm.vocab, mincount=5) 31 | corrupted_provider = StatisticsCorruptor(train_streams, confuser, args.ins_rate, protected=[lm.vocab['']]) 32 | else: 33 | corrupted_provider = Corruptor(train_streams, args.subs_rate, len(lm.vocab), args.del_rate, args.ins_rate, protected=[lm.vocab['']]) 34 | 35 | inputs, targets = corrupted_provider.provide() 36 | 37 | for i in range(args.nb_tokens): 38 | in_word = lm.vocab.i2w(inputs[i].item()) 39 | target_word = lm.vocab.i2w(targets[i].item()) 40 | 41 | is_error = i > 0 and inputs[i] != targets[i-1] 42 | if args.color and is_error: 43 | sys.stdout.write(f'{RED_MARK}{in_word}{END_MARK} {target_word}\n') 44 | else: 45 | sys.stdout.write(f'{in_word} {target_word}\n') 46 | 47 | 48 | if __name__ == '__main__': 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('--data', type=str, required=True, 51 | help='location of the train corpus') 52 | parser.add_argument('--nb-tokens', type=int, default=10, 53 | help='how many input-target pairs to show') 54 | parser.add_argument('--color', action='store_true', 55 | help='Use ANSI colorcodes to highlight errors') 56 | 57 | parser.add_argument('--ins-rate', type=float, required=True, 58 | help='what ratio of tokens should be inserted') 59 | parser.add_argument('--subs-rate', type=float, 60 | help='what ratio of input tokens should be randomly') 61 | parser.add_argument('--del-rate', type=float, 62 | help='what ratio of tokens should be removed') 63 | parser.add_argument('--statistics', type=str, 64 | help='Use these statistics to determine exact mistakes') 65 | 66 | parser.add_argument('--load', type=str, required=True, 67 | help='where to load a model from') 68 | args = parser.parse_args() 69 | print(args) 70 | 71 | if (args.del_rate is None or args.subs_rate is None) and not args.statistics: 72 | sys.stderr.write('either (--del-rate and --subs-rate) or (--statistics) must be provided\n') 73 | sys.exit(2) 74 | 75 | if (args.del_rate is not None or args.subs_rate is not None) and args.statistics: 76 | sys.stderr.write('(--del-rate and --subs-rate) and (--statistics) are mutually exclusive\n') 77 | sys.exit(2) 78 | 79 | main(args) 80 | -------------------------------------------------------------------------------- /scripts/oov-clustering/reference-matrix-by-word-alignment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import sys 5 | 6 | import numpy as np 7 | 8 | from brnolm.oov_clustering.oov_alignment_lib import align, extract_mismatch 9 | from brnolm.oov_clustering.oov_alignment_lib import find_in_mismatches, number_of_errors 10 | 11 | from typing import Dict, List, Tuple 12 | 13 | 14 | def parse_oov_id(oov_id): 15 | return tuple(oov_id.split('_')) 16 | 17 | 18 | def intersection(a, b): 19 | return list(set(a) & set(b)) 20 | 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--text-references', required=True) 25 | parser.add_argument('--oov-list', required=True) 26 | parser.add_argument('--reference-file', required=True) 27 | args = parser.parse_args() 28 | 29 | with open(args.oov_list) as f: 30 | oov_list = f.read().split() 31 | 32 | np.set_printoptions(threshold=2000, linewidth=np.inf) 33 | 34 | total_nb_errors = 0 35 | total_ref_len = 0 36 | 37 | oov_hits: Dict[int, List[Tuple[str, List[str]]]] = {} 38 | 39 | references = {} 40 | with open(args.text_references) as f: 41 | for line in f: 42 | fields = line.split() 43 | references[fields[0]] = fields[1:] 44 | 45 | candidate_possible_words = [] 46 | for line in sys.stdin: 47 | fields = line.split() 48 | _, utt_id, _, _, _ = parse_oov_id(fields[0]) 49 | 50 | candidate_line = fields[1:] 51 | reference_line = references[utt_id] 52 | alignment = align(reference_line, candidate_line) 53 | mismatches = extract_mismatch(alignment) 54 | oov_mismatch = find_in_mismatches(mismatches, "") 55 | 56 | total_ref_len += len(reference_line) 57 | total_nb_errors += number_of_errors(mismatches) 58 | matching_oovs = intersection(oov_list, oov_mismatch[0]) 59 | 60 | if len(matching_oovs) in oov_hits: 61 | oov_hits[len(matching_oovs)].append((utt_id, matching_oovs)) 62 | else: 63 | oov_hits[len(matching_oovs)] = [(utt_id, matching_oovs)] 64 | 65 | candidate_possible_words.append(oov_mismatch[0]) 66 | print(fields[0], oov_mismatch[0], '--', oov_mismatch[1]) 67 | 68 | with open(args.reference_file, 'w') as f: 69 | for i, c1 in enumerate(candidate_possible_words): 70 | intersecting = [] 71 | for c2 in candidate_possible_words[i+1:]: 72 | intersecting.append("1" if len(intersection(c1, c2)) > 0 else "0") 73 | 74 | f.write(" ".join(intersecting) + "\n") 75 | 76 | wer_fmt = "Total WER on candidate paths: {:.2f} % ({} / {})\n" 77 | sys.stderr.write(wer_fmt.format(100.0*total_nb_errors/total_ref_len, total_nb_errors, total_ref_len)) 78 | 79 | oov_hit_fmt = "Total number of candidates where the hypothesised OOV may match an actual OOV : {:.2f} % ({} / {})\n" 80 | total_nb_candidates = sum(len(k) for k in oov_hits.values()) 81 | total_nb_oov_hits = sum(len(oov_hits[k]) for k in oov_hits if k > 0) 82 | sys.stderr.write(oov_hit_fmt.format(100.0*total_nb_oov_hits/total_nb_candidates, total_nb_oov_hits, total_nb_candidates)) 83 | for k in oov_hits: 84 | nb_candidates = len(oov_hits[k]) 85 | percentage_candidates = 100.0 * nb_candidates/total_nb_candidates 86 | nb_unique_candidates = len(set((hit[0], tuple(hit[1])) for hit in oov_hits[k])) 87 | sys.stderr.write("{}: {} {:.2f}% {}\n".format(k, nb_candidates, percentage_candidates, nb_unique_candidates)) 88 | -------------------------------------------------------------------------------- /brnolm/language_models/vocab.py: -------------------------------------------------------------------------------- 1 | try: 2 | from collections.abc import Mapping 3 | except ImportError: 4 | from collections import Mapping 5 | 6 | import re 7 | 8 | 9 | class IndexGenerator(): 10 | def __init__(self, assigned): 11 | self.next_ = 0 12 | self.assigned_ = assigned 13 | 14 | def next(self): 15 | while self.next_ in self.assigned_: 16 | self.next_ += 1 17 | 18 | retval = self.next_ 19 | self.next_ += 1 20 | return retval 21 | 22 | 23 | class Vocabulary(Mapping): 24 | def __init__(self, unk_word, unk_index): 25 | self.w2i_ = {unk_word: unk_index} 26 | self.i2w_ = {unk_index: unk_word} 27 | self.unk_index_ = unk_index 28 | self.unk_word_ = unk_word 29 | self.ind_gen_ = IndexGenerator([unk_index]) 30 | 31 | def add_from_text(self, text): 32 | assert self.ind_gen_ 33 | 34 | words = text.split() 35 | for word in words: 36 | self.add_word(word) 37 | 38 | def add_word(self, word): 39 | if word not in self.w2i_: 40 | index = self.ind_gen_.next() 41 | self.w2i_[word] = index 42 | self.i2w_[index] = word 43 | else: 44 | pass # do not do anything for known words 45 | 46 | def is_continuous(self): 47 | return max(self.i2w_.keys()) == len(self) - 1 48 | 49 | def missing_indexes(self): 50 | return [i for i in range(len(self)) if i not in self.i2w_] 51 | 52 | def w2i(self, word): 53 | return self.w2i_.get(word, self.unk_index_) 54 | 55 | @property 56 | def unk_word(self): 57 | return self.unk_word_ 58 | 59 | @property 60 | def unk_ind(self): 61 | return self.unk_index_ 62 | 63 | def __getitem__(self, idx): 64 | return self.w2i(idx) 65 | 66 | def i2w(self, index): 67 | return self.i2w_[index] 68 | 69 | def __len__(self): 70 | return len(self.w2i_) 71 | 72 | def __iter__(self): 73 | return iter(self.w2i_) 74 | 75 | 76 | def vocab_from_kaldi_wordlist_base(f, unk_word, word_re, remove_quotes): 77 | d = {} 78 | line_re = re.compile('\s*(?P' + word_re + ')\s+(?P[0-9]+)\s*\n?') 79 | for i, line in enumerate(f): 80 | m = line_re.fullmatch(line) 81 | 82 | if m is None: 83 | raise ValueError("Weird line {}: '{}'".format(i, line)) 84 | 85 | w = m.group('word') 86 | if remove_quotes: 87 | assert w[0] == w[-1] == "'" 88 | w = w[1:-1] 89 | i = int(m.group('ind')) 90 | assert i >= 0 91 | 92 | if w in d: 93 | raise ValueError(f'Attempt to redefince "{w}" to {i}, while it already has {d[w]} assigned') 94 | d[w] = i 95 | 96 | try: 97 | vocab = Vocabulary(unk_word, d[unk_word]) 98 | except KeyError: 99 | raise ValueError("Unk word {} not present in the kaldi wordlist!".format(unk_word)) 100 | vocab.ind_gen_ = None 101 | vocab.w2i_ = d 102 | for w in d: 103 | vocab.i2w_[d[w]] = w 104 | 105 | vocab.size_ = max(d.values()) + 1 106 | 107 | return vocab 108 | 109 | 110 | def vocab_from_kaldi_wordlist(f, unk_word=''): 111 | return vocab_from_kaldi_wordlist_base(f, unk_word, word_re='\S+', remove_quotes=False) 112 | 113 | 114 | def quoted_vocab_from_kaldi_wordlist(f, unk_word=''): 115 | return vocab_from_kaldi_wordlist_base(f, unk_word, word_re="'.+'", remove_quotes=True) 116 | -------------------------------------------------------------------------------- /brnolm/language_models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 5 | 6 | 7 | class PositionalEncoding(nn.Module): 8 | r"""Inject some information about the relative or absolute position of the tokens 9 | in the sequence. The positional encodings have the same dimension as 10 | the embeddings, so that the two can be summed. Here, we use sine and cosine 11 | functions of different frequencies. 12 | .. math:: 13 | \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) 14 | \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) 15 | \text{where pos is the word position and i is the embed idx) 16 | Args: 17 | d_model: the embed dim (required). 18 | dropout: the dropout value (default=0.1). 19 | max_len: the max. length of the incoming sequence (default=5000). 20 | Examples: 21 | >>> pos_encoder = PositionalEncoding(d_model) 22 | """ 23 | 24 | def __init__(self, d_model, dropout=0.1, max_len=5000): 25 | super(PositionalEncoding, self).__init__() 26 | self.dropout = nn.Dropout(p=dropout) 27 | 28 | pe = torch.zeros(max_len, d_model) 29 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 30 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 31 | pe[:, 0::2] = torch.sin(position * div_term) 32 | pe[:, 1::2] = torch.cos(position * div_term) 33 | pe = pe.unsqueeze(0).transpose(0, 1) 34 | self.register_buffer('pe', pe) 35 | 36 | def forward(self, x): 37 | r"""Inputs of forward function 38 | Args: 39 | x: the sequence fed to the positional encoder model (required). 40 | Shape: 41 | x: [sequence length, batch size, embed dim] 42 | output: [sequence length, batch size, embed dim] 43 | Examples: 44 | >>> output = pos_encoder(x) 45 | """ 46 | 47 | x = x + self.pe[:x.size(0), :] 48 | return self.dropout(x) 49 | 50 | 51 | def generate_square_subsequent_mask(sz): 52 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 53 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 54 | return mask 55 | 56 | 57 | class TransformerLM(nn.Module): 58 | '''Adapted from PyTorch-examples 59 | ''' 60 | def __init__(self, vocab_size, nb_heads, dim_res, dim_ff, nb_layers, dropout): 61 | super().__init__() 62 | self.src_mask = None 63 | self.encoder = nn.Embedding(vocab_size, dim_res) 64 | self.pos_encoder = PositionalEncoding(dim_res, dropout) 65 | 66 | layer_norm = nn.LayerNorm(dim_res) 67 | encoder_layers = TransformerEncoderLayer(dim_res, nb_heads, dim_ff, dropout) 68 | self.transformer_encoder = TransformerEncoder(encoder_layers, nb_layers, layer_norm) 69 | self.dim_res = dim_res 70 | 71 | self.init_weights() 72 | self.in_len = 1 73 | self.batch_first = True 74 | 75 | def init_hidden(self, batch_size): 76 | w = next(self.parameters()) 77 | return torch.zeros((batch_size, 1), dtype=w.dtype, device=w.device) 78 | 79 | def init_weights(self): 80 | initrange = 0.1 81 | self.encoder.weight.data.uniform_(-initrange, initrange) 82 | 83 | def forward(self, src, hidden=None): 84 | src = src.t() 85 | device = src.device 86 | if self.src_mask is None or self.src_mask.size(0) != len(src): 87 | mask = generate_square_subsequent_mask(len(src)).to(device) 88 | self.src_mask = mask 89 | 90 | src = self.encoder(src) * math.sqrt(self.dim_res) 91 | src = self.pos_encoder(src) 92 | output = self.transformer_encoder(src, self.src_mask) 93 | 94 | return output.permute(1, 0, 2), hidden 95 | -------------------------------------------------------------------------------- /scripts/rescoring/rescore-nbest-continuous.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import logging 5 | import torch 6 | 7 | from brnolm.rescoring.segment_scoring import SegmentScorer 8 | from safe_gpu.safe_gpu import GPUOwner 9 | 10 | import typing 11 | 12 | import brnolm.kaldi_itf 13 | 14 | 15 | def select_hidden_state_to_pass(hidden_states): 16 | return hidden_states['1'] 17 | 18 | 19 | def spk_sess(segment_name): 20 | return segment_name.split('-')[0].split('_') 21 | 22 | 23 | def main(args): 24 | logging.info(args) 25 | 26 | logging.info("reading model...") 27 | device = torch.device('cuda') if args.cuda else torch.device('cpu') 28 | lm = torch.load(args.model_from, map_location=device) 29 | 30 | lm.eval() 31 | 32 | curr_seg = '' 33 | segment_utts: typing.Dict[str, typing.Any] = {} 34 | 35 | custom_h0 = None 36 | nb_carry_overs = 0 37 | nb_new_hs = 0 38 | 39 | with open(args.in_filename) as in_f, open(args.out_filename, 'w') as out_f: 40 | scorer = SegmentScorer(lm, out_f) 41 | 42 | for line in in_f: 43 | fields = line.split() 44 | segment, trans_id = brnolm.kaldi_itf.split_nbest_key(fields[0]) 45 | 46 | words = fields[1:] 47 | 48 | if not curr_seg: 49 | curr_seg = segment 50 | 51 | if segment != curr_seg: 52 | result = scorer.process_segment(curr_seg, segment_utts, custom_h0) 53 | if args.carry_over == 'always': 54 | custom_h0 = select_hidden_state_to_pass(result.hidden_states) 55 | nb_carry_overs += 1 56 | elif args.carry_over == 'speaker': 57 | if spk_sess(segment) == spk_sess(curr_seg): 58 | custom_h0 = select_hidden_state_to_pass(result.hidden_states) 59 | nb_carry_overs += 1 60 | else: 61 | custom_h0 = None 62 | nb_new_hs += 1 63 | elif args.carry_over == 'never': 64 | custom_h0 = None 65 | nb_new_hs += 1 66 | else: 67 | raise ValueError(f'Unsupported carry over regime {args.carry_over}') 68 | for hyp_no, cost in result.scores.items(): 69 | out_f.write(f"{curr_seg}-{hyp_no} {cost}\n") 70 | 71 | curr_seg = segment 72 | segment_utts = {} 73 | 74 | segment_utts[trans_id] = words 75 | 76 | # Last segment: 77 | result = scorer.process_segment(curr_seg, segment_utts) 78 | for hyp_no, cost in result.scores.items(): 79 | out_f.write(f"{curr_seg}-{hyp_no} {cost}\n") 80 | 81 | logging.info(f'Hidden state was carried over {nb_carry_overs} times and reset {nb_new_hs} times') 82 | 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser(description='PyTorch RNN/LSTM Language Model') 86 | parser.add_argument('--carry-over', default='always', choices=['always', 'speaker', 'never'], 87 | help='When to use the previous hidden state') 88 | parser.add_argument('--cuda', action='store_true', 89 | help='use CUDA') 90 | parser.add_argument('--character-lm', action='store_true', 91 | help='Process strings by characters') 92 | parser.add_argument('--model-from', type=str, required=True, 93 | help='where to load the model from') 94 | parser.add_argument('in_filename', help='second output of nbest-to-linear, textual') 95 | parser.add_argument('out_filename', help='where to put the LM scores') 96 | args = parser.parse_args() 97 | 98 | logging.basicConfig(level=logging.DEBUG) 99 | 100 | if args.cuda: 101 | gpu_owner = GPUOwner() 102 | 103 | main(args) 104 | -------------------------------------------------------------------------------- /brnolm/data_pipeline/flexible_pipeline.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class SequenceReadingHead: 6 | def __init__(self, seq, start=0): 7 | self.seq = seq 8 | self.pos = start 9 | 10 | def __next__(self): 11 | val = self.seq[self.pos] 12 | self.pos = (self.pos + 1) % len(self.seq) 13 | return val 14 | 15 | 16 | class FileReadingHead: 17 | def __init__(self, fn, pos, tokenizer, buffer_size=512): 18 | self.file = open(fn, 'rb') 19 | self.file.seek(pos) 20 | self.file.readline() # move to next full line 21 | self.tokenizer = tokenizer 22 | 23 | self.buffer = [] 24 | self.idx_in_buffer = 0 25 | self.target_buffer_size = buffer_size 26 | 27 | def __next__(self): 28 | assert self.idx_in_buffer <= len(self.buffer) 29 | 30 | if self.idx_in_buffer == len(self.buffer): 31 | self.refill_buffer() 32 | 33 | tok = self.buffer[self.idx_in_buffer] 34 | self.idx_in_buffer += 1 35 | 36 | return tok 37 | 38 | def refill_buffer(self): 39 | self.buffer.clear() 40 | while len(self.buffer) < self.target_buffer_size: 41 | line = self.file.readline().decode() 42 | if line == '': 43 | self.file.seek(0) 44 | self.buffer.extend(self.tokenizer(line)) 45 | 46 | self.idx_in_buffer = 0 47 | 48 | 49 | class StreamingCorruptor: 50 | def __init__(self, stream_provider, subs_rate, subs_range, del_rate, ins_rate, protected=[]): 51 | self.stream_provider = stream_provider 52 | self.sr = subs_rate 53 | self.subs_range = subs_range 54 | self.dr = del_rate 55 | self.ir = ins_rate 56 | self.protected = protected 57 | 58 | self.to_be_input = next(self.stream_provider) 59 | self.to_be_target = next(self.stream_provider) 60 | 61 | def __next__(self): 62 | nb_nonprotected = 0 63 | nb_subs = 0 64 | nb_dels = 0 65 | nb_inss = 0 66 | 67 | if self.to_be_input in self.protected: 68 | x, t = self.to_be_input, self.to_be_target 69 | self._move_one_token() 70 | return x, t 71 | 72 | nb_nonprotected += 1 73 | 74 | roll = random.random() 75 | if roll < self.dr: 76 | self._move_one_token() 77 | nb_dels += 1 78 | return next(self) 79 | elif roll < self.dr + self.sr: 80 | x, t = random.randrange(self.subs_range), self.to_be_target 81 | self._move_one_token() 82 | nb_subs += 1 83 | return x, t 84 | elif roll < self.dr + self.sr + self.ir: 85 | x, t = random.randrange(self.subs_range), self.to_be_target 86 | nb_inss += 1 87 | return x, t 88 | else: 89 | x, t = self.to_be_input, self.to_be_target 90 | self._move_one_token() 91 | return x, t 92 | 93 | def _move_one_token(self): 94 | self.to_be_input = self.to_be_target 95 | self.to_be_target = next(self.stream_provider) 96 | 97 | def summary(self): 98 | print(f'len {len(self.inputs)}, proper {nb_nonprotected}| D: {100.0*nb_dels/nb_nonprotected:.2f} % ({nb_dels}) S: {100.0*nb_subs/nb_nonprotected:.2f} % ({nb_subs}) I: {100.0*nb_inss/nb_nonprotected:.2f} % ({nb_inss})') 99 | 100 | 101 | class BatchingSlicingIterator: 102 | def __init__(self, sources, seq_len): 103 | self.sources = sources 104 | self.seq_len = seq_len 105 | 106 | def __next__(self): 107 | samples = [] 108 | for _ in range(self.seq_len): 109 | samples.append(list(next(s) for s in self.sources)) 110 | 111 | inputs = torch.tensor([[s[0] for s in time_slice] for time_slice in samples]) 112 | targets = torch.tensor([[s[1] for s in time_slice] for time_slice in samples]) 113 | 114 | return inputs.permute(1, 0), targets.permute(1, 0).contiguous() 115 | -------------------------------------------------------------------------------- /scripts/oov-clustering/process-hybrid-paths.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import sys 4 | 5 | from brnolm.language_models.vocab import vocab_from_kaldi_wordlist 6 | 7 | ST_WORDS = 0 8 | ST_OOV_INTEREST = 1 9 | ST_OOV_OTHER = 2 10 | 11 | 12 | def words_from_idx(idx_list): 13 | transcript = [] 14 | state = ST_WORDS 15 | for idx in idxes: 16 | if state == ST_WORDS: 17 | if idx == oov_start_idx + args.interest_constant: 18 | state = ST_OOV_INTEREST 19 | elif idx == oov_start_idx: 20 | state = ST_OOV_OTHER 21 | elif idx == oov_end_idx + args.interest_constant: 22 | raise ValueError("Unacceptable end of OOV-OI within WORDS ({}, key {})".format(line_no, key)) 23 | elif idx == oov_end_idx: 24 | raise ValueError("Unacceptable end of OOV-NI within WORDS ({}, key {})".format(line_no, key)) 25 | else: 26 | transcript.append(decoder_vocabulary.i2w(idx)) 27 | elif state == ST_OOV_INTEREST: 28 | if idx == oov_end_idx + args.interest_constant: 29 | transcript.append(args.unk_oi) 30 | state = ST_WORDS 31 | elif idx == oov_end_idx: 32 | raise ValueError("Unacceptable end of OOV-NI within OOV-OI ({}, key {})".format(line_no, key)) 33 | elif idx == oov_start_idx + args.interest_constant: 34 | raise ValueError("Unacceptable start of OOV-OI within OOV-OI ({}, key {})".format(line_no, key)) 35 | elif idx == oov_start_idx: 36 | raise ValueError("Unacceptable start of OOV-NI within OOV-OI ({}, key {})".format(line_no, key)) 37 | else: 38 | pass 39 | elif state == ST_OOV_OTHER: 40 | if idx == oov_end_idx: 41 | transcript.append(args.unk) 42 | state = ST_WORDS 43 | elif idx == oov_end_idx + args.interest_constant: 44 | raise ValueError("Unacceptable end of OOV-OI within OOV-NI ({}, key {})".format(line_no, key)) 45 | elif idx == oov_start_idx + args.interest_constant: 46 | raise ValueError("Unacceptable start of OOV-OI within OOV-NI ({}, key {})".format(line_no, key)) 47 | elif idx == oov_start_idx: 48 | raise ValueError("Unacceptable start of OOV-NI within OOV-NI ({}, key {})".format(line_no, key)) 49 | else: 50 | pass 51 | else: 52 | raise RuntimeError("got into an impossible state {}".format(state)) 53 | 54 | if state == ST_OOV_INTEREST: 55 | raise ValueError("Incomplete OOV of interest on line '{}'".format(idx_list)) 56 | elif state == ST_OOV_OTHER: 57 | raise ValueError("Incomplete OOV (not of interest) on line '{}'".format(idx_list)) 58 | 59 | return transcript 60 | 61 | 62 | if __name__ == '__main__': 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument('--unk', default="") 65 | parser.add_argument('--unk-oi', default="") 66 | parser.add_argument('--oov-start', required=True) 67 | parser.add_argument('--oov-end', required=True) 68 | parser.add_argument('--interest-constant', type=int, required=True) 69 | parser.add_argument('--decoder-wordlist', required=True) 70 | args = parser.parse_args() 71 | 72 | with open(args.decoder_wordlist) as f: 73 | decoder_vocabulary = vocab_from_kaldi_wordlist(f, unk_word=args.unk) 74 | 75 | oov_start_idx = decoder_vocabulary[args.oov_start] 76 | oov_end_idx = decoder_vocabulary[args.oov_end] 77 | 78 | for line_no, line in enumerate(sys.stdin): 79 | fields = line.split() 80 | key = fields[0] 81 | idxes = [int(idx) for idx in fields[1:]] 82 | 83 | try: 84 | transcript = words_from_idx(idxes) 85 | except ValueError: 86 | sys.stderr.write("WARNING: there was a problem with input line {} (counting from 0)\n".format(line_no)) 87 | continue 88 | 89 | sys.stdout.write("{} {}\n".format(key, " ".join(str(w) for w in transcript))) 90 | -------------------------------------------------------------------------------- /brnolm/data_pipeline/multistream.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LineTooLongError(Exception): 5 | pass 6 | 7 | 8 | def batchify(data, bsz, cuda): 9 | """ For simple rearranging of 'single sentence' data. 10 | """ 11 | # Work out how cleanly we can divide the dataset into bsz parts. 12 | nbatch = data.size(0) // bsz 13 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 14 | data = data.narrow(0, 0, nbatch * bsz) 15 | # Evenly divide the data across the bsz batches. 16 | data = data.view(bsz, -1).t().contiguous() 17 | if cuda: 18 | data = data.cuda() 19 | return data 20 | 21 | 22 | class Batcher: 23 | """Groups sequences in a list into batches 24 | """ 25 | def __init__(self, samples, max_batch_size=None, max_total_len=None): 26 | self.max_batch_size = max_batch_size 27 | self.max_total_len = max_total_len 28 | self.samples = samples 29 | 30 | def _batch_size_ok(self, i, j): 31 | if not self.max_batch_size: 32 | return True 33 | 34 | return (j-i) <= self.max_batch_size 35 | 36 | def _total_len_ok(self, i, j): 37 | if not self.max_total_len: 38 | return True 39 | 40 | return max((len(t) for t in self.samples[i:j]), default=0) * (j-i) <= self.max_total_len 41 | 42 | def __iter__(self): 43 | i = 0 44 | while i < len(self.samples): 45 | j = i 46 | while j <= len(self.samples) and self._batch_size_ok(i, j) and self._total_len_ok(i, j): 47 | j += 1 48 | j -= 1 49 | 50 | if i == j: 51 | raise LineTooLongError(f'Failed to construct a batch on line {i} (zero-based)') 52 | 53 | batch = self.samples[i:j] 54 | assert len(batch) > 0 55 | assert sum(len(s) for s in batch) > 0 56 | 57 | yield batch 58 | i = j 59 | 60 | 61 | class BatchBuilder(): 62 | def __init__(self, streams, max_batch_size, discard_h=True): 63 | """ For complex combination of different lenghts sources. 64 | """ 65 | self._streams = streams 66 | 67 | if max_batch_size <= 0: 68 | raise ValueError("BatchBuilder must be constructed" 69 | "with a positive batch size, (got {})".format(max_batch_size) 70 | ) 71 | self._max_bsz = max_batch_size 72 | self._discard_h = discard_h 73 | 74 | def __iter__(self): 75 | streams = [iter(s) for s in self._streams] 76 | active_streams = [] 77 | reserve_streams = streams 78 | 79 | while True: 80 | batch = [] 81 | streams_continued = [] 82 | streams_ended = [] 83 | for i, s in enumerate(active_streams): 84 | try: 85 | batch.append(next(s)) 86 | streams_continued.append(i) 87 | except StopIteration: 88 | streams_ended.append(i) 89 | 90 | active_streams = [active_streams[i] for i in streams_continued] 91 | 92 | # refill the batch (of active streams) 93 | while len(reserve_streams) > 0: 94 | if len(batch) == self._max_bsz: 95 | break 96 | 97 | stream = reserve_streams[0] 98 | del reserve_streams[0] 99 | try: 100 | batch.append(next(stream)) 101 | active_streams.append(stream) 102 | except StopIteration: 103 | pass 104 | 105 | if len(batch) == 0: 106 | return 107 | 108 | if self._discard_h: 109 | hs_passed_on = streams_continued 110 | else: 111 | hs_passed_on = (streams_continued + streams_ended)[:len(batch)] 112 | 113 | parts = zip(*batch) 114 | parts = [torch.stack(part) for part in parts] 115 | yield tuple(parts) + (torch.LongTensor(hs_passed_on), ) 116 | -------------------------------------------------------------------------------- /brnolm/runtime/runtime_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | from brnolm.data_pipeline import split_corpus_dataset 5 | 6 | import sys 7 | import math 8 | 9 | 10 | class CudaStream(): 11 | def __init__(self, source): 12 | self._source = source 13 | 14 | def __iter__(self): 15 | for batch in self._source: 16 | yield tuple(x.cuda() for x in batch) 17 | 18 | 19 | class TransposeWrapper: 20 | def __init__(self, stream): 21 | self._stream = stream 22 | 23 | def __iter__(self): 24 | for a_tuple in self._stream: 25 | yield tuple(x.t().contiguous() for x in a_tuple) 26 | 27 | def __len__(self): 28 | return len(self._stream) 29 | 30 | 31 | def repackage_hidden(h): 32 | """Detaches a tuple of tensors them from their history.""" 33 | if isinstance(h, torch.Tensor): 34 | return h.detach() 35 | else: 36 | return tuple(repackage_hidden(v) for v in h) 37 | 38 | 39 | def filelist_to_tokenized_splits(filelist_filename, vocab, bptt, wrapper=split_corpus_dataset.TokenizedSplit): 40 | filenames = filenames_file_to_filenames(filelist_filename) 41 | tss = [] 42 | for filename in filenames: 43 | with open(filename, 'r') as f: 44 | tss.append(wrapper(f, vocab, bptt)) 45 | 46 | return tss 47 | 48 | 49 | def filelist_to_objects(filelist_filename, action): 50 | filenames = filenames_file_to_filenames(filelist_filename) 51 | objects = [] 52 | for filename in filenames: 53 | with open(filename, 'r') as f: 54 | objects.append(action(f)) 55 | 56 | return objects 57 | 58 | 59 | def filenames_file_to_filenames(filelist_filename): 60 | with open(filelist_filename) as filelist: 61 | filenames = filelist.read().split() 62 | 63 | return filenames 64 | 65 | 66 | def init_seeds(seed, cuda): 67 | random.seed(seed) 68 | torch.manual_seed(seed) 69 | if cuda and torch.cuda.is_available(): 70 | torch.cuda.manual_seed(seed) 71 | 72 | 73 | class BatchFilter: 74 | def __init__(self, data, batch_size, bptt, min_batch_size): 75 | self._data = data 76 | self._batch_size = batch_size 77 | self._bptt = bptt 78 | self._min_batch_size = min_batch_size 79 | 80 | self._nb_skipped_updates = 0 81 | self._nb_skipped_words = 0 82 | self._nb_skipped_seqs = 0 # accumulates size of skipped batches 83 | 84 | def __iter__(self): 85 | for batch in self._data: 86 | X = batch[0] 87 | if X.size(0) >= self._min_batch_size: 88 | yield batch 89 | else: 90 | self._nb_skipped_updates += 1 91 | self._nb_skipped_words += X.size(0) * X.size(1) 92 | self._nb_skipped_seqs += X.size(0) 93 | 94 | def report(self): 95 | if self._nb_skipped_updates > 0: 96 | sys.stderr.write( 97 | "WARNING: due to skipping, a total of {} updates was skipped," 98 | " containing {} words. Avg batch size {}. Equal to {} full batches" 99 | "\n".format( 100 | self._nb_skipped_updates, 101 | self._nb_skipped_words, 102 | self._nb_skipped_seqs/self._nb_skipped_updates, 103 | self._nb_skipped_words/(self._batch_size*self._bptt) 104 | ) 105 | ) 106 | 107 | 108 | def epoch_summary(epoch_no, nb_updates, elapsed_time, loss): 109 | delim_line = '-' * 89 + '\n' 110 | 111 | epoch_stmt = 'end of epoch {:3d}'.format(epoch_no) 112 | updates_stmt = '# updates: {}'.format(nb_updates) 113 | time_stmt = 'time: {:5.2f}s'.format(elapsed_time) 114 | loss_stmt = 'valid loss {:5.2f}'.format(loss) 115 | ppl_stmt = 'valid ppl {:8.2f}'.format(math.exp(loss)) 116 | values_line = '| {} | {} | {} | {} | {}\n'.format( 117 | epoch_stmt, updates_stmt, time_stmt, loss_stmt, ppl_stmt 118 | ) 119 | 120 | return delim_line + values_line + delim_line 121 | -------------------------------------------------------------------------------- /test/test_smm_ivec_extractor.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from .common import TestCase 3 | import os 4 | import sys 5 | 6 | try: 7 | import brnolm.smm_itf.smm_ivec_extractor as smm_ivec_extractor 8 | except ImportError: 9 | sys.stderr.write('Failed to import SMM implementation\n') 10 | from brnolm.language_models.vocab import Vocabulary 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.autograd import Variable 15 | import numpy as np 16 | 17 | from sklearn.feature_extraction.text import CountVectorizer 18 | 19 | 20 | class DummySMM(nn.Module): 21 | def __init__(self, ivec_dim): 22 | self.W = Variable(torch.zeros(ivec_dim, 1), requires_grad=True) 23 | self.T = Variable(torch.zeros(ivec_dim, 20), requires_grad=True) 24 | 25 | 26 | @unittest.skipIf(os.environ.get('TEST_SMM') != 'yes', "For SMM tests, set TEST_SMM='yes'") 27 | class IvecExtractorTests(TestCase): 28 | def setUp(self): 29 | self.documents_six = ["text consisting of SIX different words", "text"] 30 | 31 | # note that following also contains puctuation marks, to be ignored by CountVectorizer 32 | self.documents_seven = ["text consisting of SEVEN different words", ", seventh ."] 33 | 34 | def build_neededs(self, documents): 35 | smm = DummySMM(ivec_dim=4) 36 | cvect = CountVectorizer(documents, strip_accents='ascii', analyzer='word') 37 | cvect.fit(documents) 38 | vocab = cvect.get_feature_names() 39 | self.cvect = CountVectorizer(documents, strip_accents='ascii', analyzer='word', vocabulary=vocab) 40 | 41 | self.extractor = smm_ivec_extractor.IvecExtractor(smm, nb_iters=10, lr=0.1, tokenizer=self.cvect) 42 | self.vocab = Vocabulary(unk_word="", unk_index=0) 43 | self.vocab.add_from_text(" ".join(documents)) 44 | 45 | def test_one_empty_bow(self): 46 | self.build_neededs(self.documents_six) 47 | 48 | ivecs = self.extractor.zero_bows(1) 49 | expectation = torch.zeros(1, 6) 50 | self.assertEqual(ivecs, expectation) 51 | 52 | def test_two_empty_bows(self): 53 | self.build_neededs(self.documents_six) 54 | ivecs = self.extractor.zero_bows(2) 55 | expectation = torch.zeros(2, 6) 56 | self.assertEqual(ivecs, expectation) 57 | 58 | def test_build_translator_single_word_translation(self): 59 | self.build_neededs(self.documents_six) 60 | translator = self.extractor.build_translator(self.vocab) 61 | word = "SIX" 62 | lm_word = torch.LongTensor([self.vocab[word]]) 63 | cv_word = torch.from_numpy(self.cvect.transform([word]).A.astype(np.float32)).squeeze() 64 | self.assertEqual(translator(lm_word), cv_word) 65 | 66 | def test_build_translator_two_word_translation(self): 67 | self.build_neededs(self.documents_six) 68 | translator = self.extractor.build_translator(self.vocab) 69 | words = "SIX words" 70 | lm_words = torch.LongTensor([self.vocab[w] for w in words.split()]) 71 | cv_words = torch.from_numpy(self.cvect.transform([words]).A.astype(np.float32)).squeeze() 72 | self.assertEqual(translator(lm_words), cv_words) 73 | 74 | def test_build_translator_two_two_word_translations(self): 75 | self.build_neededs(self.documents_six) 76 | translator = self.extractor.build_translator(self.vocab) 77 | words = ["SIX words", "of text"] 78 | lm_words = torch.LongTensor([[self.vocab[w] for w in seq.split()] for seq in words]) 79 | cv_words = torch.from_numpy(self.cvect.transform(words).A.astype(np.float32)).squeeze() 80 | self.assertEqual(translator(lm_words), cv_words) 81 | 82 | def test_build_translator_two_two_word_translations_different_vocab(self): 83 | self.build_neededs(self.documents_seven) 84 | translator = self.extractor.build_translator(self.vocab) 85 | words = ["SIX words", "of text"] 86 | lm_words = torch.LongTensor([[self.vocab[w] for w in seq.split()] for seq in words]) 87 | cv_words = torch.from_numpy(self.cvect.transform(words).A.astype(np.float32)).squeeze() 88 | self.assertEqual(translator(lm_words), cv_words) 89 | -------------------------------------------------------------------------------- /brnolm/data_pipeline/split_corpus_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from brnolm.data_pipeline.temporal_splitting import TemporalSplits 4 | 5 | 6 | class TokenizedSplitFFBase(): 7 | def __init__(self, f, vocab, temporal_split_builder): 8 | """ 9 | Args: 10 | f (file): File with a document. 11 | vocab (Vocabulary): Vocabulary for translation word -> index 12 | """ 13 | sentence = f.read() 14 | self._words = sentence.split() 15 | self._tokens = torch.LongTensor([vocab[w] for w in self._words]) 16 | 17 | self._temp_splits = temporal_split_builder(self._tokens) 18 | 19 | def __iter__(self): 20 | for x, t in self._temp_splits: 21 | yield x, t 22 | 23 | def __len__(self): 24 | return len(self._temp_splits) 25 | 26 | def input_words(self): 27 | for lend, rend in self._temp_splits.ranges(): 28 | yield " ".join(self._words[lend:rend]) 29 | 30 | 31 | class TokenizedSplit(TokenizedSplitFFBase): 32 | def __init__(self, f, vocab, unroll_length): 33 | """ 34 | Args: 35 | f (file): File with a document. 36 | vocab (Vocabulary): Vocabulary for translation word -> index 37 | """ 38 | ts_builder = lambda seq: TemporalSplits(seq, nb_inputs_necessary=1, nb_targets_parallel=unroll_length) 39 | super().__init__(f, vocab, ts_builder) 40 | 41 | 42 | class TokenizedSplitSingleTarget(TokenizedSplitFFBase): 43 | def __init__(self, f, vocab, unroll_length): 44 | """ 45 | Args: 46 | f (file): File with a document. 47 | vocab (Vocabulary): Vocabulary for translation word -> index 48 | """ 49 | ts_builder = lambda seq: TemporalSplits(seq, nb_inputs_necessary=unroll_length, nb_targets_parallel=1) 50 | super().__init__(f, vocab, ts_builder) 51 | 52 | 53 | class TokenizedSplitFFMultiTarget(TokenizedSplitFFBase): 54 | def __init__(self, f, vocab, hist_len, nb_targets_parallel): 55 | """ 56 | Args: 57 | f (file): File with a document. 58 | vocab (Vocabulary): Vocabulary for translation word -> index 59 | """ 60 | ts_builder = lambda seq: TemporalSplits(seq, nb_inputs_necessary=hist_len, nb_targets_parallel=nb_targets_parallel) 61 | super().__init__(f, vocab, ts_builder) 62 | 63 | 64 | class DomainAdaptationSplitFFBase: 65 | def __init__(self, f, vocab, end_portion, ts_builder): 66 | sentence = f.read() 67 | words = sentence.split() 68 | 69 | nb_domain_words = int(len(words)*end_portion-0.01) 70 | 71 | self._tokens = torch.LongTensor([vocab[w] for w in words[:-nb_domain_words]]) 72 | self._domain_string = " ".join(words[len(words)-nb_domain_words:]) 73 | 74 | self._temp_splitter = ts_builder(self._tokens) 75 | 76 | def __iter__(self): 77 | for x, t in self._temp_splitter: 78 | yield x, t 79 | 80 | def __len__(self): 81 | return len(self._temp_splitter) 82 | 83 | def input_words(self): 84 | return [self._domain_string] 85 | 86 | 87 | class DomainAdaptationSplitFFMultiTarget(DomainAdaptationSplitFFBase): 88 | def __init__(self, f, vocab, hist_len, nb_targets_parallel, end_portion): 89 | """ 90 | Args: 91 | f (file): File with a document. 92 | vocab (Vocabulary): Vocabulary for translation word -> index 93 | """ 94 | ts_builder = lambda seq: TemporalSplits(seq, nb_inputs_necessary=hist_len, nb_targets_parallel=nb_targets_parallel) 95 | super().__init__(f, vocab, end_portion, ts_builder) 96 | 97 | 98 | class DomainAdaptationSplit(DomainAdaptationSplitFFBase): 99 | def __init__(self, f, vocab, unroll_length, end_portion): 100 | """ 101 | Args: 102 | f (file): File with a document. 103 | vocab (Vocabulary): Vocabulary for translation word -> index 104 | """ 105 | 106 | ts_builder = lambda seq: TemporalSplits(seq, nb_inputs_necessary=unroll_length, nb_targets_parallel=1) 107 | super().__init__(f, vocab, end_portion, ts_builder) 108 | -------------------------------------------------------------------------------- /test/test_det.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from brnolm.oov_clustering.det import det_points_from_score_tg 4 | from brnolm.oov_clustering.det import subsample_list 5 | from brnolm.oov_clustering.det import area_under_curve, eer 6 | 7 | 8 | class DetPointTests(TestCase): 9 | def test_trivial(self): 10 | score_tg = [ 11 | (0.0, 0), 12 | (1.0, 1) 13 | ] 14 | 15 | det_points = [ 16 | [0.5, 0.0], 17 | [0.0, 0.0], 18 | [0.0, 0.5], 19 | ] 20 | 21 | self.assertEqual(det_points_from_score_tg(score_tg)[0], det_points) 22 | 23 | def test_permutation(self): 24 | score_tg = [ 25 | (1.0, 1), 26 | (0.0, 0), 27 | ] 28 | 29 | det_points = [ 30 | [0.5, 0.0], 31 | [0.0, 0.0], 32 | [0.0, 0.5], 33 | ] 34 | 35 | self.assertEqual(det_points_from_score_tg(score_tg)[0], det_points) 36 | 37 | def test_simple_bad_system(self): 38 | score_tg = [ 39 | (1.0, 0), 40 | (0.0, 1), 41 | ] 42 | 43 | det_points = [ 44 | [0.5, 0.0], 45 | [0.5, 0.5], 46 | [0.0, 0.5], 47 | ] 48 | 49 | self.assertEqual(det_points_from_score_tg(score_tg)[0], det_points) 50 | 51 | def test_multiple_points(self): 52 | score_tg = [ 53 | (0.0, 0), 54 | (0.1, 1), 55 | (0.2, 0), 56 | (0.3, 1), 57 | (0.4, 1), 58 | ] 59 | 60 | det_points = [ 61 | [0.6, 0.0], 62 | [0.4, 0.0], 63 | [0.2, 0.0], 64 | [0.2, 0.2], 65 | [0.0, 0.2], 66 | [0.0, 0.4], 67 | ] 68 | 69 | self.assertEqual(det_points_from_score_tg(score_tg)[0], det_points) 70 | 71 | 72 | class ListSubsamplingTests(TestCase): 73 | def test_trivial(self): 74 | self.assertEqual( 75 | subsample_list([1, 2], 2), 76 | [1, 2] 77 | ) 78 | 79 | def test_halving_longer(self): 80 | self.assertEqual( 81 | subsample_list([1, 2, 3, 4], 2), 82 | [1, 4] 83 | ) 84 | 85 | def test_halving_uneven(self): 86 | self.assertEqual( 87 | subsample_list([1, 2, 3, 4, 5], 2), 88 | [1, 5] 89 | ) 90 | 91 | def test_halving_uneven_2(self): 92 | self.assertEqual( 93 | subsample_list([1, 2, 3, 4, 5], 3), 94 | [1, 3, 5] 95 | ) 96 | 97 | def test_every_third(self): 98 | self.assertEqual( 99 | subsample_list([1, 2, 3, 4, 5, 6], 2), 100 | [1, 6] 101 | ) 102 | 103 | def test_include_most(self): 104 | self.assertIn( 105 | subsample_list([1, 2, 3, 4, 5, 6], 5), 106 | [ 107 | [1, 3, 4, 5, 6], 108 | [1, 2, 4, 5, 6], 109 | [1, 2, 3, 5, 6], 110 | [1, 2, 3, 4, 6], 111 | ] 112 | ) 113 | 114 | def test_include_4_of_6(self): 115 | self.assertIn( 116 | subsample_list([1, 2, 3, 4, 5, 6], 4), 117 | [ 118 | [1] + middle + [6] for middle in [ 119 | [2, 3], [2, 4], [2, 5], 120 | [3, 4], [3, 5], 121 | [4, 5], 122 | ] 123 | ] 124 | ) 125 | 126 | 127 | class AreaComputationTests(TestCase): 128 | def test_trivial(self): 129 | self.assertAlmostEqual(area_under_curve([0.0, 1.0], [1.0, 0.0]), 0.5) 130 | 131 | def test_breakpoint(self): 132 | self.assertAlmostEqual(area_under_curve([0.0, 0.1, 1.0], [1.0, 0.1, 0.0]), 0.1*0.1 + 2*0.1*0.9/2) 133 | 134 | def test_end_addition(self): 135 | self.assertAlmostEqual(area_under_curve([0.1], [0.1]), 0.1*0.1 + 2*0.1*0.9/2) 136 | 137 | 138 | class EerComputationTests(TestCase): 139 | def test_trivial(self): 140 | self.assertAlmostEqual(eer([0.0, 1.0], [1.0, 0.0]), 0.5) 141 | 142 | def test_skewed(self): 143 | self.assertAlmostEqual(eer([0.0, 1.0], [0.5, 0.0]), 1.0/3.0) 144 | 145 | def test_exact_hit(self): 146 | self.assertAlmostEqual(eer([0.0, 0.1, 1.0], [1.0, 0.1, 0.0]), 0.1) 147 | -------------------------------------------------------------------------------- /brnolm/runtime/loggers.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import math 4 | 5 | import torch 6 | 7 | 8 | class BaseLogger(): 9 | def __init__(self, report_period, output_file=sys.stdout): 10 | self._start_time = time.time() 11 | self._nb_logs = 0 12 | self._report_period = report_period 13 | self._of = output_file 14 | self._construction_time = time.time() 15 | 16 | def log(self, *args): 17 | self._log(*args) 18 | self._nb_logs += 1 19 | 20 | if self._nb_logs % self._report_period == 0: 21 | self._flush() 22 | self._reset() 23 | self._start_time = time.time() 24 | 25 | def nb_updates(self): 26 | return self._nb_logs 27 | 28 | def time_since_creation(self): 29 | return time.time() - self._construction_time 30 | 31 | def _flush(self): 32 | pass 33 | 34 | def _reset(self): 35 | pass 36 | 37 | def _log(self, *args): 38 | pass 39 | 40 | 41 | class InfinityLogger(BaseLogger): 42 | def __init__(self, epoch, report_period, lr, output_file=sys.stdout): 43 | super().__init__(report_period, output_file) 44 | self._running_loss = 0.0 45 | self._epoch = epoch 46 | self._lr = lr 47 | 48 | def _log(self, loss, lr): 49 | self._running_loss += loss 50 | self._lr = lr 51 | 52 | def _flush(self): 53 | ms_per_log = (time.time() - self._start_time) * 1000 / self._report_period 54 | cur_loss = (self._running_loss / self._report_period).item() 55 | line = f'| epoch {self._epoch:3d} | {self._nb_logs:5d} batches done | lr {self._lr:.3e} | ms/batch {ms_per_log:5.2f} | loss {cur_loss:5.2f} | ppl {math.exp(cur_loss):8.2f}\n' 56 | self._of.write(line) 57 | 58 | def _reset(self): 59 | self._running_loss = 0.0 60 | 61 | 62 | 63 | class GradLogger(BaseLogger): 64 | def __init__(self, report_period, named_params, output_file=sys.stdout): 65 | super().__init__(report_period, output_file) 66 | self._named_params = list(named_params) 67 | 68 | self._grads = {} 69 | for name, param in self._named_params: 70 | self._grads[name] = [] 71 | 72 | self._of.write("{}\n".format(" ".join(self._grads))) 73 | 74 | def _log(self): 75 | for name, param in self._named_params: 76 | self._grads[name].append(param.grad.abs().mean()) 77 | 78 | def _flush(self): 79 | grad_mavs = [] 80 | for name in self._grads: 81 | all_grads = torch.stack(self._grads[name]) 82 | grad_mavs.append(all_grads.mean().data.item()) 83 | 84 | fmt_string = " ".join("{:.7f}" for _ in grad_mavs) + "\n" 85 | line = fmt_string.format(*grad_mavs) 86 | self._of.write(line) 87 | 88 | def _reset(self): 89 | for name in self._grads: 90 | self._grads[name] = [] 91 | 92 | 93 | class ProgressLogger(): 94 | def __init__(self, epoch, report_period, lr, nb_updates, output_file=sys.stdout): 95 | self._start_time = time.time() 96 | self._nb_logs = 0 97 | self._running_loss = 0.0 98 | self._epoch = epoch 99 | self._report_period = report_period 100 | self._of = output_file 101 | self._lr = lr 102 | self._construction_time = time.time() 103 | self._nb_updates = nb_updates 104 | 105 | def log(self, loss): 106 | self._running_loss += loss 107 | self._nb_logs += 1 108 | 109 | if self._nb_logs % self._report_period == 0: 110 | self._flush() 111 | self._reset() 112 | 113 | def time_since_creation(self): 114 | return time.time() - self._construction_time 115 | 116 | def nb_updates(self): 117 | return self._nb_updates 118 | 119 | def _flush(self): 120 | ms_per_log = (time.time() - self._start_time) * 1000 / self._report_period 121 | cur_loss = (self._running_loss / self._report_period).item() 122 | fmt_string = '| epoch {:3d} | {:5d}/{:5d} batches | lr {:.3e} | ms/batch {:5.2f} | loss {:5.2f} | ppl {:8.2f}\n' 123 | line = fmt_string.format( 124 | self._epoch, self._nb_logs, self._nb_updates, self._lr, 125 | ms_per_log, cur_loss, math.exp(cur_loss) 126 | ) 127 | self._of.write(line) 128 | 129 | def _reset(self): 130 | self._running_loss = 0.0 131 | self._start_time = time.time() 132 | 133 | 134 | class NoneLogger(): 135 | def log(self, *args): 136 | pass 137 | -------------------------------------------------------------------------------- /brnolm/runtime/runtime_multifile.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .runtime_utils import repackage_hidden 4 | from .tensor_reorganization import TensorReorganizer 5 | 6 | 7 | def prepare_inputs(inputs, do_transpose, use_ivecs, custom_batches): 8 | X = inputs[0] 9 | batch_size = X.size(0) 10 | if do_transpose: 11 | X = X.t() 12 | 13 | targets = inputs[1] 14 | if do_transpose: 15 | targets = targets.t().contiguous() 16 | 17 | if use_ivecs: 18 | ivecs = inputs[2] 19 | else: 20 | ivecs = None 21 | 22 | if custom_batches: 23 | mask = inputs[-1] # 3 24 | else: 25 | mask = None 26 | 27 | return X, targets, ivecs, mask, batch_size 28 | 29 | 30 | def evaluate_(lm, data_source, use_ivecs, custom_batches): 31 | lm.eval() 32 | 33 | total_loss = 0.0 34 | total_timesteps = 0 35 | 36 | if custom_batches: 37 | hs_reorganizer = TensorReorganizer(lm.model.init_hidden) 38 | 39 | hidden = None 40 | do_transpose = not lm.model.batch_first 41 | 42 | for inputs in data_source: 43 | X, targets, ivecs, mask, batch_size = prepare_inputs( 44 | inputs, 45 | do_transpose, use_ivecs, custom_batches 46 | ) 47 | 48 | if hidden is None: 49 | hidden = lm.model.init_hidden(batch_size) 50 | 51 | if custom_batches: 52 | hidden = hs_reorganizer(hidden, mask, batch_size) 53 | 54 | hidden = repackage_hidden(hidden) 55 | 56 | if use_ivecs: 57 | output, hidden = lm.model(X, hidden, ivecs) 58 | else: 59 | output, hidden = lm.model(X, hidden) 60 | 61 | loss, nb_words = lm.decoder.neg_log_prob(output, targets) 62 | total_loss += loss.data 63 | total_timesteps += nb_words 64 | 65 | return total_loss.item() / total_timesteps 66 | 67 | 68 | def evaluate(model, data_source, use_ivecs): 69 | return evaluate_( 70 | model, data_source, 71 | use_ivecs, custom_batches=True 72 | ) 73 | 74 | 75 | def evaluate_no_transpose(model, data_source, use_ivecs): 76 | return evaluate_( 77 | model, data_source, 78 | use_ivecs, custom_batches=False 79 | ) 80 | 81 | 82 | def train_(lm, data, optim, logger, clip, use_ivecs, custom_batches, tb_logger=None): 83 | lm.train() 84 | 85 | if custom_batches: 86 | hs_reorganizer = TensorReorganizer(lm.model.init_hidden) 87 | 88 | hidden = None 89 | do_transpose = not lm.model.batch_first 90 | 91 | for inputs in data: 92 | X, targets, ivecs, mask, batch_size = prepare_inputs( 93 | inputs, 94 | do_transpose, use_ivecs, custom_batches 95 | ) 96 | 97 | if hidden is None: 98 | hidden = lm.model.init_hidden(batch_size) 99 | 100 | if custom_batches: 101 | hidden = hs_reorganizer(hidden, mask, batch_size) 102 | hidden = repackage_hidden(hidden) 103 | 104 | if use_ivecs: 105 | output, hidden = lm.model(X, hidden, ivecs) 106 | else: 107 | output, hidden = lm.model(X, hidden) 108 | loss, nb_words = lm.decoder.neg_log_prob(output, targets) 109 | loss /= nb_words 110 | 111 | optim.zero_grad() 112 | loss.backward() 113 | torch.nn.utils.clip_grad_norm(lm.parameters(), clip) 114 | 115 | optim.step() 116 | logger.log(loss.data) 117 | 118 | if tb_logger is not None: 119 | tb_logger.next_step() 120 | info = { 121 | 'loss/train': loss.item(), 122 | 'ppl/train': loss.exp().item(), 123 | } 124 | for tag, value in info.items(): 125 | tb_logger.scalar_summary(tag, value) 126 | tb_logger.hierarchical_scalar_summary(tag.split('/')[0], tag.split('/')[0], value, enforce=True) 127 | 128 | for tag, value in lm.named_parameters(): 129 | tag = tag.replace('.', '/') 130 | tb_logger.histo_summary(tag, value.data.cpu().numpy()) 131 | tb_logger.histo_summary(tag+'/grad', value.grad.data.cpu().numpy()) 132 | 133 | 134 | def train(model, data, optim, logger, clip, use_ivecs): 135 | train_( 136 | model, data, optim, logger, clip, 137 | use_ivecs, custom_batches=True 138 | ) 139 | 140 | 141 | def train_no_transpose(model, data, optim, logger, clip, use_ivecs): 142 | train_( 143 | lm, data, optim, logger, clip, 144 | use_ivecs, custom_batches=False 145 | ) 146 | -------------------------------------------------------------------------------- /brnolm/analyze-ivec-changes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import io 4 | import math 5 | import sys 6 | import torch 7 | 8 | import split_corpus_dataset 9 | import ivec_appenders 10 | import smm_ivec_extractor 11 | 12 | from runtime_utils import filenames_file_to_filenames 13 | 14 | class DummyDict: 15 | def __getitem__(self, index): 16 | return 0 17 | 18 | 19 | def euclidean_distance(a, b): 20 | return (a-b).pow(2).sum(dim=-1).pow(0.5) 21 | 22 | 23 | def length(a): 24 | return a.pow(2).sum(dim=-1).pow(0.5) 25 | 26 | 27 | def euclidean_distance(a, b): 28 | return (a-b).pow(2).sum(dim=-1).pow(0.5) 29 | 30 | 31 | def cosine_similarity(a, b): 32 | return (a*b).sum(dim=-1) / (length(a) * length(b)) 33 | 34 | 35 | def analyze_document(text, ivec_extractor): 36 | nb_words = len(text.split()) 37 | 38 | if args.unroll_steps is None: 39 | unroll = args.unroll 40 | else: 41 | if nb_words % args.unroll_steps == 0: 42 | text = text.rsplit(maxsplit=1)[0] 43 | nb_words -= 1 44 | unroll = nb_words // args.unroll_steps 45 | 46 | complete_ivec = ivec_extractor(text) 47 | complete_ivec_len = length(complete_ivec) 48 | 49 | partial_ivecs = [ 50 | ivec_extractor(" ".join(prefix)) for prefix in [ 51 | text.split()[:l] for l in range(0, nb_words, unroll) 52 | ] 53 | ] 54 | 55 | if args.unroll_steps: 56 | partial_ivecs = torch.stack(partial_ivecs[:args.unroll_steps]) 57 | 58 | if partial_ivecs.size(0) != args.unroll_steps: 59 | print(partial_ivecs.size(0)) 60 | 61 | distances = euclidean_distance(partial_ivecs, complete_ivec) 62 | cos_sims = cosine_similarity(partial_ivecs, complete_ivec) 63 | 64 | return distances, complete_ivec_len, cos_sims 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--unroll', type=int, default=35, 70 | help="bptt equivalent. Applies only when --unroll-steps is not set.") 71 | parser.add_argument('--unroll-steps', type=int, 72 | help="for how many steps the text should unrolled. Overrides --unroll.") 73 | source_opt = parser.add_mutually_exclusive_group(required=True) 74 | source_opt.add_argument('--document', help="what to perform the analysis on") 75 | source_opt.add_argument('--file-list', help="file with list of files analyze") 76 | parser.add_argument('--ivec-extractor', required=True, 77 | help="iVector extractor to use") 78 | args = parser.parse_args() 79 | print(args) 80 | 81 | print("loading SMM iVector extractor ...") 82 | with open(args.ivec_extractor, 'rb') as f: 83 | ivec_extractor = smm_ivec_extractor.load(f) 84 | print(ivec_extractor) 85 | 86 | if args.document: 87 | with open(args.document) as f: 88 | content = f.read() 89 | distances = analyze_document(content, ivec_extractor) 90 | 91 | print(distances) 92 | 93 | elif args.file_list: 94 | if not args.unroll_steps: 95 | raise ValueError("When analyzing a filelist, --unroll-steps HAS to be specified.") 96 | 97 | documents = filenames_file_to_filenames(args.file_list) 98 | 99 | distances = [] 100 | ci_lens = [] 101 | cos_sims = [] 102 | nb_failed = 0 103 | for doc in documents: 104 | with open(doc) as f: 105 | content = f.read() 106 | 107 | try: 108 | distance, ci_len, cos_sim = analyze_document(content, ivec_extractor) 109 | distances.append(distance) 110 | ci_lens.append(ci_len) 111 | cos_sims.append(cos_sim) 112 | except ValueError: 113 | nb_failed += 1 114 | 115 | if nb_failed > 0: 116 | sys.stderr.write("Failed analyzing {} documents, because they are too short.\n".format(nb_failed)) 117 | 118 | distances = torch.stack(distances) 119 | ci_lens = torch.stack(ci_lens) 120 | cos_sims = torch.stack(cos_sims) 121 | 122 | print(torch.stack([distances.min(dim=0)[0], distances.mean(dim=0), distances.max(dim=0)[0], distances.var(dim=0)]).t()) 123 | 124 | distances_normed = distances / ci_lens 125 | print(torch.stack([ 126 | distances_normed.min(dim=0)[0], distances_normed.mean(dim=0), 127 | distances_normed.max(dim=0)[0], distances_normed.var(dim=0) 128 | ]).t()) 129 | 130 | print(torch.stack([cos_sims.min(dim=0)[0], cos_sims.mean(dim=0), cos_sims.max(dim=0)[0], cos_sims.var(dim=0)]).t()) 131 | -------------------------------------------------------------------------------- /test/test_analysis.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | from brnolm.analysis import categorical_entropy, categorical_cross_entropy, categorical_kld 5 | 6 | 7 | class CategoricalEntropyTests(unittest.TestCase): 8 | def setUp(self): 9 | pass 10 | 11 | 12 | def test_uniform_2d(self): 13 | p_x = torch.FloatTensor([0.5, 0.5]) 14 | H_x = torch.FloatTensor([1.0]) 15 | 16 | self.assertTrue(torch.equal(categorical_entropy(p_x), H_x)) 17 | 18 | 19 | def test_uniform_4d(self): 20 | p_x = torch.FloatTensor([0.25, 0.25, 0.25, 0.25]) 21 | H_x = torch.FloatTensor([2.0]) 22 | 23 | self.assertTrue(torch.equal(categorical_entropy(p_x), H_x)) 24 | 25 | 26 | def test_nonuniform_3d(self): 27 | p_x = torch.FloatTensor([0.5, 0.25, 0.25]) 28 | H_x = torch.FloatTensor([1.5]) 29 | 30 | self.assertTrue(torch.equal(categorical_entropy(p_x), H_x)) 31 | 32 | 33 | def test_sparse_2d(self): 34 | p_x = torch.FloatTensor([0.0, 1.0]) 35 | H_x = torch.FloatTensor([0.0]) 36 | 37 | H_x_hat = categorical_entropy(p_x) 38 | 39 | self.assertTrue(torch.equal(H_x_hat, H_x)) 40 | 41 | 42 | def test_uniform_2_dists(self): 43 | p_x = torch.FloatTensor([[0.5, 0.5], [0.0, 1.0]]) 44 | H_x = torch.FloatTensor([1.0, 0.0]) 45 | 46 | H_x_hat = categorical_entropy(p_x) 47 | 48 | self.assertTrue(torch.equal(H_x_hat, H_x)) 49 | 50 | 51 | class CategoricalCrossEntropyTests(unittest.TestCase): 52 | def setUp(self): 53 | pass 54 | 55 | 56 | def test_simple(self): 57 | p_x = torch.FloatTensor([0.5, 0.5]) 58 | q_x = torch.FloatTensor([0.25, 0.75]) 59 | 60 | xent = categorical_cross_entropy(p_x, q_x) 61 | 62 | self.assertAlmostEqual(xent[0], 1.207518749639422, delta=1e-7) 63 | 64 | 65 | def test_sparse_true_dist(self): 66 | p_x = torch.FloatTensor([1, 0.0]) 67 | q_x = torch.FloatTensor([0.25, 0.75]) 68 | 69 | xent = categorical_cross_entropy(p_x, q_x) 70 | 71 | self.assertEqual(xent[0], 2) 72 | 73 | 74 | def test_sparse_both_same(self): 75 | p_x = torch.FloatTensor([0.5, 0.5, 0.0]) 76 | q_x = torch.FloatTensor([0.5, 0.5, 0.0]) 77 | 78 | xent = categorical_cross_entropy(p_x, q_x) 79 | 80 | self.assertEqual(xent[0], 1) 81 | 82 | 83 | def test_sparse_both_different(self): 84 | p_x = torch.FloatTensor([0.5, 0.0, 0.5]) 85 | q_x = torch.FloatTensor([0.5, 0.5, 0.0]) 86 | 87 | xent = categorical_cross_entropy(p_x, q_x) 88 | 89 | self.assertEqual(xent[0], float("inf")) 90 | 91 | 92 | def test_2_dists(self): 93 | p_x = torch.FloatTensor([[0.5, 0.5], [1.0, 0.0]]) 94 | q_x = torch.FloatTensor([[0.25, 0.75], [0.25, 0.75]]) 95 | 96 | xent = categorical_cross_entropy(p_x, q_x) 97 | 98 | self.assertAlmostEqual(xent[0], 1.207518749639422, delta=1e-7) 99 | self.assertEqual(xent[1], 2) 100 | 101 | 102 | def test_one_vs_many(self): 103 | p_x = torch.FloatTensor([[0.5, 0.5], [1.0, 0.0]]) 104 | q_x = torch.FloatTensor([0.25, 0.75]) 105 | 106 | xent = categorical_cross_entropy(p_x, q_x) 107 | 108 | self.assertAlmostEqual(xent[0], 1.207518749639422, delta=1e-7) 109 | self.assertEqual(xent[1], 2) 110 | 111 | 112 | class CategoricalKLDTests(unittest.TestCase): 113 | def setUp(self): 114 | pass 115 | 116 | 117 | def test_same(self): 118 | p_x = torch.FloatTensor([0.5, 0.5]) 119 | q_x = torch.FloatTensor([0.5, 0.5]) 120 | 121 | kld = categorical_kld(p_x, q_x) 122 | 123 | self.assertAlmostEqual(kld[0], 0.0) 124 | 125 | 126 | def test_simple_different(self): 127 | p_x = torch.FloatTensor([0.5, 0.5]) 128 | q_x = torch.FloatTensor([0.75, 0.25]) 129 | 130 | kld = categorical_kld(p_x, q_x) 131 | 132 | self.assertAlmostEqual(kld[0], 0.207518749639422, delta=1e-7) 133 | 134 | 135 | def test_true_has_zero(self): 136 | p_x = torch.FloatTensor([1.0, 0.0]) 137 | q_x = torch.FloatTensor([0.25, 0.75]) 138 | 139 | kld = categorical_kld(p_x, q_x) 140 | 141 | self.assertAlmostEqual(kld[0], 2, delta=1e-7) 142 | 143 | 144 | def test_infinite_kld(self): 145 | p_x = torch.FloatTensor([0.5, 0.5]) 146 | q_x = torch.FloatTensor([0.0, 1.0]) 147 | 148 | kld = categorical_kld(p_x, q_x) 149 | 150 | self.assertEqual(kld[0], float("inf")) 151 | 152 | 153 | def test_2_dists(self): 154 | p_x = torch.FloatTensor([[0.5, 0.5], [0.5, 0.5]]) 155 | q_x = torch.FloatTensor([[0.0, 1.0], [0.5, 0.5]]) 156 | 157 | kld = categorical_kld(p_x, q_x) 158 | 159 | self.assertEqual(kld[0], float("inf")) 160 | self.assertEqual(kld[1], 0.0) 161 | 162 | -------------------------------------------------------------------------------- /scripts/rescoring/rescore-kaldi-latts-continuous.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import logging 5 | import torch 6 | 7 | import brnolm.language_models.vocab as vocab 8 | from brnolm.rescoring.segment_scoring import SegmentScorer 9 | from safe_gpu.safe_gpu import GPUOwner 10 | 11 | import typing 12 | 13 | import brnolm.kaldi_itf 14 | 15 | 16 | def translate_latt_to_model(word_ids, latt_vocab, model_vocab, mode='words'): 17 | words = [latt_vocab.i2w(i) for i in word_ids] 18 | if mode == 'words': 19 | return words + [''] 20 | elif mode == 'chars': 21 | chars = list(" ".join(words)) 22 | return chars + [''] 23 | else: 24 | raise ValueError('Got unexpected mode "{}"'.format(mode)) 25 | 26 | 27 | def select_hidden_state_to_pass(hidden_states): 28 | return hidden_states['1'] 29 | 30 | 31 | def spk_sess(segment_name): 32 | return segment_name.split('-')[0].split('_') 33 | 34 | 35 | def main(args): 36 | logging.info(args) 37 | 38 | mode = 'chars' if args.character_lm else 'words' 39 | 40 | logging.info("reading lattice vocab...") 41 | 42 | with open(args.latt_vocab, 'r') as f: 43 | latt_vocab = vocab.vocab_from_kaldi_wordlist(f, unk_word=args.latt_unk) 44 | 45 | logging.info("reading model...") 46 | device = torch.device('cuda') if args.cuda else torch.device('cpu') 47 | lm = torch.load(args.model_from, map_location=device) 48 | 49 | lm.eval() 50 | 51 | curr_seg = '' 52 | segment_utts: typing.Dict[str, typing.Any] = {} 53 | 54 | custom_h0 = None 55 | nb_carry_overs = 0 56 | nb_new_hs = 0 57 | 58 | with open(args.in_filename) as in_f, open(args.out_filename, 'w') as out_f: 59 | scorer = SegmentScorer(lm, out_f) 60 | 61 | for line in in_f: 62 | fields = line.split() 63 | segment, trans_id = brnolm.kaldi_itf.split_nbest_key(fields[0]) 64 | 65 | word_ids = [int(wi) for wi in fields[1:]] 66 | words = translate_latt_to_model(word_ids, latt_vocab, lm.vocab, mode) 67 | 68 | if not curr_seg: 69 | curr_seg = segment 70 | 71 | if segment != curr_seg: 72 | result = scorer.process_segment(curr_seg, segment_utts, custom_h0) 73 | if args.carry_over == 'always': 74 | custom_h0 = select_hidden_state_to_pass(result.hidden_states) 75 | nb_carry_overs += 1 76 | elif args.carry_over == 'speaker': 77 | if spk_sess(segment) == spk_sess(curr_seg): 78 | custom_h0 = select_hidden_state_to_pass(result.hidden_states) 79 | nb_carry_overs += 1 80 | else: 81 | custom_h0 = None 82 | nb_new_hs += 1 83 | elif args.carry_over == 'never': 84 | custom_h0 = None 85 | nb_new_hs += 1 86 | else: 87 | raise ValueError(f'Unsupported carry over regime {args.carry_over}') 88 | for hyp_no in segment_utts.keys(): 89 | out_f.write(f"{curr_seg}-{hyp_no} {result.scores[hyp_no]}\n") 90 | 91 | curr_seg = segment 92 | segment_utts = {} 93 | 94 | segment_utts[trans_id] = words 95 | 96 | # Last segment: 97 | result = scorer.process_segment(curr_seg, segment_utts) 98 | for hyp_no in segment_utts.keys(): 99 | out_f.write(f"{curr_seg}-{hyp_no} {result.scores[hyp_no]}\n") 100 | 101 | logging.info(f'Hidden state was carried over {nb_carry_overs} times and reset {nb_new_hs} times') 102 | 103 | 104 | if __name__ == '__main__': 105 | parser = argparse.ArgumentParser(description='PyTorch RNN/LSTM Language Model') 106 | parser.add_argument('--latt-vocab', type=str, required=True, 107 | help='word -> int map; Kaldi style "words.txt"') 108 | parser.add_argument('--latt-unk', type=str, default='', 109 | help='unk symbol used in the lattice') 110 | parser.add_argument('--carry-over', default='always', choices=['always', 'speaker', 'never'], 111 | help='When to use the previous hidden state') 112 | parser.add_argument('--cuda', action='store_true', 113 | help='use CUDA') 114 | parser.add_argument('--character-lm', action='store_true', 115 | help='Process strings by characters') 116 | parser.add_argument('--model-from', type=str, required=True, 117 | help='where to load the model from') 118 | parser.add_argument('in_filename', help='second output of nbest-to-linear, textual') 119 | parser.add_argument('out_filename', help='where to put the LM scores') 120 | args = parser.parse_args() 121 | 122 | logging.basicConfig(level=logging.DEBUG) 123 | 124 | if args.cuda: 125 | gpu_owner = GPUOwner() 126 | 127 | main(args) 128 | -------------------------------------------------------------------------------- /brnolm/smm_itf/smm_ivec_extractor.py: -------------------------------------------------------------------------------- 1 | import io 2 | import tempfile 3 | import pickle 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from smm import update_ws 9 | 10 | 11 | class IvecExtractor(): 12 | def __init__(self, model, nb_iters, lr, tokenizer): 13 | self._model = model 14 | self._nb_iters = nb_iters 15 | self._lr = lr 16 | self._tokenizer = tokenizer 17 | 18 | def __call__(self, sentence): 19 | """ Extract i-vectors given the model and stats """ 20 | if isinstance(sentence, str): 21 | data = self._tokenizer.transform([sentence]) 22 | data = torch.from_numpy(data.A.astype(np.float32)) 23 | 24 | else: 25 | data = sentence 26 | 27 | if self._model.cuda: 28 | data = data.cuda() 29 | X = data.t() 30 | 31 | self._model.reset_w(X.size(-1)) # initialize i-vectors to zeros 32 | opt_w = torch.optim.Adagrad([self._model.W], lr=self._lr) 33 | 34 | loss = self._model.loss(X) 35 | 36 | # TODO this is a very nasty hack, needs to be 37 | # completely reworked, a separate class should be 38 | # prepared to implement this 39 | if self._nb_iters < 0: 40 | initrange = 10**(self._nb_iters) 41 | self._model.W.data.uniform_(-initrange, initrange) 42 | else: 43 | for i in range(self._nb_iters): 44 | loss = update_ws(self._model, opt_w, loss, X) 45 | 46 | return self._model.W.data.t().squeeze() 47 | 48 | def __str__(self): 49 | name = "IvecExtractor" 50 | ivec_size = self._model.W.size(0) 51 | 52 | fmt_str = "{} (\n\tiVectors size: {}\n\tLearning rate: {}\n\t # iterations: {}\n)\n" 53 | return fmt_str.format(name, ivec_size, self._lr, self._nb_iters) 54 | 55 | def save(self, f): 56 | tmp_f = tempfile.TemporaryFile() 57 | # self._model.cpu() 58 | torch.save(self._model, tmp_f) 59 | tmp_f.seek(0) 60 | model_bytes = io.BytesIO(tmp_f.read()) 61 | 62 | nb_iters_bytes = io.BytesIO() 63 | pickle.dump(self._nb_iters, nb_iters_bytes) 64 | 65 | lr_bytes = io.BytesIO() 66 | pickle.dump(self._lr, lr_bytes) 67 | 68 | tokenizer_byters = io.BytesIO() 69 | pickle.dump(self._tokenizer, tokenizer_byters) 70 | 71 | complete_smm = {'model': model_bytes, 'tokenizer': tokenizer_byters, 72 | 'lr': lr_bytes, 'nb_iters': nb_iters_bytes} 73 | pickle.dump(complete_smm, f) 74 | 75 | def __eq__(self, other): 76 | return (torch.equal(self._model.T, other._model.T) and 77 | self._lr == other._lr and 78 | self._nb_iters == other._nb_iters and 79 | self._tokenizer == other._tokenizer) 80 | 81 | def zero_bows(self, nb_bows): 82 | empty_docs = ["" for _ in range(nb_bows)] 83 | bows = self._tokenizer.transform(empty_docs) 84 | bows = torch.from_numpy(bows.A.astype(np.float32)) 85 | if self._model.T.is_cuda: 86 | bows = bows.cuda() 87 | return bows 88 | 89 | def build_translator(self, source_vocabulary): 90 | maxes = [] 91 | argmaxes = [] 92 | for w in source_vocabulary: 93 | bow = self._tokenizer.transform([w]) 94 | prototype = torch.from_numpy(bow.A.astype(np.float32)) 95 | p_max, p_argmax = prototype.max(dim=1) 96 | maxes.append(p_max) 97 | argmaxes.append(p_argmax) 98 | 99 | maxes = torch.cat(maxes, dim=0) 100 | argmaxes = torch.cat(argmaxes, dim=0) 101 | 102 | if self._model.T.is_cuda: 103 | maxes = maxes.cuda() 104 | argmaxes = argmaxes.cuda() 105 | 106 | return lambda W: translate(W, argmaxes, 1-maxes, prototype.size(1)) 107 | 108 | 109 | def translate(W, translation_table, translation_mask, dst_vocab_size): 110 | W_flat = W.view(-1) # W was [B, T], W_flat is [BxT] 111 | translation = translation_table[W_flat] # [BxT] 112 | invalid_translations = translation_mask[W_flat].nonzero().view(-1) # [number of words without translation] 113 | 114 | one_hot = W.new(translation.size() + (dst_vocab_size,)).float() # [BxT, SMM vocab_size] 115 | one_hot.zero_() 116 | one_hot.scatter_(1, translation.view(-1, 1), 1) 117 | if len(invalid_translations.size()) > 0: 118 | one_hot[invalid_translations] = 0 # zeroes the entries where no translation should ever happen 119 | 120 | one_hot_reshaped = one_hot.view(W.size() + (dst_vocab_size, )) # [B, T, SMM vocab_size] 121 | return one_hot_reshaped.sum(dim=-2) # [B, SMM vocab_size] 122 | 123 | 124 | def load(f): 125 | complete_lm = pickle.load(f) 126 | 127 | model_bytes = complete_lm['model'] 128 | tmp_f = tempfile.TemporaryFile() 129 | tmp_f.write(model_bytes.getvalue()) 130 | tmp_f.seek(0) 131 | model = torch.load(tmp_f) 132 | 133 | tokenizer_bytes = complete_lm['tokenizer'] 134 | tokenizer_bytes.seek(0) 135 | tokenizer = pickle.load(tokenizer_bytes) 136 | 137 | lr_bytes = complete_lm['lr'] 138 | lr_bytes.seek(0) 139 | lr = pickle.load(lr_bytes) 140 | 141 | nb_iters_bytes = complete_lm['nb_iters'] 142 | nb_iters_bytes.seek(0) 143 | nb_iters = pickle.load(nb_iters_bytes) 144 | 145 | return IvecExtractor(model, nb_iters, lr, tokenizer) 146 | -------------------------------------------------------------------------------- /scripts/train/train-multifile.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import random 5 | 6 | import torch 7 | 8 | from brnolm.data_pipeline.multistream import BatchBuilder 9 | 10 | from brnolm.data_pipeline.reading import tokens_from_file 11 | from brnolm.data_pipeline.temporal_splitting import TemporalSplits 12 | 13 | from brnolm.runtime.runtime_utils import CudaStream, init_seeds, filelist_to_objects, BatchFilter, epoch_summary 14 | from brnolm.runtime.runtime_multifile import evaluate, train 15 | 16 | from brnolm.runtime.loggers import InfinityLogger 17 | 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser(description='PyTorch RNN/LSTM Language Model') 21 | parser.add_argument('--train-list', type=str, required=True, 22 | help='file with paths to training documents') 23 | parser.add_argument('--valid-list', type=str, required=True, 24 | help='file with paths to validation documents') 25 | parser.add_argument('--lr', type=float, default=20, 26 | help='initial learning rate') 27 | parser.add_argument('--beta', type=float, default=0, 28 | help='L2 regularization penalty') 29 | parser.add_argument('--clip', type=float, default=0.25, 30 | help='gradient clipping') 31 | parser.add_argument('--epochs', type=int, default=40, 32 | help='upper epoch limit') 33 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 34 | help='batch size') 35 | parser.add_argument('--target-seq-len', type=int, default=35, 36 | help='sequence length') 37 | parser.add_argument('--seed', type=int, default=1111, 38 | help='random seed') 39 | parser.add_argument('--cuda', action='store_true', 40 | help='use CUDA') 41 | parser.add_argument('--concat-articles', action='store_true', 42 | help='pass hidden states over article boundaries') 43 | parser.add_argument('--shuffle-articles', action='store_true', 44 | help='shuffle the order of articles (at the start of the training)') 45 | parser.add_argument('--keep-shuffling', action='store_true', 46 | help='shuffle the order of articles for each epoch') 47 | parser.add_argument('--min-batch-size', type=int, default=1, 48 | help='stop, once batch is smaller than given size') 49 | parser.add_argument('--log-interval', type=int, default=200, metavar='N', 50 | help='report interval') 51 | parser.add_argument('--load', type=str, required=True, 52 | help='where to load a model from') 53 | parser.add_argument('--save', type=str, required=True, 54 | help='path to save the final model') 55 | args = parser.parse_args() 56 | print(args) 57 | 58 | init_seeds(args.seed, args.cuda) 59 | 60 | print("loading model...") 61 | lm = torch.load(args.load) 62 | if args.cuda: 63 | lm.cuda() 64 | print(lm.model) 65 | 66 | print("preparing data...") 67 | 68 | def temp_splits_from_fn(fn): 69 | tokens = tokens_from_file(fn, lm.vocab, randomize=False) 70 | return TemporalSplits(tokens, lm.model.in_len, args.target_seq_len) 71 | 72 | print("\ttraining...") 73 | train_tss = filelist_to_objects(args.train_list, temp_splits_from_fn) 74 | train_data = BatchBuilder(train_tss, args.batch_size, 75 | discard_h=not args.concat_articles) 76 | if args.cuda: 77 | train_data = CudaStream(train_data) 78 | 79 | print("\tvalidation...") 80 | valid_tss = filelist_to_objects(args.valid_list, temp_splits_from_fn) 81 | valid_data = BatchBuilder(valid_tss, args.batch_size, 82 | discard_h=not args.concat_articles) 83 | if args.cuda: 84 | valid_data = CudaStream(valid_data) 85 | 86 | print("training...") 87 | lr = args.lr 88 | best_val_loss = None 89 | 90 | for epoch in range(1, args.epochs+1): 91 | if args.keep_shuffling: 92 | random.shuffle(train_tss) 93 | train_data = BatchBuilder(train_tss, args.batch_size, 94 | discard_h=not args.concat_articles) 95 | if args.cuda: 96 | train_data = CudaStream(train_data) 97 | 98 | logger = InfinityLogger(epoch, args.log_interval, lr) 99 | train_data_filtered = BatchFilter( 100 | train_data, args.batch_size, args.target_seq_len, args.min_batch_size 101 | ) 102 | optim = torch.optim.SGD(lm.parameters(), lr=lr, weight_decay=args.beta) 103 | 104 | train( 105 | lm, train_data_filtered, optim, logger, 106 | clip=args.clip, 107 | use_ivecs=False 108 | ) 109 | train_data_filtered.report() 110 | 111 | val_loss = evaluate(lm, valid_data, use_ivecs=False) 112 | print(epoch_summary(epoch, logger.nb_updates(), logger.time_since_creation(), val_loss)) 113 | 114 | # Save the model if the validation loss is the best we've seen so far. 115 | if not best_val_loss or val_loss < best_val_loss: 116 | torch.save(lm, args.save) 117 | best_val_loss = val_loss 118 | else: 119 | lr /= 2.0 120 | pass 121 | --------------------------------------------------------------------------------