├── utils ├── save_ckp.py ├── download_t5.sh └── parse_c4.py ├── config ├── default.config ├── default_local.config ├── Downstream │ ├── FineTune │ │ ├── ELI5FT.config │ │ ├── NQFT.config │ │ ├── HQAFT.config │ │ ├── TQAFT.config │ │ ├── FEVERFT.config │ │ ├── WoWFT.config │ │ ├── zsREFT.config │ │ └── TRexFT.config │ ├── Adapter │ │ ├── NQAdapter.config │ │ ├── ELI5Adapter.config │ │ ├── HQAAdapter.config │ │ ├── TQAAdapter.config │ │ ├── FEVERAdapter.config │ │ ├── WoWAdapter.config │ │ ├── zsREAdapter.config │ │ └── TRexAdapter.config │ ├── PlugD │ │ ├── NQPlugD.config │ │ ├── ELI5PlugD.config │ │ ├── HQAPlugD.config │ │ ├── TQAPlugD.config │ │ ├── WoWPlugD.config │ │ ├── FEVERPlugD.config │ │ ├── zsREPlugD.config │ │ └── TRexPlugD.config │ ├── HyperPlugD │ │ ├── NQHyperPlugD.config │ │ ├── ELI5HyperPlugD.config │ │ ├── HQAHyperPlugD.config │ │ ├── TQAHyperPlugD.config │ │ ├── FEVERHyperPlugD.config │ │ ├── WoWHyperPlugD.config │ │ ├── zsREHyperPlugD.config │ │ └── TRexHyperPlugD.config │ ├── PostAdapter │ │ ├── NQAdapter.config │ │ ├── ELI5Adapter.config │ │ ├── TQAAdapter.config │ │ ├── FEVERAdapter.config │ │ ├── HQAAdapter.config │ │ ├── WoWAdapter.config │ │ ├── zsREAdapter.config │ │ └── TRexAdapter.config │ ├── PostHyperPlugD │ │ ├── NQ2PlugD.config │ │ ├── SQuADPlugD.config │ │ ├── NQPlugD.config │ │ ├── WoWPlugD.config │ │ ├── ELI5PlugD.config │ │ ├── TQAPlugD.config │ │ ├── HQAPlugD.config │ │ ├── FEVERPlugD.config │ │ ├── zsREPlugD.config │ │ └── TRexPlugD.config │ └── PostPlugD │ │ ├── NQPlugD.config │ │ ├── ELI5PlugD.config │ │ ├── WoWPlugD.config │ │ ├── TQAPlugD.config │ │ ├── HQAPlugD.config │ │ ├── FEVERPlugD.config │ │ ├── zsREPlugD.config │ │ └── TRexPlugD.config └── PL │ └── PlugD.config ├── config_parser ├── __init__.py ├── __pycache__ │ ├── parser.cpython-37.pyc │ ├── parser.cpython-39.pyc │ ├── __init__.cpython-37.pyc │ └── __init__.cpython-39.pyc └── parser.py ├── reader ├── __init__.py ├── __pycache__ │ ├── reader.cpython-37.pyc │ ├── reader.cpython-39.pyc │ ├── __init__.cpython-37.pyc │ └── __init__.cpython-39.pyc └── reader.py ├── figs └── overview.png ├── model ├── __pycache__ │ ├── metric.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── optimizer.cpython-37.pyc │ └── scheduler.cpython-37.pyc ├── Basic │ ├── __pycache__ │ │ ├── layers.cpython-37.pyc │ │ ├── layers.cpython-39.pyc │ │ ├── DeltaT5.cpython-37.pyc │ │ ├── DeltaT5.cpython-39.pyc │ │ ├── DeltaBlocks.cpython-37.pyc │ │ ├── DeltaBlocks.cpython-39.pyc │ │ ├── DeltaT5Allhidden.cpython-39.pyc │ │ └── DeltaT5EmbRecy.cpython-39.pyc │ └── layers.py ├── MLM │ ├── __pycache__ │ │ ├── PlugDPL.cpython-37.pyc │ │ └── PlugDPretrain.cpython-37.pyc │ └── PlugDPL.py ├── PlugD │ └── __pycache__ │ │ └── PlugD.cpython-37.pyc ├── Seq2Seq │ ├── __pycache__ │ │ ├── Seq2Seq.cpython-37.pyc │ │ └── OpenQAAdapter.cpython-37.pyc │ └── Seq2Seq.py ├── T5Adapter │ ├── __pycache__ │ │ ├── T5Adapter.cpython-37.pyc │ │ └── T5Adapter.cpython-39.pyc │ └── T5Adapter.py ├── TextClassification │ ├── __pycache__ │ │ ├── TextClassifier.cpython-37.pyc │ │ ├── TextClassifierED2LM.cpython-39.pyc │ │ ├── TextClassifierPlugD.cpython-37.pyc │ │ ├── TextClassifierPlugD.cpython-39.pyc │ │ ├── TextClassifierAdapter.cpython-39.pyc │ │ └── TextClassifierAllhidden.cpython-39.pyc │ └── TextClassifier.py ├── __init__.py ├── scheduler.py ├── loss.py ├── optimizer.py └── metric.py ├── tools ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-39.pyc │ ├── eval_tool.cpython-37.pyc │ ├── eval_tool.cpython-39.pyc │ ├── flop_tool.cpython-37.pyc │ ├── flop_tool.cpython-39.pyc │ ├── init_tool.cpython-37.pyc │ ├── init_tool.cpython-39.pyc │ ├── output_init.cpython-37.pyc │ ├── output_init.cpython-39.pyc │ ├── output_tool.cpython-37.pyc │ ├── output_tool.cpython-39.pyc │ ├── train_tool.cpython-37.pyc │ ├── train_tool.cpython-39.pyc │ ├── accuracy_tool.cpython-37.pyc │ └── accuracy_tool.cpython-39.pyc ├── output_init.py ├── output_tool.py ├── test_tool.py ├── __init__.py ├── eval_tool.py └── init_tool.py ├── dataset ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── __init__.cpython-39.pyc ├── NQ │ ├── __pycache__ │ │ ├── NQDataset.cpython-37.pyc │ │ └── NQDataset.cpython-39.pyc │ └── NQDataset.py ├── Json │ ├── __pycache__ │ │ ├── JsonDataset.cpython-37.pyc │ │ ├── JsonDataset.cpython-39.pyc │ │ ├── JsonlineDataset.cpython-37.pyc │ │ └── JsonlineDataset.cpython-39.pyc │ ├── JsonDataset.py │ └── JsonlineDataset.py ├── OpenQA │ ├── __pycache__ │ │ ├── SFDataset.cpython-37.pyc │ │ ├── SFDataset.cpython-39.pyc │ │ ├── OpenQADataset.cpython-37.pyc │ │ ├── OpenQADataset.cpython-39.pyc │ │ ├── SQuADDataset.cpython-37.pyc │ │ ├── SQuADDataset.cpython-39.pyc │ │ ├── OpenQADataset2.cpython-37.pyc │ │ └── OpenQADataset2.cpython-39.pyc │ ├── OpenQADataset2.py │ ├── SFDataset.py │ ├── OpenQADataset.py │ └── SQuADDataset.py ├── kara │ ├── __pycache__ │ │ ├── KaraDataset.cpython-37.pyc │ │ └── KaraDataset.cpython-39.pyc │ └── KaraDataset.py ├── FEVER │ ├── __pycache__ │ │ ├── FEVERDataset.cpython-37.pyc │ │ └── FEVERDataset.cpython-39.pyc │ └── FEVERDataset.py ├── KGDataset │ ├── __pycache__ │ │ └── KGDataset.cpython-39.pyc │ └── KGDataset.py ├── Pretrain │ ├── JsonDataset.py │ └── RawDataset.py └── __init__.py ├── formatter ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── __init__.cpython-39.pyc ├── MLM │ └── __pycache__ │ │ ├── MLMFormatter.cpython-39.pyc │ │ ├── OnlyMLMFormatter.cpython-39.pyc │ │ ├── PlugDPLFormatter.cpython-37.pyc │ │ ├── MultiTaskFormatter.cpython-39.pyc │ │ ├── MultiTaskVer2Formatter.cpython-39.pyc │ │ ├── PlugDPretrainFormatter.cpython-37.pyc │ │ └── PlugDPretrainFormatter.cpython-39.pyc ├── FEVER │ ├── __pycache__ │ │ ├── FEVERFormatter.cpython-37.pyc │ │ ├── FEVERFormatter.cpython-39.pyc │ │ ├── FEVERCtxFormatter.cpython-37.pyc │ │ └── FEVERCtxFormatter.cpython-39.pyc │ ├── FEVERFormatter.py │ └── FEVERCtxFormatter.py ├── OpenQA │ ├── __pycache__ │ │ ├── OpenQAFormatter.cpython-37.pyc │ │ └── OpenQAFormatter.cpython-39.pyc │ └── OpenQAFormatter.py ├── TextClassification │ ├── __pycache__ │ │ ├── Wiki80Formatter.cpython-39.pyc │ │ ├── Wiki80Formatter2.cpython-39.pyc │ │ ├── WikiETFormatter.cpython-39.pyc │ │ ├── TextClassificationAdapterFormatter.cpython-37.pyc │ │ ├── TextClassificationAdapterFormatter.cpython-39.pyc │ │ ├── TextClassificationED2LMFormatter.cpython-39.pyc │ │ ├── TextClassificationPlugDFormatter.cpython-37.pyc │ │ ├── TextClassificationPlugDFormatter.cpython-39.pyc │ │ └── TextClassificationAdapterFormatter.cpython-37.pyc.140700935018992 │ ├── TextClassificationPlugDFormatter.py │ └── TextClassificationAdapterFormatter.py └── __init__.py ├── run_script ├── PluginLearning │ └── run_plugd_pl.sh └── Downstream │ ├── finetune.sh │ ├── adapter.sh │ ├── plugd.sh │ ├── hyperplugd-large.sh │ └── postplugd-large.sh ├── train.py └── README.md /utils/save_ckp.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config/default.config: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config_parser/__init__.py: -------------------------------------------------------------------------------- 1 | from .parser import create_config -------------------------------------------------------------------------------- /reader/__init__.py: -------------------------------------------------------------------------------- 1 | from .reader import init_dataset, init_test_dataset 2 | -------------------------------------------------------------------------------- /figs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/figs/overview.png -------------------------------------------------------------------------------- /model/__pycache__/metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/__pycache__/metric.cpython-37.pyc -------------------------------------------------------------------------------- /reader/__pycache__/reader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/reader/__pycache__/reader.cpython-37.pyc -------------------------------------------------------------------------------- /reader/__pycache__/reader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/reader/__pycache__/reader.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/optimizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/__pycache__/optimizer.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/__pycache__/scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /reader/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/reader/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reader/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/reader/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /tools/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/tools/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/tools/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /tools/__pycache__/eval_tool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/tools/__pycache__/eval_tool.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/eval_tool.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/tools/__pycache__/eval_tool.cpython-39.pyc -------------------------------------------------------------------------------- /tools/__pycache__/flop_tool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/tools/__pycache__/flop_tool.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/flop_tool.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/tools/__pycache__/flop_tool.cpython-39.pyc -------------------------------------------------------------------------------- /tools/__pycache__/init_tool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/tools/__pycache__/init_tool.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/init_tool.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/tools/__pycache__/init_tool.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /formatter/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /formatter/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /model/Basic/__pycache__/layers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/Basic/__pycache__/layers.cpython-37.pyc -------------------------------------------------------------------------------- /model/Basic/__pycache__/layers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/Basic/__pycache__/layers.cpython-39.pyc -------------------------------------------------------------------------------- /model/MLM/__pycache__/PlugDPL.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/MLM/__pycache__/PlugDPL.cpython-37.pyc -------------------------------------------------------------------------------- /model/PlugD/__pycache__/PlugD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/PlugD/__pycache__/PlugD.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/output_init.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/tools/__pycache__/output_init.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/output_init.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/tools/__pycache__/output_init.cpython-39.pyc -------------------------------------------------------------------------------- /tools/__pycache__/output_tool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/tools/__pycache__/output_tool.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/output_tool.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/tools/__pycache__/output_tool.cpython-39.pyc -------------------------------------------------------------------------------- /tools/__pycache__/train_tool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/tools/__pycache__/train_tool.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/train_tool.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/tools/__pycache__/train_tool.cpython-39.pyc -------------------------------------------------------------------------------- /config_parser/__pycache__/parser.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/config_parser/__pycache__/parser.cpython-37.pyc -------------------------------------------------------------------------------- /config_parser/__pycache__/parser.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/config_parser/__pycache__/parser.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/NQ/__pycache__/NQDataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/NQ/__pycache__/NQDataset.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/NQ/__pycache__/NQDataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/NQ/__pycache__/NQDataset.cpython-39.pyc -------------------------------------------------------------------------------- /model/Basic/__pycache__/DeltaT5.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/Basic/__pycache__/DeltaT5.cpython-37.pyc -------------------------------------------------------------------------------- /model/Basic/__pycache__/DeltaT5.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/Basic/__pycache__/DeltaT5.cpython-39.pyc -------------------------------------------------------------------------------- /tools/__pycache__/accuracy_tool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/tools/__pycache__/accuracy_tool.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/accuracy_tool.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/tools/__pycache__/accuracy_tool.cpython-39.pyc -------------------------------------------------------------------------------- /config_parser/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/config_parser/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /config_parser/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/config_parser/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /model/Basic/__pycache__/DeltaBlocks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/Basic/__pycache__/DeltaBlocks.cpython-37.pyc -------------------------------------------------------------------------------- /model/Basic/__pycache__/DeltaBlocks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/Basic/__pycache__/DeltaBlocks.cpython-39.pyc -------------------------------------------------------------------------------- /model/MLM/__pycache__/PlugDPretrain.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/MLM/__pycache__/PlugDPretrain.cpython-37.pyc -------------------------------------------------------------------------------- /model/Seq2Seq/__pycache__/Seq2Seq.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/Seq2Seq/__pycache__/Seq2Seq.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/Json/__pycache__/JsonDataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/Json/__pycache__/JsonDataset.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/Json/__pycache__/JsonDataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/Json/__pycache__/JsonDataset.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/OpenQA/__pycache__/SFDataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/OpenQA/__pycache__/SFDataset.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/OpenQA/__pycache__/SFDataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/OpenQA/__pycache__/SFDataset.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/kara/__pycache__/KaraDataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/kara/__pycache__/KaraDataset.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/kara/__pycache__/KaraDataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/kara/__pycache__/KaraDataset.cpython-39.pyc -------------------------------------------------------------------------------- /model/T5Adapter/__pycache__/T5Adapter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/T5Adapter/__pycache__/T5Adapter.cpython-37.pyc -------------------------------------------------------------------------------- /model/T5Adapter/__pycache__/T5Adapter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/T5Adapter/__pycache__/T5Adapter.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/FEVER/__pycache__/FEVERDataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/FEVER/__pycache__/FEVERDataset.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/FEVER/__pycache__/FEVERDataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/FEVER/__pycache__/FEVERDataset.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/Json/__pycache__/JsonlineDataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/Json/__pycache__/JsonlineDataset.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/Json/__pycache__/JsonlineDataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/Json/__pycache__/JsonlineDataset.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/KGDataset/__pycache__/KGDataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/KGDataset/__pycache__/KGDataset.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/OpenQA/__pycache__/OpenQADataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/OpenQA/__pycache__/OpenQADataset.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/OpenQA/__pycache__/OpenQADataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/OpenQA/__pycache__/OpenQADataset.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/OpenQA/__pycache__/SQuADDataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/OpenQA/__pycache__/SQuADDataset.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/OpenQA/__pycache__/SQuADDataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/OpenQA/__pycache__/SQuADDataset.cpython-39.pyc -------------------------------------------------------------------------------- /formatter/MLM/__pycache__/MLMFormatter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/MLM/__pycache__/MLMFormatter.cpython-39.pyc -------------------------------------------------------------------------------- /model/Basic/__pycache__/DeltaT5Allhidden.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/Basic/__pycache__/DeltaT5Allhidden.cpython-39.pyc -------------------------------------------------------------------------------- /model/Basic/__pycache__/DeltaT5EmbRecy.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/Basic/__pycache__/DeltaT5EmbRecy.cpython-39.pyc -------------------------------------------------------------------------------- /model/Seq2Seq/__pycache__/OpenQAAdapter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/Seq2Seq/__pycache__/OpenQAAdapter.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/OpenQA/__pycache__/OpenQADataset2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/OpenQA/__pycache__/OpenQADataset2.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/OpenQA/__pycache__/OpenQADataset2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/dataset/OpenQA/__pycache__/OpenQADataset2.cpython-39.pyc -------------------------------------------------------------------------------- /formatter/FEVER/__pycache__/FEVERFormatter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/FEVER/__pycache__/FEVERFormatter.cpython-37.pyc -------------------------------------------------------------------------------- /formatter/FEVER/__pycache__/FEVERFormatter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/FEVER/__pycache__/FEVERFormatter.cpython-39.pyc -------------------------------------------------------------------------------- /formatter/MLM/__pycache__/OnlyMLMFormatter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/MLM/__pycache__/OnlyMLMFormatter.cpython-39.pyc -------------------------------------------------------------------------------- /formatter/MLM/__pycache__/PlugDPLFormatter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/MLM/__pycache__/PlugDPLFormatter.cpython-37.pyc -------------------------------------------------------------------------------- /formatter/FEVER/__pycache__/FEVERCtxFormatter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/FEVER/__pycache__/FEVERCtxFormatter.cpython-37.pyc -------------------------------------------------------------------------------- /formatter/FEVER/__pycache__/FEVERCtxFormatter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/FEVER/__pycache__/FEVERCtxFormatter.cpython-39.pyc -------------------------------------------------------------------------------- /formatter/MLM/__pycache__/MultiTaskFormatter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/MLM/__pycache__/MultiTaskFormatter.cpython-39.pyc -------------------------------------------------------------------------------- /formatter/OpenQA/__pycache__/OpenQAFormatter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/OpenQA/__pycache__/OpenQAFormatter.cpython-37.pyc -------------------------------------------------------------------------------- /formatter/OpenQA/__pycache__/OpenQAFormatter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/OpenQA/__pycache__/OpenQAFormatter.cpython-39.pyc -------------------------------------------------------------------------------- /formatter/MLM/__pycache__/MultiTaskVer2Formatter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/MLM/__pycache__/MultiTaskVer2Formatter.cpython-39.pyc -------------------------------------------------------------------------------- /formatter/MLM/__pycache__/PlugDPretrainFormatter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/MLM/__pycache__/PlugDPretrainFormatter.cpython-37.pyc -------------------------------------------------------------------------------- /formatter/MLM/__pycache__/PlugDPretrainFormatter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/MLM/__pycache__/PlugDPretrainFormatter.cpython-39.pyc -------------------------------------------------------------------------------- /model/TextClassification/__pycache__/TextClassifier.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/TextClassification/__pycache__/TextClassifier.cpython-37.pyc -------------------------------------------------------------------------------- /formatter/TextClassification/__pycache__/Wiki80Formatter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/TextClassification/__pycache__/Wiki80Formatter.cpython-39.pyc -------------------------------------------------------------------------------- /formatter/TextClassification/__pycache__/Wiki80Formatter2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/TextClassification/__pycache__/Wiki80Formatter2.cpython-39.pyc -------------------------------------------------------------------------------- /formatter/TextClassification/__pycache__/WikiETFormatter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/TextClassification/__pycache__/WikiETFormatter.cpython-39.pyc -------------------------------------------------------------------------------- /model/TextClassification/__pycache__/TextClassifierED2LM.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/TextClassification/__pycache__/TextClassifierED2LM.cpython-39.pyc -------------------------------------------------------------------------------- /model/TextClassification/__pycache__/TextClassifierPlugD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/TextClassification/__pycache__/TextClassifierPlugD.cpython-37.pyc -------------------------------------------------------------------------------- /model/TextClassification/__pycache__/TextClassifierPlugD.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/TextClassification/__pycache__/TextClassifierPlugD.cpython-39.pyc -------------------------------------------------------------------------------- /model/TextClassification/__pycache__/TextClassifierAdapter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/TextClassification/__pycache__/TextClassifierAdapter.cpython-39.pyc -------------------------------------------------------------------------------- /model/TextClassification/__pycache__/TextClassifierAllhidden.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/model/TextClassification/__pycache__/TextClassifierAllhidden.cpython-39.pyc -------------------------------------------------------------------------------- /formatter/TextClassification/__pycache__/TextClassificationAdapterFormatter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/TextClassification/__pycache__/TextClassificationAdapterFormatter.cpython-37.pyc -------------------------------------------------------------------------------- /formatter/TextClassification/__pycache__/TextClassificationAdapterFormatter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/TextClassification/__pycache__/TextClassificationAdapterFormatter.cpython-39.pyc -------------------------------------------------------------------------------- /formatter/TextClassification/__pycache__/TextClassificationED2LMFormatter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/TextClassification/__pycache__/TextClassificationED2LMFormatter.cpython-39.pyc -------------------------------------------------------------------------------- /formatter/TextClassification/__pycache__/TextClassificationPlugDFormatter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/TextClassification/__pycache__/TextClassificationPlugDFormatter.cpython-37.pyc -------------------------------------------------------------------------------- /formatter/TextClassification/__pycache__/TextClassificationPlugDFormatter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/TextClassification/__pycache__/TextClassificationPlugDFormatter.cpython-39.pyc -------------------------------------------------------------------------------- /run_script/PluginLearning/run_plugd_pl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python3 -m torch.distributed.launch --nproc_per_node=8 --master_port=20086 \ 4 | train.py \ 5 | -c config/PL/PlugDL-c4.config \ 6 | -g 0,1,2,3,4,5,6,7 \ 7 | 2>&1 | tee log/PL/PlugD.log 8 | -------------------------------------------------------------------------------- /formatter/TextClassification/__pycache__/TextClassificationAdapterFormatter.cpython-37.pyc.140700935018992: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Document-Plugin/HEAD/formatter/TextClassification/__pycache__/TextClassificationAdapterFormatter.cpython-37.pyc.140700935018992 -------------------------------------------------------------------------------- /run_script/Downstream/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATASET=$1 3 | python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=20089 \ 4 | train.py -c config/OpenQA/FineTune/${DATASET}FT.config \ 5 | -g 0,1 \ 6 | 2>&1 | tee log/OpenQA/FineTune/${DATASET}-FT-large.log 7 | -------------------------------------------------------------------------------- /run_script/Downstream/adapter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATASET=$1 3 | python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=20086 \ 4 | train.py -c config/OpenQA/Adapter/${DATASET}Adapter.config \ 5 | -g 0,1 \ 6 | 2>&1 | tee log/OpenQA/Adapter/${DATASET}-adapter-large.log 7 | -------------------------------------------------------------------------------- /run_script/Downstream/plugd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATASET=$1 3 | PlugD=$2 4 | python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=20089 \ 5 | train.py -c config/Downstream/PlugD/${DATASET}PlugD.config \ 6 | -g 0,1 \ 7 | --checkpoint ${PlugD} \ 8 | 2>&1 | tee log/Downstream/${DATASET}-plugd-large.log 9 | -------------------------------------------------------------------------------- /run_script/Downstream/hyperplugd-large.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATASET=$1 3 | CKP=$2 4 | python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=20089 \ 5 | train.py -c config/Downstream/HyperPlugD/${DATASET}HyperPlugD.config \ 6 | -g 0,1 \ 7 | --checkpoint ${CKP} #\ 8 | # 2>&1 | tee log/Downstream/plugd/${DATASET}-hyperplugd-large-${SFX}.log 9 | -------------------------------------------------------------------------------- /run_script/Downstream/postplugd-large.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATASET=$1 3 | PlugD=$2 4 | python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=20089 \ 5 | train.py -c config/Downstream/PostPlugD/${DATASET}PlugD.config \ 6 | -g 0,1 \ 7 | --checkpoint ${PlugD} \ 8 | --do_test \ 9 | 2>&1 | tee log/OpenQA/postplugd/${DATASET}-postplugd-large.log 10 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .MLM.PlugDPL import PlugDPlugLearning 2 | from .TextClassification.TextClassifier import TextClassifier 3 | from .Seq2Seq.Seq2Seq import Seq2Seq 4 | 5 | 6 | model_list = { 7 | "TextClassification": TextClassifier, 8 | "Seq2Seq": Seq2Seq, 9 | "PlugD": PlugDPlugLearning, 10 | } 11 | 12 | def get_model(model_name): 13 | if model_name in model_list.keys(): 14 | return model_list[model_name] 15 | else: 16 | raise NotImplementedError 17 | -------------------------------------------------------------------------------- /dataset/Pretrain/JsonDataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | 5 | class JsonDataset(Dataset): 6 | def __init__(self, config, mode, encoding="utf8", *args, **params): 7 | self.config = config 8 | self.mode = mode 9 | 10 | self.data = [json.loads(line) for line in open(config.get("data", "%s_data_path" % mode), "r", encoding="utf8")] 11 | 12 | def __getitem__(self, idx): 13 | return self.data[idx] 14 | 15 | def __len__(self): 16 | return len(self.data) 17 | -------------------------------------------------------------------------------- /dataset/Json/JsonDataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | 5 | class JsonDataset(Dataset): 6 | def __init__(self, config, mode, encoding="utf8", *args, **params): 7 | self.config = config 8 | self.mode = mode 9 | 10 | data_path = config.get("data", "%s_data_path" % mode) 11 | self.data = json.load(open(data_path, "r")) 12 | print("the number of data in %s: %s" % (mode, len(self.data))) 13 | 14 | def __getitem__(self, idx): 15 | return self.data[idx] 16 | 17 | def __len__(self): 18 | return len(self.data) 19 | -------------------------------------------------------------------------------- /tools/output_init.py: -------------------------------------------------------------------------------- 1 | from .output_tool import null_output_function, binary_output_function 2 | from .output_tool import squad_output_function, mlm_output_function 3 | 4 | output_function_dic = { 5 | "Null": null_output_function, 6 | "binary": binary_output_function, 7 | "squad": squad_output_function, 8 | "mlm": mlm_output_function, 9 | } 10 | 11 | 12 | def init_output_function(config, *args, **params): 13 | name = config.get("output", "output_function") 14 | 15 | if name in output_function_dic: 16 | return output_function_dic[name] 17 | else: 18 | raise NotImplementedError 19 | -------------------------------------------------------------------------------- /dataset/Pretrain/RawDataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | 5 | class RawDataset(Dataset): 6 | def __init__(self, config, mode, encoding="utf8", *args, **params): 7 | self.config = config 8 | self.mode = mode 9 | 10 | data = json.load(open(config.get("data", "%s_data_path" % mode), "r", encoding="utf8")) 11 | self.data = [{"doc": d["document"], "id": did, "question": d["questions"]} for did, d in enumerate(data) if len(d["questions"]) > 0] 12 | 13 | def __getitem__(self, idx): 14 | return self.data[idx] 15 | 16 | def __len__(self): 17 | return len(self.data) 18 | -------------------------------------------------------------------------------- /model/scheduler.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | from bmtrain.lr_scheduler.warmup import WarmupLRScheduler 4 | 5 | class T5Scheduler(WarmupLRScheduler): 6 | r""" 7 | After a warmup period during which performs :math:`\text{lr}=\text{start_lr}\times \dfrac{\text{num_iter}}{\text{warmup_iter}^{3/2}}`, 8 | The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{\text{1}}{\sqrt{\text{num_iter}}}` 9 | """ 10 | 11 | def get_lr_warmup(self, num_iter) -> float: 12 | return self.start_lr 13 | 14 | def get_lr_decay(self, num_iter) -> float: 15 | return self.start_lr * math.sqrt(self.warmup_iter) / math.sqrt(num_iter) 16 | -------------------------------------------------------------------------------- /dataset/Json/JsonlineDataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | 5 | class JsonlineDataset(Dataset): 6 | def __init__(self, config, mode, encoding="utf8", *args, **params): 7 | self.config = config 8 | self.mode = mode 9 | 10 | data_path = config.get("data", "%s_data_path" % mode) 11 | fin = open(data_path, "r") 12 | self.data = [json.loads(line) for line in fin] 13 | print("the number of data in %s: %s" % (mode, len(self.data))) 14 | 15 | def __getitem__(self, idx): 16 | return self.data[idx] 17 | 18 | def __len__(self): 19 | return len(self.data) 20 | -------------------------------------------------------------------------------- /dataset/NQ/NQDataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | 5 | class NQDataset(Dataset): 6 | def __init__(self, config, mode, encoding="utf8", *args, **params): 7 | self.config = config 8 | self.mode = mode 9 | self.data = [json.loads(line) for line in open(config.get("data", "%s_data_path" % mode), "r", encoding="utf8")] 10 | 11 | def __getitem__(self, idx): 12 | ret = self.data[idx] 13 | return { 14 | "context": ret["passage"], 15 | "answers": [{"text": a} for a in ret["answers"]], 16 | "question": ret["question"], 17 | } 18 | 19 | def __len__(self): 20 | return len(self.data) 21 | -------------------------------------------------------------------------------- /utils/download_t5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p checkpoint/PLMs/t5-large 3 | mkdir -p checkpoint/PLMs/t5-large/tokenizer 4 | wget https://openbmb.oss-cn-hongkong.aliyuncs.com/model_center/t5-large/config.json -cP checkpoint/PLMs/t5-large 5 | wget https://openbmb.oss-cn-hongkong.aliyuncs.com/model_center/t5-large/pytorch_model.pt -cP checkpoint/PLMs/t5-large 6 | wget https://huggingface.co/t5-large/resolve/main/config.json -cP checkpoint/PLMs/t5-large/tokenizer 7 | wget https://huggingface.co/t5-large/resolve/main/generation_config.json -cP checkpoint/PLMs/t5-large/tokenizer 8 | wget https://huggingface.co/t5-large/resolve/main/tokenizer.json -cP checkpoint/PLMs/t5-large/tokenizer 9 | wget https://huggingface.co/t5-large/resolve/main/spiece.model -cP checkpoint/PLMs/t5-large/tokenizer 10 | -------------------------------------------------------------------------------- /config/default_local.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 16 3 | batch_size = 16 4 | 5 | shuffle = True 6 | 7 | reader_num = 1 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-6 11 | weight_decay = 1e-5 12 | 13 | multilabel = False 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=1.0 17 | 18 | ignore_no_grad=False 19 | 20 | save_step=-1 21 | 22 | [eval] #eval parameters 23 | 24 | shuffle = False 25 | reader_num = 1 26 | 27 | [distributed] 28 | use = True 29 | backend = nccl 30 | 31 | [data] #data parameters 32 | 33 | 34 | [model] #model parameters 35 | pretrained_model_path = checkpoint/PLMs/t5-large 36 | 37 | adapter_path = None 38 | 39 | [output] #output parameters 40 | output_time = 10 41 | test_time = 1 42 | output_grad_step=200 43 | 44 | model_path = checkpoint 45 | 46 | output_function = squad -------------------------------------------------------------------------------- /dataset/FEVER/FEVERDataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | from tools import print_rank 5 | 6 | class FEVERDataset(Dataset): 7 | def __init__(self, config, mode, encoding="utf8", *args, **params): 8 | self.config = config 9 | self.mode = mode 10 | 11 | self.data = [] 12 | for line in open(config.get("data", "%s_data_path" % mode), "r", encoding="utf8"): 13 | line = json.loads(line) 14 | line["output"][0]["provenance"] = line["output"][0]["provenance"][:3] 15 | self.data.append(line) 16 | print_rank("Load %s data from %s, size: %d" % (mode, config.get("data", "%s_data_path" % mode), len(self.data))) 17 | 18 | def __getitem__(self, idx): 19 | return self.data[idx] 20 | 21 | def __len__(self): 22 | return len(self.data) 23 | -------------------------------------------------------------------------------- /dataset/OpenQA/OpenQADataset2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | 5 | class OpenQADataset2(Dataset): 6 | def __init__(self, config, mode, encoding="utf8", *args, **params): 7 | self.config = config 8 | self.mode = mode 9 | 10 | self.data = [] 11 | fin = open(config.get("data", "%s_data_path" % mode), "r") 12 | fin.readline() 13 | for line in fin: 14 | line = json.loads(line) 15 | for qa in line["qas"]: 16 | self.data.append({ 17 | "context": [{"text": line["context"]}], 18 | "question": qa["question"], 19 | "answers": qa["answers"] 20 | }) 21 | 22 | def __getitem__(self, idx): 23 | return self.data[idx] 24 | 25 | def __len__(self): 26 | return len(self.data) 27 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .kara.KaraDataset import make_kara_dataset 2 | 3 | 4 | from .NQ.NQDataset import NQDataset 5 | from .Json.JsonDataset import JsonDataset 6 | from .Json.JsonlineDataset import JsonlineDataset 7 | 8 | from .FEVER.FEVERDataset import FEVERDataset 9 | 10 | from .OpenQA.OpenQADataset import OpenQADataset,FewOpenQADataset 11 | from .OpenQA.SQuADDataset import SQuADDataset,FewSQuADDataset 12 | from .OpenQA.OpenQADataset2 import OpenQADataset2 13 | from .OpenQA.SFDataset import SFDataset 14 | 15 | 16 | dataset_list = { 17 | "NQT5": NQDataset, 18 | 19 | "json": JsonDataset, 20 | "json-line": JsonlineDataset, 21 | 22 | "kara": make_kara_dataset, 23 | 24 | "FEVER": FEVERDataset, 25 | "OpenQA": OpenQADataset, 26 | "SQuAD": SQuADDataset, 27 | "OpenQA2": OpenQADataset2, 28 | "SlotFilling": SFDataset, 29 | "FewSQuAD": FewSQuADDataset, 30 | "FewOpenQA": FewOpenQADataset, 31 | } 32 | -------------------------------------------------------------------------------- /config/Downstream/FineTune/ELI5FT.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 6 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=* 21 | 22 | max_len = 512 23 | ans_max_len = 384 24 | 25 | finetune = True 26 | 27 | 28 | [eval] #eval parameters 29 | batch_size = 32 30 | reader_num = 4 31 | 32 | [data] #data parameters 33 | train_dataset_type = OpenQA 34 | train_formatter_type = OpenQA 35 | train_data_path = data/dpr-top5/eli5-train-kilt.jsonl 36 | 37 | valid_dataset_type = OpenQA 38 | valid_formatter_type = OpenQA 39 | valid_data_path = data/dpr-top5/eli5-dev-kilt.jsonl 40 | 41 | 42 | [model] #model parameters 43 | model_name = Seq2Seq 44 | pretrained_model = t5-large 45 | 46 | model_type = t5 47 | 48 | [output] #output parameters 49 | output_time = 20 50 | test_time = 1 51 | output_grad_step = 200 52 | 53 | model_name = ELI5Adapter 54 | output_function = squad 55 | -------------------------------------------------------------------------------- /config/Downstream/FineTune/NQFT.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 6 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=* 21 | 22 | max_len = 512 23 | ans_max_len = 64 24 | 25 | finetune = True 26 | 27 | [eval] #eval parameters 28 | batch_size = 32 29 | reader_num = 4 30 | 31 | [data] #data parameters 32 | train_dataset_type = OpenQA 33 | train_formatter_type = OpenQA 34 | train_data_path = data/dpr-top5/nq-train-kilt.jsonl 35 | 36 | valid_dataset_type = OpenQA 37 | valid_formatter_type = OpenQA 38 | valid_data_path = data/dpr-top5/nq-dev-kilt.jsonl 39 | 40 | 41 | [model] #model parameters 42 | model_name = Seq2Seq 43 | pretrained_model = t5-large 44 | 45 | model_type = t5 46 | 47 | [output] #output parameters 48 | output_time = 20 49 | test_time = 1 50 | output_grad_step = 200 51 | 52 | model_name = NQAdapter 53 | output_function = squad 54 | -------------------------------------------------------------------------------- /config/Downstream/Adapter/NQAdapter.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*adapter* 21 | 22 | max_len = 512 23 | ans_max_len = 64 24 | bottleneck_dim=16 25 | 26 | [eval] #eval parameters 27 | batch_size = 32 28 | reader_num = 4 29 | 30 | [data] #data parameters 31 | train_dataset_type = OpenQA 32 | train_formatter_type = OpenQA 33 | train_data_path = data/dpr-top5/nq-train-kilt.jsonl 34 | 35 | valid_dataset_type = OpenQA 36 | valid_formatter_type = OpenQA 37 | valid_data_path = data/dpr-top5/nq-dev-kilt.jsonl 38 | 39 | 40 | [model] #model parameters 41 | model_name = Seq2Seq 42 | pretrained_model = t5-large 43 | 44 | model_type = t5 45 | 46 | [output] #output parameters 47 | output_time = 20 48 | test_time = 1 49 | output_grad_step = 200 50 | 51 | model_name = NQAdapter 52 | output_function = squad 53 | -------------------------------------------------------------------------------- /config/Downstream/FineTune/HQAFT.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 6 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=* 21 | 22 | max_len = 512 23 | ans_max_len = 64 24 | 25 | finetune = True 26 | 27 | [eval] #eval parameters 28 | batch_size = 32 29 | reader_num = 4 30 | 31 | [data] #data parameters 32 | train_dataset_type = OpenQA 33 | train_formatter_type = OpenQA 34 | train_data_path = data/dpr-top5/hotpotqa-train-kilt.jsonl 35 | 36 | valid_dataset_type = OpenQA 37 | valid_formatter_type = OpenQA 38 | valid_data_path = data/dpr-top5/hotpotqa-dev-kilt.jsonl 39 | 40 | 41 | [model] #model parameters 42 | model_name = Seq2Seq 43 | pretrained_model = t5-large 44 | 45 | model_type = t5 46 | 47 | [output] #output parameters 48 | output_time = 20 49 | test_time = 1 50 | output_grad_step = 200 51 | 52 | model_name = HQAAdapter 53 | output_function = squad 54 | -------------------------------------------------------------------------------- /config/Downstream/FineTune/TQAFT.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 6 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=* 21 | 22 | max_len = 512 23 | ans_max_len = 64 24 | 25 | finetune = True 26 | 27 | [eval] #eval parameters 28 | batch_size = 32 29 | reader_num = 4 30 | 31 | [data] #data parameters 32 | train_dataset_type = OpenQA 33 | train_formatter_type = OpenQA 34 | train_data_path = data/dpr-top5/triviaqa-train-kilt.jsonl 35 | 36 | valid_dataset_type = OpenQA 37 | valid_formatter_type = OpenQA 38 | valid_data_path = data/dpr-top5/triviaqa-dev-kilt.jsonl 39 | 40 | 41 | [model] #model parameters 42 | model_name = Seq2Seq 43 | pretrained_model = t5-large 44 | 45 | model_type = t5 46 | 47 | [output] #output parameters 48 | output_time = 20 49 | test_time = 1 50 | output_grad_step = 200 51 | 52 | model_name = TQAAdapter 53 | output_function = squad 54 | -------------------------------------------------------------------------------- /utils/parse_c4.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from tqdm import tqdm 4 | import kara_storage 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--in_path', '-in', default="data/c4-json", required=True) 9 | parser.add_argument('--out_path', '-out', default="data/c4-kara", required=True) 10 | args = parser.parse_args() 11 | 12 | 13 | 14 | out_path = args.out_path 15 | if os.path.exists(out_path): 16 | os.system("rm -rf %s" % out_path) 17 | os.makedirs(out_path, exist_ok=True) 18 | 19 | storage = kara_storage.KaraStorage("file://%s" % out_path) 20 | dataset = storage.open("C4", "train", "w", version="1st") 21 | 22 | valid_text = [] 23 | c4_path = args.in_path 24 | for fname in tqdm(os.listdir(c4_path)): 25 | fin = open(os.path.join(c4_path, fname), "r") 26 | try: 27 | lines = fin.readlines() 28 | except Exception as err: 29 | print(err) 30 | print(fname) 31 | continue 32 | for line in tqdm(lines): 33 | line = json.loads(line) 34 | if len(line["text"].split()) < 50: 35 | continue 36 | dataset.write(line) 37 | dataset.close() 38 | 39 | -------------------------------------------------------------------------------- /config/Downstream/Adapter/ELI5Adapter.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*adapter* 21 | 22 | max_len = 512 23 | ans_max_len = 384 24 | bottleneck_dim=16 25 | 26 | 27 | [eval] #eval parameters 28 | batch_size = 32 29 | reader_num = 4 30 | 31 | [data] #data parameters 32 | train_dataset_type = OpenQA 33 | train_formatter_type = OpenQA 34 | train_data_path = data/dpr-top5/eli5-train-kilt.jsonl 35 | 36 | valid_dataset_type = OpenQA 37 | valid_formatter_type = OpenQA 38 | valid_data_path = data/dpr-top5/eli5-dev-kilt.jsonl 39 | 40 | 41 | [model] #model parameters 42 | model_name = Seq2Seq 43 | pretrained_model = t5-large 44 | 45 | model_type = t5 46 | 47 | [output] #output parameters 48 | output_time = 20 49 | test_time = 1 50 | output_grad_step = 200 51 | 52 | model_name = ELI5Adapter 53 | output_function = squad 54 | -------------------------------------------------------------------------------- /config/Downstream/FineTune/FEVERFT.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 6 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 8 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=0.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=* 21 | 22 | max_len = 512 23 | 24 | finetune = True 25 | 26 | [eval] #eval parameters 27 | batch_size = 64 28 | reader_num = 8 29 | 30 | [data] #data parameters 31 | train_dataset_type = FEVER 32 | train_formatter_type = FEVERCtx 33 | train_data_path = data/dpr-top5/fever-train-kilt.jsonl 34 | 35 | valid_dataset_type = FEVER 36 | valid_formatter_type = FEVERCtx 37 | valid_data_path = data/dpr-top5/fever-dev-kilt.jsonl 38 | 39 | labelids = [4273,150] 40 | 41 | [model] #model parameters 42 | model_name = TextClassification 43 | pretrained_model = t5-large 44 | 45 | model_type = t5 46 | 47 | [output] #output parameters 48 | output_time = 20 49 | test_time = 1 50 | output_grad_step = 400 51 | 52 | model_name = FEVERCtx 53 | output_function = binary 54 | -------------------------------------------------------------------------------- /config/Downstream/FineTune/WoWFT.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 6 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=* 21 | 22 | max_len = 512 23 | ans_max_len = 64 24 | 25 | finetune = True 26 | 27 | [eval] #eval parameters 28 | batch_size = 32 29 | reader_num = 4 30 | 31 | [data] #data parameters 32 | train_dataset_type = OpenQA 33 | train_formatter_type = OpenQA 34 | train_data_path = data/dpr-top5/wow-train-kilt.jsonl 35 | 36 | valid_dataset_type = OpenQA 37 | valid_formatter_type = OpenQA 38 | valid_data_path = data/dpr-top5/wow-dev-kilt.jsonl 39 | 40 | 41 | [model] #model parameters 42 | model_name = Seq2Seq 43 | 44 | pretrained_model = t5-large 45 | 46 | 47 | model_type = t5 48 | 49 | [output] #output parameters 50 | output_time = 20 51 | test_time = 1 52 | output_grad_step = 200 53 | 54 | model_name = WowAdapter 55 | 56 | output_function = squad 57 | -------------------------------------------------------------------------------- /config/Downstream/Adapter/HQAAdapter.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*adapter* 21 | 22 | max_len = 512 23 | ans_max_len = 64 24 | 25 | bottleneck_dim=16 26 | 27 | [eval] #eval parameters 28 | batch_size = 32 29 | reader_num = 4 30 | 31 | [data] #data parameters 32 | train_dataset_type = OpenQA 33 | train_formatter_type = OpenQA 34 | train_data_path = data/dpr-top5/hotpotqa-train-kilt.jsonl 35 | 36 | valid_dataset_type = OpenQA 37 | valid_formatter_type = OpenQA 38 | valid_data_path = data/dpr-top5/hotpotqa-dev-kilt.jsonl 39 | 40 | 41 | [model] #model parameters 42 | model_name = Seq2Seq 43 | pretrained_model = t5-large 44 | 45 | model_type = t5 46 | 47 | [output] #output parameters 48 | output_time = 20 49 | test_time = 1 50 | output_grad_step = 200 51 | 52 | model_name = HQAAdapter 53 | output_function = squad 54 | -------------------------------------------------------------------------------- /config/Downstream/Adapter/TQAAdapter.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*adapter* 21 | 22 | max_len = 512 23 | ans_max_len = 64 24 | 25 | bottleneck_dim=16 26 | 27 | [eval] #eval parameters 28 | batch_size = 32 29 | reader_num = 4 30 | 31 | [data] #data parameters 32 | train_dataset_type = OpenQA 33 | train_formatter_type = OpenQA 34 | train_data_path = data/dpr-top5/triviaqa-train-kilt.jsonl 35 | 36 | valid_dataset_type = OpenQA 37 | valid_formatter_type = OpenQA 38 | valid_data_path = data/dpr-top5/triviaqa-dev-kilt.jsonl 39 | 40 | 41 | [model] #model parameters 42 | model_name = Seq2Seq 43 | pretrained_model = t5-large 44 | 45 | model_type = t5 46 | 47 | [output] #output parameters 48 | output_time = 20 49 | test_time = 1 50 | output_grad_step = 200 51 | 52 | model_name = TQAAdapter 53 | output_function = squad 54 | -------------------------------------------------------------------------------- /config/Downstream/Adapter/FEVERAdapter.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 16 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 8 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=0.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=* 21 | 22 | max_len = 512 23 | 24 | bottleneck_dim=16 25 | 26 | [eval] #eval parameters 27 | batch_size = 64 28 | reader_num = 8 29 | 30 | [data] #data parameters 31 | train_dataset_type = FEVER 32 | train_formatter_type = FEVERCtx 33 | train_data_path = data/dpr-top5/fever-train-kilt.jsonl 34 | 35 | valid_dataset_type = FEVER 36 | valid_formatter_type = FEVERCtx 37 | valid_data_path = data/dpr-top5/fever-dev-kilt.jsonl 38 | 39 | labelids = [4273,150] 40 | 41 | [model] #model parameters 42 | model_name = TextClassification 43 | pretrained_model = t5-large 44 | 45 | model_type = t5 46 | 47 | [output] #output parameters 48 | output_time = 20 49 | test_time = 1 50 | output_grad_step = 400 51 | 52 | model_name = FEVERCtx 53 | output_function = binary 54 | -------------------------------------------------------------------------------- /config/Downstream/Adapter/WoWAdapter.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*adapter* 21 | 22 | max_len = 512 23 | ans_max_len = 64 24 | 25 | bottleneck_dim=16 26 | 27 | [eval] #eval parameters 28 | batch_size = 32 29 | reader_num = 4 30 | 31 | [data] #data parameters 32 | train_dataset_type = OpenQA 33 | train_formatter_type = OpenQA 34 | train_data_path = data/dpr-top5/wow-train-kilt.jsonl 35 | 36 | valid_dataset_type = OpenQA 37 | valid_formatter_type = OpenQA 38 | valid_data_path = data/dpr-top5/wow-dev-kilt.jsonl 39 | 40 | 41 | [model] #model parameters 42 | model_name = Seq2Seq 43 | 44 | pretrained_model = t5-large 45 | 46 | 47 | model_type = t5 48 | 49 | [output] #output parameters 50 | output_time = 20 51 | test_time = 1 52 | output_grad_step = 200 53 | 54 | model_name = WowAdapter 55 | 56 | output_function = squad 57 | -------------------------------------------------------------------------------- /config/Downstream/FineTune/zsREFT.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 6 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=* 21 | 22 | max_len = 512 23 | ans_max_len = 16 24 | 25 | finetune = True 26 | 27 | [eval] #eval parameters 28 | batch_size = 32 29 | reader_num = 4 30 | 31 | [data] #data parameters 32 | train_dataset_type = SlotFilling 33 | train_formatter_type = OpenQA 34 | train_data_path = data/dpr-top5/structured_zeroshot-train-kilt.jsonl 35 | 36 | valid_dataset_type = SlotFilling 37 | valid_formatter_type = OpenQA 38 | valid_data_path = data/dpr-top5/structured_zeroshot-dev-kilt.jsonl 39 | 40 | 41 | [model] #model parameters 42 | model_name = Seq2Seq 43 | pretrained_model = t5-large 44 | 45 | model_type = t5 46 | 47 | [output] #output parameters 48 | output_time = 20 49 | test_time = 1 50 | output_grad_step = 200 51 | 52 | model_name = zsREAdapter 53 | output_function = squad 54 | -------------------------------------------------------------------------------- /config/Downstream/Adapter/zsREAdapter.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 512 23 | ans_max_len = 16 24 | 25 | bottleneck_dim=16 26 | 27 | [eval] #eval parameters 28 | batch_size = 32 29 | reader_num = 4 30 | 31 | [data] #data parameters 32 | train_dataset_type = SlotFilling 33 | train_formatter_type = OpenQA 34 | train_data_path = data/dpr-top5/structured_zeroshot-train-kilt.jsonl 35 | 36 | valid_dataset_type = SlotFilling 37 | valid_formatter_type = OpenQA 38 | valid_data_path = data/dpr-top5/structured_zeroshot-dev-kilt.jsonl 39 | 40 | 41 | [model] #model parameters 42 | model_name = Seq2Seq 43 | pretrained_model = t5-large 44 | 45 | model_type = t5 46 | 47 | [output] #output parameters 48 | output_time = 20 49 | test_time = 1 50 | output_grad_step = 200 51 | 52 | model_name = zsREAdapter 53 | output_function = squad 54 | -------------------------------------------------------------------------------- /config/Downstream/PlugD/NQPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-4 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 64 23 | ctx_len = 512 24 | ans_max_len = 64 25 | 26 | layerth = 12 27 | bottleneck_dim = 16 28 | 29 | 30 | [eval] #eval parameters 31 | batch_size = 32 32 | reader_num = 4 33 | 34 | [data] #data parameters 35 | train_dataset_type = OpenQA 36 | train_formatter_type = OpenQAPlugD 37 | train_data_path = data/dpr-top5/nq-train-kilt.jsonl 38 | 39 | valid_dataset_type = OpenQA 40 | valid_formatter_type = OpenQAPlugD 41 | valid_data_path = data/dpr-top5/nq-dev-kilt.jsonl 42 | 43 | 44 | [model] #model parameters 45 | model_name = Seq2Seq 46 | pretrained_model = t5-large 47 | 48 | model_type = PlugD 49 | 50 | [output] #output parameters 51 | output_time = 20 52 | test_time = 1 53 | output_grad_step = 200 54 | 55 | model_name = NQPlugD 56 | output_function = squad 57 | -------------------------------------------------------------------------------- /config/Downstream/FineTune/TRexFT.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 2 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=* 21 | 22 | max_len = 512 23 | ans_max_len = 16 24 | 25 | valid_mode=step 26 | step_epoch = 2000 27 | 28 | finetune = True 29 | 30 | [eval] #eval parameters 31 | batch_size = 32 32 | reader_num = 4 33 | 34 | [data] #data parameters 35 | train_dataset_type = SlotFilling 36 | train_formatter_type = OpenQA 37 | train_data_path = data/dpr-top5/trex-train-kilt-top10.jsonl 38 | 39 | valid_dataset_type = SlotFilling 40 | valid_formatter_type = OpenQA 41 | valid_data_path = data/dpr-top5/trex-dev-kilt.jsonl 42 | 43 | 44 | [model] #model parameters 45 | model_name = Seq2Seq 46 | pretrained_model = t5-large 47 | 48 | model_type = t5 49 | 50 | [output] #output parameters 51 | output_time = 20 52 | test_time = 1 53 | output_grad_step = 200 54 | 55 | model_name = TRexAdapter 56 | output_function = squad 57 | -------------------------------------------------------------------------------- /config/Downstream/PlugD/ELI5PlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-4 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 64 23 | ctx_len = 512 24 | ans_max_len = 384 25 | 26 | layerth = 12 27 | bottleneck_dim=16 28 | 29 | 30 | [eval] #eval parameters 31 | batch_size = 32 32 | reader_num = 4 33 | 34 | [data] #data parameters 35 | train_dataset_type = OpenQA 36 | train_formatter_type = OpenQAPlugD 37 | train_data_path = data/dpr-top5/eli5-train-kilt.jsonl 38 | 39 | valid_dataset_type = OpenQA 40 | valid_formatter_type = OpenQAPlugD 41 | valid_data_path = data/dpr-top5/eli5-dev-kilt.jsonl 42 | 43 | 44 | [model] #model parameters 45 | model_name = Seq2Seq 46 | pretrained_model = t5-large 47 | 48 | model_type = PlugD 49 | 50 | [output] #output parameters 51 | output_time = 20 52 | test_time = 1 53 | output_grad_step = 200 54 | 55 | model_name = ELI5PlugD 56 | output_function = squad 57 | -------------------------------------------------------------------------------- /config/Downstream/HyperPlugD/NQHyperPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=* 21 | 22 | max_len = 64 23 | ctx_len = 512 24 | ans_max_len = 64 25 | 26 | layerth = 12 27 | bottleneck_dim = 16 28 | 29 | 30 | [eval] #eval parameters 31 | batch_size = 32 32 | reader_num = 4 33 | 34 | [data] #data parameters 35 | train_dataset_type = OpenQA 36 | train_formatter_type = OpenQAPlugD 37 | train_data_path = data/dpr-top5/nq-train-kilt.jsonl 38 | 39 | valid_dataset_type = OpenQA 40 | valid_formatter_type = OpenQAPlugD 41 | valid_data_path = data/dpr-top5/nq-dev-kilt.jsonl 42 | 43 | 44 | [model] #model parameters 45 | model_name = Seq2Seq 46 | pretrained_model = t5-large 47 | 48 | model_type = HyperPlugD 49 | 50 | [output] #output parameters 51 | output_time = 20 52 | test_time = 1 53 | output_grad_step = 200 54 | 55 | model_name = NQHyperPlugD 56 | output_function = squad 57 | -------------------------------------------------------------------------------- /config/Downstream/PlugD/HQAPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-4 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 64 23 | ctx_len = 512 24 | ans_max_len = 64 25 | 26 | layerth = 12 27 | bottleneck_dim=16 28 | 29 | 30 | [eval] #eval parameters 31 | batch_size = 32 32 | reader_num = 4 33 | 34 | [data] #data parameters 35 | train_dataset_type = OpenQA 36 | train_formatter_type = OpenQAPlugD 37 | train_data_path = data/dpr-top5/hotpotqa-train-kilt.jsonl 38 | 39 | valid_dataset_type = OpenQA 40 | valid_formatter_type = OpenQAPlugD 41 | valid_data_path = data/dpr-top5/hotpotqa-dev-kilt.jsonl 42 | 43 | 44 | [model] #model parameters 45 | model_name = Seq2Seq 46 | pretrained_model = t5-large 47 | 48 | model_type = PlugD 49 | 50 | [output] #output parameters 51 | output_time = 20 52 | test_time = 1 53 | output_grad_step = 200 54 | 55 | model_name = HQAPlugD 56 | output_function = squad 57 | -------------------------------------------------------------------------------- /config/Downstream/PlugD/TQAPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-4 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 64 23 | ctx_len = 512 24 | ans_max_len = 64 25 | 26 | layerth = 12 27 | bottleneck_dim=16 28 | 29 | 30 | [eval] #eval parameters 31 | batch_size = 32 32 | reader_num = 4 33 | 34 | [data] #data parameters 35 | train_dataset_type = OpenQA 36 | train_formatter_type = OpenQAPlugD 37 | train_data_path = data/dpr-top5/triviaqa-train-kilt.jsonl 38 | 39 | valid_dataset_type = OpenQA 40 | valid_formatter_type = OpenQAPlugD 41 | valid_data_path = data/dpr-top5/triviaqa-dev-kilt.jsonl 42 | 43 | 44 | [model] #model parameters 45 | model_name = Seq2Seq 46 | pretrained_model = t5-large 47 | 48 | model_type = PlugD 49 | 50 | [output] #output parameters 51 | output_time = 20 52 | test_time = 1 53 | output_grad_step = 200 54 | 55 | model_name = TQAPlugD 56 | output_function = squad 57 | -------------------------------------------------------------------------------- /config/Downstream/PlugD/WoWPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-4 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | 23 | max_len = 200 24 | ctx_len = 512 25 | ans_max_len = 64 26 | 27 | layerth = 12 28 | bottleneck_dim = 16 29 | 30 | 31 | [eval] #eval parameters 32 | batch_size = 32 33 | reader_num = 4 34 | 35 | [data] #data parameters 36 | train_dataset_type = OpenQA 37 | train_formatter_type = OpenQAPlugD 38 | train_data_path = data/dpr-top5/wow-train-kilt.jsonl 39 | 40 | valid_dataset_type = OpenQA 41 | valid_formatter_type = OpenQAPlugD 42 | valid_data_path = data/dpr-top5/wow-dev-kilt.jsonl 43 | 44 | 45 | [model] #model parameters 46 | model_name = Seq2Seq 47 | pretrained_model = t5-large 48 | 49 | model_type = PlugD 50 | 51 | [output] #output parameters 52 | output_time = 20 53 | test_time = 1 54 | output_grad_step = 200 55 | 56 | model_name = wowPlugD 57 | output_function = squad 58 | -------------------------------------------------------------------------------- /config/Downstream/HyperPlugD/ELI5HyperPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 6 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=* 21 | 22 | max_len = 64 23 | ctx_len = 512 24 | ans_max_len = 384 25 | 26 | layerth = 12 27 | bottleneck_dim=16 28 | 29 | 30 | [eval] #eval parameters 31 | batch_size = 32 32 | reader_num = 4 33 | 34 | [data] #data parameters 35 | train_dataset_type = OpenQA 36 | train_formatter_type = OpenQAPlugD 37 | train_data_path = data/dpr-top5/eli5-train-kilt.jsonl 38 | 39 | valid_dataset_type = OpenQA 40 | valid_formatter_type = OpenQAPlugD 41 | valid_data_path = data/dpr-top5/eli5-dev-kilt.jsonl 42 | 43 | 44 | [model] #model parameters 45 | model_name = Seq2Seq 46 | pretrained_model = t5-large 47 | 48 | model_type = HyperPlugD 49 | 50 | [output] #output parameters 51 | output_time = 20 52 | test_time = 1 53 | output_grad_step = 200 54 | 55 | model_name = ELI5HyperPlugD 56 | output_function = squad 57 | -------------------------------------------------------------------------------- /config/Downstream/Adapter/TRexAdapter.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 2 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 512 23 | ans_max_len = 16 24 | 25 | bottleneck_dim=16 26 | 27 | valid_mode=step 28 | step_epoch = 2000 29 | 30 | [eval] #eval parameters 31 | batch_size = 32 32 | reader_num = 4 33 | 34 | [data] #data parameters 35 | train_dataset_type = SlotFilling 36 | train_formatter_type = OpenQA 37 | train_data_path = data/dpr-top5/trex-train-kilt-top10.jsonl 38 | 39 | valid_dataset_type = SlotFilling 40 | valid_formatter_type = OpenQA 41 | valid_data_path = data/dpr-top5/trex-dev-kilt.jsonl 42 | 43 | 44 | [model] #model parameters 45 | model_name = Seq2Seq 46 | pretrained_model = t5-large 47 | 48 | model_type = t5 49 | 50 | [output] #output parameters 51 | output_time = 20 52 | test_time = 1 53 | output_grad_step = 200 54 | 55 | model_name = TRexAdapter 56 | output_function = squad 57 | -------------------------------------------------------------------------------- /config/Downstream/HyperPlugD/HQAHyperPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 6 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=* 21 | 22 | max_len = 64 23 | ctx_len = 512 24 | ans_max_len = 64 25 | 26 | layerth = 12 27 | bottleneck_dim=16 28 | 29 | 30 | [eval] #eval parameters 31 | batch_size = 32 32 | reader_num = 4 33 | 34 | [data] #data parameters 35 | train_dataset_type = OpenQA 36 | train_formatter_type = OpenQAPlugD 37 | train_data_path = data/dpr-top5/hotpotqa-train-kilt.jsonl 38 | 39 | valid_dataset_type = OpenQA 40 | valid_formatter_type = OpenQAPlugD 41 | valid_data_path = data/dpr-top5/hotpotqa-dev-kilt.jsonl 42 | 43 | 44 | [model] #model parameters 45 | model_name = Seq2Seq 46 | pretrained_model = t5-large 47 | 48 | model_type = HyperPlugD 49 | 50 | [output] #output parameters 51 | output_time = 20 52 | test_time = 1 53 | output_grad_step = 200 54 | 55 | model_name = HQAHyperPlugD 56 | output_function = squad 57 | -------------------------------------------------------------------------------- /config/Downstream/HyperPlugD/TQAHyperPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=* 21 | 22 | max_len = 64 23 | ctx_len = 512 24 | ans_max_len = 64 25 | 26 | layerth = 12 27 | bottleneck_dim = 16 28 | 29 | [eval] #eval parameters 30 | batch_size = 32 31 | reader_num = 4 32 | 33 | [data] #data parameters 34 | train_dataset_type = OpenQA 35 | train_formatter_type = OpenQAPlugD 36 | train_data_path = data/dpr-top5/triviaqa-train-kilt.jsonl 37 | 38 | valid_dataset_type = OpenQA 39 | valid_formatter_type = OpenQAPlugD 40 | valid_data_path = data/dpr-top5/triviaqa-dev-kilt.jsonl 41 | 42 | 43 | [model] #model parameters 44 | model_name = Seq2Seq 45 | pretrained_model = t5-large 46 | 47 | model_type = HyperPlugD 48 | 49 | [output] #output parameters 50 | output_time = 20 51 | test_time = 1 52 | output_grad_step = 200 53 | 54 | model_name = TQAHyperPlugD 55 | output_function = squad 56 | -------------------------------------------------------------------------------- /config/Downstream/PlugD/FEVERPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 16 3 | batch_size = 64 4 | 5 | shuffle = True 6 | 7 | reader_num = 8 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-4 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=0.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 128 23 | ctx_len = 512 24 | 25 | layerth=12 26 | bottleneck_dim=16 27 | 28 | 29 | [eval] #eval parameters 30 | batch_size = 64 31 | reader_num = 8 32 | 33 | [data] #data parameters 34 | train_dataset_type = FEVER 35 | train_formatter_type = FEVERPlugD 36 | train_data_path = data/dpr-top5/fever-train-kilt.jsonl 37 | 38 | valid_dataset_type = FEVER 39 | valid_formatter_type = FEVERPlugD 40 | valid_data_path = data/dpr-top5/fever-dev-kilt.jsonl 41 | 42 | labelids = [4273,150] 43 | 44 | [model] #model parameters 45 | model_name = TextClassification 46 | pretrained_model = t5-large 47 | 48 | model_type = PlugD 49 | 50 | [output] #output parameters 51 | output_time = 20 52 | test_time = 1 53 | output_grad_step = 400 54 | 55 | model_name = FEVERPlugD 56 | output_function = binary 57 | -------------------------------------------------------------------------------- /config/Downstream/HyperPlugD/FEVERHyperPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 6 3 | batch_size = 64 4 | 5 | shuffle = True 6 | 7 | reader_num = 8 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=0.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=* 21 | 22 | max_len = 128 23 | ctx_len = 512 24 | 25 | layerth=12 26 | bottleneck_dim=16 27 | 28 | 29 | [eval] #eval parameters 30 | batch_size = 64 31 | reader_num = 8 32 | 33 | [data] #data parameters 34 | train_dataset_type = FEVER 35 | train_formatter_type = FEVERPlugD 36 | train_data_path = data/dpr-top5/fever-train-kilt.jsonl 37 | 38 | valid_dataset_type = FEVER 39 | valid_formatter_type = FEVERPlugD 40 | valid_data_path = data/dpr-top5/fever-dev-kilt.jsonl 41 | 42 | labelids = [4273,150] 43 | 44 | [model] #model parameters 45 | model_name = TextClassification 46 | pretrained_model = t5-large 47 | 48 | model_type = HyperPlugD 49 | 50 | [output] #output parameters 51 | output_time = 20 52 | test_time = 1 53 | output_grad_step = 400 54 | 55 | model_name = FEVERHyperPlugD 56 | output_function = binary 57 | -------------------------------------------------------------------------------- /config/Downstream/HyperPlugD/WoWHyperPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 6 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=* 21 | 22 | 23 | max_len = 200 24 | ctx_len = 512 25 | ans_max_len = 64 26 | 27 | layerth = 12 28 | bottleneck_dim = 16 29 | 30 | doc_pos=True 31 | 32 | [eval] #eval parameters 33 | batch_size = 32 34 | reader_num = 4 35 | 36 | [data] #data parameters 37 | train_dataset_type = OpenQA 38 | train_formatter_type = OpenQAPlugD 39 | train_data_path = data/dpr-top5/wow-train-kilt.jsonl 40 | 41 | valid_dataset_type = OpenQA 42 | valid_formatter_type = OpenQAPlugD 43 | valid_data_path = data/dpr-top5/wow-dev-kilt.jsonl 44 | 45 | 46 | [model] #model parameters 47 | model_name = Seq2Seq 48 | pretrained_model = t5-large 49 | 50 | model_type = HyperPlugD 51 | 52 | [output] #output parameters 53 | output_time = 20 54 | test_time = 1 55 | output_grad_step = 200 56 | 57 | model_name = wowHyperPlugD 58 | output_function = squad 59 | -------------------------------------------------------------------------------- /config/Downstream/PlugD/zsREPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-4 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 16 23 | ctx_len = 512 24 | ans_max_len = 16 25 | 26 | layerth = 12 27 | bottleneck_dim = 16 28 | 29 | 30 | [eval] #eval parameters 31 | batch_size = 32 32 | reader_num = 4 33 | 34 | [data] #data parameters 35 | train_dataset_type = SlotFilling 36 | train_formatter_type = OpenQAPlugD 37 | train_data_path = data/dpr-top5/structured_zeroshot-train-kilt.jsonl 38 | 39 | valid_dataset_type = SlotFilling 40 | valid_formatter_type = OpenQAPlugD 41 | valid_data_path = data/dpr-top5/structured_zeroshot-dev-kilt.jsonl 42 | 43 | 44 | [model] #model parameters 45 | model_name = Seq2Seq 46 | pretrained_model = t5-large 47 | 48 | model_type = PlugD 49 | 50 | [output] #output parameters 51 | output_time = 20 52 | test_time = 1 53 | output_grad_step = 200 54 | 55 | model_name = zsREPlugD 56 | output_function = squad 57 | -------------------------------------------------------------------------------- /config/Downstream/HyperPlugD/zsREHyperPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 6 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=* 21 | 22 | max_len = 16 23 | ctx_len = 512 24 | ans_max_len = 16 25 | 26 | layerth = 12 27 | bottleneck_dim = 16 28 | 29 | [eval] #eval parameters 30 | batch_size = 32 31 | reader_num = 4 32 | 33 | [data] #data parameters 34 | train_dataset_type = SlotFilling 35 | train_formatter_type = OpenQAPlugD 36 | train_data_path = data/dpr-top5/structured_zeroshot-train-kilt.jsonl 37 | 38 | valid_dataset_type = SlotFilling 39 | valid_formatter_type = OpenQAPlugD 40 | valid_data_path = data/dpr-top5/structured_zeroshot-dev-kilt.jsonl 41 | 42 | 43 | [model] #model parameters 44 | model_name = Seq2Seq 45 | pretrained_model = t5-large 46 | 47 | model_type = HyperPlugD 48 | 49 | [output] #output parameters 50 | output_time = 20 51 | test_time = 1 52 | output_grad_step = 200 53 | 54 | model_name = zsREHyperPlugD 55 | output_function = squad 56 | -------------------------------------------------------------------------------- /formatter/FEVER/FEVERFormatter.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import os 4 | import numpy as np 5 | 6 | import random 7 | from transformers import T5Tokenizer,T5Config 8 | 9 | class FEVERFormatter: 10 | def __init__(self, config, mode, *args, **params): 11 | self.max_len = config.getint("train", "max_len") 12 | self.mode = mode 13 | self.config = config 14 | self.tokenizer = T5Tokenizer.from_pretrained(os.path.join(config.get("model", "pretrained_model_path"), config.get("model", "pretrained_model"), "tokenizer")) 15 | self.max_len = config.getint("train", "max_len") 16 | 17 | self.label2id = { 18 | "SUPPORTS": 0, 19 | "REFUTES": 1, 20 | # "NOT ENOUGH INFO": 2 21 | } 22 | 23 | def process(self, data): 24 | 25 | # claims = [d["claim"] for d in data] 26 | claims = [d["input"] for d in data] 27 | 28 | ret = self.tokenizer(claims, max_length=self.max_len, padding="max_length", truncation=True) 29 | 30 | labels = [self.label2id[d["output"][1]["answer"]] for d in data] 31 | ret["labels"] = labels 32 | for key in ret: 33 | ret[key] = torch.LongTensor(ret[key]) 34 | return ret 35 | -------------------------------------------------------------------------------- /config/Downstream/PlugD/TRexPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 2 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-4 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 16 23 | ctx_len = 512 24 | ans_max_len = 16 25 | 26 | layerth = 12 27 | bottleneck_dim=16 28 | 29 | 30 | valid_mode=step 31 | step_epoch = 2000 32 | 33 | [eval] #eval parameters 34 | batch_size = 32 35 | reader_num = 4 36 | 37 | [data] #data parameters 38 | train_dataset_type = SlotFilling 39 | train_formatter_type = OpenQAPlugD 40 | train_data_path = data/dpr-top5/trex-train-kilt-top10.jsonl 41 | 42 | valid_dataset_type = SlotFilling 43 | valid_formatter_type = OpenQAPlugD 44 | valid_data_path = data/dpr-top5/trex-dev-kilt.jsonl 45 | 46 | 47 | [model] #model parameters 48 | model_name = Seq2Seq 49 | pretrained_model = t5-large 50 | 51 | model_type = PlugD 52 | 53 | [output] #output parameters 54 | output_time = 20 55 | test_time = 1 56 | output_grad_step = 200 57 | 58 | model_name = TRexPlugD 59 | output_function = squad 60 | -------------------------------------------------------------------------------- /config/Downstream/HyperPlugD/TRexHyperPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 2 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=* 21 | 22 | max_len = 16 23 | ctx_len = 512 24 | ans_max_len = 16 25 | 26 | layerth = 12 27 | bottleneck_dim = 16 28 | 29 | valid_mode=step 30 | step_epoch = 2000 31 | 32 | [eval] #eval parameters 33 | batch_size = 32 34 | reader_num = 4 35 | 36 | [data] #data parameters 37 | train_dataset_type = SlotFilling 38 | train_formatter_type = OpenQAPlugD 39 | train_data_path = data/dpr-top5/trex-train-kilt-top10.jsonl 40 | 41 | valid_dataset_type = SlotFilling 42 | valid_formatter_type = OpenQAPlugD 43 | valid_data_path = data/dpr-top5/trex-dev-kilt.jsonl 44 | 45 | 46 | [model] #model parameters 47 | model_name = Seq2Seq 48 | pretrained_model = t5-large 49 | 50 | model_type = HyperPlugD 51 | 52 | [output] #output parameters 53 | output_time = 20 54 | test_time = 1 55 | output_grad_step = 200 56 | 57 | model_name = TRexHyperPlugD 58 | output_function = squad 59 | -------------------------------------------------------------------------------- /config/Downstream/PostAdapter/NQAdapter.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 200 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 0 12 | 13 | 14 | warmup_steps=20000 15 | training_steps=50000 16 | max_grad_norm=0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*adapter* 21 | 22 | max_len = 512 23 | ans_max_len = 64 24 | bottleneck_dim=16 25 | 26 | [eval] #eval parameters 27 | batch_size = 32 28 | reader_num = 4 29 | 30 | [data] #data parameters 31 | train_dataset_type = OpenQA 32 | train_formatter_type = OpenQA 33 | train_data_path = data/dpr-top5/nq-train-kilt.jsonl 34 | 35 | valid_dataset_type = OpenQA 36 | valid_formatter_type = OpenQA 37 | valid_data_path = data/dpr-top5/nq-dev-kilt.jsonl 38 | 39 | test_dataset_type = OpenQA 40 | test_formatter_type = OpenQA 41 | test_data_path = data/dpr-top5/nq-dev-kilt.jsonl 42 | 43 | 44 | [model] #model parameters 45 | model_name = OpenQAAdapter 46 | pretrained_model = t5-base 47 | 48 | model_type = PostT5 49 | 50 | [output] #output parameters 51 | output_time = 20 52 | test_time = 1 53 | output_grad_step = 200 54 | 55 | model_name = NQAdapter 56 | output_function = squad 57 | -------------------------------------------------------------------------------- /dataset/OpenQA/SFDataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | 5 | class SFDataset(Dataset): 6 | def __init__(self, config, mode, encoding="utf8", *args, **params): 7 | self.config = config 8 | self.mode = mode 9 | 10 | self.data = [] 11 | self.ctxnum = 3 12 | fin = open(config.get("data", "%s_data_path" % mode), "r") 13 | self.model_type = config.get("model", "model_type") 14 | for line in fin: 15 | line = json.loads(line) 16 | question = line["input"] 17 | ctxs = line["output"][0]["provenance"][:self.ctxnum] 18 | if mode == "train" and "Post" in self.model_type: 19 | answer = [line["output"][1]["answer"]] 20 | else: 21 | answer = [l["answer"] for l in line["output"][1:]] 22 | ent0, ent1 = question.split("[SEP]") 23 | self.data.append({ 24 | "context": ctxs, 25 | "question": f"the {ent1} of {ent0} is ", # question.replace("[SEP]", " * "), 26 | "answers": answer 27 | }) 28 | 29 | def __getitem__(self, idx): 30 | return self.data[idx] 31 | 32 | def __len__(self): 33 | return len(self.data) 34 | -------------------------------------------------------------------------------- /config/Downstream/PostAdapter/ELI5Adapter.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*adapter* 21 | 22 | max_len = 512 23 | ans_max_len = 384 24 | bottleneck_dim=16 25 | 26 | 27 | [eval] #eval parameters 28 | batch_size = 32 29 | reader_num = 4 30 | 31 | [data] #data parameters 32 | train_dataset_type = OpenQA 33 | train_formatter_type = OpenQA 34 | train_data_path = data/dpr-top5/eli5-train-kilt.jsonl 35 | 36 | valid_dataset_type = OpenQA 37 | valid_formatter_type = OpenQA 38 | valid_data_path = data/dpr-top5/eli5-dev-kilt.jsonl 39 | 40 | test_dataset_type = OpenQA 41 | test_formatter_type = OpenQA 42 | test_data_path = data/dpr-top5/eli5-dev-kilt.jsonl 43 | 44 | 45 | [model] #model parameters 46 | model_name = Seq2Seq 47 | pretrained_model = t5-large 48 | 49 | model_type = PostT5 50 | 51 | [output] #output parameters 52 | output_time = 20 53 | test_time = 1 54 | output_grad_step = 200 55 | 56 | model_name = ELI5Adapter 57 | output_function = squad 58 | -------------------------------------------------------------------------------- /config/Downstream/PostAdapter/TQAAdapter.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 200 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 0 12 | 13 | 14 | warmup_steps=20000 15 | training_steps=50000 16 | max_grad_norm=0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*adapter* 21 | 22 | max_len = 512 23 | ans_max_len = 64 24 | 25 | bottleneck_dim=16 26 | 27 | [eval] #eval parameters 28 | batch_size = 32 29 | reader_num = 4 30 | 31 | [data] #data parameters 32 | train_dataset_type = OpenQA 33 | train_formatter_type = OpenQA 34 | train_data_path = data/dpr-top5/triviaqa-train-kilt.jsonl 35 | 36 | valid_dataset_type = OpenQA 37 | valid_formatter_type = OpenQA 38 | valid_data_path = data/dpr-top5/triviaqa-dev-kilt.jsonl 39 | 40 | test_dataset_type = OpenQA 41 | test_formatter_type = OpenQA 42 | test_data_path = data/dpr-top5/triviaqa-dev-kilt.jsonl 43 | 44 | 45 | [model] #model parameters 46 | model_name = OpenQAAdapter 47 | pretrained_model = t5-base 48 | 49 | model_type = PostT5 50 | 51 | [output] #output parameters 52 | output_time = 20 53 | test_time = 1 54 | output_grad_step = 200 55 | 56 | model_name = TQAAdapter 57 | output_function = squad 58 | -------------------------------------------------------------------------------- /config/Downstream/PostAdapter/FEVERAdapter.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 50 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 8 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=0.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 512 23 | 24 | bottleneck_dim=16 25 | 26 | [eval] #eval parameters 27 | batch_size = 64 28 | reader_num = 8 29 | 30 | [data] #data parameters 31 | train_dataset_type = FEVER 32 | train_formatter_type = FEVERCtx 33 | train_data_path = data/dpr-top5/fever-train-kilt.jsonl 34 | 35 | valid_dataset_type = FEVER 36 | valid_formatter_type = FEVERCtx 37 | valid_data_path = data/dpr-top5/fever-dev-kilt.jsonl 38 | 39 | test_dataset_type = FEVER 40 | test_formatter_type = FEVERCtx 41 | test_data_path = data/dpr-top5/fever-dev-kilt.jsonl 42 | 43 | labelids = [4273,150] 44 | 45 | [model] #model parameters 46 | model_name = TextClassification 47 | pretrained_model = t5-large 48 | 49 | model_type = PostT5 50 | 51 | [output] #output parameters 52 | output_time = 20 53 | test_time = 1 54 | output_grad_step = 400 55 | 56 | model_name = FEVERCtx 57 | output_function = binary 58 | -------------------------------------------------------------------------------- /config/Downstream/PostAdapter/HQAAdapter.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*adapter* 21 | 22 | max_len = 512 23 | ans_max_len = 64 24 | 25 | bottleneck_dim=16 26 | 27 | [eval] #eval parameters 28 | batch_size = 32 29 | reader_num = 4 30 | 31 | [data] #data parameters 32 | train_dataset_type = OpenQA 33 | train_formatter_type = OpenQA 34 | train_data_path = data/dpr-top5/hotpotqa-train-kilt.jsonl 35 | 36 | valid_dataset_type = OpenQA 37 | valid_formatter_type = OpenQA 38 | valid_data_path = data/dpr-top5/hotpotqa-dev-kilt.jsonl 39 | 40 | test_dataset_type = OpenQA 41 | test_formatter_type = OpenQA 42 | test_data_path = data/dpr-top5/hotpotqa-dev-kilt.jsonl 43 | 44 | 45 | [model] #model parameters 46 | model_name = OpenQAAdapter 47 | pretrained_model = t5-large 48 | 49 | model_type = PostT5 50 | 51 | [output] #output parameters 52 | output_time = 20 53 | test_time = 1 54 | output_grad_step = 200 55 | 56 | model_name = HQAAdapter 57 | output_function = squad 58 | -------------------------------------------------------------------------------- /config/Downstream/PostAdapter/WoWAdapter.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*adapter* 21 | 22 | max_len = 512 23 | ans_max_len = 64 24 | 25 | bottleneck_dim=16 26 | 27 | [eval] #eval parameters 28 | batch_size = 32 29 | reader_num = 4 30 | 31 | [data] #data parameters 32 | train_dataset_type = OpenQA 33 | train_formatter_type = OpenQA 34 | train_data_path = data/dpr-top5/wow-train-kilt.jsonl 35 | 36 | valid_dataset_type = OpenQA 37 | valid_formatter_type = OpenQA 38 | valid_data_path = data/dpr-top5/wow-dev-kilt.jsonl 39 | 40 | test_dataset_type = OpenQA 41 | test_formatter_type = OpenQA 42 | test_data_path = data/dpr-top5/wow-dev-kilt.jsonl 43 | 44 | 45 | [model] #model parameters 46 | model_name = OpenQAAdapter 47 | 48 | pretrained_model = t5-base 49 | 50 | 51 | model_type = PostT5 52 | 53 | [output] #output parameters 54 | output_time = 20 55 | test_time = 1 56 | output_grad_step = 200 57 | 58 | model_name = WowAdapter 59 | 60 | output_function = squad 61 | -------------------------------------------------------------------------------- /config/PL/PlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 16 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | ctx_len = 512 14 | ans_len = 128 15 | que_len = 196 16 | 17 | mlm_ratio=0.5 18 | mlm_mean_len=3 19 | 20 | warmup_steps=2000 21 | training_steps=50000 22 | max_grad_norm=1.5 23 | 24 | 25 | valid_mode=step 26 | step_epoch=1000 27 | 28 | scheduler=t5 29 | 30 | layerth=12 31 | 32 | no_valid = True 33 | 34 | 35 | [eval] #eval parameters 36 | batch_size = 16 37 | 38 | [data] #data parameters 39 | train_dataset_type = kara 40 | train_formatter_type = PlugDPL 41 | train_data_path = data/c4-kara 42 | train_kara_namespace = C4 43 | train_kara_dataset = train 44 | train_kara_version = 1st 45 | 46 | valid_dataset_type = kara 47 | valid_formatter_type = PlugDPL 48 | valid_data_path = data/c4-kara 49 | valid_kara_namespace = C4 50 | valid_kara_dataset = train 51 | valid_kara_version = 1st 52 | 53 | 54 | [model] #model parameters 55 | model_name = PlugD 56 | pretrained_model = t5-large 57 | 58 | model_type = PlugD 59 | 60 | [output] #output parameters 61 | output_time = 20 62 | test_time = 1 63 | 64 | model_name = PlugD-large 65 | output_function = mlm 66 | 67 | -------------------------------------------------------------------------------- /config/Downstream/PostHyperPlugD/NQ2PlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 64 23 | ctx_len = 512 24 | ans_max_len = 64 25 | 26 | layerth = 12 27 | mid_dim=2 28 | bottleneck_dim=16 29 | 30 | doc_pos=True 31 | 32 | [eval] #eval parameters 33 | batch_size = 32 34 | reader_num = 4 35 | 36 | [data] #data parameters 37 | train_dataset_type = OpenQA2 38 | train_formatter_type = OpenQAPlugD 39 | train_data_path = ../data/NQ/train.jsonl 40 | 41 | valid_dataset_type = OpenQA2 42 | valid_formatter_type = OpenQAPlugD 43 | valid_data_path = ../data/NQ/dev.jsonl 44 | 45 | test_dataset_type = OpenQA2 46 | test_formatter_type = OpenQAPlugD 47 | test_data_path = ../data/NQ/dev.jsonl 48 | 49 | 50 | [model] #model parameters 51 | model_name = OpenQAAdapter 52 | pretrained_model = t5-large 53 | 54 | model_type = PostPlugD 55 | 56 | [output] #output parameters 57 | output_time = 20 58 | test_time = 1 59 | output_grad_step = 200 60 | 61 | model_name = NQPlugD 62 | output_function = squad 63 | -------------------------------------------------------------------------------- /config/Downstream/PostAdapter/zsREAdapter.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 100 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*adapter* 21 | 22 | max_len = 512 23 | ans_max_len = 16 24 | 25 | bottleneck_dim=16 26 | 27 | [eval] #eval parameters 28 | batch_size = 32 29 | reader_num = 4 30 | 31 | [data] #data parameters 32 | train_dataset_type = SlotFilling 33 | train_formatter_type = OpenQA 34 | train_data_path = data/dpr-top5/structured_zeroshot-train-kilt.jsonl 35 | 36 | valid_dataset_type = SlotFilling 37 | valid_formatter_type = OpenQA 38 | valid_data_path = data/dpr-top5/structured_zeroshot-dev-kilt.jsonl 39 | 40 | test_dataset_type = SlotFilling 41 | test_formatter_type = OpenQA 42 | test_data_path = data/dpr-top5/structured_zeroshot-dev-kilt.jsonl 43 | 44 | 45 | [model] #model parameters 46 | model_name = OpenQAAdapter 47 | pretrained_model = t5-large 48 | 49 | model_type = PostT5 50 | 51 | [output] #output parameters 52 | output_time = 20 53 | test_time = 1 54 | output_grad_step = 200 55 | 56 | model_name = zsREAdapter 57 | output_function = squad 58 | -------------------------------------------------------------------------------- /config/Downstream/PostAdapter/TRexAdapter.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 2 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*adapter* 21 | 22 | max_len = 512 23 | ans_max_len = 16 24 | 25 | bottleneck_dim=16 26 | 27 | valid_mode=step 28 | step_epoch = 2000 29 | 30 | [eval] #eval parameters 31 | batch_size = 32 32 | reader_num = 4 33 | 34 | [data] #data parameters 35 | train_dataset_type = SlotFilling 36 | train_formatter_type = OpenQA 37 | train_data_path = data/dpr-top5/trex-train-kilt-top10.jsonl 38 | 39 | valid_dataset_type = SlotFilling 40 | valid_formatter_type = OpenQA 41 | valid_data_path = data/dpr-top5/trex-dev-kilt.jsonl 42 | 43 | test_dataset_type = SlotFilling 44 | test_formatter_type = OpenQA 45 | test_data_path = data/dpr-top5/trex-dev-kilt.jsonl 46 | 47 | 48 | [model] #model parameters 49 | model_name = OpenQAAdapter 50 | pretrained_model = t5-large 51 | 52 | model_type = PostT5 53 | 54 | [output] #output parameters 55 | output_time = 20 56 | test_time = 1 57 | output_grad_step = 200 58 | 59 | model_name = TRexPostAdapter 60 | output_function = squad 61 | -------------------------------------------------------------------------------- /config/Downstream/PostHyperPlugD/SQuADPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-4 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 64 23 | ctx_len = 512 24 | ans_max_len = 64 25 | 26 | layerth = 12 27 | mid_dim=2 28 | bottleneck_dim=16 29 | 30 | doc_pos=True 31 | 32 | [eval] #eval parameters 33 | batch_size = 32 34 | reader_num = 4 35 | 36 | [data] #data parameters 37 | train_dataset_type = SQuAD 38 | train_formatter_type = OpenQAPlugD 39 | train_data_path = ../data/SQuAD/train-v2.0.json 40 | 41 | valid_dataset_type = SQuAD 42 | valid_formatter_type = OpenQAPlugD 43 | valid_data_path = ../data/SQuAD/dev-v2.0.json 44 | 45 | test_dataset_type = SQuAD 46 | test_formatter_type = OpenQAPlugD 47 | test_data_path = ../data/SQuAD/dev-v2.0.json 48 | 49 | 50 | [model] #model parameters 51 | model_name = OpenQAAdapter 52 | pretrained_model = t5-large 53 | 54 | model_type = PostPlugD 55 | 56 | [output] #output parameters 57 | output_time = 20 58 | test_time = 1 59 | output_grad_step = 200 60 | 61 | model_name = SQuADPlugD 62 | output_function = squad 63 | -------------------------------------------------------------------------------- /config/Downstream/PostPlugD/NQPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 100 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 64 23 | ctx_len = 512 24 | ans_max_len = 64 25 | 26 | layerth = 12 27 | mid_dim=2 28 | bottleneck_dim=16 29 | 30 | doc_pos=True 31 | 32 | [eval] #eval parameters 33 | batch_size = 32 34 | reader_num = 4 35 | 36 | [data] #data parameters 37 | train_dataset_type = OpenQA 38 | train_formatter_type = OpenQAPlugD 39 | train_data_path = data/dpr-top5/nq-train-kilt.jsonl 40 | 41 | valid_dataset_type = OpenQA 42 | valid_formatter_type = OpenQAPlugD 43 | valid_data_path = data/dpr-top5/nq-dev-kilt.jsonl 44 | 45 | test_dataset_type = OpenQA 46 | test_formatter_type = OpenQAPlugD 47 | test_data_path = data/dpr-top5/nq-dev-kilt.jsonl 48 | 49 | 50 | [model] #model parameters 51 | model_name = OpenQAAdapter 52 | pretrained_model = t5-large 53 | 54 | model_type = PostPlugD 55 | 56 | [output] #output parameters 57 | output_time = 20 58 | test_time = 1 59 | output_grad_step = 200 60 | 61 | model_name = NQPlugD 62 | output_function = squad 63 | -------------------------------------------------------------------------------- /config/Downstream/PostPlugD/ELI5PlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 64 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 64 23 | ctx_len = 512 24 | ans_max_len = 384 25 | 26 | layerth = 12 27 | 28 | mid_dim=2 29 | bottleneck_dim=16 30 | 31 | doc_pos=True 32 | 33 | [eval] #eval parameters 34 | batch_size = 32 35 | reader_num = 4 36 | 37 | [data] #data parameters 38 | train_dataset_type = OpenQA 39 | train_formatter_type = OpenQAPlugD 40 | train_data_path = data/dpr-top5/eli5-train-kilt.jsonl 41 | 42 | valid_dataset_type = OpenQA 43 | valid_formatter_type = OpenQAPlugD 44 | valid_data_path = data/dpr-top5/eli5-dev-kilt.jsonl 45 | 46 | test_dataset_type = OpenQA 47 | test_formatter_type = OpenQAPlugD 48 | test_data_path = data/dpr-top5/eli5-dev-kilt.jsonl 49 | 50 | 51 | [model] #model parameters 52 | model_name = Seq2Seq 53 | pretrained_model = t5-large 54 | 55 | model_type = PostPlugD 56 | 57 | [output] #output parameters 58 | output_time = 20 59 | test_time = 1 60 | output_grad_step = 200 61 | 62 | model_name = ELI5PlugD 63 | output_function = squad 64 | -------------------------------------------------------------------------------- /config/Downstream/PostPlugD/WoWPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 50 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-3 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | 23 | max_len = 200 24 | ctx_len = 512 25 | ans_max_len = 64 26 | 27 | layerth = 12 28 | mid_dim=2 29 | bottleneck_dim=16 30 | 31 | doc_pos=True 32 | 33 | [eval] #eval parameters 34 | batch_size = 32 35 | reader_num = 4 36 | 37 | [data] #data parameters 38 | train_dataset_type = OpenQA 39 | train_formatter_type = OpenQAPlugD 40 | train_data_path = data/dpr-top5/wow-train-kilt.jsonl 41 | 42 | valid_dataset_type = OpenQA 43 | valid_formatter_type = OpenQAPlugD 44 | valid_data_path = data/dpr-top5/wow-dev-kilt.jsonl 45 | 46 | test_dataset_type = OpenQA 47 | test_formatter_type = OpenQAPlugD 48 | test_data_path = data/dpr-top5/wow-dev-kilt.jsonl 49 | 50 | 51 | [model] #model parameters 52 | model_name = OpenQAAdapter 53 | pretrained_model = t5-large 54 | 55 | model_type = PostPlugD 56 | 57 | [output] #output parameters 58 | output_time = 20 59 | test_time = 1 60 | output_grad_step = 200 61 | 62 | model_name = wowPlugD 63 | output_function = squad 64 | -------------------------------------------------------------------------------- /config/Downstream/PostPlugD/TQAPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 50 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 64 23 | ctx_len = 512 24 | ans_max_len = 64 25 | 26 | layerth = 12 27 | mid_dim=2 28 | bottleneck_dim=16 29 | 30 | doc_pos=True 31 | 32 | [eval] #eval parameters 33 | batch_size = 32 34 | reader_num = 4 35 | 36 | [data] #data parameters 37 | train_dataset_type = OpenQA 38 | train_formatter_type = OpenQAPlugD 39 | train_data_path = data/dpr-top5/triviaqa-train-kilt.jsonl 40 | 41 | valid_dataset_type = OpenQA 42 | valid_formatter_type = OpenQAPlugD 43 | valid_data_path = data/dpr-top5/triviaqa-dev-kilt.jsonl 44 | 45 | test_dataset_type = OpenQA 46 | test_formatter_type = OpenQAPlugD 47 | test_data_path = data/dpr-top5/triviaqa-dev-kilt.jsonl 48 | 49 | 50 | [model] #model parameters 51 | model_name = OpenQAAdapter 52 | pretrained_model = t5-large 53 | 54 | model_type = PostPlugD 55 | 56 | [output] #output parameters 57 | output_time = 20 58 | test_time = 1 59 | output_grad_step = 200 60 | 61 | model_name = TQAPlugD 62 | output_function = squad 63 | -------------------------------------------------------------------------------- /config/Downstream/PostPlugD/HQAPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 50 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 64 23 | ctx_len = 512 24 | ans_max_len = 64 25 | 26 | layerth = 12 27 | 28 | mid_dim=2 29 | bottleneck_dim=16 30 | 31 | doc_pos=True 32 | 33 | [eval] #eval parameters 34 | batch_size = 32 35 | reader_num = 4 36 | 37 | [data] #data parameters 38 | train_dataset_type = OpenQA 39 | train_formatter_type = OpenQAPlugD 40 | train_data_path = data/dpr-top5/hotpotqa-train-kilt.jsonl 41 | 42 | valid_dataset_type = OpenQA 43 | valid_formatter_type = OpenQAPlugD 44 | valid_data_path = data/dpr-top5/hotpotqa-dev-kilt.jsonl 45 | 46 | test_dataset_type = OpenQA 47 | test_formatter_type = OpenQAPlugD 48 | test_data_path = data/dpr-top5/hotpotqa-dev-kilt.jsonl 49 | 50 | 51 | [model] #model parameters 52 | model_name = OpenQAAdapter 53 | pretrained_model = t5-large 54 | 55 | model_type = PostPlugD 56 | 57 | [output] #output parameters 58 | output_time = 20 59 | test_time = 1 60 | output_grad_step = 200 61 | 62 | model_name = HQAPlugD 63 | output_function = squad 64 | -------------------------------------------------------------------------------- /config/Downstream/PostHyperPlugD/NQPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 50 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 64 23 | ctx_len = 512 24 | ans_max_len = 64 25 | 26 | layerth = 12 27 | mid_dim=2 28 | bottleneck_dim=16 29 | 30 | doc_pos=True 31 | 32 | [eval] #eval parameters 33 | batch_size = 32 34 | reader_num = 4 35 | 36 | [data] #data parameters 37 | train_dataset_type = OpenQA 38 | train_formatter_type = OpenQAPlugD 39 | train_data_path = ../data/KILT-DPR/dpr/nq-train-kilt.jsonl 40 | 41 | valid_dataset_type = OpenQA 42 | valid_formatter_type = OpenQAPlugD 43 | valid_data_path = ../data/KILT-DPR/dpr/nq-dev-kilt.jsonl 44 | 45 | test_dataset_type = OpenQA 46 | test_formatter_type = OpenQAPlugD 47 | test_data_path = ../data/KILT-DPR/dpr/nq-dev-kilt.jsonl 48 | 49 | 50 | [model] #model parameters 51 | model_name = OpenQAAdapter 52 | pretrained_model = t5-large 53 | 54 | model_type = PostHyperPlugD 55 | 56 | [output] #output parameters 57 | output_time = 20 58 | test_time = 1 59 | output_grad_step = 200 60 | 61 | model_name = NQPlugD 62 | output_function = squad 63 | -------------------------------------------------------------------------------- /config/Downstream/PostPlugD/FEVERPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 16 3 | batch_size = 64 4 | 5 | shuffle = True 6 | 7 | reader_num = 8 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-4 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=0.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 128 23 | ctx_len = 512 24 | 25 | layerth=12 26 | 27 | mid_dim=2 28 | bottleneck_dim=16 29 | 30 | doc_pos=True 31 | 32 | [eval] #eval parameters 33 | batch_size = 64 34 | reader_num = 8 35 | 36 | [data] #data parameters 37 | train_dataset_type = FEVER 38 | train_formatter_type = FEVERPlugD 39 | train_data_path = data/dpr-top5/fever-train-kilt.jsonl 40 | 41 | valid_dataset_type = FEVER 42 | valid_formatter_type = FEVERPlugD 43 | valid_data_path = data/dpr-top5/fever-dev-kilt.jsonl 44 | 45 | test_dataset_type = FEVER 46 | test_formatter_type = FEVERPlugD 47 | test_data_path = data/dpr-top5/fever-dev-kilt.jsonl 48 | 49 | labelids = [4273,150] 50 | 51 | [model] #model parameters 52 | model_name = TextClassificationPlugD 53 | pretrained_model = t5-large 54 | 55 | model_type = PostPlugD 56 | 57 | [output] #output parameters 58 | output_time = 20 59 | test_time = 1 60 | output_grad_step = 400 61 | 62 | model_name = FEVERPlugD 63 | output_function = binary 64 | -------------------------------------------------------------------------------- /config/Downstream/PostHyperPlugD/WoWPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 50 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | 23 | max_len = 200 24 | ctx_len = 512 25 | ans_max_len = 64 26 | 27 | layerth = 12 28 | mid_dim=2 29 | bottleneck_dim=16 30 | 31 | doc_pos=True 32 | 33 | [eval] #eval parameters 34 | batch_size = 32 35 | reader_num = 4 36 | 37 | [data] #data parameters 38 | train_dataset_type = OpenQA 39 | train_formatter_type = OpenQAPlugD 40 | train_data_path = ../data/KILT-DPR/dpr/wow-train-kilt.jsonl 41 | 42 | valid_dataset_type = OpenQA 43 | valid_formatter_type = OpenQAPlugD 44 | valid_data_path = ../data/KILT-DPR/dpr/wow-dev-kilt.jsonl 45 | 46 | test_dataset_type = OpenQA 47 | test_formatter_type = OpenQAPlugD 48 | test_data_path = ../data/KILT-DPR/dpr/wow-dev-kilt.jsonl 49 | 50 | 51 | [model] #model parameters 52 | model_name = OpenQAAdapter 53 | pretrained_model = t5-large 54 | 55 | model_type = PostHyperPlugD 56 | 57 | [output] #output parameters 58 | output_time = 20 59 | test_time = 1 60 | output_grad_step = 200 61 | 62 | model_name = wowPlugD 63 | output_function = squad 64 | -------------------------------------------------------------------------------- /config/Downstream/PostHyperPlugD/ELI5PlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 10 3 | batch_size = 64 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 64 23 | ctx_len = 512 24 | ans_max_len = 384 25 | 26 | layerth = 12 27 | 28 | mid_dim=2 29 | bottleneck_dim=16 30 | 31 | doc_pos=True 32 | 33 | [eval] #eval parameters 34 | batch_size = 32 35 | reader_num = 4 36 | 37 | [data] #data parameters 38 | train_dataset_type = OpenQA 39 | train_formatter_type = OpenQAPlugD 40 | train_data_path = ../data/KILT-DPR/dpr/eli5-train-kilt.jsonl 41 | 42 | valid_dataset_type = OpenQA 43 | valid_formatter_type = OpenQAPlugD 44 | valid_data_path = ../data/KILT-DPR/dpr/eli5-dev-kilt.jsonl 45 | 46 | test_dataset_type = OpenQA 47 | test_formatter_type = OpenQAPlugD 48 | test_data_path = ../data/KILT-DPR/dpr/eli5-dev-kilt.jsonl 49 | 50 | 51 | [model] #model parameters 52 | model_name = OpenQAAdapter 53 | pretrained_model = t5-large 54 | 55 | model_type = PostHyperPlugD 56 | 57 | [output] #output parameters 58 | output_time = 20 59 | test_time = 1 60 | output_grad_step = 200 61 | 62 | model_name = ELI5PlugD 63 | output_function = squad 64 | -------------------------------------------------------------------------------- /config/Downstream/PostHyperPlugD/TQAPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 50 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 64 23 | ctx_len = 512 24 | ans_max_len = 64 25 | 26 | layerth = 12 27 | mid_dim=2 28 | bottleneck_dim=16 29 | 30 | doc_pos=True 31 | 32 | [eval] #eval parameters 33 | batch_size = 32 34 | reader_num = 4 35 | 36 | [data] #data parameters 37 | train_dataset_type = OpenQA 38 | train_formatter_type = OpenQAPlugD 39 | train_data_path = ../data/KILT-DPR/dpr/triviaqa-train-kilt.jsonl 40 | 41 | valid_dataset_type = OpenQA 42 | valid_formatter_type = OpenQAPlugD 43 | valid_data_path = ../data/KILT-DPR/dpr/triviaqa-dev-kilt.jsonl 44 | 45 | test_dataset_type = OpenQA 46 | test_formatter_type = OpenQAPlugD 47 | test_data_path = ../data/KILT-DPR/dpr/triviaqa-dev-kilt.jsonl 48 | 49 | 50 | [model] #model parameters 51 | model_name = OpenQAAdapter 52 | pretrained_model = t5-large 53 | 54 | model_type = PostHyperPlugD 55 | 56 | [output] #output parameters 57 | output_time = 20 58 | test_time = 1 59 | output_grad_step = 200 60 | 61 | model_name = TQAPlugD 62 | output_function = squad 63 | -------------------------------------------------------------------------------- /config/Downstream/PostHyperPlugD/HQAPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 50 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 64 23 | ctx_len = 512 24 | ans_max_len = 64 25 | 26 | layerth = 12 27 | 28 | mid_dim=2 29 | bottleneck_dim=16 30 | 31 | doc_pos=True 32 | 33 | [eval] #eval parameters 34 | batch_size = 32 35 | reader_num = 4 36 | 37 | [data] #data parameters 38 | train_dataset_type = OpenQA 39 | train_formatter_type = OpenQAPlugD 40 | train_data_path = ../data/KILT-DPR/dpr/hotpotqa-train-kilt.jsonl 41 | 42 | valid_dataset_type = OpenQA 43 | valid_formatter_type = OpenQAPlugD 44 | valid_data_path = ../data/KILT-DPR/dpr/hotpotqa-dev-kilt.jsonl 45 | 46 | test_dataset_type = OpenQA 47 | test_formatter_type = OpenQAPlugD 48 | test_data_path = ../data/KILT-DPR/dpr/hotpotqa-dev-kilt.jsonl 49 | 50 | 51 | [model] #model parameters 52 | model_name = OpenQAAdapter 53 | pretrained_model = t5-large 54 | 55 | model_type = PostHyperPlugD 56 | 57 | [output] #output parameters 58 | output_time = 20 59 | test_time = 1 60 | output_grad_step = 200 61 | 62 | model_name = HQAPlugD 63 | output_function = squad 64 | -------------------------------------------------------------------------------- /config/Downstream/PostHyperPlugD/FEVERPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 16 3 | batch_size = 64 4 | 5 | shuffle = True 6 | 7 | reader_num = 8 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=0.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 128 23 | ctx_len = 512 24 | 25 | layerth=12 26 | 27 | mid_dim=2 28 | bottleneck_dim=16 29 | 30 | doc_pos=True 31 | 32 | [eval] #eval parameters 33 | batch_size = 64 34 | reader_num = 8 35 | 36 | [data] #data parameters 37 | train_dataset_type = FEVER 38 | train_formatter_type = FEVERPlugD 39 | train_data_path = ../data/KILT-DPR/dpr/fever-train-kilt.jsonl 40 | 41 | valid_dataset_type = FEVER 42 | valid_formatter_type = FEVERPlugD 43 | valid_data_path = ../data/KILT-DPR/dpr/fever-dev-kilt.jsonl 44 | 45 | test_dataset_type = FEVER 46 | test_formatter_type = FEVERPlugD 47 | test_data_path = ../data/KILT-DPR/dpr/fever-dev-kilt.jsonl 48 | 49 | labelids = [4273,150] 50 | 51 | [model] #model parameters 52 | model_name = TextClassificationPlugD 53 | pretrained_model = t5-large 54 | 55 | model_type = PostHyperPlugD 56 | 57 | [output] #output parameters 58 | output_time = 20 59 | test_time = 1 60 | output_grad_step = 400 61 | 62 | model_name = FEVERPlugD 63 | output_function = binary 64 | -------------------------------------------------------------------------------- /config/Downstream/PostPlugD/zsREPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 50 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 0 12 | 13 | 14 | warmup_steps=10000 15 | training_steps=50000 16 | max_grad_norm=0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 24 23 | ctx_len = 512 24 | ans_max_len = 16 25 | 26 | layerth = 12 27 | mid_dim=2 28 | bottleneck_dim=16 29 | 30 | doc_pos=True 31 | 32 | [eval] #eval parameters 33 | batch_size = 32 34 | reader_num = 4 35 | 36 | [data] #data parameters 37 | train_dataset_type = SlotFilling 38 | train_formatter_type = OpenQAPlugD 39 | train_data_path = data/dpr-top5/structured_zeroshot-train-kilt.jsonl 40 | 41 | valid_dataset_type = SlotFilling 42 | valid_formatter_type = OpenQAPlugD 43 | valid_data_path = data/dpr-top5/structured_zeroshot-dev-kilt.jsonl 44 | 45 | test_dataset_type = SlotFilling 46 | test_formatter_type = OpenQAPlugD 47 | test_data_path = data/dpr-top5/structured_zeroshot-dev-kilt.jsonl 48 | 49 | 50 | [model] #model parameters 51 | model_name = OpenQAAdapter 52 | pretrained_model = t5-large 53 | 54 | model_type = PostPlugD 55 | 56 | [output] #output parameters 57 | output_time = 20 58 | test_time = 1 59 | output_grad_step = 200 60 | 61 | model_name = zsREPlugD 62 | output_function = squad 63 | -------------------------------------------------------------------------------- /config/Downstream/PostPlugD/TRexPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 50 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 1e-3 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=0.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 24 23 | ctx_len = 512 24 | ans_max_len = 16 25 | 26 | layerth = 12 27 | mid_dim=2 28 | bottleneck_dim=16 29 | 30 | doc_pos=True 31 | 32 | valid_mode=step 33 | step_epoch = 1000 34 | 35 | [eval] #eval parameters 36 | batch_size = 32 37 | reader_num = 4 38 | 39 | [data] #data parameters 40 | train_dataset_type = SlotFilling 41 | train_formatter_type = OpenQAPlugD 42 | train_data_path = data/dpr-top5/trex-train-kilt-top10.jsonl 43 | 44 | valid_dataset_type = SlotFilling 45 | valid_formatter_type = OpenQAPlugD 46 | valid_data_path = data/dpr-top5/trex-dev-kilt.jsonl 47 | 48 | test_dataset_type = SlotFilling 49 | test_formatter_type = OpenQAPlugD 50 | test_data_path = data/dpr-top5/trex-dev-kilt.jsonl 51 | 52 | 53 | [model] #model parameters 54 | model_name = OpenQAAdapter 55 | pretrained_model = t5-large 56 | 57 | model_type = PostPlugD 58 | 59 | [output] #output parameters 60 | output_time = 20 61 | test_time = 1 62 | output_grad_step = 200 63 | 64 | model_name = TRexPostPlugD 65 | output_function = squad 66 | -------------------------------------------------------------------------------- /config/Downstream/PostHyperPlugD/zsREPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 50 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 0 12 | 13 | 14 | warmup_steps=10000 15 | training_steps=50000 16 | max_grad_norm=0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 16 23 | ctx_len = 512 24 | ans_max_len = 16 25 | 26 | layerth = 12 27 | mid_dim=2 28 | bottleneck_dim=16 29 | 30 | doc_pos=True 31 | 32 | [eval] #eval parameters 33 | batch_size = 32 34 | reader_num = 4 35 | 36 | [data] #data parameters 37 | train_dataset_type = SlotFilling 38 | train_formatter_type = OpenQAPlugD 39 | train_data_path = ../data/KILT-DPR/dpr/structured_zeroshot-train-kilt.jsonl 40 | 41 | valid_dataset_type = SlotFilling 42 | valid_formatter_type = OpenQAPlugD 43 | valid_data_path = ../data/KILT-DPR/dpr/structured_zeroshot-dev-kilt.jsonl 44 | 45 | test_dataset_type = SlotFilling 46 | test_formatter_type = OpenQAPlugD 47 | test_data_path = ../data/KILT-DPR/dpr/structured_zeroshot-dev-kilt.jsonl 48 | 49 | 50 | [model] #model parameters 51 | model_name = OpenQAAdapter 52 | pretrained_model = t5-large 53 | 54 | model_type = PostHyperPlugD 55 | 56 | [output] #output parameters 57 | output_time = 20 58 | test_time = 1 59 | output_grad_step = 200 60 | 61 | model_name = zsREPlugD 62 | output_function = squad 63 | -------------------------------------------------------------------------------- /config/Downstream/PostHyperPlugD/TRexPlugD.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 50 3 | batch_size = 32 4 | 5 | shuffle = True 6 | 7 | reader_num = 4 8 | 9 | optimizer = AdamW 10 | learning_rate = 2e-5 11 | weight_decay = 1e-5 12 | 13 | 14 | warmup_steps=2000 15 | training_steps=50000 16 | max_grad_norm=2.0 17 | 18 | scheduler=t5 19 | 20 | inspector_para=*plug* 21 | 22 | max_len = 16 23 | ctx_len = 512 24 | ans_max_len = 16 25 | 26 | layerth = 12 27 | mid_dim=2 28 | bottleneck_dim=16 29 | 30 | doc_pos=True 31 | 32 | valid_mode=step 33 | step_epoch = 2000 34 | 35 | [eval] #eval parameters 36 | batch_size = 32 37 | reader_num = 4 38 | 39 | [data] #data parameters 40 | train_dataset_type = SlotFilling 41 | train_formatter_type = OpenQAPlugD 42 | train_data_path = ../data/KILT-DPR/dpr/trex-train-kilt-top10.jsonl 43 | 44 | valid_dataset_type = SlotFilling 45 | valid_formatter_type = OpenQAPlugD 46 | valid_data_path = ../data/KILT-DPR/dpr/trex-dev-kilt.jsonl 47 | 48 | test_dataset_type = SlotFilling 49 | test_formatter_type = OpenQAPlugD 50 | test_data_path = ../data/KILT-DPR/dpr/trex-dev-kilt.jsonl 51 | 52 | 53 | [model] #model parameters 54 | model_name = OpenQAAdapter 55 | pretrained_model = t5-large 56 | 57 | model_type = PostHyperPlugD 58 | 59 | [output] #output parameters 60 | output_time = 20 61 | test_time = 1 62 | output_grad_step = 200 63 | 64 | model_name = TRexPostPlugD 65 | output_function = squad 66 | -------------------------------------------------------------------------------- /dataset/kara/KaraDataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | import kara_storage 5 | from kara_storage.pytorch.base import KaraPytorchDatasetBase 6 | from kara_storage.row import RowDataset 7 | import bmtrain as bmt 8 | 9 | def make_torch_dataset(ds : RowDataset, shuffle=False, auto_distributed=True, **kwargs) -> 'KaraPytorchDatasetBase': 10 | 11 | import torch 12 | import torch.distributed 13 | from kara_storage.pytorch.base import KaraPytorchDatasetBase 14 | from kara_storage.pytorch.iter import SequentialIterator 15 | from kara_storage.pytorch.shuffle import ShuffleIterator 16 | 17 | if auto_distributed: 18 | rank = bmt.rank() 19 | size = bmt.world_size() 20 | 21 | total_length = ds.size() 22 | 23 | ds.slice_(total_length * rank // size, total_length // size) 24 | if shuffle: 25 | ret = KaraPytorchDatasetBase(ds, ShuffleIterator, seed=2333, **kwargs) 26 | else: 27 | ret = KaraPytorchDatasetBase(ds, SequentialIterator, seed=2333, **kwargs) 28 | return ret 29 | 30 | def make_kara_dataset(config, mode, encoding="utf8", *args, **params): 31 | storage = kara_storage.KaraStorage("file://%s" % config.get("data", "%s_data_path" % mode)) 32 | 33 | dataset = storage.open_dataset(config.get("data", "%s_kara_namespace" % mode), config.get("data", "%s_kara_dataset" % mode), "r", version=config.get("data", "%s_kara_version" % mode)) 34 | ret = make_torch_dataset(dataset, shuffle=True) 35 | ret.length = len(dataset) 36 | return ret 37 | 38 | -------------------------------------------------------------------------------- /formatter/TextClassification/TextClassificationPlugDFormatter.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import os 4 | import numpy as np 5 | 6 | import random 7 | from transformers import T5Tokenizer, T5Config 8 | from transformers.file_utils import is_torch_fx_proxy 9 | 10 | class TextClassificationPlugDFormatter: 11 | def __init__(self, config, mode, *args, **params): 12 | self.ctx_len = config.getint("train", "ctx_len") 13 | self.mode = mode 14 | self.tokenizer = T5Tokenizer.from_pretrained(os.path.join(config.get("model", "pretrained_model_path"), config.get("model", "pretrained_model"), "tokenizer")) 15 | self.label2id = json.load(open(config.get("data", "label2id"), "r")) 16 | self.query = (" or ".join(list(self.label2id.keys())) + "?").lower() 17 | 18 | def process(self, data): 19 | ques = [self.query] * len(data) 20 | ctxs = [d["text"] for d in data] 21 | labels = [self.label2id[d["label"]] for d in data] 22 | 23 | question = self.tokenizer(ques) 24 | context = self.tokenizer(ctxs, max_length=self.ctx_len, padding="max_length", truncation=True) 25 | 26 | model_inputs = { 27 | "que_input_ids": question["input_ids"], 28 | "que_attention_mask": question["attention_mask"], 29 | "ctx_input_ids": context["input_ids"], 30 | "ctx_attention_mask": context["attention_mask"], 31 | "labels": labels, 32 | } 33 | 34 | model_inputs["decoder_input_ids"] = [[0]] * len(data) 35 | model_inputs["decoder_length"] = [1] * len(data) 36 | 37 | for key in model_inputs: 38 | model_inputs[key] = torch.LongTensor(model_inputs[key]) 39 | 40 | return model_inputs 41 | -------------------------------------------------------------------------------- /tools/output_tool.py: -------------------------------------------------------------------------------- 1 | from genericpath import exists 2 | import json 3 | import os 4 | 5 | 6 | def null_output_function(data, config, *args, **params): 7 | return str(data) 8 | 9 | 10 | def binary_output_function(data, config, *args, **params): 11 | if data['total'] == 0: 12 | metric = {'acc': 0} 13 | else: 14 | metric = {'acc': round(data['right'] / data['total'], 4)} 15 | return json.dumps(metric) 16 | 17 | 18 | 19 | def squad_output_function(data, config, *args, **params): 20 | if data["train"]: 21 | acc = round(data["right"] / data["total"], 4) 22 | return json.dumps({"tok_acc": acc}) 23 | else: 24 | if data['NA_tp'] != 0 or data['NA_fp'] != 0: 25 | pre = float(data['NA_tp']) / (data['NA_tp'] + data["NA_fp"]) 26 | recall = float(data['NA_tp']) / (data['NA_tp'] + data["NA_fn"]) 27 | if pre + recall == 0: 28 | naf1 = 0 29 | else: 30 | naf1 = 2 * pre * recall / (pre + recall) 31 | else: 32 | naf1 = 0 33 | 34 | return json.dumps({ 35 | "EM": round(data["em_sum"] / data["total"], 4), 36 | "F1": round(data["f1_sum"] / data["total"], 4), 37 | "NA_F1": round(naf1, 4), 38 | "ROUGE-L-F": round(data["ROUGE-L-F"] / data["total"], 4) if "ROUGE-L-F" in data else 0, 39 | "ROUGE-L-P": round(data["ROUGE-L-P"] / data["total"], 4) if "ROUGE-L-P" in data else 0, 40 | "ROUGE-L-R": round(data["ROUGE-L-R"] / data["total"], 4) if "ROUGE-L-R" in data else 0, 41 | } 42 | ) 43 | 44 | def mlm_output_function(data, config, *args, **params): 45 | acc = round(data["right"] / data["total"], 4) 46 | return json.dumps({"tok_acc": acc, "avg_loss": sum(data["loss"]) / len(data["loss"])}) 47 | -------------------------------------------------------------------------------- /formatter/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .MLM.PlugDPLFormatter import PlugDPLFormatter 4 | 5 | 6 | from .TextClassification.TextClassificationPlugDFormatter import TextClassificationPlugDFormatter 7 | from .TextClassification.TextClassificationAdapterFormatter import TextClassificationAdapterFormatter 8 | 9 | from .FEVER.FEVERFormatter import FEVERFormatter 10 | from .FEVER.FEVERCtxFormatter import FEVERCtxFormatter,FEVERCtxPlugDFormatter,FEVERCtxED2LMFormatter 11 | 12 | from .OpenQA.OpenQAFormatter import OpenQAFormatter,OpenQAPlugDFormatter 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | formatter_list = { 18 | "None": lambda x: None, 19 | 20 | "TextClassificationPlugD": TextClassificationPlugDFormatter, 21 | "TextClassificationAdapter": TextClassificationAdapterFormatter, 22 | 23 | "FEVER": FEVERFormatter, 24 | "FEVERCtx": FEVERCtxFormatter, 25 | "FEVERPlugD": FEVERCtxPlugDFormatter, 26 | 27 | "OpenQA": OpenQAFormatter, 28 | "OpenQAPlugD": OpenQAPlugDFormatter, 29 | 30 | "PlugDPL": PlugDPLFormatter, 31 | } 32 | 33 | def init_formatter(config, mode, *args, **params): 34 | temp_mode = mode 35 | if mode != "train": 36 | try: 37 | config.get("data", "%s_formatter_type" % temp_mode) 38 | except Exception as e: 39 | logger.warning( 40 | "[reader] %s_formatter_type has not been defined in config file, use [dataset] train_formatter_type instead." % temp_mode) 41 | temp_mode = "train" 42 | which = config.get("data", "%s_formatter_type" % temp_mode) 43 | print("formatter_type", which) 44 | if which in formatter_list: 45 | formatter = formatter_list[which](config, mode, *args, **params) 46 | 47 | return formatter 48 | else: 49 | logger.error("There is no formatter called %s, check your config." % which) 50 | raise NotImplementedError 51 | -------------------------------------------------------------------------------- /formatter/TextClassification/TextClassificationAdapterFormatter.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import os 4 | import numpy as np 5 | 6 | import random 7 | from transformers import T5Tokenizer 8 | 9 | class TextClassificationAdapterFormatter: 10 | def __init__(self, config, mode, *args, **params): 11 | self.ctx_len = config.getint("train", "ctx_len") 12 | self.mode = mode 13 | self.tokenizer = T5Tokenizer.from_pretrained(os.path.join(config.get("model", "pretrained_model_path"), config.get("model", "pretrained_model"), "tokenizer")) 14 | self.label2id = json.load(open(config.get("data", "label2id"), "r")) 15 | self.query = self.tokenizer.encode((" or ".join(list(self.label2id.keys())) + "?").lower(), add_special_tokens=False) 16 | 17 | def tokenize(self, doc): 18 | doctoken = self.tokenizer.encode(doc, add_special_tokens=False) 19 | alltokens = doctoken[:self.ctx_len - len(self.query) - 1] + self.query + [self.tokenizer.eos_token_id] 20 | mask = [1] * len(alltokens) + [0] * (self.ctx_len - len(alltokens)) 21 | if len(alltokens) < self.ctx_len: 22 | alltokens += [self.tokenizer.pad_token_id] * (self.ctx_len - len(alltokens)) 23 | return alltokens, mask 24 | 25 | def process(self, data): 26 | inputids, attmask = [], [] 27 | for doc in data: 28 | inp, ma = self.tokenize(doc["text"]) 29 | inputids.append(inp), attmask.append(ma) 30 | 31 | labels = [self.label2id[d["label"]] for d in data] 32 | 33 | model_inputs = { 34 | "input_ids": inputids, 35 | "attention_mask": attmask, 36 | "labels": labels, 37 | } 38 | 39 | model_inputs["decoder_input_ids"] = [[0]] * len(data) 40 | model_inputs["decoder_length"] = [1] * len(data) 41 | 42 | for key in model_inputs: 43 | model_inputs[key] = torch.LongTensor(model_inputs[key]) 44 | 45 | return model_inputs 46 | -------------------------------------------------------------------------------- /tools/test_tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | from torch.autograd import Variable 5 | from timeit import default_timer as timer 6 | 7 | from tools.eval_tool import gen_time_str, output_value 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def test(parameters, config, gpu_list): 13 | model = parameters["model"] 14 | dataset = parameters["test_dataset"] 15 | model.eval() 16 | 17 | acc_result = None 18 | total_loss = 0 19 | cnt = 0 20 | total_len = len(dataset) 21 | start_time = timer() 22 | output_info = "testing" 23 | 24 | output_time = config.getint("output", "output_time") 25 | step = -1 26 | result = [] 27 | 28 | for step, data in enumerate(dataset): 29 | for key in data.keys(): 30 | if isinstance(data[key], torch.Tensor): 31 | if len(gpu_list) > 0: 32 | data[key] = Variable(data[key].cuda()) 33 | else: 34 | data[key] = Variable(data[key]) 35 | 36 | results = model(data, config, gpu_list, acc_result, "test") 37 | result = result + results["output"] 38 | cnt += 1 39 | 40 | if step % output_time == 0: 41 | delta_t = timer() - start_time 42 | 43 | output_value(0, "test", "%d/%d" % (step + 1, total_len), "%s/%s" % ( 44 | gen_time_str(delta_t), gen_time_str(delta_t * (total_len - step - 1) / (step + 1))), 45 | "%.3lf" % (total_loss / (step + 1)), output_info, '\r', config) 46 | 47 | if step == -1: 48 | logger.error("There is no data given to the model in this epoch, check your data.") 49 | raise NotImplementedError 50 | 51 | delta_t = timer() - start_time 52 | output_info = "testing" 53 | output_value(0, "test", "%d/%d" % (step + 1, total_len), "%s/%s" % ( 54 | gen_time_str(delta_t), gen_time_str(delta_t * (total_len - step - 1) / (step + 1))), 55 | "%.3lf" % (total_loss / (step + 1)), output_info, None, config) 56 | 57 | return result 58 | -------------------------------------------------------------------------------- /config_parser/parser.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import os 3 | import functools 4 | 5 | 6 | class ConfigParser: 7 | def __init__(self, *args, **params): 8 | self.default_config = configparser.RawConfigParser(*args, **params) 9 | self.local_config = configparser.RawConfigParser(*args, **params) 10 | self.config = configparser.RawConfigParser(*args, **params) 11 | 12 | def read(self, filenames, encoding=None): 13 | if os.path.exists("config/default_local.config"): 14 | self.local_config.read("config/default_local.config", encoding=encoding) 15 | else: 16 | self.local_config.read("config/default.config", encoding=encoding) 17 | self.default_config.read("config/default.config", encoding=encoding) 18 | 19 | # if os.path.exists("/mnt/sfs_turbo/xcj/ReadOnce/docaspara/config/default_local.config"): 20 | # self.local_config.read("/mnt/sfs_turbo/xcj/ReadOnce/docaspara/config/default_local.config", encoding=encoding) 21 | # else: 22 | # self.local_config.read("/mnt/sfs_turbo/xcj/ReadOnce/docaspara/config/default.config", encoding=encoding) 23 | # self.default_config.read("/mnt/sfs_turbo/xcj/ReadOnce/docaspara/config/default.config", encoding=encoding) 24 | 25 | self.config.read(filenames, encoding=encoding) 26 | 27 | 28 | def _build_func(func_name): 29 | @functools.wraps(getattr(configparser.RawConfigParser, func_name)) 30 | def func(self, *args, **kwargs): 31 | try: 32 | return getattr(self.config, func_name)(*args, **kwargs) 33 | except Exception as e: 34 | try: 35 | return getattr(self.local_config, func_name)(*args, **kwargs) 36 | except Exception as e: 37 | return getattr(self.default_config, func_name)(*args, **kwargs) 38 | 39 | return func 40 | 41 | 42 | def create_config(path): 43 | for func_name in dir(configparser.RawConfigParser): 44 | if not func_name.startswith('_') and func_name != "read": 45 | setattr(ConfigParser, func_name, _build_func(func_name)) 46 | 47 | config = ConfigParser() 48 | config.read(path) 49 | 50 | return config 51 | -------------------------------------------------------------------------------- /dataset/OpenQA/OpenQADataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | 5 | class OpenQADataset(Dataset): 6 | def __init__(self, config, mode, encoding="utf8", *args, **params): 7 | self.config = config 8 | self.mode = mode 9 | 10 | self.data = [] 11 | self.ctxnum = 3 12 | fin = open(config.get("data", "%s_data_path" % mode), "r") 13 | for line in fin: 14 | line = json.loads(line) 15 | question = line["input"] 16 | ctxs = line["output"][0]["provenance"][:self.ctxnum] 17 | answer = [l["answer"] for l in line["output"][1:]] 18 | self.data.append({ 19 | "context": ctxs, 20 | "question": question, 21 | "answers": answer 22 | }) 23 | 24 | def __getitem__(self, idx): 25 | return self.data[idx] 26 | 27 | def __len__(self): 28 | return len(self.data) 29 | 30 | 31 | import random 32 | class FewOpenQADataset(Dataset): 33 | def __init__(self, config, mode, encoding="utf8", *args, **params): 34 | self.config = config 35 | self.mode = mode 36 | 37 | data = [] 38 | self.ctxnum = 3 39 | fin = open(config.get("data", "%s_data_path" % mode), "r") 40 | for line in fin: 41 | line = json.loads(line) 42 | question = line["input"] 43 | ctxs = line["output"][0]["provenance"][:self.ctxnum] 44 | answer = [l["answer"] for l in line["output"][1:]] 45 | data.append({ 46 | "context": ctxs, 47 | "question": question, 48 | "answers": answer 49 | }) 50 | 51 | self.few_num = config.getint("fewshot", "few_num") 52 | self.seed = config.getint("fewshot", "dataset_seed") 53 | if mode == "train": 54 | random.seed(self.seed) 55 | self.data = random.sample(data, self.few_num) 56 | else: 57 | self.data = data 58 | 59 | def __getitem__(self, idx): 60 | return self.data[idx % len(self.data)] 61 | 62 | def __len__(self): 63 | return max(200, len(self.data)) 64 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | import logging 3 | from bmtrain.global_var import config as bmtconfig 4 | from bmtrain import nccl 5 | import torch 6 | from transformers.file_utils import is_torch_fx_proxy 7 | 8 | def output_log(logger: logging.Logger, info: str, level: int = logging.INFO, *args): 9 | if not (dist.is_initialized() and dist.get_rank() != 0): 10 | logger._log(level, info, args) 11 | 12 | def print_rank(*arg): 13 | if not (dist.is_initialized() and dist.get_rank() != 0): 14 | print(*arg) 15 | 16 | def reduce(var : torch.Tensor, op: str = "avg"): 17 | ret = torch.empty_like(var) 18 | nccl.allReduce( 19 | var.storage(), 20 | ret.storage(), 21 | op, 22 | bmtconfig['comm'] 23 | ) 24 | return ret 25 | 26 | def shift_tokens_right(input_ids, pad_token_id: int, decoder_start_token_id: int): 27 | 28 | assert ( 29 | decoder_start_token_id is not None 30 | ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information" 31 | 32 | # shift inputs to the right 33 | if is_torch_fx_proxy(input_ids): 34 | # Item assignment is not supported natively for proxies. 35 | shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) 36 | shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) 37 | else: 38 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 39 | shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() 40 | shifted_input_ids[..., 0] = decoder_start_token_id 41 | 42 | assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." 43 | # replace possible -100 values in labels by `pad_token_id` 44 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 45 | 46 | assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values" 47 | 48 | return shifted_input_ids 49 | 50 | def freeze_module(module): 51 | for param in module.parameters(): 52 | param.requires_grad = False 53 | -------------------------------------------------------------------------------- /model/TextClassification/TextClassifier.py: -------------------------------------------------------------------------------- 1 | from model.PlugD.PlugD import PlugD,HyperPlugD 2 | 3 | from model.T5Adapter.T5Adapter import T5Adapter 4 | from torch import nn 5 | from model.metric import softmax_acc,microf1 6 | import bmtrain as bmt 7 | import json 8 | from model_center.model import T5Config 9 | import os 10 | import torch 11 | 12 | class TextClassifier(nn.Module): 13 | def __init__(self, config, gpu_list, *args, **params): 14 | super(TextClassifier, self).__init__() 15 | self.plmpath = os.path.join(config.get("model", "pretrained_model_path"), config.get("model", "pretrained_model")) 16 | self.t5config = T5Config.from_pretrained(self.plmpath) 17 | self.model_type = config.get("model", "model_type") 18 | print("model type:", self.model_type) 19 | if self.model_type == "t5" or self.model_type == "PostT5": 20 | self.model = T5Adapter(config) 21 | elif self.model_type == "PlugD" or self.model_type == "PostPlugD": 22 | self.model = PlugD(config) 23 | elif self.model_type == "HyperPlugD": 24 | self.model = HyperPlugD(config) 25 | else: 26 | raise ValueError("model_type has not been defined") 27 | 28 | self.labelids = json.loads(config.get("data", "labelids")) 29 | 30 | self.loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100) 31 | 32 | def forward(self, data, config, gpu_list, acc_result, mode): 33 | batch = data["input_ids"].size(0) if "input_ids" in data else data["ctx_input_ids"].size(0) 34 | device = data["input_ids"].device if "input_ids" in data else data["ctx_input_ids"].device 35 | 36 | data["decoder_input_ids"] = torch.zeros(batch, 2, dtype=torch.long, device=device) 37 | data["decoder_length"] = torch.ones(batch, dtype=torch.long, device=device) + 1 38 | data["decoder_input_ids"][:,1] = 32099 39 | if self.model_type == "PostPlugD" and mode != "test": 40 | logits = self.model(data, no_ctx=True) 41 | else: 42 | logits = self.model(data) 43 | scores = logits[:,1,self.labelids] *(100*self.t5config.dim_model**-0.5) 44 | 45 | 46 | loss = self.loss_func(scores, data["labels"]) 47 | 48 | acc_result = softmax_acc(scores, data["labels"], acc_result) 49 | 50 | return {"loss": loss, "acc_result": acc_result} 51 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | 7 | 8 | class MultiLabelSoftmaxLoss(nn.Module): 9 | def __init__(self, config, dim): 10 | super(MultiLabelSoftmaxLoss, self).__init__() 11 | self.task_num = dim 12 | self.criterion = [] 13 | for a in range(0, self.task_num): 14 | self.criterion.append(nn.CrossEntropyLoss()) 15 | 16 | def forward(self, outputs, labels): 17 | loss = 0 18 | for a in range(0, len(outputs[0])): 19 | o = outputs[:, a, :].view(outputs.size()[0], -1) 20 | loss += self.criterion[a](o, labels[:, a]) 21 | 22 | return loss 23 | 24 | 25 | def multi_label_cross_entropy_loss(outputs, labels): 26 | labels = labels.float() 27 | temp = outputs 28 | res = - labels * torch.log(temp) - (1 - labels) * torch.log(1 - temp) 29 | res = torch.mean(torch.sum(res, dim=1)) 30 | 31 | return res 32 | 33 | 34 | def cross_entropy_loss(outputs, labels): 35 | criterion = nn.CrossEntropyLoss() 36 | return criterion(outputs, labels) 37 | 38 | 39 | class FocalLoss(nn.Module): 40 | def __init__(self, gamma=0, alpha=None, size_average=True): 41 | super(FocalLoss, self).__init__() 42 | self.gamma = gamma 43 | self.alpha = alpha 44 | self.size_average = size_average 45 | 46 | def forward(self, input, target): 47 | if input.dim() > 2: 48 | input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W 49 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C 50 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C 51 | target = target.view(-1, 1) 52 | 53 | logpt = F.log_softmax(input) 54 | logpt = logpt.gather(1, target) 55 | logpt = logpt.view(-1) 56 | pt = Variable(logpt.data.exp()) 57 | 58 | if self.alpha is not None: 59 | if self.alpha.type() != input.data.type(): 60 | self.alpha = self.alpha.type_as(input.data) 61 | at = self.alpha.gather(0, target.data.view(-1)) 62 | logpt = logpt * Variable(at) 63 | 64 | loss = -1 * (1 - pt) ** self.gamma * logpt 65 | if self.size_average: 66 | return loss.mean() 67 | else: 68 | return loss.sum() 69 | -------------------------------------------------------------------------------- /model/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | import torch 3 | from torch.optim import AdamW 4 | import bmtrain as bmt 5 | 6 | def get_params_for_prompt_optimization(module: torch.nn.Module): 7 | # params = [{"params": [], "lr": 5e-4}, {"params": [], "lr": 1e-5}] 8 | params = [] 9 | names = [] 10 | for t in module.named_modules(): 11 | if "nograd" in t[0]: # or "doc2para2B" in t[0] or "doc2para1B" in t[0]: 12 | continue 13 | # if "mapper" in t[0]: 14 | # params[0]["params"].extend([p for p in list(t[1]._parameters.values()) if p is not None]) 15 | # else: 16 | # params[1]["params"].extend([p for p in list(t[1]._parameters.values()) if p is not None]) 17 | # params.append({'params': [p for p in list(t[1]._parameters.values()) if p is not None]}) 18 | params.extend([p for p in list(t[1]._parameters.values()) if p is not None]) 19 | names.append(t[0]) 20 | 21 | return params, names 22 | 23 | def init_optimizer(model, config, *args, **params): 24 | optimizer_type = config.get("train", "optimizer") 25 | learning_rate = config.getfloat("train", "learning_rate") 26 | 27 | if config.getboolean("train", "ignore_no_grad"): 28 | param_group, param_names = get_params_for_prompt_optimization(model) 29 | print("ignore parameters with nograd in name, and only %s parameters are turned" % len(param_group)) 30 | # print(param_names) 31 | else: 32 | param_group = model.parameters() 33 | # param_group = [{"params": model.ctx_encoder.parameters(), "lr": 1e-3}, {"params": model.que_model.parameters(), "lr": 1e-4}] 34 | print("all parameters are turned") 35 | 36 | optimizer = bmt.optim.AdamOffloadOptimizer(param_group, lr=learning_rate, 37 | weight_decay=config.getfloat("train", "weight_decay")) 38 | # if optimizer_type == "adam": 39 | # optimizer = optim.Adam(param_group, lr=learning_rate, 40 | # weight_decay=config.getfloat("train", "weight_decay")) 41 | # elif optimizer_type == "sgd": 42 | # optimizer = optim.SGD(param_group, lr=learning_rate, 43 | # weight_decay=config.getfloat("train", "weight_decay")) 44 | # elif optimizer_type == "AdamW": 45 | # optimizer = AdamW(param_group, lr=learning_rate, 46 | # weight_decay=config.getfloat("train", "weight_decay")) 47 | # else: 48 | # raise NotImplementedError 49 | 50 | return optimizer 51 | -------------------------------------------------------------------------------- /model/MLM/PlugDPL.py: -------------------------------------------------------------------------------- 1 | from random import random 2 | from transformers import T5Tokenizer 3 | from ..PlugD.PlugD import PlugD,HyperPlugD 4 | from ..Basic.DeltaT5 import DeltaT5 5 | 6 | import torch 7 | from torch import nn 8 | from model.metric import mlm_acc_loss 9 | import bmtrain as bmt 10 | from bmtrain import print_rank 11 | 12 | class PlugDPlugLearning(nn.Module): 13 | def __init__(self, config, gpu_list, *args, **params): 14 | super(PlugDPlugLearning, self).__init__() 15 | 16 | self.model = PlugD(config, pretrain=True) 17 | self.loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, reduction="none") 18 | 19 | self.layerth = config.getint("train", "layerth") 20 | 21 | def cal_dec(self, hiddens, mask, dec_inp, dec_mask, labels): 22 | logits = self.model.cal_dec(hiddens, mask, dec_inp, dec_mask) 23 | 24 | return self.cal_loss(logits * (100*self.model.plm_config.dim_model**-0.5), labels) 25 | 26 | def cal_loss(self, logits, labels): 27 | batch, seq_len, vocab_size = logits.size() 28 | 29 | loss_shape = self.loss_func(logits.view(-1, vocab_size), labels.view(-1)) 30 | loss_mask = (labels != -100).sum(dim=1).float() 31 | loss_mask[loss_mask == 0] = torch.inf 32 | loss = (loss_shape.view(batch, seq_len).sum(dim=1) / loss_mask).mean() 33 | return loss 34 | 35 | def forward(self, data, config, gpu_list, acc_result, mode): 36 | 37 | parameters, ctx_last_hidden = self.model.generate_doc_plug_train( 38 | input_ids = data["ctx_input_ids"], 39 | attention_mask = data["ctx_attention_mask"], 40 | ) 41 | 42 | 43 | deltas = {"type": "prefix", "prefix_num": ctx_last_hidden.size(1), "layerth": self.layerth} 44 | ctx_attention_mask = data["ctx_attention_mask"] 45 | 46 | output, logits = self.model.backbone( 47 | input_ids=data["que_input_ids"], 48 | attention_mask=data["que_attention_mask"], 49 | decoder_input_ids=data["decoder_input_ids"], 50 | decoder_attention_mask=data["decoder_attention_mask"], 51 | deltas = deltas, 52 | pfxatt_mask=ctx_attention_mask, 53 | parameters = parameters 54 | ) 55 | loss = self.cal_loss(logits * (100*self.model.plm_config.dim_model**-0.5), data["labels"]) 56 | 57 | predict = torch.argmax(logits, dim = -1) 58 | 59 | acc_result = mlm_acc_loss(predict, data["labels"], acc_result, loss) 60 | return {"loss": loss, "acc_result": acc_result} 61 | -------------------------------------------------------------------------------- /model/Seq2Seq/Seq2Seq.py: -------------------------------------------------------------------------------- 1 | from transformers import T5Tokenizer 2 | import torch 3 | from torch import nn 4 | from model.metric import squad_metric, squad_train_metric 5 | import bmtrain as bmt 6 | import os 7 | from ..T5Adapter.T5Adapter import T5Adapter 8 | from ..PlugD.PlugD import PlugD,HyperPlugD 9 | from model_center.model import T5Config 10 | 11 | class Seq2Seq(nn.Module): 12 | def __init__(self, config, gpu_list, *args, **params): 13 | super(Seq2Seq, self).__init__() 14 | self.plmpath = os.path.join(config.get("model", "pretrained_model_path"), config.get("model", "pretrained_model")) 15 | self.t5config = T5Config.from_pretrained(self.plmpath) 16 | 17 | self.model_type = config.get("model", "model_type") 18 | if self.model_type == "t5" or self.model_type == "PostT5": 19 | self.model = T5Adapter(config) 20 | elif self.model_type == "PlugD" or self.model_type == "PostPlugD": 21 | self.model = PlugD(config) 22 | elif self.model_type == "HyperPlugD" or self.model_type == "PostHyperPlugD": 23 | self.model = HyperPlugD(config) 24 | else: 25 | raise ValueError("model_type has not been defined") 26 | 27 | self.ans_len = config.getint("train", "ans_max_len") 28 | self.loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100) 29 | self.tokenizer = T5Tokenizer.from_pretrained(os.path.join(self.plmpath, "tokenizer")) 30 | 31 | self.RL = ("ELI5" in config.get("output", "model_name")) 32 | 33 | def forward(self, data, config, gpu_list, acc_result, mode): 34 | device = data["input_ids"].device if "input_ids" in data else data["ctx_input_ids"].device 35 | if mode == "train": 36 | if self.model_type in ["PostPlugD", "PostHyperPlugD"]: 37 | logits = self.model(data, no_ctx=True) * (100*self.t5config.dim_model**-0.5) 38 | else: 39 | logits = self.model(data) * (100*self.t5config.dim_model**-0.5) 40 | vocab_size = logits.shape[-1] 41 | 42 | loss = self.loss_func(logits.view(-1, vocab_size), data["labels"].view(-1)) 43 | predict = torch.argmax(logits, dim = 2) 44 | acc_result = squad_train_metric(predict, data["labels"], acc_result) 45 | else: 46 | if self.model_type in ["PostPlugD", "PostHyperPlugD"] and mode == "valid": 47 | answer = self.model.generate_greedy(data, gen_length=self.ans_len, no_ctx=True) 48 | else: 49 | answer = self.model.generate_greedy(data, gen_length=self.ans_len) 50 | loss = torch.tensor(0.0).to(device) 51 | 52 | acc_result = squad_metric(answer, data["answers"], acc_result, self.tokenizer, RL=self.RL) 53 | return {"loss": loss, "acc_result": acc_result} 54 | 55 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import logging 5 | 6 | from tools.init_tool import init_all 7 | from config_parser import create_config 8 | from tools.train_tool import train 9 | import bmtrain as bmt 10 | from bmtrain import print_rank 11 | import re 12 | 13 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 14 | datefmt='%m/%d/%Y %H:%M:%S', 15 | level=logging.INFO) 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def parse_hyper_para(setting, config): 21 | if setting is None: 22 | return None 23 | pat = re.compile("\+(.*?)=(.*?)=(.*?)\+") 24 | paras = pat.findall(setting) 25 | for para in paras: 26 | print_rank("add", para) 27 | config.set(para[0], para[1], para[2]) 28 | 29 | def print_config(config): 30 | for sec in config.sections(): 31 | print_rank("[%s]" % sec) 32 | for op in config.options(sec): 33 | print_rank("%s: %s" % (op, config.get(sec, op))) 34 | print_rank("========") 35 | 36 | from transformers import T5ForConditionalGeneration 37 | from transformers import T5Config 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('--config', '-c', help="specific config file", required=True) 42 | parser.add_argument('--gpu', '-g', help="gpu id list") 43 | parser.add_argument('--checkpoint', help="checkpoint file path") 44 | parser.add_argument('--only_eval', action="store_true") 45 | parser.add_argument('--do_test', help="do test while training or not", action="store_true") 46 | parser.add_argument('--hyper_para', "-hp", default=None) 47 | parser.add_argument('--local_rank', type=int, help='local rank', default=-1) 48 | parser.add_argument('--skip_step', type=int, default=0) 49 | 50 | args = parser.parse_args() 51 | 52 | configFilePath = args.config 53 | 54 | config = create_config(configFilePath) 55 | 56 | gpu_list = [] 57 | 58 | use_gpu = True 59 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 60 | 61 | device_list = args.gpu.split(",") 62 | for a in range(0, len(device_list)): 63 | gpu_list.append(int(a)) 64 | 65 | bmt.init_distributed(seed=1003) 66 | print_rank(args.hyper_para) 67 | parse_hyper_para(args.hyper_para, config) 68 | 69 | cuda = torch.cuda.is_available() 70 | logger.info("CUDA available: %s" % str(cuda)) 71 | if not cuda and len(gpu_list) > 0: 72 | logger.error("CUDA is not available but specific gpu id") 73 | raise NotImplementedError 74 | 75 | import sys 76 | sys.setrecursionlimit(512 * 400 + 10) 77 | 78 | parameters = init_all(config, gpu_list, args.checkpoint, "train", skip_step=args.skip_step) 79 | parameters["skip-step"] = args.skip_step 80 | do_test = False 81 | if args.do_test: 82 | do_test = True 83 | 84 | print_config(config) 85 | 86 | train(parameters, config, gpu_list, do_test, args.only_eval) 87 | -------------------------------------------------------------------------------- /dataset/OpenQA/SQuADDataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | import random 5 | 6 | class SQuADDataset(Dataset): 7 | def __init__(self, config, mode, encoding="utf8", *args, **params): 8 | self.config = config 9 | self.mode = mode 10 | 11 | data = json.load(open(config.get("data", "%s_data_path" % mode), "r", encoding="utf8")) 12 | self.qas = [] 13 | self.context = [] 14 | for doc in data["data"]: 15 | title = doc["title"] 16 | for para in doc["paragraphs"]: 17 | context = para["context"] 18 | self.context.append({"text": context}) 19 | qas = [] 20 | for qa in para["qas"]: 21 | qa.update({"context": len(self.context) - 1}) 22 | if "is_impossible" in qa and qa["is_impossible"]: 23 | qa["answers"] = [{"text": "no answer"}] 24 | qa["answers"] = [a["text"] for a in qa["answers"]] 25 | qa["title"] = title 26 | qas.append(qa) 27 | self.qas.extend(qas) 28 | 29 | def __getitem__(self, idx): 30 | qa = self.qas[idx] 31 | ret = qa.copy() 32 | ret["context"] = [self.context[qa["context"]]] 33 | # print(ret) 34 | return ret 35 | 36 | def __len__(self): 37 | # return 320 38 | return len(self.qas) 39 | 40 | 41 | class FewSQuADDataset(Dataset): 42 | def __init__(self, config, mode, encoding="utf8", *args, **params): 43 | self.config = config 44 | self.mode = mode 45 | 46 | self.ratio = config.getfloat("train", "few_ratio") 47 | data = json.load(open(config.get("data", "%s_data_path" % mode), "r", encoding="utf8")) 48 | self.qas = [] 49 | self.context = [] 50 | if mode == "train": 51 | random.seed(10086) 52 | doces = random.sample(data["data"], int(self.ratio * len(data["data"]))) 53 | else: 54 | doces = data["data"] 55 | # for doc in data["data"]: 56 | for doc in doces: 57 | title = doc["title"] 58 | for para in doc["paragraphs"]: 59 | context = para["context"] 60 | self.context.append({"text": context}) 61 | qas = [] 62 | for qa in para["qas"]: 63 | qa.update({"context": len(self.context) - 1}) 64 | if "is_impossible" in qa and qa["is_impossible"]: 65 | qa["answers"] = [{"text": "no answer"}] 66 | qa["answers"] = [a["text"] for a in qa["answers"]] 67 | qa["title"] = title 68 | qas.append(qa) 69 | self.qas.extend(qas) 70 | 71 | def __getitem__(self, idx): 72 | qa = self.qas[idx % len(self.qas)] 73 | ret = qa.copy() 74 | ret["context"] = [self.context[qa["context"]]] 75 | # print(ret) 76 | return ret 77 | 78 | def __len__(self): 79 | # return 320 80 | if self.mode == "train": 81 | return int(len(self.qas) * 0.1 / self.ratio) 82 | else: 83 | return len(self.qas) 84 | 85 | -------------------------------------------------------------------------------- /dataset/KGDataset/KGDataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from tqdm import tqdm 5 | 6 | class KGDataset(): 7 | def __init__(self): 8 | fin = open("/liuzyai04/thunlp/xcj/docaspara/data/wikidata5m/wikidata5m_transductive_train.txt", "r") 9 | self.triples = [] 10 | self.entity2ids = {} 11 | 12 | allents = set(json.load(open("/liuzyai04/thunlp/xcj/docaspara/data/knowledge-task/knowledge-task/Wiki80/allent.json", "r"))) 13 | self.rel2id = json.load(open("/liuzyai04/thunlp/xcj/docaspara/data/knowledge-task/knowledge-task/Wiki80/pid2name.json", "r")) 14 | word2qid = json.load(open("/liuzyai04/thunlp/xcj/docaspara/data/wikidata5m/word2qid.json", "r")) 15 | self.qid2word = {word2qid[name]: name.replace("_", " ") for name in word2qid} 16 | 17 | for line in tqdm(fin.readlines()): 18 | line = line.strip().split("\t") 19 | if line[1] not in self.rel2id: 20 | continue 21 | if line[0] not in self.qid2word or line[2] not in self.qid2word: 22 | continue 23 | self.triples.append(line) 24 | 25 | # if line[0] in allents: 26 | if line[0] not in self.entity2ids: 27 | self.entity2ids[line[0]] = [] 28 | self.entity2ids[line[0]].append(len(self.triples) - 1) 29 | # if line[2] in allents: 30 | if line[2] not in self.entity2ids: 31 | self.entity2ids[line[2]] = [] 32 | self.entity2ids[line[2]].append(len(self.triples) - 1) 33 | 34 | # def search_path_recursive(self, qid, tid, nowid=None, nowpath=[]): 35 | # if len(nowpath) == 3: 36 | # return [] 37 | # depth = len(nowpath) 38 | # if nowid is None and depth == 0: 39 | # nowid = qid 40 | # nextup = [self.triples[i] for i in self.entity2ids[nowid]] 41 | # nextent = set([tup[0] for tup in nextup] + [tup[2] for tup in nextup]) - set([nowid]) 42 | # if tid in nextent: 43 | 44 | 45 | def search_path(self, qid, tid): 46 | if qid not in self.entity2ids or tid not in self.entity2ids: 47 | return "" 48 | headtup = [self.triples[i] for i in self.entity2ids[qid]] 49 | tailtup = [self.triples[i] for i in self.entity2ids[tid]] 50 | head_relate_ent = set([tup[0] for tup in headtup] + [tup[2] for tup in headtup]) - set([qid, tid]) 51 | tail_relate_ent = set([tup[0] for tup in tailtup] + [tup[2] for tup in tailtup]) - set([qid, tid]) 52 | hop2end = head_relate_ent & tail_relate_ent 53 | if len(hop2end) == 0: 54 | return "" 55 | relate_tup = [self.triples[i] for i in self.entity2ids[qid] if self.triples[i][0] in hop2end or self.triples[i][2] in hop2end] + \ 56 | [self.triples[i] for i in self.entity2ids[tid] if self.triples[i][0] in hop2end or self.triples[i][2] in hop2end] 57 | sents = [] 58 | for tup in relate_tup: 59 | if tup[0] != tid and tup[2] != tid: 60 | sents.append("%s is %s of %s." % (self.qid2word[tup[0]], self.rel2id[tup[1]][0], self.qid2word[tup[2]])) 61 | return " ".join(sents) 62 | 63 | def get_relations(self, qid, tid=None): 64 | if not qid in self.entity2ids: 65 | return "" 66 | rels = self.entity2ids[qid][:10] 67 | sents = [] 68 | for tupid in rels: 69 | tup = self.triples[tupid] 70 | if tup[0] != tid and tup[2] != tid: 71 | sents.append("%s is %s of %s." % (self.qid2word[tup[0]], self.rel2id[tup[1]][0], self.qid2word[tup[2]])) 72 | return " ".join(sents) 73 | 74 | -------------------------------------------------------------------------------- /tools/eval_tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from threading import local 4 | from typing import List 5 | import torch 6 | from torch.autograd import Variable 7 | from torch.optim import lr_scheduler 8 | from timeit import default_timer as timer 9 | from bmtrain import print_rank 10 | import bmtrain as bmt 11 | from tools import reduce 12 | from kara_storage.pytorch.base import KaraPytorchDatasetBase 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def gen_time_str(t): 17 | t = int(t) 18 | minute = t // 60 19 | second = t % 60 20 | return '%2d:%02d' % (minute, second) 21 | 22 | 23 | def output_value(epoch, mode, step, time, loss, info, end, config, lr="", otherinfo=""): 24 | try: 25 | delimiter = config.get("output", "delimiter") 26 | except Exception as e: 27 | delimiter = " " 28 | s = "" 29 | s = s + str(epoch) + " " 30 | while len(s) < 10: 31 | s += " " 32 | s = s + str(mode) + " " 33 | while len(s) < 18: 34 | s += " " 35 | s = s + str(step) + " " 36 | while len(s) < 30: 37 | s += " " 38 | s += str(time) 39 | while len(s) < 50: 40 | s += " " 41 | s += str(loss) 42 | while len(s) < 58: 43 | s += " " 44 | s += str(info) 45 | s = s.replace(" ", delimiter) 46 | s += "\t%s" % lr 47 | s += "\t%s" % otherinfo 48 | if not (end is None): 49 | print_rank(s, end=end) 50 | else: 51 | print_rank(s) 52 | 53 | 54 | def valid(model, dataset, epoch, config, gpu_list, output_function, mode="valid"): 55 | model.eval() 56 | local_rank = bmt.rank() #config.getint('distributed', 'local_rank') 57 | acc_result = None 58 | total_loss = 0 59 | cnt = 0 60 | total_len = len(dataset) 61 | start_time = timer() 62 | output_info = "" 63 | 64 | output_time = config.getint("output", "output_time") 65 | step = -1 66 | 67 | if hasattr(dataset, "dataset") and isinstance(dataset.dataset, KaraPytorchDatasetBase): 68 | dataset.dataset.set_epoch(0) 69 | 70 | for step, data in enumerate(dataset): 71 | for key in data.keys(): 72 | if isinstance(data[key], torch.Tensor): 73 | if len(gpu_list) > 0: 74 | data[key] = Variable(data[key].cuda()) 75 | else: 76 | data[key] = Variable(data[key]) 77 | results = model(data, config, gpu_list, acc_result, mode=mode) 78 | 79 | loss, acc_result = results["loss"], results["acc_result"] 80 | total_loss += bmt.sum_loss(loss).item() 81 | cnt += 1 82 | if step % output_time == 0: 83 | delta_t = timer() - start_time 84 | 85 | output_value(epoch, mode, "%d/%d" % (step + 1, total_len), "%s/%s" % ( 86 | gen_time_str(delta_t), gen_time_str(delta_t * (total_len - step - 1) / (step + 1))), 87 | "%.3lf" % (total_loss / (step + 1)), output_info, None, config) 88 | if step == -1: 89 | logger.error("There is no data given to the model in this epoch, check your data.") 90 | raise NotImplementedError 91 | 92 | if acc_result is not None and config.getboolean("distributed", "use") and type(acc_result) != list: 93 | if "train" in acc_result: 94 | acc_result.pop("train") 95 | total_loss = bmt.sum_loss(torch.tensor(total_loss).cuda()).item() 96 | for key in acc_result: 97 | if type(acc_result[key]) == list: 98 | continue 99 | acc_result[key] = reduce(torch.tensor(acc_result[key]).cuda(), "sum").item() 100 | acc_result["train"] = False 101 | else: 102 | total_loss = bmt.sum_loss(torch.tensor(total_loss).cuda()).item() 103 | 104 | delta_t = timer() - start_time 105 | output_info = output_function(acc_result, config) 106 | output_value(epoch, mode, "%d/%d" % (step + 1, total_len), "%s/%s" % ( 107 | gen_time_str(delta_t), gen_time_str(delta_t * (total_len - step - 1) / (step + 1))), 108 | "%.3lf" % (total_loss / (step + 1)), output_info, None, config) 109 | 110 | model.train() 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Plug-and-Play Document Modules for Pre-trained Models 2 | 3 | The code and checkpoints of our ACL paper "**Plug-and-Play Document Modules for Pre-trained Models**" 4 | 5 | 6 | 7 | If you use the code, please cite the following paper: 8 | 9 | ``` 10 | @inproceedings{xiao2023plug, 11 | title={Plug-and-Play Document Modules for Pre-trained Models}, 12 | author={Xiao, Chaojun and Zhang, Zhengyan and Han, Xu and Chan, Chi-Min and Lin, Yankai and Liu, Zhiyuan and Li, Xiangyang and Li, Zhonghua and Cao, Zhao and Sun, Maosong}, 13 | booktitle={Proceedings of ACL}, 14 | year={2023} 15 | } 16 | ``` 17 | 18 | 19 | 20 | ## Quick Links 21 | 22 | * [Overview](#overview) 23 | * [Requirements](#requirements) 24 | * [Folder Structure](#folder-structure) 25 | * [Plugin Learning](#plugin-learning) 26 | * [Downstream Tuning](#downstream-tuning) 27 | 28 | 29 | 30 | ## Overview 31 | 32 | ![overview](figs/overview.png) 33 | 34 | We propose to represent documents as plug-and-play modules for pre-trained language model. In this way, we can decouple document encoding from concrete tasks, and achieve encoding doucments only once for multiple different tasks. 35 | 36 | 37 | 38 | ## Requirements 39 | 40 | ``` 41 | kara-storage==2.1.5 42 | transformers==4.26.0.dev0 43 | bmtrain==0.2.2 44 | torch==1.12.1 45 | rouge==1.0.1 46 | ``` 47 | 48 | 49 | 50 | ## Folder Structure 51 | 52 | * **`train.py`**: The entry point of all training and evaluation scripts. The arguments for the `train.py` are as follows: 53 | * `--config/-c`: the configure file path. Almost parameters, including the data path, model hyper-parameters, and so on, will be set in the configure files. 54 | * `--gpu/-g`: the GPU devices used for running the program. This argument will be used to set the environment variable `CUDA_VISIBLE_DEVICES`. 55 | * `--checkpoint`: the path of a specific checkpoint, which would be loaded for continual training. 56 | * **`dataset`**: code for reading data into memory 57 | * **`formatter`**: code for processing raw data into tensors, which will be feed into models 58 | * **`model`**: code for our models 59 | * **`config`**: configure files for training and evaluation. 60 | * **`run_script`**: the training scripts. 61 | * **`utils`**: code for pre-processing data and download checkpoints. 62 | 63 | 64 | 65 | ## Pluin Learning 66 | 67 | In this section, we will present how to conduct plugin learning by using our code. 68 | 69 | **Data Preparation** 70 | 71 | First, download the C4 dataset and put it in `data/c4-json`. It is worth noting that C4 is a large-scale pre-training dataset, and in this paper we only need to use a small portion of it. 72 | 73 | Then, run the following script to store the large-scale dataset into a streaming dataset with the `kara-storage` package: 74 | 75 | ```bash 76 | python3 utils/parse_c4.py 77 | ``` 78 | 79 | **Model Initialization** 80 | 81 | First download the T5-large checkpoints to initalize the model by running the following scripts: 82 | 83 | ```bash 84 | bash utils/download_t5.sh 85 | ``` 86 | 87 | **Training Scripts** 88 | 89 | ```bash 90 | bash run_script/PluginLearning/run_plugd_pl.sh 91 | ``` 92 | 93 | The trained chekpoint can be found in `checkpoint/PlugD-large`. 94 | 95 | We also provide the trained PlugD chechpoint in [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/7573895bf1ed4881a8d5/). 96 | 97 | 98 | 99 | ## Downstream Tuning 100 | 101 | **Data Preparation** 102 | 103 | Please refer to [KILT](https://github.com/facebookresearch/KILT) for the code to conduct retrieval. You can also download the data from [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/d/4ddd19c7758b462f971b/) 104 | 105 | Put the data in `data/dpr-top5` 106 | 107 | **Training Scripts** 108 | 109 | Then you can run downstream task tuning with the following scripts: 110 | 111 | ``` 112 | bash run_script/Downstream/plugd.sh TASK PlugDPATH 113 | ``` 114 | 115 | Here, `TASK` refers to the task to run, and must be in `FEVER, NQ, TQA, HQA, ELI5, WoW, zsRE, TRex`; `PlugDPATH` refers to the checkpoint trained with plugin learning. Notably, PlugD decouples document encoding from concrete tasks, and we can save inference time by pre-encoding the documents. Here, we donot perfrom document pre-encoding due to the limitation of storage. 116 | 117 | -------------------------------------------------------------------------------- /tools/init_tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | 4 | from reader.reader import init_dataset, init_formatter, init_test_dataset 5 | from model import get_model 6 | from model.optimizer import init_optimizer 7 | from .output_init import init_output_function 8 | from torch import nn 9 | from tools import output_log 10 | import bmtrain as bmt 11 | from bmtrain import print_rank 12 | from transformers import get_linear_schedule_with_warmup 13 | from bmtrain.store import DistributedStateDictWrapper 14 | from model.scheduler import T5Scheduler 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def init_all(config, gpu_list, checkpoint, mode, *args, **params): 19 | result = {} 20 | 21 | output_log(logger, "Begin to initialize dataset and formatter...") 22 | if mode == "train": 23 | # init_formatter(config, ["train", "valid"], *args, **params) 24 | result["train_dataset"], result["valid_dataset"] = init_dataset(config, *args, **params) 25 | else: 26 | # init_formatter(config, ["test"], *args, **params) 27 | result["test_dataset"] = init_test_dataset(config, *args, **params) 28 | 29 | output_log(logger, "Begin to initialize models...") 30 | 31 | model = get_model(config.get("model", "model_name"))(config, gpu_list, *args, **params) 32 | optimizer = init_optimizer(model, config, *args, **params) 33 | 34 | lrsche_type = config.get("train", "scheduler") 35 | if lrsche_type == "linear": 36 | lr_scheduler = bmt.lr_scheduler.Linear(optimizer, start_lr=config.getfloat('train', 'learning_rate'), warmup_iter=config.getint('train', 'warmup_steps'), end_iter=config.getint('train', 'training_steps')) 37 | elif lrsche_type == "t5": 38 | lr_scheduler = T5Scheduler(optimizer, start_lr=config.getfloat('train', 'learning_rate'), warmup_iter=config.getint('train', 'warmup_steps'), end_iter=config.getint('train', 'training_steps'), num_iter=params["skip_step"]) 39 | # lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=config.getint('train', 'warmup_steps'), num_training_steps=config.getint('train', 'training_steps')) 40 | trained_epoch = 0 41 | global_step = 0 42 | 43 | 44 | # model = model.to(gpu_list[bmt.rank()]) 45 | try: 46 | # print("read checkpoint", bmt.rank()) 47 | # for i in range(bmt.world_size()): 48 | # if bmt.rank() == i: 49 | # parameters = torch.load(checkpoint, map_location=lambda storage, loc: storage) 50 | # bmt.synchronize() 51 | # print("load checkpoint", parameters.keys(), bmt.rank()) 52 | # if bmt.rank() == 0: 53 | # model.load_state_dict(DistributedStateDictWrapper(parameters["model"])) 54 | # else: 55 | # print_rank(bmt.rank(), "load_state_dict") 56 | # model.load_state_dict({}) 57 | output_log(logger, "try load checkpoint from %s" % checkpoint, logging.INFO) 58 | if checkpoint is not None and checkpoint != "None": 59 | parameters = {"trained_epoch": 0} 60 | print_rank(bmt.load(model, checkpoint, strict=False)) 61 | if mode == "train": 62 | trained_epoch = parameters["trained_epoch"] 63 | if "optimizer" in parameters and config.get("train", "optimizer") == parameters["optimizer_name"]: 64 | optimizer.load_state_dict(parameters["optimizer"]) 65 | else: 66 | output_log(logger, "Optimizer changed, do not load parameters of optimizer.", logging.WARNING) 67 | if "global_step" in parameters: 68 | global_step = parameters["global_step"] 69 | if "lr_scheduler" in parameters: 70 | lr_scheduler.load_state_dict(parameters["lr_scheduler"]) 71 | if config.get("model", "adapter_path") != "None": 72 | print_rank(bmt.load(model, config.get("model", "adapter_path"), strict=False)) 73 | except Exception as e: 74 | information = "Cannot load checkpoint file with error %s" % str(e) 75 | if mode == "test": 76 | output_log(logger, information, logging.ERROR) 77 | raise e 78 | else: 79 | output_log(logger, information, logging.WARNING) 80 | 81 | # model = bmt.BMTrainModelWrapper(model) 82 | result["model"] = model 83 | if mode == "train": 84 | result["optimizer"] = optimizer 85 | result["lr_scheduler"] = lr_scheduler 86 | result["trained_epoch"] = trained_epoch 87 | result["output_function"] = init_output_function(config) 88 | result["global_step"] = global_step 89 | 90 | output_log(logger, "Initialize done.", logging.INFO) 91 | 92 | return result 93 | -------------------------------------------------------------------------------- /reader/reader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import logging 3 | 4 | import formatter as form 5 | from dataset import dataset_list 6 | # from torch.utils.data.distributed import DistributedSampler 7 | from torch.utils.data import RandomSampler 8 | from kara_storage.pytorch.base import KaraPytorchDatasetBase 9 | from model_center.dataset import DistributedDataLoader 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | collate_fn = {} 14 | formatter = {} 15 | 16 | def init_formatter(config, task_list, *args, **params): 17 | for task in task_list: 18 | formatter[task] = form.init_formatter(config, task, *args, **params) 19 | 20 | def train_collate_fn(data): 21 | return formatter["train"].process(data, config, "train") 22 | 23 | def valid_collate_fn(data): 24 | return formatter["valid"].process(data, config, "valid") 25 | 26 | def test_collate_fn(data): 27 | return formatter["test"].process(data, config, "test") 28 | 29 | if task == "train": 30 | collate_fn[task] = formatter[task].process # train_collate_fn 31 | elif task == "valid": 32 | collate_fn[task] = formatter[task].process # valid_collate_fn 33 | else: 34 | collate_fn[task] = formatter[task].process # test_collate_fn 35 | 36 | 37 | def init_one_dataset(config, mode, *args, **params): 38 | temp_mode = mode 39 | if mode != "train": 40 | try: 41 | config.get("data", "%s_dataset_type" % temp_mode) 42 | except Exception as e: 43 | logger.warning( 44 | "[reader] %s_dataset_type has not been defined in config file, use [dataset] train_dataset_type instead." % temp_mode) 45 | temp_mode = "train" 46 | which = config.get("data", "%s_dataset_type" % temp_mode) 47 | 48 | if which in dataset_list: 49 | dataset = dataset_list[which](config, mode, *args, **params) 50 | batch_size = config.getint("train", "batch_size") 51 | shuffle = config.getboolean("train", "shuffle") 52 | reader_num = config.getint("train", "reader_num") 53 | drop_last = True 54 | if mode in ["valid", "test"]: 55 | if mode == "test": 56 | drop_last = False 57 | 58 | try: 59 | batch_size = config.getint("eval", "batch_size") 60 | except Exception as e: 61 | logger.warning("[eval] batch size has not been defined in config file, use [train] batch_size instead.") 62 | 63 | try: 64 | shuffle = config.getboolean("eval", "shuffle") 65 | except Exception as e: 66 | shuffle = False 67 | logger.warning("[eval] shuffle has not been defined in config file, use false as default.") 68 | try: 69 | reader_num = config.getint("eval", "reader_num") 70 | except Exception as e: 71 | logger.warning("[eval] reader num has not been defined in config file, use [train] reader num instead.") 72 | 73 | 74 | collate_fn[mode] = formatter[mode].process 75 | if isinstance(dataset, KaraPytorchDatasetBase): 76 | sampler = None 77 | dataloader = DataLoader(dataset=dataset, 78 | batch_size=batch_size, 79 | num_workers=reader_num, 80 | collate_fn=collate_fn[mode], 81 | sampler=sampler) 82 | else: 83 | dataloader = DistributedDataLoader(dataset=dataset, 84 | batch_size=batch_size, 85 | shuffle=shuffle, 86 | num_workers=reader_num, 87 | collate_fn=collate_fn[mode], 88 | # drop_last=True 89 | ) 90 | 91 | return dataloader 92 | else: 93 | logger.error("There is no dataset called %s, check your config." % which) 94 | raise NotImplementedError 95 | 96 | 97 | def init_test_dataset(config, *args, **params): 98 | init_formatter(config, ["test"], *args, **params) 99 | test_dataset = init_one_dataset(config, "test", *args, **params) 100 | 101 | return test_dataset 102 | 103 | 104 | def init_dataset(config, *args, **params): 105 | init_formatter(config, ["train", "valid"], *args, **params) 106 | train_dataset = init_one_dataset(config, "train", *args, **params) 107 | if config.getboolean("train", "no_valid"): 108 | valid_dataset = None 109 | else: 110 | valid_dataset = init_one_dataset(config, "valid", *args, **params) 111 | 112 | return train_dataset, valid_dataset 113 | 114 | 115 | if __name__ == "__main__": 116 | pass 117 | -------------------------------------------------------------------------------- /model/T5Adapter/T5Adapter.py: -------------------------------------------------------------------------------- 1 | from transformers import T5Tokenizer 2 | from ..Basic.DeltaT5 import DeltaT5 3 | import torch 4 | from torch import nn 5 | from model.metric import softmax_acc 6 | import bmtrain as bmt 7 | import os 8 | from model_center.model import T5Config 9 | from ..Basic.layers import Linear 10 | from tools import reduce 11 | import time 12 | 13 | class T5Adapter(nn.Module): 14 | def __init__(self, config): 15 | super(T5Adapter, self).__init__() 16 | self.plmpath = os.path.join(config.get("model", "pretrained_model_path"), config.get("model", "pretrained_model")) 17 | print("load pre-trained model from", self.plmpath) 18 | self.t5config = T5Config.from_pretrained(self.plmpath) 19 | 20 | self.backbone = DeltaT5.from_pretrained(self.plmpath) 21 | self.tokenizer = T5Tokenizer.from_pretrained(os.path.join(self.plmpath, "tokenizer")) 22 | 23 | self.finetune = config.getboolean("train", "finetune") 24 | # self.enc_adapter = Linear(1, 4 * self.t5config.num_encoder_layers * self.t5config.dim_model * 32, init_std=0.01) 25 | # self.dec_adapter = Linear(1, 4 * self.t5config.num_encoder_layers * self.t5config.dim_model * 32, init_std=0.01) 26 | if self.finetune: 27 | self.enc_adapter = None 28 | self.dec_adapter = None 29 | else: 30 | self.bottleneck_dim = config.getint("train", "bottleneck_dim") 31 | self.enc_adapter = Linear(1, 2 * self.t5config.num_encoder_layers * self.t5config.dim_model * self.bottleneck_dim, init_std=0.01) 32 | self.dec_adapter = Linear(1, 2 * self.t5config.num_encoder_layers * self.t5config.dim_model * self.bottleneck_dim, init_std=0.01) 33 | 34 | bmt.init_parameters(self.enc_adapter) 35 | bmt.init_parameters(self.dec_adapter) 36 | 37 | self.freeze_module(self.backbone) 38 | 39 | def freeze_module(self, module): 40 | for param in module.parameters(): 41 | param.requires_grad = False 42 | 43 | def get_adapter(self, device): 44 | if self.finetune: 45 | return None, None 46 | enc_adapters = self.enc_adapter(torch.ones(1, 1, dtype=torch.half, device=device)).view(self.t5config.num_encoder_layers, 2, 1, self.t5config.dim_model, self.bottleneck_dim) 47 | dec_adapters = self.dec_adapter(torch.ones(1, 1, dtype=torch.half, device=device)).view(self.t5config.num_decoder_layers, 2, 1, self.t5config.dim_model, self.bottleneck_dim) 48 | return enc_adapters, dec_adapters 49 | 50 | def forward(self, data): 51 | enc_adapters, dec_adapters = self.get_adapter(data["input_ids"].device) 52 | 53 | logits = self.backbone( 54 | input_ids = data["input_ids"], 55 | attention_mask = data["attention_mask"], 56 | decoder_input_ids = data["decoder_input_ids"], 57 | decoder_length = data["decoder_length"] if "decoder_length" in data else None, 58 | decoder_attention_mask=data["decoder_attention_mask"] if "decoder_attention_mask" in data else None, 59 | return_logits = True, 60 | enc_adapters = enc_adapters, 61 | dec_adapters = dec_adapters, 62 | ) 63 | return logits 64 | 65 | def generate_greedy(self, data, gen_length=20): 66 | enc_adapters, dec_adapters = self.get_adapter(data["input_ids"].device) 67 | 68 | batch, device = data["input_ids"].size(0), data["input_ids"].device 69 | 70 | dec_input_ids = torch.zeros(batch, gen_length + 4, dtype=torch.long).to(device) 71 | dec_input_ids[:,1] = 32099 72 | length = torch.LongTensor([gen_length + 4] * batch).to(device) 73 | position = 1 # batch 74 | 75 | # print(self.tokenizer.decode(data["input_ids"][0])) 76 | # print("==" * 15) 77 | predict, logits = self.backbone( 78 | input_ids = data["input_ids"], 79 | attention_mask = data["attention_mask"], 80 | decoder_input_ids = dec_input_ids, 81 | decoder_length=length, 82 | enc_adapters = enc_adapters, 83 | dec_adapters = dec_adapters, 84 | ) 85 | 86 | encoder_outputs = predict.encoder_last_hidden_state 87 | answer = [torch.argmax(logits[:, position], dim=-1)] 88 | end = (answer[-1] == 1) 89 | for i in range(gen_length): 90 | if not end.all(): 91 | position += 1 92 | dec_input_ids[:, position] = answer[-1] 93 | predict, logits = self.backbone( 94 | encoder_outputs=encoder_outputs, 95 | attention_mask = data["attention_mask"], 96 | decoder_input_ids=dec_input_ids, 97 | decoder_length=length, 98 | enc_adapters = enc_adapters, 99 | dec_adapters = dec_adapters, 100 | ) 101 | if not end.all(): 102 | answer.append(torch.argmax(logits[:, position], dim=-1)) 103 | answer[-1][end] = 1 104 | end = end | (answer[-1] == 1) 105 | all_end = reduce(end.all().int(), "sum") 106 | if all_end == bmt.world_size(): 107 | break 108 | answer = torch.cat([a.unsqueeze(1) for a in answer], dim=1).contiguous() 109 | return answer 110 | 111 | -------------------------------------------------------------------------------- /formatter/OpenQA/OpenQAFormatter.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import os 4 | import numpy as np 5 | 6 | import random 7 | from transformers import T5Tokenizer, T5Config 8 | from tools import shift_tokens_right 9 | 10 | class OpenQAFormatter: 11 | def __init__(self, config, mode, *args, **params): 12 | self.max_len = config.getint("train", "max_len") 13 | self.ans_max_len = config.getint("train", "ans_max_len") 14 | self.model_type = config.get("model", "model_type") 15 | self.mode = mode 16 | if self.model_type == "PostT5" and self.mode != "test": 17 | self.max_len = 256 18 | self.plmpath = os.path.join(config.get("model", "pretrained_model_path"), config.get("model", "pretrained_model")) 19 | self.tokenizer = T5Tokenizer.from_pretrained(os.path.join(self.plmpath, "tokenizer")) 20 | 21 | def generate_input(self, question, context): 22 | if self.model_type == "PostT5" and self.mode != "test": 23 | return " ".join(["question:", question.lstrip()]) 24 | else: 25 | return " ".join(["question:", question.lstrip(), "context:", "\n".join([c["text"].lstrip() for c in context])]) 26 | 27 | def preprocess_squad_batch(self, examples): 28 | inputs = [self.generate_input(qa["question"] + "" if "" not in qa["question"] else qa["question"], qa["context"]) for qa in examples] 29 | targets = [] 30 | for qa in examples: 31 | targets.append("" + random.choice(qa["answers"])) 32 | 33 | return inputs, targets 34 | 35 | def process(self, data): 36 | inputs, targets = self.preprocess_squad_batch(data) 37 | model_inputs = self.tokenizer(inputs, max_length=self.max_len, padding="max_length", truncation=True) 38 | 39 | labels = self.tokenizer(text_target=targets, max_length=self.ans_max_len, padding="max_length", truncation=True) 40 | 41 | if self.mode == "train": 42 | labels["input_ids"] = [ 43 | [(l if l != self.tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 44 | ] 45 | 46 | model_inputs["decoder_input_ids"] = shift_tokens_right(torch.LongTensor(labels["input_ids"]), 0, 0) 47 | model_inputs["labels"] = labels["input_ids"] 48 | model_inputs["decoder_attention_mask"] = labels["attention_mask"] 49 | 50 | for key in model_inputs: 51 | model_inputs[key] = torch.LongTensor(model_inputs[key]) 52 | # print(self.tokenizer.decode(model_inputs["input_ids"][0])) 53 | # print(self.tokenizer.decode(model_inputs["decoder_input_ids"][0])) 54 | # print("===" * 10) 55 | if "labels" in model_inputs: 56 | model_inputs["labels"][:,0] = -100 57 | 58 | model_inputs["answers"] = [{" ".join(ans.split()[:512]) for ans in doc["answers"]} for doc in data] 59 | 60 | return model_inputs 61 | 62 | 63 | class OpenQAPlugDFormatter: 64 | def __init__(self, config, mode, *args, **params): 65 | self.max_len = config.getint("train", "max_len") 66 | self.ctx_len = config.getint("train", "ctx_len") 67 | self.ans_max_len = config.getint("train", "ans_max_len") 68 | self.mode = mode 69 | self.plmpath = os.path.join(config.get("model", "pretrained_model_path"), config.get("model", "pretrained_model")) 70 | self.tokenizer = T5Tokenizer.from_pretrained(os.path.join(self.plmpath, "tokenizer")) 71 | 72 | def process(self, data): 73 | query = [d["question"] + "" for d in data] 74 | # ctxs = ["\n".join(["%s\t%s" % (c["wikipedia_title"], c["text"].lstrip()) for c in d["context"]]) for d in data] 75 | ctxs = ["\n".join([c["text"].lstrip() for c in d["context"]]) for d in data] 76 | targets = ["" + random.choice(d["answers"]) for d in data] 77 | 78 | query_info = self.tokenizer(query, max_length=self.max_len, padding="max_length", truncation=True) 79 | ctx_info = self.tokenizer(ctxs, max_length=self.ctx_len, padding="max_length", truncation=True) 80 | 81 | labels = self.tokenizer(text_target=targets, max_length=self.ans_max_len, padding="max_length", truncation=True) 82 | 83 | model_inputs = { 84 | "que_input_ids": query_info["input_ids"], 85 | "que_attention_mask": query_info["attention_mask"], 86 | "ctx_input_ids": ctx_info["input_ids"], 87 | "ctx_attention_mask": ctx_info["attention_mask"] 88 | } 89 | if self.mode == "train": 90 | labels["input_ids"] = [ 91 | [(l if l != self.tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 92 | ] 93 | 94 | model_inputs["decoder_input_ids"] = shift_tokens_right(torch.LongTensor(labels["input_ids"]), 0, 0) 95 | model_inputs["labels"] = labels["input_ids"] 96 | model_inputs["decoder_attention_mask"] = labels["attention_mask"] 97 | 98 | for key in model_inputs: 99 | model_inputs[key] = torch.LongTensor(model_inputs[key]) 100 | if "labels" in model_inputs: 101 | model_inputs["labels"][:,0] = -100 102 | 103 | model_inputs["answers"] = [{ans for ans in doc["answers"]} for doc in data] 104 | 105 | return model_inputs 106 | 107 | -------------------------------------------------------------------------------- /formatter/FEVER/FEVERCtxFormatter.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import os 4 | import numpy as np 5 | 6 | import random 7 | from transformers import T5Tokenizer,T5Config 8 | 9 | class FEVERCtxFormatter: 10 | def __init__(self, config, mode, *args, **params): 11 | self.max_len = config.getint("train", "max_len") 12 | self.mode = mode 13 | self.config = config 14 | self.model_type = config.get("model", "model_type") 15 | self.tokenizer = T5Tokenizer.from_pretrained(os.path.join(config.get("model", "pretrained_model_path"), config.get("model", "pretrained_model"), "tokenizer")) 16 | self.max_len = config.getint("train", "max_len") 17 | 18 | if self.model_type == "PostT5" and self.mode != "test": 19 | self.max_len = 128 20 | self.label2id = { 21 | "SUPPORTS": 0, 22 | "REFUTES": 1, 23 | # "NOT ENOUGH INFO": 2 24 | } 25 | self.top_ctx = 3 26 | 27 | def process(self, data): 28 | # claims = [d["claim"] for d in data] 29 | claims = [d["input"] for d in data] 30 | ctxs = ["\n".join([text["text"] for text in d["output"][0]["provenance"][:self.top_ctx]]) for d in data] 31 | 32 | if self.model_type == "PostT5" and self.mode != "test": 33 | text = ["claim: %s" % c for c in claims] 34 | else: 35 | text = ["claim: %s \n Context: %s" % (c, ctx) for c, ctx in zip(claims, ctxs)] 36 | 37 | ret = self.tokenizer(text, max_length=self.max_len, padding="max_length", truncation=True) 38 | 39 | labels = [self.label2id[d["output"][1]["answer"]] for d in data] 40 | ret["labels"] = labels 41 | for key in ret: 42 | ret[key] = torch.LongTensor(ret[key]) 43 | return ret 44 | 45 | 46 | 47 | class FEVERCtxPlugDFormatter: 48 | def __init__(self, config, mode, *args, **params): 49 | self.max_len = config.getint("train", "max_len") 50 | self.ctx_len = config.getint("train", "ctx_len") 51 | self.mode = mode 52 | self.config = config 53 | self.tokenizer = T5Tokenizer.from_pretrained(os.path.join(config.get("model", "pretrained_model_path"), config.get("model", "pretrained_model"), "tokenizer")) 54 | 55 | self.label2id = { 56 | "SUPPORTS": 0, 57 | "REFUTES": 1, 58 | # "NOT ENOUGH INFO": 2 59 | } 60 | self.top_ctx = 3 61 | 62 | def process(self, data): 63 | # claims = [d["claim"] for d in data] 64 | claims = [d["input"] + "yes or no? " for d in data] 65 | ctxs = ["\n".join([text["text"] for text in d["output"][0]["provenance"][:self.top_ctx]]) for d in data] 66 | 67 | ctx_info = self.tokenizer(ctxs, max_length=self.ctx_len, padding="max_length", truncation=True) 68 | query_info = self.tokenizer(claims, max_length=self.max_len, padding="max_length", truncation=True) 69 | 70 | ret = { 71 | "que_input_ids": query_info["input_ids"], 72 | "que_attention_mask": query_info["attention_mask"], 73 | "ctx_input_ids": ctx_info["input_ids"], 74 | "ctx_attention_mask": ctx_info["attention_mask"], 75 | "labels": [self.label2id[d["output"][1]["answer"]] for d in data], 76 | "decoder_input_ids": [[0]] * len(data), 77 | "decoder_length": [1] * len(data) 78 | } 79 | 80 | for key in ret: 81 | ret[key] = torch.LongTensor(ret[key]) 82 | return ret 83 | 84 | 85 | 86 | class FEVERCtxED2LMFormatter: 87 | def __init__(self, config, mode, *args, **params): 88 | self.max_len = config.getint("train", "max_len") 89 | self.ctx_len = config.getint("train", "ctx_len") 90 | self.mode = mode 91 | self.config = config 92 | self.tokenizer = T5Tokenizer.from_pretrained(os.path.join(config.get("model", "pretrained_model_path"), config.get("model", "pretrained_model"), "tokenizer")) 93 | 94 | self.label2id = { 95 | "SUPPORTS": 0, 96 | "REFUTES": 1, 97 | # "NOT ENOUGH INFO": 2 98 | } 99 | self.top_ctx = 3 100 | 101 | def process(self, data): 102 | # claims = [d["claim"] for d in data] 103 | claims = [d["input"] for d in data] 104 | ctxs = ["\n".join([text["text"] for text in d["output"][0]["provenance"][:self.top_ctx]]) for d in data] 105 | 106 | ctx_info = self.tokenizer(ctxs, max_length=self.ctx_len, padding="max_length", truncation=True) 107 | decoder_inp, position = [], [] 108 | for query in claims: 109 | qtoken = self.tokenizer.encode(query + "Answer:", add_special_tokens=False) 110 | p = len(qtoken) 111 | qtoken = qtoken + [0] * (self.max_len - len(qtoken)) 112 | decoder_inp.append(qtoken[:self.max_len]) 113 | position.append(min(p, self.max_len - 1)) 114 | 115 | ret = { 116 | "decoder_input_ids": decoder_inp, 117 | "decoder_attention_mask": [[1] * self.max_len] * len(data), 118 | "ctx_input_ids": ctx_info["input_ids"], 119 | "ctx_attention_mask": ctx_info["attention_mask"], 120 | "labels": [self.label2id[d["output"][1]["answer"]] for d in data], 121 | "position": position, 122 | } 123 | 124 | for key in ret: 125 | ret[key] = torch.LongTensor(ret[key]) 126 | return ret 127 | -------------------------------------------------------------------------------- /model/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import json 5 | import re 6 | import string 7 | from collections import Counter 8 | import bmtrain as bmt 9 | from rouge import Rouge 10 | 11 | def softmax_acc(score, label, acc_result): 12 | if acc_result is None or acc_result["total"] > 25600: 13 | acc_result = {'total': 0, 'right': 0} 14 | predict = torch.max(score, dim = 1)[1] 15 | acc_result['total'] += int(label.shape[0]) 16 | acc_result['right'] += int((predict == label).int().sum()) 17 | return acc_result 18 | 19 | def mlm_acc_loss(predict, labels, acc_result, loss): 20 | if acc_result is None: 21 | acc_result = {'total': 0, 'right': 0, "loss": []} 22 | acc_result["right"] += int((predict[labels > 0] == labels[labels > 0]).sum()) 23 | acc_result["total"] += int((labels > 0).sum()) 24 | # if loss == 0 and len(acc_result["loss"]) > 0: 25 | # loss = torch.tensor(sum(acc_result["loss"]) / len(acc_result["loss"]), device=predict.device) 26 | if loss != 0: 27 | acc_result["loss"].append(bmt.sum_loss(loss).item()) 28 | acc_result["loss"] = acc_result["loss"][-100:] 29 | return acc_result 30 | 31 | def microf1(scores, labels, acc_result): 32 | if acc_result is None: 33 | acc_result = {"TP": 0, "FP": 0, "FN": 0} 34 | # scores: batch, label_num 35 | # labels: batch, label_num 36 | predict = scores > 0.5 37 | acc_result["TP"] += int(labels[predict].sum()) 38 | acc_result["FP"] += int((labels[predict] == 0).sum()) 39 | acc_result["FN"] += int(labels[scores <= 0.5].sum()) 40 | return acc_result 41 | 42 | 43 | 44 | def normalize_answer(s): 45 | # return s 46 | """Lower text and remove punctuation, articles and extra whitespace.""" 47 | 48 | def remove_articles(text): 49 | return re.sub(r"\b(a|an|the)\b", " ", text) 50 | 51 | def white_space_fix(text): 52 | return " ".join(text.split()) 53 | 54 | def remove_punc(text): 55 | exclude = set(string.punctuation) 56 | return "".join(ch for ch in text if ch not in exclude) 57 | 58 | def lower(text): 59 | return text.lower() 60 | 61 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 62 | 63 | def ROUGE_normalize_answer(s): 64 | # return s 65 | """Lower text and remove punctuation, articles and extra whitespace.""" 66 | 67 | def white_space_fix(text): 68 | return " ".join(text.split()) 69 | 70 | def remove_punc(text): 71 | exclude = set(string.punctuation) 72 | return "".join(ch for ch in text if ch not in exclude) 73 | 74 | def lower(text): 75 | return text.lower() 76 | 77 | return white_space_fix(remove_punc(lower(s))) 78 | 79 | def squad_em(predict, answers): 80 | em = 0 81 | for pre, ans in zip(predict, answers): 82 | if pre in ans: 83 | em += 1 84 | # else: 85 | # print("predict: %s\t answer: %s" % (pre, ans)) 86 | return em 87 | 88 | def squad_f1(predict, answers): 89 | ret = 0 90 | for pred, ans in zip(predict, answers): 91 | # if pred == "no answer": 92 | # continue 93 | prediction_tokens = pred.split() 94 | cpred_token = Counter(prediction_tokens) 95 | curf1 = [] 96 | for a in ans: 97 | ground_truth_tokens = a.split() 98 | common = cpred_token & Counter(ground_truth_tokens) 99 | num_same = sum(common.values()) 100 | if num_same == 0: 101 | curf1.append(0) 102 | else: 103 | precision = 1.0 * num_same / len(prediction_tokens) 104 | recall = 1.0 * num_same / len(ground_truth_tokens) 105 | f1 = (2 * precision * recall) / (precision + recall) 106 | curf1.append(f1) 107 | ret += max(curf1) 108 | return ret 109 | 110 | def squad_NAF1(predict, answers, acc_result): 111 | for p, ans in zip(predict, answers): 112 | if p == "no answer": 113 | if "no answer" in ans: 114 | acc_result["NA_tp"] += 1 115 | else: 116 | acc_result["NA_fp"] += 1 117 | else: 118 | if "no answer" in ans: 119 | acc_result["NA_tn"] += 1 120 | else: 121 | acc_result["NA_fn"] += 1 122 | return acc_result 123 | 124 | def squad_metric(predict, answers, acc_result, tokenizer, RL=False): 125 | if acc_result is None: 126 | acc_result = {"train": False, "total": 0, "em_sum": 0, "f1_sum": 0., "NA_tp": 0, "NA_fp": 0, "NA_tn": 0, "NA_fn": 0, "ROUGE-L-R": 0, "ROUGE-L-P": 0, "ROUGE-L-F": 0} 127 | pred = [] 128 | rouge_pred = [] 129 | for p in predict: 130 | tmp = [] 131 | # print("token ID: %s" % p) 132 | for n in p: 133 | if n == 1: 134 | break 135 | tmp.append(int(n)) 136 | s = tokenizer.decode(tmp, skip_special_tokens=True) 137 | rouge_pred.append(ROUGE_normalize_answer(s)) 138 | pred.append(normalize_answer(s)) 139 | # pred = [normalize_answer([int(n) for n in p if n == 1 break], skip_special_tokens=True)) for p in predict] 140 | ground = [{normalize_answer(a) for a in ans} for ans in answers] 141 | 142 | if RL: 143 | ROUGE_ground = [ROUGE_normalize_answer(list(ans)[0]) for ans in answers] 144 | scorer = Rouge() 145 | score = scorer.get_scores([p if p != "" else " " for p in rouge_pred], ROUGE_ground, avg=True) 146 | acc_result["ROUGE-L-P"] += score["rouge-l"]["p"] * len(rouge_pred) 147 | acc_result["ROUGE-L-R"] += score["rouge-l"]["r"] * len(rouge_pred) 148 | acc_result["ROUGE-L-F"] += score["rouge-l"]["f"] * len(rouge_pred) 149 | 150 | acc_result["em_sum"] += squad_em(pred, ground) 151 | acc_result["f1_sum"] += squad_f1(pred, ground) 152 | acc_result["total"] += len(pred) 153 | acc_result = squad_NAF1(pred, ground, acc_result) 154 | # print(acc_result) 155 | return acc_result 156 | 157 | def squad_train_metric(predict, labels, acc_result): 158 | # predict: batch, len 159 | # labels: batch, len 160 | if acc_result is None: 161 | acc_result = {"train": True, "total": 0, "right": 0} 162 | acc_result["right"] += int((predict[labels > 0] == labels[labels > 0]).sum()) 163 | acc_result["total"] += int((labels > 0).sum()) 164 | return acc_result 165 | 166 | def sum_normalize_answer(s): 167 | """Lower text and remove punctuation, articles and extra whitespace.""" 168 | def white_space_fix(text): 169 | return " ".join(text.split()) 170 | 171 | def lower(text): 172 | return text.lower() 173 | ret = white_space_fix(lower(s)) 174 | if ret == "": 175 | ret = " " 176 | return ret 177 | 178 | def summarization_metric(predict, answers, acc_result, tokenizer): 179 | if acc_result is None: 180 | acc_result = {"train": False, "total": 0, "rouge-1": 0.0, "rouge-2": 0.0, "rouge-3": 0.0} 181 | pred = [] 182 | for p in predict: 183 | tmp = [] 184 | for n in p: 185 | if n == 1: 186 | break 187 | tmp.append(int(n)) 188 | pred.append(sum_normalize_answer(tokenizer.decode(tmp, skip_special_tokens=True))) 189 | ground = [sum_normalize_answer(ans) for ans in answers] 190 | # print(pred) 191 | scorer = Rouge() 192 | score = scorer.get_scores(pred, ground) 193 | acc_result["rouge-1"] = score[0]["rouge-1"]["r"] * len(pred) 194 | acc_result["rouge-2"] = score[0]["rouge-2"]["r"] * len(pred) 195 | acc_result["rouge-l"] = score[0]["rouge-l"]["r"] * len(pred) 196 | acc_result["total"] += len(pred) 197 | # print(score) 198 | # print(acc_result) 199 | return acc_result 200 | 201 | -------------------------------------------------------------------------------- /model/Basic/layers.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import bmtrain as bmt 5 | import math 6 | import torch.nn.functional as F 7 | 8 | class MLP(torch.nn.Module): 9 | def __init__(self, dim_in, dim_out, length_scale=True, layer_norm=False, init_std=0.02, dim_mid=None, bias=False): 10 | super().__init__() 11 | if layer_norm: 12 | self.layer_norm = LayerNorm(dim_in, bias=False, eps=1e-6, group="mapper") 13 | else: 14 | self.layer_norm = None 15 | self.dim_mid = dim_in if dim_mid is None else dim_mid 16 | self.layer1 = Linear(dim_in=dim_in, dim_out=self.dim_mid, init_std=init_std, length_scale=length_scale, bias=bias, group="mapper") 17 | self.layer2 = Linear(dim_in=self.dim_mid, dim_out=dim_out, init_std=init_std, length_scale=length_scale, bias=bias, group="mapper") 18 | self.act = torch.nn.ReLU() 19 | 20 | def forward(self, rep: torch.Tensor): 21 | if self.layer_norm is not None: 22 | return self.layer2(self.act(self.layer1(self.layer_norm(rep)))) 23 | else: 24 | return self.layer2(self.act(self.layer1(rep))) 25 | 26 | 27 | class Linear(bmt.DistributedModule): 28 | r"""A fully connected layer, which performs :math:`\pmb{y} = \mathbf{W} \pmb{x} + \pmb{b}` 29 | 30 | Args: 31 | dim_in (int): input dimension of :math:`\pmb{x}` 32 | dim_out (int): output dimension of :math:`\pmb{y}` 33 | dtype (optional): Defaults to torch.half. 34 | init_mean (float, optional): mean of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}, \text{std}^2)`. Defaults to 0. 35 | init_std (float, optional): std of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}, \text{std}^2)`. Defaults to 1. 36 | bias (bool, optional): whether to add bias term :math:`\pmb{b}`. Defaults to False. 37 | """ 38 | def __init__(self, 39 | dim_in : int, 40 | dim_out : int, 41 | length_scale : bool = False, 42 | length_scale_before : bool = False, 43 | dtype = torch.half, 44 | # dtype = torch.float32, 45 | int8 : bool = False, 46 | init_mean : float = 0.0, 47 | init_std : float = 1, 48 | bias : bool = False, 49 | group : Optional[str] = None, 50 | ): 51 | super().__init__() 52 | self.dim_in = dim_in 53 | self.weight = bmt.DistributedParameter( 54 | torch.empty((dim_out, dim_in), dtype=dtype), 55 | init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), 56 | # init_method=bmt.ParameterInitializer(torch.nn.init.kaiming_uniform_, a=init_std), 57 | group = group, 58 | ) 59 | self.bias = bmt.DistributedParameter( 60 | torch.empty((dim_out,), dtype=dtype), 61 | init_method=bmt.ParameterInitializer(torch.nn.init.zeros_), 62 | group = group, 63 | ) if bias else None 64 | self.length_scale = length_scale 65 | self.length_scale_before = length_scale_before 66 | self.int8 = int8 67 | 68 | def forward(self, x : torch.Tensor): 69 | """ 70 | Args: 71 | x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer 72 | 73 | Returns: 74 | :obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y. 75 | 76 | """ 77 | if self.length_scale and self.length_scale_before: 78 | x = x / math.sqrt(self.dim_in) 79 | x = F.linear(x, self.weight) 80 | if self.length_scale and not self.length_scale_before: 81 | x = x / math.sqrt(self.dim_in) 82 | if self.bias is not None: 83 | x = x + self.bias 84 | return x 85 | 86 | 87 | 88 | class KaimingLinear(bmt.DistributedModule): 89 | r"""A fully connected layer, which performs :math:`\pmb{y} = \mathbf{W} \pmb{x} + \pmb{b}` 90 | 91 | Args: 92 | dim_in (int): input dimension of :math:`\pmb{x}` 93 | dim_out (int): output dimension of :math:`\pmb{y}` 94 | dtype (optional): Defaults to torch.half. 95 | init_mean (float, optional): mean of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}, \text{std}^2)`. Defaults to 0. 96 | init_std (float, optional): std of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}, \text{std}^2)`. Defaults to 1. 97 | bias (bool, optional): whether to add bias term :math:`\pmb{b}`. Defaults to False. 98 | """ 99 | def __init__(self, 100 | dim_in : int, 101 | dim_out : int, 102 | length_scale : bool = False, 103 | length_scale_before : bool = False, 104 | dtype = torch.half, 105 | int8 : bool = False, 106 | init_a : float = 1, 107 | bias : bool = False, 108 | group : Optional[str] = None, 109 | zeros : bool = False 110 | ): 111 | super().__init__() 112 | self.dim_in = dim_in 113 | self.weight = bmt.DistributedParameter( 114 | torch.empty((dim_out, dim_in), dtype=dtype), 115 | # init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), 116 | init_method=bmt.ParameterInitializer(torch.nn.init.kaiming_uniform_, a=init_a) if not zeros else bmt.ParameterInitializer(torch.nn.init.zeros_), 117 | group = group, 118 | ) 119 | self.bias = bmt.DistributedParameter( 120 | torch.empty((dim_out,), dtype=dtype), 121 | init_method=bmt.ParameterInitializer(torch.nn.init.zeros_), 122 | group = group, 123 | ) if bias else None 124 | self.length_scale = length_scale 125 | self.length_scale_before = length_scale_before 126 | self.int8 = int8 127 | 128 | def forward(self, x : torch.Tensor): 129 | """ 130 | Args: 131 | x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer 132 | 133 | Returns: 134 | :obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y. 135 | 136 | """ 137 | if self.length_scale and self.length_scale_before: 138 | x = x / math.sqrt(self.dim_in) 139 | x = F.linear(x, self.weight) 140 | if self.length_scale and not self.length_scale_before: 141 | x = x / math.sqrt(self.dim_in) 142 | if self.bias is not None: 143 | x = x + self.bias 144 | return x 145 | 146 | 147 | 148 | class LayerNorm(bmt.DistributedModule): 149 | r""" 150 | `LayerNorm `_ if bias = True: :math:`y = {x-\text{E}[x]\over \text{Var}[x]+\text{eps}} * w + \text{bias}` 151 | 152 | `RMS LayerNorm `_ if bias = False: :math:`y = {x\over \text{Var}[x]+\text{eps}} * w` 153 | 154 | Args: 155 | dim_norm (int): norm dimesion 156 | dtype (optional): Defaults to torch.half. 157 | bias (bool, optional): whether to add the :math:`\text{bias}` term. Defaults to True. 158 | eps (float, optional): :math:`\text{eps}` term. Defaults to 1e-5. 159 | init_var (float, optional): weight will be all initialized to init_var. Defaults to 1.0. 160 | """ 161 | def __init__(self, dim_norm : int, 162 | dtype=torch.half, 163 | bias=True, 164 | eps : float = 1e-5, 165 | init_var = 1.0, 166 | group : Optional[str] = None, 167 | ): 168 | 169 | super().__init__() 170 | 171 | self.eps = eps 172 | self.dim_norm = dim_norm 173 | self.weight = bmt.DistributedParameter( 174 | torch.ones(dim_norm, dtype=dtype) * init_var, group=group) 175 | self.bias = bmt.DistributedParameter( 176 | torch.zeros(dim_norm, dtype=dtype), group=group) if bias else None 177 | 178 | def forward(self, x : torch.Tensor): 179 | """ 180 | Args: 181 | x (:obj:`torch.Tensor` of shape ``(batch_size, seq_len, dim_norm)``): Input tensor that need to be normalized. 182 | 183 | Return: 184 | :obj:`torch.Tensor` of shape ``(batch_size, seq_len, dim_norm)``: The layernorm output. 185 | 186 | """ 187 | assert x.size(-1) == self.dim_norm 188 | 189 | if self.bias is not None: 190 | return F.layer_norm(x, (self.dim_norm,), self.weight, self.bias, self.eps) 191 | else: 192 | return rms_layernorm(x, self.weight, self.eps) 193 | 194 | @torch.jit.script 195 | def rms_layernorm(hidden : torch.Tensor, weight : torch.Tensor, eps :float): 196 | old_dtype = hidden.dtype 197 | variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True) 198 | hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype) 199 | return hidden * weight --------------------------------------------------------------------------------