├── nlp_commons ├── dep │ ├── __init__.py │ ├── rhead.py │ ├── lhead.py │ ├── tiger.py │ ├── depgraph.py │ ├── depset.py │ ├── model.py │ ├── dwsj.py │ ├── conll.py │ └── dnegra.py ├── __init__.py ├── PKG-INFO ├── setup.py ├── README.txt ├── paramdict.py ├── lbranch.py ├── negra.py ├── sentence.py ├── rbranch.py ├── eval.py ├── ubound.py ├── wsj.py ├── negra10.py ├── util.py ├── graph.py ├── cast3lb10.py ├── wsj10.py ├── model.py ├── cast3lb.py └── bracketing.py ├── modules ├── __init__.py ├── projection.py ├── utils.py ├── markov_flow_model.py └── dmv_viterbi_model.py ├── LICENSE ├── prepare_data.py ├── preprocess_ptb.py ├── dmv_viterbi_train.py ├── README.md ├── markov_flow_train.py └── dmv_flow_train.py /nlp_commons/dep/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .markov_flow_model import * 3 | from .dmv_viterbi_model import * 4 | from .dmv_flow_model import * 5 | from .projection import * 6 | -------------------------------------------------------------------------------- /nlp_commons/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # vim:fenc=utf-8 4 | # 5 | # Copyright © 2017-11-30 Junxian He 6 | # 7 | # Distributed under terms of the MIT license. 8 | 9 | """ 10 | Modules of the lq_nlp_commons library 11 | 12 | """ 13 | 14 | 15 | -------------------------------------------------------------------------------- /nlp_commons/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: lq-nlp-commons 3 | Version: 0.2.0 4 | Summary: Franco M. Luque's Common Python Code for NLP 5 | Home-page: http://www.cs.famaf.unc.edu.ar/~francolq/ 6 | Author: Franco M. Luque 7 | Author-email: francolq@famaf.unc.edu.ar 8 | License: GNU General Public License 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /nlp_commons/dep/rhead.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | # rhead.py: RHEAD baseline for unsupervised dependency parsing. 6 | 7 | from dep import model 8 | from dep import depset 9 | 10 | class RHead(model.DepModel): 11 | trained = True 12 | tested = True 13 | 14 | def __init__(self, treebank=None): 15 | model.DepModel.__init__(self, treebank) 16 | self.Parse = [depset.rhead_depset(b.length) for b in self.Gold] 17 | 18 | 19 | def main(): 20 | print "WSJ10" 21 | import dep.dwsj 22 | tb = dep.dwsj.DepWSJ10() 23 | m = RHead(tb) 24 | m.eval() 25 | 26 | """ 27 | from dep import rhead 28 | rhead.main() 29 | 30 | WSJ10 31 | Number of Trees: 7422 32 | Directed Accuracy: 33.5 33 | Undirected Accuracy: 56.4 34 | """ 35 | -------------------------------------------------------------------------------- /nlp_commons/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | from distutils.core import setup 6 | 7 | setup(name='lq-nlp-commons', 8 | # Read the following page for advice on version numbering: 9 | # http://docs.python.org/distutils/setupscript.html#additional-meta-data 10 | version='0.2.0', 11 | description="Franco M. Luque's Common Python Code for NLP", 12 | author='Franco M. Luque', 13 | author_email='francolq@famaf.unc.edu.ar', 14 | url='http://www.cs.famaf.unc.edu.ar/~francolq/', 15 | packages=['dep'], 16 | py_modules=['bracketing', 'lbranch', 'paramdict', 17 | 'treebank', 'wsj10', 'cast3lb', 'model', 'rbranch', 18 | 'ubound', 'cast3lb10', 'negra', 'sentence', 'util', 19 | 'eval', 'negra10', 'setup', 'wsj', 'graph'], 20 | license='GNU General Public License', 21 | ) 22 | -------------------------------------------------------------------------------- /nlp_commons/README.txt: -------------------------------------------------------------------------------- 1 | lq-nlp-commons - Franco M. Luque's Common Python Code for NLP 2 | 3 | Copyright (C) 2007-2011 Franco M. Luque 4 | URL: 5 | For license information, see LICENSE.txt 6 | 7 | 8 | Introduction 9 | ============ 10 | 11 | This package includes basic data structures, interfaces to the WSJ, Negra and 12 | Cast3LB corpuses (English, German and Spanish respectively), and some baselines 13 | and upper bounds computation code for unsupervised parsing (Klein and 14 | Manning, 2004). 15 | 16 | This work was done as part of the PhD in Computer Science I am doing in FaMAF, 17 | Universidad Nacional de Cordoba, Argentina, under the supervision of 18 | Gabriel Infante-Lopez, with a research fellowship from CONICET. 19 | 20 | No direct usage instructions for now. Only to be used as dependency for other 21 | software. 22 | 23 | 24 | References 25 | ========== 26 | 27 | Klein, D. and Manning, C. D. (2004). Corpus-based induction of syntactic 28 | structure: Models of dependency and constituency. In ACL, pages 478-485. 29 | -------------------------------------------------------------------------------- /nlp_commons/dep/lhead.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | # lhead.py: LHEAD baseline for unsupervised dependency parsing. 6 | 7 | from dep import model 8 | from dep import depset 9 | 10 | class LHead(model.DepModel): 11 | trained = True 12 | tested = True 13 | 14 | def __init__(self, treebank=None): 15 | model.DepModel.__init__(self, treebank) 16 | self.Parse = [depset.lhead_depset(b.length) for b in self.Gold] 17 | 18 | 19 | def main(): 20 | print "WSJ10" 21 | import dep.dwsj 22 | tb = dep.dwsj.DepWSJ10() 23 | m = LHead(tb) 24 | m.eval() 25 | 26 | """ 27 | from dep import lhead 28 | lhead.main() 29 | 30 | WSJ10 31 | Number of Trees: 7422 32 | Directed Accuracy: 23.7 33 | Undirected Accuracy: 55.6 34 | Debe dar: 24.0, 55.9. 35 | 36 | >>> m.count_length_2_1 = True 37 | >>> m.eval() 38 | Number of Trees: 7422 39 | Directed Accuracy: 24.0 40 | Undirected Accuracy: 55.7 41 | (52248, 0.23974123411422446, 0.55726534986985143) 42 | Mas cerca... 43 | """ 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Junxian He 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /nlp_commons/paramdict.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | # paramdict.py: ParamDict is a comfortable dictionary with commonly needed 6 | # functions. 7 | 8 | class ParamDict(object): 9 | 10 | def __init__(self, d=None, default_val=0.0, count_evidence=False): 11 | if d is None: 12 | self.d = {} 13 | else: 14 | self.d = d 15 | self.count_evidence = count_evidence 16 | if count_evidence: 17 | self.evidence = {} 18 | self.default_val = default_val 19 | 20 | def set_default_val(self, val): 21 | self.default_val = val 22 | 23 | def val(self, x): 24 | return self.d.get(x, self.default_val) 25 | 26 | def setVal(self, x, val): 27 | self.d[x] = val 28 | 29 | def add1(self, x): 30 | self.add(x, 1.0) 31 | 32 | def add(self, x, y): 33 | add(self.d, x, y) 34 | if self.count_evidence and y > 0.0: 35 | add(self.evidence, x, 1.0) 36 | 37 | def iteritems(self): 38 | return iter(self.d.items()) 39 | 40 | 41 | # Common procedure used in ParamDict: 42 | def add(dict, x, val): 43 | dict[x] = dict.get(x, 0) + val 44 | -------------------------------------------------------------------------------- /nlp_commons/lbranch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | # lbranch.py: LBRANCH baseline for unsupervised parsing. 6 | 7 | from . import bracketing 8 | from . import model 9 | 10 | class LBranch(model.BracketingModel): 11 | trained = True 12 | tested = True 13 | 14 | def __init__(self, treebank=None): 15 | model.BracketingModel.__init__(self, treebank) 16 | self.Parse = [bracketing.lbranch_bracketing(b.length) for b in self.Gold] 17 | 18 | 19 | def main(): 20 | print('WSJ10') 21 | main1() 22 | print('NEGRA10') 23 | main2() 24 | print('CAST3LB10') 25 | main3() 26 | 27 | def main1(): 28 | from . import wsj10 29 | tb = wsj10.WSJ10() 30 | m = LBranch(tb) 31 | m.eval() 32 | 33 | def main2(): 34 | from . import negra10 35 | tb = negra10.Negra10() 36 | tb.simplify_tags() 37 | m = LBranch(tb) 38 | m.eval() 39 | 40 | def main3(): 41 | from . import cast3lb10 42 | tb = cast3lb10.Cast3LB10() 43 | tb.simplify_tags() 44 | m = LBranch(tb) 45 | m.eval() 46 | 47 | """ 48 | from lbranch import * 49 | main() 50 | 51 | WSJ10 52 | Cantidad de arboles: 7422.0 53 | Medidas sumando todos los brackets: 54 | Precision: 25.7 55 | Recall: 32.6 56 | Media harmonica F1: 28.7 57 | NEGRA10 58 | Cantidad de arboles: 7537.0 59 | Medidas sumando todos los brackets: 60 | Precision: 27.4 61 | Recall: 48.6 62 | Media harmonica F1: 35.1 63 | CAST3LB10 64 | Cantidad de arboles: 712.0 65 | Medidas sumando todos los brackets: 66 | Precision: 26.9 67 | Recall: 38.4 68 | Media harmonica F1: 31.7 69 | """ 70 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # vim:fenc=utf-8 4 | # 5 | # Copyright © 2019-08-14 Junxian 6 | # 7 | # Distributed under terms of the MIT license. 8 | 9 | import argparse 10 | import requests 11 | import tarfile 12 | import os 13 | 14 | def download_file_from_google_drive(id, destination): 15 | URL = "https://docs.google.com/uc?export=download" 16 | 17 | session = requests.Session() 18 | 19 | response = session.get(URL, params = { 'id' : id }, stream = True) 20 | token = get_confirm_token(response) 21 | 22 | if token: 23 | params = { 'id' : id, 'confirm' : token } 24 | response = session.get(URL, params = params, stream = True) 25 | 26 | save_response_content(response, destination) 27 | 28 | def get_confirm_token(response): 29 | for key, value in response.cookies.items(): 30 | if key.startswith('download_warning'): 31 | return value 32 | 33 | return None 34 | 35 | def save_response_content(response, destination): 36 | CHUNK_SIZE = 32768 37 | 38 | with open(destination, "wb") as f: 39 | for chunk in response.iter_content(CHUNK_SIZE): 40 | if chunk: # filter out keep-alive new chunks 41 | f.write(chunk) 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser(description="data downloading") 45 | 46 | data_id = "1cE0GQ3B4zJhP8305hFbOkifBvf4pBR6S" 47 | 48 | file_id = [data_id] 49 | 50 | destination = "datasets_sample.tar.gz" 51 | 52 | for file_id_e in file_id: 53 | download_file_from_google_drive(file_id_e, destination) 54 | tar = tarfile.open(destination, "r:gz") 55 | tar.extractall() 56 | tar.close() 57 | os.remove(destination) 58 | 59 | -------------------------------------------------------------------------------- /nlp_commons/dep/tiger.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | 4 | import nltk 5 | from nltk.corpus.reader import api 6 | 7 | import treebank 8 | 9 | basedir = 'corpora/TiGerDB' 10 | 11 | class Tiger10(api.SyntaxCorpusReader): 12 | 13 | def __init__(self): 14 | api.SyntaxCorpusReader.__init__(self, nltk.data.find(basedir), 'fdsc-Apr08/.*\.fdsc') 15 | 16 | def _read_block(self, stream): 17 | return [stream.readlines()] 18 | #return [stream.read()] 19 | #s = stream.readline() 20 | #while not s.startswith('sentence_form'): 21 | # s = stream.readline() 22 | 23 | def _word(self, s): 24 | # jump to sentence: 25 | i = 0 26 | while i < len(s) and not s[i].startswith('sentence_form('): 27 | i += 1 28 | assert i < len(s) 29 | l = s[i] 30 | 31 | return l[14:-3].split() 32 | 33 | def _tag(self, s, simplify_tags=False): 34 | return [(x, x) for x in self._word(s)] 35 | 36 | def _parse(self, s): 37 | #print s 38 | 39 | # get sentence length: 40 | w = self._word(s) 41 | n = len(w) 42 | 43 | # jump to structure: 44 | i = 0 45 | while i < len(s) and not s[i].startswith('structure('): 46 | i += 1 47 | assert i < len(s) 48 | 49 | # read dependencies: 50 | deps = [] 51 | i += 1 52 | while i < len(s) and not s[i].startswith(')'): 53 | l = s[i] 54 | #print 'Empieza con', l 55 | l2 = [x for x in re.split(r'[\(,~\s\)]*', l) if x != ''] 56 | if len(l2) == 5: 57 | # this line encodes a dependency 58 | j = int(l2[4]) 59 | k = int(l2[2]) 60 | if j <= n and k <= n: 61 | deps += [(j, k)] 62 | else: 63 | assert len(l2) == 4 64 | i += 1 65 | assert i < len(s) 66 | 67 | deps.sort() 68 | 69 | return deps 70 | 71 | -------------------------------------------------------------------------------- /preprocess_ptb.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # vim:fenc=utf-8 4 | # 5 | # Copyright © 2018-01-19 Junxian He 6 | # 7 | # Distributed under terms of the MIT license. 8 | from __future__ import print_function 9 | 10 | import os 11 | import shutil 12 | import argparse 13 | from nlp_commons import wsj10 14 | from nlp_commons.dep import dwsj 15 | 16 | def generate_file(dir_name, fname, max_length=10): 17 | data_reader = dwsj.DepWSJ(max_length=max_length, basedir=dir_name) 18 | 19 | print('complete reading data') 20 | 21 | tag_sents, _ = data_reader.tagged_sents() 22 | deps_total = data_reader.get_gold_dep() 23 | with open(fname, "w") as fout: 24 | for tag_sent, sent_deps in zip(tag_sents, deps_total): 25 | deps = sent_deps.deps 26 | for i, (tag_word, dep) in enumerate(zip(tag_sent, deps)): 27 | fout.write('%d\t%s\t%s\t%d\n' % (i+1, tag_word[1], tag_word[0], dep[1]+1)) 28 | fout.write('\n') 29 | 30 | parser = argparse.ArgumentParser(description='preprocess ptb data') 31 | parser.add_argument('--ptbdir', type=str, help='input directory') 32 | # parser.add_argument('--task', type=str, choices=["tag", "parse"], 33 | # default="tag") 34 | 35 | args = parser.parse_args() 36 | 37 | if not os.path.exists("tmp_train"): 38 | os.makedirs("tmp_train") 39 | 40 | if not os.path.exists("tmp_test"): 41 | os.makedirs("tmp_test") 42 | 43 | abs_ptb = os.path.abspath(args.ptbdir) 44 | 45 | for i in range(2, 22): 46 | ind = str("%02d" % i) 47 | if not os.path.exists("tmp_train/%02d" % i): 48 | os.symlink(os.path.join(abs_ptb, ind), "tmp_train/%02d" % i) 49 | 50 | ind = 23 51 | if not os.path.exists("tmp_test/%02d" % ind): 52 | os.symlink(os.path.join(abs_ptb, str(ind)), "tmp_test/%02d" % ind) 53 | 54 | outdir = "ptb_parse_data" 55 | if not os.path.exists(outdir): 56 | os.makedirs(outdir) 57 | 58 | print("generate train file (len <= 10)") 59 | generate_file("tmp_train", os.path.join(outdir, "ptb_parse_train_len10.txt")) 60 | 61 | print("generate test file") 62 | generate_file("tmp_test", os.path.join(outdir, "ptb_parse_test.txt"), max_length=200) 63 | 64 | shutil.rmtree("tmp_train") 65 | shutil.rmtree("tmp_test") 66 | -------------------------------------------------------------------------------- /nlp_commons/negra.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | import itertools 6 | 7 | from nltk.corpus.reader.bracket_parse import BracketParseCorpusReader 8 | 9 | from . import treebank 10 | 11 | 12 | def is_ellipsis(s): 13 | #return s[:2] == '*T' 14 | return s[0] == '*' 15 | 16 | 17 | def is_punctuation(s): 18 | return s[0] == '$' 19 | 20 | 21 | class NegraTree(treebank.Tree): 22 | 23 | 24 | def is_ellipsis(self, s): 25 | return is_ellipsis(s) 26 | 27 | 28 | def is_punctuation(self, s): 29 | return is_punctuation(s) 30 | 31 | 32 | class Negra(treebank.SavedTreebank): 33 | default_basedir = 'negra-corpus' 34 | trees = [] 35 | filename = 'negra.treebank' 36 | 37 | 38 | def __init__(self, basedir=None): 39 | if basedir == None: 40 | basedir = self.default_basedir 41 | self.basedir = basedir 42 | self.reader = BracketParseCorpusReader(basedir, 'negra-corpus2.penn', comment_char='%') 43 | 44 | 45 | def parsed(self, files=None): 46 | #for t in treebank.SavedTreebank.parsed(self, files): 47 | for (i, t) in zip(itertools.count(), self.reader.parsed_sents()): 48 | yield NegraTree(t, labels=i) 49 | 50 | 51 | def get_tree(self, offset=0): 52 | t = self.get_trees2(offset, offset+1)[0] 53 | return t 54 | 55 | 56 | # Devuelve los arboles que se encuentran en la posicion i con start <= i < end 57 | def get_trees2(self, start=0, end=None): 58 | lt = [t for t in itertools.islice(self.parsed(), start, end)] 59 | return lt 60 | 61 | 62 | def is_ellipsis(self, s): 63 | return is_ellipsis(s) 64 | 65 | 66 | def is_punctuation(self, s): 67 | return is_punctuation(s) 68 | 69 | 70 | def test(): 71 | tb = Negra() 72 | trees = tb.get_trees() 73 | return tb 74 | 75 | """ 76 | PREPROCESAMIENTO DEL NEGRA: 77 | 78 | >>> f = open('negra-corpus/negra-corpus.penn') 79 | >>> g = open('negra-corpus/negra-corpus2.penn', 'w') 80 | >>> for l in f: 81 | ... if l[0] == '(': 82 | ... l = '(ROOT'+l[1:] 83 | ... g.write(l) 84 | ... 85 | >>> f.close() 86 | >>> g.close() 87 | """ 88 | -------------------------------------------------------------------------------- /nlp_commons/sentence.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | # sentence.py: Class Sentence. 6 | 7 | import string 8 | 9 | class Sentence: # (list): 10 | 11 | def __init__(self, tag_list): 12 | if isinstance(tag_list, str): 13 | tag_list = tag_list.split() 14 | self.tag_list = tag_list 15 | # si heredara de list podria poner: 16 | #list.__init__(self, tag_list) 17 | 18 | def __str__(self): 19 | return string.join(self.tag_list) 20 | 21 | def __repr__(self): 22 | return str(self) 23 | 24 | def reverse(self): 25 | self.tag_list.reverse() 26 | 27 | # iterador sobre todas las subsecuencias de tags. 28 | # devuelve las subsecuencias separadas por espacios en un string. 29 | def itersubseqs(self): 30 | l = len(self.tag_list) 31 | x = 2 # span minimo 32 | for i in range(x, l+1): 33 | for j in range(l-i+1): 34 | yield string.join(self.tag_list[j:j+i]) 35 | 36 | # iterador sobre todos los contextos "a la CCM". 37 | # devuelve pares de tags. 38 | def itercontexts(self): 39 | s = self.tag_list + ['END', 'START'] 40 | l = len(self.tag_list) 41 | x = 2 # span minimo 42 | for i in range(x, l+1): 43 | for j in range(l-i+1): 44 | yield (s[j-1], s[j+i]) 45 | 46 | # Por francolq, basdo en tree.Tree de NLTK 0.9: 47 | 48 | #//////////////////////////////////////////////////////////// 49 | # Disabled list operations 50 | #//////////////////////////////////////////////////////////// 51 | 52 | def __rmul__(self, v): 53 | raise TypeError('Sentence does not support multiplication') 54 | 55 | #//////////////////////////////////////////////////////////// 56 | # Enabled list operations 57 | #/////////////////////////////////////////////////////////// 58 | 59 | # ver "Emulating numeric types": 60 | # http://docs.python.org/ref/numeric-types.html 61 | def __add__(self, v): 62 | return Sentence(self.tag_list + v) 63 | def __radd__(self, v): 64 | return Sentence(v + self.tag_list) 65 | def __iadd__(self, v): 66 | self.tag_list += v 67 | return self 68 | def __mul__(self, v): 69 | return Sentence(self.tag_list * v) 70 | 71 | #//////////////////////////////////////////////////////////// 72 | # Indexing 73 | #//////////////////////////////////////////////////////////// 74 | 75 | def __len__(self): 76 | return len(self.tag_list) 77 | 78 | def __getitem__(self, index): 79 | # return list.__getitem__(self, index) 80 | return self.tag_list[index] 81 | 82 | def __setitem__(self, index, value): 83 | self.tag_list[index] = value 84 | 85 | # se me hace que no la necesito: 86 | def __delitem__(self, index): 87 | del(self.tag_list[index]) 88 | -------------------------------------------------------------------------------- /nlp_commons/rbranch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | # rbranch.py: RBRANCH baseline for unsupervised parsing. 6 | 7 | from . import bracketing, model 8 | 9 | class RBranch(model.BracketingModel): 10 | trained = True 11 | tested = True 12 | 13 | def __init__(self, treebank=None): 14 | model.BracketingModel.__init__(self, treebank) 15 | self.Parse = [bracketing.rbranch_bracketing(b.length) for b in self.Gold] 16 | 17 | 18 | def main(): 19 | print('WSJ10') 20 | main1() 21 | print('NEGRA10') 22 | main2() 23 | print('CAST3LB10') 24 | main3() 25 | 26 | def main1(): 27 | from . import wsj10 28 | tb = wsj10.WSJ10() 29 | m = RBranch(tb) 30 | m.eval() 31 | 32 | def main2(): 33 | from . import negra10 34 | tb = negra10.Negra10() 35 | tb.simplify_tags() 36 | m = RBranch(tb) 37 | m.eval() 38 | 39 | def main3(): 40 | from . import cast3lb10 41 | tb = cast3lb10.Cast3LB10() 42 | tb.simplify_tags() 43 | m = RBranch(tb) 44 | m.eval() 45 | 46 | """ 47 | from rbranch import * 48 | main() 49 | 50 | WSJ10 51 | Cantidad de arboles: 7422.0 52 | Medidas sumando todos los brackets: 53 | Precision: 55.2 54 | Recall: 70.0 55 | Media harmonica F1: 61.7 56 | NEGRA10 57 | Cantidad de arboles: 7537.0 58 | Medidas sumando todos los brackets: 59 | Precision: 33.9 60 | Recall: 60.1 61 | Media harmonica F1: 43.3 62 | CAST3LB10 63 | Cantidad de arboles: 712.0 64 | Medidas sumando todos los brackets: 65 | Precision: 46.9 66 | Recall: 67.0 67 | Media harmonica F1: 55.2 68 | """ 69 | 70 | # VIEJO: 71 | 72 | # No hace falta construir los parses binarios RBRANCH. 73 | """p = 0.0 74 | r = 0.0 75 | brackets_ok = 0 76 | brackets_parse = 0 77 | brackets_gold = 0 78 | # Cantidad de arboles: 79 | m = 0 80 | for b in bs: 81 | n = b.length 82 | #if n >= 3: 83 | if True: 84 | m = m+1 85 | # print str(m)+"-esima frase..." 86 | # s = t.spannings(leaves=False,root=False,unary=False) 87 | # s2 = filter(lambda (a,b): b == n, s) 88 | s = b.brackets 89 | s2 = filter(lambda (a,b): b == n+1, s) 90 | 91 | precision = float(len(s2)) / float(n-2) 92 | 93 | if len(s) > 0: 94 | recall = float(len(s2)) / float(len(s)) 95 | else: 96 | recall = 1.0 97 | 98 | brackets_ok += len(s2) 99 | brackets_parse += n-2 100 | brackets_gold += len(s) 101 | 102 | p = p + precision 103 | r = r + recall 104 | p = p / float(m) 105 | r = r / float(m) 106 | print "Cantidad de arboles:", m 107 | print "Medidas promediando p y r por frase:" 108 | print " Precision de RBRANCH:", p 109 | print " Recall de RBRANCH:", r 110 | print " Media harmonica F1:", 2*(p*r)/(p+r) 111 | p = float(brackets_ok) / float(brackets_parse) 112 | r = float(brackets_ok) / float(brackets_gold) 113 | print "Medidas sumando todos los brackets:" 114 | print " Precision de RBRANCH:", p 115 | print " Recall de RBRANCH:", r 116 | print " Media harmonica F1:", 2*(p*r)/(p+r)""" 117 | 118 | # Debugging: 119 | """ 120 | Para ir tirando arboles hasta encontrar el que da recall con division por 0 (RBRANCH): 121 | from wsj import * 122 | l = [] 123 | m = 0 124 | for e in get_treebank_iterator(): 125 | e.filter_tags() 126 | n = len(e.leaves()) 127 | if n <= 10: 128 | m = m+1 129 | print str(m)+"-esima frase..." 130 | l = l + [e] 131 | if m == 100: 132 | break 133 | """ 134 | -------------------------------------------------------------------------------- /nlp_commons/dep/depgraph.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | from nltk.parse import dependencygraph 6 | from nltk import tree 7 | 8 | import treebank 9 | 10 | class DepGraph(dependencygraph.DependencyGraph): 11 | 12 | def __init__(self, nltk_depgraph): 13 | dependencygraph.DependencyGraph.__init__(self) 14 | self.nodelist = nltk_depgraph.nodelist 15 | self.root = nltk_depgraph.root 16 | self.stream = nltk_depgraph.stream 17 | 18 | def remove_leaves(self, f): 19 | """f must be a function that takes a node dict and returns a boolean. 20 | """ 21 | nodelist = self.nodelist 22 | newnodelist = [nodelist[0].copy()] 23 | newindex = [0] 24 | i, j = 1, 1 25 | while i < len(nodelist): 26 | node = nodelist[i] 27 | if not f(node): 28 | # this node stays 29 | newnode = node.copy() 30 | newnode['address'] = j 31 | newnodelist.append(newnode) 32 | newindex.append(j) 33 | j += 1 34 | else: 35 | newindex.append(-1) 36 | i += 1 37 | #print newindex 38 | # fix attributes 'head' and 'deps': 39 | node = newnodelist[0] 40 | node['deps'] = [newindex[i] for i in node['deps'] if newindex[i] != -1] 41 | for node in newnodelist[1:]: 42 | i = newindex[node['head']] 43 | if i == -1: 44 | raise Exception('Removing non-leaf.') 45 | node['head'] = i 46 | node['deps'] = [newindex[i] for i in node['deps'] if newindex[i] != -1] 47 | self.nodelist = newnodelist 48 | 49 | def constree(self): 50 | # Some depgraphs have several roots (for instance, 512th of Turkish). 51 | #i = self.root['address'] 52 | roots = self.nodelist[0]['deps'] 53 | if len(roots) == 1: 54 | return treebank.Tree(self._constree(roots[0])) 55 | else: 56 | # TODO: check projectivity here also. 57 | trees = [self._constree(i) for i in roots] 58 | return treebank.Tree(tree.Tree('TOP', trees)) 59 | 60 | def _constree(self, i): 61 | node = self.nodelist[i] 62 | word = node['word'] 63 | deps = node['deps'] 64 | if len(deps) == 0: 65 | t = tree.Tree(node['tag'], [word]) 66 | t.span = (i, i+1) 67 | return t 68 | address = node['address'] 69 | ldeps = [j for j in deps if j < address] 70 | rdeps = [j for j in deps if j > address] 71 | lsubtrees = [self._constree(j) for j in ldeps] 72 | rsubtrees = [self._constree(j) for j in rdeps] 73 | csubtree = tree.Tree(node['tag'], [word]) 74 | csubtree.span = (i, i+1) 75 | subtrees = lsubtrees+[csubtree]+rsubtrees 76 | 77 | # check projectivity: 78 | for j in range(len(subtrees)-1): 79 | if subtrees[j].span[1] != subtrees[j+1].span[0]: 80 | raise Exception('Non-projectable dependency graph.') 81 | 82 | t = tree.Tree(word, subtrees) 83 | j = subtrees[0].span[0] 84 | k = subtrees[-1].span[1] 85 | t.span = (j, k) 86 | return t 87 | 88 | 89 | def from_depset(depset, s): 90 | """Returns a DepGraph with the dependencies of depset over the sentence s. 91 | (i, j) in depset means that s[i] depends on s[j]. depset must be sorted. 92 | """ 93 | tab = "" 94 | for i, j in depset: 95 | tab += '\t'.join([str(i+1), s[i], '_', s[i], '_\t_', str(j+1), '_\t_\t_\n']) 96 | return DepGraph(dependencygraph.DependencyGraph(tab)) 97 | -------------------------------------------------------------------------------- /modules/projection.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class ReLUNet(nn.Module): 10 | def __init__(self, hidden_layers, hidden_units, in_features, out_features): 11 | super(ReLUNet, self).__init__() 12 | 13 | self.hidden_layers = hidden_layers 14 | self.in_layer = nn.Linear(in_features, hidden_units, bias=True) 15 | self.out_layer = nn.Linear(hidden_units, out_features, bias=True) 16 | for i in range(hidden_layers): 17 | name = 'cell{}'.format(i) 18 | cell = nn.Linear(hidden_units, hidden_units, bias=True) 19 | setattr(self, name, cell) 20 | 21 | def reset_parameters(self): 22 | self.in_layer.reset_parameters() 23 | self.out_layer.reset_parameters() 24 | for i in range(self.hidden_layers): 25 | name = 'cell{}'.format(i) 26 | getattr(self, name).reset_parameters() 27 | 28 | def init_identity(self): 29 | self.in_layer.weight.data.zero_() 30 | self.in_layer.bias.data.zero_() 31 | self.out_layer.weight.data.zero_() 32 | self.out_layer.bias.data.zero_() 33 | for i in range(self.hidden_layers): 34 | name = 'cell{}'.format(i) 35 | getattr(self, name).weight.data.zero_() 36 | getattr(self, name).bias.data.zero_() 37 | 38 | def forward(self, input): 39 | """ 40 | input: (batch_size, seq_length, in_features) 41 | output: (batch_size, seq_length, out_features) 42 | 43 | """ 44 | h = self.in_layer(input) 45 | h = F.relu(h) 46 | for i in range(self.hidden_layers): 47 | name = 'cell{}'.format(i) 48 | h = getattr(self, name)(h) 49 | h = F.relu(h) 50 | return self.out_layer(h) 51 | 52 | 53 | class NICETrans(nn.Module): 54 | def __init__(self, 55 | couple_layers, 56 | cell_layers, 57 | hidden_units, 58 | features, 59 | device): 60 | super(NICETrans, self).__init__() 61 | 62 | self.device = device 63 | self.couple_layers = couple_layers 64 | 65 | for i in range(couple_layers): 66 | name = 'cell{}'.format(i) 67 | cell = ReLUNet(cell_layers, hidden_units, features//2, features//2) 68 | setattr(self, name, cell) 69 | 70 | def reset_parameters(self): 71 | for i in range(self.couple_layers): 72 | name = 'cell{}'.format(i) 73 | getattr(self, name).reset_parameters() 74 | 75 | def init_identity(self): 76 | for i in range(self.couple_layers): 77 | name = 'cell{}'.format(i) 78 | getattr(self, name).init_identity() 79 | 80 | 81 | def forward(self, input): 82 | """ 83 | input: (seq_length, batch_size, features) 84 | h: (seq_length, batch_size, features) 85 | 86 | """ 87 | 88 | # For NICE it is a constant 89 | jacobian_loss = torch.zeros(1, device=self.device, 90 | requires_grad=False) 91 | 92 | ep_size = input.size() 93 | features = ep_size[-1] 94 | # h = odd_input 95 | h = input 96 | for i in range(self.couple_layers): 97 | name = 'cell{}'.format(i) 98 | h1, h2 = torch.split(h, features//2, dim=-1) 99 | if i%2 == 0: 100 | h = torch.cat((h1, h2 + getattr(self, name)(h1)), dim=-1) 101 | else: 102 | h = torch.cat((h1 + getattr(self, name)(h2), h2), dim=-1) 103 | return h, jacobian_loss 104 | -------------------------------------------------------------------------------- /nlp_commons/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | 6 | 7 | from . import bracketing 8 | 9 | count_fullspan_bracket = True 10 | count_length_2 = True 11 | count_length_2_1 = False 12 | 13 | # Calculo de precision, recall y F1 para dos Bracketings: 14 | def eval(Gold, Parse, output=True, short=False, long=False): 15 | assert len(Gold) == len(Parse) 16 | 17 | # Medidas sumando brackets y despues promediando: 18 | brackets_ok = 0 19 | brackets_parse = 0 20 | brackets_gold = 0 21 | 22 | for gb, pb in zip(Gold, Parse): 23 | l = gb.length 24 | if count_length_2_1 or (count_length_2 and l == 2) or l >= 3: 25 | # Medidas sumando brackets y despues promediando: 26 | (b_ok, b_p, b_g) = measures(gb, pb) 27 | brackets_ok += b_ok 28 | brackets_parse += b_p 29 | brackets_gold += b_g 30 | 31 | """# Medidas sumando brackets y despues promediando: 32 | brackets_ok += n 33 | brackets_parse += len(p) 34 | brackets_gold += len(g)""" 35 | 36 | m = float(len(Gold)) 37 | Prec = float(brackets_ok) / float(brackets_parse) 38 | Rec = float(brackets_ok) / float(brackets_gold) 39 | F1 = 2*(Prec*Rec)/(Prec+Rec) 40 | if output and not short: 41 | print("Cantidad de arboles:", m) 42 | print("Medidas sumando todos los brackets:") 43 | print(" Precision: %2.1f" % (100*Prec)) 44 | print(" Recall: %2.1f" % (100*Rec)) 45 | print(" Media harmonica F1: %2.1f" % (100*F1)) 46 | if int: 47 | print("Brackets parse:", brackets_parse) 48 | print("Brackets gold:", brackets_gold) 49 | print("Brackets ok:", brackets_ok) 50 | elif output and short: 51 | print("F1 =", F1) 52 | else: 53 | return (m, Prec, Rec, F1) 54 | 55 | 56 | def string_measures(gs, ps): 57 | gb = bracketing.string_to_bracketing(gs) 58 | pb = bracketing.string_to_bracketing(ps) 59 | return measures(gb, pb) 60 | 61 | 62 | # FIXME: hacer andar con frases de largo 1! 63 | # devuelve la terna (brackets_ok, brackets_parse, brackets_gold) 64 | # del i-esimo arbol. Se usa para calcular las medidas 65 | # micro-promediadas. 66 | def measures(gb, pb): 67 | g, p = gb.brackets, pb.brackets 68 | n = bracketing.coincidences(gb, pb) 69 | if count_fullspan_bracket: 70 | return (n+1, len(p)+1, len(g)+1) 71 | else: 72 | return (n, len(p), len(g)) 73 | 74 | 75 | # TODO: esta funcion es util, podria pasar a model.BracketingModel. 76 | # goldtb debe ser un treebank, parse una lista de bracketings. 77 | def eval_label(label, goldtb, parse): 78 | Rec = 0.0 79 | brackets_ok = 0 80 | brackets_gold = 0 81 | 82 | bad = [] 83 | 84 | for gt, pb in zip(goldtb.trees, parse): 85 | g = set(x[1] for x in gt.labelled_spannings(leaves=False, root=False, unary=False) if x[0] == label) 86 | gb = bracketing.Bracketing(pb.length, g, start_index=0) 87 | 88 | n = bracketing.coincidences(gb, pb) 89 | if len(g) > 0: 90 | rec = float(n) / float(len(g)) 91 | bad += [difference(gb, pb)] 92 | else: 93 | rec = 1.0 94 | bad += [set()] 95 | Rec += rec 96 | 97 | brackets_ok += n 98 | brackets_gold += len(g) 99 | 100 | m = len(parse) 101 | Rec = Rec / float(m) 102 | 103 | print("Recall:", Rec) 104 | print("Brackets gold:", brackets_gold) 105 | print("Brackets ok:", brackets_ok) 106 | 107 | return (Rec, bad) 108 | 109 | # Conj. de brackets que estan en b1 pero no en b2 110 | # los devuelve con indices comenzando del 0. 111 | def difference(b1, b2): 112 | s1 = set([(x_y[0] - b1.start_index, x_y[1] - b1.start_index) for x_y in b1.brackets]) 113 | s2 = set([(x_y1[0] - b2.start_index, x_y1[1] - b2.start_index) for x_y1 in b2.brackets]) 114 | return s1 - s2 115 | 116 | 117 | -------------------------------------------------------------------------------- /dmv_viterbi_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import argparse 5 | import math 6 | import time 7 | import pickle 8 | 9 | from collections import namedtuple 10 | from modules import read_conll, get_tag_set 11 | import modules.dmv_viterbi_model as dmv 12 | 13 | def init_config(): 14 | 15 | parser = argparse.ArgumentParser(description='train dmv with viterbi EM') 16 | 17 | # hyperparams 18 | parser.add_argument('--stop_adj', default=0.3, type=float, 19 | help='initial value for stop adjacent') 20 | parser.add_argument('--smth_const', default=1, type=int, 21 | help='laplace smooth parameter') 22 | 23 | # data input 24 | parser.add_argument('--train_file', type=str, help='train data path') 25 | parser.add_argument('--test_file', type=str, help='test data path') 26 | 27 | # others 28 | parser.add_argument('--train_from', type=str, default='', 29 | help='load a pre-trained checkpoint') 30 | parser.add_argument('--choice', choices=['random', 'minival', 'bias_middle', 31 | 'soft_bias_middle', 'exclude_end', 'bias_left'], default='exclude_end', 32 | help='tie breaking policy at initial stage') 33 | parser.add_argument('--valid_nepoch', default=1, type=int, 34 | help='test every n iterations') 35 | parser.add_argument('--epochs', default=10, type=int, 36 | help='number of epochs') 37 | 38 | args = parser.parse_args() 39 | 40 | save_dir = "dump_models/dmv" 41 | 42 | if not os.path.exists(save_dir): 43 | os.makedirs(save_dir) 44 | 45 | save_path = os.path.join(save_dir, "viterbi_dmv.pickle") 46 | args.save_path = save_path 47 | 48 | print(args) 49 | 50 | return args 51 | 52 | def main(args): 53 | 54 | train_sents, _ = read_conll(args.train_file) 55 | test_sents, _ = read_conll(args.test_file, max_len=10) 56 | 57 | train_tags = [sent["tag"] for sent in train_sents] 58 | test_tags = [sent["tag"] for sent in test_sents] 59 | test_deps = [sent["head"] for sent in test_sents] 60 | 61 | tag_set = get_tag_set(train_tags) 62 | print('%d tags' % len(tag_set)) 63 | 64 | model = dmv.DMV(args) 65 | model.init_params(train_tags, tag_set) 66 | 67 | model.set_harmonic(False) 68 | 69 | if args.train_from != '': 70 | model = pickle.load(open(args.train_from, 'rb')) 71 | directed, undirected = model.eval(test_deps, test_tags) 72 | print('acc on length <= 10: #trees %d, undir %2.2f, dir %2.2f' \ 73 | % (len(test_deps), 100 * undirected, 100 * directed)) 74 | 75 | epoch = 0 76 | stop = False 77 | 78 | directed, undirected = model.eval(test_deps, test_tags) 79 | print('starting acc on length <= 10: #trees %d, undir %2.2f, dir %2.2f' \ 80 | % (len(test_deps), 100 * undirected, 100 * directed)) 81 | 82 | num_train = len(train_tags) 83 | begin_time = time.time() 84 | while epoch < args.epochs and (not stop): 85 | tita, count = dmv.DMVDict(), dmv.DMVDict() 86 | dmv.lplace_smooth(tita, count, tag_set, model.end_symbol, args.smth_const) 87 | log_likelihood = 0.0 88 | 89 | for i, s in enumerate(filter(lambda s: len(s) > 1, 90 | train_tags)): 91 | if i % 1000 == 0: 92 | print('epoch %d, sentence %d' % (epoch, i)) 93 | parse_tree, prob = model.dep_parse(s) 94 | log_likelihood += prob 95 | model.MStep_s(parse_tree, tita, count) 96 | 97 | model.MStep(tita, count) 98 | print('\n\navg_log_likelihood:%.5f time elapsed: %.2f sec\n\n' % \ 99 | (log_likelihood / num_train, time.time() - begin_time)) 100 | 101 | if epoch % args.valid_nepoch == 0: 102 | directed, undirected = model.eval(test_deps, test_tags) 103 | print('acc on length <= 10: #trees %d, undir %2.2f, dir %2.2f' \ 104 | % (len(test_deps), 100 * undirected, 100 * directed)) 105 | 106 | epoch += 1 107 | 108 | pickle.dump(model, open(args.save_path, 'wb')) 109 | 110 | 111 | if __name__ == '__main__': 112 | parse_args = init_config() 113 | main(parse_args) 114 | -------------------------------------------------------------------------------- /nlp_commons/ubound.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | #!/usr/bin/python 6 | 7 | # Calculo de precision y recall para el topline UBOUND 8 | 9 | 10 | """ 11 | WSJ10 12 | Cantidad de arboles: 7422.0 13 | Medidas sumando todos los brackets: 14 | Precision: 78.8 15 | Recall: 100.0 16 | Media harmonica F1: 88.1 17 | NEGRA10 18 | Cantidad de arboles: 7537.0 19 | Medidas sumando todos los brackets: 20 | Precision: 56.4 21 | Recall: 100.0 22 | Media harmonica F1: 72.1 23 | CAST3LB10 24 | Cantidad de arboles: 712.0 25 | Medidas sumando todos los brackets: 26 | Precision: 70.1 27 | Recall: 100.0 28 | Media harmonica F1: 82.4 29 | """ 30 | 31 | from . import model, bracketing 32 | 33 | class UBound(model.BracketingModel): 34 | trained = True 35 | tested = True 36 | 37 | def __init__(self, treebank): 38 | self.Gold = [bracketing.tree_to_bracketing(t) for t in treebank.trees] 39 | 40 | # FIXME: no esta bien adaptado para usar count_fullspan_bracket 41 | def measures(self, i): 42 | g = self.Gold[i] 43 | n = len(g.brackets) 44 | # m es la cant. de brackets del supuesto parse 45 | m = g.length - 2 46 | if m > 0: 47 | if self.count_fullspan_bracket: 48 | prec = float(n+1) / float(m+1) 49 | else: 50 | prec = float(n) / float(m) 51 | else: 52 | prec = 1.0 53 | rec = 1.0 54 | return (prec, rec) 55 | 56 | def measures2(self, i): 57 | g = self.Gold[i] 58 | n = len(g.brackets) 59 | m = g.length - 2 60 | if self.count_fullspan_bracket: 61 | return (n+1, m+1, n+1) 62 | else: 63 | return (n, m, n) 64 | 65 | def main(): 66 | print('WSJ10') 67 | main1() 68 | print('NEGRA10') 69 | main2() 70 | print('CAST3LB10') 71 | main3() 72 | 73 | def main1(): 74 | from . import wsj10 75 | tb = wsj10.WSJ10() 76 | m = UBound(tb) 77 | m.eval() 78 | return m 79 | 80 | def main2(): 81 | from . import negra10 82 | tb = negra10.Negra10() 83 | tb.simplify_tags() 84 | m = UBound(tb) 85 | m.eval() 86 | return m 87 | 88 | def main3(): 89 | from . import cast3lb10 90 | tb = cast3lb10.Cast3LB10() 91 | tb.simplify_tags() 92 | m = UBound(tb) 93 | m.eval() 94 | return m 95 | 96 | # VIEJO: 97 | 98 | """wsj10 = wsj.get_wsj10_treebank() 99 | 100 | # Recall es 1, obvio. 101 | p = 0.0 102 | r = 1.0 103 | brackets_ok = 0 104 | brackets_parse = 0 105 | brackets_gold = 0 106 | # Cantidad de arboles: 107 | m = 0 108 | for t in wsj10: 109 | n = len(t.leaves()) 110 | if n >= 3: 111 | m = m+1 112 | # print str(m)+"-esima frase..." 113 | s = t.spannings(leaves=False,root=False,unary=False) 114 | precision = float(len(s)) / float(n-2) 115 | brackets_parse += n-2 116 | brackets_gold += len(s) 117 | 118 | p = p + precision 119 | p = p / float(m) 120 | print "Cantidad de arboles:", m 121 | print "Medidas promediando p y r por frase:" 122 | print " Precision de UBOUND:", p 123 | print " Recall de UBOUND:", r 124 | print " Media harmonica F1:", 2*(p*r)/(p+r) 125 | p = float(brackets_gold) / float(brackets_parse) 126 | print "Medidas sumando todos los brackets:" 127 | print " Precision de UBOUND:", p 128 | print " Recall de UBOUND:", r 129 | print " Media harmonica F1:", 2*(p*r)/(p+r)""" 130 | 131 | # Cantidad de arboles: 7056 132 | # Medidas promediando p y r por frase: 133 | # Precision de UBOUND: 0.740901529262 134 | # Recall de UBOUND: 1.0 135 | # Media harmonica F1: 0.851169944777 136 | # Medidas sumando todos los brackets: 137 | # Precision de UBOUND: 0.747252747253 138 | # Recall de UBOUND: 1.0 139 | # Media harmonica F1: 0.85534591195 140 | 141 | # Intento de usar eval del que desisti antes de fracasar: 142 | # (deberia programar un binarize y que el parse sea eso) 143 | """import eval 144 | 145 | wsj10 = wsj.get_wsj10_treebank() 146 | Gold = [] 147 | Parse = [] 148 | for t in wsj10: 149 | if len(t.leaves()) >= 3: 150 | g = t.spannings(leaves=False,root=False)""" 151 | 152 | # Debugging: 153 | """ 154 | Para ir tirando arboles hasta encontrar el que da precision > 1 (UBOUND): 155 | from wsj import * 156 | l = [] 157 | m = 0 158 | for e in get_treebank_iterator(): 159 | e.filter_tags() 160 | n = len(e.leaves()) 161 | if n <= 10: 162 | m = m+1 163 | print str(m)+"-esima frase..." 164 | l = l + [t] 165 | # Cuento los spans que coinciden no trivialmente con rbranch: 166 | s = e.spannings(leaves=False) 167 | s.remove((0,n)) 168 | if len(s) > float(n-2): 169 | break 170 | """ 171 | -------------------------------------------------------------------------------- /nlp_commons/dep/depset.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | # depset.py: Dependency set. 6 | 7 | from .. import util 8 | 9 | 10 | class DepSet: 11 | 12 | def __init__(self, length, deps): 13 | self.length = length 14 | self.deps = deps 15 | 16 | 17 | def from_depgraph(g): 18 | length = len(g.nodelist)-1 19 | deps = [(n['address']-1, n['head']-1) for n in g.nodelist[1:]] 20 | return DepSet(length, deps) 21 | 22 | 23 | def from_string(s): 24 | """ 25 | >>> d = from_string('[(0,3), (1,0), (2,1), (3,-1)]\n') 26 | """ 27 | """t = s[1:].split() 28 | l = len(t) 29 | deps = [] 30 | for x in t: 31 | y = x[1:-2].split(',') 32 | deps += [(int(y[0]), int(y[1]))]""" 33 | deps = util.safe_eval(s) 34 | l = len(deps) 35 | return DepSet(l, deps) 36 | 37 | 38 | def deptree_to_depset(t): 39 | return DepSet(len(t.leaves()), t.depset) 40 | 41 | 42 | def lhead_depset(length): 43 | deps = [(i, i-1) for i in range(length)] 44 | return DepSet(length, deps) 45 | 46 | 47 | def rhead_depset(length): 48 | deps = [(i, i+1) for i in range(length-1)] + [(length-1, -1)] 49 | return DepSet(length, deps) 50 | 51 | 52 | def _binary_depsets(n): 53 | """Helper for binary_depsets. 54 | """ 55 | if n == 0: 56 | return [[]] 57 | elif n == 1: 58 | return [[(0, -1)]] 59 | else: 60 | result = [] 61 | for i in range(n): 62 | lres = _binary_depsets(i) 63 | rres = _binary_depsets(n-1-i) 64 | lres = map(lambda l: [(j, (k!=-1 and k) or i) for (j,k) in l], lres) 65 | rres = map(lambda l: [(j+i+1, (k!=-1 and (k+i+1)) or i) for (j,k) in l], rres) 66 | #print i, lres, rres 67 | result += [l+[(i, -1)]+r for l in lres for r in rres] 68 | 69 | return result 70 | 71 | 72 | def binary_depsets(n): 73 | """Returns all the binary dependency trees for a sentence of length n. 74 | """ 75 | return map(lambda s: DepSet(n, s), _binary_depsets(n)) 76 | 77 | 78 | def _all_depsets(n): 79 | """Helper for all_depsets. 80 | """ 81 | # Dynamically programmed: 82 | depsets = {1: [[(0, -1)]]} 83 | sums = _all_sums(n) 84 | sums[0] = [[]] 85 | 86 | for i in range(2, n+1): 87 | result = [] 88 | for j in range(0, i): 89 | # j is the root. 90 | 91 | # to the left: 92 | lres = [] 93 | ll = sums[j] 94 | for l in ll: 95 | laux = [[]] 96 | acum = 0 97 | for k in l: 98 | # for instance, j=3, l=[1,2]. 99 | laux2 = [] 100 | for m in depsets[k]: 101 | m2 = [(p+acum, (q!=-1 and (q+acum)) or j) for (p,q) in m] 102 | laux2 += [m2] 103 | laux = [o+m2 for o in laux for m2 in laux2] 104 | acum += k 105 | lres += laux 106 | 107 | # to the right: 108 | rres = [] 109 | ll = sums[i-1-j] 110 | for l in ll: 111 | laux = [[]] 112 | acum = j+1 113 | for k in l: 114 | laux2 = [] 115 | for m in depsets[k]: 116 | m2 = [(p+acum, (q!=-1 and (q+acum)) or j) for (p,q) in m] 117 | laux2 += [m2] 118 | laux = [o+m2 for o in laux for m2 in laux2] 119 | acum += k 120 | rres += laux 121 | 122 | #lres = map(lambda l: [(p, (q!=-1 and q) or j) for (p,q) in l], lres) 123 | #rres = map(lambda l: [(p+j+1, (q!=-1 and (q+j+1)) or j) for (p,q) in l], rres) 124 | 125 | result += [l+[(j, -1)]+r for l in lres for r in rres] 126 | depsets[i] = result 127 | 128 | return depsets 129 | 130 | 131 | """if n == 0: 132 | return [[]] 133 | elif n == 1: 134 | return [[(0, -1)]] 135 | else: 136 | result = [] 137 | for i in range(n): 138 | lres = [[]] 139 | for lsplits in range(0, i): 140 | 141 | lres = _all_depsets(i) 142 | 143 | 144 | rres = _all_depsets(n-1-i) 145 | lres = map(lambda l: [(j, (k!=-1 and k) or i) for (j,k) in l], lres) 146 | rres = map(lambda l: [(j+i+1, (k!=-1 and (k+i+1)) or i) for (j,k) in l], rres) 147 | #print i, lres, rres 148 | result += [l+[(i, -1)]+r for l in lres for r in rres] 149 | 150 | return result""" 151 | 152 | 153 | def all_depsets(n): 154 | """Returns all the dependency sets for a sentence of length n. 155 | """ 156 | return map(lambda s: DepSet(n, s), _all_depsets(n)) 157 | 158 | 159 | def _all_sums(n): 160 | """Helper for all_depsets. 161 | Returns a dictionary with keys 1, ..., n. 162 | The value for each i is a list with all the ways of summing i. 163 | """ 164 | # Dynamically programmed 165 | # sums(1) = 1 166 | # sums(2) = 1+1,2 167 | # sums(3) = 1+1+1,1+2,2+1,3 168 | # sums(4) = 1+1+1+1,1+1+2,1+2+1,2+1+1,2+2,1+3,3+1,4 169 | # = map (1+) (sums(3)), map (2+) (sums(2)), map (3+) (sums(1)), 4 170 | # = 1+1+1+1,1+1+2,1+2+1,1+3, 2+1+1,2+2, 3+1, 4 171 | sums = {1:[[1]]} 172 | for i in range(2, n+1): 173 | l = [] 174 | for j in range(1, i): 175 | l += [[j]+l2 for l2 in sums[i-j]] 176 | l += [[i]] 177 | sums[i] = l 178 | 179 | return sums 180 | -------------------------------------------------------------------------------- /nlp_commons/wsj.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | import codecs 6 | import itertools 7 | import os 8 | 9 | from nltk.corpus.reader.util import read_sexpr_block 10 | from nltk.corpus.reader import bracket_parse 11 | from nltk import tree 12 | from nltk import Tree 13 | from nltk.util import LazyMap 14 | 15 | from . import treebank 16 | 17 | word_tags = ['CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NN', 'NNS', 'NNP', 'NNPS', 'PDT', 18 | 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP$', 'WRB'] 19 | currency_tags_words = ['#', '$', 'C$', 'A$'] 20 | ellipsis = ['*', '*?*', '0', '*T*', '*ICH*', '*U*', '*RNR*', '*EXP*', '*PPA*', '*NOT*'] 21 | punctuation_tags = ['.', ',', ':', '-LRB-', '-RRB-', '\'\'', '``'] 22 | punctuation_words = ['.', ',', ':', '-LRB-', '-RRB-', '\'\'', '``', '--', ';', '-', '?', '!', '...', '-LCB-', '-RCB-'] 23 | # tag de -- - ; ... es : 24 | # tag de ? ! es . 25 | # ' no es puntuacion sino POS (pronombre posesivo?) 26 | # '-LCB-', '-RCB-' son los corchetes 27 | 28 | # el puto arbol ['07/wsj_0758.mrg', 74] (antepenultimo) usa comillas simples 29 | 30 | 31 | # funciona tanto si las hojas son lexico como POS tags. 32 | def is_ellipsis(s): 33 | return s == '-NONE-' or s.partition('-')[0] in ellipsis 34 | 35 | 36 | # funciona tanto si las hojas son lexico como POS tags. 37 | def is_punctuation(s): 38 | # solo comparo con punctuation_words porque incluye a punctuation_tags. 39 | return s in punctuation_words 40 | 41 | 42 | class WSJTree(treebank.Tree): 43 | 44 | def is_ellipsis(self, s): 45 | return is_ellipsis(s) 46 | 47 | def is_punctuation(self, s): 48 | return is_punctuation(s) 49 | 50 | 51 | # TODO: Rename this class to WSJ. 52 | class WSJSents(bracket_parse.BracketParseCorpusReader): 53 | def __init__(self): 54 | bracket_parse.BracketParseCorpusReader.__init__(self, 'wsj_comb', '.*') 55 | 56 | def tagged_sents(self): 57 | # Remove punctuation, ellipsis and currency ($, #) at the same time: 58 | f = lambda s: [x for x in s if x[1] in word_tags] 59 | return LazyMap(f, bracket_parse.BracketParseCorpusReader.tagged_sents(self)) 60 | 61 | 62 | # TODO: remove this class and rename WSJSents to WSJ. 63 | class WSJ(treebank.SavedTreebank): 64 | default_basedir = 'wsj_comb' 65 | trees = [] 66 | filename = 'wsj.treebank' 67 | 68 | def __init__(self, basedir=None): 69 | if basedir == None: 70 | self.basedir = self.default_basedir 71 | else: 72 | self.basedir = basedir 73 | #self.reader = BracketParseCorpusReader(self.basedir, self.get_files()) 74 | 75 | def get_files(self): 76 | l = os.listdir(self.basedir) 77 | files = [] 78 | for d in l: 79 | files = files + [d+'/'+s for s in os.listdir(self.basedir+'/'+d)] 80 | return files 81 | 82 | """def parsed(self, files=None): 83 | if files is None: 84 | files = self.get_files() 85 | for (i, t) in itertools.izip(itertools.count(), treebank.SavedTreebank.parsed(self, files)): 86 | yield WSJTree(t, labels=i)""" 87 | 88 | def parsed(self, files=None): 89 | """ 90 | @param files: One or more WSJ treebank files to be processed 91 | @type files: L{string} or L{tuple(string)} 92 | @rtype: iterator over L{tree} 93 | """ 94 | if files is None or files == []: 95 | files = self.get_files() 96 | 97 | # Just one file to process? If so convert to a tuple so we can iterate 98 | if isinstance(files, str): 99 | files = (files,) 100 | 101 | size = 0 102 | for file in files: 103 | path = os.path.join(self.basedir, file) 104 | f = codecs.open(path, encoding='utf-8') 105 | i = 0 106 | t = read_parsed_tb_block(f) 107 | #print "Parsing", len(t), "trees from file", file 108 | # print "Parsing file", file 109 | while t != []: 110 | size += 1 111 | #yield treebank.Tree(t[0], [file, i]) 112 | yield WSJTree(t[0], [file, i]) 113 | i = i+1 114 | t = t[1:] 115 | if t == []: 116 | t = read_parsed_tb_block(f) 117 | print("Finished processing", size, "trees") 118 | 119 | def get_tree(self, offset=0): 120 | t = self.get_trees2(offset, offset+1)[0] 121 | return t 122 | 123 | # Devuelve los arboles que se encuentran en la posicion i con start <= i < end 124 | def get_trees2(self, start=0, end=None): 125 | lt = [t for t in itertools.islice(self.parsed(), start, end)] 126 | return lt 127 | 128 | def is_ellipsis(self, s): 129 | return is_ellipsis(s) 130 | 131 | def is_punctuation(self, s): 132 | return is_punctuation(s) 133 | 134 | 135 | def test(): 136 | tb = WSJ() 137 | trees = tb.get_trees() 138 | return tb 139 | 140 | 141 | # ROBADO DE nltk 0.8, nltk/corpus/treebank.py, despues eliminado de nltk. 142 | 143 | def treebank_bracket_parse(t): 144 | try: 145 | return Tree.fromstring(t, remove_empty_top_bracketing=True) 146 | except IndexError: 147 | # in case it's the real treebank format, 148 | # strip first and last brackets before parsing 149 | return tree.bracket_parse(t.strip()[1:-1]) 150 | 151 | def read_parsed_tb_block(stream): 152 | return [treebank_bracket_parse(t) for t in read_sexpr_block(stream)] 153 | -------------------------------------------------------------------------------- /nlp_commons/negra10.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | import itertools 6 | 7 | from nltk.util import LazyMap 8 | 9 | from . import negra 10 | 11 | 12 | class Negran(negra.Negra): 13 | 14 | def __init__(self, n, basedir=None, load=True): 15 | negra.Negra.__init__(self, basedir) 16 | self.n = n 17 | self.filename = 'negra%02i.treebank' % n 18 | if load: 19 | self.get_trees() 20 | 21 | def _generate_trees(self): 22 | print("Parsing treebank...") 23 | # algunas frases quedan de largo 0 porque son solo '.' 24 | g = lambda l: (l <= self.n) and (l > 0) 25 | #f = lambda t: (len(t.leaves()) <= self.n) and (len(t.leaves()) > 0) 26 | f = lambda t: g(len(t.leaves())) 27 | m = lambda t: self._prepare(t) 28 | trees = [t for t in filter(f, map(m, self.parsed()))] 29 | return trees 30 | 31 | def _prepare(self, t): 32 | t.remove_leaves() 33 | t.remove_ellipsis() 34 | t.remove_punctuation() 35 | return t 36 | 37 | def tagged_sents(self): 38 | # LazyMap from nltk.util: 39 | f = lambda t: [(x,x) for x in t.leaves()] 40 | return LazyMap(f, self.get_trees()) 41 | 42 | # XXX: este simplify tags deja un tag '' en el corpus. leer implementacion. 43 | def simplify_tags(self): 44 | # XXX: esto no funciona cuando el '-' no esta, leer docs: 45 | #f = lambda s: s.rpartition('-')[0] 46 | 47 | # partition with firt or last '-'? 48 | # sent 1461 (4391 in the whole corpus) 49 | # has tag '--': with 1st: '', with 2nd: '-'. 50 | # this is the only sentence that has a tag with two '-'s. 51 | f = lambda s: s.partition('-')[0] 52 | list(map(lambda t: t.map_leaves(f), self.trees)) 53 | 54 | # manually fix tree 1461: 55 | self.trees[1461][1] = '-' 56 | 57 | 58 | class Negra10(Negran): 59 | 60 | def __init__(self, basedir=None, load=True): 61 | Negran.__init__(self, 10, basedir, load) 62 | 63 | """ 64 | Punctuation in NEGRA has some problems: 65 | 1. Opening and closing quotes are not distinguished (they are always "). 66 | 2. Closing parenthesis are always tagged as opening ('($*LRB* *RRB*)'). 67 | 3. Quotes are always tagged $*LRB*, getting confused with parenthesis. 68 | 4. Single quotes are always tagged $*LRB* ($*LRB*-PNC in two cases). They aren't even real punctuation but possesives in the most cases. 69 | 5. What about dashes? slash (/)? elipsis (...)? 70 | 5a. Slashes (/) are tagged $*LRB*. They usually are punctuation. 71 | 5b. Dashes (-) are tagged $*LRB*. They usually are punctuation. Sometimes bracket punctuation. Some dashes are tagged starting with *. Those are not punctuation but empty elements. 72 | 5c. Ellipsis (...) are tagged $*LRB*. They usually are punctuation. 73 | 6. In Penn format some punctuation has been crossed. What originally was 74 | " La mujer de Benjamen " ( Benjamins Frau , 1990 ) 75 | became 76 | " La mujer de Benjamen ( " Benjamins Frau , 1990 ) 77 | in Penn format (sentence 120). 78 | 79 | In this class we fix problems 2, 3, 4 and 5: 80 | a. To be sure that they are different, parenthesis are tagged $( and )$. 81 | b. Quotes are tagged with $", slashes $/, dashes $d (- is used to separate tag from function), ellipsis '$...'. 82 | c. Single quotes are removed by being tagged as ellipsis '*' (not '...' but empty elements). 83 | """ 84 | class Negra10P(Negra10): 85 | # sadly I need this list (redundant beacuse we have is_punctuation) to allow 86 | # usage by other classes that want to pickle this information: 87 | # XXX: only sure it works for subclass Negra10: 88 | punctuation_tags = ['$.', '$/', '$,', '$d', '$.-NMC', '$(', '$)', '$.-CD', '$"', '$...'] 89 | # this was found this way: 90 | #from negra10 import * 91 | #tb = Negra10P() 92 | #punct = set(sum(([x for x in t.leaves() if x[0] == '$'] for t in tb.trees), [])) 93 | 94 | stop_punctuation_tags = ['$.', '$/', '$,', '$d', '$.-NMC', '$.-CD', '$...'] 95 | bracket_punctuation_tag_pairs = [('$(', '$)'), ('$"',)] 96 | 97 | 98 | def __init__(self, basedir=None, load=True): 99 | Negra10.__init__(self, basedir, load=False) 100 | self.filename = 'negra%02ip.treebank' % self.n 101 | if load: 102 | self.get_trees() 103 | 104 | 105 | def _generate_trees(self): 106 | print("Parsing treebank...") 107 | # algunas frases quedan de largo 0 porque son solo '.' 108 | g = lambda l: (l <= self.n) and (l > 0) 109 | #f = lambda t: g(len(filter(lambda x: x not in self.punctuation_tags, t.leaves()))) 110 | f = lambda t: g(len([x for x in t.leaves() if not negra.is_punctuation(x)])) 111 | m = lambda t: self._prepare(t) 112 | trees = [t for t in filter(f, map(m, self.parsed()))] 113 | return trees 114 | 115 | 116 | def _prepare(self, t): 117 | # before removing leaves, punctuation tags must be fixed: 118 | for p in t.treepositions('leaves'): 119 | l = t[p] 120 | tag = t[p[:-1]] 121 | if l == '*RRB*': 122 | tag.node = '$)' 123 | elif l == '*LRB*': 124 | tag.node = '$(' 125 | elif l == '"': 126 | tag.node = '$"' 127 | elif l == '/': 128 | tag.node = '$/' 129 | elif l == '-' and tag.node[0] == '$': 130 | # some dashes are not punctutation but empty elements 131 | # (tag starting with *). 132 | tag.node = '$d' 133 | elif l == '...': 134 | tag.node = '$...' 135 | elif l == '\'': 136 | tag.node = '*' 137 | 138 | t.remove_leaves() 139 | t.remove_ellipsis() 140 | #t.remove_punctuation() 141 | return t 142 | 143 | 144 | def test(): 145 | tb = Negra10() 146 | tb.simplify_tags() 147 | return tb 148 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # struct-learning-with-flow 2 | 3 | This is PyTorch implementation of the [paper](https://arxiv.org/abs/1808.09111): 4 | ``` 5 | Unsupervised Learning of Syntactic Structure with Invertible Neural Projections 6 | Junxian He, Graham Neubig, Taylor Berg-Kirkpatrick 7 | EMNLP 2018 8 | ``` 9 | 10 | The code performs unsupervised structure learning on language, specifically to learn Markov structure and dependency structure. 11 | 12 | Please concact junxianh@cs.cmu.edu if you have any questions. 13 | 14 | ## Requirements 15 | 16 | - Python 3 17 | - PyTorch >=0.4 18 | - [scikit-learn](http://scikit-learn.org/stable/) (for tagging task only) 19 | - [NLTK](https://www.nltk.org/) (for parsing task only) 20 | 21 | ## Data 22 | We provide the pre-trained word vector file we used in the paper and a small subset of Penn Treebank data for testing the tagging code. This dataset contains 10% samples of Penn Treebank and is public in [NLTK corpus](http://www.nltk.org/howto/corpus.html). Full Penn Treebank dataset requires a LDC license. 23 | 24 | To download the sample data, run: 25 | ```shell 26 | python prepare_data.py 27 | ``` 28 | The downloaded data is located in `sample_data`. 29 | 30 | Throughout two tasks we use simplified CoNLL format as data input that contains four columns: 31 | ``` 32 | ID Token Tag Head 33 | ``` 34 | At training time only `Token` is used, `Head` represents the dependency head index (for evaluation of parsing task). `Tag` is used for evaluation of tagging task. As observations in our generative model, pre-trained word vectors are required. The input word2vec map should be a pickled representation of Python dict object. 35 | 36 | We also provide script to preprocess full Penn Treebank dataset for parsing (e.g. converting parse trees, removing punctuations, etc.), the `wsj` directory should look like: 37 | ``` 38 | wsj 39 | +-- 00 40 | | +-- wsj_0001.mrg 41 | | +-- ... 42 | +-- 01 43 | +-- ... 44 | 45 | ``` 46 | run: 47 | ```shell 48 | python preprocess_ptb.py --ptbdir /path/to/wsj 49 | ``` 50 | This command would generate train/test files in `ptb_parse_data`. Note that the generated data files contain gold POS tags in the `Tag` column, thus are not the files we used in the paper, where the tags are induced from the Markov model. 51 | 52 | **TODO**: Simpify the pipline to generate train/test files without gold POS tags for parsing to reproduce the parsing results. 53 | 54 | ## Markov Structure for Tagging 55 | 56 | ### Training 57 | 58 | Train a Gaussian HMM baseline: 59 | 60 | ```shell 61 | python markov_flow_train.py --model gaussian --train_file /path/to/train --word_vec /path/to/word_vec_file 62 | ``` 63 | 64 | By default we evaluate on the training data (this is not cheating in unsupervised learning case), different test dataset can be specified by `--test_file` option. Training uses GPU when there is GPU available, and CPU otherwise, but running on CPU can be extremely slow. Full configuration options can be found in `markov_flow_train.py`. After training the trained model will be saved in `dump_models/markov/`. 65 | 66 | Unsupervised learning is usually very sensitive to initializations, for this task we run multiple random restarts and pick the one with the highest training data likelihood as described in paper. It is generally sufficient to run 10 random restarts. When running with multiple random restarts, it is necessary to specify the `--jobid` or `--taskid` options to avoid model overwriting. 67 | 68 | After training the Gaussian HMM, train a projection model with Markov prior: 69 | 70 | ```shell 71 | python markov_flow_train.py \ 72 | --model nice \ 73 | --lr 0.01 \ 74 | --train_file /path/to/train \ 75 | --word_vec /path/to/word_vec_file \ 76 | --load_gaussian /path/to/gaussian_model 77 | ``` 78 | 79 | Initializing the prior with pre-trained Gaussian baseline would make the training much more stable. By default 4 coupling layers are used in NICE projection. 80 | 81 | ### Results 82 | 83 | On the provided subset of Penn Treebank that contains 3914 sentences, the Gaussian HMM is able to achieve ~76.5% M1 accuracy and ~0.692 VM score, and the projection model (4 layers) achieves ~79.2% M1 accuracy and ~0.718 VM score. 84 | 85 | ### Prediction 86 | 87 | After training, prediction can be performed with : 88 | 89 | ```shell 90 | python markov_flow_train.py --model nice --train_file /path/to/tag_file --tag_from /path/to/pretrained_model 91 | ``` 92 | 93 | Here `--train_file` represents the file to be tagged, the output file is located in the current directory. 94 | 95 | 96 | 97 | 98 | ## DMV Structure for Parsing 99 | ### Training 100 | 101 | First train a vanilla DMV model with viterbi EM (this only runs on CPU): 102 | 103 | ```shell 104 | python dmv_viterbi_train.py --train_file /path/to/train_data --test_file /path/to/test_data 105 | ``` 106 | 107 | Trained model is saved in `dump_models/dmv/viterbi_dmv.pickle`. Implementation of this basic DMV training is partially based on [this repo](https://github.com/davidswelt/dmvccm). 108 | 109 | 110 | 111 | Then use the pre-trained DMV to initialize the syntax model in flow/Gaussian model: 112 | 113 | ```shell 114 | python dmv_flow_train.py \ 115 | --model nice \ 116 | --train_file /path/to/train_data \ 117 | --test_file /path/to/test_data \ 118 | --word_vec /path/to/word_vec_file \ 119 | --load_viterbi_dmv dump_models/dmv/viterbi_dmv.pickle 120 | ``` 121 | 122 | The script trains a Gaussian baseline when `--model` is specified as `gaussian`. Training uses GPU when there is GPU available, and CPU otherwise. Trained model is saved in `dump_models/dmv/`. 123 | 124 | ## Acknowledgement 125 | The awesome `nlp_commons` package (for preprocessing the Penn Treebank) in this repo was originally developed by Franco M. Luque and can be found in this [repo](https://github.com/davidswelt/dmvccm). 126 | 127 | 128 | ## Reference 129 | ``` 130 | @inproceedings{he2018unsupervised, 131 | title = {Unsupervised Learning of Syntactic Structure with Invertible Neural Projections}, 132 | author = {Junxian He and Graham Neubig and Taylor Berg-Kirkpatrick}, 133 | booktitle = {Proceedings of EMNLP}, 134 | year = {2018} 135 | } 136 | ``` 137 | 138 | -------------------------------------------------------------------------------- /nlp_commons/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | # util.py: Some utilities, mainly for serialization (pickling) of objects. 6 | 7 | import os 8 | import pickle 9 | import sys 10 | 11 | import nltk 12 | 13 | 14 | obj_basedir = 'lq-nlp-commons' 15 | 16 | 17 | def write_file(filename, content): 18 | f = open(filename, 'w') 19 | f.write(content) 20 | f.close() 21 | 22 | 23 | def read_file(filename): 24 | f = open(filename) 25 | content = f.read() 26 | f.close() 27 | return content 28 | 29 | 30 | # XXX: trabaja con listas aunque podria hacerse con set. 31 | def powerset(s): 32 | if len(s) == 0: 33 | return [[]] 34 | else: 35 | e = s[0] 36 | p = powerset(s[1:]) 37 | return p + [x+[e] for x in p] 38 | 39 | 40 | # me fijo si un bracketing no tiene cosas que se cruzan 41 | def tree_consistent(b): 42 | """FIXME: move this to the bracketing package. 43 | """ 44 | def crosses(xxx_todo_changeme, xxx_todo_changeme1): 45 | (a,b) = xxx_todo_changeme 46 | (c,d) = xxx_todo_changeme1 47 | return (a < c and c < b and b < d) or (c < a and a < d and d < b) 48 | 49 | for i in range(len(b)): 50 | for j in range(i+1,len(b)): 51 | if crosses(b[i], b[j]): 52 | return False 53 | return True 54 | 55 | 56 | def get_obj_basedir(): 57 | try: 58 | return nltk.data.find(obj_basedir) 59 | except LookupError: 60 | os.mkdir(os.path.join(nltk.data.path[0], obj_basedir)) 61 | return nltk.data.find(obj_basedir) 62 | 63 | 64 | # Guarda un objeto en un archivo, para luego ser cargado con load_obj. 65 | def save_obj(object, filename): 66 | path = os.path.join(get_obj_basedir(), filename) 67 | f = open(path, 'w') 68 | pickle.dump(object, f, pickle.HIGHEST_PROTOCOL) 69 | f.close() 70 | 71 | 72 | # Carga un objeto guardado en un archivo con save_obj. 73 | def load_obj(filename): 74 | path = os.path.join(get_obj_basedir(), filename) 75 | try: 76 | f = open(path, 'r') 77 | object = pickle.load(f) 78 | f.close() 79 | except IOError: 80 | object = None 81 | return object 82 | 83 | 84 | # Carga una lista de objetos guardados en un archivo usando ObjectSaver. 85 | def load_objs(filename): 86 | path = os.path.join(get_obj_basedir(), filename) 87 | try: 88 | f = open(path, 'r') 89 | objects = [] 90 | try: 91 | while True: 92 | objects += [pickle.load(f)] 93 | except EOFError: # It will always be thrown 94 | f.close() 95 | except IOError: 96 | objects = None 97 | return objects 98 | 99 | 100 | class ObjectSaver: 101 | 102 | # Si el archivo existe, lo abre, lo lee y comienza a escribir al final 103 | def __init__(self, filename): 104 | path = os.path.join(get_obj_basedir(), filename) 105 | self.f = open(path, 'a+') 106 | self.orig_objs = [] 107 | try: 108 | while True: 109 | self.orig_objs += [pickle.load(self.f)] 110 | except EOFError: # It will always be thrown 111 | pass 112 | 113 | def save_obj(self, object): 114 | pickle.dump(object, self.f, pickle.HIGHEST_PROTOCOL) 115 | 116 | def flush(self): 117 | self.f.flush() 118 | 119 | def close(self): 120 | self.f.close() 121 | 122 | 123 | class Progress: 124 | """ 125 | Helper class to ouput to stdout a fancy indicator of the progress of something. 126 | See model.Model for an example of usage. 127 | 128 | >>> p = Progress('Parsed', 0, 200) 129 | Parsed 0 of 200 130 | >>> p.next() 131 | 1 of 200 132 | >>> p.next() 133 | 2 of 200 134 | """ 135 | 136 | def __init__(self, prefix, n_init, n_max): 137 | m = len(str(n_max)) 138 | o = "%"+str(m)+"d of "+str(n_max) 139 | self.i = 0 140 | print(prefix, o % self.i, end=' ') 141 | sys.stdout.flush() 142 | self.o = ("\b"*(2*m+5)) + o 143 | 144 | def __next__(self): 145 | self.i += 1 146 | print(self.o % self.i, end=' ') 147 | sys.stdout.flush() 148 | 149 | 150 | # Recipe 364469: "Safe" Eval (Python) by Michael Spencer 151 | # ActiveState Code (http://code.activestate.com/recipes/364469/) 152 | 153 | 154 | # import compiler 155 | 156 | # class Unsafe_Source_Error(Exception): 157 | # def __init__(self,error,descr = None,node = None): 158 | # self.error = error 159 | # self.descr = descr 160 | # self.node = node 161 | # self.lineno = getattr(node,"lineno",None) 162 | 163 | # def __repr__(self): 164 | # return "Line %d. %s: %s" % (self.lineno, self.error, self.descr) 165 | # __str__ = __repr__ 166 | 167 | # class SafeEval(object): 168 | 169 | # def visit(self, node,**kw): 170 | # cls = node.__class__ 171 | # meth = getattr(self,'visit'+cls.__name__,self.default) 172 | # return meth(node, **kw) 173 | 174 | # def default(self, node, **kw): 175 | # for child in node.getChildNodes(): 176 | # return self.visit(child, **kw) 177 | 178 | # visitExpression = default 179 | 180 | # def visitConst(self, node, **kw): 181 | # return node.value 182 | 183 | # def visitDict(self,node,**kw): 184 | # return dict([(self.visit(k),self.visit(v)) for k,v in node.items]) 185 | 186 | # def visitTuple(self,node, **kw): 187 | # return tuple(self.visit(i) for i in node.nodes) 188 | 189 | # def visitList(self,node, **kw): 190 | # return [self.visit(i) for i in node.nodes] 191 | 192 | # def visitUnarySub(self, node, **kw): 193 | # return -self.visit(node.getChildNodes()[0]) 194 | 195 | 196 | # class SafeEvalWithErrors(SafeEval): 197 | 198 | # def default(self, node, **kw): 199 | # raise Unsafe_Source_Error("Unsupported source construct", 200 | # node.__class__,node) 201 | 202 | # def visitName(self,node, **kw): 203 | # raise Unsafe_Source_Error("Strings must be quoted", 204 | # node.name, node) 205 | 206 | # # Add more specific errors if desired 207 | 208 | 209 | # def safe_eval(source, fail_on_error = True): 210 | # walker = fail_on_error and SafeEvalWithErrors() or SafeEval() 211 | # try: 212 | # ast = compiler.parse(source,"eval") 213 | # except SyntaxError as err: 214 | # raise 215 | # try: 216 | # return walker.visit(ast) 217 | # except Unsafe_Source_Error as err: 218 | # raise 219 | -------------------------------------------------------------------------------- /nlp_commons/graph.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | import pydot 6 | 7 | 8 | class Graph: 9 | 10 | def __init__(self, nodes, edges): 11 | """For instance Graph([1,2,3], [(1,2),(2,3),(3,1)]). 12 | """ 13 | self.nodes = nodes 14 | self.edges = edges 15 | 16 | def get_node(self): 17 | """Return any node of the graph. 18 | """ 19 | return self.nodes[0] 20 | 21 | def nodes_list(self): 22 | return self.nodes 23 | 24 | def edge(self, n1, n2): 25 | return (n1, n2) in self.edges or (n2, n1) in self.edges 26 | 27 | def neighbors(self, n): 28 | """Return a list with the neighbors of n. 29 | """ 30 | return [p for (p,q) in self.edges if q==n] \ 31 | + [q for (p,q) in self.edges if p==n] 32 | 33 | def remove_node(self, n): 34 | if n in self.nodes: 35 | self.nodes.remove(n) 36 | for (p, q) in self.edges: 37 | if n in [p,q]: 38 | self.edges.remove((p,q)) 39 | 40 | def has_cicle(self): 41 | """assert 'self must be connected' 42 | Returns True if there is a cicle, False otherwise. 43 | """ 44 | #if len(self.edges) >= len(self.nodes): 45 | # return True 46 | n = self.get_node() 47 | stack = [n] 48 | visited = set() 49 | while stack != []: 50 | n = stack.pop() 51 | if n in visited: 52 | return True 53 | else: 54 | visited.add(n) 55 | stack += [m for m in self.neighbors(n) if m not in visited] 56 | 57 | return False 58 | 59 | def prune_subtrees(self, iterations=-1): 60 | """Iteratively remove the nodes of degree 1. iterations == -1 means to 61 | do this until there aren't nodes to remove. 62 | """ 63 | nodes1 = [n for n in self.nodes_list() if len(self.neighbors(n)) == 1] 64 | i = 0 65 | while nodes1 != [] and i != iterations: 66 | for n in nodes1: 67 | self.remove_node(n) 68 | nodes1 = [n for n in self.nodes_list() if len(self.neighbors(n)) == 1] 69 | i += 1 70 | 71 | def compute_connected_components(self): 72 | """Compute the connected components of the graph in the variable self.ccs. 73 | """ 74 | nodes = self.nodes_list() 75 | if not nodes: 76 | self.ccs = [] 77 | return 0 78 | #nodes = self.nodes 79 | ccs = {0: [nodes[0]]} 80 | nextcc = 1 81 | for i in range(1, len(nodes)): 82 | node = nodes[i] 83 | l = [] 84 | for j in ccs: 85 | if any(self.edge(node, m) for m in ccs[j]): 86 | l += [j] 87 | if l == []: 88 | # node is a new cc. 89 | ccs[nextcc] = [node] 90 | nextcc += 1 91 | else: 92 | # the node joins all the ccs in l. 93 | cc = ccs[l[0]] 94 | for j in l[1:]: 95 | cc += ccs[j] 96 | del ccs[j] 97 | cc += [node] 98 | 99 | self.ccs = list(ccs.values()) 100 | #return len(self.ccs) 101 | 102 | def is_independent_set(self, nodes, show=False): 103 | """Checks if a list of nodes is an independent set. If show=True and the 104 | result is False, prints the first counter-example found. 105 | """ 106 | #return all(not self.edge(m, n) for m in nodes for n in nodes) 107 | for m in nodes: 108 | for n in nodes: 109 | if self.edge(m, n): 110 | if show: 111 | print("Edge found:", (m, n)) 112 | return False 113 | return True 114 | 115 | def get_dot_graph(self, nodes=None): 116 | if nodes == None: 117 | nodes = self.nodes 118 | g = pydot.Dot() 119 | g.set_type('graph') 120 | for i in range(1, len(nodes)): 121 | node = nodes[i] 122 | for j in range(i): 123 | # maybe: 124 | #if self.edge(str(node), str(nodes[j])): 125 | if self.edge(node, nodes[j]): 126 | e = pydot.Edge(node, nodes[j]) 127 | g.add_edge(e) 128 | return g 129 | 130 | def draw_graph(self, filename, nodes=None): 131 | """Draw the graph in a JPG file. 132 | """ 133 | g = self.get_dot_graph(nodes) 134 | g.write_jpeg(filename, prog='dot') 135 | 136 | 137 | class WGraph(Graph): 138 | """Graph with weighted edges. 139 | """ 140 | 141 | def edge_weight(self, n1, n2): 142 | return 0 143 | 144 | def compute_connected_components(self, w_min=1): 145 | """Compute the connected components of the graph in the variable self.ccs. 146 | w_min is the minimal weight for the edges to consider. 147 | """ 148 | nodes = self.nodes_list() 149 | if not nodes: 150 | self.ccs = [] 151 | return 0 152 | #nodes = self.nodes 153 | ccs = {0: [nodes[0]]} 154 | nextcc = 1 155 | for i in range(1, len(nodes)): 156 | node = nodes[i] 157 | l = [] 158 | for j in ccs: 159 | if any(self.edge_weight(node, m) >= w_min for m in ccs[j]): 160 | l += [j] 161 | if l == []: 162 | # node is a new cc. 163 | ccs[nextcc] = [node] 164 | nextcc += 1 165 | else: 166 | # the node joins all the ccs in l. 167 | cc = ccs[l[0]] 168 | for j in l[1:]: 169 | cc += ccs[j] 170 | del ccs[j] 171 | cc += [node] 172 | 173 | self.ccs = list(ccs.values()) 174 | #return len(self.ccs) 175 | 176 | def get_dot_graph(self, nodes=None, w_min=1): 177 | if nodes == None: 178 | nodes = self.nodes_list() 179 | g = pydot.Dot() 180 | g.set_type('graph') 181 | for i in range(1, len(nodes)): 182 | node = nodes[i] 183 | for j in range(i): 184 | # maybe: 185 | #if self.edge(str(node), str(nodes[j])): 186 | if self.edge_weight(node, nodes[j]) >= w_min: 187 | e = pydot.Edge(node, nodes[j]) 188 | g.add_edge(e) 189 | return g 190 | 191 | def draw_graph(self, filename, nodes=None, w_min=1): 192 | """Draw the graph in a JPG file. 193 | """ 194 | g = self.get_dot_graph(nodes, w_min=w_min) 195 | g.write_jpeg(filename, prog='dot') 196 | -------------------------------------------------------------------------------- /nlp_commons/cast3lb10.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | import itertools 6 | 7 | from . import cast3lb 8 | 9 | class Cast3LBn(cast3lb.Cast3LB): 10 | 11 | 12 | def __init__(self, n, basedir=None, load=True): 13 | cast3lb.Cast3LB.__init__(self, basedir) 14 | self.n = n 15 | self.filename = 'cast3lb%02i.treebank' % n 16 | if load: 17 | self.get_trees() 18 | 19 | 20 | def _generate_trees(self): 21 | print("Parsing Cast3LB treebank...") 22 | f = lambda t: len(t.leaves()) <= self.n 23 | m = lambda t: self._prepare(t) 24 | trees = [t for t in filter(f, map(m, self.parsed()))] 25 | return trees 26 | 27 | 28 | def _prepare(self, t): 29 | t.remove_leaves() 30 | t.remove_ellipsis() 31 | t.remove_punctuation() 32 | return t 33 | 34 | 35 | def simplify_tags(self): 36 | list(map(lambda t: t.map_leaves(self.tag_filter), self.trees)) 37 | 38 | 39 | def tag_filter(self, t): 40 | t2 = t.lower() 41 | if t == '': 42 | print("Empty tag!", t) 43 | return t 44 | elif t2[0] == 'a': 45 | # Adjetivo: Dejo tipo (calificativo 'q' u ordinal 'o') y 46 | # numero (singular 's', plural 'p' o invariable 'n'). 47 | return t2[0:2]+t2[4] 48 | elif t2[0] == 'r': 49 | # Adverbio: solo hay 'rg' y 'rn' para la palabra 'no'. Lo dejo asi. 50 | return t2 51 | elif t2[0] == 'd': 52 | # Determinante: Dejo tipo y numero. 53 | return t2[0:2]+t2[4] 54 | elif t2[0] == 'n': 55 | # Nombre: Dejo tipo y numero. 56 | return t2[0:2]+t2[3] 57 | elif t2[0] == 'v': 58 | # Verbo: dejo tipo, modo y numero. 59 | return t2[0:3]+t2[5] 60 | elif t2 in ['i', 'y', 'zm', 'zp', 'z', 'w', 'cc', 'cs', 'x']: 61 | # [interjeccion, abreviatura, moneda, porcentaje, numero, fecha u hora, 62 | # conjuncion coordinada, conjuncion subordinada, elemento deconocido]. 63 | return t2 64 | elif t2[0:2] == 'sp': 65 | # Adposicion de tipo preposicion: dejo forma. 66 | return t2[0:3] 67 | elif t2[0] == 'p': 68 | # Pronombre: dejo tipo y numero. 69 | return t2[0:2]+t2[4] 70 | elif t2[0] == 'f': 71 | # Puntuacion. La devolvemos sin pasar a lowercase. 72 | return t 73 | else: 74 | print("Unrecognized tag:", t) 75 | return t 76 | # Quedan colgados los tags: sn, sn.e.1, sn.co 77 | 78 | 79 | def simplify_tags_more(self): 80 | list(map(lambda t: t.map_leaves(self.tag_filter_more), self.trees)) 81 | 82 | 83 | def tag_filter_more(self, t): 84 | t2 = t.lower() 85 | if t == '': 86 | print("Empty tag!", t) 87 | return t 88 | elif t2[0] == 'a': 89 | # Adjetivo: Dejo numero (singular 's', plural 'p' o invariable 'n'). 90 | return t2[0]+t2[4] 91 | elif t2[0] == 'r': 92 | # Adverbio: solo hay 'rg' y 'rn' para la palabra 'no'. Lo dejo asi. 93 | return t2 94 | # Unifico todos: 95 | #return t[0] 96 | elif t2[0] == 'd': 97 | # Determinante: Dejo numero. 98 | return t2[0]+t2[4] 99 | elif t2[0] == 'n': 100 | # Nombre: Dejo tipo (comun o propio) y numero. 101 | return t2[0:2]+t2[3] 102 | elif t2[0] == 'v': 103 | # Verbo: dejo modo y numero. 104 | return t2[0]+t2[2]+t2[5] 105 | elif t2 in ['i', 'y', 'zm', 'zp', 'z', 'w', 'cc', 'cs', 'x']: 106 | # [interjeccion, abreviatura, moneda, porcentaje, numero, fecha u hora, 107 | # conjuncion coordinada, conjuncion subordinada, elemento deconocido]. 108 | return t2 109 | elif t2[0:2] == 'sp': 110 | # Adposicion de tipo preposicion: no dejo nada. 111 | return t2[0:2] 112 | elif t2[0] == 'p': 113 | # Pronombre: dejo numero. 114 | return t2[0]+t2[4] 115 | elif t2[0] == 'f': 116 | # Puntuacion. La devolvemos sin pasar a lowercase. 117 | return t 118 | else: 119 | print("Unrecognized tag:", t) 120 | return t 121 | # Quedan colgados los tags: sn, sn.e.1, sn.co 122 | 123 | 124 | class Cast3LB10(Cast3LBn): 125 | 126 | 127 | def __init__(self, basedir=None, load=True): 128 | Cast3LBn.__init__(self, 10, basedir, load) 129 | 130 | 131 | class Cast3LB30(Cast3LBn): 132 | 133 | 134 | def __init__(self, basedir=None, load=True): 135 | Cast3LBn.__init__(self, 30, basedir, load) 136 | 137 | 138 | class Cast3LBPn(Cast3LBn): 139 | # sadly I need this list (redundant beacuse we have is_punctuation) to allow 140 | # usage by other classes that want to pickle this information: 141 | punctuation_tags = ['Fp', 'Fs', 'Fpa', 'Fia', 'Fit', 'Fx', 'Fz', 'Fat', 'Fpt', 'Fc', 'Fd', 'Fe', 'Fg', 'Faa'] 142 | # this was found this way: 143 | #from cast3lb10 import * 144 | #tb = Cast3LB10P() 145 | #punct = set(sum(([x for x in t.leaves() if tb.is_punctuation(x)] for t in tb.trees), [])) 146 | 147 | stop_punctuation_tags = ['Fp', 'Fs', 'Fx', 'Fz', 'Fc', 'Fd', 'Fe', 'Fg'] 148 | bracket_punctuation_tag_pairs = [('Fpa', 'Fpt'), ('Fia', 'Fit'), ('Faa', 'Fat')] 149 | # these are: parenthesis, question marks, exclamation marks 150 | # quotes appear all with the Fe tag. 151 | # other not present in Cast3LB10P: Fc*: [ ], Fr*: << >>, Fl*: { } 152 | 153 | 154 | def __init__(self, n, basedir=None, load=True): 155 | #Cast3LB10n.__init__(self, n, load=False) 156 | if basedir == None: 157 | self.basedir = self.default_basedir 158 | else: 159 | self.basedir = basedir 160 | 161 | self.n = n 162 | self.filename = 'cast3lb%02ip.treebank' % n 163 | if load: 164 | self.get_trees() 165 | 166 | 167 | def _generate_trees(self): 168 | print("Parsing Cast3LB treebank...") 169 | f = lambda t: len([x for x in t.leaves() if not cast3lb.is_punctuation(x)]) <= self.n 170 | m = lambda t: self._prepare(t) 171 | trees = [t for t in filter(f, map(m, self.parsed()))] 172 | return trees 173 | 174 | 175 | def _prepare(self, t): 176 | t.remove_leaves() 177 | t.remove_ellipsis() 178 | #t.remove_punctuation() 179 | return t 180 | 181 | 182 | # XXX: For consistency This class should be called Cast3LBP10: 183 | class Cast3LB10P(Cast3LBPn): 184 | 185 | 186 | def __init__(self, basedir=None, load=True): 187 | Cast3LBPn.__init__(self, 10, basedir, load) 188 | 189 | 190 | class Cast3LBP30(Cast3LBPn): 191 | 192 | 193 | def __init__(self, basedir=None, load=True): 194 | Cast3LBPn.__init__(self, 30, basedir, load) 195 | 196 | 197 | def test(): 198 | tb = Cast3LB10() 199 | tb.simplify_tags() 200 | return tb 201 | -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict 3 | import math 4 | import numpy as np 5 | from math import log 6 | 7 | class ConllSent(object): 8 | """docstring for ConllSent""" 9 | def __init__(self, key_list=["word", "tag", "head"]): 10 | super(ConllSent, self).__init__() 11 | self.sent_dict = {} 12 | self.keys = key_list 13 | for key in key_list: 14 | self.sent_dict[key] = [] 15 | 16 | def __getitem__(self, key): 17 | return self.sent_dict[key] 18 | 19 | def __setitem__(self, key, item): 20 | self.sent_dict[key] = item 21 | 22 | def __len__(self): 23 | return len(self.sent_dict["word"]) 24 | 25 | def is_number(s): 26 | try: 27 | float(s) 28 | return True 29 | except ValueError: 30 | return False 31 | 32 | def cast_to_int(s): 33 | try: 34 | return int(s) 35 | except ValueError: 36 | return s 37 | 38 | def word2id(sentences): 39 | """map words to word ids 40 | 41 | Args: 42 | sentences: a nested list of sentences 43 | 44 | """ 45 | ids = defaultdict(lambda: len(ids)) 46 | id_sents = [[ids[word] for word in sent] for sent in sentences] 47 | return id_sents, ids 48 | 49 | # Compute log sum exp in a numerically stable way for the forward algorithm 50 | def log_sum_exp(value, dim=None, keepdim=False): 51 | """Numerically stable implementation of the operation 52 | 53 | value.exp().sum(dim, keepdim).log() 54 | """ 55 | if dim is not None: 56 | m, _ = torch.max(value, dim=dim, keepdim=True) 57 | value0 = value - m 58 | if keepdim is False: 59 | m = m.squeeze(dim) 60 | return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim)) 61 | else: 62 | m = torch.max(value) 63 | sum_exp = torch.sum(torch.exp(value - m)) 64 | return m + torch.log(sum_exp) 65 | 66 | def sents_to_vec(vec_dict, sentences): 67 | """read data, produce training data and labels. 68 | 69 | Args: 70 | vec_dict: a dict mapping words to vectors. 71 | sentences: A list of ConllSent objects 72 | 73 | Returns: 74 | embeddings: a list of tensors 75 | tags: a nested list of gold tags 76 | """ 77 | embeddings = [] 78 | for sent in sentences: 79 | sample = [vec_dict[word] for word in sent["word"]] 80 | embeddings.append(sample) 81 | 82 | return embeddings 83 | 84 | def sents_to_tagid(sentences): 85 | """transform tagged sents to tagids, 86 | also return the look up table 87 | """ 88 | ids = defaultdict(lambda: len(ids)) 89 | id_sents = [[ids[tag] for tag in sent["tag"]] for sent in sentences] 90 | return id_sents, ids 91 | 92 | def read_conll(fname, max_len=1e3, rm_null=True, prc_num=True): 93 | sentences = [] 94 | sent = ConllSent() 95 | 96 | null_total = [] 97 | null_sent = [] 98 | loc = 0 99 | with open(fname) as fin: 100 | for line in fin: 101 | if line != '\n': 102 | line = line.strip().split('\t') 103 | sent["head"].append((int(line[0]), 104 | cast_to_int(line[3]))) 105 | if rm_null and line[2] == '-NONE-': 106 | null_sent.append(loc) 107 | else: 108 | sent["tag"].append(line[2]) 109 | if prc_num and is_number(line[1]): 110 | sent["word"].append('0') 111 | else: 112 | sent["word"].append(line[1]) 113 | 114 | loc += 1 115 | else: 116 | loc = 0 117 | if len(sent) > 0 and len(sent) <= max_len: 118 | sentences.append(sent) 119 | null_total.append(null_sent) 120 | 121 | null_sent = [] 122 | sent = ConllSent() 123 | 124 | return sentences, null_total 125 | 126 | def write_conll(fname, sentences, pred_tags, null_total): 127 | with open(fname, 'w') as fout: 128 | for (pred, null_sent, sent) in zip(pred_tags, null_total, sentences): 129 | word_list = sent["word"] 130 | head_list = sent["head"] 131 | length = len(sent) + len(null_sent) 132 | assert (length == len(head_list)) 133 | pred_tag_list = [str(k.item()) for k in pred] 134 | for null in null_sent: 135 | pred_tag_list.insert(null, '-NONE-') 136 | word_list.insert(null, '-NONE-') 137 | 138 | for i in range(length): 139 | fout.write("{}\t{}\t{}\t{}\n".format( 140 | i+1, word_list[i], pred_tag_list[i], 141 | head_list[i][1])) 142 | fout.write('\n') 143 | 144 | def input_transpose(sents, pad): 145 | max_len = max(len(s) for s in sents) 146 | batch_size = len(sents) 147 | 148 | sents_t = [] 149 | masks = [] 150 | for i in range(max_len): 151 | sents_t.append([sent[i] if len(sent) > i else pad for sent in sents]) 152 | masks.append([1 if len(sent) > i else 0 for sent in sents]) 153 | 154 | return sents_t, masks 155 | 156 | def to_input_tensor(sents, pad, device): 157 | """ 158 | return a tensor of shape (src_sent_len, batch_size) 159 | """ 160 | 161 | sents, masks = input_transpose(sents, pad) 162 | 163 | 164 | sents_t = torch.tensor(sents, dtype=torch.float32, requires_grad=False, device=device) 165 | masks_t = torch.tensor(masks, dtype=torch.float32, requires_grad=False, device=device) 166 | 167 | return sents_t, masks_t 168 | 169 | def data_iter(data, batch_size, is_test=False, shuffle=True): 170 | index_arr = np.arange(len(data)) 171 | # in_place operation 172 | 173 | if shuffle: 174 | np.random.shuffle(index_arr) 175 | 176 | batch_num = int(np.ceil(len(data) / float(batch_size))) 177 | for i in range(batch_num): 178 | batch_ids = index_arr[i * batch_size: (i + 1) * batch_size] 179 | batch_data = [data[index] for index in batch_ids] 180 | 181 | if is_test: 182 | # batch_data.sort(key=lambda e: -len(e[0])) 183 | test_data = [data_tuple[0] for data_tuple in batch_data] 184 | tags = [data_tuple[1] for data_tuple in batch_data] 185 | 186 | 187 | yield test_data, tags 188 | 189 | else: 190 | # batch_data.sort(key=lambda e: -len(e)) 191 | yield batch_data 192 | 193 | def generate_seed(data, size, shuffle=True): 194 | index_arr = np.arange(len(data)) 195 | # in_place operation 196 | 197 | if shuffle: 198 | np.random.shuffle(index_arr) 199 | 200 | seed = [data[index] for index in index_arr[:size]] 201 | 202 | return seed 203 | 204 | def get_tag_set(tag_list): 205 | tag_set = set() 206 | tag_set.update([x for s in tag_list for x in s]) 207 | return tag_set 208 | 209 | def stable_math_log(val, default_val=-1e20): 210 | if val == 0: 211 | return default_val 212 | 213 | return math.log(val) 214 | 215 | def unravel_index(input, size): 216 | """Unravel the index of tensor given size 217 | Args: 218 | input: LongTensor 219 | size: a tuple of integers 220 | 221 | Outputs: output, 222 | - **output**: the unraveled new tensor 223 | 224 | Examples:: 225 | <<< value = torch.LongTensor(4,5,7,9) 226 | <<< max_val, flat_index = torch.max(value.view(4, 5, -1), dim=-1) 227 | <<< index = unravel_index(flat_index, (7, 9)) 228 | <<< # output is a tensor with size (4, 5, 2) 229 | 230 | """ 231 | idx = [] 232 | for adim in size[::-1]: 233 | idx.append((input % adim).unsqueeze(dim=-1)) 234 | input = input / adim 235 | idx = idx[::-1] 236 | return torch.cat(idx, -1) 237 | -------------------------------------------------------------------------------- /nlp_commons/dep/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | 6 | from __future__ import absolute_import 7 | # dep/model.py: A general model for dependency parsing (class DepModel), and a 8 | # general model for projective dependency parsing, with evaluation also as 9 | # constituent trees. 10 | 11 | 12 | #from .. import model 13 | from .. import sentence, bracketing, model 14 | from . import depset 15 | from . import dwsj 16 | 17 | 18 | class DepModel(model.Model): 19 | """A general model for dependency parsing.""" 20 | count_length_2 = True 21 | count_length_2_1 = False 22 | 23 | def __init__(self, treebank=None): 24 | 25 | treebank = self._get_treebank(treebank) 26 | 27 | S, Gold = [], [] 28 | for t in treebank.get_trees(): 29 | s = sentence.Sentence(t.leaves()) 30 | S += [s] 31 | #Gold += [depset.deptree_to_depset(t)] 32 | Gold += [t.depset] 33 | 34 | self.S = S 35 | self.Gold = Gold 36 | 37 | def _get_treebank(self, treebank=None): 38 | if treebank is None: 39 | treebank = dwsj.DepWSJ10() 40 | return treebank 41 | 42 | def eval(self, output=True, short=False, long=False, max_length=None): 43 | Gold = self.Gold 44 | 45 | Count = 0 46 | Directed = 0.0 47 | Undirected = 0.0 48 | 49 | for i in range(len(Gold)): 50 | l = Gold[i].length 51 | if (max_length is None or l <= max_length) \ 52 | and (self.count_length_2_1 or (self.count_length_2 and l == 2) or l >= 3): 53 | (count, directed, undirected) = self.measures(i) 54 | Count += count 55 | Directed += directed 56 | Undirected += undirected 57 | 58 | Directed = Directed / Count 59 | Undirected = Undirected / Count 60 | 61 | self.evaluation = (Count, Directed, Undirected) 62 | self.evaluated = True 63 | 64 | if output and not short: 65 | print "Number of Trees:", len(Gold) 66 | print " Directed Accuracy: %2.1f" % (100*Directed) 67 | print " Undirected Accuracy: %2.1f" % (100*Undirected) 68 | elif output and short: 69 | print "L =", Directed, "UL =", Undirected 70 | 71 | return self.evaluation 72 | 73 | def measures(self, i): 74 | # Helper for eval(). 75 | # Measures for the i-th parse. 76 | 77 | g, p = self.Gold[i].deps, self.Parse[i].deps 78 | (n, d, u) = (self.Gold[i].length, 0, 0) 79 | for (a, b) in g: 80 | b1 = (a, b) in p 81 | b2 = (b, a) in p 82 | if b1: 83 | d += 1 84 | if b1 or b2: 85 | u += 1 86 | 87 | return (n, d, u) 88 | 89 | #def eval_stats(self, output=True, short=False, long=False, max_length=None): 90 | def eval_stats(self, output=True, max_length=None): 91 | Gold, Parse = self.Gold, self.Parse 92 | gold_stats = {} 93 | parse_stats = {} 94 | stats = {} 95 | for i in range(len(Gold)): 96 | l = Gold[i].length 97 | if (max_length is None or l <= max_length) \ 98 | and (self.count_length_2_1 or (self.count_length_2 and l == 2) or l >= 3): 99 | #(count, directed, undirected) = self.measures(i) 100 | #Count += count 101 | #Directed += directed 102 | #Undirected += undirected 103 | s = self.S[i] + ['ROOT'] 104 | g, p = Gold[i].deps, Parse[i].deps 105 | lg = [(s[i], s[j], i < j) for i,j in g] 106 | lp = [(s[i], s[j], i < j) for i,j in p] 107 | for x in lg: 108 | gold_stats[x] = gold_stats.get(x, 0) + 1 109 | stats[x] = stats.get(x, 0) - 1 110 | for x in lp: 111 | parse_stats[x] = parse_stats.get(x, 0) + 1 112 | stats[x] = stats.get(x, 0) + 1 113 | lstats = sorted(stats.iteritems(), key=lambda x:x[1]) 114 | if output: 115 | # a -> b iif b is head of a. 116 | print 'Overproposals' 117 | for ((d, h, left), n) in lstats[:len(lstats)-10:-1]: 118 | if left: 119 | print '\t{0} -> {1}\t{2}'.format(d, h, n) 120 | else: 121 | print '\t{1} <- {0}\t{2}'.format(d, h, n) 122 | print 'Underproposals' 123 | for ((d, h, left), n) in lstats[:10]: 124 | if left: 125 | print '\t{0} -> {1}\t{2}'.format(d, h, -n) 126 | else: 127 | print '\t{1} <- {0}\t{2}'.format(d, h, -n) 128 | 129 | #return (gold_stats, parse_stats) 130 | return lstats 131 | 132 | 133 | class ProjDepModel(DepModel): 134 | """A general model for projective dependency parsing, with evaluation also 135 | as constituent trees. 136 | """ 137 | def __init__(self, treebank=None, training_corpus=None): 138 | """ 139 | The elements of the treebank must be trees with a DepSet in the 140 | attribute depset. 141 | """ 142 | treebank = self._get_treebank(treebank) 143 | if training_corpus == None: 144 | training_corpus = treebank 145 | self.test_corpus = treebank 146 | self.training_corpus = training_corpus 147 | S = [] 148 | for s in treebank.tagged_sents(): 149 | s = [x[1] for x in s] 150 | S += [sentence.Sentence(s)] 151 | self.S = S 152 | # Extract gold as DepSets: 153 | # FIXME: call super and do this there. 154 | self.Gold = [t.depset for t in treebank.parsed_sents()] 155 | 156 | # Extract gold as Bracketings: 157 | # self.bracketing_model = model.BracketingModel(treebank) 158 | 159 | def eval(self, output=True, short=False, long=False, max_length=None): 160 | """Compute evaluation of the parses against the test corpus. Computes 161 | unlabeled precision, recall and F1 between the bracketings, and directed 162 | and undirected dependency accuracy between the dependency structures. 163 | """ 164 | # XXX: empezamos a lo bruto: 165 | self.bracketing_model.Parse = [bracketing.tree_to_bracketing(t) for t in self.Parse] 166 | #dmvccm.DMVCCM.eval(self, output, short, long, max_length) 167 | self.bracketing_model.eval(output, short, long, max_length) 168 | 169 | # Ahora eval de dependencias: 170 | self.DepParse = self.Parse 171 | # type no anda porque devuelve instance: 172 | #self.Parse = [type(self).tree_to_depset(t) for t in self.DepParse] 173 | self.Parse = [self.__class__.tree_to_depset(t) for t in self.DepParse] 174 | #model.DepModel.eval(self, output, short, long, max_length) 175 | DepModel.eval(self, output, short, long, max_length) 176 | self.Parse = self.DepParse 177 | 178 | def eval_stats(self, output=True, max_length=None): 179 | # Ahora eval de dependencias: 180 | self.DepParse = self.Parse 181 | # type no anda porque devuelve instance: 182 | #self.Parse = [type(self).tree_to_depset(t) for t in self.DepParse] 183 | self.Parse = [self.__class__.tree_to_depset(t) for t in self.DepParse] 184 | #model.DepModel.eval(self, output, short, long, max_length) 185 | DepModel.eval_stats(self, output, max_length) 186 | self.Parse = self.DepParse 187 | 188 | @staticmethod 189 | def tree_to_depset(t): 190 | """Function used to convert the trees returned by the parser to DepSets. 191 | """ 192 | raise Exception('Static function tree_to_depset must be overriden.') 193 | -------------------------------------------------------------------------------- /nlp_commons/dep/dwsj.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | # dwsj.py: Dependency version of the WSJ corpus. 6 | 7 | """ 8 | from dep.dwsj import * 9 | import wsj 10 | tb = wsj.WSJ() 11 | t = tb.get_tree(2) 12 | t.leaves() 13 | # ['Rudolph', 'Agnew', ',', '55', 'years', 'old', 'and', 'former', 'chairman', 'of', 'Consolidated', 'Gold', 'Fields', 'PLC', ',', 'was', 'named', '*-1', 'a', 'nonexecutive', 'director', 'of', 'this', 'British', 'industrial', 'conglomerate', '.'] 14 | find_heads(t) 15 | t.depset = tree_to_depset(t) 16 | t.depset.deps 17 | """ 18 | 19 | from .. import wsj 20 | from .. import wsj10 21 | from . import depset 22 | 23 | 24 | class DepWSJ(wsj10.WSJn): 25 | 26 | def __init__(self, max_length, basedir=None, load=True, extra_tags=None): 27 | wsj10.WSJn.__init__(self, max_length, basedir, load=False, extra_tags=extra_tags) 28 | self.filename = '%s.treebank' % basedir 29 | if load: 30 | self.get_trees() 31 | 32 | def _generate_trees(self): 33 | trees = wsj10.WSJn._generate_trees(self) 34 | 35 | for t in trees: 36 | # First find the head for each constituent: 37 | find_heads(t) 38 | t.depset = tree_to_depset(t) 39 | return trees 40 | 41 | def get_gold_dep(self): 42 | """ 43 | Return: 44 | gold_dep: a list of DepSet 45 | 46 | """ 47 | gold_dep = [t.depset for t in self.trees] 48 | 49 | return gold_dep 50 | 51 | 52 | 53 | 54 | class DepWSJ10(DepWSJ): 55 | 56 | def __init__(self, basedir=None, load=True, extra_tags=None): 57 | DepWSJ.__init__(self, 10, basedir, load, extra_tags) 58 | 59 | 60 | def find_heads(t): 61 | """Mark heads in the constituent tree t using the Collins PhD Thesis (1999) 62 | rules. The heads are marked in every subtree st in the attributes st.node 63 | and st.head. 64 | """ 65 | for st in t.subtrees(): 66 | label = st.label().split('-')[0].split('=')[0] 67 | # the children may be a tree or a leaf (type string): 68 | children = [(type(x) is bytes and x) or (type(x) is str and x) or 69 | x.label().split('-')[0] for x in st] 70 | st.head = get_head(label, children)-1 71 | st.set_label('['+children[st.head]+']') 72 | 73 | 74 | def tree_to_depset(t): 75 | """Returns the DepSet associated to the head marked tree t (with find_heads). 76 | """ 77 | leave_index = 0 78 | res = set() 79 | aux = {} 80 | # Traverse the tree from the leaves upwards (postorder) 81 | for p in t.treepositions(order='postorder'): 82 | st = t[p] 83 | if isinstance(st, bytes) or isinstance(st, str): 84 | # We are at leave with index leave_index. 85 | aux[p] = leave_index 86 | leave_index += 1 87 | else: 88 | # We are at a subtree. aux has the index of the 89 | # head for each subsubtree. 90 | head = st.head 91 | if type(st[head]) is bytes or type(st[head]) is str: 92 | # index of the leave at deptree[head] 93 | head_index = aux[p+(head,)] 94 | else: 95 | head_index = st[head].head_index 96 | st.head_index = head_index 97 | for i in range(len(st)): 98 | sst = st[i] 99 | if i == head: 100 | pass # skip self dependency 101 | elif type(sst) is bytes or type(sst) is str: 102 | res.add((aux[p+(i,)], head_index)) 103 | else: 104 | res.add((sst.head_index, head_index)) 105 | res.add((t.head_index, -1)) 106 | 107 | return depset.DepSet(len(t.leaves()), sorted(res)) 108 | 109 | 110 | def get_head(label, children): 111 | """children must be a not empty list. Returns the index of the head, 112 | starting from 1. 113 | (rules for the Penn Treebank, taken from p. 239 of Collins thesis). 114 | The rules for X and NX are not specified by Collins. We use the ones 115 | at (also at Yamada and Matsumoto 2003). 116 | (X only appears at wsj_0056.mrg and at wsj_0077.mrg) 117 | """ 118 | assert children != [] 119 | if len(children) == 1: 120 | # Used also when label is a POS tag and children is a word. 121 | res = 1 122 | elif label == 'NP': 123 | # Rules for NPs 124 | 125 | # search* returns indexes starting from 1 126 | # (to avoid confusion between 0 and False): 127 | res = (children[-1] in wsj.word_tags and len(children)) or \ 128 | searchr(children, set('NN NNP NNS NNPS NNS NX POS JJR'.split())) or \ 129 | searchl(children, 'NP') or \ 130 | searchr(children, set('$ ADJP PRN'.split())) or \ 131 | searchr(children, 'CD') or \ 132 | searchr(children, set('JJ JJS RB QP'.split())) or \ 133 | len(children) 134 | else: 135 | rule = head_rules[label] 136 | plist = rule[1] 137 | if plist == [] and rule[0] == 'r': 138 | res = len(children) 139 | # Redundant: 140 | #elif plist == [] and rule[0] == 'l': 141 | # res = 1 142 | else: 143 | res = None 144 | i, n = 0, len(plist) 145 | if rule[0] == 'l': 146 | while i < n and res is None: 147 | # search* returns indexes starting from 1 148 | res = searchl(children, plist[i]) 149 | i += 1 150 | else: 151 | #assert rule[0] == 'r' 152 | while i < n and res is None: 153 | # search* returns indexes starting from 1 154 | res = searchr(children, plist[i]) 155 | i += 1 156 | if res is None: 157 | res = 1 158 | 159 | # Rules for coordinated phrases 160 | #if 'CC' in [res-2 >= 0 and children[res-2], \ 161 | # res < len(children) and children[res]]: 162 | if res-2 >= 1 and children[res-2] == 'CC': 163 | # On the other case the head doesn't change. 164 | res -= 2 165 | 166 | return res 167 | 168 | 169 | def searchr(l, e): 170 | """As searchl but from right to left. When not None, returns the index 171 | starting from 1. 172 | """ 173 | l = l[::-1] 174 | r = searchl(l, e) 175 | if r is None: 176 | return None 177 | else: 178 | return len(l)-r+1 179 | 180 | 181 | def searchl(l, e): 182 | """Returns the index of the first occurrence of any member of e in l, 183 | starting from 1 (just for convenience in the usage, see get_head). Returns 184 | None if there is no occurrence. 185 | """ 186 | #print 'searchl('+str(l)+', '+str(e)+')' 187 | if type(e) is not set: 188 | e = set([e]) 189 | i, n = 0, len(l) 190 | while i < n and l[i] not in e: 191 | i += 1 192 | if i == n: 193 | return None 194 | else: 195 | #return l[i] 196 | return i+1 197 | 198 | 199 | head_rules = {'ADJP': ('l', 'NNS QP NN $ ADVP JJ VBN VBG ADJP JJR NP JJS DT FW RBR RBS SBAR RB'.split()), \ 200 | 'ADVP': ('r', 'RB RBR RBS FW ADVP TO CD JJR JJ IN NP JJS NN'.split()), \ 201 | 'CONJP': ('r', 'CC RB IN'.split()), \ 202 | 'FRAG': ('r', []), \ 203 | 'INTJ': ('l', []), \ 204 | 'LST': ('r', 'LS :'.split()), \ 205 | 'NAC': ('l', 'NN NNS NNP NNPS NP NAC EX $ CD QP PRP VBG JJ JJS JJR ADJP FW'.split()), \ 206 | 'PP': ('r', 'IN TO VBG VBN RP FW'.split()), \ 207 | 'PRN': ('l', []), \ 208 | 'PRT': ('r', 'RP'.split()), \ 209 | 'QP': ('l', '$ IN NNS NN JJ RB DT CD NCD QP JJR JJS'.split()), \ 210 | 'RRC': ('r', 'VP NP ADVP ADJP PP'.split()), \ 211 | 'S': ('l', 'TO IN VP S SBAR ADJP UCP NP'.split()), \ 212 | 'SBAR': ('l', 'WHNP WHPP WHADVP WHADJP IN DT S SQ SINV SBAR FRAG'.split()), \ 213 | 'SBARQ': ('l', 'SQ S SINV SBARQ FRAG'.split()), \ 214 | 'SINV': ('l', 'VBZ VBD VBP VB MD VP S SINV ADJP NP'.split()), \ 215 | 'SQ': ('l', 'VBZ VBD VBP VB MD VP SQ'.split()), \ 216 | 'UCP': ('r', []), \ 217 | 'VP': ('l', 'TO VBD VBN MD VBZ VB VBG VBP VP ADJP NN NNS NP'.split()), \ 218 | 'WHADJP': ('l', 'CC WRB JJ ADJP'.split()), \ 219 | 'WHADVP': ('r', 'CC WRB'.split()), \ 220 | 'WHNP': ('l', 'WDT WP WP$ WHADJP WHPP WHNP'.split()), \ 221 | 'WHPP': ('r', 'IN TO FW'.split()), \ 222 | 'NX': ('r', 'POS NN NNP NNPS NNS NX JJR CD JJ JJS RB QP NP'.split()), \ 223 | 'X': ('r', []) 224 | } 225 | -------------------------------------------------------------------------------- /markov_flow_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import pickle 4 | import argparse 5 | import sys 6 | import time 7 | import os 8 | 9 | import torch 10 | import numpy as np 11 | 12 | from modules import read_conll, \ 13 | to_input_tensor, \ 14 | data_iter, \ 15 | generate_seed, \ 16 | sents_to_vec 17 | 18 | 19 | from modules import MarkovFlow 20 | 21 | 22 | def init_config(): 23 | 24 | parser = argparse.ArgumentParser(description='POS tagging') 25 | 26 | # train and test data 27 | parser.add_argument('--word_vec', type=str, 28 | help='the word vector file (cPickle saved file)') 29 | parser.add_argument('--train_file', type=str, help='train data') 30 | parser.add_argument('--test_file', default='', type=str, help='test data') 31 | 32 | # optimization parameters 33 | parser.add_argument('--batch_size', default=32, type=int, help='batch_size') 34 | parser.add_argument('--epochs', default=50, type=int, help='number of training epochs') 35 | parser.add_argument('--lr', default=0.001, type=float, help='learning rate') 36 | 37 | # model config 38 | parser.add_argument('--model', choices=['gaussian', 'nice'], default='gaussian') 39 | parser.add_argument('--num_state', default=45, type=int, 40 | help='number of hidden states of z') 41 | parser.add_argument('--couple_layers', default=4, type=int, 42 | help='number of coupling layers in NICE') 43 | parser.add_argument('--cell_layers', default=1, type=int, 44 | help='number of cell layers of ReLU net in each coupling layer') 45 | parser.add_argument('--hidden_units', default=50, type=int, help='hidden units in ReLU Net') 46 | 47 | # pretrained model options 48 | parser.add_argument('--load_nice', default='', type=str, 49 | help='load pretrained projection model, ignored by default') 50 | parser.add_argument('--load_gaussian', default='', type=str, 51 | help='load pretrained Gaussian model, ignored by default') 52 | 53 | # log parameters 54 | parser.add_argument('--valid_nepoch', default=1, type=int, help='valid_nepoch') 55 | 56 | # Others 57 | parser.add_argument('--tag_from', default='', type=str, 58 | help='load pretrained model and perform tagging') 59 | parser.add_argument('--seed', default=5783287, type=int, help='random seed') 60 | parser.add_argument('--set_seed', action='store_true', default=False, help='if set seed') 61 | 62 | # these are for slurm purpose to save model 63 | # they can also be used to run multiple random restarts with various settings, 64 | # to save models that can be identified with ids 65 | parser.add_argument('--jobid', type=int, default=0, help='slurm job id') 66 | parser.add_argument('--taskid', type=int, default=0, help='slurm task id') 67 | 68 | args = parser.parse_args() 69 | args.cuda = torch.cuda.is_available() 70 | 71 | save_dir = "dump_models/markov" 72 | 73 | if not os.path.exists(save_dir): 74 | os.makedirs(save_dir) 75 | 76 | id_ = "pos_%s_%dlayers_%d_%d" % (args.model, args.couple_layers, args.jobid, args.taskid) 77 | save_path = os.path.join(save_dir, id_ + '.pt') 78 | args.save_path = save_path 79 | print("model save path: ", save_path) 80 | 81 | if args.tag_from != '': 82 | if args.model == 'nice': 83 | args.load_nice = args.tag_from 84 | else: 85 | args.load_gaussian = args.tag_from 86 | args.tag_path = "pos_%s_%slayers_tagging%d_%d.txt" % \ 87 | (args.model, args.couple_layers, args.jobid, args.taskid) 88 | 89 | if args.set_seed: 90 | torch.manual_seed(args.seed) 91 | if args.cuda: 92 | torch.cuda.manual_seed(args.seed) 93 | np.random.seed(args.seed * 13 / 7) 94 | 95 | print(args) 96 | 97 | return args 98 | 99 | def main(args): 100 | 101 | word_vec = pickle.load(open(args.word_vec, 'rb')) 102 | print('complete loading word vectors') 103 | 104 | train_text, null_index = read_conll(args.train_file) 105 | if args.test_file != '': 106 | test_text, null_index = read_conll(args.test_file) 107 | else: 108 | test_text = train_text 109 | 110 | train_data = sents_to_vec(word_vec, train_text) 111 | test_data = sents_to_vec(word_vec, test_text) 112 | 113 | test_tags = [sent["tag"] for sent in test_text] 114 | 115 | num_dims = len(train_data[0][0]) 116 | print('complete reading data') 117 | 118 | print('#training sentences: %d' % len(train_data)) 119 | print('#testing sentences: %d' % len(test_data)) 120 | 121 | log_niter = (len(train_data)//args.batch_size)//10 122 | 123 | 124 | pad = np.zeros(num_dims) 125 | device = torch.device("cuda" if args.cuda else "cpu") 126 | args.device = device 127 | init_seed = to_input_tensor(generate_seed(train_data, args.batch_size), 128 | pad, device=device) 129 | 130 | model = MarkovFlow(args, num_dims).to(device) 131 | 132 | model.init_params(init_seed) 133 | 134 | if args.tag_from != '': 135 | model.eval() 136 | with torch.no_grad(): 137 | accuracy, vm = model.test(test_data, test_tags, sentences=test_text, 138 | tagging=True, path=args.tag_path, null_index=null_index) 139 | print('\n***** M1 %f, VM %f, max_var %.4f, min_var %.4f*****\n' 140 | % (accuracy, vm, model.var.data.max(), model.var.data.min())) 141 | return 142 | 143 | 144 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 145 | 146 | begin_time = time.time() 147 | print('begin training') 148 | 149 | train_iter = report_obj = report_jc = report_ll = report_num_words = 0 150 | 151 | # print the accuracy under init params 152 | model.eval() 153 | with torch.no_grad(): 154 | accuracy, vm = model.test(test_data, test_tags) 155 | print('\n*****starting M1 %f, VM %f, max_var %.4f, min_var %.4f*****\n' 156 | % (accuracy, vm, model.var.data.max(), model.var.data.min())) 157 | 158 | 159 | model.train() 160 | for epoch in range(args.epochs): 161 | # model.print_params() 162 | report_obj = report_jc = report_ll = report_num_words = 0 163 | for sents in data_iter(train_data, batch_size=args.batch_size, shuffle=True): 164 | train_iter += 1 165 | batch_size = len(sents) 166 | num_words = sum(len(sent) for sent in sents) 167 | sents_var, masks = to_input_tensor(sents, pad, device=args.device) 168 | optimizer.zero_grad() 169 | likelihood, jacobian_loss = model(sents_var, masks) 170 | neg_likelihood_loss = -likelihood 171 | 172 | avg_ll_loss = (neg_likelihood_loss + jacobian_loss)/batch_size 173 | 174 | avg_ll_loss.backward() 175 | 176 | optimizer.step() 177 | 178 | log_likelihood_val = -neg_likelihood_loss.item() 179 | jacobian_val = -jacobian_loss.item() 180 | obj_val = log_likelihood_val + jacobian_val 181 | 182 | report_ll += log_likelihood_val 183 | report_jc += jacobian_val 184 | report_obj += obj_val 185 | report_num_words += num_words 186 | 187 | if train_iter % log_niter == 0: 188 | print('epoch %d, iter %d, log_likelihood %.2f, jacobian %.2f, obj %.2f, max_var %.4f ' \ 189 | 'min_var %.4f time elapsed %.2f sec' % (epoch, train_iter, report_ll / report_num_words, \ 190 | report_jc / report_num_words, report_obj / report_num_words, model.var.max(), \ 191 | model.var.min(), time.time() - begin_time)) 192 | 193 | print('\nepoch %d, log_likelihood %.2f, jacobian %.2f, obj %.2f\n' % \ 194 | (epoch, report_ll / report_num_words, report_jc / report_num_words, 195 | report_obj / report_num_words)) 196 | 197 | if epoch % args.valid_nepoch == 0: 198 | model.eval() 199 | with torch.no_grad(): 200 | accuracy, vm = model.test(test_data, test_tags) 201 | print('\n*****epoch %d, iter %d, M1 %f, VM %f*****\n' % 202 | (epoch, train_iter, accuracy, vm)) 203 | model.train() 204 | 205 | torch.save(model.state_dict(), args.save_path) 206 | 207 | model.eval() 208 | with torch.no_grad(): 209 | accuracy, vm = model.test(test_data, test_tags) 210 | print('\n complete training, accuracy %f, vm %f\n' % (accuracy, vm)) 211 | 212 | if __name__ == '__main__': 213 | parse_args = init_config() 214 | main(parse_args) 215 | -------------------------------------------------------------------------------- /dmv_flow_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import argparse 5 | import time 6 | import sys 7 | import pickle 8 | 9 | import torch 10 | import numpy as np 11 | 12 | import modules.dmv_flow_model as dmv 13 | from modules import data_iter, \ 14 | read_conll, \ 15 | sents_to_vec, \ 16 | sents_to_tagid, \ 17 | to_input_tensor, \ 18 | generate_seed 19 | 20 | 21 | def init_config(): 22 | 23 | parser = argparse.ArgumentParser(description='dependency parsing') 24 | 25 | # train and test data 26 | parser.add_argument('--word_vec', type=str, 27 | help='the word vector file (cPickle saved file)') 28 | parser.add_argument('--train_file', type=str, help='train data') 29 | parser.add_argument('--test_file', default='', type=str, help='test data') 30 | parser.add_argument('--load_viterbi_dmv', type=str, 31 | help='load pretrained DMV') 32 | 33 | # optimization parameters 34 | parser.add_argument('--epochs', default=15, type=int, help='number of epochs') 35 | parser.add_argument('--batch_size', default=32, type=int, help='batch_size') 36 | parser.add_argument('--lr', default=0.01, type=float, help='learning rate') 37 | parser.add_argument('--clip_grad', default=5., type=float, help='clip gradients') 38 | 39 | # model config 40 | parser.add_argument('--model', choices=['gaussian', 'nice'], default='gaussian') 41 | parser.add_argument('--couple_layers', default=8, type=int, 42 | help='number of coupling layers in NICE') 43 | parser.add_argument('--cell_layers', default=1, type=int, 44 | help='number of cell layers of ReLU net in each coupling layer') 45 | parser.add_argument('--hidden_units', default=50, type=int, help='hidden units in ReLU Net') 46 | 47 | # others 48 | parser.add_argument('--train_from', type=str, default='', 49 | help='load a pre-trained checkpoint') 50 | parser.add_argument('--seed', default=5783287, type=int, help='random seed') 51 | parser.add_argument('--set_seed', action='store_true', default=False, 52 | help='if set seed') 53 | parser.add_argument('--valid_nepoch', default=1, type=int, 54 | help='valid every n epochs') 55 | parser.add_argument('--eval_all', action='store_true', default=False, 56 | help='if true, the script would evaluate on all lengths after training') 57 | 58 | # these are for slurm purpose to save model 59 | # they can also be used to run multiple random restarts with various settings, 60 | # to save models that can be identified with ids 61 | parser.add_argument('--jobid', type=int, default=0, help='slurm job id') 62 | parser.add_argument('--taskid', type=int, default=0, help='slurm task id') 63 | 64 | 65 | args = parser.parse_args() 66 | args.cuda = torch.cuda.is_available() 67 | 68 | save_dir = "dump_models/dmv" 69 | 70 | if not os.path.exists(save_dir): 71 | os.makedirs(save_dir) 72 | 73 | save_path = "parse_%s_%dlayers_%d_%d" % \ 74 | (args.model, args.couple_layers, args.jobid, args.taskid) 75 | save_path = os.path.join(save_dir, save_path + '.pt') 76 | args.save_path = save_path 77 | 78 | if args.set_seed: 79 | torch.manual_seed(args.seed) 80 | if args.cuda: 81 | torch.cuda.manual_seed(args.seed) 82 | np.random.seed(args.seed) 83 | 84 | print(args) 85 | 86 | return args 87 | 88 | 89 | def main(args): 90 | 91 | word_vec = pickle.load(open(args.word_vec, 'rb')) 92 | print('complete loading word vectors') 93 | 94 | train_sents, _ = read_conll(args.train_file, max_len=10) 95 | test_sents, _ = read_conll(args.test_file, max_len=10) 96 | test_deps = [sent["head"] for sent in test_sents] 97 | 98 | train_emb = sents_to_vec(word_vec, train_sents) 99 | test_emb = sents_to_vec(word_vec, test_sents) 100 | 101 | num_dims = len(train_emb[0][0]) 102 | 103 | train_tagid, tag2id = sents_to_tagid(train_sents) 104 | print('%d types of tags' % len(tag2id)) 105 | id2tag = {v: k for k, v in tag2id.items()} 106 | 107 | pad = np.zeros(num_dims) 108 | device = torch.device("cuda" if args.cuda else "cpu") 109 | args.device = device 110 | 111 | model = dmv.DMVFlow(args, id2tag, num_dims).to(device) 112 | 113 | init_seed = to_input_tensor(generate_seed(train_emb, args.batch_size), 114 | pad, device=device) 115 | 116 | with torch.no_grad(): 117 | model.init_params(init_seed, train_tagid, train_emb) 118 | print('complete init') 119 | 120 | if args.train_from != '': 121 | model.load_state_dict(torch.load(args.train_from)) 122 | with torch.no_grad(): 123 | directed, undirected = model.test(test_deps, test_emb) 124 | print('acc on length <= 10: #trees %d, undir %2.1f, dir %2.1f' \ 125 | % (len(test_deps), 100 * undirected, 100 * directed)) 126 | 127 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 128 | 129 | log_niter = (len(train_emb)//args.batch_size)//5 130 | report_ll = report_num_words = report_num_sents = epoch = train_iter = 0 131 | stop_avg_ll = stop_num_words = 0 132 | stop_avg_ll_last = 1 133 | dir_last = 0 134 | begin_time = time.time() 135 | 136 | print('begin training') 137 | 138 | with torch.no_grad(): 139 | directed, undirected = model.test(test_deps, test_emb) 140 | print('starting acc on length <= 10: #trees %d, undir %2.1f, dir %2.1f' \ 141 | % (len(test_deps), 100 * undirected, 100 * directed)) 142 | 143 | for epoch in range(args.epochs): 144 | report_ll = report_num_sents = report_num_words = 0 145 | for sents in data_iter(train_emb, batch_size=args.batch_size): 146 | batch_size = len(sents) 147 | num_words = sum(len(sent) for sent in sents) 148 | stop_num_words += num_words 149 | optimizer.zero_grad() 150 | 151 | sents_var, masks = to_input_tensor(sents, pad, device) 152 | sents_var, _ = model.transform(sents_var) 153 | sents_var = sents_var.transpose(0, 1) 154 | log_likelihood = model.p_inside(sents_var, masks) 155 | 156 | avg_ll_loss = -log_likelihood / batch_size 157 | 158 | avg_ll_loss.backward() 159 | 160 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) 161 | optimizer.step() 162 | 163 | report_ll += log_likelihood.item() 164 | report_num_words += num_words 165 | report_num_sents += batch_size 166 | 167 | stop_avg_ll += log_likelihood.item() 168 | 169 | if train_iter % log_niter == 0: 170 | print('epoch %d, iter %d, ll_per_sent %.4f, ll_per_word %.4f, ' \ 171 | 'max_var %.4f, min_var %.4f time elapsed %.2f sec' % \ 172 | (epoch, train_iter, report_ll / report_num_sents, \ 173 | report_ll / report_num_words, model.var.data.max(), \ 174 | model.var.data.min(), time.time() - begin_time), file=sys.stderr) 175 | 176 | train_iter += 1 177 | if epoch % args.valid_nepoch == 0: 178 | with torch.no_grad(): 179 | directed, undirected = model.test(test_deps, test_emb) 180 | print('\n\nacc on length <= 10: #trees %d, undir %2.1f, dir %2.1f, \n\n' \ 181 | % (len(test_deps), 100 * undirected, 100 * directed)) 182 | 183 | stop_avg_ll = stop_avg_ll / stop_num_words 184 | rate = (stop_avg_ll - stop_avg_ll_last) / abs(stop_avg_ll_last) 185 | 186 | print('\n\nlikelihood: %.4f, likelihood last: %.4f, rate: %f\n' % \ 187 | (stop_avg_ll, stop_avg_ll_last, rate)) 188 | 189 | if rate < 0.001 and epoch >= 5: 190 | break 191 | 192 | stop_avg_ll_last = stop_avg_ll 193 | stop_avg_ll = stop_num_words = 0 194 | 195 | torch.save(model.state_dict(), args.save_path) 196 | 197 | # eval on all lengths 198 | if args.eval_all: 199 | test_sents, _ = read_conll(args.test_file) 200 | test_deps = [sent["head"] for sent in test_sents] 201 | test_emb = sents_to_vec(word_vec, test_sents) 202 | print("start evaluating on all lengths") 203 | with torch.no_grad(): 204 | directed, undirected = model.test(test_deps, test_emb, eval_all=True) 205 | print('accuracy on all lengths: number of trees:%d, undir: %2.1f, dir: %2.1f' \ 206 | % (len(test_gold), 100 * undirected, 100 * directed)) 207 | 208 | 209 | if __name__ == '__main__': 210 | parse_args = init_config() 211 | main(parse_args) 212 | -------------------------------------------------------------------------------- /nlp_commons/dep/conll.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | # conll.py: Classes to read CoNLL 2006 and 2007 corpora. 6 | # TODO: test projectiveness and project the dependency trees. 7 | 8 | import nltk 9 | from nltk.corpus.reader import dependency 10 | from nltk import tree 11 | from nltk import corpus 12 | 13 | from dep import depgraph 14 | from dep import depset 15 | import treebank 16 | 17 | class CoNLLTreebank(treebank.Treebank): 18 | def __init__(self, corpus, files=None, max_length=None): 19 | treebank.Treebank.__init__(self) 20 | self.corpus = corpus 21 | self.trees = [] 22 | #print is_punctuation 23 | i = 0 24 | non_projectable, empty = 0, 0 25 | non_leaf = [] 26 | for d in self.corpus.parsed_sents(files): 27 | # print "Voy por la ", i 28 | d2 = depgraph.DepGraph(d) 29 | try: 30 | d2.remove_leaves(type(self).is_punctuation) 31 | t = d2.constree() 32 | except Exception as e: 33 | msg = e[0] 34 | if msg.startswith('Non-projectable'): 35 | non_projectable += 1 36 | else: 37 | non_leaf += [i] 38 | else: 39 | s = t.leaves() 40 | if s != [] and (max_length is None or len(s) <= max_length): 41 | t.corpus_index = i 42 | t.depset = depset.from_depgraph(d2) 43 | self.trees += [t] 44 | else: 45 | empty += 1 46 | i += 1 47 | self.non_projectable = non_projectable 48 | self.empty = empty 49 | self.non_leaf = non_leaf 50 | 51 | @staticmethod 52 | def is_punctuation(n): 53 | # n['tag'] is the fifth column. 54 | return False 55 | 56 | 57 | class CoNLL06Treebank(CoNLLTreebank): 58 | def __init__(self, root, max_length=None, files=None): 59 | if files is None: 60 | files = self.files 61 | corpus = dependency.DependencyCorpusReader(nltk.data.find('corpora/conll06/data/'+root), files) 62 | CoNLLTreebank.__init__(self, corpus, None, max_length) 63 | 64 | 65 | class German(CoNLL06Treebank): 66 | root = 'german/tiger/' 67 | files = ['train/german_tiger_train.conll', \ 68 | 'test/german_tiger_test.conll'] 69 | 70 | def __init__(self, max_length=None, files=None): 71 | CoNLL06Treebank.__init__(self, self.root, max_length, files) 72 | 73 | @staticmethod 74 | def is_punctuation(n): 75 | # n['tag'] is the fifth column. 76 | return n['tag'][0] == '$' 77 | 78 | 79 | class Turkish(CoNLL06Treebank): 80 | root = 'turkish/metu_sabanci/' 81 | files = ['train/turkish_metu_sabanci_train.conll', \ 82 | 'test/turkish_metu_sabanci_test.conll'] 83 | 84 | def __init__(self, max_length=None, files=None): 85 | CoNLL06Treebank.__init__(self, self.root, max_length, files) 86 | 87 | @staticmethod 88 | def is_punctuation(n): 89 | # n['tag'] is the fifth column. 90 | return n['tag'] == 'Punc' 91 | 92 | 93 | class Danish(CoNLL06Treebank): 94 | root = 'danish/ddt/' 95 | files = ['train/danish_ddt_train.conll', 'test/danish_ddt_test.conll'] 96 | 97 | def __init__(self, max_length=None, files=None): 98 | CoNLL06Treebank.__init__(self, self.root, max_length, files) 99 | 100 | @staticmethod 101 | def is_punctuation(n): 102 | # n['tag'] is the fifth column. 103 | return n['tag'] == 'XP' 104 | 105 | 106 | class Swedish(CoNLL06Treebank): 107 | root = 'swedish/talbanken05/' 108 | files = ['train/swedish_talbanken05_train.conll', 'test/swedish_talbanken05_test.conll'] 109 | 110 | def __init__(self, max_length=None, files=None): 111 | CoNLL06Treebank.__init__(self, self.root, max_length, files) 112 | 113 | @staticmethod 114 | def is_punctuation(n): 115 | # n['tag'] is the fifth column. 116 | return n['tag'] == 'IP' 117 | 118 | 119 | class Portuguese(CoNLL06Treebank): 120 | root = 'portuguese/bosque/' 121 | files = ['treebank/portuguese_bosque_train.conll', 'test/portuguese_bosque_test.conll'] 122 | 123 | def __init__(self, max_length=None, files=None): 124 | CoNLL06Treebank.__init__(self, self.root, max_length, files) 125 | 126 | @staticmethod 127 | def is_punctuation(n): 128 | # n['tag'] is the fifth column. 129 | return n['tag'] == 'punc' 130 | 131 | 132 | class Arabic(CoNLL06Treebank): 133 | root = 'arabic/PADT/' 134 | files = ['train/arabic.train', 'treebank/arabic_PADT_test.conll'] 135 | 136 | def __init__(self, max_length=None, files=None): 137 | CoNLL06Treebank.__init__(self, self.root, max_length, files) 138 | 139 | @staticmethod 140 | def is_punctuation(n): 141 | # n['tag'] is the fifth column. 142 | return n['tag'] == 'G' 143 | 144 | 145 | class Bulgarian(CoNLL06Treebank): 146 | root = 'bulgarian/bultreebank/' 147 | files = ['train/bulgarian_bultreebank_train.conll', 'test/bulgarian_bultreebank_test.conll'] 148 | 149 | def __init__(self, max_length=None, files=None): 150 | CoNLL06Treebank.__init__(self, self.root, max_length, files) 151 | 152 | @staticmethod 153 | def is_punctuation(n): 154 | # n['tag'] is the fifth column. 155 | return n['tag'] == 'Punct' 156 | 157 | 158 | class Chinese(CoNLL06Treebank): 159 | root = 'chinese/sinica/' 160 | files = ['train/chinese_sinica_train.conll', 'test/chinese_sinica_test.conll'] 161 | 162 | def __init__(self, max_length=None, files=None): 163 | CoNLL06Treebank.__init__(self, self.root, max_length, files) 164 | 165 | @staticmethod 166 | def is_punctuation(n): 167 | # n['tag'] is the fifth column. 168 | return False 169 | 170 | 171 | class Czech(CoNLL06Treebank): 172 | root = 'czech/pdt/' 173 | files = ['train/czech.train', 'treebank/czech_pdt_test.conll'] 174 | 175 | def __init__(self, max_length=None, files=None): 176 | CoNLL06Treebank.__init__(self, self.root, max_length, files) 177 | 178 | @staticmethod 179 | def is_punctuation(n): 180 | # see http://ufal.mff.cuni.cz/pdt2.0/doc/manuals/en/m-layer/html/ch02s02s01.html 181 | # n['tag'] is the fifth column. 182 | return n['tag'] == ':' 183 | 184 | 185 | class Dutch(CoNLL06Treebank): 186 | root = 'dutch/alpino/' 187 | files = ['train/dutch_alpino_train.conll', 'test/dutch_alpino_test.conll'] 188 | 189 | def __init__(self, max_length=None, files=None): 190 | CoNLL06Treebank.__init__(self, self.root, max_length, files) 191 | 192 | @staticmethod 193 | def is_punctuation(n): 194 | # n['tag'] is the fifth column. 195 | return n['tag'] == 'Punc' 196 | 197 | 198 | class Japanese(CoNLL06Treebank): 199 | root = 'japanese/verbmobil/' 200 | files = ['train/japanese_verbmobil_train.conll', 'test/japanese_verbmobil_test.conll'] 201 | 202 | def __init__(self, max_length=None, files=None): 203 | CoNLL06Treebank.__init__(self, self.root, max_length, files) 204 | 205 | @staticmethod 206 | def is_punctuation(n): 207 | # n['tag'] is the fifth column. 208 | return n['tag'] == '.' 209 | 210 | 211 | class Slovene(CoNLL06Treebank): 212 | root = 'slovene/sdt/' 213 | files = ['treebank/slovene_sdt_train.conll', 'test/slovene_sdt_test.conll'] 214 | 215 | def __init__(self, max_length=None, files=None): 216 | CoNLL06Treebank.__init__(self, self.root, max_length, files) 217 | 218 | @staticmethod 219 | def is_punctuation(n): 220 | # n['tag'] is the fifth column. 221 | return n['tag'] == 'PUNC' 222 | 223 | 224 | class Spanish(CoNLL06Treebank): 225 | root = 'spanish/cast3lb/' 226 | files = ['train/spanish_cast3lb_train.conll', 'test/spanish_cast3lb_test.conll'] 227 | 228 | def __init__(self, max_length=None, files=None): 229 | CoNLL06Treebank.__init__(self, self.root, max_length, files) 230 | 231 | @staticmethod 232 | def is_punctuation(n): 233 | # n['tag'] is the fifth column. 234 | return n['tag'][0] == 'F' 235 | 236 | 237 | class Catalan(CoNLLTreebank): 238 | def __init__(self): 239 | CoNLLTreebank.__init__(self, corpus.conll2007, ['cat.test', 'cat.train']) 240 | 241 | @staticmethod 242 | def is_punctuation(n): 243 | return n['tag'].lower()[0] == 'f' 244 | 245 | 246 | class Basque(CoNLLTreebank): 247 | def __init__(self): 248 | CoNLLTreebank.__init__(self, corpus.conll2007, ['eus.test', 'eus.train']) 249 | 250 | @staticmethod 251 | def is_punctuation(n): 252 | return n['tag'] == 'PUNT' 253 | 254 | 255 | def stats(): 256 | cls = [German, Turkish, Danish, Swedish, Portuguese, Arabic, Bulgarian, \ 257 | Chinese, Czech, Dutch, Japanese, Slovene, Spanish] 258 | for c in cls: 259 | tb = c(max_length=10) 260 | print c, len(tb.trees) 261 | -------------------------------------------------------------------------------- /nlp_commons/wsj10.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | import itertools 6 | import nltk 7 | 8 | from nltk.util import LazyMap 9 | 10 | from . import wsj 11 | from . import util 12 | 13 | 14 | class WSJn(wsj.WSJ): 15 | 16 | def __init__(self, n, basedir=None, load=True, extra_tags=None): 17 | wsj.WSJ.__init__(self, basedir) 18 | self.n = n 19 | self.filename = '%s.treebank' % basedir 20 | self.extra_tags = extra_tags 21 | if load: 22 | self.get_trees() 23 | 24 | def _generate_trees(self): 25 | # trees = util.load_obj(self.filename + '_gold_len10') 26 | trees = None 27 | 28 | if trees is None: 29 | print("Parsing treebank...") 30 | 31 | f = lambda t: len(t.leaves()) <= self.n 32 | m = lambda t: self._prepare(t) 33 | # f = lambda t: t 34 | # m = lambda t: t 35 | trees = [t for t in filter(f, map(m, self.parsed()))] 36 | 37 | # util.save_obj(trees, self.filename + '_gold_len10') 38 | 39 | # w/o filtering 40 | # trees = [t for t in self.parsed()] 41 | 42 | 43 | 44 | # add hmm induced tags 45 | if self.extra_tags is not None: 46 | # new_trees = util.load_obj(self.filename + '_35_len10') 47 | new_trees = None 48 | if new_trees is None: 49 | new_trees = [] 50 | # cnt = 0 51 | for t, tags in zip(self.parsed(), self.extra_tags): 52 | # if cnt % 100 == 0: 53 | # print cnt 54 | # cnt += 1 55 | for st, tag in zip(t.subtrees(lambda x: x.height() == 2), tags): 56 | st.insert(0, tag) 57 | t = self._prepare(t) 58 | if len(t.leaves()) / 2 <= self.n: 59 | new_trees += [t] 60 | # util.save_obj(new_trees, self.filename + '_35_len10') 61 | 62 | 63 | self.gold_tag_sents = [[(sub.label(), sub.leaves()[0]) \ 64 | for sub in t.subtrees(lambda x: x.height() == 2)] \ 65 | for t in trees] 66 | 67 | if self.extra_tags is not None: 68 | self.induce_tag_sents = [[(sub.leaves()[0], sub.leaves()[1]) \ 69 | for sub in t.subtrees(lambda x: x.height() == 2)] \ 70 | for t in new_trees] 71 | else: 72 | self.induce_tag_sents = None 73 | 74 | 75 | return trees 76 | 77 | def _prepare(self, t): 78 | # t.remove_leaves() 79 | # Remove punctuation, ellipsis and currency ($, #) at the same time: 80 | t.filter_tags(lambda x: x in wsj.word_tags) 81 | return t 82 | 83 | def tagged_sents(self): 84 | return self.gold_tag_sents, self.induce_tag_sents 85 | 86 | 87 | class WSJ10(WSJn): 88 | 89 | def __init__(self, basedir=None, load=True, extra_tags=None): 90 | WSJn.__init__(self, 10, basedir, load, extra_tags) 91 | 92 | 93 | class WSJ40(WSJn): 94 | 95 | def __init__(self, basedir=None, load=True): 96 | WSJn.__init__(self, 40, basedir, load) 97 | 98 | 99 | class WSJ10P(wsj.WSJ): 100 | """The 7422 sentences of the WSJ10 treebank but including punctuation. 101 | """ 102 | # antes era puntuacion pero sin el punto final 103 | #valid_tags = wsj.word_tags + wsj.punctuation_tags[1:] 104 | #punctuation_tags = wsj.punctuation_tags[1:] 105 | # pero no da para dejar afuera el punto porque no solo aparece al final (y es tag de ? y !): 106 | valid_tags = wsj.word_tags + wsj.punctuation_tags 107 | punctuation_tags = wsj.punctuation_tags 108 | stop_punctuation_tags = [',', '.', ':'] 109 | bracket_punctuation_tag_pairs = [('-LRB-', '-RRB-'), ('``', '\'\'')] 110 | 111 | def __init__(self, basedir=None, load=True): 112 | n = 10 113 | wsj.WSJ.__init__(self, basedir) 114 | self.n = n 115 | self.filename = 'wsj%02ip.treebank' % n 116 | if load: 117 | self.get_trees() 118 | 119 | def _generate_trees(self): 120 | print("Parsing treebank...") 121 | f = lambda t: len([x for x in t.leaves() if x not in self.punctuation_tags]) <= self.n 122 | m = lambda t: self._prepare(t) 123 | trees = [t for t in filter(f, map(m, self.parsed()))] 124 | return trees 125 | 126 | def _prepare(self, t): 127 | t.remove_leaves() 128 | # Con esto elimino ellipsis y $ y # (currency) al mismo tiempo: 129 | t.filter_tags(lambda x: x in self.valid_tags) 130 | return t 131 | 132 | """ 133 | Comparo la version vieja con la nueva: 134 | 135 | >>> from wsj10 import * 136 | >>> tbold = WSJ10P(load=False) 137 | >>> tbold.filename = 'wsj10p.treebank.old' 138 | >>> ts = tbold.get_trees() 139 | >>> tb = WSJ10P() 140 | 141 | # l son los indices de los arboles con hojas distintas 142 | >>> l = [i for i in range(len(tbold.trees)) if tbold.trees[i].leaves() != tb.trees[i].leaves()] 143 | >>> len(l) 144 | 6713 145 | # l2 son los indices de los arboles que cambian algo mas ademas del punto al final. Vemos que quedan pocos, 683. 146 | >>> l2 = [j for j in l if tbold.trees[j].leaves() != tb.trees[j].leaves()[:-1]] 147 | >>> len(l2) 148 | 683 149 | # entre los 683 de l2 hay algunos que agregan punto pero no al final sino un lugar antes (despues se suele cerrar comillas). quitamos estos en l3 para quedarnos con los que realmente hacen alguna diferencia: 150 | >>> l3 = [k for k in l2 if tbold.trees[j].leaves()[:-1] != tb.trees[j].leaves()[:-2]] 151 | >>> len(l3) 152 | 0 153 | # NO HAY NINGUNO! o sea que son masomenos lo mismo los corpus, PERO SOLO PARA EL CASO DEL WSJ10. 154 | """ 155 | 156 | 157 | class WSJnTagged(WSJn): 158 | 159 | def __init__(self, n, basedir=None, load=True): 160 | wsj.WSJ.__init__(self, basedir) 161 | self.n = n 162 | self.filename = 'wsj%02i.tagged_treebank' % n 163 | self.tagger = WSJTagger() 164 | if load: 165 | self.get_trees() 166 | 167 | def _prepare(self, t): 168 | # quito puntuacion, ellipsis y monedas, sin quitar las hojas: 169 | #t.remove_punctuation() 170 | #t.remove_ellipsis() 171 | #t.filter_tags(lambda x: x not in wsj.currency_tags_words) 172 | t.filter_subtrees(lambda t: type(t) == str or len([x for x in t.pos() if x[1] in wsj.word_tags]) > 0) 173 | t.map_leaves(self.tagger.tag) 174 | return t 175 | 176 | 177 | class WSJ10Tagged(WSJnTagged): 178 | 179 | def __init__(self, basedir=None, load=True): 180 | WSJnTagged.__init__(self, 10, basedir, load) 181 | 182 | 183 | class WSJTagger: 184 | 185 | filename = '../obj/clusters.nem.32' 186 | 187 | def __init__(self): 188 | f = open(self.filename) 189 | self.tag_dict = {} 190 | for l in f: 191 | l2 = l.split() 192 | self.tag_dict[l2[0]] = l2[1]+'C' 193 | 194 | def tag(self, word): 195 | return self.tag_dict[word.upper()] 196 | 197 | 198 | """ 199 | Chequeo del corpus (pa ver si saca los mismos arboles que WSJ10): 200 | 201 | >>> from wsj10 import * 202 | >>> tb2 = WSJ10Tagged() 203 | >>> len(tb2.trees) 204 | 7412 205 | # significa que faltan arboles... deberian ser 7422 206 | >>> tb = WSJ10() 207 | >>> l = [i for i in range(len(tb2.trees)) if tb.trees[i].labels != tb2.trees[i].labels] 208 | >>> l[0] 209 | 2112 210 | >>> i = l[0] 211 | >>> l[1] 212 | 2113 213 | >>> tb.trees[i].labels 214 | ['07/wsj_0758.mrg', 74] 215 | >>> tb2.trees[i].labels 216 | ['07/wsj_0758.mrg', 75] 217 | 218 | QUE BOSTA, SE USAN COMILLAS SIMPLES CUANDO DEBERIAN SER DOBLES: 219 | 220 | ( (S 221 | (NP-SBJ (PRP You) ) 222 | (VP (MD might) (RB not) 223 | (VP (VB find) 224 | (NP (NN one) ) 225 | (PP-LOC (IN in) 226 | (NP (DT the) (`` `) (NN Jurisprudence) ('' ') (NN column) )))) 227 | (. .) )) 228 | """ 229 | 230 | 231 | class WSJnLex(WSJn): 232 | 233 | def __init__(self, n, load=True): 234 | wsj.WSJ.__init__(self) 235 | self.n = n 236 | self.filename = 'wsj%02i.lex_treebank' % n 237 | self.tagger = WSJTagger() 238 | if load: 239 | self.get_trees() 240 | 241 | def _prepare(self, t): 242 | # quito puntuacion, ellipsis y monedas, sin quitar las hojas: 243 | #t.remove_punctuation() 244 | #t.remove_ellipsis() 245 | #t.filter_tags(lambda x: x not in wsj.currency_tags_words) 246 | t.filter_subtrees(lambda t: type(t) == str or len([x for x in t.pos() if x[1] in wsj.word_tags]) > 0) 247 | return t 248 | 249 | 250 | class WSJ10Lex(WSJnLex): 251 | def __init__(self, load=True): 252 | WSJnLex.__init__(self, 10, load) 253 | 254 | """ 255 | CREO UN ARCHIVO DE TEXTO CON LAS FRASES DEL WSJ10: 256 | 257 | >>> from wsj10 import * 258 | >>> tb = WSJ10Lex() 259 | >>> for t in tb.trees: 260 | ... s = string.join(t.leaves())+'\n' 261 | ... f.write(s) 262 | ... 263 | >>> f.close() 264 | >>> 265 | 266 | """ 267 | 268 | def test(): 269 | tb = WSJ10() 270 | return tb 271 | -------------------------------------------------------------------------------- /nlp_commons/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | # model.py: A general model for parsing (class Model). 6 | # Also a general model for bracketing parsing (class BracketingModel). 7 | 8 | import itertools 9 | import sys 10 | 11 | from . import util 12 | from . import sentence 13 | from . import bracketing 14 | from . import wsj10 15 | 16 | class Model: 17 | Gold = [] 18 | Parse = [] 19 | evaluation = None 20 | trained = False 21 | tested = False 22 | evaluated = False 23 | 24 | def train(self): 25 | self.trained = True 26 | 27 | def parse(self, s): 28 | return None 29 | 30 | #def test(self, S): 31 | # self.Parse = [self.parse(s) for s in S] 32 | # self.tested = True 33 | 34 | def test(self, short=False, max_length=None): 35 | self.Parse, self.Weight = [], 0.0 36 | 37 | #n = str(len(self.S)) 38 | #m = len(n) 39 | #o = "%"+str(m)+"d of "+n 40 | #i = 0 41 | #print "Parsed", o % i, 42 | #sys.stdout.flush() 43 | #o = ("\b"*(2*m+5)) + o 44 | p = util.Progress('Parsed', 0, len(self.S)) 45 | for s in self.S: 46 | if max_length is None or len(s) <= max_length: 47 | (parse, weight) = self.parse(s) 48 | else: 49 | (parse, weight) = (None, 0.0) 50 | self.Parse += [parse] 51 | self.Weight += weight 52 | #i += 1 53 | #print o % i, 54 | #sys.stdout.flush() 55 | next(p) 56 | print("\nFinished parsing.") 57 | self.eval(short=short, max_length=max_length) 58 | self.tested = True 59 | 60 | def eval(self, short=False, max_length=None): 61 | self.evaluated = True 62 | 63 | 64 | class BracketingModel(Model): 65 | count_fullspan_bracket = True 66 | count_length_2 = True 67 | count_length_2_1 = False 68 | 69 | def __init__(self, treebank=None, training_corpus=None): 70 | 71 | treebank = self._get_treebank(treebank) 72 | if training_corpus == None: 73 | training_corpus = treebank 74 | self.training_corpus = training_corpus 75 | 76 | S, Gold = [], [] 77 | #for s in treebank.sents(): 78 | for s in treebank.tagged_sents(): 79 | s = [x[1] for x in s] 80 | S += [sentence.Sentence(s)] 81 | 82 | for t in treebank.parsed_sents(): 83 | Gold += [bracketing.tree_to_bracketing(t)] 84 | 85 | self.S = S 86 | self.Gold = Gold 87 | 88 | def _get_treebank(self, treebank=None): 89 | if treebank is None: 90 | treebank = wsj10.WSJ10() 91 | return treebank 92 | 93 | def eval(self, output=True, short=False, long=False, max_length=None): 94 | """Compute precision, recall and F1 between the parsed bracketings and 95 | the gold bracketings. 96 | """ 97 | Gold = self.Gold 98 | 99 | Prec = 0.0 100 | Rec = 0.0 101 | 102 | # Medidas sumando brackets y despues promediando: 103 | brackets_ok = 0 104 | brackets_parse = 0 105 | brackets_gold = 0 106 | 107 | for i in range(len(Gold)): 108 | l = Gold[i].length 109 | if (max_length is None or l <= max_length) \ 110 | and (self.count_length_2_1 or (self.count_length_2 and l == 2) or l >= 3): 111 | (prec, rec) = self.measures(i) 112 | Prec += prec 113 | Rec += rec 114 | 115 | # Medidas sumando brackets y despues promediando: 116 | (b_ok, b_p, b_g) = self.measures2(i) 117 | brackets_ok += b_ok 118 | brackets_parse += b_p 119 | brackets_gold += b_g 120 | 121 | m = float(len(Gold)) 122 | Prec2 = float(brackets_ok) / float(brackets_parse) 123 | Rec2 = float(brackets_ok) / float(brackets_gold) 124 | F12 = 2*(Prec2*Rec2)/(Prec2+Rec2) 125 | 126 | self.evaluation = (m, Prec2, Rec2, F12) 127 | self.evaluated = True 128 | 129 | if output and not short: 130 | #print "Cantidad de arboles:", int(m) 131 | #print "Medidas sumando todos los brackets:" 132 | #print " Precision: %2.1f" % (100*Prec2) 133 | #print " Recall: %2.1f" % (100*Rec2) 134 | #print " Media harmonica F1: %2.1f" % (100*F12) 135 | #if long: 136 | #print "Brackets parse:", brackets_parse 137 | #print "Brackets gold:", brackets_gold 138 | #print "Brackets ok:", brackets_ok 139 | #Prec = Prec / m 140 | #Rec = Rec / m 141 | #F1 = 2*(Prec*Rec)/(Prec+Rec) 142 | #print "Medidas promediando p y r por frase:" 143 | #print " Precision: %2.1f" % (100*Prec) 144 | #print " Recall: %2.1f" % (100*Rec) 145 | #print " Media harmonica F1: %2.1f" % (100*F1) 146 | print("Sentences:", int(m)) 147 | print("Micro-averaged measures:") 148 | print(" Precision: %2.1f" % (100*Prec2)) 149 | print(" Recall: %2.1f" % (100*Rec2)) 150 | print(" Harmonic mean F1: %2.1f" % (100*F12)) 151 | if int: 152 | print("Brackets parse:", brackets_parse) 153 | print("Brackets gold:", brackets_gold) 154 | print("Brackets ok:", brackets_ok) 155 | Prec = Prec / m 156 | Rec = Rec / m 157 | F1 = 2*(Prec*Rec)/(Prec+Rec) 158 | print("Macro-averaged measures:") 159 | print(" Precision: %2.1f" % (100*Prec)) 160 | print(" Recall: %2.1f" % (100*Rec)) 161 | print(" Harmonic mean F1: %2.1f" % (100*F1)) 162 | elif output and short: 163 | print("F1 =", F12) 164 | 165 | return self.evaluation 166 | 167 | # FIXME: no esta bien adaptado para usar count_fullspan_bracket 168 | # Funcion auxiliar de eval(); 169 | # Precision y recall del i-esimo parse respecto de su gold: 170 | def measures(self, i): 171 | g = self.Gold[i].brackets 172 | if self.Parse[i] is None: 173 | p, n = set(), 0 174 | else: 175 | p = self.Parse[i].brackets 176 | n = float(bracketing.coincidences(self.Gold[i], self.Parse[i])) 177 | 178 | if len(p) > 0: 179 | if self.count_fullspan_bracket: 180 | prec = (n+1) / float(len(p)+1) 181 | else: 182 | prec = n / float(len(p)) 183 | elif len(g) == 0: 184 | prec = 1.0 185 | else: 186 | # XXX: no deberia ser 1? 187 | prec = 0.0 188 | 189 | if len(g) > 0: 190 | if self.count_fullspan_bracket: 191 | rec = (n+1) / float(len(g)+1) 192 | else: 193 | rec = n / float(len(g)) 194 | else: 195 | rec = 1.0 196 | 197 | return (prec, rec) 198 | 199 | # FIXME: hacer andar con frases de largo 1! 200 | # devuelve la terna (brackets_ok, brackets_parse, brackets_gold) 201 | # del i-esimo arbol. Se usa para calcular las medidas 202 | # micro-promediadas. 203 | def measures2(self, i): 204 | g = self.Gold[i].brackets 205 | if self.Parse[i] is None: 206 | p, n = set(), 0 207 | else: 208 | p = self.Parse[i].brackets 209 | n = float(bracketing.coincidences(self.Gold[i], self.Parse[i])) 210 | if self.count_fullspan_bracket: 211 | return (n+1, len(p)+1, len(g)+1) 212 | else: 213 | return (n, len(p), len(g)) 214 | 215 | # FIXME: pegado asi nomas: adaptar esto para usar measures. 216 | def eval_by_length(self): 217 | #Prec = {} 218 | #Rec = {} 219 | Gold = self.Gold 220 | Parse = self.Parse 221 | 222 | brackets_ok = {} 223 | brackets_parse = {} 224 | brackets_gold = {} 225 | 226 | for i in range(2, 11): 227 | brackets_ok[i] = 0 228 | brackets_parse[i] = 0 229 | brackets_gold[i] = 0 230 | 231 | for gb, pb in zip(Gold, Parse): 232 | gb.set_start_index(0) 233 | pb.set_start_index(0) 234 | l = gb.length 235 | for i in range(2, l): 236 | g = set([x_y for x_y in gb.brackets if x_y[1]-x_y[0] == i]) 237 | p = set([x_y1 for x_y1 in pb.brackets if x_y1[1]-x_y1[0] == i]) 238 | 239 | brackets_ok[i] += len(g & p) 240 | brackets_parse[i] += len(p) 241 | brackets_gold[i] += len(g) 242 | if self.count_fullspan_bracket and ((self.count_length_2 and l == 2) or l >= 3): 243 | brackets_ok[l] += 1 244 | brackets_parse[l] += 1 245 | brackets_gold[l] += 1 246 | 247 | Prec = {} 248 | Rec = {} 249 | F1 = {} 250 | print("i\tP\tR\tF1") 251 | for i in range(2, 10): 252 | Prec[i] = float(brackets_ok[i]) / float(brackets_parse[i]) 253 | Rec[i] = float(brackets_ok[i]) / float(brackets_gold[i]) 254 | F1[i] = 2*(Prec[i]*Rec[i])/(Prec[i]+Rec[i]) 255 | print("%i\t%2.2f\t%2.2f\t%2.2f" % (i, 100*Prec[i], 100*Rec[i], 100*F1[i])) 256 | 257 | return (Prec, Rec, F1) 258 | -------------------------------------------------------------------------------- /nlp_commons/dep/dnegra.py: -------------------------------------------------------------------------------- 1 | # dnegra.py: Dependency trees of the NEGRA corpus. 2 | 3 | from nltk import tree 4 | 5 | import treebank 6 | from dep import depset 7 | 8 | class Negra10(treebank.SavedTreebank): 9 | default_basedir = 'negra-corpus' 10 | trees = [] 11 | filename = 'negra10.deptreebank' 12 | 13 | def __init__(self, basedir=None, load=True): 14 | if basedir == None: 15 | basedir = self.default_basedir 16 | self.basedir = basedir 17 | if load: 18 | self.get_trees() 19 | 20 | def parsed(self): 21 | f = open(self.basedir+'/negra-corpus.export') 22 | self.f = f 23 | 24 | # go to first sentece 25 | s = f.readline() 26 | while not s.startswith('#BOS'): 27 | s = f.readline() 28 | 29 | while s != '': 30 | l = s.split() 31 | (num, origin) = (int(l[1]), int(l[4])) 32 | sent = [] 33 | l = f.readline().split() 34 | while l[0][0] != '#': 35 | #if l[4] != '0': 36 | if not l[1].startswith('$'): 37 | sent += [l] 38 | l = f.readline().split() 39 | 40 | parse = [] 41 | while l[0] != '#EOS': 42 | parse += [l] 43 | l = f.readline().split() 44 | 45 | if len(sent) > 0 and len(sent) <= 10: 46 | self.sent = sent 47 | self.parse = parse 48 | t = build_tree(sent, parse) 49 | t2 = treebank.Tree(t, (num, origin)) 50 | t2.depset = tree_to_depset(t) 51 | yield t2 52 | 53 | s = f.readline() 54 | 55 | 56 | def build_tree(sent, parse): 57 | entries = dict((l[0], l) for l in parse) 58 | # for sentences that have several roots (e.g. #BOS 77 3 863208763 1): 59 | entries['#0'] = ['#0', 'ROOT', '--', '--', ''] 60 | 61 | # add indexed lexical entries: 62 | for i in range(len(sent)): 63 | entries[i] = sent[i] 64 | 65 | return _build_tree(entries, '#0') 66 | 67 | 68 | def _build_tree(entries, root): 69 | """Helper for build_tree. (Fue un dolor de huevos.) 70 | """ 71 | entry = entries[root] 72 | if isinstance(root, int): 73 | t = tree.Tree(entry[1], [entry[0]]) 74 | t.head = 0 75 | t.start_index = root 76 | t.edge = entry[3] 77 | return t 78 | else: 79 | root = root[1:] 80 | subtrees = [] 81 | for (word, l) in entries.iteritems(): 82 | # parent = l[4] 83 | if l[4] == root: 84 | subtree = _build_tree(entries, word) 85 | subtrees += [subtree] 86 | subtrees = sorted(subtrees, key=lambda t: t.start_index) 87 | t = tree.Tree(entry[1], subtrees) 88 | t.start_index = subtrees[0].start_index 89 | t.edge = entry[3] 90 | 91 | # head-finding from http://maltparser.org/userguide.html: 92 | (dir, plist) = head_rules['CAT:'+t.node] 93 | plist = plist + [('LEXICAL', '')] 94 | if dir == 'r': 95 | # we will reverse again later. 96 | subtrees.reverse() 97 | found = False 98 | i = 0 99 | while i < len(plist) and not found: 100 | (type, val) = plist[i] 101 | j = 0 102 | while j < len(subtrees) and not found: 103 | subtree = subtrees[j] 104 | if (type == 'LABEL' and subtree.edge == val) or \ 105 | (type == 'CAT' and subtree.node.split('[')[0] == val) or \ 106 | (type == 'LEXICAL' and isinstance(subtree[0], str)): 107 | head_st = subtree 108 | found = True 109 | j += 1 110 | i += 1 111 | if not found: 112 | head_st = subtrees[0] 113 | if dir == 'r': 114 | subtrees.reverse() 115 | 116 | #if t.node == 'ROOT': 117 | # print dir, plist, subtrees, head_st 118 | 119 | # mark head: 120 | t.head = subtrees.index(head_st) 121 | t.node += '['+subtrees[t.head].node.split('[')[0]+']' 122 | 123 | return t 124 | 125 | 126 | head_rules = \ 127 | {'CAT:ROOT': ('l', []), \ 128 | 'CAT:AA': ('r', [('LABEL', 'HD')]), \ 129 | 'CAT:AP': ('r', [('LABEL', 'HD')]), \ 130 | 'CAT:AVP': ('r', [('LABEL', 'HD'), ('CAT', 'AVP')]), \ 131 | 'CAT:CAC': ('l', [('LABEL', 'CJ')]), \ 132 | 'CAT:CAP': ('l', [('LABEL', 'CJ')]), \ 133 | 'CAT:CAVP': ('l', [('LABEL', 'CJ')]), \ 134 | 'CAT:CCP': ('l', [('LABEL', 'CJ')]), \ 135 | 'CAT:CH': ('l', []), \ 136 | 'CAT:CNP': ('l', [('LABEL', 'CJ')]), \ 137 | 'CAT:CO': ('l', [('LABEL', 'CJ')]), \ 138 | 'CAT:CPP': ('l', [('LABEL', 'CJ')]), \ 139 | 'CAT:CS': ('l', [('LABEL', 'CJ')]), \ 140 | 'CAT:CVP': ('l', [('LABEL', 'CJ')]), \ 141 | 'CAT:CVZ': ('l', [('LABEL', 'CJ')]), \ 142 | 'CAT:DL': ('l', [('LABEL', 'DH')]), \ 143 | 'CAT:ISU': ('l', []), \ 144 | 'CAT:NM': ('r', []), \ 145 | 'CAT:NP': ('r', [('LABEL', 'NK')]), \ 146 | # Malt says 'PN' (why?) 147 | 'CAT:MPN': ('l', []), \ 148 | 'CAT:PP': ('r', [('LABEL', 'NK')]), \ 149 | 'CAT:S': ('r', [('LABEL', 'HD')]), \ 150 | 'CAT:VP': ('r', [('LABEL', 'HD')]), \ 151 | 'CAT:VROOT': ('l', []), \ 152 | # missing rules: 153 | # e.g. BOS 507: 154 | 'CAT:VZ': ('l', []), \ 155 | # e.g. BOS 5576: 156 | 'CAT:MTA': ('l', []) \ 157 | } 158 | 159 | 160 | def tree_to_depset(t): 161 | """Returns the DepSet associated to the partially head marked tree t. 162 | """ 163 | (res, head) = _tree_to_depset(t) 164 | if head != -1: 165 | res.append((head, -1)) 166 | return depset.DepSet(len(t.leaves()), sorted(res)) 167 | 168 | 169 | def _tree_to_depset(t): 170 | """Helper for tree_to_depset. (Fue un dolor de huevos.) 171 | """ 172 | #if isinstance(t, str): 173 | # return ([], []) 174 | if isinstance(t[0], str): 175 | return ([], t.start_index) 176 | else: 177 | depset = [] 178 | heads = [] 179 | for st in t: 180 | (d, h) = _tree_to_depset(st) 181 | depset += d 182 | heads += [h] 183 | if t.head != -1: 184 | # resolve all unresolved dependencies: 185 | new_head = heads[t.head] 186 | new_depset = [(i, (j==-1 and new_head) or j) for (i, j) in depset] 187 | else: 188 | # propagate unresolved dependencies 189 | new_head = -1 190 | new_depset = depset 191 | new_depset += [(j, new_head) for j in heads if j != -1 and j != new_head] 192 | return (new_depset, new_head) 193 | 194 | 195 | def build_tree2(sent, parse): 196 | """Iterative version of build_tree. Maybe faster, but uglyer. 197 | """ 198 | dparse = dict((l[0][1:], l) for l in parse) 199 | # for sentences that have several roots (e.g. #BOS 77 3 863208763 1): 200 | dparse['0'] = ['#0', 'ROOT', '--', '--', ''] 201 | 202 | ltree = [] 203 | for i in range(len(sent)): 204 | l = sent[i] 205 | t = tree.Tree(l[1], [l[0]]) 206 | t.head = 0 207 | # to be used in tree_to_depset: 208 | t.start_index = i 209 | ltree += [(t, l[3], l[4])] 210 | 211 | # (XXX: not sure about the condition) 212 | while len(ltree) > 1: 213 | ids = set(dparse.keys()) - set(l[4] for l in dparse.itervalues()) 214 | new_ltree = [] 215 | last_id = -1 216 | for (t, edge, id) in ltree: 217 | if id in ids: 218 | if last_id != id: 219 | new_t = tree.Tree(dparse[id][1], [t]) 220 | new_t.head = -1 221 | new_ltree += [(new_t, dparse[id][3], dparse[id][4])] 222 | else: 223 | new_t = new_ltree[-1][0] 224 | new_t.append(t) 225 | # basic head-finding: 226 | if edge == 'HD': 227 | new_t.head = len(new_t)-1 228 | new_t.node += '['+t.node+']' 229 | """# head-finding from http://maltparser.org/userguide.html: 230 | # (section "Phrase structure parsing") 231 | elif hasattr(t, 'start_index') and new_t.head == -1: 232 | # "hasattr" says that t is a lexical item. 233 | new_t.head = len(new_t) - 1 234 | new_t.node += '['+t.node+']'""" 235 | else: 236 | new_ltree += [(t, edge, id)] 237 | last_id = id 238 | 239 | for id in ids: 240 | del dparse[id] 241 | ltree = new_ltree 242 | 243 | result = ltree[0][0] 244 | result.depset = depset 245 | 246 | return result 247 | 248 | 249 | """ 250 | head_rules = { \ 251 | 'CAT:AA': ('r', 'r[LABEL:HD]'), \ 252 | 'CAT:AP': ('r', 'r[LABEL:HD]'), \ 253 | 'CAT:AVP': ('r', 'r[LABEL:HD CAT:AVP]'), \ 254 | 'CAT:CAC': ('l', 'l[LABEL:CJ]'), \ 255 | 'CAT:CAP': ('l', 'l[LABEL:CJ]'), \ 256 | 'CAT:CAVP':('l', 'l[LABEL:CJ]'), \ 257 | 'CAT:CH': ('l', '*'), \ 258 | 'CAT:CNP': ('l', 'l[LABEL:CJ]'), \ 259 | 'CAT:CO': ('l', 'l[LABEL:CJ]'), \ 260 | 'CAT:CPP': ('l', 'l[LABEL:CJ]'), \ 261 | 'CAT:CS': ('l', 'l[LABEL:CJ]'), \ 262 | 'CAT:CVP': ('l', 'l[LABEL:CJ]'), \ 263 | 'CAT:CCP': ('l', 'l[LABEL:CJ]'), \ 264 | 'CAT:CVZ': ('l', 'l[LABEL:CJ]'), \ 265 | 'CAT:DL': ('l', 'l[LABEL:DH]'), \ 266 | 'CAT:ISU': ('l', '*'), \ 267 | 'CAT:NM': ('r', '*'), \ 268 | 'CAT:NP': ('r', 'r[LABEL:NK]'), \ 269 | 'CAT:PN': ('l', '*'), \ 270 | 'CAT:PP': ('r', 'r[LABEL:NK]'), \ 271 | 'CAT:S': ('r', 'r[LABEL:HD]'), \ 272 | 'CAT:VROOT':('l', '*'), \ 273 | 'CAT:VP': ('r', 'r[LABEL:HD]') } 274 | 275 | >>> for k,v in head_rules.iteritems(): 276 | ... (dir,plist)=v 277 | ... plist = plist[2:-1].split() 278 | ... plist = [(s.split(':')[0], s.split(':')[1]) for s in plist] 279 | ... head_rules[k] = (dir,plist) 280 | ... 281 | >>> from pprint import pprint 282 | >>> pprint(head_rules) 283 | """ 284 | -------------------------------------------------------------------------------- /nlp_commons/cast3lb.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | # -*- coding: iso-8859-1 -*- 6 | # Creado en base al modulo nltk_lite.corpora.treebank. 7 | 8 | # Natural Language Toolkit: Penn Treebank Reader 9 | # 10 | # Copyright (C) 2001-2005 University of Pennsylvania 11 | # Author: Steven Bird 12 | # Edward Loper 13 | # URL: 14 | # For license information, see LICENSE.TXT 15 | 16 | # from nltk_lite.corpora import get_basedir 17 | 18 | import itertools 19 | 20 | from nltk import tree 21 | 22 | from . import treebank 23 | 24 | # Funciona para el Cast3LB antes y despues de quitar las hojas, 25 | # y antes y despues de eliminar las funciones. 26 | def is_ellipsis(s): 27 | # 'sn.co' aparece como tag de una elipsis en '204_c-3.tbf', 11, 28 | # 'sn' en 'a12-4.tbf', 1 y en 'a14-0.tbf', 2. 29 | return s == '*0*' or \ 30 | s.split('-')[0] in ['sn.e', 'sn.e.1', 'sn.co', 'sn'] 31 | 32 | 33 | # Funciona para el Cast3LB solo si las hojas son POS tags. 34 | def is_punctuation(s): 35 | return s.lower()[0] == 'f' 36 | 37 | 38 | class Cast3LBTree(treebank.Tree): 39 | 40 | 41 | # Funciona para el Cast3LB antes y despues de quitar las hojas, 42 | # y antes y despues de eliminar las funciones. 43 | def is_ellipsis(self, s): 44 | return is_ellipsis(s) 45 | 46 | 47 | # Funciona para el Cast3LB solo si las hojas son POS tags. 48 | def is_punctuation(self, s): 49 | return is_punctuation(s) 50 | 51 | 52 | class Cast3LB(treebank.SavedTreebank): 53 | default_basedir = "3lb-cast" 54 | trees = [] 55 | filename = 'cast3lb.treebank' 56 | 57 | 58 | def __init__(self, basedir=None, load=False): 59 | if basedir == None: 60 | self.basedir = self.default_basedir 61 | else: 62 | self.basedir = basedir 63 | if load: 64 | self.get_trees() 65 | 66 | 67 | # Devuelve el arbol que se encuentra en la posicion offset de los archivos 68 | # files del treebank Cast3LB. Sin parametros devuelve un arbol cualquiera. 69 | # files puede ser un nombre de archivo o una lista de nombres de archivo. 70 | def get_tree(self, files=None, offset=0): 71 | # Parsear files y parar cuando se llegue al item offset+1. 72 | #t = [t for t in itertools.islice(parsed(files),offset+1)][offset] 73 | #if preprocess: 74 | # t = prepare(t) 75 | #return t 76 | t = self.get_trees2(files, offset, offset+1)[0] 77 | return t 78 | 79 | 80 | # Devuelve los arboles que se encuentran en la posicion i con start <= i < end 81 | # dentro de los archivo files del treebank Cast3LB. 82 | # files puede ser un nombre de archivo o una lista de nombres de archivo. 83 | def get_trees2(self, files=None, start=0, end=None): 84 | lt = [t for t in itertools.islice(self.parsed(files), start, end)] 85 | return lt 86 | 87 | 88 | """# puede ser reemplazado en las subclases para filtrar: 89 | # FIXME: capaz que get_trees2 hace lo mismo y esto es al pedo: 90 | def _generate_trees(self): 91 | print "Parseando el Cast3LB treebank..." 92 | trees = [self._prepare(t) for t in self.parsed()] 93 | return trees 94 | 95 | 96 | # para ser reemplazado en las subclases: 97 | def _prepare(self, t): 98 | return t""" 99 | 100 | 101 | def remove_ellipsis(self): 102 | list(map(lambda t: t.remove_ellipsis(), self.trees)) 103 | 104 | 105 | def remove_punctuation(self): 106 | list(map(lambda t: t.remove_punctuation(), self.trees)) 107 | 108 | 109 | def parsed(self, files=None): 110 | for t in treebank.SavedTreebank.parsed(self, files): 111 | yield Cast3LBTree(tree.Tree('ROOT', [t]), t.labels) 112 | 113 | 114 | # Funciona para el Cast3LB antes y despues de quitar las hojas, 115 | # y antes y despues de eliminar las funciones. 116 | def is_ellipsis(self, s): 117 | return is_ellipsis(s) 118 | 119 | 120 | # Funciona para el Cast3LB solo si las hojas son POS tags. 121 | def is_punctuation(self, s): 122 | return is_punctuation(s) 123 | 124 | 125 | """# Devuelve el treebank Cast3LB entero. 126 | def get_treebank(): 127 | cast3lb_treebank = treebank.load_treebank('cast3lb.treebank') 128 | if cast3lb_treebank is None: 129 | return cast3lb_treebank 130 | 131 | # Devuelve los datos de entrenamiento del Cast3LB. 132 | def get_training_treebank(): 133 | training_treebank = treebank.load_treebank('cast3lb_training.treebank') 134 | if training_treebank is None: 135 | print "Parseando datos de entrenamiento del Cast3LB treebank..." 136 | training_files = get_training_files() 137 | trees = [prepare(t) for t in parsed(training_files)] 138 | training_treebank = treebank.Treebank(trees) 139 | training_treebank.save('cast3lb_training.treebank') 140 | return training_treebank 141 | 142 | # Devuelve los datos de testeo del Cast3LB. 143 | def get_test_treebank(): 144 | test_treebank = treebank.load_treebank('cast3lb_test.treebank') 145 | if test_treebank is None: 146 | print "Parseando datos de testeo del Cast3LB treebank..." 147 | test_files = get_test_files() 148 | trees = [prepare(t) for t in parsed(test_files)] 149 | test_treebank = treebank.Treebank(trees) 150 | # Ordena de menor a mayor largo de oracion. 151 | test_treebank.length_sort() 152 | test_treebank.save('cast3lb_test.treebank') 153 | return test_treebank 154 | """ 155 | 156 | """def get_files(filename): 157 | f = open(filename, 'r') 158 | s = f.read() 159 | return s.split() 160 | 161 | 162 | def get_training_files(filename='training.txt'): 163 | return get_files(filename) 164 | 165 | 166 | def get_test_files(filename='test.txt'): 167 | return get_files(filename) 168 | """ 169 | 170 | """ 171 | def prepare(t): 172 | "" 173 | Prepara un arbol obtenido del Cast3LB para ser usado para crear un 174 | PCFG. 175 | 176 | @param t: el arbol 177 | "" 178 | t.remove_leaves() 179 | return Cast3LBTree(tree.Tree('ROOT', [t]), t.labels) 180 | """ 181 | 182 | 183 | """def filter_nodes(t, f): 184 | if not isinstance(t, tree.Tree): 185 | return t 186 | 187 | subtrees = [] 188 | for st in t: 189 | if (isinstance(st, tree.Tree) and f(st.node)) or \ 190 | (not isinstance(st, tree.Tree) and f(st)): 191 | st = filter_nodes(st, f) 192 | subtrees += [st] 193 | return tree.Tree(t.node, subtrees)""" 194 | 195 | """ 196 | Raw: 197 | 198 | Pierre Vinken, 61 years old, will join the board as a nonexecutive 199 | director Nov. 29. 200 | 201 | Tagged: 202 | 203 | Pierre/NNP Vinken/NNP ,/, 61/CD years/NNS old/JJ ,/, will/MD join/VB 204 | the/DT board/NN as/IN a/DT nonexecutive/JJ director/NN Nov./NNP 29/CD ./. 205 | 206 | NP-Chunked: 207 | 208 | [ Pierre/NNP Vinken/NNP ] 209 | ,/, 210 | [ 61/CD years/NNS ] 211 | old/JJ ,/, will/MD join/VB 212 | [ the/DT board/NN ] 213 | as/IN 214 | [ a/DT nonexecutive/JJ director/NN Nov./NNP 29/CD ] 215 | ./. 216 | 217 | Parsed: 218 | 219 | ( (S 220 | (NP-SBJ 221 | (NP (NNP Pierre) (NNP Vinken) ) 222 | (, ,) 223 | (ADJP 224 | (NP (CD 61) (NNS years) ) 225 | (JJ old) ) 226 | (, ,) ) 227 | (VP (MD will) 228 | (VP (VB join) 229 | (NP (DT the) (NN board) ) 230 | (PP-CLR (IN as) 231 | (NP (DT a) (JJ nonexecutive) (NN director) )) 232 | (NP-TMP (NNP Nov.) (CD 29) ))) 233 | (. .) )) 234 | """ 235 | 236 | 237 | """def chunked(files = 'chunked'): 238 | "" 239 | @param files: One or more treebank files to be processed 240 | @type files: L{string} or L{tuple(string)} 241 | @rtype: iterator over L{tree} 242 | "" 243 | 244 | # Just one file to process? If so convert to a tuple so we can iterate 245 | if isinstance(files, str): 246 | files = (files,) 247 | 248 | for file in files: 249 | path = os.path.join(get_basedir(), "treebank", file) 250 | s = open(path).read() 251 | for t in tokenize.blankline(s): 252 | yield tree.chunk(t) 253 | 254 | 255 | def tagged(files = 'chunked'): 256 | "" 257 | @param files: One or more treebank files to be processed 258 | @type files: L{string} or L{tuple(string)} 259 | @rtype: iterator over L{list(tuple)} 260 | "" 261 | 262 | # Just one file to process? If so convert to a tuple so we can iterate 263 | if isinstance(files, str): 264 | files = (files,) 265 | 266 | for file in files: 267 | path = os.path.join(get_basedir(), "treebank", file) 268 | f = open(path).read() 269 | for sent in tokenize.blankline(f): 270 | l = [] 271 | for t in tokenize.whitespace(sent): 272 | if (t != '[' and t != ']'): 273 | l.append(tag2tuple(t)) 274 | yield l 275 | 276 | def raw(files = 'raw'): 277 | "" 278 | @param files: One or more treebank files to be processed 279 | @type files: L{string} or L{tuple(string)} 280 | @rtype: iterator over L{list(string)} 281 | "" 282 | 283 | # Just one file to process? If so convert to a tuple so we can iterate 284 | if isinstance(files, str): 285 | files = (files,) 286 | 287 | for file in files: 288 | path = os.path.join(get_basedir(), "treebank", file) 289 | f = open(path).read() 290 | for sent in tokenize.blankline(f): 291 | l = [] 292 | for t in tokenize.whitespace(sent): 293 | l.append(t) 294 | yield l 295 | 296 | 297 | def demo(): 298 | from nltk_lite.corpora import treebank 299 | 300 | print "Parsed:" 301 | for tree in itertools.islice(treebank.parsed(), 3): 302 | print tree.pp() 303 | print 304 | 305 | print "Chunked:" 306 | for tree in itertools.islice(treebank.chunked(), 3): 307 | print tree.pp() 308 | print 309 | 310 | print "Tagged:" 311 | for sent in itertools.islice(treebank.tagged(), 3): 312 | print sent 313 | print 314 | 315 | print "Raw:" 316 | for sent in itertools.islice(treebank.raw(), 3): 317 | print sent 318 | print 319 | 320 | if __name__ == '__main__': 321 | demo() 322 | """ -------------------------------------------------------------------------------- /nlp_commons/bracketing.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2007-2011 Franco M. Luque 2 | # URL: 3 | # For license information, see LICENSE.txt 4 | 5 | # bracketing.py: Bracketing data structure. 6 | 7 | import itertools 8 | import random 9 | import string 10 | import math 11 | 12 | from nltk import tree 13 | 14 | from . import treebank 15 | 16 | class Bracketing: 17 | """For instance: 18 | Bracketing(10, set([(1, 3), (5, 11), (6, 11), (8, 10), (1, 4), (7, 11), 19 | (4, 11)]), 1). 20 | """ 21 | 22 | # FIXME: eliminar brackets unarios. 23 | def __init__(self, length, brackets=None, start_index=0): 24 | """brackets debe ser un set de pares de enteros. 25 | """ 26 | 27 | self.length = length 28 | self.start_index = start_index 29 | if brackets is None: 30 | self.brackets = set() 31 | else: 32 | brackets.discard((start_index, start_index+length)) 33 | self.brackets = brackets 34 | 35 | def __eq__(self, other): 36 | if not isinstance(other, Bracketing): 37 | return False 38 | return (self.length, self.brackets, self.start_index) == \ 39 | (other.length, other.brackets, other.start_index) 40 | 41 | def __ne__(self, other): 42 | return not self.__eq__(other) 43 | 44 | def __le__(self, other): 45 | if not isinstance(other, Bracketing): 46 | return False 47 | return (self.length, self.start_index) == \ 48 | (other.length, other.start_index) and \ 49 | self.brackets <= other.brackets 50 | 51 | def has(self, xxx_todo_changeme9): 52 | """Returns True if the bracket belongs to the bracketing or encloses 53 | the whole sentence.""" 54 | (i, j) = xxx_todo_changeme9 55 | return j - i == 1 or \ 56 | (i, j) == (self.start_index, self.start_index+self.length) or \ 57 | (i, j) in self.brackets 58 | 59 | def has_opening_bracket(self, i, whole=True): 60 | if whole and i == b.start_index: 61 | return True 62 | bs = [a_b3 for a_b3 in self.brackets if a_b3[0] == i] 63 | return bs != [] 64 | 65 | def has_closing_bracket(self, i, whole=True): 66 | if whole and i == self.start_index + self.length: 67 | return True 68 | bs = [a_b4 for a_b4 in self.brackets if a_b4[1] == i] 69 | return bs != [] 70 | 71 | def ibrackets(self, whole=False, unary=False): 72 | """Iterator over the brackets. 73 | """ 74 | if unary and (whole or self.length > 1): 75 | c1 = map(lambda a: (a, a+1), list(range(self.start_index, self.start_index+self.length))) 76 | else: 77 | c1 = [] 78 | #if whole and self.length > 1: 79 | if whole: 80 | c3 = [(self.start_index, self.start_index + self.length)] 81 | else: 82 | c3 = [] 83 | 84 | return itertools.chain(c1, self.brackets, c3) 85 | 86 | def set_start_index(self, start_index): 87 | """Change internal representation. 88 | """ 89 | old = self.start_index 90 | new = start_index 91 | self.brackets = set([(a_b[0] - old + new, a_b[1] - old + new) for a_b in self.brackets]) 92 | self.start_index = new 93 | 94 | def is_binary(self): 95 | return (self.length < 3 or len(self.brackets) == self.length - 2) and \ 96 | self.non_crossing() 97 | 98 | def non_crossing(self): 99 | if len(self.brackets) < 2: 100 | return True 101 | 102 | def consistent(xxx_todo_changeme, xxx_todo_changeme8): 103 | # Disjuntos, 1 dentro de 2 o 2 dentro de 1 104 | (i1, j1) = xxx_todo_changeme 105 | (i2, j2) = xxx_todo_changeme8 106 | return j1 <= i2 or j2 <= i1 or \ 107 | (i2 <= i1 and j1 <= j2) or \ 108 | (i1 <= i2 and j2 <= j1) 109 | 110 | result = True 111 | blist = list(self.brackets) 112 | i, j, l = 0, 1, len(blist) 113 | while result and (i, j) != (l-1, l): 114 | result = result and consistent(blist[i], blist[j]) 115 | if j < l-1: 116 | j += 1 117 | else: 118 | i += 1 119 | j = i+1 120 | return result 121 | 122 | def treefy(self, s=None): 123 | if s is None: 124 | s = ['X'] * self.length 125 | b2 = set([(a_b1[0]-self.start_index, a_b1[1]-self.start_index) for a_b1 in self.brackets]) 126 | return treefy(s, b2) 127 | 128 | def strfy(self, s, whole=False): 129 | """Returns a string representation of the bracketing, using 130 | s as the bracketed sentence (e.g. 'DT (VB NN)'). 131 | """ 132 | s2 = [x for x in s] 133 | for (i, j) in self.ibrackets(whole=whole): 134 | s2[i] = '('+s2[i] 135 | s2[j-1] = s2[j-1]+')' 136 | return string.join(s2) 137 | 138 | def randomly_binarize(self, start=None, end=None): 139 | """Binarize the bracketing adding the missing brackets randomly. 140 | (start and end are used for the recursive call, do not use.) 141 | """ 142 | brackets = self.brackets 143 | if start is None: 144 | first = True 145 | l = self.length 146 | start = self.start_index 147 | end = start + l 148 | else: 149 | first = False 150 | l = end - start 151 | 152 | if l > 2: 153 | # lo primero es identificar los split points posibles: 154 | splits = [] 155 | i = 1 156 | while i < l: 157 | if self.splittable(start + i, start, end): 158 | splits += [i] 159 | i += 1 160 | """if first: 161 | print splits 162 | else: 163 | print 'start', start, 'end', end""" 164 | assert splits != [] 165 | 166 | # ahora elegimos un split al azar y agregamos los brackets: 167 | split = start + random.choice(splits) 168 | # esto elegiria si quiero binarizar lo mas parecido posible a rbranch: 169 | #split = start + splits[0] 170 | 171 | if start + 1 < split: 172 | brackets.add((start, split)) 173 | if split + 1 < end: 174 | brackets.add((split, end)) 175 | 176 | # ahora llenamos adentro 177 | self.randomly_binarize(start=start, end=split) 178 | self.randomly_binarize(start=split, end=end) 179 | 180 | def splittable(self, x, start=None, end=None): 181 | """Helper for randomly_binarize. 182 | """ 183 | if start is None: 184 | start = self.start_index 185 | if end is None: 186 | end = self.length 187 | bs = [a_b5 for a_b5 in list(self.brackets) if start < a_b5[0] or a_b5[1] < end] 188 | i = 0 189 | while i < len(bs) and (bs[i][1] <= x or x <= bs[i][0]): 190 | i += 1 191 | if i == len(bs): 192 | return True 193 | else: 194 | return False 195 | 196 | def reverse(self): 197 | """Reverse the bracketing. 198 | """ 199 | s = self.start_index 200 | n = self.length 201 | self.brackets = set((n-j+2*s, n-i+2*s) for (i, j) in self.brackets) 202 | 203 | 204 | def coincidences(b1, b2): 205 | """Count coincidences between two bracketings. 206 | """ 207 | s1 = set([(x_y[0] - b1.start_index, x_y[1] - b1.start_index) for x_y in b1.brackets]) 208 | s2 = set([(x_y6[0] - b2.start_index, x_y6[1] - b2.start_index) for x_y6 in b2.brackets]) 209 | return len(s1 & s2) 210 | 211 | 212 | def treefy(s, b): 213 | """Convert a binary bracketing b of a sentence s to a NLTK tree. 214 | b is a set and must not have the trivial top bracket. 215 | """ 216 | l = len(s) 217 | if l == 2: 218 | t = tree.Tree('X', [s[0], s[1]]) 219 | # buscar los hijos de la raiz: 220 | elif (0, l-1) in b: 221 | b2 = b - set((0, l-1)) 222 | 223 | t2 = treefy(s[:-1], b2) 224 | 225 | t = tree.Tree('X', [t2, s[-1]]) 226 | elif (1, l) in b: 227 | b2 = b - set((1, l)) 228 | b2 = set([(i_j[0]-1, i_j[1]-1) for i_j in b2]) 229 | 230 | t2 = treefy(s[1:], b2) 231 | 232 | t = tree.Tree('X', [s[0], t2]) 233 | else: 234 | x = 2 235 | while not ((0, x) in b and (x, l) in b): 236 | x = x + 1 237 | 238 | b2 = set((i, j) for (i, j) in b if 0 <= i and j <= x) 239 | b3 = set((i-x, j-x) for (i, j) in b if x <= i and j <= l) 240 | 241 | t2 = treefy(s[:x], b2) 242 | t3 = treefy(s[x:], b3) 243 | 244 | t = tree.Tree('X', [t2, t3]) 245 | 246 | return t 247 | 248 | 249 | def string_to_bracketing(s): 250 | """Converts a string to a bracketing. 251 | 252 | >>> string_to_bracketing('(DT NNP NN) (VBD (DT (VBZ (DT JJ NN))))') 253 | """ 254 | s2 = s.replace('(', '(X ') 255 | s2 = '((X '+s2+'))' 256 | t = treebank.Tree(tree.bracket_parse(s2)) 257 | b = tree_to_bracketing(t) 258 | return b 259 | 260 | 261 | def tree_to_bracketing(t, start_index=0): 262 | """t must be instance of treebank.Tree. 263 | """ 264 | l = len(t.leaves()) 265 | spans = t.spannings(leaves=False,root=False,unary=False) 266 | moved_spans = set([(a_b7[0]+start_index, a_b7[1]+start_index) for a_b7 in spans]) 267 | return Bracketing(l, moved_spans, start_index) 268 | 269 | 270 | def add(B, x): 271 | """Helper for binary_bracketings. Adds x to the indices of the brackets. 272 | """ 273 | return [[(a_b2[0]+x,a_b2[1]+x) for a_b2 in s] for s in B] 274 | 275 | 276 | def _binary_bracketings(n): 277 | """Helper for binary_bracketings. 278 | """ 279 | if n == 1: 280 | return [[]] 281 | elif n == 2: 282 | return [[(0,2)]] 283 | else: 284 | b = {} 285 | for i in range(1, n): 286 | b[i] = _binary_bracketings(i) 287 | B = [] 288 | for i in range(1, n): 289 | # todas las combinaciones posibles de b[i] y add(b[n-i], i): 290 | b1 = b[i] 291 | b2 = add(b[n-i], i) 292 | for j in range(len(b1)): 293 | for k in range(len(b2)): 294 | B = B + [[(0,n)] + b1[j] + b2[k]] 295 | 296 | return B 297 | 298 | 299 | def binary_bracketings(n): 300 | """Returns all the possible binary bracketings of n leaves. 301 | """ 302 | # remove whole span bracket and wrap into a Bracketing object: 303 | return [Bracketing(n, set(b[1:])) for b in _binary_bracketings(n)] 304 | 305 | 306 | def binary_bracketings_count(n): 307 | """Returns the number of binary bracketings of n leaves (this is, the 308 | Catalan number C_{n-1}). 309 | """ 310 | return catalan(n-1) 311 | 312 | 313 | def catalan(n): 314 | """Helper for binary_bracketings_count(n). 315 | """ 316 | if n <= 1: 317 | return 1 318 | else: 319 | # http://mathworld.wolfram.com/CatalanNumber.html 320 | return catalan(n-1)*2*(2*n-1)/(n+1) 321 | 322 | 323 | def rbranch_bracketing(length, start_index=0): 324 | """Returns the rbranch bracketing of the given length. 325 | """ 326 | b = set((i, start_index+length) for i in range(start_index+1, start_index+length-1)) 327 | return Bracketing(length, b, start_index=start_index) 328 | 329 | 330 | def lbranch_bracketing(length, start_index=0): 331 | """Returns the lbranch bracketing of the given length. 332 | """ 333 | b = set((start_index, i) for i in range(start_index+2, start_index+length)) 334 | return Bracketing(length, b, start_index=start_index) 335 | 336 | 337 | def P_split(n): 338 | """Returns a binary bracketing according to the P_split() distribution. 339 | n is the number of leaves. 340 | """ 341 | if n <= 2: 342 | return Bracketing(n) 343 | k = random.randint(1, n-1) 344 | # b = [(0, n)] + gP_split(0, k) + gP_split(k, n) 345 | b = Bracketing(n, set(gP_split(0, k) + gP_split(k, n))) 346 | return b 347 | 348 | 349 | def gP_split(i, j): 350 | """Helper for P_split(). 351 | """ 352 | if i+1 == j: 353 | b = [] 354 | else: 355 | k = random.randint(i+1, j-1) 356 | b = [(i, j)] + gP_split(i, k) + gP_split(k, j) 357 | return b 358 | 359 | 360 | # FIXME: I think it only works with b.start_index = 0. 361 | def P_split_prob(b): 362 | """Returns the probability of b according to the P_split() distribution. 363 | """ 364 | """n = b.length 365 | if n <= 2: 366 | p = 1.0 367 | else: 368 | k = 1 369 | # si el arbol es binario y n > 2 seguro que tiene que ser splittable. 370 | #while k < n and not b.splittable(k): 371 | while not b.splittable(k): 372 | k += 1 373 | 374 | p = (1.0 / float(n)) * gP_split_prob(b, 0, k) * gP_split_prob(b, k, n) 375 | 376 | return p""" 377 | return gP_split_prob(b, b.start_index, b.start_index+b.length) 378 | 379 | 380 | def gP_split_prob(b, i, j): 381 | n = j - i 382 | if n <= 2: 383 | p = 1.0 384 | else: 385 | k = i+1 386 | # si el arbol es binario y n > 2 seguro que tiene que ser splittable. 387 | #while k < n and not b.splittable(k): 388 | while not b.splittable(k, i, j): 389 | k += 1 390 | 391 | p = (1.0 / float(n-1)) * gP_split_prob(b, i, k) * gP_split_prob(b, k, j) 392 | 393 | return p 394 | -------------------------------------------------------------------------------- /modules/markov_flow_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import math 4 | from collections import Counter 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import numpy as np 11 | 12 | from torch.nn import Parameter 13 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 14 | from sklearn.metrics.cluster import v_measure_score 15 | 16 | from .utils import log_sum_exp, data_iter, to_input_tensor, \ 17 | write_conll 18 | from .projection import * 19 | 20 | 21 | 22 | class MarkovFlow(nn.Module): 23 | def __init__(self, args, num_dims): 24 | super(MarkovFlow, self).__init__() 25 | 26 | self.args = args 27 | self.device = args.device 28 | 29 | # Gaussian Variance 30 | self.var = torch.zeros(num_dims, dtype=torch.float32, 31 | device=self.device, requires_grad=False) 32 | 33 | self.num_state = args.num_state 34 | self.num_dims = num_dims 35 | self.couple_layers = args.couple_layers 36 | self.cell_layers = args.cell_layers 37 | self.hidden_units = num_dims // 2 38 | 39 | # transition parameters in log space 40 | self.tparams = Parameter( 41 | torch.Tensor(self.num_state, self.num_state)) 42 | 43 | # Gaussian means 44 | self.means = Parameter(torch.Tensor(self.num_state, self.num_dims)) 45 | 46 | if args.model == 'nice': 47 | self.nice_layer = NICETrans(self.couple_layers, 48 | self.cell_layers, 49 | self.hidden_units, 50 | self.num_dims, 51 | self.device) 52 | 53 | self.pi = torch.zeros(self.num_state, 54 | dtype=torch.float32, 55 | requires_grad=False, 56 | device=self.device).fill_(1.0/self.num_state) 57 | 58 | self.pi = torch.log(self.pi) 59 | 60 | def init_params(self, init_seed): 61 | """ 62 | init_seed:(sents, masks) 63 | sents: (seq_length, batch_size, features) 64 | masks: (seq_length, batch_size) 65 | 66 | """ 67 | 68 | # initialize transition matrix params 69 | # self.tparams.data.uniform_().add_(1) 70 | self.tparams.data.uniform_() 71 | 72 | # load pretrained model 73 | if self.args.load_nice != '': 74 | self.load_state_dict(torch.load(self.args.load_nice), strict=False) 75 | 76 | # load pretrained Gaussian baseline 77 | if self.args.load_gaussian != '': 78 | self.load_state_dict(torch.load(self.args.load_gaussian), strict=False) 79 | 80 | # initialize mean and variance with empirical values 81 | with torch.no_grad(): 82 | sents, masks = init_seed 83 | sents, _ = self.transform(sents) 84 | seq_length, _, features = sents.size() 85 | flat_sents = sents.view(-1, features) 86 | seed_mean = torch.sum(masks.view(-1, 1).expand_as(flat_sents) * 87 | flat_sents, dim=0) / masks.sum() 88 | seed_var = torch.sum(masks.view(-1, 1).expand_as(flat_sents) * 89 | ((flat_sents - seed_mean.expand_as(flat_sents)) ** 2), 90 | dim = 0) / masks.sum() 91 | self.var.copy_(seed_var) 92 | 93 | # add noise to the pretrained Gaussian mean 94 | if self.args.load_gaussian != '' and self.args.model == 'nice': 95 | self.means.data.add_(seed_mean.data.expand_as(self.means.data)) 96 | elif self.args.load_gaussian == '' and self.args.load_nice == '': 97 | self.means.data.normal_().mul_(0.04) 98 | self.means.data.add_(seed_mean.data.expand_as(self.means.data)) 99 | 100 | def _calc_log_density_c(self): 101 | # return -self.num_dims/2.0 * (math.log(2) + \ 102 | # math.log(np.pi)) - 0.5 * self.num_dims * (torch.log(self.var)) 103 | 104 | return -self.num_dims/2.0 * (math.log(2) + \ 105 | math.log(np.pi)) - 0.5 * torch.sum(torch.log(self.var)) 106 | 107 | def transform(self, x): 108 | """ 109 | Args: 110 | x: (sent_length, batch_size, num_dims) 111 | """ 112 | jacobian_loss = torch.zeros(1, device=self.device, requires_grad=False) 113 | 114 | if self.args.model == 'nice': 115 | x, jacobian_loss_new = self.nice_layer(x) 116 | jacobian_loss = jacobian_loss + jacobian_loss_new 117 | 118 | 119 | return x, jacobian_loss 120 | 121 | 122 | def forward(self, sents, masks): 123 | """ 124 | sents: (sent_length, batch_size, self.num_dims) 125 | masks: (sent_length, batch_size) 126 | 127 | """ 128 | max_length = sents.size()[0] 129 | sents, jacobian_loss = self.transform(sents) 130 | 131 | assert self.var.data.min() > 0 132 | 133 | batch_size = len(sents[0]) 134 | self.logA = self._calc_logA() 135 | self.log_density_c = self._calc_log_density_c() 136 | 137 | alpha = self.pi + self._eval_density(sents[0]) 138 | for t in range(1, max_length): 139 | density = self._eval_density(sents[t]) 140 | mask_ep = masks[t].expand(self.num_state, batch_size) \ 141 | .transpose(0, 1) 142 | alpha = torch.mul(mask_ep, 143 | self._forward_cell(alpha, density)) + \ 144 | torch.mul(1-mask_ep, alpha) 145 | 146 | # calculate objective from log space 147 | objective = torch.sum(log_sum_exp(alpha, dim=1)) 148 | 149 | return objective, jacobian_loss 150 | 151 | def _calc_alpha(self, sents, masks): 152 | """ 153 | sents: (sent_length, batch_size, self.num_dims) 154 | masks: (sent_length, batch_size) 155 | 156 | Returns: 157 | output: (batch_size, sent_length, num_state) 158 | 159 | """ 160 | max_length, batch_size, _ = sents.size() 161 | 162 | alpha_all = [] 163 | alpha = self.pi + self._eval_density(sents[0]) 164 | alpha_all.append(alpha.unsqueeze(1)) 165 | for t in range(1, max_length): 166 | density = self._eval_density(sents[t]) 167 | mask_ep = masks[t].expand(self.num_state, batch_size) \ 168 | .transpose(0, 1) 169 | alpha = torch.mul(mask_ep, self._forward_cell(alpha, density)) + \ 170 | torch.mul(1-mask_ep, alpha) 171 | alpha_all.append(alpha.unsqueeze(1)) 172 | 173 | return torch.cat(alpha_all, dim=1) 174 | 175 | def _forward_cell(self, alpha, density): 176 | batch_size = len(alpha) 177 | ep_size = torch.Size([batch_size, self.num_state, self.num_state]) 178 | alpha = log_sum_exp(alpha.unsqueeze(dim=2).expand(ep_size) + 179 | self.logA.expand(ep_size) + 180 | density.unsqueeze(dim=1).expand(ep_size), dim=1) 181 | 182 | return alpha 183 | 184 | def _backward_cell(self, beta, density): 185 | """ 186 | density: (batch_size, num_state) 187 | beta: (batch_size, num_state) 188 | 189 | """ 190 | batch_size = len(beta) 191 | ep_size = torch.Size([batch_size, self.num_state, self.num_state]) 192 | beta = log_sum_exp(self.logA.expand(ep_size) + 193 | density.unsqueeze(dim=1).expand(ep_size) + 194 | beta.unsqueeze(dim=1).expand(ep_size), dim=2) 195 | 196 | return beta 197 | 198 | def _eval_density(self, words): 199 | """ 200 | words: (batch_size, self.num_dims) 201 | 202 | """ 203 | 204 | batch_size = words.size(0) 205 | ep_size = torch.Size([batch_size, self.num_state, self.num_dims]) 206 | words = words.unsqueeze(dim=1).expand(ep_size) 207 | means = self.means.expand(ep_size) 208 | var = self.var.expand(ep_size) 209 | 210 | return self.log_density_c - \ 211 | 0.5 * torch.sum((means-words) ** 2 / var, dim=2) 212 | 213 | def _calc_logA(self): 214 | return (self.tparams - \ 215 | log_sum_exp(self.tparams, dim=1, keepdim=True) \ 216 | .expand(self.num_state, self.num_state)) 217 | 218 | def _calc_log_mul_emit(self): 219 | return self.emission - \ 220 | log_sum_exp(self.emission, dim=1, keepdim=True) \ 221 | .expand(self.num_state, self.vocab_size) 222 | 223 | def _viterbi(self, sents_var, masks): 224 | """ 225 | Args: 226 | sents_var: (sent_length, batch_size, num_dims) 227 | masks: (sent_length, batch_size) 228 | """ 229 | 230 | self.log_density_c = self._calc_log_density_c() 231 | self.logA = self._calc_logA() 232 | 233 | length, batch_size = masks.size() 234 | 235 | # (batch_size, num_state) 236 | delta = self.pi + self._eval_density(sents_var[0]) 237 | 238 | ep_size = torch.Size([batch_size, self.num_state, self.num_state]) 239 | index_all = [] 240 | 241 | # forward calculate delta 242 | for t in range(1, length): 243 | density = self._eval_density(sents_var[t]) 244 | delta_new = self.logA.expand(ep_size) + \ 245 | density.unsqueeze(dim=1).expand(ep_size) + \ 246 | delta.unsqueeze(dim=2).expand(ep_size) 247 | mask_ep = masks[t].view(-1, 1, 1).expand(ep_size) 248 | delta = mask_ep * delta_new + \ 249 | (1 - mask_ep) * delta.unsqueeze(dim=1).expand(ep_size) 250 | 251 | # index: (batch_size, num_state) 252 | delta, index = torch.max(delta, dim=1) 253 | index_all.append(index) 254 | 255 | assign_all = [] 256 | # assign: (batch_size) 257 | _, assign = torch.max(delta, dim=1) 258 | assign_all.append(assign.unsqueeze(dim=1)) 259 | 260 | # backward retrieve path 261 | # len(index_all) = length-1 262 | for t in range(length-2, -1, -1): 263 | assign_new = torch.gather(index_all[t], 264 | dim=1, 265 | index=assign.view(-1, 1)).squeeze(dim=1) 266 | 267 | assign_new = assign_new.float() 268 | assign = assign.float() 269 | assign = masks[t+1] * assign_new + (1 - masks[t+1]) * assign 270 | assign = assign.long() 271 | 272 | assign_all.append(assign.unsqueeze(dim=1)) 273 | 274 | assign_all = assign_all[-1::-1] 275 | 276 | return torch.cat(assign_all, dim=1) 277 | 278 | def test(self, 279 | test_data, 280 | test_tags, 281 | sentences=None, 282 | tagging=False, 283 | path=None, 284 | null_index=None): 285 | """Evaluate tagging performance with 286 | many-to-1 metric and VM score 287 | 288 | Args: 289 | test_data: nested list of sentences 290 | test_tags: nested list of gold tags 291 | tagging: output the predicted tags if True 292 | path: The output tag file path 293 | null_index: the null element location in Penn 294 | Treebank, only used for writing unsupervised 295 | tags for downstream parsing task 296 | 297 | Returns: 298 | Tuple1: (M1, VM score) 299 | 300 | """ 301 | 302 | pad = np.zeros(self.num_dims) 303 | 304 | total = 0.0 305 | correct = 0.0 306 | cnt_stats = {} 307 | match_dict = {} 308 | 309 | index_all = [] 310 | eval_tags = [] 311 | 312 | gold_vm = [] 313 | model_vm = [] 314 | 315 | for sents, tags in data_iter(list(zip(test_data, test_tags)), 316 | batch_size=self.args.batch_size, 317 | is_test=True, 318 | shuffle=False): 319 | total += sum(len(sent) for sent in sents) 320 | sents_var, masks = to_input_tensor(sents, 321 | pad, 322 | device=self.device) 323 | sents_var, _ = self.transform(sents_var) 324 | 325 | # index: (batch_size, seq_length) 326 | index = self._viterbi(sents_var, masks) 327 | 328 | index_all += list(index) 329 | eval_tags += tags 330 | 331 | # count 332 | for (seq_gold_tags, seq_model_tags) in zip(tags, index): 333 | for (gold_tag, model_tag) in zip(seq_gold_tags, seq_model_tags): 334 | model_tag = model_tag.item() 335 | gold_vm += [gold_tag] 336 | model_vm += [model_tag] 337 | if model_tag not in cnt_stats: 338 | cnt_stats[model_tag] = Counter() 339 | cnt_stats[model_tag][gold_tag] += 1 340 | # match 341 | for tag in cnt_stats: 342 | match_dict[tag] = cnt_stats[tag].most_common(1)[0][0] 343 | 344 | # eval many2one 345 | for (seq_gold_tags, seq_model_tags) in zip(eval_tags, index_all): 346 | for (gold_tag, model_tag) in zip(seq_gold_tags, seq_model_tags): 347 | model_tag = model_tag.item() 348 | if match_dict[model_tag] == gold_tag: 349 | correct += 1 350 | 351 | if tagging: 352 | write_conll(path, sentences, index_all, null_index) 353 | 354 | return correct/total, v_measure_score(gold_vm, model_vm) 355 | -------------------------------------------------------------------------------- /modules/dmv_viterbi_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import random 4 | import math 5 | 6 | from nltk import tree 7 | from .utils import stable_math_log 8 | 9 | 10 | harmonic_constant = 2.0 11 | 12 | def add(dict, x, val): 13 | dict[x] = dict.get(x, 0) + val 14 | 15 | class DMVDict(object): 16 | def __init__(self, d=None, default_val=math.log(0.1)): 17 | if d is None: 18 | self.d = {} 19 | else: 20 | self.d = d 21 | self.default_val = default_val 22 | 23 | def set_default_val(self, val): 24 | self.default_val = val 25 | 26 | def val(self, x): 27 | return self.d.get(x, self.default_val) 28 | 29 | def setVal(self, x, val): 30 | self.d[x] = val 31 | 32 | def add1(self, x): 33 | self.add(x, 1.0) 34 | 35 | def add(self, x, y): 36 | add(self.d, x, y) 37 | 38 | def iteritems(self): 39 | return self.d.items() 40 | 41 | 42 | def lplace_smooth(tita, count, tag_set, end_symbol, smth_const): 43 | for h in tag_set: 44 | tita.add(('attach_left', h, end_symbol), smth_const) 45 | count.add(('attach_left', end_symbol), smth_const) 46 | for a in tag_set: 47 | tita.add(('attach_left', a, h), smth_const) 48 | tita.add(('attach_right', a, h), smth_const) 49 | count.add(('attach_right', h), smth_const) 50 | count.add(('attach_left', h), smth_const) 51 | 52 | class DMV(object): 53 | def __init__(self, args): 54 | 55 | self.end_symbol = 'END' 56 | self.tita = None 57 | 58 | self.harmonic = False 59 | self.args = args 60 | 61 | def set_harmonic(self, val): 62 | self.harmonic = val 63 | 64 | def init_params(self, train_tags, tag_set): 65 | tita, count = DMVDict(), DMVDict() 66 | # harmonic initializer 67 | lplace_smooth(tita, count, tag_set, 68 | self.end_symbol, self.args.smth_const) 69 | self.set_harmonic(True) 70 | for i, s in enumerate(filter(lambda s: len(s) > 1, \ 71 | train_tags)): 72 | if i % 1000 == 0: 73 | print('initialize, sentence %d' % i) 74 | parse_tree, prob = self.dep_parse(s) 75 | self.MStep_s(parse_tree, tita, count) 76 | self.MStep(tita, count) 77 | 78 | @staticmethod 79 | def tree_to_depset(t): 80 | # add the root symbol (-1) 81 | res = set([(t.label().index, -1)]) 82 | res.update(DMV._tree_to_depset(t)) 83 | return sorted(res) 84 | 85 | @staticmethod 86 | def _tree_to_depset(t): 87 | node = t.label() 88 | index = node.index 89 | mark = node.mark 90 | #res = set([(index, -1)]) 91 | # len(t) is the number of children 92 | if len(t) > 1: 93 | if mark == '<>': 94 | arg = t[0] 95 | elif mark == '>': 96 | arg = t[1] 97 | res = set([(arg.label().index, index)]) 98 | res.update(DMV._tree_to_depset(t[0]), DMV._tree_to_depset(t[1])) 99 | else: 100 | if not isinstance(t[0], str): 101 | res = DMV._tree_to_depset(t[0]) 102 | else: 103 | res = set() 104 | return res 105 | 106 | def eval(self, gold, tags, all_len=False): 107 | """ 108 | Args: 109 | gold: A nested list of heads 110 | all_len: True if evaluating on all lengths 111 | 112 | """ 113 | 114 | # parse: a list of DepSets 115 | parse = [] 116 | for k, s in enumerate(tags): 117 | parse.append(self.tree_to_depset(self.parse(s))) 118 | if all_len: 119 | if k % 10 == 0: 120 | print('parse %d trees' % k) 121 | 122 | cnt = 0 123 | dir_cnt = 0.0 124 | undir_cnt = 0.0 125 | 126 | for gold_s, parse_s in zip(gold, parse): 127 | length = len(gold_s) 128 | if length > 1: 129 | (directed, undirected) = self.measures(gold_s, parse_s) 130 | cnt += length 131 | dir_cnt += directed 132 | undir_cnt += undirected 133 | 134 | dir_acu = dir_cnt / cnt 135 | undir_acu = undir_cnt / cnt 136 | 137 | return (dir_acu, undir_acu) 138 | 139 | @staticmethod 140 | def measures(gold_s, parse_s): 141 | # Helper for eval(). 142 | (d, u) = (0, 0) 143 | for (a, b) in gold_s: 144 | (a, b) = (a-1, b-1) 145 | b1 = (a, b) in parse_s 146 | b2 = (b, a) in parse_s 147 | if b1: 148 | d += 1.0 149 | u += 1.0 150 | if b2: 151 | u += 1.0 152 | 153 | return (d, u) 154 | 155 | def EStep(self, s): 156 | pio = self.p_inside_outside(s) 157 | 158 | return pio 159 | 160 | def MStep(self, tita, count): 161 | 162 | for x, p in tita.iteritems(): 163 | p = float(p) 164 | if p == 0.0: 165 | raise ValueError 166 | elif x[0] == 'stop_left': 167 | tita.setVal(x, math.log(p / count.val(x))) 168 | elif x[0] == 'stop_right': 169 | tita.setVal(x, math.log(p / count.val(x))) 170 | elif x[0] == 'attach_left': 171 | tita.setVal(x, math.log(p / count.val(('attach_left', x[2])))) 172 | elif x[0] == 'attach_right': 173 | tita.setVal(x, math.log(p / count.val(('attach_right', x[2])))) 174 | p_new = tita.val(x) 175 | 176 | if p_new > 0: 177 | self.count = count 178 | print('(x, p, p_new) =', (x, p, p_new)) 179 | raise ValueError 180 | self.tita = tita 181 | 182 | def _calc_maxval(self, t): 183 | max_val = 0 184 | node = t.label() 185 | max_val = max(node.r_val, node.l_val) 186 | 187 | max_val_list = [max_val] 188 | for child in t: 189 | if not isinstance(child, str): 190 | max_val_list += [self._calc_maxval(child)] 191 | 192 | return max(max_val_list) 193 | 194 | 195 | 196 | def _calc_stats(self, t, tita, count): 197 | node = t.label() 198 | index = node.index 199 | mark = node.mark 200 | word = node.word 201 | l_val = node.l_val 202 | r_val = node.r_val 203 | 204 | # calc stop denom 205 | if mark == '>': 206 | count.add(('stop_right', word, r_val == 0), 1) 207 | elif mark == '<>': 208 | count.add(('stop_left', word, l_val == 0), 1) 209 | 210 | 211 | if len(t) > 1: 212 | if mark == '<>': 213 | arg = t[0] 214 | tita.add(('attach_left', arg.label().word, word), 1) 215 | count.add(('attach_left', word), 1) 216 | elif mark == '>': 217 | arg = t[1] 218 | tita.add(('attach_right', arg.label().word, word), 1) 219 | count.add(('attach_right', word), 1) 220 | self._calc_stats(t[0], tita, count) 221 | self._calc_stats(t[1], tita, count) 222 | else: 223 | if not isinstance(t[0], str): 224 | if mark == '|': 225 | tita.add(('stop_left', word, l_val == 0), 1) 226 | 227 | elif mark == '<>': 228 | tita.add(('stop_right', word, r_val == 0), 1) 229 | self._calc_stats(t[0], tita, count) 230 | else: 231 | assert mark == '>' 232 | 233 | def MStep_s(self, t, tita, count): 234 | 235 | h = self.end_symbol 236 | count.add(('attach_left', h), 1) 237 | tita.add(('attach_left', t.label().word, h), 1) 238 | self._calc_stats(t, tita, count) 239 | 240 | 241 | def parse(self, s): 242 | t, w = self.dep_parse(s) 243 | return t 244 | 245 | def dep_parse(self, s): 246 | """ 247 | output: 248 | returned t is a nltk.tree.Tree without root node 249 | """ 250 | parse = {} 251 | # OPTIMIZATION: END considered only explicitly 252 | # s = s + [self.end_symbol] 253 | 254 | n = len(s) 255 | 256 | for i in range(n): 257 | j = i + 1 258 | w = str(s[i]) 259 | t1 = tree.Tree(Node('>', w, i, 0, 0), [w]) 260 | 261 | parse[i, j] = ParseDict(self.unary_parses(math.log(1.0), t1, i, j)) 262 | 263 | for l in range(2, n+1): 264 | for i in range(n-l+1): 265 | j = i + l 266 | parse_dict = ParseDict() 267 | for k in range(i+1, j): 268 | for (p1, t1) in parse[i, k].itervalues(): 269 | for (p2, t2) in parse[k, j].itervalues(): 270 | n1 = t1.label() 271 | n2 = t2.label() 272 | if n1.mark == '>' and n2.mark == '|': 273 | m = n1.index 274 | h = n1.word 275 | p = self.p_nonstop_right(h, n1.r_val, self.harmonic) + \ 276 | self.p_attach_right(n2.word, h, self.harmonic, n2.index - m) + \ 277 | p1 + p2 278 | new_node = Node(n1.mark, n1.word, n1.index, n1.l_val, n1.r_val + 1) 279 | t = tree.Tree(new_node, [t1, t2]) 280 | parse_dict.add(p, t) 281 | if n1.mark == '|' and n2.mark == '<>': 282 | m = n2.index 283 | h = n2.word 284 | p = self.p_nonstop_left(h, n2.l_val, self.harmonic) + \ 285 | self.p_attach_left(n1.word, h, self.harmonic, m - n1.index) + \ 286 | p1 + p2 287 | new_node = Node(n2.mark, n2.word, n2.index, n2.l_val + 1, n2.r_val) 288 | t = tree.Tree(new_node, [t1, t2]) 289 | parse_dict.add(p, t) 290 | 291 | parse[i, j] = ParseDict(sum((self.unary_parses(p, t, i, j) \ 292 | for (p, t) in parse_dict.itervalues()), [])) 293 | 294 | w = s[0] 295 | (p1, t1) = parse[0, n].val('|'+w+'0') 296 | t_max, p_max = t1, p1 + self.p_attach_left(w, self.end_symbol, self.harmonic) 297 | l = [(t_max, p_max)] 298 | for i in range(1, n): 299 | w = s[i] 300 | (p1, t1) = parse[0, n].val('|'+w+str(i)) 301 | p = p1 + self.p_attach_left(w, self.end_symbol, self.harmonic) 302 | if p > p_max: 303 | p_max = p 304 | l = [(t1, p)] 305 | elif p == p_max: 306 | l += [(t1, p)] 307 | (t_max, p_max) = self.choice(l, self.args.choice) 308 | 309 | return (t_max, p_max) 310 | 311 | def choice(self, l, method): 312 | """ 313 | select on parse tree from list l, 314 | which is a list of tuple (t, p) 315 | 316 | """ 317 | if method == 'random': 318 | return random.choice(l) 319 | elif method == 'minival': 320 | (t_min, p_min) = l[0] 321 | val_min = 10 322 | for (t, p) in l: 323 | val = self._calc_maxval(t) 324 | print(val) 325 | if val < val_min: 326 | (t_min, p_min) = (t, p) 327 | val_min = val 328 | 329 | return (t_min, p_min) 330 | elif method == 'bias_middle': 331 | (t_min, p_min) = l[0] 332 | min_dist = 10 333 | for (t, p) in l: 334 | middle = (len(t.leaves())) / 2.0 335 | dist = abs(t.label().index - middle) 336 | if dist < min_dist: 337 | (t_min, p_min) = (t, p) 338 | min_dist = dist 339 | return (t_min, p_min) 340 | elif method == 'soft_bias_middle': 341 | new_list = [random.choice(l)] 342 | for (t, p) in l: 343 | middle = (len(t.leaves())) / 2.0 344 | dist = abs(t.label().index - middle) + 1 345 | if dist < middle: 346 | new_list += [(t, p)] 347 | return random.choice(new_list) 348 | elif method == 'exclude_end': 349 | new_list = [] 350 | for (t, p) in l: 351 | length = len(t.leaves()) 352 | if (t.label().index != 0 and t.label().index != length - 1) or (len(t.leaves()) < 5): 353 | new_list += [(t, p)] 354 | 355 | if len(new_list) == 0: 356 | new_list = l 357 | return random.choice(new_list) 358 | elif method == 'bias_left': 359 | return l[0] 360 | 361 | 362 | 363 | def unary_parses(self, p, t, i, j): 364 | node = t.label() 365 | l_val = node.l_val 366 | r_val = node.r_val 367 | if node.mark == '|': 368 | res = [] 369 | 370 | elif node.mark == '<>': 371 | p2 = self.p_stop_left(node.word, l_val, self.harmonic) + p 372 | t2 = tree.Tree(Node('|', node.word, node.index, l_val, r_val), [t]) 373 | res = [(p2, t2)] 374 | elif node.mark == '>': 375 | p2 = self.p_stop_right(node.word, r_val, self.harmonic) + p 376 | t2 = tree.Tree(Node('<>', node.word, node.index, l_val, r_val), [t]) 377 | res = self.unary_parses(p2, t2, i, j) 378 | return [(p, t)] + res 379 | 380 | def p_nonstop_left(self, w, val, harmonic=False): 381 | try: 382 | return stable_math_log(1.0 - math.exp(self.p_stop_left(w, val, harmonic))) 383 | except ValueError: 384 | print(math.exp(self.p_stop_left(w, val, harmonic)), 385 | self.p_stop_left(w, val, harmonic)) 386 | 387 | def p_nonstop_right(self, w, val, harmonic=False): 388 | return stable_math_log(1.0 - math.exp(self.p_stop_right(w, val, harmonic))) 389 | 390 | def p_stop_left(self, w, val, harmonic=False): 391 | if harmonic: 392 | if val == 0: 393 | return math.log(self.args.stop_adj) 394 | else: 395 | return math.log(1 - self.args.stop_adj) 396 | 397 | return self.tita.val(('stop_left', w, val == 0)) 398 | 399 | def p_stop_right(self, w, val, harmonic=False): 400 | if harmonic: 401 | if val == 0: 402 | return math.log(self.args.stop_adj) 403 | else: 404 | return math.log(1 - self.args.stop_adj) 405 | 406 | return self.tita.val(('stop_right', w, val == 0)) 407 | 408 | def p_attach_left(self, a, h, harmonic=False, dist=None): 409 | if harmonic: 410 | if h == self.end_symbol: 411 | return math.log(0.02) 412 | return math.log(1.0 / (dist + harmonic_constant)) 413 | return self.tita.val(('attach_left', a, h)) 414 | 415 | def p_attach_right(self, a, h, harmonic=False, dist=None): 416 | if harmonic: 417 | return math.log(1.0 / (dist + harmonic_constant)) 418 | return self.tita.val(('attach_right', a, h)) 419 | 420 | 421 | class Node(object): 422 | def __init__(self, mark, word, index, l_val, r_val): 423 | self.mark = mark 424 | self.word = word 425 | self.index = index 426 | self.l_val = l_val 427 | self.r_val = r_val 428 | 429 | def __eq__(self, other): 430 | if not isinstance(other, Node): 431 | return False 432 | return (self.mark, self.word, self.index, self.l_val, self.r_val) \ 433 | == (other.mark, other.word, other.index, other.l_val, other.r_val) 434 | 435 | def __str__(self): 436 | return str(self.mark) + str(self.word) + str(self.index) 437 | 438 | def __repr__(self): 439 | return self.__str__() 440 | 441 | 442 | class ParseDict(object): 443 | def __init__(self, parses=None): 444 | self.dict = {} 445 | if parses is not None: 446 | self.add_all(parses) 447 | 448 | def val(self, node): 449 | return self.dict[str(node)] 450 | 451 | def add(self, p, t): 452 | n = t.label() 453 | s = str(n) 454 | if (s not in self.dict) or (self.dict[s][0] < p): 455 | self.dict[s] = (p, t) 456 | 457 | def add_all(self, parses): 458 | for (p, t) in parses: 459 | self.add(p, t) 460 | 461 | def itervalues(self): 462 | return self.dict.values() 463 | --------------------------------------------------------------------------------