├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── build_data.py ├── compare-annotation.R ├── config.py ├── data_utils.py ├── general_utils.py ├── images └── word-importance.png ├── main.py ├── model.py ├── model_utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Sushant Kafle 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Generic Bi-LSTM Model for Word Importance Prediction in Spoken Dialogues: 2 | ============================================================== 3 | 4 | This project demonstrates the use of generic bi-directional LSTM models for predicting importance of words in a spoken dialgoue for understanding its meaning. The model operates on human-annotated corpus of word importance for its training and evaluation. The corpus can be downloaded from: http://latlab.ist.rit.edu/lrec2018 5 | 6 | ![Word Importance Visualization in a Dialgoue](https://github.com/SushantKafle/speechtext-wimp-labeler/blob/master/images/word-importance.png "Word Importance Visualization in a Dialgoue") 7 | 8 | Performance Summary (will-be-updated-soon): 9 | ------------------------------------------- 10 |
11 | 12 | | Model | Classes | f-score | rms | 13 | |:------------------------------: |:-------: |:-------: |------ | 14 | | bi-LSTM (char-embedding + CRF) | 3 | 0.73 | 0.598 | 15 | | bi-LSTM (char-embedding + CRF) | 6 | 0.60 | 0.154 | 16 | 17 |
18 | 19 | You can cite this work and/or the corpus using: 20 | 21 | > Sushant Kafle and Matt Huenerfauth. 2018. A Corpus for Modeling Word Importance in Spoken Dialogue Transcripts. In Proceedings of the 11th edition of the Language Resources and Evaluation Conference (LREC). ACM. 22 | 23 | 24 | I. Data Preparation 25 | ==================== 26 | 27 | 1. Download and locate the text transcripts of the Switchboard Corpus and the corresponding word importance corpus: 28 | * The word importance corpus can be downloaded from this website: http://latlab.ist.rit.edu/lrec2018 29 | * The text transcripts form the Switchboard corpus can be downloaded via this link: https://www.isip.piconepress.com/projects/switchboard/releases/switchboard_word_alignments.tar.gz 30 | 31 | * In the “config.py” file, update the varibles shown below: 32 | 33 | # location of the Word Importance Corpus "annotations folder" 34 | wimp_corpus = --HERE-- 35 | 36 | # location of the Switchboard transcripts 37 | swd_transcripts = --AND HERE-- 38 | 39 | 2. Download glove vectors `glove.6B.300d.txt` from http://nlp.stanford.edu/data/glove.6B.zip and update `glove_filename` in `config.py` 40 | 41 | 3. Run the ‘build_data.py’ to prepare data for training, development and testing as: 42 | 43 | ```python build_data.py``` 44 | 45 | This will create all the necessary files (such as the word vocabulary, character vocabulary and the training, development and test files) in the “$PROJECT_HOME/data/“ directory. 46 | 47 | II. Install Python Dependencies 48 | ====================== 49 | 50 | `pip install -r requirements.txt` 51 | 52 | 53 | III. Running the model 54 | ====================== 55 | 56 | 1. Traverse inside the model you want to train and open the ‘config.py’ file and review the configurations: 57 | 58 | * model : type of model to run (options: lstm_crf or lstm_sig) 59 | * wimp_corpus : Path to the Word Importance Corpus 60 | * swd_transcripts : Path to the Switchboard Transcripts 61 | * output_path : Path to the output directory 62 | * model_output : Path to save the best performing model 63 | * log_path : Path to store the log 64 | * confusion_mat : Path to save the image of the confusion matrix (part of analysis on the test data) 65 | * compare_predictions : Path to save the predictions of the model (.csv file is produced) 66 | 67 | * random_seed : Random seed 68 | * opt_metric : Metric to evaluate the progress at each epoch 69 | * nclass : Num of classes for prediction 70 | 71 | * dim : Size of the word embeddings used in the model 72 | * dim_char : Size of the character embeddings 73 | * glove_filename : Path to the glove-embeddings file 74 | * trimmed_filename : Path to save the trimmed glove-embeddings 75 | 76 | * dev_filename : Path to the development data, used to select the best epoch 77 | * test_filename : Path to the test data, use for evaluating the model performance 78 | * train_filename : Path to the train data 79 | 80 | * words_filename : Path to the word vocabulary 81 | * tags_filename : Path to the vocabulary of the tags 82 | * chars_filename : Path to the vocabulary of the characters 83 | 84 | * train_embeddings : If True, trains the word-level embeddings 85 | * nepochs : Maximum number of epoches to run 86 | * dropout : The probability of applying dropout during training 87 | * batch_size : Number of examples in each batch 88 | * lr_method : Optimization strategy (options: adam, adagrad, sgd, rmsprop) 89 | * lr : Learning rate 90 | * lr_decay : Rate of decay of the learning rate 91 | * clip : Gradient clipping, if negative no clipping 92 | * nepoch_no_imprv : Number of epoch without improvement for early termination 93 | * reload : Reload the latest trained model 94 | 95 | * word_rnn_size : Size of the word-level LSTM hidden layers 96 | * char_rnn_size : Size of the char-level LSTM hidden layers 97 | 98 | 2. Run the model by: 99 | 100 | ```python main.py``` 101 | 102 | Summary: 103 | * trained model saved at “model_output” (declared inside config.py). 104 | * log of the analysis at “log_path” (declared inside config.py) - contains train, dev and test performance. 105 | * confusion matrix at “confusion_mat” (declared inside config.py). 106 | * CSV file containing the actual scores and the predicted score at “compare_predictions” (declared inside config.py). 107 | 108 | 109 | IV. Running the agreement analysis 110 | ==================================== 111 | 112 | 1. Locate the csv file containing the actual scores annotated by the annotators and the predicted scores. 113 | 114 | 2. Open up “compare-annotation.R” file and update the “annotation_src” variable with this new location. 115 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SushantKafle/speechtext-wimp-labeler/32b71e72f86ab7f864e75e8517bb32f4400352d4/__init__.py -------------------------------------------------------------------------------- /build_data.py: -------------------------------------------------------------------------------- 1 | from math import floor 2 | 3 | from config import Config 4 | import xlrd, os, random, csv 5 | from data_utils import AnnotationDataset, get_vocabs, UNK, NUM, \ 6 | get_glove_vocab, write_vocab, load_vocab, get_char_vocab, \ 7 | export_trimmed_glove_vectors, get_processing_word 8 | from general_utils import clean_word 9 | 10 | def build_data(config): 11 | annotations = [] 12 | meta_filename = 'sw%s%s-ms98-a-trans.text' # % (file_id, speaker_id) 13 | 14 | for idx in os.listdir(config.wimp_corpus): 15 | idx_path = os.path.join(config.wimp_corpus, idx) 16 | if os.path.isfile(idx_path): 17 | continue 18 | 19 | for file_id in os.listdir(idx_path): 20 | folder = os.path.join(idx_path, file_id) 21 | if os.path.isfile(folder): 22 | continue 23 | 24 | wimp_trans_files = [os.path.join(folder, meta_filename % (file_id, 'A')), 25 | os.path.join(folder, meta_filename % (file_id, 'B'))] 26 | 27 | swd_trans_files = [os.path.join(config.swd_transcripts, idx, file_id, meta_filename % (file_id, 'A')), 28 | os.path.join(config.swd_transcripts, idx, file_id, meta_filename % (file_id, 'B'))] 29 | 30 | for i, wimp_trans_file in enumerate(wimp_trans_files): 31 | swd_trans_file = swd_trans_files[i] 32 | file_id, speaker = swd_trans_file.split("/")[-2:] 33 | speaker = speaker[6] 34 | with open(wimp_trans_file) as w_file_obj, open(swd_trans_file) as s_file_obj: 35 | for line_num, (anns_, wrds_) in enumerate(zip(w_file_obj, s_file_obj)): 36 | sentence = [] 37 | anns = anns_.strip().split(' ')[3:] 38 | wrds = wrds_.strip().split(' ')[3:] 39 | assert(len(anns) == len(wrds)), \ 40 | "file mismatch, line %d : %s and %s" % (line_num, swd_trans_file, wimp_trans_file) 41 | 42 | for id_, wrd in enumerate(wrds): 43 | wrd = clean_word(wrd) 44 | if wrd != '': 45 | sentence.append([(file_id, line_num, speaker), wrd, float(anns[id_])]) 46 | 47 | if len(sentence) != 0: 48 | annotations.append(sentence) 49 | 50 | random.shuffle(annotations) 51 | 52 | #80% for training, 10% dev, 10% test 53 | d_train = annotations[ : floor(0.8 * len(annotations))] 54 | d_test = annotations[floor(0.8 * len(annotations)) : floor(0.9 * len(annotations))] 55 | d_dev = annotations[floor(0.9 * len(annotations)): ] 56 | 57 | def prep_text_data(D, outfile): 58 | with open(outfile, 'w') as f: 59 | for sent in D: 60 | for _, word, label in sent: 61 | f.write("%s %f\n" % (word, label)) 62 | f.write("\n") 63 | 64 | prep_text_data(d_train, config.train_filename) 65 | prep_text_data(d_test, config.test_filename) 66 | prep_text_data(d_dev, config.dev_filename) 67 | 68 | processing_word = get_processing_word(lowercase=True) 69 | 70 | # Generators 71 | dev = AnnotationDataset(config.dev_filename, processing_word) 72 | test = AnnotationDataset(config.test_filename, processing_word) 73 | train = AnnotationDataset(config.train_filename, processing_word) 74 | 75 | # Build Word and Tag vocab 76 | # Vocabulary is built using training data 77 | vocab_words, vocab_tags = get_vocabs([train]) 78 | vocab_glove = get_glove_vocab(config.glove_filename) 79 | 80 | vocab = vocab_words & vocab_glove 81 | vocab.add(UNK) 82 | vocab.add(NUM) 83 | 84 | # Save vocab 85 | write_vocab(vocab, config.words_filename) 86 | write_vocab(vocab_tags, config.tags_filename) 87 | 88 | # Trim GloVe Vectors 89 | vocab = load_vocab(config.words_filename) 90 | export_trimmed_glove_vectors(vocab, config.glove_filename, 91 | config.trimmed_filename, config.dim) 92 | 93 | # Build and save char vocab 94 | train = AnnotationDataset(config.train_filename) 95 | vocab_chars = get_char_vocab(train) 96 | write_vocab(vocab_chars, config.chars_filename) 97 | 98 | 99 | if __name__ == "__main__": 100 | config = Config() 101 | build_data(config) -------------------------------------------------------------------------------- /compare-annotation.R: -------------------------------------------------------------------------------- 1 | annotation_src = '/compare-predictions.csv' 2 | annotations = read.csv(annotation_src) 3 | 4 | method1 <- annotations[['truth']] 5 | method2 <- annotations[['predictions']] 6 | 7 | library(epiR) 8 | 9 | tmp <- data.frame(method1, method2) 10 | tmp.ccc <- epi.ccc(method1, method2, ci = "z-transform", conf.level = 0.95) 11 | 12 | tmp.lab <- data.frame(lab = paste("CCC: ", 13 | round(tmp.ccc$rho.c[,1], digits = 2), " (95% CI ", 14 | round(tmp.ccc$rho.c[,2], digits = 2), " - ", 15 | round(tmp.ccc$rho.c[,3], digits = 2), ")", sep = "")) 16 | 17 | z <- lm(method2 ~ method1) 18 | alpha <- summary(z)$coefficients[1,1] 19 | beta <- summary(z)$coefficients[2,1] 20 | tmp.lm <- data.frame(alpha, beta) 21 | 22 | ## Concordance correlation plot: 23 | library(ggplot2) 24 | 25 | ggplot(tmp, aes(x = method1, y = method2)) + 26 | geom_point() + 27 | geom_abline(intercept = 0, slope = 1) + 28 | geom_abline(data = tmp.lm, aes(intercept = alpha, slope = beta), 29 | linetype = "dashed") + 30 | xlim(0, 1.1) + 31 | ylim(0, 1.1) + 32 | xlab("Annotator 1") + 33 | ylab("Annotator 2") + 34 | geom_text(data = tmp.lab, x = 0.5, y = 2.95, label = tmp.lab$lab) + 35 | coord_fixed(ratio = 1 / 1) 36 | 37 | ## In this plot the dashed line represents the line of perfect concordance. 38 | ## The solid line represents the reduced major axis. 39 | 40 | 41 | ## Bland and Altman plot (Figure 2 from Bland and Altman 1986): 42 | tmp.ccc <- epi.ccc(method1, method2, ci = "z-transform", conf.level = 0.95, 43 | rep.measure = FALSE) 44 | tmp <- data.frame(mean = tmp.ccc$blalt[,1], delta = tmp.ccc$blalt[,2]) 45 | 46 | 47 | library(ggplot2) 48 | 49 | ggplot(tmp.ccc$blalt, aes(x = mean, y = delta)) + 50 | geom_point() + 51 | geom_hline(data = tmp.ccc$sblalt, aes(yintercept = lower), linetype = 2) + 52 | geom_hline(data = tmp.ccc$sblalt, aes(yintercept = upper), linetype = 2) + 53 | geom_hline(data = tmp.ccc$sblalt, aes(yintercept = est), linetype = 1) + 54 | xlab("Average PEFR by two meters (L/min)") + 55 | ylab("Difference in PEFR (L/min)") + 56 | xlim(0, 1) + 57 | ylim(-1,1) 58 | 59 | 60 | ## Interclass Correlation 61 | library(psych) 62 | 63 | sf <- data.matrix(annotations) 64 | ICC(sf, missing=FALSE, alpha = 0.05) 65 | 66 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os, math 2 | from general_utils import get_logger 3 | 4 | class Config(): 5 | def __init__(self): 6 | if not os.path.exists(self.output_path): 7 | os.makedirs(self.output_path) 8 | self.logger = get_logger(self.log_path) 9 | 10 | # location of the Word Improtance Corpus 11 | wimp_corpus = "--UPDATE--" 12 | 13 | # location of the Switchboard transcripts 14 | swd_transcripts = "--UPDATE--" 15 | 16 | # type of model 17 | model = "lstm_crf" 18 | opt_metric = "f-score" 19 | nclass = 6 20 | 21 | random_seed = 1000 22 | 23 | # general config 24 | output_path = "results/exp-1/" 25 | model_output = output_path + "model.weights/" 26 | log_path = output_path + "log.txt" 27 | confusion_mat = output_path + "confusion-mat.png" 28 | compare_predictions = output_path + "compare-predictions.csv" 29 | 30 | # embeddings 31 | dim = 300 32 | dim_char = 100 33 | glove_filename = "data/glove.6B/glove.6B.300d.txt" 34 | trimmed_filename = "data/glove.6B.300d.trimmed.npz" 35 | 36 | # dataset 37 | dev_filename = "data/testa.txt" 38 | test_filename = "data/testb.txt" 39 | train_filename = "data/train.txt" 40 | 41 | # vocab 42 | words_filename = "data/words.txt" 43 | tags_filename = "data/tags.txt" 44 | chars_filename = "data/chars.txt" 45 | 46 | # training 47 | train_embeddings = False 48 | nepochs = 20 49 | dropout = 0.5 50 | batch_size = 20 51 | lr_method = "adam" 52 | lr = 0.001 53 | lr_decay = 0.9 54 | nepoch_no_imprv = 7 55 | reload = False 56 | 57 | # model hyperparameters 58 | word_rnn_size = 300 59 | char_rnn_size = 100 60 | 61 | 62 | # some utility functions 63 | def ann2class(self, tag): 64 | tag = float(tag) 65 | if self.nclass == 6: 66 | if tag < 0.1: 67 | return 0 68 | return int(math.ceil(tag/0.2)) 69 | elif self.nclass == 3: 70 | if tag < 0.3: 71 | return 0 72 | elif tag < 0.6: 73 | return 1 74 | return 2 75 | elif self.nclass == 2: 76 | if tag < 0.5: 77 | return 0 78 | return 1 79 | 80 | def class2ann(self, tag): 81 | tag = float(tag) 82 | if self.nclass == 6: 83 | return tag/5. 84 | elif self.nclass == 3: 85 | return ((tag + 1) * 0.3 - 0.1) 86 | elif self.nclass == 2: 87 | return 0.25 if tag == 0 else 0.75 88 | 89 | def digitize_labels(self, tags): 90 | return list(map(self.ann2class, tags)) 91 | 92 | 93 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | 2 | """Most utility functions here has been adopted from: 3 | https://github.com/guillaumegenthial/sequence_tagging/blob/master/model/data_utils.py 4 | """ 5 | 6 | import numpy as np 7 | import os 8 | import math 9 | 10 | # shared global variables to be imported from model also 11 | UNK = "$UNK$" 12 | NUM = "$NUM$" 13 | NONE = "O" 14 | 15 | # read the word importance scores 16 | class AnnotationDataset(object): 17 | def __init__(self, filename, processing_word=None): 18 | self.filename = filename 19 | self.processing_word = processing_word 20 | self.length = None 21 | 22 | def __iter__(self): 23 | with open(self.filename) as f: 24 | words, tags = [], [] 25 | for line in f: 26 | line = line.strip() 27 | if (len(line) == 0): 28 | if len(words) != 0: 29 | yield words, tags 30 | words, tags = [], [] 31 | else: 32 | ls = line.split(' ') 33 | word, tag = ls[0], ls[-1] 34 | if self.processing_word is not None: 35 | word = self.processing_word(word) 36 | words += [word] 37 | tags += [tag] 38 | 39 | 40 | def __len__(self): 41 | if self.length is None: 42 | self.length = 0 43 | for _ in self: 44 | self.length += 1 45 | return self.length 46 | 47 | def get_vocabs(datasets): 48 | print("Building vocab...") 49 | vocab_words = set() 50 | vocab_tags = set() 51 | for dataset in datasets: 52 | for words, tags in dataset: 53 | vocab_words.update(words) 54 | vocab_tags.update(tags) 55 | print("- done. {} tokens".format(len(vocab_words))) 56 | return vocab_words, vocab_tags 57 | 58 | 59 | def get_char_vocab(dataset): 60 | vocab_char = set() 61 | for words, _ in dataset: 62 | for word in words: 63 | vocab_char.update(word) 64 | 65 | return vocab_char 66 | 67 | 68 | def get_glove_vocab(filename): 69 | print("Building vocab...") 70 | vocab = set() 71 | with open(filename) as f: 72 | for line in f: 73 | word = line.strip().split(' ')[0] 74 | vocab.add(word) 75 | print("- done. {} tokens".format(len(vocab))) 76 | return vocab 77 | 78 | def get_google_vocab(filename): 79 | from gensim.models import Word2Vec 80 | model = Word2Vec.load_word2vec_format(filename, binary=True) 81 | 82 | print ("Building vocab...") 83 | vocab = set(model.vocab.keys()) 84 | 85 | print ("- done. {} tokens".format(len(vocab))) 86 | return model, vocab 87 | 88 | 89 | def get_senna_vocab(filename): 90 | print ("Building vocab...") 91 | vocab = set() 92 | with open(filename) as f: 93 | for line in f: 94 | word = line.strip() 95 | vocab.add(word) 96 | print ("- done. {} tokens".format(len(vocab))) 97 | return vocab 98 | 99 | 100 | def write_vocab(vocab, filename): 101 | print("Writing vocab...") 102 | with open(filename, "w") as f: 103 | for i, word in enumerate(vocab): 104 | if i != len(vocab) - 1: 105 | f.write("{}\n".format(word)) 106 | else: 107 | f.write(word) 108 | print("- done. {} tokens".format(len(vocab))) 109 | 110 | 111 | def load_vocab(filename): 112 | try: 113 | d = dict() 114 | with open(filename) as f: 115 | for idx, word in enumerate(f): 116 | word = word.strip() 117 | d[word] = idx 118 | 119 | except IOError: 120 | raise MyIOError(filename) 121 | return d 122 | 123 | 124 | def export_trimmed_glove_vectors(vocab, glove_filename, trimmed_filename, dim): 125 | embeddings = np.zeros([len(vocab), dim]) 126 | with open(glove_filename) as f: 127 | for line in f: 128 | line = line.strip().split(' ') 129 | word = line[0] 130 | embedding = [float(x) for x in line[1:]] 131 | if word in vocab: 132 | word_idx = vocab[word] 133 | embeddings[word_idx] = np.asarray(embedding) 134 | 135 | np.savez_compressed(trimmed_filename, embeddings=embeddings) 136 | 137 | 138 | def export_trimmed_google_vectors(vocab, google_model, trimmed_filename, dim, random): 139 | embeddings = np.asarray(random.normal(loc=0.0, scale=0.1, size= [len(vocab), dim]), dtype=np.float32) 140 | for word in google_model.vocab.keys(): 141 | if word in vocab: 142 | word_idx = vocab[word] 143 | embedding = google_model[word] 144 | embeddings[word_idx] = np.asarray(embedding) 145 | 146 | np.savez_compressed(trimmed_filename, embeddings=embeddings) 147 | 148 | 149 | def export_trimmed_senna_vectors(vocab, vocab_emb, senna_filename, trimmed_filename, dim): 150 | embeddings = np.zeros([len(vocab), dim]) 151 | vocab_emb = list(vocab_emb) 152 | with open(senna_filename) as f: 153 | for i, line in enumerate(f): 154 | line = line.strip().split(' ') 155 | word = vocab_emb[i] 156 | embedding = map(float, line) 157 | if word in vocab: 158 | word_idx = vocab[word] 159 | embeddings[word_idx] = np.asarray(embedding) 160 | 161 | np.savez_compressed(trimmed_filename, embeddings=embeddings) 162 | 163 | 164 | def get_trimmed_glove_vectors(filename): 165 | try: 166 | with np.load(filename) as data: 167 | return data["embeddings"] 168 | 169 | except IOError: 170 | raise MyIOError(filename) 171 | 172 | 173 | def get_trimmed_vectors(filename): 174 | return get_trimmed_glove_vectors(filename) 175 | 176 | 177 | def get_processing_word(vocab_words=None, vocab_chars=None, 178 | lowercase=False, chars=False): 179 | def f(word): 180 | # 0. get chars of words 181 | if vocab_chars is not None and chars == True: 182 | char_ids = [] 183 | for char in word: 184 | # ignore chars out of vocabulary 185 | if char in vocab_chars: 186 | char_ids += [vocab_chars[char]] 187 | 188 | # 1. preprocess word 189 | if lowercase: 190 | word = word.lower() 191 | if word.isdigit(): 192 | word = NUM 193 | 194 | # 2. get id of word 195 | if vocab_words is not None: 196 | if word in vocab_words: 197 | word = vocab_words[word] 198 | else: 199 | word = vocab_words[UNK] 200 | 201 | # 3. return tuple char ids, word id 202 | if vocab_chars is not None and chars == True: 203 | return char_ids, word 204 | else: 205 | return word 206 | 207 | return f 208 | 209 | 210 | def _pad_sequences(sequences, pad_tok, max_length): 211 | sequence_padded, sequence_length = [], [] 212 | 213 | for seq in sequences: 214 | seq = list(seq) 215 | seq_ = seq[:max_length] + [pad_tok]*max(max_length - len(seq), 0) 216 | sequence_padded += [seq_] 217 | sequence_length += [min(len(seq), max_length)] 218 | 219 | return sequence_padded, sequence_length 220 | 221 | 222 | def pad_sequences(sequences, pad_tok, nlevels=1): 223 | if nlevels == 1: 224 | max_length = max(map(lambda x : len(x), sequences)) 225 | sequence_padded, sequence_length = _pad_sequences(sequences, 226 | pad_tok, max_length) 227 | 228 | elif nlevels == 2: 229 | max_length_word = max([max(map(lambda x: len(x), seq)) for seq in sequences]) 230 | sequence_padded, sequence_length = [], [] 231 | for seq in sequences: 232 | # all words are same length now 233 | sp, sl = _pad_sequences(seq, pad_tok, max_length_word) 234 | sequence_padded += [sp] 235 | sequence_length += [sl] 236 | 237 | max_length_sentence = max(map(lambda x : len(x), sequences)) 238 | sequence_padded, _ = _pad_sequences(sequence_padded, [pad_tok]*max_length_word, 239 | max_length_sentence) 240 | sequence_length, _ = _pad_sequences(sequence_length, 0, max_length_sentence) 241 | 242 | 243 | return sequence_padded, sequence_length 244 | 245 | 246 | def minibatches(data, minibatch_size): 247 | x_batch, y_batch = [], [] 248 | for (x, y) in data: 249 | if len(x_batch) == minibatch_size: 250 | yield x_batch, y_batch 251 | x_batch, y_batch = [], [] 252 | 253 | if type(x[0]) == tuple: 254 | x = zip(*x) 255 | x_batch += [x] 256 | y_batch += [y] 257 | 258 | if len(x_batch) != 0: 259 | yield x_batch, y_batch 260 | -------------------------------------------------------------------------------- /general_utils.py: -------------------------------------------------------------------------------- 1 | 2 | """Most utility functions here has been adopted from: 3 | https://github.com/guillaumegenthial/sequence_tagging/blob/master/model/general_utils.py 4 | """ 5 | 6 | import time 7 | import sys 8 | import logging 9 | import numpy as np 10 | import re 11 | import itertools 12 | import numpy as np 13 | import matplotlib.pyplot as plt 14 | 15 | def get_logger(filename): 16 | logger = logging.getLogger('logger') 17 | logger.setLevel(logging.DEBUG) 18 | logging.basicConfig(format='%(message)s', level=logging.DEBUG) 19 | handler = logging.FileHandler(filename) 20 | handler.setLevel(logging.DEBUG) 21 | handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 22 | logging.getLogger().addHandler(handler) 23 | return logger 24 | 25 | 26 | def plot_confusion_matrix(config, cm, classes, 27 | normalize=False, 28 | title='Confusion matrix', 29 | cmap=plt.cm.Blues): 30 | if normalize: 31 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 32 | print("Normalized confusion matrix") 33 | else: 34 | print('Confusion matrix, without normalization') 35 | 36 | print(cm) 37 | 38 | fig = plt.figure() 39 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 40 | plt.title(title) 41 | plt.colorbar() 42 | tick_marks = np.arange(len(classes)) 43 | plt.xticks(tick_marks, classes, rotation=45) 44 | plt.yticks(tick_marks, classes) 45 | 46 | fmt = '.2f' if normalize else 'd' 47 | thresh = cm.max() / 2. 48 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 49 | plt.text(j, i, format(cm[i, j], fmt), 50 | horizontalalignment="center", 51 | color="white" if cm[i, j] > thresh else "black") 52 | 53 | plt.tight_layout() 54 | plt.ylabel('True label') 55 | plt.xlabel('Predicted label') 56 | fig.savefig(config.confusion_mat, dpi=fig.dpi) 57 | 58 | 59 | class Progbar(object): 60 | """Progbar class copied from keras (https://github.com/fchollet/keras/)""" 61 | 62 | def __init__(self, target, width=30, verbose=1): 63 | self.width = width 64 | self.target = target 65 | self.sum_values = {} 66 | self.unique_values = [] 67 | self.start = time.time() 68 | self.total_width = 0 69 | self.seen_so_far = 0 70 | self.verbose = verbose 71 | 72 | def update(self, current, values=[], exact=[], strict=[]): 73 | for k, v in values: 74 | if k not in self.sum_values: 75 | self.sum_values[k] = [v * (current - self.seen_so_far), current - self.seen_so_far] 76 | self.unique_values.append(k) 77 | else: 78 | self.sum_values[k][0] += v * (current - self.seen_so_far) 79 | self.sum_values[k][1] += (current - self.seen_so_far) 80 | for k, v in exact: 81 | if k not in self.sum_values: 82 | self.unique_values.append(k) 83 | self.sum_values[k] = [v, 1] 84 | 85 | for k, v in strict: 86 | if k not in self.sum_values: 87 | self.unique_values.append(k) 88 | self.sum_values[k] = v 89 | 90 | self.seen_so_far = current 91 | 92 | now = time.time() 93 | if self.verbose == 1: 94 | prev_total_width = self.total_width 95 | sys.stdout.write("\b" * prev_total_width) 96 | sys.stdout.write("\r") 97 | 98 | numdigits = int(np.floor(np.log10(self.target))) + 1 99 | barstr = '%%%dd/%%%dd [' % (numdigits, numdigits) 100 | bar = barstr % (current, self.target) 101 | prog = float(current)/self.target 102 | prog_width = int(self.width*prog) 103 | if prog_width > 0: 104 | bar += ('='*(prog_width-1)) 105 | if current < self.target: 106 | bar += '>' 107 | else: 108 | bar += '=' 109 | bar += ('.'*(self.width-prog_width)) 110 | bar += ']' 111 | sys.stdout.write(bar) 112 | self.total_width = len(bar) 113 | 114 | if current: 115 | time_per_unit = (now - self.start) / current 116 | else: 117 | time_per_unit = 0 118 | eta = time_per_unit*(self.target - current) 119 | info = '' 120 | if current < self.target: 121 | info += ' - ETA: %ds' % eta 122 | else: 123 | info += ' - %ds' % (now - self.start) 124 | for k in self.unique_values: 125 | if type(self.sum_values[k]) is list: 126 | info += ' - %s: %.4f' % (k, self.sum_values[k][0] / max(1, self.sum_values[k][1])) 127 | else: 128 | info += ' - %s: %s' % (k, self.sum_values[k]) 129 | 130 | self.total_width += len(info) 131 | if prev_total_width > self.total_width: 132 | info += ((prev_total_width-self.total_width) * " ") 133 | 134 | sys.stdout.write(info) 135 | sys.stdout.flush() 136 | 137 | if current >= self.target: 138 | sys.stdout.write("\n") 139 | 140 | if self.verbose == 2: 141 | if current >= self.target: 142 | info = '%ds' % (now - self.start) 143 | for k in self.unique_values: 144 | info += ' - %s: %.4f' % (k, self.sum_values[k][0] / max(1, self.sum_values[k][1])) 145 | sys.stdout.write(info + "\n") 146 | 147 | def add(self, n, values=[]): 148 | self.update(self.seen_so_far+n, values) 149 | 150 | def clean_word(word): 151 | word = word.strip() 152 | if word != "": 153 | if word[0] in ["<"]: 154 | return "" 155 | 156 | #remove [vocablized-noise] tags 157 | word = re.sub(r'\[vocalized-(.*?)\]', '', word) 158 | 159 | #[laughter-yeah], [laughter-i], etc. 160 | word = re.sub(r'\[[^]]*?-(.*?)\]', '\\1', word) 161 | 162 | #[atteck/attend], [regwet/regret], etc. 163 | word = re.sub(r'\[[^]]*?\/(.*?)\]', '\\1', word) 164 | 165 | #-[ok]ay, etc. 166 | word = re.sub(r'-\[([^\]\s]+)\]([^\[\s]+)', '\\1\\2', word) 167 | 168 | #th[ey]- , sim[ilar]- , etc. 169 | word = re.sub(r'([^\[\s]+)\[([^\]\s]+)\]-', '\\1\\2', word) 170 | 171 | #[silence], [noise], etc. 172 | word = re.sub(r'\[.*?\]', '', word) 173 | 174 | #, , etc. 175 | word = re.sub(r'<.*?>', '', word) 176 | 177 | word = re.sub(r'(\w)-([ \n])', '\\1\\2', word) 178 | 179 | #remove symbols like { and } 180 | word = re.sub(r'[{}]', '', word) 181 | 182 | return word 183 | 184 | 185 | 186 | -------------------------------------------------------------------------------- /images/word-importance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SushantKafle/speechtext-wimp-labeler/32b71e72f86ab7f864e75e8517bb32f4400352d4/images/word-importance.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from data_utils import get_trimmed_glove_vectors, load_vocab, \ 2 | get_processing_word, AnnotationDataset 3 | from config import Config 4 | from model import WImpModel 5 | 6 | def main(config): 7 | # load vocabs 8 | vocab_words = load_vocab(config.words_filename) 9 | vocab_chars = load_vocab(config.chars_filename) 10 | 11 | # get processing functions 12 | processing_word = get_processing_word(vocab_words, vocab_chars, 13 | lowercase=True, chars=True) 14 | 15 | # get pre trained embeddings 16 | embeddings = get_trimmed_glove_vectors(config.trimmed_filename) 17 | 18 | # create dataset 19 | dev = AnnotationDataset(config.dev_filename, processing_word) 20 | test = AnnotationDataset(config.test_filename, processing_word) 21 | train = AnnotationDataset(config.train_filename, processing_word) 22 | 23 | print ("Num. train: %d" % len(train)) 24 | print ("Num. test: %d" % len(test)) 25 | print ("Num. dev: %d" % len(dev)) 26 | 27 | model = WImpModel(config, embeddings, ntags=config.nclass, 28 | nchars=len(vocab_chars)) 29 | 30 | # build WImpModel 31 | model.build_graph() 32 | 33 | # train, evaluate and interact 34 | model.train(train, dev) 35 | model.evaluate(test) 36 | 37 | if __name__ == "__main__": 38 | # create instance of config 39 | config = Config() 40 | main(config) 41 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, csv, math 3 | import tensorflow as tf 4 | from data_utils import minibatches, pad_sequences 5 | from general_utils import Progbar, plot_confusion_matrix 6 | from model_utils import create_feedforward, get_rnn_cell 7 | from sklearn.metrics import precision_recall_fscore_support as score 8 | from sklearn.metrics import confusion_matrix 9 | 10 | 11 | class WImpModel(object): 12 | 13 | def __init__(self, config, embeddings, ntags, nchars = None): 14 | self.config = config 15 | self.embeddings = embeddings 16 | self.nchars = nchars 17 | self.ntags = ntags 18 | self.logger = config.logger 19 | self.rng = np.random.RandomState(self.config.random_seed) 20 | 21 | 22 | def _init_graph_(self): 23 | self.word_ids = tf.placeholder(tf.int32, shape=[None, None], 24 | name="word_ids") 25 | self.sequence_lengths = tf.placeholder(tf.int32, shape=[None], 26 | name="sequence_lengths") 27 | 28 | self.char_ids = tf.placeholder(tf.int32, shape=[None, None, None], 29 | name="char_ids") 30 | self.word_lengths = tf.placeholder(tf.int32, shape=[None, None], 31 | name="word_lengths") 32 | 33 | if self.config.model == "lstm_crf": 34 | self.imp_labels = tf.placeholder(tf.int32, shape=[None, None], 35 | name="imp_labels") 36 | else: 37 | self.imp_labels = tf.placeholder(tf.float32, shape=[None, None], 38 | name="imp_labels") 39 | 40 | self.dropout = tf.placeholder(dtype=tf.float32, shape=[], 41 | name="dropout") 42 | self.lr = tf.placeholder(dtype=tf.float32, shape=[], 43 | name="lr") 44 | 45 | 46 | def create_initializer(self, size): 47 | return tf.constant(np.asarray(self.rng.normal(loc = 0.0, scale = 0.1, size = size), dtype = np.float32)) 48 | 49 | 50 | def get_feed_dict(self, words, imp_labels = None, lr = None, dropout = None): 51 | char_ids, word_ids = zip(*words) 52 | word_ids, sequence_lengths = pad_sequences(word_ids, 0) 53 | char_ids, word_lengths = pad_sequences(char_ids, pad_tok=0, nlevels=2) 54 | 55 | feed = { 56 | self.word_ids: word_ids, 57 | self.sequence_lengths: sequence_lengths 58 | } 59 | 60 | feed[self.char_ids] = char_ids 61 | feed[self.word_lengths] = word_lengths 62 | 63 | if imp_labels is not None: 64 | imp_labels, _ = pad_sequences(imp_labels, 0) 65 | feed[self.imp_labels] = imp_labels 66 | 67 | if lr is not None: 68 | feed[self.lr] = lr 69 | 70 | if dropout is not None: 71 | feed[self.dropout] = dropout 72 | 73 | return feed, sequence_lengths 74 | 75 | def _define_embedings_(self): 76 | with tf.variable_scope("words"): 77 | _word_embeddings = tf.Variable(self.embeddings, name = "_word_embeddings", dtype = tf.float32, 78 | trainable = self.config.train_embeddings) 79 | word_embeddings = tf.nn.embedding_lookup(_word_embeddings, self.word_ids, 80 | name = "word_embeddings") 81 | 82 | with tf.variable_scope("chars"): 83 | _char_embeddings = tf.get_variable(name = "_char_embeddings", dtype = tf.float32, 84 | shape = [self.nchars, self.config.dim_char]) 85 | char_embeddings = tf.nn.embedding_lookup(_char_embeddings, self.char_ids, 86 | name = "char_embeddings") 87 | 88 | s = tf.shape(char_embeddings) 89 | char_embeddings = tf.reshape(char_embeddings, shape = [-1, s[-2], self.config.dim_char]) 90 | word_lengths = tf.reshape(self.word_lengths, shape = [-1]) 91 | 92 | cell_fw = get_rnn_cell(self.config.char_rnn_size, "LSTM", state_is_tuple = True) 93 | cell_bw = get_rnn_cell(self.config.char_rnn_size, "LSTM", state_is_tuple = True) 94 | 95 | _, ((_, output_fw), (_, output_bw)) = tf.nn.bidirectional_dynamic_rnn(cell_fw, 96 | cell_bw, char_embeddings, sequence_length = word_lengths, 97 | dtype = tf.float32) 98 | 99 | output = tf.concat([output_fw, output_bw], axis = -1) 100 | output = tf.reshape(output, shape= [-1, s[1], 2 * self.config.char_rnn_size]) 101 | 102 | word_embeddings = tf.concat([word_embeddings, output], axis=-1) 103 | self.word_embeddings = tf.nn.dropout(word_embeddings, self.dropout) 104 | 105 | 106 | def _define_logits_(self): 107 | with tf.variable_scope("bi-lstm"): 108 | cell_fw = get_rnn_cell(self.config.word_rnn_size, "LSTM") 109 | cell_bw = get_rnn_cell(self.config.word_rnn_size, "LSTM") 110 | (output_fw, output_bw), _ = tf.nn.bidirectional_dynamic_rnn(cell_fw, 111 | cell_bw, self.word_embeddings, sequence_length=self.sequence_lengths, 112 | dtype=tf.float32) 113 | output = tf.concat([output_fw, output_bw], axis=-1) 114 | 115 | ntime_steps = tf.shape(output)[1] 116 | output = tf.reshape(output, [-1, 2 * self.config.word_rnn_size]) 117 | pred_dim = 1 if self.config.model == "lstm_sig" else self.ntags 118 | pred = create_feedforward(output, 2 * self.config.word_rnn_size, pred_dim, self.create_initializer, 119 | "linear", "projection") 120 | 121 | if self.config.model == "lstm_sig": 122 | pred = tf.sigmoid(pred) 123 | self.logits = tf.reshape(pred, [-1, ntime_steps]) 124 | else: 125 | self.logits = tf.reshape(pred, [-1, ntime_steps, self.ntags]) 126 | 127 | def _define_predictions_(self): 128 | self.imp_pred = self.logits 129 | 130 | 131 | def _define_loss_(self): 132 | if self.config.model == "lstm_sig": 133 | Y = tf.reshape(self.imp_labels, [-1, 1]) 134 | pred_Y = tf.reshape(self.logits, [-1, 1]) 135 | self.loss = tf.sqrt(tf.reduce_sum(tf.pow(pred_Y - Y, 2))) 136 | else: 137 | log_likelihood, self.transition_params = tf.contrib.crf.crf_log_likelihood( 138 | self.logits, self.imp_labels, self.sequence_lengths) 139 | self.loss = tf.reduce_mean(-log_likelihood) 140 | 141 | def _setup_optimizer_(self): 142 | with tf.variable_scope("optimizer_setup"): 143 | if self.config.lr_method == 'adam': 144 | optimizer = tf.train.AdamOptimizer(self.lr) 145 | elif self.config.lr_method == 'adagrad': 146 | optimizer = tf.train.AdagradOptimizer(self.lr) 147 | elif self.config.lr_method == 'sgd': 148 | optimizer = tf.train.GradientDescentOptimizer(self.lr) 149 | else: 150 | optimizer = tf.train.RMSPropOptimizer(self.lr) 151 | 152 | self.optimize_ = optimizer.minimize(self.loss) 153 | 154 | 155 | def build_graph(self): 156 | self._init_graph_() 157 | self._define_embedings_() 158 | self._define_logits_() 159 | self._define_predictions_() 160 | self._define_loss_() 161 | self._setup_optimizer_() 162 | 163 | self.init = tf.global_variables_initializer() 164 | 165 | 166 | def predict_batch(self, sess, words): 167 | fd, sequence_lengths = self.get_feed_dict(words, dropout=1.0) 168 | if self.config.model == "lstm_sig": 169 | imp_pred = sess.run(self.imp_pred, feed_dict=fd) 170 | else: 171 | imp_pred = [] 172 | logits, transition_params = sess.run([self.logits, self.transition_params], 173 | feed_dict=fd) 174 | # iterate over the sentences 175 | for logit, sequence_length in zip(logits, sequence_lengths): 176 | # keep only the valid time steps 177 | logit = logit[:sequence_length] 178 | viterbi_sequence, viterbi_score = tf.contrib.crf.viterbi_decode( 179 | logit, transition_params) 180 | imp_pred += [viterbi_sequence] 181 | return imp_pred, sequence_lengths 182 | 183 | 184 | def run_epoch(self, sess, train, dev, epoch): 185 | nbatches = (len(train) + self.config.batch_size - 1) // self.config.batch_size 186 | prog = Progbar(target=nbatches) 187 | for i, (words, imp_labels) in enumerate(minibatches(train, self.config.batch_size)): 188 | 189 | if self.config.model == "lstm_crf": 190 | imp_labels = list(map(self.config.digitize_labels, imp_labels)) 191 | 192 | fd, _ = self.get_feed_dict(words, imp_labels, self.config.lr, self.config.dropout) 193 | _, train_loss = sess.run([self.optimize_, self.loss], feed_dict=fd) 194 | prog.update(i + 1, [("train loss", train_loss)]) 195 | 196 | result = self.run_evaluate(sess, dev) 197 | self.logger.info("- dev acc {:04.4f} - f {:04.4f} - rms {:04.4f}".format(100*result['accuracy'], 198 | 100 * result['f-score'], -1 * result['rms'])) 199 | return result 200 | 201 | def run_evaluate(self, sess, test, save=False): 202 | accs, rms = [], [] 203 | labs, labs_ = [], [] 204 | for words, imp_labels in minibatches(test, self.config.batch_size): 205 | imp_labels_, sequence_lengths = self.predict_batch(sess, words) 206 | for lab, lab_, length in zip(imp_labels, imp_labels_, sequence_lengths): 207 | lab = lab[:length] 208 | lab_ = lab_[:length] 209 | 210 | if self.config.model == "lstm_sig": 211 | d_lab = map(self.config.ann2class, lab) 212 | d_lab_ = map(self.config.ann2class, lab_) 213 | else: 214 | d_lab = list(map(self.config.ann2class, lab)) 215 | d_lab_ = lab_[:] 216 | lab_ = list(map(self.config.class2ann, d_lab_)) 217 | 218 | rms += [pow((float(a)-float(b)), 2) for (a,b) in zip(lab, lab_)] 219 | accs += [a==b for (a, b) in zip(d_lab, d_lab_)] 220 | 221 | labs.extend(d_lab) 222 | labs_.extend(d_lab_) 223 | 224 | if save: 225 | with open(self.config.compare_predictions, 'w') as f: 226 | csv_writer = csv.writer(f) 227 | csv_writer.writerow(['truth', 'predictions']) 228 | for y, pred_y in zip(labs, labs_): 229 | csv_writer.writerow([y, pred_y]) 230 | print ("'compare.csv' file saved!") 231 | 232 | p, r, f, s = score(labs, labs_, average="macro") 233 | cnf_mat = confusion_matrix(labs, labs_) 234 | acc = np.mean(accs) 235 | rms_ = np.sqrt(np.mean(rms)) 236 | return {'accuracy': acc, 'precision': p, 'recall': r, 'f-score': f, 'cnf': cnf_mat, 'rms': -1 * rms_} 237 | 238 | 239 | def train(self, train, dev): 240 | best_score = -100 241 | saver = tf.train.Saver() 242 | 243 | nepoch_no_imprv = 0 244 | with tf.Session() as sess: 245 | sess.run(self.init) 246 | 247 | for epoch in range(self.config.nepochs): 248 | self.logger.info("Epoch {:} out of {:}".format(epoch + 1, self.config.nepochs)) 249 | result = self.run_epoch(sess, train, dev, epoch) 250 | 251 | self.config.lr *= self.config.lr_decay 252 | 253 | if result[self.config.opt_metric] >= best_score: 254 | nepoch_no_imprv = 0 255 | if not os.path.exists(self.config.model_output): 256 | os.makedirs(self.config.model_output) 257 | saver.save(sess, self.config.model_output) 258 | best_score = result[self.config.opt_metric] 259 | self.logger.info("- new best score!") 260 | 261 | else: 262 | nepoch_no_imprv += 1 263 | if nepoch_no_imprv >= self.config.nepoch_no_imprv: 264 | self.logger.info("- early stopping {} epochs without improvement".format( 265 | nepoch_no_imprv)) 266 | break 267 | 268 | 269 | def evaluate(self, test): 270 | saver = tf.train.Saver() 271 | with tf.Session() as sess: 272 | self.logger.info("Testing model over test set") 273 | saver.restore(sess, self.config.model_output) 274 | result = self.run_evaluate(sess, test, save=True) 275 | 276 | #plot the confustion matrix 277 | plot_confusion_matrix(self.config, result['cnf'], classes=[str(i) for i in range(0, 6)], normalize=True, 278 | title='Normalized confusion matrix') 279 | self.logger.info("- test acc {:04.4f} - f {:04.4f} - rms {:04.4f}".format(100 * result['accuracy'], 280 | 100 * result['f-score'], -1 * result['rms'])) 281 | 282 | 283 | def annotate_files(self, file_path, out_path, processing_word): 284 | output_file = open(out_path, 'w') 285 | saver = tf.train.Saver() 286 | with tf.Session() as sess: 287 | saver.restore(sess, self.config.model_output) 288 | with open(file_path) as f: 289 | for line in f: 290 | sentence = line.strip() 291 | 292 | if sentence == "": 293 | output_file.write(" 0\n") 294 | output_file.write("\n") 295 | else: 296 | words_raw = sentence.strip().split(" ") 297 | 298 | words = [processing_word(w) for w in words_raw] 299 | if type(words[0]) == tuple: 300 | words = zip(*words) 301 | preds, _ = self.predict_batch(sess, [words]) 302 | preds = preds[0] 303 | 304 | for w, pred in zip(words_raw, preds): 305 | output_file.write(w + " " + str(pred) + "\n") 306 | output_file.write("\n") 307 | output_file.close() 308 | 309 | 310 | def reset(self): 311 | tf.reset_default_graph() 312 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def get_rnn_cell(size, type_, state_is_tuple = True): 4 | if type_ == "LSTM": 5 | return tf.contrib.rnn.LSTMCell(size, state_is_tuple=state_is_tuple) 6 | elif type_ == "GRU": 7 | return tf.contrib.rnn.GRUCell(size) 8 | 9 | def create_feedforward(input_tensor, input_size, output_size, fn_initializer, activation, scope): 10 | with tf.variable_scope(scope): 11 | weights = tf.get_variable("W_", dtype = tf.float32, initializer = fn_initializer((input_size, output_size))) 12 | bias = tf.get_variable("b_", dtype = tf.float32, initializer = fn_initializer((output_size,))) 13 | output = tf.matmul(input_tensor, weights) + bias 14 | if activation == "tanh": 15 | output = tf.tanh(output) 16 | elif activation == "sigmoid": 17 | output = tf.sigmoid(output) 18 | return output -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.14 2 | numpy==1.18.1 3 | scikit-learn==0.22.1 4 | scipy==1.4.1 5 | matplotlib==3.1.3 --------------------------------------------------------------------------------