├── cstlstm ├── __init__.py ├── prev_states.py ├── encoder.py ├── cell.py └── tree_batch.py ├── ext ├── __init__.py ├── pickling.py ├── parameters.py ├── vocab_emb.py ├── histories.py ├── models.py └── training.py ├── glovar.py ├── tests ├── __init__.py └── ext_tests.py ├── pre_process.py ├── train_sst.py ├── train_nli.py ├── .gitignore ├── eval_nli.py ├── models ├── sentiment.py └── inference.py ├── README.md └── data ├── nli.py └── sst.py /cstlstm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ext/__init__.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | 3 | 4 | NLP = spacy.load('en') 5 | -------------------------------------------------------------------------------- /glovar.py: -------------------------------------------------------------------------------- 1 | """Global variables.""" 2 | import os 3 | 4 | 5 | APP_DIR = os.getcwd() 6 | CKPT_DIR = os.path.join(APP_DIR, 'ckpts/') 7 | DATA_DIR = '/home/hanshan/dev/data/' 8 | GLOVE_DIR = '/home/hanshan/dev/data/glove/glove.840B.300d.txt' 9 | PKL_DIR = os.path.join(APP_DIR, 'pickles/') 10 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from tests import ext_tests 3 | 4 | 5 | # > python -m unittest discover 6 | 7 | 8 | test_cases = [ 9 | ext_tests.HistoryTests, 10 | ] 11 | 12 | 13 | def load_tests(loader, tests, pattern): 14 | suite = unittest.TestSuite() 15 | for test_case in test_cases: 16 | tests = loader.loadTestsFromTestCase(test_case) 17 | suite.addTests(tests) 18 | return suite -------------------------------------------------------------------------------- /pre_process.py: -------------------------------------------------------------------------------- 1 | """For pre-processing the data.""" 2 | from ext import vocab_emb, pickling 3 | from data import sst, nli 4 | import glovar 5 | import os 6 | 7 | 8 | if not os.path.exists(glovar.PKL_DIR): 9 | os.makedirs(glovar.PKL_DIR) 10 | if not os.path.exists(glovar.CKPT_DIR): 11 | os.makedirs(glovar.CKPT_DIR) 12 | 13 | 14 | # Create the vocab dictionary 15 | print('Creating vocab dict...') 16 | sst_text = sst.get_text() 17 | nli_text = nli.get_text() 18 | all_text = ' '.join([sst_text, nli_text]) 19 | vocab_dict, _ = vocab_emb.create_vocab_dict(all_text) 20 | pickling.save(vocab_dict, glovar.PKL_DIR, 'vocab_dict.pkl') 21 | print('Success.') 22 | 23 | 24 | # Create GloVe embeddings 25 | print('Creating GloVe embeddings...') 26 | embedding_mat = vocab_emb.create_embeddings(vocab_dict, 300, glovar.GLOVE_DIR) 27 | pickling.save(embedding_mat, glovar.PKL_DIR, 'glove_embeddings.pkl') 28 | print('Success.') 29 | -------------------------------------------------------------------------------- /tests/ext_tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from ext import histories 3 | 4 | 5 | class HistoryTests(unittest.TestCase): 6 | def test_last_change(self): 7 | empty_series = [] 8 | single_value = [1.3] 9 | no_change = [1.3, 1.3] 10 | pos_change = [1.3, 1.4] 11 | neg_change = [1.3, 1.2] 12 | exception = False 13 | try: 14 | histories.History.last_change(empty_series) 15 | except ValueError: 16 | exception = True 17 | self.assertTrue(exception) 18 | self.assertEqual(round(histories.History.last_change(single_value), 1), 19 | 1.3) 20 | self.assertEqual(round(histories.History.last_change(no_change), 1), 21 | 0.0) 22 | self.assertEqual(round(histories.History.last_change(pos_change), 1), 23 | 0.1) 24 | self.assertEqual(round(histories.History.last_change(neg_change), 1), 25 | -0.1) 26 | -------------------------------------------------------------------------------- /ext/pickling.py: -------------------------------------------------------------------------------- 1 | """Convenient interface for loading and saving pickles.""" 2 | import pickle 3 | import os 4 | 5 | 6 | def load(pkl_dir, pkl_name): 7 | """Load a pickle. 8 | 9 | Args: 10 | pkl_dir: String, the directory in which to save the pickle. 11 | pkl_name: String, the file_name for the pickle. 12 | 13 | Returns: 14 | Object. 15 | 16 | Raises: 17 | Exception if pickle not found. 18 | """ 19 | file_path = os.path.join(pkl_dir, pkl_name) 20 | try: 21 | with open(file_path, 'rb') as file: 22 | obj = pickle.load(file) 23 | return obj 24 | except FileNotFoundError: 25 | raise Exception('Pickle not found: %s' % file_path) 26 | 27 | 28 | def save(obj, pkl_dir, pkl_name): 29 | """Save a pickle. 30 | 31 | Args: 32 | obj: Object, the object to pickle. 33 | pkl_dir: String, the directory in which to save the pickle. 34 | pkl_name: String, the file_name for the pickle. 35 | """ 36 | file_path = os.path.join(pkl_dir, pkl_name) 37 | with open(file_path, 'wb') as file: 38 | pickle.dump(obj, file) 39 | -------------------------------------------------------------------------------- /train_sst.py: -------------------------------------------------------------------------------- 1 | """Train Child-Sum Tree-LSTM model on the Stanford Sentiment Treebank.""" 2 | import glovar 3 | from data import sst 4 | from ext import parameters, pickling, training, histories 5 | from models import sentiment 6 | 7 | # Parse configuration settings from command line 8 | params, arg_config = parameters.parse_arguments() 9 | 10 | 11 | # Get or create History 12 | history = histories.get( 13 | glovar.PKL_DIR, params.name, params.override, arg_config) 14 | 15 | 16 | # Report config to be used 17 | config = history.config 18 | print(config) 19 | 20 | 21 | print('Load embedding matrix...') 22 | embedding_matrix = pickling.load(glovar.PKL_DIR, 'glove_embeddings.pkl')[0] 23 | 24 | 25 | print('Loading data...') 26 | train_data, dev_data, _ = sst.get_data() 27 | train_loader = sst.get_data_loader(train_data, config.batch_size) 28 | dev_loader = sst.get_data_loader(dev_data, config.batch_size) 29 | 30 | 31 | print('Loading model...') 32 | model = sentiment.SentimentModel(params.name, config, embedding_matrix) 33 | 34 | 35 | print('Loading trainer...') 36 | trainer = training.PyTorchTrainer( 37 | model, history, train_loader, dev_loader, glovar.CKPT_DIR) 38 | 39 | 40 | print('Training...') 41 | trainer.train() 42 | -------------------------------------------------------------------------------- /train_nli.py: -------------------------------------------------------------------------------- 1 | """For training on NLI data.""" 2 | import glovar 3 | from data import nli 4 | from ext import parameters, histories, pickling, training 5 | from models import inference 6 | 7 | # Parse configuration settings from command line 8 | params, arg_config = parameters.parse_arguments() 9 | 10 | 11 | # Get or create History 12 | history = histories.get( 13 | glovar.PKL_DIR, params.name, params.override, arg_config) 14 | 15 | 16 | # Report config to be used 17 | config = history.config 18 | print(config) 19 | 20 | 21 | # Get vocab dict and embeddings 22 | print('Load vocab dict and embedding matrix...') 23 | vocab_dict = pickling.load(glovar.PKL_DIR, 'vocab_dict.pkl') 24 | embedding_matrix = pickling.load(glovar.PKL_DIR, 'glove_embeddings.pkl')[0] 25 | 26 | 27 | print('Loading data...') 28 | mnli_train = nli.load_json('mnli', 'train') 29 | snli_train = nli.load_json('snli', 'train') 30 | mnli_dev_matched = nli.load_json('mnli', 'dev_matched') 31 | train_data = nli.NYUDataSet( 32 | mnli_train, snli_train, vocab_dict, params.train_subset) 33 | tune_data = nli.NLIDataSet(mnli_dev_matched, vocab_dict, params.tune_subset) 34 | train_loader = nli.get_data_loader(train_data, config.batch_size) 35 | dev_loader = nli.get_data_loader(tune_data, config.batch_size) 36 | 37 | 38 | print('Loading model...') 39 | model = inference.InferenceModel(params.name, config, embedding_matrix) 40 | 41 | 42 | print('Loading trainer...') 43 | trainer = training.PyTorchTrainer( 44 | model, history, train_loader, dev_loader, glovar.CKPT_DIR) 45 | 46 | 47 | print('Training...') 48 | trainer.train() 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints/ 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | 93 | # Rope project settings 94 | .ropeproject 95 | 96 | # Pycharm 97 | .idea/ 98 | 99 | # Project specific 100 | graphs/ 101 | logs/ 102 | ckpts/ 103 | histories/ 104 | pickles/ 105 | main.py 106 | scratch.py 107 | prep.py 108 | eval.py 109 | kai/ 110 | -------------------------------------------------------------------------------- /ext/parameters.py: -------------------------------------------------------------------------------- 1 | """Configuration settings and mappings.""" 2 | import argparse 3 | from ext import models 4 | 5 | 6 | class Params: 7 | def __init__(self, name, override, train_subset, tune_subset): 8 | self.name = name 9 | self.override = override 10 | self.train_subset = train_subset 11 | self.tune_subset = tune_subset 12 | 13 | 14 | def parse_arguments(): 15 | base_config = models.Config() 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('name', 18 | type=str, 19 | help='The name for the training run (unique).') 20 | parser.add_argument('--override', 21 | action='store_true', 22 | help='Set true to overwrite an old history.') 23 | parser.add_argument('--tune_embeddings', 24 | action='store_true', 25 | help='Set true to tune embeddings.') 26 | parser.add_argument('--train_subset', 27 | type=int, 28 | help='Size of subset to select from training data.', 29 | default=None) 30 | parser.add_argument('--tune_subset', 31 | type=int, 32 | help='Size of subset to select from tuning data.', 33 | default=None) 34 | arg_config = {} 35 | for key in [k for k in base_config.keys() if k != 'tune_embeddings']: 36 | parser.add_argument( 37 | '--%s' % key, 38 | help='Set config.%s' % key, 39 | type=type(base_config[key])) 40 | arg_config[key] = base_config[key] 41 | args = parser.parse_args() 42 | params = Params( 43 | args.name, args.override, args.train_subset, args.tune_subset) 44 | for key in base_config.keys(): 45 | passed_value = getattr(args, key) 46 | if passed_value is not None: 47 | arg_config[key] = passed_value 48 | return params, arg_config 49 | -------------------------------------------------------------------------------- /eval_nli.py: -------------------------------------------------------------------------------- 1 | """For evaluating on NLI data.""" 2 | from data import nli 3 | from models import inference 4 | from ext import parameters, histories, pickling, training 5 | import glovar 6 | 7 | 8 | # Parse configuration settings from command line 9 | params, arg_config = parameters.parse_arguments() 10 | 11 | 12 | # Get or create History 13 | history = histories.get( 14 | glovar.PKL_DIR, params.name, params.override, arg_config) 15 | 16 | 17 | # Report config to be used 18 | config = history.config 19 | print(config) 20 | 21 | 22 | # Get vocab dict and embeddings 23 | print('Load vocab dict and embedding matrix...') 24 | vocab_dict = pickling.load(glovar.PKL_DIR, 'vocab_dict.pkl') 25 | embedding_matrix = pickling.load(glovar.PKL_DIR, 'glove_embeddings.pkl')[0] 26 | 27 | 28 | print('Loading data...') 29 | mnli_train = nli.load_json('mnli', 'train') 30 | snli_train = nli.load_json('snli', 'train') 31 | mnli_dev_matched = nli.load_json('mnli', 'dev_matched') 32 | train_data = nli.NYUDataSet(mnli_train, snli_train, vocab_dict) 33 | tune_data = nli.NLIDataSet(mnli_dev_matched, vocab_dict) 34 | train_loader = nli.get_data_loader(train_data, config.batch_size) 35 | dev_loader = nli.get_data_loader(tune_data, config.batch_size) 36 | 37 | 38 | print('Loading model...') 39 | model = inference.InferenceModel(params.name, config, embedding_matrix) 40 | 41 | 42 | print('Loading best checkpoint...') 43 | saver = training.Saver(glovar.CKPT_DIR) 44 | saver.load(model, history.name, True) 45 | 46 | 47 | print('Evaluating...') 48 | for db in nli.NLI_DBS: 49 | for coll in nli.NLI_COLLS[db]: 50 | if not (db == 'mnli' and coll.startswith('test')): 51 | subset_size = None 52 | if coll == 'train': # For both mnli and snli. 53 | subset_size = 10000 54 | data = nli.NLIDataSet( 55 | nli.load_json(db, coll), vocab_dict, subset_size) 56 | data_loader = nli.get_data_loader(data, config.batch_size) 57 | cum_acc = 0. 58 | for _, batch in enumerate(data_loader): 59 | __, ___, acc = model.forward(batch) 60 | cum_acc += acc 61 | acc = cum_acc / len(data_loader) 62 | print('%s\t%s\t%55.3f%%' % (db, coll, acc)) 63 | -------------------------------------------------------------------------------- /models/sentiment.py: -------------------------------------------------------------------------------- 1 | """Model for sentiment analysis with the Stanford Sentiment Treebank.""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.autograd import Variable 6 | import numpy as np 7 | from cstlstm import encoder 8 | from ext import models 9 | 10 | 11 | class SentimentModel(models.PyTorchModel): 12 | """Classifier for Stanford Sentiment Treebank.""" 13 | 14 | def __init__(self, name, config, embedding_matrix): 15 | super(SentimentModel, self).__init__(name, config, embedding_matrix) 16 | 17 | # Define encoder. 18 | self.encoder = encoder.ChildSumTreeLSTMEncoder( 19 | self.embed_size, self.hidden_size, self.embedding, 20 | self.p_keep_input, self.p_keep_rnn) 21 | 22 | # Define linear classification layer. 23 | self.logits_layer = nn.Linear(self.hidden_size, 5).cuda() 24 | 25 | # Define optimizer. 26 | print(len(list(self.encoder.cell.parameters()))) 27 | params = [{'params': self.encoder.cell.parameters()}, 28 | {'params': self.logits_layer.parameters()}] 29 | if self.tune_embeddings: 30 | params.append({'params': self.embeddings.parameters(), 31 | 'lr': self.learning_rate / 10.}) # Avoid overfitting 32 | self.optimizer = optim.Adam(params, lr=self.learning_rate) 33 | 34 | # Init params with xavier. 35 | nn.init.xavier_uniform(self.logits_layer.weight.data, gain=1) 36 | 37 | @staticmethod 38 | def annotated_encodings(encodings, annotation_ixs): 39 | selected = [] 40 | for l in range(max(encodings.keys()) + 1): 41 | selected += [encodings[l][1][i] for i in annotation_ixs[l]] 42 | return torch.stack(selected, 0) 43 | 44 | @staticmethod 45 | def current_batch_size(forest): 46 | return len(forest.labels) 47 | 48 | def forward(self, forest): 49 | labels = Variable( 50 | torch.from_numpy(np.array(forest.labels)), 51 | requires_grad=False).cuda() 52 | logits = self.logits(forest) 53 | loss = self.loss(logits, labels) 54 | predictions = self.predictions(logits).type_as(labels) 55 | correct = self.correct_predictions(predictions, labels) 56 | accuracy = self.accuracy(correct, self.current_batch_size(forest))[0] 57 | return predictions, loss, accuracy 58 | 59 | def logits(self, forest): 60 | encodings = self.encoder.forward(forest) 61 | annotated = self.annotated_encodings(encodings, forest.annotation_ixs) 62 | logits = self.logits_layer(annotated) 63 | return logits 64 | -------------------------------------------------------------------------------- /cstlstm/prev_states.py: -------------------------------------------------------------------------------- 1 | """For getting hidden states for nodes on lower level.""" 2 | import torch 3 | from torch.autograd import Variable 4 | 5 | 6 | class PreviousStates: 7 | """For getting previous hidden states from lower level given wirings.""" 8 | 9 | def __init__(self, hidden_size): 10 | """Create a new PreviousStates. 11 | 12 | Args: 13 | hidden_size: Integer, number of units in a hidden state vector. 14 | """ 15 | self.hidden_size = hidden_size 16 | 17 | def __call__(self, level_nodes, level_up_wirings, prev_outputs): 18 | """Get previous hidden states. 19 | 20 | Args: 21 | level_nodes: List of nodes on the level to be processed. 22 | level_up_wirings: List of Lists: the list is of the same length as the 23 | level_nodes list. Each sublist gives the integer indices of the 24 | child nodes in the node list on the previous (lower) level. This 25 | defines how the child nodes wire up to the parent nodes. 26 | prev_outputs: List of previous hidden state tuples for the level below 27 | from which we will select from. 28 | 29 | Returns: 30 | ? 31 | """ 32 | # Count how many nodes on this level of the forest. 33 | level_length = len(level_nodes) 34 | 35 | # grab the cell states 36 | cell_states = self.states( 37 | level_nodes, level_length, prev_outputs[0], level_up_wirings) 38 | 39 | # grab the hidden states 40 | hidden_states = self.states( 41 | level_nodes, level_length, prev_outputs[1], level_up_wirings) 42 | 43 | # mind the order of returning 44 | return cell_states, hidden_states 45 | 46 | @staticmethod 47 | def children(prev_out, child_ixs_level_i): 48 | # doesn't work for empty sets of children - in which case don't call 49 | selector = Variable(torch.LongTensor(child_ixs_level_i)).cuda() 50 | return prev_out.index_select(0, selector) 51 | 52 | def states(self, level_nodes, level_length, prev_out, child_ixs_level): 53 | return [(self.zero_vec() 54 | if (level_nodes[i].is_leaf or len(child_ixs_level[i]) == 0) 55 | else self.children(prev_out, child_ixs_level[i])) 56 | for i in range(level_length)] 57 | 58 | def zero_level(self, level_length): 59 | # Doing this right away should save a bit more time each batch. 60 | cell_states = [self.zero_vec() for _ in range(level_length)] 61 | hidden_states = [self.zero_vec() for _ in range(level_length)] 62 | return cell_states, hidden_states 63 | 64 | def zero_vec(self): 65 | return Variable(torch.zeros(1, self.hidden_size), 66 | requires_grad=False).cuda() 67 | -------------------------------------------------------------------------------- /ext/vocab_emb.py: -------------------------------------------------------------------------------- 1 | """For creating vocab dictionaries and word embedding matrices.""" 2 | import numpy as np 3 | import spacy 4 | import collections 5 | 6 | 7 | PADDING = "" 8 | UNKNOWN = "" 9 | LBR = '(' 10 | RBR = ')' 11 | 12 | 13 | def create_embeddings(vocab, emb_size, embedding_file_path): 14 | """Create embeddings for the vocabulary. 15 | 16 | Creates an embedding matrix given the pre-trained word vectors, and any OOV 17 | tokens are initialized to random vectors. 18 | 19 | Args: 20 | vocab: Dictionary for the vocab with {token: id}. 21 | emb_size: Integer, the size of the word embeddings. 22 | embedding_file_path: String, file path to the pre-trained embeddings to 23 | use. 24 | 25 | Returns: 26 | embeddings, oov: 2D numpy.ndarray of shape vocab_size x emb_size, 27 | Dictionary of OOV vocab items. 28 | """ 29 | print('Creating word embeddings from %s...' % embedding_file_path) 30 | vocab_size = max(vocab.values()) + 1 31 | print('vocab_size = %s' % vocab_size) 32 | oov = dict(vocab) 33 | embeddings = np.random.normal(size=(vocab_size, emb_size))\ 34 | .astype('float32', copy=False) 35 | with open(embedding_file_path, 'r', encoding='utf-8') as f: 36 | for i, line in enumerate(f): 37 | s = line.split() 38 | if len(s) > 301: # a hack I have seemed to require for GloVe 840B 39 | s = [s[0]] + s[-300:] 40 | assert len(s) == 301 41 | if s[0] in vocab.keys(): 42 | if s[0] in oov.keys(): # seems we get some duplicate vectors. 43 | oov.pop(s[0]) 44 | try: 45 | embeddings[vocab[s[0]], :] = np.asarray(s[1:]) 46 | except Exception as e: 47 | print('i: %s' % i) 48 | print('s[0]: %s' % s[0]) 49 | print('vocab_[s[0]]: %s' % vocab[s[0]]) 50 | print('len(vocab): %s' % len(vocab)) 51 | print('vocab_min_val: %s' % min(vocab.values())) 52 | print('vocab_max_val: %s' % max(vocab.values())) 53 | raise e 54 | print('Success.') 55 | print('OOV count = %s' % len(oov)) 56 | print(oov) 57 | return embeddings, oov 58 | 59 | 60 | def create_vocab_dict(text): 61 | """Create vocab dictionary. 62 | 63 | Args: 64 | text: String. Join all the text in the corpus on a space. It will be 65 | tokenized by SpaCy. 66 | 67 | Returns: 68 | Dictionary {token: id}, collections.Counter() with token counts. 69 | """ 70 | nlp = spacy.load('en') 71 | doc = nlp(text) 72 | counter = collections.Counter() 73 | counter.update([t.text for t in doc]) 74 | tokens = set([t for t in counter] + [UNKNOWN, LBR, RBR]) 75 | # Make sure 0 is padding. 76 | vocab_dict = dict(zip(tokens, range(1, len(tokens) + 1))) 77 | assert PADDING not in vocab_dict.keys() 78 | assert 0 not in vocab_dict.values() 79 | vocab_dict[PADDING] = 0 80 | return vocab_dict, counter 81 | -------------------------------------------------------------------------------- /models/inference.py: -------------------------------------------------------------------------------- 1 | """Natural Language Inference model.""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import numpy as np 8 | from cstlstm import encoder 9 | from ext import models 10 | 11 | 12 | class InferenceModel(models.PyTorchModel): 13 | """Natural language inference model.""" 14 | 15 | def __init__(self, name, config, embedding_matrix): 16 | super(InferenceModel, self).__init__(name, config, embedding_matrix) 17 | 18 | # Define encoder. 19 | self.encoder = encoder.ChildSumTreeLSTMEncoder( 20 | self.embed_size, self.hidden_size, self.embedding, 21 | self.p_keep_input, self.p_keep_rnn) 22 | 23 | # Define dropouts 24 | self.drop_fc = nn.Dropout(p=1.0 - self.p_keep_fc) 25 | 26 | # Define MLP 27 | self.fc1 = nn.Linear(self.hidden_size * 4, self.hidden_size).cuda() 28 | self.fc2 = nn.Linear(self.hidden_size, self.hidden_size).cuda() 29 | self.logits_layer = nn.Linear(self.hidden_size, 3).cuda() 30 | 31 | # Define optimizer 32 | params = [{'params': self.encoder.cell.parameters()}, 33 | {'params': self.fc1.parameters()}, 34 | {'params': self.fc2.parameters()}, 35 | {'params': self.logits_layer.parameters()}] 36 | if self.tune_embeddings: 37 | params.append({'params': self.embeddings.parameters(), 38 | 'lr': self.learning_rate / 10.}) # Avoid overfitting 39 | self.optimizer = optim.Adam(params, lr=self.learning_rate) 40 | 41 | # Initialize parameters 42 | nn.init.xavier_uniform(self.fc1.weight.data, gain=np.sqrt(2.0)) 43 | nn.init.xavier_uniform(self.fc2.weight.data, gain=np.sqrt(2.0)) 44 | nn.init.xavier_uniform(self.logits_layer.weight.data, gain=1.) 45 | 46 | @staticmethod 47 | def current_batch_size(forest): 48 | return int(len(forest.nodes[0]) / 2) 49 | 50 | def forward(self, forest): 51 | labels = Variable( 52 | torch.from_numpy(np.array(forest.labels)), 53 | requires_grad=False).cuda() 54 | logits = self.logits(forest) 55 | loss = self.loss(logits, labels) 56 | predictions = self.predictions(logits).type_as(labels) 57 | correct = self.correct_predictions(predictions, labels) 58 | accuracy = self.accuracy(correct, self.current_batch_size(forest))[0] 59 | return predictions, loss, accuracy 60 | 61 | def logits(self, forest): 62 | # Following the DataLoader collate fn, the premises and hypotheses are 63 | # concatenated in the forest, in order, so splitting the root level of 64 | # the forest into two yields premises and hypotheses encodings. 65 | encodings = self.encoder.forward(forest)[0][1] # 1 selects hs, not cs. 66 | premises, hypotheses = encodings.split( 67 | self.current_batch_size(forest), 0) 68 | 69 | # Mou et al. concat layer 70 | diff = premises - hypotheses 71 | mul = premises * hypotheses 72 | x = torch.cat([premises, hypotheses, diff, mul], 1) 73 | 74 | # MLP 75 | h1 = self.drop_fc(F.relu(self.fc1(x))) 76 | h2 = self.drop_fc(F.relu(self.fc2(h1))) 77 | logits = self.logits_layer(h2) 78 | return logits 79 | -------------------------------------------------------------------------------- /cstlstm/encoder.py: -------------------------------------------------------------------------------- 1 | """Child-Sum Tree-LSTM batch sentence encoder module.""" 2 | import torch 3 | import torch.nn as nn 4 | from cstlstm import prev_states, cell 5 | from torch.autograd import Variable 6 | 7 | 8 | class ChildSumTreeLSTMEncoder(nn.Module): 9 | """Child-Sum Tree-LSTM Encoder Module.""" 10 | 11 | def __init__(self, embed_size, hidden_size, embeddings, 12 | p_keep_input, p_keep_rnn): 13 | """Create a new ChildSumTreeLSTMEncoder. 14 | 15 | Args: 16 | embed_size: Integer, number of units in word embeddings vectors. 17 | hidden_size: Integer, number of units in hidden state vectors. 18 | embeddings: torch.nn.Embedding. 19 | p_keep_input: Float, the probability of keeping an input unit. 20 | p_keep_rnn: Float, the probability of keeping an rnn unit. 21 | """ 22 | super(ChildSumTreeLSTMEncoder, self).__init__() 23 | 24 | self._embeddings = embeddings 25 | 26 | # Define dropout layer for embedding lookup 27 | self._drop_input = nn.Dropout(p=1.0 - p_keep_input) 28 | 29 | # Initialize the batch Child-Sum Tree-LSTM cell 30 | self.cell = cell.BatchChildSumTreeLSTMCell( 31 | input_size=embed_size, 32 | hidden_size=hidden_size, 33 | p_dropout=1.0 - p_keep_rnn).cuda() 34 | 35 | # Initialize previous states (to get wirings from nodes on lower level) 36 | self._prev_states = prev_states.PreviousStates(hidden_size) 37 | 38 | def forward(self, forest): 39 | """Get encoded vectors for each node in the forest. 40 | 41 | Args: 42 | nodes: Dictionary of structure {Integer (level_index): List (nodes)} 43 | where each node is represented by a ext.Node object. 44 | up_wirings: Dictionary of structure 45 | {Integer (level_index): List of Lists (up wirings)}, where the up 46 | wirings List is the same length as the number of nodes on the 47 | current level, and each sublist gives the indices of it's children 48 | on the lower level's node list, thus defining the upward wiring. 49 | 50 | Returns: 51 | Dictionary of hidden states for all nodes on all levels, indexed by 52 | level number, with the list order following that of forest.nodes[l] 53 | for each level, l. 54 | """ 55 | outputs = {} 56 | 57 | # Work backwards through level indices - i.e. bottom up. 58 | for l in reversed(range(forest.max_level + 1)): 59 | # Get input word vectors for this level. 60 | inputs = [(self._word_vec(n.vocab_ix) if n.token 61 | else self._prev_states.zero_vec()) 62 | for n in forest.nodes[l]] 63 | 64 | # Get previous hidden states for this level. 65 | if l == forest.max_level: 66 | hidden_states = self._prev_states.zero_level( 67 | len(forest.nodes[l])) 68 | else: 69 | hidden_states = self._prev_states( 70 | level_nodes=forest.nodes[l], 71 | level_up_wirings=forest.child_ixs[l], 72 | prev_outputs=outputs[l+1]) 73 | 74 | outputs[l] = self.cell(inputs, hidden_states) 75 | 76 | return outputs 77 | 78 | def _word_vec(self, vocab_ix): 79 | lookup_tensor = Variable( 80 | torch.LongTensor([vocab_ix]), 81 | requires_grad=False).cuda() 82 | word_vec = self._embeddings(lookup_tensor)\ 83 | .type(torch.FloatTensor)\ 84 | .cuda() 85 | word_vec = self._drop_input(word_vec) 86 | return word_vec 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cstlstm 2 | Child-Sum Tree-LSTM Implementation in PyTorch. 3 | 4 | Baed on the paper "Improved Semantic Representations From Tree-Structure Long Short-Term Memory Networks" by Tai et al. 2015 (http://www.aclweb.org/anthology/P15-1150). 5 | 6 | The standalone components are contained in the folder `cstlstm/`. 7 | 8 | Dependencies: 9 | ``` 10 | python 3.5 11 | PyTorch 0.2 12 | spacy 13 | numpy 14 | ``` 15 | 16 | ## Implementation Notes 17 | 18 | Tree-structured recursive neural networks are challenging because of computational efficiency issues, especially when compared to linear RNNs. Implementing Tai's model with recursion over nodes was found to be inefficient. 19 | 20 | Therefore, the model is implemented here with a parallelization strategy: to process all nodes on each (depth) level of the forest in parallel, working upwards from the deepest leaves. The LSTM cell (`cstlstm/cell.py`) is designed with this in mind: it is called a "BatchCell" for this reason. 21 | 22 | This does indeed improve speed considerably, however it is still slow compared to a linear RNN. 23 | 24 | WARNING: I have found that opting to tune the word embeddings during training is brutally inefficient. I hope this can be solved in the future and welcome any advice (as always). 25 | 26 | ## Encoder 27 | 28 | The file `cstlm/encoder.py` is a reusable encoder, a PyTorch `nn.module`. It will return a hidden state for every node in the forest. These are returned as a dictionary structure, mirroring the input structure (see "Input" section below). 29 | 30 | Usage example of the encoder: 31 | 32 | ``` 33 | import cstlstm 34 | encoder = cstlstm.ChildSumTreeLSTMEncoder( 35 | embed_size, hidden_size, embeddings, p_keep_input, p_keep_rnn) 36 | encodings = encoder(forest) 37 | ``` 38 | 39 | ## Input 40 | 41 | The file `cstlstm/tree_batch.py` deals with converting SpaCy sents, or dependency trees given as sexpressions, to Tree and Forest objects that contain all the information (critically, the upward wiring information) needed by the encoder. 42 | 43 | For example, with NLI data we are given Stanford Parser dependency parses as strings, which can be turned into a Tree as such: 44 | 45 | ``` 46 | from cstlstm import tree_batch 47 | s = '(ROOT (S (NP ... )))' 48 | tree = tree_batch.sexpr_to_tree(s) 49 | ``` 50 | 51 | Alternatively (and my personal preference), we can pass a sentence as a string to SpaCy to do our dependency parsing for us, and use `tree_batch.py` to turn it into a tree: 52 | 53 | ``` 54 | from cstlstm import tree_batch 55 | import spacy 56 | nlp = spacy.load('en') 57 | s = nlp('If he had a mind, there was something on it.') 58 | tree = tree_batch.sent_to_tree(s) 59 | ``` 60 | 61 | This gives us individual trees, but for mini-batching (or even for a single NLI training sample) we need to input more than one sentence. We need a Forest. `tree_batch.py` takes care of this, defining the `Forest` object with a constructor that takes a list of `Tree`s. 62 | 63 | ``` 64 | from cstlstm import tree_batch 65 | import spacy 66 | nlp = spacy.load('en') 67 | sents = [nlp('First sentence.'), nlp('Second sentence.')] 68 | trees = [tree_batch.sent_to_tree(s) for s in sents] 69 | forest = tree_batch.Forest(trees) 70 | ``` 71 | 72 | Inputting the trees in the order you want, the forest keeps the nodes ordered along each level - e.g. the roots are in order on the first level, so `Forest.nodes[0]` gives a list of root nodes that can be indexed by this order. This ordering is mirrored in the encoding output - so `encodings[0]` gives a list of hidden states in the same order. The same applies to the nodes on all levels. 73 | 74 | ## Models 75 | 76 | I have implemented this model for Natural Language Inference data (`models/inference.py`), my own research interest. That model trains successfully. 77 | 78 | The Stanford Sentiment Treebank model (`models/sentiment.py`) is incomplete. 79 | 80 | ## Getting Started 81 | 82 | TODO: explain glovar.py, necessary folders, data downloads, pre-processing. 83 | -------------------------------------------------------------------------------- /data/nli.py: -------------------------------------------------------------------------------- 1 | """For handling Natural Language Inference data.""" 2 | import json 3 | import random 4 | 5 | from torch.utils.data import dataset, dataloader 6 | 7 | import glovar 8 | from cstlstm import tree_batch 9 | from ext import NLP 10 | 11 | LABEL_MAP = { 12 | "entailment": 0, 13 | "neutral": 1, 14 | "contradiction": 2, 15 | "hidden": -1} 16 | NLI_DBS = ['snli', 'mnli'] 17 | NLI_COLLS = { 18 | 'snli': ['train', 'dev', 'test'], 19 | 'mnli': ['train', 20 | 'dev_matched', 'dev_mismatched', 21 | 'test_matched', 'test_mismatched']} 22 | 23 | 24 | def get_data_loader(data_set, batch_size): 25 | return dataloader.DataLoader( 26 | data_set, 27 | batch_size, 28 | shuffle=True, 29 | num_workers=4, 30 | collate_fn=data_set.collate) 31 | 32 | 33 | def get_text(): 34 | premises = [] 35 | hypotheses = [] 36 | for db in NLI_DBS: 37 | for coll in NLI_COLLS[db]: 38 | for x in load_json(db, coll): 39 | premises.append(x['sentence1']) 40 | hypotheses.append(x['sentence2']) 41 | premises = ' '.join(premises) 42 | hypotheses = ' '.join(hypotheses) 43 | nli_text = ' '.join([premises, hypotheses]) 44 | return nli_text 45 | 46 | 47 | def load_json(db, coll): 48 | filename = '%s%s/%s_%s.jsonl' % (glovar.DATA_DIR, db, db, coll) 49 | data = [] 50 | with open(filename, 'r') as file: 51 | for line in file.readlines(): 52 | x = json.loads(line) 53 | if x['gold_label'] in LABEL_MAP.keys(): 54 | data.append(x) 55 | return data 56 | 57 | 58 | class NLIDataSet(dataset.Dataset): 59 | def __init__(self, data, vocab_dict, subset_size=None): 60 | super(NLIDataSet, self).__init__() 61 | self.data = data 62 | self.subset_size = subset_size 63 | self.vocab_dict = vocab_dict 64 | self._prepare_epoch() 65 | self._subsample() 66 | random.shuffle(self.epoch_data) 67 | self.len = len(self.epoch_data) 68 | 69 | def __getitem__(self, index): 70 | item = self.epoch_data[index] 71 | if index == self.len - 1: 72 | self._prepare_epoch() 73 | self._subsample() 74 | random.shuffle(self.epoch_data) 75 | return item 76 | 77 | def __len__(self): 78 | return self.len 79 | 80 | def collate(self, batch_data): 81 | # Create a forest from premises and hypotheses, in order 82 | premises = [NLP(x['sentence1'].rstrip()) for x in batch_data] 83 | hypotheses = [NLP(x['sentence2'].rstrip()) for x in batch_data] 84 | premises = [tree_batch.sent_to_tree(x) for x in premises] 85 | hypotheses = [tree_batch.sent_to_tree(x) for x in hypotheses] 86 | forest = tree_batch.Forest(premises + hypotheses) 87 | # Get the labels 88 | forest.labels = [LABEL_MAP[x['gold_label']] for x in batch_data] 89 | # Pre-lookup dictionary ixs - the encoder expects an attribute vocab_ix 90 | for node in forest.node_list: 91 | node.vocab_ix = self.vocab_dict[node.token] 92 | return forest 93 | 94 | def _subsample(self): 95 | if self.subset_size: 96 | self.epoch_data = random.sample(self.epoch_data, self.subset_size) 97 | 98 | def _prepare_epoch(self): 99 | self.epoch_data = self.data 100 | 101 | 102 | class NYUDataSet(NLIDataSet): 103 | def __init__(self, mnli_train, snli_train, vocab_dict, 104 | subset_size=None, alpha=0.15): 105 | self.mnli_train = mnli_train 106 | self.snli_train = snli_train 107 | self.alpha = alpha 108 | self.n_snli = int(len(snli_train) * alpha) 109 | super(NYUDataSet, self).__init__([], vocab_dict, subset_size) 110 | 111 | def _prepare_epoch(self): 112 | self.epoch_data = self.mnli_train + random.sample( 113 | self.snli_train, self.n_snli) 114 | -------------------------------------------------------------------------------- /ext/histories.py: -------------------------------------------------------------------------------- 1 | """For tracking and saving training histories.""" 2 | import numpy as np 3 | from ext import pickling, models 4 | import glovar 5 | import os 6 | 7 | 8 | def get(pkl_dir, name, override, arg_config): 9 | print('Getting history with name %s; override=%s...' % (name, override)) 10 | pkl_name = 'history_%s.pkl' % name 11 | exists = os.path.exists(os.path.join(pkl_dir, pkl_name)) 12 | print('Exists: %s' % exists) 13 | if exists: 14 | if override: 15 | print('Overriding...') 16 | return History(name, models.Config(**arg_config)) 17 | else: 18 | print('Loading...') 19 | return pickling.load(pkl_dir, pkl_name) 20 | else: 21 | print('Creating...') 22 | return History(name, models.Config(**arg_config)) 23 | 24 | 25 | class History: 26 | """Wraps config, training run name, and all training history values.""" 27 | 28 | def __init__(self, name, config=None): 29 | """Create a new History. 30 | 31 | Args: 32 | name: String, unique identifying name of the training run. 33 | config: coldnet.models.Config. If creating a new History object, 34 | this cannot be None. 35 | 36 | Raises: 37 | ValueError: if name is not found and config is None. 38 | """ 39 | if not config: 40 | raise ValueError('config cannot be None for new Histories.') 41 | # Global Variables 42 | self.name = name # This ends up being the _id 43 | self.config = config 44 | # Epoch Variables 45 | self.global_epoch = 1 46 | self.epoch_losses = [] 47 | self.epoch_accs = [] 48 | self.epoch_times = [] 49 | self.cum_epoch_loss = 0. 50 | self.cum_epoch_acc = 0. 51 | self.best_epoch_acc = 0. 52 | # Step Variables 53 | self.global_step = 1 54 | self.epoch_step_times = [] # only keep for one epoch 55 | self.cum_loss = 0. 56 | self.cum_acc = 0. 57 | # Tuning Variables 58 | self.tuning_accs = [] 59 | 60 | def end_epoch(self, time_taken): 61 | self.epoch_times.append(time_taken) 62 | avg_time = np.average(self.epoch_times) 63 | self.epoch_losses.append(self.cum_epoch_loss) 64 | avg_loss = np.average(self.epoch_losses) 65 | change_loss = self.last_change(self.epoch_losses) 66 | self.epoch_accs.append(self.cum_epoch_acc) 67 | avg_acc = np.average(self.epoch_accs) 68 | change_acc = self.last_change(self.epoch_accs) 69 | is_best = avg_acc > self.best_epoch_acc 70 | if is_best: 71 | self.best_epoch_acc = avg_acc 72 | self.epoch_step_times = [] 73 | self.cum_epoch_loss = 0. 74 | self.cum_epoch_acc = 0. 75 | self.global_epoch += 1 76 | return avg_time, avg_loss, change_loss, avg_acc, change_acc, is_best 77 | 78 | def end_step(self, time_taken, loss, accuracy): 79 | self.epoch_step_times.append(time_taken) 80 | avg_time = np.average(self.epoch_step_times) 81 | self.cum_loss += loss 82 | avg_loss = self.cum_loss / self.global_step 83 | self.cum_acc += accuracy 84 | avg_acc = self.cum_acc / self.global_step 85 | self.cum_epoch_loss += loss 86 | self.cum_epoch_acc += accuracy 87 | self.global_step += 1 88 | return self.global_step, avg_time, avg_loss, avg_acc 89 | 90 | def end_tuning(self, accuracy): 91 | self.tuning_accs.append(accuracy) 92 | avg_acc = np.average(self.tuning_accs) 93 | change_acc = self.last_change(self.tuning_accs) 94 | return avg_acc, change_acc 95 | 96 | @staticmethod 97 | def last_change(series): 98 | if len(series) == 0: 99 | raise ValueError('Series has no elements.') 100 | elif len(series) == 1: 101 | return series[0] 102 | else: 103 | return series[-1] - series[-2] 104 | 105 | @staticmethod 106 | def load(name): 107 | pkl_name = 'history_%s.pkl' % name 108 | return pickling.load(glovar.PKL_DIR, pkl_name) 109 | 110 | def save(self): 111 | pickling.save(self, glovar.PKL_DIR, 'history_%s.pkl' % self.name) 112 | 113 | def to_json(self): 114 | json = dict(self.__dict__) 115 | json.pop('name') 116 | json['_id'] = self.name 117 | json['config'] = self.config.to_json() 118 | return json 119 | -------------------------------------------------------------------------------- /cstlstm/cell.py: -------------------------------------------------------------------------------- 1 | """Batch Child-Sum Tree-LSTM cell for parallel processing of nodes per level.""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BatchChildSumTreeLSTMCell(nn.Module): 8 | """Child-Sum Tree-LSTM Cell implementation for mini batches. 9 | 10 | Based on https://arxiv.org/abs/1503.00075. 11 | Equations on p.3 as follows. 12 | 13 | .. math:: 14 | 15 | \begin{array}{ll} 16 | \tilde{h_j} = \sum_{k \in C(j)} h_k \\ 17 | i_j = \mathrm{sigmoid}(W^{(i)} x_j + U^{(i)} \tilde{h}_j + b^{(i)}) \\ 18 | f_{jk} = \mathrm{sigmoid}(W^{(f)} x_j + U^{(f)} h_k + b^{(f)}) \\ 19 | o_j = \mathrm{sigmoid}(W^{(o)} x_j + U^{(o)} \tilde{h}_j + b^{(o)}) \\ 20 | u_j = \tanh(W^{(u)} x_j + U^{(u)} \tilde{h}_j + b^{(u)}) \\ 21 | c_j = i_j \circ u_j + \sum_{k \in C(j)} f_{jk} \circ c_k \\ 22 | h_j = o_j \circ \tanh(c_j) 23 | \end{array} 24 | """ 25 | 26 | def __init__(self, input_size, hidden_size, p_dropout): 27 | """Create a new ChildSumTreeLSTMCell. 28 | 29 | Args: 30 | input_size: Integer, the size of the input vector. 31 | hidden_size: Integer, the size of the hidden state to return. 32 | dropout: torch.nn.Dropout module. 33 | """ 34 | super(BatchChildSumTreeLSTMCell, self).__init__() 35 | self.input_size = input_size 36 | self.hidden_size = hidden_size 37 | self.dropout = nn.Dropout(p=p_dropout) 38 | self.W_combined = nn.Parameter( 39 | torch.Tensor(input_size + hidden_size, 3 * hidden_size), 40 | requires_grad=True) 41 | self.b_combined = nn.Parameter( 42 | torch.zeros(1, 3 * hidden_size), 43 | requires_grad=True) 44 | self.W_f = nn.Parameter( 45 | torch.Tensor(input_size, hidden_size), 46 | requires_grad=True) 47 | self.U_f = nn.Parameter( 48 | torch.Tensor(hidden_size, hidden_size), 49 | requires_grad=True) 50 | self.b_f = nn.Parameter( 51 | torch.zeros(1, hidden_size), 52 | requires_grad=True) 53 | nn.init.xavier_uniform(self.W_combined, gain=1.0) 54 | nn.init.xavier_uniform(self.W_f, gain=1.0) 55 | nn.init.xavier_uniform(self.U_f, gain=1.0) 56 | 57 | def forward(self, inputs, previous_states): 58 | """Calculate the next hidden state given the inputs. 59 | 60 | This is for custom control over a batch, designed for efficiency. 61 | I hope it is efficient... 62 | 63 | Args: 64 | inputs: List of Tensors of shape (1, input_size) - row vectors. 65 | previous_states: Tuple of (List, List), being cell_states and 66 | hidden_states respectively. Inside the lists, for nodes with 67 | multiple children, we expect they are already concatenated into 68 | matrices. 69 | 70 | Returns: 71 | cell_states, hidden_states: being state tuple, where both states are 72 | row vectors of length hidden_size. 73 | """ 74 | # prepare the inputs 75 | cell_states = previous_states[0] 76 | hidden_states = previous_states[1] 77 | inputs_mat = torch.cat(inputs) 78 | h_tilde_mat = torch.cat([torch.sum(h, 0).expand(1, self.hidden_size) 79 | for h in hidden_states], 80 | dim=0) 81 | prev_c_mat = torch.cat(cell_states, 0) 82 | big_cat_in = torch.cat([inputs_mat, h_tilde_mat], 1) 83 | 84 | # process in parallel those parts we can 85 | big_cat_out = big_cat_in.mm(self.W_combined) + self.b_combined.expand( 86 | big_cat_in.size()[0], 87 | 3 * self.hidden_size) 88 | z_i, z_o, z_u = big_cat_out.split(self.hidden_size, 1) 89 | 90 | # apply dropout to u, like the Fold boys 91 | z_u = self.dropout(z_u) 92 | 93 | # forget gates 94 | f_inputs = inputs_mat.mm(self.W_f) 95 | # we can concat the matrices along the row axis, 96 | # but we need to calculate cumsums for splitting after 97 | 98 | # NOTE: I could probably pass this information from pre-processing 99 | # yes, I think that's the idea: move this out. Test it out there. 100 | # then come back to here. That's my next job. And moving the other 101 | # stuff out of the CSTLSTM model. 102 | lens = [t.size()[0] for t in hidden_states] 103 | start = [sum([lens[j] for j in range(i)]) for i in range(len(lens))] 104 | end = [start[i] + lens[i] for i in range(len(lens))] 105 | 106 | # we can then go ahead and concatenate for matmul 107 | prev_h_mat = torch.cat(hidden_states, 0) 108 | f_hiddens = prev_h_mat.mm(self.U_f) 109 | # compute the f_jks by expanding the inputs to the same number 110 | # of rows as there are prev_hs for each, then just do a simple add. 111 | f_inputs_split = f_inputs.split(1, 0) 112 | f_inputs_expanded = [f_inputs_split[i].expand(lens[i], self.hidden_size) 113 | for i in range(len(lens))] 114 | f_inputs_ready = torch.cat(f_inputs_expanded, 0) 115 | f_jks = F.sigmoid( 116 | f_inputs_ready + f_hiddens + self.b_f.expand( 117 | f_hiddens.size()[0], self.hidden_size)) 118 | 119 | # cell and hidden state 120 | fc_mul = f_jks * prev_c_mat 121 | split_fcs = [fc_mul[start[i]:end[i]] for i in range(len(lens))] 122 | fc_term = torch.cat([torch.sum(item, 0).expand(1, self.hidden_size) 123 | for item in split_fcs]) 124 | c = F.sigmoid(z_i) * F.tanh(z_u) + fc_term 125 | h = F.sigmoid(z_o) * F.tanh(c) 126 | 127 | return c, h 128 | -------------------------------------------------------------------------------- /ext/models.py: -------------------------------------------------------------------------------- 1 | """Base classes for models.""" 2 | import torch 3 | from torch import nn 4 | 5 | 6 | FRAMEWORKS = ['tf', 'torch'] 7 | _DEFAULT_CONFIG = { 8 | 'batch_size': 32, 9 | 'embed_size': 300, 10 | 'hidden_size': 300, 11 | 'projection_size': 200, 12 | 'learning_rate': 1e-3, 13 | 'grad_clip_norm': 0.0, 14 | '_lambda': 0.0, 15 | 'p_keep_input': 0.9, 16 | 'p_keep_rnn': 0.9, 17 | 'p_keep_fc': 0.9, 18 | 'tune_embeddings': True 19 | } 20 | 21 | 22 | class Config: 23 | """Wrapper of config variables.""" 24 | 25 | def __init__(self, default=_DEFAULT_CONFIG, **kwargs): 26 | """Create a new Config. 27 | 28 | Args: 29 | default: Dictionary of default values. These can be passed in, or else 30 | the _DEFAULT_CONFIG from this file will be used. 31 | """ 32 | self.default = default 33 | self.kwargs = kwargs 34 | self.batch_size = self._value('batch_size', kwargs) 35 | self.embed_size = self._value('embed_size', kwargs) 36 | self.hidden_size = self._value('hidden_size', kwargs) 37 | self.projection_size = self._value('projection_size', kwargs) 38 | self.learning_rate = self._value('learning_rate', kwargs) 39 | self.grad_clip_norm = self._value('grad_clip_norm', kwargs) 40 | self._lambda = self._value('_lambda', kwargs) 41 | self.p_keep_input = self._value('p_keep_input', kwargs) 42 | self.p_keep_rnn = self._value('p_keep_rnn', kwargs) 43 | self.p_keep_fc = self._value('p_keep_fc', kwargs) 44 | self.tune_embeddings = self._value('tune_embeddings', kwargs) 45 | for key in [k for k in kwargs.keys() 46 | if k not in self.default.keys()]: 47 | setattr(self, key, kwargs[key]) 48 | 49 | def __delitem__(self, key): 50 | pass 51 | 52 | def __getitem__(self, key): 53 | return self.__getattribute__(key) 54 | 55 | def __repr__(self): 56 | x = 'Config as follows:\n' 57 | for key in sorted(self.keys()): 58 | x += '\t%s \t%s%s\n' % \ 59 | (key, '\t' if len(key) < 15 else '', self[key]) 60 | return x 61 | 62 | def __setitem__(self, key, value): 63 | self.__setattr__(key, value) 64 | 65 | def dropout_keys(self): 66 | return [k for k in self.__dict__.keys() if k.startswith('p_keep_')] 67 | 68 | def keys(self): 69 | return [key for key in self.__dict__.keys() 70 | if key not in ['default', 'kwargs']] 71 | 72 | def to_json(self): 73 | return dict(self.__dict__) 74 | 75 | def _value(self, key, kwargs): 76 | if key in kwargs.keys(): 77 | return kwargs[key] 78 | else: 79 | return self.default[key] 80 | 81 | 82 | class Model: 83 | """Base class for a model of any kind.""" 84 | 85 | def __init__(self, framework, config): 86 | """Create a new Model. 87 | 88 | Args: 89 | framework: String, the framework of the model, e.g. 'pytorch'. 90 | config: Config object, a configuration settings wrapper. 91 | """ 92 | self.framework = framework 93 | self.config = config 94 | for key in config.keys(): 95 | setattr(self, key, config[key]) 96 | 97 | def accuracy(self, *args): 98 | raise NotImplementedError 99 | 100 | def forward(self, *args): 101 | """Forward step of the network. 102 | 103 | Returns: 104 | predictions, loss, accuracy. 105 | """ 106 | raise NotImplementedError 107 | 108 | def logits(self, *args): 109 | raise NotImplementedError 110 | 111 | def loss(self, *args): 112 | raise NotImplementedError 113 | 114 | def optimize(self, *args): 115 | raise NotImplementedError 116 | 117 | def predictions(self, *args): 118 | raise NotImplementedError 119 | 120 | 121 | class PyTorchModel(Model, nn.Module): 122 | """Base for a PyTorch model.""" 123 | 124 | def __init__(self, name, config, embedding_matrix): 125 | Model.__init__(self, 'pytorch', config) 126 | nn.Module.__init__(self) 127 | 128 | self.name = name 129 | 130 | # Define embedding. 131 | self.embedding = nn.Embedding( 132 | embedding_matrix.shape[0], embedding_matrix.shape[1], sparse=False) 133 | embedding_tensor = torch.from_numpy(embedding_matrix) 134 | self.embedding.weight = nn.Parameter( 135 | embedding_tensor, 136 | requires_grad=self.tune_embeddings) 137 | self.embedding.cuda() 138 | 139 | # Define loss 140 | self.criterion = torch.nn.CrossEntropyLoss().cuda() 141 | 142 | @staticmethod 143 | def accuracy(correct_predictions, batch_size): 144 | # batch_size may vary - i.e. the last batch of the data set. 145 | correct = correct_predictions.cpu().sum().data.numpy() 146 | return correct / float(batch_size) 147 | 148 | def _biases(self): 149 | return [p for n, p in self.named_parameters() if n in ['bias']] 150 | 151 | @staticmethod 152 | def correct_predictions(predictions, labels): 153 | return predictions.eq(labels) 154 | 155 | def forward(self, forest): 156 | # Need to return predictions, loss, accuracy 157 | raise NotImplementedError 158 | 159 | def logits(self, forest): 160 | raise NotImplementedError 161 | 162 | def loss(self, logits, labels): 163 | loss = self.criterion(logits, labels) 164 | return loss 165 | 166 | def optimize(self, loss): 167 | loss.backward() 168 | self.optimizer.step() 169 | 170 | @staticmethod 171 | def predictions(logits): 172 | return logits.max(1)[1] 173 | 174 | def _weights(self): 175 | return [p for n, p in self.named_parameters() if n in ['weight']] 176 | 177 | def zero_grad(self): 178 | self.optimizer.zero_grad() 179 | -------------------------------------------------------------------------------- /data/sst.py: -------------------------------------------------------------------------------- 1 | """For handling the Stanford Sentiment Treebank data.""" 2 | import os 3 | 4 | import spacy 5 | from nltk.tokenize import sexpr 6 | from torch.utils.data import dataset, dataloader 7 | 8 | import glovar 9 | from cstlstm import tree_batch 10 | from ext import pickling 11 | 12 | NLP = spacy.load('en') 13 | SST_DIR = os.path.join(glovar.DATA_DIR, 'sst/') 14 | 15 | 16 | def annotate_data(): 17 | raw_data = get_raw_data() 18 | parsed_data = get_parsed_data(raw_data) 19 | # combining text at nodes occurs within these functions 20 | sst_trees = get_sst_trees(raw_data) 21 | dep_trees = get_dep_trees(parsed_data) 22 | # use compare and annotate to annotate 23 | for dataset in dep_trees.keys(): 24 | dep_set = dep_trees[dataset] 25 | sst_set = sst_trees[dataset] 26 | for i in range(len(dep_set)): 27 | compare_and_annotate(sst_set[i], dep_set[i]) 28 | # report every so often to check integrity 29 | if i % 100 == 0: 30 | print('ORIGINAL') 31 | for node in sst_set[i].node_list: 32 | print('%s\t%s\t%s' % (node.id, node.tag, node.text_at_node)) 33 | print('DEP') 34 | for node in dep_set[i].node_list: 35 | print('%s\t%s\t%s' % ( 36 | node.id, node.annotation, node.text_at_node)) 37 | # save a pickle 38 | pickling.save(dep_trees, glovar.PKL_DIR, 'annotated_dep_trees.pkl') 39 | return dep_trees 40 | 41 | 42 | def compare_and_annotate(sst_tree, dep_tree): 43 | for dep_node in dep_tree.node_list: 44 | # init an empty property as all nodes need one 45 | dep_node.annotation = None 46 | dep_doc = NLP(dep_node.text_at_node) 47 | # check for a match in the sst tree 48 | for sst_node in sst_tree.node_list: 49 | sst_doc = NLP(sst_node.text_at_node) 50 | if len(dep_doc) == len(sst_doc): 51 | match = True 52 | i = 0 53 | while match and i <= len(dep_doc) - 1: 54 | match = dep_doc[i].text == sst_doc[i].text 55 | i += 1 56 | if match: 57 | dep_node.annotation = sst_node.tag 58 | 59 | 60 | def get_data(): 61 | vocab_dict = load_vocab_dict() 62 | data = pickling.load(glovar.PKL_DIR, 'annotated_dep_trees.pkl') 63 | train = SSTDataset(data['train'], vocab_dict) 64 | dev = SSTDataset(data['dev'], vocab_dict) 65 | test = SSTDataset(data['test'], vocab_dict) 66 | return train, dev, test 67 | 68 | 69 | def get_data_loader(data_set, batch_size): 70 | return dataloader.DataLoader( 71 | data_set, 72 | batch_size, 73 | shuffle=True, 74 | num_workers=4, 75 | collate_fn=data_set.collate) 76 | 77 | 78 | def get_dep_trees(parsed_data): 79 | dep_trees = {} 80 | for dataset in parsed_data.keys(): 81 | trees = [] 82 | for sample in parsed_data[dataset]: 83 | sent = NLP(sample['text']) 84 | tree = tree_batch.sent_to_tree(sent) 85 | tree_batch.combine_text_at_nodes(tree) 86 | trees.append(tree) 87 | dep_trees[dataset] = trees 88 | return dep_trees 89 | 90 | 91 | def get_parsed_data(raw_data): 92 | parsed = {} 93 | for dataset in raw_data.keys(): 94 | labels_texts = [] 95 | for root_sexpr in raw_data[dataset]: 96 | label, text = parse(root_sexpr) 97 | labels_texts.append({'label': label, 'text': text}) 98 | parsed[dataset] = labels_texts 99 | return parsed 100 | 101 | 102 | def get_raw_data(): 103 | data = {} 104 | files = ['train.txt', 'dev.txt', 'test.txt'] 105 | for file in files: 106 | with open(SST_DIR + file, 'r') as f: 107 | lines = f.readlines() 108 | lines = [l.rstrip() for l in lines] 109 | data[file.split('.')[0]] = lines 110 | return data 111 | 112 | 113 | def get_sst_trees(raw_data): 114 | sst_trees = {} 115 | for dataset in raw_data.keys(): 116 | trees = [] 117 | for sexpr in raw_data[dataset]: 118 | tree = tree_batch.sexpr_to_tree(sexpr) 119 | tree_batch.combine_text_at_nodes(tree) 120 | trees.append(tree) 121 | sst_trees[dataset] = trees 122 | return sst_trees 123 | 124 | 125 | def get_text(): 126 | sst_raw = get_raw_data() 127 | sst_parsed = get_parsed_data(sst_raw) 128 | sst_data = sst_parsed['train'] + sst_parsed['dev'] + sst_parsed['test'] 129 | sst_text = ' '.join([s['text'] for s in sst_data]) 130 | return sst_text 131 | 132 | 133 | def load_vocab_dict(): 134 | return pickling.load(glovar.PKL_DIR, 'vocab_dict.pkl') 135 | 136 | 137 | def parse(root_sexpr): 138 | label, sub_sexpr = root_sexpr[1:-1].split(None, 1) 139 | tokens = [] 140 | stack = Stack() 141 | for sub_sexpr in reversed(sexpr.sexpr_tokenize(sub_sexpr)): 142 | stack.push(sub_sexpr) 143 | while not stack.empty: 144 | _, next_sexpr = stack.pop()[1:-1].split(None, 1) 145 | # Leaf: if the length of the next is 1 and the string isn't in brackets 146 | next_sexprs = sexpr.sexpr_tokenize(next_sexpr) 147 | if len(next_sexprs) == 1 and ('(' not in next_sexprs[0] 148 | and ')' not in next_sexprs): 149 | tokens.append(next_sexprs[0]) 150 | # Otherwise, add them to the stack in reverse order 151 | else: 152 | for sub_sexpr in reversed(next_sexprs): 153 | stack.push(sub_sexpr) 154 | return label, ' '.join(tokens) 155 | 156 | 157 | class SSTDataset(dataset.Dataset): 158 | """Dataset wrapper for the Stanford Sentiment Treebank.""" 159 | 160 | def __init__(self, data, vocab_dict): 161 | super(SSTDataset, self).__init__() 162 | self.data = list(data) 163 | self.len = len(self.data) 164 | self.vocab_dict = vocab_dict 165 | 166 | def __getitem__(self, index): 167 | return self.data[index] 168 | 169 | def __len__(self): 170 | return self.len 171 | 172 | @staticmethod 173 | def annotation_ixs(forest): 174 | ixs = {} 175 | for l in range(forest.max_level + 1): 176 | l_nodes = forest.nodes[l] 177 | l_ixs = [i for i in range(len(l_nodes)) if l_nodes[i].annotation] 178 | ixs[l] = l_ixs 179 | return ixs 180 | 181 | def collate(self, batch_data): 182 | """For collating a batch of trees. 183 | 184 | Args: 185 | batch_data: List of tree_batch.Tree. 186 | 187 | Returns: 188 | tree_batch.Forest. 189 | """ 190 | forest = tree_batch.Forest(batch_data) 191 | forest.labels = [] 192 | 193 | # Setting annotation_ixs here necessary downstream and for labels 194 | forest.annotation_ixs = self.annotation_ixs(forest) 195 | 196 | # Get labels and pre-emptively perform dictionary lookup. 197 | for l in range(forest.max_level + 1): 198 | forest.labels += [int(forest.nodes[l][i].annotation) 199 | for i in forest.annotation_ixs[l]] 200 | for node in [n for n in forest.nodes[l] if n.token]: 201 | node.vocab_ix = self.vocab_dict[node.token] 202 | 203 | return forest 204 | 205 | 206 | class Stack: 207 | # Internal utility class for parsing sexprs in the correct order. 208 | 209 | def __init__(self): 210 | self.items = [] 211 | 212 | @property 213 | def empty(self): 214 | return len(self.items) == 0 215 | 216 | def push(self, item): 217 | self.items.append(item) 218 | 219 | def pop(self): 220 | item = self.items[-1] 221 | del self.items[-1] 222 | return item 223 | -------------------------------------------------------------------------------- /ext/training.py: -------------------------------------------------------------------------------- 1 | """Base code for training.""" 2 | import torch 3 | import time 4 | import numpy as np 5 | import os 6 | 7 | 8 | # Utility Functions 9 | 10 | 11 | def pretty_time(secs): 12 | """Get a readable string for a quantity of seconds. 13 | Args: 14 | secs: Integer, seconds. 15 | Returns: 16 | String, nicely formatted. 17 | """ 18 | if secs < 60.0: 19 | return '%4.2f secs' % secs 20 | elif secs < 3600.0: 21 | return '%4.2f mins' % (secs / 60) 22 | elif secs < 86400.0: 23 | return '%4.2f hrs' % (secs / 60 / 60) 24 | else: 25 | return '%3.2f days' % (secs / 60 / 60 / 24) 26 | 27 | 28 | def _print_dividing_lines(): 29 | # For visuals, when reporting results to terminal. 30 | print('--------\t ----------------\t------------------' 31 | '\t--------\t--------') 32 | 33 | 34 | def _print_epoch_start(epoch): 35 | _print_dividing_lines() 36 | print('Epoch %s \t loss \t accuracy ' 37 | '\tt(avg.)\t\tremaining' 38 | % epoch) 39 | print(' \t last avg. \t last avg. \t \t') 40 | _print_dividing_lines() 41 | 42 | 43 | # Base Trainer Class 44 | 45 | 46 | class TrainerBase: 47 | """Wraps a model and implements a train method.""" 48 | 49 | def __init__(self, model, history, train_loader, tune_loader, ckpt_dir): 50 | """Create a new training wrapper. 51 | Args: 52 | model: any model to be trained, be it TensorFlow or PyTorch. 53 | history: histories.History object for storing training statistics. 54 | train_loader: the data to be used for training. 55 | tune_loader: the data to be used for tuning; can be list of data sets. 56 | ckpt_dir: String, path to checkpoint file directory. 57 | """ 58 | self.model = model 59 | self.history = history 60 | self.train_loader = train_loader 61 | self.tune_loader = tune_loader 62 | self.batches_per_epoch = len(train_loader) 63 | self.ckpt_dir = ckpt_dir 64 | # Load the latest checkpoint if necessary 65 | if self.history.global_step > 1: 66 | print('Loading last checkpoint...') 67 | self._load_last() 68 | 69 | def _checkpoint(self, is_best): 70 | raise NotImplementedError('Deriving classes must implement.') 71 | 72 | def ckpt_path(self, is_best): 73 | return os.path.join( 74 | self.ckpt_dir, 75 | '%s_%s' % (self.model.name, 'best' if is_best else 'latest')) 76 | 77 | def _end_epoch(self): 78 | self._epoch_end = time.time() 79 | time_taken = self._epoch_end - self._epoch_start 80 | avg_time, avg_loss, change_loss, avg_acc, change_acc, is_best = \ 81 | self.history.end_epoch(time_taken) 82 | self._report_epoch(avg_time) 83 | self._checkpoint(is_best) 84 | self.history.save() 85 | 86 | def _end_step(self, loss, acc): 87 | self.step_end = time.time() 88 | time_taken = self.step_end - self.step_start 89 | global_step, avg_time, avg_loss, avg_acc = \ 90 | self.history.end_step(time_taken, loss, acc) 91 | self._report_step(global_step, loss, avg_loss, acc, avg_acc, avg_time) 92 | 93 | def _load_last(self): 94 | raise NotImplementedError('Deriving classes must implement.') 95 | 96 | @property 97 | def progress_percent(self): 98 | percent = (self.history.global_step % self.batches_per_epoch) \ 99 | / self.batches_per_epoch \ 100 | * 100 101 | rounded = int(np.ceil(percent / 10.0) * 10) 102 | return rounded 103 | 104 | def _report_epoch(self, avg_time): 105 | _print_dividing_lines() 106 | print('\t\t\t\t\t\t\t%s' 107 | % pretty_time(np.average(avg_time))) 108 | 109 | @property 110 | def report_every(self): 111 | return int(np.floor(self.batches_per_epoch / 10)) 112 | 113 | def _report_step(self, global_step, loss, avg_loss, acc, avg_acc, avg_time): 114 | if global_step % self.report_every == 0: 115 | print('%s%%:\t\t' 116 | '%8.4f %8.4f\t' 117 | '%6.4f%% %6.4f%%\t' 118 | '%s\t' 119 | '%s' 120 | % (self.progress_percent, 121 | loss, 122 | avg_loss, 123 | acc * 100, 124 | avg_acc * 100, 125 | pretty_time(avg_time), 126 | pretty_time(avg_time * self.steps_remaining))) 127 | 128 | def _start_epoch(self): 129 | _print_epoch_start(self.history.global_epoch) 130 | self.model.train() 131 | self._epoch_start = time.time() 132 | 133 | def _start_step(self): 134 | self.step_start = time.time() 135 | 136 | def step(self, *args): 137 | """Take a training step. 138 | Calculate loss and accuracy and do optimization. 139 | Returns: 140 | Float, Float: loss, accuracy for the batch. 141 | """ 142 | raise NotImplementedError('Deriving classes must implement.') 143 | 144 | @property 145 | def steps_remaining(self): 146 | return self.batches_per_epoch \ 147 | - (self.history.global_step % self.batches_per_epoch) 148 | 149 | def _stopping_condition_met(self): 150 | # Override this method to set a custom stopping condition. 151 | return False 152 | 153 | def train(self): 154 | """Run the training algorithm.""" 155 | while not self._stopping_condition_met(): 156 | self._start_epoch() 157 | for _, batch in enumerate(self.train_loader): 158 | self._start_step() 159 | loss, acc = self.step(batch) 160 | self._end_step(loss, acc) 161 | self._tuning() 162 | self._end_epoch() 163 | 164 | def _tune(self, tune_loader): 165 | cum_acc = 0. 166 | for _, batch in enumerate(tune_loader): 167 | _, _, acc = self.model.forward(batch) 168 | cum_acc += acc 169 | tuning_acc = cum_acc / len(tune_loader) 170 | avg_acc, change_acc = self.history.end_tuning(tuning_acc) 171 | print('Tuning accuracy: %5.3f%%' % tuning_acc) 172 | print('Average tuning accuracy: %5.3f%% (%s%5.3f%%)' % 173 | (avg_acc * 100, 174 | '+' if change_acc > 0 else '', 175 | change_acc * 100)) 176 | 177 | def _tuning(self): 178 | self.model.eval() 179 | if isinstance(self.tune_loader, list): 180 | for tune_loader in self.tune_loader: 181 | self._tune(tune_loader) 182 | else: 183 | self._tune(self.tune_loader) 184 | self.model.train() 185 | 186 | 187 | # PyTorch Trainer 188 | 189 | 190 | class PyTorchTrainer(TrainerBase): 191 | """Training wrapper for a PyTorch model.""" 192 | 193 | def __init__(self, model, history, train_loader, tune_loader, ckpt_dir): 194 | """Create a new PyTorchTrainer. 195 | 196 | Args: 197 | model: a Pytorch model that inherits from torch.nn.Module. 198 | history: History object. 199 | train_loader: torch.util.data.dataloader.DataLoader. 200 | tune_loader: torch.util.data.dataloader.DataLoader. 201 | """ 202 | super(PyTorchTrainer, self).__init__( 203 | model, history, train_loader, tune_loader, ckpt_dir) 204 | self.model.cuda() 205 | 206 | def _checkpoint(self, is_best): 207 | file_path = self.ckpt_path(False) 208 | torch.save(self.model.state_dict(), file_path) 209 | if is_best: 210 | print('Saving checkpoint with new best tuning accuracy...') 211 | file_path = self.ckpt_path(True) 212 | torch.save(self.model.state_dict(), file_path) 213 | 214 | def _load_last(self): 215 | file_path = self.ckpt_path(False) 216 | self.model.load_state_dict(torch.load(file_path)) 217 | 218 | def step(self, batch): 219 | self.model.zero_grad() 220 | _, loss, acc = self.model.forward(batch) 221 | self.model.optimize(loss) 222 | return loss.cpu().data.numpy()[0], acc 223 | 224 | 225 | class Saver: 226 | """For loading and saving models.""" 227 | 228 | def __init__(self, ckpt_dir): 229 | self.ckpt_dir = ckpt_dir 230 | 231 | def ckpt_path(self, name, is_best): 232 | return os.path.join( 233 | self.ckpt_dir, 234 | '%s_%s' % (name, 'best' if is_best else 'latest')) 235 | 236 | def load(self, model, name, is_best): 237 | path = self.ckpt_path(name, is_best) 238 | print('Loading checkpoint at %s...' % path) 239 | model.load_state_dict(torch.load(path)) 240 | 241 | def save(self, model, name, is_best): 242 | path = self.ckpt_path(name, is_best) 243 | torch.save(model.state_dict(), path) 244 | if is_best: 245 | print('Checkpointing with new best accuracy...') 246 | path = self.ckpt_path(name, is_best=True) 247 | torch.save(model.state_dict(), path) 248 | -------------------------------------------------------------------------------- /cstlstm/tree_batch.py: -------------------------------------------------------------------------------- 1 | """Tree data structures and functions for parallel processing.""" 2 | import numpy as np 3 | from nltk.tokenize import sexpr 4 | 5 | 6 | def cumsum(seq): 7 | """Get the cumulative sum of a sequence of sequences at each index. 8 | 9 | Args: 10 | seq: List of sequences. 11 | 12 | Returns: 13 | List of integers. 14 | """ 15 | r, s = [], 0 16 | for e in seq: 17 | l = len(e) 18 | r.append(l + s) 19 | s += l 20 | return r 21 | 22 | 23 | def flatten_list_of_lists(list_of_lists): 24 | """Flatten a list of lists. 25 | 26 | Args: 27 | list_of_lists: List of Lists. 28 | 29 | Returns: 30 | List. 31 | """ 32 | return [item for sub_list in list_of_lists for item in sub_list] 33 | 34 | 35 | def get_adj_mat(nodes): 36 | """Get an adjacency matrix from a node set. 37 | 38 | A row in the matrix indicates the children of the node at that index. 39 | A column in the matrix indicates the parent of the node at that index. 40 | 41 | Args: 42 | nodes: List of Nodes. 43 | 44 | Returns: 45 | 2D numpy.ndarray: an adjacency matrix. 46 | """ 47 | size = len(nodes) 48 | mat = np.zeros((size, size), dtype='int32') 49 | for node in nodes: 50 | if node.parent_id >= 0: 51 | mat[node.parent_id][node.id] = 1 52 | return mat 53 | 54 | 55 | def get_child_ixs(nodes, adj_mat): 56 | """Get lists of children indices at each level. 57 | 58 | We need this for batching, to show the wiring of the nodes at each level, 59 | as we process them in parallel. 60 | 61 | Args: 62 | nodes: Dictionary of {Integer: [List of Nodes]} for the nodes at each 63 | level in the tree / forest. 64 | adj_mat: 2D numpy.ndarray, adjacency matrix for all nodes. 65 | 66 | Returns: 67 | Dictionary of {Integer: [[List of child_ixs @ l+1] for parent_ixs @ l]}. 68 | """ 69 | child_ixs = {} 70 | # We don't need child_ixs for the last level so just range(max_level) not +1 71 | for l in range(max(nodes.keys())): 72 | child_nodes = nodes[l+1] 73 | id_to_ix = {child_nodes[ix].id: ix for ix in 74 | range(len(child_nodes))} 75 | ids = [np.nonzero(adj_mat[n.id])[0] for n in nodes[l]] 76 | try: 77 | ixs = [[id_to_ix[id] for id in id_list] for id_list in ids] 78 | except Exception as e: 79 | print('level: %s' % l) 80 | print('child_ixs state') 81 | print(child_ixs) 82 | print('child_nodes') 83 | print(child_nodes) 84 | print('id_to_ix') 85 | print(id_to_ix) 86 | raise e 87 | child_ixs[l] = ixs 88 | return child_ixs 89 | 90 | 91 | def get_max_level(nodes): 92 | """Get the highest level number given a list of nodes. 93 | 94 | Args: 95 | nodes: List of Nodes. 96 | 97 | Returns: 98 | Integer, the highest level number. It is a zero-based number, so if later 99 | the actual number of levels is desired, will need to add one to this. 100 | """ 101 | return max([n.level for n in nodes]) 102 | 103 | 104 | def get_nodes_at_levels(nodes): 105 | """Get a dictionary listing nodes at each level. 106 | 107 | Args: 108 | nodes: List of Nodes. 109 | 110 | Returns: 111 | Dictionary of {Integer: [List of Nodes]} for each level. 112 | """ 113 | max_level = get_max_level(nodes) 114 | return dict(zip( 115 | range(max_level+1), 116 | [[n for n in nodes if n.level == l] 117 | for l in range(max_level+1)])) 118 | 119 | 120 | def get_parent_ixs(nodes, adj_mat): 121 | """Get lists of parent indices at each level. 122 | 123 | We need this for batching, to show the wiring of the nodes at each level, 124 | as we process them in parallel. 125 | 126 | Args: 127 | nodes: Dictionary of {Integer: [List of Nodes]} for the nodes at each 128 | level in the tree / forest. 129 | adj_mat: 2D numpy.ndarray, adjacency matrix for all nodes. 130 | 131 | Returns: 132 | Dictionary of {Integer: [List of parent_ixs @ l-1 for child_ixs @ l]}. 133 | """ 134 | parent_ixs = {} 135 | # We don't need parent_ixs for the first level, 0. 136 | for l in range(1, max(nodes.keys()) + 1): 137 | parent_nodes = nodes[l - 1] 138 | id_to_ix = {parent_nodes[ix].id: ix for ix in 139 | range(len(parent_nodes))} 140 | ids = [np.nonzero(adj_mat[:, n.id])[0][0] for n in nodes[l]] 141 | ixs = [id_to_ix[id] for id in ids] 142 | parent_ixs[l] = ixs 143 | return parent_ixs 144 | 145 | 146 | def offset_node_lists(node_lists): 147 | """Offset the ids in the list of node lists. 148 | 149 | Args: 150 | node_lists: List of Lists of Nodes. 151 | 152 | Returns: 153 | List of Lists of Nodes. 154 | """ 155 | cumsums = cumsum(node_lists) 156 | for list_ix in range(len(node_lists)): 157 | for node in node_lists[list_ix]: 158 | offset = cumsums[list_ix - 1] if list_ix > 0 else 0 159 | node.id = node.id + offset 160 | node.parent_id = node.parent_id + offset \ 161 | if node.parent_id > 0 \ 162 | else -1 163 | node.text_ix = node.text_ix + offset 164 | return node_lists 165 | 166 | 167 | # Model Classes 168 | 169 | 170 | class Forest: 171 | """Forest data structure. 172 | 173 | Designed for the parallel processing of trees in a batch. Will offset ixs 174 | of it's constituent trees, and define global wirings between all levels, 175 | allowing each level to be processed in parallel either upwards or downwards. 176 | 177 | Attributes: 178 | trees: List of Trees. 179 | node_list: List of all nodes in the forest. 180 | nodes: Dictionary of {Integer: [List of Nodes]}, defining the nodes at 181 | each level of depth. 182 | size: Integer, the number of nodes in the forest. 183 | max_level: Integer, the maximum level (depth) of the deepest tree in the 184 | forest. 185 | adj_mat: 2d numpy.array, adjacency matrix for all nodes. 186 | child_ixs: Dictionary {Int: [List of List of ixs]}, defining the upward 187 | wirings. 188 | parent_ixs: Dictionary {Int: [List of ixs]}, defining the downward 189 | wirings. 190 | """ 191 | 192 | def __init__(self, trees): 193 | """Create a new Forest. 194 | 195 | Args: 196 | trees: List of Trees. They will be processed in order. Pass them in 197 | the desired order. 198 | """ 199 | self.trees = trees 200 | node_lists = offset_node_lists([tree.node_list for tree in trees]) 201 | self.node_list = flatten_list_of_lists(node_lists) 202 | self.nodes = get_nodes_at_levels(self.node_list) 203 | self.size = len(self.node_list) 204 | self.max_level = get_max_level(self.node_list) 205 | self.adj_mat = get_adj_mat(self.node_list) 206 | self.child_ixs = get_child_ixs(self.nodes, self.adj_mat) 207 | #self.parent_ixs = get_parent_ixs(self.nodes, self.adj_mat) 208 | 209 | 210 | class Node: 211 | """Node data structure. 212 | 213 | Attributes: 214 | tag: String, the tag of the token - e.g. VBP. 215 | pos: String, the part of speech - e.g. VERB. 216 | token: String, the text of the token - e.g. 'do'. 217 | id: Integer, the unique id of the node in it's original tree. 218 | parent_id: Integer, the unique id of the node's prent in it's original 219 | tree. For ROOT nodes, this should be -1 by convention. 220 | relationship: String, the relation of this node to the parent - e.g. 221 | 'aux'. For the ROOT of a tree, this should be 'ROOT' by convention. 222 | text_ix: Integer, the unique index of this node in the order of text, if 223 | any. 224 | level: Integer, the level (depth) this node is on in it's original tree, 225 | where the ROOT level is zero-indexed. 226 | is_leaf: Boolean indicating whether this node is a leaf. 227 | """ 228 | 229 | def __init__(self, tag, pos, token, id, parent_id, relationship, text_ix, 230 | level, is_leaf): 231 | """Create a new Node.""" 232 | self.tag = tag 233 | self.pos = pos 234 | self.token = token 235 | self.id = id 236 | self.parent_id = parent_id 237 | self.relationship = relationship 238 | self.text_ix = text_ix 239 | self.level = level 240 | self.is_leaf = is_leaf 241 | self.has_token = token is not None 242 | self.vocab_ix = None # For vocab_dict index 243 | 244 | def __repr__(self): 245 | return '\n'.join(['%s: %s' % (key, value) 246 | for key, value 247 | in self.__dict__.items()]) 248 | 249 | 250 | class Tree: 251 | """Tree data structure. 252 | 253 | Attributes: 254 | node_list: List of all nodes in the tree. 255 | nodes: Dictionary {Int: [List of Nodes]}, giving the Nodes at each level. 256 | size: Integer, the count of the nodes in the tree. 257 | max_level: Integer, the max level (depth) of the tree. 258 | adj_mat: 2D numpy.ndarray, an adjacency matrix giving the relationships 259 | between all nodes. 260 | child_ixs: Dictionary {Int: [List of List of ixs]}, defining the upward 261 | wirings. 262 | parent_ixs: Dictionary {Int: [List of ixs]}, defining the downward 263 | wirings. 264 | """ 265 | 266 | def __init__(self, nodes): 267 | """Create a new Tree. 268 | 269 | Args: 270 | nodes: List of Nodes. 271 | """ 272 | self.node_list = nodes 273 | self.nodes = get_nodes_at_levels(self.node_list) 274 | self.size = len(self.node_list) 275 | self.max_level = get_max_level(self.node_list) 276 | self.adj_mat = get_adj_mat(self.node_list) 277 | self.child_ixs = get_child_ixs(self.nodes, self.adj_mat) 278 | #self.parent_ixs = get_parent_ixs(self.nodes, self.adj_mat) 279 | 280 | 281 | # Parsing Classes and Functions 282 | 283 | 284 | class Queue: 285 | def __init__(self): 286 | self.data = [] 287 | 288 | def empty(self): 289 | return len(self.data) == 0 290 | 291 | def push(self, token, level): 292 | self.data.append((token, level)) 293 | 294 | def pop(self): 295 | token, level = self.data[0] 296 | del self.data[0] 297 | return token, level 298 | 299 | 300 | class Stack: 301 | def __init__(self): 302 | self.items = [] 303 | 304 | def empty(self): 305 | return len(self.items) == 0 306 | 307 | def push(self, sexpr, level, parent_ix): 308 | self.items.append((sexpr, level, parent_ix)) 309 | 310 | def pop(self): 311 | sexpr, level, parent_ix = self.items[-1] 312 | del self.items[-1] 313 | return sexpr, level, parent_ix 314 | 315 | 316 | # Parsing SpaCy Sents 317 | 318 | 319 | def sent_to_tree(sent): 320 | nodes = [] 321 | q = Queue() 322 | head = next(t for t in sent if t.head == t) 323 | q.push(head, 0) 324 | while not q.empty(): 325 | token, level = q.pop() 326 | node = token_to_node(token, level) 327 | nodes.append(node) 328 | for child in token.children: 329 | q.push(child, level + 1) 330 | return Tree(nodes) 331 | 332 | 333 | def token_to_node(token, level): 334 | return Node( 335 | tag=token.tag_, 336 | pos=token.pos_, 337 | token=token.text, 338 | id=token.i, 339 | parent_id=token.head.i if token.head.i != token.i else -1, 340 | relationship=token.dep_, 341 | text_ix=token.i, 342 | level=level, 343 | is_leaf=len(list(token.children)) == 0) 344 | 345 | 346 | # Parsing S-Expressions 347 | 348 | 349 | def tokenize(x): 350 | """Tokenizes S-expression dependency parse trees that come with NLI data. 351 | 352 | This one has been tested here: 353 | https://github.com/timniven/hsnli/blob/master/hsnli/tests/tree_sexpr_tests.py 354 | 355 | Args: 356 | x: String, the tree (or subtree) S-expression. 357 | 358 | Returns: 359 | String, List(String), Boolean: tag, [S-expression for the node], is_leaf 360 | flag indicating whether this node is a leaf. 361 | """ 362 | remove_outer_brackets = x[1:-1] 363 | if '(' not in remove_outer_brackets: # means it's a leaf 364 | split = remove_outer_brackets.split(' ') 365 | tag, data = split[0], [split[1]] 366 | else: 367 | sexpr_tokenized = sexpr.sexpr_tokenize(remove_outer_brackets) 368 | tag = sexpr_tokenized[0] 369 | del sexpr_tokenized[0] 370 | data = sexpr_tokenized 371 | is_leaf = len(data) == 1 and not (data[0][0] == '(' and data[0][-1] == ')') 372 | return tag, data, is_leaf 373 | 374 | 375 | def sexpr_to_tree(sexpr): 376 | """Returns all nodes in a tree. 377 | 378 | Args: 379 | sexpr: String, a sexpr. 380 | 381 | Returns: 382 | Tree. 383 | """ 384 | nodes = [] 385 | id = -1 386 | text_ix = -1 387 | 388 | stack = Stack() 389 | stack.push(sexpr, 0, id) 390 | 391 | while not stack.empty(): 392 | sexpr, level, parent_id = stack.pop() 393 | tag, data, is_leaf = tokenize(sexpr) 394 | id += 1 395 | if not is_leaf: 396 | for sexpr in reversed(data): # reversing here gives desired order 397 | stack.push(sexpr, level + 1, id) 398 | else: 399 | text_ix += 1 400 | nodes.append(Node( 401 | tag=tag, 402 | pos=None, # don't have it in these sexpr Strings 403 | token=data[0] if is_leaf else None, 404 | id=id, 405 | parent_id=parent_id, 406 | relationship=None, # don't have it 407 | text_ix=text_ix if is_leaf else None, 408 | level=level, 409 | is_leaf=is_leaf)) 410 | 411 | return Tree(nodes) 412 | 413 | 414 | # Combining text into internal nodes 415 | 416 | 417 | def combine_text_at_nodes(tree): 418 | for l in reversed(range(tree.max_level + 1)): 419 | nodes = tree.nodes[l] 420 | for ix in range(len(nodes)): 421 | node = nodes[ix] 422 | # for lowest level, set text_at_node to the token 423 | if l == tree.max_level: 424 | node.text_at_node = node.token 425 | # for higher levels, compose these strings 426 | if l < tree.max_level: 427 | node.text_at_node = node.token if node.token else '' 428 | children = [tree.nodes[l+1][cix] 429 | for cix in tree.child_ixs[l][ix]] 430 | sorted_nodes = sorted([node] + [c for c in children], 431 | key=lambda x: x.id) 432 | nodes_text = [n.text_at_node 433 | for n in sorted_nodes 434 | if n.text_at_node != ''] 435 | node.text_at_node = ' '.join([tok for tok in nodes_text]) 436 | --------------------------------------------------------------------------------