├── pics ├── .gitignore ├── arch.png ├── adaptor.png ├── banner.png ├── hfl_qrcode.jpg ├── nav_banner.png ├── trainingloop.png ├── distillation_workflow.png ├── distillation_workflow2.png └── distillation_workflow_en.png ├── MANIFEST.in ├── examples ├── .gitignore ├── mnli_example │ ├── jsons │ │ ├── bert_config_L4t.json │ │ ├── TrainBertTeacher.json │ │ ├── DistillBertToTiny.json │ │ └── DistillMultiBertToTiny.json │ ├── run_mnli_train.sh │ ├── run_mnli_distill_T4tiny_emd.sh │ ├── run_mnli_distill_multiteacher.sh │ ├── run_mnli_distill_T4tiny.sh │ ├── README_ZH.md │ ├── parse.py │ ├── README.md │ ├── modeling.py │ ├── config.py │ ├── utils.py │ └── predict_function.py ├── random_tokens_example │ ├── bert_config │ │ ├── bert_config.json │ │ └── bert_config_T3.json │ ├── README.md │ └── distill.py ├── student_config │ ├── bert_base_cased_config │ │ ├── bert_config.json │ │ ├── bert_config_L3.json │ │ ├── bert_config_L3n.json │ │ ├── bert_config_L4t.json │ │ └── bert_config_L6.json │ └── roberta_wwm_config │ │ ├── bert_config_L4t.json │ │ ├── bert_config.json │ │ ├── bert_config_L3.json │ │ └── bert_config_L3n.json ├── conll2003_example │ ├── run_conll2003_train.sh │ ├── run_conll2003_distill_T3.sh │ ├── README_ZH.md │ └── README.md ├── cmrc2018_example │ ├── README_ZH.md │ ├── README.md │ ├── pytorch_pretrained_bert │ │ ├── __init__.py │ │ ├── convert_tf_checkpoint_to_pytorch.py │ │ ├── convert_gpt2_checkpoint_to_pytorch.py │ │ ├── convert_openai_checkpoint_to_pytorch.py │ │ ├── __main__.py │ │ ├── optimization_openai.py │ │ └── convert_transfo_xl_checkpoint_to_pytorch.py │ ├── run_cmrc2018_train.sh │ ├── run_cmrc2018_distill_T4tiny.sh │ ├── run_cmrc2018_distill_T3.sh │ ├── utils.py │ ├── cmrc2018_evaluate.py │ └── train_eval.py └── msra_ner_example │ ├── README_ZH.md │ ├── ner_ElectraTrain_dist.sh │ ├── README.md │ ├── ner_ElectraDistill_dist.sh │ ├── utils.py │ ├── modeling.py │ ├── utils_ner.py │ └── config.py ├── docs ├── source │ ├── _build │ │ └── html │ │ │ ├── _static │ │ │ ├── custom.css │ │ │ ├── fonts │ │ │ │ ├── Lato-Bold.ttf │ │ │ │ ├── Inconsolata.ttf │ │ │ │ ├── Lato-Regular.ttf │ │ │ │ ├── Lato │ │ │ │ │ ├── lato-bold.eot │ │ │ │ │ ├── lato-bold.ttf │ │ │ │ │ ├── lato-bold.woff │ │ │ │ │ ├── lato-bold.woff2 │ │ │ │ │ ├── lato-italic.eot │ │ │ │ │ ├── lato-italic.ttf │ │ │ │ │ ├── lato-italic.woff │ │ │ │ │ ├── lato-regular.eot │ │ │ │ │ ├── lato-regular.ttf │ │ │ │ │ ├── lato-bolditalic.eot │ │ │ │ │ ├── lato-bolditalic.ttf │ │ │ │ │ ├── lato-italic.woff2 │ │ │ │ │ ├── lato-regular.woff │ │ │ │ │ ├── lato-regular.woff2 │ │ │ │ │ ├── lato-bolditalic.woff │ │ │ │ │ └── lato-bolditalic.woff2 │ │ │ │ ├── RobotoSlab-Bold.ttf │ │ │ │ ├── Inconsolata-Bold.ttf │ │ │ │ ├── Inconsolata-Regular.ttf │ │ │ │ ├── RobotoSlab-Regular.ttf │ │ │ │ ├── fontawesome-webfont.eot │ │ │ │ ├── fontawesome-webfont.ttf │ │ │ │ ├── fontawesome-webfont.woff │ │ │ │ ├── fontawesome-webfont.woff2 │ │ │ │ └── RobotoSlab │ │ │ │ │ ├── roboto-slab-v7-bold.eot │ │ │ │ │ ├── roboto-slab-v7-bold.ttf │ │ │ │ │ ├── roboto-slab-v7-bold.woff │ │ │ │ │ ├── roboto-slab-v7-bold.woff2 │ │ │ │ │ ├── roboto-slab-v7-regular.eot │ │ │ │ │ ├── roboto-slab-v7-regular.ttf │ │ │ │ │ ├── roboto-slab-v7-regular.woff │ │ │ │ │ └── roboto-slab-v7-regular.woff2 │ │ │ ├── documentation_options.js │ │ │ ├── css │ │ │ │ └── badge_only.css │ │ │ ├── js │ │ │ │ └── theme.js │ │ │ └── pygments.css │ │ │ ├── objects.inv │ │ │ ├── .doctrees │ │ │ ├── index.doctree │ │ │ ├── environment.pickle │ │ │ ├── GettingStarted.doctree │ │ │ └── Quickstart copy.doctree │ │ │ ├── .buildinfo │ │ │ ├── _sources │ │ │ └── index.rst.txt │ │ │ ├── genindex.html │ │ │ ├── search.html │ │ │ └── searchindex.js │ ├── _static │ │ └── css │ │ │ └── custom.css │ ├── Configurations.rst │ ├── Utils.rst │ ├── Distillers.rst │ ├── Losses.rst │ ├── Presets.rst │ ├── index.rst │ ├── conf.py │ └── ExperimentResults.md ├── requirements.txt └── Makefile ├── .gitignore ├── src └── textbrewer │ ├── distillers.py │ ├── compatibility.py │ ├── __init__.py │ ├── projections.py │ ├── schedulers.py │ ├── snippet.py │ ├── data_utils.py │ ├── utils.py │ ├── presets.py │ └── distiller_multiteacher.py ├── .github └── stale.yml ├── CONTRIBUTING.md ├── setup.py └── CODE_OF_CONDUCT.md /pics/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints 3 | -------------------------------------------------------------------------------- /pics/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/pics/arch.png -------------------------------------------------------------------------------- /docs/source/_build/html/_static/custom.css: -------------------------------------------------------------------------------- 1 | /* This file intentionally left blank. */ 2 | -------------------------------------------------------------------------------- /pics/adaptor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/pics/adaptor.png -------------------------------------------------------------------------------- /pics/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/pics/banner.png -------------------------------------------------------------------------------- /pics/hfl_qrcode.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/pics/hfl_qrcode.jpg -------------------------------------------------------------------------------- /pics/nav_banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/pics/nav_banner.png -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==2.4.4 2 | sphinx_markdown_tables 3 | sphinxcontrib.napoleon 4 | 5 | -------------------------------------------------------------------------------- /pics/trainingloop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/pics/trainingloop.png -------------------------------------------------------------------------------- /pics/distillation_workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/pics/distillation_workflow.png -------------------------------------------------------------------------------- /pics/distillation_workflow2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/pics/distillation_workflow2.png -------------------------------------------------------------------------------- /pics/distillation_workflow_en.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/pics/distillation_workflow_en.png -------------------------------------------------------------------------------- /docs/source/_build/html/objects.inv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/objects.inv -------------------------------------------------------------------------------- /docs/source/_build/html/.doctrees/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/.doctrees/index.doctree -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | docs/build 3 | .vscode 4 | dist/ 5 | *.egg-info/ 6 | .ipynb_checkpoints 7 | .DS_Store 8 | .idea 9 | *.pdf 10 | *.png 11 | -------------------------------------------------------------------------------- /docs/source/_build/html/.doctrees/environment.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/.doctrees/environment.pickle -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Lato-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Lato-Bold.ttf -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Inconsolata.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Inconsolata.ttf -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Lato-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Lato-Regular.ttf -------------------------------------------------------------------------------- /docs/source/_build/html/.doctrees/GettingStarted.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/.doctrees/GettingStarted.doctree -------------------------------------------------------------------------------- /docs/source/_build/html/.doctrees/Quickstart copy.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/.doctrees/Quickstart copy.doctree -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Lato/lato-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Lato/lato-bold.eot -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Lato/lato-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Lato/lato-bold.ttf -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Lato/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Lato/lato-bold.woff -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/RobotoSlab-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/RobotoSlab-Bold.ttf -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Inconsolata-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Inconsolata-Bold.ttf -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Lato/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Lato/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Lato/lato-italic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Lato/lato-italic.eot -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Lato/lato-italic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Lato/lato-italic.ttf -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Lato/lato-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Lato/lato-italic.woff -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Lato/lato-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Lato/lato-regular.eot -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Lato/lato-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Lato/lato-regular.ttf -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Inconsolata-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Inconsolata-Regular.ttf -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Lato/lato-bolditalic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Lato/lato-bolditalic.eot -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Lato/lato-bolditalic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Lato/lato-bolditalic.ttf -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Lato/lato-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Lato/lato-italic.woff2 -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Lato/lato-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Lato/lato-regular.woff -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Lato/lato-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Lato/lato-regular.woff2 -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/RobotoSlab-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/RobotoSlab-Regular.ttf -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Lato/lato-bolditalic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Lato/lato-bolditalic.woff -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/Lato/lato-bolditalic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/Lato/lato-bolditalic.woff2 -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf -------------------------------------------------------------------------------- /docs/source/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | .wy-side-nav-search{ background-color: rgba(47, 85, 153, 0.69) } 2 | 3 | .wy-side-nav-search > div.version { 4 | color:rgba(255, 255, 255, 0.5) 5 | } 6 | 7 | -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2 -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff -------------------------------------------------------------------------------- /docs/source/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airaria/TextBrewer/HEAD/docs/source/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2 -------------------------------------------------------------------------------- /docs/source/_build/html/.buildinfo: -------------------------------------------------------------------------------- 1 | # Sphinx build info version 1 2 | # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. 3 | config: b8d40d925d79662c88241bdbb65bd991 4 | tags: 645f666f9bcd5a90fca523b33c5a78b7 5 | -------------------------------------------------------------------------------- /src/textbrewer/distillers.py: -------------------------------------------------------------------------------- 1 | from .distiller_train import BasicTrainer 2 | from .distiller_basic import BasicDistiller 3 | from .distiller_general import GeneralDistiller 4 | from .distiller_multitask import MultiTaskDistiller 5 | from .distiller_multiteacher import MultiTeacherDistiller 6 | -------------------------------------------------------------------------------- /docs/source/Configurations.rst: -------------------------------------------------------------------------------- 1 | Configurations 2 | ============== 3 | 4 | TrainingConfig 5 | -------------- 6 | 7 | .. autoclass:: textbrewer.TrainingConfig 8 | :members: from_json_file, from_dict 9 | 10 | DistillationConfig 11 | ------------------ 12 | 13 | .. autoclass:: textbrewer.DistillationConfig 14 | :members: from_json_file, from_dict -------------------------------------------------------------------------------- /src/textbrewer/compatibility.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | if torch.__version__ < '1.2': 4 | mask_dtype = torch.uint8 5 | else: 6 | mask_dtype = torch.bool 7 | 8 | def is_apex_available(): 9 | try: 10 | from apex import amp 11 | _has_apex = True 12 | except ImportError: 13 | _has_apex = False 14 | return _has_apex -------------------------------------------------------------------------------- /examples/mnli_example/jsons/bert_config_L4t.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 312, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 1200, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 4, 11 | "type_vocab_size": 2, 12 | "vocab_size": 30522 13 | } 14 | -------------------------------------------------------------------------------- /docs/source/_build/html/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | var DOCUMENTATION_OPTIONS = { 2 | URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), 3 | VERSION: '0.1.8', 4 | LANGUAGE: 'None', 5 | COLLAPSE_INDEX: false, 6 | BUILDER: 'html', 7 | FILE_SUFFIX: '.html', 8 | HAS_SOURCE: true, 9 | SOURCELINK_SUFFIX: '.txt', 10 | NAVIGATION_WITH_KEYS: false 11 | }; -------------------------------------------------------------------------------- /examples/random_tokens_example/bert_config/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 12, 11 | "type_vocab_size": 2, 12 | "vocab_size": 30522 13 | } 14 | -------------------------------------------------------------------------------- /examples/random_tokens_example/bert_config/bert_config_T3.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 3, 11 | "type_vocab_size": 2, 12 | "vocab_size": 30522 13 | } 14 | -------------------------------------------------------------------------------- /examples/student_config/bert_base_cased_config/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 12, 11 | "type_vocab_size": 2, 12 | "vocab_size": 28996 13 | } 14 | -------------------------------------------------------------------------------- /examples/student_config/bert_base_cased_config/bert_config_L3.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 3, 11 | "type_vocab_size": 2, 12 | "vocab_size": 28996 13 | } 14 | -------------------------------------------------------------------------------- /examples/student_config/bert_base_cased_config/bert_config_L3n.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 384, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 1536, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 3, 11 | "type_vocab_size": 2, 12 | "vocab_size": 28996 13 | } 14 | -------------------------------------------------------------------------------- /examples/student_config/bert_base_cased_config/bert_config_L4t.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 312, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 1200, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 4, 11 | "type_vocab_size": 2, 12 | "vocab_size": 28996 13 | } 14 | -------------------------------------------------------------------------------- /examples/student_config/bert_base_cased_config/bert_config_L6.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 6, 11 | "type_vocab_size": 2, 12 | "vocab_size": 28996 13 | } 14 | -------------------------------------------------------------------------------- /examples/student_config/roberta_wwm_config/bert_config_L4t.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 312, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 1200, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 4, 11 | "type_vocab_size": 2, 12 | "vocab_size": 21128 13 | } 14 | -------------------------------------------------------------------------------- /examples/mnli_example/jsons/TrainBertTeacher.json: -------------------------------------------------------------------------------- 1 | { 2 | "teachers":[], 3 | "student":{ 4 | "model_type":"bert", 5 | "prefix":"bertuncased", 6 | "vocab_file":"/path/to/vocab.txt", 7 | "config_file":"/path/to/config.json", 8 | "checkpoint":"/path/to/bert/original/weight(pytorch_model.bin)", 9 | "tokenizer_kwargs":{"do_lower_case": true}, 10 | "disable":false 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /examples/conll2003_example/run_conll2003_train.sh: -------------------------------------------------------------------------------- 1 | export OUTPUT_DIR=outputs-model-base 2 | export BATCH_SIZE=32 3 | export NUM_EPOCHS=3 4 | export SAVE_STEPS=750 5 | export SEED=42 6 | MAX_LENGTH=128 7 | export BERT_MODEL=/path/to/bert/model 8 | python run_ner.py \ 9 | --data_dir data \ 10 | --model_type bert \ 11 | --model_name_or_path $BERT_MODEL \ 12 | --output_dir $OUTPUT_DIR \ 13 | --max_seq_length $MAX_LENGTH \ 14 | --num_train_epochs $NUM_EPOCHS \ 15 | --per_gpu_train_batch_size $BATCH_SIZE \ 16 | --save_steps $SAVE_STEPS \ 17 | --seed $SEED \ 18 | --do_train \ 19 | --do_eval \ 20 | --do_predict 21 | -------------------------------------------------------------------------------- /examples/student_config/roberta_wwm_config/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 21128 19 | } 20 | -------------------------------------------------------------------------------- /examples/student_config/roberta_wwm_config/bert_config_L3.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 3, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 21128 19 | } 20 | -------------------------------------------------------------------------------- /examples/student_config/roberta_wwm_config/bert_config_L3n.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 384, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 1536, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 3, 12 | "pooler_fc_size": 384, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 21128 19 | } 20 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /examples/conll2003_example/run_conll2003_distill_T3.sh: -------------------------------------------------------------------------------- 1 | export OUTPUT_DIR=output-model-T3 2 | export BATCH_SIZE=32 3 | export NUM_EPOCHS=2 4 | export SAVE_STEPS=750 5 | export SEED=42 6 | export MAX_LENGTH=128 7 | export BERT_MODEL_TEACHER=path/to/teacher/dir 8 | python run_ner_distill.py \ 9 | --data_dir data \ 10 | --model_type bert \ 11 | --model_name_or_path $BERT_MODEL_TEACHER \ 12 | --output_dir $OUTPUT_DIR \ 13 | --max_seq_length $MAX_LENGTH \ 14 | --num_train_epochs $NUM_EPOCHS \ 15 | --per_gpu_train_batch_size $BATCH_SIZE \ 16 | --num_hidden_layers 3 \ 17 | --save_steps $SAVE_STEPS \ 18 | --learning_rate 1e-4 \ 19 | --warmup_steps 0.1 \ 20 | --seed $SEED \ 21 | --do_distill \ 22 | --do_train \ 23 | --do_eval \ 24 | --do_predict 25 | -------------------------------------------------------------------------------- /examples/random_tokens_example/README.md: -------------------------------------------------------------------------------- 1 | # Simple Example 2 | 3 | This runable example demonstrates the usage of TextBrewer. 4 | 5 | Teacher is BERT-base. Student is a 3-layer BERT. 6 | 7 | Task is text classification. We generate some random token ids and labels as inputs. 8 | 9 | So this simple example is for pedagogical purpose only. 10 | 11 | We also list the summarization of model parameters using the utility provided by the toolkit. 12 | 13 | 14 | ## Requirements 15 | 16 | PyTorch >= 1.0 17 | 18 | transformers >= 2.0 19 | 20 | tensorboard 21 | 22 | ## Run 23 | ```python 24 | python distill.py 25 | ``` 26 | 27 | ## Screenshots 28 | 29 | ![screenshot1](screenshots/screenshot1.png) 30 | 31 | ![screenshot2](screenshots/screenshot2.png) -------------------------------------------------------------------------------- /examples/conll2003_example/README_ZH.md: -------------------------------------------------------------------------------- 1 | [**中文说明**](README_ZH.md) | [**English**](README.md) 2 | 3 | 这个例子展示CoNLL-2003英文NER任务上的蒸馏。 4 | 5 | * run_conll2003_train.sh : 在CoNLL-2003英文NER数据集上训练教师模型(BERT-base-cased) 6 | * run_conll2003_distill_T3.sh : 在CoNLL-2003英文NER数据集上蒸馏教师模型到T3(三层BERT) 7 | 8 | 运行要求 9 | 10 | * Transformers 11 | * seqeval 12 | 13 | 运行脚本前,请根据自己的环境设置相应变量: 14 | 15 | * BERT_MODEL : 存放BERT-base模型的目录,包含vocab.txt, pytorch_model.bin, config.json 16 | * OUTPUT_DIR : 存放训练好的模型权重文件 17 | * BERT_MODEL_TEACHER : 存放训练好的teacher模型的目录, 包含vocab.txt, pytorch_model.bin, config.json 18 | * data : 包含CoNLL-2003数据集(包含train.txt,dev.txt,test.txt 三个文件) 19 | 20 | 示例包含两个部分: 21 | 22 | * 训练教师模型:./run_conll2003_train.sh 23 | * 蒸馏三层BERT学生模型:./run_conll2003_distill_T3.sh 24 | -------------------------------------------------------------------------------- /examples/mnli_example/jsons/DistillBertToTiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "teachers":[ 3 | { 4 | "model_type":"bert", 5 | "prefix":"bertuncased", 6 | "vocab_file":"/path/to/vocab.txt", 7 | "config_file":"/path/to/config.json", 8 | "checkpoint":"/path/to/fine-tuned/checkpoint", 9 | "tokenizer_kwargs":{"do_lower_case": true}, 10 | "disable":false 11 | } 12 | ], 13 | "student":{ 14 | "model_type":"bert", 15 | "prefix":"bertuncased", 16 | "vocab_file":"/path/to/vocab.txt", 17 | "config_file":"./jsons/bert_config_L4t.json", 18 | "checkpoint": null, 19 | "tokenizer_kwargs":{"do_lower_case": true}, 20 | "disable":false 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /docs/source/Utils.rst: -------------------------------------------------------------------------------- 1 | Model Utils 2 | ============== 3 | 4 | display_parameters 5 | ------------------ 6 | 7 | .. autofunction:: textbrewer.utils.display_parameters 8 | 9 | 10 | Data Utils 11 | ========== 12 | 13 | This module provides the following data augmentation methods. 14 | 15 | masking 16 | ------- 17 | 18 | .. autofunction:: textbrewer.data_utils.masking 19 | 20 | deleting 21 | -------- 22 | 23 | .. autofunction:: textbrewer.data_utils.deleting 24 | 25 | n_gram_sampling 26 | --------------- 27 | 28 | .. autofunction:: textbrewer.data_utils.n_gram_sampling 29 | 30 | short_disorder 31 | -------------- 32 | 33 | .. autofunction:: textbrewer.data_utils.short_disorder 34 | 35 | long_disorder 36 | ------------- 37 | 38 | .. autofunction:: textbrewer.data_utils.long_disorder -------------------------------------------------------------------------------- /src/textbrewer/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.1.post1" 2 | 3 | from .distillers import BasicTrainer 4 | from .distillers import BasicDistiller 5 | from .distillers import GeneralDistiller 6 | from .distillers import MultiTaskDistiller 7 | from .distillers import MultiTeacherDistiller 8 | 9 | 10 | from .configurations import TrainingConfig, DistillationConfig 11 | 12 | from .presets import FEATURES 13 | from .presets import ADAPTOR_KEYS 14 | from .presets import KD_LOSS_MAP, MATCH_LOSS_MAP, PROJ_MAP 15 | from .presets import WEIGHT_SCHEDULER, TEMPERATURE_SCHEDULER 16 | from .presets import register_new 17 | 18 | Distillers = { 19 | 'Basic': BasicDistiller, 20 | 'General': GeneralDistiller, 21 | 'MultiTeacher': MultiTeacherDistiller, 22 | 'MultiTask': MultiTaskDistiller, 23 | 'Train': BasicTrainer 24 | } 25 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 5 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 7 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | # Label to use when marking an issue as stale 10 | staleLabel: stale 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: > 18 | Closing the issue, since no updates observed. 19 | Feel free to re-open if you need any further assistance. 20 | -------------------------------------------------------------------------------- /examples/cmrc2018_example/README_ZH.md: -------------------------------------------------------------------------------- 1 | [**中文说明**](README_ZH.md) | [**English**](README.md) 2 | 3 | 这个例子展示CMRC 2018阅读理解任务上的蒸馏,并使用DRCD数据集作为数据增强。 4 | 5 | * run_cmrc2018_train.sh : 在cmrc2018数据集上训练教师模型(roberta-wwm-base) 6 | * run_cmrc2018_distill_T3.sh : 在cmrc2018和drcd数据集上蒸馏教师模型到T3 7 | * run_cmrc2018_distill_T4tiny.sh : 在cmrc2018和drcd数据集上蒸馏教师模型到T4-tiny 8 | 9 | 运行脚本前,请根据自己的环境设置相应变量: 10 | 11 | * BERT_DIR : 存放RoBERTa-wwm-base模型的目录,包含vocab.txt, pytorch_model.bin, bert_config.json 12 | * OUTPUT_ROOT_DIR : 存放训练好的模型权重文件和日志 13 | * DATA_ROOT_DIR : 包含cmrc2018数据集和drcd数据集: 14 | * \$\{DATA_ROOT_DIR\}/cmrc2018/squad-style-data/cmrc2018_train.json 15 | * \$\{DATA_ROOT_DIR\}/cmrc2018/squad-style-data/cmrc2018_dev.json 16 | * \$\{DATA_ROOT_DIR\}/drcd/DRCD_training.json 17 | * 如果是运行 run_cmrc2018_distill_T3.sh 和 run_cmrc2018_distill_T4tiny.sh, 还需要指定训练好的教师模型权重文件 trained_teacher_model 18 | -------------------------------------------------------------------------------- /examples/conll2003_example/README.md: -------------------------------------------------------------------------------- 1 | [**中文说明**](README_ZH.md) | [**English**](README.md) 2 | 3 | This example demonstrates distillation on CoNLL-2003 English NER task. 4 | 5 | * run_conll2003_train.sh : trains a treacher model (BERT-base-cased) on CoNLL-2003. 6 | * run_conll2003_distill_T3.sh : distills the teacher to T3. 7 | 8 | Requirements 9 | 10 | * Transformers 11 | * seqeval 12 | 13 | Set the following variables in the shell scripts before running: 14 | 15 | * BERT_MODEL : this is where BERT-base-cased stores, including vocab.txt, pytorch_model.bin, config.json 16 | * OUTPUT_DIR : this directory stores model weights 17 | * BERT_MODEL_TEACHER : this directory stores the trained teacher model weights (for distillation). 18 | * data : this directory includes CoNLL-2003 dataset (contains train.txt, dev.txt and test.txt) 19 | 20 | This example contains: 21 | 22 | * Teacher Model training : ./run_conll2003_train.sh 23 | * Distillation to student model T3 : ./run_conll2003_distill_T3.sh 24 | -------------------------------------------------------------------------------- /examples/msra_ner_example/README_ZH.md: -------------------------------------------------------------------------------- 1 | [**中文说明**](README_ZH.md) | [**English**](README.md) 2 | 3 | 这个例子展示MSRA NER(中文命名实体识别)任务上,在**分布式数据并行训练**(Distributed Data-Parallel, DDP)模式(single node, muliti-GPU)下的[Chinese-ELECTRA-base](https://github.com/ymcui/Chinese-ELECTRA)模型蒸馏。 4 | 5 | 6 | * ner_ElectraTrain_dist.sh : 训练教师模型(ELECTRA-base)。 7 | * ner_ElectraDistill_dist.sh : 将教师模型蒸馏到学生模型(ELECTRA-small)。 8 | 9 | 10 | 运行脚本前,请根据自己的环境设置相应变量: 11 | 12 | * ELECTRA_DIR_BASE : 存放Chinese-ELECTRA-base模型的目录,包含vocab.txt,pytorch_model.bin和config.json。 13 | 14 | * OUTPUT_DIR : 存放训练好的模型权重文件和日志。 15 | * DATA_DIR : MSRA NER数据集目录,包含 16 | * msra_train_bio.txt 17 | * msra_test_bio.txt 18 | 19 | 对于蒸馏,需要设置: 20 | 21 | * ELECTRA_DIR_SMALL : Chinese-ELECTRA-small预训练权重所在目录。应包含pytorch_model.bin。 也可不提供预训练权重,则学生模型将随机初始化。 22 | * student_config_file : 学生模型配置文件,一般文件名为config.json,也位于 $\{ELECTRA_DIR_SMALL\}。 23 | * trained_teacher_model_file : 在MSRA NER任务上训练好的ELECTRA-base教师模型。 24 | 25 | 该脚本在 **PyTorch==1.2, Transformers==2.8** 下测试通过。 -------------------------------------------------------------------------------- /docs/source/Distillers.rst: -------------------------------------------------------------------------------- 1 | .. _distillers: 2 | 3 | Distillers 4 | =========== 5 | 6 | Distillers perform the actual experiments. 7 | 8 | Initialize a distiller object, call its `train` method to start training/distillation. 9 | 10 | BasicDistiller 11 | --------------- 12 | 13 | .. autoclass:: textbrewer.BasicDistiller 14 | :members: train 15 | 16 | GeneralDistiller 17 | ----------------- 18 | 19 | .. autoclass:: textbrewer.GeneralDistiller 20 | :members: train 21 | 22 | MultiTeacherDistiller 23 | --------------------- 24 | 25 | .. autoclass:: textbrewer.MultiTeacherDistiller 26 | 27 | .. method:: train(self, optimizer, scheduler, dataloader, num_epochs, num_steps=None, callback=None, batch_postprocessor=None, **args) 28 | 29 | trains the student model. See :meth:`BasicDistiller.train`. 30 | 31 | MultiTaskDistiller 32 | ------------------ 33 | .. autoclass:: textbrewer.MultiTaskDistiller 34 | :members: train 35 | 36 | BasicTrainer 37 | ------------ 38 | .. autoclass:: textbrewer.BasicTrainer 39 | :members: train -------------------------------------------------------------------------------- /docs/source/Losses.rst: -------------------------------------------------------------------------------- 1 | .. _intermediate_losses: 2 | 3 | Intermediate Losses 4 | =================== 5 | Here we list the definitions of pre-defined intermediate losses. 6 | Usually, users don't need to refer to these functions directly, but refer to them by the names in :obj:`MATCH_LOSS_MAP`. 7 | 8 | attention_mse 9 | ------------- 10 | .. autofunction:: textbrewer.losses.att_mse_loss 11 | 12 | attention_mse_sum 13 | ----------------- 14 | .. autofunction:: textbrewer.losses.att_mse_sum_loss 15 | 16 | attention_ce 17 | ----------------- 18 | .. autofunction:: textbrewer.losses.att_ce_loss 19 | 20 | attention_ce_mean 21 | ----------------- 22 | .. autofunction:: textbrewer.losses.att_ce_mean_loss 23 | 24 | hidden_mse 25 | ---------- 26 | .. autofunction:: textbrewer.losses.hid_mse_loss 27 | 28 | cos 29 | --- 30 | .. autofunction:: textbrewer.losses.cos_loss 31 | 32 | pkd 33 | --- 34 | .. autofunction:: textbrewer.losses.pkd_loss 35 | 36 | nst (mmd) 37 | --------- 38 | .. autofunction:: textbrewer.losses.mmd_loss 39 | 40 | 41 | fsp (gram) 42 | ---------- 43 | .. autofunction:: textbrewer.losses.fsp_loss -------------------------------------------------------------------------------- /examples/mnli_example/run_mnli_train.sh: -------------------------------------------------------------------------------- 1 | #set hyperparameters 2 | OUTPUT_ROOT_DIR=/path/to/output_root_dir 3 | DATA_ROOT_DIR=/path/to/data_root_dir 4 | 5 | 6 | accu=1 7 | ep=3 8 | lr=2 9 | temperature=8 10 | batch_size=32 11 | length=128 12 | torch_seed=9580 13 | 14 | taskname='mnli' 15 | NAME=${taskname}_base_lr${lr}e${ep}_bs${batch_size}_teacher 16 | DATA_DIR=${DATA_ROOT_DIR}/MNLI 17 | OUTPUT_DIR=${OUTPUT_ROOT_DIR}/${NAME} 18 | 19 | mkdir -p $OUTPUT_DIR 20 | model_config_json_file=TrainBertTeacher.json 21 | cp jsons/${model_config_json_file} ${OUTPUT_DIR}/${model_config_json_file}.run 22 | 23 | 24 | python -u main.trainer.py \ 25 | --data_dir $DATA_DIR \ 26 | --do_train \ 27 | --do_eval \ 28 | --do_predict \ 29 | --max_seq_length ${length} \ 30 | --train_batch_size ${batch_size} \ 31 | --random_seed $torch_seed \ 32 | --num_train_epochs ${ep} \ 33 | --learning_rate ${lr}e-5 \ 34 | --ckpt_frequency 2 \ 35 | --output_dir $OUTPUT_DIR \ 36 | --gradient_accumulation_steps ${accu} \ 37 | --task_name ${taskname} \ 38 | --fp16 \ 39 | --model_config_json ${OUTPUT_DIR}/${model_config_json_file}.run 40 | -------------------------------------------------------------------------------- /examples/cmrc2018_example/README.md: -------------------------------------------------------------------------------- 1 | [**中文说明**](README_ZH.md) | [**English**](README.md) 2 | 3 | This example demonstrates distilltion on CMRC 2018 task, and using DRCD dataset as data augmentation. 4 | 5 | 6 | * run_cmrc2018_train.sh : trains a treacher model (roberta-wwm-base) on CMRC 2018. 7 | * run_cmrc2018_distill_T3.sh : distills the teacher to T3 with CMRC 2018 and DRCD datasets. 8 | * run_cmrc2018_distill_T4tiny.sh : distills the teacher to T4tiny with CMRC 2018 and DRCD datasets. 9 | 10 | Set the following variables in the shell scripts before running: 11 | 12 | * BERT_DIR : where RoBERTa-wwm-base stores,including vocab.txt, pytorch_model.bin, bert_config.json 13 | * OUTPUT_ROOT_DIR : this directory stores logs and trained model weights 14 | * DATA_ROOT_DIR : it includes CMRC 2018 and DRCD datasets: 15 | * \$\{DATA_ROOT_DIR\}/cmrc2018/squad-style-data/cmrc2018_train.json 16 | * \$\{DATA_ROOT_DIR\}/cmrc2018/squad-style-data/cmrc2018_dev.json 17 | * \$\{DATA_ROOT_DIR\}/drcd/DRCD_training.json 18 | * The trained teacher weights file *trained_teacher_model* has to be specified if running run_cmrc2018_distill_T3.sh or run_cmrc2018_distill_T4tiny.sh. 19 | -------------------------------------------------------------------------------- /src/textbrewer/projections.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .utils import initializer_builder 3 | from typing import List 4 | 5 | act_dict = {} 6 | for k,v in torch.nn.modules.activation.__dict__.items(): 7 | if not k.startswith('__'): 8 | act_dict[k] = v 9 | 10 | def linear_projection(dim_in, dim_out): 11 | model = torch.nn.Linear(in_features=dim_in, out_features=dim_out, bias=True) 12 | initializer = initializer_builder(0.02) 13 | model.apply(initializer) 14 | return model 15 | 16 | def projection_with_activation(act_fn): 17 | if type(act_fn) is str: 18 | assert act_fn in act_dict, f"invalid activations, please choice from {list(act_dict.keys())}" 19 | act_fn = act_dict[act_fn]() 20 | else: 21 | assert isinstance(act_fn,torch.nn.Module), "act_fn must be a string or module" 22 | act_fn = act_fn() 23 | def projection(dim_in, dim_out): 24 | model = torch.nn.Sequential( 25 | torch.nn.Linear(in_features=dim_in, out_features=dim_out, bias=True), 26 | act_fn) 27 | initializer = initializer_builder(0.02) 28 | model.apply(initializer) 29 | return model 30 | return projection -------------------------------------------------------------------------------- /docs/source/_build/html/_sources/index.rst.txt: -------------------------------------------------------------------------------- 1 | .. TextBrewer documentation master file, created by 2 | sphinx-quickstart on Tue Mar 17 17:23:13 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | TextBrewer 7 | ====================================== 8 | 9 | .. .. image:: ../../pics/banner.png 10 | .. :width: 500 11 | 12 | **TextBrewer** is a PyTorch-based toolkit for **distillation of NLP models**. 13 | 14 | It includes various distilltion techniques from both NLP and CV, and provides an easy-to-use distillation framework, which allows users to quickly experiment with state-of-the-art distillation methods to compress the model with a relatively small sacrifice in performance, increase the inference speed and reduce the memory usage. 15 | 16 | Paper: `TextBrewer: An Open-Source Knowledge Distillation Toolkit for Natural Language Processing `_ 17 | 18 | 19 | 20 | .. toctree:: 21 | :maxdepth: 2 22 | :caption: Getting Started 23 | 24 | GettingStarted 25 | 26 | Indices and tables 27 | ================== 28 | 29 | * :ref:`genindex` 30 | * :ref:`modindex` 31 | * :ref:`search` 32 | -------------------------------------------------------------------------------- /examples/msra_ner_example/ner_ElectraTrain_dist.sh: -------------------------------------------------------------------------------- 1 | ELECTRA_DIR_BASE=/path/to/Chinese-electra-base 2 | OUTPUT_DIR=/path/to/output_dir 3 | DATA_DIR=/path/to/MSRA_NER_data 4 | 5 | mkdir -p $OUTPUT_DIR 6 | 7 | ngpu=2 8 | lr=10 9 | batch_size=12 10 | length=160 11 | ep=2 12 | lr=30 13 | 14 | python -m torch.distributed.launch --nproc_per_node=${ngpu} main.train.dist.py \ 15 | --vocab_file $ELECTRA_DIR_BASE/vocab.txt \ 16 | --do_lower_case \ 17 | --bert_config_file_T none \ 18 | --bert_config_file_S $ELECTRA_DIR_BASE/config.json \ 19 | --init_checkpoint_S $ELECTRA_DIR_BASE/pytorch_model.bin \ 20 | --do_train \ 21 | --do_eval \ 22 | --do_predict \ 23 | --max_seq_length ${length} \ 24 | --train_batch_size ${batch_size} \ 25 | --random_seed 1337 \ 26 | --train_file $DATA_DIR/msra_train_bio.txt \ 27 | --predict_file $DATA_DIR/msra_test_bio.txt \ 28 | --num_train_epochs ${ep} \ 29 | --learning_rate ${lr}e-5 \ 30 | --ckpt_frequency 2 \ 31 | --official_schedule linear \ 32 | --output_dir $OUTPUT_DIR \ 33 | --gradient_accumulation_steps 1 \ 34 | --output_encoded_layers true \ 35 | --output_attention_layers false \ 36 | --lr_decay 0.8 37 | #--fp16 38 | -------------------------------------------------------------------------------- /examples/mnli_example/run_mnli_distill_T4tiny_emd.sh: -------------------------------------------------------------------------------- 1 | #set hyperparameters 2 | OUTPUT_ROOT_DIR=/path/to/output_root_dir 3 | DATA_ROOT_DIR=/path/to/data_root_dir 4 | 5 | 6 | accu=1 7 | ep=40 8 | lr=10 9 | temperature=8 10 | batch_size=32 11 | length=128 12 | torch_seed=9580 13 | 14 | taskname='mnli' 15 | NAME=${taskname}_t${temperature}_TbaseST4tiny_EMDHiddenMSE_lr${lr}e${ep}_bs${batch_size} 16 | DATA_DIR=${DATA_ROOT_DIR}/MNLI 17 | OUTPUT_DIR=${OUTPUT_ROOT_DIR}/${NAME} 18 | 19 | mkdir -p $OUTPUT_DIR 20 | model_config_json_file=DistillBertToTiny.json 21 | cp jsons/${model_config_json_file} ${OUTPUT_DIR}/${model_config_json_file}.run 22 | 23 | 24 | python -u main.emd.py \ 25 | --data_dir $DATA_DIR \ 26 | --do_train \ 27 | --do_eval \ 28 | --do_predict \ 29 | --max_seq_length ${length} \ 30 | --train_batch_size ${batch_size} \ 31 | --random_seed $torch_seed \ 32 | --num_train_epochs ${ep} \ 33 | --learning_rate ${lr}e-5 \ 34 | --ckpt_frequency 1 \ 35 | --output_dir $OUTPUT_DIR \ 36 | --gradient_accumulation_steps ${accu} \ 37 | --temperature ${temperature} \ 38 | --task_name ${taskname} \ 39 | --model_config_json ${OUTPUT_DIR}/${model_config_json_file}.run \ 40 | --fp16 41 | -------------------------------------------------------------------------------- /examples/mnli_example/run_mnli_distill_multiteacher.sh: -------------------------------------------------------------------------------- 1 | #set hyperparameters 2 | OUTPUT_ROOT_DIR=/path/to/output_root_dir 3 | DATA_ROOT_DIR=/path/to/data_root_dir 4 | 5 | 6 | accu=1 7 | ep=10 8 | lr=2 9 | temperature=8 10 | batch_size=32 11 | length=128 12 | torch_seed=9580 13 | 14 | taskname='mnli' 15 | NAME=${taskname}_t${temperature}_MTbaseST4tiny_lr${lr}e${ep}_bs${batch_size} 16 | DATA_DIR=${DATA_ROOT_DIR}/MNLI 17 | OUTPUT_DIR=${OUTPUT_ROOT_DIR}/${NAME} 18 | 19 | mkdir -p $OUTPUT_DIR 20 | model_config_json_file=DistillMultiBertToTiny.json 21 | cp jsons/${model_config_json_file} ${OUTPUT_DIR}/${model_config_json_file}.run 22 | 23 | 24 | python -u main.multiteacher.py \ 25 | --data_dir $DATA_DIR \ 26 | --do_train \ 27 | --do_eval \ 28 | --do_predict \ 29 | --max_seq_length ${length} \ 30 | --train_batch_size ${batch_size} \ 31 | --random_seed $torch_seed \ 32 | --num_train_epochs ${ep} \ 33 | --learning_rate ${lr}e-5 \ 34 | --ckpt_frequency 1 \ 35 | --output_dir $OUTPUT_DIR \ 36 | --gradient_accumulation_steps ${accu} \ 37 | --temperature ${temperature} \ 38 | --task_name ${taskname} \ 39 | --model_config_json ${OUTPUT_DIR}/${model_config_json_file}.run \ 40 | --fp16 41 | -------------------------------------------------------------------------------- /examples/mnli_example/run_mnli_distill_T4tiny.sh: -------------------------------------------------------------------------------- 1 | #set hyperparameters 2 | OUTPUT_ROOT_DIR=/path/to/output_root_dir 3 | DATA_ROOT_DIR=/path/to/data_root_dir 4 | 5 | 6 | accu=1 7 | ep=40 8 | lr=10 9 | temperature=8 10 | batch_size=32 11 | length=128 12 | torch_seed=9580 13 | 14 | taskname='mnli' 15 | NAME=${taskname}_t${temperature}_TbaseST4tiny_L4SmmdMSE_lr${lr}e${ep}_bs${batch_size} 16 | DATA_DIR=${DATA_ROOT_DIR}/MNLI 17 | OUTPUT_DIR=${OUTPUT_ROOT_DIR}/${NAME} 18 | 19 | mkdir -p $OUTPUT_DIR 20 | model_config_json_file=DistillBertToTiny.json 21 | cp jsons/${model_config_json_file} ${OUTPUT_DIR}/${model_config_json_file}.run 22 | 23 | 24 | python -u main.distill.py \ 25 | --data_dir $DATA_DIR \ 26 | --do_train \ 27 | --do_eval \ 28 | --do_predict \ 29 | --max_seq_length ${length} \ 30 | --train_batch_size ${batch_size} \ 31 | --random_seed $torch_seed \ 32 | --num_train_epochs ${ep} \ 33 | --learning_rate ${lr}e-5 \ 34 | --ckpt_frequency 1 \ 35 | --output_dir $OUTPUT_DIR \ 36 | --gradient_accumulation_steps ${accu} \ 37 | --temperature ${temperature} \ 38 | --task_name ${taskname} \ 39 | --model_config_json ${OUTPUT_DIR}/${model_config_json_file}.run \ 40 | --fp16 \ 41 | --matches L4t_hidden_mse L4_hidden_smmd 42 | -------------------------------------------------------------------------------- /examples/msra_ner_example/README.md: -------------------------------------------------------------------------------- 1 | [**中文说明**](README_ZH.md) | [**English**](README.md) 2 | 3 | This example demonstrates distilling a [Chinese-ELECTRA-base](https://github.com/ymcui/Chinese-ELECTRA) model on the MSRA NER task with **distributed data-parallel training**(single node, muliti-GPU). 4 | 5 | 6 | * ner_ElectraTrain_dist.sh : trains a treacher model (Chinese-ELECTRA-base) on MSRA NER. 7 | * ner_ElectraDistill_dist.sh : distills the teacher to a ELECTRA-small model. 8 | 9 | 10 | Set the following variables in the shell scripts before running: 11 | 12 | * ELECTRA_DIR_BASE : where Chinese-ELECTRA-base locates, should includ vocab.txt, pytorch_model.bin and config.json. 13 | 14 | * OUTPUT_DIR : this directory stores the logs and the trained model weights. 15 | * DATA_DIR : it includes MSRA NER dataset: 16 | * msra_train_bio.txt 17 | * msra_test_bio.txt 18 | 19 | For distillation: 20 | 21 | * ELECTRA_DIR_SMALL : where the pretrained Chinese-ELECTRA-small weight locates, should include pytorch_model.bin. This is optional. If you don't provide the ELECTRA-small weight, the student model will be initialized randomly. 22 | * student_config_file : the model config file (i.e., config.json) for the student. Usually it should be in $\{ELECTRA_DIR_SMALL\}. 23 | * trained_teacher_model_file : the ELECTRA-base teacher model that has been fine-tuned. 24 | 25 | The scripts have been tested under **PyTorch==1.2, Transformers==2.8**. -------------------------------------------------------------------------------- /examples/cmrc2018_example/pytorch_pretrained_bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.2" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | from .tokenization_openai import OpenAIGPTTokenizer 4 | from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) 5 | from .tokenization_gpt2 import GPT2Tokenizer 6 | 7 | from .modeling import (BertConfig, BertModel, BertForPreTraining, 8 | BertForMaskedLM, BertForNextSentencePrediction, 9 | BertForSequenceClassification, BertForMultipleChoice, 10 | BertForTokenClassification, BertForQuestionAnswering, 11 | load_tf_weights_in_bert) 12 | from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, 13 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, 14 | load_tf_weights_in_openai_gpt) 15 | from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel, 16 | load_tf_weights_in_transfo_xl) 17 | from .modeling_gpt2 import (GPT2Config, GPT2Model, 18 | GPT2LMHeadModel, GPT2DoubleHeadsModel, 19 | load_tf_weights_in_gpt2) 20 | 21 | from .optimization import BertAdam 22 | from .optimization_openai import OpenAIAdam 23 | 24 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME 25 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | ## Introduction 3 | 4 | First off, thank you for considering contributing to TextBrewer. 5 | 6 | TextBrewer is an open source project and we love to receive contributions from our community. There are many ways to contribute, including but are not limited to: 7 | 8 | * Writing tutorials and adding examples 9 | * Improving the documentation 10 | * Submitting new features, such as new new distillation methods 11 | * Reporting bugs 12 | 13 | The following is the guidelines for contributing to TextBrewer. Also, feel free to propose changes to this document. 14 | 15 | Everyone participating in the project should follow the [Code of Conduct](CODE_OF_CONDUCT.md). 16 | 17 | ## Before Making a Change 18 | 19 | * If you have found any bugs or have any features that you wish to added, please first discuss via the issue. 20 | 21 | ## Pull Request Process 22 | 23 | 1. Create your own fork of the code and clone it to your local computer 24 | 2. In your local repository, checkout a new branch and commit your changes 25 | 3. Push your branch back to the fork on GitHub 26 | 4. Submit a pull request from your fork, and describe your changes 27 | 28 | 29 | ## Coding conventions 30 | 31 | * Indent with four spaces 32 | * Naming conventions: `lower_case_with_underscores` for variables, methods and functions; `CapWords` for class names 33 | * Remember to add or update the docstrings and/or README.md after making any changes to the classes, methods or functions 34 | 35 | -------------------------------------------------------------------------------- /examples/msra_ner_example/ner_ElectraDistill_dist.sh: -------------------------------------------------------------------------------- 1 | ELECTRA_DIR_BASE=/path/to/Chinese-electra-base 2 | ELECTRA_DIR_SMALL=/path/to/Chinese-electra-small 3 | OUTPUT_DIR=/path/to/output_dir 4 | DATA_DIR=/path/to/MSRA_NER_data 5 | 6 | student_config_file=/path/to/student_config_file 7 | trained_teacher_model_file=/path/to/trained_teacher_model_file 8 | mkdir -p $OUTPUT_DIR 9 | 10 | ngpu=2 11 | lr=10 12 | temperature=8 13 | batch_size=12 14 | length=160 15 | ep=30 16 | lr=5 17 | 18 | python -m torch.distributed.launch --nproc_per_node=${ngpu} main.distill.dist.py \ 19 | --vocab_file $ELECTRA_DIR_BASE/vocab.txt \ 20 | --do_lower_case \ 21 | --bert_config_file_T $ELECTRA_DIR_BASE/config.json \ 22 | --tuned_checkpoint_T ${trained_teacher_model_file} \ 23 | --bert_config_file_S ${student_config_file} \ 24 | --init_checkpoint_S $ELECTRA_DIR_SMALL/pytorch_model.bin \ 25 | --do_train \ 26 | --do_eval \ 27 | --do_predict \ 28 | --max_seq_length ${length} \ 29 | --train_batch_size ${batch_size} \ 30 | --random_seed 1337 \ 31 | --train_file $DATA_DIR/msra_train_bio.txt \ 32 | --predict_file $DATA_DIR/msra_test_bio.txt \ 33 | --num_train_epochs ${ep} \ 34 | --learning_rate ${lr}e-5 \ 35 | --ckpt_frequency 1 \ 36 | --official_schedule linear \ 37 | --output_dir $OUTPUT_DIR \ 38 | --gradient_accumulation_steps 1 \ 39 | --temperature ${temperature} \ 40 | --output_encoded_layers true \ 41 | --output_attention_layers false 42 | # --fp16 43 | -------------------------------------------------------------------------------- /examples/cmrc2018_example/run_cmrc2018_train.sh: -------------------------------------------------------------------------------- 1 | #set hyperparameters 2 | BERT_DIR=/path/to/roberta-wwm-base 3 | OUTPUT_ROOT_DIR=/path/to/output_root_dir 4 | DATA_ROOT_DIR=/path/to/data_root_dir 5 | 6 | STUDENT_CONF_DIR=../student_config/roberta_wwm_config 7 | cmrc_train_file=$DATA_ROOT_DIR/cmrc2018/squad-style-data/cmrc2018_train.json 8 | cmrc_dev_file=$DATA_ROOT_DIR/cmrc2018/squad-style-data/cmrc2018_dev.json 9 | 10 | accu=1 11 | ep=2 12 | lr=3 13 | batch_size=24 14 | length=512 15 | torch_seed=9580 16 | 17 | NAME=cmrc2018_base_lr${lr}e${ep}_teacher 18 | OUTPUT_DIR=${OUTPUT_ROOT_DIR}/${NAME} 19 | 20 | 21 | 22 | mkdir -p $OUTPUT_DIR 23 | 24 | python -u main.trainer.py \ 25 | --vocab_file $BERT_DIR/vocab.txt \ 26 | --do_lower_case \ 27 | --bert_config_file_T none \ 28 | --bert_config_file_S $STUDENT_CONF_DIR/bert_config.json \ 29 | --init_checkpoint_S $BERT_DIR/pytorch_model.bin \ 30 | --do_train \ 31 | --do_eval \ 32 | --do_predict \ 33 | --doc_stride 320 \ 34 | --max_seq_length ${length} \ 35 | --train_batch_size ${batch_size} \ 36 | --random_seed $torch_seed \ 37 | --train_file $cmrc_train_file \ 38 | --predict_file $cmrc_dev_file \ 39 | --num_train_epochs ${ep} \ 40 | --learning_rate ${lr}e-5 \ 41 | --ckpt_frequency 1 \ 42 | --schedule slanted_triangular \ 43 | --s_opt1 30 \ 44 | --output_dir $OUTPUT_DIR \ 45 | --gradient_accumulation_steps ${accu} \ 46 | --output_encoded_layers false \ 47 | --output_attention_layers false 48 | -------------------------------------------------------------------------------- /examples/mnli_example/jsons/DistillMultiBertToTiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "teachers":[ 3 | { 4 | "model_type":"bert", 5 | "prefix":"bertuncased", 6 | "vocab_file":"/path/to/vocab.txt", 7 | "config_file":"/path/to/config.json", 8 | "checkpoint":"/path/to/teacher1/checkpoint", 9 | "tokenizer_kwargs":{"do_lower_case": true}, 10 | "disable":false 11 | }, 12 | { 13 | "model_type":"bert", 14 | "prefix":"bertuncased", 15 | "vocab_file":"/path/to/vocab.txt", 16 | "config_file":"/path/to/config.json", 17 | "checkpoint":"/path/to/teacher2/checkpoint", 18 | "tokenizer_kwargs":{"do_lower_case": true}, 19 | "disable":false 20 | }, 21 | { 22 | "model_type":"bert", 23 | "prefix":"bertuncased", 24 | "vocab_file":"/path/to/vocab.txt", 25 | "config_file":"/path/to/config.json", 26 | "checkpoint":"/path/to/teacher3/checkpoint", 27 | "tokenizer_kwargs":{"do_lower_case": true}, 28 | "disable":false 29 | } 30 | ], 31 | "student":{ 32 | "model_type":"bert", 33 | "prefix":"bertuncased", 34 | "vocab_file":"/path/to/vocab.txt", 35 | "config_file":"/path/to/config.json", 36 | "checkpoint":"/path/to/bert/original/weight(pytorch_model.bin)", 37 | "tokenizer_kwargs":{"do_lower_case": true}, 38 | "disable":false 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /examples/msra_ner_example/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import config 4 | import logging 5 | logger = logging.getLogger("utils") 6 | logger.setLevel(logging.INFO) 7 | 8 | def divide_parameters(named_parameters,lr=None): 9 | no_decay = ['bias', 'LayerNorm.bias','LayerNorm.weight'] 10 | decay_parameters_names = list(zip(*[(p,n) for n,p in named_parameters if not any((di in n) for di in no_decay)])) 11 | no_decay_parameters_names = list(zip(*[(p,n) for n,p in named_parameters if any((di in n) for di in no_decay)])) 12 | param_group = [] 13 | if len(decay_parameters_names)>0: 14 | decay_parameters, decay_names = decay_parameters_names 15 | #print ("decay:",decay_names) 16 | if lr is not None: 17 | decay_group = {'params':decay_parameters, 'weight_decay': config.args.weight_decay_rate, 'lr':lr} 18 | else: 19 | decay_group = {'params': decay_parameters, 'weight_decay': config.args.weight_decay_rate} 20 | param_group.append(decay_group) 21 | 22 | if len(no_decay_parameters_names)>0: 23 | no_decay_parameters, no_decay_names = no_decay_parameters_names 24 | #print ("no decay:", no_decay_names) 25 | if lr is not None: 26 | no_decay_group = {'params': no_decay_parameters, 'weight_decay': 0.0, 'lr': lr} 27 | else: 28 | no_decay_group = {'params': no_decay_parameters, 'weight_decay': 0.0} 29 | param_group.append(no_decay_group) 30 | 31 | assert len(param_group)>0 32 | return param_group 33 | -------------------------------------------------------------------------------- /src/textbrewer/schedulers.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | 4 | # x is between 0 and 1 5 | def linear_growth_weight_scheduler(x): 6 | return x 7 | 8 | def linear_decay_weight_scheduler(x): 9 | return 1-x 10 | 11 | def constant_temperature_scheduler(logits_S, logits_T, base_temperature): 12 | ''' 13 | Remember to detach logits_S 14 | ''' 15 | return base_temperature 16 | 17 | 18 | def flsw_temperature_scheduler_builder(beta,gamma,eps=1e-4, *args): 19 | ''' 20 | adapted from arXiv:1911.07471 21 | ''' 22 | def flsw_temperature_scheduler(logits_S, logits_T, base_temperature): 23 | v = logits_S.detach() 24 | t = logits_T.detach() 25 | with torch.no_grad(): 26 | v = v/(torch.norm(v,dim=-1,keepdim=True)+eps) 27 | t = t/(torch.norm(t,dim=-1,keepdim=True)+eps) 28 | w = torch.pow((1 - (v*t).sum(dim=-1)),gamma) 29 | tau = base_temperature + (w.mean()-w)*beta 30 | return tau 31 | return flsw_temperature_scheduler 32 | 33 | 34 | def cwsm_temperature_scheduler_builder(beta,*args): 35 | ''' 36 | adapted from arXiv:1911.07471 37 | ''' 38 | def cwsm_temperature_scheduler(logits_S, logits_T, base_temperature): 39 | v = logits_S.detach() 40 | with torch.no_grad(): 41 | v = torch.softmax(v,dim=-1) 42 | v_max = v.max(dim=-1)[0] 43 | w = 1 / (v_max + 1e-3) 44 | tau = base_temperature + (w.mean()-w)*beta 45 | return tau 46 | return cwsm_temperature_scheduler 47 | -------------------------------------------------------------------------------- /src/textbrewer/snippet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import partial 3 | from textbrewer import GeneralDistiller 4 | from textbrewer import TrainingConfig, DistillationConfig 5 | 6 | # We omit the initialization of models, optimizer, and dataloader. 7 | teacher_model : torch.nn.Module = ... 8 | student_model : torch.nn.Module = ... 9 | dataloader : torch.utils.data.DataLoader = ... 10 | optimizer : torch.optim.Optimizer = ... 11 | scheduler : torch.optim.lr_scheduler = ... 12 | 13 | def simple_adaptor(batch, model_outputs): 14 | # We assume that the first element of model_outputs 15 | # is the logits before softmax 16 | return {'logits': model_outputs[0]} 17 | 18 | train_config = TrainingConfig() 19 | distill_config = DistillationConfig() 20 | distiller = GeneralDistiller( 21 | train_config=train_config, distill_config = distill_config, 22 | model_T = teacher_model, model_S = student_model, 23 | adaptor_T = simple_adaptor, adaptor_S = simple_adaptor) 24 | 25 | distiller.train(optimizer, scheduler, 26 | dataloader, num_epochs, callback=None) 27 | 28 | 29 | 30 | 31 | 32 | def predict(model, eval_dataset, step, args): 33 | raise NotImplementedError 34 | # fill other arguments 35 | my_callback = partial(predict, eval_dataset=my_eval_dataset, args=args) 36 | train_config = TrainingConfig() 37 | 38 | # 自定义的预测与评估函数 39 | def predict(model, eval_dataset, step, args): 40 | ''' 41 | eval_dataset: 验证集 42 | args: 评估中需要的其他参数 43 | ''' 44 | raise NotImplementedError 45 | 46 | # 填充多余的参数 47 | my_callback = partial(predict, eval_dataset=my_eval_dataset, args=args) 48 | distillator.train(..., callback = my_callback) 49 | -------------------------------------------------------------------------------- /docs/source/Presets.rst: -------------------------------------------------------------------------------- 1 | Presets 2 | ======= 3 | 4 | 5 | Presets include module variables that define pre-defined loss functions and strategies. 6 | 7 | Module variables 8 | ---------------- 9 | 10 | ADAPTOR_KEYS 11 | ^^^^^^^^^^^^ 12 | .. autodata:: textbrewer.presets.ADAPTOR_KEYS 13 | :annotation: 14 | 15 | KD_LOSS_MAP 16 | ^^^^^^^^^^^^ 17 | .. autodata:: textbrewer.presets.KD_LOSS_MAP 18 | :annotation: 19 | 20 | PROJ_MAP 21 | ^^^^^^^^ 22 | .. autodata:: textbrewer.presets.PROJ_MAP 23 | :annotation: 24 | 25 | MATCH_LOSS_MAP 26 | ^^^^^^^^^^^^^^ 27 | .. autodata:: textbrewer.presets.MATCH_LOSS_MAP 28 | :annotation: 29 | 30 | WEIGHT_SCHEDULER 31 | ^^^^^^^^^^^^^^^^ 32 | .. autodata:: textbrewer.presets.WEIGHT_SCHEDULER 33 | :annotation: 34 | 35 | TEMPERATURE_SCHEDULER 36 | ^^^^^^^^^^^^^^^^^^^^^ 37 | .. autodata:: textbrewer.presets.TEMPERATURE_SCHEDULER 38 | :annotation: 39 | 40 | Customization 41 | ------------- 42 | 43 | If the pre-defined modules do not satisfy your requirements, you can add your own defined modules to the above dict. 44 | 45 | For example:: 46 | 47 | MATCH_LOSS_MAP['my_L1_loss'] = my_L1_loss 48 | WEIGHT_SCHEDULER['my_weight_scheduler'] = my_weight_scheduler 49 | 50 | then used in :class:`~textbrewer.DistillationConfig`:: 51 | 52 | distill_config = DistillationConfig( 53 | kd_loss_weight_scheduler = 'my_weight_scheduler' 54 | intermediate_matches = [{'layer_T':0, 'layer_S':0, 'feature':'hidden','loss': 'my_L1_loss', 'weight' : 1}] 55 | ...) 56 | 57 | Refer to the source code for more details on inputs and outputs conventions (will be explained in detail in a later version of the documentation). -------------------------------------------------------------------------------- /examples/mnli_example/README_ZH.md: -------------------------------------------------------------------------------- 1 | [**中文说明**](README_ZH.md) | [**English**](README.md) 2 | 3 | 这个例子展示MNLI句对分类任务上的蒸馏,同时提供了一个**自定义distiller**的例子。 4 | 5 | * run_mnli_train.sh : 在MNLI数据上训练教师模型(bert-base)。 6 | * run_mnli_distill_T4tiny.sh : 在MNLI上蒸馏教师模型到T4Tiny。 7 | * run_mnli_distill_T4tiny_emd.sh:使用EMD方法自动计算隐层与隐层的匹配,而无需人工指定。该例子同时展示了如何自定义distiller(见下文详解)。 8 | * run_mnli_distill_multiteacher.sh : 多教师蒸馏,将多个教师模型压缩到一个学生模型。 9 | 10 | **PyTorch==1.2.0,transformers==3.0.2** 上测试通过。 11 | 12 | ## 运行 13 | 14 | 1. 运行以上任一个脚本前,请根据自己的环境设置sh文件中相应变量: 15 | 16 | 17 | * OUTPUT_ROOT_DIR : 存放训练好的模型和日志 18 | * DATA_ROOT_DIR : 包含MNLI数据集: 19 | * \$\{DATA_ROOT_DIR\}/MNLI/train.tsv 20 | * \$\{DATA_ROOT_DIR\}/MNLI/dev_matched.tsv 21 | * \$\{DATA_ROOT_DIR\}/MNLI/dev_mismatched.tsv 22 | 23 | 2. 设置BERT模型路径: 24 | * 如果运行run_mnli_train.sh,修改jsons/TrainBertTeacher.json中"student"键下的"vocab_file","config_file"和"checkpoint"路径 25 | * 如果运行 run_mnli_distill_T4tiny.sh 或 run_mnli_distill_T4tiny_emd.sh,修改jsons/DistillBertToTiny.json中"teachers"键下的"vocab_file","config_file"和"checkpoint"路径 26 | * 如果运行 run_mnli_distill_multiteacher.sh, 修改jsons/DistillMultiBert.json中"teachers"键下的所有"vocab_file","config_file"和"checkpoint"路径。可以自行添加更多teacher。 27 | 28 | 3. 设置完成,执行sh文件开始训练。 29 | 30 | ## BERT-EMD与自定义distiller 31 | [BERT-EMD](https://www.aclweb.org/anthology/2020.emnlp-main.242/) 通过优化中间层之间的Earth Mvoer's Distance以自适应地调整教师与学生之间中间层匹配。 32 | 33 | 我们参照了其[原始实现](https://github.com/lxk00/BERT-EMD),并以distiller的形式实现了其一个简化版本EMDDistiller(忽略了attention间的mapping)。 34 | BERT-EMD相关代码位于distiller_emd.py。EMDDistiller使用方法与其他distiller无太大差异: 35 | ```python 36 | from distiller_emd import EMDDistiller 37 | distiller = EMDDistiller(...) 38 | with distiller: 39 | distiller.train(...) 40 | ``` 41 | 使用方式详见 main.emd.py。 42 | 43 | EMDDistiller要求pyemd包: 44 | ```bash 45 | pip install pyemd 46 | ``` 47 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. TextBrewer documentation master file, created by 2 | sphinx-quickstart on Tue Mar 17 17:23:13 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | 7 | .. image:: ../../pics/banner.png 8 | :width: 500 9 | :align: center 10 | 11 | | 12 | 13 | **TextBrewer** is a PyTorch-based model distillation toolkit for natural language processing. 14 | 15 | It includes various distillation techniques from both NLP and CV field and provides an easy-to-use distillation framework, which allows users to quickly experiment with the state-of-the-art distillation methods to compress the model with a relatively small sacrifice in the performance, increasing the inference speed and reducing the memory usage. 16 | 17 | Main features 18 | ------------- 19 | 20 | * **Wide-support** : it supports various model architectures (especially **transformer**-based models). 21 | * **Flexibility** : design your own distillation scheme by combining different techniques. 22 | * **Easy-to-use** : users don't need to modify the model architectures. 23 | * **Built for NLP** : it is suitable for a wide variety of NLP tasks: text classification, machine reading comprehension, sequence labeling, ... 24 | 25 | 26 | Paper: `TextBrewer: An Open-Source Knowledge Distillation Toolkit for Natural Language Processing `_ 27 | 28 | 29 | .. toctree:: 30 | :maxdepth: 2 31 | :caption: Getting Started 32 | 33 | Tutorial 34 | Concepts 35 | 36 | .. toctree:: 37 | :maxdepth: 2 38 | :caption: Experiments 39 | 40 | Experiments 41 | 42 | .. toctree:: 43 | :maxdepth: 2 44 | :caption: API Reference 45 | 46 | Configurations 47 | Distillers 48 | Presets 49 | Losses 50 | Utils 51 | 52 | .. toctree:: 53 | :maxdepth: 2 54 | :caption: Appendices 55 | 56 | ExperimentResults 57 | 58 | Indices and tables 59 | ================== 60 | 61 | * :ref:`genindex` 62 | * :ref:`modindex` 63 | * :ref:`search` -------------------------------------------------------------------------------- /examples/cmrc2018_example/run_cmrc2018_distill_T4tiny.sh: -------------------------------------------------------------------------------- 1 | #set hyperparameters 2 | BERT_DIR=/path/to/roberta-wwm-base 3 | OUTPUT_ROOT_DIR=/path/to/output_root_dir 4 | DATA_ROOT_DIR=/path/to/data_root_dir 5 | trained_teacher_model=/path/to/trained_teacher_model_file 6 | 7 | STUDENT_CONF_DIR=../student_config/roberta_wwm_config 8 | cmrc_train_file=$DATA_ROOT_DIR/cmrc2018/squad-style-data/cmrc2018_train.json 9 | cmrc_dev_file=$DATA_ROOT_DIR/cmrc2018/squad-style-data/cmrc2018_dev.json 10 | DA_file=$DATA_ROOT_DIR/drcd/DRCD_training.json # used for data augmentation 11 | 12 | accu=1 13 | ep=50 14 | lr=10 15 | temperature=8 16 | batch_size=24 17 | length=512 18 | sopt1=1 # The final learning rate is 1/sopt1 of the initial learning rate; 30 is used in most cases 19 | torch_seed=9580 20 | 21 | NAME=cmrc2018_t${temperature}_TbaseST4tiny_AllSmmdH1_lr${lr}e${ep}_opt${sopt1} 22 | OUTPUT_DIR=${OUTPUT_ROOT_DIR}/${NAME} 23 | 24 | 25 | 26 | mkdir -p $OUTPUT_DIR 27 | 28 | 29 | python -u main.distill.py \ 30 | --vocab_file $BERT_DIR/vocab.txt \ 31 | --do_lower_case \ 32 | --bert_config_file_T $BERT_DIR/bert_config.json \ 33 | --bert_config_file_S $STUDENT_CONF_DIR/bert_config_L4t.json \ 34 | --tuned_checkpoint_T $trained_teacher_model \ 35 | --load_model_type none \ 36 | --do_train \ 37 | --do_eval \ 38 | --do_predict \ 39 | --doc_stride 128 \ 40 | --max_seq_length ${length} \ 41 | --train_batch_size ${batch_size} \ 42 | --random_seed $torch_seed \ 43 | --train_file $cmrc_train_file \ 44 | --fake_file_1 $DA_file \ 45 | --predict_file $cmrc_dev_file \ 46 | --num_train_epochs ${ep} \ 47 | --learning_rate ${lr}e-5 \ 48 | --ckpt_frequency 1 \ 49 | --schedule slanted_triangular \ 50 | --s_opt1 ${sopt1} \ 51 | --output_dir $OUTPUT_DIR \ 52 | --gradient_accumulation_steps ${accu} \ 53 | --temperature ${temperature} \ 54 | --output_att_score true \ 55 | --output_att_sum false \ 56 | --output_encoded_layers true \ 57 | --output_attention_layers true \ 58 | --matches L4t_hidden_mse \ 59 | L4_hidden_smmd \ 60 | --tag RB \ 61 | -------------------------------------------------------------------------------- /examples/cmrc2018_example/run_cmrc2018_distill_T3.sh: -------------------------------------------------------------------------------- 1 | #set hyperparameters 2 | BERT_DIR=/path/to/roberta-wwm-base 3 | OUTPUT_ROOT_DIR=/path/to/output_root_dir 4 | DATA_ROOT_DIR=/path/to/data_root_dir 5 | trained_teacher_model=/path/to/trained_teacher_model_file 6 | 7 | STUDENT_CONF_DIR=../student_config/roberta_wwm_config 8 | cmrc_train_file=$DATA_ROOT_DIR/cmrc2018/squad-style-data/cmrc2018_train.json 9 | cmrc_dev_file=$DATA_ROOT_DIR/cmrc2018/squad-style-data/cmrc2018_dev.json 10 | DA_file=$DATA_ROOT_DIR/drcd/DRCD_training.json # used for data augmentation 11 | 12 | accu=1 13 | ep=50 14 | lr=15 15 | temperature=8 16 | batch_size=24 17 | length=512 18 | sopt1=1 # The final learning rate is 1/sopt1 of the initial learning rate; 30 is used in most cases 19 | torch_seed=9580 20 | 21 | NAME=cmrc2018_t${temperature}_TbaseST3_AllSmmdH1_lr${lr}e${ep}_opt${sopt1} 22 | OUTPUT_DIR=${OUTPUT_ROOT_DIR}/${NAME} 23 | 24 | 25 | 26 | mkdir -p $OUTPUT_DIR 27 | 28 | 29 | python -u main.distill.py \ 30 | --vocab_file $BERT_DIR/vocab.txt \ 31 | --do_lower_case \ 32 | --bert_config_file_T $BERT_DIR/bert_config.json \ 33 | --bert_config_file_S $STUDENT_CONF_DIR/bert_config_L3.json \ 34 | --tuned_checkpoint_T $trained_teacher_model \ 35 | --init_checkpoint_S $BERT_DIR/pytorch_model.bin \ 36 | --do_train \ 37 | --do_eval \ 38 | --do_predict \ 39 | --doc_stride 128 \ 40 | --max_seq_length ${length} \ 41 | --train_batch_size ${batch_size} \ 42 | --random_seed $torch_seed \ 43 | --train_file $cmrc_train_file \ 44 | --fake_file_1 $DA_file \ 45 | --predict_file $cmrc_dev_file \ 46 | --num_train_epochs ${ep} \ 47 | --learning_rate ${lr}e-5 \ 48 | --ckpt_frequency 1 \ 49 | --schedule slanted_triangular \ 50 | --s_opt1 ${sopt1} \ 51 | --output_dir $OUTPUT_DIR \ 52 | --gradient_accumulation_steps ${accu} \ 53 | --temperature ${temperature} \ 54 | --output_att_score true \ 55 | --output_att_sum false \ 56 | --output_encoded_layers true \ 57 | --output_attention_layers true \ 58 | --matches L3_hidden_mse \ 59 | L3_hidden_smmd \ 60 | --tag RB \ 61 | -------------------------------------------------------------------------------- /examples/mnli_example/parse.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | BertConfig, BertTokenizer, BertForSequenceClassification, 3 | #CamembertConfig, CamembertTokenizer, CamembertForSequenceClassification, 4 | #XLMRobertaConfig,XLMRobertaTokenizer, XLMRobertaForSequenceClassification, 5 | #RobertaConfig, RobertaTokenizer, RobertaForSequenceClassification, 6 | ) 7 | import json 8 | from typing import Dict,List 9 | from modeling import BertForGLUESimple 10 | 11 | MODEL_CLASSES = { 12 | 'bert': (BertConfig, BertTokenizer, BertForGLUESimple), 13 | } 14 | 15 | def parse_model_config(config) -> Dict : 16 | 17 | results = {"teachers":[]} 18 | 19 | if isinstance(config,str): 20 | with open(config,'r') as f: 21 | config = json.load(f) 22 | else: 23 | assert isinstance(config,dict) 24 | teachers = config['teachers'] 25 | for teacher in teachers: 26 | if teacher['disable'] is False: 27 | model_config, model_tokenizer, _ = MODEL_CLASSES[teacher['model_type']] 28 | if teacher['vocab_file'] is not None: 29 | kwargs = teacher.get('tokenizer_kwargs',{}) 30 | teacher['tokenizer'] = model_tokenizer(vocab_file=teacher['vocab_file'],**kwargs) 31 | if teacher['config_file'] is not None: 32 | teacher['config'] = model_config.from_json_file(teacher['config_file']) 33 | results['teachers'].append(teacher) 34 | 35 | student = config['student'] 36 | if student['disable'] is False: 37 | model_config, model_tokenizer, _ = MODEL_CLASSES[student['model_type']] 38 | if student['vocab_file'] is not None: 39 | kwargs = student.get('tokenizer_kwargs',{}) 40 | student['tokenizer'] = model_tokenizer(vocab_file=student['vocab_file'],**kwargs) 41 | if student['config_file'] is not None: 42 | student['config'] = model_config.from_json_file(student['config_file']) 43 | if 'num_hidden_layers' in student: 44 | student['config'].num_hidden_layers = student['num_hidden_layers'] 45 | results['student'] = student 46 | 47 | return results 48 | -------------------------------------------------------------------------------- /examples/msra_ner_example/modeling.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from torch import nn 4 | from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss 5 | import torch.nn.functional as F 6 | import config as Conf 7 | 8 | from transformers import ElectraModel 9 | from transformers.modeling_electra import ElectraPreTrainedModel 10 | 11 | class ElectraForTokenClassification(ElectraPreTrainedModel): 12 | def __init__(self, config): 13 | super().__init__(config) 14 | 15 | self.electra = ElectraModel(config) 16 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 17 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 18 | self.init_weights() 19 | 20 | def forward(self, input_ids, attention_mask=None,labels=None, token_type_ids=None, 21 | position_ids=None, head_mask=None, inputs_embeds=None): 22 | discriminator_hidden_states = self.electra( 23 | input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds 24 | ) 25 | discriminator_sequence_output = discriminator_hidden_states[0] 26 | 27 | discriminator_sequence_output = self.dropout(discriminator_sequence_output) 28 | logits = self.classifier(discriminator_sequence_output) 29 | 30 | output = (logits,) 31 | 32 | if labels is not None: 33 | loss_fct = nn.CrossEntropyLoss() 34 | # Only keep active parts of the loss 35 | if attention_mask is not None: 36 | active_loss = attention_mask.view(-1) == 1 37 | active_logits = logits.view(-1, self.config.num_labels)[active_loss] 38 | active_labels = labels.view(-1)[active_loss] 39 | loss = loss_fct(active_logits, active_labels) 40 | else: 41 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 42 | 43 | output = (loss,) + output 44 | 45 | output += discriminator_hidden_states[1:] 46 | 47 | return output # (loss), scores, (hidden_states), (attentions) 48 | 49 | def ElectraForTokenClassificationAdaptorTraining(batch, model_outputs): 50 | return {'losses':(model_outputs[0],)} 51 | 52 | def ElectraForTokenClassificationAdaptor(batch, model_outputs): 53 | return {'logits':(model_outputs[1],), 54 | 'hidden':model_outputs[2], 55 | 'input_mask':batch[1], 56 | 'logits_mask':batch[1]} 57 | -------------------------------------------------------------------------------- /examples/mnli_example/README.md: -------------------------------------------------------------------------------- 1 | [**中文说明**](README_ZH.md) | [**English**](README.md) 2 | 3 | This example demonstrates distilltion on MNLI task and **how to write a new distiller**. 4 | 5 | * run_mnli_train.sh : trains a teacher model (bert-base) on MNLI. 6 | * run_mnli_distill_T4tiny.sh : distills the teacher to T4tiny. 7 | * run_mnli_distill_T4tiny_emd.sh : distills the teacher to T4tiny with many-to-many intermediate matches using EMD, so there is no need to specifying the mathcing scheme. This example also demonstrates how to write a custom distiller (see below for details). 8 | * run_mnli_distill_multiteacher.sh : runs multi-teacher distillation, distilling several teacher models into a student model. 9 | 10 | Examples have been tested on **PyTorch==1.2.0, transformers==3.0.2**. 11 | 12 | ## Run 13 | 14 | 1. Set the following variables in the bash scripts before running: 15 | 16 | * OUTPUT_ROOT_DIR : this directory stores logs and trained model weights 17 | * DATA_ROOT_DIR : it includes MNLI dataset: 18 | * \$\{DATA_ROOT_DIR\}/MNLI/train.tsv 19 | * \$\{DATA_ROOT_DIR\}/MNLI/dev_matched.tsv 20 | * \$\{DATA_ROOT_DIR\}/MNLI/dev_mismatched.tsv 21 | 22 | 2. Set the path to BERT: 23 | * If you are running run_mnli_train.sh: open jsons/TrainBertTeacher.json and set "vocab_file","config_file"和"checkpoint" which are under the key "student". 24 | * If you are running run_mnli_distill_T4tiny.sh or run_mnli_distill_T4tiny_emd.sh: open jsons/DistillBertToTiny.json and set "vocab_file", "config_file" and"checkpoint" which are under the key "teachers". 25 | * If you are running run_mnli_distill_multiteacher.sh: open jsons/DistillMultiBert.json and set all the "vocab_file","config_file" and "checkpoint" under the key "teachers". You can also add more teachers to the json. 26 | 27 | 3. Run the bash script and have fun. 28 | 29 | ## BERT-EMD and custom distiller 30 | [BERT-EMD](https://www.aclweb.org/anthology/2020.emnlp-main.242/) allows each intermediate student layer to learn from any intermediate teacher layers adaptively, bassed on optimizing Earth Mover’s Distance. So there is no need to specify the mathcing scheme. 31 | 32 | Based on the [original implementation](https://github.com/lxk00/BERT-EMD), we have written a new distiller (EMDDistiller) to implement a simplified viersion of BERT-EMD (which ignores mappings between attentions). The code of the algorithm is in distiller_emd.py. The EMDDistiller is much like the other distillers: 33 | ```python 34 | from distiller_emd import EMDDistiller 35 | distiller = EMDDistiller(...) 36 | with distiller: 37 | distiller.train(...) 38 | ``` 39 | see main.emd.py for detailed usages. 40 | 41 | EMDDistiller requires pyemd package: 42 | ```bash 43 | pip install pyemd 44 | ``` 45 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('../../src')) 16 | 17 | from textbrewer import __version__ 18 | 19 | # sys.path.insert(0, os.path.abspath('.')) 20 | 21 | 22 | # -- Project information ----------------------------------------------------- 23 | 24 | project = 'TextBrewer' 25 | author = 'Joint Laboratory of HIT and iFLYTEK Research (HFL)' 26 | copyright = '2020, '+author 27 | 28 | # The full version, including alpha/beta/rc tags 29 | version = __version__ 30 | release = __version__ 31 | 32 | 33 | # -- General configuration --------------------------------------------------- 34 | 35 | # Add any Sphinx extension module names here, as strings. They can be 36 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 37 | # ones. 38 | extensions = ['recommonmark', 39 | 'sphinx_markdown_tables', 40 | 'sphinx.ext.autodoc', 41 | 'sphinxcontrib.napoleon', 42 | 'sphinx.ext.viewcode', 43 | ] 44 | 45 | # Add any paths that contain templates here, relative to this directory. 46 | templates_path = ['_templates'] 47 | 48 | # List of patterns, relative to source directory, that match files and 49 | # directories to ignore when looking for source files. 50 | # This pattern also affects html_static_path and html_extra_path. 51 | exclude_patterns = [] 52 | 53 | 54 | # -- Options for HTML output ------------------------------------------------- 55 | 56 | # The theme to use for HTML and HTML Help pages. See the documentation for 57 | # a list of builtin themes. 58 | # 59 | #html_theme = 'alabaster' 60 | html_theme = 'sphinx_rtd_theme' 61 | 62 | html_theme_options = { 63 | 'logo_only': True, 64 | #'style_nav_header_background' :'#EEEEEE' 65 | } 66 | html_logo = '../../pics/nav_banner.png' 67 | 68 | # Add any paths that contain custom static files (such as style sheets) here, 69 | # relative to this directory. They are copied after the builtin static files, 70 | # so a file named "default.css" will overwrite the builtin "default.css". 71 | html_static_path = ['_static'] 72 | 73 | def setup(app): 74 | app.add_stylesheet('css/custom.css') 75 | -------------------------------------------------------------------------------- /examples/cmrc2018_example/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import config 4 | import logging 5 | logger = logging.getLogger("utils") 6 | 7 | def read_and_convert(fn,is_training,read_fn,convert_fn,do_lower_case): 8 | data_dirname, data_basename = os.path.split(fn) 9 | cased = '' if do_lower_case else 'cased' 10 | if config.args.max_seq_length != 416: 11 | data_pklname = data_basename + '%s%d_l%d_cHA.t%s.pkl' % (cased,config.args.doc_stride,config.args.max_seq_length,config.args.tag) 12 | else: 13 | data_pklname = data_basename + '%s%d_cHA.t%s.pkl' % (cased,config.args.doc_stride,config.args.tag) 14 | full_pklname = os.path.join(data_dirname,data_pklname) 15 | if os.path.exists(full_pklname): 16 | logger.info("Loading dataset %s " % data_pklname) 17 | with open(full_pklname,'rb') as f: 18 | examples,features = pickle.load(f) 19 | else: 20 | logger.info("Building dataset %s " % data_pklname) 21 | examples = read_fn(input_file=fn,is_training=is_training,do_lower_case=do_lower_case) 22 | logger.info(f"Size: {len(examples)}") 23 | features = convert_fn(examples=examples,is_training=is_training) 24 | try: 25 | with open(full_pklname,'wb') as f: 26 | pickle.dump((examples,features),f) 27 | except: 28 | logger.info("Can't save train data file.") 29 | return examples,features 30 | 31 | 32 | def divide_parameters(named_parameters,lr=None): 33 | no_decay = ['bias', 'LayerNorm.bias','LayerNorm.weight'] 34 | decay_parameters_names = list(zip(*[(p,n) for n,p in named_parameters if not any((di in n) for di in no_decay)])) 35 | no_decay_parameters_names = list(zip(*[(p,n) for n,p in named_parameters if any((di in n) for di in no_decay)])) 36 | param_group = [] 37 | if len(decay_parameters_names)>0: 38 | decay_parameters, decay_names = decay_parameters_names 39 | #print ("decay:",decay_names) 40 | if lr is not None: 41 | decay_group = {'params':decay_parameters, 'weight_decay_rate': config.args.weight_decay_rate, 'lr':lr} 42 | else: 43 | decay_group = {'params': decay_parameters, 'weight_decay_rate': config.args.weight_decay_rate} 44 | param_group.append(decay_group) 45 | 46 | if len(no_decay_parameters_names)>0: 47 | no_decay_parameters, no_decay_names = no_decay_parameters_names 48 | #print ("no decay:", no_decay_names) 49 | if lr is not None: 50 | no_decay_group = {'params': no_decay_parameters, 'weight_decay_rate': 0.0, 'lr': lr} 51 | else: 52 | no_decay_group = {'params': no_decay_parameters, 'weight_decay_rate': 0.0} 53 | param_group.append(no_decay_group) 54 | 55 | assert len(param_group)>0 56 | return param_group 57 | -------------------------------------------------------------------------------- /examples/mnli_example/modeling.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from torch import nn 4 | from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss 5 | import torch.nn.functional as F 6 | import config as Conf 7 | BertLayerNorm = torch.nn.LayerNorm 8 | from transformers import BertModel, RobertaModel 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def initializer_builder(std): 13 | _std = std 14 | def init_bert_weights(module): 15 | if isinstance(module, (nn.Linear, nn.Embedding)): 16 | module.weight.data.normal_(mean=0.0, std=_std) 17 | elif isinstance(module, BertLayerNorm): 18 | module.bias.data.zero_() 19 | module.weight.data.fill_(1.0) 20 | if isinstance(module, nn.Linear) and module.bias is not None: 21 | module.bias.data.zero_() 22 | return init_bert_weights 23 | 24 | class BertForGLUESimple(nn.Module): 25 | def __init__(self, config, num_labels): 26 | super(BertForGLUESimple, self).__init__() 27 | 28 | config.num_labels = num_labels 29 | self.num_labels = num_labels 30 | config.output_hidden_states = True 31 | config.output_attentions = False 32 | self.bert = BertModel(config) 33 | self.classifier = nn.Linear(config.hidden_size, num_labels) 34 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 35 | initializer = initializer_builder(config.initializer_range) 36 | self.apply(initializer) 37 | 38 | def forward(self, input_ids, attention_mask, token_type_ids, labels=None): 39 | last_hidden_state, pooled_output, hidden_states = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) 40 | output_for_cls = self.dropout(pooled_output) 41 | logits = self.classifier(output_for_cls) # output size: batch_size,num_labels 42 | #assert len(sequence_output)==self.bert.config.num_hidden_layers + 1 # embeddings + 12 hiddens 43 | #assert len(attention_output)==self.bert.config.num_hidden_layers + 1 # None + 12 attentions 44 | if labels is not None: 45 | if self.num_labels == 1: 46 | loss = F.mse_loss(logits.view(-1), labels.view(-1)) 47 | else: 48 | loss = F.cross_entropy(logits,labels) 49 | return logits, hidden_states, loss 50 | else: 51 | return logits 52 | 53 | 54 | 55 | def BertForGLUESimpleAdaptor(batch, model_outputs, with_logits=True, with_mask=False): 56 | dict_obj = {'hidden': model_outputs[1]} 57 | if with_mask: 58 | dict_obj['inputs_mask'] = batch[1] 59 | if with_logits: 60 | dict_obj['logits'] = (model_outputs[0],) 61 | return dict_obj 62 | 63 | def BertForGLUESimpleAdaptorTrain(batch, model_outputs): 64 | return {'losses':(model_outputs[2],)} 65 | -------------------------------------------------------------------------------- /examples/cmrc2018_example/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import argparse 24 | import tensorflow as tf 25 | import torch 26 | import numpy as np 27 | 28 | from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert 29 | 30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 31 | # Initialise PyTorch model 32 | config = BertConfig.from_json_file(bert_config_file) 33 | print("Building PyTorch model from configuration: {}".format(str(config))) 34 | model = BertForPreTraining(config) 35 | 36 | # Load weights from tf checkpoint 37 | load_tf_weights_in_bert(model, tf_checkpoint_path) 38 | 39 | # Save pytorch-model 40 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 41 | torch.save(model.state_dict(), pytorch_dump_path) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | ## Required parameters 47 | parser.add_argument("--tf_checkpoint_path", 48 | default = None, 49 | type = str, 50 | required = True, 51 | help = "Path the TensorFlow checkpoint path.") 52 | parser.add_argument("--bert_config_file", 53 | default = None, 54 | type = str, 55 | required = True, 56 | help = "The config json file corresponding to the pre-trained BERT model. \n" 57 | "This specifies the model architecture.") 58 | parser.add_argument("--pytorch_dump_path", 59 | default = None, 60 | type = str, 61 | required = True, 62 | help = "Path to the output PyTorch model.") 63 | args = parser.parse_args() 64 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 65 | args.bert_config_file, 66 | args.pytorch_dump_path) 67 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py 3 | To create the package for pypi. 4 | 1. Change the version in __init__.py, setup.py as well as docs/source/conf.py. 5 | 2. Commit these changes with the message: "Release: VERSION" 6 | 3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' " 7 | Push the tag to git: git push --tags origin master 8 | 4. Build both the sources and the wheel. Do not change anything in setup.py between 9 | creating the wheel and the source distribution (obviously). 10 | For the wheel, run: "python setup.py bdist_wheel" in the top level directory. 11 | (this will build a wheel for the python version you use to build it). 12 | For the sources, run: "python setup.py sdist" 13 | You should now have a /dist directory with both .whl and .tar.gz source versions. 14 | 5. Check that everything looks correct by uploading the package to the pypi test server: 15 | twine upload dist/* -r pypitest 16 | (pypi suggest using twine as other methods upload files via plaintext.) 17 | Check that you can install it in a virtualenv by running: 18 | pip install -i https://testpypi.python.org/pypi transformers 19 | 6. Upload the final version to actual pypi: 20 | twine upload dist/* -r pypi 21 | 7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. 22 | """ 23 | 24 | import shutil 25 | from pathlib import Path 26 | 27 | from setuptools import find_packages, setup 28 | 29 | 30 | setup( 31 | name="textbrewer", 32 | version="0.2.1.post1", 33 | author="ziqingyang", 34 | author_email="zqyang5@iflytek.com", 35 | description="PyTorch-based knowledge distillation toolkit for natural language processing", 36 | long_description="PyTorch-based knowledge distillation toolkit for natural language processing.", 37 | #long_description=open("READMEshort.md", "r", encoding="utf-8").read(), 38 | long_description_content_type="text/markdown", 39 | keywords="NLP deep learning knowledge distillation pytorch", 40 | #license="", 41 | url="http://textbrewer.hfl-rc.com", 42 | #package_dir={"": "src"}, 43 | packages=['textbrewer'], 44 | package_dir={'':'src'}, 45 | install_requires=[ 46 | "numpy", 47 | "torch >= 1.1", 48 | "tensorboard", 49 | "tqdm" 50 | ], 51 | 52 | python_requires=">=3.6", 53 | classifiers=[ 54 | #"Development Status :: 5 - Production/Stable", 55 | "Intended Audience :: Developers", 56 | "Intended Audience :: Education", 57 | "Intended Audience :: Science/Research", 58 | "License :: OSI Approved :: Apache Software License", 59 | "Operating System :: OS Independent", 60 | "Programming Language :: Python :: 3.6", 61 | "Programming Language :: Python :: 3.7", 62 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 63 | ], 64 | ) 65 | -------------------------------------------------------------------------------- /docs/source/_build/html/_static/css/badge_only.css: -------------------------------------------------------------------------------- 1 | .fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:before,.clearfix:after{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-weight:normal;font-style:normal;src:url("../fonts/fontawesome-webfont.eot");src:url("../fonts/fontawesome-webfont.eot?#iefix") format("embedded-opentype"),url("../fonts/fontawesome-webfont.woff") format("woff"),url("../fonts/fontawesome-webfont.ttf") format("truetype"),url("../fonts/fontawesome-webfont.svg#FontAwesome") format("svg")}.fa:before{display:inline-block;font-family:FontAwesome;font-style:normal;font-weight:normal;line-height:1;text-decoration:inherit}a .fa{display:inline-block;text-decoration:inherit}li .fa{display:inline-block}li .fa-large:before,li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-0.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before,ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before{content:""}.icon-book:before{content:""}.fa-caret-down:before{content:""}.icon-caret-down:before{content:""}.fa-caret-up:before{content:""}.icon-caret-up:before{content:""}.fa-caret-left:before{content:""}.icon-caret-left:before{content:""}.fa-caret-right:before{content:""}.icon-caret-right:before{content:""}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;z-index:400}.rst-versions a{color:#2980B9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27AE60;*zoom:1}.rst-versions .rst-current-version:before,.rst-versions .rst-current-version:after{display:table;content:""}.rst-versions .rst-current-version:after{clear:both}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book{float:left}.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#E74C3C;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#F1C40F;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:gray;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:solid 1px #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .icon-book{float:none}.rst-versions.rst-badge .fa-book{float:none}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book{float:left}.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge .rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width: 768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} 2 | -------------------------------------------------------------------------------- /examples/cmrc2018_example/pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | 30 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if gpt2_config_file == "": 33 | config = GPT2Config() 34 | else: 35 | config = GPT2Config(gpt2_config_file) 36 | model = GPT2Model(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_gpt2(model, gpt2_checkpoint_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--gpt2_checkpoint_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--gpt2_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 71 | args.gpt2_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /examples/cmrc2018_example/pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | 30 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if openai_config_file == "": 33 | config = OpenAIGPTConfig() 34 | else: 35 | config = OpenAIGPTConfig(openai_config_file) 36 | model = OpenAIGPTModel(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--openai_checkpoint_folder_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--openai_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 71 | args.openai_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /examples/mnli_example/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils_glue import processors 3 | args = None 4 | 5 | def parse(opt=None): 6 | parser = argparse.ArgumentParser() 7 | 8 | ## Required parameters 9 | 10 | parser.add_argument("--output_dir", default=None, type=str, required=True, 11 | help="The output directory where the model checkpoints will be written.") 12 | 13 | ## Other parameters 14 | parser.add_argument("--data_dir", default=None, type=str) 15 | parser.add_argument("--max_seq_length", default=128, type=int) 16 | parser.add_argument("--do_train", default=False, action='store_true', help="Whether to run training.") 17 | parser.add_argument("--do_predict", default=False, action='store_true', help="Whether to run eval on the dev set.") 18 | parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") 19 | parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.") 20 | parser.add_argument("--learning_rate", default=3e-5, type=float, help="The initial learning rate for Adam.") 21 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 22 | help="Total number of training epochs to perform.") 23 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 24 | help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% " 25 | "of training.") 26 | parser.add_argument("--no_cuda", 27 | default=False, 28 | action='store_true', 29 | help="Whether not to use CUDA when available") 30 | parser.add_argument('--gradient_accumulation_steps', 31 | type=int, 32 | default=1, 33 | help="Number of updates steps to accumualte before performing a backward/update pass.") 34 | parser.add_argument("--local_rank", 35 | type=int, 36 | default=-1, 37 | help="local_rank for distributed training on gpus") 38 | parser.add_argument('--fp16', 39 | default=False, 40 | action='store_true', 41 | help="Whether to use 16-bit float precisoin instead of 32-bit") 42 | 43 | parser.add_argument('--random_seed',type=int,default=10236797) 44 | parser.add_argument('--weight_decay_rate',type=float,default=0.01) 45 | parser.add_argument('--do_eval',action='store_true') 46 | parser.add_argument('--PRINT_EVERY',type=int,default=200) 47 | parser.add_argument('--ckpt_frequency',type=int,default=2) 48 | 49 | parser.add_argument("--temperature", default=1, type=float) 50 | 51 | parser.add_argument("--teacher_cached",action='store_true') 52 | parser.add_argument('--task_name',type=str,choices=list(processors.keys())) 53 | parser.add_argument('--aux_task_name',type=str,choices=list(processors.keys()),default=None) 54 | parser.add_argument('--aux_data_dir', type=str) 55 | 56 | parser.add_argument('--matches',nargs='*',type=str) 57 | parser.add_argument('--model_config_json',type=str) 58 | parser.add_argument('--do_test',action='store_true') 59 | 60 | 61 | global args 62 | if opt is None: 63 | args = parser.parse_args() 64 | else: 65 | args = parser.parse_args(opt) 66 | 67 | if __name__ == '__main__': 68 | print (args) 69 | parse(['--SAVE_DIR','test']) 70 | print(args) 71 | -------------------------------------------------------------------------------- /examples/mnli_example/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import config 4 | import logging 5 | logger = logging.getLogger(__name__) 6 | import torch 7 | from torch.utils.data import TensorDataset 8 | from utils_glue import processors, output_modes, convert_examples_to_features, trans3_convert_examples_to_features 9 | 10 | def load_and_cache_examples(args, task, tokenizer, prefix='bert', evaluate=False, return_dataset=True): 11 | data_dir = args.data_dir 12 | processor = processors[task]() 13 | output_mode = output_modes[task] 14 | # Load data features from cache or dataset file 15 | #seg_ids = 1 if 'bert' in prefix else 0 16 | if False: #args.do_test: 17 | pass #stage = 'test' 18 | else: 19 | if evaluate: 20 | stage = 'dev' if args.do_test is False else 'test' 21 | else: 22 | stage = 'train' 23 | cached_features_file = os.path.join(data_dir, '{}_{}_{}_{}'.format( 24 | prefix, 25 | stage, #'dev' if evaluate else 'train', 26 | str(args.max_seq_length), 27 | str(task))) 28 | if os.path.exists(cached_features_file): 29 | logger.info("Loading features from cached file %s", cached_features_file) 30 | features = torch.load(cached_features_file) 31 | else: 32 | logger.info("Creating features from dataset file at %s", data_dir) 33 | label_list = processor.get_labels() 34 | if evaluate: 35 | if args.do_test: 36 | examples = processor.get_test_examples(data_dir) 37 | else: 38 | examples = processor.get_dev_examples(data_dir) 39 | else: 40 | examples = processor.get_train_examples(data_dir) 41 | features = trans3_convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode) 42 | if args.local_rank in [-1, 0]: 43 | logger.info("Saving features into cached file %s", cached_features_file) 44 | torch.save(features, cached_features_file) 45 | # Convert to Tensors and build dataset 46 | if True: 47 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 48 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 49 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 50 | if output_mode == "classification": 51 | all_label_ids = torch.tensor([f.label for f in features], dtype=torch.long) 52 | elif output_mode == "regression": 53 | all_label_ids = torch.tensor([f.label for f in features], dtype=torch.float) 54 | if return_dataset: 55 | dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_label_ids) 56 | return dataset 57 | else: 58 | return all_input_ids, all_attention_mask, all_token_type_ids, all_label_ids 59 | 60 | 61 | def divide_parameters(named_parameters,lr=None): 62 | no_decay = ['bias', 'LayerNorm.bias','LayerNorm.weight'] 63 | decay_parameters_names = list(zip(*[(p,n) for n,p in named_parameters if not any((di in n) for di in no_decay)])) 64 | no_decay_parameters_names = list(zip(*[(p,n) for n,p in named_parameters if any((di in n) for di in no_decay)])) 65 | param_group = [] 66 | if len(decay_parameters_names)>0: 67 | decay_parameters, decay_names = decay_parameters_names 68 | #print ("decay:",decay_names) 69 | if lr is not None: 70 | decay_group = {'params':decay_parameters, 'weight_decay': config.args.weight_decay_rate, 'lr':lr} 71 | else: 72 | decay_group = {'params': decay_parameters, 'weight_decay': config.args.weight_decay_rate} 73 | param_group.append(decay_group) 74 | 75 | if len(no_decay_parameters_names)>0: 76 | no_decay_parameters, no_decay_names = no_decay_parameters_names 77 | #print ("no decay:", no_decay_names) 78 | if lr is not None: 79 | no_decay_group = {'params': no_decay_parameters, 'weight_decay': 0.0, 'lr': lr} 80 | else: 81 | no_decay_group = {'params': no_decay_parameters, 'weight_decay': 0.0} 82 | param_group.append(no_decay_group) 83 | 84 | assert len(param_group)>0 85 | return param_group 86 | -------------------------------------------------------------------------------- /docs/source/_build/html/_static/js/theme.js: -------------------------------------------------------------------------------- 1 | /* sphinx_rtd_theme version 0.4.3 | MIT license */ 2 | /* Built 20190212 16:02 */ 3 | require=function r(s,a,l){function c(e,n){if(!a[e]){if(!s[e]){var i="function"==typeof require&&require;if(!n&&i)return i(e,!0);if(u)return u(e,!0);var t=new Error("Cannot find module '"+e+"'");throw t.code="MODULE_NOT_FOUND",t}var o=a[e]={exports:{}};s[e][0].call(o.exports,function(n){return c(s[e][1][n]||n)},o,o.exports,r,s,a,l)}return a[e].exports}for(var u="function"==typeof require&&require,n=0;n"),i("table.docutils.footnote").wrap("
"),i("table.docutils.citation").wrap("
"),i(".wy-menu-vertical ul").not(".simple").siblings("a").each(function(){var e=i(this);expand=i(''),expand.on("click",function(n){return t.toggleCurrent(e),n.stopPropagation(),!1}),e.prepend(expand)})},reset:function(){var n=encodeURI(window.location.hash)||"#";try{var e=$(".wy-menu-vertical"),i=e.find('[href="'+n+'"]');if(0===i.length){var t=$('.document [id="'+n.substring(1)+'"]').closest("div.section");0===(i=e.find('[href="#'+t.attr("id")+'"]')).length&&(i=e.find('[href="#"]'))}0this.docHeight||(this.navBar.scrollTop(i),this.winPosition=n)},onResize:function(){this.winResize=!1,this.winHeight=this.win.height(),this.docHeight=$(document).height()},hashChange:function(){this.linkScroll=!0,this.win.one("hashchange",function(){this.linkScroll=!1})},toggleCurrent:function(n){var e=n.closest("li");e.siblings("li.current").removeClass("current"),e.siblings().find("li.current").removeClass("current"),e.find("> ul li.current").removeClass("current"),e.toggleClass("current")}},"undefined"!=typeof window&&(window.SphinxRtdTheme={Navigation:e.exports.ThemeNav,StickyNav:e.exports.ThemeNav}),function(){for(var r=0,n=["ms","moz","webkit","o"],e=0;e> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" 13 | ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" 14 | ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" 15 | ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`") 16 | else: 17 | if sys.argv[1] == "convert_tf_checkpoint_to_pytorch": 18 | try: 19 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 20 | except ImportError: 21 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 22 | "In that case, it requires TensorFlow to be installed. Please see " 23 | "https://www.tensorflow.org/install/ for installation instructions.") 24 | raise 25 | 26 | if len(sys.argv) != 5: 27 | # pylint: disable=line-too-long 28 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 29 | else: 30 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 31 | TF_CONFIG = sys.argv.pop() 32 | TF_CHECKPOINT = sys.argv.pop() 33 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 34 | elif sys.argv[1] == "convert_openai_checkpoint": 35 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch 36 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 37 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 38 | if len(sys.argv) == 5: 39 | OPENAI_GPT_CONFIG = sys.argv[4] 40 | else: 41 | OPENAI_GPT_CONFIG = "" 42 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 43 | OPENAI_GPT_CONFIG, 44 | PYTORCH_DUMP_OUTPUT) 45 | elif sys.argv[1] == "convert_transfo_xl_checkpoint": 46 | try: 47 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch 48 | except ImportError: 49 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 50 | "In that case, it requires TensorFlow to be installed. Please see " 51 | "https://www.tensorflow.org/install/ for installation instructions.") 52 | raise 53 | 54 | if 'ckpt' in sys.argv[2].lower(): 55 | TF_CHECKPOINT = sys.argv[2] 56 | TF_DATASET_FILE = "" 57 | else: 58 | TF_DATASET_FILE = sys.argv[2] 59 | TF_CHECKPOINT = "" 60 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 61 | if len(sys.argv) == 5: 62 | TF_CONFIG = sys.argv[4] 63 | else: 64 | TF_CONFIG = "" 65 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) 66 | else: 67 | try: 68 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch 69 | except ImportError: 70 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 71 | "In that case, it requires TensorFlow to be installed. Please see " 72 | "https://www.tensorflow.org/install/ for installation instructions.") 73 | raise 74 | 75 | TF_CHECKPOINT = sys.argv[2] 76 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 77 | if len(sys.argv) == 5: 78 | TF_CONFIG = sys.argv[4] 79 | else: 80 | TF_CONFIG = "" 81 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /src/textbrewer/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | def masking(tokens, p = 0.1, mask='[MASK]'): 5 | """ 6 | Returns a new list by replacing elements in `tokens` by `mask` with probability `p`. 7 | 8 | Args: 9 | tokens (list): list of tokens or token ids. 10 | p (float): probability to mask each element in `tokens`. 11 | Returns: 12 | A new list by replacing elements in `tokens` by `mask` with probability `p`. 13 | """ 14 | outputs = tokens[:] 15 | for i in range(len(tokens)): 16 | if np.random.rand() < p: 17 | outputs[i] = mask 18 | return outputs 19 | 20 | def deleting(tokens, p = 0.1): 21 | """ 22 | Returns a new list by deleting elements in `tokens` with probability `p`. 23 | 24 | Args: 25 | tokens (list): list of tokens or token ids. 26 | p (float): probability to delete each element in `tokens`. 27 | Retunrns: 28 | a new list by deleting elements in :`tokens` with probability `p`. 29 | """ 30 | choice = np.random.binomial(1,1-p,len(tokens)) 31 | outputs = [tokens[i] for i in range(len(tokens)) if choice[i]==1] 32 | return outputs 33 | 34 | 35 | def n_gram_sampling(tokens, 36 | p_ng = [0.2,0.2,0.2,0.2,0.2], 37 | l_ng = [1,2,3,4,5]): 38 | """ 39 | Samples a length `l` from `l_ng` with probability distribution `p_ng`, then returns a random span of length `l` from `tokens`. 40 | 41 | Args: 42 | tokens (list): list of tokens or token ids. 43 | p_ng (list): probability distribution of the n-grams, should sum to 1. 44 | l_ng (list): specify the n-grams. 45 | Returns: 46 | a n-gram random span from `tokens`. 47 | """ 48 | span_length = np.random.choice(l_ng,p= p_ng) 49 | start_position = max(0,np.random.randint(0,len(tokens)-span_length+1)) 50 | n_gram_span = tokens[start_position:start_position+span_length] 51 | return n_gram_span 52 | 53 | 54 | def short_disorder(tokens, p = [0.9,0.1,0,0,0]): # untouched + four cases abc, bac, cba, cab, bca 55 | """ 56 | Returns a new list by disordering `tokens` with probability distribution `p` at every possible position. Let `abc` be a 3-gram in `tokens`, 57 | there are five ways to disorder, corresponding to five probability values: 58 | 59 | | abc -> abc 60 | | abc -> bac 61 | | abc -> cba 62 | | abc -> cab 63 | | abc -> bca 64 | 65 | Args: 66 | tokens (list): list of tokens or token ids. 67 | p (list): probability distribution of 5 disorder types, should sum to 1. 68 | Returns: 69 | a new disordered list 70 | """ 71 | i = 0 72 | outputs = tokens[:] 73 | l = len(tokens) 74 | while i < l-1: 75 | permutation = np.random.choice([0,1,2,3,4],p=p) 76 | if permutation!=0 and i==l-2: 77 | outputs[i], outputs[i+1] = outputs[i+1], outputs[i] 78 | i += 2 79 | elif permutation==1: 80 | outputs[i], outputs[i+1] = outputs[i+1], outputs[i] 81 | i += 2 82 | elif permutation==2: 83 | outputs[i], outputs[i+2] = outputs[i+2], outputs[i] 84 | i +=3 85 | elif permutation==3: 86 | outputs[i],outputs[i+1],outputs[i+2] = outputs[i+2],outputs[i],outputs[i+1] 87 | i += 3 88 | elif permutation==4: 89 | outputs[i],outputs[i+1],outputs[i+2] = outputs[i+1],outputs[i+2],outputs[i] 90 | i += 3 91 | else: 92 | i += 1 93 | return outputs 94 | 95 | def long_disorder(tokens,p = 0.1, length=20): 96 | """ 97 | Performs a long-range disordering. If ``length>1``, then swaps the two halves of each span of length `length` in `tokens`; 98 | if ``length<=1``, treats `length` as the relative length. For example:: 99 | 100 | >>>long_disorder([0,1,2,3,4,5,6,7,8,9,10], p=1, length=0.4) 101 | [2, 3, 0, 1, 6, 7, 4, 5, 8, 9] 102 | 103 | Args: 104 | tokens (list): list of tokens or token ids. 105 | p (list): probability to swaps the two halves of a spans at possible positions. 106 | length (int or float): length of the disordered span. 107 | Returns: 108 | a new disordered list 109 | """ 110 | outputs = tokens[:] 111 | if int(length) <= 1: 112 | length = len(tokens)*length 113 | length = (int(length)+1) //2 * 2 114 | i = 0 115 | while i<=len(outputs)-length: 116 | if np.random.rand() < p: 117 | outputs[i:i+length//2], outputs[i+length//2:i+length] = outputs[i+length//2:i+length], outputs[i:i+length//2] 118 | i += length 119 | else: 120 | i += 1 121 | return outputs -------------------------------------------------------------------------------- /examples/cmrc2018_example/cmrc2018_evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Evaluation script for CMRC 2018 4 | version: v5 - special 5 | Note: 6 | v5 - special: Evaluate on SQuAD-style CMRC 2018 Datasets 7 | v5: formatted output, add usage description 8 | v4: fixed segmentation issues 9 | ''' 10 | from __future__ import print_function 11 | from collections import Counter, OrderedDict 12 | import string 13 | import re 14 | import argparse 15 | import json 16 | import sys 17 | #sys.setdefaultencoding('utf8') 18 | import nltk 19 | import pdb 20 | 21 | # split Chinese with English 22 | def mixed_segmentation(in_str, rm_punc=False): 23 | in_str = (in_str).lower().strip() 24 | segs_out = [] 25 | temp_str = "" 26 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=', 27 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、', 28 | '「','」','(',')','-','~','『','』'] 29 | for char in in_str: 30 | if rm_punc and char in sp_char: 31 | continue 32 | if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char: 33 | if temp_str != "": 34 | ss = nltk.word_tokenize(temp_str) 35 | segs_out.extend(ss) 36 | temp_str = "" 37 | segs_out.append(char) 38 | else: 39 | temp_str += char 40 | 41 | #handling last part 42 | if temp_str != "": 43 | ss = nltk.word_tokenize(temp_str) 44 | segs_out.extend(ss) 45 | 46 | return segs_out 47 | 48 | 49 | # remove punctuation 50 | def remove_punctuation(in_str): 51 | in_str = str(in_str).lower().strip() 52 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=', 53 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、', 54 | '「','」','(',')','-','~','『','』'] 55 | out_segs = [] 56 | for char in in_str: 57 | if char in sp_char: 58 | continue 59 | else: 60 | out_segs.append(char) 61 | return ''.join(out_segs) 62 | 63 | 64 | # find longest common string 65 | def find_lcs(s1, s2): 66 | m = [[0 for i in range(len(s2)+1)] for j in range(len(s1)+1)] 67 | mmax = 0 68 | p = 0 69 | for i in range(len(s1)): 70 | for j in range(len(s2)): 71 | if s1[i] == s2[j]: 72 | m[i+1][j+1] = m[i][j]+1 73 | if m[i+1][j+1] > mmax: 74 | mmax=m[i+1][j+1] 75 | p=i+1 76 | return s1[p-mmax:p], mmax 77 | 78 | # 79 | def evaluate(ground_truth_file, prediction_file): 80 | f1 = 0 81 | em = 0 82 | total_count = 0 83 | skip_count = 0 84 | for instance in ground_truth_file["data"]: 85 | #context_id = instance['context_id'].strip() 86 | #context_text = instance['context_text'].strip() 87 | for para in instance["paragraphs"]: 88 | for qas in para['qas']: 89 | total_count += 1 90 | query_id = qas['id'].strip() 91 | query_text = qas['question'].strip() 92 | answers = [x["text"] for x in qas['answers']] 93 | 94 | if query_id not in prediction_file: 95 | sys.stderr.write('Unanswered question: {}\n'.format(query_id)) 96 | skip_count += 1 97 | continue 98 | 99 | prediction = (prediction_file[query_id]) 100 | f1 += calc_f1_score(answers, prediction) 101 | em += calc_em_score(answers, prediction) 102 | 103 | f1_score = 100.0 * f1 / total_count 104 | em_score = 100.0 * em / total_count 105 | return f1_score, em_score, total_count, skip_count 106 | 107 | 108 | def calc_f1_score(answers, prediction): 109 | f1_scores = [] 110 | for ans in answers: 111 | ans_segs = mixed_segmentation(ans, rm_punc=True) 112 | prediction_segs = mixed_segmentation(prediction, rm_punc=True) 113 | lcs, lcs_len = find_lcs(ans_segs, prediction_segs) 114 | if lcs_len == 0: 115 | f1_scores.append(0) 116 | continue 117 | precision = 1.0*lcs_len/len(prediction_segs) 118 | recall = 1.0*lcs_len/len(ans_segs) 119 | f1 = (2*precision*recall)/(precision+recall) 120 | f1_scores.append(f1) 121 | return max(f1_scores) 122 | 123 | 124 | def calc_em_score(answers, prediction): 125 | em = 0 126 | for ans in answers: 127 | ans_ = remove_punctuation(ans) 128 | prediction_ = remove_punctuation(prediction) 129 | if ans_ == prediction_: 130 | em = 1 131 | break 132 | return em 133 | 134 | if __name__ == '__main__': 135 | parser = argparse.ArgumentParser(description='Evaluation Script for CMRC 2018') 136 | parser.add_argument('dataset_file', help='Official dataset file') 137 | parser.add_argument('prediction_file', help='Your prediction File') 138 | args = parser.parse_args() 139 | ground_truth_file = json.load(open(args.dataset_file, 'rb')) 140 | prediction_file = json.load(open(args.prediction_file, 'rb')) 141 | F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file) 142 | AVG = (EM+F1)*0.5 143 | output_result = OrderedDict() 144 | output_result['AVERAGE'] = '%.3f' % AVG 145 | output_result['F1'] = '%.3f' % F1 146 | output_result['EM'] = '%.3f' % EM 147 | output_result['TOTAL'] = TOTAL 148 | output_result['SKIP'] = SKIP 149 | output_result['FILE'] = args.prediction_file 150 | print(json.dumps(output_result)) 151 | 152 | -------------------------------------------------------------------------------- /src/textbrewer/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List 3 | 4 | def cycle(iterable): 5 | while True: 6 | for x in iterable: 7 | yield x 8 | 9 | def initializer_builder(std): 10 | _std = std 11 | def init_weights(module): 12 | if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): 13 | module.weight.data.normal_(mean=0.0, std=_std) 14 | if isinstance(module, torch.nn.Linear) and module.bias is not None: 15 | module.bias.data.zero_() 16 | return init_weights 17 | 18 | 19 | class LayerNode: 20 | def __init__(self,name,parent=None,value=None,fullname=None): 21 | self.name = name 22 | self.fullname = fullname 23 | self.value = None 24 | self.children_name = {} 25 | self.parent = parent 26 | def __contains__(self, key): 27 | return key in self.children_name 28 | def __getitem__(self,key): 29 | return self.children_name[key] 30 | def __setitem__(self,key,value): 31 | self.children_name[key]=value 32 | def update(self,value): 33 | if self.parent: 34 | if self.parent.value is None: 35 | self.parent.value = value 36 | else: 37 | if isinstance(value,(tuple,list)): 38 | old_value = self.parent.value 39 | new_value = [old_value[i]+value[i] for i in range(len(value))] 40 | self.parent.value = new_value 41 | else: 42 | self.parent.value += value 43 | 44 | self.parent.update(value) 45 | 46 | def format(self, level=0, total=None ,indent='--',max_level=None,max_length=None): 47 | string ='' 48 | if total is None: 49 | total = self.value[0] 50 | if level ==0: 51 | max_length = self._max_name_length(indent,' ',max_level=max_level) + 1 52 | string += '\n' 53 | string +=f"{'LAYER NAME':<{max_length}}\t{'#PARAMS':>15}\t{'RATIO':>10}\t{'MEM(MB)':>8}\n" 54 | 55 | if max_level is not None and level==max_level: 56 | string += f"{indent+self.name+':':<{max_length}}\t{self.value[0]:15,d}\t{self.value[0]/total:>10.2%}\t{self.value[1]:>8.2f}\n" 57 | else: 58 | if len(self.children_name)==1: 59 | string += f"{indent+self.name:{max_length}}\n" 60 | else: 61 | string += f"{indent+self.name+':':<{max_length}}\t{self.value[0]:15,d}\t{self.value[0]/total:>10.2%}\t{self.value[1]:>8.2f}\n" 62 | for child_name, child in self.children_name.items(): 63 | string += child.format(level+1, total, 64 | indent=' '+indent, max_level=max_level,max_length=max_length) 65 | return string 66 | 67 | def _max_name_length(self,indent1='--', indent2=' ',level=0,max_level=None): 68 | length = len(self.name) + len(indent1) + level *len(indent2) 69 | if max_level is not None and level >= max_level: 70 | child_lengths = [] 71 | else: 72 | child_lengths = [child._max_name_length(indent1,indent2,level=level+1,max_level=max_level) 73 | for child in self.children_name.values()] 74 | max_length = max(child_lengths+[length]) 75 | return max_length 76 | 77 | 78 | def display_parameters(model,max_level=None): 79 | """ 80 | Display the numbers and memory usage of module parameters. 81 | 82 | Args: 83 | model (torch.nn.Module or dict): the model to be inspected. 84 | max_level (int or None): The max level to display. If ``max_level==None``, show all the levels. 85 | Returns: 86 | A formatted string and a :class:`~textbrewer.utils.LayerNode` object representing the model. 87 | """ 88 | if isinstance(model,torch.nn.Module): 89 | state_dict = model.state_dict() 90 | elif isinstance(model,dict): 91 | state_dict = model 92 | else: 93 | raise TypeError("model should be either torch.nn.Module or a dict") 94 | hash_set = set() 95 | model_node = LayerNode('model',fullname='model') 96 | current = model_node 97 | for key,value in state_dict.items(): 98 | names = key.split('.') 99 | for i,name in enumerate(names): 100 | if name not in current: 101 | current[name] = LayerNode(name,parent=current,fullname='.'.join(names[:i+1])) 102 | current = current[name] 103 | 104 | if (value.data_ptr()) in hash_set: 105 | current.value = [0,0] 106 | current.name += "(shared)" 107 | current.fullname += "(shared)" 108 | current.update(current.value) 109 | else: 110 | hash_set.add(value.data_ptr()) 111 | current.value = [value.numel(),value.numel() * value.element_size() / 1024/1024] 112 | current.update(current.value) 113 | 114 | current = model_node 115 | 116 | result = model_node.format(max_level=max_level) 117 | #print (result) 118 | return result, model_node 119 | 120 | -------------------------------------------------------------------------------- /examples/msra_ner_example/utils_ner.py: -------------------------------------------------------------------------------- 1 | import os, pickle 2 | import torch 3 | from torch.utils.data import TensorDataset 4 | 5 | label2id_dict = { 6 | 'O': 0, 7 | 'B-LOC': 1, 8 | 'I-LOC': 2, 9 | 'B-ORG': 3, 10 | 'I-ORG': 4, 11 | 'B-PER': 5, 12 | 'I-PER': 6 13 | } 14 | 15 | id2label_dict = { 16 | 0: 'O', 17 | 1: 'B-LOC', 18 | 2: 'I-LOC', 19 | 3: 'B-ORG', 20 | 4: 'I-ORG', 21 | 5: 'B-PER', 22 | 6: 'I-PER' 23 | } 24 | 25 | class Examples: 26 | def __init__(self, tokens, label_ids): 27 | self.tokens = tokens 28 | self.label_ids = label_ids 29 | 30 | def __str__(self): 31 | return self.__repr__() 32 | 33 | def __repr__(self): 34 | s = "" 35 | s += f"tokens: {''.join(self.tokens)}\n" 36 | s += f"labels: {' '.join(str(i) for i in self.label_ids)}\n" 37 | return s 38 | 39 | class Featues: 40 | def __init__(self, token_ids, input_mask, label_ids): 41 | self.token_ids = token_ids 42 | self.input_mask = input_mask 43 | self.label_ids = label_ids 44 | 45 | 46 | def __str__(self): 47 | return self.__repr__() 48 | 49 | def __repr__(self): 50 | s = "" 51 | s += f"token_ids: {' '.join(str(i) for i in self.token_ids)}\n" 52 | s += f"label_ids: {' '.join(str(i) for i in self.label_ids)}\n" 53 | s += f"input_mask:{' '.join(str(i) for i in self.input_mask)}\n" 54 | return s 55 | 56 | def read_examples(input_file): 57 | examples = [] 58 | tokens = [] 59 | label_ids = [] 60 | errors = 0 61 | with open(input_file) as f: 62 | for idx,line in enumerate(f): 63 | if len(line.strip())==0: 64 | if len(tokens)>0: 65 | examples.append(Examples(tokens,label_ids)) 66 | tokens = [] 67 | label_ids = [] 68 | continue 69 | try: 70 | token, label = line.strip().split('\t') 71 | except ValueError: 72 | errors +=1 73 | continue 74 | tokens.append(token) 75 | label_ids.append(label2id_dict[label]) 76 | if len(tokens) > 0: 77 | examples.append(Examples(tokens, label_ids)) 78 | print ("Num errors: ", errors) 79 | return examples 80 | 81 | def convert_example_to_features(input_file, tokenizer, max_seq_length, 82 | cls_token='[CLS]', sep_token='[SEP]', pad_token_id=0): 83 | features = [] 84 | 85 | examples = read_examples(input_file) 86 | 87 | #convert token to ids 88 | pad_label = [label2id_dict['O']] 89 | for example in examples: 90 | tokens = [cls_token] + example.tokens[:max_seq_length-2] + [sep_token] 91 | label_ids = pad_label + example.label_ids[:max_seq_length-2] + pad_label 92 | 93 | 94 | token_ids = tokenizer.convert_tokens_to_ids(tokens) 95 | input_mask = [1] * len(token_ids) 96 | 97 | padding_length = max_seq_length - len(token_ids) 98 | token_ids = token_ids + [pad_token_id] * padding_length 99 | input_mask = input_mask + [0] * padding_length 100 | label_ids = label_ids + pad_label * padding_length 101 | 102 | assert len(token_ids) == len(input_mask) == len(label_ids) 103 | 104 | features.append(Featues(token_ids=token_ids,input_mask=input_mask,label_ids=label_ids)) 105 | 106 | return examples, features 107 | 108 | def read_features(input_file, max_seq_length=160, tokenizer=None, cls_token='[CLS]', sep_token='[SEP]', pad_token_id=0): 109 | cached_features_file = input_file +f'.cached_feat_{max_seq_length}' 110 | if os.path.exists(cached_features_file): 111 | with open(cached_features_file,'rb') as f: 112 | examples, features = pickle.load(f) 113 | else: 114 | examples, features = convert_example_to_features(input_file,tokenizer,max_seq_length,cls_token,sep_token,pad_token_id) 115 | with open(cached_features_file, 'wb') as f: 116 | pickle.dump([examples, features],f) 117 | 118 | all_token_ids = torch.tensor([f.token_ids for f in features],dtype=torch.long) 119 | all_input_mask = torch.tensor([f.input_mask for f in features],dtype=torch.long) 120 | all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long) 121 | 122 | dataset = TensorDataset(all_token_ids,all_input_mask,all_label_ids) 123 | 124 | return examples, dataset 125 | 126 | if __name__ == '__main__': 127 | #from transformers import BertTokenizer 128 | #tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext") 129 | 130 | input_file = 'msra_train_bio.txt' 131 | max_seq_length = 128 132 | 133 | #dataset = read_features(input_file, 128,tokenizer) 134 | #print (f"length of dataset: {len(dataset)}") 135 | #print (dataset[0]) 136 | #print (dataset[-1]) 137 | 138 | examples = read_examples(input_file) 139 | length = [len(example.tokens) for example in examples] 140 | import numpy as np 141 | print (np.max(length),np.mean(length),np.percentile(length,99)) 142 | print (sum(i>160 for i in length)/len(length)) -------------------------------------------------------------------------------- /examples/msra_ner_example/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | args = None 4 | 5 | def parse(opt=None): 6 | parser = argparse.ArgumentParser() 7 | 8 | ## Required parameters 9 | 10 | parser.add_argument("--vocab_file", default=None, type=str, required=True, 11 | help="The vocabulary file that the BERT model was trained on.") 12 | parser.add_argument("--output_dir", default=None, type=str, required=True, 13 | help="The output directory where the model checkpoints will be written.") 14 | 15 | ## Other parameters 16 | parser.add_argument("--train_file", default=None, type=str) 17 | parser.add_argument("--predict_file", default=None, type=str) 18 | parser.add_argument("--do_lower_case", action='store_true', 19 | help="Whether to lower case the input text. Should be True for uncased " 20 | "models and False for cased models.") 21 | parser.add_argument("--max_seq_length", default=416, type=int, 22 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 23 | "longer than this will be truncated, and sequences shorter than this will be padded.") 24 | parser.add_argument("--do_train", default=False, action='store_true', help="Whether to run training.") 25 | parser.add_argument("--do_predict", default=False, action='store_true', help="Whether to run eval on the dev set.") 26 | parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") 27 | parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.") 28 | parser.add_argument("--learning_rate", default=3e-5, type=float, help="The initial learning rate for Adam.") 29 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 30 | help="Total number of training epochs to perform.") 31 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 32 | help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% " 33 | "of training.") 34 | parser.add_argument("--verbose_logging", default=False, action='store_true', 35 | help="If true, all of the warnings related to data processing will be printed. " 36 | "A number of warnings are expected for a normal SQuAD evaluation.") 37 | parser.add_argument("--no_cuda", 38 | default=False, 39 | action='store_true', 40 | help="Whether not to use CUDA when available") 41 | parser.add_argument('--gradient_accumulation_steps', 42 | type=int, 43 | default=1, 44 | help="Number of updates steps to accumualte before performing a backward/update pass.") 45 | parser.add_argument("--local_rank", 46 | type=int, 47 | default=-1, 48 | help="local_rank for distributed training on gpus") 49 | parser.add_argument('--fp16', 50 | default=False, 51 | action='store_true', 52 | help="Whether to use 16-bit float precisoin instead of 32-bit") 53 | 54 | parser.add_argument('--random_seed',type=int,default=10236797) 55 | parser.add_argument('--load_model_type',type=str,default='bert',choices=['bert','all','none']) 56 | parser.add_argument('--weight_decay_rate',type=float,default=0.01) 57 | parser.add_argument('--do_eval',action='store_true') 58 | parser.add_argument('--PRINT_EVERY',type=int,default=200) 59 | parser.add_argument('--weight',type=float,default=1.0) 60 | parser.add_argument('--ckpt_frequency',type=int,default=2) 61 | 62 | parser.add_argument('--tuned_checkpoint_T',type=str,default=None) 63 | parser.add_argument('--tuned_checkpoint_S',type=str,default=None) 64 | parser.add_argument("--init_checkpoint_S", default=None, type=str) 65 | parser.add_argument("--bert_config_file_T", default=None, type=str, required=True) 66 | parser.add_argument("--bert_config_file_S", default=None, type=str, required=True) 67 | parser.add_argument("--temperature", default=1, type=float, required=False) 68 | parser.add_argument("--teacher_cached",action='store_true') 69 | 70 | parser.add_argument('--schedule',type=str,default='warmup_linear_release') 71 | 72 | parser.add_argument('--no_inputs_mask',action='store_true') 73 | parser.add_argument('--no_logits', action='store_true') 74 | parser.add_argument('--output_encoded_layers' ,default='true',choices=['true','false']) 75 | parser.add_argument('--output_attention_layers',default='true',choices=['true','false']) 76 | parser.add_argument('--matches',nargs='*',type=str) 77 | 78 | parser.add_argument('--lr_decay',default=None,type=float) 79 | parser.add_argument('--official_schedule',default='linear',type=str) 80 | global args 81 | if opt is None: 82 | args = parser.parse_args() 83 | else: 84 | args = parser.parse_args(opt) 85 | 86 | 87 | if __name__ == '__main__': 88 | print (args) 89 | parse(['--SAVE_DIR','test']) 90 | print(args) 91 | -------------------------------------------------------------------------------- /docs/source/_build/html/genindex.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | Index — TextBrewer 0.1.8 documentation 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 |
45 | 46 | 99 | 100 |
101 | 102 | 103 | 109 | 110 | 111 |
112 | 113 |
114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 |
132 | 133 |
    134 | 135 |
  • Docs »
  • 136 | 137 |
  • Index
  • 138 | 139 | 140 |
  • 141 | 142 | 143 | 144 |
  • 145 | 146 |
147 | 148 | 149 |
150 |
151 |
152 |
153 | 154 | 155 |

Index

156 | 157 |
158 | 159 |
160 | 161 | 162 |
163 | 164 |
165 |
166 | 167 | 168 |
169 | 170 |
171 |

172 | © Copyright 2020, Ziqing Yang 173 | 174 |

175 |
176 | Built with Sphinx using a theme provided by Read the Docs. 177 | 178 |
179 | 180 |
181 |
182 | 183 |
184 | 185 |
186 | 187 | 188 | 189 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | -------------------------------------------------------------------------------- /examples/mnli_example/predict_function.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | from torch.utils.data import SequentialSampler,DistributedSampler,DataLoader 5 | from utils_glue import compute_metrics 6 | from tqdm import tqdm 7 | import logging 8 | from collections import defaultdict 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | 13 | def predict(model,eval_datasets,step,args): 14 | eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) 15 | eval_output_dir = args.output_dir 16 | task_results = {} 17 | for eval_task,eval_dataset in zip(eval_task_names, eval_datasets): 18 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: 19 | os.makedirs(eval_output_dir) 20 | logger.info("Predicting...") 21 | logger.info("***** Running predictions *****") 22 | logger.info(" task name = %s", eval_task) 23 | logger.info(" Num examples = %d", len(eval_dataset)) 24 | logger.info(" Batch size = %d", args.predict_batch_size) 25 | eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) 26 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.predict_batch_size) 27 | model.eval() 28 | 29 | pred_logits = [] 30 | label_ids = [] 31 | for batch in tqdm(eval_dataloader, desc="Evaluating", disable=None): 32 | input_ids, input_mask, segment_ids, labels = batch 33 | input_ids = input_ids.to(args.device) 34 | input_mask = input_mask.to(args.device) 35 | segment_ids = segment_ids.to(args.device) 36 | with torch.no_grad(): 37 | logits = model(input_ids, input_mask, segment_ids) 38 | pred_logits.append(logits.detach().cpu()) 39 | label_ids.append(labels) 40 | pred_logits = np.array(torch.cat(pred_logits),dtype=np.float32) 41 | label_ids = np.array(torch.cat(label_ids),dtype=np.int64) 42 | 43 | preds = np.argmax(pred_logits, axis=1) 44 | results = compute_metrics(eval_task, preds, label_ids) 45 | 46 | logger.info("***** Eval results {} task {} *****".format(step, eval_task)) 47 | for key in sorted(results.keys()): 48 | logger.info(f"{eval_task} {key} = {results[key]:.5f}") 49 | task_results[eval_task] = results 50 | 51 | output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") 52 | 53 | write_results(output_eval_file,step,task_results,eval_task_names) 54 | model.train() 55 | return task_results 56 | 57 | 58 | def write_results(output_eval_file,step,task_results,eval_task_names): 59 | with open(output_eval_file, "a") as writer: 60 | all_acc = 0 61 | writer.write(f"step: {step:<8d} ") 62 | line = "Acc:" 63 | 64 | for eval_task in eval_task_names: 65 | acc = task_results[eval_task]['acc'] 66 | all_acc += acc 67 | line += f"{eval_task}={acc:.5f} " 68 | all_acc /= len(eval_task_names) 69 | line += f"All={all_acc:.5f}\n" 70 | writer.write(line) 71 | 72 | def predict_ens(models,eval_datasets,step,args): 73 | eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) 74 | eval_output_dir = args.output_dir 75 | task_results = {} 76 | for eval_task,eval_dataset in zip(eval_task_names, eval_datasets): 77 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: 78 | os.makedirs(eval_output_dir) 79 | logger.info("Predicting...") 80 | logger.info("***** Running predictions *****") 81 | logger.info(" task name = %s", eval_task) 82 | logger.info(" Num examples = %d", len(eval_dataset)) 83 | logger.info(" Batch size = %d", args.predict_batch_size) 84 | eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) 85 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.predict_batch_size) 86 | for model in models: 87 | model.eval() 88 | 89 | pred_logits = [] 90 | label_ids = [] 91 | for batch in tqdm(eval_dataloader, desc="Evaluating", disable=None): 92 | input_ids, input_mask, segment_ids, labels = batch 93 | input_ids = input_ids.to(args.device) 94 | input_mask = input_mask.to(args.device) 95 | segment_ids = segment_ids.to(args.device) 96 | 97 | with torch.no_grad(): 98 | logits_list = [model(input_ids, input_mask, segment_ids) for model in models] 99 | logits = sum(logits_list)/len(logits_list) 100 | pred_logits.append(logits.detach().cpu()) 101 | label_ids.append(labels) 102 | pred_logits = np.array(torch.cat(pred_logits),dtype=np.float32) 103 | label_ids = np.array(torch.cat(label_ids),dtype=np.int64) 104 | 105 | preds = np.argmax(pred_logits, axis=1) 106 | results = compute_metrics(eval_task, preds, label_ids) 107 | 108 | logger.info("***** Eval results {} task {} *****".format(step, eval_task)) 109 | for key in sorted(results.keys()): 110 | logger.info(f"{eval_task} {key} = {results[key]:.5f}") 111 | task_results[eval_task] = results 112 | 113 | output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") 114 | 115 | write_results(output_eval_file,step,task_results,eval_task_names) 116 | for model in models: 117 | model.train() 118 | return task_results 119 | -------------------------------------------------------------------------------- /src/textbrewer/presets.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | from .losses import * 4 | from .schedulers import * 5 | from .utils import cycle 6 | from .projections import linear_projection, projection_with_activation 7 | 8 | class DynamicKeyDict: 9 | def __init__(self, kv_dict): 10 | self.store = kv_dict 11 | def __getitem__(self, key): 12 | if not isinstance(key,(list,tuple)): 13 | return self.store[key] 14 | else: 15 | name = key[0] 16 | args = key[1:] 17 | if len(args)==1 and isinstance(args[0],dict): 18 | return self.store[name](**(args[0])) 19 | else: 20 | return self.store[name](*args) 21 | def __setitem__(self, key, value): 22 | self.store[key] = value 23 | def __contains__(self, key): 24 | if isinstance(key, (list,tuple)): 25 | return key[0] in self.store 26 | else: 27 | return key in self.store 28 | 29 | TEMPERATURE_SCHEDULER=DynamicKeyDict( 30 | {'constant': constant_temperature_scheduler, 31 | 'flsw': flsw_temperature_scheduler_builder, 32 | 'cwsm':cwsm_temperature_scheduler_builder}) 33 | """ 34 | (*custom dict*) used to dynamically adjust distillation temperature. 35 | 36 | * '**constant**' : Constant temperature. 37 | * '**flsw**' : See `Preparing Lessons: Improve Knowledge Distillation with Better Supervision `_. Needs parameters ``beta`` and ``gamma``. 38 | * '**cwsm**': See `Preparing Lessons: Improve Knowledge Distillation with Better Supervision `_. Needs parameter ``beta``. 39 | 40 | Different from other options, when using ``'flsw'`` and ``'cwsm'``, you need to provide extra parameters, for example:: 41 | 42 | #flsw 43 | distill_config = DistillationConfig( 44 | temperature_scheduler = ['flsw', 1, 2] # beta=1, gamma=2 45 | ) 46 | 47 | #cwsm 48 | distill_config = DistillationConfig( 49 | temperature_scheduler = ['cwsm', 1] # beta = 1 50 | ) 51 | 52 | """ 53 | 54 | 55 | 56 | FEATURES = ['hidden','attention'] 57 | 58 | 59 | ADAPTOR_KEYS = ['logits','logits_mask','losses','inputs_mask','labels'] + FEATURES 60 | """ 61 | (*list*) valid keys of the dict returned by the adaptor, includes: 62 | 63 | * '**logits**' 64 | * '**logits_mask**' 65 | * '**losses**' 66 | * '**inputs_mask**' 67 | * '**labels**' 68 | * '**hidden**' 69 | * '**attention**' 70 | """ 71 | 72 | 73 | KD_LOSS_MAP = {'mse': kd_mse_loss, 74 | 'ce': kd_ce_loss} 75 | """ 76 | (*dict*) available KD losses 77 | 78 | * '**mse**' : mean squared error 79 | * '**ce**': cross-entropy loss 80 | """ 81 | 82 | MATCH_LOSS_MAP = {'attention_mse_sum': att_mse_sum_loss, 83 | 'attention_mse': att_mse_loss, 84 | 'attention_ce_mean': att_ce_mean_loss, 85 | 'attention_ce': att_ce_loss, 86 | 'hidden_mse' : hid_mse_loss, 87 | 'cos' : cos_loss, 88 | 'pkd' : pkd_loss, 89 | 'gram' : fsp_loss, 90 | 'fsp' : fsp_loss, 91 | 'mmd' : mmd_loss, 92 | 'nst' : mmd_loss} 93 | """ 94 | (*dict*) intermediate feature matching loss functions, includes: 95 | 96 | * :func:`attention_mse_sum ` 97 | * :func:`attention_mse ` 98 | * :func:`attention_ce_mean ` 99 | * :func:`attention_ce ` 100 | * :func:`hidden_mse ` 101 | * :func:`cos ` 102 | * :func:`pkd ` 103 | * :func:`fsp `, :func:`gram ` 104 | * :func:`nst `, :func:`mmd ` 105 | 106 | See :ref:`intermediate_losses` for details. 107 | """ 108 | 109 | PROJ_MAP = {'linear': linear_projection, 110 | 'relu' : projection_with_activation('ReLU'), 111 | 'tanh' : projection_with_activation('Tanh') 112 | } 113 | """ 114 | (*dict*) layers used to match the different dimensions of intermediate features 115 | 116 | * '**linear**' : linear layer, no activation 117 | * '**relu**' : ReLU activation 118 | * '**tanh**': Tanh activation 119 | """ 120 | 121 | WEIGHT_SCHEDULER = {'linear_decay': linear_decay_weight_scheduler, 122 | 'linear_growth' : linear_growth_weight_scheduler} 123 | """ 124 | (dict) Scheduler used to dynamically adjust KD loss weight and hard_label_loss weight. 125 | 126 | * ‘**linear_decay**' : decay from 1 to 0 during the whole training process. 127 | * '**linear_growth**' : grow from 0 to 1 during the whole training process. 128 | """ 129 | 130 | #TEMPERATURE_SCHEDULER = {'constant': constant_temperature_scheduler, 131 | # 'flsw_scheduler': flsw_temperature_scheduler_builder(1,1)} 132 | 133 | 134 | MAPS = {'kd_loss': KD_LOSS_MAP, 135 | 'match_Loss': MATCH_LOSS_MAP, 136 | 'projection': PROJ_MAP, 137 | 'weight_scheduler': WEIGHT_SCHEDULER, 138 | 'temperature_scheduler': TEMPERATURE_SCHEDULER} 139 | 140 | 141 | def register_new(map_name, name, func): 142 | assert map_name in MAPS 143 | assert callable(func), "Functions to be registered is not callable" 144 | MAPS[map_name][name] = func 145 | 146 | 147 | ''' 148 | Add new loss: 149 | def my_L1_loss(feature_S, feature_T, mask=None): 150 | return (feature_S-feature_T).abs().mean() 151 | 152 | MATCH_LOSS_MAP['my_L1_loss'] = my_L1_loss 153 | ''' 154 | -------------------------------------------------------------------------------- /examples/cmrc2018_example/train_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | import torch.nn.functional as F 4 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 5 | from torch.utils.data.distributed import DistributedSampler 6 | import os, logging 7 | from tqdm import tqdm, trange 8 | from processing import RawResult, write_predictions_google 9 | from cmrc2018_evaluate import evaluate 10 | import json 11 | import numpy as np 12 | 13 | logger = logging.getLogger("train_eval") 14 | 15 | 16 | def predict(model, eval_examples, eval_features, step, args): 17 | device = args.device 18 | logger.info("Predicting...") 19 | logger.info("***** Running predictions *****") 20 | logger.info(" Num orig examples = %d", len(eval_examples)) 21 | logger.info(" Num split examples = %d", len(eval_features)) 22 | logger.info(" Batch size = %d", args.predict_batch_size) 23 | 24 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 25 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 26 | all_doc_mask = torch.tensor([f.doc_mask for f in eval_features], dtype=torch.float) 27 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 28 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 29 | 30 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_doc_mask, all_segment_ids, all_example_index) 31 | if args.local_rank == -1: 32 | eval_sampler = SequentialSampler(eval_data) 33 | else: 34 | eval_sampler = DistributedSampler(eval_data) 35 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size) 36 | 37 | model.eval() 38 | all_results = [] 39 | logger.info("Start evaluating") 40 | 41 | if os.path.exists('all_results.tmp') and not args.do_train: 42 | pass # all_results = pickle.load(open('all_results.tmp', 'rb')) 43 | else: 44 | for input_ids, input_mask, doc_mask, segment_ids, example_indices \ 45 | in tqdm(eval_dataloader, desc="Evaluating", disable=None): 46 | if len(all_results) % 1000 == 0: 47 | logger.info("Processing example: %d" % (len(all_results))) 48 | input_ids = input_ids.to(device) 49 | input_mask = input_mask.to(device) 50 | doc_mask = doc_mask.to(device) 51 | segment_ids = segment_ids.to(device) 52 | with torch.no_grad(): 53 | batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask, 54 | doc_mask) 55 | for i, example_index in enumerate(example_indices): 56 | start_logits = batch_start_logits[i].detach().cpu().tolist() 57 | end_logits = batch_end_logits[i].detach().cpu().tolist() 58 | cls_logits = 0 # Not used batch_cls_logits[i].detach().cpu().tolist() 59 | eval_feature = eval_features[example_index.item()] 60 | unique_id = int(eval_feature.unique_id) 61 | all_results.append(RawResult(unique_id=unique_id, 62 | start_logits=start_logits, 63 | end_logits=end_logits, 64 | cls_logits=cls_logits)) 65 | if not args.do_train: 66 | pass 67 | # try: 68 | # pickle.dump(all_results, open('all_results.tmp', 'wb')) 69 | # except: 70 | # print("can't save all_results.tmp") 71 | 72 | logger.info("Write predictions...") 73 | output_prediction_file = os.path.join(args.output_dir, "predictions_%d.json" % step) 74 | 75 | all_predictions, scores_diff_json = \ 76 | write_predictions_google(eval_examples, eval_features, all_results, 77 | args.n_best_size, args.max_answer_length, 78 | args.do_lower_case, output_prediction_file, 79 | output_nbest_file=None, output_null_log_odds_file=None) 80 | model.train() 81 | if args.do_eval: 82 | eval_data = json.load(open(args.predict_file, 'r', encoding='utf-8')) 83 | F1, EM, TOTAL, SKIP = evaluate(eval_data, all_predictions) # ,scores_diff_json, na_prob_thresh=0) 84 | AVG = (EM+F1)*0.5 85 | output_result = OrderedDict() 86 | output_result['AVERAGE'] = '%.3f' % AVG 87 | output_result['F1'] = '%.3f' % F1 88 | output_result['EM'] = '%.3f' % EM 89 | output_result['TOTAL'] = TOTAL 90 | output_result['SKIP'] = SKIP 91 | logger.info("***** Eval results {} *****".format(step)) 92 | logger.info(json.dumps(output_result)+'\n') 93 | 94 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 95 | with open(output_eval_file, "a") as writer: 96 | writer.write(f"Step: {step} {json.dumps(output_result)}\n") 97 | 98 | #output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_%d.json" % step) 99 | #output_null_odds_file = os.path.join(args.output_dir, "null_odds_%d.json" % (step)) 100 | 101 | # torch.save(state_dict, os.path.join(args.output_dir,"EM{:.4f}_F{:.4f}_gs{}.pkl".format(em,f1,global_step))) 102 | # print ("saving at finish") 103 | # coreModel = model.module if 'DataParallel' in model.__class__.__name__ else model 104 | # torch.save(coreModel.state_dict(),os.path.join(args.output_dir,"%d.pkl" % (global_step))) 105 | # predict(global_step) 106 | -------------------------------------------------------------------------------- /docs/source/_build/html/search.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | Search — TextBrewer 0.1.8 documentation 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 |
45 | 46 | 99 | 100 |
101 | 102 | 103 | 109 | 110 | 111 |
112 | 113 |
114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 |
132 | 133 |
    134 | 135 |
  • Docs »
  • 136 | 137 |
  • Search
  • 138 | 139 | 140 |
  • 141 | 142 | 143 | 144 |
  • 145 | 146 |
147 | 148 | 149 |
150 |
151 |
152 |
153 | 154 | 162 | 163 | 164 |
165 | 166 |
167 | 168 |
169 | 170 |
171 |
172 | 173 | 174 |
175 | 176 |
177 |

178 | © Copyright 2020, Ziqing Yang 179 | 180 |

181 |
182 | Built with Sphinx using a theme provided by Read the Docs. 183 | 184 |
185 | 186 |
187 |
188 | 189 |
190 | 191 |
192 | 193 | 194 | 195 | 200 | 201 | 202 | 203 | 204 | 205 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | -------------------------------------------------------------------------------- /examples/random_tokens_example/distill.py: -------------------------------------------------------------------------------- 1 | import textbrewer 2 | from textbrewer import GeneralDistiller 3 | from textbrewer import TrainingConfig, DistillationConfig 4 | from transformers import BertForSequenceClassification, BertConfig, AdamW 5 | from transformers import get_linear_schedule_with_warmup 6 | import torch 7 | from torch.utils.data import Dataset,DataLoader 8 | import numpy as np 9 | 10 | #device 11 | device = torch.device('cpu') 12 | 13 | # Define models 14 | bert_config = BertConfig.from_json_file('bert_config/bert_config.json') 15 | bert_config_T3 = BertConfig.from_json_file('bert_config/bert_config_T3.json') 16 | 17 | bert_config.output_hidden_states = True 18 | bert_config_T3.output_hidden_states = True 19 | 20 | 21 | teacher_model = BertForSequenceClassification(bert_config) #, num_labels = 2 22 | # Teacher should be initialized with pre-trained weights and fine-tuned on the downstream task. 23 | # For the demonstration purpose, we omit these steps here 24 | 25 | student_model = BertForSequenceClassification(bert_config_T3) #, num_labels = 2 26 | 27 | teacher_model.to(device=device) 28 | student_model.to(device=device) 29 | 30 | # Define Dict Dataset 31 | class DictDataset(Dataset): 32 | def __init__(self, all_input_ids, all_attention_mask, all_labels): 33 | assert len(all_input_ids)==len(all_attention_mask)==len(all_labels) 34 | self.all_input_ids = all_input_ids 35 | self.all_attention_mask = all_attention_mask 36 | self.all_labels = all_labels 37 | 38 | def __getitem__(self, index): 39 | return {'input_ids': self.all_input_ids[index], 40 | 'attention_mask': self.all_attention_mask[index], 41 | 'labels': self.all_labels[index]} 42 | 43 | def __len__(self): 44 | return self.all_input_ids.size(0) 45 | 46 | # Prepare random data 47 | all_input_ids = torch.randint(low=0,high=100,size=(100,128)) # 100 examples of length 128 48 | all_attention_mask = torch.ones_like(all_input_ids) 49 | all_labels = torch.randint(low=0,high=2,size=(100,)) 50 | dataset = DictDataset(all_input_ids, all_attention_mask, all_labels) 51 | eval_dataset = DictDataset(all_input_ids, all_attention_mask, all_labels) 52 | dataloader = DataLoader(dataset,batch_size=32) 53 | num_epochs = 10 54 | num_training_steps = len(dataloader) * num_epochs 55 | # Optimizer and learning rate scheduler 56 | optimizer = AdamW(student_model.parameters(), lr=1e-4) 57 | 58 | scheduler_class = get_linear_schedule_with_warmup 59 | # arguments dict except 'optimizer' 60 | scheduler_args = {'num_warmup_steps':int(0.1*num_training_steps), 'num_training_steps':num_training_steps} 61 | 62 | 63 | # display model parameters statistics 64 | print("\nteacher_model's parametrers:") 65 | result, _ = textbrewer.utils.display_parameters(teacher_model,max_level=3) 66 | print (result) 67 | 68 | print("student_model's parametrers:") 69 | result, _ = textbrewer.utils.display_parameters(student_model,max_level=3) 70 | print (result) 71 | 72 | def simple_adaptor(batch, model_outputs): 73 | # The second element of model_outputs is the logits before softmax 74 | # The third element of model_outputs is hidden states 75 | return {'logits': model_outputs[1], 76 | 'hidden': model_outputs[2], 77 | 'inputs_mask': batch['attention_mask']} 78 | 79 | 80 | #Define callback function 81 | def predict(model, eval_dataset, step, device): 82 | ''' 83 | eval_dataset: 验证数据集 84 | ''' 85 | model.eval() 86 | pred_logits = [] 87 | label_ids =[] 88 | dataloader = DataLoader(eval_dataset,batch_size=32) 89 | for batch in dataloader: 90 | input_ids = batch['input_ids'].to(device) 91 | attention_mask = batch['attention_mask'].to(device) 92 | labels = batch['labels'] 93 | with torch.no_grad(): 94 | logits, _ = model(input_ids=input_ids, attention_mask=attention_mask) 95 | cpu_logits = logits.detach().cpu() 96 | for i in range(len(cpu_logits)): 97 | pred_logits.append(cpu_logits[i].numpy()) 98 | label_ids.append(labels[i]) 99 | model.train() 100 | pred_logits = np.array(pred_logits) 101 | label_ids = np.array(label_ids) 102 | y_p = pred_logits.argmax(axis=-1) 103 | accuracy = (y_p==label_ids).sum()/len(label_ids) 104 | print ("Number of examples: ",len(y_p)) 105 | print ("Acc: ", accuracy) 106 | from functools import partial 107 | callback_fun = partial(predict, eval_dataset=eval_dataset, device=device) # fill other arguments 108 | 109 | 110 | 111 | # Initialize configurations and distiller 112 | train_config = TrainingConfig(device=device) 113 | distill_config = DistillationConfig( 114 | temperature=8, 115 | hard_label_weight=0, 116 | kd_loss_type='ce', 117 | probability_shift=False, 118 | intermediate_matches=[ 119 | {'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}, 120 | {'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse', 'weight' : 1}, 121 | {'layer_T':[0,0], 'layer_S':[0,0], 'feature':'hidden', 'loss': 'nst', 'weight': 1}, 122 | {'layer_T':[8,8], 'layer_S':[2,2], 'feature':'hidden', 'loss': 'nst', 'weight': 1}] 123 | ) 124 | 125 | print ("train_config:") 126 | print (train_config) 127 | 128 | print ("distill_config:") 129 | print (distill_config) 130 | 131 | distiller = GeneralDistiller( 132 | train_config=train_config, distill_config = distill_config, 133 | model_T = teacher_model, model_S = student_model, 134 | adaptor_T = simple_adaptor, adaptor_S = simple_adaptor) 135 | 136 | # Start distilling 137 | with distiller: 138 | distiller.train(optimizer,dataloader, num_epochs=num_epochs, 139 | scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=callback_fun) 140 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | 2 | # Contributor Covenant Code of Conduct 3 | 4 | ## Our Pledge 5 | 6 | We as members, contributors, and leaders pledge to make participation in our 7 | community a harassment-free experience for everyone, regardless of age, body 8 | size, visible or invisible disability, ethnicity, sex characteristics, gender 9 | identity and expression, level of experience, education, socio-economic status, 10 | nationality, personal appearance, race, caste, color, religion, or sexual identity 11 | and orientation. 12 | 13 | We pledge to act and interact in ways that contribute to an open, welcoming, 14 | diverse, inclusive, and healthy community. 15 | 16 | ## Our Standards 17 | 18 | Examples of behavior that contributes to a positive environment for our 19 | community include: 20 | 21 | * Demonstrating empathy and kindness toward other people 22 | * Being respectful of differing opinions, viewpoints, and experiences 23 | * Giving and gracefully accepting constructive feedback 24 | * Accepting responsibility and apologizing to those affected by our mistakes, 25 | and learning from the experience 26 | * Focusing on what is best not just for us as individuals, but for the 27 | overall community 28 | 29 | Examples of unacceptable behavior include: 30 | 31 | * The use of sexualized language or imagery, and sexual attention or 32 | advances of any kind 33 | * Trolling, insulting or derogatory comments, and personal or political attacks 34 | * Public or private harassment 35 | * Publishing others' private information, such as a physical or email 36 | address, without their explicit permission 37 | * Other conduct which could reasonably be considered inappropriate in a 38 | professional setting 39 | 40 | ## Enforcement Responsibilities 41 | 42 | Community leaders are responsible for clarifying and enforcing our standards of 43 | acceptable behavior and will take appropriate and fair corrective action in 44 | response to any behavior that they deem inappropriate, threatening, offensive, 45 | or harmful. 46 | 47 | Community leaders have the right and responsibility to remove, edit, or reject 48 | comments, commits, code, wiki edits, issues, and other contributions that are 49 | not aligned to this Code of Conduct, and will communicate reasons for moderation 50 | decisions when appropriate. 51 | 52 | ## Scope 53 | 54 | This Code of Conduct applies within all community spaces, and also applies when 55 | an individual is officially representing the community in public spaces. 56 | Examples of representing our community include using an official e-mail address, 57 | posting via an official social media account, or acting as an appointed 58 | representative at an online or offline event. 59 | 60 | ## Enforcement 61 | 62 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 63 | reported to the community leaders responsible for enforcement at 64 | ziqingyang@gmail.com. 65 | All complaints will be reviewed and investigated promptly and fairly. 66 | 67 | All community leaders are obligated to respect the privacy and security of the 68 | reporter of any incident. 69 | 70 | ## Enforcement Guidelines 71 | 72 | Community leaders will follow these Community Impact Guidelines in determining 73 | the consequences for any action they deem in violation of this Code of Conduct: 74 | 75 | ### 1. Correction 76 | 77 | **Community Impact**: Use of inappropriate language or other behavior deemed 78 | unprofessional or unwelcome in the community. 79 | 80 | **Consequence**: A private, written warning from community leaders, providing 81 | clarity around the nature of the violation and an explanation of why the 82 | behavior was inappropriate. A public apology may be requested. 83 | 84 | ### 2. Warning 85 | 86 | **Community Impact**: A violation through a single incident or series 87 | of actions. 88 | 89 | **Consequence**: A warning with consequences for continued behavior. No 90 | interaction with the people involved, including unsolicited interaction with 91 | those enforcing the Code of Conduct, for a specified period of time. This 92 | includes avoiding interactions in community spaces as well as external channels 93 | like social media. Violating these terms may lead to a temporary or 94 | permanent ban. 95 | 96 | ### 3. Temporary Ban 97 | 98 | **Community Impact**: A serious violation of community standards, including 99 | sustained inappropriate behavior. 100 | 101 | **Consequence**: A temporary ban from any sort of interaction or public 102 | communication with the community for a specified period of time. No public or 103 | private interaction with the people involved, including unsolicited interaction 104 | with those enforcing the Code of Conduct, is allowed during this period. 105 | Violating these terms may lead to a permanent ban. 106 | 107 | ### 4. Permanent Ban 108 | 109 | **Community Impact**: Demonstrating a pattern of violation of community 110 | standards, including sustained inappropriate behavior, harassment of an 111 | individual, or aggression toward or disparagement of classes of individuals. 112 | 113 | **Consequence**: A permanent ban from any sort of public interaction within 114 | the community. 115 | 116 | ## Attribution 117 | 118 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 119 | version 2.0, available at 120 | [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0]. 121 | 122 | Community Impact Guidelines were inspired by 123 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 124 | 125 | For answers to common questions about this code of conduct, see the FAQ at 126 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available 127 | at [https://www.contributor-covenant.org/translations][translations]. 128 | 129 | [homepage]: https://www.contributor-covenant.org 130 | [v2.0]: https://www.contributor-covenant.org/version/2/0/code_of_conduct.html 131 | [Mozilla CoC]: https://github.com/mozilla/diversity 132 | [FAQ]: https://www.contributor-covenant.org/faq 133 | [translations]: https://www.contributor-covenant.org/translations 134 | 135 | -------------------------------------------------------------------------------- /examples/cmrc2018_example/pytorch_pretrained_bert/optimization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for OpenAI GPT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | from .optimization import SCHEDULES, _LRSchedule, WarmupCosineWithWarmupRestartsSchedule, \ 24 | WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule, WarmupLinearSchedule, WarmupConstantSchedule 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class OpenAIAdam(Optimizer): 30 | """Implements Open AI version of Adam algorithm with weight decay fix. 31 | """ 32 | def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1, 33 | b1=0.9, b2=0.999, e=1e-8, weight_decay=0, 34 | vector_l2=False, max_grad_norm=-1, **kwargs): 35 | if lr is not required and lr < 0.0: 36 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 37 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: 38 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 39 | if not 0.0 <= b1 < 1.0: 40 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 41 | if not 0.0 <= b2 < 1.0: 42 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 43 | if not e >= 0.0: 44 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 45 | # initialize schedule object 46 | if not isinstance(schedule, _LRSchedule): 47 | schedule_type = SCHEDULES[schedule] 48 | schedule = schedule_type(warmup=warmup, t_total=t_total) 49 | else: 50 | if warmup != -1 or t_total != -1: 51 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " 52 | "Please specify custom warmup and t_total in _LRSchedule object.") 53 | defaults = dict(lr=lr, schedule=schedule, 54 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2, 55 | max_grad_norm=max_grad_norm) 56 | super(OpenAIAdam, self).__init__(params, defaults) 57 | 58 | def get_lr(self): 59 | lr = [] 60 | for group in self.param_groups: 61 | for p in group['params']: 62 | state = self.state[p] 63 | if len(state) == 0: 64 | return [0] 65 | lr_scheduled = group['lr'] 66 | lr_scheduled *= group['schedule'].get_lr(state['step']) 67 | lr.append(lr_scheduled) 68 | return lr 69 | 70 | def step(self, closure=None): 71 | """Performs a single optimization step. 72 | 73 | Arguments: 74 | closure (callable, optional): A closure that reevaluates the model 75 | and returns the loss. 76 | """ 77 | loss = None 78 | if closure is not None: 79 | loss = closure() 80 | 81 | for group in self.param_groups: 82 | for p in group['params']: 83 | if p.grad is None: 84 | continue 85 | grad = p.grad.data 86 | if grad.is_sparse: 87 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 88 | 89 | state = self.state[p] 90 | 91 | # State initialization 92 | if len(state) == 0: 93 | state['step'] = 0 94 | # Exponential moving average of gradient values 95 | state['exp_avg'] = torch.zeros_like(p.data) 96 | # Exponential moving average of squared gradient values 97 | state['exp_avg_sq'] = torch.zeros_like(p.data) 98 | 99 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 100 | beta1, beta2 = group['b1'], group['b2'] 101 | 102 | state['step'] += 1 103 | 104 | # Add grad clipping 105 | if group['max_grad_norm'] > 0: 106 | clip_grad_norm_(p, group['max_grad_norm']) 107 | 108 | # Decay the first and second moment running average coefficient 109 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 110 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 111 | denom = exp_avg_sq.sqrt().add_(group['e']) 112 | 113 | bias_correction1 = 1 - beta1 ** state['step'] 114 | bias_correction2 = 1 - beta2 ** state['step'] 115 | 116 | lr_scheduled = group['lr'] 117 | lr_scheduled *= group['schedule'].get_lr(state['step']) 118 | 119 | step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 120 | 121 | p.data.addcdiv_(-step_size, exp_avg, denom) 122 | 123 | # Add weight decay at the end (fixed version) 124 | if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0: 125 | p.data.add_(-lr_scheduled * group['weight_decay'], p.data) 126 | 127 | return loss 128 | -------------------------------------------------------------------------------- /examples/cmrc2018_example/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils 27 | from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME, 28 | WEIGHTS_NAME, 29 | TransfoXLConfig, 30 | TransfoXLLMHeadModel, 31 | load_tf_weights_in_transfo_xl) 32 | from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME, 33 | VOCAB_NAME) 34 | 35 | if sys.version_info[0] == 2: 36 | import cPickle as pickle 37 | else: 38 | import pickle 39 | 40 | # We do this to be able to load python 2 datasets pickles 41 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 42 | data_utils.Vocab = data_utils.TransfoXLTokenizer 43 | data_utils.Corpus = data_utils.TransfoXLCorpus 44 | sys.modules['data_utils'] = data_utils 45 | sys.modules['vocabulary'] = data_utils 46 | 47 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 48 | transfo_xl_config_file, 49 | pytorch_dump_folder_path, 50 | transfo_xl_dataset_file): 51 | if transfo_xl_dataset_file: 52 | # Convert a pre-processed corpus (see original TensorFlow repo) 53 | with open(transfo_xl_dataset_file, "rb") as fp: 54 | corpus = pickle.load(fp, encoding="latin1") 55 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 56 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME 57 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 58 | corpus_vocab_dict = corpus.vocab.__dict__ 59 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 60 | 61 | corpus_dict_no_vocab = corpus.__dict__ 62 | corpus_dict_no_vocab.pop('vocab', None) 63 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 64 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 65 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 66 | 67 | if tf_checkpoint_path: 68 | # Convert a pre-trained TensorFlow model 69 | config_path = os.path.abspath(transfo_xl_config_file) 70 | tf_path = os.path.abspath(tf_checkpoint_path) 71 | 72 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 73 | # Initialise PyTorch model 74 | if transfo_xl_config_file == "": 75 | config = TransfoXLConfig() 76 | else: 77 | config = TransfoXLConfig(transfo_xl_config_file) 78 | print("Building PyTorch model from configuration: {}".format(str(config))) 79 | model = TransfoXLLMHeadModel(config) 80 | 81 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 82 | # Save pytorch-model 83 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 84 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 85 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 86 | torch.save(model.state_dict(), pytorch_weights_dump_path) 87 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 88 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 89 | f.write(config.to_json_string()) 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("--pytorch_dump_folder_path", 95 | default = None, 96 | type = str, 97 | required = True, 98 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 99 | parser.add_argument("--tf_checkpoint_path", 100 | default = "", 101 | type = str, 102 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 103 | parser.add_argument("--transfo_xl_config_file", 104 | default = "", 105 | type = str, 106 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 107 | "This specifies the model architecture.") 108 | parser.add_argument("--transfo_xl_dataset_file", 109 | default = "", 110 | type = str, 111 | help = "An optional dataset file to be converted in a vocabulary.") 112 | args = parser.parse_args() 113 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 114 | args.transfo_xl_config_file, 115 | args.pytorch_dump_folder_path, 116 | args.transfo_xl_dataset_file) 117 | -------------------------------------------------------------------------------- /docs/source/ExperimentResults.md: -------------------------------------------------------------------------------- 1 | # Experimental Results 2 | 3 | 4 | ## English Datasets 5 | 6 | ### MNLI 7 | 8 | * Training without Distillation: 9 | 10 | | Model(ours) | MNLI | 11 | | ------------- | ------------- | 12 | | **BERT-base-cased** | 83.7 / 84.0 | 13 | | T3 | 76.1 / 76.5 | 14 | 15 | * Single-teacher distillation with `GeneralDistiller`: 16 | 17 | | Model (ours) | MNLI | 18 | | :------------- | -------------- | 19 | | **BERT-base-cased** (teacher) | 83.7 / 84.0 | 20 | | T6 (student) | 83.5 / 84.0 | 21 | | T3 (student) | 81.8 / 82.7 | 22 | | T3-small (student) | 81.3 / 81.7 | 23 | | T4-tiny (student) | 82.0 / 82.6 | 24 | | T12-nano (student) | 83.2 / 83.9 | 25 | 26 | * Multi-teacher distillation with `MultiTeacherDistiller`: 27 | 28 | | Model (ours) | MNLI | 29 | | :------------- | -------------- | 30 | | **BERT-base-cased** (teacher #1) | 83.7 / 84.0 | 31 | | **BERT-base-cased** (teacher #2) | 83.6 / 84.2 | 32 | | **BERT-base-cased** (teacher #3) | 83.7 / 83.8 | 33 | | ensemble (average of #1, #2 and #3) | 84.3 / 84.7 | 34 | | BERT-base-cased (student) | **84.8 / 85.3**| 35 | 36 | ### SQuAD 37 | 38 | * Training without Distillation: 39 | 40 | | Model(ours) | SQuAD | 41 | | ------------- | ------------- | 42 | | **BERT-base-cased** | 81.5 / 88.6 | 43 | | T6 | 75.0 / 83.3 | 44 | | T3 | 63.0 / 74.3 | 45 | 46 | * Single-teacher distillation with `GeneralDistiller`: 47 | 48 | | Model(ours) | SQuAD | 49 | | ------------- | ------------- | 50 | | **BERT-base-cased** (teacher) | 81.5 / 88.6 | 51 | | T6 (student) | 80.8 / 88.1 | 52 | | T3 (student) | 76.4 / 84.9 | 53 | | T3-small (student) | 72.3 / 81.4 | 54 | | T4-tiny (student) | 73.7 / 82.5 | 55 | |   + DA | 75.2 / 84.0 | 56 | | T12-nano (student) | 79.0 / 86.6 | 57 | 58 | **Note**: When distilling to T4-tiny, NewsQA is used for data augmentation on SQuAD. 59 | 60 | * Multi-teacher distillation with `MultiTeacherDistiller`: 61 | 62 | | Model (ours) | SQuAD | 63 | | :------------- | -------------- | 64 | | **BERT-base-cased** (teacher #1) | 81.1 / 88.6 | 65 | | **BERT-base-cased** (teacher #2) | 81.2 / 88.5 | 66 | | **BERT-base-cased** (teacher #3) | 81.2 / 88.7 | 67 | | ensemble (average of #1, #2 and #3) | 82.3 / 89.4 | 68 | | BERT-base-cased (student) | **83.5 / 90.0**| 69 | 70 | ### CoNLL-2003 English NER 71 | 72 | * Training without Distillation: 73 | 74 | | Model(ours) | CoNLL-2003 | 75 | | ------------- | ----------- | 76 | | **BERT-base-cased** | 91.1 | 77 | | BiGRU | 81.1 | 78 | | T3 | 85.3 | 79 | 80 | * Single-teacher distillation with `GeneralDistiller`: 81 | 82 | | Model(ours) | CoNLL-2003 | 83 | | ------------- | ------------- | 84 | | **BERT-base-cased** (teacher) | 91.1 | 85 | | BiGRU | 85.3 | 86 | | T6 (student) | 90.7 | 87 | | T3 (student) | 87.5 | 88 | |   + DA | 90.0 | 89 | | T3-small (student) | 78.6 | 90 | |   + DA | - | 91 | | T4-tiny (student) | 77.5 | 92 | |   + DA | 89.1 | 93 | | T12-nano (student) | 78.8 | 94 | |   + DA | 89.6 | 95 | 96 | **Note**: HotpotQA is used for data augmentation on CoNLL-2003. 97 | 98 | ## Chinese Datasets (RoBERTa-wwm-ext as the teacher) 99 | 100 | ### XNLI 101 | 102 | | Model | XNLI | 103 | | :--------------- | ----------------- | 104 | | **RoBERTa-wwm-ext** (teacher) | 79.9 | 105 | | T3 (student) | 78.4 | 106 | | T3-small (student) | 76.0 | 107 | | T4-tiny (student) | 76.2 | 108 | 109 | ### LCQMC 110 | 111 | | Model | LCQMC | 112 | | :--------------- | ----------- | 113 | | **RoBERTa-wwm-ext** (teacher) | 89.4 | 114 | | T3 (student) | 89.0 | 115 | | T3-small (student) | 88.1 | 116 | | T4-tiny (student) | 88.4 | 117 | 118 | ### CMRC 2018 and DRCD 119 | 120 | | Model | CMRC 2018 | DRCD | 121 | | --------------- | ---------------- | ------------ | 122 | | **RoBERTa-wwm-ext** (teacher) | 68.8 / 86.4 | 86.5 / 92.5 | 123 | | T3 (student) | 63.4 / 82.4 | 76.7 / 85.2 | 124 | |   + DA | 66.4 / 84.2 | 78.2 / 86.4 | 125 | | T3-small (student) | 46.1 / 71.0 | 71.4 / 82.2 | 126 | |   + DA | 58.0 / 79.3 | 75.8 / 84.8 | 127 | | T4-tiny (student) | 54.3 / 76.8 | 75.5 / 84.9 | 128 | |   + DA | 61.8 / 81.8 | 77.3 / 86.1 | 129 | 130 | **Note**: CMRC 2018 and DRCD take each other as the augmentation dataset on the experiments. 131 | 132 | ## Chinese Datasets (Electra-base as the teacher) 133 | 134 | * Training without Distillation: 135 | 136 | | Model | XNLI | LCQMC | CMRC 2018 | DRCD | MSRA NER| 137 | |:---------------------------|------------|--------| --------------| -------------|---------| 138 | | **Electra-base** (teacher) | 77.8 | 89.8 | 65.6 / 84.7 | 86.9 / 92.3 | 95.14 | 139 | | Electra-small (pretrained) | 72.5 | 86.3 | 62.9 / 80.2 | 79.4 / 86.4 | | 140 | 141 | * Single-teacher distillation with `GeneralDistiller`: 142 | 143 | | Model | XNLI | LCQMC | CMRC 2018 | DRCD | MSRA NER | 144 | | :---------------------------|------------|-------------|-----------------| -------------|----------| 145 | | **Electra-base** (teacher) | 77.8 | 89.8 | 65.6 / 84.7 | 86.9 / 92.3 | 95.14 | 146 | | Electra-small (random) | 77.2 | 89.0 | 66.5 / 84.9 | 84.8 / 91.0 | | 147 | | Electra-small (pretrained) | 77.7 | 89.3 | 66.5 / 84.9 | 85.5 / 91.3 |93.48 | 148 | 149 | **Note**: 150 | 151 | 1. Random: randomly initialized 152 | 2. Pretrained: initialized with pretrained weights 153 | 154 | A good initialization of the student (Electra-small) improves the performance. -------------------------------------------------------------------------------- /src/textbrewer/distiller_multiteacher.py: -------------------------------------------------------------------------------- 1 | from .distiller_utils import * 2 | from .distiller_basic import BasicDistiller 3 | 4 | class MultiTeacherDistiller(BasicDistiller): 5 | """ 6 | Distills multiple teacher models (of the same tasks) into a student model. **It doesn't support intermediate feature matching**. 7 | 8 | Args: 9 | train_config (:class:`TrainingConfig`): training configuration. 10 | distill_config (:class:`DistillationConfig`): distillation configuration. 11 | model_T (List[torch.nn.Module]): list of teacher models. 12 | model_S (torch.nn.Module): student model. 13 | adaptor_T (Callable): teacher model's adaptor. 14 | adaptor_S (Callable): student model's adaptor. 15 | 16 | The roles of `adaptor_T` and `adaptor_S` are explained in :py:func:`adaptor`. 17 | """ 18 | 19 | def __init__(self, train_config, 20 | distill_config, 21 | model_T, 22 | model_S, 23 | adaptor_T, 24 | adaptor_S): 25 | super(MultiTeacherDistiller, self).__init__( 26 | train_config, distill_config, 27 | model_T, model_S, 28 | adaptor_T, adaptor_S) 29 | if hasattr(self.adaptor_T,'__iter__'): 30 | assert len(self.adaptor_T)==len(self.model_T) 31 | 32 | def train_on_batch(self, batch, args): 33 | if self.d_config.is_caching_logits is False: 34 | (teacher_batch, results_T), (student_batch, results_S) = get_outputs_from_batch(batch, self.t_config.device, self.model_T, self.model_S, args) 35 | 36 | if hasattr(self.adaptor_T,'__iter__'): 37 | results_T = [post_adaptor(adpt_t(teacher_batch,results_t)) for results_t,adpt_t in zip(results_T,self.adaptor_T)] 38 | else: 39 | results_T = [post_adaptor(self.adaptor_T(teacher_batch,results_t)) for results_t in results_T] 40 | results_S = post_adaptor(self.adaptor_S(student_batch,results_S)) 41 | else: 42 | batch, cached_logits = batch 43 | _, (student_batch, results_S) = get_outputs_from_batch(batch, self.t_config.device, self.model_T, self.model_S, args, no_teacher_forward=True) 44 | results_S = post_adaptor(self.adaptor_S(student_batch,results_S)) 45 | results_T = [{'logits': [lo.to(self.t_config.device) for lo in logits]} for logits in cached_logits] 46 | if 'logits_mask' in results_S: 47 | results_T[0]['logits_mask'] = results_S['logits_mask'] 48 | 49 | 50 | logits_list_T = [results_t['logits'] for results_t in results_T] # list of tensor 51 | logits_list_S = results_S['logits'] # list of tensor 52 | total_loss = 0 53 | losses_dict = dict() 54 | total_kd_loss = 0 55 | 56 | if 'logits_mask' in results_S: 57 | masks_list_S = results_S['logits_mask'] 58 | logits_list_S = select_logits_with_mask(logits_list_S,masks_list_S) #(mask_sum, num_of_class) 59 | if 'logits_mask' in results_T[0]: 60 | masks_list_T = results_T[0]['logits_mask'] 61 | logits_list_T = [select_logits_with_mask(logits_list_t,masks_list_T) 62 | for logits_list_t in logits_list_T] #(mask_sum, num_of_class) 63 | 64 | if self.d_config.probability_shift is True: 65 | labels_list = results_S['labels'] 66 | for l_T, l_S, labels in zip(zip(*logits_list_T),logits_list_S,labels_list): 67 | mean_l_T = sum(l_T)/len(l_T) 68 | mean_l_T = probability_shift_(mean_l_T, labels) 69 | if self.d_config.temperature_scheduler is not None: 70 | temperature = self.d_config.temperature_scheduler(l_S, mean_l_T, self.d_config.temperature) 71 | else: 72 | temperature = self.d_config.temperature 73 | total_kd_loss += self.kd_loss(l_S, mean_l_T, temperature) 74 | else: 75 | for l_T, l_S in zip(zip(*logits_list_T),logits_list_S): 76 | mean_l_T = sum(l_T)/len(l_T) 77 | if self.d_config.temperature_scheduler is not None: 78 | temperature = self.d_config.temperature_scheduler(l_S, mean_l_T, self.d_config.temperature) 79 | else: 80 | temperature = self.d_config.temperature 81 | total_kd_loss += self.kd_loss(l_S, mean_l_T, temperature) 82 | total_loss += total_kd_loss * self.d_config.kd_loss_weight 83 | losses_dict['unweighted_kd_loss'] = total_kd_loss 84 | 85 | if 'losses' in results_S: 86 | total_hl_loss = 0 87 | for loss in results_S['losses']: 88 | # in case of multi-GPU 89 | total_hl_loss += loss.mean() 90 | total_loss += total_hl_loss * self.d_config.hard_label_weight 91 | losses_dict['unweighted_hard_label_loss'] = total_hl_loss 92 | 93 | return total_loss, losses_dict 94 | 95 | def cache_logits(self, batch, args, batch_postprocessor): 96 | if batch_postprocessor is not None: 97 | batch = batch_postprocessor(batch) 98 | 99 | if type(batch) is dict: 100 | new_batch = {} 101 | for k,v in batch.items(): 102 | if type(v) is torch.Tensor: 103 | new_batch[k] = v.to(self.t_config.device) 104 | else: 105 | new_batch[k] = v 106 | with torch.no_grad(): 107 | results_T = [model_t(**new_batch, **args) for model_t in self.model_T] 108 | else: 109 | new_batch = tuple(item.to(self.t_config.device) if type(item) is torch.Tensor else item for item in batch) 110 | with torch.no_grad(): 111 | results_T = [model_t(*new_batch, **args) for model_t in self.model_T] 112 | 113 | if hasattr(self.adaptor_T,'__iter__'): 114 | results_T = [post_adaptor(adpt_t(batch,results_t)) for results_t,adpt_t in zip(results_T,self.adaptor_T)] 115 | else: 116 | results_T = [post_adaptor(self.adaptor_T(batch,results_t)) for results_t in results_T] 117 | 118 | self.logits_cache.append([batch, [[logits.to('cpu') for logits in results_t['logits']] for results_t in results_T]]) -------------------------------------------------------------------------------- /docs/source/_build/html/searchindex.js: -------------------------------------------------------------------------------- 1 | Search.setIndex({docnames:["GettingStarted","Quickstart copy","index"],envversion:{"sphinx.domains.c":1,"sphinx.domains.changeset":1,"sphinx.domains.citation":1,"sphinx.domains.cpp":1,"sphinx.domains.index":1,"sphinx.domains.javascript":1,"sphinx.domains.math":2,"sphinx.domains.python":1,"sphinx.domains.rst":1,"sphinx.domains.std":1,sphinx:56},filenames:["GettingStarted.md","Quickstart copy.md","index.rst"],objects:{},objnames:{},objtypes:{},terms:{"108m":[0,1],"10k":[0,1],"11k":[0,1],"14m":[0,1],"17m":[0,1],"20k":[0,1],"239k":[0,1],"23k":[0,1],"27k":[0,1],"2ab":[],"31m":[0,1],"393k":[0,1],"44m":[0,1],"65m":[0,1],"88k":[0,1],"case":[0,1],"class":[0,1],"default":[0,1],"import":[0,1],"public":[0,1],"return":[0,1],"try":[0,1],For:[0,1],The:[0,1],There:[0,1],acc:[0,1],account:[0,1],achiev:[0,1],actual:[0,1],adapt:[0,1],adaptor_:[0,1],adaptor_t:[0,1],adding:[0,1],adjust:[0,1],advantag:[0,1],advis:[0,1],after:[0,1],airaria:[0,1],all:[0,1],allow:2,also:[0,1],analysi:[0,1],api:[0,1],architectur:[0,1],architecur:[0,1],archtectur:[0,1],argument:[0,1],art:2,articl:[0,1],arxiv:[0,1],assum:[0,1],attent:[0,1],augment:[0,1],author:[0,1],auxiliari:[0,1],avail:[0,1],avoid:[0,1],base:[0,1,2],basic:[0,1],basicdistil:[0,1],basictrain:[0,1],batch:[0,1],been:[0,1],befor:[0,1],below:[0,1],bert:[0,1],bidirect:[0,1],bigru:[0,1],binari:[0,1],block:[0,1],both:2,bracket:[0,1],build:[0,1],built:[0,1],call:[0,1],can:[0,1],charg:[0,1],che:[0,1],checkpoint:[0,1],chen:[0,1],citat:2,cite:[0,1],classif:[0,1],clone:[0,1],cmrc2018:[0,1],cmrc:[0,1],cold:[0,1],com:[0,1],combin:[0,1],compar:[0,1],comparison:[0,1],compat:[0,1],comprehens:[0,1],compress:2,comput:[0,1],concept:2,conclus:[0,1],conduct:[0,1],conll2003:[0,1],conll:[0,1],construct:[0,1],convert:[0,1],core:2,could:[0,1],cui:[0,1],current:[0,1],data:[0,1],dataload:[0,1],dataparallel:[0,1],decai:[0,1],deep:[0,1],def:[0,1],definit:[0,1],demonstr:[0,1],design:[0,1],detail:[0,1],dev:[0,1],dictionari:[0,1],differ:[0,1],differnt:[0,1],directori:[0,1],displai:[],display_paramet:[0,1],distil:2,distilbert:[0,1],distill_config:[0,1],distillationconfig:[0,1],distillt:2,document:[0,1],doe:[0,1],doesn:[0,1],don:[0,1],drcd:[0,1],dure:[0,1],dynam:[0,1],each:[0,1],easi:[0,1,2],element:[0,1],embed:[0,1],enlgish:[0,1],entiti:[0,1],epoch:[0,1],equat:[],equivl:[0,1],especi:[0,1],etc:[0,1],evalu:[0,1],exampl:[0,1],except:[0,1],exmapl:[0,1],experi:2,explan:[0,1],ext:[0,1],extract:[0,1],faq:2,featur:[0,1],feed:[0,1],find:[0,1],first:[0,1],flexibl:[0,1],follow:2,fore:[0,1],form:[0,1],format:[0,1],forward:[0,1],found:[0,1],fp16:[0,1],frac:[],framework:[0,1,2],freeli:[0,1],from:[0,1,2],fulli:[0,1],gener:[0,1],generaldistil:[0,1],get:2,git:[0,1],github:[0,1],gpu:[0,1],gru:[0,1],guop:[0,1],hard:[0,1],has:[0,1],hat:[],have:[0,1],hbar:[],help:[0,1],here:[0,1],hidden:[0,1],hidden_ms:[0,1],hit:[0,1],hotpotqa:[0,1],how:[0,1],howev:[0,1],http:[0,1],hyperparamt:[0,1],iflytek:[0,1],implement:[0,1],impress:[0,1],includ:[0,1,2],increas:2,index:2,infer:2,initi:[0,1],input:[0,1],instal:2,intermedi:[0,1],intermediate_match:[0,1],intial:[0,1],introduct:2,issu:2,joint:[0,1],journal:[0,1],keep:[0,1],knowledg:[0,1,2],known:2,knwledg:[0,1],l3n:[0,1],l4t:[0,1],label:[0,1],laboratori:[0,1],languag:[0,1,2],larg:[0,1],larger:[0,1],latest:[0,1],layer:[0,1],layer_:[0,1],layer_t:[0,1],lcqmc:[0,1],learn:[0,1],list:[0,1],liu:[0,1],logit:[0,1],loss:[0,1],lvert:[],lvertpsirangl:[],machin:[0,1],main:[0,1],match:[0,1],matrix:[0,1],max_level:[0,1],memori:2,method:[0,1,2],metric:[0,1],mix:[0,1],mnli:[0,1],mode:[0,1],model:2,model_:[0,1],model_output:[0,1],model_t:[0,1],modifi:[0,1],modul:[0,1,2],more:[0,1],most:[0,1],mrc:[0,1],mse:[0,1],multi:[0,1],multipl:[0,1],multitaskdistil:[0,1],multiteacherdistil:[0,1],name:[0,1],natur:[0,1,2],need:[0,1],ner:[0,1],neuron:[0,1],newsqa:[0,1],nlp:[0,1,2],none:[0,1],note:[0,1],nteacher_model:[0,1],num_epoch:[0,1],number:[0,1],numpi:[0,1],offer:[0,1],offici:[0,1],onli:[0,1],open:[0,1,2],optim:[0,1],option:[0,1],organ:[0,1],other:[0,1],otherwis:[0,1],our:[0,1],output:[0,1],own:[0,1],page:2,pair:[0,1],paper:[0,1,2],param:[0,1],paramet:[0,1],parametr:[0,1],partial:[],pass:[0,1],perform:[0,1,2],pip:[0,1],pkd:[0,1],pleas:[0,1],possibl:[0,1],pre:[0,1],predefin:[0,1],prepar:[0,1],preprint:[0,1],preset:[0,1],print:[0,1],prior:[0,1],problem:[0,1],process:[0,1,2],provid:[0,1,2],psi:[],pypi:[0,1],python:[0,1],pytorch:[0,1,2],quickli:[0,1,2],quickstart:2,random:[0,1],randomli:[0,1],rangl:[],rate:[0,1],rbt3:[0,1],read:[0,1],recogn:[0,1],recognit:[0,1],recommend:[0,1],reduc:2,rel:[0,1,2],relat:[0,1],releas:[0,1],requir:[0,1],research:[0,1],respect:[0,1],roberta:[0,1],runabl:[0,1],sacrific:2,same:[0,1],save:[0,1],scale:[0,1],schedul:[0,1],scheme:[0,1],schroding:[],search:2,second:[0,1],see:[0,1],select:[0,1],sentenc:[0,1],sequenc:[0,1],set:[0,1],setup:[0,1],sever:[0,1],shijin:[0,1],ship:[0,1],should:[0,1],show:[0,1],shown:[0,1],simpl:[0,1],simple_adaptor:[0,1],singl:[0,1],size:[0,1],small:[0,1,2],smmd:[0,1],soft:[0,1],some:[0,1],sourc:[0,1,2],span:[0,1],specif:[0,1],specifi:[0,1],speed:2,squad:[0,1],stage:[0,1],standard:[0,1],start:[0,1,2],state:[0,1,2],statist:[0,1],step:[0,1],strategi:[0,1],student:[0,1],student_model:[0,1],suitabl:[0,1],supervis:[0,1],support:[0,1],take:[0,1],task:[0,1],teacher:[0,1],teacher_model:[0,1],techniqu:[0,1,2],technolog:[0,1],temperatur:[0,1],tensorboard:[0,1],tensorboardx:[0,1],test:[0,1],text:[0,1],textbrew:[0,1],than:[0,1],thei:[0,1],theseu:[0,1],thi:[0,1],third:[0,1],three:[0,1],through:[0,1],time:[0,1],ting:[0,1],tini:[0,1],tinybert:[0,1],titl:[0,1],toi:[0,1],token:[0,1],tool:[0,1],toolkit:[0,1,2],tqdm:[0,1],tradit:[0,1],train_config:[0,1],trainingconfig:[0,1],traningconfig:[0,1],transfer:[0,1],transform:[0,1],translat:[0,1],two:[0,1],type:[0,1],typic:[0,1],unless:[0,1],updat:[0,1],usag:[0,1,2],use:[0,1,2],used:[0,1],user:2,using:[0,1],usual:[0,1],util:[0,1],valu:[0,1],varieti:[0,1],variou:[0,1,2],version:[0,1],wang:[0,1],wanxiang:[0,1],wechat:[0,1],weight:[0,1],when:[0,1],whenev:[0,1],which:[0,1,2],wide:[0,1],workflow:2,wwm:[0,1],xnli:[0,1],yang:[0,1],year:[0,1],yime:[0,1],you:[0,1],your:[0,1],zhipeng:[0,1],ziqe:[0,1]},titles:["Introduction","Quickstart","TextBrewer"],titleterms:{"function":[0,1],adaptor:[0,1],callback:[0,1],chines:[0,1],citat:[0,1],concept:[0,1],configur:[0,1],core:[0,1],dataset:[0,1],defin:[0,1],distil:[0,1],document:[],english:[0,1],experi:[0,1],faq:[0,1],follow:[0,1],indic:2,instal:[0,1],introduct:[0,1],issu:[0,1],known:[0,1],model:[0,1],quickstart:[0,1],result:[0,1],tabl:2,textbrew:2,train:[0,1],user:[0,1],welcom:[],workflow:[0,1]}}) --------------------------------------------------------------------------------