├── .gitignore ├── data ├── data-goes-here.md ├── outputs │ └── results-go-here.md ├── test.txt.tmp ├── train.txt.tmp └── validation.txt.tmp ├── eval.pl ├── generate_results.py ├── graph.py ├── model_reader.py ├── pos_eval.py ├── readme.md ├── run_all.sh ├── run_epoch.py ├── run_model.py └── saveload.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.txt 6 | *.csv 7 | *.pkl 8 | .DS_Store 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | env/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *,cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # IPython Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | venv/ 87 | ENV/ 88 | 89 | # Spyder project settings 90 | .spyderproject 91 | 92 | # Rope project settings 93 | .ropeproject 94 | -------------------------------------------------------------------------------- /data/data-goes-here.md: -------------------------------------------------------------------------------- 1 | # Your Data Goes In This Folder 2 | -------------------------------------------------------------------------------- /data/outputs/results-go-here.md: -------------------------------------------------------------------------------- 1 | # Results Go In This Folder (Automatically) 2 | -------------------------------------------------------------------------------- /eval.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | # conlleval: evaluate result of processing CoNLL-2000 shared task 3 | # usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file 4 | # README: http://cnts.uia.ac.be/conll2000/chunking/output.html 5 | # options: l: generate LaTeX output for tables like in 6 | # http://cnts.uia.ac.be/conll2003/ner/example.tex 7 | # r: accept raw result tags (without B- and I- prefix; 8 | # assumes one word per chunk) 9 | # d: alternative delimiter tag (default is single space) 10 | # o: alternative outside tag (default is O) 11 | # note: the file should contain lines with items separated 12 | # by $delimiter characters (default space). The final 13 | # two items should contain the correct tag and the 14 | # guessed tag in that order. Sentences should be 15 | # separated from each other by empty lines or lines 16 | # with $boundary fields (default -X-). 17 | # url: http://lcg-www.uia.ac.be/conll2000/chunking/ 18 | # started: 1998-09-25 19 | # version: 2004-01-26 20 | # author: Erik Tjong Kim Sang 21 | 22 | use strict; 23 | 24 | my $false = 0; 25 | my $true = 42; 26 | 27 | my $boundary = "-X-"; # sentence boundary 28 | my $correct; # current corpus chunk tag (I,O,B) 29 | my $correctChunk = 0; # number of correctly identified chunks 30 | my $correctTags = 0; # number of correct chunk tags 31 | my $correctType; # type of current corpus chunk tag (NP,VP,etc.) 32 | my $delimiter = " "; # field delimiter 33 | my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979) 34 | my $firstItem; # first feature (for sentence boundary checks) 35 | my $foundCorrect = 0; # number of chunks in corpus 36 | my $foundGuessed = 0; # number of identified chunks 37 | my $guessed; # current guessed chunk tag 38 | my $guessedType; # type of current guessed chunk tag 39 | my $i; # miscellaneous counter 40 | my $inCorrect = $false; # currently processed chunk is correct until now 41 | my $lastCorrect = "O"; # previous chunk tag in corpus 42 | my $latex = 0; # generate LaTeX formatted output 43 | my $lastCorrectType = ""; # type of previously identified chunk tag 44 | my $lastGuessed = "O"; # previously identified chunk tag 45 | my $lastGuessedType = ""; # type of previous chunk tag in corpus 46 | my $lastType; # temporary storage for detecting duplicates 47 | my $line; # line 48 | my $nbrOfFeatures = -1; # number of features per line 49 | my $precision = 0.0; # precision score 50 | my $oTag = "O"; # outside tag, default O 51 | my $raw = 0; # raw input: add B to every token 52 | my $recall = 0.0; # recall score 53 | my $tokenCounter = 0; # token counter (ignores sentence breaks) 54 | 55 | my %correctChunk = (); # number of correctly identified chunks per type 56 | my %foundCorrect = (); # number of chunks in corpus per type 57 | my %foundGuessed = (); # number of identified chunks per type 58 | 59 | my @features; # features on line 60 | my @sortedTypes; # sorted list of chunk type names 61 | 62 | # sanity check 63 | while (@ARGV and $ARGV[0] =~ /^-/) { 64 | if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); } 65 | elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); } 66 | elsif ($ARGV[0] eq "-d") { 67 | shift(@ARGV); 68 | if (not defined $ARGV[0]) { 69 | die "conlleval: -d requires delimiter character"; 70 | } 71 | $delimiter = shift(@ARGV); 72 | } elsif ($ARGV[0] eq "-o") { 73 | shift(@ARGV); 74 | if (not defined $ARGV[0]) { 75 | die "conlleval: -o requires delimiter character"; 76 | } 77 | $oTag = shift(@ARGV); 78 | } else { die "conlleval: unknown argument $ARGV[0]\n"; } 79 | } 80 | if (@ARGV) { die "conlleval: unexpected command line argument\n"; } 81 | # process input 82 | while () { 83 | chomp($line = $_); 84 | @features = split(/$delimiter/,$line); 85 | if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; } 86 | elsif ($nbrOfFeatures != $#features and @features != 0) { 87 | printf STDERR "unexpected number of features: %d (%d)\n", 88 | $#features+1,$nbrOfFeatures+1; 89 | exit(1); 90 | } 91 | if (@features == 0 or 92 | $features[0] eq $boundary) { @features = ($boundary,"O","O"); } 93 | if (@features < 2) { 94 | die "conlleval: unexpected number of features in line $line\n"; 95 | } 96 | if ($raw) { 97 | if ($features[$#features] eq $oTag) { $features[$#features] = "O"; } 98 | if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; } 99 | if ($features[$#features] ne "O") { 100 | $features[$#features] = "B-$features[$#features]"; 101 | } 102 | if ($features[$#features-1] ne "O") { 103 | $features[$#features-1] = "B-$features[$#features-1]"; 104 | } 105 | } 106 | # 20040126 ET code which allows hyphens in the types 107 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 108 | $guessed = $1; 109 | $guessedType = $2; 110 | } else { 111 | $guessed = $features[$#features]; 112 | $guessedType = ""; 113 | } 114 | pop(@features); 115 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 116 | $correct = $1; 117 | $correctType = $2; 118 | } else { 119 | $correct = $features[$#features]; 120 | $correctType = ""; 121 | } 122 | pop(@features); 123 | # ($guessed,$guessedType) = split(/-/,pop(@features)); 124 | # ($correct,$correctType) = split(/-/,pop(@features)); 125 | $guessedType = $guessedType ? $guessedType : ""; 126 | $correctType = $correctType ? $correctType : ""; 127 | $firstItem = shift(@features); 128 | 129 | # 1999-06-26 sentence breaks should always be counted as out of chunk 130 | if ( $firstItem eq $boundary ) { $guessed = "O"; } 131 | 132 | if ($inCorrect) { 133 | if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 134 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 135 | $lastGuessedType eq $lastCorrectType) { 136 | $inCorrect=$false; 137 | $correctChunk++; 138 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 139 | $correctChunk{$lastCorrectType}+1 : 1; 140 | } elsif ( 141 | &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) != 142 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or 143 | $guessedType ne $correctType ) { 144 | $inCorrect=$false; 145 | } 146 | } 147 | 148 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 149 | &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 150 | $guessedType eq $correctType) { $inCorrect = $true; } 151 | 152 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) { 153 | $foundCorrect++; 154 | $foundCorrect{$correctType} = $foundCorrect{$correctType} ? 155 | $foundCorrect{$correctType}+1 : 1; 156 | } 157 | if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) { 158 | $foundGuessed++; 159 | $foundGuessed{$guessedType} = $foundGuessed{$guessedType} ? 160 | $foundGuessed{$guessedType}+1 : 1; 161 | } 162 | if ( $firstItem ne $boundary ) { 163 | if ( $correct eq $guessed and $guessedType eq $correctType ) { 164 | $correctTags++; 165 | } 166 | $tokenCounter++; 167 | } 168 | 169 | $lastGuessed = $guessed; 170 | $lastCorrect = $correct; 171 | $lastGuessedType = $guessedType; 172 | $lastCorrectType = $correctType; 173 | } 174 | if ($inCorrect) { 175 | $correctChunk++; 176 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 177 | $correctChunk{$lastCorrectType}+1 : 1; 178 | } 179 | 180 | if (not $latex) { 181 | # compute overall precision, recall and FB1 (default values are 0.0) 182 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 183 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 184 | $FB1 = 2*$precision*$recall/($precision+$recall) 185 | if ($precision+$recall > 0); 186 | 187 | # print overall performance 188 | printf "processed $tokenCounter tokens with $foundCorrect phrases; "; 189 | printf "found: $foundGuessed phrases; correct: $correctChunk.\n"; 190 | if ($tokenCounter>0) { 191 | printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter; 192 | printf "precision: %6.2f%%; ",$precision; 193 | printf "recall: %6.2f%%; ",$recall; 194 | printf "FB1: %6.2f\n",$FB1; 195 | } 196 | } 197 | 198 | # sort chunk type names 199 | undef($lastType); 200 | @sortedTypes = (); 201 | foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) { 202 | if (not($lastType) or $lastType ne $i) { 203 | push(@sortedTypes,($i)); 204 | } 205 | $lastType = $i; 206 | } 207 | # print performance per chunk type 208 | if (not $latex) { 209 | for $i (@sortedTypes) { 210 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 211 | if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; } 212 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 213 | if (not($foundCorrect{$i})) { $recall = 0.0; } 214 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 215 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 216 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 217 | printf "%17s: ",$i; 218 | printf "precision: %6.2f%%; ",$precision; 219 | printf "recall: %6.2f%%; ",$recall; 220 | printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i}; 221 | } 222 | } else { 223 | print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline"; 224 | for $i (@sortedTypes) { 225 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 226 | if (not($foundGuessed{$i})) { $precision = 0.0; } 227 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 228 | if (not($foundCorrect{$i})) { $recall = 0.0; } 229 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 230 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 231 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 232 | printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\", 233 | $i,$precision,$recall,$FB1; 234 | } 235 | print "\\hline\n"; 236 | $precision = 0.0; 237 | $recall = 0; 238 | $FB1 = 0.0; 239 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 240 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 241 | $FB1 = 2*$precision*$recall/($precision+$recall) 242 | if ($precision+$recall > 0); 243 | printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n", 244 | $precision,$recall,$FB1; 245 | } 246 | 247 | exit 0; 248 | 249 | # endOfChunk: checks if a chunk ended between the previous and current word 250 | # arguments: previous and current chunk tags, previous and current types 251 | # note: this code is capable of handling other chunk representations 252 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 253 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 254 | 255 | sub endOfChunk { 256 | my $prevTag = shift(@_); 257 | my $tag = shift(@_); 258 | my $prevType = shift(@_); 259 | my $type = shift(@_); 260 | my $chunkEnd = $false; 261 | 262 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; } 263 | if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; } 264 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; } 265 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 266 | 267 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; } 268 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; } 269 | if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; } 270 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 271 | 272 | if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) { 273 | $chunkEnd = $true; 274 | } 275 | 276 | # corrected 1998-12-22: these chunks are assumed to have length 1 277 | if ( $prevTag eq "]" ) { $chunkEnd = $true; } 278 | if ( $prevTag eq "[" ) { $chunkEnd = $true; } 279 | 280 | return($chunkEnd); 281 | } 282 | 283 | # startOfChunk: checks if a chunk started between the previous and current word 284 | # arguments: previous and current chunk tags, previous and current types 285 | # note: this code is capable of handling other chunk representations 286 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 287 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 288 | 289 | sub startOfChunk { 290 | my $prevTag = shift(@_); 291 | my $tag = shift(@_); 292 | my $prevType = shift(@_); 293 | my $type = shift(@_); 294 | my $chunkStart = $false; 295 | 296 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; } 297 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; } 298 | if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; } 299 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 300 | 301 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; } 302 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; } 303 | if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; } 304 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 305 | 306 | if ($tag ne "O" and $tag ne "." and $prevType ne $type) { 307 | $chunkStart = $true; 308 | } 309 | 310 | # corrected 1998-12-22: these chunks are assumed to have length 1 311 | if ( $tag eq "[" ) { $chunkStart = $true; } 312 | if ( $tag eq "]" ) { $chunkStart = $true; } 313 | 314 | return($chunkStart); 315 | } 316 | -------------------------------------------------------------------------------- /generate_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pos_eval import pos_eval 3 | 4 | import argparse 5 | 6 | 7 | def generate_results(path): 8 | 9 | # def = ../../data/current_outcome 10 | 11 | chunk_train = path + 'chunk_pred_train.txt' 12 | chunk_val = path + 'chunk_pred_val.txt' 13 | chunk_comb = path + 'chunk_pred_combined.txt' 14 | chunk_test = path + 'chunk_pred_test.txt' 15 | pos_train = path + 'pos_pred_train.txt' 16 | pos_val = path + 'pos_pred_val.txt' 17 | pos_comb = path + 'pos_pred_combined.txt' 18 | pos_test = path + 'pos_pred_test.txt' 19 | 20 | print('generating latex tables - chunk train') 21 | cmd = 'perl eval.pl -l < ' + chunk_train 22 | os.system(cmd) 23 | 24 | print('generating latex tables - chunk valid') 25 | cmd = 'perl eval.pl -l < ' + chunk_val 26 | os.system(cmd) 27 | 28 | print('generating latex tables - chunk combined') 29 | cmd = 'perl eval.pl -l < ' + chunk_comb 30 | os.system(cmd) 31 | 32 | print('generating latex tables - chunk test') 33 | cmd = 'perl eval.pl -l < ' + chunk_test 34 | os.system(cmd) 35 | 36 | print('generating accuracy - pos train') 37 | print(pos_eval(pos_train)) 38 | 39 | print('generating accruacy - pos valid') 40 | print(pos_eval(pos_val)) 41 | 42 | print('generating accruacy - pos combined') 43 | print(pos_eval(pos_comb)) 44 | 45 | print('generating accruacy - pos test') 46 | print(pos_eval(pos_test)) 47 | 48 | print('done') 49 | 50 | if __name__ == '__main__': 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("--path") 53 | args = parser.parse_args() 54 | path = args.path 55 | generate_results(path) 56 | -------------------------------------------------------------------------------- /graph.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | from tensorflow.contrib import rnn 8 | 9 | import pdb 10 | 11 | 12 | class Shared_Model(object): 13 | """Tensorflow Graph For Shared Pos & Chunk Model""" 14 | 15 | def __init__(self, config, is_training): 16 | self.max_grad_norm = config.max_grad_norm 17 | self.num_steps = num_steps = config.num_steps 18 | self.encoder_size = config.encoder_size 19 | self.pos_decoder_size = config.pos_decoder_size 20 | self.chunk_decoder_size = config.chunk_decoder_size 21 | self.batch_size = config.batch_size 22 | self.vocab_size = config.vocab_size 23 | self.num_pos_tags = config.num_pos_tags 24 | self.num_chunk_tags = config.num_chunk_tags 25 | self.input_data = tf.placeholder(tf.int32, [config.batch_size, num_steps]) 26 | self.word_embedding_size = config.word_embedding_size 27 | self.pos_embedding_size = config.pos_embedding_size 28 | self.num_shared_layers = config.num_shared_layers 29 | self.argmax = config.argmax 30 | 31 | # add input size - size of pos tags 32 | self.pos_targets = tf.placeholder(tf.float32, [(self.batch_size * num_steps), 33 | self.num_pos_tags]) 34 | self.chunk_targets = tf.placeholder(tf.float32, [(self.batch_size * num_steps), 35 | self.num_chunk_tags]) 36 | 37 | self._build_graph(config, is_training) 38 | 39 | def _shared_layer(self, input_data, config, is_training): 40 | """Build the model up until decoding. 41 | 42 | Args: 43 | input_data = size batch_size X num_steps X embedding size 44 | 45 | Returns: 46 | output units 47 | """ 48 | 49 | with tf.variable_scope('encoder'): 50 | lstm_cell = rnn.BasicLSTMCell(config.encoder_size, reuse=tf.get_variable_scope().reuse, forget_bias=1.0) 51 | if is_training and config.keep_prob < 1: 52 | lstm_cell = rnn.DropoutWrapper( 53 | lstm_cell, output_keep_prob=config.keep_prob) 54 | encoder_outputs, encoder_states = tf.nn.dynamic_rnn(lstm_cell, 55 | input_data, 56 | dtype=tf.float32, 57 | scope="encoder_rnn") 58 | 59 | return encoder_outputs 60 | 61 | def _pos_private(self, encoder_units, config, is_training): 62 | """Decode model for pos 63 | 64 | Args: 65 | encoder_units - these are the encoder units 66 | num_pos - the number of pos tags there are (output units) 67 | 68 | returns: 69 | logits 70 | """ 71 | with tf.variable_scope("pos_decoder"): 72 | pos_decoder_cell = rnn.BasicLSTMCell(config.pos_decoder_size, 73 | forget_bias=1.0, reuse=tf.get_variable_scope().reuse) 74 | 75 | if is_training and config.keep_prob < 1: 76 | pos_decoder_cell = rnn.DropoutWrapper( 77 | pos_decoder_cell, output_keep_prob=config.keep_prob) 78 | 79 | encoder_units = tf.transpose(encoder_units, [1, 0, 2]) 80 | 81 | decoder_outputs, decoder_states = tf.nn.dynamic_rnn(pos_decoder_cell, 82 | encoder_units, 83 | dtype=tf.float32, 84 | scope="pos_rnn") 85 | 86 | output = tf.reshape(tf.concat(decoder_outputs, 1), 87 | [-1, config.pos_decoder_size]) 88 | 89 | softmax_w = tf.get_variable("softmax_w", 90 | [config.pos_decoder_size, 91 | config.num_pos_tags]) 92 | softmax_b = tf.get_variable("softmax_b", [config.num_pos_tags]) 93 | logits = tf.matmul(output, softmax_w) + softmax_b 94 | 95 | return logits, decoder_states 96 | 97 | def _chunk_private(self, encoder_units, pos_prediction, config, is_training): 98 | """Decode model for chunks 99 | 100 | Args: 101 | encoder_units - these are the encoder units: 102 | [batch_size X encoder_size] with the one the pos prediction 103 | pos_prediction: 104 | must be the same size as the encoder_size 105 | 106 | returns: 107 | logits 108 | """ 109 | # concatenate the encoder_units and the pos_prediction 110 | 111 | pos_prediction = tf.reshape(pos_prediction, 112 | [self.batch_size, self.num_steps, self.pos_embedding_size]) 113 | encoder_units = tf.transpose(encoder_units, [1, 0, 2]) 114 | chunk_inputs = tf.concat([pos_prediction, encoder_units], 2) 115 | 116 | with tf.variable_scope("chunk_decoder"): 117 | cell = rnn.BasicLSTMCell(config.chunk_decoder_size, forget_bias=1.0, reuse=tf.get_variable_scope().reuse) 118 | 119 | if is_training and config.keep_prob < 1: 120 | cell = rnn.DropoutWrapper( 121 | cell, output_keep_prob=config.keep_prob) 122 | 123 | decoder_outputs, decoder_states = tf.nn.dynamic_rnn(cell, 124 | chunk_inputs, 125 | dtype=tf.float32, 126 | scope="chunk_rnn") 127 | 128 | output = tf.reshape(tf.concat(decoder_outputs, 1), 129 | [-1, config.chunk_decoder_size]) 130 | 131 | softmax_w = tf.get_variable("softmax_w", 132 | [config.chunk_decoder_size, 133 | config.num_chunk_tags]) 134 | softmax_b = tf.get_variable("softmax_b", [config.num_chunk_tags]) 135 | logits = tf.matmul(output, softmax_w) + softmax_b 136 | 137 | return logits, decoder_states 138 | 139 | def _loss(self, logits, labels): 140 | """Calculate loss for both pos and chunk 141 | Args: 142 | logits from the decoder 143 | labels - one-hot 144 | returns: 145 | loss as tensor of type float 146 | """ 147 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, 148 | labels=labels, 149 | name='xentropy') 150 | loss = tf.reduce_mean(cross_entropy, name='xentropy_mean') 151 | (_, int_targets) = tf.nn.top_k(labels, 1) 152 | (_, int_predictions) = tf.nn.top_k(logits, 1) 153 | num_true = tf.reduce_sum(tf.cast(tf.equal(int_targets, int_predictions), tf.float32)) 154 | accuracy = num_true / (self.num_steps * self.batch_size) 155 | return loss, accuracy, int_predictions, int_targets 156 | 157 | def _training(self, loss, config): 158 | """Sets up training ops 159 | 160 | Creates the optimiser 161 | 162 | The op returned from this is what is passed to session run 163 | 164 | Args: 165 | loss float 166 | learning_rate float 167 | 168 | returns: 169 | 170 | Op for training 171 | """ 172 | # Create the gradient descent optimizer with the 173 | # given learning rate. 174 | tvars = tf.trainable_variables() 175 | grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars), 176 | config.max_grad_norm) 177 | optimizer = tf.train.AdamOptimizer() 178 | train_op = optimizer.apply_gradients(zip(grads, tvars)) 179 | return train_op 180 | 181 | def _build_graph(self, config, is_training): 182 | word_embedding = tf.get_variable("word_embedding", [config.vocab_size, config.word_embedding_size]) 183 | inputs = tf.nn.embedding_lookup(word_embedding, self.input_data) 184 | pos_embedding = tf.get_variable("pos_embedding", [config.num_pos_tags, config.pos_embedding_size]) 185 | 186 | if is_training and config.keep_prob < 1: 187 | inputs = tf.nn.dropout(inputs, config.keep_prob) 188 | 189 | encoding = self._shared_layer(inputs, config, is_training) 190 | 191 | encoding = tf.stack(encoding) 192 | encoding = tf.transpose(encoding, perm=[1, 0, 2]) 193 | 194 | pos_logits, pos_states = self._pos_private(encoding, config, is_training) 195 | pos_loss, pos_accuracy, pos_int_pred, pos_int_targ = self._loss(pos_logits, self.pos_targets) 196 | self.pos_loss = pos_loss 197 | 198 | self.pos_int_pred = pos_int_pred 199 | self.pos_int_targ = pos_int_targ 200 | 201 | # choose either argmax or dot product for pos 202 | if config.argmax == 1: 203 | pos_to_chunk_embed = tf.nn.embedding_lookup(pos_embedding, pos_int_pred) 204 | else: 205 | pos_to_chunk_embed = tf.matmul(tf.nn.softmax(pos_logits), pos_embedding) 206 | 207 | chunk_logits, chunk_states = self._chunk_private(encoding, pos_to_chunk_embed, config, is_training) 208 | chunk_loss, chunk_accuracy, chunk_int_pred, chunk_int_targ = self._loss(chunk_logits, self.chunk_targets) 209 | self.chunk_loss = chunk_loss 210 | 211 | self.chunk_int_pred = chunk_int_pred 212 | self.chunk_int_targ = chunk_int_targ 213 | self.joint_loss = chunk_loss + pos_loss 214 | 215 | # return pos embedding 216 | self.pos_embedding = pos_embedding 217 | 218 | if not is_training: 219 | return 220 | 221 | self.pos_op = self._training(pos_loss, config) 222 | self.chunk_op = self._training(chunk_loss, config) 223 | self.joint_op = self._training(chunk_loss + pos_loss, config) 224 | -------------------------------------------------------------------------------- /model_reader.py: -------------------------------------------------------------------------------- 1 | """Utilities for parsing CONll text files.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import collections 7 | import os 8 | import sys 9 | import time 10 | import pandas as pd 11 | import csv 12 | import pdb 13 | import pickle 14 | 15 | import numpy as np 16 | 17 | """ 18 | 1.0. Utility Methods 19 | """ 20 | 21 | 22 | def read_tokens(filename, padding_val, col_val=-1): 23 | # Col Values 24 | # 0 - words 25 | # 1 - POS 26 | # 2 - tags 27 | 28 | with open(filename, 'rt', encoding='utf8') as csvfile: 29 | r = csv.reader(csvfile, delimiter=' ') 30 | words = np.transpose(np.array([x for x in list(r) if x != []])).astype(object) 31 | # padding token '0' 32 | print('reading ' + str(col_val) + ' ' + filename) 33 | if col_val!=-1: 34 | words = words[col_val] 35 | return np.pad( 36 | words, pad_width=(padding_val, 0), mode='constant', constant_values=0) 37 | 38 | 39 | def _build_vocab(filename, padding_width, col_val): 40 | # can be used for input vocab 41 | data = read_tokens(filename, padding_width, col_val) 42 | counter = collections.Counter(data) 43 | # get rid of all words with frequency == 1 44 | counter = {k: v for k, v in counter.items() if v > 1} 45 | counter[''] = 10000 46 | count_pairs = sorted(counter.items(), key=lambda x: -x[1]) 47 | words, _ = list(zip(*count_pairs)) 48 | word_to_id = dict(zip(words, range(len(words)))) 49 | 50 | return word_to_id 51 | 52 | def _build_tags(filename, padding_width, col_val): 53 | # can be used for classifications and input vocab 54 | data = read_tokens(filename, padding_width, col_val) 55 | counter = collections.Counter(data) 56 | count_pairs = sorted(counter.items(), key=lambda x: -x[1]) 57 | words, _ = list(zip(*count_pairs)) 58 | tag_to_id = dict(zip(words, range(len(words)))) 59 | if col_val == 1: 60 | pickle.dump(tag_to_id,open('pos_to_id.pkl','wb')) 61 | pickle.dump(count_pairs,open('pos_counts.pkl','wb')) 62 | 63 | return tag_to_id 64 | 65 | 66 | """ 67 | 1.1. Word Methods 68 | """ 69 | 70 | 71 | def _file_to_word_ids(filename, word_to_id, padding_width): 72 | # assumes _build_vocab has been called first as is called word to id 73 | data = read_tokens(filename, padding_width, 0) 74 | default_value = word_to_id[''] 75 | return [word_to_id.get(word, default_value) for word in data] 76 | 77 | """ 78 | 1.2. tag Methods 79 | """ 80 | 81 | 82 | def _int_to_tag(tag_int, tag_vocab_size): 83 | # creates the one-hot vector 84 | a = np.empty(tag_vocab_size) 85 | a.fill(0) 86 | np.put(a, tag_int, 1) 87 | return a 88 | 89 | 90 | def _seq_tag(tag_integers, tag_vocab_size): 91 | # create the array of one-hot vectors for your sequence 92 | return np.vstack(_int_to_tag( 93 | tag, tag_vocab_size) for tag in tag_integers) 94 | 95 | 96 | def _file_to_tag_classifications(filename, tag_to_id, padding_width, col_val): 97 | # assumes _build_vocab has been called first and is called tag to id 98 | data = read_tokens(filename, padding_width, col_val) 99 | return [tag_to_id[tag] for tag in data] 100 | 101 | 102 | def raw_x_y_data(data_path, num_steps): 103 | train = "train.txt" 104 | valid = "validation.txt" 105 | train_valid = "train_val_combined.txt" 106 | comb = "all_combined.txt" 107 | test = "test.txt" 108 | 109 | train_path = os.path.join(data_path, train) 110 | valid_path = os.path.join(data_path, valid) 111 | train_valid_path = os.path.join(data_path, train_valid) 112 | comb_path = os.path.join(data_path, comb) 113 | test_path = os.path.join(data_path, test) 114 | 115 | # checking for all combined 116 | if not os.path.exists(data_path + '/train_val_combined.txt'): 117 | print('writing train validation combined') 118 | train_data = pd.read_csv(data_path + '/train.txt', sep= ' ',header=None) 119 | validation_data = pd.read_csv(data_path + '/validation.txt', sep= ' ',header=None) 120 | 121 | comb = pd.concat([train_data,validation_data]) 122 | comb.to_csv(data_path + '/train_val_combined.txt', sep=' ', index=False, header=False) 123 | 124 | if not os.path.exists(data_path + '/all_combined.txt'): 125 | print('writing combined') 126 | test_data = pd.read_csv(data_path + '/test.txt', sep= ' ',header=None) 127 | train_data = pd.read_csv(data_path + '/train.txt', sep= ' ',header=None) 128 | val_data = pd.read_csv(data_path + '/validation.txt', sep=' ', header=None) 129 | 130 | comb = pd.concat([train_data,val_data,test_data]) 131 | comb.to_csv(data_path + '/all_combined.txt', sep=' ', index=False, header=False) 132 | 133 | word_to_id = _build_vocab(train_path, num_steps-1, 0) 134 | # use the full training set for building the target tags 135 | pos_to_id = _build_tags(comb_path, num_steps-1, 1) 136 | 137 | chunk_to_id = _build_tags(comb_path, num_steps-1, 2) 138 | 139 | word_data_t = _file_to_word_ids(train_path, word_to_id, num_steps-1) 140 | pos_data_t = _file_to_tag_classifications(train_path, pos_to_id, num_steps-1, 1) 141 | chunk_data_t = _file_to_tag_classifications(train_path, chunk_to_id, num_steps-1, 2) 142 | 143 | word_data_v = _file_to_word_ids(valid_path, word_to_id, num_steps-1) 144 | pos_data_v = _file_to_tag_classifications(valid_path, pos_to_id, num_steps-1, 1) 145 | chunk_data_v = _file_to_tag_classifications(valid_path, chunk_to_id, num_steps-1, 2) 146 | 147 | word_data_c = _file_to_word_ids(train_valid_path, word_to_id, num_steps-1) 148 | pos_data_c = _file_to_tag_classifications(train_valid_path, pos_to_id, num_steps-1, 1) 149 | chunk_data_c = _file_to_tag_classifications(train_valid_path, chunk_to_id, num_steps-1, 2) 150 | 151 | word_data_test = _file_to_word_ids(test_path, word_to_id, num_steps-1) 152 | pos_data_test = _file_to_tag_classifications(test_path, pos_to_id, num_steps-1, 1) 153 | chunk_data_test = _file_to_tag_classifications(test_path, chunk_to_id, num_steps-1, 2) 154 | 155 | return word_data_t, pos_data_t, chunk_data_t, word_data_v, \ 156 | pos_data_v, chunk_data_v, word_to_id, pos_to_id, chunk_to_id, \ 157 | word_data_test, pos_data_test, chunk_data_test, word_data_c, \ 158 | pos_data_c, chunk_data_c 159 | 160 | 161 | def create_batches(raw_words, raw_pos, raw_chunk, batch_size, num_steps, pos_vocab_size, 162 | chunk_vocab_size): 163 | """Tokenize and create batches From words (inputs), raw_pos (output 1), raw_chunk(output 2). The parameters 164 | of the minibatch are defined by the batch_size, the length of the sequence. 165 | 166 | :param raw_words: 167 | :param raw_pos: 168 | :param raw_chunk: 169 | :param batch_size: 170 | :param num_steps: 171 | :param pos_vocab_size: 172 | :param chunk_vocab_size: 173 | :return: 174 | """ 175 | 176 | def _reshape_and_pad(tokens, batch_size, num_steps): 177 | tokens = np.array(tokens, dtype=np.int32) 178 | data_len = len(tokens) 179 | post_padding_required = (batch_size*num_steps) - np.mod(data_len, batch_size*num_steps) 180 | 181 | tokens = np.pad(tokens, (0, post_padding_required), 'constant', 182 | constant_values=0) 183 | epoch_length = len(tokens) // (batch_size*num_steps) 184 | tokens = tokens.reshape([batch_size, num_steps*epoch_length]) 185 | return tokens 186 | 187 | """ 188 | 1. Prepare the input (word) data 189 | """ 190 | word_data = _reshape_and_pad(raw_words, batch_size, num_steps) 191 | pos_data = _reshape_and_pad(raw_pos, batch_size, num_steps) 192 | chunk_data = _reshape_and_pad(raw_chunk, batch_size, num_steps) 193 | 194 | """ 195 | 3. Do the epoch thing and iterate 196 | """ 197 | data_len = len(raw_words) 198 | # how many times do you iterate to reach the end of the epoch 199 | epoch_size = (data_len // (batch_size*num_steps)) + 1 200 | 201 | if epoch_size == 0: 202 | raise ValueError("epoch_size == 0, decrease batch_size or num_steps") 203 | 204 | for i in range(epoch_size): 205 | x = word_data[:, i*num_steps:(i+1)*num_steps] 206 | y_pos = np.vstack(_seq_tag(pos_data[tag, i*num_steps:(i+1)*num_steps], 207 | pos_vocab_size) for tag in range(batch_size)) 208 | y_chunk = np.vstack(_seq_tag(chunk_data[tag, i*num_steps:(i+1)*num_steps], 209 | chunk_vocab_size) for tag in range(batch_size)) 210 | y_pos = y_pos.astype(np.int32) 211 | y_chunk = y_chunk.astype(np.int32) 212 | yield (x, y_pos, y_chunk) 213 | 214 | 215 | def _int_to_string(int_pred, d): 216 | 217 | # integers are the Values 218 | keys = [] 219 | for x in int_pred: 220 | keys.append([k for k, v in d.items() if v == (x)]) 221 | 222 | return keys 223 | 224 | 225 | def res_to_list(res, batch_size, num_steps, to_id, w_length): 226 | 227 | tmp = np.concatenate([x.reshape(batch_size, num_steps) 228 | for x in res], axis=1).reshape(-1) 229 | tmp = np.squeeze(_int_to_string(tmp, to_id)) 230 | return tmp[range(num_steps-1, w_length)].reshape(-1,1) 231 | -------------------------------------------------------------------------------- /pos_eval.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import pdb 4 | 5 | def pos_eval(path): 6 | import numpy as np 7 | data = pd.read_csv(path, sep=' ', header=None) 8 | targ = data[1].as_matrix() 9 | pred = data[3].as_matrix() 10 | return np.sum(targ == pred)/float(len(targ)) 11 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ### Multi-Task Learning 2 | 3 | ------- 4 | 5 | ## Introduction 6 | 7 | This is an example of how to construct a multi-task neural net in Tensorflow. Here we're looking at Natural Language Processing ("NLP"), specifically on whether by learning about Part of Speech (POS) and Shallow Parsing (Chunking) at the same time we can improve performance on both. 8 | 9 | ## Network Structure 10 | 11 | Our network looks a little bit like this, with Task 1 being Part of Speech (POS) and Task 2 being Chunk: 12 | 13 | 14 | 15 | As you can see, you can train either tasks separately (by calling the individual training ops), or you can train the tasks jointly (by calling the join training op). 16 | 17 | We have also added in an explicit connection from POS to Chunk, which actually makes the network into something similar to a ladder network with an explicit hidden state representation. 18 | 19 | ## Quick Start (Mac and Linux) 20 | 21 | * This is python3, so please install anaconda3 and tensorflow. This should be enough to get you started. 22 | * Then, go into the data folder and get rid of the ``.tmp`` endings on the data. 23 | * Then run ``$ sh run_all.sh`` - this will start the joint training. Once it's finished, the outputs will be stored in ``./data/outputs`` 24 | * You can then print out the evaluations by typing ``python generate_results.py --path "./data/outputs/predictions/"`` 25 | 26 | ## How to do single training 27 | 28 | If you want to train each task separately and compare the results you just need to change an argument in the ``run_all.sh`` script. 29 | 30 | ### POS Single 31 | ```bash 32 | python3 run_model.py --model_type "POS" \ 33 | --dataset_path "./data" \ 34 | --save_path "./data/outputs/" 35 | 36 | ``` 37 | 38 | ### Chunk Single 39 | ```bash 40 | python3 run_model.py --model_type "CHUNK" \ 41 | --dataset_path "./data" \ 42 | --save_path "./data/outputs/" 43 | 44 | ``` 45 | 46 | ### Joint 47 | ```bash 48 | python3 run_model.py --model_type "JOINT" \ 49 | --dataset_path "./data" \ 50 | --save_path "./data/outputs/" 51 | 52 | ``` 53 | -------------------------------------------------------------------------------- /run_all.sh: -------------------------------------------------------------------------------- 1 | echo 'Running Model' 2 | python3 run_model.py --model_type "JOINT" \ 3 | --dataset_path "./data" \ 4 | --save_path "./data/outputs/" 5 | -------------------------------------------------------------------------------- /run_epoch.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import time 6 | import tensorflow as tf 7 | import model_reader as reader 8 | import numpy as np 9 | 10 | 11 | def run_epoch(session, m, words, pos, chunk, pos_vocab_size, chunk_vocab_size, 12 | verbose=False, valid=False, model_type='JOINT'): 13 | """Runs the model on the given data.""" 14 | epoch_size = ((len(words) // m.batch_size) - 1) // m.num_steps 15 | start_time = time.time() 16 | comb_loss = 0.0 17 | pos_total_loss = 0.0 18 | chunk_total_loss = 0.0 19 | iters = 0 20 | accuracy = 0.0 21 | pos_predictions = [] 22 | pos_true = [] 23 | chunk_predictions = [] 24 | chunk_true = [] 25 | 26 | for step, (x, y_pos, y_chunk) in enumerate(reader.create_batches(words, pos, chunk, m.batch_size, 27 | m.num_steps, pos_vocab_size, chunk_vocab_size)): 28 | 29 | if model_type == 'POS': 30 | if valid: 31 | eval_op = tf.no_op() 32 | else: 33 | eval_op = m.pos_op 34 | elif model_type == 'CHUNK': 35 | if valid: 36 | eval_op = tf.no_op() 37 | else: 38 | eval_op = m.chunk_op 39 | else: 40 | if valid: 41 | eval_op = tf.no_op() 42 | else: 43 | eval_op = m.joint_op 44 | 45 | joint_loss, _, pos_int_pred, chunk_int_pred, pos_int_true, \ 46 | chunk_int_true, pos_loss, chunk_loss = \ 47 | session.run([m.joint_loss, eval_op, m.pos_int_pred, 48 | m.chunk_int_pred, m.pos_int_targ, m.chunk_int_targ, 49 | m.pos_loss, m.chunk_loss], 50 | {m.input_data: x, 51 | m.pos_targets: y_pos, 52 | m.chunk_targets: y_chunk}) 53 | comb_loss += joint_loss 54 | chunk_total_loss += chunk_loss 55 | pos_total_loss += pos_loss 56 | iters += 1 57 | if verbose and step % 5 == 0: 58 | if model_type == 'POS': 59 | costs = pos_total_loss 60 | cost = pos_loss 61 | elif model_type == 'CHUNK': 62 | costs = chunk_total_loss 63 | cost = chunk_loss 64 | else: 65 | costs = comb_loss 66 | cost = joint_loss 67 | print("Type: %s,cost: %3f, total cost: %3f" % (model_type, cost, costs)) 68 | 69 | pos_int_pred = np.reshape(pos_int_pred, [m.batch_size, m.num_steps]) 70 | pos_predictions.append(pos_int_pred) 71 | pos_true.append(pos_int_true) 72 | 73 | chunk_int_pred = np.reshape(chunk_int_pred, [m.batch_size, m.num_steps]) 74 | chunk_predictions.append(chunk_int_pred) 75 | chunk_true.append(chunk_int_true) 76 | 77 | return (comb_loss / iters), pos_predictions, chunk_predictions, pos_true, \ 78 | chunk_true, (pos_total_loss / iters), (chunk_total_loss / iters) 79 | -------------------------------------------------------------------------------- /run_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import tensorflow as tf 5 | import tensorflow.python.platform 6 | 7 | import model_reader as reader 8 | import numpy as np 9 | import pdb 10 | from graph import Shared_Model 11 | from run_epoch import run_epoch 12 | import argparse 13 | import saveload 14 | 15 | 16 | class Config(object): 17 | """Configuration for the network""" 18 | init_scale = 0.1 # initialisation scale 19 | learning_rate = 0.001 # learning_rate (if you are using SGD) 20 | max_grad_norm = 5 # for gradient clipping 21 | num_steps = 20 # length of sequence 22 | word_embedding_size = 400 # size of the embedding 23 | encoder_size = 200 # first layer 24 | pos_decoder_size = 200 # second layer 25 | chunk_decoder_size = 200 # second layer 26 | max_epoch = 1 # maximum number of epochs 27 | keep_prob = 0.5 # for dropout 28 | batch_size = 64 # number of sequence 29 | vocab_size = 20000 # this isn't used - need to look at this 30 | num_pos_tags = 45 # hard coded, should it be? 31 | num_chunk_tags = 23 # as above 32 | pos_embedding_size = 400 33 | num_shared_layers = 1 34 | argmax = 0 35 | 36 | 37 | def main(model_type, dataset_path, save_path): 38 | """Main""" 39 | config = Config() 40 | raw_data = reader.raw_x_y_data( 41 | dataset_path, config.num_steps) 42 | words_t, pos_t, chunk_t, words_v, \ 43 | pos_v, chunk_v, word_to_id, pos_to_id, \ 44 | chunk_to_id, words_test, pos_test, chunk_test, \ 45 | words_c, pos_c, chunk_c = raw_data 46 | 47 | config.num_pos_tags = len(pos_to_id) 48 | config.num_chunk_tags = len(chunk_to_id) 49 | 50 | with tf.Graph().as_default(), tf.Session() as session: 51 | initializer = tf.random_uniform_initializer(-config.init_scale, 52 | config.init_scale) 53 | 54 | # model to train hyperparameters on 55 | with tf.variable_scope("hyp_model", reuse=None, initializer=initializer): 56 | m = Shared_Model(is_training=True, config=config) 57 | with tf.variable_scope("hyp_model", reuse=True, initializer=initializer): 58 | mvalid = Shared_Model(is_training=False, config=config) 59 | 60 | # model that trains, given hyper-parameters 61 | with tf.variable_scope("final_model", reuse=None, initializer=initializer): 62 | mTrain = Shared_Model(is_training=True, config=config) 63 | with tf.variable_scope("final_model", reuse=True, initializer=initializer): 64 | mTest = Shared_Model(is_training=False, config=config) 65 | 66 | tf.initialize_all_variables().run() 67 | 68 | # Create an empty array to hold [epoch number, loss] 69 | best_epoch = [0, 100000] 70 | 71 | print('finding best epoch parameter') 72 | # ==================================== 73 | # Create vectors for training results 74 | # ==================================== 75 | 76 | # Create empty vectors for loss 77 | train_loss_stats = np.array([]) 78 | train_pos_loss_stats = np.array([]) 79 | train_chunk_loss_stats = np.array([]) 80 | # Create empty vectors for accuracy 81 | train_pos_stats = np.array([]) 82 | train_chunk_stats = np.array([]) 83 | 84 | # ==================================== 85 | # Create vectors for validation results 86 | # ==================================== 87 | # Create empty vectors for loss 88 | valid_loss_stats = np.array([]) 89 | valid_pos_loss_stats = np.array([]) 90 | valid_chunk_loss_stats = np.array([]) 91 | # Create empty vectors for accuracy 92 | valid_pos_stats = np.array([]) 93 | valid_chunk_stats = np.array([]) 94 | 95 | for i in range(config.max_epoch): 96 | print("Epoch: %d" % (i + 1)) 97 | mean_loss, posp_t, chunkp_t, post_t, chunkt_t, pos_loss, chunk_loss = \ 98 | run_epoch(session, m, 99 | words_t, pos_t, chunk_t, 100 | config.num_pos_tags, config.num_chunk_tags, 101 | verbose=True, model_type=model_type) 102 | 103 | # Save stats for charts 104 | train_loss_stats = np.append(train_loss_stats, mean_loss) 105 | train_pos_loss_stats = np.append(train_pos_loss_stats, pos_loss) 106 | train_chunk_loss_stats = np.append(train_chunk_loss_stats, chunk_loss) 107 | 108 | # get predictions as list 109 | posp_t = reader.res_to_list(posp_t, config.batch_size, config.num_steps, 110 | pos_to_id, len(words_t)) 111 | chunkp_t = reader.res_to_list(chunkp_t, config.batch_size, 112 | config.num_steps, chunk_to_id, len(words_t)) 113 | post_t = reader.res_to_list(post_t, config.batch_size, config.num_steps, 114 | pos_to_id, len(words_t)) 115 | chunkt_t = reader.res_to_list(chunkt_t, config.batch_size, 116 | config.num_steps, chunk_to_id, len(words_t)) 117 | 118 | # find the accuracy 119 | pos_acc = np.sum(posp_t == post_t) / float(len(posp_t)) 120 | chunk_acc = np.sum(chunkp_t == chunkt_t) / float(len(chunkp_t)) 121 | 122 | # add to array 123 | train_pos_stats = np.append(train_pos_stats, pos_acc) 124 | train_chunk_stats = np.append(train_chunk_stats, chunk_acc) 125 | 126 | # print for tracking 127 | print("Pos Training Accuracy After Epoch %d : %3f" % (i + 1, pos_acc)) 128 | print("Chunk Training Accuracy After Epoch %d : %3f" % (i + 1, chunk_acc)) 129 | 130 | valid_loss, posp_v, chunkp_v, post_v, chunkt_v, pos_v_loss, chunk_v_loss = \ 131 | run_epoch(session, mvalid, words_v, pos_v, chunk_v, 132 | config.num_pos_tags, config.num_chunk_tags, 133 | verbose=True, valid=True, model_type=model_type) 134 | 135 | # Save loss for charts 136 | valid_loss_stats = np.append(valid_loss_stats, valid_loss) 137 | valid_pos_loss_stats = np.append(valid_pos_loss_stats, pos_v_loss) 138 | valid_chunk_loss_stats = np.append(valid_chunk_loss_stats, chunk_v_loss) 139 | 140 | # get predictions as list 141 | 142 | posp_v = reader.res_to_list(posp_v, config.batch_size, config.num_steps, 143 | pos_to_id, len(words_v)) 144 | chunkp_v = reader.res_to_list(chunkp_v, config.batch_size, 145 | config.num_steps, chunk_to_id, len(words_v)) 146 | chunkt_v = reader.res_to_list(chunkt_v, config.batch_size, 147 | config.num_steps, chunk_to_id, len(words_v)) 148 | post_v = reader.res_to_list(post_v, config.batch_size, config.num_steps, 149 | pos_to_id, len(words_v)) 150 | 151 | # find accuracy 152 | pos_acc = np.sum(posp_v == post_v) / float(len(posp_v)) 153 | chunk_acc = np.sum(chunkp_v == chunkt_v) / float(len(chunkp_v)) 154 | 155 | print("Pos Validation Accuracy After Epoch %d : %3f" % (i + 1, pos_acc)) 156 | print("Chunk Validation Accuracy After Epoch %d : %3f" % (i + 1, chunk_acc)) 157 | 158 | # add to stats 159 | valid_pos_stats = np.append(valid_pos_stats, pos_acc) 160 | valid_chunk_stats = np.append(valid_chunk_stats, chunk_acc) 161 | 162 | # update best parameters 163 | if (valid_loss < best_epoch[1]): 164 | best_epoch = [i + 1, valid_loss] 165 | 166 | # Save loss & accuracy plots 167 | np.savetxt(save_path + '/loss/valid_loss_stats.txt', valid_loss_stats) 168 | np.savetxt(save_path + '/loss/valid_pos_loss_stats.txt', valid_pos_loss_stats) 169 | np.savetxt(save_path + '/loss/valid_chunk_loss_stats.txt', valid_chunk_loss_stats) 170 | np.savetxt(save_path + '/accuracy/valid_pos_stats.txt', valid_pos_stats) 171 | np.savetxt(save_path + '/accuracy/valid_chunk_stats.txt', valid_chunk_stats) 172 | 173 | np.savetxt(save_path + '/loss/train_loss_stats.txt', train_loss_stats) 174 | np.savetxt(save_path + '/loss/train_pos_loss_stats.txt', train_pos_loss_stats) 175 | np.savetxt(save_path + '/loss/train_chunk_loss_stats.txt', train_chunk_loss_stats) 176 | np.savetxt(save_path + '/accuracy/train_pos_stats.txt', train_pos_stats) 177 | np.savetxt(save_path + '/accuracy/train_chunk_stats.txt', train_chunk_stats) 178 | 179 | # Train given epoch parameter 180 | print('Train Given Best Epoch Parameter :' + str(best_epoch[0])) 181 | for i in range(best_epoch[0]): 182 | print("Epoch: %d" % (i + 1)) 183 | _, posp_c, chunkp_c, _, _, _, _ = \ 184 | run_epoch(session, mTrain, 185 | words_c, pos_c, chunk_c, 186 | config.num_pos_tags, config.num_chunk_tags, 187 | verbose=True, model_type=model_type) 188 | 189 | print('Getting Testing Predictions') 190 | _, posp_test, chunkp_test, _, _, _, _ = \ 191 | run_epoch(session, mTest, 192 | words_test, pos_test, chunk_test, 193 | config.num_pos_tags, config.num_chunk_tags, 194 | verbose=True, valid=True, model_type=model_type) 195 | 196 | print('Writing Predictions') 197 | # prediction reshaping 198 | posp_c = reader.res_to_list(posp_c, config.batch_size, config.num_steps, 199 | pos_to_id, len(words_c)) 200 | posp_test = reader.res_to_list(posp_test, config.batch_size, config.num_steps, 201 | pos_to_id, len(words_test)) 202 | chunkp_c = reader.res_to_list(chunkp_c, config.batch_size, 203 | config.num_steps, chunk_to_id, len(words_c)) 204 | chunkp_test = reader.res_to_list(chunkp_test, config.batch_size, config.num_steps, 205 | chunk_to_id, len(words_test)) 206 | 207 | # save pickle - save_path + '/saved_variables.pkl' 208 | print('saving variables (pickling)') 209 | saveload.save(save_path + '/saved_variables.pkl', session) 210 | 211 | train_custom = reader.read_tokens(dataset_path + '/train.txt', 0) 212 | valid_custom = reader.read_tokens(dataset_path + '/validation.txt', 0) 213 | combined = reader.read_tokens(dataset_path + '/train_val_combined.txt', 0) 214 | test_data = reader.read_tokens(dataset_path + '/test.txt', 0) 215 | 216 | print('loaded text') 217 | 218 | chunk_pred_train = np.concatenate((np.transpose(train_custom), chunkp_t), axis=1) 219 | chunk_pred_val = np.concatenate((np.transpose(valid_custom), chunkp_v), axis=1) 220 | chunk_pred_c = np.concatenate((np.transpose(combined), chunkp_c), axis=1) 221 | chunk_pred_test = np.concatenate((np.transpose(test_data), chunkp_test), axis=1) 222 | pos_pred_train = np.concatenate((np.transpose(train_custom), posp_t), axis=1) 223 | pos_pred_val = np.concatenate((np.transpose(valid_custom), posp_v), axis=1) 224 | pos_pred_c = np.concatenate((np.transpose(combined), posp_c), axis=1) 225 | pos_pred_test = np.concatenate((np.transpose(test_data), posp_test), axis=1) 226 | 227 | print('finished concatenating, about to start saving') 228 | 229 | np.savetxt(save_path + '/predictions/chunk_pred_train.txt', 230 | chunk_pred_train, fmt='%s') 231 | print('writing to ' + save_path + '/predictions/chunk_pred_train.txt') 232 | np.savetxt(save_path + '/predictions/chunk_pred_val.txt', 233 | chunk_pred_val, fmt='%s') 234 | print('writing to ' + save_path + '/predictions/chunk_pred_val.txt') 235 | np.savetxt(save_path + '/predictions/chunk_pred_combined.txt', 236 | chunk_pred_c, fmt='%s') 237 | print('writing to ' + save_path + '/predictions/chunk_pred_val.txt') 238 | np.savetxt(save_path + '/predictions/chunk_pred_test.txt', 239 | chunk_pred_test, fmt='%s') 240 | print('writing to ' + save_path + '/predictions/chunk_pred_val.txt') 241 | np.savetxt(save_path + '/predictions/pos_pred_train.txt', 242 | pos_pred_train, fmt='%s') 243 | print('writing to ' + save_path + '/predictions/chunk_pred_val.txt') 244 | np.savetxt(save_path + '/predictions/pos_pred_val.txt', 245 | pos_pred_val, fmt='%s') 246 | print('writing to ' + save_path + '/predictions/chunk_pred_val.txt') 247 | np.savetxt(save_path + '/predictions/pos_pred_combined.txt', 248 | pos_pred_c, fmt='%s') 249 | np.savetxt(save_path + '/predictions/pos_pred_test.txt', 250 | pos_pred_test, fmt='%s') 251 | 252 | 253 | if __name__ == "__main__": 254 | parser = argparse.ArgumentParser() 255 | parser.add_argument("--model_type") 256 | parser.add_argument("--dataset_path") 257 | parser.add_argument("--save_path") 258 | args = parser.parse_args() 259 | if (str(args.model_type) != "POS") and (str(args.model_type) != "CHUNK"): 260 | args.model_type = 'JOINT' 261 | print('Model Selected : ' + str(args.model_type)) 262 | main(str(args.model_type), str(args.dataset_path), str(args.save_path)) 263 | -------------------------------------------------------------------------------- /saveload.py: -------------------------------------------------------------------------------- 1 | """Utility methods for pickling and unpickling models""" 2 | import pickle 3 | from tensorflow import Session 4 | import os 5 | import tensorflow as tf 6 | 7 | def save(save_path, sess): 8 | with open(save_path, "wb") as file: 9 | variables = tf.trainable_variables() 10 | values = sess.run(variables) 11 | pickle.dump({var.name: val for var, val in zip(variables, values)}, file) 12 | 13 | 14 | def load_np(save_path): 15 | if not os.path.exists(save_path): 16 | raise Exception("No saved weights at that location") 17 | else: 18 | v_dict = pickle.load(open(save_path, "rb")) 19 | for key in v_dict.keys(): 20 | print("Key name: " + key) 21 | 22 | return v_dict 23 | --------------------------------------------------------------------------------