├── utils ├── __init__.py ├── locked_dropout.py ├── hinton.py ├── utils.py └── listops_data.py ├── data ├── listops │ ├── __init__.py │ ├── base.py │ ├── load_listops_data.py │ └── make_data.py └── propositionallogic │ ├── __init__.py │ ├── test0 │ ├── train0 │ ├── generate_neg_set_data.py │ └── test1 ├── EVALB ├── evalb ├── Makefile ├── tgrep_proc.prl ├── LICENSE ├── sample │ ├── sample.tst │ ├── sample.gld │ ├── sample.prm │ └── sample.rsl ├── COLLINS.prm ├── new.prm ├── README └── evalb.c ├── .gitattributes ├── Ordered_Memory_Slides.pdf ├── requirements.txt ├── .ptignore ├── .gitignore ├── LICENSE ├── README.md ├── ordered_memory.py ├── listops.py ├── sentiment.py └── proplog.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/listops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/propositionallogic/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /EVALB/evalb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yikangshen/Ordered-Memory/HEAD/EVALB/evalb -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /EVALB/Makefile: -------------------------------------------------------------------------------- 1 | all: evalb 2 | 3 | evalb: evalb.c 4 | gcc -Wall -g -o evalb evalb.c 5 | -------------------------------------------------------------------------------- /data/propositionallogic/test0: -------------------------------------------------------------------------------- 1 | # a b 2 | # a e 3 | = b b 4 | # b d 5 | # f e 6 | # c e 7 | -------------------------------------------------------------------------------- /Ordered_Memory_Slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yikangshen/Ordered-Memory/HEAD/Ordered_Memory_Slides.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.13.3 2 | matplotlib>=3.0.3 3 | python_gflags 4 | nltk 5 | spacy 6 | torch 7 | torchtext -------------------------------------------------------------------------------- /EVALB/tgrep_proc.prl: -------------------------------------------------------------------------------- 1 | #!/usr/local/bin/perl 2 | 3 | while(<>) 4 | { 5 | if(m/TOP/) #skip lines which are blank 6 | { 7 | print; 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /.ptignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | Philly* 3 | exp_logic 4 | *.diff 5 | data/data_scan 6 | data/scan 7 | data/propositionallogic_exp1 8 | data/mnli 9 | data/SST 10 | data/penn 11 | data/listops 12 | data/treebank_proc.conllu 13 | data/test_proc.conllu 14 | -------------------------------------------------------------------------------- /data/propositionallogic/train0: -------------------------------------------------------------------------------- 1 | = a a 2 | # d f 3 | = f f 4 | # b e 5 | # d c 6 | # c f 7 | # e a 8 | # f c 9 | # e c 10 | # e f 11 | # e d 12 | # f d 13 | # a d 14 | # c a 15 | # c b 16 | # c d 17 | # f b 18 | # d e 19 | # d a 20 | # d b 21 | # a f 22 | = d d 23 | # f a 24 | # a c 25 | # b a 26 | # b c 27 | = e e 28 | # b f 29 | = c c 30 | # e b 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | core.* 2 | 3 | *.hdf5 4 | *.pt 5 | .vector_cache 6 | .data 7 | 8 | *.xml 9 | *.iml 10 | 11 | 12 | # vim and gedit cache: 13 | *.swp 14 | *.swo 15 | *.swn 16 | *.swl 17 | *~ 18 | 19 | # cluster logs 20 | SMART_DISPATCH_LOGS/* 21 | 22 | # model params 23 | model/* 24 | params/* 25 | tmptrees/* 26 | 27 | # logs 28 | tblogs/* 29 | logs/* 30 | 31 | # Byte-compiled / optimized / DLL files 32 | __pycache__/ 33 | *.py[cod] 34 | *$py.class 35 | *.pyc 36 | 37 | *.log 38 | 39 | -------------------------------------------------------------------------------- /data/listops/base.py: -------------------------------------------------------------------------------- 1 | from spinn import util 2 | 3 | NUMBERS = list(range(10)) 4 | 5 | FIXED_VOCABULARY = {str(x): i + 1 for i, x in enumerate(NUMBERS)} 6 | FIXED_VOCABULARY.update({ 7 | util.PADDING_TOKEN: 0, 8 | "[MIN": len(FIXED_VOCABULARY) + 1, 9 | "[MAX": len(FIXED_VOCABULARY) + 2, 10 | "[FIRST": len(FIXED_VOCABULARY) + 3, 11 | "[LAST": len(FIXED_VOCABULARY) + 4, 12 | "[MED": len(FIXED_VOCABULARY) + 5, 13 | "[SM": len(FIXED_VOCABULARY) + 6, 14 | "[PM": len(FIXED_VOCABULARY) + 7, 15 | "[FLSUM": len(FIXED_VOCABULARY) + 8, 16 | "]": len(FIXED_VOCABULARY) + 9 17 | }) 18 | assert len(set(FIXED_VOCABULARY.values())) == len(list(FIXED_VOCABULARY.values())) 19 | -------------------------------------------------------------------------------- /utils/locked_dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | class LockedDropout(nn.Module): 6 | def __init__(self, dropout=0.5, dim=0): 7 | super().__init__() 8 | 9 | assert dim in [0, 1] 10 | self.dim = dim 11 | self.dropout = dropout 12 | 13 | def forward(self, x): 14 | assert len(x.size()) == 3 15 | if not self.training or not self.dropout: 16 | return x 17 | if self.dim == 0: 18 | m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout) 19 | elif self.dim == 1: 20 | m = x.data.new(x.size(0), 1, x.size(2)).bernoulli_(1 - self.dropout) 21 | mask = Variable(m, requires_grad=False) / (1 - self.dropout) 22 | mask = mask.expand_as(x) 23 | return mask * x 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yikang Shen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/hinton.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import numpy as np 3 | chars = [" ", "▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"] 4 | 5 | 6 | class BarHack(str): 7 | 8 | def __str__(self): 9 | return self.internal 10 | 11 | def __len__(self): 12 | return 1 13 | 14 | 15 | def plot(arr, max_val=None): 16 | if max_val is None: 17 | max_arr = arr 18 | max_val = max(abs(np.max(max_arr)), abs(np.min(max_arr))) 19 | 20 | opts = np.get_printoptions() 21 | np.set_printoptions(edgeitems=500) 22 | fig = np.array2string(arr, 23 | formatter={ 24 | 'float_kind': lambda x: visual(x, max_val), 25 | 'int_kind': lambda x: visual(x, max_val)}, 26 | max_line_width=5000 27 | ) 28 | np.set_printoptions(**opts) 29 | 30 | return fig 31 | 32 | 33 | def visual(val, max_val): 34 | val = np.clip(val, 0, max_val) 35 | if abs(val) == max_val: 36 | step = len(chars) - 1 37 | else: 38 | step = int(abs(float(val) / max_val) * len(chars)) 39 | colourstart = "" 40 | colourend = "" 41 | if val < 0: 42 | colourstart, colourend = '\033[90m', '\033[0m' 43 | return colourstart + chars[step] + colourend 44 | -------------------------------------------------------------------------------- /EVALB/LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ordered Memory 2 | 3 | This repository contains the code used for [Ordered Memory](https://arxiv.org/abs/1910.13466). 4 | 5 | The code comes with instructions for experiments: 6 | + [propositional logic experiments](https://www.aclweb.org/anthology/W15-4002.pdf) 7 | 8 | + [ListOps](https://arxiv.org/pdf/1804.06028.pdf) 9 | 10 | + [SST](https://nlp.stanford.edu/sentiment/treebank.html) 11 | 12 | If you use this code or our results in your research, please cite as appropriate: 13 | 14 | ``` 15 | @incollection{NIPS2019_8748, 16 | title = {Ordered Memory}, 17 | author = {Shen, Yikang and Tan, Shawn and Hosseini, Arian and Lin, Zhouhan and Sordoni, Alessandro and Courville, Aaron C}, 18 | booktitle = {Advances in Neural Information Processing Systems 32}, 19 | editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett}, 20 | pages = {5038--5049}, 21 | year = {2019}, 22 | publisher = {Curran Associates, Inc.}, 23 | url = {http://papers.nips.cc/paper/8748-ordered-memory.pdf} 24 | } 25 | 26 | ``` 27 | 28 | ## Software Requirements 29 | 30 | Python 3, PyTorch 1.2, and torchtext are required for the current codebase. 31 | 32 | ## Experiments 33 | 34 | ### Propositional Logic 35 | 36 | + `python -u proplog.py --cuda --save logic.pt` 37 | 38 | ### ListOps 39 | 40 | + `python -u listops.py --cuda --name listops.pt` 41 | 42 | ### SST 43 | 44 | + `python -u main.py --subtrees --cuda --name sentiment.pt --glove/--elmo (--fine-grained)` 45 | 46 | -------------------------------------------------------------------------------- /EVALB/sample/sample.tst: -------------------------------------------------------------------------------- 1 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 2 | (S (A (P this)) (B (Q is) (C (R a) (T test)))) 3 | (S (A (P this)) (B (Q is) (A (R a) (U test)))) 4 | (S (C (P this)) (B (Q is) (A (R a) (U test)))) 5 | (S (A (P this)) (B (Q is) (R a) (A (T test)))) 6 | (S (A (P this) (Q is)) (A (R a) (T test))) 7 | (S (P this) (Q is) (R a) (T test)) 8 | (P this) (Q is) (R a) (T test) 9 | (S (A (P this)) (B (Q is) (A (A (R a) (T test))))) 10 | (S (A (P this)) (B (Q is) (A (A (A (A (A (R a) (T test)))))))) 11 | 12 | (S (A (P this)) (B (Q was) (A (A (R a) (T test))))) 13 | (S (A (P this)) (B (Q is) (U not) (A (A (R a) (T test))))) 14 | 15 | (TOP (S (A (P this)) (B (Q is) (A (R a) (T test))))) 16 | (S (A (P this)) (NONE *) (B (Q is) (A (R a) (T test)))) 17 | (S (A (P this)) (S (NONE abc) (A (NONE *))) (B (Q is) (A (R a) (T test)))) 18 | (S (A (P this)) (B (Q is) (A (R a) (TT test)))) 19 | (S (A (P This)) (B (Q is) (A (R a) (T test)))) 20 | (S (A (P That)) (B (Q is) (A (R a) (T test)))) 21 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 22 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test)))) 23 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (-NONE- *)) 24 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (: *)) 25 | -------------------------------------------------------------------------------- /EVALB/sample/sample.gld: -------------------------------------------------------------------------------- 1 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 2 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 3 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 4 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 5 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 6 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 7 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 8 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 9 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 10 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 11 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 12 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 13 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 14 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 15 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 16 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 17 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 18 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 19 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 20 | (S (A (P this)) (B (Q is) (A (R a) (T test)))) 21 | (S (A-SBJ-1 (P this)) (B-WHATEVER (Q is) (A (R a) (T test)))) 22 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test))) (A (P this)) (B (Q is) (A (R a) (T test)))) 23 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (-NONE- *)) 24 | (S (A (P this)) (B (Q is) (A (R a) (T test))) (: *)) 25 | -------------------------------------------------------------------------------- /data/listops/load_listops_data.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from spinn import util 3 | 4 | from spinn.data.listops.base import FIXED_VOCABULARY 5 | 6 | SENTENCE_PAIR_DATA = False 7 | OUTPUTS = list(range(10)) 8 | LABEL_MAP = {str(x): i for i, x in enumerate(OUTPUTS)} 9 | 10 | Node = namedtuple('Node', 'tag span') 11 | 12 | 13 | def spans(transitions, tokens=None): 14 | n = (len(transitions) + 1) // 2 15 | stack = [] 16 | buf = [Node("leaf", (l, r)) for l, r in zip(list(range(n)), list(range(1, n + 1)))] 17 | buf = list(reversed(buf)) 18 | 19 | nodes = [] 20 | reduced = [False] * n 21 | 22 | def SHIFT(item): 23 | nodes.append(item) 24 | return item 25 | 26 | def REDUCE(l, r): 27 | tag = None 28 | i = r.span[1] - 1 29 | if tokens is not None and tokens[i] == ']' and not reduced[i]: 30 | reduced[i] = True 31 | tag = "struct" 32 | new_stack_item = Node(tag=tag, span=(l.span[0], r.span[1])) 33 | nodes.append(new_stack_item) 34 | return new_stack_item 35 | 36 | for t in transitions: 37 | if t == 0: 38 | stack.append(SHIFT(buf.pop())) 39 | elif t == 1: 40 | r, l = stack.pop(), stack.pop() 41 | stack.append(REDUCE(l, r)) 42 | 43 | return nodes 44 | 45 | 46 | def load_data(path, lowercase=None, choose=lambda x: True, eval_mode=False): 47 | examples = [] 48 | with open(path) as f: 49 | for example_id, line in enumerate(f): 50 | line = line.strip() 51 | label, seq = line.split('\t') 52 | if len(seq) <= 1: 53 | continue 54 | 55 | tokens, transitions = util.ConvertBinaryBracketedSeq( 56 | seq.split(' ')) 57 | 58 | example = {} 59 | example["label"] = label 60 | example["sentence"] = seq 61 | example["tokens"] = tokens 62 | example["transitions"] = transitions 63 | example["example_id"] = str(example_id) 64 | 65 | examples.append(example) 66 | return examples 67 | -------------------------------------------------------------------------------- /data/listops/make_data.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | MIN = "[MIN" 5 | MAX = "[MAX" 6 | MED = "[MED" 7 | FIRST = "[FIRST" 8 | LAST = "[LAST" 9 | SUM_MOD = "[SM" 10 | END = "]" 11 | 12 | OPERATORS = [MIN, MAX, MED, SUM_MOD] # , FIRST, LAST] 13 | VALUES = range(10) 14 | 15 | VALUE_P = 0.25 16 | MAX_ARGS = 5 17 | MAX_DEPTH = 20 18 | 19 | DATA_POINTS = 1000000 20 | 21 | 22 | def generate_tree(depth): 23 | if depth < MAX_DEPTH: 24 | r = random.random() 25 | else: 26 | r = 1 27 | 28 | if r > VALUE_P: 29 | value = random.choice(VALUES) 30 | return value 31 | else: 32 | num_values = random.randint(2, MAX_ARGS) 33 | values = [] 34 | for _ in range(num_values): 35 | values.append(generate_tree(depth + 1)) 36 | 37 | op = random.choice(OPERATORS) 38 | t = (op, values[0]) 39 | for value in values[1:]: 40 | t = (t, value) 41 | t = (t, END) 42 | return t 43 | 44 | 45 | def to_string(t, parens=True): 46 | if isinstance(t, str): 47 | return t 48 | elif isinstance(t, int): 49 | return str(t) 50 | else: 51 | if parens: 52 | return '( ' + to_string(t[0]) + ' ' + to_string(t[1]) + ' )' 53 | 54 | 55 | def to_value(t): 56 | if not isinstance(t, tuple): 57 | return t 58 | l = to_value(t[0]) 59 | r = to_value(t[1]) 60 | if l in OPERATORS: # Create an unsaturated function. 61 | return (l, [r]) 62 | elif r == END: # l must be an unsaturated function. 63 | if l[0] == MIN: 64 | return min(l[1]) 65 | elif l[0] == MAX: 66 | return max(l[1]) 67 | elif l[0] == FIRST: 68 | return l[1][0] 69 | elif l[0] == LAST: 70 | return l[1][-1] 71 | elif l[0] == MED: 72 | return int(np.median(l[1])) 73 | elif l[0] == SUM_MOD: 74 | return (np.sum(l[1]) % 10) 75 | elif isinstance(l, tuple): # We've hit an unsaturated function and an argument. 76 | return (l[0], l[1] + [r]) 77 | 78 | 79 | data = set() 80 | while len(data) < DATA_POINTS: 81 | data.add(generate_tree(1)) 82 | 83 | for example in data: 84 | print(str(to_value(example)) + '\t' + to_string(example)) -------------------------------------------------------------------------------- /EVALB/COLLINS.prm: -------------------------------------------------------------------------------- 1 | ##------------------------------------------## 2 | ## Debug mode ## 3 | ## 0: No debugging ## 4 | ## 1: print data for individual sentence ## 5 | ##------------------------------------------## 6 | DEBUG 0 7 | 8 | ##------------------------------------------## 9 | ## MAX error ## 10 | ## Number of error to stop the process. ## 11 | ## This is useful if there could be ## 12 | ## tokanization error. ## 13 | ## The process will stop when this number## 14 | ## of errors are accumulated. ## 15 | ##------------------------------------------## 16 | MAX_ERROR 10 17 | 18 | ##------------------------------------------## 19 | ## Cut-off length for statistics ## 20 | ## At the end of evaluation, the ## 21 | ## statistics for the senetnces of length## 22 | ## less than or equal to this number will## 23 | ## be shown, on top of the statistics ## 24 | ## for all the sentences ## 25 | ##------------------------------------------## 26 | CUTOFF_LEN 40 27 | 28 | ##------------------------------------------## 29 | ## unlabeled or labeled bracketing ## 30 | ## 0: unlabeled bracketing ## 31 | ## 1: labeled bracketing ## 32 | ##------------------------------------------## 33 | LABELED 0 34 | 35 | ##------------------------------------------## 36 | ## Delete labels ## 37 | ## list of labels to be ignored. ## 38 | ## If it is a pre-terminal label, delete ## 39 | ## the word along with the brackets. ## 40 | ## If it is a non-terminal label, just ## 41 | ## delete the brackets (don't delete ## 42 | ## deildrens). ## 43 | ##------------------------------------------## 44 | DELETE_LABEL ROOT 45 | 46 | ##------------------------------------------## 47 | ## Delete labels for length calculation ## 48 | ## list of labels to be ignored for ## 49 | ## length calculation purpose ## 50 | ##------------------------------------------## 51 | DELETE_LABEL_FOR_LENGTH -NONE- 52 | 53 | ##------------------------------------------## 54 | ## Equivalent labels, words ## 55 | ## the pairs are considered equivalent ## 56 | ## This is non-directional. ## 57 | ##------------------------------------------## 58 | EQ_LABEL ADVP PRT 59 | 60 | # EQ_WORD Example example 61 | -------------------------------------------------------------------------------- /EVALB/sample/sample.prm: -------------------------------------------------------------------------------- 1 | ##------------------------------------------## 2 | ## Debug mode ## 3 | ## print out data for individual sentence ## 4 | ##------------------------------------------## 5 | DEBUG 0 6 | 7 | ##------------------------------------------## 8 | ## MAX error ## 9 | ## Number of error to stop the process. ## 10 | ## This is useful if there could be ## 11 | ## tokanization error. ## 12 | ## The process will stop when this number## 13 | ## of errors are accumulated. ## 14 | ##------------------------------------------## 15 | MAX_ERROR 10 16 | 17 | ##------------------------------------------## 18 | ## Cut-off length for statistics ## 19 | ## At the end of evaluation, the ## 20 | ## statistics for the senetnces of length## 21 | ## less than or equal to this number will## 22 | ## be shown, on top of the statistics ## 23 | ## for all the sentences ## 24 | ##------------------------------------------## 25 | CUTOFF_LEN 40 26 | 27 | ##------------------------------------------## 28 | ## unlabeled or labeled bracketing ## 29 | ## 0: unlabeled bracketing ## 30 | ## 1: labeled bracketing ## 31 | ##------------------------------------------## 32 | LABELED 1 33 | 34 | ##------------------------------------------## 35 | ## Delete labels ## 36 | ## list of labels to be ignored. ## 37 | ## If it is a pre-terminal label, delete ## 38 | ## the word along with the brackets. ## 39 | ## If it is a non-terminal label, just ## 40 | ## delete the brackets (don't delete ## 41 | ## deildrens). ## 42 | ##------------------------------------------## 43 | DELETE_LABEL TOP 44 | DELETE_LABEL -NONE- 45 | DELETE_LABEL , 46 | DELETE_LABEL : 47 | DELETE_LABEL `` 48 | DELETE_LABEL '' 49 | 50 | ##------------------------------------------## 51 | ## Delete labels for length calculation ## 52 | ## list of labels to be ignored for ## 53 | ## length calculation purpose ## 54 | ##------------------------------------------## 55 | DELETE_LABEL_FOR_LENGTH -NONE- 56 | 57 | 58 | ##------------------------------------------## 59 | ## Equivalent labels, words ## 60 | ## the pairs are considered equivalent ## 61 | ## This is non-directional. ## 62 | ##------------------------------------------## 63 | EQ_LABEL T TT 64 | 65 | EQ_WORD This this 66 | -------------------------------------------------------------------------------- /EVALB/sample/sample.rsl: -------------------------------------------------------------------------------- 1 | Sent. Matched Bracket Cross Correct Tag 2 | ID Len. Stat. Recal Prec. Bracket gold test Bracket Words Tags Accracy 3 | ============================================================================ 4 | 1 4 0 100.00 100.00 4 4 4 0 4 4 100.00 5 | 2 4 0 75.00 75.00 3 4 4 0 4 4 100.00 6 | 3 4 0 100.00 100.00 4 4 4 0 4 3 75.00 7 | 4 4 0 75.00 75.00 3 4 4 0 4 3 75.00 8 | 5 4 0 75.00 75.00 3 4 4 0 4 4 100.00 9 | 6 4 0 50.00 66.67 2 4 3 1 4 4 100.00 10 | 7 4 0 25.00 100.00 1 4 1 0 4 4 100.00 11 | 8 4 0 0.00 0.00 0 4 0 0 4 4 100.00 12 | 9 4 0 100.00 80.00 4 4 5 0 4 4 100.00 13 | 10 4 0 100.00 50.00 4 4 8 0 4 4 100.00 14 | 11 4 2 0.00 0.00 0 0 0 0 4 0 0.00 15 | 12 4 1 0.00 0.00 0 0 0 0 4 0 0.00 16 | 13 4 1 0.00 0.00 0 0 0 0 4 0 0.00 17 | 14 4 2 0.00 0.00 0 0 0 0 4 0 0.00 18 | 15 4 0 100.00 100.00 4 4 4 0 4 4 100.00 19 | 16 4 1 0.00 0.00 0 0 0 0 4 0 0.00 20 | 17 4 1 0.00 0.00 0 0 0 0 4 0 0.00 21 | 18 4 0 100.00 100.00 4 4 4 0 4 4 100.00 22 | 19 4 0 100.00 100.00 4 4 4 0 4 4 100.00 23 | 20 4 1 0.00 0.00 0 0 0 0 4 0 0.00 24 | 21 4 0 100.00 100.00 4 4 4 0 4 4 100.00 25 | 22 44 0 100.00 100.00 34 34 34 0 44 44 100.00 26 | 23 4 0 100.00 100.00 4 4 4 0 4 4 100.00 27 | 24 5 0 100.00 100.00 4 4 4 0 4 4 100.00 28 | ============================================================================ 29 | 87.76 90.53 86 98 95 16 108 106 98.15 30 | === Summary === 31 | 32 | -- All -- 33 | Number of sentence = 24 34 | Number of Error sentence = 5 35 | Number of Skip sentence = 2 36 | Number of Valid sentence = 17 37 | Bracketing Recall = 87.76 38 | Bracketing Precision = 90.53 39 | Complete match = 52.94 40 | Average crossing = 0.06 41 | No crossing = 94.12 42 | 2 or less crossing = 100.00 43 | Tagging accuracy = 98.15 44 | 45 | -- len<=40 -- 46 | Number of sentence = 23 47 | Number of Error sentence = 5 48 | Number of Skip sentence = 2 49 | Number of Valid sentence = 16 50 | Bracketing Recall = 81.25 51 | Bracketing Precision = 85.25 52 | Complete match = 50.00 53 | Average crossing = 0.06 54 | No crossing = 93.75 55 | 2 or less crossing = 100.00 56 | Tagging accuracy = 96.88 57 | -------------------------------------------------------------------------------- /EVALB/new.prm: -------------------------------------------------------------------------------- 1 | ##------------------------------------------## 2 | ## Debug mode ## 3 | ## 0: No debugging ## 4 | ## 1: print data for individual sentence ## 5 | ## 2: print detailed bracketing info ## 6 | ##------------------------------------------## 7 | DEBUG 0 8 | 9 | ##------------------------------------------## 10 | ## MAX error ## 11 | ## Number of error to stop the process. ## 12 | ## This is useful if there could be ## 13 | ## tokanization error. ## 14 | ## The process will stop when this number## 15 | ## of errors are accumulated. ## 16 | ##------------------------------------------## 17 | MAX_ERROR 10 18 | 19 | ##------------------------------------------## 20 | ## Cut-off length for statistics ## 21 | ## At the end of evaluation, the ## 22 | ## statistics for the senetnces of length## 23 | ## less than or equal to this number will## 24 | ## be shown, on top of the statistics ## 25 | ## for all the sentences ## 26 | ##------------------------------------------## 27 | CUTOFF_LEN 40 28 | 29 | ##------------------------------------------## 30 | ## unlabeled or labeled bracketing ## 31 | ## 0: unlabeled bracketing ## 32 | ## 1: labeled bracketing ## 33 | ##------------------------------------------## 34 | LABELED 1 35 | 36 | ##------------------------------------------## 37 | ## Delete labels ## 38 | ## list of labels to be ignored. ## 39 | ## If it is a pre-terminal label, delete ## 40 | ## the word along with the brackets. ## 41 | ## If it is a non-terminal label, just ## 42 | ## delete the brackets (don't delete ## 43 | ## deildrens). ## 44 | ##------------------------------------------## 45 | DELETE_LABEL TOP 46 | DELETE_LABEL S1 47 | DELETE_LABEL -NONE- 48 | DELETE_LABEL , 49 | DELETE_LABEL : 50 | DELETE_LABEL `` 51 | DELETE_LABEL '' 52 | DELETE_LABEL . 53 | DELETE_LABEL ? 54 | DELETE_LABEL ! 55 | 56 | ##------------------------------------------## 57 | ## Delete labels for length calculation ## 58 | ## list of labels to be ignored for ## 59 | ## length calculation purpose ## 60 | ##------------------------------------------## 61 | DELETE_LABEL_FOR_LENGTH -NONE- 62 | 63 | ##------------------------------------------## 64 | ## Labels to be considered for misquote ## 65 | ## (could be possesive or quote) ## 66 | ##------------------------------------------## 67 | QUOTE_LABEL `` 68 | QUOTE_LABEL '' 69 | QUOTE_LABEL POS 70 | 71 | ##------------------------------------------## 72 | ## These ones are less common, but ## 73 | ## are on occasion output by parsers: ## 74 | ##------------------------------------------## 75 | QUOTE_LABEL NN 76 | QUOTE_LABEL CD 77 | QUOTE_LABEL VBZ 78 | QUOTE_LABEL : 79 | 80 | ##------------------------------------------## 81 | ## Equivalent labels, words ## 82 | ## the pairs are considered equivalent ## 83 | ## This is non-directional. ## 84 | ##------------------------------------------## 85 | EQ_LABEL ADVP PRT 86 | 87 | # EQ_WORD Example example 88 | -------------------------------------------------------------------------------- /data/propositionallogic/generate_neg_set_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from itertools import * 3 | from collections import * 4 | import random 5 | 6 | 7 | def powerset(iterable): 8 | "From itertools: powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)" 9 | s = list(iterable) 10 | return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) 11 | 12 | 13 | def get_candidate_worlds(num_vars): 14 | return powerset(set(range(num_vars))) 15 | 16 | 17 | def get_satisfying_worlds_for_tree(tree, candidate_worlds): 18 | if isinstance(tree, tuple): 19 | if tree[0] == 'not': 20 | child = get_satisfying_worlds_for_tree(tree[1], candidate_worlds) 21 | return candidate_worlds.difference(child) 22 | else: 23 | left = get_satisfying_worlds_for_tree(tree[0], candidate_worlds) 24 | right = get_satisfying_worlds_for_tree(tree[2], candidate_worlds) 25 | if tree[1] == "and": 26 | return left.intersection(right) 27 | elif tree[1] == "or": 28 | return left.union(right) 29 | else: 30 | print 'syntax error', tree 31 | else: 32 | result = [] 33 | for world in candidate_worlds: 34 | if tree in world: 35 | result.append(world) 36 | return set(result) 37 | 38 | 39 | def compute_relation(left, right, universe): 40 | ne_intersection = left.intersection(right) 41 | ne_just_left = left.difference(right) 42 | ne_just_right = right.difference(left) 43 | ne_outside = universe.difference(left.union(right)) 44 | if ne_intersection and not ne_just_right and not ne_just_left and ne_outside: 45 | return "=" 46 | elif ne_intersection and ne_just_right and not ne_just_left and ne_outside: 47 | return "<" 48 | elif ne_intersection and not ne_just_right and ne_just_left and ne_outside: 49 | return ">" 50 | elif not ne_intersection and ne_just_right and ne_just_left and not ne_outside: 51 | return "^" 52 | elif not ne_intersection and ne_just_right and ne_just_left and ne_outside: 53 | return "|" 54 | elif ne_intersection and ne_just_right and ne_just_left and not ne_outside: 55 | return "v" 56 | else: 57 | return "#" 58 | 59 | 60 | def create_sub_statement(universe, maxlen): 61 | operator = random.choice(operators) 62 | temp = () 63 | if operator == '0' or maxlen < 2: 64 | temp = random.choice(list(universe)) 65 | else: 66 | lhs = create_sub_statement(universe, maxlen / 2) 67 | rhs = create_sub_statement(universe, maxlen / 2) 68 | temp = tuple([lhs, operator, rhs]) 69 | 70 | neg_or_none = random.choice(neg_or_nones) 71 | if neg_or_none == '0': 72 | return temp 73 | else: 74 | return tuple([neg_or_none, temp]) 75 | 76 | 77 | def uniq(seq, idfun=None): 78 | # order preserving 79 | if idfun is None: 80 | def idfun(x): 81 | return x 82 | seen = {} 83 | result = [] 84 | for item in seq: 85 | marker = idfun(item) 86 | # in old Python versions: 87 | # if seen.has_key(marker) 88 | # but in new ones: 89 | if marker in seen: 90 | continue 91 | seen[marker] = 1 92 | result.append(item) 93 | return result 94 | 95 | 96 | def to_string(expr, individuals): 97 | if isinstance(expr, int): 98 | return individuals[expr] 99 | if isinstance(expr, str): 100 | return expr 101 | elif len(expr) == 3: 102 | return "( " + to_string(expr[0], individuals) + " ( " + to_string(expr[1], individuals) + " " + to_string(expr[2], individuals) + " ) )" 103 | else: 104 | return "( " + to_string(expr[0], individuals) + " " + to_string(expr[1], individuals) + " )" 105 | 106 | 107 | def get_len(tree): 108 | if isinstance(tree, tuple): 109 | accum = 0 110 | for entry in tree: 111 | accum += get_len(entry) 112 | return accum 113 | elif tree == 'and' or tree == 'or' or tree == 'not': 114 | return 1 115 | else: 116 | return 0 117 | 118 | individuals = ['a', 'b', 'c', 'd', 'e', 'f', 'g'] 119 | 120 | worlds = set(get_candidate_worlds(6)) 121 | universe = set(range(6)) 122 | 123 | neg_or_nones = ['not', '0', '0'] 124 | operators = ['and', 'or', 'and', 'or', '0', '0', '0', '0', '0'] 125 | 126 | 127 | stats = Counter() 128 | total = 0 129 | outputs = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 130 | 6: [], 7: [], 8: [], 9: [], 10: [], 11: [], 12: []} 131 | while total < 500000: 132 | subuniverse = random.sample(universe, 4) 133 | lhs = create_sub_statement(subuniverse, 12) 134 | rhs = create_sub_statement(subuniverse, 12) 135 | sat1 = get_satisfying_worlds_for_tree(lhs, worlds) 136 | sat2 = get_satisfying_worlds_for_tree(rhs, worlds) 137 | if sat1 == worlds or len(sat1) == 0: 138 | continue 139 | if sat2 == worlds or len(sat2) == 0: 140 | continue 141 | rel = compute_relation(sat1, sat2, worlds) 142 | 143 | if rel != "?": 144 | stats[rel] += 1 145 | total += 1 146 | max_len = min(max(get_len(rhs), get_len(lhs)), 12) 147 | outputs[max_len].append("" + rel + "\t" + to_string( 148 | lhs, individuals) + "\t" + to_string(rhs, individuals)) 149 | 150 | TRAIN_PORTION = 0.85 151 | 152 | for length in outputs.keys(): 153 | outputs[length] = uniq(outputs[length]) 154 | 155 | filename = 'train' + str(length) 156 | f = open(filename, 'w') 157 | for i in range(int(TRAIN_PORTION * len(outputs[length]))): 158 | output = outputs[length][i] 159 | f.write(output + "\n") 160 | f.close() 161 | 162 | filename = 'test' + str(length) 163 | f = open(filename, 'w') 164 | for i in range(int(TRAIN_PORTION * len(outputs[length])), len(outputs[length])): 165 | output = outputs[length][i] 166 | f.write(output + "\n") 167 | f.close() 168 | 169 | print stats 170 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from collections import deque, Counter 2 | 3 | import torch 4 | 5 | 6 | class Dictionary(object): 7 | def __init__(self): 8 | self.word2idx = {} 9 | self.idx2word = [] 10 | self.counter = Counter() 11 | self.total = 0 12 | 13 | def add_word(self, word): 14 | if word not in self.word2idx: 15 | self.idx2word.append(word) 16 | self.word2idx[word] = len(self.idx2word) - 1 17 | token_id = self.word2idx[word] 18 | self.counter[token_id] += 1 19 | self.total += 1 20 | return self.word2idx[word] 21 | 22 | def __len__(self): 23 | return len(self.idx2word) 24 | 25 | 26 | def build_tree(depth, sen): 27 | depth = depth 28 | queue = deque(sen) 29 | stack = [queue.popleft()] 30 | head = depth[0] - 1 31 | for point in depth[1:]: 32 | d = point - head 33 | if d > 0: 34 | for _ in range(d): 35 | if len(stack) == 1: 36 | break 37 | x1 = stack.pop() 38 | x2 = stack.pop() 39 | stack.append([x2, x1]) 40 | if len(queue) > 0: 41 | stack.append(queue.popleft()) 42 | head = point - 1 43 | while len(stack) > 2 and isinstance(stack, list): 44 | x1 = stack.pop() 45 | x2 = stack.pop() 46 | stack.append([x2, x1]) 47 | while len(stack) == 1 and isinstance(stack, list): 48 | stack = stack.pop() 49 | return stack 50 | 51 | 52 | def repackage_hidden(h): 53 | """Wraps hidden states in new Tensors, 54 | to detach them from their history.""" 55 | if isinstance(h, torch.Tensor): 56 | return h.detach() 57 | elif h is None: 58 | return None 59 | else: 60 | return tuple(repackage_hidden(v) for v in h) 61 | 62 | 63 | def batchify(data, bsz, args): 64 | # Work out how cleanly we can divide the dataset into bsz parts. 65 | nbatch = data.size(0) // bsz 66 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 67 | data = data.narrow(0, 0, nbatch * bsz) 68 | # Evenly divide the data across the bsz batches. 69 | data = data.view(bsz, -1).t().contiguous() 70 | if args.cuda: 71 | data = data.cuda() 72 | return data 73 | 74 | 75 | def get_batch(source, i, args, seq_len=None, evaluation=False): 76 | seq_len = min(seq_len if seq_len else args.bptt, len(source) - 1 - i) 77 | data = source[i:i + seq_len] 78 | target = source[i + 1:i + 1 + seq_len].view(-1) 79 | return data, target 80 | 81 | 82 | def evalb(pred_tree_list, targ_tree_list, evalb_path="EVALB"): 83 | import os 84 | import subprocess 85 | import re 86 | import nltk 87 | import tempfile 88 | 89 | temp_path = tempfile.TemporaryDirectory(prefix="evalb-") 90 | # temp_path = './test/' 91 | temp_file_path = os.path.join(temp_path.name, "pred_trees.txt") 92 | temp_targ_path = os.path.join(temp_path.name, "true_trees.txt") 93 | temp_eval_path = os.path.join(temp_path.name, "evals.txt") 94 | 95 | print("Temp: {}, {}".format(temp_file_path, temp_targ_path)) 96 | temp_tree_file = open(temp_file_path, "w") 97 | temp_targ_file = open(temp_targ_path, "w") 98 | 99 | for pred_tree, targ_tree in zip(pred_tree_list, targ_tree_list): 100 | def process_str_tree(str_tree): 101 | return re.sub('[ |\n]+', ' ', str_tree) 102 | 103 | def list2tree(node): 104 | if isinstance(node, nltk.Tree): 105 | return node 106 | if isinstance(node, list): 107 | tree = [] 108 | for child in node: 109 | tree.append(list2tree(child)) 110 | return nltk.Tree('', tree) 111 | elif isinstance(node, str): 112 | return nltk.Tree('', [node]) 113 | 114 | if re.search(r'[RRB|rrb]- [0-9]', process_str_tree(str(list2tree(targ_tree)))) is not None: 115 | continue 116 | temp_tree_file.write(process_str_tree(str(list2tree(pred_tree))) + '\n') 117 | temp_targ_file.write(process_str_tree(str(list2tree(targ_tree))) + '\n') 118 | 119 | temp_tree_file.close() 120 | temp_targ_file.close() 121 | 122 | evalb_dir = os.path.join(os.getcwd(), evalb_path) 123 | evalb_param_path = os.path.join(evalb_dir, "COLLINS.prm") 124 | evalb_program_path = os.path.join(evalb_dir, "evalb") 125 | command = "{} -p {} {} {} > {}".format( 126 | evalb_program_path, 127 | evalb_param_path, 128 | temp_targ_path, 129 | temp_file_path, 130 | temp_eval_path) 131 | 132 | subprocess.run(command, shell=True) 133 | 134 | with open(temp_eval_path) as infile: 135 | for line in infile: 136 | match = re.match(r"Bracketing Recall\s+=\s+(\d+\.\d+)", line) 137 | if match: 138 | evalb_recall = float(match.group(1)) 139 | match = re.match(r"Bracketing Precision\s+=\s+(\d+\.\d+)", line) 140 | if match: 141 | evalb_precision = float(match.group(1)) 142 | match = re.match(r"Bracketing FMeasure\s+=\s+(\d+\.\d+)", line) 143 | if match: 144 | evalb_fscore = float(match.group(1)) 145 | break 146 | 147 | temp_path.cleanup() 148 | 149 | print('-' * 80) 150 | print('Evalb Prec:', evalb_precision, 151 | ', Evalb Reca:', evalb_recall, 152 | ', Evalb F1:', evalb_fscore) 153 | 154 | return evalb_fscore 155 | 156 | 157 | def remove_bracket(tree): 158 | if isinstance(tree, str): 159 | if tree in ['(', ')']: 160 | return None 161 | else: 162 | return tree 163 | elif isinstance(tree, list): 164 | new_tree = [] 165 | for child in tree: 166 | new_child = remove_bracket(child) 167 | if new_child is not None: 168 | new_tree.append(new_child) 169 | if new_tree == []: 170 | return None 171 | else: 172 | while len(new_tree) == 1 and isinstance(new_tree, list): 173 | new_tree = new_tree[0] 174 | return new_tree 175 | 176 | 177 | def char2tree(s): 178 | stack = [] 179 | for w in s: 180 | if w == '(': 181 | stack.append(w) 182 | elif w == ')': 183 | node = [] 184 | e = stack.pop() 185 | while not e == '(': 186 | node.append(e) 187 | e = stack.pop() 188 | node = node[::-1] 189 | stack.append(node) 190 | else: 191 | stack.append(w) 192 | while len(stack) == 1 and isinstance(stack, list): 193 | stack = stack[0] 194 | return stack 195 | 196 | 197 | 198 | def makedirs(name): 199 | """helper function for python 2 and 3 to call os.makedirs() 200 | avoiding an error if the directory to be created already exists""" 201 | 202 | import os, errno 203 | 204 | try: 205 | os.makedirs(name) 206 | except OSError as ex: 207 | if ex.errno == errno.EEXIST and os.path.isdir(name): 208 | # ignore existing directory 209 | pass 210 | else: 211 | # a different error happened 212 | raise 213 | 214 | 215 | def ConvertBinaryBracketedSeq(seq): 216 | T_SHIFT = 0 217 | T_REDUCE = 1 218 | T_SKIP = 2 219 | 220 | tokens, transitions = [], [] 221 | for item in seq: 222 | if item != "(": 223 | if item != ")": 224 | tokens.append(item) 225 | transitions.append(T_REDUCE if item == ")" else T_SHIFT) 226 | return tokens, transitions 227 | -------------------------------------------------------------------------------- /ordered_memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class Distribution(nn.Module): 8 | def __init__(self, nslot, hidden_size, dropout): 9 | super(Distribution, self).__init__() 10 | 11 | self.query = nn.Sequential( 12 | nn.Dropout(dropout), 13 | nn.Linear(hidden_size, hidden_size), 14 | nn.LayerNorm(hidden_size), 15 | ) 16 | 17 | self.key = nn.Sequential( 18 | nn.Dropout(dropout), 19 | nn.Linear(hidden_size, hidden_size), 20 | nn.LayerNorm(hidden_size), 21 | ) 22 | 23 | self.beta = nn.Sequential( 24 | nn.ReLU(), 25 | nn.Dropout(dropout), 26 | nn.Linear(hidden_size, 1), 27 | ) 28 | 29 | self.hidden_size = hidden_size 30 | 31 | def init_p(self, bsz, nslot): 32 | return None 33 | 34 | @staticmethod 35 | def process_softmax(beta, prev_p): 36 | if prev_p is None: 37 | return torch.zeros_like(beta), torch.ones_like(beta), torch.zeros_like(beta) 38 | 39 | beta_normalized = beta - beta.max(dim=-1)[0][:, None] 40 | x = torch.exp(beta_normalized) 41 | 42 | prev_cp = torch.cumsum(prev_p, dim=1) 43 | mask = prev_cp[:, 1:] 44 | mask = mask.masked_fill(mask < 1e-5, 0.) 45 | mask = F.pad(mask, (0, 1), value=1) 46 | 47 | x_masked = x * mask 48 | 49 | p = F.normalize(x_masked, p=1) 50 | cp = torch.cumsum(p, dim=1) 51 | rcp = torch.cumsum(p.flip([1]), dim=1).flip([1]) 52 | return cp, rcp, p 53 | 54 | def forward(self, in_val, prev_out_M, prev_p): 55 | query = self.query(in_val) 56 | key = self.key(prev_out_M) 57 | beta = self.beta(query[:, None, :] + key).squeeze(dim=2) 58 | beta = beta / math.sqrt(self.hidden_size) 59 | cp, rcp, p = self.process_softmax(beta, prev_p) 60 | return cp, rcp, p 61 | 62 | 63 | class Cell(nn.Module): 64 | def __init__(self, hidden_size, dropout, activation=None): 65 | super(Cell, self).__init__() 66 | self.hidden_size = hidden_size 67 | self.cell_hidden_size = 4 * hidden_size 68 | 69 | self.input_t = nn.Sequential( 70 | nn.Dropout(dropout), 71 | nn.Linear(hidden_size * 2, self.cell_hidden_size), 72 | nn.ReLU(), 73 | nn.Dropout(dropout), 74 | nn.Linear(self.cell_hidden_size, hidden_size * 4), 75 | ) 76 | 77 | self.gates = nn.Sequential( 78 | nn.Sigmoid(), 79 | ) 80 | 81 | assert activation is not None 82 | self.activation = activation 83 | 84 | self.drop = nn.Dropout(dropout) 85 | 86 | def forward(self, vi, hi): 87 | input = torch.cat([vi, hi], dim=-1) 88 | 89 | g_input, cell = self.input_t(input).split( 90 | (self.hidden_size * 3, self.hidden_size), 91 | dim=-1 92 | ) 93 | 94 | gates = self.gates(g_input) 95 | vg, hg, cg = gates.chunk(3, dim=1) 96 | output = self.activation(vg * vi + hg * hi + cg * cell) 97 | return output 98 | 99 | 100 | class OrderedMemoryRecurrent(nn.Module): 101 | def __init__(self, input_size, slot_size, nslot, 102 | dropout=0.2, dropoutm=0.2): 103 | super(OrderedMemoryRecurrent, self).__init__() 104 | 105 | self.activation = nn.LayerNorm(slot_size) 106 | self.input_projection = nn.Sequential( 107 | nn.Linear(input_size, slot_size), 108 | self.activation 109 | ) 110 | 111 | self.distribution = Distribution(nslot, slot_size, dropoutm) 112 | 113 | self.cell = Cell(slot_size, dropout, activation=self.activation) 114 | 115 | self.nslot = nslot 116 | self.slot_size = slot_size 117 | self.input_size = input_size 118 | 119 | def init_hidden(self, bsz): 120 | weight = next(self.parameters()).data 121 | zeros = weight.new(bsz, self.nslot, self.slot_size).zero_() 122 | p = self.distribution.init_p(bsz, self.nslot) 123 | return (zeros, zeros, p) 124 | 125 | def omr_step(self, in_val, prev_M, prev_out_M, prev_p): 126 | batch_size, nslot, slot_size = prev_M.size() 127 | _batch_size, slot_size = in_val.size() 128 | 129 | assert self.slot_size == slot_size 130 | assert self.nslot == nslot 131 | assert batch_size == _batch_size 132 | 133 | cp, rcp, p = self.distribution(in_val, prev_out_M, prev_p) 134 | 135 | curr_M = prev_M * (1 - rcp)[:, :, None] + prev_out_M * rcp[:, :, None] 136 | 137 | M_list = [] 138 | h = in_val 139 | for i in range(nslot): 140 | if i == nslot - 1 or cp[:, i+1].max() > 0: 141 | h = self.cell(h, curr_M[:, i, :]) 142 | h = in_val * (1 - cp)[:, i, None] + h * cp[:, i, None] 143 | M_list.append(h) 144 | out_M = torch.stack(M_list, dim=1) 145 | 146 | output = out_M[:, -1] 147 | return output, curr_M, out_M, p 148 | 149 | def forward(self, X, hidden, mask=None): 150 | prev_M, prev_memory_output, prev_p = hidden 151 | output_list = [] 152 | p_list = [] 153 | X_projected = self.input_projection(X) 154 | if mask is not None: 155 | padded = ~mask 156 | for t in range(X_projected.size(0)): 157 | output, prev_M, prev_memory_output, prev_p = self.omr_step( 158 | X_projected[t], prev_M, prev_memory_output, prev_p) 159 | if mask is not None: 160 | padded_1 = padded[t, :, None] 161 | padded_2 = padded[t, :, None, None] 162 | output = output.masked_fill(padded_1, 0.) 163 | prev_p = prev_p.masked_fill(padded_1, 0.) 164 | prev_M = prev_M.masked_fill(padded_2, 0.) 165 | prev_memory_output = prev_memory_output.masked_fill(padded_2, 0.) 166 | output_list.append(output) 167 | p_list.append(prev_p) 168 | 169 | output = torch.stack(output_list) 170 | probs = torch.stack(p_list) 171 | 172 | return (output, 173 | probs, 174 | (prev_M, prev_memory_output, prev_p)) 175 | 176 | 177 | class OrderedMemory(nn.Module): 178 | def __init__(self, input_size, slot_size, 179 | nslot, dropout=0.2, dropoutm=0.1, 180 | bidirection=False): 181 | super(OrderedMemory, self).__init__() 182 | 183 | self.OM_forward = OrderedMemoryRecurrent(input_size, slot_size, nslot, 184 | dropout=dropout, dropoutm=dropoutm) 185 | if bidirection: 186 | self.OM_backward = OrderedMemoryRecurrent(input_size, slot_size, nslot, 187 | dropout=dropout, dropoutm=dropoutm) 188 | 189 | self.bidirection = bidirection 190 | 191 | def init_hidden(self, bsz): 192 | return self.OM_forward.init_hidden(bsz) 193 | 194 | def forward(self, X, mask, output_last=False): 195 | bsz = X.size(1) 196 | lengths = mask.sum(0) 197 | init_hidden = self.init_hidden(bsz) 198 | 199 | output_list = [] 200 | prob_list = [] 201 | 202 | om_output_forward, prob_forward, _ = self.OM_forward(X, init_hidden, mask) 203 | if output_last: 204 | output_list.append(om_output_forward[-1]) 205 | else: 206 | output_list.append(om_output_forward[lengths - 1, torch.arange(bsz).long()]) 207 | prob_list.append(prob_forward) 208 | 209 | if self.bidirection: 210 | om_output_backward, prob_backward, _ = self.OM_backward(X.flip([0]), init_hidden, mask.flip([0])) 211 | output_list.append(om_output_backward[-1]) 212 | prob_list.append(prob_backward.flip([0])) 213 | 214 | output = torch.cat(output_list, dim=-1) 215 | self.probs = prob_list[0] 216 | 217 | return output 218 | -------------------------------------------------------------------------------- /EVALB/README: -------------------------------------------------------------------------------- 1 | ################################################################# 2 | # # 3 | # Bug fix and additional functionality for evalb # 4 | # # 5 | # This updated version of evalb fixes a bug in which sentences # 6 | # were incorrectly categorized as "length mismatch" when the # 7 | # the parse output had certain mislabeled parts-of-speech. # 8 | # # 9 | # The bug was the result of evalb treating one of the tags (in # 10 | # gold or test) as a label to be deleted (see sections [6],[7] # 11 | # for details), but not the corresponding tag in the other. # 12 | # This most often occurs with punctuation. See the subdir # 13 | # "bug" for an example gld and tst file demonstating the bug, # 14 | # as well as output of evalb with and without the bug fix. # 15 | # # 16 | # For the present version in case of length mismatch, the nodes # 17 | # causing the imbalance are reinserted to resolve the miscount. # 18 | # If the lengths of gold and test truly differ, the error is # 19 | # still reported. The parameter file "new.prm" (derived from # 20 | # COLLINS.prm) shows how to add new potential mislabelings for # 21 | # quotes (",``,',`). # 22 | # # 23 | # I have preserved DJB's revision for modern compilers except # 24 | # for the delcaration of "exit" which is provided by stdlib. # 25 | # # 26 | # Other changes: # 27 | # # 28 | # * output of F-Measure in addition to precision and recall # 29 | # (I did not update the documention in section [4] for this) # 30 | # # 31 | # * more comprehensive DEBUG output that includes bracketing # 32 | # information as evalb is processing each sentence # 33 | # (useful in working through this, and peraps other bugs). # 34 | # Use either the "-D" run-time switch or set DEBUG to 2 in # 35 | # the parameter file. # 36 | # # 37 | # * added DELETE_LABEL lines in new.prm for S1 nodes produced # 38 | # by the Charniak parser and "?", "!" punctuation produced by # 39 | # the Bikel parser. # 40 | # # 41 | # # 42 | # David Ellis (Brown) # 43 | # # 44 | # January.2006 # 45 | ################################################################# 46 | 47 | ################################################################# 48 | # # 49 | # Update of evalb for modern compilers # 50 | # # 51 | # This is an updated version of evalb, for use with modern C # 52 | # compilers. There are a few updates, each marked in the code: # 53 | # # 54 | # /* DJB: explanation of comment */ # 55 | # # 56 | # The updates are purely to help compilation with recent # 57 | # versions of GCC (and other C compilers). There are *NO* other # 58 | # changes to the algorithm itself. # 59 | # # 60 | # I have made these changes following recommendations from # 61 | # users of the Corpora Mailing List, especially Peet Morris and # 62 | # Ramon Ziai. # 63 | # # 64 | # David Brooks (Birmingham) # 65 | # # 66 | # September.2005 # 67 | ################################################################# 68 | 69 | ################################################################# 70 | # # 71 | # README file for evalb # 72 | # # 73 | # Satoshi Sekine (NYU) # 74 | # Mike Collins (UPenn) # 75 | # # 76 | # October.1997 # 77 | ################################################################# 78 | 79 | Contents of this README: 80 | 81 | [0] COPYRIGHT 82 | [1] INTRODUCTION 83 | [2] INSTALLATION AND RUN 84 | [3] OPTIONS 85 | [4] OUTPUT FORMAT FROM THE SCORER 86 | [5] HOW TO CREATE A GOLDFILE FROM THE TREEBANK 87 | [6] THE PARAMETER FILE 88 | [7] MORE DETAILS ABOUT THE SCORING ALGORITHM 89 | 90 | 91 | [0] COPYRIGHT 92 | 93 | The authors abandon the copyright of this program. Everyone is 94 | permitted to copy and distribute the program or a portion of the program 95 | with no charge and no restrictions unless it is harmful to someone. 96 | 97 | However, the authors are delightful for the user's kindness of proper 98 | usage and letting the authors know bugs or problems. 99 | 100 | This software is provided "AS IS", and the authors make no warranties, 101 | express or implied. 102 | 103 | To legally enforce the abandonment of copyright, this package is released 104 | under the Unlicense (see LICENSE). 105 | 106 | [1] INTRODUCTION 107 | 108 | Evaluation of bracketing looks simple, but in fact, there are minor 109 | differences from system to system. This is a program to parametarize 110 | such minor differences and to give an informative result. 111 | 112 | "evalb" evaluates bracketing accuracy in a test-file against a gold-file. 113 | It returns recall, precision, tagging accuracy. It uses an identical 114 | algorithm to that used in (Collins ACL97). 115 | 116 | 117 | [2] Installation and Run 118 | 119 | To compile the scorer, type 120 | 121 | > make 122 | 123 | 124 | To run the scorer: 125 | 126 | > evalb -p Parameter_file Gold_file Test_file 127 | 128 | 129 | For example to use the sample files: 130 | 131 | > evalb -p sample.prm sample.gld sample.tst 132 | 133 | 134 | 135 | [3] OPTIONS 136 | 137 | You can specify system parameters in the command line options. 138 | Other options concerning to evaluation metrix should be specified 139 | in parameter file, described later. 140 | 141 | -p param_file parameter file 142 | -d debug mode 143 | -e n number of error to kill (default=10) 144 | -h help 145 | 146 | 147 | 148 | [4] OUTPUT FORMAT FROM THE SCORER 149 | 150 | The scorer gives individual scores for each sentence, for 151 | example: 152 | 153 | Sent. Matched Bracket Cross Correct Tag 154 | ID Len. Stat. Recal Prec. Bracket gold test Bracket Words Tags Accracy 155 | ============================================================================ 156 | 1 8 0 100.00 100.00 5 5 5 0 6 5 83.33 157 | 158 | At the end of the output the === Summary === section gives statistics 159 | for all sentences, and for sentences <=40 words in length. The summary 160 | contains the following information: 161 | 162 | i) Number of sentences -- total number of sentences. 163 | 164 | ii) Number of Error/Skip sentences -- should both be 0 if there is no 165 | problem with the parsed/gold files. 166 | 167 | iii) Number of valid sentences = Number of sentences - Number of Error/Skip 168 | sentences 169 | 170 | iv) Bracketing recall = (number of correct constituents) 171 | ---------------------------------------- 172 | (number of constituents in the goldfile) 173 | 174 | v) Bracketing precision = (number of correct constituents) 175 | ---------------------------------------- 176 | (number of constituents in the parsed file) 177 | 178 | vi) Complete match = percentaage of sentences where recall and precision are 179 | both 100%. 180 | 181 | vii) Average crossing = (number of constituents crossing a goldfile constituen 182 | ---------------------------------------------------- 183 | (number of sentences) 184 | 185 | viii) No crossing = percentage of sentences which have 0 crossing brackets. 186 | 187 | ix) 2 or less crossing = percentage of sentences which have <=2 crossing brackets. 188 | 189 | x) Tagging accuracy = percentage of correct POS tags (but see [5].3 for exact 190 | details of what is counted). 191 | 192 | 193 | 194 | [5] HOW TO CREATE A GOLDFILE FROM THE PENN TREEBANK 195 | 196 | 197 | The gold and parsed files are in a format similar to this: 198 | 199 | (TOP (S (INTJ (RB No)) (, ,) (NP (PRP it)) (VP (VBD was) (RB n't) (NP (NNP Black) (NNP Monday))) (. .))) 200 | 201 | To create a gold file from the treebank: 202 | 203 | tgrep -wn '/.*/' | tgrep_proc.prl 204 | 205 | will produce a goldfile in the required format. ("tgrep -wn '/.*/'" prints 206 | parse trees, "tgrep_process.prl" just skips blank lines). 207 | 208 | For example, to produce a goldfile for section 23 of the treebank: 209 | 210 | tgrep -wn '/.*/' | tail +90895 | tgrep_process.prl | sed 2416q > sec23.gold 211 | 212 | 213 | 214 | [6] THE PARAMETER (.prm) FILE 215 | 216 | 217 | The .prm file sets options regarding the scoring method. COLLINS.prm gives 218 | the same scoring behaviour as the scorer used in (Collins 97). The options 219 | chosen were: 220 | 221 | 1) LABELED 1 222 | 223 | to give labelled precision/recall figures, i.e. a constituent must have the 224 | same span *and* label as a constituent in the goldfile. 225 | 226 | 2) DELETE_LABEL TOP 227 | 228 | Don't count the "TOP" label (which is always given in the output of tgrep) 229 | when scoring. 230 | 231 | 3) DELETE_LABEL -NONE- 232 | 233 | Remove traces (and all constituents which dominate nothing but traces) when 234 | scoring. For example 235 | 236 | .... (VP (VBD reported) (SBAR (-NONE- 0) (S (-NONE- *T*-1)))) (. .))) 237 | 238 | would be processed to give 239 | 240 | .... (VP (VBD reported)) (. .))) 241 | 242 | 243 | 4) 244 | DELETE_LABEL , -- for the purposes of scoring remove punctuation 245 | DELETE_LABEL : 246 | DELETE_LABEL `` 247 | DELETE_LABEL '' 248 | DELETE_LABEL . 249 | 250 | 5) DELETE_LABEL_FOR_LENGTH -NONE- -- don't include traces when calculating 251 | the length of a sentence (important 252 | when classifying a sentence as <=40 253 | words or >40 words) 254 | 255 | 6) EQ_LABEL ADVP PRT 256 | 257 | Count ADVP and PRT as being the same label when scoring. 258 | 259 | 260 | 261 | 262 | [7] MORE DETAILS ABOUT THE SCORING ALGORITHM 263 | 264 | 265 | 1) The scorer initially processes the files to remove all nodes specified 266 | by DELETE_LABEL in the .prm file. It also recursively removes nodes which 267 | dominate nothing due to all their children being removed. For example, if 268 | -NONE- is specified as a label to be deleted, 269 | 270 | .... (VP (VBD reported) (SBAR (-NONE- 0) (S (-NONE- *T*-1)))) (. .))) 271 | 272 | would be processed to give 273 | 274 | .... (VP (VBD reported)) (. .))) 275 | 276 | 2) The scorer also removes all functional tags attached to non-terminals 277 | (functional tags are prefixed with "-" or "=" in the treebank). For example 278 | "NP-SBJ" is processed to give "NP", "NP=2" is changed to "NP". 279 | 280 | 281 | 3) Tagging accuracy counts tags for all words *except* any tags which are 282 | deleted by a DELETE_LABEL specification in the .prm file. (For example, for 283 | COLLINS.prm, punctuation tagged as "," ":" etc. would not be included). 284 | 285 | 4) When calculating the length of a sentence, all words with POS tags not 286 | included in the "DELETE_LABEL_FOR_LENGTH" list in the .prm file are 287 | counted. (For COLLINS.prm, only "-NONE-" is specified in this list, so 288 | traces are removed before calculating the length of the sentence). 289 | 290 | 5) There are some subtleties in scoring when either the goldfile or parsed 291 | file contains multiple constituents for the same span which have the same 292 | non-terminal label. e.g. (NP (NP the man)) If the goldfile contains n 293 | constituents for the same span, and the parsed file contains m constituents 294 | with that nonterminal, the scorer works as follows: 295 | 296 | i) If m>n, then the precision is n/m, recall is 100% 297 | 298 | ii) If n>m, then the precision is 100%, recall is m/n. 299 | 300 | iii) If n==m, recall and precision are both 100%. 301 | -------------------------------------------------------------------------------- /data/propositionallogic/test1: -------------------------------------------------------------------------------- 1 | # ( d ( and e ) ) ( d ( and c ) ) 2 | # ( e ( and d ) ) ( f ( or f ) ) 3 | > ( b ( or c ) ) ( b ( and d ) ) 4 | # ( e ( or e ) ) ( a ( and c ) ) 5 | # ( not d ) ( c ( or b ) ) 6 | # ( a ( and a ) ) ( d ( and f ) ) 7 | # ( d ( or f ) ) ( b ( or a ) ) 8 | # ( d ( and b ) ) ( a ( or a ) ) 9 | > ( f ( or b ) ) ( f ( or f ) ) 10 | # ( c ( or c ) ) ( b ( and b ) ) 11 | # ( f ( and f ) ) ( a ( and a ) ) 12 | < ( d ( and a ) ) ( a ( or c ) ) 13 | < ( f ( and c ) ) ( f ( and f ) ) 14 | # ( e ( and a ) ) ( f ( and a ) ) 15 | # ( a ( and c ) ) ( not f ) 16 | # ( f ( and f ) ) ( b ( or e ) ) 17 | > ( b ( or f ) ) ( b ( and d ) ) 18 | # ( d ( and f ) ) ( e ( or e ) ) 19 | # ( f ( and e ) ) ( b ( and f ) ) 20 | > ( a ( or a ) ) ( a ( and e ) ) 21 | < ( b ( and e ) ) ( e ( or b ) ) 22 | # ( d ( or e ) ) ( b ( and b ) ) 23 | # ( a ( or d ) ) ( d ( or b ) ) 24 | # ( b ( and b ) ) ( c ( and c ) ) 25 | < ( a ( or a ) ) ( e ( or a ) ) 26 | # ( c ( and f ) ) ( not b ) 27 | # ( b ( or a ) ) ( not c ) 28 | # ( b ( and a ) ) ( e ( and e ) ) 29 | # ( not b ) ( c ( and e ) ) 30 | > ( d ( or d ) ) ( e ( and d ) ) 31 | > ( f ( or a ) ) ( c ( and f ) ) 32 | # ( f ( or a ) ) ( b ( or f ) ) 33 | # ( e ( and f ) ) ( a ( and a ) ) 34 | # ( e ( or a ) ) ( b ( and b ) ) 35 | # ( a ( or e ) ) ( e ( or c ) ) 36 | > ( d ( or c ) ) ( d ( and b ) ) 37 | < ( b ( and d ) ) ( d ( or c ) ) 38 | > ( d ( or f ) ) ( a ( and f ) ) 39 | < ( a ( and f ) ) ( e ( or a ) ) 40 | = ( c ( and e ) ) ( e ( and c ) ) 41 | > ( f ( or f ) ) ( f ( and b ) ) 42 | # ( f ( and c ) ) ( d ( or d ) ) 43 | # ( c ( and c ) ) ( d ( or d ) ) 44 | # ( d ( and d ) ) ( b ( and c ) ) 45 | > ( b ( or f ) ) ( b ( and b ) ) 46 | < ( a ( and e ) ) ( e ( or a ) ) 47 | = ( c ( and d ) ) ( d ( and c ) ) 48 | # ( d ( and d ) ) ( e ( or e ) ) 49 | # ( c ( or d ) ) ( d ( or a ) ) 50 | # ( a ( or e ) ) ( c ( or c ) ) 51 | > ( d ( or f ) ) ( d ( and e ) ) 52 | # ( d ( or d ) ) ( e ( or b ) ) 53 | = ( e ( or f ) ) ( f ( or e ) ) 54 | # ( b ( and d ) ) ( c ( and d ) ) 55 | > ( c ( or b ) ) ( b ( and b ) ) 56 | # ( d ( or d ) ) ( a ( or e ) ) 57 | > ( e ( and e ) ) ( e ( and b ) ) 58 | # ( b ( or a ) ) ( f ( or c ) ) 59 | # ( f ( or d ) ) ( d ( or c ) ) 60 | > ( c ( or a ) ) ( a ( and a ) ) 61 | = ( e ( or d ) ) ( e ( or d ) ) 62 | = ( c ( or e ) ) ( c ( or e ) ) 63 | # ( a ( and e ) ) ( f ( or d ) ) 64 | # ( a ( or b ) ) ( c ( or d ) ) 65 | > ( e ( or a ) ) ( a ( and a ) ) 66 | # ( e ( or f ) ) ( d ( or f ) ) 67 | # ( b ( or f ) ) ( e ( or c ) ) 68 | # ( b ( or a ) ) ( e ( or f ) ) 69 | # ( a ( or e ) ) ( f ( and b ) ) 70 | < ( f ( and a ) ) ( b ( or f ) ) 71 | # ( a ( or e ) ) ( a ( or b ) ) 72 | < ( b ( or b ) ) ( b ( or c ) ) 73 | # ( f ( or c ) ) ( a ( and d ) ) 74 | # ( c ( and a ) ) ( e ( and d ) ) 75 | < ( f ( and a ) ) ( a ( or f ) ) 76 | # ( a ( and e ) ) ( a ( and d ) ) 77 | # ( e ( and a ) ) ( e ( and d ) ) 78 | # ( b ( or d ) ) ( f ( and f ) ) 79 | < ( a ( and c ) ) ( c ( or d ) ) 80 | < ( c ( or c ) ) ( c ( or f ) ) 81 | > ( d ( or b ) ) ( b ( and a ) ) 82 | # ( d ( and e ) ) ( b ( and c ) ) 83 | # ( e ( or a ) ) ( e ( or f ) ) 84 | # ( c ( and d ) ) ( c ( and e ) ) 85 | # ( a ( or a ) ) ( c ( or f ) ) 86 | < ( c ( or c ) ) ( c ( or e ) ) 87 | # ( f ( or f ) ) ( c ( or a ) ) 88 | # ( c ( and a ) ) ( e ( and a ) ) 89 | > ( d ( or a ) ) ( c ( and d ) ) 90 | # ( e ( or e ) ) ( d ( and c ) ) 91 | > ( b ( or d ) ) ( b ( and a ) ) 92 | < ( b ( and d ) ) ( b ( or d ) ) 93 | # ( c ( and a ) ) ( f ( or f ) ) 94 | # ( b ( or e ) ) ( d ( or b ) ) 95 | # ( e ( and f ) ) ( e ( and a ) ) 96 | > ( a ( or a ) ) ( a ( and d ) ) 97 | > ( f ( and f ) ) ( f ( and b ) ) 98 | > ( e ( or f ) ) ( c ( and e ) ) 99 | > ( c ( or d ) ) ( d ( and c ) ) 100 | > ( e ( or d ) ) ( e ( and e ) ) 101 | # ( f ( or e ) ) ( e ( or a ) ) 102 | > ( e ( or b ) ) ( e ( and a ) ) 103 | # ( d ( or f ) ) ( b ( and a ) ) 104 | < ( e ( and a ) ) ( e ( and e ) ) 105 | < ( d ( and d ) ) ( a ( or d ) ) 106 | < ( b ( and e ) ) ( c ( or e ) ) 107 | > ( f ( or b ) ) ( f ( and f ) ) 108 | # ( d ( and a ) ) ( c ( and b ) ) 109 | = ( e ( and c ) ) ( c ( and e ) ) 110 | # ( c ( and c ) ) ( f ( or d ) ) 111 | < ( e ( and e ) ) ( c ( or e ) ) 112 | # ( e ( or e ) ) ( d ( and d ) ) 113 | # ( f ( and f ) ) ( b ( or a ) ) 114 | # ( not c ) ( b ( or e ) ) 115 | # ( a ( or c ) ) ( f ( or f ) ) 116 | # ( d ( or e ) ) ( b ( or a ) ) 117 | > ( a ( or f ) ) ( d ( and a ) ) 118 | > ( a ( or c ) ) ( c ( and f ) ) 119 | # ( not f ) ( a ( and d ) ) 120 | > ( e ( or f ) ) ( e ( and f ) ) 121 | < ( f ( and f ) ) ( e ( or f ) ) 122 | # ( b ( and d ) ) ( b ( and e ) ) 123 | # ( c ( and f ) ) ( c ( and a ) ) 124 | # ( c ( and f ) ) ( a ( and c ) ) 125 | # ( c ( or c ) ) ( f ( and e ) ) 126 | # ( b ( or c ) ) ( c ( or a ) ) 127 | # ( b ( and b ) ) ( e ( or a ) ) 128 | # ( f ( or d ) ) ( b ( or e ) ) 129 | # ( c ( or b ) ) ( e ( or a ) ) 130 | # ( a ( and a ) ) ( d ( and c ) ) 131 | # ( b ( and b ) ) ( c ( and f ) ) 132 | # ( a ( and a ) ) ( c ( or e ) ) 133 | # ( e ( or f ) ) ( c ( or c ) ) 134 | # ( f ( and f ) ) ( a ( and c ) ) 135 | # ( not a ) ( b ( and f ) ) 136 | # ( a ( or e ) ) ( b ( or e ) ) 137 | < ( a ( and e ) ) ( d ( or e ) ) 138 | # ( a ( and a ) ) ( f ( or d ) ) 139 | = ( f ( or e ) ) ( f ( or e ) ) 140 | # ( f ( and f ) ) ( b ( or c ) ) 141 | # ( c ( or f ) ) ( d ( or a ) ) 142 | < ( f ( and a ) ) ( f ( or f ) ) 143 | = ( d ( or b ) ) ( b ( or d ) ) 144 | > ( a ( or b ) ) ( d ( and a ) ) 145 | < ( d ( and e ) ) ( d ( or b ) ) 146 | # ( b ( and b ) ) ( e ( and e ) ) 147 | < ( c ( and f ) ) ( b ( or f ) ) 148 | # ( f ( or e ) ) ( f ( or b ) ) 149 | < ( e ( or e ) ) ( d ( or e ) ) 150 | < ( a ( and e ) ) ( b ( or e ) ) 151 | # ( a ( or f ) ) ( f ( or b ) ) 152 | > ( f ( or a ) ) ( a ( and a ) ) 153 | > ( d ( or b ) ) ( b ( and b ) ) 154 | = ( a ( and d ) ) ( a ( and d ) ) 155 | # ( f ( or f ) ) ( e ( and c ) ) 156 | # ( e ( or c ) ) ( e ( or b ) ) 157 | # ( f ( and e ) ) ( a ( and f ) ) 158 | # ( a ( or c ) ) ( d ( or a ) ) 159 | # ( d ( and d ) ) ( f ( and e ) ) 160 | # ( e ( and d ) ) ( e ( and f ) ) 161 | # ( c ( or a ) ) ( c ( or f ) ) 162 | < ( f ( and f ) ) ( c ( or f ) ) 163 | = ( d ( and d ) ) ( d ( and d ) ) 164 | < ( d ( and e ) ) ( c ( or d ) ) 165 | = ( b ( or b ) ) ( b ( or b ) ) 166 | < ( b ( and a ) ) ( a ( and a ) ) 167 | < ( a ( and c ) ) ( c ( or c ) ) 168 | # ( c ( and c ) ) ( f ( and a ) ) 169 | # ( d ( and b ) ) ( c ( and c ) ) 170 | # ( c ( or d ) ) ( c ( or f ) ) 171 | > ( e ( or a ) ) ( a ( and d ) ) 172 | > ( c ( or d ) ) ( c ( and a ) ) 173 | < ( e ( and a ) ) ( a ( or b ) ) 174 | # ( c ( or e ) ) ( f ( or c ) ) 175 | > ( c ( or a ) ) ( c ( and c ) ) 176 | # ( b ( and f ) ) ( d ( and b ) ) 177 | # ( d ( and a ) ) ( e ( and d ) ) 178 | # ( f ( and b ) ) ( f ( and e ) ) 179 | > ( e ( or b ) ) ( e ( and d ) ) 180 | # ( c ( or a ) ) ( d ( or c ) ) 181 | < ( c ( or c ) ) ( b ( or c ) ) 182 | # ( e ( or d ) ) ( e ( or a ) ) 183 | # ( c ( and b ) ) ( a ( and b ) ) 184 | > ( c ( or b ) ) ( b ( or b ) ) 185 | # ( d ( and f ) ) ( e ( and d ) ) 186 | # ( f ( and a ) ) ( e ( and a ) ) 187 | > ( a ( or b ) ) ( a ( or a ) ) 188 | < ( d ( and a ) ) ( a ( and a ) ) 189 | # ( c ( or a ) ) ( d ( or a ) ) 190 | # ( a ( and a ) ) ( c ( or c ) ) 191 | # ( a ( or d ) ) ( a ( or f ) ) 192 | < ( e ( and f ) ) ( f ( or e ) ) 193 | < ( c ( and f ) ) ( c ( or f ) ) 194 | = ( c ( and b ) ) ( c ( and b ) ) 195 | # ( f ( and b ) ) ( a ( and d ) ) 196 | # ( a ( and b ) ) ( f ( or c ) ) 197 | < ( f ( or f ) ) ( f ( or c ) ) 198 | = ( f ( and d ) ) ( d ( and f ) ) 199 | # ( d ( or f ) ) ( d ( or b ) ) 200 | < ( c ( and a ) ) ( a ( or b ) ) 201 | # ( d ( or b ) ) ( e ( or b ) ) 202 | # ( c ( and c ) ) ( d ( and f ) ) 203 | > ( d ( or a ) ) ( a ( or a ) ) 204 | > ( c ( or a ) ) ( a ( and e ) ) 205 | # ( not c ) ( a ( and f ) ) 206 | < ( d ( and f ) ) ( a ( or f ) ) 207 | # ( c ( or f ) ) ( b ( or c ) ) 208 | # ( e ( and d ) ) ( e ( and c ) ) 209 | # ( a ( or d ) ) ( e ( or c ) ) 210 | # ( f ( and f ) ) ( c ( and d ) ) 211 | < ( d ( and d ) ) ( d ( or b ) ) 212 | < ( f ( and d ) ) ( d ( or a ) ) 213 | # ( c ( or c ) ) ( d ( or d ) ) 214 | # ( b ( and a ) ) ( b ( and c ) ) 215 | # ( a ( and f ) ) ( a ( and c ) ) 216 | # ( a ( or d ) ) ( b ( or d ) ) 217 | # ( f ( and b ) ) ( d ( or c ) ) 218 | # ( c ( or d ) ) ( e ( or c ) ) 219 | # ( a ( and b ) ) ( f ( and c ) ) 220 | # ( c ( or b ) ) ( f ( and e ) ) 221 | # ( e ( or a ) ) ( d ( or f ) ) 222 | # ( f ( or f ) ) ( e ( or c ) ) 223 | # ( e ( and f ) ) ( a ( and d ) ) 224 | # ( c ( and c ) ) ( a ( or e ) ) 225 | < ( b ( and f ) ) ( f ( or c ) ) 226 | # ( c ( or c ) ) ( f ( or d ) ) 227 | > ( c ( or b ) ) ( b ( and a ) ) 228 | < ( c ( and b ) ) ( e ( or b ) ) 229 | # ( e ( and b ) ) ( c ( or c ) ) 230 | # ( e ( and e ) ) ( d ( or c ) ) 231 | # ( a ( or d ) ) ( a ( or e ) ) 232 | # ( a ( or b ) ) ( e ( or b ) ) 233 | < ( e ( and b ) ) ( b ( or f ) ) 234 | < ( e ( and f ) ) ( f ( and f ) ) 235 | # ( e ( and b ) ) ( a ( or a ) ) 236 | < ( e ( and c ) ) ( e ( or f ) ) 237 | # ( e ( or e ) ) ( f ( or b ) ) 238 | # ( f ( or f ) ) ( c ( or e ) ) 239 | # ( f ( and a ) ) ( f ( and e ) ) 240 | > ( c ( or b ) ) ( c ( or c ) ) 241 | > ( a ( or f ) ) ( e ( and f ) ) 242 | # ( d ( and f ) ) ( c ( and c ) ) 243 | # ( d ( or e ) ) ( e ( or c ) ) 244 | # ( a ( and d ) ) ( a ( and c ) ) 245 | > ( d ( or b ) ) ( d ( and d ) ) 246 | < ( f ( and a ) ) ( a ( or e ) ) 247 | < ( a ( and c ) ) ( a ( or c ) ) 248 | > ( d ( or d ) ) ( d ( and a ) ) 249 | < ( a ( and b ) ) ( b ( or a ) ) 250 | = ( b ( and f ) ) ( f ( and b ) ) 251 | # ( c ( or e ) ) ( b ( and d ) ) 252 | # ( c ( or b ) ) ( e ( or c ) ) 253 | > ( d ( or b ) ) ( b ( and e ) ) 254 | < ( d ( and f ) ) ( d ( or f ) ) 255 | # ( d ( and f ) ) ( c ( and d ) ) 256 | # ( d ( and b ) ) ( e ( or a ) ) 257 | < ( d ( and c ) ) ( d ( or d ) ) 258 | # ( f ( or a ) ) ( f ( or e ) ) 259 | = ( a ( and d ) ) ( d ( and a ) ) 260 | # ( c ( and c ) ) ( e ( or b ) ) 261 | < ( c ( and e ) ) ( c ( or a ) ) 262 | = ( f ( or b ) ) ( b ( or f ) ) 263 | # ( c ( and f ) ) ( e ( and f ) ) 264 | > ( e ( or f ) ) ( e ( or e ) ) 265 | # ( e ( or c ) ) ( a ( and a ) ) 266 | # ( f ( and f ) ) ( c ( and b ) ) 267 | # ( e ( and a ) ) ( c ( and f ) ) 268 | # ( d ( and a ) ) ( e ( and f ) ) 269 | # ( b ( or e ) ) ( d ( and d ) ) 270 | > ( c ( or e ) ) ( a ( and e ) ) 271 | > ( b ( or c ) ) ( b ( and b ) ) 272 | # ( a ( or e ) ) ( d ( or d ) ) 273 | # ( b ( or e ) ) ( d ( and a ) ) 274 | # ( f ( and b ) ) ( e ( and f ) ) 275 | # ( not e ) ( f ( or b ) ) 276 | # ( c ( or a ) ) ( b ( and d ) ) 277 | # ( e ( and f ) ) ( b ( and b ) ) 278 | = ( b ( and e ) ) ( e ( and b ) ) 279 | # ( a ( or a ) ) ( c ( or e ) ) 280 | > ( f ( or c ) ) ( f ( and f ) ) 281 | # ( d ( and f ) ) ( b ( or b ) ) 282 | # ( b ( and b ) ) ( c ( or e ) ) 283 | # ( b ( or b ) ) ( a ( or c ) ) 284 | # ( c ( or c ) ) ( d ( and a ) ) 285 | # ( a ( or e ) ) ( a ( or f ) ) 286 | # ( b ( or e ) ) ( c ( or f ) ) 287 | > ( d ( or e ) ) ( d ( or d ) ) 288 | # ( f ( and c ) ) ( a ( and d ) ) 289 | # ( d ( or e ) ) ( d ( or f ) ) 290 | # ( a ( or a ) ) ( b ( or d ) ) 291 | # ( f ( or e ) ) ( d ( or e ) ) 292 | < ( e ( and f ) ) ( e ( or e ) ) 293 | # ( d ( or a ) ) ( f ( or d ) ) 294 | # ( e ( or e ) ) ( c ( or b ) ) 295 | = ( f ( and f ) ) ( f ( or f ) ) 296 | # ( e ( and e ) ) ( d ( and d ) ) 297 | > ( f ( or e ) ) ( e ( and a ) ) 298 | # ( d ( or a ) ) ( f ( and b ) ) 299 | # ( d ( and b ) ) ( d ( and f ) ) 300 | # ( f ( or c ) ) ( e ( or a ) ) 301 | # ( b ( or d ) ) ( b ( or c ) ) 302 | # ( c ( or b ) ) ( d ( or e ) ) 303 | # ( d ( and a ) ) ( e ( and a ) ) 304 | < ( a ( and c ) ) ( a ( or a ) ) 305 | > ( f ( or b ) ) ( f ( and e ) ) 306 | # ( f ( or c ) ) ( e ( or f ) ) 307 | # ( d ( and d ) ) ( f ( or a ) ) 308 | # ( d ( or c ) ) ( e ( and e ) ) 309 | < ( a ( and f ) ) ( a ( or e ) ) 310 | # ( a ( or d ) ) ( c ( and f ) ) 311 | # ( d ( or d ) ) ( c ( and c ) ) 312 | > ( e ( or c ) ) ( b ( and c ) ) 313 | # ( a ( and d ) ) ( b ( and f ) ) 314 | < ( b ( and a ) ) ( b ( and b ) ) 315 | = ( f ( or a ) ) ( f ( or a ) ) 316 | # ( c ( or c ) ) ( b ( and a ) ) 317 | > ( c ( or c ) ) ( f ( and c ) ) 318 | > ( d ( and d ) ) ( d ( and a ) ) 319 | < ( b ( and c ) ) ( c ( or a ) ) 320 | > ( e ( or d ) ) ( d ( and b ) ) 321 | # ( f ( and a ) ) ( c ( or c ) ) 322 | # ( c ( or f ) ) ( d ( or b ) ) 323 | > ( f ( or a ) ) ( f ( and c ) ) 324 | > ( f ( or f ) ) ( c ( and f ) ) 325 | # ( e ( and f ) ) ( d ( and e ) ) 326 | # ( d ( and a ) ) ( d ( and f ) ) 327 | # ( f ( or a ) ) ( e ( and e ) ) 328 | # ( d ( and c ) ) ( c ( and a ) ) 329 | < ( e ( and b ) ) ( e ( or e ) ) 330 | # ( c ( or d ) ) ( f ( or c ) ) 331 | > ( d ( or c ) ) ( a ( and d ) ) 332 | # ( d ( or e ) ) ( a ( or c ) ) 333 | # ( e ( and d ) ) ( f ( and f ) ) 334 | > ( d ( or c ) ) ( e ( and c ) ) 335 | # ( c ( and a ) ) ( e ( and e ) ) 336 | # ( b ( or a ) ) ( f ( and f ) ) 337 | = ( e ( or d ) ) ( d ( or e ) ) 338 | < ( b ( and d ) ) ( e ( or d ) ) 339 | # ( f ( or d ) ) ( a ( or d ) ) 340 | > ( c ( and c ) ) ( c ( and d ) ) 341 | < ( c ( and e ) ) ( b ( or e ) ) 342 | # ( a ( or d ) ) ( f ( or b ) ) 343 | # ( c ( or c ) ) ( b ( or b ) ) 344 | # ( b ( and a ) ) ( e ( or d ) ) 345 | # ( a ( or e ) ) ( c ( or b ) ) 346 | = ( e ( or c ) ) ( c ( or e ) ) 347 | # ( c ( and a ) ) ( b ( or b ) ) 348 | # ( a ( or b ) ) ( a ( or e ) ) 349 | # ( b ( or d ) ) ( a ( or d ) ) 350 | # ( a ( or b ) ) ( c ( and d ) ) 351 | # ( d ( or e ) ) ( f ( or e ) ) 352 | > ( b ( or f ) ) ( a ( and f ) ) 353 | < ( a ( and c ) ) ( e ( or c ) ) 354 | < ( b ( or b ) ) ( b ( or e ) ) 355 | < ( a ( and f ) ) ( d ( or f ) ) 356 | # ( f ( or a ) ) ( d ( and e ) ) 357 | = ( a ( or a ) ) ( a ( and a ) ) 358 | # ( b ( or b ) ) ( e ( or c ) ) 359 | # ( e ( and b ) ) ( c ( and b ) ) 360 | < ( f ( and f ) ) ( f ( or e ) ) 361 | > ( f ( or d ) ) ( f ( and e ) ) 362 | # ( e ( or b ) ) ( b ( or a ) ) 363 | # ( a ( and c ) ) ( f ( and d ) ) 364 | > ( f ( or b ) ) ( b ( and e ) ) 365 | > ( a ( or e ) ) ( c ( and a ) ) 366 | > ( d ( or c ) ) ( d ( or d ) ) 367 | # ( b ( or d ) ) ( c ( and e ) ) 368 | > ( c ( or d ) ) ( d ( and d ) ) 369 | # ( d ( and d ) ) ( a ( or a ) ) 370 | # ( f ( and a ) ) ( e ( and c ) ) 371 | = ( c ( and c ) ) ( c ( and c ) ) 372 | # ( f ( and f ) ) ( d ( or d ) ) 373 | = ( c ( or c ) ) ( c ( or c ) ) 374 | < ( b ( and b ) ) ( d ( or b ) ) 375 | > ( b ( or e ) ) ( e ( and c ) ) 376 | < ( d ( and b ) ) ( d ( or f ) ) 377 | # ( f ( or d ) ) ( a ( and a ) ) 378 | # ( f ( or f ) ) ( d ( or e ) ) 379 | < ( b ( and c ) ) ( b ( or c ) ) 380 | > ( e ( or c ) ) ( e ( and e ) ) 381 | < ( b ( and b ) ) ( a ( or b ) ) 382 | # ( d ( and d ) ) ( f ( or f ) ) 383 | > ( f ( or b ) ) ( c ( and b ) ) 384 | < ( a ( and c ) ) ( c ( and c ) ) 385 | # ( b ( and d ) ) ( b ( and f ) ) 386 | # ( f ( and c ) ) ( a ( or d ) ) 387 | < ( d ( and e ) ) ( e ( or a ) ) 388 | # ( c ( or b ) ) ( d ( or c ) ) 389 | = ( f ( and a ) ) ( a ( and f ) ) 390 | < ( b ( and d ) ) ( d ( and d ) ) 391 | # ( a ( and f ) ) ( c ( and a ) ) 392 | < ( e ( or e ) ) ( e ( or c ) ) 393 | < ( c ( and d ) ) ( d ( or c ) ) 394 | # ( d ( or d ) ) ( c ( or a ) ) 395 | # ( e ( or f ) ) ( a ( or c ) ) 396 | # ( c ( or e ) ) ( d ( and f ) ) 397 | # ( a ( or b ) ) ( a ( or d ) ) 398 | < ( b ( and e ) ) ( e ( and e ) ) 399 | # ( e ( and e ) ) ( c ( and a ) ) 400 | # ( a ( or a ) ) ( e ( or b ) ) 401 | > ( b ( or d ) ) ( e ( and b ) ) 402 | # ( a ( and b ) ) ( b ( and f ) ) 403 | # ( a ( or f ) ) ( b ( and e ) ) 404 | # ( c ( or e ) ) ( c ( or d ) ) 405 | # ( a ( and a ) ) ( f ( and f ) ) 406 | # ( a ( and d ) ) ( f ( and a ) ) 407 | # ( c ( and c ) ) ( b ( or d ) ) 408 | > ( c ( or b ) ) ( a ( and c ) ) 409 | # ( b ( or d ) ) ( c ( or c ) ) 410 | # ( a ( or e ) ) ( d ( and d ) ) 411 | -------------------------------------------------------------------------------- /listops.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | 10 | import ordered_memory 11 | from utils.hinton import plot 12 | from utils.listops_data import load_data_and_embeddings, LABEL_MAP, PADDING_TOKEN, get_batch 13 | from utils.utils import build_tree, char2tree, evalb 14 | 15 | 16 | class ListOpsModel(nn.Module): 17 | def __init__(self, args): 18 | super(ListOpsModel, self).__init__() 19 | 20 | self.args = args 21 | self.padding_idx = args.padding_idx 22 | self.embedding = nn.Embedding(args.ntoken, args.ninp, 23 | padding_idx=self.padding_idx) 24 | 25 | self.encoder = ordered_memory.OrderedMemory(args.ninp, args.nhid, args.nslot, 26 | dropout=args.dropout, dropoutm=args.dropoutm, 27 | bidirection=args.bidirection) 28 | 29 | self.mlp = nn.Sequential( 30 | nn.Dropout(args.dropouto), 31 | nn.Linear(args.nhid * 2 if args.bidirection else args.nhid, args.nout), 32 | ) 33 | 34 | self.drop_input = nn.Dropout(args.dropouti) 35 | self.drop_output = nn.Dropout(args.dropouto) 36 | self.cost = nn.CrossEntropyLoss() 37 | 38 | def forward(self, input): 39 | mask = (input != self.padding_idx).bool() 40 | 41 | emb = self.embedding(input) 42 | emb.transpose_(0, 1) 43 | 44 | mask.transpose_(0, 1) 45 | emb = self.drop_input(emb) 46 | output = self.encoder(emb, mask, output_last=True) 47 | output = self.mlp(output) 48 | return output 49 | 50 | def set_pretrained_embeddings(self, ext_embeddings, ext_word_to_index, word_to_index, finetune=False): 51 | assert hasattr(self, 'embedding') 52 | embeddings = self.embedding.weight.data.cpu().numpy() 53 | for word, index in word_to_index.items(): 54 | if word in ext_word_to_index: 55 | embeddings[index] = ext_embeddings[ext_word_to_index[word]] 56 | embeddings = torch.from_numpy(embeddings).to(self.embedding.weight.device) 57 | self.embedding.weight.data.set_(embeddings) 58 | self.embedding.weight.requires_grad = finetune 59 | 60 | 61 | def model_save(fn): 62 | if args.philly: 63 | fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn) 64 | with open(fn, 'wb') as f: 65 | # torch.save([model, optimizer], f) 66 | torch.save({ 67 | 'epoch': epoch, 68 | 'model_state_dict': model.state_dict(), 69 | 'optimizer_state_dict': optimizer.state_dict(), 70 | 'loss': test_loss 71 | }, f) 72 | 73 | 74 | def model_load(fn): 75 | global model, optimizer 76 | if args.philly: 77 | fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn) 78 | with open(fn, 'rb') as f: 79 | checkpoint = torch.load(f) 80 | model.load_state_dict(checkpoint['model_state_dict']) 81 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 82 | epoch = checkpoint['epoch'] 83 | test_loss = checkpoint['loss'] 84 | 85 | 86 | ############################################################################### 87 | # Training code 88 | ############################################################################### 89 | 90 | @torch.no_grad() 91 | def evaluate(data_iter): 92 | # Turn on evaluation mode which disables dropout. 93 | model.eval() 94 | 95 | total_loss = 0 96 | total_datapoints = 0 97 | for batch, data in enumerate(data_iter): 98 | batch_data = get_batch(data) 99 | X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch_data 100 | 101 | X_batch = torch.from_numpy(X_batch).long().to('cuda' if args.cuda else 'cpu') 102 | y_batch = torch.from_numpy(y_batch).long().to('cuda' if args.cuda else 'cpu') 103 | 104 | lin_output = model(X_batch) 105 | count = y_batch.shape[0] 106 | total_loss += torch.sum( 107 | torch.argmax(lin_output, dim=1) == y_batch 108 | ).float().data 109 | total_datapoints += count 110 | 111 | return total_loss.item() / total_datapoints 112 | 113 | 114 | def train(): 115 | # Turn on training mode which enables dropout. 116 | model.train() 117 | 118 | total_loss = 0 119 | total_acc = 0 120 | start_time = time.time() 121 | for batch, data in enumerate(training_data_iter): 122 | # print(data) 123 | # batch_data = get_batch(next(training_data_iter)) 124 | data, n_batches = data 125 | batch_data = get_batch(data) 126 | X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch_data 127 | 128 | X_batch = torch.from_numpy(X_batch).long().to('cuda' if args.cuda else 'cpu') 129 | y_batch = torch.from_numpy(y_batch).long().to('cuda' if args.cuda else 'cpu') 130 | 131 | optimizer.zero_grad() 132 | 133 | lin_output = model(X_batch) 134 | loss = model.cost(lin_output, y_batch) 135 | acc = torch.mean( 136 | (torch.argmax(lin_output, dim=1) == y_batch).float()) 137 | loss.backward() 138 | 139 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 140 | if args.clip: 141 | torch.nn.utils.clip_grad_norm_(params, args.clip) 142 | optimizer.step() 143 | 144 | total_loss += loss.detach().data 145 | total_acc += acc.detach().data 146 | if batch % args.log_interval == 0 and batch > 0: 147 | elapsed = time.time() - start_time 148 | print( 149 | '| epoch {:3d} ' 150 | '| {:5d}/ {:5d} batches ' 151 | '| lr {:05.5f} | ms/batch {:5.2f} ' 152 | '| loss {:5.2f} | acc {:0.2f}'.format( 153 | epoch, 154 | batch, 155 | n_batches, 156 | optimizer.param_groups[0]['lr'], 157 | elapsed * 1000 / args.log_interval, 158 | total_loss.item() / args.log_interval, 159 | total_acc.item() / args.log_interval)) 160 | total_loss = 0 161 | total_acc = 0 162 | start_time = time.time() 163 | ### 164 | batch += 1 165 | if batch >= n_batches: 166 | break 167 | 168 | 169 | @torch.no_grad() 170 | def generate_parse(data_iter): 171 | model.eval() 172 | np.set_printoptions(precision=2, suppress=True, linewidth=5000, formatter={'float': '{: 0.2f}'.format}) 173 | pred_tree_list = [] 174 | targ_tree_list = [] 175 | crop_count = 0 176 | total_count = 0 177 | for batch, data in enumerate(data_iter): 178 | sents = data['tokens'] 179 | X = np.array([vocabulary[t] for t in data['tokens']]) 180 | # if len(sents) > 100: # In case Evalb fail to process very long sequences 181 | # continue 182 | 183 | X_batch = torch.from_numpy(X).long().to('cuda' if args.cuda else 'cpu') 184 | 185 | model(X_batch[None, :]) 186 | probs = model.encoder.probs 187 | distance = torch.argmax(probs, dim=-1) 188 | distance[0] = args.nslot 189 | 190 | total_count += 1 191 | depth = distance[:, 0] 192 | probs_k = probs[:, 0, :].data.cpu().numpy() 193 | 194 | try: 195 | parse_tree = build_tree(depth, sents) 196 | sen_tree = char2tree(data['sentence'].split()) 197 | except: 198 | crop_count += 1 199 | print('Unbalanced datapoint!') 200 | continue 201 | 202 | pred_tree_list.append(parse_tree) 203 | targ_tree_list.append(sen_tree) 204 | 205 | if batch % 100 > 0: 206 | continue 207 | print(batch) 208 | for i in range(len(sents)): 209 | if sents[i] == '': 210 | break 211 | print('%20s\t%2.2f\t%s' % (sents[i], depth[i], plot(probs_k[i], 1))) 212 | print(parse_tree) 213 | print(sen_tree) 214 | print() 215 | 216 | print('Cropped: %d, Total: %d' % (crop_count, total_count)) 217 | evalb(pred_tree_list, targ_tree_list, evalb_path="../EVALB") 218 | 219 | 220 | if __name__ == "__main__": 221 | parser = argparse.ArgumentParser(description='') 222 | 223 | parser.add_argument('--data', type=str, default='./data/listops', 224 | help='location of the data corpus') 225 | parser.add_argument('--bidirection', action='store_true', 226 | help='use bidirection model') 227 | parser.add_argument('--seq_len', type=int, default=100, 228 | help='max sequence length') 229 | parser.add_argument('--seq_len_test', type=int, default=1000, 230 | help='max sequence length') 231 | parser.add_argument('--no-smart-batching', action='store_true', # reverse 232 | help='batch based on length') 233 | parser.add_argument('--no-use_peano', action='store_true', 234 | help='batch based on length') 235 | parser.add_argument('--emsize', type=int, default=128, 236 | help='size of word embeddings') 237 | parser.add_argument('--nhid', type=int, default=128, 238 | help='number of hidden units per layer') 239 | parser.add_argument('--nslot', type=int, default=21, 240 | help='number of memory slots') 241 | parser.add_argument('--lr', type=float, default=0.001, 242 | help='initial learning rate') 243 | parser.add_argument('--clip', type=float, default=1., 244 | help='gradient clipping') 245 | parser.add_argument('--epochs', type=int, default=50, 246 | help='upper epoch limit') 247 | parser.add_argument('--batch_size', type=int, default=128, metavar='N', 248 | help='batch size') 249 | parser.add_argument('--batch_size_test', type=int, default=128, metavar='N', 250 | help='batch size') 251 | parser.add_argument('--dropout', type=float, default=0.1, 252 | help='dropout applied to layers (0 = no dropout)') 253 | parser.add_argument('--dropoutm', type=float, default=0.3, 254 | help='dropout applied to memory (0 = no dropout)') 255 | parser.add_argument('--dropouti', type=float, default=0.1, 256 | help='dropout for input embedding layers (0 = no dropout)') 257 | parser.add_argument('--dropouto', type=float, default=0.2, 258 | help='dropout applied to layers (0 = no dropout)') 259 | parser.add_argument('--seed', type=int, default=1111, 260 | help='random seed') 261 | parser.add_argument('--cuda', action='store_true', 262 | help='use CUDA') 263 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 264 | help='report interval') 265 | parser.add_argument('--test-only', action='store_true', 266 | help='Test only') 267 | 268 | randomhash = ''.join(str(time.time()).split('.')) 269 | parser.add_argument('--name', type=str, default=randomhash + '.pt', 270 | help='exp name') 271 | parser.add_argument('--wdecay', type=float, default=1.2e-6, 272 | help='weight decay applied to all weights') 273 | parser.add_argument('--std', action='store_true', 274 | help='use standard LSTM') 275 | parser.add_argument('--philly', action='store_true', 276 | help='Use philly cluster') 277 | args = parser.parse_args() 278 | 279 | args.smart_batching = not args.no_smart_batching 280 | args.use_peano = not args.no_use_peano 281 | 282 | # Set the random seed manually for reproducibility. 283 | np.random.seed(args.seed) 284 | torch.manual_seed(args.seed) 285 | if torch.cuda.is_available(): 286 | if not args.cuda: 287 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 288 | else: 289 | torch.cuda.manual_seed(args.seed) 290 | 291 | ############################################################################### 292 | # Load data 293 | ############################################################################### 294 | train_data_path = os.path.join(args.data, 'train_d20s.tsv') 295 | test_data_path = os.path.join(args.data, 'test_d20s.tsv') 296 | vocabulary, initial_embeddings, training_data_iter, eval_iterator, training_data_length, raw_eval_data \ 297 | = load_data_and_embeddings(args, train_data_path, test_data_path) 298 | dictionary = {} 299 | for k, v in vocabulary.items(): 300 | dictionary[v] = k 301 | # make iterator for splits 302 | vocab_size = len(vocabulary) 303 | num_classes = len(set(LABEL_MAP.values())) 304 | args.__dict__.update({'ntoken': vocab_size, 305 | 'ninp': args.emsize, 306 | 'nout': num_classes, 307 | 'padding_idx': vocabulary[PADDING_TOKEN]}) 308 | 309 | model = ListOpsModel(args) 310 | 311 | if args.cuda: 312 | model = model.cuda() 313 | 314 | params = list(model.parameters()) 315 | total_params = sum(x.size()[0] * x.size()[1] 316 | if len(x.size()) > 1 else x.size()[0] 317 | for x in params if x.size()) 318 | total_params_sanity = sum(np.prod(x.size()) for x in model.parameters()) 319 | assert total_params == total_params_sanity 320 | print("TOTAL PARAMS: %d" % sum(np.prod(x.size()) for x in model.parameters())) 321 | print('Args:', args) 322 | print('Model total parameters:', total_params) 323 | 324 | # Ensure the optimizer is optimizing params, which includes both the model's weights as well as the criterion's weight (i.e. Adaptive Softmax) 325 | optimizer = torch.optim.Adam(params, 326 | lr=args.lr, 327 | betas=(0, 0.999), 328 | eps=1e-9, 329 | weight_decay=args.wdecay) 330 | 331 | if not args.test_only: 332 | # Loop over epochs. 333 | lr = args.lr 334 | stored_loss = 0. 335 | 336 | # At any point you can hit Ctrl + C to break out of training early. 337 | try: 338 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', 0.5, patience=2, threshold=0) 339 | for epoch in range(1, args.epochs + 1): 340 | epoch_start_time = time.time() 341 | train() 342 | test_loss = evaluate(eval_iterator) 343 | 344 | print('-' * 89) 345 | print( 346 | '| end of epoch {:3d} ' 347 | '| time: {:5.2f}s ' 348 | '| test acc: {:.4f} ' 349 | '|\n'.format( 350 | epoch, 351 | (time.time() - epoch_start_time), 352 | test_loss 353 | ) 354 | ) 355 | 356 | if test_loss > stored_loss: 357 | model_save(args.name) 358 | print('Saving model (new best validation)') 359 | stored_loss = test_loss 360 | print('-' * 89) 361 | 362 | scheduler.step(test_loss) 363 | except KeyboardInterrupt: 364 | print('-' * 89) 365 | print('Exiting from training early') 366 | 367 | model_load(args.name) 368 | generate_parse(raw_eval_data) 369 | test_loss = evaluate(eval_iterator) 370 | data = {'args': args.__dict__, 371 | 'parameters': total_params, 372 | 'test_acc': test_loss} 373 | print('-' * 89) 374 | print( 375 | '| test acc: {:.4f} ' 376 | '|\n'.format( 377 | test_loss 378 | ) 379 | ) 380 | -------------------------------------------------------------------------------- /sentiment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import os 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim.lr_scheduler as lr_scheduler 11 | 12 | import ordered_memory 13 | from utils.hinton import plot 14 | from utils.locked_dropout import LockedDropout 15 | from utils.utils import build_tree 16 | 17 | 18 | class SSTClassifier(nn.Module): 19 | def __init__(self, args, elmo=None, glove=None): 20 | super(SSTClassifier, self).__init__() 21 | 22 | self.args = args 23 | self.padding_idx = args.padding_idx 24 | 25 | ninp = args.emsize 26 | if ninp > 0: 27 | self.embedding = nn.Embedding( 28 | args.ntoken, ninp, 29 | padding_idx=self.padding_idx, 30 | ) 31 | else: 32 | self.embedding = None 33 | 34 | self.elmo = elmo 35 | if elmo is not None: 36 | ninp += 1024 37 | 38 | self.glove = glove 39 | if glove is not None: 40 | ninp += 300 41 | 42 | self.lockdrop = LockedDropout(dropout=args.dropouti) 43 | 44 | self.encoder = ordered_memory.OrderedMemory(ninp, args.nhid, args.nslot, 45 | dropout=args.dropout, dropoutm=args.dropoutm, 46 | bidirection=args.bidirection) 47 | 48 | self.mlp = nn.Sequential( 49 | nn.Dropout(args.dropouto), 50 | nn.Linear(args.nhid, args.nhid), 51 | nn.ReLU(), 52 | nn.Dropout(args.dropouto), 53 | nn.Linear(args.nhid, args.nout), 54 | ) 55 | 56 | self.drop_input = nn.Dropout(args.dropouti) 57 | self.cost = nn.CrossEntropyLoss() 58 | 59 | def forward(self, input): 60 | if self.elmo is not None: 61 | input_elmo, input_torchtext = input 62 | else: 63 | input_torchtext = input 64 | mask = (input_torchtext != self.padding_idx) 65 | 66 | emb_list = [] 67 | if self.embedding is not None: 68 | emb_torchtext = self.embedding(input_torchtext) 69 | emb_list.append(emb_torchtext) 70 | if self.glove is not None: 71 | emb_glove = self.glove(input_torchtext).detach() 72 | emb_list.append(emb_glove) 73 | if self.elmo is not None: 74 | emb_elmo = self.elmo(input_elmo) 75 | assert (mask.long() == emb_elmo['mask']).all() 76 | emb_elmo = emb_elmo['elmo_representations'][0] 77 | emb_list.append(emb_elmo) 78 | emb = torch.cat(emb_list, dim=-1) 79 | 80 | emb.transpose_(0, 1) 81 | mask.transpose_(0, 1) 82 | emb = self.lockdrop(emb) 83 | 84 | output = self.encoder(emb, mask) 85 | 86 | output = self.mlp(output) 87 | 88 | return output 89 | 90 | @staticmethod 91 | def load_model(input_path): 92 | state = torch.load(input_path) 93 | print('Loading model from %s' % input_path) 94 | model = SSTClassifier(state['args']) 95 | model.load_state_dict(state['state_dict']) 96 | return model 97 | 98 | def save(self, output_path): 99 | state = dict(args=self.args, 100 | state_dict=self.state_dict()) 101 | torch.save(state, output_path) 102 | 103 | def set_pretrained_embeddings(self, ext_embeddings, ext_word_to_index, word_to_index, finetune=False): 104 | assert hasattr(self, 'embedding') 105 | embeddings = self.embedding.weight.data.cpu().numpy() 106 | for word, index in word_to_index.items(): 107 | if word in ext_word_to_index: 108 | embeddings[index] = ext_embeddings[ext_word_to_index[word]] 109 | embeddings = torch.from_numpy(embeddings).to(self.embedding.weight.device) 110 | self.embedding.weight.data.set_(embeddings) 111 | self.embedding.weight.requires_grad = finetune 112 | 113 | 114 | def model_save(fn): 115 | if args.philly: 116 | fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn) 117 | with open(fn, 'wb') as f: 118 | # torch.save([model, optimizer], f) 119 | torch.save({ 120 | 'epoch': epoch, 121 | 'model_state_dict': model.state_dict(), 122 | 'optimizer_state_dict': optimizer.state_dict(), 123 | 'loss': val_loss 124 | }, f) 125 | 126 | 127 | def model_load(fn): 128 | global model, optimizer 129 | if args.philly: 130 | fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn) 131 | with open(fn, 'rb') as f: 132 | checkpoint = torch.load(f) 133 | model.load_state_dict(checkpoint['model_state_dict']) 134 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 135 | epoch = checkpoint['epoch'] 136 | val_loss = checkpoint['loss'] 137 | 138 | 139 | ############################################################################### 140 | # Training code 141 | ############################################################################### 142 | 143 | 144 | def evaluate(data_iter): 145 | # Turn on evaluation mode which disables dropout. 146 | model.eval() 147 | total_loss = 0 148 | total_datapoints = 0 149 | for batch, data in enumerate(data_iter): 150 | sents = data.text 151 | lbls = data.label 152 | count = lbls.shape[0] 153 | lin_output = model(sents) 154 | total_loss += torch.sum( 155 | torch.argmax(lin_output, dim=1) == lbls 156 | ).float().data 157 | total_datapoints += count 158 | 159 | return total_loss.item() / total_datapoints 160 | 161 | 162 | def train(): 163 | # Turn on training mode which enables dropout. 164 | total_loss = 0 165 | total_acc = 0 166 | start_time = time.time() 167 | for batch, data in enumerate(train_iter): 168 | sents = data.text 169 | lbls = data.label 170 | 171 | model.train() 172 | optimizer.zero_grad() 173 | 174 | lin_output = model(sents) 175 | loss = model.cost(lin_output, lbls) 176 | acc = torch.mean( 177 | (torch.argmax(lin_output, dim=1) == lbls).float()) 178 | loss.backward() 179 | 180 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 181 | if args.clip: 182 | torch.nn.utils.clip_grad_norm_(params, args.clip) 183 | optimizer.step() 184 | 185 | total_loss += loss.detach().data 186 | total_acc += acc.detach().data 187 | if batch % args.log_interval == 0 and batch > 0: 188 | elapsed = time.time() - start_time 189 | print( 190 | '| epoch {:3d} ' 191 | '| {:5d}/{:5d} batches ' 192 | '| lr {:05.5f} | ms/batch {:5.2f} ' 193 | '| loss {:5.2f} | acc {:0.2f}'.format( 194 | epoch, 195 | batch, len(train_iter), 196 | optimizer.param_groups[0]['lr'], 197 | elapsed * 1000 / args.log_interval, 198 | total_loss.item() / args.log_interval, 199 | total_acc.item() / args.log_interval)) 200 | total_loss = 0 201 | total_acc = 0 202 | start_time = time.time() 203 | ### 204 | batch += 1 205 | 206 | 207 | def generate_parse(): 208 | from nltk import Tree 209 | from utils import evalb 210 | 211 | batch = [] 212 | pred_tree_list = [] 213 | targ_tree_list = [] 214 | 215 | def process_batch(): 216 | nonlocal batch, pred_tree_list, targ_tree_list 217 | idx = TEXT.process([example['sents'] for example in batch], device=hidden[0].device) 218 | 219 | model(idx) 220 | 221 | probs = model.encoder.probs 222 | distance = torch.argmax(probs, dim=-1) 223 | distance[0] = args.nslot 224 | probs = probs.data.cpu().numpy() 225 | 226 | for i, example in enumerate(batch): 227 | sents = example['sents'] 228 | sents_tree = example['sents_tree'] 229 | depth = distance[:, i] 230 | 231 | parse_tree = build_tree(depth, sents) 232 | 233 | if len(sents) <= 100: 234 | pred_tree_list.append(parse_tree) 235 | targ_tree_list.append(sents_tree) 236 | 237 | if i == 0: 238 | for j in range(len(sents)): 239 | print('%20s\t%2.2f\t%s' % (sents[j], depth[j], plot(probs[j, i], 1.))) 240 | print(parse_tree) 241 | print(sents_tree) 242 | print('-' * 80) 243 | 244 | batch = [] 245 | 246 | np.set_printoptions(precision=2, suppress=True, linewidth=5000, formatter={'float': '{: 0.2f}'.format}) 247 | 248 | model.eval() 249 | hidden = model.encoder.init_hidden(1) 250 | 251 | fin = open('.data/sst/trees/dev.txt', 'r') 252 | for line in fin: 253 | line = line.lower() 254 | sents_tree = Tree.fromstring(line) 255 | sents = sents_tree.leaves() 256 | batch.append({'sents_tree': sents_tree, 'sents': sents}) 257 | 258 | if len(batch) == 16: 259 | process_batch() 260 | 261 | if len(batch) > 0: 262 | process_batch() 263 | 264 | evalb(pred_tree_list, targ_tree_list, evalb_path='./EVALB') 265 | 266 | 267 | if __name__ == "__main__": 268 | parser = argparse.ArgumentParser(description='') 269 | 270 | parser.add_argument('--fine-grained', action='store_true', 271 | help='use fine grained label') 272 | parser.add_argument('--subtrees', action='store_true', 273 | help='use fine subtrees') 274 | parser.add_argument('--glove', action='store_true', 275 | help='use pretrained glove embedding') 276 | parser.add_argument('--elmo', action='store_true', 277 | help='use pretrained elmo') 278 | parser.add_argument('--bidirection', action='store_true', 279 | help='use bidirection model') 280 | parser.add_argument('--emsize', type=int, default=0, 281 | help='size of word embeddings') 282 | parser.add_argument('--nhid', type=int, default=300, 283 | help='number of hidden units per layer') 284 | parser.add_argument('--nslot', type=int, default=15, 285 | help='number of memory slots') 286 | parser.add_argument('--lr', type=float, default=0.001, 287 | help='initial learning rate') 288 | parser.add_argument('--clip', type=float, default=1., 289 | help='gradient clipping') 290 | parser.add_argument('--epochs', type=int, default=50, 291 | help='upper epoch limit') 292 | parser.add_argument('--batch_size', type=int, default=128, metavar='N', 293 | help='batch size') 294 | parser.add_argument('--dropout', type=float, default=0.2, 295 | help='dropout applied to layers (0 = no dropout)') 296 | parser.add_argument('--dropouti', type=float, default=0.3, 297 | help='dropout for input embedding layers (0 = no dropout)') 298 | parser.add_argument('--dropouto', type=float, default=0.4, 299 | help='dropout applied to layers (0 = no dropout)') 300 | parser.add_argument('--dropoutm', type=float, default=0.2, 301 | help='dropout applied to memory (0 = no dropout)') 302 | parser.add_argument('--attention', type=str, default='softmax', 303 | help='attention method') 304 | parser.add_argument('--seed', type=int, default=1111, 305 | help='random seed') 306 | parser.add_argument('--cuda', action='store_true', 307 | help='use CUDA') 308 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 309 | help='report interval') 310 | parser.add_argument('--test-only', action='store_true', 311 | help='Test only') 312 | 313 | randomhash = ''.join(str(time.time()).split('.')) 314 | parser.add_argument('--name', type=str, default=randomhash + '.pt', 315 | help='exp name') 316 | parser.add_argument('--wdecay', type=float, default=1.2e-6, 317 | help='weight decay applied to all weights') 318 | parser.add_argument('--std', action='store_true', 319 | help='use standard LSTM') 320 | parser.add_argument('--philly', action='store_true', 321 | help='Use philly cluster') 322 | parser.add_argument('--resume', action='store_true', 323 | help='resume from checkpoint') 324 | args = parser.parse_args() 325 | 326 | # Set the random seed manually for reproducibility. 327 | np.random.seed(args.seed) 328 | torch.manual_seed(args.seed) 329 | if torch.cuda.is_available(): 330 | if not args.cuda: 331 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 332 | else: 333 | torch.cuda.manual_seed(args.seed) 334 | 335 | ############################################################################### 336 | # Load data 337 | ############################################################################### 338 | from torchtext import data 339 | from torchtext import datasets 340 | from torchtext.vocab import GloVe 341 | 342 | # set up fields 343 | TEXT = data.Field(lower=True, include_lengths=False, batch_first=True) 344 | LABEL = data.Field(sequential=False, unk_token=None) 345 | 346 | # make splits for data 347 | filter_pred = None 348 | if not args.fine_grained: 349 | filter_pred = lambda ex: ex.label != 'neutral' 350 | train_set, dev_set, test_set = datasets.SST.splits( 351 | TEXT, LABEL, 352 | train_subtrees=args.subtrees, 353 | fine_grained=args.fine_grained, 354 | filter_pred=filter_pred 355 | ) 356 | 357 | # build the vocabulary 358 | if args.glove: 359 | TEXT.build_vocab(train_set, dev_set, test_set, min_freq=1, vectors=GloVe(name='840B', dim=300)) 360 | else: 361 | TEXT.build_vocab(train_set, min_freq=2) 362 | LABEL.build_vocab(train_set) 363 | 364 | # make iterator for splits 365 | train_iter, dev_iter, test_iter = data.BucketIterator.splits( 366 | (train_set, dev_set, test_set), 367 | batch_size=args.batch_size, 368 | device='cuda' if args.cuda else 'cpu' 369 | ) 370 | 371 | args.__dict__.update({'ntoken': len(TEXT.vocab), 372 | 'nout': len(LABEL.vocab), 373 | 'padding_idx': TEXT.vocab.stoi['']}) 374 | 375 | if args.elmo: 376 | from allennlp.modules.elmo import Elmo, batch_to_ids 377 | 378 | options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json" 379 | weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5" 380 | 381 | elmo = Elmo(options_file, weight_file, 1, requires_grad=False, dropout=0) 382 | 383 | torchtext_process = TEXT.process 384 | 385 | 386 | def elmo_process(batch, device): 387 | elmo_tensor = batch_to_ids(batch) 388 | elmo_tensor = elmo_tensor.to(device=device) 389 | torchtext_tensor = torchtext_process(batch, device) 390 | return (elmo_tensor, torchtext_tensor) 391 | 392 | 393 | TEXT.process = elmo_process 394 | else: 395 | elmo = None 396 | 397 | if args.glove: 398 | glove = torch.nn.Embedding(args.ntoken, 300, _weight=TEXT.vocab.vectors) 399 | else: 400 | glove = None 401 | 402 | model = SSTClassifier(args, elmo=elmo, glove=glove) 403 | 404 | if args.resume: 405 | model_load(args.name) 406 | 407 | if args.cuda: 408 | model = model.cuda() 409 | 410 | params = list(model.parameters()) 411 | total_params = sum(x.size()[0] * x.size()[1] 412 | if len(x.size()) > 1 else x.size()[0] 413 | for x in params if x.size()) 414 | print('Args:', args) 415 | print('Model total parameters:', total_params) 416 | 417 | optimizer = torch.optim.Adam(params, 418 | lr=args.lr, 419 | betas=(0, 0.999), 420 | eps=1e-9, 421 | weight_decay=args.wdecay) 422 | 423 | if not args.test_only: 424 | # Loop over epochs. 425 | lr = args.lr 426 | stored_loss = 0. 427 | 428 | # At any point you can hit Ctrl + C to break out of training early. 429 | try: 430 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', 0.5, patience=2, threshold=0) 431 | for epoch in range(1, args.epochs + 1): 432 | epoch_start_time = time.time() 433 | train() 434 | val_loss = evaluate(dev_iter) 435 | test_loss = evaluate(test_iter) 436 | 437 | print('-' * 89) 438 | print( 439 | '| end of epoch {:3d} ' 440 | '| time: {:5.2f}s ' 441 | '| valid acc: {:.4f} ' 442 | '| test acc: {:.4f} ' 443 | '|\n'.format( 444 | epoch, 445 | (time.time() - epoch_start_time), 446 | val_loss, 447 | test_loss 448 | ) 449 | ) 450 | 451 | if val_loss > stored_loss: 452 | model_save(args.name) 453 | print('Saving model (new best validation)') 454 | stored_loss = val_loss 455 | print('-' * 89) 456 | 457 | scheduler.step(val_loss) 458 | 459 | except KeyboardInterrupt: 460 | print('-' * 89) 461 | print('Exiting from training early') 462 | 463 | model_load(args.name) 464 | test_loss = evaluate(test_iter) 465 | val_loss = evaluate(dev_iter) 466 | 467 | try: 468 | generate_parse() 469 | except: 470 | print('Unable to parse') 471 | 472 | print('-' * 89) 473 | print( 474 | '| valid acc: {:.4f} ' 475 | '| test acc: {:.4f} ' 476 | '|\n'.format( 477 | val_loss, 478 | test_loss 479 | ) 480 | ) 481 | -------------------------------------------------------------------------------- /utils/listops_data.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import random 3 | import itertools 4 | import time 5 | import sys 6 | 7 | import numpy as np 8 | 9 | from utils.utils import ConvertBinaryBracketedSeq 10 | NUMBERS = list(range(10)) 11 | PADDING_TOKEN = "_PAD" 12 | UNK_TOKEN = "_" 13 | SENTENCE_PADDING_SYMBOL = 0 14 | 15 | FIXED_VOCABULARY = {str(x): i + 1 for i, x in enumerate(NUMBERS)} 16 | FIXED_VOCABULARY.update({ 17 | PADDING_TOKEN: 0, 18 | "[MIN": len(FIXED_VOCABULARY) + 1, 19 | "[MAX": len(FIXED_VOCABULARY) + 2, 20 | "[FIRST": len(FIXED_VOCABULARY) + 3, 21 | "[LAST": len(FIXED_VOCABULARY) + 4, 22 | "[MED": len(FIXED_VOCABULARY) + 5, 23 | "[SM": len(FIXED_VOCABULARY) + 6, 24 | "[PM": len(FIXED_VOCABULARY) + 7, 25 | "[FLSUM": len(FIXED_VOCABULARY) + 8, 26 | "]": len(FIXED_VOCABULARY) + 9 27 | }) 28 | assert len(set(FIXED_VOCABULARY.values())) == len(list(FIXED_VOCABULARY.values())) 29 | 30 | 31 | SENTENCE_PAIR_DATA = False 32 | OUTPUTS = list(range(10)) 33 | LABEL_MAP = {str(x): i for i, x in enumerate(OUTPUTS)} 34 | 35 | Node = namedtuple('Node', 'tag span') 36 | 37 | 38 | def spans(transitions, tokens=None): 39 | n = (len(transitions) + 1) // 2 40 | stack = [] 41 | buf = [Node("leaf", (l, r)) for l, r in zip(list(range(n)), list(range(1, n + 1)))] 42 | buf = list(reversed(buf)) 43 | 44 | nodes = [] 45 | reduced = [False] * n 46 | 47 | def SHIFT(item): 48 | nodes.append(item) 49 | return item 50 | 51 | def REDUCE(l, r): 52 | tag = None 53 | i = r.span[1] - 1 54 | if tokens is not None and tokens[i] == ']' and not reduced[i]: 55 | reduced[i] = True 56 | tag = "struct" 57 | new_stack_item = Node(tag=tag, span=(l.span[0], r.span[1])) 58 | nodes.append(new_stack_item) 59 | return new_stack_item 60 | 61 | for t in transitions: 62 | if t == 0: 63 | stack.append(SHIFT(buf.pop())) 64 | elif t == 1: 65 | r, l = stack.pop(), stack.pop() 66 | stack.append(REDUCE(l, r)) 67 | 68 | return nodes 69 | 70 | def PreprocessDataset( 71 | dataset, 72 | vocabulary, 73 | seq_length, 74 | eval_mode=False, 75 | sentence_pair_data=False, 76 | simple=True, 77 | allow_cropping=False, 78 | pad_from_left=True): 79 | dataset = TrimDataset( 80 | dataset, 81 | seq_length, 82 | eval_mode=eval_mode, 83 | sentence_pair_data=sentence_pair_data, 84 | logger=None, 85 | allow_cropping=allow_cropping) 86 | dataset = TokensToIDs( 87 | vocabulary, 88 | dataset, 89 | sentence_pair_data=sentence_pair_data) 90 | 91 | dataset = CropAndPadSimple( 92 | dataset, 93 | seq_length, 94 | logger=None, 95 | sentence_pair_data=sentence_pair_data, 96 | allow_cropping=allow_cropping, 97 | pad_from_left=pad_from_left) 98 | 99 | if sentence_pair_data: 100 | X = np.transpose(np.array([[example["premise_tokens"] for example in dataset], 101 | [example["hypothesis_tokens"] for example in dataset]], 102 | dtype=np.int32), (1, 2, 0)) 103 | if simple: 104 | transitions = np.zeros((len(dataset), 2, 0)) 105 | num_transitions = np.transpose(np.array( 106 | [[len(np.array(example["premise_tokens"]).nonzero()[0]) for example in dataset], 107 | [len(np.array(example["hypothesis_tokens"]).nonzero()[0]) for example in dataset]], 108 | dtype=np.int32), (1, 0)) 109 | else: 110 | transitions = np.transpose(np.array([[example["premise_transitions"] for example in dataset], 111 | [example["hypothesis_transitions"] for example in dataset]], 112 | dtype=np.int32), (1, 2, 0)) 113 | num_transitions = np.transpose(np.array( 114 | [[example["num_premise_transitions"] for example in dataset], 115 | [example["num_hypothesis_transitions"] for example in dataset]], 116 | dtype=np.int32), (1, 0)) 117 | else: 118 | X = np.array([example["tokens"] for example in dataset], 119 | dtype=np.int32) 120 | if simple: 121 | transitions = np.zeros((len(dataset), 0)) 122 | num_transitions = np.array( 123 | [len(np.array(example["tokens"]).nonzero()[0]) for example in dataset], 124 | dtype=np.int32) 125 | else: 126 | transitions = np.array([example["transitions"] 127 | for example in dataset], dtype=np.int32) 128 | num_transitions = np.array( 129 | [example["num_transitions"] for example in dataset], 130 | dtype=np.int32) 131 | 132 | y = np.array( 133 | [LABEL_MAP[example["label"]] for example in dataset], 134 | dtype=np.int32) 135 | 136 | # NP Array of Strings 137 | example_ids = np.array([example["example_id"] for example in dataset]) 138 | 139 | return X, transitions, y, num_transitions, example_ids 140 | 141 | 142 | def load_data(path, lowercase=None, choose=lambda x: True, eval_mode=False): 143 | examples = [] 144 | with open(path) as f: 145 | for example_id, line in enumerate(f): 146 | line = line.strip() 147 | label, seq = line.split('\t') 148 | if len(seq) <= 1: 149 | continue 150 | 151 | tokens, transitions = ConvertBinaryBracketedSeq( 152 | seq.split(' ')) 153 | 154 | example = {} 155 | example["label"] = label 156 | example["sentence"] = seq 157 | example["tokens"] = tokens 158 | example["transitions"] = transitions 159 | example["example_id"] = str(example_id) 160 | 161 | examples.append(example) 162 | return examples 163 | 164 | 165 | def load_data_and_embeddings(args, training_data_path, eval_data_path): 166 | raw_training_data = load_data(training_data_path, None, eval_mode=False) 167 | raw_eval_data = load_data(eval_data_path, None, eval_mode=True) 168 | import copy 169 | raw_eval_data_copy = copy.deepcopy(raw_eval_data) 170 | # Prepare the vocabulary 171 | vocabulary = FIXED_VOCABULARY 172 | print("In fixed vocabulary mode. Training embeddings from scratch.") 173 | initial_embeddings = None 174 | # Trim dataset, convert token sequences to integer sequences, crop, and 175 | # pad. 176 | print("Preprocessing training data.") 177 | training_data = PreprocessDataset( 178 | raw_training_data, 179 | vocabulary, 180 | args.seq_len, #def to 100 181 | eval_mode=False, 182 | sentence_pair_data=SENTENCE_PAIR_DATA, 183 | simple=True, 184 | allow_cropping=False, 185 | pad_from_left=True) 186 | training_data_iter = MakeTrainingIterator(training_data, args.batch_size, args.smart_batching, args.use_peano, sentence_pair_data=SENTENCE_PAIR_DATA) 187 | training_data_length = len(training_data[0]) 188 | # Preprocess eval sets. 189 | eval_data = PreprocessDataset( 190 | raw_eval_data, 191 | vocabulary, 192 | args.seq_len_test, 193 | eval_mode=True, 194 | sentence_pair_data=SENTENCE_PAIR_DATA, 195 | simple=True, #for RNNs and shit 196 | allow_cropping=True, 197 | pad_from_left=True) 198 | eval_it = MakeEvalIterator(eval_data, args.batch_size_test, None, 199 | bucket_eval=True, 200 | shuffle=False) 201 | 202 | return vocabulary, initial_embeddings, training_data_iter, eval_it, training_data_length, raw_eval_data_copy 203 | 204 | 205 | def MakeTrainingIterator( 206 | sources, 207 | batch_size, 208 | smart_batches=True, 209 | use_peano=True, 210 | sentence_pair_data=True, 211 | pad_from_left=True): 212 | # Make an iterator that exposes a dataset as random minibatches. 213 | 214 | def get_key(num_transitions): 215 | if use_peano and sentence_pair_data: 216 | prem_len, hyp_len = num_transitions 217 | key = Peano(prem_len, hyp_len) 218 | return key 219 | else: 220 | if not isinstance(num_transitions, list): 221 | num_transitions = [num_transitions] 222 | return max(num_transitions) 223 | 224 | def build_batches(): 225 | dataset_size = len(sources[0]) 226 | order = list(range(dataset_size)) 227 | random.shuffle(order) 228 | order = np.array(order) 229 | 230 | num_splits = 10 # TODO: Should we be smarter about split size? 231 | order_limit = len(order) // num_splits * num_splits 232 | order = order[:order_limit] 233 | order_splits = np.split(order, num_splits) 234 | batches = [] 235 | 236 | for split in order_splits: 237 | # Put indices into buckets based on example length. 238 | keys = [] 239 | for i in split: 240 | num_transitions = sources[3][i] 241 | key = get_key(num_transitions) 242 | keys.append((i, key)) 243 | keys = sorted(keys, key=lambda __key: __key[1]) 244 | 245 | # Group indices from buckets into batches, so that 246 | # examples in each batch have similar length. 247 | batch = [] 248 | for i, _ in keys: 249 | batch.append(i) 250 | if len(batch) == batch_size: 251 | batches.append(batch) 252 | batch = [] 253 | return batches 254 | 255 | def batch_iter(): 256 | batches = build_batches() 257 | num_batches = len(batches) 258 | idx = -1 259 | order = list(range(num_batches)) 260 | random.shuffle(order) 261 | 262 | while True: 263 | idx += 1 264 | if idx >= num_batches: 265 | # Start another epoch. 266 | batches = build_batches() 267 | num_batches = len(batches) 268 | idx = 0 269 | order = list(range(num_batches)) 270 | random.shuffle(order) 271 | batch_indices = batches[order[idx]] 272 | yield tuple(source[batch_indices] for source in sources), num_batches 273 | 274 | def data_iter(): 275 | dataset_size = len(sources[0]) 276 | start = -1 * batch_size 277 | order = list(range(dataset_size)) 278 | random.shuffle(order) 279 | 280 | while True: 281 | start += batch_size 282 | if start > dataset_size - batch_size: 283 | # Start another epoch. 284 | start = 0 285 | random.shuffle(order) 286 | batch_indices = order[start:start + batch_size] 287 | yield tuple(source[batch_indices] for source in sources) 288 | 289 | train_iter = batch_iter if smart_batches else data_iter 290 | 291 | return train_iter() 292 | 293 | def MakeBucketEvalIterator(sources, batch_size): 294 | # Order in eval should not matter. Use batches sorted by length for speed 295 | # improvement. 296 | 297 | def single_sentence_key(num_transitions): 298 | return num_transitions 299 | 300 | def sentence_pair_key(num_transitions): 301 | sent1_len, sent2_len = num_transitions 302 | return Peano(sent1_len, sent2_len) 303 | 304 | dataset_size = len(sources[0]) 305 | 306 | # Sort examples by length. From longest to shortest. 307 | num_transitions = sources[3] 308 | sort_key = sentence_pair_key if len( 309 | num_transitions.shape) == 2 else single_sentence_key 310 | order = sorted(zip(list(range(dataset_size)), num_transitions), 311 | key=lambda x: sort_key(x[1])) 312 | order = list(reversed(order)) 313 | order = [x[0] for x in order] 314 | 315 | num_batches = dataset_size // batch_size 316 | batches = [] 317 | 318 | # Roll examples into batches so they have similar length. 319 | for i in range(num_batches): 320 | batch_indices = order[i * batch_size:(i + 1) * batch_size] 321 | batch = tuple(source[batch_indices] for source in sources) 322 | batches.append(batch) 323 | 324 | examples_leftover = dataset_size - num_batches * batch_size 325 | 326 | # Create a short batch: 327 | if examples_leftover > 0: 328 | batch_indices = order[num_batches * 329 | batch_size:num_batches * 330 | batch_size + 331 | examples_leftover] 332 | batch = tuple(source[batch_indices] for source in sources) 333 | batches.append(batch) 334 | 335 | return batches 336 | 337 | 338 | def MakeEvalIterator(sources, batch_size, limit=None, shuffle=False, rseed=123, bucket_eval=False): 339 | return MakeBucketEvalIterator(sources, batch_size)[:limit] 340 | 341 | def TrimDataset(dataset, seq_length, eval_mode=False, 342 | sentence_pair_data=False, logger=None, allow_cropping=False): 343 | """Avoid using excessively long training examples.""" 344 | 345 | if sentence_pair_data: 346 | trimmed_dataset = [ 347 | example for example in dataset if len( 348 | example["premise_transitions"]) <= seq_length and len( 349 | example["hypothesis_transitions"]) <= seq_length] 350 | else: 351 | trimmed_dataset = [example for example in dataset if 352 | len(example["transitions"]) <= seq_length] 353 | 354 | diff = len(dataset) - len(trimmed_dataset) 355 | if eval_mode: 356 | assert allow_cropping or diff == 0, "allow_eval_cropping is false but there are over-length eval examples." 357 | if logger and diff > 0: 358 | logger.Log( 359 | "Warning: Cropping " + 360 | str(diff) + 361 | " over-length eval examples.") 362 | return dataset 363 | else: 364 | if allow_cropping: 365 | if logger and diff > 0: 366 | logger.Log( 367 | "Cropping " + 368 | str(diff) + 369 | " over-length training examples.") 370 | return dataset 371 | else: 372 | if logger and diff > 0: 373 | logger.Log( 374 | "Discarding " + 375 | str(diff) + 376 | " over-length training examples.") 377 | return trimmed_dataset 378 | 379 | def TokensToIDs(vocabulary, dataset, sentence_pair_data=False): 380 | """Replace strings in original boolean dataset with token IDs.""" 381 | if sentence_pair_data: 382 | keys = ["premise_tokens", "hypothesis_tokens"] 383 | else: 384 | keys = ["tokens"] 385 | 386 | tokens = 0 387 | unks = 0 388 | lowers = 0 389 | raises = 0 390 | 391 | for key in keys: 392 | if UNK_TOKEN in vocabulary: 393 | unk_id = vocabulary[UNK_TOKEN] 394 | for example in dataset: 395 | for i, token in enumerate(example[key]): 396 | if token in vocabulary: 397 | example[key][i] = vocabulary[token] 398 | elif token.lower() in vocabulary: 399 | example[key][i] = vocabulary[token.lower()] 400 | lowers += 1 401 | elif token.upper() in vocabulary: 402 | example[key][i] = vocabulary[token.upper()] 403 | raises += 1 404 | else: 405 | example[key][i] = unk_id 406 | unks += 1 407 | tokens += 1 408 | print("Unk rate {:2.6f}%, downcase rate {:2.6f}%, upcase rate {:2.6f}%".format((unks * 100.0 / tokens), (lowers * 100.0 / tokens), (raises * 100.0 / tokens))) 409 | else: 410 | for example in dataset: 411 | example[key] = [vocabulary[token] 412 | for token in example[key]] 413 | return dataset 414 | 415 | def CropAndPadExample( 416 | example, 417 | padding_amount, 418 | target_length, 419 | key, 420 | symbol=0, 421 | logger=None, 422 | allow_cropping=False, 423 | pad_from_left=True): 424 | """ 425 | Crop/pad a sequence value of the given dict `example`. 426 | """ 427 | if padding_amount < 0: 428 | if not allow_cropping: 429 | raise NotImplementedError( 430 | "Cropping not allowed. " 431 | "Please set seq_length and eval_seq_length to some sufficiently large value or (for non-SPINN models) use --allow_cropping and --allow_eval_cropping..") 432 | # Crop, then pad normally. 433 | if pad_from_left: 434 | example[key] = example[key][-padding_amount:] 435 | else: 436 | example[key] = example[key][:padding_amount] 437 | padding_amount = 0 438 | alternate_side_padding = target_length - \ 439 | (padding_amount + len(example[key])) 440 | if pad_from_left: 441 | example[key] = ([symbol] * padding_amount) + \ 442 | example[key] + ([symbol] * alternate_side_padding) 443 | else: 444 | example[key] = ([symbol] * alternate_side_padding) + \ 445 | example[key] + ([symbol] * padding_amount) 446 | 447 | 448 | def CropAndPadSimple( 449 | dataset, 450 | length, 451 | logger=None, 452 | sentence_pair_data=False, 453 | allow_cropping=True, 454 | pad_from_left=True): 455 | # NOTE: This can probably be done faster in NumPy if it winds up making a 456 | # difference. 457 | if sentence_pair_data: 458 | keys = ["premise_tokens", 459 | "hypothesis_tokens"] 460 | else: 461 | keys = ["tokens"] 462 | 463 | for example in dataset: 464 | for tokens_key in keys: 465 | num_tokens = len(example[tokens_key]) 466 | tokens_padding_amount = length - num_tokens 467 | CropAndPadExample( 468 | example, 469 | tokens_padding_amount, 470 | length, 471 | tokens_key, 472 | symbol=SENTENCE_PADDING_SYMBOL, 473 | logger=logger, 474 | allow_cropping=allow_cropping, 475 | pad_from_left=pad_from_left) 476 | return dataset 477 | 478 | def truncate(data, seq_length, max_length, left_padded): 479 | if left_padded: 480 | data = data[:, seq_length - max_length:] 481 | else: 482 | data = data[:, :max_length] 483 | return data 484 | 485 | 486 | def get_batch(batch): 487 | X_batch, transitions_batch, y_batch, num_transitions_batch, example_ids = batch 488 | 489 | # Truncate each batch to max length within the batch. 490 | X_batch_is_left_padded = True 491 | transitions_batch_is_left_padded = True 492 | max_length = np.max(num_transitions_batch) 493 | seq_length = X_batch.shape[1] 494 | 495 | # Truncate batch. 496 | X_batch = truncate(X_batch, seq_length, max_length, X_batch_is_left_padded) 497 | transitions_batch = truncate(transitions_batch, seq_length, 498 | max_length, transitions_batch_is_left_padded) 499 | 500 | return X_batch, transitions_batch, y_batch, num_transitions_batch, example_ids 501 | -------------------------------------------------------------------------------- /proplog.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import os 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim.lr_scheduler as lr_scheduler 11 | 12 | import ordered_memory 13 | from utils.utils import build_tree, evalb, remove_bracket, char2tree 14 | from utils.hinton import plot 15 | 16 | # from orion.client import report_results 17 | 18 | 19 | if __name__ == "__main__": 20 | parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model') 21 | parser.add_argument('--data', type=str, default='data/propositionallogic/', 22 | help='location of the data corpus') 23 | parser.add_argument('--max-op', type=int, default=6, 24 | help='maximum number of operator') 25 | parser.add_argument('--emsize', type=int, default=200, 26 | help='size of word embeddings') 27 | parser.add_argument('--nhid', type=int, default=200, 28 | help='number of hidden units per layer') 29 | parser.add_argument('--nslot', type=int, default=12, 30 | help='number of memory slots') 31 | parser.add_argument('--lr', type=float, default=0.001, 32 | help='initial learning rate') 33 | parser.add_argument('--clip', type=float, default=1., 34 | help='gradient clipping') 35 | parser.add_argument('--epochs', type=int, default=50, 36 | help='upper epoch limit') 37 | parser.add_argument('--batch_size', type=int, default=128, metavar='N', 38 | help='batch size') 39 | parser.add_argument('--dropout', type=float, default=0.2, 40 | help='dropout applied to layers (0 = no dropout)') 41 | parser.add_argument('--dropouti', type=float, default=0.1, 42 | help='dropout applied to layers (0 = no dropout)') 43 | parser.add_argument('--dropouto', type=float, default=0.3, 44 | help='dropout applied to layers (0 = no dropout)') 45 | parser.add_argument('--dropoutm', type=float, default=0.2, 46 | help='dropout applied to layers (0 = no dropout)') 47 | parser.add_argument('--seed', type=int, default=1111, 48 | help='random seed') 49 | parser.add_argument('--cuda', action='store_true', 50 | help='use CUDA') 51 | parser.add_argument('--test-only', action='store_true', 52 | help='Test only') 53 | 54 | parser.add_argument('--log-interval', type=int, default=200, metavar='N', 55 | help='report interval') 56 | randomhash = ''.join(str(time.time()).split('.')) 57 | parser.add_argument('--save', type=str, default=randomhash + '.pt', 58 | help='path to save the final model') 59 | parser.add_argument('--wdecay', type=float, default=1.2e-6, 60 | help='weight decay applied to all weights') 61 | parser.add_argument('--std', action='store_true', 62 | help='use standard LSTM') 63 | parser.add_argument('--philly', action='store_true', 64 | help='Use philly cluster') 65 | args = parser.parse_args() 66 | args.tied = True 67 | 68 | # Set the random seed manually for reproducibility. 69 | np.random.seed(args.seed) 70 | torch.manual_seed(args.seed) 71 | if torch.cuda.is_available(): 72 | if not args.cuda: 73 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 74 | else: 75 | torch.cuda.manual_seed(args.seed) 76 | 77 | 78 | ############################################################################### 79 | # Load data 80 | ############################################################################### 81 | 82 | 83 | def model_save(fn): 84 | if args.philly: 85 | fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn) 86 | with open(fn, 'wb') as f: 87 | # torch.save([model, optimizer], f) 88 | torch.save({ 89 | 'epoch': epoch, 90 | 'model_state_dict': model.state_dict(), 91 | 'optimizer_state_dict': optimizer.state_dict(), 92 | 'loss': val_loss 93 | }, f) 94 | 95 | 96 | def model_load(fn): 97 | global model, optimizer 98 | if args.philly: 99 | fn = os.path.join(os.environ['PT_OUTPUT_DIR'], fn) 100 | with open(fn, 'rb') as f: 101 | checkpoint = torch.load(f) 102 | model.load_state_dict(checkpoint['model_state_dict']) 103 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 104 | epoch = checkpoint['epoch'] 105 | val_loss = checkpoint['loss'] 106 | 107 | 108 | class LogicInference(object): 109 | def __init__(self, datapath='data/propositionallogic/', maxn=12): 110 | """maxn=0 indicates variable expression length.""" 111 | self.num2char = ['(', ')', 112 | 'a', 'b', 'c', 'd', 'e', 'f', 113 | 'or', 'and', 'not'] 114 | self.char2num = {self.num2char[i]: i 115 | for i in range(len(self.num2char))} 116 | 117 | self.num2lbl = list('<>=^|#v') 118 | self.lbl2num = {self.num2lbl[i]: i 119 | for i in range(len(self.num2lbl))} 120 | 121 | self.train_set, self.valid_set, self.test_set = [], [], [] 122 | counter = 0 123 | for i in range(maxn): 124 | itrainexample = self._readfile(os.path.join(datapath, "train" + str(i))) 125 | for e in itrainexample: 126 | counter += 1 127 | if counter % 10 == 0: 128 | self.valid_set.append(e) 129 | else: 130 | self.train_set.append(e) 131 | # self.train_set = self.train_set + itrainexample 132 | 133 | for i in range(13): 134 | itestexample = self._readfile(os.path.join(datapath, "test" + str(i))) 135 | self.test_set.append(itestexample) 136 | 137 | def _readfile(self, filepath): 138 | f = open(filepath, 'r') 139 | examples = [] 140 | for line in f.readlines(): 141 | relation, p1, p2 = line.strip().split('\t') 142 | p1 = p1.split() 143 | p2 = p2.split() 144 | examples.append((self.lbl2num[relation], 145 | [self.char2num[w] for w in p1], 146 | [self.char2num[w] for w in p2])) 147 | return examples 148 | 149 | def stream(self, dataset, batch_size, shuffle=False, pad=None): 150 | if pad is None: 151 | pad = len(self.num2char) 152 | import random 153 | import math 154 | batch_count = int(math.ceil(len(dataset) / float(batch_size))) 155 | 156 | def shuffle_stream(): 157 | if shuffle: 158 | random.shuffle(dataset) 159 | for i in range(batch_count): 160 | yield dataset[i * batch_size: (i + 1) * batch_size] 161 | 162 | def arrayify(stream, pad): 163 | for batch in stream: 164 | batch_lbls = np.array([x[0] for x in batch], dtype=np.int64) 165 | batch_sent = [x[1] for x in batch] + [x[2] for x in batch] 166 | max_len = max(len(s) for s in batch_sent) 167 | batch_idxs = np.full((max_len, len(batch_sent)), pad, 168 | dtype=np.int64) 169 | for i in range(len(batch_sent)): 170 | sentence = batch_sent[i] 171 | batch_idxs[:len(sentence), i] = sentence 172 | yield batch_idxs, batch_lbls 173 | 174 | stream = shuffle_stream() 175 | stream = arrayify(stream, pad) 176 | return stream 177 | 178 | 179 | corpus = LogicInference(datapath=args.data, maxn=args.max_op + 1) 180 | 181 | ############################################################################### 182 | # Build the model 183 | ############################################################################### 184 | ### 185 | # if args.resume: 186 | # print('Resuming model ...') 187 | # model_load(args.resume) 188 | # optimizer.param_groups[0]['lr'] = args.lr 189 | # model.dropouti, model.dropouth, model.dropout, args.dropoute = args.dropouti, args.dropouth, args.dropout, args.dropoute 190 | # if args.wdrop: 191 | # for rnn in model.rnn.cells: 192 | # rnn.hh.dropout = args.wdrop 193 | ### 194 | 195 | ntokens = len(corpus.num2char) + 1 196 | nlbls = len(corpus.num2lbl) 197 | 198 | 199 | class Classifier(nn.Module): 200 | """Container module with an encoder, a recurrent module, and a decoder.""" 201 | 202 | def __init__(self, ntoken, ninp, nhid, nout, nslot, dropout, dropouti, dropouto, dropoutm): 203 | super(Classifier, self).__init__() 204 | 205 | self.padding_idx = ntoken - 1 206 | self.embedding = nn.Embedding(ntoken, ninp, 207 | padding_idx=self.padding_idx) 208 | 209 | self.encoder = ordered_memory.OrderedMemory(ninp, nhid, nslot, 210 | dropout=dropout, dropoutm=dropoutm) 211 | 212 | self.mlp = nn.Sequential( 213 | nn.Dropout(dropouto), 214 | nn.Linear(4 * nhid, nhid), 215 | nn.ELU(), 216 | nn.Dropout(dropouto), 217 | nn.Linear(nhid, nout), 218 | ) 219 | 220 | self.drop = nn.Dropout(dropouti) 221 | 222 | self.cost = nn.CrossEntropyLoss() 223 | self.init_weights() 224 | 225 | def init_weights(self): 226 | initrange = 0.1 227 | self.embedding.weight.data.uniform_(-initrange, initrange) 228 | 229 | def forward(self, input): 230 | batch_size = input.size(1) 231 | mask = (input != self.padding_idx) 232 | emb = self.drop(self.embedding(input)) 233 | output = self.encoder(emb, mask) 234 | self.probs = self.encoder.probs 235 | 236 | clause_1 = output[:batch_size // 2] 237 | clause_2 = output[batch_size // 2:] 238 | output = self.mlp(torch.cat([clause_1, clause_2, 239 | clause_1 * clause_2, 240 | torch.abs(clause_1 - clause_2)], dim=1)) 241 | return output 242 | 243 | 244 | if __name__ == "__main__": 245 | model = Classifier( 246 | ntoken=ntokens, 247 | ninp=args.emsize, 248 | nhid=args.nhid, 249 | nout=nlbls, 250 | nslot=args.nslot, 251 | dropout=args.dropout, 252 | dropouti=args.dropouti, 253 | dropouto=args.dropouto, 254 | dropoutm=args.dropoutm, 255 | ) 256 | 257 | if args.cuda: 258 | model = model.cuda() 259 | # model = model.half() 260 | 261 | params = list(model.parameters()) 262 | total_params = sum(x.size()[0] * x.size()[1] 263 | if len(x.size()) > 1 else x.size()[0] 264 | for x in params if x.size()) 265 | print('Args:', args) 266 | print('Model total parameters:', total_params) 267 | 268 | optimizer = torch.optim.Adam(params, 269 | lr=args.lr, 270 | betas=(0, 0.999), 271 | eps=1e-9, 272 | weight_decay=args.wdecay) 273 | 274 | 275 | ############################################################################### 276 | # Training code 277 | ############################################################################### 278 | 279 | @torch.no_grad() 280 | def valid(): 281 | # Turn on evaluation mode which disables dropout. 282 | model.eval() 283 | total_loss = 0 284 | total_datapoints = 0 285 | for sents, lbls in corpus.stream(corpus.valid_set, args.batch_size * 2): 286 | count = lbls.shape[0] 287 | sents = torch.from_numpy(sents) 288 | lbls = torch.from_numpy(lbls) 289 | if args.cuda: 290 | sents = sents.cuda() 291 | lbls = lbls.cuda() 292 | lin_output = model(sents) 293 | total_loss += torch.sum( 294 | torch.argmax(lin_output, dim=1) == lbls 295 | ).float().data 296 | total_datapoints += count 297 | accs = total_loss.item() / total_datapoints 298 | return accs 299 | 300 | @torch.no_grad() 301 | def evaluate(): 302 | # Turn on evaluation mode which disables dropout. 303 | model.eval() 304 | model.encoder.OM_forward.nslot = args.nslot * 2 305 | 306 | accs = [] 307 | global_loss = 0 308 | global_datapoints = 0 309 | for l in range(13): 310 | total_loss = 0 311 | total_datapoints = 0 312 | for sents, lbls in corpus.stream(corpus.test_set[l], args.batch_size * 2): 313 | count = lbls.shape[0] 314 | sents = torch.from_numpy(sents) 315 | lbls = torch.from_numpy(lbls) 316 | if args.cuda: 317 | sents = sents.cuda() 318 | lbls = lbls.cuda() 319 | lin_output = model(sents) 320 | total_loss += torch.sum( 321 | torch.argmax(lin_output, dim=1) == lbls 322 | ).float().data.item() 323 | total_datapoints += count 324 | accs.append(total_loss / total_datapoints if total_datapoints > 0 else -1) 325 | global_loss += total_loss 326 | global_datapoints += total_datapoints 327 | 328 | accs.append(global_loss / global_datapoints) 329 | 330 | model.encoder.OM_forward.nslot = args.nslot 331 | return accs 332 | 333 | 334 | def train(): 335 | # Turn on training mode which enables dropout. 336 | total_loss = 0 337 | total_acc = 0 338 | start_time = time.time() 339 | batch = 0 340 | for sents, lbls in corpus.stream(corpus.train_set, args.batch_size, 341 | shuffle=True): 342 | sents = torch.from_numpy(sents) 343 | lbls = torch.from_numpy(lbls) 344 | if args.cuda: 345 | sents = sents.cuda() 346 | lbls = lbls.cuda() 347 | 348 | model.train() 349 | optimizer.zero_grad() 350 | 351 | lin_output = model(sents) 352 | loss = model.cost(lin_output, lbls) 353 | acc = torch.mean( 354 | (torch.argmax(lin_output, dim=1) == lbls).float()) 355 | loss.backward() 356 | 357 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 358 | if args.clip: 359 | torch.nn.utils.clip_grad_norm_(params, args.clip) 360 | optimizer.step() 361 | 362 | total_loss += loss.detach().data 363 | total_acc += acc.detach().data 364 | if batch % args.log_interval == 0 and batch > 0: 365 | elapsed = time.time() - start_time 366 | print( 367 | '| epoch {:3d} ' 368 | '| lr {:05.5f} | ms/batch {:5.2f} ' 369 | '| loss {:5.2f} | acc {:0.2f}'.format( 370 | epoch, 371 | optimizer.param_groups[0]['lr'], 372 | elapsed * 1000 / args.log_interval, 373 | total_loss.item() / args.log_interval, 374 | total_acc.item() / args.log_interval)) 375 | total_loss = 0 376 | total_acc = 0 377 | start_time = time.time() 378 | ### 379 | batch += 1 380 | 381 | @torch.no_grad() 382 | def genparse(): 383 | model.eval() 384 | model.encoder.OM_forward.nslot = args.nslot * 2 385 | 386 | np.set_printoptions(precision=2, suppress=True, linewidth=5000, formatter={'float': '{: 0.2f}'.format}) 387 | pred_tree_list = [] 388 | targ_tree_list = [] 389 | for l in range(13): 390 | for sents, lbls in corpus.stream(corpus.test_set[l], args.batch_size * 2): 391 | sents = torch.from_numpy(sents) 392 | if args.cuda: 393 | sents = sents.cuda() 394 | 395 | # hidden = model.encoder.init_hidden(sents.size(1)) 396 | # emb = model.drop(model.embedding(sents)) 397 | # raw_output, probs_batch, _ = model.encoder(emb, hidden) 398 | 399 | model(sents) 400 | probs_batch = model.probs 401 | 402 | for i in range(sents.size(1)): 403 | probs = probs_batch[:, i].view(-1, args.nslot * 2) 404 | # self.distance = (torch.cumsum(self.probs, dim=-1) < 0.5).sum(dim=-1) 405 | 406 | distance = torch.argmax(probs, dim=-1) 407 | distance[0] = args.nslot * 2 408 | sen = [corpus.num2char[x] 409 | for x in sents[:, i] if x < len(corpus.num2char)] 410 | if len(sen) < 2: 411 | continue 412 | depth = distance[:len(sen)] 413 | probs = probs.data.cpu().numpy() 414 | 415 | parse_tree = remove_bracket(build_tree(depth, sen)) 416 | sen_tree = char2tree(sen) 417 | 418 | pred_tree_list.append(parse_tree) 419 | targ_tree_list.append(sen_tree) 420 | 421 | if np.random.randint(0, 100) > 0: 422 | continue 423 | print() 424 | for i in range(len(sen)): 425 | print('%5s\t%2.2f\t%s' % (sen[i], distance[i], plot(probs[i], 1))) 426 | 427 | print(' '.join(sen)) 428 | # print(sen_tree) 429 | print(parse_tree) 430 | print('') 431 | 432 | evalb(pred_tree_list, targ_tree_list) 433 | 434 | model.encoder.OM_forward.nslot = args.nslot 435 | 436 | 437 | if __name__ == "__main__": 438 | # Loop over epochs. 439 | if not args.test_only: 440 | lr = args.lr 441 | stored_loss = 0. 442 | 443 | # At any point you can hit Ctrl + C to break out of training early. 444 | try: 445 | # Ensure the optimizer is optimizing params, which includes both the model's weights as well as the criterion's weight (i.e. Adaptive Softmax) 446 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', 0.5, patience=2, threshold=0) 447 | for epoch in range(1, args.epochs + 1): 448 | epoch_start_time = time.time() 449 | train() 450 | val_loss = valid() 451 | test_loss = evaluate() 452 | 453 | print('-' * 89) 454 | print( 455 | '| end of epoch {:3d} ' 456 | '| time: {:5.2f}s ' 457 | '| valid acc: {:.2f} ' 458 | '|\n'.format( 459 | epoch, 460 | (time.time() - epoch_start_time), 461 | val_loss 462 | ), 463 | ', '.join(str('{:0.2f}'.format(v)) for v in test_loss) 464 | ) 465 | 466 | if val_loss > stored_loss: 467 | model_save(args.save) 468 | print('Saving model (new best validation)') 469 | stored_loss = val_loss 470 | print('-' * 89) 471 | 472 | scheduler.step(val_loss) 473 | except KeyboardInterrupt: 474 | print('-' * 89) 475 | print('Exiting from training early') 476 | # Load the best saved model. 477 | model_load(args.save) 478 | 479 | genparse() 480 | 481 | test_loss = evaluate() 482 | val_loss = valid() 483 | print('-' * 89) 484 | print( 485 | '| valid acc: {:.2f} ' 486 | '|\n'.format( 487 | val_loss 488 | ), 489 | ', '.join(str('{:0.2f}'.format(v)) for v in test_loss) 490 | ) 491 | 492 | # report_results([dict( 493 | # name='val_loss', 494 | # type='objective', 495 | # value=val_loss)]) 496 | 497 | -------------------------------------------------------------------------------- /EVALB/evalb.c: -------------------------------------------------------------------------------- 1 | /*****************************************************************/ 2 | /* evalb [-p param_file] [-dh] [-e n] gold-file test-file */ 3 | /* */ 4 | /* Evaluate bracketing in test-file against gold-file. */ 5 | /* Return recall, precision, tagging accuracy. */ 6 | /* */ 7 | /*