├── presentation.pdf ├── statetr ├── .predict.bash.swp ├── predict_wsj.bash ├── statetr.bash ├── statetr_g2g.bash ├── predict.bash ├── restore_conllu_lines.pl ├── conllu_to_conllx.pl ├── conllu_to_conllx_no_underline.pl ├── test.py ├── substitue_underline.py ├── statetr.yml ├── dep2conllx.py ├── README.md ├── file_utils.py └── featurize.py ├── senttr ├── parser │ ├── __init__.py │ ├── cmds │ │ ├── __init__.py │ │ ├── predict.py │ │ └── train.py │ ├── utils │ │ ├── __init__.py │ │ ├── data.py │ │ ├── corpus.py │ │ ├── scalar_mix.py │ │ ├── vocab.py │ │ └── base.py │ ├── metric.py │ ├── parser.py │ └── model.py ├── config.ini ├── senttr.bash ├── senttr_g2g.bash ├── config.py ├── restore_conllu_lines.pl ├── predict.bash ├── run.py ├── conllu_to_conllx.pl ├── conllu_to_conllx_no_underline.pl ├── substitue_underline.py ├── parserstate.py ├── environment.yml ├── transition.py └── README.md ├── README.md ├── sample_data ├── train.conll └── dev.conll └── LICENSE /presentation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alirezamshi-zz/G2GTr/HEAD/presentation.pdf -------------------------------------------------------------------------------- /statetr/.predict.bash.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alirezamshi-zz/G2GTr/HEAD/statetr/.predict.bash.swp -------------------------------------------------------------------------------- /senttr/parser/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .model import Model 4 | from .parser import Parser 5 | 6 | 7 | __all__ = ['Parser', 'Model'] 8 | -------------------------------------------------------------------------------- /senttr/parser/cmds/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .predict import Predict 4 | from .train import Train 5 | 6 | 7 | __all__ = ['Predict', 'Train'] 8 | -------------------------------------------------------------------------------- /senttr/parser/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from . import data 4 | from .corpus import Corpus 5 | from .vocab import Vocab 6 | 7 | __all__ = ['data', 'Corpus', 'Vocab'] 8 | -------------------------------------------------------------------------------- /senttr/config.ini: -------------------------------------------------------------------------------- 1 | [Network] 2 | embed_dropout = 0.33 3 | n_mlp_arc = 500 4 | n_mlp_rel = 100 5 | mlp_dropout = 0.33 6 | 7 | [Optimizer] 8 | lr = 1e-5 9 | warmupproportion = 0.01 10 | 11 | [Run] 12 | batch_size = 1000 13 | epochs = 12 14 | patience = 2 15 | -------------------------------------------------------------------------------- /senttr/senttr.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | MAIN_PATH="" 3 | BERT_PATH="" 4 | DATA_PATH="" 5 | 6 | python run.py train --lr 1e-5 -w 0.001 --modelname senttr --batch_size 40 --buckets 10 \ 7 | --ftrain $DATA_PATH/train.conll \ 8 | --ftrain_seq $DATA_PATH/train.seq \ 9 | --ftest $DATA_PATH/test.conll \ 10 | --fdev $DATA_PATH/dev.conll \ 11 | --bert_path $BERT_PATH --punct --n_attention_layer 6 --epochs 12 --act_thr 280 \ 12 | --main_path $MAIN_PATH 13 | -------------------------------------------------------------------------------- /statetr/predict_wsj.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model_name="" 3 | datapath="" 4 | batch_size=40 5 | main_path="" 6 | output_path="" 7 | 8 | if [ ! -d $output_path ]; then 9 | mkdir -p $output_path; 10 | fi 11 | 12 | 13 | python $main_path/test.py --batchsize $batch_size --model_name $model_name --datapath $output_path --testfile $datapath --mainpath $main_path 14 | python $main_path/dep2conllx.py $datapath $main_path $model_name > $output_path/pred.conllx 15 | perl $main_path/eval.pl -g $datapath -s $output_path/pred.conllx -q 16 | -------------------------------------------------------------------------------- /senttr/senttr_g2g.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | MAIN_PATH="" 3 | BERT_PATH="" 4 | DATA_PATH="" 5 | 6 | python run.py train --lr1 1e-5 --lr2 1e-4 -w 0.001 --modelname senttr_g2g --batch_size 40 --buckets 10 \ 7 | --ftrain $DATA_PATH/train.conll \ 8 | --ftrain_seq $DATA_PATH/train.seq \ 9 | --ftest $DATA_PATH/test.conll \ 10 | --fdev $DATA_PATH/dev.conll \ 11 | --bert_path $BERT_PATH --punct --n_attention_layer 6 --epochs 12 \ 12 | --input_graph --act_thr 210 --use_two_opts --main_path $MAIN_PATH 13 | -------------------------------------------------------------------------------- /statetr/statetr.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATA_PATH="/idiap/temp/amohammadshahi/Debug_transformer/edited-transformer-new-ud-swap/data" 3 | BERT_NAME="bert-base-uncased" 4 | BERT_PATH="/idiap/temp/amohammadshahi/Debug_transformer/graph-based-g2g-parser/" 5 | MAIN_PATH="/idiap/temp/amohammadshahi/Debug_transformer/emnlp/EMNLP2020/emnlp_statetr" 6 | python run.py --mean_seq --lr 1e-5 --lowercase --usepos --withpunct \ 7 | --batchsize 40 --nepochs 12 --warmupproportion 0.01 --Bertoptim \ 8 | --nattentionlayer 6 --nattentionheads 12 --seppoint --withbert --fhistmodel --fcompmodel \ 9 | --outputname statetr --mainpath $MAIN_PATH \ 10 | --datapath $DATA_PATH \ 11 | --bertname $BERT_NAME --bertpath $BERT_PATH --use_topbuffer --use_justexist 12 | -------------------------------------------------------------------------------- /statetr/statetr_g2g.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATA_PATH="/idiap/temp/amohammadshahi/Debug_transformer/edited-transformer-new-ud-swap/data" 3 | BERT_NAME="bert-base-uncased" 4 | BERT_PATH="/idiap/temp/amohammadshahi/Debug_transformer/graph-based-g2g-parser/" 5 | MAIN_PATH="/idiap/temp/amohammadshahi/Debug_transformer/emnlp/EMNLP2020/emnlp_statetr" 6 | python run.py --mean_seq --lr 1e-5 --graphinput --lowercase --usepos --withpunct \ 7 | --batchsize 12 --nepochs 12 --warmupproportion 0.01 --Bertoptim \ 8 | --nattentionlayer 6 --nattentionheads 12 --fhistmodel --seppoint --withbert \ 9 | --outputname statetr_g2g --mainpath $MAIN_PATH \ 10 | --datapath $DATA_PATH \ 11 | --bertname $BERT_NAME --bertpath $BERT_PATH --use_topbuffer --use_justexist --debug 12 | -------------------------------------------------------------------------------- /statetr/predict.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model_name="" 3 | datapath="" 4 | batch_size=16 5 | main_path="" 6 | split="test" 7 | lang="" 8 | output_path="" 9 | 10 | if [ ! -d $output_path ]; then 11 | mkdir -p $output_path; 12 | fi 13 | 14 | 15 | perl $main_path/conllu_to_conllx.pl < $datapath/$lang-ud-$split.conllu > $output_path/org.conllx 16 | perl $main_path/conllu_to_conllx_no_underline.pl < $datapath/$lang-ud-$split.conllu > $output_path/original_nounderline.conllx 17 | 18 | python $main_path/test.py --batchsize $batch_size --model_name $model_name --model_name2 $model_name --datapath $output_path --testfile org.conllx 19 | python $main_path/dep2conllx.py $output_path/org.conllx $model_name > $output_path/pred.conllx 20 | 21 | python $main_path/substitue_underline.py $output_path/original_nounderline.conllx $output_path/pred.conllx $output_path/pred_nounderline.conllx 22 | perl $main_path/restore_conllu_lines.pl $output_path/pred_nounderline.conllx $datapath/$lang-ud-$split.conllu > $output_path/pred.conllu 23 | 24 | python $main_path/ud_eval.py $datapath/$lang-ud-$split.conllu $output_path/pred.conllu -v 25 | -------------------------------------------------------------------------------- /senttr/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from ast import literal_eval 4 | from configparser import ConfigParser 5 | 6 | 7 | class Config(object): 8 | 9 | def __init__(self, fname): 10 | super(Config, self).__init__() 11 | 12 | self.config = ConfigParser() 13 | self.config.read(fname) 14 | self.kwargs = dict((option, literal_eval(value)) 15 | for section in self.config.sections() 16 | for option, value in self.config.items(section)) 17 | 18 | def __repr__(self): 19 | info = f"{self.__class__.__name__}:\n" 20 | for i, (option, value) in enumerate(self.kwargs.items()): 21 | info += f"{option:15} {value:<25}" + ('\n' if i % 2 > 0 else '') 22 | if i % 2 == 0: 23 | info += '\n' 24 | 25 | return info 26 | 27 | def __getattr__(self, attr): 28 | return self.kwargs.get(attr, None) 29 | 30 | def __getstate__(self): 31 | return vars(self) 32 | 33 | def __setstate__(self, state): 34 | self.__dict__.update(state) 35 | 36 | def update(self, kwargs): 37 | self.kwargs.update(kwargs) 38 | -------------------------------------------------------------------------------- /senttr/restore_conllu_lines.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # Merges a CoNLL-X and a CoNLL-U file. CoNLL-X is an output from an old parser, 3 | # CoNLL-U is the desired output format, which will be compared to the gold 4 | # standard. All node lines will be copied from the CoNLL-X file, only CoNLL-U 5 | # specific lines will be taken from the CoNLL-U file. These include sentence 6 | # level comments, empty nodes and, most importantly, multi-word token lines. 7 | # Copyright © 2017 Dan Zeman 8 | # License: GNU GPL 9 | 10 | use utf8; 11 | use open ':utf8'; 12 | binmode(STDIN, ':utf8'); 13 | binmode(STDOUT, ':utf8'); 14 | binmode(STDERR, ':utf8'); 15 | 16 | # Usage: restore_conllu_lines.pl x.conll x.conllu > x-merged.conllu 17 | 18 | my $xin = $ARGV[0]; 19 | my $uin = $ARGV[1]; 20 | open(XIN, $xin) or die("Cannot read $xin: $!"); 21 | open(UIN, $uin) or die("Cannot read $uin: $!"); 22 | while() 23 | { 24 | if(m/^\#/ || 25 | m/^\d+\./ || 26 | m/^\d+-/) 27 | { 28 | print; 29 | } 30 | else # node line or empty line after a sentence 31 | { 32 | my $uline = $_; 33 | my $xline = ; 34 | print($xline); 35 | } 36 | } 37 | close(XIN); 38 | close(UIN); 39 | -------------------------------------------------------------------------------- /statetr/restore_conllu_lines.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # Merges a CoNLL-X and a CoNLL-U file. CoNLL-X is an output from an old parser, 3 | # CoNLL-U is the desired output format, which will be compared to the gold 4 | # standard. All node lines will be copied from the CoNLL-X file, only CoNLL-U 5 | # specific lines will be taken from the CoNLL-U file. These include sentence 6 | # level comments, empty nodes and, most importantly, multi-word token lines. 7 | # Copyright © 2017 Dan Zeman 8 | # License: GNU GPL 9 | 10 | use utf8; 11 | use open ':utf8'; 12 | binmode(STDIN, ':utf8'); 13 | binmode(STDOUT, ':utf8'); 14 | binmode(STDERR, ':utf8'); 15 | 16 | # Usage: restore_conllu_lines.pl x.conll x.conllu > x-merged.conllu 17 | 18 | my $xin = $ARGV[0]; 19 | my $uin = $ARGV[1]; 20 | open(XIN, $xin) or die("Cannot read $xin: $!"); 21 | open(UIN, $uin) or die("Cannot read $uin: $!"); 22 | while() 23 | { 24 | if(m/^\#/ || 25 | m/^\d+\./ || 26 | m/^\d+-/) 27 | { 28 | print; 29 | } 30 | else # node line or empty line after a sentence 31 | { 32 | my $uline = $_; 33 | my $xline = ; 34 | print($xline); 35 | } 36 | } 37 | close(XIN); 38 | close(UIN); 39 | -------------------------------------------------------------------------------- /senttr/parser/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | class Metric(object): 5 | 6 | def __init__(self, eps=1e-5): 7 | super(Metric, self).__init__() 8 | 9 | self.eps = eps 10 | self.total = 0.0 11 | self.correct_arcs = 0.0 12 | self.correct_rels = 0.0 13 | 14 | def __repr__(self): 15 | return f"UAS: {self.uas:.2%} LAS: {self.las:.2%}" 16 | 17 | def __call__(self, pred_arcs, pred_rels, gold_arcs, gold_rels): 18 | arc_mask = pred_arcs.eq(gold_arcs) 19 | rel_mask = pred_rels.eq(gold_rels) & arc_mask 20 | 21 | self.total += len(arc_mask) 22 | self.correct_arcs += arc_mask.sum().item() 23 | self.correct_rels += rel_mask.sum().item() 24 | 25 | def __lt__(self, other): 26 | return self.score < other 27 | 28 | def __le__(self, other): 29 | return self.score <= other 30 | 31 | def __ge__(self, other): 32 | return self.score >= other 33 | 34 | def __gt__(self, other): 35 | return self.score > other 36 | 37 | @property 38 | def score(self): 39 | return self.las 40 | 41 | @property 42 | def uas(self): 43 | return self.correct_arcs / (self.total + self.eps) 44 | 45 | @property 46 | def las(self): 47 | return self.correct_rels / (self.total + self.eps) 48 | -------------------------------------------------------------------------------- /senttr/predict.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | main_path="/idiap/temp/amohammadshahi/Debug_transformer/emnlp/EMNLP2020/emnlp-senttr" 3 | input_data="/idiap/temp/amohammadshahi/Debug_transformer/edited-transformer-new-ud-swap/data/test.conll" 4 | type="conllx" 5 | modelname="senttr" 6 | output_path="idiap/temp/amohammadshahi/Debug_transformer/emnlp/EMNLP2020/emnlp-senttr/predict_output" 7 | 8 | if [ ! -d $output_path ]; then 9 | mkdir -p $output_path; 10 | fi 11 | if [ "$type" = "conllu" ]; then 12 | echo "Input is CONLL-U format" 13 | perl $main_path/conllu_to_conllx.pl < $input_data > $output_path/original.conllx 14 | perl $main_path/conllu_to_conllx_no_underline.pl < $input_data > $output_path/original_nounderline.conllx 15 | else 16 | echo "Input is CONLL-X format" 17 | cp $input_data $output_path/original.conllx 18 | fi 19 | 20 | echo "Predicting the input file" 21 | python run.py predict --modelname $modelname --fdata $output_path/original.conllx --fpred $output_path/pred.conllx --mainpath $main_path/ 22 | echo "Finished Prediction" 23 | if [ "$type" = "conllu" ]; then 24 | echo "Converting back to CONLL-U format" 25 | python substitue_underline.py $output_path/original_nounderline.conllx $output_path/pred.conllx $output_path/pred_nounderline.conllx 26 | perl $main_path/restore_conllu_lines.pl $output_path/pred_nounderline.conllx $input_data > $output_path/pred.conllu 27 | else 28 | echo "Output is CONLL-X format" 29 | fi 30 | 31 | if [ "$type" = "conllu" ]; then 32 | echo "Evaluating based on official UD script" 33 | python $main_path/ud_eval.py $input_data $output_path/pred.conllu -v 34 | else 35 | if [ "$type" = "conllx" ]; then 36 | perl eval.pl -g $output_path/original.conllx -s $output_path/pred.conllx -q 37 | else 38 | perl eval.pl -g $output_path/original.conllx -s $output_path/pred.conllx -q -p 39 | fi 40 | echo "done" 41 | fi 42 | -------------------------------------------------------------------------------- /senttr/run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import argparse 4 | import os 5 | from parser.cmds import Predict, Train 6 | 7 | import torch 8 | 9 | from config import Config 10 | 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser( 14 | description='Create the Biaffine Parser model.' 15 | ) 16 | subparsers = parser.add_subparsers(title='Commands', dest='mode') 17 | subcommands = { 18 | 'predict': Predict(), 19 | 'train': Train() 20 | } 21 | for name, subcommand in subcommands.items(): 22 | subparser = subcommand.add_subparser(name, subparsers) 23 | subparser.add_argument('--conf', '-c', default='config.ini', 24 | help='path to config file') 25 | subparser.add_argument('--model', '-m', default='exp/', 26 | help='path to model file') 27 | subparser.add_argument('--vocab', '-v', default='exp/', 28 | help='path to vocab file') 29 | subparser.add_argument('--device', '-d', default='-1', 30 | help='ID of GPU to use') 31 | subparser.add_argument('--seed', '-s', default=1, type=int, 32 | help='seed for generating random numbers') 33 | subparser.add_argument('--threads', '-t', default=4, type=int, 34 | help='max num of threads') 35 | subparser.add_argument('--batch_size', default=1000, type=int, 36 | help='max num of buckets to use') 37 | 38 | 39 | 40 | args = parser.parse_args() 41 | 42 | print(f"Set the max num of threads to {args.threads}") 43 | print(f"Set the seed for generating random numbers to {args.seed}") 44 | print(f"Set the device with ID {args.device} visible") 45 | torch.set_num_threads(args.threads) 46 | torch.manual_seed(args.seed) 47 | 48 | print(f"Override the default configs with parsed arguments") 49 | config = Config(args.conf) 50 | config.update(vars(args)) 51 | 52 | print(f"Run the subcommand in mode {args.mode}") 53 | cmd = subcommands[args.mode] 54 | cmd(config) 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph-to-Graph Transformer 2 | Pytorch implementation of ["Graph-to-Graph Transformer for Transition-based Dependency Parsing"](https://www.aclweb.org/anthology/2020.findings-emnlp.294/) 3 | 4 | ## Sentence Transformer 5 | 6 | To reproduce results of SentTr and SentTr+G2GTr model, you can find all required materials in `senttr` directory. 7 | 8 | ## State Transformer 9 | 10 | To reproduce results of all variations of StateTr model, you can find all required materials in `statetr` directory. 11 | 12 | ## General Graph-to-Graph Transformer 13 | 14 | To use our Graph-to-Graph Transformer for other NLP tasks, plese refer to [this repository](https://github.com/idiap/g2g-transformer). 15 | 16 | ## Citation 17 | 18 | If you use the code for your research, please cite this work as: 19 | 20 | ``` 21 | @inproceedings{mohammadshahi-henderson-2020-graph, 22 | title = "Graph-to-Graph Transformer for Transition-based Dependency Parsing", 23 | author = "Mohammadshahi, Alireza and 24 | Henderson, James", 25 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: Findings", 26 | month = nov, 27 | year = "2020", 28 | address = "Online", 29 | publisher = "Association for Computational Linguistics", 30 | url = "https://www.aclweb.org/anthology/2020.findings-emnlp.294", 31 | pages = "3278--3289", 32 | abstract = "We propose the Graph2Graph Transformer architecture for conditioning on and predicting arbitrary graphs, and apply it to the challenging task of transition-based dependency parsing. After proposing two novel Transformer models of transition-based dependency parsing as strong baselines, we show that adding the proposed mechanisms for conditioning on and predicting graphs of Graph2Graph Transformer results in significant improvements, both with and without BERT pre-training. The novel baselines and their integration with Graph2Graph Transformer significantly outperform the state-of-the-art in traditional transition-based dependency parsing on both English Penn Treebank, and 13 languages of Universal Dependencies Treebanks. Graph2Graph Transformer can be integrated with many previous structured prediction methods, making it easy to apply to a wide range of NLP tasks.", 33 | } 34 | ``` 35 | -------------------------------------------------------------------------------- /senttr/conllu_to_conllx.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # Converts a CoNLL-U file (Universal Dependencies) to the older CoNLL-X format. 3 | # The conversion is by definition lossy. It is a lightweight converter: we do not check for validity of the CoNLL-U input! 4 | # Copyright © 2015, 2017 Dan Zeman 5 | # License: GNU GPL 6 | 7 | use utf8; 8 | use open ':utf8'; 9 | binmode(STDIN, ':utf8'); 10 | binmode(STDOUT, ':utf8'); 11 | binmode(STDERR, ':utf8'); 12 | 13 | while(<>) 14 | { 15 | # Discard sentence-level comment lines. 16 | next if(m/^\#/); 17 | # Discard lines of fused surface tokens. Syntactic words will be the node-level unit in the output file. 18 | next if(m/^\d+-\d+/); 19 | # Discard lines with empty nodes from the enhanced representation. 20 | next if(m/^\d+\./); 21 | if(m/\t/) 22 | { 23 | s/\r?\n$//; 24 | my @fields = split(/\t/, $_); 25 | # CoNLL-U v2 (December 2016) allows spaces in FORM and LEMMA but older tools may not survive it. 26 | # Replace spaces by underscores. 27 | $fields[1] =~ s/ /_/g; 28 | $fields[2] =~ s/ /_/g; 29 | # CoNLL-X specification did not allow POSTAG to be empty if there was CPOSTAG, and some tools rely on it. 30 | # Also, some tools rely on POSTAG being a fine-grained version of CPOSTAG, i.e. CPOSTAG should be always 31 | # inferrable from POSTAG. This is not an explicit requirement in the format specification but we will 32 | # enforce it anyway. 33 | # Copy CPOSTAG to POSTAG if POSTAG is empty. Otherwise, prepend CPOSTAG to POSTAG. 34 | if($fields[4] eq '_') 35 | { 36 | $fields[4] = $fields[3]; 37 | } 38 | else 39 | { 40 | $fields[4] = $fields[3].'_'.$fields[4]; 41 | } 42 | # The last two columns ([8] and [9]) had different meaning in CoNLL-X. 43 | # In many cases it is probably harmless to keep their contents from CoNLL-U, but some tools may rely on their expectations about these columns, 44 | # especially in [8] they may require either '_' or a numeric value. Let's erase the contents of these columns to be on the safe side. 45 | $fields[8] = $fields[9] = '_'; 46 | $_ = join("\t", @fields)."\n"; 47 | } 48 | print; 49 | } 50 | -------------------------------------------------------------------------------- /statetr/conllu_to_conllx.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # Converts a CoNLL-U file (Universal Dependencies) to the older CoNLL-X format. 3 | # The conversion is by definition lossy. It is a lightweight converter: we do not check for validity of the CoNLL-U input! 4 | # Copyright © 2015, 2017 Dan Zeman 5 | # License: GNU GPL 6 | 7 | use utf8; 8 | use open ':utf8'; 9 | binmode(STDIN, ':utf8'); 10 | binmode(STDOUT, ':utf8'); 11 | binmode(STDERR, ':utf8'); 12 | 13 | while(<>) 14 | { 15 | # Discard sentence-level comment lines. 16 | next if(m/^\#/); 17 | # Discard lines of fused surface tokens. Syntactic words will be the node-level unit in the output file. 18 | next if(m/^\d+-\d+/); 19 | # Discard lines with empty nodes from the enhanced representation. 20 | next if(m/^\d+\./); 21 | if(m/\t/) 22 | { 23 | s/\r?\n$//; 24 | my @fields = split(/\t/, $_); 25 | # CoNLL-U v2 (December 2016) allows spaces in FORM and LEMMA but older tools may not survive it. 26 | # Replace spaces by underscores. 27 | $fields[1] =~ s/ /_/g; 28 | $fields[2] =~ s/ /_/g; 29 | # CoNLL-X specification did not allow POSTAG to be empty if there was CPOSTAG, and some tools rely on it. 30 | # Also, some tools rely on POSTAG being a fine-grained version of CPOSTAG, i.e. CPOSTAG should be always 31 | # inferrable from POSTAG. This is not an explicit requirement in the format specification but we will 32 | # enforce it anyway. 33 | # Copy CPOSTAG to POSTAG if POSTAG is empty. Otherwise, prepend CPOSTAG to POSTAG. 34 | if($fields[4] eq '_') 35 | { 36 | $fields[4] = $fields[3]; 37 | } 38 | else 39 | { 40 | $fields[4] = $fields[3].'_'.$fields[4]; 41 | } 42 | # The last two columns ([8] and [9]) had different meaning in CoNLL-X. 43 | # In many cases it is probably harmless to keep their contents from CoNLL-U, but some tools may rely on their expectations about these columns, 44 | # especially in [8] they may require either '_' or a numeric value. Let's erase the contents of these columns to be on the safe side. 45 | $fields[8] = $fields[9] = '_'; 46 | $_ = join("\t", @fields)."\n"; 47 | } 48 | print; 49 | } 50 | -------------------------------------------------------------------------------- /senttr/conllu_to_conllx_no_underline.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # Converts a CoNLL-U file (Universal Dependencies) to the older CoNLL-X format. 3 | # The conversion is by definition lossy. It is a lightweight converter: we do not check for validity of the CoNLL-U input! 4 | # Copyright © 2015, 2017 Dan Zeman 5 | # License: GNU GPL 6 | 7 | use utf8; 8 | use open ':utf8'; 9 | binmode(STDIN, ':utf8'); 10 | binmode(STDOUT, ':utf8'); 11 | binmode(STDERR, ':utf8'); 12 | 13 | while(<>) 14 | { 15 | # Discard sentence-level comment lines. 16 | next if(m/^\#/); 17 | # Discard lines of fused surface tokens. Syntactic words will be the node-level unit in the output file. 18 | next if(m/^\d+-\d+/); 19 | # Discard lines with empty nodes from the enhanced representation. 20 | next if(m/^\d+\./); 21 | if(m/\t/) 22 | { 23 | s/\r?\n$//; 24 | my @fields = split(/\t/, $_); 25 | # CoNLL-U v2 (December 2016) allows spaces in FORM and LEMMA but older tools may not survive it. 26 | # Replace spaces by underscores. 27 | #$fields[1] =~ s/ /_/g; 28 | #$fields[2] =~ s/ /_/g; 29 | # CoNLL-X specification did not allow POSTAG to be empty if there was CPOSTAG, and some tools rely on it. 30 | # Also, some tools rely on POSTAG being a fine-grained version of CPOSTAG, i.e. CPOSTAG should be always 31 | # inferrable from POSTAG. This is not an explicit requirement in the format specification but we will 32 | # enforce it anyway. 33 | # Copy CPOSTAG to POSTAG if POSTAG is empty. Otherwise, prepend CPOSTAG to POSTAG. 34 | if($fields[4] eq '_') 35 | { 36 | $fields[4] = $fields[3]; 37 | } 38 | else 39 | { 40 | $fields[4] = $fields[3].'_'.$fields[4]; 41 | } 42 | # The last two columns ([8] and [9]) had different meaning in CoNLL-X. 43 | # In many cases it is probably harmless to keep their contents from CoNLL-U, but some tools may rely on their expectations about these columns, 44 | # especially in [8] they may require either '_' or a numeric value. Let's erase the contents of these columns to be on the safe side. 45 | $fields[8] = $fields[9] = '_'; 46 | $_ = join("\t", @fields)."\n"; 47 | } 48 | print; 49 | } 50 | -------------------------------------------------------------------------------- /statetr/conllu_to_conllx_no_underline.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # Converts a CoNLL-U file (Universal Dependencies) to the older CoNLL-X format. 3 | # The conversion is by definition lossy. It is a lightweight converter: we do not check for validity of the CoNLL-U input! 4 | # Copyright © 2015, 2017 Dan Zeman 5 | # License: GNU GPL 6 | 7 | use utf8; 8 | use open ':utf8'; 9 | binmode(STDIN, ':utf8'); 10 | binmode(STDOUT, ':utf8'); 11 | binmode(STDERR, ':utf8'); 12 | 13 | while(<>) 14 | { 15 | # Discard sentence-level comment lines. 16 | next if(m/^\#/); 17 | # Discard lines of fused surface tokens. Syntactic words will be the node-level unit in the output file. 18 | next if(m/^\d+-\d+/); 19 | # Discard lines with empty nodes from the enhanced representation. 20 | next if(m/^\d+\./); 21 | if(m/\t/) 22 | { 23 | s/\r?\n$//; 24 | my @fields = split(/\t/, $_); 25 | # CoNLL-U v2 (December 2016) allows spaces in FORM and LEMMA but older tools may not survive it. 26 | # Replace spaces by underscores. 27 | #$fields[1] =~ s/ /_/g; 28 | #$fields[2] =~ s/ /_/g; 29 | # CoNLL-X specification did not allow POSTAG to be empty if there was CPOSTAG, and some tools rely on it. 30 | # Also, some tools rely on POSTAG being a fine-grained version of CPOSTAG, i.e. CPOSTAG should be always 31 | # inferrable from POSTAG. This is not an explicit requirement in the format specification but we will 32 | # enforce it anyway. 33 | # Copy CPOSTAG to POSTAG if POSTAG is empty. Otherwise, prepend CPOSTAG to POSTAG. 34 | if($fields[4] eq '_') 35 | { 36 | $fields[4] = $fields[3]; 37 | } 38 | else 39 | { 40 | $fields[4] = $fields[3].'_'.$fields[4]; 41 | } 42 | # The last two columns ([8] and [9]) had different meaning in CoNLL-X. 43 | # In many cases it is probably harmless to keep their contents from CoNLL-U, but some tools may rely on their expectations about these columns, 44 | # especially in [8] they may require either '_' or a numeric value. Let's erase the contents of these columns to be on the safe side. 45 | $fields[8] = $fields[9] = '_'; 46 | $_ = join("\t", @fields)."\n"; 47 | } 48 | print; 49 | } 50 | -------------------------------------------------------------------------------- /senttr/parser/cmds/predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from parser import Parser, Model 4 | from parser.utils import Corpus 5 | from parser.utils.data import TextDataset, batchify 6 | import torch 7 | 8 | 9 | class Predict(object): 10 | 11 | def add_subparser(self, name, parser): 12 | subparser = parser.add_parser( 13 | name, help='Use a trained model to make predictions.' 14 | ) 15 | subparser.add_argument('--fdata', default='data/ptb/test.conllx', 16 | help='Path to test dataset') 17 | subparser.add_argument('--fpred', default='pred.conllx', 18 | help='Prediction path') 19 | subparser.add_argument('--modelname', default='None', 20 | help='Path to trained model') 21 | subparser.add_argument('--mainpath', default='None', 22 | help='Main path') 23 | return subparser 24 | 25 | def rearange(self, items, ids): 26 | 27 | indicies = [] 28 | for id in ids: 29 | for i in id: 30 | indicies.append(i) 31 | indicies = sorted(range(len(indicies)), key=lambda k: indicies[k]) 32 | items = [items[i] for i in indicies] 33 | return items 34 | 35 | def __call__(self, args): 36 | print("Load the model") 37 | 38 | modelpath = args.mainpath + args.model + args.modelname + "/model_weights" 39 | vocabpath = args.mainpath + args.vocab + args.modelname + "/vocab.tag" 40 | 41 | config = torch.load(modelpath)['config'] 42 | 43 | 44 | vocab = torch.load(vocabpath) 45 | parser = Parser.load(modelpath) 46 | model = Model(vocab, parser, config, vocab.n_rels) 47 | 48 | print("Load the dataset") 49 | corpus = Corpus.load(args.fdata) 50 | dataset = TextDataset(vocab.numericalize(corpus)) 51 | # set the data loader 52 | loader, ids = batchify(dataset,5*config.batch_size, config.buckets) 53 | 54 | print("Make predictions on the dataset") 55 | heads_pred, rels_pred, metric = model.predict(loader) 56 | 57 | print(metric) 58 | print(f"Save the predicted result to {args.fpred}") 59 | 60 | heads_pred = self.rearange(heads_pred, ids) 61 | rels_pred = self.rearange(rels_pred,ids) 62 | 63 | 64 | corpus.heads = heads_pred 65 | corpus.rels = rels_pred 66 | corpus.save(args.fpred) -------------------------------------------------------------------------------- /statetr/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import math 5 | import os 6 | import pickle 7 | import time 8 | from datetime import datetime 9 | import argparse 10 | import torch 11 | from model import ParserModel 12 | from torch import nn, optim 13 | from tqdm import tqdm 14 | from featurize import AverageMeter, load_and_preprocess_data_test 15 | from utils import batch_dev_test 16 | import numpy as np 17 | from run import validate 18 | 19 | if __name__ == "__main__": 20 | 21 | # Hyper Parameters 22 | parser = argparse.ArgumentParser() 23 | 24 | parser.add_argument('--datapath', default='./data_new2', 25 | help='Data directory for train/test') 26 | 27 | parser.add_argument('--testfile', default='test.conll', 28 | help='File to test the model') 29 | 30 | parser.add_argument('--model_name', 31 | help='Model directory') 32 | 33 | parser.add_argument('--batchsize', default=2, type=int, 34 | help='Batch size number') 35 | 36 | parser.add_argument('--mainpath', default='', 37 | help='File to test the model') 38 | opt2 = parser.parse_args() 39 | print(opt2) 40 | 41 | 42 | checkpoint = torch.load(opt2.mainpath+'/output/'+str(opt2.model_name)+'model.weights') 43 | 44 | opt = checkpoint['opt'] 45 | 46 | opt.datapath = opt2.datapath 47 | opt.testfile = opt2.testfile 48 | opt.batchsize = opt2.batchsize 49 | opt.mainpath = opt2.mainpath 50 | 51 | 52 | with open(opt2.mainpath+'/vocab/'+str(opt2.model_name)+'.pkl', 'rb') as f: 53 | parser = pickle.load(f) 54 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 55 | 56 | print(80 * "=") 57 | print("INITIALIZING") 58 | print(80 * "=") 59 | start = time.time() 60 | debug = False 61 | test_data,pad_action = load_and_preprocess_data_test(opt,parser,debug) 62 | 63 | test_batched = batch_dev_test(test_data, opt.batchsize, parser.NULL, parser.P_NULL, 64 | parser ,no_sort=False) 65 | 66 | model = ParserModel(parser.embedding_shape, device, parser, pad_action, opt) 67 | 68 | model.load_state_dict(checkpoint['model'],strict=False) 69 | 70 | model = model.to(device) 71 | print("took {:.2f} seconds\n".format(time.time() - start)) 72 | 73 | print(80 * "=") 74 | print("TESTING") 75 | print(80 * "=") 76 | print("Final evaluation on test set", ) 77 | model.eval() 78 | 79 | UAS, LAS = validate(model, parser, test_batched, test_data, device, opt.batchsize,pad_action['P'],opt) 80 | print("- test UAS: {:.2f}".format(UAS * 100.0)) 81 | print("- test LAS: {:.2f}".format(LAS * 100.0)) 82 | print("Done!") 83 | 84 | -------------------------------------------------------------------------------- /senttr/substitue_underline.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import logging 5 | import pdb 6 | sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir)) 7 | logging.basicConfig(format='%(asctime)s %(message)s', 8 | datefmt='%m/%d/%Y %I:%M:%S %p') 9 | 10 | 11 | 12 | if len(sys.argv) < 3: 13 | sys.stderr.write("usage: python substituteConllxPOS.py conllx_file pos_file\n") 14 | sys.exit(1) 15 | 16 | class CoNLLReader: 17 | def __init__(self, file): 18 | """ 19 | 20 | :param file: FileIO object 21 | """ 22 | self.file = file 23 | 24 | def __iter__(self): 25 | return self 26 | 27 | def __next__(self): 28 | sent = self.readsent() 29 | if sent == []: 30 | raise StopIteration() 31 | else: 32 | return sent 33 | 34 | def readsent(self): 35 | """ 36 | Assuming CoNLL-U format, where the columns are: 37 | ID FORM LEMMA UPOSTAG XPOSTAG FEATS HEAD DEPREL DEPS MISC 38 | """ 39 | sent = [] 40 | row_str = self.file.readline().strip() 41 | while row_str != "": 42 | row = {} 43 | columns = row_str.split("\t") 44 | row["ID"] = int(columns[0]) 45 | row["FORM"] = columns[1] 46 | row["LEMMA"] = columns[2] if len(columns) > 2 else "_" 47 | row["UPOSTAG"] = columns[3] if len(columns) > 3 else "_" 48 | row["XPOSTAG"] = columns[4] if len(columns) > 4 else "_" 49 | row["FEATS"] = columns[5] if len(columns) > 5 else "_" 50 | row["HEAD"] = columns[6] if len(columns) > 6 else "_" 51 | row["DEPREL"] = columns[7] if len(columns) > 7 else "_" 52 | row["DEPS"] = columns[8] if len(columns) > 8 else "_" 53 | row["MISC"] = columns[9] if len(columns) > 9 else "_" 54 | sent.append(row) 55 | row_str = self.file.readline().strip() 56 | return sent 57 | 58 | def close(self): 59 | self.file.close() 60 | 61 | 62 | def write_row(row,f): 63 | f.write(str(row["ID"]) + "\t") 64 | f.write(row["FORM"] + "\t") 65 | f.write(row["LEMMA"] + "\t") 66 | f.write(row["UPOSTAG"] + "\t") 67 | f.write(row["XPOSTAG"] + "\t") 68 | f.write(row["FEATS"] + "\t") 69 | f.write(str(row["HEAD"]) + "\t") 70 | f.write(row["DEPREL"] + "\t") 71 | f.write(row["DEPS"] + "\t") 72 | f.write(row["MISC"] + "\n") 73 | 74 | 75 | conll_reader_org = CoNLLReader(open(sys.argv[1])) 76 | conll_reader_pred = CoNLLReader(open(sys.argv[2])) 77 | f = open(sys.argv[3],"w") 78 | 79 | for conll_sent_org, conll_sent_pred in zip(conll_reader_org, conll_reader_pred): 80 | assert len(conll_sent_org) == len(conll_sent_pred) 81 | for conll_row_org, conll_row_pred in zip(conll_sent_org, conll_sent_pred): 82 | conll_row_pred["FORM"] = conll_row_org["FORM"] 83 | conll_row_pred["LEMMA"] = conll_row_org["LEMMA"] 84 | write_row(conll_row_pred,f) 85 | f.write("\n") 86 | 87 | conll_reader_org.close() 88 | conll_reader_pred.close() 89 | -------------------------------------------------------------------------------- /senttr/parserstate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Parser state of transition-based parsers. 3 | """ 4 | 5 | from copy import copy 6 | import numpy as np 7 | class ParserState: 8 | def __init__(self, sentence, transsys=None, goldrels=None): 9 | self.stack = [0] 10 | 11 | self.sentence = sentence 12 | # sentences should already have a symbol as the first token 13 | self.buf = [i+1 for i in xrange(len(sentence)-1)] 14 | # head and relation labels 15 | self.head = [[-1, -1] for _ in xrange(len(sentence))] 16 | 17 | self.pos = [-1 for _ in xrange(len(sentence))] 18 | 19 | self.goldrels = goldrels 20 | 21 | self.proj_order = self.build_inorder() 22 | 23 | self.transsys = transsys 24 | if self.transsys is not None: 25 | self.transsys._preparetransitionset(self) 26 | 27 | def build_inorder(self): 28 | ids = self.stack + self.buf 29 | parent_ids = [-1] * len(ids) 30 | for i,parent in enumerate(self.goldrels): 31 | childs = parent.keys() 32 | for child in childs: 33 | assert parent_ids[child] == -1 34 | parent_ids[child] = i 35 | both = zip(ids,parent_ids) 36 | sentence = [ConllEntry(i,parent_id) for i,parent_id in both] 37 | assert len(sentence) == len(ids) == len(parent_ids) 38 | 39 | new_tokens = inorder(sentence) 40 | 41 | assert len(new_tokens) == len(sentence), 'before inorder,sent:{},inoder:{},parent_ids:{}' \ 42 | ''.format(len(sentence),len(new_tokens),parent_ids) 43 | all_ids = [] 44 | for x in new_tokens: 45 | all_ids.append(x.id) 46 | 47 | all_ids = np.asarray(all_ids) 48 | all_ids = np.argsort(all_ids) 49 | 50 | assert len(all_ids) == len(sentence),'after inorder' 51 | 52 | return all_ids 53 | 54 | def transitionset(self): 55 | return self._transitionset 56 | 57 | def clone(self): 58 | res = ParserState([]) 59 | res.stack = copy(self.stack) 60 | res.buf = copy(self.buf) 61 | res.head = copy(self.head) 62 | res.pos = copy(self.pos) 63 | res.goldrels = copy(self.goldrels) 64 | res.transsys = self.transsys 65 | if hasattr(self, '_transitionset'): 66 | res._transitionset = copy(self._transitionset) 67 | return res 68 | 69 | class ConllEntry: 70 | def __init__(self, id, parent_id=None): 71 | self.id = id 72 | self.parent_id = parent_id 73 | 74 | def inorder(sentence): 75 | queue = [sentence[0]] 76 | def inorder_helper(sentence,i): 77 | results = [] 78 | left_children = [entry for entry in sentence[:i] if entry.parent_id == i] 79 | for child in left_children: 80 | results += inorder_helper(sentence,child.id) 81 | results.append(sentence[i]) 82 | 83 | right_children = [entry for entry in sentence[i:] if entry.parent_id == i ] 84 | for child in right_children: 85 | results += inorder_helper(sentence,child.id) 86 | return results 87 | return inorder_helper(sentence,queue[0].id) 88 | 89 | 90 | -------------------------------------------------------------------------------- /statetr/substitue_underline.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import logging 5 | import pdb 6 | sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir)) 7 | logging.basicConfig(format='%(asctime)s %(message)s', 8 | datefmt='%m/%d/%Y %I:%M:%S %p') 9 | 10 | 11 | 12 | if len(sys.argv) < 3: 13 | sys.stderr.write("usage: python substituteConllxPOS.py conllx_file pos_file\n") 14 | sys.exit(1) 15 | 16 | class CoNLLReader: 17 | def __init__(self, file): 18 | """ 19 | 20 | :param file: FileIO object 21 | """ 22 | self.file = file 23 | 24 | def __iter__(self): 25 | return self 26 | 27 | def __next__(self): 28 | sent = self.readsent() 29 | if sent == []: 30 | raise StopIteration() 31 | else: 32 | return sent 33 | 34 | def readsent(self): 35 | """ 36 | Assuming CoNLL-U format, where the columns are: 37 | ID FORM LEMMA UPOSTAG XPOSTAG FEATS HEAD DEPREL DEPS MISC 38 | """ 39 | sent = [] 40 | row_str = self.file.readline().strip() 41 | while row_str != "": 42 | row = {} 43 | columns = row_str.split("\t") 44 | row["ID"] = int(columns[0]) 45 | row["FORM"] = columns[1] 46 | row["LEMMA"] = columns[2] if len(columns) > 2 else "_" 47 | row["UPOSTAG"] = columns[3] if len(columns) > 3 else "_" 48 | row["XPOSTAG"] = columns[4] if len(columns) > 4 else "_" 49 | row["FEATS"] = columns[5] if len(columns) > 5 else "_" 50 | row["HEAD"] = columns[6] if len(columns) > 6 else "_" 51 | row["DEPREL"] = columns[7] if len(columns) > 7 else "_" 52 | row["DEPS"] = columns[8] if len(columns) > 8 else "_" 53 | row["MISC"] = columns[9] if len(columns) > 9 else "_" 54 | sent.append(row) 55 | row_str = self.file.readline().strip() 56 | return sent 57 | 58 | def close(self): 59 | self.file.close() 60 | 61 | 62 | def write_row(row,f): 63 | f.write(str(row["ID"]) + "\t") 64 | f.write(row["FORM"] + "\t") 65 | f.write(row["LEMMA"] + "\t") 66 | f.write(row["UPOSTAG"] + "\t") 67 | f.write(row["XPOSTAG"] + "\t") 68 | f.write(row["FEATS"] + "\t") 69 | f.write(str(row["HEAD"]) + "\t") 70 | f.write(row["DEPREL"] + "\t") 71 | f.write(row["DEPS"] + "\t") 72 | f.write(row["MISC"] + "\n") 73 | 74 | 75 | conll_reader_org = CoNLLReader(open(sys.argv[1])) 76 | conll_reader_pred = CoNLLReader(open(sys.argv[2])) 77 | f = open(sys.argv[3],"w") 78 | #assert len(conll_reader_org) == len(conll_reader_pred) 79 | 80 | for conll_sent_org, conll_sent_pred in zip(conll_reader_org, conll_reader_pred): 81 | assert len(conll_sent_org) == len(conll_sent_pred),"org:{},pred:{}".format(len(conll_sent_org),len(conll_sent_pred)) 82 | for conll_row_org, conll_row_pred in zip(conll_sent_org, conll_sent_pred): 83 | conll_row_pred["FORM"] = conll_row_org["FORM"] 84 | conll_row_pred["LEMMA"] = conll_row_org["LEMMA"] 85 | write_row(conll_row_pred,f) 86 | f.write("\n") 87 | 88 | conll_reader_org.close() 89 | conll_reader_pred.close() 90 | -------------------------------------------------------------------------------- /senttr/parser/utils/data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from torch.nn.utils.rnn import pad_sequence 5 | from torch.utils.data import DataLoader, Dataset, Sampler 6 | 7 | 8 | def kmeans(x, k): 9 | x = torch.tensor(x, dtype=torch.float) 10 | # initialize k centroids randomly 11 | c, old = x[torch.randperm(len(x))[:k]], None 12 | # assign labels to each datapoint based on centroids 13 | dists, y = torch.abs_(x.unsqueeze(-1) - c).min(dim=-1) 14 | 15 | while old is None or not c.equal(old): 16 | # handle the empty clusters 17 | for i in range(k): 18 | # choose the farthest datapoint from the biggest cluster 19 | # and move that the empty cluster 20 | if not y.eq(i).any(): 21 | mask = y.eq(torch.arange(k).unsqueeze(-1)) 22 | lens = mask.sum(dim=-1) 23 | biggest = mask[lens.argmax()].nonzero().view(-1) 24 | farthest = dists[biggest].argmax() 25 | y[biggest[farthest]] = i 26 | # update the centroids 27 | c, old = torch.tensor([x[y.eq(i)].mean() for i in range(k)]), c 28 | # re-assign all datapoints to clusters 29 | dists, y = torch.abs_(x.unsqueeze(-1) - c).min(dim=-1) 30 | clusters = [y.eq(i) for i in range(k)] 31 | clusters = [i.nonzero().view(-1).tolist() for i in clusters if i.any()] 32 | centroids = [round(x[i].mean().item()) for i in clusters] 33 | 34 | return centroids, clusters 35 | 36 | 37 | def collate_fn(data): 38 | reprs = (pad_sequence(i, True) for i in zip(*data)) 39 | if torch.cuda.is_available(): 40 | reprs = (i.cuda() for i in reprs) 41 | 42 | return reprs 43 | 44 | 45 | class TextSampler(Sampler): 46 | 47 | def __init__(self, lengths, batch_size, n_buckets, shuffle=False): 48 | self.lengths = lengths 49 | self.batch_size = batch_size 50 | self.shuffle = shuffle 51 | # NOTE: the final bucket count is less than or equal to n_buckets 52 | self.sizes, self.buckets = kmeans(x=lengths, k=n_buckets) 53 | self.chunks = [max(size * len(bucket) // self.batch_size, 1) 54 | for size, bucket in zip(self.sizes, self.buckets)] 55 | 56 | def __iter__(self): 57 | # if shuffle, shffule both the buckets and samples in each bucket 58 | range_fn = torch.randperm if self.shuffle else torch.arange 59 | for i in range_fn(len(self.buckets)): 60 | for batch in range_fn(len(self.buckets[i])).chunk(self.chunks[i]): 61 | yield [self.buckets[i][j] for j in batch.tolist()] 62 | 63 | def __len__(self): 64 | return sum(self.chunks) 65 | 66 | 67 | class TextDataset(Dataset): 68 | 69 | def __init__(self, items, n_buckets=1): 70 | super(TextDataset, self).__init__() 71 | 72 | self.items = items 73 | 74 | def __getitem__(self, index): 75 | return tuple(item[index] for item in self.items) 76 | 77 | def __len__(self): 78 | return len(self.items[0]) 79 | 80 | @property 81 | def lengths(self): 82 | return [len(i.nonzero()) for i in self.items[0]] 83 | 84 | 85 | def batchify(dataset, batch_size, n_buckets=64, shuffle=False): 86 | batch_sampler = TextSampler(lengths=dataset.lengths, 87 | batch_size=batch_size, 88 | n_buckets=n_buckets, 89 | shuffle=shuffle) 90 | loader = DataLoader(dataset=dataset, 91 | batch_sampler=batch_sampler, 92 | collate_fn=collate_fn) 93 | 94 | return loader, batch_sampler.buckets 95 | -------------------------------------------------------------------------------- /statetr/statetr.yml: -------------------------------------------------------------------------------- 1 | name: statetr 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _pytorch_select=0.2=gpu_0 8 | - asn1crypto=1.3.0=py36_0 9 | - blas=1.0=mkl 10 | - boto3=1.10.39=py_0 11 | - botocore=1.13.39=py_0 12 | - ca-certificates=2020.1.1=0 13 | - certifi=2019.11.28=py36_0 14 | - cffi=1.13.2=py36h2e261b9_0 15 | - chardet=3.0.4=py36_1003 16 | - cryptography=2.8=py36h1ba5d50_0 17 | - cudatoolkit=10.0.130=0 18 | - docutils=0.15.2=py36_0 19 | - freetype=2.9.1=h8a8886c_1 20 | - idna=2.8=py36_0 21 | - intel-openmp=2019.4=243 22 | - jmespath=0.9.4=py_0 23 | - jpeg=9b=h024ee3a_2 24 | - ld_impl_linux-64=2.33.1=h53a641e_7 25 | - libedit=3.1.20181209=hc058e9b_0 26 | - libffi=3.2.1=hd88cf55_4 27 | - libgcc-ng=9.1.0=hdf63c60_0 28 | - libgfortran-ng=7.3.0=hdf63c60_0 29 | - libpng=1.6.37=hbc83047_0 30 | - libstdcxx-ng=9.1.0=hdf63c60_0 31 | - libtiff=4.1.0=h2733197_0 32 | - mkl=2019.4=243 33 | - ncurses=6.1=he6710b0_1 34 | - ninja=1.9.0=py36hfd86e86_0 35 | - numpy=1.14.2=py36hdbf6ddf_0 36 | - olefile=0.46=py36_0 37 | - openssl=1.1.1d=h7b6447c_3 38 | - pillow=7.0.0=py36hb39fc2d_0 39 | - pip=19.3.1=py36_0 40 | - pycparser=2.19=py36_0 41 | - pyopenssl=19.1.0=py36_0 42 | - pysocks=1.7.1=py36_0 43 | - python=3.6.10=h0371630_0 44 | - python-dateutil=2.8.0=py36_0 45 | - readline=7.0=h7b6447c_5 46 | - regex=2019.12.9=py36h7b6447c_0 47 | - requests=2.22.0=py36_1 48 | - s3transfer=0.2.1=py36_0 49 | - setuptools=44.0.0=py36_0 50 | - six=1.13.0=py36_0 51 | - sqlite=3.30.1=h7b6447c_0 52 | - tk=8.6.8=hbc83047_0 53 | - torchvision=0.5.0=py36_cu100 54 | - tqdm=4.41.1=py_0 55 | - urllib3=1.25.7=py36_0 56 | - wheel=0.33.6=py36_0 57 | - xz=5.2.4=h14c3975_4 58 | - zlib=1.2.11=h7b6447c_3 59 | - zstd=1.3.7=h0b5b093_0 60 | - pip: 61 | - alabaster==0.7.12 62 | - allennlp==0.9.0 63 | - attrs==19.3.0 64 | - babel==2.8.0 65 | - blis==0.2.4 66 | - click==7.0 67 | - conllu==1.3.1 68 | - cycler==0.10.0 69 | - cymem==2.0.3 70 | - editdistance==0.5.3 71 | - filelock==3.0.12 72 | - flaky==3.6.1 73 | - flask==1.1.1 74 | - flask-cors==3.0.8 75 | - ftfy==5.6 76 | - gevent==1.4.0 77 | - greenlet==0.4.15 78 | - h5py==2.10.0 79 | - imagesize==1.2.0 80 | - importlib-metadata==1.4.0 81 | - itsdangerous==1.1.0 82 | - jinja2==2.10.3 83 | - joblib==0.14.1 84 | - jsonnet==0.14.0 85 | - jsonpickle==1.2 86 | - kiwisolver==1.1.0 87 | - markupsafe==1.1.1 88 | - matplotlib==3.1.2 89 | - more-itertools==8.1.0 90 | - murmurhash==1.0.2 91 | - nltk==3.4.5 92 | - numpydoc==0.9.2 93 | - overrides==2.8.0 94 | - packaging==20.0 95 | - parsimonious==0.8.1 96 | - plac==0.9.6 97 | - pluggy==0.13.1 98 | - preshed==2.0.1 99 | - protobuf==3.11.2 100 | - py==1.8.1 101 | - pygments==2.5.2 102 | - pyparsing==2.4.6 103 | - pytest==5.3.4 104 | - pytorch-pretrained-bert==0.6.2 105 | - pytorch-transformers==1.1.0 106 | - pytz==2019.3 107 | - responses==0.10.9 108 | - sacremoses==0.0.38 109 | - scikit-learn==0.22.1 110 | - scipy==1.4.1 111 | - sentencepiece==0.1.85 112 | - snowballstemmer==2.0.0 113 | - spacy==2.1.9 114 | - sphinx==2.3.1 115 | - sphinxcontrib-applehelp==1.0.1 116 | - sphinxcontrib-devhelp==1.0.1 117 | - sphinxcontrib-htmlhelp==1.0.2 118 | - sphinxcontrib-jsmath==1.0.1 119 | - sphinxcontrib-qthelp==1.0.2 120 | - sphinxcontrib-serializinghtml==1.1.3 121 | - sqlparse==0.3.0 122 | - srsly==1.0.1 123 | - tensorboardx==2.0 124 | - thinc==7.0.8 125 | - tokenizers==0.0.11 126 | - torch==1.0.1 127 | - transformers==2.4.1 128 | - unidecode==1.1.1 129 | - wasabi==0.6.0 130 | - wcwidth==0.1.8 131 | - werkzeug==0.16.0 132 | - word2number==1.1 133 | - zipp==2.0.1 134 | -------------------------------------------------------------------------------- /senttr/parser/utils/corpus.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from collections import namedtuple 4 | 5 | 6 | Sentence = namedtuple(typename='Sentence', 7 | field_names=['ID', 'FORM', 'LEMMA', 'CPOS', 8 | 'POS', 'FEATS', 'HEAD', 'DEPREL', 9 | 'PHEAD', 'PDEPREL'], 10 | defaults=[None]*10) 11 | 12 | 13 | ## transit = ['L', 'R', 'S','H'] 14 | def read_seq(in_file, vocab): 15 | lines = [] 16 | with open(in_file, 'r') as f: 17 | for line in f: 18 | lines.append(line) 19 | for i in range(len(lines)): 20 | lines[i] = lines[i].strip().split() 21 | gold_seq, arcs, seq = [], [], [] 22 | max_read = 0 23 | for line in lines: 24 | #if max_read == 100: 25 | # break 26 | if len(line) == 0: 27 | gold_seq.append({'act':seq, 'rel':arcs}) 28 | max_read += 1 29 | arcs, seq = [], [] 30 | elif len(line) == 3: 31 | assert line[0] == 'Shift' 32 | seq.append(2) 33 | arcs.append(0) 34 | elif len(line) == 1: 35 | assert line[0] == 'Swap' 36 | seq.append(3) 37 | arcs.append(0) 38 | elif len(line) == 2: 39 | if line[0].startswith('R'): 40 | assert line[0] == 'Right-Arc' 41 | seq.append(1) 42 | arcs.append(vocab.rel2id( line[1] )) 43 | elif line[0].startswith('L'): 44 | assert line[0] == 'Left-Arc' 45 | seq.append(0) 46 | arcs.append(vocab.rel2id( line[1] )) 47 | return gold_seq 48 | 49 | 50 | class Corpus(object): 51 | ROOT = '' 52 | 53 | def __init__(self, sentences): 54 | super(Corpus, self).__init__() 55 | 56 | self.sentences = sentences 57 | self.ids = [i+1 for i in range(len(sentences))] 58 | 59 | def __len__(self): 60 | return len(self.sentences) 61 | 62 | def __repr__(self): 63 | return '\n'.join( 64 | '\n'.join('\t'.join(map(str, i)) 65 | for i in zip(*(f for f in sentence if f))) + '\n' 66 | for sentence in self 67 | ) 68 | 69 | def __getitem__(self, index): 70 | return self.sentences[index] 71 | 72 | 73 | @property 74 | def words(self): 75 | return [[self.ROOT] + list(sentence.FORM) for sentence in self] 76 | 77 | @property 78 | def tags(self): 79 | return [[self.ROOT] + list(sentence.CPOS) for sentence in self] 80 | 81 | @property 82 | def heads(self): 83 | #return [[0] + [0] + list(map(int, sentence.HEAD))+[0] for sentence in self] 84 | return [[0] + list(map(int, sentence.HEAD)) for sentence in self] 85 | 86 | @property 87 | def rels(self): 88 | #return [[self.ROOT] + [self.ROOT] + list(sentence.DEPREL)+[self.ROOT] for sentence in self] 89 | return [[self.ROOT] + list(sentence.DEPREL) for sentence in self] 90 | 91 | @heads.setter 92 | def heads(self, sequences): 93 | self.sentences = [sentence._replace(HEAD=sequence) 94 | for sentence, sequence in zip(self, sequences)] 95 | 96 | @rels.setter 97 | def rels(self, sequences): 98 | self.sentences = [sentence._replace(DEPREL=sequence) 99 | for sentence, sequence in zip(self, sequences)] 100 | 101 | @classmethod 102 | def load(cls, fname): 103 | start, sentences = 0, [] 104 | with open(fname, 'r') as f: 105 | lines = [line for line in f] 106 | for i, line in enumerate(lines): 107 | if len(line) <= 1: 108 | sentence = Sentence(*zip(*[l.split() for l in lines[start:i]])) 109 | sentences.append(sentence) 110 | start = i + 1 111 | corpus = cls(sentences) 112 | 113 | return corpus 114 | 115 | def save(self, fname): 116 | with open(fname, 'w') as f: 117 | f.write(f"{self}\n") 118 | -------------------------------------------------------------------------------- /statetr/dep2conllx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir)) 5 | 6 | 7 | 8 | if len(sys.argv) < 3: 9 | sys.stderr.write("usage: python substituteConllxPOS.py conllx_file pos_file\n") 10 | sys.exit(1) 11 | 12 | class CoNLLReader: 13 | def __init__(self, file): 14 | self.file = file 15 | def __iter__(self): 16 | return self 17 | 18 | def __next__(self): 19 | sent = self.readsent() 20 | if sent == []: 21 | raise StopIteration() 22 | else: 23 | return sent 24 | 25 | def readsent(self): 26 | """ 27 | Assuming CoNLL-U format, where the columns are: 28 | ID FORM LEMMA UPOSTAG XPOSTAG FEATS HEAD DEPREL DEPS MISC 29 | """ 30 | sent = [] 31 | row_str = self.file.readline().strip() 32 | while row_str != "": 33 | row = {} 34 | columns = row_str.split() 35 | row["ID"] = int(columns[0]) 36 | row["FORM"] = columns[1] 37 | row["LEMMA"] = columns[2] if len(columns) > 2 else "_" 38 | row["UPOSTAG"] = columns[3] if len(columns) > 3 else "_" 39 | row["XPOSTAG"] = columns[4] if len(columns) > 4 else "_" 40 | row["FEATS"] = columns[5] if len(columns) > 5 else "_" 41 | row["HEAD"] = columns[6] if len(columns) > 6 else "_" 42 | row["DEPREL"] = columns[7] if len(columns) > 7 else "_" 43 | row["DEPS"] = columns[8] if len(columns) > 8 else "_" 44 | row["MISC"] = columns[9] if len(columns) > 9 else "_" 45 | sent.append(row) 46 | row_str = self.file.readline().strip() 47 | return sent 48 | 49 | def close(self): 50 | self.file.close() 51 | 52 | def write_row(row): 53 | sys.stdout.write(str(row["ID"]) + "\t") 54 | sys.stdout.write(row["FORM"] + "\t") 55 | sys.stdout.write(row["LEMMA"] + "\t") 56 | sys.stdout.write(row["UPOSTAG"] + "\t") 57 | sys.stdout.write(row["XPOSTAG"] + "\t") 58 | sys.stdout.write(row["FEATS"] + "\t") 59 | sys.stdout.write(str(row["HEAD"]) + "\t") 60 | sys.stdout.write(row["DEPREL"] + "\t") 61 | sys.stdout.write(row["DEPS"] + "\t") 62 | sys.stdout.write(row["MISC"] + "\n") 63 | 64 | 65 | conll_reader = CoNLLReader(open(sys.argv[1])) 66 | #pos_reader = sys.argv[2]) 67 | with open(str(sys.argv[2]+"/dependency/"+sys.argv[3])+'.pkl', 'rb') as f: 68 | dependencies_total_old = pickle.load(f) 69 | with open(str(sys.argv[2])+"/vocab/"+sys.argv[3]+'.pkl', 'rb') as f: 70 | parser = pickle.load(f) 71 | #print(dependencies_total[0]) 72 | conll_read = [] 73 | for conll_sent in conll_reader: 74 | conll_read.append(conll_sent) 75 | 76 | lengths = [len(x)+1 for x in conll_read] 77 | 78 | index_sorted = sorted(range(len(lengths)), key=lambda k: lengths[k],reverse=True) 79 | index_new = sorted(range(len(index_sorted)), key=lambda k: index_sorted[k]) 80 | 81 | dependencies_total = [dependencies_total_old[i] for i in index_new] 82 | 83 | assert len(dependencies_total) == len(conll_read) 84 | 85 | for conll_sent, dependencies in zip(conll_read, dependencies_total): 86 | heads = [] 87 | deprels = [] 88 | dependencies.sort(key=lambda row: row[1]) 89 | for item in dependencies: 90 | heads.append(item[0]) 91 | #print(item) 92 | if item[2] < 0: 93 | x = item[2]+parser.n_deprel 94 | elif item[2] >= parser.n_deprel: 95 | x = item[2] - parser.n_deprel 96 | else: 97 | x = item[2] 98 | assert 0<= x < parser.n_deprel 99 | deprels.append(parser.id2tok[x].split(":")[1]) 100 | assert len(heads) == len(deprels) 101 | if len(conll_sent) > len(heads): 102 | dif = len(conll_sent) - len(heads) 103 | heads = heads + [0]*dif 104 | deprels = deprels + ['-']*dif 105 | assert len(conll_sent) == len(heads) == len(deprels), "conll:{},pred:{},deprel:{}".format(len(conll_sent),len(heads),len(deprels)) 106 | for conll_row, head, deprel in zip(conll_sent, heads, deprels): 107 | conll_row["HEAD"] = head 108 | conll_row["DEPREL"] = deprel 109 | write_row(conll_row) 110 | sys.stdout.write("\n") 111 | 112 | 113 | conll_reader.close() 114 | -------------------------------------------------------------------------------- /sample_data/train.conll: -------------------------------------------------------------------------------- 1 | 1 Ms. _ NNP NNP _ 2 nn _ _ 2 | 2 Haag _ NNP NNP _ 3 nsubj _ _ 3 | 3 plays _ VBZ VBZ _ 0 root _ _ 4 | 4 Elianti _ NNP NNP _ 3 dobj _ _ 5 | 5 . _ . . _ 3 punct _ _ 6 | 7 | 1 Rolls-Royce _ NNP NNP _ 4 nn _ _ 8 | 2 Motor _ NNP NNP _ 4 nn _ _ 9 | 3 Cars _ NNPS NNPS _ 4 nn _ _ 10 | 4 Inc. _ NNP NNP _ 5 nsubj _ _ 11 | 5 said _ VBD VBD _ 0 root _ _ 12 | 6 it _ PRP PRP _ 7 nsubj _ _ 13 | 7 expects _ VBZ VBZ _ 5 ccomp _ _ 14 | 8 its _ PRP$ PRP$ _ 10 poss _ _ 15 | 9 U.S. _ NNP NNP _ 10 nn _ _ 16 | 10 sales _ NNS NNS _ 13 nsubj _ _ 17 | 11 to _ TO TO _ 13 aux _ _ 18 | 12 remain _ VB VB _ 13 cop _ _ 19 | 13 steady _ JJ JJ _ 7 xcomp _ _ 20 | 14 at _ IN IN _ 13 prep _ _ 21 | 15 about _ IN IN _ 16 quantmod _ _ 22 | 16 1,200 _ CD CD _ 17 num _ _ 23 | 17 cars _ NNS NNS _ 14 pobj _ _ 24 | 18 in _ IN IN _ 13 prep _ _ 25 | 19 1990 _ CD CD _ 18 pobj _ _ 26 | 20 . _ . . _ 5 punct _ _ 27 | 28 | 1 The _ DT DT _ 4 det _ _ 29 | 2 luxury _ NN NN _ 4 nn _ _ 30 | 3 auto _ NN NN _ 4 nn _ _ 31 | 4 maker _ NN NN _ 7 nsubj _ _ 32 | 5 last _ JJ JJ _ 6 amod _ _ 33 | 6 year _ NN NN _ 7 tmod _ _ 34 | 7 sold _ VBD VBD _ 0 root _ _ 35 | 8 1,214 _ CD CD _ 9 num _ _ 36 | 9 cars _ NNS NNS _ 7 dobj _ _ 37 | 10 in _ IN IN _ 7 prep _ _ 38 | 11 the _ DT DT _ 12 det _ _ 39 | 12 U.S. _ NNP NNP _ 10 pobj _ _ 40 | 1 BELL _ NNP NNP _ 3 nn _ _ 41 | 2 INDUSTRIES _ NNP NNP _ 3 nn _ _ 42 | 3 Inc. _ NNP NNP _ 4 nsubj _ _ 43 | 4 increased _ VBD VBD _ 0 root _ _ 44 | 5 its _ PRP$ PRP$ _ 6 poss _ _ 45 | 6 quarterly _ NN NN _ 4 dobj _ _ 46 | 7 to _ TO TO _ 4 prep _ _ 47 | 8 10 _ CD CD _ 9 num _ _ 48 | 9 cents _ NNS NNS _ 7 pobj _ _ 49 | 10 from _ IN IN _ 4 prep _ _ 50 | 11 seven _ CD CD _ 12 num _ _ 51 | 12 cents _ NNS NNS _ 10 pobj _ _ 52 | 13 a _ DT DT _ 14 det _ _ 53 | 14 share _ NN NN _ 12 npadvmod _ _ 54 | 15 . _ . . _ 4 punct _ _ 55 | 56 | 1 The _ DT DT _ 3 det _ _ 57 | 2 new _ JJ JJ _ 3 amod _ _ 58 | 3 rate _ NN NN _ 6 nsubj _ _ 59 | 4 will _ MD MD _ 6 aux _ _ 60 | 5 be _ VB VB _ 6 cop _ _ 61 | 6 payable _ JJ JJ _ 0 root _ _ 62 | 7 Feb. _ NNP NNP _ 6 tmod _ _ 63 | 8 15 _ CD CD _ 7 num _ _ 64 | 9 . _ . . _ 6 punct _ _ 65 | -------------------------------------------------------------------------------- /senttr/parser/utils/scalar_mix.py: -------------------------------------------------------------------------------- 1 | """ 2 | The dot-product "Layer Attention" that is applied to the layers of BERT, along with layer dropout to reduce overfitting 3 | """ 4 | 5 | from typing import List 6 | 7 | import torch 8 | from torch.nn import ParameterList, Parameter 9 | 10 | from allennlp.common.checks import ConfigurationError 11 | 12 | 13 | class ScalarMixWithDropout(torch.nn.Module): 14 | """ 15 | Computes a parameterised scalar mixture of N tensors, ``mixture = gamma * sum(s_k * tensor_k)`` 16 | where ``s = softmax(w)``, with ``w`` and ``gamma`` scalar parameters. 17 | 18 | If ``do_layer_norm=True`` then apply layer normalization to each tensor before weighting. 19 | 20 | If ``dropout > 0``, then for each scalar weight, adjust its softmax weight mass to 0 with 21 | the dropout probability (i.e., setting the unnormalized weight to -inf). This effectively 22 | should redistribute dropped probability mass to all other weights. 23 | """ 24 | def __init__(self, 25 | mixture_size: int, 26 | do_layer_norm: bool = False, 27 | initial_scalar_parameters: List[float] = None, 28 | trainable: bool = True, 29 | dropout: float = None, 30 | dropout_value: float = -1e20) -> None: 31 | super(ScalarMixWithDropout, self).__init__() 32 | self.mixture_size = mixture_size 33 | self.do_layer_norm = do_layer_norm 34 | self.dropout = dropout 35 | 36 | if initial_scalar_parameters is None: 37 | initial_scalar_parameters = [0.0] * mixture_size 38 | elif len(initial_scalar_parameters) != mixture_size: 39 | raise ConfigurationError("Length of initial_scalar_parameters {} differs " 40 | "from mixture_size {}".format( 41 | initial_scalar_parameters, mixture_size)) 42 | 43 | self.scalar_parameters = ParameterList( 44 | [Parameter(torch.FloatTensor([initial_scalar_parameters[i]]), 45 | requires_grad=trainable) for i 46 | in range(mixture_size)]) 47 | self.gamma = Parameter(torch.FloatTensor([1.0]), requires_grad=trainable) 48 | 49 | if self.dropout: 50 | dropout_mask = torch.zeros(len(self.scalar_parameters)) 51 | dropout_fill = torch.empty(len(self.scalar_parameters)).fill_(dropout_value) 52 | self.register_buffer("dropout_mask", dropout_mask) 53 | self.register_buffer("dropout_fill", dropout_fill) 54 | 55 | def forward(self, tensors: List[torch.Tensor], # pylint: disable=arguments-differ 56 | mask: torch.Tensor = None) -> torch.Tensor: 57 | """ 58 | Compute a weighted average of the ``tensors``. The input tensors an be any shape 59 | with at least two dimensions, but must all be the same shape. 60 | 61 | When ``do_layer_norm=True``, the ``mask`` is required input. If the ``tensors`` are 62 | dimensioned ``(dim_0, ..., dim_{n-1}, dim_n)``, then the ``mask`` is dimensioned 63 | ``(dim_0, ..., dim_{n-1})``, as in the typical case with ``tensors`` of shape 64 | ``(batch_size, timesteps, dim)`` and ``mask`` of shape ``(batch_size, timesteps)``. 65 | 66 | When ``do_layer_norm=False`` the ``mask`` is ignored. 67 | """ 68 | if len(tensors) != self.mixture_size: 69 | raise ConfigurationError("{} tensors were passed, but the module was initialized to " 70 | "mix {} tensors.".format(len(tensors), self.mixture_size)) 71 | 72 | def _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked): 73 | tensor_masked = tensor * broadcast_mask 74 | mean = torch.sum(tensor_masked) / num_elements_not_masked 75 | variance = torch.sum(((tensor_masked - mean) * broadcast_mask)**2) / num_elements_not_masked 76 | return (tensor - mean) / torch.sqrt(variance + 1E-12) 77 | 78 | weights = torch.cat([parameter for parameter in self.scalar_parameters]) 79 | 80 | if self.dropout: 81 | weights = torch.where(self.dropout_mask.uniform_() > self.dropout, weights, self.dropout_fill) 82 | 83 | normed_weights = torch.nn.functional.softmax(weights, dim=0) 84 | normed_weights = torch.split(normed_weights, split_size_or_sections=1) 85 | 86 | if not self.do_layer_norm: 87 | pieces = [] 88 | for weight, tensor in zip(normed_weights, tensors): 89 | pieces.append(weight * tensor) 90 | return self.gamma * sum(pieces) 91 | 92 | else: 93 | mask_float = mask.float() 94 | broadcast_mask = mask_float.unsqueeze(-1) 95 | input_dim = tensors[0].size(-1) 96 | num_elements_not_masked = torch.sum(mask_float) * input_dim 97 | 98 | pieces = [] 99 | for weight, tensor in zip(normed_weights, tensors): 100 | pieces.append(weight * _do_layer_norm(tensor, 101 | broadcast_mask, num_elements_not_masked)) 102 | return self.gamma * sum(pieces) -------------------------------------------------------------------------------- /senttr/parser/parser.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | from transformers.configuration_bert import BertConfig 6 | from parser.utils.base import BertBaseModel 7 | from parser.utils.graph import BertGraphModel 8 | from parser.utils.scalar_mix import ScalarMixWithDropout 9 | 10 | 11 | class Classifier(nn.Module): 12 | def __init__(self, input_size, hidden_size, n_labels): 13 | super(Classifier, self).__init__() 14 | self.layer1 = nn.Linear(input_size, hidden_size) 15 | nn.init.xavier_uniform_(self.layer1.weight) 16 | self.activation = nn.LeakyReLU() 17 | self.layer2 = nn.Linear(hidden_size, n_labels) 18 | nn.init.xavier_uniform_(self.layer2.weight) 19 | 20 | def forward(self, input_label): 21 | output = self.layer1(input_label) 22 | output = self.activation(output) 23 | output = self.layer2(output) 24 | 25 | return output 26 | 27 | class Parser(nn.Module): 28 | 29 | def __init__(self, config, bertmodel): 30 | super(Parser, self).__init__() 31 | 32 | self.config = config 33 | 34 | # build and load BERT G2G model 35 | bertconfig = BertConfig.from_pretrained( 36 | config.main_path+"/model"+"/model_"+config.modelname+'/config.json') 37 | 38 | bertconfig.num_hidden_layers = config.n_attention_layer 39 | bertconfig.label_size = config.n_rels 40 | bertconfig.layernorm_value = config.layernorm_value 41 | bertconfig.layernorm_key = config.layernorm_key 42 | 43 | if self.config.input_graph: 44 | self.bert = BertGraphModel(bertconfig) 45 | else: 46 | self.bert = BertBaseModel(bertconfig) 47 | 48 | self.bert.load_state_dict(bertmodel.state_dict(),strict=False) 49 | self.mlp = Classifier(3*bertconfig.hidden_size,bertconfig.hidden_size,config.n_trans) 50 | self.mlp_rel = Classifier(2*bertconfig.hidden_size,bertconfig.hidden_size,config.n_rels) 51 | 52 | self.pad_index = config.pad_index 53 | self.unk_index = config.unk_index 54 | 55 | # build proper features for graph output mechanism 56 | def merge(self,states): 57 | features = [state.feature() for state in states] 58 | features = torch.stack(features) 59 | return features 60 | 61 | def merge_label(self,states): 62 | features = [state.feature_label() for state in states] 63 | features = torch.stack(features) 64 | return features 65 | 66 | # build graph input matrices for a batch 67 | def mix_graph(self,states): 68 | graphs = torch.stack([state.graph for state in states]) 69 | labels = torch.stack([state.label for state in states]) 70 | return graphs,labels 71 | 72 | def forward(self, words, tags, masks, states, actions=None, rels=None): 73 | 74 | mask = words.ne(self.pad_index) 75 | 76 | batch_size = words.size()[0] 77 | if actions is None: 78 | output_acts = torch.zeros((batch_size,self.config.max_seq,self.config.n_trans)).to(words.device) 79 | output_rels = torch.zeros((batch_size, self.config.max_seq,self.config.n_rels)).to(words.device) 80 | max_seq = self.config.max_seq 81 | else: 82 | output_acts = torch.zeros((batch_size,actions.size()[1],self.config.n_trans)).to(words.device) 83 | output_rels = torch.zeros((batch_size, actions.size()[1],self.config.n_rels)).to(words.device) 84 | max_seq = actions.size()[1] 85 | step = 0 86 | while step < max_seq: 87 | if self.config.input_graph: 88 | graphs,labels = self.mix_graph(states) 89 | embs = self.bert(words,tags,mask,graphs,labels)[0] 90 | else: 91 | embs = self.bert(words, tags, mask)[0] 92 | feats = self.merge(states) 93 | out = torch.stack([embs[i][feats[i]] for i in range(feats.size()[0])]).view(feats.size()[0],-1).clone() 94 | out_arc = self.mlp(out) 95 | 96 | feats_label = self.merge_label(states) 97 | out_label_input = torch.stack([embs[i][feats_label[i]] for i in range(feats_label.size()[0])])\ 98 | .view(feats_label.size()[0],-1).clone() 99 | out_rel = self.mlp_rel(out_label_input) 100 | 101 | # mask PAD and BERT relations 102 | out_rel[:,0:2] = out_rel[:,0:2] - 1000 103 | output_acts[:,step] = out_arc 104 | output_rels[:,step] = out_rel 105 | 106 | # predict action and label for the next iteration (use gold during training) 107 | if actions is None: 108 | legal_actions = torch.tensor([state.legal_act() for state in states])\ 109 | .long().to(words.device) 110 | _,act = torch.max(out_arc+1000*legal_actions,dim=1) 111 | _,rel = torch.max(out_rel,dim=1) 112 | else: 113 | act = actions[:,step] 114 | rel = rels[:,step] 115 | 116 | for i,(state,a,r) in enumerate(zip(states,act,rel)): 117 | state.update(a,r) 118 | 119 | if all([state.finished() for state in states]) and actions is None: 120 | break 121 | step+=1 122 | 123 | if actions is None: 124 | return states 125 | else: 126 | return output_acts,output_rels 127 | @classmethod 128 | def load(cls, fname): 129 | if torch.cuda.is_available(): 130 | device = torch.device('cuda') 131 | else: 132 | device = torch.device('cpu') 133 | state = torch.load(fname, map_location=device) 134 | parser = cls(state['config'],state['bertmodel']) 135 | 136 | parser.load_state_dict(state['state_dict'],strict=False) 137 | parser.to(device) 138 | 139 | return parser 140 | 141 | def save(self, fname): 142 | state = { 143 | 'bertmodel':self.bert, 144 | 'config': self.config, 145 | 'state_dict': self.state_dict() 146 | } 147 | torch.save(state, fname) 148 | -------------------------------------------------------------------------------- /senttr/environment.yml: -------------------------------------------------------------------------------- 1 | name: g2g-gr-emnlp 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _pytorch_select=0.2=gpu_0 8 | - blas=1.0=mkl 9 | - ca-certificates=2020.7.22=0 10 | - certifi=2020.6.20=py37_0 11 | - cffi=1.13.2=py37h2e261b9_0 12 | - cudatoolkit=10.0.130=0 13 | - cudnn=7.6.5=cuda10.0_0 14 | - freetype=2.9.1=h8a8886c_1 15 | - intel-openmp=2019.4=243 16 | - joblib=0.14.1=py_0 17 | - jpeg=9b=h024ee3a_2 18 | - ld_impl_linux-64=2.33.1=h53a641e_7 19 | - libedit=3.1.20181209=hc058e9b_0 20 | - libffi=3.2.1=hd88cf55_4 21 | - libgcc-ng=9.1.0=hdf63c60_0 22 | - libgfortran-ng=7.3.0=hdf63c60_0 23 | - libpng=1.6.37=hbc83047_0 24 | - libstdcxx-ng=9.1.0=hdf63c60_0 25 | - libtiff=4.1.0=h2733197_0 26 | - mkl=2019.4=243 27 | - mkl-service=2.3.0=py37he904b0f_0 28 | - mkl_fft=1.0.15=py37ha843d7b_0 29 | - mkl_random=1.1.0=py37hd6b4f25_0 30 | - ncurses=6.1=he6710b0_1 31 | - ninja=1.9.0=py37hfd86e86_0 32 | - numpy=1.17.4=py37hc1035e2_0 33 | - numpy-base=1.17.4=py37hde5b4d6_0 34 | - olefile=0.46=py37_0 35 | - openssl=1.1.1g=h7b6447c_0 36 | - pillow=7.0.0=py37hb39fc2d_0 37 | - pip=19.3.1=py37_0 38 | - pycparser=2.19=py37_0 39 | - python=3.7.6=h0371630_2 40 | - pytorch=1.4.0=py3.7_cuda10.0.130_cudnn7.6.3_0 41 | - readline=7.0=h7b6447c_5 42 | - regex=2019.12.9=py37h7b6447c_0 43 | - scikit-learn=0.22.1=py37hd81dba3_0 44 | - setuptools=44.0.0=py37_0 45 | - six=1.13.0=py37_0 46 | - sqlite=3.30.1=h7b6447c_0 47 | - tk=8.6.8=hbc83047_0 48 | - torchvision=0.5.0=py37_cu100 49 | - tqdm=4.41.1=py_0 50 | - wheel=0.33.6=py37_0 51 | - xz=5.2.4=h14c3975_4 52 | - zlib=1.2.11=h7b6447c_3 53 | - zstd=1.3.7=h0b5b093_0 54 | - pip: 55 | - absl-py==0.9.0 56 | - alabaster==0.7.12 57 | - allennlp==1.1.0 58 | - allennlp-models==1.1.0 59 | - astor==0.8.1 60 | - astunparse==1.6.3 61 | - attrs==19.3.0 62 | - babel==2.8.0 63 | - backcall==0.1.0 64 | - bleach==3.1.1 65 | - blis==0.2.4 66 | - boto3==1.14.60 67 | - botocore==1.17.60 68 | - cachetools==4.0.0 69 | - chardet==3.0.4 70 | - click==7.0 71 | - conllu==4.1 72 | - cycler==0.10.0 73 | - cymem==2.0.3 74 | - cython==0.29.16 75 | - decorator==4.4.2 76 | - defusedxml==0.6.0 77 | - dill==0.3.1.1 78 | - docutils==0.15.2 79 | - editdistance==0.5.3 80 | - entrypoints==0.3 81 | - fairseq==0.9.0 82 | - filelock==3.0.12 83 | - flaky==3.6.1 84 | - flask==1.1.1 85 | - flask-cors==3.0.8 86 | - ftfy==5.6 87 | - future==0.18.2 88 | - gast==0.3.3 89 | - gevent==1.4.0 90 | - google-auth==1.11.2 91 | - google-auth-oauthlib==0.4.1 92 | - google-pasta==0.1.8 93 | - greenlet==0.4.15 94 | - grpcio==1.27.2 95 | - h5py==2.10.0 96 | - idna==2.8 97 | - imagesize==1.2.0 98 | - importlib-metadata==1.4.0 99 | - ipykernel==5.1.4 100 | - ipython==7.13.0 101 | - ipython-genutils==0.2.0 102 | - ipywidgets==7.5.1 103 | - itsdangerous==1.1.0 104 | - jedi==0.16.0 105 | - jinja2==2.10.3 106 | - jmespath==0.9.4 107 | - jsonlines==1.2.0 108 | - jsonnet==0.14.0 109 | - jsonpickle==1.2 110 | - jsonschema==3.2.0 111 | - jupyter==1.0.0 112 | - jupyter-client==6.0.0 113 | - jupyter-console==6.1.0 114 | - jupyter-core==4.6.3 115 | - keras==2.4.3 116 | - keras-applications==1.0.8 117 | - keras-preprocessing==1.1.2 118 | - kiwisolver==1.1.0 119 | - markdown==3.2.1 120 | - markupsafe==1.1.1 121 | - matplotlib==3.1.2 122 | - mecab-python3==0.996.5 123 | - mistune==0.8.4 124 | - more-itertools==8.1.0 125 | - murmurhash==1.0.2 126 | - nbconvert==5.6.1 127 | - nbformat==5.0.4 128 | - nltk==3.4.5 129 | - notebook==6.0.3 130 | - numpydoc==0.9.2 131 | - oauthlib==3.1.0 132 | - opt-einsum==3.1.0 133 | - overrides==3.1.0 134 | - packaging==20.0 135 | - pandocfilters==1.4.2 136 | - parsimonious==0.8.1 137 | - parso==0.6.2 138 | - pexpect==4.8.0 139 | - pickleshare==0.7.5 140 | - plac==0.9.6 141 | - pluggy==0.13.1 142 | - portalocker==1.6.0 143 | - preshed==2.0.1 144 | - prometheus-client==0.7.1 145 | - prompt-toolkit==3.0.4 146 | - protobuf==3.11.2 147 | - ptyprocess==0.6.0 148 | - py==1.8.1 149 | - py-rouge==1.1 150 | - pyasn1==0.4.8 151 | - pyasn1-modules==0.2.8 152 | - pyconll==2.2.1 153 | - pygments==2.5.2 154 | - pyparsing==2.4.6 155 | - pyrsistent==0.15.7 156 | - pytest==5.3.4 157 | - python-dateutil==2.8.1 158 | - pytorch-pretrained-bert==0.6.2 159 | - pytorch-transformers==1.1.0 160 | - pytz==2019.3 161 | - pyyaml==5.3.1 162 | - pyzmq==19.0.0 163 | - qtconsole==4.7.1 164 | - qtpy==1.9.0 165 | - requests==2.22.0 166 | - requests-oauthlib==1.3.0 167 | - responses==0.10.9 168 | - rsa==4.0 169 | - s3transfer==0.3.0 170 | - sacrebleu==1.4.6 171 | - sacremoses==0.0.38 172 | - scipy==1.4.1 173 | - send2trash==1.5.0 174 | - sentencepiece==0.1.85 175 | - snowballstemmer==2.0.0 176 | - spacy==2.1.9 177 | - sphinx==2.3.1 178 | - sphinxcontrib-applehelp==1.0.1 179 | - sphinxcontrib-devhelp==1.0.1 180 | - sphinxcontrib-htmlhelp==1.0.2 181 | - sphinxcontrib-jsmath==1.0.1 182 | - sphinxcontrib-qthelp==1.0.2 183 | - sphinxcontrib-serializinghtml==1.1.3 184 | - sqlparse==0.3.0 185 | - srsly==1.0.1 186 | - tabulate==0.8.7 187 | - tensorboard==2.3.0 188 | - tensorboard-plugin-wit==1.7.0 189 | - tensorboardx==2.0 190 | - tensorflow==2.3.0 191 | - tensorflow-estimator==2.3.0 192 | - termcolor==1.1.0 193 | - terminado==0.8.3 194 | - testpath==0.4.4 195 | - thinc==7.0.8 196 | - tokenizers==0.0.11 197 | - torch==1.4.0 198 | - torchtext==0.2.1 199 | - tornado==6.0.4 200 | - traitlets==4.3.3 201 | - transformers==2.4.1 202 | - typing==3.7.4.1 203 | - unidecode==1.1.1 204 | - urllib3==1.25.7 205 | - wasabi==0.6.0 206 | - wcwidth==0.1.8 207 | - webencodings==0.5.1 208 | - werkzeug==0.16.0 209 | - widgetsnbextension==3.5.1 210 | - word2number==1.1 211 | - wrapt==1.12.0 212 | - zipp==2.0.0 213 | prefix: /idiap/user/amohammadshahi/miniconda3/envs/g2g-gr 214 | -------------------------------------------------------------------------------- /sample_data/dev.conll: -------------------------------------------------------------------------------- 1 | 1 Influential _ JJ JJ _ 2 amod _ _ 2 | 2 members _ NNS NNS _ 10 nsubj _ _ 3 | 3 of _ IN IN _ 2 prep _ _ 4 | 4 the _ DT DT _ 6 det _ _ 5 | 5 House _ NNP NNP _ 6 nn _ _ 6 | 6 Ways _ NNP NNP _ 3 pobj _ _ 7 | 7 and _ CC CC _ 6 cc _ _ 8 | 8 Means _ NNP NNP _ 9 nn _ _ 9 | 9 Committee _ NNP NNP _ 6 conj _ _ 10 | 10 introduced _ VBD VBD _ 0 root _ _ 11 | 11 legislation _ NN NN _ 10 dobj _ _ 12 | 12 that _ WDT WDT _ 14 nsubj _ _ 13 | 13 would _ MD MD _ 14 aux _ _ 14 | 14 restrict _ VB VB _ 11 rcmod _ _ 15 | 15 how _ WRB WRB _ 22 advmod _ _ 16 | 16 the _ DT DT _ 20 det _ _ 17 | 17 new _ JJ JJ _ 20 amod _ _ 18 | 18 savings-and-loan _ JJ JJ _ 20 nn _ _ 19 | 19 bailout _ NN NN _ 20 nn _ _ 20 | 20 agency _ NN NN _ 22 nsubj _ _ 21 | 21 can _ MD MD _ 22 aux _ _ 22 | 22 raise _ VB VB _ 14 ccomp _ _ 23 | 23 capital _ NN NN _ 22 dobj _ _ 24 | 24 , _ , , _ 14 punct _ _ 25 | 25 creating _ VBG VBG _ 14 xcomp _ _ 26 | 26 another _ DT DT _ 28 det _ _ 27 | 27 potential _ JJ JJ _ 28 amod _ _ 28 | 28 obstacle _ NN NN _ 25 dobj _ _ 29 | 29 to _ TO TO _ 28 prep _ _ 30 | 30 the _ DT DT _ 31 det _ _ 31 | 31 government _ NN NN _ 33 poss _ _ 32 | 32 's _ POS POS _ 31 possessive _ _ 33 | 33 sale _ NN NN _ 29 pobj _ _ 34 | 34 of _ IN IN _ 33 prep _ _ 35 | 35 sick _ JJ JJ _ 36 amod _ _ 36 | 36 thrifts _ NNS NNS _ 34 pobj _ _ 37 | 37 . _ . . _ 10 punct _ _ 38 | 39 | 1 The _ DT DT _ 2 det _ _ 40 | 2 bill _ NN NN _ 17 nsubj _ _ 41 | 3 , _ , , _ 2 punct _ _ 42 | 4 whose _ WP$ WP$ _ 5 poss _ _ 43 | 5 backers _ NNS NNS _ 6 nsubj _ _ 44 | 6 include _ VBP VBP _ 2 rcmod _ _ 45 | 7 Chairman _ NNP NNP _ 9 nn _ _ 46 | 8 Dan _ NNP NNP _ 9 nn _ _ 47 | 9 Rostenkowski _ NNP NNP _ 6 dobj _ _ 48 | 10 -LRB- _ -LRB- -LRB- _ 11 punct _ _ 49 | 11 D. _ NNP NNP _ 9 appos _ _ 50 | 12 , _ , , _ 11 punct _ _ 51 | 13 Ill. _ NNP NNP _ 11 dep _ _ 52 | 14 -RRB- _ -RRB- -RRB- _ 11 punct _ _ 53 | 15 , _ , , _ 2 punct _ _ 54 | 16 would _ MD MD _ 17 aux _ _ 55 | 17 prevent _ VB VB _ 0 root _ _ 56 | 18 the _ DT DT _ 21 det _ _ 57 | 19 Resolution _ NNP NNP _ 21 nn _ _ 58 | 20 Trust _ NNP NNP _ 21 nn _ _ 59 | 21 Corp. _ NNP NNP _ 17 dobj _ _ 60 | 22 from _ IN IN _ 17 prep _ _ 61 | 23 raising _ VBG VBG _ 22 pcomp _ _ 62 | 24 temporary _ JJ JJ _ 26 amod _ _ 63 | 25 working _ JJ JJ _ 26 amod _ _ 64 | 26 capital _ NN NN _ 23 dobj _ _ 65 | 27 by _ IN IN _ 17 prep _ _ 66 | 28 having _ VBG VBG _ 27 pcomp _ _ 67 | 29 an _ DT DT _ 31 det _ _ 68 | 30 RTC-owned _ JJ JJ _ 31 amod _ _ 69 | 31 bank _ NN NN _ 28 dobj _ _ 70 | 32 or _ CC CC _ 31 cc _ _ 71 | 33 thrift _ NN NN _ 35 nn _ _ 72 | 34 issue _ NN NN _ 35 nn _ _ 73 | 35 debt _ NN NN _ 31 conj _ _ 74 | 36 that _ WDT WDT _ 40 nsubjpass _ _ 75 | 37 would _ MD MD _ 40 aux _ _ 76 | 38 n't _ RB RB _ 40 neg _ _ 77 | 39 be _ VB VB _ 40 auxpass _ _ 78 | 40 counted _ VBN VBN _ 31 rcmod _ _ 79 | 41 on _ IN IN _ 40 prep _ _ 80 | 42 the _ DT DT _ 44 det _ _ 81 | 43 federal _ JJ JJ _ 44 amod _ _ 82 | 44 budget _ NN NN _ 41 pobj _ _ 83 | 45 . _ . . _ 17 punct _ _ 84 | -------------------------------------------------------------------------------- /senttr/transition.py: -------------------------------------------------------------------------------- 1 | class ArcStandard(TransitionSystem): 2 | @classmethod 3 | def actions_list(self): 4 | return ['Shift', 'Swap','Left-Arc', 'Right-Arc'] 5 | 6 | def _preparetransitionset(self, parserstate): 7 | SHIFT = self.mappings['action']['Shift'] 8 | SWAP = self.mappings['action']['Swap'] 9 | LEFTARC = self.mappings['action']['Left-Arc'] 10 | RIGHTARC = self.mappings['action']['Right-Arc'] 11 | 12 | stack, buf, head = parserstate.stack, parserstate.buf, parserstate.head 13 | 14 | t = [] 15 | 16 | if len(buf) > 0: 17 | t += [(SHIFT,)] 18 | 19 | if len(stack) > 1: 20 | t += [(SWAP,)] 21 | 22 | if len(stack) > 1: 23 | t += [(LEFTARC,)] 24 | 25 | if len(stack) > 1: 26 | t += [(RIGHTARC,)] 27 | 28 | parserstate._transitionset = t 29 | 30 | def advance(self, parserstate, action): 31 | SHIFT = self.mappings['action']['Shift'] 32 | SWAP = self.mappings['action']['Swap'] 33 | LEFTARC = self.mappings['action']['Left-Arc'] 34 | RIGHTARC = self.mappings['action']['Right-Arc'] 35 | 36 | RELS = len(self.mappings['rel']) 37 | cand = parserstate.transitionset() 38 | 39 | if isinstance(action, int): 40 | a, rel = self.tuple_trans_from_int(cand, action) 41 | else: 42 | rel = action[-1] 43 | a = action[:-1] 44 | 45 | stack = parserstate.stack 46 | buf = parserstate.buf 47 | 48 | if a[0] == SHIFT: 49 | parserstate.stack = [buf[0]] + stack 50 | parserstate.buf = buf[1:] 51 | elif a[0] == SWAP: 52 | parserstate.stack = [stack[0]] + stack[2:] 53 | parserstate.buf = [stack[1]] + buf 54 | elif a[0] == LEFTARC: 55 | parserstate.head[stack[1]] = [stack[0], rel] 56 | parserstate.stack = [stack[0]] + stack[2:] 57 | elif a[0] == RIGHTARC: 58 | parserstate.head[stack[0]] = [stack[1], rel] 59 | parserstate.stack = stack[1:] 60 | 61 | self._preparetransitionset(parserstate) 62 | 63 | def goldtransition(self, parserstate, goldrels=None): 64 | SHIFT = self.mappings['action']['Shift'] 65 | SWAP = self.mappings['action']['Swap'] 66 | LEFTARC = self.mappings['action']['Left-Arc'] 67 | RIGHTARC = self.mappings['action']['Right-Arc'] 68 | 69 | goldrels = goldrels or parserstate.goldrels 70 | stack = parserstate.stack 71 | buf = parserstate.buf 72 | head = parserstate.head 73 | proj_order = parserstate.proj_order 74 | POS = len(self.mappings['pos']) 75 | 76 | 77 | if len(stack) < 2 and len(buf) > 0: 78 | return (SHIFT,-1) 79 | 80 | stack0_done = True 81 | for x in buf: 82 | if x in goldrels[stack[0]]: 83 | stack0_done = False 84 | break 85 | for y in stack: 86 | if y in goldrels[stack[0]]: 87 | stack0_done=False 88 | break 89 | 90 | stack1_done=True 91 | for x in buf: 92 | if x in goldrels[stack[1]]: 93 | stack1_done=False 94 | break 95 | for y in stack: 96 | if y in goldrels[stack[1]]: 97 | stack1_done=False 98 | break 99 | 100 | if stack[1] in goldrels[stack[0]] and stack1_done: 101 | rel = goldrels[stack[0]][stack[1]] 102 | return (LEFTARC, rel) 103 | elif stack[0] in goldrels[stack[1]] and stack0_done: 104 | rel = goldrels[stack[1]][stack[0]] 105 | return (RIGHTARC, rel) 106 | else: 107 | if stack[1] < stack[0] and proj_order[stack[0]] < proj_order[stack[1]]: 108 | return (SWAP, -1) 109 | else: 110 | return (SHIFT, -1) 111 | 112 | 113 | def trans_to_str(self, t, state, pos, fpos=None): 114 | SHIFT = self.mappings['action']['Shift'] 115 | SWAP = self.mappings['action']['Swap'] 116 | LEFTARC = self.mappings['action']['Left-Arc'] 117 | RIGHTARC = self.mappings['action']['Right-Arc'] 118 | if t[0] == SHIFT: 119 | if fpos is None: 120 | return "Shift\t%s" % (pos[state.buf[0]]) 121 | else: 122 | return "Shift\t%s\t%s" % (pos[state.buf[0]], fpos[state.buf[0]]) 123 | elif t[0] == SWAP: 124 | return "Swap\t" 125 | elif t[0] == LEFTARC: 126 | return "Left-Arc\t%s" % (self.invmappings['rel'][t[1]]) 127 | elif t[0] == RIGHTARC: 128 | return "Right-Arc\t%s" % (self.invmappings['rel'][t[1]]) 129 | 130 | @classmethod 131 | def trans_from_line(self, line): 132 | if line[0] == 'Left-Arc': 133 | fields = { 'action':line[0], 'rel':line[1] } 134 | elif line[0] == 'Right-Arc': 135 | fields = { 'action':line[0], 'rel':line[1] } 136 | elif line[0] == 'Swap': 137 | fields = { 'action':line[0], 'pos':None } 138 | elif line[0] == 'Shift': 139 | fields = { 'action':line[0], 'pos':line[1] } 140 | if len(line) > 2: 141 | fields['fpos'] = line[2] 142 | else: 143 | raise ValueError(line[0]) 144 | return fields 145 | 146 | def tuple_trans_to_int(self, cand, t): 147 | SHIFT = self.mappings['action']['Shift'] 148 | SWAP = self.mappings['action']['Swap'] 149 | LEFTARC = self.mappings['action']['Left-Arc'] 150 | RIGHTARC = self.mappings['action']['Right-Arc'] 151 | 152 | RELS = len(self.mappings['rel']) 153 | 154 | base = 0 155 | if t[0] == SHIFT: 156 | return base 157 | 158 | base += 1 159 | 160 | if t[0] == SWAP: 161 | return base 162 | base += 1 163 | 164 | if t[0] == LEFTARC: 165 | return base + t[1] 166 | 167 | base += RELS 168 | 169 | if t[0] == RIGHTARC: 170 | return base + t[1] 171 | 172 | def tuple_trans_from_int(self, cand, action): 173 | SHIFT = self.mappings['action']['Shift'] 174 | SWAP = self.mappings['action']['Swap'] 175 | LEFTARC = self.mappings['action']['Left-Arc'] 176 | RIGHTARC = self.mappings['action']['Right-Arc'] 177 | RELS = len(self.mappings['rel']) 178 | rel = -1 179 | 180 | base = 0 181 | if action == base: 182 | a = (SHIFT,) 183 | base += 1 184 | 185 | if action == base: 186 | a = (SWAP,) 187 | 188 | base += 1 189 | 190 | if base <= action < base + RELS: 191 | a = (LEFTARC,) 192 | rel = action - base 193 | base += RELS 194 | 195 | if base <= action < base + RELS: 196 | a = (RIGHTARC,) 197 | rel = action - base 198 | 199 | return a, rel 200 | -------------------------------------------------------------------------------- /statetr/README.md: -------------------------------------------------------------------------------- 1 | # Graph-to-Graph Transformer for Transition-based Dependency Parsing (State Transformer) 2 | 3 | Pytorch implementation of the paper for the State Transformer model 4 | 5 | ## Dependencies: 6 | You should install the following packages for training/evaluating the model: 7 | - Python 3.6 8 | - [Pytorch](https://pytorch.org/) > 1.0.1 9 | - [Numpy](https://numpy.org/) 10 | - [transformers](https://github.com/huggingface/transformers) 11 | - [pytorch-pretrained-bert](https://github.com/huggingface/transformers/tree/v0.6.2) 12 | - [Tensorboard](https://www.tensorflow.org/guide/summaries_and_tensorboard) 13 | - [Torchvision](https://pytorch.org/) 14 | 15 | Or the easiest way is to run the following command: 16 | ``` 17 | conda env create -n statetr -f statetr.yml 18 | ``` 19 | 20 | ## Preparing the data: 21 | Just follow the exact instruction that is described in "Sentence Transformer" repository. 22 | 23 | ## Training: 24 | 25 | Here are all paramters that are needed to train your own model: 26 | 27 | ``` 28 | run.py [-h] [--withpunct] [--graphinput] [--poolinghid] [--unlabeled] 29 | [--freezedp] [--lowercase] [--usepos] [--Bertoptim] 30 | [--pretrained] [--withbert] [--bertname BERTNAME] 31 | [--bertpath BERTPATH] [--fhistmodel] [--fcompmodel] 32 | [--layernorm] [--multigpu] [--seppoint] [--mean_seq] 33 | [--language LANGUAGE] [--datapath DATAPATH] 34 | [--trainfile TRAINFILE] [--devfile DEVFILE] 35 | [--testfile TESTFILE] [--seqpath SEQPATH] 36 | [--outputname OUTPUTNAME] [--batchsize BATCHSIZE] 37 | [--nepochs NEPOCHS] [--real_epoch REAL_EPOCH] [--lr LR] 38 | [--shuffle] [--ffhidden FFHIDDEN] [--clipword CLIPWORD] 39 | [--nclass NCLASS] [--ffdropout FFDROPOUT] 40 | [--nlayershistory NLAYERSHISTORY] [--embsize EMBSIZE] 41 | [--maxsteplength MAXSTEPLENGTH] [--updatelr UPDATELR] 42 | [--hiddensizelabel HIDDENSIZELABEL] [--histsize HISTSIZE] 43 | [--labelemb LABELEMB] [--nattentionlayer NATTENTIONLAYER] 44 | [--nattentionheads NATTENTIONHEADS] 45 | [--warmupproportion WARMUPPROPORTION] [--modelpath MODELPATH] 46 | [--use_topbuffer] [--use_justexist] [--use_two_opts] 47 | [--lr_nonbert LR_NONBERT] [--mainpath MAINPATH] [--debug] 48 | 49 | optional arguments: 50 | -h, --help show this help message and exit 51 | --withpunct Use punctuation 52 | --graphinput Input graph to the model 53 | --poolinghid Max Pooling the last hidden layer instead of CLS 54 | --unlabeled Unlabeled dependency parsing 55 | --freezedp Freeze the dependency relation embeddings 56 | --lowercase Lowercase the words 57 | --usepos Use POS tagger 58 | --Bertoptim Use BertAdam for optimization 59 | --pretrained Start with a checkpoint 60 | --withbert Initialize the model with BERT 61 | --bertname BERTNAME Type of Pre-trained BERT 62 | --bertpath BERTPATH Type of Pre-trained BERT 63 | --fhistmodel Apply history model 64 | --fcompmodel Apply composition model 65 | --layernorm Layer normalization for graph input 66 | --multigpu Run the model on multiple GPUs 67 | --seppoint Use CLS for dependency classifiers or graph output 68 | mechanism 69 | --mean_seq Used for computing total number of steps 70 | --language LANGUAGE Language to train 71 | --datapath DATAPATH Data directory for train/test 72 | --trainfile TRAINFILE 73 | File to train the model 74 | --devfile DEVFILE File to validate the model 75 | --testfile TESTFILE File to test the model 76 | --seqpath SEQPATH File to test the model 77 | --outputname OUTPUTNAME 78 | Name of the output model 79 | --batchsize BATCHSIZE 80 | Batch size number 81 | --nepochs NEPOCHS Number of epochs 82 | --real_epoch REAL_EPOCH 83 | Number of epochs that is reduced from total epochs 84 | (checkpoint) 85 | --lr LR Learning rate for training 86 | --shuffle Shuffle training inputs 87 | --ffhidden FFHIDDEN Size of hidden layer in classifier 88 | --clipword CLIPWORD Percentage of keeping the orginal words of dataset 89 | --nclass NCLASS Number of classes in classifier 90 | --ffdropout FFDROPOUT 91 | Amount of drop-out in classifier 92 | --nlayershistory NLAYERSHISTORY 93 | Number of layers in LSTM history model 94 | --embsize EMBSIZE Dimension of Embeddings 95 | --maxsteplength MAXSTEPLENGTH 96 | Maximum size of steps to de done on validation/test 97 | time 98 | --updatelr UPDATELR Step to update the learning rate 99 | --hiddensizelabel HIDDENSIZELABEL 100 | Size of hidden layer in label classifier 101 | --histsize HISTSIZE Size of embedding in history model 102 | --labelemb LABELEMB Size of label embeddings 103 | --nattentionlayer NATTENTIONLAYER 104 | Number of layers in self-attention model 105 | --nattentionheads NATTENTIONHEADS 106 | Number of attention heads in self-attention model 107 | --warmupproportion WARMUPPROPORTION 108 | Proportion of warm-up for BertAdam optimizer 109 | --modelpath MODELPATH 110 | Name of the pretrained model 111 | --use_topbuffer Use also top element of Buffer 112 | --use_justexist Use top buffer just for exist classifier 113 | --use_two_opts Use two optimizers for training 114 | --lr_nonbert LR_NONBERT 115 | Learning rate for non-bert 116 | --mainpath MAINPATH File to pre-trained char embeddings 117 | --debug Debug phase 118 | 119 | ``` 120 | 121 | To reproduce results of the paper, you can run the model with ```statetr.bash``` for baseline model, and ```statetr_g2g.bash``` 122 | for the integrated one. For UD results, change the number of epochs to 20. 123 | 124 | ## Evaluation: 125 | 126 | To evaluate a trained model, add the location of saved model, the input file, and output path to ```predict.bash``` or ```predict_wsj.bash``` file, 127 | then it computes the output CoNLL file, and LAS (UAS) scores. 128 | Here are the input requirements for evaluation: 129 | 130 | ``` 131 | test.py [-h] [--datapath DATAPATH] [--testfile TESTFILE] 132 | [--model_name MODEL_NAME] [--batchsize BATCHSIZE] 133 | [--mainpath MAINPATH] 134 | 135 | optional arguments: 136 | -h, --help show this help message and exit 137 | --datapath DATAPATH Data directory for train/test 138 | --testfile TESTFILE File to test the model 139 | --model_name MODEL_NAME 140 | Model directory 141 | --batchsize BATCHSIZE 142 | Batch size number 143 | --mainpath MAINPATH File to test the model 144 | 145 | ``` 146 | 147 | ```eval.pl``` and ```ud_eval.py``` files are used as the official evaluation script for English Penn Treebank, and UD Treebanks. 148 | 149 | ## Error Analysis: 150 | 151 | To replicate Figure 3 and Table 3 of the paper, you can donwload [MaltEval tool](https://cl.lingfil.uu.se/~nivre/research/MaltEval.html), and use 152 | the output predictions of ```predict_wsj.bash``` or ```predict.bash``` file, and gold dependencies to re-create plots. It's so easy, just one command! 153 | 154 | -------------------------------------------------------------------------------- /senttr/README.md: -------------------------------------------------------------------------------- 1 | # Graph-to-Graph Transformer for Transition-based Dependency Parsing (Sentence Transformer) 2 | 3 | Pytorch implementation of the paper for the Sentence Transformer model 4 | 5 | ## Dependencies : 6 | You should install the following packages for train/testing the model: 7 | - Python 3.7 8 | - [Pytorch](https://pytorch.org/) > 1.4.0 9 | - [Numpy](https://numpy.org/) 10 | - [transformers](https://github.com/huggingface/transformers) 11 | - [Tensorboard](https://www.tensorflow.org/guide/summaries_and_tensorboard) 12 | - [Torchvision](https://pytorch.org/) 13 | 14 | Or the easiest way is to run the following command: 15 | ``` 16 | conda env create -n senttr -f environment.yml 17 | ``` 18 | 19 | ## Preparing the data: 20 | For each dataset, we should do some pre-processing steps to build the proper input. 21 | ### WSJ Penn Treebank: 22 | Download the data from [here](https://catalog.ldc.upenn.edu/LDC99T42). 23 | Now, convert constituency format to Stanford dependency style by following 24 | [this repository](https://github.com/hankcs/TreebankPreprocessing). 25 | Now, you can build the gold oracle for training data as follows (it's based on [arc-swift](https://github.com/qipeng/arc-swift) repo): 26 | 27 | ``` 28 | cd preprocess/utils 29 | ./create_mappings.sh ../data/train.conll > mappings-ptb.txt 30 | cd preprocess/src 31 | python gen_oracle_seq.py ../data/train.conll train.seq --transsys ASd --mappings ./utils/mappings-ptb.txt 32 | ``` 33 | To include `SWAP` operation, you should update `transition.py` and `parserstate.py` files of arc-swift repository with our `transition.py` and `parserstate.py` files (for `transition.py` file, update the `ArcStandard` class) . 34 | 35 | Finally, you should replace the gold PoS tags with the predicted ones from [Stanford PoS tagger](https://nlp.stanford.edu/software/tagger.shtml). 36 | You can use [this repository](https://github.com/shuoyangd/hoolock) to do this replacement. 37 | 38 | ### UD Treebanks: 39 | 40 | Download the data from [here](http://hdl.handle.net/11234/1-2895). 41 | Since our models work with CoNLL-X format, you should convert dataset from CoNLL-U format to CoNLL-X format with [this tool](https://github.com/alirezamshi/G2GTr/blob/master/senttr/conllu_to_conllx_no_underline.pl). Then, you can find oracles by the modified version of arc-swift, as mentioned in above section. 42 | ## Training : 43 | 44 | To train Sentence Transformer model, and its combination with Graph2Graph Transformer, you can check the following details: 45 | 46 | ``` 47 | run.py train [-h] [--buckets BUCKETS] [--epochs EPOCHS] [--punct] 48 | [--ftrain FTRAIN] [--ftrain_seq FTRAIN_SEQ] [--fdev FDEV] 49 | [--ftest FTEST] [--warmupproportion WARMUPPROPORTION] 50 | [--lowercase] [--lower_for_nonbert] 51 | [--modelname MODELNAME] [--lr LR] [--lr2 LR2] 52 | [--input_graph] [--layernorm_key] [--layernorm_value] 53 | [--use_two_opts] [--mlp_dropout MLP_DROPOUT] 54 | [--weight_decay WEIGHT_DECAY] 55 | [--max_grad_norm MAX_GRAD_NORM] [--max_seq MAX_SEQ] 56 | [--n_attention_layer N_ATTENTION_LAYER] [--checkpoint] 57 | [--act_thr ACT_THR] [--bert_path BERT_PATH] 58 | [--main_path MAIN_PATH] [--conf CONF] [--model MODEL] 59 | [--vocab VOCAB] [--device DEVICE] [--seed SEED] 60 | [--threads THREADS] [--batch_size BATCH_SIZE] 61 | 62 | optional arguments: 63 | -h, --help show this help message and exit 64 | --buckets BUCKETS Max number of buckets to use 65 | --epochs EPOCHS Number of training epochs 66 | --punct Whether to include punctuation 67 | --ftrain FTRAIN Path to train data 68 | --ftrain_seq FTRAIN_SEQ 69 | Path to train oracle file 70 | --fdev FDEV Path to dev file 71 | --ftest FTEST Path to test file 72 | --warmupproportion WARMUPPROPORTION, -w WARMUPPROPORTION 73 | Warm up proportion for BertAdam optimizer 74 | --lowercase Whether to do lowercase in tokenisation step 75 | --lower_for_nonbert Divide warm-up proportion of optimiser for randomly 76 | initialised parameters 77 | --modelname MODELNAME 78 | Path to saved checkpoint 79 | --lr LR Learning rate for optimizer (for BERT parameters if 80 | two optimisers used) 81 | --lr2 LR2 Learning rate for non-BERT parameters (two optimisers) 82 | --input_graph Input dependency graph to attention mechanism 83 | --layernorm_key layer normalization for Key (graph input) 84 | --layernorm_value layer normalization for Value (graph input) 85 | --use_two_opts Use one optimizer for Bert and one for others 86 | --mlp_dropout MLP_DROPOUT 87 | MLP drop out 88 | --weight_decay WEIGHT_DECAY 89 | Weight Decay 90 | --max_grad_norm MAX_GRAD_NORM 91 | Clip gradient 92 | --max_seq MAX_SEQ Maximum number of actions per sentence 93 | --n_attention_layer N_ATTENTION_LAYER 94 | Number of Attention Layers 95 | --checkpoint Start from a checkpoint 96 | --act_thr ACT_THR Maximum number of actions per sentence (training data) 97 | --bert_path BERT_PATH 98 | path to BERT 99 | --main_path MAIN_PATH 100 | path to main directory 101 | --conf CONF, -c CONF path to config file 102 | --model MODEL, -m MODEL 103 | path to model file 104 | --vocab VOCAB, -v VOCAB 105 | path to vocab file 106 | --device DEVICE, -d DEVICE 107 | ID of GPU to use 108 | --seed SEED, -s SEED seed for generating random numbers 109 | --threads THREADS, -t THREADS 110 | max num of threads 111 | --batch_size BATCH_SIZE 112 | max num of buckets to use 113 | ``` 114 | 115 | To replicate our results, you can run ```senttr.bash``` for the baseline, and ```senttr_g2g.bash``` for the integrated model. 116 | 117 | ## Evaluation: 118 | 119 | To evaluate the model, you can check the following input requirements: 120 | 121 | ``` 122 | run.py predict [-h] [--fdata FDATA] [--fpred FPRED] 123 | [--modelname MODELNAME] [--mainpath MAINPATH] 124 | [--conf CONF] [--model MODEL] [--vocab VOCAB] 125 | [--device DEVICE] [--seed SEED] [--threads THREADS] 126 | [--batch_size BATCH_SIZE] 127 | 128 | optional arguments: 129 | -h, --help show this help message and exit 130 | --fdata FDATA Path to test dataset 131 | --fpred FPRED Prediction path 132 | --modelname MODELNAME 133 | Path to trained model 134 | --mainpath MAINPATH Main path 135 | --conf CONF, -c CONF path to config file 136 | --model MODEL, -m MODEL 137 | path to model file 138 | --vocab VOCAB, -v VOCAB 139 | path to vocab file 140 | --device DEVICE, -d DEVICE 141 | ID of GPU to use 142 | --seed SEED, -s SEED seed for generating random numbers 143 | --threads THREADS, -t THREADS 144 | max num of threads 145 | --batch_size BATCH_SIZE 146 | max num of buckets to use 147 | ``` 148 | 149 | To predict and evaluate your trained model, fill requirements (data path, prediction path, model path) in the ```predict.bash``` file, then it 150 | produces the predicted output CoNLL file, LAS, and UAS scores. 151 | -------------------------------------------------------------------------------- /senttr/parser/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from parser.metric import Metric 4 | 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | import numpy 9 | from tqdm import tqdm 10 | import torch.nn.functional as F 11 | 12 | 13 | ## this class keeps parser state information 14 | class State(object): 15 | def __init__(self, mask,device,bert_label=None,input_graph=False): 16 | self.tok_buffer = mask.nonzero().squeeze(1) 17 | self.tok_stack = torch.zeros(len(self.tok_buffer)+1).long().to(device) 18 | self.tok_stack[0] = 1 19 | self.buf = [i+1 for i in range(len(self.tok_buffer))] 20 | self.stack = [0] 21 | self.head = [[-1, -1] for _ in range(len(self.tok_buffer)+1)] 22 | self.dict = {0:"LEFTARC", 1:"RIGHTARC" ,2:"SHIFT", 3:"SWAP"} 23 | self.graph,self.label,self.convert = self.build_graph(mask,device,bert_label) 24 | self.input_graph=input_graph 25 | 26 | # build partially constructed graph 27 | def build_graph(self,mask,device,bert_label): 28 | graph = torch.zeros((len(mask),len(mask))).long().to(device) 29 | label = torch.ones(len(mask) * bert_label).long().to(device) 30 | offset = self.tok_buffer.clone() 31 | convert = {0:1} 32 | convert.update({i+1:off.item() for i,off in enumerate(offset)}) 33 | convert.update({len(convert):len(mask)}) 34 | for i in range(len(offset)-1): 35 | graph[offset[i],offset[i]+1:offset[i+1]] = 1 36 | graph[offset[i]+1:offset[i+1],offset[i]] = 2 37 | label[offset] = 0 38 | label[:2] = 0 39 | del offset 40 | return graph,label,convert 41 | 42 | def get_graph(self): 43 | return self.graph,self.label 44 | 45 | #required features for graph output mechanism (exist classifier) 46 | def feature(self): 47 | return torch.cat((self.tok_stack[1].unsqueeze(0),self.tok_stack[0].unsqueeze(0) 48 | ,self.tok_buffer[0].unsqueeze(0))) 49 | 50 | # required features for graph output mechanism (relation classifier) 51 | def feature_label(self): 52 | return torch.cat((self.tok_stack[1].unsqueeze(0),self.tok_stack[0].unsqueeze(0))) 53 | 54 | # update state 55 | def update(self,act,rel=None): 56 | act = self.dict[act.item()] 57 | if not self.finished(): 58 | if act == "SHIFT": 59 | self.stack = [self.buf[0]] + self.stack 60 | self.buf = self.buf[1:] 61 | self.tok_buffer = torch.roll(self.tok_buffer,-1,dims=0).clone() 62 | self.tok_stack = torch.roll(self.tok_stack,1,dims=0).clone() 63 | self.tok_stack[0] = self.tok_buffer[-1].clone() 64 | elif act == "LEFTARC": 65 | self.head[self.stack[1]] = [self.stack[0], rel.item()] 66 | if self.input_graph: 67 | self.graph[self.convert[self.stack[0]],self.convert[self.stack[1]]] = 1 68 | self.graph[self.convert[self.stack[1]],self.convert[self.stack[0]]] = 2 69 | self.label[self.convert[self.stack[1]]] = rel 70 | self.stack = [self.stack[0]] + self.stack[2:] 71 | self.tok_stack = torch.cat( 72 | (self.tok_stack[0].unsqueeze(0),torch.roll(self.tok_stack[1:],-1,dims=0))).clone() 73 | elif act == "RIGHTARC": 74 | self.head[self.stack[0]] = [self.stack[1], rel.item()] 75 | if self.input_graph: 76 | self.graph[self.convert[self.stack[1]],self.convert[self.stack[0]]] = 1 77 | self.graph[self.convert[self.stack[0]],self.convert[self.stack[1]]] = 2 78 | self.label[self.convert[self.stack[0]]] = rel 79 | self.stack = self.stack[1:] 80 | self.tok_stack = torch.roll(self.tok_stack,-1,dims=0).clone() 81 | elif act == "SWAP": 82 | self.buf = [self.stack[1]] + self.buf 83 | self.stack = [self.stack[0]] + self.stack[2:] 84 | self.tok_stack = torch.cat( 85 | (self.tok_stack[0].unsqueeze(0), torch.roll(self.tok_stack[1:], -1, dims=0))).clone() 86 | self.tok_buffer = torch.roll(self.tok_buffer, 1, dims=0).clone() 87 | self.tok_buffer[0] = self.tok_stack[-1] 88 | 89 | # legal actions at evaluation time 90 | def legal_act(self): 91 | t = [0,0,0,0] 92 | if len(self.stack) >= 2 and self.stack[1] != 0: 93 | t[0] = 1 94 | if len(self.stack) >= 2 and self.stack[0] != 0: 95 | t[1] = 1 96 | if len(self.buf) > 0: 97 | t[2] = 1 98 | if len(self.stack) >= 2 and 0 < self.stack[1] < self.stack[0]: 99 | t[3] = 1 100 | return t 101 | 102 | # check whether the dependency tree is completed or not. 103 | def finished(self): 104 | return len(self.stack) == 1 and len(self.buf) == 0 105 | 106 | def __repr__(self): 107 | return "State:\nConvert:{}\n Graph:{}\n,Label:{}\nHead:{}\n".\ 108 | format(self.convert,self.graph,self.label,self.head) 109 | 110 | class Model(object): 111 | 112 | def __init__(self, vocab, parser, config, num_labels): 113 | super(Model, self).__init__() 114 | self.vocab = vocab 115 | self.parser = parser 116 | self.num_labels = num_labels 117 | self.config = config 118 | self.criterion = nn.CrossEntropyLoss() 119 | 120 | def train(self, loader): 121 | self.parser.train() 122 | pbar = tqdm(total= len(loader)) 123 | 124 | for ccc,(words, tags, masks, actions, mask_actions, rels) in enumerate(loader): 125 | 126 | states = [State(mask,tags.device,self.vocab.bert_index,self.config.input_graph) 127 | for mask in masks] 128 | s_arc,s_rel = self.parser(words, tags, masks, states, actions, rels) 129 | 130 | 131 | if self.config.use_two_opts: 132 | self.optimizer_nonbert.zero_grad() 133 | self.optimizer_bert.zero_grad() 134 | else: 135 | self.optimizer.zero_grad() 136 | 137 | ## leftarc and rightarc have dependencies, so we filter swap and shift 138 | mask_rels = ((actions != 3).long() * (actions != 2).long() * mask_actions.long()).bool() 139 | 140 | actions = actions[mask_actions] 141 | s_arc = s_arc[mask_actions] 142 | 143 | rels = rels[mask_rels] 144 | s_rel = s_rel[mask_rels] 145 | 146 | loss = self.get_loss(s_arc,actions,s_rel,rels) 147 | loss.backward() 148 | ## optimization step 149 | if self.config.use_two_opts: 150 | self.optimizer_nonbert.step() 151 | self.optimizer_bert.step() 152 | self.scheduler_nonbert.step() 153 | self.scheduler_bert.step() 154 | else: 155 | self.optimizer.step() 156 | self.scheduler.step() 157 | del states,words,tags,masks,mask_actions,actions,rels,s_rel,s_arc,mask_rels 158 | 159 | pbar.update(1) 160 | 161 | @torch.no_grad() 162 | def evaluate(self, loader, punct=False): 163 | self.parser.eval() 164 | metric = Metric() 165 | pbar = tqdm(total=len(loader)) 166 | 167 | for words, tags, masks,heads,rels,mask_heads in loader: 168 | states = [State(mask, tags.device,self.vocab.bert_index,self.config.input_graph) 169 | for mask in masks] 170 | 171 | states = self.parser(words, tags, masks,states) 172 | 173 | pred_heads = [] 174 | pred_rels = [] 175 | for state in states: 176 | pred_heads.append([h[0] for h in state.head][1:]) 177 | pred_rels.append([h[1] for h in state.head][1:]) 178 | pred_heads = [item for sublist in pred_heads for item in sublist] 179 | pred_rels = [item for sublist in pred_rels for item in sublist] 180 | 181 | pred_heads = torch.tensor(pred_heads).to(heads.device) 182 | pred_rels = torch.tensor(pred_rels).to(heads.device) 183 | 184 | heads = heads[mask_heads] 185 | rels = rels[mask_heads] 186 | pbar.update(1) 187 | metric(pred_heads, pred_rels, heads, rels) 188 | del states 189 | 190 | return metric 191 | 192 | @torch.no_grad() 193 | def predict(self, loader): 194 | self.parser.eval() 195 | metric = Metric() 196 | pbar = tqdm(total=len(loader)) 197 | all_arcs, all_rels = [], [] 198 | for words, tags, masks,heads,rels,mask_heads in loader: 199 | states = [State(mask, tags.device, self.vocab.bert_index,self.config.input_graph) 200 | for mask in masks] 201 | states = self.parser(words, tags, masks, states) 202 | 203 | pred_heads = [] 204 | pred_rels = [] 205 | for state in states: 206 | pred_heads.append([h[0] for h in state.head][1:]) 207 | pred_rels.append([h[1] for h in state.head][1:]) 208 | 209 | pred_heads = [item for sublist in pred_heads for item in sublist] 210 | pred_rels = [item for sublist in pred_rels for item in sublist] 211 | 212 | pred_heads = torch.tensor(pred_heads).to(heads.device) 213 | pred_rels = torch.tensor(pred_rels).to(heads.device) 214 | 215 | heads = heads[mask_heads] 216 | rels = rels[mask_heads] 217 | 218 | metric(pred_heads, pred_rels, heads, rels) 219 | 220 | lens = masks.sum(1).tolist() 221 | 222 | all_arcs.extend(torch.split(pred_heads, lens)) 223 | all_rels.extend(torch.split(pred_rels, lens)) 224 | pbar.update(1) 225 | all_arcs = [seq.tolist() for seq in all_arcs] 226 | all_rels = [self.vocab.id2rel(seq) for seq in all_rels] 227 | 228 | return all_arcs, all_rels, metric 229 | 230 | def get_loss(self, s_arc, actions, s_rel, rels): 231 | arc_loss = self.criterion(s_arc, actions) 232 | rel_loss = self.criterion(s_rel, rels) 233 | loss = arc_loss + rel_loss 234 | return loss -------------------------------------------------------------------------------- /statetr/file_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import (absolute_import, division, print_function, unicode_literals) 3 | 4 | import sys 5 | import json 6 | import logging 7 | import os 8 | import shutil 9 | import tempfile 10 | import fnmatch 11 | from functools import wraps 12 | from hashlib import sha256 13 | import sys 14 | from io import open 15 | import boto3 16 | import requests 17 | from botocore.exceptions import ClientError 18 | from tqdm import tqdm 19 | 20 | try: 21 | from torch.hub import _get_torch_home 22 | torch_cache_home = _get_torch_home() 23 | except ImportError: 24 | torch_cache_home = os.path.expanduser( 25 | os.getenv('TORCH_HOME', os.path.join( 26 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) 27 | default_cache_path = os.path.join(torch_cache_home, 'pytorch_pretrained_bert') 28 | 29 | try: 30 | from urllib.parse import urlparse 31 | except ImportError: 32 | from urlparse import urlparse 33 | 34 | try: 35 | from pathlib import Path 36 | PYTORCH_PRETRAINED_BERT_CACHE = Path( 37 | os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)) 38 | except (AttributeError, ImportError): 39 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 40 | default_cache_path) 41 | 42 | CONFIG_NAME = "config.json" 43 | WEIGHTS_NAME = "pytorch_model.bin" 44 | 45 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 46 | 47 | 48 | def url_to_filename(url, etag=None): 49 | """ 50 | Convert `url` into a hashed filename in a repeatable way. 51 | If `etag` is specified, append its hash to the url's, delimited 52 | by a period. 53 | """ 54 | url_bytes = url.encode('utf-8') 55 | url_hash = sha256(url_bytes) 56 | filename = url_hash.hexdigest() 57 | 58 | if etag: 59 | etag_bytes = etag.encode('utf-8') 60 | etag_hash = sha256(etag_bytes) 61 | filename += '.' + etag_hash.hexdigest() 62 | 63 | return filename 64 | 65 | 66 | def filename_to_url(filename, cache_dir=None): 67 | """ 68 | Return the url and etag (which may be ``None``) stored for `filename`. 69 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 70 | """ 71 | if cache_dir is None: 72 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 73 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 74 | cache_dir = str(cache_dir) 75 | 76 | cache_path = os.path.join(cache_dir, filename) 77 | if not os.path.exists(cache_path): 78 | raise EnvironmentError("file {} not found".format(cache_path)) 79 | 80 | meta_path = cache_path + '.json' 81 | if not os.path.exists(meta_path): 82 | raise EnvironmentError("file {} not found".format(meta_path)) 83 | 84 | with open(meta_path, encoding="utf-8") as meta_file: 85 | metadata = json.load(meta_file) 86 | url = metadata['url'] 87 | etag = metadata['etag'] 88 | 89 | return url, etag 90 | 91 | 92 | def cached_path(url_or_filename, cache_dir=None): 93 | """ 94 | Given something that might be a URL (or might be a local path), 95 | determine which. If it's a URL, download the file and cache it, and 96 | return the path to the cached file. If it's already a local path, 97 | make sure the file exists and then return the path. 98 | """ 99 | if cache_dir is None: 100 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 101 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 102 | url_or_filename = str(url_or_filename) 103 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 104 | cache_dir = str(cache_dir) 105 | 106 | parsed = urlparse(url_or_filename) 107 | 108 | if parsed.scheme in ('http', 'https', 's3'): 109 | # URL, so get it from the cache (downloading if necessary) 110 | return get_from_cache(url_or_filename, cache_dir) 111 | elif os.path.exists(url_or_filename): 112 | # File, and it exists. 113 | return url_or_filename 114 | elif parsed.scheme == '': 115 | # File, but it doesn't exist. 116 | raise EnvironmentError("file {} not found".format(url_or_filename)) 117 | else: 118 | # Something unknown 119 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 120 | 121 | 122 | def split_s3_path(url): 123 | """Split a full s3 path into the bucket name and path.""" 124 | parsed = urlparse(url) 125 | if not parsed.netloc or not parsed.path: 126 | raise ValueError("bad s3 path {}".format(url)) 127 | bucket_name = parsed.netloc 128 | s3_path = parsed.path 129 | # Remove '/' at beginning of path. 130 | if s3_path.startswith("/"): 131 | s3_path = s3_path[1:] 132 | return bucket_name, s3_path 133 | 134 | 135 | def s3_request(func): 136 | """ 137 | Wrapper function for s3 requests in order to create more helpful error 138 | messages. 139 | """ 140 | 141 | @wraps(func) 142 | def wrapper(url, *args, **kwargs): 143 | try: 144 | return func(url, *args, **kwargs) 145 | except ClientError as exc: 146 | if int(exc.response["Error"]["Code"]) == 404: 147 | raise EnvironmentError("file {} not found".format(url)) 148 | else: 149 | raise 150 | 151 | return wrapper 152 | 153 | 154 | @s3_request 155 | def s3_etag(url): 156 | """Check ETag on S3 object.""" 157 | s3_resource = boto3.resource("s3") 158 | bucket_name, s3_path = split_s3_path(url) 159 | s3_object = s3_resource.Object(bucket_name, s3_path) 160 | return s3_object.e_tag 161 | 162 | 163 | @s3_request 164 | def s3_get(url, temp_file): 165 | """Pull a file directly from S3.""" 166 | s3_resource = boto3.resource("s3") 167 | bucket_name, s3_path = split_s3_path(url) 168 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 169 | 170 | 171 | def http_get(url, temp_file): 172 | req = requests.get(url, stream=True) 173 | content_length = req.headers.get('Content-Length') 174 | total = int(content_length) if content_length is not None else None 175 | progress = tqdm(unit="B", total=total) 176 | for chunk in req.iter_content(chunk_size=1024): 177 | if chunk: # filter out keep-alive new chunks 178 | progress.update(len(chunk)) 179 | temp_file.write(chunk) 180 | progress.close() 181 | 182 | 183 | def get_from_cache(url, cache_dir=None): 184 | """ 185 | Given a URL, look for the corresponding dataset in the local cache. 186 | If it's not there, download it. Then return the path to the cached file. 187 | """ 188 | if cache_dir is None: 189 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 190 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 191 | cache_dir = str(cache_dir) 192 | 193 | if not os.path.exists(cache_dir): 194 | os.makedirs(cache_dir) 195 | 196 | # Get eTag to add to filename, if it exists. 197 | if url.startswith("s3://"): 198 | etag = s3_etag(url) 199 | else: 200 | try: 201 | response = requests.head(url, allow_redirects=True) 202 | if response.status_code != 200: 203 | etag = None 204 | else: 205 | etag = response.headers.get("ETag") 206 | except EnvironmentError: 207 | etag = None 208 | 209 | if sys.version_info[0] == 2 and etag is not None: 210 | etag = etag.decode('utf-8') 211 | filename = url_to_filename(url, etag) 212 | 213 | # get cache path to put the file 214 | cache_path = os.path.join(cache_dir, filename) 215 | 216 | # If we don't have a connection (etag is None) and can't identify the file 217 | # try to get the last downloaded one 218 | if not os.path.exists(cache_path) and etag is None: 219 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 220 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 221 | if matching_files: 222 | cache_path = os.path.join(cache_dir, matching_files[-1]) 223 | 224 | if not os.path.exists(cache_path): 225 | # Download to temporary file, then copy to cache dir once finished. 226 | # Otherwise you get corrupt cache entries if the download gets interrupted. 227 | with tempfile.NamedTemporaryFile() as temp_file: 228 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 229 | 230 | # GET file object 231 | if url.startswith("s3://"): 232 | s3_get(url, temp_file) 233 | else: 234 | http_get(url, temp_file) 235 | 236 | # we are copying the file before closing it, so flush to avoid truncation 237 | temp_file.flush() 238 | # shutil.copyfileobj() starts at the current position, so go to the start 239 | temp_file.seek(0) 240 | 241 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 242 | with open(cache_path, 'wb') as cache_file: 243 | shutil.copyfileobj(temp_file, cache_file) 244 | 245 | logger.info("creating metadata file for %s", cache_path) 246 | meta = {'url': url, 'etag': etag} 247 | meta_path = cache_path + '.json' 248 | with open(meta_path, 'w') as meta_file: 249 | output_string = json.dumps(meta) 250 | if sys.version_info[0] == 2 and isinstance(output_string, str): 251 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 252 | meta_file.write(output_string) 253 | 254 | logger.info("removing temp file %s", temp_file.name) 255 | 256 | return cache_path 257 | 258 | 259 | def read_set_from_file(filename): 260 | ''' 261 | Extract a de-duped collection (set) of text from a file. 262 | Expected file format is one item per line. 263 | ''' 264 | collection = set() 265 | with open(filename, 'r', encoding='utf-8') as file_: 266 | for line in file_: 267 | collection.add(line.rstrip()) 268 | return collection 269 | 270 | 271 | def get_file_extension(path, dot=True, lower=True): 272 | ext = os.path.splitext(path)[1] 273 | ext = ext if dot else ext[1:] 274 | return ext.lower() if lower else ext 275 | -------------------------------------------------------------------------------- /senttr/parser/utils/vocab.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from collections import Counter 4 | import os 5 | import regex 6 | import torch 7 | from transformers import * 8 | 9 | class Vocab(object): 10 | PAD = '[PAD]' 11 | UNK = '[UNK]' 12 | BERT = '[BERT]' 13 | 14 | def __init__(self, config, words, tags, rels): 15 | 16 | self.config = config 17 | self.max_pad_length = config.max_seq_length 18 | 19 | self.words = [self.PAD, self.UNK] + sorted(words) 20 | 21 | self.tags = sorted(tags) 22 | self.tags = [self.PAD, self.UNK] + [':'+tag for tag in self.tags] 23 | 24 | self.rels = sorted(rels) 25 | 26 | self.bert_index = 1 27 | if self.config.input_graph: 28 | self.rels = [self.PAD] + [self.BERT] + self.rels 29 | else: 30 | self.rels = [self.PAD] + self.rels 31 | 32 | ## left-arc:L,right-arc:R,shift:S,swap:H 33 | self.trans = ['L', 'R', 'S','H'] 34 | self.trans_dict = {tr:i for i,tr in enumerate(self.trans)} 35 | 36 | self.word_dict = {word: i for i,word in enumerate(self.words)} 37 | self.punct = [word for word, i in self.word_dict.items() if regex.match(r'\p{P}+$', word)] 38 | ### Let's load a model and tokenizer############################ 39 | 40 | self.bertmodel = BertModel.from_pretrained(config.bert_path) 41 | self.tokenizer = BertTokenizer.from_pretrained(config.bert_path) 42 | 43 | self.tokenizer.add_tokens(self.tags + ['']+ self.punct + ['[CLS]']+ ['[SEP]']) 44 | 45 | all_tokens = self.tags + [''] + self.punct + ['[CLS]'] + ['[SEP]'] + self.words 46 | all_pics = [] 47 | for word in all_tokens: 48 | tokens = self.tokenizer.tokenize(word) 49 | for token in tokens: 50 | all_pics.append(token) 51 | self.word2bert = {} 52 | cou = 0 53 | for pic in all_pics: 54 | index = self.tokenizer.convert_tokens_to_ids(pic) 55 | if index not in self.word2bert: 56 | self.word2bert[index] = cou 57 | cou+= 1 58 | 59 | self.bertmodel.resize_token_embeddings(len(self.tokenizer)) 60 | self.bertmodel.train() 61 | vectors = self.bertmodel.embeddings.word_embeddings.weight 62 | new_vectors = torch.index_select(vectors,0,torch.tensor(list(self.word2bert.keys()))) 63 | self.bertmodel.resize_token_embeddings(len(self.word2bert)) 64 | self.bertmodel.embeddings.word_embeddings = self.bertmodel.\ 65 | embeddings.word_embeddings.from_pretrained(new_vectors) 66 | for index in self.word2bert: 67 | assert torch.all(torch.eq(self.bertmodel.embeddings.word_embeddings( 68 | torch.tensor(self.word2bert[index])),vectors[index])),\ 69 | "index-word2bert:{}".format(vectors[index]-self.bertmodel.embeddings. 70 | word_embeddings(torch.tensor(self.word2bert[index]))) 71 | 72 | # Train our model 73 | self.bertmodel.train() 74 | 75 | if os.path.exists(config.main_path + "/model" + "/model_" + config.modelname) != True: 76 | os.mkdir(config.main_path + "/model" + "/model_" + config.modelname) 77 | 78 | ### Now let's save our model and tokenizer to a directory 79 | self.bertmodel.save_pretrained(config.main_path + "/model" + "/model_" + config.modelname) 80 | 81 | self.tag_dict = {tag: i for i,tag in enumerate(self.tags)} 82 | 83 | self.rel_dict = {rel: i for i, rel in enumerate(self.rels)} 84 | # ids of punctuation that appear in words 85 | 86 | self.puncts = [] 87 | for punct in self.punct: 88 | self.puncts.append(self.word2bert.get(self.tokenizer.convert_tokens_to_ids(punct))) 89 | 90 | self.pad_index = self.tokenizer.convert_tokens_to_ids(self.PAD) 91 | self.pad_index = self.word2bert[self.pad_index] 92 | self.unk_index = self.tokenizer.convert_tokens_to_ids(self.UNK) 93 | self.unk_index = self.word2bert[self.unk_index] 94 | 95 | self.cls_index = self.word2bert[self.tokenizer.convert_tokens_to_ids('[CLS]') ] 96 | self.sep_index = self.word2bert[self.tokenizer.convert_tokens_to_ids('[SEP]')] 97 | 98 | self.n_words = len(self.words) 99 | self.n_tags = len(self.tags) 100 | self.n_rels = len(self.rels) 101 | self.n_trans = len(self.trans) 102 | self.n_train_words = self.n_words 103 | self.unk_count = 0 104 | self.total_count = 0 105 | self.long_seq = 0 106 | 107 | def __repr__(self): 108 | info = f"{self.__class__.__name__}: " 109 | info += f"{self.n_words} words, " 110 | info += f"{self.n_tags} tags, " 111 | info += f"{self.n_rels} rels" 112 | 113 | return info 114 | 115 | ## prepare data for train set 116 | def map_arcs_bert_pred(self, corpus, seq_corpus): 117 | 118 | all_words = [] 119 | all_tags = [] 120 | all_masks = [] 121 | all_actions = [] 122 | all_masks_action = [] 123 | all_rels = [] 124 | 125 | for i, (words, tags, seq) in enumerate(zip(corpus.words,corpus.tags, seq_corpus)): 126 | 127 | old_to_new_node = {0: 0} 128 | tokens_org, tokens_length = self.word2id(words) 129 | tokens = [item for sublist in tokens_org for item in sublist] 130 | 131 | index = 0 132 | for token_id, token_length in enumerate(tokens_length): 133 | index += token_length 134 | old_to_new_node[token_id + 1] = index 135 | 136 | # CLS heads and tags 137 | new_tags = [] 138 | offsets = torch.tensor(list(old_to_new_node.values()))[:-1] + 1 139 | 140 | for token_id, token_length in enumerate(tokens_length): 141 | for sub_token in range(token_length): 142 | new_tags.append(tags[token_id]) 143 | 144 | words_id = torch.tensor([self.cls_index] + tokens + [self.sep_index]) 145 | 146 | # 100 is the id of [UNK] 147 | self.unk_count += len((words_id == 100).nonzero()) 148 | self.total_count += len(words_id) 149 | 150 | tags = torch.tensor([self.cls_index] + self.tag2id(new_tags) + [self.sep_index]) 151 | 152 | masks = torch.zeros(len(words_id)).long() 153 | masks[offsets[1:]] = 1 154 | 155 | ## ignore some long sentences to fit the training phase in memory 156 | if len(seq['act']) < self.config.act_thr: 157 | all_words.append(words_id) 158 | all_tags.append(tags) 159 | all_masks.append(masks.bool()) 160 | all_actions.append(torch.tensor(seq['act'])) 161 | all_masks_action.append(torch.ones_like(torch.tensor(seq['act'])).bool()) 162 | all_rels.append(torch.tensor(seq['rel'])) 163 | 164 | print("Percentage of unkown tokens:{}".format(self.unk_count * 1.0 / self.total_count * 100)) 165 | self.unk_count = 0 166 | self.total_count = 0 167 | 168 | return all_words, all_tags, all_masks, all_actions, all_masks_action, all_rels 169 | 170 | ## prepare data for test and evaluation 171 | def map_arcs_bert(self, corpus): 172 | all_words = [] 173 | all_tags = [] 174 | all_masks = [] 175 | all_heads = [] 176 | all_rels = [] 177 | all_masks_head = [] 178 | 179 | for i, (words, tags,heads,rels) in enumerate(zip(corpus.words, corpus.tags, corpus.heads, corpus.rels)): 180 | 181 | old_to_new_node = {0: 0} 182 | tokens_org, tokens_length = self.word2id(words) 183 | tokens = [item for sublist in tokens_org for item in sublist] 184 | 185 | index = 0 186 | for token_id, token_length in enumerate(tokens_length): 187 | index += token_length 188 | old_to_new_node[token_id + 1] = index 189 | 190 | # CLS heads and tags 191 | new_tags = [] 192 | offsets = torch.tensor(list(old_to_new_node.values()))[:-1] + 1 193 | 194 | for token_id, token_length in enumerate(tokens_length): 195 | for sub_token in range(token_length): 196 | new_tags.append(tags[token_id]) 197 | 198 | words_id = torch.tensor([self.cls_index] + tokens + [self.sep_index]) 199 | 200 | self.unk_count += len((words_id == self.unk_index).nonzero()) 201 | self.total_count += len(words_id) 202 | 203 | tags = torch.tensor([self.cls_index] + self.tag2id(new_tags) + [self.sep_index]) 204 | 205 | rels = torch.tensor([self.rel2id(rel) for rel in rels]) 206 | masks = torch.zeros(len(words_id)).long() 207 | masks[offsets[1:]] = 1 208 | 209 | heads = torch.tensor(heads[1:]) 210 | masks_head = torch.ones_like(heads) 211 | 212 | 213 | if len(masks) < 512: 214 | all_words.append(words_id) 215 | all_tags.append(tags) 216 | all_masks.append(masks.bool()) 217 | all_rels.append(rels[1:]) 218 | all_heads.append(heads) 219 | all_masks_head.append(masks_head.bool()) 220 | 221 | print("Percentage of unknown tokens:{}".format(self.unk_count * 1.0 / self.total_count * 100)) 222 | self.unk_count = 0 223 | self.total_count = 0 224 | 225 | return all_words, all_tags, all_masks, all_heads, all_rels,all_masks_head 226 | 227 | def word2id(self, sequence): 228 | WORD2ID = [] 229 | lengths = [] 230 | for word in sequence: 231 | x = self.tokenizer.tokenize(word) 232 | if len(x) == 0: 233 | x = ['[UNK]'] 234 | x = self.tokenizer.convert_tokens_to_ids(x) 235 | x = [self.word2bert.get(y,self.unk_index) for y in x] 236 | lengths.append(len(x)) 237 | WORD2ID.append(x) 238 | return WORD2ID,lengths 239 | 240 | def tag2id(self, sequence): 241 | 242 | tags = [] 243 | for tag in sequence: 244 | tokenized_tag = self.tokenizer.tokenize(':'+tag) 245 | if len(tokenized_tag) != 1: 246 | tags.append(self.unk_index) 247 | else: 248 | tags.append(self.word2bert.get(self.tokenizer.convert_tokens_to_ids( 249 | tokenized_tag)[0],self.unk_index)) 250 | return tags 251 | 252 | def rel2id(self, rel): 253 | return self.rel_dict.get(rel, 0) 254 | 255 | def id2rel(self, ids): 256 | return [self.rels[i] for i in ids] 257 | 258 | def extend(self, words): 259 | self.words.extend(sorted(set(words).difference(self.word_dict))) 260 | self.word_dict = {word: i for i, word in enumerate(self.words)} 261 | self.puncts = sorted(i for word, i in self.word_dict.items() 262 | if regex.match(r'\p{P}+$', word)) 263 | self.n_words = len(self.words) 264 | 265 | 266 | def numericalize(self, corpus, seq_corpus = None): 267 | 268 | if seq_corpus is None: 269 | return self.map_arcs_bert(corpus) 270 | else: 271 | return self.map_arcs_bert_pred(corpus,seq_corpus) 272 | 273 | @classmethod 274 | def from_corpus(cls, config, corpus, corpus_dev=None, corpus_test=None, min_freq=0): 275 | if corpus_dev is not None: 276 | all_words = corpus.words + corpus_dev.words + corpus_test.words 277 | else: 278 | all_words = corpus.words 279 | words = Counter(word for seq in all_words for word in seq) 280 | words = list(word for word, freq in words.items() if freq >= min_freq) 281 | tags = list({tag for seq in corpus.tags for tag in seq}) 282 | rels = list({rel for seq in corpus.rels for rel in seq}) 283 | vocab = cls(config, words, tags, rels) 284 | 285 | return vocab 286 | -------------------------------------------------------------------------------- /senttr/parser/utils/base.py: -------------------------------------------------------------------------------- 1 | ## the base code is from https://github.com/huggingface/transformers 2 | import logging 3 | import math 4 | import os 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import CrossEntropyLoss, MSELoss 9 | 10 | from transformers.configuration_bert import BertConfig 11 | from transformers.modeling_bert import BertSelfAttention, BertSelfOutput, BertAttention, \ 12 | BertIntermediate, BertOutput, BertLayer, BertEncoder, BertPooler, \ 13 | BertPreTrainedModel 14 | 15 | from transformers.file_utils import add_start_docstrings 16 | from transformers.modeling_utils import PreTrainedModel, prune_linear_layer 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | BertLayerNorm = torch.nn.LayerNorm 22 | 23 | class BertBaseEmbeddings(nn.Module): 24 | """Construct the embeddings from word, position and token_type embeddings. 25 | """ 26 | 27 | def __init__(self, config): 28 | super(BertBaseEmbeddings, self).__init__() 29 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) 30 | 31 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 32 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 33 | 34 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 35 | # any TensorFlow checkpoint file 36 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 37 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 38 | 39 | def forward(self, input_ids=None, pos_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): 40 | if input_ids is not None: 41 | input_shape = input_ids.size() 42 | else: 43 | input_shape = inputs_embeds.size()[:-1] 44 | 45 | seq_length = input_shape[1] 46 | device = input_ids.device if input_ids is not None else inputs_embeds.device 47 | if position_ids is None: 48 | position_ids = torch.arange(seq_length, dtype=torch.long, device=device) 49 | position_ids = position_ids.unsqueeze(0).expand(input_shape) 50 | if token_type_ids is None: 51 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 52 | 53 | if inputs_embeds is None: 54 | inputs_embeds = self.word_embeddings(input_ids) 55 | pos_embeds = self.word_embeddings(pos_ids) 56 | inputs_embeds = inputs_embeds + pos_embeds 57 | 58 | position_embeddings = self.position_embeddings(position_ids) 59 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 60 | 61 | embeddings = inputs_embeds + position_embeddings + token_type_embeddings 62 | 63 | embeddings = self.LayerNorm(embeddings) 64 | embeddings = self.dropout(embeddings) 65 | return embeddings 66 | 67 | 68 | class BertBaseModel(BertPreTrainedModel): 69 | r""" 70 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 71 | **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` 72 | Sequence of hidden-states at the output of the last layer of the model. 73 | **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` 74 | Last layer hidden-state of the first token of the sequence (classification token) 75 | further processed by a Linear layer and a Tanh activation function. The Linear 76 | layer weights are trained from the next sentence prediction (classification) 77 | objective during Bert pretraining. This output is usually *not* a good summary 78 | of the semantic content of the input, you're often better with averaging or pooling 79 | the sequence of hidden-states for the whole input sequence. 80 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 81 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 82 | of shape ``(batch_size, sequence_length, hidden_size)``: 83 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 84 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 85 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 86 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 87 | 88 | Examples:: 89 | 90 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 91 | model = BertModel.from_pretrained('bert-base-uncased') 92 | input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 93 | outputs = model(input_ids) 94 | last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple 95 | 96 | """ 97 | 98 | def __init__(self, config): 99 | super(BertBaseModel, self).__init__(config) 100 | self.config = config 101 | 102 | self.embeddings = BertBaseEmbeddings(config) 103 | self.encoder = BertEncoder(config) 104 | self.pooler = BertPooler(config) 105 | 106 | self.init_weights() 107 | 108 | def get_input_embeddings(self): 109 | return self.embeddings.word_embeddings 110 | 111 | def set_input_embeddings(self, value): 112 | self.embeddings.word_embeddings = value 113 | 114 | def _prune_heads(self, heads_to_prune): 115 | """ Prunes heads of the model. 116 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 117 | See base class PreTrainedModel 118 | """ 119 | for layer, heads in heads_to_prune.items(): 120 | self.encoder.layer[layer].attention.prune_heads(heads) 121 | 122 | def forward( 123 | self, 124 | input_ids=None, 125 | pos_ids=None, 126 | attention_mask=None, 127 | token_type_ids=None, 128 | position_ids=None, 129 | head_mask=None, 130 | inputs_embeds=None, 131 | encoder_hidden_states=None, 132 | encoder_attention_mask=None, 133 | ): 134 | """ Forward pass on the Model. 135 | 136 | The model can behave as an encoder (with only self-attention) as well 137 | as a decoder, in which case a layer of cross-attention is added between 138 | the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani, 139 | Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. 140 | 141 | To behave as an decoder the model needs to be initialized with the 142 | `is_decoder` argument of the configuration set to `True`; an 143 | `encoder_hidden_states` is expected as an input to the forward pass. 144 | 145 | .. _`Attention is all you need`: 146 | https://arxiv.org/abs/1706.03762 147 | 148 | """ 149 | if input_ids is not None and inputs_embeds is not None: 150 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 151 | elif input_ids is not None: 152 | input_shape = input_ids.size() 153 | elif inputs_embeds is not None: 154 | input_shape = inputs_embeds.size()[:-1] 155 | else: 156 | raise ValueError("You have to specify either input_ids or inputs_embeds") 157 | 158 | device = input_ids.device if input_ids is not None else inputs_embeds.device 159 | 160 | if attention_mask is None: 161 | attention_mask = torch.ones(input_shape, device=device) 162 | if token_type_ids is None: 163 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 164 | 165 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 166 | # ourselves in which case we just need to make it broadcastable to all heads. 167 | if attention_mask.dim() == 3: 168 | extended_attention_mask = attention_mask[:, None, :, :] 169 | elif attention_mask.dim() == 2: 170 | # Provided a padding mask of dimensions [batch_size, seq_length] 171 | # - if the model is a decoder, apply a causal mask in addition to the padding mask 172 | # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] 173 | if self.config.is_decoder: 174 | batch_size, seq_length = input_shape 175 | seq_ids = torch.arange(seq_length, device=device) 176 | causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] 177 | causal_mask = causal_mask.to( 178 | torch.long 179 | ) # not converting to long will cause errors with pytorch version < 1.3 180 | extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] 181 | else: 182 | extended_attention_mask = attention_mask[:, None, None, :] 183 | else: 184 | raise ValueError( 185 | "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( 186 | input_shape, attention_mask.shape 187 | ) 188 | ) 189 | 190 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 191 | # masked positions, this operation will create a tensor which is 0.0 for 192 | # positions we want to attend and -10000.0 for masked positions. 193 | # Since we are adding it to the raw scores before the softmax, this is 194 | # effectively the same as removing these entirely. 195 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 196 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 197 | 198 | # If a 2D ou 3D attention mask is provided for the cross-attention 199 | # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] 200 | if self.config.is_decoder and encoder_hidden_states is not None: 201 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 202 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 203 | if encoder_attention_mask is None: 204 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 205 | 206 | if encoder_attention_mask.dim() == 3: 207 | encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] 208 | elif encoder_attention_mask.dim() == 2: 209 | encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] 210 | else: 211 | raise ValueError( 212 | "Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format( 213 | encoder_hidden_shape, encoder_attention_mask.shape 214 | ) 215 | ) 216 | 217 | encoder_extended_attention_mask = encoder_extended_attention_mask.to( 218 | dtype=next(self.parameters()).dtype 219 | ) # fp16 compatibility 220 | encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 221 | else: 222 | encoder_extended_attention_mask = None 223 | 224 | # Prepare head mask if needed 225 | # 1.0 in head_mask indicate we keep the head 226 | # attention_probs has shape bsz x n_heads x N x N 227 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 228 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 229 | if head_mask is not None: 230 | if head_mask.dim() == 1: 231 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 232 | head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) 233 | elif head_mask.dim() == 2: 234 | head_mask = ( 235 | head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) 236 | ) # We can specify head_mask for each layer 237 | head_mask = head_mask.to( 238 | dtype=next(self.parameters()).dtype 239 | ) # switch to fload if need + fp16 compatibility 240 | else: 241 | head_mask = [None] * self.config.num_hidden_layers 242 | 243 | embedding_output = self.embeddings( 244 | input_ids=input_ids, pos_ids=pos_ids, position_ids=position_ids, token_type_ids=token_type_ids, 245 | inputs_embeds=inputs_embeds) 246 | 247 | encoder_outputs = self.encoder( 248 | embedding_output, 249 | attention_mask=extended_attention_mask, 250 | head_mask=head_mask, 251 | encoder_hidden_states=encoder_hidden_states, 252 | encoder_attention_mask=encoder_extended_attention_mask, 253 | ) 254 | sequence_output = encoder_outputs[0] 255 | pooled_output = self.pooler(sequence_output) 256 | 257 | outputs = (sequence_output, pooled_output,) + encoder_outputs[ 258 | 1: 259 | ] # add hidden_states and attentions if they are here 260 | return outputs # sequence_output, pooled_output, (hidden_states), (attentions) 261 | 262 | -------------------------------------------------------------------------------- /senttr/parser/cmds/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | from os import path 5 | from datetime import datetime, timedelta 6 | from parser import Parser, Model 7 | from parser.metric import Metric 8 | from parser.utils import Corpus, Vocab 9 | from parser.utils.data import TextDataset, batchify 10 | 11 | import torch 12 | from transformers import AdamW,get_linear_schedule_with_warmup 13 | from parser.utils.corpus import read_seq 14 | 15 | class Train(object): 16 | 17 | def add_subparser(self, name, parser): 18 | subparser = parser.add_parser( 19 | name, help='Train a model.' 20 | ) 21 | subparser.add_argument('--buckets', default=64, type=int, 22 | help='Max number of buckets to use') 23 | 24 | subparser.add_argument('--epochs', default=12, type=int, 25 | help='Number of training epochs') 26 | 27 | subparser.add_argument('--punct', default=False, action='store_true', 28 | help='Whether to include punctuation') 29 | 30 | subparser.add_argument('--ftrain', default='data/train.conll', 31 | help='Path to train data') 32 | 33 | subparser.add_argument('--ftrain_seq', default='data/train.seq', 34 | help='Path to train oracle file') 35 | 36 | subparser.add_argument('--fdev', default='data/dev.conll', 37 | help='Path to dev file') 38 | 39 | subparser.add_argument('--ftest', default='data/test.conll', 40 | help='Path to test file') 41 | 42 | subparser.add_argument('--warmupproportion', '-w', default=0.01, type=float, 43 | help='Warm up proportion for BertAdam optimizer') 44 | 45 | subparser.add_argument('--lowercase', default=False, action='store_true', 46 | help='Whether to do lowercase in tokenisation step') 47 | 48 | subparser.add_argument('--lower_for_nonbert', default=False, action='store_true', 49 | help='Divide warm-up proportion of optimiser ' 50 | 'for randomly initialised parameters') 51 | 52 | subparser.add_argument('--modelname', default='None', 53 | help='Path to saved checkpoint') 54 | 55 | subparser.add_argument('--lr', default=1e-5, type=float, 56 | help='Learning rate for optimizer (for BERT parameters if two optimisers used)') 57 | 58 | subparser.add_argument('--lr2', default=2e-3, type=float, 59 | help='Learning rate for non-BERT parameters (two optimisers)') 60 | 61 | subparser.add_argument('--input_graph', default=False, action='store_true', 62 | help='Input dependency graph to attention mechanism') 63 | 64 | subparser.add_argument('--layernorm_key', default=False, action='store_true', 65 | help='layer normalization for Key (graph input)') 66 | 67 | subparser.add_argument('--layernorm_value', default=False, action='store_true', 68 | help='layer normalization for Value (graph input)') 69 | 70 | subparser.add_argument('--use_two_opts', default=False, action='store_true', 71 | help='Use one optimizer for Bert and one for others') 72 | 73 | subparser.add_argument('--mlp_dropout', default=0.33,type=float, 74 | help='MLP drop out') 75 | 76 | subparser.add_argument('--weight_decay', default=0.01,type=float, 77 | help='Weight Decay') 78 | 79 | subparser.add_argument('--max_grad_norm', default=1.0,type=float, 80 | help='Clip gradient') 81 | 82 | subparser.add_argument('--max_seq', default=1000,type=int, 83 | help='Maximum number of actions per sentence') 84 | 85 | subparser.add_argument('--n_attention_layer', default=6,type=int, 86 | help='Number of Attention Layers') 87 | 88 | subparser.add_argument('--checkpoint', default=False,action='store_true', 89 | help='Start from a checkpoint') 90 | 91 | subparser.add_argument('--act_thr', default=210,type=int, 92 | help='Maximum number of actions per sentence (training data)') 93 | 94 | subparser.add_argument('--bert_path', default='', help='path to BERT') 95 | 96 | subparser.add_argument('--main_path', default='', help='path to main directory') 97 | 98 | return subparser 99 | 100 | def __call__(self, config): 101 | print("Preprocess the data") 102 | train = Corpus.load(config.ftrain) 103 | dev = Corpus.load(config.fdev) 104 | test = Corpus.load(config.ftest) 105 | 106 | if path.exists(config.model) != True: 107 | os.mkdir(config.model) 108 | 109 | if path.exists("model/") != True: 110 | os.mkdir("model/") 111 | 112 | if path.exists(config.model+config.modelname) != True: 113 | os.mkdir(config.model+config.modelname) 114 | 115 | if config.checkpoint: 116 | vocab = torch.load(config.main_path + config.vocab+config.modelname + "/vocab.tag") 117 | else: 118 | vocab = Vocab.from_corpus(config=config, corpus=train, 119 | corpus_dev=dev,corpus_test=test,min_freq=0) 120 | train_seq = read_seq(config.ftrain_seq,vocab) 121 | total_act = 0 122 | for x in train_seq: 123 | total_act += len(x) 124 | print("number of transitions:{}".format(total_act)) 125 | 126 | torch.save(vocab, config.vocab+config.modelname + "/vocab.tag") 127 | 128 | config.update({ 129 | 'n_words': vocab.n_train_words, 130 | 'n_tags': vocab.n_tags, 131 | 'n_rels': vocab.n_rels, 132 | 'n_trans':vocab.n_trans, 133 | 'pad_index': vocab.pad_index, 134 | 'unk_index': vocab.unk_index 135 | }) 136 | 137 | print("Load the dataset") 138 | trainset = TextDataset(vocab.numericalize(train,train_seq)) 139 | devset = TextDataset(vocab.numericalize(dev)) 140 | testset = TextDataset(vocab.numericalize(test)) 141 | 142 | # set the data loaders 143 | train_loader,_ = batchify(dataset=trainset, 144 | batch_size=config.batch_size, 145 | n_buckets=config.buckets, 146 | shuffle=True) 147 | dev_loader,_ = batchify(dataset=devset, 148 | batch_size=config.batch_size, 149 | n_buckets=config.buckets) 150 | test_loader,_ = batchify(dataset=testset, 151 | batch_size=config.batch_size, 152 | n_buckets=config.buckets) 153 | 154 | print(f"{'train:':6} {len(trainset):5} sentences in total, " 155 | f"{len(train_loader):3} batches provided") 156 | print(f"{'dev:':6} {len(devset):5} sentences in total, " 157 | f"{len(dev_loader):3} batches provided") 158 | print(f"{'test:':6} {len(testset):5} sentences in total, " 159 | f"{len(test_loader):3} batches provided") 160 | print("Create the model") 161 | 162 | if config.checkpoint: 163 | parser = Parser.load(config.main_path + config.model + config.modelname 164 | + "/parser-checkpoint") 165 | else: 166 | parser = Parser(config, vocab.bertmodel) 167 | 168 | print("number of parameters:{}".format(sum(p.numel() for p in parser.parameters() 169 | if p.requires_grad))) 170 | if torch.cuda.is_available(): 171 | print('Train/Evaluate on GPU') 172 | device = torch.device('cuda') 173 | parser = parser.to(device) 174 | 175 | model = Model(vocab, parser, config, vocab.n_rels) 176 | total_time = timedelta() 177 | best_e, best_metric = 1, Metric() 178 | 179 | ## prepare optimisers 180 | num_train_optimization_steps = int(config.epochs * len(train_loader)) 181 | warmup_steps = int(config.warmupproportion * num_train_optimization_steps) 182 | ## one for parsing parameters, one for BERT parameters 183 | if config.use_two_opts: 184 | model_nonbert = [] 185 | model_bert = [] 186 | layernorm_params = ['layernorm_key_layer', 'layernorm_value_layer', 187 | 'dp_relation_k', 'dp_relation_v'] 188 | for name, param in parser.named_parameters(): 189 | if 'bert' in name and not any(nd in name for nd in layernorm_params): 190 | model_bert.append((name, param)) 191 | else: 192 | model_nonbert.append((name, param)) 193 | 194 | # Prepare optimizer and schedule (linear warmup and decay) for Non-bert parameters 195 | no_decay = ['bias', 'LayerNorm.weight'] 196 | optimizer_grouped_parameters_nonbert = [ 197 | {'params': [p for n, p in model_nonbert if not any(nd in n for nd in no_decay)], 198 | 'weight_decay': config.weight_decay}, 199 | {'params': [p for n, p in model_nonbert if any(nd in n for nd in no_decay)], 200 | 'weight_decay': 0.0} 201 | ] 202 | model.optimizer_nonbert = AdamW(optimizer_grouped_parameters_nonbert, lr=config.lr2) 203 | 204 | model.scheduler_nonbert = get_linear_schedule_with_warmup(model.optimizer_nonbert, 205 | num_warmup_steps=warmup_steps, 206 | num_training_steps=num_train_optimization_steps) 207 | 208 | # Prepare optimizer and schedule (linear warmup and decay) for Bert parameters 209 | optimizer_grouped_parameters_bert = [ 210 | {'params': [p for n, p in model_bert if not any(nd in n for nd in no_decay)], 211 | 'weight_decay': config.weight_decay}, 212 | {'params': [p for n, p in model_bert if any(nd in n for nd in no_decay)], 213 | 'weight_decay': 0.0} 214 | ] 215 | 216 | model.optimizer_bert = AdamW(optimizer_grouped_parameters_bert, lr=config.lr) 217 | model.scheduler_bert = get_linear_schedule_with_warmup( 218 | model.optimizer_bert, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps 219 | ) 220 | 221 | else: 222 | # Prepare optimizer and schedule (linear warmup and decay) 223 | no_decay = ['bias', 'LayerNorm.weight'] 224 | optimizer_grouped_parameters = [ 225 | {'params': [p for n, p in parser.named_parameters() if not any(nd in n for nd in no_decay)], 226 | 'weight_decay': config.weight_decay}, 227 | {'params': [p for n, p in parser.named_parameters() if any(nd in n for nd in no_decay)], 228 | 'weight_decay': 0.0} 229 | ] 230 | model.optimizer = AdamW(optimizer_grouped_parameters, lr=config.lr) 231 | model.scheduler = get_linear_schedule_with_warmup( 232 | model.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps 233 | ) 234 | 235 | 236 | start_epoch = 1 237 | 238 | ## load model, optimiser, and other parameters from a checkpoint 239 | if config.checkpoint: 240 | check_load = torch.load(config.main_path + config.model 241 | + config.modelname + "/checkpoint") 242 | if config.use_two_opts: 243 | model.optimizer_bert.load_state_dict(check_load['optimizer_bert']) 244 | model.optimizer_nonbert.load_state_dict(check_load['optimizer_nonbert']) 245 | model.scheduler_bert.load_state_dict(check_load['lr_schedule_bert']) 246 | model.scheduler_nonbert.load_state_dict(check_load['lr_schedule_nonbert']) 247 | start_epoch = check_load['epoch']+1 248 | best_e = check_load['best_e'] 249 | best_metric = check_load['best_metric'] 250 | else: 251 | model.optimizer.load_state_dict(check_load['optimizer']) 252 | model.scheduler.load_state_dict(check_load['lr_schedule']) 253 | start_epoch = check_load['epoch']+1 254 | best_e = check_load['best_e'] 255 | best_metric = check_load['best_metric'] 256 | 257 | f1 = open(config.model+config.modelname+"/baseline.txt","a") 258 | 259 | f1.write("New Model:\n") 260 | f1.close() 261 | for epoch in range(start_epoch, config.epochs + 1): 262 | start = datetime.now() 263 | # train one epoch and update the parameters 264 | model.train(train_loader) 265 | print(f"Epoch {epoch} / {config.epochs}:") 266 | f1 = open(config.model+config.modelname+"/baseline.txt","a") 267 | dev_metric = model.evaluate(dev_loader, config.punct) 268 | f1.write(str(epoch)+"\n") 269 | print(f"{'dev:':6} {dev_metric}") 270 | f1.write(f"{'dev:':6} {dev_metric}") 271 | f1.write("\n") 272 | f1.close() 273 | 274 | t = datetime.now() - start 275 | # save the model if it is the best so far 276 | if dev_metric > best_metric: 277 | best_e, best_metric = epoch, dev_metric 278 | print(config.model + config.modelname + "/model_weights") 279 | model.parser.save(config.model + config.modelname + "/model_weights") 280 | print(f"{t}s elapsed (saved)\n") 281 | else: 282 | print(f"{t}s elapsed\n") 283 | total_time += t 284 | if epoch - best_e >= config.patience: 285 | break 286 | 287 | ## save checkpoint 288 | if config.use_two_opts: 289 | checkpoint = { 290 | "epoch": epoch, 291 | "optimizer_bert":model.optimizer_bert.state_dict(), 292 | "lr_schedule_bert":model.scheduler_bert.state_dict(), 293 | "lr_schedule_nonbert":model.scheduler_nonbert.state_dict(), 294 | "optimizer_nonbert":model.optimizer_nonbert.state_dict(), 295 | 'best_metric':best_metric, 296 | 'best_e':best_e 297 | } 298 | torch.save(checkpoint,config.main_path + config.model + config.modelname + "/checkpoint") 299 | parser.save(config.main_path + config.model + config.modelname + "/parser-checkpoint") 300 | else: 301 | checkpoint = { 302 | "epoch": epoch, 303 | "optimizer":model.optimizer.state_dict(), 304 | "lr_schedule":model.scheduler.state_dict(), 305 | 'best_metric':best_metric, 306 | 'best_e':best_e 307 | } 308 | torch.save(checkpoint,config.main_path + config.model + config.modelname + "/checkpoint") 309 | parser.save(config.main_path + config.model + config.modelname + "/parser-checkpoint") 310 | model.parser = Parser.load(config.model + config.modelname + "/model_weights") 311 | metric = model.evaluate(test_loader, config.punct) 312 | print(metric) 313 | print(f"max score of dev is {best_metric.score:.2%} at epoch {best_e}") 314 | print(f"the score of test at epoch {best_e} is {metric.score:.2%}") 315 | print(f"average time of each epoch is {total_time / epoch}s") 316 | print(f"{total_time}s elapsed") 317 | -------------------------------------------------------------------------------- /statetr/featurize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import logging 5 | import os 6 | import time 7 | from collections import Counter 8 | import torch.nn as nn 9 | import numpy as np 10 | import torch 11 | import torch.cuda as cuda 12 | from tqdm import tqdm 13 | from transformers import BertModel, BertTokenizer 14 | from model import BertConfig 15 | from model import FFCompose 16 | import pickle 17 | 18 | P_PREFIX = '

:' 19 | L_PREFIX = ':' 20 | UNK = '[UNK]' 21 | NULL = '[NULL]' 22 | ROOT = '[ROOT]' 23 | CLS = '[CLS]' 24 | SEP = '[SEP]' 25 | 26 | LEN_VOCAB = 0 27 | class Parser(object): 28 | def __init__(self, dataset, opt, dataset_dev=None,dataset_test=None): 29 | 30 | self.embedding_shape = 0 31 | root_labels = list([ 32 | l for ex in dataset for (h, l) in zip(ex['head'], ex['label']) 33 | if h == 0 34 | ]) 35 | counter = Counter(root_labels) 36 | if len(counter) > 1: 37 | logging.info('Warning: more than one root label') 38 | logging.info(counter) 39 | self.root_label = counter.most_common()[0][0] 40 | deprel = [self.root_label] + list( 41 | set([ 42 | w for ex in dataset 43 | for w in ex['label'] if w != self.root_label 44 | ])) 45 | self.unlabeled = opt.unlabeled 46 | self.with_punct = opt.withpunct 47 | self.use_pos = opt.usepos 48 | self.language = opt.language 49 | 50 | if self.unlabeled: 51 | ## L:left-arc,R:right-arc,S:shift,H:swap 52 | transit = ['L', 'R', 'S','H'] 53 | self.n_deprel = 1 54 | else: 55 | transit = ['L-' + l for l in deprel] + ['R-' + l for l in deprel] + ['UNK_LABEL'] 56 | self.n_deprel = len(deprel) 57 | 58 | tok2id = {L_PREFIX + l: i for (i, l) in enumerate(deprel)} 59 | 60 | self.n_transit = len(transit) 61 | self.L_NULL = len(transit)-1 62 | self.tran2id = {t: i for (i, t) in enumerate(transit)} 63 | self.id2tran = {i: t for (i, t) in enumerate(transit)} 64 | 65 | logging.info('Build dictionary for part-of-speech tags.') 66 | tok2id.update( 67 | build_dict([P_PREFIX + w for ex in dataset for w in ex['pos']], 68 | offset=len(tok2id))) 69 | tok2id[P_PREFIX + UNK] = self.P_UNK = len(tok2id) 70 | tok2id[P_PREFIX + NULL] = self.P_NULL = len(tok2id) 71 | tok2id[P_PREFIX + ROOT] = self.P_ROOT = len(tok2id) 72 | 73 | 74 | logging.info('Build dictionary for words.') 75 | dataset_total = dataset + dataset_dev + dataset_test 76 | train_words = Counter([w for ex in dataset_total for w in ex['word']]) 77 | clip = int(opt.clipword*len(train_words)) 78 | train_words = train_words.most_common(clip) 79 | 80 | final_words = [] 81 | for word in train_words: 82 | final_words.append(word[0]) 83 | del train_words 84 | tok2id.update( 85 | build_dict(final_words, 86 | offset=len(tok2id))) 87 | 88 | tok2id[UNK] = self.UNK = len(tok2id) 89 | tok2id[NULL] = self.NULL = len(tok2id) 90 | tok2id[ROOT] = self.ROOT = len(tok2id) 91 | tok2id[CLS] = self.CLS = len(tok2id) 92 | tok2id[SEP] = self.SEP = len(tok2id) 93 | 94 | self.tok2id = tok2id 95 | 96 | self.id2tok = {v: k for (k, v) in tok2id.items()} 97 | 98 | self.n_tokens = len(tok2id) 99 | 100 | self.layers_lstm = opt.nlayershistory 101 | self.emb_size = opt.embsize 102 | 103 | def vectorize(self, examples): 104 | vec_examples = [] 105 | for ex in examples: 106 | word = [self.ROOT] + [ 107 | self.tok2id[w] if w in self.tok2id else self.UNK 108 | for w in ex['word'] 109 | ] 110 | 111 | pos = [self.P_ROOT] + [ 112 | self.tok2id[P_PREFIX + w] 113 | if P_PREFIX + w in self.tok2id else self.P_UNK 114 | for w in ex['pos'] 115 | ] 116 | head = [-1] + ex['head'] 117 | 118 | label = [-1] + [ 119 | self.tok2id[L_PREFIX + w] 120 | if L_PREFIX + w in self.tok2id else -1 for w in ex['label'] 121 | ] 122 | vec_examples.append({ 123 | 'word': word, 124 | 'pos': pos, 125 | 'head': head, 126 | 'label': label 127 | }) 128 | return vec_examples 129 | 130 | def create_instances(self, examples, seq_examples): 131 | all_instances = [] 132 | for id, (ex,seq_ex) in enumerate(zip(examples,seq_examples)): 133 | n_words = len(ex['word']) - 1 134 | 135 | if 3 not in seq_ex[0]: 136 | if len(seq_ex[0]) == 2*n_words: 137 | all_instances.append((ex,seq_ex[0],seq_ex[1])) 138 | else: 139 | assert False,'wrong oracle!! word:{},oracle:{}'.format(ex,seq_ex[0]) 140 | else: 141 | all_instances.append( (ex,seq_ex[0],seq_ex[1]) ) 142 | return all_instances 143 | 144 | def legal_labels(self, len_stack, len_buffer, index_stack): 145 | labels = [1] if len_stack >= 2 and index_stack[-2] != 0 else [0] 146 | labels += [1] if len_stack >= 2 and index_stack[-1] != 0 else [0] 147 | labels += [1] if len_buffer > 0 else [0] 148 | labels += [1] if len_stack >= 2 and 0 < index_stack[-2] < index_stack[-1] else [0] 149 | return labels 150 | 151 | 152 | ## transit = ['L', 'R', 'S','H'] 153 | def read_seq(in_file, parser, reduced, thr): 154 | max_read = 0 155 | lines = [] 156 | with open(in_file, 'r') as f: 157 | for line in f: 158 | lines.append(line) 159 | 160 | for i in range(len(lines)): 161 | lines[i] = lines[i].strip().split() 162 | 163 | gold_seq, arcs, seq = [], [], [] 164 | for line in lines: 165 | 166 | if reduced and max_read == thr: 167 | break 168 | if len(line) == 0: 169 | gold_seq.append((seq, arcs)) 170 | max_read += 1 171 | arcs, seq = [], [] 172 | elif len(line) == 3: 173 | # print(line) 174 | assert line[0] == 'Shift' 175 | seq.append(2) 176 | arcs.append(parser.L_NULL) 177 | elif len(line) == 1: 178 | assert line[0] == 'Swap' 179 | seq.append(3) 180 | arcs.append(parser.L_NULL) 181 | elif len(line) == 2: 182 | if line[0].startswith('R'): 183 | assert line[0] == 'Right-Arc' 184 | seq.append(1) 185 | arcs.append(parser.tran2id['R-' + line[1]]) 186 | elif line[0].startswith('L'): 187 | assert line[0] == 'Left-Arc' 188 | seq.append(0) 189 | arcs.append(parser.tran2id['L-' + line[1]]) 190 | return gold_seq 191 | 192 | 193 | # reading input data 194 | def read_conll(in_file, lowercase=False, max_example=None): 195 | examples = [] 196 | with open(in_file) as f: 197 | word, pos, head, label = [], [], [], [] 198 | for line in f.readlines(): 199 | sp = line.strip().split('\t') 200 | if len(sp) == 10: 201 | if '-' not in sp[0]: 202 | word.append(sp[1].lower() if lowercase else sp[1]) 203 | pos.append(sp[4]) 204 | head.append(int(sp[6])) 205 | label.append(sp[7]) 206 | elif len(word) > 0: 207 | examples.append({ 208 | 'word': word, 209 | 'pos': pos, 210 | 'head': head, 211 | 'label': label 212 | }) 213 | word, pos, head, label = [], [], [], [] 214 | if (max_example is 215 | not None) and (len(examples) == max_example): 216 | break 217 | if len(word) > 0: 218 | examples.append({ 219 | 'word': word, 220 | 'pos': pos, 221 | 'head': head, 222 | 'label': label 223 | }) 224 | 225 | 226 | return examples 227 | 228 | 229 | def build_dict(keys, n_max=None, offset=0): 230 | count = Counter() 231 | for key in keys: 232 | count[key] += 1 233 | 234 | if n_max is None: 235 | ls = count.most_common() 236 | else: 237 | ls = count.most_common(n_max) 238 | 239 | return {w[0]: index + offset for (index, w) in enumerate(ls)} 240 | 241 | def punct(language, pos): 242 | if language == 'english': 243 | return pos in ["''", ",", ".", ":", "``", "-LRB-", "-RRB-"] 244 | elif language == 'chinese': 245 | return pos == 'PU' 246 | elif language == 'french': 247 | return pos == 'PUNC' 248 | elif language == 'german': 249 | return pos in ["$.", "$,", "$["] 250 | elif language == 'spanish': 251 | # http://nlp.stanford.edu/software/spanish-faq.shtml 252 | return pos in [ 253 | "f0", "faa", "fat", "fc", "fd", "fe", "fg", "fh", "fia", "fit", 254 | "fp", "fpa", "fpt", "fs", "ft", "fx", "fz" 255 | ] 256 | elif language == 'universal': 257 | return pos == 'PUNCT' 258 | else: 259 | raise ValueError('language: %s is not supported.' % language) 260 | 261 | # preprocess the data without bert initialization 262 | def filter_random(opt,parser): 263 | 264 | word_vectors = {} 265 | embeddings_matrix = np.asarray( 266 | np.random.normal(-0.0279, 0.041, (parser.n_tokens, opt.embsize)), dtype='float32') 267 | ## loading the pre-trained BERT model, then modifying the word embedding part 268 | tempbert = BertModel.from_pretrained(str(opt.bertpath)+str(opt.bertname)) 269 | torch.save(tempbert.embeddings.position_embeddings.state_dict(),'position'+str(opt.outputname)) 270 | torch.save(tempbert.embeddings.token_type_embeddings.state_dict(),'token_type'+str(opt.outputname)) 271 | 272 | word_vectors = list(tempbert.parameters())[0].data.numpy() 273 | temptokenizer = BertTokenizer.from_pretrained(str(opt.bertpath)+str(opt.bertname)) 274 | 275 | UNK_ID = temptokenizer.convert_tokens_to_ids(temptokenizer.tokenize('[UNK]')) 276 | 277 | for token in parser.tok2id: 278 | i = parser.tok2id[token] 279 | j = temptokenizer.convert_tokens_to_ids(temptokenizer.tokenize(token)) 280 | if j != UNK_ID: 281 | embeddings_matrix[i] = word_vectors[j[0]] 282 | 283 | emb_size = embeddings_matrix.shape[1] 284 | embeddings_matrix = torch.from_numpy(embeddings_matrix) 285 | EMB = nn.Embedding(parser.n_tokens,emb_size,padding_idx=0) 286 | EMB.weight = nn.Parameter(embeddings_matrix) 287 | tempbert.embeddings.word_embeddings = EMB 288 | 289 | torch.save(tempbert.embeddings.word_embeddings.state_dict(),'word_emb'+str(opt.outputname)) 290 | 291 | del EMB, tempbert, word_vectors, temptokenizer 292 | 293 | return embeddings_matrix 294 | 295 | ## preprocess the data with bert initialization 296 | def filter_bert(opt,parser): 297 | 298 | word_vectors = {} 299 | embeddings_matrix = np.asarray( 300 | np.random.normal(-0.0279, 0.041, (parser.n_tokens, opt.embsize)), dtype='float32') 301 | ## loading the pre-trained BERT model, then modifying the word embedding part 302 | print(str(opt.bertpath)+str(opt.bertname)) 303 | print("Use Normal BERT") 304 | tempbert = BertModel.from_pretrained(str(opt.bertpath)+str(opt.bertname)) 305 | 306 | dict = tempbert.state_dict() 307 | keys = list(dict.keys()) 308 | numbers = list(range(opt.nattentionlayer, 12)) 309 | 310 | deleted = [] 311 | for x in numbers: 312 | deleted.append(str(x)) 313 | 314 | for dl in deleted: 315 | for key in keys: 316 | if key.find(dl) != -1: 317 | del dict[key] 318 | word_vectors = list(tempbert.parameters())[0].data.numpy() 319 | 320 | print("Use Normal BERT") 321 | temptokenizer = BertTokenizer.from_pretrained(str(opt.bertpath)+str(opt.bertname)) 322 | 323 | UNK_ID = temptokenizer.convert_tokens_to_ids(temptokenizer.tokenize('[UNK]')) 324 | counter = 0.0 325 | total = 0.0 326 | for token in parser.tok2id: 327 | i = parser.tok2id[token] 328 | x = temptokenizer.tokenize(token) 329 | if "[UNK]" in x and len(x) == 1: 330 | counter+=1.0 331 | total +=1.0 332 | j = temptokenizer.convert_tokens_to_ids(temptokenizer.tokenize(token)) 333 | 334 | if j != UNK_ID and len(j) > 0: 335 | embeddings_matrix[i] = word_vectors[j[0]] 336 | 337 | print("unk ratio: {}".format(counter/total*100)) 338 | 339 | emb_size = embeddings_matrix.shape[1] 340 | embeddings_matrix = torch.from_numpy(embeddings_matrix) 341 | 342 | dict.update({'embeddings.word_embeddings.weight':embeddings_matrix}) 343 | 344 | 345 | if opt.graphinput: 346 | del dict['embeddings.token_type_embeddings.weight'] 347 | del dict['pooler.dense.weight'] 348 | del dict['pooler.dense.bias'] 349 | 350 | torch.save(dict,'small_bert'+str(opt.outputname)) 351 | 352 | del tempbert, word_vectors, temptokenizer, dict 353 | 354 | return embeddings_matrix 355 | 356 | 357 | # preprocess the data when starting from a checkpoint 358 | def load_and_preprocess_datap(opt,parser,reduced=True): 359 | 360 | print("Loading data...", ) 361 | start = time.time() 362 | train_set = read_conll( 363 | os.path.join(opt.datapath, opt.trainfile), 364 | lowercase=opt.lowercase) 365 | dev_set = read_conll( 366 | os.path.join(opt.datapath, opt.devfile), 367 | lowercase=opt.lowercase) 368 | test_set = read_conll( 369 | os.path.join(opt.datapath, opt.testfile), 370 | lowercase=opt.lowercase) 371 | 372 | thr = 128 373 | if reduced: 374 | train_set = train_set[:thr+3] 375 | dev_set = dev_set[:thr+1] 376 | test_set = test_set[:thr+1] 377 | 378 | print("took {:.2f} seconds".format(time.time() - start)) 379 | 380 | print("Building parser...", ) 381 | start = time.time() 382 | print("took {:.2f} seconds".format(time.time() - start)) 383 | 384 | print("Reading gold actions...") 385 | seq_train = read_seq(os.path.join(opt.datapath, opt.seqpath), parser,reduced,thr+3) 386 | 387 | print("Loading pretrained embeddings...", ) 388 | start = time.time() 389 | 390 | if opt.withbert: 391 | embeddings_matrix = filter_bert(opt,parser) 392 | else: 393 | embeddings_matrix = filter_random(opt,parser) 394 | 395 | print("took {:.2f} seconds".format(time.time() - start)) 396 | 397 | print("Vectorizing data...", ) 398 | start = time.time() 399 | train_set = parser.vectorize(train_set) 400 | dev_set = parser.vectorize(dev_set) 401 | test_set = parser.vectorize(test_set) 402 | print("took {:.2f} seconds".format(time.time() - start)) 403 | 404 | print("Preprocessing training data...", ) 405 | start = time.time() 406 | train_examples = parser.create_instances(train_set,seq_train) 407 | 408 | print("took {:.2f} seconds".format(time.time() - start)) 409 | return embeddings_matrix, train_examples, train_set, dev_set, test_set, {'P':4} 410 | 411 | # preprocess the data when not starting from a checkpoint 412 | def load_and_preprocess_data(opt,reduced=True): 413 | 414 | print("Loading data...", ) 415 | start = time.time() 416 | train_set = read_conll( 417 | os.path.join(opt.datapath, opt.trainfile), 418 | lowercase=opt.lowercase) 419 | dev_set = read_conll( 420 | os.path.join(opt.datapath, opt.devfile), 421 | lowercase=opt.lowercase) 422 | test_set = read_conll( 423 | os.path.join(opt.datapath, opt.testfile), 424 | lowercase=opt.lowercase) 425 | 426 | thr = 64 427 | if reduced: 428 | train_set = train_set[:thr+3] 429 | dev_set = dev_set[:thr+1] 430 | test_set = test_set[:thr+1] 431 | 432 | print("took {:.2f} seconds".format(time.time() - start)) 433 | 434 | print("Building parser...", ) 435 | start = time.time() 436 | parser = Parser(train_set,opt,dev_set,test_set) 437 | print("took {:.2f} seconds".format(time.time() - start)) 438 | 439 | print("Reading gold actions...") 440 | seq_train = read_seq(os.path.join(opt.datapath, opt.seqpath), parser,reduced,thr+3) 441 | 442 | print("Loading pretrained embeddings...", ) 443 | start = time.time() 444 | 445 | if opt.withbert: 446 | embeddings_matrix = filter_bert(opt,parser) 447 | else: 448 | embeddings_matrix = filter_random(opt,parser) 449 | 450 | print("took {:.2f} seconds".format(time.time() - start)) 451 | 452 | print("Vectorizing data...", ) 453 | start = time.time() 454 | train_set = parser.vectorize(train_set) 455 | dev_set = parser.vectorize(dev_set) 456 | test_set = parser.vectorize(test_set) 457 | print("took {:.2f} seconds".format(time.time() - start)) 458 | 459 | print("Preprocessing training data...", ) 460 | start = time.time() 461 | train_examples = parser.create_instances(train_set,seq_train) 462 | 463 | print("took {:.2f} seconds".format(time.time() - start)) 464 | return parser, embeddings_matrix, train_examples, train_set, dev_set, test_set, {'P':4} 465 | 466 | # preprocess the test/evaluation data 467 | def load_and_preprocess_data_test(opt,parser,reduced=True): 468 | 469 | print("Loading data...", ) 470 | start = time.time() 471 | test_set = read_conll( 472 | os.path.join(opt.datapath, opt.testfile), 473 | lowercase=opt.lowercase) 474 | 475 | thr = 32 476 | if reduced: 477 | test_set = test_set[:thr+1] 478 | 479 | print("took {:.2f} seconds".format(time.time() - start)) 480 | 481 | print("Building parser...", ) 482 | start = time.time() 483 | print("took {:.2f} seconds".format(time.time() - start)) 484 | 485 | if opt.withbert: 486 | embeddings_matrix = filter_bert(opt,parser) 487 | else: 488 | embeddings_matrix = filter_random(opt,parser) 489 | 490 | print("Vectorizing data...", ) 491 | start = time.time() 492 | test_set = parser.vectorize(test_set) 493 | print("took {:.2f} seconds".format(time.time() - start)) 494 | 495 | print("Preprocessing training data...", ) 496 | start = time.time() 497 | print("took {:.2f} seconds".format(time.time() - start)) 498 | return test_set, {'P':4} 499 | 500 | class AverageMeter(object): 501 | """Computes and stores the average and current value""" 502 | 503 | def __init__(self): 504 | self.reset() 505 | 506 | def reset(self): 507 | self.val = 0 508 | self.avg = 0 509 | self.sum = 0 510 | self.count = 0 511 | 512 | def update(self, val, n=1): 513 | self.val = val 514 | self.sum += val * n 515 | self.count += n 516 | self.avg = self.sum / self.count 517 | 518 | if __name__ == "__main__": 519 | pass 520 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 2, June 1991 3 | 4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc., 5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | Preamble 10 | 11 | The licenses for most software are designed to take away your 12 | freedom to share and change it. By contrast, the GNU General Public 13 | License is intended to guarantee your freedom to share and change free 14 | software--to make sure the software is free for all its users. This 15 | General Public License applies to most of the Free Software 16 | Foundation's software and to any other program whose authors commit to 17 | using it. (Some other Free Software Foundation software is covered by 18 | the GNU Lesser General Public License instead.) You can apply it to 19 | your programs, too. 20 | 21 | When we speak of free software, we are referring to freedom, not 22 | price. Our General Public Licenses are designed to make sure that you 23 | have the freedom to distribute copies of free software (and charge for 24 | this service if you wish), that you receive source code or can get it 25 | if you want it, that you can change the software or use pieces of it 26 | in new free programs; and that you know you can do these things. 27 | 28 | To protect your rights, we need to make restrictions that forbid 29 | anyone to deny you these rights or to ask you to surrender the rights. 30 | These restrictions translate to certain responsibilities for you if you 31 | distribute copies of the software, or if you modify it. 32 | 33 | For example, if you distribute copies of such a program, whether 34 | gratis or for a fee, you must give the recipients all the rights that 35 | you have. You must make sure that they, too, receive or can get the 36 | source code. And you must show them these terms so they know their 37 | rights. 38 | 39 | We protect your rights with two steps: (1) copyright the software, and 40 | (2) offer you this license which gives you legal permission to copy, 41 | distribute and/or modify the software. 42 | 43 | Also, for each author's protection and ours, we want to make certain 44 | that everyone understands that there is no warranty for this free 45 | software. If the software is modified by someone else and passed on, we 46 | want its recipients to know that what they have is not the original, so 47 | that any problems introduced by others will not reflect on the original 48 | authors' reputations. 49 | 50 | Finally, any free program is threatened constantly by software 51 | patents. We wish to avoid the danger that redistributors of a free 52 | program will individually obtain patent licenses, in effect making the 53 | program proprietary. To prevent this, we have made it clear that any 54 | patent must be licensed for everyone's free use or not licensed at all. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | GNU GENERAL PUBLIC LICENSE 60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 61 | 62 | 0. This License applies to any program or other work which contains 63 | a notice placed by the copyright holder saying it may be distributed 64 | under the terms of this General Public License. The "Program", below, 65 | refers to any such program or work, and a "work based on the Program" 66 | means either the Program or any derivative work under copyright law: 67 | that is to say, a work containing the Program or a portion of it, 68 | either verbatim or with modifications and/or translated into another 69 | language. (Hereinafter, translation is included without limitation in 70 | the term "modification".) Each licensee is addressed as "you". 71 | 72 | Activities other than copying, distribution and modification are not 73 | covered by this License; they are outside its scope. The act of 74 | running the Program is not restricted, and the output from the Program 75 | is covered only if its contents constitute a work based on the 76 | Program (independent of having been made by running the Program). 77 | Whether that is true depends on what the Program does. 78 | 79 | 1. You may copy and distribute verbatim copies of the Program's 80 | source code as you receive it, in any medium, provided that you 81 | conspicuously and appropriately publish on each copy an appropriate 82 | copyright notice and disclaimer of warranty; keep intact all the 83 | notices that refer to this License and to the absence of any warranty; 84 | and give any other recipients of the Program a copy of this License 85 | along with the Program. 86 | 87 | You may charge a fee for the physical act of transferring a copy, and 88 | you may at your option offer warranty protection in exchange for a fee. 89 | 90 | 2. You may modify your copy or copies of the Program or any portion 91 | of it, thus forming a work based on the Program, and copy and 92 | distribute such modifications or work under the terms of Section 1 93 | above, provided that you also meet all of these conditions: 94 | 95 | a) You must cause the modified files to carry prominent notices 96 | stating that you changed the files and the date of any change. 97 | 98 | b) You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | 103 | c) If the modified program normally reads commands interactively 104 | when run, you must cause it, when started running for such 105 | interactive use in the most ordinary way, to print or display an 106 | announcement including an appropriate copyright notice and a 107 | notice that there is no warranty (or else, saying that you provide 108 | a warranty) and that users may redistribute the program under 109 | these conditions, and telling the user how to view a copy of this 110 | License. (Exception: if the Program itself is interactive but 111 | does not normally print such an announcement, your work based on 112 | the Program is not required to print an announcement.) 113 | 114 | These requirements apply to the modified work as a whole. If 115 | identifiable sections of that work are not derived from the Program, 116 | and can be reasonably considered independent and separate works in 117 | themselves, then this License, and its terms, do not apply to those 118 | sections when you distribute them as separate works. But when you 119 | distribute the same sections as part of a whole which is a work based 120 | on the Program, the distribution of the whole must be on the terms of 121 | this License, whose permissions for other licensees extend to the 122 | entire whole, and thus to each and every part regardless of who wrote it. 123 | 124 | Thus, it is not the intent of this section to claim rights or contest 125 | your rights to work written entirely by you; rather, the intent is to 126 | exercise the right to control the distribution of derivative or 127 | collective works based on the Program. 128 | 129 | In addition, mere aggregation of another work not based on the Program 130 | with the Program (or with a work based on the Program) on a volume of 131 | a storage or distribution medium does not bring the other work under 132 | the scope of this License. 133 | 134 | 3. You may copy and distribute the Program (or a work based on it, 135 | under Section 2) in object code or executable form under the terms of 136 | Sections 1 and 2 above provided that you also do one of the following: 137 | 138 | a) Accompany it with the complete corresponding machine-readable 139 | source code, which must be distributed under the terms of Sections 140 | 1 and 2 above on a medium customarily used for software interchange; or, 141 | 142 | b) Accompany it with a written offer, valid for at least three 143 | years, to give any third party, for a charge no more than your 144 | cost of physically performing source distribution, a complete 145 | machine-readable copy of the corresponding source code, to be 146 | distributed under the terms of Sections 1 and 2 above on a medium 147 | customarily used for software interchange; or, 148 | 149 | c) Accompany it with the information you received as to the offer 150 | to distribute corresponding source code. (This alternative is 151 | allowed only for noncommercial distribution and only if you 152 | received the program in object code or executable form with such 153 | an offer, in accord with Subsection b above.) 154 | 155 | The source code for a work means the preferred form of the work for 156 | making modifications to it. For an executable work, complete source 157 | code means all the source code for all modules it contains, plus any 158 | associated interface definition files, plus the scripts used to 159 | control compilation and installation of the executable. However, as a 160 | special exception, the source code distributed need not include 161 | anything that is normally distributed (in either source or binary 162 | form) with the major components (compiler, kernel, and so on) of the 163 | operating system on which the executable runs, unless that component 164 | itself accompanies the executable. 165 | 166 | If distribution of executable or object code is made by offering 167 | access to copy from a designated place, then offering equivalent 168 | access to copy the source code from the same place counts as 169 | distribution of the source code, even though third parties are not 170 | compelled to copy the source along with the object code. 171 | 172 | 4. You may not copy, modify, sublicense, or distribute the Program 173 | except as expressly provided under this License. Any attempt 174 | otherwise to copy, modify, sublicense or distribute the Program is 175 | void, and will automatically terminate your rights under this License. 176 | However, parties who have received copies, or rights, from you under 177 | this License will not have their licenses terminated so long as such 178 | parties remain in full compliance. 179 | 180 | 5. You are not required to accept this License, since you have not 181 | signed it. However, nothing else grants you permission to modify or 182 | distribute the Program or its derivative works. These actions are 183 | prohibited by law if you do not accept this License. Therefore, by 184 | modifying or distributing the Program (or any work based on the 185 | Program), you indicate your acceptance of this License to do so, and 186 | all its terms and conditions for copying, distributing or modifying 187 | the Program or works based on it. 188 | 189 | 6. Each time you redistribute the Program (or any work based on the 190 | Program), the recipient automatically receives a license from the 191 | original licensor to copy, distribute or modify the Program subject to 192 | these terms and conditions. You may not impose any further 193 | restrictions on the recipients' exercise of the rights granted herein. 194 | You are not responsible for enforcing compliance by third parties to 195 | this License. 196 | 197 | 7. If, as a consequence of a court judgment or allegation of patent 198 | infringement or for any other reason (not limited to patent issues), 199 | conditions are imposed on you (whether by court order, agreement or 200 | otherwise) that contradict the conditions of this License, they do not 201 | excuse you from the conditions of this License. If you cannot 202 | distribute so as to satisfy simultaneously your obligations under this 203 | License and any other pertinent obligations, then as a consequence you 204 | may not distribute the Program at all. For example, if a patent 205 | license would not permit royalty-free redistribution of the Program by 206 | all those who receive copies directly or indirectly through you, then 207 | the only way you could satisfy both it and this License would be to 208 | refrain entirely from distribution of the Program. 209 | 210 | If any portion of this section is held invalid or unenforceable under 211 | any particular circumstance, the balance of the section is intended to 212 | apply and the section as a whole is intended to apply in other 213 | circumstances. 214 | 215 | It is not the purpose of this section to induce you to infringe any 216 | patents or other property right claims or to contest validity of any 217 | such claims; this section has the sole purpose of protecting the 218 | integrity of the free software distribution system, which is 219 | implemented by public license practices. Many people have made 220 | generous contributions to the wide range of software distributed 221 | through that system in reliance on consistent application of that 222 | system; it is up to the author/donor to decide if he or she is willing 223 | to distribute software through any other system and a licensee cannot 224 | impose that choice. 225 | 226 | This section is intended to make thoroughly clear what is believed to 227 | be a consequence of the rest of this License. 228 | 229 | 8. If the distribution and/or use of the Program is restricted in 230 | certain countries either by patents or by copyrighted interfaces, the 231 | original copyright holder who places the Program under this License 232 | may add an explicit geographical distribution limitation excluding 233 | those countries, so that distribution is permitted only in or among 234 | countries not thus excluded. In such case, this License incorporates 235 | the limitation as if written in the body of this License. 236 | 237 | 9. The Free Software Foundation may publish revised and/or new versions 238 | of the General Public License from time to time. Such new versions will 239 | be similar in spirit to the present version, but may differ in detail to 240 | address new problems or concerns. 241 | 242 | Each version is given a distinguishing version number. If the Program 243 | specifies a version number of this License which applies to it and "any 244 | later version", you have the option of following the terms and conditions 245 | either of that version or of any later version published by the Free 246 | Software Foundation. If the Program does not specify a version number of 247 | this License, you may choose any version ever published by the Free Software 248 | Foundation. 249 | 250 | 10. If you wish to incorporate parts of the Program into other free 251 | programs whose distribution conditions are different, write to the author 252 | to ask for permission. For software which is copyrighted by the Free 253 | Software Foundation, write to the Free Software Foundation; we sometimes 254 | make exceptions for this. Our decision will be guided by the two goals 255 | of preserving the free status of all derivatives of our free software and 256 | of promoting the sharing and reuse of software generally. 257 | 258 | NO WARRANTY 259 | 260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 268 | REPAIR OR CORRECTION. 269 | 270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 278 | POSSIBILITY OF SUCH DAMAGES. 279 | 280 | END OF TERMS AND CONDITIONS 281 | 282 | How to Apply These Terms to Your New Programs 283 | 284 | If you develop a new program, and you want it to be of the greatest 285 | possible use to the public, the best way to achieve this is to make it 286 | free software which everyone can redistribute and change under these terms. 287 | 288 | To do so, attach the following notices to the program. It is safest 289 | to attach them to the start of each source file to most effectively 290 | convey the exclusion of warranty; and each file should have at least 291 | the "copyright" line and a pointer to where the full notice is found. 292 | 293 | 294 | Copyright (C) 295 | 296 | This program is free software; you can redistribute it and/or modify 297 | it under the terms of the GNU General Public License as published by 298 | the Free Software Foundation; either version 2 of the License, or 299 | (at your option) any later version. 300 | 301 | This program is distributed in the hope that it will be useful, 302 | but WITHOUT ANY WARRANTY; without even the implied warranty of 303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 304 | GNU General Public License for more details. 305 | 306 | You should have received a copy of the GNU General Public License along 307 | with this program; if not, write to the Free Software Foundation, Inc., 308 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 309 | 310 | Also add information on how to contact you by electronic and paper mail. 311 | 312 | If the program is interactive, make it output a short notice like this 313 | when it starts in an interactive mode: 314 | 315 | Gnomovision version 69, Copyright (C) year name of author 316 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 317 | This is free software, and you are welcome to redistribute it 318 | under certain conditions; type `show c' for details. 319 | 320 | The hypothetical commands `show w' and `show c' should show the appropriate 321 | parts of the General Public License. Of course, the commands you use may 322 | be called something other than `show w' and `show c'; they could even be 323 | mouse-clicks or menu items--whatever suits your program. 324 | 325 | You should also get your employer (if you work as a programmer) or your 326 | school, if any, to sign a "copyright disclaimer" for the program, if 327 | necessary. Here is a sample; alter the names: 328 | 329 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program 330 | `Gnomovision' (which makes passes at compilers) written by James Hacker. 331 | 332 | , 1 April 1989 333 | Ty Coon, President of Vice 334 | 335 | This General Public License does not permit incorporating your program into 336 | proprietary programs. If your program is a subroutine library, you may 337 | consider it more useful to permit linking proprietary applications with the 338 | library. If this is what you want to do, use the GNU Lesser General 339 | Public License instead of this License. 340 | --------------------------------------------------------------------------------