├── README.md
├── RecursiveNN.py
├── SenTree.py
├── TreeLSTM.py
└── trees
├── dev.txt
├── small.txt
├── test.txt
└── train.txt
/README.md:
--------------------------------------------------------------------------------
1 | PyTorch implementation of
2 | 1) RNN: Recursive Neural Network from https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf
3 | python RecursiveNN.py
4 | 2) TreeLSTM from https://arxiv.org/abs/1503.00075
5 | python TreeLSTM.py
6 |
7 | add cuda as an argument to run it using cuda
8 |
9 | Requirements:
10 | nltk
11 | pytorch
12 | progressbar
13 |
14 | Mixed code from:
15 | Socher's cs224d class (see for e.g. https://github.com/kingtaurus/cs224d/tree/master/assignment3)
16 | and https://gist.github.com/wolet/1b49c03968b2c83897a4a15c78980b18
17 |
--------------------------------------------------------------------------------
/RecursiveNN.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import random
3 | import progressbar
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import Variable
7 | import torch.nn.functional as F
8 | from torch.nn.utils import clip_grad_norm
9 | from SenTree import *
10 |
11 | class RecursiveNN(nn.Module):
12 | def __init__(self, vocabSize, embedSize=100, numClasses=5):
13 | super(RecursiveNN, self).__init__()
14 | self.embedding = nn.Embedding(int(vocabSize), embedSize)
15 | self.W = nn.Linear(2*embedSize, embedSize, bias=True)
16 | self.projection = nn.Linear(embedSize, numClasses, bias=True)
17 | self.activation = F.relu
18 | self.nodeProbList = []
19 | self.labelList = []
20 |
21 | def traverse(self, node):
22 | if node.isLeaf(): currentNode = self.activation(self.embedding(Var(torch.LongTensor([node.getLeafWord()]))))
23 | else: currentNode = self.activation(self.W(torch.cat((self.traverse(node.left()),self.traverse(node.right())),1)))
24 | self.nodeProbList.append(self.projection(currentNode))
25 | self.labelList.append(torch.LongTensor([node.label()]))
26 | return currentNode
27 |
28 | def forward(self, x):
29 | self.nodeProbList = []
30 | self.labelList = []
31 | self.traverse(x)
32 | self.labelList = Var(torch.cat(self.labelList))
33 | return torch.cat(self.nodeProbList)
34 |
35 | def getLoss(self, tree):
36 | nodes = self.forward(tree)
37 | predictions = nodes.max(dim=1)[1]
38 | loss = F.cross_entropy(input=nodes, target=self.labelList)
39 | return predictions,loss
40 |
41 | def evaluate(self, trees):
42 | pbar = progressbar.ProgressBar(widgets=widgets, maxval=len(trees)).start()
43 | n = nAll = correctRoot = correctAll = 0.0
44 | for j, tree in enumerate(trees):
45 | predictions,loss = self.getLoss(tree)
46 | correct = (predictions.data==self.labelList.data)
47 | correctAll += correct.sum()
48 | nAll += correct.squeeze().size()[0]
49 | correctRoot += correct.squeeze()[-1]
50 | n += 1
51 | pbar.update(j)
52 | pbar.finish()
53 | return correctRoot / n, correctAll/nAll
54 |
55 | def Var(v):
56 | if CUDA: return Variable(v.cuda())
57 | else: return Variable(v)
58 |
59 | CUDA=False
60 | if len(sys.argv)>1:
61 | if sys.argv[1].lower()=="cuda": CUDA=True
62 |
63 | print("Reading and parsing trees")
64 | trn = SenTree.getTrees("./trees/train.txt","train.vocab")
65 | dev = SenTree.getTrees("./trees/dev.txt",vocabIndicesMapFile="train.vocab")
66 |
67 | if CUDA: model = RecursiveNN(SenTree.vocabSize).cuda()
68 | else: model = RecursiveNN(SenTree.vocabSize)
69 | max_epochs = 100
70 | widgets = [progressbar.Percentage(), ' ', progressbar.Bar(), ' ', progressbar.ETA()]
71 |
72 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, dampening=0.0)
73 | bestAll=bestRoot=0.0
74 | for epoch in range(max_epochs):
75 | print("Epoch %d" % epoch)
76 | pbar = progressbar.ProgressBar(widgets=widgets, maxval=len(trn)).start()
77 | for step, tree in enumerate(trn):
78 | predictions, loss = model.getLoss(tree)
79 | optimizer.zero_grad()
80 | loss.backward()
81 | clip_grad_norm(model.parameters(), 5, norm_type=2.)
82 | optimizer.step()
83 | pbar.update(step)
84 | pbar.finish()
85 | correctRoot, correctAll = model.evaluate(dev)
86 | if bestAll1:
81 | if sys.argv[1].lower()=="cuda": CUDA=True
82 |
83 | print("Reading and parsing trees")
84 | trn = SenTree.getTrees("./trees/train.txt","train.vocab")
85 | dev = SenTree.getTrees("./trees/dev.txt",vocabIndicesMapFile="train.vocab")
86 |
87 | if CUDA: model = TreeLSTM(SenTree.vocabSize).cuda()
88 | else: model = TreeLSTM(SenTree.vocabSize)
89 | max_epochs = 100
90 | widgets = [progressbar.Percentage(), ' ', progressbar.Bar(), ' ', progressbar.ETA()]
91 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, dampening=0.0)
92 | bestAll=bestRoot=0.0
93 | for epoch in range(max_epochs):
94 | print("Epoch %d" % epoch)
95 | pbar = progressbar.ProgressBar(widgets=widgets, maxval=len(trn)).start()
96 | for step, tree in enumerate(trn):
97 | predictions, loss = model.getLoss(tree)
98 | optimizer.zero_grad()
99 | loss.backward()
100 | clip_grad_norm(model.parameters(), 5, norm_type=2.)
101 | optimizer.step()
102 | pbar.update(step)
103 | pbar.finish()
104 | correctRoot, correctAll = model.evaluate(dev)
105 | if bestAll