├── README.md ├── RecursiveNN.py ├── SenTree.py ├── TreeLSTM.py └── trees ├── dev.txt ├── small.txt ├── test.txt └── train.txt /README.md: -------------------------------------------------------------------------------- 1 | PyTorch implementation of
2 | 1) RNN: Recursive Neural Network from https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf
3 | python RecursiveNN.py
4 | 2) TreeLSTM from https://arxiv.org/abs/1503.00075
5 | python TreeLSTM.py
6 | 7 | add cuda as an argument to run it using cuda
8 | 9 | Requirements:
10 | nltk
11 | pytorch
12 | progressbar
13 | 14 | Mixed code from:
15 | Socher's cs224d class (see for e.g. https://github.com/kingtaurus/cs224d/tree/master/assignment3)
16 | and https://gist.github.com/wolet/1b49c03968b2c83897a4a15c78980b18
17 | -------------------------------------------------------------------------------- /RecursiveNN.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import random 3 | import progressbar 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import torch.nn.functional as F 8 | from torch.nn.utils import clip_grad_norm 9 | from SenTree import * 10 | 11 | class RecursiveNN(nn.Module): 12 | def __init__(self, vocabSize, embedSize=100, numClasses=5): 13 | super(RecursiveNN, self).__init__() 14 | self.embedding = nn.Embedding(int(vocabSize), embedSize) 15 | self.W = nn.Linear(2*embedSize, embedSize, bias=True) 16 | self.projection = nn.Linear(embedSize, numClasses, bias=True) 17 | self.activation = F.relu 18 | self.nodeProbList = [] 19 | self.labelList = [] 20 | 21 | def traverse(self, node): 22 | if node.isLeaf(): currentNode = self.activation(self.embedding(Var(torch.LongTensor([node.getLeafWord()])))) 23 | else: currentNode = self.activation(self.W(torch.cat((self.traverse(node.left()),self.traverse(node.right())),1))) 24 | self.nodeProbList.append(self.projection(currentNode)) 25 | self.labelList.append(torch.LongTensor([node.label()])) 26 | return currentNode 27 | 28 | def forward(self, x): 29 | self.nodeProbList = [] 30 | self.labelList = [] 31 | self.traverse(x) 32 | self.labelList = Var(torch.cat(self.labelList)) 33 | return torch.cat(self.nodeProbList) 34 | 35 | def getLoss(self, tree): 36 | nodes = self.forward(tree) 37 | predictions = nodes.max(dim=1)[1] 38 | loss = F.cross_entropy(input=nodes, target=self.labelList) 39 | return predictions,loss 40 | 41 | def evaluate(self, trees): 42 | pbar = progressbar.ProgressBar(widgets=widgets, maxval=len(trees)).start() 43 | n = nAll = correctRoot = correctAll = 0.0 44 | for j, tree in enumerate(trees): 45 | predictions,loss = self.getLoss(tree) 46 | correct = (predictions.data==self.labelList.data) 47 | correctAll += correct.sum() 48 | nAll += correct.squeeze().size()[0] 49 | correctRoot += correct.squeeze()[-1] 50 | n += 1 51 | pbar.update(j) 52 | pbar.finish() 53 | return correctRoot / n, correctAll/nAll 54 | 55 | def Var(v): 56 | if CUDA: return Variable(v.cuda()) 57 | else: return Variable(v) 58 | 59 | CUDA=False 60 | if len(sys.argv)>1: 61 | if sys.argv[1].lower()=="cuda": CUDA=True 62 | 63 | print("Reading and parsing trees") 64 | trn = SenTree.getTrees("./trees/train.txt","train.vocab") 65 | dev = SenTree.getTrees("./trees/dev.txt",vocabIndicesMapFile="train.vocab") 66 | 67 | if CUDA: model = RecursiveNN(SenTree.vocabSize).cuda() 68 | else: model = RecursiveNN(SenTree.vocabSize) 69 | max_epochs = 100 70 | widgets = [progressbar.Percentage(), ' ', progressbar.Bar(), ' ', progressbar.ETA()] 71 | 72 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, dampening=0.0) 73 | bestAll=bestRoot=0.0 74 | for epoch in range(max_epochs): 75 | print("Epoch %d" % epoch) 76 | pbar = progressbar.ProgressBar(widgets=widgets, maxval=len(trn)).start() 77 | for step, tree in enumerate(trn): 78 | predictions, loss = model.getLoss(tree) 79 | optimizer.zero_grad() 80 | loss.backward() 81 | clip_grad_norm(model.parameters(), 5, norm_type=2.) 82 | optimizer.step() 83 | pbar.update(step) 84 | pbar.finish() 85 | correctRoot, correctAll = model.evaluate(dev) 86 | if bestAll1: 81 | if sys.argv[1].lower()=="cuda": CUDA=True 82 | 83 | print("Reading and parsing trees") 84 | trn = SenTree.getTrees("./trees/train.txt","train.vocab") 85 | dev = SenTree.getTrees("./trees/dev.txt",vocabIndicesMapFile="train.vocab") 86 | 87 | if CUDA: model = TreeLSTM(SenTree.vocabSize).cuda() 88 | else: model = TreeLSTM(SenTree.vocabSize) 89 | max_epochs = 100 90 | widgets = [progressbar.Percentage(), ' ', progressbar.Bar(), ' ', progressbar.ETA()] 91 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, dampening=0.0) 92 | bestAll=bestRoot=0.0 93 | for epoch in range(max_epochs): 94 | print("Epoch %d" % epoch) 95 | pbar = progressbar.ProgressBar(widgets=widgets, maxval=len(trn)).start() 96 | for step, tree in enumerate(trn): 97 | predictions, loss = model.getLoss(tree) 98 | optimizer.zero_grad() 99 | loss.backward() 100 | clip_grad_norm(model.parameters(), 5, norm_type=2.) 101 | optimizer.step() 102 | pbar.update(step) 103 | pbar.finish() 104 | correctRoot, correctAll = model.evaluate(dev) 105 | if bestAll