├── SMART-KPE ├── __init__.py ├── train_model.sh ├── eval_model.sh ├── evaluation.py ├── model.py └── parser.py ├── BERT-KPE-based ├── preprocess │ ├── __init__.py │ ├── .DS_Store │ └── preprocess.sh ├── bertkpe │ ├── transformers │ │ ├── tests │ │ │ ├── __init__.py │ │ │ ├── fixtures │ │ │ │ ├── input.txt │ │ │ │ ├── test_sentencepiece.model │ │ │ │ └── sample_text.txt │ │ │ ├── conftest.py │ │ │ ├── tokenization_auto_test.py │ │ │ ├── tokenization_utils_test.py │ │ │ ├── tokenization_distilbert_test.py │ │ │ ├── modeling_encoder_decoder_test.py │ │ │ ├── configuration_common_test.py │ │ │ ├── tokenization_openai_test.py │ │ │ ├── tokenization_ctrl_test.py │ │ │ ├── tokenization_gpt2_test.py │ │ │ ├── tokenization_transfo_xl_test.py │ │ │ ├── tokenization_xlm_test.py │ │ │ ├── modeling_tf_auto_test.py │ │ │ ├── modeling_auto_test.py │ │ │ ├── tokenization_roberta_test.py │ │ │ ├── tokenization_xlnet_test.py │ │ │ ├── tokenization_bert_test.py │ │ │ └── optimization_test.py │ │ ├── .DS_Store │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── file_utils.cpython-36.pyc │ │ │ ├── modeling_auto.cpython-36.pyc │ │ │ ├── modeling_bert.cpython-36.pyc │ │ │ ├── modeling_ctrl.cpython-36.pyc │ │ │ ├── modeling_gpt2.cpython-36.pyc │ │ │ ├── modeling_xlm.cpython-36.pyc │ │ │ ├── optimization.cpython-36.pyc │ │ │ ├── modeling_openai.cpython-36.pyc │ │ │ ├── modeling_utils.cpython-36.pyc │ │ │ ├── modeling_xlnet.cpython-36.pyc │ │ │ ├── configuration_auto.cpython-36.pyc │ │ │ ├── configuration_bert.cpython-36.pyc │ │ │ ├── configuration_ctrl.cpython-36.pyc │ │ │ ├── configuration_gpt2.cpython-36.pyc │ │ │ ├── configuration_xlm.cpython-36.pyc │ │ │ ├── modeling_roberta.cpython-36.pyc │ │ │ ├── tokenization_auto.cpython-36.pyc │ │ │ ├── tokenization_bert.cpython-36.pyc │ │ │ ├── tokenization_ctrl.cpython-36.pyc │ │ │ ├── tokenization_gpt2.cpython-36.pyc │ │ │ ├── tokenization_utils.cpython-36.pyc │ │ │ ├── tokenization_xlm.cpython-36.pyc │ │ │ ├── tokenization_xlnet.cpython-36.pyc │ │ │ ├── configuration_openai.cpython-36.pyc │ │ │ ├── configuration_utils.cpython-36.pyc │ │ │ ├── configuration_xlnet.cpython-36.pyc │ │ │ ├── modeling_distilbert.cpython-36.pyc │ │ │ ├── modeling_transfo_xl.cpython-36.pyc │ │ │ ├── tokenization_openai.cpython-36.pyc │ │ │ ├── tokenization_roberta.cpython-36.pyc │ │ │ ├── configuration_roberta.cpython-36.pyc │ │ │ ├── tokenization_distilbert.cpython-36.pyc │ │ │ ├── tokenization_transfo_xl.cpython-36.pyc │ │ │ ├── configuration_distilbert.cpython-36.pyc │ │ │ ├── configuration_transfo_xl.cpython-36.pyc │ │ │ ├── modeling_encoder_decoder.cpython-36.pyc │ │ │ ├── modeling_tf_pytorch_utils.cpython-36.pyc │ │ │ └── modeling_transfo_xl_utilities.cpython-36.pyc │ │ ├── data │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-35.pyc │ │ │ │ └── __init__.cpython-36.pyc │ │ │ ├── processors │ │ │ │ ├── __pycache__ │ │ │ │ │ ├── glue.cpython-35.pyc │ │ │ │ │ ├── glue.cpython-36.pyc │ │ │ │ │ ├── utils.cpython-35.pyc │ │ │ │ │ ├── utils.cpython-36.pyc │ │ │ │ │ ├── __init__.cpython-35.pyc │ │ │ │ │ └── __init__.cpython-36.pyc │ │ │ │ ├── __init__.py │ │ │ │ └── utils.py │ │ │ ├── metrics │ │ │ │ ├── __pycache__ │ │ │ │ │ ├── __init__.cpython-35.pyc │ │ │ │ │ └── __init__.cpython-36.pyc │ │ │ │ └── __init__.py │ │ │ └── __init__.py │ │ ├── configuration_roberta.py │ │ ├── tokenization_distilbert.py │ │ ├── convert_bert_original_tf_checkpoint_to_pytorch.py │ │ ├── convert_gpt2_original_tf_checkpoint_to_pytorch.py │ │ ├── convert_openai_original_tf_checkpoint_to_pytorch.py │ │ ├── convert_xlm_original_pytorch_checkpoint_to_pytorch.py │ │ ├── configuration_distilbert.py │ │ ├── convert_xlnet_original_tf_checkpoint_to_pytorch.py │ │ ├── convert_bert_pytorch_checkpoint_to_original_tf.py │ │ ├── configuration_openai.py │ │ ├── convert_transfo_xl_original_tf_checkpoint_to_pytorch.py │ │ ├── configuration_ctrl.py │ │ └── configuration_gpt2.py │ ├── constant │ │ ├── __init__.py │ │ ├── .DS_Store │ │ ├── __pycache__ │ │ │ ├── Constant.cpython-36.pyc │ │ │ └── __init__.cpython-36.pyc │ │ └── Constant.py │ ├── .DS_Store │ ├── evaluator │ │ ├── __init__.py │ │ ├── .DS_Store │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── kp20k_evaluator.cpython-36.pyc │ │ │ └── openkp_evaluator.cpython-36.pyc │ │ ├── kp20k_evaluator.py │ │ └── openkp_evaluator.py │ ├── networks │ │ ├── .DS_Store │ │ ├── __pycache__ │ │ │ ├── Bert2Rank.cpython-36.pyc │ │ │ ├── Bert2Span.cpython-36.pyc │ │ │ ├── Bert2Tag.cpython-36.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── Bert2Chunk.cpython-36.pyc │ │ │ ├── Bert2Joint.cpython-36.pyc │ │ │ ├── Roberta2Rank.cpython-36.pyc │ │ │ ├── Roberta2Span.cpython-36.pyc │ │ │ ├── Roberta2Tag.cpython-36.pyc │ │ │ ├── Roberta2Chunk.cpython-36.pyc │ │ │ └── Roberta2Joint.cpython-36.pyc │ │ ├── __init__.py │ │ ├── Bert2Tag.py │ │ ├── Roberta2Tag.py │ │ └── Bert2Span.py │ ├── dataloader │ │ ├── .DS_Store │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── loader_utils.cpython-36.pyc │ │ │ ├── bert2rank_dataloader.cpython-36.pyc │ │ │ ├── bert2span_dataloader.cpython-36.pyc │ │ │ ├── bert2tag_dataloader.cpython-36.pyc │ │ │ ├── bert2chunk_dataloader.cpython-36.pyc │ │ │ └── bert2joint_dataloader.cpython-36.pyc │ │ └── __init__.py │ ├── generator │ │ ├── .DS_Store │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── Rank2Phrase.cpython-36.pyc │ │ │ ├── Span2Phrase.cpython-36.pyc │ │ │ ├── Tag2Phrase.cpython-36.pyc │ │ │ ├── Chunk2Phrase.cpython-36.pyc │ │ │ └── generator_utils.cpython-36.pyc │ │ ├── __init__.py │ │ ├── Rank2Phrase.py │ │ ├── generator_utils.py │ │ ├── Chunk2Phrase.py │ │ ├── Span2Phrase.py │ │ └── Tag2Phrase.py │ └── __init__.py └── scripts │ ├── .DS_Store │ ├── test.sh │ ├── train.sh │ └── train_dist.sh ├── .gitignore └── README.md /SMART-KPE/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /BERT-KPE-based/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__ 2 | *screenlog.0 3 | data 4 | */results 5 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/constant/__init__.py: -------------------------------------------------------------------------------- 1 | from .Constant import * -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/fixtures/input.txt: -------------------------------------------------------------------------------- 1 | Who was Jim Henson ? ||| Jim Henson was a puppeteer 2 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/.DS_Store -------------------------------------------------------------------------------- /BERT-KPE-based/scripts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/scripts/.DS_Store -------------------------------------------------------------------------------- /BERT-KPE-based/preprocess/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/preprocess/.DS_Store -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/constant/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/constant/.DS_Store -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from .openkp_evaluator import evaluate_openkp 2 | from .kp20k_evaluator import evaluate_kp20k -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/networks/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/networks/.DS_Store -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/dataloader/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/dataloader/.DS_Store -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/evaluator/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/evaluator/.DS_Store -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/generator/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/generator/.DS_Store -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/.DS_Store -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/constant/__pycache__/Constant.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/constant/__pycache__/Constant.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/constant/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/constant/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/evaluator/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/evaluator/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/generator/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/generator/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/networks/__pycache__/Bert2Rank.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/networks/__pycache__/Bert2Rank.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/networks/__pycache__/Bert2Span.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/networks/__pycache__/Bert2Span.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/networks/__pycache__/Bert2Tag.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/networks/__pycache__/Bert2Tag.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/networks/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/networks/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/dataloader/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/dataloader/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/generator/__pycache__/Rank2Phrase.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/generator/__pycache__/Rank2Phrase.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/generator/__pycache__/Span2Phrase.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/generator/__pycache__/Span2Phrase.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/generator/__pycache__/Tag2Phrase.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/generator/__pycache__/Tag2Phrase.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/networks/__pycache__/Bert2Chunk.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/networks/__pycache__/Bert2Chunk.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/networks/__pycache__/Bert2Joint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/networks/__pycache__/Bert2Joint.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/networks/__pycache__/Roberta2Rank.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/networks/__pycache__/Roberta2Rank.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/networks/__pycache__/Roberta2Span.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/networks/__pycache__/Roberta2Span.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/networks/__pycache__/Roberta2Tag.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/networks/__pycache__/Roberta2Tag.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/dataloader/__pycache__/loader_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/dataloader/__pycache__/loader_utils.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/generator/__pycache__/Chunk2Phrase.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/generator/__pycache__/Chunk2Phrase.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/networks/__pycache__/Roberta2Chunk.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/networks/__pycache__/Roberta2Chunk.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/networks/__pycache__/Roberta2Joint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/networks/__pycache__/Roberta2Joint.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/file_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/file_utils.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/evaluator/__pycache__/kp20k_evaluator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/evaluator/__pycache__/kp20k_evaluator.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/evaluator/__pycache__/openkp_evaluator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/evaluator/__pycache__/openkp_evaluator.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/generator/__pycache__/generator_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/generator/__pycache__/generator_utils.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_auto.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_auto.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_bert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_bert.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_ctrl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_ctrl.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_gpt2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_gpt2.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_xlm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_xlm.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/optimization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/optimization.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/data/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/data/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/fixtures/test_sentencepiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/tests/fixtures/test_sentencepiece.model -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_openai.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_openai.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_utils.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_xlnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_xlnet.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/dataloader/__pycache__/bert2rank_dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/dataloader/__pycache__/bert2rank_dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/dataloader/__pycache__/bert2span_dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/dataloader/__pycache__/bert2span_dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/dataloader/__pycache__/bert2tag_dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/dataloader/__pycache__/bert2tag_dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_auto.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_auto.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_bert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_bert.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_ctrl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_ctrl.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_gpt2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_gpt2.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_xlm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_xlm.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_roberta.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_roberta.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_auto.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_auto.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_bert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_bert.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_ctrl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_ctrl.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_gpt2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_gpt2.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_utils.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_xlm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_xlm.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_xlnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_xlnet.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/dataloader/__pycache__/bert2chunk_dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/dataloader/__pycache__/bert2chunk_dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/dataloader/__pycache__/bert2joint_dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/dataloader/__pycache__/bert2joint_dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_openai.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_openai.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_utils.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_xlnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_xlnet.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_distilbert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_distilbert.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_transfo_xl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_transfo_xl.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_openai.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_openai.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_roberta.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_roberta.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/data/processors/__pycache__/glue.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/data/processors/__pycache__/glue.cpython-35.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/data/processors/__pycache__/glue.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/data/processors/__pycache__/glue.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_roberta.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_roberta.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_distilbert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_distilbert.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_transfo_xl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/tokenization_transfo_xl.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/data/metrics/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/data/metrics/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/data/metrics/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/data/metrics/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/data/processors/__pycache__/utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/data/processors/__pycache__/utils.cpython-35.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/data/processors/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/data/processors/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_distilbert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_distilbert.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_transfo_xl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/configuration_transfo_xl.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_encoder_decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_encoder_decoder.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_tf_pytorch_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_tf_pytorch_utils.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/data/processors/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/data/processors/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/data/processors/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/data/processors/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_transfo_xl_utilities.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorywys/SMART-KPE/HEAD/BERT-KPE-based/bertkpe/transformers/__pycache__/modeling_transfo_xl_utilities.cpython-36.pyc -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/generator/__init__.py: -------------------------------------------------------------------------------- 1 | from . import generator_utils 2 | 3 | from .Span2Phrase import span2phrase 4 | from .Tag2Phrase import tag2phrase 5 | from .Chunk2Phrase import chunk2phrase 6 | from .Rank2Phrase import rank2phrase 7 | 8 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/data/processors/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import InputExample, InputFeatures, DataProcessor 2 | from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features 3 | 4 | -------------------------------------------------------------------------------- /BERT-KPE-based/preprocess/preprocess.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=../../data 2 | 3 | # preprocess openkp or kp20k 4 | python preprocess.py --dataset_class openkp \ 5 | --source_dataset_dir $DATA_PATH/dataset \ 6 | --output_path $DATA_PATH/prepro_dataset 7 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .processors import InputExample, InputFeatures, DataProcessor 2 | from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features 3 | 4 | from .metrics import is_sklearn_available 5 | if is_sklearn_available(): 6 | from .metrics import glue_compute_metrics 7 | -------------------------------------------------------------------------------- /BERT-KPE-based/scripts/test.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=../../data 2 | 3 | CUDA_VISIBLE_DEVICES=2,3 python test.py --run_mode test \ 4 | --model_class bert2span \ 5 | --pretrain_model_type bert-base-cased \ 6 | --dataset_class openkp \ 7 | --per_gpu_test_batch_size 64 \ 8 | --preprocess_folder $DATA_PATH/prepro_dataset \ 9 | --pretrain_model_path $DATA_PATH/pretrain_model \ 10 | --cached_features_dir $DATA_PATH/cached_features \ 11 | --eval_checkpoint /usr0/home/yansenwa/courses/11747/project/BERT-KPE/checkpoints/bert2span/bert2span.openkp.bert.checkpoint \ 12 | --local_rank -1 13 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/conftest.py: -------------------------------------------------------------------------------- 1 | # content of conftest.py 2 | 3 | import pytest 4 | 5 | 6 | def pytest_addoption(parser): 7 | parser.addoption( 8 | "--runslow", action="store_true", default=False, help="run slow tests" 9 | ) 10 | parser.addoption( 11 | "--use_cuda", action="store_true", default=False, help="run tests on gpu" 12 | ) 13 | 14 | 15 | def pytest_configure(config): 16 | config.addinivalue_line("markers", "slow: mark test as slow to run") 17 | 18 | 19 | def pytest_collection_modifyitems(config, items): 20 | if config.getoption("--runslow"): 21 | # --runslow given in cli: do not skip slow tests 22 | return 23 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 24 | for item in items: 25 | if "slow" in item.keywords: 26 | item.add_marker(skip_slow) 27 | 28 | @pytest.fixture 29 | def use_cuda(request): 30 | """ Run test on gpu """ 31 | return request.config.getoption("--use_cuda") 32 | -------------------------------------------------------------------------------- /SMART-KPE/train_model.sh: -------------------------------------------------------------------------------- 1 | export TRAIN_DATA_DIR=../../data_title/ 2 | export CACHE_DATA_DIR=../../data_title/title_snap/ 3 | export OUTPUT_DIR=../output_new/title_snap_4/ 4 | export PRINT_DIR=../output_new/title_snap_4/ 5 | export META_DIR=../../metadata/ 6 | 7 | CUDA_VISIBLE_DEVICES=2,3,4,5 python3 -u run_model.py \ 8 | --cached_features_dir $CACHE_DATA_DIR/ \ 9 | --data_dir $TRAIN_DATA_DIR/ \ 10 | --output_dir $OUTPUT_DIR/ \ 11 | --print_dir $PRINT_DIR \ 12 | --meta_dir $META_DIR \ 13 | --use_snapshot \ 14 | --train \ 15 | --dev \ 16 | --test \ 17 | --num_trans 4 \ 18 | --learning_rate 1e-5 \ 19 | --num_train_epochs 3 \ 20 | --batch_size 32 \ 21 | --tag_num 5 \ 22 | --main_metric F@3 \ 23 | --gradient_accumulation_steps 4 \ 24 | --max_text_length 512 \ 25 | --logging_steps 200 \ 26 | --save_steps 2000 \ 27 | --evaluate_during_training \ 28 | --save_best \ 29 | --include_title \ 30 | --read_from_cached_features \ 31 | # --from_checkpoint $OUTPUT_DIR/checkpoint-best 32 | -------------------------------------------------------------------------------- /SMART-KPE/eval_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #SBATCH -N 1 # tasks requested 3 | #SBATCH -n 6 # tasks requested 4 | #SBATCH --gres=gpu:1 #GPU 5 | #SBATCH -e ./tmp/err_evaltitle_4 # send stderr to errfile 6 | #SBATCH -o ./tmp/out_evaltitle_4 # send stdout to outfile 7 | #SBATCH --mem=48000 # memory in Mb 8 | #SBATCH -t 1-00:00 # time:q 9 | #SBATCH -p gpu 10 | 11 | export TRAIN_DATA_DIR=../../data_title/ 12 | export CACHE_DATA_DIR=../../data_title/title_snap/ 13 | export OUTPUT_DIR=../output_new/title_snap_4/ 14 | export PRINT_DIR=../output_new/title_snap_4/ 15 | export META_DIR=../../metadata/ 16 | 17 | CUDA_VISIBLE_DEVICES=0 python3 -u run_model.py \ 18 | --cached_features_dir $CACHE_DATA_DIR/ \ 19 | --data_dir $TRAIN_DATA_DIR/ \ 20 | --output_dir $OUTPUT_DIR/ \ 21 | --print_dir $PRINT_DIR \ 22 | --meta_dir $META_DIR \ 23 | --use_snapshot \ 24 | --dev \ 25 | --test \ 26 | --num_trans 4 \ 27 | --batch_size 8 \ 28 | --tag_num 5 \ 29 | --max_text_length 512 \ 30 | --from_checkpoint $OUTPUT_DIR/checkpoint-best/ \ 31 | --read_from_cached_features \ 32 | --include_title \ 33 | 34 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/__init__.py: -------------------------------------------------------------------------------- 1 | from . import transformers 2 | from . import dataloader 3 | from .transformers import BertTokenizer, RobertaTokenizer, BertConfig, RobertaConfig 4 | 5 | from . import networks 6 | from . import generator 7 | from . import evaluator 8 | from .evaluator import evaluate_openkp, evaluate_kp20k 9 | 10 | from .constant import (PAD, 11 | UNK, 12 | BOS, 13 | EOS, 14 | DIGIT, 15 | PAD_WORD, 16 | UNK_WORD, 17 | BOS_WORD, 18 | EOS_WORD, 19 | DIGIT_WORD, 20 | Idx2Tag, 21 | Tag2Idx, 22 | IdxTag_Converter, 23 | Decode_Candidate_Number) 24 | 25 | 26 | tokenizer_class = {"bert-base-cased":BertTokenizer, 27 | "spanbert-base-cased":BertTokenizer, 28 | "roberta-base":RobertaTokenizer} 29 | 30 | config_class = {"bert-base-cased":BertConfig, 31 | "spanbert-base-cased":BertConfig, 32 | "roberta-base":RobertaConfig} 33 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/constant/Constant.py: -------------------------------------------------------------------------------- 1 | PAD = 0 2 | UNK = 100 3 | BOS = 101 4 | EOS = 102 5 | DIGIT = 1 6 | 7 | PAD_WORD = '[PAD]' 8 | UNK_WORD = '[UNK]' 9 | BOS_WORD = '[CLS]' 10 | EOS_WORD = '[SEP]' 11 | DIGIT_WORD = 'DIGIT' 12 | 13 | Idx2Tag = ['O', 'B', 'I', 'E', 'U'] 14 | Tag2Idx = {'O':0, 'B':1, 'I':2, 'E':3, 'U':4} 15 | 16 | Decode_Candidate_Number = {'openkp':5, 'kp20k':10} 17 | 18 | class IdxTag_Converter(object): 19 | ''' idx2tag : a tag list like ['O','B','I','E','U'] 20 | tag2idx : {'O': 0, 'B': 1, ..., 'U':4} 21 | ''' 22 | def __init__(self, idx2tag): 23 | self.idx2tag = idx2tag 24 | tag2idx = {} 25 | for idx, tag in enumerate(idx2tag): 26 | tag2idx[tag] = idx 27 | self.tag2idx = tag2idx 28 | 29 | def convert_idx2tag(self, index_list): 30 | tag_list = [self.idx2tag[index] for index in index_list] 31 | return tag_list 32 | 33 | def convert_tag2idx(self, tag_list): 34 | index_list = [self.tag2idx[tag] for tag in tag_list] 35 | return index_list 36 | 37 | # 'O' : non-keyphrase 38 | # 'B' : begin word of the keyphrase 39 | # 'I' : middle word of the keyphrase 40 | # 'E' : end word of the keyphrase 41 | # 'U' : single word keyphrase -------------------------------------------------------------------------------- /SMART-KPE/evaluation.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | import numpy as np 4 | 5 | def evaluate_kps(true_ls, pred_ls): 6 | result = {} 7 | for length in [1,3,5]: 8 | result[f"P@{length}"] = 1e-8 9 | result[f"R@{length}"] = 1e-8 10 | result[f"F@{length}"] = 1e-8 11 | for i in range(len(true_ls)): 12 | cur_true = set(true_ls[i]) 13 | cur_pred = pred_ls[i] 14 | match_ls = [] 15 | match_cnt = 0 16 | if len(cur_pred)<=0: 17 | match_ls = [0] 18 | for kp in cur_pred[:5]: 19 | if kp in cur_true: 20 | match_cnt += 1 21 | match_ls.append(match_cnt) 22 | for length in [1,3,5]: 23 | if length>len(match_ls): 24 | result[f"P@{length}"] += match_ls[-1] / len(match_ls) 25 | result[f"R@{length}"] += match_ls[-1] / len(cur_true) 26 | else: 27 | result[f"P@{length}"] += match_ls[length-1] / length 28 | result[f"R@{length}"] += match_ls[length-1] / len(cur_true) 29 | for length in [1,3,5]: 30 | result[f"P@{length}"] /= len(true_ls) 31 | result[f"R@{length}"] /= len(true_ls) 32 | result[f"F@{length}"] = 2.0 / (1.0/result[f"P@{length}"] + 1.0/result[f"R@{length}"]) 33 | return result 34 | 35 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | def get_class(name): 2 | if name == 'bert2span': 3 | return batchify_bert2span_features_for_train, batchify_bert2span_features_for_test 4 | elif name == 'bert2tag': 5 | return batchify_bert2tag_features_for_train, batchify_bert2tag_features_for_test 6 | elif name == 'bert2chunk': 7 | return batchify_bert2chunk_features_for_train, batchify_bert2chunk_features_for_test 8 | elif name == 'bert2rank': 9 | return batchify_bert2rank_features_for_train, batchify_bert2rank_features_for_test 10 | elif name == 'bert2joint': 11 | return batchify_bert2joint_features_for_train, batchify_bert2joint_features_for_test 12 | raise RuntimeError('Invalid retriever class: %s' % name) 13 | 14 | 15 | from .loader_utils import build_dataset 16 | from .bert2span_dataloader import batchify_bert2span_features_for_train, batchify_bert2span_features_for_test 17 | from .bert2tag_dataloader import batchify_bert2tag_features_for_train, batchify_bert2tag_features_for_test 18 | from .bert2chunk_dataloader import batchify_bert2chunk_features_for_train, batchify_bert2chunk_features_for_test 19 | from .bert2rank_dataloader import batchify_bert2rank_features_for_train, batchify_bert2rank_features_for_test 20 | from .bert2joint_dataloader import batchify_bert2joint_features_for_train, batchify_bert2joint_features_for_test -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/generator/Rank2Phrase.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | from .generator_utils import remove_empty, remove_empty_phase, del_stemming_duplicate_phrase 4 | 5 | logger = logging.getLogger() 6 | 7 | 8 | def rank2phrase(examples, logit_lists, indices, stem_flag=False, return_num=None): 9 | batch_predictions = [] 10 | for batch_id, logit_list in enumerate(logit_lists): 11 | example = examples[indices[batch_id]] 12 | 13 | params = {'gram_list': example['phrase_list'], 14 | 'score_logits': logit_list} 15 | 16 | n_best_phrases_scores = decode_n_best_candidates(**params) 17 | candidate_KP, score_KP = remove_empty_phase(n_best_phrases_scores) 18 | 19 | if return_num: 20 | if stem_flag: 21 | candidate_KP, score_KP = del_stemming_duplicate_phrase(candidate_KP, score_KP, return_num) 22 | else: 23 | candidate_KP = candidate_KP[:return_num] 24 | score_KP = score_KP[:return_num] 25 | assert len(candidate_KP) == return_num 26 | 27 | assert len(candidate_KP) == len(score_KP) 28 | batch_predictions.append((example['url'], candidate_KP, score_KP)) 29 | 30 | return batch_predictions 31 | 32 | 33 | def decode_n_best_candidates(gram_list, score_logits): 34 | 35 | assert len(gram_list) == len(score_logits) 36 | ngrams = [(gram.split(), score) for gram, score in zip(gram_list, score_logits)] 37 | sorted_ngrams = sorted(ngrams, key=lambda x: x[1], reverse=True) 38 | 39 | return sorted_ngrams -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SMART-KPE 2 | Code for paper "[Incorporating Multimodal Information in Open-Domain Web Keyphrase Extraction](https://www.aclweb.org/anthology/2020.emnlp-main.140/)" 3 | You can download the data [here](https://victorywys-datasets.s3.us-east-2.amazonaws.com/OpenKP_title_and_snapshot.zip). 4 | 5 | Update: 6 | 7 | Since we released the code, we have been working on writing a document and comments for you. In order to make it easy to replicate the result and commpare to previous works, we're trying to generate checkpoints from all 15 varients according to [BERT-KPE](https://github.com/thunlp/BERT-KPE) upon their version of codes (included in BERT-KPE-BASED folder) and we will release them as soon as possible. 8 | 9 | **We provide final checkpoints from BERT-KPE_based [here](https://victorywys-datasets.s3.us-east-2.amazonaws.com/final_checkpoints.zip). Currently we only upload the best model(Roberta2Joint based SMART-KPE, F@3: 0.405) and we'll update more from different varients soon.** You can use the `test.sh` in the script folder to check the results. 10 | 11 | To run the code, make sure you're using Pytorch 1.4.0, otherwise the data parallel part/transformer may not work properly. 12 | 13 | If you would like to replicate the best result before we release other checkpoints, you can first try the following steps: 14 | 1. Download the image data and title data. 15 | 2. Add all the title data to the dataset files. In the original jsonl file, each line corresponds to a piece of data and contain 3 domains: `url`, `text` and `VDOM`. In order to use title data, you can add a new domain named `title` and the content is the title string. (We will add a script to help you process it in the near future) 16 | 3. Follow instructions of BERT-KPE to proprecess data and run the experiment with scripts provided in this repo. 17 | 18 | Thanks again for your interest of our work! 19 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/configuration_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ RoBERTa configuration """ 17 | 18 | from __future__ import (absolute_import, division, print_function, 19 | unicode_literals) 20 | 21 | import logging 22 | 23 | from .configuration_bert import BertConfig 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 28 | 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json", 29 | 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json", 30 | 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json", 31 | 'distilroberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json", 32 | 'roberta-base-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-config.json", 33 | 'roberta-large-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-config.json", 34 | } 35 | 36 | 37 | class RobertaConfig(BertConfig): 38 | pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP 39 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/tokenization_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import pytest 22 | import logging 23 | 24 | from transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer 25 | from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP 26 | 27 | 28 | class AutoTokenizerTest(unittest.TestCase): 29 | @pytest.mark.slow 30 | def test_tokenizer_from_pretrained(self): 31 | logging.basicConfig(level=logging.INFO) 32 | for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]: 33 | tokenizer = AutoTokenizer.from_pretrained(model_name) 34 | self.assertIsNotNone(tokenizer) 35 | self.assertIsInstance(tokenizer, BertTokenizer) 36 | self.assertGreater(len(tokenizer), 0) 37 | 38 | for model_name in list(GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]: 39 | tokenizer = AutoTokenizer.from_pretrained(model_name) 40 | self.assertIsNotNone(tokenizer) 41 | self.assertIsInstance(tokenizer, GPT2Tokenizer) 42 | self.assertGreater(len(tokenizer), 0) 43 | 44 | 45 | if __name__ == "__main__": 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /BERT-KPE-based/scripts/train.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=../../data 2 | export SNAPSHOT_PATH=/usr0/home/yansenwa/courses/11747/project/metadata/snapshot 3 | 4 | export dataset_class=openkp # openkp , kp20k 5 | export max_train_steps=20810 # 20810 (openkp) , 73430 (kp20k) 6 | 7 | export model_class=bert2joint # bert2span, bert2tag, bert2chunk, bert2rank, bert2joint 8 | export pretrain_model=bert-base-cased # bert-base-cased , spanbert-base-cased , roberta-base 9 | 10 | ## -------------------------------------------------------------------------------- 11 | ## DataParallel (Multi-GPUs) 12 | 13 | CUDA_VISIBLE_DEVICES=0 python3 train.py --run_mode train \ 14 | --local_rank -1 \ 15 | --max_train_steps $max_train_steps \ 16 | --model_class $model_class \ 17 | --dataset_class $dataset_class \ 18 | --pretrain_model_type $pretrain_model \ 19 | --per_gpu_train_batch_size 4 \ 20 | --gradient_accumulation_steps 8 \ 21 | --per_gpu_test_batch_size 16 \ 22 | --preprocess_folder $DATA_PATH/prepro_dataset \ 23 | --pretrain_model_path $DATA_PATH/pretrain_model \ 24 | --cached_features_dir $DATA_PATH/cached_features \ 25 | --snapshot_path $SNAPSHOT_PATH \ 26 | --display_iter 10000 \ 27 | --save_checkpoint \ 28 | --use_viso \ 29 | 30 | 31 | # ## -------------------------------------------------------------------------------- 32 | # ## Distributed-DataParallel (Multi-GPUs) 33 | #CUDA_VISIBLE_DEVICES=0,1 OMP_NUM_THREADS=2 python -m torch.distributed.launch --nproc_per_node=2 --master_port=1234 train.py --run_mode train \ 34 | # --max_train_steps $max_train_steps \ 35 | # --model_class $model_class \ 36 | # --dataset_class $dataset_class \ 37 | # --pretrain_model_type $pretrain_model \ 38 | # --per_gpu_train_batch_size 4 \ 39 | # --gradient_accumulation_steps 8 \ 40 | # --per_gpu_test_batch_size 16 \ 41 | # --preprocess_folder $DATA_PATH/prepro_dataset \ 42 | # --pretrain_model_path $DATA_PATH/pretrain_model \ 43 | # --cached_features_dir $DATA_PATH/cached_features \ 44 | # --display_iter 1000 \ 45 | # --save_checkpoint \ 46 | # --use_viso \ 47 | -------------------------------------------------------------------------------- /BERT-KPE-based/scripts/train_dist.sh: -------------------------------------------------------------------------------- 1 | export DATA_PATH=../../data 2 | export SNAPSHOT_PATH=/usr0/home/yansenwa/courses/11747/project/metadata/snapshot 3 | 4 | export dataset_class=openkp # openkp , kp20k 5 | export max_train_steps=20810 # 20810 (openkp) , 73430 (kp20k) 6 | 7 | export model_class=bert2joint # bert2span, bert2tag, bert2chunk, bert2rank, bert2joint 8 | export pretrain_model=roberta-base # bert-base-cased , spanbert-base-cased , roberta-base 9 | 10 | ## -------------------------------------------------------------------------------- 11 | ## DataParallel (Multi-GPUs) 12 | 13 | #CUDA_VISIBLE_DEVICES=1 python3 train.py --run_mode train \ 14 | # --local_rank -1 \ 15 | # --max_train_steps $max_train_steps \ 16 | # --model_class $model_class \ 17 | # --dataset_class $dataset_class \ 18 | # --pretrain_model_type $pretrain_model \ 19 | # --per_gpu_train_batch_size 4 \ 20 | # --gradient_accumulation_steps 16 \ 21 | # --per_gpu_test_batch_size 16 \ 22 | # --preprocess_folder $DATA_PATH/prepro_dataset \ 23 | # --pretrain_model_path $DATA_PATH/pretrain_model \ 24 | # --cached_features_dir $DATA_PATH/cached_features \ 25 | # --snapshot_path $SNAPSHOT_PATH \ 26 | # --display_iter 500 \ 27 | # --save_checkpoint \ 28 | # --use_viso \ 29 | 30 | 31 | # ## -------------------------------------------------------------------------------- 32 | # ## Distributed-DataParallel (Multi-GPUs) 33 | CUDA_VISIBLE_DEVICES=0,1 OMP_NUM_THREADS=2 python -m torch.distributed.launch --nproc_per_node=2 --master_port=1234 train.py --run_mode train \ 34 | --max_train_steps $max_train_steps \ 35 | --model_class $model_class \ 36 | --dataset_class $dataset_class \ 37 | --pretrain_model_type $pretrain_model \ 38 | --per_gpu_train_batch_size 4 \ 39 | --gradient_accumulation_steps 8 \ 40 | --per_gpu_test_batch_size 16 \ 41 | --preprocess_folder $DATA_PATH/prepro_dataset \ 42 | --pretrain_model_path $DATA_PATH/pretrain_model \ 43 | --cached_features_dir $DATA_PATH/cached_features \ 44 | --display_iter 1000 \ 45 | --save_checkpoint \ 46 | --use_viso \ 47 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/tokenization_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 HuggingFace Inc.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import six 21 | import pytest 22 | 23 | from transformers import PreTrainedTokenizer 24 | from transformers.tokenization_gpt2 import GPT2Tokenizer 25 | 26 | class TokenizerUtilsTest(unittest.TestCase): 27 | @pytest.mark.slow 28 | def check_tokenizer_from_pretrained(self, tokenizer_class): 29 | s3_models = list(tokenizer_class.max_model_input_sizes.keys()) 30 | for model_name in s3_models[:1]: 31 | tokenizer = tokenizer_class.from_pretrained(model_name) 32 | self.assertIsNotNone(tokenizer) 33 | self.assertIsInstance(tokenizer, tokenizer_class) 34 | self.assertIsInstance(tokenizer, PreTrainedTokenizer) 35 | 36 | for special_tok in tokenizer.all_special_tokens: 37 | if six.PY2: 38 | self.assertIsInstance(special_tok, unicode) 39 | else: 40 | self.assertIsInstance(special_tok, str) 41 | special_tok_id = tokenizer.convert_tokens_to_ids(special_tok) 42 | self.assertIsInstance(special_tok_id, int) 43 | 44 | def test_pretrained_tokenizers(self): 45 | self.check_tokenizer_from_pretrained(GPT2Tokenizer) 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/tokenization_distilbert_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import pytest 20 | from io import open 21 | 22 | from transformers.tokenization_distilbert import (DistilBertTokenizer) 23 | 24 | from .tokenization_tests_commons import CommonTestCases 25 | from .tokenization_bert_test import BertTokenizationTest 26 | 27 | class DistilBertTokenizationTest(BertTokenizationTest): 28 | 29 | tokenizer_class = DistilBertTokenizer 30 | 31 | def get_tokenizer(self, **kwargs): 32 | return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs) 33 | 34 | @pytest.mark.slow 35 | def test_sequence_builders(self): 36 | tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") 37 | 38 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 39 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 40 | 41 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 42 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 43 | 44 | assert encoded_sentence == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] 45 | assert encoded_pair == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] + \ 46 | text_2 + [tokenizer.sep_token_id] 47 | 48 | 49 | if __name__ == '__main__': 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/generator/generator_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | import unicodedata 4 | 5 | from nltk.stem.porter import PorterStemmer 6 | stemmer = PorterStemmer() 7 | 8 | 9 | def normalize_answer(s): 10 | def remove_articles(text): 11 | return re.sub(r'\b(a|an|the)\b', ' ', text) 12 | def white_space_fix(text): 13 | return ' '.join(text.split()) 14 | def remove_punc(text): 15 | exclude = set(string.punctuation) 16 | return ''.join(ch for ch in text if ch not in exclude) 17 | def lower(text): 18 | return text.lower() 19 | return ' '.join([lower(x) for x in s]).rstrip() 20 | 21 | 22 | def remove_empty(a_list): 23 | new_list = [] 24 | for i in a_list: 25 | if len(i) > 0: 26 | if len(i[0]) >0: 27 | new_list.append(normalize_answer(i)) 28 | return new_list 29 | 30 | 31 | def remove_empty_phase(phrases_scores): 32 | phrase_list = [] 33 | score_list = [] 34 | for phrase, score, in phrases_scores: 35 | if len(phrase) > 0: 36 | if len(phrase[0]) > 0: 37 | phrase_list.append(normalize_answer(phrase)) 38 | score_list.append(score) 39 | return phrase_list, score_list 40 | 41 | 42 | 43 | def stem_norm_phrase(phrase): 44 | norm_chars = unicodedata.normalize('NFD', phrase) 45 | stem_chars = " ".join([stemmer.stem(w) for w in norm_chars.split(" ")]) 46 | return norm_chars, stem_chars 47 | 48 | 49 | def del_stemming_duplicate_phrase(phrase_list, score_list, return_num): 50 | tot_phrases_set = set() 51 | return_phrases, return_scores = [], [] 52 | 53 | for phrase, score in zip(phrase_list, score_list): 54 | norm_phrase, stem_phrase = stem_norm_phrase(phrase) 55 | 56 | if (norm_phrase not in tot_phrases_set) and (stem_phrase not in tot_phrases_set): 57 | return_phrases.append(phrase) 58 | return_scores.append(score) 59 | 60 | tot_phrases_set.add(norm_phrase) 61 | tot_phrases_set.add(stem_phrase) 62 | 63 | if len(return_phrases) >= return_num:break 64 | return return_phrases, return_scores -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/modeling_encoder_decoder_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Hugging Face Inc. Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import logging 17 | import unittest 18 | import pytest 19 | 20 | from transformers import is_torch_available 21 | 22 | if is_torch_available(): 23 | from transformers import BertModel, BertForMaskedLM, Model2Model 24 | from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 25 | else: 26 | pytestmark = pytest.mark.skip("Require Torch") 27 | 28 | 29 | class EncoderDecoderModelTest(unittest.TestCase): 30 | @pytest.mark.slow 31 | def test_model2model_from_pretrained(self): 32 | logging.basicConfig(level=logging.INFO) 33 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 34 | model = Model2Model.from_pretrained(model_name) 35 | self.assertIsInstance(model.encoder, BertModel) 36 | self.assertIsInstance(model.decoder, BertForMaskedLM) 37 | self.assertEqual(model.decoder.config.is_decoder, True) 38 | self.assertEqual(model.encoder.config.is_decoder, False) 39 | 40 | def test_model2model_from_pretrained_not_bert(self): 41 | logging.basicConfig(level=logging.INFO) 42 | with self.assertRaises(ValueError): 43 | _ = Model2Model.from_pretrained('roberta') 44 | 45 | with self.assertRaises(ValueError): 46 | _ = Model2Model.from_pretrained('distilbert') 47 | 48 | with self.assertRaises(ValueError): 49 | _ = Model2Model.from_pretrained('does-not-exist') 50 | 51 | 52 | if __name__ == "__main__": 53 | unittest.main() 54 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/networks/__init__.py: -------------------------------------------------------------------------------- 1 | def get_class(args): 2 | 3 | # Bert2Span 4 | if args.model_class == 'bert2span' and args.pretrain_model_type in ['bert-base-cased', 'spanbert-base-cased']: 5 | return BertForAttSpanExtractor 6 | elif args.model_class == 'bert2span' and args.pretrain_model_type == 'roberta-base': 7 | return RobertaForAttSpanExtractor 8 | 9 | # Bert2Tag 10 | elif args.model_class == 'bert2tag' and args.pretrain_model_type in ['bert-base-cased', 'spanbert-base-cased']: 11 | return BertForSeqTagging 12 | elif args.model_class == 'bert2tag' and args.pretrain_model_type == 'roberta-base': 13 | return RobertaForSeqTagging 14 | 15 | # Bert2Chunk 16 | elif args.model_class == 'bert2chunk' and args.pretrain_model_type in ['bert-base-cased', 'spanbert-base-cased']: 17 | return BertForCnnGramExtractor 18 | elif args.model_class == 'bert2chunk' and args.pretrain_model_type == 'roberta-base': 19 | return RobertaForCnnGramExtractor 20 | 21 | # Bert2Rank 22 | elif args.model_class == 'bert2rank' and args.pretrain_model_type in ['bert-base-cased', 'spanbert-base-cased']: 23 | return BertForTFRanking 24 | elif args.model_class == 'bert2rank' and args.pretrain_model_type == 'roberta-base': 25 | return RobertaForTFRanking 26 | 27 | # Bert2Joint 28 | elif args.model_class == 'bert2joint' and args.pretrain_model_type in ['bert-base-cased', 'spanbert-base-cased']: 29 | return BertForChunkTFRanking 30 | elif args.model_class == 'bert2joint' and args.pretrain_model_type == 'roberta-base': 31 | return RobertaForChunkTFRanking 32 | 33 | raise RuntimeError('Invalid retriever class: %s' % name) 34 | 35 | 36 | # bert2span 37 | from .Bert2Span import BertForAttSpanExtractor 38 | from .Roberta2Span import RobertaForAttSpanExtractor 39 | 40 | # bert2tag 41 | from .Bert2Tag import BertForSeqTagging 42 | from .Roberta2Tag import RobertaForSeqTagging 43 | 44 | # bert2chunk 45 | from .Bert2Chunk import BertForCnnGramExtractor 46 | from .Roberta2Chunk import RobertaForCnnGramExtractor 47 | 48 | # bert2rank 49 | from .Bert2Rank import BertForTFRanking 50 | from .Roberta2Rank import RobertaForTFRanking 51 | 52 | # bert2joint 53 | from .Bert2Joint import BertForChunkTFRanking 54 | from .Roberta2Joint import RobertaForChunkTFRanking -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/configuration_common_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 HuggingFace Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import copy 20 | import os 21 | import shutil 22 | import json 23 | import random 24 | import uuid 25 | 26 | import unittest 27 | import logging 28 | 29 | 30 | class ConfigTester(object): 31 | def __init__(self, parent, config_class=None, **kwargs): 32 | self.parent = parent 33 | self.config_class = config_class 34 | self.inputs_dict = kwargs 35 | 36 | def create_and_test_config_common_properties(self): 37 | config = self.config_class(**self.inputs_dict) 38 | self.parent.assertTrue(hasattr(config, 'vocab_size')) 39 | self.parent.assertTrue(hasattr(config, 'hidden_size')) 40 | self.parent.assertTrue(hasattr(config, 'num_attention_heads')) 41 | self.parent.assertTrue(hasattr(config, 'num_hidden_layers')) 42 | 43 | def create_and_test_config_to_json_string(self): 44 | config = self.config_class(**self.inputs_dict) 45 | obj = json.loads(config.to_json_string()) 46 | for key, value in self.inputs_dict.items(): 47 | self.parent.assertEqual(obj[key], value) 48 | 49 | def create_and_test_config_to_json_file(self): 50 | config_first = self.config_class(**self.inputs_dict) 51 | json_file_path = os.path.join(os.getcwd(), "config_" + str(uuid.uuid4()) + ".json") 52 | config_first.to_json_file(json_file_path) 53 | config_second = self.config_class.from_json_file(json_file_path) 54 | os.remove(json_file_path) 55 | self.parent.assertEqual(config_second.to_dict(), config_first.to_dict()) 56 | 57 | def run_common_tests(self): 58 | self.create_and_test_config_common_properties() 59 | self.create_and_test_config_to_json_string() 60 | self.create_and_test_config_to_json_file() 61 | 62 | if __name__ == "__main__": 63 | unittest.main() -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/generator/Chunk2Phrase.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | from ..constant import Tag2Idx 4 | from .generator_utils import remove_empty_phase, del_stemming_duplicate_phrase 5 | logger = logging.getLogger() 6 | 7 | def chunk2phrase(examples, logit_lists, indices, max_phrase_words, return_num, stem_flag=False): 8 | batch_predictions = [] 9 | for batch_id, logit_list in enumerate(logit_lists): 10 | example = examples[indices[batch_id]] 11 | 12 | params = {'orig_tokens': example['doc_words'], 13 | 'gram_logits': logit_list, 14 | 'max_gram': max_phrase_words} 15 | 16 | n_best_phrases_scores = decode_n_best_candidates(**params) 17 | candidate_KP, score_KP = remove_empty_phase(n_best_phrases_scores) 18 | 19 | if return_num: 20 | if stem_flag: 21 | candidate_KP, score_KP = del_stemming_duplicate_phrase(candidate_KP, score_KP, return_num) 22 | else: 23 | candidate_KP = candidate_KP[:return_num] 24 | score_KP = score_KP[:return_num] 25 | assert len(candidate_KP) == return_num 26 | 27 | assert len(candidate_KP) == len(score_KP) 28 | batch_predictions.append((example['url'], candidate_KP, score_KP)) 29 | return batch_predictions 30 | 31 | 32 | 33 | def decode_n_best_candidates(orig_tokens, gram_logits, max_gram): 34 | ''' 35 | max_gram : type :int , max_phrase_words 36 | return : phrase token list & score list 37 | ''' 38 | orig_tokens = [token.lower() for token in orig_tokens] 39 | sorted_ngrams = decode_ngram(orig_tokens=orig_tokens, 40 | gram_logits=gram_logits, 41 | max_gram=max_gram) 42 | return sorted_ngrams 43 | 44 | 45 | 46 | 47 | def decode_ngram(orig_tokens, gram_logits, max_gram): 48 | 49 | ngram_score = [] 50 | for n in range(max_gram): 51 | for i in range(len(orig_tokens) - n): 52 | ngram_score.append((" ".join(orig_tokens[i:i+n+1]), gram_logits[n][i])) 53 | 54 | phrase_set = {} 55 | for n_gram, n_gram_score in ngram_score: 56 | if n_gram not in phrase_set or n_gram_score > phrase_set[n_gram]: 57 | phrase_set[n_gram] = n_gram_score 58 | else: 59 | continue 60 | 61 | phrase_list = [] 62 | for phrase, score in phrase_set.items(): 63 | phrase_list.append((phrase.split(), score)) 64 | 65 | sorted_phrase_list = sorted(phrase_list, key=lambda x: x[1], reverse=True) 66 | return sorted_phrase_list 67 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tokenization_distilbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for DistilBERT.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from .tokenization_bert import BertTokenizer 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} 30 | 31 | PRETRAINED_VOCAB_FILES_MAP = { 32 | 'vocab_file': 33 | { 34 | 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 35 | 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 36 | } 37 | } 38 | 39 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 40 | 'distilbert-base-uncased': 512, 41 | 'distilbert-base-uncased-distilled-squad': 512, 42 | } 43 | 44 | 45 | class DistilBertTokenizer(BertTokenizer): 46 | r""" 47 | Constructs a DistilBertTokenizer. 48 | :class:`~transformers.DistilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece 49 | 50 | Args: 51 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 52 | do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False 53 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 54 | max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the 55 | minimum of this value (if specified) and the underlying BERT model's sequence length. 56 | never_split: List of tokens which will never be split during tokenization. Only has an effect when 57 | do_wordpiece_only=False 58 | """ 59 | 60 | vocab_files_names = VOCAB_FILES_NAMES 61 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 62 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 63 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/convert_bert_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import torch 23 | 24 | from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert 25 | 26 | import logging 27 | logging.basicConfig(level=logging.INFO) 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = BertConfig.from_json_file(bert_config_file) 32 | print("Building PyTorch model from configuration: {}".format(str(config))) 33 | model = BertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_bert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | ## Required parameters 46 | parser.add_argument("--tf_checkpoint_path", 47 | default = None, 48 | type = str, 49 | required = True, 50 | help = "Path to the TensorFlow checkpoint path.") 51 | parser.add_argument("--bert_config_file", 52 | default = None, 53 | type = str, 54 | required = True, 55 | help = "The config json file corresponding to the pre-trained BERT model. \n" 56 | "This specifies the model architecture.") 57 | parser.add_argument("--pytorch_dump_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the output PyTorch model.") 62 | args = parser.parse_args() 63 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 64 | args.bert_config_file, 65 | args.pytorch_dump_path) 66 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/tokenization_openai_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | 21 | from transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | 26 | class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | 28 | tokenizer_class = OpenAIGPTTokenizer 29 | 30 | def setUp(self): 31 | super(OpenAIGPTTokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "w", "r", "t", 36 | "lo", "low", "er", 37 | "low", "lowest", "newer", "wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["#version: 0.2", "l o", "lo w", "e r", ""] 40 | 41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 42 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 43 | with open(self.vocab_file, "w") as fp: 44 | fp.write(json.dumps(vocab_tokens)) 45 | with open(self.merges_file, "w") as fp: 46 | fp.write("\n".join(merges)) 47 | 48 | def get_tokenizer(self, **kwargs): 49 | return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs) 50 | 51 | def get_input_output_texts(self): 52 | input_text = u"lower newer" 53 | output_text = u"lower newer" 54 | return input_text, output_text 55 | 56 | 57 | def test_full_tokenizer(self): 58 | tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file) 59 | 60 | text = "lower" 61 | bpe_tokens = ["low", "er"] 62 | tokens = tokenizer.tokenize(text) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [""] 66 | input_bpe_tokens = [14, 15, 20] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/generator/Span2Phrase.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from .generator_utils import remove_empty_phase, del_stemming_duplicate_phrase 5 | 6 | logger = logging.getLogger() 7 | 8 | 9 | def span2phrase(examples, start_lists, end_lists, indices, max_phrase_words, return_num=None, stem_flag=False): 10 | 11 | batch_predictions = [] 12 | for batch_id, (start_logit, end_logit) in enumerate(zip(start_lists, end_lists)): 13 | example = examples[indices[batch_id]] 14 | 15 | assert len(start_logit) == len(end_logit) == len(end_logit[0]) == len(example['doc_words']) # word_len 16 | 17 | params = {'orig_tokens': example['doc_words'], 18 | 'start_logit': start_logit, 19 | 'end_logit': end_logit, 20 | 'max_gram':max_phrase_words} 21 | 22 | n_best_phrases_scores = decode_n_best_candidates(**params) 23 | candidate_KP, score_KP = remove_empty_phase(n_best_phrases_scores) 24 | 25 | if return_num: 26 | if stem_flag: 27 | candidate_KP, score_KP = del_stemming_duplicate_phrase(candidate_KP, score_KP, return_num) 28 | else: 29 | candidate_KP = candidate_KP[:return_num] 30 | score_KP = score_KP[:return_num] 31 | assert len(candidate_KP) == return_num 32 | 33 | assert len(candidate_KP) == len(score_KP) 34 | batch_predictions.append((example['url'], candidate_KP, score_KP)) 35 | return batch_predictions 36 | 37 | 38 | 39 | def decode_n_best_candidates(orig_tokens, start_logit, end_logit, max_gram): 40 | ''' 41 | max_gram : type :int , max_phrase_words 42 | return : phrase token list & score list 43 | ''' 44 | assert len(orig_tokens) == len(start_logit) == len(end_logit) 45 | orig_tokens = [token.lower() for token in orig_tokens] 46 | 47 | sorted_ngrams = decode_span2phrase(**{"orig_tokens":orig_tokens, 48 | "start_logit":start_logit, 49 | "end_logit":end_logit, 50 | "max_gram":max_gram}) 51 | return sorted_ngrams 52 | 53 | 54 | def decode_span2phrase(orig_tokens, start_logit, end_logit, max_gram): 55 | phrase2score = {} 56 | for (i, s) in enumerate(start_logit): 57 | for (j, e) in enumerate(end_logit[i][i:(i+max_gram)]): 58 | phrase = " ".join(orig_tokens[i:(i+j+1)]) 59 | score = s * e 60 | if (phrase not in phrase2score) or (score > phrase2score[phrase]): 61 | phrase2score[phrase] = score 62 | else: 63 | continue 64 | 65 | phrase_list = [] 66 | for phrase, score in phrase2score.items(): 67 | phrase_list.append((phrase.split(), score)) 68 | 69 | sorted_phrase_list = sorted(phrase_list, key=lambda x: x[1], reverse=True) 70 | return sorted_phrase_list -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/tokenization_ctrl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Salesforce and HuggingFace Inc. team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import absolute_import, division, print_function, unicode_literals 15 | 16 | import os 17 | import unittest 18 | import json 19 | from io import open 20 | 21 | from transformers.tokenization_ctrl import CTRLTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | class CTRLTokenizationTest(CommonTestCases.CommonTokenizerTester): 26 | 27 | tokenizer_class = CTRLTokenizer 28 | 29 | def setUp(self): 30 | super(CTRLTokenizationTest, self).setUp() 31 | 32 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 33 | vocab = ['adapt', 're@@', 'a@@', 'apt', 'c@@', 't', ''] 34 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 35 | merges = ["#version: 0.2", 'a p', 'ap t', 'r e', 'a d', 'ad apt', ''] 36 | self.special_tokens_map = {"unk_token": ""} 37 | 38 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 39 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 40 | with open(self.vocab_file, "w", encoding="utf-8") as fp: 41 | fp.write(json.dumps(vocab_tokens) + "\n") 42 | with open(self.merges_file, "w", encoding="utf-8") as fp: 43 | fp.write("\n".join(merges)) 44 | 45 | def get_tokenizer(self, **kwargs): 46 | kwargs.update(self.special_tokens_map) 47 | return CTRLTokenizer.from_pretrained(self.tmpdirname, **kwargs) 48 | 49 | def get_input_output_texts(self): 50 | input_text = u"adapt react readapt apt" 51 | output_text = u"adapt react readapt apt" 52 | return input_text, output_text 53 | 54 | def test_full_tokenizer(self): 55 | tokenizer = CTRLTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 56 | text = "adapt react readapt apt" 57 | bpe_tokens = 'adapt re@@ a@@ c@@ t re@@ adapt apt'.split() 58 | tokens = tokenizer.tokenize(text) 59 | self.assertListEqual(tokens, bpe_tokens) 60 | 61 | input_tokens = tokens + [tokenizer.unk_token] 62 | 63 | input_bpe_tokens = [0, 1, 2, 4, 5, 1, 0, 3, 6] 64 | self.assertListEqual( 65 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 66 | 67 | 68 | if __name__ == '__main__': 69 | unittest.main() 70 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/data/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import csv 18 | import sys 19 | import logging 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | try: 24 | from scipy.stats import pearsonr, spearmanr 25 | from sklearn.metrics import matthews_corrcoef, f1_score 26 | _has_sklearn = True 27 | except (AttributeError, ImportError) as e: 28 | logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html") 29 | _has_sklearn = False 30 | 31 | def is_sklearn_available(): 32 | return _has_sklearn 33 | 34 | if _has_sklearn: 35 | 36 | def simple_accuracy(preds, labels): 37 | return (preds == labels).mean() 38 | 39 | 40 | def acc_and_f1(preds, labels): 41 | acc = simple_accuracy(preds, labels) 42 | f1 = f1_score(y_true=labels, y_pred=preds) 43 | return { 44 | "acc": acc, 45 | "f1": f1, 46 | "acc_and_f1": (acc + f1) / 2, 47 | } 48 | 49 | 50 | def pearson_and_spearman(preds, labels): 51 | pearson_corr = pearsonr(preds, labels)[0] 52 | spearman_corr = spearmanr(preds, labels)[0] 53 | return { 54 | "pearson": pearson_corr, 55 | "spearmanr": spearman_corr, 56 | "corr": (pearson_corr + spearman_corr) / 2, 57 | } 58 | 59 | 60 | def glue_compute_metrics(task_name, preds, labels): 61 | assert len(preds) == len(labels) 62 | if task_name == "cola": 63 | return {"mcc": matthews_corrcoef(labels, preds)} 64 | elif task_name == "sst-2": 65 | return {"acc": simple_accuracy(preds, labels)} 66 | elif task_name == "mrpc": 67 | return acc_and_f1(preds, labels) 68 | elif task_name == "sts-b": 69 | return pearson_and_spearman(preds, labels) 70 | elif task_name == "qqp": 71 | return acc_and_f1(preds, labels) 72 | elif task_name == "mnli": 73 | return {"acc": simple_accuracy(preds, labels)} 74 | elif task_name == "mnli-mm": 75 | return {"acc": simple_accuracy(preds, labels)} 76 | elif task_name == "qnli": 77 | return {"acc": simple_accuracy(preds, labels)} 78 | elif task_name == "rte": 79 | return {"acc": simple_accuracy(preds, labels)} 80 | elif task_name == "wnli": 81 | return {"acc": simple_accuracy(preds, labels)} 82 | else: 83 | raise KeyError(task_name) 84 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/tokenization_gpt2_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | from io import open 21 | 22 | from transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES 23 | 24 | from .tokenization_tests_commons import CommonTestCases 25 | 26 | class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | 28 | tokenizer_class = GPT2Tokenizer 29 | 30 | def setUp(self): 31 | super(GPT2TokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "\u0120", "\u0120l", "\u0120n", 36 | "\u0120lo", "\u0120low", "er", 37 | "\u0120lowest", "\u0120newer", "\u0120wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] 40 | self.special_tokens_map = {"unk_token": ""} 41 | 42 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 43 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 44 | with open(self.vocab_file, "w", encoding="utf-8") as fp: 45 | fp.write(json.dumps(vocab_tokens) + "\n") 46 | with open(self.merges_file, "w", encoding="utf-8") as fp: 47 | fp.write("\n".join(merges)) 48 | 49 | def get_tokenizer(self, **kwargs): 50 | kwargs.update(self.special_tokens_map) 51 | return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs) 52 | 53 | def get_input_output_texts(self): 54 | input_text = u"lower newer" 55 | output_text = u"lower newer" 56 | return input_text, output_text 57 | 58 | def test_full_tokenizer(self): 59 | tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 60 | text = "lower newer" 61 | bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] 62 | tokens = tokenizer.tokenize(text, add_prefix_space=True) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [tokenizer.unk_token] 66 | input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/tokenization_transfo_xl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import pytest 20 | from io import open 21 | 22 | from transformers import is_torch_available 23 | 24 | if is_torch_available(): 25 | import torch 26 | from transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES 27 | else: 28 | pytestmark = pytest.mark.skip("Require Torch") # TODO: untangle Transfo-XL tokenizer from torch.load and torch.save 29 | 30 | from .tokenization_tests_commons import CommonTestCases 31 | 32 | class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): 33 | 34 | tokenizer_class = TransfoXLTokenizer if is_torch_available() else None 35 | 36 | def setUp(self): 37 | super(TransfoXLTokenizationTest, self).setUp() 38 | 39 | vocab_tokens = [ 40 | "", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", 41 | "running", ",", "low", "l", 42 | ] 43 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 44 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: 45 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 46 | 47 | def get_tokenizer(self, **kwargs): 48 | kwargs['lower_case'] = True 49 | return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs) 50 | 51 | def get_input_output_texts(self): 52 | input_text = u" UNwanted , running" 53 | output_text = u" unwanted, running" 54 | return input_text, output_text 55 | 56 | def test_full_tokenizer(self): 57 | tokenizer = TransfoXLTokenizer(vocab_file=self.vocab_file, lower_case=True) 58 | 59 | tokens = tokenizer.tokenize(u" UNwanted , running") 60 | self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) 61 | 62 | self.assertListEqual( 63 | tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) 64 | 65 | def test_full_tokenizer_lower(self): 66 | tokenizer = TransfoXLTokenizer(lower_case=True) 67 | 68 | self.assertListEqual( 69 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 70 | ["hello", "!", "how", "are", "you", "?"]) 71 | 72 | def test_full_tokenizer_no_lower(self): 73 | tokenizer = TransfoXLTokenizer(lower_case=False) 74 | 75 | self.assertListEqual( 76 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 77 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 78 | 79 | 80 | if __name__ == '__main__': 81 | unittest.main() 82 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/networks/Bert2Tag.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import numpy as np 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch.nn import CrossEntropyLoss, NLLLoss 7 | from ..transformers import BertForTokenClassification 8 | 9 | logger = logging.getLogger() 10 | 11 | 12 | 13 | class BertForSeqTagging(BertForTokenClassification): 14 | def forward(self, visual_input, meta_input, input_ids, attention_mask, valid_ids, active_mask, valid_output, labels=None): 15 | 16 | # -------------------------------------------------------------------------------- 17 | # Bert Embedding Outputs 18 | outputs = self.bert(input_ids=input_ids, 19 | attention_mask=attention_mask) 20 | 21 | attention_mask = attention_mask.to(torch.bool) 22 | 23 | sequence_output = outputs[0] # batch * len * bert_size 24 | 25 | # -------------------------------------------------------------------------------- 26 | # Meta-feature Predictor 27 | bert_cls = sequence_output[:, 0, :].squeeze(1) # batch * bert_size 28 | meta_cat = torch.cat([bert_cls, meta_input], -1) 29 | pred_mask_before_softmax = self.meta_selector(meta_cat) 30 | pred_mask = F.softmax(pred_mask_before_softmax, -1).unsqueeze(-1).unsqueeze(-1) 31 | 32 | # -------------------------------------------------------------------------------- 33 | # Visual Embedding Outputs 34 | visual_t = visual_input.transpose(0, 1) 35 | visual_embedding = self.visual_trans(visual_t, 36 | src_key_padding_mask=(~attention_mask)).transpose(0, 1) 37 | 38 | embedding = torch.cat([sequence_output, visual_embedding], -1).transpose(0, 1) # len * batch * embed_size 39 | 40 | phrase_embedding = self.phrase_trans(embedding, 41 | src_key_padding_mask=(~attention_mask)).transpose(0, 1) 42 | 43 | # -------------------------------------------------------------------------------- 44 | # Valid Outputs : get first token vector 45 | batch_size = phrase_embedding.size(0) 46 | for i in range(batch_size): 47 | valid_num = sum(valid_ids[i]).item() 48 | 49 | vectors = phrase_embedding[i][valid_ids[i] == 1] 50 | valid_output[i, :valid_num].copy_(vectors) 51 | 52 | # -------------------------------------------------------------------------------- 53 | # Dropout 54 | phrase_embedding = self.dropout(valid_output) 55 | logits = torch.cat( 56 | [self.tag_prediction[i](phrase_embedding).unsqueeze(1) 57 | for i in range(4)], 58 | 1 59 | ) 60 | pred = F.softmax(logits, -1) 61 | pred = torch.log(torch.sum(pred * pred_mask, 1)) 62 | 63 | # -------------------------------------------------------------------------------- 64 | # Active Logits 65 | active_loss = active_mask.view(-1) == 1 # [False, True, ...] 66 | active_logits = pred.view(-1, self.num_labels)[active_loss] # False 67 | 68 | if labels is not None: 69 | loss_fct = NLLLoss() 70 | active_labels = labels.view(-1)[active_loss] 71 | loss = loss_fct(active_logits, active_labels) 72 | return loss 73 | else: 74 | return active_logits -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/convert_gpt2_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from transformers import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | 33 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 34 | # Construct model 35 | if gpt2_config_file == "": 36 | config = GPT2Config() 37 | else: 38 | config = GPT2Config.from_json_file(gpt2_config_file) 39 | model = GPT2Model(config) 40 | 41 | # Load weights from numpy 42 | load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 48 | torch.save(model.state_dict(), pytorch_weights_dump_path) 49 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 51 | f.write(config.to_json_string()) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | ## Required parameters 57 | parser.add_argument("--gpt2_checkpoint_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the TensorFlow checkpoint path.") 62 | parser.add_argument("--pytorch_dump_folder_path", 63 | default = None, 64 | type = str, 65 | required = True, 66 | help = "Path to the output PyTorch model.") 67 | parser.add_argument("--gpt2_config_file", 68 | default = "", 69 | type = str, 70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 71 | "This specifies the model architecture.") 72 | args = parser.parse_args() 73 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 74 | args.gpt2_config_file, 75 | args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/networks/Roberta2Tag.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import numpy as np 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch.nn import CrossEntropyLoss, NLLLoss 7 | from ..transformers import RobertaForTokenClassification 8 | 9 | logger = logging.getLogger() 10 | 11 | 12 | class RobertaForSeqTagging(RobertaForTokenClassification): 13 | 14 | def forward(self, visual_input, meta_input, input_ids, attention_mask, valid_ids, active_mask, valid_output, labels=None): 15 | 16 | # -------------------------------------------------------------------------------- 17 | # Bert Embedding Outputs 18 | 19 | outputs = self.roberta(input_ids=input_ids, 20 | attention_mask=attention_mask) 21 | 22 | attention_mask = attention_mask.to(torch.bool) 23 | 24 | sequence_output = outputs[0] # batch * len * bert_size 25 | 26 | # -------------------------------------------------------------------------------- 27 | # Meta-feature Predictor 28 | bert_cls = sequence_output[:, 0, :].squeeze(1) # batch * bert_size 29 | meta_cat = torch.cat([bert_cls, meta_input], -1) 30 | pred_mask_before_softmax = self.meta_selector(meta_cat) 31 | pred_mask = F.softmax(pred_mask_before_softmax, -1).unsqueeze(-1).unsqueeze(-1) 32 | 33 | # -------------------------------------------------------------------------------- 34 | # Visual Embedding Outputs 35 | visual_t = visual_input.transpose(0, 1) 36 | visual_embedding = self.visual_trans(visual_t, 37 | src_key_padding_mask=(~attention_mask)).transpose(0, 1) 38 | 39 | embedding = torch.cat([sequence_output, visual_embedding], -1).transpose(0, 1) # len * batch * embed_size 40 | 41 | phrase_embedding = self.phrase_trans(embedding, 42 | src_key_padding_mask=(~attention_mask)).transpose(0, 1) 43 | 44 | # -------------------------------------------------------------------------------- 45 | # Valid Outputs : get first token vector 46 | batch_size = phrase_embedding.size(0) 47 | for i in range(batch_size): 48 | valid_num = sum(valid_ids[i]).item() 49 | 50 | vectors = phrase_embedding[i][valid_ids[i] == 1] 51 | valid_output[i, :valid_num].copy_(vectors) 52 | 53 | # -------------------------------------------------------------------------------- 54 | # Dropout 55 | phrase_embedding = self.dropout(valid_output) 56 | logits = torch.cat( 57 | [self.tag_prediction[i](phrase_embedding).unsqueeze(1) 58 | for i in range(4)], 59 | 1 60 | ) 61 | pred = F.softmax(logits, -1) 62 | pred = torch.log(torch.sum(pred * pred_mask, 1)) 63 | 64 | 65 | # -------------------------------------------------------------------------------- 66 | # Active Logits 67 | active_loss = active_mask.view(-1) == 1 # [False, True, ...] 68 | active_logits = pred.view(-1, self.num_labels)[active_loss] # False 69 | 70 | if labels is not None: 71 | loss_fct = NLLLoss() 72 | active_labels = labels.view(-1)[active_loss] 73 | loss = loss_fct(active_logits, active_labels) 74 | return loss 75 | else: 76 | return active_logits -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/convert_openai_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from transformers import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | 33 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 34 | # Construct model 35 | if openai_config_file == "": 36 | config = OpenAIGPTConfig() 37 | else: 38 | config = OpenAIGPTConfig.from_json_file(openai_config_file) 39 | model = OpenAIGPTModel(config) 40 | 41 | # Load weights from numpy 42 | load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 48 | torch.save(model.state_dict(), pytorch_weights_dump_path) 49 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 51 | f.write(config.to_json_string()) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | ## Required parameters 57 | parser.add_argument("--openai_checkpoint_folder_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the TensorFlow checkpoint path.") 62 | parser.add_argument("--pytorch_dump_folder_path", 63 | default = None, 64 | type = str, 65 | required = True, 66 | help = "Path to the output PyTorch model.") 67 | parser.add_argument("--openai_config_file", 68 | default = "", 69 | type = str, 70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 71 | "This specifies the model architecture.") 72 | args = parser.parse_args() 73 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 74 | args.openai_config_file, 75 | args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/generator/Tag2Phrase.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | from ..constant import Tag2Idx 4 | from .generator_utils import remove_empty_phase, del_stemming_duplicate_phrase 5 | 6 | logger = logging.getLogger() 7 | 8 | def tag2phrase(examples, logit_lists, indices, max_phrase_words, pooling, return_num=None, stem_flag=False): 9 | batch_predictions = [] 10 | for batch_id, logit_list in enumerate(logit_lists): 11 | example = examples[indices[batch_id]] 12 | 13 | params = {'orig_tokens': example['doc_words'], 14 | 'token_logits': logit_list, 15 | 'max_gram':max_phrase_words, 16 | 'pooling':pooling} 17 | 18 | n_best_phrases_scores = decode_n_best_candidates(**params) 19 | candidate_KP, score_KP = remove_empty_phase(n_best_phrases_scores) 20 | 21 | if return_num: 22 | if stem_flag: 23 | candidate_KP, score_KP = del_stemming_duplicate_phrase(candidate_KP, score_KP, return_num) 24 | else: 25 | candidate_KP = candidate_KP[:return_num] 26 | score_KP = score_KP[:return_num] 27 | assert len(candidate_KP) == return_num 28 | 29 | assert len(candidate_KP) == len(score_KP) 30 | batch_predictions.append((example['url'], candidate_KP, score_KP)) 31 | 32 | return batch_predictions 33 | 34 | 35 | def decode_n_best_candidates(orig_tokens, token_logits, max_gram, pooling): 36 | ''' 37 | max_gram : type :int , max_phrase_words 38 | return : phrase token list & score list 39 | ''' 40 | assert len(orig_tokens) == len(token_logits) 41 | orig_tokens = [token.lower() for token in orig_tokens] 42 | 43 | ngrams = [] 44 | for n in range(1, max_gram+1): 45 | ngrams.extend(decode_ngram(orig_tokens, token_logits, n, pooling)) 46 | # sorted all n-grams 47 | sorted_ngrams = sorted(ngrams, key=lambda x: x[1], reverse=True) 48 | return sorted_ngrams 49 | 50 | 51 | def decode_ngram(orig_tokens, token_logits, n, pooling=None): 52 | ''' 53 | Combine n-gram score and sorted 54 | Inputs : 55 | n : n_gram 56 | orig_tokens : document lower cased words' list 57 | token_logits : each token has five score : for 'O', 'B', 'I', 'E', 'U' tag 58 | pooling : pooling method : mean / min / log_mean (min is the best) 59 | sum_tf : if True Sum All Mention 60 | Outputs : sorted phrase and socre list 61 | ''' 62 | if n == 1: 63 | ngram_ids= [Tag2Idx['U']] 64 | elif n >= 2: 65 | ngram_ids = [Tag2Idx['B']] + [Tag2Idx['I'] for _ in range(n-2)] + [Tag2Idx['E']] 66 | else: 67 | logger.info('invalid %d-gram !' %n) 68 | offsets = [i for i in range(len(ngram_ids))] 69 | 70 | # combine n-gram scores 71 | phrase_set = {} 72 | valid_length = (len(orig_tokens) - n + 1) 73 | for i in range(valid_length): 74 | 75 | n_gram = ' '.join(orig_tokens[i:i+n]) 76 | n_gram_score = min([token_logits[i+bias][tag] for bias, tag in zip(offsets, ngram_ids)]) 77 | 78 | if n_gram not in phrase_set or n_gram_score > phrase_set[n_gram]: 79 | phrase_set[n_gram] = n_gram_score 80 | else: 81 | continue 82 | 83 | phrase_list = [] 84 | for phrase, score in phrase_set.items(): 85 | phrase_list.append((phrase.split(), score)) 86 | 87 | sorted_phrase_list = sorted(phrase_list, key=lambda x: x[1], reverse=True) 88 | return sorted_phrase_list -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/convert_xlm_original_pytorch_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import json 21 | from io import open 22 | 23 | import torch 24 | import numpy 25 | 26 | from transformers import CONFIG_NAME, WEIGHTS_NAME 27 | from transformers.tokenization_xlm import VOCAB_FILES_NAMES 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path): 33 | # Load checkpoint 34 | chkpt = torch.load(xlm_checkpoint_path, map_location='cpu') 35 | 36 | state_dict = chkpt['model'] 37 | 38 | # We have the base model one level deeper than the original XLM repository 39 | two_levels_state_dict = {} 40 | for k, v in state_dict.items(): 41 | if 'pred_layer' in k: 42 | two_levels_state_dict[k] = v 43 | else: 44 | two_levels_state_dict['transformer.' + k] = v 45 | 46 | config = chkpt['params'] 47 | config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray))) 48 | 49 | vocab = chkpt['dico_word2id'] 50 | vocab = dict((s + '' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in vocab.items()) 51 | 52 | # Save pytorch-model 53 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 54 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 55 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file'] 56 | 57 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 58 | torch.save(two_levels_state_dict, pytorch_weights_dump_path) 59 | 60 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 61 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 62 | f.write(json.dumps(config, indent=2) + "\n") 63 | 64 | print("Save vocab file to {}".format(pytorch_config_dump_path)) 65 | with open(pytorch_vocab_dump_path, "w", encoding="utf-8") as f: 66 | f.write(json.dumps(vocab, indent=2) + "\n") 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser() 71 | ## Required parameters 72 | parser.add_argument("--xlm_checkpoint_path", 73 | default = None, 74 | type = str, 75 | required = True, 76 | help = "Path the official PyTorch dump.") 77 | parser.add_argument("--pytorch_dump_folder_path", 78 | default = None, 79 | type = str, 80 | required = True, 81 | help = "Path to the output PyTorch model.") 82 | args = parser.parse_args() 83 | convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path) 84 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/tokenization_xlm_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | import pytest 21 | 22 | from transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES 23 | 24 | from .tokenization_tests_commons import CommonTestCases 25 | 26 | class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | 28 | tokenizer_class = XLMTokenizer 29 | 30 | def setUp(self): 31 | super(XLMTokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "w", "r", "t", 36 | "lo", "low", "er", 37 | "low", "lowest", "newer", "wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["l o 123", "lo w 1456", "e r 1789", ""] 40 | 41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 42 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 43 | with open(self.vocab_file, "w") as fp: 44 | fp.write(json.dumps(vocab_tokens)) 45 | with open(self.merges_file, "w") as fp: 46 | fp.write("\n".join(merges)) 47 | 48 | def get_tokenizer(self, **kwargs): 49 | return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs) 50 | 51 | def get_input_output_texts(self): 52 | input_text = u"lower newer" 53 | output_text = u"lower newer" 54 | return input_text, output_text 55 | 56 | def test_full_tokenizer(self): 57 | """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ 58 | tokenizer = XLMTokenizer(self.vocab_file, self.merges_file) 59 | 60 | text = "lower" 61 | bpe_tokens = ["low", "er"] 62 | tokens = tokenizer.tokenize(text) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [""] 66 | input_bpe_tokens = [14, 15, 20] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | @pytest.mark.slow 71 | def test_sequence_builders(self): 72 | tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048") 73 | 74 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 75 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 76 | 77 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 78 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 79 | 80 | assert encoded_sentence == [1] + text + [1] 81 | assert encoded_pair == [1] + text + [1] + text_2 + [1] 82 | 83 | if __name__ == '__main__': 84 | unittest.main() 85 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/configuration_distilbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ DistilBERT model configuration """ 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | from io import open 23 | 24 | from .configuration_utils import PretrainedConfig 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 29 | 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json", 30 | 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json" 31 | } 32 | 33 | 34 | class DistilBertConfig(PretrainedConfig): 35 | pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP 36 | 37 | def __init__(self, 38 | vocab_size_or_config_json_file=30522, 39 | max_position_embeddings=512, 40 | sinusoidal_pos_embds=False, 41 | n_layers=6, 42 | n_heads=12, 43 | dim=768, 44 | hidden_dim=4*768, 45 | dropout=0.1, 46 | attention_dropout=0.1, 47 | activation='gelu', 48 | initializer_range=0.02, 49 | tie_weights_=True, 50 | qa_dropout=0.1, 51 | seq_classif_dropout=0.2, 52 | **kwargs): 53 | super(DistilBertConfig, self).__init__(**kwargs) 54 | 55 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 56 | and isinstance(vocab_size_or_config_json_file, unicode)): 57 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 58 | json_config = json.loads(reader.read()) 59 | for key, value in json_config.items(): 60 | self.__dict__[key] = value 61 | elif isinstance(vocab_size_or_config_json_file, int): 62 | self.vocab_size = vocab_size_or_config_json_file 63 | self.max_position_embeddings = max_position_embeddings 64 | self.sinusoidal_pos_embds = sinusoidal_pos_embds 65 | self.n_layers = n_layers 66 | self.n_heads = n_heads 67 | self.dim = dim 68 | self.hidden_dim = hidden_dim 69 | self.dropout = dropout 70 | self.attention_dropout = attention_dropout 71 | self.activation = activation 72 | self.initializer_range = initializer_range 73 | self.tie_weights_ = tie_weights_ 74 | self.qa_dropout = qa_dropout 75 | self.seq_classif_dropout = seq_classif_dropout 76 | else: 77 | raise ValueError("First argument must be either a vocabulary size (int)" 78 | " or the path to a pretrained model config file (str)") 79 | @property 80 | def hidden_size(self): 81 | return self.dim 82 | 83 | @property 84 | def num_attention_heads(self): 85 | return self.n_heads 86 | 87 | @property 88 | def num_hidden_layers(self): 89 | return self.n_layers 90 | -------------------------------------------------------------------------------- /SMART-KPE/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from transformers import * 6 | 7 | def positional_encoding(max_len, dim): 8 | a = np.arange(max_len)[:, np.newaxis] 9 | b = np.arange(dim)[np.newaxis, :] 10 | angle = a / np.power(10000., (b // 2 * 2) / 10) 11 | sines = np.sin(angle[:, 0::2]) 12 | cosines = np.cos(angle[:, 1::2]) 13 | return np.concatenate([sines, cosines], -1) # max_len * dim 14 | 15 | class BLING_KPE(nn.Module): 16 | def __init__(self, args): 17 | super(BLING_KPE, self).__init__() 18 | # TODO: initializing all the parameters 19 | self.args = args 20 | self.BERT = BertModel.from_pretrained('bert-base-uncased') 21 | self.BERT.resize_token_embeddings(len(args.tokenizer)) 22 | embed_size = args.bert_size + args.visual_size 23 | visual_trans_layer = nn.TransformerEncoderLayer(d_model=args.visual_size, nhead=3) 24 | self.visual_trans = nn.TransformerEncoder(visual_trans_layer, num_layers=2) 25 | 26 | self.meta_dim = 0 27 | self.meta_dim += args.bert_size 28 | if args.use_snapshot: 29 | self.meta_dim += args.snapshot_dim 30 | assert self.meta_dim > 0, "At least one of the meta data should be used" 31 | 32 | self.meta_selector = nn.Sequential( 33 | nn.Linear(self.meta_dim, 256), 34 | nn.ReLU(), 35 | nn.Linear(256, args.num_trans), 36 | ) 37 | 38 | phrase_trans_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=6) 39 | self.phrase_trans = nn.TransformerEncoder(phrase_trans_layer, num_layers=2) 40 | 41 | #self.pos_embed = torch.from_numpy(positional_encoding(args.max_text_length, args.positional_size)).to(args.device,dtype=torch.float) 42 | 43 | self.tag_prediction = nn.ModuleList([ 44 | nn.Sequential( 45 | nn.Linear(embed_size, 128), 46 | nn.ReLU(), 47 | nn.Linear(128, args.tag_num), 48 | ) for i in range(args.num_trans)]) 49 | 50 | self.dropout = nn.Dropout(p=0.2) 51 | 52 | def forward(self, text_id, visual_input, input_mask, meta, valid_id=None): 53 | """ 54 | text_id: batch * len 55 | position_input: batch * len * pos_size 56 | visual_input: batch * len * visual_size 57 | meta: batch * meta_dim 58 | """ 59 | bert_embedding,_ = self.BERT(text_id, attention_mask = input_mask) 60 | batch, length, _ = bert_embedding.size() 61 | 62 | bert_cls = bert_embedding[:, 0, :].squeeze(1) # batch * bert_size 63 | 64 | visual_t = visual_input.transpose(0, 1) # len * batch * visual_size 65 | visual_embedding = self.visual_trans(visual_t, src_key_padding_mask=(~input_mask)).transpose(0, 1) # batch * len * visual_size 66 | 67 | embedding = torch.cat([bert_embedding, visual_embedding], -1).transpose(0, 1) # len * batch * embed_size 68 | 69 | phrase_embedding = self.phrase_trans(embedding, src_key_padding_mask=(~input_mask)).transpose(0, 1) 70 | ''' 71 | batch_size = phrase_embedding.size(0) 72 | for i in range(batch_size): 73 | valid_num = sum(valid_id[i]).item() 74 | vectors = phrase_embedding[i][valid_id[i]==1] 75 | phrase_embedding[i,:valid_num].copy_(vectors) 76 | phrase_embedding[i,valid_num:] = 0 77 | ''' 78 | pred_before_softmax = torch.cat([self.tag_prediction[i](phrase_embedding).unsqueeze(1) for i in range(self.args.num_trans)], 1) # batch * num_trans * len * tag_num 79 | 80 | meta = torch.mean(meta.view(batch, 4, -1), 1) # batch * 512 81 | meta_cat = torch.cat([bert_cls, meta], -1) # batch * meta_size 82 | # meta_cat = bert_cls 83 | pred_mask_before_softmax = self.meta_selector(meta_cat) # batch * num_trans 84 | pred_mask = F.softmax(pred_mask_before_softmax, -1).unsqueeze(-1).unsqueeze(-1) # batch * num_trans * 1 * 1 85 | pred = F.softmax(pred_before_softmax, -1) # batch * num_trans * len * tag_num 86 | pred = torch.log(torch.sum(pred * pred_mask, 1)) # batch * len * tag_num 87 | return pred 88 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/fixtures/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/modeling_tf_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import pytest 22 | import logging 23 | 24 | from transformers import is_tf_available 25 | 26 | if is_tf_available(): 27 | from transformers import (AutoConfig, BertConfig, 28 | TFAutoModel, TFBertModel, 29 | TFAutoModelWithLMHead, TFBertForMaskedLM, 30 | TFAutoModelForSequenceClassification, TFBertForSequenceClassification, 31 | TFAutoModelForQuestionAnswering, TFBertForQuestionAnswering) 32 | from transformers.modeling_tf_bert import TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP 33 | 34 | from .modeling_common_test import (CommonTestCases, ids_tensor) 35 | from .configuration_common_test import ConfigTester 36 | else: 37 | pytestmark = pytest.mark.skip("Require TensorFlow") 38 | 39 | 40 | class TFAutoModelTest(unittest.TestCase): 41 | def test_model_from_pretrained(self): 42 | import h5py 43 | self.assertTrue(h5py.version.hdf5_version.startswith("1.10")) 44 | 45 | logging.basicConfig(level=logging.INFO) 46 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 47 | for model_name in ['bert-base-uncased']: 48 | config = AutoConfig.from_pretrained(model_name, force_download=True) 49 | self.assertIsNotNone(config) 50 | self.assertIsInstance(config, BertConfig) 51 | 52 | model = TFAutoModel.from_pretrained(model_name, force_download=True) 53 | self.assertIsNotNone(model) 54 | self.assertIsInstance(model, TFBertModel) 55 | 56 | def test_lmhead_model_from_pretrained(self): 57 | logging.basicConfig(level=logging.INFO) 58 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 59 | for model_name in ['bert-base-uncased']: 60 | config = AutoConfig.from_pretrained(model_name, force_download=True) 61 | self.assertIsNotNone(config) 62 | self.assertIsInstance(config, BertConfig) 63 | 64 | model = TFAutoModelWithLMHead.from_pretrained(model_name, force_download=True) 65 | self.assertIsNotNone(model) 66 | self.assertIsInstance(model, TFBertForMaskedLM) 67 | 68 | def test_sequence_classification_model_from_pretrained(self): 69 | logging.basicConfig(level=logging.INFO) 70 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 71 | for model_name in ['bert-base-uncased']: 72 | config = AutoConfig.from_pretrained(model_name, force_download=True) 73 | self.assertIsNotNone(config) 74 | self.assertIsInstance(config, BertConfig) 75 | 76 | model = TFAutoModelForSequenceClassification.from_pretrained(model_name, force_download=True) 77 | self.assertIsNotNone(model) 78 | self.assertIsInstance(model, TFBertForSequenceClassification) 79 | 80 | def test_question_answering_model_from_pretrained(self): 81 | logging.basicConfig(level=logging.INFO) 82 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 83 | for model_name in ['bert-base-uncased']: 84 | config = AutoConfig.from_pretrained(model_name, force_download=True) 85 | self.assertIsNotNone(config) 86 | self.assertIsInstance(config, BertConfig) 87 | 88 | model = TFAutoModelForQuestionAnswering.from_pretrained(model_name, force_download=True) 89 | self.assertIsNotNone(model) 90 | self.assertIsInstance(model, TFBertForQuestionAnswering) 91 | 92 | 93 | if __name__ == "__main__": 94 | unittest.main() 95 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/modeling_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import pytest 22 | import logging 23 | 24 | from transformers import is_torch_available 25 | 26 | if is_torch_available(): 27 | from transformers import (AutoConfig, BertConfig, 28 | AutoModel, BertModel, 29 | AutoModelWithLMHead, BertForMaskedLM, 30 | AutoModelForSequenceClassification, BertForSequenceClassification, 31 | AutoModelForQuestionAnswering, BertForQuestionAnswering) 32 | from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 33 | 34 | from .modeling_common_test import (CommonTestCases, ids_tensor) 35 | from .configuration_common_test import ConfigTester 36 | else: 37 | pytestmark = pytest.mark.skip("Require Torch") 38 | 39 | 40 | class AutoModelTest(unittest.TestCase): 41 | @pytest.mark.slow 42 | def test_model_from_pretrained(self): 43 | logging.basicConfig(level=logging.INFO) 44 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 45 | config = AutoConfig.from_pretrained(model_name) 46 | self.assertIsNotNone(config) 47 | self.assertIsInstance(config, BertConfig) 48 | 49 | model = AutoModel.from_pretrained(model_name) 50 | model, loading_info = AutoModel.from_pretrained(model_name, output_loading_info=True) 51 | self.assertIsNotNone(model) 52 | self.assertIsInstance(model, BertModel) 53 | for value in loading_info.values(): 54 | self.assertEqual(len(value), 0) 55 | 56 | @pytest.mark.slow 57 | def test_lmhead_model_from_pretrained(self): 58 | logging.basicConfig(level=logging.INFO) 59 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 60 | config = AutoConfig.from_pretrained(model_name) 61 | self.assertIsNotNone(config) 62 | self.assertIsInstance(config, BertConfig) 63 | 64 | model = AutoModelWithLMHead.from_pretrained(model_name) 65 | model, loading_info = AutoModelWithLMHead.from_pretrained(model_name, output_loading_info=True) 66 | self.assertIsNotNone(model) 67 | self.assertIsInstance(model, BertForMaskedLM) 68 | 69 | @pytest.mark.slow 70 | def test_sequence_classification_model_from_pretrained(self): 71 | logging.basicConfig(level=logging.INFO) 72 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 73 | config = AutoConfig.from_pretrained(model_name) 74 | self.assertIsNotNone(config) 75 | self.assertIsInstance(config, BertConfig) 76 | 77 | model = AutoModelForSequenceClassification.from_pretrained(model_name) 78 | model, loading_info = AutoModelForSequenceClassification.from_pretrained(model_name, output_loading_info=True) 79 | self.assertIsNotNone(model) 80 | self.assertIsInstance(model, BertForSequenceClassification) 81 | 82 | @pytest.mark.slow 83 | def test_question_answering_model_from_pretrained(self): 84 | logging.basicConfig(level=logging.INFO) 85 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 86 | config = AutoConfig.from_pretrained(model_name) 87 | self.assertIsNotNone(config) 88 | self.assertIsInstance(config, BertConfig) 89 | 90 | model = AutoModelForQuestionAnswering.from_pretrained(model_name) 91 | model, loading_info = AutoModelForQuestionAnswering.from_pretrained(model_name, output_loading_info=True) 92 | self.assertIsNotNone(model) 93 | self.assertIsInstance(model, BertForQuestionAnswering) 94 | 95 | 96 | if __name__ == "__main__": 97 | unittest.main() 98 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/tokenization_roberta_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import json 19 | import unittest 20 | import pytest 21 | from io import open 22 | 23 | from transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES 24 | from .tokenization_tests_commons import CommonTestCases 25 | 26 | 27 | class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): 28 | tokenizer_class = RobertaTokenizer 29 | 30 | def setUp(self): 31 | super(RobertaTokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "\u0120", "\u0120l", "\u0120n", 36 | "\u0120lo", "\u0120low", "er", 37 | "\u0120lowest", "\u0120newer", "\u0120wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] 40 | self.special_tokens_map = {"unk_token": ""} 41 | 42 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 43 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 44 | with open(self.vocab_file, "w", encoding="utf-8") as fp: 45 | fp.write(json.dumps(vocab_tokens) + "\n") 46 | with open(self.merges_file, "w", encoding="utf-8") as fp: 47 | fp.write("\n".join(merges)) 48 | 49 | def get_tokenizer(self, **kwargs): 50 | kwargs.update(self.special_tokens_map) 51 | return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs) 52 | 53 | def get_input_output_texts(self): 54 | input_text = u"lower newer" 55 | output_text = u"lower newer" 56 | return input_text, output_text 57 | 58 | def test_full_tokenizer(self): 59 | tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 60 | text = "lower newer" 61 | bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] 62 | tokens = tokenizer.tokenize(text, add_prefix_space=True) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [tokenizer.unk_token] 66 | input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | def roberta_dict_integration_testing(self): 71 | tokenizer = self.get_tokenizer() 72 | 73 | self.assertListEqual( 74 | tokenizer.encode('Hello world!', add_special_tokens=False), 75 | [0, 31414, 232, 328, 2] 76 | ) 77 | self.assertListEqual( 78 | tokenizer.encode('Hello world! cécé herlolip 418', add_special_tokens=False), 79 | [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2] 80 | ) 81 | 82 | @pytest.mark.slow 83 | def test_sequence_builders(self): 84 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 85 | 86 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 87 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 88 | 89 | encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True) 90 | encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True) 91 | 92 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 93 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 94 | 95 | assert encoded_sentence == encoded_text_from_decode 96 | assert encoded_pair == encoded_pair_from_decode 97 | 98 | 99 | if __name__ == '__main__': 100 | unittest.main() 101 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/convert_xlnet_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import argparse 23 | import torch 24 | 25 | from transformers import (CONFIG_NAME, WEIGHTS_NAME, 26 | XLNetConfig, 27 | XLNetLMHeadModel, XLNetForQuestionAnswering, 28 | XLNetForSequenceClassification, 29 | load_tf_weights_in_xlnet) 30 | 31 | GLUE_TASKS_NUM_LABELS = { 32 | "cola": 2, 33 | "mnli": 3, 34 | "mrpc": 2, 35 | "sst-2": 2, 36 | "sts-b": 1, 37 | "qqp": 2, 38 | "qnli": 2, 39 | "rte": 2, 40 | "wnli": 2, 41 | } 42 | 43 | import logging 44 | logging.basicConfig(level=logging.INFO) 45 | 46 | def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None): 47 | # Initialise PyTorch model 48 | config = XLNetConfig.from_json_file(bert_config_file) 49 | 50 | finetuning_task = finetuning_task.lower() if finetuning_task is not None else "" 51 | if finetuning_task in GLUE_TASKS_NUM_LABELS: 52 | print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config))) 53 | config.finetuning_task = finetuning_task 54 | config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task] 55 | model = XLNetForSequenceClassification(config) 56 | elif 'squad' in finetuning_task: 57 | config.finetuning_task = finetuning_task 58 | model = XLNetForQuestionAnswering(config) 59 | else: 60 | model = XLNetLMHeadModel(config) 61 | 62 | # Load weights from tf checkpoint 63 | load_tf_weights_in_xlnet(model, config, tf_checkpoint_path) 64 | 65 | # Save pytorch-model 66 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 67 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 68 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 69 | torch.save(model.state_dict(), pytorch_weights_dump_path) 70 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 71 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 72 | f.write(config.to_json_string()) 73 | 74 | 75 | if __name__ == "__main__": 76 | parser = argparse.ArgumentParser() 77 | ## Required parameters 78 | parser.add_argument("--tf_checkpoint_path", 79 | default = None, 80 | type = str, 81 | required = True, 82 | help = "Path to the TensorFlow checkpoint path.") 83 | parser.add_argument("--xlnet_config_file", 84 | default = None, 85 | type = str, 86 | required = True, 87 | help = "The config json file corresponding to the pre-trained XLNet model. \n" 88 | "This specifies the model architecture.") 89 | parser.add_argument("--pytorch_dump_folder_path", 90 | default = None, 91 | type = str, 92 | required = True, 93 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 94 | parser.add_argument("--finetuning_task", 95 | default = None, 96 | type = str, 97 | help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned") 98 | args = parser.parse_args() 99 | print(args) 100 | 101 | convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path, 102 | args.xlnet_config_file, 103 | args.pytorch_dump_folder_path, 104 | args.finetuning_task) 105 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/convert_bert_pytorch_checkpoint_to_original_tf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint.""" 17 | 18 | import os 19 | import argparse 20 | import torch 21 | import numpy as np 22 | import tensorflow as tf 23 | from transformers import BertModel 24 | 25 | 26 | def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str): 27 | 28 | """ 29 | :param model:BertModel Pytorch model instance to be converted 30 | :param ckpt_dir: Tensorflow model directory 31 | :param model_name: model name 32 | :return: 33 | 34 | Currently supported HF models: 35 | Y BertModel 36 | N BertForMaskedLM 37 | N BertForPreTraining 38 | N BertForMultipleChoice 39 | N BertForNextSentencePrediction 40 | N BertForSequenceClassification 41 | N BertForQuestionAnswering 42 | """ 43 | 44 | tensors_to_transpose = ( 45 | "dense.weight", 46 | "attention.self.query", 47 | "attention.self.key", 48 | "attention.self.value" 49 | ) 50 | 51 | var_map = ( 52 | ('layer.', 'layer_'), 53 | ('word_embeddings.weight', 'word_embeddings'), 54 | ('position_embeddings.weight', 'position_embeddings'), 55 | ('token_type_embeddings.weight', 'token_type_embeddings'), 56 | ('.', '/'), 57 | ('LayerNorm/weight', 'LayerNorm/gamma'), 58 | ('LayerNorm/bias', 'LayerNorm/beta'), 59 | ('weight', 'kernel') 60 | ) 61 | 62 | if not os.path.isdir(ckpt_dir): 63 | os.makedirs(ckpt_dir) 64 | 65 | state_dict = model.state_dict() 66 | 67 | def to_tf_var_name(name:str): 68 | for patt, repl in iter(var_map): 69 | name = name.replace(patt, repl) 70 | return 'bert/{}'.format(name) 71 | 72 | def create_tf_var(tensor:np.ndarray, name:str, session:tf.Session): 73 | tf_dtype = tf.dtypes.as_dtype(tensor.dtype) 74 | tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer()) 75 | session.run(tf.variables_initializer([tf_var])) 76 | session.run(tf_var) 77 | return tf_var 78 | 79 | tf.reset_default_graph() 80 | with tf.Session() as session: 81 | for var_name in state_dict: 82 | tf_name = to_tf_var_name(var_name) 83 | torch_tensor = state_dict[var_name].numpy() 84 | if any([x in var_name for x in tensors_to_transpose]): 85 | torch_tensor = torch_tensor.T 86 | tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session) 87 | tf.keras.backend.set_value(tf_var, torch_tensor) 88 | tf_weight = session.run(tf_var) 89 | print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor))) 90 | 91 | saver = tf.train.Saver(tf.trainable_variables()) 92 | saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt")) 93 | 94 | 95 | def main(raw_args=None): 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--model_name", 98 | type=str, 99 | required=True, 100 | help="model name e.g. bert-base-uncased") 101 | parser.add_argument("--cache_dir", 102 | type=str, 103 | default=None, 104 | required=False, 105 | help="Directory containing pytorch model") 106 | parser.add_argument("--pytorch_model_path", 107 | type=str, 108 | required=True, 109 | help="/path/to/.bin") 110 | parser.add_argument("--tf_cache_dir", 111 | type=str, 112 | required=True, 113 | help="Directory in which to save tensorflow model") 114 | args = parser.parse_args(raw_args) 115 | 116 | model = BertModel.from_pretrained( 117 | pretrained_model_name_or_path=args.model_name, 118 | state_dict=torch.load(args.pytorch_model_path), 119 | cache_dir=args.cache_dir 120 | ) 121 | 122 | convert_pytorch_checkpoint_to_tf( 123 | model=model, 124 | ckpt_dir=args.tf_cache_dir, 125 | model_name=args.model_name 126 | ) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/evaluator/kp20k_evaluator.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import re 4 | import string 5 | import logging 6 | import numpy as np 7 | import unicodedata 8 | 9 | from nltk.stem.porter import PorterStemmer 10 | stemmer = PorterStemmer() 11 | 12 | logger = logging.getLogger() 13 | 14 | 15 | def normalize_answer(s): 16 | def remove_articles(text): 17 | return re.sub(r'\b(a|an|the)\b', ' ', text) 18 | def white_space_fix(text): 19 | return ' '.join(text.split()) 20 | def remove_punc(text): 21 | exclude = set(string.punctuation) 22 | return ''.join(ch for ch in text if ch not in exclude) 23 | def lower(text): 24 | return text.lower() 25 | return ' '.join([lower(x) for x in s]).rstrip() 26 | 27 | 28 | def remove_empty(a_list): 29 | new_list = [] 30 | for i in a_list: 31 | if len(i) > 0: 32 | if len(i[0]) >0: 33 | new_list.append(normalize_answer(i)) 34 | return new_list 35 | 36 | 37 | # ---------------------------------------------------------------------------- 38 | # stem phrase 39 | def norm_and_stem(phrase_list, merge=False): 40 | 41 | norm_stem_phrases = [] 42 | for phrase in phrase_list: 43 | norm_chars = unicodedata.normalize('NFD', phrase) 44 | stem_chars = " ".join([stemmer.stem(w) for w in norm_chars.split(" ")]) 45 | if merge: 46 | norm_stem_phrases.append(norm_chars) 47 | norm_stem_phrases.append(stem_chars) 48 | else: 49 | norm_stem_phrases.append((norm_chars, stem_chars)) 50 | return norm_stem_phrases 51 | 52 | 53 | def get_match_scores(pred_list, truth_list): 54 | match_score = np.asarray([0.0] * len(pred_list), dtype='float32') 55 | 56 | norm_stem_preds = norm_and_stem(pred_list) 57 | norm_stem_truths = norm_and_stem(truth_list, merge=True) 58 | 59 | for pred_id, pred_seq in enumerate(norm_stem_preds): 60 | if pred_seq[0] in norm_stem_truths or pred_seq[1] in norm_stem_truths: 61 | match_score[pred_id] = 1 62 | return match_score 63 | 64 | # ---------------------------------------------------------------------------- 65 | 66 | def evaluate(candidates, references, urls): 67 | precision_scores, recall_scores, f1_scores = {1:[], 3:[], 5:[], 10:[]}, {1:[], 3:[], 5:[], 10:[]}, {1:[], 3:[], 5:[], 10:[]} 68 | for url in urls: 69 | candidate = remove_empty(candidates[url]['KeyPhrases']) # covert word list to string 70 | reference = remove_empty(references[url]['KeyPhrases']) # have remove empty 71 | 72 | # stem match scores 73 | match_list = get_match_scores(candidate, reference) 74 | 75 | # Padding 76 | if len(match_list) < 10: 77 | for _ in range(10-len(match_list)): 78 | candidate.append('') 79 | assert len(candidate) == 10 80 | 81 | for topk in [1, 3, 5, 10]: 82 | 83 | # Micro-Averaged Method 84 | micropk = float(sum(match_list[:topk])) / float(len(candidate[:topk])) if len(candidate[:topk]) > 0 else 0.0 85 | micrork = float(sum(match_list[:topk])) / float(len(reference)) if len(reference) > 0 else 0.0 86 | 87 | if micropk + micrork > 0: 88 | microf1 = float(2 * (micropk * micrork)) / (micropk + micrork) 89 | else: 90 | microf1 = 0.0 91 | 92 | precision_scores[topk].append(micropk) 93 | recall_scores[topk].append(micrork) 94 | f1_scores[topk].append(microf1) 95 | 96 | return f1_scores, precision_scores, recall_scores 97 | 98 | def files_are_good(candidate, reference): 99 | referenceURLs = set(reference.keys()) 100 | candidateURLs = set(candidate.keys()) 101 | if len((referenceURLs - candidateURLs)) > 0: 102 | logger.info("ERROR:Candidate File is missing URLS present in reference file\nMissing urls:{}".format(referenceURLs-candidateURLs)) 103 | return False 104 | if len((candidateURLs - referenceURLs)) > 0: 105 | logger.info("ERROR:Candidate File includes URLS not present in reference file\nUnexpected urls:{}".format(candidateURLs-referenceURLs)) 106 | return False 107 | return True 108 | 109 | def load_file(filename): 110 | data = {} 111 | with open(filename,'r') as f: 112 | for l in f: 113 | item = json.loads(l) 114 | data[item['url']] = item 115 | return data 116 | 117 | 118 | def evaluate_kp20k(candidate, reference_filename): 119 | reference = load_file(reference_filename) 120 | if files_are_good(candidate, reference) == True: 121 | candidate_urls = set(candidate.keys()) 122 | reference_urls = set(reference.keys()) 123 | urls = reference_urls.intersection(candidate_urls) 124 | f1_scores, precision_scores, recall_scores = evaluate(candidate, reference, urls) 125 | return f1_scores, precision_scores, recall_scores 126 | else: 127 | logger.info("Candidate file and Reference are not comparable. Please verify your candidate file.") 128 | 129 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/data/processors/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import csv 18 | import sys 19 | import copy 20 | import json 21 | 22 | class InputExample(object): 23 | """ 24 | A single training/test example for simple sequence classification. 25 | 26 | Args: 27 | guid: Unique id for the example. 28 | text_a: string. The untokenized text of the first sequence. For single 29 | sequence tasks, only this sequence must be specified. 30 | text_b: (Optional) string. The untokenized text of the second sequence. 31 | Only must be specified for sequence pair tasks. 32 | label: (Optional) string. The label of the example. This should be 33 | specified for train and dev examples, but not for test examples. 34 | """ 35 | def __init__(self, guid, text_a, text_b=None, label=None): 36 | self.guid = guid 37 | self.text_a = text_a 38 | self.text_b = text_b 39 | self.label = label 40 | 41 | def __repr__(self): 42 | return str(self.to_json_string()) 43 | 44 | def to_dict(self): 45 | """Serializes this instance to a Python dictionary.""" 46 | output = copy.deepcopy(self.__dict__) 47 | return output 48 | 49 | def to_json_string(self): 50 | """Serializes this instance to a JSON string.""" 51 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 52 | 53 | 54 | class InputFeatures(object): 55 | """ 56 | A single set of features of data. 57 | 58 | Args: 59 | input_ids: Indices of input sequence tokens in the vocabulary. 60 | attention_mask: Mask to avoid performing attention on padding token indices. 61 | Mask values selected in ``[0, 1]``: 62 | Usually ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens. 63 | token_type_ids: Segment token indices to indicate first and second portions of the inputs. 64 | label: Label corresponding to the input 65 | """ 66 | 67 | def __init__(self, input_ids, attention_mask, token_type_ids, label): 68 | self.input_ids = input_ids 69 | self.attention_mask = attention_mask 70 | self.token_type_ids = token_type_ids 71 | self.label = label 72 | 73 | def __repr__(self): 74 | return str(self.to_json_string()) 75 | 76 | def to_dict(self): 77 | """Serializes this instance to a Python dictionary.""" 78 | output = copy.deepcopy(self.__dict__) 79 | return output 80 | 81 | def to_json_string(self): 82 | """Serializes this instance to a JSON string.""" 83 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 84 | 85 | 86 | class DataProcessor(object): 87 | """Base class for data converters for sequence classification data sets.""" 88 | 89 | def get_example_from_tensor_dict(self, tensor_dict): 90 | """Gets an example from a dict with tensorflow tensors 91 | 92 | Args: 93 | tensor_dict: Keys and values should match the corresponding Glue 94 | tensorflow_dataset examples. 95 | """ 96 | raise NotImplementedError() 97 | 98 | def get_train_examples(self, data_dir): 99 | """Gets a collection of `InputExample`s for the train set.""" 100 | raise NotImplementedError() 101 | 102 | def get_dev_examples(self, data_dir): 103 | """Gets a collection of `InputExample`s for the dev set.""" 104 | raise NotImplementedError() 105 | 106 | def get_labels(self): 107 | """Gets the list of labels for this data set.""" 108 | raise NotImplementedError() 109 | 110 | def tfds_map(self, example): 111 | """Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are. 112 | This method converts examples to the correct format.""" 113 | if len(self.get_labels()) > 1: 114 | example.label = self.get_labels()[int(example.label)] 115 | return example 116 | 117 | @classmethod 118 | def _read_tsv(cls, input_file, quotechar=None): 119 | """Reads a tab separated value file.""" 120 | with open(input_file, "r", encoding="utf-8-sig") as f: 121 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 122 | lines = [] 123 | for line in reader: 124 | if sys.version_info[0] == 2: 125 | line = list(unicode(cell, 'utf-8') for cell in line) 126 | lines.append(line) 127 | return lines 128 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/configuration_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ OpenAI GPT configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 30 | "openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json" 31 | } 32 | 33 | class OpenAIGPTConfig(PretrainedConfig): 34 | """ 35 | Configuration class to store the configuration of a `OpenAIGPTModel`. 36 | 37 | Args: 38 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file. 39 | n_positions: Number of positional embeddings. 40 | n_ctx: Size of the causal mask (usually same as n_positions). 41 | n_embd: Dimensionality of the embeddings and hidden states. 42 | n_layer: Number of hidden layers in the Transformer encoder. 43 | n_head: Number of attention heads for each attention layer in 44 | the Transformer encoder. 45 | afn: The non-linear activation function (function or string) in the 46 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 47 | resid_pdrop: The dropout probabilitiy for all fully connected 48 | layers in the embeddings, encoder, and pooler. 49 | attn_pdrop: The dropout ratio for the attention 50 | probabilities. 51 | embd_pdrop: The dropout ratio for the embeddings. 52 | layer_norm_epsilon: epsilon to use in the layer norm layers 53 | initializer_range: The sttdev of the truncated_normal_initializer for 54 | initializing all weight matrices. 55 | predict_special_tokens: should we predict special tokens (when the model has a LM head) 56 | """ 57 | pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP 58 | 59 | def __init__( 60 | self, 61 | vocab_size_or_config_json_file=40478, 62 | n_positions=512, 63 | n_ctx=512, 64 | n_embd=768, 65 | n_layer=12, 66 | n_head=12, 67 | afn="gelu", 68 | resid_pdrop=0.1, 69 | embd_pdrop=0.1, 70 | attn_pdrop=0.1, 71 | layer_norm_epsilon=1e-5, 72 | initializer_range=0.02, 73 | predict_special_tokens=True, 74 | 75 | num_labels=1, 76 | summary_type='cls_index', 77 | summary_use_proj=True, 78 | summary_activation=None, 79 | summary_proj_to_labels=True, 80 | summary_first_dropout=0.1, 81 | **kwargs 82 | ): 83 | """Constructs OpenAIGPTConfig. 84 | """ 85 | super(OpenAIGPTConfig, self).__init__(**kwargs) 86 | 87 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 88 | and isinstance(vocab_size_or_config_json_file, unicode)): 89 | with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: 90 | json_config = json.loads(reader.read()) 91 | for key, value in json_config.items(): 92 | self.__dict__[key] = value 93 | elif isinstance(vocab_size_or_config_json_file, int): 94 | self.vocab_size = vocab_size_or_config_json_file 95 | self.n_ctx = n_ctx 96 | self.n_positions = n_positions 97 | self.n_embd = n_embd 98 | self.n_layer = n_layer 99 | self.n_head = n_head 100 | self.afn = afn 101 | self.resid_pdrop = resid_pdrop 102 | self.embd_pdrop = embd_pdrop 103 | self.attn_pdrop = attn_pdrop 104 | self.layer_norm_epsilon = layer_norm_epsilon 105 | self.initializer_range = initializer_range 106 | self.predict_special_tokens = predict_special_tokens 107 | 108 | self.num_labels = num_labels 109 | self.summary_type = summary_type 110 | self.summary_use_proj = summary_use_proj 111 | self.summary_activation = summary_activation 112 | self.summary_first_dropout = summary_first_dropout 113 | self.summary_proj_to_labels = summary_proj_to_labels 114 | else: 115 | raise ValueError( 116 | "First argument must be either a vocabulary size (int)" 117 | "or the path to a pretrained model config file (str)" 118 | ) 119 | 120 | @property 121 | def max_position_embeddings(self): 122 | return self.n_positions 123 | 124 | @property 125 | def hidden_size(self): 126 | return self.n_embd 127 | 128 | @property 129 | def num_attention_heads(self): 130 | return self.n_head 131 | 132 | @property 133 | def num_hidden_layers(self): 134 | return self.n_layer 135 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/tokenization_xlnet_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import pytest 20 | 21 | from transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), 26 | 'fixtures/test_sentencepiece.model') 27 | 28 | class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): 29 | 30 | tokenizer_class = XLNetTokenizer 31 | 32 | def setUp(self): 33 | super(XLNetTokenizationTest, self).setUp() 34 | 35 | # We have a SentencePiece fixture for testing 36 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) 37 | tokenizer.save_pretrained(self.tmpdirname) 38 | 39 | def get_tokenizer(self, **kwargs): 40 | return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs) 41 | 42 | def get_input_output_texts(self): 43 | input_text = u"This is a test" 44 | output_text = u"This is a test" 45 | return input_text, output_text 46 | 47 | 48 | def test_full_tokenizer(self): 49 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) 50 | 51 | tokens = tokenizer.tokenize(u'This is a test') 52 | self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) 53 | 54 | self.assertListEqual( 55 | tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) 56 | 57 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 58 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 59 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 60 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 61 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.']) 62 | ids = tokenizer.convert_tokens_to_ids(tokens) 63 | self.assertListEqual( 64 | ids, [8, 21, 84, 55, 24, 19, 7, 0, 65 | 602, 347, 347, 347, 3, 12, 66, 66 | 46, 72, 80, 6, 0, 4]) 67 | 68 | back_tokens = tokenizer.convert_ids_to_tokens(ids) 69 | self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 70 | u'or', u'n', SPIECE_UNDERLINE + u'in', 71 | SPIECE_UNDERLINE + u'', u'', u'2', u'0', u'0', u'0', u',', 72 | SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 73 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', 74 | u'', u'.']) 75 | 76 | def test_tokenizer_lower(self): 77 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True) 78 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 79 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'', u'i', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 80 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 81 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 82 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) 83 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), [u"▁he", u"ll", u"o"]) 84 | 85 | def test_tokenizer_no_lower(self): 86 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False) 87 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 88 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', u'or', 89 | u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 90 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 91 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) 92 | 93 | @pytest.mark.slow 94 | def test_sequence_builders(self): 95 | tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased") 96 | 97 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 98 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 99 | 100 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 101 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 102 | 103 | assert encoded_sentence == text + [4, 3] 104 | assert encoded_pair == text + [4] + text_2 + [4, 3] 105 | 106 | 107 | if __name__ == '__main__': 108 | unittest.main() 109 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import transformers.tokenization_transfo_xl as data_utils 27 | 28 | from transformers import CONFIG_NAME, WEIGHTS_NAME 29 | from transformers import (TransfoXLConfig, TransfoXLLMHeadModel, 30 | load_tf_weights_in_transfo_xl) 31 | from transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES) 32 | 33 | if sys.version_info[0] == 2: 34 | import cPickle as pickle 35 | else: 36 | import pickle 37 | 38 | import logging 39 | logging.basicConfig(level=logging.INFO) 40 | 41 | # We do this to be able to load python 2 datasets pickles 42 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 43 | data_utils.Vocab = data_utils.TransfoXLTokenizer 44 | data_utils.Corpus = data_utils.TransfoXLCorpus 45 | sys.modules['data_utils'] = data_utils 46 | sys.modules['vocabulary'] = data_utils 47 | 48 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 49 | transfo_xl_config_file, 50 | pytorch_dump_folder_path, 51 | transfo_xl_dataset_file): 52 | if transfo_xl_dataset_file: 53 | # Convert a pre-processed corpus (see original TensorFlow repo) 54 | with open(transfo_xl_dataset_file, "rb") as fp: 55 | corpus = pickle.load(fp, encoding="latin1") 56 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 57 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file'] 58 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 59 | corpus_vocab_dict = corpus.vocab.__dict__ 60 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 61 | 62 | corpus_dict_no_vocab = corpus.__dict__ 63 | corpus_dict_no_vocab.pop('vocab', None) 64 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 65 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 66 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 67 | 68 | if tf_checkpoint_path: 69 | # Convert a pre-trained TensorFlow model 70 | config_path = os.path.abspath(transfo_xl_config_file) 71 | tf_path = os.path.abspath(tf_checkpoint_path) 72 | 73 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 74 | # Initialise PyTorch model 75 | if transfo_xl_config_file == "": 76 | config = TransfoXLConfig() 77 | else: 78 | config = TransfoXLConfig.from_json_file(transfo_xl_config_file) 79 | print("Building PyTorch model from configuration: {}".format(str(config))) 80 | model = TransfoXLLMHeadModel(config) 81 | 82 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 83 | # Save pytorch-model 84 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 85 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 86 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 87 | torch.save(model.state_dict(), pytorch_weights_dump_path) 88 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 89 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 90 | f.write(config.to_json_string()) 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument("--pytorch_dump_folder_path", 96 | default = None, 97 | type = str, 98 | required = True, 99 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 100 | parser.add_argument("--tf_checkpoint_path", 101 | default = "", 102 | type = str, 103 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 104 | parser.add_argument("--transfo_xl_config_file", 105 | default = "", 106 | type = str, 107 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 108 | "This specifies the model architecture.") 109 | parser.add_argument("--transfo_xl_dataset_file", 110 | default = "", 111 | type = str, 112 | help = "An optional dataset file to be converted in a vocabulary.") 113 | args = parser.parse_args() 114 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 115 | args.transfo_xl_config_file, 116 | args.pytorch_dump_folder_path, 117 | args.transfo_xl_dataset_file) 118 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/tokenization_bert_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import pytest 20 | from io import open 21 | 22 | from transformers.tokenization_bert import (BasicTokenizer, 23 | BertTokenizer, 24 | WordpieceTokenizer, 25 | _is_control, _is_punctuation, 26 | _is_whitespace, VOCAB_FILES_NAMES) 27 | 28 | from .tokenization_tests_commons import CommonTestCases 29 | 30 | class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): 31 | 32 | tokenizer_class = BertTokenizer 33 | 34 | def setUp(self): 35 | super(BertTokenizationTest, self).setUp() 36 | 37 | vocab_tokens = [ 38 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 39 | "##ing", ",", "low", "lowest", 40 | ] 41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 42 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: 43 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 44 | 45 | def get_tokenizer(self, **kwargs): 46 | return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) 47 | 48 | def get_input_output_texts(self): 49 | input_text = u"UNwant\u00E9d,running" 50 | output_text = u"unwanted, running" 51 | return input_text, output_text 52 | 53 | def test_full_tokenizer(self): 54 | tokenizer = self.tokenizer_class(self.vocab_file) 55 | 56 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 57 | self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 58 | self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 59 | 60 | def test_chinese(self): 61 | tokenizer = BasicTokenizer() 62 | 63 | self.assertListEqual( 64 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 65 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 66 | 67 | def test_basic_tokenizer_lower(self): 68 | tokenizer = BasicTokenizer(do_lower_case=True) 69 | 70 | self.assertListEqual( 71 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 72 | ["hello", "!", "how", "are", "you", "?"]) 73 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 74 | 75 | def test_basic_tokenizer_no_lower(self): 76 | tokenizer = BasicTokenizer(do_lower_case=False) 77 | 78 | self.assertListEqual( 79 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 80 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 81 | 82 | def test_wordpiece_tokenizer(self): 83 | vocab_tokens = [ 84 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 85 | "##ing" 86 | ] 87 | 88 | vocab = {} 89 | for (i, token) in enumerate(vocab_tokens): 90 | vocab[token] = i 91 | tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]") 92 | 93 | self.assertListEqual(tokenizer.tokenize(""), []) 94 | 95 | self.assertListEqual( 96 | tokenizer.tokenize("unwanted running"), 97 | ["un", "##want", "##ed", "runn", "##ing"]) 98 | 99 | self.assertListEqual( 100 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 101 | 102 | def test_is_whitespace(self): 103 | self.assertTrue(_is_whitespace(u" ")) 104 | self.assertTrue(_is_whitespace(u"\t")) 105 | self.assertTrue(_is_whitespace(u"\r")) 106 | self.assertTrue(_is_whitespace(u"\n")) 107 | self.assertTrue(_is_whitespace(u"\u00A0")) 108 | 109 | self.assertFalse(_is_whitespace(u"A")) 110 | self.assertFalse(_is_whitespace(u"-")) 111 | 112 | def test_is_control(self): 113 | self.assertTrue(_is_control(u"\u0005")) 114 | 115 | self.assertFalse(_is_control(u"A")) 116 | self.assertFalse(_is_control(u" ")) 117 | self.assertFalse(_is_control(u"\t")) 118 | self.assertFalse(_is_control(u"\r")) 119 | 120 | def test_is_punctuation(self): 121 | self.assertTrue(_is_punctuation(u"-")) 122 | self.assertTrue(_is_punctuation(u"$")) 123 | self.assertTrue(_is_punctuation(u"`")) 124 | self.assertTrue(_is_punctuation(u".")) 125 | 126 | self.assertFalse(_is_punctuation(u"A")) 127 | self.assertFalse(_is_punctuation(u" ")) 128 | 129 | @pytest.mark.slow 130 | def test_sequence_builders(self): 131 | tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased") 132 | 133 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 134 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 135 | 136 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 137 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 138 | 139 | assert encoded_sentence == [101] + text + [102] 140 | assert encoded_pair == [101] + text + [102] + text_2 + [102] 141 | 142 | if __name__ == '__main__': 143 | unittest.main() 144 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/evaluator/openkp_evaluator.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import re 4 | import string 5 | import numpy as np 6 | import logging 7 | 8 | logger = logging.getLogger() 9 | 10 | 11 | def normalize_answer(s): 12 | def remove_articles(text): 13 | return re.sub(r'\b(a|an|the)\b', ' ', text) 14 | def white_space_fix(text): 15 | return ' '.join(text.split()) 16 | def remove_punc(text): 17 | exclude = set(string.punctuation) 18 | return ''.join(ch for ch in text if ch not in exclude) 19 | def lower(text): 20 | return text.lower() 21 | return ' '.join([lower(x) for x in s]).rstrip() 22 | 23 | def remove_empty(a_list): 24 | new_list = [] 25 | for i in a_list: 26 | if len(i) > 0: 27 | if len(i[0]) >0: 28 | new_list.append(normalize_answer(i)) 29 | return new_list 30 | 31 | def dedup(kp_list): 32 | dedupset = set() 33 | kp_list_dedup = [] 34 | for kp in kp_list: 35 | if kp in dedupset: 36 | continue 37 | kp_list_dedup.append(kp) 38 | dedupset.add(kp) 39 | return kp_list_dedup 40 | 41 | def get_score_full(candidates, references, maxDepth = 5): 42 | precision = [] 43 | recall = [] 44 | reference_set = set(dedup(references)) 45 | candidates = dedup(candidates) 46 | referencelen = len(reference_set) 47 | true_positive = 0 48 | for i in range(maxDepth): 49 | if len(candidates) > i: 50 | kp_pred = candidates[i] 51 | if kp_pred in reference_set: 52 | true_positive += 1 53 | precision.append(true_positive/float(i + 1)) 54 | recall.append(true_positive/float(referencelen)) 55 | return precision, recall 56 | 57 | def online_evaluate(candidates, references, urls): 58 | precision_scores, recall_scores, f1_scores = {1:[], 3:[], 5:[]}, {1:[], 3:[], 5:[]}, {1:[], 3:[], 5:[]} 59 | for url in urls: 60 | candidate = remove_empty(candidates[url]['KeyPhrases']) 61 | reference = remove_empty(references[url]['KeyPhrases']) 62 | p, r = get_score_full(candidate, reference) 63 | for i in [1,3,5]: 64 | precision = p[i-1] 65 | recall = r[i-1] 66 | if precision + recall > 0: 67 | f1_scores[i].append((2 * (precision * recall)) / (precision + recall)) 68 | else: 69 | f1_scores[i].append(0) 70 | precision_scores[i].append(precision) 71 | recall_scores[i].append(recall) 72 | return f1_scores, precision_scores, recall_scores 73 | 74 | 75 | def evaluate(candidates, references, urls): 76 | precision_scores, recall_scores, f1_scores = {1:[], 3:[], 5:[]}, {1:[], 3:[], 5:[]}, {1:[], 3:[], 5:[]} 77 | for url in urls: 78 | candidate = remove_empty(candidates[url]['KeyPhrases']) 79 | reference = remove_empty(references[url]['KeyPhrases']) 80 | p, r = get_score_full(candidate, reference) 81 | for i in [1,3,5]: 82 | precision = p[i-1] 83 | recall = r[i-1] 84 | if precision + recall > 0: 85 | f1_scores[i].append((2 * (precision * recall)) / (precision + recall)) 86 | else: 87 | f1_scores[i].append(0) 88 | precision_scores[i].append(precision) 89 | recall_scores[i].append(recall) 90 | print("########################\nMetrics") 91 | for i in precision_scores: 92 | print("@{}".format(i)) 93 | print("F1:{}".format(np.mean(f1_scores[i]))) 94 | print("P:{}".format(np.mean(precision_scores[i]))) 95 | print("R:{}".format(np.mean(recall_scores[i]))) 96 | print("#########################") 97 | 98 | 99 | def files_are_good(candidate, reference): 100 | referenceURLs = set(reference.keys()) 101 | candidateURLs = set(candidate.keys()) 102 | if len((referenceURLs - candidateURLs)) > 0: 103 | logger.info("ERROR:Candidate File is missing URLS present in reference file\nMissing urls:{}".format(referenceURLs-candidateURLs)) 104 | return False 105 | if len((candidateURLs - referenceURLs)) > 0: 106 | logger.info("ERROR:Candidate File includes URLS not present in reference file\nUnexpected urls:{}".format(candidateURLs-referenceURLs)) 107 | return False 108 | return True 109 | 110 | def load_file(filename): 111 | data = {} 112 | with open(filename,'r') as f: 113 | for l in f: 114 | item = json.loads(l) 115 | data[item['url']] = item 116 | return data 117 | 118 | 119 | def evaluate_openkp(candidate, reference_filename): 120 | reference = load_file(reference_filename) 121 | if files_are_good(candidate, reference) == True: 122 | candidate_urls = set(candidate.keys()) 123 | reference_urls = set(reference.keys()) 124 | urls = reference_urls.intersection(candidate_urls) 125 | f1_scores, precision_scores, recall_scores = online_evaluate(candidate, reference, urls) 126 | return f1_scores, precision_scores, recall_scores 127 | else: 128 | logger.info("Candidate file and Reference are not comparable. Please verify your candidate file.") 129 | 130 | 131 | def main(candidate_filename, reference_filename): 132 | candidate = load_file(candidate_filename) 133 | reference = load_file(reference_filename) 134 | if files_are_good(candidate, reference) == True: 135 | candidate_urls = set(candidate.keys()) 136 | reference_urls = set(reference.keys()) 137 | urls = reference_urls.intersection(candidate_urls) 138 | evaluate(candidate, reference, urls) 139 | exit(0) 140 | else: 141 | print("Candidate file and Reference are not comparable. Please verify your candidate file.") 142 | exit(-1) 143 | 144 | if __name__ == "__main__": 145 | if len(sys.argv) != 3: 146 | print("Usage:evaluate.py ") 147 | exit(-1) 148 | else: 149 | main(sys.argv[1], sys.argv[2]) -------------------------------------------------------------------------------- /SMART-KPE/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_parser(): 4 | parser = argparse.ArgumentParser() 5 | # IO 6 | parser.add_argument( 7 | "--read_from_cached_features", 8 | action="store_true", 9 | help="Whether to read saved features") 10 | 11 | parser.add_argument( 12 | "--cached_features_dir", 13 | default="./features", 14 | help="The input data dir(cached)") 15 | 16 | parser.add_argument( 17 | "--data_dir", 18 | default="./data", 19 | help="The input data dir(.jsonl)") 20 | 21 | parser.add_argument( 22 | "--from_checkpoint", 23 | default=None, 24 | help="The directory to load the model checkpoint") 25 | 26 | parser.add_argument( 27 | "--output_dir", 28 | default="./output", 29 | help="The output dir") 30 | 31 | parser.add_argument( 32 | "--print_dir", 33 | default="./output", 34 | help="The printing dir") 35 | 36 | parser.add_argument( 37 | "--meta_dir", 38 | default="./meta_data", 39 | help="The metadata dir.") 40 | 41 | parser.add_argument( 42 | "--use_snapshot", 43 | action="store_true", 44 | help="Whether to use snapshot feature") 45 | 46 | parser.add_argument( 47 | "--include_title", 48 | action="store_true", 49 | help="Whether to use titles") 50 | 51 | parser.add_argument( 52 | "--snapshot_dim", 53 | type=int, default=512, 54 | help="The dimension of snapshot vectors") 55 | 56 | ''' 57 | parser.add_argument( 58 | "--elmo_option_file", 59 | default="None", 60 | help="ELMO option file") 61 | 62 | parser.add_argument( 63 | "--elmo_weight_file", 64 | default="None", 65 | help="ELMO weight file") 66 | 67 | parser.add_argument( 68 | "--elmo_finetune", 69 | action="store_true", 70 | help="ELMO weights") 71 | 72 | parser.add_argument( 73 | "--elmo_layernorm", 74 | action="store_true", 75 | help="ELMO") 76 | ''' 77 | parser.add_argument( 78 | "--train", 79 | action='store_true', 80 | help="whether to train") 81 | parser.set_defaults(train=False) 82 | 83 | parser.add_argument( 84 | "--dev", 85 | action='store_true', 86 | help="whether to evaluate") 87 | parser.set_defaults(dev=False) 88 | 89 | parser.add_argument( 90 | "--test", 91 | action='store_true', 92 | help="whether to test") 93 | parser.set_defaults(test=False) 94 | 95 | parser.add_argument( 96 | "--only-predictions", 97 | dest='only_pred', 98 | action='store_true', 99 | help="only present the prediction results and do not evaluate precisions and recalls according to the golden labels.") 100 | parser.set_defaults(only_pred=False) 101 | 102 | # Model 103 | parser.add_argument( 104 | "--device", 105 | default=None, 106 | help="Whether to use cuda") 107 | 108 | parser.add_argument( 109 | "--learning_rate", 110 | type=float, default=1e-3, 111 | help="Learning rate") 112 | 113 | parser.add_argument( 114 | "--max_grad_norm", 115 | type=float, default=1., 116 | help="max gradient norm used for clipping") 117 | 118 | parser.add_argument( 119 | "--num_train_epochs", 120 | type=int, default=1, 121 | help="Number of epochs") 122 | 123 | parser.add_argument( 124 | "--batch_size", 125 | type=int, default=32, 126 | help="Batch size") 127 | 128 | parser.add_argument( 129 | "--max_text_length", 130 | type=int, default=256, 131 | help="Sequence length") 132 | 133 | parser.add_argument( 134 | "--logging_steps", 135 | type=int, default=2000, 136 | help="Steps for warmup") 137 | 138 | parser.add_argument( 139 | "--evaluate_during_training", 140 | action="store_true", 141 | help="whether to evaluate model during training") 142 | parser.set_defaults(evaluate_during_training=False) 143 | 144 | parser.add_argument( 145 | "--filter_predicted_kp", 146 | action="store_true", 147 | help="whether to filter kps") 148 | 149 | parser.add_argument( 150 | "--save_steps", 151 | type=int, default=2000, 152 | help="Steps to save model") 153 | 154 | parser.add_argument( 155 | "--save_best", 156 | action='store_true', 157 | help="store the best model checkpoint according to main_metric") 158 | parser.set_defaults(save_best=False) 159 | 160 | parser.add_argument( 161 | "--main_metric", 162 | type=str, default="P@3", 163 | help="the main metric to compare for best model saving") 164 | 165 | parser.add_argument( 166 | "--max_steps", 167 | type=int, default=-1, 168 | help="Max steps per epoch") 169 | 170 | parser.add_argument( 171 | "--evaluate_num", 172 | type=int, default=-1, 173 | help="when set with positive integer, use limited samples to evalute during training. this will save much time for debugging and logging.") 174 | 175 | parser.add_argument("--positional_size",type=int, default=256,help="Positional encoding") 176 | #parser.add_argument("--elmo_size", type=int, default=1024, help="ELMo encoding") 177 | parser.add_argument("--tag_num", type=int, default=3, help="Tags.") 178 | parser.add_argument("--bert_size", type=int, default=768, help="ELMo encoding") 179 | parser.add_argument("--visual_size",type=int, default=18,help="Positional encoding") 180 | parser.add_argument("--gradient_accumulation_steps",type=int, default=1,help="Steps for gradient calculation") 181 | parser.add_argument("--weight_decay", default=0.1, type=float, help="Weight decay if we apply some.") 182 | parser.add_argument("--num_trans", type=int, default=2, help="number of parallel transformers") 183 | 184 | args = parser.parse_args() 185 | return args 186 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/configuration_ctrl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Salesforce and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Salesforce CTRL configuration """ 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import json 20 | import logging 21 | import sys 22 | from io import open 23 | 24 | from .configuration_utils import PretrainedConfig 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json"} 29 | 30 | class CTRLConfig(PretrainedConfig): 31 | """Configuration class to store the configuration of a `CTRLModel`. 32 | 33 | Args: 34 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CTRLModel` or a configuration json file. 35 | n_positions: Number of positional embeddings. 36 | n_ctx: Size of the causal mask (usually same as n_positions). 37 | dff: Size of the inner dimension of the FFN. 38 | n_embd: Dimensionality of the embeddings and hidden states. 39 | n_layer: Number of hidden layers in the Transformer encoder. 40 | n_head: Number of attention heads for each attention layer in 41 | the Transformer encoder. 42 | layer_norm_epsilon: epsilon to use in the layer norm layers 43 | resid_pdrop: The dropout probabilitiy for all fully connected 44 | layers in the embeddings, encoder, and pooler. 45 | attn_pdrop: The dropout ratio for the attention 46 | probabilities. 47 | embd_pdrop: The dropout ratio for the embeddings. 48 | initializer_range: The sttdev of the truncated_normal_initializer for 49 | initializing all weight matrices. 50 | """ 51 | pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP 52 | 53 | def __init__( 54 | self, 55 | vocab_size_or_config_json_file=246534, 56 | n_positions=256, 57 | n_ctx=256, 58 | n_embd=1280, 59 | dff=8192, 60 | n_layer=48, 61 | n_head=16, 62 | resid_pdrop=0.1, 63 | embd_pdrop=0.1, 64 | attn_pdrop=0.1, 65 | layer_norm_epsilon=1e-6, 66 | initializer_range=0.02, 67 | 68 | num_labels=1, 69 | summary_type='cls_index', 70 | summary_use_proj=True, 71 | summary_activation=None, 72 | summary_proj_to_labels=True, 73 | summary_first_dropout=0.1, 74 | **kwargs 75 | ): 76 | """Constructs CTRLConfig. 77 | 78 | Args: 79 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CTRLModel` or a configuration json file. 80 | n_positions: Number of positional embeddings. 81 | n_ctx: Size of the causal mask (usually same as n_positions). 82 | dff: Size of the inner dimension of the FFN. 83 | n_embd: Dimensionality of the embeddings and hidden states. 84 | n_layer: Number of hidden layers in the Transformer encoder. 85 | n_head: Number of attention heads for each attention layer in 86 | the Transformer encoder. 87 | layer_norm_epsilon: epsilon to use in the layer norm layers 88 | resid_pdrop: The dropout probabilitiy for all fully connected 89 | layers in the embeddings, encoder, and pooler. 90 | attn_pdrop: The dropout ratio for the attention 91 | probabilities. 92 | embd_pdrop: The dropout ratio for the embeddings. 93 | initializer_range: The sttdev of the truncated_normal_initializer for 94 | initializing all weight matrices. 95 | """ 96 | super(CTRLConfig, self).__init__(**kwargs) 97 | 98 | self.vocab_size = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1 99 | self.n_ctx = n_ctx 100 | self.n_positions = n_positions 101 | self.n_embd = n_embd 102 | self.n_layer = n_layer 103 | self.n_head = n_head 104 | self.dff = dff 105 | self.resid_pdrop = resid_pdrop 106 | self.embd_pdrop = embd_pdrop 107 | self.attn_pdrop = attn_pdrop 108 | self.layer_norm_epsilon = layer_norm_epsilon 109 | self.initializer_range = initializer_range 110 | 111 | self.num_labels = num_labels 112 | self.summary_type = summary_type 113 | self.summary_use_proj = summary_use_proj 114 | self.summary_activation = summary_activation 115 | self.summary_first_dropout = summary_first_dropout 116 | self.summary_proj_to_labels = summary_proj_to_labels 117 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 118 | and isinstance(vocab_size_or_config_json_file, unicode)): 119 | with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: 120 | json_config = json.loads(reader.read()) 121 | for key, value in json_config.items(): 122 | self.__dict__[key] = value 123 | elif not isinstance(vocab_size_or_config_json_file, int): 124 | raise ValueError( 125 | "First argument must be either a vocabulary size (int)" 126 | "or the path to a pretrained model config file (str)" 127 | ) 128 | 129 | @property 130 | def max_position_embeddings(self): 131 | return self.n_positions 132 | 133 | @property 134 | def hidden_size(self): 135 | return self.n_embd 136 | 137 | @property 138 | def num_attention_heads(self): 139 | return self.n_head 140 | 141 | @property 142 | def num_hidden_layers(self): 143 | return self.n_layer 144 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/configuration_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ OpenAI GPT-2 configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", 30 | "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json", 31 | "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json", 32 | "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-config.json", 33 | "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",} 34 | 35 | class GPT2Config(PretrainedConfig): 36 | """Configuration class to store the configuration of a `GPT2Model`. 37 | 38 | Args: 39 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. 40 | n_positions: Number of positional embeddings. 41 | n_ctx: Size of the causal mask (usually same as n_positions). 42 | n_embd: Dimensionality of the embeddings and hidden states. 43 | n_layer: Number of hidden layers in the Transformer encoder. 44 | n_head: Number of attention heads for each attention layer in 45 | the Transformer encoder. 46 | layer_norm_epsilon: epsilon to use in the layer norm layers 47 | resid_pdrop: The dropout probabilitiy for all fully connected 48 | layers in the embeddings, encoder, and pooler. 49 | attn_pdrop: The dropout ratio for the attention 50 | probabilities. 51 | embd_pdrop: The dropout ratio for the embeddings. 52 | initializer_range: The sttdev of the truncated_normal_initializer for 53 | initializing all weight matrices. 54 | """ 55 | pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP 56 | 57 | def __init__( 58 | self, 59 | vocab_size_or_config_json_file=50257, 60 | n_positions=1024, 61 | n_ctx=1024, 62 | n_embd=768, 63 | n_layer=12, 64 | n_head=12, 65 | resid_pdrop=0.1, 66 | embd_pdrop=0.1, 67 | attn_pdrop=0.1, 68 | layer_norm_epsilon=1e-5, 69 | initializer_range=0.02, 70 | 71 | num_labels=1, 72 | summary_type='cls_index', 73 | summary_use_proj=True, 74 | summary_activation=None, 75 | summary_proj_to_labels=True, 76 | summary_first_dropout=0.1, 77 | **kwargs 78 | ): 79 | """Constructs GPT2Config. 80 | 81 | Args: 82 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. 83 | n_positions: Number of positional embeddings. 84 | n_ctx: Size of the causal mask (usually same as n_positions). 85 | n_embd: Dimensionality of the embeddings and hidden states. 86 | n_layer: Number of hidden layers in the Transformer encoder. 87 | n_head: Number of attention heads for each attention layer in 88 | the Transformer encoder. 89 | layer_norm_epsilon: epsilon to use in the layer norm layers 90 | resid_pdrop: The dropout probabilitiy for all fully connected 91 | layers in the embeddings, encoder, and pooler. 92 | attn_pdrop: The dropout ratio for the attention 93 | probabilities. 94 | embd_pdrop: The dropout ratio for the embeddings. 95 | initializer_range: The sttdev of the truncated_normal_initializer for 96 | initializing all weight matrices. 97 | """ 98 | super(GPT2Config, self).__init__(**kwargs) 99 | 100 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 101 | and isinstance(vocab_size_or_config_json_file, unicode)): 102 | with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: 103 | json_config = json.loads(reader.read()) 104 | for key, value in json_config.items(): 105 | self.__dict__[key] = value 106 | elif isinstance(vocab_size_or_config_json_file, int): 107 | self.vocab_size = vocab_size_or_config_json_file 108 | self.n_ctx = n_ctx 109 | self.n_positions = n_positions 110 | self.n_embd = n_embd 111 | self.n_layer = n_layer 112 | self.n_head = n_head 113 | self.resid_pdrop = resid_pdrop 114 | self.embd_pdrop = embd_pdrop 115 | self.attn_pdrop = attn_pdrop 116 | self.layer_norm_epsilon = layer_norm_epsilon 117 | self.initializer_range = initializer_range 118 | 119 | self.num_labels = num_labels 120 | self.summary_type = summary_type 121 | self.summary_use_proj = summary_use_proj 122 | self.summary_activation = summary_activation 123 | self.summary_first_dropout = summary_first_dropout 124 | self.summary_proj_to_labels = summary_proj_to_labels 125 | else: 126 | raise ValueError( 127 | "First argument must be either a vocabulary size (int)" 128 | "or the path to a pretrained model config file (str)" 129 | ) 130 | 131 | @property 132 | def max_position_embeddings(self): 133 | return self.n_positions 134 | 135 | @property 136 | def hidden_size(self): 137 | return self.n_embd 138 | 139 | @property 140 | def num_attention_heads(self): 141 | return self.n_head 142 | 143 | @property 144 | def num_hidden_layers(self): 145 | return self.n_layer 146 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/networks/Bert2Span.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import logging 4 | import numpy as np 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from torch.nn import NLLLoss 8 | from ..transformers import BertPreTrainedModel, BertModel 9 | 10 | 11 | logger = logging.getLogger() 12 | 13 | 14 | # ------------------------------------------------------------------------------------------- 15 | # SelfAtttion Extractor 16 | # ------------------------------------------------------------------------------------------- 17 | class SpanAttention(nn.Module): 18 | def __init__(self, hidden_size): 19 | super(SpanAttention, self).__init__() 20 | 21 | self.hidden_size = hidden_size 22 | self.query_layer = nn.Linear(hidden_size, hidden_size) 23 | self.key_layer = nn.Linear(hidden_size, hidden_size) 24 | 25 | def forward(self, hidden_states, active_mask): 26 | '''hidden_states and active_mask for word_level''' 27 | attention_mask, tril_mask = self.create_mask(active_mask, hidden_states.size(1)) 28 | 29 | query = self.query_layer(hidden_states) 30 | key = self.key_layer(hidden_states) 31 | 32 | # Take the dot product between "query" and "key" to get the raw attention scores. 33 | attention_scores = torch.matmul(query, key.transpose(-1, -2)) 34 | attention_scores = attention_scores / math.sqrt(self.hidden_size) 35 | 36 | attention_scores = attention_scores + attention_mask + tril_mask 37 | 38 | # Normalize the attention scores to probabilities. 39 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 40 | return attention_probs 41 | 42 | def create_mask(self, active_mask, max_len): 43 | 44 | # extended_mask for padding 45 | extended_active_mask = active_mask[:, None, :] 46 | extended_active_mask = extended_active_mask.to(dtype=next(self.parameters()).dtype) 47 | extended_active_mask = (1.0 - extended_active_mask) * -10000.0 48 | 49 | full_mask = torch.full([max_len, max_len], -1000.0) 50 | tril_mask = full_mask.tril_(-1) 51 | tril_mask = tril_mask.to(next(self.parameters())) 52 | tril_mask = tril_mask[None,:,:] 53 | return extended_active_mask, tril_mask 54 | 55 | 56 | # ------------------------------------------------------------------------------------------- 57 | # Inherit BertPreTrainedModel 58 | # ------------------------------------------------------------------------------------------- 59 | class BertForAttSpanClassification(BertPreTrainedModel): 60 | 61 | def __init__(self, config): 62 | super(BertForAttSpanClassification, self).__init__(config) 63 | self.num_labels = config.num_labels 64 | 65 | self.bert = BertModel(config) 66 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 67 | 68 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 69 | self.self_att_classifier = SpanAttention(config.hidden_size) 70 | 71 | self.init_weights() 72 | 73 | 74 | 75 | # ------------------------------------------------------------------------------------------- 76 | # Bert2AttSpanExtractor 77 | # ------------------------------------------------------------------------------------------- 78 | class BertForAttSpanExtractor(BertForAttSpanClassification): 79 | 80 | def forward(self, input_ids, attention_mask, valid_ids, valid_output, 81 | active_mask, s_label=None, e_label=None, end_mask=None): 82 | 83 | # -------------------------------------------------------------------------------- 84 | # Bert Embedding Outputs 85 | outputs = self.bert(input_ids=input_ids, 86 | attention_mask=attention_mask) 87 | 88 | sequence_output = outputs[0] 89 | 90 | # -------------------------------------------------------------------------------- 91 | # Valid Outputs : get first token vector 92 | batch_size = sequence_output.size(0) 93 | for i in range(batch_size): 94 | valid_num = sum(valid_ids[i]).item() 95 | 96 | vectors = sequence_output[i][valid_ids[i] == 1] 97 | valid_output[i, :valid_num].copy_(vectors) 98 | 99 | # -------------------------------------------------------------------------------- 100 | # Dropout 101 | sequence_output = self.dropout(valid_output) # shape = (batch_size * max_word_length * 768) 102 | 103 | # start softmax logit 104 | s_logits = self.classifier(sequence_output) 105 | s_logits = F.softmax(s_logits, dim=-1) 106 | 107 | # end softmax logit 108 | e_logits = self.self_att_classifier(hidden_states=sequence_output, 109 | active_mask=active_mask) # shape = (batch_size, max_word_len, max_word_len) 110 | 111 | s_active_loss = active_mask.view(-1) == 1 # [False, True, ...] 112 | s_active_logits = s_logits.view(-1, self.num_labels)[s_active_loss] # False 113 | 114 | if (s_label is not None) and (e_label is not None): 115 | 116 | loss_fct = NLLLoss() 117 | 118 | # -------------------------------------------------------- 119 | # Start Loss : log 120 | s_active_logits = torch.log(s_active_logits + 1e-16) 121 | # Start Loss : activate label 122 | s_active_labels = s_label.view(-1)[s_active_loss] 123 | # Start Loss : final start loss 124 | start_loss = loss_fct(s_active_logits, s_active_labels) 125 | 126 | # -------------------------------------------------------- 127 | # End Loss : log (+ 1e-16 prevent -inf) 128 | e_logits = torch.log(e_logits + 1e-16) 129 | # End Loss : activate end loss from s_label 130 | e_active_loss = s_label.view(-1) == 1 131 | e_active_logits = e_logits.view(-1, e_logits.shape[1])[e_active_loss] 132 | 133 | # -------------------------------------------------------- 134 | # End Loss : activate end label 135 | e_label_valid_ids = end_mask.view(-1) == 1 136 | e_activate_labels = e_label.view(-1)[e_label_valid_ids] 137 | end_loss = loss_fct(e_active_logits, e_activate_labels) 138 | 139 | # -------------------------------------------------------- 140 | # total loss 141 | total_loss = start_loss + end_loss # (start_loss + end_loss) / 2 142 | return total_loss 143 | else: 144 | e_active_logits = e_logits.view(-1, e_logits.shape[1])[s_active_loss] 145 | s_active_logits = s_active_logits[:,1] 146 | return s_active_logits, e_active_logits 147 | -------------------------------------------------------------------------------- /BERT-KPE-based/bertkpe/transformers/tests/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import os 21 | import pytest 22 | 23 | from transformers import is_torch_available 24 | 25 | if is_torch_available(): 26 | import torch 27 | 28 | from transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, 29 | WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) 30 | else: 31 | pytestmark = pytest.mark.skip("Require Torch") 32 | 33 | from .tokenization_tests_commons import TemporaryDirectory 34 | 35 | 36 | def unwrap_schedule(scheduler, num_steps=10): 37 | lrs = [] 38 | for _ in range(num_steps): 39 | scheduler.step() 40 | lrs.append(scheduler.get_lr()) 41 | return lrs 42 | 43 | def unwrap_and_save_reload_schedule(scheduler, num_steps=10): 44 | lrs = [] 45 | for step in range(num_steps): 46 | scheduler.step() 47 | lrs.append(scheduler.get_lr()) 48 | if step == num_steps // 2: 49 | with TemporaryDirectory() as tmpdirname: 50 | file_name = os.path.join(tmpdirname, 'schedule.bin') 51 | torch.save(scheduler.state_dict(), file_name) 52 | 53 | state_dict = torch.load(file_name) 54 | scheduler.load_state_dict(state_dict) 55 | return lrs 56 | 57 | class OptimizationTest(unittest.TestCase): 58 | 59 | def assertListAlmostEqual(self, list1, list2, tol): 60 | self.assertEqual(len(list1), len(list2)) 61 | for a, b in zip(list1, list2): 62 | self.assertAlmostEqual(a, b, delta=tol) 63 | 64 | def test_adam_w(self): 65 | w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True) 66 | target = torch.tensor([0.4, 0.2, -0.5]) 67 | criterion = torch.nn.MSELoss() 68 | # No warmup, constant schedule, no gradient clipping 69 | optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0) 70 | for _ in range(100): 71 | loss = criterion(w, target) 72 | loss.backward() 73 | optimizer.step() 74 | w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves. 75 | w.grad.zero_() 76 | self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) 77 | 78 | 79 | class ScheduleInitTest(unittest.TestCase): 80 | m = torch.nn.Linear(50, 50) if is_torch_available() else None 81 | optimizer = AdamW(m.parameters(), lr=10.) if is_torch_available() else None 82 | num_steps = 10 83 | 84 | def assertListAlmostEqual(self, list1, list2, tol): 85 | self.assertEqual(len(list1), len(list2)) 86 | for a, b in zip(list1, list2): 87 | self.assertAlmostEqual(a, b, delta=tol) 88 | 89 | def test_constant_scheduler(self): 90 | scheduler = ConstantLRSchedule(self.optimizer) 91 | lrs = unwrap_schedule(scheduler, self.num_steps) 92 | expected_learning_rates = [10.] * self.num_steps 93 | self.assertEqual(len(lrs[0]), 1) 94 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 95 | 96 | scheduler = ConstantLRSchedule(self.optimizer) 97 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 98 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 99 | 100 | def test_warmup_constant_scheduler(self): 101 | scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4) 102 | lrs = unwrap_schedule(scheduler, self.num_steps) 103 | expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0] 104 | self.assertEqual(len(lrs[0]), 1) 105 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 106 | 107 | scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4) 108 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 109 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 110 | 111 | def test_warmup_linear_scheduler(self): 112 | scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10) 113 | lrs = unwrap_schedule(scheduler, self.num_steps) 114 | expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0] 115 | self.assertEqual(len(lrs[0]), 1) 116 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 117 | 118 | scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10) 119 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 120 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 121 | 122 | def test_warmup_cosine_scheduler(self): 123 | scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10) 124 | lrs = unwrap_schedule(scheduler, self.num_steps) 125 | expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0] 126 | self.assertEqual(len(lrs[0]), 1) 127 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) 128 | 129 | scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10) 130 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 131 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 132 | 133 | def test_warmup_cosine_hard_restart_scheduler(self): 134 | scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10) 135 | lrs = unwrap_schedule(scheduler, self.num_steps) 136 | expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0] 137 | self.assertEqual(len(lrs[0]), 1) 138 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) 139 | 140 | scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10) 141 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 142 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 143 | 144 | if __name__ == "__main__": 145 | unittest.main() 146 | --------------------------------------------------------------------------------