├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── build_data.py
├── compare-annotation.R
├── config.py
├── data_utils.py
├── general_utils.py
├── images
└── word-importance.png
├── main.py
├── model.py
├── model_utils.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Sushant Kafle
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Generic Bi-LSTM Model for Word Importance Prediction in Spoken Dialogues:
2 | ==============================================================
3 |
4 | This project demonstrates the use of generic bi-directional LSTM models for predicting importance of words in a spoken dialgoue for understanding its meaning. The model operates on human-annotated corpus of word importance for its training and evaluation. The corpus can be downloaded from: http://latlab.ist.rit.edu/lrec2018
5 |
6 | 
7 |
8 | Performance Summary (will-be-updated-soon):
9 | -------------------------------------------
10 |
11 |
12 | | Model | Classes | f-score | rms |
13 | |:------------------------------: |:-------: |:-------: |------ |
14 | | bi-LSTM (char-embedding + CRF) | 3 | 0.73 | 0.598 |
15 | | bi-LSTM (char-embedding + CRF) | 6 | 0.60 | 0.154 |
16 |
17 |
18 |
19 | You can cite this work and/or the corpus using:
20 |
21 | > Sushant Kafle and Matt Huenerfauth. 2018. A Corpus for Modeling Word Importance in Spoken Dialogue Transcripts. In Proceedings of the 11th edition of the Language Resources and Evaluation Conference (LREC). ACM.
22 |
23 |
24 | I. Data Preparation
25 | ====================
26 |
27 | 1. Download and locate the text transcripts of the Switchboard Corpus and the corresponding word importance corpus:
28 | * The word importance corpus can be downloaded from this website: http://latlab.ist.rit.edu/lrec2018
29 | * The text transcripts form the Switchboard corpus can be downloaded via this link: https://www.isip.piconepress.com/projects/switchboard/releases/switchboard_word_alignments.tar.gz
30 |
31 | * In the “config.py” file, update the varibles shown below:
32 |
33 | # location of the Word Importance Corpus "annotations folder"
34 | wimp_corpus = --HERE--
35 |
36 | # location of the Switchboard transcripts
37 | swd_transcripts = --AND HERE--
38 |
39 | 2. Download glove vectors `glove.6B.300d.txt` from http://nlp.stanford.edu/data/glove.6B.zip and update `glove_filename` in `config.py`
40 |
41 | 3. Run the ‘build_data.py’ to prepare data for training, development and testing as:
42 |
43 | ```python build_data.py```
44 |
45 | This will create all the necessary files (such as the word vocabulary, character vocabulary and the training, development and test files) in the “$PROJECT_HOME/data/“ directory.
46 |
47 | II. Install Python Dependencies
48 | ======================
49 |
50 | `pip install -r requirements.txt`
51 |
52 |
53 | III. Running the model
54 | ======================
55 |
56 | 1. Traverse inside the model you want to train and open the ‘config.py’ file and review the configurations:
57 |
58 | * model : type of model to run (options: lstm_crf or lstm_sig)
59 | * wimp_corpus : Path to the Word Importance Corpus
60 | * swd_transcripts : Path to the Switchboard Transcripts
61 | * output_path : Path to the output directory
62 | * model_output : Path to save the best performing model
63 | * log_path : Path to store the log
64 | * confusion_mat : Path to save the image of the confusion matrix (part of analysis on the test data)
65 | * compare_predictions : Path to save the predictions of the model (.csv file is produced)
66 |
67 | * random_seed : Random seed
68 | * opt_metric : Metric to evaluate the progress at each epoch
69 | * nclass : Num of classes for prediction
70 |
71 | * dim : Size of the word embeddings used in the model
72 | * dim_char : Size of the character embeddings
73 | * glove_filename : Path to the glove-embeddings file
74 | * trimmed_filename : Path to save the trimmed glove-embeddings
75 |
76 | * dev_filename : Path to the development data, used to select the best epoch
77 | * test_filename : Path to the test data, use for evaluating the model performance
78 | * train_filename : Path to the train data
79 |
80 | * words_filename : Path to the word vocabulary
81 | * tags_filename : Path to the vocabulary of the tags
82 | * chars_filename : Path to the vocabulary of the characters
83 |
84 | * train_embeddings : If True, trains the word-level embeddings
85 | * nepochs : Maximum number of epoches to run
86 | * dropout : The probability of applying dropout during training
87 | * batch_size : Number of examples in each batch
88 | * lr_method : Optimization strategy (options: adam, adagrad, sgd, rmsprop)
89 | * lr : Learning rate
90 | * lr_decay : Rate of decay of the learning rate
91 | * clip : Gradient clipping, if negative no clipping
92 | * nepoch_no_imprv : Number of epoch without improvement for early termination
93 | * reload : Reload the latest trained model
94 |
95 | * word_rnn_size : Size of the word-level LSTM hidden layers
96 | * char_rnn_size : Size of the char-level LSTM hidden layers
97 |
98 | 2. Run the model by:
99 |
100 | ```python main.py```
101 |
102 | Summary:
103 | * trained model saved at “model_output” (declared inside config.py).
104 | * log of the analysis at “log_path” (declared inside config.py) - contains train, dev and test performance.
105 | * confusion matrix at “confusion_mat” (declared inside config.py).
106 | * CSV file containing the actual scores and the predicted score at “compare_predictions” (declared inside config.py).
107 |
108 |
109 | IV. Running the agreement analysis
110 | ====================================
111 |
112 | 1. Locate the csv file containing the actual scores annotated by the annotators and the predicted scores.
113 |
114 | 2. Open up “compare-annotation.R” file and update the “annotation_src” variable with this new location.
115 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SushantKafle/speechtext-wimp-labeler/32b71e72f86ab7f864e75e8517bb32f4400352d4/__init__.py
--------------------------------------------------------------------------------
/build_data.py:
--------------------------------------------------------------------------------
1 | from math import floor
2 |
3 | from config import Config
4 | import xlrd, os, random, csv
5 | from data_utils import AnnotationDataset, get_vocabs, UNK, NUM, \
6 | get_glove_vocab, write_vocab, load_vocab, get_char_vocab, \
7 | export_trimmed_glove_vectors, get_processing_word
8 | from general_utils import clean_word
9 |
10 | def build_data(config):
11 | annotations = []
12 | meta_filename = 'sw%s%s-ms98-a-trans.text' # % (file_id, speaker_id)
13 |
14 | for idx in os.listdir(config.wimp_corpus):
15 | idx_path = os.path.join(config.wimp_corpus, idx)
16 | if os.path.isfile(idx_path):
17 | continue
18 |
19 | for file_id in os.listdir(idx_path):
20 | folder = os.path.join(idx_path, file_id)
21 | if os.path.isfile(folder):
22 | continue
23 |
24 | wimp_trans_files = [os.path.join(folder, meta_filename % (file_id, 'A')),
25 | os.path.join(folder, meta_filename % (file_id, 'B'))]
26 |
27 | swd_trans_files = [os.path.join(config.swd_transcripts, idx, file_id, meta_filename % (file_id, 'A')),
28 | os.path.join(config.swd_transcripts, idx, file_id, meta_filename % (file_id, 'B'))]
29 |
30 | for i, wimp_trans_file in enumerate(wimp_trans_files):
31 | swd_trans_file = swd_trans_files[i]
32 | file_id, speaker = swd_trans_file.split("/")[-2:]
33 | speaker = speaker[6]
34 | with open(wimp_trans_file) as w_file_obj, open(swd_trans_file) as s_file_obj:
35 | for line_num, (anns_, wrds_) in enumerate(zip(w_file_obj, s_file_obj)):
36 | sentence = []
37 | anns = anns_.strip().split(' ')[3:]
38 | wrds = wrds_.strip().split(' ')[3:]
39 | assert(len(anns) == len(wrds)), \
40 | "file mismatch, line %d : %s and %s" % (line_num, swd_trans_file, wimp_trans_file)
41 |
42 | for id_, wrd in enumerate(wrds):
43 | wrd = clean_word(wrd)
44 | if wrd != '':
45 | sentence.append([(file_id, line_num, speaker), wrd, float(anns[id_])])
46 |
47 | if len(sentence) != 0:
48 | annotations.append(sentence)
49 |
50 | random.shuffle(annotations)
51 |
52 | #80% for training, 10% dev, 10% test
53 | d_train = annotations[ : floor(0.8 * len(annotations))]
54 | d_test = annotations[floor(0.8 * len(annotations)) : floor(0.9 * len(annotations))]
55 | d_dev = annotations[floor(0.9 * len(annotations)): ]
56 |
57 | def prep_text_data(D, outfile):
58 | with open(outfile, 'w') as f:
59 | for sent in D:
60 | for _, word, label in sent:
61 | f.write("%s %f\n" % (word, label))
62 | f.write("\n")
63 |
64 | prep_text_data(d_train, config.train_filename)
65 | prep_text_data(d_test, config.test_filename)
66 | prep_text_data(d_dev, config.dev_filename)
67 |
68 | processing_word = get_processing_word(lowercase=True)
69 |
70 | # Generators
71 | dev = AnnotationDataset(config.dev_filename, processing_word)
72 | test = AnnotationDataset(config.test_filename, processing_word)
73 | train = AnnotationDataset(config.train_filename, processing_word)
74 |
75 | # Build Word and Tag vocab
76 | # Vocabulary is built using training data
77 | vocab_words, vocab_tags = get_vocabs([train])
78 | vocab_glove = get_glove_vocab(config.glove_filename)
79 |
80 | vocab = vocab_words & vocab_glove
81 | vocab.add(UNK)
82 | vocab.add(NUM)
83 |
84 | # Save vocab
85 | write_vocab(vocab, config.words_filename)
86 | write_vocab(vocab_tags, config.tags_filename)
87 |
88 | # Trim GloVe Vectors
89 | vocab = load_vocab(config.words_filename)
90 | export_trimmed_glove_vectors(vocab, config.glove_filename,
91 | config.trimmed_filename, config.dim)
92 |
93 | # Build and save char vocab
94 | train = AnnotationDataset(config.train_filename)
95 | vocab_chars = get_char_vocab(train)
96 | write_vocab(vocab_chars, config.chars_filename)
97 |
98 |
99 | if __name__ == "__main__":
100 | config = Config()
101 | build_data(config)
--------------------------------------------------------------------------------
/compare-annotation.R:
--------------------------------------------------------------------------------
1 | annotation_src = '/compare-predictions.csv'
2 | annotations = read.csv(annotation_src)
3 |
4 | method1 <- annotations[['truth']]
5 | method2 <- annotations[['predictions']]
6 |
7 | library(epiR)
8 |
9 | tmp <- data.frame(method1, method2)
10 | tmp.ccc <- epi.ccc(method1, method2, ci = "z-transform", conf.level = 0.95)
11 |
12 | tmp.lab <- data.frame(lab = paste("CCC: ",
13 | round(tmp.ccc$rho.c[,1], digits = 2), " (95% CI ",
14 | round(tmp.ccc$rho.c[,2], digits = 2), " - ",
15 | round(tmp.ccc$rho.c[,3], digits = 2), ")", sep = ""))
16 |
17 | z <- lm(method2 ~ method1)
18 | alpha <- summary(z)$coefficients[1,1]
19 | beta <- summary(z)$coefficients[2,1]
20 | tmp.lm <- data.frame(alpha, beta)
21 |
22 | ## Concordance correlation plot:
23 | library(ggplot2)
24 |
25 | ggplot(tmp, aes(x = method1, y = method2)) +
26 | geom_point() +
27 | geom_abline(intercept = 0, slope = 1) +
28 | geom_abline(data = tmp.lm, aes(intercept = alpha, slope = beta),
29 | linetype = "dashed") +
30 | xlim(0, 1.1) +
31 | ylim(0, 1.1) +
32 | xlab("Annotator 1") +
33 | ylab("Annotator 2") +
34 | geom_text(data = tmp.lab, x = 0.5, y = 2.95, label = tmp.lab$lab) +
35 | coord_fixed(ratio = 1 / 1)
36 |
37 | ## In this plot the dashed line represents the line of perfect concordance.
38 | ## The solid line represents the reduced major axis.
39 |
40 |
41 | ## Bland and Altman plot (Figure 2 from Bland and Altman 1986):
42 | tmp.ccc <- epi.ccc(method1, method2, ci = "z-transform", conf.level = 0.95,
43 | rep.measure = FALSE)
44 | tmp <- data.frame(mean = tmp.ccc$blalt[,1], delta = tmp.ccc$blalt[,2])
45 |
46 |
47 | library(ggplot2)
48 |
49 | ggplot(tmp.ccc$blalt, aes(x = mean, y = delta)) +
50 | geom_point() +
51 | geom_hline(data = tmp.ccc$sblalt, aes(yintercept = lower), linetype = 2) +
52 | geom_hline(data = tmp.ccc$sblalt, aes(yintercept = upper), linetype = 2) +
53 | geom_hline(data = tmp.ccc$sblalt, aes(yintercept = est), linetype = 1) +
54 | xlab("Average PEFR by two meters (L/min)") +
55 | ylab("Difference in PEFR (L/min)") +
56 | xlim(0, 1) +
57 | ylim(-1,1)
58 |
59 |
60 | ## Interclass Correlation
61 | library(psych)
62 |
63 | sf <- data.matrix(annotations)
64 | ICC(sf, missing=FALSE, alpha = 0.05)
65 |
66 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import os, math
2 | from general_utils import get_logger
3 |
4 | class Config():
5 | def __init__(self):
6 | if not os.path.exists(self.output_path):
7 | os.makedirs(self.output_path)
8 | self.logger = get_logger(self.log_path)
9 |
10 | # location of the Word Improtance Corpus
11 | wimp_corpus = "--UPDATE--"
12 |
13 | # location of the Switchboard transcripts
14 | swd_transcripts = "--UPDATE--"
15 |
16 | # type of model
17 | model = "lstm_crf"
18 | opt_metric = "f-score"
19 | nclass = 6
20 |
21 | random_seed = 1000
22 |
23 | # general config
24 | output_path = "results/exp-1/"
25 | model_output = output_path + "model.weights/"
26 | log_path = output_path + "log.txt"
27 | confusion_mat = output_path + "confusion-mat.png"
28 | compare_predictions = output_path + "compare-predictions.csv"
29 |
30 | # embeddings
31 | dim = 300
32 | dim_char = 100
33 | glove_filename = "data/glove.6B/glove.6B.300d.txt"
34 | trimmed_filename = "data/glove.6B.300d.trimmed.npz"
35 |
36 | # dataset
37 | dev_filename = "data/testa.txt"
38 | test_filename = "data/testb.txt"
39 | train_filename = "data/train.txt"
40 |
41 | # vocab
42 | words_filename = "data/words.txt"
43 | tags_filename = "data/tags.txt"
44 | chars_filename = "data/chars.txt"
45 |
46 | # training
47 | train_embeddings = False
48 | nepochs = 20
49 | dropout = 0.5
50 | batch_size = 20
51 | lr_method = "adam"
52 | lr = 0.001
53 | lr_decay = 0.9
54 | nepoch_no_imprv = 7
55 | reload = False
56 |
57 | # model hyperparameters
58 | word_rnn_size = 300
59 | char_rnn_size = 100
60 |
61 |
62 | # some utility functions
63 | def ann2class(self, tag):
64 | tag = float(tag)
65 | if self.nclass == 6:
66 | if tag < 0.1:
67 | return 0
68 | return int(math.ceil(tag/0.2))
69 | elif self.nclass == 3:
70 | if tag < 0.3:
71 | return 0
72 | elif tag < 0.6:
73 | return 1
74 | return 2
75 | elif self.nclass == 2:
76 | if tag < 0.5:
77 | return 0
78 | return 1
79 |
80 | def class2ann(self, tag):
81 | tag = float(tag)
82 | if self.nclass == 6:
83 | return tag/5.
84 | elif self.nclass == 3:
85 | return ((tag + 1) * 0.3 - 0.1)
86 | elif self.nclass == 2:
87 | return 0.25 if tag == 0 else 0.75
88 |
89 | def digitize_labels(self, tags):
90 | return list(map(self.ann2class, tags))
91 |
92 |
93 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 |
2 | """Most utility functions here has been adopted from:
3 | https://github.com/guillaumegenthial/sequence_tagging/blob/master/model/data_utils.py
4 | """
5 |
6 | import numpy as np
7 | import os
8 | import math
9 |
10 | # shared global variables to be imported from model also
11 | UNK = "$UNK$"
12 | NUM = "$NUM$"
13 | NONE = "O"
14 |
15 | # read the word importance scores
16 | class AnnotationDataset(object):
17 | def __init__(self, filename, processing_word=None):
18 | self.filename = filename
19 | self.processing_word = processing_word
20 | self.length = None
21 |
22 | def __iter__(self):
23 | with open(self.filename) as f:
24 | words, tags = [], []
25 | for line in f:
26 | line = line.strip()
27 | if (len(line) == 0):
28 | if len(words) != 0:
29 | yield words, tags
30 | words, tags = [], []
31 | else:
32 | ls = line.split(' ')
33 | word, tag = ls[0], ls[-1]
34 | if self.processing_word is not None:
35 | word = self.processing_word(word)
36 | words += [word]
37 | tags += [tag]
38 |
39 |
40 | def __len__(self):
41 | if self.length is None:
42 | self.length = 0
43 | for _ in self:
44 | self.length += 1
45 | return self.length
46 |
47 | def get_vocabs(datasets):
48 | print("Building vocab...")
49 | vocab_words = set()
50 | vocab_tags = set()
51 | for dataset in datasets:
52 | for words, tags in dataset:
53 | vocab_words.update(words)
54 | vocab_tags.update(tags)
55 | print("- done. {} tokens".format(len(vocab_words)))
56 | return vocab_words, vocab_tags
57 |
58 |
59 | def get_char_vocab(dataset):
60 | vocab_char = set()
61 | for words, _ in dataset:
62 | for word in words:
63 | vocab_char.update(word)
64 |
65 | return vocab_char
66 |
67 |
68 | def get_glove_vocab(filename):
69 | print("Building vocab...")
70 | vocab = set()
71 | with open(filename) as f:
72 | for line in f:
73 | word = line.strip().split(' ')[0]
74 | vocab.add(word)
75 | print("- done. {} tokens".format(len(vocab)))
76 | return vocab
77 |
78 | def get_google_vocab(filename):
79 | from gensim.models import Word2Vec
80 | model = Word2Vec.load_word2vec_format(filename, binary=True)
81 |
82 | print ("Building vocab...")
83 | vocab = set(model.vocab.keys())
84 |
85 | print ("- done. {} tokens".format(len(vocab)))
86 | return model, vocab
87 |
88 |
89 | def get_senna_vocab(filename):
90 | print ("Building vocab...")
91 | vocab = set()
92 | with open(filename) as f:
93 | for line in f:
94 | word = line.strip()
95 | vocab.add(word)
96 | print ("- done. {} tokens".format(len(vocab)))
97 | return vocab
98 |
99 |
100 | def write_vocab(vocab, filename):
101 | print("Writing vocab...")
102 | with open(filename, "w") as f:
103 | for i, word in enumerate(vocab):
104 | if i != len(vocab) - 1:
105 | f.write("{}\n".format(word))
106 | else:
107 | f.write(word)
108 | print("- done. {} tokens".format(len(vocab)))
109 |
110 |
111 | def load_vocab(filename):
112 | try:
113 | d = dict()
114 | with open(filename) as f:
115 | for idx, word in enumerate(f):
116 | word = word.strip()
117 | d[word] = idx
118 |
119 | except IOError:
120 | raise MyIOError(filename)
121 | return d
122 |
123 |
124 | def export_trimmed_glove_vectors(vocab, glove_filename, trimmed_filename, dim):
125 | embeddings = np.zeros([len(vocab), dim])
126 | with open(glove_filename) as f:
127 | for line in f:
128 | line = line.strip().split(' ')
129 | word = line[0]
130 | embedding = [float(x) for x in line[1:]]
131 | if word in vocab:
132 | word_idx = vocab[word]
133 | embeddings[word_idx] = np.asarray(embedding)
134 |
135 | np.savez_compressed(trimmed_filename, embeddings=embeddings)
136 |
137 |
138 | def export_trimmed_google_vectors(vocab, google_model, trimmed_filename, dim, random):
139 | embeddings = np.asarray(random.normal(loc=0.0, scale=0.1, size= [len(vocab), dim]), dtype=np.float32)
140 | for word in google_model.vocab.keys():
141 | if word in vocab:
142 | word_idx = vocab[word]
143 | embedding = google_model[word]
144 | embeddings[word_idx] = np.asarray(embedding)
145 |
146 | np.savez_compressed(trimmed_filename, embeddings=embeddings)
147 |
148 |
149 | def export_trimmed_senna_vectors(vocab, vocab_emb, senna_filename, trimmed_filename, dim):
150 | embeddings = np.zeros([len(vocab), dim])
151 | vocab_emb = list(vocab_emb)
152 | with open(senna_filename) as f:
153 | for i, line in enumerate(f):
154 | line = line.strip().split(' ')
155 | word = vocab_emb[i]
156 | embedding = map(float, line)
157 | if word in vocab:
158 | word_idx = vocab[word]
159 | embeddings[word_idx] = np.asarray(embedding)
160 |
161 | np.savez_compressed(trimmed_filename, embeddings=embeddings)
162 |
163 |
164 | def get_trimmed_glove_vectors(filename):
165 | try:
166 | with np.load(filename) as data:
167 | return data["embeddings"]
168 |
169 | except IOError:
170 | raise MyIOError(filename)
171 |
172 |
173 | def get_trimmed_vectors(filename):
174 | return get_trimmed_glove_vectors(filename)
175 |
176 |
177 | def get_processing_word(vocab_words=None, vocab_chars=None,
178 | lowercase=False, chars=False):
179 | def f(word):
180 | # 0. get chars of words
181 | if vocab_chars is not None and chars == True:
182 | char_ids = []
183 | for char in word:
184 | # ignore chars out of vocabulary
185 | if char in vocab_chars:
186 | char_ids += [vocab_chars[char]]
187 |
188 | # 1. preprocess word
189 | if lowercase:
190 | word = word.lower()
191 | if word.isdigit():
192 | word = NUM
193 |
194 | # 2. get id of word
195 | if vocab_words is not None:
196 | if word in vocab_words:
197 | word = vocab_words[word]
198 | else:
199 | word = vocab_words[UNK]
200 |
201 | # 3. return tuple char ids, word id
202 | if vocab_chars is not None and chars == True:
203 | return char_ids, word
204 | else:
205 | return word
206 |
207 | return f
208 |
209 |
210 | def _pad_sequences(sequences, pad_tok, max_length):
211 | sequence_padded, sequence_length = [], []
212 |
213 | for seq in sequences:
214 | seq = list(seq)
215 | seq_ = seq[:max_length] + [pad_tok]*max(max_length - len(seq), 0)
216 | sequence_padded += [seq_]
217 | sequence_length += [min(len(seq), max_length)]
218 |
219 | return sequence_padded, sequence_length
220 |
221 |
222 | def pad_sequences(sequences, pad_tok, nlevels=1):
223 | if nlevels == 1:
224 | max_length = max(map(lambda x : len(x), sequences))
225 | sequence_padded, sequence_length = _pad_sequences(sequences,
226 | pad_tok, max_length)
227 |
228 | elif nlevels == 2:
229 | max_length_word = max([max(map(lambda x: len(x), seq)) for seq in sequences])
230 | sequence_padded, sequence_length = [], []
231 | for seq in sequences:
232 | # all words are same length now
233 | sp, sl = _pad_sequences(seq, pad_tok, max_length_word)
234 | sequence_padded += [sp]
235 | sequence_length += [sl]
236 |
237 | max_length_sentence = max(map(lambda x : len(x), sequences))
238 | sequence_padded, _ = _pad_sequences(sequence_padded, [pad_tok]*max_length_word,
239 | max_length_sentence)
240 | sequence_length, _ = _pad_sequences(sequence_length, 0, max_length_sentence)
241 |
242 |
243 | return sequence_padded, sequence_length
244 |
245 |
246 | def minibatches(data, minibatch_size):
247 | x_batch, y_batch = [], []
248 | for (x, y) in data:
249 | if len(x_batch) == minibatch_size:
250 | yield x_batch, y_batch
251 | x_batch, y_batch = [], []
252 |
253 | if type(x[0]) == tuple:
254 | x = zip(*x)
255 | x_batch += [x]
256 | y_batch += [y]
257 |
258 | if len(x_batch) != 0:
259 | yield x_batch, y_batch
260 |
--------------------------------------------------------------------------------
/general_utils.py:
--------------------------------------------------------------------------------
1 |
2 | """Most utility functions here has been adopted from:
3 | https://github.com/guillaumegenthial/sequence_tagging/blob/master/model/general_utils.py
4 | """
5 |
6 | import time
7 | import sys
8 | import logging
9 | import numpy as np
10 | import re
11 | import itertools
12 | import numpy as np
13 | import matplotlib.pyplot as plt
14 |
15 | def get_logger(filename):
16 | logger = logging.getLogger('logger')
17 | logger.setLevel(logging.DEBUG)
18 | logging.basicConfig(format='%(message)s', level=logging.DEBUG)
19 | handler = logging.FileHandler(filename)
20 | handler.setLevel(logging.DEBUG)
21 | handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
22 | logging.getLogger().addHandler(handler)
23 | return logger
24 |
25 |
26 | def plot_confusion_matrix(config, cm, classes,
27 | normalize=False,
28 | title='Confusion matrix',
29 | cmap=plt.cm.Blues):
30 | if normalize:
31 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
32 | print("Normalized confusion matrix")
33 | else:
34 | print('Confusion matrix, without normalization')
35 |
36 | print(cm)
37 |
38 | fig = plt.figure()
39 | plt.imshow(cm, interpolation='nearest', cmap=cmap)
40 | plt.title(title)
41 | plt.colorbar()
42 | tick_marks = np.arange(len(classes))
43 | plt.xticks(tick_marks, classes, rotation=45)
44 | plt.yticks(tick_marks, classes)
45 |
46 | fmt = '.2f' if normalize else 'd'
47 | thresh = cm.max() / 2.
48 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
49 | plt.text(j, i, format(cm[i, j], fmt),
50 | horizontalalignment="center",
51 | color="white" if cm[i, j] > thresh else "black")
52 |
53 | plt.tight_layout()
54 | plt.ylabel('True label')
55 | plt.xlabel('Predicted label')
56 | fig.savefig(config.confusion_mat, dpi=fig.dpi)
57 |
58 |
59 | class Progbar(object):
60 | """Progbar class copied from keras (https://github.com/fchollet/keras/)"""
61 |
62 | def __init__(self, target, width=30, verbose=1):
63 | self.width = width
64 | self.target = target
65 | self.sum_values = {}
66 | self.unique_values = []
67 | self.start = time.time()
68 | self.total_width = 0
69 | self.seen_so_far = 0
70 | self.verbose = verbose
71 |
72 | def update(self, current, values=[], exact=[], strict=[]):
73 | for k, v in values:
74 | if k not in self.sum_values:
75 | self.sum_values[k] = [v * (current - self.seen_so_far), current - self.seen_so_far]
76 | self.unique_values.append(k)
77 | else:
78 | self.sum_values[k][0] += v * (current - self.seen_so_far)
79 | self.sum_values[k][1] += (current - self.seen_so_far)
80 | for k, v in exact:
81 | if k not in self.sum_values:
82 | self.unique_values.append(k)
83 | self.sum_values[k] = [v, 1]
84 |
85 | for k, v in strict:
86 | if k not in self.sum_values:
87 | self.unique_values.append(k)
88 | self.sum_values[k] = v
89 |
90 | self.seen_so_far = current
91 |
92 | now = time.time()
93 | if self.verbose == 1:
94 | prev_total_width = self.total_width
95 | sys.stdout.write("\b" * prev_total_width)
96 | sys.stdout.write("\r")
97 |
98 | numdigits = int(np.floor(np.log10(self.target))) + 1
99 | barstr = '%%%dd/%%%dd [' % (numdigits, numdigits)
100 | bar = barstr % (current, self.target)
101 | prog = float(current)/self.target
102 | prog_width = int(self.width*prog)
103 | if prog_width > 0:
104 | bar += ('='*(prog_width-1))
105 | if current < self.target:
106 | bar += '>'
107 | else:
108 | bar += '='
109 | bar += ('.'*(self.width-prog_width))
110 | bar += ']'
111 | sys.stdout.write(bar)
112 | self.total_width = len(bar)
113 |
114 | if current:
115 | time_per_unit = (now - self.start) / current
116 | else:
117 | time_per_unit = 0
118 | eta = time_per_unit*(self.target - current)
119 | info = ''
120 | if current < self.target:
121 | info += ' - ETA: %ds' % eta
122 | else:
123 | info += ' - %ds' % (now - self.start)
124 | for k in self.unique_values:
125 | if type(self.sum_values[k]) is list:
126 | info += ' - %s: %.4f' % (k, self.sum_values[k][0] / max(1, self.sum_values[k][1]))
127 | else:
128 | info += ' - %s: %s' % (k, self.sum_values[k])
129 |
130 | self.total_width += len(info)
131 | if prev_total_width > self.total_width:
132 | info += ((prev_total_width-self.total_width) * " ")
133 |
134 | sys.stdout.write(info)
135 | sys.stdout.flush()
136 |
137 | if current >= self.target:
138 | sys.stdout.write("\n")
139 |
140 | if self.verbose == 2:
141 | if current >= self.target:
142 | info = '%ds' % (now - self.start)
143 | for k in self.unique_values:
144 | info += ' - %s: %.4f' % (k, self.sum_values[k][0] / max(1, self.sum_values[k][1]))
145 | sys.stdout.write(info + "\n")
146 |
147 | def add(self, n, values=[]):
148 | self.update(self.seen_so_far+n, values)
149 |
150 | def clean_word(word):
151 | word = word.strip()
152 | if word != "":
153 | if word[0] in ["<"]:
154 | return ""
155 |
156 | #remove [vocablized-noise] tags
157 | word = re.sub(r'\[vocalized-(.*?)\]', '', word)
158 |
159 | #[laughter-yeah], [laughter-i], etc.
160 | word = re.sub(r'\[[^]]*?-(.*?)\]', '\\1', word)
161 |
162 | #[atteck/attend], [regwet/regret], etc.
163 | word = re.sub(r'\[[^]]*?\/(.*?)\]', '\\1', word)
164 |
165 | #-[ok]ay, etc.
166 | word = re.sub(r'-\[([^\]\s]+)\]([^\[\s]+)', '\\1\\2', word)
167 |
168 | #th[ey]- , sim[ilar]- , etc.
169 | word = re.sub(r'([^\[\s]+)\[([^\]\s]+)\]-', '\\1\\2', word)
170 |
171 | #[silence], [noise], etc.
172 | word = re.sub(r'\[.*?\]', '', word)
173 |
174 | #, , etc.
175 | word = re.sub(r'<.*?>', '', word)
176 |
177 | word = re.sub(r'(\w)-([ \n])', '\\1\\2', word)
178 |
179 | #remove symbols like { and }
180 | word = re.sub(r'[{}]', '', word)
181 |
182 | return word
183 |
184 |
185 |
186 |
--------------------------------------------------------------------------------
/images/word-importance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SushantKafle/speechtext-wimp-labeler/32b71e72f86ab7f864e75e8517bb32f4400352d4/images/word-importance.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from data_utils import get_trimmed_glove_vectors, load_vocab, \
2 | get_processing_word, AnnotationDataset
3 | from config import Config
4 | from model import WImpModel
5 |
6 | def main(config):
7 | # load vocabs
8 | vocab_words = load_vocab(config.words_filename)
9 | vocab_chars = load_vocab(config.chars_filename)
10 |
11 | # get processing functions
12 | processing_word = get_processing_word(vocab_words, vocab_chars,
13 | lowercase=True, chars=True)
14 |
15 | # get pre trained embeddings
16 | embeddings = get_trimmed_glove_vectors(config.trimmed_filename)
17 |
18 | # create dataset
19 | dev = AnnotationDataset(config.dev_filename, processing_word)
20 | test = AnnotationDataset(config.test_filename, processing_word)
21 | train = AnnotationDataset(config.train_filename, processing_word)
22 |
23 | print ("Num. train: %d" % len(train))
24 | print ("Num. test: %d" % len(test))
25 | print ("Num. dev: %d" % len(dev))
26 |
27 | model = WImpModel(config, embeddings, ntags=config.nclass,
28 | nchars=len(vocab_chars))
29 |
30 | # build WImpModel
31 | model.build_graph()
32 |
33 | # train, evaluate and interact
34 | model.train(train, dev)
35 | model.evaluate(test)
36 |
37 | if __name__ == "__main__":
38 | # create instance of config
39 | config = Config()
40 | main(config)
41 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os, csv, math
3 | import tensorflow as tf
4 | from data_utils import minibatches, pad_sequences
5 | from general_utils import Progbar, plot_confusion_matrix
6 | from model_utils import create_feedforward, get_rnn_cell
7 | from sklearn.metrics import precision_recall_fscore_support as score
8 | from sklearn.metrics import confusion_matrix
9 |
10 |
11 | class WImpModel(object):
12 |
13 | def __init__(self, config, embeddings, ntags, nchars = None):
14 | self.config = config
15 | self.embeddings = embeddings
16 | self.nchars = nchars
17 | self.ntags = ntags
18 | self.logger = config.logger
19 | self.rng = np.random.RandomState(self.config.random_seed)
20 |
21 |
22 | def _init_graph_(self):
23 | self.word_ids = tf.placeholder(tf.int32, shape=[None, None],
24 | name="word_ids")
25 | self.sequence_lengths = tf.placeholder(tf.int32, shape=[None],
26 | name="sequence_lengths")
27 |
28 | self.char_ids = tf.placeholder(tf.int32, shape=[None, None, None],
29 | name="char_ids")
30 | self.word_lengths = tf.placeholder(tf.int32, shape=[None, None],
31 | name="word_lengths")
32 |
33 | if self.config.model == "lstm_crf":
34 | self.imp_labels = tf.placeholder(tf.int32, shape=[None, None],
35 | name="imp_labels")
36 | else:
37 | self.imp_labels = tf.placeholder(tf.float32, shape=[None, None],
38 | name="imp_labels")
39 |
40 | self.dropout = tf.placeholder(dtype=tf.float32, shape=[],
41 | name="dropout")
42 | self.lr = tf.placeholder(dtype=tf.float32, shape=[],
43 | name="lr")
44 |
45 |
46 | def create_initializer(self, size):
47 | return tf.constant(np.asarray(self.rng.normal(loc = 0.0, scale = 0.1, size = size), dtype = np.float32))
48 |
49 |
50 | def get_feed_dict(self, words, imp_labels = None, lr = None, dropout = None):
51 | char_ids, word_ids = zip(*words)
52 | word_ids, sequence_lengths = pad_sequences(word_ids, 0)
53 | char_ids, word_lengths = pad_sequences(char_ids, pad_tok=0, nlevels=2)
54 |
55 | feed = {
56 | self.word_ids: word_ids,
57 | self.sequence_lengths: sequence_lengths
58 | }
59 |
60 | feed[self.char_ids] = char_ids
61 | feed[self.word_lengths] = word_lengths
62 |
63 | if imp_labels is not None:
64 | imp_labels, _ = pad_sequences(imp_labels, 0)
65 | feed[self.imp_labels] = imp_labels
66 |
67 | if lr is not None:
68 | feed[self.lr] = lr
69 |
70 | if dropout is not None:
71 | feed[self.dropout] = dropout
72 |
73 | return feed, sequence_lengths
74 |
75 | def _define_embedings_(self):
76 | with tf.variable_scope("words"):
77 | _word_embeddings = tf.Variable(self.embeddings, name = "_word_embeddings", dtype = tf.float32,
78 | trainable = self.config.train_embeddings)
79 | word_embeddings = tf.nn.embedding_lookup(_word_embeddings, self.word_ids,
80 | name = "word_embeddings")
81 |
82 | with tf.variable_scope("chars"):
83 | _char_embeddings = tf.get_variable(name = "_char_embeddings", dtype = tf.float32,
84 | shape = [self.nchars, self.config.dim_char])
85 | char_embeddings = tf.nn.embedding_lookup(_char_embeddings, self.char_ids,
86 | name = "char_embeddings")
87 |
88 | s = tf.shape(char_embeddings)
89 | char_embeddings = tf.reshape(char_embeddings, shape = [-1, s[-2], self.config.dim_char])
90 | word_lengths = tf.reshape(self.word_lengths, shape = [-1])
91 |
92 | cell_fw = get_rnn_cell(self.config.char_rnn_size, "LSTM", state_is_tuple = True)
93 | cell_bw = get_rnn_cell(self.config.char_rnn_size, "LSTM", state_is_tuple = True)
94 |
95 | _, ((_, output_fw), (_, output_bw)) = tf.nn.bidirectional_dynamic_rnn(cell_fw,
96 | cell_bw, char_embeddings, sequence_length = word_lengths,
97 | dtype = tf.float32)
98 |
99 | output = tf.concat([output_fw, output_bw], axis = -1)
100 | output = tf.reshape(output, shape= [-1, s[1], 2 * self.config.char_rnn_size])
101 |
102 | word_embeddings = tf.concat([word_embeddings, output], axis=-1)
103 | self.word_embeddings = tf.nn.dropout(word_embeddings, self.dropout)
104 |
105 |
106 | def _define_logits_(self):
107 | with tf.variable_scope("bi-lstm"):
108 | cell_fw = get_rnn_cell(self.config.word_rnn_size, "LSTM")
109 | cell_bw = get_rnn_cell(self.config.word_rnn_size, "LSTM")
110 | (output_fw, output_bw), _ = tf.nn.bidirectional_dynamic_rnn(cell_fw,
111 | cell_bw, self.word_embeddings, sequence_length=self.sequence_lengths,
112 | dtype=tf.float32)
113 | output = tf.concat([output_fw, output_bw], axis=-1)
114 |
115 | ntime_steps = tf.shape(output)[1]
116 | output = tf.reshape(output, [-1, 2 * self.config.word_rnn_size])
117 | pred_dim = 1 if self.config.model == "lstm_sig" else self.ntags
118 | pred = create_feedforward(output, 2 * self.config.word_rnn_size, pred_dim, self.create_initializer,
119 | "linear", "projection")
120 |
121 | if self.config.model == "lstm_sig":
122 | pred = tf.sigmoid(pred)
123 | self.logits = tf.reshape(pred, [-1, ntime_steps])
124 | else:
125 | self.logits = tf.reshape(pred, [-1, ntime_steps, self.ntags])
126 |
127 | def _define_predictions_(self):
128 | self.imp_pred = self.logits
129 |
130 |
131 | def _define_loss_(self):
132 | if self.config.model == "lstm_sig":
133 | Y = tf.reshape(self.imp_labels, [-1, 1])
134 | pred_Y = tf.reshape(self.logits, [-1, 1])
135 | self.loss = tf.sqrt(tf.reduce_sum(tf.pow(pred_Y - Y, 2)))
136 | else:
137 | log_likelihood, self.transition_params = tf.contrib.crf.crf_log_likelihood(
138 | self.logits, self.imp_labels, self.sequence_lengths)
139 | self.loss = tf.reduce_mean(-log_likelihood)
140 |
141 | def _setup_optimizer_(self):
142 | with tf.variable_scope("optimizer_setup"):
143 | if self.config.lr_method == 'adam':
144 | optimizer = tf.train.AdamOptimizer(self.lr)
145 | elif self.config.lr_method == 'adagrad':
146 | optimizer = tf.train.AdagradOptimizer(self.lr)
147 | elif self.config.lr_method == 'sgd':
148 | optimizer = tf.train.GradientDescentOptimizer(self.lr)
149 | else:
150 | optimizer = tf.train.RMSPropOptimizer(self.lr)
151 |
152 | self.optimize_ = optimizer.minimize(self.loss)
153 |
154 |
155 | def build_graph(self):
156 | self._init_graph_()
157 | self._define_embedings_()
158 | self._define_logits_()
159 | self._define_predictions_()
160 | self._define_loss_()
161 | self._setup_optimizer_()
162 |
163 | self.init = tf.global_variables_initializer()
164 |
165 |
166 | def predict_batch(self, sess, words):
167 | fd, sequence_lengths = self.get_feed_dict(words, dropout=1.0)
168 | if self.config.model == "lstm_sig":
169 | imp_pred = sess.run(self.imp_pred, feed_dict=fd)
170 | else:
171 | imp_pred = []
172 | logits, transition_params = sess.run([self.logits, self.transition_params],
173 | feed_dict=fd)
174 | # iterate over the sentences
175 | for logit, sequence_length in zip(logits, sequence_lengths):
176 | # keep only the valid time steps
177 | logit = logit[:sequence_length]
178 | viterbi_sequence, viterbi_score = tf.contrib.crf.viterbi_decode(
179 | logit, transition_params)
180 | imp_pred += [viterbi_sequence]
181 | return imp_pred, sequence_lengths
182 |
183 |
184 | def run_epoch(self, sess, train, dev, epoch):
185 | nbatches = (len(train) + self.config.batch_size - 1) // self.config.batch_size
186 | prog = Progbar(target=nbatches)
187 | for i, (words, imp_labels) in enumerate(minibatches(train, self.config.batch_size)):
188 |
189 | if self.config.model == "lstm_crf":
190 | imp_labels = list(map(self.config.digitize_labels, imp_labels))
191 |
192 | fd, _ = self.get_feed_dict(words, imp_labels, self.config.lr, self.config.dropout)
193 | _, train_loss = sess.run([self.optimize_, self.loss], feed_dict=fd)
194 | prog.update(i + 1, [("train loss", train_loss)])
195 |
196 | result = self.run_evaluate(sess, dev)
197 | self.logger.info("- dev acc {:04.4f} - f {:04.4f} - rms {:04.4f}".format(100*result['accuracy'],
198 | 100 * result['f-score'], -1 * result['rms']))
199 | return result
200 |
201 | def run_evaluate(self, sess, test, save=False):
202 | accs, rms = [], []
203 | labs, labs_ = [], []
204 | for words, imp_labels in minibatches(test, self.config.batch_size):
205 | imp_labels_, sequence_lengths = self.predict_batch(sess, words)
206 | for lab, lab_, length in zip(imp_labels, imp_labels_, sequence_lengths):
207 | lab = lab[:length]
208 | lab_ = lab_[:length]
209 |
210 | if self.config.model == "lstm_sig":
211 | d_lab = map(self.config.ann2class, lab)
212 | d_lab_ = map(self.config.ann2class, lab_)
213 | else:
214 | d_lab = list(map(self.config.ann2class, lab))
215 | d_lab_ = lab_[:]
216 | lab_ = list(map(self.config.class2ann, d_lab_))
217 |
218 | rms += [pow((float(a)-float(b)), 2) for (a,b) in zip(lab, lab_)]
219 | accs += [a==b for (a, b) in zip(d_lab, d_lab_)]
220 |
221 | labs.extend(d_lab)
222 | labs_.extend(d_lab_)
223 |
224 | if save:
225 | with open(self.config.compare_predictions, 'w') as f:
226 | csv_writer = csv.writer(f)
227 | csv_writer.writerow(['truth', 'predictions'])
228 | for y, pred_y in zip(labs, labs_):
229 | csv_writer.writerow([y, pred_y])
230 | print ("'compare.csv' file saved!")
231 |
232 | p, r, f, s = score(labs, labs_, average="macro")
233 | cnf_mat = confusion_matrix(labs, labs_)
234 | acc = np.mean(accs)
235 | rms_ = np.sqrt(np.mean(rms))
236 | return {'accuracy': acc, 'precision': p, 'recall': r, 'f-score': f, 'cnf': cnf_mat, 'rms': -1 * rms_}
237 |
238 |
239 | def train(self, train, dev):
240 | best_score = -100
241 | saver = tf.train.Saver()
242 |
243 | nepoch_no_imprv = 0
244 | with tf.Session() as sess:
245 | sess.run(self.init)
246 |
247 | for epoch in range(self.config.nepochs):
248 | self.logger.info("Epoch {:} out of {:}".format(epoch + 1, self.config.nepochs))
249 | result = self.run_epoch(sess, train, dev, epoch)
250 |
251 | self.config.lr *= self.config.lr_decay
252 |
253 | if result[self.config.opt_metric] >= best_score:
254 | nepoch_no_imprv = 0
255 | if not os.path.exists(self.config.model_output):
256 | os.makedirs(self.config.model_output)
257 | saver.save(sess, self.config.model_output)
258 | best_score = result[self.config.opt_metric]
259 | self.logger.info("- new best score!")
260 |
261 | else:
262 | nepoch_no_imprv += 1
263 | if nepoch_no_imprv >= self.config.nepoch_no_imprv:
264 | self.logger.info("- early stopping {} epochs without improvement".format(
265 | nepoch_no_imprv))
266 | break
267 |
268 |
269 | def evaluate(self, test):
270 | saver = tf.train.Saver()
271 | with tf.Session() as sess:
272 | self.logger.info("Testing model over test set")
273 | saver.restore(sess, self.config.model_output)
274 | result = self.run_evaluate(sess, test, save=True)
275 |
276 | #plot the confustion matrix
277 | plot_confusion_matrix(self.config, result['cnf'], classes=[str(i) for i in range(0, 6)], normalize=True,
278 | title='Normalized confusion matrix')
279 | self.logger.info("- test acc {:04.4f} - f {:04.4f} - rms {:04.4f}".format(100 * result['accuracy'],
280 | 100 * result['f-score'], -1 * result['rms']))
281 |
282 |
283 | def annotate_files(self, file_path, out_path, processing_word):
284 | output_file = open(out_path, 'w')
285 | saver = tf.train.Saver()
286 | with tf.Session() as sess:
287 | saver.restore(sess, self.config.model_output)
288 | with open(file_path) as f:
289 | for line in f:
290 | sentence = line.strip()
291 |
292 | if sentence == "":
293 | output_file.write(" 0\n")
294 | output_file.write("\n")
295 | else:
296 | words_raw = sentence.strip().split(" ")
297 |
298 | words = [processing_word(w) for w in words_raw]
299 | if type(words[0]) == tuple:
300 | words = zip(*words)
301 | preds, _ = self.predict_batch(sess, [words])
302 | preds = preds[0]
303 |
304 | for w, pred in zip(words_raw, preds):
305 | output_file.write(w + " " + str(pred) + "\n")
306 | output_file.write("\n")
307 | output_file.close()
308 |
309 |
310 | def reset(self):
311 | tf.reset_default_graph()
312 |
--------------------------------------------------------------------------------
/model_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | def get_rnn_cell(size, type_, state_is_tuple = True):
4 | if type_ == "LSTM":
5 | return tf.contrib.rnn.LSTMCell(size, state_is_tuple=state_is_tuple)
6 | elif type_ == "GRU":
7 | return tf.contrib.rnn.GRUCell(size)
8 |
9 | def create_feedforward(input_tensor, input_size, output_size, fn_initializer, activation, scope):
10 | with tf.variable_scope(scope):
11 | weights = tf.get_variable("W_", dtype = tf.float32, initializer = fn_initializer((input_size, output_size)))
12 | bias = tf.get_variable("b_", dtype = tf.float32, initializer = fn_initializer((output_size,)))
13 | output = tf.matmul(input_tensor, weights) + bias
14 | if activation == "tanh":
15 | output = tf.tanh(output)
16 | elif activation == "sigmoid":
17 | output = tf.sigmoid(output)
18 | return output
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow==1.14
2 | numpy==1.18.1
3 | scikit-learn==0.22.1
4 | scipy==1.4.1
5 | matplotlib==3.1.3
--------------------------------------------------------------------------------