├── .gitignore ├── LICENSE ├── README.md ├── examples ├── data4generation.jsonl ├── data4parsing.jsonl ├── test.jsonl ├── train.jsonl └── val.jsonl ├── fine-tune ├── Eval-AMRBART-large-AMR2Text.sh ├── Eval-AMRBART-large-AMRParsing.sh ├── base_trainer.py ├── common │ ├── additional-tokens.json │ ├── callbacks.py │ ├── constant.py │ ├── options.py │ ├── penman_interface.py │ ├── postprocessing.py │ ├── training_args.py │ └── utils.py ├── data_interface │ ├── data.py │ └── dataset.py ├── evaluation │ ├── cdec-corpus │ │ ├── README.md │ │ ├── add-self-translations.pl │ │ ├── add-sos-eos.pl │ │ ├── conll2cdec.pl │ │ ├── corpus-stats.pl │ │ ├── cut-corpus.pl │ │ ├── filter-length.pl │ │ ├── lowercase.pl │ │ ├── moses-scfg-to-cdec.pl │ │ ├── moses-xml.pl │ │ ├── paste-files.pl │ │ ├── sample-dev-sets.py │ │ ├── support │ │ │ ├── README │ │ │ ├── fix-contract.pl │ │ │ ├── fix-eos.pl │ │ │ ├── quote-norm.pl │ │ │ ├── token_list │ │ │ ├── token_patterns │ │ │ ├── tokenizer.pl │ │ │ ├── utf8-normalize-batch.pl │ │ │ └── utf8-normalize.sh │ │ ├── tokenize-anything.sh │ │ ├── tokenize-parallel.py │ │ ├── untok.pl │ │ ├── utf8-normalize.sh │ │ └── xml-tok.py │ ├── eval_gen.py │ ├── eval_gen.sh │ ├── eval_smatch.py │ └── eval_smatch.sh ├── inference-amr.sh ├── inference-text.sh ├── main.py ├── metric │ └── sacrebleu.py ├── model_interface │ ├── modeling_bart.py │ ├── modeling_outputs.py │ └── tokenization_bart.py ├── seq2seq_trainer.py ├── train-AMRBART-large-AMR2Text.sh └── train-AMRBART-large-AMRParsing.sh ├── pre-train ├── common │ ├── additional-tokens.json │ ├── constant.py │ ├── penman_interface.py │ ├── postprocessing.py │ └── utils.py ├── data_interface │ ├── amrdata.py │ └── dataset.py ├── model_interface │ ├── modeling_bart.py │ ├── modeling_outputs.py │ └── tokenization_bart.py ├── run-posttrain-bart-textinf-joint-denoising-6task-large-unified-A100.sh ├── run-posttrain-bart-textinf-joint-denoising-6task-large-unified-V100.sh └── run_multitask_unified_pretraining.py └── requirements.yml /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 xfbai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AMRBART 2 | The refactored implementation for ACL2022 paper "Graph Pre-training for AMR Parsing and Generation". You may find our paper [here](https://arxiv.org/pdf/2203.07836.pdf) (Arxiv). The original implementation is avaliable [here](https://github.com/goodbai-nlp/AMRBART/tree/acl2022) 3 | 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/graph-pre-training-for-amr-parsing-and-1/amr-to-text-generation-on-ldc2017t10)](https://paperswithcode.com/sota/amr-to-text-generation-on-ldc2017t10?p=graph-pre-training-for-amr-parsing-and-1) 5 | 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/graph-pre-training-for-amr-parsing-and-1/amr-to-text-generation-on-ldc2020t02)](https://paperswithcode.com/sota/amr-to-text-generation-on-ldc2020t02?p=graph-pre-training-for-amr-parsing-and-1) 7 | 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/graph-pre-training-for-amr-parsing-and-1/amr-parsing-on-ldc2017t10)](https://paperswithcode.com/sota/amr-parsing-on-ldc2017t10?p=graph-pre-training-for-amr-parsing-and-1) 9 | 10 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/graph-pre-training-for-amr-parsing-and-1/amr-parsing-on-ldc2020t02)](https://paperswithcode.com/sota/amr-parsing-on-ldc2020t02?p=graph-pre-training-for-amr-parsing-and-1) 11 | 12 | **News**🎈 13 | 14 | - (2022/12/10) fix max_length bugs in AMR parsing and update results. 15 | - (2022/10/16) release the AMRBART-v2 model which is simpler, faster, and stronger. 16 | 17 | # Requirements 18 | + python 3.8 19 | + pytorch 1.8 20 | + transformers 4.21.3 21 | + datasets 2.4.0 22 | + Tesla V100 or A100 23 | 24 | We recommend to use conda to manage virtual environments: 25 | ``` 26 | conda env update --name --file requirements.yml 27 | ``` 28 | 29 | # Data Processing 30 | 31 | 32 | You may download the AMR corpora at [LDC](https://www.ldc.upenn.edu). 33 | 34 | Please follow [this respository](https://github.com/goodbai-nlp/AMR-Process) to preprocess AMR graphs: 35 | ``` 36 | bash run-process-acl2022.sh 37 | ``` 38 | 39 | # Usage 40 | 41 | Our model is avaliable at [huggingface](https://huggingface.co/xfbai). Here is how to initialize a AMR parsing model in PyTorch: 42 | 43 | ``` 44 | from transformers import BartForConditionalGeneration 45 | from model_interface.tokenization_bart import AMRBartTokenizer # We use our own tokenizer to process AMRs 46 | 47 | model = BartForConditionalGeneration.from_pretrained("xfbai/AMRBART-large-finetuned-AMR3.0-AMRParsing-v2") 48 | tokenizer = AMRBartTokenizer.from_pretrained("xfbai/AMRBART-large-finetuned-AMR3.0-AMRParsing-v2") 49 | ``` 50 | 51 | 52 | ## Pre-training 53 | ``` 54 | bash run-posttrain-bart-textinf-joint-denoising-6task-large-unified-V100.sh "facebook/bart-large" 55 | ``` 56 | 57 | ## Fine-tuning 58 | 59 | For **AMR Parsing**, run 60 | ``` 61 | bash train-AMRBART-large-AMRParsing.sh "xfbai/AMRBART-large-v2" 62 | ``` 63 | 64 | For **AMR-to-text Generation**, run 65 | ``` 66 | bash train-AMRBART-large-AMR2Text.sh "xfbai/AMRBART-large-v2" 67 | ``` 68 | 69 | 70 | ## Evaluation 71 | ``` 72 | cd evaluation 73 | ``` 74 | 75 | For **AMR Parsing**, run 76 | ``` 77 | bash eval_smatch.sh /path/to/gold-amr /path/to/predicted-amr 78 | ``` 79 | For better results, you can postprocess the predicted AMRs using the [BLINK](https://github.com/facebookresearch/BLINK) tool following [SPRING](https://github.com/SapienzaNLP/spring). 80 | 81 | For **AMR-to-text Generation**, run 82 | ``` 83 | bash eval_gen.sh /path/to/gold-text /path/to/predicted-text 84 | ``` 85 | 86 | ## Inference on your own data 87 | 88 | If you want to run our code on your own data, try to transform your data into the format [here](https://github.com/goodbai-nlp/AMRBART/blob/main/examples/data4parsing.jsonl), then run 89 | 90 | For **AMR Parsing**, run 91 | ``` 92 | bash inference_amr.sh "xfbai/AMRBART-large-finetuned-AMR3.0-AMRParsing-v2" 93 | ``` 94 | 95 | For **AMR-to-text Generation**, run 96 | ``` 97 | bash inference_text.sh "xfbai/AMRBART-large-finetuned-AMR3.0-AMR2Text-v2" 98 | ``` 99 | 100 | # Pre-trained Models 101 | 102 | ## Pre-trained AMRBART 103 | 104 | 105 | |Setting| Params | checkpoint | 106 | | :----: | :----: |:---:| 107 | | AMRBART-large | 409M | [model](https://huggingface.co/xfbai/AMRBART-large-v2) | 108 | 109 | 110 | ## Fine-tuned models on AMR-to-Text Generation 111 | 112 | |Setting| BLEU(JAMR_tok) | Sacre-BLEU | checkpoint | output | 113 | | :----: | :----: |:---:| :----: | :----: | 114 | | AMRBART-large (AMR2.0) | 50.76 | 50.44 | [model](https://huggingface.co/xfbai/AMRBART-large-finetuned-AMR2.0-AMR2Text-v2) | [output](https://1drv.ms/t/s!ArC7JSpdBblgswHoArZOm8ej0yhB?e=0jxWTK) | 115 | | AMRBART-large (AMR3.0) | 50.29 | 50.38 | [model](https://huggingface.co/xfbai/AMRBART-large-finetuned-AMR3.0-AMR2Text-v2) | [output](https://1drv.ms/t/s!ArC7JSpdBblgswB1X7XrPjlxUtnn?e=zlowU9) | 116 | 117 | To get the tokenized bleu score, you need to use the scorer we provide [here](https://github.com/muyeby/AMRBART/blob/main/fine-tune/evaluation/eval_gen.sh). We use this script in order to ensure comparability with previous approaches. 118 | 119 | ## Fine-tuned models on AMR Parsing 120 | 121 | |Setting| Smatch(amrlib) | Smatch(amr-evaluation) | Smatch++(smatchpp) | checkpoint | output | 122 | | :----: | :----: |:---: |:---:| :----: | :----: | 123 | | AMRBART-large (AMR2.0) | 85.5 | 85.3 | 85.4 | [model](https://huggingface.co/xfbai/AMRBART-large-finetuned-AMR2.0-AMRParsing-v2) | [output](https://1drv.ms/t/s!ArC7JSpdBblgsywfCHhxkM6DGfbL?e=OxynaR) | 124 | | AMRBART-large (AMR3.0) | 84.4 | 84.2 | 84.3 | [model](https://huggingface.co/xfbai/AMRBART-large-finetuned-AMR3.0-AMRParsing-v2) | [output](https://1drv.ms/t/s!ArC7JSpdBblgsyuzmOH_0GMBr9m7?e=qtz2RD) | 125 | 126 | 127 | 128 | # Acknowledgements 129 | We thank authors of [SPRING](https://github.com/SapienzaNLP/spring), [amrlib](https://github.com/bjascob/amrlib), and [BLINK](https://github.com/facebookresearch/BLINK) that share open-source scripts for this project. 130 | # References 131 | ``` 132 | @inproceedings{bai-etal-2022-graph, 133 | title = "Graph Pre-training for {AMR} Parsing and Generation", 134 | author = "Bai, Xuefeng and 135 | Chen, Yulong and 136 | Zhang, Yue", 137 | booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 138 | month = may, 139 | year = "2022", 140 | address = "Dublin, Ireland", 141 | publisher = "Association for Computational Linguistics", 142 | url = "https://aclanthology.org/2022.acl-long.415", 143 | pages = "6001--6015" 144 | } 145 | ``` 146 | -------------------------------------------------------------------------------- /examples/data4generation.jsonl: -------------------------------------------------------------------------------- 1 | {"sent": "", "amr": "( multi-sentence :snt1 ( support-01 :mode imperative :ARG0 ( you ) :ARG1 ( person :ARG0-of ( start-01 :ARG1 ( thread ) ) ) :manner ( resolute ) ) :snt2 ( reply-01 :ARG0 ( i ) :ARG2 ( compose-02 :ARG0 :ARG1 ( poem ) ) ) )"} 2 | {"sent": "", "amr": "( pledge-01 :mode imperative :ARG0 ( you ) :ARG2 ( fight-01 :ARG0 :ARG2 ( defend-01 :ARG0 :ARG1 ( and :op1 ( island :wiki \"Senkaku_Islands\" :name ( name :op1 \"Diaoyu\" :op2 \"Islands\" ) ) :op2 ( island :ARG1-of ( relate-01 :ARG2 ) ) ) ) :manner ( die-01 :ARG1 ) ) )"} 3 | {"sent": "", "amr": "( multi-sentence :snt1 ( bump-01 :ARG1 ( boat :purpose ( fish-01 ) ) :ARG2 ( fleet ) ) :snt3 ( show-01 :ARG0 ( they :ARG0-of ( have-03 :ARG1 ( and :op1 ( heart :mod ( person :ARG0-of ( steal-01 ) ) ) :op2 ( form :mod ( arrogance ) ) ) ) ) :ARG1 ( appearance :poss :mod ( wolfish ) ) :mod ( again ) ) :snt2 ( and :op1 ( stir-up-04 :ARG0 ( ghost :mod ( country :wiki \"Japan\" :name ( name :op1 \"Japan\" ) ) :mod ( little ) :mod ( evil ) ) :ARG1 ( unrest ) ) :op2 ( make-trouble-06 :ARG0 ) ) )"} 4 | {"sent": "", "amr": "( multi-sentence :snt1 ( go-back-19 :mode imperative :polarity - :ARG1 ( you ) :ARG2 ( time :mod ( that ) :time-of ( appear-01 :ARG1 ( we ) :ARG0-of ( humiliate-01 ) :time ( sign-02 :ARG0 :ARG1 ( treaty ) ) ) ) :time ( ever ) :direction ( back ) ) :snt2 ( strong-02 :mode imperative :ARG1 ( country :wiki \"China\" :name ( name :op1 \"China\" ) ) ) :snt3 ( and :op1 ( sharp-02 :mode imperative :ARG1 ( sword ) ) :op2 ( shine-01 :mode imperative :ARG0 ( knife ) ) ) :snt4 ( bury-01 :mode imperative :ARG0 ( we ) :ARG1 ( enemy :ARG1-of ( approach-01 ) ) ) )"} 5 | {"sent": "", "amr": "( date-entity :month 9 :day 11 :year 2010 )"} 6 | {"sent": "", "amr": "( multi-sentence :snt2 ( bump-02 :ARG5 ( up ) :manner ( resolute ) ) :snt1 ( agree-01 :ARG0 ( i ) :ARG1 ( analyze-01 :ARG0 ( person :ARG0-of ( start-01 :ARG1 ( thread ) ) ) ) ) :snt3 ( like-02 :ARG0 ( i ) :ARG1 ( person :wiki \"Hua_Mulan\" :name ( name :op1 \"Hua\" :op2 \"Mulan\" ) :domain :purpose ( and :op1 ( and :op1 ( protect-01 :ARG0 :ARG1 ( home :poss ( we ) ) ) :op2 ( defend-01 :ARG0 :ARG1 ( homeland :poss ) ) ) :op2 ( spread-03 :ARG0 :ARG1 ( prestige :poss ( country :wiki \"China\" :name ( name :op1 \"China\" ) :poss ) ) ) :op5 ( pledge-01 :ARG0 :ARG2 ( fight-01 :ARG0 :ARG2 ( defend-01 :ARG1 ( dignity :mod ( nation ) :poss ) ) :degree ( die-01 :ARG1 ) ) ) ) ) :condition ( come-03 :ARG1 ( war ) :ARG1-of ( real-04 ) ) ) )"} 7 | {"sent": "", "amr": "( wish-01 :ARG0 ( i ) :ARG1 ( bear-02 :ARG1 :time ( event :wiki \"Long_March\" :name ( name :op1 \"Long\" :op2 \"March\" ) :mod ( passionate ) :ARG1-of ( upsurge-00 ) :mod ( that ) :poss ( person :wiki \"Mao_Zedong\" :name ( name :op1 \"Mao\" :op2 \"Zedong\" ) ) ) :purpose ( and :op1 ( lose-02 :ARG0 :ARG1 ( head :part-of ) ) :op2 ( shed-01 :ARG0 :ARG1 ( blood :part-of ) ) :op3 ( extinguish-01 :ARG0 :ARG1 ( person :mod ( all ) :ARG0-of ( wrong-01 :ARG1 ( country :wiki \"China\" :name ( name :op1 \"China\" ) ) ) :ARG0-of ( have-rel-role-91 :ARG1 :ARG2 ( enemy ) ) ) :ARG1-of ( complete-02 ) ) :op4 ( spread-03 :ARG0 :ARG1 ( prestige :mod ( dragon ) :poss ( we ) ) ) ) ) :degree ( much :degree ( very ) ) )"} 8 | {"sent": "", "amr": "( tolerate-01 :ARG0 ( we ) :ARG1 ( country :wiki \"Japan\" :name ( name :op1 \"Japan\" ) ) :duration ( amr-unknown ) )"} 9 | {"sent": "", "amr": "( say-01 :ARG0 ( i ) :ARG2 ( citizen :mod ( fellow ) :poss ) )"} 10 | {"sent": "", "amr": "( multi-sentence :snt1 ( say-01 :ARG0 ( i ) :ARG1 ( hello ) :ARG2 ( everyone ) ) :snt2 ( have-concession-91 :ARG1 ( plan-01 :ARG0 ( country :wiki \"Japan\" :name ( name :op1 \"Japan\" ) ) :ARG1 ( incident :mod ( another ) :location ( sea ) :ARG1-of ( mean-01 :ARG2 ( incident :wiki \"Marco_Polo_Bridge_Incident\" :name ( name :op1 \"Lugou\" :op2 \"Bridge\" ) :mod ( another ) :location ( sea ) ) ) :ARG1-of ( resemble-01 :ARG2 ) ) :ARG1-of ( premeditate-01 ) :time ( now ) :mod ( again ) ) :ARG2 ( know-01 :ARG0 ( we :mod ( all ) ) :ARG1 ( incident :wiki \"Mukden_Incident\" :name ( name :op1 \"September\" :op2 \"18th\" ) ) ) ) )"} 11 | -------------------------------------------------------------------------------- /examples/data4parsing.jsonl: -------------------------------------------------------------------------------- 1 | {"sent": "Resolutely support the thread starter! I compose a poem in reply:", "amr": ""} 2 | {"sent": "Fleets bumping fishing boats. Little evil Japanese ghosts stirring up trouble and unrest. With hearts of thieves and arrogant form, they again show their wolfish appearance", "amr": ""} 3 | {"sent": "Never go back to that time, our humiliating appearance when signing the treaties. China be strong, swords be sharp and knives be shining, let's bury the approaching enemies!", "amr": ""} 4 | {"sent": "September 11th, 2010", "amr": ""} 5 | {"sent": "I agree with the analysis of the thread starter, resolutely bump up. If war really comes, I would like to be Hua Mulan, to protect our home and defend our homeland, to spread the prestige of our China, to pledge to fight to the death defending our national dignity.", "amr": ""} 6 | {"sent": "I very much wish I had been born in those passionate and upsurging times of Mao Zedong, the Long March, losing my head and shedding my blood, completely extinguishing all the enemies that wrong China, spreading our dragon prestige!", "amr": ""} 7 | {"sent": "How Long are We Going to Tolerate Japan?", "amr": ""} 8 | {"sent": "My fellow citizens:", "amr": ""} 9 | {"sent": "Hello, everyone! We all know the \"September 18th\" Incident, but now Japan again has a premeditated plan for another \"September 18th\" on the sea, another \"Lugou Bridge on the sea\".", "amr": ""} 10 | -------------------------------------------------------------------------------- /examples/test.jsonl: -------------------------------------------------------------------------------- 1 | {"sent": "Resolutely support the thread starter! I compose a poem in reply:", "amr": "( multi-sentence :snt1 ( support-01 :mode imperative :ARG0 ( you ) :ARG1 ( person :ARG0-of ( start-01 :ARG1 ( thread ) ) ) :manner ( resolute ) ) :snt2 ( reply-01 :ARG0 ( i ) :ARG2 ( compose-02 :ARG0 :ARG1 ( poem ) ) ) )"} 2 | {"sent": "Pledge to fight to the death defending the Diaoyu Islands and the related islands", "amr": "( pledge-01 :mode imperative :ARG0 ( you ) :ARG2 ( fight-01 :ARG0 :ARG2 ( defend-01 :ARG0 :ARG1 ( and :op1 ( island :wiki \"Senkaku_Islands\" :name ( name :op1 \"Diaoyu\" :op2 \"Islands\" ) ) :op2 ( island :ARG1-of ( relate-01 :ARG2 ) ) ) ) :manner ( die-01 :ARG1 ) ) )"} 3 | {"sent": "Fleets bumping fishing boats. Little evil Japanese ghosts stirring up trouble and unrest. With hearts of thieves and arrogant form, they again show their wolfish appearance", "amr": "( multi-sentence :snt1 ( bump-01 :ARG1 ( boat :purpose ( fish-01 ) ) :ARG2 ( fleet ) ) :snt3 ( show-01 :ARG0 ( they :ARG0-of ( have-03 :ARG1 ( and :op1 ( heart :mod ( person :ARG0-of ( steal-01 ) ) ) :op2 ( form :mod ( arrogance ) ) ) ) ) :ARG1 ( appearance :poss :mod ( wolfish ) ) :mod ( again ) ) :snt2 ( and :op1 ( stir-up-04 :ARG0 ( ghost :mod ( country :wiki \"Japan\" :name ( name :op1 \"Japan\" ) ) :mod ( little ) :mod ( evil ) ) :ARG1 ( unrest ) ) :op2 ( make-trouble-06 :ARG0 ) ) )"} 4 | {"sent": "Never go back to that time, our humiliating appearance when signing the treaties. China be strong, swords be sharp and knives be shining, let's bury the approaching enemies!", "amr": "( multi-sentence :snt1 ( go-back-19 :mode imperative :polarity - :ARG1 ( you ) :ARG2 ( time :mod ( that ) :time-of ( appear-01 :ARG1 ( we ) :ARG0-of ( humiliate-01 ) :time ( sign-02 :ARG0 :ARG1 ( treaty ) ) ) ) :time ( ever ) :direction ( back ) ) :snt2 ( strong-02 :mode imperative :ARG1 ( country :wiki \"China\" :name ( name :op1 \"China\" ) ) ) :snt3 ( and :op1 ( sharp-02 :mode imperative :ARG1 ( sword ) ) :op2 ( shine-01 :mode imperative :ARG0 ( knife ) ) ) :snt4 ( bury-01 :mode imperative :ARG0 ( we ) :ARG1 ( enemy :ARG1-of ( approach-01 ) ) ) )"} 5 | {"sent": "September 11th, 2010", "amr": "( date-entity :month 9 :day 11 :year 2010 )"} 6 | {"sent": "I agree with the analysis of the thread starter, resolutely bump up. If war really comes, I would like to be Hua Mulan, to protect our home and defend our homeland, to spread the prestige of our China, to pledge to fight to the death defending our national dignity.", "amr": "( multi-sentence :snt2 ( bump-02 :ARG5 ( up ) :manner ( resolute ) ) :snt1 ( agree-01 :ARG0 ( i ) :ARG1 ( analyze-01 :ARG0 ( person :ARG0-of ( start-01 :ARG1 ( thread ) ) ) ) ) :snt3 ( like-02 :ARG0 ( i ) :ARG1 ( person :wiki \"Hua_Mulan\" :name ( name :op1 \"Hua\" :op2 \"Mulan\" ) :domain :purpose ( and :op1 ( and :op1 ( protect-01 :ARG0 :ARG1 ( home :poss ( we ) ) ) :op2 ( defend-01 :ARG0 :ARG1 ( homeland :poss ) ) ) :op2 ( spread-03 :ARG0 :ARG1 ( prestige :poss ( country :wiki \"China\" :name ( name :op1 \"China\" ) :poss ) ) ) :op5 ( pledge-01 :ARG0 :ARG2 ( fight-01 :ARG0 :ARG2 ( defend-01 :ARG1 ( dignity :mod ( nation ) :poss ) ) :degree ( die-01 :ARG1 ) ) ) ) ) :condition ( come-03 :ARG1 ( war ) :ARG1-of ( real-04 ) ) ) )"} 7 | {"sent": "I very much wish I had been born in those passionate and upsurging times of Mao Zedong, the Long March, losing my head and shedding my blood, completely extinguishing all the enemies that wrong China, spreading our dragon prestige!", "amr": "( wish-01 :ARG0 ( i ) :ARG1 ( bear-02 :ARG1 :time ( event :wiki \"Long_March\" :name ( name :op1 \"Long\" :op2 \"March\" ) :mod ( passionate ) :ARG1-of ( upsurge-00 ) :mod ( that ) :poss ( person :wiki \"Mao_Zedong\" :name ( name :op1 \"Mao\" :op2 \"Zedong\" ) ) ) :purpose ( and :op1 ( lose-02 :ARG0 :ARG1 ( head :part-of ) ) :op2 ( shed-01 :ARG0 :ARG1 ( blood :part-of ) ) :op3 ( extinguish-01 :ARG0 :ARG1 ( person :mod ( all ) :ARG0-of ( wrong-01 :ARG1 ( country :wiki \"China\" :name ( name :op1 \"China\" ) ) ) :ARG0-of ( have-rel-role-91 :ARG1 :ARG2 ( enemy ) ) ) :ARG1-of ( complete-02 ) ) :op4 ( spread-03 :ARG0 :ARG1 ( prestige :mod ( dragon ) :poss ( we ) ) ) ) ) :degree ( much :degree ( very ) ) )"} 8 | {"sent": "How Long are We Going to Tolerate Japan?", "amr": "( tolerate-01 :ARG0 ( we ) :ARG1 ( country :wiki \"Japan\" :name ( name :op1 \"Japan\" ) ) :duration ( amr-unknown ) )"} 9 | {"sent": "My fellow citizens:", "amr": "( say-01 :ARG0 ( i ) :ARG2 ( citizen :mod ( fellow ) :poss ) )"} 10 | {"sent": "Hello, everyone! We all know the \"September 18th\" Incident, but now Japan again has a premeditated plan for another \"September 18th\" on the sea, another \"Lugou Bridge on the sea\".", "amr": "( multi-sentence :snt1 ( say-01 :ARG0 ( i ) :ARG1 ( hello ) :ARG2 ( everyone ) ) :snt2 ( have-concession-91 :ARG1 ( plan-01 :ARG0 ( country :wiki \"Japan\" :name ( name :op1 \"Japan\" ) ) :ARG1 ( incident :mod ( another ) :location ( sea ) :ARG1-of ( mean-01 :ARG2 ( incident :wiki \"Marco_Polo_Bridge_Incident\" :name ( name :op1 \"Lugou\" :op2 \"Bridge\" ) :mod ( another ) :location ( sea ) ) ) :ARG1-of ( resemble-01 :ARG2 ) ) :ARG1-of ( premeditate-01 ) :time ( now ) :mod ( again ) ) :ARG2 ( know-01 :ARG0 ( we :mod ( all ) ) :ARG1 ( incident :wiki \"Mukden_Incident\" :name ( name :op1 \"September\" :op2 \"18th\" ) ) ) ) )"} 11 | -------------------------------------------------------------------------------- /examples/train.jsonl: -------------------------------------------------------------------------------- 1 | {"sent": "Establishing Models in Industrial Innovation", "amr": "( establish-01 :ARG1 ( model :mod ( innovate-01 :ARG1 ( industry ) ) ) )"} 2 | {"sent": "After its competitor invented the front loading washing machine, the CEO of the American IM company believed that each of its employees had the ability for innovation , and formulated strategic countermeasures for innovation in the industry.", "amr": "( and :op1 ( believe-01 :ARG0 ( person :ARG0-of ( have-org-role-91 :ARG1 ( company :wiki - :name ( name :op1 \"IM\" ) :mod ( country :wiki \"United_States\" :name ( name :op1 \"United\" :op2 \"States\" ) ) ) :ARG2 ( officer :mod ( executive ) :mod ( chief ) ) ) ) :ARG1 ( capable-01 :ARG1 ( person :ARG1-of ( employ-01 :ARG0 ) :mod ( each ) ) :ARG2 ( innovate-01 :ARG0 ) ) ) :op2 ( formulate-01 :ARG0 ( officer :mod ( executive ) :mod ( chief ) ) :ARG1 ( countermeasure :mod ( strategy ) :purpose ( innovate-01 :topic ( industry ) ) ) ) :time ( after :op1 ( invent-01 :ARG0 ( company :ARG0-of ( compete-02 :ARG1 ) ) :ARG1 ( machine :ARG0-of ( wash-01 ) :ARG1-of ( load-01 :mod ( front ) ) ) ) ) )"} 3 | {"sent": "1. Establish an innovation fund with a maximum amount of 1,000 U.S. dollars.", "amr": "( establish-01 :li 1 :ARG1 ( fund :purpose ( innovate-01 ) :ARG1-of ( amount-01 :ARG2 ( at-most :op1 ( monetary-quantity :quant 1000 :unit ( dollar :mod ( country :wiki \"United_States\" :name ( name :op1 \"United\" :op2 \"States\" ) ) ) ) ) ) ) )"} 4 | {"sent": "2. Choose 100 innovative concepts to encourage employees to conduct research and development during their work time or spare time.", "amr": "( choose-01 :ARG1 ( concept :quant 100 :ARG1-of ( innovate-01 ) ) :li 2 :purpose ( encourage-01 :ARG0 :ARG1 ( person :ARG1-of ( employ-01 ) ) :ARG2 ( and :op1 ( research-01 :ARG0 ) :op2 ( develop-02 :ARG0 ) :time ( or :op1 ( work-01 :ARG0 ) :op2 ( time :poss :mod ( spare ) ) ) ) ) )"} 5 | {"sent": "3. From among them, pick out 50 for submission to an assessment committee to assess.", "amr": "( pick-out-03 :ARG1 ( thing :quant 50 :ARG1-of ( submit-01 :ARG2 ( committee :ARG0-of ( assess-01 ) ) :ARG3 ( assess-01 :ARG0 :ARG1 ) ) ) :ARG2 ( they ) :li 3 )"} 6 | {"sent": "Since the Tangshan Earthquake, the starting point for construction standards in the mainland is that under an earthquake of the same magnitude, buildings should preserve their basic frame without collapsing.", "amr": "( point :mod ( start-01 :ARG1 ( standard :mod ( construct-01 ) ) ) :location ( mainland ) :domain ( recommend-01 :ARG1 ( preserve-01 :ARG0 ( building ) :ARG1 ( frame :part-of ( thing ) :mod ( basic ) :poss ) :manner ( collapse-01 :polarity - :ARG1 ) ) :time ( since :op1 ( earthquake :wiki \"1976_Tangshan_earthquake\" :name ( name :op1 \"Tangshan\" ) ) ) :condition ( earthquake :mod ( magnitude :ARG1-of ( same-01 :ARG2 ( magnitude :poss ) ) ) ) ) )"} 7 | {"sent": "However, most of the buildings in this hard-hit area did not meet these requirements, with the widespread collapse of school buildings in particular arousing intense public disgust.", "amr": "( contrast-01 :ARG2 ( and :op1 ( meet-01 :polarity - :ARG0 ( building :quant ( most ) :ARG1-of ( include-91 :ARG2 ( building :location ( area :ARG1-of ( hit-01 :ARG1-of ( hard-04 ) ) :mod ( this ) ) ) ) ) :ARG1 ( thing :ARG1-of ( require-01 ) :mod ( this ) ) ) :op2 ( arouse-01 :ARG0 ( collapse-01 :ARG1 ( building :mod ( school ) ) :ARG1-of ( spread-02 :ARG1-of ( wide-02 ) ) ) :ARG1 ( disgust-01 :ARG1 ( public ) :ARG1-of ( intense-02 ) ) :mod ( particular ) ) ) )"} 8 | {"sent": "Raising standards to in excess of Tangshan's 8.0 magnitude could leave authorities with some breathing space for explanation, and alleviate public anger.", "amr": "( possible-01 :ARG1 ( and :op1 ( leave-13 :ARG0 ( raise-01 :ARG1 ( standard ) :ARG4 ( in-excess-of :op1 ( seismic-quantity :quant 8.0 :poss ( earthquake :wiki \"1976_Tangshan_earthquake\" :name ( name :op1 \"Tangshan\" ) ) ) ) ) :ARG1 ( have-03 :ARG0 ( authority ) :ARG1 ( space :mod ( breathe-01 ) :quant ( some ) :purpose ( explain-01 :ARG0 ) ) ) ) :op2 ( alleviate-01 :ARG0 :ARG1 ( anger-01 :ARG1 ( public ) ) ) ) )"} 9 | {"sent": "According to information leaked from numerous channels, we can say for certain that before the earthquake struck, the serious earthquake risk in the Ngawa region was already well known to the CCP.", "amr": "( say-01 :ARG0 ( information :ARG1-of ( leak-01 :ARG0 ( channel :quant ( numerous ) ) ) ) :ARG1 ( possible-01 :ARG1 ( say-01 :ARG0 ( we ) :ARG1 ( know-02 :ARG0 ( political-party :wiki \"Communist_Party_of_China\" :name ( name :op1 \"CCP\" ) ) :ARG1 ( risk-01 :ARG2 ( earthquake ) :ARG1-of ( serious-02 ) :location ( local-region :wiki \"Ngawa_Tibetan_and_Qiang_Autonomous_Prefecture\" :name ( name :op1 \"Ngawa\" ) ) ) :time ( before :op1 ( strike-01 :ARG2 ( earthquake ) ) ) :time ( already ) :ARG1-of ( well-09 ) ) :manner ( certain ) ) ) )"} 10 | {"sent": "Although current forecasting standards cannot give us an accurate prediction of the exact time, place, and strength of an earthquake, there is considerable experience accumulated both in China and overseas in predicting to within the timeframe of a month, and the area of a province.", "amr": "( have-concession-91 :ARG1 ( accumulate-01 :ARG1 ( experience :quant ( considerable ) :topic ( and :op1 ( predict-01 :extent ( temporal-quantity :quant 1 :unit ( month ) ) ) :op2 ( predict-01 :ARG1 ( area :part-of ( province ) ) ) ) ) :location ( and :op1 ( country :wiki \"China\" :name ( name :op1 \"China\" ) ) :op2 ( overseas ) ) ) :ARG2 ( possible-01 :polarity - :ARG1 ( predict-01 :ARG0 ( standard :mod ( forecast-01 ) :time ( current ) ) :ARG1 ( and :op1 ( time :time-of ( earthquake ) ) :op2 ( place :location-of ) :op3 ( strong-02 :ARG1 ) :mod ( exact ) ) :mod ( accurate ) :beneficiary ( we ) ) ) )"} 11 | -------------------------------------------------------------------------------- /examples/val.jsonl: -------------------------------------------------------------------------------- 1 | {"sent": "There are many who have a sense of urgency, quietly watching how things develop,you are dragons coiling, you are tigers crouching, I admire noble-minded patriots.", "amr": "( multi-sentence :snt1 ( many :ARG0-of ( sense-01 :ARG1 ( urgency ) :time ( watch-01 :ARG0 :ARG1 ( thing :manner-of ( develop-02 :ARG0 ( thing ) ) ) :manner ( quiet-04 :ARG1 ) ) ) ) :snt2 ( dragon :domain ( you ) :ARG0-of ( coil-01 ) ) :snt3 ( tiger :domain ( you ) :ARG0-of ( crouch-01 ) ) :snt4 ( admire-01 :ARG0 ( i ) :ARG1 ( patriot :poss-of ( mind :mod ( noble ) ) ) ) )"} 2 | {"sent": "Has history given us too many lessons?, 530, 412, 64", "amr": "( multi-sentence :snt1 ( give-01 :ARG0 ( history ) :ARG1 ( lesson :ARG1-of ( have-quant-91 :ARG2 ( many ) :ARG3 ( too ) ) ) :ARG2 ( we ) :polarity ( amr-unknown ) ) :snt2 ( and :op1 530 :op2 412 :op3 64 ) )"} 3 | {"sent": "taking a look", "amr": "( look-01 )"} 4 | {"sent": "the ones who are suffering are the ordinary people: even if the body of a salted fish is turned over, it is still a salted fish ...", "amr": "( multi-sentence :snt1 ( suffer-01 :ARG0 ( person :mod ( ordinary ) ) ) :snt2 ( fish :ARG1-of ( salt-01 ) :mod ( still ) :domain :concession ( even-if :op1 ( turn-01 :ARG1 ( body :poss ( fish :ARG1-of ( salt-01 ) ) ) :direction ( over ) ) ) ) )"} 5 | {"sent": "Freedom of speech\\thought, if people express a view somewhat different than the traditional view, and put forward slightly different criticism, then they are called slaves of foreigners, or are accused of reverence for and fascination by foreign things,", "amr": "( or :op1 ( call-01 :ARG1 ( person ) :ARG2 ( slave :poss ( foreign ) ) ) :op2 ( accuse-01 :ARG1 :ARG2 ( and :op1 ( revere-01 :ARG0 :ARG1 ( thing :mod ( foreign ) ) ) :op2 ( fascinate-01 :ARG0 :ARG1 ) ) ) :condition ( and :op1 ( express-01 :ARG0 :ARG1 ( view-02 :ARG0 :ARG1-of ( differ-02 :ARG2 ( view-02 :mod ( tradition ) ) :degree ( somewhat ) ) ) ) :op2 ( criticize-01 :ARG0 :ARG1-of ( differ-02 :degree ( slight ) ) ) ) :topic ( free-04 :ARG3 ( slash :op1 ( speak-01 ) :op2 ( think-01 ) ) ) )"} 6 | {"sent": "What is more they are considered traitors of China, which is a fact of cultural tyranny in the cloak of nationalism and patriotism.", "amr": "( consider-01 :ARG1 ( person :domain ( they ) :ARG0-of ( betray-01 :ARG1 ( country :wiki \"China\" :name ( name :op1 \"China\" ) ) ) ) :mod ( more ) :mod ( tyrannize-01 :ARG2 ( culture ) :ARG1-of ( cloak-01 :ARG2 ( and :op1 ( nationalism ) :op2 ( patriotism ) ) ) ) )"} 7 | {"sent": "In fact, the US no longer needs to use force to deal with China, they have achieved the result of \"defeating enemy soldiers without fighting\".", "amr": "( cause-01 :ARG0 ( achieve-01 :ARG0 :ARG1 ( result-01 :ARG2 ( defeat-01 :ARG0 :ARG1 ( soldier :ARG0-of ( have-rel-role-91 :ARG1 :ARG2 ( enemy ) ) ) :manner ( fight-01 :polarity - :ARG0 ) ) ) ) :ARG1 ( need-01 :ARG0 ( country :wiki \"United_States\" :name ( name :op1 \"US\" ) ) :ARG1 ( use-01 :ARG0 :ARG1 ( force-04 ) :ARG2 ( deal-01 :ARG0 :ARG2 ( country :wiki \"China\" :name ( name :op1 \"China\" ) ) ) ) :time ( no-longer ) :mod ( in-fact ) ) )"} 8 | {"sent": "Is the article too intense, is the United States so good?", "amr": "( multi-sentence :snt1 ( intense-02 :ARG1 ( article ) :polarity ( amr-unknown ) :ARG2-of ( have-degree-91 :ARG1 :ARG3 ( too ) ) ) :snt2 ( good-02 :ARG1 ( country :wiki \"United_States\" :name ( name :op1 \"United\" :op2 \"States\" ) ) :polarity ( amr-unknown ) :ARG2-of ( have-degree-91 :ARG1 :ARG3 ( so ) ) ) )"} 9 | {"sent": "If things are not seen by eyes and heard by ears yourself, do not assume their existence, three people create a tiger.", "amr": "( multi-sentence :snt1 ( assume-02 :polarity - :mode imperative :ARG0 :ARG1 ( exist-01 :ARG1 ) :condition ( and :op1 ( see-01 :polarity - :ARG0 ( eye :part-of ( you ) ) :ARG1 ( thing ) ) :op2 ( hear-01 :polarity - :ARG0 ( ear :part-of ) :ARG1 ) ) ) :snt2 ( create-01 :ARG0 ( person :quant 3 ) :ARG1 ( tiger ) ) )"} 10 | {"sent": "Just passing by and taking a look. Won't express my opinion", "amr": "( multi-sentence :snt1 ( and :op1 ( pass-by-17 :mod ( just ) ) :op2 ( look-01 ) ) :snt2 ( express-01 :polarity - :ARG0 ( i ) :ARG1 ( thing :ARG1-of ( opine-01 :ARG0 ) ) ) )"} 11 | -------------------------------------------------------------------------------- /fine-tune/Eval-AMRBART-large-AMR2Text.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | RootDir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 3 | 4 | Dataset=LDC2020 5 | Dataset=LDC2017 6 | 7 | BasePath=/mnt/nfs-storage/data # change dir here 8 | DataPath=$RootDir/data/$Dataset 9 | 10 | ModelCate=AMRBART-large 11 | 12 | MODEL=$1 13 | ModelCache=$BasePath/.cache 14 | DataCache=$DataPath/.cache/dump-amr2text 15 | 16 | lr=2e-6 17 | 18 | OutputDir=${RootDir}/outputs/Eval-$Dataset-$ModelCate-AMR2Text-bsz16-lr-${lr}-UnifiedInp 19 | 20 | if [ ! -d ${OutputDir} ];then 21 | mkdir -p ${OutputDir} 22 | else 23 | read -p "${OutputDir} already exists, delete origin one [y/n]?" yn 24 | case $yn in 25 | [Yy]* ) rm -rf ${OutputDir}; mkdir -p ${OutputDir};; 26 | [Nn]* ) echo "exiting..."; exit;; 27 | * ) echo "Please answer yes or no.";; 28 | esac 29 | fi 30 | 31 | export HF_DATASETS_CACHE=$DataCache 32 | 33 | if [ ! -d ${DataCache} ];then 34 | mkdir -p ${DataCache} 35 | fi 36 | 37 | # torchrun --nnodes=1 --nproc_per_node=1 --max_restarts=0 --rdzv_id=1 --rdzv_backend=c10d main.py \ 38 | python -u main.py \ 39 | --data_dir $DataPath \ 40 | --task "amr2text" \ 41 | --validation_file $DataPath/val.jsonl \ 42 | --test_file $DataPath/test.jsonl \ 43 | --output_dir $OutputDir \ 44 | --cache_dir $ModelCache \ 45 | --data_cache_dir $DataCache \ 46 | --model_name_or_path $MODEL \ 47 | --overwrite_output_dir \ 48 | --unified_input True \ 49 | --per_device_eval_batch_size 16 \ 50 | --max_source_length 1024 \ 51 | --max_target_length 400 \ 52 | --val_max_target_length 400 \ 53 | --generation_max_length 400 \ 54 | --generation_num_beams 5 \ 55 | --predict_with_generate \ 56 | --smart_init False \ 57 | --use_fast_tokenizer False \ 58 | --logging_dir $OutputDir/logs \ 59 | --seed 42 \ 60 | --fp16 \ 61 | --fp16_backend "auto" \ 62 | --dataloader_num_workers 8 \ 63 | --eval_dataloader_num_workers 2 \ 64 | --metric_for_best_model "eval_bleu" \ 65 | --include_inputs_for_metrics \ 66 | --do_eval \ 67 | --do_predict \ 68 | --ddp_find_unused_parameters False \ 69 | --report_to "tensorboard" \ 70 | --dataloader_pin_memory True 2>&1 | tee $OutputDir/run.log 71 | -------------------------------------------------------------------------------- /fine-tune/Eval-AMRBART-large-AMRParsing.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | RootDir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 3 | 4 | Dataset=LDC2020 5 | Dataset=LDC2017 6 | 7 | BasePath=/mnt/nfs-storage/data # change dir here 8 | DataPath=$RootDir/data/$Dataset 9 | 10 | ModelCate=AMRBART-large 11 | 12 | MODEL=$1 13 | ModelCache=$BasePath/.cache 14 | DataCache=$DataPath/.cache/dump-amrparsing 15 | 16 | lr=1e-5 17 | 18 | OutputDir=${RootDir}/outputs/Eval-$Dataset-${ModelCate}-AMRParing-bsz8-lr-${lr}-UnifiedInp 19 | 20 | if [ ! -d ${OutputDir} ];then 21 | mkdir -p ${OutputDir} 22 | else 23 | read -p "${OutputDir} already exists, delete origin one [y/n]?" yn 24 | case $yn in 25 | [Yy]* ) rm -rf ${OutputDir}; mkdir -p ${OutputDir};; 26 | [Nn]* ) echo "exiting..."; exit;; 27 | * ) echo "Please answer yes or no.";; 28 | esac 29 | fi 30 | 31 | export HF_DATASETS_CACHE=$DataCache 32 | 33 | if [ ! -d ${DataCache} ];then 34 | mkdir -p ${DataCache} 35 | fi 36 | 37 | # torchrun --nnodes=1 --nproc_per_node=1 --max_restarts=0 --rdzv_id=1 --rdzv_backend=c10d main.py \ 38 | python -u main.py \ 39 | --data_dir $DataPath \ 40 | --task "text2amr" \ 41 | --validation_file $DataPath/val.jsonl \ 42 | --test_file $DataPath/test.jsonl \ 43 | --output_dir $OutputDir \ 44 | --cache_dir $ModelCache \ 45 | --data_cache_dir $DataCache \ 46 | --overwrite_cache True \ 47 | --model_name_or_path $MODEL \ 48 | --overwrite_output_dir \ 49 | --unified_input True \ 50 | --per_device_eval_batch_size 8 \ 51 | --max_source_length 400 \ 52 | --max_target_length 1024 \ 53 | --val_max_target_length 1024 \ 54 | --generation_max_length 1024 \ 55 | --generation_num_beams 5 \ 56 | --predict_with_generate \ 57 | --smart_init False \ 58 | --use_fast_tokenizer False \ 59 | --logging_dir $OutputDir/logs \ 60 | --seed 42 \ 61 | --fp16 \ 62 | --fp16_backend "auto" \ 63 | --dataloader_num_workers 8 \ 64 | --eval_dataloader_num_workers 2 \ 65 | --include_inputs_for_metrics \ 66 | --metric_for_best_model "eval_smatch" \ 67 | --do_eval \ 68 | --do_predict \ 69 | --ddp_find_unused_parameters False \ 70 | --report_to "tensorboard" \ 71 | --dataloader_pin_memory True 2>&1 | tee $OutputDir/run.log -------------------------------------------------------------------------------- /fine-tune/common/callbacks.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | 3 | import logging 4 | import os 5 | import torch 6 | import numpy as np 7 | from pathlib import Path 8 | from transformers import TrainerCallback 9 | from transformers.training_args import TrainingArguments 10 | from transformers.trainer_callback import TrainerControl, TrainerState 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class Seq2seqCallback(TrainerCallback): 17 | 18 | def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs): 19 | pass 20 | 21 | def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 22 | """ 23 | Event called at the beginning of an epoch. 24 | """ 25 | pass 26 | 27 | def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 28 | """ 29 | Event called at the end of an epoch. 30 | """ 31 | pass 32 | 33 | def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 34 | """ 35 | Event called at the beginning of a training step. If using gradient accumulation, one training step might take 36 | several inputs. 37 | """ 38 | pass 39 | 40 | def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 41 | """ 42 | Event called at the end of an substep during gradient accumulation. 43 | """ 44 | pass 45 | 46 | def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 47 | """ 48 | Event called at the end of a training step. If using gradient accumulation, one training step might take 49 | several inputs. 50 | """ 51 | pass 52 | 53 | def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 54 | """ 55 | Event called after an evaluation phase. 56 | """ 57 | pass 58 | 59 | def on_predict(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs): 60 | """ 61 | Event called after a successful prediction. 62 | """ 63 | pass 64 | 65 | def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 66 | """ 67 | Event called after a checkpoint save. 68 | """ 69 | pass 70 | 71 | def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 72 | """ 73 | Event called after logging the last logs. 74 | """ 75 | pass 76 | 77 | def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 78 | """ 79 | Event called after a prediction step. 80 | """ 81 | pass -------------------------------------------------------------------------------- /fine-tune/common/constant.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import os 3 | import json 4 | 5 | from transformers.optimization import ( 6 | get_cosine_schedule_with_warmup, 7 | get_cosine_with_hard_restarts_schedule_with_warmup, 8 | get_linear_schedule_with_warmup, 9 | get_polynomial_decay_schedule_with_warmup, 10 | get_constant_schedule_with_warmup, 11 | ) 12 | 13 | from transformers import ( 14 | WEIGHTS_NAME, 15 | AdamW, 16 | Adafactor, 17 | AutoConfig, 18 | AutoTokenizer, 19 | AutoModelForSeq2SeqLM, 20 | BartTokenizer, 21 | BartForConditionalGeneration, 22 | T5Tokenizer, 23 | T5Model, 24 | T5ForConditionalGeneration, 25 | ) 26 | 27 | raw_special_tokens = json.load( 28 | open(f"{os.path.dirname(__file__)}/additional-tokens.json", "r", encoding="utf-8") 29 | ) 30 | special_tokens = [itm.lstrip("Ġ") for itm in raw_special_tokens] 31 | 32 | recategorizations = [ 33 | "\u0120COUNTRY", 34 | "\u0120QUANTITY", 35 | "\u0120ORGANIZATION", 36 | "\u0120DATE_ATTRS", 37 | "\u0120NATIONALITY", 38 | "\u0120LOCATION", 39 | "\u0120ENTITY", 40 | "\u0120MISC", 41 | "\u0120ORDINAL_ENTITY", 42 | "\u0120IDEOLOGY", 43 | "\u0120RELIGION", 44 | "\u0120STATE_OR_PROVINCE", 45 | "\u0120CAUSE_OF_DEATH", 46 | "\u0120TITLE", 47 | "\u0120DATE", 48 | "\u0120NUMBER", 49 | "\u0120HANDLE", 50 | "\u0120SCORE_ENTITY", 51 | "\u0120DURATION", 52 | "\u0120ORDINAL", 53 | "\u0120MONEY", 54 | "\u0120CRIMINAL_CHARGE", 55 | ] 56 | 57 | # special_tokens = ["", ""] 58 | 59 | arg_to_scheduler = { 60 | "linear": get_linear_schedule_with_warmup, 61 | "cosine": get_cosine_schedule_with_warmup, 62 | "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, 63 | "polynomial": get_polynomial_decay_schedule_with_warmup, 64 | "constant": get_constant_schedule_with_warmup, 65 | } 66 | arg_to_scheduler_choices = sorted(arg_to_scheduler.keys()) 67 | arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}" 68 | 69 | ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"] 70 | 71 | arg_to_tokenizer = { 72 | "AutoTokenizer": AutoTokenizer, 73 | "BartTokenizer": BartTokenizer, 74 | "T5Tokenizer": T5Tokenizer, 75 | } 76 | arg_to_plm_model = { 77 | "AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM, 78 | "BartForConditionalGeneration": BartForConditionalGeneration, 79 | "T5Model": T5Model, 80 | "T5ForConditionalGeneration": T5ForConditionalGeneration, 81 | } 82 | -------------------------------------------------------------------------------- /fine-tune/common/options.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | from typing import Optional 3 | from dataclasses import dataclass, field 4 | from common.training_args import TrainingArguments 5 | 6 | 7 | @dataclass 8 | class ModelArguments: 9 | """ 10 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 11 | """ 12 | 13 | model_name_or_path: str = field( 14 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 15 | ) 16 | config_name: Optional[str] = field( 17 | default=None, 18 | metadata={"help": "Pretrained config name or path if not the same as model_name"}, 19 | ) 20 | tokenizer_name: Optional[str] = field( 21 | default=None, 22 | metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}, 23 | ) 24 | cache_dir: Optional[str] = field( 25 | default=None, 26 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 27 | ) 28 | use_fast_tokenizer: bool = field( 29 | default=True, 30 | metadata={ 31 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." 32 | }, 33 | ) 34 | model_revision: str = field( 35 | default="main", 36 | metadata={ 37 | "help": "The specific model version to use (can be a branch name, tag name or commit id)." 38 | }, 39 | ) 40 | use_auth_token: bool = field( 41 | default=False, 42 | metadata={ 43 | "help": ( 44 | "Will use the token generated when running `transformers-cli login` (necessary to use this script " 45 | "with private models)." 46 | ) 47 | }, 48 | ) 49 | resize_position_embeddings: Optional[bool] = field( 50 | default=None, 51 | metadata={ 52 | "help": ( 53 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 54 | "the model's position embeddings." 55 | ) 56 | }, 57 | ) 58 | 59 | 60 | @dataclass 61 | class DataTrainingArguments: 62 | """ 63 | Arguments pertaining to what data we are going to input our model for training and eval. 64 | """ 65 | 66 | lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."}) 67 | 68 | dataset_name: Optional[str] = field( 69 | default=None, 70 | metadata={"help": "The name of the dataset to use (via the datasets library)."}, 71 | ) 72 | dataset_config_name: Optional[str] = field( 73 | default=None, 74 | metadata={ 75 | "help": "The configuration name of the dataset to use (via the datasets library)." 76 | }, 77 | ) 78 | text_column: Optional[str] = field( 79 | default=None, 80 | metadata={ 81 | "help": "The name of the column in the datasets containing the full texts (for summarization)." 82 | }, 83 | ) 84 | summary_column: Optional[str] = field( 85 | default=None, 86 | metadata={ 87 | "help": "The name of the column in the datasets containing the summaries (for summarization)." 88 | }, 89 | ) 90 | data_dir: Optional[str] = field( 91 | default=None, metadata={"help": "The directory which stores gold AMRs."} 92 | ) 93 | unified_input: Optional[bool] = field( 94 | default=False, metadata={"help": "Whether to use unified input format for finetuning."} 95 | ) 96 | train_file: Optional[str] = field( 97 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 98 | ) 99 | validation_file: Optional[str] = field( 100 | default=None, 101 | metadata={ 102 | "help": ( 103 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 104 | ) 105 | }, 106 | ) 107 | test_file: Optional[str] = field( 108 | default=None, 109 | metadata={ 110 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 111 | }, 112 | ) 113 | data_cache_dir: Optional[str] = field( 114 | default=None, 115 | metadata={"help": "Where to store the cached dataset"}, 116 | ) 117 | overwrite_cache: bool = field( 118 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 119 | ) 120 | preprocessing_num_workers: Optional[int] = field( 121 | default=None, 122 | metadata={"help": "The number of processes to use for the preprocessing."}, 123 | ) 124 | max_source_length: Optional[int] = field( 125 | default=1024, 126 | metadata={ 127 | "help": ( 128 | "The maximum total input sequence length after tokenization. Sequences longer " 129 | "than this will be truncated, sequences shorter will be padded." 130 | ) 131 | }, 132 | ) 133 | max_source_amr_length: Optional[int] = field( 134 | default=1024, 135 | metadata={ 136 | "help": ( 137 | "The maximum total input sequence length after tokenization. Sequences longer " 138 | "than this will be truncated, sequences shorter will be padded." 139 | ) 140 | }, 141 | ) 142 | max_target_length: Optional[int] = field( 143 | default=128, 144 | metadata={ 145 | "help": ( 146 | "The maximum total sequence length for target text after tokenization. Sequences longer " 147 | "than this will be truncated, sequences shorter will be padded." 148 | ) 149 | }, 150 | ) 151 | val_max_target_length: Optional[int] = field( 152 | default=None, 153 | metadata={ 154 | "help": ( 155 | "The maximum total sequence length for validation target text after tokenization. Sequences longer " 156 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 157 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 158 | "during ``evaluate`` and ``predict``." 159 | ) 160 | }, 161 | ) 162 | pad_to_max_length: bool = field( 163 | default=False, 164 | metadata={ 165 | "help": ( 166 | "Whether to pad all samples to model maximum sentence length. " 167 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 168 | "efficient on GPU but very bad for TPU." 169 | ) 170 | }, 171 | ) 172 | max_train_samples: Optional[int] = field( 173 | default=None, 174 | metadata={ 175 | "help": ( 176 | "For debugging purposes or quicker training, truncate the number of training examples to this " 177 | "value if set." 178 | ) 179 | }, 180 | ) 181 | max_eval_samples: Optional[int] = field( 182 | default=None, 183 | metadata={ 184 | "help": ( 185 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 186 | "value if set." 187 | ) 188 | }, 189 | ) 190 | max_predict_samples: Optional[int] = field( 191 | default=None, 192 | metadata={ 193 | "help": ( 194 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 195 | "value if set." 196 | ) 197 | }, 198 | ) 199 | num_beams: Optional[int] = field( 200 | default=None, 201 | metadata={ 202 | "help": ( 203 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 204 | "which is used during ``evaluate`` and ``predict``." 205 | ) 206 | }, 207 | ) 208 | ignore_pad_token_for_loss: bool = field( 209 | default=True, 210 | metadata={ 211 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 212 | }, 213 | ) 214 | use_speaker_prefix: bool = field( 215 | default=True, 216 | metadata={ 217 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 218 | }, 219 | ) 220 | source_prefix: Optional[str] = field( 221 | default="", 222 | metadata={"help": "A prefix to add before every source text (useful for T5 models)."}, 223 | ) 224 | target_prefix: Optional[str] = field( 225 | default="", 226 | metadata={"help": "A prefix to add before every target text (useful for T5 models)."}, 227 | ) 228 | forced_bos_token: Optional[str] = field( 229 | default=None, 230 | metadata={ 231 | "help": ( 232 | "The token to force as the first generated token after the decoder_start_token_id." 233 | "Useful for multilingual models like mBART where the first generated token" 234 | "needs to be the target language token (Usually it is the target language token)" 235 | ) 236 | }, 237 | ) 238 | 239 | def __post_init__(self): 240 | # if self.do_train and self.dataset_name is None and self.train_file is None and self.validation_file is None: 241 | # raise ValueError("Need either a dataset name or a training/validation file.") 242 | # else: 243 | # if self.train_file is not None: 244 | # extension = self.train_file.split(".")[-1] 245 | # assert extension in [ 246 | # "csv", 247 | # "json", 248 | # "jsonl", 249 | # ], "`train_file` should be a csv or a json file." 250 | # if self.validation_file is not None: 251 | # extension = self.validation_file.split(".")[-1] 252 | # assert extension in [ 253 | # "csv", 254 | # "json", 255 | # "jsonl", 256 | # ], "`validation_file` should be a csv or a json file." 257 | if self.val_max_target_length is None: 258 | self.val_max_target_length = self.max_target_length 259 | 260 | 261 | @dataclass 262 | class Seq2SeqTrainingArguments(TrainingArguments): 263 | """ 264 | Args: 265 | sortish_sampler (`bool`, *optional*, defaults to `False`): 266 | Whether to use a *sortish sampler* or not. Only possible if the underlying datasets are *Seq2SeqDataset* 267 | for now but will become generally available in the near future. 268 | It sorts the inputs according to lengths in order to minimize the padding size, with a bit of randomness 269 | for the training set. 270 | predict_with_generate (`bool`, *optional*, defaults to `False`): 271 | Whether to use generate to calculate generative metrics (ROUGE, BLEU). 272 | generation_max_length (`int`, *optional*): 273 | The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default to the 274 | `max_length` value of the model configuration. 275 | generation_num_beams (`int`, *optional*): 276 | The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default to the 277 | `num_beams` value of the model configuration. 278 | """ 279 | eval_dataloader_num_workers: int = field( 280 | default=0, 281 | metadata={ 282 | "help": ( 283 | "Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded" 284 | " in the main process." 285 | ) 286 | }, 287 | ) 288 | sortish_sampler: bool = field( 289 | default=False, metadata={"help": "Whether to use SortishSampler or not."} 290 | ) 291 | smart_init: bool = field( 292 | default=False, 293 | metadata={"help": "Whether to use initialize AMR embeddings with their sub-word embeddings."}, 294 | ) 295 | predict_with_generate: bool = field( 296 | default=False, 297 | metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}, 298 | ) 299 | task: str = field( 300 | default="amr2text", 301 | metadata={"help": "The name of the task, (amr2text or text2amr)."}, 302 | ) 303 | generation_max_length: Optional[int] = field( 304 | default=None, 305 | metadata={ 306 | "help": ( 307 | "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default " 308 | "to the `max_length` value of the model configuration." 309 | ) 310 | }, 311 | ) 312 | generation_num_beams: Optional[int] = field( 313 | default=None, 314 | metadata={ 315 | "help": ( 316 | "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default " 317 | "to the `num_beams` value of the model configuration." 318 | ) 319 | }, 320 | ) 321 | early_stopping: Optional[int] = field( 322 | default=5, metadata={"help": "Early stopping patience for training"} 323 | ) 324 | eval_lenpen: Optional[float] = field( 325 | default=1.0, metadata={"help": "lenpen for generation"} 326 | ) 327 | -------------------------------------------------------------------------------- /fine-tune/common/penman_interface.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | from penman import load as load_, Graph, Triple 3 | from penman import loads as loads_ 4 | from penman import encode as encode_ 5 | from penman.model import Model 6 | from penman.models.noop import NoOpModel 7 | from penman.models import amr 8 | 9 | op_model = Model() 10 | noop_model = NoOpModel() 11 | amr_model = amr.model 12 | DEFAULT = op_model 13 | 14 | 15 | def _get_model(dereify): 16 | if dereify is None: 17 | return DEFAULT 18 | 19 | elif dereify: 20 | return op_model 21 | 22 | else: 23 | return noop_model 24 | 25 | 26 | def _remove_wiki(graph): 27 | metadata = graph.metadata 28 | triples = [] 29 | for t in graph.triples: 30 | v1, rel, v2 = t 31 | if rel == ":wiki": 32 | t = Triple(v1, rel, "+") 33 | triples.append(t) 34 | graph = Graph(triples) 35 | graph.metadata = metadata 36 | return graph 37 | 38 | 39 | def load(source, dereify=None, remove_wiki=False): 40 | model = _get_model(dereify) 41 | out = load_(source=source, model=model) 42 | if remove_wiki: 43 | for i in range(len(out)): 44 | out[i] = _remove_wiki(out[i]) 45 | return out 46 | 47 | 48 | def loads(string, dereify=None, remove_wiki=False): 49 | model = _get_model(dereify) 50 | out = loads_(string=string, model=model) 51 | if remove_wiki: 52 | for i in range(len(out)): 53 | out[i] = _remove_wiki(out[i]) 54 | return out 55 | 56 | 57 | def encode(g, top=None, indent=-1, compact=False): 58 | model = amr_model 59 | return encode_(g=g, top=top, indent=indent, compact=compact, model=model) 60 | -------------------------------------------------------------------------------- /fine-tune/common/postprocessing.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import re 3 | import enum 4 | import penman 5 | import networkx as nx 6 | from common.penman_interface import encode 7 | from collections import defaultdict, Counter 8 | 9 | BACKOFF = penman.Graph( 10 | [ 11 | penman.Triple("d2", ":instance", "dog"), 12 | penman.Triple("b1", ":instance", "bark-01"), 13 | penman.Triple("b1", ":ARG0", "d2"), 14 | ] 15 | ) 16 | 17 | 18 | def token_processing(tok): 19 | if tok is None: 20 | return None 21 | elif tok.isdigit(): 22 | try: 23 | return eval(tok) 24 | except: 25 | return tok 26 | elif tok.startswith('"') and (not tok.endswith('"')): 27 | return tok + '"' 28 | elif tok.endswith('"') and (not tok.startswith('"')): 29 | return '"' + tok 30 | else: 31 | return tok 32 | 33 | 34 | def decode_into_node_and_backreferences(subtoken_ids, tokenizer): 35 | rex_arg = re.compile(f"^{tokenizer.INIT}(op|snt|conj|prep)") 36 | rex_spc = re.compile(r"<(s|/s|lit|/lit|stop|unk|pad|mask)>") 37 | 38 | # subtoken_ids.insert(1,36) # add "(" id 39 | # subtoken_ids.insert(-1, 4839) # add ")" id 40 | 41 | # get strings 42 | subtokens = [tokenizer.decoder.get(t) for t in subtoken_ids] 43 | # print("subtokens:", subtokens) 44 | # fix backreferences 45 | 46 | subtoken_backreferences = [max(t - len(tokenizer.encoder), -1) for t in subtoken_ids] 47 | # strip padding 48 | subtokens, subtoken_backreferences = zip( 49 | *[ 50 | (s, b) 51 | for s, b in zip(subtokens, subtoken_backreferences) 52 | if s != ("") 53 | ] 54 | ) 55 | 56 | # subword collapse 57 | tokens = [] 58 | backreferences = [] 59 | subword_to_token_map = {} 60 | current_token_i = 0 61 | for subw_i, (subw_backr, subtok) in enumerate(zip(subtoken_backreferences, subtokens)): 62 | subword_to_token_map[subw_i] = current_token_i 63 | 64 | # if empty you cannot do anything but add a new word 65 | if not tokens: 66 | tokens.append(subtok.lstrip(tokenizer.INIT)) 67 | backreferences.append(-1) 68 | current_token_i += 1 69 | 70 | # backref can't be splitted 71 | elif subw_backr > -1: 72 | tokens.append(None) 73 | backreferences.append(subword_to_token_map[subw_backr]) 74 | current_token_i += 1 75 | 76 | # after a special token release 77 | elif isinstance(tokens[-1], str) and rex_spc.match(tokens[-1]): 78 | tokens.append(subtok.lstrip(tokenizer.INIT)) 79 | backreferences.append(-1) 80 | current_token_i += 1 81 | 82 | # after a subtoken ':' (which should be followed by the rest of the edge) ignore tokenizer.INIT 83 | # TODO: this is an ugly patch due to the fact that BART tokenizer splits after ':' 84 | elif (tokens[-1] == ":") and rex_arg.match(subtok): 85 | tokens[-1] = tokens[-1] + subtok[1:] 86 | 87 | # leading tokenizer.INIT 88 | elif subtok.startswith(tokenizer.INIT): 89 | tokens.append(subtok.lstrip(tokenizer.INIT)) 90 | backreferences.append(-1) 91 | current_token_i += 1 92 | 93 | # very ugly patch for some cases in which tokenizer.INIT is not in the following token to the edge 94 | elif ( 95 | isinstance(tokens[-1], str) 96 | and tokens[-1].startswith(":") 97 | and tokens[-1][-1].isdigit() 98 | and (subtok != "-of") 99 | ): 100 | tokens.append(subtok.lstrip(tokenizer.INIT)) 101 | backreferences.append(-1) 102 | current_token_i += 1 103 | 104 | # in any other case attach to the previous 105 | else: 106 | tokens[-1] = tokens[-1] + subtok 107 | 108 | # strip INIT and fix byte-level 109 | tokens = [ 110 | tokenizer.convert_tokens_to_string(list(t)).lstrip() if isinstance(t, str) else t 111 | for t in tokens 112 | ] 113 | # tokens = [t.replace(tokenizer.INIT, '') if isinstance(t, str) else t for t in tokens] 114 | 115 | # unks are substituted with thing 116 | tokens = [t if t != "" else "thing" for t in tokens] 117 | 118 | old_tokens = tokens 119 | old_backreferences = backreferences 120 | 121 | # Barack Obama -> "Barack Obama" 122 | tokens = [] 123 | backreferences = [] 124 | token_to_token_map = {} 125 | start_search = 0 126 | removed = 0 127 | while True: 128 | try: 129 | 130 | lit_start = old_tokens.index("", start_search) 131 | token_addition = old_tokens[start_search:lit_start] 132 | for i, t in enumerate(token_addition, start=start_search): 133 | token_to_token_map[i] = i - removed 134 | tokens += token_addition 135 | 136 | backreferences_addition = [ 137 | token_to_token_map[b] if b > -1 else -1 138 | for b in old_backreferences[start_search:lit_start] 139 | ] 140 | backreferences += backreferences_addition 141 | 142 | lit_end = min(lit_start + 2, len(old_tokens) - 1) 143 | 144 | while lit_end < len(old_tokens): 145 | old_tok = old_tokens[lit_end] 146 | 147 | if isinstance(old_tok, str) and ( 148 | (old_tok.startswith(":") and len(old_tok) > 3) or (old_tok == "") 149 | ): 150 | res_tok = old_tokens[lit_start + 1 : lit_end] 151 | for i in range(lit_start, lit_end): 152 | token_to_token_map[i] = len(tokens) 153 | 154 | # Remove possible wrong None 155 | res = old_tokens[lit_start + 1 : lit_end] 156 | res = [str(r) for r in res if r is not None] 157 | res = '"' + "_".join(res) + '"' 158 | 159 | removed += len(res_tok) 160 | start_search = lit_end 161 | tokens += [res, old_tok] 162 | backreferences += [-1, -1] 163 | break 164 | 165 | elif old_tok == "": 166 | res_tok = old_tokens[lit_start + 1 : lit_end] 167 | for i in range(lit_start, lit_end + 1): 168 | token_to_token_map[i] = len(tokens) 169 | 170 | # Remove possible wrong None 171 | res = old_tokens[lit_start + 1 : lit_end] 172 | res = [str(r) for r in res if r is not None] 173 | res = '"' + "_".join(res) + '"' 174 | 175 | removed += len(res_tok) + 1 176 | start_search = lit_end + 1 177 | tokens.append(res) 178 | backreferences.append(-1) 179 | break 180 | 181 | else: 182 | lit_end += 1 183 | start_search = lit_end 184 | 185 | except ValueError: 186 | token_addition = old_tokens[start_search:] 187 | for i, t in enumerate(token_addition, start=start_search): 188 | token_to_token_map[i] = i - removed 189 | backreferences_addition = [ 190 | token_to_token_map[b] if b > -1 else b for b in old_backreferences[start_search:] 191 | ] 192 | tokens += token_addition 193 | backreferences += backreferences_addition 194 | break 195 | 196 | tokens = [token_processing(t) for t in tokens] 197 | 198 | shift = 1 199 | if tokens[1] == "": 200 | shift = 2 201 | 202 | tokens = tokens[shift:] 203 | backreferences = [b if b == -1 else b - shift for b in backreferences[shift:]] 204 | 205 | if tokens[-1] == "": 206 | tokens.pop() 207 | backreferences.pop() 208 | 209 | return tokens, backreferences 210 | 211 | 212 | def index_of(element, iterable, default=None, start=None, end=None): 213 | if not callable(element): 214 | 215 | def check(x): 216 | return element == x 217 | 218 | else: 219 | check = element 220 | if start is None: 221 | start = 0 222 | if end is None: 223 | end = len(iterable) 224 | item = start 225 | while item < end: 226 | if check(iterable[item]): 227 | return item 228 | item += 1 229 | return default 230 | 231 | 232 | def separate_edges_nodes(edges_nodes_slice, *other): 233 | is_arg = lambda x: isinstance(x, str) and x.startswith(":") 234 | start = 0 235 | edges = [] 236 | nodes = [] 237 | l = len(edges_nodes_slice) 238 | while start < l: 239 | edge_index = index_of(is_arg, edges_nodes_slice, start=start) 240 | if edge_index is None or edge_index == (l - 1): 241 | break 242 | if is_arg(edges_nodes_slice[edge_index + 1]): 243 | start = edge_index + 1 244 | continue 245 | edges.append(edge_index) 246 | nodes.append(edge_index + 1) 247 | start = edge_index + 2 248 | ret = [] 249 | for oth in other: 250 | edges_oth = [oth[i] for i in edges] 251 | nodes_oth = [oth[i] for i in nodes] 252 | ret.append((edges_oth, nodes_oth)) 253 | return ret 254 | 255 | 256 | def _split_name_ops(graph): 257 | # identify name triples 258 | name_vars = {} 259 | for i, (v1, rel, v2) in enumerate(graph.triples): 260 | if rel == ":instance" and v2 == "name": 261 | name_vars[v1] = 1 262 | 263 | # check if they have ops 264 | name_vars_to_ops = defaultdict(list) 265 | for i, (v1, rel, v2) in enumerate(graph.triples): 266 | if v1 in name_vars and rel.startswith(":op"): 267 | name_vars_to_ops[v1].append((i, rel, v2.strip('"'))) 268 | 269 | triples = graph.triples.copy() 270 | for nv, ops in name_vars_to_ops.items(): 271 | ops = sorted(ops, key=lambda x: int(x[1][3:])) 272 | idx, _, lits = zip(*ops) 273 | for i in idx: 274 | triples[i] = None 275 | 276 | lits = ['"' + l + '"' for lit in lits for l in lit.split("_")] 277 | 278 | tt = [] 279 | for i, l in enumerate(lits, start=1): 280 | rel = ":op" + str(i) 281 | tt.append(penman.Triple(nv, rel, l)) 282 | 283 | triples[min(idx)] = tt 284 | 285 | triples = [t if isinstance(t, list) else [t] for t in triples if t is not None] 286 | triples = [t for tt in triples for t in tt] 287 | 288 | graph_ = penman.Graph(triples) 289 | graph_.metadata = graph.metadata 290 | return graph_ 291 | 292 | 293 | def _reconstruct_graph_from_nodes(nodes, backreferences): 294 | triples = [] 295 | triples_added = set() 296 | 297 | variable2index = {} 298 | index2variable = {} 299 | start_index = 0 300 | 301 | cnt = defaultdict(Counter) 302 | 303 | while start_index < len(nodes): 304 | stop_index = index_of("", nodes, default=len(nodes) + 1, start=start_index) 305 | old_start_index = start_index 306 | start_index = stop_index + 1 307 | 308 | src_node, src_backr = nodes[old_start_index], backreferences[old_start_index] 309 | 310 | if src_node == "": 311 | continue 312 | 313 | trg_nodes_edges = nodes[old_start_index:stop_index] 314 | trg_nodes_edges_backr = backreferences[old_start_index:stop_index] 315 | trg_nodes_edges_indices = list(range(old_start_index, stop_index)) 316 | 317 | if isinstance(src_node, str): 318 | if src_node in ("", "", ""): 319 | continue 320 | elif ("/" in src_node) or (":" in src_node) or ("(" in src_node) or (")" in src_node): 321 | src_node = "thing" 322 | 323 | if src_node is not None: 324 | src_node = str(src_node) 325 | src_var = src_node[0].lower() 326 | if not src_var not in "abcdefghijklmnopqrstuvwxyz": 327 | src_var = "x" 328 | # src_var = f'{src_var}_{len(variable2index)}' 329 | src_var = f"{src_var}{len(variable2index)}" 330 | src_var_i = old_start_index 331 | variable2index[src_var] = src_var_i 332 | index2variable[src_var_i] = src_var 333 | triple = penman.Triple(src_var, ":instance", src_node) 334 | if triple not in triples_added: 335 | triples.append(triple) 336 | triples_added.add(triple) 337 | else: 338 | if src_backr in index2variable: 339 | src_var = index2variable[src_backr] 340 | # more resilient logic here 341 | (trg_edges, trg_nodes), (_, trg_nodes_backr), (_, trg_nodes_indices) = separate_edges_nodes( 342 | trg_nodes_edges, trg_nodes_edges, trg_nodes_edges_backr, trg_nodes_edges_indices 343 | ) 344 | 345 | for n, e, nb, ni in zip(trg_nodes, trg_edges, trg_nodes_backr, trg_nodes_indices): 346 | 347 | if isinstance(n, str) and n.startswith(":"): 348 | continue 349 | if isinstance(n, str) and n.startswith("<") and n.endswith(">"): 350 | continue 351 | if e == ":li": 352 | pass 353 | elif len(e) < 4 or (not e.startswith(":")): 354 | continue 355 | 356 | # same edge more than once 357 | num = cnt[src_var][e] 358 | # num = 0 359 | if num: 360 | 361 | if e.startswith(":op") or e.startswith(":snt"): 362 | continue 363 | # elif e.startswith(':ARG'): 364 | # continue 365 | elif num > 3: 366 | continue 367 | 368 | if n is None: 369 | if nb not in index2variable: 370 | continue 371 | trg_var = index2variable[nb] 372 | trg = trg_var 373 | elif e == ":mode": 374 | trg = n 375 | elif ( 376 | (not isinstance(n, str)) 377 | or re.match(r"^[+-]?\d+\.?\d*$", n) 378 | or (n == "-") 379 | or (n == "+") 380 | ): 381 | trg = str(n) 382 | elif n.startswith('"') and n.endswith('"') and len(n) > 2: 383 | trg = '"' + n.replace('"', "") + '"' 384 | elif ("/" in n) or (":" in n) or ("(" in n) or (")" in n) or ("=" in n): 385 | trg = f'"{n}"' 386 | elif n == '"': 387 | continue 388 | elif ( 389 | (n.startswith('"') and (not n.endswith('"'))) 390 | or (not n.startswith('"') and (n.endswith('"'))) 391 | or ('"' in n) 392 | ): 393 | trg = '"' + n.replace('"', "") + '"' 394 | else: 395 | trg_var = n[0].lower() 396 | if trg_var not in "abcdefghijklmnopqrstuvwxyz": 397 | trg_var = "x" 398 | # trg_var = f'{trg_var}_{len(variable2index)}' 399 | trg_var = f"{trg_var}{len(variable2index)}" 400 | trg_var_i = ni 401 | variable2index[trg_var] = trg_var_i 402 | index2variable[trg_var_i] = trg_var 403 | triple = penman.Triple(trg_var, ":instance", n) 404 | if triple not in triples_added: 405 | triples.append(triple) 406 | triples_added.add(triple) 407 | trg = trg_var 408 | 409 | triple = penman.Triple(src_var, e, trg) 410 | if triple not in triples_added: 411 | triples.append(triple) 412 | triples_added.add(triple) 413 | 414 | cnt[src_var][e] += 1 415 | 416 | return penman.Graph(triples) 417 | 418 | 419 | def build_graph(nodes, backreferences, restore_name_ops=False): 420 | graph = _reconstruct_graph_from_nodes(nodes, backreferences) 421 | if restore_name_ops: 422 | graph = _split_name_ops(graph) 423 | return graph 424 | 425 | 426 | class ParsedStatus(enum.Enum): 427 | OK = 0 428 | FIXED = 1 429 | BACKOFF = 2 430 | 431 | 432 | def connect_graph_if_not_connected(graph): 433 | 434 | try: 435 | encoded = encode(graph) 436 | return graph, ParsedStatus.OK 437 | except: 438 | pass 439 | 440 | nxgraph = nx.MultiGraph() 441 | variables = graph.variables() 442 | for v1, _, v2 in graph.triples: 443 | if v1 in variables and v2 in variables: 444 | nxgraph.add_edge(v1, v2) 445 | elif v1 in variables: 446 | nxgraph.add_edge(v1, v1) 447 | 448 | triples = graph.triples.copy() 449 | new_triples = [] 450 | addition = f"a{len(variables) + 1}" 451 | triples.append(penman.Triple(addition, ":instance", "and")) 452 | for i, conn_set in enumerate(nx.connected_components(nxgraph), start=1): 453 | edge = f":op{i}" 454 | conn_set = sorted(conn_set, key=lambda x: int(x[1:])) 455 | conn_set = [c for c in conn_set if c in variables] 456 | node = conn_set[0] 457 | new_triples.append(penman.Triple(addition, edge, node)) 458 | triples = new_triples + triples 459 | metadata = graph.metadata 460 | graph = penman.Graph(triples) 461 | graph.metadata.update(metadata) 462 | encode(graph) 463 | 464 | return graph, ParsedStatus.FIXED 465 | 466 | 467 | def restore_backreferences_from_pointers(nodes): 468 | new_nodes, new_backreferences = [], [] 469 | prev_pointer = None 470 | pointer2i = {} 471 | for n in nodes: 472 | is_pointer = isinstance(n, str) and n.startswith("") 473 | 474 | if not is_pointer: 475 | if prev_pointer is not None: 476 | if prev_pointer in pointer2i: 477 | new_nodes.append(None) 478 | new_backreferences.append(pointer2i[prev_pointer]) 479 | new_nodes.append(n) 480 | new_backreferences.append(-1) 481 | 482 | else: 483 | pointer2i[prev_pointer] = len(new_nodes) 484 | new_nodes.append(n) 485 | new_backreferences.append(-1) 486 | else: 487 | new_nodes.append(n) 488 | new_backreferences.append(-1) 489 | 490 | prev_pointer = None 491 | else: 492 | prev_pointer = n 493 | return new_nodes, new_backreferences 494 | -------------------------------------------------------------------------------- /fine-tune/data_interface/data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """AMR dataset.""" 18 | 19 | 20 | from inspect import EndOfBlock 21 | import json 22 | import os 23 | 24 | import datasets 25 | 26 | logger = datasets.logging.get_logger(__name__) 27 | 28 | 29 | _DESCRIPTION = """ 30 | 31 | There are three features: 32 | - src: text. 33 | - tgt: Linearized AMR. 34 | """ 35 | 36 | _SRC = "src" 37 | _TGT = "tgt" 38 | 39 | 40 | class AMRData(datasets.GeneratorBasedBuilder): 41 | """AMR Dataset.""" 42 | 43 | # Version 1.0.0 expands coverage, includes ids, and removes web contents. 44 | VERSION = datasets.Version("1.0.0") 45 | 46 | def _info(self): 47 | return datasets.DatasetInfo( 48 | description=_DESCRIPTION, 49 | features=datasets.Features( 50 | {_SRC: datasets.Value("string"), _TGT: datasets.Value("string"),} 51 | ), 52 | supervised_keys=None, 53 | ) 54 | 55 | def _split_generators(self, dl_manager): 56 | """Returns SplitGenerators.""" 57 | 58 | train_path = self.config.data_files["train"] if "train" in self.config.data_files else None 59 | dev_path = self.config.data_files["validation"] if "validation" in self.config.data_files else None 60 | test_path = self.config.data_files["test"] if "test" in self.config.data_files else None 61 | 62 | train_generator = datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": train_path}) 63 | dev_generator = datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": dev_path}) 64 | test_generator = datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepath": test_path}) 65 | return [ 66 | train_generator, dev_generator, test_generator 67 | ] 68 | 69 | def _generate_examples(self, filepath): 70 | """Yields examples.""" 71 | if filepath: 72 | logger.info("generating examples from = %s", filepath[0]) 73 | with open(filepath[0], "r", encoding="utf-8") as f: 74 | lines = f.readlines() 75 | for idx, line in enumerate(lines): 76 | json_dict = json.loads(line.strip()) 77 | src = json_dict["amr"] 78 | tgt = json_dict["sent"] 79 | yield idx, {_SRC: src, _TGT: tgt} 80 | -------------------------------------------------------------------------------- /fine-tune/data_interface/dataset.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import os 3 | from torch.utils.data import Dataset 4 | from datasets import load_dataset 5 | from dataclasses import dataclass 6 | from transformers.file_utils import PaddingStrategy 7 | from transformers.modeling_utils import PreTrainedModel 8 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 9 | from typing import Optional, Union 10 | from common.utils import shift_tokens_right 11 | 12 | 13 | def padding_func(features, padding_side="right", pad_token_id=1, key="label", pad_to_multiple_of=1, max_length=None): 14 | assert key in features[0].keys(), f"{key} not in {features[0].keys()}" 15 | max_label_length = max(len(feature[key]) for feature in features) 16 | if pad_to_multiple_of > 1: 17 | if max_length is not None: 18 | max_label_length = min(max_length, 19 | (max_label_length + pad_to_multiple_of - 1) // pad_to_multiple_of * pad_to_multiple_of 20 | ) 21 | else: 22 | max_label_length = (max_label_length + pad_to_multiple_of - 1) // pad_to_multiple_of * pad_to_multiple_of 23 | 24 | for feature in features: 25 | remainder = [pad_token_id] * (max_label_length - len(feature[key])) 26 | feature[key] = ( 27 | feature[key] + remainder if padding_side == "right" else remainder + feature[key] 28 | ) 29 | return 30 | 31 | 32 | class AMRParsingDataSet(Dataset): 33 | def __init__( 34 | self, tokenizer, args, model_args 35 | ): 36 | super().__init__() 37 | self.train_file = args.train_file 38 | self.validation_file = args.validation_file 39 | self.test_file = args.test_file 40 | self.src_prefix = args.source_prefix 41 | self.tgt_prefix = args.target_prefix 42 | self.cache_dir = model_args.cache_dir 43 | self.use_speaker_prefix = args.use_speaker_prefix 44 | self.tokenizer = tokenizer 45 | self.unified_input = args.unified_input 46 | 47 | self.max_src_length = min(args.max_source_length, self.tokenizer.model_max_length) 48 | self.max_tgt_length = min(args.max_target_length, self.tokenizer.model_max_length) 49 | 50 | data_files = {} 51 | if self.train_file is not None: 52 | data_files["train"] = self.train_file 53 | 54 | if self.validation_file is not None: 55 | data_files["validation"] = self.validation_file 56 | 57 | if self.test_file is not None: 58 | data_files["test"] = self.test_file 59 | 60 | # print("datafiles:", data_files) 61 | print("Dataset cache dir:", self.cache_dir) 62 | # exit() 63 | self.datasets = load_dataset( 64 | f"{os.path.dirname(__file__)}/data.py", 65 | data_files=data_files, 66 | keep_in_memory=False, 67 | ) 68 | column_names = self.datasets["train"].column_names 69 | print("datasets:", self.datasets) 70 | print("colums:", column_names) 71 | 72 | def tokenize_function(self, examples): 73 | amr = examples["src"] # AMR tokens 74 | txt = examples["tgt"] # Text tokens 75 | 76 | amr_ids = [self.tokenizer.tokenize_amr(itm.split())[:self.max_tgt_length-2] + [self.tokenizer.amr_eos_token_id] for itm in amr] 77 | 78 | raw_txt_ids = self.tokenizer( 79 | txt, max_length=self.max_src_length, padding=False, truncation=True 80 | )["input_ids"] 81 | if self.unified_input: 82 | txt_ids = [itm[:self.max_src_length-3] + [self.tokenizer.amr_bos_token_id, self.tokenizer.mask_token_id, self.tokenizer.amr_eos_token_id] for itm in raw_txt_ids] 83 | else: 84 | txt_ids = raw_txt_ids 85 | return { 86 | "input_ids": txt_ids, 87 | "labels": amr_ids 88 | } 89 | 90 | 91 | @dataclass 92 | class DataCollatorForAMRParsing: 93 | """ 94 | Data collator that will dynamically pad the inputs received, as well as the labels. 95 | 96 | Args: 97 | tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): 98 | The tokenizer used for encoding the data. 99 | model (:class:`~transformers.PreTrainedModel`): 100 | The model that is being trained. If set and has the `prepare_decoder_input_ids_from_labels`, use it to 101 | prepare the `decoder_input_ids` 102 | 103 | This is useful when using `label_smoothing` to avoid calculating loss twice. 104 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`): 105 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 106 | among: 107 | 108 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 109 | sequence is provided). 110 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 111 | maximum acceptable input length for the model if that argument is not provided. 112 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 113 | different lengths). 114 | max_length (:obj:`int`, `optional`): 115 | Maximum length of the returned list and optionally padding length (see above). 116 | pad_to_multiple_of (:obj:`int`, `optional`): 117 | If set will pad the sequence to a multiple of the provided value. 118 | 119 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 120 | 7.5 (Volta). 121 | label_pad_token_id (:obj:`int`, `optional`, defaults to -100): 122 | The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). 123 | """ 124 | 125 | tokenizer: PreTrainedTokenizerBase 126 | model: Optional[PreTrainedModel] = None 127 | padding: Union[bool, str, PaddingStrategy] = True 128 | max_length: Optional[int] = None 129 | pad_to_multiple_of: Optional[int] = None 130 | label_pad_token_id: int = -100 131 | 132 | def __call__(self, features): 133 | 134 | padding_func( 135 | features, 136 | padding_side=self.tokenizer.padding_side, 137 | pad_token_id=self.label_pad_token_id, 138 | key="labels", 139 | pad_to_multiple_of=self.pad_to_multiple_of, 140 | ) 141 | 142 | features = self.tokenizer.pad( 143 | features, 144 | padding=self.padding, 145 | max_length=self.max_length, 146 | pad_to_multiple_of=self.pad_to_multiple_of, 147 | return_tensors="pt", 148 | ) 149 | 150 | # prepare decoder_input_ids 151 | features["decoder_input_ids"] = shift_tokens_right( 152 | features["labels"], 153 | pad_token_id=self.tokenizer.pad_token_id, 154 | decoder_start_token_id=self.tokenizer.amr_bos_token_id, 155 | ) 156 | 157 | return { 158 | "input_ids": features["input_ids"], 159 | "labels": features["labels"], 160 | "decoder_input_ids": features["decoder_input_ids"], 161 | } 162 | 163 | 164 | class AMR2TextDataSet(Dataset): 165 | def __init__( 166 | self, tokenizer, args, model_args 167 | ): 168 | super().__init__() 169 | self.train_file = args.train_file 170 | self.validation_file = args.validation_file 171 | self.test_file = args.test_file 172 | self.src_prefix = args.source_prefix 173 | self.tgt_prefix = args.target_prefix 174 | self.cache_dir = model_args.cache_dir 175 | self.use_speaker_prefix = args.use_speaker_prefix 176 | self.tokenizer = tokenizer 177 | self.unified_input = args.unified_input 178 | 179 | self.max_src_length = min(args.max_source_length, self.tokenizer.model_max_length) 180 | self.max_tgt_length = min(args.max_target_length, self.tokenizer.model_max_length) 181 | 182 | data_files = {} 183 | if self.train_file is not None: 184 | data_files["train"] = self.train_file 185 | 186 | if self.validation_file is not None: 187 | data_files["validation"] = self.validation_file 188 | 189 | if self.test_file is not None: 190 | data_files["test"] = self.test_file 191 | # print("datafiles:", data_files) 192 | print("Dataset cache dir:", self.cache_dir) 193 | # exit() 194 | self.datasets = load_dataset( 195 | f"{os.path.dirname(__file__)}/data.py", 196 | data_files=data_files, 197 | keep_in_memory=False, 198 | ) 199 | column_names = self.datasets["train"].column_names 200 | print("datasets:", self.datasets) 201 | print("colums:", column_names) 202 | 203 | def tokenize_function(self, examples): 204 | src = examples["src"] # AMR tokens 205 | tgt = examples["tgt"] # Text tokens 206 | if not self.unified_input: 207 | src_ids = [[self.tokenizer.amr_bos_token_id] + self.tokenizer.tokenize_amr(itm.split())[:self.max_src_length - 2] + [self.tokenizer.amr_eos_token_id] for itm in src] 208 | else: 209 | # [[mask]xxx] 210 | src_ids = [[self.tokenizer.bos_token_id, self.tokenizer.mask_token_id, self.tokenizer.eos_token_id] + [self.tokenizer.amr_bos_token_id] + self.tokenizer.tokenize_amr(itm.split())[:self.max_src_length -5] + [self.tokenizer.amr_eos_token_id] for itm in src] 211 | 212 | with self.tokenizer.as_target_tokenizer(): 213 | tgt_ids = self.tokenizer( 214 | tgt, max_length=self.max_tgt_length, padding=False, truncation=True 215 | ) 216 | tgt_ids["input_ids"] = [ 217 | label[1:] for label in tgt_ids["input_ids"] 218 | ] 219 | model_inputs = {} 220 | model_inputs["input_ids"] = src_ids 221 | model_inputs["labels"] = tgt_ids["input_ids"] 222 | return model_inputs 223 | 224 | 225 | @dataclass 226 | class DataCollatorForAMR2Text: 227 | """ 228 | Data collator that will dynamically pad the inputs received, as well as the labels. 229 | 230 | Args: 231 | tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): 232 | The tokenizer used for encoding the data. 233 | model (:class:`~transformers.PreTrainedModel`): 234 | The model that is being trained. If set and has the `prepare_decoder_input_ids_from_labels`, use it to 235 | prepare the `decoder_input_ids` 236 | 237 | This is useful when using `label_smoothing` to avoid calculating loss twice. 238 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`): 239 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 240 | among: 241 | 242 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 243 | sequence is provided). 244 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 245 | maximum acceptable input length for the model if that argument is not provided. 246 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 247 | different lengths). 248 | max_length (:obj:`int`, `optional`): 249 | Maximum length of the returned list and optionally padding length (see above). 250 | pad_to_multiple_of (:obj:`int`, `optional`): 251 | If set will pad the sequence to a multiple of the provided value. 252 | 253 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 254 | 7.5 (Volta). 255 | label_pad_token_id (:obj:`int`, `optional`, defaults to -100): 256 | The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). 257 | """ 258 | 259 | tokenizer: PreTrainedTokenizerBase 260 | model: Optional[PreTrainedModel] = None 261 | padding: Union[bool, str, PaddingStrategy] = True 262 | max_length: Optional[int] = None 263 | pad_to_multiple_of: Optional[int] = None 264 | label_pad_token_id: int = -100 265 | 266 | def __call__(self, features): 267 | 268 | padding_func( 269 | features, 270 | padding_side=self.tokenizer.padding_side, 271 | pad_token_id=self.label_pad_token_id, 272 | key="labels", 273 | pad_to_multiple_of=self.pad_to_multiple_of, 274 | ) 275 | 276 | features = self.tokenizer.pad( 277 | features, 278 | padding=self.padding, 279 | max_length=self.max_length, 280 | pad_to_multiple_of=self.pad_to_multiple_of, 281 | return_tensors="pt", 282 | ) 283 | 284 | # prepare decoder_input_ids 285 | features["decoder_input_ids"] = shift_tokens_right( 286 | features["labels"], 287 | pad_token_id=self.tokenizer.pad_token_id, 288 | decoder_start_token_id=self.tokenizer.eos_token_id, 289 | ) 290 | 291 | return { 292 | "input_ids": features["input_ids"], 293 | "labels": features["labels"], 294 | "decoder_input_ids": features["decoder_input_ids"], 295 | } -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/README.md: -------------------------------------------------------------------------------- 1 | This directory contains a number of useful scripts that are helpful for preprocessing parallel and monolingual corpora. They are provided for convenience and may be very useful, but their functionality will often be supplainted by other, more specialized tools. 2 | 3 | Many of these scripts assume that the input is [UTF-8 encoded](http://en.wikipedia.org/wiki/UTF-8). 4 | 5 | ## Paste parallel files together 6 | 7 | This script reads one line at a time from a set of files and concatenates them with a triple pipe separator (`|||`) in the output. This is useful for generating parallel corpora files for training or evaluation: 8 | 9 | ./paste-files.pl file.a file.b file.c [...] 10 | 11 | ## Punctuation Normalization and Tokenization 12 | 13 | This script tokenizes text in any language (well, it does a good job in most languages, and in some it will completely go crazy): 14 | 15 | ./tokenize-anything.sh < input.txt > output.txt 16 | 17 | It also normalizes a lot of unicode symbols and even corrects some common encoding errors. It can be applied to monolingual and parallel corpora directly. 18 | 19 | ## Text lowercasing 20 | 21 | This script also does what it says, provided your input is in UTF8: 22 | 23 | ./lowercase.pl < input.txt > output.txt 24 | 25 | ## Length ratio filtering (for parallel corpora) 26 | 27 | This script computes statistics about sentence length ratios in a parallel corpus and removes sentences that are statistical outliers. This tends to remove extremely poorly aligned sentence pairs or sentence pairs that would otherwise be difficult to align: 28 | 29 | ./filter-length.pl input.src-trg > output.src-trg 30 | 31 | ## Add infrequent self-transaltions to a parallel corpus 32 | 33 | This script identifies rare words (those that occur less than 2 times in the corpus) and which have the same orthographic form in both the source and target language. Several copies of these words are then inserted at the end of the corpus that is written, which improves alignment quality. 34 | 35 | ./add-self-translations.pl input.src-trg > output.src-trg 36 | 37 | 38 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/add-self-translations.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | use strict; 3 | 4 | # ADDS SELF-TRANSLATIONS OF POORLY ATTESTED WORDS TO THE PARALLEL DATA 5 | 6 | my %df; 7 | my %def; 8 | while(<>) { 9 | # print; 10 | chomp; 11 | my ($sf, $se) = split / \|\|\| /; 12 | die "Format error: $_\n" unless defined $sf && defined $se; 13 | my @fs = split /\s+/, $sf; 14 | my @es = split /\s+/, $se; 15 | for my $f (@fs) { 16 | $df{$f}++; 17 | for my $e (@es) { 18 | if ($f eq $e) { $def{$f}++; } 19 | } 20 | } 21 | } 22 | 23 | for my $k (sort keys %def) { 24 | next if $df{$k} > 4; 25 | print "$k ||| $k\n"; 26 | print "$k ||| $k\n"; 27 | print "$k ||| $k\n"; 28 | } 29 | 30 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/add-sos-eos.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | use strict; 3 | 4 | die "Usage: $0 corpus.fr[-en1-en2-...] [corpus.al out-corpus.al]\n" unless (scalar @ARGV == 1 || scalar @ARGV == 3); 5 | my $filec = shift @ARGV; 6 | my $filea = shift @ARGV; 7 | my $ofilea = shift @ARGV; 8 | open C, "<$filec" or die "Can't read $filec: $!"; 9 | if ($filea) { 10 | open A, "<$filea" or die "Can't read $filea: $!"; 11 | open OA, ">$ofilea" or die "Can't write $ofilea: $!"; 12 | } 13 | binmode(C, ":utf8"); 14 | binmode(STDOUT, ":utf8"); 15 | print STDERR "Adding and markers to input...\n"; 16 | print STDERR " Reading corpus: $filec\n"; 17 | print STDERR " Writing corpus: STDOUT\n"; 18 | print STDERR "Reading alignments: $filea\n" if $filea; 19 | print STDERR "Writing alignments: $ofilea\n" if $filea; 20 | 21 | my $lines = 0; 22 | while() { 23 | $lines++; 24 | die "ERROR. Input line $filec:$lines should not contain SGML markup" if /; 36 | die "ERROR. Mismatched number of lines between $filec and $filea\n" unless $aa; 37 | chomp $aa; 38 | my ($ff, $ee) = @fields; 39 | die "ERROR in $filec:$lines: expected 'source ||| target'" unless defined $ee; 40 | my @fs = split /\s+/, $ff; 41 | my @es = split /\s+/, $ee; 42 | my @as = split /\s+/, $aa; 43 | my @oas = (); 44 | push @oas, '0-0'; 45 | my $flen = scalar @fs; 46 | my $elen = scalar @es; 47 | for my $ap (@as) { 48 | my ($a, $b) = split /-/, $ap; 49 | die "ERROR. Bad format in: @as" unless defined $a && defined $b; 50 | push @oas, ($a + 1) . '-' . ($b + 1); 51 | } 52 | push @oas, ($flen + 1) . '-' . ($elen + 1); 53 | print OA "@oas\n"; 54 | } 55 | print "$o\n"; 56 | } 57 | if ($filea) { 58 | close OA; 59 | my $aa = ; 60 | die "ERROR. Alignment input file $filea contains more lines than corpus file!\n" if $aa; 61 | } 62 | print STDERR "\nSUCCESS. Processed $lines lines.\n"; 63 | 64 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/conll2cdec.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | use strict; 3 | 4 | die "Usage: $0 file.conll\n\n Converts a CoNLL formatted labeled sequence into cdec's format.\n\n" unless scalar @ARGV == 1; 5 | open F, "<$ARGV[0]" or die "Can't read $ARGV[0]: $!\n"; 6 | 7 | my @xx; 8 | my @yy; 9 | my @os; 10 | my $sec = undef; 11 | my $i = 0; 12 | while() { 13 | chomp; 14 | if (/^\s*$/) { 15 | print "[$j]; 21 | $sym =~ s/"/'/g; 22 | push @oo, $sym; 23 | } 24 | my $zz = $j + 1; 25 | print " feat$zz=\"@oo\""; 26 | } 27 | 28 | print "> @xx ||| @yy \n"; 29 | @xx = (); 30 | @yy = (); 31 | @os = (); 32 | } else { 33 | my ($x, @fs) = split /\s+/; 34 | my $y = pop @fs; 35 | if (!defined $sec) { $sec = scalar @fs; } 36 | die unless $sec == scalar @fs; 37 | push @xx, $x; 38 | push @yy, $y; 39 | push @os, \@fs; 40 | } 41 | } 42 | 43 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/corpus-stats.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | use strict; 3 | 4 | my $f = <>; 5 | my $IS_PARALLEL = ($f =~ / \|\|\| /); 6 | if ($IS_PARALLEL) { 7 | die "This script is only valid for monolingual corpora, but file contains |||\n"; 8 | } 9 | 10 | my %d; 11 | my $tc = 0; 12 | my $lc = 0; 13 | while($f) { 14 | $lc++; 15 | chomp $f; 16 | my @toks = split /\s+/, $f; 17 | for my $t (@toks) { 18 | $d{$t}++; 19 | $tc++; 20 | } 21 | $f=<>; 22 | } 23 | 24 | my $types = scalar keys %d; 25 | my $ttr = $tc / $types; 26 | my @mfts; 27 | for my $k (sort {$d{$b} <=> $d{$a}} keys %d) { 28 | push @mfts, $k; 29 | last if scalar @mfts > 24; 30 | } 31 | my $sing = 0; 32 | for my $k (keys %d) { 33 | if ($d{$k} == 1) { $sing++; } 34 | } 35 | my $stypes = sqrt($types); 36 | 37 | print < 0; 4 | 5 | my $x = shift @ARGV; 6 | my @ind = split /,/, $x; 7 | my @o = (); 8 | for my $ff (@ind) { 9 | if ($ff =~ /^\d+$/) { 10 | push @o, $ff - 1; 11 | } elsif ($ff =~ /^(\d+)-(\d+)$/) { 12 | my $a = $1; 13 | my $b = $2; 14 | die "$a-$b is a bad range in input: $x\n" unless $b > $a; 15 | for (my $i=$a; $i <= $b; $i++) { 16 | push @o, $i - 1; 17 | } 18 | } else { 19 | die "Bad input: $x\n"; 20 | } 21 | } 22 | 23 | while(<>) { 24 | chomp; 25 | my @fields = split /\s*\|\|\|\s*/; 26 | my @sf; 27 | for my $i (@o) { 28 | my $y = $fields[$i]; 29 | if (!defined $y) { $y= ''; } 30 | push @sf, $y; 31 | } 32 | print join(' ||| ', @sf) . "\n"; 33 | } 34 | 35 | 36 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/filter-length.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | use strict; 3 | use utf8; 4 | 5 | ##### EDIT THESE SETTINGS #################################################### 6 | my $AUTOMATIC_INCLUDE_IF_SHORTER_THAN = 7; # if both are shorter, include 7 | my $MAX_ZSCORE = 2.0; # how far from the mean can the (log)ratio be? 8 | ############################################################################## 9 | 10 | die "Usage: $0 [-NNN] corpus.fr-en\n\n Filter sentence pairs containing sentences longer than NNN words (where NNN\n is 150 by default) or whose log length ratios are $MAX_ZSCORE stddevs away from the\n mean log ratio.\n\n" unless scalar @ARGV == 1 || scalar @ARGV == 2; 11 | binmode(STDOUT,":utf8"); 12 | binmode(STDERR,":utf8"); 13 | 14 | my $MAX_LENGTH = 150; # discard a sentence if it is longer than this 15 | if (scalar @ARGV == 2) { 16 | my $fp = shift @ARGV; 17 | die "Expected -NNN for first parameter, but got $fp\n" unless $fp =~ /^-(\d+)$/; 18 | $MAX_LENGTH=$1; 19 | } 20 | 21 | my $corpus = shift @ARGV; 22 | 23 | die "Cannot read from STDIN\n" if $corpus eq '-'; 24 | my $ff = "<$corpus"; 25 | $ff = "gunzip -c $corpus|" if $ff =~ /\.gz$/; 26 | 27 | print STDERR "Max line length (monolingual): $MAX_LENGTH\n"; 28 | print STDERR " Parallel corpus: $corpus\n"; 29 | 30 | open F,$ff or die "Can't read $corpus: $!"; 31 | binmode(F,":utf8"); 32 | 33 | my $rat_max = log(9); 34 | my $lrm = 0; 35 | my $zerof = 0; 36 | my $zeroe = 0; 37 | my $bad_format = 0; 38 | my $absbadrat = 0; 39 | my $overlene = 0; 40 | my $overlenf = 0; 41 | my $lines = 0; 42 | my @lograts = (); 43 | while() { 44 | $lines++; 45 | if ($lines % 100000 == 0) { print STDERR " [$lines]\n"; } 46 | elsif ($lines % 2500 == 0) { print STDERR "."; } 47 | my ($sf, $se, @d) = split /\s*\|\|\|\s*| *\t */; 48 | if (scalar @d != 0 or !defined $se) { 49 | $bad_format++; 50 | if ($bad_format > 100 && ($bad_format / $lines) > 0.02) { 51 | die "$bad_format / $lines : Corpus appears to be incorretly formatted, example: $_"; 52 | } 53 | next; 54 | } 55 | my @fs = (); 56 | my @es = (); 57 | if (defined $sf && length($sf) > 0) { @fs = split /\s+/, $sf; } 58 | if (defined $se && length($se) > 0) { @es = split /\s+/, $se; } 59 | my $flen = scalar @fs; 60 | my $elen = scalar @es; 61 | if ($flen == 0) { 62 | $zerof++; 63 | next; 64 | } 65 | if ($elen == 0) { 66 | $zeroe++; 67 | next; 68 | } 69 | if ($flen > $MAX_LENGTH) { 70 | $overlenf++; 71 | next; 72 | } 73 | if ($elen > $MAX_LENGTH) { 74 | $overlene++; 75 | next; 76 | } 77 | if ($elen >= $AUTOMATIC_INCLUDE_IF_SHORTER_THAN || 78 | $flen >= $AUTOMATIC_INCLUDE_IF_SHORTER_THAN) { 79 | my $lograt = log($flen) - log($elen); 80 | if (abs($lograt) > $rat_max) { 81 | $absbadrat++; 82 | next; 83 | } 84 | $lrm += $lograt; 85 | push @lograts, $lograt; 86 | } 87 | } 88 | close F; 89 | 90 | print STDERR "\nComputing statistics...\n"; 91 | my $lmean = $lrm / scalar @lograts; 92 | 93 | my $lsd = 0; 94 | for my $lr (@lograts) { 95 | $lsd += ($lr - $lmean)**2; 96 | } 97 | $lsd = sqrt($lsd / scalar @lograts); 98 | @lograts = (); 99 | 100 | my $pass1_discard = $zerof + $zeroe + $absbadrat + $overlene + $overlenf + $bad_format; 101 | my $discard_rate = int(10000 * $pass1_discard / $lines) / 100; 102 | print STDERR " Total lines: $lines\n"; 103 | print STDERR " Already discared: $pass1_discard\t(discard rate = $discard_rate%)\n"; 104 | print STDERR " Mean F:E ratio: " . exp($lmean) . "\n"; 105 | print STDERR " StdDev F:E ratio: " . exp($lsd) . "\n"; 106 | print STDERR "Writing...\n"; 107 | open F,$ff or die "Can't reread $corpus: $!"; 108 | binmode(F,":utf8"); 109 | my $to = 0; 110 | my $zviol = 0; 111 | my $worstz = -1; 112 | my $worst = "\n"; 113 | $lines = 0; 114 | while() { 115 | $lines++; 116 | if ($lines % 100000 == 0) { print STDERR " [$lines]\n"; } 117 | elsif ($lines % 2500 == 0) { print STDERR "."; } 118 | my ($sf, $se, @d) = split /\s*\|\|\|\s*| *\t */; 119 | if (!defined $se) { next; } 120 | my @fs = split /\s+/, $sf; 121 | my @es = split /\s+/, $se; 122 | my $flen = scalar @fs; 123 | my $elen = scalar @es; 124 | next if ($flen == 0); 125 | next if ($elen == 0); 126 | next if ($flen > $MAX_LENGTH); 127 | next if ($elen > $MAX_LENGTH); 128 | if ($elen >= $AUTOMATIC_INCLUDE_IF_SHORTER_THAN || 129 | $flen >= $AUTOMATIC_INCLUDE_IF_SHORTER_THAN) { 130 | my $lograt = log($flen) - log($elen); 131 | if (abs($lograt) > $rat_max) { 132 | $absbadrat++; 133 | next; 134 | } 135 | my $zscore = abs($lograt - $lmean) / $lsd; 136 | if ($elen > $AUTOMATIC_INCLUDE_IF_SHORTER_THAN && 137 | $flen > $AUTOMATIC_INCLUDE_IF_SHORTER_THAN && $zscore > $worstz) { $worstz = $zscore; $worst = $_; } 138 | if ($zscore > $MAX_ZSCORE) { 139 | $zviol++; 140 | next; 141 | } 142 | print; 143 | } else { 144 | print; 145 | } 146 | $to++; 147 | } 148 | my $discard_rate2 = int(10000 * $zviol / ($lines - $pass1_discard)) / 100; 149 | print STDERR "\n Lines printed: $to\n Ratio violations: $zviol\t(discard rate = $discard_rate2%)\n"; 150 | print STDERR " Worst z-score: $worstz\n sentence: $worst"; 151 | exit 0; 152 | 153 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/lowercase.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | use strict; 3 | binmode(STDIN,":utf8"); 4 | binmode(STDOUT,":utf8"); 5 | while() { 6 | $_ = lc $_; 7 | print; 8 | } 9 | 10 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/moses-scfg-to-cdec.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | use strict; 3 | 4 | while(<>) { 5 | my ($src, $trg, $feats, $al) = split / \|\|\| /; 6 | # [X][NP] von [X][NP] [X] ||| [X][NP] 's [X][NP] [S] ||| 0.00110169 0.0073223 2.84566e-06 0.0027702 0.0121867 2.718 0.606531 ||| 0-0 1-1 2-2 ||| 635 245838 2 7 | 8 | my @srcs = split /\s+/, $src; 9 | my @trgs = split /\s+/, $trg; 10 | my $lhs = pop @trgs; 11 | $lhs =~ s/&apos;/'/g; 12 | $lhs =~ s/'/'/g; 13 | $lhs =~ s/,/COMMA/g; 14 | my $ntc = 0; 15 | my $sc = 0; 16 | my @of = (); 17 | my $x = pop @srcs; 18 | my %d = (); # src index to nonterminal count 19 | die "Expected [X]" unless $x eq '[X]'; 20 | my %amap = (); 21 | my @als = split / /, $al; 22 | for my $st (@als) { 23 | my ($s, $t) = split /-/, $st; 24 | $amap{$t} = $s; 25 | } 26 | for my $f (@srcs) { 27 | if ($f =~ /^\[X\]\[([^]]+)\]$/) { 28 | $ntc++; 29 | my $nt = $1; 30 | $nt =~ s/&apos;/'/g; 31 | $nt =~ s/'/'/g; 32 | $nt =~ s/,/COMMA/g; 33 | push @of, "[$nt]"; 34 | $d{$sc} = $ntc; 35 | } elsif ($f =~ /^\[[^]]+\]$/) { 36 | die "Unexpected $f"; 37 | } else { 38 | push @of, $f; 39 | } 40 | $sc++; 41 | } 42 | my @oe = (); 43 | my $ind = 0; 44 | for my $e (@trgs) { 45 | if ($e =~ /^\[X\]\[([^]]+)\]$/) { 46 | my $imap = $d{$amap{$ind}}; 47 | push @oe, "[$imap]"; 48 | } else { 49 | push @oe, $e; 50 | } 51 | $ind++; 52 | } 53 | my ($fe, $ef, $j, $lfe, $lef, $dummy, $of) = split / /, $feats; 54 | next if $lef eq '0'; 55 | next if $lfe eq '0'; 56 | next if $ef eq '0'; 57 | next if $fe eq '0'; 58 | next if $j eq '0'; 59 | next if $of eq '0'; 60 | $ef = sprintf('%.6g', log($ef)); 61 | $fe = sprintf('%.6g',log($fe)); 62 | $j = sprintf('%.6g',log($j)); 63 | $lef = sprintf('%.6g',log($lef)); 64 | $lfe = sprintf('%.6g',log($lfe)); 65 | $of = sprintf('%.6g',log($of)); 66 | print "$lhs ||| @of ||| @oe ||| RuleCount=1 FgivenE=$fe EgivenF=$ef Joint=$j LexEgivenF=$lef LexFgivenE=$lfe Other=$of\n"; 67 | } 68 | 69 | # [X][ADVP] angestiegen [X] ||| rose [X][ADVP] [VP] ||| 0.0538131 0.0097508 0.00744224 0.0249653 0.000698602 2.718 0.606531 ||| 0-1 1-0 ||| 13 94 2 70 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/moses-xml.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | 3 | use strict; 4 | $|++; 5 | 6 | my $msg = "Usage: $0 (escape|unescape)\n\n Escapes XMl entities and other special characters for use with Moses.\n\n"; 7 | 8 | die $msg unless scalar @ARGV == 1; 9 | 10 | if ($ARGV[0] eq "escape") { 11 | while () { 12 | $_ =~ s/\&/\&/g; # escape escape 13 | $_ =~ s/\|/\|/g; # factor separator 14 | $_ =~ s/\/\>/g; # xml 16 | $_ =~ s/\'/\'/g; # xml 17 | $_ =~ s/\"/\"/g; # xml 18 | $_ =~ s/\[/\[/g; # syntax non-terminal 19 | $_ =~ s/\]/\]/g; # syntax non-terminal 20 | print; 21 | } 22 | } elsif ($ARGV[0] eq "unescape") { 23 | while () { 24 | $_ =~ s/\|/\|/g; # factor separator 25 | $_ =~ s/\</\/g; # xml 27 | $_ =~ s/\'/\'/g; # xml 28 | $_ =~ s/\"/\"/g; # xml 29 | $_ =~ s/\[/\[/g; # syntax non-terminal 30 | $_ =~ s/\]/\]/g; # syntax non-terminal 31 | $_ =~ s/\&/\&/g; # escape escape 32 | print; 33 | } 34 | } else { 35 | die $msg; 36 | } 37 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/paste-files.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | use strict; 3 | 4 | die "Usage: $0 file1.txt file2.txt [file3.txt ...]\n\n Performs a per-line concatenation of all files using the ||| seperator.\n\n" unless scalar @ARGV > 1; 5 | 6 | my @fhs = (); 7 | for my $file (@ARGV) { 8 | my $fh; 9 | if ($file =~ /\.gz$/) { 10 | open $fh, "gunzip -c $file|" or die "Can't fork gunzip -c $file: $!"; 11 | } else { 12 | open $fh, "<$file" or die "Can't read $file: $!"; 13 | } 14 | binmode($fh,":utf8"); 15 | push @fhs, $fh; 16 | } 17 | binmode(STDOUT,":utf8"); 18 | binmode(STDERR,":utf8"); 19 | 20 | my $bad = 0; 21 | my $lc = 0; 22 | my $done = 0; 23 | my $fl = 0; 24 | while(1) { 25 | my @line; 26 | $lc++; 27 | if ($lc % 100000 == 0) { print STDERR " [$lc]\n"; $fl = 0; } 28 | elsif ($lc % 2500 == 0) { print STDERR "."; $fl = 1; } 29 | my $anum = 0; 30 | for my $fh (@fhs) { 31 | my $r = <$fh>; 32 | if (!defined $r) { 33 | die "Mismatched number of lines.\n" if scalar @line > 0; 34 | $done = 1; 35 | last; 36 | } 37 | $r =~ s/\r//g; 38 | chomp $r; 39 | if ($r =~ /\|\|\|/) { 40 | $r = ''; 41 | $bad++; 42 | } 43 | warn "$ARGV[$anum]:$lc contains a ||| symbol - please remove.\n" if $r =~ /\|\|\|/; 44 | $r =~ s/\|\|\|/ /g; 45 | $r =~ s/\s+/ /g; 46 | $r =~ s/^ +//; 47 | $r =~ s/ +$//; 48 | $anum++; 49 | push @line, $r; 50 | } 51 | last if $done; 52 | print STDOUT join(' ||| ', @line) . "\n"; 53 | } 54 | print STDERR "\n" if $fl; 55 | for (my $i = 1; $i < scalar @fhs; $i++) { 56 | my $fh = $fhs[$i]; 57 | my $r = <$fh>; 58 | die "Mismatched number of lines.\n" if defined $r; 59 | } 60 | print STDERR "Number of lines containing ||| was: $bad\n" if $bad > 0; 61 | 62 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/sample-dev-sets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import gzip 4 | import os 5 | import sys 6 | 7 | HELP = '''Process an input corpus by dividing it into pseudo-documents and uniformly 8 | sampling train and dev sets (simulate uniform sampling at the document level 9 | when document boundaries are unknown) 10 | 11 | usage: {} in_file out_prefix doc_size docs_per_dev_set dev_sets [-lc] 12 | recommended: doc_size=20, docs_per_dev_set=100, dev_sets=2 (dev and test) 13 | ''' 14 | 15 | def gzopen(f): 16 | return gzip.open(f, 'rb') if f.endswith('.gz') else open(f, 'r') 17 | 18 | def wc(f): 19 | return sum(1 for _ in gzopen(f)) 20 | 21 | def main(argv): 22 | 23 | if len(argv[1:]) < 5: 24 | sys.stderr.write(HELP.format(os.path.basename(argv[0]))) 25 | sys.exit(2) 26 | 27 | # Args 28 | in_file = os.path.abspath(argv[1]) 29 | out_prefix = os.path.abspath(argv[2]) 30 | doc_size = int(argv[3]) 31 | docs_per_dev_set = int(argv[4]) 32 | dev_sets = int(argv[5]) 33 | lc = (len(argv[1:]) == 6 and argv[6] == '-lc') 34 | 35 | # Compute sizes 36 | corpus_size = wc(in_file) 37 | total_docs = corpus_size / doc_size 38 | leftover = corpus_size % doc_size 39 | train_docs = total_docs - (dev_sets * docs_per_dev_set) 40 | train_batch_size = (train_docs / docs_per_dev_set) 41 | 42 | # Report 43 | sys.stderr.write('Splitting {} lines ({} documents)\n'.format(corpus_size, total_docs + (1 if leftover else 0))) 44 | sys.stderr.write('Train: {} ({})\n'.format((train_docs * doc_size) + leftover, train_docs + (1 if leftover else 0))) 45 | sys.stderr.write('Dev: {} x {} ({})\n'.format(dev_sets, docs_per_dev_set * doc_size, docs_per_dev_set)) 46 | 47 | inp = gzopen(in_file) 48 | train_out = open('{}.train'.format(out_prefix), 'w') 49 | dev_out = [open('{}.dev.{}'.format(out_prefix, i + 1), 'w') for i in range(dev_sets)] 50 | i = 0 51 | 52 | # For each set of documents 53 | for _ in range(docs_per_dev_set): 54 | # Write several documents to train 55 | for _ in range(train_batch_size): 56 | for _ in range(doc_size): 57 | i += 1 58 | train_out.write('{} ||| {}'.format(i, inp.readline()) if lc else inp.readline()) 59 | # Write a document to each dev 60 | for out in dev_out: 61 | for _ in range(doc_size): 62 | i += 1 63 | out.write('{} ||| {}'.format(i, inp.readline()) if lc else inp.readline()) 64 | # Write leftover lines to train 65 | for line in inp: 66 | i += 1 67 | train_out.write('{} ||| {}'.format(i, line) if lc else line) 68 | 69 | train_out.close() 70 | for out in dev_out: 71 | out.close() 72 | 73 | if __name__ == '__main__': 74 | main(sys.argv) 75 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/support/README: -------------------------------------------------------------------------------- 1 | Run ./tokenize.sh to tokenize text 2 | Edit eng_token_patterns and eng_token_list to add rules for things not to segment 3 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/support/fix-contract.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | $|++; 3 | 4 | use strict; 5 | while(<>) { 6 | #s/ (pre|anti|re|pro|inter|intra|multi|e|x|neo) - / $1- /ig; 7 | #s/ - (year) - (old)/ -$1-$2/ig; 8 | s/ ' (s|m|ll|re|d|ve) / '$1 /ig; 9 | s/n ' t / n't /ig; 10 | s/( |^)(\d+)-(\d+)( |$)/$1$2 - $3$4/g; 11 | s/ "$/ ”/; 12 | print; 13 | } 14 | 15 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/support/fix-eos.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | $|++; 3 | 4 | use strict; 5 | use utf8; 6 | 7 | binmode(STDIN, ":utf8"); 8 | binmode(STDOUT, ":utf8"); 9 | while() { 10 | s/(\p{Devanagari}{2}[A-Za-z0-9! ,.\@\p{Devanagari}]+?)\s+(\.)(\s*$|\s+\|\|\|)/$1 \x{0964}$3/s; 11 | print; 12 | } 13 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/support/quote-norm.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | $|++; 3 | use strict; 4 | use utf8; 5 | binmode(STDIN,"utf8"); 6 | binmode(STDOUT,"utf8"); 7 | while() { 8 | chomp; 9 | $_ = " $_ "; 10 | 11 | # Delete control characters: 12 | s/[\x{00}-\x{1f}]//g; 13 | 14 | # PTB --> normal 15 | s/-LRB-/(/g; 16 | s/-RRB-/)/g; 17 | s/-LSB-/[/g; 18 | s/-RSB-/]/g; 19 | s/-LCB-/{/g; 20 | s/-RCB-/}/g; 21 | s/ gon na / gonna /g; 22 | 23 | # Regularize named HTML/XML escapes: 24 | s/&\s*lt\s*;//gi; # HTML closing angle bracket 26 | s/&\s*squot\s*;/'/gi; # HTML single quote 27 | s/&\s*quot\s*;/"/gi; # HTML double quote 28 | s/&\s*nbsp\s*;/ /gi; # HTML non-breaking space 29 | s/'/\'/g; # HTML apostrophe 30 | s/&\s*amp\s*;/&/gi; # HTML ampersand (last) 31 | 32 | # Regularize known HTML numeric codes: 33 | s/&\s*#\s*160\s*;/ /gi; # no-break space 34 | s/&\s*#45\s*;\s*&\s*#45\s*;/--/g; # hyphen-minus hyphen-minus 35 | s/&\s*#45\s*;/--/g; # hyphen-minus 36 | 37 | # Convert arbitrary hex or decimal HTML entities to actual characters: 38 | s/&\#x([0-9A-Fa-f]+);/pack("U", hex($1))/ge; 39 | s/&\#([0-9]+);/pack("U", $1)/ge; 40 | 41 | # Regularlize spaces: 42 | s/\x{ad}//g; # soft hyphen 43 | s/\x{200C}//g; # zero-width non-joiner 44 | s/\x{a0}/ /g; # non-breaking space 45 | s/\x{2009}/ /g; # thin space 46 | s/\x{2028}/ /g; # "line separator" 47 | s/\x{2029}/ /g; # "paragraph separator" 48 | s/\x{202a}/ /g; # "left-to-right embedding" 49 | s/\x{202b}/ /g; # "right-to-left embedding" 50 | s/\x{202c}/ /g; # "pop directional formatting" 51 | s/\x{202d}/ /g; # "left-to-right override" 52 | s/\x{202e}/ /g; # "right-to-left override" 53 | s/\x{85}/ /g; # "next line" 54 | s/\x{fffd}/ /g; # "replacement character" 55 | s/\x{feff}/ /g; # byte-order mark 56 | s/\x{fdd3}/ /g; # "unicode non-character" 57 | 58 | # Convert other Windows 1252 characters to UTF-8 59 | s/\x{80}/\x{20ac}/g; # euro sign 60 | s/\x{95}/\x{2022}/g; # bullet 61 | s/\x{99}/\x{2122}/g; # trademark sign 62 | 63 | # Currency and measure conversions: 64 | s/ (\d\d): (\d\d)/ $1:$2/g; 65 | s/[\x{20a0}]\x{20ac}]/ EUR /g; 66 | s/[\x{00A3}]/ GBP /g; 67 | s/(\W)([A-Z]+\$?)(\d*\.\d+|\d+)/$1$2 $3/g; 68 | s/(\W)(euro?)(\d*\.\d+|\d+)/$1EUR $3/gi; 69 | 70 | # Ridiculous double conversions, UTF8 -> Windows 1252 -> UTF8: 71 | s/�c/--/g; # long dash 72 | s/\x{e2}\x{20ac}oe/\"/g; # opening double quote 73 | s/\x{e2}\x{20ac}\x{9c}/\"/g; # opening double quote 74 | s/\x{e2}\x{20ac}\x{9d}/\"/g; # closing double quote 75 | s/\x{e2}\x{20ac}\x{2122}/\'/g; # apostrophe 76 | s/\x{e2}\x{20ac}\x{201c}/ -- /g; # en dash? 77 | s/\x{e2}\x{20ac}\x{201d}/ -- /g; # em dash? 78 | s/â(\x{80}\x{99}|\x{80}\x{98})/'/g; # single quote? 79 | s/â(\x{80}\x{9c}|\x{80}\x{9d})/"/g; # double quote? 80 | s/\x{c3}\x{9f}/\x{df}/g; # esset 81 | s/\x{c3}\x{0178}/\x{df}/g; # esset 82 | s/\x{c3}\x{a4}/\x{e4}/g; # a umlaut 83 | s/\x{c3}\x{b6}/\x{f6}/g; # o umlaut 84 | s/\x{c3}\x{bc}/\x{fc}/g; # u umlaut 85 | s/\x{c3}\x{84}/\x{c4}/g; # A umlaut: create no C4s after this 86 | s/\x{c3}\x{201e}/\x{c4}/g; # A umlaut: create no C4s after this 87 | s/\x{c3}\x{96}/\x{d6}/g; # O umlaut 88 | s/\x{c3}\x{2013}/\x{d6}/g; # O umlaut 89 | s/\x{c3}\x{bc}/\x{dc}/g; # U umlaut 90 | s/\x{80}/\x{20ac}/g; # euro sign 91 | s/\x{95}/\x{2022}/g; # bullet 92 | s/\x{99}/\x{2122}/g; # trademark sign 93 | 94 | # Regularize quotes: 95 | s/ˇ/'/g; # caron 96 | s/´/'/g; # acute accent 97 | s/`/'/g; # grave accent 98 | s/ˉ/'/g; # modified letter macron 99 | s/ ,,/ "/g; # ghetto low-99 quote 100 | s/``/"/g; # latex-style left quote 101 | s/''/"/g; # latex-style right quote 102 | s/\x{300c}/"/g; # left corner bracket 103 | s/\x{300d}/"/g; # right corner bracket 104 | s/\x{3003}/"/g; # ditto mark 105 | s/\x{00a8}/"/g; # diaeresis 106 | s/\x{92}/\'/g; # curly apostrophe 107 | s/\x{2019}/\'/g; # curly apostrophe 108 | s/\x{f03d}/\'/g; # curly apostrophe 109 | s/\x{b4}/\'/g; # curly apostrophe 110 | s/\x{2018}/\'/g; # curly single open quote 111 | s/\x{201a}/\'/g; # low-9 quote 112 | s/\x{93}/\"/g; # curly left quote 113 | s/\x{201c}/\"/g; # curly left quote 114 | s/\x{94}/\"/g; # curly right quote 115 | s/\x{201d}/\"/g; # curly right quote 116 | s/\x{2033}/\"/g; # curly right quote 117 | s/\x{201e}/\"/g; # low-99 quote 118 | s/\x{84}/\"/g; # low-99 quote (bad enc) 119 | s/\x{201f}/\"/g; # high-rev-99 quote 120 | s/\x{ab}/\"/g; # opening guillemet 121 | s/\x{bb}/\"/g; # closing guillemet 122 | s/\x{0301}/'/g; # combining acute accent 123 | s/\x{203a}/\"/g; # angle quotation mark 124 | s/\x{2039}/\"/g; # angle quotation mark 125 | 126 | # Space inverted punctuation: 127 | s/¡/ ¡ /g; 128 | s/¿/ ¿ /g; 129 | 130 | # Russian abbreviations: 131 | s/ п. п. / п.п. /g; 132 | s/ ст. л. / ст.л. /g; 133 | s/ т. е. / т.е. /g; 134 | s/ т. к. / т.к. /g; 135 | s/ т. ч. / т.ч. /g; 136 | s/ т. д. / т.д. /g; 137 | s/ т. п. / т.п. /g; 138 | s/ и. о. / и.о. /g; 139 | s/ с. г. / с.г. /g; 140 | s/ г. р. / г.р. /g; 141 | s/ т. н. / т.н. /g; 142 | s/ т. ч. / т.ч. /g; 143 | s/ н. э. / н.э. /g; 144 | 145 | # Convert foreign numerals into Arabic numerals 146 | tr/०-९/0-9/; # devangari 147 | tr/౦-౯/0-9/; # telugu 148 | tr/೦-೯/0-9/; # kannada 149 | #tr/೦-௯/0-9/; # tamil 150 | tr/൦-൯/0-9/; # malayalam 151 | 152 | # Random punctuation: 153 | tr/!-~/!-~/; 154 | s/、/,/g; 155 | # s/。/./g; 156 | s/\x{85}/.../g; 157 | s/…/.../g; 158 | s/―/--/g; 159 | s/–/--/g; 160 | s/─/--/g; 161 | s/—/--/g; 162 | s/\x{97}/--/g; 163 | s/•/ * /g; 164 | s/\*/ * /g; 165 | s/،/,/g; 166 | s/؟/?/g; 167 | s/ـ/ /g; 168 | s/à ̄/i/g; 169 | s/’/'/g; 170 | s/â€"/"/g; 171 | s/؛/;/g; 172 | 173 | # Regularize ligatures: 174 | s/\x{9c}/oe/g; # "oe" ligature 175 | s/\x{0153}/oe/g; # "oe" ligature 176 | s/\x{8c}/Oe/g; # "OE" ligature 177 | s/\x{0152}/Oe/g; # "OE" ligature 178 | s/\x{fb00}/ff/g; # "ff" ligature 179 | s/\x{fb01}/fi/g; # "fi" ligature 180 | s/\x{fb02}/fl/g; # "fl" ligature 181 | s/\x{fb03}/ffi/g; # "ffi" ligature 182 | s/\x{fb04}/ffi/g; # "ffl" ligature 183 | 184 | s/β/ß/g; # WMT 2010 error 185 | 186 | # Strip extra spaces: 187 | s/\s+/ /g; 188 | s/^\s+//; 189 | s/\s+$//; 190 | 191 | print "$_\n"; 192 | } 193 | 194 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/support/token_list: -------------------------------------------------------------------------------- 1 | #words 2 | 3 | vis-à-vis 4 | vis-a-vis 5 | so-called 6 | anti- 7 | 8 | # Finnish 9 | eaa. 10 | ap. 11 | arv. 12 | ay. 13 | eKr. 14 | em. 15 | engl. 16 | esim. 17 | fil. 18 | lis. 19 | fil. 20 | maist. 21 | fil.toht. 22 | harv. 23 | ilt. 24 | jatk. 25 | jKr. 26 | jms. 27 | jne. 28 | joht. 29 | klo 30 | ko. 31 | ks. 32 | leht. 33 | lv. 34 | lyh. 35 | mm. 36 | mon. 37 | nim. 38 | nro. 39 | ns. 40 | nti. 41 | os. 42 | oy. 43 | pj. 44 | pnä. 45 | puh. 46 | pvm. 47 | rva. 48 | tms. 49 | ts. 50 | vars. 51 | vrt. 52 | ym. 53 | yms. 54 | yo. 55 | >>>>>>> 8646b68e5b124f612fd65b51ea40624f65a2f3d6 56 | 57 | # hindi abbreviation patterns 58 | जन. 59 | फर. 60 | अग. 61 | सित. 62 | अक्टू. 63 | अक्तू. 64 | नव. 65 | दिस. 66 | डी.एल. 67 | डी.टी.ओ. 68 | डी.ए. 69 | ए.एस.आई. 70 | डी.टी.ओ. 71 | एम.एस.आर.टी.सी. 72 | बी.बी.एम.बी. 73 | डी.एस.पी. 74 | सी.आर.पी. 75 | एस.डी.एम. 76 | सी.डी.पी.ओ. 77 | बी.डी.ओ. 78 | एस.डी.ओ. 79 | एम.पी.पी. 80 | पी.एच.ई. 81 | एस.एच.ओ. 82 | ए.सी.पी. 83 | यू.पी. 84 | पी.एम. 85 | आर.बी.डी. 86 | वी.पी. 87 | सी.ए.डी.पी. 88 | ए. 89 | बी. 90 | सी. 91 | डी. 92 | ई. 93 | एफ. 94 | जी. 95 | एच. 96 | आई. 97 | जे. 98 | के. 99 | एल. 100 | एम. 101 | एन. 102 | ओ. 103 | पी. 104 | क़यू. 105 | आर. 106 | एस. 107 | टी. 108 | यू. 109 | वी. 110 | डबल्यू. 111 | एक्स. 112 | वाई. 113 | ज़ेड. 114 | ज़ी. 115 | 116 | ##################### words made of punct only 117 | :- 118 | :-) 119 | :-( 120 | += 121 | -= 122 | .= 123 | *= 124 | >= 125 | <= 126 | == 127 | && 128 | || 129 | => 130 | -> 131 | <- 132 | :) 133 | :( 134 | ;) 135 | 136 | #################### abbr added by Fei 137 | oz. 138 | fl. 139 | tel. 140 | 1. 141 | 2. 142 | 3. 143 | 4. 144 | 5. 145 | 6. 146 | 7. 147 | 8. 148 | 9. 149 | 10. 150 | 151 | ##################### abbreviation: words that contain period. 152 | EE.UU. 153 | ee.uu. 154 | U.A.E 155 | Ala. 156 | Ph.D. 157 | min. 158 | max. 159 | z.B. 160 | d.h. 161 | ggf. 162 | ca. 163 | bzw. 164 | bzgl. 165 | Eng. 166 | i.e. 167 | a.m. 168 | am. 169 | A.M. 170 | Apr. 171 | Ariz. 172 | Ark. 173 | Aug. 174 | B.A.T. 175 | B.A.T 176 | Calif. 177 | Co. 178 | Conn. 179 | Corp. 180 | Cos. 181 | D.C. 182 | Dec. 183 | Dept. 184 | Dr. 185 | Drs. 186 | Feb. 187 | Fla. 188 | Fri. 189 | Ga. 190 | Gen. 191 | gen. 192 | GEN. 193 | Gov. 194 | Govt. 195 | Ill. 196 | Inc. 197 | Jan. 198 | Jr. 199 | Jul. 200 | Jun. 201 | Kan. 202 | L.A. 203 | Lieut. 204 | Lt. 205 | Ltd. 206 | Ma. 207 | Mar. 208 | Mass. 209 | Md. 210 | Mfg. 211 | Mgr. 212 | Mio. 213 | Mrd. 214 | Bio. 215 | Minn. 216 | Mo. 217 | Mon. 218 | Mr. 219 | Mrs. 220 | Ms. 221 | Mt. 222 | N.D. 223 | Neb. 224 | Nev. 225 | No. 226 | Nos. 227 | Nov. 228 | Oct. 229 | Okla. 230 | Op. 231 | Ore. 232 | Pa. 233 | p.m 234 | p.m. 235 | I.B.C. 236 | N.T.V 237 | Pres. 238 | Prof. 239 | Prop. 240 | Rd. 241 | Rev. 242 | R.J. 243 | C.L 244 | Rs. 245 | Rte. 246 | Sat. 247 | W.T 248 | Sen. 249 | Sep. 250 | Sept. 251 | Sgt. 252 | Sr. 253 | SR. 254 | St. 255 | Ste. 256 | Sun. 257 | Tenn. 258 | Tex. 259 | Thu. 260 | Tue. 261 | Univ. 262 | Va. 263 | Vt. 264 | Wed. 265 | approx. 266 | dept. 267 | e.g. 268 | E.G. 269 | eg. 270 | est. 271 | etc. 272 | ex. 273 | ext. 274 | ft. 275 | hon. 276 | hr. 277 | hrs. 278 | lab. 279 | lb. 280 | lbs. 281 | mass. 282 | misc. 283 | no. 284 | nos. 285 | nt. 286 | para. 287 | paras. 288 | pct. 289 | prod. 290 | rec. 291 | ref. 292 | rel. 293 | rep. 294 | sq. 295 | st. 296 | stg. 297 | vol. 298 | vs. 299 | U.S. 300 | J.S. 301 | U.N. 302 | u.n. 303 | A. 304 | B. 305 | C. 306 | D. 307 | E. 308 | F. 309 | G. 310 | H. 311 | I. 312 | J. 313 | K. 314 | L. 315 | M. 316 | N. 317 | O. 318 | P. 319 | Q. 320 | R. 321 | S. 322 | T. 323 | U. 324 | V. 325 | W. 326 | X. 327 | Y. 328 | Z. 329 | А. 330 | Б. 331 | В. 332 | Г. 333 | Д. 334 | Е. 335 | Ё. 336 | Ж. 337 | З. 338 | И. 339 | Й. 340 | К. 341 | Л. 342 | М. 343 | Н. 344 | О. 345 | П. 346 | Р. 347 | С. 348 | Т. 349 | У. 350 | Ф. 351 | Х. 352 | Ц. 353 | Ч. 354 | Ш. 355 | Щ. 356 | Ъ. 357 | Ы. 358 | Ь. 359 | Э. 360 | Ю. 361 | Я. 362 | л. 363 | г. 364 | обл. 365 | гг. 366 | в. 367 | вв. 368 | мин. 369 | ч. 370 | тыс. 371 | млн. 372 | млрд. 373 | трлн. 374 | кв. 375 | куб. 376 | руб. 377 | коп. 378 | долл. 379 | Прим. 380 | прим. 381 | чел. 382 | грн. 383 | мин. 384 | им. 385 | проф. 386 | акад. 387 | ред. 388 | авт. 389 | корр. 390 | соб. 391 | спец. 392 | см. 393 | тж. 394 | др. 395 | пр. 396 | букв. 397 | # Two-letter abbreviations - can be written with space 398 | п.п. 399 | ст.л. 400 | т.е. 401 | т.к. 402 | т.ч. 403 | т.д. 404 | т.п. 405 | и.о. 406 | с.г. 407 | г.р. 408 | т.н. 409 | т.ч. 410 | н.э. 411 | # Swahili 412 | A.D. 413 | Afr. 414 | A.G. 415 | agh. 416 | A.H. 417 | A.M. 418 | a.s. 419 | B.A. 420 | B.C. 421 | Bi. 422 | B.J. 423 | B.K. 424 | B.O.M. 425 | Brig. 426 | Bro. 427 | bt. 428 | bw. 429 | Bw. 430 | Cap. 431 | C.C. 432 | cCM. 433 | C.I.A. 434 | cit. 435 | C.M.S. 436 | Co. 437 | Corp. 438 | C.S.Sp. 439 | C.W. 440 | D.C. 441 | Dk. 442 | Dkt. 443 | Dk.B. 444 | Dr. 445 | E.C. 446 | e.g. 447 | E.M. 448 | E.n. 449 | etc. 450 | Feb. 451 | F.F.U. 452 | F.M. 453 | Fr. 454 | F.W. 455 | I.C.O. 456 | i.e. 457 | I.L.C. 458 | Inc. 459 | Jan. 460 | J.F. 461 | Jr. 462 | J.S. 463 | J.V.W.A. 464 | K.A.R. 465 | K.A.U. 466 | K.C.M.C. 467 | K.k. 468 | K.K. 469 | k.m. 470 | km. 471 | K.m. 472 | K.N.C.U. 473 | K.O. 474 | K.S. 475 | Ksh. 476 | kt. 477 | kumb. 478 | k.v. 479 | kv. 480 | L.G. 481 | ltd. 482 | Ltd. 483 | M.A. 484 | M.D. 485 | mf. 486 | Mh. 487 | Mhe. 488 | mil. 489 | m.m. 490 | M.m. 491 | Mm. 492 | M.M. 493 | Mr. 494 | Mrs. 495 | M.S. 496 | Mt. 497 | Mw. 498 | M.W. 499 | Mwl. 500 | na. 501 | Na. 502 | N.F. 503 | N.J. 504 | n.k. 505 | nk. 506 | n.k.w. 507 | N.N. 508 | Nov. 509 | O.C.D. 510 | op. 511 | P.C. 512 | Phd. 513 | Ph.D. 514 | P.J. 515 | P.o. 516 | P.O. 517 | P.O.P. 518 | P.P.F. 519 | Prof. 520 | P.s. 521 | P.S. 522 | Q.C. 523 | Rd. 524 | s.a.w. 525 | S.A.W. 526 | S.D. 527 | Sept. 528 | sh. 529 | Sh. 530 | SH. 531 | shs. 532 | Shs. 533 | S.J. 534 | S.L. 535 | S.L.P. 536 | S.s. 537 | S.S. 538 | St. 539 | s.w. 540 | s.w.T. 541 | taz. 542 | Taz. 543 | T.C. 544 | T.E.C. 545 | T.L.P. 546 | T.O.H.S. 547 | Tsh. 548 | T.V. 549 | tz. 550 | uk. 551 | Uk. 552 | U.M.C.A. 553 | U.N. 554 | U.S. 555 | Ush. 556 | U.W.T. 557 | Viii. 558 | Vol. 559 | V.T.C. 560 | W.H. 561 | yamb. 562 | Y.M.C.A. 563 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/support/token_patterns: -------------------------------------------------------------------------------- 1 | /^(al|el|ul|e)\-[a-z]+$/ 2 | /\.(fi|fr|es|co\.uk|de)$/ 3 | /:[a-zä]+$/ 4 | /^((а|А)(ль|ш)|уль)-\p{Cyrillic}+$/ 5 | /^\p{Cyrillic}\.\p{Cyrillic}\.$/ 6 | /^(\d|\d\d|\d\d\d)\.$/ 7 | 8 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/support/utf8-normalize-batch.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | 3 | use IPC::Open2; 4 | 5 | $|++; 6 | 7 | if (scalar(@ARGV) != 1) { 8 | print STDERR "usage: $0 \"CMD\"\n"; 9 | exit(2); 10 | } 11 | 12 | $CMD = $ARGV[0]; 13 | 14 | while () { 15 | s/\r\n*/\n/g; 16 | $PID = open2(*SOUT, *SIN, $CMD); 17 | print SIN "$_\n"; 18 | close(SIN); 19 | $_ = ; 20 | close(SOUT); 21 | waitpid($PID, 0); 22 | chomp; 23 | s/[\x00-\x1F]+/ /g; 24 | s/ +/ /g; 25 | s/^ //; 26 | s/ $//; 27 | print "$_\n"; 28 | } 29 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/support/utf8-normalize.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # this is the location on malbec, if you want to run on another machine 4 | # ICU may be installed in /usr or /usr/local 5 | ICU_DIR=/usr0/tools/icu 6 | UCONV_BIN=$ICU_DIR/bin/uconv 7 | UCONV_LIB=$ICU_DIR/lib 8 | 9 | if [ -e $UCONV_BIN ] && [ -d $UCONV_LIB ] 10 | then 11 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$UCONV_LIB 12 | if [ ! -x $UCONV_BIN ] 13 | then 14 | echo "$0: Cannot execute $UCONV_BIN! Please fix." 1>&2 15 | exit 16 | fi 17 | CMD="$UCONV_BIN -f utf8 -t utf8 -x Any-NFKC --callback skip" 18 | else 19 | if which uconv > /dev/null 20 | then 21 | CMD="uconv -f utf8 -t utf8 -x Any-NFKC --callback skip" 22 | else 23 | echo "$0: Cannot find ICU uconv (http://site.icu-project.org/) ... falling back to iconv. Quality may suffer." 1>&2 24 | CMD="iconv -f utf8 -t utf8 -c" 25 | fi 26 | fi 27 | 28 | if [[ $# == 1 && $1 == "--batchline" ]]; then 29 | perl $(dirname $0)/utf8-normalize-batch.pl "$CMD" 30 | else 31 | perl -e '$|++; while(<>){s/\r\n*/\n/g; print;}' \ 32 | |$CMD \ 33 | |/usr/bin/perl -w -e ' 34 | $|++; 35 | while (<>) { 36 | chomp; 37 | s/[\x00-\x1F]+/ /g; 38 | s/ +/ /g; 39 | s/^ //; 40 | s/ $//; 41 | print "$_\n"; 42 | }' 43 | fi 44 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/tokenize-anything.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ROOTDIR=`dirname $0` 4 | SUPPORT=$ROOTDIR/support 5 | 6 | if [[ $# == 1 && $1 == '-u' ]] ; then 7 | NORMARGS="--batchline" 8 | SEDFLAGS="-u" 9 | else 10 | if [[ $# != 0 ]] ; then 11 | echo Usage: `basename $0` [-u] \< file.in \> file.out 1>&2 12 | echo 1>&2 13 | echo Tokenizes text in a reasonable way in most languages. 1>&2 14 | echo 1>&2 15 | exit 1 16 | fi 17 | NORMARGS="" 18 | SEDFLAGS="" 19 | fi 20 | 21 | $SUPPORT/utf8-normalize.sh $NORMARGS | 22 | $SUPPORT/quote-norm.pl | 23 | $SUPPORT/tokenizer.pl | 24 | $SUPPORT/fix-eos.pl | 25 | sed $SEDFLAGS -e 's/ al - / al-/g' | 26 | $SUPPORT/fix-contract.pl | 27 | sed $SEDFLAGS -e 's/^ //' | sed $SEDFLAGS -e 's/ $//' | 28 | perl -e '$|++; while(<>){s/(\d+)(\.+)$/$1 ./; s/(\d+)(\.+) \|\|\|/$1 . |||/; print;}' 29 | 30 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/tokenize-parallel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import gzip 4 | import math 5 | import os 6 | import shutil 7 | import subprocess 8 | import sys 9 | import tempfile 10 | 11 | DEFAULT_JOBS = 8 12 | DEFAULT_TMP = '/tmp' 13 | 14 | TOKENIZER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tokenize-anything.sh') 15 | 16 | def gzopen(f): 17 | return gzip.open(f) if f.endswith('.gz') else open(f) 18 | 19 | def wc(f): 20 | return sum(1 for line in gzopen(f)) 21 | 22 | def main(argv): 23 | 24 | if len(argv[1:]) < 1: 25 | sys.stderr.write('Parallelize text normalization with multiple instances of tokenize-anything.sh\n\n') 26 | sys.stderr.write('Usage: {} in-file [jobs [temp-dir]] >out-file\n'.format(argv[0])) 27 | sys.exit(2) 28 | 29 | in_file = argv[1] 30 | jobs = int(argv[2]) if len(argv[1:]) > 1 else DEFAULT_JOBS 31 | tmp = argv[3] if len(argv[1:]) > 2 else DEFAULT_TMP 32 | 33 | work = tempfile.mkdtemp(prefix='tok.', dir=tmp) 34 | in_wc = wc(in_file) 35 | # Don't start more jobs than we have lines 36 | jobs = min(jobs, in_wc) 37 | lines_per = int(math.ceil(float(in_wc)/jobs)) 38 | 39 | inp = gzopen(in_file) 40 | procs = [] 41 | files = [] 42 | outs = [] 43 | for i in range(jobs): 44 | raw = os.path.join(work, 'in.{}'.format(i)) 45 | tok = os.path.join(work, 'out.{}'.format(i)) 46 | files.append(tok) 47 | # Write raw batch 48 | raw_out = open(raw, 'w') 49 | for _ in range(lines_per): 50 | line = inp.readline() 51 | if not line: 52 | break 53 | raw_out.write(line) 54 | raw_out.close() 55 | # Start tokenizer 56 | raw_in = open(raw) 57 | tok_out = open(tok, 'w') 58 | outs.append(tok_out) 59 | p = subprocess.Popen(TOKENIZER, stdin=raw_in, stdout=tok_out) 60 | procs.append(p) 61 | 62 | # Cat output of each tokenizer as it finishes 63 | for (p, f, o) in zip(procs, files, outs): 64 | p.wait() 65 | o.close() 66 | for line in open(f): 67 | sys.stdout.write(line) 68 | 69 | # Cleanup 70 | shutil.rmtree(work) 71 | 72 | if __name__ == '__main__': 73 | main(sys.argv) 74 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/untok.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | 3 | use IO::Handle; 4 | STDOUT->autoflush(1); 5 | 6 | while (<>) { 7 | $output = ""; 8 | @tokens = split; 9 | $lspace = 0; 10 | $qflag = 0; 11 | for ($i=0; $i<=$#tokens; $i++) { 12 | $token = $tokens[$i]; 13 | $prev = $next = ""; 14 | $rspace = 1; 15 | if ($i > 0) { 16 | $prev = $tokens[$i-1]; 17 | } 18 | if ($i < $#tokens) { 19 | $next = $tokens[$i+1]; 20 | } 21 | 22 | # possessives join to the left 23 | if ($token =~ /^(n't|'(s|m|re|ll|ve|d))$/) { 24 | $lspace = 0; 25 | } elsif ($token eq "'" && $prev =~ /s$/) { 26 | $lspace = 0; 27 | 28 | # hyphen only when a hyphen, not a dash 29 | } elsif ($token eq "-" && $prev =~ /[A-Za-z0-9]$/ && $next =~ /^[A-Za-z0-9]/) { 30 | $lspace = $rspace = 0; 31 | 32 | # quote marks alternate 33 | } elsif ($token eq '"') { 34 | if ($qflag) { 35 | $lspace = 0; 36 | } else { 37 | $rspace = 0; 38 | } 39 | $qflag = !$qflag; 40 | 41 | # period joins on both sides when a decimal point 42 | } elsif ($token eq "." && $prev =~ /\d$/ && $next =~ /\d$/) { 43 | $lspace = $rspace = 0; 44 | 45 | # Left joiners 46 | } elsif ($token =~ /^[.,:;?!%)\]]$/) { 47 | $lspace = 0; 48 | # Right joiners 49 | } elsif ($token =~ /^[$(\[]$/) { 50 | $rspace = 0; 51 | # Joiners on both sides 52 | } elsif ($token =~ /^[\/]$/) { 53 | $lspace = $rspace = 0; 54 | } 55 | 56 | if ($lspace) { 57 | $output .= " "; 58 | } 59 | $output .= $token; 60 | $lspace = $rspace; 61 | } 62 | print "$output\n"; 63 | } 64 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/utf8-normalize.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script uses ICU uconv (http://site.icu-project.org/), if it's available 4 | # to normalize UTF8 text into a standard form. For information about this 5 | # process, refer to http://en.wikipedia.org/wiki/Unicode_equivalence#Normalization 6 | # Escape characters between 0x00-0x1F are removed 7 | 8 | if which uconv > /dev/null 9 | then 10 | CMD="uconv -f utf8 -t utf8 -x Any-NFKC --callback skip --remove-signature" 11 | else 12 | echo "Cannot find ICU uconv (http://site.icu-project.org/) ... falling back to iconv. Normalization NOT taking place." 1>&2 13 | CMD="iconv -f utf8 -t utf8 -c" 14 | fi 15 | 16 | $CMD | /usr/bin/perl -w -e ' 17 | while (<>) { 18 | chomp; 19 | s/[\x00-\x1F]+/ /g; 20 | s/ +/ /g; 21 | s/^ //; 22 | s/ $//; 23 | print "$_\n"; 24 | }' 25 | 26 | -------------------------------------------------------------------------------- /fine-tune/evaluation/cdec-corpus/xml-tok.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import re 5 | import subprocess 6 | import sys 7 | 8 | # Tokenize XML files with tokenize-anything.sh 9 | # in: The earnings on its 10-year bonds are 28.45%. 10 | # out: The earnings on its 10 - year bonds are 28.45 % . 11 | 12 | def escape(s): 13 | return s.replace('&', '&').replace('>', '>').replace('<', '<').replace('"', '"').replace('\'', ''') 14 | 15 | def unescape(s): 16 | return s.replace('>', '>').replace('<', '<').replace('"', '"').replace(''', '\'').replace('&', '&') 17 | 18 | def main(): 19 | tok = subprocess.Popen([os.path.join(os.path.dirname(__file__), 'tokenize-anything.sh'), '-u'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) 20 | while True: 21 | line = sys.stdin.readline() 22 | if not line: 23 | break 24 | line = line.strip() 25 | pieces = [] 26 | eol = len(line) 27 | pos = 0 28 | while pos < eol: 29 | next = line.find('<', pos) 30 | if next == -1: 31 | next = eol 32 | tok.stdin.write('{}\n'.format(unescape(line[pos:next]))) 33 | pieces.append(escape(tok.stdout.readline().strip())) 34 | if next == eol: 35 | break 36 | pos = line.find('>', next + 1) 37 | if pos == -1: 38 | pos = eol 39 | else: 40 | pos += 1 41 | pieces.append(line[next:pos]) 42 | sys.stdout.write('{}\n'.format(' '.join(pieces).strip())) 43 | tok.stdin.close() 44 | tok.wait() 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /fine-tune/evaluation/eval_gen.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | from typing import Iterable, Optional 5 | import datasets 6 | import re 7 | 8 | 9 | def argument_parser(): 10 | 11 | parser = argparse.ArgumentParser(description="Preprocess AMR data") 12 | # Multiple input parameters 13 | parser.add_argument("--in-tokens", help="input tokens", required=True, type=str) 14 | parser.add_argument("--in-reference-tokens", help="refrence tokens to compute metric", type=str) 15 | args = parser.parse_args() 16 | 17 | return args 18 | 19 | 20 | def tokenize_sentence(text, debug=False): 21 | text = re.sub(r"('ll|n't|'m|'s|'d|'re)", r" \1", text) 22 | text = re.sub(r"(\s+)", r" ", text) 23 | return text 24 | 25 | 26 | def raw_corpus_bleu( 27 | hypothesis: Iterable[str], reference: Iterable[str], offset: Optional[float] = 0.01 28 | ) -> float: 29 | bleu = datasets.load_metric("bleu") 30 | hypothesis = [itm.strip().split() for itm in hypothesis] 31 | reference = [[itm.strip().split()] for itm in reference] 32 | res = bleu.compute(predictions=hypothesis, references=reference) 33 | return res 34 | 35 | 36 | def raw_corpus_chrf(hypotheses: Iterable[str], references: Iterable[str]) -> float: 37 | chrf = datasets.load_metric("chrf") 38 | hypotheses = [itm.strip() for itm in hypotheses] 39 | references = [[itm.strip()] for itm in references] 40 | res = chrf.compute(predictions=hypotheses, references=references) 41 | return res 42 | 43 | 44 | def raw_corpus_meteor(hypotheses: Iterable[str], references: Iterable[str]): 45 | hypotheses = [itm.strip() for itm in hypotheses] 46 | references = [[itm.strip()] for itm in references] 47 | meteor = datasets.load_metric("meteor") 48 | res = meteor.compute(predictions=hypotheses, references=references) 49 | return res 50 | 51 | 52 | def raw_corpus_bleurt(hypotheses: Iterable[str], references: Iterable[str]): 53 | hypotheses = [itm.strip() for itm in hypotheses] 54 | references = [itm.strip() for itm in references] 55 | bleurt = datasets.load_metric("bleurt", 'bleurt-base-512') 56 | res = bleurt.compute(predictions=hypotheses, references=references) 57 | return res 58 | 59 | 60 | def read_tokens(in_tokens_file): 61 | with open(in_tokens_file) as fid: 62 | lines = fid.readlines() 63 | return lines 64 | 65 | 66 | if __name__ == "__main__": 67 | 68 | # Argument handlig 69 | args = argument_parser() 70 | 71 | # read files 72 | ref = read_tokens(args.in_reference_tokens) 73 | hyp = read_tokens(args.in_tokens) 74 | 75 | # Lower evaluation 76 | for i in range(len(ref)): 77 | ref[i] = ref[i].lower() 78 | 79 | # Lower case output 80 | for i in range(len(hyp)): 81 | if "" in hyp[i]: 82 | hyp[i] = hyp[i].split("")[-1] 83 | hyp[i] = tokenize_sentence(hyp[i].lower()) 84 | 85 | # results 86 | 87 | bleu = raw_corpus_bleu(hyp, ref) 88 | print("BLEU {}".format(bleu)) 89 | 90 | chrFpp = raw_corpus_chrf(hyp, ref) 91 | print("chrF++ {}".format(chrFpp)) 92 | 93 | #meteor = raw_corpus_meteor(hyp, ref) 94 | #print("meteor {}".format(meteor)) 95 | 96 | #bleurt = raw_corpus_bleurt(hyp, ref) 97 | #b_res = sum(bleurt["scores"]) / len(bleurt["scores"]) 98 | #print("bleurt {}".format(b_res)) 99 | -------------------------------------------------------------------------------- /fine-tune/evaluation/eval_gen.sh: -------------------------------------------------------------------------------- 1 | #:< $gold_tok 10 | bash $tokenizer -u < $pred > $pred_tok 11 | #! 12 | 13 | echo "Evaluating BLEU score ..." 14 | python eval_gen.py --in-tokens $pred_tok --in-reference-tokens $gold_tok 15 | 16 | echo "Evaluating Meteor score ..." 17 | java -jar meteor-1.5/meteor-1.5.jar $pred_tok $gold_tok > $pred.meteor 18 | tail -n 10 $pred.meteor 19 | 20 | -------------------------------------------------------------------------------- /fine-tune/evaluation/eval_smatch.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from amrlib.evaluate.smatch_enhanced import compute_scores 4 | GOLD=sys.argv[1] 5 | PRED=sys.argv[2] 6 | compute_scores(PRED, GOLD) 7 | -------------------------------------------------------------------------------- /fine-tune/evaluation/eval_smatch.sh: -------------------------------------------------------------------------------- 1 | gold_amr=$1 2 | hyp_amr=$2 3 | 4 | python eval_smatch.py $gold_amr $hyp_amr -------------------------------------------------------------------------------- /fine-tune/inference-amr.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | RootDir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 3 | 4 | Dataset=examples 5 | 6 | BasePath=/mnt/nfs-storage/data # change dir here 7 | DataPath=$RootDir/../$Dataset 8 | 9 | ModelCate=AMRBART-large 10 | 11 | MODEL=$1 12 | ModelCache=$BasePath/.cache 13 | DataCache=$DataPath/.cache/dump-amrparsing 14 | 15 | lr=1e-5 16 | 17 | OutputDir=${RootDir}/outputs/Infer-$Dataset-${ModelCate}-AMRParing-bsz16-lr-${lr}-UnifiedInp 18 | 19 | if [ ! -d ${OutputDir} ];then 20 | mkdir -p ${OutputDir} 21 | else 22 | read -p "${OutputDir} already exists, delete origin one [y/n]?" yn 23 | case $yn in 24 | [Yy]* ) rm -rf ${OutputDir}; mkdir -p ${OutputDir};; 25 | [Nn]* ) echo "exiting..."; exit;; 26 | * ) echo "Please answer yes or no.";; 27 | esac 28 | fi 29 | 30 | export HF_DATASETS_CACHE=$DataCache 31 | 32 | if [ ! -d ${DataCache} ];then 33 | mkdir -p ${DataCache} 34 | fi 35 | 36 | # torchrun --nnodes=1 --nproc_per_node=1 --max_restarts=0 --rdzv_id=1 --rdzv_backend=c10d main.py \ 37 | python -u main.py \ 38 | --data_dir $DataPath \ 39 | --task "text2amr" \ 40 | --test_file $DataPath/data4parsing.jsonl \ 41 | --output_dir $OutputDir \ 42 | --cache_dir $ModelCache \ 43 | --data_cache_dir $DataCache \ 44 | --overwrite_cache True \ 45 | --model_name_or_path $MODEL \ 46 | --overwrite_output_dir \ 47 | --unified_input True \ 48 | --per_device_eval_batch_size 16 \ 49 | --max_source_length 400 \ 50 | --max_target_length 1024 \ 51 | --val_max_target_length 1024 \ 52 | --generation_max_length 1024 \ 53 | --generation_num_beams 5 \ 54 | --predict_with_generate \ 55 | --smart_init False \ 56 | --use_fast_tokenizer False \ 57 | --logging_dir $OutputDir/logs \ 58 | --seed 42 \ 59 | --fp16 \ 60 | --fp16_backend "auto" \ 61 | --dataloader_num_workers 8 \ 62 | --eval_dataloader_num_workers 2 \ 63 | --include_inputs_for_metrics \ 64 | --do_predict \ 65 | --ddp_find_unused_parameters False \ 66 | --report_to "tensorboard" \ 67 | --dataloader_pin_memory True 2>&1 | tee $OutputDir/run.log 68 | -------------------------------------------------------------------------------- /fine-tune/inference-text.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | RootDir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 3 | 4 | Dataset=examples 5 | 6 | BasePath=/mnt/nfs-storage/data # change dir here 7 | DataPath=$RootDir/../$Dataset 8 | 9 | ModelCate=AMRBART-large 10 | 11 | MODEL=$1 12 | ModelCache=$BasePath/.cache 13 | DataCache=$DataPath/.cache/dump-amr2text 14 | 15 | lr=2e-6 16 | 17 | OutputDir=${RootDir}/outputs/Infer-$Dataset-$ModelCate-AMR2Text-bsz8-lr-${lr}-UnifiedInp 18 | 19 | if [ ! -d ${OutputDir} ];then 20 | mkdir -p ${OutputDir} 21 | else 22 | read -p "${OutputDir} already exists, delete origin one [y/n]?" yn 23 | case $yn in 24 | [Yy]* ) rm -rf ${OutputDir}; mkdir -p ${OutputDir};; 25 | [Nn]* ) echo "exiting..."; exit;; 26 | * ) echo "Please answer yes or no.";; 27 | esac 28 | fi 29 | 30 | export HF_DATASETS_CACHE=$DataCache 31 | 32 | if [ ! -d ${DataCache} ];then 33 | mkdir -p ${DataCache} 34 | fi 35 | 36 | # torchrun --nnodes=1 --nproc_per_node=1 --max_restarts=0 --rdzv_id=1 --rdzv_backend=c10d main.py \ 37 | python -u main.py \ 38 | --data_dir $DataPath \ 39 | --task "amr2text" \ 40 | --test_file $DataPath/data4generation.jsonl \ 41 | --output_dir $OutputDir \ 42 | --cache_dir $ModelCache \ 43 | --data_cache_dir $DataCache \ 44 | --overwrite_cache True \ 45 | --model_name_or_path $MODEL \ 46 | --overwrite_output_dir \ 47 | --unified_input True \ 48 | --per_device_eval_batch_size 8 \ 49 | --max_source_length 1024 \ 50 | --max_target_length 400 \ 51 | --val_max_target_length 400 \ 52 | --generation_max_length 400 \ 53 | --generation_num_beams 5 \ 54 | --predict_with_generate \ 55 | --smart_init False \ 56 | --use_fast_tokenizer False \ 57 | --logging_dir $OutputDir/logs \ 58 | --seed 42 \ 59 | --fp16 \ 60 | --fp16_backend "auto" \ 61 | --dataloader_num_workers 8 \ 62 | --eval_dataloader_num_workers 2 \ 63 | --include_inputs_for_metrics \ 64 | --do_predict \ 65 | --ddp_find_unused_parameters False \ 66 | --report_to "tensorboard" \ 67 | --dataloader_pin_memory True 2>&1 | tee $OutputDir/run.log 68 | -------------------------------------------------------------------------------- /fine-tune/metric/sacrebleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Datasets Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ SACREBLEU metric. """ 15 | 16 | import sacrebleu as scb 17 | from packaging import version 18 | 19 | import datasets 20 | 21 | 22 | _CITATION = """\ 23 | @inproceedings{post-2018-call, 24 | title = "A Call for Clarity in Reporting {BLEU} Scores", 25 | author = "Post, Matt", 26 | booktitle = "Proceedings of the Third Conference on Machine Translation: Research Papers", 27 | month = oct, 28 | year = "2018", 29 | address = "Belgium, Brussels", 30 | publisher = "Association for Computational Linguistics", 31 | url = "https://www.aclweb.org/anthology/W18-6319", 32 | pages = "186--191", 33 | } 34 | """ 35 | 36 | _DESCRIPTION = """\ 37 | SacreBLEU provides hassle-free computation of shareable, comparable, and reproducible BLEU scores. 38 | Inspired by Rico Sennrich's `multi-bleu-detok.perl`, it produces the official WMT scores but works with plain text. 39 | It also knows all the standard test sets and handles downloading, processing, and tokenization for you. 40 | 41 | See the [README.md] file at https://github.com/mjpost/sacreBLEU for more information. 42 | """ 43 | 44 | _KWARGS_DESCRIPTION = """ 45 | Produces BLEU scores along with its sufficient statistics 46 | from a source against one or more references. 47 | 48 | Args: 49 | predictions (`list` of `str`): list of translations to score. Each translation should be tokenized into a list of tokens. 50 | references (`list` of `list` of `str`): A list of lists of references. The contents of the first sub-list are the references for the first prediction, the contents of the second sub-list are for the second prediction, etc. Note that there must be the same number of references for each prediction (i.e. all sub-lists must be of the same length). 51 | smooth_method (`str`): The smoothing method to use, defaults to `'exp'`. Possible values are: 52 | - `'none'`: no smoothing 53 | - `'floor'`: increment zero counts 54 | - `'add-k'`: increment num/denom by k for n>1 55 | - `'exp'`: exponential decay 56 | smooth_value (`float`): The smoothing value. Only valid when `smooth_method='floor'` (in which case `smooth_value` defaults to `0.1`) or `smooth_method='add-k'` (in which case `smooth_value` defaults to `1`). 57 | tokenize (`str`): Tokenization method to use for BLEU. If not provided, defaults to `'zh'` for Chinese, `'ja-mecab'` for Japanese and `'13a'` (mteval) otherwise. Possible values are: 58 | - `'none'`: No tokenization. 59 | - `'zh'`: Chinese tokenization. 60 | - `'13a'`: mimics the `mteval-v13a` script from Moses. 61 | - `'intl'`: International tokenization, mimics the `mteval-v14` script from Moses 62 | - `'char'`: Language-agnostic character-level tokenization. 63 | - `'ja-mecab'`: Japanese tokenization. Uses the [MeCab tokenizer](https://pypi.org/project/mecab-python3). 64 | lowercase (`bool`): If `True`, lowercases the input, enabling case-insensitivity. Defaults to `False`. 65 | force (`bool`): If `True`, insists that your tokenized input is actually detokenized. Defaults to `False`. 66 | use_effective_order (`bool`): If `True`, stops including n-gram orders for which precision is 0. This should be `True`, if sentence-level BLEU will be computed. Defaults to `False`. 67 | 68 | Returns: 69 | 'score': BLEU score, 70 | 'counts': Counts, 71 | 'totals': Totals, 72 | 'precisions': Precisions, 73 | 'bp': Brevity penalty, 74 | 'sys_len': predictions length, 75 | 'ref_len': reference length, 76 | 77 | Examples: 78 | 79 | Example 1: 80 | >>> predictions = ["hello there general kenobi", "foo bar foobar"] 81 | >>> references = [["hello there general kenobi", "hello there !"], ["foo bar foobar", "foo bar foobar"]] 82 | >>> sacrebleu = datasets.load_metric("sacrebleu") 83 | >>> results = sacrebleu.compute(predictions=predictions, references=references) 84 | >>> print(list(results.keys())) 85 | ['score', 'counts', 'totals', 'precisions', 'bp', 'sys_len', 'ref_len'] 86 | >>> print(round(results["score"], 1)) 87 | 100.0 88 | 89 | Example 2: 90 | >>> predictions = ["hello there general kenobi", 91 | ... "on our way to ankh morpork"] 92 | >>> references = [["hello there general kenobi", "hello there !"], 93 | ... ["goodbye ankh morpork", "ankh morpork"]] 94 | >>> sacrebleu = datasets.load_metric("sacrebleu") 95 | >>> results = sacrebleu.compute(predictions=predictions, 96 | ... references=references) 97 | >>> print(list(results.keys())) 98 | ['score', 'counts', 'totals', 'precisions', 'bp', 'sys_len', 'ref_len'] 99 | >>> print(round(results["score"], 1)) 100 | 39.8 101 | """ 102 | 103 | 104 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 105 | class Sacrebleu(datasets.Metric): 106 | def _info(self): 107 | if version.parse(scb.__version__) < version.parse("1.4.12"): 108 | raise ImportWarning( 109 | "To use `sacrebleu`, the module `sacrebleu>=1.4.12` is required, and the current version of `sacrebleu` doesn't match this condition.\n" 110 | 'You can install it with `pip install "sacrebleu>=1.4.12"`.' 111 | ) 112 | return datasets.MetricInfo( 113 | description=_DESCRIPTION, 114 | citation=_CITATION, 115 | homepage="https://github.com/mjpost/sacreBLEU", 116 | inputs_description=_KWARGS_DESCRIPTION, 117 | features=datasets.Features( 118 | { 119 | "predictions": datasets.Value("string", id="sequence"), 120 | "references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"), 121 | } 122 | ), 123 | codebase_urls=["https://github.com/mjpost/sacreBLEU"], 124 | reference_urls=[ 125 | "https://github.com/mjpost/sacreBLEU", 126 | "https://en.wikipedia.org/wiki/BLEU", 127 | "https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213", 128 | ], 129 | ) 130 | 131 | def _compute( 132 | self, 133 | predictions, 134 | references, 135 | smooth_method="exp", 136 | smooth_value=None, 137 | force=False, 138 | lowercase=False, 139 | tokenize=None, 140 | use_effective_order=False, 141 | ): 142 | references_per_prediction = len(references[0]) 143 | if any(len(refs) != references_per_prediction for refs in references): 144 | raise ValueError("Sacrebleu requires the same number of references for each prediction") 145 | transformed_references = [[refs[i] for refs in references] for i in range(references_per_prediction)] 146 | output = scb.corpus_bleu( 147 | predictions, 148 | transformed_references, 149 | smooth_method=smooth_method, 150 | smooth_value=smooth_value, 151 | force=force, 152 | lowercase=lowercase, 153 | use_effective_order=use_effective_order, 154 | **(dict(tokenize=tokenize) if tokenize else {}), 155 | ) 156 | output_dict = { 157 | "score": output.score, 158 | "counts": output.counts, 159 | "totals": output.totals, 160 | "precisions": output.precisions, 161 | "bp": output.bp, 162 | "sys_len": output.sys_len, 163 | "ref_len": output.ref_len, 164 | } 165 | return output_dict -------------------------------------------------------------------------------- /fine-tune/model_interface/tokenization_bart.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | # this is a simplified version of "https://github.com/SapienzaNLP/spring/blob/main/spring_amr/tokenization_bart.py" 3 | import sys 4 | import penman 5 | import regex as re 6 | from transformers import BartTokenizer 7 | from common import postprocessing 8 | from common.penman_interface import encode 9 | from common.constant import raw_special_tokens, recategorizations 10 | 11 | 12 | class AMRBartTokenizer(BartTokenizer): 13 | INIT = 'Ġ' 14 | 15 | def __init__(self, vocab_file, merges_file, errors="replace", bos_token="", eos_token="", sep_token="", cls_token="", unk_token="", pad_token="", mask_token="", add_prefix_space=False, **kwargs): 16 | super().__init__(vocab_file, merges_file, errors, bos_token, eos_token, sep_token, cls_token, unk_token, pad_token, mask_token, add_prefix_space, **kwargs) 17 | self.modified = 0 18 | self.recategorizations = set(recategorizations) 19 | self.patterns = re.compile(r""" ?<[a-z]+:?\d*>| ?:[^\s]+|'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 20 | self.remove_pars = False 21 | 22 | @classmethod 23 | def from_pretrained(cls, pretrained_model_path, *args, **kwargs): 24 | inst = super().from_pretrained(pretrained_model_path, *args, **kwargs) 25 | inst.init_amr_vocabulary() 26 | return inst 27 | 28 | def init_amr_vocabulary(self): 29 | self.old_enc_size = old_enc_size = len(self.encoder) 30 | tokens = [t for t in raw_special_tokens if t not in self.encoder] 31 | 32 | for i, t in enumerate(tokens, start=old_enc_size): 33 | self.encoder[t] = i 34 | 35 | self.encoder = {k: i for i, (k,v) in enumerate(sorted(self.encoder.items(), key=lambda x: x[1]))} 36 | self.decoder = {v: k for k, v in sorted(self.encoder.items(), key=lambda x: x[1])} 37 | self.modified = len(tokens) 38 | 39 | self.amr_bos_token = "" 40 | self.amr_bos_token_id = self.encoder[self.amr_bos_token] 41 | self.amr_eos_token = "" 42 | self.amr_eos_token_id = self.encoder[self.amr_eos_token] 43 | print(f"Added {self.modified} AMR tokens") 44 | 45 | def _tokenize(self, text): 46 | """ Tokenize a string. Modified in order to handle sentences with recategorization pointers""" 47 | bpe_tokens = [] 48 | for tok_span in text.lstrip().split(' '): 49 | tok_span = tok_span.strip() 50 | recats = tok_span.rsplit('_', 1) 51 | if len(recats) == 2 and recats[0] in self.recategorizations and ('_' + recats[1]) in self.encoder: 52 | bpe_tokens.extend([self.INIT + recats[0], '_' + recats[1]]) 53 | else: 54 | for token in re.findall(self.pat, ' ' + tok_span): 55 | token = "".join( 56 | self.byte_encoder[b] for b in token.encode("utf-8") 57 | ) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case) 58 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) 59 | 60 | return bpe_tokens 61 | 62 | def _tok_bpe(self, token): 63 | tokk = [] 64 | tok = token.strip() 65 | recats = tok.rsplit('_', 1) 66 | if len(recats) == 2 and recats[0] in self.recategorizations and ('_' + recats[1]) in self.encoder: 67 | tokk.extend([self.INIT + recats[0], '_' + recats[1]]) 68 | else: 69 | for tok in self.patterns.findall(' ' + token): 70 | tok = "".join( 71 | self.byte_encoder[b] for b in tok.encode("utf-8")) 72 | toks = self.bpe(tok).split(' ') 73 | tokk.extend(toks) 74 | return tokk 75 | 76 | def tokenize_amr(self, amr_tokens): 77 | bpe_tokens = [] 78 | for i, tokk in enumerate(amr_tokens): 79 | is_in_enc = self.INIT + tokk in self.encoder 80 | is_rel = tokk.startswith(':') and len(tokk) > 1 81 | is_spc = tokk.startswith('<') and tokk.endswith('>') 82 | is_of = tokk.startswith(':') and tokk.endswith('-of') 83 | is_frame = re.match(r'.+-\d\d', tokk) is not None 84 | 85 | if tokk.startswith('"') and tokk.endswith('"'): # dealing with examples like "The_United_Kingdom_of_xxx" 86 | tokk = tokk[1:-1].replace('_', ' ') 87 | bpe_toks = [self.INIT + ""] 88 | bpe_toks += self._tok_bpe(tokk) 89 | bpe_toks.append(self.INIT + "") 90 | 91 | elif (is_rel or is_spc or is_frame or is_of): 92 | if is_in_enc: 93 | bpe_toks = [self.INIT + tokk] 94 | elif is_frame: 95 | bpe_toks = self._tok_bpe(tokk[:-3]) + [tokk[-3:]] 96 | elif is_of: 97 | rel = tokk[:-3] 98 | if self.INIT + rel in self.encoder: 99 | bpe_toks = [self.INIT + rel, '-of'] 100 | else: 101 | bpe_toks = [self.INIT + ':'] + self._tok_bpe(rel[1:]) + ['-of'] 102 | elif is_rel: 103 | bpe_toks = [self.INIT + ':'] + self._tok_bpe(tokk[1:]) 104 | else: 105 | print("tok:", tokk) 106 | print(f"is_rel:{is_rel}, is_spc:{is_spc}, is_frame:{is_frame}, is_of:{is_of}") 107 | exit() 108 | raise 109 | else: 110 | if is_in_enc: 111 | bpe_toks = [self.INIT + tokk] 112 | else: 113 | bpe_toks = self._tok_bpe(tokk) 114 | 115 | bpe_tokens.append(bpe_toks) 116 | bpe_tokens = [b for bb in bpe_tokens for b in bb] 117 | bpe_token_ids = [self.encoder.get(b, self.unk_token_id) for b in bpe_tokens] 118 | return bpe_token_ids 119 | 120 | def decode_amr(self, tokens, restore_name_ops=None): 121 | try: 122 | nodes, backreferences = postprocessing.decode_into_node_and_backreferences(tokens, self) 123 | except Exception as e: 124 | print('Decoding failure:', file=sys.stderr) 125 | print(e, file=sys.stderr) 126 | return postprocessing.BACKOFF, postprocessing.ParsedStatus.BACKOFF, (None, None) 127 | try: 128 | graph_ = graph = self._fix_and_make_graph(nodes) 129 | # if collapse_name_ops: 130 | # graph_ = graph = postprocessing._split_name_ops(graph) 131 | except Exception as e: 132 | print('Building failure:', file=sys.stderr) 133 | print(nodes, file=sys.stderr) 134 | print(backreferences, file=sys.stderr) 135 | print(e, file=sys.stderr) 136 | return postprocessing.BACKOFF, postprocessing.ParsedStatus.BACKOFF, (None, None) 137 | try: 138 | graph, status = postprocessing.connect_graph_if_not_connected(graph) 139 | if status == postprocessing.ParsedStatus.BACKOFF: 140 | print('Reconnection 1 failure:') 141 | print(nodes, file=sys.stderr) 142 | print(backreferences, file=sys.stderr) 143 | print(graph_, file=sys.stderr) 144 | return graph, status, (nodes, backreferences) 145 | except Exception as e: 146 | print('Reconnction 2 failure:', file=sys.stderr) 147 | print(e, file=sys.stderr) 148 | print(nodes, file=sys.stderr) 149 | print(backreferences, file=sys.stderr) 150 | print(graph_, file=sys.stderr) 151 | return postprocessing.BACKOFF, postprocessing.ParsedStatus.BACKOFF, (nodes, backreferences) 152 | 153 | def _fix_and_make_graph(self, nodes): 154 | 155 | nodes_ = [] 156 | for n in nodes: 157 | if isinstance(n, str): 158 | if n.startswith('<') and n.endswith('>') and (not n.startswith('') 174 | if e != len(nxt) -1: 175 | pst = nxt[e+1:] 176 | nxt = nxt[:e+1] 177 | nodes_.append(nxt) 178 | if pst is not None: 179 | nodes_.append(pst) 180 | else: 181 | nodes_.append(nxt) 182 | i += 1 183 | nodes = nodes_ 184 | 185 | i = 1 186 | nodes_ = [nodes[0]] 187 | while i < len(nodes): 188 | nxt = nodes[i] 189 | if isinstance(nxt, str) and nxt.startswith(' 0: 365 | line = line[:i].strip() 366 | break 367 | old_line = line 368 | while True: 369 | open_count = len(re.findall(r'\(', line)) 370 | close_count = len(re.findall(r'\)', line)) 371 | if open_count > close_count: 372 | line += ')' * (open_count - close_count) 373 | elif close_count > open_count: 374 | for i in range(close_count - open_count): 375 | line = line.rstrip(')') 376 | line = line.rstrip(' ') 377 | if old_line == line: 378 | break 379 | old_line = line 380 | """ 381 | 382 | graph = penman.decode(linearized + ' ') 383 | triples = [] 384 | newvars = 2000 385 | for triple in graph.triples: 386 | x, rel, y = triple 387 | if x is None: 388 | pass 389 | elif rel == ':instance' and y is None: 390 | triples.append(penman.Triple(x, rel, 'thing')) 391 | elif y is None: 392 | var = f'z{newvars}' 393 | newvars += 1 394 | triples.append(penman.Triple(x, rel, var)) 395 | triples.append(penman.Triple(var, ':instance', 'thing')) 396 | else: 397 | triples.append(triple) 398 | graph = penman.Graph(triples) 399 | linearized = encode(graph) 400 | 401 | def fix_text(linearized=linearized): 402 | n = 0 403 | def _repl1(match): 404 | nonlocal n 405 | out = match.group(1) + match.group(2) + str(3000 + n) + ' / ' + match.group(2) + match.group(3) 406 | n += 1 407 | return out 408 | linearized = re.sub(r'(\(\s?)([a-z])([^\/:\)]+[:\)])', _repl1, linearized, 409 | flags=re.IGNORECASE | re.MULTILINE) 410 | 411 | def _repl2(match): 412 | return match.group(1) 413 | linearized = re.sub(r'(\(\s*[a-z][\d+]\s*\/\s*[^\s\)\(:\/]+\s*)((?:/\s*[^\s\)\(:\/]+\s*)+)', _repl2, 414 | linearized, 415 | flags=re.IGNORECASE | re.MULTILINE) 416 | 417 | # adds a ':' to args w/o it 418 | linearized = re.sub(r'([^:])(ARG)', r'\1 :\2', linearized) 419 | 420 | # removes edges with no node 421 | # linearized = re.sub(r':[^\s\)\(:\/]+?\s*\)', ')', linearized, flags=re.MULTILINE) 422 | 423 | return linearized 424 | 425 | linearized = fix_text(linearized) 426 | g = penman.decode(linearized) 427 | return g 428 | 429 | def _classify(self, node): 430 | if not isinstance(node, str): 431 | return "CONST" 432 | elif node == 'i': 433 | return "I" 434 | elif re.match(r'^[a-z]\d*$', node) is not None: 435 | return "VAR" 436 | elif node[0].isdigit(): 437 | return "CONST" 438 | elif node.startswith('"') and node.endswith('"'): 439 | return "CONST" 440 | elif node in ('+', '-'): 441 | return "CONST" 442 | elif node == ':mode': 443 | return 'MODE' 444 | elif node.startswith(':'): 445 | return "EDGE" 446 | elif node in ['/', '(', ')']: 447 | return node 448 | elif node[0].isalpha(): 449 | for char in (',', ':', '/', '(', ')', '.', '!', '?', '\\'): 450 | if char in node: 451 | return "CONST" 452 | return "INST" 453 | else: 454 | return 'CONST' -------------------------------------------------------------------------------- /fine-tune/train-AMRBART-large-AMR2Text.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | RootDir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 3 | 4 | Dataset=LDC2020 5 | Dataset=LDC2017 6 | 7 | BasePath=/mnt/nfs-storage/data # change dir here 8 | DataPath=$RootDir/data/$Dataset 9 | 10 | ModelCate=AMRBART-large 11 | 12 | MODEL=$1 13 | ModelCache=$BasePath/.cache 14 | DataCache=$DataPath/.cache/dump-amr2text 15 | 16 | lr=2e-6 17 | 18 | OutputDir=${RootDir}/outputs/$Dataset-$ModelCate-AMR2Text-bsz8-lr-${lr}-UnifiedInp 19 | 20 | if [ ! -d ${OutputDir} ];then 21 | mkdir -p ${OutputDir} 22 | else 23 | read -p "${OutputDir} already exists, delete origin one [y/n]?" yn 24 | case $yn in 25 | [Yy]* ) rm -rf ${OutputDir}; mkdir -p ${OutputDir};; 26 | [Nn]* ) echo "exiting..."; exit;; 27 | * ) echo "Please answer yes or no.";; 28 | esac 29 | fi 30 | 31 | export HF_DATASETS_CACHE=$DataCache 32 | 33 | if [ ! -d ${DataCache} ];then 34 | mkdir -p ${DataCache} 35 | fi 36 | 37 | # torchrun --nnodes=1 --nproc_per_node=1 --max_restarts=0 --rdzv_id=1 --rdzv_backend=c10d main.py \ 38 | python -u main.py \ 39 | --data_dir $DataPath \ 40 | --task "amr2text" \ 41 | --train_file $DataPath/train.jsonl \ 42 | --validation_file $DataPath/val.jsonl \ 43 | --test_file $DataPath/test.jsonl \ 44 | --output_dir $OutputDir \ 45 | --cache_dir $ModelCache \ 46 | --data_cache_dir $DataCache \ 47 | --model_name_or_path $MODEL \ 48 | --overwrite_output_dir \ 49 | --unified_input True \ 50 | --per_device_train_batch_size 8 \ 51 | --per_device_eval_batch_size 4 \ 52 | --gradient_accumulation_steps 1 \ 53 | --learning_rate $lr \ 54 | --optim "adamw_hf" \ 55 | --lr_scheduler_type "polynomial" \ 56 | --warmup_steps 200 \ 57 | --num_train_epochs 30 \ 58 | --early_stopping 10 \ 59 | --max_source_length 1024 \ 60 | --max_target_length 384 \ 61 | --val_max_target_length 384 \ 62 | --generation_max_length 380 \ 63 | --generation_num_beams 5 \ 64 | --label_smoothing_factor 0.1 \ 65 | --evaluation_strategy "epoch" \ 66 | --weight_decay 0.01 \ 67 | --max_grad_norm 0 \ 68 | --max_steps -1 \ 69 | --predict_with_generate \ 70 | --smart_init False \ 71 | --use_fast_tokenizer False \ 72 | --logging_dir $OutputDir/logs \ 73 | --logging_first_step True \ 74 | --logging_steps 10 \ 75 | --save_strategy "epoch" \ 76 | --save_total_limit 1 \ 77 | --seed 42 \ 78 | --fp16 \ 79 | --fp16_backend "auto" \ 80 | --dataloader_num_workers 8 \ 81 | --eval_dataloader_num_workers 2 \ 82 | --load_best_model_at_end True \ 83 | --metric_for_best_model "eval_bleu" \ 84 | --include_inputs_for_metrics \ 85 | --greater_is_better True \ 86 | --do_train \ 87 | --do_eval \ 88 | --do_predict \ 89 | --ddp_find_unused_parameters False \ 90 | --report_to "tensorboard" \ 91 | --dataloader_pin_memory True 2>&1 | tee $OutputDir/run.log 92 | -------------------------------------------------------------------------------- /fine-tune/train-AMRBART-large-AMRParsing.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | RootDir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 3 | 4 | Dataset=LDC2020 5 | #Dataset=LDC2017 6 | 7 | BasePath=/mnt/nfs-storage/data # change dir here 8 | DataPath=$RootDir/data/$Dataset 9 | 10 | ModelCate=AMRBART-large 11 | 12 | MODEL=$1 13 | ModelCache=$BasePath/.cache 14 | DataCache=$DataPath/.cache/dump-amrparsing 15 | 16 | lr=1e-5 17 | 18 | OutputDir=${RootDir}/outputs/$Dataset-${ModelCate}-AMRParing-bsz16-lr-${lr}-UnifiedInp 19 | 20 | if [ ! -d ${OutputDir} ];then 21 | mkdir -p ${OutputDir} 22 | else 23 | read -p "${OutputDir} already exists, delete origin one [y/n]?" yn 24 | case $yn in 25 | [Yy]* ) rm -rf ${OutputDir}; mkdir -p ${OutputDir};; 26 | [Nn]* ) echo "exiting..."; exit;; 27 | * ) echo "Please answer yes or no.";; 28 | esac 29 | fi 30 | 31 | export HF_DATASETS_CACHE=$DataCache 32 | 33 | if [ ! -d ${DataCache} ];then 34 | mkdir -p ${DataCache} 35 | fi 36 | 37 | # torchrun --nnodes=1 --nproc_per_node=1 --max_restarts=0 --rdzv_id=1 --rdzv_backend=c10d main.py \ 38 | python -u main.py \ 39 | --data_dir $DataPath \ 40 | --task "text2amr" \ 41 | --train_file $DataPath/train.jsonl \ 42 | --validation_file $DataPath/val.jsonl \ 43 | --test_file $DataPath/test.jsonl \ 44 | --output_dir $OutputDir \ 45 | --cache_dir $ModelCache \ 46 | --data_cache_dir $DataCache \ 47 | --tokenizer_name "facebook/bart-large" \ 48 | --model_name_or_path $MODEL \ 49 | --overwrite_output_dir \ 50 | --unified_input True \ 51 | --per_device_train_batch_size 16 \ 52 | --per_device_eval_batch_size 8 \ 53 | --gradient_accumulation_steps 1 \ 54 | --learning_rate $lr \ 55 | --optim "adamw_hf" \ 56 | --lr_scheduler_type "polynomial" \ 57 | --warmup_steps 200 \ 58 | --num_train_epochs 30 \ 59 | --early_stopping 10 \ 60 | --max_source_length 400 \ 61 | --max_target_length 1024 \ 62 | --val_max_target_length 1024 \ 63 | --generation_max_length 1024 \ 64 | --generation_num_beams 5 \ 65 | --label_smoothing_factor 0.1 \ 66 | --evaluation_strategy "epoch" \ 67 | --weight_decay 0.01 \ 68 | --max_grad_norm 0 \ 69 | --max_steps -1 \ 70 | --predict_with_generate \ 71 | --smart_init False \ 72 | --use_fast_tokenizer False \ 73 | --logging_dir $OutputDir/logs \ 74 | --logging_first_step True \ 75 | --logging_steps 20 \ 76 | --save_strategy "epoch" \ 77 | --save_total_limit 1 \ 78 | --seed 42 \ 79 | --fp16 \ 80 | --fp16_backend "auto" \ 81 | --dataloader_num_workers 8 \ 82 | --eval_dataloader_num_workers 2 \ 83 | --load_best_model_at_end True \ 84 | --metric_for_best_model "eval_smatch" \ 85 | --include_inputs_for_metrics \ 86 | --greater_is_better True \ 87 | --do_train \ 88 | --do_eval \ 89 | --do_predict \ 90 | --ddp_find_unused_parameters False \ 91 | --report_to "tensorboard" \ 92 | --dataloader_pin_memory True 2>&1 | tee $OutputDir/run.log 93 | -------------------------------------------------------------------------------- /pre-train/common/constant.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import os 3 | import json 4 | 5 | from transformers.optimization import ( 6 | get_cosine_schedule_with_warmup, 7 | get_cosine_with_hard_restarts_schedule_with_warmup, 8 | get_linear_schedule_with_warmup, 9 | get_polynomial_decay_schedule_with_warmup, 10 | get_constant_schedule_with_warmup, 11 | ) 12 | 13 | from transformers import ( 14 | WEIGHTS_NAME, 15 | AdamW, 16 | Adafactor, 17 | AutoConfig, 18 | AutoTokenizer, 19 | AutoModelForSeq2SeqLM, 20 | BartTokenizer, 21 | BartForConditionalGeneration, 22 | T5Tokenizer, 23 | T5Model, 24 | T5ForConditionalGeneration, 25 | ) 26 | 27 | raw_special_tokens = json.load( 28 | open(f"{os.path.dirname(__file__)}/additional-tokens.json", "r", encoding="utf-8") 29 | ) 30 | special_tokens = [itm.lstrip("Ġ") for itm in raw_special_tokens] 31 | 32 | recategorizations = [ 33 | "\u0120COUNTRY", 34 | "\u0120QUANTITY", 35 | "\u0120ORGANIZATION", 36 | "\u0120DATE_ATTRS", 37 | "\u0120NATIONALITY", 38 | "\u0120LOCATION", 39 | "\u0120ENTITY", 40 | "\u0120MISC", 41 | "\u0120ORDINAL_ENTITY", 42 | "\u0120IDEOLOGY", 43 | "\u0120RELIGION", 44 | "\u0120STATE_OR_PROVINCE", 45 | "\u0120CAUSE_OF_DEATH", 46 | "\u0120TITLE", 47 | "\u0120DATE", 48 | "\u0120NUMBER", 49 | "\u0120HANDLE", 50 | "\u0120SCORE_ENTITY", 51 | "\u0120DURATION", 52 | "\u0120ORDINAL", 53 | "\u0120MONEY", 54 | "\u0120CRIMINAL_CHARGE", 55 | ] 56 | 57 | # special_tokens = ["", ""] 58 | 59 | arg_to_scheduler = { 60 | "linear": get_linear_schedule_with_warmup, 61 | "cosine": get_cosine_schedule_with_warmup, 62 | "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, 63 | "polynomial": get_polynomial_decay_schedule_with_warmup, 64 | "constant": get_constant_schedule_with_warmup, 65 | } 66 | arg_to_scheduler_choices = sorted(arg_to_scheduler.keys()) 67 | arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}" 68 | 69 | ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"] 70 | 71 | arg_to_tokenizer = { 72 | "AutoTokenizer": AutoTokenizer, 73 | "BartTokenizer": BartTokenizer, 74 | "T5Tokenizer": T5Tokenizer, 75 | } 76 | arg_to_plm_model = { 77 | "AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM, 78 | "BartForConditionalGeneration": BartForConditionalGeneration, 79 | "T5Model": T5Model, 80 | "T5ForConditionalGeneration": T5ForConditionalGeneration, 81 | } 82 | -------------------------------------------------------------------------------- /pre-train/common/penman_interface.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | from penman import load as load_, Graph, Triple 3 | from penman import loads as loads_ 4 | from penman import encode as encode_ 5 | from penman.model import Model 6 | from penman.models.noop import NoOpModel 7 | from penman.models import amr 8 | 9 | op_model = Model() 10 | noop_model = NoOpModel() 11 | amr_model = amr.model 12 | DEFAULT = op_model 13 | 14 | 15 | def _get_model(dereify): 16 | if dereify is None: 17 | return DEFAULT 18 | 19 | elif dereify: 20 | return op_model 21 | 22 | else: 23 | return noop_model 24 | 25 | 26 | def _remove_wiki(graph): 27 | metadata = graph.metadata 28 | triples = [] 29 | for t in graph.triples: 30 | v1, rel, v2 = t 31 | if rel == ":wiki": 32 | t = Triple(v1, rel, "+") 33 | triples.append(t) 34 | graph = Graph(triples) 35 | graph.metadata = metadata 36 | return graph 37 | 38 | 39 | def load(source, dereify=None, remove_wiki=False): 40 | model = _get_model(dereify) 41 | out = load_(source=source, model=model) 42 | if remove_wiki: 43 | for i in range(len(out)): 44 | out[i] = _remove_wiki(out[i]) 45 | return out 46 | 47 | 48 | def loads(string, dereify=None, remove_wiki=False): 49 | model = _get_model(dereify) 50 | out = loads_(string=string, model=model) 51 | if remove_wiki: 52 | for i in range(len(out)): 53 | out[i] = _remove_wiki(out[i]) 54 | return out 55 | 56 | 57 | def encode(g, top=None, indent=-1, compact=False): 58 | model = amr_model 59 | return encode_(g=g, top=top, indent=indent, compact=compact, model=model) 60 | -------------------------------------------------------------------------------- /pre-train/common/postprocessing.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import re 3 | import enum 4 | import penman 5 | import networkx as nx 6 | from common.penman_interface import encode 7 | from collections import defaultdict, Counter 8 | 9 | BACKOFF = penman.Graph( 10 | [ 11 | penman.Triple("d2", ":instance", "dog"), 12 | penman.Triple("b1", ":instance", "bark-01"), 13 | penman.Triple("b1", ":ARG0", "d2"), 14 | ] 15 | ) 16 | 17 | 18 | def token_processing(tok): 19 | if tok is None: 20 | return None 21 | elif tok.isdigit(): 22 | try: 23 | return eval(tok) 24 | except: 25 | return tok 26 | elif tok.startswith('"') and (not tok.endswith('"')): 27 | return tok + '"' 28 | elif tok.endswith('"') and (not tok.startswith('"')): 29 | return '"' + tok 30 | else: 31 | return tok 32 | 33 | 34 | def decode_into_node_and_backreferences(subtoken_ids, tokenizer): 35 | rex_arg = re.compile(f"^{tokenizer.INIT}(op|snt|conj|prep)") 36 | rex_spc = re.compile(r"<(s|/s|lit|/lit|stop|unk|pad|mask)>") 37 | 38 | subtoken_ids.insert(1,36) # add "(" id 39 | subtoken_ids.insert(-1, 4839) # add ")" id 40 | 41 | # get strings 42 | subtokens = [tokenizer.decoder.get(t) for t in subtoken_ids] 43 | # print("subtokens:", subtokens) 44 | # fix backreferences 45 | 46 | subtoken_backreferences = [max(t - len(tokenizer.encoder), -1) for t in subtoken_ids] 47 | # strip padding 48 | subtokens, subtoken_backreferences = zip( 49 | *[ 50 | (s, b) 51 | for s, b in zip(subtokens, subtoken_backreferences) 52 | if s != ("") 53 | ] 54 | ) 55 | 56 | # subword collapse 57 | tokens = [] 58 | backreferences = [] 59 | subword_to_token_map = {} 60 | current_token_i = 0 61 | for subw_i, (subw_backr, subtok) in enumerate(zip(subtoken_backreferences, subtokens)): 62 | subword_to_token_map[subw_i] = current_token_i 63 | 64 | # if empty you cannot do anything but add a new word 65 | if not tokens: 66 | tokens.append(subtok.lstrip(tokenizer.INIT)) 67 | backreferences.append(-1) 68 | current_token_i += 1 69 | 70 | # backref can't be splitted 71 | elif subw_backr > -1: 72 | tokens.append(None) 73 | backreferences.append(subword_to_token_map[subw_backr]) 74 | current_token_i += 1 75 | 76 | # after a special token release 77 | elif isinstance(tokens[-1], str) and rex_spc.match(tokens[-1]): 78 | tokens.append(subtok.lstrip(tokenizer.INIT)) 79 | backreferences.append(-1) 80 | current_token_i += 1 81 | 82 | # after a subtoken ':' (which should be followed by the rest of the edge) ignore tokenizer.INIT 83 | # TODO: this is an ugly patch due to the fact that BART tokenizer splits after ':' 84 | elif (tokens[-1] == ":") and rex_arg.match(subtok): 85 | tokens[-1] = tokens[-1] + subtok[1:] 86 | 87 | # leading tokenizer.INIT 88 | elif subtok.startswith(tokenizer.INIT): 89 | tokens.append(subtok.lstrip(tokenizer.INIT)) 90 | backreferences.append(-1) 91 | current_token_i += 1 92 | 93 | # very ugly patch for some cases in which tokenizer.INIT is not in the following token to the edge 94 | elif ( 95 | isinstance(tokens[-1], str) 96 | and tokens[-1].startswith(":") 97 | and tokens[-1][-1].isdigit() 98 | and (subtok != "-of") 99 | ): 100 | tokens.append(subtok.lstrip(tokenizer.INIT)) 101 | backreferences.append(-1) 102 | current_token_i += 1 103 | 104 | # in any other case attach to the previous 105 | else: 106 | tokens[-1] = tokens[-1] + subtok 107 | 108 | # strip INIT and fix byte-level 109 | tokens = [ 110 | tokenizer.convert_tokens_to_string(list(t)).lstrip() if isinstance(t, str) else t 111 | for t in tokens 112 | ] 113 | # tokens = [t.replace(tokenizer.INIT, '') if isinstance(t, str) else t for t in tokens] 114 | 115 | # unks are substituted with thing 116 | tokens = [t if t != "" else "thing" for t in tokens] 117 | 118 | old_tokens = tokens 119 | old_backreferences = backreferences 120 | 121 | # Barack Obama -> "Barack Obama" 122 | tokens = [] 123 | backreferences = [] 124 | token_to_token_map = {} 125 | start_search = 0 126 | removed = 0 127 | while True: 128 | try: 129 | 130 | lit_start = old_tokens.index("", start_search) 131 | token_addition = old_tokens[start_search:lit_start] 132 | for i, t in enumerate(token_addition, start=start_search): 133 | token_to_token_map[i] = i - removed 134 | tokens += token_addition 135 | 136 | backreferences_addition = [ 137 | token_to_token_map[b] if b > -1 else -1 138 | for b in old_backreferences[start_search:lit_start] 139 | ] 140 | backreferences += backreferences_addition 141 | 142 | lit_end = min(lit_start + 2, len(old_tokens) - 1) 143 | 144 | while lit_end < len(old_tokens): 145 | old_tok = old_tokens[lit_end] 146 | 147 | if isinstance(old_tok, str) and ( 148 | (old_tok.startswith(":") and len(old_tok) > 3) or (old_tok == "") 149 | ): 150 | res_tok = old_tokens[lit_start + 1 : lit_end] 151 | for i in range(lit_start, lit_end): 152 | token_to_token_map[i] = len(tokens) 153 | 154 | # Remove possible wrong None 155 | res = old_tokens[lit_start + 1 : lit_end] 156 | res = [str(r) for r in res if r is not None] 157 | res = '"' + "_".join(res) + '"' 158 | 159 | removed += len(res_tok) 160 | start_search = lit_end 161 | tokens += [res, old_tok] 162 | backreferences += [-1, -1] 163 | break 164 | 165 | elif old_tok == "": 166 | res_tok = old_tokens[lit_start + 1 : lit_end] 167 | for i in range(lit_start, lit_end + 1): 168 | token_to_token_map[i] = len(tokens) 169 | 170 | # Remove possible wrong None 171 | res = old_tokens[lit_start + 1 : lit_end] 172 | res = [str(r) for r in res if r is not None] 173 | res = '"' + "_".join(res) + '"' 174 | 175 | removed += len(res_tok) + 1 176 | start_search = lit_end + 1 177 | tokens.append(res) 178 | backreferences.append(-1) 179 | break 180 | 181 | else: 182 | lit_end += 1 183 | start_search = lit_end 184 | 185 | except ValueError: 186 | token_addition = old_tokens[start_search:] 187 | for i, t in enumerate(token_addition, start=start_search): 188 | token_to_token_map[i] = i - removed 189 | backreferences_addition = [ 190 | token_to_token_map[b] if b > -1 else b for b in old_backreferences[start_search:] 191 | ] 192 | tokens += token_addition 193 | backreferences += backreferences_addition 194 | break 195 | 196 | tokens = [token_processing(t) for t in tokens] 197 | 198 | shift = 1 199 | if tokens[1] == "": 200 | shift = 2 201 | 202 | tokens = tokens[shift:] 203 | backreferences = [b if b == -1 else b - shift for b in backreferences[shift:]] 204 | 205 | if tokens[-1] == "": 206 | tokens.pop() 207 | backreferences.pop() 208 | 209 | return tokens, backreferences 210 | 211 | 212 | def index_of(element, iterable, default=None, start=None, end=None): 213 | if not callable(element): 214 | 215 | def check(x): 216 | return element == x 217 | 218 | else: 219 | check = element 220 | if start is None: 221 | start = 0 222 | if end is None: 223 | end = len(iterable) 224 | item = start 225 | while item < end: 226 | if check(iterable[item]): 227 | return item 228 | item += 1 229 | return default 230 | 231 | 232 | def separate_edges_nodes(edges_nodes_slice, *other): 233 | is_arg = lambda x: isinstance(x, str) and x.startswith(":") 234 | start = 0 235 | edges = [] 236 | nodes = [] 237 | l = len(edges_nodes_slice) 238 | while start < l: 239 | edge_index = index_of(is_arg, edges_nodes_slice, start=start) 240 | if edge_index is None or edge_index == (l - 1): 241 | break 242 | if is_arg(edges_nodes_slice[edge_index + 1]): 243 | start = edge_index + 1 244 | continue 245 | edges.append(edge_index) 246 | nodes.append(edge_index + 1) 247 | start = edge_index + 2 248 | ret = [] 249 | for oth in other: 250 | edges_oth = [oth[i] for i in edges] 251 | nodes_oth = [oth[i] for i in nodes] 252 | ret.append((edges_oth, nodes_oth)) 253 | return ret 254 | 255 | 256 | def _split_name_ops(graph): 257 | # identify name triples 258 | name_vars = {} 259 | for i, (v1, rel, v2) in enumerate(graph.triples): 260 | if rel == ":instance" and v2 == "name": 261 | name_vars[v1] = 1 262 | 263 | # check if they have ops 264 | name_vars_to_ops = defaultdict(list) 265 | for i, (v1, rel, v2) in enumerate(graph.triples): 266 | if v1 in name_vars and rel.startswith(":op"): 267 | name_vars_to_ops[v1].append((i, rel, v2.strip('"'))) 268 | 269 | triples = graph.triples.copy() 270 | for nv, ops in name_vars_to_ops.items(): 271 | ops = sorted(ops, key=lambda x: int(x[1][3:])) 272 | idx, _, lits = zip(*ops) 273 | for i in idx: 274 | triples[i] = None 275 | 276 | lits = ['"' + l + '"' for lit in lits for l in lit.split("_")] 277 | 278 | tt = [] 279 | for i, l in enumerate(lits, start=1): 280 | rel = ":op" + str(i) 281 | tt.append(penman.Triple(nv, rel, l)) 282 | 283 | triples[min(idx)] = tt 284 | 285 | triples = [t if isinstance(t, list) else [t] for t in triples if t is not None] 286 | triples = [t for tt in triples for t in tt] 287 | 288 | graph_ = penman.Graph(triples) 289 | graph_.metadata = graph.metadata 290 | return graph_ 291 | 292 | 293 | def _reconstruct_graph_from_nodes(nodes, backreferences): 294 | triples = [] 295 | triples_added = set() 296 | 297 | variable2index = {} 298 | index2variable = {} 299 | start_index = 0 300 | 301 | cnt = defaultdict(Counter) 302 | 303 | while start_index < len(nodes): 304 | stop_index = index_of("", nodes, default=len(nodes) + 1, start=start_index) 305 | old_start_index = start_index 306 | start_index = stop_index + 1 307 | 308 | src_node, src_backr = nodes[old_start_index], backreferences[old_start_index] 309 | 310 | if src_node == "": 311 | continue 312 | 313 | trg_nodes_edges = nodes[old_start_index:stop_index] 314 | trg_nodes_edges_backr = backreferences[old_start_index:stop_index] 315 | trg_nodes_edges_indices = list(range(old_start_index, stop_index)) 316 | 317 | if isinstance(src_node, str): 318 | if src_node in ("", "", ""): 319 | continue 320 | elif ("/" in src_node) or (":" in src_node) or ("(" in src_node) or (")" in src_node): 321 | src_node = "thing" 322 | 323 | if src_node is not None: 324 | src_node = str(src_node) 325 | src_var = src_node[0].lower() 326 | if not src_var not in "abcdefghijklmnopqrstuvwxyz": 327 | src_var = "x" 328 | # src_var = f'{src_var}_{len(variable2index)}' 329 | src_var = f"{src_var}{len(variable2index)}" 330 | src_var_i = old_start_index 331 | variable2index[src_var] = src_var_i 332 | index2variable[src_var_i] = src_var 333 | triple = penman.Triple(src_var, ":instance", src_node) 334 | if triple not in triples_added: 335 | triples.append(triple) 336 | triples_added.add(triple) 337 | else: 338 | if src_backr in index2variable: 339 | src_var = index2variable[src_backr] 340 | # more resilient logic here 341 | (trg_edges, trg_nodes), (_, trg_nodes_backr), (_, trg_nodes_indices) = separate_edges_nodes( 342 | trg_nodes_edges, trg_nodes_edges, trg_nodes_edges_backr, trg_nodes_edges_indices 343 | ) 344 | 345 | for n, e, nb, ni in zip(trg_nodes, trg_edges, trg_nodes_backr, trg_nodes_indices): 346 | 347 | if isinstance(n, str) and n.startswith(":"): 348 | continue 349 | if isinstance(n, str) and n.startswith("<") and n.endswith(">"): 350 | continue 351 | if e == ":li": 352 | pass 353 | elif len(e) < 4 or (not e.startswith(":")): 354 | continue 355 | 356 | # same edge more than once 357 | num = cnt[src_var][e] 358 | # num = 0 359 | if num: 360 | 361 | if e.startswith(":op") or e.startswith(":snt"): 362 | continue 363 | # elif e.startswith(':ARG'): 364 | # continue 365 | elif num > 3: 366 | continue 367 | 368 | if n is None: 369 | if nb not in index2variable: 370 | continue 371 | trg_var = index2variable[nb] 372 | trg = trg_var 373 | elif e == ":mode": 374 | trg = n 375 | elif ( 376 | (not isinstance(n, str)) 377 | or re.match(r"^[+-]?\d+\.?\d*$", n) 378 | or (n == "-") 379 | or (n == "+") 380 | ): 381 | trg = str(n) 382 | elif n.startswith('"') and n.endswith('"') and len(n) > 2: 383 | trg = '"' + n.replace('"', "") + '"' 384 | elif ("/" in n) or (":" in n) or ("(" in n) or (")" in n) or ("=" in n): 385 | trg = f'"{n}"' 386 | elif n == '"': 387 | continue 388 | elif ( 389 | (n.startswith('"') and (not n.endswith('"'))) 390 | or (not n.startswith('"') and (n.endswith('"'))) 391 | or ('"' in n) 392 | ): 393 | trg = '"' + n.replace('"', "") + '"' 394 | else: 395 | trg_var = n[0].lower() 396 | if trg_var not in "abcdefghijklmnopqrstuvwxyz": 397 | trg_var = "x" 398 | # trg_var = f'{trg_var}_{len(variable2index)}' 399 | trg_var = f"{trg_var}{len(variable2index)}" 400 | trg_var_i = ni 401 | variable2index[trg_var] = trg_var_i 402 | index2variable[trg_var_i] = trg_var 403 | triple = penman.Triple(trg_var, ":instance", n) 404 | if triple not in triples_added: 405 | triples.append(triple) 406 | triples_added.add(triple) 407 | trg = trg_var 408 | 409 | triple = penman.Triple(src_var, e, trg) 410 | if triple not in triples_added: 411 | triples.append(triple) 412 | triples_added.add(triple) 413 | 414 | cnt[src_var][e] += 1 415 | 416 | return penman.Graph(triples) 417 | 418 | 419 | def build_graph(nodes, backreferences, restore_name_ops=False): 420 | graph = _reconstruct_graph_from_nodes(nodes, backreferences) 421 | if restore_name_ops: 422 | graph = _split_name_ops(graph) 423 | return graph 424 | 425 | 426 | class ParsedStatus(enum.Enum): 427 | OK = 0 428 | FIXED = 1 429 | BACKOFF = 2 430 | 431 | 432 | def connect_graph_if_not_connected(graph): 433 | 434 | try: 435 | encoded = encode(graph) 436 | return graph, ParsedStatus.OK 437 | except: 438 | pass 439 | 440 | nxgraph = nx.MultiGraph() 441 | variables = graph.variables() 442 | for v1, _, v2 in graph.triples: 443 | if v1 in variables and v2 in variables: 444 | nxgraph.add_edge(v1, v2) 445 | elif v1 in variables: 446 | nxgraph.add_edge(v1, v1) 447 | 448 | triples = graph.triples.copy() 449 | new_triples = [] 450 | addition = f"a{len(variables) + 1}" 451 | triples.append(penman.Triple(addition, ":instance", "and")) 452 | for i, conn_set in enumerate(nx.connected_components(nxgraph), start=1): 453 | edge = f":op{i}" 454 | conn_set = sorted(conn_set, key=lambda x: int(x[1:])) 455 | conn_set = [c for c in conn_set if c in variables] 456 | node = conn_set[0] 457 | new_triples.append(penman.Triple(addition, edge, node)) 458 | triples = new_triples + triples 459 | metadata = graph.metadata 460 | graph = penman.Graph(triples) 461 | graph.metadata.update(metadata) 462 | encode(graph) 463 | 464 | return graph, ParsedStatus.FIXED 465 | 466 | 467 | def restore_backreferences_from_pointers(nodes): 468 | new_nodes, new_backreferences = [], [] 469 | prev_pointer = None 470 | pointer2i = {} 471 | for n in nodes: 472 | is_pointer = isinstance(n, str) and n.startswith("") 473 | 474 | if not is_pointer: 475 | if prev_pointer is not None: 476 | if prev_pointer in pointer2i: 477 | new_nodes.append(None) 478 | new_backreferences.append(pointer2i[prev_pointer]) 479 | new_nodes.append(n) 480 | new_backreferences.append(-1) 481 | 482 | else: 483 | pointer2i[prev_pointer] = len(new_nodes) 484 | new_nodes.append(n) 485 | new_backreferences.append(-1) 486 | else: 487 | new_nodes.append(n) 488 | new_backreferences.append(-1) 489 | 490 | prev_pointer = None 491 | else: 492 | prev_pointer = n 493 | return new_nodes, new_backreferences 494 | -------------------------------------------------------------------------------- /pre-train/data_interface/amrdata.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """AMR dataset.""" 18 | 19 | 20 | from inspect import EndOfBlock 21 | import json 22 | import os 23 | 24 | import datasets 25 | 26 | logger = datasets.logging.get_logger(__name__) 27 | 28 | 29 | _DESCRIPTION = """ 30 | PubMed articles. 31 | 32 | There are three features: 33 | - src: source text. 34 | - tgt: AMR Graph. 35 | """ 36 | 37 | 38 | _SRC = "amr" 39 | _TGT = "text" 40 | 41 | 42 | class AMRData(datasets.GeneratorBasedBuilder): 43 | 44 | # Version 1.2.0 expands coverage, includes ids, and removes web contents. 45 | VERSION = datasets.Version("1.2.0") 46 | 47 | def _info(self): 48 | return datasets.DatasetInfo( 49 | description=_DESCRIPTION, 50 | features=datasets.Features( 51 | { 52 | _SRC: datasets.Value("string"), 53 | _TGT: datasets.Value("string"), 54 | } 55 | ), 56 | supervised_keys=None, 57 | ) 58 | 59 | def _split_generators(self, dl_manager): 60 | """Returns SplitGenerators.""" 61 | 62 | train_path = self.config.data_files["train"] 63 | dev_path = self.config.data_files["validation"] 64 | test_path = self.config.data_files["test"] 65 | return [ 66 | datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": train_path}), 67 | datasets.SplitGenerator( 68 | name=datasets.Split.VALIDATION, gen_kwargs={"filepath": dev_path} 69 | ), 70 | datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepath": test_path}), 71 | ] 72 | 73 | def _generate_examples(self, filepath): 74 | """Yields examples.""" 75 | logger.info("generating examples from = %s", filepath[0]) 76 | with open(filepath[0], "r", encoding="utf-8") as f: 77 | lines = f.readlines() 78 | for idx, line in enumerate(lines): 79 | json_dict = json.loads(line) 80 | src = json_dict["amr"] 81 | tgt = json_dict["sent"] 82 | yield idx, {_SRC: src, _TGT: tgt} -------------------------------------------------------------------------------- /pre-train/data_interface/dataset.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import os 3 | import torch 4 | from datasets import load_dataset 5 | from dataclasses import dataclass 6 | from transformers.file_utils import PaddingStrategy 7 | from transformers.modeling_utils import PreTrainedModel 8 | from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase 9 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union 10 | 11 | 12 | class AMRDataSet(torch.nn.Module): 13 | def __init__( 14 | self, 15 | tokenizer, 16 | train_file, 17 | validation_file, 18 | test_file, 19 | prefix="", 20 | pad_to_max_length=True, 21 | max_src_length=512, 22 | max_tgt_length=512, 23 | ignore_pad_token_for_loss=True, 24 | ): 25 | super().__init__() 26 | self.train_file = train_file 27 | self.validation_file = validation_file 28 | self.test_file = test_file 29 | self.tokenizer = tokenizer 30 | self.prefix = prefix 31 | self.pad_to_max_length = pad_to_max_length 32 | self.ignore_pad_token_for_loss = ignore_pad_token_for_loss 33 | self.max_src_length = max_src_length 34 | self.max_tgt_length = max_tgt_length 35 | 36 | def setup(self, stage="fit"): 37 | data_files = {} 38 | data_files["train"] = self.train_file 39 | data_files["validation"] = self.validation_file 40 | data_files["test"] = self.test_file 41 | 42 | datasets = load_dataset(f"{os.path.dirname(__file__)}/amrdata.py", data_files=data_files, keep_in_memory=True) 43 | print("datasets:", datasets) 44 | column_names = datasets["train"].column_names 45 | print("colums:", column_names) 46 | padding = "max_length" if self.pad_to_max_length else False 47 | 48 | def tokenize_function(examples): 49 | # Remove empty lines 50 | amrs = examples["amr"] # AMR tokens 51 | sents = examples["text"] # text tokens 52 | sents = [self.prefix + inp for inp in sents] 53 | 54 | model_inputs = self.tokenizer( 55 | sents, max_length=self.max_src_length, padding=False, truncation=True 56 | ) 57 | amr_ids = [self.tokenizer.tokenize_amr(itm.split())[:self.max_src_length - 1] + [self.tokenizer.amr_eos_token_id] for itm in amrs] 58 | model_inputs["labels"] = amr_ids 59 | 60 | joint_ids = [ 61 | srci + [self.tokenizer.amr_bos_token_id] + tgti 62 | for srci, tgti in zip(model_inputs["input_ids"], model_inputs["labels"]) 63 | ] # [ x1,x2...,xn y1,y2,...ym ] 64 | 65 | max_src_length = min(self.max_src_length * 2, 512) 66 | joint_ids = [ 67 | itm[:max_src_length - 1] + [self.tokenizer.amr_eos_token_id] 68 | if len(itm) > max_src_length 69 | else itm 70 | for itm in joint_ids 71 | ] 72 | seg_ids = [ 73 | [0 for _ in range(len(srci))] + [1 for _ in range(len(tgti) + 1)] 74 | for srci, tgti in zip(model_inputs["input_ids"], model_inputs["labels"]) 75 | ] # [0,0,...,0,1,1,...1] 76 | seg_ids = [itm[:max_src_length] for itm in seg_ids] 77 | model_inputs["joint_ids"] = joint_ids 78 | model_inputs["seg_ids"] = seg_ids 79 | srcEtgt_ids = [ 80 | srci[: self.max_src_length - 4] 81 | + [ 82 | self.tokenizer.eos_token_id, 83 | self.tokenizer.amr_bos_token_id, 84 | self.tokenizer.mask_token_id, 85 | self.tokenizer.amr_eos_token_id, 86 | ] 87 | if len(srci) > self.max_src_length - 3 88 | else srci 89 | + [ 90 | self.tokenizer.amr_bos_token_id, 91 | self.tokenizer.mask_token_id, 92 | self.tokenizer.amr_eos_token_id, 93 | ] 94 | for srci in model_inputs["input_ids"] 95 | ] # [ x1,x2...,xn <\s> [mask] ] 96 | Esrctgt_ids = [ 97 | [ 98 | self.tokenizer.bos_token_id, 99 | self.tokenizer.mask_token_id, 100 | self.tokenizer.eos_token_id, 101 | self.tokenizer.amr_bos_token_id 102 | ] 103 | + tgti 104 | if len(tgti) <= self.max_src_length - 4 105 | else 106 | [ 107 | self.tokenizer.bos_token_id, 108 | self.tokenizer.mask_token_id, 109 | self.tokenizer.eos_token_id, 110 | self.tokenizer.amr_bos_token_id 111 | ] 112 | + tgti[: self.max_src_length - 5] 113 | + [self.tokenizer.amr_eos_token_id] 114 | for tgti in model_inputs["labels"] 115 | ] # [ [mask] <\s> y1,y2...,yn ] 116 | 117 | Esrctgt_segids = [ 118 | [0 for _ in range(3)] + [1 for _ in range(len(itm) - 3)] 119 | for itm in Esrctgt_ids 120 | ] 121 | srcEtgt_segids = [ 122 | [0 for _ in range(len(itm) - 3)] + [1 for _ in range(3)] 123 | for itm in srcEtgt_ids 124 | ] 125 | model_inputs["srcEtgt_ids"] = srcEtgt_ids 126 | model_inputs["srcEtgt_segids"] = srcEtgt_segids 127 | model_inputs["Esrctgt_ids"] = Esrctgt_ids 128 | model_inputs["Esrctgt_segids"] = Esrctgt_segids 129 | return model_inputs 130 | 131 | self.train_dataset = datasets["train"].map( 132 | tokenize_function, batched=True, remove_columns=["amr", "text"], num_proc=8 133 | ) 134 | print(f"ALL {len(self.train_dataset)} training instances") 135 | self.valid_dataset = datasets["validation"].map( 136 | tokenize_function, batched=True, remove_columns=["amr", "text"], num_proc=8 137 | ) 138 | print(f"ALL {len(self.valid_dataset)} validation instances") 139 | 140 | self.test_dataset = datasets["test"].map( 141 | tokenize_function, batched=True, remove_columns=["amr", "text"], num_proc=8 142 | ) 143 | print(f"ALL {len(self.test_dataset)} test instances") 144 | 145 | print("Dataset Instance Example:", self.train_dataset[0]) 146 | 147 | 148 | def padding_func(features, padding_side="right", pad_token_id=1, key="label"): 149 | assert key in features[0].keys(), f"{key} not in {features[0].keys()}" 150 | max_label_length = max(len(feature[key]) for feature in features) 151 | for feature in features: 152 | remainder = [pad_token_id] * (max_label_length - len(feature[key])) 153 | feature[key] = ( 154 | feature[key] + remainder if padding_side == "right" else remainder + feature[key] 155 | ) 156 | return 157 | 158 | 159 | @dataclass 160 | class DataCollatorForSeq2Seq: 161 | """ 162 | Data collator that will dynamically pad the inputs received, as well as the labels. 163 | 164 | Args: 165 | tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): 166 | The tokenizer used for encoding the data. 167 | model (:class:`~transformers.PreTrainedModel`): 168 | The model that is being trained. If set and has the `prepare_decoder_input_ids_from_labels`, use it to 169 | prepare the `decoder_input_ids` 170 | 171 | This is useful when using `label_smoothing` to avoid calculating loss twice. 172 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`): 173 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 174 | among: 175 | 176 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 177 | sequence is provided). 178 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 179 | maximum acceptable input length for the model if that argument is not provided. 180 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 181 | different lengths). 182 | max_length (:obj:`int`, `optional`): 183 | Maximum length of the returned list and optionally padding length (see above). 184 | pad_to_multiple_of (:obj:`int`, `optional`): 185 | If set will pad the sequence to a multiple of the provided value. 186 | 187 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 188 | 7.5 (Volta). 189 | label_pad_token_id (:obj:`int`, `optional`, defaults to -100): 190 | The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). 191 | """ 192 | 193 | tokenizer: PreTrainedTokenizerBase 194 | model: Optional[PreTrainedModel] = None 195 | padding: Union[bool, str, PaddingStrategy] = True 196 | max_length: Optional[int] = None 197 | pad_to_multiple_of: Optional[int] = None 198 | label_pad_token_id: int = -100 199 | 200 | def __call__(self, features): 201 | padding_func( 202 | features, 203 | padding_side=self.tokenizer.padding_side, 204 | pad_token_id=self.label_pad_token_id, 205 | key="labels", 206 | ) 207 | padding_func( 208 | features, 209 | padding_side=self.tokenizer.padding_side, 210 | pad_token_id=self.tokenizer.pad_token_id, 211 | key="joint_ids", 212 | ) 213 | padding_func( 214 | features, 215 | padding_side=self.tokenizer.padding_side, 216 | pad_token_id=self.tokenizer.pad_token_id, 217 | key="seg_ids", 218 | ) 219 | padding_func( 220 | features, 221 | padding_side=self.tokenizer.padding_side, 222 | pad_token_id=self.tokenizer.pad_token_id, 223 | key="srcEtgt_ids", 224 | ) 225 | padding_func( 226 | features, 227 | padding_side=self.tokenizer.padding_side, 228 | pad_token_id=self.tokenizer.pad_token_id, 229 | key="srcEtgt_segids", 230 | ) 231 | padding_func( 232 | features, 233 | padding_side=self.tokenizer.padding_side, 234 | pad_token_id=self.tokenizer.pad_token_id, 235 | key="Esrctgt_ids", 236 | ) 237 | padding_func( 238 | features, 239 | padding_side=self.tokenizer.padding_side, 240 | pad_token_id=self.tokenizer.pad_token_id, 241 | key="Esrctgt_segids", 242 | ) 243 | features = self.tokenizer.pad( 244 | features, 245 | padding=self.padding, 246 | max_length=self.max_length, 247 | pad_to_multiple_of=self.pad_to_multiple_of, 248 | return_tensors="pt", 249 | ) 250 | 251 | return { 252 | "input_ids": features["input_ids"], 253 | "labels": features["labels"], 254 | "joint_ids": features["joint_ids"], 255 | "seg_ids": features["seg_ids"], 256 | "srcEtgt_ids": features["srcEtgt_ids"], 257 | "srcEtgt_segids": features["srcEtgt_segids"], 258 | "Esrctgt_ids": features["Esrctgt_ids"], 259 | "Esrctgt_segids": features["Esrctgt_segids"], 260 | } 261 | -------------------------------------------------------------------------------- /pre-train/run-posttrain-bart-textinf-joint-denoising-6task-large-unified-A100.sh: -------------------------------------------------------------------------------- 1 | RootDir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 2 | 3 | dataset=Giga 4 | DataPath=$RootDir/data/$dataset 5 | 6 | MODEL=$1 7 | interval=1 8 | 9 | lr=5e-5 10 | 11 | model_size="large" 12 | 13 | outpath=output/${dataset}-bart-${model_size}-Unifiedtextinf-JointDenoise-6task-${lr}-AMREOS 14 | DataCache=$DataPath/.cache 15 | 16 | mkdir -p $outpath 17 | echo "OutputDir: $outpath" 18 | 19 | if [ ! -d ${DataCache} ];then 20 | mkdir -p ${DataCache} 21 | fi 22 | 23 | export HF_DATASETS_CACHE=$DataCache 24 | 25 | CUDA_VISIBLE_DEVICES=0,1 python -u -m torch.distributed.launch --nproc_per_node=2 run_multitask_unified_pretraining.py \ 26 | --train_file $DataPath/train.jsonl \ 27 | --val_file $DataPath/val.jsonl \ 28 | --test_file $DataPath/test.jsonl \ 29 | --output_dir $outpath \ 30 | --mlm \ 31 | --mlm_amr \ 32 | --mlm_text \ 33 | --mlm_amr_plus_text \ 34 | --mlm_text_plus_amr \ 35 | --mlm_joint_to_amr \ 36 | --mlm_joint_to_text \ 37 | --block_size 512 \ 38 | --per_gpu_train_batch_size 4 \ 39 | --gradient_accumulation_steps 4 \ 40 | --model_type "facebook/bart-${model_size}" \ 41 | --model_name_or_path $MODEL \ 42 | --save_total_limit 2 \ 43 | --do_train \ 44 | --do_eval \ 45 | --evaluate_during_training \ 46 | --num_train_epochs 100 \ 47 | --learning_rate $lr \ 48 | --joint_train_interval $interval \ 49 | --warmup_steps 2500 \ 50 | --max_steps 100000 \ 51 | --logging_steps 1000 \ 52 | --fp16 \ 53 | --overwrite_output_dir 2>&1 | tee $outpath/run.log -------------------------------------------------------------------------------- /pre-train/run-posttrain-bart-textinf-joint-denoising-6task-large-unified-V100.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 2 | dataset=Giga 3 | datapath=../data/$dataset 4 | MODEL=$1 5 | interval=1 6 | 7 | lr=5e-5 8 | 9 | model_size="large" 10 | 11 | outpath=output/${dataset}-bart-${model_size}-Unifiedtextinf-JointDenoise-6task-${lr}-AMREOS 12 | 13 | mkdir -p $outpath 14 | echo "OutputDir: $outpath" 15 | 16 | python -u -m torch.distributed.launch --nproc_per_node=8 run_multitask_unified_pretraining.py \ 17 | --train_file $datapath/train.jsonl \ 18 | --val_file $datapath/val.jsonl \ 19 | --test_file $datapath/test.jsonl \ 20 | --output_dir $outpath \ 21 | --mlm \ 22 | --mlm_amr \ 23 | --mlm_text \ 24 | --mlm_amr_plus_text \ 25 | --mlm_text_plus_amr \ 26 | --mlm_joint_to_amr \ 27 | --mlm_joint_to_text \ 28 | --block_size 512 \ 29 | --per_gpu_train_batch_size 2 \ 30 | --gradient_accumulation_steps 2 \ 31 | --model_type "facebook/bart-${model_size}" \ 32 | --model_name_or_path $MODEL \ 33 | --save_total_limit 2 \ 34 | --do_train \ 35 | --do_eval \ 36 | --evaluate_during_training \ 37 | --num_train_epochs 100 \ 38 | --learning_rate $lr \ 39 | --joint_train_interval $interval \ 40 | --warmup_steps 2500 \ 41 | --max_steps 100000 \ 42 | --logging_steps 1000 \ 43 | --fp16 \ 44 | --overwrite_output_dir 2>&1 | tee $outpath/run.log 45 | -------------------------------------------------------------------------------- /requirements.yml: -------------------------------------------------------------------------------- 1 | name: base 2 | dependencies: 3 | - nvidia::cudatoolkit=11.1.1 4 | - numpy 5 | - pillow 6 | - pip 7 | - python=3.8 8 | - pytorch::pytorch=1.8.1=py3.8_cuda11.1_cudnn8.0.5_0 9 | - scipy 10 | - tqdm 11 | - scikit-learn 12 | - gensim 13 | - pandas 14 | - tensorboard 15 | - tensorboardX 16 | - pip: 17 | - pillow-simd 18 | - h5py-cache 19 | - configargparse 20 | - sacrebleu 21 | - rouge-score 22 | - datasets==2.4.0 23 | - transformers==4.21.3 24 | - nltk 25 | - cached_property 26 | - networkx 27 | - penman>=1.1.0 28 | - pytorch-ignite 29 | - regex 30 | - smatch 31 | - wandb 32 | - amrlib 33 | - PyYAML>=5.1 34 | --------------------------------------------------------------------------------