├── .gitignore ├── LICENSE ├── README.md ├── docs └── fig │ └── dep.png ├── elit ├── __init__.py ├── callbacks │ ├── __init__.py │ └── fine_csv_logger.py ├── common │ ├── __init__.py │ ├── cache.py │ ├── component.py │ ├── dataset.py │ ├── keras_component.py │ ├── structure.py │ ├── torch_component.py │ ├── transform.py │ ├── transform_tf.py │ ├── vocab.py │ └── vocab_tf.py ├── components │ ├── __init__.py │ ├── amr │ │ ├── __init__.py │ │ └── seq2seq │ │ │ ├── __init__.py │ │ │ ├── dataset │ │ │ ├── IO.py │ │ │ ├── __init__.py │ │ │ ├── dataset.py │ │ │ ├── linearization.py │ │ │ ├── penman.py │ │ │ ├── postprocessing.py │ │ │ ├── tokenization_bart.py │ │ │ └── tokenization_t5.py │ │ │ ├── evaluation.py │ │ │ ├── graph │ │ │ ├── __init__.py │ │ │ ├── graph_bart.py │ │ │ ├── prompt.py │ │ │ ├── seq2seq_rgcn_amr_parser.py │ │ │ └── soft_rgcn.py │ │ │ ├── linguistic_bart.py │ │ │ ├── optim.py │ │ │ ├── seq2seq_amr_parser.py │ │ │ ├── seq2seq_cfg_amr_parser.py │ │ │ ├── seq2seq_embed_amr_parser.py │ │ │ └── seq2seq_levi_amr_parser.py │ ├── classifiers │ │ ├── __init__.py │ │ ├── fasttext_classifier.py │ │ ├── transformer_classifier.py │ │ └── transformer_classifier_tf.py │ ├── detok.py │ ├── lambda_wrapper.py │ ├── lemmatizer.py │ ├── parsers │ │ ├── __init__.py │ │ ├── alg.py │ │ ├── alg_tf.py │ │ ├── biaffine │ │ │ ├── __init__.py │ │ │ ├── biaffine.py │ │ │ ├── biaffine_2nd_dep.py │ │ │ ├── biaffine_dep.py │ │ │ ├── biaffine_model.py │ │ │ ├── biaffine_sdp.py │ │ │ ├── mlp.py │ │ │ ├── structual_attention.py │ │ │ └── variationalbilstm.py │ │ ├── biaffine_parser_tf.py │ │ ├── biaffine_tf │ │ │ ├── __init__.py │ │ │ ├── alg.py │ │ │ ├── layers.py │ │ │ └── model.py │ │ ├── chu_liu_edmonds.py │ │ ├── conll.py │ │ ├── constituency │ │ │ ├── __init__.py │ │ │ ├── crf_constituency_model.py │ │ │ ├── crf_constituency_parser.py │ │ │ └── treecrf.py │ │ ├── parse_alg.py │ │ └── ud │ │ │ ├── __init__.py │ │ │ ├── lemma_edit.py │ │ │ ├── tag_decoder.py │ │ │ ├── ud_model.py │ │ │ ├── ud_parser.py │ │ │ ├── udify_util.py │ │ │ └── util.py │ ├── pipeline.py │ ├── rnn_language_model_tf.py │ └── seq2seq │ │ ├── __init__.py │ │ ├── con │ │ ├── __init__.py │ │ ├── constrained_decoding.py │ │ ├── seq2seq_con.py │ │ ├── transformers_ext.py │ │ ├── utility.py │ │ └── verbalizer.py │ │ ├── dep │ │ ├── __init__.py │ │ ├── arc_eager.py │ │ ├── arc_standard.py │ │ ├── constrained_decoding.py │ │ ├── dep_utility.py │ │ ├── seq2seq_dep.py │ │ ├── transformers_ext.py │ │ └── verbalizer.py │ │ ├── ner │ │ ├── __init__.py │ │ ├── conditional_seq2seq.py │ │ ├── conditional_seq2seq_constrained.py │ │ ├── constrained_decoding.py │ │ ├── dataset.py │ │ ├── dynamic_oracle_bart.py │ │ ├── dynamic_oracle_seq2seq_ner.py │ │ ├── dynamic_seq2seq.py │ │ ├── prompt_ner.py │ │ ├── seq2seq_ner.py │ │ └── transformers_ext.py │ │ └── pos │ │ ├── __init__.py │ │ ├── constrained_decoding.py │ │ ├── seq2seq_pos.py │ │ ├── transformers_ext.py │ │ └── verbalizer.py ├── datasets │ ├── __init__.py │ ├── classification │ │ ├── __init__.py │ │ └── sentiment.py │ ├── coref │ │ ├── __init__.py │ │ └── loaders │ │ │ ├── __init__.py │ │ │ └── conll12coref.py │ ├── detokenization │ │ ├── __init__.py │ │ └── detok.py │ ├── eos │ │ ├── __init__.py │ │ ├── eos.py │ │ └── loaders │ │ │ ├── __init__.py │ │ │ └── nn_eos.py │ ├── lm │ │ ├── __init__.py │ │ ├── doc_dataset.py │ │ ├── loaders │ │ │ ├── __init__.py │ │ │ └── lm_dataset.py │ │ └── sent_dataset.py │ ├── lu │ │ ├── __init__.py │ │ └── glue.py │ ├── ner │ │ ├── __init__.py │ │ ├── conll03.py │ │ ├── conll03_json.py │ │ ├── gazetters.py │ │ ├── loaders │ │ │ ├── __init__.py │ │ │ ├── json_ner.py │ │ │ └── tsv.py │ │ ├── msra.py │ │ ├── resume.py │ │ └── weibo.py │ ├── parsing │ │ ├── __init__.py │ │ ├── amr.py │ │ ├── ctb5.py │ │ ├── ctb7.py │ │ ├── ctb8.py │ │ ├── ctb9.py │ │ ├── loaders │ │ │ ├── __init__.py │ │ │ ├── _ctb_utils.py │ │ │ ├── conll_dataset.py │ │ │ └── constituency_dataset.py │ │ ├── pmt1.py │ │ ├── ptb.py │ │ ├── semeval15.py │ │ ├── semeval16.py │ │ └── ud │ │ │ ├── __init__.py │ │ │ ├── ud210.py │ │ │ ├── ud210m.py │ │ │ ├── ud23.py │ │ │ ├── ud23m.py │ │ │ ├── ud27.py │ │ │ ├── ud27m.py │ │ │ ├── ud28.py │ │ │ └── ud28m.py │ ├── pos │ │ ├── __init__.py │ │ ├── ctb5.py │ │ └── ptb.py │ ├── qa │ │ ├── __init__.py │ │ └── hotpotqa.py │ ├── srl │ │ ├── __init__.py │ │ ├── loaders │ │ │ ├── __init__.py │ │ │ ├── conll2012.py │ │ │ └── ontonotes_loader.py │ │ ├── ontonotes4 │ │ │ ├── __init__.py │ │ │ └── chinese.py │ │ └── ontonotes5 │ │ │ ├── __init__.py │ │ │ ├── _utils.py │ │ │ ├── chinese.py │ │ │ └── english.py │ ├── sts │ │ ├── __init__.py │ │ └── stsb.py │ └── tokenization │ │ ├── __init__.py │ │ ├── ctb6.py │ │ ├── loaders │ │ ├── __init__.py │ │ ├── chunking_dataset.py │ │ ├── multi_criteria_cws │ │ │ ├── __init__.py │ │ │ └── mcws_dataset.py │ │ └── txt.py │ │ └── sighan2005 │ │ ├── __init__.py │ │ ├── as_.py │ │ ├── cityu.py │ │ ├── msr.py │ │ └── pku.py ├── layers │ ├── __init__.py │ ├── cnn_encoder.py │ ├── crf │ │ ├── __init__.py │ │ ├── crf.py │ │ ├── crf_layer_tf.py │ │ └── crf_tf.py │ ├── dropout.py │ ├── embeddings │ │ ├── __init__.py │ │ ├── char_cnn.py │ │ ├── char_cnn_tf.py │ │ ├── char_rnn.py │ │ ├── char_rnn_tf.py │ │ ├── concat_embedding.py │ │ ├── contextual_string_embedding.py │ │ ├── contextual_string_embedding_tf.py │ │ ├── contextual_word_embedding.py │ │ ├── embedding.py │ │ ├── fast_text.py │ │ ├── fast_text_tf.py │ │ ├── util.py │ │ ├── util_tf.py │ │ ├── word2vec.py │ │ └── word2vec_tf.py │ ├── feed_forward.py │ ├── feedforward.py │ ├── gates │ │ ├── __init__.py │ │ ├── concrete_gate.py │ │ └── concrete_gate_tf.py │ ├── scalar_mix.py │ ├── time_distributed.py │ ├── transformers │ │ ├── __init__.py │ │ ├── encoder.py │ │ ├── loader_tf.py │ │ ├── longformer │ │ │ ├── __init__.py │ │ │ └── long_models.py │ │ ├── pt_imports.py │ │ ├── relative_transformer.py │ │ ├── resource.py │ │ ├── tf_imports.py │ │ ├── utils.py │ │ └── utils_tf.py │ └── weight_normalization.py ├── losses │ ├── __init__.py │ ├── homoscedastic_loss_weighted_sum.py │ └── sparse_categorical_crossentropy.py ├── metrics │ ├── __init__.py │ ├── accuracy.py │ ├── amr │ │ ├── __init__.py │ │ └── smatch_eval.py │ ├── chunking │ │ ├── __init__.py │ │ ├── binary_chunking_f1.py │ │ ├── bmes_tf.py │ │ ├── chunking_f1.py │ │ ├── chunking_f1_tf.py │ │ ├── conlleval.py │ │ ├── iobes_tf.py │ │ └── sequence_labeling.py │ ├── f1.py │ ├── metric.py │ ├── mtl.py │ ├── parsing │ │ ├── __init__.py │ │ ├── attachmentscore.py │ │ ├── conllx_eval.py │ │ ├── evalb_bracketing_scorer.py │ │ ├── labeled_f1.py │ │ ├── labeled_f1_tf.py │ │ ├── labeled_score.py │ │ ├── semdep_eval.py │ │ └── span.py │ ├── spearman_correlation.py │ └── srl │ │ ├── __init__.py │ │ ├── e2e_srl.py │ │ └── srlconll.py ├── optimizers │ ├── __init__.py │ └── adamw │ │ ├── __init__.py │ │ └── optimization.py ├── pretrained │ ├── __init__.py │ ├── amr.py │ ├── amr2text.py │ ├── classifiers.py │ ├── constituency.py │ ├── dep.py │ ├── eos.py │ ├── fasttext.py │ ├── glove.py │ ├── mtl.py │ ├── ner.py │ ├── pos.py │ ├── rnnlm.py │ ├── sdp.py │ ├── srl.py │ ├── sts.py │ ├── tok.py │ └── word2vec.py ├── transform │ ├── __init__.py │ ├── conll_tf.py │ ├── glue_tf.py │ ├── table_tf.py │ ├── tacred_tf.py │ ├── text_tf.py │ ├── transformer_tokenizer.py │ ├── tsv_tf.py │ └── txt_tf.py ├── utils │ ├── __init__.py │ ├── component_util.py │ ├── file_read_backwards │ │ ├── __init__.py │ │ ├── buffer_work_space.py │ │ └── file_read_backwards.py │ ├── init_util.py │ ├── io_util.py │ ├── lang │ │ ├── __init__.py │ │ ├── en │ │ │ ├── __init__.py │ │ │ └── english_tokenizer.py │ │ ├── ja │ │ │ ├── __init__.py │ │ │ └── bert_tok.py │ │ └── zh │ │ │ ├── __init__.py │ │ │ ├── char_table.py │ │ │ └── localization.py │ ├── log_util.py │ ├── rules.py │ ├── span_util.py │ ├── statistics │ │ ├── __init__.py │ │ └── moving_avg.py │ ├── string_util.py │ ├── tf_util.py │ ├── time_util.py │ └── torch_util.py └── version.py ├── setup.py └── tests ├── __init__.py ├── con ├── __init__.py ├── ontonotes │ ├── __init__.py │ ├── ls.py │ ├── lt.py │ ├── pt.py │ └── pt_inc_vrb.py └── ptb │ ├── __init__.py │ ├── ls.py │ ├── lt.py │ └── pt.py ├── dep ├── __init__.py ├── ontonotes │ ├── __init__.py │ ├── ls.py │ ├── lt.py │ ├── pt.py │ └── pt_dec_lex.py └── ptb │ ├── __init__.py │ ├── ls.py │ ├── lt.py │ ├── pt.py │ └── pt_dec_lex.py ├── ner ├── __init__.py ├── conll │ ├── __init__.py │ ├── ls.py │ ├── lt.py │ ├── pt.py │ └── pt_inc_vrb.py └── ontonotes │ ├── __init__.py │ ├── ls.py │ ├── lt.py │ ├── pt.py │ └── pt_inc_vrb.py └── pos ├── __init__.py ├── ontonotes ├── __init__.py ├── ls.py ├── lt.py ├── pt.py └── pt_dec_lex.py └── ptb ├── __init__.py ├── ls.py ├── lt.py ├── pt.py └── pt_dec_lex.py /README.md: -------------------------------------------------------------------------------- 1 | # Sequence-to-Sequence CoreNLP 2 | 3 | Codes for our paper *[Unleashing the True Potential of Sequence-to-Sequence Models for Sequence Tagging and Structure Parsing](https://arxiv.org/abs/2302.02275)* published to [TACL 2023](https://transacl.org/). 4 | 5 | ![dep](docs/fig/dep.png) 6 | 7 | ## Installation 8 | 9 | Run the following setup script. Feel free to install [a GPU-enabled PyTorch](https://pytorch.org/get-started/locally/) (`torch>=1.6.0`). 10 | 11 | ```bash 12 | python3 -m venv env 13 | source env/bin/activate 14 | ln -sf "$(which python2.7)" env/bin/python 15 | pip install -e . 16 | export PYTHONPATH=.:$PYTHONPATH 17 | ``` 18 | 19 | ## Data Pre-processing 20 | 21 | Download OntoNotes 5 ([`LDC2013T19.tgz`](https://catalog.ldc.upenn.edu/LDC2013T19)) and put it into the following directory: 22 | 23 | ```bash 24 | mkdir -p ~/.elit/thirdparty/catalog.ldc.upenn.edu/LDC2013T19/ 25 | cp LDC2013T19.tgz ~/.elit/thirdparty/catalog.ldc.upenn.edu/LDC2013T19/LDC2013T19.tgz 26 | ``` 27 | 28 | That's all. ELIT will automatically do the rest for you the first time you run a training script. 29 | 30 | ## Experiments 31 | 32 | Training and evaluation scripts are grouped in `tests` following the pattern: `tests/{pos|ner|con|dep}/{ptb|conll|ontonotes}/{ls|lt|pt}.py`. 33 | 34 | For example, the script for `POS-LS` on PTB can be executed via: 35 | 36 | ```bash 37 | python3 tests/pos/ptb/ls.py 38 | ``` 39 | 40 | 41 | ## Citation 42 | 43 | If you use this repository in your research, please kindly cite our TACL 2023 paper: 44 | 45 | ```bibtex 46 | @article{he-choi-2023-seq2seq, 47 | title = "Unleashing the True Potential of Sequence-to-Sequence Models for Sequence Tagging and Structure Parsing", 48 | author = "He, Han and Choi, Jinho D.", 49 | journal = "Transactions of the Association for Computational Linguistics", 50 | year = "2023", 51 | address = "Cambridge, MA", 52 | publisher = "MIT Press", 53 | abstract = "Sequence-to-Sequence (S2S) models have achieved remarkable success on various text generation tasks. However, learning complex structures with S2S models remains challenging as external neural modules and additional lexicons are often supplemented to predict non-textual outputs. We present a systematic study of S2S modeling using contained decoding on four core tasks: part-of-speech tagging, named entity recognition, constituency and dependency parsing, to develop efficient exploitation methods costing zero extra parameters. In particular, 3 lexically diverse linearization schemas and corresponding constrained decoding methods are designed and evaluated. Experiments show that although more lexicalized schemas yield longer output sequences that require heavier training, their sequences being closer to natural language makes them easier to learn. Moreover, S2S models using our constrained decoding outperform other S2S approaches using external resources. Our best models perform better than or comparably to the state-of-the-art for all 4 tasks, lighting a promise for S2S models to generate non-sequential structures. ", 54 | } 55 | ``` -------------------------------------------------------------------------------- /docs/fig/dep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emorynlp/seq2seq-corenlp/7155b117630b79ba1a640e76dfe5ba93e1166fff/docs/fig/dep.png -------------------------------------------------------------------------------- /elit/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-06-13 18:05 4 | import elit.common 5 | import elit.components 6 | import elit.pretrained 7 | import elit.utils 8 | from elit.version import __version__ 9 | 10 | elit.utils.ls_resource_in_module(elit.pretrained) 11 | 12 | 13 | def load(save_dir: str, verbose=None, **kwargs) -> elit.common.component.Component: 14 | """Load a pretrained component from an identifier. 15 | 16 | Args: 17 | save_dir (str): The identifier to the saved component. It could be a remote URL or a local path. 18 | verbose: ``True`` to print loading progress. 19 | **kwargs: Arguments passed to :func:`elit.common.torch_component.TorchComponent.load`, e.g., 20 | ``devices`` is a useful argument to specify which GPU devices a PyTorch component will use. 21 | 22 | Examples:: 23 | 24 | import elit 25 | # Load component onto the 0-th GPU. 26 | elit.load(..., devices=0) 27 | # Load component onto the 0-th and 1-st GPUs using data parallelization. 28 | elit.load(..., devices=[0, 1]) 29 | 30 | .. Note:: 31 | A component can have dependencies on other components or resources, which will be recursively loaded. So it's 32 | common to see multiple downloading messages per single load. 33 | 34 | Returns: 35 | elit.common.component.Component: A pretrained component. 36 | 37 | """ 38 | save_dir = elit.pretrained.ALL.get(save_dir, save_dir) 39 | from elit.utils.component_util import load_from_meta_file 40 | if verbose is None: 41 | from hanlp_common.constant import HANLP_VERBOSE 42 | verbose = HANLP_VERBOSE 43 | return load_from_meta_file(save_dir, 'meta.json', verbose=verbose, **kwargs) 44 | 45 | 46 | def pipeline(*pipes) -> elit.components.pipeline.Pipeline: 47 | """Creates a pipeline of components. It's made for bundling `KerasComponents`. For `TorchComponent`, use 48 | :class:`~elit.components.mtl.multi_task_learning.MultiTaskLearning` instead. 49 | 50 | Args: 51 | *pipes: Components if pre-defined any. 52 | 53 | Returns: 54 | hanlp.components.pipeline.Pipeline: A pipeline, which is a list of components in order. 55 | 56 | """ 57 | return elit.components.pipeline.Pipeline(*pipes) 58 | -------------------------------------------------------------------------------- /elit/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-05 02:10 -------------------------------------------------------------------------------- /elit/callbacks/fine_csv_logger.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-05 02:12 4 | import copy 5 | from io import TextIOWrapper 6 | from typing import List 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | 12 | class StreamTableFormatter(object): 13 | 14 | def __init__(self) -> None: 15 | super().__init__() 16 | self.col_widths = None 17 | 18 | def format_row(self, cells) -> List[str]: 19 | if not isinstance(cells, list): 20 | cells = list(cells) 21 | if not self.col_widths: 22 | self.col_widths = [0] * len([_ for _ in cells]) 23 | for i, c in enumerate(cells): 24 | self.col_widths[i] = max(self.col_widths[i], len(self.format_cell(c, self.col_widths[i]))) 25 | return list(self.format_cell(cell, width) for cell, width in zip(cells, self.col_widths)) 26 | 27 | def format_cell(self, cell: str, min_width) -> str: 28 | if isinstance(cell, (np.float32, np.float)): 29 | return '{:>{}.4f}'.format(cell, min_width) 30 | return '{:>{}}'.format(cell, min_width) 31 | 32 | 33 | class FineCSVLogger(tf.keras.callbacks.History): 34 | 35 | def __init__(self, filename, separator=',', append=False): 36 | super().__init__() 37 | self.append = append 38 | self.separator = separator 39 | self.filename = filename 40 | self.out: TextIOWrapper = None 41 | self.keys = [] 42 | self.formatter = StreamTableFormatter() 43 | 44 | def on_train_begin(self, logs=None): 45 | super().on_train_begin(logs) 46 | self.out = open(self.filename, 'a' if self.append else 'w') 47 | 48 | def on_train_end(self, logs=None): 49 | self.out.close() 50 | 51 | def on_epoch_end(self, epoch, logs=None): 52 | super().on_epoch_end(epoch, logs) 53 | if not self.keys: 54 | self.keys = sorted(logs.keys()) 55 | 56 | if getattr(self.model, 'stop_training', None): 57 | # We set NA so that csv parsers do not fail for this last epoch. 58 | logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys]) 59 | 60 | # feed them twice to decide the actual width 61 | values = self.formatter.format_row([epoch + 1] + [logs.get(k, 'NA') for k in self.keys]) 62 | headers = self.formatter.format_row(['epoch'] + self.keys) 63 | # print headers and bars 64 | self.out.write(self.separator.join(headers) + '\n') 65 | # bars for markdown style 66 | bars = [''.join(['-'] * width) for width in self.formatter.col_widths] 67 | self.out.write(self.separator.join(bars) + '\n') 68 | 69 | values = self.formatter.format_row([epoch + 1] + [logs.get(k, 'NA') for k in self.keys]) 70 | self.out.write(self.separator.join(values) + '\n') 71 | self.out.flush() 72 | -------------------------------------------------------------------------------- /elit/common/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-08-26 14:45 4 | -------------------------------------------------------------------------------- /elit/common/component.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-08-26 14:45 4 | import inspect 5 | from abc import ABC, abstractmethod 6 | from typing import Any 7 | 8 | from hanlp_common.configurable import Configurable 9 | 10 | 11 | class Component(Configurable, ABC): 12 | @abstractmethod 13 | def predict(self, *args, **kwargs): 14 | """Predict on data. This is the base class for all components, including rule based and statistical ones. 15 | 16 | Args: 17 | *args: Any type of data subject to sub-classes 18 | **kwargs: Additional arguments 19 | 20 | Returns: Any predicted annotations. 21 | 22 | """ 23 | raise NotImplementedError('%s.%s()' % (self.__class__.__name__, inspect.stack()[0][3])) 24 | 25 | def __call__(self, *args, **kwargs): 26 | """ 27 | A shortcut for :func:`~elit.common.component.predict`. 28 | 29 | Args: 30 | *args: Any type of data subject to sub-classes 31 | **kwargs: Additional arguments 32 | 33 | Returns: Any predicted annotations. 34 | 35 | """ 36 | return self.predict(*args, **kwargs) 37 | -------------------------------------------------------------------------------- /elit/common/structure.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-08-26 14:58 4 | from typing import Dict 5 | 6 | from hanlp_common.configurable import Configurable 7 | from hanlp_common.reflection import classpath_of 8 | from hanlp_common.structure import SerializableDict 9 | 10 | 11 | class ConfigTracker(Configurable): 12 | 13 | def __init__(self, locals_: Dict, exclude=('kwargs', 'self', '__class__', 'locals_')) -> None: 14 | """This base class helps sub-classes to capture their arguments passed to ``__init__``, and also their types so 15 | that they can be deserialized from a config in dict form. 16 | 17 | Args: 18 | locals_: Obtained by :meth:`locals`. 19 | exclude: Arguments to be excluded. 20 | 21 | Examples: 22 | >>> class MyClass(ConfigTracker): 23 | >>> def __init__(self, i_need_this='yes') -> None: 24 | >>> super().__init__(locals()) 25 | >>> obj = MyClass() 26 | >>> print(obj.config) 27 | {'i_need_this': 'yes', 'classpath': 'test_config_tracker.MyClass'} 28 | 29 | """ 30 | if 'kwargs' in locals_: 31 | locals_.update(locals_['kwargs']) 32 | self.config = SerializableDict( 33 | (k, v.config if hasattr(v, 'config') else v) for k, v in locals_.items() if k not in exclude) 34 | self.config['classpath'] = classpath_of(self) 35 | 36 | 37 | class History(object): 38 | def __init__(self): 39 | """ A history of training context. It records how many steps have passed and provides methods to decide whether 40 | an update should be performed, and to caculate number of training steps given dataloader size and 41 | ``gradient_accumulation``. 42 | """ 43 | self.num_mini_batches = 0 44 | 45 | def step(self, gradient_accumulation): 46 | """ Whether the training procedure should perform an update. 47 | 48 | Args: 49 | gradient_accumulation: Number of batches per update. 50 | 51 | Returns: 52 | bool: ``True`` to update. 53 | """ 54 | self.num_mini_batches += 1 55 | return self.num_mini_batches % gradient_accumulation == 0 56 | 57 | def num_training_steps(self, num_batches, gradient_accumulation): 58 | """ Caculate number of training steps. 59 | 60 | Args: 61 | num_batches: Size of dataloader. 62 | gradient_accumulation: Number of batches per update. 63 | 64 | Returns: 65 | 66 | """ 67 | return len( 68 | [i for i in range(self.num_mini_batches + 1, self.num_mini_batches + num_batches + 1) if 69 | i % gradient_accumulation == 0]) 70 | -------------------------------------------------------------------------------- /elit/components/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-08-26 16:10 4 | from .pipeline import Pipeline -------------------------------------------------------------------------------- /elit/components/amr/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-08-20 17:35 4 | -------------------------------------------------------------------------------- /elit/components/amr/seq2seq/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-04-27 19:24 4 | -------------------------------------------------------------------------------- /elit/components/amr/seq2seq/dataset/IO.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emorynlp/seq2seq-corenlp/7155b117630b79ba1a640e76dfe5ba93e1166fff/elit/components/amr/seq2seq/dataset/IO.py -------------------------------------------------------------------------------- /elit/components/amr/seq2seq/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-04-27 19:29 4 | -------------------------------------------------------------------------------- /elit/components/amr/seq2seq/dataset/penman.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from penman import load as load_, Graph, Triple 4 | from penman import loads as loads_ 5 | from penman import encode as encode_ 6 | from penman.model import Model 7 | from penman.models.noop import NoOpModel 8 | from penman.models import amr 9 | import penman 10 | import logging 11 | 12 | op_model = Model() 13 | noop_model = NoOpModel() 14 | amr_model = amr.model 15 | DEFAULT = op_model 16 | 17 | # Mute loggers 18 | penman.layout.logger.setLevel(logging.CRITICAL) 19 | penman._parse.logger.setLevel(logging.CRITICAL) 20 | 21 | 22 | def _get_model(dereify): 23 | if dereify is None: 24 | return DEFAULT 25 | elif dereify: 26 | return op_model 27 | else: 28 | return noop_model 29 | 30 | 31 | def _remove_wiki(graph): 32 | metadata = graph.metadata 33 | triples = [] 34 | for t in graph.triples: 35 | v1, rel, v2 = t 36 | if rel == ':wiki': 37 | t = Triple(v1, rel, '+') 38 | triples.append(t) 39 | graph = Graph(triples) 40 | graph.metadata = metadata 41 | return graph 42 | 43 | 44 | def pm_load(source, dereify=None, remove_wiki=False) -> List[penman.Graph]: 45 | """ 46 | 47 | Args: 48 | source: 49 | dereify: Restore reverted relations 50 | remove_wiki: 51 | 52 | Returns: 53 | 54 | """ 55 | model = _get_model(dereify) 56 | out = load_(source=source, model=model) 57 | if remove_wiki: 58 | for i in range(len(out)): 59 | out[i] = _remove_wiki(out[i]) 60 | return out 61 | 62 | 63 | def loads(string, dereify=None, remove_wiki=False): 64 | model = _get_model(dereify) 65 | out = loads_(string=string, model=model) 66 | if remove_wiki: 67 | for i in range(len(out)): 68 | out[i] = _remove_wiki(out[i]) 69 | return out 70 | 71 | 72 | def pm_encode(g, top=None, indent=-1, compact=False): 73 | model = amr_model 74 | return encode_(g=g, top=top, indent=indent, compact=compact, model=model) 75 | 76 | 77 | def role_is_reverted(role: str): 78 | if role.endswith('consist-of'): 79 | return False 80 | return role.endswith('-of') 81 | 82 | 83 | class AMRGraph(penman.Graph): 84 | def __str__(self): 85 | return penman.encode(self) 86 | -------------------------------------------------------------------------------- /elit/components/amr/seq2seq/evaluation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emorynlp/seq2seq-corenlp/7155b117630b79ba1a640e76dfe5ba93e1166fff/elit/components/amr/seq2seq/evaluation.py -------------------------------------------------------------------------------- /elit/components/amr/seq2seq/graph/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-10-09 17:39 4 | -------------------------------------------------------------------------------- /elit/components/amr/seq2seq/graph/soft_rgcn.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-10-05 17:47 4 | import torch 5 | from torch import nn, FloatTensor 6 | from torch.nn import Parameter 7 | 8 | from elit.utils.torch_util import set_seed 9 | 10 | 11 | class SoftRGCNConv(nn.Module): 12 | def __init__(self, in_channels: int, out_channels: int, num_relations: int): 13 | super().__init__() 14 | self.in_channels = in_channels 15 | self.out_channels = out_channels 16 | self.num_relations = num_relations 17 | self.weight = Parameter(torch.Tensor(num_relations, in_channels, out_channels)) 18 | torch.nn.init.normal_(self.weight) 19 | 20 | def forward(self, x: FloatTensor, adj: FloatTensor): 21 | batch_size, num_nodes, in_channels = x.size() 22 | f = torch.einsum('bni,rio->bnro', x, self.weight) 23 | f = f[:, :, None].expand(batch_size, num_nodes, num_nodes, self.num_relations, self.out_channels) 24 | f = f * adj.unsqueeze(-1) 25 | f = f.sum(1).sum(2) 26 | f /= adj.sum(1).sum(2).clamp_min(1e-16).unsqueeze(-1) # avoid div by zero 27 | return f 28 | 29 | 30 | def main(): 31 | set_seed(1) 32 | batch_size = 2 33 | in_channels = 4 34 | num_nodes = 3 35 | x = torch.rand((batch_size, num_nodes, in_channels)) 36 | mask = torch.tensor([[True, True, True], [True, True, False]]) 37 | x[~mask] = 0 38 | num_relations = 2 39 | adj = torch.rand((batch_size, num_nodes, num_nodes, num_relations)).softmax(dim=-1) 40 | mask3d = mask.unsqueeze(-1).expand(batch_size, num_nodes, num_nodes) & mask.unsqueeze(1).expand(batch_size, 41 | num_nodes, 42 | num_nodes) 43 | adj[~mask3d] = 0 44 | out_channels = 5 45 | conv = SoftRGCNConv(in_channels, out_channels, num_relations) 46 | y = conv(x, adj) 47 | print(y) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /elit/components/amr/seq2seq/seq2seq_cfg_amr_parser.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-04-28 17:33 4 | import json 5 | 6 | import torch 7 | from phrasetree.tree import Tree 8 | from transformers.models.bart.modeling_bart import BartForConditionalGeneration, BartLearnedPositionalEmbedding 9 | 10 | from elit.common.vocab import Vocab 11 | from elit.components.amr.seq2seq.dataset.dataset import dfs_linearize_constituency 12 | from elit.components.amr.seq2seq.seq2seq_amr_parser import Seq2seq_AMR_Parser 13 | 14 | 15 | class Seq2seq_CFG_AMR_Parser(Seq2seq_AMR_Parser): 16 | def collect_additional_tokens(self, additional_tokens, dataset): 17 | super().collect_additional_tokens(additional_tokens, dataset) 18 | for sample in dataset: 19 | amr = sample['amr'] 20 | tree = Tree.from_list(json.loads(amr.metadata['con_list'])) 21 | for s in tree.subtrees(): 22 | additional_tokens.add(s.label()) 23 | 24 | def finalize_dataset(self, dataset): 25 | dataset.append_transform(lambda x: dfs_linearize_constituency(x, tokenizer=self._tokenizer)) 26 | 27 | def build_model(self, training=True, **kwargs) -> torch.nn.Module: 28 | # noinspection PyTypeChecker 29 | model: BartForConditionalGeneration = super().build_model(training, **kwargs) 30 | config = model.config 31 | config.max_position_embeddings = 2048 32 | pos_embed = BartLearnedPositionalEmbedding( 33 | config.max_position_embeddings, 34 | config.d_model, 35 | ) 36 | if training: 37 | with torch.no_grad(): 38 | pos_embed.weight[:model.base_model.encoder.embed_positions.weight.size(0), :] \ 39 | = model.base_model.encoder.embed_positions.weight 40 | model.base_model.encoder.embed_positions = pos_embed 41 | return model 42 | -------------------------------------------------------------------------------- /elit/components/amr/seq2seq/seq2seq_levi_amr_parser.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-04-28 17:33 4 | import json 5 | 6 | import torch 7 | from transformers.models.bart.modeling_bart import BartForConditionalGeneration, BartLearnedPositionalEmbedding 8 | 9 | from elit.components.amr.seq2seq.dataset.dataset import dfs_linearize_levi 10 | from elit.components.amr.seq2seq.seq2seq_amr_parser import Seq2seq_AMR_Parser 11 | 12 | 13 | class Seq2seq_Levi_AMR_Parser(Seq2seq_AMR_Parser): 14 | def collect_additional_tokens(self, additional_tokens, dataset): 15 | super().collect_additional_tokens(additional_tokens, dataset) 16 | for sample in dataset: 17 | amr = sample['amr'] 18 | tree = json.loads(amr.metadata['dep']) 19 | for arc, rel in tree: 20 | additional_tokens.add(rel) 21 | 22 | def finalize_dataset(self, dataset): 23 | dataset.append_transform(lambda x: dfs_linearize_levi(x, tokenizer=self._tokenizer)) 24 | 25 | # def build_model(self, training=True, **kwargs) -> torch.nn.Module: 26 | # # noinspection PyTypeChecker 27 | # model: BartForConditionalGeneration = super().build_model(training, **kwargs) 28 | # config = model.config 29 | # config.max_position_embeddings = 2048 30 | # pos_embed = BartLearnedPositionalEmbedding( 31 | # config.max_position_embeddings, 32 | # config.d_model, 33 | # ) 34 | # if training: 35 | # with torch.no_grad(): 36 | # pos_embed.weight[:model.base_model.encoder.embed_positions.weight.size(0), :] \ 37 | # = model.base_model.encoder.embed_positions.weight 38 | # model.base_model.encoder.embed_positions = pos_embed 39 | # return model 40 | -------------------------------------------------------------------------------- /elit/components/classifiers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-11-10 13:18 -------------------------------------------------------------------------------- /elit/components/detok.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-11-26 20:32 4 | from typing import Any 5 | 6 | from elit.components.taggers.transformers.transformer_tagger import TransformerTagger 7 | from elit.datasets.detokenization.detok import DetokenizationDataset 8 | 9 | 10 | class TransformerDetokenizer(TransformerTagger): 11 | def build_dataset(self, data, transform=None, **kwargs): 12 | return DetokenizationDataset(data, transform=transform, **kwargs) 13 | 14 | def predict(self, tokens: Any, batch_size: int = None, ret_scores=False, ret_tags=False, **kwargs): 15 | tags = super().predict(tokens, batch_size, ret_scores, **kwargs) 16 | if not ret_tags: 17 | flat = self.input_is_flat(tokens) 18 | if flat: 19 | tags = [tags] 20 | tokens = [tokens] 21 | sents = [''.join(sum(list(zip(token, tag)), ())) for token, tag in zip(tokens, tags)] 22 | if flat: 23 | return sents[0] 24 | return sents 25 | return tags 26 | -------------------------------------------------------------------------------- /elit/components/lambda_wrapper.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-31 18:36 4 | from typing import Callable, Any 5 | 6 | from elit.common.component import Component 7 | from hanlp_common.reflection import classpath_of, object_from_classpath, str_to_type 8 | 9 | 10 | class LambdaComponent(Component): 11 | def __init__(self, function: Callable) -> None: 12 | super().__init__() 13 | self.config = {} 14 | self.function = function 15 | self.config['function'] = classpath_of(function) 16 | self.config['classpath'] = classpath_of(self) 17 | 18 | def predict(self, data: Any, **kwargs): 19 | unpack = kwargs.pop('_hanlp_unpack', None) 20 | if unpack: 21 | return self.function(*data, **kwargs) 22 | return self.function(data, **kwargs) 23 | 24 | @staticmethod 25 | def from_config(meta: dict, **kwargs): 26 | cls = str_to_type(meta['classpath']) 27 | function = meta['function'] 28 | function = object_from_classpath(function) 29 | return cls(function) 30 | -------------------------------------------------------------------------------- /elit/components/lemmatizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-08 18:35 4 | from typing import List 5 | 6 | from elit.common.transform import TransformList 7 | from elit.components.parsers.ud.lemma_edit import gen_lemma_rule, apply_lemma_rule 8 | from elit.components.taggers.transformers.transformer_tagger import TransformerTagger 9 | 10 | 11 | def add_lemma_rules_to_sample(sample: dict): 12 | if 'tag' in sample and 'lemma' not in sample: 13 | lemma_rules = [gen_lemma_rule(word, lemma) 14 | if lemma != "_" else "_" 15 | for word, lemma in zip(sample['token'], sample['tag'])] 16 | sample['lemma'] = sample['tag'] = lemma_rules 17 | return sample 18 | 19 | 20 | class TransformerLemmatizer(TransformerTagger): 21 | 22 | def __init__(self, **kwargs) -> None: 23 | """A transition based lemmatizer using transformer as encoder. 24 | 25 | Args: 26 | **kwargs: Predefined config. 27 | """ 28 | super().__init__(**kwargs) 29 | 30 | def build_dataset(self, data, transform=None, **kwargs): 31 | if not isinstance(transform, list): 32 | transform = TransformList() 33 | transform.append(add_lemma_rules_to_sample) 34 | return super().build_dataset(data, transform, **kwargs) 35 | 36 | def prediction_to_human(self, pred, vocab: List[str], batch, token=None): 37 | if token is None: 38 | token = batch['token'] 39 | rules = super().prediction_to_human(pred, vocab, batch) 40 | for token_per_sent, rule_per_sent in zip(token, rules): 41 | lemma_per_sent = [apply_lemma_rule(t, r) for t, r in zip(token_per_sent, rule_per_sent)] 42 | yield lemma_per_sent 43 | -------------------------------------------------------------------------------- /elit/components/parsers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-22 12:46 -------------------------------------------------------------------------------- /elit/components/parsers/biaffine/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-05-08 20:43 4 | -------------------------------------------------------------------------------- /elit/components/parsers/biaffine/mlp.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Yu Zhang 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | 25 | import torch.nn as nn 26 | 27 | from elit.layers.dropout import SharedDropout 28 | 29 | 30 | class MLP(nn.Module): 31 | r""" 32 | Applies a linear transformation together with a non-linear activation to the incoming tensor: 33 | :math:`y = \mathrm{Activation}(x A^T + b)` 34 | 35 | Args: 36 | n_in (~torch.Tensor): 37 | The size of each input feature. 38 | n_out (~torch.Tensor): 39 | The size of each output feature. 40 | dropout (float): 41 | If non-zero, introduce a :class:`SharedDropout` layer on the output with this dropout ratio. Default: 0. 42 | activation (bool): 43 | Whether to use activations. Default: True. 44 | """ 45 | 46 | def __init__(self, n_in, n_out, dropout=0, activation=True): 47 | super().__init__() 48 | 49 | self.n_in = n_in 50 | self.n_out = n_out 51 | self.linear = nn.Linear(n_in, n_out) 52 | self.activation = nn.LeakyReLU(negative_slope=0.1) if activation else nn.Identity() 53 | self.dropout = SharedDropout(p=dropout) 54 | 55 | self.reset_parameters() 56 | 57 | def __repr__(self): 58 | s = f"n_in={self.n_in}, n_out={self.n_out}" 59 | if self.dropout.p > 0: 60 | s += f", dropout={self.dropout.p}" 61 | 62 | return f"{self.__class__.__name__}({s})" 63 | 64 | def reset_parameters(self): 65 | nn.init.orthogonal_(self.linear.weight) 66 | nn.init.zeros_(self.linear.bias) 67 | 68 | def forward(self, x): 69 | r""" 70 | Args: 71 | x (~torch.Tensor): 72 | The size of each input feature is `n_in`. 73 | 74 | Returns: 75 | A tensor with the size of each output feature `n_out`. 76 | """ 77 | 78 | x = self.linear(x) 79 | x = self.activation(x) 80 | x = self.dropout(x) 81 | 82 | return x 83 | 84 | -------------------------------------------------------------------------------- /elit/components/parsers/biaffine_tf/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-26 23:03 -------------------------------------------------------------------------------- /elit/components/parsers/conll.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-26 15:37 4 | from typing import Union 5 | 6 | from elit.utils.io_util import get_resource, TimingFileIterator 7 | from elit.utils.log_util import logger 8 | 9 | 10 | def collapse_enhanced_empty_nodes(sent: list): 11 | collapsed = [] 12 | for cells in sent: 13 | if isinstance(cells[0], float): 14 | id = cells[0] 15 | head, deprel = cells[8].split(':', 1) 16 | for x in sent: 17 | arrows = [s.split(':', 1) for s in x[8].split('|')] 18 | arrows = [(head, f'{head}:{deprel}>{r}') if h == str(id) else (h, r) for h, r in arrows] 19 | arrows = sorted(arrows) 20 | x[8] = '|'.join(f'{h}:{r}' for h, r in arrows) 21 | sent[head][7] += f'>{cells[7]}' 22 | else: 23 | collapsed.append(cells) 24 | return collapsed 25 | 26 | 27 | def read_conll(filepath: Union[str, TimingFileIterator], underline_to_none=False, enhanced_collapse_empty_nodes=False): 28 | sent = [] 29 | if isinstance(filepath, str): 30 | filepath: str = get_resource(filepath) 31 | if filepath.endswith('.conllu') and enhanced_collapse_empty_nodes is None: 32 | enhanced_collapse_empty_nodes = True 33 | src = open(filepath, encoding='utf-8') 34 | else: 35 | src = filepath 36 | for idx, line in enumerate(src): 37 | if line.startswith('#'): 38 | continue 39 | line = line.strip() 40 | cells = line.split('\t') 41 | if line and cells: 42 | if enhanced_collapse_empty_nodes and '.' in cells[0]: 43 | cells[0] = float(cells[0]) 44 | cells[6] = None 45 | else: 46 | if '-' in cells[0] or '.' in cells[0]: 47 | # sent[-1][1] += cells[1] 48 | continue 49 | cells[0] = int(cells[0]) 50 | if cells[6] != '_': 51 | try: 52 | cells[6] = int(cells[6]) 53 | except ValueError: 54 | cells[6] = 0 55 | logger.exception(f'Wrong CoNLL format {filepath}:{idx + 1}\n{line}') 56 | if underline_to_none: 57 | for i, x in enumerate(cells): 58 | if x == '_': 59 | cells[i] = None 60 | sent.append(cells) 61 | else: 62 | if enhanced_collapse_empty_nodes: 63 | sent = collapse_enhanced_empty_nodes(sent) 64 | yield sent 65 | sent = [] 66 | 67 | if sent: 68 | if enhanced_collapse_empty_nodes: 69 | sent = collapse_enhanced_empty_nodes(sent) 70 | yield sent 71 | 72 | src.close() 73 | 74 | -------------------------------------------------------------------------------- /elit/components/parsers/constituency/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-11-28 19:26 4 | -------------------------------------------------------------------------------- /elit/components/parsers/ud/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-14 20:34 4 | -------------------------------------------------------------------------------- /elit/components/parsers/ud/lemma_edit.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emorynlp/seq2seq-corenlp/7155b117630b79ba1a640e76dfe5ba93e1166fff/elit/components/parsers/ud/lemma_edit.py -------------------------------------------------------------------------------- /elit/components/parsers/ud/util.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-14 20:44 4 | from hanlp_common.constant import ROOT 5 | from elit.components.parsers.ud.lemma_edit import gen_lemma_rule 6 | 7 | 8 | def generate_lemma_rule(sample: dict): 9 | if 'LEMMA' in sample: 10 | sample['lemma'] = [gen_lemma_rule(word, lemma) if lemma != "_" else "_" for word, lemma in 11 | zip(sample['FORM'], sample['LEMMA'])] 12 | return sample 13 | 14 | 15 | def append_bos(sample: dict): 16 | if 'FORM' in sample: 17 | sample['token'] = [ROOT] + sample['FORM'] 18 | if 'UPOS' in sample: 19 | sample['pos'] = sample['UPOS'][:1] + sample['UPOS'] 20 | sample['arc'] = [0] + sample['HEAD'] 21 | sample['rel'] = sample['DEPREL'][:1] + sample['DEPREL'] 22 | sample['lemma'] = sample['lemma'][:1] + sample['lemma'] 23 | sample['feat'] = sample['FEATS'][:1] + sample['FEATS'] 24 | return sample 25 | 26 | 27 | def sample_form_missing(sample: dict): 28 | return all(t == '_' for t in sample['FORM']) 29 | -------------------------------------------------------------------------------- /elit/components/seq2seq/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-10-22 10:57 4 | -------------------------------------------------------------------------------- /elit/components/seq2seq/con/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-04-17 22:14 4 | -------------------------------------------------------------------------------- /elit/components/seq2seq/con/constrained_decoding.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-11-10 16:05 4 | import math 5 | import torch 6 | from transformers import BartTokenizer 7 | from transformers.generation_logits_process import LogitsProcessor 8 | 9 | from elit.utils.log_util import cprint 10 | 11 | 12 | class ShiftReduceProcessor(LogitsProcessor): 13 | def __init__(self, batch, ls, sh, rs, tokenizer: BartTokenizer): 14 | self.rs = rs 15 | self.sh = sh 16 | self.ls = ls 17 | self.tokenizer = tokenizer 18 | self.batch = batch 19 | self.eos = tokenizer.eos_token_id 20 | self.bos = tokenizer.bos_token_id 21 | tokens = batch['token'] 22 | self.offsets = [0] * len(tokens) 23 | self.depth = [0] * len(tokens) 24 | self.batch['_predictions'] = [[] for _ in tokens] 25 | 26 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 27 | mask = torch.full_like(scores, -math.inf) 28 | batch = self.batch 29 | for batch_id, beam_sent in enumerate(input_ids.view(-1, 1, input_ids.shape[-1])): 30 | for beam_id, sent in enumerate(beam_sent): 31 | allowed_tokens = set() 32 | index = batch_id * 1 + beam_id 33 | prefix_ids: list = input_ids[index][1:].tolist() 34 | if self.eos in prefix_ids: 35 | prefix_ids = prefix_ids[:prefix_ids.index(self.eos)] 36 | allowed_tokens.add(self.eos) 37 | prefix_str = self.tokenizer.convert_ids_to_tokens(prefix_ids) 38 | tokens = self.batch['token'][index] 39 | if prefix_ids: 40 | if prefix_ids[-1] == self.rs: 41 | self.depth[index] -= 1 42 | elif prefix_ids[-1] == self.sh: 43 | self.offsets[index] += 1 44 | elif prefix_ids[-1] in self.ls: 45 | self.depth[index] += 1 46 | if self.depth[index]: 47 | allowed_tokens.add(self.sh) 48 | allowed_tokens.update(self.ls) 49 | if self.depth[index]: 50 | allowed_tokens.add(self.rs) 51 | elif prefix_ids: 52 | allowed_tokens = {self.eos} 53 | allowed_tokens = sorted(list(allowed_tokens)) 54 | mask[index, allowed_tokens] = 0 55 | # cprint(f'{len(prefix_ids)} {prefix_str} [yellow]{self.tokenizer.convert_ids_to_tokens(allowed_tokens)}[/yellow]') 56 | 57 | return scores + mask 58 | -------------------------------------------------------------------------------- /elit/components/seq2seq/con/utility.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2023-01-10 18:08 4 | from typing import List 5 | 6 | from phrasetree.tree import Tree 7 | 8 | from elit.components.seq2seq.dep.dep_utility import LB, RB 9 | 10 | 11 | def bracket_linearize(tree: Tree, buffer: List): 12 | buffer.append(LB) 13 | buffer.append(tree.label()) 14 | for t in tree: 15 | if isinstance(t, Tree): 16 | bracket_linearize(t, buffer) 17 | else: 18 | buffer.append(t) 19 | buffer.append(RB) 20 | 21 | 22 | def flatten_terminals(tree: Tree, anonymize_token, placeholder='XX'): 23 | for i, child in enumerate(tree): 24 | if isinstance(child, str): 25 | if anonymize_token: 26 | tree[i] = placeholder 27 | elif child.label() == placeholder: 28 | tree[i] = placeholder if anonymize_token else child[0] 29 | else: 30 | flatten_terminals(child, anonymize_token, placeholder) 31 | 32 | 33 | def unflatten_terminals(tree: Tree, placeholder='XX'): 34 | for i, child in enumerate(tree): 35 | if isinstance(child, str): 36 | tree[i] = Tree(placeholder, [child]) 37 | else: 38 | unflatten_terminals(child, placeholder) 39 | 40 | 41 | def find_first(stack, label: str, start=0): 42 | for i in range(start, len(stack)): 43 | if stack[i] == label: 44 | return i 45 | return -1 -------------------------------------------------------------------------------- /elit/components/seq2seq/dep/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-03-29 21:59 4 | -------------------------------------------------------------------------------- /elit/components/seq2seq/dep/arc_eager.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | SH = 0 4 | RE = 1 5 | RA = 2 6 | LA = 3 7 | 8 | 9 | def transition(trans, stack: List[int], buffer: List[int], arcs: List[Tuple[int, int, str]]): 10 | if trans == SH: 11 | stack.insert(0, buffer.pop(0)) 12 | elif trans == RE: 13 | stack.pop(0) 14 | elif trans[0] == RA: 15 | top_w = stack[0] 16 | next_w = buffer[0] 17 | arcs.append((top_w, next_w, trans[1])) 18 | stack.insert(0, buffer.pop(0)) 19 | elif trans[0] == LA: 20 | top_w = stack.pop(0) 21 | next_w = buffer[0] 22 | arcs.append((next_w, top_w, trans[1])) 23 | 24 | 25 | def oracle(stack: List[int], buffer: List[int], heads: List[int], labels: List[str]): 26 | '''In accordance with algorithm 1 (Goldberg & Nivre, 2012)''' 27 | if heads[stack[0]] == buffer[0]: 28 | trans = (LA, labels[stack[0]]) 29 | elif stack[0] == heads[buffer[0]]: 30 | trans = (RA, labels[buffer[0]]) 31 | else: 32 | for i in range(stack[0]): 33 | if heads[i] == buffer[0] or heads[buffer[0]] == i: 34 | trans = RE 35 | return trans 36 | trans = SH 37 | return trans 38 | 39 | 40 | def encode(trans) -> str: 41 | if trans == SH: 42 | return 'SH' 43 | elif trans == RE: 44 | return 'RE' 45 | elif trans[0] == RA: 46 | return 'RA-' + trans[1] 47 | elif trans[0] == LA: 48 | return 'LA-' + trans[1] 49 | 50 | 51 | def decode(trans: str): 52 | if trans == 'SH': 53 | return SH 54 | elif trans == 'RE': 55 | return RE 56 | else: 57 | a, l = trans.split('-', 1) 58 | if a == 'RA': 59 | return RA, l 60 | else: 61 | return LA, l 62 | 63 | 64 | def restore_from_arcs(arcs: List[Tuple[int, int, str]]): 65 | heads, labels = [None] * len(arcs), [None] * len(arcs) 66 | for h, d, l in arcs: 67 | d -= 1 68 | if d < len(arcs): 69 | heads[d] = h 70 | labels[d] = l 71 | return heads, labels 72 | -------------------------------------------------------------------------------- /elit/components/seq2seq/ner/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2023-02-06 18:18 4 | -------------------------------------------------------------------------------- /elit/components/seq2seq/ner/conditional_seq2seq_constrained.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-01-19 10:01 4 | import string 5 | from typing import Callable, List, Optional 6 | 7 | import torch 8 | from transformers import LogitsProcessorList 9 | from transformers.generation_logits_process import LogitsProcessor 10 | from transformers.models.bart.modeling_bart import BartForConditionalGeneration 11 | from transformers.tokenization_utils import PreTrainedTokenizer 12 | 13 | from elit.components.seq2seq.ner.conditional_seq2seq import ConditionalSeq2seq 14 | 15 | 16 | class ContentLogitsProcessor(LogitsProcessor): 17 | 18 | def __init__(self, tokenizer: PreTrainedTokenizer) -> None: 19 | super().__init__() 20 | self.bad_ids = sorted( 21 | sum(tokenizer(list(string.punctuation), add_special_tokens=False).input_ids, [tokenizer.eos_token_id])) 22 | self.pred_length = 0 23 | 24 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 25 | if self.pred_length: 26 | return scores 27 | self.pred_length += 1 28 | mask = torch.full_like(scores, 0) 29 | mask[:, self.bad_ids] = float('-inf') 30 | return scores + mask 31 | 32 | 33 | class ConstrainedBartForConditionalGeneration(BartForConditionalGeneration): 34 | def _get_logits_processor(self, repetition_penalty: float, no_repeat_ngram_size: int, 35 | encoder_no_repeat_ngram_size: int, encoder_input_ids: torch.LongTensor, 36 | bad_words_ids: List[List[int]], min_length: int, max_length: int, eos_token_id: int, 37 | forced_bos_token_id: int, forced_eos_token_id: int, 38 | prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int, 39 | num_beam_groups: int, diversity_penalty: float, 40 | remove_invalid_values: bool, 41 | logits_processor: Optional[LogitsProcessorList]) -> LogitsProcessorList: 42 | logits_processor_list = super()._get_logits_processor(repetition_penalty, no_repeat_ngram_size, 43 | encoder_no_repeat_ngram_size, encoder_input_ids, 44 | bad_words_ids, min_length, max_length, eos_token_id, 45 | forced_bos_token_id, forced_eos_token_id, 46 | prefix_allowed_tokens_fn, num_beams, 47 | num_beam_groups, diversity_penalty, remove_invalid_values, 48 | logits_processor) 49 | processor = ContentLogitsProcessor(tokenizer=self.config.tokenizer) 50 | logits_processor_list.append(processor) 51 | return logits_processor_list 52 | 53 | 54 | class ConstrainedConditionalSeq2seq(ConditionalSeq2seq): 55 | def model_cls(self, **kwargs): 56 | return ConstrainedBartForConditionalGeneration 57 | -------------------------------------------------------------------------------- /elit/components/seq2seq/pos/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-03-23 10:30 4 | -------------------------------------------------------------------------------- /elit/components/seq2seq/pos/verbalizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-03-24 13:16 4 | from abc import ABC, abstractmethod 5 | 6 | from hanlp_common.configurable import AutoConfigurable 7 | from transformers import BartTokenizer 8 | 9 | from elit.components.seq2seq.ner.seq2seq_ner import tokenize 10 | 11 | 12 | class Verbalizer(ABC, AutoConfigurable): 13 | def __call__(self, sample: dict): 14 | if 'tag' in sample: 15 | sample['prompt'] = self.to_prompt(sample['token'], sample['tag']) 16 | return sample 17 | 18 | def tokenize_prompt(self, prompt, tokenizer): 19 | return tokenize(prompt, tokenizer, '')[-1] 20 | 21 | @abstractmethod 22 | def to_prompt(self, tokens, tags): 23 | pass 24 | 25 | def recover_no_constraints(self, tokens, ids, tokenizer): 26 | raise NotImplementedError() 27 | 28 | 29 | class TagVerbalizer(Verbalizer): 30 | def to_prompt(self, tokens, tags): 31 | return tags 32 | 33 | def recover_no_constraints(self, tokens, ids, tokenizer): 34 | tags = tokenizer.convert_ids_to_tokens(ids) 35 | if len(tags) < len(tokens): 36 | tags += [None] * (len(tokens) - len(tags)) 37 | elif len(tags) > len(tokens): 38 | tags = tags[:len(tokens)] 39 | return tags 40 | 41 | 42 | class TokenTagVerbalizer(Verbalizer): 43 | def to_prompt(self, tokens, tags): 44 | return list(sum(zip(tokens, tags), ())) 45 | 46 | def tokenize_prompt(self, prompt, tokenizer: BartTokenizer): 47 | ids = [tokenizer.bos_token_id] 48 | for i, token_or_tag in enumerate(prompt): 49 | if i % 2: 50 | ids.append(tokenizer.convert_tokens_to_ids(token_or_tag)) 51 | else: 52 | ids.extend(tokenizer(' ' + token_or_tag, add_special_tokens=False).input_ids) 53 | ids.append(tokenizer.eos_token_id) 54 | return ids 55 | 56 | def recover_no_constraints(self, tokens, ids, tokenizer): 57 | generated_tokens = tokenizer.convert_ids_to_tokens(ids) 58 | print() 59 | 60 | 61 | class IsAVerbalizer(Verbalizer): 62 | def __init__(self, tag_to_phrase: dict, quotation=False, is_a_tag=False) -> None: 63 | super().__init__() 64 | self.is_a_tag = is_a_tag 65 | self.quotation = quotation 66 | self.tag_to_phrase = tag_to_phrase 67 | 68 | def to_prompt(self, tokens, tags): 69 | phrases = [] 70 | for token, tag in zip(tokens, tags): 71 | p = f'" {token} " is {self.tag_to_phrase[tag]};' if self.quotation else \ 72 | f'{token} is {self.tag_to_phrase[tag]};' 73 | phrases.append(p) 74 | return ' '.join(phrases) 75 | 76 | def tokenize_prompt(self, prompt, tokenizer): 77 | return tokenizer(prompt).input_ids 78 | -------------------------------------------------------------------------------- /elit/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-06-13 18:15 4 | -------------------------------------------------------------------------------- /elit/datasets/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-11-10 11:49 -------------------------------------------------------------------------------- /elit/datasets/classification/sentiment.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-30 21:03 4 | _ERNIE_TASK_DATA = 'https://ernie.bj.bcebos.com/task_data_zh.tgz#' 5 | 6 | CHNSENTICORP_ERNIE_TRAIN = _ERNIE_TASK_DATA + 'chnsenticorp/train.tsv' 7 | CHNSENTICORP_ERNIE_DEV = _ERNIE_TASK_DATA + 'chnsenticorp/dev.tsv' 8 | CHNSENTICORP_ERNIE_TEST = _ERNIE_TASK_DATA + 'chnsenticorp/test.tsv' 9 | -------------------------------------------------------------------------------- /elit/datasets/coref/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-04 13:39 -------------------------------------------------------------------------------- /elit/datasets/coref/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-12-28 19:03 4 | -------------------------------------------------------------------------------- /elit/datasets/detokenization/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-11-24 20:31 4 | -------------------------------------------------------------------------------- /elit/datasets/detokenization/detok.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-11-24 20:31 4 | from typing import Union, List, Callable 5 | 6 | from elit.common.dataset import TransformableDataset 7 | from elit.utils.io_util import load_jsonl 8 | 9 | 10 | class DetokenizationDataset(TransformableDataset): 11 | 12 | def __init__(self, data: Union[str, List], transform: Union[Callable, List] = None, cache=None, 13 | generate_idx=None, **kwargs) -> None: 14 | super().__init__(data, transform, cache, generate_idx) 15 | 16 | def load_file(self, filepath: str): 17 | for sample in load_jsonl(filepath): 18 | text = sample['text'] 19 | offsets = sample['offsets'] 20 | tokens = [text[x[0]:x[1]] for x in offsets] 21 | spaces = [' ' if x.isspace() else '' for x in text] + [''] 22 | tags = [spaces[x[1]] for x in offsets] 23 | yield {'token': tokens, 'tag': tags} 24 | -------------------------------------------------------------------------------- /elit/datasets/eos/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-26 18:11 -------------------------------------------------------------------------------- /elit/datasets/eos/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-12-28 19:03 4 | -------------------------------------------------------------------------------- /elit/datasets/eos/loaders/nn_eos.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-24 22:51 4 | _SETIMES2_EN_HR_SENTENCES_HOME = 'https://schweter.eu/cloud/nn_eos/SETIMES2.en-hr.sentences.tar.xz' 5 | SETIMES2_EN_HR_HR_SENTENCES_TRAIN = _SETIMES2_EN_HR_SENTENCES_HOME + '#SETIMES2.en-hr.hr.sentences.train' 6 | '''Training set of SETimes corpus.''' 7 | SETIMES2_EN_HR_HR_SENTENCES_DEV = _SETIMES2_EN_HR_SENTENCES_HOME + '#SETIMES2.en-hr.hr.sentences.dev' 8 | '''Dev set of SETimes corpus.''' 9 | SETIMES2_EN_HR_HR_SENTENCES_TEST = _SETIMES2_EN_HR_SENTENCES_HOME + '#SETIMES2.en-hr.hr.sentences.test' 10 | '''Test set of SETimes corpus.''' 11 | _EUROPARL_V7_DE_EN_EN_SENTENCES_HOME = 'http://schweter.eu/cloud/nn_eos/europarl-v7.de-en.en.sentences.tar.xz' 12 | EUROPARL_V7_DE_EN_EN_SENTENCES_TRAIN = _EUROPARL_V7_DE_EN_EN_SENTENCES_HOME + '#europarl-v7.de-en.en.sentences.train' 13 | '''Training set of Europarl corpus (:cite:`koehn2005europarl`).''' 14 | EUROPARL_V7_DE_EN_EN_SENTENCES_DEV = _EUROPARL_V7_DE_EN_EN_SENTENCES_HOME + '#europarl-v7.de-en.en.sentences.dev' 15 | '''Dev set of Europarl corpus (:cite:`koehn2005europarl`).''' 16 | EUROPARL_V7_DE_EN_EN_SENTENCES_TEST = _EUROPARL_V7_DE_EN_EN_SENTENCES_HOME + '#europarl-v7.de-en.en.sentences.test' 17 | '''Test set of Europarl corpus (:cite:`koehn2005europarl`).''' 18 | -------------------------------------------------------------------------------- /elit/datasets/lm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-06-05 21:41 4 | 5 | _PTB_HOME = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz#' 6 | PTB_TOKEN_TRAIN = _PTB_HOME + 'data/ptb.train.txt' 7 | PTB_TOKEN_DEV = _PTB_HOME + 'data/ptb.valid.txt' 8 | PTB_TOKEN_TEST = _PTB_HOME + 'data/ptb.test.txt' 9 | 10 | PTB_CHAR_TRAIN = _PTB_HOME + 'data/ptb.char.train.txt' 11 | PTB_CHAR_DEV = _PTB_HOME + 'data/ptb.char.valid.txt' 12 | PTB_CHAR_TEST = _PTB_HOME + 'data/ptb.char.test.txt' 13 | -------------------------------------------------------------------------------- /elit/datasets/lm/doc_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-05-21 22:16 4 | from typing import Iterator, Any, Dict, Union, Callable, List 5 | 6 | from elit.common.dataset import TransformSequentialDataset 7 | 8 | 9 | class DocumentDataset(TransformSequentialDataset): 10 | 11 | def __init__(self, transform: Union[Callable, List] = None) -> None: 12 | """ 13 | Datasets where documents are segmented by two newlines and sentences are separated by one newline. 14 | 15 | Args: 16 | transform: 17 | """ 18 | super().__init__(transform) 19 | 20 | def __iter__(self) -> Iterator[Dict[str, Any]]: 21 | pass 22 | -------------------------------------------------------------------------------- /elit/datasets/lm/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-12-28 19:04 4 | -------------------------------------------------------------------------------- /elit/datasets/lm/sent_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-05-21 22:22 4 | from typing import Iterator, Any, Dict, Union, Callable, List 5 | 6 | from glob import glob 7 | 8 | from elit.common.dataset import TransformSequentialDataset 9 | 10 | 11 | class SentenceDataset(TransformSequentialDataset): 12 | def __init__(self, data, transform: Union[Callable, List] = None) -> None: 13 | """ 14 | Datasets where documents are segmented by two newlines and sentences are separated by one newline. 15 | 16 | Args: 17 | transform: 18 | """ 19 | super().__init__(transform) 20 | if isinstance(data, str): 21 | self.files = glob(data, recursive=True) 22 | assert self.files, f'No such file(s): {data}' 23 | else: 24 | self.files = None 25 | self.data = data 26 | 27 | def __iter__(self) -> Iterator[Dict[str, Any]]: 28 | for f in self.files: 29 | with open(f) as src: 30 | for line in src: 31 | line = line.strip() 32 | if not line: 33 | continue 34 | yield self.transform_sample({'sent': line}, inplace=True) 35 | -------------------------------------------------------------------------------- /elit/datasets/lu/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-12-28 19:08 4 | -------------------------------------------------------------------------------- /elit/datasets/lu/glue.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-11-10 11:47 4 | from elit.common.dataset import TableDataset 5 | 6 | STANFORD_SENTIMENT_TREEBANK_2_TRAIN = 'http://file.hankcs.com/corpus/SST2.zip#train.tsv' 7 | STANFORD_SENTIMENT_TREEBANK_2_DEV = 'http://file.hankcs.com/corpus/SST2.zip#dev.tsv' 8 | STANFORD_SENTIMENT_TREEBANK_2_TEST = 'http://file.hankcs.com/corpus/SST2.zip#test.tsv' 9 | 10 | MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_TRAIN = 'http://file.hankcs.com/corpus/mrpc.zip#train.tsv' 11 | MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_DEV = 'http://file.hankcs.com/corpus/mrpc.zip#dev.tsv' 12 | MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_TEST = 'http://file.hankcs.com/corpus/mrpc.zip#test.tsv' 13 | 14 | 15 | class SST2Dataset(TableDataset): 16 | pass 17 | 18 | 19 | def main(): 20 | dataset = SST2Dataset(STANFORD_SENTIMENT_TREEBANK_2_TEST) 21 | print(dataset) 22 | 23 | 24 | if __name__ == '__main__': 25 | main() 26 | -------------------------------------------------------------------------------- /elit/datasets/ner/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-06 15:32 -------------------------------------------------------------------------------- /elit/datasets/ner/conll03.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-06 15:31 4 | 5 | 6 | CONLL03_EN_TRAIN = 'https://file.hankcs.com/corpus/conll03_en_iobes.zip#eng.train.tsv' 7 | '''Training set of CoNLL03 (:cite:`tjong-kim-sang-de-meulder-2003-introduction`)''' 8 | CONLL03_EN_DEV = 'https://file.hankcs.com/corpus/conll03_en_iobes.zip#eng.dev.tsv' 9 | '''Dev set of CoNLL03 (:cite:`tjong-kim-sang-de-meulder-2003-introduction`)''' 10 | CONLL03_EN_TEST = 'https://file.hankcs.com/corpus/conll03_en_iobes.zip#eng.test.tsv' 11 | '''Test set of CoNLL03 (:cite:`tjong-kim-sang-de-meulder-2003-introduction`)''' 12 | -------------------------------------------------------------------------------- /elit/datasets/ner/conll03_json.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-12-09 00:32 4 | import os 5 | 6 | from elit.datasets.ner.conll03 import CONLL03_EN_TRAIN, CONLL03_EN_DEV, CONLL03_EN_TEST 7 | from elit.utils.io_util import get_resource, replace_ext 8 | from elit.utils.span_util import ner_tsv_to_jsonlines 9 | 10 | CONLL03_EN_JSON_TRAIN = 'https://file.hankcs.com/corpus/conll03_en_iobes.zip#eng.train.jsonlines' 11 | '''Training set of CoNLL03 (:cite:`tjong-kim-sang-de-meulder-2003-introduction`)''' 12 | CONLL03_EN_JSON_DEV = 'https://file.hankcs.com/corpus/conll03_en_iobes.zip#eng.dev.jsonlines' 13 | '''Dev set of CoNLL03 (:cite:`tjong-kim-sang-de-meulder-2003-introduction`)''' 14 | CONLL03_EN_JSON_TEST = 'https://file.hankcs.com/corpus/conll03_en_iobes.zip#eng.test.jsonlines' 15 | '''Test set of CoNLL03 (:cite:`tjong-kim-sang-de-meulder-2003-introduction`)''' 16 | 17 | 18 | def make_jsonlines_if_needed(): 19 | for tsv in [CONLL03_EN_TRAIN, CONLL03_EN_DEV, CONLL03_EN_TEST]: 20 | tsv = get_resource(tsv) 21 | if not os.path.isfile(replace_ext(tsv, '.jsonlines')): 22 | ner_tsv_to_jsonlines(tsv) 23 | 24 | 25 | make_jsonlines_if_needed() 26 | -------------------------------------------------------------------------------- /elit/datasets/ner/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-12-28 19:04 4 | -------------------------------------------------------------------------------- /elit/datasets/ner/msra.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 23:13 4 | 5 | _MSRA_NER_HOME = 'http://file.hankcs.com/corpus/msra_ner.zip' 6 | _MSRA_NER_TOKEN_LEVEL_HOME = 'http://file.hankcs.com/corpus/msra_ner_token_level.zip' 7 | 8 | MSRA_NER_CHAR_LEVEL_TRAIN = f'{_MSRA_NER_HOME}#train.tsv' 9 | '''Training set of MSRA (:cite:`levow-2006-third`) in character level.''' 10 | MSRA_NER_CHAR_LEVEL_DEV = f'{_MSRA_NER_HOME}#dev.tsv' 11 | '''Dev set of MSRA (:cite:`levow-2006-third`) in character level.''' 12 | MSRA_NER_CHAR_LEVEL_TEST = f'{_MSRA_NER_HOME}#test.tsv' 13 | '''Test set of MSRA (:cite:`levow-2006-third`) in character level.''' 14 | 15 | MSRA_NER_TOKEN_LEVEL_IOBES_TRAIN = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.train.tsv' 16 | '''Training set of MSRA (:cite:`levow-2006-third`) in token level.''' 17 | MSRA_NER_TOKEN_LEVEL_IOBES_DEV = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.dev.tsv' 18 | '''Dev set of MSRA (:cite:`levow-2006-third`) in token level.''' 19 | MSRA_NER_TOKEN_LEVEL_IOBES_TEST = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.test.tsv' 20 | '''Test set of MSRA (:cite:`levow-2006-third`) in token level.''' 21 | 22 | MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_TRAIN = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.train.short.tsv' 23 | '''Training set of shorten (<= 128 tokens) MSRA (:cite:`levow-2006-third`) in token level.''' 24 | MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_DEV = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.dev.short.tsv' 25 | '''Dev set of shorten (<= 128 tokens) MSRA (:cite:`levow-2006-third`) in token level.''' 26 | MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_TEST = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.test.short.tsv' 27 | '''Test set of shorten (<= 128 tokens) MSRA (:cite:`levow-2006-third`) in token level.''' 28 | 29 | MSRA_NER_TOKEN_LEVEL_SHORT_JSON_TRAIN = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.train.short.jsonlines' 30 | '''Training set of shorten (<= 128 tokens) MSRA (:cite:`levow-2006-third`) in token level and jsonlines format.''' 31 | MSRA_NER_TOKEN_LEVEL_SHORT_JSON_DEV = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.dev.short.jsonlines' 32 | '''Dev set of shorten (<= 128 tokens) MSRA (:cite:`levow-2006-third`) in token level and jsonlines format.''' 33 | MSRA_NER_TOKEN_LEVEL_SHORT_JSON_TEST = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.test.short.jsonlines' 34 | '''Test set of shorten (<= 128 tokens) MSRA (:cite:`levow-2006-third`) in token level and jsonlines format.''' 35 | -------------------------------------------------------------------------------- /elit/datasets/ner/resume.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-06-08 12:10 4 | from elit.common.dataset import TransformableDataset 5 | 6 | from elit.utils.io_util import get_resource, generate_words_tags_from_tsv 7 | 8 | _RESUME_NER_HOME = 'https://github.com/jiesutd/LatticeLSTM/archive/master.zip#' 9 | 10 | RESUME_NER_TRAIN = _RESUME_NER_HOME + 'ResumeNER/train.char.bmes' 11 | '''Training set of Resume in char level.''' 12 | RESUME_NER_DEV = _RESUME_NER_HOME + 'ResumeNER/dev.char.bmes' 13 | '''Dev set of Resume in char level.''' 14 | RESUME_NER_TEST = _RESUME_NER_HOME + 'ResumeNER/test.char.bmes' 15 | '''Test set of Resume in char level.''' 16 | 17 | -------------------------------------------------------------------------------- /elit/datasets/ner/weibo.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-06-03 23:33 4 | from elit.common.dataset import TransformableDataset 5 | 6 | from elit.utils.io_util import get_resource, generate_words_tags_from_tsv 7 | 8 | _WEIBO_NER_HOME = 'https://github.com/hltcoe/golden-horse/archive/master.zip#data/' 9 | 10 | WEIBO_NER_TRAIN = _WEIBO_NER_HOME + 'weiboNER_2nd_conll.train' 11 | '''Training set of Weibo in char level.''' 12 | WEIBO_NER_DEV = _WEIBO_NER_HOME + 'weiboNER_2nd_conll.dev' 13 | '''Dev set of Weibo in char level.''' 14 | WEIBO_NER_TEST = _WEIBO_NER_HOME + 'weiboNER_2nd_conll.test' 15 | '''Test set of Weibo in char level.''' 16 | -------------------------------------------------------------------------------- /elit/datasets/parsing/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 00:51 4 | -------------------------------------------------------------------------------- /elit/datasets/parsing/ctb5.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 18:44 4 | from hanlp_common.constant import HANLP_URL 5 | 6 | _CTB_HOME = HANLP_URL + 'embeddings/SUDA-LA-CIP_20200109_021624.zip#' 7 | 8 | _CTB5_DEP_HOME = _CTB_HOME + 'BPNN/data/ctb5/' 9 | 10 | CTB5_DEP_TRAIN = _CTB5_DEP_HOME + 'train.conll' 11 | '''Training set for ctb5 dependency parsing.''' 12 | CTB5_DEP_DEV = _CTB5_DEP_HOME + 'dev.conll' 13 | '''Dev set for ctb5 dependency parsing.''' 14 | CTB5_DEP_TEST = _CTB5_DEP_HOME + 'test.conll' 15 | '''Test set for ctb5 dependency parsing.''' 16 | 17 | CIP_W2V_100_CN = _CTB_HOME + 'BPNN/data/embed.txt' 18 | -------------------------------------------------------------------------------- /elit/datasets/parsing/ctb7.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 18:44 4 | from elit.datasets.parsing.ctb5 import _CTB_HOME 5 | 6 | _CTB7_HOME = _CTB_HOME + 'BPNN/data/ctb7/' 7 | 8 | CTB7_DEP_TRAIN = _CTB7_HOME + 'train.conll' 9 | '''Training set for ctb7 dependency parsing.''' 10 | CTB7_DEP_DEV = _CTB7_HOME + 'dev.conll' 11 | '''Dev set for ctb7 dependency parsing.''' 12 | CTB7_DEP_TEST = _CTB7_HOME + 'test.conll' 13 | '''Test set for ctb7 dependency parsing.''' 14 | -------------------------------------------------------------------------------- /elit/datasets/parsing/ctb8.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-10-14 20:54 4 | 5 | from elit.datasets.parsing.loaders._ctb_utils import make_ctb 6 | 7 | _CTB8_HOME = 'https://wakespace.lib.wfu.edu/bitstream/handle/10339/39379/LDC2013T21.tgz#data/' 8 | 9 | CTB8_CWS_TRAIN = _CTB8_HOME + 'tasks/cws/train.txt' 10 | '''Training set for ctb8 Chinese word segmentation.''' 11 | CTB8_CWS_DEV = _CTB8_HOME + 'tasks/cws/dev.txt' 12 | '''Dev set for ctb8 Chinese word segmentation.''' 13 | CTB8_CWS_TEST = _CTB8_HOME + 'tasks/cws/test.txt' 14 | '''Test set for ctb8 Chinese word segmentation.''' 15 | 16 | CTB8_POS_TRAIN = _CTB8_HOME + 'tasks/pos/train.tsv' 17 | '''Training set for ctb8 PoS tagging.''' 18 | CTB8_POS_DEV = _CTB8_HOME + 'tasks/pos/dev.tsv' 19 | '''Dev set for ctb8 PoS tagging.''' 20 | CTB8_POS_TEST = _CTB8_HOME + 'tasks/pos/test.tsv' 21 | '''Test set for ctb8 PoS tagging.''' 22 | 23 | CTB8_BRACKET_LINE_TRAIN = _CTB8_HOME + 'tasks/par/train.txt' 24 | '''Training set for ctb8 constituency parsing with empty categories.''' 25 | CTB8_BRACKET_LINE_DEV = _CTB8_HOME + 'tasks/par/dev.txt' 26 | '''Dev set for ctb8 constituency parsing with empty categories.''' 27 | CTB8_BRACKET_LINE_TEST = _CTB8_HOME + 'tasks/par/test.txt' 28 | '''Test set for ctb8 constituency parsing with empty categories.''' 29 | 30 | CTB8_BRACKET_LINE_NOEC_TRAIN = _CTB8_HOME + 'tasks/par/train.noempty.txt' 31 | '''Training set for ctb8 constituency parsing without empty categories.''' 32 | CTB8_BRACKET_LINE_NOEC_DEV = _CTB8_HOME + 'tasks/par/dev.noempty.txt' 33 | '''Dev set for ctb8 constituency parsing without empty categories.''' 34 | CTB8_BRACKET_LINE_NOEC_TEST = _CTB8_HOME + 'tasks/par/test.noempty.txt' 35 | '''Test set for ctb8 constituency parsing without empty categories.''' 36 | 37 | CTB8_SD330_TRAIN = _CTB8_HOME + 'tasks/dep/train.conllx' 38 | '''Training set for ctb8 in Stanford Dependencies 3.3.0 standard.''' 39 | CTB8_SD330_DEV = _CTB8_HOME + 'tasks/dep/dev.conllx' 40 | '''Dev set for ctb8 in Stanford Dependencies 3.3.0 standard.''' 41 | CTB8_SD330_TEST = _CTB8_HOME + 'tasks/dep/test.conllx' 42 | '''Test set for ctb8 in Stanford Dependencies 3.3.0 standard.''' 43 | 44 | make_ctb(_CTB8_HOME) 45 | -------------------------------------------------------------------------------- /elit/datasets/parsing/ctb9.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-10-14 20:54 4 | from urllib.error import HTTPError 5 | 6 | from elit.datasets.parsing.loaders._ctb_utils import make_ctb 7 | from elit.utils.io_util import get_resource, path_from_url 8 | 9 | _CTB9_HOME = 'https://catalog.ldc.upenn.edu/LDC2016T13/ctb9.0_LDC2016T13.tgz#data/' 10 | 11 | CTB9_CWS_TRAIN = _CTB9_HOME + 'tasks/cws/train.txt' 12 | '''Training set for ctb9 Chinese word segmentation.''' 13 | CTB9_CWS_DEV = _CTB9_HOME + 'tasks/cws/dev.txt' 14 | '''Dev set for ctb9 Chinese word segmentation.''' 15 | CTB9_CWS_TEST = _CTB9_HOME + 'tasks/cws/test.txt' 16 | '''Test set for ctb9 Chinese word segmentation.''' 17 | 18 | CTB9_POS_TRAIN = _CTB9_HOME + 'tasks/pos/train.tsv' 19 | '''Training set for ctb9 PoS tagging.''' 20 | CTB9_POS_DEV = _CTB9_HOME + 'tasks/pos/dev.tsv' 21 | '''Dev set for ctb9 PoS tagging.''' 22 | CTB9_POS_TEST = _CTB9_HOME + 'tasks/pos/test.tsv' 23 | '''Test set for ctb9 PoS tagging.''' 24 | 25 | CTB9_BRACKET_LINE_TRAIN = _CTB9_HOME + 'tasks/par/train.txt' 26 | '''Training set for ctb9 constituency parsing with empty categories.''' 27 | CTB9_BRACKET_LINE_DEV = _CTB9_HOME + 'tasks/par/dev.txt' 28 | '''Dev set for ctb9 constituency parsing with empty categories.''' 29 | CTB9_BRACKET_LINE_TEST = _CTB9_HOME + 'tasks/par/test.txt' 30 | '''Test set for ctb9 constituency parsing with empty categories.''' 31 | 32 | CTB9_BRACKET_LINE_NOEC_TRAIN = _CTB9_HOME + 'tasks/par/train.noempty.txt' 33 | '''Training set for ctb9 constituency parsing without empty categories.''' 34 | CTB9_BRACKET_LINE_NOEC_DEV = _CTB9_HOME + 'tasks/par/dev.noempty.txt' 35 | '''Dev set for ctb9 constituency parsing without empty categories.''' 36 | CTB9_BRACKET_LINE_NOEC_TEST = _CTB9_HOME + 'tasks/par/test.noempty.txt' 37 | '''Test set for ctb9 constituency parsing without empty categories.''' 38 | 39 | CTB9_SD330_TRAIN = _CTB9_HOME + 'tasks/dep/train.conllx' 40 | '''Training set for ctb9 in Stanford Dependencies 3.3.0 standard.''' 41 | CTB9_SD330_DEV = _CTB9_HOME + 'tasks/dep/dev.conllx' 42 | '''Dev set for ctb9 in Stanford Dependencies 3.3.0 standard.''' 43 | CTB9_SD330_TEST = _CTB9_HOME + 'tasks/dep/test.conllx' 44 | '''Test set for ctb9 in Stanford Dependencies 3.3.0 standard.''' 45 | 46 | try: 47 | get_resource(_CTB9_HOME) 48 | except HTTPError: 49 | raise FileNotFoundError( 50 | 'Chinese Treebank 9.0 is a copyright dataset owned by LDC which we cannot re-distribute. ' 51 | f'Please apply for a licence from LDC (https://catalog.ldc.upenn.edu/LDC2016T13) ' 52 | f'then download it to {path_from_url(_CTB9_HOME)}' 53 | ) 54 | 55 | make_ctb(_CTB9_HOME) 56 | -------------------------------------------------------------------------------- /elit/datasets/parsing/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-12-28 19:04 4 | -------------------------------------------------------------------------------- /elit/datasets/parsing/pmt1.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-02-15 04:14 4 | import os.path 5 | 6 | from elit.utils.io_util import get_resource 7 | from elit.utils.log_util import cprint 8 | from hanlp_common.conll import CoNLLSentence, CoNLLWord 9 | 10 | _HOME = 'https://github.com/qiulikun/PKUMultiviewTreebank/archive/refs/heads/master.zip' 11 | PTM_V1_RAW = _HOME + '#199801_dependency_treebank_2014pos.txt' 12 | PTM_V1_TRAIN = _HOME + '#train.conllx' 13 | 'The training set of PKU Multi-view Chinese Treebank (PMT) 1.0 (:cite:`qiu-etal-2014-multi`).' 14 | PTM_V1_DEV = _HOME + '#dev.conllx' 15 | 'The dev set of PKU Multi-view Chinese Treebank (PMT) 1.0 (:cite:`qiu-etal-2014-multi`).' 16 | PTM_V1_TEST = _HOME + '#test.conllx' 17 | 'The test set of PKU Multi-view Chinese Treebank (PMT) 1.0 (:cite:`qiu-etal-2014-multi`).' 18 | 19 | 20 | def _make_ptm(): 21 | raw = get_resource(PTM_V1_RAW) 22 | home = os.path.dirname(raw) 23 | done = True 24 | for part in ['train', 'dev', 'test']: 25 | if not os.path.isfile(os.path.join(home, f'{part}.conllx')): 26 | done = False 27 | break 28 | if done: 29 | return 30 | sents = [] 31 | with open(raw) as src: 32 | buffer = [] 33 | for line in src: 34 | line = line.strip() 35 | if line: 36 | buffer.append(line) 37 | else: 38 | if buffer: 39 | tok, pos, rel, arc = [x.split() for x in buffer] 40 | sent = CoNLLSentence() 41 | for i, (t, p, r, a) in enumerate(zip(tok, pos, rel, arc)): 42 | sent.append(CoNLLWord(i + 1, form=t, cpos=p, head=a, deprel=r)) 43 | sents.append(sent) 44 | buffer.clear() 45 | 46 | prev_offset = 0 47 | # Sentences 12001-13000 and 13001-14463 are used as the development and test set, respectively. The remaining 48 | # sentences are used as training data. 49 | for part, offset in zip(['train', 'dev', 'test'], [12000, 13000, 14463]): 50 | with open(os.path.join(home, f'{part}.conllx'), 'w') as out: 51 | portion = sents[prev_offset:offset] 52 | cprint(f'[yellow]{len(portion)}[/yellow] sentences [cyan][{prev_offset + 1}:{offset})[/cyan] in {part}') 53 | for sent in portion: 54 | out.write(str(sent) + '\n\n') 55 | prev_offset = offset 56 | 57 | 58 | _make_ptm() 59 | -------------------------------------------------------------------------------- /elit/datasets/parsing/ptb.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-02-17 15:46 4 | 5 | _PTB_HOME = 'https://github.com/KhalilMrini/LAL-Parser/archive/master.zip#data/' 6 | 7 | PTB_TRAIN = _PTB_HOME + '02-21.10way.clean' 8 | '''Training set of PTB without empty categories. PoS tags are automatically predicted using 10-fold 9 | jackknifing (:cite:`collins-koo-2005-discriminative`).''' 10 | PTB_DEV = _PTB_HOME + '22.auto.clean' 11 | '''Dev set of PTB without empty categories. PoS tags are automatically predicted using 10-fold 12 | jackknifing (:cite:`collins-koo-2005-discriminative`).''' 13 | PTB_TEST = _PTB_HOME + '23.auto.clean' 14 | '''Test set of PTB without empty categories. PoS tags are automatically predicted using 10-fold 15 | jackknifing (:cite:`collins-koo-2005-discriminative`).''' 16 | 17 | PTB_SD330_TRAIN = _PTB_HOME + 'ptb_train_3.3.0.sd.clean' 18 | '''Training set of PTB in Stanford Dependencies 3.3.0 format. PoS tags are automatically predicted using 10-fold 19 | jackknifing (:cite:`collins-koo-2005-discriminative`).''' 20 | PTB_SD330_DEV = _PTB_HOME + 'ptb_dev_3.3.0.sd.clean' 21 | '''Dev set of PTB in Stanford Dependencies 3.3.0 format. PoS tags are automatically predicted using 10-fold 22 | jackknifing (:cite:`collins-koo-2005-discriminative`).''' 23 | PTB_SD330_TEST = _PTB_HOME + 'ptb_test_3.3.0.sd.clean' 24 | '''Test set of PTB in Stanford Dependencies 3.3.0 format. PoS tags are automatically predicted using 10-fold 25 | jackknifing (:cite:`collins-koo-2005-discriminative`).''' 26 | 27 | PTB_TOKEN_MAPPING = { 28 | "-LRB-": "(", 29 | "-RRB-": ")", 30 | "-LCB-": "{", 31 | "-RCB-": "}", 32 | "-LSB-": "[", 33 | "-RSB-": "]", 34 | "``": '"', 35 | "''": '"', 36 | "`": "'", 37 | '«': '"', 38 | '»': '"', 39 | '‘': "'", 40 | '’': "'", 41 | '“': '"', 42 | '”': '"', 43 | '„': '"', 44 | '‹': "'", 45 | '›': "'", 46 | "\u2013": "--", # en dash 47 | "\u2014": "--", # em dash 48 | } 49 | -------------------------------------------------------------------------------- /elit/datasets/parsing/semeval15.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-28 14:40 4 | # from elit.datasets.parsing.conll_dataset import CoNLLParsingDataset 5 | # 6 | # 7 | # class SemEval15Dataset(CoNLLParsingDataset): 8 | # def load_file(self, filepath: str): 9 | # pass 10 | import warnings 11 | 12 | from hanlp_common.constant import ROOT, PAD 13 | from hanlp_common.conll import CoNLLSentence 14 | 15 | 16 | def unpack_deps_to_head_deprel(sample: dict, pad_rel=None, arc_key='arc', rel_key='rel'): 17 | if 'DEPS' in sample: 18 | deps = ['_'] + sample['DEPS'] 19 | sample[arc_key] = arc = [] 20 | sample[rel_key] = rel = [] 21 | for each in deps: 22 | arc_per_token = [False] * len(deps) 23 | rel_per_token = [None] * len(deps) 24 | if each != '_': 25 | for ar in each.split('|'): 26 | a, r = ar.split(':') 27 | a = int(a) 28 | arc_per_token[a] = True 29 | rel_per_token[a] = r 30 | if not pad_rel: 31 | pad_rel = r 32 | arc.append(arc_per_token) 33 | rel.append(rel_per_token) 34 | if not pad_rel: 35 | pad_rel = PAD 36 | for i in range(len(rel)): 37 | rel[i] = [r if r else pad_rel for r in rel[i]] 38 | return sample 39 | 40 | 41 | def append_bos_to_form_pos(sample, pos_key='CPOS'): 42 | sample['token'] = [ROOT] + sample['FORM'] 43 | if pos_key in sample: 44 | sample['pos'] = [ROOT] + sample[pos_key] 45 | return sample 46 | 47 | 48 | def merge_head_deprel_with_2nd(sample: dict): 49 | if 'arc' in sample: 50 | arc_2nd = sample['arc_2nd'] 51 | rel_2nd = sample['rel_2nd'] 52 | for i, (arc, rel) in enumerate(zip(sample['arc'], sample['rel'])): 53 | if i: 54 | if arc_2nd[i][arc] and rel_2nd[i][arc] != rel: 55 | sample_str = CoNLLSentence.from_dict(sample, conllu=True).to_markdown() 56 | warnings.warn(f'The main dependency conflicts with 2nd dependency at ID={i}, ' \ 57 | 'which means joint mode might not be suitable. ' \ 58 | f'The sample is\n{sample_str}') 59 | arc_2nd[i][arc] = True 60 | rel_2nd[i][arc] = rel 61 | return sample 62 | -------------------------------------------------------------------------------- /elit/datasets/parsing/semeval16.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 00:51 4 | from hanlp_common.conll import CoNLLSentence 5 | import os 6 | 7 | from elit.utils.io_util import get_resource, merge_files 8 | from hanlp_common.io import eprint 9 | 10 | _SEMEVAL2016_HOME = 'https://github.com/HIT-SCIR/SemEval-2016/archive/master.zip' 11 | 12 | SEMEVAL2016_NEWS_TRAIN = _SEMEVAL2016_HOME + '#train/news.train.conll' 13 | SEMEVAL2016_NEWS_DEV = _SEMEVAL2016_HOME + '#validation/news.valid.conll' 14 | SEMEVAL2016_NEWS_TEST = _SEMEVAL2016_HOME + '#test/news.test.conll' 15 | 16 | SEMEVAL2016_NEWS_TRAIN_CONLLU = _SEMEVAL2016_HOME + '#train/news.train.conllu' 17 | SEMEVAL2016_NEWS_DEV_CONLLU = _SEMEVAL2016_HOME + '#validation/news.valid.conllu' 18 | SEMEVAL2016_NEWS_TEST_CONLLU = _SEMEVAL2016_HOME + '#test/news.test.conllu' 19 | 20 | SEMEVAL2016_TEXT_TRAIN = _SEMEVAL2016_HOME + '#train/text.train.conll' 21 | SEMEVAL2016_TEXT_DEV = _SEMEVAL2016_HOME + '#validation/text.valid.conll' 22 | SEMEVAL2016_TEXT_TEST = _SEMEVAL2016_HOME + '#test/text.test.conll' 23 | 24 | SEMEVAL2016_TEXT_TRAIN_CONLLU = _SEMEVAL2016_HOME + '#train/text.train.conllu' 25 | SEMEVAL2016_TEXT_DEV_CONLLU = _SEMEVAL2016_HOME + '#validation/text.valid.conllu' 26 | SEMEVAL2016_TEXT_TEST_CONLLU = _SEMEVAL2016_HOME + '#test/text.test.conllu' 27 | 28 | SEMEVAL2016_FULL_TRAIN_CONLLU = _SEMEVAL2016_HOME + '#train/full.train.conllu' 29 | SEMEVAL2016_FULL_DEV_CONLLU = _SEMEVAL2016_HOME + '#validation/full.valid.conllu' 30 | SEMEVAL2016_FULL_TEST_CONLLU = _SEMEVAL2016_HOME + '#test/full.test.conllu' 31 | 32 | 33 | def convert_conll_to_conllu(path): 34 | sents = CoNLLSentence.from_file(path, conllu=True) 35 | with open(os.path.splitext(path)[0] + '.conllu', 'w') as out: 36 | for sent in sents: 37 | for word in sent: 38 | if not word.deps: 39 | word.deps = [(word.head, word.deprel)] 40 | word.head = None 41 | word.deprel = None 42 | out.write(str(sent)) 43 | out.write('\n\n') 44 | 45 | 46 | for file in [SEMEVAL2016_NEWS_TRAIN, SEMEVAL2016_NEWS_DEV, SEMEVAL2016_NEWS_TEST, 47 | SEMEVAL2016_TEXT_TRAIN, SEMEVAL2016_TEXT_DEV, SEMEVAL2016_TEXT_TEST]: 48 | file = get_resource(file) 49 | conllu = os.path.splitext(file)[0] + '.conllu' 50 | if not os.path.isfile(conllu): 51 | eprint(f'Converting {os.path.basename(file)} to {os.path.basename(conllu)} ...') 52 | convert_conll_to_conllu(file) 53 | 54 | for group, part in zip([[SEMEVAL2016_NEWS_TRAIN_CONLLU, SEMEVAL2016_TEXT_TRAIN_CONLLU], 55 | [SEMEVAL2016_NEWS_DEV_CONLLU, SEMEVAL2016_TEXT_DEV_CONLLU], 56 | [SEMEVAL2016_NEWS_TEST_CONLLU, SEMEVAL2016_TEXT_TEST_CONLLU]], 57 | ['train', 'valid', 'test']): 58 | root = get_resource(_SEMEVAL2016_HOME) 59 | dst = f'{root}/train/full.{part}.conllu' 60 | if not os.path.isfile(dst): 61 | group = [get_resource(x) for x in group] 62 | eprint(f'Concatenating {os.path.basename(group[0])} and {os.path.basename(group[1])} ' 63 | f'into full dataset {os.path.basename(dst)} ...') 64 | merge_files(group, dst) 65 | -------------------------------------------------------------------------------- /elit/datasets/parsing/ud/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-07 21:45 4 | import os 5 | import shutil 6 | 7 | from elit.components.parsers.ud.udify_util import get_ud_treebank_files 8 | from elit.utils.io_util import get_resource 9 | from elit.utils.log_util import flash 10 | 11 | 12 | def concat_treebanks(home, version): 13 | ud_home = get_resource(home) 14 | treebanks = get_ud_treebank_files(ud_home) 15 | output_dir = os.path.abspath(os.path.join(ud_home, os.path.pardir, os.path.pardir, f'ud-multilingual-v{version}')) 16 | if os.path.isdir(output_dir): 17 | return output_dir 18 | os.makedirs(output_dir) 19 | train, dev, test = list(zip(*[treebanks[k] for k in treebanks])) 20 | 21 | for treebank, name in zip([train, dev, test], ["train.conllu", "dev.conllu", "test.conllu"]): 22 | flash(f'Concatenating {len(train)} treebanks into {name} [blink][yellow]...[/yellow][/blink]') 23 | with open(os.path.join(output_dir, name), 'w') as write: 24 | for t in treebank: 25 | if not t: 26 | continue 27 | with open(t, 'r') as read: 28 | shutil.copyfileobj(read, write) 29 | flash('') 30 | return output_dir 31 | -------------------------------------------------------------------------------- /elit/datasets/parsing/ud/ud210m.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-05-21 20:39 4 | import os 5 | 6 | from elit.datasets.parsing.ud import concat_treebanks 7 | from elit.datasets.parsing.ud.ud210 import _UD_210_HOME 8 | 9 | _UD_210_MULTILINGUAL_HOME = concat_treebanks(_UD_210_HOME, '2.10') 10 | UD_210_MULTILINGUAL_TRAIN = os.path.join(_UD_210_MULTILINGUAL_HOME, 'train.conllu') 11 | "Training set of multilingual UD_210 obtained by concatenating all training sets." 12 | UD_210_MULTILINGUAL_DEV = os.path.join(_UD_210_MULTILINGUAL_HOME, 'dev.conllu') 13 | "Dev set of multilingual UD_210 obtained by concatenating all dev sets." 14 | UD_210_MULTILINGUAL_TEST = os.path.join(_UD_210_MULTILINGUAL_HOME, 'test.conllu') 15 | "Test set of multilingual UD_210 obtained by concatenating all test sets." 16 | -------------------------------------------------------------------------------- /elit/datasets/parsing/ud/ud23m.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-05-21 20:39 4 | import os 5 | 6 | from elit.datasets.parsing.ud import concat_treebanks 7 | from .ud23 import _UD_23_HOME 8 | 9 | _UD_23_MULTILINGUAL_HOME = concat_treebanks(_UD_23_HOME, '2.3') 10 | UD_23_MULTILINGUAL_TRAIN = os.path.join(_UD_23_MULTILINGUAL_HOME, 'train.conllu') 11 | UD_23_MULTILINGUAL_DEV = os.path.join(_UD_23_MULTILINGUAL_HOME, 'dev.conllu') 12 | UD_23_MULTILINGUAL_TEST = os.path.join(_UD_23_MULTILINGUAL_HOME, 'test.conllu') 13 | -------------------------------------------------------------------------------- /elit/datasets/parsing/ud/ud27m.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-05-21 20:39 4 | import os 5 | 6 | from elit.datasets.parsing.ud import concat_treebanks 7 | from elit.datasets.parsing.ud.ud27 import _UD_27_HOME 8 | 9 | _UD_27_MULTILINGUAL_HOME = concat_treebanks(_UD_27_HOME, '2.7') 10 | UD_27_MULTILINGUAL_TRAIN = os.path.join(_UD_27_MULTILINGUAL_HOME, 'train.conllu') 11 | "Training set of multilingual UD_27 obtained by concatenating all training sets." 12 | UD_27_MULTILINGUAL_DEV = os.path.join(_UD_27_MULTILINGUAL_HOME, 'dev.conllu') 13 | "Dev set of multilingual UD_27 obtained by concatenating all dev sets." 14 | UD_27_MULTILINGUAL_TEST = os.path.join(_UD_27_MULTILINGUAL_HOME, 'test.conllu') 15 | "Test set of multilingual UD_27 obtained by concatenating all test sets." 16 | -------------------------------------------------------------------------------- /elit/datasets/parsing/ud/ud28m.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-05-21 20:39 4 | import os 5 | 6 | from elit.datasets.parsing.ud import concat_treebanks 7 | from elit.datasets.parsing.ud.ud28 import _UD_28_HOME 8 | 9 | _UD_28_MULTILINGUAL_HOME = concat_treebanks(_UD_28_HOME, '2.8') 10 | UD_28_MULTILINGUAL_TRAIN = os.path.join(_UD_28_MULTILINGUAL_HOME, 'train.conllu') 11 | "Training set of multilingual UD_28 obtained by concatenating all training sets." 12 | UD_28_MULTILINGUAL_DEV = os.path.join(_UD_28_MULTILINGUAL_HOME, 'dev.conllu') 13 | "Dev set of multilingual UD_28 obtained by concatenating all dev sets." 14 | UD_28_MULTILINGUAL_TEST = os.path.join(_UD_28_MULTILINGUAL_HOME, 'test.conllu') 15 | "Test set of multilingual UD_28 obtained by concatenating all test sets." 16 | -------------------------------------------------------------------------------- /elit/datasets/pos/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 22:50 -------------------------------------------------------------------------------- /elit/datasets/pos/ctb5.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 22:51 4 | 5 | _CTB5_POS_HOME = 'http://file.hankcs.com/corpus/ctb5.1-pos.zip' 6 | 7 | CTB5_POS_TRAIN = f'{_CTB5_POS_HOME}#train.tsv' 8 | '''PoS training set for CTB5.''' 9 | CTB5_POS_DEV = f'{_CTB5_POS_HOME}#dev.tsv' 10 | '''PoS dev set for CTB5.''' 11 | CTB5_POS_TEST = f'{_CTB5_POS_HOME}#test.tsv' 12 | '''PoS test set for CTB5.''' 13 | -------------------------------------------------------------------------------- /elit/datasets/pos/ptb.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-03-26 19:11 4 | import os 5 | import re 6 | from glob import glob 7 | from urllib.error import HTTPError 8 | from elit.utils.io_util import get_resource, path_from_url 9 | 10 | _PTB_POS_HOME = 'https://catalog.ldc.upenn.edu/LDC99T42/LDC99T42.tgz#treebank_3/tagged/pos/wsj/' 11 | 12 | PTB_POS_TRAIN = _PTB_POS_HOME + 'train.tsv' 13 | '''Training set for PTB PoS tagging.''' 14 | PTB_POS_DEV = _PTB_POS_HOME + 'dev.tsv' 15 | '''Dev set for PTB PoS tagging.''' 16 | PTB_POS_TEST = _PTB_POS_HOME + 'test.tsv' 17 | '''Test set for PTB PoS tagging.''' 18 | 19 | try: 20 | get_resource(_PTB_POS_HOME, verbose=False) 21 | except HTTPError: 22 | raise FileNotFoundError( 23 | 'The Penn Treebank is a copyright dataset owned by LDC which we cannot re-distribute. ' 24 | f'Please apply for a licence from LDC (https://catalog.ldc.upenn.edu/LDC99T42) ' 25 | f'then download it to {path_from_url(_PTB_POS_HOME)}' 26 | ) from None 27 | 28 | _TOKEN_TAG = re.compile(r'\S+/\S+') 29 | 30 | 31 | def _make_ptb_pos(): 32 | home = get_resource(_PTB_POS_HOME) 33 | training = list(range(0, 18 + 1)) 34 | development = list(range(19, 21 + 1)) 35 | test = list(range(22, 24 + 1)) 36 | for part, ids in zip(['train', 'dev', 'test'], [training, development, test]): 37 | out = f'{home}{part}.tsv' 38 | if os.path.isfile(out): 39 | continue 40 | with open(out, 'w') as out: 41 | dataset = [] 42 | for fid in ids: 43 | for file in sorted(glob(f'{home}{fid:02d}/*.pos')): 44 | with open(file) as src: 45 | sent = [] 46 | for line in src: 47 | line = line.strip() 48 | if not line: 49 | if sent: 50 | dataset.append(sent) 51 | sent = [] 52 | elif line.startswith('=========='): 53 | continue 54 | else: 55 | for pair in _TOKEN_TAG.findall(line): 56 | pair = pair.rsplit('/', 1) 57 | sent.append(pair) 58 | if sent: 59 | dataset.append(sent) 60 | 61 | for sent in dataset: 62 | for token, pos in sent: 63 | out.write(f'{token}\t{pos}\n') 64 | out.write('\n') 65 | 66 | 67 | _make_ptb_pos() 68 | -------------------------------------------------------------------------------- /elit/datasets/qa/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-03-20 19:17 -------------------------------------------------------------------------------- /elit/datasets/srl/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-06-22 19:15 4 | 5 | 6 | -------------------------------------------------------------------------------- /elit/datasets/srl/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-12-28 19:05 4 | -------------------------------------------------------------------------------- /elit/datasets/srl/ontonotes4/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-11-26 16:07 4 | ONTONOTES4_HOME = 'https://catalog.ldc.upenn.edu/LDC2011T03/ontonotes-release-4.0_LDC2011T03.tgz#/ontonotes-release-4.0/data/' 5 | ONTONOTES4_TASKS_HOME = ONTONOTES4_HOME + '../tasks/' 6 | -------------------------------------------------------------------------------- /elit/datasets/srl/ontonotes5/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-11-26 16:07 4 | ONTONOTES5_HOME = 'https://catalog.ldc.upenn.edu/LDC2013T19/LDC2013T19.tgz#/ontonotes-release-5.0/data/' 5 | CONLL12_HOME = ONTONOTES5_HOME + '../conll-2012/' 6 | -------------------------------------------------------------------------------- /elit/datasets/sts/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-05-20 16:25 4 | -------------------------------------------------------------------------------- /elit/datasets/sts/stsb.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-05-20 16:25 4 | from typing import Union, List, Callable 5 | 6 | from elit.common.dataset import TransformableDataset 7 | from elit.utils.io_util import read_cells 8 | 9 | STS_B_TRAIN = 'http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz#sts-train.csv' 10 | STS_B_DEV = 'http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz#sts-dev.csv' 11 | STS_B_TEST = 'http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz#sts-test.csv' 12 | 13 | 14 | class SemanticTextualSimilarityDataset(TransformableDataset): 15 | def __init__(self, 16 | data: Union[str, List], 17 | sent_a_col, 18 | sent_b_col, 19 | similarity_col, 20 | delimiter='auto', 21 | transform: Union[Callable, List] = None, 22 | cache=None, 23 | generate_idx=None) -> None: 24 | self.delimiter = delimiter 25 | self.similarity_col = similarity_col 26 | self.sent_b_col = sent_b_col 27 | self.sent_a_col = sent_a_col 28 | super().__init__(data, transform, cache, generate_idx) 29 | 30 | def load_file(self, filepath: str): 31 | for i, cells in enumerate(read_cells(filepath, strip=True, delimiter=self.delimiter)): 32 | yield { 33 | 'sent_a': cells[self.sent_a_col], 34 | 'sent_b': cells[self.sent_b_col], 35 | 'similarity': float(cells[self.similarity_col]) 36 | } 37 | -------------------------------------------------------------------------------- /elit/datasets/tokenization/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-08-01 12:33 -------------------------------------------------------------------------------- /elit/datasets/tokenization/ctb6.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 22:19 4 | 5 | _CTB6_CWS_HOME = 'http://file.hankcs.com/corpus/ctb6_cws.zip' 6 | 7 | CTB6_CWS_TRAIN = _CTB6_CWS_HOME + '#train.txt' 8 | '''CTB6 training set.''' 9 | CTB6_CWS_DEV = _CTB6_CWS_HOME + '#dev.txt' 10 | '''CTB6 dev set.''' 11 | CTB6_CWS_TEST = _CTB6_CWS_HOME + '#test.txt' 12 | '''CTB6 test set.''' 13 | -------------------------------------------------------------------------------- /elit/datasets/tokenization/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-12-28 19:06 4 | -------------------------------------------------------------------------------- /elit/datasets/tokenization/loaders/chunking_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-06-03 18:50 4 | from typing import Union, List, Callable 5 | 6 | from elit.common.dataset import TransformableDataset 7 | from elit.utils.io_util import get_resource 8 | from elit.utils.span_util import bmes_of 9 | from elit.utils.string_util import ispunct 10 | 11 | 12 | class ChunkingDataset(TransformableDataset): 13 | 14 | def __init__(self, data: Union[str, List], transform: Union[Callable, List] = None, cache=None, 15 | generate_idx=None, max_seq_len=None, sent_delimiter=None) -> None: 16 | if not sent_delimiter: 17 | sent_delimiter = lambda x: ispunct(x) 18 | elif isinstance(sent_delimiter, str): 19 | sent_delimiter = set(list(sent_delimiter)) 20 | sent_delimiter = lambda x: x in sent_delimiter 21 | self.sent_delimiter = sent_delimiter 22 | self.max_seq_len = max_seq_len 23 | super().__init__(data, transform, cache, generate_idx) 24 | 25 | def load_file(self, filepath): 26 | max_seq_len = self.max_seq_len 27 | delimiter = self.sent_delimiter 28 | for chars, tags in self._generate_chars_tags(filepath, delimiter, max_seq_len): 29 | yield {'char': chars, 'tag': tags} 30 | 31 | @staticmethod 32 | def _generate_chars_tags(filepath, delimiter, max_seq_len): 33 | filepath = get_resource(filepath) 34 | with open(filepath, encoding='utf8') as src: 35 | for text in src: 36 | chars, tags = bmes_of(text, True) 37 | if max_seq_len and delimiter and len(chars) > max_seq_len: 38 | short_chars, short_tags = [], [] 39 | for idx, (char, tag) in enumerate(zip(chars, tags)): 40 | short_chars.append(char) 41 | short_tags.append(tag) 42 | if len(short_chars) >= max_seq_len and delimiter(char): 43 | yield short_chars, short_tags 44 | short_chars, short_tags = [], [] 45 | if short_chars: 46 | yield short_chars, short_tags 47 | else: 48 | yield chars, tags 49 | -------------------------------------------------------------------------------- /elit/datasets/tokenization/loaders/multi_criteria_cws/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-08-11 20:35 4 | 5 | _HOME = 'https://github.com/hankcs/multi-criteria-cws/archive/naive-mix.zip#data/raw/' 6 | 7 | CNC_TRAIN_ALL = _HOME + 'cnc/train-all.txt' 8 | CNC_TRAIN = _HOME + 'cnc/train.txt' 9 | CNC_DEV = _HOME + 'cnc/dev.txt' 10 | CNC_TEST = _HOME + 'cnc/test.txt' 11 | 12 | CTB_TRAIN_ALL = _HOME + 'ctb/train-all.txt' 13 | CTB_TRAIN = _HOME + 'ctb/train.txt' 14 | CTB_DEV = _HOME + 'ctb/dev.txt' 15 | CTB_TEST = _HOME + 'ctb/test.txt' 16 | 17 | SXU_TRAIN_ALL = _HOME + 'sxu/train-all.txt' 18 | SXU_TRAIN = _HOME + 'sxu/train.txt' 19 | SXU_DEV = _HOME + 'sxu/dev.txt' 20 | SXU_TEST = _HOME + 'sxu/test.txt' 21 | 22 | UDC_TRAIN_ALL = _HOME + 'udc/train-all.txt' 23 | UDC_TRAIN = _HOME + 'udc/train.txt' 24 | UDC_DEV = _HOME + 'udc/dev.txt' 25 | UDC_TEST = _HOME + 'udc/test.txt' 26 | 27 | WTB_TRAIN_ALL = _HOME + 'wtb/train-all.txt' 28 | WTB_TRAIN = _HOME + 'wtb/train.txt' 29 | WTB_DEV = _HOME + 'wtb/dev.txt' 30 | WTB_TEST = _HOME + 'wtb/test.txt' 31 | 32 | ZX_TRAIN_ALL = _HOME + 'zx/train-all.txt' 33 | ZX_TRAIN = _HOME + 'zx/train.txt' 34 | ZX_DEV = _HOME + 'zx/dev.txt' 35 | ZX_TEST = _HOME + 'zx/test.txt' 36 | -------------------------------------------------------------------------------- /elit/datasets/tokenization/sighan2005/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-21 15:42 4 | import os 5 | 6 | from elit.utils.io_util import get_resource, split_file 7 | from elit.utils.log_util import logger 8 | 9 | SIGHAN2005 = 'http://sighan.cs.uchicago.edu/bakeoff2005/data/icwb2-data.zip' 10 | 11 | 12 | def make(train): 13 | root = get_resource(SIGHAN2005) 14 | train = os.path.join(root, train.split('#')[-1]) 15 | if not os.path.isfile(train): 16 | full = train.replace('_90.txt', '.utf8') 17 | logger.info(f'Splitting {full} into training set and valid set with 9:1 proportion') 18 | valid = train.replace('90.txt', '10.txt') 19 | split_file(full, train=0.9, dev=0.1, test=0, names={'train': train, 'dev': valid}) 20 | assert os.path.isfile(train), f'Failed to make {train}' 21 | assert os.path.isfile(valid), f'Failed to make {valid}' 22 | logger.info(f'Successfully made {train} {valid}') 23 | -------------------------------------------------------------------------------- /elit/datasets/tokenization/sighan2005/as_.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-21 15:42 4 | from elit.datasets.tokenization.sighan2005 import SIGHAN2005, make 5 | 6 | SIGHAN2005_AS_DICT = SIGHAN2005 + "#" + "gold/as_training_words.utf8" 7 | '''Dictionary built on trainings set.''' 8 | SIGHAN2005_AS_TRAIN_ALL = SIGHAN2005 + "#" + "training/as_training.utf8" 9 | '''Full training set.''' 10 | SIGHAN2005_AS_TRAIN = SIGHAN2005 + "#" + "training/as_training_90.txt" 11 | '''Training set (first 90% of the full official training set).''' 12 | SIGHAN2005_AS_DEV = SIGHAN2005 + "#" + "training/as_training_10.txt" 13 | '''Dev set (last 10% of full official training set).''' 14 | SIGHAN2005_AS_TEST_INPUT = SIGHAN2005 + "#" + "testing/as_testing.utf8" 15 | '''Test input.''' 16 | SIGHAN2005_AS_TEST = SIGHAN2005 + "#" + "gold/as_testing_gold.utf8" 17 | '''Test set.''' 18 | 19 | make(SIGHAN2005_AS_TRAIN) 20 | -------------------------------------------------------------------------------- /elit/datasets/tokenization/sighan2005/cityu.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-21 15:42 4 | from elit.datasets.tokenization.sighan2005 import SIGHAN2005, make 5 | 6 | SIGHAN2005_CITYU_DICT = SIGHAN2005 + "#" + "gold/cityu_training_words.utf8" 7 | '''Dictionary built on trainings set.''' 8 | SIGHAN2005_CITYU_TRAIN_ALL = SIGHAN2005 + "#" + "training/cityu_training.utf8" 9 | '''Full training set.''' 10 | SIGHAN2005_CITYU_TRAIN = SIGHAN2005 + "#" + "training/cityu_training_90.txt" 11 | '''Training set (first 90% of the full official training set).''' 12 | SIGHAN2005_CITYU_DEV = SIGHAN2005 + "#" + "training/cityu_training_10.txt" 13 | '''Dev set (last 10% of full official training set).''' 14 | SIGHAN2005_CITYU_TEST_INPUT = SIGHAN2005 + "#" + "testing/cityu_test.utf8" 15 | '''Test input.''' 16 | SIGHAN2005_CITYU_TEST = SIGHAN2005 + "#" + "gold/cityu_test_gold.utf8" 17 | '''Test set.''' 18 | 19 | make(SIGHAN2005_CITYU_TRAIN) 20 | -------------------------------------------------------------------------------- /elit/datasets/tokenization/sighan2005/msr.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-21 15:42 4 | from elit.datasets.tokenization.sighan2005 import SIGHAN2005, make 5 | 6 | SIGHAN2005_MSR_DICT = SIGHAN2005 + "#" + "gold/msr_training_words.utf8" 7 | '''Dictionary built on trainings set.''' 8 | SIGHAN2005_MSR_TRAIN_ALL = SIGHAN2005 + "#" + "training/msr_training.utf8" 9 | '''Full training set.''' 10 | SIGHAN2005_MSR_TRAIN = SIGHAN2005 + "#" + "training/msr_training_90.txt" 11 | '''Training set (first 90% of the full official training set).''' 12 | SIGHAN2005_MSR_DEV = SIGHAN2005 + "#" + "training/msr_training_10.txt" 13 | '''Dev set (last 10% of full official training set).''' 14 | SIGHAN2005_MSR_TEST_INPUT = SIGHAN2005 + "#" + "testing/msr_test.utf8" 15 | '''Test input.''' 16 | SIGHAN2005_MSR_TEST = SIGHAN2005 + "#" + "gold/msr_test_gold.utf8" 17 | '''Test set.''' 18 | 19 | make(SIGHAN2005_MSR_TRAIN) 20 | -------------------------------------------------------------------------------- /elit/datasets/tokenization/sighan2005/pku.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-21 15:42 4 | from elit.datasets.tokenization.sighan2005 import SIGHAN2005, make 5 | 6 | SIGHAN2005_PKU_DICT = SIGHAN2005 + "#" + "gold/pku_training_words.utf8" 7 | '''Dictionary built on trainings set.''' 8 | SIGHAN2005_PKU_TRAIN_ALL = SIGHAN2005 + "#" + "training/pku_training.utf8" 9 | '''Full training set.''' 10 | SIGHAN2005_PKU_TRAIN = SIGHAN2005 + "#" + "training/pku_training_90.txt" 11 | '''Training set (first 90% of the full official training set).''' 12 | SIGHAN2005_PKU_DEV = SIGHAN2005 + "#" + "training/pku_training_10.txt" 13 | '''Dev set (last 10% of full official training set).''' 14 | SIGHAN2005_PKU_TEST_INPUT = SIGHAN2005 + "#" + "testing/pku_test.utf8" 15 | '''Test input.''' 16 | SIGHAN2005_PKU_TEST = SIGHAN2005 + "#" + "gold/pku_test_gold.utf8" 17 | '''Test set.''' 18 | 19 | make(SIGHAN2005_PKU_TRAIN) 20 | -------------------------------------------------------------------------------- /elit/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-10-26 00:50 -------------------------------------------------------------------------------- /elit/layers/crf/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-18 22:55 -------------------------------------------------------------------------------- /elit/layers/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-08-24 21:48 4 | -------------------------------------------------------------------------------- /elit/layers/embeddings/char_rnn_tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-20 17:02 4 | import tensorflow as tf 5 | 6 | from elit.common.vocab_tf import VocabTF 7 | from elit.utils.tf_util import hanlp_register 8 | 9 | 10 | @hanlp_register 11 | class CharRNNEmbeddingTF(tf.keras.layers.Layer): 12 | def __init__(self, word_vocab: VocabTF, char_vocab: VocabTF, 13 | char_embedding=100, 14 | char_rnn_units=25, 15 | dropout=0.5, 16 | trainable=True, name=None, dtype=None, dynamic=False, 17 | **kwargs): 18 | super().__init__(trainable, name, dtype, dynamic, **kwargs) 19 | self.char_embedding = char_embedding 20 | self.char_rnn_units = char_rnn_units 21 | self.char_vocab = char_vocab 22 | self.word_vocab = word_vocab 23 | self.embedding = tf.keras.layers.Embedding(input_dim=len(self.char_vocab), output_dim=char_embedding, 24 | trainable=True, mask_zero=True) 25 | self.dropout = tf.keras.layers.Dropout(dropout) 26 | self.rnn = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=char_rnn_units, 27 | return_state=True), name='bilstm') 28 | 29 | def call(self, inputs: tf.Tensor, **kwargs): 30 | mask = tf.not_equal(inputs, self.word_vocab.pad_token) 31 | inputs = tf.ragged.boolean_mask(inputs, mask) 32 | chars = tf.strings.unicode_split(inputs, input_encoding='UTF-8') 33 | chars = chars.to_tensor(default_value=self.char_vocab.pad_token) 34 | chars = self.char_vocab.lookup(chars) 35 | embed = self.embedding(chars) 36 | char_mask = embed._keras_mask 37 | embed = self.dropout(embed) 38 | embed_shape = tf.shape(embed) 39 | embed = tf.reshape(embed, [-1, embed_shape[2], embed_shape[3]]) 40 | char_mask = tf.reshape(char_mask, [-1, embed_shape[2]]) 41 | all_zeros = tf.reduce_sum(tf.cast(char_mask, tf.int32), axis=1) == 0 42 | char_mask_shape = tf.shape(char_mask) 43 | hole = tf.zeros(shape=(char_mask_shape[0], char_mask_shape[1] - 1), dtype=tf.bool) 44 | all_zeros = tf.expand_dims(all_zeros, -1) 45 | non_all_zeros = tf.concat([all_zeros, hole], axis=1) 46 | char_mask = tf.logical_or(char_mask, non_all_zeros) 47 | output, h_fw, c_fw, h_bw, c_bw = self.rnn(embed, mask=char_mask) 48 | hidden = tf.concat([h_fw, h_bw], axis=-1) 49 | # hidden = output 50 | hidden = tf.reshape(hidden, [embed_shape[0], embed_shape[1], -1]) 51 | hidden._keras_mask = mask 52 | return hidden 53 | 54 | def get_config(self): 55 | config = { 56 | 'char_embedding': self.char_embedding, 57 | 'char_rnn_units': self.char_rnn_units, 58 | 'dropout': self.dropout.rate, 59 | } 60 | base_config = super(CharRNNEmbeddingTF, self).get_config() 61 | return dict(list(base_config.items()) + list(config.items())) 62 | -------------------------------------------------------------------------------- /elit/layers/embeddings/concat_embedding.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-20 17:08 4 | import tensorflow as tf 5 | 6 | from elit.utils.tf_util import hanlp_register, copy_mask 7 | 8 | 9 | @hanlp_register 10 | class ConcatEmbedding(tf.keras.layers.Layer): 11 | def __init__(self, *embeddings, trainable=True, name=None, dtype=None, dynamic=False, **kwargs): 12 | self.embeddings = [] 13 | for embed in embeddings: 14 | embed: tf.keras.layers.Layer = tf.keras.utils.deserialize_keras_object(embed) if isinstance(embed, 15 | dict) else embed 16 | self.embeddings.append(embed) 17 | if embed.trainable: 18 | trainable = True 19 | if embed.dynamic: 20 | dynamic = True 21 | if embed.supports_masking: 22 | self.supports_masking = True 23 | 24 | super().__init__(trainable, name, dtype, dynamic, **kwargs) 25 | 26 | def build(self, input_shape): 27 | for embed in self.embeddings: 28 | embed.build(input_shape) 29 | super().build(input_shape) 30 | 31 | def compute_mask(self, inputs, mask=None): 32 | for embed in self.embeddings: 33 | mask = embed.compute_mask(inputs, mask) 34 | if mask is not None: 35 | return mask 36 | return mask 37 | 38 | def call(self, inputs, **kwargs): 39 | embeds = [embed.call(inputs) for embed in self.embeddings] 40 | feature = tf.concat(embeds, axis=-1) 41 | 42 | for embed in embeds: 43 | mask = copy_mask(embed, feature) 44 | if mask is not None: 45 | break 46 | return feature 47 | 48 | def get_config(self): 49 | config = { 50 | 'embeddings': [embed.get_config() for embed in self.embeddings], 51 | } 52 | base_config = super(ConcatEmbedding, self).get_config() 53 | return dict(list(base_config.items()) + list(config.items())) 54 | 55 | def compute_output_shape(self, input_shape): 56 | dim = 0 57 | for embed in self.embeddings: 58 | dim += embed.compute_output_shape(input_shape)[-1] 59 | 60 | return input_shape + dim 61 | -------------------------------------------------------------------------------- /elit/layers/feed_forward.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-06 14:37 4 | from typing import Union, List 5 | 6 | from elit.layers import feedforward 7 | 8 | from elit.common.structure import ConfigTracker 9 | 10 | 11 | class FeedForward(feedforward.FeedForward, ConfigTracker): 12 | def __init__(self, input_dim: int, num_layers: int, hidden_dims: Union[int, List[int]], 13 | activations: Union[str, List[str]], dropout: Union[float, List[float]] = 0.0) -> None: 14 | super().__init__(input_dim, num_layers, hidden_dims, activations, dropout) 15 | ConfigTracker.__init__(self, locals()) 16 | -------------------------------------------------------------------------------- /elit/layers/gates/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-02-14 13:31 4 | -------------------------------------------------------------------------------- /elit/layers/time_distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | A wrapper that unrolls the second (time) dimension of a tensor 3 | into the first (batch) dimension, applies some other `Module`, 4 | and then rolls the time dimension back up. 5 | """ 6 | 7 | from typing import List 8 | 9 | 10 | import torch 11 | 12 | 13 | class TimeDistributed(torch.nn.Module): 14 | """ 15 | Given an input shaped like `(batch_size, time_steps, [rest])` and a `Module` that takes 16 | inputs like `(batch_size, [rest])`, `TimeDistributed` reshapes the input to be 17 | `(batch_size * time_steps, [rest])`, applies the contained `Module`, then reshapes it back. 18 | 19 | Note that while the above gives shapes with `batch_size` first, this `Module` also works if 20 | `batch_size` is second - we always just combine the first two dimensions, then split them. 21 | 22 | It also reshapes keyword arguments unless they are not tensors or their name is specified in 23 | the optional `pass_through` iterable. 24 | """ 25 | 26 | def __init__(self, module): 27 | super().__init__() 28 | self._module = module 29 | 30 | 31 | def forward(self, *inputs, pass_through: List[str] = None, **kwargs): 32 | 33 | pass_through = pass_through or [] 34 | 35 | reshaped_inputs = [self._reshape_tensor(input_tensor) for input_tensor in inputs] 36 | 37 | # Need some input to then get the batch_size and time_steps. 38 | some_input = None 39 | if inputs: 40 | some_input = inputs[-1] 41 | 42 | reshaped_kwargs = {} 43 | for key, value in kwargs.items(): 44 | if isinstance(value, torch.Tensor) and key not in pass_through: 45 | if some_input is None: 46 | some_input = value 47 | 48 | value = self._reshape_tensor(value) 49 | 50 | reshaped_kwargs[key] = value 51 | 52 | reshaped_outputs = self._module(*reshaped_inputs, **reshaped_kwargs) 53 | 54 | if some_input is None: 55 | raise RuntimeError("No input tensor to time-distribute") 56 | 57 | # Now get the output back into the right shape. 58 | # (batch_size, time_steps, **output_size) 59 | new_size = some_input.size()[:2] + reshaped_outputs.size()[1:] 60 | outputs = reshaped_outputs.contiguous().view(new_size) 61 | 62 | return outputs 63 | 64 | @staticmethod 65 | def _reshape_tensor(input_tensor): 66 | input_size = input_tensor.size() 67 | if len(input_size) <= 2: 68 | raise RuntimeError(f"No dimension to distribute: {input_size}") 69 | # Squash batch_size and time_steps into a single axis; result has shape 70 | # (batch_size * time_steps, **input_size). 71 | squashed_shape = [-1] + list(input_size[2:]) 72 | return input_tensor.contiguous().view(*squashed_shape) 73 | -------------------------------------------------------------------------------- /elit/layers/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-29 15:17 4 | # mute transformers 5 | import logging 6 | 7 | logging.getLogger('transformers.file_utils').setLevel(logging.ERROR) 8 | logging.getLogger('transformers.filelock').setLevel(logging.ERROR) 9 | logging.getLogger('transformers.tokenization_utils').setLevel(logging.ERROR) 10 | logging.getLogger('transformers.configuration_utils').setLevel(logging.ERROR) 11 | logging.getLogger('transformers.modeling_tf_utils').setLevel(logging.ERROR) 12 | logging.getLogger('transformers.modeling_utils').setLevel(logging.ERROR) 13 | logging.getLogger('transformers.tokenization_utils_base').setLevel(logging.ERROR) 14 | -------------------------------------------------------------------------------- /elit/layers/transformers/loader_tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-01-04 06:05 4 | import tensorflow as tf 5 | from transformers import TFAutoModel 6 | 7 | from elit.layers.transformers.pt_imports import AutoTokenizer_, AutoModel_ 8 | 9 | 10 | def build_transformer(transformer, max_seq_length, num_labels, tagging=True, tokenizer_only=False): 11 | tokenizer = AutoTokenizer_.from_pretrained(transformer) 12 | if tokenizer_only: 13 | return tokenizer 14 | l_bert = TFAutoModel.from_pretrained(transformer) 15 | l_input_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype='int32', name="input_ids") 16 | l_mask_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype='int32', name="mask_ids") 17 | l_token_type_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype='int32', name="token_type_ids") 18 | output = l_bert(input_ids=l_input_ids, token_type_ids=l_token_type_ids, attention_mask=l_mask_ids).last_hidden_state 19 | if not tagging: 20 | output = tf.keras.layers.Lambda(lambda seq: seq[:, 0, :])(output) 21 | logits = tf.keras.layers.Dense(num_labels)(output) 22 | model = tf.keras.Model(inputs=[l_input_ids, l_mask_ids, l_token_type_ids], outputs=logits) 23 | model.build(input_shape=(None, max_seq_length)) 24 | return model, tokenizer 25 | -------------------------------------------------------------------------------- /elit/layers/transformers/longformer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-10-13 19:44 4 | -------------------------------------------------------------------------------- /elit/layers/transformers/resource.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-05-20 12:43 4 | from elit.utils.io_util import get_resource 5 | from hanlp_common.constant import HANLP_URL 6 | 7 | tokenizer_mirrors = { 8 | 'hfl/chinese-electra-180g-base-discriminator': HANLP_URL + 'transformers/electra_zh_base_20210706_125233.zip', 9 | 'hfl/chinese-electra-180g-small-discriminator': HANLP_URL + 'transformers/electra_zh_small_20210706_125427.zip', 10 | 'xlm-roberta-base': HANLP_URL + 'transformers/xlm-roberta-base_20210706_125502.zip', 11 | 'cl-tohoku/bert-base-japanese-char': HANLP_URL + 'transformers/bert-base-japanese-char_20210602_215445.zip', 12 | 'bart5-chinese-small': HANLP_URL + 'transformers/bart5-chinese-small_tok_20210723_180743.zip', 13 | 'ernie-gram': HANLP_URL + 'transformers/ernie-gram_20220207_103518.zip', 14 | 'xlm-roberta-base-no-space': HANLP_URL + 'transformers/xlm-roberta-base-no-space-tokenizer_20220610_204241.zip', 15 | 'mMiniLMv2L6-no-space': HANLP_URL + 'transformers/mMiniLMv2L6-no-space-tokenizer_20220616_094859.zip', 16 | 'mMiniLMv2L12-no-space': HANLP_URL + 'transformers/mMiniLMv2L12-no-space-tokenizer_20220616_095900.zip', 17 | } 18 | 19 | model_mirrors = { 20 | 'bart5-chinese-small': HANLP_URL + 'transformers/bart5-chinese-small_20210723_203923.zip', 21 | 'xlm-roberta-base-no-space': HANLP_URL + 'transformers/xlm-roberta-base-no-space_20220610_203944.zip', 22 | 'mMiniLMv2L6-no-space': HANLP_URL + 'transformers/mMiniLMv2L6-no-space_20220616_094949.zip', 23 | 'mMiniLMv2L12-no-space': HANLP_URL + 'transformers/mMiniLMv2L12-no-space_20220616_095924.zip', 24 | } 25 | 26 | 27 | def get_tokenizer_mirror(transformer: str) -> str: 28 | m = tokenizer_mirrors.get(transformer, None) 29 | if m: 30 | return get_resource(m) 31 | return transformer 32 | 33 | 34 | def get_model_mirror(transformer: str) -> str: 35 | m = model_mirrors.get(transformer, None) 36 | if m: 37 | return get_resource(m) 38 | return transformer 39 | -------------------------------------------------------------------------------- /elit/layers/transformers/tf_imports.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-05-08 21:57 4 | from transformers import BertTokenizer, BertConfig, PretrainedConfig, TFAutoModel, \ 5 | AutoConfig, AutoTokenizer, PreTrainedTokenizer, TFPreTrainedModel, TFAlbertModel, TFAutoModelWithLMHead, \ 6 | BertTokenizerFast, TFAlbertForMaskedLM, AlbertConfig, TFBertModel 7 | -------------------------------------------------------------------------------- /elit/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-20 01:28 -------------------------------------------------------------------------------- /elit/losses/homoscedastic_loss_weighted_sum.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-01-22 21:14 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class HomoscedasticLossWeightedSum(nn.Module): 9 | 10 | def __init__(self, num_losses): 11 | """Automatically weighted sum of multi-task losses described in :cite:`Kendall_2018_CVPR`. 12 | 13 | Args: 14 | num_losses: The number of losses. 15 | """ 16 | super(HomoscedasticLossWeightedSum, self).__init__() 17 | params = torch.ones(num_losses, requires_grad=True) 18 | self.params = torch.nn.Parameter(params) 19 | 20 | def forward(self, *losses): 21 | losses = torch.stack(losses) 22 | norm = self.params ** 2 23 | return torch.sum(0.5 / norm * losses + torch.log(1 + norm)) 24 | -------------------------------------------------------------------------------- /elit/losses/sparse_categorical_crossentropy.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-20 01:29 4 | 5 | import tensorflow as tf 6 | 7 | from elit.utils.tf_util import hanlp_register 8 | 9 | 10 | @hanlp_register 11 | class SparseCategoricalCrossentropyOverNonzeroWeights(object): 12 | def __init__(self) -> None: 13 | super().__init__() 14 | self.__name__ = type(self).__name__ 15 | 16 | def __call__(self, y_true, y_pred, sample_weight=None, **kwargs): 17 | loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True) 18 | if sample_weight is not None: 19 | loss = loss * sample_weight 20 | loss = tf.reduce_sum(loss) 21 | if sample_weight is not None: 22 | # This is equivalent to SUM_OVER_BATCH_SIZE 23 | # loss /= tf.reduce_sum(tf.ones_like(sample_weight, dtype=loss.dtype)) 24 | # This one is SUM_BY_NONZERO_WEIGHTS 25 | loss /= tf.reduce_sum(sample_weight) 26 | return loss 27 | 28 | 29 | @hanlp_register 30 | class SparseCategoricalCrossentropyOverBatchFirstDim(object): 31 | 32 | def __init__(self) -> None: 33 | super().__init__() 34 | self.__name__ = type(self).__name__ 35 | 36 | def __call__(self, y_true, y_pred, sample_weight=None, **kwargs): 37 | loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True) 38 | if sample_weight is not None: 39 | loss = loss * sample_weight 40 | # could use sum of sample_weight[:,0] too 41 | loss = tf.reduce_sum(loss) / tf.cast(tf.shape(y_true)[0], tf.float32) 42 | return loss 43 | 44 | def get_config(self): 45 | return {} 46 | 47 | 48 | @hanlp_register 49 | class MaskedSparseCategoricalCrossentropyOverBatchFirstDim(object): 50 | def __init__(self, mask_value=0) -> None: 51 | super().__init__() 52 | self.mask_value = mask_value 53 | self.__name__ = type(self).__name__ 54 | 55 | def __call__(self, y_true, y_pred, sample_weight=None, **kwargs): 56 | assert sample_weight is None, 'the mask will be computed via y_true != mask_value, ' \ 57 | 'it might conflict with sample_weight' 58 | active_loss = tf.not_equal(y_true, self.mask_value) 59 | active_labels = tf.boolean_mask(y_true, active_loss) 60 | active_logits = tf.boolean_mask(y_pred, active_loss) 61 | loss = tf.keras.losses.sparse_categorical_crossentropy(active_labels, active_logits, from_logits=True) 62 | loss = tf.reduce_sum(loss) / tf.cast(tf.shape(y_true)[0], tf.float32) 63 | return loss 64 | -------------------------------------------------------------------------------- /elit/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-09-14 21:55 -------------------------------------------------------------------------------- /elit/metrics/amr/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-08-24 12:47 -------------------------------------------------------------------------------- /elit/metrics/chunking/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-21 03:49 -------------------------------------------------------------------------------- /elit/metrics/chunking/binary_chunking_f1.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-08-02 14:27 4 | from collections import defaultdict 5 | from typing import List, Union 6 | 7 | import torch 8 | 9 | from elit.metrics.f1 import F1 10 | 11 | 12 | class BinaryChunkingF1(F1): 13 | def __call__(self, pred_tags: torch.LongTensor, gold_tags: torch.LongTensor, lens: List[int] = None): 14 | if lens is None: 15 | lens = [gold_tags.size(1)] * gold_tags.size(0) 16 | self.update(self.decode_spans(pred_tags, lens), self.decode_spans(gold_tags, lens)) 17 | 18 | def update(self, pred_tags, gold_tags): 19 | for pred, gold in zip(pred_tags, gold_tags): 20 | super().__call__(set(pred), set(gold)) 21 | 22 | @staticmethod 23 | def decode_spans(pred_tags: torch.LongTensor, lens: Union[List[int], torch.LongTensor]): 24 | if isinstance(lens, torch.Tensor): 25 | lens = lens.tolist() 26 | batch_pred = defaultdict(list) 27 | for batch, offset in pred_tags.nonzero(as_tuple=False).tolist(): 28 | batch_pred[batch].append(offset) 29 | batch_pred_spans = [[(0, l)] for l in lens] 30 | for batch, offsets in batch_pred.items(): 31 | l = lens[batch] 32 | batch_pred_spans[batch] = list(zip(offsets, offsets[1:] + [l])) 33 | return batch_pred_spans 34 | -------------------------------------------------------------------------------- /elit/metrics/chunking/bmes_tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-09-14 21:55 4 | 5 | from elit.common.vocab_tf import VocabTF 6 | from elit.metrics.chunking.chunking_f1_tf import ChunkingF1_TF 7 | from elit.metrics.chunking.sequence_labeling import get_entities 8 | 9 | 10 | class BMES_F1_TF(ChunkingF1_TF): 11 | 12 | def __init__(self, tag_vocab: VocabTF, from_logits=True, suffix=False, name='f1', dtype=None, **kwargs): 13 | super().__init__(tag_vocab, from_logits, name, dtype, **kwargs) 14 | self.nb_correct = 0 15 | self.nb_pred = 0 16 | self.nb_true = 0 17 | self.suffix = suffix 18 | 19 | def update_tags(self, true_tags, pred_tags): 20 | for t, p in zip(true_tags, pred_tags): 21 | self.update_entities(get_entities(t, self.suffix), get_entities(p, self.suffix)) 22 | return self.result() 23 | 24 | def update_entities(self, true_entities, pred_entities): 25 | true_entities = set(true_entities) 26 | pred_entities = set(pred_entities) 27 | nb_correct = len(true_entities & pred_entities) 28 | nb_pred = len(pred_entities) 29 | nb_true = len(true_entities) 30 | self.nb_correct += nb_correct 31 | self.nb_pred += nb_pred 32 | self.nb_true += nb_true 33 | 34 | def result(self): 35 | nb_correct = self.nb_correct 36 | nb_pred = self.nb_pred 37 | nb_true = self.nb_true 38 | p = nb_correct / nb_pred if nb_pred > 0 else 0 39 | r = nb_correct / nb_true if nb_true > 0 else 0 40 | score = 2 * p * r / (p + r) if p + r > 0 else 0 41 | 42 | return score 43 | 44 | def reset_states(self): 45 | self.nb_correct = 0 46 | self.nb_pred = 0 47 | self.nb_true = 0 48 | -------------------------------------------------------------------------------- /elit/metrics/chunking/chunking_f1_tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-29 23:09 4 | from abc import ABC, abstractmethod 5 | 6 | import tensorflow as tf 7 | 8 | from elit.common.vocab_tf import VocabTF 9 | 10 | 11 | class ChunkingF1_TF(tf.keras.metrics.Metric, ABC): 12 | 13 | def __init__(self, tag_vocab: VocabTF, from_logits=True, name='f1', dtype=None, **kwargs): 14 | super().__init__(name, dtype, dynamic=True, **kwargs) 15 | self.tag_vocab = tag_vocab 16 | self.from_logits = from_logits 17 | 18 | def update_the_state(self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor = None, **kwargs): 19 | if sample_weight is None: 20 | if hasattr(y_pred, '_keras_mask'): 21 | mask = y_pred._keras_mask 22 | else: 23 | mask = None 24 | else: 25 | mask = sample_weight 26 | if self.tag_vocab.pad_idx is not None and mask is None: 27 | # in this case, the model doesn't compute mask but provide a masking index, it's ok to 28 | mask = y_true != self.tag_vocab.pad_idx 29 | assert mask is not None, 'ChunkingF1 requires masking, check your _keras_mask or compute_mask' 30 | if self.from_logits: 31 | y_pred = tf.argmax(y_pred, axis=-1) 32 | y_true = self.to_tags(y_true, mask) 33 | y_pred = self.to_tags(y_pred, mask) 34 | return self.update_tags(y_true, y_pred) 35 | 36 | def __call__(self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor = None, **kwargs): 37 | return self.update_the_state(y_true, y_pred, sample_weight) 38 | 39 | def update_state(self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor = None, **kwargs): 40 | return self.update_the_state(y_true, y_pred, sample_weight) 41 | 42 | def to_tags(self, y: tf.Tensor, sample_weight: tf.Tensor): 43 | batch = [] 44 | y = y.numpy() 45 | sample_weight = sample_weight.numpy() 46 | for sent, mask in zip(y, sample_weight): 47 | tags = [] 48 | for tag, m in zip(sent, mask): 49 | if not m: 50 | continue 51 | tag = int(tag) 52 | if self.tag_vocab.pad_idx is not None and tag == self.tag_vocab.pad_idx: 53 | # If model predicts , it will fail most metrics. So replace it with a valid one 54 | tag = 1 55 | tags.append(self.tag_vocab.get_token(tag)) 56 | batch.append(tags) 57 | return batch 58 | 59 | @abstractmethod 60 | def update_tags(self, true_tags, pred_tags): 61 | pass 62 | 63 | @abstractmethod 64 | def result(self): 65 | pass 66 | -------------------------------------------------------------------------------- /elit/metrics/chunking/iobes_tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-09-14 21:55 4 | 5 | from elit.common.vocab_tf import VocabTF 6 | from elit.metrics.chunking.conlleval import SpanF1 7 | from elit.metrics.chunking.chunking_f1_tf import ChunkingF1_TF 8 | 9 | 10 | class IOBES_F1_TF(ChunkingF1_TF): 11 | 12 | def __init__(self, tag_vocab: VocabTF, from_logits=True, name='f1', dtype=None, **kwargs): 13 | super().__init__(tag_vocab, from_logits, name, dtype, **kwargs) 14 | self.state = SpanF1() 15 | 16 | def update_tags(self, true_tags, pred_tags): 17 | # true_tags = list(itertools.chain.from_iterable(true_tags)) 18 | # pred_tags = list(itertools.chain.from_iterable(pred_tags)) 19 | # self.state.update_state(true_tags, pred_tags) 20 | for gold, pred in zip(true_tags, pred_tags): 21 | self.state.update_state(gold, pred) 22 | return self.result() 23 | 24 | def result(self): 25 | return self.state.result(full=False, verbose=False).fscore 26 | 27 | def reset_states(self): 28 | self.state.reset_state() 29 | -------------------------------------------------------------------------------- /elit/metrics/f1.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-10 14:55 4 | from abc import ABC 5 | 6 | from elit.metrics.metric import Metric 7 | 8 | 9 | class F1(Metric, ABC): 10 | def __init__(self, nb_pred=0, nb_true=0, nb_correct=0) -> None: 11 | super().__init__() 12 | self.nb_correct = nb_correct 13 | self.nb_pred = nb_pred 14 | self.nb_true = nb_true 15 | 16 | def __repr__(self) -> str: 17 | p, r, f = self.prf 18 | return f"P: {p:.2%} R: {r:.2%} F1: {f:.2%}" 19 | 20 | @property 21 | def prf(self): 22 | nb_correct = self.nb_correct 23 | nb_pred = self.nb_pred 24 | nb_true = self.nb_true 25 | p = nb_correct / nb_pred if nb_pred > 0 else .0 26 | r = nb_correct / nb_true if nb_true > 0 else .0 27 | f = 2 * p * r / (p + r) if p + r > 0 else .0 28 | return p, r, f 29 | 30 | @property 31 | def score(self): 32 | return self.prf[-1] 33 | 34 | def reset(self): 35 | self.nb_correct = 0 36 | self.nb_pred = 0 37 | self.nb_true = 0 38 | 39 | def __call__(self, pred: set, gold: set): 40 | self.nb_correct += len(pred & gold) 41 | self.nb_pred += len(pred) 42 | self.nb_true += len(gold) 43 | 44 | 45 | class F1_(Metric): 46 | def __init__(self, p, r, f) -> None: 47 | super().__init__() 48 | self.f = f 49 | self.r = r 50 | self.p = p 51 | 52 | @property 53 | def score(self): 54 | return self.f 55 | 56 | def __call__(self, pred, gold): 57 | raise NotImplementedError() 58 | 59 | def reset(self): 60 | self.f = self.r = self.p = 0 61 | 62 | def __repr__(self) -> str: 63 | p, r, f = self.p, self.r, self.f 64 | return f"P: {p:.2%} R: {r:.2%} F1: {f:.2%}" 65 | -------------------------------------------------------------------------------- /elit/metrics/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-06-03 11:35 4 | from abc import ABC, abstractmethod 5 | 6 | 7 | class Metric(ABC): 8 | 9 | def __lt__(self, other): 10 | return self.score < other 11 | 12 | def __le__(self, other): 13 | return self.score <= other 14 | 15 | def __eq__(self, other): 16 | return self.score == other 17 | 18 | def __ge__(self, other): 19 | return self.score >= other 20 | 21 | def __gt__(self, other): 22 | return self.score > other 23 | 24 | def __ne__(self, other): 25 | return self.score != other 26 | 27 | @property 28 | @abstractmethod 29 | def score(self): 30 | pass 31 | 32 | @abstractmethod 33 | def __call__(self, pred, gold, mask=None): 34 | pass 35 | 36 | def __repr__(self) -> str: 37 | return f'{self.score:.4f}' 38 | 39 | def __float__(self): 40 | return self.score 41 | 42 | @abstractmethod 43 | def reset(self): 44 | pass 45 | 46 | 47 | class ScalarMetric(Metric): 48 | 49 | def __init__(self, score) -> None: 50 | super().__init__() 51 | self._score = score 52 | 53 | @property 54 | def score(self): 55 | return self._score 56 | 57 | def __call__(self, pred, gold, mask=None): 58 | pass 59 | 60 | def reset(self): 61 | self._score = 0 62 | 63 | def __repr__(self) -> str: 64 | return f'{self.score:.2%}' 65 | -------------------------------------------------------------------------------- /elit/metrics/mtl.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-08-03 00:16 4 | from collections.abc import MutableMapping 5 | 6 | from elit.metrics.metric import Metric 7 | 8 | 9 | class MetricDict(Metric, MutableMapping): 10 | _COLORS = ["magenta", "cyan", "green", "yellow"] 11 | 12 | def __init__(self, *args, primary_key=None, **kwargs) -> None: 13 | self.store = dict(*args, **kwargs) 14 | self.primary_key = primary_key 15 | 16 | @property 17 | def score(self): 18 | return float(self[self.primary_key]) if self.primary_key else sum(float(x) for x in self.values()) / len(self) 19 | 20 | def __call__(self, pred, gold): 21 | for metric in self.values(): 22 | metric(pred, gold) 23 | 24 | def reset(self): 25 | for metric in self.values(): 26 | metric.reset() 27 | 28 | def __repr__(self) -> str: 29 | return ' '.join(f'({k} {v})' for k, v in self.items()) 30 | 31 | def cstr(self, idx=None, level=0) -> str: 32 | if idx is None: 33 | idx = [0] 34 | prefix = '' 35 | for _, (k, v) in enumerate(self.items()): 36 | color = self._COLORS[idx[0] % len(self._COLORS)] 37 | idx[0] += 1 38 | child_is_dict = isinstance(v, MetricDict) 39 | _level = min(level, 2) 40 | # if level != 0 and not child_is_dict: 41 | # _level = 2 42 | lb = '{[(' 43 | rb = '}])' 44 | k = f'[bold][underline]{k}[/underline][/bold]' 45 | prefix += f'[{color}]{lb[_level]}{k} [/{color}]' 46 | if child_is_dict: 47 | prefix += v.cstr(idx, level + 1) 48 | else: 49 | prefix += f'[{color}]{v}[/{color}]' 50 | prefix += f'[{color}]{rb[_level]}[/{color}]' 51 | return prefix 52 | 53 | def __getitem__(self, key): 54 | return self.store[key] 55 | 56 | def __setitem__(self, key, value): 57 | self.store[key] = value 58 | 59 | def __delitem__(self, key): 60 | del self.store[key] 61 | 62 | def __iter__(self): 63 | return iter(self.store) 64 | 65 | def __len__(self): 66 | return len(self.store) 67 | -------------------------------------------------------------------------------- /elit/metrics/parsing/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-27 00:48 -------------------------------------------------------------------------------- /elit/metrics/parsing/attachmentscore.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Yu Zhang 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | from typing import Tuple, List 23 | from elit.metrics.metric import Metric 24 | 25 | 26 | class AttachmentScore(Metric): 27 | 28 | def __init__(self, eps=1e-12): 29 | super(AttachmentScore, self).__init__() 30 | 31 | self.eps = eps 32 | self.total = 0.0 33 | self.correct_arcs = 0.0 34 | self.correct_rels = 0.0 35 | 36 | def __repr__(self): 37 | return f"UAS: {self.uas:.2%} LAS: {self.las:.2%}" 38 | 39 | # noinspection PyMethodOverriding 40 | def __call__(self, arc_preds, rel_preds, arc_golds, rel_golds, mask): 41 | arc_mask = arc_preds.eq(arc_golds)[mask] 42 | rel_mask = rel_preds.eq(rel_golds)[mask] & arc_mask 43 | 44 | self.total += len(arc_mask) 45 | self.correct_arcs += arc_mask.sum().item() 46 | self.correct_rels += rel_mask.sum().item() 47 | 48 | def __lt__(self, other): 49 | return self.score < other 50 | 51 | def __le__(self, other): 52 | return self.score <= other 53 | 54 | def __ge__(self, other): 55 | return self.score >= other 56 | 57 | def __gt__(self, other): 58 | return self.score > other 59 | 60 | @property 61 | def score(self): 62 | return self.las 63 | 64 | @property 65 | def uas(self): 66 | return self.correct_arcs / (self.total + self.eps) 67 | 68 | @property 69 | def las(self): 70 | return self.correct_rels / (self.total + self.eps) 71 | 72 | def reset(self): 73 | self.total = 0.0 74 | self.correct_arcs = 0.0 75 | self.correct_rels = 0.0 76 | 77 | def update_lists(self, preds: List[Tuple[int, str]], golds: List[Tuple[int, str]]): 78 | self.total += len(golds) 79 | self.correct_arcs += sum([p[0] == g[0] for p, g in zip(preds, golds)]) 80 | self.correct_rels += sum([p == g for p, g in zip(preds, golds)]) 81 | -------------------------------------------------------------------------------- /elit/metrics/parsing/conllx_eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-03-08 22:35 4 | import tempfile 5 | 6 | from elit.utils.io_util import get_resource, get_exitcode_stdout_stderr 7 | 8 | CONLLX_EVAL = get_resource( 9 | 'https://github.com/elikip/bist-parser/archive/master.zip' + '#bmstparser/src/utils/eval.pl') 10 | 11 | 12 | def evaluate(gold_file, pred_file): 13 | """Evaluate using official CoNLL-X evaluation script (Yuval Krymolowski) 14 | 15 | Args: 16 | gold_file(str): The gold conllx file 17 | pred_file(str): The pred conllx file 18 | 19 | Returns: 20 | 21 | 22 | """ 23 | gold_file = get_resource(gold_file) 24 | fixed_pred_file = tempfile.NamedTemporaryFile().name 25 | copy_cols(gold_file, pred_file, fixed_pred_file, keep_comments=False) 26 | if gold_file.endswith('.conllu'): 27 | fixed_gold_file = tempfile.NamedTemporaryFile().name 28 | copy_cols(gold_file, gold_file, fixed_gold_file, keep_comments=False) 29 | gold_file = fixed_gold_file 30 | 31 | exitcode, out, err = get_exitcode_stdout_stderr(f'perl {CONLLX_EVAL} -q -b -g {gold_file} -s {fixed_pred_file}') 32 | if exitcode: 33 | raise RuntimeError(f'eval.pl exited with error code {exitcode} and error message {err} and output {out}.') 34 | lines = out.split('\n')[-4:] 35 | las = int(lines[0].split()[3]) / int(lines[0].split()[5]) 36 | uas = int(lines[1].split()[3]) / int(lines[1].split()[5]) 37 | return uas, las 38 | 39 | 40 | def copy_cols(gold_file, pred_file, copied_pred_file, keep_comments=True): 41 | """Copy the first 6 columns from gold file to pred file 42 | 43 | Args: 44 | gold_file: 45 | pred_file: 46 | copied_pred_file: 47 | keep_comments: (Default value = True) 48 | 49 | Returns: 50 | 51 | 52 | """ 53 | with open(copied_pred_file, 'w') as to_out, open(pred_file) as pred_file, open(gold_file) as gold_file: 54 | for idx, (p, g) in enumerate(zip(pred_file, gold_file)): 55 | while p.startswith('#'): 56 | p = next(pred_file) 57 | if not g.strip(): 58 | if p.strip(): 59 | raise ValueError( 60 | f'Prediction file {pred_file.name} does not end a sentence at line {idx + 1}\n{p.strip()}') 61 | to_out.write('\n') 62 | continue 63 | while g.startswith('#') or '-' in g.split('\t')[0]: 64 | if keep_comments or g.startswith('-'): 65 | to_out.write(g) 66 | g = next(gold_file) 67 | to_out.write('\t'.join(str(x) for x in g.split('\t')[:6] + p.split('\t')[6:])) 68 | -------------------------------------------------------------------------------- /elit/metrics/parsing/labeled_score.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-27 00:49 4 | 5 | import tensorflow as tf 6 | 7 | 8 | class LabeledScore(object): 9 | 10 | def __init__(self, eps=1e-5): 11 | super(LabeledScore, self).__init__() 12 | 13 | self.eps = eps 14 | self.total = 0.0 15 | self.correct_arcs = 0.0 16 | self.correct_rels = 0.0 17 | 18 | def __repr__(self): 19 | return f"UAS: {self.uas:6.2%} LAS: {self.las:6.2%}" 20 | 21 | def __call__(self, arc_preds, rel_preds, arc_golds, rel_golds, mask): 22 | arc_mask = (arc_preds == arc_golds)[mask] 23 | rel_mask = (rel_preds == rel_golds)[mask] & arc_mask 24 | 25 | self.total += len(arc_mask) 26 | self.correct_arcs += int(tf.math.count_nonzero(arc_mask)) 27 | self.correct_rels += int(tf.math.count_nonzero(rel_mask)) 28 | 29 | def __lt__(self, other): 30 | return self.score < other 31 | 32 | def __le__(self, other): 33 | return self.score <= other 34 | 35 | def __ge__(self, other): 36 | return self.score >= other 37 | 38 | def __gt__(self, other): 39 | return self.score > other 40 | 41 | @property 42 | def score(self): 43 | return self.las 44 | 45 | @property 46 | def uas(self): 47 | return self.correct_arcs / (self.total + self.eps) 48 | 49 | @property 50 | def las(self): 51 | return self.correct_rels / (self.total + self.eps) 52 | 53 | def reset_states(self): 54 | self.total = 0.0 55 | self.correct_arcs = 0.0 56 | self.correct_rels = 0.0 57 | 58 | def to_dict(self) -> dict: 59 | return {'UAS': self.uas, 'LAS': self.las} 60 | -------------------------------------------------------------------------------- /elit/metrics/srl/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-16 18:44 -------------------------------------------------------------------------------- /elit/metrics/srl/e2e_srl.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-01-23 12:43 4 | from elit.metrics.mtl import MetricDict 5 | 6 | 7 | class SemanticRoleLabelingMetrics(MetricDict): 8 | @property 9 | def score(self): 10 | """Obtain the end-to-end score, which is the major metric for SRL. 11 | 12 | Returns: 13 | The end-to-end score. 14 | """ 15 | return self['e2e'].score 16 | -------------------------------------------------------------------------------- /elit/metrics/srl/srlconll.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-07-16 18:44 4 | import os 5 | 6 | from elit.utils.io_util import get_resource, get_exitcode_stdout_stderr, run_cmd 7 | 8 | 9 | def official_conll_05_evaluate(pred_path, gold_path): 10 | script_root = get_resource('http://www.lsi.upc.edu/~srlconll/srlconll-1.1.tgz') 11 | lib_path = f'{script_root}/lib' 12 | if lib_path not in os.environ.get("PERL5LIB", ""): 13 | os.environ['PERL5LIB'] = f'{lib_path}:{os.environ.get("PERL5LIB", "")}' 14 | bin_path = f'{script_root}/bin' 15 | if bin_path not in os.environ.get('PATH', ''): 16 | os.environ['PATH'] = f'{bin_path}:{os.environ.get("PATH", "")}' 17 | eval_info_gold_pred = run_cmd(f'perl {script_root}/bin/srl-eval.pl {gold_path} {pred_path}') 18 | eval_info_pred_gold = run_cmd(f'perl {script_root}/bin/srl-eval.pl {pred_path} {gold_path}') 19 | conll_recall = float(eval_info_gold_pred.strip().split("\n")[6].strip().split()[5]) / 100 20 | conll_precision = float(eval_info_pred_gold.strip().split("\n")[6].strip().split()[5]) / 100 21 | if conll_recall + conll_precision > 0: 22 | conll_f1 = 2 * conll_recall * conll_precision / (conll_recall + conll_precision) 23 | else: 24 | conll_f1 = 0 25 | return conll_precision, conll_recall, conll_f1 26 | 27 | 28 | def run_perl(script, src, dst=None): 29 | os.environ['PERL5LIB'] = f'' 30 | exitcode, out, err = get_exitcode_stdout_stderr( 31 | f'perl -I{os.path.expanduser("~/.local/lib/perl5")} {script} {src}') 32 | if exitcode: 33 | # cpanm -l ~/.local namespace::autoclean 34 | # cpanm -l ~/.local Moose 35 | # cpanm -l ~/.local MooseX::SemiAffordanceAccessor module 36 | raise RuntimeError(err) 37 | with open(dst, 'w') as ofile: 38 | ofile.write(out) 39 | return dst 40 | -------------------------------------------------------------------------------- /elit/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-11-11 18:44 -------------------------------------------------------------------------------- /elit/optimizers/adamw/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-11-11 18:44 4 | import tensorflow as tf 5 | from elit.optimizers.adamw.optimization import WarmUp, AdamWeightDecay 6 | 7 | 8 | # from elit.optimization.adamw.optimizers_v2 import AdamW 9 | # from elit.optimization.adamw.utils import get_weight_decays 10 | 11 | 12 | # def create_optimizer(model, init_lr, num_train_steps, num_warmup_steps): 13 | # """Creates an optimizer with learning rate schedule.""" 14 | # wd_dict = get_weight_decays(model) 15 | # 16 | # # Implements linear decay of the learning rate. 17 | # learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( 18 | # initial_learning_rate=init_lr, 19 | # decay_steps=num_train_steps, 20 | # end_learning_rate=0.0) 21 | # if num_warmup_steps: 22 | # learning_rate_fn = WarmUp(initial_learning_rate=init_lr, 23 | # decay_schedule_fn=learning_rate_fn, 24 | # warmup_steps=num_warmup_steps) 25 | # optimizer = AdamW( 26 | # learning_rate=learning_rate_fn, 27 | # weight_decay_rate=0.01, 28 | # beta_1=0.9, 29 | # beta_2=0.999, 30 | # epsilon=1e-6, 31 | # exclude_from_weight_decay=['layer_norm', 'bias']) 32 | # return optimizer 33 | 34 | 35 | def create_optimizer(init_lr, num_train_steps, num_warmup_steps, weight_decay_rate=0.01, epsilon=1e-6, clipnorm=None): 36 | """Creates an optimizer with learning rate schedule. 37 | 38 | Args: 39 | init_lr: 40 | num_train_steps: 41 | num_warmup_steps: 42 | weight_decay_rate: (Default value = 0.01) 43 | epsilon: (Default value = 1e-6) 44 | clipnorm: (Default value = None) 45 | 46 | Returns: 47 | 48 | """ 49 | # Implements linear decay of the learning rate. 50 | learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( 51 | initial_learning_rate=init_lr, 52 | decay_steps=num_train_steps, 53 | end_learning_rate=0.0) 54 | if num_warmup_steps: 55 | learning_rate_fn = WarmUp(initial_learning_rate=init_lr, 56 | decay_schedule_fn=learning_rate_fn, 57 | warmup_steps=num_warmup_steps) 58 | additional_args = {} 59 | if clipnorm: 60 | additional_args['clipnorm'] = clipnorm 61 | optimizer = AdamWeightDecay( 62 | learning_rate=learning_rate_fn, 63 | weight_decay_rate=weight_decay_rate, 64 | beta_1=0.9, 65 | beta_2=0.999, 66 | epsilon=epsilon, 67 | exclude_from_weight_decay=['LayerNorm', 'bias'], 68 | **additional_args 69 | ) 70 | # {'LayerNorm/gamma:0', 'LayerNorm/beta:0'} 71 | return optimizer 72 | -------------------------------------------------------------------------------- /elit/pretrained/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 19:10 4 | from elit.pretrained import tok 5 | from elit.pretrained import dep 6 | from elit.pretrained import sdp 7 | from elit.pretrained import glove 8 | from elit.pretrained import pos 9 | from elit.pretrained import rnnlm 10 | from elit.pretrained import word2vec 11 | from elit.pretrained import ner 12 | from elit.pretrained import classifiers 13 | from elit.pretrained import fasttext 14 | from elit.pretrained import mtl 15 | from elit.pretrained import eos 16 | from elit.pretrained import sts 17 | from elit.pretrained import constituency 18 | from elit.pretrained import amr 19 | from elit.pretrained import amr2text 20 | from elit.pretrained import srl 21 | 22 | # Will be filled up during runtime 23 | ALL = {} 24 | -------------------------------------------------------------------------------- /elit/pretrained/amr2text.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-12-07 15:19 4 | from hanlp_common.constant import HANLP_URL 5 | 6 | AMR3_GRAPH_PRETRAIN_GENERATION = HANLP_URL + 'amr2text/amr3_graph_pretrain_generation_20221207_153535.zip' 7 | '''A seq2seq (:cite:`bevilacqua-etal-2021-one`) BART (:cite:`lewis-etal-2020-bart`) large AMR2Text generator trained on 8 | Abstract Meaning Representation 3.0 (:cite:`knight2014abstract`) with graph pre-training (:cite:`bai-etal-2022-graph`). 9 | Its Sacre-BLEU is ``50.38`` according to their official repository. 10 | ''' 11 | 12 | # Will be filled up during runtime 13 | ALL = {} 14 | -------------------------------------------------------------------------------- /elit/pretrained/classifiers.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-01-01 03:51 4 | from hanlp_common.constant import HANLP_URL 5 | 6 | CHNSENTICORP_BERT_BASE_ZH = HANLP_URL + 'classification/chnsenticorp_bert_base_20211228_163210.zip' 7 | SST2_ALBERT_BASE_EN = HANLP_URL + 'classification/sst2_albert_base_20211228_164917.zip' 8 | 9 | LID_176_FASTTEXT_BASE = 'https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin' 10 | ''' 11 | 126MB FastText model for language identification trained on data from Wikipedia, Tatoeba and SETimes. 12 | ''' 13 | LID_176_FASTTEXT_SMALL = 'https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz' 14 | ''' 15 | 917kB FastText model for language identification trained on data from Wikipedia, Tatoeba and SETimes. 16 | ''' 17 | 18 | ALL = {} 19 | -------------------------------------------------------------------------------- /elit/pretrained/constituency.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author=hankcs 3 | # Date=2022-01-18 10:34 4 | from hanlp_common.constant import HANLP_URL 5 | 6 | CTB9_CON_ELECTRA_SMALL = HANLP_URL + 'constituency/ctb9_con_electra_small_20220215_230116.zip' 7 | 'Electra (:cite:`clark2020electra`) small tree CRF model (:cite:`ijcai2020-560`) trained on CTB9 with major categories. ' \ 8 | 'Its performance is UCM=39.06% LCM=34.99% UP=90.05% UR=90.01% UF=90.03% LP=87.02% LR=86.98% LF=87.00%.' 9 | 10 | CTB9_CON_FULL_TAG_ELECTRA_SMALL = HANLP_URL + 'constituency/ctb9_full_tag_con_electra_small_20220118_103119.zip' 11 | 'Electra (:cite:`clark2020electra`) small tree CRF model (:cite:`ijcai2020-560`) trained on CTB9 with full subcategories. ' \ 12 | 'Its performance is UCM=38.29% LCM=28.95% UP=90.16% UR=90.13% UF=90.15% LP=83.46% LR=83.43% LF=83.45%.' 13 | 14 | CTB9_CON_FULL_TAG_ERNIE_GRAM = 'http://download.hanlp.com/constituency/extra/ctb9_full_tag_con_ernie_20220331_121430.zip' 15 | 'ERNIE-GRAM (:cite:`xiao-etal-2021-ernie`) base tree CRF model (:cite:`ijcai2020-560`) trained on CTB9 with full subcategories. ' \ 16 | 'Its performance is UCM=42.04% LCM=31.72% UP=91.33% UR=91.53% UF=91.43% LP=85.31% LR=85.49% LF=85.40%.' 17 | 18 | # Will be filled up during runtime 19 | ALL = {} 20 | -------------------------------------------------------------------------------- /elit/pretrained/dep.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-29 02:55 4 | from hanlp_common.constant import HANLP_URL 5 | 6 | CTB5_BIAFFINE_DEP_ZH = HANLP_URL + 'dep/biaffine_ctb5_20191229_025833.zip' 7 | 'Biaffine LSTM model (:cite:`dozat:17a`) trained on CTB5.' 8 | CTB7_BIAFFINE_DEP_ZH = HANLP_URL + 'dep/biaffine_ctb7_20200109_022431.zip' 9 | 'Biaffine LSTM model (:cite:`dozat:17a`) trained on CTB7.' 10 | CTB9_DEP_ELECTRA_SMALL = HANLP_URL + 'dep/ctb9_dep_electra_small_20220216_100306.zip' 11 | 'Electra small encoder (:cite:`clark2020electra`) with Biaffine decoder (:cite:`dozat:17a`) trained on CTB9-SD330. ' \ 12 | 'Performance is UAS=87.68% LAS=83.54%.' 13 | PMT1_DEP_ELECTRA_SMALL = HANLP_URL + 'dep/pmt_dep_electra_small_20220218_134518.zip' 14 | 'Electra small encoder (:cite:`clark2020electra`) with Biaffine decoder (:cite:`dozat:17a`) trained on PKU ' \ 15 | 'Multi-view Chinese Treebank (PMT) 1.0 (:cite:`qiu-etal-2014-multi`). Performance is UAS=91.21% LAS=88.65%.' 16 | CTB9_UDC_ELECTRA_SMALL = HANLP_URL + 'dep/udc_dep_electra_small_20220218_095452.zip' 17 | 'Electra small encoder (:cite:`clark2020electra`) with Biaffine decoder (:cite:`dozat:17a`) trained on CTB9-UD420. ' \ 18 | 'Performance is UAS=85.92% LAS=81.13% .' 19 | 20 | PTB_BIAFFINE_DEP_EN = HANLP_URL + 'dep/ptb_dep_biaffine_20200101_174624.zip' 21 | 'Biaffine LSTM model (:cite:`dozat:17a`) trained on PTB.' 22 | 23 | ALL = {} 24 | -------------------------------------------------------------------------------- /elit/pretrained/eos.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-22 13:22 4 | from hanlp_common.constant import HANLP_URL 5 | 6 | UD_CTB_EOS_MUL = HANLP_URL + 'eos/eos_ud_ctb_mul_20201222_133543.zip' 7 | 'EOS model (:cite:`Schweter:Ahmed:2019`) trained on concatenated UD2.3 and CTB9.' 8 | 9 | # Will be filled up during runtime 10 | ALL = {} 11 | -------------------------------------------------------------------------------- /elit/pretrained/fasttext.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-30 18:57 4 | FASTTEXT_DEBUG_EMBEDDING_EN = 'https://elit-models.s3-us-west-2.amazonaws.com/fasttext.debug.bin.zip' 5 | FASTTEXT_CC_300_EN = 'https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz' 6 | 'FastText (:cite:`bojanowski2017enriching`) embeddings trained on Common Crawl.' 7 | FASTTEXT_WIKI_NYT_AMAZON_FRIENDS_200_EN \ 8 | = 'https://elit-models.s3-us-west-2.amazonaws.com/fasttext-200-wikipedia-nytimes-amazon-friends-20191107.bin' 9 | 'FastText (:cite:`bojanowski2017enriching`) embeddings trained on wikipedia, nytimes and friends.' 10 | 11 | FASTTEXT_WIKI_300_ZH = 'https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.zh.zip#wiki.zh.bin' 12 | 'FastText (:cite:`bojanowski2017enriching`) embeddings trained on Chinese Wikipedia.' 13 | FASTTEXT_WIKI_300_ZH_CLASSICAL = 'https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.zh_classical.zip#wiki.zh_classical.bin' 14 | 'FastText (:cite:`bojanowski2017enriching`) embeddings trained on traditional Chinese wikipedia.' 15 | 16 | ALL = {} 17 | -------------------------------------------------------------------------------- /elit/pretrained/glove.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emorynlp/seq2seq-corenlp/7155b117630b79ba1a640e76dfe5ba93e1166fff/elit/pretrained/glove.py -------------------------------------------------------------------------------- /elit/pretrained/ner.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-30 20:07 4 | from hanlp_common.constant import HANLP_URL 5 | 6 | MSRA_NER_BERT_BASE_ZH = HANLP_URL + 'ner/ner_bert_base_msra_20211227_114712.zip' 7 | 'BERT model (:cite:`devlin-etal-2019-bert`) trained on MSRA with 3 entity types.' 8 | MSRA_NER_ALBERT_BASE_ZH = HANLP_URL + 'ner/msra_ner_albert_base_20211228_173323.zip' 9 | 'ALBERT model (:cite:`Lan2020ALBERT:`) trained on MSRA with 3 entity types.' 10 | MSRA_NER_ELECTRA_SMALL_ZH = HANLP_URL + 'ner/msra_ner_electra_small_20220215_205503.zip' 11 | 'Electra small model (:cite:`clark2020electra`) trained on MSRA with 26 entity types. F1 = `95.16`' 12 | CONLL03_NER_BERT_BASE_CASED_EN = HANLP_URL + 'ner/ner_conll03_bert_base_cased_en_20211227_121443.zip' 13 | 'BERT model (:cite:`devlin-etal-2019-bert`) trained on CoNLL03.' 14 | 15 | ALL = {} 16 | -------------------------------------------------------------------------------- /elit/pretrained/pos.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-29 01:57 4 | from hanlp_common.constant import HANLP_URL 5 | 6 | CTB5_POS_RNN = HANLP_URL + 'pos/ctb5_pos_rnn_20200113_235925.zip' 7 | 'An old school BiLSTM tagging model trained on CTB5.' 8 | CTB5_POS_RNN_FASTTEXT_ZH = HANLP_URL + 'pos/ctb5_pos_rnn_fasttext_20191230_202639.zip' 9 | 'An old school BiLSTM tagging model with FastText (:cite:`bojanowski2017enriching`) embeddings trained on CTB5.' 10 | CTB9_POS_ALBERT_BASE = HANLP_URL + 'pos/ctb9_albert_base_20211228_163935.zip' 11 | 'ALBERT model (:cite:`Lan2020ALBERT:`) trained on CTB9 (:cite:`https://doi.org/10.35111/gvd0-xk91`). This is a TF component.' 12 | CTB9_POS_ELECTRA_SMALL_TF = HANLP_URL + 'pos/pos_ctb_electra_small_20211227_121341.zip' 13 | 'Electra small model (:cite:`clark2020electra`) trained on CTB9 (:cite:`https://doi.org/10.35111/gvd0-xk91`). Accuracy = `96.75`. This is a TF component.' 14 | CTB9_POS_ELECTRA_SMALL = HANLP_URL + 'pos/pos_ctb_electra_small_20220215_111944.zip' 15 | 'Electra small model (:cite:`clark2020electra`) trained on CTB9 (:cite:`https://doi.org/10.35111/gvd0-xk91`). Accuracy = `96.26`.' 16 | CTB9_POS_RADICAL_ELECTRA_SMALL = HANLP_URL + 'pos/pos_ctb_radical_electra_small_20220215_111932.zip' 17 | 'Electra small model (:cite:`clark2020electra`) with radical embeddings (:cite:`he2018dual`) trained on CTB9 (:cite:`https://doi.org/10.35111/gvd0-xk91`). Accuracy = `96.14`.' 18 | C863_POS_ELECTRA_SMALL = HANLP_URL + 'pos/pos_863_electra_small_20220217_101958.zip' 19 | 'Electra small model (:cite:`clark2020electra`) trained on Chinese 863 corpus. Accuracy = `95.19`.' 20 | PKU_POS_ELECTRA_SMALL = HANLP_URL + 'pos/pos_pku_electra_small_20220217_142436.zip' 21 | 'Electra small model (:cite:`clark2020electra`) trained on Chinese PKU corpus. Accuracy = `97.55`.' 22 | PTB_POS_RNN_FASTTEXT_EN = HANLP_URL + 'pos/ptb_pos_rnn_fasttext_20220418_101708.zip' 23 | 'An old school BiLSTM tagging model with FastText (:cite:`bojanowski2017enriching`) embeddings trained on PTB.' 24 | 25 | ALL = {} 26 | -------------------------------------------------------------------------------- /elit/pretrained/rnnlm.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-19 03:47 4 | from hanlp_common.constant import HANLP_URL 5 | 6 | FLAIR_LM_FW_WMT11_EN_TF = HANLP_URL + 'lm/flair_lm_wmt11_en_20200211_091932.zip#flair_lm_fw_wmt11_en' 7 | 'The forward LSTM of Contextual String Embedding (:cite:`akbik-etal-2018-contextual`).' 8 | FLAIR_LM_BW_WMT11_EN_TF = HANLP_URL + 'lm/flair_lm_wmt11_en_20200211_091932.zip#flair_lm_bw_wmt11_en' 9 | 'The backward LSTM of Contextual String Embedding (:cite:`akbik-etal-2018-contextual`).' 10 | FLAIR_LM_WMT11_EN = HANLP_URL + 'lm/flair_lm_wmt11_en_20200601_205350.zip' 11 | 'The BiLSTM of Contextual String Embedding (:cite:`akbik-etal-2018-contextual`).' 12 | 13 | ALL = {} 14 | -------------------------------------------------------------------------------- /elit/pretrained/sdp.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-31 23:54 4 | from hanlp_common.constant import HANLP_URL 5 | 6 | SEMEVAL16_NEWS_BIAFFINE_ZH = HANLP_URL + 'sdp/semeval16-news-biaffine_20191231_235407.zip' 7 | 'Biaffine SDP (:cite:`he-choi-2019`) trained on SemEval16 news data.' 8 | SEMEVAL16_TEXT_BIAFFINE_ZH = HANLP_URL + 'sdp/semeval16-text-biaffine_20200101_002257.zip' 9 | 'Biaffine SDP (:cite:`he-choi-2019`) trained on SemEval16 text data.' 10 | 11 | SEMEVAL16_ALL_ELECTRA_SMALL_ZH = HANLP_URL + 'sdp/semeval16_sdp_electra_small_20220719_171433.zip' 12 | 'Biaffine SDP (:cite:`he-choi-2019`) trained on SemEval16 text and news data. Performance: ``UF: 83.03% LF: 72.58%``' 13 | 14 | SEMEVAL15_PAS_BIAFFINE_EN = HANLP_URL + 'sdp/semeval15_biaffine_pas_20200103_152405.zip' 15 | 'Biaffine SDP (:cite:`he-choi-2019`) trained on SemEval15 PAS data.' 16 | SEMEVAL15_PSD_BIAFFINE_EN = HANLP_URL + 'sdp/semeval15_biaffine_psd_20200106_123009.zip' 17 | 'Biaffine SDP (:cite:`he-choi-2019`) trained on SemEval15 PSD data.' 18 | SEMEVAL15_DM_BIAFFINE_EN = HANLP_URL + 'sdp/semeval15_biaffine_dm_20200106_122808.zip' 19 | 'Biaffine SDP (:cite:`he-choi-2019`) trained on SemEval15 DM data.' 20 | 21 | ALL = {} 22 | -------------------------------------------------------------------------------- /elit/pretrained/srl.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-08-07 19:07 4 | from hanlp_common.constant import HANLP_URL 5 | 6 | CPB3_SRL_ELECTRA_SMALL = HANLP_URL + 'srl/cpb3_electra_small_crf_has_transform_20220218_135910.zip' 7 | 'Electra small model (:cite:`clark2020electra`) trained on CPB3. P=75.87% R=76.24% F1=76.05%.' 8 | 9 | ALL = {} 10 | -------------------------------------------------------------------------------- /elit/pretrained/sts.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-05-24 12:51 4 | from hanlp_common.constant import HANLP_URL 5 | 6 | STS_ELECTRA_BASE_ZH = HANLP_URL + 'sts/sts_electra_base_zh_20210530_200109.zip' 7 | 'A naive regression model trained on concatenated STS corpora.' 8 | 9 | # Will be filled up during runtime 10 | ALL = {} 11 | -------------------------------------------------------------------------------- /elit/pretrained/word2vec.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-21 18:25 4 | from hanlp_common.constant import HANLP_URL 5 | 6 | CONVSEG_W2V_NEWS_TENSITE = HANLP_URL + 'embeddings/convseg_embeddings.zip' 7 | CONVSEG_W2V_NEWS_TENSITE_WORD_PKU = CONVSEG_W2V_NEWS_TENSITE + '#news_tensite.pku.words.w2v50' 8 | CONVSEG_W2V_NEWS_TENSITE_WORD_MSR = CONVSEG_W2V_NEWS_TENSITE + '#news_tensite.msr.words.w2v50' 9 | CONVSEG_W2V_NEWS_TENSITE_CHAR = CONVSEG_W2V_NEWS_TENSITE + '#news_tensite.w2v200' 10 | 11 | SEMEVAL16_EMBEDDINGS_CN = HANLP_URL + 'embeddings/semeval16_embeddings.zip' 12 | SEMEVAL16_EMBEDDINGS_300_NEWS_CN = SEMEVAL16_EMBEDDINGS_CN + '#news.fasttext.300.txt' 13 | SEMEVAL16_EMBEDDINGS_300_TEXT_CN = SEMEVAL16_EMBEDDINGS_CN + '#text.fasttext.300.txt' 14 | 15 | CTB5_FASTTEXT_300_CN = HANLP_URL + 'embeddings/ctb.fasttext.300.txt.zip' 16 | 17 | TENCENT_AILAB_EMBEDDING_SMALL_200 = 'https://ai.tencent.com/ailab/nlp/en/data/tencent-ailab-embedding-zh-d200-v0.2.0-s.tar.gz#tencent-ailab-embedding-zh-d200-v0.2.0-s.txt' 18 | 'Chinese word embeddings (:cite:`NIPS2013_9aa42b31`) with small vocabulary size and 200 dimension provided by Tencent AI lab.' 19 | TENCENT_AILAB_EMBEDDING_LARGE_200 = 'https://ai.tencent.com/ailab/nlp/en/data/tencent-ailab-embedding-zh-d200-v0.2.0.tar.gz#tencent-ailab-embedding-zh-d200-v0.2.0.txt' 20 | 'Chinese word embeddings (:cite:`NIPS2013_9aa42b31`) with large vocabulary size and 200 dimension provided by Tencent AI lab.' 21 | TENCENT_AILAB_EMBEDDING_SMALL_100 = 'https://ai.tencent.com/ailab/nlp/en/data/tencent-ailab-embedding-zh-d100-v0.2.0-s.tar.gz#tencent-ailab-embedding-zh-d100-v0.2.0-s.txt' 22 | 'Chinese word embeddings (:cite:`NIPS2013_9aa42b31`) with small vocabulary size and 100 dimension provided by Tencent AI lab.' 23 | TENCENT_AILAB_EMBEDDING_LARGE_100 = 'https://ai.tencent.com/ailab/nlp/en/data/tencent-ailab-embedding-zh-d100-v0.2.0.tar.gz#tencent-ailab-embedding-zh-d100-v0.2.0.txt' 24 | 'Chinese word embeddings (:cite:`NIPS2013_9aa42b31`) with large vocabulary size and 100 dimension provided by Tencent AI lab.' 25 | 26 | MERGE_SGNS_BIGRAM_CHAR_300_ZH = 'http://download.hanlp.com/embeddings/extra/merge_sgns_bigram_char300_20220130_214613.txt.zip' 27 | 'Chinese word embeddings trained with context features (word, ngram, character, and more) using Skip-Gram with Negative Sampling (SGNS) (:cite:`li-etal-2018-analogical`).' 28 | 29 | RADICAL_CHAR_EMBEDDING_100 = HANLP_URL + 'embeddings/radical_char_vec_20191229_013849.zip#character.vec.txt' 30 | 'Chinese character embedding enhanced with rich radical information (:cite:`he2018dual`).' 31 | 32 | _SUBWORD_ENCODING_CWS = 'http://download.hanlp.com/embeddings/extra/subword_encoding_cws_20200524_190636.zip' 33 | SUBWORD_ENCODING_CWS_ZH_WIKI_BPE_50 = _SUBWORD_ENCODING_CWS + '#zh.wiki.bpe.vs200000.d50.w2v.txt' 34 | SUBWORD_ENCODING_CWS_GIGAWORD_UNI = _SUBWORD_ENCODING_CWS + '#gigaword_chn.all.a2b.uni.ite50.vec' 35 | SUBWORD_ENCODING_CWS_GIGAWORD_BI = _SUBWORD_ENCODING_CWS + '#gigaword_chn.all.a2b.bi.ite50.vec' 36 | SUBWORD_ENCODING_CWS_CTB_GAZETTEER_50 = _SUBWORD_ENCODING_CWS + '#ctb.50d.vec' 37 | 38 | ALL = {} 39 | -------------------------------------------------------------------------------- /elit/transform/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-29 22:24 -------------------------------------------------------------------------------- /elit/transform/glue_tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-05-08 16:34 4 | from hanlp_common.structure import SerializableDict 5 | from elit.datasets.glu.glue import STANFORD_SENTIMENT_TREEBANK_2_TRAIN, MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_DEV 6 | from elit.transform.table_tf import TableTransform 7 | 8 | 9 | class StanfordSentimentTreebank2Transorm(TableTransform): 10 | pass 11 | 12 | 13 | class MicrosoftResearchParaphraseCorpus(TableTransform): 14 | 15 | def __init__(self, config: SerializableDict = None, map_x=False, map_y=True, x_columns=(3, 4), 16 | y_column=0, skip_header=True, delimiter='auto', **kwargs) -> None: 17 | super().__init__(config, map_x, map_y, x_columns, y_column, skip_header, delimiter, **kwargs) 18 | 19 | 20 | def main(): 21 | # _test_sst2() 22 | _test_mrpc() 23 | 24 | 25 | def _test_sst2(): 26 | transform = StanfordSentimentTreebank2Transorm() 27 | transform.fit(STANFORD_SENTIMENT_TREEBANK_2_TRAIN) 28 | transform.lock_vocabs() 29 | transform.label_vocab.summary() 30 | transform.build_config() 31 | dataset = transform.file_to_dataset(STANFORD_SENTIMENT_TREEBANK_2_TRAIN) 32 | for batch in dataset.take(1): 33 | print(batch) 34 | 35 | 36 | def _test_mrpc(): 37 | transform = MicrosoftResearchParaphraseCorpus() 38 | transform.fit(MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_DEV) 39 | transform.lock_vocabs() 40 | transform.label_vocab.summary() 41 | transform.build_config() 42 | dataset = transform.file_to_dataset(MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_DEV) 43 | for batch in dataset.take(1): 44 | print(batch) -------------------------------------------------------------------------------- /elit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-08-24 22:12 4 | from . import rules 5 | 6 | 7 | def ls_resource_in_module(root) -> dict: 8 | res = dict() 9 | for k, v in root.__dict__.items(): 10 | if k.startswith('_') or v == root: 11 | continue 12 | if isinstance(v, str): 13 | if v.startswith('http') and not v.endswith('/') and not v.endswith('#') and not v.startswith('_'): 14 | res[k] = v 15 | elif type(v).__name__ == 'module': 16 | res.update(ls_resource_in_module(v)) 17 | if 'ALL' in root.__dict__ and isinstance(root.__dict__['ALL'], dict): 18 | root.__dict__['ALL'].update(res) 19 | return res 20 | -------------------------------------------------------------------------------- /elit/utils/file_read_backwards/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .file_read_backwards import FileReadBackwards # noqa: F401 4 | 5 | __author__ = """Robin Robin""" 6 | __email__ = 'robinsquare42@gmail.com' 7 | __version__ = '2.0.0' 8 | -------------------------------------------------------------------------------- /elit/utils/init_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-05-27 13:25 4 | import math 5 | 6 | import torch 7 | from torch import nn 8 | import functools 9 | 10 | 11 | def embedding_uniform(tensor:torch.Tensor, seed=233): 12 | gen = torch.Generator().manual_seed(seed) 13 | with torch.no_grad(): 14 | fan_out = tensor.size(-1) 15 | bound = math.sqrt(3.0 / fan_out) 16 | return tensor.uniform_(-bound, bound, generator=gen) 17 | -------------------------------------------------------------------------------- /elit/utils/lang/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-01-09 18:46 4 | 5 | __doc__ = ''' 6 | This package holds misc utils for specific languages. 7 | ''' 8 | -------------------------------------------------------------------------------- /elit/utils/lang/en/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-12-28 19:28 4 | -------------------------------------------------------------------------------- /elit/utils/lang/ja/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-05-13 13:24 4 | -------------------------------------------------------------------------------- /elit/utils/lang/zh/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-01-09 18:47 -------------------------------------------------------------------------------- /elit/utils/lang/zh/char_table.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-01-09 19:07 4 | from typing import List 5 | 6 | from elit.utils.io_util import get_resource 7 | from hanlp_common.io import load_json 8 | 9 | HANLP_CHAR_TABLE_TXT = 'https://file.hankcs.com/corpus/char_table.zip#CharTable.txt' 10 | HANLP_CHAR_TABLE_JSON = 'https://file.hankcs.com/corpus/char_table.json.zip' 11 | 12 | 13 | class CharTable: 14 | convert = {} 15 | 16 | @staticmethod 17 | def convert_char(c): 18 | if not CharTable.convert: 19 | CharTable._init() 20 | return CharTable.convert.get(c, c) 21 | 22 | @staticmethod 23 | def normalize_text(text: str) -> str: 24 | return ''.join(CharTable.convert_char(c) for c in text) 25 | 26 | @staticmethod 27 | def normalize_chars(chars: List[str]) -> List[str]: 28 | return [CharTable.convert_char(c) for c in chars] 29 | 30 | @staticmethod 31 | def _init(): 32 | CharTable.convert = CharTable.load() 33 | 34 | @staticmethod 35 | def load(path=HANLP_CHAR_TABLE_TXT): 36 | mapper = {} 37 | with open(get_resource(path), encoding='utf-8') as src: 38 | for line in src: 39 | cells = line.rstrip('\n') 40 | if len(cells) != 3: 41 | continue 42 | a, _, b = cells 43 | mapper[a] = b 44 | return mapper 45 | 46 | 47 | class JsonCharTable(CharTable): 48 | 49 | @staticmethod 50 | def load(path=HANLP_CHAR_TABLE_JSON): 51 | return load_json(get_resource(path)) 52 | -------------------------------------------------------------------------------- /elit/utils/lang/zh/localization.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2020-12-05 02:09 4 | 5 | task = { 6 | 'dep': '依存句法树', 7 | 'token': '单词', 8 | 'pos': '词性', 9 | 'ner': '命名实体', 10 | 'srl': '语义角色' 11 | } 12 | 13 | pos = { 14 | 'VA': '表语形容词', 'VC': '系动词', 'VE': '动词有无', 'VV': '其他动词', 'NR': '专有名词', 'NT': '时间名词', 'NN': '其他名词', 15 | 'LC': '方位词', 'PN': '代词', 'DT': '限定词', 'CD': '概数词', 'OD': '序数词', 'M': '量词', 'AD': '副词', 'P': '介词', 16 | 'CC': '并列连接词', 'CS': '从属连词', 'DEC': '补语成分“的”', 'DEG': '属格“的”', 'DER': '表结果的“得”', 'DEV': '表方式的“地”', 17 | 'AS': '动态助词', 'SP': '句末助词', 'ETC': '表示省略', 'MSP': '其他小品词', 'IJ': '句首感叹词', 'ON': '象声词', 18 | 'LB': '长句式表被动', 'SB': '短句式表被动', 'BA': '把字句', 'JJ': '其他名词修饰语', 'FW': '外来语', 'PU': '标点符号', 19 | 'NOI': '噪声', 'URL': '网址' 20 | } 21 | 22 | ner = { 23 | 'NT': '机构团体', 'NS': '地名', 'NR': '人名' 24 | } 25 | 26 | dep = { 27 | 'nn': '复合名词修饰', 'punct': '标点符号', 'nsubj': '名词性主语', 'conj': '连接性状语', 'dobj': '直接宾语', 'advmod': '名词性状语', 28 | 'prep': '介词性修饰语', 'nummod': '数词修饰语', 'amod': '形容词修饰语', 'pobj': '介词性宾语', 'rcmod': '相关关系', 'cpm': '补语', 29 | 'assm': '关联标记', 'assmod': '关联修饰', 'cc': '并列关系', 'elf': '类别修饰', 'ccomp': '从句补充', 'det': '限定语', 'lobj': '时间介词', 30 | 'range': '数量词间接宾语', 'asp': '时态标记', 'tmod': '时间修饰语', 'plmod': '介词性地点修饰', 'attr': '属性', 'mmod': '情态动词', 31 | 'loc': '位置补语', 'top': '主题', 'pccomp': '介词补语', 'etc': '省略关系', 'lccomp': '位置补语', 'ordmod': '量词修饰', 32 | 'xsubj': '控制主语', 'neg': '否定修饰', 'rcomp': '结果补语', 'comod': '并列联合动词', 'vmod': '动词修饰', 'prtmod': '小品词', 33 | 'ba': '把字关系', 'dvpm': '地字修饰', 'dvpmod': '地字动词短语', 'prnmod': '插入词修饰', 'cop': '系动词', 'pass': '被动标记', 34 | 'nsubjpass': '被动名词主语', 'clf': '类别修饰', 'dep': '依赖关系', 'root': '核心关系' 35 | } 36 | -------------------------------------------------------------------------------- /elit/utils/rules.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | _SEPARATOR = r'@' 4 | _RE_SENTENCE = re.compile(r'(\S.+?[.!?])(?=\s+|$)|(\S.+?)(?=[\n]|$)', re.UNICODE) 5 | _AB_SENIOR = re.compile(r'([A-Z][a-z]{1,2}\.)\s(\w)', re.UNICODE) 6 | _AB_ACRONYM = re.compile(r'(\.[a-zA-Z]\.)\s(\w)', re.UNICODE) 7 | _UNDO_AB_SENIOR = re.compile(r'([A-Z][a-z]{1,2}\.)' + _SEPARATOR + r'(\w)', re.UNICODE) 8 | _UNDO_AB_ACRONYM = re.compile(r'(\.[a-zA-Z]\.)' + _SEPARATOR + r'(\w)', re.UNICODE) 9 | 10 | 11 | def _replace_with_separator(text, separator, regexs): 12 | replacement = r"\1" + separator + r"\2" 13 | result = text 14 | for regex in regexs: 15 | result = regex.sub(replacement, result) 16 | return result 17 | 18 | 19 | def split_sentence(text, best=True): 20 | text = re.sub(r'([。!??])([^”’])', r"\1\n\2", text) 21 | text = re.sub(r'(\.{6})([^”’])', r"\1\n\2", text) 22 | text = re.sub(r'(…{2})([^”’])', r"\1\n\2", text) 23 | text = re.sub(r'([。!??][”’])([^,。!??])', r'\1\n\2', text) 24 | for chunk in text.split("\n"): 25 | chunk = chunk.strip() 26 | if not chunk: 27 | continue 28 | if not best: 29 | yield chunk 30 | continue 31 | processed = _replace_with_separator(chunk, _SEPARATOR, [_AB_SENIOR, _AB_ACRONYM]) 32 | sents = list(_RE_SENTENCE.finditer(processed)) 33 | if not sents: 34 | yield chunk 35 | continue 36 | for sentence in sents: 37 | sentence = _replace_with_separator(sentence.group(), r" ", [_UNDO_AB_SENIOR, _UNDO_AB_ACRONYM]) 38 | yield sentence 39 | 40 | 41 | -------------------------------------------------------------------------------- /elit/utils/statistics/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-01-14 12:46 4 | -------------------------------------------------------------------------------- /elit/utils/statistics/moving_avg.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-01-14 12:46 4 | from collections import defaultdict, deque 5 | 6 | 7 | class MovingAverage(object): 8 | def __init__(self, maxlen=5) -> None: 9 | self._queue = defaultdict(lambda: deque(maxlen=maxlen)) 10 | 11 | def append(self, key, value: float): 12 | self._queue[key].append(value) 13 | 14 | def average(self, key) -> float: 15 | queue = self._queue[key] 16 | return sum(queue) / len(queue) 17 | -------------------------------------------------------------------------------- /elit/version.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 19:26 4 | 5 | __version__ = '2.1.0-beta.45' 6 | """ELIT version""" 7 | 8 | 9 | class NotCompatible(Exception): 10 | pass 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2019-12-28 19:26 4 | from os.path import abspath, join, dirname 5 | from setuptools import find_packages, setup 6 | 7 | this_dir = abspath(dirname(__file__)) 8 | with open(join(this_dir, 'README.md'), encoding='utf-8') as file: 9 | long_description = file.read() 10 | version = {} 11 | with open(join(this_dir, "elit", "version.py")) as fp: 12 | exec(fp.read(), version) 13 | 14 | setup( 15 | name='seq2seq-corenlp', 16 | version=version['__version__'], 17 | description='Unleashing the True Potential of Sequence-to-Sequence Models for Sequence Tagging and Structure Parsing', 18 | long_description=long_description, 19 | long_description_content_type="text/markdown", 20 | url='https://github.com/emorynlp/seq2seq-corenlp', 21 | author='Han He', 22 | author_email='han.he@emory.edu', 23 | license='Apache License 2.0', 24 | classifiers=[ 25 | 'Intended Audience :: Science/Research', 26 | 'Intended Audience :: Developers', 27 | "Development Status :: 3 - Alpha", 28 | 'Operating System :: OS Independent', 29 | "License :: OSI Approved :: Apache Software License", 30 | 'Programming Language :: Python :: 3 :: Only', 31 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 32 | "Topic :: Text Processing :: Linguistic" 33 | ], 34 | keywords='corpus,machine-learning,NLU,NLP', 35 | packages=find_packages(exclude=['docs', 'tests*']), 36 | include_package_data=True, 37 | install_requires=[ 38 | 'termcolor', 39 | 'pynvml', 40 | 'alnlp', 41 | 'toposort==1.5', 42 | 'transformers==4.9.2', 43 | 'sentencepiece>=0.1.91' 44 | 'torch>=1.6.0', 45 | 'hanlp-common==0.0.11', 46 | 'hanlp-trie==0.0.4', 47 | 'hanlp-downloader', 48 | 'tensorboardX==2.1', 49 | 'penman==1.2.2', 50 | 'networkx==2.8.8', 51 | ], 52 | python_requires='>=3.6', 53 | ) 54 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2023-02-06 18:26 4 | import os 5 | 6 | root = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) 7 | 8 | 9 | def cdroot(): 10 | """ 11 | cd to project root, so models are saved in the root folder 12 | """ 13 | os.chdir(root) 14 | -------------------------------------------------------------------------------- /tests/con/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-04-17 22:28 4 | -------------------------------------------------------------------------------- /tests/con/ontonotes/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2023-02-06 18:49 4 | -------------------------------------------------------------------------------- /tests/con/ontonotes/ls.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-03-29 22:10 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.components.seq2seq import Seq2seqConstituencyParser 6 | from elit.components.seq2seq.con.verbalizer import BracketedVerbalizer 7 | from elit.datasets.srl.ontonotes5.english import ONTONOTES5_CON_ENGLISH_TEST, ONTONOTES5_CON_ENGLISH_DEV, \ 8 | ONTONOTES5_CON_ENGLISH_TRAIN 9 | from elit.utils.log_util import cprint 10 | from tests import cdroot 11 | 12 | cdroot() 13 | scores = [] 14 | for i in range(3): 15 | save_dir = f'data/model/con/ontonotes/ls/{i}' 16 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 17 | con = Seq2seqConstituencyParser() 18 | con.fit( 19 | ONTONOTES5_CON_ENGLISH_TRAIN, 20 | ONTONOTES5_CON_ENGLISH_DEV, 21 | save_dir, 22 | BracketedVerbalizer(flatten_pos=True, anonymize_token=True), 23 | epochs=30, 24 | eval_after=25, 25 | gradient_accumulation=4, 26 | sampler_builder=SortingSamplerBuilder(batch_size=32, use_effective_tokens=True), 27 | ) 28 | con.load(save_dir) 29 | test_score = con.evaluate(ONTONOTES5_CON_ENGLISH_TEST, save_dir, official=True)[-1] 30 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 31 | scores.append(test_score) 32 | 33 | print(f'Scores on {len(scores)} runs:') 34 | for metric in scores: 35 | print(metric) 36 | -------------------------------------------------------------------------------- /tests/con/ontonotes/lt.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-03-29 22:10 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.components.seq2seq import Seq2seqConstituencyParser 6 | from elit.components.seq2seq.con.verbalizer import BracketedVerbalizer 7 | from elit.datasets.srl.ontonotes5.english import ONTONOTES5_CON_ENGLISH_TRAIN, ONTONOTES5_CON_ENGLISH_DEV, \ 8 | ONTONOTES5_CON_ENGLISH_TEST 9 | from elit.utils.log_util import cprint 10 | from tests import cdroot 11 | 12 | cdroot() 13 | scores = [] 14 | for i in range(3): 15 | save_dir = f'data/model/con/ontonotes/lt/{i}' 16 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 17 | con = Seq2seqConstituencyParser() 18 | con.fit( 19 | ONTONOTES5_CON_ENGLISH_TRAIN, 20 | ONTONOTES5_CON_ENGLISH_DEV, 21 | save_dir, 22 | BracketedVerbalizer(flatten_pos=True), 23 | epochs=30, 24 | eval_after=25, 25 | gradient_accumulation=8, 26 | sampler_builder=SortingSamplerBuilder(batch_size=32, use_effective_tokens=True), 27 | ) 28 | con.load(save_dir) 29 | test_score = con.evaluate(ONTONOTES5_CON_ENGLISH_TEST, save_dir, official=True)[-1] 30 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 31 | scores.append(test_score) 32 | 33 | print(f'Scores on {len(scores)} runs:') 34 | for metric in scores: 35 | print(metric) 36 | -------------------------------------------------------------------------------- /tests/con/ontonotes/pt.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-05-22 14:36 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.common.transform import NormalizeToken 6 | from elit.components.seq2seq.con.seq2seq_con import Seq2seqConstituencyParser 7 | from elit.components.seq2seq.con.verbalizer import IsAPhraseVerbalizer 8 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING 9 | from elit.datasets.srl.ontonotes5.english import ONTONOTES5_CON_ENGLISH_TRAIN, ONTONOTES5_CON_ENGLISH_DEV, \ 10 | ONTONOTES5_CON_ENGLISH_TEST 11 | from elit.utils.log_util import cprint 12 | from tests import cdroot 13 | 14 | cdroot() 15 | scores = [] 16 | for i in range(3): 17 | save_dir = f'data/model/con/ontonotes/pt/{i}' 18 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 19 | con = Seq2seqConstituencyParser() 20 | con.fit( 21 | ONTONOTES5_CON_ENGLISH_TRAIN, 22 | ONTONOTES5_CON_ENGLISH_DEV, 23 | save_dir, 24 | verbalizer=IsAPhraseVerbalizer( 25 | label_to_phrase={ 26 | 'ADJP': 'an adjective phrase', 'ADVP': 'an adverb phrase', 'CONJP': 'a conjunction phrase', 27 | 'FRAG': 'a fragment phrase', 'INTJ': 'an interjection', 'LST': 'a list marker', 28 | 'NAC': 'a non-constituent', 'NP': 'a noun phrase', "NX": "a head noun phrase", 29 | 'PP': 'a prepositional phrase', 'PRN': 'a parenthetical.', 'PRT': 'a particle', 30 | 'QP': 'a quantifier phrase', 'RRC': 'a reduced relative clause', 31 | 'UCP': 'an unlike coordinated phrase', 'VP': 'a verb phrase', 32 | 'WHADJP': 'a wh-adjective phrase', 'WHADVP': 'a wh-adverb phrase', 33 | 'WHNP': 'a wh-noun phrase', 34 | 'WHPP': 'a wh-prepositional phrase', 'X': 'an unknown phrase', 35 | 'S': 'a simple clause', 36 | "SBAR": "a subordinating clause", 37 | "SBARQ": "a wh-subordinating clause", 38 | "SINV": "an inverted clause", 39 | "SQ": "an interrogative clause", 40 | "NML": "a nominal", 41 | 'META': 'a meta tag', 42 | 'TOP': 'a sentence', 43 | 'EMBED': 'an embedding' 44 | }, 45 | top=True 46 | ), 47 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'token'), 48 | epochs=30, 49 | eval_after=28, 50 | gradient_accumulation=8, 51 | sampler_builder=SortingSamplerBuilder(batch_size=32, use_effective_tokens=True), 52 | max_seq_len=1024, 53 | max_prompt_len=1024, 54 | ) 55 | con.load(save_dir, max_prompt_len=None) 56 | test_score = con.evaluate(ONTONOTES5_CON_ENGLISH_TEST, save_dir, official=True)[-1] 57 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 58 | scores.append(test_score) 59 | 60 | print(f'Scores on {len(scores)} runs:') 61 | for metric in scores: 62 | print(metric) 63 | -------------------------------------------------------------------------------- /tests/con/ontonotes/pt_inc_vrb.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-05-22 14:36 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.common.transform import NormalizeToken 6 | from elit.components.seq2seq import Seq2seqConstituencyParser 7 | from elit.components.seq2seq.con.verbalizer import IsAPhraseVerbalizerVerbose 8 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING 9 | from elit.datasets.srl.ontonotes5.english import ONTONOTES5_CON_ENGLISH_TRAIN, ONTONOTES5_CON_ENGLISH_DEV, \ 10 | ONTONOTES5_CON_ENGLISH_TEST 11 | from elit.utils.log_util import cprint 12 | from tests import cdroot 13 | 14 | cdroot() 15 | scores = [] 16 | for i in range(3): 17 | save_dir = f'data/model/con/ontonotes/pt_inc_vrb/{i}' 18 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 19 | con = Seq2seqConstituencyParser() 20 | con.fit( 21 | ONTONOTES5_CON_ENGLISH_TRAIN, 22 | ONTONOTES5_CON_ENGLISH_DEV, 23 | save_dir, 24 | verbalizer=IsAPhraseVerbalizerVerbose( 25 | label_to_phrase={ 26 | 'ADJP': 'an adjective phrase', 'ADVP': 'an adverb phrase', 'CONJP': 'a conjunction phrase', 27 | 'FRAG': 'a fragment phrase', 'INTJ': 'an interjection', 'LST': 'a list marker', 28 | 'NAC': 'a non-constituent', 'NP': 'a noun phrase', "NX": "a head noun phrase", 29 | 'PP': 'a prepositional phrase', 'PRN': 'a parenthetical.', 'PRT': 'a particle', 30 | 'QP': 'a quantifier phrase', 'RRC': 'a reduced relative clause', 31 | 'UCP': 'an unlike coordinated phrase', 'VP': 'a verb phrase', 32 | 'WHADJP': 'a wh-adjective phrase', 'WHADVP': 'a wh-adverb phrase', 33 | 'WHNP': 'a wh-noun phrase', 34 | 'WHPP': 'a wh-prepositional phrase', 'X': 'an unknown phrase', 35 | 'S': 'a simple clause', 36 | "SBAR": "a subordinating clause", 37 | "SBARQ": "a wh-subordinating clause", 38 | "SINV": "an inverted clause", 39 | "SQ": "an interrogative clause", 40 | "NML": "a nominal", 41 | 'META': 'a meta tag', 42 | 'TOP': 'a sentence', 43 | 'EMBED': 'an embedding' 44 | }, 45 | top=True 46 | ), 47 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'token'), 48 | epochs=30, 49 | eval_after=28, 50 | gradient_accumulation=4, 51 | sampler_builder=SortingSamplerBuilder(batch_size=32, use_effective_tokens=True), 52 | max_seq_len=1024, 53 | max_prompt_len=1024, 54 | ) 55 | con.load(save_dir, max_prompt_len=None) 56 | test_score = con.evaluate(ONTONOTES5_CON_ENGLISH_TEST, save_dir, official=True)[-1] 57 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 58 | scores.append(test_score) 59 | 60 | print(f'Scores on {len(scores)} runs:') 61 | for metric in scores: 62 | print(metric) 63 | -------------------------------------------------------------------------------- /tests/con/ptb/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2023-02-06 18:45 4 | -------------------------------------------------------------------------------- /tests/con/ptb/ls.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-03-29 22:10 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.common.transform import NormalizeToken 6 | from elit.components.seq2seq import Seq2seqConstituencyParser 7 | from elit.components.seq2seq.con.verbalizer import ShiftReduceVerbalizer 8 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING, PTB_DEV, PTB_TEST, PTB_TRAIN 9 | from elit.utils.log_util import cprint 10 | from tests import cdroot 11 | 12 | cdroot() 13 | scores = [] 14 | for i in range(3): 15 | save_dir = f'data/model/con/ptb/ls/{i}' 16 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 17 | con = Seq2seqConstituencyParser() 18 | con.fit( 19 | PTB_TRAIN, 20 | PTB_DEV, 21 | save_dir, 22 | ShiftReduceVerbalizer(flatten_pos=True, anonymize_token=True), 23 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'token'), 24 | epochs=30, 25 | eval_after=28, 26 | save_every_epoch=False, 27 | gradient_accumulation=2, 28 | sampler_builder=SortingSamplerBuilder(batch_size=32, use_effective_tokens=True), 29 | ) 30 | con.load(save_dir) 31 | test_score = con.evaluate(PTB_TEST, save_dir, official=True)[-1] 32 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 33 | scores.append(test_score) 34 | 35 | print(f'Scores on {len(scores)} runs:') 36 | for metric in scores: 37 | print(metric) 38 | -------------------------------------------------------------------------------- /tests/con/ptb/lt.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-03-29 22:10 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.common.transform import NormalizeToken 6 | from elit.components.seq2seq import Seq2seqConstituencyParser 7 | from elit.components.seq2seq.con.verbalizer import BracketedVerbalizer 8 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING, PTB_DEV, PTB_TEST, PTB_TRAIN 9 | from elit.utils.log_util import cprint 10 | from tests import cdroot 11 | 12 | cdroot() 13 | scores = [] 14 | for i in range(3): 15 | save_dir = f'data/model/con/ptb/lt/{i}' 16 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 17 | con = Seq2seqConstituencyParser() 18 | con.fit( 19 | PTB_TRAIN, 20 | PTB_DEV, 21 | save_dir, 22 | BracketedVerbalizer(flatten_pos=True), 23 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'token'), 24 | epochs=30, 25 | eval_after=28, 26 | save_every_epoch=False, 27 | gradient_accumulation=4, 28 | sampler_builder=SortingSamplerBuilder(batch_size=32, use_effective_tokens=True), 29 | ) 30 | con.load(save_dir) 31 | test_score = con.evaluate(PTB_TEST, save_dir, official=True)[-1] 32 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 33 | scores.append(test_score) 34 | 35 | print(f'Scores on {len(scores)} runs:') 36 | for metric in scores: 37 | print(metric) 38 | -------------------------------------------------------------------------------- /tests/con/ptb/pt.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-05-22 14:36 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.common.transform import NormalizeToken 6 | from elit.components.seq2seq.con.seq2seq_con import Seq2seqConstituencyParser 7 | from elit.components.seq2seq.con.verbalizer import IsAPhraseVerbalizerVerbose 8 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING, PTB_DEV, PTB_TEST, PTB_TRAIN 9 | from elit.utils.log_util import cprint 10 | from tests import cdroot 11 | 12 | cdroot() 13 | scores = [] 14 | for i in range(3): 15 | save_dir = f'data/model/con/ptb/pt/{i}' 16 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 17 | con = Seq2seqConstituencyParser() 18 | con.fit( 19 | PTB_TRAIN, 20 | PTB_DEV, 21 | save_dir, 22 | verbalizer=IsAPhraseVerbalizerVerbose( 23 | label_to_phrase={ 24 | 'ADJP': 'an adjective phrase', 'ADVP': 'an adverb phrase', 'CONJP': 'a conjunction phrase', 25 | 'FRAG': 'a fragment phrase', 'INTJ': 'an interjection', 'LST': 'a list marker', 26 | 'NAC': 'a non-constituent', 'NP': 'a noun phrase', "NX": "a head noun phrase", 27 | 'PP': 'a prepositional phrase', 'PRN': 'a parenthetical.', 'PRT': 'a particle', 28 | 'QP': 'a quantifier phrase', 'RRC': 'a reduced relative clause', 29 | 'UCP': 'an unlike coordinated phrase', 'VP': 'a verb phrase', 30 | 'WHADJP': 'a wh-adjective phrase', 'WHADVP': 'a wh-adverb phrase', 31 | 'WHNP': 'a wh-noun phrase', 32 | 'WHPP': 'a wh-prepositional phrase', 'X': 'an unknown phrase', 33 | 'S': 'a simple clause', 34 | "SBAR": "a subordinating clause", 35 | "SBARQ": "a wh-subordinating clause", 36 | "SINV": "an inverted clause", "SQ": "an interrogative clause", 37 | }), 38 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'token'), 39 | epochs=30, 40 | eval_after=28, 41 | gradient_accumulation=4, 42 | sampler_builder=SortingSamplerBuilder(batch_size=32, use_effective_tokens=True), 43 | max_seq_len=1024, 44 | max_prompt_len=1024, 45 | ) 46 | con.load(save_dir) 47 | test_score = con.evaluate(PTB_TEST, save_dir, official=True)[-1] 48 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 49 | scores.append(test_score) 50 | 51 | print(f'Scores on {len(scores)} runs:') 52 | for metric in scores: 53 | print(metric) 54 | -------------------------------------------------------------------------------- /tests/dep/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-03-29 22:10 4 | -------------------------------------------------------------------------------- /tests/dep/ontonotes/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2023-02-06 18:57 4 | -------------------------------------------------------------------------------- /tests/dep/ontonotes/ls.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-03-29 22:10 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.common.transform import NormalizeToken 6 | from elit.components.seq2seq.dep.seq2seq_dep import Seq2seqDependencyParser 7 | from elit.components.seq2seq.dep.verbalizer import ArcStandardVerbalizer 8 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING 9 | from elit.datasets.srl.ontonotes5.english import ONTONOTES5_DEP_ENGLISH_TRAIN, ONTONOTES5_DEP_ENGLISH_DEV, \ 10 | ONTONOTES5_DEP_ENGLISH_TEST 11 | from elit.utils.log_util import cprint 12 | from tests import cdroot 13 | 14 | cdroot() 15 | scores = [] 16 | for i in range(3): 17 | save_dir = f'data/model/dep/ontonotes/ls/{i}' 18 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 19 | dep = Seq2seqDependencyParser() 20 | dep.fit( 21 | ONTONOTES5_DEP_ENGLISH_TRAIN, 22 | ONTONOTES5_DEP_ENGLISH_DEV, 23 | save_dir, 24 | ArcStandardVerbalizer(), 25 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'FORM'), 26 | epochs=30, 27 | eval_after=25, 28 | gradient_accumulation=4, 29 | sampler_builder=SortingSamplerBuilder(batch_size=32, use_effective_tokens=True), 30 | ) 31 | dep.load(save_dir) 32 | test_score = dep.evaluate(ONTONOTES5_DEP_ENGLISH_TEST, save_dir)[-1] 33 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 34 | scores.append(test_score) 35 | 36 | print(f'Scores on {len(scores)} runs:') 37 | for metric in scores: 38 | print(metric) 39 | -------------------------------------------------------------------------------- /tests/dep/ontonotes/lt.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-03-29 22:10 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.common.transform import NormalizeToken 6 | from elit.components.seq2seq.dep.seq2seq_dep import Seq2seqDependencyParser 7 | from elit.components.seq2seq.dep.verbalizer import ArcStandardVerbalizer 8 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING 9 | from elit.datasets.srl.ontonotes5.english import ONTONOTES5_DEP_ENGLISH_TEST, ONTONOTES5_DEP_ENGLISH_DEV, \ 10 | ONTONOTES5_DEP_ENGLISH_TRAIN 11 | from elit.utils.log_util import cprint 12 | from tests import cdroot 13 | 14 | cdroot() 15 | scores = [] 16 | for i in range(3): 17 | save_dir = f'data/model/dep/ontonotes/lt/{i}' 18 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 19 | dep = Seq2seqDependencyParser() 20 | dep.fit( 21 | ONTONOTES5_DEP_ENGLISH_TRAIN, 22 | ONTONOTES5_DEP_ENGLISH_DEV, 23 | save_dir, 24 | ArcStandardVerbalizer(lexical=True), 25 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'FORM'), 26 | epochs=30, 27 | eval_after=25, 28 | save_every_epoch=False, 29 | gradient_accumulation=4, 30 | sampler_builder=SortingSamplerBuilder(batch_size=32, use_effective_tokens=True), 31 | ) 32 | dep.load(save_dir) 33 | test_score = dep.evaluate(ONTONOTES5_DEP_ENGLISH_TEST, save_dir)[-1] 34 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 35 | scores.append(test_score) 36 | 37 | print(f'Scores on {len(scores)} runs:') 38 | for metric in scores: 39 | print(metric) 40 | -------------------------------------------------------------------------------- /tests/dep/ptb/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2023-02-06 18:53 4 | -------------------------------------------------------------------------------- /tests/dep/ptb/ls.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-03-29 22:10 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.common.transform import NormalizeToken 6 | from elit.components.seq2seq.dep.seq2seq_dep import Seq2seqDependencyParser 7 | from elit.components.seq2seq.dep.verbalizer import ArcEagerVerbalizer 8 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING, PTB_SD330_DEV, PTB_SD330_TEST, PTB_SD330_TRAIN 9 | from elit.utils.log_util import cprint 10 | from tests import cdroot 11 | 12 | cdroot() 13 | scores = [] 14 | for i in range(3): 15 | save_dir = f'data/model/dep/ptb/ls/{i}' 16 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 17 | dep = Seq2seqDependencyParser() 18 | dep.fit( 19 | PTB_SD330_TRAIN, 20 | PTB_SD330_DEV, 21 | save_dir, 22 | ArcEagerVerbalizer(), 23 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'FORM'), 24 | epochs=30, 25 | eval_after=25, 26 | gradient_accumulation=2, 27 | sampler_builder=SortingSamplerBuilder(batch_size=32, use_effective_tokens=True), 28 | ) 29 | dep.load(save_dir) 30 | test_score = dep.evaluate(PTB_SD330_TEST, save_dir)[-1] 31 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 32 | scores.append(test_score) 33 | 34 | print(f'Scores on {len(scores)} runs:') 35 | for metric in scores: 36 | print(metric) 37 | -------------------------------------------------------------------------------- /tests/dep/ptb/lt.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-03-29 22:10 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.common.transform import NormalizeToken 6 | from elit.components.seq2seq.dep.seq2seq_dep import Seq2seqDependencyParser 7 | from elit.components.seq2seq.dep.verbalizer import ArcEagerVerbalizer 8 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING, PTB_SD330_DEV, PTB_SD330_TEST, PTB_SD330_TRAIN 9 | from elit.utils.log_util import cprint 10 | from tests import cdroot 11 | 12 | cdroot() 13 | scores = [] 14 | for i in range(3): 15 | save_dir = f'data/model/dep/ptb/lt/{i}' 16 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 17 | dep = Seq2seqDependencyParser() 18 | dep.fit( 19 | PTB_SD330_TRAIN, 20 | PTB_SD330_DEV, 21 | save_dir, 22 | ArcEagerVerbalizer(lexical=True), 23 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'FORM'), 24 | epochs=30, 25 | eval_after=25, 26 | gradient_accumulation=2, 27 | sampler_builder=SortingSamplerBuilder(batch_size=32, use_effective_tokens=True), 28 | ) 29 | dep.load(save_dir) 30 | test_score = dep.evaluate(PTB_SD330_TEST, save_dir)[-1] 31 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 32 | scores.append(test_score) 33 | 34 | print(f'Scores on {len(scores)} runs:') 35 | for metric in scores: 36 | print(metric) 37 | -------------------------------------------------------------------------------- /tests/ner/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-10-22 17:19 4 | -------------------------------------------------------------------------------- /tests/ner/conll/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2023-02-06 18:35 4 | -------------------------------------------------------------------------------- /tests/ner/conll/ls.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-10-22 17:19 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.common.transform import NormalizeToken 6 | from elit.components.seq2seq.ner.prompt_ner import TagVerbalizer 7 | from elit.components.seq2seq.ner.seq2seq_ner import Seq2seqNamedEntityRecognizer 8 | from elit.datasets.ner.conll03_json import CONLL03_EN_JSON_TRAIN, CONLL03_EN_JSON_TEST, CONLL03_EN_JSON_DEV 9 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING 10 | from elit.utils.log_util import cprint 11 | from tests import cdroot 12 | 13 | cdroot() 14 | 15 | save_dir = 'data/model/ner/conll/ls/0' 16 | ner = Seq2seqNamedEntityRecognizer() 17 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 18 | ner.fit( 19 | CONLL03_EN_JSON_TRAIN, 20 | CONLL03_EN_JSON_DEV, 21 | save_dir, 22 | epochs=30, 23 | eval_after=25, 24 | transformer='facebook/bart-large', 25 | sampler_builder=SortingSamplerBuilder(batch_max_tokens=6000, use_effective_tokens=True), 26 | gradient_accumulation=1, 27 | fp16=False, 28 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'token'), 29 | verbalizer=TagVerbalizer(['LOC', 'PER', 'ORG', 'MISC']), 30 | _device_placeholder=True, 31 | save_every_epoch=False 32 | ) 33 | ner.load(save_dir, constrained_decoding=True) 34 | test_score = ner.evaluate(CONLL03_EN_JSON_TEST, save_dir)[-1] 35 | cprint(f'Official score on testset: [red]{test_score.score:.2%}[/red]') 36 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 37 | -------------------------------------------------------------------------------- /tests/ner/conll/lt.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-10-22 17:19 4 | 5 | from elit.common.dataset import SortingSamplerBuilder 6 | from elit.common.transform import NormalizeToken 7 | from elit.components.seq2seq.ner.prompt_ner import PairOfTagsVerbalizer 8 | from elit.components.seq2seq.ner.seq2seq_ner import Seq2seqNamedEntityRecognizer 9 | from elit.datasets.ner.conll03_json import CONLL03_EN_JSON_TRAIN, CONLL03_EN_JSON_DEV, CONLL03_EN_JSON_TEST 10 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING 11 | from elit.utils.log_util import cprint 12 | from tests import cdroot 13 | 14 | cdroot() 15 | save_dir = 'data/model/ner/conll/lt/0' 16 | ner = Seq2seqNamedEntityRecognizer() 17 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 18 | ner.fit( 19 | CONLL03_EN_JSON_TRAIN, 20 | CONLL03_EN_JSON_DEV, 21 | save_dir, 22 | epochs=30, 23 | eval_after=25, 24 | transformer='facebook/bart-large', 25 | sampler_builder=SortingSamplerBuilder(batch_max_tokens=6000, use_effective_tokens=True), 26 | gradient_accumulation=8, 27 | fp16=False, 28 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'token'), 29 | verbalizer=PairOfTagsVerbalizer( 30 | ['PER', 'ORG', 'LOC', 'MISC']), 31 | _device_placeholder=True, 32 | ) 33 | ner.load(save_dir) 34 | test_score = ner.evaluate(CONLL03_EN_JSON_TEST, save_dir)[-1] 35 | cprint(f'Official score on testset: [red]{test_score.score:.2%}[/red]') 36 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 37 | -------------------------------------------------------------------------------- /tests/ner/conll/pt.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-10-22 17:19 4 | 5 | from elit.common.dataset import SortingSamplerBuilder 6 | from elit.common.transform import NormalizeToken 7 | from elit.components.seq2seq.ner.prompt_ner import PairOfTagsVerbalizer 8 | from elit.components.seq2seq.ner.seq2seq_ner import Seq2seqNamedEntityRecognizer 9 | from elit.datasets.ner.conll03_json import CONLL03_EN_JSON_TRAIN, CONLL03_EN_JSON_DEV, CONLL03_EN_JSON_TEST 10 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING 11 | from elit.utils.log_util import cprint 12 | from tests import cdroot 13 | 14 | cdroot() 15 | 16 | save_dir = 'data/model/ner/conll/pt/0' 17 | ner = Seq2seqNamedEntityRecognizer() 18 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 19 | ner.fit( 20 | CONLL03_EN_JSON_TRAIN, 21 | CONLL03_EN_JSON_DEV, 22 | save_dir, 23 | epochs=30, 24 | eval_after=25, 25 | transformer='facebook/bart-large', 26 | sampler_builder=SortingSamplerBuilder(batch_max_tokens=6000, use_effective_tokens=True), 27 | gradient_accumulation=8, 28 | fp16=False, 29 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'token'), 30 | verbalizer=PairOfTagsVerbalizer( 31 | ['PER', 'ORG', 'LOC', 'MISC']), 32 | _device_placeholder=True, 33 | ) 34 | ner.load(save_dir) 35 | test_score = ner.evaluate(CONLL03_EN_JSON_TEST, save_dir)[-1] 36 | cprint(f'Official score on testset: [red]{test_score.score:.2%}[/red]') 37 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 38 | -------------------------------------------------------------------------------- /tests/ner/conll/pt_inc_vrb.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-10-22 17:19 4 | 5 | from elit.common.dataset import SortingSamplerBuilder 6 | from elit.common.transform import NormalizeToken 7 | from elit.components.seq2seq.ner.prompt_ner import VerboseVerbalizer 8 | from elit.components.seq2seq.ner.seq2seq_ner import Seq2seqNamedEntityRecognizer 9 | from elit.datasets.ner.conll03_json import CONLL03_EN_JSON_TRAIN, CONLL03_EN_JSON_TEST, CONLL03_EN_JSON_DEV 10 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING 11 | from elit.utils.log_util import cprint 12 | from tests import cdroot 13 | 14 | cdroot() 15 | 16 | for i in range(3): 17 | save_dir = f'data/model/ner/conll03/pt_inc_vrb/{i}' 18 | ner = Seq2seqNamedEntityRecognizer() 19 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 20 | ner.fit( 21 | CONLL03_EN_JSON_TRAIN, 22 | CONLL03_EN_JSON_DEV, 23 | save_dir, 24 | epochs=30, 25 | eval_after=25, 26 | transformer='facebook/bart-large', 27 | sampler_builder=SortingSamplerBuilder(batch_max_tokens=6000, use_effective_tokens=True), 28 | gradient_accumulation=5, 29 | fp16=False, 30 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'token'), 31 | verbalizer=VerboseVerbalizer( 32 | label_to_phrase={ 33 | 'PER': 'a person', 34 | 'ORG': 'an organization', 35 | 'LOC': 'a location', 36 | 'MISC': 'a nationality or an event or an entity', 37 | }, 38 | ), 39 | _device_placeholder=True, 40 | dropout=0.1, 41 | attention_dropout=0.1, 42 | optimizer_name='adam', 43 | constrained_decoding=False, 44 | ) 45 | ner.load(save_dir) 46 | test_score = ner.evaluate(CONLL03_EN_JSON_TEST, save_dir)[-1] 47 | cprint(f'Official score on testset: [red]{test_score.score:.2%}[/red]') 48 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 49 | -------------------------------------------------------------------------------- /tests/ner/ontonotes/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2023-02-06 18:40 4 | -------------------------------------------------------------------------------- /tests/ner/ontonotes/ls.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-10-22 17:19 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.common.transform import NormalizeToken 6 | from elit.components.seq2seq.ner.prompt_ner import TagVerbalizer 7 | from elit.components.seq2seq.ner.seq2seq_ner import Seq2seqNamedEntityRecognizer 8 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING 9 | from elit.datasets.srl.ontonotes5.english import ONTONOTES5_NER_ENGLISH_DEV, ONTONOTES5_NER_ENGLISH_TRAIN, \ 10 | ONTONOTES5_NER_ENGLISH_TEST 11 | from elit.utils.log_util import cprint 12 | from tests import cdroot 13 | 14 | cdroot() 15 | 16 | save_dir = 'data/model/ner/ontonotes/ls/0' 17 | ner = Seq2seqNamedEntityRecognizer() 18 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 19 | ner.fit( 20 | ONTONOTES5_NER_ENGLISH_TRAIN, 21 | ONTONOTES5_NER_ENGLISH_DEV, 22 | save_dir, 23 | # lr=3e-5, 24 | epochs=30, 25 | eval_after=25, 26 | transformer='facebook/bart-large', 27 | sampler_builder=SortingSamplerBuilder(batch_max_tokens=6000, use_effective_tokens=True), 28 | gradient_accumulation=10, 29 | fp16=False, 30 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'token'), 31 | verbalizer=TagVerbalizer( 32 | ['CARDINAL', 'DATE', 'EVENT', 'FAC', 'GPE', 'LANGUAGE', 'LAW', 'LOC', 'MONEY', 'NORP', 'ORDINAL', 'ORG', 33 | 'PERCENT', 'PERSON', 'PRODUCT', 'QUANTITY', 'TIME', 'WORK_OF_ART']), 34 | _device_placeholder=True, 35 | save_every_epoch=False 36 | # max_seq_len=600 37 | ) 38 | ner.load(save_dir, constrained_decoding=True) 39 | test_score = ner.evaluate(ONTONOTES5_NER_ENGLISH_TEST, save_dir, output='.jsonlines')[-1] 40 | cprint(f'Official score on testset: [red]{test_score.score:.2%}[/red]') 41 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 42 | -------------------------------------------------------------------------------- /tests/ner/ontonotes/lt.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-10-22 17:19 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.common.transform import NormalizeToken 6 | from elit.components.seq2seq.ner.prompt_ner import PairOfTagsVerbalizer 7 | from elit.components.seq2seq.ner.seq2seq_ner import Seq2seqNamedEntityRecognizer 8 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING 9 | from elit.datasets.srl.ontonotes5.english import ONTONOTES5_NER_ENGLISH_DEV, ONTONOTES5_NER_ENGLISH_TRAIN, \ 10 | ONTONOTES5_NER_ENGLISH_TEST 11 | from elit.utils.log_util import cprint 12 | from tests import cdroot 13 | 14 | cdroot() 15 | for run in range(3): 16 | save_dir = f'data/model/ner/ontonotes/lt/{run}' 17 | ner = Seq2seqNamedEntityRecognizer() 18 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 19 | ner.fit( 20 | ONTONOTES5_NER_ENGLISH_TRAIN, 21 | ONTONOTES5_NER_ENGLISH_DEV, 22 | save_dir, 23 | # lr=3e-5, 24 | epochs=30, 25 | eval_after=0, 26 | transformer='facebook/bart-large', 27 | sampler_builder=SortingSamplerBuilder(batch_max_tokens=6000, use_effective_tokens=True), 28 | gradient_accumulation=6, 29 | fp16=False, 30 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'token'), 31 | verbalizer=PairOfTagsVerbalizer( 32 | ['CARDINAL', 'DATE', 'EVENT', 'FAC', 'GPE', 'LANGUAGE', 'LAW', 'LOC', 'MONEY', 'NORP', 'ORDINAL', 'ORG', 33 | 'PERCENT', 'PERSON', 'PRODUCT', 'QUANTITY', 'TIME', 'WORK_OF_ART']), 34 | _device_placeholder=True, 35 | ) 36 | ner.load(save_dir) 37 | test_score = ner.evaluate(ONTONOTES5_NER_ENGLISH_TEST, save_dir)[-1] 38 | cprint(f'Official score on testset: [red]{test_score.score:.2%}[/red]') 39 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 40 | -------------------------------------------------------------------------------- /tests/ner/ontonotes/pt.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-10-22 17:19 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.common.transform import NormalizeToken 6 | from elit.components.seq2seq.ner.prompt_ner import Verbalizer 7 | from elit.components.seq2seq.ner.seq2seq_ner import Seq2seqNamedEntityRecognizer 8 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING 9 | from elit.datasets.srl.ontonotes5.english import ONTONOTES5_NER_ENGLISH_DEV, ONTONOTES5_NER_ENGLISH_TRAIN, \ 10 | ONTONOTES5_NER_ENGLISH_TEST 11 | from elit.utils.log_util import cprint 12 | from tests import cdroot 13 | 14 | cdroot() 15 | save_dir = 'data/model/ner/ontonotes/pt/0' 16 | ner = Seq2seqNamedEntityRecognizer() 17 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 18 | ner.fit( 19 | ONTONOTES5_NER_ENGLISH_TRAIN, 20 | ONTONOTES5_NER_ENGLISH_DEV, 21 | save_dir, 22 | epochs=30, 23 | eval_after=25, 24 | transformer='facebook/bart-large', 25 | sampler_builder=SortingSamplerBuilder(batch_max_tokens=6000, use_effective_tokens=True), 26 | gradient_accumulation=6, 27 | fp16=False, 28 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'token'), 29 | verbalizer=Verbalizer( 30 | label_to_phrase={ 31 | 'PERSON': 'a person', 32 | 'NORP': 'an ethnicity', 33 | 'FAC': 'a facility', 34 | 'ORG': 'an organization', 35 | 'GPE': 'a geopolitical entity', 36 | 'LOC': 'a location', 37 | 'PRODUCT': 'a product', 38 | 'EVENT': 'an event', 39 | 'WORK_OF_ART': 'an art work', 40 | 'LAW': 'a law', 41 | 'DATE': 'a date', 42 | 'TIME': 'a time', 43 | 'PERCENT': 'a percentage', 44 | 'MONEY': 'a monetary value', 45 | 'QUANTITY': 'a quantity', 46 | 'ORDINAL': 'an ordinal', 47 | 'CARDINAL': 'a cardinal', 48 | 'LANGUAGE': 'a language', 49 | } 50 | ), 51 | _device_placeholder=True, 52 | ) 53 | ner.load(save_dir) 54 | test_score = ner.evaluate(ONTONOTES5_NER_ENGLISH_TEST, save_dir)[-1] 55 | cprint(f'Official score on testset: [red]{test_score.score:.2%}[/red]') 56 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 57 | -------------------------------------------------------------------------------- /tests/ner/ontonotes/pt_inc_vrb.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2021-10-22 17:19 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.common.transform import NormalizeToken 6 | from elit.components.seq2seq.ner.prompt_ner import VerboseVerbalizer 7 | from elit.components.seq2seq.ner.seq2seq_ner import Seq2seqNamedEntityRecognizer 8 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING 9 | from elit.datasets.srl.ontonotes5.english import ONTONOTES5_NER_ENGLISH_TRAIN, ONTONOTES5_NER_ENGLISH_DEV, \ 10 | ONTONOTES5_NER_ENGLISH_TEST 11 | from elit.utils.log_util import cprint 12 | from tests import cdroot 13 | cdroot() 14 | for i in range(3): 15 | save_dir = f'data/model/ner/ontonotes/pt_inc_vrb/{i}' 16 | ner = Seq2seqNamedEntityRecognizer() 17 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 18 | ner.fit( 19 | ONTONOTES5_NER_ENGLISH_TRAIN, 20 | ONTONOTES5_NER_ENGLISH_DEV, 21 | save_dir, 22 | epochs=30, 23 | eval_after=25, 24 | transformer='facebook/bart-large', 25 | sampler_builder=SortingSamplerBuilder(batch_max_tokens=6000, use_effective_tokens=True), 26 | gradient_accumulation=6, 27 | fp16=False, 28 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'token'), 29 | verbalizer=VerboseVerbalizer( 30 | label_to_phrase={ 31 | 'PERSON': 'a person', 32 | 'NORP': 'an ethnicity', 33 | 'FAC': 'a facility', 34 | 'ORG': 'an organization', 35 | 'GPE': 'a geopolitical entity', 36 | 'LOC': 'a location', 37 | 'PRODUCT': 'a product', 38 | 'EVENT': 'an event', 39 | 'WORK_OF_ART': 'an art work', 40 | 'LAW': 'a law', 41 | 'DATE': 'a date', 42 | 'TIME': 'a time', 43 | 'PERCENT': 'a percentage', 44 | 'MONEY': 'a monetary value', 45 | 'QUANTITY': 'a quantity', 46 | 'ORDINAL': 'an ordinal', 47 | 'CARDINAL': 'a cardinal', 48 | 'LANGUAGE': 'a language', 49 | } 50 | ), 51 | _device_placeholder=True, 52 | constrained_decoding=False, 53 | ) 54 | ner.load(save_dir) 55 | test_score = ner.evaluate(ONTONOTES5_NER_ENGLISH_TEST, save_dir)[-1] 56 | cprint(f'Official score on testset: [red]{test_score.score:.2%}[/red]') 57 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 58 | -------------------------------------------------------------------------------- /tests/pos/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-03-23 16:30 4 | -------------------------------------------------------------------------------- /tests/pos/ontonotes/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-03-30 16:54 4 | -------------------------------------------------------------------------------- /tests/pos/ontonotes/ls.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-03-23 16:30 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.common.transform import NormalizeToken 6 | from elit.components.seq2seq.pos.seq2seq_pos import Seq2seqTagger 7 | from elit.components.seq2seq.pos.verbalizer import TagVerbalizer 8 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING 9 | from elit.datasets.srl.ontonotes5.english import ONTONOTES5_POS_ENGLISH_TRAIN, ONTONOTES5_POS_ENGLISH_DEV, \ 10 | ONTONOTES5_POS_ENGLISH_TEST 11 | from elit.utils.log_util import cprint 12 | from tests import cdroot 13 | 14 | cdroot() 15 | scores = [] 16 | for i in range(3): 17 | save_dir = f'data/model/pos/ontonotes/ls/{i}' 18 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 19 | pos = Seq2seqTagger() 20 | pos.fit( 21 | ONTONOTES5_POS_ENGLISH_TRAIN, 22 | ONTONOTES5_POS_ENGLISH_DEV, 23 | save_dir, 24 | verbalizer=TagVerbalizer(), 25 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'token'), 26 | epochs=30, 27 | eval_after=25, 28 | gradient_accumulation=1, 29 | sampler_builder=SortingSamplerBuilder(batch_max_tokens=6000, use_effective_tokens=True), 30 | ) 31 | pos.load(save_dir, constrained_decoding=True) 32 | test_score = pos.evaluate(ONTONOTES5_POS_ENGLISH_TEST, save_dir, output=True)[-1] 33 | scores.append(test_score) 34 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 35 | 36 | print(f'Scores on {len(scores)} runs:') 37 | for metric in scores: 38 | print(metric) 39 | -------------------------------------------------------------------------------- /tests/pos/ontonotes/lt.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-03-23 16:30 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.common.transform import NormalizeToken 6 | from elit.components.seq2seq.pos.seq2seq_pos import Seq2seqTagger 7 | from elit.components.seq2seq.pos.verbalizer import TokenTagVerbalizer 8 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING 9 | from elit.datasets.srl.ontonotes5.english import ONTONOTES5_POS_ENGLISH_TRAIN, ONTONOTES5_POS_ENGLISH_DEV, \ 10 | ONTONOTES5_POS_ENGLISH_TEST 11 | from elit.utils.log_util import cprint 12 | from tests import cdroot 13 | 14 | cdroot() 15 | scores = [] 16 | for i in range(3): 17 | save_dir = f'data/model/pos/ontonotes/{i}' 18 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 19 | pos = Seq2seqTagger() 20 | pos.fit( 21 | ONTONOTES5_POS_ENGLISH_TRAIN, 22 | ONTONOTES5_POS_ENGLISH_DEV, 23 | save_dir, 24 | verbalizer=TokenTagVerbalizer(), 25 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'token'), 26 | epochs=30, 27 | eval_after=25, 28 | sampler_builder=SortingSamplerBuilder(batch_max_tokens=6000, use_effective_tokens=True), 29 | gradient_accumulation=2, 30 | ) 31 | pos.load(save_dir, constrained_decoding=True) 32 | score = pos.evaluate(ONTONOTES5_POS_ENGLISH_TEST, save_dir, output=True)[-1] 33 | scores.append(score) 34 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 35 | 36 | print(f'Scores on {len(scores)} runs:') 37 | for metric in scores: 38 | print(metric) 39 | -------------------------------------------------------------------------------- /tests/pos/ptb/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2023-02-06 18:28 4 | -------------------------------------------------------------------------------- /tests/pos/ptb/ls.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-03-23 16:30 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.common.transform import NormalizeToken 6 | from elit.components.seq2seq.pos.seq2seq_pos import Seq2seqTagger 7 | from elit.components.seq2seq.pos.verbalizer import TagVerbalizer 8 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING 9 | from elit.utils.log_util import cprint 10 | from tests import cdroot 11 | 12 | cdroot() 13 | scores = [] 14 | for i in range(3): 15 | save_dir = f'data/model/pos/ptb/ls/{i}' 16 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 17 | pos = Seq2seqTagger() 18 | pos.fit( 19 | 'data/pos/wsj-pos/train.tsv', 20 | 'data/pos/wsj-pos/dev.tsv', 21 | save_dir, 22 | verbalizer=TagVerbalizer(), 23 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'token'), 24 | epochs=30, 25 | eval_after=25, 26 | gradient_accumulation=1, 27 | sampler_builder=SortingSamplerBuilder(batch_max_tokens=6000, use_effective_tokens=True), 28 | ) 29 | pos.load(save_dir, constrained_decoding=True) 30 | test_score = pos.evaluate('data/pos/wsj-pos/test.tsv', save_dir)[-1] 31 | scores.append(test_score) 32 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 33 | 34 | print(f'Scores on {len(scores)} runs:') 35 | for metric in scores: 36 | print(metric) 37 | -------------------------------------------------------------------------------- /tests/pos/ptb/lt.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Author: hankcs 3 | # Date: 2022-03-23 16:30 4 | from elit.common.dataset import SortingSamplerBuilder 5 | from elit.common.transform import NormalizeToken 6 | from elit.components.seq2seq.pos.seq2seq_pos import Seq2seqTagger 7 | from elit.components.seq2seq.pos.verbalizer import TokenTagVerbalizer 8 | from elit.datasets.parsing.ptb import PTB_TOKEN_MAPPING 9 | from elit.utils.log_util import cprint 10 | from tests import cdroot 11 | 12 | cdroot() 13 | scores = [] 14 | for i in range(3): 15 | save_dir = f'data/model/pos/ptb/lt/{i}' 16 | cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]') 17 | pos = Seq2seqTagger() 18 | pos.fit( 19 | 'data/pos/wsj-pos/train.tsv', 20 | 'data/pos/wsj-pos/dev.tsv', 21 | save_dir, 22 | verbalizer=TokenTagVerbalizer(), 23 | transform=NormalizeToken(PTB_TOKEN_MAPPING, 'token'), 24 | epochs=30, 25 | eval_after=25, 26 | sampler_builder=SortingSamplerBuilder(batch_max_tokens=6000, use_effective_tokens=True), 27 | gradient_accumulation=2, 28 | ) 29 | pos.load(save_dir, constrained_decoding=False) 30 | score = pos.evaluate('data/pos/wsj-pos/test.tsv', save_dir)[-1] 31 | scores.append(score) 32 | cprint(f'Model saved in [cyan]{save_dir}[/cyan]') 33 | 34 | print(f'Scores on {len(scores)} runs:') 35 | for metric in scores: 36 | print(metric) 37 | --------------------------------------------------------------------------------