├── preprocess.sh ├── Results.pdf ├── util.py ├── dataset.py ├── config.py ├── .gitignore ├── model.py ├── vocab.py ├── main.py ├── README.md ├── gru.py ├── rnn.py └── train.py /preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | paste $1 $2 > $3 4 | -------------------------------------------------------------------------------- /Results.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shivanshu-Gupta/Pytorch-POS-Tagger/HEAD/Results.pdf -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | # not used in final implementaiton 2 | 3 | # write unique words from a set of files to a new file 4 | def build_vocab(filename, vocabfile): 5 | vocab = set() 6 | with open(filename, 'r') as f: 7 | for line in f: 8 | tokens = line.rstrip('\n').split(' ') 9 | vocab |= set(tokens) 10 | idx = 0 11 | print(vocabfile) 12 | with open(vocabfile, 'w') as f: 13 | for token in vocab: 14 | f.write(token + '\n') 15 | idx = idx + 1 16 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # custom Dataset - not used in final implementation. 2 | 3 | import os 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class POSDataset(Dataset): 9 | def __init__(self, path, sen_vocab, tag_vocab): 10 | super(POSDataset, self).__init__() 11 | self.sen_vocab = sen_vocab 12 | self.tag_vocab = tag_vocab 13 | self.num_classes = tag_vocab.size() 14 | sen_file = os.path.join(path, 'sentences.txt') 15 | tag_file = os.path.join(path, 'tags.txt') 16 | self.sentences = [] 17 | with open(sen_file, 'r') as f: 18 | for line in f: 19 | idxs = self.sen_vocab.toIdx(line.rstrip('\n').split(' ')) 20 | tensor = torch.LongTensor(idxs) 21 | self.sentences.append(tensor) 22 | 23 | self.tags = [] 24 | with open(tag_file, 'r') as f: 25 | for line in f: 26 | idxs = self.tag_vocab.toIdx(line.rstrip('\n').split(' ')) 27 | tensor = torch.LongTensor(idxs) 28 | self.tags.append(tensor) 29 | 30 | # making sure there are same number of sentences as tags. 31 | assert(len(self.sentences) == len(self.tags)) 32 | 33 | def __getitem__(self, index): 34 | sentence = self.sentences[index] 35 | tags = self.tags[index] 36 | return sentence, tags 37 | 38 | def __len__(self): 39 | return len(self.sentences) 40 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_args(): 5 | print("parsing arguments") 6 | parser = argparse.ArgumentParser(description='PyTorch Parts-of-Speech Tagger') 7 | parser.add_argument('--use_gpu', default=False, action='store_true') 8 | 9 | parser.add_argument('--data_dir', default='RNN_Data_files/', metavar='PATH', 10 | help='directory containing train_data.tsv and val_data.tsv') 11 | parser.add_argument('--save_dir', default='/home/cse/dual/cs5130298/scratch/checkpoints2/', metavar='PATH') 12 | 13 | parser.add_argument('--rnn_class', choices=['lstm', 'gru', 'rnn', 'customgru'], default='lstm', 14 | help='class of underlying RNN to use') 15 | parser.add_argument('--reload', default='', metavar='PATH', 16 | help='path to checkpoint to load (default: none)') 17 | parser.add_argument('--test', default=False, action='store_true', 18 | help='test model on test set (use with --reload)') 19 | 20 | parser.add_argument('--batch_size', type=int, default=1, 21 | help='batchsize for optimizer updates') 22 | parser.add_argument('--epochs', type=int, default=1, 23 | help='number of total epochs to run') 24 | 25 | parser.add_argument('--lr', type=float, default=0.1, 26 | metavar='LR', help='initial learning rate') 27 | parser.add_argument('--step_size', type=int, default=10, metavar='N') 28 | parser.add_argument('--gamma', type=float, default=1) 29 | 30 | parser.add_argument('--seed', type=int, default=123, 31 | help='random seed (default: 123)') 32 | 33 | args = parser.parse_args() 34 | return args 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | i# Django stuff: 54 | *.log 55 | .static_storage/ 56 | .media/ 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | .DS_Store 107 | 108 | scratch/ 109 | checkpoints/ 110 | scripts/ 111 | runs/ 112 | results/ 113 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from rnn import CustomRNN 4 | from gru import CustomGRUCell 5 | 6 | 7 | class POSTagger(nn.Module): 8 | 9 | def __init__(self, rnn_class, embedding_dim, hidden_dim, vocab_size, target_size, use_gpu=True): 10 | super(POSTagger, self).__init__() 11 | self.rnn_class = rnn_class 12 | self.hidden_dim = hidden_dim 13 | 14 | self.word_embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=1) 15 | self.use_gpu = use_gpu 16 | if use_gpu: 17 | self.word_embeddings.cuda() 18 | self.num_layers = 1 19 | # The LSTM takes word embeddings as inputs, and outputs hidden states 20 | # with dimensionality hidden_dim. 21 | if self.rnn_class == 'lstm': 22 | self.rnn = CustomRNN(nn.LSTMCell, embedding_dim, hidden_dim, batch_first=False) 23 | elif self.rnn_class == 'gru': 24 | self.rnn = CustomRNN(nn.GRUCell, embedding_dim, hidden_dim, batch_first=False) 25 | elif self.rnn_class == 'rnn': 26 | self.rnn = CustomRNN(nn.RNNCell, embedding_dim, hidden_dim, batch_first=False) 27 | else: 28 | self.rnn = CustomRNN(CustomGRUCell, embedding_dim, hidden_dim, batch_first=False) 29 | # The linear layer that maps from hidden state space to tag space 30 | self.hidden2tag = nn.Linear(hidden_dim, target_size) 31 | 32 | def forward(self, sentences, ranges, lengths): 33 | embeds = self.word_embeddings(sentences) 34 | lstm_out, _ = self.rnn(embeds, ranges, lengths) 35 | tag_space = self.hidden2tag(lstm_out) 36 | tag_scores = F.log_softmax(tag_space) 37 | # do this if want 3D tensor (ref: https://github.com/pytorch/pytorch/issues/1020) 38 | # tag_scores = F.log_softmax(tag_space.transpose(0, 2)).transpose(0, 2) 39 | return tag_scores 40 | -------------------------------------------------------------------------------- /vocab.py: -------------------------------------------------------------------------------- 1 | # reference: vocab object from harvardnlp/opennmt-py 2 | 3 | class Vocab(object): 4 | def __init__(self, filename=None, lower=False, unkWord='unk'): 5 | self.idxToLabel = {} 6 | self.labelToIdx = {} 7 | self.lower = lower 8 | 9 | if filename is not None: 10 | self.loadFile(filename) 11 | 12 | # have only 'unk' as special word 13 | idx = self.add(unkWord) 14 | self.unk = idx 15 | 16 | def size(self): 17 | return len(self.idxToLabel) 18 | 19 | # Load entries from a file. 20 | def loadFile(self, filename): 21 | for line in open(filename): 22 | token = line.rstrip('\n') 23 | self.add(token) 24 | 25 | def getIndex(self, key, default=None): 26 | if self.lower: 27 | key = key.lower() 28 | try: 29 | return self.labelToIdx[key] 30 | except KeyError: 31 | return default 32 | 33 | def getLabel(self, idx, default=None): 34 | try: 35 | return self.idxToLabel[idx] 36 | except KeyError: 37 | return default 38 | 39 | # Add `label` in the dictionary. Use `idx` as its index if given. 40 | def add(self, label): 41 | if self.lower: 42 | label = label.lower() 43 | 44 | if label in self.labelToIdx: 45 | idx = self.labelToIdx[label] 46 | else: 47 | idx = len(self.idxToLabel) + 1 48 | self.idxToLabel[idx] = label 49 | self.labelToIdx[label] = idx 50 | return idx 51 | 52 | # Convert `labels` to indices. Use `unkWord` if not found. 53 | # Optionally insert `bosWord` at the beginning and `eosWord` at the . 54 | def toIdx(self, labels): 55 | vec = [self.getIndex(label, default=self.unk) for label in labels] 56 | return vec 57 | 58 | # Convert `idx` to labels. If index `stop` is reached, convert it and return. 59 | def toLabels(self, idx, stop): 60 | labels = [] 61 | 62 | for i in idx: 63 | labels += [self.getLabel(i)] 64 | if i == stop: 65 | break 66 | 67 | return labels 68 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.optim import lr_scheduler 6 | from torchtext import data 7 | from config import parse_args 8 | from model import POSTagger 9 | from train import train_model, test_model 10 | 11 | # These will usually be more like 32 or 64 dimensional. 12 | # We will keep them small, so we can see how the weights change as we train. 13 | EMBEDDING_DIM = 300 14 | HIDDEN_DIM = 200 15 | 16 | 17 | def load_datasets(): 18 | text = data.Field(include_lengths=True) 19 | tags = data.Field() 20 | train_data, val_data, test_data = data.TabularDataset.splits(path='RNN_Data_files/', train='train_data.tsv', validation='val_data.tsv', test='val_data.tsv', fields=[('text', text), ('tags', tags)], format='tsv') 21 | 22 | batch_sizes = (args.batch_size, args.batch_size, args.batch_size) 23 | train_loader, val_loader, test_loader = data.BucketIterator.splits((train_data, val_data, test_data), batch_sizes=batch_sizes, sort_key=lambda x: len(x.text)) 24 | 25 | text.build_vocab(train_data) 26 | tags.build_vocab(train_data) 27 | dataloaders = {'train': train_loader, 28 | 'validation': val_loader, 29 | 'test': val_loader} 30 | return text, tags, dataloaders 31 | 32 | 33 | def save_params(): 34 | os.makedirs(args.save_dir, exist_ok=True) 35 | param_file = args.save_dir + '/' + 'params.txt' 36 | with open(param_file, 'w') as fout: 37 | fout.write(args) 38 | 39 | 40 | if __name__ == '__main__': 41 | global args 42 | args = parse_args() 43 | save_params() 44 | args.use_gpu = args.use_gpu and torch.cuda.is_available() 45 | print(args) 46 | torch.manual_seed(args.seed) 47 | torch.cuda.manual_seed(args.seed) 48 | 49 | text, tags, dataloaders = load_datasets() 50 | text_vocab_size = len(text.vocab.stoi) + 1 51 | tag_vocab_size = len(tags.vocab.stoi) - 1 # = 42 (not including the token 52 | print(text_vocab_size) 53 | print(tag_vocab_size) 54 | 55 | model = POSTagger(args.rnn_class, EMBEDDING_DIM, HIDDEN_DIM, 56 | text_vocab_size, tag_vocab_size, args.use_gpu) 57 | if args.use_gpu: 58 | model = model.cuda() 59 | 60 | if args.reload: 61 | if os.path.isfile(args.reload): 62 | print("=> loading checkpoint '{}'".format(args.reload)) 63 | checkpoint = torch.load(args.reload) 64 | model.load_state_dict(checkpoint['state_dict']) 65 | # optimizer.reload_state_dict(checkpoint['optimizer']) 66 | print("=> loaded checkpoint '{}' (epoch {}, accuracy {})" 67 | .format(args.reload, checkpoint['epoch'], checkpoint['best_acc'])) 68 | else: 69 | print("=> no checkpoint found at '{}'".format(args.reload)) 70 | 71 | if args.test: 72 | test_model(model, dataloaders['test'], use_gpu=args.use_gpu) 73 | else: 74 | criterion = nn.NLLLoss() 75 | optimizer = optim.SGD(model.parameters(), lr=args.lr) 76 | # Decay LR by a factor of gamma every step_size epochs 77 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) 78 | 79 | print("begin training") 80 | model = train_model(model, dataloaders, criterion, optimizer, exp_lr_scheduler, args.save_dir, 81 | num_epochs=args.epochs, use_gpu=args.use_gpu) 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Parts-of-Speech Tagger 2 | 3 | The purpose of this project was to learn how to implement RNNs and compare different types of RNNs on the task of **Parts-of-Speech tagging** using a part of the **CoNLL-2012 dataset** with 42 possible tags. This repository contains: 4 | 1. a custom implementation of the *GRU cell*. 5 | 2. a custom implementation of the RNN architecture that may be configured to be used as an *LSTM*, *GRU* or *Vanilla RNN*. 6 | 3. a Parts-of-Speech tagger that can be configured to use any of the above custom RNN implementations. 7 | 8 | ## Requirements 9 | - python 3.5 10 | - [pytorch](http://pytorch.org/) 11 | - [torchtext](https://github.com/pytorch/text) 12 | 13 | ## Organisation 14 | The code in the repository are organised as follows: 15 | - [gru.py](https://github.com/Shivanshu-Gupta/Pytorch-POS-Tagger/blob/master/gru.py): custom GRU 16 | - [rnn.py](https://github.com/Shivanshu-Gupta/Pytorch-POS-Tagger/blob/master/rnn.py): custom RNN 17 | - [model.py](https://github.com/Shivanshu-Gupta/Pytorch-POS-Tagger/blob/master/model.py): POS Tagger Model 18 | - [train.py](https://github.com/Shivanshu-Gupta/Pytorch-POS-Tagger/blob/master/train.py): training/validation/testing code 19 | - [main.py](https://github.com/Shivanshu-Gupta/Pytorch-POS-Tagger/blob/master/main.py): driver code 20 | 21 | The raw dataset is in [RNN_Data_files/]. 22 | 23 | ## Usage 24 | ### Preprocessing datasets 25 | Use [preprocess.sh](https://github.com/Shivanshu-Gupta/Pytorch-POS-Tagger/blob/master/preprocess.sh) to generate tsv datasets containing sentences and POS tags in the intended *data_dir* (*RNN_Data_files/* here). 26 | ```sh 27 | $ ./preprocess.sh RNN_Data_files/train/sentences.tsv RNN_Data_files/train/tags.tsv RNN_Data_files/train_data.tsv 28 | $ ./preprocess.sh RNN_Data_files/val/sentences.tsv RNN_Data_files/val/tags.tsv RNN_Data_files/val_data.tsv 29 | ``` 30 | ### Training/Testing 31 | ```sh 32 | usage: main.py [-h] [--use_gpu] [--data_dir PATH] [--save_dir PATH] 33 | [--rnn_class RNN_CLASS] [--reload PATH] [--test] 34 | [--batch_size BATCH_SIZE] [--epochs EPOCHS] [--lr LR] 35 | [--step_size N] [--gamma GAMMA] [--seed SEED] 36 | 37 | PyTorch Parts-of-Speech Tagger 38 | 39 | optional arguments: 40 | -h, --help show this help message and exit 41 | --use_gpu 42 | --data_dir PATH directory containing train_data.tsv and val_data.tsv (default=RNN_Data_files/) 43 | --save_dir PATH 44 | --rnn_class RNN_CLASS 45 | class of underlying RNN to use 46 | --reload PATH path to checkpoint to load (default: none) 47 | --test test model on test set (use with --reload) 48 | --batch_size BATCH_SIZE 49 | batchsize for optimizer updates 50 | --epochs EPOCHS number of total epochs to run 51 | --lr LR initial learning rate 52 | --step_size N 53 | --gamma GAMMA 54 | --seed SEED random seed (default: 123) 55 | ``` 56 | ## Results 57 | [Results.pdf] compares the results for LSTM, GRU and Vanilla RNN based POS Taggers on various metrics. The best accuracy of 96.12% was obtained using LSTM-based POS Tagger. The pretrained model can be downloaded from [here](https://drive.google.com/open?id=0By07sE0zY59RRnhVd1VjUURlSWs). 58 | 59 | [RNN_Data_files/]: https://github.com/Shivanshu-Gupta/Pytorch-POS-Tagger/tree/master/RNN_Data_files 60 | [Results.pdf]: https://github.com/Shivanshu-Gupta/Pytorch-POS-Tagger/blob/master/Results.pdf 61 | -------------------------------------------------------------------------------- /gru.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import matmul 5 | from torch.nn import Parameter 6 | from torch.nn.modules.rnn import RNNCellBase 7 | 8 | 9 | class CustomGRUCell(RNNCellBase): 10 | r"""A custom gated recurrent unit (GRU) cell 11 | .. math:: 12 | \begin{array}{ll} 13 | r = \mathrm{sigmoid}(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ 14 | z = \mathrm{sigmoid}(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ 15 | n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ 16 | h' = (1 - z) * n + z * h 17 | \end{array} 18 | Args: 19 | input_size: The number of expected features in the input x 20 | hidden_size: The number of features in the hidden state h 21 | bias: If `False`, then the layer does not use bias weights `b_ih` and 22 | `b_hh`. Default: `True` 23 | Inputs: input, hidden 24 | - **input** (batch, input_size): tensor containing input features 25 | - **hidden** (batch, hidden_size): tensor containing the initial hidden 26 | state for each element in the batch. 27 | Outputs: h' 28 | - **h'**: (batch, hidden_size): tensor containing the next hidden state 29 | for each element in the batch 30 | Attributes: 31 | weight_ih: the learnable input-hidden weights, of shape 32 | `(3*hidden_size x input_size)` 33 | weight_hh: the learnable hidden-hidden weights, of shape 34 | `(3*hidden_size x hidden_size)` 35 | bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)` 36 | bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)` 37 | Examples:: 38 | >>> from gru import CustomGRUCell 39 | >>> rnn = CustomGRUCell(10, 20) 40 | >>> input = Variable(torch.randn(6, 3, 10)) 41 | >>> hx = Variable(torch.randn(3, 20)) 42 | >>> output = [] 43 | >>> for i in range(6): 44 | ... hx = rnn(input[i], hx) 45 | ... output.append(hx) 46 | """ 47 | 48 | def __init__(self, input_size, hidden_size, bias=True): 49 | super(CustomGRUCell, self).__init__() 50 | self.input_size = input_size 51 | self.hidden_size = hidden_size 52 | self.bias = bias 53 | self.weight_ih = Parameter(torch.Tensor(3 * hidden_size, input_size)) 54 | self.weight_hh = Parameter(torch.Tensor(3 * hidden_size, hidden_size)) 55 | if bias: 56 | self.bias_ih = Parameter(torch.Tensor(3 * hidden_size)) 57 | self.bias_hh = Parameter(torch.Tensor(3 * hidden_size)) 58 | else: 59 | self.register_parameter('bias_ih', None) 60 | self.register_parameter('bias_hh', None) 61 | self.reset_parameters() 62 | 63 | def reset_parameters(self): 64 | stdv = 1.0 / math.sqrt(self.hidden_size) 65 | for weight in self.parameters(): 66 | weight.data.uniform_(-stdv, stdv) 67 | 68 | def forward(self, input, hx): 69 | wih = self.weight_ih 70 | bih = self.bias_ih 71 | whh = self.weight_hh 72 | bhh = self.bias_hh 73 | dim_h = self.hidden_size 74 | 75 | bih = bih.expand(hx.size(0), 3 * dim_h) # batch_size * input_size 76 | bhh = bhh.expand(hx.size(0), 3 * dim_h) # batch_size * hidden_size 77 | # reset gate - batch_size * hidden_size 78 | # r = \mathrm{sigmoid}(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) 79 | r = F.sigmoid(matmul(input, wih[:dim_h].t()) + bih[:, :dim_h] + 80 | matmul(hx, whh[:dim_h].t()) + bhh[:, :dim_h]) 81 | 82 | # update gate - batch_size * hidden_size 83 | # z = \mathrm{sigmoid}(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) 84 | z = F.sigmoid(matmul(input, wih[dim_h:2 * dim_h].t()) + bih[:, dim_h:2 * dim_h] + 85 | matmul(hx, whh[dim_h:2 * dim_h].t()) + bhh[:, dim_h:2 * dim_h]) 86 | 87 | # n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) 88 | n = F.tanh(matmul(input, wih[2 * dim_h:].t()) + bih[:, 2 * dim_h:] + 89 | matmul(r * hx, whh[2 * dim_h:].t()) + bhh[:, 2 * dim_h:]) 90 | 91 | # h' = (1 - z) * n + z * h 92 | output = (1 - z) * n + z * hx 93 | 94 | return output 95 | -------------------------------------------------------------------------------- /rnn.py: -------------------------------------------------------------------------------- 1 | # reference: https://github.com/jihunchoi/recurrent-batch-normalization-pytorch/blob/master/bnlstm.py 2 | import torch 3 | import torch.nn as nn 4 | from torch import autograd 5 | 6 | 7 | class CustomRNN(nn.Module): 8 | 9 | """A module that implements an RNN using the given Cell type.""" 10 | 11 | def __init__(self, cell_class, input_size, hidden_size, 12 | use_bias=True, batch_first=False, **kwargs): 13 | super(CustomRNN, self).__init__() 14 | self.cell_class = cell_class 15 | self.input_size = input_size 16 | self.hidden_size = hidden_size 17 | self.use_bias = use_bias 18 | self.batch_first = batch_first 19 | 20 | self.cell = cell_class(input_size=input_size, 21 | hidden_size=hidden_size, 22 | **kwargs) 23 | self.cell.reset_parameters() 24 | 25 | # def _forward_rnn(self, cell, input_, ranges, lengths, hx): 26 | # max_time, batch_size, _ = input_.size() 27 | # output = [[] for i in range(batch_size)] 28 | # # print(input_.size()) 29 | # curr = 0 30 | # # print(ranges) 31 | # for time in range(max_time): 32 | # beg = ranges[curr][0] 33 | # end = ranges[curr][1] 34 | # assert(input_[time].size(0) == hx[0].size(0)) 35 | # hx = cell(input=input_[time], hx=hx) 36 | # if isinstance(cell, nn.LSTMCell): 37 | # for idx in range(beg, end): 38 | # output[idx].append(hx[0][idx - beg]) 39 | # else: 40 | # for idx in range(beg, end): 41 | # output[idx].append(hx[idx - beg]) 42 | # if time == lengths[beg] - 1 and time != max_time - 1: 43 | # curr += 1 44 | # input_ = input_[ranges[curr][0] - beg:] 45 | # if isinstance(cell, nn.LSTMCell): 46 | # h_next = hx[0][ranges[curr][0] - beg:] 47 | # c_next = hx[0][ranges[curr][0] - beg:] 48 | # hx = (h_next, c_next) 49 | # else: 50 | # h_next = cell(input=input_[time], hx=hx) 51 | # h_next = h_next[ranges[curr][0] - beg:] 52 | # hx = h_next 53 | # output = [torch.stack(sentence_out, 0) for sentence_out in output] 54 | # output = torch.cat(output, 0) 55 | # return output, hx 56 | 57 | # def _forward_rnn(self, cell, input_, length, hx): 58 | # max_time = input_.size(0) 59 | # output = [] 60 | # for time in range(max_time): 61 | # if isinstance(cell, nn.LSTMCell): 62 | # h_next, c_next = cell(input_=input_[time], hx=hx, time=time) 63 | # mask = (time < length).float().unsqueeze(1).expand_as(h_next) 64 | # h_next = h_next * mask + hx[0] * (1 - mask) 65 | # c_next = c_next * mask + hx[1] * (1 - mask) 66 | # hx_next = (h_next, c_next) 67 | # else: 68 | # h_next = cell(input_=input_[time], hx=hx) 69 | # mask = (time < length).float().unsqueeze(1).expand_as(h_next) 70 | # h_next = h_next * mask + hx[0] * (1 - mask) 71 | # hx_next = h_next 72 | 73 | # output.append(h_next) 74 | # hx = hx_next 75 | # output = torch.stack(output, 0) 76 | 77 | def _forward_rnn_no_mask(self, cell, input_, hx): 78 | max_time = input_.size(0) 79 | output = [] 80 | for time in range(max_time): 81 | if isinstance(cell, nn.LSTMCell): 82 | h_next, c_next = cell(input=input_[time], hx=hx) 83 | hx = (h_next, c_next) 84 | else: 85 | h_next = cell(input=input_[time], hx=hx) 86 | hx = h_next 87 | output.append(h_next) 88 | output = torch.cat(output, 0) 89 | # output = torch.stack(output, 0) # do this if want 3D tensor 90 | return output, hx 91 | 92 | def forward(self, input_, ranges, lengths, hx=None): 93 | if self.batch_first: 94 | input_ = input_.transpose(0, 1) 95 | max_time, batch_size, _ = input_.size() 96 | if hx is None: 97 | hx = autograd.Variable(input_.data.new(batch_size, self.hidden_size).zero_(), 98 | requires_grad=False) 99 | if self.cell_class == nn.LSTMCell: 100 | hx = (hx, hx) 101 | cell = self.cell 102 | # output, h_n = self._forward_rnn(cell=cell, input_=input_, ranges=ranges, lengths=lengths, hx=hx) 103 | output, h_n = self._forward_rnn_no_mask(cell=cell, input_=input_, hx=hx) 104 | return output, h_n 105 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import time 3 | from tensorboardX import SummaryWriter 4 | import torch 5 | 6 | 7 | def get_ranges(lengths): 8 | beg = 0 9 | curr = 0 10 | curr_length = lengths[0] 11 | ranges = [] 12 | for length in lengths: 13 | if length != curr_length: 14 | ranges.append((beg, curr)) 15 | curr_length = length 16 | beg = curr 17 | curr += 1 18 | ranges.append((beg, curr)) 19 | return ranges 20 | 21 | 22 | # run one epoch of training 23 | def train(model, train_loader, criterion, optimizer, use_gpu=False): 24 | model.train() # Set model to training mode 25 | running_loss = 0.0 26 | running_corrects = 0 27 | example_count = 0 28 | step = 0 29 | # Iterate over data. 30 | for batch in train_loader: 31 | all_sentences = batch.text[0] 32 | all_tags = batch.tags - 2 # first valid tag index is 2 33 | if not use_gpu: 34 | all_sentences = all_sentences.cpu() 35 | all_tags = all_tags.cpu() 36 | 37 | lengths = batch.text[1] 38 | batch_size = lengths.size(0) 39 | ranges = get_ranges(lengths) 40 | for rng in ranges: 41 | length = lengths[rng[0]] 42 | sentences = all_sentences[0:length, rng[0]:rng[1]].clone() 43 | tags = all_tags[0:length, rng[0]:rng[1]].clone() 44 | tags = tags.view(-1) 45 | 46 | # zero grad 47 | model.zero_grad() 48 | 49 | # forward 50 | tag_scores = model(sentences, [(0, rng[1] - rng[0])], lengths[rng[0]:rng[1]]) 51 | _, preds = torch.max(tag_scores, 1) 52 | loss = criterion(tag_scores, tags) 53 | 54 | # backward + optimize 55 | loss.backward() 56 | optimizer.step() 57 | 58 | # statistics 59 | running_loss += loss.data[0] 60 | running_corrects += torch.sum((preds == tags).data) 61 | example_count += torch.sum(lengths) 62 | step += 1 63 | # if step % 1000 == 0: 64 | # print('loss: {}, running_corrects: {}, example_count: {}, acc: {}'.format(loss.data[0], running_corrects, example_count, (running_corrects / example_count) * 100)) 65 | if step * batch_size == 40000: 66 | break 67 | loss = running_loss / example_count 68 | acc = (running_corrects / example_count) * 100 69 | print(loss) 70 | print(acc) 71 | # print('Train Loss: {:.4f} Acc: {:2.3f} ({}/{})'.format(loss, acc, running_corrects, example_count)) 72 | return loss, acc 73 | 74 | 75 | def validate(model, val_loader, criterion, use_gpu=False): 76 | model.eval() # Set model to evaluate mode 77 | running_loss = 0.0 78 | running_corrects = 0 79 | example_count = 0 80 | # Iterate over data. 81 | for batch in val_loader: 82 | all_sentences = batch.text[0] 83 | all_tags = batch.tags - 2 # first valid tag index is 2 84 | if not use_gpu: 85 | all_sentences = all_sentences.cpu() 86 | all_tags = all_tags.cpu() 87 | 88 | lengths = batch.text[1] 89 | ranges = get_ranges(lengths) 90 | for rng in ranges: 91 | length = lengths[rng[0]] 92 | sentences = all_sentences[0:length, rng[0]:rng[1]].clone() 93 | tags = all_tags[0:length, rng[0]:rng[1]].clone() 94 | tags = tags.view(-1) 95 | 96 | # forward 97 | tag_scores = model(sentences, [(0, rng[1] - rng[0])], lengths[rng[0]:rng[1]]) 98 | _, preds = torch.max(tag_scores, 1) 99 | loss = criterion(tag_scores, tags) 100 | 101 | # statistics 102 | running_loss += loss.data[0] 103 | running_corrects += torch.sum((preds == tags).data) 104 | example_count += torch.sum(lengths) 105 | loss = running_loss / example_count 106 | acc = (running_corrects / example_count) * 100 107 | print('Validation Loss: {:.4f} Acc: {:2.3f} ({}/{})'.format(loss, acc, running_corrects, example_count)) 108 | return loss, acc 109 | 110 | 111 | def train_model(model, data_loaders, criterion, optimizer, scheduler, save_dir, num_epochs=25, use_gpu=False): 112 | print('Training Model with use_gpu={}...'.format(use_gpu)) 113 | since = time.time() 114 | 115 | best_model_wts = model.state_dict() 116 | best_acc = 0.0 117 | writer = SummaryWriter(save_dir) 118 | for epoch in range(num_epochs): 119 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 120 | print('-' * 10) 121 | train_begin = time.time() 122 | train_loss, train_acc = train(model, data_loaders['train'], criterion, optimizer, use_gpu) 123 | train_time = time.time() - train_begin 124 | print('Epoch Train Time: {:.0f}m {:.0f}s'.format(train_time // 60, train_time % 60)) 125 | writer.add_scalar('Train Loss', train_loss, epoch) 126 | writer.add_scalar('Train Accuracy', train_acc, epoch) 127 | 128 | validation_begin = time.time() 129 | val_loss, val_acc = validate(model, data_loaders['validation'], criterion, use_gpu) 130 | validation_time = time.time() - validation_begin 131 | print('Epoch Validation Time: {:.0f}m {:.0f}s'.format(validation_time // 60, validation_time % 60)) 132 | writer.add_scalar('Validation Loss', val_loss, epoch) 133 | writer.add_scalar('Validation Accuracy', val_acc, epoch) 134 | 135 | # deep copy the model 136 | is_best = val_acc > best_acc 137 | if is_best: 138 | best_acc = val_acc 139 | best_model_wts = model.state_dict() 140 | 141 | save_checkpoint(save_dir, { 142 | 'epoch': epoch, 143 | 'best_acc': best_acc, 144 | 'state_dict': model.state_dict(), 145 | # 'optimizer': optimizer.state_dict(), 146 | }, is_best) 147 | 148 | scheduler.step() 149 | 150 | time_elapsed = time.time() - since 151 | print('Training complete in {:.0f}m {:.0f}s'.format( 152 | time_elapsed // 60, time_elapsed % 60)) 153 | print('Best val Acc: {:4f}'.format(best_acc)) 154 | # load best model weights 155 | model.load_state_dict(best_model_wts) 156 | 157 | # export scalar data to JSON for external processing 158 | writer.export_scalars_to_json(save_dir + "/all_scalars.json") 159 | writer.close() 160 | 161 | return model 162 | 163 | 164 | def save_checkpoint(save_dir, state, is_best): 165 | savepath = save_dir + '/' + 'checkpoint.pth.tar' 166 | torch.save(state, savepath) 167 | if is_best: 168 | shutil.copyfile(savepath, save_dir + '/' + 'model_best.pth.tar') 169 | 170 | 171 | def test_model(model, test_loader, use_gpu=False): 172 | model.eval() # Set model to evaluate mode 173 | running_corrects = 0 174 | example_count = 0 175 | test_begin = time.time() 176 | # Iterate over data. 177 | for batch in test_loader: 178 | all_sentences = batch.text[0] 179 | all_tags = batch.tags - 2 # first valid tag index is 2 180 | # print(all_sentences) 181 | # print(all_tags) 182 | if not use_gpu: 183 | all_sentences = all_sentences.cpu() 184 | all_tags = all_tags.cpu() 185 | 186 | lengths = batch.text[1] 187 | ranges = get_ranges(lengths) 188 | for rng in ranges: 189 | length = lengths[rng[0]] 190 | sentences = all_sentences[0:length, rng[0]:rng[1]].clone() 191 | tags = all_tags[0:length, rng[0]:rng[1]].clone() 192 | tags = tags.view(-1) 193 | # tags = torch.cat(torch.split(tags, split_size=1, dim=1)).squeeze(1) # do this if want 2D 194 | 195 | # forward 196 | tag_scores = model(sentences, [(0, rng[1] - rng[0])], lengths[rng[0]:rng[1]]) 197 | _, preds = torch.max(tag_scores.data, 1) 198 | # do this if want 2D 199 | # tag_scores = torch.cat(torch.split(tag_scores, split_size=1, dim=1)).squeeze(1) 200 | # _, preds = torch.max(tag_scores, 2) 201 | 202 | # statistics 203 | running_corrects += torch.sum(preds == tags.data) 204 | example_count += length * (rng[1] - rng[0]) 205 | acc = (running_corrects / example_count) * 100 206 | print('Test Acc: {:2.3f} ({}/{})'.format(acc, running_corrects, example_count)) 207 | test_time = time.time() - test_begin 208 | print('Test Time: {:.0f}m {:.0f}s'.format(test_time // 60, test_time % 60)) 209 | return acc 210 | --------------------------------------------------------------------------------