├── speech ├── utils │ ├── __init__.py │ ├── wave.py │ ├── score.py │ ├── data_helpers.py │ ├── io.py │ └── convert.py ├── __init__.py ├── models │ ├── __init__.py │ ├── ctc_model.py │ ├── transducer_model.py │ ├── model.py │ ├── ctc_decoder.py │ └── seq2seq.py └── loader.py ├── examples ├── .gitignore ├── wsj │ ├── .gitignore │ ├── README.md │ ├── preprocess.sh │ ├── seq2seq_config.json │ └── preprocess.py ├── timit │ ├── .gitignore │ ├── data_prep.sh │ ├── phones.60-48-39.map │ ├── ctc_config.json │ ├── transducer_config.json │ ├── score.py │ ├── seq2seq_config.json │ ├── README.md │ └── preprocess.py └── librispeech │ ├── config.json │ ├── download.py │ └── preprocess.py ├── tests ├── test0.wav ├── test1.wav ├── wave_test.py ├── test.json ├── shared.py ├── model_test.py ├── io_test.py ├── ctc_test.py ├── loader_test.py └── seq2seq_test.py ├── setup.sh ├── requirements.txt ├── Makefile ├── .gitignore ├── eval.py ├── README.md ├── train.py └── LICENSE /speech/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | */*.json 2 | fb/ 3 | -------------------------------------------------------------------------------- /examples/wsj/.gitignore: -------------------------------------------------------------------------------- 1 | models 2 | data 3 | -------------------------------------------------------------------------------- /examples/timit/.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | models 3 | -------------------------------------------------------------------------------- /speech/__init__.py: -------------------------------------------------------------------------------- 1 | from speech.utils.io import save, load 2 | from speech.utils.score import compute_cer 3 | -------------------------------------------------------------------------------- /tests/test0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewjimpson9551/Speech_to_Text_dev21/HEAD/tests/test0.wav -------------------------------------------------------------------------------- /tests/test1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewjimpson9551/Speech_to_Text_dev21/HEAD/tests/test1.wav -------------------------------------------------------------------------------- /examples/timit/data_prep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | timit_path=$1 4 | python preprocess.py $timit_path 5 | ln -s $timit_path data 6 | -------------------------------------------------------------------------------- /speech/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from speech.models.model import Model 3 | from speech.models.seq2seq import Seq2Seq 4 | from speech.models.ctc_model import CTC 5 | from speech.models.transducer_model import Transducer 6 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Run `source setup.sh` from this directory. 3 | export PYTHONPATH=`pwd`:`pwd`/libs/warp-ctc/pytorch_binding:`pwd`/libs:$PYTHONPATH 4 | 5 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:`pwd`/libs/warp-ctc/build 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cffi==1.11.2 2 | editdistance==0.3.1 3 | numpy==1.13.3 4 | protobuf==3.4.0 5 | py==1.10.0 6 | pycparser==2.18 7 | pytest==3.2.3 8 | PyYAML==5.4 9 | scipy==0.18.1 10 | six==1.11.0 11 | SoundFile==0.10.2 12 | tensorboard-logger==0.0.4 13 | tqdm==4.19.4 14 | -------------------------------------------------------------------------------- /tests/wave_test.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | import speech.utils.wave as wave 5 | 6 | def test_load(): 7 | audio, samp_rate = wave.array_from_wave("test0.wav") 8 | 9 | assert samp_rate == 16000 10 | assert audio.dtype == np.int16 11 | 12 | def test_duration(): 13 | duration = wave.wav_duration("test0.wav") 14 | 15 | assert round(duration, 3) == 1.101 16 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | all: warp transduce 3 | 4 | warp: 5 | git clone https://github.com/awni/warp-ctc.git libs/warp-ctc 6 | cd libs/warp-ctc; mkdir build; cd build; cmake ../ && make; \ 7 | cd ../pytorch_binding; python build.py 8 | 9 | # TODO, awni, put this into a package 10 | transduce: 11 | git clone https://github.com/awni/transducer.git libs/transducer 12 | cd libs/transducer; python build.py 13 | 14 | -------------------------------------------------------------------------------- /speech/utils/wave.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import soundfile 7 | 8 | def array_from_wave(file_name): 9 | audio, samp_rate = soundfile.read(file_name, dtype='int16') 10 | return audio, samp_rate 11 | 12 | def wav_duration(file_name): 13 | audio, samp_rate = soundfile.read(file_name, dtype='int16') 14 | nframes = audio.shape[0] 15 | duration = nframes / samp_rate 16 | return duration 17 | -------------------------------------------------------------------------------- /examples/wsj/README.md: -------------------------------------------------------------------------------- 1 | 2 | A Recipe for the Wall Street Journal (WSJ) corpus. 3 | 4 | The WSJ corpus consists of about 80 hours of read sentences taken from the Wall 5 | Street Journal. The WSJ corpus can be purchased from the LDC as LDC93S6B (wsj0) 6 | and LDC94S13B (wsj1). 7 | 8 | In these experiments we use three subsets following the Kaldi WSJ recipe: 9 | 10 | - train: 37318 utterances, referred to as train_si284 in Kaldi 11 | - dev: 503 utterances, referred to as dev92 in Kaldi 12 | - test: 333 utterances, referred to as eval93 in Kaldi 13 | 14 | -------------------------------------------------------------------------------- /tests/test.json: -------------------------------------------------------------------------------- 1 | {"text" : "hello world", "duration" : 1.101, "audio" : "test0.wav"} 2 | {"text" : "hello world", "duration" : 1.101, "audio" : "test0.wav"} 3 | {"text" : "hello world", "duration" : 1.101, "audio" : "test0.wav"} 4 | {"text" : "hello world", "duration" : 1.101, "audio" : "test0.wav"} 5 | {"text" : "hello hi", "duration" : 1.571, "audio" : "test1.wav"} 6 | {"text" : "hello hi", "duration" : 1.571, "audio" : "test1.wav"} 7 | {"text" : "hello hi", "duration" : 1.571, "audio" : "test1.wav"} 8 | {"text" : "hello hi", "duration" : 1.571, "audio" : "test1.wav"} 9 | -------------------------------------------------------------------------------- /speech/utils/score.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import editdistance 6 | 7 | def compute_cer(results): 8 | """ 9 | Arguments: 10 | results (list): list of ground truth and 11 | predicted sequence pairs. 12 | 13 | Returns the CER for the full set. 14 | """ 15 | dist = sum(editdistance.eval(label, pred) 16 | for label, pred in results) 17 | total = sum(len(label) for label, _ in results) 18 | return dist / total 19 | -------------------------------------------------------------------------------- /speech/utils/data_helpers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import glob 6 | import os 7 | import tqdm 8 | 9 | from speech.utils import convert 10 | 11 | def convert_full_set(path, pattern, new_ext="wav", **kwargs): 12 | pattern = os.path.join(path, pattern) 13 | audio_files = glob.glob(pattern) 14 | for af in tqdm.tqdm(audio_files): 15 | base, ext = os.path.splitext(af) 16 | wav = base + os.path.extsep + new_ext 17 | convert.to_wave(af, wav, **kwargs) 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /examples/librispeech/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed" : 2017, 3 | "save_path" : "/deep/group/awni/speech_models/test", 4 | 5 | "data" : { 6 | "train_set" : "/deep/group/speech/datasets/LibriSpeech/train-toy.json", 7 | "dev_set" : "/deep/group/speech/datasets/LibriSpeech/dev-toy.json" 8 | }, 9 | 10 | "optimizer" : { 11 | "batch_size" : 8, 12 | "epochs" : 1000, 13 | "learning_rate" : 1e-3, 14 | "momentum" : 0.0 15 | }, 16 | 17 | "model" : { 18 | "encoder" : { 19 | "rnn" : { 20 | "dim" : 256, 21 | "layers" : 1 22 | } 23 | }, 24 | "decoder" : { 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /tests/shared.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | model_config = { 5 | "dropout" : 0.0, 6 | "encoder" : { 7 | "conv" : [ 8 | [32, 5, 32, 2] 9 | ], 10 | "rnn" : { 11 | "dim" : 16, 12 | "bidirectional" : False, 13 | "layers" : 1 14 | } 15 | } 16 | } 17 | 18 | def gen_fake_data(freq_dim, output_dim, max_time=100, 19 | max_seq_len=20, batch_size=4): 20 | data = [] 21 | for i in range(batch_size): 22 | inputs = np.random.randn(max_time, freq_dim) 23 | labels = np.random.randint(0, output_dim, max_seq_len) 24 | data.append((inputs, labels)) 25 | inputs, labels = list(zip(*data)) 26 | return inputs, labels 27 | 28 | -------------------------------------------------------------------------------- /tests/model_test.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | 5 | import speech.models 6 | 7 | import shared 8 | 9 | def test_model(): 10 | time_steps = 100 11 | freq_dim = 40 12 | batch_size = 4 13 | 14 | model = speech.models.Model(freq_dim, shared.model_config) 15 | 16 | x = torch.randn(batch_size, time_steps, freq_dim) 17 | 18 | x_enc = model.encode(x) 19 | t_dim = model.conv_out_size(time_steps, 0) 20 | expected_size = torch.Size((batch_size, t_dim, model.encoder_dim)) 21 | 22 | # Check output size is correct. 23 | assert x_enc.size() == expected_size 24 | 25 | # Check cuda attribute works 26 | assert not model.is_cuda 27 | if torch.cuda.is_available(): 28 | model.cuda() 29 | assert model.is_cuda 30 | -------------------------------------------------------------------------------- /speech/utils/io.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import pickle 4 | import torch 5 | 6 | MODEL = "model" 7 | PREPROC = "preproc.pyc" 8 | 9 | def get_names(path, tag): 10 | tag = tag + "_" if tag else "" 11 | model = os.path.join(path, tag + MODEL) 12 | preproc = os.path.join(path, tag + PREPROC) 13 | return model, preproc 14 | 15 | def save(model, preproc, path, tag=""): 16 | model_n, preproc_n = get_names(path, tag) 17 | torch.save(model, model_n) 18 | with open(preproc_n, 'wb') as fid: 19 | pickle.dump(preproc, fid) 20 | 21 | def load(path, tag=""): 22 | model_n, preproc_n = get_names(path, tag) 23 | model = torch.load(model_n) 24 | with open(preproc_n, 'rb') as fid: 25 | preproc = pickle.load(fid) 26 | return model, preproc 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/timit/phones.60-48-39.map: -------------------------------------------------------------------------------- 1 | aa aa aa 2 | ae ae ae 3 | ah ah ah 4 | ao ao aa 5 | aw aw aw 6 | ax ax ah 7 | ax-h ax ah 8 | axr er er 9 | ay ay ay 10 | b b b 11 | bcl vcl sil 12 | ch ch ch 13 | d d d 14 | dcl vcl sil 15 | dh dh dh 16 | dx dx dx 17 | eh eh eh 18 | el el l 19 | em m m 20 | en en n 21 | eng ng ng 22 | epi epi sil 23 | er er er 24 | ey ey ey 25 | f f f 26 | g g g 27 | gcl vcl sil 28 | h# sil sil 29 | hh hh hh 30 | hv hh hh 31 | ih ih ih 32 | ix ix ih 33 | iy iy iy 34 | jh jh jh 35 | k k k 36 | kcl cl sil 37 | l l l 38 | m m m 39 | n n n 40 | ng ng ng 41 | nx n n 42 | ow ow ow 43 | oy oy oy 44 | p p p 45 | pau sil sil 46 | pcl cl sil 47 | q 48 | r r r 49 | s s s 50 | sh sh sh 51 | t t t 52 | tcl cl sil 53 | th th th 54 | uh uh uh 55 | uw uw uw 56 | ux uw uw 57 | v v v 58 | w w w 59 | y y y 60 | z z z 61 | zh zh sh 62 | -------------------------------------------------------------------------------- /examples/timit/ctc_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed" : 2017, 3 | "save_path" : "examples/timit/models/ctc_best", 4 | 5 | "data" : { 6 | "train_set" : "examples/timit/data/timit/train.json", 7 | "dev_set" : "examples/timit/data/timit/dev.json", 8 | "start_and_end" : false 9 | }, 10 | 11 | "optimizer" : { 12 | "batch_size" : 8, 13 | "epochs" : 200, 14 | "learning_rate" : 1e-3, 15 | "momentum" : 0.0 16 | }, 17 | 18 | "model" : { 19 | "class" : "CTC", 20 | "dropout" : 0.4, 21 | "encoder" : { 22 | "conv" : [ 23 | [32, 5, 32, 2], 24 | [32, 5, 32, 1] 25 | ], 26 | "rnn" : { 27 | "dim" : 256, 28 | "bidirectional" : true, 29 | "layers" : 4 30 | } 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /tests/io_test.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import speech.models 4 | import speech.loader 5 | 6 | import shared 7 | 8 | def test_save(): 9 | 10 | freq_dim = 120 11 | model = speech.models.Model(freq_dim, 12 | shared.model_config) 13 | 14 | batch_size = 2 15 | data_json = "test.json" 16 | preproc = speech.loader.Preprocessor(data_json) 17 | 18 | save_dir = tempfile.mkdtemp() 19 | speech.save(model, preproc, save_dir) 20 | 21 | s_model, s_preproc = speech.load(save_dir) 22 | assert hasattr(s_preproc, 'mean') 23 | assert hasattr(s_preproc, 'std') 24 | assert hasattr(s_preproc, 'int_to_char') 25 | assert hasattr(s_preproc, 'char_to_int') 26 | 27 | msd = model.state_dict() 28 | for k, v in s_model.state_dict().items(): 29 | assert k in msd 30 | assert hasattr(s_model, 'encoder_dim') 31 | assert hasattr(s_model, 'is_cuda') 32 | -------------------------------------------------------------------------------- /examples/timit/transducer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed" : 2017, 3 | "save_path" : "examples/timit/models/trans_best", 4 | 5 | "data" : { 6 | "train_set" : "examples/timit/data/timit/train.json", 7 | "dev_set" : "examples/timit/data/timit/dev.json", 8 | "start_and_end" : false 9 | }, 10 | 11 | "optimizer" : { 12 | "batch_size" : 8, 13 | "epochs" : 200, 14 | "learning_rate" : 1e-3, 15 | "momentum" : 0.0 16 | }, 17 | 18 | "model" : { 19 | "class" : "Transducer", 20 | "dropout" : 0.5, 21 | "encoder" : { 22 | "conv" : [ 23 | [8, 5, 32, 2], 24 | [8, 5, 32, 1] 25 | ], 26 | "rnn" : { 27 | "dim" : 256, 28 | "bidirectional" : true, 29 | "layers" : 4 30 | } 31 | }, 32 | "decoder" : { 33 | "embedding_dim" : 256, 34 | "layers" : 1 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /examples/timit/score.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import editdistance 7 | import json 8 | 9 | import preprocess 10 | 11 | def remap(data): 12 | _, m48_39 = preprocess.load_phone_map() 13 | for d in data: 14 | d['prediction'] = [m48_39[p] for p in d['prediction']] 15 | d['label'] = [m48_39[p] for p in d['label']] 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser( 19 | description="CER on Timit with reduced phoneme set.") 20 | 21 | parser.add_argument("data_json", 22 | help="JSON with the transcripts.") 23 | args = parser.parse_args() 24 | 25 | with open(args.data_json, 'r') as fid: 26 | data = [json.loads(l) for l in fid] 27 | 28 | remap(data) 29 | dist = sum(editdistance.eval(d['label'], d['prediction']) 30 | for d in data) 31 | total = sum(len(d['label']) for d in data) 32 | print("CER {:.3f}".format(dist / total)) 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /examples/timit/seq2seq_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed" : 2017, 3 | "save_path" : "examples/timit/models/seq2seq_best", 4 | 5 | "data" : { 6 | "train_set" : "examples/timit/data/timit/train.json", 7 | "dev_set" : "examples/timit/data/timit/dev.json", 8 | "start_and_end" : true 9 | }, 10 | 11 | "optimizer" : { 12 | "batch_size" : 8, 13 | "epochs" : 200, 14 | "learning_rate" : 1e-3, 15 | "momentum" : 0.0 16 | }, 17 | 18 | "model" : { 19 | "class" : "Seq2Seq", 20 | "dropout" : 0.3, 21 | "encoder" : { 22 | "conv" : [ 23 | [8, 5, 32, 2], 24 | [8, 5, 32, 1] 25 | ], 26 | "rnn" : { 27 | "dim" : 256, 28 | "bidirectional" : true, 29 | "layers" : 4 30 | } 31 | }, 32 | "decoder" : { 33 | "sample_prob" : 0.4, 34 | "embedding_dim" : 256, 35 | "log_t" : true, 36 | "layers" : 1 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /examples/wsj/preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Call this script to preprocess the WSJ datasets. 4 | # 5 | # This script follows the Kaldi setup closely 6 | # https://github.com/kaldi-asr/kaldi/tree/master/egs/wsj/s5 7 | # 8 | # Usage: 9 | # ./preprocess.sh 10 | # 11 | # The default behavior is to convert all the *.wv1 SPHERE files 12 | # in the wsj corpus to 'wav' format. Thus you will need write access 13 | # to . 14 | # 15 | # Upon completion three files will be created and saved in 16 | # : 17 | # - train_si284.json (37318 utts) 18 | # - dev_93.json (503 utts) 19 | # - eval_92.json (333 utts) 20 | 21 | # Path where the dataset is stored. 22 | wsj_base=$1 23 | 24 | # Path to save dataset jsons. 25 | save_path=$2 26 | 27 | # Install sph2pipe 28 | sph_v=sph2pipe_v2.5 29 | wget http://www.openslr.org/resources/3/${sph_v}.tar.gz 30 | tar -xzvf ${sph_v}.tar.gz 31 | cd ${sph_v} && gcc -o sph2pipe *.c -lm 32 | cd .. 33 | rm ${sph_v}.tar.gz 34 | 35 | python preprocess.py $1 $2 --convert 36 | rm -rf $sph_v 37 | -------------------------------------------------------------------------------- /examples/wsj/seq2seq_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed" : 2017, 3 | "save_path" : "examples/wsj/models/nn_logt_drop0.2_sample0.4", 4 | 5 | "data" : { 6 | "train_set" : "examples/wsj/data/train_si284.json", 7 | "dev_set" : "examples/wsj/data/test_dev93.json", 8 | "start_and_end" : true 9 | }, 10 | 11 | "optimizer" : { 12 | "batch_size" : 16, 13 | "epochs" : 50, 14 | "learning_rate" : 1e-3, 15 | "momentum" : 0.0 16 | }, 17 | 18 | "model" : { 19 | "class" : "Seq2Seq", 20 | "dropout" : 0.2, 21 | "encoder" : { 22 | "conv" : [ 23 | [32, 5, 8, 2], 24 | [32, 5, 8, 2] 25 | ], 26 | "rnn" : { 27 | "dim" : 256, 28 | "bidirectional" : true, 29 | "layers" : 4 30 | } 31 | }, 32 | "decoder" : { 33 | "sample_prob": 0.2, 34 | "embedding_dim" : 256, 35 | "log_t" : true, 36 | "layers" : 1 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /tests/ctc_test.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.autograd as autograd 4 | 5 | from speech.models import CTC 6 | 7 | import shared 8 | 9 | def test_ctc_model(): 10 | freq_dim = 40 11 | vocab_size = 10 12 | 13 | batch = shared.gen_fake_data(freq_dim, vocab_size) 14 | batch_size = len(batch[0]) 15 | 16 | model = CTC(freq_dim, vocab_size, shared.model_config) 17 | out = model(batch) 18 | 19 | assert out.size()[0] == batch_size 20 | 21 | # CTC model adds the blank token to the vocab 22 | assert out.size()[2] == (vocab_size + 1) 23 | 24 | assert len(out.size()) == 3 25 | 26 | loss = model.loss(batch) 27 | preds = model.infer(batch) 28 | assert len(preds) == batch_size 29 | 30 | 31 | def test_argmax_decode(): 32 | blank = 0 33 | pre = [1, 2, 2, 0, 0, 0, 2, 1] 34 | post = [1, 2, 2, 1] 35 | assert CTC.max_decode(pre, blank) == post 36 | 37 | pre = [2, 2, 2] 38 | post = [2] 39 | assert CTC.max_decode(pre, blank) == post 40 | 41 | pre = [0, 0, 0] 42 | post = [] 43 | assert CTC.max_decode(pre, blank) == post 44 | -------------------------------------------------------------------------------- /tests/loader_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from speech import loader 4 | 5 | def test_dataset(): 6 | batch_size = 2 7 | data_json = "test.json" 8 | preproc = loader.Preprocessor(data_json) 9 | dataset = loader.AudioDataset(data_json, preproc, batch_size) 10 | 11 | # Num chars plus start and end tokens 12 | assert preproc.vocab_size == 11 13 | s_idx = preproc.vocab_size - 1 14 | assert preproc.int_to_char[s_idx] == preproc.START 15 | 16 | inputs, targets = dataset[0] 17 | 18 | # Inputs should be time x frequency 19 | assert inputs.shape[1] == preproc.input_dim 20 | assert inputs.dtype == np.float32 21 | 22 | # Correct number of examples 23 | assert len(dataset.data) == 8 24 | 25 | def test_loader(): 26 | 27 | batch_size = 2 28 | data_json = "test.json" 29 | preproc = loader.Preprocessor(data_json) 30 | ldr = loader.make_loader(data_json, preproc, 31 | batch_size, num_workers=0) 32 | 33 | # Test that batches are properly sorted by size 34 | for inputs, labels in ldr: 35 | assert inputs[0].shape == inputs[1].shape 36 | -------------------------------------------------------------------------------- /examples/librispeech/download.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import os 7 | import tarfile 8 | import urllib.request 9 | 10 | EXT = ".tar.gz" 11 | FILES = ["raw-metadata", "train-clean-100", "dev-clean"] 12 | BASE_URL = "http://www.openslr.org/resources/12/" 13 | 14 | def download_and_extract(in_file, out_dir): 15 | in_file = in_file + EXT 16 | file_url = os.path.join(BASE_URL, in_file) 17 | out_file = os.path.join(out_dir, in_file) 18 | 19 | # Download and extract zip file. 20 | urllib.request.urlretrieve(file_url, filename=out_file) 21 | with tarfile.open(out_file) as tf: 22 | tf.extractall(path=out_dir) 23 | 24 | # Remove zip file after use 25 | os.remove(out_file) 26 | 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser( 30 | description="Download librispeech dataset.") 31 | 32 | parser.add_argument("output_directory", 33 | help="The dataset is saved in /LibriSpeech.") 34 | args = parser.parse_args() 35 | 36 | for f in FILES: 37 | print("Downloading {}".format(f)) 38 | download_and_extract(f, args.output_directory) 39 | -------------------------------------------------------------------------------- /tests/seq2seq_test.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import torch.autograd as autograd 5 | 6 | from speech.models import Seq2Seq 7 | 8 | import shared 9 | 10 | def test_model(): 11 | freq_dim = 120 12 | vocab_size = 10 13 | 14 | np.random.seed(1337) 15 | torch.manual_seed(1337) 16 | 17 | conf = shared.model_config 18 | rnn_dim = conf['encoder']['rnn']['dim'] 19 | conf["decoder"] = {"embedding_dim" : rnn_dim, 20 | "layers" : 2} 21 | model = Seq2Seq(freq_dim, vocab_size + 1, conf) 22 | batch = shared.gen_fake_data(freq_dim, vocab_size) 23 | batch_size = len(batch[0]) 24 | 25 | out = model(batch) 26 | loss = model.loss(batch) 27 | 28 | assert out.size()[0] == batch_size 29 | assert out.size()[2] == vocab_size 30 | assert len(out.size()) == 3 31 | 32 | x, y = model.collate(*batch) 33 | x_enc = model.encode(x) 34 | 35 | state = None 36 | out_s = [] 37 | for t in range(y.size()[1] - 1): 38 | ox, state = model.decode_step(x_enc, y[:,t:t+1], state=state) 39 | out_s.append(ox) 40 | out_s = torch.stack(out_s, dim=1) 41 | assert out.size() == out_s.size() 42 | assert np.allclose(out_s.data.numpy(), 43 | out.data.numpy(), 44 | rtol=1e-5, 45 | atol=1e-7) 46 | 47 | -------------------------------------------------------------------------------- /speech/utils/convert.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import subprocess 6 | 7 | FFMPEG = "ffmpeg" 8 | AVCONV = "avconv" 9 | 10 | def check_install(*args): 11 | try: 12 | subprocess.check_output(args, 13 | stderr=subprocess.STDOUT) 14 | return True 15 | except OSError as e: 16 | return False 17 | 18 | def check_avconv(): 19 | """ 20 | Check if avconv is installed. 21 | """ 22 | return check_install(AVCONV, "-version") 23 | 24 | def check_ffmpeg(): 25 | """ 26 | Check if ffmpeg is installed. 27 | """ 28 | return check_install(FFMPEG, "-version") 29 | 30 | 31 | USE_AVCONV = check_avconv() 32 | USE_FFMPEG = check_ffmpeg() 33 | if not (USE_AVCONV or USE_FFMPEG): 34 | raise OSError(("Must have avconv or ffmpeg " 35 | "installed to use conversion functions.")) 36 | USE_AVCONV = not USE_FFMPEG 37 | 38 | def to_wave(audio_file, wave_file, use_avconv=USE_AVCONV): 39 | """ 40 | Convert audio file to wave format. 41 | """ 42 | prog = AVCONV if use_avconv else FFMPEG 43 | args = [prog, "-y", "-i", audio_file, "-f", "wav", wave_file] 44 | subprocess.check_output(args, stderr=subprocess.STDOUT) 45 | 46 | if __name__ == "__main__": 47 | print("Use avconv", USE_AVCONV) 48 | 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # TODOs 2 | libs/ 3 | notebooks/ 4 | 5 | # vim 6 | .*.swp 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | env/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # dotenv 90 | .env 91 | 92 | # virtualenv 93 | .venv 94 | venv/ 95 | ENV/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import json 7 | import torch 8 | import tqdm 9 | import speech 10 | import speech.loader as loader 11 | 12 | def eval_loop(model, ldr): 13 | all_preds = []; all_labels = [] 14 | for batch in tqdm.tqdm(ldr): 15 | preds = model.infer(batch) 16 | all_preds.extend(preds) 17 | all_labels.extend(batch[1]) 18 | return list(zip(all_labels, all_preds)) 19 | 20 | def run(model_path, dataset_json, 21 | batch_size=8, tag="best", 22 | out_file=None): 23 | 24 | use_cuda = torch.cuda.is_available() 25 | 26 | model, preproc = speech.load(model_path, tag=tag) 27 | ldr = loader.make_loader(dataset_json, 28 | preproc, batch_size) 29 | 30 | model.cuda() if use_cuda else model.cpu() 31 | model.set_eval() 32 | 33 | results = eval_loop(model, ldr) 34 | results = [(preproc.decode(label), preproc.decode(pred)) 35 | for label, pred in results] 36 | cer = speech.compute_cer(results) 37 | print("CER {:.3f}".format(cer)) 38 | 39 | if out_file is not None: 40 | with open(out_file, 'w') as fid: 41 | for label, pred in results: 42 | res = {'prediction' : pred, 43 | 'label' : label} 44 | json.dump(res, fid) 45 | fid.write("\n") 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser( 49 | description="Eval a speech model.") 50 | 51 | parser.add_argument("model", 52 | help="A path to a stored model.") 53 | parser.add_argument("dataset", 54 | help="A json file with the dataset to evaluate.") 55 | parser.add_argument("--last", action="store_true", 56 | help="Last saved model instead of best on dev set.") 57 | parser.add_argument("--save", 58 | help="Optional file to save predicted results.") 59 | args = parser.parse_args() 60 | 61 | run(args.model, args.dataset, 62 | tag=None if args.last else "best", 63 | out_file=args.save) 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # speech 2 | 3 | Speech is an open-source package to build end-to-end models for automatic 4 | speech recognition. Sequence-to-sequence models with attention, 5 | Connectionist Temporal Classification and the RNN Sequence Transducer 6 | are currently supported. 7 | 8 | The goal of this software is to facilitate research in end-to-end models for 9 | speech recognition. The models are implemented in PyTorch. 10 | 11 | The software has only been tested in Python3.6. 12 | 13 | **We will not be providing backward compatability for Python2.7.** 14 | 15 | ## Install 16 | 17 | We recommend creating a virtual environment and installing the python 18 | requirements there. 19 | 20 | ``` 21 | virtualenv 22 | source /bin/activate 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | Then follow the installation instructions for a version of 27 | [PyTorch](http://pytorch.org/) which works for your machine. 28 | 29 | After all the python requirements are installed, from the top level directory, 30 | run: 31 | 32 | ``` 33 | make 34 | ``` 35 | 36 | The build process requires CMake as well as Make. 37 | 38 | After that, source the `setup.sh` from the repo root. 39 | 40 | ``` 41 | source setup.sh 42 | ``` 43 | 44 | Consider adding this to your `bashrc`. 45 | 46 | You can verify the install was successful by running the 47 | tests from the `tests` directory. 48 | 49 | ``` 50 | cd tests 51 | pytest 52 | ``` 53 | 54 | ## Run 55 | 56 | To train a model run 57 | ``` 58 | python train.py 59 | ``` 60 | 61 | After the model is done training you can evaluate it with 62 | 63 | ``` 64 | python eval.py 65 | ``` 66 | 67 | To see the available options for each script use `-h`: 68 | 69 | ``` 70 | python {train, eval}.py -h 71 | ``` 72 | 73 | ## Examples 74 | 75 | For examples of model configurations and datasets, visit the examples 76 | directory. Each example dataset should have instructions and/or scripts for 77 | downloading and preparing the data. There should also be one or more model 78 | configurations available. The results for each configuration will documented in 79 | each examples corresponding `README.md`. 80 | -------------------------------------------------------------------------------- /speech/models/ctc_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import torch 7 | import torch.autograd as autograd 8 | 9 | import functions.ctc as ctc 10 | from . import model 11 | from .ctc_decoder import decode 12 | 13 | class CTC(model.Model): 14 | def __init__(self, freq_dim, output_dim, config): 15 | super().__init__(freq_dim, config) 16 | 17 | # include the blank token 18 | self.blank = output_dim 19 | self.fc = model.LinearND(self.encoder_dim, output_dim + 1) 20 | 21 | def forward(self, batch): 22 | x, y, x_lens, y_lens = self.collate(*batch) 23 | return self.forward_impl(x) 24 | 25 | def forward_impl(self, x, softmax=False): 26 | if self.is_cuda: 27 | x = x.cuda() 28 | x = self.encode(x) 29 | x = self.fc(x) 30 | if softmax: 31 | return torch.nn.functional.softmax(x, dim=2) 32 | return x 33 | 34 | def loss(self, batch): 35 | x, y, x_lens, y_lens = self.collate(*batch) 36 | out = self.forward_impl(x) 37 | 38 | loss_fn = ctc.CTCLoss() 39 | loss = loss_fn(out, y, x_lens, y_lens) 40 | return loss 41 | 42 | def collate(self, inputs, labels): 43 | max_t = max(i.shape[0] for i in inputs) 44 | max_t = self.conv_out_size(max_t, 0) 45 | x_lens = torch.IntTensor([max_t] * len(inputs)) 46 | x = torch.FloatTensor(model.zero_pad_concat(inputs)) 47 | y_lens = torch.IntTensor([len(l) for l in labels]) 48 | y = torch.IntTensor([l for label in labels for l in label]) 49 | batch = [x, y, x_lens, y_lens] 50 | if self.volatile: 51 | for v in batch: 52 | v.volatile = True 53 | return batch 54 | 55 | def infer(self, batch): 56 | x, y, x_lens, y_lens = self.collate(*batch) 57 | probs = self.forward_impl(x, softmax=True) 58 | probs = probs.data.cpu().numpy() 59 | return [decode(p, beam_size=1, blank=self.blank)[0] 60 | for p in probs] 61 | 62 | @staticmethod 63 | def max_decode(pred, blank): 64 | prev = pred[0] 65 | seq = [prev] if prev != blank else [] 66 | for p in pred[1:]: 67 | if p != blank and p != prev: 68 | seq.append(p) 69 | prev = p 70 | return seq 71 | -------------------------------------------------------------------------------- /examples/librispeech/preprocess.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import glob 7 | import json 8 | import os 9 | import tqdm 10 | import wave 11 | 12 | from speech.utils import data_helpers 13 | from speech.utils import wave 14 | 15 | SETS = { 16 | "train" : ["train-clean-100"], 17 | "dev" : ["dev-clean"], 18 | "test" : [] 19 | } 20 | 21 | def load_transcripts(path): 22 | pattern = os.path.join(path, "*/*/*.trans.txt") 23 | files = glob.glob(pattern) 24 | data = {} 25 | for f in files: 26 | with open(f) as fid: 27 | lines = (l.strip().split() for l in fid) 28 | lines = ((l[0], " ".join(l[1:])) for l in lines) 29 | data.update(lines) 30 | return data 31 | 32 | def path_from_key(key, prefix, ext): 33 | dirs = key.split("-") 34 | dirs[-1] = key 35 | path = os.path.join(prefix, *dirs) 36 | return path + os.path.extsep + ext 37 | 38 | def convert_to_wav(path): 39 | data_helpers.convert_full_set(path, "*/*/*/*.flac") 40 | 41 | def clean_text(text): 42 | return text.strip().lower() 43 | 44 | def build_json(path): 45 | transcripts = load_transcripts(path) 46 | dirname = os.path.dirname(path) 47 | basename = os.path.basename(path) + os.path.extsep + "json" 48 | with open(os.path.join(dirname, basename), 'w') as fid: 49 | for k, t in tqdm.tqdm(transcripts.items()): 50 | wave_file = path_from_key(k, path, ext="wav") 51 | dur = wave.wav_duration(wave_file) 52 | t = clean_text(t) 53 | datum = {'text' : t, 54 | 'duration' : dur, 55 | 'audio' : wave_file} 56 | json.dump(datum, fid) 57 | fid.write("\n") 58 | 59 | if __name__ == "__main__": 60 | parser = argparse.ArgumentParser( 61 | description="Preprocess librispeech dataset.") 62 | 63 | parser.add_argument("output_directory", 64 | help="The dataset is saved in /LibriSpeech.") 65 | args = parser.parse_args() 66 | 67 | path = os.path.join(args.output_directory, "LibriSpeech") 68 | 69 | print("Converting files from flac to wave...") 70 | convert_to_wav(path) 71 | for dataset, dirs in SETS.items(): 72 | for d in dirs: 73 | print("Preprocessing {}".format(d)) 74 | prefix = os.path.join(path, d) 75 | build_json(prefix) 76 | -------------------------------------------------------------------------------- /examples/timit/README.md: -------------------------------------------------------------------------------- 1 | The Timit Speech corpus must be purchased from the LDC to run these 2 | experiments. The catalog number is [LDC93S1]. 3 | 4 | The data is mapped from 61 to 48 phonemes for training. For final test set 5 | evaluation the 48 phonemes are again mapped to 39. The phoneme mapping is the 6 | standard recipe, the map used here is taken from the [Kaldi TIMIT recipe]. 7 | 8 | ## Setup 9 | 10 | Once you have the TIMIT data downloaded, run 11 | 12 | ``` 13 | ./data_prep.sh 14 | ``` 15 | 16 | This script will convert the `.flac` to `.wav` files and store them in the same 17 | location. You'll need write access to directory where timit is stored. It will 18 | then symlink the timit directory to `./data`. There should be three data json 19 | files in `data/timit`: 20 | 21 | - `train.json`: 3696 utterances from 462 speakers 22 | - `dev.json`: 400 utterances from 50 held-out speakers 23 | - `test.json`: 192 utterances from 24 speakers, the standard TIMIT test set 24 | 25 | ## Train 26 | 27 | There is a CTC and a sequence-to-sequence with attention configuration. Before 28 | training a model, edit the configuration file. In particular, set the 29 | `save_path` to a location where you'd like to store the model. Edit any other 30 | parameters for your experiment. From the top-level directory, you can train the 31 | model with 32 | 33 | ``` 34 | python train.py examples/timit/seq2seq_config.json 35 | ``` 36 | 37 | ## Score 38 | 39 | Save the 48 phoneme predictions with the top-level `eval.py` script. 40 | 41 | ``` 42 | python eval.py examples/timit/data/timit/test.json --save predictions.json 43 | ``` 44 | 45 | To score using the reduced phoneme set (39 phonemes) run 46 | 47 | ``` 48 | python examples/timit/score.py predictions.json 49 | ``` 50 | 51 | ## Results 52 | 53 | TODO, awni, results are from an earlier version of the training set. Need to 54 | update the results for the 462 speaker training set. 55 | 56 | *NB* for best results with all models, evaluate with a batch size of 1. 57 | Otherwise the scores can be slightly worse due to the fact that we pad the 58 | inputs to all be the same length in a given batch. 59 | 60 | ### seq2seq 61 | 62 | These are the dev and test results for the best sequence-to-sequence model with 63 | attention. The configuration can be found in `seq2seq_config.json`. Note this 64 | is without an external LM and with a beam size of 1. Also we don't use any 65 | speaker adaptation or sophisticated features (MFCCs). Results *should* improve 66 | with these features. 67 | 68 | - Dev: 16.8 PER 69 | - Test: 18.7 PER 70 | 71 | ### CTC 72 | 73 | These are the dev and test results for the best CTC model. The configuration 74 | can be found in `ctc_config.json`. Note this is without an external LM and with 75 | `argmax` decoding. Also we don't use any speaker adaptation or sophisticated 76 | features (MFCCs). Results *should* improve with these features. 77 | 78 | - Dev: 15.4 PER 79 | - Test: 17.6 PER 80 | 81 | ## Leaderboard 82 | 83 | | Paper | Test PER | Model | Features | Notes| 84 | |---|---|---|---|---| 85 | | [Speech Recognition with Deep Recurrent Neural Networks](https://arxiv.org/abs/1303.5778) | 17.7 | Transducer | MFCC + deltas | 3-layer bidirectional LSTM, beam search decoder, no external LM, pretrained CTC encoder | 86 | | [Attention-Based Models for Speech Recognition](https://arxiv.org/abs/1506.07503) | 17.6 | Seq2seq | MFCC + deltas | 3-layer bidirectional GRU, beam search decoder, no external LM | 87 | 88 | [Kaldi TIMIT recipe]: https://github.com/kaldi-asr/kaldi/blob/master/egs/timit/s5/conf/phones.60-48-39.map 89 | [LDC93S1]: https://catalog.ldc.upenn.edu/ldc93s1 90 | -------------------------------------------------------------------------------- /examples/timit/preprocess.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import collections 7 | import glob 8 | import json 9 | import os 10 | import random 11 | import tqdm 12 | 13 | from speech.utils import data_helpers 14 | from speech.utils import wave 15 | 16 | WAV_EXT = "wv" # using wv since NIST took wav 17 | TEST_SPEAKERS = [ # Core test set from timit/readme.doc 18 | 'mdab0', 'mwbt0', 'felc0', 'mtas1', 'mwew0', 'fpas0', 19 | 'mjmp0', 'mlnt0', 'fpkt0', 'mlll0', 'mtls0', 'fjlm0', 20 | 'mbpm0', 'mklt0', 'fnlp0', 'mcmj0', 'mjdh0', 'fmgd0', 21 | 'mgrt0', 'mnjm0', 'fdhc0', 'mjln0', 'mpam0', 'fmld0'] 22 | 23 | def load_phone_map(): 24 | with open("phones.60-48-39.map", 'r') as fid: 25 | lines = (l.strip().split() for l in fid) 26 | lines = [l for l in lines if len(l) == 3] 27 | m60_48 = {l[0] : l[1] for l in lines} 28 | m48_39 = {l[1] : l[2] for l in lines} 29 | return m60_48, m48_39 30 | 31 | def load_transcripts(path): 32 | pattern = os.path.join(path, "*/*/*.phn") 33 | m60_48, _ = load_phone_map() 34 | files = glob.glob(pattern) 35 | # Standard practic is to remove all "sa" sentences 36 | # for each speaker since they are the same for all. 37 | filt_sa = lambda x : os.path.basename(x)[:2] != "sa" 38 | files = filter(filt_sa, files) 39 | data = {} 40 | for f in files: 41 | with open(f) as fid: 42 | lines = (l.strip() for l in fid) 43 | phonemes = (l.split()[-1] for l in lines) 44 | phonemes = [m60_48[p] for p in phonemes if p in m60_48] 45 | data[f] = phonemes 46 | return data 47 | 48 | def split_by_speaker(data, dev_speakers=50): 49 | 50 | def speaker_id(f): 51 | return os.path.basename(os.path.dirname(f)) 52 | 53 | speaker_dict = collections.defaultdict(list) 54 | for k, v in data.items(): 55 | speaker_dict[speaker_id(k)].append((k, v)) 56 | speakers = speaker_dict.keys() 57 | for t in TEST_SPEAKERS: 58 | speakers.remove(t) 59 | random.shuffle(speakers) 60 | dev = speakers[:dev_speakers] 61 | dev = dict(v for s in dev for v in speaker_dict[s]) 62 | test = dict(v for s in TEST_SPEAKERS for v in speaker_dict[s]) 63 | return dev, test 64 | 65 | def convert_to_wav(path): 66 | data_helpers.convert_full_set(path, "*/*/*/*.wav", 67 | new_ext=WAV_EXT, 68 | use_avconv=False) 69 | 70 | def build_json(data, path, set_name): 71 | basename = set_name + os.path.extsep + "json" 72 | with open(os.path.join(path, basename), 'w') as fid: 73 | for k, t in tqdm.tqdm(data.items()): 74 | wave_file = os.path.splitext(k)[0] + os.path.extsep + WAV_EXT 75 | dur = wave.wav_duration(wave_file) 76 | datum = {'text' : t, 77 | 'duration' : dur, 78 | 'audio' : wave_file} 79 | json.dump(datum, fid) 80 | fid.write("\n") 81 | 82 | if __name__ == "__main__": 83 | parser = argparse.ArgumentParser( 84 | description="Preprocess Timit dataset.") 85 | 86 | parser.add_argument("output_directory", 87 | help="Path where the dataset is saved.") 88 | args = parser.parse_args() 89 | 90 | path = os.path.join(args.output_directory, "timit") 91 | path = os.path.abspath(path) 92 | 93 | print("Converting files from NIST to standard wave format...") 94 | convert_to_wav(path) 95 | 96 | print("Preprocessing train") 97 | train = load_transcripts(os.path.join(path, "train")) 98 | build_json(train, path, "train") 99 | 100 | print("Preprocessing dev") 101 | transcripts = load_transcripts(os.path.join(path, "test")) 102 | dev, test = split_by_speaker(transcripts) 103 | build_json(dev, path, "dev") 104 | 105 | print("Preprocessing test") 106 | build_json(test, path, "test") 107 | -------------------------------------------------------------------------------- /speech/models/transducer_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.autograd as autograd 9 | 10 | import transducer.decoders as td 11 | import transducer.functions.transducer as transducer 12 | from . import model 13 | 14 | class Transducer(model.Model): 15 | def __init__(self, freq_dim, vocab_size, config): 16 | super().__init__(freq_dim, config) 17 | 18 | # For decoding 19 | decoder_cfg = config["decoder"] 20 | rnn_dim = self.encoder_dim 21 | embed_dim = decoder_cfg["embedding_dim"] 22 | self.embedding = nn.Embedding(vocab_size, embed_dim) 23 | self.dec_rnn = nn.GRU(input_size=embed_dim, 24 | hidden_size=rnn_dim, 25 | num_layers=decoder_cfg["layers"], 26 | batch_first=True, dropout=config["dropout"]) 27 | 28 | # include the blank token 29 | self.blank = vocab_size 30 | self.fc1 = model.LinearND(rnn_dim, rnn_dim) 31 | self.fc2 = model.LinearND(rnn_dim, vocab_size + 1) 32 | 33 | def forward(self, batch): 34 | x, y, x_lens, y_lens = self.collate(*batch) 35 | y_mat = self.label_collate(batch[1]) 36 | return self.forward_impl(x, y_mat) 37 | 38 | def forward_impl(self, x, y): 39 | if self.is_cuda: 40 | x = x.cuda() 41 | y = y.cuda() 42 | x = self.encode(x) 43 | out = self.decode(x, y) 44 | return out 45 | 46 | def loss(self, batch): 47 | x, y, x_lens, y_lens = self.collate(*batch) 48 | y_mat = self.label_collate(batch[1]) 49 | out = self.forward_impl(x, y_mat) 50 | loss_fn = transducer.TransducerLoss() 51 | loss = loss_fn(out, y, x_lens, y_lens) 52 | return loss 53 | 54 | def decode(self, x, y): 55 | """ 56 | x should be shape (batch, time, hidden dimension) 57 | y should be shape (batch, label sequence length) 58 | """ 59 | y = self.embedding(y) 60 | 61 | # preprend zeros 62 | b, t, h = y.shape 63 | start = torch.zeros((b, 1, h)) 64 | if self.is_cuda: 65 | start = start.cuda() 66 | y = torch.cat([start, y], dim=1) 67 | 68 | y, _ = self.dec_rnn(y) 69 | 70 | # Combine the input states and the output states 71 | x = x.unsqueeze(dim=2) 72 | y = y.unsqueeze(dim=1) 73 | out = self.fc1(x) + self.fc1(y) 74 | out = nn.functional.relu(out) 75 | out = self.fc2(out) 76 | out = nn.functional.log_softmax(out, dim=3) 77 | return out 78 | 79 | def collate(self, inputs, labels): 80 | max_t = max(i.shape[0] for i in inputs) 81 | max_t = self.conv_out_size(max_t, 0) 82 | x_lens = torch.IntTensor([max_t] * len(inputs)) 83 | x = torch.FloatTensor(model.zero_pad_concat(inputs)) 84 | y_lens = torch.IntTensor([len(l) for l in labels]) 85 | y = torch.IntTensor([l for label in labels for l in label]) 86 | batch = [x, y, x_lens, y_lens] 87 | if self.volatile: 88 | for v in batch: 89 | v.volatile = True 90 | return batch 91 | 92 | def infer(self, batch, beam_size=4): 93 | out = self(batch) 94 | out = out.cpu().data.numpy() 95 | preds = [] 96 | for e, (i, l) in enumerate(zip(*batch)): 97 | T = i.shape[0] 98 | U = len(l) + 1 99 | lp = out[e, :T, :U, :] 100 | preds.append(td.decode_static(lp, beam_size, blank=self.blank)[0]) 101 | return preds 102 | 103 | def label_collate(self, labels): 104 | # Doesn't matter what we pad the end with 105 | # since it will be ignored. 106 | batch_size = len(labels) 107 | end_tok = labels[0][-1] 108 | max_len = max(len(l) for l in labels) 109 | cat_labels = np.full((batch_size, max_len), 110 | fill_value=end_tok, dtype=np.int64) 111 | for e, l in enumerate(labels): 112 | cat_labels[e, :len(l)] = l 113 | labels = torch.LongTensor(cat_labels) 114 | if self.volatile: 115 | labels.volatile = True 116 | return labels 117 | -------------------------------------------------------------------------------- /speech/models/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import math 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | 10 | class Model(nn.Module): 11 | 12 | def __init__(self, input_dim, config): 13 | super().__init__() 14 | self.input_dim = input_dim 15 | 16 | encoder_cfg = config["encoder"] 17 | conv_cfg = encoder_cfg["conv"] 18 | 19 | convs = [] 20 | in_c = 1 21 | for out_c, h, w, s in conv_cfg: 22 | conv = nn.Conv2d(in_c, out_c, (h, w), 23 | stride=(s, s), padding=0) 24 | convs.extend([conv, nn.ReLU()]) 25 | if config["dropout"] != 0: 26 | convs.append(nn.Dropout(p=config["dropout"])) 27 | in_c = out_c 28 | 29 | self.conv = nn.Sequential(*convs) 30 | conv_out = out_c * self.conv_out_size(input_dim, 1) 31 | assert conv_out > 0, \ 32 | "Convolutional ouptut frequency dimension is negative." 33 | 34 | rnn_cfg = encoder_cfg["rnn"] 35 | self.rnn = nn.GRU(input_size=conv_out, 36 | hidden_size=rnn_cfg["dim"], 37 | num_layers=rnn_cfg["layers"], 38 | batch_first=True, dropout=config["dropout"], 39 | bidirectional=rnn_cfg["bidirectional"]) 40 | self._encoder_dim = rnn_cfg["dim"] 41 | 42 | self.volatile = False 43 | 44 | def conv_out_size(self, n, dim): 45 | for c in self.conv.children(): 46 | if type(c) == nn.Conv2d: 47 | # assuming a valid convolution 48 | k = c.kernel_size[dim] 49 | s = c.stride[dim] 50 | n = (n - k + 1) / s 51 | n = int(math.ceil(n)) 52 | return n 53 | 54 | def forward(self, batch): 55 | """ 56 | Must be overridden by subclasses. 57 | """ 58 | raise NotImplementedError 59 | 60 | def encode(self, x): 61 | x = x.unsqueeze(1) 62 | x = self.conv(x) 63 | 64 | # At this point x should have shape 65 | # (batch, channels, time, freq) 66 | x = torch.transpose(x, 1, 2).contiguous() 67 | 68 | # Reshape x to be (batch, time, freq * channels) 69 | # for the RNN 70 | b, t, f, c = x.size() 71 | x = x.view((b, t, f * c)) 72 | 73 | x, h = self.rnn(x) 74 | 75 | if self.rnn.bidirectional: 76 | half = x.size()[-1] // 2 77 | x = x[:, :, :half] + x[:, :, half:] 78 | 79 | return x 80 | 81 | def loss(self, x, y): 82 | """ 83 | Must be overridden by subclasses. 84 | """ 85 | raise NotImplementedError 86 | 87 | def set_eval(self): 88 | """ 89 | Set the model to evaluation mode. 90 | """ 91 | self.eval() 92 | self.volatile = True 93 | 94 | def set_train(self): 95 | """ 96 | Set the model to training mode. 97 | """ 98 | self.train() 99 | self.volatile = False 100 | 101 | def infer(self, x): 102 | """ 103 | Must be overridden by subclasses. 104 | """ 105 | raise NotImplementedError 106 | 107 | @property 108 | def is_cuda(self): 109 | return list(self.parameters())[0].is_cuda 110 | 111 | @property 112 | def encoder_dim(self): 113 | return self._encoder_dim 114 | 115 | class LinearND(nn.Module): 116 | 117 | def __init__(self, *args): 118 | """ 119 | A torch.nn.Linear layer modified to accept ND arrays. 120 | The function treats the last dimension of the input 121 | as the hidden dimension. 122 | """ 123 | super(LinearND, self).__init__() 124 | self.fc = nn.Linear(*args) 125 | 126 | def forward(self, x): 127 | size = x.size() 128 | n = int(np.prod(size[:-1])) 129 | out = x.contiguous().view(n, size[-1]) 130 | out = self.fc(out) 131 | size = list(size) 132 | size[-1] = out.size()[-1] 133 | return out.view(size) 134 | 135 | def zero_pad_concat(inputs): 136 | max_t = max(inp.shape[0] for inp in inputs) 137 | shape = (len(inputs), max_t, inputs[0].shape[1]) 138 | input_mat = np.zeros(shape, dtype=np.float32) 139 | for e, inp in enumerate(inputs): 140 | input_mat[e, :inp.shape[0], :] = inp 141 | return input_mat 142 | 143 | -------------------------------------------------------------------------------- /examples/wsj/preprocess.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import json 7 | import glob 8 | import os 9 | import re 10 | import subprocess 11 | import tqdm 12 | 13 | from speech.utils import wave 14 | 15 | DATASETS = { 16 | "train_si284" : ["wsj1/doc/indices/si_tr_s.ndx", 17 | "wsj0/doc/indices/train/tr_s_wv1.ndx"], 18 | "eval_92" : ["wsj0/doc/indices/test/nvp/si_et_20.ndx"], 19 | "dev_93" : ["wsj1/doc/indices/h1_p0.ndx"] 20 | } 21 | DOT_PATHS = ["wsj0/transcrp/dots/*/*/*.dot", 22 | "wsj1/trans/wsj1/*/*/*.dot", 23 | "wsj0/si_et_20/*/*.dot"] 24 | ALLOWED = set("abcdefghijklmnopqrstuvwxyz.' -") 25 | REPLACE = { 26 | ".point" : "point", 27 | ".period": "period", 28 | "'single-quote": "single-quote", 29 | "'single-close-quote": "single-close-quote", 30 | "`single-quote" : "single-quote", 31 | "-hyphen": "hyphen", 32 | ")close_paren" : "close-paren", 33 | "(left(-paren)-": "left-", 34 | "." : "", 35 | "--dash" : "dash", 36 | "-dash" : "dash", 37 | } 38 | 39 | def load_text(wsj_base): 40 | transcripts = [] 41 | dots = [] 42 | for d in DOT_PATHS: 43 | dots.extend(glob.glob(os.path.join(wsj_base, d))) 44 | for f in dots: 45 | with open(f, 'r') as fid: 46 | transcripts.extend(l.strip() for l in fid) 47 | transcripts = (t.split() for t in transcripts) 48 | # Key text by utterance id 49 | transcripts = {t[-1][1:-1] : clean(" ".join(t[:-1])) 50 | for t in transcripts} 51 | return transcripts 52 | 53 | def load_waves(wsj_base, files): 54 | waves = [] 55 | for f in files: 56 | flist = os.path.join(wsj_base, f) 57 | with open(flist, 'r') as fid: 58 | lines = (l.split(":")[1].strip().strip("/") 59 | for l in fid if l[0] != ';') 60 | lines = (os.path.join(wsj_base, l) for l in lines) 61 | # Replace wv1 with wav 62 | lines = (os.path.splitext(l)[0] + ".wav" for l in lines) 63 | waves.extend(sorted(lines)) 64 | return waves 65 | 66 | def clean(line): 67 | pl = line 68 | line = line.lower() 69 | line = re.sub("<|>|\\\\|\[\S+\]", "", line) 70 | toks = line.split() 71 | clean_toks = [] 72 | for tok in toks: 73 | if re.match("\S+-dash", tok): 74 | clean_toks.extend(tok.split("-")) 75 | else: 76 | clean_toks.append(REPLACE.get(tok, tok)) 77 | line = " ".join(t for t in clean_toks if t).strip() 78 | line = re.sub("\(\S*\)", "", line) 79 | line = re.sub("[()\*\":\?;!}{\~<>/&,\$\%\~]", "", line) 80 | line = re.sub("`", "'", line) 81 | line = " ".join(line.split()) 82 | return line 83 | 84 | def write_json(save_path, dataset, waves, transcripts): 85 | out_file = os.path.join(save_path, dataset + ".json") 86 | with open(out_file, 'w') as fid: 87 | for wave_file in tqdm.tqdm(waves): 88 | dur = wave.wav_duration(wave_file) 89 | key = os.path.basename(wave_file) 90 | key = os.path.splitext(key)[0] 91 | datum = {'text' : transcripts[key], 92 | 'duration' : dur, 93 | 'audio' : wave_file} 94 | json.dump(datum, fid) 95 | fid.write("\n") 96 | 97 | def convert_sph_to_wav(files): 98 | command = ["sph2pipe_v2.5/sph2pipe", "-p", "-f", 99 | "wav", "-c", "1"] 100 | for out_f in tqdm.tqdm(files): 101 | sph_f = os.path.splitext(out_f)[0] + ".wv1" 102 | subprocess.call(command + [sph_f, out_f]) 103 | 104 | if __name__ == "__main__": 105 | parser = argparse.ArgumentParser( 106 | description="Preprocess WSJ dataset.") 107 | parser.add_argument("wsj_base", 108 | help="Path where the dataset is stored.") 109 | parser.add_argument("save_path", 110 | help="Path to save dataset jsons.") 111 | parser.add_argument("--convert", action="store_true", 112 | help="Convert sphere to wav format.") 113 | args = parser.parse_args() 114 | 115 | transcripts = load_text(args.wsj_base) 116 | for d, v in DATASETS.items(): 117 | waves = load_waves(args.wsj_base, v) 118 | if args.convert: 119 | print("Converting {}".format(d)) 120 | convert_sph_to_wav(waves) 121 | if d == "train_si284": 122 | waves = filter(lambda x: "wsj0/si_tr_s/401" not in x, waves) 123 | print("Writing {}".format(d)) 124 | write_json(args.save_path, d, waves, transcripts) 125 | 126 | -------------------------------------------------------------------------------- /speech/models/ctc_decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Awni Hannun 3 | 4 | This is an example CTC decoder written in Python. The code is 5 | intended to be a simple example and is not designed to be 6 | especially efficient. 7 | 8 | The algorithm is a prefix beam search for a model trained 9 | with the CTC loss function. 10 | 11 | For more details checkout either of these references: 12 | https://distill.pub/2017/ctc/#inference 13 | https://arxiv.org/abs/1408.2873 14 | 15 | """ 16 | 17 | import numpy as np 18 | import math 19 | import collections 20 | 21 | NEG_INF = -float("inf") 22 | 23 | def make_new_beam(): 24 | fn = lambda : (NEG_INF, NEG_INF) 25 | return collections.defaultdict(fn) 26 | 27 | def logsumexp(*args): 28 | """ 29 | Stable log sum exp. 30 | """ 31 | if all(a == NEG_INF for a in args): 32 | return NEG_INF 33 | a_max = max(args) 34 | lsp = math.log(sum(math.exp(a - a_max) 35 | for a in args)) 36 | return a_max + lsp 37 | 38 | def decode(probs, beam_size=10, blank=0): 39 | """ 40 | Performs inference for the given output probabilities. 41 | 42 | Arguments: 43 | probs: The output probabilities (e.g. post-softmax) for each 44 | time step. Should be an array of shape (time x output dim). 45 | beam_size (int): Size of the beam to use during inference. 46 | blank (int): Index of the CTC blank label. 47 | 48 | Returns the output label sequence and the corresponding negative 49 | log-likelihood estimated by the decoder. 50 | """ 51 | T, S = probs.shape 52 | probs = np.log(probs) 53 | 54 | # Elements in the beam are (prefix, (p_blank, p_no_blank)) 55 | # Initialize the beam with the empty sequence, a probability of 56 | # 1 for ending in blank and zero for ending in non-blank 57 | # (in log space). 58 | beam = [(tuple(), (0.0, NEG_INF))] 59 | 60 | for t in range(T): # Loop over time 61 | 62 | # A default dictionary to store the next step candidates. 63 | next_beam = make_new_beam() 64 | 65 | for s in range(S): # Loop over vocab 66 | p = probs[t, s] 67 | 68 | # The variables p_b and p_nb are respectively the 69 | # probabilities for the prefix given that it ends in a 70 | # blank and does not end in a blank at this time step. 71 | for prefix, (p_b, p_nb) in beam: # Loop over beam 72 | 73 | # If we propose a blank the prefix doesn't change. 74 | # Only the probability of ending in blank gets updated. 75 | if s == blank: 76 | n_p_b, n_p_nb = next_beam[prefix] 77 | n_p_b = logsumexp(n_p_b, p_b + p, p_nb + p) 78 | next_beam[prefix] = (n_p_b, n_p_nb) 79 | continue 80 | 81 | # Extend the prefix by the new character s and add it to 82 | # the beam. Only the probability of not ending in blank 83 | # gets updated. 84 | end_t = prefix[-1] if prefix else None 85 | n_prefix = prefix + (s,) 86 | n_p_b, n_p_nb = next_beam[n_prefix] 87 | if s != end_t: 88 | n_p_nb = logsumexp(n_p_nb, p_b + p, p_nb + p) 89 | else: 90 | # We don't include the previous probability of not ending 91 | # in blank (p_nb) if s is repeated at the end. The CTC 92 | # algorithm merges characters not separated by a blank. 93 | n_p_nb = logsumexp(n_p_nb, p_b + p) 94 | 95 | # *NB* this would be a good place to include an LM score. 96 | next_beam[n_prefix] = (n_p_b, n_p_nb) 97 | 98 | # If s is repeated at the end we also update the unchanged 99 | # prefix. This is the merging case. 100 | if s == end_t: 101 | n_p_b, n_p_nb = next_beam[prefix] 102 | n_p_nb = logsumexp(n_p_nb, p_nb + p) 103 | next_beam[prefix] = (n_p_b, n_p_nb) 104 | 105 | # Sort and trim the beam before moving on to the 106 | # next time-step. 107 | beam = sorted(next_beam.items(), 108 | key=lambda x : logsumexp(*x[1]), 109 | reverse=True) 110 | beam = beam[:beam_size] 111 | 112 | best = beam[0] 113 | return best[0], -logsumexp(*best[1]) 114 | 115 | if __name__ == "__main__": 116 | np.random.seed(3) 117 | 118 | time = 50 119 | output_dim = 20 120 | 121 | probs = np.random.rand(time, output_dim) 122 | probs = probs / np.sum(probs, axis=1, keepdims=True) 123 | 124 | labels, score = decode(probs) 125 | print(labels) 126 | print("Score {:.3f}".format(score)) 127 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import json 7 | import random 8 | import time 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim 12 | import tqdm 13 | 14 | import speech 15 | import speech.loader as loader 16 | import speech.models as models 17 | 18 | # TODO, (awni) why does putting this above crash.. 19 | import tensorboard_logger as tb 20 | 21 | def run_epoch(model, optimizer, train_ldr, it, avg_loss): 22 | 23 | model_t = 0.0; data_t = 0.0 24 | end_t = time.time() 25 | tq = tqdm.tqdm(train_ldr) 26 | for batch in tq: 27 | start_t = time.time() 28 | optimizer.zero_grad() 29 | loss = model.loss(batch) 30 | loss.backward() 31 | 32 | grad_norm = nn.utils.clip_grad_norm(model.parameters(), 200) 33 | loss = loss.data[0] 34 | 35 | optimizer.step() 36 | prev_end_t = end_t 37 | end_t = time.time() 38 | model_t += end_t - start_t 39 | data_t += start_t - prev_end_t 40 | 41 | exp_w = 0.99 42 | avg_loss = exp_w * avg_loss + (1 - exp_w) * loss 43 | tb.log_value('train_loss', loss, it) 44 | tq.set_postfix(iter=it, loss=loss, 45 | avg_loss=avg_loss, grad_norm=grad_norm, 46 | model_time=model_t, data_time=data_t) 47 | it += 1 48 | 49 | return it, avg_loss 50 | 51 | def eval_dev(model, ldr, preproc): 52 | losses = []; all_preds = []; all_labels = [] 53 | 54 | model.set_eval() 55 | 56 | for batch in tqdm.tqdm(ldr): 57 | preds = model.infer(batch) 58 | loss = model.loss(batch) 59 | losses.append(loss.data[0]) 60 | all_preds.extend(preds) 61 | all_labels.extend(batch[1]) 62 | 63 | model.set_train() 64 | 65 | loss = sum(losses) / len(losses) 66 | results = [(preproc.decode(l), preproc.decode(p)) 67 | for l, p in zip(all_labels, all_preds)] 68 | cer = speech.compute_cer(results) 69 | print("Dev: Loss {:.3f}, CER {:.3f}".format(loss, cer)) 70 | return loss, cer 71 | 72 | def run(config): 73 | 74 | opt_cfg = config["optimizer"] 75 | data_cfg = config["data"] 76 | model_cfg = config["model"] 77 | 78 | # Loaders 79 | batch_size = opt_cfg["batch_size"] 80 | preproc = loader.Preprocessor(data_cfg["train_set"], 81 | start_and_end=data_cfg["start_and_end"]) 82 | train_ldr = loader.make_loader(data_cfg["train_set"], 83 | preproc, batch_size) 84 | dev_ldr = loader.make_loader(data_cfg["dev_set"], 85 | preproc, batch_size) 86 | 87 | # Model 88 | model_class = eval("models." + model_cfg["class"]) 89 | model = model_class(preproc.input_dim, 90 | preproc.vocab_size, 91 | model_cfg) 92 | model.cuda() if use_cuda else model.cpu() 93 | 94 | # Optimizer 95 | optimizer = torch.optim.SGD(model.parameters(), 96 | lr=opt_cfg["learning_rate"], 97 | momentum=opt_cfg["momentum"]) 98 | 99 | run_state = (0, 0) 100 | best_so_far = float("inf") 101 | for e in range(opt_cfg["epochs"]): 102 | start = time.time() 103 | 104 | run_state = run_epoch(model, optimizer, train_ldr, *run_state) 105 | 106 | msg = "Epoch {} completed in {:.2f} (s)." 107 | print(msg.format(e, time.time() - start)) 108 | 109 | dev_loss, dev_cer = eval_dev(model, dev_ldr, preproc) 110 | 111 | # Log for tensorboard 112 | tb.log_value("dev_loss", dev_loss, e) 113 | tb.log_value("dev_cer", dev_cer, e) 114 | 115 | speech.save(model, preproc, config["save_path"]) 116 | 117 | # Save the best model on the dev set 118 | if dev_cer < best_so_far: 119 | best_so_far = dev_cer 120 | speech.save(model, preproc, 121 | config["save_path"], tag="best") 122 | 123 | if __name__ == "__main__": 124 | parser = argparse.ArgumentParser( 125 | description="Train a speech model.") 126 | 127 | parser.add_argument("config", 128 | help="A json file with the training configuration.") 129 | parser.add_argument("--deterministic", default=False, 130 | action="store_true", 131 | help="Run in deterministic mode (no cudnn). Only works on GPU.") 132 | args = parser.parse_args() 133 | 134 | with open(args.config, 'r') as fid: 135 | config = json.load(fid) 136 | 137 | random.seed(config["seed"]) 138 | torch.manual_seed(config["seed"]) 139 | 140 | tb.configure(config["save_path"]) 141 | 142 | use_cuda = torch.cuda.is_available() 143 | 144 | if use_cuda and args.deterministic: 145 | torch.backends.cudnn.enabled = False 146 | run(config) 147 | -------------------------------------------------------------------------------- /speech/loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import numpy as np 7 | import random 8 | import scipy.signal 9 | import torch 10 | import torch.autograd as autograd 11 | import torch.utils.data as tud 12 | 13 | from speech.utils import wave 14 | 15 | class Preprocessor(): 16 | 17 | END = "" 18 | START = "" 19 | 20 | def __init__(self, data_json, max_samples=100, start_and_end=True): 21 | """ 22 | Builds a preprocessor from a dataset. 23 | Arguments: 24 | data_json (string): A file containing a json representation 25 | of each example per line. 26 | max_samples (int): The maximum number of examples to be used 27 | in computing summary statistics. 28 | start_and_end (bool): Include start and end tokens in labels. 29 | """ 30 | data = read_data_json(data_json) 31 | 32 | # Compute data mean, std from sample 33 | audio_files = [d['audio'] for d in data] 34 | random.shuffle(audio_files) 35 | self.mean, self.std = compute_mean_std(audio_files[:max_samples]) 36 | self._input_dim = self.mean.shape[0] 37 | 38 | # Make char map 39 | chars = list(set(t for d in data for t in d['text'])) 40 | if start_and_end: 41 | # START must be last so it can easily be 42 | # excluded in the output classes of a model. 43 | chars.extend([self.END, self.START]) 44 | self.start_and_end = start_and_end 45 | self.int_to_char = dict(enumerate(chars)) 46 | self.char_to_int = {v : k for k, v in self.int_to_char.items()} 47 | 48 | def encode(self, text): 49 | text = list(text) 50 | if self.start_and_end: 51 | text = [self.START] + text + [self.END] 52 | return [self.char_to_int[t] for t in text] 53 | 54 | def decode(self, seq): 55 | text = [self.int_to_char[s] for s in seq] 56 | if not self.start_and_end: 57 | return text 58 | 59 | s = text[0] == self.START 60 | e = len(text) 61 | if text[-1] == self.END: 62 | e = text.index(self.END) 63 | return text[s:e] 64 | 65 | def preprocess(self, wave_file, text): 66 | inputs = log_specgram_from_file(wave_file) 67 | inputs = (inputs - self.mean) / self.std 68 | targets = self.encode(text) 69 | return inputs, targets 70 | 71 | @property 72 | def input_dim(self): 73 | return self._input_dim 74 | 75 | @property 76 | def vocab_size(self): 77 | return len(self.int_to_char) 78 | 79 | def compute_mean_std(audio_files): 80 | samples = [log_specgram_from_file(af) 81 | for af in audio_files] 82 | samples = np.vstack(samples) 83 | mean = np.mean(samples, axis=0) 84 | std = np.std(samples, axis=0) 85 | return mean, std 86 | 87 | class AudioDataset(tud.Dataset): 88 | 89 | def __init__(self, data_json, preproc, batch_size): 90 | 91 | data = read_data_json(data_json) 92 | self.preproc = preproc 93 | 94 | bucket_diff = 4 95 | max_len = max(len(x['text']) for x in data) 96 | num_buckets = max_len // bucket_diff 97 | buckets = [[] for _ in range(num_buckets)] 98 | for d in data: 99 | bid = min(len(d['text']) // bucket_diff, num_buckets - 1) 100 | buckets[bid].append(d) 101 | 102 | # Sort by input length followed by output length 103 | sort_fn = lambda x : (round(x['duration'], 1), 104 | len(x['text'])) 105 | for b in buckets: 106 | b.sort(key=sort_fn) 107 | data = [d for b in buckets for d in b] 108 | self.data = data 109 | 110 | def __len__(self): 111 | return len(self.data) 112 | 113 | def __getitem__(self, idx): 114 | datum = self.data[idx] 115 | datum = self.preproc.preprocess(datum["audio"], 116 | datum["text"]) 117 | return datum 118 | 119 | 120 | class BatchRandomSampler(tud.sampler.Sampler): 121 | """ 122 | Batches the data consecutively and randomly samples 123 | by batch without replacement. 124 | """ 125 | 126 | def __init__(self, data_source, batch_size): 127 | it_end = len(data_source) - batch_size + 1 128 | self.batches = [range(i, i + batch_size) 129 | for i in range(0, it_end, batch_size)] 130 | self.data_source = data_source 131 | 132 | def __iter__(self): 133 | random.shuffle(self.batches) 134 | return (i for b in self.batches for i in b) 135 | 136 | def __len__(self): 137 | return len(self.data_source) 138 | 139 | def make_loader(dataset_json, preproc, 140 | batch_size, num_workers=4): 141 | dataset = AudioDataset(dataset_json, preproc, 142 | batch_size) 143 | sampler = BatchRandomSampler(dataset, batch_size) 144 | loader = tud.DataLoader(dataset, 145 | batch_size=batch_size, 146 | sampler=sampler, 147 | num_workers=num_workers, 148 | collate_fn=lambda batch : zip(*batch), 149 | drop_last=True) 150 | return loader 151 | 152 | def log_specgram_from_file(audio_file): 153 | audio, sr = wave.array_from_wave(audio_file) 154 | return log_specgram(audio, sr) 155 | 156 | def log_specgram(audio, sample_rate, window_size=20, 157 | step_size=10, eps=1e-10): 158 | nperseg = int(window_size * sample_rate / 1e3) 159 | noverlap = int(step_size * sample_rate / 1e3) 160 | _, _, spec = scipy.signal.spectrogram(audio, 161 | fs=sample_rate, 162 | window='hann', 163 | nperseg=nperseg, 164 | noverlap=noverlap, 165 | detrend=False) 166 | return np.log(spec.T.astype(np.float32) + eps) 167 | 168 | def read_data_json(data_json): 169 | with open(data_json) as fid: 170 | return [json.loads(l) for l in fid] 171 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /speech/models/seq2seq.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import math 6 | import numpy as np 7 | import random 8 | import torch 9 | import torch.nn as nn 10 | import torch.autograd as autograd 11 | 12 | from . import model 13 | 14 | class Seq2Seq(model.Model): 15 | 16 | def __init__(self, freq_dim, vocab_size, config): 17 | super().__init__(freq_dim, config) 18 | 19 | # For decoding 20 | decoder_cfg = config["decoder"] 21 | rnn_dim = self.encoder_dim 22 | embed_dim = decoder_cfg["embedding_dim"] 23 | self.embedding = nn.Embedding(vocab_size, embed_dim) 24 | self.dec_rnn = nn.GRUCell(input_size=embed_dim, 25 | hidden_size=rnn_dim) 26 | 27 | self.attend = NNAttention(rnn_dim, log_t=decoder_cfg.get("log_t", False)) 28 | 29 | self.sample_prob = decoder_cfg.get("sample_prob", 0) 30 | self.scheduled_sampling = (self.sample_prob != 0) 31 | 32 | # *NB* we predict vocab_size - 1 classes since we 33 | # never need to predict the start of sequence token. 34 | self.fc = model.LinearND(rnn_dim, vocab_size - 1) 35 | 36 | def set_eval(self): 37 | """ 38 | Set the model to evaluation mode. 39 | """ 40 | self.eval() 41 | self.volatile = True 42 | self.scheduled_sampling = False 43 | 44 | def set_train(self): 45 | """ 46 | Set the model to training mode. 47 | """ 48 | self.train() 49 | self.volatile = False 50 | self.scheduled_sampling = (self.sample_prob != 0) 51 | 52 | def loss(self, batch): 53 | x, y = self.collate(*batch) 54 | if self.is_cuda: 55 | x = x.cuda() 56 | y = y.cuda() 57 | out, alis = self.forward_impl(x, y) 58 | batch_size, _, out_dim = out.size() 59 | out = out.view((-1, out_dim)) 60 | y = y[:,1:].contiguous().view(-1) 61 | loss = nn.functional.cross_entropy(out, y, 62 | size_average=False) 63 | loss = loss / batch_size 64 | return loss 65 | 66 | def forward_impl(self, x, y): 67 | x = self.encode(x) 68 | out, alis = self.decode(x, y) 69 | return out, alis 70 | 71 | def forward(self, batch): 72 | x, y = self.collate(*batch) 73 | if self.is_cuda: 74 | x = x.cuda() 75 | y = y.cuda() 76 | return self.forward_impl(x, y)[0] 77 | 78 | def decode(self, x, y): 79 | """ 80 | x should be shape (batch, time, hidden dimension) 81 | y should be shape (batch, label sequence length) 82 | """ 83 | 84 | inputs = self.embedding(y[:, :-1]) 85 | 86 | out = []; aligns = [] 87 | 88 | hx = torch.zeros((x.shape[0], x.shape[2]), requires_grad=False) 89 | if self.is_cuda: 90 | hx.cuda() 91 | ax = None; sx = None; 92 | for t in range(y.size()[1] - 1): 93 | sample = (out and self.scheduled_sampling) 94 | if sample and random.random() < self.sample_prob: 95 | ix = torch.max(out[-1], dim=2)[1] 96 | ix = self.embedding(ix) 97 | else: 98 | ix = inputs[:, t:t+1, :] 99 | 100 | if sx is not None: 101 | ix = ix + sx 102 | 103 | hx = self.dec_rnn(ix.squeeze(dim=1), hx) 104 | ox = hx.unsqueeze(dim=1) 105 | 106 | sx, ax = self.attend(x, ox, ax) 107 | aligns.append(ax) 108 | out.append(self.fc(ox + sx)) 109 | 110 | out = torch.cat(out, dim=1) 111 | aligns = torch.stack(aligns, dim=1) 112 | return out, aligns 113 | 114 | def decode_step(self, x, y, state=None, softmax=False): 115 | """ 116 | x should be shape (batch, time, hidden dimension) 117 | y should be shape (batch, label sequence length) 118 | """ 119 | if state is None: 120 | hx = torch.zeros((x.shape[0], x.shape[2]), requires_grad=False) 121 | if self.is_cuda: 122 | hx.cuda() 123 | ax = None; sx = None; 124 | else: 125 | hx, ax, sx = state 126 | 127 | ix = self.embedding(y) 128 | if sx is not None: 129 | ix = ix + sx 130 | hx = self.dec_rnn(ix.squeeze(dim=1), hx=hx) 131 | ox = hx.unsqueeze(dim=1) 132 | sx, ax = self.attend(x, ox, ax=ax) 133 | out = ox + sx 134 | out = self.fc(out.squeeze(dim=1)) 135 | if softmax: 136 | out = nn.functional.log_softmax(out, dim=1) 137 | return out, (hx, ax, sx) 138 | 139 | def predict(self, batch): 140 | probs = self(batch) 141 | argmaxs = torch.max(probs, dim=2)[1] 142 | argmaxs = argmaxs.cpu().data.numpy() 143 | return [seq.tolist() for seq in argmaxs] 144 | 145 | def infer_decode(self, x, y, end_tok, max_len): 146 | probs = [] 147 | argmaxs = [y] 148 | state = None 149 | for e in range(max_len): 150 | out, state = self.decode_step(x, y, state=state) 151 | probs.append(out) 152 | y = torch.max(out, dim=1)[1] 153 | y = y.unsqueeze(dim=1) 154 | argmaxs.append(y) 155 | if torch.sum(y.data == end_tok) == y.numel(): 156 | break 157 | 158 | probs = torch.cat(probs) 159 | argmaxs = torch.cat(argmaxs, dim=1) 160 | return probs, argmaxs 161 | 162 | def infer(self, batch, max_len=200): 163 | """ 164 | Infer a likely output. No beam search yet. 165 | """ 166 | x, y = self.collate(*batch) 167 | end_tok = y.data[0, -1] # TODO 168 | t = y 169 | if self.is_cuda: 170 | x = x.cuda() 171 | t = y.cuda() 172 | x = self.encode(x) 173 | 174 | # needs to be the start token, TODO 175 | y = t[:, 0:1] 176 | _, argmaxs = self.infer_decode(x, y, end_tok, max_len) 177 | argmaxs = argmaxs.cpu().data.numpy() 178 | return [seq.tolist() for seq in argmaxs] 179 | 180 | def beam_search(self, batch, beam_size=10, max_len=200): 181 | x, y = self.collate(*batch) 182 | start_tok = y.data[0, 0] 183 | end_tok = y.data[0, -1] # TODO 184 | if self.is_cuda: 185 | x = x.cuda() 186 | y = y.cuda() 187 | x = self.encode(x) 188 | 189 | y = y[:, 0:1].clone() 190 | 191 | beam = [((start_tok,), 0, None)]; 192 | complete = [] 193 | for _ in range(max_len): 194 | new_beam = [] 195 | for hyp, score, state in beam: 196 | 197 | y[0] = hyp[-1] 198 | out, state = self.decode_step(x, y, state=state, softmax=True) 199 | out = out.cpu().data.numpy().squeeze(axis=0).tolist() 200 | for i, p in enumerate(out): 201 | new_score = score + p 202 | new_hyp = hyp + (i,) 203 | new_beam.append((new_hyp, new_score, state)) 204 | new_beam = sorted(new_beam, key=lambda x: x[1], reverse=True) 205 | 206 | # Remove complete hypotheses 207 | for cand in new_beam[:beam_size]: 208 | if cand[0][-1] == end_tok: 209 | complete.append(cand) 210 | 211 | beam = filter(lambda x : x[0][-1] != end_tok, new_beam) 212 | beam = beam[:beam_size] 213 | 214 | if len(beam) == 0: 215 | break 216 | 217 | # Stopping criteria: 218 | # complete contains beam_size more probable 219 | # candidates than anything left in the beam 220 | if sum(c[1] > beam[0][1] for c in complete) >= beam_size: 221 | break 222 | 223 | complete = sorted(complete, key=lambda x: x[1], reverse=True) 224 | if len(complete) == 0: 225 | complete = beam 226 | hyp, score, _ = complete[0] 227 | return [hyp] 228 | 229 | def collate(self, inputs, labels): 230 | inputs = model.zero_pad_concat(inputs) 231 | labels = end_pad_concat(labels) 232 | inputs = torch.from_numpy(inputs) 233 | labels = torch.from_numpy(labels) 234 | if self.volatile: 235 | inputs.volatile = True 236 | labels.volatile = True 237 | return inputs, labels 238 | 239 | def end_pad_concat(labels): 240 | # Assumes last item in each example is the end token. 241 | batch_size = len(labels) 242 | end_tok = labels[0][-1] 243 | max_len = max(len(l) for l in labels) 244 | cat_labels = np.full((batch_size, max_len), 245 | fill_value=end_tok, dtype=np.int64) 246 | for e, l in enumerate(labels): 247 | cat_labels[e, :len(l)] = l 248 | return cat_labels 249 | 250 | class Attention(nn.Module): 251 | 252 | def __init__(self, kernel_size=11, log_t=False): 253 | """ 254 | Module which Performs a single attention step along the 255 | second axis of a given encoded input. The module uses 256 | both 'content' and 'location' based attention. 257 | 258 | The 'content' based attention is an inner product of the 259 | decoder hidden state with each time-step of the encoder 260 | state. 261 | 262 | The 'location' based attention performs a 1D convollution 263 | on the previous attention vector and adds this into the 264 | next attention vector prior to normalization. 265 | 266 | *NB* Should compute attention differently if using cuda or cpu 267 | based on performance. See 268 | https://gist.github.com/awni/9989dd31642d42405903dec8ab91d1f0 269 | """ 270 | super(Attention, self).__init__() 271 | assert kernel_size % 2 == 1, \ 272 | "Kernel size should be odd for 'same' conv." 273 | padding = (kernel_size - 1) // 2 274 | self.conv = nn.Conv1d(1, 1, kernel_size, padding=padding) 275 | self.log_t = log_t 276 | 277 | def forward(self, eh, dhx, ax=None): 278 | """ 279 | Arguments: 280 | eh (FloatTensor): the encoder hidden state with 281 | shape (batch size, time, hidden dimension). 282 | dhx (FloatTensor): one time step of the decoder hidden 283 | state with shape (batch size, hidden dimension). 284 | The hidden dimension must match that of the 285 | encoder state. 286 | ax (FloatTensor): one time step of the attention 287 | vector. 288 | 289 | Returns the summary of the encoded hidden state 290 | and the corresponding alignment. 291 | """ 292 | # Compute inner product of decoder slice with every 293 | # encoder slice. 294 | # location attention 295 | pax = eh * dhx 296 | pax = torch.sum(pax, dim=2) 297 | 298 | 299 | if ax is not None: 300 | ax = ax.unsqueeze(dim=1) 301 | ax = self.conv(ax).squeeze(dim=1) 302 | pax = pax + ax 303 | 304 | if self.log_t: 305 | log_t = math.log(pax.size()[1]) 306 | pax = log_t * pax 307 | ax = nn.functional.softmax(pax, dim=1) 308 | 309 | # At this point sx should have size (batch size, time). 310 | # Reduce the encoder state accross time weighting each 311 | # slice by its corresponding value in sx. 312 | sx = ax.unsqueeze(2) 313 | sx = torch.sum(eh * sx, dim=1, keepdim=True) 314 | return sx, ax 315 | 316 | class ProdAttention(nn.Module): 317 | 318 | def __init__(self): 319 | super(ProdAttention, self).__init__() 320 | 321 | def forward(self, eh, dhx, ax=None): 322 | pax = eh * dhx 323 | pax = torch.sum(pax, dim=2) 324 | 325 | ax = nn.functional.softmax(pax, dim=1) 326 | 327 | sx = ax.unsqueeze(2) 328 | sx = torch.sum(eh * sx, dim=1, keepdim=True) 329 | return sx, ax 330 | 331 | class NNAttention(nn.Module): 332 | 333 | def __init__(self, n_channels, kernel_size=15, log_t=False): 334 | super(NNAttention, self).__init__() 335 | assert kernel_size % 2 == 1, \ 336 | "Kernel size should be odd for 'same' conv." 337 | padding = (kernel_size - 1) // 2 338 | self.conv = nn.Conv1d(1, n_channels, kernel_size, padding=padding) 339 | self.nn = nn.Sequential( 340 | nn.ReLU(), 341 | model.LinearND(n_channels, 1)) 342 | self.log_t = log_t 343 | 344 | def forward(self, eh, dhx, ax=None): 345 | pax = eh + dhx 346 | if ax is not None: 347 | ax = ax.unsqueeze(dim=1) 348 | ax = self.conv(ax).transpose(1, 2) 349 | pax = pax + ax 350 | 351 | pax = self.nn(pax) 352 | pax = pax.squeeze(dim=2) 353 | if self.log_t: 354 | log_t = math.log(pax.size()[1]) 355 | pax = log_t * pax 356 | ax = nn.functional.softmax(pax, dim=1) 357 | 358 | sx = ax.unsqueeze(2) 359 | sx = torch.sum(eh * sx, dim=1, keepdim=True) 360 | return sx, ax 361 | --------------------------------------------------------------------------------