├── .gitignore ├── LICENSE ├── README.md ├── amr_parser ├── AMRGraph.py ├── __init__.py ├── adam.py ├── amr.py ├── bert_utils.py ├── data.py ├── decoder.py ├── draw_attn.py ├── dump_attention.py ├── encoder.py ├── eval_dev_test.py ├── eval_test.py ├── extract.py ├── match.py ├── parser.py ├── postprocess.py ├── search.py ├── train.py ├── transformer.py ├── utils.py ├── visual_beam_search.py └── work.py ├── elit ├── __init__.py ├── callbacks │ ├── __init__.py │ └── fine_csv_logger.py ├── common │ ├── __init__.py │ ├── amr.py │ ├── component.py │ ├── configurable.py │ ├── conll.py │ ├── constant.py │ ├── dataset.py │ ├── io.py │ ├── keras_component.py │ ├── reflection.py │ ├── structure.py │ ├── torch_component.py │ ├── transform.py │ ├── transform_tf.py │ ├── util.py │ ├── visualization.py │ ├── vocab.py │ └── vocab_tf.py ├── components │ ├── __init__.py │ ├── amr │ │ ├── __init__.py │ │ └── amr_parser │ │ │ ├── __init__.py │ │ │ ├── adam.py │ │ │ ├── amr.py │ │ │ ├── amr_graph.py │ │ │ ├── amrio.py │ │ │ ├── customized_bart.py │ │ │ ├── data.py │ │ │ ├── decoder.py │ │ │ ├── encoder.py │ │ │ ├── graph_amr_decoder.py │ │ │ ├── graph_model.py │ │ │ ├── graph_parser.py │ │ │ ├── graph_sequence_parser.py │ │ │ ├── parser_model.py │ │ │ ├── postprocess.py │ │ │ ├── search.py │ │ │ ├── transformer.py │ │ │ ├── utils.py │ │ │ └── work.py │ ├── classifiers │ │ ├── __init__.py │ │ ├── transformer_classifier.py │ │ └── transformer_classifier_tf.py │ ├── coref │ │ ├── __init__.py │ │ └── end_to_end.py │ ├── distillation │ │ ├── __init__.py │ │ ├── distillable_component.py │ │ ├── losses.py │ │ └── schedulers.py │ ├── eos │ │ ├── __init__.py │ │ └── ngram.py │ ├── lambda_wrapper.py │ ├── lemmatizer.py │ ├── mtl │ │ ├── __init__.py │ │ ├── multi_task_learning.py │ │ └── tasks │ │ │ ├── __init__.py │ │ │ ├── amr.py │ │ │ ├── constituency.py │ │ │ ├── dep.py │ │ │ ├── dep_2nd.py │ │ │ ├── lem.py │ │ │ ├── ner │ │ │ ├── __init__.py │ │ │ ├── biaffine_ner.py │ │ │ └── tag_ner.py │ │ │ ├── pos.py │ │ │ ├── sdp.py │ │ │ ├── srl │ │ │ ├── __init__.py │ │ │ ├── bio_srl.py │ │ │ └── rank_srl.py │ │ │ ├── tok │ │ │ ├── __init__.py │ │ │ ├── reg_tok.py │ │ │ └── tag_tok.py │ │ │ └── ud.py │ ├── ner │ │ ├── __init__.py │ │ ├── biaffine_ner │ │ │ ├── __init__.py │ │ │ ├── biaffine_ner.py │ │ │ └── biaffine_ner_model.py │ │ ├── rnn_ner.py │ │ └── transformer_ner.py │ ├── ner_tf.py │ ├── parsers │ │ ├── __init__.py │ │ ├── alg.py │ │ ├── alg_tf.py │ │ ├── biaffine │ │ │ ├── __init__.py │ │ │ ├── biaffine.py │ │ │ ├── biaffine_2nd_dep.py │ │ │ ├── biaffine_dep.py │ │ │ ├── biaffine_model.py │ │ │ ├── biaffine_sdp.py │ │ │ ├── mlp.py │ │ │ ├── structual_attention.py │ │ │ └── variationalbilstm.py │ │ ├── biaffine_parser_tf.py │ │ ├── biaffine_tf │ │ │ ├── __init__.py │ │ │ ├── alg.py │ │ │ ├── layers.py │ │ │ └── model.py │ │ ├── chu_liu_edmonds.py │ │ ├── conll.py │ │ ├── constituency │ │ │ ├── __init__.py │ │ │ ├── constituency_dataset.py │ │ │ ├── crf_constituency_model.py │ │ │ ├── crf_constituency_parser.py │ │ │ └── treecrf.py │ │ ├── hpsg │ │ │ ├── __init__.py │ │ │ ├── bracket_eval.py │ │ │ ├── const_decoder.pyx │ │ │ ├── dep_eval.py │ │ │ ├── hpsg_dataset.py │ │ │ ├── hpsg_decoder.pyx │ │ │ ├── hpsg_parser.py │ │ │ ├── hpsg_parser_model.py │ │ │ └── trees.py │ │ ├── parse_alg.py │ │ ├── second_order │ │ │ ├── __init__.py │ │ │ ├── affine.py │ │ │ ├── model.py │ │ │ ├── tree_crf_dependency_parser.py │ │ │ └── treecrf_decoder.py │ │ └── ud │ │ │ ├── __init__.py │ │ │ ├── lemma_edit.py │ │ │ ├── tag_decoder.py │ │ │ ├── ud_model.py │ │ │ ├── ud_parser.py │ │ │ ├── udify_util.py │ │ │ └── util.py │ ├── pipeline.py │ ├── pos_tf.py │ ├── rnn_language_model.py │ ├── srl │ │ ├── __init__.py │ │ ├── span_bio │ │ │ ├── __init__.py │ │ │ ├── baffine_tagging.py │ │ │ └── span_bio.py │ │ └── span_rank │ │ │ ├── __init__.py │ │ │ ├── highway_variational_lstm.py │ │ │ ├── inference_utils.py │ │ │ ├── layer.py │ │ │ ├── span_rank.py │ │ │ ├── span_ranking_srl_model.py │ │ │ ├── srl_eval_utils.py │ │ │ └── util.py │ ├── taggers │ │ ├── __init__.py │ │ ├── cnn_tagger_tf.py │ │ ├── ngram_conv │ │ │ ├── __init__.py │ │ │ └── ngram_conv_tagger.py │ │ ├── rnn │ │ │ ├── __init__.py │ │ │ └── rnntaggingmodel.py │ │ ├── rnn_tagger.py │ │ ├── rnn_tagger_tf.py │ │ ├── tagger.py │ │ ├── tagger_tf.py │ │ ├── transformers │ │ │ ├── __init__.py │ │ │ ├── metrics_tf.py │ │ │ ├── transformer_tagger.py │ │ │ ├── transformer_tagger_tf.py │ │ │ └── transformer_transform_tf.py │ │ └── util.py │ ├── tok.py │ ├── tok_tf.py │ └── tokenizers │ │ ├── __init__.py │ │ ├── multi_criteria_cws_transformer.py │ │ └── transformer.py ├── datasets │ ├── __init__.py │ ├── classification │ │ ├── __init__.py │ │ └── sentiment.py │ ├── coref │ │ ├── __init__.py │ │ └── conll12coref.py │ ├── cws │ │ ├── __init__.py │ │ ├── chunking_dataset.py │ │ ├── ctb6.py │ │ ├── multi_criteria_cws │ │ │ ├── __init__.py │ │ │ └── mcws_dataset.py │ │ └── sighan2005 │ │ │ ├── __init__.py │ │ │ ├── as_.py │ │ │ ├── cityu.py │ │ │ ├── msr.py │ │ │ └── pku.py │ ├── eos │ │ ├── __init__.py │ │ ├── eos.py │ │ └── nn_eos.py │ ├── glue.py │ ├── lm │ │ ├── __init__.py │ │ └── lm_dataset.py │ ├── ner │ │ ├── __init__.py │ │ ├── conll03.py │ │ ├── json_ner.py │ │ ├── msra.py │ │ ├── resume.py │ │ ├── tsv.py │ │ └── weibo.py │ ├── parsing │ │ ├── __init__.py │ │ ├── _ctb_utils.py │ │ ├── amr.py │ │ ├── conll_dataset.py │ │ ├── ctb5.py │ │ ├── ctb7.py │ │ ├── ctb8.py │ │ ├── ctb9.py │ │ ├── ptb.py │ │ ├── semeval15.py │ │ ├── semeval16.py │ │ └── ud │ │ │ ├── __init__.py │ │ │ ├── ud23.py │ │ │ ├── ud23m.py │ │ │ ├── ud27.py │ │ │ └── ud27m.py │ ├── pos │ │ ├── __init__.py │ │ └── ctb5.py │ ├── qa │ │ ├── __init__.py │ │ └── hotpotqa.py │ ├── srl │ │ ├── __init__.py │ │ ├── conll2012.py │ │ └── ontonotes5 │ │ │ ├── __init__.py │ │ │ ├── _utils.py │ │ │ ├── chinese.py │ │ │ └── english.py │ └── tokenization │ │ ├── __init__.py │ │ └── txt.py ├── layers │ ├── __init__.py │ ├── context_layer.py │ ├── crf │ │ ├── __init__.py │ │ ├── crf.py │ │ ├── crf_layer_tf.py │ │ └── crf_tf.py │ ├── dropout.py │ ├── embeddings │ │ ├── __init__.py │ │ ├── char_cnn.py │ │ ├── char_cnn_tf.py │ │ ├── char_rnn.py │ │ ├── char_rnn_tf.py │ │ ├── concat_embedding.py │ │ ├── contextual_string_embedding.py │ │ ├── contextual_string_embedding_tf.py │ │ ├── contextual_word_embedding.py │ │ ├── embedding.py │ │ ├── fast_text.py │ │ ├── fast_text_tf.py │ │ ├── util.py │ │ ├── util_tf.py │ │ ├── word2vec.py │ │ └── word2vec_tf.py │ ├── feed_forward.py │ ├── pass_through_encoder.py │ ├── scalar_mix.py │ ├── transformers │ │ ├── __init__.py │ │ ├── encoder.py │ │ ├── loader_tf.py │ │ ├── pt_imports.py │ │ ├── relative_transformer.py │ │ ├── tf_imports.py │ │ ├── utils.py │ │ └── utils_tf.py │ └── weight_normalization.py ├── losses │ ├── __init__.py │ └── sparse_categorical_crossentropy.py ├── metrics │ ├── __init__.py │ ├── accuracy.py │ ├── amr │ │ ├── __init__.py │ │ └── smatch_eval.py │ ├── chunking │ │ ├── __init__.py │ │ ├── binary_chunking_f1.py │ │ ├── bmes.py │ │ ├── chunking_f1.py │ │ ├── chunking_f1_tf.py │ │ ├── conlleval.py │ │ ├── cws_eval.py │ │ ├── iobes_tf.py │ │ └── sequence_labeling.py │ ├── f1.py │ ├── metric.py │ ├── mtl.py │ ├── parsing │ │ ├── __init__.py │ │ ├── attachmentscore.py │ │ ├── conllx_eval.py │ │ ├── iwpt20_eval.py │ │ ├── iwpt20_xud_eval.py │ │ ├── labeled_f1.py │ │ ├── labeled_f1_tf.py │ │ ├── labeled_score.py │ │ ├── semdep_eval.py │ │ └── span.py │ └── srl │ │ ├── __init__.py │ │ └── srlconll.py ├── optimizers │ ├── __init__.py │ └── adamw │ │ ├── __init__.py │ │ └── optimization.py ├── pretrained │ ├── __init__.py │ ├── classifiers.py │ ├── dep.py │ ├── eos.py │ ├── fasttext.py │ ├── glove.py │ ├── mtl.py │ ├── ner.py │ ├── pos.py │ ├── rnnlm.py │ ├── sdp.py │ ├── tok.py │ └── word2vec.py ├── transform │ ├── __init__.py │ ├── conll_tf.py │ ├── glue_tf.py │ ├── table.py │ ├── tacred.py │ ├── text.py │ ├── transformer_tokenizer.py │ ├── tsv.py │ └── txt.py ├── utils │ ├── __init__.py │ ├── component_util.py │ ├── english_tokenizer.py │ ├── file_read_backwards │ │ ├── __init__.py │ │ ├── buffer_work_space.py │ │ └── file_read_backwards.py │ ├── init_util.py │ ├── io_util.py │ ├── lang │ │ ├── __init__.py │ │ └── zh │ │ │ ├── __init__.py │ │ │ ├── char_table.py │ │ │ └── localization.py │ ├── log_util.py │ ├── rules.py │ ├── span_util.py │ ├── string_util.py │ ├── tf_util.py │ ├── time_util.py │ └── torch_util.py └── version.py ├── requirements.txt ├── scripts ├── annotate_features.sh ├── download_artifacts.sh ├── env.sh ├── postprocess_2.0.sh ├── postprocess_3.0.sh ├── prepare_data.sh ├── prepare_vocab.sh ├── preprocess_2.0.sh ├── preprocess_3.0.sh ├── run_spotlight.sh ├── run_standford_corenlp_server.sh ├── train_joint.sh └── train_levi.sh ├── stog ├── .gitattributes ├── __init__.py ├── algorithms │ ├── __init__.py │ ├── dict_merge.py │ └── maximum_spanning_tree.py ├── commands │ ├── __init__.py │ ├── evaluate.py │ ├── predict.py │ ├── subcommand.py │ └── train.py ├── data │ ├── __init__.py │ ├── dataset.py │ ├── dataset_builder.py │ ├── dataset_readers │ │ ├── __init__.py │ │ ├── abstract_meaning_representation.py │ │ ├── amr_parsing │ │ │ ├── __init__.py │ │ │ ├── amr.py │ │ │ ├── amr_concepts │ │ │ │ ├── __init__.py │ │ │ │ ├── date.py │ │ │ │ ├── entity.py │ │ │ │ ├── ordinal.py │ │ │ │ ├── polarity.py │ │ │ │ ├── polite.py │ │ │ │ ├── quantity.py │ │ │ │ ├── score.py │ │ │ │ └── url.py │ │ │ ├── graph_repair.py │ │ │ ├── io.py │ │ │ ├── node_utils.py │ │ │ ├── postprocess │ │ │ │ ├── __init__.py │ │ │ │ ├── expander.py │ │ │ │ ├── node_restore.py │ │ │ │ ├── postprocess.py │ │ │ │ └── wikification.py │ │ │ ├── preprocess │ │ │ │ ├── __init__.py │ │ │ │ ├── feature_annotator.py │ │ │ │ ├── input_cleaner.py │ │ │ │ ├── morph.py │ │ │ │ ├── recategorizer.py │ │ │ │ ├── sense_remover.py │ │ │ │ └── text_anonymizor.py │ │ │ └── propbank_reader.py │ │ └── dataset_reader.py │ ├── fields │ │ ├── __init__.py │ │ ├── adjacency_field.py │ │ ├── array_field.py │ │ ├── field.py │ │ ├── index_field.py │ │ ├── knowledge_graph_field.py │ │ ├── label_field.py │ │ ├── list_field.py │ │ ├── metadata_field.py │ │ ├── multilabel_field.py │ │ ├── production_rule_field.py │ │ ├── sequence_field.py │ │ ├── sequence_label_field.py │ │ ├── span_field.py │ │ └── text_field.py │ ├── instance.py │ ├── iterators │ │ ├── __init__.py │ │ ├── basic_iterator.py │ │ ├── bucket_iterator.py │ │ ├── data_iterator.py │ │ ├── epoch_tracking_bucket_iterator.py │ │ └── multiprocess_iterator.py │ ├── token_indexers │ │ ├── __init__.py │ │ ├── dep_label_indexer.py │ │ ├── elmo_indexer.py │ │ ├── ner_tag_indexer.py │ │ ├── openai_transformer_byte_pair_indexer.py │ │ ├── pos_tag_indexer.py │ │ ├── single_id_token_indexer.py │ │ ├── token_characters_indexer.py │ │ └── token_indexer.py │ ├── tokenizers │ │ ├── __init__.py │ │ ├── bert_tokenizer.py │ │ ├── character_tokenizer.py │ │ ├── token.py │ │ ├── tokenizer.py │ │ ├── word_filter.py │ │ ├── word_splitter.py │ │ ├── word_stemmer.py │ │ └── word_tokenizer.py │ └── vocabulary.py ├── metrics │ ├── __init__.py │ ├── attachment_score.py │ ├── metric.py │ └── seq2seq_metrics.py ├── models │ ├── __init__.py │ ├── model.py │ └── stog.py ├── modules │ ├── __init__.py │ ├── attention │ │ ├── __init__.py │ │ ├── biaffine_attention.py │ │ ├── dot_production_attention.py │ │ └── mlp_attention.py │ ├── attention_layers │ │ ├── __init__.py │ │ └── global_attention.py │ ├── augmented_lstm.py │ ├── decoders │ │ ├── __init__.py │ │ ├── deep_biaffine_graph_decoder.py │ │ ├── generator.py │ │ ├── pointer_generator.py │ │ └── rnn_decoder.py │ ├── encoder_base.py │ ├── initializers.py │ ├── input_variational_dropout.py │ ├── linear │ │ ├── __init__.py │ │ └── bilinear.py │ ├── optimizer.py │ ├── seq2seq_encoders │ │ ├── __init__.py │ │ ├── pytorch_seq2seq_wrapper.py │ │ ├── seq2seq_bert_encoder.py │ │ └── seq2seq_encoder.py │ ├── seq2vec_encoders │ │ ├── __init__.py │ │ ├── boe_encoder.py │ │ ├── cnn_encoder.py │ │ ├── pytorch_seq2vec_wrapper.py │ │ └── seq2vec_encoder.py │ ├── stacked_bilstm.py │ ├── stacked_lstm.py │ ├── text_field_embedders │ │ ├── __init__.py │ │ ├── basic_text_field_embedder.py │ │ └── text_field_embedder.py │ ├── time_distributed.py │ └── token_embedders │ │ ├── __init__.py │ │ ├── embedding.py │ │ ├── openai_transformer_embedder.py │ │ ├── token_characters_encoder.py │ │ └── token_embedder.py ├── predictors │ ├── __init__.py │ ├── predictor.py │ └── stog.py ├── training │ ├── __init__.py │ ├── tensorboard.py │ └── trainer.py └── utils │ ├── __init__.py │ ├── archival.py │ ├── checks.py │ ├── environment.py │ ├── exception_hook.py │ ├── extract_tokens_from_amr.py │ ├── file.py │ ├── from_params.py │ ├── logging.py │ ├── nn.py │ ├── params.py │ ├── registrable.py │ ├── string.py │ ├── time.py │ └── tqdm.py └── tools ├── amr-evaluation-tool-enhanced ├── .Rhistory ├── README.md ├── __init__.py ├── alignments.py ├── allscores.txt ├── amrdata.py ├── evaluation.sh ├── evaluation_label.sh ├── extract_np.py ├── onelabel.py ├── scores.py ├── scores.py.bak ├── smatch │ ├── .filt │ ├── .output_jamr.txt.swp │ ├── LICENSE.txt │ ├── README.txt │ ├── __init__.py │ ├── amr.py │ ├── amr.py.bak │ ├── amr_edited.py │ ├── sample_file_list │ ├── smatch-table.py │ ├── smatch.py │ ├── smatch_edited.py │ └── update_log ├── tmp_debug.txt └── unlabel.py └── fast_smatch ├── README.md ├── __init__.py ├── _gain.cc ├── _gain.h ├── _smatch.c ├── _smatch.cpp ├── _smatch.pyx ├── amr.py ├── api.py ├── compute_smatch.sh ├── fast_smatch.py ├── setup.py ├── smatch-table.py └── smatch.py /amr_parser/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-08-17 18:23 4 | -------------------------------------------------------------------------------- /amr_parser/eval_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-09-18 12:44 4 | import argparse 5 | 6 | from amr_parser.work import predict 7 | from elit.metrics.amr.smatch_eval import smatch_eval, SmatchScores 8 | from elit.utils.io_util import run_cmd 9 | from elit.utils.log_util import flash, cprint 10 | 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('checkpoint', type=str, help='The checkpoint') 15 | parser.add_argument('--version', type=str, help='AMR data version', default='2.0') 16 | args = parser.parse_args() 17 | version = args.version 18 | device = 0 19 | checkpoint = args.checkpoint 20 | test_pred = f'{checkpoint}_test_out.pred' 21 | flash(f'Running prediction {checkpoint} on testset [blink][yellow]...[/yellow][/blink]') 22 | predict(checkpoint, f'data/AMR/amr_{version}/test.txt.features.preproc', device=device) 23 | test_score = eval_checkpoint(test_pred, False, False, version) 24 | cprint(f'Official score on testset: [red]{test_score.score:.1%}[/red]') 25 | print(test_score) 26 | 27 | 28 | def eval_checkpoint(path, use_fast=True, dev=True, version='2.0'): 29 | run_cmd(f'sh postprocess_{version}.sh {path}') 30 | scores: SmatchScores = smatch_eval(f'{path}.post', 31 | f'/home/hhe43/amr_gs/data/AMR/amr_{version}/{"dev" if dev else "test"}.txt', 32 | use_fast=use_fast) 33 | return scores 34 | 35 | 36 | if __name__ == '__main__': 37 | main() 38 | -------------------------------------------------------------------------------- /amr_parser/match.py: -------------------------------------------------------------------------------- 1 | def match(output_file, input_file): 2 | block = [] 3 | blocks = [] 4 | for line in open(input_file, encoding='utf8').readlines(): 5 | if line.startswith('#'): 6 | block.append(line) 7 | else: 8 | if block: 9 | blocks.append(block) 10 | block = [] 11 | 12 | block1 = [] 13 | blocks1 = [] 14 | for line in open(output_file, encoding='utf8').readlines(): 15 | if not line.startswith('#'): 16 | block1.append(line) 17 | else: 18 | if block1: 19 | blocks1.append(block1) 20 | block1 = [] 21 | if block1: 22 | blocks1.append(block1) 23 | assert len(blocks) == len(blocks1), (len(blocks), len(blocks1)) 24 | 25 | 26 | with open(output_file+'.pred', 'w', encoding='utf8') as fo: 27 | for block, block1 in zip(blocks, blocks1): 28 | for line in block: 29 | fo.write(line) 30 | for line in block1: 31 | fo.write(line) 32 | -------------------------------------------------------------------------------- /amr_parser/visual_beam_search.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-11-22 19:10 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from matplotlib import ticker 7 | 8 | 9 | def to_scores(text): 10 | return [float(x) for x in text.split()] 11 | 12 | 13 | giis = to_scores('80.2 80.2 80 80 77.5') 14 | merge = to_scores('80 80.1 80 79.7 78.4') 15 | levi = to_scores('80 80 79.9 79.4 78.3') 16 | 17 | # evenly sampled time at 200ms intervals 18 | t = [1, 2, 4, 8, 16] 19 | t = list(reversed(t)) 20 | 21 | # red dashes, blue squares and green triangles 22 | plt.rcParams["figure.figsize"] = (4, 4) 23 | _, ax = plt.subplots() 24 | 25 | # Be sure to only pick integer tick locations. 26 | for axis in [ax.xaxis, ax.yaxis]: 27 | axis.set_major_locator(ticker.MaxNLocator(integer=True)) 28 | plt.plot(t, giis, 'r--') 29 | plt.plot(t, merge, 'g:') 30 | plt.plot(t, levi, 'b') 31 | plt.legend(['GSII', 'ND + AD + BD', 'ND + AD + Levi']) 32 | plt.xlabel('Beam Size') 33 | plt.ylabel('Smatch') 34 | plt.savefig("/Users/hankcs/Dropbox/应用/Overleaf/NAACL-2021-AMR/fig/beam.pdf", bbox_inches='tight') 35 | plt.show() 36 | -------------------------------------------------------------------------------- /elit/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-06-13 18:05 4 | import elit.pretrained 5 | import elit.utils 6 | from elit.version import __version__ 7 | 8 | elit.utils.ls_resource_in_module(elit.pretrained) 9 | -------------------------------------------------------------------------------- /elit/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-05 02:10 -------------------------------------------------------------------------------- /elit/common/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-08-26 14:45 4 | -------------------------------------------------------------------------------- /elit/common/component.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-08-26 14:45 4 | import inspect 5 | from abc import ABC, abstractmethod 6 | from typing import Any 7 | 8 | from elit.common.configurable import Configurable 9 | 10 | 11 | class Component(Configurable, ABC): 12 | @abstractmethod 13 | def predict(self, data: Any, **kwargs): 14 | """Predict on data. This is the base class for all components, including rule based and statistical ones. 15 | 16 | Args: 17 | data: Any type of data subject to sub-classes 18 | kwargs: Additional arguments 19 | 20 | Returns: Any predicted annotations. 21 | 22 | """ 23 | raise NotImplementedError('%s.%s()' % (self.__class__.__name__, inspect.stack()[0][3])) 24 | 25 | def __call__(self, data: Any, **kwargs): 26 | """ 27 | A shortcut for :func:`~elit.common.component.predict`. 28 | 29 | Args: 30 | data: 31 | **kwargs: 32 | 33 | Returns: 34 | 35 | """ 36 | return self.predict(data, **kwargs) 37 | -------------------------------------------------------------------------------- /elit/common/configurable.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-16 22:24 4 | from elit.common.reflection import str_to_type, classpath_of 5 | 6 | 7 | class Configurable(object): 8 | @staticmethod 9 | def from_config(config: dict, **kwargs): 10 | """Build an object from config. 11 | 12 | Args: 13 | config: A ``dict`` holding parameters for its constructor. It has to contain a `classpath` key, 14 | which has a classpath str as its value. ``classpath`` will determine the type of object 15 | being deserialized. 16 | kwargs: Arguments not used. 17 | 18 | Returns: A deserialized object. 19 | 20 | """ 21 | cls = config.get('classpath', None) 22 | assert cls, f'{config} doesn\'t contain classpath field' 23 | cls = str_to_type(cls) 24 | deserialized_config = dict(config) 25 | for k, v in config.items(): 26 | if isinstance(v, dict) and 'classpath' in v: 27 | deserialized_config[k] = Configurable.from_config(v) 28 | if cls.from_config == Configurable.from_config: 29 | deserialized_config.pop('classpath') 30 | return cls(**deserialized_config) 31 | else: 32 | return cls.from_config(deserialized_config) 33 | 34 | 35 | class AutoConfigurable(Configurable): 36 | @property 37 | def config(self) -> dict: 38 | """ 39 | The config of this object, which are public properties. If any properties needs to be excluded from this config, 40 | simply declare it with prefix ``_``. 41 | """ 42 | return dict([('classpath', classpath_of(self))] + 43 | [(k, v.config if hasattr(v, 'config') else v) 44 | for k, v in self.__dict__.items() if 45 | not k.startswith('_')]) 46 | 47 | def __repr__(self) -> str: 48 | return repr(self.config) 49 | -------------------------------------------------------------------------------- /elit/common/constant.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-06-13 22:41 4 | import os 5 | 6 | PAD = '' 7 | '''Padding token.''' 8 | UNK = '' 9 | '''Unknown token.''' 10 | CLS = '[CLS]' 11 | BOS = '' 12 | EOS = '' 13 | ROOT = BOS 14 | IDX = '_idx_' 15 | '''Key for index.''' 16 | HANLP_URL = os.getenv('HANLP_URL', 'https://file.hankcs.com/hanlp/') 17 | '''Resource URL.''' 18 | HANLP_VERBOSE = os.environ.get('HANLP_VERBOSE', '1').lower() in ('1', 'true', 'yes') 19 | '''Enable verbose or not.''' 20 | NULL = '' 21 | PRED = 'PRED' 22 | try: 23 | # noinspection PyUnresolvedReferences,PyStatementEffect 24 | get_ipython 25 | IPYTHON = True 26 | except NameError: 27 | IPYTHON = False 28 | -------------------------------------------------------------------------------- /elit/common/io.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-16 22:38 4 | import json 5 | import os 6 | import pickle 7 | import sys 8 | 9 | 10 | def save_pickle(item, path): 11 | with open(path, 'wb') as f: 12 | pickle.dump(item, f) 13 | 14 | 15 | def load_pickle(path): 16 | with open(path, 'rb') as f: 17 | return pickle.load(f) 18 | 19 | 20 | def save_json(item: dict, path: str, ensure_ascii=False, cls=None, default=lambda o: repr(o), indent=2): 21 | dirname = os.path.dirname(path) 22 | if dirname: 23 | os.makedirs(dirname, exist_ok=True) 24 | with open(path, 'w', encoding='utf-8') as out: 25 | json.dump(item, out, ensure_ascii=ensure_ascii, indent=indent, cls=cls, default=default) 26 | 27 | 28 | def load_json(path): 29 | with open(path, encoding='utf-8') as src: 30 | return json.load(src) 31 | 32 | 33 | def filename_is_json(filename): 34 | filename, file_extension = os.path.splitext(filename) 35 | return file_extension in ['.json', '.jsonl'] 36 | 37 | 38 | def eprint(*args, **kwargs): 39 | print(*args, file=sys.stderr, **kwargs) -------------------------------------------------------------------------------- /elit/common/reflection.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 16:41 4 | import importlib 5 | import inspect 6 | 7 | 8 | def classpath_of(obj) -> str: 9 | """get the full class path of object 10 | 11 | Args: 12 | obj: return: 13 | 14 | Returns: 15 | 16 | """ 17 | if inspect.isfunction(obj): 18 | return module_path_of(obj) 19 | return "{0}.{1}".format(obj.__class__.__module__, obj.__class__.__name__) 20 | 21 | 22 | def module_path_of(func) -> str: 23 | return inspect.getmodule(func).__name__ + '.' + func.__name__ 24 | 25 | 26 | def object_from_classpath(classpath, **kwargs): 27 | classpath = str_to_type(classpath) 28 | if inspect.isfunction(classpath): 29 | return classpath 30 | return classpath(**kwargs) 31 | 32 | 33 | def str_to_type(classpath): 34 | """convert class path in str format to a type 35 | 36 | Args: 37 | classpath: class path 38 | 39 | Returns: 40 | type 41 | 42 | """ 43 | module_name, class_name = classpath.rsplit(".", 1) 44 | cls = getattr(importlib.import_module(module_name), class_name) 45 | return cls 46 | 47 | 48 | def type_to_str(type_object) -> str: 49 | """convert a type object to class path in str format 50 | 51 | Args: 52 | type_object: type 53 | 54 | Returns: 55 | class path 56 | 57 | """ 58 | cls_name = str(type_object) 59 | assert cls_name.startswith(""), 'illegal input' 62 | cls_name = cls_name[:-len("'>")] 63 | return cls_name 64 | -------------------------------------------------------------------------------- /elit/components/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-08-26 16:10 -------------------------------------------------------------------------------- /elit/components/amr/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-08-20 17:35 4 | -------------------------------------------------------------------------------- /elit/components/amr/amr_parser/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-08-17 18:23 4 | -------------------------------------------------------------------------------- /elit/components/amr/amr_parser/data.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | 5 | from elit.common.constant import PAD 6 | from elit.common.vocab import Vocab 7 | 8 | DUM, NIL, END = '[unused0]', '', '[unused1]' 9 | REL = 'rel=' 10 | 11 | 12 | def list_to_tensor(xs, vocab: Vocab = None, local_vocabs=None, unk_rate=0.): 13 | pad = vocab.pad_idx if vocab else 0 14 | 15 | def to_idx(w, i): 16 | if vocab is None: 17 | return w 18 | if isinstance(w, list): 19 | return [to_idx(_, i) for _ in w] 20 | if random.random() < unk_rate: 21 | return vocab.unk_idx 22 | if local_vocabs is not None: 23 | local_vocab = local_vocabs[i] 24 | if (local_vocab is not None) and (w in local_vocab): 25 | return local_vocab[w] 26 | return vocab.get_idx(w) 27 | 28 | max_len = max(len(x) for x in xs) 29 | ys = [] 30 | for i, x in enumerate(xs): 31 | y = to_idx(x, i) + [pad] * (max_len - len(x)) 32 | ys.append(y) 33 | data = np.transpose(np.array(ys)) 34 | return data 35 | 36 | 37 | def lists_of_string_to_tensor(xs, vocab: Vocab, max_string_len=20): 38 | max_len = max(len(x) for x in xs) 39 | ys = [] 40 | for x in xs: 41 | y = x + [PAD] * (max_len - len(x)) 42 | zs = [] 43 | for z in y: 44 | z = list(z[:max_string_len]) 45 | zs.append(vocab([DUM] + z + [END]) + [vocab.pad_idx] * (max_string_len - len(z))) 46 | ys.append(zs) 47 | 48 | data = np.transpose(np.array(ys), (1, 0, 2)) 49 | return data 50 | -------------------------------------------------------------------------------- /elit/components/classifiers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-11-10 13:18 -------------------------------------------------------------------------------- /elit/components/coref/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-05 19:56 -------------------------------------------------------------------------------- /elit/components/distillation/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-10-17 20:29 4 | -------------------------------------------------------------------------------- /elit/components/distillation/distillable_component.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-10-17 20:30 4 | from abc import ABC 5 | from copy import copy 6 | 7 | import elit 8 | from elit.common.torch_component import TorchComponent 9 | from elit.components.distillation.losses import KnowledgeDistillationLoss 10 | from elit.components.distillation.schedulers import TemperatureScheduler 11 | from elit.utils.torch_util import cuda_devices 12 | from elit.common.util import merge_locals_kwargs 13 | 14 | 15 | class DistillableComponent(TorchComponent, ABC): 16 | 17 | # noinspection PyMethodMayBeStatic,PyTypeChecker 18 | def build_teacher(self, teacher: str, devices) -> TorchComponent: 19 | return elit.load(teacher, load_kwargs={'devices': devices}) 20 | 21 | def distill(self, 22 | teacher: str, 23 | trn_data, 24 | dev_data, 25 | save_dir, 26 | batch_size=None, 27 | epochs=None, 28 | kd_criterion='kd_ce_loss', 29 | temperature_scheduler='flsw', 30 | devices=None, 31 | logger=None, 32 | seed=None, 33 | **kwargs): 34 | devices = devices or cuda_devices() 35 | if isinstance(kd_criterion, str): 36 | kd_criterion = KnowledgeDistillationLoss(kd_criterion) 37 | if isinstance(temperature_scheduler, str): 38 | temperature_scheduler = TemperatureScheduler.from_name(temperature_scheduler) 39 | teacher = self.build_teacher(teacher, devices=devices) 40 | self.vocabs = teacher.vocabs 41 | config = copy(teacher.config) 42 | batch_size = batch_size or config.get('batch_size', None) 43 | epochs = epochs or config.get('epochs', None) 44 | config.update(kwargs) 45 | return super().fit(**merge_locals_kwargs(locals(), 46 | config, 47 | excludes=('self', 'kwargs', '__class__', 'config'))) 48 | 49 | @property 50 | def _savable_config(self): 51 | config = super(DistillableComponent, self)._savable_config 52 | if 'teacher' in config: 53 | config.teacher = config.teacher.load_path 54 | return config 55 | -------------------------------------------------------------------------------- /elit/components/eos/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-26 20:19 -------------------------------------------------------------------------------- /elit/components/lambda_wrapper.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-31 18:36 4 | from typing import Callable, Any 5 | 6 | from elit.common.component import Component 7 | from elit.common.reflection import classpath_of, object_from_classpath, str_to_type 8 | 9 | 10 | class LambdaComponent(Component): 11 | def __init__(self, function: Callable) -> None: 12 | super().__init__() 13 | self.config = {} 14 | self.function = function 15 | self.config['function'] = classpath_of(function) 16 | self.config['classpath'] = classpath_of(self) 17 | 18 | def predict(self, data: Any, **kwargs): 19 | unpack = kwargs.pop('_hanlp_unpack', None) 20 | if unpack: 21 | return self.function(*data, **kwargs) 22 | return self.function(data, **kwargs) 23 | 24 | @staticmethod 25 | def from_config(meta: dict, **kwargs): 26 | cls = str_to_type(meta['classpath']) 27 | function = meta['function'] 28 | function = object_from_classpath(function) 29 | return cls(function) 30 | -------------------------------------------------------------------------------- /elit/components/lemmatizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-08 18:35 4 | from typing import List 5 | 6 | from elit.common.transform import TransformList 7 | from elit.components.parsers.ud.lemma_edit import gen_lemma_rule, apply_lemma_rule 8 | from elit.components.taggers.transformers.transformer_tagger import TransformerTagger 9 | 10 | 11 | def add_lemma_rules_to_sample(sample: dict): 12 | if 'tag' in sample and 'lemma' not in sample: 13 | lemma_rules = [gen_lemma_rule(word, lemma) 14 | if lemma != "_" else "_" 15 | for word, lemma in zip(sample['token'], sample['tag'])] 16 | sample['lemma'] = sample['tag'] = lemma_rules 17 | return sample 18 | 19 | 20 | class TransformerLemmatizer(TransformerTagger): 21 | 22 | def __init__(self, **kwargs) -> None: 23 | """A transition based lemmatizer using transformer as encoder. 24 | 25 | Args: 26 | **kwargs: Predefined config. 27 | """ 28 | super().__init__(**kwargs) 29 | 30 | def build_dataset(self, data, transform=None, **kwargs): 31 | if not isinstance(transform, list): 32 | transform = TransformList() 33 | transform.append(add_lemma_rules_to_sample) 34 | return super().build_dataset(data, transform, **kwargs) 35 | 36 | def prediction_to_human(self, pred, vocab: List[str], batch, token=None): 37 | if token is None: 38 | token = batch['token'] 39 | rules = super().prediction_to_human(pred, vocab, batch) 40 | for token_per_sent, rule_per_sent in zip(token, rules): 41 | lemma_per_sent = [apply_lemma_rule(t, r) for t, r in zip(token_per_sent, rule_per_sent)] 42 | yield lemma_per_sent 43 | -------------------------------------------------------------------------------- /elit/components/mtl/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-06-20 19:54 -------------------------------------------------------------------------------- /elit/components/mtl/tasks/ner/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-03 14:34 4 | -------------------------------------------------------------------------------- /elit/components/mtl/tasks/srl/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-04 16:49 4 | -------------------------------------------------------------------------------- /elit/components/mtl/tasks/tok/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-08-11 16:34 -------------------------------------------------------------------------------- /elit/components/ner/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-21 17:22 -------------------------------------------------------------------------------- /elit/components/ner/biaffine_ner/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-21 18:41 -------------------------------------------------------------------------------- /elit/components/parsers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-22 12:46 -------------------------------------------------------------------------------- /elit/components/parsers/biaffine/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-05-08 20:43 4 | -------------------------------------------------------------------------------- /elit/components/parsers/biaffine_tf/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-26 23:03 -------------------------------------------------------------------------------- /elit/components/parsers/constituency/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-11-28 19:26 4 | -------------------------------------------------------------------------------- /elit/components/parsers/hpsg/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-22 21:35 -------------------------------------------------------------------------------- /elit/components/parsers/hpsg/hpsg_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-22 21:36 4 | import os 5 | from typing import Union, List, Callable, Tuple 6 | 7 | from elit.common.dataset import TransformableDataset 8 | from elit.components.parsers.hpsg.trees import load_trees_from_str 9 | from elit.utils.io_util import read_tsv_as_sents, TimingFileIterator, get_resource 10 | 11 | 12 | class HeadDrivenPhraseStructureDataset(TransformableDataset): 13 | 14 | def __init__(self, data: Union[List, Tuple] = None, 15 | transform: Union[Callable, List] = None, cache=None) -> None: 16 | super().__init__(data, transform, cache) 17 | 18 | def load_data(self, data, generate_idx=False): 19 | if isinstance(data, tuple): 20 | data = list(self.load_file(data)) 21 | return data 22 | 23 | def load_file(self, filepath: tuple): 24 | phrase_tree_path = get_resource(filepath[0]) 25 | dep_tree_path = get_resource(filepath[1]) 26 | pf = TimingFileIterator(phrase_tree_path) 27 | message_prefix = f'Loading {os.path.basename(phrase_tree_path)} and {os.path.basename(dep_tree_path)}' 28 | for i, (dep_sent, phrase_sent) in enumerate(zip(read_tsv_as_sents(dep_tree_path), pf)): 29 | # Somehow the file contains escaped literals 30 | phrase_sent = phrase_sent.replace('\\/', '/') 31 | 32 | token = [x[1] for x in dep_sent] 33 | pos = [x[3] for x in dep_sent] 34 | head = [int(x[6]) for x in dep_sent] 35 | rel = [x[7] for x in dep_sent] 36 | phrase_tree = load_trees_from_str(phrase_sent, [head], [rel], [token]) 37 | assert len(phrase_tree) == 1, f'{phrase_tree_path} must have on tree per line.' 38 | phrase_tree = phrase_tree[0] 39 | 40 | yield { 41 | 'FORM': token, 42 | 'CPOS': pos, 43 | 'HEAD': head, 44 | 'DEPREL': rel, 45 | 'tree': phrase_tree, 46 | 'hpsg': phrase_tree.convert() 47 | } 48 | pf.log(f'{message_prefix} {i + 1} samples [blink][yellow]...[/yellow][/blink]') 49 | pf.erase() 50 | -------------------------------------------------------------------------------- /elit/components/parsers/second_order/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-01 13:44 4 | -------------------------------------------------------------------------------- /elit/components/parsers/second_order/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-01 15:28 4 | from torch import nn 5 | 6 | 7 | # noinspection PyAbstractClass 8 | class DependencyModel(nn.Module): 9 | def __init__(self, embed: nn.Module, encoder: nn.Module, decoder: nn.Module): 10 | super().__init__() 11 | self.embed = embed 12 | self.encoder = encoder 13 | self.decoder = decoder 14 | 15 | def forward(self, batch, mask): 16 | x = self.embed(batch, mask=mask) 17 | x = self.encoder(x, mask) 18 | return self.decoder(x, mask=mask) 19 | -------------------------------------------------------------------------------- /elit/components/parsers/second_order/treecrf_decoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-01 16:51 4 | from typing import Any, Tuple 5 | 6 | import torch 7 | 8 | from elit.components.parsers.biaffine.biaffine_model import BiaffineDecoder 9 | from elit.components.parsers.biaffine.mlp import MLP 10 | from elit.components.parsers.constituency.treecrf import CRF2oDependency 11 | from elit.components.parsers.second_order.affine import Triaffine 12 | 13 | 14 | class TreeCRFDecoder(BiaffineDecoder): 15 | def __init__(self, hidden_size, n_mlp_arc, n_mlp_sib, n_mlp_rel, mlp_dropout, n_rels) -> None: 16 | super().__init__(hidden_size, n_mlp_arc, n_mlp_rel, mlp_dropout, n_rels) 17 | self.mlp_sib_s = MLP(hidden_size, n_mlp_sib, dropout=mlp_dropout) 18 | self.mlp_sib_d = MLP(hidden_size, n_mlp_sib, dropout=mlp_dropout) 19 | self.mlp_sib_h = MLP(hidden_size, n_mlp_sib, dropout=mlp_dropout) 20 | 21 | self.sib_attn = Triaffine(n_in=n_mlp_sib, bias_x=True, bias_y=True) 22 | self.crf = CRF2oDependency() 23 | 24 | def forward(self, x, mask=None, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 25 | s_arc, s_rel = super(TreeCRFDecoder, self).forward(x, mask) 26 | sib_s = self.mlp_sib_s(x) 27 | sib_d = self.mlp_sib_d(x) 28 | sib_h = self.mlp_sib_h(x) 29 | # [batch_size, seq_len, seq_len, seq_len] 30 | s_sib = self.sib_attn(sib_s, sib_d, sib_h).permute(0, 3, 1, 2) 31 | return s_arc, s_sib, s_rel 32 | -------------------------------------------------------------------------------- /elit/components/parsers/ud/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-14 20:34 4 | -------------------------------------------------------------------------------- /elit/components/parsers/ud/util.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-14 20:44 4 | from elit.common.constant import ROOT 5 | from elit.components.parsers.ud.lemma_edit import gen_lemma_rule 6 | 7 | 8 | def generate_lemma_rule(sample: dict): 9 | if 'LEMMA' in sample: 10 | sample['lemma'] = [gen_lemma_rule(word, lemma) if lemma != "_" else "_" for word, lemma in 11 | zip(sample['FORM'], sample['LEMMA'])] 12 | return sample 13 | 14 | 15 | def append_bos(sample: dict): 16 | if 'FORM' in sample: 17 | sample['token'] = [ROOT] + sample['FORM'] 18 | if 'UPOS' in sample: 19 | sample['pos'] = sample['UPOS'][:1] + sample['UPOS'] 20 | sample['arc'] = [0] + sample['HEAD'] 21 | sample['rel'] = sample['DEPREL'][:1] + sample['DEPREL'] 22 | sample['lemma'] = sample['lemma'][:1] + sample['lemma'] 23 | sample['feat'] = sample['FEATS'][:1] + sample['FEATS'] 24 | return sample 25 | 26 | 27 | def sample_form_missing(sample: dict): 28 | return all(t == '_' for t in sample['FORM']) 29 | -------------------------------------------------------------------------------- /elit/components/pos_tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-05 23:05 4 | from elit.components.taggers.cnn_tagger_tf import CNNTaggerTF 5 | from elit.components.taggers.rnn_tagger_tf import RNNTaggerTF 6 | 7 | 8 | class CNNPartOfSpeechTaggerTF(CNNTaggerTF): 9 | pass 10 | 11 | 12 | class RNNPartOfSpeechTaggerTF(RNNTaggerTF): 13 | pass 14 | -------------------------------------------------------------------------------- /elit/components/srl/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-06-22 20:50 -------------------------------------------------------------------------------- /elit/components/srl/span_bio/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-04 13:59 4 | -------------------------------------------------------------------------------- /elit/components/srl/span_rank/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-19 22:22 -------------------------------------------------------------------------------- /elit/components/srl/span_rank/util.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/KiroSummer/A_Syntax-aware_MTL_Framework_for_Chinese_SRL 2 | import torch 3 | 4 | 5 | def block_orth_normal_initializer(input_size, output_size): 6 | weight = [] 7 | for o in output_size: 8 | for i in input_size: 9 | param = torch.FloatTensor(o, i) 10 | torch.nn.init.orthogonal_(param) 11 | weight.append(param) 12 | return torch.cat(weight) 13 | -------------------------------------------------------------------------------- /elit/components/taggers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-08-28 15:39 -------------------------------------------------------------------------------- /elit/components/taggers/ngram_conv/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-29 22:18 -------------------------------------------------------------------------------- /elit/components/taggers/rnn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-05-19 15:41 -------------------------------------------------------------------------------- /elit/components/taggers/tagger_tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-10-25 21:49 4 | import logging 5 | from abc import ABC 6 | 7 | import tensorflow as tf 8 | 9 | from elit.common.keras_component import KerasComponent 10 | from elit.layers.crf.crf_layer_tf import CRF, CRFLoss, CRFWrapper 11 | from elit.metrics.chunking.iobes_tf import IOBES_F1_TF 12 | 13 | 14 | class TaggerComponent(KerasComponent, ABC): 15 | 16 | def build_metrics(self, metrics, logger: logging.Logger, **kwargs): 17 | if metrics == 'f1': 18 | assert hasattr(self.transform, 'tag_vocab'), 'Name your tag vocab tag_vocab in your transform ' \ 19 | 'or override build_metrics' 20 | if not self.config.get('run_eagerly', None): 21 | logger.debug('ChunkingF1 runs only under eager mode, ' 22 | 'set run_eagerly=True to remove this warning') 23 | self.config.run_eagerly = True 24 | return IOBES_F1_TF(self.transform.tag_vocab) 25 | return super().build_metrics(metrics, logger, **kwargs) 26 | 27 | def build_loss(self, loss, **kwargs): 28 | assert self.model is not None, 'should create model before build loss' 29 | if loss == 'crf': 30 | if isinstance(self.model, tf.keras.models.Sequential): 31 | crf = CRF(len(self.transform.tag_vocab)) 32 | self.model.add(crf) 33 | loss = CRFLoss(crf, self.model.dtype) 34 | else: 35 | self.model = CRFWrapper(self.model, len(self.transform.tag_vocab)) 36 | loss = CRFLoss(self.model.crf, self.model.dtype) 37 | return loss 38 | return super().build_loss(loss, **kwargs) 39 | -------------------------------------------------------------------------------- /elit/components/taggers/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-29 13:57 -------------------------------------------------------------------------------- /elit/components/taggers/transformers/metrics_tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-30 16:33 4 | import tensorflow as tf 5 | 6 | 7 | class Accuracy(tf.keras.metrics.SparseCategoricalAccuracy): 8 | 9 | def __init__(self, name='sparse_categorical_accuracy', dtype=None, mask_value=0): 10 | super().__init__(name, dtype) 11 | self.mask_value = mask_value 12 | 13 | def update_state(self, y_true, y_pred, sample_weight=None): 14 | sample_weight = tf.not_equal(y_true, self.mask_value) 15 | return super().update_state(y_true, y_pred, sample_weight) 16 | -------------------------------------------------------------------------------- /elit/components/taggers/util.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-06-01 00:31 4 | from typing import List, Tuple 5 | from alnlp.modules.conditional_random_field import allowed_transitions 6 | 7 | 8 | def guess_tagging_scheme(labels: List[str]) -> str: 9 | tagset = set(y.split('-')[0] for y in labels) 10 | for scheme in "BIO", "BIOUL", "BMES", 'IOBES': 11 | if tagset == set(list(scheme)): 12 | return scheme 13 | 14 | 15 | def guess_allowed_transitions(labels) -> List[Tuple[int, int]]: 16 | scheme = guess_tagging_scheme(labels) 17 | if not scheme: 18 | return None 19 | if scheme == 'IOBES': 20 | scheme = 'BIOUL' 21 | labels = [y.replace('E-', 'L-').replace('S-', 'U-') for y in labels] 22 | return allowed_transitions(scheme, dict(enumerate(labels))) 23 | -------------------------------------------------------------------------------- /elit/components/tok.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-06-12 13:08 4 | from typing import Any, Callable 5 | 6 | from elit.components.taggers.rnn_tagger import RNNTagger 7 | from elit.datasets.cws.chunking_dataset import ChunkingDataset 8 | from elit.metrics.chunking.chunking_f1 import ChunkingF1 9 | from elit.utils.span_util import bmes_to_words 10 | from elit.common.util import merge_locals_kwargs 11 | 12 | 13 | class RNNTokenizer(RNNTagger): 14 | 15 | def predict(self, sentence: Any, batch_size: int = None, **kwargs): 16 | flat = isinstance(sentence, str) 17 | if flat: 18 | sentence = [sentence] 19 | for i, s in enumerate(sentence): 20 | sentence[i] = list(s) 21 | outputs = RNNTagger.predict(self, sentence, batch_size, **kwargs) 22 | if flat: 23 | return outputs[0] 24 | return outputs 25 | 26 | def predict_data(self, data, batch_size, **kwargs): 27 | tags = RNNTagger.predict_data(self, data, batch_size, **kwargs) 28 | words = [bmes_to_words(c, t) for c, t in zip(data, tags)] 29 | return words 30 | 31 | def build_dataset(self, data, transform=None): 32 | dataset = ChunkingDataset(data) 33 | if 'transform' in self.config: 34 | dataset.append_transform(self.config.transform) 35 | if transform: 36 | dataset.append_transform(transform) 37 | return dataset 38 | 39 | def build_metric(self, **kwargs): 40 | return ChunkingF1() 41 | 42 | def update_metrics(self, metric, logits, y, mask, batch): 43 | pred = self.decode_output(logits, mask, batch) 44 | pred = self._id_to_tags(pred) 45 | gold = batch['tag'] 46 | metric(pred, gold) 47 | 48 | def fit(self, trn_data, dev_data, save_dir, batch_size=50, epochs=100, embed=100, rnn_input=None, rnn_hidden=256, 49 | drop=0.5, lr=0.001, patience=10, crf=True, optimizer='adam', token_key='char', tagging_scheme=None, 50 | anneal_factor: float = 0.5, anneal_patience=2, devices=None, logger=None, 51 | verbose=True, transform: Callable = None, **kwargs): 52 | return super().fit(**merge_locals_kwargs(locals(), kwargs)) 53 | 54 | 55 | -------------------------------------------------------------------------------- /elit/components/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-08-11 02:48 -------------------------------------------------------------------------------- /elit/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-06-13 18:15 4 | -------------------------------------------------------------------------------- /elit/datasets/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-11-10 11:49 -------------------------------------------------------------------------------- /elit/datasets/classification/sentiment.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-30 21:03 4 | _ERNIE_TASK_DATA = 'https://ernie.bj.bcebos.com/task_data_zh.tgz#' 5 | 6 | CHNSENTICORP_ERNIE_TRAIN = _ERNIE_TASK_DATA + 'chnsenticorp/train.tsv' 7 | CHNSENTICORP_ERNIE_DEV = _ERNIE_TASK_DATA + 'chnsenticorp/dev.tsv' 8 | CHNSENTICORP_ERNIE_TEST = _ERNIE_TASK_DATA + 'chnsenticorp/test.tsv' 9 | -------------------------------------------------------------------------------- /elit/datasets/coref/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-04 13:39 -------------------------------------------------------------------------------- /elit/datasets/cws/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-21 15:40 -------------------------------------------------------------------------------- /elit/datasets/cws/chunking_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-06-03 18:50 4 | from typing import Union, List, Callable 5 | 6 | from elit.common.dataset import TransformableDataset 7 | from elit.utils.io_util import get_resource 8 | from elit.utils.span_util import bmes_of 9 | from elit.utils.string_util import ispunct 10 | 11 | 12 | class ChunkingDataset(TransformableDataset): 13 | 14 | def __init__(self, data: Union[str, List], transform: Union[Callable, List] = None, cache=None, 15 | generate_idx=None, max_seq_len=None, sent_delimiter=None) -> None: 16 | if not sent_delimiter: 17 | sent_delimiter = lambda x: ispunct(x) 18 | elif isinstance(sent_delimiter, str): 19 | sent_delimiter = set(list(sent_delimiter)) 20 | sent_delimiter = lambda x: x in sent_delimiter 21 | self.sent_delimiter = sent_delimiter 22 | self.max_seq_len = max_seq_len 23 | super().__init__(data, transform, cache, generate_idx) 24 | 25 | def load_file(self, filepath): 26 | max_seq_len = self.max_seq_len 27 | delimiter = self.sent_delimiter 28 | for chars, tags in self._generate_chars_tags(filepath, delimiter, max_seq_len): 29 | yield {'char': chars, 'tag': tags} 30 | 31 | @staticmethod 32 | def _generate_chars_tags(filepath, delimiter, max_seq_len): 33 | filepath = get_resource(filepath) 34 | with open(filepath, encoding='utf8') as src: 35 | for text in src: 36 | chars, tags = bmes_of(text, True) 37 | if max_seq_len and delimiter and len(chars) > max_seq_len: 38 | short_chars, short_tags = [], [] 39 | for idx, (char, tag) in enumerate(zip(chars, tags)): 40 | short_chars.append(char) 41 | short_tags.append(tag) 42 | if len(short_chars) >= max_seq_len and delimiter(char): 43 | yield short_chars, short_tags 44 | short_chars, short_tags = [], [] 45 | if short_chars: 46 | yield short_chars, short_tags 47 | else: 48 | yield chars, tags 49 | -------------------------------------------------------------------------------- /elit/datasets/cws/ctb6.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 22:19 4 | 5 | _CTB6_CWS_HOME = 'http://file.hankcs.com/corpus/ctb6_cws.zip' 6 | 7 | CTB6_CWS_TRAIN = _CTB6_CWS_HOME + '#train.txt' 8 | '''CTB6 training set.''' 9 | CTB6_CWS_DEV = _CTB6_CWS_HOME + '#dev.txt' 10 | '''CTB6 dev set.''' 11 | CTB6_CWS_TEST = _CTB6_CWS_HOME + '#test.txt' 12 | '''CTB6 test set.''' 13 | -------------------------------------------------------------------------------- /elit/datasets/cws/multi_criteria_cws/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-08-11 20:35 4 | 5 | _HOME = 'https://github.com/hankcs/multi-criteria-cws/archive/naive-mix.zip#data/raw/' 6 | 7 | CNC_TRAIN_ALL = _HOME + 'cnc/train-all.txt' 8 | CNC_TRAIN = _HOME + 'cnc/train.txt' 9 | CNC_DEV = _HOME + 'cnc/dev.txt' 10 | CNC_TEST = _HOME + 'cnc/test.txt' 11 | 12 | CTB_TRAIN_ALL = _HOME + 'ctb/train-all.txt' 13 | CTB_TRAIN = _HOME + 'ctb/train.txt' 14 | CTB_DEV = _HOME + 'ctb/dev.txt' 15 | CTB_TEST = _HOME + 'ctb/test.txt' 16 | 17 | SXU_TRAIN_ALL = _HOME + 'sxu/train-all.txt' 18 | SXU_TRAIN = _HOME + 'sxu/train.txt' 19 | SXU_DEV = _HOME + 'sxu/dev.txt' 20 | SXU_TEST = _HOME + 'sxu/test.txt' 21 | 22 | UDC_TRAIN_ALL = _HOME + 'udc/train-all.txt' 23 | UDC_TRAIN = _HOME + 'udc/train.txt' 24 | UDC_DEV = _HOME + 'udc/dev.txt' 25 | UDC_TEST = _HOME + 'udc/test.txt' 26 | 27 | WTB_TRAIN_ALL = _HOME + 'wtb/train-all.txt' 28 | WTB_TRAIN = _HOME + 'wtb/train.txt' 29 | WTB_DEV = _HOME + 'wtb/dev.txt' 30 | WTB_TEST = _HOME + 'wtb/test.txt' 31 | 32 | ZX_TRAIN_ALL = _HOME + 'zx/train-all.txt' 33 | ZX_TRAIN = _HOME + 'zx/train.txt' 34 | ZX_DEV = _HOME + 'zx/dev.txt' 35 | ZX_TEST = _HOME + 'zx/test.txt' 36 | -------------------------------------------------------------------------------- /elit/datasets/cws/sighan2005/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-21 15:42 4 | import os 5 | 6 | from elit.utils.io_util import get_resource, split_file 7 | from elit.utils.log_util import logger 8 | 9 | SIGHAN2005 = 'http://sighan.cs.uchicago.edu/bakeoff2005/data/icwb2-data.zip' 10 | 11 | 12 | def make(train): 13 | root = get_resource(SIGHAN2005) 14 | train = os.path.join(root, train.split('#')[-1]) 15 | if not os.path.isfile(train): 16 | full = train.replace('_90.txt', '.utf8') 17 | logger.info(f'Splitting {full} into training set and valid set with 9:1 proportion') 18 | valid = train.replace('90.txt', '10.txt') 19 | split_file(full, train=0.9, dev=0.1, test=0, names={'train': train, 'dev': valid}) 20 | assert os.path.isfile(train), f'Failed to make {train}' 21 | assert os.path.isfile(valid), f'Failed to make {valid}' 22 | logger.info(f'Successfully made {train} {valid}') 23 | -------------------------------------------------------------------------------- /elit/datasets/cws/sighan2005/as_.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-21 15:42 4 | from elit.datasets.cws.sighan2005 import SIGHAN2005, make 5 | 6 | SIGHAN2005_AS_DICT = SIGHAN2005 + "#" + "gold/as_training_words.utf8" 7 | '''Dictionary built on trainings set.''' 8 | SIGHAN2005_AS_TRAIN_ALL = SIGHAN2005 + "#" + "training/as_training.utf8" 9 | '''Full training set.''' 10 | SIGHAN2005_AS_TRAIN = SIGHAN2005 + "#" + "training/as_training_90.txt" 11 | '''Training set (first 90% of the full official training set).''' 12 | SIGHAN2005_AS_DEV = SIGHAN2005 + "#" + "training/as_training_10.txt" 13 | '''Dev set (last 10% of full official training set).''' 14 | SIGHAN2005_AS_TEST_INPUT = SIGHAN2005 + "#" + "testing/as_testing.utf8" 15 | '''Test input.''' 16 | SIGHAN2005_AS_TEST = SIGHAN2005 + "#" + "gold/as_testing_gold.utf8" 17 | '''Test set.''' 18 | 19 | make(SIGHAN2005_AS_TRAIN) 20 | -------------------------------------------------------------------------------- /elit/datasets/cws/sighan2005/cityu.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-21 15:42 4 | from elit.datasets.cws.sighan2005 import SIGHAN2005, make 5 | 6 | SIGHAN2005_CITYU_DICT = SIGHAN2005 + "#" + "gold/cityu_training_words.utf8" 7 | '''Dictionary built on trainings set.''' 8 | SIGHAN2005_CITYU_TRAIN_ALL = SIGHAN2005 + "#" + "training/cityu_training.utf8" 9 | '''Full training set.''' 10 | SIGHAN2005_CITYU_TRAIN = SIGHAN2005 + "#" + "training/cityu_training_90.txt" 11 | '''Training set (first 90% of the full official training set).''' 12 | SIGHAN2005_CITYU_DEV = SIGHAN2005 + "#" + "training/cityu_training_10.txt" 13 | '''Dev set (last 10% of full official training set).''' 14 | SIGHAN2005_CITYU_TEST_INPUT = SIGHAN2005 + "#" + "testing/cityu_test.utf8" 15 | '''Test input.''' 16 | SIGHAN2005_CITYU_TEST = SIGHAN2005 + "#" + "gold/cityu_test_gold.utf8" 17 | '''Test set.''' 18 | 19 | make(SIGHAN2005_CITYU_TRAIN) 20 | -------------------------------------------------------------------------------- /elit/datasets/cws/sighan2005/msr.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-21 15:42 4 | from elit.datasets.cws.sighan2005 import SIGHAN2005, make 5 | 6 | SIGHAN2005_MSR_DICT = SIGHAN2005 + "#" + "gold/msr_training_words.utf8" 7 | '''Dictionary built on trainings set.''' 8 | SIGHAN2005_MSR_TRAIN_ALL = SIGHAN2005 + "#" + "training/msr_training.utf8" 9 | '''Full training set.''' 10 | SIGHAN2005_MSR_TRAIN = SIGHAN2005 + "#" + "training/msr_training_90.txt" 11 | '''Training set (first 90% of the full official training set).''' 12 | SIGHAN2005_MSR_DEV = SIGHAN2005 + "#" + "training/msr_training_10.txt" 13 | '''Dev set (last 10% of full official training set).''' 14 | SIGHAN2005_MSR_TEST_INPUT = SIGHAN2005 + "#" + "testing/msr_test.utf8" 15 | '''Test input.''' 16 | SIGHAN2005_MSR_TEST = SIGHAN2005 + "#" + "gold/msr_test_gold.utf8" 17 | '''Test set.''' 18 | 19 | make(SIGHAN2005_MSR_TRAIN) 20 | -------------------------------------------------------------------------------- /elit/datasets/cws/sighan2005/pku.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-21 15:42 4 | from elit.datasets.cws.sighan2005 import SIGHAN2005, make 5 | 6 | SIGHAN2005_PKU_DICT = SIGHAN2005 + "#" + "gold/pku_training_words.utf8" 7 | '''Dictionary built on trainings set.''' 8 | SIGHAN2005_PKU_TRAIN_ALL = SIGHAN2005 + "#" + "training/pku_training.utf8" 9 | '''Full training set.''' 10 | SIGHAN2005_PKU_TRAIN = SIGHAN2005 + "#" + "training/pku_training_90.txt" 11 | '''Training set (first 90% of the full official training set).''' 12 | SIGHAN2005_PKU_DEV = SIGHAN2005 + "#" + "training/pku_training_10.txt" 13 | '''Dev set (last 10% of full official training set).''' 14 | SIGHAN2005_PKU_TEST_INPUT = SIGHAN2005 + "#" + "testing/pku_test.utf8" 15 | '''Test input.''' 16 | SIGHAN2005_PKU_TEST = SIGHAN2005 + "#" + "gold/pku_test_gold.utf8" 17 | '''Test set.''' 18 | 19 | make(SIGHAN2005_PKU_TRAIN) 20 | -------------------------------------------------------------------------------- /elit/datasets/eos/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-26 18:11 -------------------------------------------------------------------------------- /elit/datasets/eos/nn_eos.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-24 22:51 4 | _SETIMES2_EN_HR_SENTENCES_HOME = 'https://schweter.eu/cloud/nn_eos/SETIMES2.en-hr.sentences.tar.xz' 5 | SETIMES2_EN_HR_HR_SENTENCES_TRAIN = _SETIMES2_EN_HR_SENTENCES_HOME + '#SETIMES2.en-hr.hr.sentences.train' 6 | '''Training set of SETimes corpus.''' 7 | SETIMES2_EN_HR_HR_SENTENCES_DEV = _SETIMES2_EN_HR_SENTENCES_HOME + '#SETIMES2.en-hr.hr.sentences.dev' 8 | '''Dev set of SETimes corpus.''' 9 | SETIMES2_EN_HR_HR_SENTENCES_TEST = _SETIMES2_EN_HR_SENTENCES_HOME + '#SETIMES2.en-hr.hr.sentences.test' 10 | '''Test set of SETimes corpus.''' 11 | _EUROPARL_V7_DE_EN_EN_SENTENCES_HOME = 'http://schweter.eu/cloud/nn_eos/europarl-v7.de-en.en.sentences.tar.xz' 12 | EUROPARL_V7_DE_EN_EN_SENTENCES_TRAIN = _EUROPARL_V7_DE_EN_EN_SENTENCES_HOME + '#europarl-v7.de-en.en.sentences.train' 13 | '''Training set of Europarl corpus (:cite:`koehn2005europarl`).''' 14 | EUROPARL_V7_DE_EN_EN_SENTENCES_DEV = _EUROPARL_V7_DE_EN_EN_SENTENCES_HOME + '#europarl-v7.de-en.en.sentences.dev' 15 | '''Dev set of Europarl corpus (:cite:`koehn2005europarl`).''' 16 | EUROPARL_V7_DE_EN_EN_SENTENCES_TEST = _EUROPARL_V7_DE_EN_EN_SENTENCES_HOME + '#europarl-v7.de-en.en.sentences.test' 17 | '''Test set of Europarl corpus (:cite:`koehn2005europarl`).''' 18 | -------------------------------------------------------------------------------- /elit/datasets/glue.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-11-10 11:47 4 | from elit.common.dataset import TableDataset 5 | 6 | STANFORD_SENTIMENT_TREEBANK_2_TRAIN = 'http://file.hankcs.com/corpus/SST2.zip#train.tsv' 7 | STANFORD_SENTIMENT_TREEBANK_2_DEV = 'http://file.hankcs.com/corpus/SST2.zip#dev.tsv' 8 | STANFORD_SENTIMENT_TREEBANK_2_TEST = 'http://file.hankcs.com/corpus/SST2.zip#test.tsv' 9 | 10 | MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_TRAIN = 'http://file.hankcs.com/corpus/mrpc.zip#train.tsv' 11 | MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_DEV = 'http://file.hankcs.com/corpus/mrpc.zip#dev.tsv' 12 | MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_TEST = 'http://file.hankcs.com/corpus/mrpc.zip#test.tsv' 13 | 14 | 15 | class SST2Dataset(TableDataset): 16 | pass 17 | 18 | 19 | def main(): 20 | dataset = SST2Dataset(STANFORD_SENTIMENT_TREEBANK_2_TEST) 21 | print(dataset) 22 | 23 | 24 | if __name__ == '__main__': 25 | main() 26 | -------------------------------------------------------------------------------- /elit/datasets/lm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-06-05 21:41 4 | 5 | _PTB_HOME = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz#' 6 | PTB_TOKEN_TRAIN = _PTB_HOME + 'data/ptb.train.txt' 7 | PTB_TOKEN_DEV = _PTB_HOME + 'data/ptb.valid.txt' 8 | PTB_TOKEN_TEST = _PTB_HOME + 'data/ptb.test.txt' 9 | 10 | PTB_CHAR_TRAIN = _PTB_HOME + 'data/ptb.char.train.txt' 11 | PTB_CHAR_DEV = _PTB_HOME + 'data/ptb.char.valid.txt' 12 | PTB_CHAR_TEST = _PTB_HOME + 'data/ptb.char.test.txt' 13 | -------------------------------------------------------------------------------- /elit/datasets/ner/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-06 15:32 -------------------------------------------------------------------------------- /elit/datasets/ner/conll03.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-06 15:31 4 | 5 | 6 | CONLL03_EN_TRAIN = 'https://file.hankcs.com/corpus/conll03_en_iobes.zip#eng.train.tsv' 7 | '''Training set of CoNLL03 (:cite:`tjong-kim-sang-de-meulder-2003-introduction`)''' 8 | CONLL03_EN_DEV = 'https://file.hankcs.com/corpus/conll03_en_iobes.zip#eng.dev.tsv' 9 | '''Dev set of CoNLL03 (:cite:`tjong-kim-sang-de-meulder-2003-introduction`)''' 10 | CONLL03_EN_TEST = 'https://file.hankcs.com/corpus/conll03_en_iobes.zip#eng.test.tsv' 11 | '''Test set of CoNLL03 (:cite:`tjong-kim-sang-de-meulder-2003-introduction`)''' 12 | -------------------------------------------------------------------------------- /elit/datasets/ner/msra.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 23:13 4 | 5 | _MSRA_NER_HOME = 'http://file.hankcs.com/corpus/msra_ner.zip' 6 | _MSRA_NER_TOKEN_LEVEL_HOME = 'http://file.hankcs.com/corpus/msra_ner_token_level.zip' 7 | 8 | MSRA_NER_CHAR_LEVEL_TRAIN = f'{_MSRA_NER_HOME}#train.tsv' 9 | '''Training set of MSRA (:cite:`levow-2006-third`) in character level.''' 10 | MSRA_NER_CHAR_LEVEL_DEV = f'{_MSRA_NER_HOME}#dev.tsv' 11 | '''Dev set of MSRA (:cite:`levow-2006-third`) in character level.''' 12 | MSRA_NER_CHAR_LEVEL_TEST = f'{_MSRA_NER_HOME}#test.tsv' 13 | '''Test set of MSRA (:cite:`levow-2006-third`) in character level.''' 14 | 15 | MSRA_NER_TOKEN_LEVEL_IOBES_TRAIN = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.train.tsv' 16 | '''Training set of MSRA (:cite:`levow-2006-third`) in token level.''' 17 | MSRA_NER_TOKEN_LEVEL_IOBES_DEV = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.dev.tsv' 18 | '''Dev set of MSRA (:cite:`levow-2006-third`) in token level.''' 19 | MSRA_NER_TOKEN_LEVEL_IOBES_TEST = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.test.tsv' 20 | '''Test set of MSRA (:cite:`levow-2006-third`) in token level.''' 21 | 22 | MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_TRAIN = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.train.short.tsv' 23 | '''Training set of shorten (<= 128 tokens) MSRA (:cite:`levow-2006-third`) in token level.''' 24 | MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_DEV = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.dev.short.tsv' 25 | '''Dev set of shorten (<= 128 tokens) MSRA (:cite:`levow-2006-third`) in token level.''' 26 | MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_TEST = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.test.short.tsv' 27 | '''Test set of shorten (<= 128 tokens) MSRA (:cite:`levow-2006-third`) in token level.''' 28 | 29 | MSRA_NER_TOKEN_LEVEL_SHORT_JSON_TRAIN = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.train.short.jsonlines' 30 | '''Training set of shorten (<= 128 tokens) MSRA (:cite:`levow-2006-third`) in token level and jsonlines format.''' 31 | MSRA_NER_TOKEN_LEVEL_SHORT_JSON_DEV = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.dev.short.jsonlines' 32 | '''Dev set of shorten (<= 128 tokens) MSRA (:cite:`levow-2006-third`) in token level and jsonlines format.''' 33 | MSRA_NER_TOKEN_LEVEL_SHORT_JSON_TEST = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.test.short.jsonlines' 34 | '''Test set of shorten (<= 128 tokens) MSRA (:cite:`levow-2006-third`) in token level and jsonlines format.''' 35 | -------------------------------------------------------------------------------- /elit/datasets/ner/resume.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-06-08 12:10 4 | from elit.common.dataset import TransformableDataset 5 | 6 | from elit.utils.io_util import get_resource, generate_words_tags_from_tsv 7 | 8 | _RESUME_NER_HOME = 'https://github.com/jiesutd/LatticeLSTM/archive/master.zip#' 9 | 10 | RESUME_NER_TRAIN = _RESUME_NER_HOME + 'ResumeNER/train.char.bmes' 11 | '''Training set of Resume in char level.''' 12 | RESUME_NER_DEV = _RESUME_NER_HOME + 'ResumeNER/dev.char.bmes' 13 | '''Dev set of Resume in char level.''' 14 | RESUME_NER_TEST = _RESUME_NER_HOME + 'ResumeNER/test.char.bmes' 15 | '''Test set of Resume in char level.''' 16 | 17 | -------------------------------------------------------------------------------- /elit/datasets/ner/weibo.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-06-03 23:33 4 | from elit.common.dataset import TransformableDataset 5 | 6 | from elit.utils.io_util import get_resource, generate_words_tags_from_tsv 7 | 8 | _WEIBO_NER_HOME = 'https://github.com/hltcoe/golden-horse/archive/master.zip#data/' 9 | 10 | WEIBO_NER_TRAIN = _WEIBO_NER_HOME + 'weiboNER_2nd_conll.train' 11 | '''Training set of Weibo in char level.''' 12 | WEIBO_NER_DEV = _WEIBO_NER_HOME + 'weiboNER_2nd_conll.dev' 13 | '''Dev set of Weibo in char level.''' 14 | WEIBO_NER_TEST = _WEIBO_NER_HOME + 'weiboNER_2nd_conll.test' 15 | '''Test set of Weibo in char level.''' 16 | -------------------------------------------------------------------------------- /elit/datasets/parsing/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 00:51 4 | -------------------------------------------------------------------------------- /elit/datasets/parsing/ctb5.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 18:44 4 | from elit.common.constant import HANLP_URL 5 | 6 | _CTB_HOME = HANLP_URL + 'embeddings/SUDA-LA-CIP_20200109_021624.zip#' 7 | 8 | _CTB5_DEP_HOME = _CTB_HOME + 'BPNN/data/ctb5/' 9 | 10 | CTB5_DEP_TRAIN = _CTB5_DEP_HOME + 'train.conll' 11 | '''Training set for ctb5 dependency parsing.''' 12 | CTB5_DEP_DEV = _CTB5_DEP_HOME + 'dev.conll' 13 | '''Dev set for ctb5 dependency parsing.''' 14 | CTB5_DEP_TEST = _CTB5_DEP_HOME + 'test.conll' 15 | '''Test set for ctb5 dependency parsing.''' 16 | 17 | CIP_W2V_100_CN = _CTB_HOME + 'BPNN/data/embed.txt' 18 | -------------------------------------------------------------------------------- /elit/datasets/parsing/ctb7.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 18:44 4 | from elit.datasets.parsing.ctb5 import _CTB_HOME 5 | 6 | _CTB7_HOME = _CTB_HOME + 'BPNN/data/ctb7/' 7 | 8 | CTB7_DEP_TRAIN = _CTB7_HOME + 'train.conll' 9 | '''Training set for ctb7 dependency parsing.''' 10 | CTB7_DEP_DEV = _CTB7_HOME + 'dev.conll' 11 | '''Dev set for ctb7 dependency parsing.''' 12 | CTB7_DEP_TEST = _CTB7_HOME + 'test.conll' 13 | '''Test set for ctb7 dependency parsing.''' 14 | -------------------------------------------------------------------------------- /elit/datasets/parsing/ctb8.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-10-14 20:54 4 | 5 | from elit.datasets.parsing._ctb_utils import make_ctb 6 | 7 | _CTB8_HOME = 'https://wakespace.lib.wfu.edu/bitstream/handle/10339/39379/LDC2013T21.tgz#data/' 8 | 9 | CTB8_CWS_TRAIN = _CTB8_HOME + 'tasks/cws/train.txt' 10 | '''Training set for ctb8 Chinese word segmentation.''' 11 | CTB8_CWS_DEV = _CTB8_HOME + 'tasks/cws/dev.txt' 12 | '''Dev set for ctb8 Chinese word segmentation.''' 13 | CTB8_CWS_TEST = _CTB8_HOME + 'tasks/cws/test.txt' 14 | '''Test set for ctb8 Chinese word segmentation.''' 15 | 16 | CTB8_POS_TRAIN = _CTB8_HOME + 'tasks/pos/train.tsv' 17 | '''Training set for ctb8 PoS tagging.''' 18 | CTB8_POS_DEV = _CTB8_HOME + 'tasks/pos/dev.tsv' 19 | '''Dev set for ctb8 PoS tagging.''' 20 | CTB8_POS_TEST = _CTB8_HOME + 'tasks/pos/test.tsv' 21 | '''Test set for ctb8 PoS tagging.''' 22 | 23 | CTB8_BRACKET_LINE_TRAIN = _CTB8_HOME + 'tasks/par/train.txt' 24 | '''Training set for ctb8 constituency parsing with empty categories.''' 25 | CTB8_BRACKET_LINE_DEV = _CTB8_HOME + 'tasks/par/dev.txt' 26 | '''Dev set for ctb8 constituency parsing with empty categories.''' 27 | CTB8_BRACKET_LINE_TEST = _CTB8_HOME + 'tasks/par/test.txt' 28 | '''Test set for ctb8 constituency parsing with empty categories.''' 29 | 30 | CTB8_BRACKET_LINE_NOEC_TRAIN = _CTB8_HOME + 'tasks/par/train.noempty.txt' 31 | '''Training set for ctb8 constituency parsing without empty categories.''' 32 | CTB8_BRACKET_LINE_NOEC_DEV = _CTB8_HOME + 'tasks/par/dev.noempty.txt' 33 | '''Dev set for ctb8 constituency parsing without empty categories.''' 34 | CTB8_BRACKET_LINE_NOEC_TEST = _CTB8_HOME + 'tasks/par/test.noempty.txt' 35 | '''Test set for ctb8 constituency parsing without empty categories.''' 36 | 37 | CTB8_SD330_TRAIN = _CTB8_HOME + 'tasks/dep/train.conllx' 38 | '''Training set for ctb8 in Stanford Dependencies 3.3.0 standard.''' 39 | CTB8_SD330_DEV = _CTB8_HOME + 'tasks/dep/dev.conllx' 40 | '''Dev set for ctb8 in Stanford Dependencies 3.3.0 standard.''' 41 | CTB8_SD330_TEST = _CTB8_HOME + 'tasks/dep/test.conllx' 42 | '''Test set for ctb8 in Stanford Dependencies 3.3.0 standard.''' 43 | 44 | make_ctb(_CTB8_HOME) 45 | -------------------------------------------------------------------------------- /elit/datasets/parsing/ptb.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-02-17 15:46 4 | 5 | _PTB_HOME = 'https://github.com/KhalilMrini/LAL-Parser/archive/master.zip#data/' 6 | 7 | PTB_TRAIN = _PTB_HOME + '02-21.10way.clean' 8 | '''Training set of PTB without empty categories. PoS tags are automatically predicted using 10-fold 9 | jackknifing (:cite:`collins-koo-2005-discriminative`).''' 10 | PTB_DEV = _PTB_HOME + '22.auto.clean' 11 | '''Dev set of PTB without empty categories. PoS tags are automatically predicted using 10-fold 12 | jackknifing (:cite:`collins-koo-2005-discriminative`).''' 13 | PTB_TEST = _PTB_HOME + '23.auto.clean' 14 | '''Test set of PTB without empty categories. PoS tags are automatically predicted using 10-fold 15 | jackknifing (:cite:`collins-koo-2005-discriminative`).''' 16 | 17 | PTB_SD330_TRAIN = _PTB_HOME + 'ptb_train_3.3.0.sd.clean' 18 | '''Training set of PTB in Stanford Dependencies 3.3.0 format. PoS tags are automatically predicted using 10-fold 19 | jackknifing (:cite:`collins-koo-2005-discriminative`).''' 20 | PTB_SD330_DEV = _PTB_HOME + 'ptb_dev_3.3.0.sd.clean' 21 | '''Dev set of PTB in Stanford Dependencies 3.3.0 format. PoS tags are automatically predicted using 10-fold 22 | jackknifing (:cite:`collins-koo-2005-discriminative`).''' 23 | PTB_SD330_TEST = _PTB_HOME + 'ptb_test_3.3.0.sd.clean' 24 | '''Test set of PTB in Stanford Dependencies 3.3.0 format. PoS tags are automatically predicted using 10-fold 25 | jackknifing (:cite:`collins-koo-2005-discriminative`).''' 26 | 27 | PTB_TOKEN_MAPPING = { 28 | "-LRB-": "(", 29 | "-RRB-": ")", 30 | "-LCB-": "{", 31 | "-RCB-": "}", 32 | "-LSB-": "[", 33 | "-RSB-": "]", 34 | "``": '"', 35 | "''": '"', 36 | "`": "'", 37 | '«': '"', 38 | '»': '"', 39 | '‘': "'", 40 | '’': "'", 41 | '“': '"', 42 | '”': '"', 43 | '„': '"', 44 | '‹': "'", 45 | '›': "'", 46 | "\u2013": "--", # en dash 47 | "\u2014": "--", # em dash 48 | } 49 | -------------------------------------------------------------------------------- /elit/datasets/parsing/semeval15.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-28 14:40 4 | # from elit.datasets.parsing.conll_dataset import CoNLLParsingDataset 5 | # 6 | # 7 | # class SemEval15Dataset(CoNLLParsingDataset): 8 | # def load_file(self, filepath: str): 9 | # pass 10 | import warnings 11 | 12 | from elit.common.constant import ROOT, PAD 13 | from elit.common.conll import CoNLLSentence 14 | 15 | 16 | def unpack_deps_to_head_deprel(sample: dict, pad_rel=None, arc_key='arc', rel_key='rel'): 17 | if 'DEPS' in sample: 18 | deps = ['_'] + sample['DEPS'] 19 | sample[arc_key] = arc = [] 20 | sample[rel_key] = rel = [] 21 | for each in deps: 22 | arc_per_token = [False] * len(deps) 23 | rel_per_token = [None] * len(deps) 24 | if each != '_': 25 | for ar in each.split('|'): 26 | a, r = ar.split(':') 27 | a = int(a) 28 | arc_per_token[a] = True 29 | rel_per_token[a] = r 30 | if not pad_rel: 31 | pad_rel = r 32 | arc.append(arc_per_token) 33 | rel.append(rel_per_token) 34 | if not pad_rel: 35 | pad_rel = PAD 36 | for i in range(len(rel)): 37 | rel[i] = [r if r else pad_rel for r in rel[i]] 38 | return sample 39 | 40 | 41 | def append_bos_to_form_pos(sample, pos_key='CPOS'): 42 | sample['token'] = [ROOT] + sample['FORM'] 43 | if pos_key in sample: 44 | sample['pos'] = [ROOT] + sample[pos_key] 45 | return sample 46 | 47 | 48 | def merge_head_deprel_with_2nd(sample: dict): 49 | if 'arc' in sample: 50 | arc_2nd = sample['arc_2nd'] 51 | rel_2nd = sample['rel_2nd'] 52 | for i, (arc, rel) in enumerate(zip(sample['arc'], sample['rel'])): 53 | if i: 54 | if arc_2nd[i][arc] and rel_2nd[i][arc] != rel: 55 | sample_str = CoNLLSentence.from_dict(sample, conllu=True).to_markdown() 56 | warnings.warn(f'The main dependency conflicts with 2nd dependency at ID={i}, ' \ 57 | 'which means joint mode might not be suitable. ' \ 58 | f'The sample is\n{sample_str}') 59 | arc_2nd[i][arc] = True 60 | rel_2nd[i][arc] = rel 61 | return sample 62 | -------------------------------------------------------------------------------- /elit/datasets/parsing/ud/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-07 21:45 4 | import os 5 | import shutil 6 | 7 | from elit.components.parsers.ud.udify_util import get_ud_treebank_files 8 | from elit.utils.io_util import get_resource 9 | from elit.utils.log_util import flash 10 | 11 | 12 | def concat_treebanks(home, version): 13 | ud_home = get_resource(home) 14 | treebanks = get_ud_treebank_files(ud_home) 15 | output_dir = os.path.abspath(os.path.join(ud_home, os.path.pardir, os.path.pardir, f'ud-multilingual-v{version}')) 16 | if os.path.isdir(output_dir): 17 | return output_dir 18 | os.makedirs(output_dir) 19 | train, dev, test = list(zip(*[treebanks[k] for k in treebanks])) 20 | 21 | for treebank, name in zip([train, dev, test], ["train.conllu", "dev.conllu", "test.conllu"]): 22 | flash(f'Concatenating {len(train)} treebanks into {name} [blink][yellow]...[/yellow][/blink]') 23 | with open(os.path.join(output_dir, name), 'w') as write: 24 | for t in treebank: 25 | if not t: 26 | continue 27 | with open(t, 'r') as read: 28 | shutil.copyfileobj(read, write) 29 | flash('') 30 | return output_dir 31 | -------------------------------------------------------------------------------- /elit/datasets/parsing/ud/ud23m.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-05-21 20:39 4 | import os 5 | 6 | from elit.datasets.parsing.ud import concat_treebanks 7 | from .ud23 import _UD_23_HOME 8 | 9 | _UD_23_MULTILINGUAL_HOME = concat_treebanks(_UD_23_HOME, '2.3') 10 | UD_23_MULTILINGUAL_TRAIN = os.path.join(_UD_23_MULTILINGUAL_HOME, 'train.conllu') 11 | UD_23_MULTILINGUAL_DEV = os.path.join(_UD_23_MULTILINGUAL_HOME, 'dev.conllu') 12 | UD_23_MULTILINGUAL_TEST = os.path.join(_UD_23_MULTILINGUAL_HOME, 'test.conllu') 13 | -------------------------------------------------------------------------------- /elit/datasets/parsing/ud/ud27m.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-05-21 20:39 4 | import os 5 | 6 | from elit.datasets.parsing.ud import concat_treebanks 7 | from elit.datasets.parsing.ud.ud27 import _UD_27_HOME 8 | 9 | _UD_27_MULTILINGUAL_HOME = concat_treebanks(_UD_27_HOME, '2.7') 10 | UD_27_MULTILINGUAL_TRAIN = os.path.join(_UD_27_MULTILINGUAL_HOME, 'train.conllu') 11 | "Training set of multilingual UD_27 obtained by concatenating all training sets." 12 | UD_27_MULTILINGUAL_DEV = os.path.join(_UD_27_MULTILINGUAL_HOME, 'dev.conllu') 13 | "Dev set of multilingual UD_27 obtained by concatenating all dev sets." 14 | UD_27_MULTILINGUAL_TEST = os.path.join(_UD_27_MULTILINGUAL_HOME, 'test.conllu') 15 | "Test set of multilingual UD_27 obtained by concatenating all test sets." 16 | -------------------------------------------------------------------------------- /elit/datasets/pos/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 22:50 -------------------------------------------------------------------------------- /elit/datasets/pos/ctb5.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 22:51 4 | 5 | _CTB5_POS_HOME = 'http://file.hankcs.com/corpus/ctb5.1-pos.zip' 6 | 7 | CTB5_POS_TRAIN = f'{_CTB5_POS_HOME}#train.tsv' 8 | '''PoS training set for CTB5.''' 9 | CTB5_POS_DEV = f'{_CTB5_POS_HOME}#dev.tsv' 10 | '''PoS dev set for CTB5.''' 11 | CTB5_POS_TEST = f'{_CTB5_POS_HOME}#test.tsv' 12 | '''PoS test set for CTB5.''' 13 | -------------------------------------------------------------------------------- /elit/datasets/qa/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-03-20 19:17 -------------------------------------------------------------------------------- /elit/datasets/srl/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-06-22 19:15 4 | 5 | 6 | -------------------------------------------------------------------------------- /elit/datasets/srl/ontonotes5/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-11-26 16:07 4 | ONTONOTES5_HOME = 'https://catalog.ldc.upenn.edu/LDC2013T19/ontonotes-release-5.0.tgz#data/' 5 | CONLL12_HOME = ONTONOTES5_HOME + '../conll-2012/' 6 | -------------------------------------------------------------------------------- /elit/datasets/srl/ontonotes5/english.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-25 18:48 4 | import glob 5 | import os 6 | 7 | from elit.utils.io_util import get_resource, merge_files 8 | 9 | _CONLL2012_EN_HOME = 'https://github.com/yuchenlin/OntoNotes-5.0-NER-BIO/archive/master.zip#conll-formatted-ontonotes-5.0/data' 10 | # These are v4 of OntoNotes, in .conll format 11 | CONLL2012_EN_TRAIN = _CONLL2012_EN_HOME + '/train/data/english/annotations' 12 | CONLL2012_EN_DEV = _CONLL2012_EN_HOME + '/development/data/english/annotations' 13 | CONLL2012_EN_TEST = _CONLL2012_EN_HOME + '/conll-2012-test/data/english/annotations' 14 | 15 | 16 | def conll_2012_en_combined(): 17 | home = get_resource(_CONLL2012_EN_HOME) 18 | outputs = ['train', 'dev', 'test'] 19 | for i in range(len(outputs)): 20 | outputs[i] = f'{home}/conll12_en/{outputs[i]}.conll' 21 | if all(os.path.isfile(x) for x in outputs): 22 | return outputs 23 | os.makedirs(os.path.dirname(outputs[0]), exist_ok=True) 24 | for in_path, out_path in zip([CONLL2012_EN_TRAIN, CONLL2012_EN_DEV, CONLL2012_EN_TEST], outputs): 25 | in_path = get_resource(in_path) 26 | files = sorted(glob.glob(f'{in_path}/**/*gold_conll', recursive=True)) 27 | merge_files(files, out_path) 28 | return outputs 29 | -------------------------------------------------------------------------------- /elit/datasets/tokenization/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-08-01 12:33 -------------------------------------------------------------------------------- /elit/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-10-26 00:50 -------------------------------------------------------------------------------- /elit/layers/context_layer.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-05 19:34 4 | from alnlp.modules.pytorch_seq2seq_wrapper import LstmSeq2SeqEncoder 5 | from torch import nn 6 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 7 | 8 | from elit.common.structure import ConfigTracker 9 | 10 | 11 | class _LSTMSeq2Seq(nn.Module): 12 | def __init__( 13 | self, 14 | input_size: int, 15 | hidden_size: int, 16 | num_layers: int = 1, 17 | bias: bool = True, 18 | dropout: float = 0.0, 19 | bidirectional: bool = False, 20 | ): 21 | """ 22 | Under construction, not ready for production 23 | :param input_size: 24 | :param hidden_size: 25 | :param num_layers: 26 | :param bias: 27 | :param dropout: 28 | :param bidirectional: 29 | """ 30 | self.rnn = nn.LSTM( 31 | input_size=input_size, 32 | hidden_size=hidden_size, 33 | num_layers=num_layers, 34 | bias=bias, 35 | batch_first=True, 36 | dropout=dropout, 37 | bidirectional=bidirectional, 38 | ) 39 | 40 | def forward(self, embed, lens, max_len): 41 | x = pack_padded_sequence(embed, lens, True, False) 42 | x, _ = self.rnn(x) 43 | x, _ = pad_packed_sequence(x, True, total_length=max_len) 44 | return x 45 | 46 | 47 | # We might update this to support yaml based configuration 48 | class LSTMContextualEncoder(LstmSeq2SeqEncoder, ConfigTracker): 49 | 50 | def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, dropout: float = 0.0, 51 | bidirectional: bool = False, stateful: bool = False): 52 | super().__init__(input_size, hidden_size, num_layers, bias, dropout, bidirectional, stateful) 53 | ConfigTracker.__init__(self, locals()) 54 | -------------------------------------------------------------------------------- /elit/layers/crf/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-18 22:55 -------------------------------------------------------------------------------- /elit/layers/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-08-24 21:48 4 | -------------------------------------------------------------------------------- /elit/layers/embeddings/concat_embedding.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-20 17:08 4 | import tensorflow as tf 5 | 6 | from elit.utils.tf_util import hanlp_register, copy_mask 7 | 8 | 9 | @hanlp_register 10 | class ConcatEmbedding(tf.keras.layers.Layer): 11 | def __init__(self, *embeddings, trainable=True, name=None, dtype=None, dynamic=False, **kwargs): 12 | self.embeddings = [] 13 | for embed in embeddings: 14 | embed: tf.keras.layers.Layer = tf.keras.utils.deserialize_keras_object(embed) if isinstance(embed, 15 | dict) else embed 16 | self.embeddings.append(embed) 17 | if embed.trainable: 18 | trainable = True 19 | if embed.dynamic: 20 | dynamic = True 21 | if embed.supports_masking: 22 | self.supports_masking = True 23 | 24 | super().__init__(trainable, name, dtype, dynamic, **kwargs) 25 | 26 | def build(self, input_shape): 27 | for embed in self.embeddings: 28 | embed.build(input_shape) 29 | super().build(input_shape) 30 | 31 | def compute_mask(self, inputs, mask=None): 32 | for embed in self.embeddings: 33 | mask = embed.compute_mask(inputs, mask) 34 | if mask is not None: 35 | return mask 36 | return mask 37 | 38 | def call(self, inputs, **kwargs): 39 | embeds = [embed.call(inputs) for embed in self.embeddings] 40 | feature = tf.concat(embeds, axis=-1) 41 | 42 | for embed in embeds: 43 | mask = copy_mask(embed, feature) 44 | if mask is not None: 45 | break 46 | return feature 47 | 48 | def get_config(self): 49 | config = { 50 | 'embeddings': [embed.get_config() for embed in self.embeddings], 51 | } 52 | base_config = super(ConcatEmbedding, self).get_config() 53 | return dict(list(base_config.items()) + list(config.items())) 54 | 55 | def compute_output_shape(self, input_shape): 56 | dim = 0 57 | for embed in self.embeddings: 58 | dim += embed.compute_output_shape(input_shape)[-1] 59 | 60 | return input_shape + dim 61 | -------------------------------------------------------------------------------- /elit/layers/feed_forward.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-06 14:37 4 | from typing import Union, List 5 | 6 | from alnlp.modules import feedforward 7 | 8 | from elit.common.structure import ConfigTracker 9 | 10 | 11 | class FeedForward(feedforward.FeedForward, ConfigTracker): 12 | def __init__(self, input_dim: int, num_layers: int, hidden_dims: Union[int, List[int]], 13 | activations: Union[str, List[str]], dropout: Union[float, List[float]] = 0.0) -> None: 14 | super().__init__(input_dim, num_layers, hidden_dims, activations, dropout) 15 | ConfigTracker.__init__(self, locals()) 16 | -------------------------------------------------------------------------------- /elit/layers/pass_through_encoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-08 17:56 4 | from alnlp.modules.pass_through_encoder import PassThroughEncoder as _PassThroughEncoder 5 | 6 | from elit.common.structure import ConfigTracker 7 | 8 | 9 | class PassThroughEncoder(_PassThroughEncoder, ConfigTracker): 10 | def __init__(self, input_dim: int) -> None: 11 | super().__init__(input_dim) 12 | ConfigTracker.__init__(self, locals()) 13 | -------------------------------------------------------------------------------- /elit/layers/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-29 15:17 4 | # mute transformers 5 | import logging 6 | 7 | logging.getLogger('transformers.file_utils').setLevel(logging.ERROR) 8 | logging.getLogger('transformers.filelock').setLevel(logging.ERROR) 9 | logging.getLogger('transformers.tokenization_utils').setLevel(logging.ERROR) 10 | logging.getLogger('transformers.configuration_utils').setLevel(logging.ERROR) 11 | logging.getLogger('transformers.modeling_tf_utils').setLevel(logging.ERROR) 12 | logging.getLogger('transformers.modeling_utils').setLevel(logging.ERROR) 13 | logging.getLogger('transformers.tokenization_utils_base').setLevel(logging.ERROR) 14 | -------------------------------------------------------------------------------- /elit/layers/transformers/pt_imports.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-05-09 11:25 4 | import os 5 | 6 | if os.environ.get('USE_TF', None) is None: 7 | os.environ["USE_TF"] = 'NO' # saves time loading transformers 8 | if os.environ.get('TOKENIZERS_PARALLELISM', None) is None: 9 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 10 | from transformers import BertTokenizer, BertConfig, PretrainedConfig, \ 11 | AutoConfig, AutoTokenizer, PreTrainedTokenizer, BertTokenizerFast, AlbertConfig, BertModel, AutoModel, \ 12 | PreTrainedModel, get_linear_schedule_with_warmup, AdamW, AutoModelForSequenceClassification, \ 13 | AutoModelForTokenClassification, optimization, BartModel 14 | 15 | 16 | class AutoModel_(AutoModel): 17 | @classmethod 18 | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, training=True, **kwargs): 19 | if training: 20 | return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) 21 | else: 22 | if isinstance(pretrained_model_name_or_path, str): 23 | return super().from_config(AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)) 24 | else: 25 | assert not kwargs 26 | return super().from_config(pretrained_model_name_or_path) 27 | -------------------------------------------------------------------------------- /elit/layers/transformers/tf_imports.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-05-08 21:57 4 | from bert import bert_models_google 5 | from transformers import BertTokenizer, BertConfig, PretrainedConfig, TFAutoModel, \ 6 | AutoConfig, AutoTokenizer, PreTrainedTokenizer, TFPreTrainedModel, TFAlbertModel, TFAutoModelWithLMHead, BertTokenizerFast, TFAlbertForMaskedLM, AlbertConfig, TFBertModel 7 | 8 | from elit.common.constant import HANLP_URL 9 | 10 | zh_albert_models_google = { 11 | 'albert_base_zh': HANLP_URL + 'embeddings/albert_base_zh.tar.gz', # Provide mirroring 12 | 'albert_large_zh': 'https://storage.googleapis.com/albert_models/albert_large_zh.tar.gz', 13 | 'albert_xlarge_zh': 'https://storage.googleapis.com/albert_models/albert_xlarge_zh.tar.gz', 14 | 'albert_xxlarge_zh': 'https://storage.googleapis.com/albert_models/albert_xxlarge_zh.tar.gz', 15 | } 16 | bert_models_google['chinese_L-12_H-768_A-12'] = HANLP_URL + 'embeddings/chinese_L-12_H-768_A-12.zip' -------------------------------------------------------------------------------- /elit/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-20 01:28 -------------------------------------------------------------------------------- /elit/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-09-14 21:55 -------------------------------------------------------------------------------- /elit/metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-06-12 17:56 4 | from alnlp import metrics 5 | from elit.metrics.metric import Metric 6 | 7 | 8 | class CategoricalAccuracy(metrics.CategoricalAccuracy, Metric): 9 | @property 10 | def score(self): 11 | return self.get_metric() 12 | 13 | def __repr__(self) -> str: 14 | return f'Accuracy:{self.score:.2%}' 15 | 16 | 17 | class BooleanAccuracy(metrics.BooleanAccuracy, CategoricalAccuracy): 18 | pass 19 | -------------------------------------------------------------------------------- /elit/metrics/amr/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-08-24 12:47 -------------------------------------------------------------------------------- /elit/metrics/chunking/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-21 03:49 -------------------------------------------------------------------------------- /elit/metrics/chunking/binary_chunking_f1.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-08-02 14:27 4 | from collections import defaultdict 5 | from typing import List, Union 6 | 7 | import torch 8 | 9 | from elit.metrics.f1 import F1 10 | 11 | 12 | class BinaryChunkingF1(F1): 13 | def __call__(self, pred_tags: torch.LongTensor, gold_tags: torch.LongTensor, lens: List[int] = None): 14 | if lens is None: 15 | lens = [gold_tags.size(1)] * gold_tags.size(0) 16 | self.update(self.decode_spans(pred_tags, lens), self.decode_spans(gold_tags, lens)) 17 | 18 | def update(self, pred_tags, gold_tags): 19 | for pred, gold in zip(pred_tags, gold_tags): 20 | super().__call__(set(pred), set(gold)) 21 | 22 | @staticmethod 23 | def decode_spans(pred_tags: torch.LongTensor, lens: Union[List[int], torch.LongTensor]): 24 | if isinstance(lens, torch.Tensor): 25 | lens = lens.tolist() 26 | batch_pred = defaultdict(list) 27 | for batch, offset in pred_tags.nonzero(as_tuple=False).tolist(): 28 | batch_pred[batch].append(offset) 29 | batch_pred_spans = [[(0, l)] for l in lens] 30 | for batch, offsets in batch_pred.items(): 31 | l = lens[batch] 32 | batch_pred_spans[batch] = list(zip(offsets, offsets[1:] + [l])) 33 | return batch_pred_spans 34 | -------------------------------------------------------------------------------- /elit/metrics/chunking/bmes.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-09-14 21:55 4 | 5 | from elit.common.vocab_tf import VocabTF 6 | from elit.metrics.chunking.chunking_f1_tf import ChunkingF1_TF 7 | from elit.metrics.chunking.sequence_labeling import get_entities 8 | 9 | 10 | class BMES_F1_TF(ChunkingF1_TF): 11 | 12 | def __init__(self, tag_vocab: VocabTF, from_logits=True, suffix=False, name='f1', dtype=None, **kwargs): 13 | super().__init__(tag_vocab, from_logits, name, dtype, **kwargs) 14 | self.nb_correct = 0 15 | self.nb_pred = 0 16 | self.nb_true = 0 17 | self.suffix = suffix 18 | 19 | def update_tags(self, true_tags, pred_tags): 20 | for t, p in zip(true_tags, pred_tags): 21 | self.update_entities(get_entities(t, self.suffix), get_entities(p, self.suffix)) 22 | return self.result() 23 | 24 | def update_entities(self, true_entities, pred_entities): 25 | true_entities = set(true_entities) 26 | pred_entities = set(pred_entities) 27 | nb_correct = len(true_entities & pred_entities) 28 | nb_pred = len(pred_entities) 29 | nb_true = len(true_entities) 30 | self.nb_correct += nb_correct 31 | self.nb_pred += nb_pred 32 | self.nb_true += nb_true 33 | 34 | def result(self): 35 | nb_correct = self.nb_correct 36 | nb_pred = self.nb_pred 37 | nb_true = self.nb_true 38 | p = nb_correct / nb_pred if nb_pred > 0 else 0 39 | r = nb_correct / nb_true if nb_true > 0 else 0 40 | score = 2 * p * r / (p + r) if p + r > 0 else 0 41 | 42 | return score 43 | 44 | def reset_states(self): 45 | self.nb_correct = 0 46 | self.nb_pred = 0 47 | self.nb_true = 0 48 | -------------------------------------------------------------------------------- /elit/metrics/chunking/chunking_f1.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-06-11 22:14 4 | from typing import List 5 | 6 | from elit.metrics.chunking.sequence_labeling import get_entities 7 | from elit.metrics.f1 import F1 8 | from elit.metrics.metric import Metric 9 | 10 | 11 | class ChunkingF1(F1): 12 | 13 | def __call__(self, pred_tags: List[List[str]], gold_tags: List[List[str]]): 14 | for p, g in zip(pred_tags, gold_tags): 15 | pred = set(get_entities(p)) 16 | gold = set(get_entities(g)) 17 | self.nb_pred += len(pred) 18 | self.nb_true += len(gold) 19 | self.nb_correct += len(pred & gold) 20 | -------------------------------------------------------------------------------- /elit/metrics/chunking/chunking_f1_tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-29 23:09 4 | from abc import ABC, abstractmethod 5 | 6 | import tensorflow as tf 7 | 8 | from elit.common.vocab_tf import VocabTF 9 | 10 | 11 | class ChunkingF1_TF(tf.keras.metrics.Metric, ABC): 12 | 13 | def __init__(self, tag_vocab: VocabTF, from_logits=True, name='f1', dtype=None, **kwargs): 14 | super().__init__(name, dtype, dynamic=True, **kwargs) 15 | self.tag_vocab = tag_vocab 16 | self.from_logits = from_logits 17 | 18 | def update_the_state(self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor = None, **kwargs): 19 | mask = y_pred._keras_mask if sample_weight is None else sample_weight 20 | if self.tag_vocab.pad_idx is not None and sample_weight is None: 21 | # in this case, the model doesn't compute mask but provide a masking index, it's ok to 22 | mask = y_true != self.tag_vocab.pad_idx 23 | assert mask is not None, 'ChunkingF1 requires masking, check your _keras_mask or compute_mask' 24 | if self.from_logits: 25 | y_pred = tf.argmax(y_pred, axis=-1) 26 | y_true = self.to_tags(y_true, mask) 27 | y_pred = self.to_tags(y_pred, mask) 28 | return self.update_tags(y_true, y_pred) 29 | 30 | def __call__(self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor = None, **kwargs): 31 | return self.update_the_state(y_true, y_pred, sample_weight) 32 | 33 | def update_state(self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor = None, **kwargs): 34 | return self.update_the_state(y_true, y_pred, sample_weight) 35 | 36 | def to_tags(self, y: tf.Tensor, sample_weight: tf.Tensor): 37 | batch = [] 38 | y = y.numpy() 39 | sample_weight = sample_weight.numpy() 40 | for sent, mask in zip(y, sample_weight): 41 | tags = [] 42 | for tag, m in zip(sent, mask): 43 | if not m: 44 | continue 45 | tag = int(tag) 46 | if self.tag_vocab.pad_idx is not None and tag == self.tag_vocab.pad_idx: 47 | # If model predicts , it will fail most metrics. So replace it with a valid one 48 | tag = 1 49 | tags.append(self.tag_vocab.get_token(tag)) 50 | batch.append(tags) 51 | return batch 52 | 53 | @abstractmethod 54 | def update_tags(self, true_tags, pred_tags): 55 | pass 56 | 57 | @abstractmethod 58 | def result(self): 59 | pass 60 | -------------------------------------------------------------------------------- /elit/metrics/chunking/cws_eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author:hankcs 3 | # Date: 2018-06-02 22:53 4 | # 《自然语言处理入门》2.9 准确率评测 5 | # 配套书籍:http://nlp.hankcs.com/book.php 6 | # 讨论答疑:https://bbs.hankcs.com/ 7 | import re 8 | 9 | 10 | def to_region(segmentation: str) -> list: 11 | """将分词结果转换为区间 12 | 13 | Args: 14 | segmentation: 商品 和 服务 15 | segmentation: str: 16 | 17 | Returns: 18 | 0, 2), (2, 3), (3, 5)] 19 | 20 | """ 21 | region = [] 22 | start = 0 23 | for word in re.compile("\\s+").split(segmentation.strip()): 24 | end = start + len(word) 25 | region.append((start, end)) 26 | start = end 27 | return region 28 | 29 | 30 | def evaluate(gold: str, pred: str, dic: dict = None) -> tuple: 31 | """计算P、R、F1 32 | 33 | Args: 34 | gold: 标准答案文件,比如“商品 和 服务” 35 | pred: 分词结果文件,比如“商品 和服 务” 36 | dic: 词典 37 | gold: str: 38 | pred: str: 39 | dic: dict: (Default value = None) 40 | 41 | Returns: 42 | P, R, F1, OOV_R, IV_R) 43 | 44 | """ 45 | A_size, B_size, A_cap_B_size, OOV, IV, OOV_R, IV_R = 0, 0, 0, 0, 0, 0, 0 46 | with open(gold, encoding='utf-8') as gd, open(pred, encoding='utf-8') as pd: 47 | for g, p in zip(gd, pd): 48 | A, B = set(to_region(g)), set(to_region(p)) 49 | A_size += len(A) 50 | B_size += len(B) 51 | A_cap_B_size += len(A & B) 52 | text = re.sub("\\s+", "", g) 53 | if dic: 54 | for (start, end) in A: 55 | word = text[start: end] 56 | if word in dic: 57 | IV += 1 58 | else: 59 | OOV += 1 60 | 61 | for (start, end) in A & B: 62 | word = text[start: end] 63 | if word in dic: 64 | IV_R += 1 65 | else: 66 | OOV_R += 1 67 | p, r = safe_division(A_cap_B_size, B_size), safe_division(A_cap_B_size, A_size) 68 | return p, r, safe_division(2 * p * r, (p + r)), safe_division(OOV_R, OOV), safe_division(IV_R, IV) 69 | 70 | 71 | def build_dic_from_file(path): 72 | dic = set() 73 | with open(path, encoding='utf-8') as gd: 74 | for g in gd: 75 | for word in re.compile("\\s+").split(g.strip()): 76 | dic.add(word) 77 | return dic 78 | 79 | 80 | def safe_division(n, d): 81 | return n / d if d else float('nan') if n else 0. 82 | -------------------------------------------------------------------------------- /elit/metrics/chunking/iobes_tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-09-14 21:55 4 | 5 | from elit.common.vocab_tf import VocabTF 6 | from elit.metrics.chunking.conlleval import SpanF1 7 | from elit.metrics.chunking.chunking_f1_tf import ChunkingF1_TF 8 | 9 | 10 | class IOBES_F1_TF(ChunkingF1_TF): 11 | 12 | def __init__(self, tag_vocab: VocabTF, from_logits=True, name='f1', dtype=None, **kwargs): 13 | super().__init__(tag_vocab, from_logits, name, dtype, **kwargs) 14 | self.state = SpanF1() 15 | 16 | def update_tags(self, true_tags, pred_tags): 17 | # true_tags = list(itertools.chain.from_iterable(true_tags)) 18 | # pred_tags = list(itertools.chain.from_iterable(pred_tags)) 19 | # self.state.update_state(true_tags, pred_tags) 20 | for gold, pred in zip(true_tags, pred_tags): 21 | self.state.update_state(gold, pred) 22 | return self.result() 23 | 24 | def result(self): 25 | return self.state.result(full=False, verbose=False).fscore 26 | 27 | def reset_states(self): 28 | self.state.reset_state() 29 | -------------------------------------------------------------------------------- /elit/metrics/f1.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-10 14:55 4 | from abc import ABC 5 | 6 | from elit.metrics.metric import Metric 7 | 8 | 9 | class F1(Metric, ABC): 10 | def __init__(self, nb_pred=0, nb_true=0, nb_correct=0) -> None: 11 | super().__init__() 12 | self.nb_correct = nb_correct 13 | self.nb_pred = nb_pred 14 | self.nb_true = nb_true 15 | 16 | def __repr__(self) -> str: 17 | p, r, f = self.prf 18 | return f"P: {p:.2%} R: {r:.2%} F1: {f:.2%}" 19 | 20 | @property 21 | def prf(self): 22 | nb_correct = self.nb_correct 23 | nb_pred = self.nb_pred 24 | nb_true = self.nb_true 25 | p = nb_correct / nb_pred if nb_pred > 0 else .0 26 | r = nb_correct / nb_true if nb_true > 0 else .0 27 | f = 2 * p * r / (p + r) if p + r > 0 else .0 28 | return p, r, f 29 | 30 | @property 31 | def score(self): 32 | return self.prf[-1] 33 | 34 | def reset(self): 35 | self.nb_correct = 0 36 | self.nb_pred = 0 37 | self.nb_true = 0 38 | 39 | def __call__(self, pred: set, gold: set): 40 | self.nb_correct += len(pred & gold) 41 | self.nb_pred += len(pred) 42 | self.nb_true += len(gold) 43 | 44 | 45 | class F1_(Metric): 46 | def __init__(self, p, r, f) -> None: 47 | super().__init__() 48 | self.f = f 49 | self.r = r 50 | self.p = p 51 | 52 | @property 53 | def score(self): 54 | return self.f 55 | 56 | def __call__(self, pred, gold): 57 | raise NotImplementedError() 58 | 59 | def reset(self): 60 | self.f = self.r = self.p = 0 61 | 62 | def __repr__(self) -> str: 63 | p, r, f = self.p, self.r, self.f 64 | return f"P: {p:.2%} R: {r:.2%} F1: {f:.2%}" 65 | -------------------------------------------------------------------------------- /elit/metrics/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-06-03 11:35 4 | from abc import ABC, abstractmethod 5 | 6 | 7 | class Metric(ABC): 8 | 9 | def __lt__(self, other): 10 | return self.score < other 11 | 12 | def __le__(self, other): 13 | return self.score <= other 14 | 15 | def __eq__(self, other): 16 | return self.score == other 17 | 18 | def __ge__(self, other): 19 | return self.score >= other 20 | 21 | def __gt__(self, other): 22 | return self.score > other 23 | 24 | def __ne__(self, other): 25 | return self.score != other 26 | 27 | @property 28 | @abstractmethod 29 | def score(self): 30 | pass 31 | 32 | @abstractmethod 33 | def __call__(self, pred, gold, mask=None): 34 | pass 35 | 36 | def __repr__(self) -> str: 37 | return f'{self.score}:.4f' 38 | 39 | def __float__(self): 40 | return self.score 41 | 42 | @abstractmethod 43 | def reset(self): 44 | pass 45 | -------------------------------------------------------------------------------- /elit/metrics/mtl.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-08-03 00:16 4 | from elit.metrics.metric import Metric 5 | 6 | 7 | class MetricDict(Metric, dict): 8 | _COLORS = ["magenta", "cyan", "green", "yellow"] 9 | 10 | @property 11 | def score(self): 12 | return sum(float(x) for x in self.values()) / len(self) 13 | 14 | def __call__(self, pred, gold): 15 | for metric in self.values(): 16 | metric(pred, gold) 17 | 18 | def reset(self): 19 | for metric in self.values(): 20 | metric.reset() 21 | 22 | def __repr__(self) -> str: 23 | return ' '.join(f'({k} {v})' for k, v in self.items()) 24 | 25 | def cstr(self, idx=None, level=0) -> str: 26 | if idx is None: 27 | idx = [0] 28 | prefix = '' 29 | for _, (k, v) in enumerate(self.items()): 30 | color = self._COLORS[idx[0] % len(self._COLORS)] 31 | idx[0] += 1 32 | child_is_dict = isinstance(v, MetricDict) 33 | _level = min(level, 2) 34 | # if level != 0 and not child_is_dict: 35 | # _level = 2 36 | lb = '{[(' 37 | rb = '}])' 38 | k = f'[bold][underline]{k}[/underline][/bold]' 39 | prefix += f'[{color}]{lb[_level]}{k} [/{color}]' 40 | if child_is_dict: 41 | prefix += v.cstr(idx, level + 1) 42 | else: 43 | prefix += f'[{color}]{v}[/{color}]' 44 | prefix += f'[{color}]{rb[_level]}[/{color}]' 45 | return prefix 46 | -------------------------------------------------------------------------------- /elit/metrics/parsing/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-27 00:48 -------------------------------------------------------------------------------- /elit/metrics/parsing/labeled_score.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-27 00:49 4 | 5 | import tensorflow as tf 6 | 7 | 8 | class LabeledScore(object): 9 | 10 | def __init__(self, eps=1e-5): 11 | super(LabeledScore, self).__init__() 12 | 13 | self.eps = eps 14 | self.total = 0.0 15 | self.correct_arcs = 0.0 16 | self.correct_rels = 0.0 17 | 18 | def __repr__(self): 19 | return f"UAS: {self.uas:6.2%} LAS: {self.las:6.2%}" 20 | 21 | def __call__(self, arc_preds, rel_preds, arc_golds, rel_golds, mask): 22 | arc_mask = (arc_preds == arc_golds)[mask] 23 | rel_mask = (rel_preds == rel_golds)[mask] & arc_mask 24 | 25 | self.total += len(arc_mask) 26 | self.correct_arcs += int(tf.math.count_nonzero(arc_mask)) 27 | self.correct_rels += int(tf.math.count_nonzero(rel_mask)) 28 | 29 | def __lt__(self, other): 30 | return self.score < other 31 | 32 | def __le__(self, other): 33 | return self.score <= other 34 | 35 | def __ge__(self, other): 36 | return self.score >= other 37 | 38 | def __gt__(self, other): 39 | return self.score > other 40 | 41 | @property 42 | def score(self): 43 | return self.las 44 | 45 | @property 46 | def uas(self): 47 | return self.correct_arcs / (self.total + self.eps) 48 | 49 | @property 50 | def las(self): 51 | return self.correct_rels / (self.total + self.eps) 52 | 53 | def reset_states(self): 54 | self.total = 0.0 55 | self.correct_arcs = 0.0 56 | self.correct_rels = 0.0 57 | 58 | def to_dict(self) -> dict: 59 | return {'UAS': self.uas, 'LAS': self.las} 60 | -------------------------------------------------------------------------------- /elit/metrics/srl/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-16 18:44 -------------------------------------------------------------------------------- /elit/metrics/srl/srlconll.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-16 18:44 4 | import os 5 | 6 | from elit.utils.io_util import get_resource, get_exitcode_stdout_stderr, run_cmd 7 | 8 | 9 | def official_conll_05_evaluate(pred_path, gold_path): 10 | script_root = get_resource('http://www.lsi.upc.edu/~srlconll/srlconll-1.1.tgz') 11 | lib_path = f'{script_root}/lib' 12 | if lib_path not in os.environ.get("PERL5LIB", ""): 13 | os.environ['PERL5LIB'] = f'{lib_path}:{os.environ.get("PERL5LIB", "")}' 14 | bin_path = f'{script_root}/bin' 15 | if bin_path not in os.environ.get('PATH', ''): 16 | os.environ['PATH'] = f'{bin_path}:{os.environ.get("PATH", "")}' 17 | eval_info_gold_pred = run_cmd(f'perl {script_root}/bin/srl-eval.pl {gold_path} {pred_path}') 18 | eval_info_pred_gold = run_cmd(f'perl {script_root}/bin/srl-eval.pl {pred_path} {gold_path}') 19 | conll_recall = float(eval_info_gold_pred.strip().split("\n")[6].strip().split()[5]) / 100 20 | conll_precision = float(eval_info_pred_gold.strip().split("\n")[6].strip().split()[5]) / 100 21 | if conll_recall + conll_precision > 0: 22 | conll_f1 = 2 * conll_recall * conll_precision / (conll_recall + conll_precision) 23 | else: 24 | conll_f1 = 0 25 | return conll_precision, conll_recall, conll_f1 26 | 27 | 28 | def run_perl(script, src, dst=None): 29 | os.environ['PERL5LIB'] = f'' 30 | exitcode, out, err = get_exitcode_stdout_stderr( 31 | f'perl -I{os.path.expanduser("~/.local/lib/perl5")} {script} {src}') 32 | if exitcode: 33 | # cpanm -l ~/.local namespace::autoclean 34 | # cpanm -l ~/.local Moose 35 | # cpanm -l ~/.local MooseX::SemiAffordanceAccessor module 36 | raise RuntimeError(err) 37 | with open(dst, 'w') as ofile: 38 | ofile.write(out) 39 | return dst 40 | -------------------------------------------------------------------------------- /elit/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-11-11 18:44 -------------------------------------------------------------------------------- /elit/pretrained/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 19:10 4 | from elit.pretrained import tok 5 | from elit.pretrained import dep 6 | from elit.pretrained import sdp 7 | from elit.pretrained import glove 8 | from elit.pretrained import pos 9 | from elit.pretrained import rnnlm 10 | from elit.pretrained import word2vec 11 | from elit.pretrained import ner 12 | from elit.pretrained import classifiers 13 | from elit.pretrained import fasttext 14 | from elit.pretrained import mtl 15 | from elit.pretrained import eos 16 | 17 | # Will be filled up during runtime 18 | ALL = {} 19 | -------------------------------------------------------------------------------- /elit/pretrained/classifiers.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-01-01 03:51 4 | from elit.common.constant import HANLP_URL 5 | 6 | CHNSENTICORP_BERT_BASE_ZH = HANLP_URL + 'classification/chnsenticorp_bert_base_20200104_164655.zip' 7 | SST2_BERT_BASE_EN = HANLP_URL + 'classification/sst2_bert_base_uncased_en_20200210_090240.zip' 8 | SST2_ALBERT_BASE_EN = HANLP_URL + 'classification/sst2_albert_base_20200122_205915.zip' 9 | EMPATHETIC_DIALOGUES_SITUATION_ALBERT_BASE_EN = HANLP_URL + 'classification/empathetic_dialogues_situation_albert_base_20200122_212250.zip' 10 | EMPATHETIC_DIALOGUES_SITUATION_ALBERT_LARGE_EN = HANLP_URL + 'classification/empathetic_dialogues_situation_albert_large_20200123_142724.zip' 11 | 12 | ALL = {} 13 | -------------------------------------------------------------------------------- /elit/pretrained/dep.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-29 02:55 4 | from elit.common.constant import HANLP_URL 5 | 6 | CTB5_BIAFFINE_DEP_ZH = HANLP_URL + 'dep/biaffine_ctb5_20191229_025833.zip' 7 | 'Biaffine LSTM model (:cite:`dozat:17a`) trained on CTB5.' 8 | CTB7_BIAFFINE_DEP_ZH = HANLP_URL + 'dep/biaffine_ctb7_20200109_022431.zip' 9 | 'Biaffine LSTM model (:cite:`dozat:17a`) trained on CTB7.' 10 | 11 | PTB_BIAFFINE_DEP_EN = HANLP_URL + 'dep/ptb_dep_biaffine_20200101_174624.zip' 12 | 'Biaffine LSTM model (:cite:`dozat:17a`) trained on PTB.' 13 | 14 | ALL = {} 15 | -------------------------------------------------------------------------------- /elit/pretrained/eos.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-22 13:22 4 | from elit.common.constant import HANLP_URL 5 | 6 | UD_CTB_EOS_MUL = HANLP_URL + 'eos/eos_ud_ctb_mul_20201222_133543.zip' 7 | 'EOS model (:cite:`Schweter:Ahmed:2019`) trained on concatenated UD2.3 and CTB9.' 8 | 9 | # Will be filled up during runtime 10 | ALL = {} 11 | -------------------------------------------------------------------------------- /elit/pretrained/fasttext.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-30 18:57 4 | FASTTEXT_DEBUG_EMBEDDING_EN = 'https://elit-models.s3-us-west-2.amazonaws.com/fasttext.debug.bin.zip' 5 | FASTTEXT_CC_300_EN = 'https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz' 6 | 'FastText (:cite:`bojanowski2017enriching`) embeddings trained on Common Crawl.' 7 | FASTTEXT_WIKI_NYT_AMAZON_FRIENDS_200_EN \ 8 | = 'https://elit-models.s3-us-west-2.amazonaws.com/fasttext-200-wikipedia-nytimes-amazon-friends-20191107.bin' 9 | 'FastText (:cite:`bojanowski2017enriching`) embeddings trained on wikipedia, nytimes and friends.' 10 | 11 | FASTTEXT_WIKI_300_ZH = 'https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.zh.zip#wiki.zh.bin' 12 | 'FastText (:cite:`bojanowski2017enriching`) embeddings trained on Chinese Wikipedia.' 13 | FASTTEXT_WIKI_300_ZH_CLASSICAL = 'https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.zh_classical.zip#wiki.zh_classical.bin' 14 | 'FastText (:cite:`bojanowski2017enriching`) embeddings trained on traditional Chinese wikipedia.' 15 | 16 | ALL = {} 17 | -------------------------------------------------------------------------------- /elit/pretrained/glove.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-08-27 20:42 4 | 5 | _GLOVE_6B_ROOT = 'http://downloads.cs.stanford.edu/nlp/data/glove.6B.zip' 6 | 7 | GLOVE_6B_50D = _GLOVE_6B_ROOT + '#' + 'glove.6B.50d.txt' 8 | 'Global Vectors for Word Representation (:cite:`pennington-etal-2014-glove`) 50d trained on 6B tokens.' 9 | GLOVE_6B_100D = _GLOVE_6B_ROOT + '#' + 'glove.6B.100d.txt' 10 | 'Global Vectors for Word Representation (:cite:`pennington-etal-2014-glove`) 100d trained on 6B tokens.' 11 | GLOVE_6B_200D = _GLOVE_6B_ROOT + '#' + 'glove.6B.200d.txt' 12 | 'Global Vectors for Word Representation (:cite:`pennington-etal-2014-glove`) 200d trained on 6B tokens.' 13 | GLOVE_6B_300D = _GLOVE_6B_ROOT + '#' + 'glove.6B.300d.txt' 14 | 'Global Vectors for Word Representation (:cite:`pennington-etal-2014-glove`) 300d trained on 6B tokens.' 15 | 16 | GLOVE_840B_300D = 'http://nlp.stanford.edu/data/glove.840B.300d.zip' 17 | 'Global Vectors for Word Representation (:cite:`pennington-etal-2014-glove`) 300d trained on 840B tokens.' 18 | 19 | ALL = {} 20 | -------------------------------------------------------------------------------- /elit/pretrained/mtl.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-22 13:16 4 | from elit.common.constant import HANLP_URL 5 | 6 | OPEN_TOK_POS_NER_SRL_DEP_SDP_CON_ELECTRA_SMALL_ZH = HANLP_URL + 'mtl/open_tok_pos_ner_srl_dep_sdp_con_electra_small_20201223_203434.zip' 7 | "Electra small version of joint tok, pos, ner, srl, dep, sdp and con model trained on open-source Chinese corpus." 8 | OPEN_TOK_POS_NER_SRL_DEP_SDP_CON_ELECTRA_BASE_ZH = HANLP_URL + 'mtl/open_tok_pos_ner_srl_dep_sdp_con_electra_base_20201223_201906.zip' 9 | "Electra base version of joint tok, pos, ner, srl, dep, sdp and con model trained on open-source Chinese corpus." 10 | CLOSE_TOK_POS_NER_SRL_DEP_SDP_CON_ELECTRA_SMALL_ZH = HANLP_URL + 'mtl/close_tok_pos_ner_srl_dep_sdp_con_electra_small_zh_20201222_130611.zip' 11 | "Electra small version of joint tok, pos, ner, srl, dep, sdp and con model trained on private Chinese corpus." 12 | CLOSE_TOK_POS_NER_SRL_DEP_SDP_CON_ELECTRA_BASE_ZH = HANLP_URL + 'mtl/close_tok_pos_ner_srl_dep_sdp_con_electra_base_20201226_221208.zip' 13 | "Electra base version of joint tok, pos, ner, srl, dep, sdp and con model trained on private Chinese corpus." 14 | # Will be filled up during runtime 15 | ALL = {} 16 | -------------------------------------------------------------------------------- /elit/pretrained/ner.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-30 20:07 4 | from elit.common.constant import HANLP_URL 5 | 6 | MSRA_NER_BERT_BASE_ZH = HANLP_URL + 'ner/ner_bert_base_msra_20200104_185735.zip' 7 | 'BERT model (:cite:`devlin-etal-2019-bert`) trained on MSRA with 3 entity types.' 8 | MSRA_NER_ALBERT_BASE_ZH = HANLP_URL + 'ner/ner_albert_base_zh_msra_20200111_202919.zip' 9 | 'ALBERT model (:cite:`Lan2020ALBERT:`) trained on MSRA with 3 entity types.' 10 | CONLL03_NER_BERT_BASE_UNCASED_EN = HANLP_URL + 'ner/ner_conll03_bert_base_uncased_en_20200104_194352.zip' 11 | 'BERT model (:cite:`devlin-etal-2019-bert`) trained on CoNLL03.' 12 | 13 | ALL = {} 14 | -------------------------------------------------------------------------------- /elit/pretrained/pos.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-29 01:57 4 | from elit.common.constant import HANLP_URL 5 | 6 | CTB5_POS_RNN = HANLP_URL + 'pos/ctb5_pos_rnn_20200113_235925.zip' 7 | 'An old school BiLSTM tagging model trained on CTB5.' 8 | CTB5_POS_RNN_FASTTEXT_ZH = HANLP_URL + 'pos/ctb5_pos_rnn_fasttext_20191230_202639.zip' 9 | 'An old school BiLSTM tagging model with FastText (:cite:`bojanowski2017enriching`) embeddings trained on CTB5.' 10 | CTB9_POS_ALBERT_BASE = HANLP_URL + 'pos/ctb9_albert_base_zh_epoch_20_20201011_090522.zip' 11 | 'ALBERT model (:cite:`Lan2020ALBERT:`) trained on CTB9.' 12 | 13 | PTB_POS_RNN_FASTTEXT_EN = HANLP_URL + 'pos/ptb_pos_rnn_fasttext_20200103_145337.zip' 14 | 'An old school BiLSTM tagging model with FastText (:cite:`bojanowski2017enriching`) embeddings trained on PTB.' 15 | 16 | ALL = {} -------------------------------------------------------------------------------- /elit/pretrained/rnnlm.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-19 03:47 4 | from elit.common.constant import HANLP_URL 5 | 6 | FLAIR_LM_FW_WMT11_EN_TF = HANLP_URL + 'lm/flair_lm_wmt11_en_20200211_091932.zip#flair_lm_fw_wmt11_en' 7 | FLAIR_LM_BW_WMT11_EN_TF = HANLP_URL + 'lm/flair_lm_wmt11_en_20200211_091932.zip#flair_lm_bw_wmt11_en' 8 | FLAIR_LM_WMT11_EN = HANLP_URL + 'lm/flair_lm_wmt11_en_20200601_205350.zip' 9 | 10 | ALL = {} 11 | -------------------------------------------------------------------------------- /elit/pretrained/sdp.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-31 23:54 4 | from elit.common.constant import HANLP_URL 5 | 6 | SEMEVAL16_NEWS_BIAFFINE_ZH = HANLP_URL + 'sdp/semeval16-news-biaffine_20191231_235407.zip' 7 | 'Biaffine SDP (:cite:`bertbaseline`) trained on SemEval16 news data.' 8 | SEMEVAL16_TEXT_BIAFFINE_ZH = HANLP_URL + 'sdp/semeval16-text-biaffine_20200101_002257.zip' 9 | 'Biaffine SDP (:cite:`bertbaseline`) trained on SemEval16 text data.' 10 | 11 | SEMEVAL15_PAS_BIAFFINE_EN = HANLP_URL + 'sdp/semeval15_biaffine_pas_20200103_152405.zip' 12 | 'Biaffine SDP (:cite:`bertbaseline`) trained on SemEval15 PAS data.' 13 | SEMEVAL15_PSD_BIAFFINE_EN = HANLP_URL + 'sdp/semeval15_biaffine_psd_20200106_123009.zip' 14 | 'Biaffine SDP (:cite:`bertbaseline`) trained on SemEval15 PSD data.' 15 | SEMEVAL15_DM_BIAFFINE_EN = HANLP_URL + 'sdp/semeval15_biaffine_dm_20200106_122808.zip' 16 | 'Biaffine SDP (:cite:`bertbaseline`) trained on SemEval15 DM data.' 17 | 18 | ALL = {} 19 | -------------------------------------------------------------------------------- /elit/pretrained/tok.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 21:12 4 | from elit.common.constant import HANLP_URL 5 | 6 | SIGHAN2005_PKU_CONVSEG = HANLP_URL + 'tok/sighan2005-pku-convseg_20200110_153722.zip' 7 | 'Conv model (:cite:`wang-xu-2017-convolutional`) trained on sighan2005 pku dataset.' 8 | SIGHAN2005_MSR_CONVSEG = HANLP_URL + 'tok/convseg-msr-nocrf-noembed_20200110_153524.zip' 9 | 'Conv model (:cite:`wang-xu-2017-convolutional`) trained on sighan2005 msr dataset.' 10 | # SIGHAN2005_MSR_BERT_BASE = HANLP_URL + 'tok/cws_bert_base_msra_20191230_194627.zip' 11 | CTB6_CONVSEG = HANLP_URL + 'tok/ctb6_convseg_nowe_nocrf_20200110_004046.zip' 12 | 'Conv model (:cite:`wang-xu-2017-convolutional`) trained on CTB6 dataset.' 13 | # CTB6_BERT_BASE = HANLP_URL + 'tok/cws_bert_base_ctb6_20191230_185536.zip' 14 | PKU_NAME_MERGED_SIX_MONTHS_CONVSEG = HANLP_URL + 'tok/pku98_6m_conv_ngram_20200110_134736.zip' 15 | 'Conv model (:cite:`wang-xu-2017-convolutional`) trained on pku98 six months dataset with name merged into one unit.' 16 | LARGE_ALBERT_BASE = HANLP_URL + 'tok/large_cws_albert_base_20200828_011451.zip' 17 | 'ALBERT model (:cite:`Lan2020ALBERT:`) trained on the largest CWS dataset in the world.' 18 | SIGHAN2005_PKU_BERT_BASE_ZH = HANLP_URL + 'tok/sighan2005_pku_bert_base_zh_20201231_141130.zip' 19 | 'BERT model (:cite:`devlin-etal-2019-bert`) trained on sighan2005 pku dataset.' 20 | 21 | # Will be filled up during runtime 22 | ALL = {} 23 | -------------------------------------------------------------------------------- /elit/pretrained/word2vec.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-21 18:25 4 | from elit.common.constant import HANLP_URL 5 | 6 | CONVSEG_W2V_NEWS_TENSITE = HANLP_URL + 'embeddings/convseg_embeddings.zip' 7 | CONVSEG_W2V_NEWS_TENSITE_WORD_PKU = CONVSEG_W2V_NEWS_TENSITE + '#news_tensite.pku.words.w2v50' 8 | CONVSEG_W2V_NEWS_TENSITE_WORD_MSR = CONVSEG_W2V_NEWS_TENSITE + '#news_tensite.msr.words.w2v50' 9 | CONVSEG_W2V_NEWS_TENSITE_CHAR = CONVSEG_W2V_NEWS_TENSITE + '#news_tensite.w2v200' 10 | 11 | SEMEVAL16_EMBEDDINGS_CN = HANLP_URL + 'embeddings/semeval16_embeddings.zip' 12 | SEMEVAL16_EMBEDDINGS_300_NEWS_CN = SEMEVAL16_EMBEDDINGS_CN + '#news.fasttext.300.txt' 13 | SEMEVAL16_EMBEDDINGS_300_TEXT_CN = SEMEVAL16_EMBEDDINGS_CN + '#text.fasttext.300.txt' 14 | 15 | CTB5_FASTTEXT_300_CN = HANLP_URL + 'embeddings/ctb.fasttext.300.txt.zip' 16 | 17 | TENCENT_AI_LAB_EMBEDDING = 'https://ai.tencent.com/ailab/nlp/data/Tencent_AILab_ChineseEmbedding.tar.gz#Tencent_AILab_ChineseEmbedding.txt' 18 | 19 | RADICAL_CHAR_EMBEDDING_100 = HANLP_URL + 'embeddings/radical_char_vec_20191229_013849.zip#character.vec.txt' 20 | 'Chinese character embedding enhanced with rich radical information (:cite:`he2018dual`).' 21 | 22 | _SUBWORD_ENCODING_CWS = HANLP_URL + 'embeddings/subword_encoding_cws_20200524_190636.zip' 23 | SUBWORD_ENCODING_CWS_ZH_WIKI_BPE_50 = _SUBWORD_ENCODING_CWS + '#zh.wiki.bpe.vs200000.d50.w2v.txt' 24 | SUBWORD_ENCODING_CWS_GIGAWORD_UNI = _SUBWORD_ENCODING_CWS + '#gigaword_chn.all.a2b.uni.ite50.vec' 25 | SUBWORD_ENCODING_CWS_GIGAWORD_BI = _SUBWORD_ENCODING_CWS + '#gigaword_chn.all.a2b.bi.ite50.vec' 26 | SUBWORD_ENCODING_CWS_CTB_GAZETTEER_50 = _SUBWORD_ENCODING_CWS + '#ctb.50d.vec' 27 | 28 | ALL = {} 29 | -------------------------------------------------------------------------------- /elit/transform/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-29 22:24 -------------------------------------------------------------------------------- /elit/transform/glue_tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-05-08 16:34 4 | from elit.common.structure import SerializableDict 5 | from elit.datasets.glue import STANFORD_SENTIMENT_TREEBANK_2_TRAIN, MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_DEV 6 | from elit.transform.table import TableTransform 7 | 8 | 9 | class StanfordSentimentTreebank2Transorm(TableTransform): 10 | pass 11 | 12 | 13 | class MicrosoftResearchParaphraseCorpus(TableTransform): 14 | 15 | def __init__(self, config: SerializableDict = None, map_x=False, map_y=True, x_columns=(3, 4), 16 | y_column=0, skip_header=True, delimiter='auto', **kwargs) -> None: 17 | super().__init__(config, map_x, map_y, x_columns, y_column, skip_header, delimiter, **kwargs) 18 | 19 | 20 | def main(): 21 | # _test_sst2() 22 | _test_mrpc() 23 | 24 | 25 | def _test_sst2(): 26 | transform = StanfordSentimentTreebank2Transorm() 27 | transform.fit(STANFORD_SENTIMENT_TREEBANK_2_TRAIN) 28 | transform.lock_vocabs() 29 | transform.label_vocab.summary() 30 | transform.build_config() 31 | dataset = transform.file_to_dataset(STANFORD_SENTIMENT_TREEBANK_2_TRAIN) 32 | for batch in dataset.take(1): 33 | print(batch) 34 | 35 | 36 | def _test_mrpc(): 37 | transform = MicrosoftResearchParaphraseCorpus() 38 | transform.fit(MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_DEV) 39 | transform.lock_vocabs() 40 | transform.label_vocab.summary() 41 | transform.build_config() 42 | dataset = transform.file_to_dataset(MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_DEV) 43 | for batch in dataset.take(1): 44 | print(batch) -------------------------------------------------------------------------------- /elit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-08-24 22:12 4 | from . import rules 5 | 6 | 7 | def ls_resource_in_module(root) -> dict: 8 | res = dict() 9 | for k, v in root.__dict__.items(): 10 | if k.startswith('_') or v == root: 11 | continue 12 | if isinstance(v, str): 13 | if v.startswith('http') and not v.endswith('/') and not v.endswith('#') and not v.startswith('_'): 14 | res[k] = v 15 | elif type(v).__name__ == 'module': 16 | res.update(ls_resource_in_module(v)) 17 | if 'ALL' in root.__dict__ and isinstance(root.__dict__['ALL'], dict): 18 | root.__dict__['ALL'].update(res) 19 | return res 20 | -------------------------------------------------------------------------------- /elit/utils/file_read_backwards/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .file_read_backwards import FileReadBackwards # noqa: F401 4 | 5 | __author__ = """Robin Robin""" 6 | __email__ = 'robinsquare42@gmail.com' 7 | __version__ = '2.0.0' 8 | -------------------------------------------------------------------------------- /elit/utils/init_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-05-27 13:25 4 | import math 5 | 6 | import torch 7 | from torch import nn 8 | import functools 9 | 10 | 11 | def embedding_uniform(tensor:torch.Tensor, seed=233): 12 | gen = torch.Generator().manual_seed(seed) 13 | with torch.no_grad(): 14 | fan_out = tensor.size(-1) 15 | bound = math.sqrt(3.0 / fan_out) 16 | return tensor.uniform_(-bound, bound, generator=gen) 17 | -------------------------------------------------------------------------------- /elit/utils/lang/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-01-09 18:46 4 | 5 | __doc__ = ''' 6 | This package holds misc utils for specific languages. 7 | ''' 8 | -------------------------------------------------------------------------------- /elit/utils/lang/zh/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-01-09 18:47 -------------------------------------------------------------------------------- /elit/utils/lang/zh/char_table.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-01-09 19:07 4 | from typing import List 5 | 6 | from elit.utils.io_util import get_resource 7 | from elit.common.io import load_json 8 | 9 | HANLP_CHAR_TABLE_TXT = 'https://file.hankcs.com/corpus/char_table.zip#CharTable.txt' 10 | HANLP_CHAR_TABLE_JSON = 'https://file.hankcs.com/corpus/char_table.json.zip' 11 | 12 | 13 | class CharTable: 14 | convert = {} 15 | 16 | @staticmethod 17 | def convert_char(c): 18 | if not CharTable.convert: 19 | CharTable._init() 20 | return CharTable.convert.get(c, c) 21 | 22 | @staticmethod 23 | def normalize_text(text: str) -> str: 24 | return ''.join(CharTable.convert_char(c) for c in text) 25 | 26 | @staticmethod 27 | def normalize_chars(chars: List[str]) -> List[str]: 28 | return [CharTable.convert_char(c) for c in chars] 29 | 30 | @staticmethod 31 | def _init(): 32 | CharTable.convert = CharTable.load() 33 | 34 | @staticmethod 35 | def load(): 36 | mapper = {} 37 | with open(get_resource(HANLP_CHAR_TABLE_TXT), encoding='utf-8') as src: 38 | for line in src: 39 | cells = line.rstrip('\n') 40 | if len(cells) != 3: 41 | continue 42 | a, _, b = cells 43 | mapper[a] = b 44 | return mapper 45 | 46 | 47 | class JsonCharTable(CharTable): 48 | 49 | @staticmethod 50 | def load(): 51 | return load_json(get_resource(HANLP_CHAR_TABLE_JSON)) 52 | 53 | 54 | -------------------------------------------------------------------------------- /elit/utils/lang/zh/localization.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-05 02:09 4 | 5 | task = { 6 | 'dep': '依存句法树', 7 | 'token': '单词', 8 | 'pos': '词性', 9 | 'ner': '命名实体', 10 | 'srl': '语义角色' 11 | } 12 | 13 | pos = { 14 | 'VA': '表语形容词', 'VC': '系动词', 'VE': '动词有无', 'VV': '其他动词', 'NR': '专有名词', 'NT': '时间名词', 'NN': '其他名词', 15 | 'LC': '方位词', 'PN': '代词', 'DT': '限定词', 'CD': '概数词', 'OD': '序数词', 'M': '量词', 'AD': '副词', 'P': '介词', 16 | 'CC': '并列连接词', 'CS': '从属连词', 'DEC': '补语成分“的”', 'DEG': '属格“的”', 'DER': '表结果的“得”', 'DEV': '表方式的“地”', 17 | 'AS': '动态助词', 'SP': '句末助词', 'ETC': '表示省略', 'MSP': '其他小品词', 'IJ': '句首感叹词', 'ON': '象声词', 18 | 'LB': '长句式表被动', 'SB': '短句式表被动', 'BA': '把字句', 'JJ': '其他名词修饰语', 'FW': '外来语', 'PU': '标点符号', 19 | 'NOI': '噪声', 'URL': '网址' 20 | } 21 | 22 | ner = { 23 | 'NT': '机构团体', 'NS': '地名', 'NR': '人名' 24 | } 25 | 26 | dep = { 27 | 'nn': '复合名词修饰', 'punct': '标点符号', 'nsubj': '名词性主语', 'conj': '连接性状语', 'dobj': '直接宾语', 'advmod': '名词性状语', 28 | 'prep': '介词性修饰语', 'nummod': '数词修饰语', 'amod': '形容词修饰语', 'pobj': '介词性宾语', 'rcmod': '相关关系', 'cpm': '补语', 29 | 'assm': '关联标记', 'assmod': '关联修饰', 'cc': '并列关系', 'elf': '类别修饰', 'ccomp': '从句补充', 'det': '限定语', 'lobj': '时间介词', 30 | 'range': '数量词间接宾语', 'asp': '时态标记', 'tmod': '时间修饰语', 'plmod': '介词性地点修饰', 'attr': '属性', 'mmod': '情态动词', 31 | 'loc': '位置补语', 'top': '主题', 'pccomp': '介词补语', 'etc': '省略关系', 'lccomp': '位置补语', 'ordmod': '量词修饰', 32 | 'xsubj': '控制主语', 'neg': '否定修饰', 'rcomp': '结果补语', 'comod': '并列联合动词', 'vmod': '动词修饰', 'prtmod': '小品词', 33 | 'ba': '把字关系', 'dvpm': '地字修饰', 'dvpmod': '地字动词短语', 'prnmod': '插入词修饰', 'cop': '系动词', 'pass': '被动标记', 34 | 'nsubjpass': '被动名词主语', 'clf': '类别修饰', 'dep': '依赖关系', 'root': '核心关系' 35 | } 36 | -------------------------------------------------------------------------------- /elit/utils/rules.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | SEPARATOR = r'@' 4 | RE_SENTENCE = re.compile(r'(\S.+?[.!?])(?=\s+|$)|(\S.+?)(?=[\n]|$)', re.UNICODE) 5 | AB_SENIOR = re.compile(r'([A-Z][a-z]{1,2}\.)\s(\w)', re.UNICODE) 6 | AB_ACRONYM = re.compile(r'(\.[a-zA-Z]\.)\s(\w)', re.UNICODE) 7 | UNDO_AB_SENIOR = re.compile(r'([A-Z][a-z]{1,2}\.)' + SEPARATOR + r'(\w)', re.UNICODE) 8 | UNDO_AB_ACRONYM = re.compile(r'(\.[a-zA-Z]\.)' + SEPARATOR + r'(\w)', re.UNICODE) 9 | 10 | 11 | def replace_with_separator(text, separator, regexs): 12 | replacement = r"\1" + separator + r"\2" 13 | result = text 14 | for regex in regexs: 15 | result = regex.sub(replacement, result) 16 | return result 17 | 18 | 19 | def split_sentence(text, best=True): 20 | text = re.sub('([。!?\?])([^”’])', r"\1\n\2", text) 21 | text = re.sub('(\.{6})([^”’])', r"\1\n\2", text) 22 | text = re.sub('(\…{2})([^”’])', r"\1\n\2", text) 23 | text = re.sub('([。!?\?][”’])([^,。!?\?])', r'\1\n\2', text) 24 | for chunk in text.split("\n"): 25 | chunk = chunk.strip() 26 | if not chunk: 27 | continue 28 | if not best: 29 | yield chunk 30 | continue 31 | processed = replace_with_separator(chunk, SEPARATOR, [AB_SENIOR, AB_ACRONYM]) 32 | for sentence in RE_SENTENCE.finditer(processed): 33 | sentence = replace_with_separator(sentence.group(), r" ", [UNDO_AB_SENIOR, UNDO_AB_ACRONYM]) 34 | yield sentence 35 | 36 | 37 | -------------------------------------------------------------------------------- /elit/utils/span_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-06-12 20:34 4 | 5 | 6 | def generate_words_per_line(file_path): 7 | with open(file_path, encoding='utf-8') as src: 8 | for line in src: 9 | cells = line.strip().split() 10 | if not cells: 11 | continue 12 | yield cells 13 | 14 | 15 | def words_to_bmes(words): 16 | tags = [] 17 | for w in words: 18 | if not w: 19 | raise ValueError('{} contains None or zero-length word {}'.format(str(words), w)) 20 | if len(w) == 1: 21 | tags.append('S') 22 | else: 23 | tags.extend(['B'] + ['M'] * (len(w) - 2) + ['E']) 24 | return tags 25 | 26 | 27 | def words_to_bi(words): 28 | tags = [] 29 | for w in words: 30 | if not w: 31 | raise ValueError('{} contains None or zero-length word {}'.format(str(words), w)) 32 | tags.extend(['B'] + ['I'] * (len(w) - 1)) 33 | return tags 34 | 35 | 36 | def bmes_to_words(chars, tags): 37 | result = [] 38 | if len(chars) == 0: 39 | return result 40 | word = chars[0] 41 | 42 | for c, t in zip(chars[1:], tags[1:]): 43 | if t == 'B' or t == 'S': 44 | result.append(word) 45 | word = '' 46 | word += c 47 | if len(word) != 0: 48 | result.append(word) 49 | 50 | return result 51 | 52 | 53 | def bmes_to_spans(tags): 54 | result = [] 55 | offset = 0 56 | pre_offset = 0 57 | for t in tags[1:]: 58 | offset += 1 59 | if t == 'B' or t == 'S': 60 | result.append((pre_offset, offset)) 61 | pre_offset = offset 62 | if offset != len(tags): 63 | result.append((pre_offset, len(tags))) 64 | 65 | return result 66 | 67 | 68 | def bmes_of(sentence, segmented): 69 | if segmented: 70 | chars = [] 71 | tags = [] 72 | words = sentence.split() 73 | for w in words: 74 | chars.extend(list(w)) 75 | if len(w) == 1: 76 | tags.append('S') 77 | else: 78 | tags.extend(['B'] + ['M'] * (len(w) - 2) + ['E']) 79 | else: 80 | chars = list(sentence) 81 | tags = ['S'] * len(chars) 82 | return chars, tags 83 | -------------------------------------------------------------------------------- /elit/version.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 19:26 4 | 5 | __version__ = '2.1.0-alpha.0' 6 | """ELIT version""" 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | torch==1.7.1 3 | penman==0.6.2 4 | networkx 5 | # Data preprocessing 6 | overrides 7 | h5py 8 | boto3 9 | spacy 10 | ftfy 11 | nltk 12 | conllu 13 | pyyaml 14 | editdistance 15 | word2number 16 | bs4 17 | lxml 18 | scipy 19 | termcolor 20 | phrasetree 21 | pynvml 22 | toposort 23 | transformers 24 | hanlp_trie 25 | networkx 26 | pytorch_pretrained_bert -------------------------------------------------------------------------------- /scripts/annotate_features.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | # Start a Stanford CoreNLP server before running this script. 6 | # https://stanfordnlp.github.io/CoreNLP/corenlp-server.html 7 | 8 | # The compound file is downloaded from 9 | # https://github.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/blob/master/data/joints.txt 10 | compound_file=data/AMR/amr_2.0_utils/joints.txt 11 | amr_dir=$1 12 | 13 | python3 -u -m stog.data.dataset_readers.amr_parsing.preprocess.feature_annotator \ 14 | ${amr_dir}/test.txt ${amr_dir}/train.txt ${amr_dir}/dev.txt \ 15 | --compound_file ${compound_file} 16 | -------------------------------------------------------------------------------- /scripts/download_artifacts.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | # echo "Downloading artifacts." 6 | # mkdir -p data/bert-base-cased 7 | # curl -O https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz 8 | # tar -xzvf bert-base-cased.tar.gz -C data/bert-base-cased 9 | # curl -o data/bert-base-cased/bert-base-cased-vocab.txt \ 10 | # https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt 11 | # rm bert-base-cased.tar.gz 12 | 13 | # mkdir -p data/glove 14 | # curl -L -o data/glove/glove.840B.300d.zip \ 15 | # http://nlp.stanford.edu/data/wordvecs/glove.840B.300d.zip 16 | 17 | #mkdir -p tools 18 | #git clone https://github.com/ChunchuanLv/amr-evaluation-tool-enhanced.git tools/amr-evaluation-tool-enhanced 19 | 20 | mkdir -p data/AMR 21 | curl -o data/AMR/amr_2.0_utils.tar.gz https://www.cs.jhu.edu/~s.zhang/data/AMR/amr_2.0_utils.tar.gz 22 | curl -o data/AMR/amr_1.0_utils.tar.gz https://www.cs.jhu.edu/~s.zhang/data/AMR/amr_1.0_utils.tar.gz 23 | pushd data/AMR 24 | wget https://od.hankcs.com/research/amr2020/amr_3.0_utils.tgz 25 | tar xzvf amr_3.0_utils.tgz 26 | tar -xzvf amr_2.0_utils.tar.gz 27 | tar -xzvf amr_1.0_utils.tar.gz 28 | rm amr_3.0_utils.tgz amr_2.0_utils.tar.gz amr_1.0_utils.tar.gz 29 | popd 30 | 31 | -------------------------------------------------------------------------------- /scripts/env.sh: -------------------------------------------------------------------------------- 1 | python3 -m venv env 2 | source env/bin/activate 3 | pip install --upgrade pip 4 | # Feel free to modify the cuda version 5 | pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 6 | pip install -r requirements.txt -------------------------------------------------------------------------------- /scripts/postprocess_2.0.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | # Directory where intermediate utils will be saved to speed up processing. 6 | util_dir=data/AMR/amr_2.0_utils 7 | 8 | # AMR data with **features** 9 | test_data=$1 10 | 11 | 12 | python3 -u -m stog.data.dataset_readers.amr_parsing.postprocess.postprocess \ 13 | --amr_path ${test_data} \ 14 | --util_dir ${util_dir} \ 15 | --v 2 16 | printf "Done.`date`\n\n" 17 | -------------------------------------------------------------------------------- /scripts/postprocess_3.0.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | # Directory where intermediate utils will be saved to speed up processing. 6 | util_dir=data/AMR/amr_3.0_utils 7 | 8 | # AMR data with **features** 9 | test_data=$1 10 | 11 | 12 | python3 -u -m stog.data.dataset_readers.amr_parsing.postprocess.postprocess \ 13 | --amr_path ${test_data} \ 14 | --util_dir ${util_dir} \ 15 | --v 2 16 | printf "Done.`date`\n\n" 17 | -------------------------------------------------------------------------------- /scripts/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | usage() { 4 | echo "Usage: $0 -v -p " 5 | echo " Make sure your AMR corpus is untouched." 6 | echo " It should organized like below:" 7 | echo " " 8 | echo " data/" 9 | echo " docs/" 10 | echo " index.html" 11 | exit 1; 12 | } 13 | 14 | while getopts ":h:v:p:" o; do 15 | case "${o}" in 16 | h) 17 | usage 18 | ;; 19 | v) 20 | v=${OPTARG} 21 | ((v == 1 || v == 2 || v == 3)) || usage 22 | ;; 23 | p) 24 | p=${OPTARG} 25 | ;; 26 | \? ) 27 | usage 28 | ;; 29 | esac 30 | done 31 | shift $((OPTIND-1)) 32 | 33 | if [ -z $v ]; then 34 | usage 35 | fi 36 | 37 | if [ -z $p ]; then 38 | usage 39 | fi 40 | 41 | 42 | if [[ "$v" == "2" ]]; then 43 | DATA_DIR=data/AMR/amr_2.0 44 | SPLIT_DIR=$p/data/amrs/split 45 | TRAIN=${SPLIT_DIR}/training 46 | DEV=${SPLIT_DIR}/dev 47 | TEST=${SPLIT_DIR}/test 48 | elif [[ "$v" == "3" ]]; then 49 | DATA_DIR=data/AMR/amr_3.0 50 | SPLIT_DIR=$p/data/amrs/split 51 | TRAIN=${SPLIT_DIR}/training 52 | DEV=${SPLIT_DIR}/dev 53 | TEST=${SPLIT_DIR}/test 54 | else 55 | DATA_DIR=data/AMR/amr_1.0 56 | SPLIT_DIR=$p/data/amrs/split 57 | TRAIN=${SPLIT_DIR}/training 58 | DEV=${SPLIT_DIR}/dev 59 | TEST=${SPLIT_DIR}/test 60 | fi 61 | 62 | echo "Preparing data in ${DATA_DIR}...`date`" 63 | mkdir -p ${DATA_DIR} 64 | awk FNR!=1 ${TRAIN}/* > ${DATA_DIR}/train.txt 65 | awk FNR!=1 ${DEV}/* > ${DATA_DIR}/dev.txt 66 | awk FNR!=1 ${TEST}/* > ${DATA_DIR}/test.txt 67 | echo "Done..`date`" 68 | 69 | -------------------------------------------------------------------------------- /scripts/prepare_vocab.sh: -------------------------------------------------------------------------------- 1 | dataset=$1 2 | python3 -u -m amr_parser.extract --train_data ${dataset}/train.txt.features.preproc --levi_graph $2 3 | rm -f ${dataset}/*_vocab 4 | mv *_vocab ${dataset}/ 5 | # python3 encoder.py 6 | # cat ${dataset}/*embed | sort | uniq > ${dataset}/glove.embed.txt 7 | # rm ${dataset}/*embed 8 | -------------------------------------------------------------------------------- /scripts/preprocess_2.0.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | # ############### AMR v2.0 ################ 6 | # # Directory where intermediate utils will be saved to speed up processing. 7 | util_dir=data/AMR/amr_2.0_utils 8 | 9 | # AMR data with **features** 10 | data_dir=/home/hhe43/hanlp/data/amr/amr_2.0 11 | train_data=${data_dir}/single.txt.features 12 | dev_data=${data_dir}/single.txt.features 13 | test_data=${data_dir}/single.txt.features 14 | 15 | # ========== Set the above variables correctly ========== 16 | 17 | printf "Cleaning inputs...`date`\n" 18 | python3 -u -m stog.data.dataset_readers.amr_parsing.preprocess.input_cleaner \ 19 | --amr_files ${train_data} ${dev_data} ${test_data} 20 | printf "Done.`date`\n\n" 21 | 22 | printf "Recategorizing subgraphs...`date`\n" 23 | python3 -u -m stog.data.dataset_readers.amr_parsing.preprocess.recategorizer \ 24 | --dump_dir ${util_dir} \ 25 | --amr_files ${train_data}.input_clean ${dev_data}.input_clean 26 | python3 -u -m stog.data.dataset_readers.amr_parsing.preprocess.text_anonymizor \ 27 | --amr_file ${test_data}.input_clean \ 28 | --util_dir ${util_dir} 29 | printf "Done.`date`\n\n" 30 | 31 | printf "Removing senses...`date`\n" 32 | python3 -u -m stog.data.dataset_readers.amr_parsing.preprocess.sense_remover \ 33 | --util_dir ${util_dir} \ 34 | --amr_files ${train_data}.input_clean.recategorize \ 35 | ${dev_data}.input_clean.recategorize \ 36 | ${test_data}.input_clean.recategorize 37 | printf "Done.`date`\n\n" 38 | 39 | printf "Renaming preprocessed files...`date`\n" 40 | mv ${test_data}.input_clean.recategorize.nosense ${test_data}.preproc 41 | mv ${train_data}.input_clean.recategorize.nosense ${train_data}.preproc 42 | mv ${dev_data}.input_clean.recategorize.nosense ${dev_data}.preproc 43 | rm ${data_dir}/*.input_clean* 44 | -------------------------------------------------------------------------------- /scripts/preprocess_3.0.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | # ############### AMR v3.0 ################ 6 | # # Directory where intermediate utils will be saved to speed up processing. 7 | util_dir=data/AMR/amr_3.0_utils 8 | 9 | # AMR data with **features** 10 | data_dir=data/AMR/amr_3.0 11 | train_data=${data_dir}/train.txt.features 12 | dev_data=${data_dir}/dev.txt.features 13 | test_data=${data_dir}/test.txt.features 14 | 15 | # ========== Set the above variables correctly ========== 16 | 17 | printf "Cleaning inputs...`date`\n" 18 | python3 -u -m stog.data.dataset_readers.amr_parsing.preprocess.input_cleaner \ 19 | --amr_files ${train_data} ${dev_data} ${test_data} 20 | printf "Done.`date`\n\n" 21 | 22 | printf "Recategorizing subgraphs...`date`\n" 23 | python3 -u -m stog.data.dataset_readers.amr_parsing.preprocess.recategorizer \ 24 | --dump_dir ${util_dir} \ 25 | --amr_files ${train_data}.input_clean ${dev_data}.input_clean 26 | python3 -u -m stog.data.dataset_readers.amr_parsing.preprocess.text_anonymizor \ 27 | --amr_file ${test_data}.input_clean \ 28 | --util_dir ${util_dir} 29 | printf "Done.`date`\n\n" 30 | 31 | printf "Removing senses...`date`\n" 32 | python3 -u -m stog.data.dataset_readers.amr_parsing.preprocess.sense_remover \ 33 | --util_dir ${util_dir} \ 34 | --amr_files ${train_data}.input_clean.recategorize \ 35 | ${dev_data}.input_clean.recategorize \ 36 | ${test_data}.input_clean.recategorize 37 | printf "Done.`date`\n\n" 38 | 39 | printf "Renaming preprocessed files...`date`\n" 40 | mv ${test_data}.input_clean.recategorize.nosense ${test_data}.preproc 41 | mv ${train_data}.input_clean.recategorize.nosense ${train_data}.preproc 42 | mv ${dev_data}.input_clean.recategorize.nosense ${dev_data}.preproc 43 | rm ${data_dir}/*.input_clean* 44 | -------------------------------------------------------------------------------- /scripts/run_spotlight.sh: -------------------------------------------------------------------------------- 1 | java -XX:+IgnoreUnrecognizedVMOptions --add-modules java.xml.bi -cp jaxb-ri/mod/jakarta.xml.bind-api.jar -jar dbpedia-spotlight-1.0.0.jar en http://localhost:2222/rest -------------------------------------------------------------------------------- /scripts/run_standford_corenlp_server.sh: -------------------------------------------------------------------------------- 1 | java -classpath "stanford-corenlp-full-2018-10-05/*" -mx4g edu.stanford.nlp.pipeline.StanfordCoreNLPServer 1337 -port 8888 2 | -------------------------------------------------------------------------------- /scripts/train_joint.sh: -------------------------------------------------------------------------------- 1 | dataset=$1 2 | python3 amr_parser/train.py --tok_vocab ${dataset}/tok_vocab\ 3 | --lem_vocab ${dataset}/lem_vocab\ 4 | --pos_vocab ${dataset}/pos_vocab\ 5 | --ner_vocab ${dataset}/ner_vocab\ 6 | --concept_vocab ${dataset}/concept_vocab\ 7 | --predictable_concept_vocab ${dataset}/predictable_concept_vocab\ 8 | --rel_vocab ${dataset}/rel_vocab\ 9 | --word_char_vocab ${dataset}/word_char_vocab\ 10 | --concept_char_vocab ${dataset}/concept_char_vocab\ 11 | --train_data ${dataset}/train.txt.features.preproc \ 12 | --dev_data ${dataset}/dev.txt.features.preproc \ 13 | --with_bert \ 14 | --joint_arc_concept \ 15 | --bert_path bert-base-cased \ 16 | --word_char_dim 32\ 17 | --word_dim 300\ 18 | --pos_dim 32\ 19 | --ner_dim 16\ 20 | --concept_char_dim 32\ 21 | --concept_dim 300 \ 22 | --rel_dim 100 \ 23 | --cnn_filter 3 256\ 24 | --char2word_dim 128\ 25 | --char2concept_dim 128\ 26 | --embed_dim 512\ 27 | --ff_embed_dim 1024\ 28 | --num_heads 8\ 29 | --snt_layers 4\ 30 | --graph_layers 2\ 31 | --inference_layers 4\ 32 | --dropout 0.2\ 33 | --unk_rate 0.33\ 34 | --epochs 100000\ 35 | --max_batches_acm 60000\ 36 | --train_batch_size 4444\ 37 | --dev_batch_size 4444 \ 38 | --lr_scale 1. \ 39 | --warmup_steps 2000\ 40 | --print_every 100 \ 41 | --eval_every 1000 \ 42 | --batches_per_update 4 \ 43 | --ckpt model/amr2joint1\ 44 | --world_size 1\ 45 | --gpus 1\ 46 | --MASTER_ADDR localhost\ 47 | --MASTER_PORT 29505\ 48 | --start_rank 0 49 | -------------------------------------------------------------------------------- /scripts/train_levi.sh: -------------------------------------------------------------------------------- 1 | dataset=$1 2 | python3 amr_parser/train.py --tok_vocab ${dataset}/tok_vocab\ 3 | --lem_vocab ${dataset}/lem_vocab\ 4 | --pos_vocab ${dataset}/pos_vocab\ 5 | --ner_vocab ${dataset}/ner_vocab\ 6 | --concept_vocab ${dataset}/concept_vocab\ 7 | --predictable_concept_vocab ${dataset}/predictable_concept_vocab\ 8 | --rel_vocab ${dataset}/rel_vocab\ 9 | --word_char_vocab ${dataset}/word_char_vocab\ 10 | --concept_char_vocab ${dataset}/concept_char_vocab\ 11 | --train_data ${dataset}/train.txt.features.preproc \ 12 | --dev_data ${dataset}/dev.txt.features.preproc \ 13 | --with_bert \ 14 | --joint_arc_concept \ 15 | --levi_graph 1\ 16 | --bert_path bert-base-cased \ 17 | --word_char_dim 32\ 18 | --word_dim 300\ 19 | --pos_dim 32\ 20 | --ner_dim 16\ 21 | --concept_char_dim 32\ 22 | --concept_dim 300 \ 23 | --rel_dim 100 \ 24 | --cnn_filter 3 256\ 25 | --char2word_dim 128\ 26 | --char2concept_dim 128\ 27 | --embed_dim 512\ 28 | --ff_embed_dim 1024\ 29 | --num_heads 8\ 30 | --snt_layers 4\ 31 | --graph_layers 2\ 32 | --inference_layers 4\ 33 | --dropout 0.2\ 34 | --unk_rate 0.33\ 35 | --epochs 100000\ 36 | --max_batches_acm 60000\ 37 | --train_batch_size 17776\ 38 | --dev_batch_size 4444 \ 39 | --lr_scale 1. \ 40 | --warmup_steps 2000\ 41 | --print_every 100 \ 42 | --eval_every 1000 \ 43 | --batches_per_update 1 \ 44 | --ckpt levi_model \ 45 | --world_size 1\ 46 | --gpus 1\ 47 | --MASTER_ADDR localhost\ 48 | --MASTER_PORT 29505\ 49 | --start_rank 0 50 | -------------------------------------------------------------------------------- /stog/.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /stog/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import exception_hook 2 | -------------------------------------------------------------------------------- /stog/algorithms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emorynlp/levi-graph-amr-parser/f71f1056c13181b8db31d6136451fb8d57114819/stog/algorithms/__init__.py -------------------------------------------------------------------------------- /stog/algorithms/dict_merge.py: -------------------------------------------------------------------------------- 1 | # Recursive dictionary merge 2 | # Copyright (C) 2016 Paul Durivage 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | 17 | # Found here https://gist.github.com/angstwad/bf22d1822c38a92ec0a9 18 | # Using jpopelka's modified solution. 19 | 20 | import collections 21 | 22 | def dict_merge(dct, merge_dct): 23 | """ Recursive dict merge. Inspired by :meth:``dict.update()``, instead of 24 | updating only top-level keys, dict_merge recurses down into dicts nested 25 | to an arbitrary depth, updating keys. The ``merge_dct`` is merged into 26 | ``dct``. 27 | 28 | :param dct: dict onto which the merge is executed 29 | :param merge_dct: dct merged into dct 30 | :return: None 31 | """ 32 | for k, v in merge_dct.items(): 33 | if isinstance(dct.get(k), dict) and isinstance(v, collections.Mapping): 34 | dict_merge(dct[k], v) 35 | else: 36 | dct[k] = v 37 | -------------------------------------------------------------------------------- /stog/commands/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emorynlp/levi-graph-amr-parser/f71f1056c13181b8db31d6136451fb8d57114819/stog/commands/__init__.py -------------------------------------------------------------------------------- /stog/commands/subcommand.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for subcommands under ``allennlp.run``. 3 | """ 4 | import argparse 5 | 6 | class Subcommand: 7 | """ 8 | An abstract class representing subcommands for allennlp.run. 9 | If you wanted to (for example) create your own custom `special-evaluate` command to use like 10 | 11 | ``allennlp special-evaluate ...`` 12 | 13 | you would create a ``Subcommand`` subclass and then pass it as an override to 14 | :func:`~allennlp.commands.main` . 15 | """ 16 | def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argparse.ArgumentParser: 17 | # pylint: disable=protected-access 18 | raise NotImplementedError 19 | -------------------------------------------------------------------------------- /stog/data/__init__.py: -------------------------------------------------------------------------------- 1 | from stog.data.dataset_readers.dataset_reader import DatasetReader 2 | from stog.data.fields.field import DataArray, Field 3 | from stog.data.instance import Instance 4 | from stog.data.iterators.data_iterator import DataIterator 5 | from stog.data.token_indexers.token_indexer import TokenIndexer, TokenType 6 | from stog.data.tokenizers.token import Token 7 | from stog.data.tokenizers.tokenizer import Tokenizer 8 | from stog.data.vocabulary import Vocabulary 9 | -------------------------------------------------------------------------------- /stog/data/dataset_readers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A :class:`~stog.data.dataset_readers.dataset_reader.DatasetReader` 3 | reads a file and converts it to a collection of 4 | :class:`~stog.data.instance.Instance` s. 5 | The various subclasses know how to read specific filetypes 6 | and produce datasets in the formats required by specific models. 7 | """ 8 | 9 | # pylint: disable=line-too-long 10 | from stog.data.dataset_readers.dataset_reader import DatasetReader 11 | from stog.data.dataset_readers.abstract_meaning_representation import AbstractMeaningRepresentationDatasetReader 12 | -------------------------------------------------------------------------------- /stog/data/dataset_readers/amr_parsing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emorynlp/levi-graph-amr-parser/f71f1056c13181b8db31d6136451fb8d57114819/stog/data/dataset_readers/amr_parsing/__init__.py -------------------------------------------------------------------------------- /stog/data/dataset_readers/amr_parsing/amr_concepts/__init__.py: -------------------------------------------------------------------------------- 1 | from .entity import Entity 2 | from .date import Date 3 | from .score import Score 4 | from .ordinal import Ordinal 5 | from .polarity import Polarity 6 | from .polite import Polite 7 | from .quantity import Quantity 8 | from .url import URL 9 | -------------------------------------------------------------------------------- /stog/data/dataset_readers/amr_parsing/amr_concepts/polite.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | class Polite: 5 | 6 | lemma_map = { 7 | 'can': 'possible' 8 | } 9 | 10 | def __init__(self, amr, dry=False): 11 | self.amr = amr 12 | self.dry = dry 13 | self.heads = [] 14 | self.true_positive = 0 15 | self.false_positive = 0 16 | 17 | def remove_polite(self): 18 | count = 0 19 | for node in self.amr.graph.get_nodes(): 20 | for attr, value in node.attributes: 21 | if attr == 'polite': 22 | if not self.dry: 23 | self.amr.graph.remove_node_attribute(node, attr, value) 24 | count += 1 25 | return count 26 | 27 | def predict_polite(self): 28 | for i in range(len(self.amr.tokens)): 29 | if self.amr.lemmas[i] == 'please': 30 | if self.amr.lemmas[i + 1: i + 3] == ['take', 'a']: 31 | self.heads.append((i, i + 3)) 32 | elif i - 2 >= 0 and self.amr.lemmas[i - 2] == 'can': 33 | self.heads.append((i, i - 2)) 34 | elif i+1 None: 17 | self.array = array 18 | self.padding_value = padding_value 19 | 20 | 21 | def get_padding_lengths(self) -> Dict[str, int]: 22 | return {"dimension_" + str(i): shape 23 | for i, shape in enumerate(self.array.shape)} 24 | 25 | 26 | def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor: 27 | max_shape = [padding_lengths["dimension_{}".format(i)] 28 | for i in range(len(padding_lengths))] 29 | 30 | return_array = numpy.ones(max_shape, "float32") * self.padding_value 31 | 32 | # If the tensor has a different shape from the largest tensor, pad dimensions with zeros to 33 | # form the right shaped list of slices for insertion into the final tensor. 34 | slicing_shape = list(self.array.shape) 35 | if len(self.array.shape) < len(max_shape): 36 | slicing_shape = slicing_shape + [0 for _ in range(len(max_shape) - len(self.array.shape))] 37 | slices = tuple([slice(0, x) for x in slicing_shape]) 38 | return_array[slices] = self.array 39 | tensor = torch.from_numpy(return_array) 40 | return tensor 41 | 42 | 43 | def empty_field(self): # pylint: disable=no-self-use 44 | # Pass the padding_value, so that any outer field, e.g., `ListField[ArrayField]` uses the 45 | # same padding_value in the padded ArrayFields 46 | return ArrayField(numpy.array([], dtype="float32"), padding_value=self.padding_value) 47 | 48 | 49 | def __str__(self) -> str: 50 | return f"ArrayField with shape: {self.array.shape}." 51 | -------------------------------------------------------------------------------- /stog/data/fields/index_field.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from overrides import overrides 4 | import torch 5 | 6 | from allennlp.data.fields.field import Field 7 | from allennlp.data.fields.sequence_field import SequenceField 8 | from allennlp.common.checks import ConfigurationError 9 | 10 | 11 | class IndexField(Field[torch.Tensor]): 12 | """ 13 | An ``IndexField`` is an index into a 14 | :class:`~allennlp.data.fields.sequence_field.SequenceField`, as might be used for representing 15 | a correct answer option in a list, or a span begin and span end position in a passage, for 16 | example. Because it's an index into a :class:`SequenceField`, we take one of those as input 17 | and use it to compute padding lengths. 18 | 19 | Parameters 20 | ---------- 21 | index : ``int`` 22 | The index of the answer in the :class:`SequenceField`. This is typically the "correct 23 | answer" in some classification decision over the sequence, like where an answer span starts 24 | in SQuAD, or which answer option is correct in a multiple choice question. A value of 25 | ``-1`` means there is no label, which can be used for padding or other purposes. 26 | sequence_field : ``SequenceField`` 27 | A field containing the sequence that this ``IndexField`` is a pointer into. 28 | """ 29 | def __init__(self, index: int, sequence_field: SequenceField) -> None: 30 | self.sequence_index = index 31 | self.sequence_field = sequence_field 32 | 33 | if not isinstance(index, int): 34 | raise ConfigurationError("IndexFields must be passed integer indices. " 35 | "Found index: {} with type: {}.".format(index, type(index))) 36 | 37 | 38 | def get_padding_lengths(self) -> Dict[str, int]: 39 | # pylint: disable=no-self-use 40 | return {} 41 | 42 | 43 | def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor: 44 | # pylint: disable=unused-argument 45 | tensor = torch.LongTensor([self.sequence_index]) 46 | return tensor 47 | 48 | 49 | def empty_field(self): 50 | return IndexField(-1, self.sequence_field.empty_field()) 51 | 52 | def __str__(self) -> str: 53 | return f"IndexField with index: {self.sequence_index}." 54 | -------------------------------------------------------------------------------- /stog/data/fields/metadata_field.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=no-self-use 2 | from typing import Any, Dict, List 3 | 4 | from overrides import overrides 5 | 6 | from stog.data.fields.field import DataArray, Field 7 | 8 | 9 | class MetadataField(Field[DataArray]): 10 | """ 11 | A ``MetadataField`` is a ``Field`` that does not get converted into tensors. It just carries 12 | side information that might be needed later on, for computing some third-party metric, or 13 | outputting debugging information, or whatever else you need. We use this in the BiDAF model, 14 | for instance, to keep track of question IDs and passage token offsets, so we can more easily 15 | use the official evaluation script to compute metrics. 16 | 17 | We don't try to do any kind of smart combination of this field for batched input - when you use 18 | this ``Field`` in a model, you'll get a list of metadata objects, one for each instance in the 19 | batch. 20 | 21 | Parameters 22 | ---------- 23 | metadata : ``Any`` 24 | Some object containing the metadata that you want to store. It's likely that you'll want 25 | this to be a dictionary, but it could be anything you want. 26 | """ 27 | def __init__(self, metadata: Any) -> None: 28 | self.metadata = metadata 29 | 30 | 31 | def get_padding_lengths(self) -> Dict[str, int]: 32 | return {} 33 | 34 | 35 | def as_tensor(self, padding_lengths: Dict[str, int]) -> DataArray: 36 | # pylint: disable=unused-argument 37 | return self.metadata # type: ignore 38 | 39 | 40 | def empty_field(self) -> 'MetadataField': 41 | return MetadataField(None) 42 | 43 | @classmethod 44 | 45 | def batch_tensors(cls, tensor_list: List[DataArray]) -> DataArray: # type: ignore 46 | return tensor_list # type: ignore 47 | 48 | 49 | def __str__(self) -> str: 50 | return f"MetadataField (print field.metadata to see specific information)." 51 | -------------------------------------------------------------------------------- /stog/data/fields/sequence_field.py: -------------------------------------------------------------------------------- 1 | from stog.data.fields.field import DataArray, Field 2 | 3 | 4 | class SequenceField(Field[DataArray]): 5 | """ 6 | A ``SequenceField`` represents a sequence of things. This class just adds a method onto 7 | ``Field``: :func:`sequence_length`. It exists so that ``SequenceLabelField``, ``IndexField`` and other 8 | similar ``Fields`` can have a single type to require, with a consistent API, whether they are 9 | pointing to words in a ``TextField``, items in a ``ListField``, or something else. 10 | """ 11 | def sequence_length(self) -> int: 12 | """ 13 | How many elements are there in this sequence? 14 | """ 15 | raise NotImplementedError 16 | 17 | def empty_field(self) -> 'SequenceField': 18 | raise NotImplementedError 19 | -------------------------------------------------------------------------------- /stog/data/iterators/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The various :class:`~stog.data.iterators.data_iterator.DataIterator` subclasses 3 | can be used to iterate over datasets with different batching and padding schemes. 4 | """ 5 | 6 | from stog.data.iterators.data_iterator import DataIterator 7 | from stog.data.iterators.basic_iterator import BasicIterator 8 | from stog.data.iterators.bucket_iterator import BucketIterator 9 | from stog.data.iterators.epoch_tracking_bucket_iterator import EpochTrackingBucketIterator 10 | from stog.data.iterators.multiprocess_iterator import MultiprocessIterator 11 | -------------------------------------------------------------------------------- /stog/data/iterators/basic_iterator.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | import logging 3 | import random 4 | 5 | from stog.utils import lazy_groups_of 6 | from stog.data.instance import Instance 7 | from stog.data.iterators.data_iterator import DataIterator 8 | from stog.data.dataset import Batch 9 | 10 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 11 | 12 | 13 | @DataIterator.register("basic") 14 | class BasicIterator(DataIterator): 15 | """ 16 | A very basic iterator that takes a dataset, possibly shuffles it, and creates fixed sized batches. 17 | 18 | It takes the same parameters as :class:`stog.data.iterators.DataIterator` 19 | """ 20 | def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]: 21 | # First break the dataset into memory-sized lists: 22 | for instance_list in self._memory_sized_lists(instances): 23 | if shuffle: 24 | random.shuffle(instance_list) 25 | iterator = iter(instance_list) 26 | # Then break each memory-sized list into batches. 27 | for batch_instances in lazy_groups_of(iterator, self._batch_size): 28 | for possibly_smaller_batches in self._ensure_batch_is_sufficiently_small(batch_instances): 29 | batch = Batch(possibly_smaller_batches) 30 | yield batch 31 | -------------------------------------------------------------------------------- /stog/data/iterators/epoch_tracking_bucket_iterator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Tuple 3 | import warnings 4 | 5 | from stog.data.iterators.data_iterator import DataIterator 6 | from stog.data.iterators.bucket_iterator import BucketIterator 7 | 8 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 9 | 10 | 11 | @DataIterator.register("epoch_tracking_bucket") 12 | class EpochTrackingBucketIterator(BucketIterator): 13 | """ 14 | This is essentially a :class:`stog.data.iterators.BucketIterator` with just one difference. 15 | It keeps track of the epoch number, and adds that as an additional meta field to each instance. 16 | That way, ``Model.forward`` will have access to this information. We do this by keeping track of 17 | epochs globally, and incrementing them whenever the iterator is called. However, the iterator is 18 | called both for training and validation sets. So, we keep a dict of epoch numbers, one key per 19 | dataset. 20 | 21 | Parameters 22 | ---------- 23 | See :class:`BucketIterator`. 24 | """ 25 | def __init__(self, 26 | sorting_keys: List[Tuple[str, str]], 27 | padding_noise: float = 0.1, 28 | biggest_batch_first: bool = False, 29 | batch_size: int = 32, 30 | instances_per_epoch: int = None, 31 | max_instances_in_memory: int = None, 32 | cache_instances: bool = False) -> None: 33 | super().__init__(sorting_keys=sorting_keys, 34 | padding_noise=padding_noise, 35 | biggest_batch_first=biggest_batch_first, 36 | batch_size=batch_size, 37 | instances_per_epoch=instances_per_epoch, 38 | max_instances_in_memory=max_instances_in_memory, 39 | track_epoch=True, 40 | cache_instances=cache_instances) 41 | warnings.warn("EpochTrackingBucketIterator is deprecated, " 42 | "please just use BucketIterator with track_epoch=True", 43 | DeprecationWarning) 44 | -------------------------------------------------------------------------------- /stog/data/token_indexers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A ``TokenIndexer`` determines how string tokens get represented as arrays of indices in a model. 3 | """ 4 | 5 | from stog.data.token_indexers.dep_label_indexer import DepLabelIndexer 6 | from stog.data.token_indexers.ner_tag_indexer import NerTagIndexer 7 | from stog.data.token_indexers.pos_tag_indexer import PosTagIndexer 8 | from stog.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer 9 | from stog.data.token_indexers.token_characters_indexer import TokenCharactersIndexer 10 | from stog.data.token_indexers.token_indexer import TokenIndexer 11 | from stog.data.token_indexers.elmo_indexer import ELMoTokenCharactersIndexer 12 | from stog.data.token_indexers.openai_transformer_byte_pair_indexer import OpenaiTransformerBytePairIndexer 13 | -------------------------------------------------------------------------------- /stog/data/token_indexers/dep_label_indexer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, List, Set 3 | 4 | from overrides import overrides 5 | 6 | from stog.utils.string import pad_sequence_to_length 7 | from stog.data.vocabulary import Vocabulary 8 | from stog.data.tokenizers.token import Token 9 | from stog.data.token_indexers.token_indexer import TokenIndexer 10 | 11 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 12 | 13 | 14 | @TokenIndexer.register("dependency_label") 15 | class DepLabelIndexer(TokenIndexer[int]): 16 | """ 17 | This :class:`TokenIndexer` represents tokens by their syntactic dependency label, as determined 18 | by the ``dep_`` field on ``Token``. 19 | 20 | Parameters 21 | ---------- 22 | namespace : ``str``, optional (default=``dep_labels``) 23 | We will use this namespace in the :class:`Vocabulary` to map strings to indices. 24 | """ 25 | # pylint: disable=no-self-use 26 | def __init__(self, namespace: str = 'dep_labels') -> None: 27 | self.namespace = namespace 28 | self._logged_errors: Set[str] = set() 29 | 30 | 31 | def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]): 32 | dep_label = token.dep_ 33 | if not dep_label: 34 | if token.text not in self._logged_errors: 35 | logger.warning("Token had no dependency label: %s", token.text) 36 | self._logged_errors.add(token.text) 37 | dep_label = 'NONE' 38 | counter[self.namespace][dep_label] += 1 39 | 40 | 41 | def tokens_to_indices(self, 42 | tokens: List[Token], 43 | vocabulary: Vocabulary, 44 | index_name: str) -> Dict[str, List[int]]: 45 | dep_labels = [token.dep_ or 'NONE' for token in tokens] 46 | 47 | return {index_name: [vocabulary.get_token_index(dep_label, self.namespace) for dep_label in dep_labels]} 48 | 49 | 50 | def get_padding_token(self) -> int: 51 | return 0 52 | 53 | 54 | def get_padding_lengths(self, token: int) -> Dict[str, int]: # pylint: disable=unused-argument 55 | return {} 56 | 57 | 58 | def pad_token_sequence(self, 59 | tokens: Dict[str, List[int]], 60 | desired_num_tokens: Dict[str, int], 61 | padding_lengths: Dict[str, int]) -> Dict[str, List[int]]: # pylint: disable=unused-argument 62 | return {key: pad_sequence_to_length(val, desired_num_tokens[key]) 63 | for key, val in tokens.items()} 64 | -------------------------------------------------------------------------------- /stog/data/token_indexers/ner_tag_indexer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, List 3 | 4 | from overrides import overrides 5 | 6 | from stog.utils.string import pad_sequence_to_length 7 | from stog.data.vocabulary import Vocabulary 8 | from stog.data.tokenizers.token import Token 9 | from stog.data.token_indexers.token_indexer import TokenIndexer 10 | 11 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 12 | 13 | 14 | @TokenIndexer.register("ner_tag") 15 | class NerTagIndexer(TokenIndexer[int]): 16 | """ 17 | This :class:`TokenIndexer` represents tokens by their entity type (i.e., their NER tag), as 18 | determined by the ``ent_type_`` field on ``Token``. 19 | 20 | Parameters 21 | ---------- 22 | namespace : ``str``, optional (default=``ner_tags``) 23 | We will use this namespace in the :class:`Vocabulary` to map strings to indices. 24 | """ 25 | # pylint: disable=no-self-use 26 | def __init__(self, namespace: str = 'ner_tags') -> None: 27 | self._namespace = namespace 28 | 29 | 30 | def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]): 31 | tag = token.ent_type_ 32 | if not tag: 33 | tag = 'NONE' 34 | counter[self._namespace][tag] += 1 35 | 36 | 37 | def tokens_to_indices(self, 38 | tokens: List[Token], 39 | vocabulary: Vocabulary, 40 | index_name: str) -> Dict[str, List[int]]: 41 | tags = ['NONE' if token.ent_type_ is None else token.ent_type_ for token in tokens] 42 | 43 | return {index_name: [vocabulary.get_token_index(tag, self._namespace) for tag in tags]} 44 | 45 | 46 | def get_padding_token(self) -> int: 47 | return 0 48 | 49 | 50 | def get_padding_lengths(self, token: int) -> Dict[str, int]: # pylint: disable=unused-argument 51 | return {} 52 | 53 | 54 | def pad_token_sequence(self, 55 | tokens: Dict[str, List[int]], 56 | desired_num_tokens: Dict[str, int], 57 | padding_lengths: Dict[str, int]) -> Dict[str, List[int]]: # pylint: disable=unused-argument 58 | return {key: pad_sequence_to_length(val, desired_num_tokens[key]) 59 | for key, val in tokens.items()} 60 | -------------------------------------------------------------------------------- /stog/data/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains various classes for performing 3 | tokenization, stemming, and filtering. 4 | """ 5 | 6 | from stog.data.tokenizers.tokenizer import Token, Tokenizer 7 | from stog.data.tokenizers.word_tokenizer import WordTokenizer 8 | from stog.data.tokenizers.character_tokenizer import CharacterTokenizer 9 | -------------------------------------------------------------------------------- /stog/data/tokenizers/bert_tokenizer.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | 3 | import numpy as np 4 | from pytorch_pretrained_bert.tokenization import BertTokenizer 5 | 6 | from stog.data.vocabulary import DEFAULT_PADDING_TOKEN, DEFAULT_OOV_TOKEN 7 | 8 | 9 | class AMRBertTokenizer(BertTokenizer): 10 | 11 | def __init__(self, *args, **kwargs): 12 | super(AMRBertTokenizer, self).__init__(*args, **kwargs) 13 | 14 | 15 | def tokenize(self, tokens, split=False): 16 | tokens = ['[CLS]'] + tokens + ['[SEP]'] 17 | if not split: 18 | split_tokens = [t if t in self.vocab else '[UNK]' for t in tokens] 19 | gather_indexes = None 20 | else: 21 | split_tokens, _gather_indexes = [], [] 22 | for token in tokens: 23 | indexes = [] 24 | for i, sub_token in enumerate(self.wordpiece_tokenizer.tokenize(token)): 25 | indexes.append(len(split_tokens)) 26 | split_tokens.append(sub_token) 27 | _gather_indexes.append(indexes) 28 | 29 | _gather_indexes = _gather_indexes[1:-1] 30 | max_index_list_len = max(len(indexes) for indexes in _gather_indexes) 31 | gather_indexes = np.zeros((len(_gather_indexes), max_index_list_len)) 32 | for i, indexes in enumerate(_gather_indexes): 33 | for j, index in enumerate(indexes): 34 | gather_indexes[i, j] = index 35 | 36 | token_ids = np.array(self.convert_tokens_to_ids(split_tokens)) 37 | return token_ids, gather_indexes 38 | -------------------------------------------------------------------------------- /stog/data/tokenizers/tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from stog.utils.registrable import Registrable 4 | from stog.data.tokenizers.token import Token 5 | 6 | 7 | class Tokenizer(Registrable): 8 | """ 9 | A ``Tokenizer`` splits strings of text into tokens. Typically, this either splits text into 10 | word tokens or character tokens, and those are the two tokenizer subclasses we have implemented 11 | here, though you could imagine wanting to do other kinds of tokenization for structured or 12 | other inputs. 13 | 14 | As part of tokenization, concrete implementations of this API will also handle stemming, 15 | stopword filtering, adding start and end tokens, or other kinds of things you might want to do 16 | to your tokens. See the parameters to, e.g., :class:`~.WordTokenizer`, or whichever tokenizer 17 | you want to use. 18 | 19 | If the base input to your model is words, you should use a :class:`~.WordTokenizer`, even if 20 | you also want to have a character-level encoder to get an additional vector for each word 21 | token. Splitting word tokens into character arrays is handled separately, in the 22 | :class:`..token_representations.TokenRepresentation` class. 23 | """ 24 | default_implementation = 'word' 25 | 26 | def batch_tokenize(self, texts: List[str]) -> List[List[Token]]: 27 | """ 28 | Batches together tokenization of several texts, in case that is faster for particular 29 | tokenizers. 30 | """ 31 | raise NotImplementedError 32 | 33 | def tokenize(self, text: str) -> List[Token]: 34 | """ 35 | Actually implements splitting words into tokens. 36 | 37 | Returns 38 | ------- 39 | tokens : ``List[Token]`` 40 | """ 41 | raise NotImplementedError 42 | -------------------------------------------------------------------------------- /stog/data/tokenizers/word_stemmer.py: -------------------------------------------------------------------------------- 1 | from nltk.stem import PorterStemmer as NltkPorterStemmer 2 | from overrides import overrides 3 | 4 | from stog.utils.registrable import Registrable 5 | from stog.data.tokenizers.token import Token 6 | 7 | 8 | class WordStemmer(Registrable): 9 | """ 10 | A ``WordStemmer`` lemmatizes words. This means that we map words to their root form, so that, 11 | e.g., "have", "has", and "had" all have the same internal representation. 12 | 13 | You should think carefully about whether and how much stemming you want in your model. Kind of 14 | the whole point of using word embeddings is so that you don't have to do this, but in a highly 15 | inflected language, or in a low-data setting, you might need it anyway. The default 16 | ``WordStemmer`` does nothing, just returning the work token as-is. 17 | """ 18 | default_implementation = 'pass_through' 19 | 20 | def stem_word(self, word: Token) -> Token: 21 | """ 22 | Returns a new ``Token`` with ``word.text`` replaced by a stemmed word. 23 | """ 24 | raise NotImplementedError 25 | 26 | 27 | @WordStemmer.register('pass_through') 28 | class PassThroughWordStemmer(WordStemmer): 29 | """ 30 | Does not stem words; it's a no-op. This is the default word stemmer. 31 | """ 32 | 33 | def stem_word(self, word: Token) -> Token: 34 | return word 35 | 36 | 37 | @WordStemmer.register('porter') 38 | class PorterStemmer(WordStemmer): 39 | """ 40 | Uses NLTK's PorterStemmer to stem words. 41 | """ 42 | def __init__(self): 43 | self.stemmer = NltkPorterStemmer() 44 | 45 | 46 | def stem_word(self, word: Token) -> Token: 47 | new_text = self.stemmer.stem(word.text) 48 | return Token(text=new_text, 49 | idx=word.idx, 50 | lemma=word.lemma_, 51 | pos=word.pos_, 52 | tag=word.tag_, 53 | dep=word.dep_, 54 | ent_type=word.ent_type_, 55 | text_id=getattr(word, 'text_id', None)) 56 | -------------------------------------------------------------------------------- /stog/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from .attachment_score import AttachmentScores 4 | 5 | from stog.utils import logging 6 | 7 | 8 | logger = logging.init_logger() 9 | 10 | 11 | def dump_metrics(file_path: str, metrics, log: bool = False) -> None: 12 | metrics_json = json.dumps(metrics, indent=2) 13 | with open(file_path, "w") as metrics_file: 14 | metrics_file.write(metrics_json) 15 | if log: 16 | logger.info("Metrics: %s", metrics_json) 17 | -------------------------------------------------------------------------------- /stog/metrics/metric.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Tuple, Union 2 | import torch 3 | 4 | 5 | 6 | class Metric: 7 | """ 8 | A very general abstract class representing a metric which can be 9 | accumulated. 10 | """ 11 | def __call__(self, 12 | predictions: torch.Tensor, 13 | gold_labels: torch.Tensor, 14 | mask: Optional[torch.Tensor]): 15 | """ 16 | Parameters 17 | ---------- 18 | predictions : ``torch.Tensor``, required. 19 | A tensor of predictions. 20 | gold_labels : ``torch.Tensor``, required. 21 | A tensor corresponding to some gold label to evaluate against. 22 | mask: ``torch.Tensor``, optional (default = None). 23 | A mask can be passed, in order to deal with metrics which are 24 | computed over potentially padded elements, such as sequence labels. 25 | """ 26 | raise NotImplementedError 27 | 28 | def get_metric(self, reset: bool) -> Union[float, Tuple[float, ...], Dict[str, float]]: 29 | """ 30 | Compute and return the metric. Optionally also call :func:`self.reset`. 31 | """ 32 | raise NotImplementedError 33 | 34 | def reset(self) -> None: 35 | """ 36 | Reset any accumulators or internal state. 37 | """ 38 | raise NotImplementedError 39 | 40 | @staticmethod 41 | def unwrap_to_tensors(*tensors: torch.Tensor): 42 | """ 43 | If you actually passed gradient-tracking Tensors to a Metric, there will be 44 | a huge memory leak, because it will prevent garbage collection for the computation 45 | graph. This method ensures that you're using tensors directly and that they are on 46 | the CPU. 47 | """ 48 | return (x.detach().cpu() if isinstance(x, torch.Tensor) else x for x in tensors) 49 | -------------------------------------------------------------------------------- /stog/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .stog import STOG 2 | from .model import Model 3 | -------------------------------------------------------------------------------- /stog/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .augmented_lstm import AugmentedLstm 2 | from .encoder_base import _EncoderBase 3 | from .input_variational_dropout import InputVariationalDropout 4 | from .optimizer import MultipleOptimizer, Optimizer 5 | from .stacked_bilstm import StackedBidirectionalLstm 6 | from .stacked_lstm import StackedLstm 7 | from .time_distributed import TimeDistributed 8 | -------------------------------------------------------------------------------- /stog/modules/attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .dot_production_attention import DotProductAttention 2 | from .biaffine_attention import BiaffineAttention 3 | from .mlp_attention import MLPAttention 4 | -------------------------------------------------------------------------------- /stog/modules/attention/dot_production_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DotProductAttention(torch.nn.Module): 5 | 6 | def __init__(self, decoder_hidden_size, encoder_hidden_size, share_linear=True): 7 | super(DotProductAttention, self).__init__() 8 | self.decoder_hidden_size = decoder_hidden_size 9 | self.encoder_hidden_size = encoder_hidden_size 10 | self.linear_layer = torch.nn.Linear(decoder_hidden_size, encoder_hidden_size, bias=False) 11 | self.share_linear = share_linear 12 | 13 | def forward(self, decoder_input, encoder_input): 14 | """ 15 | :param decoder_input: [batch, decoder_seq_length, decoder_hidden_size] 16 | :param encoder_input: [batch, encoder_seq_length, encoder_hidden_size] 17 | :return: [batch, decoder_seq_length, encoder_seq_length] 18 | """ 19 | decoder_input = self.linear_layer(decoder_input) 20 | if self.share_linear: 21 | encoder_input = self.linear_layer(encoder_input) 22 | 23 | encoder_input = encoder_input.transpose(1, 2) 24 | return torch.bmm(decoder_input, encoder_input) 25 | -------------------------------------------------------------------------------- /stog/modules/attention_layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emorynlp/levi-graph-amr-parser/f71f1056c13181b8db31d6136451fb8d57114819/stog/modules/attention_layers/__init__.py -------------------------------------------------------------------------------- /stog/modules/decoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emorynlp/levi-graph-amr-parser/f71f1056c13181b8db31d6136451fb8d57114819/stog/modules/decoders/__init__.py -------------------------------------------------------------------------------- /stog/modules/decoders/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from stog.metrics.seq2seq_metrics import Seq2SeqMetrics 4 | 5 | class Generator(torch.nn.Module): 6 | 7 | def __init__(self, input_size, vocab_size, pad_idx): 8 | super(Generator, self).__init__() 9 | self._generator = torch.nn.Sequential( 10 | torch.nn.Linear(input_size, vocab_size), 11 | torch.nn.LogSoftmax(dim=-1) 12 | ) 13 | self.criterion = torch.nn.NLLLoss( 14 | ignore_index=pad_idx, reduction='sum' 15 | ) 16 | self.metrics = Seq2SeqMetrics() 17 | self.pad_idx = pad_idx 18 | 19 | def forward(self, inputs): 20 | """Transform inputs to vocab-size space and compute logits. 21 | 22 | :param inputs: [batch, seq_length, input_size] 23 | :return: [batch, seq_length, vocab_size] 24 | """ 25 | batch_size, seq_length, _ = inputs.size() 26 | inputs = inputs.view(batch_size * seq_length, -1) 27 | scores = self._generator(inputs) 28 | scores = scores.view(batch_size, seq_length, -1) 29 | _, predictions = scores.max(2) 30 | return dict( 31 | scores=scores, 32 | predictions=predictions 33 | ) 34 | 35 | def compute_loss(self, inputs, targets): 36 | batch_size, seq_length, _ = inputs.size() 37 | output = self(inputs) 38 | scores = output['scores'].view(batch_size * seq_length, -1) 39 | predictions = output['predictions'].view(-1) 40 | targets = targets.view(-1) 41 | 42 | loss = self.criterion(scores, targets) 43 | 44 | non_pad = targets.ne(self.pad_idx) 45 | num_correct = predictions.eq(targets).masked_select(non_pad).sum().item() 46 | num_non_pad = non_pad.sum().item() 47 | self.metrics(loss.item(), num_non_pad, num_correct) 48 | 49 | return dict( 50 | loss=loss.div(float(num_non_pad)), 51 | predictions=output['predictions'] 52 | ) 53 | 54 | @classmethod 55 | def from_params(cls, params): 56 | return cls( 57 | input_size=params['input_size'], 58 | vocab_size=params['vocab_size'], 59 | pad_idx=params['pad_idx'] 60 | ) 61 | -------------------------------------------------------------------------------- /stog/modules/input_variational_dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class InputVariationalDropout(torch.nn.Dropout): 4 | """ 5 | Apply the dropout technique in Gal and Ghahramani, "Dropout as a Bayesian Approximation: 6 | Representing Model Uncertainty in Deep Learning" (https://arxiv.org/abs/1506.02142) to a 7 | 3D tensor. 8 | This module accepts a 3D tensor of shape ``(batch_size, num_timesteps, embedding_dim)`` 9 | and samples a single dropout mask of shape ``(batch_size, embedding_dim)`` and applies 10 | it to every time step. 11 | """ 12 | def forward(self, input_tensor): 13 | # pylint: disable=arguments-differ 14 | """ 15 | Apply dropout to input tensor. 16 | Parameters 17 | ---------- 18 | input_tensor: ``torch.FloatTensor`` 19 | A tensor of shape ``(batch_size, num_timesteps, embedding_dim)`` 20 | Returns 21 | ------- 22 | output: ``torch.FloatTensor`` 23 | A tensor of shape ``(batch_size, num_timesteps, embedding_dim)`` with dropout applied. 24 | """ 25 | ones = input_tensor.data.new_ones(input_tensor.shape[0], input_tensor.shape[-1]) 26 | dropout_mask = torch.nn.functional.dropout(ones, self.p, self.training, inplace=False) 27 | if self.inplace: 28 | input_tensor *= dropout_mask.unsqueeze(1) 29 | return None 30 | else: 31 | return dropout_mask.unsqueeze(1) * input_tensor 32 | -------------------------------------------------------------------------------- /stog/modules/linear/__init__.py: -------------------------------------------------------------------------------- 1 | from stog.modules.linear.bilinear import BiLinear 2 | -------------------------------------------------------------------------------- /stog/modules/seq2seq_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from stog.modules.seq2seq_encoders.pytorch_seq2seq_wrapper import PytorchSeq2SeqWrapper 2 | from stog.modules.seq2seq_encoders.seq2seq_bert_encoder import Seq2SeqBertEncoder 3 | -------------------------------------------------------------------------------- /stog/modules/seq2seq_encoders/seq2seq_encoder.py: -------------------------------------------------------------------------------- 1 | from stog.modules.encoder_base import _EncoderBase 2 | 3 | 4 | class Seq2SeqEncoder(_EncoderBase): 5 | """ 6 | Adopted from AllenNLP: 7 | https://github.com/allenai/allennlp/blob/v0.6.1/allennlp/modules/seq2seq_encoders/seq2seq_encoder.py 8 | 9 | A ``Seq2SeqEncoder`` is a ``Module`` that takes as input a sequence of vectors and returns a 10 | modified sequence of vectors. Input shape: ``(batch_size, sequence_length, input_dim)``; output 11 | shape: ``(batch_size, sequence_length, output_dim)``. 12 | We add two methods to the basic ``Module`` API: :func:`get_input_dim()` and :func:`get_output_dim()`. 13 | You might need this if you want to construct a ``Linear`` layer using the output of this encoder, 14 | or to raise sensible errors for mis-matching input dimensions. 15 | """ 16 | def get_input_dim(self) -> int: 17 | """ 18 | Returns the dimension of the vector input for each element in the sequence input 19 | to a ``Seq2SeqEncoder``. This is `not` the shape of the input tensor, but the 20 | last element of that shape. 21 | """ 22 | raise NotImplementedError 23 | 24 | def get_output_dim(self) -> int: 25 | """ 26 | Returns the dimension of each vector in the sequence output by this ``Seq2SeqEncoder``. 27 | This is `not` the shape of the returned tensor, but the last element of that shape. 28 | """ 29 | raise NotImplementedError 30 | 31 | def is_bidirectional(self) -> bool: 32 | """ 33 | Returns ``True`` if this encoder is bidirectional. If so, we assume the forward direction 34 | of the encoder is the first half of the final dimension, and the backward direction is the 35 | second half. 36 | """ 37 | raise NotImplementedError 38 | -------------------------------------------------------------------------------- /stog/modules/seq2vec_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /stog/modules/seq2vec_encoders/boe_encoder.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | 3 | import torch 4 | 5 | from stog.modules.seq2vec_encoders.seq2vec_encoder import Seq2VecEncoder 6 | from stog.utils.nn import get_lengths_from_binary_sequence_mask 7 | 8 | class BagOfEmbeddingsEncoder(Seq2VecEncoder): 9 | """ 10 | A ``BagOfEmbeddingsEncoder`` is a simple :class:`Seq2VecEncoder` which simply sums the embeddings 11 | of a sequence across the time dimension. The input to this module is of shape ``(batch_size, num_tokens, 12 | embedding_dim)``, and the output is of shape ``(batch_size, embedding_dim)``. 13 | 14 | Parameters 15 | ---------- 16 | embedding_dim: ``int`` 17 | This is the input dimension to the encoder. 18 | averaged: ``bool``, optional (default=``False``) 19 | If ``True``, this module will average the embeddings across time, rather than simply summing 20 | (ie. we will divide the summed embeddings by the length of the sentence). 21 | """ 22 | def __init__(self, 23 | embedding_dim: int, 24 | averaged: bool = False) -> None: 25 | super(BagOfEmbeddingsEncoder, self).__init__() 26 | self._embedding_dim = embedding_dim 27 | self._averaged = averaged 28 | 29 | 30 | def get_input_dim(self) -> int: 31 | return self._embedding_dim 32 | 33 | 34 | def get_output_dim(self) -> int: 35 | return self._embedding_dim 36 | 37 | def forward(self, tokens: torch.Tensor, mask: torch.Tensor = None): #pylint: disable=arguments-differ 38 | if mask is not None: 39 | tokens = tokens * mask.unsqueeze(-1).float() 40 | 41 | # Our input has shape `(batch_size, num_tokens, embedding_dim)`, so we sum out the `num_tokens` 42 | # dimension. 43 | summed = tokens.sum(1) 44 | 45 | if self._averaged: 46 | if mask is not None: 47 | lengths = get_lengths_from_binary_sequence_mask(mask) 48 | length_mask = (lengths > 0) 49 | 50 | # Set any length 0 to 1, to avoid dividing by zero. 51 | lengths = torch.max(lengths, lengths.new_ones(1)) 52 | else: 53 | lengths = tokens.new_full((1,), fill_value=tokens.size(1)) 54 | length_mask = None 55 | 56 | summed = summed / lengths.unsqueeze(-1).float() 57 | 58 | if length_mask is not None: 59 | summed = summed * (length_mask > 0).float().unsqueeze(-1) 60 | 61 | return summed 62 | -------------------------------------------------------------------------------- /stog/modules/seq2vec_encoders/seq2vec_encoder.py: -------------------------------------------------------------------------------- 1 | from stog.modules.encoder_base import _EncoderBase 2 | 3 | 4 | class Seq2VecEncoder(_EncoderBase): 5 | """ 6 | A ``Seq2VecEncoder`` is a ``Module`` that takes as input a sequence of vectors and returns a 7 | single vector. Input shape: ``(batch_size, sequence_length, input_dim)``; output shape: 8 | ``(batch_size, output_dim)``. 9 | 10 | We add two methods to the basic ``Module`` API: :func:`get_input_dim()` and :func:`get_output_dim()`. 11 | You might need this if you want to construct a ``Linear`` layer using the output of this encoder, 12 | or to raise sensible errors for mis-matching input dimensions. 13 | """ 14 | def get_input_dim(self) -> int: 15 | """ 16 | Returns the dimension of the vector input for each element in the sequence input 17 | to a ``Seq2VecEncoder``. This is `not` the shape of the input tensor, but the 18 | last element of that shape. 19 | """ 20 | raise NotImplementedError 21 | 22 | def get_output_dim(self) -> int: 23 | """ 24 | Returns the dimension of the final vector output by this ``Seq2VecEncoder``. This is `not` 25 | the shape of the returned tensor, but the last element of that shape. 26 | """ 27 | raise NotImplementedError 28 | -------------------------------------------------------------------------------- /stog/modules/text_field_embedders/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A :class:`~allennlp.modules.text_field_embedders.text_field_embedder.TextFieldEmbedder` 3 | is a ``Module`` that takes as input the ``dict`` of NumPy arrays 4 | produced by a :class:`~allennlp.data.fields.text_field.TextField` and 5 | returns as output an embedded representation of the tokens in that field. 6 | """ 7 | 8 | from .text_field_embedder import TextFieldEmbedder 9 | from .basic_text_field_embedder import BasicTextFieldEmbedder 10 | -------------------------------------------------------------------------------- /stog/modules/time_distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adopted from AllenNLP: 3 | https://github.com/allenai/allennlp/blob/v0.6.1/allennlp/modules/time_distributed.py 4 | 5 | A wrapper that unrolls the second (time) dimension of a tensor 6 | into the first (batch) dimension, applies some other ``Module``, 7 | and then rolls the time dimension back up. 8 | """ 9 | 10 | import torch 11 | 12 | 13 | class TimeDistributed(torch.nn.Module): 14 | """ 15 | Given an input shaped like ``(batch_size, time_steps, [rest])`` and a ``Module`` that takes 16 | inputs like ``(batch_size, [rest])``, ``TimeDistributed`` reshapes the input to be 17 | ``(batch_size * time_steps, [rest])``, applies the contained ``Module``, then reshapes it back. 18 | Note that while the above gives shapes with ``batch_size`` first, this ``Module`` also works if 19 | ``batch_size`` is second - we always just combine the first two dimensions, then split them. 20 | """ 21 | def __init__(self, module): 22 | super(TimeDistributed, self).__init__() 23 | self._module = module 24 | 25 | def forward(self, *inputs): # pylint: disable=arguments-differ 26 | reshaped_inputs = [] 27 | for input_tensor in inputs: 28 | input_size = input_tensor.size() 29 | if len(input_size) <= 2: 30 | raise RuntimeError("No dimension to distribute: " + str(input_size)) 31 | 32 | # Squash batch_size and time_steps into a single axis; result has shape 33 | # (batch_size * time_steps, input_size). 34 | squashed_shape = [-1] + [x for x in input_size[2:]] 35 | reshaped_inputs.append(input_tensor.contiguous().view(*squashed_shape)) 36 | 37 | reshaped_outputs = self._module(*reshaped_inputs) 38 | 39 | # Now get the output back into the right shape. 40 | # (batch_size, time_steps, [hidden_size]) 41 | new_shape = [input_size[0], input_size[1]] + [x for x in reshaped_outputs.size()[1:]] 42 | outputs = reshaped_outputs.contiguous().view(*new_shape) 43 | 44 | return outputs 45 | -------------------------------------------------------------------------------- /stog/modules/token_embedders/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A :class:`~stog.modules.embedding.token_embedders.token_embedder.TokenEmbedder` is a ``Module`` that 3 | embeds one-hot-encoded tokens as vectors. 4 | """ 5 | 6 | from stog.modules.token_embedders.embedding import Embedding 7 | #from stog.modules.token_embedders.elmo_token_embedder import ElmoTokenEmbedder 8 | -------------------------------------------------------------------------------- /stog/modules/token_embedders/token_embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class TokenEmbedder(torch.nn.Module): 4 | """ 5 | A ``TokenEmbedder`` is a ``Module`` that takes as input a tensor with integer ids that have 6 | been output from a :class:`~allennlp.data.TokenIndexer` and outputs a vector per token in the 7 | input. The input typically has shape ``(batch_size, num_tokens)`` or ``(batch_size, 8 | num_tokens, num_characters)``, and the output is of shape ``(batch_size, num_tokens, 9 | output_dim)``. The simplest ``TokenEmbedder`` is just an embedding layer, but for 10 | character-level input, it could also be some kind of character encoder. 11 | 12 | We add a single method to the basic ``Module`` API: :func:`get_output_dim()`. This lets us 13 | more easily compute output dimensions for the :class:`~allennlp.modules.TextFieldEmbedder`, 14 | which we might need when defining model parameters such as LSTMs or linear layers, which need 15 | to know their input dimension before the layers are called. 16 | """ 17 | default_implementation = "embedding" 18 | 19 | def get_output_dim(self) -> int: 20 | """ 21 | Returns the final output dimension that this ``TokenEmbedder`` uses to represent each 22 | token. This is `not` the shape of the returned tensor, but the last element of that shape. 23 | """ 24 | raise NotImplementedError 25 | -------------------------------------------------------------------------------- /stog/predictors/__init__.py: -------------------------------------------------------------------------------- 1 | from stog.predictors.predictor import Predictor 2 | from stog.predictors.stog import STOGPredictor 3 | -------------------------------------------------------------------------------- /stog/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emorynlp/levi-graph-amr-parser/f71f1056c13181b8db31d6136451fb8d57114819/stog/training/__init__.py -------------------------------------------------------------------------------- /stog/training/tensorboard.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from tensorboardX import SummaryWriter 5 | 6 | 7 | class TensorboardWriter: 8 | """ 9 | Wraps a pair of ``SummaryWriter`` instances but is a no-op if they're ``None``. 10 | Allows Tensorboard logging without always checking for Nones first. 11 | """ 12 | def __init__(self, train_log=None, dev_log=None) -> None: 13 | self._train_log = SummaryWriter(train_log) if train_log is not None else None 14 | self._dev_log = SummaryWriter(dev_log) if dev_log is not None else None 15 | 16 | @staticmethod 17 | def _item(value: Any): 18 | if hasattr(value, 'item'): 19 | val = value.item() 20 | else: 21 | val = value 22 | return val 23 | 24 | def add_train_scalar(self, name: str, value: float, global_step: int) -> None: 25 | # get the scalar 26 | if self._train_log is not None: 27 | self._train_log.add_scalar(name, self._item(value), global_step) 28 | 29 | def add_train_histogram(self, name: str, values: torch.Tensor, global_step: int) -> None: 30 | if self._train_log is not None: 31 | if isinstance(values, torch.Tensor): 32 | values_to_write = values.cpu().data.numpy().flatten() 33 | self._train_log.add_histog.am(name, values_to_write, global_step) 34 | 35 | def add_dev_scalar(self, name: str, value: float, global_step: int) -> None: 36 | 37 | if self._dev_log is not None: 38 | self._dev_log.add_scalar(name, self._item(value), global_step) 39 | 40 | -------------------------------------------------------------------------------- /stog/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .exception_hook import ExceptionHook 2 | from typing import Iterator, List, TypeVar, Iterable, Dict 3 | import random 4 | from itertools import zip_longest, islice 5 | 6 | A = TypeVar('A') 7 | def lazy_groups_of(iterator: Iterator[A], group_size: int) -> Iterator[List[A]]: 8 | """ 9 | Takes an iterator and batches the invididual instances into lists of the 10 | specified size. The last list may be smaller if there are instances left over. 11 | """ 12 | return iter(lambda: list(islice(iterator, 0, group_size)), []) 13 | 14 | def ensure_list(iterable: Iterable[A]) -> List[A]: 15 | """ 16 | An Iterable may be a list or a generator. 17 | This ensures we get a list without making an unnecessary copy. 18 | """ 19 | if isinstance(iterable, list): 20 | return iterable 21 | else: 22 | return list(iterable) 23 | 24 | def is_lazy(iterable: Iterable[A]) -> bool: 25 | """ 26 | Checks if the given iterable is lazy, 27 | which here just means it's not a list. 28 | """ 29 | return not isinstance(iterable, list) 30 | 31 | def add_noise_to_dict_values(dictionary: Dict[A, float], noise_param: float) -> Dict[A, float]: 32 | """ 33 | Returns a new dictionary with noise added to every key in ``dictionary``. The noise is 34 | uniformly distributed within ``noise_param`` percent of the value for every value in the 35 | dictionary. 36 | """ 37 | new_dict = {} 38 | for key, value in dictionary.items(): 39 | noise_value = value * noise_param 40 | noise = random.uniform(-noise_value, noise_value) 41 | new_dict[key] = value + noise 42 | return new_dict 43 | -------------------------------------------------------------------------------- /stog/utils/checks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adopted from AllenNLP: 3 | https://github.com/allenai/allennlp/tree/v0.6.1/allennlp/common 4 | 5 | Functions and exceptions for checking that 6 | AllenNLP and its models are configured correctly. 7 | """ 8 | 9 | from torch import cuda 10 | 11 | from stog.utils import logging 12 | 13 | logger = logging.init_logger() # pylint: disable=invalid-name 14 | 15 | 16 | class ConfigurationError(Exception): 17 | """ 18 | The exception raised by any AllenNLP object when it's misconfigured 19 | (e.g. missing properties, invalid properties, unknown properties). 20 | """ 21 | 22 | def __init__(self, message): 23 | super(ConfigurationError, self).__init__() 24 | self.message = message 25 | 26 | def __str__(self): 27 | return repr(self.message) 28 | 29 | 30 | def log_pytorch_version_info(): 31 | import torch 32 | logger.info("Pytorch version: %s", torch.__version__) 33 | 34 | 35 | def check_dimensions_match(dimension_1: int, 36 | dimension_2: int, 37 | dim_1_name: str, 38 | dim_2_name: str) -> None: 39 | if dimension_1 != dimension_2: 40 | raise ConfigurationError(f"{dim_1_name} must match {dim_2_name}, but got {dimension_1} " 41 | f"and {dimension_2} instead") 42 | 43 | 44 | def check_for_gpu(device_id: int): 45 | if device_id is not None and device_id >= cuda.device_count(): 46 | raise ConfigurationError("Experiment specified a GPU but none is available;" 47 | " if you want to run on CPU use the override" 48 | " 'trainer.cuda_device=-1' in the json config file.") 49 | -------------------------------------------------------------------------------- /stog/utils/exception_hook.py: -------------------------------------------------------------------------------- 1 | class ExceptionHook: 2 | instance = None 3 | def __call__(self, *args, **kwargs): 4 | if self.instance is None: 5 | from IPython.core import ultratb 6 | self.instance = ultratb.FormattedTB(mode="Plain", color_scheme="Linux", call_pdb=1) 7 | return self.instance(*args, **kwargs) 8 | 9 | -------------------------------------------------------------------------------- /stog/utils/extract_tokens_from_amr.py: -------------------------------------------------------------------------------- 1 | from stog.data.dataset_readers import AbstractMeaningRepresentationDatasetReader 2 | import sys 3 | from stog.utils import logging 4 | 5 | logger = logging.init_logger() 6 | def extract_amr_token(file_path): 7 | dataset_reader = AbstractMeaningRepresentationDatasetReader() 8 | for instance in dataset_reader.read(file_path): 9 | amr_tokens = instance.fields["amr_tokens"]["decoder_tokens"] 10 | yield " ".join(amr_tokens) 11 | 12 | 13 | if __name__ == "__main__": 14 | if len(sys.argv) < 2: 15 | print("""Usage: 16 | python {} [amr_file] 17 | 18 | The output will in stdout. 19 | """) 20 | for filename in sys.argv[1:]: 21 | for line in extract_amr_token(filename): 22 | sys.stdout.write(line + "\n") 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /stog/utils/time.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | 4 | def time_to_str(timestamp: int) -> str: 5 | """ 6 | Convert seconds past Epoch to human readable string. 7 | """ 8 | datetimestamp = datetime.datetime.fromtimestamp(timestamp) 9 | return '{:04d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}'.format( 10 | datetimestamp.year, datetimestamp.month, datetimestamp.day, 11 | datetimestamp.hour, datetimestamp.minute, datetimestamp.second 12 | ) 13 | 14 | 15 | def str_to_time(time_str: str) -> datetime.datetime: 16 | """ 17 | Convert human readable string to datetime.datetime. 18 | """ 19 | pieces = [int(piece) for piece in time_str.split('-')] 20 | return datetime.datetime(*pieces) 21 | -------------------------------------------------------------------------------- /stog/utils/tqdm.py: -------------------------------------------------------------------------------- 1 | """ 2 | :class:`~allennlp.common.tqdm.Tqdm` wraps tqdm so we can add configurable 3 | global defaults for certain tqdm parameters. 4 | 5 | Adopted from AllenNLP: 6 | https://github.com/allenai/allennlp/blob/v0.6.1/allennlp/common/tqdm.py 7 | """ 8 | 9 | from tqdm import tqdm as _tqdm 10 | # This is neccesary to stop tqdm from hanging 11 | # when exceptions are raised inside iterators. 12 | # It should have been fixed in 4.2.1, but it still 13 | # occurs. 14 | # TODO(Mark): Remove this once tqdm cleans up after itself properly. 15 | # https://github.com/tqdm/tqdm/issues/469 16 | _tqdm.monitor_interval = 0 17 | 18 | class Tqdm: 19 | # These defaults are the same as the argument defaults in tqdm. 20 | default_mininterval: float = 0.1 21 | 22 | @staticmethod 23 | def set_default_mininterval(value: float) -> None: 24 | Tqdm.default_mininterval = value 25 | 26 | @staticmethod 27 | def set_slower_interval(use_slower_interval: bool) -> None: 28 | """ 29 | If ``use_slower_interval`` is ``True``, we will dramatically slow down ``tqdm's`` default 30 | output rate. ``tqdm's`` default output rate is great for interactively watching progress, 31 | but it is not great for log files. You might want to set this if you are primarily going 32 | to be looking at output through log files, not the terminal. 33 | """ 34 | if use_slower_interval: 35 | Tqdm.default_mininterval = 10.0 36 | else: 37 | Tqdm.default_mininterval = 0.1 38 | 39 | @staticmethod 40 | def tqdm(*args, **kwargs): 41 | new_kwargs = { 42 | 'mininterval': Tqdm.default_mininterval, 43 | **kwargs 44 | } 45 | 46 | return _tqdm(*args, **new_kwargs) 47 | -------------------------------------------------------------------------------- /tools/amr-evaluation-tool-enhanced/.Rhistory: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emorynlp/levi-graph-amr-parser/f71f1056c13181b8db31d6136451fb8d57114819/tools/amr-evaluation-tool-enhanced/.Rhistory -------------------------------------------------------------------------------- /tools/amr-evaluation-tool-enhanced/README.md: -------------------------------------------------------------------------------- 1 | # amr-evaluation-enhanced (this is a variant of https://github.com/mdtux89/amr-evaluation) 2 | 3 | Evaluation metrics to compare AMR graphs based on Smatch (http://amr.isi.edu/evaluation.html). The script computes a set of metrics between AMR graphs in addition to the traditional Smatch code: 4 | 5 | * **Unlabeled**(differ): Smatch score computed on the predicted graphs after (canonicalizing direction and) removing all edge labels 6 | * No WSD. Smatch score while ignoring Propbank senses (e.g., duck-01 vs duck-02) 7 | * Named Ent. F-score on the named entity recognition (:name roles) 8 | * **Non_sense_frames**(new). F-score on Propbank frame identification without sense (e.g. duck-00) 9 | * **Frames**(new). F-score on Propbank frame identification without sense (e.g. duck-01) 10 | * Wikification. F-score on the wikification (:wiki roles) 11 | * Negations. F-score on the negation detection (:polarity roles) 12 | * Concepts. F-score on the concept identification task 13 | * Reentrancy. Smatch computed on reentrant edges only 14 | * SRL. Smatch computed on :ARG-i roles only 15 | 16 | The different metrics were introduced in the paper below, which also uses them to evaluate several AMR parsers: 17 | 18 | "An Incremental Parser for Abstract Meaning Representation", Marco Damonte, Shay B. Cohen and Giorgio Satta. Proceedings of EACL (2017). URL: https://arxiv.org/abs/1608.06111 19 | 20 | **(Some of the metrics were recently fixed and updated)** 21 | 22 | **Usage:** ```./evaluation.sh ```, 23 | where and are two files which contain multiple AMRs. A blank line is used to separate two AMRs (same format required by Smatch). 24 | 25 | In the paper we also discuss a metric for noun phrase analysis. To compute this metric: 26 | 27 | - ```./preprocessing.sh ``` and ```python extract_np.py ``` to extract the noun phrases from your gold dataset. This will create two files: ```np_sents.txt``` and ```np_graphs.txt```. 28 | - Parse ```np_sents.txt``` with the AMR parser and evaluate with Smatch ```python smatch/smatch.py --pr -f np_graphs.txt``` 29 | -------------------------------------------------------------------------------- /tools/amr-evaluation-tool-enhanced/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | -------------------------------------------------------------------------------- /tools/amr-evaluation-tool-enhanced/alignments.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | ''' 5 | Definition of Alignments class. For each sentence, computes the list of node variables that are aligned 6 | to each index in the sentence, assuming alignments in the format returned by JAMR 7 | 8 | @author: Marco Damonte (m.damonte@sms.ed.ac.uk) 9 | @since: 03-10-16 10 | ''' 11 | 12 | import smatch.amr_edited as amr_annot 13 | from collections import defaultdict 14 | 15 | class Alignments: 16 | 17 | def _traverse(self, parsed_amr, amr): 18 | triples = parsed_amr.get_triples3() 19 | triples2 = [] 20 | root = None 21 | for i in range (0, len(triples)): 22 | rel = triples[i] 23 | if rel[1] == "TOP": 24 | triples2.append(("TOP",":top",rel[0])) 25 | root = rel[0] 26 | elif rel not in [r for r in parsed_amr.reent if r[2] in parsed_amr.nodes]: 27 | triples2.append((rel[0],":" + rel[1],rel[2])) 28 | indexes = {} 29 | queue = [] 30 | visited = [] 31 | queue.append((root, "0")) 32 | while len(queue) > 0: 33 | (node, prefix) = queue.pop(0) 34 | if node in visited: 35 | continue 36 | indexes[prefix] = node 37 | if node in parsed_amr.nodes: 38 | visited.append(node) 39 | children = [t for t in triples2 if str(t[0]) == node] 40 | i = 0 41 | for c in children: 42 | v = str(c[2]) 43 | queue.append((v, prefix + "." + str(i))) 44 | i += 1 45 | return indexes 46 | 47 | 48 | def __init__(self, alignments_filename, graphs): 49 | self.alignments = [] 50 | for g, line in zip(graphs,open(alignments_filename)): 51 | amr = g.strip() 52 | parsed_amr = amr_annot.AMR.parse_AMR_line(amr.replace("\n",""), False) 53 | line = line.strip() 54 | indexes = self._traverse(parsed_amr, amr) 55 | al = defaultdict(list) 56 | if line != "": 57 | for a in line.split(" "): 58 | if a.strip() == "": 59 | continue 60 | start = a.split("|")[0].split("-")[0] 61 | if start[0] == "*": 62 | start = start[1:] 63 | end = a.split("|")[0].split("-")[1] 64 | for i in range(int(start),int(end)): 65 | for segment in a.split("|")[1].split("+"): 66 | al[i].append(indexes[segment]) 67 | self.alignments.append(al) 68 | -------------------------------------------------------------------------------- /tools/amr-evaluation-tool-enhanced/evaluation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Evaluation script. Run as: ./evaluation.sh 4 | out=`/usr/bin/python2 smatch/smatch.py --pr -f "$1" "$2"` 5 | out=($out) 6 | echo 'Smatch -> P: '${out[1]}', R: '${out[3]}', F: '${out[6]} | sed 's/.$//' 7 | p=$$_pred.tmp 8 | g=$$_gold.tmp 9 | python unlabel.py "$1" > $p 10 | python unlabel.py "$2" > $g 11 | 12 | out=`/usr/bin/python2 smatch/smatch.py --pr -f $p $g` 13 | out=($out) 14 | echo 'Unlabeled -> P: '${out[1]}', R: '${out[3]}', F: '${out[6]} | sed 's/.$//' 15 | 16 | cat "$1" | perl -ne 's/(\/ [a-zA-Z0-9\-][a-zA-Z0-9\-]*)-[0-9][0-9]*/\1-01/g; print;' > $p 17 | cat "$2" | perl -ne 's/(\/ [a-zA-Z0-9\-][a-zA-Z0-9\-]*)-[0-9][0-9]*/\1-01/g; print;' > $g 18 | out=`/usr/bin/python2 smatch/smatch.py --pr -f $p $g` 19 | out=($out) 20 | echo 'No WSD -> P: '${out[1]}', R: '${out[3]}', F: '${out[6]} | sed 's/.$//' 21 | 22 | cat "$1" | perl -ne 's/^#.*\n//g; print;' | tr '\t' ' ' | tr -s ' ' > $p 23 | cat "$2" | perl -ne 's/^#.*\n//g; print;' | tr '\t' ' ' | tr -s ' ' > $g 24 | /usr/bin/python2 scores.py "$p" "$g" 25 | 26 | rm $p 27 | rm $g 28 | -------------------------------------------------------------------------------- /tools/amr-evaluation-tool-enhanced/evaluation_label.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Evaluation script. Run as: ./evaluation.sh 4 | p=$$_pred.tmp 5 | g=$$_gold.tmp 6 | python unlabel.py "$1" >$p 7 | python unlabel.py "$2" >$g 8 | 9 | out=$(/usr/bin/python2 smatch/smatch.py --pr -f $p $g) 10 | out=($out) 11 | echo 'Unlabeled -> P: '${out[1]}', R: '${out[3]}', F: '${out[6]} | sed 's/.$//' 12 | 13 | rm $p 14 | rm $g 15 | -------------------------------------------------------------------------------- /tools/amr-evaluation-tool-enhanced/onelabel.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | match_of = re.compile(":[0-9a-zA-Z]*-of") 4 | match_no_of = re.compile(":[0-9a-zA-Z]*(?!-of)") 5 | 6 | placeholder = '<`_placeholder_`>' 7 | 8 | 9 | def readFile(filepath, keeplabel): 10 | with open(filepath, 'r') as content_file: 11 | content = content_file.read() 12 | amr_t = content 13 | assert placeholder not in amr_t, 'conflicting placeholder' 14 | amr_t = amr_t.replace(keeplabel + ' ', placeholder) 15 | amr_t = re.sub(match_no_of, ":label", amr_t) 16 | amr_t = re.sub(match_of, ":label-of", amr_t) 17 | amr_t = amr_t.replace(placeholder, keeplabel + ' ') 18 | print(amr_t) 19 | 20 | 21 | import sys 22 | 23 | readFile(sys.argv[1], sys.argv[2]) 24 | -------------------------------------------------------------------------------- /tools/amr-evaluation-tool-enhanced/smatch/.filt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emorynlp/levi-graph-amr-parser/f71f1056c13181b8db31d6136451fb8d57114819/tools/amr-evaluation-tool-enhanced/smatch/.filt -------------------------------------------------------------------------------- /tools/amr-evaluation-tool-enhanced/smatch/.output_jamr.txt.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emorynlp/levi-graph-amr-parser/f71f1056c13181b8db31d6136451fb8d57114819/tools/amr-evaluation-tool-enhanced/smatch/.output_jamr.txt.swp -------------------------------------------------------------------------------- /tools/amr-evaluation-tool-enhanced/smatch/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (C) 2015 Shu Cai and Kevin Knight 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /tools/amr-evaluation-tool-enhanced/smatch/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | -------------------------------------------------------------------------------- /tools/amr-evaluation-tool-enhanced/smatch/sample_file_list: -------------------------------------------------------------------------------- 1 | nw_wsj_0001_1 nw_wsj_0001_2 nw_wsj_0002_1 nw_wsj_0003_1 2 | -------------------------------------------------------------------------------- /tools/amr-evaluation-tool-enhanced/smatch/update_log: -------------------------------------------------------------------------------- 1 | Update: 08/22/2012 2 | Person involved: Shu Cai 3 | 4 | Minor bug fix of smatch.py. smatch-v2.py was created. 5 | 6 | smatch.py-> smatch-v1.py 7 | smatch-v2.py-> smatch.py 8 | 9 | No change of interface 10 | 11 | Update: 09/14/2012 12 | Person involved: Shu Cai 13 | 14 | Bug fix of smatch.py and smatch-table.py. smatch-v0.1.py smatch-v0.2.py smatch-v0.3.py smatch-v0.4.py smatch-table-v0.1.py smatch-table-v0.2.py was created. 15 | 16 | smatch.py now equals to smatch-v0.4.py 17 | smatch-table.py now equals to smatch-table-v0.2.py 18 | 19 | smatch.py runs with a smart initialization, which matches words with the same value first, then randomly select other variable mappings. 4 restarts is applied. 20 | 21 | Update: 03/17/2013 22 | Person involved: Shu Cai 23 | 24 | Interface change of smatch.py and smatch-table.py. Using this version does not require esem-format-check.pl. (All versions before v0.5 require esem-format-check.pl to check the format of AMR) Instead it needs amr.py. 25 | 26 | It now accepts one-AMR-per-line format as well as other formats of AMR. 27 | 28 | smatch.py now equals to smatch-v0.5.py 29 | smatch-table.py now equals to smatch-table-v0.3.py 30 | 31 | Update:03/19/2013 32 | Person involved: Shu Cai 33 | 34 | Document update. The latest documents are smatch_guide.txt and smatch_guide.pdf (same content) 35 | Add some sample files to the directory: sample_file_list, test_input1, test_input2 36 | 37 | Update: 03/20/2013 38 | Person involved: Shu Cai 39 | 40 | Minor changes to the documents: smatch_guide.txt and smatch_guide.pdf 41 | 42 | Update: 04/04/2013 43 | Person involved: Shu Cai 44 | 45 | Add Software_architecture.pdf. Minor changes to the smatch.py and smatch-table.py (comments and add --pr option) 46 | Minor changes to the README.txt and smatch_guide.pdf 47 | 48 | Update: 01/18/2015 49 | Person involved: Shu Cai 50 | Code cleanup and bug fix. Add detailed comment to the code. 51 | Thanks Yoav Artzi (yoav@cs.washington.edu) for finding a bug and fixing it. 52 | 53 | Update: 12/21/2015 54 | Person involved: Jon May 55 | Fixed treatment of quoted strings to allow special characters to be actually part of the string. 56 | Empty double quoted strings also allowed 57 | 58 | Update: 1/9/2016 59 | Person involved: Guntis Barzdins and Didzis Gosko 60 | Fixed small crash bug 61 | 62 | -------------------------------------------------------------------------------- /tools/amr-evaluation-tool-enhanced/unlabel.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | match_of = re.compile(":[0-9a-zA-Z]*-of") 5 | match_no_of = re.compile(":[0-9a-zA-Z]*(?!-of)") 6 | 7 | def readFile(filepath): 8 | with open(filepath, 'r') as content_file: 9 | content = content_file.read() 10 | amr_t = content 11 | amr_t = re.sub(match_no_of,":label",amr_t) 12 | amr_t = re.sub(match_of,":label-of",amr_t) 13 | print (amr_t) 14 | 15 | import sys 16 | readFile(sys.argv[1]) 17 | -------------------------------------------------------------------------------- /tools/fast_smatch/README.md: -------------------------------------------------------------------------------- 1 | Borrowed from [Oneplus/tamr](https://github.com/Oneplus/tamr/tree/master/amr_aligner/smatch) with slight modifications. 2 | 3 | This uses Cython to re-implement the smatch script, which is super faster compared to [the ordinary version](https://github.com/snowblink14/smatch) 4 | 5 | 6 | Usage: `sh compute_smatch.sh test.txt ref.txt` -------------------------------------------------------------------------------- /tools/fast_smatch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emorynlp/levi-graph-amr-parser/f71f1056c13181b8db31d6136451fb8d57114819/tools/fast_smatch/__init__.py -------------------------------------------------------------------------------- /tools/fast_smatch/_gain.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #define _HASH_PAIR(x, y) (((x) << 14) | (y)) 6 | #define _GET_0(x) ((x) >> 14) 7 | #define _GET_1(x) ((x) & ((1 << 14) - 1)) 8 | 9 | typedef std::unordered_map > WeightDictType; 10 | typedef std::vector MappingType; 11 | 12 | int _hash_pair(int x, int y); 13 | 14 | int _get_0(int x); 15 | 16 | int _get_1(int x); 17 | 18 | int move_gain(MappingType & mapping, 19 | int node_id, 20 | int old_id, 21 | int new_id, 22 | WeightDictType & weight_dict, 23 | int match_num); 24 | 25 | int swap_gain(MappingType & mapping, 26 | int node_id1, 27 | int mapping_id1, 28 | int node_id2, 29 | int mapping_id2, 30 | WeightDictType & weight_dict, 31 | int match_num); 32 | 33 | -------------------------------------------------------------------------------- /tools/fast_smatch/compute_smatch.sh: -------------------------------------------------------------------------------- 1 | if [ ! -f "_smatch.so" ]; then 2 | echo "compiling fast smatch" 3 | python2 setup.py build 4 | mv build/*/_smatch.so . 5 | rm -rf build 6 | else 7 | echo "using fast smatch" 8 | fi 9 | PYTHONPATH=. python2 fast_smatch.py --pr -f $1 $2 10 | -------------------------------------------------------------------------------- /tools/fast_smatch/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from distutils.core import setup, Extension 3 | from Cython.Build import cythonize 4 | 5 | setup(ext_modules=cythonize(Extension("_smatch", sources=["_smatch.pyx", "_gain.cc"], language="c++",extra_compile_args=["-std=c++11"]))) 6 | --------------------------------------------------------------------------------