├── .gitignore ├── LICENSE.txt ├── README.md ├── docs └── index.html ├── tagger.dy.py ├── tagger.pt.py ├── tagger.tf.py └── visualise.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Jonathan K Kummerfeld 2 | 3 | Permission to use, copy, modify, and/or distribute this software for any 4 | purpose with or without fee is hereby granted, provided that the above 5 | copyright notice and this permission notice appear in all copies. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH 8 | REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, 10 | INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM 11 | LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR 12 | OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR 13 | PERFORMANCE OF THIS SOFTWARE. 14 | 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Part-of-Speech tagger examples 2 | 3 | This repository contains the code for DyNet, PyTorch and Tensorflow versions of a reasonably good POS tagger. 4 | It also contains a program to convert that code (and comments) into a website for easy reading and comparison. 5 | 6 | For the website, go [here](http://jkk.name/neural-tagger-tutorial/). 7 | 8 | To generate the site, install the Python library `pygments` and run: 9 | 10 | ``` 11 | ./visualise.py tagger.dy.py tagger.pt.py tagger.tf.py > docs/index.html 12 | ``` 13 | -------------------------------------------------------------------------------- /tagger.dy.py: -------------------------------------------------------------------------------- 1 | #### We use argparse for processing command line arguments, random for shuffling our data, sys for flushing output, and numpy for handling vectors of data. 2 | # DyNet Implementation 3 | import argparse 4 | import random 5 | import sys 6 | 7 | import numpy as np 8 | 9 | #### Typically, we would make many of these constants command line arguments and tune using the development set. For simplicity, I have fixed their values here to match Jiang, Liang and Zhang (CoLing 2018). 10 | PAD = "__PAD__" 11 | UNK = "__UNK__" 12 | DIM_EMBEDDING = 100 # DIM_EMBEDDING - number of dimensions in our word embeddings. 13 | LSTM_HIDDEN = 100 # LSTM_HIDDEN - number of dimensions in the hidden vectors for the LSTM. Based on NCRFpp (200 in the paper, but 100 per direction in code) 14 | BATCH_SIZE = 10 # BATCH_SIZE - number of examples considered in each model update. 15 | LEARNING_RATE = 0.015 # LEARNING_RATE - adjusts how rapidly model parameters change by rescaling the gradient vector. 16 | LEARNING_DECAY_RATE = 0.05 # LEARNING_DECAY_RATE - part of a rescaling of the learning rate after each pass through the data. 17 | EPOCHS = 100 # EPOCHS - number of passes through the data in training. 18 | KEEP_PROB = 0.5 # KEEP_PROB - probability of keeping a value when applying dropout. 19 | GLOVE = "../data/glove.6B.100d.txt" # GLOVE - location of glove vectors. 20 | WEIGHT_DECAY = 1e-8 # WEIGHT_DECAY - part of a rescaling of weights when an update occurs. 21 | 22 | #### Dynet library imports. The first allows us to configure DyNet from within code rather than on the command line: mem is the amount of system memory initially allocated (DyNet has its own memory management), autobatch toggles automatic parallelisation of computations, weight_decay rescales weights by (1 - decay) after every update, random_seed sets the seed for random number generation. 23 | import dynet_config 24 | dynet_config.set(mem=256, autobatch=0, weight_decay=WEIGHT_DECAY,random_seed=0) 25 | # dynet_config.set_gpu() for when we want to run with GPUs 26 | import dynet as dy 27 | 28 | #### 29 | # Data reading 30 | def read_data(filename): 31 | #### We are expecting a minor variation on the raw Penn Treebank data, with one line per sentence, tokens separated by spaces, and the tag for each token placed next to its word (the | works as a separator as it does not appear as a token). 32 | """Example input: 33 | Pierre|NNP Vinken|NNP ,|, 61|CD years|NNS old|JJ 34 | """ 35 | content = [] 36 | with open(filename) as data_src: 37 | for line in data_src: 38 | t_p = [w.split("|") for w in line.strip().split()] 39 | tokens = [v[0] for v in t_p] 40 | tags = [v[1] for v in t_p] 41 | content.append((tokens, tags)) 42 | return content 43 | 44 | def simplify_token(token): 45 | chars = [] 46 | for char in token: 47 | #### Reduce sparsity by replacing all digits with 0. 48 | if char.isdigit(): 49 | chars.append("0") 50 | else: 51 | chars.append(char) 52 | return ''.join(chars) 53 | 54 | def main(): 55 | #### For the purpose of this example we only have arguments for locations of the data. 56 | parser = argparse.ArgumentParser(description='POS tagger.') 57 | parser.add_argument('training_data') 58 | parser.add_argument('dev_data') 59 | args = parser.parse_args() 60 | 61 | train = read_data(args.training_data) 62 | dev = read_data(args.dev_data) 63 | 64 | #### These indices map from strings to integers, which we apply to the input for our model. UNK is added to our mapping so that there is a vector we can use when we encounter unknown words. The special PAD symbol is used in PyTorch and Tensorflow as part of shaping the data in a batch to be a consistent size. It is not needed for DyNet, but kept for consistency. 65 | # Make indices 66 | id_to_token = [PAD, UNK] 67 | token_to_id = {PAD: 0, UNK: 1} 68 | id_to_tag = [PAD] 69 | tag_to_id = {PAD: 0} 70 | #### The '+ dev' may seem like an error, but is done here for convenience. It means in the next section we will retain the GloVe embeddings that appear in dev but not train. They won't be updated during training, so it does not mean we are getting information we shouldn't. In practise I would simply keep all the GloVe embeddings to avoid any potential incorrect use of the evaluation data. 71 | for tokens, tags in train + dev: 72 | for token in tokens: 73 | token = simplify_token(token) 74 | if token not in token_to_id: 75 | token_to_id[token] = len(token_to_id) 76 | id_to_token.append(token) 77 | for tag in tags: 78 | if tag not in tag_to_id: 79 | tag_to_id[tag] = len(tag_to_id) 80 | id_to_tag.append(tag) 81 | NWORDS = len(token_to_id) 82 | NTAGS = len(tag_to_id) 83 | 84 | # Load pre-trained GloVe vectors 85 | #### I am assuming these are 100-dimensional GloVe embeddings in their standard format. 86 | pretrained = {} 87 | for line in open(GLOVE): 88 | parts = line.strip().split() 89 | word = parts[0] 90 | vector = [float(v) for v in parts[1:]] 91 | pretrained[word] = vector 92 | #### We need the word vectors as a list to initialise the embeddings. Each entry in the list corresponds to the token with that index. 93 | pretrained_list = [] 94 | scale = np.sqrt(3.0 / DIM_EMBEDDING) 95 | for word in id_to_token: 96 | # apply lower() because all GloVe vectors are for lowercase words 97 | if word.lower() in pretrained: 98 | pretrained_list.append(np.array(pretrained[word.lower()])) 99 | else: 100 | #### For words that do not appear in GloVe we generate a random vector (note, the choice of scale here is important and we follow Jiang, Liang and Zhang (CoLing 2018). 101 | random_vector = np.random.uniform(-scale, scale, [DIM_EMBEDDING]) 102 | pretrained_list.append(random_vector) 103 | 104 | #### The most significant difference between the frameworks is how the model parameters and their execution is defined. In DyNet we define parameters here and then define computation as needed. In PyTorch we use a class with the parameters defined in the constructor and the computation defined in the forward() method. In Tensorflow we define both parameters and computation here. 105 | # Model creation 106 | #### 107 | model = dy.ParameterCollection() 108 | # Create word embeddings and initialise 109 | #### Lookup parameters are a matrix that supports efficient sparse lookup. 110 | pEmbedding = model.add_lookup_parameters((NWORDS, DIM_EMBEDDING)) 111 | pEmbedding.init_from_array(np.array(pretrained_list)) 112 | # Create LSTM parameters 113 | #### Objects that create LSTM cells and the necessary parameters. 114 | stdv = 1.0 / np.sqrt(LSTM_HIDDEN) # Needed to match PyTorch 115 | f_lstm = dy.VanillaLSTMBuilder(1, DIM_EMBEDDING, LSTM_HIDDEN, model, 116 | forget_bias=(np.random.random_sample() - 0.5) * 2 * stdv) 117 | b_lstm = dy.VanillaLSTMBuilder(1, DIM_EMBEDDING, LSTM_HIDDEN, model, 118 | forget_bias=(np.random.random_sample() - 0.5) * 2 * stdv) 119 | # Create output layer 120 | pOutput = model.add_parameters((NTAGS, 2 * LSTM_HIDDEN)) 121 | 122 | # Set recurrent dropout values (not used in this case) 123 | f_lstm.set_dropouts(0.0, 0.0) 124 | b_lstm.set_dropouts(0.0, 0.0) 125 | # Initialise LSTM parameters 126 | #### To match PyTorch, we initialise the parameters with an unconventional approach. 127 | f_lstm.get_parameters()[0][0].set_value( 128 | np.random.uniform(-stdv, stdv, [4 * LSTM_HIDDEN, DIM_EMBEDDING])) 129 | f_lstm.get_parameters()[0][1].set_value( 130 | np.random.uniform(-stdv, stdv, [4 * LSTM_HIDDEN, LSTM_HIDDEN])) 131 | f_lstm.get_parameters()[0][2].set_value( 132 | np.random.uniform(-stdv, stdv, [4 * LSTM_HIDDEN])) 133 | b_lstm.get_parameters()[0][0].set_value( 134 | np.random.uniform(-stdv, stdv, [4 * LSTM_HIDDEN, DIM_EMBEDDING])) 135 | b_lstm.get_parameters()[0][1].set_value( 136 | np.random.uniform(-stdv, stdv, [4 * LSTM_HIDDEN, LSTM_HIDDEN])) 137 | b_lstm.get_parameters()[0][2].set_value( 138 | np.random.uniform(-stdv, stdv, [4 * LSTM_HIDDEN])) 139 | 140 | #### The trainer object is used to update the model. 141 | # Create the trainer 142 | trainer = dy.SimpleSGDTrainer(model, learning_rate=LEARNING_RATE) 143 | #### DyNet clips gradients by default, which we disable here (this can have a big impact on performance). 144 | trainer.set_clip_threshold(-1) 145 | 146 | #### To make the code match across the three versions, we group together some framework specific values needed when doing a pass over the data. 147 | expressions = (pEmbedding, pOutput, f_lstm, b_lstm, trainer) 148 | #### Main training loop, in which we shuffle the data, set the learning rate, do one complete pass over the training data, then evaluate on the development data. 149 | for epoch in range(EPOCHS): 150 | random.shuffle(train) 151 | 152 | #### 153 | # Update learning rate 154 | trainer.learning_rate = LEARNING_RATE / (1+ LEARNING_DECAY_RATE * epoch) 155 | 156 | #### Training pass. 157 | loss, tacc = do_pass(train, token_to_id, tag_to_id, expressions, True) 158 | #### Dev pass. 159 | _, dacc = do_pass(dev, token_to_id, tag_to_id, expressions, False) 160 | print("{} loss {} t-acc {} d-acc {}".format(epoch, loss, tacc, dacc)) 161 | 162 | #### The syntax varies, but in all three cases either saving or loading the parameters of a model must be done after the model is defined. 163 | # Save model 164 | model.save("tagger.dy.model") 165 | 166 | # Load model 167 | model.populate("tagger.dy.model") 168 | 169 | # Evaluation pass. 170 | _, test_acc = do_pass(dev, token_to_id, tag_to_id, expressions, False) 171 | print("Test Accuracy: {:.3f}".format(test_acc)) 172 | 173 | #### Inference (the same function for train and test). 174 | def do_pass(data, token_to_id, tag_to_id, expressions, train): 175 | pEmbedding, pOutput, f_lstm, b_lstm, trainer = expressions 176 | 177 | 178 | # Loop over batches 179 | loss = 0 180 | match = 0 181 | total = 0 182 | for start in range(0, len(data), BATCH_SIZE): 183 | #### Form the batch and order it based on length (important for efficient processing in PyTorch). 184 | batch = data[start : start + BATCH_SIZE] 185 | batch.sort(key = lambda x: -len(x[0])) 186 | #### Log partial results so we can conveniently check progress. 187 | if start % 4000 == 0 and start > 0: 188 | print(loss, match / total) 189 | sys.stdout.flush() 190 | 191 | #### Start a new computation graph for this batch. 192 | # Process batch 193 | dy.renew_cg() 194 | #### For each example, we will construct an expression that gives the loss. 195 | loss_expressions = [] 196 | predicted = [] 197 | #### Convert tokens and tags from strings to numbers using the indices. 198 | for n, (tokens, tags) in enumerate(batch): 199 | token_ids = [token_to_id.get(simplify_token(t), 0) for t in tokens] 200 | tag_ids = [tag_to_id[t] for t in tags] 201 | 202 | #### Now we define the computation to be performed with the model. Note that they are not applied yet, we are simply building the computation graph. 203 | # Look up word embeddings 204 | wembs = [dy.lookup(pEmbedding, w) for w in token_ids] 205 | # Apply dropout 206 | if train: 207 | wembs = [dy.dropout(w, 1.0 - KEEP_PROB) for w in wembs] 208 | # Feed words into the LSTM 209 | #### Create an expression for two LSTMs and feed in the embeddings (reversed in one case). 210 | #### We pull out the output vector from the cell state at each step. 211 | f_init = f_lstm.initial_state() 212 | f_lstm_output = [x.output() for x in f_init.add_inputs(wembs)] 213 | rev_embs = reversed(wembs) 214 | b_init = b_lstm.initial_state() 215 | b_lstm_output = [x.output() for x in b_init.add_inputs(rev_embs)] 216 | 217 | # For each output, calculate the output and loss 218 | pred_tags = [] 219 | for f, b, t in zip(f_lstm_output, reversed(b_lstm_output), tag_ids): 220 | # Combine the outputs 221 | combined = dy.concatenate([f,b]) 222 | # Apply dropout 223 | if train: 224 | combined = dy.dropout(combined, 1.0 - KEEP_PROB) 225 | # Matrix multiply to get scores for each tag 226 | r_t = pOutput * combined 227 | # Calculate cross-entropy loss 228 | if train: 229 | err = dy.pickneglogsoftmax(r_t, t) 230 | #### We are not actually evaluating the loss values here, instead we collect them together in a list. This enables DyNet's autobatching. 231 | loss_expressions.append(err) 232 | # Calculate the highest scoring tag 233 | #### This call to .npvalue() will lead to evaluation of the graph and so we don't actually get the benefits of autobatching. With some refactoring we could get the benefit back (simply keep the r_t expressions around and do this after the update), but that would have complicated this code. 234 | chosen = np.argmax(r_t.npvalue()) 235 | pred_tags.append(chosen) 236 | predicted.append(pred_tags) 237 | 238 | # combine the losses for the batch, do an update, and record the loss 239 | if train: 240 | loss_for_batch = dy.esum(loss_expressions) 241 | loss_for_batch.backward() 242 | trainer.update() 243 | loss += loss_for_batch.scalar_value() 244 | 245 | #### 246 | # Update the number of correct tags and total tags 247 | for (_, g), a in zip(batch, predicted): 248 | total += len(g) 249 | for gt, at in zip(g, a): 250 | gt = tag_to_id[gt] 251 | if gt == at: 252 | match += 1 253 | 254 | return loss, match / total 255 | 256 | if __name__ == '__main__': 257 | main() 258 | -------------------------------------------------------------------------------- /tagger.pt.py: -------------------------------------------------------------------------------- 1 | #### We use argparse for processing command line arguments, random for shuffling our data, sys for flushing output, and numpy for handling vectors of data. 2 | # PyTorch Implementation 3 | import argparse 4 | import random 5 | import sys 6 | 7 | import numpy as np 8 | 9 | #### Typically, we would make many of these constants command line arguments and tune using the development set. For simplicity, I have fixed their values here to match Jiang, Liang and Zhang (CoLing 2018). 10 | PAD = "__PAD__" 11 | UNK = "__UNK__" 12 | DIM_EMBEDDING = 100 # DIM_EMBEDDING - number of dimensions in our word embeddings. 13 | LSTM_HIDDEN = 100 # LSTM_HIDDEN - number of dimensions in the hidden vectors for the LSTM. Based on NCRFpp (200 in the paper, but 100 per direction in code) 14 | BATCH_SIZE = 10 # BATCH_SIZE - number of examples considered in each model update. 15 | LEARNING_RATE = 0.015 # LEARNING_RATE - adjusts how rapidly model parameters change by rescaling the gradient vector. 16 | LEARNING_DECAY_RATE = 0.05 # LEARNING_DECAY_RATE - part of a rescaling of the learning rate after each pass through the data. 17 | EPOCHS = 100 # EPOCHS - number of passes through the data in training. 18 | KEEP_PROB = 0.5 # KEEP_PROB - probability of keeping a value when applying dropout. 19 | GLOVE = "../data/glove.6B.100d.txt" # GLOVE - location of glove vectors. 20 | WEIGHT_DECAY = 1e-8 # WEIGHT_DECAY - part of a rescaling of weights when an update occurs. 21 | 22 | #### PyTorch library import. 23 | import torch 24 | torch.manual_seed(0) 25 | 26 | #### 27 | # Data reading 28 | def read_data(filename): 29 | #### We are expecting a minor variation on the raw Penn Treebank data, with one line per sentence, tokens separated by spaces, and the tag for each token placed next to its word (the | works as a separator as it does not appear as a token). 30 | """Example input: 31 | Pierre|NNP Vinken|NNP ,|, 61|CD years|NNS old|JJ 32 | """ 33 | content = [] 34 | with open(filename) as data_src: 35 | for line in data_src: 36 | t_p = [w.split("|") for w in line.strip().split()] 37 | tokens = [v[0] for v in t_p] 38 | tags = [v[1] for v in t_p] 39 | content.append((tokens, tags)) 40 | return content 41 | 42 | def simplify_token(token): 43 | chars = [] 44 | for char in token: 45 | #### Reduce sparsity by replacing all digits with 0. 46 | if char.isdigit(): 47 | chars.append("0") 48 | else: 49 | chars.append(char) 50 | return ''.join(chars) 51 | 52 | def main(): 53 | #### For the purpose of this example we only have arguments for locations of the data. 54 | parser = argparse.ArgumentParser(description='POS tagger.') 55 | parser.add_argument('training_data') 56 | parser.add_argument('dev_data') 57 | args = parser.parse_args() 58 | 59 | train = read_data(args.training_data) 60 | dev = read_data(args.dev_data) 61 | 62 | #### These indices map from strings to integers, which we apply to the input for our model. UNK is added to our mapping so that there is a vector we can use when we encounter unknown words. The special PAD symbol is used in PyTorch and Tensorflow as part of shaping the data in a batch to be a consistent size. It is not needed for DyNet, but kept for consistency. 63 | # Make indices 64 | id_to_token = [PAD, UNK] 65 | token_to_id = {PAD: 0, UNK: 1} 66 | id_to_tag = [PAD] 67 | tag_to_id = {PAD: 0} 68 | #### The '+ dev' may seem like an error, but is done here for convenience. It means in the next section we will retain the GloVe embeddings that appear in dev but not train. They won't be updated during training, so it does not mean we are getting information we shouldn't. In practise I would simply keep all the GloVe embeddings to avoid any potential incorrect use of the evaluation data. 69 | for tokens, tags in train + dev: 70 | for token in tokens: 71 | token = simplify_token(token) 72 | if token not in token_to_id: 73 | token_to_id[token] = len(token_to_id) 74 | id_to_token.append(token) 75 | for tag in tags: 76 | if tag not in tag_to_id: 77 | tag_to_id[tag] = len(tag_to_id) 78 | id_to_tag.append(tag) 79 | NWORDS = len(token_to_id) 80 | NTAGS = len(tag_to_id) 81 | 82 | # Load pre-trained GloVe vectors 83 | #### I am assuming these are 100-dimensional GloVe embeddings in their standard format. 84 | pretrained = {} 85 | for line in open(GLOVE): 86 | parts = line.strip().split() 87 | word = parts[0] 88 | vector = [float(v) for v in parts[1:]] 89 | pretrained[word] = vector 90 | #### We need the word vectors as a list to initialise the embeddings. Each entry in the list corresponds to the token with that index. 91 | pretrained_list = [] 92 | scale = np.sqrt(3.0 / DIM_EMBEDDING) 93 | for word in id_to_token: 94 | # apply lower() because all GloVe vectors are for lowercase words 95 | if word.lower() in pretrained: 96 | pretrained_list.append(np.array(pretrained[word.lower()])) 97 | else: 98 | #### For words that do not appear in GloVe we generate a random vector (note, the choice of scale here is important and we follow Jiang, Liang and Zhang (CoLing 2018). 99 | random_vector = np.random.uniform(-scale, scale, [DIM_EMBEDDING]) 100 | pretrained_list.append(random_vector) 101 | 102 | #### The most significant difference between the frameworks is how the model parameters and their execution is defined. In DyNet we define parameters here and then define computation as needed. In PyTorch we use a class with the parameters defined in the constructor and the computation defined in the forward() method. In Tensorflow we define both parameters and computation here. 103 | # Model creation 104 | #### 105 | model = TaggerModel(NWORDS, NTAGS, pretrained_list, id_to_token) 106 | # Create optimizer and configure the learning rate 107 | optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, 108 | weight_decay=WEIGHT_DECAY) 109 | #### The learning rate for each epoch is set by multiplying the initial rate by the factor produced by this function. 110 | rescale_lr = lambda epoch: 1 / (1 + LEARNING_DECAY_RATE * epoch) 111 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 112 | lr_lambda=rescale_lr) 113 | 114 | #### To make the code match across the three versions, we group together some framework specific values needed when doing a pass over the data. 115 | expressions = (model, optimizer) 116 | #### Main training loop, in which we shuffle the data, set the learning rate, do one complete pass over the training data, then evaluate on the development data. 117 | for epoch in range(EPOCHS): 118 | random.shuffle(train) 119 | 120 | #### 121 | # Update learning rate 122 | #### First call to rescale_lr is with a 0, which is why this must be done before the pass over the data. 123 | scheduler.step() 124 | 125 | #### Training mode (and evaluation mode below) do things like enable dropout components. 126 | model.train() 127 | model.zero_grad() 128 | #### Training pass. 129 | loss, tacc = do_pass(train, token_to_id, tag_to_id, expressions, 130 | True) 131 | 132 | #### 133 | model.eval() 134 | #### Dev pass. 135 | _, dacc = do_pass(dev, token_to_id, tag_to_id, expressions, False) 136 | print("{} loss {} t-acc {} d-acc {}".format(epoch, loss, 137 | tacc, dacc)) 138 | 139 | #### The syntax varies, but in all three cases either saving or loading the parameters of a model must be done after the model is defined. 140 | # Save model 141 | torch.save(model.state_dict(), "tagger.pt.model") 142 | 143 | # Load model 144 | model.load_state_dict(torch.load('tagger.pt.model')) 145 | 146 | # Evaluation pass. 147 | _, test_acc = do_pass(dev, token_to_id, tag_to_id, expressions, False) 148 | print("Test Accuracy: {:.3f}".format(test_acc)) 149 | 150 | #### Neural network definition code. In PyTorch networks are defined using classes that extend Module. 151 | class TaggerModel(torch.nn.Module): 152 | #### In the constructor we define objects that will do each of the computations. 153 | def __init__(self, nwords, ntags, pretrained_list, id_to_token): 154 | super().__init__() 155 | 156 | # Create word embeddings 157 | pretrained_tensor = torch.FloatTensor(pretrained_list) 158 | self.word_embedding = torch.nn.Embedding.from_pretrained( 159 | pretrained_tensor, freeze=False) 160 | # Create input dropout parameter 161 | self.word_dropout = torch.nn.Dropout(1 - KEEP_PROB) 162 | # Create LSTM parameters 163 | self.lstm = torch.nn.LSTM(DIM_EMBEDDING, LSTM_HIDDEN, num_layers=1, 164 | batch_first=True, bidirectional=True) 165 | # Create output dropout parameter 166 | self.lstm_output_dropout = torch.nn.Dropout(1 - KEEP_PROB) 167 | # Create final matrix multiply parameters 168 | self.hidden_to_tag = torch.nn.Linear(LSTM_HIDDEN * 2, ntags) 169 | 170 | def forward(self, sentences, labels, lengths, cur_batch_size): 171 | max_length = sentences.size(1) 172 | 173 | # Look up word vectors 174 | word_vectors = self.word_embedding(sentences) 175 | # Apply dropout 176 | dropped_word_vectors = self.word_dropout(word_vectors) 177 | # Run the LSTM over the input, reshaping data for efficiency 178 | #### Assuming the data is ordered longest to shortest, this provides a view of the data that fits with how cuDNN works. 179 | packed_words = torch.nn.utils.rnn.pack_padded_sequence( 180 | dropped_word_vectors, lengths, True) 181 | #### The None argument is an optional initial hidden state (default is a zero vector). The ignored return value contains the hidden states. 182 | lstm_out, _ = self.lstm(packed_words, None) 183 | #### Reverse the view shift made for cuDNN. Specifying total_length is not necessary in general (it can be inferred), but is necessary for parallel processing. The ignored return value contains the length of each sequence. 184 | lstm_out, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_out, 185 | batch_first=True, total_length=max_length) 186 | # Apply dropout 187 | lstm_out_dropped = self.lstm_output_dropout(lstm_out) 188 | # Matrix multiply to get scores for each tag 189 | output_scores = self.hidden_to_tag(lstm_out_dropped) 190 | 191 | # Calculate loss and predictions 192 | #### We reshape to [batch size * sequence length , ntags] for more efficient processing. 193 | output_scores = output_scores.view(cur_batch_size * max_length, -1) 194 | flat_labels = labels.view(cur_batch_size * max_length) 195 | #### The ignore index refers to outputs to not score, which we use to ignore padding. 'reduction' defines how to combine the losses at each point in the sequence. The default is elementwise_mean, which would not do what we want. 196 | loss_function = torch.nn.CrossEntropyLoss(ignore_index=0, reduction='sum') 197 | loss = loss_function(output_scores, flat_labels) 198 | predicted_tags = torch.argmax(output_scores, 1) 199 | #### Reshape to have dimensions [batch size , sequence length]. 200 | predicted_tags = predicted_tags.view(cur_batch_size, max_length) 201 | return loss, predicted_tags 202 | 203 | #### Inference (the same function for train and test). 204 | def do_pass(data, token_to_id, tag_to_id, expressions, train): 205 | model, optimizer = expressions 206 | 207 | 208 | # Loop over batches 209 | loss = 0 210 | match = 0 211 | total = 0 212 | for start in range(0, len(data), BATCH_SIZE): 213 | #### Form the batch and order it based on length (important for efficient processing in PyTorch). 214 | batch = data[start : start + BATCH_SIZE] 215 | batch.sort(key = lambda x: -len(x[0])) 216 | #### Log partial results so we can conveniently check progress. 217 | if start % 4000 == 0 and start > 0: 218 | print(loss, match / total) 219 | sys.stdout.flush() 220 | 221 | #### 222 | # Prepare inputs 223 | #### Prepare input arrays, using .long() to cast the type from Tensor to LongTensor. 224 | cur_batch_size = len(batch) 225 | max_length = len(batch[0][0]) 226 | lengths = [len(v[0]) for v in batch] 227 | input_array = torch.zeros((cur_batch_size, max_length)).long() 228 | output_array = torch.zeros((cur_batch_size, max_length)).long() 229 | #### Convert tokens and tags from strings to numbers using the indices. 230 | for n, (tokens, tags) in enumerate(batch): 231 | token_ids = [token_to_id.get(simplify_token(t), 0) for t in tokens] 232 | tag_ids = [tag_to_id[t] for t in tags] 233 | 234 | #### Fill the arrays, leaving the remaining values as zero (our padding value). 235 | input_array[n, :len(tokens)] = torch.LongTensor(token_ids) 236 | output_array[n, :len(tags)] = torch.LongTensor(tag_ids) 237 | 238 | # Construct computation 239 | #### Calling the model as a function will run its forward() function, which constructs the computations. 240 | batch_loss, output = model(input_array, output_array, lengths, 241 | cur_batch_size) 242 | 243 | # Run computations 244 | if train: 245 | batch_loss.backward() 246 | optimizer.step() 247 | model.zero_grad() 248 | #### To get the loss value we use .item(). 249 | loss += batch_loss.item() 250 | #### Our output is an array (rather than a single value), so we use a different approach to get it into a usable form. 251 | predicted = output.cpu().data.numpy() 252 | 253 | #### 254 | # Update the number of correct tags and total tags 255 | for (_, g), a in zip(batch, predicted): 256 | total += len(g) 257 | for gt, at in zip(g, a): 258 | gt = tag_to_id[gt] 259 | if gt == at: 260 | match += 1 261 | 262 | return loss, match / total 263 | 264 | if __name__ == '__main__': 265 | main() 266 | -------------------------------------------------------------------------------- /tagger.tf.py: -------------------------------------------------------------------------------- 1 | #### We use argparse for processing command line arguments, random for shuffling our data, sys for flushing output, and numpy for handling vectors of data. 2 | # Tensorflow Implementation 3 | import argparse 4 | import random 5 | import sys 6 | 7 | import numpy as np 8 | 9 | #### Typically, we would make many of these constants command line arguments and tune using the development set. For simplicity, I have fixed their values here to match Jiang, Liang and Zhang (CoLing 2018). 10 | PAD = "__PAD__" 11 | UNK = "__UNK__" 12 | DIM_EMBEDDING = 100 # DIM_EMBEDDING - number of dimensions in our word embeddings. 13 | LSTM_HIDDEN = 100 # LSTM_HIDDEN - number of dimensions in the hidden vectors for the LSTM. Based on NCRFpp (200 in the paper, but 100 per direction in code) 14 | BATCH_SIZE = 10 # BATCH_SIZE - number of examples considered in each model update. 15 | LEARNING_RATE = 0.015 # LEARNING_RATE - adjusts how rapidly model parameters change by rescaling the gradient vector. 16 | LEARNING_DECAY_RATE = 0.05 # LEARNING_DECAY_RATE - part of a rescaling of the learning rate after each pass through the data. 17 | EPOCHS = 100 # EPOCHS - number of passes through the data in training. 18 | KEEP_PROB = 0.5 # KEEP_PROB - probability of keeping a value when applying dropout. 19 | GLOVE = "../data/glove.6B.100d.txt" # GLOVE - location of glove vectors. 20 | # WEIGHT_DECAY = 1e-8 Not used, see note at the bottom of the page 21 | 22 | #### Tensorflow library import. 23 | import tensorflow as tf 24 | 25 | #### 26 | # Data reading 27 | def read_data(filename): 28 | #### We are expecting a minor variation on the raw Penn Treebank data, with one line per sentence, tokens separated by spaces, and the tag for each token placed next to its word (the | works as a separator as it does not appear as a token). 29 | """Example input: 30 | Pierre|NNP Vinken|NNP ,|, 61|CD years|NNS old|JJ 31 | """ 32 | content = [] 33 | with open(filename) as data_src: 34 | for line in data_src: 35 | t_p = [w.split("|") for w in line.strip().split()] 36 | tokens = [v[0] for v in t_p] 37 | tags = [v[1] for v in t_p] 38 | content.append((tokens, tags)) 39 | return content 40 | 41 | def simplify_token(token): 42 | chars = [] 43 | for char in token: 44 | #### Reduce sparsity by replacing all digits with 0. 45 | if char.isdigit(): 46 | chars.append("0") 47 | else: 48 | chars.append(char) 49 | return ''.join(chars) 50 | 51 | def main(): 52 | #### For the purpose of this example we only have arguments for locations of the data. 53 | parser = argparse.ArgumentParser(description='POS tagger.') 54 | parser.add_argument('training_data') 55 | parser.add_argument('dev_data') 56 | args = parser.parse_args() 57 | 58 | train = read_data(args.training_data) 59 | dev = read_data(args.dev_data) 60 | 61 | #### These indices map from strings to integers, which we apply to the input for our model. UNK is added to our mapping so that there is a vector we can use when we encounter unknown words. The special PAD symbol is used in PyTorch and Tensorflow as part of shaping the data in a batch to be a consistent size. It is not needed for DyNet, but kept for consistency. 62 | # Make indices 63 | id_to_token = [PAD, UNK] 64 | token_to_id = {PAD: 0, UNK: 1} 65 | id_to_tag = [PAD] 66 | tag_to_id = {PAD: 0} 67 | #### The '+ dev' may seem like an error, but is done here for convenience. It means in the next section we will retain the GloVe embeddings that appear in dev but not train. They won't be updated during training, so it does not mean we are getting information we shouldn't. In practise I would simply keep all the GloVe embeddings to avoid any potential incorrect use of the evaluation data. 68 | for tokens, tags in train + dev: 69 | for token in tokens: 70 | token = simplify_token(token) 71 | if token not in token_to_id: 72 | token_to_id[token] = len(token_to_id) 73 | id_to_token.append(token) 74 | for tag in tags: 75 | if tag not in tag_to_id: 76 | tag_to_id[tag] = len(tag_to_id) 77 | id_to_tag.append(tag) 78 | NWORDS = len(token_to_id) 79 | NTAGS = len(tag_to_id) 80 | 81 | # Load pre-trained GloVe vectors 82 | #### I am assuming these are 100-dimensional GloVe embeddings in their standard format. 83 | pretrained = {} 84 | for line in open(GLOVE): 85 | parts = line.strip().split() 86 | word = parts[0] 87 | vector = [float(v) for v in parts[1:]] 88 | pretrained[word] = vector 89 | #### We need the word vectors as a list to initialise the embeddings. Each entry in the list corresponds to the token with that index. 90 | pretrained_list = [] 91 | scale = np.sqrt(3.0 / DIM_EMBEDDING) 92 | for word in id_to_token: 93 | # apply lower() because all GloVe vectors are for lowercase words 94 | if word.lower() in pretrained: 95 | pretrained_list.append(np.array(pretrained[word.lower()])) 96 | else: 97 | #### For words that do not appear in GloVe we generate a random vector (note, the choice of scale here is important and we follow Jiang, Liang and Zhang (CoLing 2018). 98 | random_vector = np.random.uniform(-scale, scale, [DIM_EMBEDDING]) 99 | pretrained_list.append(random_vector) 100 | 101 | #### The most significant difference between the frameworks is how the model parameters and their execution is defined. In DyNet we define parameters here and then define computation as needed. In PyTorch we use a class with the parameters defined in the constructor and the computation defined in the forward() method. In Tensorflow we define both parameters and computation here. 102 | # Model creation 103 | #### 104 | #### This line creates a new graph and makes it the default graph for operations to be registered to. It is not necessary here because we only have one graph, but is considered good practise (more discussion on Stackoverflow. 105 | with tf.Graph().as_default(): 106 | #### Placeholders are inputs/values that will be fed into the network each time it is run. We define their type, name, and shape (constant, 1D vector, 2D vector, etc). This includes what we normally think of as inputs (e.g. the tokens) as well as parameters we want to change at run time (e.g. the learning rate). 107 | # Define inputs 108 | e_input = tf.placeholder(tf.int32, [None, None], name='input') 109 | e_lengths = tf.placeholder(tf.int32, [None], name='lengths') 110 | e_mask = tf.placeholder(tf.int32, [None, None], name='mask') 111 | e_gold_output = tf.placeholder(tf.int32, [None, None], 112 | name='gold_output') 113 | e_keep_prob = tf.placeholder(tf.float32, name='keep_prob') 114 | e_learning_rate = tf.placeholder(tf.float32, name='learning_rate') 115 | 116 | # Define word embedding 117 | #### The embedding matrix is a variable (so they can shift in training), initialized with the vectors defined above. 118 | glove_init = tf.constant_initializer(np.array(pretrained_list)) 119 | e_embedding = tf.get_variable("embedding", [NWORDS, DIM_EMBEDDING], 120 | initializer=glove_init) 121 | e_embed = tf.nn.embedding_lookup(e_embedding, e_input) 122 | 123 | # Define LSTM cells 124 | #### We create an LSTM cell, then wrap it in a class that applies dropout to the input and output. 125 | e_cell_f = tf.contrib.rnn.BasicLSTMCell(LSTM_HIDDEN) 126 | e_cell_f = tf.contrib.rnn.DropoutWrapper(e_cell_f, 127 | input_keep_prob=e_keep_prob, output_keep_prob=e_keep_prob) 128 | # Recurrent dropout options 129 | #### We are not using recurrent dropout, but it is a common enough feature of networks that it's good to see how it is done. 130 | # variational_recurrent=True, dtype=tf.float32, 131 | # input_size=DIM_EMBEDDING) 132 | #### Similarly, multi-layer networks are a common use case. In Tensorflow, we would wrap a list of cells with a MultiRNNCell. 133 | # Multi-layer cell creation 134 | # e_cell_f = tf.contrib.rnn.MultiRNNCell([e_cell_f]) 135 | #### We are making a bidirectional network, so we need another cell for the reverse direction. 136 | e_cell_b = tf.contrib.rnn.BasicLSTMCell(LSTM_HIDDEN) 137 | e_cell_b = tf.contrib.rnn.DropoutWrapper(e_cell_b, 138 | input_keep_prob=e_keep_prob, output_keep_prob=e_keep_prob) 139 | #### To use the cells we create a dynamic RNN. The 'dynamic' aspect means we can feed in the lengths of input sequences not counting padding and it will stop early. 140 | e_initial_state_f = e_cell_f.zero_state(BATCH_SIZE, dtype=tf.float32) 141 | e_initial_state_b = e_cell_f.zero_state(BATCH_SIZE, dtype=tf.float32) 142 | e_lstm_outputs, e_final_state = tf.nn.bidirectional_dynamic_rnn( 143 | cell_fw=e_cell_f, cell_bw=e_cell_b, inputs=e_embed, 144 | initial_state_fw=e_initial_state_f, 145 | initial_state_bw=e_initial_state_b, 146 | sequence_length=e_lengths, dtype=tf.float32) 147 | e_lstm_outputs_merged = tf.concat(e_lstm_outputs, 2) 148 | 149 | # Define output layer 150 | #### Matrix multiply to get scores for each class. 151 | e_predictions = tf.contrib.layers.fully_connected(e_lstm_outputs_merged, 152 | NTAGS, activation_fn=None) 153 | # Define loss and update 154 | #### Cross-entropy loss. The reduction flag is crucial (the default is to average over the sequence). The weights flag accounts for padding that makes all of the sequences the same length. 155 | e_loss = tf.losses.sparse_softmax_cross_entropy(e_gold_output, 156 | e_predictions, weights=e_mask, 157 | reduction=tf.losses.Reduction.SUM) 158 | e_train = tf.train.GradientDescentOptimizer(e_learning_rate).minimize(e_loss) 159 | # Update with gradient clipping 160 | #### If we wanted to do gradient clipping we would need to do the update in a few steps, first calculating the gradient, then modifying it before applying it. 161 | # e_optimiser = tf.train.GradientDescentOptimizer(LEARNING_RATE) 162 | # e_gradients = e_optimiser.compute_gradients(e_loss) 163 | # e_clipped_gradients = [(tf.clip_by_value(grad, -5., 5.), var) 164 | # for grad, var in e_gradients] 165 | # e_train = e_optimiser.apply_gradients(e_gradients) 166 | 167 | # Define output 168 | e_auto_output = tf.argmax(e_predictions, 2, output_type=tf.int32) 169 | 170 | # Do training 171 | #### Configure the system environment. By default Tensorflow uses all available GPUs and RAM. These lines limit the number of GPUs used and the amount of RAM. To limit which GPUs are used, set the environment variable CUDA_VISIBLE_DEVICES (e.g. "export CUDA_VISIBLE_DEVICES=0,1"). 172 | config = tf.ConfigProto( 173 | device_count = {'GPU': 0}, 174 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = 0.8) 175 | ) 176 | #### A session runs the graph. We use a 'with' block to ensure it is closed, which frees various resources. 177 | with tf.Session(config=config) as sess: 178 | #### Run executes operations, in this case initializing the variables. 179 | sess.run(tf.global_variables_initializer()) 180 | 181 | #### To make the code match across the three versions, we group together some framework specific values needed when doing a pass over the data. 182 | expressions = [ 183 | e_auto_output, e_gold_output, e_input, e_keep_prob, e_lengths, 184 | e_loss, e_train, e_mask, e_learning_rate, sess 185 | ] 186 | #### Main training loop, in which we shuffle the data, set the learning rate, do one complete pass over the training data, then evaluate on the development data. 187 | for epoch in range(EPOCHS): 188 | random.shuffle(train) 189 | 190 | #### 191 | # Determine the current learning rate 192 | current_lr = LEARNING_RATE / (1+ LEARNING_DECAY_RATE * epoch) 193 | 194 | #### Training pass. 195 | loss, tacc = do_pass(train, token_to_id, tag_to_id, expressions, 196 | True, current_lr) 197 | #### Dev pass. 198 | _, dacc = do_pass(dev, token_to_id, tag_to_id, expressions, 199 | False) 200 | print("{} loss {} t-acc {} d-acc {}".format(epoch, loss, tacc, 201 | dacc)) 202 | 203 | #### The syntax varies, but in all three cases either saving or loading the parameters of a model must be done after the model is defined. 204 | # Save model 205 | saver = tf.train.Saver() 206 | saver.save(sess, "./tagger.tf.model") 207 | 208 | # Load model 209 | saver.restore(sess, "./tagger.tf.model") 210 | 211 | # Evaluation pass. 212 | _, test_acc = do_pass(dev, token_to_id, tag_to_id, expressions, 213 | False) 214 | print("Test Accuracy: {:.3f}".format(test_acc)) 215 | 216 | #### Inference (the same function for train and test). 217 | def do_pass(data, token_to_id, tag_to_id, expressions, train, lr=0.0): 218 | e_auto_output, e_gold_output, e_input, e_keep_prob, e_lengths, e_loss, \ 219 | e_train, e_mask, e_learning_rate, session = expressions 220 | 221 | # Loop over batches 222 | loss = 0 223 | match = 0 224 | total = 0 225 | for start in range(0, len(data), BATCH_SIZE): 226 | #### Form the batch and order it based on length (important for efficient processing in PyTorch). 227 | batch = data[start : start + BATCH_SIZE] 228 | batch.sort(key = lambda x: -len(x[0])) 229 | #### Log partial results so we can conveniently check progress. 230 | if start % 4000 == 0 and start > 0: 231 | print(loss, match / total) 232 | sys.stdout.flush() 233 | 234 | #### 235 | # Add empty sentences to fill the batch 236 | #### We add empty sentences because Tensorflow requires every batch to be the same size. 237 | batch += [([], []) for _ in range(BATCH_SIZE - len(batch))] 238 | # Prepare inputs 239 | #### We do this here for convenience and to have greater alignment between implementations, but in practise it would be best to do this once in pre-processing. 240 | max_length = len(batch[0][0]) 241 | input_array = np.zeros([len(batch), max_length]) 242 | output_array = np.zeros([len(batch), max_length]) 243 | lengths = np.array([len(v[0]) for v in batch]) 244 | mask = np.zeros([len(batch), max_length]) 245 | #### Convert tokens and tags from strings to numbers using the indices. 246 | for n, (tokens, tags) in enumerate(batch): 247 | token_ids = [token_to_id.get(simplify_token(t), 0) for t in tokens] 248 | tag_ids = [tag_to_id[t] for t in tags] 249 | 250 | #### Fill the arrays, leaving the remaining values as zero (our padding value). 251 | input_array[n, :len(tokens)] = token_ids 252 | output_array[n, :len(tags)] = tag_ids 253 | mask[n, :len(tokens)] = np.ones([len(tokens)]) 254 | #### We can't change the computation graph to disable dropout when not training, so we just change the keep probability. 255 | cur_keep_prob = KEEP_PROB if train else 1.0 256 | #### This dictionary contains values for all of the placeholders we defined. 257 | feed = { 258 | e_input: input_array, 259 | e_gold_output: output_array, 260 | e_mask: mask, 261 | e_keep_prob: cur_keep_prob, 262 | e_lengths: lengths, 263 | e_learning_rate: lr 264 | } 265 | 266 | # Define the computations needed 267 | todo = [e_auto_output] 268 | #### If we are not training we do not need to compute a loss and we do not want to do the update. 269 | if train: 270 | todo.append(e_loss) 271 | todo.append(e_train) 272 | # Run computations 273 | outcomes = session.run(todo, feed_dict=feed) 274 | # Get outputs 275 | predicted = outcomes[0] 276 | if train: 277 | #### We do not request the e_train value because its work is done - it performed the update during its computation. 278 | loss += outcomes[1] 279 | 280 | #### 281 | # Update the number of correct tags and total tags 282 | for (_, g), a in zip(batch, predicted): 283 | total += len(g) 284 | for gt, at in zip(g, a): 285 | gt = tag_to_id[gt] 286 | if gt == at: 287 | match += 1 288 | 289 | return loss, match / total 290 | 291 | if __name__ == '__main__': 292 | main() 293 | -------------------------------------------------------------------------------- /visualise.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | 5 | import pygments 6 | from pygments.lexers import get_lexer_by_name 7 | from pygments.formatters import HtmlFormatter 8 | 9 | lexer = get_lexer_by_name("python", stripall=True) 10 | formatter = HtmlFormatter(cssclass="source") 11 | 12 | def highlight(raw_code): 13 | code = pygments.highlight(raw_code, lexer, formatter) 14 | if len(raw_code) > 0: 15 | if raw_code[-1] == '\n': 16 | code = code.split("")[0] +"\n" 17 | if raw_code[0] == ' ': 18 | indent = 0 19 | for i, char in enumerate(raw_code): 20 | if char != ' ': 21 | break 22 | indent += 1 23 | parts = code.split("""") and code.endswith(""): 27 | code = code[20:-6] 28 | return code 29 | 30 | def print_comment_and_code(content, p0, p1, p2): 31 | comment0 = '' 32 | comment1 = '' 33 | comment2 = '' 34 | code0 = ' ' 35 | code1 = ' ' 36 | code2 = ' ' 37 | if p0 is not None: 38 | part = content[0][p0] 39 | comment0 = [v[0] for v in part if v[0] is not None and v[1] is None] 40 | code0 = '\n'.join([v[1] for v in part if v[1] is not None]) 41 | code0 = highlight(code0) 42 | if p1 is not None: 43 | part = content[1][p1] 44 | comment1 = [v[0] for v in part if v[0] is not None and v[1] is None] 45 | code1 = '\n'.join([v[1] for v in part if v[1] is not None]) 46 | code1 = highlight(code1) 47 | if p2 is not None: 48 | part = content[2][p2] 49 | comment2 = [v[0] for v in part if v[0] is not None and v[1] is None] 50 | code2 = '\n'.join([v[1] for v in part if v[1] is not None]) 51 | code2 = highlight(code2) 52 | 53 | class_name = 'shared-content' 54 | comment = comment0 55 | if p0 is not None and p1 is None and p2 is None: 56 | if len(''.join(comment).strip()) > 0: 57 | class_name = 'dynet' 58 | elif p0 is None and p1 is not None and p2 is None: 59 | comment = comment1 60 | if len(''.join(comment).strip()) > 0: 61 | class_name = 'pytorch' 62 | elif p0 is None and p1 is None and p2 is not None: 63 | comment = comment2 64 | if len(''.join(comment).strip()) > 0: 65 | class_name = 'tensorflow' 66 | 67 | print("""
""".format(class_name)) 68 | 69 | div_comment = """""".format(class_name) 70 | if len(''.join(comment).strip()) > 0: 71 | div_comment += "\n
\n".join(comment) +"

" 72 | else: 73 | div_comment += " " 74 | div_comment += "
" 75 | 76 | if len(''.join(comment).strip()) > 0: 77 | if code0 == " ": 78 | if code1 == " ": 79 | order = 'tfonly' 80 | else: 81 | order = 'ptonly' 82 | 83 | div_code0 = """""" + code0 +"" 84 | div_code1 = """""" + code1 +"" 85 | div_code2 = """""" + code2 + "" 86 | if class_name == 'pytorch': 87 | div_code0 = """""" + code0 +"" 88 | div_code1 = """""" + code1 +"" 89 | div_code2 = """""" + code2 + "" 90 | elif class_name == 'tensorflow': 91 | div_code0 = """""" + code0 +"" 92 | div_code1 = """""" + code1 +"" 93 | div_code2 = """""" + code2 + "" 94 | 95 | print(div_comment, end="") 96 | print(div_code0, end="") 97 | print(div_code1, end="") 98 | print(div_code2, end="") 99 | 100 | print("
") 101 | 102 | def read_file(filename): 103 | parts = [[]] 104 | prev_comment = True 105 | for line in open(filename): 106 | line = line.strip('\n') 107 | 108 | # Update comment status 109 | if line.strip().startswith("####"): 110 | if not prev_comment: 111 | parts.append([]) 112 | prev_comment = True 113 | else: 114 | prev_comment = False 115 | 116 | # Divide up the line 117 | comment = None 118 | code = None 119 | if line.strip().startswith("####"): 120 | comment = line.strip()[4:].strip() 121 | elif '#' in line and line.strip()[0] != '#': 122 | comment = line.split("#")[-1] 123 | code = line[:-len(comment)-1] 124 | comment = comment.strip() 125 | else: 126 | code = line 127 | 128 | parts[-1].append((comment, code)) 129 | return parts 130 | 131 | def match(part0, part1, do_comments=False): 132 | if do_comments: 133 | part0 = ' '.join([v[0].strip() for v in part0 if v[0] is not None and v[1] is None]) 134 | part1 = ' '.join([v[0].strip() for v in part1 if v[0] is not None and v[1] is None]) 135 | return part0 == part1 and part0.strip() != '' 136 | else: 137 | part0 = ' '.join([v[1].strip() for v in part0 if v[1] is not None]) 138 | part1 = ' '.join([v[1].strip() for v in part1 if v[1] is not None]) 139 | return part0 == part1 140 | 141 | def align(content): 142 | # Find parts in common between all three 143 | matches = set() 144 | for i0, part0 in enumerate(content[0]): 145 | for i1, part1 in enumerate(content[1]): 146 | if match(part0, part1): 147 | for i2, part2 in enumerate(content[2]): 148 | if match(part0, part2): 149 | matches.add((i0, i1, i2)) 150 | if match(part0, part1, True): 151 | for i2, part2 in enumerate(content[2]): 152 | if match(part0, part2, True): 153 | matches.add((i0, i1, i2)) 154 | matches = sorted(list(matches)) 155 | return matches 156 | 157 | def main(): 158 | # Read data 159 | content = [read_file(filename) for filename in sys.argv[1:]] 160 | 161 | # Work out aligned sections 162 | matches = align(content) 163 | 164 | # Render 165 | print(head) 166 | 167 | print("""
""") 168 | positions = [0 for _ in content] 169 | for p0, p1, p2 in matches: 170 | while positions[0] < p0: 171 | print_comment_and_code(content, positions[0], None, None) 172 | positions[0] += 1 173 | while positions[1] < p1: 174 | print_comment_and_code(content, None, positions[1], None) 175 | positions[1] += 1 176 | while positions[2] < p2: 177 | print_comment_and_code(content, None, None, positions[2]) 178 | positions[2] += 1 179 | print_comment_and_code(content, p0, p1, p2) 180 | positions[0] += 1 181 | positions[1] += 1 182 | positions[2] += 1 183 | 184 | print("""
""") 185 | 186 | print(tail) 187 | 188 | ###style_dark = """ 189 | ### 347 | ###""" 348 | 349 | style_light = """ 350 | 529 | """ 530 | 531 | head = """ 533 | 534 | 535 | 536 | 537 | 544 | Neural Tagger Implementations 545 | 546 | """+ style_light +""" 547 | 548 | 549 | 550 |
551 |

Implementing a neural Part-of-Speech tagger

552 |

by Jonathan K. Kummerfeld [site]

553 |
554 |

555 | DyNet, PyTorch and Tensorflow are complex frameworks with different ways of approaching neural network implementation and variations in default behaviour. 556 | This page is intended to show how to implement the same non-trivial model in all three. 557 | The design of the page is motivated by my own preference for a complete program with annotations, rather than the more common tutorial style of introducing code piecemeal in between discussion. 558 | The design of the code is also geared towards providing a complete picture of how things fit together. 559 | For a non-tutorial version of this code it would be better to use abstraction to improve flexibility, but that would have complicated the flow here. 560 |

561 |

562 | Model: 563 | The three implementations below all define a part-of-speech tagger with word embeddings initialised using GloVe, fed into a one-layer bidirectional LSTM, followed by a matrix multiplication to produce scores for tags. 564 | They all score ~97.2% on the development set of the Penn Treebank. 565 | The specific hyperparameter choices follows Yang, Liang, and Zhang (CoLing 2018) and matches their performance for the setting without a CRF layer or character-based word embeddings. 566 | The repository for this page provides the code in runnable form. 567 | The only dependencies are the respective frameworks (DyNet 2.0.3, PyTorch 0.4.1 and Tensorflow 1.9.0). 568 |

569 |

570 | Website usage: Use the buttons to show one or more implementations and their associated comments (note, depending on your screen size you may need to scroll to see all the code). 571 | Matching or closely related content is aligned. 572 | Framework-specific comments are highlighted in a colour that matches their button and a line is used to make the link from the comment to the code clear. 573 |

574 |

575 | New (2019) Runnable Version: I have made a slightly modified version of the Tensorflow code available as a Google Colaboratory Notebook. 576 |

577 |

578 | Making this helped me understand all three frameworks better. Hopefully you will find it informative too! 579 |

580 | 581 |
582 | 583 | 584 | 585 |
586 |
587 | 588 |
589 | 590 | """ 591 | 592 | tail = """ 593 | 594 | 693 | 694 |
695 |
696 |

697 | This code was last updated in August 2018. 698 | If one of the frameworks has changed in a way that should be reflected here, please let me know! 699 |

700 |

701 | A few miscellaneous notes: 702 |

    703 |
  • PyTorch 0.4 does not support recurrent dropout directly. For an example of how to achieve it, see the LSTM and QRNN Language Model Toolkit's WeightDrop class and how it is used.
  • 704 |
  • Tensorflow 1.9 does not support weight decay directly, but this pull request appears to add support and will be part of 1.10.
  • 705 |
706 |

707 |

708 | And a few other gotchas I've come across: 709 |

    710 |
  • For PyTorch, consider running your code with these two environment variables set: "OMP_NUM_THREADS=1 and MKL_NUM_THREADS=1". The reason is that they prevent unnecessary thread creation by low-level matrix manipulation libraries. See this twitter thread for discussion.
  • 711 |
712 |

713 |

714 | I developed this code and webpage with help from many people and resources. In particular: 715 |

721 |

722 |
723 |
724 | 725 |
726 |
727 |
728 | 741 | 742 | 743 | """ 744 | 745 | if __name__ == '__main__': 746 | main() 747 | 748 | --------------------------------------------------------------------------------