├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── conda-env.txt ├── install_tools.sh ├── scripts ├── download_data.sh ├── gxlt_classify.sh ├── mtop.sh ├── panx.sh └── xlt_classify.sh ├── third_party ├── __init__.py ├── classify.py ├── graph_utils.py ├── gxlt_classify.py ├── modeling_bert.py ├── ner.py ├── processors │ ├── __init__.py │ ├── constants.py │ ├── tree.py │ ├── utils_classify.py │ ├── utils_tag.py │ └── utils_top.py ├── top.py └── ud-conversion-tools │ ├── conllu_to_conll.py │ └── lib │ ├── __init__.py │ └── conll.py ├── udpipe ├── conllify.py ├── mtop.sh ├── panx.sh ├── pawsx.sh ├── process.py └── xnli.sh └── utils_preprocess.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | #lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .idea/ 132 | .idea/* 133 | *.pyc 134 | */__pycache__/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Syntax-augmented Multilingual BERT 2 | Official code release of our ACL 2021 work, [Syntax-augmented Multilingual BERT for Cross-lingual Transfer](https://aclanthology.org/2021.acl-long.350). 3 | 4 | **[Notes]** 5 | 6 | - This repository provides implementations for three NLP applications. 7 | - Text classification, named entity recognition, and task-oriented semantic parsing. 8 | - We will release the question answering (QA) model implementation soon. 9 | 10 | 11 | ### Setup 12 | We setup a conda environment in order to run experiments. We assume [anaconda](https://www.anaconda.com/) 13 | and Python 3.6 is installed. The additional requirements (as noted in requirements.txt can be installed by running 14 | the following script: 15 | 16 | ```bash 17 | bash install_tools.sh 18 | ``` 19 | 20 | 21 | ### Data Preparation 22 | 23 | The next step is to download the data. To this end, first create a `download` folder with `mkdir -p download` in the root 24 | of this project. You then need to manually download `panx_dataset` (for NER) from [here](https://www.amazon.com/clouddrive/share/d3KGCRCIYwhKJF0H3eWA26hjg2ZCRhjpEQtDL70FSBN) 25 | (note that it will download as `AmazonPhotos.zip`) to the download directory. Finally, run the following command to 26 | download the remaining datasets: 27 | 28 | ```bash 29 | bash scripts/download_data.sh 30 | ``` 31 | 32 | To get the POS-tags and dependency parse of input sentences, we use UDPipe. Go to the 33 | [udpipe](https://github.com/wasiahmad/Syntax-MBERT/tree/main/udpipe) directory and run the task-specific scripts - 34 | `[xnli.sh|pawsx.sh|panx.sh|mtop.sh]`. 35 | 36 | 37 | ### Training and Evaluation 38 | 39 | The evaluation results (on the test set) are saved in `${SAVE_DIR}` directory (check the bash scripts). 40 | 41 | #### Text Classification 42 | 43 | ```bash 44 | cd scripts 45 | bash xlt_classify.sh GPU TASK USE_SYNTAX SEED 46 | ``` 47 | 48 | For **cross-lingual** text classification, do the following. 49 | 50 | ```bash 51 | # for XNLI 52 | bash xlt_classify.sh 0 xnli false 1111 53 | 54 | # for PAWS-X 55 | bash xlt_classify.sh 0 pawsx false 1111 56 | ``` 57 | 58 | - For syntax-agumented MBERT experiments, set `USE_SYNTAX=true`. 59 | - For **generalized cross-lingual** text classification evaluation, use the 60 | [gxlt_classify.sh](https://github.com/wasiahmad/Syntax-MBERT/blob/main/scripts/gxlt_classify.sh) script. 61 | 62 | 63 | #### Named Entity Recognition 64 | 65 | ```bash 66 | cd scripts 67 | bash panx.sh GPU USE_SYNTAX SEED 68 | ``` 69 | 70 | - For syntax-agumented MBERT experiments, set `USE_SYNTAX=true`. 71 | - For the CoNLL NER datasets, same set of scripts can be used (with revision). 72 | 73 | 74 | #### Task-oriented Semantic Parsing 75 | 76 | ```bash 77 | cd scripts 78 | bash mtop.sh GPU USE_SYNTAX SEED 79 | ``` 80 | 81 | - For syntax-agumented MBERT experiments, set `USE_SYNTAX=true`. 82 | - Since, mATIS++ dataset is not publicly available, we do not release the scripts. 83 | 84 | 85 | ### Acknowledgement 86 | We acknowledge the efforts of the authors of the following repositories. 87 | 88 | - https://github.com/google-research/xtreme 89 | - https://github.com/huggingface/transformers 90 | 91 | 92 | ### Citation 93 | 94 | ``` 95 | @inproceedings{ahmad-etal-2021-syntax, 96 | title = "Syntax-augmented Multilingual {BERT} for Cross-lingual Transfer", 97 | author = "Ahmad, Wasi and 98 | Li, Haoran and 99 | Chang, Kai-Wei and 100 | Mehdad, Yashar", 101 | booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)", 102 | month = aug, 103 | year = "2021", 104 | address = "Online", 105 | publisher = "Association for Computational Linguistics", 106 | url = "https://aclanthology.org/2021.acl-long.350", 107 | pages = "4538--4554", 108 | } 109 | ``` 110 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 Google and DeepMind. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /conda-env.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | @EXPLICIT 5 | https://repo.anaconda.com/pkgs/main/linux-64/_libgcc_mutex-0.1-main.tar.bz2 6 | https://repo.anaconda.com/pkgs/main/linux-64/blas-1.0-mkl.tar.bz2 7 | https://conda.anaconda.org/anaconda/linux-64/ca-certificates-2020.1.1-0.tar.bz2 8 | https://repo.anaconda.com/pkgs/main/linux-64/cudatoolkit-10.0.130-0.tar.bz2 9 | https://repo.anaconda.com/pkgs/main/linux-64/intel-openmp-2019.4-243.tar.bz2 10 | https://repo.anaconda.com/pkgs/main/linux-64/libgfortran-ng-7.3.0-hdf63c60_0.tar.bz2 11 | https://repo.anaconda.com/pkgs/main/linux-64/libstdcxx-ng-9.1.0-hdf63c60_0.tar.bz2 12 | https://repo.anaconda.com/pkgs/main/linux-64/libgcc-ng-9.1.0-hdf63c60_0.tar.bz2 13 | https://repo.anaconda.com/pkgs/main/linux-64/mkl-2019.4-243.tar.bz2 14 | https://conda.anaconda.org/anaconda/linux-64/expat-2.2.6-he6710b0_0.tar.bz2 15 | https://conda.anaconda.org/anaconda/linux-64/gmp-6.1.2-hb3b607b_0.tar.bz2 16 | https://conda.anaconda.org/anaconda/linux-64/icu-58.2-h211956c_0.tar.bz2 17 | https://repo.anaconda.com/pkgs/main/linux-64/jpeg-9b-h024ee3a_2.tar.bz2 18 | https://repo.anaconda.com/pkgs/main/linux-64/libffi-3.2.1-hd88cf55_4.tar.bz2 19 | https://conda.anaconda.org/anaconda/linux-64/libsodium-1.0.16-h1bed415_0.tar.bz2 20 | https://conda.anaconda.org/anaconda/linux-64/libuuid-1.0.3-h1bed415_2.tar.bz2 21 | https://conda.anaconda.org/anaconda/linux-64/libxcb-1.13-h1bed415_1.tar.bz2 22 | https://repo.anaconda.com/pkgs/main/linux-64/ncurses-6.1-he6710b0_1.tar.bz2 23 | https://conda.anaconda.org/anaconda/linux-64/openssl-1.1.1g-h7b6447c_0.tar.bz2 24 | https://conda.anaconda.org/anaconda/linux-64/pcre-8.43-he6710b0_0.tar.bz2 25 | https://repo.anaconda.com/pkgs/main/linux-64/xz-5.2.4-h14c3975_4.tar.bz2 26 | https://repo.anaconda.com/pkgs/main/linux-64/zlib-1.2.11-h7b6447c_3.tar.bz2 27 | https://conda.anaconda.org/anaconda/linux-64/glib-2.56.2-hd408876_0.tar.bz2 28 | https://repo.anaconda.com/pkgs/main/linux-64/libedit-3.1.20181209-hc058e9b_0.tar.bz2 29 | https://repo.anaconda.com/pkgs/main/linux-64/libpng-1.6.37-hbc83047_0.tar.bz2 30 | https://conda.anaconda.org/anaconda/linux-64/libxml2-2.9.9-hea5a465_1.tar.bz2 31 | https://conda.anaconda.org/anaconda/linux-64/pandoc-2.2.3.2-0.tar.bz2 32 | https://repo.anaconda.com/pkgs/main/linux-64/readline-7.0-h7b6447c_5.tar.bz2 33 | https://repo.anaconda.com/pkgs/main/linux-64/tk-8.6.8-hbc83047_0.tar.bz2 34 | https://conda.anaconda.org/anaconda/linux-64/zeromq-4.3.1-he6710b0_3.tar.bz2 35 | https://repo.anaconda.com/pkgs/main/linux-64/zstd-1.3.7-h0b5b093_0.tar.bz2 36 | https://conda.anaconda.org/anaconda/linux-64/dbus-1.13.12-h746ee38_0.tar.bz2 37 | https://repo.anaconda.com/pkgs/main/linux-64/freetype-2.9.1-h8a8886c_1.tar.bz2 38 | https://conda.anaconda.org/anaconda/linux-64/gstreamer-1.14.0-hb453b48_1.tar.bz2 39 | https://repo.anaconda.com/pkgs/main/linux-64/libtiff-4.1.0-h2733197_0.tar.bz2 40 | https://repo.anaconda.com/pkgs/main/linux-64/sqlite-3.30.1-h7b6447c_0.tar.bz2 41 | https://conda.anaconda.org/anaconda/linux-64/fontconfig-2.13.0-h9420a91_0.tar.bz2 42 | https://conda.anaconda.org/anaconda/linux-64/gst-plugins-base-1.14.0-hbbd80ab_1.tar.bz2 43 | https://repo.anaconda.com/pkgs/main/linux-64/python-3.7.5-h0371630_0.tar.bz2 44 | https://conda.anaconda.org/anaconda/noarch/attrs-19.3.0-py_0.tar.bz2 45 | https://conda.anaconda.org/anaconda/linux-64/backcall-0.1.0-py37_0.tar.bz2 46 | https://conda.anaconda.org/anaconda/linux-64/certifi-2019.11.28-py37_0.tar.bz2 47 | https://conda.anaconda.org/anaconda/noarch/decorator-4.4.1-py_0.tar.bz2 48 | https://conda.anaconda.org/anaconda/noarch/defusedxml-0.6.0-py_0.tar.bz2 49 | https://conda.anaconda.org/anaconda/linux-64/entrypoints-0.3-py37_0.tar.bz2 50 | https://conda.anaconda.org/anaconda/linux-64/ipython_genutils-0.2.0-py37_0.tar.bz2 51 | https://conda.anaconda.org/anaconda/linux-64/markupsafe-1.1.1-py37h7b6447c_0.tar.bz2 52 | https://conda.anaconda.org/anaconda/linux-64/mistune-0.8.4-py37h7b6447c_0.tar.bz2 53 | https://conda.anaconda.org/anaconda/noarch/more-itertools-8.0.2-py_0.tar.bz2 54 | https://repo.anaconda.com/pkgs/main/linux-64/ninja-1.9.0-py37hfd86e86_0.tar.bz2 55 | https://repo.anaconda.com/pkgs/main/noarch/olefile-0.46-py_0.tar.bz2 56 | https://conda.anaconda.org/anaconda/linux-64/pandocfilters-1.4.2-py37_1.tar.bz2 57 | https://conda.anaconda.org/anaconda/noarch/parso-0.5.2-py_0.tar.bz2 58 | https://conda.anaconda.org/anaconda/linux-64/pickleshare-0.7.5-py37_0.tar.bz2 59 | https://conda.anaconda.org/anaconda/noarch/prometheus_client-0.7.1-py_0.tar.bz2 60 | https://conda.anaconda.org/anaconda/linux-64/ptyprocess-0.6.0-py37_0.tar.bz2 61 | https://repo.anaconda.com/pkgs/main/noarch/pycparser-2.19-py_0.tar.bz2 62 | https://repo.anaconda.com/pkgs/main/noarch/pytz-2019.3-py_0.tar.bz2 63 | https://conda.anaconda.org/anaconda/linux-64/pyzmq-18.1.0-py37he6710b0_0.tar.bz2 64 | https://conda.anaconda.org/anaconda/linux-64/qt-5.9.7-h5867ecd_1.tar.bz2 65 | https://conda.anaconda.org/anaconda/linux-64/send2trash-1.5.0-py37_0.tar.bz2 66 | https://conda.anaconda.org/anaconda/linux-64/sip-4.19.13-py37he6710b0_0.tar.bz2 67 | https://repo.anaconda.com/pkgs/main/linux-64/six-1.13.0-py37_0.tar.bz2 68 | https://conda.anaconda.org/anaconda/noarch/testpath-0.4.4-py_0.tar.bz2 69 | https://conda.anaconda.org/anaconda/linux-64/tornado-6.0.3-py37h7b6447c_0.tar.bz2 70 | https://conda.anaconda.org/anaconda/linux-64/wcwidth-0.1.7-py37_0.tar.bz2 71 | https://conda.anaconda.org/anaconda/linux-64/webencodings-0.5.1-py37_1.tar.bz2 72 | https://repo.anaconda.com/pkgs/main/linux-64/cffi-1.13.2-py37h2e261b9_0.tar.bz2 73 | https://conda.anaconda.org/anaconda/linux-64/jedi-0.15.1-py37_0.tar.bz2 74 | https://repo.anaconda.com/pkgs/main/linux-64/mkl-service-2.3.0-py37he904b0f_0.tar.bz2 75 | https://conda.anaconda.org/anaconda/linux-64/networkx-1.11-py37_1.tar.bz2 76 | https://conda.anaconda.org/anaconda/linux-64/pexpect-4.7.0-py37_0.tar.bz2 77 | https://repo.anaconda.com/pkgs/main/linux-64/pillow-6.2.1-py37h34e0f95_0.tar.bz2 78 | https://conda.anaconda.org/anaconda/linux-64/pyqt-5.9.2-py37h22d08a2_1.tar.bz2 79 | https://conda.anaconda.org/anaconda/linux-64/pyrsistent-0.15.6-py37h7b6447c_0.tar.bz2 80 | https://conda.anaconda.org/anaconda/noarch/python-dateutil-2.8.1-py_0.tar.bz2 81 | https://repo.anaconda.com/pkgs/main/linux-64/setuptools-42.0.2-py37_0.tar.bz2 82 | https://conda.anaconda.org/anaconda/linux-64/terminado-0.8.3-py37_0.tar.bz2 83 | https://conda.anaconda.org/anaconda/linux-64/traitlets-4.3.3-py37_0.tar.bz2 84 | https://conda.anaconda.org/anaconda/noarch/zipp-0.6.0-py_0.tar.bz2 85 | https://conda.anaconda.org/anaconda/noarch/bleach-3.1.0-py_0.tar.bz2 86 | https://conda.anaconda.org/anaconda/linux-64/importlib_metadata-1.3.0-py37_0.tar.bz2 87 | https://conda.anaconda.org/anaconda/noarch/jinja2-2.10.3-py_0.tar.bz2 88 | https://conda.anaconda.org/anaconda/linux-64/jupyter_core-4.6.1-py37_0.tar.bz2 89 | https://repo.anaconda.com/pkgs/main/linux-64/numpy-base-1.17.4-py37hde5b4d6_0.tar.bz2 90 | https://conda.anaconda.org/anaconda/noarch/pygments-2.5.2-py_0.tar.bz2 91 | https://repo.anaconda.com/pkgs/main/linux-64/wheel-0.33.6-py37_0.tar.bz2 92 | https://conda.anaconda.org/anaconda/linux-64/jsonschema-3.2.0-py37_0.tar.bz2 93 | https://conda.anaconda.org/anaconda/linux-64/jupyter_client-5.3.4-py37_0.tar.bz2 94 | https://repo.anaconda.com/pkgs/main/linux-64/pip-19.3.1-py37_0.tar.bz2 95 | https://conda.anaconda.org/anaconda/noarch/prompt_toolkit-3.0.2-py_0.tar.bz2 96 | https://conda.anaconda.org/anaconda/linux-64/ipython-7.10.2-py37h39e3cac_0.tar.bz2 97 | https://conda.anaconda.org/anaconda/linux-64/nbformat-4.4.0-py37_0.tar.bz2 98 | https://conda.anaconda.org/anaconda/linux-64/ipykernel-5.1.3-py37h39e3cac_0.tar.bz2 99 | https://conda.anaconda.org/anaconda/linux-64/nbconvert-5.6.1-py37_0.tar.bz2 100 | https://conda.anaconda.org/anaconda/linux-64/jupyter_console-5.2.0-py37_1.tar.bz2 101 | https://conda.anaconda.org/anaconda/linux-64/notebook-6.0.2-py37_0.tar.bz2 102 | https://conda.anaconda.org/anaconda/noarch/qtconsole-4.6.0-py_0.tar.bz2 103 | https://conda.anaconda.org/anaconda/linux-64/widgetsnbextension-3.5.1-py37_0.tar.bz2 104 | https://conda.anaconda.org/anaconda/noarch/ipywidgets-7.5.1-py_0.tar.bz2 105 | https://conda.anaconda.org/anaconda/linux-64/jupyter-1.0.0-py37_7.tar.bz2 106 | https://conda.anaconda.org/pytorch/linux-64/faiss-gpu-1.6.0-py37h1a5d453_0.tar.bz2 107 | https://repo.anaconda.com/pkgs/main/linux-64/mkl_fft-1.0.15-py37ha843d7b_0.tar.bz2 108 | https://repo.anaconda.com/pkgs/main/linux-64/mkl_random-1.1.0-py37hd6b4f25_0.tar.bz2 109 | https://repo.anaconda.com/pkgs/main/linux-64/numpy-1.17.4-py37hc1035e2_0.tar.bz2 110 | https://repo.anaconda.com/pkgs/main/linux-64/pandas-0.25.3-py37he6710b0_0.tar.bz2 111 | https://conda.anaconda.org/pytorch/linux-64/pytorch-1.3.1-py3.7_cuda10.0.130_cudnn7.6.3_0.tar.bz2 112 | https://conda.anaconda.org/pytorch/linux-64/torchvision-0.4.2-py37_cu100.tar.bz2 113 | -------------------------------------------------------------------------------- /install_tools.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2020 Google and DeepMind. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | REPO=$PWD 17 | LIB=$REPO/third_party 18 | mkdir -p $LIB 19 | 20 | # install conda env 21 | conda create --name xtreme --file conda-env.txt 22 | conda init bash 23 | source activate xtreme 24 | 25 | # install latest transformer 26 | cd $LIB 27 | git clone https://github.com/huggingface/transformers 28 | cd transformers 29 | git checkout cefd51c50cc08be8146c1151544495968ce8f2ad 30 | pip install . 31 | cd $LIB 32 | 33 | pip install seqeval 34 | pip install tensorboardx 35 | 36 | # install XLM tokenizer 37 | pip install sacremoses 38 | pip install pythainlp 39 | pip install jieba 40 | 41 | git clone https://github.com/neubig/kytea.git && cd kytea 42 | autoreconf -i 43 | ./configure --prefix=$HOME/local 44 | make && make install 45 | pip install kytea 46 | -------------------------------------------------------------------------------- /scripts/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2020 Google and DeepMind. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | REPO=$PWD 17 | DIR=$REPO/download 18 | mkdir -p $DIR 19 | 20 | # download XNLI dataset 21 | function download_xnli { 22 | OUTPATH=$DIR/xnli-tmp 23 | if [[ ! -d $OUTPATH/XNLI-MT-1.0 ]]; then 24 | if [[ ! -f $OUTPATH/XNLI-MT-1.0.zip ]]; then 25 | wget -c https://dl.fbaipublicfiles.com/XNLI/XNLI-MT-1.0.zip -P $OUTPATH -q --show-progress 26 | fi 27 | unzip -qq $OUTPATH/XNLI-MT-1.0.zip -d $OUTPATH 28 | fi 29 | if [[ ! -d $OUTPATH/XNLI-1.0 ]]; then 30 | if [[ ! -f $OUTPATH/XNLI-1.0.zip ]]; then 31 | wget -c https://dl.fbaipublicfiles.com/XNLI/XNLI-1.0.zip -P $OUTPATH -q --show-progress 32 | fi 33 | unzip -qq $OUTPATH/XNLI-1.0.zip -d $OUTPATH 34 | fi 35 | python $REPO/utils_preprocess.py \ 36 | --data_dir $OUTPATH \ 37 | --output_dir $DIR/xnli \ 38 | --task xnli; 39 | rm -rf $OUTPATH 40 | echo "Successfully downloaded data at $DIR/xnli" >> $DIR/download.log 41 | } 42 | 43 | # download PAWS-X dataset 44 | function download_pawsx { 45 | cd $DIR 46 | wget https://storage.googleapis.com/paws/pawsx/x-final.tar.gz -q --show-progress 47 | tar xzf x-final.tar.gz -C $DIR 48 | python $REPO/utils_preprocess.py \ 49 | --data_dir $DIR/x-final \ 50 | --output_dir $DIR/pawsx \ 51 | --task pawsx; 52 | rm -rf x-final x-final.tar.gz 53 | echo "Successfully downloaded data at $DIR/pawsx" >> $DIR/download.log 54 | } 55 | 56 | # download UD-POS dataset 57 | function download_udpos { 58 | base_dir=$DIR/udpos-tmp 59 | out_dir=$base_dir/conll 60 | mkdir -p $out_dir 61 | cd $base_dir 62 | curl -s --remote-name-all https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-3105/ud-treebanks-v2.5.tgz 63 | tar -xzf $base_dir/ud-treebanks-v2.5.tgz 64 | 65 | langs=(af ar bg de el en es et eu fa fi fr he hi hu id it ja kk ko mr nl pt ru ta te th tl tr ur vi yo zh) 66 | for x in $base_dir/ud-treebanks-v2.5/*/*.conllu; do 67 | file="$(basename $x)" 68 | IFS='_' read -r -a array <<< "$file" 69 | lang=${array[0]} 70 | if [[ " ${langs[@]} " =~ " ${lang} " ]]; then 71 | lang_dir=$out_dir/$lang 72 | mkdir -p $lang_dir 73 | y=$lang_dir/${file/conllu/conll} 74 | if [[ ! -f "$y" ]]; then 75 | echo "python $REPO/third_party/ud-conversion-tools/conllu_to_conll.py $x $y \ 76 | --lang $lang --replace_subtokens_with_fused_forms --print_fused_forms" 77 | python $REPO/third_party/ud-conversion-tools/conllu_to_conll.py $x $y \ 78 | --lang $lang --replace_subtokens_with_fused_forms --print_fused_forms; 79 | else 80 | echo "${y} exists" 81 | fi 82 | fi 83 | done 84 | 85 | python $REPO/utils_preprocess.py --data_dir $out_dir --output_dir $DIR/udpos --task udpos 86 | rm -rf $out_dir ud-treebanks-v2.tgz $DIR/udpos-tmp 87 | echo "Successfully downloaded data at $DIR/udpos" >> $DIR/download.log 88 | } 89 | 90 | function download_panx { 91 | echo "Download panx NER dataset" 92 | if [[ ! -f $DIR/AmazonPhotos.zip ]]; then 93 | echo "Please download the AmazonPhotos.zip file on Amazon Cloud Drive mannually and save it to $DIR/AmazonPhotos.zip" 94 | echo "https://www.amazon.com/clouddrive/share/d3KGCRCIYwhKJF0H3eWA26hjg2ZCRhjpEQtDL70FSBN" 95 | else 96 | base_dir=$DIR/panx_dataset 97 | unzip -qq -j $DIR/AmazonPhotos.zip -d $base_dir 98 | cd $base_dir 99 | langs=(ar he vi id jv ms tl eu ml ta te af nl en de el bn hi mr ur fa fr it pt es bg ru ja ka ko th sw yo my zh kk tr et fi hu) 100 | for lg in ${langs[@]}; do 101 | tar xzf $base_dir/${lg}.tar.gz 102 | for f in dev test train; do 103 | mv $base_dir/$f $base_dir/${lg}-${f}; 104 | done 105 | done 106 | cd .. 107 | python $REPO/utils_preprocess.py \ 108 | --data_dir $base_dir \ 109 | --output_dir $DIR/panx \ 110 | --task panx; 111 | rm -rf $base_dir 112 | echo "Successfully downloaded data at $DIR/panx" >> $DIR/download.log 113 | fi 114 | } 115 | 116 | function download_mtop () { 117 | cd $DIR 118 | wget https://dl.fbaipublicfiles.com/mtop/mtop.zip -q --show-progress 119 | unzip mtop.zip -d . 120 | python $REPO/utils_preprocess.py \ 121 | --data_dir $DIR/mtop \ 122 | --output_dir $DIR/mtop \ 123 | --task mtop; 124 | rm -rf mtop.zip 125 | echo "Successfully downloaded data at $DIR/mtop" >> $DIR/download.log 126 | } 127 | 128 | 129 | download_xnli 130 | download_pawsx 131 | download_udpos 132 | download_panx 133 | #download_mtop 134 | -------------------------------------------------------------------------------- /scripts/gxlt_classify.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2020 Google and DeepMind. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | CURRENT_DIR=`pwd` 17 | HOME_DIR=`realpath ..` 18 | 19 | GPU=${1:-0} 20 | TASK=${2:-"xnli"} 21 | USE_SYNTAX=${3:-"false"} 22 | SEED=${3:-1111} 23 | 24 | DATA_DIR=${HOME_DIR}/download 25 | OUT_DIR=${HOME_DIR}/outputs 26 | 27 | export CUDA_VISIBLE_DEVICES=$GPU 28 | 29 | if [[ "$TASK" == 'xnli' ]]; then 30 | LANGS="ar,bg,de,el,en,es,fr,hi,ru,tr,ur,vi,zh" 31 | else 32 | LANGS="de,en,es,fr,ja,ko,zh" 33 | fi 34 | 35 | if [[ "$USE_SYNTAX" == 'true' ]]; then 36 | SAVE_DIR="${OUT_DIR}/${TASK}/syntax-seed${SEED}" 37 | else 38 | SAVE_DIR="${OUT_DIR}/${TASK}/seed${SEED}" 39 | fi 40 | mkdir -p $SAVE_DIR; 41 | 42 | export PYTHONPATH=$HOME_DIR; 43 | python $HOME_DIR/third_party/gxlt_classify.py \ 44 | --seed $SEED \ 45 | --model_type bert \ 46 | --model_name_or_path bert-base-multilingual-cased \ 47 | --task_name $TASK \ 48 | --do_predict \ 49 | --data_dir $DATA_DIR/${TASK}_udpipe_processed \ 50 | --per_gpu_eval_batch_size 32 \ 51 | --max_seq_length 128 \ 52 | --output_dir $SAVE_DIR \ 53 | --log_file 'gxl.log' \ 54 | --use_syntax $USE_SYNTAX \ 55 | --use_pos_tag $USE_SYNTAX \ 56 | --predict_languages $LANGS \ 57 | --overwrite_output_dir \ 58 | 2>&1 | tee $SAVE_DIR/output.log; 59 | -------------------------------------------------------------------------------- /scripts/mtop.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2020 Google and DeepMind. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | CURRENT_DIR=`pwd` 17 | HOME_DIR=`realpath ..` 18 | 19 | GPU=${1:-0} 20 | USE_SYNTAX=${2:-"false"} 21 | SEED=${3:-1111} 22 | 23 | DATA_DIR=${HOME_DIR}/download 24 | OUT_DIR=${HOME_DIR}/outputs 25 | DATA_DIR=$DATA_DIR/mtop_udpipe_processed 26 | 27 | export CUDA_VISIBLE_DEVICES=$GPU 28 | LANGS="en,es,fr,de,hi" 29 | 30 | if [[ "$USE_SYNTAX" == 'true' ]]; then 31 | SAVE_DIR="${OUT_DIR}/mtop/syntax-seed${SEED}" 32 | else 33 | SAVE_DIR="${OUT_DIR}/mtop/seed${SEED}" 34 | fi 35 | mkdir -p $SAVE_DIR; 36 | 37 | export PYTHONPATH=$HOME_DIR; 38 | python $HOME_DIR/third_party/top.py \ 39 | --seed $SEED \ 40 | --data_dir $DATA_DIR \ 41 | --intent_labels $DATA_DIR/intent_label.txt \ 42 | --slot_labels $DATA_DIR/slot_label.txt \ 43 | --model_type bert \ 44 | --model_name_or_path bert-base-multilingual-cased \ 45 | --task_name mtop \ 46 | --output_dir $SAVE_DIR \ 47 | --max_seq_length 96 \ 48 | --num_train_epochs 10 \ 49 | --gradient_accumulation_steps 1 \ 50 | --per_gpu_train_batch_size 32 \ 51 | --per_gpu_eval_batch_size 32 \ 52 | --save_steps 200 \ 53 | --learning_rate 2e-5 \ 54 | --do_train \ 55 | --do_predict \ 56 | --predict_langs $LANGS \ 57 | --train_langs en \ 58 | --log_file $SAVE_DIR/train.log \ 59 | --eval_all_checkpoints \ 60 | --eval_patience -1 \ 61 | --overwrite_output_dir \ 62 | --save_only_best_checkpoint \ 63 | --use_syntax $USE_SYNTAX \ 64 | --use_pos_tag $USE_SYNTAX \ 65 | --use_structural_loss $USE_SYNTAX \ 66 | --struct_loss_coeff 1.0 \ 67 | --num_gat_layer 4 \ 68 | --num_gat_head 4 \ 69 | --max_syntactic_distance 1 \ 70 | --num_syntactic_heads 2 \ 71 | --syntactic_layers 0,1,2,3,4,5,6,7,8,9,10,11 \ 72 | 2>&1 | tee $SAVE_DIR/output.log; 73 | -------------------------------------------------------------------------------- /scripts/panx.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2020 Google and DeepMind. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | CURRENT_DIR=`pwd` 17 | HOME_DIR=`realpath ..` 18 | 19 | GPU=${1:-0} 20 | USE_SYNTAX=${2:-"false"} 21 | SEED=${3:-1111} 22 | 23 | DATA_DIR=${HOME_DIR}/download 24 | OUT_DIR=${HOME_DIR}/outputs 25 | DATA_DIR=$DATA_DIR/panx_udpipe_processed 26 | 27 | export CUDA_VISIBLE_DEVICES=$GPU 28 | LANGS="en,ar,bg,de,el,es,fr,hi,ru,tr,ur,vi,ko,nl,pt" 29 | 30 | if [[ "$USE_SYNTAX" == 'true' ]]; then 31 | SAVE_DIR="${OUT_DIR}/panx/syntax-seed${SEED}" 32 | else 33 | SAVE_DIR="${OUT_DIR}/panx/seed${SEED}" 34 | fi 35 | mkdir -p $SAVE_DIR; 36 | 37 | export PYTHONPATH=$HOME_DIR; 38 | python $HOME_DIR/third_party/ner.py \ 39 | --seed $SEED \ 40 | --data_dir $DATA_DIR \ 41 | --labels $DATA_DIR/labels.txt \ 42 | --model_type bert \ 43 | --model_name_or_path bert-base-multilingual-cased \ 44 | --task_name panx \ 45 | --output_dir $SAVE_DIR \ 46 | --max_seq_length 128 \ 47 | --num_train_epochs 3 \ 48 | --gradient_accumulation_steps 1 \ 49 | --per_gpu_train_batch_size 32 \ 50 | --per_gpu_eval_batch_size 32 \ 51 | --save_steps 5000 \ 52 | --learning_rate 2e-5 \ 53 | --do_train \ 54 | --do_predict \ 55 | --predict_langs $LANGS \ 56 | --train_langs en \ 57 | --log_file $SAVE_DIR/train.log \ 58 | --eval_all_checkpoints \ 59 | --eval_patience -1 \ 60 | --overwrite_output_dir \ 61 | --save_only_best_checkpoint \ 62 | --use_syntax $USE_SYNTAX \ 63 | --use_pos_tag $USE_SYNTAX \ 64 | --use_structural_loss $USE_SYNTAX \ 65 | --struct_loss_coeff 0.5 \ 66 | --num_gat_layer 4 \ 67 | --num_gat_head 4 \ 68 | --max_syntactic_distance 1 \ 69 | --num_syntactic_heads 1 \ 70 | --syntactic_layers 0,1,2,3,4,5,6,7,8,9,10,11 \ 71 | 2>&1 | tee $SAVE_DIR/output.log; 72 | -------------------------------------------------------------------------------- /scripts/xlt_classify.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2020 Google and DeepMind. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | CURRENT_DIR=`pwd` 17 | HOME_DIR=`realpath ..` 18 | 19 | GPU=${1:-0} 20 | TASK=${2:-"xnli"} 21 | USE_SYNTAX=${3:-"false"} 22 | SEED=${4:-1111} 23 | 24 | DATA_DIR=${HOME_DIR}/download 25 | OUT_DIR=${HOME_DIR}/outputs 26 | 27 | export CUDA_VISIBLE_DEVICES=$GPU 28 | 29 | if [[ "$TASK" == 'xnli' ]]; then 30 | LANGS="ar,bg,de,el,en,es,fr,hi,ru,tr,ur,vi,zh" 31 | else 32 | LANGS="de,en,es,fr,ja,ko,zh" 33 | fi 34 | 35 | if [[ "$USE_SYNTAX" == 'true' ]]; then 36 | SAVE_DIR="${OUT_DIR}/${TASK}/syntax-seed${SEED}" 37 | else 38 | SAVE_DIR="${OUT_DIR}/${TASK}/seed${SEED}" 39 | fi 40 | mkdir -p $SAVE_DIR; 41 | 42 | export PYTHONPATH=$HOME_DIR; 43 | python $HOME_DIR/third_party/classify.py \ 44 | --seed $SEED \ 45 | --model_type bert \ 46 | --model_name_or_path bert-base-multilingual-cased \ 47 | --train_language en \ 48 | --task_name $TASK \ 49 | --do_train \ 50 | --do_predict \ 51 | --data_dir $DATA_DIR/${TASK}_udpipe_processed \ 52 | --gradient_accumulation_steps 1 \ 53 | --per_gpu_train_batch_size 32 \ 54 | --per_gpu_eval_batch_size 32 \ 55 | --learning_rate 2e-5 \ 56 | --num_train_epochs 5 \ 57 | --max_seq_length 128 \ 58 | --output_dir $SAVE_DIR/ \ 59 | --save_steps 2000 \ 60 | --eval_all_checkpoints \ 61 | --log_file 'train.log' \ 62 | --predict_languages $LANGS \ 63 | --save_only_best_checkpoint \ 64 | --overwrite_output_dir \ 65 | --use_syntax $USE_SYNTAX \ 66 | --use_pos_tag $USE_SYNTAX \ 67 | --use_structural_loss $USE_SYNTAX \ 68 | --struct_loss_coeff 1.0 \ 69 | --num_gat_layer 4 \ 70 | --num_gat_head 4 \ 71 | --max_syntactic_distance 4 \ 72 | --num_syntactic_heads 1 \ 73 | --syntactic_layers 0,1,2,3,4,5,6,7,8,9,10,11 \ 74 | 2>&1 | tee $SAVE_DIR/output.log; 75 | -------------------------------------------------------------------------------- /third_party/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /third_party/graph_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GraphAttentionNetwork(nn.Module): 7 | def __init__(self, nfeat, nhid, dropout, alpha, nheads): 8 | """Dense version of GAT.""" 9 | super(GraphAttentionNetwork, self).__init__() 10 | self.dropout = dropout 11 | 12 | self.attention_heads = nn.ModuleList([ 13 | GraphAttentionHead(nfeat, nhid, dropout=dropout, alpha=alpha, concat=False) 14 | for _ in range(nheads) 15 | ]) 16 | 17 | def forward(self, x, adj): 18 | # x = (N, l, in_features), adj = (N, l, l) 19 | y = torch.cat([att(x, adj) for att in self.attention_heads], dim=2) 20 | y = F.dropout(y, self.dropout, training=self.training) 21 | return y 22 | 23 | 24 | class GraphAttentionHead(nn.Module): 25 | """ 26 | Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 27 | """ 28 | 29 | def __init__(self, in_features, out_features, dropout, alpha, concat=True): 30 | super(GraphAttentionHead, self).__init__() 31 | self.dropout = dropout 32 | self.in_features = in_features 33 | self.out_features = out_features 34 | self.alpha = alpha 35 | self.concat = concat 36 | 37 | self.W = nn.Parameter(torch.empty(size=(in_features, out_features))) 38 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 39 | # self.a = nn.Parameter(torch.empty(size=(2 * out_features, 1))) 40 | # nn.init.xavier_uniform_(self.a.data, gain=1.414) 41 | 42 | self.a1 = nn.Parameter(torch.empty(size=(out_features, 1))) 43 | self.a2 = nn.Parameter(torch.empty(size=(out_features, 1))) 44 | nn.init.xavier_uniform_(self.a1.data, gain=1.414) 45 | nn.init.xavier_uniform_(self.a2.data, gain=1.414) 46 | 47 | self.leakyrelu = nn.LeakyReLU(self.alpha) 48 | 49 | def forward(self, h, adj): 50 | # h.shape: (bsz, N, in_features), Wh.shape: (bsz, N, out_features) 51 | Wh = torch.matmul(h, self.W) 52 | # a_input.shape : (bsz, N, N, 2 * out_features) 53 | # a_input = self._prepare_attentional_mechanism_input(Wh) 54 | 55 | # e.shape : (bsz, N, N) 56 | # e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(3)) 57 | 58 | f_1 = torch.matmul(Wh, self.a1) # (bsz, N, 1) 59 | f_2 = torch.matmul(Wh, self.a2) # (bsz, N, 1) 60 | e = self.leakyrelu(f_1 + f_2.transpose(1, 2)) # (bsz, N, N) 61 | 62 | zero_vec = -9e15 * torch.ones_like(e) 63 | # attention.shape : (bsz, N, N) 64 | attention = torch.where(adj > 0, e, zero_vec) 65 | # attention.shape : (bsz, N, N) 66 | attention = F.softmax(attention, dim=2) 67 | attention = F.dropout(attention, self.dropout, training=self.training) 68 | # h_prime.shape: (bsz, N, in_features) 69 | h_prime = torch.matmul(attention, Wh) 70 | 71 | if self.concat: 72 | return F.elu(h_prime) 73 | else: 74 | return h_prime 75 | 76 | def _prepare_attentional_mechanism_input(self, Wh): 77 | bsz = Wh.size(0) 78 | N = Wh.size(1) # number of nodes 79 | 80 | # Below, two matrices are created that contain embeddings in their rows in different orders. 81 | # (e stands for embedding) 82 | # These are the rows of the first matrix (Wh_repeated_in_chunks): 83 | # e1, e1, ..., e1, e2, e2, ..., e2, ..., eN, eN, ..., eN 84 | # '-------------' -> N times '-------------' -> N times '-------------' -> N times 85 | # 86 | # These are the rows of the second matrix (Wh_repeated_alternating): 87 | # e1, e2, ..., eN, e1, e2, ..., eN, ..., e1, e2, ..., eN 88 | # '----------------------------------------------------' -> N times 89 | # 90 | 91 | Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=1) 92 | Wh_repeated_alternating = Wh.repeat(1, N, 1) 93 | # Wh_repeated_in_chunks.shape == Wh_repeated_alternating.shape == (bsz, N * N, out_features) 94 | 95 | # The all_combination_matrix, created below, will look like this (|| denotes concatenation): 96 | # e1 || e1 97 | # e1 || e2 98 | # e1 || e3 99 | # ... 100 | # e1 || eN 101 | # e2 || e1 102 | # e2 || e2 103 | # e2 || e3 104 | # ... 105 | # e2 || eN 106 | # ... 107 | # eN || e1 108 | # eN || e2 109 | # eN || e3 110 | # ... 111 | # eN || eN 112 | 113 | all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=2) 114 | # all_combinations_matrix.shape == (bsz, N * N, 2 * out_features) 115 | 116 | return all_combinations_matrix.view(bsz, N, N, 2 * self.out_features) 117 | 118 | def __repr__(self): 119 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 120 | 121 | 122 | class GATModel(nn.Module): 123 | 124 | def __init__(self, config): 125 | super(GATModel, self).__init__() 126 | 127 | head_dim = config.hidden_size // config.num_attention_heads 128 | self.max_syntactic_distance = config.max_syntactic_distance 129 | 130 | self.dep_tag_embed = None 131 | if config.use_dependency_tag: 132 | self.dep_tag_embed = nn.Embedding( 133 | config.dep_tag_vocab_size, 134 | config.hidden_size, 135 | padding_idx=0 136 | ) 137 | self.pos_tag_embed = None 138 | if config.use_pos_tag: 139 | self.pos_tag_embed = nn.Embedding( 140 | config.pos_tag_vocab_size, 141 | config.hidden_size, 142 | padding_idx=0 143 | ) 144 | 145 | self.graph_attention = [] 146 | gat_out_size = head_dim * config.num_gat_head 147 | for i in range(config.num_gat_layer): 148 | input_size = config.hidden_size if i == 0 else gat_out_size 149 | self.graph_attention.append( 150 | GraphAttentionNetwork( 151 | nfeat=input_size, 152 | nhid=head_dim, 153 | dropout=config.hidden_dropout_prob, # 0.1 154 | alpha=0.2, 155 | nheads=config.num_gat_head 156 | ) 157 | ) 158 | self.graph_attention = nn.ModuleList(self.graph_attention) 159 | self.graph_attention_layer_norm = nn.LayerNorm(gat_out_size, eps=config.layer_norm_eps) 160 | 161 | def forward( 162 | self, inputs_embeds, dist_mat, deptag_ids=None, postag_ids=None 163 | ): 164 | gat_rep = inputs_embeds 165 | if self.dep_tag_embed is not None: 166 | assert deptag_ids is not None 167 | dep_embeddings = self.dep_tag_embed(deptag_ids) # B x T x (L*head_dim) 168 | gat_rep += dep_embeddings 169 | if self.pos_tag_embed is not None: 170 | assert postag_ids is not None 171 | pos_embeddings = self.pos_tag_embed(postag_ids) # B x T x (L*head_dim) 172 | gat_rep += pos_embeddings 173 | 174 | adj_mat = dist_mat.clone() 175 | adj_mat[dist_mat <= self.max_syntactic_distance] = 1 176 | adj_mat[dist_mat > self.max_syntactic_distance] = 0 177 | for _, gatlayer in enumerate(self.graph_attention): 178 | gat_rep = gatlayer(gat_rep, adj_mat) 179 | gat_rep = self.graph_attention_layer_norm(gat_rep) 180 | 181 | return gat_rep 182 | 183 | 184 | class L1DistanceLoss(nn.Module): 185 | """Custom L1 loss for distance matrices.""" 186 | 187 | def __init__(self, word_pair_dims=(1, 2)): 188 | super(L1DistanceLoss, self).__init__() 189 | self.word_pair_dims = word_pair_dims 190 | 191 | def forward(self, predictions, label_batch, length_batch): 192 | """ Computes L1 loss on distance matrices. 193 | Ignores all entries where label_batch=-1 194 | Normalizes first within sentences (by dividing by the square of the sentence length) 195 | and then across the batch. 196 | Args: 197 | predictions: A pytorch batch of predicted distances 198 | label_batch: A pytorch batch of true distances 199 | length_batch: A pytorch batch of sentence lengths 200 | Returns: 201 | A tuple of: 202 | batch_loss: average loss in the batch 203 | total_sents: number of sentences in the batch 204 | """ 205 | labels_1s = (label_batch != 99999).float() 206 | predictions_masked = predictions * labels_1s 207 | labels_masked = label_batch * labels_1s 208 | total_sents = torch.sum((length_batch != 0)).float() 209 | squared_lengths = length_batch.pow(2).float() 210 | loss_per_sent = torch.sum(torch.abs(predictions_masked - labels_masked), dim=self.word_pair_dims) 211 | normalized_loss_per_sent = loss_per_sent / squared_lengths 212 | batch_loss = torch.sum(normalized_loss_per_sent) / total_sents 213 | return batch_loss, total_sents 214 | 215 | 216 | class L1DepthLoss(nn.Module): 217 | """Custom L1 loss for depth sequences.""" 218 | 219 | def __init__(self, word_dim=1): 220 | super(L1DepthLoss, self).__init__() 221 | self.word_dim = word_dim 222 | 223 | def forward(self, predictions, label_batch, length_batch): 224 | """ Computes L1 loss on depth sequences. 225 | Ignores all entries where label_batch=-1 226 | Normalizes first within sentences (by dividing by the sentence length) 227 | and then across the batch. 228 | Args: 229 | predictions: A pytorch batch of predicted depths 230 | label_batch: A pytorch batch of true depths 231 | length_batch: A pytorch batch of sentence lengths 232 | Returns: 233 | A tuple of: 234 | batch_loss: average loss in the batch 235 | total_sents: number of sentences in the batch 236 | """ 237 | total_sents = torch.sum(length_batch != 0).float() 238 | labels_1s = (label_batch != 99999).float() 239 | predictions_masked = predictions * labels_1s 240 | labels_masked = label_batch * labels_1s 241 | loss_per_sent = torch.sum(torch.abs(predictions_masked - labels_masked), dim=self.word_dim) 242 | normalized_loss_per_sent = loss_per_sent / length_batch.float() 243 | batch_loss = torch.sum(normalized_loss_per_sent) / total_sents 244 | return batch_loss, total_sents 245 | 246 | 247 | class Probe(nn.Module): 248 | pass 249 | 250 | 251 | class TwoWordPSDProbe(Probe): 252 | """ Computes squared L2 distance after projection by a matrix. 253 | For a batch of sentences, computes all n^2 pairs of distances 254 | for each sentence in the batch. 255 | """ 256 | 257 | def __init__(self, model_dim, probe_rank): 258 | super(TwoWordPSDProbe, self).__init__() 259 | self.proj = nn.Parameter(data=torch.zeros(model_dim, probe_rank)) 260 | nn.init.uniform_(self.proj, -0.05, 0.05) 261 | 262 | def forward(self, batch): 263 | """ Computes all n^2 pairs of distances after projection 264 | for each sentence in a batch. 265 | Note that due to padding, some distances will be non-zero for pads. 266 | Computes (B(h_i-h_j))^T(B(h_i-h_j)) for all i,j 267 | Args: 268 | batch: a batch of word representations of the shape 269 | (batch_size, max_seq_len, representation_dim) 270 | Returns: 271 | A tensor of distances of shape (batch_size, max_seq_len, max_seq_len) 272 | """ 273 | transformed = torch.matmul(batch, self.proj) 274 | batchlen, seqlen, rank = transformed.size() 275 | transformed = transformed.unsqueeze(2) 276 | transformed = transformed.expand(-1, -1, seqlen, -1) 277 | transposed = transformed.transpose(1, 2) 278 | diffs = transformed - transposed 279 | squared_diffs = diffs.pow(2) 280 | squared_distances = torch.sum(squared_diffs, -1) 281 | return squared_distances 282 | 283 | 284 | class OneWordPSDProbe(Probe): 285 | """ Computes squared L2 norm of words after projection by a matrix.""" 286 | 287 | def __init__(self, model_dim, probe_rank): 288 | super(OneWordPSDProbe, self).__init__() 289 | self.proj = nn.Parameter(data=torch.zeros(model_dim, probe_rank)) 290 | nn.init.uniform_(self.proj, -0.05, 0.05) 291 | 292 | def forward(self, batch): 293 | """ Computes all n depths after projection 294 | for each sentence in a batch. 295 | Computes (Bh_i)^T(Bh_i) for all i 296 | Args: 297 | batch: a batch of word representations of the shape 298 | (batch_size, max_seq_len, representation_dim) 299 | Returns: 300 | A tensor of depths of shape (batch_size, max_seq_len) 301 | """ 302 | transformed = torch.matmul(batch, self.proj) 303 | batchlen, seqlen, rank = transformed.size() 304 | norms = torch.bmm(transformed.view(batchlen * seqlen, 1, rank), 305 | transformed.view(batchlen * seqlen, rank, 1)) 306 | norms = norms.view(batchlen, seqlen) 307 | return norms 308 | -------------------------------------------------------------------------------- /third_party/gxlt_classify.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors, 3 | # The HuggingFace Inc. team, and The XTREME Benchmark Authors. 4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """ Finetuning multi-lingual models on XNLI/PAWSX (Bert, XLM, XLMRoberta).""" 18 | 19 | import argparse 20 | import glob 21 | import logging 22 | import os 23 | import h5py 24 | import random 25 | 26 | import numpy as np 27 | import torch 28 | from torch.utils.data import DataLoader 29 | from torch.utils.data import RandomSampler, SequentialSampler 30 | from tqdm import tqdm 31 | 32 | from transformers import ( 33 | WEIGHTS_NAME, 34 | AdamW, 35 | BertConfig, 36 | BertTokenizer, 37 | XLMRobertaTokenizer, 38 | get_linear_schedule_with_warmup, 39 | ) 40 | 41 | from third_party.modeling_bert import BertForSequenceClassification 42 | from third_party.processors.constants import * 43 | 44 | from third_party.processors.utils_classify import ( 45 | convert_examples_to_features, 46 | SequencePairDataset, 47 | batchify, 48 | XnliProcessor, 49 | PawsxProcessor 50 | ) 51 | 52 | try: 53 | from torch.utils.tensorboard import SummaryWriter 54 | except ImportError: 55 | from tensorboardX import SummaryWriter 56 | 57 | logger = logging.getLogger(__name__) 58 | 59 | ALL_MODELS = sum( 60 | ( 61 | tuple(conf.pretrained_config_archive_map.keys()) 62 | for conf in (BertConfig,) 63 | ), 64 | (), 65 | ) 66 | 67 | MODEL_CLASSES = { 68 | "bert": (BertConfig, BertForSequenceClassification, BertTokenizer), 69 | } 70 | 71 | PROCESSORS = { 72 | "xnli": XnliProcessor, 73 | "pawsx": PawsxProcessor, 74 | } 75 | 76 | 77 | def compute_metrics(preds, labels): 78 | scores = { 79 | "acc": (preds == labels).mean(), 80 | "num": len(preds), 81 | "correct": (preds == labels).sum(), 82 | } 83 | return scores 84 | 85 | 86 | def set_seed(args): 87 | random.seed(args.seed) 88 | np.random.seed(args.seed) 89 | torch.manual_seed(args.seed) 90 | if args.n_gpu > 0: 91 | torch.cuda.manual_seed_all(args.seed) 92 | 93 | 94 | def evaluate( 95 | args, 96 | model, 97 | tokenizer, 98 | split="train", 99 | language="en", 100 | lang2id=None, 101 | prefix="", 102 | output_file=None, 103 | label_list=None, 104 | output_only_prediction=True, 105 | ): 106 | """Evalute the model.""" 107 | eval_task_names = (args.task_name,) 108 | eval_outputs_dirs = (args.output_dir,) 109 | 110 | results = {} 111 | for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): 112 | eval_dataset = load_and_cache_examples( 113 | args, 114 | eval_task, 115 | tokenizer, 116 | split=split, 117 | language=language, 118 | lang2id=lang2id, 119 | evaluate=True, 120 | ) 121 | 122 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: 123 | os.makedirs(eval_output_dir) 124 | 125 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 126 | # Note that DistributedSampler samples randomly 127 | eval_sampler = SequentialSampler(eval_dataset) 128 | eval_dataloader = DataLoader( 129 | eval_dataset, 130 | sampler=eval_sampler, 131 | batch_size=args.eval_batch_size, 132 | num_workers=4, 133 | pin_memory=True, 134 | collate_fn=batchify, 135 | ) 136 | 137 | # multi-gpu eval 138 | if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): 139 | model = torch.nn.DataParallel(model) 140 | 141 | # Eval! 142 | logger.info("***** Running evaluation {} {} *****".format(prefix, language)) 143 | logger.info(" Num examples = %d", len(eval_dataset)) 144 | logger.info(" Batch size = %d", args.eval_batch_size) 145 | eval_loss = 0.0 146 | nb_eval_steps = 0 147 | preds = None 148 | out_label_ids = None 149 | sentences = None 150 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 151 | model.eval() 152 | batch = tuple(t.to(args.device) if t is not None else None for t in batch) 153 | 154 | with torch.no_grad(): 155 | inputs = dict() 156 | inputs['input_ids'] = batch[0] 157 | inputs['attention_mask'] = batch[1] 158 | inputs["token_type_ids"] = batch[2] if args.model_type in ["bert"] else None 159 | inputs['labels'] = batch[3] 160 | 161 | if args.use_syntax: 162 | inputs["dep_tag_ids"] = batch[4] 163 | inputs["pos_tag_ids"] = batch[5] 164 | inputs["dist_mat"] = batch[6] 165 | inputs["tree_depths"] = batch[7] 166 | 167 | outputs = model(**inputs) 168 | tmp_eval_loss, logits = outputs[:2] 169 | 170 | eval_loss += tmp_eval_loss.mean().item() 171 | nb_eval_steps += 1 172 | 173 | if preds is None: 174 | preds = logits.detach().cpu().numpy() 175 | out_label_ids = inputs["labels"].detach().cpu().numpy() 176 | sentences = inputs["input_ids"].detach().cpu().numpy() 177 | else: 178 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 179 | out_label_ids = np.append( 180 | out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0 181 | ) 182 | sentences = np.append( 183 | sentences, inputs["input_ids"].detach().cpu().numpy(), axis=0 184 | ) 185 | 186 | eval_loss = eval_loss / nb_eval_steps 187 | if args.output_mode == "classification": 188 | preds = np.argmax(preds, axis=1) 189 | else: 190 | raise ValueError("No other `output_mode` for XNLI.") 191 | result = compute_metrics(preds, out_label_ids) 192 | results.update(result) 193 | 194 | if output_file: 195 | logger.info("***** Save prediction ******") 196 | with open(output_file, "w") as fout: 197 | pad_token_id = tokenizer.pad_token_id 198 | sentences = sentences.astype(int).tolist() 199 | sentences = [[w for w in s if w != pad_token_id] for s in sentences] 200 | sentences = [tokenizer.convert_ids_to_tokens(s) for s in sentences] 201 | # fout.write('Prediction\tLabel\tSentences\n') 202 | for p, l, s in zip(list(preds), list(out_label_ids), sentences): 203 | s = " ".join(s) 204 | if label_list: 205 | p = label_list[p] 206 | l = label_list[l] 207 | if output_only_prediction: 208 | fout.write(str(p) + "\n") 209 | else: 210 | fout.write("{}\t{}\t{}\n".format(p, l, s)) 211 | logger.info("***** Eval results {} {} *****".format(prefix, language)) 212 | for key in sorted(result.keys()): 213 | logger.info(" %s = %s", key, str(result[key])) 214 | 215 | return results 216 | 217 | 218 | def load_and_cache_examples( 219 | args, task, tokenizer, split="train", language="en", lang2id=None, evaluate=False 220 | ): 221 | # Make sure only the first process in distributed training process the 222 | # dataset, and the others will use the cache 223 | if args.local_rank not in [-1, 0] and not evaluate: 224 | torch.distributed.barrier() 225 | 226 | processor = PROCESSORS[task]() 227 | output_mode = "classification" 228 | # Load data features from cache or dataset file 229 | lc = "_lc" if args.do_lower_case else "" 230 | cached_features_file = os.path.join( 231 | args.output_dir, 232 | "cached_{}_{}_{}_{}_{}{}".format( 233 | split, 234 | list(filter(None, args.model_name_or_path.split("/"))).pop(), 235 | str(args.max_seq_length), 236 | str(task), 237 | language, 238 | lc, 239 | ), 240 | ) 241 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 242 | logger.info("Loading features from cached file %s", cached_features_file) 243 | features = torch.load(cached_features_file) 244 | else: 245 | logger.info("Creating features from dataset file at %s", args.data_dir) 246 | label_list = processor.get_labels() 247 | examples = processor.get_test_examples(args.data_dir, language, args.swap_pairs) 248 | 249 | features = convert_examples_to_features( 250 | examples, 251 | tokenizer, 252 | label_list=label_list, 253 | max_length=args.max_seq_length, 254 | output_mode=output_mode, 255 | sep_token_extra=bool(args.model_type in ["roberta", "xlmr"]), 256 | pad_on_left=False, 257 | pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], 258 | pad_token_segment_id=0, 259 | lang2id=lang2id, 260 | use_syntax=args.use_syntax, 261 | ) 262 | 263 | # NOTE. WE do not cache the features as we will do this experiment less often 264 | # if args.local_rank in [-1, 0]: 265 | # logger.info("Saving features into cached file %s", cached_features_file) 266 | # torch.save(features, cached_features_file) 267 | 268 | # Make sure only the first process in distributed training process the 269 | # dataset, and the others will use the cache 270 | if args.local_rank == 0 and not evaluate: 271 | torch.distributed.barrier() 272 | 273 | if args.model_type == "xlm": 274 | raise NotImplementedError 275 | else: 276 | dataset = SequencePairDataset(features) 277 | 278 | return dataset 279 | 280 | 281 | def main(): 282 | parser = argparse.ArgumentParser() 283 | 284 | # Required parameters 285 | parser.add_argument( 286 | "--data_dir", 287 | default=None, 288 | type=str, 289 | required=True, 290 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.", 291 | ) 292 | parser.add_argument( 293 | "--model_type", 294 | default=None, 295 | type=str, 296 | required=True, 297 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), 298 | ) 299 | parser.add_argument( 300 | "--model_name_or_path", 301 | default=None, 302 | type=str, 303 | required=True, 304 | help="Path to pre-trained model or shortcut name selected in the list: " 305 | + ", ".join(ALL_MODELS), 306 | ) 307 | parser.add_argument( 308 | "--train_language", 309 | default="en", 310 | type=str, 311 | help="Train language if is different of the evaluation language.", 312 | ) 313 | parser.add_argument( 314 | "--predict_languages", 315 | type=str, 316 | default="en", 317 | help="prediction languages separated by ','.", 318 | ) 319 | parser.add_argument( 320 | "--output_dir", 321 | default=None, 322 | type=str, 323 | required=True, 324 | help="The output directory where the model predictions and checkpoints will be written.", 325 | ) 326 | parser.add_argument( 327 | "--task_name", 328 | default="xnli", 329 | type=str, 330 | required=True, 331 | help="The task name", 332 | ) 333 | 334 | # Other parameters 335 | parser.add_argument( 336 | "--config_name", 337 | default="", 338 | type=str, 339 | help="Pretrained config name or path if not the same as model_name", 340 | ) 341 | parser.add_argument( 342 | "--tokenizer_name", 343 | default="", 344 | type=str, 345 | help="Pretrained tokenizer name or path if not the same as model_name", 346 | ) 347 | parser.add_argument( 348 | "--cache_dir", 349 | default="", 350 | type=str, 351 | help="Where do you want to store the pre-trained models downloaded from s3", 352 | ) 353 | parser.add_argument( 354 | "--max_seq_length", 355 | default=128, 356 | type=int, 357 | help="The maximum total input sequence length after tokenization. Sequences longer " 358 | "than this will be truncated, sequences shorter will be padded.", 359 | ) 360 | parser.add_argument( 361 | "--do_predict", action="store_true", help="Whether to run prediction." 362 | ) 363 | parser.add_argument( 364 | "--do_lower_case", 365 | action="store_true", 366 | help="Set this flag if you are using an uncased model.", 367 | ) 368 | parser.add_argument( 369 | "--test_split", type=str, default="test", help="split of training set" 370 | ) 371 | parser.add_argument( 372 | "--per_gpu_eval_batch_size", 373 | default=8, 374 | type=int, 375 | help="Batch size per GPU/CPU for evaluation.", 376 | ) 377 | parser.add_argument("--log_file", default="train", type=str, help="log file") 378 | parser.add_argument( 379 | "--no_cuda", action="store_true", help="Avoid using CUDA when available" 380 | ) 381 | parser.add_argument( 382 | "--overwrite_output_dir", 383 | action="store_true", 384 | help="Overwrite the content of the output directory", 385 | ) 386 | parser.add_argument( 387 | "--overwrite_cache", 388 | action="store_true", 389 | help="Overwrite the cached training and evaluation sets", 390 | ) 391 | parser.add_argument( 392 | "--seed", type=int, default=42, help="random seed for initialization" 393 | ) 394 | parser.add_argument( 395 | "--fp16", 396 | action="store_true", 397 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", 398 | ) 399 | parser.add_argument( 400 | "--fp16_opt_level", 401 | type=str, 402 | default="O1", 403 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 404 | "See details at https://nvidia.github.io/apex/amp.html", 405 | ) 406 | parser.add_argument( 407 | "--local_rank", 408 | type=int, 409 | default=-1, 410 | help="For distributed training: local_rank", 411 | ) 412 | parser.add_argument( 413 | "--server_ip", type=str, default="", help="For distant debugging." 414 | ) 415 | parser.add_argument( 416 | "--server_port", type=str, default="", help="For distant debugging." 417 | ) 418 | # 419 | # 420 | parser.add_argument( 421 | "--syntactic_layers", 422 | type=str, default='0,1,2', 423 | help="comma separated layer indices for syntax fusion", 424 | ) 425 | parser.add_argument( 426 | "--num_syntactic_heads", 427 | default=2, type=int, 428 | help="Number of syntactic heads", 429 | ) 430 | parser.add_argument( 431 | "--use_syntax", 432 | type='bool', 433 | default=False, 434 | help="Whether to use syntax-based modeling", 435 | ) 436 | parser.add_argument( 437 | "--use_dependency_tag", 438 | type='bool', 439 | default=False, 440 | help="Whether to use dependency tag in structure modeling", 441 | ) 442 | parser.add_argument( 443 | "--use_pos_tag", 444 | type='bool', 445 | default=False, 446 | help="Whether to use pos tags in structure modeling", 447 | ) 448 | parser.add_argument( 449 | "--use_structural_loss", 450 | type='bool', 451 | default=False, 452 | help="Whether to use structural loss along with task loss", 453 | ) 454 | parser.add_argument( 455 | "--struct_loss_coeff", 456 | default=1.0, type=float, 457 | help="Multiplying factor for the structural loss", 458 | ) 459 | parser.add_argument( 460 | "--max_syntactic_distance", 461 | default=1, type=int, 462 | help="Max distance to consider during graph attention", 463 | ) 464 | parser.add_argument( 465 | "--num_gat_layer", 466 | default=4, type=int, 467 | help="Number of layers in Graph Attention Networks (GAT)", 468 | ) 469 | parser.add_argument( 470 | "--num_gat_head", 471 | default=4, type=int, 472 | help="Number of attention heads in Graph Attention Networks (GAT)", 473 | ) 474 | parser.add_argument( 475 | "--batch_normalize", 476 | action="store_true", 477 | help="Apply batch normalization to representation", 478 | ) 479 | parser.add_argument( 480 | "--swap_pairs", 481 | action="store_true", 482 | help="Swap the input sentence pairs.", 483 | ) 484 | 485 | args = parser.parse_args() 486 | 487 | logging.basicConfig( 488 | handlers=[ 489 | logging.FileHandler(os.path.join(args.output_dir, args.log_file)), 490 | logging.StreamHandler(), 491 | ], 492 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 493 | datefmt="%m/%d/%Y %H:%M:%S", 494 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, 495 | ) 496 | logging.info("Input args: %r" % args) 497 | 498 | # Setup distant debugging if needed 499 | if args.server_ip and args.server_port: 500 | import ptvsd 501 | 502 | print("Waiting for debugger attach") 503 | ptvsd.enable_attach( 504 | address=(args.server_ip, args.server_port), redirect_output=True 505 | ) 506 | ptvsd.wait_for_attach() 507 | 508 | # Setup CUDA, GPU & distributed training 509 | if args.local_rank == -1 or args.no_cuda: 510 | device = torch.device( 511 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" 512 | ) 513 | args.n_gpu = torch.cuda.device_count() 514 | else: # Initializes the distributed backend which sychronizes nodes/GPUs 515 | torch.cuda.set_device(args.local_rank) 516 | device = torch.device("cuda", args.local_rank) 517 | torch.distributed.init_process_group(backend="nccl") 518 | args.n_gpu = 1 519 | args.device = device 520 | 521 | # Setup logging 522 | logger.warning( 523 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 524 | args.local_rank, 525 | device, 526 | args.n_gpu, 527 | bool(args.local_rank != -1), 528 | args.fp16, 529 | ) 530 | 531 | # Set seed 532 | set_seed(args) 533 | 534 | # Prepare dataset 535 | if args.task_name not in PROCESSORS: 536 | raise ValueError("Task not found: %s" % (args.task_name)) 537 | processor = PROCESSORS[args.task_name]() 538 | args.output_mode = "classification" 539 | label_list = processor.get_labels() 540 | num_labels = len(label_list) 541 | 542 | # Load pretrained model and tokenizer 543 | # Make sure only the first process in distributed training loads model & vocab 544 | if args.local_rank not in [-1, 0]: 545 | torch.distributed.barrier() 546 | 547 | args.model_type = args.model_type.lower() 548 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 549 | config = config_class.from_pretrained( 550 | args.config_name if args.config_name else args.model_name_or_path, 551 | num_labels=num_labels, 552 | finetuning_task=args.task_name, 553 | cache_dir=args.cache_dir if args.cache_dir else None, 554 | ) 555 | #################################### 556 | config.dep_tag_vocab_size = len(DEPTAG_SYMBOLS) + NUM_SPECIAL_TOKENS 557 | config.pos_tag_vocab_size = len(POS_SYMBOLS) + NUM_SPECIAL_TOKENS 558 | config.use_dependency_tag = args.use_dependency_tag 559 | config.use_pos_tag = args.use_pos_tag 560 | config.use_structural_loss = args.use_structural_loss 561 | config.struct_loss_coeff = args.struct_loss_coeff 562 | config.num_syntactic_heads = args.num_syntactic_heads 563 | config.syntactic_layers = args.syntactic_layers 564 | config.max_syntactic_distance = args.max_syntactic_distance 565 | config.use_syntax = args.use_syntax 566 | config.batch_normalize = args.batch_normalize 567 | config.num_gat_layer = args.num_gat_layer 568 | config.num_gat_head = args.num_gat_head 569 | #################################### 570 | 571 | logger.info("config = {}".format(config)) 572 | 573 | lang2id = config.lang2id if args.model_type == "xlm" else None 574 | logger.info("lang2id = {}".format(lang2id)) 575 | 576 | # Make sure only the first process in distributed training loads model & vocab 577 | if args.local_rank == 0: 578 | torch.distributed.barrier() 579 | logger.info("Training/evaluation parameters %s", args) 580 | 581 | if os.path.exists(os.path.join(args.output_dir, "checkpoint-best")): 582 | best_checkpoint = os.path.join(args.output_dir, "checkpoint-best") 583 | else: 584 | best_checkpoint = args.output_dir 585 | 586 | # Prediction 587 | if args.do_predict and args.local_rank in [-1, 0]: 588 | tokenizer = tokenizer_class.from_pretrained( 589 | args.model_name_or_path if args.model_name_or_path else best_checkpoint, 590 | do_lower_case=args.do_lower_case, 591 | ) 592 | model = model_class.from_pretrained(best_checkpoint) 593 | model.to(args.device) 594 | output_predict_file = os.path.join( 595 | args.output_dir, "xling_" + args.test_split + "_results.txt" 596 | ) 597 | total = total_correct = 0.0 598 | with open(output_predict_file, "a") as writer: 599 | writer.write( 600 | "======= Predict using the model from {} for {}:\n".format( 601 | best_checkpoint, args.test_split 602 | ) 603 | ) 604 | for pre_lang in args.predict_languages.split(","): 605 | for hyp_lang in args.predict_languages.split(","): 606 | if pre_lang == hyp_lang: 607 | continue 608 | # output_file = os.path.join( 609 | # args.output_dir, "test-{}.tsv".format(language) 610 | # ) 611 | language = pre_lang + '_' + hyp_lang 612 | result = evaluate( 613 | args, 614 | model, 615 | tokenizer, 616 | split=args.test_split, 617 | language=language, 618 | lang2id=lang2id, 619 | prefix="best_checkpoint", 620 | # output_file=output_file, 621 | label_list=label_list, 622 | ) 623 | logger.info("{}={}".format(language, result["acc"])) 624 | writer.write("=====================\nlanguage={}\n".format(language)) 625 | for key in sorted(result.keys()): 626 | writer.write("{} = {}\n".format(key, result[key])) 627 | total += result["num"] 628 | total_correct += result["correct"] 629 | 630 | writer.write("=====================\n") 631 | writer.write("total={}\n".format(total_correct / total)) 632 | 633 | return result 634 | 635 | 636 | if __name__ == "__main__": 637 | main() 638 | -------------------------------------------------------------------------------- /third_party/processors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wasiahmad/Syntax-MBERT/1af0ef5ff9cc7e3e7de6e662e3d677ee157630ca/third_party/processors/__init__.py -------------------------------------------------------------------------------- /third_party/processors/constants.py: -------------------------------------------------------------------------------- 1 | NUM_SPECIAL_TOKENS = 4 2 | BERT_SPECIAL_TOKENS = ['[PAD]', '[CLS]', '[SEP]', '[UNK]'] 3 | XLMR_SPECIAL_TOKENS = ['', '', '', ''] 4 | 5 | POS_SYMBOLS = [ 6 | "ADJ", "ADP", "ADV", "AUX", "CCONJ", "DET", "INTJ", "NOUN", "NUM", "PART", "PRON", "PROPN", "PUNCT", 7 | "SCONJ", "SYM", "VERB", "X" 8 | ] 9 | 10 | DEPTAG_SYMBOLS = [ 11 | "acl", "advcl", "advmod", "amod", "appos", "aux", "case", "cc", "ccomp", "clf", "compound", "conj", 12 | "cop", "csubj", "dep", "det", "discourse", "dislocated", "expl", "fixed", "flat", "goeswith", "iobj", 13 | "list", "mark", "nmod", "nsubj", "nummod", "obj", "obl", "orphan", "parataxis", "punct", "reparandum", 14 | "root", "vocative", "xcomp" 15 | ] 16 | 17 | 18 | def UPOS_MAP(tokenizer): 19 | sp_tokens = XLMR_SPECIAL_TOKENS if 'roberta' in tokenizer else BERT_SPECIAL_TOKENS 20 | symbols = sp_tokens[:1] + POS_SYMBOLS + sp_tokens[1:] 21 | return {s: idx for idx, s in enumerate(symbols)} 22 | 23 | 24 | def DEP_TAG_MAP(tokenizer): 25 | sp_tokens = XLMR_SPECIAL_TOKENS if 'roberta' in tokenizer else BERT_SPECIAL_TOKENS 26 | symbols = sp_tokens[:1] + DEPTAG_SYMBOLS + sp_tokens[1:] + [''] 27 | return {s: idx for idx, s in enumerate(symbols)} 28 | 29 | 30 | def upos_to_id(param, tokenizer): 31 | use_map = UPOS_MAP(tokenizer) 32 | if isinstance(param, str): 33 | return use_map[param] 34 | elif isinstance(param, list): 35 | return [use_map[i] for i in param] 36 | else: 37 | raise ValueError() 38 | 39 | 40 | def deptag_to_id(param, tokenizer): 41 | use_map = DEP_TAG_MAP(tokenizer) 42 | unk_sym = '' if 'roberta' in tokenizer else '[UNK]' 43 | if isinstance(param, str): 44 | return use_map.get(param, use_map[unk_sym]) 45 | elif isinstance(param, list): 46 | return [use_map.get(i, use_map[unk_sym]) for i in param] 47 | else: 48 | raise ValueError() 49 | -------------------------------------------------------------------------------- /third_party/processors/tree.py: -------------------------------------------------------------------------------- 1 | # src: https://github.com/qipeng/gcn-over-pruned-trees/blob/master/model/tree.py 2 | """ 3 | Basic operations on trees. 4 | """ 5 | 6 | import numpy as np 7 | from scipy.sparse import csr_matrix 8 | from scipy.sparse.csgraph import shortest_path 9 | from third_party.processors.constants import DEP_TAG_MAP 10 | 11 | 12 | class Tree(object): 13 | """ 14 | Reused tree object from stanfordnlp/treelstm. 15 | """ 16 | 17 | def __init__(self): 18 | self.parent = None 19 | self.idx = None 20 | self.token = None 21 | self.num_children = 0 22 | self.children = list() 23 | 24 | def add_child(self, child): 25 | child.parent = self 26 | self.num_children += 1 27 | self.children.append(child) 28 | 29 | def print(self, level): 30 | for i in range(1, level): 31 | print('|----', end='') 32 | print(self.token) 33 | for i in range(self.num_children): 34 | self.children[i].print(level + 1) 35 | 36 | def size(self): 37 | if getattr(self, '_size', False): 38 | return self._size 39 | count = 1 40 | for i in range(self.num_children): 41 | count += self.children[i].size() 42 | self._size = count 43 | return self._size 44 | 45 | def height(self): 46 | if getattr(self, '_height', False): 47 | return self._height 48 | count = 0 49 | if self.num_children > 0: 50 | for i in range(self.num_children): 51 | child_height = self.children[i].height() 52 | if child_height > count: 53 | count = child_height 54 | count += 1 55 | self._height = count 56 | return self._height 57 | 58 | def depth(self): 59 | if getattr(self, '_depth', False): 60 | return self._depth 61 | count = 0 62 | if self.parent: 63 | count += self.parent.depth() 64 | count += 1 65 | self._depth = count 66 | return self._depth 67 | 68 | def delete(self): 69 | for i in range(self.num_children): 70 | self.parent.add_child(self.children[i]) 71 | self.children[i].parent = self.parent 72 | index = None 73 | for i in range(self.parent.num_children): 74 | if self.parent.children[i].idx == self.idx: 75 | index = i 76 | break 77 | self.parent.children.pop(index) 78 | self.parent.num_children -= 1 79 | 80 | def __iter__(self): 81 | yield self 82 | for c in self.children: 83 | for x in c: 84 | yield x 85 | 86 | 87 | def tree_to_adj(sent_len, tree, directed=False, self_loop=False): 88 | """ 89 | Convert a tree object to an (numpy) adjacency matrix. 90 | """ 91 | ret = np.zeros((sent_len, sent_len), dtype=np.float32) 92 | 93 | queue = [tree] 94 | idx = [] 95 | while len(queue) > 0: 96 | t, queue = queue[0], queue[1:] 97 | 98 | idx += [t.idx] 99 | 100 | for c in t.children: 101 | ret[t.idx, c.idx] = 1 102 | queue += t.children 103 | 104 | if not directed: 105 | ret = ret + ret.T 106 | 107 | if self_loop: 108 | for i in idx: 109 | ret[i, i] = 1 110 | 111 | return ret 112 | 113 | 114 | def head_to_tree(head, tokens=None): 115 | """ 116 | Convert a sequence of head indexes into a tree object. 117 | """ 118 | root = None 119 | nodes = [Tree() for _ in head] 120 | for i in range(len(nodes)): 121 | h = head[i] 122 | nodes[i].idx = i 123 | if tokens is not None: 124 | nodes[i].token = tokens[i] 125 | if h == 0: 126 | root = nodes[i] 127 | else: 128 | nodes[h - 1].add_child(nodes[i]) 129 | 130 | assert root is not None 131 | return root, nodes 132 | 133 | 134 | def heads_to_dist_mat(head, tokens, directed=False): 135 | root, _ = head_to_tree(head, tokens) 136 | adj_mat = tree_to_adj(root.size(), root, directed=directed, self_loop=False) 137 | dist_matrix = shortest_path(csgraph=csr_matrix(adj_mat), directed=directed) 138 | return dist_matrix 139 | 140 | 141 | def root_to_dist_mat(root, directed=False): 142 | adj_mat = tree_to_adj(root.size(), root, directed=directed, self_loop=False) 143 | dist_matrix = shortest_path(csgraph=csr_matrix(adj_mat), directed=directed) 144 | return dist_matrix 145 | 146 | 147 | def adj_mat_to_dist_mat(adj_mat, directed=False): 148 | dist_matrix = shortest_path(csgraph=csr_matrix(adj_mat), directed=directed) 149 | return dist_matrix 150 | 151 | 152 | def dist_to_root(head, tokens): 153 | root, nodes = head_to_tree(head, tokens) 154 | distances = [] 155 | for i in range(len(nodes)): 156 | distances.append(nodes[i].depth()) 157 | 158 | return distances 159 | 160 | 161 | def ancestorMatrixRec(root, anc, mat): 162 | # base case 163 | if root == None: 164 | return 0 165 | 166 | # Update all ancestors of current node 167 | data = root.idx 168 | for i in range(len(anc)): 169 | mat[anc[i]][data] = 1 170 | 171 | # Push data to list of ancestors 172 | anc.append(data) 173 | 174 | # Traverse all the subtrees 175 | for c in root.children: 176 | ancestorMatrixRec(c, anc, mat) 177 | 178 | # Remove data from list the list of ancestors 179 | # as all descendants of it are processed now. 180 | anc.pop(-1) 181 | 182 | 183 | def heads_to_ancestor_matrix(heads, tokens): 184 | mat = np.zeros((len(tokens), len(tokens)), dtype=np.int32) 185 | root, _ = head_to_tree(heads, tokens) 186 | ancestorMatrixRec(root, [], mat) 187 | np.fill_diagonal(mat, 1) 188 | return mat 189 | 190 | 191 | def dep_path_matrix(heads, tokens, dep_labels, root=None): 192 | assert len(heads) == len(tokens) == len(dep_labels) 193 | if root is None: 194 | root, _ = head_to_tree(heads, tokens) 195 | 196 | def find_path(i, j, preds): 197 | if predecessors[i, j] == i: 198 | return preds 199 | preds.append(predecessors[i, j]) 200 | find_path(i, predecessors[i, j], preds) 201 | 202 | adj_mat = tree_to_adj(root.size(), root, directed=False, self_loop=False) 203 | _, predecessors = shortest_path( 204 | csgraph=csr_matrix(adj_mat), directed=False, return_predecessors=True 205 | ) 206 | 207 | max_path_length = 0 208 | token_to_token_paths = {} 209 | for i in range(len(tokens)): 210 | for j in range(i + 1, len(tokens)): 211 | preds = [] 212 | find_path(i, j, preds) 213 | token_ids = [i] + preds[::-1] + [j] 214 | if len(token_ids) - 1 > max_path_length: 215 | max_path_length = len(token_ids) - 1 216 | 217 | path_labels = [] 218 | for k in range(len(token_ids) - 1): 219 | # token[k+1] is the head of token[k] 220 | if heads[token_ids[k]] - 1 == token_ids[k + 1]: 221 | path_labels.append(dep_labels[token_ids[k]]) 222 | # token[k] is the head of token[k+1] 223 | elif heads[token_ids[k + 1]] - 1 == token_ids[k]: 224 | path_labels.append(dep_labels[token_ids[k + 1]]) 225 | else: 226 | raise ValueError() 227 | token_to_token_paths['{}.{}'.format(i, j)] = path_labels 228 | 229 | path_matrix = np.empty((len(tokens), len(tokens), max_path_length), dtype=np.int32) 230 | for i in range(len(tokens)): 231 | for j in range(i, len(tokens)): 232 | if i == j: 233 | labels = [''] + [''] * (max_path_length - 1) 234 | path_matrix[i, j] = [DEP_TAG_MAP[l] for l in labels] 235 | else: 236 | labels = token_to_token_paths['{}.{}'.format(i, j)] 237 | pad_length = max_path_length - len(labels) 238 | labels = labels + [''] * pad_length 239 | labels = [DEP_TAG_MAP[l] for l in labels] 240 | path_matrix[i, j] = labels 241 | path_matrix[j, i] = labels 242 | 243 | return path_matrix 244 | 245 | 246 | def get_dep_path(heads, tokens, dep_labels, root=None): 247 | assert len(heads) == len(tokens) == len(dep_labels) 248 | if root is None: 249 | root, _ = head_to_tree(heads, tokens) 250 | 251 | def find_path(i, j, preds): 252 | if predecessors[i, j] == i: 253 | return preds 254 | preds.append(predecessors[i, j]) 255 | find_path(i, predecessors[i, j], preds) 256 | 257 | adj_mat = tree_to_adj(root.size(), root, directed=False, self_loop=False) 258 | _, predecessors = shortest_path( 259 | csgraph=csr_matrix(adj_mat), directed=False, return_predecessors=True 260 | ) 261 | 262 | max_path_length = 0 263 | token_to_token_paths = {} 264 | for i in range(len(tokens)): 265 | for j in range(i + 1, len(tokens)): 266 | preds = [] 267 | find_path(i, j, preds) 268 | token_ids = [i] + preds[::-1] + [j] 269 | if len(token_ids) - 1 > max_path_length: 270 | max_path_length = len(token_ids) - 1 271 | 272 | path_labels = [] 273 | for k in range(len(token_ids) - 1): 274 | # token[k+1] is the head of token[k] 275 | if heads[token_ids[k]] - 1 == token_ids[k + 1]: 276 | path_labels.append(dep_labels[token_ids[k]]) 277 | # token[k] is the head of token[k+1] 278 | elif heads[token_ids[k + 1]] - 1 == token_ids[k]: 279 | path_labels.append(dep_labels[token_ids[k + 1]]) 280 | else: 281 | raise ValueError() 282 | token_to_token_paths['{}.{}'.format(i, j)] = path_labels 283 | 284 | return token_to_token_paths 285 | 286 | 287 | def get_path_matrix(heads, tokens, dep_labels, token_to_token_paths): 288 | assert len(heads) == len(tokens) == len(dep_labels) 289 | 290 | max_path_length = max([len(v) for k, v in token_to_token_paths.items()]) 291 | path_matrix = np.empty((len(tokens), len(tokens), max_path_length), dtype=np.int32) 292 | for i in range(len(tokens)): 293 | for j in range(i, len(tokens)): 294 | if i == j: 295 | labels = [''] + [''] * (max_path_length - 1) 296 | path_matrix[i, j] = [DEP_TAG_MAP[l] for l in labels] 297 | else: 298 | labels = token_to_token_paths['{}.{}'.format(i, j)] 299 | pad_length = max_path_length - len(labels) 300 | labels = labels + [''] * pad_length 301 | labels = [DEP_TAG_MAP[l] for l in labels] 302 | path_matrix[i, j] = labels 303 | path_matrix[j, i] = labels 304 | 305 | return path_matrix 306 | 307 | 308 | if __name__ == '__main__': 309 | tokens = ['The', 'increase', 'reflects', 'lower', 'credit', 'losses'] 310 | heads = [2, 3, 0, 6, 6, 3] 311 | anc_mat = heads_to_ancestor_matrix(heads, tokens) 312 | # print(anc_mat) 313 | # [[1 1 1 0 0 0] 314 | # [0 1 1 0 0 0] 315 | # [0 0 1 0 0 0] 316 | # [0 0 1 1 0 1] 317 | # [0 0 1 0 1 1] 318 | # [0 0 1 0 0 1]] 319 | -------------------------------------------------------------------------------- /third_party/processors/utils_tag.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors, 3 | # The HuggingFace Inc. team, and The XTREME Benchmark Authors. 4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """Utility functions for NER/POS tagging tasks.""" 18 | 19 | from __future__ import absolute_import, division, print_function 20 | 21 | import logging 22 | import os 23 | import torch 24 | import json 25 | 26 | from io import open 27 | from torch.utils.data import Dataset 28 | from third_party.processors.tree import * 29 | from third_party.processors.constants import * 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | class InputExample(object): 35 | """A single training/test example for token classification.""" 36 | 37 | def __init__(self, guid, words, labels, langs=None, 38 | heads=None, dep_tags=None, pos_tags=None): 39 | """Constructs a InputExample. 40 | 41 | Args: 42 | guid: Unique id for the example. 43 | words: list. The words of the sequence. 44 | labels: (Optional) list. The labels for each word of the sequence. This should be 45 | specified for train and dev examples, but not for test examples. 46 | """ 47 | self.guid = guid 48 | self.words = words 49 | self.labels = labels 50 | self.langs = langs 51 | self.heads = heads 52 | self.dep_tags = dep_tags 53 | self.pos_tags = pos_tags 54 | 55 | 56 | class InputFeatures(object): 57 | """A single set of features of data.""" 58 | 59 | def __init__( 60 | self, 61 | input_ids, 62 | input_mask, 63 | segment_ids, 64 | label_ids, 65 | dep_tag_ids=None, 66 | pos_tag_ids=None, 67 | langs=None, 68 | root=None, 69 | heads=None, 70 | depths=None, 71 | trunc_token_ids=None, 72 | sep_token_indices=None, 73 | ): 74 | self.input_ids = input_ids 75 | self.input_mask = input_mask 76 | self.segment_ids = segment_ids 77 | self.label_ids = label_ids 78 | self.dep_tag_ids = dep_tag_ids 79 | self.pos_tag_ids = pos_tag_ids 80 | self.langs = langs 81 | self.root = root 82 | self.heads = heads 83 | self.trunc_token_ids = trunc_token_ids 84 | self.sep_token_indices = sep_token_indices 85 | self.depths = depths 86 | 87 | 88 | def read_examples_from_file(file_path, lang, lang2id=None): 89 | guid_index = 1 90 | examples = [] 91 | lang_id = lang2id.get(lang, lang2id["en"]) if lang2id else 0 92 | logger.info("lang_id={}, lang={}, lang2id={}".format(lang_id, lang, lang2id)) 93 | 94 | if os.path.exists('{}.tsv'.format(file_path)): 95 | with open('{}.tsv'.format(file_path), encoding="utf-8") as f: 96 | words = [] 97 | labels = [] 98 | heads = [] 99 | langs = [] 100 | dep_tags = [] 101 | pos_tags = [] 102 | for line in f: 103 | line = line.strip() 104 | if not line: 105 | examples.append( 106 | InputExample( 107 | guid="{}-{}".format(lang, guid_index), 108 | words=words, 109 | labels=labels, 110 | langs=langs, 111 | heads=heads, 112 | dep_tags=dep_tags, 113 | pos_tags=pos_tags 114 | ) 115 | ) 116 | guid_index += 1 117 | words = [] 118 | labels = [] 119 | langs = [] 120 | heads = [] 121 | dep_tags = [] 122 | pos_tags = [] 123 | continue 124 | 125 | splits = line.split("\t") 126 | words.append(splits[0]) 127 | langs.append(lang_id) 128 | labels.append(splits[1]) 129 | heads.append(int(splits[2])) 130 | dep_tags.append(splits[3].split(':')[0] 131 | if ':' in splits[3] else splits[3]) 132 | dep_tags.append(splits[4]) 133 | 134 | if words: 135 | examples.append( 136 | InputExample( 137 | guid="%s-%d".format(lang, guid_index), 138 | words=words, 139 | labels=labels, 140 | langs=langs, 141 | heads=heads, 142 | dep_tags=dep_tags, 143 | pos_tags=pos_tags 144 | ) 145 | ) 146 | 147 | elif os.path.exists('{}.jsonl'.format(file_path)): 148 | with open('{}.jsonl'.format(file_path), encoding="utf-8") as f: 149 | for line in f: 150 | line = line.strip() 151 | ex = json.loads(line) 152 | examples.append( 153 | InputExample( 154 | guid="%s-%d".format(lang, guid_index), 155 | words=ex['tokens'], 156 | labels=ex['label'], 157 | langs=[lang_id] * len(ex['tokens']), 158 | heads=ex['head'], 159 | dep_tags=[tag.split(':')[0] if ':' in tag else tag \ 160 | for tag in ex['deptag']], 161 | pos_tags=ex['postag'], 162 | ) 163 | ) 164 | 165 | else: 166 | logger.info("[Warning] file {} with neither .tsv or .jsonl exists".format(file_path)) 167 | return [] 168 | 169 | return examples 170 | 171 | 172 | def process_sentence( 173 | token_list, head_list, label_list, dep_tag_list, 174 | pos_tag_list, tokenizer, label_map, pad_token_label_id 175 | ): 176 | """ 177 | When a token gets split into multiple word pieces, 178 | we make all the pieces (except the first) children of the first piece. 179 | However, only the first piece acts as the node that contains 180 | the dependent tokens as the children. 181 | """ 182 | assert len(token_list) == len(head_list) == len(label_list) == \ 183 | len(dep_tag_list) == len(pos_tag_list) 184 | 185 | text_tokens = [] 186 | text_deptags = [] 187 | text_postags = [] 188 | # My name is Wa ##si Ah ##mad 189 | # 0 1 2 3 3 4 4 190 | sub_tok_to_orig_index = [] 191 | # My name is Wa ##si Ah ##mad 192 | # 0 1 2 3 5 193 | old_index_to_new_index = [] 194 | # My name is Wa ##si Ah ##mad 195 | # 1 1 1 1 0 1 0 196 | first_wpiece_indicator = [] 197 | offset = 0 198 | labels = [] 199 | for i, (token, label) in enumerate(zip(token_list, label_list)): 200 | word_tokens = tokenizer.tokenize(token) 201 | if len(token) != 0 and len(word_tokens) == 0: 202 | word_tokens = [tokenizer.unk_token] 203 | old_index_to_new_index.append(offset) # word piece index 204 | offset += len(word_tokens) 205 | for j, word_token in enumerate(word_tokens): 206 | first_wpiece_indicator += [1] if j == 0 else [0] 207 | labels += [label_map[label]] if j == 0 else [pad_token_label_id] 208 | text_tokens.append(word_token) 209 | sub_tok_to_orig_index.append(i) 210 | text_deptags.append(dep_tag_list[i]) 211 | text_postags.append(pos_tag_list[i]) 212 | 213 | assert len(text_tokens) == len(sub_tok_to_orig_index), \ 214 | "{} != {}".format(len(text_tokens), len(sub_tok_to_orig_index)) 215 | assert len(text_tokens) == len(first_wpiece_indicator) 216 | 217 | text_heads = [] 218 | head_idx = -1 219 | assert max(head_list) <= len(head_list), (max(head_list), len(head_list)) 220 | # iterating over the word pieces to adjust heads 221 | for i, orig_idx in enumerate(sub_tok_to_orig_index): 222 | # orig_idx: index of the original word (the word-piece belong to) 223 | head = head_list[orig_idx] 224 | if head == 0: # root 225 | # if root word is split into multiple pieces, 226 | # we make the first piece as the root node 227 | # and all the other word pieces as the child of the root node 228 | if head_idx == -1: 229 | head_idx = i + 1 230 | text_heads.append(0) 231 | else: 232 | text_heads.append(head_idx) 233 | else: 234 | if first_wpiece_indicator[i] == 1: 235 | # head indices start from 1, so subtracting 1 236 | head = old_index_to_new_index[head - 1] 237 | text_heads.append(head + 1) 238 | else: 239 | # word-piece of a token (except the first) 240 | # so, we make the first piece the parent of all other word pieces 241 | head = old_index_to_new_index[orig_idx] 242 | text_heads.append(head + 1) 243 | 244 | assert len(text_tokens) == len(text_heads), \ 245 | "{} != {}".format(len(text_tokens), len(text_heads)) 246 | 247 | return text_tokens, text_heads, labels, text_deptags, text_postags 248 | 249 | 250 | def convert_examples_to_features( 251 | examples, 252 | label_list, 253 | max_seq_length, 254 | tokenizer, 255 | cls_token_segment_id=0, 256 | sep_token_extra=False, 257 | pad_on_left=False, 258 | pad_token=0, 259 | pad_token_segment_id=0, 260 | pad_token_label_id=-1, 261 | sequence_a_segment_id=0, 262 | mask_padding_with_zero=True, 263 | lang="en", 264 | use_syntax=False, 265 | ): 266 | """Loads a data file into a list of `InputBatch`s 267 | `cls_token_at_end` define the location of the CLS token: 268 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] 269 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] 270 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) 271 | """ 272 | 273 | label_map = {label: i for i, label in enumerate(label_list)} 274 | special_tokens_count = 3 if sep_token_extra else 2 275 | 276 | features = [] 277 | over_length_examples = 0 278 | wrong_examples = 0 279 | for (ex_index, example) in enumerate(examples): 280 | if ex_index % 10000 == 0: 281 | logger.info("Writing example %d of %d", ex_index, len(examples)) 282 | 283 | if 0 not in example.heads: 284 | wrong_examples += 1 285 | continue 286 | 287 | tokens, heads, label_ids, dep_tags, pos_tags = process_sentence( 288 | example.words, 289 | example.heads, 290 | example.labels, 291 | example.dep_tags, 292 | example.pos_tags, 293 | tokenizer, 294 | label_map, 295 | pad_token_label_id 296 | ) 297 | 298 | orig_text_len = len(tokens) 299 | root_idx = heads.index(0) 300 | text_offset = 1 # text_a follows 301 | # So, we add 1 to head indices 302 | heads = np.add(heads, text_offset).tolist() 303 | # HEAD( root) = index of (1-based) 304 | heads[root_idx] = 1 305 | 306 | if len(tokens) > max_seq_length - special_tokens_count: 307 | # assert False # we already truncated sequence 308 | # print("truncate token", len(tokens), max_seq_length, special_tokens_count) 309 | # tokens = tokens[: (max_seq_length - special_tokens_count)] 310 | # label_ids = label_ids[: (max_seq_length - special_tokens_count)] 311 | over_length_examples += 1 312 | continue 313 | 314 | tokens += [tokenizer.sep_token] 315 | dep_tags += [tokenizer.sep_token] 316 | pos_tags += [tokenizer.sep_token] 317 | label_ids += [pad_token_label_id] 318 | if sep_token_extra: 319 | # roberta uses an extra separator b/w pairs of sentences 320 | tokens += [tokenizer.sep_token] 321 | dep_tags += [tokenizer.sep_token] 322 | pos_tags += [tokenizer.sep_token] 323 | label_ids += [pad_token_label_id] 324 | segment_ids = [sequence_a_segment_id] * len(tokens) 325 | 326 | # cls_token_at_begining 327 | tokens = [tokenizer.cls_token] + tokens 328 | dep_tags = [tokenizer.cls_token] + dep_tags 329 | pos_tags = [tokenizer.cls_token] + pos_tags 330 | label_ids = [pad_token_label_id] + label_ids 331 | segment_ids = [cls_token_segment_id] + segment_ids 332 | 333 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 334 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 335 | # tokens are attended to. 336 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 337 | 338 | # Zero-pad up to the sequence length. 339 | padding_length = max_seq_length - len(input_ids) 340 | if pad_on_left: 341 | input_ids = ([pad_token] * padding_length) + input_ids 342 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask 343 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids 344 | label_ids = ([pad_token_label_id] * padding_length) + label_ids 345 | else: 346 | input_ids += [pad_token] * padding_length 347 | input_mask += [0 if mask_padding_with_zero else 1] * padding_length 348 | segment_ids += [pad_token_segment_id] * padding_length 349 | label_ids += [pad_token_label_id] * padding_length 350 | 351 | if example.langs and len(example.langs) > 0: 352 | langs = [example.langs[0]] * max_seq_length 353 | else: 354 | print("example.langs", example.langs, example.words, len(example.langs)) 355 | print("ex_index", ex_index, len(examples)) 356 | langs = None 357 | 358 | assert len(input_ids) == max_seq_length 359 | assert len(input_mask) == max_seq_length 360 | assert len(segment_ids) == max_seq_length 361 | assert len(label_ids) == max_seq_length 362 | assert len(langs) == max_seq_length 363 | 364 | # if ex_index < 5: 365 | # logger.info("*** Example ***") 366 | # logger.info("guid: %s", example.guid) 367 | # logger.info("tokens: %s", " ".join([str(x) for x in tokens])) 368 | # logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) 369 | # logger.info("input_mask: %s", " ".join([str(x) for x in input_mask])) 370 | # logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) 371 | # logger.info("label_ids: %s", " ".join([str(x) for x in label_ids])) 372 | # logger.info("langs: {}".format(langs)) 373 | 374 | one_ex_features = InputFeatures( 375 | input_ids=input_ids, 376 | input_mask=input_mask, 377 | segment_ids=segment_ids, 378 | label_ids=label_ids, 379 | langs=langs, 380 | ) 381 | 382 | if use_syntax: 383 | ##################################################### 384 | # prepare the UPOS and DEPENDENCY tag tensors 385 | ##################################################### 386 | dep_tag_ids = deptag_to_id(dep_tags, tokenizer=str(type(tokenizer))) 387 | pos_tag_ids = upos_to_id(pos_tags, tokenizer=str(type(tokenizer))) 388 | 389 | if pad_on_left: 390 | dep_tag_ids = ([0] * padding_length) + dep_tag_ids 391 | pos_tag_ids = ([0] * padding_length) + pos_tag_ids 392 | else: 393 | dep_tag_ids += [0] * padding_length 394 | pos_tag_ids += [0] * padding_length 395 | 396 | assert len(input_ids) == len(dep_tag_ids) 397 | assert len(input_ids) == len(pos_tag_ids) 398 | assert len(dep_tag_ids) == max_seq_length 399 | assert len(pos_tag_ids) == max_seq_length 400 | 401 | one_ex_features.tag_ids = pos_tag_ids 402 | one_ex_features.dep_tag_ids = dep_tag_ids 403 | 404 | ##################################################### 405 | # form the tree structure using head information 406 | ##################################################### 407 | heads = [0] + heads + [1, 1] if sep_token_extra else [0] + heads + [1] 408 | assert len(tokens) == len(heads) 409 | root, nodes = head_to_tree(heads, tokens) 410 | assert len(heads) == root.size() 411 | sep_token_indices = [i for i, x in enumerate(tokens) if x == tokenizer.sep_token] 412 | depths = [nodes[i].depth() for i in range(len(nodes))] 413 | depths = np.asarray(depths, dtype=np.int32) 414 | 415 | one_ex_features.root = root 416 | one_ex_features.depths = depths 417 | one_ex_features.sep_token_indices = sep_token_indices 418 | 419 | features.append(one_ex_features) 420 | 421 | if over_length_examples > 0: 422 | logger.info('{} examples are discarded due to exceeding maximum length'.format(over_length_examples)) 423 | if wrong_examples > 0: 424 | logger.info('{} wrong examples are discarded'.format(wrong_examples)) 425 | return features 426 | 427 | 428 | def get_labels(path): 429 | with open(path, "r") as f: 430 | labels = f.read().splitlines() 431 | if "O" not in labels: 432 | labels = ["O"] + labels 433 | return labels 434 | 435 | 436 | class SequenceDataset(Dataset): 437 | def __init__(self, features): 438 | self.features = features 439 | 440 | def __len__(self): 441 | return len(self.features) 442 | 443 | def __getitem__(self, index): 444 | """Generates one sample of data""" 445 | feature = self.features[index] 446 | input_ids = torch.tensor(feature.input_ids, dtype=torch.long) 447 | labels = torch.tensor(feature.label_ids, dtype=torch.long) 448 | attention_mask = torch.tensor(feature.input_mask, dtype=torch.long) 449 | token_type_ids = torch.tensor(feature.segment_ids, dtype=torch.long) 450 | 451 | dist_matrix = None 452 | depths = None 453 | dep_tag_ids = None 454 | pos_tag_ids = None 455 | if feature.root is not None: 456 | dep_tag_ids = torch.tensor(feature.dep_tag_ids, dtype=torch.long) 457 | pos_tag_ids = torch.tensor(feature.pos_tag_ids, dtype=torch.long) 458 | dist_matrix = root_to_dist_mat(feature.root) 459 | if feature.trunc_token_ids is not None: 460 | dist_matrix = np.delete(dist_matrix, feature.trunc_token_ids, 0) # delete rows 461 | dist_matrix = np.delete(dist_matrix, feature.trunc_token_ids, 1) # delete columns 462 | 463 | dist_matrix = torch.tensor(dist_matrix, dtype=torch.long) # seq_len x seq_len x max-path-len 464 | 465 | if feature.depths is not None: 466 | depths = feature.depths 467 | if feature.trunc_token_ids is not None: 468 | depths = np.delete(depths, feature.trunc_token_ids, 0) 469 | depths = torch.tensor(depths, dtype=torch.long) # seq_len 470 | 471 | return [ 472 | input_ids, 473 | attention_mask, 474 | token_type_ids, 475 | labels, 476 | dep_tag_ids, 477 | pos_tag_ids, 478 | dist_matrix, 479 | depths, 480 | ] 481 | 482 | 483 | def batchify(batch): 484 | """Receives a batch of SequencePairDataset examples""" 485 | input_ids = torch.stack([data[0] for data in batch], dim=0) 486 | attention_mask = torch.stack([data[1] for data in batch], dim=0) 487 | token_type_ids = torch.stack([data[2] for data in batch], dim=0) 488 | labels = torch.stack([data[3] for data in batch], dim=0) 489 | 490 | dist_matrix = None 491 | depths = None 492 | dep_tag_ids = None 493 | pos_tag_ids = None 494 | 495 | if batch[0][4] is not None: 496 | dep_tag_ids = torch.stack([data[4] for data in batch], dim=0) 497 | 498 | if batch[0][5] is not None: 499 | pos_tag_ids = torch.stack([data[5] for data in batch], dim=0) 500 | 501 | if batch[0][6] is not None: 502 | dist_matrix = torch.full( 503 | (len(batch), input_ids.size(1), input_ids.size(1)), 99999, dtype=torch.long 504 | ) 505 | for i, data in enumerate(batch): 506 | slen, slen = data[6].size() 507 | dist_matrix[i, :slen, :slen] = data[6] 508 | 509 | if batch[0][7] is not None: 510 | depths = torch.full( 511 | (len(batch), input_ids.size(1)), 99999, dtype=torch.long 512 | ) 513 | for i, data in enumerate(batch): 514 | slen = data[7].size(0) 515 | depths[i, :slen] = data[7] 516 | 517 | return [ 518 | input_ids, 519 | attention_mask, 520 | token_type_ids, 521 | labels, 522 | dep_tag_ids, 523 | pos_tag_ids, 524 | dist_matrix, 525 | depths 526 | ] 527 | -------------------------------------------------------------------------------- /third_party/processors/utils_top.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors, 3 | # The HuggingFace Inc. team, and The XTREME Benchmark Authors. 4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """Utility functions for NER/POS tagging tasks.""" 18 | 19 | from __future__ import absolute_import, division, print_function 20 | 21 | import logging 22 | import os 23 | import torch 24 | import json 25 | 26 | from io import open 27 | from torch.utils.data import Dataset 28 | from third_party.processors.tree import * 29 | from third_party.processors.constants import * 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | class InputExample(object): 35 | """A single training/test example for token classification.""" 36 | 37 | def __init__( 38 | self, 39 | guid, 40 | words, 41 | intent_label, 42 | slot_labels, 43 | heads=None, 44 | dep_tags=None, 45 | pos_tags=None 46 | ): 47 | """Constructs a InputExample. 48 | 49 | Args: 50 | guid: Unique id for the example. 51 | words: list. The words of the sequence. 52 | labels: (Optional) list. The labels for each word of the sequence. This should be 53 | specified for train and dev examples, but not for test examples. 54 | """ 55 | self.guid = guid 56 | self.words = words 57 | self.intent_label = intent_label 58 | self.slot_labels = slot_labels 59 | self.heads = heads 60 | self.dep_tags = dep_tags 61 | self.pos_tags = pos_tags 62 | 63 | 64 | class InputFeatures(object): 65 | """A single set of features of data.""" 66 | 67 | def __init__( 68 | self, 69 | input_ids, 70 | input_mask, 71 | segment_ids, 72 | intent_label_id, 73 | slot_label_ids, 74 | dep_tag_ids=None, 75 | pos_tag_ids=None, 76 | root=None, 77 | trunc_token_ids=None, 78 | sep_token_indices=None, 79 | depths=None, 80 | ): 81 | self.input_ids = input_ids 82 | self.input_mask = input_mask 83 | self.segment_ids = segment_ids 84 | self.intent_label_id = intent_label_id 85 | self.slot_label_ids = slot_label_ids 86 | self.root = root 87 | self.trunc_token_ids = trunc_token_ids 88 | self.sep_token_indices = sep_token_indices 89 | self.dep_tag_ids = dep_tag_ids 90 | self.pos_tag_ids = pos_tag_ids 91 | self.depths = depths 92 | 93 | 94 | def read_examples_from_file(file_path, lang, lang2id=None): 95 | if not os.path.exists(file_path): 96 | logger.info("[Warning] file {} not exists".format(file_path)) 97 | return [] 98 | 99 | guid_index = 1 100 | examples = [] 101 | 102 | # { 103 | # "tokens": ["Has", "Angelika", "Kratzer", "video", "messaged", "me", "?"], 104 | # "head": [5, 4, 4, 5, 0, 5, 5], 105 | # "slot_labels": ["O", "B-CONTACT", "I-CONTACT", "B-TYPE_CONTENT", "O", "B-RECIPIENT", "O"], 106 | # "intent_label": "GET_MESSAGE" 107 | # } 108 | with open(file_path, encoding="utf-8") as f: 109 | for line in f: 110 | ex = json.loads(line.strip()) 111 | examples.append( 112 | InputExample( 113 | guid="%s-%d".format(lang, guid_index), 114 | words=ex['tokens'], 115 | intent_label=ex['intent_label'], 116 | slot_labels=ex['slot_labels'], 117 | heads=ex['head'], 118 | dep_tags=[tag.split(':')[0] if ':' in tag else tag \ 119 | for tag in ex['deptag']], 120 | pos_tags=ex['postag'] 121 | ) 122 | ) 123 | 124 | return examples 125 | 126 | 127 | def process_sentence( 128 | token_list, 129 | head_list, 130 | label_list, 131 | dep_tag_list, 132 | pos_tag_list, 133 | tokenizer, 134 | label_map, 135 | pad_token_label_id 136 | ): 137 | """ 138 | When a token gets split into multiple word pieces, 139 | we make all the pieces (except the first) children of the first piece. 140 | However, only the first piece acts as the node that contains 141 | the dependent tokens as the children. 142 | """ 143 | assert len(token_list) == len(head_list) == len(label_list) \ 144 | == len(dep_tag_list) == len(pos_tag_list) 145 | 146 | text_tokens = [] 147 | text_deptags = [] 148 | text_postags = [] 149 | # My name is Wa ##si Ah ##mad 150 | # 0 1 2 3 3 4 4 151 | sub_tok_to_orig_index = [] 152 | # My name is Wa ##si Ah ##mad 153 | # 0 1 2 3 5 154 | old_index_to_new_index = [] 155 | # My name is Wa ##si Ah ##mad 156 | # 1 1 1 1 0 1 0 157 | first_wpiece_indicator = [] 158 | offset = 0 159 | labels = [] 160 | for i, (token, label) in enumerate(zip(token_list, label_list)): 161 | word_tokens = tokenizer.tokenize(token) 162 | if len(token) != 0 and len(word_tokens) == 0: 163 | word_tokens = [tokenizer.unk_token] 164 | old_index_to_new_index.append(offset) # word piece index 165 | offset += len(word_tokens) 166 | for j, word_token in enumerate(word_tokens): 167 | first_wpiece_indicator += [1] if j == 0 else [0] 168 | labels += [label_map[label]] if j == 0 else [pad_token_label_id] 169 | text_tokens.append(word_token) 170 | sub_tok_to_orig_index.append(i) 171 | text_deptags.append(dep_tag_list[i]) 172 | text_postags.append(pos_tag_list[i]) 173 | 174 | assert len(text_tokens) == len(sub_tok_to_orig_index), \ 175 | "{} != {}".format(len(text_tokens), len(sub_tok_to_orig_index)) 176 | assert len(text_tokens) == len(first_wpiece_indicator) 177 | 178 | text_heads = [] 179 | head_idx = -1 180 | assert max(head_list) <= len(head_list), (max(head_list), len(head_list)) 181 | # iterating over the word pieces to adjust heads 182 | for i, orig_idx in enumerate(sub_tok_to_orig_index): 183 | # orig_idx: index of the original word (the word-piece belong to) 184 | head = head_list[orig_idx] 185 | if head == 0: # root 186 | # if root word is split into multiple pieces, 187 | # we make the first piece as the root node 188 | # and all the other word pieces as the child of the root node 189 | if head_idx == -1: 190 | head_idx = i + 1 191 | text_heads.append(0) 192 | else: 193 | text_heads.append(head_idx) 194 | else: 195 | if first_wpiece_indicator[i] == 1: 196 | # head indices start from 1, so subtracting 1 197 | head = old_index_to_new_index[head - 1] 198 | text_heads.append(head + 1) 199 | else: 200 | # word-piece of a token (except the first) 201 | # so, we make the first piece the parent of all other word pieces 202 | head = old_index_to_new_index[orig_idx] 203 | text_heads.append(head + 1) 204 | 205 | assert len(text_tokens) == len(text_heads), \ 206 | "{} != {}".format(len(text_tokens), len(text_heads)) 207 | 208 | return text_tokens, text_heads, labels, text_deptags, text_postags 209 | 210 | 211 | def convert_examples_to_features( 212 | examples, 213 | label_list, 214 | max_seq_length, 215 | tokenizer, 216 | cls_token_segment_id=0, 217 | sep_token_extra=False, 218 | pad_on_left=False, 219 | pad_token=0, 220 | pad_token_segment_id=0, 221 | pad_token_label_id=-1, 222 | sequence_a_segment_id=0, 223 | mask_padding_with_zero=True, 224 | lang="en", 225 | use_syntax=False, 226 | ): 227 | """Loads a data file into a list of `InputBatch`s 228 | `cls_token_at_end` define the location of the CLS token: 229 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] 230 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] 231 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) 232 | """ 233 | 234 | intent_label_list, slot_label_list = label_list 235 | intent_label_map = {label: i for i, label in enumerate(intent_label_list)} 236 | slot_label_map = {label: i for i, label in enumerate(slot_label_list)} 237 | special_tokens_count = 3 if sep_token_extra else 2 238 | 239 | features = [] 240 | over_length_examples = 0 241 | for (ex_index, example) in enumerate(examples): 242 | if ex_index % 10000 == 0: 243 | logger.info("Writing example %d of %d", ex_index, len(examples)) 244 | 245 | tokens, heads, slot_label_ids, dep_tags, pos_tags = process_sentence( 246 | example.words, 247 | example.heads, 248 | example.slot_labels, 249 | example.dep_tags, 250 | example.pos_tags, 251 | tokenizer, 252 | slot_label_map, 253 | pad_token_label_id 254 | ) 255 | 256 | orig_text_len = len(tokens) 257 | root_idx = heads.index(0) 258 | text_offset = 1 # text_a follows 259 | # So, we add 1 to head indices 260 | heads = np.add(heads, text_offset).tolist() 261 | # HEAD( root) = index of (1-based) 262 | heads[root_idx] = 1 263 | 264 | if len(tokens) > max_seq_length - special_tokens_count: 265 | # assert False # we already truncated sequence 266 | # print("truncate token", len(tokens), max_seq_length, special_tokens_count) 267 | # tokens = tokens[: (max_seq_length - special_tokens_count)] 268 | # label_ids = label_ids[: (max_seq_length - special_tokens_count)] 269 | over_length_examples += 1 270 | continue 271 | 272 | tokens += [tokenizer.sep_token] 273 | dep_tags += [tokenizer.sep_token] 274 | pos_tags += [tokenizer.sep_token] 275 | slot_label_ids += [pad_token_label_id] 276 | if sep_token_extra: 277 | # roberta uses an extra separator b/w pairs of sentences 278 | tokens += [tokenizer.sep_token] 279 | dep_tags += [tokenizer.sep_token] 280 | pos_tags += [tokenizer.sep_token] 281 | slot_label_ids += [pad_token_label_id] 282 | segment_ids = [sequence_a_segment_id] * len(tokens) 283 | 284 | # cls_token_at_begining 285 | tokens = [tokenizer.cls_token] + tokens 286 | dep_tags = [tokenizer.cls_token] + dep_tags 287 | pos_tags = [tokenizer.cls_token] + pos_tags 288 | slot_label_ids = [pad_token_label_id] + slot_label_ids 289 | segment_ids = [cls_token_segment_id] + segment_ids 290 | 291 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 292 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 293 | # tokens are attended to. 294 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 295 | 296 | # Zero-pad up to the sequence length. 297 | padding_length = max_seq_length - len(input_ids) 298 | if pad_on_left: 299 | input_ids = ([pad_token] * padding_length) + input_ids 300 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask 301 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids 302 | slot_label_ids = ([pad_token_label_id] * padding_length) + slot_label_ids 303 | else: 304 | input_ids += [pad_token] * padding_length 305 | input_mask += [0 if mask_padding_with_zero else 1] * padding_length 306 | segment_ids += [pad_token_segment_id] * padding_length 307 | slot_label_ids += [pad_token_label_id] * padding_length 308 | 309 | assert len(input_ids) == max_seq_length 310 | assert len(input_mask) == max_seq_length 311 | assert len(segment_ids) == max_seq_length 312 | assert len(slot_label_ids) == max_seq_length 313 | 314 | # if ex_index < 5: 315 | # logger.info("*** Example ***") 316 | # logger.info("guid: %s", example.guid) 317 | # logger.info("tokens: %s", " ".join([str(x) for x in tokens])) 318 | # logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) 319 | # logger.info("input_mask: %s", " ".join([str(x) for x in input_mask])) 320 | # logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) 321 | # logger.info("label_ids: %s", " ".join([str(x) for x in label_ids])) 322 | # logger.info("langs: {}".format(langs)) 323 | 324 | intent_label_id = intent_label_map[example.intent_label] 325 | 326 | one_ex_features = InputFeatures( 327 | input_ids=input_ids, 328 | input_mask=input_mask, 329 | segment_ids=segment_ids, 330 | intent_label_id=intent_label_id, 331 | slot_label_ids=slot_label_ids, 332 | ) 333 | 334 | if use_syntax: 335 | ##################################################### 336 | # prepare the UPOS and DEPENDENCY tag tensors 337 | ##################################################### 338 | dep_tag_ids = deptag_to_id(dep_tags, tokenizer=str(type(tokenizer))) 339 | pos_tag_ids = upos_to_id(pos_tags, tokenizer=str(type(tokenizer))) 340 | 341 | if pad_on_left: 342 | dep_tag_ids = ([0] * padding_length) + dep_tag_ids 343 | pos_tag_ids = ([0] * padding_length) + pos_tag_ids 344 | else: 345 | dep_tag_ids += [0] * padding_length 346 | pos_tag_ids += [0] * padding_length 347 | 348 | assert len(input_ids) == len(dep_tag_ids) 349 | assert len(input_ids) == len(pos_tag_ids) 350 | assert len(dep_tag_ids) == max_seq_length 351 | assert len(pos_tag_ids) == max_seq_length 352 | 353 | one_ex_features.tag_ids = pos_tag_ids 354 | one_ex_features.dep_tag_ids = dep_tag_ids 355 | 356 | ##################################################### 357 | # form the tree structure using head information 358 | ##################################################### 359 | heads = [0] + heads + [1, 1] if sep_token_extra else [0] + heads + [1] 360 | assert len(tokens) == len(heads) 361 | root, nodes = head_to_tree(heads, tokens) 362 | assert len(heads) == root.size() 363 | sep_token_indices = [i for i, x in enumerate(tokens) if x == tokenizer.sep_token] 364 | depths = [nodes[i].depth() for i in range(len(nodes))] 365 | depths = np.asarray(depths, dtype=np.int32) 366 | 367 | one_ex_features.root = root 368 | one_ex_features.depths = depths 369 | one_ex_features.sep_token_indices = sep_token_indices 370 | 371 | features.append(one_ex_features) 372 | 373 | if over_length_examples > 0: 374 | logger.info('{} examples are discarded due to exceeding maximum length'.format(over_length_examples)) 375 | return features 376 | 377 | 378 | def get_intent_labels(path): 379 | with open(path, "r") as f: 380 | labels = f.read().splitlines() 381 | return labels 382 | 383 | 384 | def get_slot_labels(path): 385 | with open(path, "r") as f: 386 | labels = f.read().splitlines() 387 | if "O" not in labels: 388 | labels = ["O"] + labels 389 | return labels 390 | 391 | 392 | class SequenceDataset(Dataset): 393 | def __init__(self, features): 394 | self.features = features 395 | 396 | def __len__(self): 397 | return len(self.features) 398 | 399 | def __getitem__(self, index): 400 | """Generates one sample of data""" 401 | feature = self.features[index] 402 | input_ids = torch.tensor(feature.input_ids, dtype=torch.long) 403 | intent_label_id = torch.tensor([feature.intent_label_id], dtype=torch.long) 404 | slot_label_ids = torch.tensor(feature.slot_label_ids, dtype=torch.long) 405 | attention_mask = torch.tensor(feature.input_mask, dtype=torch.long) 406 | token_type_ids = torch.tensor(feature.segment_ids, dtype=torch.long) 407 | 408 | dist_matrix = None 409 | depths = None 410 | dep_tag_ids = None 411 | pos_tag_ids = None 412 | if feature.root is not None: 413 | dep_tag_ids = torch.tensor(feature.dep_tag_ids, dtype=torch.long) 414 | pos_tag_ids = torch.tensor(feature.pos_tag_ids, dtype=torch.long) 415 | dist_matrix = root_to_dist_mat(feature.root) 416 | if feature.trunc_token_ids is not None: 417 | dist_matrix = np.delete(dist_matrix, feature.trunc_token_ids, 0) # delete rows 418 | dist_matrix = np.delete(dist_matrix, feature.trunc_token_ids, 1) # delete columns 419 | dist_matrix = torch.tensor(dist_matrix, dtype=torch.long) # seq_len x seq_len x max-path-len 420 | 421 | if feature.depths is not None: 422 | depths = feature.depths 423 | if feature.trunc_token_ids is not None: 424 | depths = np.delete(depths, feature.trunc_token_ids, 0) 425 | depths = torch.tensor(depths, dtype=torch.long) # seq_len 426 | 427 | return [ 428 | input_ids, 429 | attention_mask, 430 | token_type_ids, 431 | intent_label_id, 432 | slot_label_ids, 433 | dep_tag_ids, 434 | pos_tag_ids, 435 | dist_matrix, 436 | depths, 437 | ] 438 | 439 | 440 | def batchify(batch): 441 | """Receives a batch of SequencePairDataset examples""" 442 | input_ids = torch.stack([data[0] for data in batch], dim=0) 443 | attention_mask = torch.stack([data[1] for data in batch], dim=0) 444 | token_type_ids = torch.stack([data[2] for data in batch], dim=0) 445 | intent_labels = torch.stack([data[3] for data in batch], dim=0) 446 | slot_labels = torch.stack([data[4] for data in batch], dim=0) 447 | 448 | dist_matrix = None 449 | depths = None 450 | dep_tag_ids = None 451 | pos_tag_ids = None 452 | 453 | if batch[0][5] is not None: 454 | dep_tag_ids = torch.stack([data[5] for data in batch], dim=0) 455 | 456 | if batch[0][6] is not None: 457 | pos_tag_ids = torch.stack([data[6] for data in batch], dim=0) 458 | 459 | if batch[0][7] is not None: 460 | dist_matrix = torch.full( 461 | (len(batch), input_ids.size(1), input_ids.size(1)), 99999, dtype=torch.long 462 | ) 463 | for i, data in enumerate(batch): 464 | slen, slen = data[7].size() 465 | dist_matrix[i, :slen, :slen] = data[7] 466 | 467 | if batch[0][8] is not None: 468 | depths = torch.full( 469 | (len(batch), input_ids.size(1)), 99999, dtype=torch.long 470 | ) 471 | for i, data in enumerate(batch): 472 | slen = data[8].size(0) 473 | depths[i, :slen] = data[8] 474 | 475 | return [ 476 | input_ids, 477 | attention_mask, 478 | token_type_ids, 479 | intent_labels, 480 | slot_labels, 481 | dep_tag_ids, 482 | pos_tag_ids, 483 | dist_matrix, 484 | depths 485 | ] 486 | 487 | 488 | def get_exact_match(intent_preds, intent_labels, slot_preds, slot_labels): 489 | """For the cases that intent and all the slots are correct (in one sentence)""" 490 | # Get the intent comparison result 491 | intent_result = (intent_preds == intent_labels) 492 | 493 | # Get the slot comparision result 494 | slot_result = [] 495 | for preds, labels in zip(slot_preds, slot_labels): 496 | assert len(preds) == len(labels) 497 | one_sent_result = True 498 | for p, l in zip(preds, labels): 499 | if p != l: 500 | one_sent_result = False 501 | break 502 | slot_result.append(one_sent_result) 503 | slot_result = np.array(slot_result) 504 | 505 | exact_match_acc = np.multiply(intent_result, slot_result).mean() 506 | return exact_match_acc 507 | -------------------------------------------------------------------------------- /third_party/ud-conversion-tools/conllu_to_conll.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from itertools import islice 3 | from pathlib import Path 4 | import argparse 5 | import sys, copy 6 | 7 | from lib.conll import CoNLLReader 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser(description="""Convert conllu to conll format""") 11 | parser.add_argument('input', help="conllu file") 12 | parser.add_argument('output', help="target file", type=Path) 13 | parser.add_argument('--replace_subtokens_with_fused_forms', help="By default removes fused tokens", default=False, action="store_true") 14 | parser.add_argument('--remove_deprel_suffixes', help="Restrict deprels to the common universal subset, e.g. nmod:tmod becomes nmod", default=False, action="store_true") 15 | parser.add_argument('--remove_node_properties', help="space-separated list of node properties to remove: form, lemma, cpostag, postag, feats", choices=['form', 'lemma', 'cpostag','postag','feats'], metavar='prop', type=str, nargs='+') 16 | parser.add_argument('--lang', help="specify a language 2-letter code", default="default") 17 | parser.add_argument('--output_format', choices=['conll2006', 'conll2009', 'conllu'], default="conll2006") 18 | parser.add_argument('--remove_arabic_diacritics', help="remove Arabic short vowels", default=False, action="store_true") 19 | parser.add_argument('--print_comments',default=False,action="store_true") 20 | parser.add_argument('--print_fused_forms',default=False,action="store_true") 21 | 22 | args = parser.parse_args() 23 | 24 | if sys.version_info < (3,0): 25 | print("Sorry, requires Python 3.x.") #suggestion: install anaconda python 26 | sys.exit(1) 27 | 28 | POSRANKPRECEDENCEDICT = defaultdict(list) 29 | POSRANKPRECEDENCEDICT["default"] = "VERB NOUN PROPN PRON ADJ NUM ADV INTJ AUX ADP DET PART CCONJ SCONJ X PUNCT ".split(" ") 30 | # POSRANKPRECEDENCEDICT["de"] = "PROPN ADP DET ".split(" ") 31 | POSRANKPRECEDENCEDICT["es"] = "VERB AUX PRON ADP DET".split(" ") 32 | POSRANKPRECEDENCEDICT["fr"] = "VERB AUX PRON NOUN ADJ ADV ADP DET PART SCONJ CONJ".split(" ") 33 | POSRANKPRECEDENCEDICT["it"] = "VERB AUX ADV PRON ADP DET INTJ".split(" ") 34 | 35 | if args.lang in POSRANKPRECEDENCEDICT: 36 | current_pos_precedence_list = POSRANKPRECEDENCEDICT[args.lang] 37 | else: 38 | current_pos_precedence_list = POSRANKPRECEDENCEDICT["default"] 39 | 40 | cio = CoNLLReader() 41 | orig_treebank = cio.read_conll_u(args.input)#, args.keep_fused_forms, args.lang, POSRANKPRECEDENCEDICT) 42 | modif_treebank = copy.copy(orig_treebank) 43 | 44 | # As per Dec 2015 the args.lang variable is redundant once you have current_pos_precedence_list 45 | # We keep it for future modifications, i.e. any language-specific modules 46 | for s in modif_treebank: 47 | # print('sentence', s.get_sentence_as_string(printid=True)) 48 | s.filter_sentence_content(args.replace_subtokens_with_fused_forms, args.lang, current_pos_precedence_list,args.remove_node_properties,args.remove_deprel_suffixes,args.remove_arabic_diacritics) 49 | 50 | cio.write_conll(modif_treebank,args.output, args.output_format,print_fused_forms=args.print_fused_forms, print_comments=args.print_comments) 51 | 52 | if __name__ == "__main__": 53 | main() -------------------------------------------------------------------------------- /third_party/ud-conversion-tools/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wasiahmad/Syntax-MBERT/1af0ef5ff9cc7e3e7de6e662e3d677ee157630ca/third_party/ud-conversion-tools/lib/__init__.py -------------------------------------------------------------------------------- /third_party/ud-conversion-tools/lib/conll.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | from collections import Counter 3 | import re 4 | 5 | 6 | #TODO make these parse functions static methods of ConllReder 7 | def parse_id(id_str): 8 | if id_str == '_': 9 | return None 10 | if "." in id_str: 11 | return None 12 | ids = tuple(map(int, id_str.split("-"))) 13 | if len(ids) == 1: 14 | return ids[0] 15 | else: 16 | return ids 17 | 18 | def parse_feats(feats_str): 19 | if feats_str == '_': 20 | return {} 21 | feat_pairs = [pair.split("=") for pair in feats_str.split("|")] 22 | return {k: v for k, v in feat_pairs} 23 | 24 | def parse_deps(dep_str): 25 | if dep_str == '_': 26 | return [] 27 | dep_pairs = [pair.split(":") for pair in dep_str.split("|")] 28 | return [(int(pair[0]), pair[1]) for pair in dep_pairs if pair[0].isdigit()] 29 | 30 | 31 | 32 | 33 | class DependencyTree(nx.DiGraph): 34 | """ 35 | A DependencyTree as networkx graph: 36 | nodes store information about tokens 37 | edges store edge related info, e.g. dependency relations 38 | """ 39 | 40 | def __init__(self): 41 | nx.DiGraph.__init__(self) 42 | 43 | def pathtoroot(self, child): 44 | path = [] 45 | newhead = self.head_of(self, child) 46 | while newhead: 47 | path.append(newhead) 48 | newhead = self.head_of(self, newhead) 49 | return path 50 | 51 | def head_of(self, n): 52 | for u, v in self.edges(): 53 | if v == n: 54 | return u 55 | return None 56 | 57 | def get_sentence_as_string(self,printid=False): 58 | out = [] 59 | for token_i in range(1, max(self.nodes()) + 1): 60 | if printid: 61 | out.append(str(token_i)+":"+self.node[token_i]['form']) 62 | else: 63 | out.append(self.node[token_i]['form']) 64 | return u" ".join(out) 65 | 66 | def subsumes(self, head, child): 67 | if head in self.pathtoroot(self, child): 68 | return True 69 | 70 | def remove_arabic_diacritics(self): 71 | # The following code is based on nltk.stem.isri 72 | # It is equivalent to an interative application of isri.norm(word,num=1) 73 | # i.e. we do not remove any hamza characters 74 | 75 | re_short_vowels = re.compile(r'[\u064B-\u0652]') 76 | for n in self.nodes(): 77 | self.node[n]["form"] = re_short_vowels.sub('', self.node[n]["form"]) 78 | 79 | 80 | def get_highest_index_of_span(self, span): # retrieves the node index that is closest to root 81 | #TODO: CANDIDATE FOR DEPRECATION 82 | distancestoroot = [len(self.pathtoroot(self, x)) for x in span] 83 | shortestdistancetoroot = min(distancestoroot) 84 | spanhead = span[distancestoroot.index(shortestdistancetoroot)] 85 | return spanhead 86 | 87 | def get_deepest_index_of_span(self, span): # retrieves the node index that is farthest from root 88 | #TODO: CANDIDATE FOR DEPRECATION 89 | distancestoroot = [len(self.pathtoroot(self, x)) for x in span] 90 | longestdistancetoroot = max(distancestoroot) 91 | lownode = span[distancestoroot.index(longestdistancetoroot)] 92 | return lownode 93 | 94 | def span_makes_subtree(self, initidx, endidx): 95 | G = nx.DiGraph() 96 | span_nodes = list(range(initidx,endidx+1)) 97 | span_words = [self.node[x]["form"] for x in span_nodes] 98 | G.add_nodes_from(span_nodes) 99 | for h,d in self.edges(): 100 | if h in span_nodes and d in span_nodes: 101 | G.add_edge(h,d) 102 | return nx.is_tree(G) 103 | 104 | def _choose_spanhead_from_heuristics(self,span_nodes,pos_precedence_list): 105 | distancestoroot = [len(nx.ancestors(self,x)) for x in span_nodes] 106 | shortestdistancetoroot = min(distancestoroot) 107 | distance_counter = Counter(distancestoroot) 108 | 109 | highest_nodes_in_span = [] 110 | # Heuristic Nr 1: If there is one single highest node in the span, it becomes the head 111 | # N.B. no need for the subspan to be a tree if there is one single highest element 112 | if distance_counter[shortestdistancetoroot] == 1: 113 | spanhead = span_nodes[distancestoroot.index(shortestdistancetoroot)] 114 | return spanhead 115 | 116 | # Heuristic Nr 2: Choose by POS ranking the best head out of the highest nodes 117 | for x in span_nodes: 118 | if len(nx.ancestors(self,x)) == shortestdistancetoroot: 119 | highest_nodes_in_span.append(x) 120 | 121 | best_rank = len(pos_precedence_list) + 1 122 | candidate_head = - 1 123 | span_upos = [self.node[x]["cpostag"]for x in highest_nodes_in_span] 124 | for upos, idx in zip(span_upos,highest_nodes_in_span): 125 | if pos_precedence_list.index(upos) < best_rank: 126 | best_rank = pos_precedence_list.index(upos) 127 | candidate_head = idx 128 | return candidate_head 129 | 130 | def _remove_node_properties(self,fields): 131 | for n in sorted(self.nodes()): 132 | for fieldname in self.node[n].keys(): 133 | if fieldname in fields: 134 | self.node[n][fieldname]="_" 135 | 136 | def _remove_deprel_suffixes(self): 137 | for h,d in self.edges(): 138 | if ":" in self[h][d]["deprel"]: 139 | self[h][d]["deprel"]=self[h][d]["deprel"].split(":")[0] 140 | 141 | def _keep_fused_form(self,posPreferenceDicts): 142 | # For a span A,B and external tokens C, such as A > B > C, we have to 143 | # Make A the head of the span 144 | # Attach C-level tokens to A 145 | #Remove B-level tokens, which are the subtokens of the fused form della: de la 146 | 147 | if self.graph["multi_tokens"] == {}: 148 | return 149 | 150 | spanheads = [] 151 | spanhead_fused_token_dict = {} 152 | # This double iteration is overkill, one could skip the spanhead identification 153 | # but in this way we avoid modifying the tree as we read it 154 | for fusedform_idx in sorted(self.graph["multi_tokens"]): 155 | fusedform_start, fusedform_end = self.graph["multi_tokens"][fusedform_idx]["id"] 156 | fuseform_span = list(range(fusedform_start,fusedform_end+1)) 157 | spanhead = self._choose_spanhead_from_heuristics(fuseform_span,posPreferenceDicts) 158 | #if not spanhead: 159 | # spanhead = self._choose_spanhead_from_heuristics(fuseform_span,posPreferenceDicts) 160 | spanheads.append(spanhead) 161 | spanhead_fused_token_dict[spanhead] = fusedform_idx 162 | 163 | # try: 164 | # order = list(nx.topological_sort(self)) 165 | # except nx.NetworkXUnfeasible: 166 | # msg = 'Circular dependency detected between hooks' 167 | # problem_graph = ', '.join(f'{a} -> {b}' 168 | # for a, b in nx.find_cycle(self)) 169 | # print('nx.simple_cycles', list(nx.simple_cycles(self))) 170 | # print(problem_graph) 171 | # exit(0) 172 | # for edge in list(nx.simple_cycles(self)): 173 | # self.remove_edge(edge[0], edge[1]) 174 | self = remove_all_cycle(self) 175 | bottom_up_order = [x for x in nx.topological_sort(self) if x in spanheads] 176 | for spanhead in bottom_up_order: 177 | fusedform_idx = spanhead_fused_token_dict[spanhead] 178 | fusedform = self.graph["multi_tokens"][fusedform_idx]["form"] 179 | fusedform_start, fusedform_end = self.graph["multi_tokens"][fusedform_idx]["id"] 180 | fuseform_span = list(range(fusedform_start,fusedform_end+1)) 181 | 182 | if spanhead: 183 | #Step 1: Replace form of head span (A) with fusedtoken form -- in this way we keep the lemma and features if any 184 | self.node[spanhead]["form"] = fusedform 185 | # 2- Reattach C-level (external dependents) to A 186 | #print(fuseform_span,spanhead) 187 | 188 | internal_dependents = set(fuseform_span) - set([spanhead]) 189 | external_dependents = [nx.bfs_successors(self,x) for x in internal_dependents] 190 | for depdict in external_dependents: 191 | for localhead in depdict: 192 | for ext_dep in depdict[localhead]: 193 | if ext_dep in self[localhead]: 194 | deprel = self[localhead][ext_dep]["deprel"] 195 | self.remove_edge(localhead,ext_dep) 196 | self.add_edge(spanhead,ext_dep,deprel=deprel) 197 | 198 | #3- Remove B-level tokens 199 | for int_dep in internal_dependents: 200 | self.remove_edge(self.head_of(int_dep),int_dep) 201 | self.remove_node(int_dep) 202 | 203 | #4 reconstruct tree at the very end 204 | new_index_dict = {} 205 | for new_node_index, old_node_idex in enumerate(sorted(self.nodes())): 206 | new_index_dict[old_node_idex] = new_node_index 207 | 208 | T = DependencyTree() # Transfer DiGraph, to replace self 209 | 210 | for n in sorted(self.nodes()): 211 | T.add_node(new_index_dict[n],self.node[n]) 212 | 213 | for h, d in self.edges(): 214 | T.add_edge(new_index_dict[h],new_index_dict[d],deprel=self[h][d]["deprel"]) 215 | #4A Quick removal of edges and nodes 216 | self.__init__() 217 | 218 | #4B Rewriting the Deptree in Self 219 | # TODO There must a more elegant way to rewrite self -- self= T for instance? 220 | for n in sorted(T.nodes()): 221 | self.add_node(n,T.node[n]) 222 | 223 | for h,d in T.edges(): 224 | self.add_edge(h,d,T[h][d]) 225 | 226 | # 5. remove all fused forms form the multi_tokens field 227 | self.graph["multi_tokens"] = {} 228 | 229 | # if not nx.is_tree(self): 230 | # print("Not a tree after fused-form heuristics:",self.get_sentence_as_string()) 231 | 232 | def filter_sentence_content(self,replace_subtokens_with_fused_forms=False, lang=None, posPreferenceDict=None,node_properties_to_remove=None,remove_deprel_suffixes=False,remove_arabic_diacritics=False): 233 | if replace_subtokens_with_fused_forms: 234 | self._keep_fused_form(posPreferenceDict) 235 | if remove_deprel_suffixes: 236 | self._remove_deprel_suffixes() 237 | if node_properties_to_remove: 238 | self._remove_node_properties(node_properties_to_remove) 239 | if remove_arabic_diacritics: 240 | self.remove_arabic_diacritics() 241 | 242 | def remove_all_cycle(G): 243 | GC = nx.DiGraph(G.edges()) 244 | edges = list(nx.simple_cycles(GC)) 245 | for edge in edges: 246 | for i in range(len(edge)-1): 247 | for j in range(i+1, len(edge)): 248 | a, b = edge[i], edge[j] 249 | if G.has_edge(a, b): 250 | # print('remove {} - {}'.format(a, b)) 251 | G.remove_edge(a, b) 252 | return G 253 | 254 | 255 | class CoNLLReader(object): 256 | """ 257 | conll input/output 258 | """ 259 | 260 | "" "Static properties""" 261 | CONLL06_COLUMNS = [('id',int), ('form',str), ('lemma',str), ('cpostag',str), ('postag',str), ('feats',str), ('head',int), ('deprel',str), ('phead', str), ('pdeprel',str)] 262 | #CONLL06_COLUMNS = ['id', 'form', 'lemma', 'cpostag', 'postag', 'feats', 'head', 'deprel', 'phead', 'pdeprel'] 263 | CONLL06DENSE_COLUMNS = [('id',int), ('form',str), ('lemma',str), ('cpostag',str), ('postag',str), ('feats',str), ('head',int), ('deprel',str), ('edgew',str)] 264 | CONLL_U_COLUMNS = [('id', parse_id), ('form', str), ('lemma', str), ('cpostag', str), 265 | ('postag', str), ('feats', str), ('head', parse_id), ('deprel', str), 266 | ('deps', parse_deps), ('misc', str)] 267 | #CONLL09_COLUMNS = ['id','form','lemma','plemma','cpostag','pcpostag','feats','pfeats','head','phead','deprel','pdeprel'] 268 | 269 | 270 | 271 | def __init__(self): 272 | pass 273 | 274 | def read_conll_2006(self, filename): 275 | sentences = [] 276 | sent = DependencyTree() 277 | for line_num, conll_line in enumerate(open(filename)): 278 | parts = conll_line.strip().split("\t") 279 | if len(parts) in (8, 10): 280 | token_dict = {key: conv_fn(val) for (key, conv_fn), val in zip(self.CONLL06_COLUMNS, parts)} 281 | 282 | sent.add_node(token_dict['id'], token_dict) 283 | sent.add_edge(token_dict['head'], token_dict['id'], deprel=token_dict['deprel']) 284 | elif len(parts) == 0 or (len(parts)==1 and parts[0]==""): 285 | sentences.append(sent) 286 | sent = DependencyTree() 287 | else: 288 | raise Exception("Invalid input format in line nr: ", line_num, conll_line, filename) 289 | 290 | return sentences 291 | 292 | def read_conll_2006_dense(self, filename): 293 | sentences = [] 294 | sent = DependencyTree() 295 | for conll_line in open(filename): 296 | parts = conll_line.strip().split("\t") 297 | if len(parts) == 9: 298 | token_dict = {key: conv_fn(val) for (key, conv_fn), val in zip(self.CONLL06DENSE_COLUMNS, parts)} 299 | 300 | sent.add_node(token_dict['id'], token_dict) 301 | sent.add_edge(token_dict['head'], token_dict['id'], deprel=token_dict['deprel']) 302 | elif len(parts) == 0 or (len(parts)==1 and parts[0]==""): 303 | sentences.append(sent) 304 | sent = DependencyTree() 305 | else: 306 | raise Exception("Invalid input format in line: ", conll_line, filename) 307 | 308 | return sentences 309 | 310 | 311 | 312 | def write_conll(self, list_of_graphs, conll_path,conllformat, print_fused_forms=False,print_comments=False): 313 | # TODO add comment writing 314 | if conllformat == "conllu": 315 | columns = [colname for colname, fname in self.CONLL_U_COLUMNS] 316 | else: 317 | columns = [colname for colname, fname in self.CONLL06_COLUMNS] 318 | 319 | with conll_path.open('w') as out: 320 | for sent_i, sent in enumerate(list_of_graphs): 321 | if sent_i > 0: 322 | print("", file=out) 323 | if print_comments: 324 | for c in sent.graph["comment"]: 325 | print(c, file=out) 326 | for token_i in range(1, max(sent.nodes()) + 1): 327 | token_dict = dict(sent.node[token_i]) 328 | head_i = sent.head_of(token_i) 329 | if head_i is None: 330 | token_dict['head'] = 0 331 | token_dict['deprel'] = '' 332 | else: 333 | token_dict['head'] = head_i 334 | token_dict['deprel'] = sent[head_i][token_i]['deprel'] 335 | token_dict['id'] = token_i 336 | row = [str(token_dict.get(col, '_')) for col in columns] 337 | if print_fused_forms and token_i in sent.graph["multi_tokens"]: 338 | currentmulti = sent.graph["multi_tokens"][token_i] 339 | currentmulti["id"]=str(currentmulti["id"][0])+"-"+str(currentmulti["id"][1]) 340 | currentmulti["feats"]="_" 341 | currentmulti["head"]="_" 342 | rowmulti = [str(currentmulti.get(col, '_')) for col in columns] 343 | print(u"\t".join(rowmulti),file=out) 344 | print(u"\t".join(row), file=out) 345 | 346 | # emtpy line afterwards 347 | print(u"", file=out) 348 | 349 | 350 | def read_conll_u(self,filename,keepFusedForm=False, lang=None, posPreferenceDict=None): 351 | sentences = [] 352 | sent = DependencyTree() 353 | multi_tokens = {} 354 | 355 | for line_no, line in enumerate(open(filename).readlines()): 356 | line = line.strip("\n") 357 | if not line: 358 | # Add extra properties to ROOT node if exists 359 | if 0 in sent: 360 | for key in ('form', 'lemma', 'cpostag', 'postag'): 361 | sent.node[0][key] = 'ROOT' 362 | 363 | # Handle multi-tokens 364 | sent.graph['multi_tokens'] = multi_tokens 365 | multi_tokens = {} 366 | sentences.append(sent) 367 | sent = DependencyTree() 368 | elif line.startswith("#"): 369 | if 'comment' not in sent.graph: 370 | sent.graph['comment'] = [line] 371 | else: 372 | sent.graph['comment'].append(line) 373 | else: 374 | parts = line.split("\t") 375 | if len(parts) != len(self.CONLL_U_COLUMNS): 376 | error_msg = 'Invalid number of columns in line {} (found {}, expected {})'.format(line_no, len(parts), len(CONLL_U_COLUMNS)) 377 | raise Exception(error_msg) 378 | 379 | token_dict = {key: conv_fn(val) for (key, conv_fn), val in zip(self.CONLL_U_COLUMNS, parts)} 380 | if isinstance(token_dict['id'], int): 381 | sent.add_edge(token_dict['head'], token_dict['id'], deprel=token_dict['deprel']) 382 | sent.node[token_dict['id']].update({k: v for (k, v) in token_dict.items() 383 | if k not in ('head', 'id', 'deprel', 'deps')}) 384 | for head, deprel in token_dict['deps']: 385 | sent.add_edge(head, token_dict['id'], deprel=deprel, secondary=True) 386 | elif token_dict['id'] is not None: 387 | #print(token_dict['id']) 388 | first_token_id = int(token_dict['id'][0]) 389 | multi_tokens[first_token_id] = token_dict 390 | return sentences 391 | -------------------------------------------------------------------------------- /udpipe/conllify.py: -------------------------------------------------------------------------------- 1 | import ufal.udpipe 2 | 3 | # UDPipe supports all MLQA and PAWS-X languages but only 13/15 languages for XNLI 4 | LANG_MAP = { 5 | 'ar': 'models/arabic-padt-ud-2.5-191206.udpipe', # MLQA, XNLI 6 | 'bg': 'models/bulgarian-btb-ud-2.5-191206.udpipe', # XNLI 7 | 'de': 'models/german-gsd-ud-2.5-191206.udpipe', # MLQA, PAWS-X, XNLI 8 | 'el': 'models/greek-gdt-ud-2.5-191206.udpipe', # XNLI 9 | 'en': 'models/english-ewt-ud-2.5-191206.udpipe', # MLQA, XNLI 10 | 'es': 'models/spanish-gsd-ud-2.5-191206.udpipe', # MLQA, PAWS-X, XNLI 11 | 'fr': 'models/french-gsd-ud-2.5-191206.udpipe', # PAWS-X, XNLI 12 | 'hi': 'models/hindi-hdtb-ud-2.5-191206.udpipe', # MLQA, XNLI 13 | 'ja': 'models/japanese-gsd-ud-2.5-191206.udpipe', # PAWS-X, 14 | 'ko': 'models/korean-gsd-ud-2.5-191206.udpipe', # PAWS-X, 15 | 'ru': 'models/russian-gsd-ud-2.5-191206.udpipe', # XNLI 16 | 'tr': 'models/turkish-imst-ud-2.5-191206.udpipe', # XNLI 17 | 'ur': 'models/urdu-udtb-ud-2.5-191206.udpipe', # XNLI 18 | 'vi': 'models/vietnamese-vtb-ud-2.5-191206.udpipe', # MLQA, XNLI 19 | 'zh': 'models/chinese-gsd-ud-2.5-191206.udpipe', # MLQA, PAWS-X, XNLI 20 | 'pt': 'models/portuguese-gsd-ud-2.5-191206.udpipe', 21 | } 22 | 23 | 24 | class Model: 25 | def __init__(self, lang, model_file=None): 26 | """Load given model.""" 27 | self.lang = lang 28 | if model_file: 29 | self.model = ufal.udpipe.Model.load(model_file) 30 | if not self.model: 31 | raise Exception("Cannot load UDPipe model from file '%s'" % model_file) 32 | else: 33 | self.model = ufal.udpipe.Model.load(LANG_MAP[lang]) 34 | if not self.model: 35 | raise Exception("Cannot load UDPipe model from file '%s'" % LANG_MAP[lang]) 36 | 37 | def tokenize(self, text, *args): 38 | """Tokenize the text and return list of ufal.udpipe.Sentence-s.""" 39 | tokenizer = self.model.newTokenizer(*args) 40 | if not tokenizer: 41 | raise Exception("The model does not have a tokenizer") 42 | return self._read(text, tokenizer) 43 | 44 | def read(self, text, in_format): 45 | """Load text in the given format (conllu|horizontal|vertical) and return list of ufal.udpipe.Sentence-s.""" 46 | input_format = ufal.udpipe.InputFormat.newInputFormat(in_format) 47 | if not input_format: 48 | raise Exception("Cannot create input format '%s'" % in_format) 49 | return self._read(text, input_format) 50 | 51 | def _read(self, text, input_format): 52 | input_format.setText(text) 53 | error = ufal.udpipe.ProcessingError() 54 | sentences = [] 55 | 56 | sentence = ufal.udpipe.Sentence() 57 | while input_format.nextSentence(sentence, error): 58 | sentences.append(sentence) 59 | sentence = ufal.udpipe.Sentence() 60 | if error.occurred(): 61 | raise Exception(error.message) 62 | 63 | return sentences 64 | 65 | def tag(self, sentence): 66 | """Tag the given ufal.udpipe.Sentence (inplace).""" 67 | self.model.tag(sentence, self.model.DEFAULT) 68 | 69 | def parse(self, sentence): 70 | """Parse the given ufal.udpipe.Sentence (inplace).""" 71 | self.model.parse(sentence, self.model.DEFAULT) 72 | 73 | def write(self, sentences, out_format): 74 | """Write given ufal.udpipe.Sentence-s in the required format (conllu|horizontal|vertical).""" 75 | 76 | output_format = ufal.udpipe.OutputFormat.newOutputFormat(out_format) 77 | output = '' 78 | for sentence in sentences: 79 | output += output_format.writeSentence(sentence) 80 | output += output_format.finishDocument() 81 | 82 | return output 83 | -------------------------------------------------------------------------------- /udpipe/mtop.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CURRENT_DIR=`pwd` 4 | HOME_DIR=`realpath ..` 5 | LANG=(en es fr de hi) 6 | 7 | ############################# Downloading UDPipe ############################# 8 | 9 | URL_PREFIX='https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-3131' 10 | declare -A LANG_MAP 11 | LANG_MAP['en']='english-ewt-ud-2.5-191206.udpipe' 12 | LANG_MAP['fr']='french-gsd-ud-2.5-191206.udpipe' 13 | LANG_MAP['es']='spanish-gsd-ud-2.5-191206.udpipe' 14 | LANG_MAP['de']='german-gsd-ud-2.5-191206.udpipe' 15 | LANG_MAP['hi']='hindi-hdtb-ud-2.5-191206.udpipe' 16 | 17 | OUT_DIR=${CURRENT_DIR}/models 18 | mkdir -p $OUT_DIR 19 | 20 | for lang in ${LANG[@]}; do 21 | if [[ ! -f ${OUT_DIR}/${LANG_MAP[${lang}]} ]]; then 22 | curl -Lo ${OUT_DIR}/${LANG_MAP[${lang}]} ${URL_PREFIX}/${LANG_MAP[${lang}]} 23 | fi 24 | done 25 | 26 | ############################# 27 | 28 | DATA_DIR=${HOME_DIR}/download/mtop 29 | OUT_DIR=${HOME_DIR}/download/mtop_udpipe_processed 30 | mkdir -p $OUT_DIR 31 | 32 | # train data processing 33 | if [[ ! -f ${OUT_DIR}/train-en.jsonl ]]; then 34 | python process.py \ 35 | --task mtop \ 36 | --input_file ${DATA_DIR}/train-en.jsonl \ 37 | --output_file ${OUT_DIR}/train-en.jsonl \ 38 | --pre_lang en \ 39 | --workers 60; 40 | fi 41 | 42 | for lang in "${LANG[@]}"; do 43 | for split in dev test; do 44 | outfile=${OUT_DIR}/${split}-${lang}.jsonl 45 | if [[ ! -f $outfile ]]; then 46 | python process.py \ 47 | --task mtop \ 48 | --input_file ${DATA_DIR}/${split}-${lang}.jsonl \ 49 | --output_file $outfile \ 50 | --pre_lang $lang \ 51 | --workers 60; 52 | fi 53 | done 54 | done 55 | -------------------------------------------------------------------------------- /udpipe/panx.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CURRENT_DIR=`pwd` 4 | HOME_DIR=`realpath ..` 5 | LANG=(af ar bg de el en es et fi fr he hi hu id it ja ko mr nl pt ru ta te tr ur vi zh) 6 | 7 | ############################# Downloading UDPipe ############################# 8 | 9 | URL_PREFIX='https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-3131' 10 | declare -A LANG_MAP 11 | LANG_MAP['af']='afrikaans-afribooms-ud-2.5-191206.udpipe' 12 | LANG_MAP['ar']='arabic-padt-ud-2.5-191206.udpipe' 13 | LANG_MAP['bg']='bulgarian-btb-ud-2.5-191206.udpipe' 14 | LANG_MAP['de']='german-gsd-ud-2.5-191206.udpipe' 15 | LANG_MAP['el']='greek-gdt-ud-2.5-191206.udpipe' 16 | LANG_MAP['en']='english-ewt-ud-2.5-191206.udpipe' 17 | LANG_MAP['es']='spanish-gsd-ud-2.5-191206.udpipe' 18 | LANG_MAP['et']='estonian-edt-ud-2.5-191206.udpipe' 19 | LANG_MAP['fi']='finnish-tdt-ud-2.5-191206.udpipe' 20 | LANG_MAP['fr']='french-gsd-ud-2.5-191206.udpipe' 21 | LANG_MAP['he']='hebrew-htb-ud-2.5-191206.udpipe' 22 | LANG_MAP['hi']='hindi-hdtb-ud-2.5-191206.udpipe' 23 | LANG_MAP['hu']='hungarian-szeged-ud-2.5-191206.udpipe' 24 | LANG_MAP['id']='indonesian-gsd-ud-2.5-191206.udpipe' 25 | LANG_MAP['it']='italian-isdt-ud-2.5-191206.udpipe' 26 | LANG_MAP['ja']='japanese-gsd-ud-2.5-191206.udpipe' 27 | LANG_MAP['ko']='korean-kaist-ud-2.5-191206.udpipe' # 'korean-gsd-ud-2.5-191206.udpipe' 28 | LANG_MAP['mr']='marathi-ufal-ud-2.5-191206.udpipe' 29 | LANG_MAP['nl']='dutch-alpino-ud-2.5-191206.udpipe' 30 | LANG_MAP['pt']='portuguese-bosque-ud-2.5-191206.udpipe' # 'portuguese-gsd-ud-2.5-191206.udpipe' 31 | LANG_MAP['ru']='russian-gsd-ud-2.5-191206.udpipe' 32 | LANG_MAP['ta']='tamil-ttb-ud-2.5-191206.udpipe' 33 | LANG_MAP['te']='telugu-mtg-ud-2.5-191206.udpipe' 34 | LANG_MAP['tr']='turkish-imst-ud-2.5-191206.udpipe' 35 | LANG_MAP['ur']='urdu-udtb-ud-2.5-191206.udpipe' 36 | LANG_MAP['vi']='vietnamese-vtb-ud-2.5-191206.udpipe' 37 | LANG_MAP['zh']='chinese-gsd-ud-2.5-191206.udpipe' 38 | 39 | UDPIPE_DIR=${CURRENT_DIR}/models 40 | mkdir -p $UDPIPE_DIR 41 | 42 | for lang in ${LANG[@]}; do 43 | if [[ ! -f ${UDPIPE_DIR}/${LANG_MAP[${lang}]} ]]; then 44 | curl -Lo ${UDPIPE_DIR}/${LANG_MAP[${lang}]} ${URL_PREFIX}/${LANG_MAP[${lang}]} 45 | fi 46 | done 47 | 48 | ############################# 49 | 50 | DATA_DIR=${HOME_DIR}/download/panx 51 | OUT_DIR=${HOME_DIR}/download/panx_udpipe_processed 52 | mkdir -p $OUT_DIR 53 | 54 | # train data processing 55 | if [[ ! -f ${OUT_DIR}/train-en.jsonl ]]; then 56 | python process.py \ 57 | --task panx \ 58 | --input_file ${DATA_DIR}/train-en.tsv \ 59 | --output_file ${OUT_DIR}/train-en.jsonl \ 60 | --pre_lang en \ 61 | --udpipe_model ${UDPIPE_DIR}/${LANG_MAP['en']} \ 62 | --workers 60; 63 | fi 64 | 65 | for lang in "${LANG[@]}"; do 66 | for split in dev test; do 67 | outfile=${OUT_DIR}/${split}-${lang}.jsonl 68 | if [[ ! -f $outfile ]]; then 69 | python process.py \ 70 | --task panx \ 71 | --input_file ${DATA_DIR}/${split}-${lang}.tsv \ 72 | --output_file $outfile \ 73 | --udpipe_model ${UDPIPE_DIR}/${LANG_MAP[${lang}]} \ 74 | --pre_lang $lang \ 75 | --workers 60; 76 | fi 77 | done 78 | done 79 | -------------------------------------------------------------------------------- /udpipe/pawsx.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CURRENT_DIR=`pwd` 4 | HOME_DIR=`realpath ..` 5 | LANG=(en fr es de zh ja ko) 6 | 7 | ############################# Downloading UDPipe ############################# 8 | 9 | URL_PREFIX='https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-3131' 10 | declare -A LANG_MAP 11 | LANG_MAP['en']='english-ewt-ud-2.5-191206.udpipe' 12 | LANG_MAP['fr']='french-gsd-ud-2.5-191206.udpipe' 13 | LANG_MAP['es']='spanish-gsd-ud-2.5-191206.udpipe' 14 | LANG_MAP['de']='german-gsd-ud-2.5-191206.udpipe' 15 | LANG_MAP['zh']='chinese-gsd-ud-2.5-191206.udpipe' 16 | LANG_MAP['ja']='japanese-gsd-ud-2.5-191206.udpipe' 17 | LANG_MAP['ko']='korean-gsd-ud-2.5-191206.udpipe' 18 | 19 | OUT_DIR=${CURRENT_DIR}/models 20 | mkdir -p $OUT_DIR 21 | 22 | for lang in ${LANG[@]}; do 23 | if [[ ! -f ${OUT_DIR}/${LANG_MAP[${lang}]} ]]; then 24 | curl -Lo ${OUT_DIR}/${LANG_MAP[${lang}]} ${URL_PREFIX}/${LANG_MAP[${lang}]} 25 | fi 26 | done 27 | 28 | ############################# 29 | 30 | DATA_DIR=${HOME_DIR}/download/pawsx 31 | OUT_DIR=${HOME_DIR}/download/pawsx_udpipe_processed 32 | mkdir -p $OUT_DIR 33 | 34 | # train data processing 35 | python process.py \ 36 | --task pawsx \ 37 | --input_file ${DATA_DIR}/train-en.tsv \ 38 | --output_file ${OUT_DIR}/train-en.jsonl \ 39 | --pre_lang en \ 40 | --hyp_lang en \ 41 | --workers 60; 42 | 43 | for split in dev test; do 44 | for lang in "${LANG[@]}"; do 45 | python process.py \ 46 | --task pawsx \ 47 | --input_file ${DATA_DIR}/${split}-${lang}.tsv \ 48 | --output_file ${OUT_DIR}/${split}-${lang}.jsonl \ 49 | --pre_lang $lang \ 50 | --hyp_lang $lang \ 51 | --workers 60; 52 | done 53 | done 54 | -------------------------------------------------------------------------------- /udpipe/process.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append(".") 4 | sys.path.append("..") 5 | 6 | import json 7 | import argparse 8 | from tqdm import tqdm 9 | from conllu import parse 10 | from conllify import Model 11 | from multiprocessing import Pool 12 | from collections import OrderedDict 13 | from multiprocessing.util import Finalize 14 | from third_party.processors.tree import head_to_tree 15 | 16 | TOK = None 17 | LANGUAGES = None 18 | UDPIPE_MODELS = None 19 | 20 | 21 | def init(): 22 | global TOK, LANGUAGES 23 | TOK = UDPipeTokenizer(LANGUAGES, UDPIPE_MODELS) 24 | Finalize(TOK, TOK.shutdown, exitpriority=100) 25 | 26 | 27 | class UDPipeTokenizer(object): 28 | 29 | def __init__(self, langs, udpipe_models=None): 30 | self.models = {} 31 | if udpipe_models is not None: 32 | assert len(langs) == len(udpipe_models) 33 | for i, l in enumerate(langs): 34 | if udpipe_models: 35 | self.models[l] = Model(l, model_file=udpipe_models[i]) 36 | else: 37 | self.models[l] = Model(l) 38 | 39 | def shutdown(self): 40 | pass 41 | 42 | def tokenize(self, text, lang): 43 | assert lang in self.models 44 | sentences = self.models[lang].tokenize(text, 'presegmented') 45 | for s in sentences: 46 | self.models[lang].tag(s) 47 | self.models[lang].parse(s) 48 | conllu = self.models[lang].write(sentences, "conllu") 49 | sentences = parse(conllu) 50 | outObj = OrderedDict([ 51 | ('tokens', []), 52 | ('upos', []), 53 | ('head', []), 54 | ('deprel', []) 55 | ]) 56 | 57 | # NOTE: num_tokens != num_words in a sentence, because tokens can be multi-word 58 | # we only use the first word's information for a multi-word token 59 | # IDEA CREDIT: https://github.com/ufal/udpipe/issues/123 60 | for idx, sentence in enumerate(sentences): 61 | tokens, upos, head, deprel = [], [], [], [] 62 | word_to_token_map = {} 63 | for widx, word in enumerate(sentence): 64 | if isinstance(word['id'], tuple): 65 | # multi-word token, e.g., word['id'] = (4, '-', 5) 66 | assert len(word['id']) == 3 67 | start, end = int(word['id'][0]), int(word['id'][2]) 68 | for word_id in list(range(start, end + 1)): 69 | assert word_id not in word_to_token_map 70 | word_to_token_map[word_id] = start 71 | else: 72 | if word['misc'] is not None: 73 | # single-word token 74 | assert word['id'] not in word_to_token_map 75 | word_to_token_map[word['id']] = word['id'] 76 | tokens.append(word['form']) 77 | upos.append(word['upostag']) 78 | deprel.append(word['deprel']) 79 | assert isinstance(word['head'], int) 80 | head.append(word['head']) 81 | 82 | assert len(tokens) == len(upos) == len(head) == len(deprel) 83 | outObj['tokens'].append(tokens) 84 | outObj['upos'].append(upos) 85 | outObj['head'].append(head) 86 | outObj['deprel'].append(deprel) 87 | 88 | return outObj 89 | 90 | def tokenize_pretokenized_sentence(self, tokens, lang): 91 | assert lang in self.models 92 | # My name is Wasi Ahmad 93 | # token_ranges = [[0, 2], [3, 7], [8, 10], [11, 15], [16, 21]] 94 | offset, token_ranges = 0, [] 95 | for t in tokens: 96 | token_ranges.append([offset, offset + len(t)]) 97 | offset += len(t) + 1 98 | 99 | sentences = self.models[lang].tokenize(' '.join(tokens), 'ranges;presegmented') 100 | for s in sentences: 101 | self.models[lang].tag(s) 102 | self.models[lang].parse(s) 103 | conllu = self.models[lang].write(sentences, "conllu") 104 | sentences = parse(conllu) 105 | assert len(sentences) == 1 106 | 107 | words, deptags, upos, heads, word_to_token = [], [], [], [], [] 108 | _token_range = None 109 | for widx, word in enumerate(sentences[0]): 110 | word = sentences[0][widx] 111 | if word['misc'] is not None: 112 | _token_range = word['misc']['TokenRange'].split(':') 113 | start, end = int(_token_range[0]), int(_token_range[1]) 114 | if isinstance(word['id'], tuple): 115 | # multi-word token, e.g., word['id'] = (4, '-', 5) 116 | pass 117 | else: 118 | words.append(word['form']) 119 | deptags.append(word['deprel']) 120 | upos.append(word['upostag']) 121 | assert isinstance(word['head'], int) 122 | heads.append(word['head']) 123 | match_indices = [] 124 | # sometimes, during tokenization multiple tokens get merged 125 | # rect 230 550 300 620 Karl-Heinz Schnellinger 126 | # after tokenization 127 | # ['rect', '230 550 300 620', 'Karl-Heinz', 'Schnellinger'] 128 | for j, o in enumerate(token_ranges): 129 | if start >= o[0] and end <= o[1]: 130 | match_indices.append(j) 131 | break 132 | elif start == o[0]: 133 | match_indices.append(j) 134 | elif end == o[1]: 135 | match_indices.append(j) 136 | 137 | if len(match_indices) == 0: 138 | return None 139 | word_to_token.append(match_indices[0]) 140 | 141 | if len(words) != len(word_to_token): 142 | print(lang, tokens, words) 143 | assert False 144 | 145 | assert max(heads) <= len(heads) 146 | root, _ = head_to_tree(heads, words) 147 | # verifying if we can construct the tree from heads 148 | assert len(heads) == root.size() 149 | outObj = OrderedDict([ 150 | ('tokens', words), 151 | ('deptag', deptags), 152 | ('upostag', upos), 153 | ('head', heads), 154 | ('word_to_token', word_to_token) 155 | ]) 156 | 157 | return outObj 158 | 159 | 160 | def xnli_pawsx_process(example): 161 | premise = TOK.tokenize(example['premise'], lang=LANGUAGES[0]) 162 | hypothesis = TOK.tokenize(example['hypothesis'], lang=LANGUAGES[1]) 163 | 164 | if len(premise['tokens']) > 0 and len(hypothesis['tokens']) > 0: 165 | return { 166 | 'premise': { 167 | 'text': example['premise'], 168 | 'tokens': premise['tokens'][0], 169 | 'upos': premise['upos'][0], 170 | 'head': premise['head'][0], 171 | 'deprel': premise['deprel'][0], 172 | }, 173 | 'hypothesis': { 174 | 'text': example['hypothesis'], 175 | 'tokens': hypothesis['tokens'][0], 176 | 'upos': hypothesis['upos'][0], 177 | 'head': hypothesis['head'][0], 178 | 'deprel': hypothesis['deprel'][0], 179 | }, 180 | 'label': example['label'] 181 | } 182 | else: 183 | return None 184 | 185 | 186 | def xnli_pawsx_tokenization(infile, outfile, pre_lang, hyp_lang, workers=5): 187 | def load_dataset(path): 188 | """Load json file and store fields separately.""" 189 | output = [] 190 | with open(path) as f: 191 | for line in f: 192 | splits = line.strip().split('\t') 193 | if len(splits) != 3: 194 | continue 195 | output.append({ 196 | 'premise': splits[0], 197 | 'hypothesis': splits[1], 198 | 'label': splits[2] 199 | }) 200 | return output 201 | 202 | global LANGUAGES 203 | LANGUAGES = [pre_lang, hyp_lang] 204 | pool = Pool(workers, initializer=init) 205 | 206 | processed_dataset = [] 207 | dataset = load_dataset(infile) 208 | with tqdm(total=len(dataset), desc='Processing') as pbar: 209 | for i, ex in enumerate(pool.imap(xnli_pawsx_process, dataset, 100)): 210 | pbar.update() 211 | if ex is not None: 212 | processed_dataset.append(ex) 213 | 214 | with open(outfile, 'w', encoding='utf-8') as fw: 215 | data_to_write = [json.dumps(ex, ensure_ascii=False) for ex in processed_dataset] 216 | fw.write('\n'.join(data_to_write)) 217 | 218 | 219 | def panx_process(example): 220 | sentence = TOK.tokenize_pretokenized_sentence(example['sentence'], LANGUAGES[0]) 221 | if sentence is None: 222 | return None 223 | 224 | labels = [] 225 | for i, tidx in enumerate(sentence['word_to_token']): 226 | labels.append(example['label'][tidx]) 227 | 228 | assert len(sentence['tokens']) == len(labels) 229 | assert len(sentence['head']) == len(labels) 230 | 231 | return { 232 | 'tokens': sentence['tokens'], 233 | 'head': sentence['head'], 234 | 'deptag': sentence['deptag'], 235 | 'postag': sentence['upostag'], 236 | 'label': labels 237 | } 238 | 239 | 240 | def panx_tokenization(infile, outfile, pre_lang, workers=5, udpipe_model=None, separator='\t'): 241 | def load_dataset(path): 242 | """Load json file and store fields separately.""" 243 | output = [] 244 | with open(path, encoding='utf-8') as f: 245 | tokens, labels = [], [] 246 | for line in f: 247 | splits = line.strip().split(separator) 248 | if len(splits) == 2: 249 | tokens.append(splits[0]) 250 | labels.append(splits[1]) 251 | else: 252 | if tokens: 253 | output.append({ 254 | 'sentence': tokens, 255 | 'label': labels 256 | }) 257 | tokens, labels = [], [] 258 | 259 | if tokens: 260 | output.append({ 261 | 'sentence': tokens, 262 | 'label': labels 263 | }) 264 | 265 | return output 266 | 267 | processed_dataset = [] 268 | dataset = load_dataset(infile) 269 | 270 | global LANGUAGES, UDPIPE_MODELS 271 | LANGUAGES = [pre_lang] 272 | UDPIPE_MODELS = [udpipe_model] 273 | pool = Pool(workers, initializer=init) 274 | 275 | desc_msg = '[{}] Processing'.format(pre_lang) 276 | with tqdm(total=len(dataset), desc=desc_msg) as pbar: 277 | for i, ex in enumerate(pool.imap(panx_process, dataset, 100)): 278 | pbar.update() 279 | if ex is not None: 280 | processed_dataset.append(ex) 281 | 282 | assert len(processed_dataset) <= len(dataset) 283 | if len(processed_dataset) < len(dataset): 284 | print('{} out of {} examples are discarded'.format( 285 | len(dataset) - len(processed_dataset), len(dataset) 286 | )) 287 | 288 | with open(outfile, 'w', encoding='utf-8') as fw: 289 | for ex in processed_dataset: 290 | assert len(ex['tokens']) == len(ex['label']) == len(ex['head']) 291 | fw.write(json.dumps(ex) + '\n') 292 | 293 | 294 | def mtop_process(example): 295 | # { 296 | # 'tokens': words, 297 | # 'slot_labels': bio_tags, 298 | # 'intent_label': intent 299 | # } 300 | sentence = TOK.tokenize_pretokenized_sentence(example['tokens'], LANGUAGES[0]) 301 | if sentence is None: 302 | return None 303 | 304 | labels = [] 305 | for i, tidx in enumerate(sentence['word_to_token']): 306 | labels.append(example['slot_labels'][tidx]) 307 | 308 | assert len(sentence['tokens']) == len(labels) 309 | assert len(sentence['head']) == len(labels) 310 | 311 | return { 312 | 'tokens': sentence['tokens'], 313 | 'deptag': sentence['deptag'], 314 | 'postag': sentence['upostag'], 315 | 'head': sentence['head'], 316 | 'slot_labels': labels, 317 | 'intent_label': example['intent_label'] 318 | } 319 | 320 | 321 | def mtop_tokenization(infile, outfile, pre_lang, workers=5): 322 | processed_dataset = [] 323 | with open(infile) as f: 324 | dataset = [json.loads(line.strip()) for line in f] 325 | 326 | global LANGUAGES 327 | LANGUAGES = [pre_lang] 328 | pool = Pool(workers, initializer=init) 329 | 330 | desc_msg = '[{}] Processing'.format(pre_lang) 331 | with tqdm(total=len(dataset), desc=desc_msg) as pbar: 332 | for i, ex in enumerate(pool.imap(mtop_process, dataset, 100)): 333 | pbar.update() 334 | if ex is not None: 335 | processed_dataset.append(ex) 336 | 337 | assert len(processed_dataset) <= len(dataset) 338 | if len(processed_dataset) < len(dataset): 339 | print('{} out of {} examples are discarded'.format( 340 | len(dataset) - len(processed_dataset), len(dataset) 341 | )) 342 | 343 | with open(outfile, 'w', encoding='utf-8') as fw: 344 | for ex in processed_dataset: 345 | fw.write(json.dumps(ex) + '\n') 346 | 347 | 348 | def matis_tokenization(infile, outfile, pre_lang, workers=5): 349 | processed_dataset = [] 350 | dataset = [] 351 | mismatch = 0 352 | with open(infile) as f: 353 | next(f) 354 | for line in f: 355 | split = line.strip().split('\t') 356 | tokens = split[1].split() 357 | slot_labels = split[2].split() 358 | if len(tokens) != len(slot_labels): 359 | if len(tokens) != len(slot_labels): 360 | mismatch += 1 361 | # print(split[0], tokens, slot_labels, len(tokens), len(slot_labels)) 362 | continue 363 | dataset.append({ 364 | 'tokens': tokens, 365 | 'slot_labels': slot_labels, 366 | 'intent_label': split[3] 367 | }) 368 | 369 | print('{} examples are discarded due to mismatch in #tokens and #slot_labels'.format(mismatch)) 370 | 371 | global LANGUAGES 372 | LANGUAGES = [pre_lang] 373 | pool = Pool(workers, initializer=init) 374 | 375 | desc_msg = '[{}] Processing'.format(pre_lang) 376 | with tqdm(total=len(dataset), desc=desc_msg) as pbar: 377 | for i, ex in enumerate(pool.imap(mtop_process, dataset, 100)): 378 | pbar.update() 379 | if ex is not None: 380 | processed_dataset.append(ex) 381 | 382 | assert len(processed_dataset) <= len(dataset) 383 | if len(processed_dataset) < len(dataset): 384 | print('{} out of {} examples are discarded'.format( 385 | len(dataset) - len(processed_dataset), len(dataset) 386 | )) 387 | 388 | with open(outfile, 'w', encoding='utf-8') as fw: 389 | for ex in processed_dataset: 390 | fw.write(json.dumps(ex) + '\n') 391 | 392 | 393 | if __name__ == '__main__': 394 | languages = ['af', 'ar', 'bg', 'de', 'el', 'en', 'es', 'et', 'fi', 'fr', 395 | 'he', 'hi', 'hu', 'id', 'it', 'ja', 'ko', 'mr', 'nl', 'pt', 396 | 'ru', 'ta', 'te', 'tr', 'ur', 'vi', 'zh'] 397 | parser = argparse.ArgumentParser() 398 | parser.add_argument('--input_file', type=str, required=True, help="Path of the source data file") 399 | parser.add_argument('--output_file', type=str, required=True, help="Path of the processed data file") 400 | parser.add_argument('--pre_lang', type=str, help="Premise language", default='en', choices=languages) 401 | parser.add_argument('--hyp_lang', type=str, help="Hypothesis language", default='en', choices=languages) 402 | parser.add_argument('--task', type=str, default='pawsx', help="Task name", 403 | choices=['pawsx', 'xnli', 'panx', 'mtop', 'matis', 'ner']) 404 | parser.add_argument('--tokenizer', type=str, default='udpipe', choices=['udpipe'], 405 | help="How to perform tokenization") 406 | parser.add_argument('--udpipe_model', type=str, default=None, 407 | help="Path of the UDPipe model") 408 | parser.add_argument('--workers', type=int, default=60) 409 | args = parser.parse_args() 410 | 411 | if args.tokenizer == 'udpipe': 412 | if args.task in ['pawsx', 'xnli']: 413 | xnli_pawsx_tokenization( 414 | args.input_file, args.output_file, args.pre_lang, args.hyp_lang, args.workers 415 | ) 416 | elif args.task == 'panx': 417 | panx_tokenization( 418 | args.input_file, args.output_file, args.pre_lang, args.workers, args.udpipe_model 419 | ) 420 | elif args.task == 'ner': 421 | panx_tokenization( 422 | args.input_file, args.output_file, args.pre_lang, args.workers, 423 | args.udpipe_model, ' ' 424 | ) 425 | elif args.task == 'mtop': 426 | mtop_tokenization( 427 | args.input_file, args.output_file, args.pre_lang, args.workers 428 | ) 429 | elif args.task == 'matis': 430 | matis_tokenization( 431 | args.input_file, args.output_file, args.pre_lang, args.workers 432 | ) 433 | -------------------------------------------------------------------------------- /udpipe/xnli.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CURRENT_DIR=`pwd` 4 | HOME_DIR=`realpath ..` 5 | LANG=(en fr es de el bg ru tr ar vi zh hi ur) 6 | 7 | ############################# Downloading UDPipe ############################# 8 | 9 | URL_PREFIX='https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-3131' 10 | declare -A LANG_MAP 11 | LANG_MAP['en']='english-ewt-ud-2.5-191206.udpipe' 12 | LANG_MAP['fr']='french-gsd-ud-2.5-191206.udpipe' 13 | LANG_MAP['es']='spanish-gsd-ud-2.5-191206.udpipe' 14 | LANG_MAP['de']='german-gsd-ud-2.5-191206.udpipe' 15 | LANG_MAP['el']='greek-gdt-ud-2.5-191206.udpipe' 16 | LANG_MAP['bg']='bulgarian-btb-ud-2.5-191206.udpipe' 17 | LANG_MAP['ru']='russian-gsd-ud-2.5-191206.udpipe' 18 | LANG_MAP['tr']='turkish-imst-ud-2.5-191206.udpipe' 19 | LANG_MAP['ar']='arabic-padt-ud-2.5-191206.udpipe' 20 | LANG_MAP['vi']='vietnamese-vtb-ud-2.5-191206.udpipe' 21 | LANG_MAP['zh']='chinese-gsd-ud-2.5-191206.udpipe' 22 | LANG_MAP['hi']='hindi-hdtb-ud-2.5-191206.udpipe' 23 | LANG_MAP['ur']='urdu-udtb-ud-2.5-191206.udpipe' 24 | 25 | OUT_DIR=${CURRENT_DIR}/models 26 | mkdir -p $OUT_DIR 27 | 28 | for lang in ${LANG[@]}; do 29 | if [[ ! -f ${OUT_DIR}/${LANG_MAP[${lang}]} ]]; then 30 | curl -Lo ${OUT_DIR}/${LANG_MAP[${lang}]} ${URL_PREFIX}/${LANG_MAP[${lang}]} 31 | fi 32 | done 33 | 34 | ############################# 35 | 36 | DATA_DIR=${HOME_DIR}/download/xnli 37 | OUT_DIR=${HOME_DIR}/download/xnli_udpipe_processed 38 | mkdir -p $OUT_DIR 39 | 40 | # train data processing 41 | python process.py \ 42 | --task xnli \ 43 | --input_file ${DATA_DIR}/train-en.tsv \ 44 | --output_file ${OUT_DIR}/train-en.jsonl \ 45 | --pre_lang en \ 46 | --hyp_lang en \ 47 | --workers 60; 48 | 49 | for split in dev test; do 50 | for lang in "${LANG[@]}"; do 51 | python process.py \ 52 | --task xnli \ 53 | --input_file ${DATA_DIR}/${split}-${lang}.tsv \ 54 | --output_file ${OUT_DIR}/${split}-${lang}.jsonl \ 55 | --pre_lang $lang \ 56 | --hyp_lang $lang \ 57 | --workers 60; 58 | done 59 | done 60 | -------------------------------------------------------------------------------- /utils_preprocess.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 Google and DeepMind. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function 16 | 17 | import os 18 | import ast 19 | import csv 20 | import json 21 | import shutil 22 | import random 23 | import argparse 24 | 25 | from collections import defaultdict 26 | from transformers import BertTokenizer, XLMTokenizer, XLMRobertaTokenizer 27 | 28 | TOKENIZERS = { 29 | "bert": BertTokenizer, 30 | "xlm": XLMTokenizer, 31 | "xlmr": XLMRobertaTokenizer, 32 | "syntax-xlmr": XLMRobertaTokenizer, 33 | } 34 | 35 | 36 | def panx_tokenize_preprocess(args): 37 | def _preprocess_one_file(infile, outfile, idxfile, tokenizer, max_len): 38 | if not os.path.exists(infile): 39 | print(f"{infile} not exists") 40 | return 0 41 | special_tokens_count = 3 if isinstance(tokenizer, XLMRobertaTokenizer) else 2 42 | max_seq_len = max_len - special_tokens_count 43 | subword_len_counter = idx = 0 44 | exception_counter = 0 45 | with open(infile, "rt") as fin, open(outfile, "w") as fout, open( 46 | idxfile, "w" 47 | ) as fidx: 48 | for line in fin: 49 | line = line.strip() 50 | if not line: 51 | fout.write("\n") 52 | fidx.write("\n") 53 | idx += 1 54 | subword_len_counter = 0 55 | continue 56 | 57 | items = line.split() 58 | token = items[0].strip() 59 | label = items[1].strip() 60 | head = -1 61 | if len(items) == 3: 62 | head = items[2].strip() 63 | current_subwords_len = len(tokenizer.tokenize(token)) 64 | 65 | if ( 66 | current_subwords_len == 0 or current_subwords_len > max_seq_len 67 | ) and len(token) != 0: 68 | token = tokenizer.unk_token 69 | current_subwords_len = 1 70 | 71 | # adding current token exceeds max sequence length 72 | # so start a new sequence 73 | if (subword_len_counter + current_subwords_len) > max_seq_len: 74 | fout.write(f"\n{token}\t{label}\t{head}\n") 75 | fidx.write(f"\n{idx}\n") 76 | subword_len_counter = current_subwords_len 77 | exception_counter += 1 78 | else: 79 | fout.write(f"{token}\t{label}\t{head}\n") 80 | fidx.write(f"{idx}\n") 81 | subword_len_counter += current_subwords_len 82 | 83 | print(f"{exception_counter} examples are truncated.") 84 | return 1 85 | 86 | model_type = args.model_type 87 | tokenizer = TOKENIZERS[model_type].from_pretrained( 88 | args.model_name_or_path, 89 | do_lower_case=args.do_lower_case, 90 | cache_dir=args.cache_dir if args.cache_dir else None, 91 | ) 92 | for lang in args.languages.split(","): 93 | out_dir = os.path.join(args.output_dir, lang) 94 | if not os.path.exists(out_dir): 95 | os.makedirs(out_dir) 96 | if lang == "en": 97 | files = ["dev", "test", "train"] 98 | else: 99 | files = ["dev", "test"] 100 | suffix = args.model_name_or_path 101 | if os.path.isdir(suffix): 102 | suffix = list(filter(None, suffix.split("/"))).pop() 103 | for file in files: 104 | infile = os.path.join(args.data_dir, f"{file}-{lang}.tsv") 105 | outfile = os.path.join(out_dir, "{}.{}".format(file, suffix)) 106 | idxfile = os.path.join(out_dir, "{}.{}.idx".format(file, suffix)) 107 | if os.path.exists(outfile) and os.path.exists(idxfile): 108 | print(f"{outfile} and {idxfile} exist") 109 | else: 110 | code = _preprocess_one_file( 111 | infile, outfile, idxfile, tokenizer, args.max_len 112 | ) 113 | if code > 0: 114 | print(f"finish preprocessing {outfile}") 115 | 116 | 117 | def panx_preprocess(args): 118 | def _process_one_file(infile, outfile): 119 | lines = open(infile, "r").readlines() 120 | if lines[-1].strip() == "": 121 | lines = lines[:-1] 122 | with open(outfile, "w") as fout: 123 | for l in lines: 124 | items = l.strip().split("\t") 125 | if len(items) == 2: 126 | label = items[1].strip() 127 | idx = items[0].find(":") 128 | if idx != -1: 129 | token = items[0][idx + 1:].strip() 130 | fout.write(f"{token}\t{label}\n") 131 | # if 'test' in infile: 132 | # fout.write(f'{token}\n') 133 | # else: 134 | # fout.write(f'{token}\t{label}\n') 135 | else: 136 | fout.write("\n") 137 | 138 | if not os.path.exists(args.output_dir): 139 | os.makedirs(args.output_dir) 140 | langs = "ar he vi id jv ms tl eu ml ta te af nl en de el bn hi mr ur fa fr it pt es bg ru ja ka ko th sw yo my zh kk tr et fi hu".split( 141 | " " 142 | ) 143 | for lg in langs: 144 | for split in ["train", "test", "dev"]: 145 | infile = os.path.join(args.data_dir, f"{lg}-{split}") 146 | outfile = os.path.join(args.output_dir, f"{split}-{lg}.tsv") 147 | _process_one_file(infile, outfile) 148 | 149 | 150 | def udpos_tokenize_preprocess(args): 151 | def _preprocess_one_file(infile, outfile, idxfile, tokenizer, max_len): 152 | if not os.path.exists(infile): 153 | print(f"{infile} does not exist") 154 | return 155 | subword_len_counter = idx = 0 156 | special_tokens_count = 3 if isinstance(tokenizer, XLMRobertaTokenizer) else 2 157 | max_seq_len = max_len - special_tokens_count 158 | with open(infile, "rt") as fin, open(outfile, "w") as fout, open( 159 | idxfile, "w" 160 | ) as fidx: 161 | for line in fin: 162 | line = line.strip() 163 | if len(line) == 0 or line == "": 164 | fout.write("\n") 165 | fidx.write("\n") 166 | idx += 1 167 | subword_len_counter = 0 168 | continue 169 | 170 | items = line.split() 171 | if len(items) == 2: 172 | label = items[1].strip() 173 | else: 174 | label = "X" 175 | token = items[0].strip() 176 | current_subwords_len = len(tokenizer.tokenize(token)) 177 | 178 | if ( 179 | current_subwords_len == 0 or current_subwords_len > max_seq_len 180 | ) and len(token) != 0: 181 | token = tokenizer.unk_token 182 | current_subwords_len = 1 183 | 184 | if (subword_len_counter + current_subwords_len) > max_seq_len: 185 | fout.write(f"\n{token}\t{label}\n") 186 | fidx.write(f"\n{idx}\n") 187 | subword_len_counter = current_subwords_len 188 | else: 189 | fout.write(f"{token}\t{label}\n") 190 | fidx.write(f"{idx}\n") 191 | subword_len_counter += current_subwords_len 192 | 193 | model_type = args.model_type 194 | tokenizer = TOKENIZERS[model_type].from_pretrained( 195 | args.model_name_or_path, 196 | do_lower_case=args.do_lower_case, 197 | cache_dir=args.cache_dir if args.cache_dir else None, 198 | ) 199 | for lang in args.languages.split(","): 200 | out_dir = os.path.join(args.output_dir, lang) 201 | if not os.path.exists(out_dir): 202 | os.makedirs(out_dir) 203 | if lang == "en": 204 | files = ["dev", "test", "train"] 205 | else: 206 | files = ["dev", "test"] 207 | suffix = args.model_name_or_path 208 | if os.path.isdir(suffix): 209 | suffix = list(filter(None, suffix.split("/"))).pop() 210 | for file in files: 211 | infile = os.path.join(args.data_dir, "{}-{}.tsv".format(file, lang)) 212 | outfile = os.path.join(out_dir, "{}.{}".format(file, suffix)) 213 | idxfile = os.path.join(out_dir, "{}.{}.idx".format(file, suffix)) 214 | if os.path.exists(outfile) and os.path.exists(idxfile): 215 | print(f"{outfile} and {idxfile} exist") 216 | else: 217 | _preprocess_one_file(infile, outfile, idxfile, tokenizer, args.max_len) 218 | print(f"finish preprocessing {outfile}") 219 | 220 | 221 | def udpos_preprocess(args): 222 | def _read_one_file(file): 223 | data = [] 224 | sent, tag, lines, head, deptag = [], [], [], [], [] 225 | for line in open(file, "r"): 226 | items = line.strip().split("\t") 227 | if len(items) != 10: 228 | empty = all(w == "_" for w in sent) 229 | if not empty: 230 | data.append((sent, tag, head, deptag, lines)) 231 | sent, tag, head, deptag, lines = [], [], [], [], [] 232 | else: 233 | sent.append(items[1].strip()) 234 | tag.append(items[3].strip()) 235 | deptag.append(items[7].strip()) 236 | head.append(int(items[6].strip())) 237 | lines.append(line.strip()) 238 | assert len(sent) == int(items[0]), \ 239 | "line={}, sent={}, tag={}".format(line, sent, tag) 240 | return data 241 | 242 | def isfloat(value): 243 | try: 244 | float(value) 245 | return True 246 | except ValueError: 247 | return False 248 | 249 | def remove_empty_space(data): 250 | new_data = {} 251 | for split in data: 252 | new_data[split] = [] 253 | for sent, tag, head, deptag, lines in data[split]: 254 | new_sent = ["".join(w.replace("\u200c", "").split(" ")) for w in sent] 255 | lines = [line.replace("\u200c", "") for line in lines] 256 | assert len(" ".join(new_sent).split(" ")) == len(tag) 257 | new_data[split].append((new_sent, tag, head, deptag, lines)) 258 | return new_data 259 | 260 | def check_file(file): 261 | for i, l in enumerate(open(file)): 262 | items = l.strip().split("\t") 263 | assert len(items[0].split(" ")) == len( 264 | items[1].split(" ") 265 | ), "idx={}, line={}".format(i, l) 266 | 267 | def _write_files(data, output_dir, lang, suffix): 268 | for split in data: 269 | if len(data[split]) > 0: 270 | prefix = os.path.join(output_dir, f"{split}-{lang}") 271 | if suffix == "mt": 272 | with open(prefix + ".mt.tsv", "w") as fout: 273 | for idx, (sent, tag, _) in enumerate(data[split]): 274 | newline = "\n" if idx != len(data[split]) - 1 else "" 275 | fout.write( 276 | "{}\t{}{}".format( 277 | " ".join(sent), " ".join(tag), newline 278 | ) 279 | ) 280 | # if split == 'test': 281 | # fout.write('{}{}'.format(' '.join(sent, newline))) 282 | # else: 283 | # fout.write('{}\t{}{}'.format(' '.join(sent), ' '.join(tag), newline)) 284 | check_file(prefix + ".mt.tsv") 285 | print(" - finish checking " + prefix + ".mt.tsv") 286 | elif suffix == "tsv": 287 | with open(prefix + ".tsv", "w") as fout: 288 | for sidx, (sent, tag, head, deptag, _) in enumerate(data[split]): 289 | for widx, (w, t, h, d) in enumerate(zip(sent, tag, head, deptag)): 290 | newline = ( 291 | "" 292 | if (sidx == len(data[split]) - 1) 293 | and (widx == len(sent) - 1) 294 | else "\n" 295 | ) 296 | fout.write("{}\t{}\t{}\t{}{}".format(w, t, h, d, newline)) 297 | # if split == 'test': 298 | # fout.write('{}{}'.format(w, newline)) 299 | # else: 300 | # fout.write('{}\t{}{}'.format(w, t, newline)) 301 | fout.write("\n") 302 | elif suffix == "conll": 303 | with open(prefix + ".conll", "w") as fout: 304 | for _, _, lines in data[split]: 305 | for l in lines: 306 | fout.write(l.strip() + "\n") 307 | fout.write("\n") 308 | print(f"finish writing file to {prefix}.{suffix}") 309 | 310 | if not os.path.exists(args.output_dir): 311 | os.makedirs(args.output_dir) 312 | 313 | languages = \ 314 | "af ar bg de el en es et eu fa fi fr he hi hu id it ja kk ko mr nl pt ru ta te th tl tr ur vi yo zh".split(" ") 315 | for root, dirs, files in os.walk(args.data_dir): 316 | lg = root.strip().split("/")[-1] 317 | if root == args.data_dir or lg not in languages: 318 | continue 319 | 320 | data = {k: [] for k in ["train", "dev", "test"]} 321 | for f in sorted(files): 322 | if f.endswith("conll"): 323 | file = os.path.join(root, f) 324 | examples = _read_one_file(file) 325 | if "train" in f: 326 | data["train"].extend(examples) 327 | elif "dev" in f: 328 | data["dev"].extend(examples) 329 | elif "test" in f: 330 | data["test"].extend(examples) 331 | else: 332 | print("split not found: ", file) 333 | print( 334 | " - finish reading {}, {}".format( 335 | file, [(k, len(v)) for k, v in data.items()] 336 | ) 337 | ) 338 | 339 | data = remove_empty_space(data) 340 | for sub in ["tsv"]: 341 | _write_files(data, args.output_dir, lg, sub) 342 | 343 | 344 | def pawsx_preprocess(args): 345 | def _preprocess_one_file(infile, outfile, remove_label=False): 346 | data = [] 347 | for i, line in enumerate(open(infile, "r")): 348 | if i == 0: 349 | continue 350 | items = line.strip().split("\t") 351 | sent1 = " ".join(items[1].strip().split(" ")) 352 | sent2 = " ".join(items[2].strip().split(" ")) 353 | label = items[3] 354 | data.append([sent1, sent2, label]) 355 | 356 | with open(outfile, "w") as fout: 357 | writer = csv.writer(fout, delimiter="\t") 358 | for sent1, sent2, label in data: 359 | if remove_label: 360 | writer.writerow([sent1, sent2]) 361 | else: 362 | writer.writerow([sent1, sent2, label]) 363 | 364 | if not os.path.exists(args.output_dir): 365 | os.makedirs(args.output_dir) 366 | 367 | split2file = {"train": "train", "test": "test_2k", "dev": "dev_2k"} 368 | for lang in ["en", "de", "es", "fr", "ja", "ko", "zh"]: 369 | for split in ["train", "test", "dev"]: 370 | if split == "train" and lang != "en": 371 | # continue 372 | file = 'translated_train' 373 | else: 374 | file = split2file[split] 375 | infile = os.path.join(args.data_dir, lang, "{}.tsv".format(file)) 376 | outfile = os.path.join(args.output_dir, "{}-{}.tsv".format(split, lang)) 377 | # _preprocess_one_file(infile, outfile, remove_label=(split == 'test')) 378 | _preprocess_one_file(infile, outfile) 379 | print(f"finish preprocessing {outfile}") 380 | 381 | 382 | def xnli_preprocess(args): 383 | def _preprocess_file(infile, output_dir, split): 384 | all_langs = defaultdict(list) 385 | for i, line in enumerate(open(infile, "r")): 386 | if i == 0: 387 | continue 388 | 389 | items = line.strip().split("\t") 390 | lang = items[0].strip() 391 | label = ( 392 | "contradiction" 393 | if items[1].strip() == "contradictory" 394 | else items[1].strip() 395 | ) 396 | sent1 = " ".join(items[6].strip().split(" ")) 397 | sent2 = " ".join(items[7].strip().split(" ")) 398 | all_langs[lang].append((sent1, sent2, label)) 399 | print(f"# langs={len(all_langs)}") 400 | for lang, pairs in all_langs.items(): 401 | outfile = os.path.join(output_dir, "{}-{}.tsv".format(split, lang)) 402 | with open(outfile, "w") as fout: 403 | writer = csv.writer(fout, delimiter="\t") 404 | for (sent1, sent2, label) in pairs: 405 | writer.writerow([sent1, sent2, label]) 406 | # if split == 'test': 407 | # writer.writerow([sent1, sent2]) 408 | # else: 409 | # writer.writerow([sent1, sent2, label]) 410 | print(f"finish preprocess {outfile}") 411 | 412 | def _preprocess_train_file(infile, outfile): 413 | with open(outfile, "w") as fout: 414 | writer = csv.writer(fout, delimiter="\t") 415 | for i, line in enumerate(open(infile, "r")): 416 | if i == 0: 417 | continue 418 | 419 | items = line.strip().split("\t") 420 | sent1 = " ".join(items[0].strip().split(" ")) 421 | sent2 = " ".join(items[1].strip().split(" ")) 422 | label = ( 423 | "contradiction" 424 | if items[2].strip() == "contradictory" 425 | else items[2].strip() 426 | ) 427 | writer.writerow([sent1, sent2, label]) 428 | print(f"finish preprocess {outfile}") 429 | 430 | train_langs = 'ar,bg,de,el,en,es,fr,hi,ru,tr,ur,vi,zh'.split(',') 431 | if not os.path.exists(args.output_dir): 432 | os.makedirs(args.output_dir) 433 | for l in train_langs: 434 | infile = os.path.join(args.data_dir, "XNLI-MT-1.0/multinli/multinli.train.{}.tsv".format(l)) 435 | outfile = os.path.join(args.output_dir, "train-{}.tsv".format(l)) 436 | _preprocess_train_file(infile, outfile) 437 | 438 | for split in ["test", "dev"]: 439 | infile = os.path.join(args.data_dir, "XNLI-1.0/xnli.{}.tsv".format(split)) 440 | print(f"reading file {infile}") 441 | _preprocess_file(infile, args.output_dir, split) 442 | 443 | 444 | def tatoeba_preprocess(args): 445 | lang3_dict = { 446 | "afr": "af", 447 | "ara": "ar", 448 | "bul": "bg", 449 | "ben": "bn", 450 | "deu": "de", 451 | "ell": "el", 452 | "spa": "es", 453 | "est": "et", 454 | "eus": "eu", 455 | "pes": "fa", 456 | "fin": "fi", 457 | "fra": "fr", 458 | "heb": "he", 459 | "hin": "hi", 460 | "hun": "hu", 461 | "ind": "id", 462 | "ita": "it", 463 | "jpn": "ja", 464 | "jav": "jv", 465 | "kat": "ka", 466 | "kaz": "kk", 467 | "kor": "ko", 468 | "mal": "ml", 469 | "mar": "mr", 470 | "nld": "nl", 471 | "por": "pt", 472 | "rus": "ru", 473 | "swh": "sw", 474 | "tam": "ta", 475 | "tel": "te", 476 | "tha": "th", 477 | "tgl": "tl", 478 | "tur": "tr", 479 | "urd": "ur", 480 | "vie": "vi", 481 | "cmn": "zh", 482 | "eng": "en", 483 | } 484 | if not os.path.exists(args.output_dir): 485 | os.makedirs(args.output_dir) 486 | for sl3, sl2 in lang3_dict.items(): 487 | if sl3 != "eng": 488 | src_file = f"{args.data_dir}/tatoeba.{sl3}-eng.{sl3}" 489 | tgt_file = f"{args.data_dir}/tatoeba.{sl3}-eng.eng" 490 | src_out = f"{args.output_dir}/{sl2}-en.{sl2}" 491 | tgt_out = f"{args.output_dir}/{sl2}-en.en" 492 | gold_out = f"{args.output_dir}/{sl2}-en.en.gold" 493 | shutil.copy(src_file, src_out) 494 | tgts = [l.strip() for l in open(tgt_file)] 495 | idx = range(len(tgts)) 496 | data = zip(tgts, idx) 497 | with open(tgt_out, "w") as ftgt, open(gold_out, "w") as fgold: 498 | for t, i in sorted(data, key=lambda x: x[0]): 499 | ftgt.write(f"{t}\n") 500 | fgold.write(f"{i}\n") 501 | 502 | 503 | def xquad_preprocess(args): 504 | pass 505 | # Remove the test annotations to prevent accidental cheating 506 | # remove_qa_test_annotations(args.data_dir) 507 | 508 | 509 | def mlqa_preprocess(args): 510 | pass 511 | # Remove the test annotations to prevent accidental cheating 512 | # remove_qa_test_annotations(args.data_dir) 513 | 514 | 515 | def tydiqa_preprocess(args): 516 | LANG2ISO = { 517 | "arabic": "ar", 518 | "bengali": "bn", 519 | "english": "en", 520 | "finnish": "fi", 521 | "indonesian": "id", 522 | "korean": "ko", 523 | "russian": "ru", 524 | "swahili": "sw", 525 | "telugu": "te", 526 | } 527 | assert os.path.exists(args.data_dir) 528 | train_file = os.path.join(args.data_dir, "tydiqa-goldp-v1.1-train.json") 529 | os.makedirs(args.output_dir, exist_ok=True) 530 | 531 | # Split the training file into language-specific files 532 | lang2data = defaultdict(list) 533 | with open(train_file, "r") as f_in: 534 | data = json.load(f_in) 535 | version = data["version"] 536 | for doc in data["data"]: 537 | for par in doc["paragraphs"]: 538 | context = par["context"] 539 | for qa in par["qas"]: 540 | question = qa["question"] 541 | question_id = qa["id"] 542 | example_lang = question_id.split("-")[0] 543 | q_id = question_id.split("-")[-1] 544 | for answer in qa["answers"]: 545 | a_start, a_text = answer["answer_start"], answer["text"] 546 | a_end = a_start + len(a_text) 547 | assert context[a_start:a_end] == a_text 548 | lang2data[example_lang].append( 549 | { 550 | "paragraphs": [ 551 | { 552 | "context": context, 553 | "qas": [ 554 | { 555 | "answers": qa["answers"], 556 | "question": question, 557 | "id": q_id, 558 | } 559 | ], 560 | } 561 | ] 562 | } 563 | ) 564 | 565 | for lang, data in lang2data.items(): 566 | out_file = os.path.join( 567 | args.output_dir, "tydiqa.%s.train.json" % LANG2ISO[lang] 568 | ) 569 | with open(out_file, "w") as f: 570 | json.dump({"data": data, "version": version}, f) 571 | 572 | # Rename the dev files 573 | dev_dir = os.path.join(args.data_dir, "tydiqa-goldp-v1.1-dev") 574 | assert os.path.exists(dev_dir) 575 | for lang, iso in LANG2ISO.items(): 576 | src_file = os.path.join(dev_dir, "tydiqa-goldp-dev-%s.json" % lang) 577 | dst_file = os.path.join(dev_dir, "tydiqa.%s.dev.json" % iso) 578 | os.rename(src_file, dst_file) 579 | 580 | # Remove the test annotations to prevent accidental cheating 581 | # remove_qa_test_annotations(dev_dir) 582 | 583 | 584 | def remove_qa_test_annotations(test_dir): 585 | assert os.path.exists(test_dir) 586 | for file_name in os.listdir(test_dir): 587 | new_data = [] 588 | test_file = os.path.join(test_dir, file_name) 589 | with open(test_file, "r") as f: 590 | data = json.load(f) 591 | version = data["version"] 592 | for doc in data["data"]: 593 | for par in doc["paragraphs"]: 594 | context = par["context"] 595 | for qa in par["qas"]: 596 | question = qa["question"] 597 | question_id = qa["id"] 598 | for answer in qa["answers"]: 599 | a_start, a_text = answer["answer_start"], answer["text"] 600 | a_end = a_start + len(a_text) 601 | assert context[a_start:a_end] == a_text 602 | new_data.append( 603 | { 604 | "paragraphs": [ 605 | { 606 | "context": context, 607 | "qas": [ 608 | { 609 | "answers": [ 610 | {"answer_start": 0, "text": ""} 611 | ], 612 | "question": question, 613 | "id": question_id, 614 | } 615 | ], 616 | } 617 | ] 618 | } 619 | ) 620 | with open(test_file, "w") as f: 621 | json.dump({"data": new_data, "version": version}, f) 622 | 623 | 624 | def mtop_preprocess(args): 625 | DOMAINS = [ 626 | 'alarm', 'calling', 'event', 'messaging', 'music', 'news', 627 | 'people', 'recipes', 'reminder', 'timer', 'watcher' 628 | ] 629 | FIELD_NAMES = ['id', 'intent', 'slots', 'utterance', 'domain', 'language', 'compositional_rep', 'tokens'] 630 | 631 | def convert(infile, outfile): 632 | with open(outfile, 'w', encoding='utf-8') as fw: 633 | with open(infile, 'r') as f: 634 | for line in f: 635 | line = line.strip() 636 | if not line: 637 | continue 638 | fields = line.split('\t') 639 | intent = fields[1].split(':')[1] 640 | slots = [] 641 | # slots field can be empty string 642 | if len(fields[2]) > 0: 643 | # slots are comma separated 644 | # 4:20:SL:CONTACT,21:26:SL:TYPE_CONTENT,36:38:SL:RECIPIENT 645 | _slots = fields[2].split(',') 646 | for s in _slots: 647 | if s: 648 | # each slot has 4 fields, separated by ':' 649 | # start_offset:end_offset:SL:slot_tag 650 | _s = s.split(':') 651 | assert len(_s) == 4, (s, _s) 652 | slots.append([int(_s[0]), int(_s[1]), _s[3]]) 653 | 654 | assert len(fields) == len(FIELD_NAMES) 655 | try: 656 | # obj = json.loads(fields[7]) 657 | obj = ast.literal_eval(fields[7]) 658 | except ValueError as e: 659 | print(fields[7]) 660 | print(e) 661 | 662 | tokens = [] 663 | for j, (w, span) in enumerate(zip(obj['tokens'], obj['tokenSpans'])): 664 | tokens.append({ 665 | 'token': w, 666 | 'token_range': (span['start'], span['start'] + span['length']) 667 | }) 668 | # print(tokens) 669 | 670 | words = [] 671 | bio_tags = ['O'] * len(tokens) 672 | total_slots = 0 673 | for j, tok in enumerate(tokens): 674 | words.append(tok['token']) 675 | tok_start, tok_end = tok['token_range'] 676 | # we match word offset with slot offsets to find proper tag 677 | for slot_start, slot_end, slot_tag in slots: 678 | # if a token starts at the slot start offset 679 | if tok_start == slot_start: 680 | bio_tags[j] = 'B-{}'.format(slot_tag) 681 | total_slots += 1 682 | break 683 | # if a token starts after a slot's start offset 684 | # but the token ends at or before the slot ends 685 | elif tok_start > slot_start and tok_end <= slot_end: 686 | # an I-{} tag must be preceded by a B-{} or I-{} tag 687 | if bio_tags[j - 1] in ['B-{}'.format(slot_tag), 'I-{}'.format(slot_tag)]: 688 | bio_tags[j] = 'I-{}'.format(slot_tag) 689 | break 690 | 691 | assert len(words) == len(bio_tags) 692 | fw.write(json.dumps({ 693 | 'tokens': words, 694 | 'slot_labels': bio_tags, 695 | 'intent_label': intent 696 | }) + '\n') 697 | 698 | # infile contains one example per line (tab separated fields) 699 | # there must be len(FIELD_NAMES) fields in every line 700 | def vocab_process(data_dir): 701 | slot_label_vocab = 'slot_label.txt' 702 | intent_label_vocab = 'intent_label.txt' 703 | 704 | intent_vocab = set() 705 | slot_vocab = set() 706 | for lang in ['en', 'de', 'es', 'fr', 'hi', 'th']: 707 | for split in ['train', 'dev', 'test']: 708 | with open(os.path.join(data_dir, '{}-{}.jsonl'.format(split, lang)), 709 | 'r', encoding='utf-8') as f_r: 710 | for line in f_r: 711 | line = line.strip() 712 | ex = json.loads(line) 713 | intent_vocab.add(ex['intent_label']) 714 | for slot in ex['slot_labels']: 715 | slot_vocab.add(slot) 716 | 717 | with open(os.path.join(data_dir, intent_label_vocab), 'w', encoding='utf-8') as f_w1, \ 718 | open(os.path.join(data_dir, slot_label_vocab), 'w', encoding='utf-8') as f_w2: 719 | intent_vocab = sorted(list(intent_vocab)) 720 | for intent in intent_vocab: 721 | f_w1.write(intent + '\n') 722 | 723 | slot_vocab = sorted(list(slot_vocab), key=lambda x: (x[2:], x[:2])) 724 | for slot in slot_vocab: 725 | f_w2.write(slot + '\n') 726 | 727 | for lang in ['en', 'de', 'es', 'fr', 'hi', 'th']: 728 | for split in ['train', 'eval', 'test']: 729 | convert( 730 | os.path.join(args.data_dir, lang, '{}.txt'.format(split)), 731 | os.path.join( 732 | args.output_dir, 733 | '{}-{}.jsonl'.format('dev' if split == 'eval' else split, lang) 734 | ) 735 | ) 736 | vocab_process(args.output_dir) 737 | 738 | 739 | if __name__ == "__main__": 740 | parser = argparse.ArgumentParser() 741 | 742 | ## Required parameters 743 | parser.add_argument( 744 | "--data_dir", 745 | default=None, 746 | type=str, 747 | required=True, 748 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.", 749 | ) 750 | parser.add_argument( 751 | "--output_dir", 752 | default=None, 753 | type=str, 754 | required=True, 755 | help="The output data dir where any processed files will be written to.", 756 | ) 757 | parser.add_argument( 758 | "--task", default="panx", type=str, required=True, help="The task name" 759 | ) 760 | parser.add_argument( 761 | "--model_name_or_path", 762 | default="bert-base-multilingual-cased", 763 | type=str, 764 | help="The pre-trained model", 765 | ) 766 | parser.add_argument("--model_type", default="bert", type=str, help="model type") 767 | parser.add_argument( 768 | "--max_len", default=512, type=int, help="the maximum length of sentences" 769 | ) 770 | parser.add_argument( 771 | "--do_lower_case", action="store_true", help="whether to do lower case" 772 | ) 773 | parser.add_argument("--cache_dir", default=None, type=str, help="cache directory") 774 | parser.add_argument("--languages", default="en", type=str, help="process language") 775 | parser.add_argument( 776 | "--remove_last_token", 777 | action="store_true", 778 | help="whether to remove the last token", 779 | ) 780 | parser.add_argument( 781 | "--remove_test_label", 782 | action="store_true", 783 | help="whether to remove test set label", 784 | ) 785 | args = parser.parse_args() 786 | 787 | if args.task == "panx_tokenize": 788 | panx_tokenize_preprocess(args) 789 | if args.task == "panx": 790 | panx_preprocess(args) 791 | if args.task == "udpos_tokenize": 792 | udpos_tokenize_preprocess(args) 793 | if args.task == "udpos": 794 | udpos_preprocess(args) 795 | if args.task == "pawsx": 796 | pawsx_preprocess(args) 797 | if args.task == "xnli": 798 | xnli_preprocess(args) 799 | if args.task == "tatoeba": 800 | tatoeba_preprocess(args) 801 | if args.task == "xquad": 802 | xquad_preprocess(args) 803 | if args.task == "mlqa": 804 | mlqa_preprocess(args) 805 | if args.task == "tydiqa": 806 | tydiqa_preprocess(args) 807 | if args.task == "mtop": 808 | mtop_preprocess(args) 809 | --------------------------------------------------------------------------------