├── saved_model └── placeholder ├── .gitignore ├── Constants.py ├── fetch_and_preprocess.sh ├── metrics.py ├── LICENSE ├── tree.py ├── lib ├── CollapseUnaryTransformer.java ├── DependencyParse.java └── ConstituencyParse.java ├── README.md ├── vocab.py ├── utils.py ├── config.py ├── scripts ├── download.py └── preprocess-sst.py ├── trainer.py ├── dataset.py ├── sentiment.py └── model.py /saved_model/placeholder: -------------------------------------------------------------------------------- 1 | A meowing place to make sure this folder exist on git 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | 3 | *.pyc 4 | 5 | *.pth 6 | *.class 7 | 8 | data/ 9 | lib/stanford-parser 10 | lib/stanford-tagger -------------------------------------------------------------------------------- /Constants.py: -------------------------------------------------------------------------------- 1 | PAD = 0 2 | UNK = 1 3 | BOS = 2 4 | EOS = 3 5 | 6 | PAD_WORD = '' 7 | UNK_WORD = '' 8 | BOS_WORD = '' 9 | EOS_WORD = '' -------------------------------------------------------------------------------- /fetch_and_preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | python2.7 scripts/download.py 4 | 5 | CLASSPATH="lib:lib/stanford-parser/stanford-parser.jar:lib/stanford-parser/stanford-parser-3.5.1-models.jar" 6 | javac -cp $CLASSPATH lib/*.java 7 | python2.7 scripts/preprocess-sst.py 8 | 9 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable as Var 5 | 6 | class Metrics(): 7 | def __init__(self, num_classes): 8 | self.num_classes = num_classes 9 | 10 | def pearson(self, predictions, labels): 11 | #hack cai nay cho no thanh accuracy 12 | x = deepcopy(predictions) 13 | y = deepcopy(labels) 14 | x -= x.mean() 15 | x /= x.std() 16 | y -= y.mean() # FIXME: 'list' object has no attribute 'mean' 17 | # label is a list, not tensor 18 | y /= y.std() 19 | return torch.mean(torch.mul(x,y)) 20 | 21 | def mse(self, predictions, labels): 22 | x = Var(deepcopy(predictions), volatile=True) 23 | y = Var(deepcopy(labels), volatile=True) 24 | return nn.MSELoss()(x,y).data[0] 25 | 26 | def sentiment_accuracy_score(self, predictions, labels, fine_gained = True): 27 | correct = (predictions==labels).sum() 28 | total = labels.size(0) 29 | acc = float(correct)/total 30 | return acc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Riddhiman Dasgupta 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /tree.py: -------------------------------------------------------------------------------- 1 | # tree object from stanfordnlp/treelstm 2 | class Tree(object): 3 | def __init__(self): 4 | self.parent = None 5 | self.num_children = 0 6 | self.children = list() 7 | self.gold_label = None # node label for SST 8 | self.output = None # output node for SST 9 | 10 | def add_child(self,child): 11 | child.parent = self 12 | self.num_children += 1 13 | self.children.append(child) 14 | 15 | def size(self): 16 | if getattr(self,'_size'): 17 | return self._size 18 | count = 1 19 | for i in xrange(self.num_children): 20 | count += self.children[i].size() 21 | self._size = count 22 | return self._size 23 | 24 | def depth(self): 25 | if getattr(self,'_depth'): 26 | return self._depth 27 | count = 0 28 | if self.num_children>0: 29 | for i in xrange(self.num_children): 30 | child_depth = self.children[i].depth() 31 | if child_depth>count: 32 | count = child_depth 33 | count += 1 34 | self._depth = count 35 | return self._depth 36 | -------------------------------------------------------------------------------- /lib/CollapseUnaryTransformer.java: -------------------------------------------------------------------------------- 1 | import java.util.List; 2 | 3 | import edu.stanford.nlp.ling.Label; 4 | import edu.stanford.nlp.trees.Tree; 5 | import edu.stanford.nlp.trees.TreeTransformer; 6 | import edu.stanford.nlp.util.Generics; 7 | 8 | /** 9 | * This transformer collapses chains of unary nodes so that the top 10 | * node is the only node left. The Sentiment model does not handle 11 | * unary nodes, so this simplifies them to make a binary tree consist 12 | * entirely of binary nodes and preterminals. A new tree with new 13 | * nodes and labels is returned; the original tree is unchanged. 14 | * 15 | * @author John Bauer 16 | */ 17 | public class CollapseUnaryTransformer implements TreeTransformer { 18 | public Tree transformTree(Tree tree) { 19 | if (tree.isPreTerminal() || tree.isLeaf()) { 20 | return tree.deepCopy(); 21 | } 22 | 23 | Label label = tree.label().labelFactory().newLabel(tree.label()); 24 | Tree[] children = tree.children(); 25 | while (children.length == 1 && !children[0].isLeaf()) { 26 | children = children[0].children(); 27 | } 28 | List processedChildren = Generics.newArrayList(); 29 | for (Tree child : children) { 30 | processedChildren.add(transformTree(child)); 31 | } 32 | return tree.treeFactory().newTreeNode(label, processedChildren); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tree-Structured Long Short-Term Memory Networks 2 | A [PyTorch](http://pytorch.org/) based implementation of Tree-LSTM from Kai Sheng Tai's paper 3 | [Improved Semantic Representations From Tree-Structured Long Short-Term Memory 4 | Networks](http://arxiv.org/abs/1503.00075). 5 | 6 | ### Requirements 7 | - [PyTorch](http://pytorch.org/) Deep learning library 8 | - [tqdm](https://github.com/tqdm/tqdm): display progress bar 9 | - [meowlogtool](https://pypi.python.org/pypi/meowlogtool): a logger that write everything on console to file 10 | - Java >= 8 (for Stanford CoreNLP utilities) 11 | - Python >= 3 12 | 13 | ## Usage 14 | First run the script `./fetch_and_preprocess.sh` 15 | 16 | This downloads the following data: 17 | - [Stanford Sentiment Treebank](http://nlp.stanford.edu/sentiment/index.html) (sentiment classification task) 18 | - [Glove word vectors](http://nlp.stanford.edu/projects/glove/) (Common Crawl 840B) -- **Warning:** this is a 2GB download! 19 | 20 | and the following libraries: 21 | 22 | - [Stanford Parser](http://nlp.stanford.edu/software/lex-parser.shtml) 23 | - [Stanford POS Tagger](http://nlp.stanford.edu/software/tagger.shtml) 24 | 25 | ### Sentiment classification 26 | 27 | ``` 28 | python sentiment.py --name --model_name --epochs 10 29 | ``` 30 | We have not fully test on fine grain classification yet. Binary classification accuracy on both model are the same in original paper. 31 | 32 | ### Acknowledgements 33 | [Kai Sheng Tai](https://github.com/kaishengtai/) for the [original LuaTorch implementation](https://github.com/stanfordnlp/treelstm)
34 | [Pytorch team](https://github.com/pytorch/pytorch#the-team) for Python library
35 | [Riddhiman Dasgupta](https://researchweb.iiit.ac.in/~riddhiman.dasgupta/) for his implement on sentiment relatedness [https://github.com/dasguptar/treelstm.pytorch](https://github.com/dasguptar/treelstm.pytorch) which I based on as starter code. 36 | 37 | 38 | 39 | 40 | 41 | 42 | ### License 43 | MIT 44 | -------------------------------------------------------------------------------- /vocab.py: -------------------------------------------------------------------------------- 1 | # vocab object from harvardnlp/opennmt-py 2 | class Vocab(object): 3 | def __init__(self, filename=None, data=None, lower=False): 4 | self.idxToLabel = {} 5 | self.labelToIdx = {} 6 | self.lower = lower 7 | 8 | # Special entries will not be pruned. 9 | self.special = [] 10 | 11 | if data is not None: 12 | self.addSpecials(data) 13 | if filename is not None: 14 | self.loadFile(filename) 15 | 16 | def size(self): 17 | return len(self.idxToLabel) 18 | 19 | # Load entries from a file. 20 | def loadFile(self, filename): 21 | idx = 0 22 | for line in open(filename): 23 | token = line.rstrip('\n') 24 | self.add(token) 25 | idx += 1 26 | 27 | def getIndex(self, key, default=None): 28 | if self.lower: 29 | key = key.lower() 30 | try: 31 | return self.labelToIdx[key] 32 | except KeyError: 33 | return default 34 | 35 | def getLabel(self, idx, default=None): 36 | try: 37 | return self.idxToLabel[idx] 38 | except KeyError: 39 | return default 40 | 41 | # Mark this `label` and `idx` as special 42 | def addSpecial(self, label, idx=None): 43 | idx = self.add(label) 44 | self.special += [idx] 45 | 46 | # Mark all labels in `labels` as specials 47 | def addSpecials(self, labels): 48 | for label in labels: 49 | self.addSpecial(label) 50 | 51 | # Add `label` in the dictionary. Use `idx` as its index if given. 52 | def add(self, label): 53 | if self.lower: 54 | label = label.lower() 55 | 56 | if label in self.labelToIdx: 57 | idx = self.labelToIdx[label] 58 | else: 59 | idx = len(self.idxToLabel) 60 | self.idxToLabel[idx] = label 61 | self.labelToIdx[label] = idx 62 | return idx 63 | 64 | # Convert `labels` to indices. Use `unkWord` if not found. 65 | # Optionally insert `bosWord` at the beginning and `eosWord` at the . 66 | def convertToIdx(self, labels, unkWord, bosWord=None, eosWord=None): 67 | vec = [] 68 | 69 | if bosWord is not None: 70 | vec += [self.getIndex(bosWord)] 71 | 72 | unk = self.getIndex(unkWord) 73 | vec += [self.getIndex(label, default=unk) for label in labels] 74 | 75 | if eosWord is not None: 76 | vec += [self.getIndex(eosWord)] 77 | 78 | return vec 79 | 80 | # Convert `idx` to labels. If index `stop` is reached, convert it and return. 81 | def convertToLabels(self, idx, stop): 82 | labels = [] 83 | 84 | for i in idx: 85 | labels += [self.getLabel(i)] 86 | if i == stop: 87 | break 88 | 89 | return labels 90 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os, math 4 | import torch 5 | from tree import Tree 6 | from vocab import Vocab 7 | 8 | # loading GLOVE word vectors 9 | # if .pth file is found, will load that 10 | # else will load from .txt file & save 11 | def load_word_vectors(path): 12 | if os.path.isfile(path+'.pth') and os.path.isfile(path+'.vocab'): 13 | print('==> File found, loading to memory') 14 | vectors = torch.load(path+'.pth') 15 | vocab = Vocab(filename=path+'.vocab') 16 | return vocab, vectors 17 | # saved file not found, read from txt file 18 | # and create tensors for word vectors 19 | print('==> File not found, preparing, be patient') 20 | count = sum(1 for line in open(path+'.txt')) 21 | with open(path+'.txt','r') as f: 22 | contents = f.readline().rstrip('\n').split(' ') 23 | dim = len(contents[1:]) 24 | words = [None]*(count) 25 | vectors = torch.zeros(count,dim) 26 | with open(path+'.txt','r') as f: 27 | idx = 0 28 | for line in f: 29 | contents = line.rstrip('\n').split(' ') 30 | words[idx] = contents[0] 31 | #vectors[idx] = torch.Tensor(map(float, contents[1:])) 32 | vectors[idx] = torch.Tensor(list(map(float, contents[1:]))) 33 | idx += 1 34 | with open(path+'.vocab','w') as f: 35 | for word in words: 36 | f.write(word+'\n') 37 | vocab = Vocab(filename=path+'.vocab') 38 | torch.save(vectors, path+'.pth') 39 | return vocab, vectors 40 | 41 | # write unique words from a set of files to a new file 42 | def build_vocab(filenames, vocabfile): 43 | vocab = set() 44 | for filename in filenames: 45 | with open(filename,'r') as f: 46 | for line in f: 47 | tokens = line.rstrip('\n').split(' ') 48 | vocab |= set(tokens) 49 | with open(vocabfile,'w') as f: 50 | for token in vocab: 51 | f.write(token+'\n') 52 | 53 | # mapping from scalar to vector 54 | def map_label_to_target(label,num_classes): 55 | target = torch.Tensor(1,num_classes) 56 | ceil = int(math.ceil(label)) 57 | floor = int(math.floor(label)) 58 | if ceil==floor: 59 | target[0][floor-1] = 1 60 | else: 61 | target[0][floor-1] = ceil - label 62 | target[0][ceil-1] = label - floor 63 | return target 64 | 65 | def map_label_to_target_sentiment(label, num_classes = 0 ,fine_grain = False): 66 | # num_classes not use yet 67 | target = torch.LongTensor(1) 68 | target[0] = int(label) # nothing to do here as we preprocess data 69 | return target 70 | 71 | def count_param(model): 72 | print('_param count_') 73 | params = list(model.parameters()) 74 | sum_param = 0 75 | for p in params: 76 | sum_param+= p.numel() 77 | print (p.size()) 78 | # emb_sum = params[0].numel() 79 | # sum_param-= emb_sum 80 | print ('sum', sum_param) 81 | print('____________') 82 | 83 | def print_tree(tree, level): 84 | indent = '' 85 | for i in range(level): 86 | indent += '| ' 87 | line = indent + str(tree.idx) 88 | print (line) 89 | for i in xrange(tree.num_children): 90 | print_tree(tree.children[i], level+1) 91 | 92 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(type=0): 4 | if type == 0: 5 | parser = argparse.ArgumentParser(description='PyTorch TreeLSTM for Sentence Similarity on Dependency Trees') 6 | parser.add_argument('--data', default='data/sick/', 7 | help='path to dataset') 8 | parser.add_argument('--glove', default='data/glove/', 9 | help='directory with GLOVE embeddings') 10 | parser.add_argument('--batchsize', default=25, type=int, 11 | help='batchsize for optimizer updates') 12 | parser.add_argument('--epochs', default=15, type=int, 13 | help='number of total epochs to run') 14 | parser.add_argument('--lr', default=0.01, type=float, 15 | metavar='LR', help='initial learning rate') 16 | parser.add_argument('--wd', default=1e-4, type=float, 17 | help='weight decay (default: 1e-4)') 18 | parser.add_argument('--optim', default='adam', 19 | help='optimizer (default: adam)') 20 | parser.add_argument('--seed', default=123, type=int, 21 | help='random seed (default: 123)') 22 | cuda_parser = parser.add_mutually_exclusive_group(required=False) 23 | cuda_parser.add_argument('--cuda', dest='cuda', action='store_true') 24 | cuda_parser.add_argument('--no-cuda', dest='cuda', action='store_false') 25 | parser.set_defaults(cuda=True) 26 | 27 | args = parser.parse_args() 28 | return args 29 | else: # for sentiment classification on SST 30 | parser = argparse.ArgumentParser(description='PyTorch TreeLSTM for Sentiment Analysis Trees') 31 | parser.add_argument('--name', default='default_name', 32 | help='name for log and saved models') 33 | parser.add_argument('--saved', default='saved_model', 34 | help='name for log and saved models') 35 | 36 | parser.add_argument('--model_name', default='constituency', 37 | help='model name constituency or dependency') 38 | parser.add_argument('--data', default='data/sst/', 39 | help='path to dataset') 40 | parser.add_argument('--glove', default='data/glove/', 41 | help='directory with GLOVE embeddings') 42 | parser.add_argument('--batchsize', default=25, type=int, 43 | help='batchsize for optimizer updates') 44 | parser.add_argument('--epochs', default=10, type=int, 45 | help='number of total epochs to run') 46 | parser.add_argument('--lr', default=0.05, type=float, 47 | metavar='LR', help='initial learning rate') 48 | parser.add_argument('--emblr', default=0.1, type=float, 49 | metavar='EMLR', help='initial embedding learning rate') 50 | parser.add_argument('--wd', default=1e-4, type=float, 51 | help='weight decay (default: 1e-4)') 52 | parser.add_argument('--reg', default=1e-4, type=float, 53 | help='l2 regularization (default: 1e-4)') 54 | parser.add_argument('--optim', default='adagrad', 55 | help='optimizer (default: adagrad)') 56 | parser.add_argument('--seed', default=123, type=int, 57 | help='random seed (default: 123)') 58 | parser.add_argument('--fine_grain', default=0, type=int, 59 | help='fine grained (default 0 - binary mode)') 60 | # untest on fine_grain yet. 61 | cuda_parser = parser.add_mutually_exclusive_group(required=False) 62 | cuda_parser.add_argument('--cuda', dest='cuda', action='store_true') 63 | cuda_parser.add_argument('--no-cuda', dest='cuda', action='store_false') 64 | cuda_parser.add_argument('--lower', dest='cuda', action='store_true') 65 | parser.set_defaults(cuda=True) 66 | parser.set_defaults(lower=True) 67 | 68 | args = parser.parse_args() 69 | return args 70 | -------------------------------------------------------------------------------- /scripts/download.py: -------------------------------------------------------------------------------- 1 | """ 2 | Downloads the following: 3 | - Stanford parser 4 | - Stanford POS tagger 5 | - Glove vectors 6 | - SICK dataset (semantic relatedness task) 7 | - Stanford Sentiment Treebank (sentiment classification task) 8 | 9 | """ 10 | 11 | from __future__ import print_function 12 | import urllib2 13 | import sys 14 | import os 15 | import shutil 16 | import zipfile 17 | import gzip 18 | 19 | def download(url, dirpath): 20 | filename = url.split('/')[-1] 21 | filepath = os.path.join(dirpath, filename) 22 | try: 23 | u = urllib2.urlopen(url) 24 | except: 25 | print("URL %s failed to open" %url) 26 | raise Exception 27 | try: 28 | f = open(filepath, 'wb') 29 | except: 30 | print("Cannot write %s" %filepath) 31 | raise Exception 32 | try: 33 | filesize = int(u.info().getheaders("Content-Length")[0]) 34 | except: 35 | print("URL %s failed to report length" %url) 36 | raise Exception 37 | print("Downloading: %s Bytes: %s" % (filename, filesize)) 38 | 39 | downloaded = 0 40 | block_sz = 8192 41 | status_width = 70 42 | while True: 43 | buf = u.read(block_sz) 44 | if not buf: 45 | print('') 46 | break 47 | else: 48 | print('', end='\r') 49 | downloaded += len(buf) 50 | f.write(buf) 51 | status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") % 52 | ('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize)) 53 | print(status, end='') 54 | sys.stdout.flush() 55 | f.close() 56 | return filepath 57 | 58 | def unzip(filepath): 59 | print("Extracting: " + filepath) 60 | dirpath = os.path.dirname(filepath) 61 | with zipfile.ZipFile(filepath) as zf: 62 | zf.extractall(dirpath) 63 | os.remove(filepath) 64 | 65 | def download_tagger(dirpath): 66 | tagger_dir = 'stanford-tagger' 67 | if os.path.exists(os.path.join(dirpath, tagger_dir)): 68 | print('Found Stanford POS Tagger - skip') 69 | return 70 | url = 'http://nlp.stanford.edu/software/stanford-postagger-2015-01-29.zip' 71 | filepath = download(url, dirpath) 72 | zip_dir = '' 73 | with zipfile.ZipFile(filepath) as zf: 74 | zip_dir = zf.namelist()[0] 75 | zf.extractall(dirpath) 76 | os.remove(filepath) 77 | os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, tagger_dir)) 78 | 79 | def download_parser(dirpath): 80 | parser_dir = 'stanford-parser' 81 | if os.path.exists(os.path.join(dirpath, parser_dir)): 82 | print('Found Stanford Parser - skip') 83 | return 84 | url = 'http://nlp.stanford.edu/software/stanford-parser-full-2015-01-29.zip' 85 | filepath = download(url, dirpath) 86 | zip_dir = '' 87 | with zipfile.ZipFile(filepath) as zf: 88 | zip_dir = zf.namelist()[0] 89 | zf.extractall(dirpath) 90 | os.remove(filepath) 91 | os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, parser_dir)) 92 | 93 | def download_wordvecs(dirpath): 94 | if os.path.exists(dirpath): 95 | print('Found Glove vectors - skip') 96 | return 97 | else: 98 | os.makedirs(dirpath) 99 | url = 'http://www-nlp.stanford.edu/data/glove.840B.300d.zip' 100 | unzip(download(url, dirpath)) 101 | 102 | def download_sst(dirpath): 103 | if os.path.exists(dirpath): 104 | print('Found SST dataset - skip') 105 | return 106 | url = 'http://nlp.stanford.edu/~socherr/stanfordSentimentTreebank.zip' 107 | parent_dir = os.path.dirname(dirpath) 108 | unzip(download(url, parent_dir)) 109 | os.rename( 110 | os.path.join(parent_dir, 'stanfordSentimentTreebank'), 111 | os.path.join(parent_dir, 'sst')) 112 | shutil.rmtree(os.path.join(parent_dir, '__MACOSX')) # remove extraneous dir 113 | 114 | if __name__ == '__main__': 115 | base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 116 | 117 | # data 118 | data_dir = os.path.join(base_dir, 'data') 119 | wordvec_dir = os.path.join(data_dir, 'glove') 120 | sst_dir = os.path.join(data_dir, 'sst') 121 | 122 | # libraries 123 | lib_dir = os.path.join(base_dir, 'lib') 124 | 125 | # download dependencies 126 | download_tagger(lib_dir) 127 | download_parser(lib_dir) 128 | download_wordvecs(wordvec_dir) 129 | download_sst(sst_dir) 130 | -------------------------------------------------------------------------------- /lib/DependencyParse.java: -------------------------------------------------------------------------------- 1 | import edu.stanford.nlp.process.WordTokenFactory; 2 | import edu.stanford.nlp.ling.HasWord; 3 | import edu.stanford.nlp.ling.Word; 4 | import edu.stanford.nlp.ling.TaggedWord; 5 | import edu.stanford.nlp.parser.nndep.DependencyParser; 6 | import edu.stanford.nlp.process.PTBTokenizer; 7 | import edu.stanford.nlp.trees.TypedDependency; 8 | import edu.stanford.nlp.util.StringUtils; 9 | import edu.stanford.nlp.tagger.maxent.MaxentTagger; 10 | 11 | import java.io.BufferedWriter; 12 | import java.io.FileWriter; 13 | import java.io.StringReader; 14 | import java.util.ArrayList; 15 | import java.util.Collection; 16 | import java.util.List; 17 | import java.util.Properties; 18 | import java.util.Scanner; 19 | 20 | public class DependencyParse { 21 | 22 | public static final String TAGGER_MODEL = "stanford-tagger/models/english-left3words-distsim.tagger"; 23 | public static final String PARSER_MODEL = "edu/stanford/nlp/models/parser/nndep/english_SD.gz"; 24 | 25 | public static void main(String[] args) throws Exception { 26 | Properties props = StringUtils.argsToProperties(args); 27 | if (!props.containsKey("tokpath") || 28 | !props.containsKey("parentpath") || 29 | !props.containsKey("relpath")) { 30 | System.err.println( 31 | "usage: java DependencyParse -tokenize - -tokpath -parentpath -relpath "); 32 | System.exit(1); 33 | } 34 | 35 | boolean tokenize = false; 36 | if (props.containsKey("tokenize")) { 37 | tokenize = true; 38 | } 39 | 40 | String tokPath = props.getProperty("tokpath"); 41 | String parentPath = props.getProperty("parentpath"); 42 | String relPath = props.getProperty("relpath"); 43 | 44 | BufferedWriter tokWriter = new BufferedWriter(new FileWriter(tokPath)); 45 | BufferedWriter parentWriter = new BufferedWriter(new FileWriter(parentPath)); 46 | BufferedWriter relWriter = new BufferedWriter(new FileWriter(relPath)); 47 | 48 | MaxentTagger tagger = new MaxentTagger(TAGGER_MODEL); 49 | DependencyParser parser = DependencyParser.loadFromModelFile(PARSER_MODEL); 50 | Scanner stdin = new Scanner(System.in); 51 | int count = 0; 52 | long start = System.currentTimeMillis(); 53 | while (stdin.hasNextLine()) { 54 | String line = stdin.nextLine(); 55 | List tokens = new ArrayList<>(); 56 | if (tokenize) { 57 | PTBTokenizer tokenizer = new PTBTokenizer( 58 | new StringReader(line), new WordTokenFactory(), ""); 59 | for (Word label; tokenizer.hasNext(); ) { 60 | tokens.add(tokenizer.next()); 61 | } 62 | } else { 63 | for (String word : line.split(" ")) { 64 | tokens.add(new Word(word)); 65 | } 66 | } 67 | 68 | List tagged = tagger.tagSentence(tokens); 69 | 70 | int len = tagged.size(); 71 | Collection tdl = parser.predict(tagged).typedDependencies(); 72 | int[] parents = new int[len]; 73 | for (int i = 0; i < len; i++) { 74 | // if a node has a parent of -1 at the end of parsing, then the node 75 | // has no parent. 76 | parents[i] = -1; 77 | } 78 | 79 | String[] relns = new String[len]; 80 | for (TypedDependency td : tdl) { 81 | // let root have index 0 82 | int child = td.dep().index(); 83 | int parent = td.gov().index(); 84 | relns[child - 1] = td.reln().toString(); 85 | parents[child - 1] = parent; 86 | } 87 | 88 | // print tokens 89 | StringBuilder sb = new StringBuilder(); 90 | for (int i = 0; i < len - 1; i++) { 91 | if (tokenize) { 92 | sb.append(PTBTokenizer.ptbToken2Text(tokens.get(i).word())); 93 | } else { 94 | sb.append(tokens.get(i).word()); 95 | } 96 | sb.append(' '); 97 | } 98 | if (tokenize) { 99 | sb.append(PTBTokenizer.ptbToken2Text(tokens.get(len - 1).word())); 100 | } else { 101 | sb.append(tokens.get(len - 1).word()); 102 | } 103 | sb.append('\n'); 104 | tokWriter.write(sb.toString()); 105 | 106 | // print parent pointers 107 | sb = new StringBuilder(); 108 | for (int i = 0; i < len - 1; i++) { 109 | sb.append(parents[i]); 110 | sb.append(' '); 111 | } 112 | sb.append(parents[len - 1]); 113 | sb.append('\n'); 114 | parentWriter.write(sb.toString()); 115 | 116 | // print relations 117 | sb = new StringBuilder(); 118 | for (int i = 0; i < len - 1; i++) { 119 | sb.append(relns[i]); 120 | sb.append(' '); 121 | } 122 | sb.append(relns[len - 1]); 123 | sb.append('\n'); 124 | relWriter.write(sb.toString()); 125 | 126 | count++; 127 | if (count % 1000 == 0) { 128 | double elapsed = (System.currentTimeMillis() - start) / 1000.0; 129 | System.err.printf("Parsed %d lines (%.2fs)\n", count, elapsed); 130 | } 131 | } 132 | 133 | long totalTimeMillis = System.currentTimeMillis() - start; 134 | System.err.printf("Done: %d lines in %.2fs (%.1fms per line)\n", 135 | count, totalTimeMillis / 1000.0, totalTimeMillis / (double) count); 136 | tokWriter.close(); 137 | parentWriter.close(); 138 | relWriter.close(); 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | from torch.autograd import Variable as Var 4 | from utils import map_label_to_target, map_label_to_target_sentiment 5 | import torch.nn.functional as F 6 | 7 | class SentimentTrainer(object): 8 | """ 9 | For Sentiment module 10 | """ 11 | def __init__(self, args, model, embedding_model ,criterion, optimizer): 12 | super(SentimentTrainer, self).__init__() 13 | self.args = args 14 | self.model = model 15 | self.embedding_model = embedding_model 16 | self.criterion = criterion 17 | self.optimizer = optimizer 18 | self.epoch = 0 19 | 20 | # helper function for training 21 | def train(self, dataset): 22 | self.model.train() 23 | self.embedding_model.train() 24 | self.embedding_model.zero_grad() 25 | self.optimizer.zero_grad() 26 | loss, k = 0.0, 0 27 | # torch.manual_seed(789) 28 | indices = torch.randperm(len(dataset)) 29 | for idx in tqdm(range(len(dataset)),desc='Training epoch '+str(self.epoch+1)+''): 30 | tree, sent, label = dataset[indices[idx]] 31 | input = Var(sent) 32 | target = Var(map_label_to_target_sentiment(label,dataset.num_classes, fine_grain=self.args.fine_grain)) 33 | if self.args.cuda: 34 | input = input.cuda() 35 | target = target.cuda() 36 | emb = F.torch.unsqueeze(self.embedding_model(input), 1) 37 | output, err = self.model.forward(tree, emb, training = True) 38 | #params = self.model.childsumtreelstm.getParameters() 39 | # params_norm = params.norm() 40 | err = err/self.args.batchsize # + 0.5*self.args.reg*params_norm*params_norm # custom bias 41 | loss += err.data[0] # 42 | err.backward() 43 | k += 1 44 | if k==self.args.batchsize: 45 | for f in self.embedding_model.parameters(): 46 | f.data.sub_(f.grad.data * self.args.emblr) 47 | self.optimizer.step() 48 | self.embedding_model.zero_grad() 49 | self.optimizer.zero_grad() 50 | k = 0 51 | self.epoch += 1 52 | return loss/len(dataset) 53 | 54 | # helper function for testing 55 | def test(self, dataset): 56 | self.model.eval() 57 | self.embedding_model.eval() 58 | loss = 0 59 | predictions = torch.zeros(len(dataset)) 60 | #predictions = predictions 61 | indices = torch.range(1,dataset.num_classes) 62 | for idx in tqdm(range(len(dataset)),desc='Testing epoch '+str(self.epoch)+''): 63 | tree, sent, label = dataset[idx] 64 | input = Var(sent, volatile=True) 65 | target = Var(map_label_to_target_sentiment(label,dataset.num_classes, fine_grain=self.args.fine_grain), volatile=True) 66 | if self.args.cuda: 67 | input = input.cuda() 68 | target = target.cuda() 69 | emb = F.torch.unsqueeze(self.embedding_model(input),1) 70 | output, _ = self.model(tree, emb) # size(1,5) 71 | err = self.criterion(output, target) 72 | loss += err.data[0] 73 | output[:,1] = -9999 # no need middle (neutral) value 74 | val, pred = torch.max(output, 1) 75 | #predictions[idx] = pred.data.cpu()[0][0] 76 | predictions[idx] = pred.data.cpu()[0] 77 | # predictions[idx] = torch.dot(indices,torch.exp(output.data.cpu())) 78 | return loss/len(dataset), predictions 79 | 80 | 81 | class Trainer(object): 82 | def __init__(self, args, model, criterion, optimizer): 83 | super(Trainer, self).__init__() 84 | self.args = args 85 | self.model = model 86 | self.criterion = criterion 87 | self.optimizer = optimizer 88 | self.epoch = 0 89 | 90 | # helper function for training 91 | def train(self, dataset): 92 | self.model.train() 93 | self.optimizer.zero_grad() 94 | loss, k = 0.0, 0 95 | indices = torch.randperm(len(dataset)) 96 | for idx in tqdm(range(len(dataset)),desc='Training epoch '+str(self.epoch+1)+''): 97 | ltree,lsent,rtree,rsent,label = dataset[indices[idx]] 98 | linput, rinput = Var(lsent), Var(rsent) 99 | target = Var(map_label_to_target(label,dataset.num_classes)) 100 | if self.args.cuda: 101 | linput, rinput = linput.cuda(), rinput.cuda() 102 | target = target.cuda() 103 | output = self.model(ltree,linput,rtree,rinput) 104 | err = self.criterion(output, target) 105 | loss += err.data[0] 106 | err.backward() 107 | k += 1 108 | if k%self.args.batchsize==0: 109 | self.optimizer.step() 110 | self.optimizer.zero_grad() 111 | self.epoch += 1 112 | return loss/len(dataset) 113 | 114 | # helper function for testing 115 | def test(self, dataset): 116 | self.model.eval() 117 | loss = 0 118 | predictions = torch.zeros(len(dataset)) 119 | indices = torch.range(1,dataset.num_classes) 120 | for idx in tqdm(range(len(dataset)),desc='Testing epoch '+str(self.epoch)+''): 121 | ltree,lsent,rtree,rsent,label = dataset[idx] 122 | linput, rinput = Var(lsent, volatile=True), Var(rsent, volatile=True) 123 | target = Var(map_label_to_target(label,dataset.num_classes), volatile=True) 124 | if self.args.cuda: 125 | linput, rinput = linput.cuda(), rinput.cuda() 126 | target = target.cuda() 127 | output = self.model(ltree,linput,rtree,rinput) 128 | err = self.criterion(output, target) 129 | loss += err.data[0] 130 | predictions[idx] = torch.dot(indices,torch.exp(output.data.cpu())) 131 | return loss/len(dataset), predictions 132 | -------------------------------------------------------------------------------- /lib/ConstituencyParse.java: -------------------------------------------------------------------------------- 1 | import edu.stanford.nlp.process.WordTokenFactory; 2 | import edu.stanford.nlp.ling.HasWord; 3 | import edu.stanford.nlp.ling.Word; 4 | import edu.stanford.nlp.ling.CoreLabel; 5 | import edu.stanford.nlp.process.PTBTokenizer; 6 | import edu.stanford.nlp.util.StringUtils; 7 | import edu.stanford.nlp.parser.lexparser.LexicalizedParser; 8 | import edu.stanford.nlp.parser.lexparser.TreeBinarizer; 9 | import edu.stanford.nlp.trees.GrammaticalStructure; 10 | import edu.stanford.nlp.trees.GrammaticalStructureFactory; 11 | import edu.stanford.nlp.trees.PennTreebankLanguagePack; 12 | import edu.stanford.nlp.trees.Tree; 13 | import edu.stanford.nlp.trees.Trees; 14 | import edu.stanford.nlp.trees.TreebankLanguagePack; 15 | import edu.stanford.nlp.trees.TypedDependency; 16 | 17 | import java.io.BufferedWriter; 18 | import java.io.FileWriter; 19 | import java.io.StringReader; 20 | import java.io.IOException; 21 | import java.util.ArrayList; 22 | import java.util.Collection; 23 | import java.util.List; 24 | import java.util.HashMap; 25 | import java.util.Properties; 26 | import java.util.Scanner; 27 | 28 | public class ConstituencyParse { 29 | 30 | private boolean tokenize; 31 | private BufferedWriter tokWriter, parentWriter; 32 | private LexicalizedParser parser; 33 | private TreeBinarizer binarizer; 34 | private CollapseUnaryTransformer transformer; 35 | private GrammaticalStructureFactory gsf; 36 | 37 | private static final String PCFG_PATH = "edu/stanford/nlp/models/lexparser/englishPCFG.ser.gz"; 38 | 39 | public ConstituencyParse(String tokPath, String parentPath, boolean tokenize) throws IOException { 40 | this.tokenize = tokenize; 41 | if (tokPath != null) { 42 | tokWriter = new BufferedWriter(new FileWriter(tokPath)); 43 | } 44 | parentWriter = new BufferedWriter(new FileWriter(parentPath)); 45 | parser = LexicalizedParser.loadModel(PCFG_PATH); 46 | binarizer = TreeBinarizer.simpleTreeBinarizer( 47 | parser.getTLPParams().headFinder(), parser.treebankLanguagePack()); 48 | transformer = new CollapseUnaryTransformer(); 49 | 50 | // set up to produce dependency representations from constituency trees 51 | TreebankLanguagePack tlp = new PennTreebankLanguagePack(); 52 | gsf = tlp.grammaticalStructureFactory(); 53 | } 54 | 55 | public List sentenceToTokens(String line) { 56 | List tokens = new ArrayList<>(); 57 | if (tokenize) { 58 | PTBTokenizer tokenizer = new PTBTokenizer(new StringReader(line), new WordTokenFactory(), ""); 59 | for (Word label; tokenizer.hasNext(); ) { 60 | tokens.add(tokenizer.next()); 61 | } 62 | } else { 63 | for (String word : line.split(" ")) { 64 | tokens.add(new Word(word)); 65 | } 66 | } 67 | 68 | return tokens; 69 | } 70 | 71 | public Tree parse(List tokens) { 72 | Tree tree = parser.apply(tokens); 73 | return tree; 74 | } 75 | 76 | public int[] constTreeParents(Tree tree) { 77 | Tree binarized = binarizer.transformTree(tree); 78 | Tree collapsedUnary = transformer.transformTree(binarized); 79 | Trees.convertToCoreLabels(collapsedUnary); 80 | collapsedUnary.indexSpans(); 81 | List leaves = collapsedUnary.getLeaves(); 82 | int size = collapsedUnary.size() - leaves.size(); 83 | int[] parents = new int[size]; 84 | HashMap index = new HashMap(); 85 | 86 | int idx = leaves.size(); 87 | int leafIdx = 0; 88 | for (Tree leaf : leaves) { 89 | Tree cur = leaf.parent(collapsedUnary); // go to preterminal 90 | int curIdx = leafIdx++; 91 | boolean done = false; 92 | while (!done) { 93 | Tree parent = cur.parent(collapsedUnary); 94 | if (parent == null) { 95 | parents[curIdx] = 0; 96 | break; 97 | } 98 | 99 | int parentIdx; 100 | int parentNumber = parent.nodeNumber(collapsedUnary); 101 | if (!index.containsKey(parentNumber)) { 102 | parentIdx = idx++; 103 | index.put(parentNumber, parentIdx); 104 | } else { 105 | parentIdx = index.get(parentNumber); 106 | done = true; 107 | } 108 | 109 | parents[curIdx] = parentIdx + 1; 110 | cur = parent; 111 | curIdx = parentIdx; 112 | } 113 | } 114 | 115 | return parents; 116 | } 117 | 118 | // convert constituency parse to a dependency representation and return the 119 | // parent pointer representation of the tree 120 | public int[] depTreeParents(Tree tree, List tokens) { 121 | GrammaticalStructure gs = gsf.newGrammaticalStructure(tree); 122 | Collection tdl = gs.typedDependencies(); 123 | int len = tokens.size(); 124 | int[] parents = new int[len]; 125 | for (int i = 0; i < len; i++) { 126 | // if a node has a parent of -1 at the end of parsing, then the node 127 | // has no parent. 128 | parents[i] = -1; 129 | } 130 | 131 | for (TypedDependency td : tdl) { 132 | // let root have index 0 133 | int child = td.dep().index(); 134 | int parent = td.gov().index(); 135 | parents[child - 1] = parent; 136 | } 137 | 138 | return parents; 139 | } 140 | 141 | public void printTokens(List tokens) throws IOException { 142 | int len = tokens.size(); 143 | StringBuilder sb = new StringBuilder(); 144 | for (int i = 0; i < len - 1; i++) { 145 | if (tokenize) { 146 | sb.append(PTBTokenizer.ptbToken2Text(tokens.get(i).word())); 147 | } else { 148 | sb.append(tokens.get(i).word()); 149 | } 150 | sb.append(' '); 151 | } 152 | 153 | if (tokenize) { 154 | sb.append(PTBTokenizer.ptbToken2Text(tokens.get(len - 1).word())); 155 | } else { 156 | sb.append(tokens.get(len - 1).word()); 157 | } 158 | 159 | sb.append('\n'); 160 | tokWriter.write(sb.toString()); 161 | } 162 | 163 | public void printParents(int[] parents) throws IOException { 164 | StringBuilder sb = new StringBuilder(); 165 | int size = parents.length; 166 | for (int i = 0; i < size - 1; i++) { 167 | sb.append(parents[i]); 168 | sb.append(' '); 169 | } 170 | sb.append(parents[size - 1]); 171 | sb.append('\n'); 172 | parentWriter.write(sb.toString()); 173 | } 174 | 175 | public void close() throws IOException { 176 | if (tokWriter != null) tokWriter.close(); 177 | parentWriter.close(); 178 | } 179 | 180 | public static void main(String[] args) throws Exception { 181 | Properties props = StringUtils.argsToProperties(args); 182 | if (!props.containsKey("parentpath")) { 183 | System.err.println( 184 | "usage: java ConstituencyParse -deps - -tokenize - -tokpath -parentpath "); 185 | System.exit(1); 186 | } 187 | 188 | // whether to tokenize input sentences 189 | boolean tokenize = false; 190 | if (props.containsKey("tokenize")) { 191 | tokenize = true; 192 | } 193 | 194 | // whether to produce dependency trees from the constituency parse 195 | boolean deps = false; 196 | if (props.containsKey("deps")) { 197 | deps = true; 198 | } 199 | 200 | String tokPath = props.containsKey("tokpath") ? props.getProperty("tokpath") : null; 201 | String parentPath = props.getProperty("parentpath"); 202 | ConstituencyParse processor = new ConstituencyParse(tokPath, parentPath, tokenize); 203 | 204 | Scanner stdin = new Scanner(System.in); 205 | int count = 0; 206 | long start = System.currentTimeMillis(); 207 | while (stdin.hasNextLine()) { 208 | String line = stdin.nextLine(); 209 | List tokens = processor.sentenceToTokens(line); 210 | Tree parse = processor.parse(tokens); 211 | 212 | // produce parent pointer representation 213 | int[] parents = deps ? processor.depTreeParents(parse, tokens) 214 | : processor.constTreeParents(parse); 215 | 216 | // print 217 | if (tokPath != null) { 218 | processor.printTokens(tokens); 219 | } 220 | processor.printParents(parents); 221 | 222 | count++; 223 | if (count % 1000 == 0) { 224 | double elapsed = (System.currentTimeMillis() - start) / 1000.0; 225 | System.err.printf("Parsed %d lines (%.2fs)\n", count, elapsed); 226 | } 227 | } 228 | 229 | long totalTimeMillis = System.currentTimeMillis() - start; 230 | System.err.printf("Done: %d lines in %.2fs (%.1fms per line)\n", 231 | count, totalTimeMillis / 1000.0, totalTimeMillis / (double) count); 232 | processor.close(); 233 | } 234 | } 235 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | from tqdm import tqdm 4 | import torch 5 | import torch.utils.data as data 6 | from tree import Tree 7 | from vocab import Vocab 8 | import Constants 9 | import utils 10 | 11 | # Dataset class for SICK dataset 12 | class SICKDataset(data.Dataset): 13 | def __init__(self, path, vocab, num_classes): 14 | super(SICKDataset, self).__init__() 15 | self.vocab = vocab 16 | self.num_classes = num_classes 17 | 18 | self.lsentences = self.read_sentences(os.path.join(path,'a.toks')) 19 | self.rsentences = self.read_sentences(os.path.join(path,'b.toks')) 20 | 21 | self.ltrees = self.read_trees(os.path.join(path,'a.parents')) 22 | self.rtrees = self.read_trees(os.path.join(path,'b.parents')) 23 | 24 | self.labels = self.read_labels(os.path.join(path,'sim.txt')) 25 | 26 | self.size = self.labels.size(0) 27 | 28 | def __len__(self): 29 | return self.size 30 | 31 | def __getitem__(self, index): 32 | ltree = deepcopy(self.ltrees[index]) 33 | rtree = deepcopy(self.rtrees[index]) 34 | lsent = deepcopy(self.lsentences[index]) 35 | rsent = deepcopy(self.rsentences[index]) 36 | label = deepcopy(self.labels[index]) 37 | return (ltree,lsent,rtree,rsent,label) 38 | 39 | def read_sentences(self, filename): 40 | with open(filename,'r') as f: 41 | sentences = [self.read_sentence(line) for line in tqdm(f.readlines())] 42 | return sentences 43 | 44 | def read_sentence(self, line): 45 | indices = self.vocab.convertToIdx(line.split(), Constants.UNK_WORD) 46 | return torch.LongTensor(indices) 47 | 48 | def read_trees(self, filename): 49 | with open(filename,'r') as f: 50 | trees = [self.read_tree(line) for line in tqdm(f.readlines())] 51 | return trees 52 | 53 | def read_tree(self, line): 54 | parents = map(int,line.split()) 55 | trees = dict() 56 | root = None 57 | for i in xrange(1,len(parents)+1): 58 | #if not trees[i-1] and parents[i-1]!=-1: 59 | if i-1 not in trees.keys() and parents[i-1]!=-1: 60 | idx = i 61 | prev = None 62 | while True: 63 | parent = parents[idx-1] 64 | if parent == -1: 65 | break 66 | tree = Tree() 67 | if prev is not None: 68 | tree.add_child(prev) 69 | trees[idx-1] = tree 70 | tree.idx = idx-1 71 | #if trees[parent-1] is not None: 72 | if parent-1 in trees.keys(): 73 | trees[parent-1].add_child(tree) 74 | break 75 | elif parent==0: 76 | root = tree 77 | break 78 | else: 79 | prev = tree 80 | idx = parent 81 | return root 82 | 83 | def read_labels(self, filename): 84 | with open(filename,'r') as f: 85 | labels = map(lambda x: float(x), f.readlines()) 86 | labels = torch.Tensor(labels) 87 | return labels 88 | 89 | # Dataset class for SICK dataset 90 | class SSTDataset(data.Dataset): 91 | def __init__(self, path, vocab, num_classes, fine_grain, model_name): 92 | super(SSTDataset, self).__init__() 93 | self.vocab = vocab 94 | self.num_classes = num_classes 95 | self.fine_grain = fine_grain 96 | self.model_name = model_name 97 | 98 | temp_sentences = self.read_sentences(os.path.join(path,'sents.toks')) 99 | if model_name == "dependency": 100 | temp_trees = self.read_trees(os.path.join(path,'dparents.txt'), os.path.join(path,'dlabels.txt')) 101 | else: 102 | temp_trees = self.read_trees(os.path.join(path, 'parents.txt'), os.path.join(path, 'labels.txt')) 103 | 104 | # self.labels = self.read_labels(os.path.join(path,'dlabels.txt')) 105 | self.labels = [] 106 | 107 | if not self.fine_grain: 108 | # only get pos or neg 109 | new_trees = [] 110 | new_sentences = [] 111 | for i in range(len(temp_trees)): 112 | if temp_trees[i].gold_label != 1: # 0 neg, 1 neutral, 2 pos 113 | new_trees.append(temp_trees[i]) 114 | new_sentences.append(temp_sentences[i]) 115 | self.trees = new_trees 116 | self.sentences = new_sentences 117 | else: 118 | self.trees = temp_trees 119 | self.sentences = temp_sentences 120 | 121 | for i in range(0, len(self.trees)): 122 | self.labels.append(self.trees[i].gold_label) 123 | self.labels = torch.Tensor(self.labels) # let labels be tensor 124 | self.size = len(self.trees) 125 | 126 | def __len__(self): 127 | return self.size 128 | 129 | def __getitem__(self, index): 130 | # ltree = deepcopy(self.ltrees[index]) 131 | # rtree = deepcopy(self.rtrees[index]) 132 | # lsent = deepcopy(self.lsentences[index]) 133 | # rsent = deepcopy(self.rsentences[index]) 134 | # label = deepcopy(self.labels[index]) 135 | tree = deepcopy(self.trees[index]) 136 | sent = deepcopy(self.sentences[index]) 137 | label = deepcopy(self.labels[index]) 138 | return (tree, sent, label) 139 | 140 | def read_sentences(self, filename): 141 | with open(filename,'r') as f: 142 | sentences = [self.read_sentence(line) for line in tqdm(f.readlines())] 143 | return sentences 144 | 145 | def read_sentence(self, line): 146 | indices = self.vocab.convertToIdx(line.split(), Constants.UNK_WORD) 147 | return torch.LongTensor(indices) 148 | 149 | def read_trees(self, filename_parents, filename_labels): 150 | pfile = open(filename_parents, 'r') # parent node 151 | lfile = open(filename_labels, 'r') # label node 152 | p = pfile.readlines() 153 | l = lfile.readlines() 154 | pl = zip(p, l) # (parent, label) tuple 155 | trees = [self.read_tree(p_line, l_line) for p_line, l_line in tqdm(pl)] 156 | 157 | return trees 158 | 159 | def parse_dlabel_token(self, x): 160 | if x == '#': 161 | return None 162 | else: 163 | if self.fine_grain: # -2 -1 0 1 2 => 0 1 2 3 4 164 | return int(x)+2 165 | else: # # -2 -1 0 1 2 => 0 1 2 166 | tmp = int(x) 167 | if tmp < 0: 168 | return 0 169 | elif tmp == 0: 170 | return 1 171 | elif tmp >0 : 172 | return 2 173 | 174 | def read_tree(self, line, label_line): 175 | # FIXED: tree.idx, also tree dict() use base 1 as it was in dataset 176 | # parents is list base 0, keep idx-1 177 | # labels is list base 0, keep idx-1 178 | #parents = map(int,line.split()) # split each number and turn to int 179 | parents = list(map(int,line.split())) # split each number and turn to int 180 | trees = dict() # this is dict 181 | root = None 182 | #labels = map(self.parse_dlabel_token, label_line.split()) 183 | labels = list(map(self.parse_dlabel_token, label_line.split())) 184 | for i in range(1,len(parents)+1): 185 | #for i in range(1,len(list(parents))+1): 186 | #if not trees[i-1] and parents[i-1]!=-1: 187 | if i not in trees.keys() and parents[i-1]!=-1: 188 | idx = i 189 | prev = None 190 | while True: 191 | parent = parents[idx-1] 192 | if parent == -1: 193 | break 194 | tree = Tree() 195 | if prev is not None: 196 | tree.add_child(prev) 197 | trees[idx] = tree 198 | tree.idx = idx # -1 remove -1 here to prevent embs[tree.idx -1] = -1 while tree.idx = 0 199 | tree.gold_label = labels[idx-1] # add node label 200 | #if trees[parent-1] is not None: 201 | if parent in trees.keys(): 202 | trees[parent].add_child(tree) 203 | break 204 | elif parent==0: 205 | root = tree 206 | break 207 | else: 208 | prev = tree 209 | idx = parent 210 | return root 211 | 212 | def read_labels(self, filename): 213 | # Not in used 214 | with open(filename,'r') as f: 215 | labels = map(lambda x: float(x), f.readlines()) 216 | labels = torch.Tensor(labels) 217 | return labels 218 | -------------------------------------------------------------------------------- /sentiment.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os, time, argparse 4 | from tqdm import tqdm 5 | import numpy 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.autograd import Variable as Var 10 | import utils 11 | import gc 12 | import sys 13 | from meowlogtool import log_util 14 | 15 | 16 | # IMPORT CONSTANTS 17 | import Constants 18 | # NEURAL NETWORK MODULES/LAYERS 19 | from model import * 20 | # DATA HANDLING CLASSES 21 | from tree import Tree 22 | from vocab import Vocab 23 | # DATASET CLASS FOR SICK DATASET 24 | from dataset import SSTDataset 25 | # METRICS CLASS FOR EVALUATION 26 | from metrics import Metrics 27 | # UTILITY FUNCTIONS 28 | from utils import load_word_vectors, build_vocab 29 | # CONFIG PARSER 30 | from config import parse_args 31 | # TRAIN AND TEST HELPER FUNCTIONS 32 | from trainer import SentimentTrainer 33 | 34 | # MAIN BLOCK 35 | def main(): 36 | global args 37 | args = parse_args(type=1) 38 | args.input_dim= 300 39 | if args.model_name == 'dependency': 40 | args.mem_dim = 168 41 | elif args.model_name == 'constituency': 42 | args.mem_dim = 150 43 | if args.fine_grain: 44 | args.num_classes = 5 # 0 1 2 3 4 45 | else: 46 | args.num_classes = 3 # 0 1 2 (1 neutral) 47 | args.cuda = args.cuda and torch.cuda.is_available() 48 | # args.cuda = False 49 | print(args) 50 | # torch.manual_seed(args.seed) 51 | # if args.cuda: 52 | # torch.cuda.manual_seed(args.seed) 53 | 54 | train_dir = os.path.join(args.data,'train/') 55 | dev_dir = os.path.join(args.data,'dev/') 56 | test_dir = os.path.join(args.data,'test/') 57 | 58 | # write unique words from all token files 59 | token_files = [os.path.join(split, 'sents.toks') for split in [train_dir, dev_dir, test_dir]] 60 | vocab_file = os.path.join(args.data,'vocab-cased.txt') # use vocab-cased 61 | # build_vocab(token_files, vocab_file) NO, DO NOT BUILD VOCAB, USE OLD VOCAB 62 | 63 | # get vocab object from vocab file previously written 64 | vocab = Vocab(filename=vocab_file) 65 | print('==> SST vocabulary size : %d ' % vocab.size()) 66 | 67 | # Load SST dataset splits 68 | 69 | is_preprocessing_data = False # let program turn off after preprocess data 70 | 71 | # train 72 | train_file = os.path.join(args.data,'sst_train.pth') 73 | if os.path.isfile(train_file): 74 | train_dataset = torch.load(train_file) 75 | else: 76 | train_dataset = SSTDataset(train_dir, vocab, args.num_classes, args.fine_grain, args.model_name) 77 | torch.save(train_dataset, train_file) 78 | is_preprocessing_data = True 79 | 80 | # dev 81 | dev_file = os.path.join(args.data,'sst_dev.pth') 82 | if os.path.isfile(dev_file): 83 | dev_dataset = torch.load(dev_file) 84 | else: 85 | dev_dataset = SSTDataset(dev_dir, vocab, args.num_classes, args.fine_grain, args.model_name) 86 | torch.save(dev_dataset, dev_file) 87 | is_preprocessing_data = True 88 | 89 | # test 90 | test_file = os.path.join(args.data,'sst_test.pth') 91 | if os.path.isfile(test_file): 92 | test_dataset = torch.load(test_file) 93 | else: 94 | test_dataset = SSTDataset(test_dir, vocab, args.num_classes, args.fine_grain, args.model_name) 95 | torch.save(test_dataset, test_file) 96 | is_preprocessing_data = True 97 | 98 | criterion = nn.NLLLoss() 99 | # initialize model, criterion/loss_function, optimizer 100 | model = TreeLSTMSentiment( 101 | args.cuda, vocab.size(), 102 | args.input_dim, args.mem_dim, 103 | args.num_classes, args.model_name, criterion 104 | ) 105 | 106 | embedding_model = nn.Embedding(vocab.size(), args.input_dim) 107 | 108 | if args.cuda: 109 | embedding_model = embedding_model.cuda() 110 | 111 | if args.cuda: 112 | model.cuda(), criterion.cuda() 113 | if args.optim=='adam': 114 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd) 115 | elif args.optim=='adagrad': 116 | # optimizer = optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd) 117 | optimizer = optim.Adagrad([ 118 | {'params': model.parameters(), 'lr': args.lr} 119 | ], lr=args.lr, weight_decay=args.wd) 120 | metrics = Metrics(args.num_classes) 121 | 122 | utils.count_param(model) 123 | 124 | # for words common to dataset vocab and GLOVE, use GLOVE vectors 125 | # for other words in dataset vocab, use random normal vectors 126 | emb_file = os.path.join(args.data, 'sst_embed.pth') 127 | if os.path.isfile(emb_file): 128 | emb = torch.load(emb_file) 129 | else: 130 | 131 | # load glove embeddings and vocab 132 | glove_vocab, glove_emb = load_word_vectors(os.path.join(args.glove,'glove.840B.300d')) 133 | print('==> GLOVE vocabulary size: %d ' % glove_vocab.size()) 134 | 135 | emb = torch.zeros(vocab.size(),glove_emb.size(1)) 136 | 137 | for word in vocab.labelToIdx.keys(): 138 | if glove_vocab.getIndex(word): 139 | emb[vocab.getIndex(word)] = glove_emb[glove_vocab.getIndex(word)] 140 | else: 141 | emb[vocab.getIndex(word)] = torch.Tensor(emb[vocab.getIndex(word)].size()).normal_(-0.05,0.05) 142 | torch.save(emb, emb_file) 143 | is_preprocessing_data = True # flag to quit 144 | print('done creating emb, quit') 145 | 146 | if is_preprocessing_data: 147 | print ('done preprocessing data, quit program to prevent memory leak') 148 | print ('please run again') 149 | quit() 150 | 151 | # plug these into embedding matrix inside model 152 | if args.cuda: 153 | emb = emb.cuda() 154 | 155 | # model.childsumtreelstm.emb.state_dict()['weight'].copy_(emb) 156 | embedding_model.state_dict()['weight'].copy_(emb) 157 | 158 | # create trainer object for training and testing 159 | trainer = SentimentTrainer(args, model, embedding_model ,criterion, optimizer) 160 | 161 | mode = 'EXPERIMENT' 162 | if mode == 'DEBUG': 163 | for epoch in range(args.epochs): 164 | dev_loss = trainer.train(dev_dataset) 165 | dev_loss, dev_pred = trainer.test(dev_dataset) 166 | test_loss, test_pred = trainer.test(test_dataset) 167 | 168 | dev_acc = metrics.sentiment_accuracy_score(dev_pred, dev_dataset.labels) 169 | test_acc = metrics.sentiment_accuracy_score(test_pred, test_dataset.labels) 170 | print('==> Dev loss : %f \t' % dev_loss, end="") 171 | print('Epoch ', epoch, 'dev percentage ', dev_acc) 172 | elif mode == "PRINT_TREE": 173 | for i in range(0, 10): 174 | ttree, tsent, tlabel = dev_dataset[i] 175 | utils.print_tree(ttree, 0) 176 | print('_______________') 177 | print('break') 178 | quit() 179 | elif mode == "EXPERIMENT": 180 | max_dev = 0 181 | max_dev_epoch = 0 182 | filename = args.name + '.pth' 183 | for epoch in range(args.epochs): 184 | train_loss = trainer.train(train_dataset) 185 | dev_loss, dev_pred = trainer.test(dev_dataset) 186 | dev_acc = metrics.sentiment_accuracy_score(dev_pred, dev_dataset.labels) 187 | print('==> Train loss : %f \t' % train_loss, end="") 188 | print('Epoch ', epoch, 'dev percentage ', dev_acc) 189 | torch.save(model, args.saved + str(epoch) + '_model_' + filename) 190 | torch.save(embedding_model, args.saved + str(epoch) + '_embedding_' + filename) 191 | if dev_acc > max_dev: 192 | max_dev = dev_acc 193 | max_dev_epoch = epoch 194 | gc.collect() 195 | print('epoch ' + str(max_dev_epoch) + ' dev score of ' + str(max_dev)) 196 | print('eva on test set ') 197 | model = torch.load(args.saved + str(max_dev_epoch) + '_model_' + filename) 198 | embedding_model = torch.load(args.saved + str(max_dev_epoch) + '_embedding_' + filename) 199 | trainer = SentimentTrainer(args, model, embedding_model, criterion, optimizer) 200 | test_loss, test_pred = trainer.test(test_dataset) 201 | test_acc = metrics.sentiment_accuracy_score(test_pred, test_dataset.labels) 202 | print('Epoch with max dev:' + str(max_dev_epoch) + ' |test percentage ' + str(test_acc)) 203 | print('____________________' + str(args.name) + '___________________') 204 | else: 205 | for epoch in range(args.epochs): 206 | train_loss = trainer.train(train_dataset) 207 | train_loss, train_pred = trainer.test(train_dataset) 208 | dev_loss, dev_pred = trainer.test(dev_dataset) 209 | test_loss, test_pred = trainer.test(test_dataset) 210 | 211 | train_acc = metrics.sentiment_accuracy_score(train_pred, train_dataset.labels) 212 | dev_acc = metrics.sentiment_accuracy_score(dev_pred, dev_dataset.labels) 213 | test_acc = metrics.sentiment_accuracy_score(test_pred, test_dataset.labels) 214 | print('==> Train loss : %f \t' % train_loss, end="") 215 | print('Epoch ', epoch, 'train percentage ', train_acc) 216 | print('Epoch ', epoch, 'dev percentage ', dev_acc) 217 | print('Epoch ', epoch, 'test percentage ', test_acc) 218 | 219 | 220 | if __name__ == "__main__": 221 | # log to console and file 222 | logger1 = log_util.create_logger("temp_file", print_console=True) 223 | logger1.info("LOG_FILE") # log using loggerba 224 | # attach log to stdout (print function) 225 | s1 = log_util.StreamToLogger(logger1) 226 | sys.stdout = s1 227 | print ('_________________________________start___________________________________') 228 | main() -------------------------------------------------------------------------------- /scripts/preprocess-sst.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocessing script for Stanford Sentiment Treebank data. 3 | 4 | """ 5 | 6 | import os 7 | import glob 8 | 9 | # 10 | # Trees and tree loading 11 | # 12 | 13 | class ConstTree(object): 14 | def __init__(self): 15 | self.left = None 16 | self.right = None 17 | 18 | def size(self): 19 | self.size = 1 20 | if self.left is not None: 21 | self.size += self.left.size() 22 | if self.right is not None: 23 | self.size += self.right.size() 24 | return self.size 25 | 26 | def set_spans(self): 27 | if self.word is not None: 28 | self.span = self.word 29 | return self.span 30 | 31 | self.span = self.left.set_spans() 32 | if self.right is not None: 33 | self.span += ' ' + self.right.set_spans() 34 | return self.span 35 | 36 | def get_labels(self, spans, labels, dictionary): 37 | if self.span in dictionary: 38 | spans[self.idx] = self.span 39 | labels[self.idx] = dictionary[self.span] 40 | if self.left is not None: 41 | self.left.get_labels(spans, labels, dictionary) 42 | if self.right is not None: 43 | self.right.get_labels(spans, labels, dictionary) 44 | 45 | class DepTree(object): 46 | def __init__(self): 47 | self.children = [] 48 | self.lo, self.hi = None, None 49 | 50 | def size(self): 51 | self.size = 1 52 | for c in self.children: 53 | self.size += c.size() 54 | return self.size 55 | 56 | def set_spans(self, words): 57 | self.lo, self.hi = self.idx, self.idx + 1 58 | if len(self.children) == 0: 59 | self.span = words[self.idx] 60 | return 61 | for c in self.children: 62 | c.set_spans(words) 63 | self.lo = min(self.lo, c.lo) 64 | self.hi = max(self.hi, c.hi) 65 | self.span = ' '.join(words[self.lo : self.hi]) 66 | 67 | def get_labels(self, spans, labels, dictionary): 68 | if self.span in dictionary: 69 | spans[self.idx] = self.span 70 | labels[self.idx] = dictionary[self.span] 71 | for c in self.children: 72 | c.get_labels(spans, labels, dictionary) 73 | 74 | def load_trees(dirpath): 75 | const_trees, dep_trees, toks = [], [], [] 76 | with open(os.path.join(dirpath, 'parents.txt')) as parentsfile, \ 77 | open(os.path.join(dirpath, 'dparents.txt')) as dparentsfile, \ 78 | open(os.path.join(dirpath, 'sents.txt')) as toksfile: 79 | parents, dparents = [], [] 80 | for line in parentsfile: 81 | parents.append(map(int, line.split())) 82 | for line in dparentsfile: 83 | dparents.append(map(int, line.split())) 84 | for line in toksfile: 85 | toks.append(line.strip().split()) 86 | for i in xrange(len(toks)): 87 | const_trees.append(load_constituency_tree(parents[i], toks[i])) 88 | dep_trees.append(load_dependency_tree(dparents[i])) 89 | return const_trees, dep_trees, toks 90 | 91 | def load_constituency_tree(parents, words): 92 | trees = [] 93 | root = None 94 | size = len(parents) 95 | for i in xrange(size): 96 | trees.append(None) 97 | 98 | word_idx = 0 99 | for i in xrange(size): 100 | if not trees[i]: 101 | idx = i 102 | prev = None 103 | prev_idx = None 104 | word = words[word_idx] 105 | word_idx += 1 106 | while True: 107 | tree = ConstTree() 108 | parent = parents[idx] - 1 109 | tree.word, tree.parent, tree.idx = word, parent, idx 110 | word = None 111 | if prev is not None: 112 | if tree.left is None: 113 | tree.left = prev 114 | else: 115 | tree.right = prev 116 | trees[idx] = tree 117 | if parent >= 0 and trees[parent] is not None: 118 | if trees[parent].left is None: 119 | trees[parent].left = tree 120 | else: 121 | trees[parent].right = tree 122 | break 123 | elif parent == -1: 124 | root = tree 125 | break 126 | else: 127 | prev = tree 128 | prev_idx = idx 129 | idx = parent 130 | return root 131 | 132 | def load_dependency_tree(parents): 133 | trees = [] 134 | root = None 135 | size = len(parents) 136 | for i in xrange(size): 137 | trees.append(None) 138 | 139 | for i in xrange(size): 140 | if not trees[i]: 141 | idx = i 142 | prev = None 143 | prev_idx = None 144 | while True: 145 | tree = DepTree() 146 | parent = parents[idx] - 1 147 | 148 | # node is not in tree 149 | if parent == -2: 150 | break 151 | 152 | tree.parent, tree.idx = parent, idx 153 | if prev is not None: 154 | tree.children.append(prev) 155 | trees[idx] = tree 156 | if parent >= 0 and trees[parent] is not None: 157 | trees[parent].children.append(tree) 158 | break 159 | elif parent == -1: 160 | root = tree 161 | break 162 | else: 163 | prev = tree 164 | prev_idx = idx 165 | idx = parent 166 | return root 167 | 168 | # 169 | # Various utilities 170 | # 171 | 172 | def make_dirs(dirs): 173 | for d in dirs: 174 | if not os.path.exists(d): 175 | os.makedirs(d) 176 | 177 | def load_sents(dirpath): 178 | sents = [] 179 | with open(os.path.join(dirpath, 'SOStr.txt')) as sentsfile: 180 | for line in sentsfile: 181 | sent = ' '.join(line.split('|')) 182 | sents.append(sent.strip()) 183 | return sents 184 | 185 | def load_splits(dirpath): 186 | splits = [] 187 | with open(os.path.join(dirpath, 'datasetSplit.txt')) as splitfile: 188 | splitfile.readline() 189 | for line in splitfile: 190 | idx, split = line.split(',') 191 | splits.append(int(split)) 192 | return splits 193 | 194 | def load_parents(dirpath): 195 | parents = [] 196 | with open(os.path.join(dirpath, 'STree.txt')) as parentsfile: 197 | for line in parentsfile: 198 | p = ' '.join(line.split('|')) 199 | parents.append(p.strip()) 200 | return parents 201 | 202 | def load_dictionary(dirpath): 203 | labels = [] 204 | with open(os.path.join(dirpath, 'sentiment_labels.txt')) as labelsfile: 205 | labelsfile.readline() 206 | for line in labelsfile: 207 | idx, rating = line.split('|') 208 | idx = int(idx) 209 | rating = float(rating) 210 | if rating <= 0.2: 211 | label = -2 212 | elif rating <= 0.4: 213 | label = -1 214 | elif rating > 0.8: 215 | label = +2 216 | elif rating > 0.6: 217 | label = +1 218 | else: 219 | label = 0 220 | labels.append(label) 221 | 222 | d = {} 223 | with open(os.path.join(dirpath, 'dictionary.txt')) as dictionary: 224 | for line in dictionary: 225 | s, idx = line.split('|') 226 | d[s] = labels[int(idx)] 227 | return d 228 | 229 | def build_vocab(filepaths, dst_path, lowercase=True): 230 | vocab = set() 231 | for filepath in filepaths: 232 | with open(filepath) as f: 233 | for line in f: 234 | if lowercase: 235 | line = line.lower() 236 | vocab |= set(line.split()) 237 | with open(dst_path, 'w') as f: 238 | for w in sorted(vocab): 239 | f.write(w + '\n') 240 | 241 | def split(sst_dir, train_dir, dev_dir, test_dir): 242 | sents = load_sents(sst_dir) 243 | splits = load_splits(sst_dir) 244 | parents = load_parents(sst_dir) 245 | 246 | with open(os.path.join(train_dir, 'sents.txt'), 'w') as train, \ 247 | open(os.path.join(dev_dir, 'sents.txt'), 'w') as dev, \ 248 | open(os.path.join(test_dir, 'sents.txt'), 'w') as test, \ 249 | open(os.path.join(train_dir, 'parents.txt'), 'w') as trainparents, \ 250 | open(os.path.join(dev_dir, 'parents.txt'), 'w') as devparents, \ 251 | open(os.path.join(test_dir, 'parents.txt'), 'w') as testparents: 252 | 253 | for sent, split, p in zip(sents, splits, parents): 254 | if split == 1: 255 | train.write(sent) 256 | train.write('\n') 257 | trainparents.write(p) 258 | trainparents.write('\n') 259 | elif split == 2: 260 | test.write(sent) 261 | test.write('\n') 262 | 263 | testparents.write(p) 264 | testparents.write('\n') 265 | else: 266 | dev.write(sent) 267 | dev.write('\n') 268 | devparents.write(p) 269 | devparents.write('\n') 270 | 271 | def get_labels(tree, dictionary): 272 | size = tree.size() 273 | spans, labels = [], [] 274 | for i in xrange(size): 275 | labels.append(None) 276 | spans.append(None) 277 | tree.get_labels(spans, labels, dictionary) 278 | return spans, labels 279 | 280 | def write_labels(dirpath, dictionary): 281 | print('Writing labels for trees in ' + dirpath) 282 | with open(os.path.join(dirpath, 'labels.txt'), 'w') as labels, \ 283 | open(os.path.join(dirpath, 'dlabels.txt'), 'w') as dlabels: 284 | # load constituency and dependency trees 285 | const_trees, dep_trees, toks = load_trees(dirpath) 286 | 287 | # write span labels 288 | for i in xrange(len(const_trees)): 289 | const_trees[i].set_spans() 290 | dep_trees[i].set_spans(toks[i]) 291 | 292 | # const tree labels 293 | s, l = [], [] 294 | for j in xrange(const_trees[i].size()): 295 | s.append(None) 296 | l.append(None) 297 | const_trees[i].get_labels(s, l, dictionary) 298 | labels.write(' '.join(map(str, l)) + '\n') 299 | 300 | # dep tree labels 301 | dep_trees[i].span = const_trees[i].span 302 | s, l = [], [] 303 | for j in xrange(len(toks[i])): 304 | s.append(None) 305 | l.append('#') 306 | dep_trees[i].get_labels(s, l, dictionary) 307 | dlabels.write(' '.join(map(str, l)) + '\n') 308 | 309 | def dependency_parse(filepath, cp='', tokenize=True): 310 | print('\nDependency parsing ' + filepath) 311 | dirpath = os.path.dirname(filepath) 312 | filepre = os.path.splitext(os.path.basename(filepath))[0] 313 | tokpath = os.path.join(dirpath, filepre + '.toks') 314 | parentpath = os.path.join(dirpath, 'dparents.txt') 315 | relpath = os.path.join(dirpath, 'rels.txt') 316 | tokenize_flag = '-tokenize - ' if tokenize else '' 317 | cmd = ('java -cp %s DependencyParse -tokpath %s -parentpath %s -relpath %s %s < %s' 318 | % (cp, tokpath, parentpath, relpath, tokenize_flag, filepath)) 319 | os.system(cmd) 320 | 321 | if __name__ == '__main__': 322 | print('=' * 80) 323 | print('Preprocessing Stanford Sentiment Treebank') 324 | print('=' * 80) 325 | 326 | base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 327 | data_dir = os.path.join(base_dir, 'data') 328 | lib_dir = os.path.join(base_dir, 'lib') 329 | sst_dir = os.path.join(data_dir, 'sst') 330 | train_dir = os.path.join(sst_dir, 'train') 331 | dev_dir = os.path.join(sst_dir, 'dev') 332 | test_dir = os.path.join(sst_dir, 'test') 333 | make_dirs([train_dir, dev_dir, test_dir]) 334 | 335 | # produce train/dev/test splits 336 | split(sst_dir, train_dir, dev_dir, test_dir) 337 | sent_paths = glob.glob(os.path.join(sst_dir, '*/sents.txt')) 338 | 339 | # produce dependency parses 340 | classpath = ':'.join([ 341 | lib_dir, 342 | os.path.join(lib_dir, 'stanford-parser/stanford-parser.jar'), 343 | os.path.join(lib_dir, 'stanford-parser/stanford-parser-3.5.1-models.jar')]) 344 | for filepath in sent_paths: 345 | dependency_parse(filepath, cp=classpath, tokenize=False) 346 | 347 | # get vocabulary 348 | build_vocab(sent_paths, os.path.join(sst_dir, 'vocab.txt')) 349 | build_vocab(sent_paths, os.path.join(sst_dir, 'vocab-cased.txt'), lowercase=False) 350 | 351 | # write sentiment labels for nodes in trees 352 | dictionary = load_dictionary(sst_dir) 353 | write_labels(train_dir, dictionary) 354 | write_labels(dev_dir, dictionary) 355 | write_labels(test_dir, dictionary) 356 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable as Var 5 | import Constants 6 | import utils 7 | 8 | class BinaryTreeLeafModule(nn.Module): 9 | """ 10 | local input = nn.Identity()() 11 | local c = nn.Linear(self.in_dim, self.mem_dim)(input) 12 | local h 13 | if self.gate_output then 14 | local o = nn.Sigmoid()(nn.Linear(self.in_dim, self.mem_dim)(input)) 15 | h = nn.CMulTable(){o, nn.Tanh()(c)} 16 | else 17 | h = nn.Tanh()(c) 18 | end 19 | 20 | local leaf_module = nn.gModule({input}, {c, h}) 21 | """ 22 | def __init__(self, cuda, in_dim, mem_dim): 23 | super(BinaryTreeLeafModule, self).__init__() 24 | self.cudaFlag = cuda 25 | self.in_dim = in_dim 26 | self.mem_dim = mem_dim 27 | 28 | self.cx = nn.Linear(self.in_dim, self.mem_dim) 29 | self.ox = nn.Linear(self.in_dim, self.mem_dim) 30 | if self.cudaFlag: 31 | self.cx = self.cx.cuda() 32 | self.ox = self.ox.cuda() 33 | 34 | def forward(self, input): 35 | c = self.cx(input) 36 | o = F.sigmoid(self.ox(input)) 37 | h = o * F.tanh(c) 38 | return c, h 39 | 40 | class BinaryTreeComposer(nn.Module): 41 | """ 42 | local lc, lh = nn.Identity()(), nn.Identity()() 43 | local rc, rh = nn.Identity()(), nn.Identity()() 44 | local new_gate = function() 45 | return nn.CAddTable(){ 46 | nn.Linear(self.mem_dim, self.mem_dim)(lh), 47 | nn.Linear(self.mem_dim, self.mem_dim)(rh) 48 | } 49 | end 50 | 51 | local i = nn.Sigmoid()(new_gate()) -- input gate 52 | local lf = nn.Sigmoid()(new_gate()) -- left forget gate 53 | local rf = nn.Sigmoid()(new_gate()) -- right forget gate 54 | local update = nn.Tanh()(new_gate()) -- memory cell update vector 55 | local c = nn.CAddTable(){ -- memory cell 56 | nn.CMulTable(){i, update}, 57 | nn.CMulTable(){lf, lc}, 58 | nn.CMulTable(){rf, rc} 59 | } 60 | 61 | local h 62 | if self.gate_output then 63 | local o = nn.Sigmoid()(new_gate()) -- output gate 64 | h = nn.CMulTable(){o, nn.Tanh()(c)} 65 | else 66 | h = nn.Tanh()(c) 67 | end 68 | local composer = nn.gModule( 69 | {lc, lh, rc, rh}, 70 | {c, h}) 71 | """ 72 | def __init__(self, cuda, in_dim, mem_dim): 73 | super(BinaryTreeComposer, self).__init__() 74 | self.cudaFlag = cuda 75 | self.in_dim = in_dim 76 | self.mem_dim = mem_dim 77 | 78 | def new_gate(): 79 | lh = nn.Linear(self.mem_dim, self.mem_dim) 80 | rh = nn.Linear(self.mem_dim, self.mem_dim) 81 | return lh, rh 82 | 83 | self.ilh, self.irh = new_gate() 84 | self.lflh, self.lfrh = new_gate() 85 | self.rflh, self.rfrh = new_gate() 86 | self.ulh, self.urh = new_gate() 87 | 88 | if self.cudaFlag: 89 | self.ilh = self.ilh.cuda() 90 | self.irh = self.irh.cuda() 91 | self.lflh = self.lflh.cuda() 92 | self.lfrh = self.lfrh.cuda() 93 | self.rflh = self.rflh.cuda() 94 | self.rfrh = self.rfrh.cuda() 95 | self.ulh = self.ulh.cuda() 96 | 97 | def forward(self, lc, lh , rc, rh): 98 | i = F.sigmoid(self.ilh(lh) + self.irh(rh)) 99 | lf = F.sigmoid(self.lflh(lh) + self.lfrh(rh)) 100 | rf = F.sigmoid(self.rflh(lh) + self.rfrh(rh)) 101 | update = F.tanh(self.ulh(lh) + self.urh(rh)) 102 | c = i* update + lf*lc + rf*rc 103 | h = F.tanh(c) 104 | return c, h 105 | 106 | 107 | 108 | 109 | 110 | 111 | class BinaryTreeLSTM(nn.Module): 112 | def __init__(self, cuda, in_dim, mem_dim, criterion): 113 | super(BinaryTreeLSTM, self).__init__() 114 | self.cudaFlag = cuda 115 | self.in_dim = in_dim 116 | self.mem_dim = mem_dim 117 | self.criterion = criterion 118 | 119 | self.leaf_module = BinaryTreeLeafModule(cuda,in_dim, mem_dim) 120 | self.composer = BinaryTreeComposer(cuda, in_dim, mem_dim) 121 | self.output_module = None 122 | 123 | def set_output_module(self, output_module): 124 | self.output_module = output_module 125 | 126 | def getParameters(self): 127 | """ 128 | Get flatParameters 129 | note that getParameters and parameters is not equal in this case 130 | getParameters do not get parameters of output module 131 | :return: 1d tensor 132 | """ 133 | params = [] 134 | for m in [self.ix, self.ih, self.fx, self.fh, self.ox, self.oh, self.ux, self.uh]: 135 | # we do not get param of output module 136 | l = list(m.parameters()) 137 | params.extend(l) 138 | 139 | one_dim = [p.view(p.numel()) for p in params] 140 | params = F.torch.cat(one_dim) 141 | return params 142 | 143 | def forward(self, tree, embs, training = False): 144 | # add singleton dimension for future call to node_forward 145 | # embs = F.torch.unsqueeze(self.emb(inputs),1) 146 | 147 | loss = Var(torch.zeros(1)) # init zero loss 148 | if self.cudaFlag: 149 | loss = loss.cuda() 150 | 151 | if tree.num_children == 0: 152 | # leaf case 153 | tree.state = self.leaf_module.forward(embs[tree.idx-1]) 154 | else: 155 | for idx in range(tree.num_children): 156 | _, child_loss = self.forward(tree.children[idx], embs, training) 157 | loss = loss + child_loss 158 | lc, lh, rc, rh = self.get_child_state(tree) 159 | tree.state = self.composer.forward(lc, lh, rc, rh) 160 | 161 | if self.output_module != None: 162 | output = self.output_module.forward(tree.state[1], training) 163 | tree.output = output 164 | if training and tree.gold_label != None: 165 | target = Var(utils.map_label_to_target_sentiment(tree.gold_label)) 166 | if self.cudaFlag: 167 | target = target.cuda() 168 | loss = loss + self.criterion(output, target) 169 | return tree.state, loss 170 | 171 | 172 | def get_child_state(self, tree): 173 | lc, lh = tree.children[0].state 174 | rc, rh = tree.children[1].state 175 | return lc, lh, rc, rh 176 | 177 | ################################################################### 178 | 179 | # module for childsumtreelstm 180 | class ChildSumTreeLSTM(nn.Module): 181 | def __init__(self, cuda, in_dim, mem_dim, criterion): 182 | super(ChildSumTreeLSTM, self).__init__() 183 | self.cudaFlag = cuda 184 | self.in_dim = in_dim 185 | self.mem_dim = mem_dim 186 | 187 | # self.emb = nn.Embedding(vocab_size,in_dim, 188 | # padding_idx=Constants.PAD) 189 | # torch.manual_seed(123) 190 | 191 | self.ix = nn.Linear(self.in_dim,self.mem_dim) 192 | self.ih = nn.Linear(self.mem_dim,self.mem_dim) 193 | 194 | self.fh = nn.Linear(self.mem_dim, self.mem_dim) 195 | self.fx = nn.Linear(self.in_dim,self.mem_dim) 196 | 197 | self.ux = nn.Linear(self.in_dim,self.mem_dim) 198 | self.uh = nn.Linear(self.mem_dim,self.mem_dim) 199 | 200 | self.ox = nn.Linear(self.in_dim,self.mem_dim) 201 | self.oh = nn.Linear(self.mem_dim,self.mem_dim) 202 | 203 | self.criterion = criterion 204 | self.output_module = None 205 | 206 | def set_output_module(self, output_module): 207 | self.output_module = output_module 208 | 209 | def getParameters(self): 210 | """ 211 | Get flatParameters 212 | note that getParameters and parameters is not equal in this case 213 | getParameters do not get parameters of output module 214 | :return: 1d tensor 215 | """ 216 | params = [] 217 | for m in [self.ix, self.ih, self.fx, self.fh, self.ox, self.oh, self.ux, self.uh]: 218 | # we do not get param of output module 219 | l = list(m.parameters()) 220 | params.extend(l) 221 | 222 | one_dim = [p.view(p.numel()) for p in params] 223 | params = F.torch.cat(one_dim) 224 | return params 225 | 226 | 227 | def node_forward(self, inputs, child_c, child_h): 228 | """ 229 | 230 | :param inputs: (1, 300) 231 | :param child_c: (num_children, 1, mem_dim) 232 | :param child_h: (num_children, 1, mem_dim) 233 | :return: (tuple) 234 | c: (1, mem_dim) 235 | h: (1, mem_dim) 236 | """ 237 | 238 | child_h_sum = F.torch.sum(torch.squeeze(child_h,1),0) 239 | 240 | i = F.sigmoid(self.ix(inputs)+self.ih(child_h_sum)) 241 | o = F.sigmoid(self.ox(inputs)+self.oh(child_h_sum)) 242 | u = F.tanh(self.ux(inputs)+self.uh(child_h_sum)) 243 | 244 | # add extra singleton dimension 245 | fx = F.torch.unsqueeze(self.fx(inputs),1) 246 | f = F.torch.cat([self.fh(child_hi)+fx for child_hi in child_h], 0) 247 | f = F.sigmoid(f) 248 | 249 | # f = F.torch.unsqueeze(f,1) # comment to fix dimension missmatch 250 | fc = F.torch.squeeze(F.torch.mul(f,child_c),1) 251 | 252 | c = F.torch.mul(i,u) + F.torch.sum(fc,0) 253 | h = F.torch.mul(o, F.tanh(c)) 254 | 255 | return c, h 256 | 257 | def forward(self, tree, embs, training = False): 258 | """ 259 | Child sum tree LSTM forward function 260 | :param tree: 261 | :param embs: (sentence_length, 1, 300) 262 | :param training: 263 | :return: 264 | """ 265 | 266 | # add singleton dimension for future call to node_forward 267 | # embs = F.torch.unsqueeze(self.emb(inputs),1) 268 | 269 | loss = Var(torch.zeros(1)) # init zero loss 270 | if self.cudaFlag: 271 | loss = loss.cuda() 272 | 273 | for idx in range(tree.num_children): 274 | _, child_loss = self.forward(tree.children[idx], embs, training) 275 | loss = loss + child_loss 276 | child_c, child_h = self.get_child_states(tree) 277 | tree.state = self.node_forward(embs[tree.idx-1], child_c, child_h) 278 | 279 | if self.output_module != None: 280 | output = self.output_module.forward(tree.state[1], training) 281 | tree.output = output 282 | if training and tree.gold_label != None: 283 | target = Var(utils.map_label_to_target_sentiment(tree.gold_label)) 284 | if self.cudaFlag: 285 | target = target.cuda() 286 | loss = loss + self.criterion(output, target) 287 | return tree.state, loss 288 | 289 | def get_child_states(self, tree): 290 | """ 291 | Get c and h of all children 292 | :param tree: 293 | :return: (tuple) 294 | child_c: (num_children, 1, mem_dim) 295 | child_h: (num_children, 1, mem_dim) 296 | """ 297 | # add extra singleton dimension in middle... 298 | # because pytorch needs mini batches... :sad: 299 | if tree.num_children==0: 300 | child_c = Var(torch.zeros(1,1,self.mem_dim)) 301 | child_h = Var(torch.zeros(1,1,self.mem_dim)) 302 | if self.cudaFlag: 303 | child_c, child_h = child_c.cuda(), child_h.cuda() 304 | else: 305 | child_c = Var(torch.Tensor(tree.num_children,1,self.mem_dim)) 306 | child_h = Var(torch.Tensor(tree.num_children,1,self.mem_dim)) 307 | if self.cudaFlag: 308 | child_c, child_h = child_c.cuda(), child_h.cuda() 309 | for idx in range(tree.num_children): 310 | child_c[idx] = tree.children[idx].state[0] 311 | child_h[idx] = tree.children[idx].state[1] 312 | # child_c[idx], child_h[idx] = tree.children[idx].state 313 | return child_c, child_h 314 | 315 | ############################################################################## 316 | 317 | # output module 318 | class SentimentModule(nn.Module): 319 | def __init__(self, cuda, mem_dim, num_classes, dropout = False): 320 | super(SentimentModule, self).__init__() 321 | self.cudaFlag = cuda 322 | self.mem_dim = mem_dim 323 | self.num_classes = num_classes 324 | self.dropout = dropout 325 | # torch.manual_seed(456) 326 | self.l1 = nn.Linear(self.mem_dim, self.num_classes) 327 | self.logsoftmax = nn.LogSoftmax() 328 | if self.cudaFlag: 329 | self.l1 = self.l1.cuda() 330 | 331 | def forward(self, vec, training = False): 332 | """ 333 | Sentiment module forward function 334 | :param vec: (1, mem_dim) 335 | :param training: 336 | :return: 337 | (1, number_of_class) 338 | """ 339 | if self.dropout: 340 | out = self.logsoftmax(self.l1(F.dropout(vec, training = training))) 341 | else: 342 | out = self.logsoftmax(self.l1(vec)) 343 | return out 344 | 345 | class TreeLSTMSentiment(nn.Module): 346 | def __init__(self, cuda, vocab_size, in_dim, mem_dim, num_classes, model_name, criterion): 347 | super(TreeLSTMSentiment, self).__init__() 348 | self.cudaFlag = cuda 349 | self.model_name = model_name 350 | if self.model_name == 'dependency': 351 | self.tree_module = ChildSumTreeLSTM(cuda, in_dim, mem_dim, criterion) 352 | elif self.model_name == 'constituency': 353 | self.tree_module = BinaryTreeLSTM(cuda, in_dim, mem_dim, criterion) 354 | self.output_module = SentimentModule(cuda, mem_dim, num_classes, dropout=True) 355 | self.tree_module.set_output_module(self.output_module) 356 | 357 | def forward(self, tree, inputs, training = False): 358 | """ 359 | TreeLSTMSentiment forward function 360 | :param tree: 361 | :param inputs: (sentence_length, 1, 300) 362 | :param training: 363 | :return: 364 | """ 365 | tree_state, loss = self.tree_module(tree, inputs, training) 366 | output = tree.output 367 | return output, loss 368 | --------------------------------------------------------------------------------